This repository implements a framework for studying coordination, alignment, and robustness in multi-agent systems through:
- Active Inference & Expected Free Energy (EFE)
- Recursive Theory of Mind (ToM) planning
- Empathy-weighted decision-making
- JAX-accelerated computation (30-86x speedup)
- Hierarchical zone-based planning for complex layouts
The central research goal is to test whether alignment emerges naturally when agents attempt to preserve each other's future option sets — and whether asymmetric empathy enables coordination in constrained environments.
- Quick Start
- Running Experiments
- Architecture Overview
- Code Structure
- Key Concepts
- Hierarchical Planning
- Understanding the Results
- JAX Acceleration
- Future Roadmap
- Citation
# Create environment
conda create -n alignment python=3.10
conda activate alignment
# Install dependencies
pip install -r requirements.txt
# Install JAX (recommended for 20-100x speedup)
pip install jax # CPU version
# OR for GPU: pip install jax[cuda12]python tests/smoke_test_tom.pyExpected output:
- ✅ TOM imports
- ✅ LavaModel / LavaAgent creation
- ✅ LavaV2Env reset + step
- ✅ Collision detection (cell + edge)
# Quick test on narrow corridor (18 experiments, ~3 minutes)
python scripts/run_empathy_sweep.py --layouts narrow --max-steps 10 --seeds 1The primary script is scripts/run_empathy_sweep.py. It tests how different empathy configurations affect coordination.
# Run on a single layout
python scripts/run_empathy_sweep.py --layouts narrow
# Run on multiple layouts
python scripts/run_empathy_sweep.py --layouts narrow bottleneck wide
# Run all layouts (takes longer)
python scripts/run_empathy_sweep.py| Option | Description | Default |
|---|---|---|
--layouts |
Layouts to test (space-separated) | All layouts |
--mode |
symmetric, asymmetric, or both |
both |
--max-steps |
Max timesteps per episode | 15 |
--horizon |
Planning horizon | 4 |
--seeds |
Number of random seeds | 1 |
--hierarchical |
Use hierarchical planner (faster for bottlenecks) | False |
--verbose |
Print every timestep | False |
# 1. Quick test - narrow corridor, asymmetric empathy
python scripts/run_empathy_sweep.py --layouts narrow --mode asymmetric --max-steps 10 --seeds 1
# 2. Full sweep on bottleneck layouts (uses hierarchical planner)
python scripts/run_empathy_sweep.py --layouts vertical_bottleneck symmetric_bottleneck --hierarchical
# 3. Compare symmetric vs asymmetric empathy
python scripts/run_empathy_sweep.py --layouts bottleneck --mode both
# 4. Debug a specific case
python scripts/run_empathy_sweep.py --layouts narrow --mode asymmetric --verbose| Layout | Description | Difficulty |
|---|---|---|
wide |
6x3 open corridor | Easy |
narrow |
6x3 single-file corridor | Hard |
bottleneck |
Wide with central chokepoint | Medium |
vertical_bottleneck |
Vertical with central chokepoint | Medium |
symmetric_bottleneck |
Equal-sized zones around chokepoint | Medium |
crossed_goals |
Goals require path crossing | Hard |
double_bottleneck |
Two sequential chokepoints | Hard |
passing_bay |
Narrow with one passing spot | Medium |
risk_reward |
Safe long path vs risky short path | Medium |
t_junction |
T-shaped intersection | Hard |
asymmetric_detour |
One agent must detour | Medium |
# Test asymmetric empathy scenarios
python scripts/test_asymmetric_empathy.py
# Single-agent demo
python scripts/run_lava_si.py
# Two-agent empathy demo
python scripts/run_lava_empathy.py
# Diagnose ToM behavior
python scripts/diagnose_tom.py┌─────────────────────────────────────────────────────────────┐
│ EmpathicLavaPlanner │
├─────────────────────────────────────────────────────────────┤
│ │
│ 1. RECURSIVE ToM: Predict other agent's action │
│ ┌──────────────────────────────────────────────────┐ │
│ │ predict_other_action_recursive_jax() │ │
│ │ - depth=2: I think that you think that I... │ │
│ │ - horizon=3: Multi-step lookahead │ │
│ │ - Uses JAX JIT for 20x speedup │ │
│ └──────────────────────────────────────────────────┘ │
│ ↓ │
│ 2. EMPATHIC EFE: Compute G_social for all policies │
│ ┌──────────────────────────────────────────────────┐ │
│ │ compute_empathic_G_jax() │ │
│ │ - G_social = G_self + α * G_other │ │
│ │ - Collision detection (cell + edge) │ │
│ │ - vmap over 125-625 policies │ │
│ └──────────────────────────────────────────────────┘ │
│ ↓ │
│ 3. ACTION SELECTION: Softmax over G_social │
│ ┌──────────────────────────────────────────────────┐ │
│ │ action = argmin(G_social) │ │
│ │ OR sample from q(π) ∝ exp(-γ * G_social) │ │
│ └──────────────────────────────────────────────────┘ │
│ │
└─────────────────────────────────────────────────────────────┘
The system detects two types of collisions:
-
Cell collision: Both agents end up in the same cell
- Detected via
A_cell_collisionobservation matrix - Penalty in
C_cell_collision
- Detected via
-
Edge collision (swap): Agents try to pass through each other
- Agent i moves A→B while agent j moves B→A
- Detected via swap probability computation
- Same penalty as cell collision
Alignment-experiments/
│
├── tom/ # Core library
│ ├── models/
│ │ └── model_lava.py # LavaModel: A, B, C, D matrices
│ │
│ ├── envs/
│ │ ├── env_lava_v2.py # Multi-agent environment
│ │ └── env_lava_variants.py # Layout definitions
│ │
│ └── planning/
│ ├── si_empathy_lava.py # EmpathicLavaPlanner (main class)
│ ├── jax_si_empathy_lava.py # JAX-accelerated functions
│ └── jax_hierarchical_planner.py # Hierarchical zone planner
│
├── scripts/ # Runnable experiments
│ ├── run_empathy_sweep.py # Main experiment sweep
│ ├── test_asymmetric_empathy.py # ToM verification tests
│ ├── run_lava_empathy.py # Two-agent demo
│ └── diagnose_tom.py # Debug ToM predictions
│
├── tests/ # Test suite
│ ├── smoke_test_tom.py # Quick sanity check
│ ├── test_jax_planner.py # JAX correctness tests
│ └── run_all_tests.py # Full test suite
│
├── results/ # Experiment outputs
│ ├── empathy_sweep_*.csv # Sweep results
│ └── figs/ # Generated plots
│
└── HIERARCHICAL_PLANNER_ROADMAP.md # Future development plans
| File | Purpose |
|---|---|
si_empathy_lava.py |
Main EmpathicLavaPlanner class. Orchestrates ToM + empathy |
jax_si_empathy_lava.py |
JAX-accelerated ToM functions (predict_other_action_recursive_jax) |
run_empathy_sweep.py |
Runs experiments across layouts and empathy configurations |
test_asymmetric_empathy.py |
Validates ToM produces correct predictions |
Agents recursively model each other's beliefs and actions:
Depth 0: "What will j do?" → Assume j stays in place
Depth 1: "What will j do, given j thinks I'll stay?" → Better prediction
Depth 2: "What will j do, given j thinks I think j stays?" → Even better
The TOM_DEPTH = 2 and TOM_HORIZON = 3 parameters control recursion depth and lookahead.
The empathy parameter α ∈ [0, 1] determines how much an agent weighs the other's utility:
G_social(π) = G_self(π) + α * G_other(π)
| α value | Behavior |
|---|---|
| α = 0 | Purely selfish - only cares about own goals |
| α = 0.5 | Balanced - weighs both equally |
| α = 1.0 | Fully empathic - other's utility as important as own |
The key insight: when agents have different empathy levels, coordination emerges:
| Agent i (α_i) | Agent j (α_j) | Outcome |
|---|---|---|
| 0.0 (selfish) | 0.0 (selfish) | Both rush → Collision |
| 0.0 (selfish) | 1.0 (empathic) | i rushes, j yields → Success |
| 1.0 (empathic) | 0.0 (selfish) | i yields, j rushes → Success |
| 1.0 (empathic) | 1.0 (empathic) | Both yield → Paralysis (deadlock) |
Each action is evaluated by its expected free energy:
G(a) = -pragmatic - epistemic
= -E[utility(observations)] - info_gain(about_world)
Components:
- Pragmatic: Goal-seeking, collision avoidance
- Epistemic: Information gain (exploration)
For complex layouts with bottlenecks, the hierarchical planner decomposes planning into two levels:
┌─────────────────────────────────────────────────────────────┐
│ HierarchicalEmpathicPlannerJax │
├─────────────────────────────────────────────────────────────┤
│ │
│ 1. HIGH-LEVEL: Zone transition planning │
│ ┌──────────────────────────────────────────────────┐ │
│ │ high_level_plan_jax() │ │
│ │ - State: (my_zone, other_zone) │ │
│ │ - Actions: STAY, FORWARD, BACK │ │
│ │ - Empathy at zone level (yielding bottleneck) │ │
│ └──────────────────────────────────────────────────┘ │
│ ↓ │
│ 2. LOW-LEVEL: Within-zone navigation │
│ ┌──────────────────────────────────────────────────┐ │
│ │ low_level_plan_multistep_jax() │ │
│ │ - Subgoal: exit point or final goal │ │
│ │ - Multi-step ToM (depth=2, horizon=3) │ │
│ │ - Smart subgoal switching at boundaries │ │
│ └──────────────────────────────────────────────────┘ │
│ │
└─────────────────────────────────────────────────────────────┘
| Approach | Policies | Memory |
|---|---|---|
| Flat H=7 | 5^7 = 78,125 | OOM |
| Hierarchical | 3^3 × 5^3 = 3,375 | OK |
# Enable hierarchical planning
python scripts/run_empathy_sweep.py --layouts risk_reward --hierarchical
# Test asymmetric empathy with hierarchical planner
python scripts/test_asymmetric_empathy.py --layout risk_rewardThe hierarchical planner has zone definitions for:
vertical_bottleneck- Vertical corridor with central chokepointsymmetric_bottleneck- Equal-sized zones around chokepointnarrow- Single-file corridor (3 zones)risk_reward- Safe long path vs risky short path (3 zones)
On risk_reward layout with asymmetric empathy (α_i=1.0, α_j=0.0):
Step 4: i@(3,1) -> STAY (empathic yields at bottleneck)
Step 5: i@(3,1) -> STAY (continues yielding)
...
Step 9: j@(0,0) -> DOWN (selfish passes through)
Step 10: i@(3,1) -> UP (empathic resumes after j clears)
...
Step 14: Both reach goals -> SUCCESS!
Results are saved to results/empathy_sweep_YYYYMMDD_HHMMSS.csv:
| Column | Description |
|---|---|
layout |
Environment layout name |
start_config |
Starting configuration (A, B, C, D) |
alpha_i, alpha_j |
Empathy parameters |
both_success |
Both agents reached goals without collision |
cell_collision |
Agents ended up in same cell |
edge_collision |
Agents tried to swap positions |
paralysis |
Agents got stuck (both yielding) |
steps |
Number of timesteps |
trajectory_i, trajectory_j |
Position sequences |
- Success rate: Both agents reach goals without collision
- Collision rate: Agents crashed into each other
- Paralysis rate: Both agents got stuck yielding to each other
import pandas as pd
# Load results
df = pd.read_csv("results/empathy_sweep_*.csv")
# Success rate by empathy configuration
success = df.groupby(['alpha_i', 'alpha_j'])['both_success'].mean()
print(success.unstack())
# Which layouts have highest collision rate?
collision_by_layout = df.groupby('layout')['cell_collision'].mean()
print(collision_by_layout.sort_values(ascending=False))Look for yielding behavior in trajectories:
# Agent yields if they stay in place while other passes
trajectory_j: (4,1) → (4,1) → (4,1) → (3,1) → (2,1) → goal
↑ stayed ↑ stayed ↑ started moving
JAX provides 30-86x speedup for planning computations through JIT compilation.
| Function | NumPy | JAX (cached) | Speedup |
|---|---|---|---|
predict_other_action_recursive |
~0.5s | ~0.025s | 20x |
compute_empathic_G (125 policies) |
~45s | ~0.5s | 90x |
| Hierarchical planner (multi-step ToM) | ~1.0s | ~0.013s | 86x |
| JAX vs NumPy (ToM prediction) | ~0.12s | ~0.004s | 30x |
JAX is enabled by default when available:
from tom.planning.si_empathy_lava import EmpathicLavaPlanner
# Automatically uses JAX (20-100x faster)
planner = EmpathicLavaPlanner(agent_i, agent_j, alpha=0.5)
# Disable JAX for debugging
planner = EmpathicLavaPlanner(agent_i, agent_j, alpha=0.5, use_jax=False)JAX compiles functions on first call (JIT). Expect:
- First call: ~1s (compilation)
- Subsequent calls: ~0.025s (cached)
See HIERARCHICAL_PLANNER_ROADMAP.md for detailed plans. Key upcoming features:
Measure how robust a trajectory is:
- Empowerment: How many future options remain available
- Returnability: Probability of reaching safe states
- Outcome overlap: Similarity of predicted futures between agents
F(π) = λ_E * Empowerment(π) + λ_R * Returnability(π) + λ_O * Overlap(π)
Bias agents toward flexible (robust) trajectories:
p(π) ∝ exp(κ * [F_i(π) + β * F_j(π)])
Combined objective:
J_i(π) = G_i + α*G_j - (κ/γ) * [F_i + β*F_j]
Replace hard-coded collision penalties with learned beliefs:
- Track P(collision | zone_i, zone_j)
- Update beliefs based on observed collisions
- High-level planner uses inferred probabilities
If you use this codebase, please cite:
@software{albarracin2025_empathic_tom,
title={Multi-Agent Theory of Mind with Empathy in Active Inference},
author={Mahault Albarracin},
year={2025},
url={https://github.com/mahault/Alignment-experiments}
}Issues & discussions: https://github.com/mahault/Alignment-experiments/issues