-
Notifications
You must be signed in to change notification settings - Fork 0
/
convert.py
22 lines (17 loc) · 747 Bytes
/
convert.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
import tensorflow as tf
import argparse
from tensorflow.python.tools import optimize_for_inference_lib
parser = argparse.ArgumentParser()
parser.add_argument('--graph', help='.pb graph path', default='resnet_v2_101_299_frozen.pb')
argv = parser.parse_args()
pb_file = argv.graph
graph_def = tf.compat.v1.GraphDef()
try:
with tf.io.gfile.GFile(pb_file, 'rb') as f:
graph_def.ParseFromString(f.read())
except:
with tf.io.gfile.GFile(pb_file, 'rb') as f:
graph_def.ParseFromString(f.read())
graph_def = optimize_for_inference_lib.optimize_for_inference(graph_def, ['input'], ['output'], tf.float32.as_datatype_enum)
with tf.io.gfile.GFile('resnet_v2_101_299_opt.pb', 'wb') as f:
f.write(graph_def.SerializeToString())