Skip to content

Commit d8b63f7

Browse files
committed
integrated review changes
1 parent d8b072f commit d8b63f7

File tree

4 files changed

+30
-23
lines changed

4 files changed

+30
-23
lines changed

main_2024.py

+1
Original file line numberDiff line numberDiff line change
@@ -346,6 +346,7 @@ def main() -> int:
346346
MIN_NEW_POINTS_TO_RUN,
347347
MAX_NUM_COMPONENTS,
348348
RANDOM_STATE,
349+
0,
349350
),
350351
input_queues=[geolocation_to_cluster_estimation_queue],
351352
output_queues=[cluster_estimation_to_communications_queue],

modules/cluster_estimation/cluster_estimation_by_label.py

+25-4
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,8 @@
66

77
from .. import detection_in_world
88
from .. import object_in_world
9-
from ..cluster_estimation import cluster_estimation
109
from ..common.modules.logger import logger
10+
from . import cluster_estimation
1111

1212

1313
class ClusterEstimationByLabel:
@@ -19,11 +19,14 @@ class ClusterEstimationByLabel:
1919
ATTRIBUTES
2020
----------
2121
min_activation_threshold: int
22-
Minimum total data points before model runs.
22+
Minimum total data points before model runs. Must be at least max_num_components.
2323
2424
min_new_points_to_run: int
2525
Minimum number of new data points that must be collected before running model.
2626
27+
max_num_components: int
28+
Max number of real landing pads. Must be at least 1.
29+
2730
random_state: int
2831
Seed for randomizer, to get consistent results.
2932
@@ -47,6 +50,7 @@ def create(
4750
cls,
4851
min_activation_threshold: int,
4952
min_new_points_to_run: int,
53+
max_num_components: int,
5054
random_state: int,
5155
local_logger: logger.Logger,
5256
) -> "tuple[bool, ClusterEstimationByLabel | None]":
@@ -55,13 +59,23 @@ def create(
5559
"""
5660

5761
# At least 1 point for model to fit
58-
if min_activation_threshold < 1:
62+
if min_activation_threshold < max_num_components:
63+
return False, None
64+
65+
if min_new_points_to_run < 0:
66+
return False, None
67+
68+
if max_num_components < 1:
69+
return False, None
70+
71+
if random_state < 0:
5972
return False, None
6073

6174
return True, ClusterEstimationByLabel(
6275
cls.__create_key,
6376
min_activation_threshold,
6477
min_new_points_to_run,
78+
max_num_components,
6579
random_state,
6680
local_logger,
6781
)
@@ -71,6 +85,7 @@ def __init__(
7185
class_private_create_key: object,
7286
min_activation_threshold: int,
7387
min_new_points_to_run: int,
88+
max_num_components: int,
7489
random_state: int,
7590
local_logger: logger.Logger,
7691
) -> None:
@@ -84,10 +99,12 @@ def __init__(
8499
# Requirements to decide to run
85100
self.__min_activation_threshold = min_activation_threshold
86101
self.__min_new_points_to_run = min_new_points_to_run
102+
self.__max_num_components = max_num_components
87103
self.__random_state = random_state
88104
self.__local_logger = local_logger
89105

90-
# cluster model corresponding to each label
106+
# Cluster model corresponding to each label
107+
# Each cluster estimation object stores the detections given to in its __all_points bucket across runs
91108
self.__label_to_cluster_estimation_model: dict[
92109
int, cluster_estimation.ClusterEstimation
93110
] = {}
@@ -120,17 +137,20 @@ def run(
120137
Dictionary where the key is a label and the value is a list of all cluster detections with that label
121138
"""
122139
label_to_detections: dict[int, list[detection_in_world.DetectionInWorld]] = {}
140+
# Sorting detections by label
123141
for detection in input_detections:
124142
if not detection.label in label_to_detections:
125143
label_to_detections[detection.label] = []
126144
label_to_detections[detection.label].append(detection)
127145

128146
labels_to_object_clusters: dict[int, list[object_in_world.ObjectInWorld]] = {}
129147
for label, detections in label_to_detections.items():
148+
# create cluster estimation for label if it doesn't exist
130149
if not label in self.__label_to_cluster_estimation_model:
131150
result, cluster_model = cluster_estimation.ClusterEstimation.create(
132151
self.__min_activation_threshold,
133152
self.__min_new_points_to_run,
153+
self.__max_num_components,
134154
self.__random_state,
135155
self.__local_logger,
136156
label,
@@ -141,6 +161,7 @@ def run(
141161
)
142162
return False, None
143163
self.__label_to_cluster_estimation_model[label] = cluster_model
164+
# runs cluster estimation for specific label
144165
result, clusters = self.__label_to_cluster_estimation_model[label].run(
145166
detections,
146167
run_override,

modules/cluster_estimation/cluster_estimation_worker.py

+2
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ def cluster_estimation_worker(
1919
min_new_points_to_run: int,
2020
max_num_components: int,
2121
random_state: int,
22+
label: int,
2223
input_queue: queue_proxy_wrapper.QueueProxyWrapper,
2324
output_queue: queue_proxy_wrapper.QueueProxyWrapper,
2425
controller: worker_controller.WorkerController,
@@ -95,6 +96,7 @@ def cluster_estimation_worker(
9596
max_num_components,
9697
random_state,
9798
local_logger,
99+
label,
98100
)
99101
if not result:
100102
local_logger.error("Worker failed to create class object", True)

tests/unit/test_cluster_detection.py

+2-19
Original file line numberDiff line numberDiff line change
@@ -34,30 +34,12 @@ def cluster_model() -> cluster_estimation.ClusterEstimation: # type: ignore
3434
assert test_logger is not None
3535

3636
result, model = cluster_estimation.ClusterEstimation.create(
37-
MIN_TOTAL_POINTS_THRESHOLD, MIN_NEW_POINTS_TO_RUN, RNG_SEED, test_logger, 0
38-
)
39-
assert result
40-
assert model is not None
41-
42-
yield model # type: ignore
43-
44-
45-
@pytest.fixture()
46-
def cluster_model_by_label() -> cluster_estimation_by_label.ClusterEstimationByLabel: # type: ignore
47-
"""
48-
Cluster estimation by label object.
49-
"""
50-
result, test_logger = logger.Logger.create("test_logger", False)
51-
assert result
52-
assert test_logger is not None
53-
54-
result, model = cluster_estimation_by_label.ClusterEstimationByLabel.create(
5537
MIN_TOTAL_POINTS_THRESHOLD,
5638
MIN_NEW_POINTS_TO_RUN,
5739
MAX_NUM_COMPONENTS,
5840
RNG_SEED,
5941
test_logger,
60-
0
42+
0,
6143
)
6244
assert result
6345
assert model is not None
@@ -77,6 +59,7 @@ def cluster_model_by_label() -> cluster_estimation_by_label.ClusterEstimationByL
7759
result, model = cluster_estimation_by_label.ClusterEstimationByLabel.create(
7860
MIN_TOTAL_POINTS_THRESHOLD,
7961
MIN_NEW_POINTS_TO_RUN,
62+
MAX_NUM_COMPONENTS,
8063
RNG_SEED,
8164
test_logger,
8265
)

0 commit comments

Comments
 (0)