From 509802d8c6321fde6910492d86dfe4d65ca11234 Mon Sep 17 00:00:00 2001
From: Gaurav Gupta <47334368+gaugup@users.noreply.github.com>
Date: Tue, 27 Apr 2021 23:57:47 -0700
Subject: [PATCH] Reduce number inferences in dice random (#127)

* Redice number of inference calls DiceRandom

Signed-off-by: gaugup <gaugup@microsoft.com>

* Added comment

Signed-off-by: gaugup <gaugup@microsoft.com>

* corrected typo to model_predictions

Co-authored-by: Amit Sharma <amit_sharma@live.com>
---
 dice_ml/explainer_interfaces/dice_random.py | 8 ++++++--
 1 file changed, 6 insertions(+), 2 deletions(-)

diff --git a/dice_ml/explainer_interfaces/dice_random.py b/dice_ml/explainer_interfaces/dice_random.py
index 6fd65a5f..5fe72e79 100644
--- a/dice_ml/explainer_interfaces/dice_random.py
+++ b/dice_ml/explainer_interfaces/dice_random.py
@@ -63,14 +63,18 @@ def _generate_counterfactuals(self, query_instance, total_CFs, desired_range,  d
         else: # compute the new ranges based on user input
             self.feature_range, feature_ranges_orig = self.data_interface.get_features_range(permitted_range)
 
+        # Do predictions once on the query_instance and reuse across to reduce the number
+        # inferences.
+        model_predictions = self.predict_fn(query_instance)
+
         # number of output nodes of ML model
         self.num_output_nodes = None
         if self.model.model_type == "classifier":
-            self.num_output_nodes = self.predict_fn(query_instance).shape[1]
+            self.num_output_nodes = model_predictions.shape[1]
 
         # query_instance need no transformation for generating CFs using random sampling.
         # find the predicted value of query_instance
-        test_pred = self.predict_fn(query_instance)[0]
+        test_pred = model_predictions[0]
         if self.model.model_type == 'classifier':
             self.target_cf_class = self.infer_target_cfs_class(desired_class, test_pred, self.num_output_nodes)
         elif self.model.model_type == 'regressor':