Skip to content

Commit

Permalink
Add function to return SHAP information.
Browse files Browse the repository at this point in the history
  • Loading branch information
jlevy44 committed Apr 25, 2019
1 parent 5216ef5 commit 775f8ed
Show file tree
Hide file tree
Showing 5 changed files with 81 additions and 15 deletions.
2 changes: 2 additions & 0 deletions example_scripts/GSE87571_age_estimation.md
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,8 @@ CUDA_VISIBLE_DEVICES=0 methylnet-interpret produce_shapley_data -mth gradient -s

Extract spreadsheet of top overall CpGs:
```
methylnet-interpret return_shap_values -c all -hist -s interpretations/shapley_explanations/shapley_binned.p -o interpretations/shap_results/ &
methylnet-interpret return_shap_values -c all -hist -abs -o interpretations/abs_shap_results/ -s interpretations/shapley_explanations/shapley_binned.p &
```

Expand Down
8 changes: 6 additions & 2 deletions example_scripts/GSE87571_cell_deconvolution.md
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,8 @@ CUDA_VISIBLE_DEVICES=0 methylnet-interpret produce_shapley_data -mth gradient -s

Extract spreadsheet of top overall CpGs:
```
methylnet-interpret return_shap_values -c all -hist -o interpretations/shap_results/ &
methylnet-interpret return_shap_values -c all -hist -abs -o interpretations/abs_shap_results/ &
```

Plot bar chart of top CpGs:
Expand Down Expand Up @@ -121,7 +122,7 @@ MethylNet Commands:
* python model_interpretability.py shapley_jaccard -c all -i -s ./interpretations/shapley_explanations/shapley_data_by_methylation/hyper_shapley_data.p -o ./interpretations/shapley_explanations/top_cpgs_jaccard/hyper/ && python model_interpretability.py order_results_by_col -c Age -i ./interpretations/shapley_explanations/top_cpgs_jaccard/hyper/all_jaccard.csv -o ./interpretations/shapley_explanations/top_cpgs_jaccard/hyper/all_jaccard.sorted.csv
* pymethyl-visualize plot_heatmap -c -m similarity -fs .4 -i ./interpretations/shapley_explanations/top_cpgs_jaccard/hypo/all_jaccard.sorted.csv -o ./interpretations/shapley_explanations/top_cpgs_jaccard/hypo/all_hypo_jaccard.png
* pymethyl-visualize plot_heatmap -c -m similarity -fs .4 -i ./interpretations/shapley_explanations/top_cpgs_jaccard/hyper/all_jaccard.sorted.csv -o ./interpretations/shapley_explanations/top_cpgs_jaccard/hyper/all_hyper_jaccard.png
* python model_interpretability.py bin_regression_shaps -c Age -n 16
* python model_interpretability.py bin_regression_shaps -c Age -n 8
* python model_interpretability.py shapley_jaccard -c all -s ./interpretations/shapley_explanations/shapley_binned.p -o ./interpretations/shapley_explanations/top_cpgs_jaccard/ -ov
* python model_interpretability.py order_results_by_col -c Age -t null -i ./interpretations/shapley_explanations/top_cpgs_jaccard/all_jaccard.csv -o ./interpretations/shapley_explanations/top_cpgs_jaccard/all_jaccard.sorted.csv &
* pymethyl-utils counts -i train_val_test_sets/test_methyl_array_shap_binned.pkl -k Age_binned
Expand Down Expand Up @@ -168,6 +169,7 @@ python model_interpretability.py interpret_biology -ov -c all -s interpretations

# to-do search for missing cpgs, do same for other studies
# check overlap with different blood types
REDO!!!
(13.92,22.0] top cpgs overlap with 0.0% of hannum cpgs
(22.0,30.0] top cpgs overlap with 0.0% of hannum cpgs
(30.0,38.0] top cpgs overlap with 1.45% of hannum cpgs
Expand All @@ -181,6 +183,8 @@ This cohort was around this age distribution...

Maybe look at horvath and epitoc age distribution for cohort.

REDO!!!!

(54.0,62.0] shared cpgs: 41/41.0
(70.0,78.0] shared cpgs: 55/55.0
(62.0,70.0] shared cpgs: 54/54.0
Expand Down
3 changes: 2 additions & 1 deletion example_scripts/TCGA.md
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,8 @@ CUDA_VISIBLE_DEVICES=0 methylnet-interpret produce_shapley_data -mth gradient -s

Extract spreadsheet of top overall CpGs:
```
methylnet-interpret return_shap_values -c all -hist &
methylnet-interpret return_shap_values -c all -hist -abs -o interpretations/abs_shap_results/ &
```

Plot bar chart of top CpGs:
Expand Down
31 changes: 19 additions & 12 deletions methylnet/interpretation_classes.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,6 +198,10 @@ def return_shapley_data_by_methylation_status(self, methyl_array):
methylation_shapley_data_dict['hypo'].top_cpgs['by_class'][class_name]['overall']=hypo_df[['cpg','shapley_value']]
return methylation_shapley_data_dict

def make_shap_scores_abs(self):
for class_name in self.list_classes():
self.shapley_data.shapley_values['by_class'][class_name] = self.shapley_data.shapley_values['by_class'][class_name].abs()

def return_binned_shapley_data(self, original_class_name, outcome_col, add_top_negative=False):
"""Converts existing shap data based on continuous variable predictions into categorical variable.
Expand Down Expand Up @@ -271,7 +275,7 @@ def return_cpg_set(shap_df):
cpg_exclusion_sets[class_name]=set(reduce(lambda x,y:x.union(y),[cpg_set for class_name_query,cpg_set in cpg_sets.items() if class_name_query != class_name]))
return cpg_sets, cpg_exclusion_sets

def extract_class(self, class_name, class_intersect=False):
def extract_class(self, class_name, class_intersect=False, get_shap_values=False):
"""Extract the top cpgs from a class
Parameters
Expand All @@ -286,18 +290,21 @@ def extract_class(self, class_name, class_intersect=False):
DataFrame
Cpgs and SHAP Values
"""
if class_intersect:
shap_dfs=[]
for individual in self.shapley_data.top_cpgs['by_class'][class_name]['by_individual'].keys():
shap_dfs.append(self.shapley_data.top_cpgs['by_class'][class_name]['by_individual'][individual].set_index('cpg'))
df=pd.concat([shap_dfs],axis=1,join='inner')
df['shapley_value']=df.values.sum(axis=1)
df['cpg'] = np.array(list(df.index))
return df[['cpg','shapley_value']]
if return_shap_values:
return self.shapley_data.shapley_values['by_class'][class_name].mean(axis=0)
else:
return self.shapley_data.top_cpgs['by_class'][class_name]['overall']
if class_intersect:
shap_dfs=[]
for individual in self.shapley_data.top_cpgs['by_class'][class_name]['by_individual'].keys():
shap_dfs.append(self.shapley_data.top_cpgs['by_class'][class_name]['by_individual'][individual].set_index('cpg'))
df=pd.concat([shap_dfs],axis=1,join='inner')
df['shapley_value']=df.values.sum(axis=1)
df['cpg'] = np.array(list(df.index))
return df[['cpg','shapley_value']]
else:
return self.shapley_data.top_cpgs['by_class'][class_name]['overall']

def extract_individual(self, individual):
def extract_individual(self, individual, get_shap_values=False):
"""Extract the top cpgs from an individual
Parameters
Expand All @@ -311,7 +318,7 @@ def extract_individual(self, individual):
Class name of individual, DataFrame of Cpgs and SHAP Values
"""
class_name=self.indiv2class[individual]
return class_name,self.shapley_data.top_cpgs['by_class'][class_name]['by_individual'][individual]
return class_name,(self.shapley_data.top_cpgs['by_class'][class_name]['by_individual'][individual] if not get_shap_values else self.shapley_data.shapley_values['by_class'][class_name].loc[individual])

def regenerate_individual_shap_values(self, n_top_cpgs, abs_val=False, neg_val=False):
"""Use original SHAP scores to make nested dictionary of top CpGs based on shapley score, can do this for ABS SHAP or Negative SHAP scores as well.
Expand Down
52 changes: 52 additions & 0 deletions methylnet/model_interpretability.py
Original file line number Diff line number Diff line change
Expand Up @@ -403,6 +403,58 @@ def regenerate_top_cpgs(shapley_data,n_top_features,output_pkl, abs_val, neg_val
shapley_data_explorer=ShapleyDataExplorer(shapley_data)
shapley_data_explorer.regenerate_individual_shap_values(n_top_cpgs=n_top_features,abs_val=abs_val, neg_val=neg_val).to_pickle(output_pkl)

@interpret.command()
@click.option('-s', '--shapley_data', default='./interpretations/shapley_explanations/shapley_data.p', help='Pickle containing top CpGs.', type=click.Path(exists=False), show_default=True)
@click.option('-o', '--output_dir', default='./interpretations/shap_outputs/', help='Output directory for output plots.', type=click.Path(exists=False), show_default=True)
@click.option('-i', '--individuals', default=[''], multiple=True, help='Individuals to evaluate.', show_default=True)
@click.option('-c', '--classes', default=[''], multiple=True, help='Classes to evaluate.', show_default=True)
@click.option('-hist', '--output_histogram', is_flag=True, help='Whether to output a histogram for each class/individual of their SHAP scores.')
@click.option('-abs', '--absolute', is_flag=True, help='Use sums of absolute values in making computations.')
def return_shap_values(shapley_data,output_dir,individuals,classes, output_histogram, absolute):
"""Return matrix of shapley values per class, with option to, classes/individuals are columns, CpGs are rows, option to plot multiple histograms/density plots."""
os.makedirs(output_dir,exist_ok=True)
shapley_data=ShapleyData.from_pickle(shapley_data)
shapley_data_explorer=ShapleyDataExplorer(shapley_data)
individuals=list(filter(None,individuals))
classes=list(filter(None,classes))
if absolute:
shapley_data_explorer.make_shap_scores_abs()
if classes and classes[0]=='all':
classes = shapley_data_explorer.list_classes()
if individuals and individuals[0]=='all':
individuals = shapley_data_explorer.list_individuals(return_list=True)
concat_list=[]
if classes:
for class_name in classes:
concat_list.append(shapley_data_explorer.extract_class(class_name,get_shap_values=True))
if individuals:
for individual in individuals:
concat_list.append(shapley_data_explorer.extract_individual(individual,get_shap_values=True))
df=pd.concat(concat_list,axis=1,keys=classes+individuals)
df.to_csv(join(output_dir,'returned_shap_values.csv'))
if output_histogram:
import matplotlib.pyplot as plt
import seaborn as sns
sns.set()
for entity in (classes+individuals):
shap_scores = df[entity]
plt.figure()
sns.distplot(shap_scores)
plt.xlabel('SHAP value')
plt.ylabel('Frequency')
plt.savefig(join(output_dir,'{}_shap_values.png'.format(entity)),dpi=300)
top_10_shaps = shap_scores.abs().sort_values(ascending=False).iloc[:10]
top_10_shaps=pd.DataFrame({'cpgs':top_10_shaps.index, '|SHAP|':top_10_shaps.values})
plt.figure()
ax=sns.barplot('|SHAP|','cpgs',orient='h',data=top_10_shaps)
ax.tick_params(labelsize=4)
plt.savefig(join(output_dir,'{}_top_shap_values.png'.format(entity)),dpi=300)
if not absolute:
print("All saved values reflect absolute values of sums, not sums of absolute values, if CpG SHAP scores are opposite signs across individuals, this will reduce the score of the resulting SHAP estimate.")




@interpret.command()
@click.option('-s', '--shapley_data', default='./interpretations/shapley_explanations/shapley_data.p', help='Pickle containing top CpGs.', type=click.Path(exists=False), show_default=True)
@click.option('-o', '--output_dir', default='./interpretations/output_plots/', help='Output directory for output plots.', type=click.Path(exists=False), show_default=True)
Expand Down

0 comments on commit 775f8ed

Please sign in to comment.