|
5 | 5 | import logging |
6 | 6 | import os |
7 | 7 | from collections.abc import Iterable |
8 | | -from typing import Set, Union |
| 8 | +from typing import Set, Union, List |
9 | 9 |
|
10 | 10 | import numpy as np |
11 | 11 | import torch |
@@ -131,10 +131,18 @@ def predict(self, orig_img: Union[Image.Image, np.ndarray]) -> Iterable[dict]: |
131 | 131 | page_img = orig_img.convert("RGB") |
132 | 132 | elif isinstance(orig_img, np.ndarray): |
133 | 133 | page_img = Image.fromarray(orig_img).convert("RGB") |
| 134 | + elif isinstance(orig_img, List): |
| 135 | + if isinstance(orig_img[0], Image.Image): |
| 136 | + page_img = [img.convert("RGB") for img in orig_img] |
| 137 | + elif isinstance(orig_img[0], np.ndarray): |
| 138 | + page_img = [Image.fromarray(img).convert("RGB") for img in orig_img] |
| 139 | + else: |
| 140 | + raise TypeError("Not supported input image format") |
134 | 141 | else: |
135 | 142 | raise TypeError("Not supported input image format") |
136 | 143 |
|
137 | 144 | resize = {"height": self._image_size, "width": self._image_size} |
| 145 | + |
138 | 146 | inputs = self._image_processor( |
139 | 147 | images=page_img, |
140 | 148 | return_tensors="pt", |
@@ -175,3 +183,90 @@ def predict(self, orig_img: Union[Image.Image, np.ndarray]) -> Iterable[dict]: |
175 | 183 | "label": label_str, |
176 | 184 | "confidence": score, |
177 | 185 | } |
| 186 | + |
| 187 | + |
| 188 | + @torch.inference_mode() |
| 189 | + def predict_batch(self, orig_img: List[Union[Image.Image, np.ndarray]]) -> Iterable[dict]: |
| 190 | + """ |
| 191 | + Predict bounding boxes for a batch of page images. |
| 192 | + The origin (0, 0) is the top-left corner and the predicted bbox coords are provided as: |
| 193 | + [left, top, right, bottom] |
| 194 | +
|
| 195 | + Parameter |
| 196 | + --------- |
| 197 | + origin_img: List of images to be predicted as a PIL Image object or numpy array. |
| 198 | +
|
| 199 | + Yield |
| 200 | + ----- |
| 201 | + Iterable per page of bounding box as a dict with the keys: "label", "confidence", "l", "t", "r", "b" |
| 202 | +
|
| 203 | + Raises |
| 204 | + ------ |
| 205 | + TypeError when the input image is not supported |
| 206 | + """ |
| 207 | + # Convert image format |
| 208 | + if isinstance(orig_img[0], Image.Image): |
| 209 | + page_img = [img.convert("RGB") for img in orig_img] |
| 210 | + elif isinstance(orig_img[0], np.ndarray): |
| 211 | + page_img = [Image.fromarray(img).convert("RGB") for img in orig_img] |
| 212 | + else: |
| 213 | + raise TypeError("Not supported input image format") |
| 214 | + |
| 215 | + resize = {"height": self._image_size, "width": self._image_size} |
| 216 | + inputs = self._image_processor( |
| 217 | + images=page_img, |
| 218 | + return_tensors="pt", |
| 219 | + size=resize, |
| 220 | + ).to(self._device) |
| 221 | + |
| 222 | + target_sizes = torch.tensor([page_img[i].size[::-1] for i in range(len(page_img))]) |
| 223 | + |
| 224 | + outputs = self._model(**inputs) |
| 225 | + |
| 226 | + results = self._image_processor.post_process_object_detection( |
| 227 | + outputs, |
| 228 | + target_sizes=target_sizes, |
| 229 | + threshold=self._threshold, |
| 230 | + ) |
| 231 | + |
| 232 | + for batch_item_idx, result in enumerate(results): |
| 233 | + w, h = page_img[batch_item_idx].size |
| 234 | + yield self.postprocess_result(result, w, h) |
| 235 | + |
| 236 | + def postprocess_result(self, result: dict, w: int, h: int) -> Iterable[dict]: |
| 237 | + """ |
| 238 | + Postprocess the result of the layout prediction. |
| 239 | + |
| 240 | + Parameters |
| 241 | + ---------- |
| 242 | + result: The result of the layout prediction. |
| 243 | + w: The width of the image. |
| 244 | + h: The height of the image. |
| 245 | + |
| 246 | + Yields |
| 247 | + ------ |
| 248 | + Bounding box as a dict with the keys: "label", "confidence", "l", "t", "r", "b" |
| 249 | + """ |
| 250 | + for score, label_id, box in zip(result["scores"], result["labels"], result["boxes"]): |
| 251 | + score = float(score.item()) |
| 252 | + |
| 253 | + label_id = int(label_id.item()) + 1 # Advance the label_id |
| 254 | + label_str = self._classes_map[label_id] |
| 255 | + |
| 256 | + # Filter out blacklisted classes |
| 257 | + if label_str in self._black_classes: |
| 258 | + continue |
| 259 | + |
| 260 | + bbox_float = [float(b.item()) for b in box] |
| 261 | + l = min(w, max(0, bbox_float[0])) |
| 262 | + t = min(h, max(0, bbox_float[1])) |
| 263 | + r = min(w, max(0, bbox_float[2])) |
| 264 | + b = min(h, max(0, bbox_float[3])) |
| 265 | + yield { |
| 266 | + "l": l, |
| 267 | + "t": t, |
| 268 | + "r": r, |
| 269 | + "b": b, |
| 270 | + "label": label_str, |
| 271 | + "confidence": score, |
| 272 | + } |
0 commit comments