Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Highlight the closest target with a different color in predictions #13018

Open
wants to merge 28 commits into
base: master
Choose a base branch
from
Open
Changes from 4 commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
eaab232
Add a function to find the nearest target
1260635600 May 16, 2024
d458a53
Add a function to find the nearest target
1260635600 May 16, 2024
33238cf
Add a function to find the closest target
1260635600 May 16, 2024
c962fca
Auto-format by https://ultralytics.com/actions
UltralyticsAssistant May 16, 2024
929cd2f
Merge branch 'master' into add-find-nearest-target
UltralyticsAssistant May 18, 2024
cba2a46
Merge branch 'master' into add-find-nearest-target
UltralyticsAssistant May 24, 2024
b35181f
Merge branch 'master' into add-find-nearest-target
UltralyticsAssistant May 28, 2024
5e723a3
Merge branch 'master' into add-find-nearest-target
UltralyticsAssistant May 29, 2024
992f8b1
Merge branch 'master' into add-find-nearest-target
UltralyticsAssistant May 29, 2024
68d4b54
Merge branch 'master' into add-find-nearest-target
UltralyticsAssistant May 30, 2024
dd569ed
Merge branch 'master' into add-find-nearest-target
UltralyticsAssistant Jun 8, 2024
a2fb804
Merge branch 'master' into add-find-nearest-target
UltralyticsAssistant Jun 8, 2024
87c43f0
Merge branch 'master' into add-find-nearest-target
UltralyticsAssistant Jun 9, 2024
6e5520a
Merge branch 'master' into add-find-nearest-target
UltralyticsAssistant Jun 16, 2024
fe0e916
Merge branch 'master' into add-find-nearest-target
UltralyticsAssistant Jun 16, 2024
5ea7bd4
Merge branch 'master' into add-find-nearest-target
UltralyticsAssistant Jun 17, 2024
5fdb2d5
Merge branch 'master' into add-find-nearest-target
UltralyticsAssistant Jun 19, 2024
ab5148a
Merge branch 'master' into add-find-nearest-target
UltralyticsAssistant Jun 19, 2024
44d7532
Merge branch 'master' into add-find-nearest-target
UltralyticsAssistant Jun 20, 2024
6bffc46
Merge branch 'master' into add-find-nearest-target
UltralyticsAssistant Jun 20, 2024
dc50ec0
Merge branch 'master' into add-find-nearest-target
UltralyticsAssistant Jun 20, 2024
e64165f
Merge branch 'master' into add-find-nearest-target
UltralyticsAssistant Jun 20, 2024
2fca818
Merge branch 'master' into add-find-nearest-target
UltralyticsAssistant Jun 22, 2024
6dc5254
Merge branch 'master' into add-find-nearest-target
UltralyticsAssistant Jun 30, 2024
34c11d8
Merge branch 'master' into add-find-nearest-target
UltralyticsAssistant Jun 30, 2024
2041ea2
Merge branch 'master' into add-find-nearest-target
UltralyticsAssistant Jun 30, 2024
8e5a840
Merge branch 'master' into add-find-nearest-target
UltralyticsAssistant Jul 5, 2024
a83cd91
Merge branch 'master' into add-find-nearest-target
UltralyticsAssistant Jul 8, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 34 additions & 16 deletions detect.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
import sys
from pathlib import Path

import numpy as np
import torch

FILE = Path(__file__).resolve()
Expand Down Expand Up @@ -130,30 +131,20 @@ def run(

# Run inference
model.warmup(imgsz=(1 if pt or model.triton else bs, 3, *imgsz)) # warmup
seen, windows, dt = 0, [], (Profile(device=device), Profile(device=device), Profile(device=device))
seen, windows, dt = 0, [], (Profile(), Profile(), Profile())
for path, im, im0s, vid_cap, s in dataset:
with dt[0]:
im = torch.from_numpy(im).to(model.device)
im = im.half() if model.fp16 else im.float() # uint8 to fp16/32
im /= 255 # 0 - 255 to 0.0 - 1.0
if len(im.shape) == 3:
im = im[None] # expand for batch dim
if model.xml and im.shape[0] > 1:
ims = torch.chunk(im, im.shape[0], 0)

# Inference
with dt[1]:
visualize = increment_path(save_dir / Path(path).stem, mkdir=True) if visualize else False
if model.xml and im.shape[0] > 1:
pred = None
for image in ims:
if pred is None:
pred = model(image, augment=augment, visualize=visualize).unsqueeze(0)
else:
pred = torch.cat((pred, model(image, augment=augment, visualize=visualize).unsqueeze(0)), dim=0)
pred = [pred, None]
else:
pred = model(im, augment=augment, visualize=visualize)
pred = model(im, augment=augment, visualize=visualize)

# NMS
with dt[2]:
pred = non_max_suppression(pred, conf_thres, iou_thres, classes, agnostic_nms, max_det=max_det)
Expand All @@ -166,7 +157,6 @@ def run(

# Create or append to the CSV file
def write_to_csv(image_name, prediction, confidence):
"""Writes prediction data for an image to a CSV file, appending if the file exists."""
data = {"Image Name": image_name, "Prediction": prediction, "Confidence": confidence}
with open(csv_path, mode="a", newline="") as f:
writer = csv.DictWriter(f, fieldnames=data.keys())
Expand All @@ -190,10 +180,37 @@ def write_to_csv(image_name, prediction, confidence):
gn = torch.tensor(im0.shape)[[1, 0, 1, 0]] # normalization gain whwh
imc = im0.copy() if save_crop else im0 # for save_crop
annotator = Annotator(im0, line_width=line_thickness, example=str(names))

# Calculate the center of the image
img_center = np.array([im0.shape[1] // 2, im0.shape[0] // 2])
min_distance = None
closest_box = None

if len(det):
# Rescale boxes from img_size to im0 size
det[:, :4] = scale_boxes(im.shape[2:], det[:, :4], im0.shape).round()

# Calculate the center of the image
img_center = np.array([im0.shape[1] // 2, im0.shape[0] // 2])

# Calculate centers of all detection boxes and find the closest one to the image center
centers = np.array(
[
[(xyxy[0].cpu() + xyxy[2].cpu()) / 2, (xyxy[1].cpu() + xyxy[3].cpu()) / 2]
for *xyxy, _, _ in reversed(det)
]
)
distances = np.linalg.norm(centers - img_center, axis=1)
closest_idx = np.argmin(distances)

# Draw boxes, marking the closest one in green
for j, (*xyxy, conf, cls) in enumerate(reversed(det)):
color = (0, 255, 0) if j == closest_idx else colors(int(cls), True)
annotator.box_label(xyxy, f"{names[int(cls)]} {conf:.2f}", color=color)

# Rescale boxes from img_size to im0 size
det[:, :4] = scale_boxes(im.shape[2:], det[:, :4], im0.shape).round()

# Print results
for c in det[:, 5].unique():
n = (det[:, 5] == c).sum() # detections per class
Expand Down Expand Up @@ -265,7 +282,6 @@ def write_to_csv(image_name, prediction, confidence):


def parse_opt():
"""Parses command-line arguments for YOLOv5 detection, setting inference options and model configurations."""
parser = argparse.ArgumentParser()
parser.add_argument("--weights", nargs="+", type=str, default=ROOT / "yolov5s.pt", help="model path or triton URL")
parser.add_argument("--source", type=str, default=ROOT / "data/images", help="file/dir/URL/glob/screen/0(webcam)")
Expand Down Expand Up @@ -302,11 +318,13 @@ def parse_opt():


def main(opt):
"""Executes YOLOv5 model inference with given options, checking requirements before running the model."""
check_requirements(ROOT / "requirements.txt", exclude=("tensorboard", "thop"))
run(**vars(opt))


# python detect.py --weights runs/train/exp10/weights/best.pt --source project/test
# python detect.py --weights runs/train/exp10/weights/best.pt --source project/test

if __name__ == "__main__":
opt = parse_opt()
main(opt)