Skip to content

Commit

Permalink
clarify dependency and add util.py
Browse files Browse the repository at this point in the history
  • Loading branch information
ycq091044 committed Sep 5, 2021
1 parent f4c1ac2 commit b3034e9
Show file tree
Hide file tree
Showing 3 changed files with 261 additions and 9 deletions.
11 changes: 5 additions & 6 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,19 +16,18 @@ python 3.7, scipy 1.1.0, pandas 0.25.3, torch 1.4.0, numpy 1.16.5, dill

### Reproductive code folder structure
- data/
- refer to https://github.com/ycq091044/SafeDrug for more information. The preparation files here are a subset from https://github.com/ycq091044/SafeDrug, and the preprocessing file is a little bit different.
- !!! ``refer to https://github.com/ycq091044/SafeDrug for more information. The preparation files here are a subset from https://github.com/ycq091044/SafeDrug, and the preprocessing file is a little bit different.``
- mapping files that collected from external sources
- drug-atc.csv: drug to atc code mapping file
- drug-atc.csv: this is a CID-ATC file, which gives the mapping from CID code to detailed ATC code (we should truncate later)
- drug-DDI.csv: this a large file, could be downloaded from https://drive.google.com/file/d/1mnPc0O0ztz0fkv3HF-dpmBb8PLWsEoDz/view?usp=sharing
- ndc2atc_level4.csv: NDC code to ATC-4 code mapping file
- ndc2rxnorm_mapping.txt: NDC to xnorm mapping file
- idx2drug.pkl: drug ID to drug SMILES string dict
- ndc2atc_level4.csv: this is a NDC-RXCUI-ATC5 file, which gives the mapping information
- ndc2rxnorm_mapping.txt: rxnorm to RXCUI file
- other files that generated from mapping files and MIMIC dataset (we attach these files here, user could use our provided scripts to generate)
- data_final.pkl: intermediate result
- ddi_A_final.pkl: ddi matrix
- ehr_adj_final.pkl: used in GAMENet baseline (refer to https://github.com/sjy1203/GAMENet)
- (important) records_final.pkl: 100 patient visit-level record samples. Under MIMIC Dataset policy, we are not allowed to distribute the datasets. Practioners could go to https://physionet.org/content/mimiciii/1.4/ and requrest the access to MIMIC-III dataset and then run our processing script to get the complete preprocessed dataset file.
- voc_final.pkl: diag/prod/med dictionary
- voc_final.pkl: diag/prod/med index to code dictionary
- dataset processing scripts
- preprocessing.py: is used to process the MIMIC original dataset
- src/
Expand Down
6 changes: 3 additions & 3 deletions data/preprocessing.py
Original file line number Diff line number Diff line change
Expand Up @@ -291,9 +291,9 @@ def get_ddi_matrix(records, med_voc, ddi_file):
if __name__ == '__main__':

# files can be downloaded from https://mimic.physionet.org/gettingstarted/dbsetup/
med_file = 'xxx/PRESCRIPTIONS.csv'
diag_file = 'xxx/DIAGNOSES_ICD.csv'
procedure_file = 'xxx/PROCEDURES_ICD.csv'
med_file = '/srv/local/data/physionet.org/files/mimiciii/1.4/PRESCRIPTIONS.csv'
diag_file = '/srv/local/data/physionet.org/files/mimiciii/1.4/DIAGNOSES_ICD.csv'
procedure_file = '/srv/local/data/physionet.org/files/mimiciii/1.4/PROCEDURES_ICD.csv'

med_structure_file = 'idx2drug.pkl'

Expand Down
253 changes: 253 additions & 0 deletions src/util.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,253 @@
from sklearn.metrics import jaccard_score, roc_auc_score, precision_score, f1_score, average_precision_score
import numpy as np
import pandas as pd
from sklearn.model_selection import train_test_split
import sys
import warnings
import dill
from collections import Counter
from collections import defaultdict
import torch
warnings.filterwarnings('ignore')

def get_n_params(model):
pp=0
for p in list(model.parameters()):
nn=1
for s in list(p.size()):
nn = nn*s
pp += nn
return pp

# use the same metric from DMNC
def llprint(message):
sys.stdout.write(message)
sys.stdout.flush()

def transform_split(X, Y):
x_train, x_eval, y_train, y_eval = train_test_split(X, Y, train_size=2/3, random_state=1203)
x_eval, x_test, y_eval, y_test = train_test_split(x_eval, y_eval, test_size=0.5, random_state=1203)
return x_train, x_eval, x_test, y_train, y_eval, y_test

def sequence_output_process(output_logits, filter_token):
pind = np.argsort(output_logits, axis=-1)[:, ::-1]

out_list = []
break_flag = False
for i in range(len(pind)):
if break_flag:
break
for j in range(pind.shape[1]):
label = pind[i][j]
if label in filter_token:
break_flag = True
break
if label not in out_list:
out_list.append(label)
break
y_pred_prob_tmp = []
for idx, item in enumerate(out_list):
y_pred_prob_tmp.append(output_logits[idx, item])
sorted_predict = [x for _, x in sorted(zip(y_pred_prob_tmp, out_list), reverse=True)]
return out_list, sorted_predict


def sequence_metric(y_gt, y_pred, y_prob, y_label):
def average_prc(y_gt, y_label):
score = []
for b in range(y_gt.shape[0]):
target = np.where(y_gt[b]==1)[0]
out_list = y_label[b]
inter = set(out_list) & set(target)
prc_score = 0 if len(out_list) == 0 else len(inter) / len(out_list)
score.append(prc_score)
return score


def average_recall(y_gt, y_label):
score = []
for b in range(y_gt.shape[0]):
target = np.where(y_gt[b] == 1)[0]
out_list = y_label[b]
inter = set(out_list) & set(target)
recall_score = 0 if len(target) == 0 else len(inter) / len(target)
score.append(recall_score)
return score


def average_f1(average_prc, average_recall):
score = []
for idx in range(len(average_prc)):
if (average_prc[idx] + average_recall[idx]) == 0:
score.append(0)
else:
score.append(2*average_prc[idx]*average_recall[idx] / (average_prc[idx] + average_recall[idx]))
return score


def jaccard(y_gt, y_label):
score = []
for b in range(y_gt.shape[0]):
target = np.where(y_gt[b] == 1)[0]
out_list = y_label[b]
inter = set(out_list) & set(target)
union = set(out_list) | set(target)
jaccard_score = 0 if union == 0 else len(inter) / len(union)
score.append(jaccard_score)
return np.mean(score)

def f1(y_gt, y_pred):
all_micro = []
for b in range(y_gt.shape[0]):
all_micro.append(f1_score(y_gt[b], y_pred[b], average='macro'))
return np.mean(all_micro)

def roc_auc(y_gt, y_pred_prob):
all_micro = []
for b in range(len(y_gt)):
all_micro.append(roc_auc_score(y_gt[b], y_pred_prob[b], average='macro'))
return np.mean(all_micro)

def precision_auc(y_gt, y_prob):
all_micro = []
for b in range(len(y_gt)):
all_micro.append(average_precision_score(y_gt[b], y_prob[b], average='macro'))
return np.mean(all_micro)

def precision_at_k(y_gt, y_prob_label, k):
precision = 0
for i in range(len(y_gt)):
TP = 0
for j in y_prob_label[i][:k]:
if y_gt[i, j] == 1:
TP += 1
precision += TP / k
return precision / len(y_gt)
try:
auc = roc_auc(y_gt, y_prob)
except ValueError:
auc = 0
p_1 = precision_at_k(y_gt, y_label, k=1)
p_3 = precision_at_k(y_gt, y_label, k=3)
p_5 = precision_at_k(y_gt, y_label, k=5)
f1 = f1(y_gt, y_pred)
prauc = precision_auc(y_gt, y_prob)
ja = jaccard(y_gt, y_label)
avg_prc = average_prc(y_gt, y_label)
avg_recall = average_recall(y_gt, y_label)
avg_f1 = average_f1(avg_prc, avg_recall)

return ja, prauc, np.mean(avg_prc), np.mean(avg_recall), np.mean(avg_f1)


def multi_label_metric(y_gt, y_pred, y_prob):

def jaccard(y_gt, y_pred):
score = []
for b in range(y_gt.shape[0]):
target = np.where(y_gt[b] == 1)[0]
out_list = np.where(y_pred[b] == 1)[0]
inter = set(out_list) & set(target)
union = set(out_list) | set(target)
jaccard_score = 0 if union == 0 else len(inter) / len(union)
score.append(jaccard_score)
return np.mean(score)

def average_prc(y_gt, y_pred):
score = []
for b in range(y_gt.shape[0]):
target = np.where(y_gt[b] == 1)[0]
out_list = np.where(y_pred[b] == 1)[0]
inter = set(out_list) & set(target)
prc_score = 0 if len(out_list) == 0 else len(inter) / len(out_list)
score.append(prc_score)
return score

def average_recall(y_gt, y_pred):
score = []
for b in range(y_gt.shape[0]):
target = np.where(y_gt[b] == 1)[0]
out_list = np.where(y_pred[b] == 1)[0]
inter = set(out_list) & set(target)
recall_score = 0 if len(target) == 0 else len(inter) / len(target)
score.append(recall_score)
return score

def average_f1(average_prc, average_recall):
score = []
for idx in range(len(average_prc)):
if average_prc[idx] + average_recall[idx] == 0:
score.append(0)
else:
score.append(2*average_prc[idx]*average_recall[idx] / (average_prc[idx] + average_recall[idx]))
return score

def f1(y_gt, y_pred):
all_micro = []
for b in range(y_gt.shape[0]):
all_micro.append(f1_score(y_gt[b], y_pred[b], average='macro'))
return np.mean(all_micro)

def roc_auc(y_gt, y_prob):
all_micro = []
for b in range(len(y_gt)):
all_micro.append(roc_auc_score(y_gt[b], y_prob[b], average='macro'))
return np.mean(all_micro)

def precision_auc(y_gt, y_prob):
all_micro = []
for b in range(len(y_gt)):
all_micro.append(average_precision_score(y_gt[b], y_prob[b], average='macro'))
return np.mean(all_micro)

def precision_at_k(y_gt, y_prob, k=3):
precision = 0
sort_index = np.argsort(y_prob, axis=-1)[:, ::-1][:, :k]
for i in range(len(y_gt)):
TP = 0
for j in range(len(sort_index[i])):
if y_gt[i, sort_index[i, j]] == 1:
TP += 1
precision += TP / len(sort_index[i])
return precision / len(y_gt)

# roc_auc
try:
auc = roc_auc(y_gt, y_prob)
except:
auc = 0
# precision
p_1 = precision_at_k(y_gt, y_prob, k=1)
p_3 = precision_at_k(y_gt, y_prob, k=3)
p_5 = precision_at_k(y_gt, y_prob, k=5)
# macro f1
f1 = f1(y_gt, y_pred)
# precision
prauc = precision_auc(y_gt, y_prob)
# jaccard
ja = jaccard(y_gt, y_pred)
# pre, recall, f1
avg_prc = average_prc(y_gt, y_pred)
avg_recall = average_recall(y_gt, y_pred)
avg_f1 = average_f1(avg_prc, avg_recall)

return ja, prauc, np.mean(avg_prc), np.mean(avg_recall), np.mean(avg_f1)

def ddi_rate_score(record, path='../data/ddi_A_final.pkl'):
# ddi rate
ddi_A = dill.load(open(path, 'rb'))
all_cnt = 0
dd_cnt = 0
for patient in record:
for adm in patient:
med_code_set = adm
for i, med_i in enumerate(med_code_set):
for j, med_j in enumerate(med_code_set):
if j <= i:
continue
all_cnt += 1
if ddi_A[med_i, med_j] == 1 or ddi_A[med_j, med_i] == 1:
dd_cnt += 1
if all_cnt == 0:
return 0
return dd_cnt / all_cnt

0 comments on commit b3034e9

Please sign in to comment.