import argparse
import os
import tempfile
import light_side as ls
import torch
def parse_arguments():
"""
Parse command line arguments.
Returns: Parsed arguments
"""
arg = argparse.ArgumentParser()
arg.add_argument(
"--model_name",
type=str,
default=ls.available_models()[0],
choices=ls.available_models(),
help="Model architecture",
)
arg.add_argument(
"--version",
type=str,
help="Model version",
)
arg.add_argument(
"--target",
"-t",
type=str,
help="Target path to save the model",
)
arg.add_argument(
"--quantize",
"-q",
action="store_true",
help="Quantize the model",
)
arg.add_argument(
"--opset-version",
type=int,
default=11,
help="Onnx opset version",
)
return arg.parse_args()
def main(args):
"""
Main function.
Args:
args : Parsed arguments
"""
# pylint: disable=no-member
if args.version:
if args.version not in ls.get_model_versions(args.model_name):
raise ValueError(
f"model version {args.version} not available for model {args.model_name}, available versions: {ls.get_model_versions(args.model_name)}"
)
version = args.version
else:
version = ls.get_model_latest_version(args.model_name)
model = ls.Enhancer.from_pretrained(
args.model_name,
version=version,
)
model.eval()
if args.target:
target_path = args.target
else:
target_path = os.path.join(
ls.core._get_model_dir(),
args.model_name,
f"v{version}",
)
print(f"Target Path: {target_path}")
input_names = ["input"]
output_names = ["output"]
dynamic_axes = {
"input": {0: "batch_size", 2: "width", 3: "height"},
"output": {0: "batch_size", 2: "width", 3: "height"},
}
input_sample = torch.rand(1, 3, model.input_size, model.input_size)
if args.quantize:
try:
from onnxruntime.quantization import quantize_qat
except ImportError:
raise AssertionError("run `pip install onnxruntime`")
target_model_path = os.path.join(
target_path,
"{}_quantize.onnx".format(args.model_name),
)
with tempfile.NamedTemporaryFile(suffix=".onnx") as temp:
model.to_onnx(
temp.name,
input_sample=input_sample,
opset_version=args.opset_version,
input_names=input_names,
output_names=output_names,
dynamic_axes=dynamic_axes,
export_params=True,
)
quantize_qat(temp.name, target_model_path)
else:
target_model_path = os.path.join(
target_path,
"{}.onnx".format(args.model_name),
)
model.to_onnx(
target_model_path,
input_sample=input_sample,
opset_version=args.opset_version,
input_names=input_names,
output_names=output_names,
dynamic_axes=dynamic_axes,
export_params=True,
)
print("Model saved")
if __name__ == "__main__":
pa = parse_arguments()
main(pa)