diff --git a/dice_ml/explainer_interfaces/dice_random.py b/dice_ml/explainer_interfaces/dice_random.py index 6fd65a5f..5fe72e79 100644 --- a/dice_ml/explainer_interfaces/dice_random.py +++ b/dice_ml/explainer_interfaces/dice_random.py @@ -63,14 +63,18 @@ def _generate_counterfactuals(self, query_instance, total_CFs, desired_range, d else: # compute the new ranges based on user input self.feature_range, feature_ranges_orig = self.data_interface.get_features_range(permitted_range) + # Do predictions once on the query_instance and reuse across to reduce the number + # inferences. + model_predictions = self.predict_fn(query_instance) + # number of output nodes of ML model self.num_output_nodes = None if self.model.model_type == "classifier": - self.num_output_nodes = self.predict_fn(query_instance).shape[1] + self.num_output_nodes = model_predictions.shape[1] # query_instance need no transformation for generating CFs using random sampling. # find the predicted value of query_instance - test_pred = self.predict_fn(query_instance)[0] + test_pred = model_predictions[0] if self.model.model_type == 'classifier': self.target_cf_class = self.infer_target_cfs_class(desired_class, test_pred, self.num_output_nodes) elif self.model.model_type == 'regressor':