diff --git a/nemo_retriever/src/nemo_retriever/model/local/__init__.py b/nemo_retriever/src/nemo_retriever/model/local/__init__.py index af068fa7d..106c01b5c 100644 --- a/nemo_retriever/src/nemo_retriever/model/local/__init__.py +++ b/nemo_retriever/src/nemo_retriever/model/local/__init__.py @@ -14,6 +14,7 @@ __all__ = [ "NemotronPageElementsV3", "NemotronOCRV1", + "NemotronOCRV2", "NemotronTableStructureV1", "NemotronGraphicElementsV1", "NemotronParseV12", @@ -32,6 +33,10 @@ def __getattr__(name: str): from .nemotron_ocr_v1 import NemotronOCRV1 return NemotronOCRV1 + if name == "NemotronOCRV2": + from .nemotron_ocr_v2 import NemotronOCRV2 + + return NemotronOCRV2 if name == "NemotronTableStructureV1": from .nemotron_table_structure_v1 import NemotronTableStructureV1 diff --git a/nemo_retriever/src/nemo_retriever/model/local/nemotron_ocr_v2.py b/nemo_retriever/src/nemo_retriever/model/local/nemotron_ocr_v2.py new file mode 100644 index 000000000..0c0d393d5 --- /dev/null +++ b/nemo_retriever/src/nemo_retriever/model/local/nemotron_ocr_v2.py @@ -0,0 +1,237 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024-25, NVIDIA CORPORATION & AFFILIATES. +# All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +from typing import Any, Dict, List, Optional, Tuple, Union # noqa: F401 + +import base64 +import io +import os +from pathlib import Path # noqa: F401 + +import numpy as np +import torch +from nemo_retriever.utils.hf_cache import configure_global_hf_cache_base +from ..model import BaseModel, RunMode + +from PIL import Image + + +class NemotronOCRV2(BaseModel): + """ + Nemotron OCR v2 model for optical character recognition. + + End-to-end OCR model that integrates: + - Text detector for region localization (RegNetX-8GF backbone) + - Text recognizer for transcription (pre-norm Transformer) + - Relational model for layout and reading order analysis + + Supports both English-only (v2_english) and multilingual variants + (v2_multilingual: EN, ZH, JA, KO, RU). + """ + + def __init__( + self, + model_dir: Optional[str] = None, + ) -> None: + super().__init__() + configure_global_hf_cache_base() + from nemotron_ocr.inference.pipeline_v2 import NemotronOCRV2 as _NemotronOCRV2 # local-only import + + if model_dir: + self._model = _NemotronOCRV2(model_dir=model_dir) + else: + self._model = _NemotronOCRV2() + self._enable_trt = os.getenv("RETRIEVER_ENABLE_TORCH_TRT", "").strip().lower() in {"1", "true", "yes", "on"} + if self._enable_trt and self._model is not None: + self._maybe_compile_submodules() + + def _maybe_compile_submodules(self) -> None: + """ + Best-effort TensorRT compilation of internal nn.Modules. + Any failure falls back to eager PyTorch without breaking initialization. + """ + try: + import torch_tensorrt # type: ignore + except Exception: + return + + if self._model is None: + return + + detector = getattr(self._model, "detector", None) + if not isinstance(detector, torch.nn.Module): + return + + try: + trt_input = torch_tensorrt.Input((1, 3, 1024, 1024), dtype=torch.float16) + except TypeError: + trt_input = torch_tensorrt.Input(shape=(1, 3, 1024, 1024), dtype=torch.float16) + + compile_kwargs: Dict[str, Any] = { + "inputs": [trt_input], + "enabled_precisions": {torch.float16}, + } + if hasattr(torch_tensorrt, "compile"): + for k in ("torch_executed_ops", "torch_executed_modules"): + if k == "torch_executed_ops": + compile_kwargs[k] = {"torchvision::nms"} + elif k == "torch_executed_modules": + compile_kwargs[k] = set() + try: + self._model.detector = torch_tensorrt.compile(detector, **compile_kwargs) + except Exception: + return + + def preprocess(self, tensor: torch.Tensor) -> torch.Tensor: + """Preprocess the input tensor.""" + return tensor + + @staticmethod + def _tensor_to_png_b64(img: torch.Tensor) -> str: + """ + Convert a CHW/BCHW tensor into a base64-encoded PNG. + + Accepts: + - CHW (3,H,W) or (1,H,W) + Returns: + - base64 string (no data: prefix) + """ + if not isinstance(img, torch.Tensor): + raise TypeError(f"Expected torch.Tensor, got {type(img)}") + if img.ndim != 3: + raise ValueError(f"Expected CHW tensor, got shape {tuple(img.shape)}") + + x = img.detach() + if x.device.type != "cpu": + x = x.cpu() + + if x.dtype.is_floating_point: + maxv = float(x.max().item()) if x.numel() else 1.0 + if maxv <= 1.5: + x = x * 255.0 + x = x.clamp(0, 255).to(dtype=torch.uint8) + else: + x = x.clamp(0, 255).to(dtype=torch.uint8) + + c, h, w = int(x.shape[0]), int(x.shape[1]), int(x.shape[2]) # noqa: F841 + if c == 1: + arr = x.squeeze(0).numpy() + pil = Image.fromarray(arr, mode="L").convert("RGB") + elif c == 3: + arr = x.permute(1, 2, 0).contiguous().numpy() + pil = Image.fromarray(arr, mode="RGB") + else: + raise ValueError(f"Expected 1 or 3 channels, got {c}") + + buf = io.BytesIO() + pil.save(buf, format="PNG") + return base64.b64encode(buf.getvalue()).decode("utf-8") + + @staticmethod + def _extract_text(obj: Any) -> str: + if obj is None: + return "" + if isinstance(obj, str): + return obj.strip() + if isinstance(obj, dict): + for k in ("text", "output_text", "generated_text", "ocr_text"): + v = obj.get(k) + if isinstance(v, str) and v.strip(): + return v.strip() + if "words" in obj and isinstance(obj["words"], list): + parts: List[str] = [] + for w in obj["words"]: + if isinstance(w, dict) and isinstance(w.get("text"), str): + parts.append(w["text"]) + if parts: + return " ".join(parts).strip() + return str(obj).strip() + + def invoke( + self, + input_data: Union[torch.Tensor, str, bytes, np.ndarray, io.BytesIO], + merge_level: str = "paragraph", + ) -> Any: + """ + Invoke OCR locally. + + Supports: + - file path (str) **only if it exists** + - base64 (str/bytes) (str is treated as base64 unless it is an existing file path) + - NumPy array (HWC) + - io.BytesIO + - torch.Tensor (CHW/BCHW): converted to base64 PNG internally for compatibility + """ + if self._model is None: + raise RuntimeError("Local OCR model was not initialized.") + + if isinstance(input_data, torch.Tensor): + if input_data.ndim == 4: + out: List[Any] = [] + for i in range(int(input_data.shape[0])): + b64 = self._tensor_to_png_b64(input_data[i]) + out.extend(self._model(b64.encode("utf-8"), merge_level=merge_level)) + return out + if input_data.ndim == 3: + b64 = self._tensor_to_png_b64(input_data) + return self._model(b64.encode("utf-8"), merge_level=merge_level) + raise ValueError(f"Unsupported torch tensor shape for OCR: {tuple(input_data.shape)}") + + if isinstance(input_data, str): + return self._model(input_data.encode("utf-8"), merge_level=merge_level) + + return self._model(input_data, merge_level=merge_level) + + @property + def model_name(self) -> str: + """Human-readable model name.""" + return "Nemotron OCR v2" + + @property + def model_type(self) -> str: + """Model category/type.""" + return "ocr" + + @property + def model_runmode(self) -> RunMode: + """Execution mode: local, NIM, or build-endpoint.""" + return "local" + + @property + def input(self) -> Any: + return { + "type": "image", + "format": "RGB", + "supported_formats": ["PNG", "JPEG"], + "data_types": ["float32", "uint8"], + "dimensions": "variable (H x W)", + "batch_support": True, + "value_range": {"float32": "[0, 1]", "uint8": "[0, 255] (auto-converted)"}, + "aggregation_levels": ["word", "sentence", "paragraph"], + "description": "Document or scene image in RGB format with automatic multi-scale resizing", + } + + @property + def output(self) -> Any: + return { + "type": "ocr_results", + "format": "structured", + "structure": { + "boxes": "List[List[List[float]]] - quadrilateral bounding box coordinates [[x1,y1], [x2,y2], [x3,y3], [x4,y4]]", # noqa: E501 + "texts": "List[str] - recognized text strings", + "confidences": "List[float] - confidence scores per detection", + }, + "properties": { + "reading_order": True, + "layout_analysis": True, + "multi_line_support": True, + "multi_block_support": True, + }, + "description": "Structured OCR results with bounding boxes, recognized text, and confidence scores", + } + + @property + def input_batch_size(self) -> int: + """Maximum or default input batch size.""" + return 8 diff --git a/nemo_retriever/src/nemo_retriever/ocr/ocr.py b/nemo_retriever/src/nemo_retriever/ocr/ocr.py index db814eaf7..782c334ac 100644 --- a/nemo_retriever/src/nemo_retriever/ocr/ocr.py +++ b/nemo_retriever/src/nemo_retriever/ocr/ocr.py @@ -874,6 +874,145 @@ def postprocess(self, data: Any, **kwargs: Any) -> Any: return data +# --------------------------------------------------------------------------- +# OCR v2 Actors +# --------------------------------------------------------------------------- + + +class OCRV2Actor(AbstractOperator, GPUOperator): + """ + Ray-friendly callable that initializes Nemotron OCR v2 once per actor. + + Identical interface to :class:`OCRActor` but loads the v2 model + (multilingual, higher throughput). The v2 model supports English, + Chinese (Simplified & Traditional), Japanese, Korean, and Russian. + + Usage with Ray Data:: + + ds = ds.map_batches( + OCRV2Actor, + batch_size=16, batch_format="pandas", num_cpus=4, num_gpus=1, + compute=ray.data.ActorPoolStrategy(size=8), + fn_constructor_kwargs={ + "extract_tables": True, + "extract_charts": True, + "extract_infographics": False, + }, + ) + """ + + def __init__(self, **ocr_kwargs: Any) -> None: + super().__init__(**ocr_kwargs) + import warnings + + if Image is not None: + warnings.filterwarnings("ignore", category=Image.DecompressionBombWarning) + + self.ocr_kwargs = dict(ocr_kwargs) + invoke_url = str(self.ocr_kwargs.get("ocr_invoke_url") or self.ocr_kwargs.get("invoke_url") or "").strip() + if invoke_url and "invoke_url" not in self.ocr_kwargs: + self.ocr_kwargs["invoke_url"] = invoke_url + + self.ocr_kwargs["extract_text"] = bool(self.ocr_kwargs.get("extract_text", False)) + self.ocr_kwargs["extract_tables"] = bool(self.ocr_kwargs.get("extract_tables", False)) + self.ocr_kwargs["extract_charts"] = bool(self.ocr_kwargs.get("extract_charts", False)) + self.ocr_kwargs["extract_infographics"] = bool(self.ocr_kwargs.get("extract_infographics", False)) + self.ocr_kwargs["use_graphic_elements"] = bool(self.ocr_kwargs.get("use_graphic_elements", False)) + self.ocr_kwargs["request_timeout_s"] = float(self.ocr_kwargs.get("request_timeout_s", 120.0)) + self.ocr_kwargs["inference_batch_size"] = int(self.ocr_kwargs.get("inference_batch_size", 8)) + + self._remote_retry = RemoteRetryParams( + remote_max_pool_workers=int(self.ocr_kwargs.get("remote_max_pool_workers", 16)), + remote_max_retries=int(self.ocr_kwargs.get("remote_max_retries", 10)), + remote_max_429_retries=int(self.ocr_kwargs.get("remote_max_429_retries", 5)), + ) + if invoke_url: + self._model = None + else: + from nemo_retriever.model.local import NemotronOCRV2 + + self._model = NemotronOCRV2() + + def preprocess(self, data: Any, **kwargs: Any) -> Any: + return data + + def process(self, data: Any, **kwargs: Any) -> Any: + return ocr_page_elements( + data, + model=self._model, + remote_retry=self._remote_retry, + **self.ocr_kwargs, + **kwargs, + ) + + def postprocess(self, data: Any, **kwargs: Any) -> Any: + return data + + def __call__(self, batch_df: Any, **override_kwargs: Any) -> Any: + try: + return self.run(batch_df, **override_kwargs) + except BaseException as e: + if isinstance(batch_df, pd.DataFrame): + out = batch_df.copy() + payload = _error_payload(stage="actor_call", exc=e) + n = len(out.index) + out["table"] = [[] for _ in range(n)] + out["chart"] = [[] for _ in range(n)] + out["infographic"] = [[] for _ in range(n)] + out["ocr_v1"] = [payload for _ in range(n)] + return out + return [{"ocr_v1": _error_payload(stage="actor_call", exc=e)}] + + +class OCRV2CPUActor(AbstractOperator, CPUOperator): + """CPU-only variant of :class:`OCRV2Actor`. + + Defaults to the build.nvidia.com endpoint for ``nemotron-ocr-v2``. + No local GPU model is loaded. + """ + + DEFAULT_INVOKE_URL = "https://ai.api.nvidia.com/v1/cv/nvidia/nemotron-ocr-v2" + + def __init__(self, **ocr_kwargs: Any) -> None: + super().__init__(**ocr_kwargs) + self.ocr_kwargs = dict(ocr_kwargs) + invoke_url = str( + self.ocr_kwargs.get("ocr_invoke_url") or self.ocr_kwargs.get("invoke_url") or self.DEFAULT_INVOKE_URL + ).strip() + if "invoke_url" not in self.ocr_kwargs: + self.ocr_kwargs["invoke_url"] = invoke_url + + self.ocr_kwargs["extract_text"] = bool(self.ocr_kwargs.get("extract_text", False)) + self.ocr_kwargs["extract_tables"] = bool(self.ocr_kwargs.get("extract_tables", False)) + self.ocr_kwargs["extract_charts"] = bool(self.ocr_kwargs.get("extract_charts", False)) + self.ocr_kwargs["extract_infographics"] = bool(self.ocr_kwargs.get("extract_infographics", False)) + self.ocr_kwargs["use_graphic_elements"] = bool(self.ocr_kwargs.get("use_graphic_elements", False)) + self.ocr_kwargs["request_timeout_s"] = float(self.ocr_kwargs.get("request_timeout_s", 120.0)) + self.ocr_kwargs["inference_batch_size"] = int(self.ocr_kwargs.get("inference_batch_size", 8)) + + self._remote_retry = RemoteRetryParams( + remote_max_pool_workers=int(self.ocr_kwargs.get("remote_max_pool_workers", 16)), + remote_max_retries=int(self.ocr_kwargs.get("remote_max_retries", 10)), + remote_max_429_retries=int(self.ocr_kwargs.get("remote_max_429_retries", 5)), + ) + self._model = None + + def preprocess(self, data: Any, **kwargs: Any) -> Any: + return data + + def process(self, data: Any, **kwargs: Any) -> Any: + return ocr_page_elements( + data, + model=self._model, + remote_retry=self._remote_retry, + **self.ocr_kwargs, + **kwargs, + ) + + def postprocess(self, data: Any, **kwargs: Any) -> Any: + return data + + # --------------------------------------------------------------------------- # Nemotron Parse v1.2 # ---------------------------------------------------------------------------