-
Notifications
You must be signed in to change notification settings - Fork 28
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
e069ea5
commit 2da83dc
Showing
47 changed files
with
30,053 additions
and
0 deletions.
There are no files selected for viewing
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,89 @@ | ||
## Install | ||
|
||
Clone the repository | ||
|
||
``` | ||
git clone [email protected]:awni/ecg.git | ||
``` | ||
|
||
If you don't have `virtualenv`, install it with | ||
|
||
``` | ||
pip install virtualenv | ||
``` | ||
|
||
Make and activate a new Python 2.7 environment | ||
|
||
``` | ||
virtualenv -p python2.7 ecg_env | ||
source ecg_env/bin/activate | ||
``` | ||
|
||
Install the requirements (this may take a few minutes). | ||
|
||
For CPU only support run | ||
``` | ||
./setup.sh | ||
``` | ||
|
||
To install with GPU support run | ||
``` | ||
env TF=gpu ./setup.sh | ||
``` | ||
|
||
## Training | ||
|
||
In the repo root direcotry (`ecg`) make a new directory called `saved`. | ||
|
||
``` | ||
mkdir saved | ||
``` | ||
|
||
To train a model use the following command, replacing `path_to_config.json` | ||
with an actual config: | ||
|
||
``` | ||
python ecg/train.py path_to_config.json | ||
``` | ||
|
||
Note that after each epoch the model is saved in | ||
`ecg/saved/<experiment_id>/<timestamp>/<model_id>.hdf5`. | ||
|
||
For an actual example of how to run this code on a real dataset, you can follow | ||
the instructions in the cinc17 [README](examples/cinc17/README.md). This will | ||
walk through downloading the Physionet 2017 challenge dataset and training and | ||
evaluating a model. | ||
|
||
## Testing | ||
|
||
After training the model for a few epochs, you can make predictions with. | ||
|
||
``` | ||
python ecg/predict.py <dataset>.json <model>.hdf5 | ||
``` | ||
|
||
replacing `<dataset>` with an actual path to the dataset and `<model>` with the | ||
path to the model. | ||
|
||
## Citation and Reference | ||
|
||
This work is published in the following paper in *Nature Medicine* | ||
|
||
[Cardiologist-level arrhythmia detection and classification in ambulatory electrocardiograms using a deep neural network](https://www.nature.com/articles/s41591-018-0268-3) | ||
|
||
If you find this codebase useful for your research please cite: | ||
|
||
``` | ||
@article{hannun2019cardiologist, | ||
title={Cardiologist-level arrhythmia detection and classification in ambulatory electrocardiograms using a deep neural network}, | ||
author={Hannun, Awni Y and Rajpurkar, Pranav and Haghpanahi, Masoumeh and Tison, Geoffrey H and Bourn, Codie and Turakhia, Mintu P and Ng, Andrew Y}, | ||
journal={Nature Medicine}, | ||
volume={25}, | ||
number={1}, | ||
pages={65}, | ||
year={2019}, | ||
publisher={Nature Publishing Group} | ||
} | ||
``` | ||
|
||
|
Empty file.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,95 @@ | ||
from __future__ import print_function | ||
from __future__ import division | ||
from __future__ import absolute_import | ||
|
||
import json | ||
import keras | ||
import numpy as np | ||
import os | ||
import random | ||
import scipy.io as sio | ||
import tqdm | ||
|
||
STEP = 256 | ||
|
||
def data_generator(batch_size, preproc, x, y): | ||
num_examples = len(x) | ||
examples = zip(x, y) | ||
examples = sorted(examples, key = lambda x: x[0].shape[0]) | ||
end = num_examples - batch_size + 1 | ||
batches = [examples[i:i+batch_size] | ||
for i in range(0, end, batch_size)] | ||
random.shuffle(batches) | ||
while True: | ||
for batch in batches: | ||
x, y = zip(*batch) | ||
yield preproc.process(x, y) | ||
|
||
class Preproc: | ||
|
||
def __init__(self, ecg, labels): | ||
self.mean, self.std = compute_mean_std(ecg) | ||
self.classes = sorted(set(l for label in labels for l in label)) | ||
self.int_to_class = dict( zip(range(len(self.classes)), self.classes)) | ||
self.class_to_int = {c : i for i, c in self.int_to_class.items()} | ||
|
||
def process(self, x, y): | ||
return self.process_x(x), self.process_y(y) | ||
|
||
def process_x(self, x): | ||
x = pad(x) | ||
x = (x - self.mean) / self.std | ||
x = x[:, :, None] | ||
return x | ||
|
||
def process_y(self, y): | ||
# TODO, awni, fix hack pad with noise for cinc | ||
y = pad([[self.class_to_int[c] for c in s] for s in y], val=3, dtype=np.int32) | ||
y = keras.utils.np_utils.to_categorical( | ||
y, num_classes=len(self.classes)) | ||
return y | ||
|
||
def pad(x, val=0, dtype=np.float32): | ||
max_len = max(len(i) for i in x) | ||
padded = np.full((len(x), max_len), val, dtype=dtype) | ||
for e, i in enumerate(x): | ||
padded[e, :len(i)] = i | ||
return padded | ||
|
||
def compute_mean_std(x): | ||
x = np.hstack(x) | ||
return (np.mean(x).astype(np.float32), | ||
np.std(x).astype(np.float32)) | ||
|
||
def load_dataset(data_json): | ||
with open(data_json, 'r') as fid: | ||
data = [json.loads(l) for l in fid] | ||
labels = []; ecgs = [] | ||
for d in tqdm.tqdm(data): | ||
labels.append(d['labels']) | ||
ecgs.append(load_ecg(d['ecg'])) | ||
return ecgs, labels | ||
|
||
def load_ecg(record): | ||
if os.path.splitext(record)[1] == ".npy": | ||
ecg = np.load(record) | ||
elif os.path.splitext(record)[1] == ".mat": | ||
ecg = sio.loadmat(record)['val'].squeeze() | ||
else: # Assumes binary 16 bit integers | ||
with open(record, 'r') as fid: | ||
ecg = np.fromfile(fid, dtype=np.int16) | ||
|
||
trunc_samp = STEP * int(len(ecg) / STEP) | ||
return ecg[:trunc_samp] | ||
|
||
from pathlib import Path | ||
|
||
if __name__ == "__main__": | ||
home = str(Path.home()) | ||
data_json = home + "/ecg/examples/cinc17/train.json" | ||
train = load_dataset(data_json) | ||
preproc = Preproc(*train) | ||
gen = data_generator(32, preproc, *train) | ||
for x, y in gen: | ||
print(x.shape, y.shape) | ||
break |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,139 @@ | ||
from keras import backend as K | ||
|
||
def _bn_relu(layer, dropout=0, **params): | ||
from keras.layers import BatchNormalization | ||
from keras.layers import Activation | ||
layer = BatchNormalization()(layer) | ||
layer = Activation(params["conv_activation"])(layer) | ||
|
||
if dropout > 0: | ||
from keras.layers import Dropout | ||
layer = Dropout(params["conv_dropout"])(layer) | ||
|
||
return layer | ||
|
||
def add_conv_weight( | ||
layer, | ||
filter_length, | ||
num_filters, | ||
subsample_length=1, | ||
**params): | ||
from keras.layers import Conv1D | ||
layer = Conv1D( | ||
filters=num_filters, | ||
kernel_size=filter_length, | ||
strides=subsample_length, | ||
padding='same', | ||
kernel_initializer=params["conv_init"])(layer) | ||
return layer | ||
|
||
|
||
def add_conv_layers(layer, **params): | ||
for subsample_length in params["conv_subsample_lengths"]: | ||
layer = add_conv_weight( | ||
layer, | ||
params["conv_filter_length"], | ||
params["conv_num_filters_start"], | ||
subsample_length=subsample_length, | ||
**params) | ||
layer = _bn_relu(layer, **params) | ||
return layer | ||
|
||
def resnet_block( | ||
layer, | ||
num_filters, | ||
subsample_length, | ||
block_index, | ||
**params): | ||
from keras.layers import Add | ||
from keras.layers import MaxPooling1D | ||
from keras.layers.core import Lambda | ||
|
||
def zeropad(x): | ||
y = K.zeros_like(x) | ||
return K.concatenate([x, y], axis=2) | ||
|
||
def zeropad_output_shape(input_shape): | ||
shape = list(input_shape) | ||
assert len(shape) == 3 | ||
shape[2] *= 2 | ||
return tuple(shape) | ||
|
||
shortcut = MaxPooling1D(pool_size=subsample_length)(layer) | ||
zero_pad = (block_index % params["conv_increase_channels_at"]) == 0 \ | ||
and block_index > 0 | ||
if zero_pad is True: | ||
shortcut = Lambda(zeropad, output_shape=zeropad_output_shape)(shortcut) | ||
|
||
for i in range(params["conv_num_skip"]): | ||
if not (block_index == 0 and i == 0): | ||
layer = _bn_relu( | ||
layer, | ||
dropout=params["conv_dropout"] if i > 0 else 0, | ||
**params) | ||
layer = add_conv_weight( | ||
layer, | ||
params["conv_filter_length"], | ||
num_filters, | ||
subsample_length if i == 0 else 1, | ||
**params) | ||
layer = Add()([shortcut, layer]) | ||
return layer | ||
|
||
def get_num_filters_at_index(index, num_start_filters, **params): | ||
return 2**int(index / params["conv_increase_channels_at"]) \ | ||
* num_start_filters | ||
|
||
def add_resnet_layers(layer, **params): | ||
layer = add_conv_weight( | ||
layer, | ||
params["conv_filter_length"], | ||
params["conv_num_filters_start"], | ||
subsample_length=1, | ||
**params) | ||
layer = _bn_relu(layer, **params) | ||
for index, subsample_length in enumerate(params["conv_subsample_lengths"]): | ||
num_filters = get_num_filters_at_index( | ||
index, params["conv_num_filters_start"], **params) | ||
layer = resnet_block( | ||
layer, | ||
num_filters, | ||
subsample_length, | ||
index, | ||
**params) | ||
layer = _bn_relu(layer, **params) | ||
return layer | ||
|
||
def add_output_layer(layer, **params): | ||
from keras.layers.core import Dense, Activation | ||
from keras.layers.wrappers import TimeDistributed | ||
layer = TimeDistributed(Dense(params["num_categories"]))(layer) | ||
return Activation('softmax')(layer) | ||
|
||
def add_compile(model, **params): | ||
from keras.optimizers import Adam | ||
optimizer = Adam( | ||
lr=params["learning_rate"], | ||
clipnorm=params.get("clipnorm", 1)) | ||
|
||
model.compile(loss='categorical_crossentropy', | ||
optimizer=optimizer, | ||
metrics=['accuracy']) | ||
|
||
def build_network(**params): | ||
from keras.models import Model | ||
from keras.layers import Input | ||
inputs = Input(shape=params['input_shape'], | ||
dtype='float32', | ||
name='inputs') | ||
|
||
if params.get('is_regular_conv', False): | ||
layer = add_conv_layers(inputs, **params) | ||
else: | ||
layer = add_resnet_layers(inputs, **params) | ||
|
||
output = add_output_layer(layer, **params) | ||
model = Model(inputs=[inputs], outputs=[output]) | ||
if params.get("compile", True): | ||
add_compile(model, **params) | ||
return model |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,26 @@ | ||
from __future__ import print_function | ||
|
||
import argparse | ||
import numpy as np | ||
import keras | ||
import os | ||
|
||
import load | ||
import util | ||
|
||
def predict(data_json, model_path): | ||
preproc = util.load(os.path.dirname(model_path)) | ||
dataset = load.load_dataset(data_json) | ||
x, y = preproc.process(*dataset) | ||
|
||
model = keras.models.load_model(model_path) | ||
probs = model.predict(x, verbose=1) | ||
|
||
return probs | ||
|
||
if __name__ == '__main__': | ||
parser = argparse.ArgumentParser() | ||
parser.add_argument("data_json", help="path to data json") | ||
parser.add_argument("model_path", help="path to model") | ||
args = parser.parse_args() | ||
probs = predict(args.data_json, args.model_path) |
Oops, something went wrong.