Skip to content

Commit 3fff35e

Browse files
committed
update
1 parent fc040a1 commit 3fff35e

6 files changed

+123
-73
lines changed

03_onnx_cpu_inference.py

+16-2
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,11 @@
33

44
import numpy as np
55
import onnxruntime as ort
6+
import torch
67
from PIL import Image
78

9+
from imagenet_classes import IMAGENET2012_CLASSES
10+
811
img = Image.open(
912
urlopen(
1013
"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/beignets-task-guide.png"
@@ -17,8 +20,8 @@ def transforms_numpy(image: Image.Image):
1720
image = image.resize((448, 448), Image.BICUBIC)
1821
img_numpy = np.array(image).astype(np.float32) / 255.0
1922
img_numpy = img_numpy.transpose(2, 0, 1)
20-
mean = np.array([0.485, 0.456, 0.406]).reshape(-1, 1, 1)
21-
std = np.array([0.229, 0.224, 0.225]).reshape(-1, 1, 1)
23+
mean = np.array([0.4815, 0.4578, 0.4082]).reshape(-1, 1, 1)
24+
std = np.array([0.2686, 0.2613, 0.2758]).reshape(-1, 1, 1)
2225
img_numpy = (img_numpy - mean) / std
2326
img_numpy = np.expand_dims(img_numpy, axis=0)
2427
img_numpy = img_numpy.astype(np.float32)
@@ -36,6 +39,17 @@ def transforms_numpy(image: Image.Image):
3639
# Run inference
3740
output = session.run([output_name], {input_name: transforms_numpy(img)})[0]
3841

42+
# Check the output
43+
output = torch.from_numpy(output)
44+
top5_probabilities, top5_class_indices = torch.topk(output.softmax(dim=1) * 100, k=5)
45+
46+
im_classes = list(IMAGENET2012_CLASSES.values())
47+
class_names = [im_classes[i] for i in top5_class_indices[0]]
48+
49+
# Print class names and probabilities
50+
for name, prob in zip(class_names, top5_probabilities[0]):
51+
print(f"{name}: {prob:.2f}%")
52+
3953
# Run benchmark
4054
num_images = 10
4155
start = time.perf_counter()

04_onnx_cuda_inference.py

+19-4
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,11 @@
44
import cupy as cp
55
import numpy as np
66
import onnxruntime as ort
7+
import torch
78
from PIL import Image
89

10+
from imagenet_classes import IMAGENET2012_CLASSES
11+
912
img = Image.open(
1013
urlopen(
1114
"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/beignets-task-guide.png"
@@ -18,8 +21,8 @@ def transforms_numpy(image: Image.Image):
1821
image = image.resize((448, 448), Image.BICUBIC)
1922
img_numpy = np.array(image).astype(np.float32) / 255.0
2023
img_numpy = img_numpy.transpose(2, 0, 1)
21-
mean = np.array([0.485, 0.456, 0.406]).reshape(-1, 1, 1)
22-
std = np.array([0.229, 0.224, 0.225]).reshape(-1, 1, 1)
24+
mean = np.array([0.4815, 0.4578, 0.4082]).reshape(-1, 1, 1)
25+
std = np.array([0.2686, 0.2613, 0.2758]).reshape(-1, 1, 1)
2326
img_numpy = (img_numpy - mean) / std
2427
img_numpy = np.expand_dims(img_numpy, axis=0)
2528
img_numpy = img_numpy.astype(np.float32)
@@ -36,8 +39,8 @@ def transforms_cupy(image: Image.Image):
3639
img_cupy = img_cupy.transpose(2, 0, 1)
3740

3841
# Apply mean and std normalization
39-
mean = cp.array([0.485, 0.456, 0.406], dtype=cp.float32).reshape(-1, 1, 1)
40-
std = cp.array([0.229, 0.224, 0.225], dtype=cp.float32).reshape(-1, 1, 1)
42+
mean = cp.array([0.4815, 0.4578, 0.4082], dtype=cp.float32).reshape(-1, 1, 1)
43+
std = cp.array([0.2686, 0.2613, 0.2758], dtype=cp.float32).reshape(-1, 1, 1)
4144
img_cupy = (img_cupy - mean) / std
4245

4346
# Add batch dimension
@@ -57,6 +60,18 @@ def transforms_cupy(image: Image.Image):
5760
# Run inference
5861
output = session.run([output_name], {input_name: transforms_numpy(img)})[0]
5962

63+
64+
# Check the output
65+
output = torch.from_numpy(output)
66+
top5_probabilities, top5_class_indices = torch.topk(output.softmax(dim=1) * 100, k=5)
67+
68+
im_classes = list(IMAGENET2012_CLASSES.values())
69+
class_names = [im_classes[i] for i in top5_class_indices[0]]
70+
71+
# Print class names and probabilities
72+
for name, prob in zip(class_names, top5_probabilities[0]):
73+
print(f"{name}: {prob:.2f}%")
74+
6075
# Run benchmark numpy
6176
num_images = 100
6277
start = time.perf_counter()

05_onnx_trt_inference.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,8 @@ def transforms_numpy(image: Image.Image):
1818
image = image.resize((448, 448), Image.BICUBIC)
1919
img_numpy = np.array(image).astype(np.float32) / 255.0
2020
img_numpy = img_numpy.transpose(2, 0, 1)
21-
mean = np.array([0.485, 0.456, 0.406]).reshape(-1, 1, 1)
22-
std = np.array([0.229, 0.224, 0.225]).reshape(-1, 1, 1)
21+
mean = np.array([0.4815, 0.4578, 0.4082]).reshape(-1, 1, 1)
22+
std = np.array([0.2686, 0.2613, 0.2758]).reshape(-1, 1, 1)
2323
img_numpy = (img_numpy - mean) / std
2424
img_numpy = np.expand_dims(img_numpy, axis=0)
2525
img_numpy = img_numpy.astype(np.float32)
@@ -36,8 +36,8 @@ def transforms_cupy(image: Image.Image):
3636
img_cupy = img_cupy.transpose(2, 0, 1)
3737

3838
# Apply mean and std normalization
39-
mean = cp.array([0.485, 0.456, 0.406], dtype=cp.float32).reshape(-1, 1, 1)
40-
std = cp.array([0.229, 0.224, 0.225], dtype=cp.float32).reshape(-1, 1, 1)
39+
mean = cp.array([0.4815, 0.4578, 0.4082], dtype=cp.float32).reshape(-1, 1, 1)
40+
std = cp.array([0.2686, 0.2613, 0.2758], dtype=cp.float32).reshape(-1, 1, 1)
4141
img_cupy = (img_cupy - mean) / std
4242

4343
# Add batch dimension

06_export_preprocessing_onnx.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,8 @@ class Preprocess(nn.Module):
1010
def __init__(self, input_shape: List[int]):
1111
super(Preprocess, self).__init__()
1212
self.input_shape = tuple(input_shape)
13-
self.mean = torch.tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1)
14-
self.std = torch.tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1)
13+
self.mean = torch.tensor([0.4815, 0.4578, 0.4082]).view(1, 3, 1, 1)
14+
self.std = torch.tensor([0.2686, 0.2613, 0.2758]).view(1, 3, 1, 1)
1515

1616
def forward(self, x: torch.Tensor):
1717
x = torch.nn.functional.interpolate(

08_inference_merged_model.py

-35
Original file line numberDiff line numberDiff line change
@@ -16,41 +16,6 @@
1616
)
1717

1818

19-
def transforms_numpy(image: Image.Image):
20-
image = image.convert("RGB")
21-
image = image.resize((448, 448), Image.BICUBIC)
22-
img_numpy = np.array(image).astype(np.float32) / 255.0
23-
img_numpy = img_numpy.transpose(2, 0, 1)
24-
25-
mean = np.array([0.485, 0.456, 0.406]).reshape(-1, 1, 1)
26-
std = np.array([0.229, 0.224, 0.225]).reshape(-1, 1, 1)
27-
img_numpy = (img_numpy - mean) / std
28-
img_numpy = np.expand_dims(img_numpy, axis=0)
29-
img_numpy = img_numpy.astype(np.float32)
30-
31-
return img_numpy
32-
33-
34-
def transforms_cupy(image: Image.Image):
35-
# Convert image to RGB and resize
36-
image = image.convert("RGB")
37-
image = image.resize((448, 448), Image.BICUBIC)
38-
39-
# Convert to CuPy array and normalize
40-
img_cupy = cp.array(image, dtype=cp.float32) / 255.0
41-
img_cupy = img_cupy.transpose(2, 0, 1)
42-
43-
# Apply mean and std normalization
44-
mean = cp.array([0.485, 0.456, 0.406], dtype=cp.float32).reshape(-1, 1, 1)
45-
std = cp.array([0.229, 0.224, 0.225], dtype=cp.float32).reshape(-1, 1, 1)
46-
img_cupy = (img_cupy - mean) / std
47-
48-
# Add batch dimension
49-
img_cupy = cp.expand_dims(img_cupy, axis=0)
50-
51-
return img_cupy
52-
53-
5419
def read_image(image: Image.Image):
5520
image = image.convert("RGB")
5621
img_numpy = np.array(image).astype(np.float32)

09_video_inference.py

+82-26
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
1-
import time
21
import argparse
2+
import time
33
from pathlib import Path
44

55
import cv2
@@ -10,14 +10,18 @@
1010

1111
from imagenet_classes import IMAGENET2012_CLASSES
1212

13+
1314
def parse_arguments():
1415
parser = argparse.ArgumentParser(description="Video inference with TensorRT")
1516
parser.add_argument("--output_video", type=str, help="Path to output video file")
1617
parser.add_argument("--input_video", type=str, help="Path to input video file")
1718
parser.add_argument("--webcam", action="store_true", help="Use webcam as input")
18-
parser.add_argument("--live", action="store_true", help="View video live during inference")
19+
parser.add_argument(
20+
"--live", action="store_true", help="View video live during inference"
21+
)
1922
return parser.parse_args()
2023

24+
2125
def get_ort_session(model_path):
2226
providers = [
2327
(
@@ -38,6 +42,7 @@ def get_ort_session(model_path):
3842
]
3943
return ort.InferenceSession(model_path, providers=providers)
4044

45+
4146
def preprocess_frame(frame):
4247
# Use cv2 for resizing instead of PIL for better performance
4348
resized = cv2.resize(frame, (448, 448), interpolation=cv2.INTER_LINEAR)
@@ -46,46 +51,93 @@ def preprocess_frame(frame):
4651
img_numpy = np.expand_dims(img_numpy, axis=0)
4752
return img_numpy
4853

54+
4955
def get_top_predictions(output, top_k=5):
5056
output = torch.from_numpy(output)
5157
probabilities, class_indices = torch.topk(output.softmax(dim=1) * 100, k=top_k)
5258
im_classes = list(IMAGENET2012_CLASSES.values())
5359
class_names = [im_classes[i] for i in class_indices[0]]
5460
return list(zip(class_names, probabilities[0].tolist()))
5561

62+
5663
def draw_predictions(frame, predictions, fps):
57-
# Draw FPS in the top right corner
58-
cv2.putText(frame, f"FPS: {fps:.2f}", (frame.shape[1] - 150, 30),
59-
cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 255, 0), 2)
60-
64+
# Draw FPS in the top right corner with dark blue background
65+
fps_text = f"FPS: {fps:.2f}"
66+
(text_width, text_height), _ = cv2.getTextSize(
67+
fps_text, cv2.FONT_HERSHEY_SIMPLEX, 0.7, 2
68+
)
69+
text_offset_x = frame.shape[1] - text_width - 10
70+
text_offset_y = 30
71+
box_coords = (
72+
(text_offset_x - 5, text_offset_y + 5),
73+
(text_offset_x + text_width + 5, text_offset_y - text_height - 5),
74+
)
75+
cv2.rectangle(
76+
frame, box_coords[0], box_coords[1], (139, 0, 0), cv2.FILLED
77+
) # Dark blue background
78+
cv2.putText(
79+
frame,
80+
fps_text,
81+
(text_offset_x, text_offset_y),
82+
cv2.FONT_HERSHEY_SIMPLEX,
83+
0.7,
84+
(255, 255, 255), # White text
85+
2,
86+
)
87+
6188
# Draw predictions
6289
for i, (name, prob) in enumerate(predictions):
6390
text = f"{name}: {prob:.2f}%"
64-
cv2.putText(frame, text, (10, 30 + i * 30), cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 255, 0), 2)
65-
66-
# Draw model name at the bottom of the frame
91+
cv2.putText(
92+
frame,
93+
text,
94+
(10, 30 + i * 30),
95+
cv2.FONT_HERSHEY_SIMPLEX,
96+
0.7,
97+
(0, 255, 0),
98+
2,
99+
)
100+
101+
# Draw model name at the bottom of the frame with red background
67102
model_name = "Model: eva02_large_patch14_448"
68-
text_size = cv2.getTextSize(model_name, cv2.FONT_HERSHEY_SIMPLEX, 0.7, 2)[0]
69-
text_x = (frame.shape[1] - text_size[0]) // 2
103+
(text_width, text_height), _ = cv2.getTextSize(
104+
model_name, cv2.FONT_HERSHEY_SIMPLEX, 0.7, 2
105+
)
106+
text_x = (frame.shape[1] - text_width) // 2
70107
text_y = frame.shape[0] - 20
71-
cv2.putText(frame, model_name, (text_x, text_y),
72-
cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 255, 0), 2)
108+
box_coords = (
109+
(text_x - 5, text_y + 5),
110+
(text_x + text_width + 5, text_y - text_height - 5),
111+
)
112+
cv2.rectangle(
113+
frame, box_coords[0], box_coords[1], (0, 0, 255), cv2.FILLED
114+
) # Red background
115+
cv2.putText(
116+
frame,
117+
model_name,
118+
(text_x, text_y),
119+
cv2.FONT_HERSHEY_SIMPLEX,
120+
0.7,
121+
(255, 255, 255), # White text
122+
2,
123+
)
73124

74125
return frame
75126

127+
76128
def process_video(input_path, output_path, session, live_view=False, use_webcam=False):
77129
if use_webcam:
78130
cap = cv2.VideoCapture(0)
79131
else:
80132
cap = cv2.VideoCapture(input_path)
81-
133+
82134
width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
83135
height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
84136
fps = int(cap.get(cv2.CAP_PROP_FPS))
85-
137+
86138
out = None
87139
if output_path:
88-
fourcc = cv2.VideoWriter_fourcc(*'mp4v')
140+
fourcc = cv2.VideoWriter_fourcc(*"mp4v")
89141
out = cv2.VideoWriter(output_path, fourcc, fps, (width, height))
90142

91143
input_name = session.get_inputs()[0].name
@@ -101,29 +153,31 @@ def process_video(input_path, output_path, session, live_view=False, use_webcam=
101153
break
102154

103155
start_time = time.time()
104-
156+
105157
preprocessed = preprocess_frame(frame)
106158
output = session.run([output_name], {input_name: preprocessed})
107159
predictions = get_top_predictions(output[0])
108-
160+
109161
end_time = time.time()
110162
frame_time = end_time - start_time
111163
current_fps = 1 / frame_time
112-
164+
113165
frame_with_predictions = draw_predictions(frame, predictions, current_fps)
114-
166+
115167
if out:
116168
out.write(frame_with_predictions)
117-
169+
118170
if live_view:
119-
cv2.imshow('Inference', frame_with_predictions)
120-
if cv2.waitKey(1) & 0xFF == ord('q'):
171+
cv2.imshow("Inference", frame_with_predictions)
172+
if cv2.waitKey(1) & 0xFF == ord("q"):
121173
break
122174

123175
total_time += frame_time
124176
frame_count += 1
125177

126-
print(f"Processed frame {frame_count}, Time: {frame_time:.3f}s, FPS: {current_fps:.2f}")
178+
print(
179+
f"Processed frame {frame_count}, Time: {frame_time:.3f}s, FPS: {current_fps:.2f}"
180+
)
127181

128182
cap.release()
129183
if out:
@@ -134,10 +188,11 @@ def process_video(input_path, output_path, session, live_view=False, use_webcam=
134188
print(f"Average processing time per frame: {avg_time:.3f}s")
135189
print(f"Average FPS: {1/avg_time:.2f}")
136190

191+
137192
def main():
138193
args = parse_arguments()
139194
session = get_ort_session("merged_model_compose.onnx")
140-
195+
141196
if args.webcam:
142197
process_video(None, args.output_video, session, args.live, use_webcam=True)
143198
elif args.input_video:
@@ -146,5 +201,6 @@ def main():
146201
print("Error: Please specify either --input_video or --webcam")
147202
return
148203

204+
149205
if __name__ == "__main__":
150-
main()
206+
main()

0 commit comments

Comments
 (0)