33
33
SmartTag ,
34
34
SmartTagFamily ,
35
35
)
36
+ from azimuth .utils .dataset_operations import (
37
+ filter_dataset_split ,
38
+ get_confidences_from_ds ,
39
+ get_outcomes_from_ds ,
40
+ get_predictions_from_ds ,
41
+ )
36
42
from azimuth .utils .ml .ece import compute_ece_from_bins
37
43
from azimuth .utils .ml .model_performance import sorted_by_utterance_count_with_last
38
44
from azimuth .utils .validation import assert_not_none
@@ -58,65 +64,78 @@ def first_value(di: Optional[Dict]) -> Optional[float]:
58
64
class MetricsModule (FilterableModule [ModelContractConfig ]):
59
65
"""Computes different metrics on each dataset split."""
60
66
61
- def compute_on_dataset_split (self ) -> List [MetricsModuleResponse ]: # type: ignore
62
- ds : Dataset = assert_not_none (self .get_dataset_split ())
63
- indices = self .get_indices ()
64
- if len (indices ) == 0 :
67
+ def compute_metrics (self , ds : Dataset ) -> List [MetricsModuleResponse ]:
68
+ """Compute all metrics on the specified dataset split.
69
+
70
+ Note: This lives outside of `compute_on_dataset_split()` so that it can be called without
71
+ going through calling the module and filtering the dataset.
72
+
73
+ Args:
74
+ ds: Dataset Split for which to compute metrics.
75
+
76
+ Returns:
77
+ MetricsModuleResponse with all metrics.
78
+ """
79
+ if len (ds ) == 0 :
65
80
# Nothing to do, we return an empty response.
66
81
return [BASE_RESPONSE ]
67
-
68
- utterance_count = len (indices )
69
- outcome_count = Counter (self ._get_outcomes_from_ds ())
70
- outcome_count .update ({outcome : 0 for outcome in ALL_OUTCOMES })
71
-
72
- # Compute ECE
73
- conf_hist_mod = ConfidenceHistogramModule (
74
- dataset_split_name = self .dataset_split_name ,
75
- config = self .config ,
76
- mod_options = self .mod_options ,
77
- )
78
- bins = conf_hist_mod .compute_on_dataset_split ()[0 ].bins
79
- ece , acc , expected = compute_ece_from_bins (bins )
80
- count_per_bin = [sum (b .outcome_count .values ()) for b in bins ]
81
-
82
- metric_values = {}
83
- dm = self .get_dataset_split_manager ()
84
- for metric_name , metric_obj_def in self .config .metrics .items ():
85
- met : Metric = self .artifact_manager .get_metric (
86
- self .config ,
87
- metric_name ,
88
- label_list = dm .get_class_names (),
89
- rejection_class_idx = dm .rejection_class_idx ,
90
- force_kwargs = True , # Set True here as load_metrics has **kwargs.
82
+ else :
83
+ utterance_count = len (ds )
84
+ outcome_count = Counter (
85
+ get_outcomes_from_ds (ds , self .mod_options .without_postprocessing )
91
86
)
92
- accept_probabilities = "probabilities" in inspect .signature (met ._compute ).parameters
93
- extra_kwargs = (
94
- dict (probabilities = self .make_probabilities ()) if accept_probabilities else {}
95
- )
96
- extra_kwargs .update (metric_obj_def .additional_kwargs )
97
- with warnings .catch_warnings ():
98
- # Ignore warnings such as
99
- # UndefinedMetricWarning: Precision is ill-defined and being set to 0.0
100
- warnings .simplefilter ("ignore" , category = UndefinedMetricWarning )
101
- metric_values [metric_name ] = assert_not_none (
102
- first_value (
103
- met .compute (
104
- predictions = self ._get_predictions_from_ds (),
105
- references = ds [self .config .columns .label ],
106
- ** extra_kwargs ,
87
+ outcome_count .update ({outcome : 0 for outcome in ALL_OUTCOMES })
88
+
89
+ # Compute ECE
90
+ bins = ConfidenceHistogramModule .get_bins (ds , self .mod_options .without_postprocessing )
91
+ ece , acc , expected = compute_ece_from_bins (bins )
92
+ count_per_bin = [sum (b .outcome_count .values ()) for b in bins ]
93
+
94
+ metric_values = {}
95
+ dm = self .get_dataset_split_manager ()
96
+ for metric_name , metric_obj_def in self .config .metrics .items ():
97
+ met : Metric = self .artifact_manager .get_metric (
98
+ self .config ,
99
+ metric_name ,
100
+ label_list = dm .get_class_names (),
101
+ rejection_class_idx = dm .rejection_class_idx ,
102
+ force_kwargs = True , # Set True here as load_metrics has **kwargs.
103
+ )
104
+ accept_probabilities = "probabilities" in inspect .signature (met ._compute ).parameters
105
+ extra_kwargs = (
106
+ dict (probabilities = self .make_probabilities ()) if accept_probabilities else {}
107
+ )
108
+ extra_kwargs .update (metric_obj_def .additional_kwargs )
109
+ with warnings .catch_warnings ():
110
+ # Ignore warnings such as
111
+ # UndefinedMetricWarning: Precision is ill-defined and being set to 0.0
112
+ warnings .simplefilter ("ignore" , category = UndefinedMetricWarning )
113
+ metric_values [metric_name ] = assert_not_none (
114
+ first_value (
115
+ met .compute (
116
+ predictions = get_predictions_from_ds (
117
+ ds , self .mod_options .without_postprocessing
118
+ ),
119
+ references = ds [self .config .columns .label ],
120
+ ** extra_kwargs ,
121
+ )
107
122
)
108
123
)
124
+
125
+ return [
126
+ MetricsModuleResponse (
127
+ outcome_count = outcome_count ,
128
+ ece = ece ,
129
+ ece_plot_args = (acc , expected , ece , count_per_bin ),
130
+ utterance_count = utterance_count ,
131
+ custom_metrics = metric_values ,
109
132
)
133
+ ]
110
134
111
- return [
112
- MetricsModuleResponse (
113
- outcome_count = outcome_count ,
114
- ece = ece ,
115
- ece_plot_args = (acc , expected , ece , count_per_bin ),
116
- utterance_count = utterance_count ,
117
- custom_metrics = metric_values ,
118
- )
119
- ]
135
+ def compute_on_dataset_split (self ) -> List [MetricsModuleResponse ]: # type: ignore
136
+ """Computes different metrics according to the specified module options."""
137
+ ds : Dataset = assert_not_none (self .get_dataset_split ())
138
+ return self .compute_metrics (ds )
120
139
121
140
@staticmethod
122
141
def module_to_api_response (res : List [MetricsModuleResponse ]) -> List [MetricsAPIResponse ]:
@@ -150,7 +169,7 @@ def make_probabilities(self) -> np.ndarray:
150
169
probs = np .zeros ([len (ds ), num_classes ])
151
170
for idx , (confidences , predictions ) in enumerate (
152
171
zip (
153
- self ._get_confidences_from_ds ( ),
172
+ get_confidences_from_ds ( ds , self .mod_options . without_postprocessing ),
154
173
ds [DatasetColumn .model_predictions ],
155
174
)
156
175
):
@@ -173,14 +192,14 @@ def get_metrics_for_filter(
173
192
Returns:
174
193
Metrics for all provided filters.
175
194
"""
195
+ ds = self .get_dataset_split ()
176
196
accumulator = []
177
197
for filter_value , filters in filters_dict .items ():
178
- met_module = MetricsModule (
198
+ ds_filtered = filter_dataset_split (ds , filters , config = self .config )
199
+ metric = MetricsModule (
179
200
dataset_split_name = self .dataset_split_name ,
180
201
config = self .config ,
181
- mod_options = self .mod_options .copy (update = {"filters" : filters }),
182
- )
183
- metric = met_module .compute_on_dataset_split ()[0 ]
202
+ ).compute_metrics (ds_filtered )[0 ]
184
203
accumulator .append (MetricsPerFilterValue (** metric .dict (), filter_value = filter_value ))
185
204
return accumulator
186
205
0 commit comments