Skip to content

Commit 84ee579

Browse files
Refactor metrics per filter (#250)
* refactor metrics per filter * Apply suggestions from code review Co-authored-by: Lindsay Brin <[email protected]> * Add profiling details and dependency * Change according to review * Change threading fct Co-authored-by: Lindsay Brin <[email protected]>
1 parent e9dcea0 commit 84ee579

14 files changed

+284
-127
lines changed

.gitignore

+1
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ cache
99
*.swp
1010
!poetry.lock
1111
!yarn.lock
12+
tests/logs.txt
1213

1314
# Byte-compiled / optimized / DLL files
1415
__pycache__/

azimuth/modules/base_classes/aggregation_module.py

+3-43
Original file line numberDiff line numberDiff line change
@@ -3,14 +3,13 @@
33
# in the root directory of this source tree.
44
import time
55
from abc import ABC
6-
from typing import List, Optional, cast
6+
from typing import List, Optional
77

88
from datasets import Dataset
99

1010
from azimuth.modules.base_classes import ConfigScope, ExpirableMixin, Module
11-
from azimuth.types import DatasetColumn, DatasetSplitName, ModuleOptions, ModuleResponse
12-
from azimuth.types.outcomes import OutcomeName
13-
from azimuth.utils.filtering import filter_dataset_split
11+
from azimuth.types import DatasetSplitName, ModuleOptions, ModuleResponse
12+
from azimuth.utils.dataset_operations import filter_dataset_split
1413

1514

1615
class AggregationModule(Module[ConfigScope], ABC):
@@ -62,42 +61,3 @@ def get_dataset_split(self, name: DatasetSplitName = None) -> Dataset:
6261
config=self.config,
6362
without_postprocessing=self.mod_options.without_postprocessing,
6463
)
65-
66-
def _get_predictions_from_ds(self) -> List[int]:
67-
"""Get predicted classes according to the module options (with or without postprocessing).
68-
69-
Returns: List of Predictions
70-
"""
71-
ds = self.get_dataset_split()
72-
if self.mod_options.without_postprocessing:
73-
return cast(List[int], [preds[0] for preds in ds[DatasetColumn.model_predictions]])
74-
else:
75-
return cast(List[int], ds[DatasetColumn.postprocessed_prediction])
76-
77-
def _get_confidences_from_ds(self) -> List[List[float]]:
78-
"""Get confidences according to the module options (with or without postprocessing).
79-
80-
Notes: Confidences are sorted according to their values (not the class id).
81-
82-
Returns: List of Confidences
83-
"""
84-
ds = self.get_dataset_split()
85-
confidences = (
86-
ds[DatasetColumn.model_confidences]
87-
if self.mod_options.without_postprocessing
88-
else ds[DatasetColumn.postprocessed_confidences]
89-
)
90-
return cast(List[List[float]], confidences)
91-
92-
def _get_outcomes_from_ds(self) -> List[OutcomeName]:
93-
"""Get outcomes according to the module options (with or without postprocessing).
94-
95-
Returns: List of Outcomes
96-
"""
97-
ds = self.get_dataset_split()
98-
outcomes = (
99-
ds[DatasetColumn.model_outcome]
100-
if self.mod_options.without_postprocessing
101-
else ds[DatasetColumn.postprocessed_outcome]
102-
)
103-
return cast(List[OutcomeName], outcomes)

azimuth/modules/model_performance/confidence_binning.py

+46-11
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,10 @@
1616
ConfidenceHistogramResponse,
1717
)
1818
from azimuth.types.outcomes import ALL_OUTCOMES, OutcomeName
19+
from azimuth.utils.dataset_operations import (
20+
get_confidences_from_ds,
21+
get_outcomes_from_ds,
22+
)
1923
from azimuth.utils.validation import assert_not_none
2024

2125
CONFIDENCE_BINS_COUNT = 20
@@ -24,23 +28,38 @@
2428
class ConfidenceHistogramModule(FilterableModule[ModelContractConfig]):
2529
"""Return a confidence histogram of the predictions."""
2630

27-
def get_outcome_mask(self, outcome: OutcomeName) -> List[bool]:
28-
return [utterance_outcome == outcome for utterance_outcome in self._get_outcomes_from_ds()]
29-
30-
def compute_on_dataset_split(self) -> List[ConfidenceHistogramResponse]: # type: ignore
31-
"""Compute the confidence histogram with CONFIDENCE_BINS_COUNT bins on the dataset split.
31+
@staticmethod
32+
def get_outcome_mask(
33+
ds, outcome: OutcomeName, without_postprocessing: bool = False
34+
) -> List[bool]:
35+
return [
36+
utterance_outcome == outcome
37+
for utterance_outcome in get_outcomes_from_ds(ds, without_postprocessing)
38+
]
39+
40+
@classmethod
41+
def get_bins(
42+
cls, ds: Dataset, without_postprocessing: bool = False
43+
) -> List[ConfidenceBinDetails]:
44+
"""Compute the bins on the specified dataset split.
45+
46+
Note: This lives outside of `compute_on_dataset_split()` so that it can be called without
47+
going through calling the module and filtering the dataset.
48+
49+
Args:
50+
ds: Dataset Split on which to compute bins
51+
without_postprocessing: Whether to use outcomes and confidences without pipeline
52+
postprocessing
3253
3354
Returns:
3455
List of the confidence bins with their confidence and the outcome count.
35-
3656
"""
37-
bins = np.linspace(0, 1, CONFIDENCE_BINS_COUNT + 1)
3857

39-
ds: Dataset = assert_not_none(self.get_dataset_split())
58+
bins = np.linspace(0, 1, CONFIDENCE_BINS_COUNT + 1)
4059

4160
if len(ds) > 0:
4261
# Get the bin index for each prediction.
43-
confidences = np.max(self._get_confidences_from_ds(), axis=1)
62+
confidences = np.max(get_confidences_from_ds(ds, without_postprocessing), axis=1)
4463
bin_indices = np.floor(confidences * CONFIDENCE_BINS_COUNT)
4564

4665
# Create the records. We drop the last bin as it's the maximum.
@@ -50,7 +69,7 @@ def compute_on_dataset_split(self) -> List[ConfidenceHistogramResponse]: # type
5069
outcome_count = defaultdict(int)
5170
for outcome in ALL_OUTCOMES:
5271
outcome_count[outcome] = np.logical_and(
53-
bin_mask, self.get_outcome_mask(outcome)
72+
bin_mask, cls.get_outcome_mask(ds, outcome, without_postprocessing)
5473
).sum()
5574
mean_conf = (
5675
0 if bin_mask.sum() == 0 else np.nan_to_num(confidences[bin_mask].mean())
@@ -75,7 +94,23 @@ def compute_on_dataset_split(self) -> List[ConfidenceHistogramResponse]: # type
7594
for bin_index, bin_min_value in enumerate(bins[:-1])
7695
]
7796

78-
return [ConfidenceHistogramResponse(bins=result, confidence_threshold=self.get_threshold())]
97+
return result
98+
99+
def compute_on_dataset_split(self) -> List[ConfidenceHistogramResponse]: # type: ignore
100+
"""Compute the confidence histogram with CONFIDENCE_BINS_COUNT bins on the dataset split.
101+
102+
Returns:
103+
Confidence bins and threshold.
104+
105+
"""
106+
ds: Dataset = assert_not_none(self.get_dataset_split())
107+
108+
return [
109+
ConfidenceHistogramResponse(
110+
bins=self.get_bins(ds, self.mod_options.without_postprocessing),
111+
confidence_threshold=self.get_threshold(),
112+
)
113+
]
79114

80115

81116
class ConfidenceBinIndexModule(DatasetResultModule[ModelContractConfig]):

azimuth/modules/model_performance/confusion_matrix.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
from azimuth.config import ModelContractConfig
1414
from azimuth.modules.base_classes import FilterableModule
1515
from azimuth.types.model_performance import ConfusionMatrixResponse
16+
from azimuth.utils.dataset_operations import get_predictions_from_ds
1617
from azimuth.utils.validation import assert_not_none
1718

1819
MIN_CONFUSION_CUTHILL_MCKEE = 0.1
@@ -35,7 +36,7 @@ def compute_on_dataset_split(self) -> List[ConfusionMatrixResponse]: # type: ig
3536
"""
3637
ds: Dataset = assert_not_none(self.get_dataset_split())
3738
predictions, labels = (
38-
self._get_predictions_from_ds(),
39+
get_predictions_from_ds(ds, self.mod_options.without_postprocessing),
3940
ds[self.config.columns.label],
4041
)
4142
ds_mng = self.get_dataset_split_manager()

azimuth/modules/model_performance/metrics.py

+76-57
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,12 @@
3333
SmartTag,
3434
SmartTagFamily,
3535
)
36+
from azimuth.utils.dataset_operations import (
37+
filter_dataset_split,
38+
get_confidences_from_ds,
39+
get_outcomes_from_ds,
40+
get_predictions_from_ds,
41+
)
3642
from azimuth.utils.ml.ece import compute_ece_from_bins
3743
from azimuth.utils.ml.model_performance import sorted_by_utterance_count_with_last
3844
from azimuth.utils.validation import assert_not_none
@@ -58,65 +64,78 @@ def first_value(di: Optional[Dict]) -> Optional[float]:
5864
class MetricsModule(FilterableModule[ModelContractConfig]):
5965
"""Computes different metrics on each dataset split."""
6066

61-
def compute_on_dataset_split(self) -> List[MetricsModuleResponse]: # type: ignore
62-
ds: Dataset = assert_not_none(self.get_dataset_split())
63-
indices = self.get_indices()
64-
if len(indices) == 0:
67+
def compute_metrics(self, ds: Dataset) -> List[MetricsModuleResponse]:
68+
"""Compute all metrics on the specified dataset split.
69+
70+
Note: This lives outside of `compute_on_dataset_split()` so that it can be called without
71+
going through calling the module and filtering the dataset.
72+
73+
Args:
74+
ds: Dataset Split for which to compute metrics.
75+
76+
Returns:
77+
MetricsModuleResponse with all metrics.
78+
"""
79+
if len(ds) == 0:
6580
# Nothing to do, we return an empty response.
6681
return [BASE_RESPONSE]
67-
68-
utterance_count = len(indices)
69-
outcome_count = Counter(self._get_outcomes_from_ds())
70-
outcome_count.update({outcome: 0 for outcome in ALL_OUTCOMES})
71-
72-
# Compute ECE
73-
conf_hist_mod = ConfidenceHistogramModule(
74-
dataset_split_name=self.dataset_split_name,
75-
config=self.config,
76-
mod_options=self.mod_options,
77-
)
78-
bins = conf_hist_mod.compute_on_dataset_split()[0].bins
79-
ece, acc, expected = compute_ece_from_bins(bins)
80-
count_per_bin = [sum(b.outcome_count.values()) for b in bins]
81-
82-
metric_values = {}
83-
dm = self.get_dataset_split_manager()
84-
for metric_name, metric_obj_def in self.config.metrics.items():
85-
met: Metric = self.artifact_manager.get_metric(
86-
self.config,
87-
metric_name,
88-
label_list=dm.get_class_names(),
89-
rejection_class_idx=dm.rejection_class_idx,
90-
force_kwargs=True, # Set True here as load_metrics has **kwargs.
82+
else:
83+
utterance_count = len(ds)
84+
outcome_count = Counter(
85+
get_outcomes_from_ds(ds, self.mod_options.without_postprocessing)
9186
)
92-
accept_probabilities = "probabilities" in inspect.signature(met._compute).parameters
93-
extra_kwargs = (
94-
dict(probabilities=self.make_probabilities()) if accept_probabilities else {}
95-
)
96-
extra_kwargs.update(metric_obj_def.additional_kwargs)
97-
with warnings.catch_warnings():
98-
# Ignore warnings such as
99-
# UndefinedMetricWarning: Precision is ill-defined and being set to 0.0
100-
warnings.simplefilter("ignore", category=UndefinedMetricWarning)
101-
metric_values[metric_name] = assert_not_none(
102-
first_value(
103-
met.compute(
104-
predictions=self._get_predictions_from_ds(),
105-
references=ds[self.config.columns.label],
106-
**extra_kwargs,
87+
outcome_count.update({outcome: 0 for outcome in ALL_OUTCOMES})
88+
89+
# Compute ECE
90+
bins = ConfidenceHistogramModule.get_bins(ds, self.mod_options.without_postprocessing)
91+
ece, acc, expected = compute_ece_from_bins(bins)
92+
count_per_bin = [sum(b.outcome_count.values()) for b in bins]
93+
94+
metric_values = {}
95+
dm = self.get_dataset_split_manager()
96+
for metric_name, metric_obj_def in self.config.metrics.items():
97+
met: Metric = self.artifact_manager.get_metric(
98+
self.config,
99+
metric_name,
100+
label_list=dm.get_class_names(),
101+
rejection_class_idx=dm.rejection_class_idx,
102+
force_kwargs=True, # Set True here as load_metrics has **kwargs.
103+
)
104+
accept_probabilities = "probabilities" in inspect.signature(met._compute).parameters
105+
extra_kwargs = (
106+
dict(probabilities=self.make_probabilities()) if accept_probabilities else {}
107+
)
108+
extra_kwargs.update(metric_obj_def.additional_kwargs)
109+
with warnings.catch_warnings():
110+
# Ignore warnings such as
111+
# UndefinedMetricWarning: Precision is ill-defined and being set to 0.0
112+
warnings.simplefilter("ignore", category=UndefinedMetricWarning)
113+
metric_values[metric_name] = assert_not_none(
114+
first_value(
115+
met.compute(
116+
predictions=get_predictions_from_ds(
117+
ds, self.mod_options.without_postprocessing
118+
),
119+
references=ds[self.config.columns.label],
120+
**extra_kwargs,
121+
)
107122
)
108123
)
124+
125+
return [
126+
MetricsModuleResponse(
127+
outcome_count=outcome_count,
128+
ece=ece,
129+
ece_plot_args=(acc, expected, ece, count_per_bin),
130+
utterance_count=utterance_count,
131+
custom_metrics=metric_values,
109132
)
133+
]
110134

111-
return [
112-
MetricsModuleResponse(
113-
outcome_count=outcome_count,
114-
ece=ece,
115-
ece_plot_args=(acc, expected, ece, count_per_bin),
116-
utterance_count=utterance_count,
117-
custom_metrics=metric_values,
118-
)
119-
]
135+
def compute_on_dataset_split(self) -> List[MetricsModuleResponse]: # type: ignore
136+
"""Computes different metrics according to the specified module options."""
137+
ds: Dataset = assert_not_none(self.get_dataset_split())
138+
return self.compute_metrics(ds)
120139

121140
@staticmethod
122141
def module_to_api_response(res: List[MetricsModuleResponse]) -> List[MetricsAPIResponse]:
@@ -150,7 +169,7 @@ def make_probabilities(self) -> np.ndarray:
150169
probs = np.zeros([len(ds), num_classes])
151170
for idx, (confidences, predictions) in enumerate(
152171
zip(
153-
self._get_confidences_from_ds(),
172+
get_confidences_from_ds(ds, self.mod_options.without_postprocessing),
154173
ds[DatasetColumn.model_predictions],
155174
)
156175
):
@@ -173,14 +192,14 @@ def get_metrics_for_filter(
173192
Returns:
174193
Metrics for all provided filters.
175194
"""
195+
ds = self.get_dataset_split()
176196
accumulator = []
177197
for filter_value, filters in filters_dict.items():
178-
met_module = MetricsModule(
198+
ds_filtered = filter_dataset_split(ds, filters, config=self.config)
199+
metric = MetricsModule(
179200
dataset_split_name=self.dataset_split_name,
180201
config=self.config,
181-
mod_options=self.mod_options.copy(update={"filters": filters}),
182-
)
183-
metric = met_module.compute_on_dataset_split()[0]
202+
).compute_metrics(ds_filtered)[0]
184203
accumulator.append(MetricsPerFilterValue(**metric.dict(), filter_value=filter_value))
185204
return accumulator
186205

azimuth/modules/model_performance/outcome_count.py

+10-3
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
SMART_TAGS_FAMILY_MAPPING,
2424
SmartTag,
2525
)
26+
from azimuth.utils.dataset_operations import get_outcomes_from_ds
2627
from azimuth.utils.ml.model_performance import (
2728
sorted_by_utterance_count,
2829
sorted_by_utterance_count_with_last,
@@ -50,7 +51,9 @@ def get_outcome_count_per_class(
5051
"""
5152
outcome_count_per_class: Dict[Tuple[str, OutcomeName], int] = defaultdict(int)
5253

53-
for utterance_class, outcome in zip(ds[dataset_column], self._get_outcomes_from_ds()):
54+
for utterance_class, outcome in zip(
55+
ds[dataset_column], get_outcomes_from_ds(ds, self.mod_options.without_postprocessing)
56+
):
5457
outcome_count_per_class[(dm.get_class_names()[utterance_class], outcome)] += 1
5558

5659
return sorted_by_utterance_count_with_last(
@@ -77,7 +80,9 @@ def get_outcome_count_per_tag(
7780
all_tags = dm.get_tags(
7881
indices=assert_is_list(ds[DatasetColumn.row_idx]), table_key=self._get_table_key()
7982
)
80-
for utterance_tags, outcome in zip(all_tags, self._get_outcomes_from_ds()):
83+
for utterance_tags, outcome in zip(
84+
all_tags, get_outcomes_from_ds(ds, self.mod_options.without_postprocessing)
85+
):
8186
no_tag = True
8287
for filter_, tagged in utterance_tags.items():
8388
if tagged and filter_ in filters[:-1]:
@@ -100,7 +105,9 @@ def get_outcome_count_per_outcome(self, ds: Dataset) -> List[OutcomeCountPerFilt
100105
List of Outcome Count for each outcome.
101106
102107
"""
103-
outcome_count = defaultdict(int, Counter(self._get_outcomes_from_ds()))
108+
outcome_count = defaultdict(
109+
int, Counter(get_outcomes_from_ds(ds, self.mod_options.without_postprocessing))
110+
)
104111
empty_outcome_count = {outcome: 0 for outcome in OutcomeName}
105112

106113
metrics = [

0 commit comments

Comments
 (0)