-
Notifications
You must be signed in to change notification settings - Fork 0
/
convert.py
35 lines (31 loc) · 1.4 KB
/
convert.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
import argparse
my_parser = argparse.ArgumentParser(description=" ")
my_parser.add_argument("--input", metavar="--input", type=str, help="input model")
my_parser.add_argument("--output", metavar="--output", type=str, help="output model")
my_parser.add_argument("--height", metavar="--height", type=int, help="height")
my_parser.add_argument("--width", metavar="--width", type=int, help="width")
args = my_parser.parse_args()
from cain.cain import CAIN
import torch
import os
model = CAIN(3)
model.load_state_dict(torch.load(args.input), strict=False)
input_names = ["input"]
output_names = ["output"]
f1 = torch.rand((1, 6, args.height, args.width))
x = f1
torch.onnx.export(
model, # model being run
x, # model input (or a tuple for multiple inputs)
"cain-temp.onnx", # where to save the model (can be a file or file-like object)
export_params=True, # store the trained parameter weights inside the model file
opset_version=16, # the ONNX version to export the model to
do_constant_folding=True, # whether to execute constant folding for optimization
input_names=input_names, # the model's input names
output_names=output_names,
dynamic_axes={'input' : {3 : 'width', 2: 'height'}} )#
del model
os.system("python3 -m onnxsim cain-temp.onnx cain-sim.onnx")
os.system(
f" trtexec --onnx=cain-sim.onnx --optShapes=input:1x6x{args.height}x{args.width} --fp16 --saveEngine={args.output}"
)