From 509802d8c6321fde6910492d86dfe4d65ca11234 Mon Sep 17 00:00:00 2001 From: Gaurav Gupta <47334368+gaugup@users.noreply.github.com> Date: Tue, 27 Apr 2021 23:57:47 -0700 Subject: [PATCH] Reduce number inferences in dice random (#127) * Redice number of inference calls DiceRandom Signed-off-by: gaugup <gaugup@microsoft.com> * Added comment Signed-off-by: gaugup <gaugup@microsoft.com> * corrected typo to model_predictions Co-authored-by: Amit Sharma <amit_sharma@live.com> --- dice_ml/explainer_interfaces/dice_random.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/dice_ml/explainer_interfaces/dice_random.py b/dice_ml/explainer_interfaces/dice_random.py index 6fd65a5f..5fe72e79 100644 --- a/dice_ml/explainer_interfaces/dice_random.py +++ b/dice_ml/explainer_interfaces/dice_random.py @@ -63,14 +63,18 @@ def _generate_counterfactuals(self, query_instance, total_CFs, desired_range, d else: # compute the new ranges based on user input self.feature_range, feature_ranges_orig = self.data_interface.get_features_range(permitted_range) + # Do predictions once on the query_instance and reuse across to reduce the number + # inferences. + model_predictions = self.predict_fn(query_instance) + # number of output nodes of ML model self.num_output_nodes = None if self.model.model_type == "classifier": - self.num_output_nodes = self.predict_fn(query_instance).shape[1] + self.num_output_nodes = model_predictions.shape[1] # query_instance need no transformation for generating CFs using random sampling. # find the predicted value of query_instance - test_pred = self.predict_fn(query_instance)[0] + test_pred = model_predictions[0] if self.model.model_type == 'classifier': self.target_cf_class = self.infer_target_cfs_class(desired_class, test_pred, self.num_output_nodes) elif self.model.model_type == 'regressor':