Skip to content

Commit

Permalink
Fix onnx prediction (#202)
Browse files Browse the repository at this point in the history
* fix imge size

* fix model path

* fix test

* style

* fix size

* keep tuple

* mypy
  • Loading branch information
MateoLostanlen authored May 29, 2024
1 parent 45b19e0 commit c18de6b
Show file tree
Hide file tree
Showing 3 changed files with 21 additions and 65 deletions.
48 changes: 1 addition & 47 deletions pyroengine/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,10 @@
# See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.


import cv2 # type: ignore[import-untyped]
import numpy as np
from tqdm import tqdm # type: ignore[import-untyped]

__all__ = ["letterbox", "nms", "xywh2xyxy", "DownloadProgressBar"]
__all__ = ["nms", "xywh2xyxy", "DownloadProgressBar"]


def xywh2xyxy(x: np.ndarray):
Expand All @@ -20,51 +19,6 @@ def xywh2xyxy(x: np.ndarray):
return y


def letterbox(
im: np.ndarray, new_shape: tuple = (640, 640), color: tuple = (0, 0, 0), auto: bool = False, stride: int = 32
):
"""Letterbox image transform for yolo models
Args:
im (np.ndarray): Input image
new_shape (tuple, optional): Image size. Defaults to (640, 640).
color (tuple, optional): Pixel fill value for the area outside the transformed image.
Defaults to (0, 0, 0).
auto (bool, optional): auto padding. Defaults to True.
stride (int, optional): padding stride. Defaults to 32.
Returns:
np.ndarray: Output image
"""
# Resize and pad image while meeting stride-multiple constraints
im = np.array(im)
shape = im.shape[:2] # current shape [height, width]
if isinstance(new_shape, int):
new_shape = (new_shape, new_shape)

# Scale ratio (new / old)
r = min(new_shape[0] / shape[0], new_shape[1] / shape[1])

# Compute padding
new_unpad = int(round(shape[1] * r)), int(round(shape[0] * r))
dw, dh = new_shape[1] - new_unpad[0], new_shape[0] - new_unpad[1] # wh padding

if auto: # minimum rectangle
dw, dh = np.mod(dw, stride), np.mod(dh, stride) # wh padding

dw /= 2 # divide padding into 2 sides
dh /= 2

if shape[::-1] != new_unpad: # resize
im = cv2.resize(im, new_unpad, interpolation=cv2.INTER_LINEAR)
top, bottom = int(round(dh - 0.1)), int(round(dh + 0.1))
left, right = int(round(dw - 0.1)), int(round(dw + 0.1))
# add border
h, w = im.shape[:2]
im_b = np.zeros((h + top + bottom, w + left + right, 3)) + color
im_b[top : top + h, left : left + w, :] = im

return im_b.astype("uint8"), (left, top)


def box_iou(box1: np.ndarray, box2: np.ndarray, eps: float = 1e-7):
"""
Calculate intersection-over-union (IoU) of boxes.
Expand Down
31 changes: 17 additions & 14 deletions pyroengine/vision.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,15 +7,16 @@
from typing import Optional, Tuple
from urllib.request import urlretrieve

import cv2 # type: ignore[import-untyped]
import numpy as np
import onnxruntime
from PIL import Image

from .utils import DownloadProgressBar, letterbox, nms, xywh2xyxy
from .utils import DownloadProgressBar, nms, xywh2xyxy

__all__ = ["Classifier"]

MODEL_URL = "https://huggingface.co/pyronear/yolov8s/resolve/main/yolov8s.onnx"
MODEL_URL = "https://huggingface.co/pyronear/yolov8s/resolve/main/model.onnx"


class Classifier:
Expand All @@ -29,7 +30,7 @@ class Classifier:
model_path: model path
"""

def __init__(self, model_path: Optional[str] = "data/model.onnx", img_size: tuple = (1024, 1024)) -> None:
def __init__(self, model_path: Optional[str] = "data/model.onnx", base_img_size: int = 1024) -> None:
if model_path is None:
model_path = "data/model.onnx"

Expand All @@ -41,9 +42,9 @@ def __init__(self, model_path: Optional[str] = "data/model.onnx", img_size: tupl
print("Model downloaded!")

self.ort_session = onnxruntime.InferenceSession(model_path)
self.img_size = img_size
self.base_img_size = base_img_size

def preprocess_image(self, pil_img: Image.Image) -> Tuple[np.ndarray, Tuple[int, int]]:
def preprocess_image(self, pil_img: Image.Image, new_img_size: list) -> Tuple[np.ndarray, Tuple[int, int]]:
"""Preprocess an image for inference
Args:
Expand All @@ -55,15 +56,20 @@ def preprocess_image(self, pil_img: Image.Image) -> Tuple[np.ndarray, Tuple[int,
- Padding information as a tuple of integers (pad_height, pad_width).
"""

np_img, pad = letterbox(np.array(pil_img), self.img_size) # Applies letterbox resize with padding
np_img = cv2.resize(np.array(pil_img), new_img_size, interpolation=cv2.INTER_LINEAR)
np_img = np.expand_dims(np_img.astype("float"), axis=0) # Add batch dimension
np_img = np.ascontiguousarray(np_img.transpose((0, 3, 1, 2))) # Convert from BHWC to BCHW format
np_img = np_img.astype("float32") / 255 # Normalize to [0, 1]

return np_img, pad
return np_img

def __call__(self, pil_img: Image.Image, occlusion_mask: Optional[np.ndarray] = None) -> np.ndarray:
np_img, pad = self.preprocess_image(pil_img)

w, h = pil_img.size
ratio = self.base_img_size / max(w, h)
new_img_size = [int(ratio * w), int(ratio * h)]
new_img_size = [x - x % 32 for x in new_img_size] # size need to be a multiple of 32 to fit the model
np_img = self.preprocess_image(pil_img, new_img_size)

# ONNX inference
y = self.ort_session.run(["output0"], {"images": np_img})[0][0]
Expand All @@ -78,12 +84,9 @@ def __call__(self, pil_img: Image.Image, occlusion_mask: Optional[np.ndarray] =

# Normalize preds
if len(y) > 0:
# Remove padding
left_pad, top_pad = pad
y[:, :4:2] -= left_pad
y[:, 1:4:2] -= top_pad
y[:, :4:2] /= self.img_size[1] - 2 * left_pad
y[:, 1:4:2] /= self.img_size[0] - 2 * top_pad
# Normalize Output
y[:, :4:2] /= new_img_size[0]
y[:, 1:4:2] /= new_img_size[1]
else:
y = np.zeros((0, 5)) # normalize output

Expand Down
7 changes: 3 additions & 4 deletions tests/test_vision.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,18 +7,17 @@ def test_classifier(mock_wildfire_image):
# Instantiate the ONNX model
model = Classifier()
# Check preprocessing
out, pad = model.preprocess_image(mock_wildfire_image)
out = model.preprocess_image(mock_wildfire_image, (1024, 576))
assert isinstance(out, np.ndarray) and out.dtype == np.float32
assert out.shape == (1, 3, 1024, 1024)
assert isinstance(pad, tuple)
assert out.shape == (1, 3, 576, 1024)
# Check inference
out = model(mock_wildfire_image)
assert out.shape == (1, 5)
conf = np.max(out[:, 4])
assert conf >= 0 and conf <= 1

# Test mask
mask = np.ones((1024, 640))
mask = np.ones((1024, 576))
out = model(mock_wildfire_image, mask)
assert out.shape == (1, 5)

Expand Down

0 comments on commit c18de6b

Please sign in to comment.