-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathDeepExplainer.py
More file actions
366 lines (306 loc) · 18 KB
/
Copy pathDeepExplainer.py
File metadata and controls
366 lines (306 loc) · 18 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
# -*- coding: utf-8 -*-
"""
Created on Fri Sep 5 15:24:06 2025
@author: renyuanfang
"""
import pandas as pd
import numpy as np
from sklearn.base import BaseEstimator,TransformerMixin
from MySurgeryRisk_XAI.DeepExplainerDataset import DeepExplainerDataset
from MySurgeryRisk_XAI.Global_feature_importance.deep_shap_global import deep_shap_get_global_feature_importance
from MySurgeryRisk_XAI.Local_feature_importance.deep_shap_method import deep_shap_get_local_feature_importance
from MySurgeryRisk_XAI.Local_feature_importance.deep_lime_method import deep_lime_get_local_feature_importance
from MySurgeryRisk_XAI.Contrastive_explaination.deep_dice_counterfactual import deep_dice_get_counterfactual
from MySurgeryRisk_XAI.Model_card.model_card_generator import ModelCardGenerator
from MySurgeryRisk_XAI.Similar_instances.deep_principle_feature_matching import deep_principle_feature_matching_similar_patients
from MySurgeryRisk_XAI.Similar_instances.deep_feature_importance_cosine_similar_patients import deep_feature_importance_based_cosine_similar_patients
from MySurgeryRisk_XAI.Similar_instances.deep_feature_importance_distance_similar_patients import deep_feature_importance_based_distance_similar_patients
import os
import torch
import torch.nn as nn
from typing import Callable, List, Union
class DummyTransformer(BaseEstimator, TransformerMixin):
"""A dummy transformer to revert transformed data to original format."""
def __init__(self, inverse_mapping, feature_names):
"""
This dummy transformer is specifically designed for LIME function. LIME requires the string format categorical variables to be
encoded as integer format. Thus, we developed this transformer to transform the interger format categorical variable into the
original format. The transformer will be added into the developed model as the first transformer.
Parameters
----------
inverse_mapping : dict
Dictionary contains the mapping from integer to original string format for the categorical variable, as well as the data type of original variable levels.
feature_names : list or np.ndarray
Feature names of original dataset.
Returns
-------
DummyTransformer
DummyTransformer class instance.
"""
self.inverse_mapping = inverse_mapping
self.feature_names = feature_names
def fit(self, X, y=None):
return self
def transform(self, X):
"""
Transform the interger format categorical variable into the original format.
Parameters
----------
X : np.ndarray or pandas.dataframe
Input data requires tranformation.
Returns
-------
pandas.dataframe
Transformed dataset.
"""
X_transformed = X.copy()
if isinstance(X_transformed, np.ndarray):
X_transformed = pd.DataFrame(X_transformed, columns=self.feature_names)
for feature in self.inverse_mapping.keys():
datatype = self.inverse_mapping[feature]['datatype']
mapping = self.inverse_mapping[feature]['mapping']
X_transformed[feature] = X_transformed[feature].map(mapping)
X_transformed[feature] = X_transformed[feature].astype(datatype)
return X_transformed
class DeepExplainer:
def __init__(self, model, deepExplainerDataset):
"""
Create an Explainer class instance. This class provides interfaces to call functions of explainability and interpretability.
Parameters
----------
model : Pipeline or tree based estimator
The AI model.
explainerDataset: ExplainerDataset
The dataset used for the development and validation of the model
Returns
-------
Explainer
Explainer class instance.
"""
def is_pytorch_model(model_obj):
# Direct check
if isinstance(model_obj, nn.Module):
return True
# Check for common wrapper patterns (like skorch)
if hasattr(model_obj, 'module_') and isinstance(model_obj.module_, nn.Module):
return True
return False
# --- VALIDATION FIRST ---
# It's best practice to validate inputs before setting any attributes.
if not is_pytorch_model(model):
# Raise an exception. This will stop the object from being created.
raise TypeError(
'The model type is not supported. This explainer currently only '
'supports PyTorch (torch.nn.Module) models.'
)
if not isinstance(deepExplainerDataset, DeepExplainerDataset):
# Get the type of the incorrect object that was passed
actual_type = type(deepExplainerDataset).__name__
raise TypeError(
f"Invalid type for 'deepExplainerDataset'. Expected an instance of DeepExplainerDataset, "
f"but received an object of type '{actual_type}'."
)
self.model = model
self.dataset = deepExplainerDataset
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
###check if the last layer of the model output probability
def check_probability_output():
self.model.to(self.device).eval()
with torch.no_grad():
data, _, _, _ = self.dataset.get_tensor_data()
data = data.to(self.device)
predictions = self.model(data)
return torch.all((predictions >= 0.0) & (predictions <= 1.0))
assert check_probability_output() == True, "Error: The output of the model is not a probability. Use a wrapper around your model and apply the sigmoid function."
def interpretability(self, feature_names=None, plot=True, num_features_display=10,
**kwargs):
"""
This is an interface to call functions of generating interpretability. Support shap only.
Parameters
----------
feature_names: list
Feature names of transformed data. Required for methods shap and inherent_importance methods if the features for the final estimator are different from the original feature names.
If the transformed features represent levels of the same categorical variable, the feature names should be use the same names so the overall effect of a categorical variable can be processed.
See tutorial for detailed explanations.
plot: bool
Whether to plot the figure, default is True.
num_features_display: int
If plotting the figure, the number of important features should show in the figures.
Returns
-------
Importance score.
Return a dataframe of shap values with shape (Number of samples, Number of features).
"""
if feature_names:
if (isinstance(feature_names, list) and len(feature_names) > 0) == False:
feature_names = self.dataset.feature_names
else:
feature_names=self.dataset.feature_names
train_x_tensor, _, _, _ = self.dataset.get_tensor_data()
train_x_tensor = train_x_tensor.to(self.device)
return deep_shap_get_global_feature_importance(self.model, train_x_tensor, feature_names=feature_names,
plot=plot, num_features_display=num_features_display, **kwargs)
def explainability_why(self, test_X, method='shap', feature_names=None, plot=True, num_features_display=10, figure_type='bar', **kwargs):
assert method in ['shap', 'lime'], f"Error: Invalid method: {method}. Allowed methods are shap and lime."
if isinstance(test_X, pd.Series):
test_X = test_X.to_frame().T
elif isinstance(test_X, np.ndarray):
if len(test_X.shape) == 1:
test_X = test_X.reshape((1, -1))
test_X = pd.DataFrame(test_X, columns=self.dataset.feature_names)
if method == 'shap':
transformed_tensor = self.dataset.transform_data(test_X)
train_x_tensor, _, _, _ = self.dataset.get_tensor_data()
train_x_tensor = train_x_tensor.to(self.device)
transformed_tensor = transformed_tensor.to(self.device)
self.model.to(self.device)
return deep_shap_get_local_feature_importance(self.model, train_x_tensor, sample=transformed_tensor, raw_sample=test_X, feature_names=feature_names,
plot=plot, num_features_display=num_features_display, **kwargs)
elif method == 'lime':
test_X = test_X.copy() ###need to copy, avoiding changing affect others
train_X = self.dataset.train_x.copy()
if len(self.dataset.categorical_variable_level_map) > 0:
inverse_mapping = {}
for feature_index, level_array in self.dataset.categorical_variable_level_map.items():
mapping = {}
feature = self.dataset.feature_names[feature_index]
levels = list(level_array)
datatype = train_X[feature].dtype
inverse_mapping[feature] = {}
inverse_mapping[feature]['datatype'] = datatype
inverse_mapping[feature]['mapping'] = {}
test_level = test_X[feature].unique().tolist()
for lev in test_level:
if lev not in levels:
levels.append(lev)
for i, lev in enumerate(levels):
mapping[lev] = i
inverse_mapping[feature]['mapping'][i] = lev
train_X[feature] = train_X[feature].map(mapping)
test_X[feature] = test_X[feature].map(mapping)
train_X[feature] = train_X[feature].astype(int)
test_X[feature] = test_X[feature].astype(int)
new_step = DummyTransformer(inverse_mapping, self.dataset.feature_names)
precessor = lambda x: self.dataset.transform_data(new_step.transform(x))
else:
precessor = self.dataset.transform_data
return deep_lime_get_local_feature_importance(self.model, test_X, train_X, precessor=precessor, device=self.device, categorical_variable_level_map=self.dataset.categorical_variable_level_map,
feature_names=self.dataset.feature_names,plot=plot,num_features_display=num_features_display, **kwargs)
def explainability_whynot(self, test_X, stopping_threshold=0.5, plot=True, **kwargs):
"""
This is an interface to call function of generating contrastive explanation.
Parameters
----------
test_X: pandas.Series, pandas.DataFrame, numpy.ndarray
test data for explainability
stopping_threshold: float in range (0, 1)
decision boundary for flipping the outcome
plot: bool
Whether to plot the figure, default is True.
Returns
-------
List
A list of contrastive examples, length of the list equals the number of test samples.
For each element, there is a dataframe containing the contrastive samples.
"""
if isinstance(test_X, pd.Series):
test_X = test_X.to_frame().T
elif isinstance(test_X, np.ndarray):
if len(test_X.shape) == 1:
test_X = test_X.reshape((1, -1))
test_X = pd.DataFrame(test_X, columns=self.dataset.feature_names)
self.model.to("cpu")
return deep_dice_get_counterfactual(self.model, test_X, precessor=self.dataset.transform_data, categorical_features=self.dataset.categorical_variables,
feature_names=self.dataset.feature_names, outcome_name=self.dataset.outcome_name,
features=self.dataset.features, stopping_threshold=stopping_threshold, plot=plot, **kwargs)
def explainability_whatif(self, test_X, features_to_vary='all', permitted_range=None, stopping_threshold=0.5, plot=True, **kwargs):
"""
This is an interface to call function of generating recommendations through contrastive explanation approach, focus on the actionable features.
Parameters
----------
test_X: pandas.Series, pandas.DataFrame, numpy.ndarray
test data for explainability
features_to_very: list or string 'all'
set the features that are permitted to change. Default is 'all'.
permitted_range: dict of list
the range of features can be changed. Keys are feauture names. For numeric features, give min and max. For categorical features, give permitted levels can be changed.
stopping_threshold: float in range (0, 1)
decision boundary for flipping the outcome
plot: bool
Whether to plot the figure, default is True.
Returns
-------
List
A list of contrastive examples, length of the list equals the number of test samples.
For each element, there is a dataframe containing the contrastive samples.
"""
if isinstance(test_X, pd.Series):
test_X = test_X.to_frame().T
elif isinstance(test_X, np.ndarray):
if len(test_X.shape) == 1:
test_X = test_X.reshape((1, -1))
test_X = pd.DataFrame(test_X, columns=self.dataset.feature_names)
self.model.to("cpu")
return deep_dice_get_counterfactual(self.model, test_X, precessor=self.dataset.transform_data, categorical_features=self.dataset.categorical_variables,
feature_names=self.dataset.feature_names, outcome_name=self.dataset.outcome_name,
features=self.dataset.features, features_to_vary=features_to_vary, permitted_range=permitted_range,
stopping_threshold=stopping_threshold, plot=plot, **kwargs)
@staticmethod
def explainability_how(content_json_path, output_path):
"""
This is an interface to call function of generating model card.
Parameters
----------
content_json_path: str
json path of model card content
output_path: str
output path of html model card
Returns
-------
"""
if os.path.exists(content_json_path):
generator = ModelCardGenerator(content_json_path)
generator.save_html(output_path)
else:
print(f'Error: file {content_json_path} does not exist.')
return
def explainability_whatelse(self, test_X, method='feature_matching', n_similar=100, plot=True, matching_func: Callable[[np.ndarray, np.ndarray], bool]=None,
embedding_fun: Callable[[Union[torch.Tensor, List[torch.Tensor]], nn.Module], torch.Tensor]=None):
"""
This is an interface to call function of identifying similar patients. Support feature importance-based similarity (cosine similarity and euclidean distance similarity), proximity matrix, and principal feature matching methods.
It provides model's predictions and actual event ratios for similar patients.
Parameters
----------
test_X: pandas.Series, pandas.DataFrame, numpy.ndarray
test data for explainability
method: str
method name, allowed methods are proximity_similarity, cosine_similarity, euclidean_distince_similarity and feature_matching.
n_similar: int
number of similar patients should be identified. This parameter is not for principal feature matching method.
plot: bool
Whether to plot the figure, default is True.
matching_func: Callable[[np.ndarray, np.ndarray], bool]
Matching function to be called for principal feature matching method.
Returns
-------
pandas.DataFrame
Similarity scores
Dict
model's predictions and actual event ratios for similar patients.
"""
assert method in ['cosine_similarity', 'euclidean_distince_similarity', 'feature_matching'], f"Error: Invalid method: {method}. Allowed methods are cosine_similarity, euclidean_distince_similarity, feature_matching."
if isinstance(test_X, pd.Series):
test_X = test_X.to_frame().T
elif isinstance(test_X, np.ndarray):
if len(test_X.shape) == 1:
test_X = test_X.reshape((1, -1))
test_X = pd.DataFrame(test_X, columns=self.dataset.feature_names)
if method == 'feature_matching':
return deep_principle_feature_matching_similar_patients(self.model, test_X, self.dataset.train_x, self.dataset.train_y,
precessor=self.dataset.transform_func, device=self.device, matching_func=matching_func, plot=plot)
elif method == 'cosine_similarity':
return deep_feature_importance_based_cosine_similar_patients(self.model, test_X, self.dataset.train_x, self.dataset.train_y,
precessor=self.dataset.transform_func, device=self.device, embedding_fun=embedding_fun, n_similar=n_similar, plot=plot)
elif method == 'euclidean_distince_similarity':
return deep_feature_importance_based_distance_similar_patients(self.model, test_X, self.dataset.train_x, self.dataset.train_y,
precessor=self.dataset.transform_func, device=self.device, embedding_fun=embedding_fun, n_similar=n_similar, plot=plot)