Skip to content

Commit

Permalink
merge
Browse files Browse the repository at this point in the history
  • Loading branch information
awni committed Apr 13, 2017
2 parents 4d246c0 + c197f93 commit 8a8f198
Show file tree
Hide file tree
Showing 4 changed files with 92 additions and 6 deletions.
63 changes: 63 additions & 0 deletions configs/test.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
{
"seed" : 4337,

"data" : {
"path" : "/deep/group/med/alivecor/training2017",
"seed" : 2016,
"augment" : true,
"random_noise": true,
"random_samples": 200
},

"optimizer" : {
"name" : "momentum",
"epochs" : 50,
"learning_rate" : 1e-3,
"momentum" : 0.95,
"decay_rate" : 1.0,
"decay_steps" : 2000
},

"model" : {
"dropout" : 0.0,
"batch_size" : 32,
"conv_layers" : [
{ "filter_size" : 64,
"num_filters" : 32,
"stride" : 2
},
{ "filter_size" : 64,
"num_filters" : 32,
"stride" : 2
},
{ "filter_size" : 64,
"num_filters" : 32,
"stride" : 2
},
{ "filter_size" : 64,
"num_filters" : 32,
"stride" : 2
},
{ "filter_size" : 64,
"num_filters" : 32,
"stride" : 2
},
{ "filter_size" : 64,
"num_filters" : 32,
"stride" : 2
},
{ "filter_size" : 64,
"num_filters" : 32,
"stride" : 2
},
{ "filter_size" : 64,
"num_filters" : 32,
"stride" : 2
}
]
},

"io" : {
"output_save_path" : "/deep/group/sudnya/test_noise-200"
}
}
4 changes: 3 additions & 1 deletion configs/train.json
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,9 @@
"data" : {
"path" : "/deep/group/med/alivecor/training2017",
"seed" : 2016,
"augment" : true
"augment" : true,
"random_noise": true,
"random_samples": 200
},

"optimizer" : {
Expand Down
27 changes: 23 additions & 4 deletions loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ class Loader:

def __init__(self, data_path, batch_size,
val_frac=0.2, seed=None,
augment=False):
augment=False, random_noise=False, random_samples=200):
"""
:param data_path: path to the training and validation files
:param batch_size: size of the minibatches to train on
Expand All @@ -48,7 +48,7 @@ def __init__(self, data_path, batch_size,
self.batch_size = batch_size
self.augment = augment

self._train, self._val = load_all_data(data_path, val_frac)
self._train, self._val = load_all_data(data_path, val_frac, random_noise, random_samples)
logger.debug("Training set has " + str(len(self._train)) + " samples")
logger.debug("Validation set has " + str(len(self._val)) + " samples")

Expand Down Expand Up @@ -177,7 +177,20 @@ def transform(ecg):

return ecg * flip * scale

def load_all_data(data_path, val_frac):

def add_random_noise_samples(sample_count):
retVal = []
for i in range(0, sample_count):
length_window = random.randint(3000, 18000)
logger.debug("Random window length " + str(length_window))
retVal.append((np.random.randint(low=-100, high=100, size=(length_window), dtype=np.int16), 'N'))
logger.debug("First entry of random samples" + str(retVal[0]))

return retVal



def load_all_data(data_path, val_frac, train_noise, random_samples):
"""
Returns tuple of training and validation sets. Each set
will contain a list of pairs of raw ecg and the
Expand All @@ -193,12 +206,18 @@ def load_all_data(data_path, val_frac):
for record, label in records:
ecg_file = os.path.join(data_path, record + ".mat")
ecg = load_ecg_mat(ecg_file)
#logger.info( "ecg " + str(ecg.size))
all_records.append((ecg, label))

# Shuffle before train/val split
random.shuffle(all_records)
cut = int(len(all_records) * val_frac)
train, val = all_records[cut:], all_records[:cut]
logger.info(len(train))
logger.info(train[0])
if train_noise:
logger.info("Adding random noise samples to training set")
t = add_random_noise_samples(random_samples)
return train, val

def load_ecg_mat(ecg_file):
Expand Down Expand Up @@ -226,7 +245,7 @@ def main():
logging.basicConfig(level=logging.INFO)

random.seed(2016)
ldr = Loader(data_path, batch_size)
ldr = Loader(data_path, batch_size, random_noise=True, random_samples=200)
logger.info("Length of training set {}".format(len(list(ldr.train))))
logger.info("Length of validation set {}".format(len(ldr.val)))
logger.info("Output dimension {}".format(ldr.output_dim))
Expand Down
4 changes: 3 additions & 1 deletion train.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,9 @@ def main(argv=None):
data_loader = loader.Loader(config['data']['path'],
config['model']['batch_size'],
seed=config['data']['seed'],
augment=config['data'].get('augment', False))
augment=config['data'].get('augment', False),
random_noise=config['data'].get('random_noise', False),
random_samples=config['data'].get('random_samples', 0))

model = network.Network(is_verbose)

Expand Down

0 comments on commit 8a8f198

Please sign in to comment.