1
- import time
2
1
import argparse
2
+ import time
3
3
from pathlib import Path
4
4
5
5
import cv2
10
10
11
11
from imagenet_classes import IMAGENET2012_CLASSES
12
12
13
+
13
14
def parse_arguments ():
14
15
parser = argparse .ArgumentParser (description = "Video inference with TensorRT" )
15
16
parser .add_argument ("--output_video" , type = str , help = "Path to output video file" )
16
17
parser .add_argument ("--input_video" , type = str , help = "Path to input video file" )
17
18
parser .add_argument ("--webcam" , action = "store_true" , help = "Use webcam as input" )
18
- parser .add_argument ("--live" , action = "store_true" , help = "View video live during inference" )
19
+ parser .add_argument (
20
+ "--live" , action = "store_true" , help = "View video live during inference"
21
+ )
19
22
return parser .parse_args ()
20
23
24
+
21
25
def get_ort_session (model_path ):
22
26
providers = [
23
27
(
@@ -38,6 +42,7 @@ def get_ort_session(model_path):
38
42
]
39
43
return ort .InferenceSession (model_path , providers = providers )
40
44
45
+
41
46
def preprocess_frame (frame ):
42
47
# Use cv2 for resizing instead of PIL for better performance
43
48
resized = cv2 .resize (frame , (448 , 448 ), interpolation = cv2 .INTER_LINEAR )
@@ -46,46 +51,93 @@ def preprocess_frame(frame):
46
51
img_numpy = np .expand_dims (img_numpy , axis = 0 )
47
52
return img_numpy
48
53
54
+
49
55
def get_top_predictions (output , top_k = 5 ):
50
56
output = torch .from_numpy (output )
51
57
probabilities , class_indices = torch .topk (output .softmax (dim = 1 ) * 100 , k = top_k )
52
58
im_classes = list (IMAGENET2012_CLASSES .values ())
53
59
class_names = [im_classes [i ] for i in class_indices [0 ]]
54
60
return list (zip (class_names , probabilities [0 ].tolist ()))
55
61
62
+
56
63
def draw_predictions (frame , predictions , fps ):
57
- # Draw FPS in the top right corner
58
- cv2 .putText (frame , f"FPS: { fps :.2f} " , (frame .shape [1 ] - 150 , 30 ),
59
- cv2 .FONT_HERSHEY_SIMPLEX , 0.7 , (0 , 255 , 0 ), 2 )
60
-
64
+ # Draw FPS in the top right corner with dark blue background
65
+ fps_text = f"FPS: { fps :.2f} "
66
+ (text_width , text_height ), _ = cv2 .getTextSize (
67
+ fps_text , cv2 .FONT_HERSHEY_SIMPLEX , 0.7 , 2
68
+ )
69
+ text_offset_x = frame .shape [1 ] - text_width - 10
70
+ text_offset_y = 30
71
+ box_coords = (
72
+ (text_offset_x - 5 , text_offset_y + 5 ),
73
+ (text_offset_x + text_width + 5 , text_offset_y - text_height - 5 ),
74
+ )
75
+ cv2 .rectangle (
76
+ frame , box_coords [0 ], box_coords [1 ], (139 , 0 , 0 ), cv2 .FILLED
77
+ ) # Dark blue background
78
+ cv2 .putText (
79
+ frame ,
80
+ fps_text ,
81
+ (text_offset_x , text_offset_y ),
82
+ cv2 .FONT_HERSHEY_SIMPLEX ,
83
+ 0.7 ,
84
+ (255 , 255 , 255 ), # White text
85
+ 2 ,
86
+ )
87
+
61
88
# Draw predictions
62
89
for i , (name , prob ) in enumerate (predictions ):
63
90
text = f"{ name } : { prob :.2f} %"
64
- cv2 .putText (frame , text , (10 , 30 + i * 30 ), cv2 .FONT_HERSHEY_SIMPLEX , 0.7 , (0 , 255 , 0 ), 2 )
65
-
66
- # Draw model name at the bottom of the frame
91
+ cv2 .putText (
92
+ frame ,
93
+ text ,
94
+ (10 , 30 + i * 30 ),
95
+ cv2 .FONT_HERSHEY_SIMPLEX ,
96
+ 0.7 ,
97
+ (0 , 255 , 0 ),
98
+ 2 ,
99
+ )
100
+
101
+ # Draw model name at the bottom of the frame with red background
67
102
model_name = "Model: eva02_large_patch14_448"
68
- text_size = cv2 .getTextSize (model_name , cv2 .FONT_HERSHEY_SIMPLEX , 0.7 , 2 )[0 ]
69
- text_x = (frame .shape [1 ] - text_size [0 ]) // 2
103
+ (text_width , text_height ), _ = cv2 .getTextSize (
104
+ model_name , cv2 .FONT_HERSHEY_SIMPLEX , 0.7 , 2
105
+ )
106
+ text_x = (frame .shape [1 ] - text_width ) // 2
70
107
text_y = frame .shape [0 ] - 20
71
- cv2 .putText (frame , model_name , (text_x , text_y ),
72
- cv2 .FONT_HERSHEY_SIMPLEX , 0.7 , (0 , 255 , 0 ), 2 )
108
+ box_coords = (
109
+ (text_x - 5 , text_y + 5 ),
110
+ (text_x + text_width + 5 , text_y - text_height - 5 ),
111
+ )
112
+ cv2 .rectangle (
113
+ frame , box_coords [0 ], box_coords [1 ], (0 , 0 , 255 ), cv2 .FILLED
114
+ ) # Red background
115
+ cv2 .putText (
116
+ frame ,
117
+ model_name ,
118
+ (text_x , text_y ),
119
+ cv2 .FONT_HERSHEY_SIMPLEX ,
120
+ 0.7 ,
121
+ (255 , 255 , 255 ), # White text
122
+ 2 ,
123
+ )
73
124
74
125
return frame
75
126
127
+
76
128
def process_video (input_path , output_path , session , live_view = False , use_webcam = False ):
77
129
if use_webcam :
78
130
cap = cv2 .VideoCapture (0 )
79
131
else :
80
132
cap = cv2 .VideoCapture (input_path )
81
-
133
+
82
134
width = int (cap .get (cv2 .CAP_PROP_FRAME_WIDTH ))
83
135
height = int (cap .get (cv2 .CAP_PROP_FRAME_HEIGHT ))
84
136
fps = int (cap .get (cv2 .CAP_PROP_FPS ))
85
-
137
+
86
138
out = None
87
139
if output_path :
88
- fourcc = cv2 .VideoWriter_fourcc (* ' mp4v' )
140
+ fourcc = cv2 .VideoWriter_fourcc (* " mp4v" )
89
141
out = cv2 .VideoWriter (output_path , fourcc , fps , (width , height ))
90
142
91
143
input_name = session .get_inputs ()[0 ].name
@@ -101,29 +153,31 @@ def process_video(input_path, output_path, session, live_view=False, use_webcam=
101
153
break
102
154
103
155
start_time = time .time ()
104
-
156
+
105
157
preprocessed = preprocess_frame (frame )
106
158
output = session .run ([output_name ], {input_name : preprocessed })
107
159
predictions = get_top_predictions (output [0 ])
108
-
160
+
109
161
end_time = time .time ()
110
162
frame_time = end_time - start_time
111
163
current_fps = 1 / frame_time
112
-
164
+
113
165
frame_with_predictions = draw_predictions (frame , predictions , current_fps )
114
-
166
+
115
167
if out :
116
168
out .write (frame_with_predictions )
117
-
169
+
118
170
if live_view :
119
- cv2 .imshow (' Inference' , frame_with_predictions )
120
- if cv2 .waitKey (1 ) & 0xFF == ord ('q' ):
171
+ cv2 .imshow (" Inference" , frame_with_predictions )
172
+ if cv2 .waitKey (1 ) & 0xFF == ord ("q" ):
121
173
break
122
174
123
175
total_time += frame_time
124
176
frame_count += 1
125
177
126
- print (f"Processed frame { frame_count } , Time: { frame_time :.3f} s, FPS: { current_fps :.2f} " )
178
+ print (
179
+ f"Processed frame { frame_count } , Time: { frame_time :.3f} s, FPS: { current_fps :.2f} "
180
+ )
127
181
128
182
cap .release ()
129
183
if out :
@@ -134,10 +188,11 @@ def process_video(input_path, output_path, session, live_view=False, use_webcam=
134
188
print (f"Average processing time per frame: { avg_time :.3f} s" )
135
189
print (f"Average FPS: { 1 / avg_time :.2f} " )
136
190
191
+
137
192
def main ():
138
193
args = parse_arguments ()
139
194
session = get_ort_session ("merged_model_compose.onnx" )
140
-
195
+
141
196
if args .webcam :
142
197
process_video (None , args .output_video , session , args .live , use_webcam = True )
143
198
elif args .input_video :
@@ -146,5 +201,6 @@ def main():
146
201
print ("Error: Please specify either --input_video or --webcam" )
147
202
return
148
203
204
+
149
205
if __name__ == "__main__" :
150
- main ()
206
+ main ()
0 commit comments