Skip to content

Commit

Permalink
Rename threshold to conf_threshold for clarity
Browse files Browse the repository at this point in the history
  • Loading branch information
SpecLad committed Nov 14, 2024
1 parent 79c09ee commit f92da49
Show file tree
Hide file tree
Showing 10 changed files with 37 additions and 34 deletions.
15 changes: 9 additions & 6 deletions changelog.d/20241112_201034_roman_aa_threshold.md
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
### Added

- \[SDK, CLI\] Added a `threshold` parameter to `cvat_sdk.auto_annotation.annotate_task`,
which is passed as-is to the AA function object via the context. The CLI
equivalent is `auto-annotate --threshold`. This makes it easier to write
and use AA functions that support object filtering based on confidence
levels. Updated the builtin functions in `cvat_sdk.auto_annotation.functions`
to support filtering via this parameter
- \[SDK, CLI\] Added a `conf_threshold` parameter to
`cvat_sdk.auto_annotation.annotate_task`, which is passed as-is to the AA
function object via the context. The CLI equivalent is `auto-annotate
--conf-threshold`. This makes it easier to write and use AA functions that
support object filtering based on confidence levels
(<https://github.com/cvat-ai/cvat/pull/8688>)

- \[SDK\] Built-in auto-annotation functions now support object filtering by
confidence level
(<https://github.com/cvat-ai/cvat/pull/8688>)
6 changes: 3 additions & 3 deletions cvat-cli/src/cvat_cli/_internal/commands.py
Original file line number Diff line number Diff line change
Expand Up @@ -470,7 +470,7 @@ def configure_parser(self, parser: argparse.ArgumentParser) -> None:
)

parser.add_argument(
"--threshold",
"--conf-threshold",
type=parse_threshold,
help="Confidence threshold for filtering detections",
default=None,
Expand All @@ -486,7 +486,7 @@ def execute(
function_parameters: dict[str, Any],
clear_existing: bool = False,
allow_unmatched_labels: bool = False,
threshold: Optional[float],
conf_threshold: Optional[float],
) -> None:
if function_module is not None:
function = importlib.import_module(function_module)
Expand All @@ -511,5 +511,5 @@ def execute(
pbar=DeferredTqdmProgressReporter(),
clear_existing=clear_existing,
allow_unmatched_labels=allow_unmatched_labels,
threshold=threshold,
conf_threshold=conf_threshold,
)
14 changes: 7 additions & 7 deletions cvat-sdk/cvat_sdk/auto_annotation/driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,7 +223,7 @@ def validate_and_remap(self, shapes: list[models.LabeledShapeRequest], ds_frame:
@attrs.frozen(kw_only=True)
class _DetectionFunctionContextImpl(DetectionFunctionContext):
frame_name: str
threshold: Optional[float] = None
conf_threshold: Optional[float] = None


def annotate_task(
Expand All @@ -234,7 +234,7 @@ def annotate_task(
pbar: Optional[ProgressReporter] = None,
clear_existing: bool = False,
allow_unmatched_labels: bool = False,
threshold: Optional[float] = None,
conf_threshold: Optional[float] = None,
) -> None:
"""
Downloads data for the task with the given ID, applies the given function to it
Expand Down Expand Up @@ -267,15 +267,15 @@ def annotate_task(
If it's set to true, then such labels are allowed, and any annotations returned by the
function that refer to this label are ignored. Otherwise, BadFunctionError is raised.
The threshold parameter must be None or a number between 0 and 1. It will be passed
to the function as the threshold attribute of the context object.
The conf_threshold parameter must be None or a number between 0 and 1. It will be passed
to the function as the conf_threshold attribute of the context object.
"""

if pbar is None:
pbar = NullProgressReporter()

if threshold is not None and not 0 <= threshold <= 1:
raise ValueError("threshold must be None or a number between 0 and 1")
if conf_threshold is not None and not 0 <= conf_threshold <= 1:
raise ValueError("conf_threshold must be None or a number between 0 and 1")

dataset = TaskDataset(client, task_id, load_annotations=False)

Expand All @@ -293,7 +293,7 @@ def annotate_task(
with pbar.task(total=len(dataset.samples), unit="samples"):
for sample in pbar.iter(dataset.samples):
frame_shapes = function.detect(
_DetectionFunctionContextImpl(frame_name=sample.frame_name, threshold=threshold),
_DetectionFunctionContextImpl(frame_name=sample.frame_name, conf_threshold=conf_threshold),
sample.media.load_image(),
)
mapper.validate_and_remap(frame_shapes, sample.frame_index)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,14 +30,14 @@ def spec(self) -> cvataa.DetectionFunctionSpec:
def detect(
self, context: cvataa.DetectionFunctionContext, image: PIL.Image.Image
) -> list[models.LabeledShapeRequest]:
threshold = context.threshold or 0
conf_threshold = context.conf_threshold or 0
results = self._model([self._transforms(image)])

return [
cvataa.rectangle(label.item(), [x.item() for x in box])
for result in results
for box, label, score in zip(result["boxes"], result["labels"], result["scores"])
if score >= threshold
if score >= conf_threshold
]


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def spec(self) -> cvataa.DetectionFunctionSpec:
def detect(
self, context: cvataa.DetectionFunctionContext, image: PIL.Image.Image
) -> list[models.LabeledShapeRequest]:
threshold = context.threshold or 0
conf_threshold = context.conf_threshold or 0
results = self._model([self._transforms(image)])

return [
Expand All @@ -57,7 +57,7 @@ def detect(
for keypoints, label, score in zip(
result["keypoints"], result["labels"], result["scores"]
)
if score >= threshold
if score >= conf_threshold
]


Expand Down
2 changes: 1 addition & 1 deletion cvat-sdk/cvat_sdk/auto_annotation/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ def frame_name(self) -> str:

@property
@abc.abstractmethod
def threshold(self) -> Optional[float]:
def conf_threshold(self) -> Optional[float]:
"""
The confidence threshold that the function should use for filtering
detections.
Expand Down
8 changes: 4 additions & 4 deletions site/content/en/docs/api_sdk/sdk/auto-annotation.md
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ class TorchvisionDetectionFunction:
self, context: cvataa.DetectionFunctionContext, image: PIL.Image.Image
) -> list[models.LabeledShapeRequest]:
# determine the threshold for filtering results
threshold = context.threshold or 0
conf_threshold = context.conf_threshold or 0

# convert the input into a form the model can understand
transformed_image = [self._transforms(image)]
Expand All @@ -85,7 +85,7 @@ class TorchvisionDetectionFunction:
cvataa.rectangle(label.item(), [x.item() for x in box])
for result in results
for box, label, score in zip(result["boxes"], result["labels"], result["scores"])
if score >= threshold
if score >= conf_threshold
]

# log into the CVAT server
Expand Down Expand Up @@ -122,7 +122,7 @@ that these objects must follow.
The following fields are available:

- `frame_name` (`str`). The file name of the frame on the CVAT server.
- `threshold` (`float | None`). The confidence threshold that the function
- `conf_threshold` (`float | None`). The confidence threshold that the function
should use to filter objects. If `None`, the function may apply a default
threshold at its discretion.

Expand Down Expand Up @@ -206,7 +206,7 @@ and any shapes referring to them will be dropped.
Same logic applies to sub-label IDs.

It's possible to pass a custom confidence threshold to the function via the
`threshold` parameter.
`conf_threshold` parameter.

`annotate_task` will raise a `BadFunctionError` exception
if it detects that the function violated the AA function protocol.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,5 +17,5 @@ def detect(
context: cvataa.DetectionFunctionContext, image: PIL.Image.Image
) -> list[models.LabeledShapeRequest]:
return [
cvataa.rectangle(0, [context.threshold, 1, 1, 1]),
cvataa.rectangle(0, [context.conf_threshold, 1, 1, 1]),
]
4 changes: 2 additions & 2 deletions tests/python/cli/test_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -355,8 +355,8 @@ def test_auto_annotate_with_threshold(self, fxt_new_task: Task):
self.run_cli(
"auto-annotate",
str(fxt_new_task.id),
f"--function-module={__package__}.threshold_function",
"--threshold=0.75",
f"--function-module={__package__}.conf_threshold_function",
"--conf-threshold=0.75",
)

annotations = fxt_new_task.get_annotations()
Expand Down
12 changes: 6 additions & 6 deletions tests/python/sdk/test_auto_annotation.py
Original file line number Diff line number Diff line change
Expand Up @@ -269,7 +269,7 @@ def detect(context, image: PIL.Image.Image) -> list[models.LabeledShapeRequest]:
assert shapes[i].points == [5, 6, 7, 8]
assert shapes[i].rotation == 10

def test_threshold(self):
def test_conf_threshold(self):
spec = cvataa.DetectionFunctionSpec(labels=[])

received_threshold = None
Expand All @@ -278,14 +278,14 @@ def detect(
context: cvataa.DetectionFunctionContext, image: PIL.Image.Image
) -> list[models.LabeledShapeRequest]:
nonlocal received_threshold
received_threshold = context.threshold
received_threshold = context.conf_threshold
return []

cvataa.annotate_task(
self.client,
self.task.id,
namespace(spec=spec, detect=detect),
threshold=0.75,
conf_threshold=0.75,
)

assert received_threshold == 0.75
Expand All @@ -304,7 +304,7 @@ def detect(
self.client,
self.task.id,
namespace(spec=spec, detect=detect),
threshold=bad_threshold,
conf_threshold=bad_threshold,
)

def _test_bad_function_spec(self, spec: cvataa.DetectionFunctionSpec, exc_match: str) -> None:
Expand Down Expand Up @@ -713,7 +713,7 @@ def test_torchvision_detection(self, monkeypatch: pytest.MonkeyPatch):
self.task.id,
td.create("fasterrcnn_resnet50_fpn_v2", "COCO_V1", test_param="expected_value"),
allow_unmatched_labels=True,
threshold=0.75,
conf_threshold=0.75,
)

annotations = self.task.get_annotations()
Expand All @@ -733,7 +733,7 @@ def test_torchvision_keypoint_detection(self, monkeypatch: pytest.MonkeyPatch):
self.task.id,
tkd.create("keypointrcnn_resnet50_fpn", "COCO_V1", test_param="expected_value"),
allow_unmatched_labels=True,
threshold=0.75,
conf_threshold=0.75,
)

annotations = self.task.get_annotations()
Expand Down

0 comments on commit f92da49

Please sign in to comment.