Skip to content

Commit b2f6d7f

Browse files
committed
test_continuous_observable: Add multi-agent test cases
1 parent 835ad79 commit b2f6d7f

File tree

1 file changed

+274
-0
lines changed

1 file changed

+274
-0
lines changed

tests/test_continuous_observables.py

Lines changed: 274 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -604,3 +604,277 @@ def check_death():
604604
model.simulator.schedule_event_absolute(check_continued, 100.0)
605605
model.simulator.schedule_event_absolute(check_death, 300.0)
606606
model.simulator.run_until(300.0)
607+
608+
609+
def test_continuous_observable_multiple_agents_independent_values():
610+
"""Test that multiple agents maintain independent continuous values."""
611+
612+
class MyAgent(Agent, HasObservables):
613+
energy = ContinuousObservable(
614+
initial_value=100.0,
615+
rate_func=lambda value, elapsed, agent: -agent.metabolic_rate,
616+
)
617+
618+
def __init__(self, model, metabolic_rate):
619+
super().__init__(model)
620+
self.metabolic_rate = metabolic_rate
621+
self.energy = 100.0
622+
623+
model = SimpleModel()
624+
625+
# Create agents with different metabolic rates
626+
agent1 = MyAgent(model, metabolic_rate=1.0)
627+
agent2 = MyAgent(model, metabolic_rate=2.0)
628+
agent3 = MyAgent(model, metabolic_rate=0.5)
629+
630+
def check_values():
631+
# Each agent should deplete at their own rate
632+
assert agent1.energy == 90.0 # 100 - (1.0 * 10)
633+
assert agent2.energy == 80.0 # 100 - (2.0 * 10)
634+
assert agent3.energy == 95.0 # 100 - (0.5 * 10)
635+
636+
model.simulator.schedule_event_absolute(check_values, 10.0)
637+
model.simulator.run_until(10.0)
638+
639+
640+
def test_continuous_observable_multiple_agents_independent_thresholds():
641+
"""Test that different agents can have different thresholds."""
642+
643+
class MyAgent(Agent, HasObservables):
644+
energy = ContinuousObservable(
645+
initial_value=100.0, rate_func=lambda value, elapsed, agent: -1.0
646+
)
647+
648+
def __init__(self, model, name):
649+
super().__init__(model)
650+
self.name = name
651+
self.energy = 100.0
652+
self.threshold_crossed = False
653+
654+
def on_threshold(self, signal):
655+
if signal.direction == "down":
656+
self.threshold_crossed = True
657+
658+
model = SimpleModel()
659+
660+
# Create agents with different thresholds
661+
agent1 = MyAgent(model, "agent1")
662+
agent1.add_threshold("energy", 75.0, agent1.on_threshold)
663+
664+
agent2 = MyAgent(model, "agent2")
665+
agent2.add_threshold("energy", 25.0, agent2.on_threshold)
666+
667+
agent3 = MyAgent(model, "agent3")
668+
agent3.add_threshold("energy", 50.0, agent3.on_threshold)
669+
670+
def check_at_30():
671+
# At t=30, all agents at energy=70
672+
_ = agent1.energy
673+
_ = agent2.energy
674+
_ = agent3.energy
675+
676+
# Only agent1 should have crossed their threshold (75)
677+
assert agent1.threshold_crossed
678+
assert not agent2.threshold_crossed # Hasn't reached 25 yet
679+
assert not agent3.threshold_crossed # Hasn't reached 50 yet
680+
681+
def check_at_55():
682+
# At t=55, all agents at energy=45
683+
_ = agent1.energy
684+
_ = agent2.energy
685+
_ = agent3.energy
686+
687+
# agent1 and agent3 should have crossed
688+
assert agent1.threshold_crossed
689+
assert not agent2.threshold_crossed # Still hasn't reached 25
690+
assert agent3.threshold_crossed # Crossed 50
691+
692+
def check_at_80():
693+
# At t=80, all agents at energy=20
694+
_ = agent1.energy
695+
_ = agent2.energy
696+
_ = agent3.energy
697+
698+
# All should have crossed now
699+
assert agent1.threshold_crossed
700+
assert agent2.threshold_crossed # Finally crossed 25
701+
assert agent3.threshold_crossed
702+
703+
model.simulator.schedule_event_absolute(check_at_30, 30.0)
704+
model.simulator.schedule_event_absolute(check_at_55, 55.0)
705+
model.simulator.schedule_event_absolute(check_at_80, 80.0)
706+
model.simulator.run_until(80.0)
707+
708+
709+
def test_continuous_observable_multiple_agents_same_threshold_different_callbacks():
710+
"""Test that multiple agents can watch the same threshold value with different callbacks."""
711+
712+
class MyAgent(Agent, HasObservables):
713+
energy = ContinuousObservable(
714+
initial_value=100.0, rate_func=lambda value, elapsed, agent: -1.0
715+
)
716+
717+
def __init__(self, model, name):
718+
super().__init__(model)
719+
self.name = name
720+
self.energy = 100.0
721+
self.crossed_count = 0
722+
723+
def on_threshold(self, signal):
724+
if signal.direction == "down":
725+
self.crossed_count += 1
726+
727+
model = SimpleModel()
728+
729+
# Create multiple agents, all watching threshold at 50
730+
agents = [MyAgent(model, f"agent{i}") for i in range(5)]
731+
732+
for agent in agents:
733+
agent.add_threshold("energy", 50.0, agent.on_threshold)
734+
735+
def check_crossings():
736+
# Access all agents' energy
737+
for agent in agents:
738+
_ = agent.energy
739+
740+
# Each should have crossed independently
741+
for agent in agents:
742+
assert agent.crossed_count == 1
743+
744+
model.simulator.schedule_event_absolute(check_crossings, 60.0)
745+
model.simulator.run_until(60.0)
746+
747+
748+
def test_continuous_observable_agents_with_different_initial_values():
749+
"""Test agents starting with different energy values."""
750+
751+
class MyAgent(Agent, HasObservables):
752+
energy = ContinuousObservable(
753+
initial_value=100.0, rate_func=lambda value, elapsed, agent: -1.0
754+
)
755+
756+
def __init__(self, model, initial_energy):
757+
super().__init__(model)
758+
self.energy = initial_energy
759+
760+
model = SimpleModel()
761+
762+
# Create agents with different starting energies
763+
agent1 = MyAgent(model, initial_energy=100.0)
764+
agent2 = MyAgent(model, initial_energy=50.0)
765+
agent3 = MyAgent(model, initial_energy=150.0)
766+
767+
def check_values():
768+
# Each should deplete from their starting value
769+
assert agent1.energy == 90.0 # 100 - 10
770+
assert agent2.energy == 40.0 # 50 - 10
771+
assert agent3.energy == 140.0 # 150 - 10
772+
773+
model.simulator.schedule_event_absolute(check_values, 10.0)
774+
model.simulator.run_until(10.0)
775+
776+
777+
def test_continuous_observable_agent_interactions():
778+
"""Test agents affecting each other's continuous observables."""
779+
780+
class Predator(Agent, HasObservables):
781+
energy = ContinuousObservable(
782+
initial_value=50.0, rate_func=lambda value, elapsed, agent: -0.5
783+
)
784+
785+
def __init__(self, model):
786+
super().__init__(model)
787+
self.energy = 50.0
788+
self.kills = 0
789+
790+
def eat(self, prey):
791+
"""Eat prey and gain energy."""
792+
self.energy += 20
793+
self.kills += 1
794+
prey.die()
795+
796+
class Prey(Agent, HasObservables):
797+
energy = ContinuousObservable(
798+
initial_value=100.0, rate_func=lambda value, elapsed, agent: -1.0
799+
)
800+
801+
def __init__(self, model):
802+
super().__init__(model)
803+
self.energy = 100.0
804+
self.alive = True
805+
806+
def die(self):
807+
self.alive = False
808+
809+
model = SimpleModel()
810+
811+
predator = Predator(model)
812+
prey1 = Prey(model)
813+
prey2 = Prey(model)
814+
815+
def predator_hunts():
816+
# Predator energy should have depleted
817+
assert predator.energy == 45.0 # 50 - (0.5 * 10)
818+
819+
# Predator eats prey1
820+
predator.eat(prey1)
821+
822+
# Predator gains energy
823+
assert predator.energy == 65.0 # 45 + 20
824+
assert not prey1.alive
825+
assert prey2.alive
826+
827+
def check_final():
828+
# Predator continues depleting from boosted energy
829+
assert predator.energy == 60.0 # 65 - (0.5 * 10)
830+
831+
# prey2 continues depleting
832+
assert prey2.energy == 80.0 # 100 - (1.0 * 20)
833+
assert prey2.alive
834+
835+
model.simulator.schedule_event_absolute(predator_hunts, 10.0)
836+
model.simulator.schedule_event_absolute(check_final, 20.0)
837+
model.simulator.run_until(20.0)
838+
839+
840+
def test_continuous_observable_batch_creation_with_thresholds():
841+
"""Test batch agent creation where each agent has instance-specific thresholds."""
842+
843+
class MyAgent(Agent, HasObservables):
844+
energy = ContinuousObservable(
845+
initial_value=100.0, rate_func=lambda value, elapsed, agent: -1.0
846+
)
847+
848+
def __init__(self, model, critical_threshold):
849+
super().__init__(model)
850+
self.energy = 100.0
851+
self.critical_threshold = critical_threshold
852+
self.critical = False
853+
854+
# Each agent watches their own critical threshold
855+
self.add_threshold("energy", critical_threshold, self.on_critical)
856+
857+
def on_critical(self, signal):
858+
if signal.direction == "down":
859+
self.critical = True
860+
861+
model = SimpleModel()
862+
863+
# Create 10 agents with different critical thresholds
864+
thresholds = [90.0, 80.0, 70.0, 60.0, 50.0, 40.0, 30.0, 20.0, 10.0, 5.0]
865+
agents = [MyAgent(model, threshold) for threshold in thresholds]
866+
867+
def check_at_45():
868+
# At t=45, all agents at energy=55
869+
for agent in agents:
870+
_ = agent.energy # Trigger recalculation
871+
872+
# Agents with thresholds > 55 should be critical
873+
for agent, threshold in zip(agents, thresholds):
874+
if threshold > 55:
875+
assert agent.critical
876+
else:
877+
assert not agent.critical
878+
879+
model.simulator.schedule_event_absolute(check_at_45, 45.0)
880+
model.simulator.run_until(45.0)

0 commit comments

Comments
 (0)