From 52de3d55e2964f1c97bc16631518dd483ee4daad Mon Sep 17 00:00:00 2001 From: SkalskiP Date: Wed, 18 Jan 2023 20:21:25 +0100 Subject: [PATCH 1/5] =?UTF-8?q?=F0=9F=9B=A0=EF=B8=8F=20fix=20default=20Col?= =?UTF-8?q?orPalette?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- supervision/draw/color.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/supervision/draw/color.py b/supervision/draw/color.py index 0e0c198da..3f44a3e6f 100644 --- a/supervision/draw/color.py +++ b/supervision/draw/color.py @@ -94,7 +94,7 @@ def blue(cls) -> Color: @dataclass class ColorPalette: - colors: List[Color] = field(default_factory=lambda: DEFAULT_COLOR_PALETTE) + colors: List[Color] = field(default_factory=lambda: ColorPalette.from_hex(DEFAULT_COLOR_PALETTE)) @classmethod def from_hex(cls, color_hex_list: List[str]): From dda91a7a80d4552c9aa1ec1734769510f5588940 Mon Sep 17 00:00:00 2001 From: SkalskiP Date: Wed, 18 Jan 2023 20:21:25 +0100 Subject: [PATCH 2/5] =?UTF-8?q?=F0=9F=9B=A0=EF=B8=8F=20fix=20default=20Col?= =?UTF-8?q?orPalette?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- supervision/draw/color.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/supervision/draw/color.py b/supervision/draw/color.py index 0e0c198da..7f3c6d3b1 100644 --- a/supervision/draw/color.py +++ b/supervision/draw/color.py @@ -94,7 +94,9 @@ def blue(cls) -> Color: @dataclass class ColorPalette: - colors: List[Color] = field(default_factory=lambda: DEFAULT_COLOR_PALETTE) + colors: List[Color] = field( + default_factory=lambda: [Color.from_hex(color_hex) for color_hex in DEFAULT_COLOR_PALETTE] + ) @classmethod def from_hex(cls, color_hex_list: List[str]): From bad6ff43a5f0c33030e10406101e15cee5f64dd1 Mon Sep 17 00:00:00 2001 From: SkalskiP Date: Thu, 19 Jan 2023 00:05:13 +0100 Subject: [PATCH 3/5] =?UTF-8?q?=E2=9C=8F=EF=B8=8F=20first=20iteration=20of?= =?UTF-8?q?=20Detections=20and=20BoxAnnotator=20classes=20added?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- supervision/draw/color.py | 4 +- supervision/{commons => geometry}/__init__.py | 0 .../{commons => geometry}/dataclasses.py | 0 .../commons => supervision/tools}/__init__.py | 0 supervision/tools/detections.py | 199 ++++++++++++++++++ test/geometry/__init__.py | 0 .../{commons => geometry}/test_dataclasses.py | 2 +- 7 files changed, 203 insertions(+), 2 deletions(-) rename supervision/{commons => geometry}/__init__.py (100%) rename supervision/{commons => geometry}/dataclasses.py (100%) rename {test/commons => supervision/tools}/__init__.py (100%) create mode 100644 supervision/tools/detections.py create mode 100644 test/geometry/__init__.py rename test/{commons => geometry}/test_dataclasses.py (96%) diff --git a/supervision/draw/color.py b/supervision/draw/color.py index 7f3c6d3b1..842dd2be1 100644 --- a/supervision/draw/color.py +++ b/supervision/draw/color.py @@ -95,7 +95,9 @@ def blue(cls) -> Color: @dataclass class ColorPalette: colors: List[Color] = field( - default_factory=lambda: [Color.from_hex(color_hex) for color_hex in DEFAULT_COLOR_PALETTE] + default_factory=lambda: [ + Color.from_hex(color_hex) for color_hex in DEFAULT_COLOR_PALETTE + ] ) @classmethod diff --git a/supervision/commons/__init__.py b/supervision/geometry/__init__.py similarity index 100% rename from supervision/commons/__init__.py rename to supervision/geometry/__init__.py diff --git a/supervision/commons/dataclasses.py b/supervision/geometry/dataclasses.py similarity index 100% rename from supervision/commons/dataclasses.py rename to supervision/geometry/dataclasses.py diff --git a/test/commons/__init__.py b/supervision/tools/__init__.py similarity index 100% rename from test/commons/__init__.py rename to supervision/tools/__init__.py diff --git a/supervision/tools/detections.py b/supervision/tools/detections.py new file mode 100644 index 000000000..e143150a1 --- /dev/null +++ b/supervision/tools/detections.py @@ -0,0 +1,199 @@ +from typing import List, Optional, Union + +import cv2 +import numpy as np + +from supervision.draw.color import Color, ColorPalette + + +class Detections: + def __init__( + self, + xyxy: np.ndarray, + confidence: np.ndarray, + class_id: np.ndarray, + tracker_id: Optional[np.ndarray] = None, + ): + """ + Data class containing information about the detections in a video frame. + + :param xyxy: np.ndarray : An array of shape (n, 4) containing the bounding boxes coordinates in format [x1, y1, x2, y2] + :param confidence: np.ndarray : An array of shape (n,) containing the confidence scores of the detections. + :param class_id: np.ndarray : An array of shape (n,) containing the class ids of the detections. + :param tracker_id: Optional[np.ndarray] : An array of shape (n,) containing the tracker ids of the detections. + """ + self.xyxy: np.ndarray = xyxy + self.confidence: np.ndarray = confidence + self.class_id: np.ndarray = class_id + self.tracker_id: Optional[np.ndarray] = tracker_id + + n = len(self.xyxy) + validators = [ + (isinstance(self.xyxy, np.ndarray) and self.xyxy.shape == (n, 4)), + (isinstance(self.confidence, np.ndarray) and self.confidence.shape == (n,)), + (isinstance(self.class_id, np.ndarray) and self.class_id.shape == (n,)), + self.tracker_id is None + or ( + isinstance(self.tracker_id, np.ndarray) + and self.tracker_id.shape == (n,) + ), + ] + if not all(validators): + raise ValueError( + "xyxy must be 2d np.ndarray with (n, 4) shape, " + "confidence must be 1d np.ndarray with (n,) shape, " + "class_id must be 1d np.ndarray with (n,) shape, " + "tracker_id must be None or 1d np.ndarray with (n,) shape" + ) + + def __len__(self): + """ + Returns the number of detections in the Detections object. + """ + return len(self.xyxy) + + def __iter__(self): + """ + Iterates over the Detections object and yield a tuple of (xyxy, confidence, class_id, tracker_id) for each detection. + """ + for i in range(len(self.xyxy)): + yield ( + self.xyxy[i], + self.confidence[i], + self.class_id[i], + self.tracker_id[i] if self.tracker_id is not None else None, + ) + + @classmethod + def from_yolov5(cls, yolov5_output: np.ndarray): + """ + Creates a Detections instance from a YOLOv5 output tensor + + :param yolov5_output: np.ndarray : The output tensor from YOLOv5 + :return: Detections : A Detections instance representing the detections in the frame + + Example: + detections = Detections.from_yolov5(yolov5_output) + """ + xyxy = yolov5_output[:, :4] + confidence = yolov5_output[:, 4] + class_id = yolov5_output[:, 5].astype(int) + return cls(xyxy, confidence, class_id) + + def filter(self, mask: np.ndarray, inplace: bool = False) -> Optional[np.ndarray]: + """ + Filter the detections by applying a mask + + :param mask: np.ndarray : A mask of shape (n,) containing a boolean value for each detection indicating if it should be included in the filtered detections + :param inplace: bool : If True, the original data will be modified and self will be returned. + :return: Optional[np.ndarray] : A new instance of Detections with the filtered detections, if inplace is set to False. None otherwise. + """ + if inplace: + self.xyxy = self.xyxy[mask] + self.confidence = self.confidence[mask] + self.class_id = self.class_id[mask] + self.tracker_id = ( + self.tracker_id[mask] if self.tracker_id is not None else None + ) + return self + else: + return Detections( + xyxy=self.xyxy[mask], + confidence=self.confidence[mask], + class_id=self.class_id[mask], + tracker_id=self.tracker_id[mask] + if self.tracker_id is not None + else None, + ) + + +class BoxAnnotator: + def __init__( + self, + color: Union[Color, ColorPalette], + thickness: int = 2, + text_color: Color = Color.black(), + text_scale: float = 0.5, + text_thickness: int = 1, + text_padding: int = 10, + ): + """ + A class for drawing bounding boxes on an image using detections provided. + + :param color: Union[Color, ColorPalette] : The color to draw the bounding box, can be a single color or a color palette + :param thickness: int : The thickness of the bounding box lines, default is 2 + :param text_color: Color : The color of the text on the bounding box, default is white + :param text_scale: float : The scale of the text on the bounding box, default is 0.5 + :param text_thickness: int : The thickness of the text on the bounding box, default is 1 + :param text_padding: int : The padding around the text on the bounding box, default is 5 + """ + self.color: Union[Color, ColorPalette] = color + self.thickness: int = thickness + self.text_color: Color = text_color + self.text_scale: float = text_scale + self.text_thickness: int = text_thickness + self.text_padding: int = text_padding + + def annotate( + self, + frame: np.ndarray, + detections: Detections, + labels: Optional[List[str]] = None, + ) -> np.ndarray: + """ + Draws bounding boxes on the frame using the detections provided. + + :param frame: np.ndarray : The image on which the bounding boxes will be drawn + :param detections: Detections : The detections for which the bounding boxes will be drawn + :param labels: Optional[List[str]] : An optional list of labels corresponding to each detection. If labels is provided, the confidence score of the detection will be replaced with the label. + :return: np.ndarray : The image with the bounding boxes drawn on it + """ + font = cv2.FONT_HERSHEY_SIMPLEX + for i, (xyxy, confidence, class_id, tracker_id) in enumerate(detections): + color = ( + self.color.by_idx(class_id) + if isinstance(self.color, ColorPalette) + else self.color + ) + + x1, y1, x2, y2 = xyxy.astype(int) + cv2.rectangle(frame, (x1, y1), (x2, y2), color.as_bgr(), self.thickness) + + text = ( + f"{confidence:0.2f}" + if (labels is None or len(detections) != len(labels)) + else labels[i] + ) + + text_size = cv2.getTextSize( + text, font, self.text_scale, self.text_thickness + )[0] + text_width, text_height = text_size + + text_x = x1 + self.text_padding + text_y = y1 - self.text_padding + + text_background_x1 = x1 + text_background_y1 = y1 - 2 * self.text_padding - text_height + + text_background_x2 = x1 + 2 * self.text_padding + text_width + text_background_y2 = y1 + + cv2.rectangle( + frame, + (text_background_x1, text_background_y1), + (text_background_x2, text_background_y2), + color.as_bgr(), + cv2.FILLED, + ) + cv2.putText( + frame, + text, + (text_x, text_y), + font, + self.text_scale, + self.text_color.as_rgb(), + self.text_thickness, + cv2.LINE_AA, + ) + return frame diff --git a/test/geometry/__init__.py b/test/geometry/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/test/commons/test_dataclasses.py b/test/geometry/test_dataclasses.py similarity index 96% rename from test/commons/test_dataclasses.py rename to test/geometry/test_dataclasses.py index ebb88fc1d..5e8298773 100644 --- a/test/commons/test_dataclasses.py +++ b/test/geometry/test_dataclasses.py @@ -1,6 +1,6 @@ import pytest -from supervision.commons.dataclasses import Vector, Point +from supervision.geometry.dataclasses import Vector, Point @pytest.mark.parametrize( From 981b1e87225bfbeb0093f759b834bfa51308f0c6 Mon Sep 17 00:00:00 2001 From: SkalskiP Date: Thu, 19 Jan 2023 00:05:13 +0100 Subject: [PATCH 4/5] =?UTF-8?q?=E2=9C=8F=EF=B8=8F=20first=20iteration=20of?= =?UTF-8?q?=20Detections=20and=20BoxAnnotator=20classes=20added?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- supervision/draw/color.py | 4 +- supervision/{commons => geometry}/__init__.py | 0 .../{commons => geometry}/dataclasses.py | 0 .../commons => supervision/tools}/__init__.py | 0 supervision/tools/detections.py | 202 ++++++++++++++++++ test/geometry/__init__.py | 0 .../{commons => geometry}/test_dataclasses.py | 2 +- 7 files changed, 206 insertions(+), 2 deletions(-) rename supervision/{commons => geometry}/__init__.py (100%) rename supervision/{commons => geometry}/dataclasses.py (100%) rename {test/commons => supervision/tools}/__init__.py (100%) create mode 100644 supervision/tools/detections.py create mode 100644 test/geometry/__init__.py rename test/{commons => geometry}/test_dataclasses.py (96%) diff --git a/supervision/draw/color.py b/supervision/draw/color.py index 7f3c6d3b1..842dd2be1 100644 --- a/supervision/draw/color.py +++ b/supervision/draw/color.py @@ -95,7 +95,9 @@ def blue(cls) -> Color: @dataclass class ColorPalette: colors: List[Color] = field( - default_factory=lambda: [Color.from_hex(color_hex) for color_hex in DEFAULT_COLOR_PALETTE] + default_factory=lambda: [ + Color.from_hex(color_hex) for color_hex in DEFAULT_COLOR_PALETTE + ] ) @classmethod diff --git a/supervision/commons/__init__.py b/supervision/geometry/__init__.py similarity index 100% rename from supervision/commons/__init__.py rename to supervision/geometry/__init__.py diff --git a/supervision/commons/dataclasses.py b/supervision/geometry/dataclasses.py similarity index 100% rename from supervision/commons/dataclasses.py rename to supervision/geometry/dataclasses.py diff --git a/test/commons/__init__.py b/supervision/tools/__init__.py similarity index 100% rename from test/commons/__init__.py rename to supervision/tools/__init__.py diff --git a/supervision/tools/detections.py b/supervision/tools/detections.py new file mode 100644 index 000000000..a5b67d6e0 --- /dev/null +++ b/supervision/tools/detections.py @@ -0,0 +1,202 @@ +from typing import List, Optional, Union + +import cv2 +import numpy as np + +from supervision.draw.color import Color, ColorPalette + + +class Detections: + def __init__( + self, + xyxy: np.ndarray, + confidence: np.ndarray, + class_id: np.ndarray, + tracker_id: Optional[np.ndarray] = None, + ): + """ + Data class containing information about the detections in a video frame. + + :param xyxy: np.ndarray : An array of shape (n, 4) containing the bounding boxes coordinates in format [x1, y1, x2, y2] + :param confidence: np.ndarray : An array of shape (n,) containing the confidence scores of the detections. + :param class_id: np.ndarray : An array of shape (n,) containing the class ids of the detections. + :param tracker_id: Optional[np.ndarray] : An array of shape (n,) containing the tracker ids of the detections. + """ + self.xyxy: np.ndarray = xyxy + self.confidence: np.ndarray = confidence + self.class_id: np.ndarray = class_id + self.tracker_id: Optional[np.ndarray] = tracker_id + + n = len(self.xyxy) + validators = [ + (isinstance(self.xyxy, np.ndarray) and self.xyxy.shape == (n, 4)), + (isinstance(self.confidence, np.ndarray) and self.confidence.shape == (n,)), + (isinstance(self.class_id, np.ndarray) and self.class_id.shape == (n,)), + self.tracker_id is None + or ( + isinstance(self.tracker_id, np.ndarray) + and self.tracker_id.shape == (n,) + ), + ] + if not all(validators): + raise ValueError( + "xyxy must be 2d np.ndarray with (n, 4) shape, " + "confidence must be 1d np.ndarray with (n,) shape, " + "class_id must be 1d np.ndarray with (n,) shape, " + "tracker_id must be None or 1d np.ndarray with (n,) shape" + ) + + def __len__(self): + """ + Returns the number of detections in the Detections object. + """ + return len(self.xyxy) + + def __iter__(self): + """ + Iterates over the Detections object and yield a tuple of (xyxy, confidence, class_id, tracker_id) for each detection. + """ + for i in range(len(self.xyxy)): + yield ( + self.xyxy[i], + self.confidence[i], + self.class_id[i], + self.tracker_id[i] if self.tracker_id is not None else None, + ) + + @classmethod + def from_yolov5(cls, yolov5_output: np.ndarray): + """ + Creates a Detections instance from a YOLOv5 output tensor + + :param yolov5_output: np.ndarray : The output tensor from YOLOv5 + :return: Detections : A Detections instance representing the detections in the frame + + Example: + detections = Detections.from_yolov5(yolov5_output) + """ + xyxy = yolov5_output[:, :4] + confidence = yolov5_output[:, 4] + class_id = yolov5_output[:, 5].astype(int) + return cls(xyxy, confidence, class_id) + + def filter(self, mask: np.ndarray, inplace: bool = False) -> Optional[np.ndarray]: + """ + Filter the detections by applying a mask + + :param mask: np.ndarray : A mask of shape (n,) containing a boolean value for each detection indicating if it should be included in the filtered detections + :param inplace: bool : If True, the original data will be modified and self will be returned. + :return: Optional[np.ndarray] : A new instance of Detections with the filtered detections, if inplace is set to False. None otherwise. + """ + if inplace: + self.xyxy = self.xyxy[mask] + self.confidence = self.confidence[mask] + self.class_id = self.class_id[mask] + self.tracker_id = ( + self.tracker_id[mask] if self.tracker_id is not None else None + ) + return self + else: + return Detections( + xyxy=self.xyxy[mask], + confidence=self.confidence[mask], + class_id=self.class_id[mask], + tracker_id=self.tracker_id[mask] + if self.tracker_id is not None + else None, + ) + + +class BoxAnnotator: + def __init__( + self, + color: Union[Color, ColorPalette], + thickness: int = 2, + text_color: Color = Color.black(), + text_scale: float = 0.5, + text_thickness: int = 1, + text_padding: int = 10, + ): + """ + A class for drawing bounding boxes on an image using detections provided. + + :param color: Union[Color, ColorPalette] : The color to draw the bounding box, can be a single color or a color palette + :param thickness: int : The thickness of the bounding box lines, default is 2 + :param text_color: Color : The color of the text on the bounding box, default is white + :param text_scale: float : The scale of the text on the bounding box, default is 0.5 + :param text_thickness: int : The thickness of the text on the bounding box, default is 1 + :param text_padding: int : The padding around the text on the bounding box, default is 5 + """ + self.color: Union[Color, ColorPalette] = color + self.thickness: int = thickness + self.text_color: Color = text_color + self.text_scale: float = text_scale + self.text_thickness: int = text_thickness + self.text_padding: int = text_padding + + def annotate( + self, + frame: np.ndarray, + detections: Detections, + labels: Optional[List[str]] = None, + ) -> np.ndarray: + """ + Draws bounding boxes on the frame using the detections provided. + + :param frame: np.ndarray : The image on which the bounding boxes will be drawn + :param detections: Detections : The detections for which the bounding boxes will be drawn + :param labels: Optional[List[str]] : An optional list of labels corresponding to each detection. If labels is provided, the confidence score of the detection will be replaced with the label. + :return: np.ndarray : The image with the bounding boxes drawn on it + """ + font = cv2.FONT_HERSHEY_SIMPLEX + for i, (xyxy, confidence, class_id, tracker_id) in enumerate(detections): + color = ( + self.color.by_idx(class_id) + if isinstance(self.color, ColorPalette) + else self.color + ) + text = ( + f"{confidence:0.2f}" + if (labels is None or len(detections) != len(labels)) + else labels[i] + ) + + x1, y1, x2, y2 = xyxy.astype(int) + text_width, text_height = cv2.getTextSize( + text=text, fontFace=font, fontScale=self.text_scale, thickness=self.text_thickness + )[0] + + text_x = x1 + self.text_padding + text_y = y1 - self.text_padding + + text_background_x1 = x1 + text_background_y1 = y1 - 2 * self.text_padding - text_height + + text_background_x2 = x1 + 2 * self.text_padding + text_width + text_background_y2 = y1 + + cv2.rectangle( + image=frame, + start_point=(x1, y1), + end_point=(x2, y2), + color=color.as_bgr(), + thickness=self.thickness + ) + cv2.rectangle( + image=frame, + start_point=(text_background_x1, text_background_y1), + end_point=(text_background_x2, text_background_y2), + color=color.as_bgr(), + thickness=cv2.FILLED, + ) + cv2.putText( + image=frame, + text=text, + org=(text_x, text_y), + font=font, + fontScale=self.text_scale, + color=self.text_color.as_rgb(), + thickness=self.text_thickness, + lineType=cv2.LINE_AA, + ) + return frame diff --git a/test/geometry/__init__.py b/test/geometry/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/test/commons/test_dataclasses.py b/test/geometry/test_dataclasses.py similarity index 96% rename from test/commons/test_dataclasses.py rename to test/geometry/test_dataclasses.py index ebb88fc1d..5e8298773 100644 --- a/test/commons/test_dataclasses.py +++ b/test/geometry/test_dataclasses.py @@ -1,6 +1,6 @@ import pytest -from supervision.commons.dataclasses import Vector, Point +from supervision.geometry.dataclasses import Vector, Point @pytest.mark.parametrize( From 93f1dd5511e5ba6fb28ffeea0ff84066a334e43b Mon Sep 17 00:00:00 2001 From: SkalskiP Date: Thu, 19 Jan 2023 01:17:51 +0100 Subject: [PATCH 5/5] style and check_code_quality --- supervision/tools/detections.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/supervision/tools/detections.py b/supervision/tools/detections.py index f2a3ad5f4..2cf548679 100644 --- a/supervision/tools/detections.py +++ b/supervision/tools/detections.py @@ -163,7 +163,10 @@ def annotate( x1, y1, x2, y2 = xyxy.astype(int) text_width, text_height = cv2.getTextSize( - text=text, fontFace=font, fontScale=self.text_scale, thickness=self.text_thickness + text=text, + fontFace=font, + fontScale=self.text_scale, + thickness=self.text_thickness, )[0] text_x = x1 + self.text_padding @@ -180,7 +183,7 @@ def annotate( pt1=(x1, y1), pt2=(x2, y2), color=color.as_bgr(), - thickness=self.thickness + thickness=self.thickness, ) cv2.rectangle( img=frame,