Skip to content

Commit 05d1767

Browse files
committed
Fix DataCollector tests for new agent reporter types
- Move agent_reporters specification into initialize_data_collector() - Add back the tables argument (accidentally deleted in previous commit) - Use parentheses to parse step and agent_id from agent records dataframe, since those are the multi-index key - Update expected values for new agent reporter types - Update length values of new agent table and vars (both increase by 2 due to 2 new agent reporter columns)
1 parent 92ef8bc commit 05d1767

File tree

1 file changed

+12
-9
lines changed

1 file changed

+12
-9
lines changed

tests/test_datacollector.py

Lines changed: 12 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -34,9 +34,11 @@ def write_final_values(self):
3434
row = {"agent_id": self.unique_id, "final_value": self.val}
3535
self.model.datacollector.add_table_row("Final_Values", row)
3636

37+
3738
def agent_function_with_params(agent, multiplier, offset):
3839
return (agent.val * multiplier) + offset
3940

41+
4042
class DifferentMockAgent(MockAgent):
4143
# We define a different MockAgent to test for attributes that are present
4244
# only in 1 type of agent, but not the other.
@@ -59,7 +61,6 @@ def __init__(self):
5961
self.n = 10
6062
for i in range(self.n):
6163
self.schedule.add(MockAgent(i, self, val=i))
62-
agent_reporters = {"value": lambda a: a.val, "value2": "val2"}
6364
self.initialize_data_collector(
6465
model_reporters={
6566
"total_agents": lambda m: m.schedule.get_agent_count(),
@@ -72,9 +73,11 @@ def __init__(self):
7273
"value": lambda a: a.val,
7374
"value2": "val2",
7475
"double_value": MockAgent.double_val,
75-
"value_with_params": [agent_function_with_params, [2, 3]]
76-
}
76+
"value_with_params": [agent_function_with_params, [2, 3]],
77+
},
78+
tables={"Final_Values": ["agent_id", "final_value"]},
7779
)
80+
7881
def test_model_calc_comp(self, input1, input2):
7982
if input2 > 0:
8083
return (self.model_val * input1) / input2
@@ -144,13 +147,13 @@ def test_agent_records(self):
144147
assert "value_with_params" in list(agent_table.columns)
145148

146149
# Check the double_value column
147-
for step, agent_id, value in agent_table["double_value"].items():
148-
expected_value = agent_id * 2
150+
for (step, agent_id), value in agent_table["double_value"].items():
151+
expected_value = (step + agent_id) * 2
149152
self.assertEqual(value, expected_value)
150153

151154
# Check the value_with_params column
152-
for step, agent_id, value in agent_table["value_with_params"].items():
153-
expected_value = (agent_id * 2) + 3
155+
for (step, agent_id), value in agent_table["value_with_params"].items():
156+
expected_value = ((step + agent_id) * 2) + 3
154157
self.assertEqual(value, expected_value)
155158

156159
assert len(data_collector._agent_records) == 8
@@ -161,7 +164,7 @@ def test_agent_records(self):
161164
assert len(records) == 9
162165

163166
for values in records:
164-
assert len(values) == 4
167+
assert len(values) == 6
165168

166169
assert "value" in list(agent_table.columns)
167170
assert "value2" in list(agent_table.columns)
@@ -196,7 +199,7 @@ def test_exports(self):
196199
agent_vars = data_collector.get_agent_vars_dataframe()
197200
table_df = data_collector.get_table_dataframe("Final_Values")
198201
assert model_vars.shape == (8, 5)
199-
assert agent_vars.shape == (77, 2)
202+
assert agent_vars.shape == (77, 4)
200203
assert table_df.shape == (9, 2)
201204

202205
with self.assertRaises(Exception):

0 commit comments

Comments
 (0)