Skip to content

Commit aa2e39b

Browse files
committed
add batch prediction for layout model
Signed-off-by: SteffiVanHees <[email protected]>
1 parent f438c60 commit aa2e39b

File tree

1 file changed

+96
-1
lines changed

1 file changed

+96
-1
lines changed

docling_ibm_models/layoutmodel/layout_predictor.py

Lines changed: 96 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import logging
66
import os
77
from collections.abc import Iterable
8-
from typing import Set, Union
8+
from typing import Set, Union, List
99

1010
import numpy as np
1111
import torch
@@ -131,10 +131,18 @@ def predict(self, orig_img: Union[Image.Image, np.ndarray]) -> Iterable[dict]:
131131
page_img = orig_img.convert("RGB")
132132
elif isinstance(orig_img, np.ndarray):
133133
page_img = Image.fromarray(orig_img).convert("RGB")
134+
elif isinstance(orig_img, List):
135+
if isinstance(orig_img[0], Image.Image):
136+
page_img = [img.convert("RGB") for img in orig_img]
137+
elif isinstance(orig_img[0], np.ndarray):
138+
page_img = [Image.fromarray(img).convert("RGB") for img in orig_img]
139+
else:
140+
raise TypeError("Not supported input image format")
134141
else:
135142
raise TypeError("Not supported input image format")
136143

137144
resize = {"height": self._image_size, "width": self._image_size}
145+
138146
inputs = self._image_processor(
139147
images=page_img,
140148
return_tensors="pt",
@@ -175,3 +183,90 @@ def predict(self, orig_img: Union[Image.Image, np.ndarray]) -> Iterable[dict]:
175183
"label": label_str,
176184
"confidence": score,
177185
}
186+
187+
188+
@torch.inference_mode()
189+
def predict_batch(self, orig_img: List[Union[Image.Image, np.ndarray]]) -> Iterable[dict]:
190+
"""
191+
Predict bounding boxes for a batch of page images.
192+
The origin (0, 0) is the top-left corner and the predicted bbox coords are provided as:
193+
[left, top, right, bottom]
194+
195+
Parameter
196+
---------
197+
origin_img: List of images to be predicted as a PIL Image object or numpy array.
198+
199+
Yield
200+
-----
201+
Iterable per page of bounding box as a dict with the keys: "label", "confidence", "l", "t", "r", "b"
202+
203+
Raises
204+
------
205+
TypeError when the input image is not supported
206+
"""
207+
# Convert image format
208+
if isinstance(orig_img[0], Image.Image):
209+
page_img = [img.convert("RGB") for img in orig_img]
210+
elif isinstance(orig_img[0], np.ndarray):
211+
page_img = [Image.fromarray(img).convert("RGB") for img in orig_img]
212+
else:
213+
raise TypeError("Not supported input image format")
214+
215+
resize = {"height": self._image_size, "width": self._image_size}
216+
inputs = self._image_processor(
217+
images=page_img,
218+
return_tensors="pt",
219+
size=resize,
220+
).to(self._device)
221+
222+
target_sizes = torch.tensor([page_img[i].size[::-1] for i in range(len(page_img))])
223+
224+
outputs = self._model(**inputs)
225+
226+
results = self._image_processor.post_process_object_detection(
227+
outputs,
228+
target_sizes=target_sizes,
229+
threshold=self._threshold,
230+
)
231+
232+
for batch_item_idx, result in enumerate(results):
233+
w, h = page_img[batch_item_idx].size
234+
yield self.postprocess_result(result, w, h)
235+
236+
def postprocess_result(self, result: dict, w: int, h: int) -> Iterable[dict]:
237+
"""
238+
Postprocess the result of the layout prediction.
239+
240+
Parameters
241+
----------
242+
result: The result of the layout prediction.
243+
w: The width of the image.
244+
h: The height of the image.
245+
246+
Yields
247+
------
248+
Bounding box as a dict with the keys: "label", "confidence", "l", "t", "r", "b"
249+
"""
250+
for score, label_id, box in zip(result["scores"], result["labels"], result["boxes"]):
251+
score = float(score.item())
252+
253+
label_id = int(label_id.item()) + 1 # Advance the label_id
254+
label_str = self._classes_map[label_id]
255+
256+
# Filter out blacklisted classes
257+
if label_str in self._black_classes:
258+
continue
259+
260+
bbox_float = [float(b.item()) for b in box]
261+
l = min(w, max(0, bbox_float[0]))
262+
t = min(h, max(0, bbox_float[1]))
263+
r = min(w, max(0, bbox_float[2]))
264+
b = min(h, max(0, bbox_float[3]))
265+
yield {
266+
"l": l,
267+
"t": t,
268+
"r": r,
269+
"b": b,
270+
"label": label_str,
271+
"confidence": score,
272+
}

0 commit comments

Comments
 (0)