diff --git a/.gitignore b/.gitignore index bce3533..2a7afe6 100644 --- a/.gitignore +++ b/.gitignore @@ -9,6 +9,7 @@ __pycache__/ MANIFEST .Python env/ +venv/ bin/ build/ develop-eggs/ @@ -61,4 +62,4 @@ doc/.ipynb_checkpoints # PyCharm .idea/ -.mypy_cache/ \ No newline at end of file +.mypy_cache/ diff --git a/pyannote/core/annotation.py b/pyannote/core/annotation.py index fc2d17a..84fec6f 100755 --- a/pyannote/core/annotation.py +++ b/pyannote/core/annotation.py @@ -111,6 +111,7 @@ from collections import defaultdict from typing import ( Hashable, + Literal, Optional, Dict, Union, @@ -122,6 +123,8 @@ Iterator, Text, TYPE_CHECKING, + NamedTuple, + overload, ) import numpy as np @@ -139,7 +142,17 @@ from .utils.types import Label, Key, Support, LabelGenerator, TrackName, CropMode if TYPE_CHECKING: - import pandas as pd + import pandas as pd # type: ignore + + +class SegmentTrack(NamedTuple): + segment: Segment + track: TrackName + +class SegmentTrackLabel(NamedTuple): + segment: Segment + track: TrackName + label: Label class Annotation: @@ -187,7 +200,7 @@ def __init__(self, uri: Optional[str] = None, modality: Optional[str] = None): self._labelNeedsUpdate: Dict[Label, bool] = {} # timeline meant to store all annotated segments - self._timeline: Timeline = None + self._timeline: Optional[Timeline] = None self._timelineNeedsUpdate: bool = True @property @@ -259,9 +272,16 @@ def itersegments(self): """ return iter(self._tracks) + @overload + def itertracks(self, yield_label: Literal[False] = ...) -> Iterator[SegmentTrack]: ... + @overload + def itertracks(self, yield_label: Literal[True]) -> Iterator[SegmentTrackLabel]: ... + @overload + def itertracks(self, yield_label: bool) -> Iterator[Union[SegmentTrack, SegmentTrackLabel]]: ... + def itertracks( self, yield_label: bool = False - ) -> Iterator[Union[Tuple[Segment, TrackName], Tuple[Segment, TrackName, Label]]]: + ) -> Iterator[Union[SegmentTrack, SegmentTrackLabel]]: """Iterate over tracks (in chronological order) Parameters @@ -286,9 +306,9 @@ def itertracks( tracks.items(), key=lambda tl: (str(tl[0]), str(tl[1])) ): if yield_label: - yield segment, track, lbl + yield SegmentTrackLabel(segment, track, lbl) else: - yield segment, track + yield SegmentTrack(segment, track) def _updateTimeline(self): self._timeline = Timeline(segments=self._tracks, uri=self.uri) @@ -317,9 +337,14 @@ def get_timeline(self, copy: bool = True) -> Timeline: """ if self._timelineNeedsUpdate: self._updateTimeline() + + timeline_ = self._timeline + if timeline_ is None: + timeline_ = Timeline(uri=self.uri) + if copy: - return self._timeline.copy() - return self._timeline + return timeline_.copy() + return timeline_ def __eq__(self, other: "Annotation"): """Equality @@ -556,6 +581,9 @@ def crop(self, support: Support, mode: CropMode = "intersection") -> "Annotation else: raise NotImplementedError("unsupported mode: '%s'" % mode) + else: + raise TypeError("unsupported support type: '%s'" % type(support)) + def extrude( self, removed: Support, mode: CropMode = "intersection" ) -> "Annotation": @@ -1178,7 +1206,7 @@ def argmax(self, support: Optional[Support] = None) -> Optional[Label]: key=lambda x: x[1], )[0] - def rename_tracks(self, generator: LabelGenerator = "string") -> "Annotation": + def rename_tracks(self, generator: Union[LabelGenerator, Iterable[str], Iterable[int]] = "string") -> "Annotation": """Rename all tracks Parameters @@ -1215,13 +1243,17 @@ def rename_tracks(self, generator: LabelGenerator = "string") -> "Annotation": renamed = self.__class__(uri=self.uri, modality=self.modality) if generator == "string": - generator = string_generator() + generator_ = string_generator() elif generator == "int": - generator = int_generator() + generator_ = int_generator() + elif isinstance(generator, Iterable): + generator_ = iter(generator) + else: + raise ValueError("generator must be 'string', 'int', or iterable") # TODO speed things up by working directly with annotation internals for s, _, label in self.itertracks(yield_label=True): - renamed[s, next(generator)] = label + renamed[s, next(generator_)] = label return renamed def rename_labels( @@ -1439,11 +1471,11 @@ def discretize( duration: Optional[float] = None, ): """Discretize - + Parameters ---------- support : Segment, optional - Part of annotation to discretize. + Part of annotation to discretize. Defaults to annotation full extent. resolution : float or SlidingWindow, optional Defaults to 10ms frames. diff --git a/pyannote/core/feature.py b/pyannote/core/feature.py index ca95142..b39f035 100755 --- a/pyannote/core/feature.py +++ b/pyannote/core/feature.py @@ -36,11 +36,11 @@ """ import numbers import warnings -from typing import Tuple, Optional, Union, Iterator, List, Text +from typing import Tuple, Optional, Union, Iterator, List import numpy as np -from pyannote.core.utils.types import Alignment +from pyannote.core.utils.types import Alignment, Label from .segment import Segment from .segment import SlidingWindow from .timeline import Timeline @@ -58,7 +58,7 @@ class SlidingWindowFeature(np.lib.mixins.NDArrayOperatorsMixin): """ def __init__( - self, data: np.ndarray, sliding_window: SlidingWindow, labels: List[Text] = None + self, data: np.ndarray, sliding_window: SlidingWindow, labels: Optional[List[Label]] = None ): self.sliding_window: SlidingWindow = sliding_window self.data = data @@ -106,7 +106,7 @@ def __next__(self) -> Tuple[Segment, np.ndarray]: self.__i += 1 try: return self.sliding_window[self.__i], self.data[self.__i] - except IndexError as e: + except IndexError: raise StopIteration() def next(self): diff --git a/pyannote/core/timeline.py b/pyannote/core/timeline.py index 1af9151..6a88281 100755 --- a/pyannote/core/timeline.py +++ b/pyannote/core/timeline.py @@ -92,6 +92,8 @@ from typing import (Optional, Iterable, List, Union, Callable, TextIO, Tuple, TYPE_CHECKING, Iterator, Dict, Text) +# sortedcontainers does not support type hinting +# introduces some type: ignore comments below. from sortedcontainers import SortedList from . import PYANNOTE_SEGMENT @@ -104,7 +106,7 @@ # https://stackoverflow.com/questions/39740632/python-type-hinting-without-cyclic-imports if TYPE_CHECKING: from .annotation import Annotation - import pandas as pd + import pandas as pd # type: ignore # ===================================================================== @@ -141,13 +143,13 @@ def from_df(cls, df: 'pd.DataFrame', uri: Optional[str] = None) -> 'Timeline': def __init__(self, segments: Optional[Iterable[Segment]] = None, - uri: str = None): + uri: Optional[str] = None): if segments is None: segments = () # set of segments (used for checking inclusion) # Store only non-empty Segments. - segments_set = set([segment for segment in segments if segment]) + segments_set = set(segment for segment in segments if segment) self.segments_set_ = segments_set @@ -159,7 +161,7 @@ def __init__(self, self.segments_boundaries_ = SortedList(boundaries) # path to (or any identifier of) segmented resource - self.uri: str = uri + self.uri: Optional[str] = uri def __len__(self): """Number of segments @@ -200,7 +202,7 @@ def __getitem__(self, k: int) -> Segment: >>> first_segment = timeline[0] >>> penultimate_segment = timeline[-2] """ - return self.segments_list_[k] + return self.segments_list_[k] # type: ignore def __eq__(self, other: 'Timeline'): """Equality @@ -320,7 +322,7 @@ def discard(self, segment: Segment) -> 'Timeline': def __ior__(self, timeline: 'Timeline') -> 'Timeline': return self.update(timeline) - def update(self, timeline: Segment) -> 'Timeline': + def update(self, timeline: 'Timeline') -> 'Timeline': """Add every segments of an existing timeline (in place) Parameters @@ -465,6 +467,20 @@ def crop_iter(self, else: yield mapped_to + def crop_iter_no_mapping(self, + support: Support, + mode: CropMode = 'intersection') \ + -> Iterator[Segment]: + """Typed version of crop_iter without mapping""" + return self.crop_iter(support=support, mode=mode, returns_mapping=False) # type: ignore + + def crop_iter_with_mapping(self, + support: Support, + mode: CropMode = 'intersection') \ + -> Iterator[Tuple[Segment, Segment]]: + """Typed version of crop_iter with mapping""" + return self.crop_iter(support=support, mode=mode, returns_mapping=True) # type: ignore + def crop(self, support: Support, mode: CropMode = 'intersection', @@ -516,16 +532,29 @@ def crop(self, if mode == 'intersection' and returns_mapping: segments, mapping = [], {} - for segment, mapped_to in self.crop_iter(support, - mode='intersection', - returns_mapping=True): + for segment, mapped_to in self.crop_iter_with_mapping(support, + mode='intersection'): segments.append(mapped_to) mapping[mapped_to] = mapping.get(mapped_to, list()) + [segment] return Timeline(segments=segments, uri=self.uri), mapping - return Timeline(segments=self.crop_iter(support, mode=mode), + return Timeline(segments=self.crop_iter_no_mapping(support, mode=mode), uri=self.uri) + def crop_with_mapping(self, + support: Support, + mode: CropMode = 'intersection') \ + -> Tuple['Timeline', Dict[Segment, List[Segment]]]: + """Typed version of crop with mapping""" + return self.crop(support=support, mode=mode, returns_mapping=True) # type: ignore + + def crop_no_mapping(self, + support: Support, + mode: CropMode = 'intersection') \ + -> 'Timeline': + """Typed version of crop without mapping""" + return self.crop(support=support, mode=mode, returns_mapping=False) # type: ignore + def overlapping(self, t: float) -> List[Segment]: """Get list of segments overlapping `t` @@ -622,7 +651,7 @@ def extrude(self, mode = "strict" elif mode == "strict": mode = "loose" - return self.crop(truncating_support, mode=mode) + return self.crop_no_mapping(truncating_support, mode=mode) def __str__(self): """Human-readable representation @@ -705,7 +734,7 @@ def empty(self) -> 'Timeline': def covers(self, other: 'Timeline') -> bool: """Check whether other timeline is fully covered by the timeline - + Parameter --------- other : Timeline @@ -718,16 +747,16 @@ def covers(self, other: 'Timeline') -> bool: one segment of "other" is not fully covered by timeline """ - # compute gaps within "other" extent - # this is where we should look for possible faulty segments + # compute gaps within "other" extent + # this is where we should look for possible faulty segments gaps = self.gaps(support=other.extent()) - # if at least one gap intersects with a segment from "other", + # if at least one gap intersects with a segment from "other", # "self" does not cover "other" entirely --> return False for _ in gaps.co_iter(other): return False - # if no gap intersects with a segment from "other", + # if no gap intersects with a segment from "other", # "self" covers "other" entirely --> return True return True @@ -790,8 +819,8 @@ def extent(self) -> Segment: """ if self.segments_set_: segments_boundaries_ = self.segments_boundaries_ - start = segments_boundaries_[0] - end = segments_boundaries_[-1] + start: float = segments_boundaries_[0] # type: ignore + end: float = segments_boundaries_[-1] # type: ignore return Segment(start=start, end=end) return Segment(start=0.0, end=0.0) @@ -818,7 +847,7 @@ def support_iter(self, collar: float = 0.0) -> Iterator[Segment]: # Initialize new support segment # as very first segment of the timeline - new_segment = self.segments_list_[0] + new_segment: Segment = self.segments_list_[0] # type: ignore for segment in self: @@ -918,7 +947,7 @@ def gaps_iter(self, support: Optional[Support] = None) -> Iterator[Segment]: end = support.start # support on the intersection of timeline and provided segment - for segment in self.crop(support, mode='intersection').support(): + for segment in self.crop_no_mapping(support, mode='intersection').support(): # add gap between each pair of consecutive segments # if there is no gap, segment is empty, therefore not added @@ -1015,8 +1044,6 @@ def segmentation(self) -> 'Timeline': # becomes # |-|--|-| |-|---|--| |--|----|--| - # start with an empty copy - timeline = Timeline(uri=self.uri) if len(timestamps) == 0: return Timeline(uri=self.uri) @@ -1034,7 +1061,7 @@ def segmentation(self) -> 'Timeline': return Timeline(segments=segments, uri=self.uri) def to_annotation(self, - generator: Union[str, Iterable[Label], None, None] = 'string', + generator: Union[str, Iterable[Label], None] = 'string', modality: Optional[str] = None) \ -> 'Annotation': """Turn timeline into an annotation @@ -1056,15 +1083,22 @@ def to_annotation(self, from .annotation import Annotation annotation = Annotation(uri=self.uri, modality=modality) - if generator == 'string': + if generator == 'string' or generator is None: from .utils.generators import string_generator - generator = string_generator() + generator_ = string_generator() elif generator == 'int': from .utils.generators import int_generator - generator = int_generator() + generator_ = int_generator() + elif isinstance(generator, Iterable): + generator_ = iter(generator) + else: + msg = ("`generator` must be one of 'string', 'int', or an iterable " + "(got {generator}).") + raise ValueError(msg.format(generator=generator)) + for segment in self: - annotation[segment] = next(generator) + annotation[segment] = next(generator_) return annotation diff --git a/pyannote/core/utils/types.py b/pyannote/core/utils/types.py index 0105a16..2751bec 100644 --- a/pyannote/core/utils/types.py +++ b/pyannote/core/utils/types.py @@ -1,7 +1,14 @@ -from typing import Hashable, Union, Tuple, Iterator +from typing import Hashable, Union, Tuple, Iterator, TYPE_CHECKING from typing_extensions import Literal +if TYPE_CHECKING: + from pyannote.core.segment import Segment + from pyannote.core.timeline import Timeline + from pyannote.core.feature import SlidingWindowFeature + from pyannote.core.annotation import Annotation + + Label = Hashable Support = Union['Segment', 'Timeline'] LabelGeneratorMode = Literal['int', 'string']