Skip to content

Commit

Permalink
load and evaluate a single record
Browse files Browse the repository at this point in the history
  • Loading branch information
awni committed Mar 20, 2017
1 parent 800d18e commit 21e85b4
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 32 deletions.
49 changes: 20 additions & 29 deletions eval.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,13 @@
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import argparse
import json
import numpy as np
import os
import pickle
import tensorflow as tf

import loader
import network

tf.flags.DEFINE_string("save_path", None,
"Path to saved model.")
FLAGS = tf.flags.FLAGS

class Evaler:

def __init__(self, save_path, batch_size=1):
Expand All @@ -39,31 +33,28 @@ def probs(self, inputs):

def predict(self, inputs):
probs = self.probs(inputs)
return np.argmax(probs, axis=2)
return np.argmax(probs, axis=1)

def main(argv=None):
assert FLAGS.save_path is not None, \
"Must provide the path to a model directory."
def predict_record(record_id, model_path):
evaler = Evaler(model_path)

config_file = os.path.join(FLAGS.save_path, "config.json")
with open(config_file, 'r') as fid:
config = json.load(fid)
ldr_path = os.path.join(model_path, "loader.pkl")
with open(ldr_path, 'rb') as fid:
ldr = pickle.load(fid)

batch_size = 32
data_loader = loader.Loader(config['data']['path'], batch_size,
seed=config['data']['seed'])
inputs = ldr.load_preprocess(record_id)
outputs = evaler.predict([inputs])
return ldr.int_to_class(outputs[0])

evaler = Evaler(FLAGS.save_path, batch_size=batch_size)

corr = 0.0
total = 0
for inputs, labels in data_loader.batches(data_loader.val):
probs = evaler.probs(inputs)
predictions = np.vstack(predictions)
corr += np.sum(predictions == np.vstack(labels))
total += predictions.size
print("Number {}, Accuracy {:.3f}".format(total, corr / total))
def main():
parser = argparse.ArgumentParser(description="Evaluater Script")
parser.add_argument("model_path")
parser.add_argument("record")

args = parser.parse_args()
prediction = predict_record(args.record, args.model_path)
print(prediction)

if __name__ == "__main__":
tf.app.run()
main()

17 changes: 14 additions & 3 deletions loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,8 +94,8 @@ def compute_mean_std(self):
Estimates the mean and std over the training set.
"""
all_dat = np.hstack(w for w, _ in self._train)
self.mean = np.mean(all_dat)
self.std = np.std(all_dat)
self.mean = np.mean(all_dat, dtype=np.float32)
self.std = np.std(all_dat, dtype=np.float32)

@property
def output_dim(self):
Expand All @@ -112,6 +112,14 @@ def val(self):
""" Returns the raw validation set. """
return self._val

def load_preprocess(self, record_id):
ecg = load_ecg_mat(record_id + ".mat")
return self.normalize(ecg)

def int_to_class(self, label_int):
""" Convert integer label to class label. """
return self._int_to_class[label_int]

def __getstate__(self):
"""
For pickling.
Expand Down Expand Up @@ -145,7 +153,7 @@ def load_all_data(data_path, val_frac):
# Load raw ecg
for record, label in records:
ecg_file = os.path.join(data_path, record + ".mat")
ecg = sio.loadmat(ecg_file)['val'].squeeze()
ecg = load_ecg_mat(ecg_file)
all_records.append((ecg, label))

# Shuffle before train/val split
Expand All @@ -154,6 +162,9 @@ def load_all_data(data_path, val_frac):
train, val = all_records[cut:], all_records[:cut]
return train, val

def load_ecg_mat(ecg_file):
return sio.loadmat(ecg_file)['val'].squeeze()

def main():
parser = argparse.ArgumentParser(description="Data Loader")
parser.add_argument("-v", "--verbose",
Expand Down

0 comments on commit 21e85b4

Please sign in to comment.