diff --git a/modules/detect_target/detect_target_factory.py b/modules/detect_target/detect_target_factory.py index 376456cb..32350984 100644 --- a/modules/detect_target/detect_target_factory.py +++ b/modules/detect_target/detect_target_factory.py @@ -30,7 +30,10 @@ def create_detect_target( local_logger: logger.Logger, ) -> tuple[bool, base_detect_target.BaseDetectTarget | None]: """ - Construct detect target class at runtime. + Factory function to create a detection target object. + + Return: + Success, detect target object. """ match detect_target_option: case DetectTargetOption.ML_ULTRALYTICS: diff --git a/modules/detect_target/detect_target_ultralytics.py b/modules/detect_target/detect_target_ultralytics.py index d26e762c..0df1351b 100644 --- a/modules/detect_target/detect_target_ultralytics.py +++ b/modules/detect_target/detect_target_ultralytics.py @@ -5,6 +5,7 @@ import time import cv2 +import torch import ultralytics from . import base_detect_target @@ -18,6 +19,8 @@ class DetectTargetUltralyticsConfig: Configuration for DetectTargetUltralytics. """ + CPU_DEVICE = "cpu" + def __init__( self, device: "str | int", @@ -55,13 +58,19 @@ def __init__( show_annotations: Display annotated images. save_name: filename prefix for logging detections and annotated images. """ + self.__local_logger = local_logger self.__device = config.device - self.__enable_half_precision = not self.__device == "cpu" + if ( + self.__device != DetectTargetUltralyticsConfig.CPU_DEVICE + and not torch.cuda.is_available() + ): + self.__local_logger.warning("CUDA not available. Falling back to CPU.") + self.__device = DetectTargetUltralyticsConfig.CPU_DEVICE + self.__enable_half_precision = self.__device != DetectTargetUltralyticsConfig.CPU_DEVICE self.__model = ultralytics.YOLO(config.model_path) if config.override_full: self.__enable_half_precision = False self.__counter = 0 - self.__local_logger = local_logger self.__show_annotations = show_annotations self.__filename_prefix = "" if save_name != "": @@ -127,11 +136,16 @@ def run( filename = self.__filename_prefix + str(self.__counter) # Annotated image - cv2.imwrite(filename + ".png", image_annotated) # type: ignore + cv2.imwrite(filename + ".png", image_annotated) self.__counter += 1 if self.__show_annotations: - cv2.imshow("Annotated", image_annotated) # type: ignore + if image_annotated is not None: + # Display the annotated image in a named window + cv2.imshow("Annotated", image_annotated) + cv2.waitKey(1) # Short delay to process GUI events + else: + self.__local_logger.warning("Annotated image is invalid.") return True, detections