diff --git a/dice_ml/explainer_interfaces/dice_random.py b/dice_ml/explainer_interfaces/dice_random.py index 00a27e9e..6fd65a5f 100644 --- a/dice_ml/explainer_interfaces/dice_random.py +++ b/dice_ml/explainer_interfaces/dice_random.py @@ -163,7 +163,7 @@ def _generate_counterfactuals(self, query_instance, total_CFs, desired_range, d m, 'min %02d' % s, 'sec') else: if self.total_cfs_found == 0 : - print('No Counterfactuals found for the given configuation, perhaps try with different parameters...', '; total time taken: %02d' % m, 'min %02d' % s, 'sec') + print('No Counterfactuals found for the given configuration, perhaps try with different parameters...', '; total time taken: %02d' % m, 'min %02d' % s, 'sec') else: print('Only %d (required %d) Diverse Counterfactuals found for the given configuration, perhaps try with different parameters...' % (self.total_cfs_found, self.total_CFs), '; total time taken: %02d' % m, 'min %02d' % s, 'sec') diff --git a/dice_ml/explainer_interfaces/explainer_base.py b/dice_ml/explainer_interfaces/explainer_base.py index 67d1e886..aaeb2408 100644 --- a/dice_ml/explainer_interfaces/explainer_base.py +++ b/dice_ml/explainer_interfaces/explainer_base.py @@ -68,7 +68,8 @@ def generate_counterfactuals(self, query_instances, total_CFs, :returns: A CounterfactualExplanations object that contains the list of counterfactual examples per query_instance as one of its attributes. """ - + if total_CFs <= 0: + raise UserConfigValidationException("The number of counterfactuals generated per query instance (total_CFs) should be a positive integer.") cf_examples_arr = [] query_instances_list = [] if isinstance(query_instances, pd.DataFrame): diff --git a/tests/conftest.py b/tests/conftest.py index 72f21c93..16c01d85 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -4,6 +4,40 @@ import dice_ml from dice_ml.utils import helpers +@pytest.fixture +def binary_classification_exp_object(method="random"): + backend = 'sklearn' + dataset = helpers.load_custom_testing_dataset_binary() + d = dice_ml.Data(dataframe=dataset, continuous_features=['Numerical'], outcome_name='Outcome') + ML_modelpath = helpers.get_custom_dataset_modelpath_pipeline_binary() + m = dice_ml.Model(model_path=ML_modelpath, backend=backend) + exp = dice_ml.Dice(d, m, method=method) + return exp + + +@pytest.fixture +def multi_classification_exp_object(method="random"): + backend = 'sklearn' + dataset = helpers.load_custom_testing_dataset_multiclass() + d = dice_ml.Data(dataframe=dataset, continuous_features=['Numerical'], outcome_name='Outcome') + ML_modelpath = helpers.get_custom_dataset_modelpath_pipeline_multiclass() + m = dice_ml.Model(model_path=ML_modelpath, backend=backend) + exp = dice_ml.Dice(d, m, method=method) + return exp + + +@pytest.fixture +def regression_exp_object(method="random"): + backend = 'sklearn' + dataset = helpers.load_custom_testing_dataset_regression() + d = dice_ml.Data(dataframe=dataset, continuous_features=['Numerical'], outcome_name='Outcome') + ML_modelpath = helpers.get_custom_dataset_modelpath_pipeline_regression() + m = dice_ml.Model(model_path=ML_modelpath, backend=backend, model_type='regressor') + exp = dice_ml.Dice(d, m, method=method) + return exp + + + @pytest.fixture def public_data_object(): """ diff --git a/tests/test_dice_interface/test_explainer_base.py b/tests/test_dice_interface/test_explainer_base.py new file mode 100644 index 00000000..2ae05b2f --- /dev/null +++ b/tests/test_dice_interface/test_explainer_base.py @@ -0,0 +1,37 @@ +import pytest +from dice_ml.utils.exception import UserConfigValidationException + + +class TestExplainerBaseBinaryClassification: + + @pytest.mark.parametrize("desired_class, binary_classification_exp_object", [(1, 'random'),(1,'genetic'),(1,'kdtree')], indirect=['binary_classification_exp_object']) + def test_zero_totalcfs(self, desired_class, binary_classification_exp_object, sample_custom_query_1): + exp = binary_classification_exp_object # explainer object + with pytest.raises(UserConfigValidationException): + exp.generate_counterfactuals( + query_instances=[sample_custom_query_1], + total_CFs=0, + desired_class=desired_class) + +class TestExplainerBaseMultiClassClassification: + + @pytest.mark.parametrize("desired_class, multi_classification_exp_object", [(1, 'random'),(1,'genetic'),(1,'kdtree')], indirect=['multi_classification_exp_object']) + def test_zero_totalcfs(self, desired_class, multi_classification_exp_object, sample_custom_query_1): + exp = multi_classification_exp_object # explainer object + with pytest.raises(UserConfigValidationException): + exp.generate_counterfactuals( + query_instances=[sample_custom_query_1], + total_CFs=0, + desired_class=desired_class) + + +class TestExplainerBaseRegression: + + @pytest.mark.parametrize("desired_class, regression_exp_object", [(1, 'random'),(1,'genetic'),(1,'kdtree')], indirect=['regression_exp_object']) + def test_zero_totalcfs(self, desired_class, regression_exp_object, sample_custom_query_1): + exp = regression_exp_object # explainer object + with pytest.raises(UserConfigValidationException): + exp.generate_counterfactuals( + query_instances=[sample_custom_query_1], + total_CFs=0, + desired_class=desired_class)