-
Notifications
You must be signed in to change notification settings - Fork 297
/
test_cli.py
79 lines (70 loc) · 2.2 KB
/
test_cli.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
import inspect
import os
import subprocess
import unittest
import onnx
from onnx.backend.test.runner import Runner
from onnx.backend.test.case.model import TestCase
from onnx_tf.backend import TensorflowBackend
from onnx_tf.common import IS_PYTHON3
from onnx_tf.common.legacy import legacy_onnx_pre_ver
_ONNX_MODELS = [(
"mobilenetv2-1.0",
"https://s3.amazonaws.com/onnx-model-zoo/mobilenet/mobilenetv2-1.0/mobilenetv2-1.0.tar.gz"
)]
class TestCli(unittest.TestCase):
@staticmethod
def prepare_model(model_name, url):
if legacy_onnx_pre_ver(1, 5, 0):
prepare_model_data = Runner._prepare_model_data
else:
prepare_model_data = Runner.prepare_model_data
if IS_PYTHON3:
params = list(
inspect.signature(prepare_model_data).parameters.keys())
else:
params = inspect.getargspec(prepare_model_data).args
runner_class = Runner
if params[0] == "self":
runner_class = Runner(TensorflowBackend)
if legacy_onnx_pre_ver(1, 5, 0):
prepare_model_data = runner_class._prepare_model_data
else:
prepare_model_data = runner_class.prepare_model_data
if legacy_onnx_pre_ver(1, 4, 0):
tc = TestCase(
name="test_{}".format(model_name),
model_name=model_name,
url=url,
model_dir=None,
model=None,
data_sets=None,
kind='real')
else:
tc = TestCase(
name="test_{}".format(model_name),
model_name=model_name,
url=url,
model_dir=None,
model=None,
data_sets=None,
kind='real',
rtol=1e-3,
atol=1e-7)
return prepare_model_data(model_test=tc)
def test_convert_to_tf(self):
if legacy_onnx_pre_ver(1, 2, 1):
raise unittest.SkipTest(
"The current version of ONNX uses dead model link.")
for model_name, url in _ONNX_MODELS:
model_dir = self.prepare_model(model_name, url)
subprocess.check_call([
"onnx-tf",
"convert",
"-i",
os.path.join(model_dir, '{}.onnx'.format(model_name)),
"-o",
os.path.join(model_dir, '{}.pb'.format(model_name)),
])
if __name__ == '__main__':
unittest.main()