Skip to content

Commit

Permalink
flipping and scaling data augmentation
Browse files Browse the repository at this point in the history
  • Loading branch information
awni committed Apr 12, 2017
1 parent dc0dad3 commit 196e52c
Show file tree
Hide file tree
Showing 3 changed files with 19 additions and 5 deletions.
3 changes: 2 additions & 1 deletion configs/train.json
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@

"data" : {
"path" : "/deep/group/med/alivecor/training2017",
"seed" : 2016
"seed" : 2016,
"augment" : true
},

"optimizer" : {
Expand Down
15 changes: 13 additions & 2 deletions loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,13 +28,15 @@ class Loader:
"""

def __init__(self, data_path, batch_size,
val_frac=0.2, seed=None):
val_frac=0.2, seed=None,
augment=False):
"""
:param data_path: path to the training and validation files
:param batch_size: size of the minibatches to train on
:param val_frac: fraction of the dataset to use for validation
(held out by record)
:param seed: seed the rng for shuffling data
:param augment: set to true to augment the training data
"""
if not os.path.exists(data_path):
msg = "Non-existent data path: {}".format(data_path)
Expand All @@ -44,6 +46,7 @@ def __init__(self, data_path, batch_size,
random.seed(seed)

self.batch_size = batch_size
self.augment = augment

self._train, self._val = load_all_data(data_path, val_frac)
logger.debug("Training set has " + str(len(self._train)) + " samples")
Expand Down Expand Up @@ -122,7 +125,10 @@ def output_dim(self):
@property
def train(self):
""" Returns the raw training set. """
return self._train
for ecgs, labels in self._train:
if self.augment:
ecgs = [transform(ecg) for ecg in ecgs]
yield (ecgs, labels)

@property
def val(self):
Expand Down Expand Up @@ -157,6 +163,11 @@ def __setstate__(self, state):
self._class_to_int = state[3]
self.class_counts = state[4]

def transform(ecg):
scale = random.uniform(0.1, 5.0)
flip = random.choice([-1.0, 1.0])
return ecg * flip * scale

def load_all_data(data_path, val_frac):
"""
Returns tuple of training and validation sets. Each set
Expand Down
6 changes: 4 additions & 2 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ def run_epoch(model, data_loader, session, summarizer):
for batch in data_loader.train:
ops = [model.train_op, model.avg_loss,
model.avg_acc, model.it, summary_op]

res = session.run(ops, feed_dict=model.feed_dict(*batch))
_, loss, acc, it, summary = res
summarizer.add_summary(summary, global_step=it)
Expand Down Expand Up @@ -89,8 +90,9 @@ def main(argv=None):
random.seed(config['seed'])
epochs = config['optimizer']['epochs']
data_loader = loader.Loader(config['data']['path'],
config['model']['batch_size'],
seed=config['data']['seed'])
config['model']['batch_size'],
seed=config['data']['seed'],
augment=config['data'].get('augment', False))

model = network.Network(is_verbose)

Expand Down

0 comments on commit 196e52c

Please sign in to comment.