Skip to content

Commit 3cf1b76

Browse files
authored
Allow AgentSet.do() to take Callable function (#2219)
This PR enhances `AgentSet.do` to take a callable or str. Currently, AgentSet.do takes a `str` which maps to a method on the agents in the set. This PR makes it possible to also use a `Callable` instead. This callable will be called with the `agent` as the first argument. ⚠️ Breaking change ⚠️ A small breaking change is introduced here: the `method_name` parameter is renamed to `method`. For models that use this as a keyword argument this is a breaking change, and need to replace `do(method_name="something")` with `do(method="something")`.
1 parent 90628fa commit 3cf1b76

File tree

2 files changed

+89
-12
lines changed

2 files changed

+89
-12
lines changed

mesa/agent.py

Lines changed: 22 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -217,26 +217,37 @@ def _update(self, agents: Iterable[Agent]):
217217
return self
218218

219219
def do(
220-
self, method_name: str, *args, return_results: bool = False, **kwargs
220+
self, method: str | Callable, *args, return_results: bool = False, **kwargs
221221
) -> AgentSet | list[Any]:
222222
"""
223-
Invoke a method on each agent in the AgentSet.
223+
Invoke a method or function on each agent in the AgentSet.
224224
225225
Args:
226-
method_name (str): The name of the method to call on each agent.
226+
method (str, callable): the callable to do on each agents
227+
228+
* in case of str, the name of the method to call on each agent.
229+
* in case of callable, the function to be called with each agent as first argument
230+
227231
return_results (bool, optional): If True, returns the results of the method calls; otherwise, returns the AgentSet itself. Defaults to False, so you can chain method calls.
228-
*args: Variable length argument list passed to the method being called.
229-
**kwargs: Arbitrary keyword arguments passed to the method being called.
232+
*args: Variable length argument list passed to the callable being called.
233+
**kwargs: Arbitrary keyword arguments passed to the callable being called.
230234
231235
Returns:
232-
AgentSet | list[Any]: The results of the method calls if return_results is True, otherwise the AgentSet itself.
236+
AgentSet | list[Any]: The results of the callable calls if return_results is True, otherwise the AgentSet itself.
233237
"""
234238
# we iterate over the actual weakref keys and check if weakref is alive before calling the method
235-
res = [
236-
getattr(agent, method_name)(*args, **kwargs)
237-
for agentref in self._agents.keyrefs()
238-
if (agent := agentref()) is not None
239-
]
239+
if isinstance(method, str):
240+
res = [
241+
getattr(agent, method)(*args, **kwargs)
242+
for agentref in self._agents.keyrefs()
243+
if (agent := agentref()) is not None
244+
]
245+
else:
246+
res = [
247+
method(agent, *args, **kwargs)
248+
for agentref in self._agents.keyrefs()
249+
if (agent := agentref()) is not None
250+
]
240251

241252
return res if return_results else self
242253

tests/test_agent.py

Lines changed: 67 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -176,7 +176,7 @@ def test_agentset_get_item():
176176
_ = agentset[20]
177177

178178

179-
def test_agentset_do_method():
179+
def test_agentset_do_str():
180180
model = Model()
181181
agents = [TestAgent(model.next_id(), model) for _ in range(10)]
182182
agentset = AgentSet(agents, model)
@@ -210,6 +210,72 @@ def test_agentset_do_method():
210210
assert len(agentset) == 0
211211

212212

213+
def test_agentset_do_callable():
214+
model = Model()
215+
agents = [TestAgent(model.next_id(), model) for _ in range(10)]
216+
agentset = AgentSet(agents, model)
217+
218+
# Test callable with non-existent function
219+
with pytest.raises(AttributeError):
220+
agentset.do(lambda agent: agent.non_existing_method())
221+
222+
# tests for addition and removal in do using callables
223+
# do iterates, so no error should be raised to change size while iterating
224+
# related to issue #1595
225+
226+
# setup for lambda function tests
227+
n = 10
228+
model = Model()
229+
agents = [TestAgentDo(model.next_id(), model) for _ in range(n)]
230+
agentset = AgentSet(agents, model)
231+
for agent in agents:
232+
agent.agent_set = agentset
233+
234+
# Lambda for addition
235+
agentset.do(lambda agent: agent.do_add())
236+
assert len(agentset) == 2 * n
237+
238+
# setup again for lambda function tests
239+
model = Model()
240+
agents = [TestAgentDo(model.next_id(), model) for _ in range(10)]
241+
agentset = AgentSet(agents, model)
242+
for agent in agents:
243+
agent.agent_set = agentset
244+
245+
# Lambda for removal
246+
agentset.do(lambda agent: agent.do_remove())
247+
assert len(agentset) == 0
248+
249+
# setup for actual function tests
250+
def add_function(agent):
251+
agent.do_add()
252+
253+
def remove_function(agent):
254+
agent.do_remove()
255+
256+
# setup again for actual function tests
257+
model = Model()
258+
agents = [TestAgentDo(model.next_id(), model) for _ in range(n)]
259+
agentset = AgentSet(agents, model)
260+
for agent in agents:
261+
agent.agent_set = agentset
262+
263+
# Actual function for addition
264+
agentset.do(add_function)
265+
assert len(agentset) == 2 * n
266+
267+
# setup again for actual function tests
268+
model = Model()
269+
agents = [TestAgentDo(model.next_id(), model) for _ in range(10)]
270+
agentset = AgentSet(agents, model)
271+
for agent in agents:
272+
agent.agent_set = agentset
273+
274+
# Actual function for removal
275+
agentset.do(remove_function)
276+
assert len(agentset) == 0
277+
278+
213279
def test_agentset_get_attribute():
214280
model = Model()
215281
agents = [TestAgent(model.next_id(), model) for _ in range(10)]

0 commit comments

Comments
 (0)