Skip to content

Commit

Permalink
decoding with prior
Browse files Browse the repository at this point in the history
  • Loading branch information
awni committed Apr 4, 2017
1 parent 3ac9fb9 commit e3f4ba2
Show file tree
Hide file tree
Showing 3 changed files with 31 additions and 11 deletions.
30 changes: 22 additions & 8 deletions evaler.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,9 @@

class Evaler:

def __init__(self, save_path, is_verbose, batch_size=1):
def __init__(self, save_path, is_verbose=False,
batch_size=1, class_counts=None,
smooth=350): # TODO, awni, setup way to x-val smoothing param
config_file = os.path.join(save_path, "config.json")

with open(config_file, 'r') as fid:
Expand All @@ -25,36 +27,48 @@ def __init__(self, save_path, is_verbose, batch_size=1):
saver = tf.train.Saver(tf.global_variables())
saver.restore(sess, os.path.join(save_path, "best_model.epoch"))

if class_counts is not None:
counts = np.array(class_counts)[None, :]
total = np.sum(counts) + counts.shape[1]
self.prior = (counts + smooth) / total
else:
self.prior = None

def probs(self, inputs):
model = self.model
probs, = self.session.run([model.probs], model.feed_dict(inputs))
if self.prior is not None:
probs /= self.prior
return probs

def predict(self, inputs):
probs = self.probs(inputs)
return np.argmax(probs, axis=1)

def predict_record(record_id, model_path):
evaler = Evaler(model_path)

def predict_record(record_id, model_path, prior=False):
ldr_path = os.path.join(model_path, "loader.pkl")
with open(ldr_path, 'rb') as fid:
ldr = pickle.load(fid)

if prior:
evaler = Evaler(model_path, class_counts=ldr.class_counts)
else:
evaler = Evaler(model_path)

inputs = ldr.load_preprocess(record_id)
outputs = evaler.predict([inputs])
return ldr.int_to_class(outputs[0])

def main():
parser = argparse.ArgumentParser(description="Evaluater Script")
parser.add_argument("-v", "--verbose",
default = False, action = "store_true")
parser.add_argument("-m", "--model_path")
parser.add_argument("-r", "--record")
parser.add_argument("-p", "--prior", action="store_true",
help="Decode with prior")

args = parser.parse_args()
prediction = predict_record(args.record, args.model_path)
logger.info(prediction)
prediction = predict_record(args.record, args.model_path,
prior=args.prior)

if __name__ == "__main__":
main()
Expand Down
5 changes: 4 additions & 1 deletion loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ def __init__(self, data_path, batch_size,
classes = sorted([c for c, _ in label_counter.most_common()])
self._int_to_class = dict(zip(range(len(classes)), classes))
self._class_to_int = {c : i for i, c in self._int_to_class.items()}
self.class_counts = [label_counter[c] for c in classes]

self._train = self.batches(self._train)
self._val = self.batches(self._val)
Expand Down Expand Up @@ -143,7 +144,8 @@ def __getstate__(self):
return (self.mean,
self.std,
self._int_to_class,
self._class_to_int)
self._class_to_int,
self.class_counts)

def __setstate__(self, state):
"""
Expand All @@ -153,6 +155,7 @@ def __setstate__(self, state):
self.std = state[1]
self._int_to_class = state[2]
self._class_to_int = state[3]
self.class_counts = state[4]

def load_all_data(data_path, val_frac):
"""
Expand Down
7 changes: 5 additions & 2 deletions score.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,6 @@ def print_scores(labels, predictions, classes):
logger.info("Macro Average F1: {:.3f}".format(macro_scores[2]))

def load_model(model_path, is_verbose, batch_size):
evl = evaler.Evaler(model_path, is_verbose,
batch_size=batch_size)

# TODO, (awni), would be good to simplify loading and
# not rely on random seed for validation set.
Expand All @@ -35,6 +33,11 @@ def load_model(model_path, is_verbose, batch_size):
ldr = loader.Loader(data_conf['path'],
batch_size,
seed=data_conf['seed'])

evl = evaler.Evaler(model_path, is_verbose,
batch_size=batch_size,
class_counts=ldr.class_counts)

return evl, ldr

def eval_all(ldr, evl):
Expand Down

0 comments on commit e3f4ba2

Please sign in to comment.