Skip to content

Commit

Permalink
Initial commit
Browse files Browse the repository at this point in the history
  • Loading branch information
Adam Amster authored and Adam Amster committed Mar 9, 2019
0 parents commit 66f095c
Show file tree
Hide file tree
Showing 10 changed files with 615 additions and 0 deletions.
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
data
*.Rproj
*.RData
167 changes: 167 additions & 0 deletions emoryhealthcare-project.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,167 @@
library(modules)
library(data.table)
library(magrittr)
library(ggplot2)
library(dplyr)
library(caret)

preprocessing = import('preprocessing/preprocessing')
metrics = import('metrics')

set.seed(1234)

data = preprocessing$read_and_preprocess_data(years = c(
'2015-2016', '2013-2014', '2011-2012', '2009-2010',
'2007-2008', '2005-2006', '2003-2004', '2001-2002', '1999-2000'
))

# Note, DMDEDUC2 and DMDMARTL only apply to those ages 20+
# Note, RIDRETH3 includes non-hispanic asian code 5 which is not included prior to 2011
# Thought, exclude those < 20 yrs old since they tend to be healthy and some variables are not given values for them?

# create a testing and training set
data = data[sample(nrow(data))]
train_idx = createDataPartition(
data$questionnaire_overnight_hospital_patient_in_last_year,
p = .8,
list = F
)

train = data[train_idx]
test = data[-train_idx]

# train_upsampled = upSample(
# x = train[, .SD, .SDcols = !names(train) == 'questionnaire_overnight_hospital_patient_in_last_year'],
# y = train$questionnaire_overnight_hospital_patient_in_last_year,
# yname = 'questionnaire_overnight_hospital_patient_in_last_year'
# )

fit_control = trainControl(method = "cv",
number = 5,
classProbs = T,
summaryFunction = twoClassSummary,
search = 'random'
)

# model_xgbTree = train(
# questionnaire_overnight_hospital_patient_in_last_year ~ .,
# data = train,
# method = "xgbTree",
# metric = 'ROC',
# trControl = fit_control
# )
# model_lasso = train(
# questionnaire_overnight_hospital_patient_in_last_year ~ .,
# data = train,
# method = "glmnet",
# family = "binomial",
# metric = 'ROC',
# trControl = fit_control
# )

fit_control$sampling = "smote"

model_xgbTree_smote = train(
questionnaire_overnight_hospital_patient_in_last_year ~ .,
data = train,
method = "xgbTree",
metric = 'ROC',
trControl = fit_control
)
model_lasso_smote = train(
questionnaire_overnight_hospital_patient_in_last_year ~ .,
data = train,
method = "glmnet",
family = "binomial",
metric = 'ROC',
trControl = fit_control
)
resamps = resamples(
list(
xgbTree_smote = model_xgbTree_smote,
lasso_smote = model_lasso_smote
# xgbTree = model_xgbTree,
# logReg = model_logReg
)
)
summary(resamps)

# Plot cross validation metrics
ggplot(model_xgbTree_smote) + labs(title = 'xgboost')
ggplot(model_lasso_smote) + labs(title = 'Lasso logistic regression')

## Test performance
ROCR::performance(
ROCR::prediction(predict(model_xgbTree_smote, type = "prob", newdata = test)[1], test$questionnaire_overnight_hospital_patient_in_last_year),
'auc'
)@y.values[1]

ROCR::performance(
ROCR::prediction(predict(model_lasso_smote, type = "prob", newdata = test)[1], test$questionnaire_overnight_hospital_patient_in_last_year),
'auc'
)@y.values[1]


# Lift plot
lift_results <- data.frame(
questionnaire_overnight_hospital_patient_in_last_year = test$questionnaire_overnight_hospital_patient_in_last_year,
xgbTree_smote = predict(model_xgbTree_smote, type = "prob", newdata = test)[[1]],
lasso_smote = predict(model_lasso_smote, type = "prob", newdata = test)[[1]]
)
lift_obj = lift(questionnaire_overnight_hospital_patient_in_last_year ~ xgbTree_smote + lasso_smote, data = lift_results)
ggplot(lift_obj)

# Confusion matrix
fourfoldplot(
confusionMatrix(
predict(model_xgbTree_smote, newdata = test), test$questionnaire_overnight_hospital_patient_in_last_year
)$table
)

fourfoldplot(
confusionMatrix(
predict(model_lasso_smote, newdata = test), test$questionnaire_overnight_hospital_patient_in_last_year
)$table
)

# Precision
precision(
predict(model_xgbTree_smote, newdata = test), test$questionnaire_overnight_hospital_patient_in_last_year
)
precision(
predict(model_lasso_smote, newdata = test), test$questionnaire_overnight_hospital_patient_in_last_year
)

# Variable importances
var_imp = varImp(model_xgbTree_smote)$importance
ggplot(
data.table(
variable = rownames(var_imp),
importance = var_imp$Overall
)[1:10], aes(reorder(variable, importance), importance)
) +
geom_bar(stat = "identity") +
coord_flip() +
labs(title = 'xgboost', x = 'variable')

var_imp = varImp(model_lasso_smote)$importance
ggplot(
data.table(
variable = rownames(var_imp),
importance = var_imp$Overall
)[order(-importance)][1:10], aes(reorder(variable, importance), importance)
) +
geom_bar(stat = "identity") +
coord_flip() +
labs(title = 'lasso', x = 'variable')

# Looking at individual predictions
pred = predict(model_xgbTree_smote, type = "prob", newdata = test)[[1]]

high = sort(pred, decreasing = T, index.return=T)$ix[1]
high_obs = test[high]
high_pred = pred[high]

low = sort(pred, decreasing = T, index.return=T)$ix[length(pred)]
low_obs = test[low]
low_pred = pred[low]
4 changes: 4 additions & 0 deletions metrics.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
f1 <- function(data, lev = NULL, model = NULL) {
f1_val <- F1_Score(y_pred = data$pred, y_true = data$obs, positive = lev[1])
c(F1 = f1_val)
}
123 changes: 123 additions & 0 deletions preprocessing/demographics.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,123 @@
import_package('data.table', attach = T)
library(dplyr)
library(stringr)

preprocessing_util = import('preprocessing/util')

DEMOGRAPHICS_COLS = list(
list(old = 'RIAGENDR', new = 'demographics_gender'),
list(old = 'RIDAGEYR', new = 'demographics_age'),
list(old = 'RIDRETH3', new = 'demographics_race'),
list(old = 'RIDRETH1', new = 'demographics_race'),
# list(old = 'DMDEDUC2', new = 'demographics_education'),
# list(old = 'DMDMARTL', new = 'demographics_maritalStatus'),
list(old = 'INDHHIN2', new = 'demographics_householdIncome')
)

######
# Recoding
######
recode_race = function(code) {
x = case_when(
code == 1 ~ 'Mexican American',
code == 2 ~ 'Other Hispanic',
code == 3 ~ 'Non-Hispanic White',
code == 4 ~ 'Non-Hispanic Black',
code == 5 | code == 7 ~ 'Other Race - Including Multi-Racial',
code == 6 ~ 'Non-Hispanic Asian'
)
return (factor(x))
}

recode_gender = function(code) {
x = case_when(code == 1 ~ 'Male',
code == 2 ~ 'Female')
return (factor(x))
}

recode_eduction = function(code) {
x = case_when(
code == 1 ~ 'Less than 9th grade',
code == 2 ~ ' 9-11th grade (Includes 12th grade with no diploma)',
code == 3 ~ ' High school graduate/GED or equivalent',
code == 4 ~ 'Some college or AA degree',
code == 5 ~ 'College graduate or above'
)
return (factor(x))
}

recode_maritalStatus = function(code) {
x = case_when(
code == 1 ~ 'Married',
code == 2 ~ 'Widowed',
code == 3 ~ 'Divorced',
code == 4 ~ 'Separated',
code == 5 ~ 'Never married',
code == 6 ~ 'Living with partner'
)
return (factor(x))
}

convert_householdIncome_to_numeric = function(code) {
to_numeric = function(code) {
if (is.na(code)) {
x = NA
} else if (code == 1) {
x = runif(1, 0, 4999)
} else if (code == 2) {
x = runif(1, 5000, 9999)
} else if (code == 3) {
x = runif(1, 10000, 14999)
} else if (code == 4) {
x = runif(1, 15000, 20000)
} else if (code == 5) {
x = runif(1, 20000, 24999)
} else if (code == 6) {
x = runif(1, 25000, 34999)
} else if (code == 7) {
x = runif(1, 35000, 44999)
} else if (code == 8) {
x = runif(1, 45000, 54999)
} else if (code == 9) {
x = runif(1, 55000, 64999)
} else if (code == 10) {
x = runif(1, 65000, 74999)
} else if (code == 14) {
x = runif(1, 75000, 99999)
} else if (code == 15) {
x = 100000
} else {
x = NA
}
return (x)
}

return (sapply(code, to_numeric))
}

read_and_preprocess_data = function(years) {
data = list()

for (year in years) {
base_dir = str_interp('~/emoryhealthcare-project/data/${year}/Demographics')
files = list.files(base_dir)
file = files[str_detect(files, 'DEMO_*')][1]
path = str_interp('${base_dir}/${file}')
data_ = preprocessing_util$read_and_preprocess_data(path, DEMOGRAPHICS_COLS)
data_[, `:=`(
demographics_race = recode_race(demographics_race),
demographics_gender = recode_gender(demographics_gender),
# demographics_education = recode_eduction(demographics_education),
# demographics_maritalStatus = recode_maritalStatus(demographics_maritalStatus),
demographics_householdIncome = convert_householdIncome_to_numeric(demographics_householdIncome)
)]

# Remove children
data_ = data_[demographics_age >= 20]

data = append(data, list(data_))
}

data = rbindlist(data, use.names = T)
return (data)
}
59 changes: 59 additions & 0 deletions preprocessing/dietary.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
import_package('data.table', attach = T)
import_package('dplyr', attach = T)
import_package('stringr', attach = T)

preprocessing_util = import('preprocessing/util')

COLS = list(
list(old = 'DR1TKCAL', new = 'dietary_kcal'),
list(old = 'DR1TPROT', new = 'dietary_protein'),
list(old = 'DR1TCARB', new = 'dietary_carbs'),
list(old = 'DR1TSUGR', new = 'dietary_sugars'),
list(old = 'DR1TFIBE', new = 'dietary_fiber'),
list(old = 'DR1TSFAT', new = 'dietary_satfat'),
list(old = 'DR1TMFAT', new = 'dietary_monounsatfat'),
list(old = 'DR1TPFAT', new = 'dietary_polyunsatfat'),
list(old = 'DR1TCHOL', new = 'dietary_cholesterol'),
list(old = 'DR1TATOC', new = 'dietary_vitamine'),
list(old = 'DR1TRET', new = 'dietary_retinol'),
list(old = 'DR1TVARA', new = 'dietary_vitamina'),
list(old = 'DR1TLYCO', new = 'dietary_lycopene'),
list(old = 'DR1TVB1', new = 'dietary_vitaminb1'),
list(old = 'DR1TVB2', new = 'dietary_vitaminb2'),
list(old = 'DR1TNIAC', new = 'dietary_niacin'),
list(old = 'DR1TVB6', new = 'dietary_vitaminb6'),
list(old = 'DR1TFOLA', new = 'dietary_folate'),
list(old = 'DR1TCHL', new = 'dietary_choline'),
list(old = 'DR1TVB12', new = 'dietary_b12'),
list(old = 'DR1TVC', new = 'dietary_vitaminc'),
list(old = 'DR1TVD', new = 'dietary_vitamind'),
list(old = 'DR1TVK', new = 'dietary_vitamink'),
list(old = 'DR1TCALC', new = 'dietary_calcium'),
list(old = 'DR1TPHOS', new = 'dietary_phosphorous'),
list(old = 'DR1TMAGN', new = 'dietary_magnesium'),
list(old = 'DR1TIRON', new = 'dietary_iron'),
list(old = 'DR1TZINC', new = 'dietary_zinc'),
list(old = 'DR1TSODI', new = 'dietary_sodium'),
list(old = 'DR1TPOTA', new = 'dietary_potassium'),
list(old = 'DR1TCAFF', new = 'dietary_caffeine'),
list(old = 'DR1TALCO', new = 'dietary_alcohol'),
list(old = 'DR1_320Z', new = 'dietary_water')
)

read_and_preprocess_data = function(years) {
data = list()

for (year in years) {
base_dir = str_interp('~/emoryhealthcare-project/data/${year}/Dietary')
files = list.files(base_dir)

file_name = files[str_detect(files, 'DR1TOT_*')][1]
path = str_interp('${base_dir}/${file_name}')
data_ = preprocessing_util$read_and_preprocess_data(path, COLS)

data = append(data, list(data_))
}

data = rbindlist(data, use.names = T)
return (data)
}
Loading

0 comments on commit 66f095c

Please sign in to comment.