Skip to content

Commit 47118de

Browse files
committed
feat: Implement DataCollectorWithoutNone
1 parent 6f08b07 commit 47118de

File tree

3 files changed

+92
-7
lines changed

3 files changed

+92
-7
lines changed

mesa/datacollection.py

Lines changed: 26 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,13 @@ class DataCollector:
5151
one and stores the results.
5252
"""
5353

54-
def __init__(self, model_reporters=None, agent_reporters=None, tables=None):
54+
def __init__(
55+
self,
56+
model_reporters=None,
57+
agent_reporters=None,
58+
tables=None,
59+
exclude_none_values=False,
60+
):
5561
"""Instantiate a DataCollector with lists of model and agent reporters.
5662
Both model_reporters and agent_reporters accept a dictionary mapping a
5763
variable name to either an attribute name, or a method.
@@ -74,6 +80,8 @@ def __init__(self, model_reporters=None, agent_reporters=None, tables=None):
7480
model_reporters: Dictionary of reporter names and attributes/funcs
7581
agent_reporters: Dictionary of reporter names and attributes/funcs.
7682
tables: Dictionary of table names to lists of column names.
83+
exclude_non_values: Boolean of whether to drop records which values
84+
are None in the final result.
7785
7886
Notes:
7987
If you want to pickle your model you must not use lambda functions.
@@ -97,6 +105,7 @@ class attributes of a model
97105
self.model_vars = {}
98106
self._agent_records = {}
99107
self.tables = {}
108+
self.exclude_none_values = exclude_none_values
100109

101110
if model_reporters is not None:
102111
for name, reporter in model_reporters.items():
@@ -151,7 +160,23 @@ def _new_table(self, table_name, table_columns):
151160
def _record_agents(self, model):
152161
"""Record agents data in a mapping of functions and agents."""
153162
rep_funcs = self.agent_reporters.values()
163+
if self.exclude_none_values:
164+
# Drop records which values are None.
165+
166+
def get_reports(agent):
167+
_prefix = (agent.model.schedule.steps, agent.unique_id)
168+
reports = (rep(agent) for rep in rep_funcs)
169+
reports_without_none = tuple(r for r in reports if r is not None)
170+
if len(reports_without_none) == 0:
171+
return None
172+
return _prefix + reports_without_none
173+
174+
agent_records = (get_reports(agent) for agent in model.schedule.agents)
175+
agent_records_without_none = (r for r in agent_records if r is not None)
176+
return agent_records_without_none
177+
154178
if all(hasattr(rep, "attribute_name") for rep in rep_funcs):
179+
# This branch is for performance optimization purpose.
155180
prefix = ["model.schedule.steps", "unique_id"]
156181
attributes = [func.attribute_name for func in rep_funcs]
157182
get_reports = attrgetter(*prefix + attributes)

mesa/model.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,11 @@ def reset_randomizer(self, seed: int | None = None) -> None:
6666
self._seed = seed
6767

6868
def initialize_data_collector(
69-
self, model_reporters=None, agent_reporters=None, tables=None
69+
self,
70+
model_reporters=None,
71+
agent_reporters=None,
72+
tables=None,
73+
exclude_none_values=False,
7074
) -> None:
7175
if not hasattr(self, "schedule") or self.schedule is None:
7276
raise RuntimeError(
@@ -80,6 +84,7 @@ def initialize_data_collector(
8084
model_reporters=model_reporters,
8185
agent_reporters=agent_reporters,
8286
tables=tables,
87+
exclude_none_values=exclude_none_values,
8388
)
8489
# Collect data for the first time during initialization.
8590
self.datacollector.collect(self)

tests/test_datacollector.py

Lines changed: 60 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -32,20 +32,35 @@ def write_final_values(self):
3232
self.model.datacollector.add_table_row("Final_Values", row)
3333

3434

35+
class DifferentMockAgent(MockAgent):
36+
# We define a different MockAgent to test for attributes that are present
37+
# only in 1 type of agent, but not the other.
38+
def __init__(self, unique_id, model, val=0):
39+
super().__init__(unique_id, model, val=val)
40+
self.val3 = val + 42
41+
42+
3543
class MockModel(Model):
3644
"""
3745
Minimalistic model for testing purposes.
3846
"""
3947

4048
schedule = BaseScheduler(None)
4149

42-
def __init__(self):
50+
def __init__(self, test_exclude_none_values=False):
4351
self.schedule = BaseScheduler(self)
4452
self.model_val = 100
4553

46-
for i in range(10):
47-
a = MockAgent(i, self, val=i)
48-
self.schedule.add(a)
54+
self.n = 10
55+
for i in range(self.n):
56+
self.schedule.add(MockAgent(i, self, val=i))
57+
if test_exclude_none_values:
58+
self.schedule.add(DifferentMockAgent(self.n + i, self, val=i))
59+
if test_exclude_none_values:
60+
# Only DifferentMockAgent has val3.
61+
agent_reporters = {"value": lambda a: a.val, "value3": "val3"}
62+
else:
63+
agent_reporters = {"value": lambda a: a.val, "value2": "val2"}
4964
self.initialize_data_collector(
5065
{
5166
"total_agents": lambda m: m.schedule.get_agent_count(),
@@ -54,8 +69,9 @@ def __init__(self):
5469
"model_calc_comp": [self.test_model_calc_comp, [3, 4]],
5570
"model_calc_fail": [self.test_model_calc_comp, [12, 0]],
5671
},
57-
{"value": lambda a: a.val, "value2": "val2"},
72+
agent_reporters,
5873
{"Final_Values": ["agent_id", "final_value"]},
74+
exclude_none_values=test_exclude_none_values,
5975
)
6076

6177
def test_model_calc_comp(self, input1, input2):
@@ -195,5 +211,44 @@ def test_initialize_before_agents_added_to_scheduler(self):
195211
)
196212

197213

214+
class TestDataCollectorExcludeNone(unittest.TestCase):
215+
def setUp(self):
216+
"""
217+
Create the model and run it a set number of steps.
218+
"""
219+
self.model = MockModel(test_exclude_none_values=True)
220+
for i in range(7):
221+
if i == 4:
222+
self.model.schedule.remove(self.model.schedule._agents[3])
223+
self.model.step()
224+
225+
def test_agent_records(self):
226+
"""
227+
Test agent-level variable collection.
228+
"""
229+
data_collector = self.model.datacollector
230+
agent_table = data_collector.get_agent_vars_dataframe()
231+
232+
assert len(data_collector._agent_records) == 8
233+
for step, records in data_collector._agent_records.items():
234+
if step < 5:
235+
assert len(records) == 20
236+
else:
237+
assert len(records) == 19
238+
239+
for values in records:
240+
agent_id = values[1]
241+
if agent_id < self.model.n:
242+
assert len(values) == 3
243+
else:
244+
# Agents with agent_id >= self.model.n are
245+
# DifferentMockAgent, which additionally contains val3.
246+
assert len(values) == 4
247+
248+
assert "value" in list(agent_table.columns)
249+
assert "value2" not in list(agent_table.columns)
250+
assert "value3" in list(agent_table.columns)
251+
252+
198253
if __name__ == "__main__":
199254
unittest.main()

0 commit comments

Comments
 (0)