forked from MAIF/shapash
-
Notifications
You must be signed in to change notification settings - Fork 0
/
smart_explainer.py
1289 lines (1209 loc) · 57.1 KB
/
smart_explainer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
"""
Smart explainer module
"""
import logging
import copy
import tempfile
import shutil
import numpy as np
import pandas as pd
from shapash.webapp.smart_app import SmartApp
from shapash.backend import BaseBackend, get_backend_cls_from_name
from shapash.utils.io import save_pickle
from shapash.utils.io import load_pickle
from shapash.utils.transform import inverse_transform, apply_postprocessing, handle_categorical_missing
from shapash.utils.utils import get_host_name
from shapash.utils.threading import CustomThread
from shapash.utils.check import check_model, check_label_dict, check_y, check_postprocessing, check_features_name, check_additional_data
from shapash.backend.shap_backend import get_shap_interaction_values
from shapash.manipulation.select_lines import keep_right_contributions
from shapash.report import check_report_requirements
from shapash.manipulation.summarize import create_grouped_features_values
from .smart_plotter import SmartPlotter
import shapash.explainer.smart_predictor
from shapash.utils.model import predict_proba, predict, predict_error
from shapash.utils.explanation_metrics import find_neighbors, shap_neighbors, get_min_nb_features, get_distance
from shapash.style.style_utils import colors_loading, select_palette
logging.basicConfig(level=logging.INFO)
class SmartExplainer:
"""
The SmartExplainer class is the main object of the Shapash library.
It allows the Data Scientists to perform many operations to make the
results more understandable :
linking encoders, models, predictions, label dict and datasets.
SmartExplainer users have several methods which are described below.
Parameters
----------
model : model object
model used to consistency check. model object can also be used by some method to compute
predict and predict_proba values
backend : str or shpash.backend object (default: 'shap')
Select which computation method to use in order to compute contributions
and feature importance. Possible values are 'shap' or 'lime'. Default is 'shap'.
It is also possible to pass a backend class inherited from shpash.backend.BaseBackend.
preprocessing : category_encoders, ColumnTransformer, list, dict, optional (default: None)
--> Differents types of preprocessing are available:
- A single category_encoders (OrdinalEncoder/OnehotEncoder/BaseNEncoder/BinaryEncoder/TargetEncoder)
- A single ColumnTransformer with scikit-learn encoding or category_encoders transformers
- A list with multiple category_encoders with optional (dict, list of dict)
- A list with a single ColumnTransformer with optional (dict, list of dict)
- A dict
- A list of dict
postprocessing : dict, optional (default: None)
Dictionnary of postprocessing modifications to apply in x_init dataframe.
Dictionnary with feature names as keys (or number, or well labels referencing to features names),
which modifies dataset features by features.
--> Different types of postprocessing are available, but the syntax is this one:
One key by features, 5 different types of modifications:
features_groups : dict, optional (default: None)
Dictionnary containing features that should be grouped together. This option allows
to compute and display the contributions and importance of this group of features.
Features that are grouped together will still be displayed in the webapp when clicking
on a group.
>>> {
‘feature1’ : { ‘type’ : ‘prefix’, ‘rule’ : ‘age: ‘ },
‘feature2’ : { ‘type’ : ‘suffix’, ‘rule’ : ‘$/week ‘ },
‘feature3’ : { ‘type’ : ‘transcoding’, ‘rule‘: { ‘code1’ : ‘single’, ‘code2’ : ‘married’}},
‘feature4’ : { ‘type’ : ‘regex’ , ‘rule‘: { ‘in’ : ‘AND’, ‘out’ : ‘ & ‘ }},
‘feature5’ : { ‘type’ : ‘case’ , ‘rule‘: ‘lower’‘ }
}
Only one transformation by features is possible.
features_groups : dict, optional (default: None)
Dictionnary containing features that should be grouped together. This option allows
to compute and display the contributions and importance of this group of features.
Features that are grouped together will still be displayed in the webapp when clicking
on a group.
>>> {
‘feature_group_1’ : ['feature3', 'feature7', 'feature24'],
‘feature_group_2’ : ['feature1', 'feature12'],
}
features_dict: dict
Dictionary mapping technical feature names to domain names.
label_dict: dict
Dictionary mapping integer labels to domain names (classification - target values).
title_story: str (default: None)
The default title is empty. You can specify a custom title
which can be used the webapp, or other methods
palette_name : str
Name of the palette used for the colors of the report (refer to style folder).
colors_dic : dict
dictionnary contaning every palettes of colors. You can use this parameter to change
any color of the graphs.
**kwargs : dict
Keyword parameters to be passed to the backend.
Attributes
----------
data: dict
Data dictionary has 3 entries. Each key returns a pd.DataFrame (regression) or a list of pd.DataFrame
(classification - The length of the lists is equivalent to the number of labels).
All pd.DataFrame have she same shape (n_samples, n_features).
For the regression case, data that should be regarded as a single array
of size (n_samples, n_features, 3).
data['contrib_sorted']: pandas.DataFrame (regression) or list of pandas.DataFrame (classification)
Contains local contributions of the prediction set, with common line index.
Columns are 'contrib_1', 'contrib_2', ... and contains the top contributions
for each line from left to right. In multi-class problems, this is a list of
contributions, one for each class.
data['var_dict']: pandas.DataFrame (regression) or list of pandas.DataFrame (classification)
Must contain only ints. It gives, for each line, the list of most import features
regarding the local decomposition. In order to save space, columns are denoted by
integers, the conversion being done with the columns_dict member. In multi-class
problems, this is a list of dataframes, one for each class.
data['x_sorted']: pandas.DataFrame (regression) or list of pandas.DataFrame (classification)
It gives, for each line, the list of most important features values regarding the local
decomposition. These values can only be understood with respect to data['var_dict']
x_encoded: pandas.DataFrame
preprocessed dataset used by the model to perform the prediction.
x_init: pandas.DataFrame
x_encoded dataset with inverse transformation with eventual postprocessing modifications.
x_contrib_plot: pandas.DataFrame
x_encoded dataset with inverse transformation, without postprocessing used for contribution_plot.
y_pred: pandas.DataFrame
User-specified prediction values.
contributions: pandas.DataFrame (regression) or list (classification)
local contributions aggregated if the preprocessing part requires it (e.g. one-hot encoding).
features_dict: dict
Dictionary mapping technical feature names to domain names.
inv_features_dict: dict
Inverse features_dict mapping.
label_dict: dict
Dictionary mapping integer labels to domain names (classification - target values).
inv_label_dict: dict
Inverse label_dict mapping.
columns_dict: dict
Dictionary mapping integer column number to technical feature names.
plot: object
Helper object containing all plotting functions (Bridge pattern).
model: model object
model used to check the different values of target estimate predict proba
features_desc: dict
Dictionary that references the numbers of feature values in the x_init
features_imp: pandas.Series (regression) or list (classification)
Features importance values
local_neighbors: dict
Dictionary of values to be displayed on the local_neighbors plot.
The key is "norm_shap (normalized contributions values of instance and neighbors)
features_stability: dict
Dictionary of arrays to be displayed on the stability plot.
The keys are "amplitude" (average contributions values for selected instances) and
"stability" (stability metric across neighborhood)
preprocessing : category_encoders, ColumnTransformer, list or dict
The processing apply to the original data.
postprocessing : dict
Dictionnary of postprocessing modifications to apply in x_init dataframe.
y_target : pandas.Series or pandas.DataFrame, optional (default: None)
Target values
Example
--------
>>> xpl = SmartExplainer(model, features_dict=featd,label_dict=labeld)
>>> xpl.compile(x=x_encoded, y_target=y)
>>> xpl.plot.features_importance()
"""
def __init__(
self,
model,
backend='shap',
preprocessing=None,
postprocessing=None,
features_groups=None,
features_dict=None,
label_dict=None,
title_story: str = None,
palette_name=None,
colors_dict=None,
**kwargs
):
if features_dict is not None and not isinstance(features_dict, dict):
raise ValueError(
"""
features_dict must be a dict
"""
)
if label_dict is not None and isinstance(label_dict, dict) is False:
raise ValueError(
"""
label_dict must be a dict
"""
)
self.model = model
if isinstance(backend, str):
backend_cls = get_backend_cls_from_name(backend)
self.backend = backend_cls(
model=self.model, preprocessing=preprocessing, **kwargs)
elif isinstance(backend, BaseBackend):
self.backend = backend
if backend.preprocessing is None and preprocessing is not None:
self.backend.preprocessing = preprocessing
else:
raise NotImplementedError(f'Unknown backend : {backend}')
self.preprocessing = self.backend.preprocessing
self.features_dict = dict() if features_dict is None else copy.deepcopy(features_dict)
self.label_dict = label_dict
self.plot = SmartPlotter(self)
self.title_story = title_story if title_story is not None else ''
self.palette_name = palette_name if palette_name else 'default'
self.colors_dict = copy.deepcopy(
select_palette(colors_loading(), self.palette_name))
if colors_dict is not None:
self.colors_dict.update(colors_dict)
self.plot.define_style_attributes(colors_dict=self.colors_dict)
self._case, self._classes = check_model(self.model)
self.postprocessing = postprocessing
self.check_label_dict()
if self.label_dict:
self.inv_label_dict = {v: k for k, v in self.label_dict.items()}
self.features_groups = features_groups
self.local_neighbors = None
self.features_stability = None
self.features_compacity = None
self.contributions = None
self.explain_data = None
self.features_imp = None
def compile(self,
x,
contributions=None,
y_pred=None,
y_target=None,
additional_data=None,
additional_features_dict=None):
"""
The compile method is the first step to understand model and
prediction. It performs the sorting of contributions, the reverse
preprocessing steps and performs all the calculations necessary for
a quick display of plots and efficient display of summary of
explanation. This step can last a few moments with large datasets.
Parameters
----------
x : pandas.DataFrame
Prediction set.
IMPORTANT: this should be the raw prediction set,
whose values are seen by the end user.
x is a preprocessed dataset: Shapash can apply the model to it
contributions : pandas.DataFrame, np.ndarray or list
single or multiple contributions (multi-class) to handle.
if pandas.Dataframe, the index and columns should be share with
the prediction set. if np.ndarray, index and columns will be
generated according to x dataset
y_pred : pandas.Series or pandas.DataFrame, optional (default: None)
Prediction values (1 column only).
The index must be identical to the index of x_init.
This is an interesting parameter for more explicit outputs.
Shapash lets users define their own predict,
as they may wish to set their own threshold (classification)
y_target : pandas.Series or pandas.DataFrame, optional (default: None)
Target values (1 column only).
The index must be identical to the index of x_init.
This is an interesting parameter for outputs on prediction
additional_data : pandas.DataFrame, optional (default: None)
Additional dataset of features outsite the model.
The index must be identical to the index of x_init.
This is an interesting parameter for visualisation and filtering
in Shapash SmartApp.
additional_features_dict : dict
Dictionary mapping technical feature names to domain names for additional data.
Example
--------
>>> xpl.compile(x=x_test)
"""
self.x_encoded = handle_categorical_missing(x)
x_init = inverse_transform(self.x_encoded, self.preprocessing)
self.x_init = handle_categorical_missing(x_init)
self.y_pred = check_y(self.x_init, y_pred, y_name="y_pred")
self.y_target = check_y(self.x_init, y_target, y_name="y_target")
self.prediction_error = predict_error(self.y_target, self.y_pred, self._case)
self._get_contributions_from_backend_or_user(x, contributions)
self.check_contributions()
self.columns_dict = {i: col for i, col in enumerate(self.x_init.columns)}
self.check_features_dict()
self.inv_features_dict = {v: k for k, v in self.features_dict.items()}
self._apply_all_postprocessing_modifications()
self.data = self.state.assign_contributions(
self.state.rank_contributions(
self.contributions,
self.x_init
)
)
self.features_desc = dict(self.x_init.nunique())
if self.features_groups is not None:
self._compile_features_groups(self.features_groups)
self.additional_features_dict = dict() if additional_features_dict is None else self._compile_additional_features_dict(additional_features_dict)
self.additional_data = self._compile_additional_data(additional_data)
def _get_contributions_from_backend_or_user(self,
x,
contributions):
# Computing contributions using backend
if contributions is None:
self.explain_data = self.backend.run_explainer(x=x)
self.contributions = self.backend.get_local_contributions(
x=x,
explain_data=self.explain_data)
else:
self.explain_data = contributions
self.contributions = self.backend.format_and_aggregate_local_contributions(
x=x,
contributions=contributions,
)
self.state = self.backend.state
def _apply_all_postprocessing_modifications(self):
postprocessing = self.modify_postprocessing(self.postprocessing)
check_postprocessing(self.x_init, postprocessing)
self.postprocessing_modifications = self.check_postprocessing_modif_strings(postprocessing)
self.postprocessing = postprocessing
if self.postprocessing_modifications:
self.x_contrib_plot = copy.deepcopy(self.x_init)
self.x_init = self.apply_postprocessing(postprocessing)
def _compile_features_groups(self,
features_groups):
"""
Performs required computations for groups of features.
"""
if self.backend.support_groups is False:
raise AssertionError(
f'Selected backend ({self.backend.name}) '
f'does not support groups of features.'
)
# Compute contributions for groups of features
self.contributions_groups = self.state.compute_grouped_contributions(
self.contributions, features_groups)
self.features_imp_groups = None
# Update features dict with groups names
self._update_features_dict_with_groups(features_groups=features_groups)
# Compute t-sne projections for groups of features
self.x_init_groups = create_grouped_features_values(
x_init=self.x_init, x_encoded=self.x_encoded,
preprocessing=self.preprocessing,
features_groups=self.features_groups,
features_dict=self.features_dict,
how='dict_of_values')
# Compute data attribute for groups of features
self.data_groups = self.state.assign_contributions(
self.state.rank_contributions(
self.contributions_groups,
self.x_init_groups
)
)
self.columns_dict_groups = {
i: col for i, col in enumerate(self.x_init_groups.columns)}
def _compile_additional_features_dict(self, additional_features_dict):
"""
Performs required computations for additional features dict.
"""
if not isinstance(additional_features_dict, dict):
raise ValueError(
"""
additional_features_dict must be a dict
"""
)
additional_features_dict = {f"_{key}": f"_{value}" for key, value in additional_features_dict.items()}
return additional_features_dict
def _compile_additional_data(self, additional_data):
"""
Performs required computations for additional data.
"""
if additional_data is not None:
check_additional_data(self.x_init, additional_data)
for feature in additional_data.columns:
if feature in self.features_dict.keys() and feature not in self.columns_dict.values():
self.additional_features_dict[f"_{feature}"] = f"_{self.features_dict[feature]}"
del self.features_dict[feature]
additional_data = additional_data.add_prefix("_")
for feature in set(list(additional_data.columns)) - set(self.additional_features_dict):
self.additional_features_dict[feature] = feature
return additional_data
def define_style(self,
palette_name=None,
colors_dict=None):
if palette_name is None and colors_dict is None:
raise ValueError("At least one of palette_name or colors_dict parameters must be defined")
new_palette_name = palette_name or self.palette_name
new_colors_dict = copy.deepcopy(
select_palette(colors_loading(), new_palette_name))
if colors_dict is not None:
new_colors_dict.update(colors_dict)
self.colors_dict.update(new_colors_dict)
self.plot.define_style_attributes(colors_dict=self.colors_dict)
def add(self,
y_pred=None,
y_target=None,
label_dict=None,
features_dict=None,
title_story: str = None,
additional_data=None,
additional_features_dict=None):
"""
add method allows the user to add a label_dict, features_dict
or y_pred without compiling again (and it can last a few moments).
y_pred can be used in the plot to color scatter.
y_pred is needed in the to_pandas method.
label_dict and features_dict displays allow to display clearer results.
Parameters
----------
y_pred : pandas.Series, optional (default: None)
Prediction values (1 column only).
The index must be identical to the index of x_init.
label_dict: dict, optional (default: None)
Dictionary mapping integer labels to domain names.
features_dict: dict, optional (default: None)
Dictionary mapping technical feature names to domain names.
title_story: str (default: None)
The default title is empty. You can specify a custom title
which can be used the webapp, or other methods
y_target : pandas.Series or pandas.DataFrame, optional (default: None)
Target values (1 column only).
The index must be identical to the index of x_init.
This is an interesting parameter for outputs on prediction
additional_data : pandas.DataFrame, optional (default: None)
Additional dataset of features outsite the model.
The index must be identical to the index of x_init.
This is an interesting parameter for visualisation and filtering
in Shapash SmartApp.
additional_features_dict : dict
Dictionary mapping technical feature names to domain names for additional data.
"""
if y_pred is not None:
self.y_pred = check_y(self.x_init, y_pred, y_name="y_pred")
if hasattr(self, 'y_target'):
self.prediction_error = predict_error(self.y_target, self.y_pred, self._case)
if y_target is not None:
self.y_target = check_y(self.x_init, y_target, y_name="y_target")
if hasattr(self, 'y_pred'):
self.prediction_error = predict_error(self.y_target, self.y_pred, self._case)
if label_dict is not None:
if isinstance(label_dict, dict) is False:
raise ValueError(
"""
label_dict must be a dict
"""
)
self.label_dict = label_dict
self.check_label_dict()
self.inv_label_dict = {v: k for k, v in self.label_dict.items()}
if features_dict is not None:
if isinstance(features_dict, dict) is False:
raise ValueError(
"""
features_dict must be a dict
"""
)
self.features_dict = features_dict
self.check_features_dict()
self.inv_features_dict = {v: k for k, v in self.features_dict.items()}
if title_story is not None:
self.title_story = title_story
if additional_features_dict is not None:
self.additional_features_dict = self._compile_additional_features_dict(additional_features_dict)
if additional_data is not None:
self.additional_data = self._compile_additional_data(additional_data)
def get_interaction_values(self, n_samples_max=None, selection=None):
"""
Compute shap interaction values for each row of x_encoded.
This function is only available for explainer of type TreeExplainer (used for tree based models).
Please refer to the official tree shap paper for more information : https://arxiv.org/pdf/1802.03888.pdf
Parameters
----------
n_samples_max : int, optional
Limit the number of points for which we compute the interactions.
selection : list, optional
Contains list of index, subset of the input DataFrame that we want to plot
Returns
-------
np.ndarray
Shap interaction values for each sample as an array of shape (# samples x # features x # features).
"""
x = copy.deepcopy(self.x_encoded)
if selection:
x = x.loc[selection]
if hasattr(self, 'x_interaction'):
if self.x_interaction.equals(x[:n_samples_max]):
return self.interaction_values
self.x_interaction = x[:n_samples_max]
self.interaction_values = get_shap_interaction_values(self.x_interaction, self.backend.explainer)
return self.interaction_values
def check_postprocessing_modif_strings(self, postprocessing=None):
"""
Check if any modification of postprocessing will convert numeric values into strings values.
If so, return True, otherwise False.
Parameters
----------
postprocessing: dict
Dict of postprocessing modifications to apply.
Returns
-------
modif: bool
Boolean which is True if any numerical variable will be converted into string.
"""
modif = False
if postprocessing is not None:
for key in postprocessing.keys():
dict_postprocess = postprocessing[key]
if dict_postprocess['type'] in {'prefix', 'suffix'} \
and pd.api.types.is_numeric_dtype(self.x_init[key]):
modif = True
return modif
def modify_postprocessing(self, postprocessing=None):
"""
Modifies postprocessing parameter, to change only keys, with features name,
in case of parameters are not real feature names (with columns_dict,
or inv_features_dict).
Parameters
----------
postprocessing : Dict
Dictionnary of postprocessing to modify.
Returns
-------
Dict
Modified dictionnary, with same values but keys directly referencing to feature names.
"""
if postprocessing:
new_dic = dict()
for key in postprocessing.keys():
if key in self.features_dict:
new_dic[key] = postprocessing[key]
elif key in self.columns_dict.keys():
new_dic[self.columns_dict[key]] = postprocessing[key]
elif key in self.inv_features_dict:
new_dic[self.inv_features_dict[key]] = postprocessing[key]
else:
raise ValueError(f"Feature name '{key}' not found in the dataset.")
return new_dic
def apply_postprocessing(self, postprocessing=None):
"""
Modifies x_init Dataframe according to postprocessing modifications, if exists.
Parameters
----------
postprocessing: Dict
Dictionnary of postprocessing modifications to apply in x_init.
Returns
-------
pandas.Dataframe
Returns x_init if postprocessing is empty, modified dataframe otherwise.
"""
if postprocessing:
return apply_postprocessing(self.x_init, postprocessing)
else:
return self.x_init
def check_label_dict(self):
"""
Check if label_dict and model _classes match
"""
if self._case != "regression":
return check_label_dict(self.label_dict, self._case, self._classes)
def check_features_dict(self):
"""
Check the features_dict and add the necessary keys if all the
input X columns are not present
"""
for feature in (set(list(self.columns_dict.values())) - set(list(self.features_dict))):
self.features_dict[feature] = feature
def _update_features_dict_with_groups(self, features_groups):
"""
Add groups into features dict and inv_features_dict if not present.
"""
for group_name in features_groups.keys():
self.features_desc[group_name] = 1000
if group_name not in self.features_dict.keys():
self.features_dict[group_name] = group_name
self.inv_features_dict[group_name] = group_name
def check_contributions(self):
"""
Check if contributions and prediction set match in terms of shape and index.
"""
if not self.state.check_contributions(self.contributions, self.x_init):
raise ValueError(
"""
Prediction set and contributions should have exactly the same number of lines
and number of columns. the order of the columns must be the same
Please check x, contributions and preprocessing arguments.
"""
)
def check_label_name(self, label, origin=None):
"""
Convert a string label in integer. If the label is already
an integer nothing is done. In all other cases an error is raised.
Parameters
----------
label: int or string
Integer (id) or string (business names)
origin: None, 'num', 'code', 'value' (default: None)
Kind of the label used in parameter
Returns
-------
tuple
label num, label code (class of the mode), label value
"""
if origin is None:
if label in self._classes:
origin = 'code'
elif self.label_dict is not None and label in self.label_dict.values():
origin = 'value'
elif isinstance(label, int) and label in range(-1, len(self._classes)):
origin = 'num'
try:
if origin == 'num':
label_num = label
label_code = self._classes[label]
label_value = self.label_dict[label_code] if self.label_dict else label_code
elif origin == 'code':
label_code = label
label_num = self._classes.index(label)
label_value = self.label_dict[label_code] if self.label_dict else label_code
elif origin == 'value':
label_code = self.inv_label_dict[label]
label_num = self._classes.index(label_code)
label_value = label
else:
raise ValueError
except ValueError:
raise Exception({"message": "Origin must be 'num', 'code' or 'value'."})
except Exception:
raise Exception({"message": f"Label ({label}) not found for origin ({origin})"})
return label_num, label_code, label_value
def check_features_name(self, features, use_groups=False):
"""
Convert a list of feature names (string) or features ids into features ids.
Features names can be part of columns_dict or features_dict.
Parameters
----------
features : List
List of ints (columns ids) or of strings (business names)
use_groups : bool
Whether or not features parameter includes groups of features
Returns
-------
list of ints
Columns ids compatible with var_dict
"""
columns_dict = self.columns_dict if use_groups is False else self.columns_dict_groups
return check_features_name(columns_dict, self.features_dict, features)
def check_attributes(self, attribute):
"""
Check that explainer has the attribute precised
Parameters
----------
attribute: string
the label of the attribute to test
Returns
-------
Object content of the attribute specified from SmartExplainer instance
"""
if not hasattr(self, attribute):
raise ValueError(
"""
attribute {0} isn't an attribute of the explainer precised.
""".format(attribute))
return self.__dict__[attribute]
def filter(
self,
features_to_hide=None,
threshold=None,
positive=None,
max_contrib=None,
display_groups=None
):
"""
The filter method is an important method which allows to summarize the local explainability
by using the user defined parameters which correspond to its use case.
Filter method is used with the local_plot method of Smarplotter to see the concrete result of this summary
with a local contribution barchart
Please, watch the local_plot tutorial to see how these two methods are combined with a concrete example
Parameters
----------
features_to_hide : list, optional (default: None)
List of strings, containing features to hide.
threshold : float, optional (default: None)
Absolute threshold below which any contribution is hidden.
positive: bool, optional (default: None)
If True, hide negative values. False, hide positive values
If None, hide nothing.
max_contrib : int, optional (default: None)
Maximum number of contributions to show.
display_groups : bool (default: None)
Whether or not to display groups of features. This option is
only useful if groups of features are declared when compiling
SmartExplainer object.
"""
display_groups = True if (display_groups is not False and self.features_groups is not None) else False
if display_groups:
data = self.data_groups
else:
data = self.data
mask = [self.state.init_mask(data['contrib_sorted'], True)]
if features_to_hide:
mask.append(
self.state.hide_contributions(
data['var_dict'],
features_list=self.check_features_name(features_to_hide, use_groups=display_groups)
)
)
if threshold:
mask.append(
self.state.cap_contributions(
data['contrib_sorted'],
threshold=threshold
)
)
if positive is not None:
mask.append(
self.state.sign_contributions(
data['contrib_sorted'],
positive=positive
)
)
self.mask = self.state.combine_masks(mask)
if max_contrib:
self.mask = self.state.cutoff_contributions(self.mask, max_contrib=max_contrib)
self.masked_contributions = self.state.compute_masked_contributions(
data['contrib_sorted'],
self.mask
)
self.mask_params = {
'features_to_hide': features_to_hide,
'threshold': threshold,
'positive': positive,
'max_contrib': max_contrib
}
def save(self, path):
"""
Save method allows user to save SmartExplainer object on disk
using a pickle file.
Save method can be useful: you don't have to recompile to display
results later
Parameters
----------
path : str
File path to store the pickle file
Example
--------
>>> xpl.save('path_to_pkl/xpl.pkl')
"""
if hasattr(self, 'smartapp'):
self.smartapp = None
save_pickle(self, path)
@classmethod
def load(cls, path):
"""
Load method allows Shapash user to use pickled SmartExplainer.
To use this method you must first declare your SmartExplainer object
Watch the following example
Parameters
----------
path : str
File path of the pickle file.
Example
--------
>>> xpl = SmartExplainer.load('path_to_pkl/xpl.pkl')
"""
xpl = load_pickle(path)
if isinstance(xpl, SmartExplainer):
smart_explainer = cls(model=xpl.model)
smart_explainer.__dict__.update(xpl.__dict__)
return smart_explainer
else:
raise ValueError(
"File is not a SmartExplainer object"
)
def predict_proba(self):
"""
The predict_proba compute the proba values for each x_encoded row
"""
self.proba_values = predict_proba(self.model, self.x_encoded, self._classes)
def predict(self):
"""
The predict method computes the model output for each x_encoded row and stores it in y_pred attribute
"""
self.y_pred = predict(self.model, self.x_encoded)
if hasattr(self, 'y_target'):
self.prediction_error = predict_error(self.y_target, self.y_pred, self._case)
def to_pandas(
self,
features_to_hide=None,
threshold=None,
positive=None,
max_contrib=None,
proba=False,
use_groups=None
):
"""
The to_pandas method allows to export the summary of local explainability.
This method proposes a set of parameters to summarize the explainability of each point.
If the user does not specify any, the to_pandas method uses the parameter specified during
the last execution of the filter method.
In classification case, The method to_pandas summarizes the explicability which corresponds
to the predicted values specified by the user (with compile or add method).
the proba parameter displays the corresponding predict proba value for each point
In classification case, There are 2 ways to use this to pandas method.
- Provide a real prediction set to explain
- Focus on a constant target value and look at the proba and explainability corresponding to each point.
(in that case, specify a constant pd.Series with add or compile method)
Examples are presented in the tutorial local_plot (please check tutorial part of this doc)
Parameters
----------
features_to_hide : list, optional (default: None)
List of strings, containing features to hide.
threshold : float, optional (default: None)
Absolute threshold below which any contribution is hidden.
positive: bool, optional (default: None)
If True, hide negative values. Hide positive values otherwise. If None, hide nothing.
max_contrib : int, optional (default: 5)
Number of contributions to show in the pandas df
proba : bool, optional (default: False)
adding proba in output df
use_groups : bool (optional)
Whether or not to use groups of features contributions (only available if features_groups
parameter was not empty when calling compile method).
Returns
-------
pandas.DataFrame
- selected explanation of each row for classification case
Examples
--------
>>> summary_df = xpl.to_pandas(max_contrib=2,proba=True)
>>> summary_df
pred proba feature_1 value_1 contribution_1 feature_2 value_2 contribution_2
0 0 0.756416 Sex 1.0 0.322308 Pclass 3.0 0.155069
1 3 0.628911 Sex 2.0 0.585475 Pclass 1.0 0.370504
2 0 0.543308 Sex 2.0 -0.486667 Pclass 3.0 0.255072
"""
use_groups = True if (use_groups is not False and self.features_groups is not None) else False
if use_groups:
data = self.data_groups
else:
data = self.data
# Classification: y_pred is needed
if self.y_pred is None:
raise ValueError(
"You have to specify y_pred argument. Please use add() or compile() method"
)
# Apply filter method if necessary
if all(var is None for var in [features_to_hide, threshold, positive, max_contrib]) \
and hasattr(self, 'mask_params') \
and (
# if the already computed mask does not have the right shape (this can happen when
# we use groups of features once and then use method without groups)
(isinstance(data['contrib_sorted'], pd.DataFrame)
and len(data["contrib_sorted"].columns) == len(self.mask.columns))
or
(isinstance(data['contrib_sorted'], list)
and len(data["contrib_sorted"][0].columns) == len(self.mask[0].columns))
):
print('to_pandas params: ' + str(self.mask_params))
else:
self.filter(features_to_hide=features_to_hide,
threshold=threshold,
positive=positive,
max_contrib=max_contrib,
display_groups=use_groups)
if use_groups:
columns_dict = {i: col for i, col in enumerate(self.x_init_groups.columns)}
else:
columns_dict = self.columns_dict
# Summarize information
data['summary'] = self.state.summarize(
data['contrib_sorted'],
data['var_dict'],
data['x_sorted'],
self.mask,
columns_dict,
self.features_dict
)
# Matching with y_pred
if proba:
self.predict_proba() if proba else None
proba_values = self.proba_values
else:
proba_values = None
y_pred, summary = keep_right_contributions(self.y_pred, data['summary'],
self._case, self._classes,
self.label_dict, proba_values)
return pd.concat([y_pred, summary], axis=1)
def compute_features_import(self, force=False):
"""
Compute a relative features importance, sum of absolute values
of the contributions for each.
Features importance compute in base 100
Parameters
----------
force: bool (default: False)
True to force de compute if features importance is
already calculated
Returns
-------
pd.Serie (Regression)
or list of pd.Serie (Classification: One Serie for each target modality)
Each Serie: feature importance, One row by feature,
index of the serie = contributions.columns
"""
self.features_imp = self.backend.get_global_features_importance(
contributions=self.contributions,
explain_data=self.explain_data,
subset=None
)
if self.features_groups is not None and self.features_imp_groups is None:
self.features_imp_groups = self.state.compute_features_import(self.contributions_groups)
def compute_features_stability(self, selection):
"""
For a selection of instances, compute features stability metrics used in
methods `local_neighbors_plot` and `local_stability_plot`.
- If selection is a single instance, the method returns the (normalized) contribution values
of instance and corresponding neighbors.
- If selection represents multiple instances, the method returns the average (normalized) contribution values
of instances and neighbors (=amplitude), as well as the variability of those values in the neighborhood (=variability)
Parameters
----------
selection: list
Indices of rows to be displayed on the stability plot
Returns
-------
Dictionary
Values that will be displayed on the graph. Keys are "amplitude", "variability" and "norm_shap"
"""
if (self._case == "classification") and (len(self._classes) > 2):
raise AssertionError("Multi-class classification is not supported")
all_neighbors = find_neighbors(selection, self.x_encoded, self.model, self._case)
# Check if entry is a single instance or not
if len(selection) == 1:
# Compute explanations for instance and neighbors
norm_shap, _, _ = shap_neighbors(all_neighbors[0], self.x_encoded, self.contributions, self._case)
self.local_neighbors = {"norm_shap": norm_shap}
else:
numb_expl = len(selection)
amplitude = np.zeros((numb_expl, self.x_init.shape[1]))
variability = np.zeros((numb_expl, self.x_init.shape[1]))
# For each instance (+ neighbors), compute explanation
for i in range(numb_expl):
(_, variability[i, :], amplitude[i, :],) = shap_neighbors(all_neighbors[i], self.x_encoded, self.contributions, self._case)
self.features_stability = {"variability": variability, "amplitude": amplitude}
def compute_features_compacity(self, selection, distance, nb_features):
"""
For a selection of instances, compute features compacity metrics used in method `compacity_plot`.
The method returns :
* the minimum number of features needed for a given approximation level
* conversely, the approximation reached with a given number of features