diff --git a/doc/source/pyplots/tutorial.py b/doc/source/pyplots/tutorial.py index da11463..1d9ffd5 100644 --- a/doc/source/pyplots/tutorial.py +++ b/doc/source/pyplots/tutorial.py @@ -1,6 +1,5 @@ - notebook.width = 10 -plt.rcParams['figure.figsize'] = (notebook.width, 3) +plt.rcParams["figure.figsize"] = (notebook.width, 3) # only display [0, 20] timerange notebook.crop = Segment(0, 40) @@ -8,22 +7,22 @@ # plot reference plt.subplot(211) reference = Annotation() -reference[Segment(0, 10)] = 'A' -reference[Segment(12, 20)] = 'B' -reference[Segment(24, 27)] = 'A' -reference[Segment(30, 40)] = 'C' +reference[Segment(0, 10)] = "A" +reference[Segment(12, 20)] = "B" +reference[Segment(24, 27)] = "A" +reference[Segment(30, 40)] = "C" notebook.plot_annotation(reference, legend=True, time=False) -plt.gca().text(0.6, 0.15, 'reference', fontsize=16) +plt.gca().text(0.6, 0.15, "reference", fontsize=16) # plot hypothesis plt.subplot(212) hypothesis = Annotation() -hypothesis[Segment(2, 13)] = 'a' -hypothesis[Segment(13, 14)] = 'd' -hypothesis[Segment(14, 20)] = 'b' -hypothesis[Segment(22, 38)] = 'c' -hypothesis[Segment(38, 40)] = 'd' +hypothesis[Segment(2, 13)] = "a" +hypothesis[Segment(13, 14)] = "d" +hypothesis[Segment(14, 20)] = "b" +hypothesis[Segment(22, 38)] = "c" +hypothesis[Segment(38, 40)] = "d" notebook.plot_annotation(hypothesis, legend=True, time=True) -plt.gca().text(0.6, 0.15, 'hypothesis', fontsize=16) +plt.gca().text(0.6, 0.15, "hypothesis", fontsize=16) plt.show() diff --git a/src/pyannote/metrics/__init__.py b/src/pyannote/metrics/__init__.py index 9b3b650..5a23398 100644 --- a/src/pyannote/metrics/__init__.py +++ b/src/pyannote/metrics/__init__.py @@ -29,6 +29,7 @@ from .base import f_measure import importlib.metadata + __version__ = importlib.metadata.version("pyannote-metrics") __all__ = ["f_measure"] diff --git a/src/pyannote/metrics/base.py b/src/pyannote/metrics/base.py index 106f80c..0d5b2aa 100755 --- a/src/pyannote/metrics/base.py +++ b/src/pyannote/metrics/base.py @@ -50,14 +50,14 @@ class BaseMetric: def metric_name(cls) -> str: raise NotImplementedError( cls.__name__ + " is missing a 'metric_name' class method. " - "It should return the name of the metric as string." + "It should return the name of the metric as string." ) @classmethod def metric_components(cls) -> MetricComponents: raise NotImplementedError( cls.__name__ + " is missing a 'metric_components' class method. " - "It should return the list of names of metric components." + "It should return the list of names of metric components." ) def __init__(self, **kwargs): @@ -84,9 +84,14 @@ def name(self): # TODO: use joblib/locky to allow parallel processing? # TODO: signature could be something like __call__(self, reference_iterator, hypothesis_iterator, ...) - def __call__(self, reference: Union[Timeline, Annotation], - hypothesis: Union[Timeline, Annotation], - detailed: bool = False, uri: Optional[str] = None, **kwargs): + def __call__( + self, + reference: Union[Timeline, Annotation], + hypothesis: Union[Timeline, Annotation], + detailed: bool = False, + uri: Optional[str] = None, + **kwargs, + ): """Compute metric value and accumulate components Parameters @@ -247,10 +252,12 @@ def __iter__(self): for uri, component in self.results_: yield uri, component - def compute_components(self, - reference: Union[Timeline, Annotation], - hypothesis: Union[Timeline, Annotation], - **kwargs) -> Details: + def compute_components( + self, + reference: Union[Timeline, Annotation], + hypothesis: Union[Timeline, Annotation], + **kwargs, + ) -> Details: """Compute metric components Parameters @@ -269,8 +276,8 @@ def compute_components(self, """ raise NotImplementedError( self.__class__.__name__ + " is missing a 'compute_components' method." - "It should return a dictionary where keys are component names " - "and values are component values." + "It should return a dictionary where keys are component names " + "and values are component values." ) def compute_metric(self, components: Details): @@ -289,12 +296,13 @@ def compute_metric(self, components: Details): """ raise NotImplementedError( self.__class__.__name__ + " is missing a 'compute_metric' method. " - "It should return the actual value of the metric based " - "on the precomputed component dictionary given as input." + "It should return the actual value of the metric based " + "on the precomputed component dictionary given as input." ) - def confidence_interval(self, alpha: float = 0.9) \ - -> Tuple[float, Tuple[float, float]]: + def confidence_interval( + self, alpha: float = 0.9 + ) -> Tuple[float, Tuple[float, float]]: """Compute confidence interval on accumulated metric values Parameters @@ -319,13 +327,17 @@ def confidence_interval(self, alpha: float = 0.9) \ values = [r[self.metric_name_] for _, r in self.results_] if len(values) == 0: - raise ValueError("Please evaluate a bunch of files before computing confidence interval.") - + raise ValueError( + "Please evaluate a bunch of files before computing confidence interval." + ) + elif len(values) == 1: - warnings.warn("Cannot compute a reliable confidence interval out of just one file.") + warnings.warn( + "Cannot compute a reliable confidence interval out of just one file." + ) center = lower = upper = values[0] return center, (lower, upper) - + else: return scipy.stats.bayes_mvs(values, alpha=alpha)[0] diff --git a/src/pyannote/metrics/cli.py b/src/pyannote/metrics/cli.py index e04bbde..b480b85 100644 --- a/src/pyannote/metrics/cli.py +++ b/src/pyannote/metrics/cli.py @@ -97,10 +97,12 @@ import numpy as np import pandas as pd + # command line parsing from docopt import docopt from pyannote.core import Annotation from pyannote.core import Timeline + # evaluation protocols from pyannote.database import get_protocol from pyannote.database.util import get_annotated diff --git a/src/pyannote/metrics/detection.py b/src/pyannote/metrics/detection.py index 3b1929b..f13ba63 100755 --- a/src/pyannote/metrics/detection.py +++ b/src/pyannote/metrics/detection.py @@ -34,10 +34,10 @@ from .types import Details, MetricComponents from .utils import UEMSupportMixin -DER_NAME = 'detection error rate' -DER_TOTAL = 'total' -DER_FALSE_ALARM = 'false alarm' -DER_MISS = 'miss' +DER_NAME = "detection error rate" +DER_TOTAL = "total" +DER_FALSE_ALARM = "false alarm" +DER_MISS = "miss" class DetectionErrorRate(UEMSupportMixin, BaseMetric): @@ -77,16 +77,22 @@ def __init__(self, collar: float = 0.0, skip_overlap: bool = False, **kwargs): self.collar = collar self.skip_overlap = skip_overlap - def compute_components(self, - reference: Annotation, - hypothesis: Annotation, - uem: Optional[Timeline] = None, - **kwargs) -> Details: + def compute_components( + self, + reference: Annotation, + hypothesis: Annotation, + uem: Optional[Timeline] = None, + **kwargs, + ) -> Details: reference, hypothesis, uem = self.uemify( - reference, hypothesis, uem=uem, - collar=self.collar, skip_overlap=self.skip_overlap, - returns_uem=True) + reference, + hypothesis, + uem=uem, + collar=self.collar, + skip_overlap=self.skip_overlap, + returns_uem=True, + ) reference = reference.get_timeline(copy=False).support() hypothesis = hypothesis.get_timeline(copy=False).support() @@ -94,11 +100,11 @@ def compute_components(self, reference_ = reference.gaps(support=uem) hypothesis_ = hypothesis.gaps(support=uem) - false_positive = 0. + false_positive = 0.0 for r_, h in reference_.co_iter(hypothesis): false_positive += (r_ & h).duration - false_negative = 0. + false_negative = 0.0 for r, h_ in reference.co_iter(hypothesis_): false_negative += (r & h_).duration @@ -110,22 +116,22 @@ def compute_components(self, return detail def compute_metric(self, detail: Details) -> float: - error = 1. * (detail[DER_FALSE_ALARM] + detail[DER_MISS]) - total = 1. * detail[DER_TOTAL] - if total == 0.: + error = 1.0 * (detail[DER_FALSE_ALARM] + detail[DER_MISS]) + total = 1.0 * detail[DER_TOTAL] + if total == 0.0: if error == 0: - return 0. + return 0.0 else: - return 1. + return 1.0 else: return error / total -ACCURACY_NAME = 'detection accuracy' -ACCURACY_TRUE_POSITIVE = 'true positive' -ACCURACY_TRUE_NEGATIVE = 'true negative' -ACCURACY_FALSE_POSITIVE = 'false positive' -ACCURACY_FALSE_NEGATIVE = 'false negative' +ACCURACY_NAME = "detection accuracy" +ACCURACY_TRUE_POSITIVE = "true positive" +ACCURACY_TRUE_NEGATIVE = "true negative" +ACCURACY_FALSE_POSITIVE = "false positive" +ACCURACY_FALSE_NEGATIVE = "false negative" class DetectionAccuracy(DetectionErrorRate): @@ -158,19 +164,29 @@ def metric_name(cls): @classmethod def metric_components(cls) -> MetricComponents: - return [ACCURACY_TRUE_POSITIVE, ACCURACY_TRUE_NEGATIVE, - ACCURACY_FALSE_POSITIVE, ACCURACY_FALSE_NEGATIVE] - - def compute_components(self, - reference: Annotation, - hypothesis: Annotation, - uem: Optional[Timeline] = None, - **kwargs) -> Details: + return [ + ACCURACY_TRUE_POSITIVE, + ACCURACY_TRUE_NEGATIVE, + ACCURACY_FALSE_POSITIVE, + ACCURACY_FALSE_NEGATIVE, + ] + + def compute_components( + self, + reference: Annotation, + hypothesis: Annotation, + uem: Optional[Timeline] = None, + **kwargs, + ) -> Details: reference, hypothesis, uem = self.uemify( - reference, hypothesis, uem=uem, - collar=self.collar, skip_overlap=self.skip_overlap, - returns_uem=True) + reference, + hypothesis, + uem=uem, + collar=self.collar, + skip_overlap=self.skip_overlap, + returns_uem=True, + ) reference = reference.get_timeline(copy=False).support() hypothesis = hypothesis.get_timeline(copy=False).support() @@ -178,19 +194,19 @@ def compute_components(self, reference_ = reference.gaps(support=uem) hypothesis_ = hypothesis.gaps(support=uem) - true_positive = 0. + true_positive = 0.0 for r, h in reference.co_iter(hypothesis): true_positive += (r & h).duration - true_negative = 0. + true_negative = 0.0 for r_, h_ in reference_.co_iter(hypothesis_): true_negative += (r_ & h_).duration - false_positive = 0. + false_positive = 0.0 for r_, h in reference_.co_iter(hypothesis): false_positive += (r_ & h).duration - false_negative = 0. + false_negative = 0.0 for r, h_ in reference.co_iter(hypothesis_): false_negative += (r & h_).duration @@ -203,22 +219,25 @@ def compute_components(self, return detail def compute_metric(self, detail: Details) -> float: - numerator = 1. * (detail[ACCURACY_TRUE_NEGATIVE] + - detail[ACCURACY_TRUE_POSITIVE]) - denominator = 1. * (detail[ACCURACY_TRUE_NEGATIVE] + - detail[ACCURACY_TRUE_POSITIVE] + - detail[ACCURACY_FALSE_NEGATIVE] + - detail[ACCURACY_FALSE_POSITIVE]) - - if denominator == 0.: - return 1. + numerator = 1.0 * ( + detail[ACCURACY_TRUE_NEGATIVE] + detail[ACCURACY_TRUE_POSITIVE] + ) + denominator = 1.0 * ( + detail[ACCURACY_TRUE_NEGATIVE] + + detail[ACCURACY_TRUE_POSITIVE] + + detail[ACCURACY_FALSE_NEGATIVE] + + detail[ACCURACY_FALSE_POSITIVE] + ) + + if denominator == 0.0: + return 1.0 else: return numerator / denominator -PRECISION_NAME = 'detection precision' -PRECISION_RETRIEVED = 'retrieved' -PRECISION_RELEVANT_RETRIEVED = 'relevant retrieved' +PRECISION_NAME = "detection precision" +PRECISION_RETRIEVED = "retrieved" +PRECISION_RELEVANT_RETRIEVED = "relevant retrieved" class DetectionPrecision(DetectionErrorRate): @@ -252,27 +271,33 @@ def metric_name(cls): def metric_components(cls) -> MetricComponents: return [PRECISION_RETRIEVED, PRECISION_RELEVANT_RETRIEVED] - def compute_components(self, - reference: Annotation, - hypothesis: Annotation, - uem: Optional[Timeline] = None, - **kwargs) -> Details: + def compute_components( + self, + reference: Annotation, + hypothesis: Annotation, + uem: Optional[Timeline] = None, + **kwargs, + ) -> Details: reference, hypothesis, uem = self.uemify( - reference, hypothesis, uem=uem, - collar=self.collar, skip_overlap=self.skip_overlap, - returns_uem=True) + reference, + hypothesis, + uem=uem, + collar=self.collar, + skip_overlap=self.skip_overlap, + returns_uem=True, + ) reference = reference.get_timeline(copy=False).support() hypothesis = hypothesis.get_timeline(copy=False).support() reference_ = reference.gaps(support=uem) - true_positive = 0. + true_positive = 0.0 for r, h in reference.co_iter(hypothesis): true_positive += (r & h).duration - false_positive = 0. + false_positive = 0.0 for r_, h in reference_.co_iter(hypothesis): false_positive += (r_ & h).duration @@ -283,17 +308,17 @@ def compute_components(self, return detail def compute_metric(self, detail: Details) -> float: - relevant_retrieved = 1. * detail[PRECISION_RELEVANT_RETRIEVED] - retrieved = 1. * detail[PRECISION_RETRIEVED] - if retrieved == 0.: - return 1. + relevant_retrieved = 1.0 * detail[PRECISION_RELEVANT_RETRIEVED] + retrieved = 1.0 * detail[PRECISION_RETRIEVED] + if retrieved == 0.0: + return 1.0 else: return relevant_retrieved / retrieved -RECALL_NAME = 'detection recall' -RECALL_RELEVANT = 'relevant' -RECALL_RELEVANT_RETRIEVED = 'relevant retrieved' +RECALL_NAME = "detection recall" +RECALL_RELEVANT = "relevant" +RECALL_RELEVANT_RETRIEVED = "relevant retrieved" class DetectionRecall(DetectionErrorRate): @@ -327,27 +352,33 @@ def metric_name(cls): def metric_components(cls) -> MetricComponents: return [RECALL_RELEVANT, RECALL_RELEVANT_RETRIEVED] - def compute_components(self, - reference: Annotation, - hypothesis: Annotation, - uem: Optional[Timeline] = None, - **kwargs) -> Details: + def compute_components( + self, + reference: Annotation, + hypothesis: Annotation, + uem: Optional[Timeline] = None, + **kwargs, + ) -> Details: reference, hypothesis, uem = self.uemify( - reference, hypothesis, uem=uem, - collar=self.collar, skip_overlap=self.skip_overlap, - returns_uem=True) + reference, + hypothesis, + uem=uem, + collar=self.collar, + skip_overlap=self.skip_overlap, + returns_uem=True, + ) reference = reference.get_timeline(copy=False).support() hypothesis = hypothesis.get_timeline(copy=False).support() hypothesis_ = hypothesis.gaps(support=uem) - true_positive = 0. + true_positive = 0.0 for r, h in reference.co_iter(hypothesis): true_positive += (r & h).duration - false_negative = 0. + false_negative = 0.0 for r, h_ in reference.co_iter(hypothesis_): false_negative += (r & h_).duration @@ -358,21 +389,21 @@ def compute_components(self, return detail def compute_metric(self, detail: Details) -> float: - relevant_retrieved = 1. * detail[RECALL_RELEVANT_RETRIEVED] - relevant = 1. * detail[RECALL_RELEVANT] - if relevant == 0.: + relevant_retrieved = 1.0 * detail[RECALL_RELEVANT_RETRIEVED] + relevant = 1.0 * detail[RECALL_RELEVANT] + if relevant == 0.0: if relevant_retrieved == 0: - return 1. + return 1.0 else: - return 0. + return 0.0 else: return relevant_retrieved / relevant -DFS_NAME = 'F[precision|recall]' -DFS_PRECISION_RETRIEVED = 'retrieved' -DFS_RECALL_RELEVANT = 'relevant' -DFS_RELEVANT_RETRIEVED = 'relevant retrieved' +DFS_NAME = "F[precision|recall]" +DFS_PRECISION_RETRIEVED = "retrieved" +DFS_RECALL_RELEVANT = "relevant" +DFS_RELEVANT_RETRIEVED = "relevant retrieved" class DetectionPrecisionRecallFMeasure(UEMSupportMixin, BaseMetric): @@ -407,23 +438,34 @@ def metric_name(cls): def metric_components(cls): return [DFS_PRECISION_RETRIEVED, DFS_RECALL_RELEVANT, DFS_RELEVANT_RETRIEVED] - def __init__(self, collar: float = 0.0, skip_overlap: bool = False, - beta: float = 1., **kwargs): + def __init__( + self, + collar: float = 0.0, + skip_overlap: bool = False, + beta: float = 1.0, + **kwargs, + ): super(DetectionPrecisionRecallFMeasure, self).__init__(**kwargs) self.collar = collar self.skip_overlap = skip_overlap self.beta = beta - def compute_components(self, - reference: Annotation, - hypothesis: Annotation, - uem: Optional[Timeline] = None, - **kwargs) -> Details: + def compute_components( + self, + reference: Annotation, + hypothesis: Annotation, + uem: Optional[Timeline] = None, + **kwargs, + ) -> Details: reference, hypothesis, uem = self.uemify( - reference, hypothesis, uem=uem, - collar=self.collar, skip_overlap=self.skip_overlap, - returns_uem=True) + reference, + hypothesis, + uem=uem, + collar=self.collar, + skip_overlap=self.skip_overlap, + returns_uem=True, + ) reference = reference.get_timeline(copy=False).support() hypothesis = hypothesis.get_timeline(copy=False).support() @@ -434,21 +476,23 @@ def compute_components(self, # Better to recompute everything from scratch instead of calling the # DetectionPrecision & DetectionRecall classes (we skip one of the loop # that computes the amount of true positives). - true_positive = 0. + true_positive = 0.0 for r, h in reference.co_iter(hypothesis): true_positive += (r & h).duration - false_positive = 0. + false_positive = 0.0 for r_, h in reference_.co_iter(hypothesis): false_positive += (r_ & h).duration - false_negative = 0. + false_negative = 0.0 for r, h_ in reference.co_iter(hypothesis_): false_negative += (r & h_).duration - detail = {DFS_PRECISION_RETRIEVED: true_positive + false_positive, - DFS_RECALL_RELEVANT: true_positive + false_negative, - DFS_RELEVANT_RETRIEVED: true_positive} + detail = { + DFS_PRECISION_RETRIEVED: true_positive + false_positive, + DFS_RECALL_RELEVANT: true_positive + false_negative, + DFS_RELEVANT_RETRIEVED: true_positive, + } return detail @@ -456,8 +500,9 @@ def compute_metric(self, detail: Details) -> float: _, _, value = self.compute_metrics(detail=detail) return value - def compute_metrics(self, detail: Optional[Details] = None) \ - -> Tuple[float, float, float]: + def compute_metrics( + self, detail: Optional[Details] = None + ) -> Tuple[float, float, float]: detail = self.accumulated_ if detail is None else detail precision_retrieved = detail[DFS_PRECISION_RETRIEVED] @@ -465,28 +510,28 @@ def compute_metrics(self, detail: Optional[Details] = None) \ relevant_retrieved = detail[DFS_RELEVANT_RETRIEVED] # Special cases : precision - if precision_retrieved == 0.: + if precision_retrieved == 0.0: precision = 1 else: precision = relevant_retrieved / precision_retrieved # Special cases : recall - if recall_relevant == 0.: + if recall_relevant == 0.0: if relevant_retrieved == 0: - recall = 1. + recall = 1.0 else: - recall = 0. + recall = 0.0 else: recall = relevant_retrieved / recall_relevant return precision, recall, f_measure(precision, recall, beta=self.beta) -DCF_NAME = 'detection cost function' -DCF_POS_TOTAL = 'positive class total' # Total duration of positive class. -DCF_NEG_TOTAL = 'negative class total' # Total duration of negative class. -DCF_FALSE_ALARM = 'false alarm' # Total duration of false alarms. -DCF_MISS = 'miss' # Total duration of misses. +DCF_NAME = "detection cost function" +DCF_POS_TOTAL = "positive class total" # Total duration of positive class. +DCF_NEG_TOTAL = "negative class total" # Total duration of negative class. +DCF_FALSE_ALARM = "false alarm" # Total duration of false alarms. +DCF_MISS = "miss" # Total duration of misses. class DetectionCostFunction(UEMSupportMixin, BaseMetric): @@ -530,8 +575,9 @@ class DetectionCostFunction(UEMSupportMixin, BaseMetric): "OpenSAT19 Evaluation Plan v2." https://www.nist.gov/system/files/documents/2018/11/05/opensat19_evaluation_plan_v2_11-5-18.pdf """ - def __init__(self, collar=0.0, skip_overlap=False, fa_weight=0.25, - miss_weight=0.75, **kwargs): + def __init__( + self, collar=0.0, skip_overlap=False, fa_weight=0.25, miss_weight=0.75, **kwargs + ): super(DetectionCostFunction, self).__init__(**kwargs) self.collar = collar self.skip_overlap = skip_overlap @@ -546,16 +592,22 @@ def metric_name(cls): def metric_components(cls) -> MetricComponents: return [DCF_POS_TOTAL, DCF_NEG_TOTAL, DCF_FALSE_ALARM, DCF_MISS] - def compute_components(self, - reference: Annotation, - hypothesis: Annotation, - uem: Optional[Timeline] = None, - **kwargs) -> Details: + def compute_components( + self, + reference: Annotation, + hypothesis: Annotation, + uem: Optional[Timeline] = None, + **kwargs, + ) -> Details: reference, hypothesis, uem = self.uemify( - reference, hypothesis, uem=uem, - collar=self.collar, skip_overlap=self.skip_overlap, - returns_uem=True) + reference, + hypothesis, + uem=uem, + collar=self.collar, + skip_overlap=self.skip_overlap, + returns_uem=True, + ) # Obtain timelines corresponding to positive class. reference = reference.get_timeline(copy=False).support() @@ -583,7 +635,8 @@ def compute_components(self, DCF_POS_TOTAL: pos_dur, DCF_NEG_TOTAL: neg_dur, DCF_MISS: miss_dur, - DCF_FALSE_ALARM: fa_dur} + DCF_FALSE_ALARM: fa_dur, + } return components diff --git a/src/pyannote/metrics/diarization.py b/src/pyannote/metrics/diarization.py index aadb416..effc2e3 100755 --- a/src/pyannote/metrics/diarization.py +++ b/src/pyannote/metrics/diarization.py @@ -41,12 +41,24 @@ from .types import Details, MetricComponents from .utils import UEMSupportMixin +from .matcher import ( + LabelMatcher, + MATCH_TOTAL, + MATCH_CORRECT, + MATCH_CONFUSION, + MATCH_MISSED_DETECTION, + MATCH_FALSE_ALARM, +) + if TYPE_CHECKING: pass # TODO: can't we put these as class attributes? DER_NAME = "diarization error rate" +OVLDER_PREFIX_OVL = "ovl" +OVLDER_PREFIX_NONOVL = "nonovl" + class DiarizationErrorRate(IdentificationErrorRate): """Diarization error rate @@ -646,3 +658,96 @@ def compute_components( return super(DiarizationCompleteness, self).compute_components( hypothesis, reference, uem=uem, **kwargs ) + + +class OverlappedDiarizationErrorRate(BaseMetric): + """Diarization error rate with details for overlap and non-overlap errors. + Error components will be prefixed with 'ovl' or 'nonovl' (e.g. 'ovl false alarm') + + Parameters + ---------- + collar : float, optional + Duration (in seconds) of collars removed from evaluation around + boundaries of reference segments. + """ + + OVDER_NAME = "diarization error rate" + + def __init__(self, collar: float = 0.0): + super().__init__() + + self.ier_ovl = IdentificationErrorRate(collar=collar, skip_overlap=False) + self.ier_nonovl = IdentificationErrorRate(collar=collar, skip_overlap=False) + + @classmethod + def metric_components(cls) -> MetricComponents: + comps = [] + for ovl in [OVLDER_PREFIX_NONOVL, OVLDER_PREFIX_OVL]: + for comp in [ + MATCH_TOTAL, + MATCH_CORRECT, + MATCH_CONFUSION, + MATCH_MISSED_DETECTION, + MATCH_FALSE_ALARM, + ]: + comps.append(f"{ovl} {comp}") + return comps + + @classmethod + def metric_name(cls) -> str: + return cls.OVDER_NAME + + def compute_components( + self, reference: Annotation, hypothesis: Annotation, uem: Timeline | None = None + ) -> Details: + + # map 'hypothesis' labels to 'reference' labels + mapping: dict = DiarizationErrorRate().optimal_mapping( + reference, hypothesis, uem=uem + ) + hypothesis = hypothesis.rename_labels(mapping) + + # split uem into non-overlapping and overlapping regions + overlap: Timeline = reference.get_overlap() + if uem is None: + uem = ( + reference.support() + .get_timeline() + .union(hypothesis.support().get_timeline()) + ) + nonovl_regions: Timeline = uem.extrude(overlap) + ovl_regions: Timeline = uem.crop(overlap) + + # update internal metrics for (non-)overlapping errors + comps_nonovl = self.ier_nonovl.compute_components( + reference, hypothesis, uem=nonovl_regions + ) + comps_ovl = self.ier_ovl.compute_components( + reference, hypothesis, uem=ovl_regions + ) + + components = {} + components.update( + {f"{OVLDER_PREFIX_NONOVL} {k}": v for k, v in comps_nonovl.items()} + ) + components.update({f"{OVLDER_PREFIX_OVL} {k}": v for k, v in comps_ovl.items()}) + + return components + + def compute_metric(self, detail: Details) -> float: + numerator = 0.0 + denominator = 0.0 + for ovl in [OVLDER_PREFIX_NONOVL, OVLDER_PREFIX_OVL]: + numerator += ( + detail[f"{ovl} {MATCH_FALSE_ALARM}"] + + detail[f"{ovl} {MATCH_MISSED_DETECTION}"] + + detail[f"{ovl} {MATCH_CONFUSION}"] + ) + denominator += detail[f"{ovl} {MATCH_TOTAL}"] + if denominator == 0.0: + if numerator == 0: + return 0.0 + else: + return 1.0 + else: + return numerator / denominator diff --git a/src/pyannote/metrics/errors/identification.py b/src/pyannote/metrics/errors/identification.py index 5370b02..03562fe 100755 --- a/src/pyannote/metrics/errors/identification.py +++ b/src/pyannote/metrics/errors/identification.py @@ -34,19 +34,23 @@ from ..identification import UEMSupportMixin from ..matcher import LabelMatcher -from ..matcher import MATCH_CORRECT, MATCH_CONFUSION, \ - MATCH_MISSED_DETECTION, MATCH_FALSE_ALARM +from ..matcher import ( + MATCH_CORRECT, + MATCH_CONFUSION, + MATCH_MISSED_DETECTION, + MATCH_FALSE_ALARM, +) if TYPE_CHECKING: from xarray import DataArray -REFERENCE_TOTAL = 'reference' -HYPOTHESIS_TOTAL = 'hypothesis' +REFERENCE_TOTAL = "reference" +HYPOTHESIS_TOTAL = "hypothesis" -REGRESSION = 'regression' -IMPROVEMENT = 'improvement' -BOTH_CORRECT = 'both_correct' -BOTH_INCORRECT = 'both_incorrect' +REGRESSION = "regression" +IMPROVEMENT = "improvement" +BOTH_CORRECT = "both_correct" +BOTH_INCORRECT = "both_incorrect" class IdentificationErrorAnalysis(UEMSupportMixin): @@ -62,18 +66,20 @@ class IdentificationErrorAnalysis(UEMSupportMixin): Defaults to False (i.e. keep overlap regions). """ - def __init__(self, collar: float = 0., skip_overlap: bool = False): + def __init__(self, collar: float = 0.0, skip_overlap: bool = False): super().__init__() self.matcher = LabelMatcher() self.collar = collar self.skip_overlap = skip_overlap - def difference(self, - reference: Annotation, - hypothesis: Annotation, - uem: Optional[Timeline] = None, - uemified: bool = False): + def difference( + self, + reference: Annotation, + hypothesis: Annotation, + uem: Optional[Timeline] = None, + uemified: bool = False, + ): """Get error analysis as `Annotation` Labels are (status, reference_label, hypothesis_label) tuples. @@ -95,9 +101,13 @@ def difference(self, """ R, H, common_timeline = self.uemify( - reference, hypothesis, uem=uem, - collar=self.collar, skip_overlap=self.skip_overlap, - returns_timeline=True) + reference, + hypothesis, + uem=uem, + collar=self.collar, + skip_overlap=self.skip_overlap, + returns_timeline=True, + ) errors = Annotation(uri=reference.uri, modality=reference.modality) @@ -121,8 +131,7 @@ def difference(self, errors[segment, track] = (MATCH_CONFUSION, r, h) for r in details[MATCH_MISSED_DETECTION]: - track = errors.new_track(segment, - prefix=MATCH_MISSED_DETECTION) + track = errors.new_track(segment, prefix=MATCH_MISSED_DETECTION) errors[segment, track] = (MATCH_MISSED_DETECTION, r, None) for h in details[MATCH_FALSE_ALARM]: @@ -140,24 +149,29 @@ def _match_errors(self, before, after): return (b_ref == a_ref) * (1 + (b_type == a_type) + (b_hyp == a_hyp)) # TODO : return type - def regression(self, - reference: Annotation, - before: Annotation, - after: Annotation, - uem: Optional[Timeline] = None, - uemified: bool = False): + def regression( + self, + reference: Annotation, + before: Annotation, + after: Annotation, + uem: Optional[Timeline] = None, + uemified: bool = False, + ): _, before, errors_before = self.difference( - reference, before, uem=uem, uemified=True) + reference, before, uem=uem, uemified=True + ) reference, after, errors_after = self.difference( - reference, after, uem=uem, uemified=True) + reference, after, uem=uem, uemified=True + ) behaviors = Annotation(uri=reference.uri, modality=reference.modality) # common (up-sampled) timeline common_timeline = errors_after.get_timeline().union( - errors_before.get_timeline()) + errors_before.get_timeline() + ) common_timeline = common_timeline.segmentation() # align 'before' errors on common timeline @@ -183,50 +197,60 @@ def regression(self, for i1, i2 in zip(*linear_sum_assignment(-match)): if i1 >= n1: - track = behaviors.new_track(segment, - candidate=REGRESSION, - prefix=REGRESSION) - behaviors[segment, track] = ( - REGRESSION, None, new_errors[i2]) + track = behaviors.new_track( + segment, candidate=REGRESSION, prefix=REGRESSION + ) + behaviors[segment, track] = (REGRESSION, None, new_errors[i2]) elif i2 >= n2: - track = behaviors.new_track(segment, - candidate=IMPROVEMENT, - prefix=IMPROVEMENT) - behaviors[segment, track] = ( - IMPROVEMENT, old_errors[i1], None) + track = behaviors.new_track( + segment, candidate=IMPROVEMENT, prefix=IMPROVEMENT + ) + behaviors[segment, track] = (IMPROVEMENT, old_errors[i1], None) elif old_errors[i1][0] == MATCH_CORRECT: if new_errors[i2][0] == MATCH_CORRECT: - track = behaviors.new_track(segment, - candidate=BOTH_CORRECT, - prefix=BOTH_CORRECT) + track = behaviors.new_track( + segment, candidate=BOTH_CORRECT, prefix=BOTH_CORRECT + ) behaviors[segment, track] = ( - BOTH_CORRECT, old_errors[i1], new_errors[i2]) + BOTH_CORRECT, + old_errors[i1], + new_errors[i2], + ) else: - track = behaviors.new_track(segment, - candidate=REGRESSION, - prefix=REGRESSION) + track = behaviors.new_track( + segment, candidate=REGRESSION, prefix=REGRESSION + ) behaviors[segment, track] = ( - REGRESSION, old_errors[i1], new_errors[i2]) + REGRESSION, + old_errors[i1], + new_errors[i2], + ) else: if new_errors[i2][0] == MATCH_CORRECT: - track = behaviors.new_track(segment, - candidate=IMPROVEMENT, - prefix=IMPROVEMENT) + track = behaviors.new_track( + segment, candidate=IMPROVEMENT, prefix=IMPROVEMENT + ) behaviors[segment, track] = ( - IMPROVEMENT, old_errors[i1], new_errors[i2]) + IMPROVEMENT, + old_errors[i1], + new_errors[i2], + ) else: - track = behaviors.new_track(segment, - candidate=BOTH_INCORRECT, - prefix=BOTH_INCORRECT) + track = behaviors.new_track( + segment, candidate=BOTH_INCORRECT, prefix=BOTH_INCORRECT + ) behaviors[segment, track] = ( - BOTH_INCORRECT, old_errors[i1], new_errors[i2]) + BOTH_INCORRECT, + old_errors[i1], + new_errors[i2], + ) behaviors = behaviors.support() @@ -235,13 +259,16 @@ def regression(self, else: return behaviors - def matrix(self, - reference: Annotation, - hypothesis: Annotation, - uem: Optional[Timeline] = None) -> 'DataArray': + def matrix( + self, + reference: Annotation, + hypothesis: Annotation, + uem: Optional[Timeline] = None, + ) -> "DataArray": reference, hypothesis, errors = self.difference( - reference, hypothesis, uem=uem, uemified=True) + reference, hypothesis, uem=uem, uemified=True + ) chart = errors.chart() @@ -262,15 +289,17 @@ def matrix(self, # append false alarm labels as last 'reference' labels # (make sure to mark them as such) - rLabels = rLabels + [(MATCH_FALSE_ALARM, hLabel) - for hLabel in falseAlarmLabels] + rLabels = rLabels + [(MATCH_FALSE_ALARM, hLabel) for hLabel in falseAlarmLabels] # prepend duration columns before the detailed confusion matrix hLabels = [ - REFERENCE_TOTAL, HYPOTHESIS_TOTAL, - MATCH_CORRECT, MATCH_CONFUSION, - MATCH_FALSE_ALARM, MATCH_MISSED_DETECTION - ] + hLabels + REFERENCE_TOTAL, + HYPOTHESIS_TOTAL, + MATCH_CORRECT, + MATCH_CONFUSION, + MATCH_FALSE_ALARM, + MATCH_MISSED_DETECTION, + ] + hLabels # initialize empty matrix @@ -285,7 +314,8 @@ def matrix(self, matrix = DataArray( np.zeros((len(rLabels), len(hLabels))), - coords=[('reference', rLabels), ('hypothesis', hLabels)]) + coords=[("reference", rLabels), ("hypothesis", hLabels)], + ) # loop on chart for (status, rLabel, hLabel), duration in chart: @@ -309,7 +339,9 @@ def matrix(self, if status == MATCH_FALSE_ALARM: # hLabel is also a reference label if hLabel in falseAlarmLabels: - matrix.loc[(MATCH_FALSE_ALARM, hLabel), MATCH_FALSE_ALARM] += duration + matrix.loc[ + (MATCH_FALSE_ALARM, hLabel), MATCH_FALSE_ALARM + ] += duration else: matrix.loc[hLabel, MATCH_FALSE_ALARM] += duration @@ -320,7 +352,7 @@ def matrix(self, for rLabel in rLabels: if isinstance(rLabel, tuple) and rLabel[0] == MATCH_FALSE_ALARM: - r = 0. + r = 0.0 h = hypothesis.label_duration(rLabel[1]) else: r = reference.label_duration(rLabel) diff --git a/src/pyannote/metrics/errors/segmentation.py b/src/pyannote/metrics/errors/segmentation.py index 34c9c22..4a6be60 100644 --- a/src/pyannote/metrics/errors/segmentation.py +++ b/src/pyannote/metrics/errors/segmentation.py @@ -35,8 +35,11 @@ class SegmentationErrorAnalysis: def __init__(self): super().__init__() - def __call__(self, reference: Union[Timeline, Annotation], - hypothesis: Union[Timeline, Annotation]) -> Annotation: + def __call__( + self, + reference: Union[Timeline, Annotation], + hypothesis: Union[Timeline, Annotation], + ) -> Annotation: if isinstance(reference, Annotation): reference = reference.get_timeline() @@ -100,10 +103,10 @@ def __call__(self, reference: Union[Timeline, Annotation], status = Annotation(uri=reference.uri) for segment in frontier: - status[segment, '_'] = 'shift' + status[segment, "_"] = "shift" for segment in only_over: - status[segment, '_'] = 'over-segmentation' + status[segment, "_"] = "over-segmentation" for segment in only_under: - status[segment, '_'] = 'under-segmentation' + status[segment, "_"] = "under-segmentation" return status.support() diff --git a/src/pyannote/metrics/identification.py b/src/pyannote/metrics/identification.py index 0ffddc5..3f40da1 100755 --- a/src/pyannote/metrics/identification.py +++ b/src/pyannote/metrics/identification.py @@ -32,9 +32,14 @@ from .base import BaseMetric from .base import Precision, PRECISION_RETRIEVED, PRECISION_RELEVANT_RETRIEVED from .base import Recall, RECALL_RELEVANT, RECALL_RELEVANT_RETRIEVED -from .matcher import LabelMatcher, \ - MATCH_TOTAL, MATCH_CORRECT, MATCH_CONFUSION, \ - MATCH_MISSED_DETECTION, MATCH_FALSE_ALARM +from .matcher import ( + LabelMatcher, + MATCH_TOTAL, + MATCH_CORRECT, + MATCH_CONFUSION, + MATCH_MISSED_DETECTION, + MATCH_FALSE_ALARM, +) from .types import MetricComponents, Details from .utils import UEMSupportMixin @@ -44,7 +49,7 @@ IER_CONFUSION = MATCH_CONFUSION IER_FALSE_ALARM = MATCH_FALSE_ALARM IER_MISS = MATCH_MISSED_DETECTION -IER_NAME = 'identification error rate' +IER_NAME = "identification error rate" class IdentificationErrorRate(UEMSupportMixin, BaseMetric): @@ -78,19 +83,17 @@ def metric_name(cls) -> str: @classmethod def metric_components(cls) -> MetricComponents: - return [ - IER_TOTAL, - IER_CORRECT, - IER_FALSE_ALARM, IER_MISS, - IER_CONFUSION] - - def __init__(self, - confusion: float = 1., - miss: float = 1., - false_alarm: float = 1., - collar: float = 0., - skip_overlap: bool = False, - **kwargs): + return [IER_TOTAL, IER_CORRECT, IER_FALSE_ALARM, IER_MISS, IER_CONFUSION] + + def __init__( + self, + confusion: float = 1.0, + miss: float = 1.0, + false_alarm: float = 1.0, + collar: float = 0.0, + skip_overlap: bool = False, + **kwargs, + ): super().__init__(**kwargs) self.matcher_ = LabelMatcher() @@ -100,13 +103,15 @@ def __init__(self, self.collar = collar self.skip_overlap = skip_overlap - def compute_components(self, - reference: Annotation, - hypothesis: Annotation, - uem: Optional[Timeline] = None, - collar: Optional[float] = None, - skip_overlap: Optional[float] = None, - **kwargs) -> Details: + def compute_components( + self, + reference: Annotation, + hypothesis: Annotation, + uem: Optional[Timeline] = None, + collar: Optional[float] = None, + skip_overlap: Optional[float] = None, + **kwargs, + ) -> Details: """ Parameters @@ -131,9 +136,13 @@ def compute_components(self, skip_overlap = self.skip_overlap R, H, common_timeline = self.uemify( - reference, hypothesis, uem=uem, - collar=collar, skip_overlap=skip_overlap, - returns_timeline=True) + reference, + hypothesis, + uem=uem, + collar=collar, + skip_overlap=skip_overlap, + returns_timeline=True, + ) # loop on all segments for segment in common_timeline: @@ -158,17 +167,17 @@ def compute_components(self, def compute_metric(self, detail: Details) -> float: - numerator = 1. * ( - self.confusion * detail[IER_CONFUSION] + - self.false_alarm * detail[IER_FALSE_ALARM] + - self.miss * detail[IER_MISS] + numerator = 1.0 * ( + self.confusion * detail[IER_CONFUSION] + + self.false_alarm * detail[IER_FALSE_ALARM] + + self.miss * detail[IER_MISS] ) - denominator = 1. * detail[IER_TOTAL] - if denominator == 0.: + denominator = 1.0 * detail[IER_TOTAL] + if denominator == 0.0: if numerator == 0: - return 0. + return 0.0 else: - return 1. + return 1.0 else: return numerator / denominator @@ -186,23 +195,29 @@ class IdentificationPrecision(UEMSupportMixin, Precision): Defaults to False (i.e. keep overlap regions). """ - def __init__(self, collar: float = 0., skip_overlap: bool = False, **kwargs): + def __init__(self, collar: float = 0.0, skip_overlap: bool = False, **kwargs): super().__init__(**kwargs) self.collar = collar self.skip_overlap = skip_overlap self.matcher_ = LabelMatcher() - def compute_components(self, - reference: Annotation, - hypothesis: Annotation, - uem: Optional[Timeline] = None, - **kwargs) -> Details: + def compute_components( + self, + reference: Annotation, + hypothesis: Annotation, + uem: Optional[Timeline] = None, + **kwargs, + ) -> Details: detail = self.init_components() R, H, common_timeline = self.uemify( - reference, hypothesis, uem=uem, - collar=self.collar, skip_overlap=self.skip_overlap, - returns_timeline=True) + reference, + hypothesis, + uem=uem, + collar=self.collar, + skip_overlap=self.skip_overlap, + returns_timeline=True, + ) # loop on all segments for segment in common_timeline: @@ -218,8 +233,7 @@ def compute_components(self, counts, _ = self.matcher_(r, h) detail[PRECISION_RETRIEVED] += duration * len(h) - detail[PRECISION_RELEVANT_RETRIEVED] += \ - duration * counts[IER_CORRECT] + detail[PRECISION_RELEVANT_RETRIEVED] += duration * counts[IER_CORRECT] return detail @@ -237,23 +251,29 @@ class IdentificationRecall(UEMSupportMixin, Recall): Defaults to False (i.e. keep overlap regions). """ - def __init__(self, collar: float = 0., skip_overlap: bool = False, **kwargs): + def __init__(self, collar: float = 0.0, skip_overlap: bool = False, **kwargs): super().__init__(**kwargs) self.collar = collar self.skip_overlap = skip_overlap self.matcher_ = LabelMatcher() - def compute_components(self, - reference: Annotation, - hypothesis: Annotation, - uem: Optional[Timeline] = None, - **kwargs) -> Details: + def compute_components( + self, + reference: Annotation, + hypothesis: Annotation, + uem: Optional[Timeline] = None, + **kwargs, + ) -> Details: detail = self.init_components() R, H, common_timeline = self.uemify( - reference, hypothesis, uem=uem, - collar=self.collar, skip_overlap=self.skip_overlap, - returns_timeline=True) + reference, + hypothesis, + uem=uem, + collar=self.collar, + skip_overlap=self.skip_overlap, + returns_timeline=True, + ) # loop on all segments for segment in common_timeline: diff --git a/src/pyannote/metrics/matcher.py b/src/pyannote/metrics/matcher.py index 13e9bb1..c1540d4 100644 --- a/src/pyannote/metrics/matcher.py +++ b/src/pyannote/metrics/matcher.py @@ -34,11 +34,11 @@ if TYPE_CHECKING: from pyannote.core.utils.types import Label -MATCH_CORRECT = 'correct' -MATCH_CONFUSION = 'confusion' -MATCH_MISSED_DETECTION = 'missed detection' -MATCH_FALSE_ALARM = 'false alarm' -MATCH_TOTAL = 'total' +MATCH_CORRECT = "correct" +MATCH_CONFUSION = "confusion" +MATCH_MISSED_DETECTION = "missed detection" +MATCH_FALSE_ALARM = "false alarm" +MATCH_TOTAL = "total" class LabelMatcher: @@ -50,7 +50,7 @@ class LabelMatcher: otherwise. """ - def match(self, rlabel: 'Label', hlabel: 'Label') -> bool: + def match(self, rlabel: "Label", hlabel: "Label") -> bool: """ Parameters ---------- @@ -68,9 +68,9 @@ def match(self, rlabel: 'Label', hlabel: 'Label') -> bool: # Two IDs match if they are equal to each other return rlabel == hlabel - def __call__(self, rlabels: Iterable['Label'], hlabels: Iterable['Label']) \ - -> Tuple[Dict[str, int], - Dict[str, List['Label']]]: + def __call__( + self, rlabels: Iterable["Label"], hlabels: Iterable["Label"] + ) -> Tuple[Dict[str, int], Dict[str, List["Label"]]]: """ Parameters @@ -91,14 +91,14 @@ def __call__(self, rlabels: Iterable['Label'], hlabels: Iterable['Label']) \ MATCH_CONFUSION: 0, MATCH_MISSED_DETECTION: 0, MATCH_FALSE_ALARM: 0, - MATCH_TOTAL: 0 + MATCH_TOTAL: 0, } details = { MATCH_CORRECT: [], MATCH_CONFUSION: [], MATCH_MISSED_DETECTION: [], - MATCH_FALSE_ALARM: [] + MATCH_FALSE_ALARM: [], } # this is to make sure rlabels and hlabels are lists # as we will access them later by index @@ -156,7 +156,7 @@ def __call__(self, rlabels: Iterable['Label'], hlabels: Iterable['Label']) \ class HungarianMapper: - def __call__(self, A: Annotation, B: Annotation) -> Dict['Label', 'Label']: + def __call__(self, A: Annotation, B: Annotation) -> Dict["Label", "Label"]: mapping = {} cooccurrence = A * B @@ -171,7 +171,7 @@ def __call__(self, A: Annotation, B: Annotation) -> Dict['Label', 'Label']: class GreedyMapper: - def __call__(self, A: Annotation, B: Annotation) -> Dict['Label', 'Label']: + def __call__(self, A: Annotation, B: Annotation) -> Dict["Label", "Label"]: mapping = {} cooccurrence = A * B @@ -183,8 +183,8 @@ def __call__(self, A: Annotation, B: Annotation) -> Dict['Label', 'Label']: if cooccurrence[a, b] > 0: mapping[a_labels[a]] = b_labels[b] - cooccurrence[a, :] = 0. - cooccurrence[:, b] = 0. + cooccurrence[a, :] = 0.0 + cooccurrence[:, b] = 0.0 continue break diff --git a/src/pyannote/metrics/segmentation.py b/src/pyannote/metrics/segmentation.py index 200d6e2..8d25444 100755 --- a/src/pyannote/metrics/segmentation.py +++ b/src/pyannote/metrics/segmentation.py @@ -39,22 +39,22 @@ from .utils import UEMSupportMixin #  TODO: can't we put these as class attributes? -PURITY_NAME = 'segmentation purity' -COVERAGE_NAME = 'segmentation coverage' -PURITY_COVERAGE_NAME = 'segmentation F[purity|coverage]' -PTY_CVG_TOTAL = 'total duration' -PTY_CVG_INTER = 'intersection duration' +PURITY_NAME = "segmentation purity" +COVERAGE_NAME = "segmentation coverage" +PURITY_COVERAGE_NAME = "segmentation F[purity|coverage]" +PTY_CVG_TOTAL = "total duration" +PTY_CVG_INTER = "intersection duration" -PTY_TOTAL = 'pty total duration' -PTY_INTER = 'pty intersection duration' -CVG_TOTAL = 'cvg total duration' -CVG_INTER = 'cvg intersection duration' +PTY_TOTAL = "pty total duration" +PTY_INTER = "pty intersection duration" +CVG_TOTAL = "cvg total duration" +CVG_INTER = "cvg intersection duration" -PRECISION_NAME = 'segmentation precision' -RECALL_NAME = 'segmentation recall' +PRECISION_NAME = "segmentation precision" +RECALL_NAME = "segmentation recall" -PR_BOUNDARIES = 'number of boundaries' -PR_MATCHES = 'number of matches' +PR_BOUNDARIES = "number of boundaries" +PR_MATCHES = "number of matches" class SegmentationCoverage(BaseMetric): @@ -72,9 +72,7 @@ def __init__(self, tolerance: float = 0.500, **kwargs): super().__init__(**kwargs) self.tolerance = tolerance - def _partition(self, - timeline: Timeline, - coverage: Timeline) -> Annotation: + def _partition(self, timeline: Timeline, coverage: Timeline) -> Annotation: # boundaries (as set of timestamps) boundaries = set([]) @@ -86,16 +84,16 @@ def _partition(self, partition = Annotation() for start, end in pairwise(sorted(boundaries)): segment = Segment(start, end) - partition[segment] = '_' + partition[segment] = "_" - return partition.crop(coverage, mode='intersection').relabel_tracks() + return partition.crop(coverage, mode="intersection").relabel_tracks() - def _preprocess(self, reference: Annotation, - hypothesis: Union[Annotation, Timeline]) \ - -> Tuple[Annotation, Annotation]: + def _preprocess( + self, reference: Annotation, hypothesis: Union[Annotation, Timeline] + ) -> Tuple[Annotation, Annotation]: if not isinstance(reference, Annotation): - raise TypeError('reference must be an instance of `Annotation`') + raise TypeError("reference must be an instance of `Annotation`") if isinstance(hypothesis, Annotation): hypothesis: Timeline = hypothesis.get_timeline() @@ -138,8 +136,9 @@ def metric_name(cls): def metric_components(cls) -> MetricComponents: return [PTY_CVG_TOTAL, PTY_CVG_INTER] - def compute_components(self, reference: Annotation, - hypothesis: Union[Annotation, Timeline], **kwargs): + def compute_components( + self, reference: Annotation, hypothesis: Union[Annotation, Timeline], **kwargs + ): reference, hypothesis = self._preprocess(reference, hypothesis) return self._process(reference, hypothesis) @@ -163,9 +162,9 @@ def metric_name(cls) -> str: return PURITY_NAME # TODO : Use type from parent class - def compute_components(self, reference: Annotation, - hypothesis: Union[Annotation, Timeline], - **kwargs) -> Details: + def compute_components( + self, reference: Annotation, hypothesis: Union[Annotation, Timeline], **kwargs + ) -> Details: reference, hypothesis = self._preprocess(reference, hypothesis) return self._process(hypothesis, reference) @@ -194,11 +193,14 @@ class SegmentationPurityCoverageFMeasure(SegmentationCoverage): """ def __init__(self, tolerance=0.500, beta=1, **kwargs): - super(SegmentationPurityCoverageFMeasure, self).__init__(tolerance=tolerance, **kwargs) + super(SegmentationPurityCoverageFMeasure, self).__init__( + tolerance=tolerance, **kwargs + ) self.beta = beta - def _process(self, reference: Annotation, - hypothesis: Union[Annotation, Timeline]) -> Details: + def _process( + self, reference: Annotation, hypothesis: Union[Annotation, Timeline] + ) -> Details: reference, hypothesis = self._preprocess(reference, hypothesis) detail = self.init_components() @@ -214,26 +216,27 @@ def _process(self, reference: Annotation, return detail - def compute_components(self, reference: Annotation, - hypothesis: Union[Annotation, Timeline], - **kwargs) -> Details: + def compute_components( + self, reference: Annotation, hypothesis: Union[Annotation, Timeline], **kwargs + ) -> Details: return self._process(reference, hypothesis) def compute_metric(self, detail: Details) -> float: _, _, value = self.compute_metrics(detail=detail) return value - def compute_metrics(self, detail: Optional[Details] = None) \ - -> Tuple[float, float, float]: + def compute_metrics( + self, detail: Optional[Details] = None + ) -> Tuple[float, float, float]: detail = self.accumulated_ if detail is None else detail - purity = \ - 1. if detail[PTY_TOTAL] == 0. \ - else detail[PTY_INTER] / detail[PTY_TOTAL] + purity = ( + 1.0 if detail[PTY_TOTAL] == 0.0 else detail[PTY_INTER] / detail[PTY_TOTAL] + ) - coverage = \ - 1. if detail[CVG_TOTAL] == 0. \ - else detail[CVG_INTER] / detail[CVG_TOTAL] + coverage = ( + 1.0 if detail[CVG_TOTAL] == 0.0 else detail[CVG_INTER] / detail[CVG_TOTAL] + ) return purity, coverage, f_measure(purity, coverage, beta=self.beta) @@ -281,15 +284,17 @@ def metric_name(cls): def metric_components(cls): return [PR_MATCHES, PR_BOUNDARIES] - def __init__(self, tolerance=0., **kwargs): + def __init__(self, tolerance=0.0, **kwargs): super().__init__(**kwargs) self.tolerance = tolerance - def compute_components(self, - reference: Union[Annotation, Timeline], - hypothesis: Union[Annotation, Timeline], - **kwargs) -> Details: + def compute_components( + self, + reference: Union[Annotation, Timeline], + hypothesis: Union[Annotation, Timeline], + **kwargs, + ) -> Details: # extract timeline if needed if isinstance(reference, Annotation): @@ -300,7 +305,7 @@ def compute_components(self, detail = self.init_components() # number of matches so far... - n_matches = 0. # make sure it is a float (for later ratio) + n_matches = 0.0 # make sure it is a float (for later ratio) # number of boundaries in reference and hypothesis N = len(reference) - 1 @@ -311,7 +316,7 @@ def compute_components(self, # corner case (no boundary in hypothesis or in reference) if M == 0 or N == 0: - detail[PR_MATCHES] = 0. + detail[PR_MATCHES] = 0.0 return detail # reference and hypothesis boundaries @@ -357,11 +362,11 @@ def compute_metric(self, detail: Details) -> float: numerator = detail[PR_MATCHES] denominator = detail[PR_BOUNDARIES] - if denominator == 0.: + if denominator == 0.0: if numerator == 0: - return 1. + return 1.0 else: - raise ValueError('') + raise ValueError("") else: return numerator / denominator @@ -397,8 +402,10 @@ class SegmentationRecall(SegmentationPrecision): def metric_name(cls): return RECALL_NAME - def compute_components(self, reference: Union[Annotation, Timeline], - hypothesis: Union[Annotation, Timeline], - **kwargs) -> Details: - return super(SegmentationRecall, self).compute_components( - hypothesis, reference) + def compute_components( + self, + reference: Union[Annotation, Timeline], + hypothesis: Union[Annotation, Timeline], + **kwargs, + ) -> Details: + return super(SegmentationRecall, self).compute_components(hypothesis, reference) diff --git a/src/pyannote/metrics/spotting.py b/src/pyannote/metrics/spotting.py index b1c3e47..d0b847f 100644 --- a/src/pyannote/metrics/spotting.py +++ b/src/pyannote/metrics/spotting.py @@ -38,9 +38,9 @@ from .types import MetricComponents, Details SPOTTING_TARGET = "target" -SPOTTING_SPK_LATENCY = 'speaker_latency' -SPOTTING_SPK_SCORE = 'spk_score' -SPOTTING_ABS_LATENCY = 'absolute_latency' +SPOTTING_SPK_LATENCY = "speaker_latency" +SPOTTING_SPK_SCORE = "spk_score" +SPOTTING_ABS_LATENCY = "absolute_latency" SPOTTING_ABS_SCORE = "abs_score" SPOTTING_SCORE = "score" @@ -78,19 +78,20 @@ def metric_name(cls) -> str: return "Low-latency speaker spotting" def metric_components(self) -> Dict[str, float]: - return {'target': 0.} + return {"target": 0.0} - def __init__(self, - thresholds: Optional[ArrayLike] = None, - latencies: Optional[ArrayLike] = None): + def __init__( + self, + thresholds: Optional[ArrayLike] = None, + latencies: Optional[ArrayLike] = None, + ): super().__init__() if thresholds is None and latencies is None: latencies = [1, 5, 10, 30, 60] if thresholds is not None and latencies is not None: - raise ValueError( - 'One must choose between fixed and variable latency.') + raise ValueError("One must choose between fixed and variable latency.") if thresholds is not None: self.thresholds = np.sort(thresholds) @@ -103,8 +104,9 @@ def __init__(self, def compute_metric(self, detail: MetricComponents): return None - def _fixed_latency(self, reference: Timeline, - timestamps: List[float], scores: List[float]) -> Details: + def _fixed_latency( + self, reference: Timeline, timestamps: List[float], scores: List[float] + ) -> Details: if not reference: target_trial = False @@ -123,8 +125,9 @@ def _fixed_latency(self, reference: Timeline, abs_score = [] # index of speech turn when given latency is reached - for i, latency in zip(np.searchsorted(total, self.latencies), - self.latencies): + for i, latency in zip( + np.searchsorted(total, self.latencies), self.latencies + ): # maximum score in timerange [0, t] # where t is when latency is reached @@ -161,9 +164,13 @@ def _fixed_latency(self, reference: Timeline, SPOTTING_ABS_SCORE: abs_score, } - def _variable_latency(self, reference: Union[Timeline, Annotation], - timestamps: List[float], scores: List[float], - **kwargs) -> Details: + def _variable_latency( + self, + reference: Union[Timeline, Annotation], + timestamps: List[float], + scores: List[float], + **kwargs, + ) -> Details: # pre-compute latencies speaker_latency = np.nan * np.ones((len(timestamps), 1)) @@ -181,15 +188,19 @@ def _variable_latency(self, reference: Union[Timeline, Annotation], # for every threshold, compute when (if ever) alarm is triggered maxcum = (np.maximum.accumulate(scores)).reshape((-1, 1)) triggered = maxcum > self.thresholds - indices = np.array([np.searchsorted(triggered[:, i], True) - for i, _ in enumerate(self.thresholds)]) + indices = np.array( + [ + np.searchsorted(triggered[:, i], True) + for i, _ in enumerate(self.thresholds) + ] + ) if reference: target_trial = True - absolute_latency = np.take(absolute_latency, indices, mode='clip') - speaker_latency = np.take(speaker_latency, indices, mode='clip') + absolute_latency = np.take(absolute_latency, indices, mode="clip") + speaker_latency = np.take(speaker_latency, indices, mode="clip") # is alarm triggered at all? positive = triggered[-1, :] @@ -213,13 +224,15 @@ def _variable_latency(self, reference: Union[Timeline, Annotation], SPOTTING_TARGET: target_trial, SPOTTING_ABS_LATENCY: absolute_latency, SPOTTING_SPK_LATENCY: speaker_latency, - SPOTTING_SCORE: np.max(scores) + SPOTTING_SCORE: np.max(scores), } - def compute_components(self, reference: Union[Timeline, Annotation], - hypothesis: Union[SlidingWindowFeature, - Iterable[Tuple[float, float]]], - **kwargs) -> Details: + def compute_components( + self, + reference: Union[Timeline, Annotation], + hypothesis: Union[SlidingWindowFeature, Iterable[Tuple[float, float]]], + **kwargs, + ) -> Details: """ Parameters @@ -240,22 +253,26 @@ def compute_components(self, reference: Union[Timeline, Annotation], @property def absolute_latency(self): - latencies = [trial[SPOTTING_ABS_LATENCY] for _, trial in self - if trial[SPOTTING_TARGET]] + latencies = [ + trial[SPOTTING_ABS_LATENCY] for _, trial in self if trial[SPOTTING_TARGET] + ] return np.nanmean(latencies, axis=0) @property def speaker_latency(self): - latencies = [trial[SPOTTING_SPK_LATENCY] for _, trial in self - if trial[SPOTTING_TARGET]] + latencies = [ + trial[SPOTTING_SPK_LATENCY] for _, trial in self if trial[SPOTTING_TARGET] + ] return np.nanmean(latencies, axis=0) # TODO : figure out return type - def det_curve(self, - cost_miss: float = 100, - cost_fa: float = 1, - prior_target: float = 0.01, - return_latency: bool = False): + def det_curve( + self, + cost_miss: float = 100, + cost_fa: float = 1, + prior_target: float = 0.01, + return_latency: bool = False, + ): """DET curve Parameters @@ -293,20 +310,26 @@ def det_curve(self, scores = np.array([trial[SPOTTING_SCORE] for _, trial in self]) fpr, fnr, thresholds, eer = det_curve(y_true, scores, distances=False) fpr, fnr, thresholds = fpr[::-1], fnr[::-1], thresholds[::-1] - cdet = cost_miss * fnr * prior_target + \ - cost_fa * fpr * (1. - prior_target) + cdet = cost_miss * fnr * prior_target + cost_fa * fpr * (1.0 - prior_target) if return_latency: # needed to align the thresholds used in the DET curve # with (self.)thresholds used to compute latencies. - indices = np.searchsorted(thresholds, self.thresholds, side='left') - - thresholds = np.take(thresholds, indices, mode='clip') - fpr = np.take(fpr, indices, mode='clip') - fnr = np.take(fnr, indices, mode='clip') - cdet = np.take(cdet, indices, mode='clip') - return thresholds, fpr, fnr, eer, cdet, \ - self.speaker_latency, self.absolute_latency + indices = np.searchsorted(thresholds, self.thresholds, side="left") + + thresholds = np.take(thresholds, indices, mode="clip") + fpr = np.take(fpr, indices, mode="clip") + fnr = np.take(fnr, indices, mode="clip") + cdet = np.take(cdet, indices, mode="clip") + return ( + thresholds, + fpr, + fnr, + eer, + cdet, + self.speaker_latency, + self.absolute_latency, + ) else: return thresholds, fpr, fnr, eer, cdet @@ -318,17 +341,18 @@ def det_curve(self, abs_scores = np.array([trial[SPOTTING_ABS_SCORE] for _, trial in self]) result = {} - for key, scores in {'speaker': spk_scores, - 'absolute': abs_scores}.items(): + for key, scores in {"speaker": spk_scores, "absolute": abs_scores}.items(): result[key] = {} for i, latency in enumerate(self.latencies): - fpr, fnr, theta, eer = det_curve(y_true, scores[:, i], - distances=False) + fpr, fnr, theta, eer = det_curve( + y_true, scores[:, i], distances=False + ) fpr, fnr, theta = fpr[::-1], fnr[::-1], theta[::-1] - cdet = cost_miss * fnr * prior_target + \ - cost_fa * fpr * (1. - prior_target) + cdet = cost_miss * fnr * prior_target + cost_fa * fpr * ( + 1.0 - prior_target + ) result[key][latency] = theta, fpr, fnr, eer, cdet return result diff --git a/src/pyannote/metrics/types.py b/src/pyannote/metrics/types.py index a51dc4d..71ca85f 100644 --- a/src/pyannote/metrics/types.py +++ b/src/pyannote/metrics/types.py @@ -5,4 +5,4 @@ MetricComponent = str CalibrationMethod = Literal["isotonic", "sigmoid"] MetricComponents = List[MetricComponent] -Details = Dict[MetricComponent, float] \ No newline at end of file +Details = Dict[MetricComponent, float] diff --git a/src/pyannote/metrics/utils.py b/src/pyannote/metrics/utils.py index 4df3789..0d2fc33 100644 --- a/src/pyannote/metrics/utils.py +++ b/src/pyannote/metrics/utils.py @@ -35,11 +35,13 @@ class UEMSupportMixin: """Provides 'uemify' method with optional (à la NIST) collar""" - def extrude(self, - uem: Timeline, - reference: Annotation, - collar: float = 0.0, - skip_overlap: bool = False) -> Timeline: + def extrude( + self, + uem: Timeline, + reference: Annotation, + collar: float = 0.0, + skip_overlap: bool = False, + ) -> Timeline: """Extrude reference boundary collars from uem reference |----| |--------------| |-------------| @@ -65,22 +67,22 @@ def extrude(self, extruded_uem : Timeline """ - if collar == 0. and not skip_overlap: + if collar == 0.0 and not skip_overlap: return uem collars, overlap_regions = [], [] # build list of collars if needed - if collar > 0.: + if collar > 0.0: # iterate over all segments in reference for segment in reference.itersegments(): # add collar centered on start time t = segment.start - collars.append(Segment(t - .5 * collar, t + .5 * collar)) + collars.append(Segment(t - 0.5 * collar, t + 0.5 * collar)) # add collar centered on end time t = segment.end - collars.append(Segment(t - .5 * collar, t + .5 * collar)) + collars.append(Segment(t - 0.5 * collar, t + 0.5 * collar)) # build list of overlap regions if needed if skip_overlap: @@ -95,8 +97,9 @@ def extrude(self, return Timeline(segments=segments).support().gaps(support=uem) - def common_timeline(self, reference: Annotation, hypothesis: Annotation) \ - -> Timeline: + def common_timeline( + self, reference: Annotation, hypothesis: Annotation + ) -> Timeline: """Return timeline common to both reference and hypothesis reference |--------| |------------| |---------| |----| @@ -144,19 +147,20 @@ def project(self, annotation: Annotation, timeline: Timeline) -> Annotation: projection[segment, track] = annotation[segment_, track_] return projection - def uemify(self, - reference: Annotation, - hypothesis: Annotation, - uem: Optional[Timeline] = None, - collar: float = 0., - skip_overlap: bool = False, - returns_uem: bool = False, - returns_timeline: bool = False) \ - -> Union[ - Tuple[Annotation, Annotation], - Tuple[Annotation, Annotation, Timeline], - Tuple[Annotation, Annotation, Timeline, Timeline], - ]: + def uemify( + self, + reference: Annotation, + hypothesis: Annotation, + uem: Optional[Timeline] = None, + collar: float = 0.0, + skip_overlap: bool = False, + returns_uem: bool = False, + returns_timeline: bool = False, + ) -> Union[ + Tuple[Annotation, Annotation], + Tuple[Annotation, Annotation, Timeline], + Tuple[Annotation, Annotation, Timeline, Timeline], + ]: """Crop 'reference' and 'hypothesis' to 'uem' support Parameters @@ -195,19 +199,18 @@ def uemify(self, r_extent = reference.get_timeline().extent() h_extent = hypothesis.get_timeline().extent() extent = r_extent | h_extent - uem = Timeline(segments=[extent] if extent else [], - uri=reference.uri) + uem = Timeline(segments=[extent] if extent else [], uri=reference.uri) warnings.warn( "'uem' was approximated by the union of 'reference' " - "and 'hypothesis' extents.") + "and 'hypothesis' extents." + ) # extrude collars (and overlap regions) from uem - uem = self.extrude(uem, reference, collar=collar, - skip_overlap=skip_overlap) + uem = self.extrude(uem, reference, collar=collar, skip_overlap=skip_overlap) # extrude regions outside of uem - reference = reference.crop(uem, mode='intersection') - hypothesis = hypothesis.crop(uem, mode='intersection') + reference = reference.crop(uem, mode="intersection") + hypothesis = hypothesis.crop(uem, mode="intersection") # project reference and hypothesis on common timeline if returns_timeline: diff --git a/tests/test_detection.py b/tests/test_detection.py index 40e22db..0f78ba3 100644 --- a/tests/test_detection.py +++ b/tests/test_detection.py @@ -53,24 +53,25 @@ # Time 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 # UEM |--------------------------------------| + @pytest.fixture def reference(): reference = Annotation() - reference[Segment(0, 5)] = 'A' - reference[Segment(6, 10)] = 'B' - reference[Segment(12, 14)] = 'A' - reference[Segment(15, 20)] = 'C' + reference[Segment(0, 5)] = "A" + reference[Segment(6, 10)] = "B" + reference[Segment(12, 14)] = "A" + reference[Segment(15, 20)] = "C" return reference @pytest.fixture def hypothesis(): hypothesis = Annotation() - hypothesis[Segment(1, 7)] = 'A' - hypothesis[Segment(7, 9)] = 'D' - hypothesis[Segment(7, 10)] = 'B' - hypothesis[Segment(11, 17)] = 'C' - hypothesis[Segment(18, 20)] = 'D' + hypothesis[Segment(1, 7)] = "A" + hypothesis[Segment(7, 9)] = "D" + hypothesis[Segment(7, 10)] = "B" + hypothesis[Segment(11, 17)] = "C" + hypothesis[Segment(18, 20)] = "D" return hypothesis @@ -89,16 +90,16 @@ def test_detailed(reference, hypothesis): detectionErrorRate = DetectionErrorRate() details = detectionErrorRate(reference, hypothesis, detailed=True) - rate = details['detection error rate'] + rate = details["detection error rate"] npt.assert_almost_equal(rate, 0.3125, decimal=7) - false_alarm = details['false alarm'] + false_alarm = details["false alarm"] npt.assert_almost_equal(false_alarm, 3.0, decimal=7) - missed_detection = details['miss'] + missed_detection = details["miss"] npt.assert_almost_equal(missed_detection, 2.0, decimal=7) - total = details['total'] + total = details["total"] npt.assert_almost_equal(total, 16.0, decimal=7) @@ -141,7 +142,7 @@ def test_decision_cost_function(reference, hypothesis, uem): npt.assert_almost_equal(actual, expected, decimal=7) # UEM. - expected = 1/6. + expected = 1 / 6.0 dcf = DetectionCostFunction(fa_weight=0.25, miss_weight=0.75) actual = dcf(reference, hypothesis, uem=uem) npt.assert_almost_equal(actual, expected, decimal=7) diff --git a/tests/test_diarization.py b/tests/test_diarization.py index 78971e7..1b6530e 100644 --- a/tests/test_diarization.py +++ b/tests/test_diarization.py @@ -1,3 +1,4 @@ +from calendar import c import pytest import pyannote.core @@ -7,9 +8,17 @@ from pyannote.metrics.diarization import DiarizationErrorRate from pyannote.metrics.diarization import DiarizationPurity from pyannote.metrics.diarization import DiarizationCoverage +from pyannote.metrics.diarization import ( + OverlappedDiarizationErrorRate, + OVLDER_PREFIX_NONOVL, + OVLDER_PREFIX_OVL, +) import numpy.testing as npt +from pyannote.metrics.matcher import MATCH_TOTAL +from pyannote.metrics.types import Details + @pytest.fixture def reference(): @@ -42,6 +51,53 @@ def hypothesis(): return hypothesis +@pytest.fixture +def hypothesis_overlap(): + hypothesis = Annotation() + hypothesis[Segment(2, 13)] = "a" + hypothesis[Segment(10, 14)] = "d" + hypothesis[Segment(14, 24)] = "b" + hypothesis[Segment(22, 38)] = "c" + hypothesis[Segment(36, 40)] = "d" + return hypothesis + + +def test_ovl_der(reference_with_overlap, hypothesis_overlap: Annotation): + der_ovl = OverlappedDiarizationErrorRate() + der_regular = DiarizationErrorRate() + + error_rate_ovl = der_ovl(reference_with_overlap, hypothesis_overlap) + error_rate_regular = der_regular(reference_with_overlap, hypothesis_overlap) + + npt.assert_almost_equal(error_rate_ovl, error_rate_regular, decimal=7) + + +def test_ovl_der_components(reference_with_overlap, hypothesis_overlap): + for collar in [0.0, 0.1, 0.5]: + der_ovl = OverlappedDiarizationErrorRate(collar=collar) + der_regular = DiarizationErrorRate(collar=collar) + + comp_ovl: Details = der_ovl( + reference_with_overlap, hypothesis_overlap, detailed=True + ) + comp_regular: Details = der_regular( + reference_with_overlap, hypothesis_overlap, detailed=True + ) + + print(comp_ovl) + print(comp_regular) + + # test that for each component, the sum of non-overlapped and overlapped components is equal to the regular component + # eg check that ovl confusion+nonovl confusion = confusion + for component in der_regular.metric_components(): + ovl_compsum = comp_ovl["nonovl " + component] + comp_ovl["ovl " + component] + reg_compsum = comp_regular[component] + npt.assert_almost_equal(ovl_compsum, reg_compsum, decimal=7) + # check there is overlapped and nonoverlapped speech + assert comp_ovl[f"{OVLDER_PREFIX_NONOVL} {MATCH_TOTAL}"] > 0.0 + assert comp_ovl[f"{OVLDER_PREFIX_OVL} {MATCH_TOTAL}"] > 0.0 + + def test_error_rate(reference, hypothesis): diarizationErrorRate = DiarizationErrorRate() error_rate = diarizationErrorRate(reference, hypothesis) diff --git a/tests/test_identification.py b/tests/test_identification.py index 635eabc..7f4dd1c 100644 --- a/tests/test_identification.py +++ b/tests/test_identification.py @@ -14,21 +14,21 @@ @pytest.fixture def reference(): reference = Annotation() - reference[Segment(0, 10)] = 'A' - reference[Segment(12, 20)] = 'B' - reference[Segment(24, 27)] = 'A' - reference[Segment(30, 40)] = 'C' + reference[Segment(0, 10)] = "A" + reference[Segment(12, 20)] = "B" + reference[Segment(24, 27)] = "A" + reference[Segment(30, 40)] = "C" return reference @pytest.fixture def hypothesis(): hypothesis = Annotation() - hypothesis[Segment(2, 13)] = 'A' - hypothesis[Segment(13, 14)] = 'D' - hypothesis[Segment(14, 20)] = 'B' - hypothesis[Segment(22, 38)] = 'C' - hypothesis[Segment(38, 40)] = 'D' + hypothesis[Segment(2, 13)] = "A" + hypothesis[Segment(13, 14)] = "D" + hypothesis[Segment(14, 20)] = "B" + hypothesis[Segment(22, 38)] = "C" + hypothesis[Segment(38, 40)] = "D" return hypothesis @@ -42,22 +42,22 @@ def test_detailed(reference, hypothesis): identificationErrorRate = IdentificationErrorRate() details = identificationErrorRate(reference, hypothesis, detailed=True) - confusion = details['confusion'] + confusion = details["confusion"] npt.assert_almost_equal(confusion, 7.0, decimal=7) - correct = details['correct'] + correct = details["correct"] npt.assert_almost_equal(correct, 22.0, decimal=7) - rate = details['identification error rate'] + rate = details["identification error rate"] npt.assert_almost_equal(rate, 0.5161290322580645, decimal=7) - false_alarm = details['false alarm'] + false_alarm = details["false alarm"] npt.assert_almost_equal(false_alarm, 7.0, decimal=7) - missed_detection = details['missed detection'] + missed_detection = details["missed detection"] npt.assert_almost_equal(missed_detection, 2.0, decimal=7) - total = details['total'] + total = details["total"] npt.assert_almost_equal(total, 31.0, decimal=7)