-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Adam Amster
authored and
Adam Amster
committed
Mar 9, 2019
0 parents
commit 66f095c
Showing
10 changed files
with
615 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
data | ||
*.Rproj | ||
*.RData |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,167 @@ | ||
library(modules) | ||
library(data.table) | ||
library(magrittr) | ||
library(ggplot2) | ||
library(dplyr) | ||
library(caret) | ||
|
||
preprocessing = import('preprocessing/preprocessing') | ||
metrics = import('metrics') | ||
|
||
set.seed(1234) | ||
|
||
data = preprocessing$read_and_preprocess_data(years = c( | ||
'2015-2016', '2013-2014', '2011-2012', '2009-2010', | ||
'2007-2008', '2005-2006', '2003-2004', '2001-2002', '1999-2000' | ||
)) | ||
|
||
# Note, DMDEDUC2 and DMDMARTL only apply to those ages 20+ | ||
# Note, RIDRETH3 includes non-hispanic asian code 5 which is not included prior to 2011 | ||
# Thought, exclude those < 20 yrs old since they tend to be healthy and some variables are not given values for them? | ||
|
||
# create a testing and training set | ||
data = data[sample(nrow(data))] | ||
train_idx = createDataPartition( | ||
data$questionnaire_overnight_hospital_patient_in_last_year, | ||
p = .8, | ||
list = F | ||
) | ||
|
||
train = data[train_idx] | ||
test = data[-train_idx] | ||
|
||
# train_upsampled = upSample( | ||
# x = train[, .SD, .SDcols = !names(train) == 'questionnaire_overnight_hospital_patient_in_last_year'], | ||
# y = train$questionnaire_overnight_hospital_patient_in_last_year, | ||
# yname = 'questionnaire_overnight_hospital_patient_in_last_year' | ||
# ) | ||
|
||
fit_control = trainControl(method = "cv", | ||
number = 5, | ||
classProbs = T, | ||
summaryFunction = twoClassSummary, | ||
search = 'random' | ||
) | ||
|
||
# model_xgbTree = train( | ||
# questionnaire_overnight_hospital_patient_in_last_year ~ ., | ||
# data = train, | ||
# method = "xgbTree", | ||
# metric = 'ROC', | ||
# trControl = fit_control | ||
# ) | ||
# model_lasso = train( | ||
# questionnaire_overnight_hospital_patient_in_last_year ~ ., | ||
# data = train, | ||
# method = "glmnet", | ||
# family = "binomial", | ||
# metric = 'ROC', | ||
# trControl = fit_control | ||
# ) | ||
|
||
fit_control$sampling = "smote" | ||
|
||
model_xgbTree_smote = train( | ||
questionnaire_overnight_hospital_patient_in_last_year ~ ., | ||
data = train, | ||
method = "xgbTree", | ||
metric = 'ROC', | ||
trControl = fit_control | ||
) | ||
model_lasso_smote = train( | ||
questionnaire_overnight_hospital_patient_in_last_year ~ ., | ||
data = train, | ||
method = "glmnet", | ||
family = "binomial", | ||
metric = 'ROC', | ||
trControl = fit_control | ||
) | ||
resamps = resamples( | ||
list( | ||
xgbTree_smote = model_xgbTree_smote, | ||
lasso_smote = model_lasso_smote | ||
# xgbTree = model_xgbTree, | ||
# logReg = model_logReg | ||
) | ||
) | ||
summary(resamps) | ||
|
||
# Plot cross validation metrics | ||
ggplot(model_xgbTree_smote) + labs(title = 'xgboost') | ||
ggplot(model_lasso_smote) + labs(title = 'Lasso logistic regression') | ||
|
||
## Test performance | ||
ROCR::performance( | ||
ROCR::prediction(predict(model_xgbTree_smote, type = "prob", newdata = test)[1], test$questionnaire_overnight_hospital_patient_in_last_year), | ||
'auc' | ||
)@y.values[1] | ||
|
||
ROCR::performance( | ||
ROCR::prediction(predict(model_lasso_smote, type = "prob", newdata = test)[1], test$questionnaire_overnight_hospital_patient_in_last_year), | ||
'auc' | ||
)@y.values[1] | ||
|
||
|
||
# Lift plot | ||
lift_results <- data.frame( | ||
questionnaire_overnight_hospital_patient_in_last_year = test$questionnaire_overnight_hospital_patient_in_last_year, | ||
xgbTree_smote = predict(model_xgbTree_smote, type = "prob", newdata = test)[[1]], | ||
lasso_smote = predict(model_lasso_smote, type = "prob", newdata = test)[[1]] | ||
) | ||
lift_obj = lift(questionnaire_overnight_hospital_patient_in_last_year ~ xgbTree_smote + lasso_smote, data = lift_results) | ||
ggplot(lift_obj) | ||
|
||
# Confusion matrix | ||
fourfoldplot( | ||
confusionMatrix( | ||
predict(model_xgbTree_smote, newdata = test), test$questionnaire_overnight_hospital_patient_in_last_year | ||
)$table | ||
) | ||
|
||
fourfoldplot( | ||
confusionMatrix( | ||
predict(model_lasso_smote, newdata = test), test$questionnaire_overnight_hospital_patient_in_last_year | ||
)$table | ||
) | ||
|
||
# Precision | ||
precision( | ||
predict(model_xgbTree_smote, newdata = test), test$questionnaire_overnight_hospital_patient_in_last_year | ||
) | ||
precision( | ||
predict(model_lasso_smote, newdata = test), test$questionnaire_overnight_hospital_patient_in_last_year | ||
) | ||
|
||
# Variable importances | ||
var_imp = varImp(model_xgbTree_smote)$importance | ||
ggplot( | ||
data.table( | ||
variable = rownames(var_imp), | ||
importance = var_imp$Overall | ||
)[1:10], aes(reorder(variable, importance), importance) | ||
) + | ||
geom_bar(stat = "identity") + | ||
coord_flip() + | ||
labs(title = 'xgboost', x = 'variable') | ||
|
||
var_imp = varImp(model_lasso_smote)$importance | ||
ggplot( | ||
data.table( | ||
variable = rownames(var_imp), | ||
importance = var_imp$Overall | ||
)[order(-importance)][1:10], aes(reorder(variable, importance), importance) | ||
) + | ||
geom_bar(stat = "identity") + | ||
coord_flip() + | ||
labs(title = 'lasso', x = 'variable') | ||
|
||
# Looking at individual predictions | ||
pred = predict(model_xgbTree_smote, type = "prob", newdata = test)[[1]] | ||
|
||
high = sort(pred, decreasing = T, index.return=T)$ix[1] | ||
high_obs = test[high] | ||
high_pred = pred[high] | ||
|
||
low = sort(pred, decreasing = T, index.return=T)$ix[length(pred)] | ||
low_obs = test[low] | ||
low_pred = pred[low] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,4 @@ | ||
f1 <- function(data, lev = NULL, model = NULL) { | ||
f1_val <- F1_Score(y_pred = data$pred, y_true = data$obs, positive = lev[1]) | ||
c(F1 = f1_val) | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,123 @@ | ||
import_package('data.table', attach = T) | ||
library(dplyr) | ||
library(stringr) | ||
|
||
preprocessing_util = import('preprocessing/util') | ||
|
||
DEMOGRAPHICS_COLS = list( | ||
list(old = 'RIAGENDR', new = 'demographics_gender'), | ||
list(old = 'RIDAGEYR', new = 'demographics_age'), | ||
list(old = 'RIDRETH3', new = 'demographics_race'), | ||
list(old = 'RIDRETH1', new = 'demographics_race'), | ||
# list(old = 'DMDEDUC2', new = 'demographics_education'), | ||
# list(old = 'DMDMARTL', new = 'demographics_maritalStatus'), | ||
list(old = 'INDHHIN2', new = 'demographics_householdIncome') | ||
) | ||
|
||
###### | ||
# Recoding | ||
###### | ||
recode_race = function(code) { | ||
x = case_when( | ||
code == 1 ~ 'Mexican American', | ||
code == 2 ~ 'Other Hispanic', | ||
code == 3 ~ 'Non-Hispanic White', | ||
code == 4 ~ 'Non-Hispanic Black', | ||
code == 5 | code == 7 ~ 'Other Race - Including Multi-Racial', | ||
code == 6 ~ 'Non-Hispanic Asian' | ||
) | ||
return (factor(x)) | ||
} | ||
|
||
recode_gender = function(code) { | ||
x = case_when(code == 1 ~ 'Male', | ||
code == 2 ~ 'Female') | ||
return (factor(x)) | ||
} | ||
|
||
recode_eduction = function(code) { | ||
x = case_when( | ||
code == 1 ~ 'Less than 9th grade', | ||
code == 2 ~ ' 9-11th grade (Includes 12th grade with no diploma)', | ||
code == 3 ~ ' High school graduate/GED or equivalent', | ||
code == 4 ~ 'Some college or AA degree', | ||
code == 5 ~ 'College graduate or above' | ||
) | ||
return (factor(x)) | ||
} | ||
|
||
recode_maritalStatus = function(code) { | ||
x = case_when( | ||
code == 1 ~ 'Married', | ||
code == 2 ~ 'Widowed', | ||
code == 3 ~ 'Divorced', | ||
code == 4 ~ 'Separated', | ||
code == 5 ~ 'Never married', | ||
code == 6 ~ 'Living with partner' | ||
) | ||
return (factor(x)) | ||
} | ||
|
||
convert_householdIncome_to_numeric = function(code) { | ||
to_numeric = function(code) { | ||
if (is.na(code)) { | ||
x = NA | ||
} else if (code == 1) { | ||
x = runif(1, 0, 4999) | ||
} else if (code == 2) { | ||
x = runif(1, 5000, 9999) | ||
} else if (code == 3) { | ||
x = runif(1, 10000, 14999) | ||
} else if (code == 4) { | ||
x = runif(1, 15000, 20000) | ||
} else if (code == 5) { | ||
x = runif(1, 20000, 24999) | ||
} else if (code == 6) { | ||
x = runif(1, 25000, 34999) | ||
} else if (code == 7) { | ||
x = runif(1, 35000, 44999) | ||
} else if (code == 8) { | ||
x = runif(1, 45000, 54999) | ||
} else if (code == 9) { | ||
x = runif(1, 55000, 64999) | ||
} else if (code == 10) { | ||
x = runif(1, 65000, 74999) | ||
} else if (code == 14) { | ||
x = runif(1, 75000, 99999) | ||
} else if (code == 15) { | ||
x = 100000 | ||
} else { | ||
x = NA | ||
} | ||
return (x) | ||
} | ||
|
||
return (sapply(code, to_numeric)) | ||
} | ||
|
||
read_and_preprocess_data = function(years) { | ||
data = list() | ||
|
||
for (year in years) { | ||
base_dir = str_interp('~/emoryhealthcare-project/data/${year}/Demographics') | ||
files = list.files(base_dir) | ||
file = files[str_detect(files, 'DEMO_*')][1] | ||
path = str_interp('${base_dir}/${file}') | ||
data_ = preprocessing_util$read_and_preprocess_data(path, DEMOGRAPHICS_COLS) | ||
data_[, `:=`( | ||
demographics_race = recode_race(demographics_race), | ||
demographics_gender = recode_gender(demographics_gender), | ||
# demographics_education = recode_eduction(demographics_education), | ||
# demographics_maritalStatus = recode_maritalStatus(demographics_maritalStatus), | ||
demographics_householdIncome = convert_householdIncome_to_numeric(demographics_householdIncome) | ||
)] | ||
|
||
# Remove children | ||
data_ = data_[demographics_age >= 20] | ||
|
||
data = append(data, list(data_)) | ||
} | ||
|
||
data = rbindlist(data, use.names = T) | ||
return (data) | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,59 @@ | ||
import_package('data.table', attach = T) | ||
import_package('dplyr', attach = T) | ||
import_package('stringr', attach = T) | ||
|
||
preprocessing_util = import('preprocessing/util') | ||
|
||
COLS = list( | ||
list(old = 'DR1TKCAL', new = 'dietary_kcal'), | ||
list(old = 'DR1TPROT', new = 'dietary_protein'), | ||
list(old = 'DR1TCARB', new = 'dietary_carbs'), | ||
list(old = 'DR1TSUGR', new = 'dietary_sugars'), | ||
list(old = 'DR1TFIBE', new = 'dietary_fiber'), | ||
list(old = 'DR1TSFAT', new = 'dietary_satfat'), | ||
list(old = 'DR1TMFAT', new = 'dietary_monounsatfat'), | ||
list(old = 'DR1TPFAT', new = 'dietary_polyunsatfat'), | ||
list(old = 'DR1TCHOL', new = 'dietary_cholesterol'), | ||
list(old = 'DR1TATOC', new = 'dietary_vitamine'), | ||
list(old = 'DR1TRET', new = 'dietary_retinol'), | ||
list(old = 'DR1TVARA', new = 'dietary_vitamina'), | ||
list(old = 'DR1TLYCO', new = 'dietary_lycopene'), | ||
list(old = 'DR1TVB1', new = 'dietary_vitaminb1'), | ||
list(old = 'DR1TVB2', new = 'dietary_vitaminb2'), | ||
list(old = 'DR1TNIAC', new = 'dietary_niacin'), | ||
list(old = 'DR1TVB6', new = 'dietary_vitaminb6'), | ||
list(old = 'DR1TFOLA', new = 'dietary_folate'), | ||
list(old = 'DR1TCHL', new = 'dietary_choline'), | ||
list(old = 'DR1TVB12', new = 'dietary_b12'), | ||
list(old = 'DR1TVC', new = 'dietary_vitaminc'), | ||
list(old = 'DR1TVD', new = 'dietary_vitamind'), | ||
list(old = 'DR1TVK', new = 'dietary_vitamink'), | ||
list(old = 'DR1TCALC', new = 'dietary_calcium'), | ||
list(old = 'DR1TPHOS', new = 'dietary_phosphorous'), | ||
list(old = 'DR1TMAGN', new = 'dietary_magnesium'), | ||
list(old = 'DR1TIRON', new = 'dietary_iron'), | ||
list(old = 'DR1TZINC', new = 'dietary_zinc'), | ||
list(old = 'DR1TSODI', new = 'dietary_sodium'), | ||
list(old = 'DR1TPOTA', new = 'dietary_potassium'), | ||
list(old = 'DR1TCAFF', new = 'dietary_caffeine'), | ||
list(old = 'DR1TALCO', new = 'dietary_alcohol'), | ||
list(old = 'DR1_320Z', new = 'dietary_water') | ||
) | ||
|
||
read_and_preprocess_data = function(years) { | ||
data = list() | ||
|
||
for (year in years) { | ||
base_dir = str_interp('~/emoryhealthcare-project/data/${year}/Dietary') | ||
files = list.files(base_dir) | ||
|
||
file_name = files[str_detect(files, 'DR1TOT_*')][1] | ||
path = str_interp('${base_dir}/${file_name}') | ||
data_ = preprocessing_util$read_and_preprocess_data(path, COLS) | ||
|
||
data = append(data, list(data_)) | ||
} | ||
|
||
data = rbindlist(data, use.names = T) | ||
return (data) | ||
} |
Oops, something went wrong.