Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Approximate Survival Shapley #21

Merged
merged 1 commit into from
Oct 27, 2023

Conversation

PierrickPochelu
Copy link
Contributor

import numpy as np
import pandas as pd
from survshap import SurvivalModelExplainer, ModelSurvSHAP
import time

np.random.seed(42)
nb_features=8
nb_events=150

np_X=np.random.rand(nb_events, nb_features)
np_time=np.random.rand(nb_events, 1)
np_is_living=np_X[:,0] < np_time[:,0]

y=np.empty(nb_events, dtype=[('event', '?'), ('time', '<f16')])
y['event']=np_is_living.reshape(-1)
y['time']=np_time.reshape(-1)
X=pd.DataFrame(np_X,columns=['f'+str(i) for i in range(1,nb_features+1)])

from sksurv.ensemble import RandomSurvivalForest
rsf=RandomSurvivalForest(random_state=42)
st=time.time()
rsf.fit(X,y)
print(f"score:{rsf.score(X,y)} fit time:{time.time()-st}")
print(f"predict: {rsf.predict(X)}")


exp_rsf=SurvivalModelExplainer(rsf,X,y)
ms_rsf=ModelSurvSHAP(random_state=42)                         # <-------- EXACT SURVIVAL SHAP
#ms_rsf=ModelSurvSHAP(random_state=42,max_shap_value_inputs=20) # <-------- APPROXIMATE SURVIVAL SHAP
st=time.time()
ms_rsf.fit(exp_rsf)
print(f"Interpretation time:{time.time()-st}")

# The scope of these changes made to
# pandas settings are local to with statement.
with pd.option_context('display.max_rows', None,
                       'display.max_columns', None,
                       'display.precision', 3,
                       ):
    ms_rsf.get_mean_abs_shap_values()
    print(ms_rsf.result)

OUTPUT:

THE EXACT VERSION TAKES 363 seconds:

100%|██████████| 150/150 [06:03<00:00, 2.42s/it]
Interpretation time:363.3058955669403
variable_name variable_value B aggregated_change index
0 f1 0.466 0.0 0.901 74.5
2 f3 0.455 0.0 0.105 74.5
6 f7 0.501 0.0 0.094 74.5
5 f6 0.508 0.0 0.056 74.5
3 f4 0.544 0.0 0.051 74.5
4 f5 0.509 0.0 0.036 74.5
7 f8 0.509 0.0 0.035 74.5
1 f2 0.502 0.0 0.027 74.5

THE APPROXIMATE VERSION TAKES 32 seconds:

100%|██████████| 150/150 [00:31<00:00, 4.70it/s]
Interpretation time:32.05052995681763
variable_name variable_value B aggregated_change index
0 f1 0.466 0.0 0.891 74.5
2 f3 0.455 0.0 0.102 74.5
6 f7 0.501 0.0 0.086 74.5
5 f6 0.508 0.0 0.055 74.5
3 f4 0.544 0.0 0.050 74.5
4 f5 0.509 0.0 0.034 74.5
7 f8 0.509 0.0 0.033 74.5
1 f2 0.502 0.0 0.026 74.5

CONCLUSION MUCH FASTER, MUCH MEMORY EFFICIENT AND ABOUT THE SAME RESULT

@krzyzinskim krzyzinskim mentioned this pull request Jun 29, 2023
@krzyzinskim
Copy link
Collaborator

Thanks again for your contribution! 😄
I will be testing different ways of approximating SurvSHAP(t) and will also take your code under consideration.

@krzyzinskim krzyzinskim merged commit 6385c7b into MI2DataLab:main Oct 27, 2023
@PierrickPochelu
Copy link
Contributor Author

I've tested the new wheel and this new feature works. 👌

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

2 participants