Skip to content

Commit

Permalink
Wp2 (#88)
Browse files Browse the repository at this point in the history
* updates

* README

* README

* README
  • Loading branch information
PotosnakW committed Jun 28, 2022
1 parent 3ed86ce commit fc9087e
Show file tree
Hide file tree
Showing 4 changed files with 50 additions and 50 deletions.
23 changes: 19 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,8 @@ Phenotyping and Knowledge Discovery

`auton_survival.phenotyping` allows extraction of latent clusters or subgroups
of patients that demonstrate similar outcomes. In the context of this package,
we refer to this task as **phenotyping**. `auton_survival.phenotyping` allows:
we refer to this task as **phenotyping**. `auton_survival.phenotyping` provides
the following phenotyping utilities:

- **Intersectional Phenotyping**: Recovers groups, or phenotypes, of individuals
over exhaustive combinations of user-specified categorical and numerical features.
Expand Down Expand Up @@ -226,6 +227,8 @@ response to a specific intervention. Relies on the specially designed
`auton_survival.models.cmhe.DeepCoxMixturesHeterogenousEffects` latent variable model.

```python
from auton_survival.models.cmhe DeepCoxMixturesHeterogenousEffects

# Instantiate the CMHE model
model = DeepCoxMixturesHeterogenousEffects(random_seed=random_seed, k=k, g=g, layers=layers)

Expand All @@ -248,6 +251,13 @@ model = SurvivalVirtualTwins(horizon=365)
phenotypes = model.fit_predict(features, outcomes.time, outcomes.event, interventions)
```

DAG representations of the unsupervised, supervised, and counterfactual probabilitic
phenotypers in auton-survival are shown in the below figure. *X* represents the
covariates, *T* the time-to-event and *Z* is the phenotype to be inferred.

<p align="center"><img src="https://ndownloader.figshare.com/files/36056648" width=60%></p>


<a id="evaluation"></a>

Evaluation and Reporting
Expand Down Expand Up @@ -277,9 +287,14 @@ score = survival_regression_metric(metric='brs', outcomes_train,
```

- **Treatment Effect**: Used to compare treatment arms by computing the difference in the following metrics for treatment and control groups:
- **Time at Risk** (TaR)
- **Risk at Time**
- **Restricted Mean Survival Time** (RMST)
- **Time at Risk (TaR)** (left)
- **Risk at Time** (center)
- **Restricted Mean Survival Time (RMST)** (right)

<p align="center">
<img src="https://ndownloader.figshare.com/files/36056507" width=30%>
<img src="https://ndownloader.figshare.com/files/36056534" width=30%>
<img src="https://ndownloader.figshare.com/files/36056546" width=30%></p>

```python
from auton_survival.metrics import survival_diff_metric
Expand Down
41 changes: 21 additions & 20 deletions auton_survival/experiments.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,11 +39,11 @@

class SurvivalRegressionCV:
"""Universal interface to train Survival Analysis models in a cross-
validation or nested cross-validation fashion.
validation fashion.
Each of the model is trained in a CV fashion over the user specified
hyperparameter grid. The best model(s) in terms of user-specified metric
is selected.
The model is trained in a CV fashion over the user-specified
hyperparameter grid. Model hyperparameters are selected based on the
user-specified metric.
Parameters
-----------
Expand All @@ -65,9 +65,6 @@ class SurvivalRegressionCV:
num_folds : int, default=5
The number of folds.
Ignored if folds is specified.
num_nested_folds : int, default=None
The number of folds to use for nested cross-validation.
If None, then regular (unnested) CV is performed.
random_seed : int, default=0
Controls reproducibility of results.
hyperparam_grid : dict
Expand All @@ -92,12 +89,11 @@ class SurvivalRegressionCV:
"""

def __init__(self, model='dcph', folds=None, num_folds=5,
num_nested_folds=None, random_seed=0, hyperparam_grid={}):
random_seed=0, hyperparam_grid={}):

self.model = model
self.folds = folds
self.num_folds = num_folds
self.num_nested_folds = num_nested_folds
self.random_seed = random_seed
self.hyperparam_grid = list(ParameterGrid(hyperparam_grid))

Expand All @@ -116,7 +112,7 @@ def fit(self, features, outcomes, horizons, metric='ibs'):
outcomes : pd.DataFrame
A pandas dataframe with columns 'time' and 'event' that contain the
survival time and censoring status \( \delta_i = 1 \), respectively.
horizon : int or float or list
horizons : int or float or list
Event-horizons at which to evaluate model performance.
metric : str, default='ibs'
Metric used to evaluate model performance and tune hyperparameters.
Expand All @@ -125,12 +121,12 @@ def fit(self, features, outcomes, horizons, metric='ibs'):
- 'brs' : Brier Score
- 'ibs' : Integrated Brier Score
- 'ctd' : Concordance Index
Returns
-----------
Trained survival regression model(s).
"""


assert horizons is not None, "Horizons must be specified."
if isinstance(horizons, (int, float)):
Expand All @@ -156,10 +152,6 @@ def fit(self, features, outcomes, horizons, metric='ibs'):
assert max(horizons) < time_max, "Horizons exceeds max time range."
assert min(horizons) > time_min, "Horizons exceeds min time range."

# if self.horizon is None:
# assert (self.metric == 'ibs'), "Horizon must be specified for the selected metric"
# self.horizon = time_max

hyper_param_scores = []
for i, hyper_param in enumerate(self.hyperparam_grid):
print("At hyper-param", hyper_param)
Expand Down Expand Up @@ -189,7 +181,6 @@ def fit(self, features, outcomes, horizons, metric='ibs'):
**best_hyper_param).fit(features, outcomes)
return model


def _get_stratified_folds(self, dataset, event_label, n_folds, random_seed):

"""Get cross-validation fold value for each sample.
Expand Down Expand Up @@ -288,7 +279,6 @@ class CounterfactualSurvivalRegressionCV:
model : str
A string that determines the choice of the surival analysis model.
Survival model choices include:
- 'dsm' : Deep Survival Machines [3] model
- 'dcph' : Deep Cox Proportional Hazards [2] model
- 'dcm' : Deep Cox Mixtures [4] model
Expand Down Expand Up @@ -341,10 +331,10 @@ def __init__(self, model, cv_folds=5, random_seed=0, hyperparam_grid={}):
random_seed=random_seed,
hyperparam_grid=hyperparam_grid)

def fit(self, features, outcomes, interventions, metric):
def fit(self, features, outcomes, interventions, horizons, metric):

r"""Fits the Survival Regression Model to the data in a Cross
Validation fashion.
r"""Fits the Survival Regression Model to the data in a cross-
validation fashion.
Parameters
-----------
Expand All @@ -359,6 +349,15 @@ def fit(self, features, outcomes, interventions, metric):
interventions: pandas.Series
A pandas series containing the treatment status of each subject.
\( a_i = 1 \) if the subject is `treated`, else is considered control.
horizons : int or float or list
Event-horizons at which to evaluate model performance.
metric : str, default='ibs'
Metric used to evaluate model performance and tune hyperparameters.
Options include:
- 'auc': Dynamic area under the ROC curve
- 'brs' : Brier Score
- 'ibs' : Integrated Brier Score
- 'ctd' : Concordance Index
Returns
-----------
Expand All @@ -369,9 +368,11 @@ def fit(self, features, outcomes, interventions, metric):

treated_model = self.treated_experiment.fit(features.loc[interventions==1],
outcomes.loc[interventions==1],
horizons=horizons,
metric=metric)
control_model = self.control_experiment.fit(features.loc[interventions!=1],
outcomes.loc[interventions!=1],
horizons=horizons,
metric=metric)

return CounterfactualSurvivalModel(treated_model, control_model)
15 changes: 7 additions & 8 deletions auton_survival/phenotyping.py
Original file line number Diff line number Diff line change
Expand Up @@ -478,8 +478,7 @@ def __init__(self,

self.random_seed = random_seed

def fit(self, features, outcomes, interventions, metric,
horizon):
def fit(self, features, outcomes, interventions, horizons, metric):

"""Fit a counterfactual model and regress the difference of the estimated
counterfactual Restricted Mean Survival Time using a Random Forest regressor.
Expand All @@ -495,16 +494,15 @@ def fit(self, features, outcomes, interventions, metric,
interventions : np.array
Boolean numpy array of treatment indicators. True means individual
was assigned a specific treatment.
horizons : int or float or list
Event-horizons at which to evaluate model performance.
metric : str, default='ibs'
Metric used to evaluate model performance and tune hyperparameters.
Options include:
- 'auc': Dynamic area under the ROC curve
- 'brs' : Brier Score
- 'ibs' : Integrated Brier Score
- 'ctd' : Concordance Index
horizon : np.float
The event horizon at which to compute the counterfacutal RMST for
regression.
Returns
-----------
Expand All @@ -515,12 +513,13 @@ def fit(self, features, outcomes, interventions, metric,
cf_model = CounterfactualSurvivalRegressionCV(model=self.cf_method,
hyperparam_grid=self.cf_hyperparams)

self.cf_model = cf_model.fit(features, outcomes, interventions, metric)
self.cf_model = cf_model.fit(features, outcomes, interventions,
horizons, metric)

times = np.unique(outcomes.time.values)
cf_predictions = self.cf_model.predict_counterfactual_survival(features,
times.tolist())

horizon = max(horizons)
ite_estimates = cf_predictions[1] - cf_predictions[0]
ite_estimates = [estimate[times < horizon] for estimate in ite_estimates]
times = times[times < horizon]
Expand Down Expand Up @@ -558,7 +557,7 @@ def predict_proba(self, features):
"""

phenotype_preds= self.pheno_model.predict(features)
phenotype_preds = self.pheno_model.predict(features)
preds_surv_greater = (phenotype_preds - phenotype_preds.min()) / (phenotype_preds.max() - phenotype_preds.min())
preds_surv_less = 1 - preds_surv_greater
preds = np.array([[preds_surv_less[i], preds_surv_greater[i]]
Expand Down
21 changes: 3 additions & 18 deletions examples/CV Survival Regression on SUPPORT Dataset.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -49,8 +49,7 @@
"outputs": [],
"source": [
"import numpy as np\n",
"horizons = [0.25, 0.5, 0.75]\n",
"times = np.quantile(outcomes.time[outcomes.event==1], horizons).tolist()"
"times = np.quantile(outcomes.time[outcomes.event==1], [0.25, 0.5, 0.75]).tolist()"
]
},
{
Expand All @@ -67,7 +66,7 @@
" 'layers' : [[100]]}\n",
"\n",
"experiment = SurvivalRegressionCV(model='dsm', num_folds=3, hyperparam_grid=param_grid, random_seed=0)\n",
"model = experiment.fit(x, outcomes, metric='ctd')"
"model = experiment.fit(x, outcomes, times, metric='brs')"
]
},
{
Expand All @@ -80,13 +79,6 @@
"model"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": null,
Expand Down Expand Up @@ -122,7 +114,7 @@
"for fold in set(experiment.folds):\n",
" print(survival_regression_metric('ctd', outcomes[experiment.folds==fold], \n",
" out_survival[experiment.folds==fold], \n",
" times=times))\n"
" times=times))"
]
},
{
Expand All @@ -136,13 +128,6 @@
" print(time)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": null,
Expand Down

0 comments on commit fc9087e

Please sign in to comment.