From 8aca43278a65c436611bd7825a1fd74cd37c0d4d Mon Sep 17 00:00:00 2001 From: Cesar Berrospi Ramis Date: Fri, 14 Nov 2025 11:41:09 +0100 Subject: [PATCH 1/5] refactor: move WebVTT data model from docling Signed-off-by: Cesar Berrospi Ramis --- docling_core/types/doc/webvtt.py | 416 +++++++++++++++++++++++++ test/data/webvtt/webvtt_example_01.vtt | 42 +++ test/data/webvtt/webvtt_example_02.vtt | 15 + test/data/webvtt/webvtt_example_03.vtt | 57 ++++ test/test_webvtt.py | 199 ++++++++++++ 5 files changed, 729 insertions(+) create mode 100644 docling_core/types/doc/webvtt.py create mode 100644 test/data/webvtt/webvtt_example_01.vtt create mode 100644 test/data/webvtt/webvtt_example_02.vtt create mode 100644 test/data/webvtt/webvtt_example_03.vtt create mode 100644 test/test_webvtt.py diff --git a/docling_core/types/doc/webvtt.py b/docling_core/types/doc/webvtt.py new file mode 100644 index 00000000..eccae4a6 --- /dev/null +++ b/docling_core/types/doc/webvtt.py @@ -0,0 +1,416 @@ +"""Models for the Docling's adoption of Web Video Text Tracks format.""" + +import logging +import re +from typing import Annotated, ClassVar, Literal, Optional, Union, cast + +from pydantic import BaseModel, ConfigDict, Field, field_validator, model_validator +from pydantic.types import StringConstraints +from typing_extensions import Self, override + +_log = logging.getLogger(__name__) + + +class _WebVTTTimestamp(BaseModel): + """Model representing a WebVTT timestamp. + + A WebVTT timestamp is always interpreted relative to the current playback position + of the media data that the WebVTT file is to be synchronized with. + """ + + model_config = ConfigDict(regex_engine="python-re") + + raw: Annotated[ + str, + Field( + description="A representation of the WebVTT Timestamp as a single string" + ), + ] + + _pattern: ClassVar[re.Pattern] = re.compile( + r"^(?:(\d{2,}):)?([0-5]\d):([0-5]\d)\.(\d{3})$" + ) + _hours: int + _minutes: int + _seconds: int + _millis: int + + @model_validator(mode="after") + def validate_raw(self) -> Self: + m = self._pattern.match(self.raw) + if not m: + raise ValueError(f"Invalid WebVTT timestamp format: {self.raw}") + self._hours = int(m.group(1)) if m.group(1) else 0 + self._minutes = int(m.group(2)) + self._seconds = int(m.group(3)) + self._millis = int(m.group(4)) + + if self._minutes < 0 or self._minutes > 59: + raise ValueError("Minutes must be between 0 and 59") + if self._seconds < 0 or self._seconds > 59: + raise ValueError("Seconds must be between 0 and 59") + + return self + + @property + def seconds(self) -> float: + """A representation of the WebVTT Timestamp in seconds.""" + return ( + self._hours * 3600 + + self._minutes * 60 + + self._seconds + + self._millis / 1000.0 + ) + + @override + def __str__(self) -> str: + return self.raw + + +_WebVTTCueIdentifier = Annotated[ + str, StringConstraints(strict=True, pattern=r"^(?!.*-->)[^\n\r]+$") +] + + +class _WebVTTCueTimings(BaseModel): + """Model representating WebVTT cue timings.""" + + start: Annotated[ + _WebVTTTimestamp, Field(description="Start time offset of the cue") + ] + end: Annotated[_WebVTTTimestamp, Field(description="End time offset of the cue")] + + @model_validator(mode="after") + def check_order(self) -> Self: + if self.start and self.end: + if self.end.seconds <= self.start.seconds: + raise ValueError("End timestamp must be greater than start timestamp") + return self + + @override + def __str__(self): + return f"{self.start} --> {self.end}" + + +class _WebVTTCueTextSpan(BaseModel): + """Model representing a WebVTT cue text span.""" + + text: str + span_type: Literal["text"] = "text" + + @field_validator("text", mode="after") + @classmethod + def validate_text(cls, value: str) -> str: + if any(ch in value for ch in {"\n", "\r", "&", "<"}): + raise ValueError("Cue text span contains invalid characters") + if len(value) == 0: + raise ValueError("Cue text span cannot be empty") + return value + + @override + def __str__(self): + return self.text + + +class _WebVTTCueVoiceSpan(BaseModel): + """Model representing a WebVTT cue voice span.""" + + annotation: Annotated[ + str, + Field( + description=( + "Cue span start tag annotation text representing the name of thevoice" + ) + ), + ] + classes: Annotated[ + list[str], + Field(description="List of classes representing the cue span's significance"), + ] = [] + components: Annotated[ + list["_WebVTTCueComponent"], + Field(description="The components representing the cue internal text"), + ] = [] + span_type: Literal["v"] = "v" + + @field_validator("annotation", mode="after") + @classmethod + def validate_annotation(cls, value: str) -> str: + if any(ch in value for ch in {"\n", "\r", "&", ">"}): + raise ValueError( + "Cue span start tag annotation contains invalid characters" + ) + if not value: + raise ValueError("Cue text span cannot be empty") + return value + + @field_validator("classes", mode="after") + @classmethod + def validate_classes(cls, value: list[str]) -> list[str]: + for item in value: + if any(ch in item for ch in {"\t", "\n", "\r", " ", "&", "<", ">", "."}): + raise ValueError( + "A cue span start tag class contains invalid characters" + ) + if not item: + raise ValueError("Cue span start tag classes cannot be empty") + return value + + @override + def __str__(self): + tag = f"v.{'.'.join(self.classes)}" if self.classes else "v" + inner = "".join(str(span) for span in self.components) + return f"<{tag} {self.annotation}>{inner}" + + +class _WebVTTCueClassSpan(BaseModel): + span_type: Literal["c"] = "c" + components: list["_WebVTTCueComponent"] + + @override + def __str__(self): + inner = "".join(str(span) for span in self.components) + return f"{inner}" + + +class _WebVTTCueItalicSpan(BaseModel): + span_type: Literal["i"] = "i" + components: list["_WebVTTCueComponent"] + + @override + def __str__(self): + inner = "".join(str(span) for span in self.components) + return f"{inner}" + + +class _WebVTTCueBoldSpan(BaseModel): + span_type: Literal["b"] = "b" + components: list["_WebVTTCueComponent"] + + @override + def __str__(self): + inner = "".join(str(span) for span in self.components) + return f"{inner}" + + +class _WebVTTCueUnderlineSpan(BaseModel): + span_type: Literal["u"] = "u" + components: list["_WebVTTCueComponent"] + + @override + def __str__(self): + inner = "".join(str(span) for span in self.components) + return f"{inner}" + + +_WebVTTCueComponent = Annotated[ + Union[ + _WebVTTCueTextSpan, + _WebVTTCueClassSpan, + _WebVTTCueItalicSpan, + _WebVTTCueBoldSpan, + _WebVTTCueUnderlineSpan, + _WebVTTCueVoiceSpan, + ], + Field(discriminator="span_type", description="The WebVTT cue component"), +] + + +class _WebVTTCueBlock(BaseModel): + """Model representing a WebVTT cue block. + + The optional WebVTT cue settings list is not supported. + The cue payload is limited to the following spans: text, class, italic, bold, + underline, and voice. + """ + + model_config = ConfigDict(regex_engine="python-re") + + identifier: Optional[_WebVTTCueIdentifier] = Field( + None, description="The WebVTT cue identifier" + ) + timings: Annotated[_WebVTTCueTimings, Field(description="The WebVTT cue timings")] + payload: Annotated[list[_WebVTTCueComponent], Field(description="The cue payload")] + + _pattern_block: ClassVar[re.Pattern] = re.compile( + r"<(/?)(i|b|c|u|v(?:\.[^\t\n\r &<>.]+)*)(?:\s+([^>]*))?>" + ) + _pattern_voice_tag: ClassVar[re.Pattern] = re.compile( + r"^\.[^\t\n\r &<>]+)?" # zero or more classes + r"[ \t]+(?P[^\n\r&>]+)>" # required space and annotation + ) + + @field_validator("payload", mode="after") + @classmethod + def validate_payload(cls, payload): + for voice in payload: + if "-->" in str(voice): + raise ValueError("Cue payload must not contain '-->'") + return payload + + @classmethod + def parse(cls, raw: str) -> "_WebVTTCueBlock": + lines = raw.strip().splitlines() + if not lines: + raise ValueError("Cue block must have at least one line") + identifier: Optional[_WebVTTCueIdentifier] = None + timing_line = lines[0] + if "-->" not in timing_line and len(lines) > 1: + identifier = timing_line + timing_line = lines[1] + cue_lines = lines[2:] + else: + cue_lines = lines[1:] + + if "-->" not in timing_line: + raise ValueError("Cue block must contain WebVTT cue timings") + + start, end = [t.strip() for t in timing_line.split("-->")] + end = re.split(" |\t", end)[0] # ignore the cue settings list + timings: _WebVTTCueTimings = _WebVTTCueTimings( + start=_WebVTTTimestamp(raw=start), end=_WebVTTTimestamp(raw=end) + ) + cue_text = " ".join(cue_lines).strip() + if cue_text.startswith("" not in cue_text: + # adding close tag for cue voice spans without end tag + cue_text += "" + + stack: list[list[_WebVTTCueComponent]] = [[]] + tag_stack: list[Union[str, tuple]] = [] + + pos = 0 + matches = list(cls._pattern_block.finditer(cue_text)) + i = 0 + while i < len(matches): + match = matches[i] + if match.start() > pos: + stack[-1].append(_WebVTTCueTextSpan(text=cue_text[pos : match.start()])) + tag = match.group(0) + + if tag.startswith(("", "", "", "")): + tag_type = tag[1:2] + tag_stack.append(tag_type) + stack.append([]) + elif tag == "": + children = stack.pop() + stack[-1].append(_WebVTTCueItalicSpan(components=children)) + tag_stack.pop() + elif tag == "": + children = stack.pop() + stack[-1].append(_WebVTTCueBoldSpan(components=children)) + tag_stack.pop() + elif tag == "": + children = stack.pop() + stack[-1].append(_WebVTTCueUnderlineSpan(components=children)) + tag_stack.pop() + elif tag == "": + children = stack.pop() + stack[-1].append(_WebVTTCueClassSpan(components=children)) + tag_stack.pop() + elif tag.startswith("")) + else: + parts.append(str(span)) + + return "".join(parts) + + +class _WebVTTFile(BaseModel): + """A model representing a WebVTT file.""" + + cue_blocks: list[_WebVTTCueBlock] + + @staticmethod + def verify_signature(content: str) -> bool: + if not content: + return False + elif len(content) == 6: + return content == "WEBVTT" + elif len(content) > 6 and content.startswith("WEBVTT"): + return content[6] in (" ", "\t", "\n") + else: + return False + + @classmethod + def parse(cls, raw: str) -> "_WebVTTFile": + # Normalize newlines to LF + raw = raw.replace("\r\n", "\n").replace("\r", "\n") + + # Check WebVTT signature + if not cls.verify_signature(raw): + raise ValueError("Invalid WebVTT file signature") + + # Strip "WEBVTT" header line + lines = raw.split("\n", 1) + body = lines[1] if len(lines) > 1 else "" + + # Remove NOTE/STYLE/REGION blocks + body = re.sub(r"^(NOTE[^\n]*\n(?:.+\n)*?)\n", "", body, flags=re.MULTILINE) + body = re.sub(r"^(STYLE|REGION)(?:.+\n)*?\n", "", body, flags=re.MULTILINE) + + # Split into cue blocks + raw_blocks = re.split(r"\n\s*\n", body.strip()) + cues: list[_WebVTTCueBlock] = [] + for block in raw_blocks: + try: + cues.append(_WebVTTCueBlock.parse(block)) + except ValueError as e: + _log.warning(f"Failed to parse cue block:\n{block}\n{e}") + + return cls(cue_blocks=cues) + + def __iter__(self): + return iter(self.cue_blocks) + + def __getitem__(self, idx): + return self.cue_blocks[idx] + + def __len__(self): + return len(self.cue_blocks) diff --git a/test/data/webvtt/webvtt_example_01.vtt b/test/data/webvtt/webvtt_example_01.vtt new file mode 100644 index 00000000..333ca4a8 --- /dev/null +++ b/test/data/webvtt/webvtt_example_01.vtt @@ -0,0 +1,42 @@ +WEBVTT + +NOTE Copyright © 2019 World Wide Web Consortium. https://www.w3.org/TR/webvtt1/ + +00:11.000 --> 00:13.000 +We are in New York City + +00:13.000 --> 00:16.000 +We’re actually at the Lucern Hotel, just down the street + +00:16.000 --> 00:18.000 +from the American Museum of Natural History + +00:18.000 --> 00:20.000 +And with me is Neil deGrasse Tyson + +00:20.000 --> 00:22.000 +Astrophysicist, Director of the Hayden Planetarium + +00:22.000 --> 00:24.000 +at the AMNH. + +00:24.000 --> 00:26.000 +Thank you for walking down here. + +00:27.000 --> 00:30.000 +And I want to do a follow-up on the last conversation we did. + +00:30.000 --> 00:31.500 align:right size:50% +When we e-mailed— + +00:30.500 --> 00:32.500 align:left size:50% +Didn’t we talk about enough in that conversation? + +00:32.000 --> 00:35.500 align:right size:50% +No! No no no no; 'cos 'cos obviously 'cos + +00:32.500 --> 00:33.500 align:left size:50% +Laughs + +00:35.500 --> 00:38.000 +You know I’m so excited my glasses are falling off here. diff --git a/test/data/webvtt/webvtt_example_02.vtt b/test/data/webvtt/webvtt_example_02.vtt new file mode 100644 index 00000000..1152a1e8 --- /dev/null +++ b/test/data/webvtt/webvtt_example_02.vtt @@ -0,0 +1,15 @@ +WEBVTT + +NOTE Copyright © 2019 World Wide Web Consortium. https://www.w3.org/TR/webvtt1/ + +00:00.000 --> 00:02.000 +It’s a blue apple tree! + +00:02.000 --> 00:04.000 +No way! + +00:04.000 --> 00:06.000 +Hee! laughter + +00:06.000 --> 00:08.000 +That’s awesome! \ No newline at end of file diff --git a/test/data/webvtt/webvtt_example_03.vtt b/test/data/webvtt/webvtt_example_03.vtt new file mode 100644 index 00000000..a4dc1291 --- /dev/null +++ b/test/data/webvtt/webvtt_example_03.vtt @@ -0,0 +1,57 @@ +WEBVTT + +62357a1d-d250-41d5-a1cf-6cc0eeceffcc/15-0 +00:00:04.963 --> 00:00:08.571 +OK, +I think now we should be recording + +62357a1d-d250-41d5-a1cf-6cc0eeceffcc/15-1 +00:00:08.571 --> 00:00:09.403 +properly. + +62357a1d-d250-41d5-a1cf-6cc0eeceffcc/16-0 +00:00:10.683 --> 00:00:11.563 +Good. + +62357a1d-d250-41d5-a1cf-6cc0eeceffcc/17-0 +00:00:13.363 --> 00:00:13.803 +Yeah. + +62357a1d-d250-41d5-a1cf-6cc0eeceffcc/78-0 +00:00:49.603 --> 00:00:53.363 +I was also thinking. + +62357a1d-d250-41d5-a1cf-6cc0eeceffcc/113-0 +00:00:54.963 --> 00:01:02.072 +Would be maybe good to create items, + +62357a1d-d250-41d5-a1cf-6cc0eeceffcc/113-1 +00:01:02.072 --> 00:01:06.811 +some metadata, +some options that can be specific. + +62357a1d-d250-41d5-a1cf-6cc0eeceffcc/150-0 +00:01:10.243 --> 00:01:13.014 +Yeah, +I mean I think you went even more than + +62357a1d-d250-41d5-a1cf-6cc0eeceffcc/119-0 +00:01:10.563 --> 00:01:12.643 +But we preserved the atoms. + +62357a1d-d250-41d5-a1cf-6cc0eeceffcc/150-1 +00:01:13.014 --> 00:01:15.907 +than me. +I just opened the format. + +62357a1d-d250-41d5-a1cf-6cc0eeceffcc/197-1 +00:01:50.222 --> 00:01:51.643 +give it a try, yeah. + +62357a1d-d250-41d5-a1cf-6cc0eeceffcc/200-0 +00:01:52.043 --> 00:01:55.043 +Okay, talk to you later. + +62357a1d-d250-41d5-a1cf-6cc0eeceffcc/202-0 +00:01:54.603 --> 00:01:55.283 +See you. \ No newline at end of file diff --git a/test/test_webvtt.py b/test/test_webvtt.py new file mode 100644 index 00000000..75f5dfc1 --- /dev/null +++ b/test/test_webvtt.py @@ -0,0 +1,199 @@ +# Assisted by watsonx Code Assistant + + +import pytest +from pydantic import ValidationError + +from docling_core.types.doc.webvtt import ( + _WebVTTCueItalicSpan, + _WebVTTCueTextSpan, + _WebVTTCueTimings, + _WebVTTCueVoiceSpan, + _WebVTTFile, + _WebVTTTimestamp, +) + +from .test_data_gen_flag import GEN_TEST_DATA + +GENERATE = GEN_TEST_DATA + + +def test_vtt_cue_commponents(): + """Test WebVTT components.""" + valid_timestamps = [ + "00:01:02.345", + "12:34:56.789", + "02:34.567", + "00:00:00.000", + ] + valid_total_seconds = [ + 1 * 60 + 2.345, + 12 * 3600 + 34 * 60 + 56.789, + 2 * 60 + 34.567, + 0.0, + ] + for idx, ts in enumerate(valid_timestamps): + model = _WebVTTTimestamp(raw=ts) + assert model.seconds == valid_total_seconds[idx] + + """Test invalid WebVTT timestamps.""" + invalid_timestamps = [ + "00:60:02.345", # minutes > 59 + "00:01:60.345", # seconds > 59 + "00:01:02.1000", # milliseconds > 999 + "01:02:03", # missing milliseconds + "01:02", # missing milliseconds + ":01:02.345", # extra : for missing hours + "abc:01:02.345", # invalid format + ] + for ts in invalid_timestamps: + with pytest.raises(ValidationError): + _WebVTTTimestamp(raw=ts) + + """Test the timestamp __str__ method.""" + model = _WebVTTTimestamp(raw="00:01:02.345") + assert str(model) == "00:01:02.345" + + """Test valid cue timings.""" + start = _WebVTTTimestamp(raw="00:10.005") + end = _WebVTTTimestamp(raw="00:14.007") + cue_timings = _WebVTTCueTimings(start=start, end=end) + assert cue_timings.start == start + assert cue_timings.end == end + assert str(cue_timings) == "00:10.005 --> 00:14.007" + + """Test invalid cue timings with end timestamp before start.""" + start = _WebVTTTimestamp(raw="00:10.700") + end = _WebVTTTimestamp(raw="00:10.500") + with pytest.raises(ValidationError) as excinfo: + _WebVTTCueTimings(start=start, end=end) + assert "End timestamp must be greater than start timestamp" in str(excinfo.value) + + """Test invalid cue timings with missing end.""" + start = _WebVTTTimestamp(raw="00:10.500") + with pytest.raises(ValidationError) as excinfo: + _WebVTTCueTimings(start=start) + assert "Field required" in str(excinfo.value) + + """Test invalid cue timings with missing start.""" + end = _WebVTTTimestamp(raw="00:10.500") + with pytest.raises(ValidationError) as excinfo: + _WebVTTCueTimings(end=end) + assert "Field required" in str(excinfo.value) + + """Test with valid text.""" + valid_text = "This is a valid cue text span." + span = _WebVTTCueTextSpan(text=valid_text) + assert span.text == valid_text + assert str(span) == valid_text + + """Test with text containing newline characters.""" + invalid_text = "This cue text span\ncontains a newline." + with pytest.raises(ValidationError): + _WebVTTCueTextSpan(text=invalid_text) + + """Test with text containing ampersand.""" + invalid_text = "This cue text span contains &." + with pytest.raises(ValidationError): + _WebVTTCueTextSpan(text=invalid_text) + + """Test with text containing less-than sign.""" + invalid_text = "This cue text span contains <." + with pytest.raises(ValidationError): + _WebVTTCueTextSpan(text=invalid_text) + + """Test with empty text.""" + with pytest.raises(ValidationError): + _WebVTTCueTextSpan(text="") + + """Test that annotation validation works correctly.""" + valid_annotation = "valid-annotation" + invalid_annotation = "invalid\nannotation" + with pytest.raises(ValidationError): + _WebVTTCueVoiceSpan(annotation=invalid_annotation) + assert _WebVTTCueVoiceSpan(annotation=valid_annotation) + + """Test that classes validation works correctly.""" + annotation = "speaker name" + valid_classes = ["class1", "class2"] + invalid_classes = ["class\nwith\nnewlines", ""] + with pytest.raises(ValidationError): + _WebVTTCueVoiceSpan(annotation=annotation, classes=invalid_classes) + assert _WebVTTCueVoiceSpan(annotation=annotation, classes=valid_classes) + + """Test that components validation works correctly.""" + annotation = "speaker name" + valid_components = [_WebVTTCueTextSpan(text="random text")] + invalid_components = [123, "not a component"] + with pytest.raises(ValidationError): + _WebVTTCueVoiceSpan(annotation=annotation, components=invalid_components) + assert _WebVTTCueVoiceSpan(annotation=annotation, components=valid_components) + + """Test valid cue voice spans.""" + cue_span = _WebVTTCueVoiceSpan( + annotation="speaker", + classes=["loud", "clear"], + components=[_WebVTTCueTextSpan(text="random text")], + ) + + expected_str = "random text" + assert str(cue_span) == expected_str + + cue_span = _WebVTTCueVoiceSpan( + annotation="speaker", + components=[_WebVTTCueTextSpan(text="random text")], + ) + expected_str = "random text" + assert str(cue_span) == expected_str + + +def test_webvtt_file(): + """Test WebVTT files.""" + with open("./test/data/webvtt/webvtt_example_01.vtt", encoding="utf-8") as f: + content = f.read() + vtt = _WebVTTFile.parse(content) + assert len(vtt) == 13 + block = vtt.cue_blocks[11] + assert str(block.timings) == "00:32.500 --> 00:33.500" + assert len(block.payload) == 1 + cue_span = block.payload[0] + assert isinstance(cue_span, _WebVTTCueVoiceSpan) + assert cue_span.annotation == "Neil deGrasse Tyson" + assert not cue_span.classes + assert len(cue_span.components) == 1 + comp = cue_span.components[0] + assert isinstance(comp, _WebVTTCueItalicSpan) + assert len(comp.components) == 1 + comp2 = comp.components[0] + assert isinstance(comp2, _WebVTTCueTextSpan) + assert comp2.text == "Laughs" + + with open("./test/data/webvtt/webvtt_example_02.vtt", encoding="utf-8") as f: + content = f.read() + vtt = _WebVTTFile.parse(content) + assert len(vtt) == 4 + reverse = ( + "WEBVTT\n\nNOTE Copyright © 2019 World Wide Web Consortium. " + "https://www.w3.org/TR/webvtt1/\n\n" + ) + reverse += "\n\n".join([str(block) for block in vtt.cue_blocks]) + assert content == reverse + + with open("./test/data/webvtt/webvtt_example_03.vtt", encoding="utf-8") as f: + content = f.read() + vtt = _WebVTTFile.parse(content) + assert len(vtt) == 13 + for block in vtt: + assert block.identifier + block = vtt.cue_blocks[0] + assert block.identifier == "62357a1d-d250-41d5-a1cf-6cc0eeceffcc/15-0" + assert str(block.timings) == "00:00:04.963 --> 00:00:08.571" + assert len(block.payload) == 1 + assert isinstance(block.payload[0], _WebVTTCueVoiceSpan) + block = vtt.cue_blocks[2] + assert isinstance(cue_span, _WebVTTCueVoiceSpan) + assert block.identifier == "62357a1d-d250-41d5-a1cf-6cc0eeceffcc/16-0" + assert str(block.timings) == "00:00:10.683 --> 00:00:11.563" + assert len(block.payload) == 1 + assert isinstance(block.payload[0], _WebVTTCueTextSpan) + assert block.payload[0].text == "Good." From 8ea1fafc19b95f40b3f02b47a63d6c730abb7649 Mon Sep 17 00:00:00 2001 From: Cesar Berrospi Ramis Date: Fri, 14 Nov 2025 14:53:05 +0100 Subject: [PATCH 2/5] fix(webvtt): deal with HTML entities in cue text spans Signed-off-by: Cesar Berrospi Ramis --- docling_core/types/doc/webvtt.py | 15 ++++++++++++++- test/test_webvtt.py | 6 ++++++ 2 files changed, 20 insertions(+), 1 deletion(-) diff --git a/docling_core/types/doc/webvtt.py b/docling_core/types/doc/webvtt.py index eccae4a6..d7cabdc3 100644 --- a/docling_core/types/doc/webvtt.py +++ b/docling_core/types/doc/webvtt.py @@ -98,10 +98,23 @@ class _WebVTTCueTextSpan(BaseModel): text: str span_type: Literal["text"] = "text" + _valid_entities: ClassVar[set] = {"amp", "lt", "gt", "lrm", "rlm", "nbsp"} + _entity_pattern: ClassVar[re.Pattern] = re.compile(r"&([a-zA-Z0-9]+);") + @field_validator("text", mode="after") @classmethod def validate_text(cls, value: str) -> str: - if any(ch in value for ch in {"\n", "\r", "&", "<"}): + for match in cls._entity_pattern.finditer(value): + entity = match.group(1) + if entity not in cls._valid_entities: + raise ValueError( + f"Cue text span contains an invalid HTML entity: &{entity};" + ) + if "&" in re.sub(cls._entity_pattern, "", value): + raise ValueError( + "Found '&' not part of a valid entity in the cue text span" + ) + if any(ch in value for ch in {"\n", "\r", "<"}): raise ValueError("Cue text span contains invalid characters") if len(value) == 0: raise ValueError("Cue text span cannot be empty") diff --git a/test/test_webvtt.py b/test/test_webvtt.py index 75f5dfc1..ea4f2889 100644 --- a/test/test_webvtt.py +++ b/test/test_webvtt.py @@ -96,6 +96,12 @@ def test_vtt_cue_commponents(): invalid_text = "This cue text span contains &." with pytest.raises(ValidationError): _WebVTTCueTextSpan(text=invalid_text) + invalid_text = "An invalid &foo; entity" + with pytest.raises(ValidationError): + _WebVTTCueTextSpan(text=invalid_text) + valid_text = "My favorite book is Pride & Prejudice" + span = _WebVTTCueTextSpan(text=valid_text) + assert span.text == valid_text """Test with text containing less-than sign.""" invalid_text = "This cue text span contains <." From eba67a2616599e43114ae422686aadb33be69b25 Mon Sep 17 00:00:00 2001 From: Cesar Berrospi Ramis Date: Mon, 17 Nov 2025 03:32:05 +0100 Subject: [PATCH 3/5] refactor(webvtt): support more WebVTT models Signed-off-by: Cesar Berrospi Ramis --- docling_core/types/doc/webvtt.py | 367 +++++++++++++++++++------------ test/test_webvtt.py | 137 +++++++++--- 2 files changed, 332 insertions(+), 172 deletions(-) diff --git a/docling_core/types/doc/webvtt.py b/docling_core/types/doc/webvtt.py index d7cabdc3..6d60a2d8 100644 --- a/docling_core/types/doc/webvtt.py +++ b/docling_core/types/doc/webvtt.py @@ -2,7 +2,8 @@ import logging import re -from typing import Annotated, ClassVar, Literal, Optional, Union, cast +from enum import Enum +from typing import Annotated, ClassVar, Literal, Optional, Union from pydantic import BaseModel, ConfigDict, Field, field_validator, model_validator from pydantic.types import StringConstraints @@ -11,8 +12,24 @@ _log = logging.getLogger(__name__) +_VALID_ENTITIES: set = {"amp", "lt", "gt", "lrm", "rlm", "nbsp"} +_ENTITY_PATTERN: re.Pattern = re.compile(r"&([a-zA-Z0-9]+);") +_START_TAG_NAMES = Literal["c", "b", "i", "u", "v", "lang"] + + +class _WebVTTLineTerminator(str, Enum): + CRLF = "\r\n" + LF = "\n" + CR = "\r" + + +_WebVTTCueIdentifier = Annotated[ + str, StringConstraints(strict=True, pattern=r"^(?!.*-->)[^\n\r]+$") +] + + class _WebVTTTimestamp(BaseModel): - """Model representing a WebVTT timestamp. + """WebVTT timestamp. A WebVTT timestamp is always interpreted relative to the current playback position of the media data that the WebVTT file is to be synchronized with. @@ -67,13 +84,8 @@ def __str__(self) -> str: return self.raw -_WebVTTCueIdentifier = Annotated[ - str, StringConstraints(strict=True, pattern=r"^(?!.*-->)[^\n\r]+$") -] - - class _WebVTTCueTimings(BaseModel): - """Model representating WebVTT cue timings.""" + """WebVTT cue timings.""" start: Annotated[ _WebVTTTimestamp, Field(description="Start time offset of the cue") @@ -93,31 +105,27 @@ def __str__(self): class _WebVTTCueTextSpan(BaseModel): - """Model representing a WebVTT cue text span.""" + """WebVTT cue text span.""" - text: str - span_type: Literal["text"] = "text" - - _valid_entities: ClassVar[set] = {"amp", "lt", "gt", "lrm", "rlm", "nbsp"} - _entity_pattern: ClassVar[re.Pattern] = re.compile(r"&([a-zA-Z0-9]+);") + kind: Literal["text"] = "text" + text: Annotated[str, Field(description="The cue text.")] @field_validator("text", mode="after") @classmethod - def validate_text(cls, value: str) -> str: - for match in cls._entity_pattern.finditer(value): + def is_valid_text(cls, value: str) -> str: + for match in _ENTITY_PATTERN.finditer(value): entity = match.group(1) - if entity not in cls._valid_entities: + if entity not in _VALID_ENTITIES: raise ValueError( - f"Cue text span contains an invalid HTML entity: &{entity};" + f"Cue text contains an invalid HTML entity: &{entity};" ) - if "&" in re.sub(cls._entity_pattern, "", value): - raise ValueError( - "Found '&' not part of a valid entity in the cue text span" - ) + if "&" in re.sub(_ENTITY_PATTERN, "", value): + raise ValueError("Found '&' not part of a valid entity in the cue text") if any(ch in value for ch in {"\n", "\r", "<"}): - raise ValueError("Cue text span contains invalid characters") + raise ValueError("Cue text contains invalid characters") if len(value) == 0: - raise ValueError("Cue text span cannot be empty") + raise ValueError("Cue text cannot be empty") + return value @override @@ -125,37 +133,48 @@ def __str__(self): return self.text -class _WebVTTCueVoiceSpan(BaseModel): - """Model representing a WebVTT cue voice span.""" +class _WebVTTCueComponentWithTerminator(BaseModel): + """WebVTT caption or subtitle cue component optionally with a line terminator.""" - annotation: Annotated[ - str, + component: "_WebVTTCueComponent" + terminator: Optional[_WebVTTLineTerminator] = None + + @override + def __str__(self): + return f"{self.component}{self.terminator.value if self.terminator else ''}" + + +class _WebVTTCueInternalText(BaseModel): + """WebVTT cue internal text.""" + + terminator: Optional[_WebVTTLineTerminator] = None + components: Annotated[ + list[_WebVTTCueComponentWithTerminator], Field( description=( - "Cue span start tag annotation text representing the name of thevoice" + "WebVTT caption or subtitle cue components representing the " + "cue internal text" ) ), - ] + ] = [] + + @override + def __str__(self): + cue_str = ( + f"{self.terminator.value if self.terminator else ''}" + f"{''.join(str(span) for span in self.components)}" + ) + return cue_str + + +class _WebVTTCueSpanStartTag(BaseModel): + """WebVTT cue span start tag.""" + + name: Annotated[_START_TAG_NAMES, Field(description="The tag name")] classes: Annotated[ list[str], Field(description="List of classes representing the cue span's significance"), ] = [] - components: Annotated[ - list["_WebVTTCueComponent"], - Field(description="The components representing the cue internal text"), - ] = [] - span_type: Literal["v"] = "v" - - @field_validator("annotation", mode="after") - @classmethod - def validate_annotation(cls, value: str) -> str: - if any(ch in value for ch in {"\n", "\r", "&", ">"}): - raise ValueError( - "Cue span start tag annotation contains invalid characters" - ) - if not value: - raise ValueError("Cue text span cannot be empty") - return value @field_validator("classes", mode="after") @classmethod @@ -169,51 +188,113 @@ def validate_classes(cls, value: list[str]) -> list[str]: raise ValueError("Cue span start tag classes cannot be empty") return value + def _get_name_with_classes(self) -> str: + return f"{self.name}.{'.'.join(self.classes)}" if self.classes else self.name + @override def __str__(self): - tag = f"v.{'.'.join(self.classes)}" if self.classes else "v" - inner = "".join(str(span) for span in self.components) - return f"<{tag} {self.annotation}>{inner}" + return f"<{self._get_name_with_classes()}>" -class _WebVTTCueClassSpan(BaseModel): - span_type: Literal["c"] = "c" - components: list["_WebVTTCueComponent"] +class _WebVTTCueSpanStartTagAnnotated(_WebVTTCueSpanStartTag): + """WebVTT cue span start tag requiring an annotation.""" - @override - def __str__(self): - inner = "".join(str(span) for span in self.components) - return f"{inner}" + annotation: Annotated[str, Field(description="Cue span start tag annotation")] + @field_validator("annotation", mode="after") + @classmethod + def is_valid_annotation(cls, value: str) -> str: + for match in _ENTITY_PATTERN.finditer(value): + entity = match.group(1) + if entity not in _VALID_ENTITIES: + raise ValueError( + f"Annotation contains an invalid HTML entity: &{entity};" + ) + if "&" in re.sub(_ENTITY_PATTERN, "", value): + raise ValueError("Found '&' not part of a valid entity in annotation") + if any(ch in value for ch in {"\n", "\r", ">"}): + raise ValueError("Annotation contains invalid characters") + if len(value) == 0: + raise ValueError("Annotation cannot be empty") -class _WebVTTCueItalicSpan(BaseModel): - span_type: Literal["i"] = "i" - components: list["_WebVTTCueComponent"] + return value @override def __str__(self): - inner = "".join(str(span) for span in self.components) - return f"{inner}" + return f"<{self._get_name_with_classes()} {self.annotation}>" -class _WebVTTCueBoldSpan(BaseModel): - span_type: Literal["b"] = "b" - components: list["_WebVTTCueComponent"] +class _WebVTTCueComponentBase(BaseModel): + """WebVTT caption or subtitle cue component. - @override - def __str__(self): - inner = "".join(str(span) for span in self.components) - return f"{inner}" + All the WebVTT caption or subtitle cue components are represented by this class + except the WebVTT cue text span, which requires different definitions. + """ + kind: Literal["c", "b", "i", "u", "v", "lang"] + start_tag: _WebVTTCueSpanStartTag + internal_text: _WebVTTCueInternalText -class _WebVTTCueUnderlineSpan(BaseModel): - span_type: Literal["u"] = "u" - components: list["_WebVTTCueComponent"] + @model_validator(mode="after") + def check_tag_names_match(self) -> Self: + if self.kind != self.start_tag.name: + raise ValueError("The tag name of this cue component should be {self.kind}") + return self @override def __str__(self): - inner = "".join(str(span) for span in self.components) - return f"{inner}" + return f"{self.start_tag}{self.internal_text}" + + +class _WebVTTCueVoiceSpan(_WebVTTCueComponentBase): + """WebVTT cue voice span associated with a specific voice.""" + + kind: Literal["v"] = "v" + start_tag: _WebVTTCueSpanStartTagAnnotated + + +class _WebVTTCueClassSpan(_WebVTTCueComponentBase): + """WebVTT cue class span. + + It represents a span of text and it is used to annotate parts of the cue with + applicable classes without implying further meaning (such as italics or bold). + """ + + kind: Literal["c"] = "c" + start_tag: _WebVTTCueSpanStartTag = _WebVTTCueSpanStartTag(name="c") + + +class _WebVTTCueItalicSpan(_WebVTTCueComponentBase): + """WebVTT cue italic span representing a span of italic text.""" + + kind: Literal["i"] = "i" + start_tag: _WebVTTCueSpanStartTag = _WebVTTCueSpanStartTag(name="i") + + +class _WebVTTCueBoldSpan(_WebVTTCueComponentBase): + """WebVTT cue bold span representing a span of bold text.""" + + kind: Literal["b"] = "b" + start_tag: _WebVTTCueSpanStartTag = _WebVTTCueSpanStartTag(name="b") + + +class _WebVTTCueUnderlineSpan(_WebVTTCueComponentBase): + """WebVTT cue underline span representing a span of underline text.""" + + kind: Literal["u"] = "u" + start_tag: _WebVTTCueSpanStartTag = _WebVTTCueSpanStartTag(name="u") + + +class _WebVTTCueLanguageSpan(_WebVTTCueComponentBase): + """WebVTT cue language span. + + It represents a span of text and it is used to annotate parts of the cue where the + applicable language might be different than the surrounding text's, without + implying further meaning (such as italics or bold). + """ + + kind: Literal["lang"] = "lang" + start_tag: _WebVTTCueSpanStartTagAnnotated _WebVTTCueComponent = Annotated[ @@ -224,8 +305,12 @@ def __str__(self): _WebVTTCueBoldSpan, _WebVTTCueUnderlineSpan, _WebVTTCueVoiceSpan, + _WebVTTCueLanguageSpan, ], - Field(discriminator="span_type", description="The WebVTT cue component"), + Field( + discriminator="kind", + description="The type of WebVTT caption or subtitle cue component.", + ), ] @@ -243,14 +328,17 @@ class _WebVTTCueBlock(BaseModel): None, description="The WebVTT cue identifier" ) timings: Annotated[_WebVTTCueTimings, Field(description="The WebVTT cue timings")] - payload: Annotated[list[_WebVTTCueComponent], Field(description="The cue payload")] + payload: Annotated[ + list[_WebVTTCueComponentWithTerminator], + Field(description="The WebVTT caption or subtitle cue text"), + ] - _pattern_block: ClassVar[re.Pattern] = re.compile( - r"<(/?)(i|b|c|u|v(?:\.[^\t\n\r &<>.]+)*)(?:\s+([^>]*))?>" - ) - _pattern_voice_tag: ClassVar[re.Pattern] = re.compile( - r"^\.[^\t\n\r &<>]+)?" # zero or more classes - r"[ \t]+(?P[^\n\r&>]+)>" # required space and annotation + # pattern of a WebVTT cue span start/end tag + _pattern_tag: ClassVar[re.Pattern] = re.compile( + r"<(?P/?)" + r"(?Pi|b|c|u|v|lang)" + r"(?P(?:\.[^\t\n\r &<>.]+)*)" + r"(?:[ \t](?P[^\n\r&>]*))?>" ) @field_validator("payload", mode="after") @@ -284,74 +372,77 @@ def parse(cls, raw: str) -> "_WebVTTCueBlock": start=_WebVTTTimestamp(raw=start), end=_WebVTTTimestamp(raw=end) ) cue_text = " ".join(cue_lines).strip() - if cue_text.startswith("" not in cue_text: - # adding close tag for cue voice spans without end tag - cue_text += "" + # adding close tag for cue spans without end tag + for omm in {"v"}: + if cue_text.startswith(f"<{omm}") and f"" not in cue_text: + cue_text += f"" + break - stack: list[list[_WebVTTCueComponent]] = [[]] - tag_stack: list[Union[str, tuple]] = [] + stack: list[list[_WebVTTCueComponentWithTerminator]] = [[]] + tag_stack: list[dict] = [] pos = 0 - matches = list(cls._pattern_block.finditer(cue_text)) + matches = list(cls._pattern_tag.finditer(cue_text)) i = 0 while i < len(matches): match = matches[i] if match.start() > pos: - stack[-1].append(_WebVTTCueTextSpan(text=cue_text[pos : match.start()])) - tag = match.group(0) - - if tag.startswith(("", "", "", "")): - tag_type = tag[1:2] - tag_stack.append(tag_type) - stack.append([]) - elif tag == "": - children = stack.pop() - stack[-1].append(_WebVTTCueItalicSpan(components=children)) - tag_stack.pop() - elif tag == "": - children = stack.pop() - stack[-1].append(_WebVTTCueBoldSpan(components=children)) - tag_stack.pop() - elif tag == "": - children = stack.pop() - stack[-1].append(_WebVTTCueUnderlineSpan(components=children)) - tag_stack.pop() - elif tag == "": - children = stack.pop() - stack[-1].append(_WebVTTCueClassSpan(components=children)) - tag_stack.pop() - elif tag.startswith("")) else: parts.append(str(span)) - return "".join(parts) + return "".join(parts) + "\n" class _WebVTTFile(BaseModel): diff --git a/test/test_webvtt.py b/test/test_webvtt.py index ea4f2889..b4d408cb 100644 --- a/test/test_webvtt.py +++ b/test/test_webvtt.py @@ -1,11 +1,20 @@ -# Assisted by watsonx Code Assistant +"""Test the data model for WebVTT files. +Assisted by watsonx Code Assistant. +Examples extracted from https://www.w3.org/TR/webvtt1/ +Copyright © 2019 World Wide Web Consortium. +""" import pytest from pydantic import ValidationError from docling_core.types.doc.webvtt import ( + _WebVTTCueBlock, + _WebVTTCueComponentWithTerminator, + _WebVTTCueInternalText, _WebVTTCueItalicSpan, + _WebVTTCueLanguageSpan, + _WebVTTCueSpanStartTagAnnotated, _WebVTTCueTextSpan, _WebVTTCueTimings, _WebVTTCueVoiceSpan, @@ -18,7 +27,7 @@ GENERATE = GEN_TEST_DATA -def test_vtt_cue_commponents(): +def test_vtt_cue_commponents() -> None: """Test WebVTT components.""" valid_timestamps = [ "00:01:02.345", @@ -72,13 +81,13 @@ def test_vtt_cue_commponents(): """Test invalid cue timings with missing end.""" start = _WebVTTTimestamp(raw="00:10.500") with pytest.raises(ValidationError) as excinfo: - _WebVTTCueTimings(start=start) + _WebVTTCueTimings(start=start) # type: ignore[call-arg] assert "Field required" in str(excinfo.value) """Test invalid cue timings with missing start.""" end = _WebVTTTimestamp(raw="00:10.500") with pytest.raises(ValidationError) as excinfo: - _WebVTTCueTimings(end=end) + _WebVTTCueTimings(end=end) # type: ignore[call-arg] assert "Field required" in str(excinfo.value) """Test with valid text.""" @@ -116,44 +125,105 @@ def test_vtt_cue_commponents(): valid_annotation = "valid-annotation" invalid_annotation = "invalid\nannotation" with pytest.raises(ValidationError): - _WebVTTCueVoiceSpan(annotation=invalid_annotation) - assert _WebVTTCueVoiceSpan(annotation=valid_annotation) + _WebVTTCueSpanStartTagAnnotated(name="v", annotation=invalid_annotation) + assert _WebVTTCueSpanStartTagAnnotated(name="v", annotation=valid_annotation) """Test that classes validation works correctly.""" annotation = "speaker name" valid_classes = ["class1", "class2"] invalid_classes = ["class\nwith\nnewlines", ""] with pytest.raises(ValidationError): - _WebVTTCueVoiceSpan(annotation=annotation, classes=invalid_classes) - assert _WebVTTCueVoiceSpan(annotation=annotation, classes=valid_classes) + _WebVTTCueSpanStartTagAnnotated( + name="v", annotation=annotation, classes=invalid_classes + ) + assert _WebVTTCueSpanStartTagAnnotated( + name="v", annotation=annotation, classes=valid_classes + ) """Test that components validation works correctly.""" annotation = "speaker name" - valid_components = [_WebVTTCueTextSpan(text="random text")] + valid_components = [ + _WebVTTCueComponentWithTerminator( + component=_WebVTTCueTextSpan(text="random text") + ) + ] invalid_components = [123, "not a component"] with pytest.raises(ValidationError): - _WebVTTCueVoiceSpan(annotation=annotation, components=invalid_components) - assert _WebVTTCueVoiceSpan(annotation=annotation, components=valid_components) + _WebVTTCueInternalText(components=invalid_components) + assert _WebVTTCueInternalText(components=valid_components) """Test valid cue voice spans.""" cue_span = _WebVTTCueVoiceSpan( - annotation="speaker", - classes=["loud", "clear"], - components=[_WebVTTCueTextSpan(text="random text")], + start_tag=_WebVTTCueSpanStartTagAnnotated( + name="v", annotation="speaker", classes=["loud", "clear"] + ), + internal_text=_WebVTTCueInternalText( + components=[ + _WebVTTCueComponentWithTerminator( + component=_WebVTTCueTextSpan(text="random text") + ) + ] + ), ) - expected_str = "random text" assert str(cue_span) == expected_str cue_span = _WebVTTCueVoiceSpan( - annotation="speaker", - components=[_WebVTTCueTextSpan(text="random text")], + start_tag=_WebVTTCueSpanStartTagAnnotated(name="v", annotation="speaker"), + internal_text=_WebVTTCueInternalText( + components=[ + _WebVTTCueComponentWithTerminator( + component=_WebVTTCueTextSpan(text="random text") + ) + ] + ), ) expected_str = "random text" assert str(cue_span) == expected_str -def test_webvtt_file(): +def test_webvttcueblock_parse() -> None: + """Test the method parse of _WebVTTCueBlock class.""" + raw: str = ( + "04:02.500 --> 04:05.000\n" "J’ai commencé le basket à l'âge de 13, 14 ans\n" + ) + block: _WebVTTCueBlock = _WebVTTCueBlock.parse(raw) + assert str(block.timings) == "04:02.500 --> 04:05.000" + assert len(block.payload) == 1 + assert isinstance(block.payload[0], _WebVTTCueComponentWithTerminator) + assert isinstance(block.payload[0].component, _WebVTTCueTextSpan) + assert ( + block.payload[0].component.text + == "J’ai commencé le basket à l'âge de 13, 14 ans" + ) + assert raw == str(block) + + raw = ( + "04:05.001 --> 04:07.800\n" + "Sur les playground, ici à Montpellier\n" + ) + block = _WebVTTCueBlock.parse(raw) + assert str(block.timings) == "04:05.001 --> 04:07.800" + assert len(block.payload) == 3 + assert isinstance(block.payload[0], _WebVTTCueComponentWithTerminator) + assert isinstance(block.payload[0].component, _WebVTTCueTextSpan) + assert block.payload[0].component.text == "Sur les " + assert isinstance(block.payload[1], _WebVTTCueComponentWithTerminator) + assert isinstance(block.payload[1].component, _WebVTTCueItalicSpan) + assert len(block.payload[1].component.internal_text.components) == 1 + lang_span = block.payload[1].component.internal_text.components[0].component + assert isinstance(lang_span, _WebVTTCueLanguageSpan) + assert isinstance( + lang_span.internal_text.components[0].component, _WebVTTCueTextSpan + ) + assert lang_span.internal_text.components[0].component.text == "playground" + assert isinstance(block.payload[2], _WebVTTCueComponentWithTerminator) + assert isinstance(block.payload[2].component, _WebVTTCueTextSpan) + assert block.payload[2].component.text == ", ici à Montpellier" + assert raw == str(block) + + +def test_webvtt_file() -> None: """Test WebVTT files.""" with open("./test/data/webvtt/webvtt_example_01.vtt", encoding="utf-8") as f: content = f.read() @@ -163,16 +233,16 @@ def test_webvtt_file(): assert str(block.timings) == "00:32.500 --> 00:33.500" assert len(block.payload) == 1 cue_span = block.payload[0] - assert isinstance(cue_span, _WebVTTCueVoiceSpan) - assert cue_span.annotation == "Neil deGrasse Tyson" - assert not cue_span.classes - assert len(cue_span.components) == 1 - comp = cue_span.components[0] - assert isinstance(comp, _WebVTTCueItalicSpan) - assert len(comp.components) == 1 - comp2 = comp.components[0] - assert isinstance(comp2, _WebVTTCueTextSpan) - assert comp2.text == "Laughs" + assert isinstance(cue_span.component, _WebVTTCueVoiceSpan) + assert cue_span.component.start_tag.annotation == "Neil deGrasse Tyson" + assert not cue_span.component.start_tag.classes + assert len(cue_span.component.internal_text.components) == 1 + comp = cue_span.component.internal_text.components[0] + assert isinstance(comp.component, _WebVTTCueItalicSpan) + assert len(comp.component.internal_text.components) == 1 + comp2 = comp.component.internal_text.components[0] + assert isinstance(comp2.component, _WebVTTCueTextSpan) + assert comp2.component.text == "Laughs" with open("./test/data/webvtt/webvtt_example_02.vtt", encoding="utf-8") as f: content = f.read() @@ -182,8 +252,8 @@ def test_webvtt_file(): "WEBVTT\n\nNOTE Copyright © 2019 World Wide Web Consortium. " "https://www.w3.org/TR/webvtt1/\n\n" ) - reverse += "\n\n".join([str(block) for block in vtt.cue_blocks]) - assert content == reverse + reverse += "\n".join([str(block) for block in vtt.cue_blocks]) + assert content == reverse.rstrip() with open("./test/data/webvtt/webvtt_example_03.vtt", encoding="utf-8") as f: content = f.read() @@ -195,11 +265,10 @@ def test_webvtt_file(): assert block.identifier == "62357a1d-d250-41d5-a1cf-6cc0eeceffcc/15-0" assert str(block.timings) == "00:00:04.963 --> 00:00:08.571" assert len(block.payload) == 1 - assert isinstance(block.payload[0], _WebVTTCueVoiceSpan) + assert isinstance(block.payload[0].component, _WebVTTCueVoiceSpan) block = vtt.cue_blocks[2] - assert isinstance(cue_span, _WebVTTCueVoiceSpan) assert block.identifier == "62357a1d-d250-41d5-a1cf-6cc0eeceffcc/16-0" assert str(block.timings) == "00:00:10.683 --> 00:00:11.563" assert len(block.payload) == 1 - assert isinstance(block.payload[0], _WebVTTCueTextSpan) - assert block.payload[0].text == "Good." + assert isinstance(block.payload[0].component, _WebVTTCueTextSpan) + assert block.payload[0].component.text == "Good." From f1d59c71bd0a20da36afebf54c58948792c3bc85 Mon Sep 17 00:00:00 2001 From: Cesar Berrospi Ramis Date: Thu, 27 Nov 2025 18:58:35 +0100 Subject: [PATCH 4/5] refactor(DoclingDocument): create a new provenance model for media file types Signed-off-by: Cesar Berrospi Ramis --- docling_core/transforms/serializer/azure.py | 21 +- docling_core/transforms/serializer/common.py | 25 +- docling_core/transforms/serializer/doctags.py | 14 +- .../visualizer/key_value_visualizer.py | 17 +- .../visualizer/layout_visualizer.py | 18 +- .../visualizer/reading_order_visualizer.py | 5 +- .../transforms/visualizer/table_visualizer.py | 11 +- docling_core/types/doc/__init__.py | 1 + docling_core/types/doc/document.py | 207 +++++++++++----- docling_core/utils/legacy.py | 8 +- docs/DoclingDocument.json | 229 ++++++++++++++++-- 11 files changed, 451 insertions(+), 105 deletions(-) diff --git a/docling_core/transforms/serializer/azure.py b/docling_core/transforms/serializer/azure.py index 674f90b8..d4522a62 100644 --- a/docling_core/transforms/serializer/azure.py +++ b/docling_core/transforms/serializer/azure.py @@ -44,9 +44,10 @@ DocSerializer, create_ser_result, ) -from docling_core.types.doc.base import CoordOrigin -from docling_core.types.doc.document import ( +from docling_core.types.doc import ( + CoordOrigin, DocItem, + DocItemLabel, DoclingDocument, FormItem, InlineGroup, @@ -54,12 +55,12 @@ ListGroup, NodeItem, PictureItem, + ProvenanceItem, RefItem, RichTableCell, TableItem, TextItem, ) -from docling_core.types.doc.labels import DocItemLabel def _bbox_to_polygon_coords( @@ -78,7 +79,7 @@ def _bbox_to_polygon_for_item( doc: DoclingDocument, item: DocItem ) -> Optional[list[float]]: """Compute a TOPLEFT-origin polygon for the first provenance of the item.""" - if not item.prov: + if not item.prov or not isinstance(item.prov[0], ProvenanceItem): return None prov = item.prov[0] @@ -189,7 +190,7 @@ def serialize( # Lists may be represented either as TextItem(ListItem) or via groups; # we treat any TextItem as a paragraph-like entry. - if item.prov: + if item.prov and isinstance(item.prov[0], ProvenanceItem): prov = item.prov[0] page_no = prov.page_no polygon = _bbox_to_polygon_for_item(doc, item) @@ -241,7 +242,7 @@ def serialize( ) -> SerializationResult: assert isinstance(doc_serializer, AzureDocSerializer) - if not item.prov: + if not item.prov or not isinstance(item.prov[0], ProvenanceItem): return create_ser_result() prov = item.prov[0] @@ -322,7 +323,7 @@ def serialize( ) -> SerializationResult: assert isinstance(doc_serializer, AzureDocSerializer) - if not item.prov: + if not item.prov or not isinstance(item.prov[0], ProvenanceItem): return create_ser_result() prov = item.prov[0] @@ -340,7 +341,11 @@ def serialize( for foot_ref in item.footnotes: if isinstance(foot_ref, RefItem): tgt = foot_ref.resolve(doc) - if isinstance(tgt, TextItem) and tgt.prov: + if ( + isinstance(tgt, TextItem) + and tgt.prov + and isinstance(tgt.prov[0], ProvenanceItem) + ): f_poly = _bbox_to_polygon_for_item(doc, tgt) if f_poly is not None: foots.append( diff --git a/docling_core/transforms/serializer/common.py b/docling_core/transforms/serializer/common.py index 4930e839..cb5ed095 100644 --- a/docling_core/transforms/serializer/common.py +++ b/docling_core/transforms/serializer/common.py @@ -34,11 +34,11 @@ SerializationResult, Span, ) -from docling_core.types.doc.document import ( - DOCUMENT_TOKENS_EXPORT_LABELS, +from docling_core.types.doc import ( ContentLayer, DescriptionAnnotation, DocItem, + DocItemLabel, DoclingDocument, FloatingItem, Formatting, @@ -51,12 +51,13 @@ PictureDataType, PictureItem, PictureMoleculeData, + ProvenanceItem, Script, TableAnnotationType, TableItem, TextItem, ) -from docling_core.types.doc.labels import DocItemLabel +from docling_core.types.doc.document import DOCUMENT_TOKENS_EXPORT_LABELS _DEFAULT_LABELS = DOCUMENT_TOKENS_EXPORT_LABELS _DEFAULT_LAYERS = {cl for cl in ContentLayer} @@ -110,7 +111,11 @@ def _iterate_items( add_page_breaks=add_page_breaks, visited=my_visited, ): - if isinstance(it, DocItem) and it.prov: + if ( + isinstance(it, DocItem) + and it.prov + and isinstance(it.prov[0], ProvenanceItem) + ): page_no = it.prov[0].page_no if prev_page_nr is not None and page_no > prev_page_nr: yield _PageBreakNode( @@ -119,7 +124,11 @@ def _iterate_items( next_page=page_no, ), lvl break - elif isinstance(item, DocItem) and item.prov: + elif ( + isinstance(item, DocItem) + and item.prov + and isinstance(item.prov[0], ProvenanceItem) + ): page_no = item.prov[0].page_no if prev_page_nr is None or page_no > prev_page_nr: if prev_page_nr is not None: # close previous range @@ -288,7 +297,10 @@ def get_excluded_refs(self, **kwargs: Any) -> set[str]: params.pages is not None and ( (not item.prov) - or item.prov[0].page_no not in params.pages + or ( + isinstance(item.prov[0], ProvenanceItem) + and item.prov[0].page_no not in params.pages + ) ) ) ) @@ -635,6 +647,7 @@ def _get_applicable_pages(self) -> Optional[list[int]]: if ( isinstance(item, DocItem) and item.prov + and isinstance(item.prov[0], ProvenanceItem) and ( self.params.pages is None or item.prov[0].page_no in self.params.pages diff --git a/docling_core/transforms/serializer/doctags.py b/docling_core/transforms/serializer/doctags.py index 807b7750..a15999b2 100644 --- a/docling_core/transforms/serializer/doctags.py +++ b/docling_core/transforms/serializer/doctags.py @@ -26,11 +26,13 @@ _should_use_legacy_annotations, create_ser_result, ) -from docling_core.types.doc.base import BoundingBox from docling_core.types.doc.document import ( + BoundingBox, CodeItem, DocItem, + DocItemLabel, DoclingDocument, + DocumentToken, FloatingItem, FormItem, GroupItem, @@ -40,6 +42,7 @@ ListItem, NodeItem, PictureClassificationData, + PictureClassificationLabel, PictureItem, PictureMoleculeData, PictureTabularChartData, @@ -47,10 +50,9 @@ SectionHeaderItem, TableData, TableItem, + TableToken, TextItem, ) -from docling_core.types.doc.labels import DocItemLabel, PictureClassificationLabel -from docling_core.types.doc.tokens import DocumentToken, TableToken def _wrap(text: str, wrap_tag: str) -> str: @@ -360,7 +362,7 @@ def serialize( results: list[SerializationResult] = [] page_no = 1 - if len(item.prov) > 0: + if len(item.prov) > 0 and isinstance(item.prov[0], ProvenanceItem): page_no = item.prov[0].page_no if params.add_location: @@ -380,7 +382,7 @@ def serialize( for cell in item.graph.cells: cell_txt = "" - if cell.prov is not None: + if cell.prov is not None and isinstance(cell.prov, ProvenanceItem): if len(doc.pages.keys()): page_w, page_h = doc.pages[page_no].size.as_tuple() cell_txt += DocumentToken.get_location( @@ -492,7 +494,7 @@ def _get_inline_location_tags( doc_items: list[DocItem] = [] for it, _ in doc.iterate_items(root=item): if isinstance(it, DocItem): - for prov in it.prov: + for prov in (im for im in it.prov if isinstance(im, ProvenanceItem)): boxes.append(prov.bbox) doc_items.append(it) if prov is None: diff --git a/docling_core/transforms/visualizer/key_value_visualizer.py b/docling_core/transforms/visualizer/key_value_visualizer.py index b0198455..1ef12654 100644 --- a/docling_core/transforms/visualizer/key_value_visualizer.py +++ b/docling_core/transforms/visualizer/key_value_visualizer.py @@ -16,8 +16,13 @@ from typing_extensions import override from docling_core.transforms.visualizer.base import BaseVisualizer -from docling_core.types.doc.document import ContentLayer, DoclingDocument -from docling_core.types.doc.labels import GraphCellLabel, GraphLinkLabel +from docling_core.types.doc import ( + ContentLayer, + DoclingDocument, + GraphCellLabel, + GraphLinkLabel, + ProvenanceItem, +) # --------------------------------------------------------------------------- # Helper functions / constants @@ -78,7 +83,11 @@ def _draw_key_value_layer( # First draw cells (rectangles + optional labels) # ------------------------------------------------------------------ for cell in cell_dict.values(): - if cell.prov is None or cell.prov.page_no != page_no: + if ( + cell.prov is None + or not isinstance(cell.prov, ProvenanceItem) + or cell.prov.page_no != page_no + ): continue # skip cells not on this page or without bbox tl_bbox = cell.prov.bbox.to_top_left_origin( @@ -127,6 +136,8 @@ def _draw_key_value_layer( if ( src_cell.prov is None or tgt_cell.prov is None + or not isinstance(src_cell.prov, ProvenanceItem) + or not isinstance(tgt_cell.prov, ProvenanceItem) or src_cell.prov.page_no != page_no or tgt_cell.prov.page_no != page_no ): diff --git a/docling_core/transforms/visualizer/layout_visualizer.py b/docling_core/transforms/visualizer/layout_visualizer.py index 886ad8b4..8478a198 100644 --- a/docling_core/transforms/visualizer/layout_visualizer.py +++ b/docling_core/transforms/visualizer/layout_visualizer.py @@ -10,10 +10,16 @@ from typing_extensions import override from docling_core.transforms.visualizer.base import BaseVisualizer -from docling_core.types.doc import DocItemLabel -from docling_core.types.doc.base import CoordOrigin -from docling_core.types.doc.document import ContentLayer, DocItem, DoclingDocument -from docling_core.types.doc.page import BoundingRectangle, TextCell +from docling_core.types.doc import ( + BoundingRectangle, + ContentLayer, + CoordOrigin, + DocItem, + DocItemLabel, + DoclingDocument, + ProvenanceItem, + TextCell, +) class _TLBoundingRectangle(BoundingRectangle): @@ -157,7 +163,9 @@ def _draw_doc_layout( if len(elem.prov) == 0: continue # Skip elements without provenances - for prov in elem.prov: + for prov in ( + item for item in elem.prov if isinstance(item, ProvenanceItem) + ): page_nr = prov.page_no if page_nr in my_images: diff --git a/docling_core/transforms/visualizer/reading_order_visualizer.py b/docling_core/transforms/visualizer/reading_order_visualizer.py index c012f22b..0e2aa6a1 100644 --- a/docling_core/transforms/visualizer/reading_order_visualizer.py +++ b/docling_core/transforms/visualizer/reading_order_visualizer.py @@ -15,6 +15,7 @@ DocItem, DoclingDocument, PictureItem, + ProvenanceItem, ) @@ -139,7 +140,9 @@ def _draw_doc_reading_order( if len(elem.prov) == 0: continue # Skip elements without provenances - for prov in elem.prov: + for prov in ( + item for item in elem.prov if isinstance(item, ProvenanceItem) + ): page_no = prov.page_no image = my_images.get(page_no) diff --git a/docling_core/transforms/visualizer/table_visualizer.py b/docling_core/transforms/visualizer/table_visualizer.py index 0a722959..c173f33f 100644 --- a/docling_core/transforms/visualizer/table_visualizer.py +++ b/docling_core/transforms/visualizer/table_visualizer.py @@ -10,7 +10,12 @@ from typing_extensions import override from docling_core.transforms.visualizer.base import BaseVisualizer -from docling_core.types.doc.document import ContentLayer, DoclingDocument, TableItem +from docling_core.types.doc import ( + ContentLayer, + DoclingDocument, + ProvenanceItem, + TableItem, +) _log = logging.getLogger(__name__) @@ -171,12 +176,12 @@ def _draw_doc_tables( image = deepcopy(pil_img) my_images[page_nr] = image - for idx, (elem, _) in enumerate( + for _, (elem, _) in enumerate( doc.iterate_items(included_content_layers=included_content_layers) ): if not isinstance(elem, TableItem): continue - if len(elem.prov) == 0: + if len(elem.prov) == 0 or not isinstance(elem.prov[0], ProvenanceItem): continue # Skip elements without provenances if len(elem.prov) == 1: diff --git a/docling_core/types/doc/__init__.py b/docling_core/types/doc/__init__.py index 3c699f89..f0e0e92d 100644 --- a/docling_core/types/doc/__init__.py +++ b/docling_core/types/doc/__init__.py @@ -61,6 +61,7 @@ Script, SectionHeaderItem, SummaryMetaField, + TableAnnotationType, TableCell, TableData, TableItem, diff --git a/docling_core/types/doc/document.py b/docling_core/types/doc/document.py index 08d1a045..d9b77282 100644 --- a/docling_core/types/doc/document.py +++ b/docling_core/types/doc/document.py @@ -35,9 +35,11 @@ AnyUrl, BaseModel, ConfigDict, + Discriminator, Field, FieldSerializationInfo, StringConstraints, + Tag, computed_field, field_serializer, field_validator, @@ -66,6 +68,7 @@ ) from docling_core.types.doc.tokens import DocumentToken, TableToken from docling_core.types.doc.utils import parse_otsl_table_content, relative_path +from docling_core.types.doc.webvtt import _WebVTTTimestamp _logger = logging.getLogger(__name__) @@ -1206,11 +1209,85 @@ def from_multipage_doctags_and_images( class ProvenanceItem(BaseModel): - """ProvenanceItem.""" + """Provenance information for elements extracted from a textual document. - page_no: int - bbox: BoundingBox - charspan: Tuple[int, int] + A `ProvenanceItem` object acts as a lightweight pointer back into the original + document for an extracted element. It applies to documents with an explicity + or implicit layout, such as PDF, HTML, docx, or pptx. + """ + + page_no: Annotated[int, Field(description="Page number")] + bbox: Annotated[BoundingBox, Field(description="Bounding box")] + charspan: Annotated[ + tuple[int, int], Field(description="Character span (0-indexed)") + ] + + +class ProvenanceTrack(BaseModel): + """Provenance information for elements extracted from media assets. + + A `ProvenanceTrack` instance describes a cue in a text track associated with a + media element (audio, video, subtitles, screen recordings, ...). + """ + + start_time: Annotated[ + _WebVTTTimestamp, + Field( + examples=["00.11.000", "00:00:06.500", "01:28:34.300"], + description="Start time offset of the track cue", + ), + ] + end_time: Annotated[ + _WebVTTTimestamp, + Field( + examples=["00.12.000", "00:00:08.200", "01:29:30.100"], + description="End time offset of the track cue", + ), + ] + identifier: Optional[str] = Field( + None, + examples=["test", "123", "b72d946"], + description="An identifier of the cue", + ) + voice: Optional[str] = Field( + None, + examples=["Mary", "Fred", "Name Surname"], + description="The cue voice (speaker)", + ) + language: Optional[str] = Field( + None, + examples=["en", "en-GB", "fr-CA"], + description="Language of the cue in BCP 47 language tag format", + ) + classes: Optional[list[str]] = Field( + None, + min_length=1, + examples=["first", "loud", "yellow"], + description="Classes for describing the cue significance", + ) + + +def get_provenance_discriminator_value(v: Any) -> str: + """Callable discriminator for provenance instances. + + Args: + v: Either dict or model input. + + Returns: + A string discriminator of provenance instances. + """ + fields = {"bbox", "page_no", "charspan"} + if isinstance(v, dict): + return "item" if any(f in v for f in fields) else "track" + return "item" if any(hasattr(v, f) for f in fields) else "track" + + +ProvenanceType = Annotated[ + Union[ + Annotated[ProvenanceItem, Tag("item")], Annotated[ProvenanceTrack, Tag("track")] + ], + Discriminator(get_provenance_discriminator_value), +] class ContentLayer(str, Enum): @@ -1534,7 +1611,7 @@ class DocItem( """DocItem.""" label: DocItemLabel - prov: List[ProvenanceItem] = [] + prov: List[ProvenanceType] = [] def get_location_tokens( self, @@ -1549,7 +1626,7 @@ def get_location_tokens( return "" location = "" - for prov in self.prov: + for prov in (item for item in self.prov if isinstance(item, ProvenanceItem)): page_w, page_h = doc.pages[prov.page_no].size.as_tuple() loc_str = DocumentToken.get_location( @@ -1573,10 +1650,13 @@ def get_image( if a valid image of the page containing this DocItem is not available in doc. """ - if not len(self.prov): + if not self.prov or prov_index >= len(self.prov): + return None + prov = self.prov[prov_index] + if not isinstance(prov, ProvenanceItem): return None - page = doc.pages.get(self.prov[prov_index].page_no) + page = doc.pages.get(prov.page_no) if page is None or page.size is None or page.image is None: return None @@ -1584,9 +1664,9 @@ def get_image( if not page_image: return None crop_bbox = ( - self.prov[prov_index] - .bbox.to_top_left_origin(page_height=page.size.height) - .scale_to_size(old_size=page.size, new_size=page.image.size) + prov.bbox.to_top_left_origin(page_height=page.size.height).scale_to_size( + old_size=page.size, new_size=page.image.size + ) # .scaled(scale=page_image.height / page.size.height) ) return page_image.crop(crop_bbox.as_tuple()) @@ -2278,7 +2358,7 @@ def export_to_otsl( return "" page_no = 0 - if len(self.prov) > 0: + if len(self.prov) > 0 and isinstance(self.prov[0], ProvenanceItem): page_no = self.prov[0].page_no for i in range(nrows): @@ -2410,7 +2490,7 @@ class GraphCell(BaseModel): text: str # sanitized text orig: str # text as seen on document - prov: Optional[ProvenanceItem] = None + prov: Optional[ProvenanceType] = None # in case you have a text, table or picture item item_ref: Optional[RefItem] = None @@ -3101,7 +3181,7 @@ def add_list_item( enumerated: bool = False, marker: Optional[str] = None, orig: Optional[str] = None, - prov: Optional[ProvenanceItem] = None, + prov: Optional[ProvenanceType] = None, parent: Optional[NodeItem] = None, content_layer: Optional[ContentLayer] = None, formatting: Optional[Formatting] = None, @@ -3112,7 +3192,7 @@ def add_list_item( :param label: str: :param text: str: :param orig: Optional[str]: (Default value = None) - :param prov: Optional[ProvenanceItem]: (Default value = None) + :param prov: Optional[ProvenanceType]: (Default value = None) :param parent: Optional[NodeItem]: (Default value = None) """ @@ -3153,7 +3233,7 @@ def add_text( label: DocItemLabel, text: str, orig: Optional[str] = None, - prov: Optional[ProvenanceItem] = None, + prov: Optional[ProvenanceType] = None, parent: Optional[NodeItem] = None, content_layer: Optional[ContentLayer] = None, formatting: Optional[Formatting] = None, @@ -3164,7 +3244,7 @@ def add_text( :param label: str: :param text: str: :param orig: Optional[str]: (Default value = None) - :param prov: Optional[ProvenanceItem]: (Default value = None) + :param prov: Optional[ProvenanceType]: (Default value = None) :param parent: Optional[NodeItem]: (Default value = None) """ @@ -3259,7 +3339,7 @@ def add_table( self, data: TableData, caption: Optional[Union[TextItem, RefItem]] = None, # This is not cool yet. - prov: Optional[ProvenanceItem] = None, + prov: Optional[ProvenanceType] = None, parent: Optional[NodeItem] = None, label: DocItemLabel = DocItemLabel.TABLE, content_layer: Optional[ContentLayer] = None, @@ -3269,7 +3349,7 @@ def add_table( :param data: TableData: :param caption: Optional[Union[TextItem, RefItem]]: (Default value = None) - :param prov: Optional[ProvenanceItem]: (Default value = None) + :param prov: Optional[ProvenanceType]: (Default value = None) :param parent: Optional[NodeItem]: (Default value = None) :param label: DocItemLabel: (Default value = DocItemLabel.TABLE) @@ -3305,7 +3385,7 @@ def add_picture( annotations: Optional[List[PictureDataType]] = None, image: Optional[ImageRef] = None, caption: Optional[Union[TextItem, RefItem]] = None, - prov: Optional[ProvenanceItem] = None, + prov: Optional[ProvenanceType] = None, parent: Optional[NodeItem] = None, content_layer: Optional[ContentLayer] = None, ): @@ -3314,7 +3394,7 @@ def add_picture( :param data: Optional[List[PictureData]]: (Default value = None) :param caption: Optional[Union[TextItem: :param RefItem]]: (Default value = None) - :param prov: Optional[ProvenanceItem]: (Default value = None) + :param prov: Optional[ProvenanceType]: (Default value = None) :param parent: Optional[NodeItem]: (Default value = None) """ if not parent: @@ -3346,7 +3426,7 @@ def add_title( self, text: str, orig: Optional[str] = None, - prov: Optional[ProvenanceItem] = None, + prov: Optional[ProvenanceType] = None, parent: Optional[NodeItem] = None, content_layer: Optional[ContentLayer] = None, formatting: Optional[Formatting] = None, @@ -3357,7 +3437,7 @@ def add_title( :param text: str: :param orig: Optional[str]: (Default value = None) :param level: LevelNumber: (Default value = 1) - :param prov: Optional[ProvenanceItem]: (Default value = None) + :param prov: Optional[ProvenanceType]: (Default value = None) :param parent: Optional[NodeItem]: (Default value = None) """ if not parent: @@ -3392,7 +3472,7 @@ def add_code( code_language: Optional[CodeLanguageLabel] = None, orig: Optional[str] = None, caption: Optional[Union[TextItem, RefItem]] = None, - prov: Optional[ProvenanceItem] = None, + prov: Optional[ProvenanceType] = None, parent: Optional[NodeItem] = None, content_layer: Optional[ContentLayer] = None, formatting: Optional[Formatting] = None, @@ -3405,7 +3485,7 @@ def add_code( :param orig: Optional[str]: (Default value = None) :param caption: Optional[Union[TextItem: :param RefItem]]: (Default value = None) - :param prov: Optional[ProvenanceItem]: (Default value = None) + :param prov: Optional[ProvenanceType]: (Default value = None) :param parent: Optional[NodeItem]: (Default value = None) """ if not parent: @@ -3442,7 +3522,7 @@ def add_formula( self, text: str, orig: Optional[str] = None, - prov: Optional[ProvenanceItem] = None, + prov: Optional[ProvenanceType] = None, parent: Optional[NodeItem] = None, content_layer: Optional[ContentLayer] = None, formatting: Optional[Formatting] = None, @@ -3453,7 +3533,7 @@ def add_formula( :param text: str: :param orig: Optional[str]: (Default value = None) :param level: LevelNumber: (Default value = 1) - :param prov: Optional[ProvenanceItem]: (Default value = None) + :param prov: Optional[ProvenanceType]: (Default value = None) :param parent: Optional[NodeItem]: (Default value = None) """ if not parent: @@ -3487,7 +3567,7 @@ def add_heading( text: str, orig: Optional[str] = None, level: LevelNumber = 1, - prov: Optional[ProvenanceItem] = None, + prov: Optional[ProvenanceType] = None, parent: Optional[NodeItem] = None, content_layer: Optional[ContentLayer] = None, formatting: Optional[Formatting] = None, @@ -3499,7 +3579,7 @@ def add_heading( :param text: str: :param orig: Optional[str]: (Default value = None) :param level: LevelNumber: (Default value = 1) - :param prov: Optional[ProvenanceItem]: (Default value = None) + :param prov: Optional[ProvenanceType]: (Default value = None) :param parent: Optional[NodeItem]: (Default value = None) """ if not parent: @@ -3532,13 +3612,13 @@ def add_heading( def add_key_values( self, graph: GraphData, - prov: Optional[ProvenanceItem] = None, + prov: Optional[ProvenanceType] = None, parent: Optional[NodeItem] = None, ): """add_key_values. :param graph: GraphData: - :param prov: Optional[ProvenanceItem]: (Default value = None) + :param prov: Optional[ProvenanceType]: (Default value = None) :param parent: Optional[NodeItem]: (Default value = None) """ if not parent: @@ -3563,13 +3643,13 @@ def add_key_values( def add_form( self, graph: GraphData, - prov: Optional[ProvenanceItem] = None, + prov: Optional[ProvenanceType] = None, parent: Optional[NodeItem] = None, ): """add_form. :param graph: GraphData: - :param prov: Optional[ProvenanceItem]: (Default value = None) + :param prov: Optional[ProvenanceType]: (Default value = None) :param parent: Optional[NodeItem]: (Default value = None) """ if not parent: @@ -3766,7 +3846,7 @@ def insert_list_item( enumerated: bool = False, marker: Optional[str] = None, orig: Optional[str] = None, - prov: Optional[ProvenanceItem] = None, + prov: Optional[ProvenanceType] = None, content_layer: Optional[ContentLayer] = None, formatting: Optional[Formatting] = None, hyperlink: Optional[Union[AnyUrl, Path]] = None, @@ -3779,7 +3859,7 @@ def insert_list_item( :param enumerated: bool: (Default value = False) :param marker: Optional[str]: (Default value = None) :param orig: Optional[str]: (Default value = None) - :param prov: Optional[ProvenanceItem]: (Default value = None) + :param prov: Optional[ProvenanceType]: (Default value = None) :param content_layer: Optional[ContentLayer]: (Default value = None) :param formatting: Optional[Formatting]: (Default value = None) :param hyperlink: Optional[Union[AnyUrl, Path]]: (Default value = None) @@ -3840,7 +3920,7 @@ def insert_text( label: DocItemLabel, text: str, orig: Optional[str] = None, - prov: Optional[ProvenanceItem] = None, + prov: Optional[ProvenanceType] = None, content_layer: Optional[ContentLayer] = None, formatting: Optional[Formatting] = None, hyperlink: Optional[Union[AnyUrl, Path]] = None, @@ -3852,7 +3932,7 @@ def insert_text( :param label: DocItemLabel: :param text: str: :param orig: Optional[str]: (Default value = None) - :param prov: Optional[ProvenanceItem]: (Default value = None) + :param prov: Optional[ProvenanceType]: (Default value = None) :param content_layer: Optional[ContentLayer]: (Default value = None) :param formatting: Optional[Formatting]: (Default value = None) :param hyperlink: Optional[Union[AnyUrl, Path]]: (Default value = None) @@ -3952,7 +4032,7 @@ def insert_table( sibling: NodeItem, data: TableData, caption: Optional[Union[TextItem, RefItem]] = None, - prov: Optional[ProvenanceItem] = None, + prov: Optional[ProvenanceType] = None, label: DocItemLabel = DocItemLabel.TABLE, content_layer: Optional[ContentLayer] = None, annotations: Optional[list[TableAnnotationType]] = None, @@ -3963,7 +4043,7 @@ def insert_table( :param sibling: NodeItem: :param data: TableData: :param caption: Optional[Union[TextItem, RefItem]]: (Default value = None) - :param prov: Optional[ProvenanceItem]: (Default value = None) + :param prov: Optional[ProvenanceType]: (Default value = None) :param label: DocItemLabel: (Default value = DocItemLabel.TABLE) :param content_layer: Optional[ContentLayer]: (Default value = None) :param annotations: Optional[List[TableAnnotationType]]: (Default value = None) @@ -4000,7 +4080,7 @@ def insert_picture( annotations: Optional[List[PictureDataType]] = None, image: Optional[ImageRef] = None, caption: Optional[Union[TextItem, RefItem]] = None, - prov: Optional[ProvenanceItem] = None, + prov: Optional[ProvenanceType] = None, content_layer: Optional[ContentLayer] = None, after: bool = True, ) -> PictureItem: @@ -4010,7 +4090,7 @@ def insert_picture( :param annotations: Optional[List[PictureDataType]]: (Default value = None) :param image: Optional[ImageRef]: (Default value = None) :param caption: Optional[Union[TextItem, RefItem]]: (Default value = None) - :param prov: Optional[ProvenanceItem]: (Default value = None) + :param prov: Optional[ProvenanceType]: (Default value = None) :param content_layer: Optional[ContentLayer]: (Default value = None) :param after: bool: (Default value = True) @@ -4044,7 +4124,7 @@ def insert_title( sibling: NodeItem, text: str, orig: Optional[str] = None, - prov: Optional[ProvenanceItem] = None, + prov: Optional[ProvenanceType] = None, content_layer: Optional[ContentLayer] = None, formatting: Optional[Formatting] = None, hyperlink: Optional[Union[AnyUrl, Path]] = None, @@ -4055,7 +4135,7 @@ def insert_title( :param sibling: NodeItem: :param text: str: :param orig: Optional[str]: (Default value = None) - :param prov: Optional[ProvenanceItem]: (Default value = None) + :param prov: Optional[ProvenanceType]: (Default value = None) :param content_layer: Optional[ContentLayer]: (Default value = None) :param formatting: Optional[Formatting]: (Default value = None) :param hyperlink: Optional[Union[AnyUrl, Path]]: (Default value = None) @@ -4095,7 +4175,7 @@ def insert_code( code_language: Optional[CodeLanguageLabel] = None, orig: Optional[str] = None, caption: Optional[Union[TextItem, RefItem]] = None, - prov: Optional[ProvenanceItem] = None, + prov: Optional[ProvenanceType] = None, content_layer: Optional[ContentLayer] = None, formatting: Optional[Formatting] = None, hyperlink: Optional[Union[AnyUrl, Path]] = None, @@ -4108,7 +4188,7 @@ def insert_code( :param code_language: Optional[str]: (Default value = None) :param orig: Optional[str]: (Default value = None) :param caption: Optional[Union[TextItem, RefItem]]: (Default value = None) - :param prov: Optional[ProvenanceItem]: (Default value = None) + :param prov: Optional[ProvenanceType]: (Default value = None) :param content_layer: Optional[ContentLayer]: (Default value = None) :param formatting: Optional[Formatting]: (Default value = None) :param hyperlink: Optional[Union[AnyUrl, Path]]: (Default value = None) @@ -4150,7 +4230,7 @@ def insert_formula( sibling: NodeItem, text: str, orig: Optional[str] = None, - prov: Optional[ProvenanceItem] = None, + prov: Optional[ProvenanceType] = None, content_layer: Optional[ContentLayer] = None, formatting: Optional[Formatting] = None, hyperlink: Optional[Union[AnyUrl, Path]] = None, @@ -4161,7 +4241,7 @@ def insert_formula( :param sibling: NodeItem: :param text: str: :param orig: Optional[str]: (Default value = None) - :param prov: Optional[ProvenanceItem]: (Default value = None) + :param prov: Optional[ProvenanceType]: (Default value = None) :param content_layer: Optional[ContentLayer]: (Default value = None) :param formatting: Optional[Formatting]: (Default value = None) :param hyperlink: Optional[Union[AnyUrl, Path]]: (Default value = None) @@ -4200,7 +4280,7 @@ def insert_heading( text: str, orig: Optional[str] = None, level: LevelNumber = 1, - prov: Optional[ProvenanceItem] = None, + prov: Optional[ProvenanceType] = None, content_layer: Optional[ContentLayer] = None, formatting: Optional[Formatting] = None, hyperlink: Optional[Union[AnyUrl, Path]] = None, @@ -4212,7 +4292,7 @@ def insert_heading( :param text: str: :param orig: Optional[str]: (Default value = None) :param level: LevelNumber: (Default value = 1) - :param prov: Optional[ProvenanceItem]: (Default value = None) + :param prov: Optional[ProvenanceType]: (Default value = None) :param content_layer: Optional[ContentLayer]: (Default value = None) :param formatting: Optional[Formatting]: (Default value = None) :param hyperlink: Optional[Union[AnyUrl, Path]]: (Default value = None) @@ -4250,14 +4330,14 @@ def insert_key_values( self, sibling: NodeItem, graph: GraphData, - prov: Optional[ProvenanceItem] = None, + prov: Optional[ProvenanceType] = None, after: bool = True, ) -> KeyValueItem: """Creates a new KeyValueItem item and inserts it into the document. :param sibling: NodeItem: :param graph: GraphData: - :param prov: Optional[ProvenanceItem]: (Default value = None) + :param prov: Optional[ProvenanceType]: (Default value = None) :param after: bool: (Default value = True) :returns: KeyValueItem: The newly created KeyValueItem item. @@ -4279,14 +4359,14 @@ def insert_form( self, sibling: NodeItem, graph: GraphData, - prov: Optional[ProvenanceItem] = None, + prov: Optional[ProvenanceType] = None, after: bool = True, ) -> FormItem: """Creates a new FormItem item and inserts it into the document. :param sibling: NodeItem: :param graph: GraphData: - :param prov: Optional[ProvenanceItem]: (Default value = None) + :param prov: Optional[ProvenanceType]: (Default value = None) :param after: bool: (Default value = True) :returns: FormItem: The newly created FormItem item. @@ -4654,7 +4734,11 @@ def _iterate_items_with_stack( not isinstance(root, DocItem) or ( page_nrs is None - or any(prov.page_no in page_nrs for prov in root.prov) + or any( + prov.page_no in page_nrs + for prov in root.prov + if isinstance(prov, ProvenanceItem) + ) ) ) and root.content_layer in my_layers @@ -4764,7 +4848,7 @@ def _with_pictures_refs( image_dir.mkdir(parents=True, exist_ok=True) if image_dir.is_dir(): - for item, level in result.iterate_items(page_no=page_no, with_groups=False): + for item, _ in result.iterate_items(page_no=page_no, with_groups=False): if isinstance(item, PictureItem): img = item.get_image(doc=self) if img is not None: @@ -4784,12 +4868,15 @@ def _with_pictures_refs( else: obj_path = loc_path - if item.image is None: + if item.image is None and isinstance( + item.prov[0], ProvenanceItem + ): scale = img.size[0] / item.prov[0].bbox.width item.image = ImageRef.from_pil( image=img, dpi=round(72 * scale) ) - item.image.uri = Path(obj_path) + elif item.image is not None: + item.image.uri = Path(obj_path) # if item.image._pil is not None: # item.image._pil.close() @@ -6268,7 +6355,11 @@ def index( if isinstance(new_item, DocItem): # update page numbers # NOTE other prov sources (e.g. GraphCell) currently not covered - for prov in new_item.prov: + for prov in ( + item + for item in new_item.prov + if isinstance(item, ProvenanceItem) + ): prov.page_no += page_delta if item.parent: diff --git a/docling_core/utils/legacy.py b/docling_core/utils/legacy.py index 6f8fdf99..b3b21364 100644 --- a/docling_core/utils/legacy.py +++ b/docling_core/utils/legacy.py @@ -7,20 +7,23 @@ from docling_core.types.doc import ( BoundingBox, + ContentLayer, CoordOrigin, DocItem, DocItemLabel, DoclingDocument, DocumentOrigin, + GroupItem, + ListItem, PictureItem, ProvenanceItem, SectionHeaderItem, Size, TableCell, + TableData, TableItem, TextItem, ) -from docling_core.types.doc.document import ContentLayer, GroupItem, ListItem, TableData from docling_core.types.legacy_doc.base import ( BaseCell, BaseText, @@ -164,6 +167,7 @@ def docling_document_to_legacy(doc: DoclingDocument, fallback_filaname: str = "f span=[0, len(item.text)], ) for p in item.prov + if isinstance(p, ProvenanceItem) ] main_text.append( BaseText( @@ -287,6 +291,7 @@ def _make_spans(cell: TableCell, table_item: TableItem): span=[0, 0], ) for p in item.prov + if isinstance(p, ProvenanceItem) ], ) ) @@ -314,6 +319,7 @@ def _make_spans(cell: TableCell, table_item: TableItem): span=[0, len(caption)], ) for p in item.prov + if isinstance(p, ProvenanceItem) ], obj_type=doc_item_label_to_legacy_type(item.label), text=caption, diff --git a/docs/DoclingDocument.json b/docs/DoclingDocument.json index 365a62bf..d732f0f8 100644 --- a/docs/DoclingDocument.json +++ b/docs/DoclingDocument.json @@ -233,7 +233,14 @@ "prov": { "default": [], "items": { - "$ref": "#/$defs/ProvenanceItem" + "oneOf": [ + { + "$ref": "#/$defs/ProvenanceItem" + }, + { + "$ref": "#/$defs/ProvenanceTrack" + } + ] }, "title": "Prov", "type": "array" @@ -606,7 +613,14 @@ "prov": { "default": [], "items": { - "$ref": "#/$defs/ProvenanceItem" + "oneOf": [ + { + "$ref": "#/$defs/ProvenanceItem" + }, + { + "$ref": "#/$defs/ProvenanceTrack" + } + ] }, "title": "Prov", "type": "array" @@ -740,7 +754,14 @@ "prov": { "default": [], "items": { - "$ref": "#/$defs/ProvenanceItem" + "oneOf": [ + { + "$ref": "#/$defs/ProvenanceItem" + }, + { + "$ref": "#/$defs/ProvenanceTrack" + } + ] }, "title": "Prov", "type": "array" @@ -812,13 +833,21 @@ "prov": { "anyOf": [ { - "$ref": "#/$defs/ProvenanceItem" + "oneOf": [ + { + "$ref": "#/$defs/ProvenanceItem" + }, + { + "$ref": "#/$defs/ProvenanceTrack" + } + ] }, { "type": "null" } ], - "default": null + "default": null, + "title": "Prov" }, "item_ref": { "anyOf": [ @@ -1137,7 +1166,14 @@ "prov": { "default": [], "items": { - "$ref": "#/$defs/ProvenanceItem" + "oneOf": [ + { + "$ref": "#/$defs/ProvenanceItem" + }, + { + "$ref": "#/$defs/ProvenanceTrack" + } + ] }, "title": "Prov", "type": "array" @@ -1301,7 +1337,14 @@ "prov": { "default": [], "items": { - "$ref": "#/$defs/ProvenanceItem" + "oneOf": [ + { + "$ref": "#/$defs/ProvenanceItem" + }, + { + "$ref": "#/$defs/ProvenanceTrack" + } + ] }, "title": "Prov", "type": "array" @@ -1669,7 +1712,14 @@ "prov": { "default": [], "items": { - "$ref": "#/$defs/ProvenanceItem" + "oneOf": [ + { + "$ref": "#/$defs/ProvenanceItem" + }, + { + "$ref": "#/$defs/ProvenanceTrack" + } + ] }, "title": "Prov", "type": "array" @@ -2054,16 +2104,19 @@ "type": "object" }, "ProvenanceItem": { - "description": "ProvenanceItem.", + "description": "Provenance information for elements extracted from a textual document.\n\nA `ProvenanceItem` object acts as a lightweight pointer back into the original\ndocument for an extracted element. It applies to documents with an explicity\nor implicit layout, such as PDF, HTML, docx, or pptx.", "properties": { "page_no": { + "description": "Page number", "title": "Page No", "type": "integer" }, "bbox": { - "$ref": "#/$defs/BoundingBox" + "$ref": "#/$defs/BoundingBox", + "description": "Bounding box" }, "charspan": { + "description": "Character span (0-indexed)", "maxItems": 2, "minItems": 2, "prefixItems": [ @@ -2086,6 +2139,111 @@ "title": "ProvenanceItem", "type": "object" }, + "ProvenanceTrack": { + "description": "Provenance information for elements extracted from media assets.\n\nA `ProvenanceTrack` instance describes a cue in a text track associated with a\nmedia element (audio, video, subtitles, screen recordings, ...).", + "properties": { + "start_time": { + "$ref": "#/$defs/_WebVTTTimestamp", + "description": "Start time offset of the track cue", + "examples": [ + "00.11.000", + "00:00:06.500", + "01:28:34.300" + ] + }, + "end_time": { + "$ref": "#/$defs/_WebVTTTimestamp", + "description": "End time offset of the track cue", + "examples": [ + "00.12.000", + "00:00:08.200", + "01:29:30.100" + ] + }, + "identifier": { + "anyOf": [ + { + "type": "string" + }, + { + "type": "null" + } + ], + "default": null, + "description": "An identifier of the cue", + "examples": [ + "test", + "123", + "b72d946" + ], + "title": "Identifier" + }, + "voice": { + "anyOf": [ + { + "type": "string" + }, + { + "type": "null" + } + ], + "default": null, + "description": "The cue voice (speaker)", + "examples": [ + "Mary", + "Fred", + "Name Surname" + ], + "title": "Voice" + }, + "language": { + "anyOf": [ + { + "type": "string" + }, + { + "type": "null" + } + ], + "default": null, + "description": "Language of the cue in BCP 47 language tag format", + "examples": [ + "en", + "en-GB", + "fr-CA" + ], + "title": "Language" + }, + "classes": { + "anyOf": [ + { + "items": { + "type": "string" + }, + "minItems": 1, + "type": "array" + }, + { + "type": "null" + } + ], + "default": null, + "description": "Classes for describing the cue significance", + "examples": [ + "first", + "loud", + "yellow" + ], + "title": "Classes" + } + }, + "required": [ + "start_time", + "end_time" + ], + "title": "ProvenanceTrack", + "type": "object" + }, "RefItem": { "description": "RefItem.", "properties": { @@ -2242,7 +2400,14 @@ "prov": { "default": [], "items": { - "$ref": "#/$defs/ProvenanceItem" + "oneOf": [ + { + "$ref": "#/$defs/ProvenanceItem" + }, + { + "$ref": "#/$defs/ProvenanceTrack" + } + ] }, "title": "Prov", "type": "array" @@ -2529,7 +2694,14 @@ "prov": { "default": [], "items": { - "$ref": "#/$defs/ProvenanceItem" + "oneOf": [ + { + "$ref": "#/$defs/ProvenanceItem" + }, + { + "$ref": "#/$defs/ProvenanceTrack" + } + ] }, "title": "Prov", "type": "array" @@ -2726,7 +2898,14 @@ "prov": { "default": [], "items": { - "$ref": "#/$defs/ProvenanceItem" + "oneOf": [ + { + "$ref": "#/$defs/ProvenanceItem" + }, + { + "$ref": "#/$defs/ProvenanceTrack" + } + ] }, "title": "Prov", "type": "array" @@ -2830,7 +3009,14 @@ "prov": { "default": [], "items": { - "$ref": "#/$defs/ProvenanceItem" + "oneOf": [ + { + "$ref": "#/$defs/ProvenanceItem" + }, + { + "$ref": "#/$defs/ProvenanceTrack" + } + ] }, "title": "Prov", "type": "array" @@ -2880,6 +3066,21 @@ ], "title": "TitleItem", "type": "object" + }, + "_WebVTTTimestamp": { + "description": "WebVTT timestamp.\n\nA WebVTT timestamp is always interpreted relative to the current playback position\nof the media data that the WebVTT file is to be synchronized with.", + "properties": { + "raw": { + "description": "A representation of the WebVTT Timestamp as a single string", + "title": "Raw", + "type": "string" + } + }, + "required": [ + "raw" + ], + "title": "_WebVTTTimestamp", + "type": "object" } }, "description": "DoclingDocument.", From c2efd15823813b9392c8d0d52aa232b52636cf9a Mon Sep 17 00:00:00 2001 From: Cesar Berrospi Ramis Date: Thu, 4 Dec 2025 14:49:53 +0100 Subject: [PATCH 5/5] refactor(webvtt): make WebVTTTimestamp public Since WebVTTTimestamp is used in DoclingDocument, the class should be public. Strengthen validation of cue language start tag annotation. Signed-off-by: Cesar Berrospi Ramis --- docling_core/types/doc/document.py | 6 ++-- docling_core/types/doc/webvtt.py | 52 +++++++++++++++++++++++------- docs/DoclingDocument.json | 10 +++--- test/test_webvtt.py | 31 ++++++++++++------ 4 files changed, 69 insertions(+), 30 deletions(-) diff --git a/docling_core/types/doc/document.py b/docling_core/types/doc/document.py index d9b77282..649d8036 100644 --- a/docling_core/types/doc/document.py +++ b/docling_core/types/doc/document.py @@ -68,7 +68,7 @@ ) from docling_core.types.doc.tokens import DocumentToken, TableToken from docling_core.types.doc.utils import parse_otsl_table_content, relative_path -from docling_core.types.doc.webvtt import _WebVTTTimestamp +from docling_core.types.doc.webvtt import WebVTTTimestamp _logger = logging.getLogger(__name__) @@ -1231,14 +1231,14 @@ class ProvenanceTrack(BaseModel): """ start_time: Annotated[ - _WebVTTTimestamp, + WebVTTTimestamp, Field( examples=["00.11.000", "00:00:06.500", "01:28:34.300"], description="Start time offset of the track cue", ), ] end_time: Annotated[ - _WebVTTTimestamp, + WebVTTTimestamp, Field( examples=["00.12.000", "00:00:08.200", "01:29:30.100"], description="End time offset of the track cue", diff --git a/docling_core/types/doc/webvtt.py b/docling_core/types/doc/webvtt.py index 6d60a2d8..12c18137 100644 --- a/docling_core/types/doc/webvtt.py +++ b/docling_core/types/doc/webvtt.py @@ -28,9 +28,18 @@ class _WebVTTLineTerminator(str, Enum): ] -class _WebVTTTimestamp(BaseModel): +class WebVTTTimestamp(BaseModel): """WebVTT timestamp. + The timestamp is a string consisting of the following components in the given order: + + - hours (optional, required if non-zero): two or more digits + - minutes: two digits between 0 and 59 + - a colon character (:) + - seconds: two digits between 0 and 59 + - a full stop character (.) + - thousandths of a second: three digits + A WebVTT timestamp is always interpreted relative to the current playback position of the media data that the WebVTT file is to be synchronized with. """ @@ -54,6 +63,7 @@ class _WebVTTTimestamp(BaseModel): @model_validator(mode="after") def validate_raw(self) -> Self: + """Validate the WebVTT timestamp as a string.""" m = self._pattern.match(self.raw) if not m: raise ValueError(f"Invalid WebVTT timestamp format: {self.raw}") @@ -81,16 +91,15 @@ def seconds(self) -> float: @override def __str__(self) -> str: + """Return a string representation of a WebVTT timestamp.""" return self.raw class _WebVTTCueTimings(BaseModel): """WebVTT cue timings.""" - start: Annotated[ - _WebVTTTimestamp, Field(description="Start time offset of the cue") - ] - end: Annotated[_WebVTTTimestamp, Field(description="End time offset of the cue")] + start: Annotated[WebVTTTimestamp, Field(description="Start time offset of the cue")] + end: Annotated[WebVTTTimestamp, Field(description="End time offset of the cue")] @model_validator(mode="after") def check_order(self) -> Self: @@ -224,6 +233,21 @@ def __str__(self): return f"<{self._get_name_with_classes()} {self.annotation}>" +class _WebVTTCueLanguageSpanStartTag(_WebVTTCueSpanStartTagAnnotated): + _bcp47_regex = re.compile(r"^[a-zA-Z]{2,3}(-[a-zA-Z0-9]{2,8})*$", re.IGNORECASE) + + name: Literal["lang"] = Field("lang", description="The tag name") + annotation: Annotated[ + str, + Field( + pattern=_bcp47_regex.pattern, + min_length=2, + max_length=99, + description="Cue language span start tag annotation", + ), + ] + + class _WebVTTCueComponentBase(BaseModel): """WebVTT caption or subtitle cue component. @@ -294,7 +318,7 @@ class _WebVTTCueLanguageSpan(_WebVTTCueComponentBase): """ kind: Literal["lang"] = "lang" - start_tag: _WebVTTCueSpanStartTagAnnotated + start_tag: _WebVTTCueLanguageSpanStartTag _WebVTTCueComponent = Annotated[ @@ -369,7 +393,7 @@ def parse(cls, raw: str) -> "_WebVTTCueBlock": start, end = [t.strip() for t in timing_line.split("-->")] end = re.split(" |\t", end)[0] # ignore the cue settings list timings: _WebVTTCueTimings = _WebVTTCueTimings( - start=_WebVTTTimestamp(raw=start), end=_WebVTTTimestamp(raw=end) + start=WebVTTTimestamp(raw=start), end=WebVTTTimestamp(raw=end) ) cue_text = " ".join(cue_lines).strip() # adding close tag for cue spans without end tag @@ -409,13 +433,17 @@ def parse(cls, raw: str) -> "_WebVTTCueBlock": classes: list[str] = [] if class_string: classes = [c for c in class_string.split(".") if c] - st = ( - _WebVTTCueSpanStartTagAnnotated( + st: _WebVTTCueSpanStartTag + if annotation and ct == "lang": + st = _WebVTTCueLanguageSpanStartTag( name=ct, classes=classes, annotation=annotation.strip() ) - if annotation - else _WebVTTCueSpanStartTag(name=ct, classes=classes) - ) + elif annotation: + st = _WebVTTCueSpanStartTagAnnotated( + name=ct, classes=classes, annotation=annotation.strip() + ) + else: + st = _WebVTTCueSpanStartTag(name=ct, classes=classes) it = _WebVTTCueInternalText(components=children) cp: _WebVTTCueComponent if ct == "c": diff --git a/docs/DoclingDocument.json b/docs/DoclingDocument.json index d732f0f8..9d96d5d6 100644 --- a/docs/DoclingDocument.json +++ b/docs/DoclingDocument.json @@ -2143,7 +2143,7 @@ "description": "Provenance information for elements extracted from media assets.\n\nA `ProvenanceTrack` instance describes a cue in a text track associated with a\nmedia element (audio, video, subtitles, screen recordings, ...).", "properties": { "start_time": { - "$ref": "#/$defs/_WebVTTTimestamp", + "$ref": "#/$defs/WebVTTTimestamp", "description": "Start time offset of the track cue", "examples": [ "00.11.000", @@ -2152,7 +2152,7 @@ ] }, "end_time": { - "$ref": "#/$defs/_WebVTTTimestamp", + "$ref": "#/$defs/WebVTTTimestamp", "description": "End time offset of the track cue", "examples": [ "00.12.000", @@ -3067,8 +3067,8 @@ "title": "TitleItem", "type": "object" }, - "_WebVTTTimestamp": { - "description": "WebVTT timestamp.\n\nA WebVTT timestamp is always interpreted relative to the current playback position\nof the media data that the WebVTT file is to be synchronized with.", + "WebVTTTimestamp": { + "description": "WebVTT timestamp.\n\nThe timestamp is a string consisting of the following components in the given order:\n\n- hours (optional, required if non-zero): two or more digits\n- minutes: two digits between 0 and 59\n- a colon character (:)\n- seconds: two digits between 0 and 59\n- a full stop character (.)\n- thousandths of a second: three digits\n\nA WebVTT timestamp is always interpreted relative to the current playback position\nof the media data that the WebVTT file is to be synchronized with.", "properties": { "raw": { "description": "A representation of the WebVTT Timestamp as a single string", @@ -3079,7 +3079,7 @@ "required": [ "raw" ], - "title": "_WebVTTTimestamp", + "title": "WebVTTTimestamp", "type": "object" } }, diff --git a/test/test_webvtt.py b/test/test_webvtt.py index b4d408cb..2295c682 100644 --- a/test/test_webvtt.py +++ b/test/test_webvtt.py @@ -9,17 +9,18 @@ from pydantic import ValidationError from docling_core.types.doc.webvtt import ( + WebVTTTimestamp, _WebVTTCueBlock, _WebVTTCueComponentWithTerminator, _WebVTTCueInternalText, _WebVTTCueItalicSpan, _WebVTTCueLanguageSpan, + _WebVTTCueLanguageSpanStartTag, _WebVTTCueSpanStartTagAnnotated, _WebVTTCueTextSpan, _WebVTTCueTimings, _WebVTTCueVoiceSpan, _WebVTTFile, - _WebVTTTimestamp, ) from .test_data_gen_flag import GEN_TEST_DATA @@ -42,7 +43,7 @@ def test_vtt_cue_commponents() -> None: 0.0, ] for idx, ts in enumerate(valid_timestamps): - model = _WebVTTTimestamp(raw=ts) + model = WebVTTTimestamp(raw=ts) assert model.seconds == valid_total_seconds[idx] """Test invalid WebVTT timestamps.""" @@ -57,35 +58,35 @@ def test_vtt_cue_commponents() -> None: ] for ts in invalid_timestamps: with pytest.raises(ValidationError): - _WebVTTTimestamp(raw=ts) + WebVTTTimestamp(raw=ts) """Test the timestamp __str__ method.""" - model = _WebVTTTimestamp(raw="00:01:02.345") + model = WebVTTTimestamp(raw="00:01:02.345") assert str(model) == "00:01:02.345" """Test valid cue timings.""" - start = _WebVTTTimestamp(raw="00:10.005") - end = _WebVTTTimestamp(raw="00:14.007") + start = WebVTTTimestamp(raw="00:10.005") + end = WebVTTTimestamp(raw="00:14.007") cue_timings = _WebVTTCueTimings(start=start, end=end) assert cue_timings.start == start assert cue_timings.end == end assert str(cue_timings) == "00:10.005 --> 00:14.007" """Test invalid cue timings with end timestamp before start.""" - start = _WebVTTTimestamp(raw="00:10.700") - end = _WebVTTTimestamp(raw="00:10.500") + start = WebVTTTimestamp(raw="00:10.700") + end = WebVTTTimestamp(raw="00:10.500") with pytest.raises(ValidationError) as excinfo: _WebVTTCueTimings(start=start, end=end) assert "End timestamp must be greater than start timestamp" in str(excinfo.value) """Test invalid cue timings with missing end.""" - start = _WebVTTTimestamp(raw="00:10.500") + start = WebVTTTimestamp(raw="00:10.500") with pytest.raises(ValidationError) as excinfo: _WebVTTCueTimings(start=start) # type: ignore[call-arg] assert "Field required" in str(excinfo.value) """Test invalid cue timings with missing start.""" - end = _WebVTTTimestamp(raw="00:10.500") + end = WebVTTTimestamp(raw="00:10.500") with pytest.raises(ValidationError) as excinfo: _WebVTTCueTimings(end=end) # type: ignore[call-arg] assert "Field required" in str(excinfo.value) @@ -272,3 +273,13 @@ def test_webvtt_file() -> None: assert len(block.payload) == 1 assert isinstance(block.payload[0].component, _WebVTTCueTextSpan) assert block.payload[0].component.text == "Good." + + +def test_webvtt_cue_language_span_start_tag(): + _WebVTTCueLanguageSpanStartTag.model_validate_json('{"annotation": "en"}') + _WebVTTCueLanguageSpanStartTag.model_validate_json('{"annotation": "en-US"}') + _WebVTTCueLanguageSpanStartTag.model_validate_json('{"annotation": "zh-Hant"}') + with pytest.raises(ValidationError, match="should match pattern"): + _WebVTTCueLanguageSpanStartTag.model_validate_json('{"annotation": "en_US"}') + with pytest.raises(ValidationError, match="should match pattern"): + _WebVTTCueLanguageSpanStartTag.model_validate_json('{"annotation": "123-de"}')