Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

implemented label in cluster estimation #231

Open
wants to merge 23 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 39 additions & 19 deletions modules/cluster_estimation/cluster_estimation.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,32 @@ class ClusterEstimation:
__WEIGHT_DROP_THRESHOLD = 0.1
__MAX_COVARIANCE_THRESHOLD = 10

@staticmethod
def check_create_arguments(
min_activation_threshold: int,
min_new_points_to_run: int,
max_num_components: int,
random_state: int,
) -> bool:
"""
Checks if a valid cluster estimation object can be constructed.

See `ClusterEstimation` for parameter descriptions.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Change this to:

See `create()` for parameter descriptions.

Return: Whether the arguments are valid.

"""
if min_activation_threshold < max_num_components:
return False

if min_new_points_to_run < 0:
return False

if max_num_components < 1:
return False

if random_state < 0:
return False

return True

@classmethod
def create(
cls,
Expand Down Expand Up @@ -85,16 +111,11 @@ def create(

RETURNS: The ClusterEstimation object if all conditions pass, otherwise False, None
"""
if min_activation_threshold < max_num_components:
return False, None

if min_new_points_to_run < 0:
return False, None

if max_num_components < 1:
return False, None
is_valid_arguments = ClusterEstimation.check_create_arguments(
min_activation_threshold, min_new_points_to_run, max_num_components, random_state
)

if random_state < 0:
if not is_valid_arguments:
return False, None

return True, ClusterEstimation(
Expand Down Expand Up @@ -211,21 +232,20 @@ def run(
model_output = self.__filter_by_covariances(model_output)

# Create output list of remaining valid clusters
detections_in_world = []
objects_in_world = []
for cluster in model_output:
result, landing_pad = object_in_world.ObjectInWorld.create(
cluster[0][0],
cluster[0][1],
cluster[2],
cluster[0][0], cluster[0][1], cluster[2]
)

if result:
detections_in_world.append(landing_pad)
else:
self.__logger.warning("Failed to create ObjectInWorld object")
if not result:
self.__logger.error("Failed to create ObjectInWorld object")
return False, None

objects_in_world.append(landing_pad)

self.__logger.info(detections_in_world)
return True, detections_in_world
self.__logger.info(objects_in_world)
return True, objects_in_world

def __decide_to_run(self, run_override: bool) -> bool:
"""
Expand Down
162 changes: 162 additions & 0 deletions modules/cluster_estimation/cluster_estimation_by_label.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,162 @@
"""
Cluster estimation by label.
"""

from . import cluster_estimation
from .. import detection_in_world
from .. import object_in_world
from ..common.modules.logger import logger


class ClusterEstimationByLabel:
"""
Cluster estimation filtered on label.

ATTRIBUTES
----------
min_activation_threshold: int
Minimum total data points before model runs. Must be at least max_num_components.

min_new_points_to_run: int
Minimum number of new data points that must be collected before running model.

max_num_components: int
Max number of real landing pads. Must be at least 1.

random_state: int
Seed for randomizer, to get consistent results.

local_logger: Logger
For logging error and debug messages.

METHODS
-------
run()
Cluster estimation filtered by label.
Comment on lines +14 to +35
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remove this.

"""

__create_key = object()

@classmethod
def create(
cls,
min_activation_threshold: int,
min_new_points_to_run: int,
max_num_components: int,
random_state: int,
local_logger: logger.Logger,
) -> "tuple[True, ClusterEstimationByLabel] | tuple[False, None]":
"""
See `ClusterEstimation` for parameter descriptions.

Return: Success, cluster estimation by label object.
"""

is_valid_arguments = cluster_estimation.ClusterEstimation.check_create_arguments(
min_activation_threshold, min_new_points_to_run, max_num_components, random_state
)

if not is_valid_arguments:
return False, None

return True, ClusterEstimationByLabel(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please check ClusterEstimation's restrictions. Either apply them again here or by invoking creating cluster estimation

cls.__create_key,
min_activation_threshold,
min_new_points_to_run,
max_num_components,
random_state,
local_logger,
)

def __init__(
self,
class_private_create_key: object,
min_activation_threshold: int,
min_new_points_to_run: int,
max_num_components: int,
random_state: int,
local_logger: logger.Logger,
) -> None:
"""
Private constructor, use create() method.
"""
assert (
class_private_create_key is ClusterEstimationByLabel.__create_key
), "Use create() method"

# Construction arguments for `ClusterEstimation`
self.__min_activation_threshold = min_activation_threshold
self.__min_new_points_to_run = min_new_points_to_run
self.__max_num_components = max_num_components
self.__random_state = random_state
self.__local_logger = local_logger

# Cluster model corresponding to each label
# Each cluster estimation object stores the detections given to in its __all_points bucket across runs
self.__label_to_cluster_estimation_model: dict[
int, cluster_estimation.ClusterEstimation
] = {}

def run(
self,
input_detections: list[detection_in_world.DetectionInWorld],
run_override: bool,
) -> tuple[True, dict[int, list[object_in_world.ObjectInWorld]]] | tuple[False, None]:
"""
See `ClusterEstimation` for parameter descriptions.

RETURNS
-------
model_ran: bool
True if ClusterEstimation object successfully ran its estimation model, False otherwise.

labels_to_objects: dict[int, list[object_in_world.ObjectInWorld] or None.
Dictionary where the key is a label and the value is a list of all cluster detections with that label.
ObjectInWorld objects don't have a label property, but they are sorted into label categories in the dictionary.
Comment on lines +108 to +115
Copy link
Collaborator

@Xierumeng Xierumeng Mar 7, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Simplify:

Return: Success, labels and their associated objects.

"""
label_to_detections: dict[int, list[detection_in_world.DetectionInWorld]] = {}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add comment saying sorting detections by label


# Filtering detections by label
for detection in input_detections:
if not detection.label in label_to_detections:
label_to_detections[detection.label] = []

label_to_detections[detection.label].append(detection)

labels_to_objects: dict[int, list[object_in_world.ObjectInWorld]] = {}

for label, detections in label_to_detections.items():
# Create cluster estimation for label if it doesn't exist
if not label in self.__label_to_cluster_estimation_model:
result, cluster_model = cluster_estimation.ClusterEstimation.create(
self.__min_activation_threshold,
self.__min_new_points_to_run,
self.__max_num_components,
self.__random_state,
self.__local_logger,
)
if not result:
self.__local_logger.error(
f"Failed to create cluster estimation for label {label}"
)
return False, None

self.__label_to_cluster_estimation_model[label] = cluster_model

# Runs cluster estimation for specific label
result, clusters = self.__label_to_cluster_estimation_model[label].run(
detections,
run_override,
)

if not result:
self.__local_logger.error(
f"Failed to run cluster estimation model for label {label}"
)
return False, None

if not label in labels_to_objects:
labels_to_objects[label] = []
labels_to_objects[label] += clusters
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add empty line above this line.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this the desired behaviour? The cluster estimation objects already hold a record of all points, so they will always generate an updated version of the cluster centres. I think this should be an unconditional assignment instead.


return True, labels_to_objects
19 changes: 1 addition & 18 deletions modules/cluster_estimation/cluster_estimation_worker.py
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Change this worker to use the new ClusterEstimationByLabel class.

Original file line number Diff line number Diff line change
Expand Up @@ -26,26 +26,9 @@ def cluster_estimation_worker(

PARAMETERS
----------
Comment on lines 27 to 28
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remove.

min_activation_threshold: int
Minimum total data points before model runs.

min_new_points_to_run: int
Minimum number of new data points that must be collected before running model.
See `ClusterEstimation` for parameter descriptions.

max_num_components: int
Max number of real landing pads.

random_state: int
Seed for randomizer, to get consistent results.

input_queue: queue_proxy_wrapper.QueuePRoxyWrapper
Data queue.

output_queue: queue_proxy_wrapper.QueuePRoxyWrapper
Data queue.

worker_controller: worker_controller.WorkerController
How the main process communicates to this worker process.
"""
worker_name = pathlib.Path(__file__).stem
process_id = os.getpid()
Expand Down
7 changes: 6 additions & 1 deletion modules/object_in_world.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,12 @@ def create(
if spherical_variance < 0.0:
return False, None

return True, ObjectInWorld(cls.__create_key, location_x, location_y, spherical_variance)
return True, ObjectInWorld(
cls.__create_key,
location_x,
location_y,
spherical_variance,
)

def __init__(
self,
Expand Down
Loading