-
Notifications
You must be signed in to change notification settings - Fork 7
/
test.py
59 lines (45 loc) · 1.63 KB
/
test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
import argparse
import json
import numpy as np
import os
import pickle
import tensorflow as tf
import network
class Evaler:
def __init__(self, save_path, batch_size=1):
config_file = os.path.join(save_path, "config.json")
with open(config_file, 'r') as fid:
config = json.load(fid)
config['model']['batch_size'] = batch_size
self.model = network.Network()
self.graph = tf.Graph()
self.session = sess = tf.Session(graph=self.graph)
with self.graph.as_default():
self.model.init_inference(config['model'])
tf.global_variables_initializer().run(session=sess)
saver = tf.train.Saver(tf.global_variables())
saver.restore(sess, os.path.join(save_path, "model"))
def probs(self, inputs):
model = self.model
probs, = self.session.run([model.probs], model.feed_dict(inputs))
return probs
def predict(self, inputs):
probs = self.probs(inputs)
return np.argmax(probs, axis=1)
def predict_record(record_id, model_path):
evaler = Evaler(model_path)
ldr_path = os.path.join(model_path, "loader.pkl")
with open(ldr_path, 'rb') as fid:
ldr = pickle.load(fid)
inputs = ldr.load_preprocess(record_id)
outputs = evaler.predict([inputs])
return ldr.int_to_class(outputs[0])
def main():
parser = argparse.ArgumentParser(description="Evaluater Script")
parser.add_argument("model_path")
parser.add_argument("record")
args = parser.parse_args()
prediction = predict_record(args.record, args.model_path)
print(prediction)
if __name__ == "__main__":
main()