Tensorflow Lite Runtime#

import argparse
import os

import imageio as imageio
import numpy as np
from PIL import Image

os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
import tensorflow as tf


def parse_arguments():
    """
    Parse command line arguments.

    Returns: Parsed arguments
    """
    arg = argparse.ArgumentParser()

    arg.add_argument(
        "--model-path",
        "-m",
        type=str,
        required=True,
        help="Model path",
    )

    arg.add_argument(
        "--source",
        "-s",
        type=str,
        required=True,
        help="Path to the image file",
    )

    return arg.parse_args()


def main(args):
    """
    Main function.

    Args:
        args : Parsed arguments
    """
    img = imageio.imread(args.source)[:, :, :3]
    img = Image.fromarray(img).resize((224, 224))
    batch = np.expand_dims(np.transpose(img, (2, 0, 1)), 0).astype(np.float32)

    # Load the TFLite model and allocate tensors
    interpreter = tf.lite.Interpreter(model_path=args.model_path)
    interpreter.allocate_tensors()

    # Get input and output tensors
    input_details = interpreter.get_input_details()
    output_details = interpreter.get_output_details()

    interpreter.set_tensor(input_details[0]["index"], batch)

    interpreter.invoke()

    # get_tensor() returns a copy of the tensor data
    # use tensor() in order to get a pointer to the tensor
    preds = interpreter.get_tensor(output_details[0]["index"])
    x = np.array(preds).squeeze().transpose(1, 2, 0)

    enh_img = Image.fromarray(
        (np.clip(x, 0, 1) * 255).astype(np.uint8),
        mode="RGB",
    )
    enh_img.show()


if __name__ == "__main__":
    pa = parse_arguments()
    main(pa)