Skip to content

Commit

Permalink
update README and folder structure
Browse files Browse the repository at this point in the history
  • Loading branch information
ycq091044 committed Apr 2, 2022
1 parent eb46826 commit 6b0cbcb
Show file tree
Hide file tree
Showing 20 changed files with 209 additions and 112 deletions.
103 changes: 80 additions & 23 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,26 +10,24 @@ IJCAI2021 - MICRON - Medication Change Prediction
Artificial Intelligence, {IJCAI} 2021},
year = {2021}
}
```
### (Reference) Dependency
python 3.7, scipy 1.1.0, pandas 0.25.3, torch 1.4.0, numpy 1.16.5, dill

### Reproductive code folder structure
- data/
- !!! ``refer to`` https://github.com/ycq091044/SafeDrug ``for more information. The preparation files here are a subset from`` https://github.com/ycq091044/SafeDrug ``and the preprocessing file is a little bit different.``
- mapping files that collected from external sources
- drug-atc.csv: this is a CID-ATC file, which gives the mapping from CID code to detailed ATC code (we should truncate later)
- drug-DDI.csv: this a large file, could be downloaded from https://drive.google.com/file/d/1mnPc0O0ztz0fkv3HF-dpmBb8PLWsEoDz/view?usp=sharing
- ndc2atc_level4.csv: this is a NDC-RXCUI-ATC5 file, which gives the mapping information
- ndc2rxnorm_mapping.txt: rxnorm to RXCUI file
- other files that generated from mapping files and MIMIC dataset (we attach these files here, user could use our provided scripts to generate)
- data_final.pkl: intermediate result
- ddi_A_final.pkl: ddi matrix
- ehr_adj_final.pkl: used in GAMENet baseline (refer to https://github.com/sjy1203/GAMENet)
- (important) records_final.pkl: 100 patient visit-level record samples. Under MIMIC Dataset policy, we are not allowed to distribute the datasets. Practioners could go to https://physionet.org/content/mimiciii/1.4/ and requrest the access to MIMIC-III dataset and then run our processing script to get the complete preprocessed dataset file.
### Folder Specification
- data
- **processing.py**: our data preprocessing file.
- Input (extracted from external resources)
- PRESCRIPTIONS.csv: the prescription file from MIMIC-III raw dataset
- DIAGNOSES_ICD.csv: the diagnosis file from MIMIC-III raw dataset
- PROCEDURES_ICD.csv: the procedure file from MIMIC-III raw dataset
- RXCUI2atc4.csv: this is a NDC-RXCUI-ATC4 mapping file, and we only need the RXCUI to ATC4 mapping. This file is obtained from https://github.com/sjy1203/GAMENet, where the name is called ndc2atc_level4.csv.
- drug-atc.csv: this is a CID-ATC file, which gives the mapping from CID code to detailed ATC code (we will use the prefix of the ATC code latter for aggregation). This file is obtained from https://github.com/sjy1203/GAMENet.
- rxnorm2RXCUI.txt: rxnorm to RXCUI mapping file. This file is obtained from https://github.com/sjy1203/GAMENet, where the name is called ndc2rxnorm_mapping.csv.
- drug-DDI.csv: this a large file, containing the drug DDI information, coded by CID. The file could be downloaded from https://drive.google.com/file/d/1mnPc0O0ztz0fkv3HF-dpmBb8PLWsEoDz/view?usp=sharing
- Output
- ddi_A_final.pkl: ddi adjacency matrix
- ddi_matrix_H.pkl: H mask structure (This file is created by **ddi_mask_H.py**)
- ehr_adj_final.pkl: used in GAMENet baseline (if two drugs appear in one set, then they are connected)
- records_final.pkl: The final diagnosis-procedure-medication EHR records of each patient, used for train/val/test split.
- voc_final.pkl: diag/prod/med index to code dictionary
- dataset processing scripts
- preprocessing.py: is used to process the MIMIC original dataset
- src/
- MICRON.py: our model
- baselines:
Expand All @@ -43,7 +41,59 @@ python 3.7, scipy 1.1.0, pandas 0.25.3, torch 1.4.0, numpy 1.16.5, dill
- util.py
- layer.py
### Data Processing
> Dataset statistics can be found below
```
#patients 6350
#clinical events 15032
#diagnosis 1958
#med 151
#procedure 1430
#avg of diagnoses 10.5089143161256
#avg of medicines 11.865886109632783
#avg of procedures 3.8436668440659925
#avg of vists 2.367244094488189
#max of diagnoses 128
#max of medicines 68
#max of procedures 50
#max of visit 29
```
### Step 1: Package Dependency
- install the following package
```python
pip install scikit-learn, dill, dnc
```
Note that torch setup may vary according to GPU hardware. Generally, run the following
```python
pip install torch
```
If you are using RTX 3090, then plase use the following, which is the right way to make torch work.
```python
python3 -m pip install --user torch==1.8.0+cu111 torchvision==0.9.0+cu111 torchaudio==0.8.0 -f https://download.pytorch.org/whl/torch_stable.html
```

- install other packages if necessary
```python
pip install [xxx] # any required package if necessary, maybe do not specify the version
```

Here is a list of reference versions for all package

```shell
pandas: 1.3.0
dill: 0.3.4
torch: 1.8.0+cu111
rdkit: 2021.03.4
scikit-learn: 0.24.2
numpy: 1.21.1
```

Let us know any of the package dependency issue. Please pay special attention to pandas, some report that a high version of pandas would raise error for dill loading.


### Step 2: Data Processing

- Go to https://physionet.org/content/mimiciii/1.4/ to download the MIMIC-III dataset (You may need to get the certificate)

Expand All @@ -61,17 +111,24 @@ python 3.7, scipy 1.1.0, pandas 0.25.3, torch 1.4.0, numpy 1.16.5, dill
gzip -d DIAGNOSES_ICD.csv.gz # diagnosis information
```

- change the path in processing.py and processing the data to get a complete records_final.pkl
- download the DDI file and move it to the data folder
download https://drive.google.com/file/d/1mnPc0O0ztz0fkv3HF-dpmBb8PLWsEoDz/view?usp=sharing
```python
mv drug-DDI.csv ./data
```

- processing the data to get a complete records_final.pkl

```python
cd ./data
vim processing.py

# line 294~296
# line 323-325
# med_file = './physionet.org/files/mimiciii/1.4/PRESCRIPTIONS.csv'
# diag_file = './physionet.org/files/mimiciii/1.4/DIAGNOSES_ICD.csv'
# procedure_file = './physionet.org/files/mimiciii/1.4/PROCEDURES_ICD.csv'

python preprocessing.py
python processing.py
```

### Run the code
Expand Down
Binary file removed data/data_final.pkl
Binary file not shown.
Binary file removed data/ehr_adj_final.pkl
Binary file not shown.
Binary file removed data/idx2drug.pkl
Binary file not shown.
File renamed without changes.
File renamed without changes.
File renamed without changes.
Binary file not shown.
Binary file added data/output/ehr_adj_final.pkl
Binary file not shown.
Binary file added data/output/records_final.pkl
Binary file not shown.
Binary file added data/output/voc_final.pkl
Binary file not shown.
131 changes: 82 additions & 49 deletions data/preprocessing.py → data/processing.py
100644 → 100755
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
from xml.dom.pulldom import ErrorHandler
import pandas as pd
import dill
import numpy as np
from collections import defaultdict


##### process medications #####
# load med data
def med_process(med_file):
Expand All @@ -29,23 +29,23 @@ def med_process(med_file):
return med_pd

# medication mapping
def ndc2atc4(med_pd):
with open(ndc_rxnorm_file, 'r') as f:
ndc2rxnorm = eval(f.read())
med_pd['RXCUI'] = med_pd['NDC'].map(ndc2rxnorm)
def codeMapping2atc4(med_pd):
with open(rxnorm2RXCUI_file, 'r') as f:
rxnorm2RXCUI = eval(f.read())
med_pd['RXCUI'] = med_pd['NDC'].map(rxnorm2RXCUI)
med_pd.dropna(inplace=True)

rxnorm2atc = pd.read_csv(ndc2atc_file)
rxnorm2atc = rxnorm2atc.drop(columns=['YEAR','MONTH','NDC'])
rxnorm2atc.drop_duplicates(subset=['RXCUI'], inplace=True)
rxnorm2atc4 = pd.read_csv(RXCUI2atc4_file)
rxnorm2atc4 = rxnorm2atc4.drop(columns=['YEAR','MONTH','NDC'])
rxnorm2atc4.drop_duplicates(subset=['RXCUI'], inplace=True)
med_pd.drop(index = med_pd[med_pd['RXCUI'].isin([''])].index, axis=0, inplace=True)

med_pd['RXCUI'] = med_pd['RXCUI'].astype('int64')
med_pd = med_pd.reset_index(drop=True)
med_pd = med_pd.merge(rxnorm2atc, on=['RXCUI'])
med_pd = med_pd.merge(rxnorm2atc4, on=['RXCUI'])
med_pd.drop(columns=['NDC', 'RXCUI'], inplace=True)
med_pd = med_pd.rename(columns={'ATC4':'NDC'})
med_pd['NDC'] = med_pd['NDC'].map(lambda x: x[:4])
med_pd['ATC4'] = med_pd['ATC4'].map(lambda x: x[:4])
med_pd = med_pd.rename(columns={'ATC4':'ATC3'})
med_pd = med_pd.drop_duplicates()
med_pd = med_pd.reset_index(drop=True)
return med_pd
Expand All @@ -59,8 +59,8 @@ def process_visit_lg2(med_pd):

# most common medications
def filter_300_most_med(med_pd):
med_count = med_pd.groupby(by=['NDC']).size().reset_index().rename(columns={0:'count'}).sort_values(by=['count'],ascending=False).reset_index(drop=True)
med_pd = med_pd[med_pd['NDC'].isin(med_count.loc[:299, 'NDC'])]
med_count = med_pd.groupby(by=['ATC3']).size().reset_index().rename(columns={0:'count'}).sort_values(by=['count'],ascending=False).reset_index(drop=True)
med_pd = med_pd[med_pd['ATC3'].isin(med_count.loc[:299, 'ATC3'])]

return med_pd.reset_index(drop=True)

Expand Down Expand Up @@ -117,15 +117,14 @@ def combine_process(med_pd, diag_pd, pro_pd):

# flatten and merge
diag_pd = diag_pd.groupby(by=['SUBJECT_ID','HADM_ID'])['ICD9_CODE'].unique().reset_index()
med_pd = med_pd.groupby(by=['SUBJECT_ID', 'HADM_ID'])['NDC'].unique().reset_index()

med_pd = med_pd.groupby(by=['SUBJECT_ID', 'HADM_ID'])['ATC3'].unique().reset_index()
pro_pd = pro_pd.groupby(by=['SUBJECT_ID','HADM_ID'])['ICD9_CODE'].unique().reset_index().rename(columns={'ICD9_CODE':'PRO_CODE'})
med_pd['NDC'] = med_pd['NDC'].map(lambda x: list(x))
med_pd['ATC3'] = med_pd['ATC3'].map(lambda x: list(x))
pro_pd['PRO_CODE'] = pro_pd['PRO_CODE'].map(lambda x: list(x))
data = diag_pd.merge(med_pd, on=['SUBJECT_ID', 'HADM_ID'], how='inner')
data = data.merge(pro_pd, on=['SUBJECT_ID', 'HADM_ID'], how='inner')
# data['ICD9_CODE_Len'] = data['ICD9_CODE'].map(lambda x: len(x))
data['NDC_Len'] = data['NDC'].map(lambda x: len(x))
data['ATC3_num'] = data['ATC3'].map(lambda x: len(x))

return data

Expand All @@ -134,7 +133,7 @@ def statistics(data):
print('#clinical events ', len(data))

diag = data['ICD9_CODE'].values
med = data['NDC'].values
med = data['ATC3'].values
pro = data['PRO_CODE'].values

unique_diag = set([j for i in diag for j in list(i)])
Expand All @@ -155,7 +154,7 @@ def statistics(data):
visit_cnt += 1
cnt += 1
x.extend(list(row['ICD9_CODE']))
y.extend(list(row['NDC']))
y.extend(list(row['ATC3']))
z.extend(list(row['PRO_CODE']))
x, y, z = set(x), set(y), set(z)
avg_diag += len(x)
Expand Down Expand Up @@ -201,10 +200,10 @@ def create_str_token_mapping(df):

for index, row in df.iterrows():
diag_voc.add_sentence(row['ICD9_CODE'])
med_voc.add_sentence(row['NDC'])
med_voc.add_sentence(row['ATC3'])
pro_voc.add_sentence(row['PRO_CODE'])

dill.dump(obj={'diag_voc':diag_voc, 'med_voc':med_voc ,'pro_voc':pro_voc}, file=open('voc_final.pkl','wb'))
dill.dump(obj={'diag_voc':diag_voc, 'med_voc':med_voc ,'pro_voc':pro_voc}, file=open(vocabulary_file,'wb'))
return diag_voc, med_voc, pro_voc

# create final records
Expand All @@ -217,14 +216,12 @@ def create_patient_record(df, diag_voc, med_voc, pro_voc):
admission = []
admission.append([diag_voc.word2idx[i] for i in row['ICD9_CODE']])
admission.append([pro_voc.word2idx[i] for i in row['PRO_CODE']])
admission.append([med_voc.word2idx[i] for i in row['NDC']])
admission.append([med_voc.word2idx[i] for i in row['ATC3']])
patient.append(admission)
records.append(patient)
dill.dump(obj=records, file=open('records_final.pkl', 'wb'))
dill.dump(obj=records, file=open(ehr_sequence_file, 'wb'))
return records



# get ddi matrix
def get_ddi_matrix(records, med_voc, ddi_file):

Expand All @@ -236,7 +233,7 @@ def get_ddi_matrix(records, med_voc, ddi_file):
for item in med_unique_word:
atc3_atc4_dic[item[:4]].add(item)

with open(cid_atc, 'r') as f:
with open(cid2atc6_file, 'r') as f:
for line in f:
line_ls = line[:-1].split(',')
cid = line_ls[0]
Expand All @@ -248,7 +245,8 @@ def get_ddi_matrix(records, med_voc, ddi_file):
# ddi load
ddi_df = pd.read_csv(ddi_file)
# fliter sever side effect
ddi_most_pd = ddi_df.groupby(by=['Polypharmacy Side Effect', 'Side Effect Name']).size().reset_index().rename(columns={0:'count'}).sort_values(by=['count'],ascending=False).reset_index(drop=True)
ddi_most_pd = ddi_df.groupby(by=['Polypharmacy Side Effect', 'Side Effect Name'])\
.size().reset_index().rename(columns={0:'count'}).sort_values(by=['count'],ascending=False).reset_index(drop=True)
ddi_most_pd = ddi_most_pd.iloc[-TOPK:,:]
# ddi_most_pd = pd.DataFrame(columns=['Side Effect Name'], data=['as','asd','as'])
fliter_ddi_df = ddi_df.merge(ddi_most_pd[['Side Effect Name']], how='inner', on=['Side Effect Name'])
Expand All @@ -265,7 +263,7 @@ def get_ddi_matrix(records, med_voc, ddi_file):
continue
ehr_adj[med_i, med_j] = 1
ehr_adj[med_j, med_i] = 1
dill.dump(ehr_adj, open('ehr_adj_final.pkl', 'wb'))
dill.dump(ehr_adj, open(ehr_adjacency_file, 'wb'))

# ddi adj
ddi_adj = np.zeros((med_voc_size,med_voc_size))
Expand All @@ -284,36 +282,66 @@ def get_ddi_matrix(records, med_voc, ddi_file):
if med_voc.word2idx[i] != med_voc.word2idx[j]:
ddi_adj[med_voc.word2idx[i], med_voc.word2idx[j]] = 1
ddi_adj[med_voc.word2idx[j], med_voc.word2idx[i]] = 1
dill.dump(ddi_adj, open('ddi_A_final.pkl', 'wb'))
dill.dump(ddi_adj, open(ddi_adjacency_file, 'wb'))

return ddi_adj

def get_ddi_mask(atc42SMLES, med_voc):

# ATC3_List[22] = {0}
# ATC3_List[25] = {0}
# ATC3_List[27] = {0}
fraction = []
for k, v in med_voc.idx2word.items():
tempF = set()
for SMILES in atc42SMLES[v]:
try:
m = BRICS.BRICSDecompose(Chem.MolFromSmiles(SMILES))
for frac in m:
tempF.add(frac)
except:
pass
fraction.append(tempF)
fracSet = []
for i in fraction:
fracSet += i
fracSet = list(set(fracSet)) # set of all segments

ddi_matrix = np.zeros((len(med_voc.idx2word), len(fracSet)))
for i, fracList in enumerate(fraction):
for frac in fracList:
ddi_matrix[i, fracSet.index(frac)] = 1
return ddi_matrix


if __name__ == '__main__':

# files can be downloaded from https://mimic.physionet.org/gettingstarted/dbsetup/
# please change into your own MIMIC folder
med_file = '/srv/local/data/physionet.org/files/mimiciii/1.4/PRESCRIPTIONS.csv'
diag_file = '/srv/local/data/physionet.org/files/mimiciii/1.4/DIAGNOSES_ICD.csv'
procedure_file = '/srv/local/data/physionet.org/files/mimiciii/1.4/PROCEDURES_ICD.csv'

med_structure_file = 'idx2drug.pkl'

# drug code mapping files
ndc2atc_file = 'ndc2atc_level4.csv'
cid_atc = 'drug-atc.csv'
ndc_rxnorm_file = 'ndc2rxnorm_mapping.txt'

# ddi information
ddi_file = 'drug-DDI.csv'
cid_atc = 'drug-atc.csv'

# input auxiliary files
med_structure_file = './output/atc32SMILES.pkl'
RXCUI2atc4_file = './input/RXCUI2atc4.csv'
cid2atc6_file = './input/drug-atc.csv'
rxnorm2RXCUI_file = './input/rxnorm2RXCUI.txt'
ddi_file = './input/drug-DDI.csv'

# output files
ddi_adjacency_file = "./output/ddi_A_final.pkl"
ehr_adjacency_file = "./output/ehr_adj_final.pkl"
ehr_sequence_file = "./output/records_final.pkl"
vocabulary_file = "./output/voc_final.pkl"
ddi_mask_H_file = "./output/ddi_mask_H.pkl"

# for med
med_pd = med_process(med_file)
med_pd_lg2 = process_visit_lg2(med_pd).reset_index(drop=True)
med_pd = med_pd.merge(med_pd_lg2[['SUBJECT_ID']], on='SUBJECT_ID', how='inner')
med_pd = med_pd.merge(med_pd_lg2[['SUBJECT_ID']], on='SUBJECT_ID', how='inner').reset_index(drop=True)

med_pd = ndc2atc4(med_pd)
NDCList = dill.load(open(med_structure_file, 'rb'))
med_pd = med_pd[med_pd.NDC.isin(list(NDCList.keys()))]
med_pd = codeMapping2atc4(med_pd)
med_pd = filter_300_most_med(med_pd)

print ('complete medication processing')
Expand All @@ -325,18 +353,23 @@ def get_ddi_matrix(records, med_voc, ddi_file):

# for procedure
pro_pd = procedure_process(procedure_file)
pro_pd = filter_1000_most_pro(pro_pd)
# pro_pd = filter_1000_most_pro(pro_pd)

print ('complete procedure processing')

# combine
data = combine_process(med_pd, diag_pd, pro_pd)
statistics(data)
data.to_pickle('data_final.pkl')

print ('complete combining')

# ddi_matrix
# create vocab
diag_voc, med_voc, pro_voc = create_str_token_mapping(data)
print ("obtain voc")

# create ehr sequence data
records = create_patient_record(data, diag_voc, med_voc, pro_voc)
ddi_adj = get_ddi_matrix(records, med_voc, ddi_file)
print ("obtain ehr sequence data")

# create ddi adj matrix
ddi_adj = get_ddi_matrix(records, med_voc, ddi_file)
print ("obtain ddi adj matrix")
Binary file removed data/records_final.pkl
Binary file not shown.
Binary file removed data/voc_final.pkl
Binary file not shown.
Loading

0 comments on commit 6b0cbcb

Please sign in to comment.