Skip to content

Commit

Permalink
added userconfig exception for totalcfs<1 (#125)
Browse files Browse the repository at this point in the history
* added userconfig exception for totalcfs<1

* added unit test for explainer base
  • Loading branch information
amit-sharma authored Apr 23, 2021
1 parent f95c777 commit e69f7ca
Show file tree
Hide file tree
Showing 4 changed files with 74 additions and 2 deletions.
2 changes: 1 addition & 1 deletion dice_ml/explainer_interfaces/dice_random.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,7 @@ def _generate_counterfactuals(self, query_instance, total_CFs, desired_range, d
m, 'min %02d' % s, 'sec')
else:
if self.total_cfs_found == 0 :
print('No Counterfactuals found for the given configuation, perhaps try with different parameters...', '; total time taken: %02d' % m, 'min %02d' % s, 'sec')
print('No Counterfactuals found for the given configuration, perhaps try with different parameters...', '; total time taken: %02d' % m, 'min %02d' % s, 'sec')
else:
print('Only %d (required %d) Diverse Counterfactuals found for the given configuration, perhaps try with different parameters...' % (self.total_cfs_found, self.total_CFs), '; total time taken: %02d' % m, 'min %02d' % s, 'sec')

Expand Down
3 changes: 2 additions & 1 deletion dice_ml/explainer_interfaces/explainer_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,8 @@ def generate_counterfactuals(self, query_instances, total_CFs,
:returns: A CounterfactualExplanations object that contains the list of
counterfactual examples per query_instance as one of its attributes.
"""

if total_CFs <= 0:
raise UserConfigValidationException("The number of counterfactuals generated per query instance (total_CFs) should be a positive integer.")
cf_examples_arr = []
query_instances_list = []
if isinstance(query_instances, pd.DataFrame):
Expand Down
34 changes: 34 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,40 @@
import dice_ml
from dice_ml.utils import helpers

@pytest.fixture
def binary_classification_exp_object(method="random"):
backend = 'sklearn'
dataset = helpers.load_custom_testing_dataset_binary()
d = dice_ml.Data(dataframe=dataset, continuous_features=['Numerical'], outcome_name='Outcome')
ML_modelpath = helpers.get_custom_dataset_modelpath_pipeline_binary()
m = dice_ml.Model(model_path=ML_modelpath, backend=backend)
exp = dice_ml.Dice(d, m, method=method)
return exp


@pytest.fixture
def multi_classification_exp_object(method="random"):
backend = 'sklearn'
dataset = helpers.load_custom_testing_dataset_multiclass()
d = dice_ml.Data(dataframe=dataset, continuous_features=['Numerical'], outcome_name='Outcome')
ML_modelpath = helpers.get_custom_dataset_modelpath_pipeline_multiclass()
m = dice_ml.Model(model_path=ML_modelpath, backend=backend)
exp = dice_ml.Dice(d, m, method=method)
return exp


@pytest.fixture
def regression_exp_object(method="random"):
backend = 'sklearn'
dataset = helpers.load_custom_testing_dataset_regression()
d = dice_ml.Data(dataframe=dataset, continuous_features=['Numerical'], outcome_name='Outcome')
ML_modelpath = helpers.get_custom_dataset_modelpath_pipeline_regression()
m = dice_ml.Model(model_path=ML_modelpath, backend=backend, model_type='regressor')
exp = dice_ml.Dice(d, m, method=method)
return exp



@pytest.fixture
def public_data_object():
"""
Expand Down
37 changes: 37 additions & 0 deletions tests/test_dice_interface/test_explainer_base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
import pytest
from dice_ml.utils.exception import UserConfigValidationException


class TestExplainerBaseBinaryClassification:

@pytest.mark.parametrize("desired_class, binary_classification_exp_object", [(1, 'random'),(1,'genetic'),(1,'kdtree')], indirect=['binary_classification_exp_object'])
def test_zero_totalcfs(self, desired_class, binary_classification_exp_object, sample_custom_query_1):
exp = binary_classification_exp_object # explainer object
with pytest.raises(UserConfigValidationException):
exp.generate_counterfactuals(
query_instances=[sample_custom_query_1],
total_CFs=0,
desired_class=desired_class)

class TestExplainerBaseMultiClassClassification:

@pytest.mark.parametrize("desired_class, multi_classification_exp_object", [(1, 'random'),(1,'genetic'),(1,'kdtree')], indirect=['multi_classification_exp_object'])
def test_zero_totalcfs(self, desired_class, multi_classification_exp_object, sample_custom_query_1):
exp = multi_classification_exp_object # explainer object
with pytest.raises(UserConfigValidationException):
exp.generate_counterfactuals(
query_instances=[sample_custom_query_1],
total_CFs=0,
desired_class=desired_class)


class TestExplainerBaseRegression:

@pytest.mark.parametrize("desired_class, regression_exp_object", [(1, 'random'),(1,'genetic'),(1,'kdtree')], indirect=['regression_exp_object'])
def test_zero_totalcfs(self, desired_class, regression_exp_object, sample_custom_query_1):
exp = regression_exp_object # explainer object
with pytest.raises(UserConfigValidationException):
exp.generate_counterfactuals(
query_instances=[sample_custom_query_1],
total_CFs=0,
desired_class=desired_class)

0 comments on commit e69f7ca

Please sign in to comment.