Skip to content

Commit

Permalink
feat: SHAP support for symmetric CatBoost models (#1619)
Browse files Browse the repository at this point in the history
* chore: remove redundant SHAP checks

* fix: convert_model is a staticmethod

* fix: Create GBTDAALRegressor if necessary

* Add SHAP kwargs to GBTDAALRegressor

* fix SHAP support for catboost symmetric trees

* Skip SHAP calculation for a month at a time

* Fix is_classification property

* Update expected warning message in test

* Add GBTDAALClassifier fallback

* fix version check

* Disable CatBoost SHAP checks for Jan
  • Loading branch information
ahuber21 authored Jan 9, 2024
1 parent 0e137ab commit 7db9b0f
Show file tree
Hide file tree
Showing 5 changed files with 223 additions and 146 deletions.
12 changes: 10 additions & 2 deletions daal4py/mb/model_builders.py
Original file line number Diff line number Diff line change
Expand Up @@ -283,8 +283,16 @@ def predict_proba(self, X):


def convert_model(model):
gbm = GBTDAALModel()
gbm._convert_model(model)
try:
gbm = GBTDAALModel()
gbm._convert_model(model)
except TypeError as err:
if "Only GBTDAALRegressor can be created" in str(err):
gbm = d4p.sklearn.ensemble.GBTDAALRegressor.convert_model(model)
elif "Only GBTDAALClassifier can be created" in str(err):
gbm = d4p.sklearn.ensemble.GBTDAALClassifier.convert_model(model)
else:
raise

gbm._is_regression = isinstance(gbm.daal_model_, d4p.gbt_regression_model)

Expand Down
6 changes: 4 additions & 2 deletions daal4py/sklearn/ensemble/GBTDAAL.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,6 +235,7 @@ def predict_log_proba(self, X):

return proba

@staticmethod
def convert_model(model):
gbm = GBTDAALClassifier()
gbm._convert_model(model)
Expand Down Expand Up @@ -292,7 +293,7 @@ def fit(self, X, y):
return self

@run_with_n_jobs
def predict(self, X):
def predict(self, X, pred_contribs=False, pred_interactions=False):
# Input validation
if not self.allow_nan_:
X = check_array(X, dtype=[np.single, np.double])
Expand All @@ -303,8 +304,9 @@ def predict(self, X):
check_is_fitted(self, ["n_features_in_"])

fptype = getFPType(X)
return self._predict_regression(X, fptype)
return self._predict_regression(X, fptype, pred_contribs, pred_interactions)

@staticmethod
def convert_model(model):
gbm = GBTDAALRegressor()
gbm._convert_model(model)
Expand Down
Loading

0 comments on commit 7db9b0f

Please sign in to comment.