-
Notifications
You must be signed in to change notification settings - Fork 38
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
implemented label in cluster estimation #231
base: main
Are you sure you want to change the base?
Changes from all commits
3205913
dabc8cb
a101d17
571cc13
61dcfa1
9c70b00
0b89706
f5c4b90
f687063
c0ea3e0
50ef6ff
42c9fbe
6c8e724
b702a76
34224fe
6c02dc8
2106bbe
1692e16
769ba6e
8618361
d3c0c89
d12aa34
b8155cb
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,162 @@ | ||
""" | ||
Cluster estimation by label. | ||
""" | ||
|
||
from . import cluster_estimation | ||
from .. import detection_in_world | ||
from .. import object_in_world | ||
from ..common.modules.logger import logger | ||
|
||
|
||
class ClusterEstimationByLabel: | ||
""" | ||
Cluster estimation filtered on label. | ||
|
||
ATTRIBUTES | ||
---------- | ||
min_activation_threshold: int | ||
Minimum total data points before model runs. Must be at least max_num_components. | ||
|
||
min_new_points_to_run: int | ||
Minimum number of new data points that must be collected before running model. | ||
|
||
max_num_components: int | ||
Max number of real landing pads. Must be at least 1. | ||
|
||
random_state: int | ||
Seed for randomizer, to get consistent results. | ||
|
||
local_logger: Logger | ||
For logging error and debug messages. | ||
|
||
METHODS | ||
------- | ||
run() | ||
Cluster estimation filtered by label. | ||
Comment on lines
+14
to
+35
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Remove this. |
||
""" | ||
|
||
__create_key = object() | ||
|
||
@classmethod | ||
def create( | ||
cls, | ||
min_activation_threshold: int, | ||
min_new_points_to_run: int, | ||
max_num_components: int, | ||
random_state: int, | ||
Aleksa-M marked this conversation as resolved.
Show resolved
Hide resolved
|
||
local_logger: logger.Logger, | ||
) -> "tuple[True, ClusterEstimationByLabel] | tuple[False, None]": | ||
""" | ||
See `ClusterEstimation` for parameter descriptions. | ||
Xierumeng marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
Return: Success, cluster estimation by label object. | ||
""" | ||
|
||
is_valid_arguments = cluster_estimation.ClusterEstimation.check_create_arguments( | ||
min_activation_threshold, min_new_points_to_run, max_num_components, random_state | ||
) | ||
|
||
if not is_valid_arguments: | ||
return False, None | ||
|
||
return True, ClusterEstimationByLabel( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Please check |
||
cls.__create_key, | ||
min_activation_threshold, | ||
min_new_points_to_run, | ||
max_num_components, | ||
random_state, | ||
local_logger, | ||
) | ||
|
||
def __init__( | ||
self, | ||
class_private_create_key: object, | ||
min_activation_threshold: int, | ||
min_new_points_to_run: int, | ||
max_num_components: int, | ||
random_state: int, | ||
local_logger: logger.Logger, | ||
) -> None: | ||
""" | ||
Private constructor, use create() method. | ||
""" | ||
assert ( | ||
class_private_create_key is ClusterEstimationByLabel.__create_key | ||
), "Use create() method" | ||
|
||
# Construction arguments for `ClusterEstimation` | ||
self.__min_activation_threshold = min_activation_threshold | ||
self.__min_new_points_to_run = min_new_points_to_run | ||
self.__max_num_components = max_num_components | ||
self.__random_state = random_state | ||
self.__local_logger = local_logger | ||
|
||
# Cluster model corresponding to each label | ||
# Each cluster estimation object stores the detections given to in its __all_points bucket across runs | ||
self.__label_to_cluster_estimation_model: dict[ | ||
int, cluster_estimation.ClusterEstimation | ||
] = {} | ||
|
||
def run( | ||
self, | ||
input_detections: list[detection_in_world.DetectionInWorld], | ||
run_override: bool, | ||
) -> tuple[True, dict[int, list[object_in_world.ObjectInWorld]]] | tuple[False, None]: | ||
""" | ||
See `ClusterEstimation` for parameter descriptions. | ||
|
||
RETURNS | ||
------- | ||
model_ran: bool | ||
True if ClusterEstimation object successfully ran its estimation model, False otherwise. | ||
|
||
labels_to_objects: dict[int, list[object_in_world.ObjectInWorld] or None. | ||
Dictionary where the key is a label and the value is a list of all cluster detections with that label. | ||
ObjectInWorld objects don't have a label property, but they are sorted into label categories in the dictionary. | ||
Comment on lines
+108
to
+115
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Simplify:
|
||
""" | ||
label_to_detections: dict[int, list[detection_in_world.DetectionInWorld]] = {} | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Add comment saying sorting detections by label |
||
|
||
# Filtering detections by label | ||
for detection in input_detections: | ||
if not detection.label in label_to_detections: | ||
label_to_detections[detection.label] = [] | ||
|
||
label_to_detections[detection.label].append(detection) | ||
Xierumeng marked this conversation as resolved.
Show resolved
Hide resolved
Xierumeng marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
labels_to_objects: dict[int, list[object_in_world.ObjectInWorld]] = {} | ||
|
||
for label, detections in label_to_detections.items(): | ||
# Create cluster estimation for label if it doesn't exist | ||
if not label in self.__label_to_cluster_estimation_model: | ||
Aleksa-M marked this conversation as resolved.
Show resolved
Hide resolved
|
||
result, cluster_model = cluster_estimation.ClusterEstimation.create( | ||
self.__min_activation_threshold, | ||
self.__min_new_points_to_run, | ||
self.__max_num_components, | ||
self.__random_state, | ||
self.__local_logger, | ||
) | ||
if not result: | ||
self.__local_logger.error( | ||
f"Failed to create cluster estimation for label {label}" | ||
) | ||
return False, None | ||
|
||
self.__label_to_cluster_estimation_model[label] = cluster_model | ||
Xierumeng marked this conversation as resolved.
Show resolved
Hide resolved
Xierumeng marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
# Runs cluster estimation for specific label | ||
result, clusters = self.__label_to_cluster_estimation_model[label].run( | ||
Aleksa-M marked this conversation as resolved.
Show resolved
Hide resolved
|
||
detections, | ||
run_override, | ||
) | ||
|
||
if not result: | ||
self.__local_logger.error( | ||
f"Failed to run cluster estimation model for label {label}" | ||
) | ||
return False, None | ||
|
||
if not label in labels_to_objects: | ||
labels_to_objects[label] = [] | ||
labels_to_objects[label] += clusters | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Add empty line above this line. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is this the desired behaviour? The cluster estimation objects already hold a record of all points, so they will always generate an updated version of the cluster centres. I think this should be an unconditional assignment instead. |
||
|
||
return True, labels_to_objects |
Aleksa-M marked this conversation as resolved.
Show resolved
Hide resolved
Xierumeng marked this conversation as resolved.
Show resolved
Hide resolved
Xierumeng marked this conversation as resolved.
Show resolved
Hide resolved
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Change this worker to use the new |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -26,26 +26,9 @@ def cluster_estimation_worker( | |
|
||
PARAMETERS | ||
---------- | ||
Comment on lines
27
to
28
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Remove. |
||
min_activation_threshold: int | ||
Minimum total data points before model runs. | ||
|
||
min_new_points_to_run: int | ||
Minimum number of new data points that must be collected before running model. | ||
See `ClusterEstimation` for parameter descriptions. | ||
|
||
max_num_components: int | ||
Max number of real landing pads. | ||
|
||
random_state: int | ||
Seed for randomizer, to get consistent results. | ||
|
||
input_queue: queue_proxy_wrapper.QueuePRoxyWrapper | ||
Data queue. | ||
|
||
output_queue: queue_proxy_wrapper.QueuePRoxyWrapper | ||
Data queue. | ||
|
||
worker_controller: worker_controller.WorkerController | ||
How the main process communicates to this worker process. | ||
""" | ||
worker_name = pathlib.Path(__file__).stem | ||
process_id = os.getpid() | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Change this to: