Skip to content

Commit

Permalink
adds cluster model/additional logic
Browse files Browse the repository at this point in the history
  • Loading branch information
Adam Amster committed Apr 19, 2019
1 parent 0abf168 commit 7a8c2ea
Show file tree
Hide file tree
Showing 8 changed files with 301 additions and 22 deletions.
6 changes: 5 additions & 1 deletion clustering/clustering.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,8 @@ def calc_ratio(cluster, cluster_idx):
fig.savefig('feature_p_cluster_to_pop.png', bbox_inches='tight')

if __name__ == '__main__':
np.random.seed(0)

#df = pd.read_csv('../data/data.csv')
#df = pd.read_csv('../data/data_drug_names.csv')
df = pd.read_csv('../data/data_questionnaire_for_clustering.csv')
Expand Down Expand Up @@ -127,4 +129,6 @@ def calc_ratio(cluster, cluster_idx):
plot_ratio_feature_in_cluster_to_pop_matrix(df, nonzero_features)



# write results to disk
pd.DataFrame.from_records(zip(df.cluster, df.SEQN), columns=['CLUSTER', 'SEQN']).to_csv('seqn_cluster_map.csv',
index=False)
4 changes: 3 additions & 1 deletion preprocessing/demographics.R
Original file line number Diff line number Diff line change
Expand Up @@ -122,4 +122,6 @@ read_and_preprocess_data = function(years) {
return (data)
}

read_and_preprocess_data(c('2015-2016'))
# read_and_preprocess_data( years = c(
# '2015-2016'
# ))
7 changes: 6 additions & 1 deletion preprocessing/examination.R
Original file line number Diff line number Diff line change
Expand Up @@ -36,4 +36,9 @@ read_and_preprocess_data = function(years) {

data = rbindlist(data, use.names = T)
return (data)
}
}

# read_and_preprocess_data( years = c(
# '2015-2016', '2013-2014', '2011-2012', '2009-2010',
# '2007-2008', '2005-2006', '2003-2004', '2001-2002', '1999-2000'
# ))
49 changes: 49 additions & 0 deletions preprocessing/laboratory.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
library(modules)

import_package('data.table', attach = T)
import_package('dplyr', attach = T)
import_package('stringr', attach = T)
import_package('magrittr', attach = T)

preprocessing_util = import('preprocessing/util')

STANDARD_BIOCHEMISTRY_COLS = list(
list(old = 'LBXSCR', new = 'lab_creatinine'),
list(old = 'LBXSTR', new = 'lab_triglycerides'),
list(old = 'LBXSCH', new = 'lab_cholesterol'),
list(old = 'LBXSIR', new = 'lab_iron')
)

GLYCO_COLS = list(
list(old = 'LBXGH', new = 'lab_a1c')
)

read_and_preprocess_data = function(years) {
data = list()
for (year in years) {
base_dir = str_interp('~/emoryhealthcare/data/${year}/Laboratory')
files = list.files(base_dir)

file_name = files[str_detect(files, 'BIOPRO_\\w{1}.csv|L40_\\w{1}.csv|LAB18.csv')][1]
path = str_interp('${base_dir}/${file_name}')
bio = preprocessing_util$read_and_preprocess_data(path, STANDARD_BIOCHEMISTRY_COLS)

file_name = files[str_detect(files, 'GHB_\\w{1}.csv|L10_\\w{1}.csv|LAB10.csv')][1]
path = str_interp('${base_dir}/${file_name}')
glyco = preprocessing_util$read_and_preprocess_data(path, GLYCO_COLS)

data_ = bio %>%
merge(glyco, all = T, by = 'SEQN')

data = append(data, list(data_))
}

data = rbindlist(data, use.names = T, fill=T)

return (data)
}

read_and_preprocess_data( years = c(
'2015-2016', '2013-2014', '2011-2012', '2009-2010',
'2007-2008', '2005-2006', '2003-2004', '2001-2002', '1999-2000'
))
5 changes: 4 additions & 1 deletion preprocessing/preprocessing.R
Original file line number Diff line number Diff line change
Expand Up @@ -4,16 +4,19 @@ demo_preprocessing = import('preprocessing/demographics')
questionnaire_preprocessing = import('preprocessing/questionnaire')
examination_preprocessing = import('preprocessing/examination')
dietary_preprocessing = import('preprocessing/dietary')
lab_preprocessing = import('preprocessing/laboratory')

read_and_preprocess_data = function(years) {
demographics = demo_preprocessing$read_and_preprocess_data(years = years)
questionnaire = questionnaire_preprocessing$read_and_preprocess_data(years = years)
examination = examination_preprocessing$read_and_preprocess_data(years = years)
#dietary = dietary_preprocessing$read_and_preprocess_data(years = years)
lab = lab_preprocessing$read_and_preprocess_data(years = years)

data = demographics %>%
merge(questionnaire, on = 'SEQN') %>%
merge(examination, on = 'SEQN')
merge(examination, on = 'SEQN') %>%
merge(lab, on = 'SEQN')

return(data)
}
33 changes: 26 additions & 7 deletions preprocessing/questionnaire.R
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,26 @@ HOSPITAL_UTILIZATION_COLS = list(
list(old = 'HUD070', new = 'questionnaire_overnight_hospital_patient_in_last_year')
)

SMOKING_COLS = list(
list(old = 'SMQ040', new = 'smoke_cigarettes')
)

######
# Recoding
######
recode_hospital_patient_last_year = function(code) {
x = case_when(
code == 1 ~ 'Yes',
code == 2 ~ 'No'
code == 2 ~ 'No',
T ~ 'No'
)
return (factor(x, levels = c('Yes', 'No')))
}

recode_smoking = function(code) {
x = case_when(
code == 1 ~ 'Yes',
T ~ 'No'
)
return (factor(x, levels = c('Yes', 'No')))
}
Expand All @@ -39,9 +52,10 @@ preprocess_med_category = function(med, med_info, hospital_util) {
med_[, drug_category := paste0('drug_category_', drug_category)]
med_ = med_[, .(drug_count = .N), by = .(SEQN, drug_category)]
med_[, total_drug_count := sum(drug_count), by = .(SEQN)]
med_[, drug_count := ifelse(drug_count > 0, 1, 0)]

seqn_no_meds = setdiff(hospital_util$SEQN, med_$SEQN)
no_meds = data.table(seqn = seqn_no_meds, drug_category=NA, drug_count=0, total_drug_count=0)
no_meds = data.table(SEQN = seqn_no_meds, drug_category=NA, drug_count=0, total_drug_count=0)
med_ = rbindlist(list(med_, no_meds))

med_wide = dcast(med_, SEQN ~ drug_category, fill=0, value.var = 'drug_count')
Expand Down Expand Up @@ -105,9 +119,15 @@ read_and_preprocess_data = function(years, use_drug_category = T) {
path = str_interp('${base_dir}/${file_name}')
med_info = fread(path)
med = preprocess_meds(med, med_info, hospital_util, use_drug_category = use_drug_category)

file_name = files[str_detect(files, 'SMQ_*')][1]
path = str_interp('${base_dir}/${file_name}')
smoking = preprocessing_util$read_and_preprocess_data(path, SMOKING_COLS)
smoking[, smoke_cigarettes := recode_smoking(smoke_cigarettes)]

data_ = hospital_util %>%
merge(med, all = T, by = 'SEQN')
merge(med, all = T, by = 'SEQN') %>%
merge(smoking, all = T, by = 'SEQN')

data = append(data, list(data_))
}
Expand All @@ -128,9 +148,8 @@ read_and_preprocess_data = function(years, use_drug_category = T) {
# med_info = fread('~/emoryhealthcare/data/2015-2016/Questionnaire/RXQ_DRUG.csv')
# hospital_util = fread('~/emoryhealthcare/data/2015-2016/Questionnaire/HUQ_I.csv')
# preprocess_meds(med, med_info, hospital_util, use_drug_category = F)
# read_and_preprocess_data(years = c(
# '2015-2016', '2013-2014', '2011-2012', '2009-2010',
# '2007-2008', '2005-2006', '2003-2004', '2001-2002', '1999-2000'
# read_and_preprocess_data( years = c(
# '2015-2016'
# ))

get_data_for_clustering = function() {
Expand All @@ -153,4 +172,4 @@ get_data_for_clustering = function() {
fwrite(data, 'data/data_questionnaire_for_clustering.csv')
}

get_data_for_clustering()
# get_data_for_clustering()
59 changes: 48 additions & 11 deletions risk_factor_clf.R
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ library(magrittr)
library(ggplot2)
library(dplyr)
library(caret)
library(gridExtra)

preprocessing = import('preprocessing/preprocessing')
metrics = import('metrics')
Expand Down Expand Up @@ -97,8 +98,9 @@ resamps = resamples(
summary(resamps)

# Plot cross validation metrics
ggplot(model_xgbTree_smote) + labs(title = 'xgboost')
ggplot(model_lasso_smote) + labs(title = 'Lasso logistic regression')
g1 = ggplot(model_xgbTree_smote) + labs(title = 'xgboost')
g2 = ggplot(model_lasso_smote) + labs(title = 'Lasso logistic regression')
grid.arrange(g1, g2, nrow=1)

## Test performance
ROCR::performance(
Expand Down Expand Up @@ -166,12 +168,47 @@ ggplot(
labs(title = 'lasso', x = 'variable')

# Looking at individual predictions
pred = predict(model_xgbTree_smote, type = "prob", newdata = test)[[1]]

high = sort(pred, decreasing = T, index.return=T)$ix[1]
high_obs = test[high]
high_pred = pred[high]

low = sort(pred, decreasing = T, index.return=T)$ix[length(pred)]
low_obs = test[low]
low_pred = pred[low]
# pred = predict(model_xgbTree_smote, type = "prob", newdata = test)[[1]]
#
# high = sort(pred, decreasing = T, index.return=T)$ix[1]
# high_obs = test[high]
# high_pred = pred[high]
#
# low = sort(pred, decreasing = T, index.return=T)$ix[length(pred)]
# low_obs = test[low]
# low_pred = pred[low]

# LIME local approximation
# X = train[, .SD, .SDcols=names(train)[names(train) != "questionnaire_overnight_hospital_patient_in_last_year"]]
#
# pred = predict(model_xgbTree_smote, type="prob")[,1]
# test_idx = c(which.min(pred), which.max(pred))
# X_train = X[-test_idx]
# X_test = X[test_idx]
#
# explainer = lime::lime(X_train, model_xgbTree_smote)
# explanation = lime::explain(X_test, explainer, labels="Yes", n_features=10, feature_select = "lasso_path")
#
# g = lime::plot_features(explanation, ncol=1)

g = ggplot(train, aes(questionnaire_overnight_hospital_patient_in_last_year, total_drug_count)) + geom_boxplot()
g = ggplot(train, aes(questionnaire_overnight_hospital_patient_in_last_year, demographics_householdIncome)) + geom_boxplot()
g = ggplot(train, aes(questionnaire_overnight_hospital_patient_in_last_year, demographics_age)) + geom_boxplot()
g = ggplot(train[,
.(female_frac = mean(demographics_gender == "Female")),
by = .(questionnaire_overnight_hospital_patient_in_last_year)],
aes(questionnaire_overnight_hospital_patient_in_last_year, female_frac)) +
geom_bar(stat = 'identity') +
labs(title = 'Fraction female', y = 'Fraction')
g = ggplot(copy(train)[, age_bin := cut(demographics_age, 4)][,
.(female_frac = mean(demographics_gender == "Female")),
by = .(questionnaire_overnight_hospital_patient_in_last_year, age_bin)],
aes(questionnaire_overnight_hospital_patient_in_last_year, female_frac)) +
geom_bar(stat = 'identity') +
facet_wrap(~age_bin) +
labs(title = 'Fraction female by age', y = 'Fraction')
# g = ggplot(copy(train)[, `:=`(age_bin = cut(demographics_age, 4), income_bin = cut(demographics_householdIncome, 4))],
# aes(questionnaire_overnight_hospital_patient_in_last_year, examination_diastolic_blood_pressure)) +
# geom_boxplot() +
# facet_grid(age_bin~income_bin) +
# labs(title = 'Diastolic blood pressure by age')
Loading

0 comments on commit 7a8c2ea

Please sign in to comment.