@@ -18,6 +18,7 @@ class DetectTargetUltralyticsConfig:
18
18
"""
19
19
Configuration for DetectTargetUltralytics.
20
20
"""
21
+ CPU_DEVICE = "cpu"
21
22
22
23
def __init__ (
23
24
self ,
@@ -57,7 +58,10 @@ def __init__(
57
58
save_name: filename prefix for logging detections and annotated images.
58
59
"""
59
60
self .__device = config .device
60
- self .__enable_half_precision = self .__device != "cpu"
61
+ if self .__device != DetectTargetUltralyticsConfig .CPU_DEVICE and not torch .cuda .is_available ():
62
+ self .__local_logger .warning ("CUDA not available. Falling back to CPU." )
63
+ self .__device = DetectTargetUltralyticsConfig .CPU_DEVICE
64
+ self .__enable_half_precision = self .__device != DetectTargetUltralyticsConfig .CPU_DEVICE
61
65
self .__model = ultralytics .YOLO (config .model_path )
62
66
if config .override_full :
63
67
self .__enable_half_precision = False
@@ -68,9 +72,7 @@ def __init__(
68
72
if save_name != "" :
69
73
self .__filename_prefix = save_name + "_" + str (int (time .time ())) + "_"
70
74
71
- if self .__device != "cpu" and not torch .cuda .is_available ():
72
- self .__local_logger .warning ("CUDA not available. Falling back to CPU." )
73
- self .__device = "cpu"
75
+
74
76
75
77
def run (
76
78
self , data : image_and_time .ImageAndTime
@@ -132,7 +134,7 @@ def run(
132
134
filename = self .__filename_prefix + str (self .__counter )
133
135
134
136
# Annotated image
135
- cv2 .imwrite (filename + ".png" , image_annotated ) # type: ignore
137
+ cv2 .imwrite (filename + ".png" , image_annotated )
136
138
137
139
self .__counter += 1
138
140
0 commit comments