6
6
7
7
from .. import detection_in_world
8
8
from .. import object_in_world
9
- from ..cluster_estimation import cluster_estimation
10
9
from ..common .modules .logger import logger
10
+ from . import cluster_estimation
11
11
12
12
13
13
class ClusterEstimationByLabel :
@@ -19,11 +19,14 @@ class ClusterEstimationByLabel:
19
19
ATTRIBUTES
20
20
----------
21
21
min_activation_threshold: int
22
- Minimum total data points before model runs.
22
+ Minimum total data points before model runs. Must be at least max_num_components.
23
23
24
24
min_new_points_to_run: int
25
25
Minimum number of new data points that must be collected before running model.
26
26
27
+ max_num_components: int
28
+ Max number of real landing pads. Must be at least 1.
29
+
27
30
random_state: int
28
31
Seed for randomizer, to get consistent results.
29
32
@@ -47,6 +50,7 @@ def create(
47
50
cls ,
48
51
min_activation_threshold : int ,
49
52
min_new_points_to_run : int ,
53
+ max_num_components : int ,
50
54
random_state : int ,
51
55
local_logger : logger .Logger ,
52
56
) -> "tuple[bool, ClusterEstimationByLabel | None]" :
@@ -55,13 +59,23 @@ def create(
55
59
"""
56
60
57
61
# At least 1 point for model to fit
58
- if min_activation_threshold < 1 :
62
+ if min_activation_threshold < max_num_components :
63
+ return False , None
64
+
65
+ if min_new_points_to_run < 0 :
66
+ return False , None
67
+
68
+ if max_num_components < 1 :
69
+ return False , None
70
+
71
+ if random_state < 0 :
59
72
return False , None
60
73
61
74
return True , ClusterEstimationByLabel (
62
75
cls .__create_key ,
63
76
min_activation_threshold ,
64
77
min_new_points_to_run ,
78
+ max_num_components ,
65
79
random_state ,
66
80
local_logger ,
67
81
)
@@ -71,6 +85,7 @@ def __init__(
71
85
class_private_create_key : object ,
72
86
min_activation_threshold : int ,
73
87
min_new_points_to_run : int ,
88
+ max_num_components : int ,
74
89
random_state : int ,
75
90
local_logger : logger .Logger ,
76
91
) -> None :
@@ -84,10 +99,12 @@ def __init__(
84
99
# Requirements to decide to run
85
100
self .__min_activation_threshold = min_activation_threshold
86
101
self .__min_new_points_to_run = min_new_points_to_run
102
+ self .__max_num_components = max_num_components
87
103
self .__random_state = random_state
88
104
self .__local_logger = local_logger
89
105
90
- # cluster model corresponding to each label
106
+ # Cluster model corresponding to each label
107
+ # Each cluster estimation object stores the detections given to in its __all_points bucket across runs
91
108
self .__label_to_cluster_estimation_model : dict [
92
109
int , cluster_estimation .ClusterEstimation
93
110
] = {}
@@ -120,17 +137,20 @@ def run(
120
137
Dictionary where the key is a label and the value is a list of all cluster detections with that label
121
138
"""
122
139
label_to_detections : dict [int , list [detection_in_world .DetectionInWorld ]] = {}
140
+ # Sorting detections by label
123
141
for detection in input_detections :
124
142
if not detection .label in label_to_detections :
125
143
label_to_detections [detection .label ] = []
126
144
label_to_detections [detection .label ].append (detection )
127
145
128
146
labels_to_object_clusters : dict [int , list [object_in_world .ObjectInWorld ]] = {}
129
147
for label , detections in label_to_detections .items ():
148
+ # create cluster estimation for label if it doesn't exist
130
149
if not label in self .__label_to_cluster_estimation_model :
131
150
result , cluster_model = cluster_estimation .ClusterEstimation .create (
132
151
self .__min_activation_threshold ,
133
152
self .__min_new_points_to_run ,
153
+ self .__max_num_components ,
134
154
self .__random_state ,
135
155
self .__local_logger ,
136
156
label ,
@@ -141,6 +161,7 @@ def run(
141
161
)
142
162
return False , None
143
163
self .__label_to_cluster_estimation_model [label ] = cluster_model
164
+ # runs cluster estimation for specific label
144
165
result , clusters = self .__label_to_cluster_estimation_model [label ].run (
145
166
detections ,
146
167
run_override ,
0 commit comments