diff --git a/notebooks/01_observations.ipynb b/notebooks/01_observations.ipynb index a3e2fbe309..a72fcaa6f3 100644 --- a/notebooks/01_observations.ipynb +++ b/notebooks/01_observations.ipynb @@ -16,65 +16,10 @@ "source": [ "import numpy as np\n", "import matplotlib.pyplot as plt\n", - "from pufferlib.ocean.drive import binding\n", - "import pufferlib.viz\n", - "from notebooks.notebook_utils import (\n", - " COEF_NAMES,\n", - " EGO_LABELS,\n", - " make_drive_env,\n", - " notebook_dims,\n", - " random_actions,\n", - " zero_actions,\n", - ")\n", - "\n", - "# --- Environment configuration ---\n", - "NUM_AGENTS = 64\n", - "SIMULATION_MODE = \"gigaflow\"\n", - "DYNAMICS_MODEL = \"jerk\"\n", - "ACTION_TYPE = \"discrete\"\n", - "DT = 0.1\n", - "SCENARIO_LENGTH = 512\n", - "RESAMPLE_FREQUENCY = 0\n", - "REWARD_CONDITIONING = True\n", - "REWARD_RANDOMIZATION = False\n", - "TARGET_TYPE = \"static\"\n", - "COLLISION_BEHAVIOR = 1\n", - "OFFROAD_BEHAVIOR = 1\n", - "SEED = 42\n", - "\n", - "# --- Observation dimensions ---\n", - "MAX_PARTNERS = 16\n", - "MAX_LANES = 32\n", - "MAX_BOUNDS = 32\n", - "MAX_TRAFFIC = 4\n", - "\n", - "env, obs, info = make_drive_env(\n", - " num_agents=NUM_AGENTS,\n", - " min_agents_per_env=NUM_AGENTS,\n", - " max_agents_per_env=NUM_AGENTS,\n", - " simulation_mode=SIMULATION_MODE,\n", - " dynamics_model=DYNAMICS_MODEL,\n", - " action_type=ACTION_TYPE,\n", - " dt=DT,\n", - " scenario_length=SCENARIO_LENGTH,\n", - " resample_frequency=RESAMPLE_FREQUENCY,\n", - " reward_conditioning=REWARD_CONDITIONING,\n", - " reward_randomization=REWARD_RANDOMIZATION,\n", - " target_type=TARGET_TYPE,\n", - " collision_behavior=COLLISION_BEHAVIOR,\n", - " offroad_behavior=OFFROAD_BEHAVIOR,\n", - " obs_slots_lane_n=MAX_LANES,\n", - " obs_slots_boundary_n=MAX_BOUNDS,\n", - " obs_slots_partners_n=MAX_PARTNERS,\n", - " obs_slots_traffic_controls_n=MAX_TRAFFIC,\n", - " seed=SEED,\n", - ")\n", - "globals().update(notebook_dims(env))\n", + "from pufferlib.viz import plot_observation, plot_simulator_state, unpack_obs\n", + "from notebooks.notebook_utils import COEF_NAMES, make_drive_env, zero_actions\n", "\n", - "print(f\"env ready: {N} agents, obs={obs.shape}, act_shape={ACT_SHAPE}\")\n", - "print(f\"EGO_DIM={EGO_DIM}, NUM_COEFS={NUM_COEFS}, MAX_PARTNERS={MAX_PARTNERS}, PARTNER_F={PARTNER_F}\")\n", - "print(f\"MAX_LANES={MAX_LANES}, MAX_BOUNDS={MAX_BOUNDS}, ROAD_F={ROAD_F}\")\n", - "print(f\"MAX_TRAFFIC={MAX_TRAFFIC}, TRAFFIC_F={TRAFFIC_CONTROL_F}\")" + "env, obs, info = make_drive_env()" ] }, { @@ -127,15 +72,15 @@ "metadata": {}, "outputs": [], "source": [ - "ego, target, partners, lanes, boundaries, traffic = pufferlib.viz.unpack_obs(\n", + "ego, target, partners, lanes, boundaries, traffic = unpack_obs(\n", " obs[:1],\n", - " target_type=TARGET_TYPE,\n", - " reward_conditioning=REWARD_CONDITIONING,\n", + " target_type=env.target_type,\n", + " reward_conditioning=env.reward_conditioning,\n", " num_target_waypoints=env.num_target_waypoints,\n", - " max_partners=MAX_PARTNERS,\n", - " max_lane_segments=MAX_LANES,\n", - " max_boundary_segments=MAX_BOUNDS,\n", - " obs_slots_traffic_controls_n=MAX_TRAFFIC,\n", + " obs_slots_partners_n=env.obs_slots_partners_n,\n", + " obs_slots_lane_n=env.obs_slots_lane_kept,\n", + " obs_slots_boundary_n=env.obs_slots_boundary_kept,\n", + " obs_slots_traffic_controls_n=env.obs_slots_traffic_controls_n,\n", ")\n", "print(f\"ego: {ego.shape} = {ego}\")\n", "print(f\"target: {target.shape}\")\n", @@ -177,37 +122,47 @@ "idx = 0\n", "\n", "# Ego\n", - "ego_manual = o[idx : idx + EGO_DIM]\n", - "idx += EGO_DIM\n", + "ego_manual = o[idx : idx + env.ego_features]\n", + "idx += env.ego_features\n", "assert np.allclose(ego_manual, ego), f\"ego mismatch: {ego_manual} vs {ego}\"\n", "\n", "# Reward conditioning coefs\n", - "coefs_manual = o[idx : idx + NUM_COEFS]\n", - "idx += NUM_COEFS\n", + "coefs_manual = o[idx : idx + env.num_reward_coefs]\n", + "idx += env.num_reward_coefs\n", "\n", "# Target\n", - "target_manual = o[idx : idx + MAX_TARGET * TARGET_F].reshape(MAX_TARGET, TARGET_F)\n", - "idx += MAX_TARGET * TARGET_F\n", + "target_manual = o[idx : idx + env.num_target_waypoints * env.target_features].reshape(\n", + " env.num_target_waypoints, env.target_features\n", + ")\n", + "idx += env.num_target_waypoints * env.target_features\n", "assert np.allclose(target_manual, target), \"target mismatch\"\n", "\n", "# Partners\n", - "partners_manual = o[idx : idx + MAX_PARTNERS * PARTNER_F].reshape(MAX_PARTNERS, PARTNER_F)\n", - "idx += MAX_PARTNERS * PARTNER_F\n", + "partners_manual = o[idx : idx + env.obs_slots_partners_n * env.partner_features].reshape(\n", + " env.obs_slots_partners_n, env.partner_features\n", + ")\n", + "idx += env.obs_slots_partners_n * env.partner_features\n", "assert np.allclose(partners_manual, partners), \"partners mismatch\"\n", "\n", "# Lanes\n", - "lanes_manual = o[idx : idx + MAX_LANES * ROAD_F].reshape(MAX_LANES, ROAD_F)\n", - "idx += MAX_LANES * ROAD_F\n", + "lanes_manual = o[idx : idx + env.obs_slots_lane_kept * env.road_features].reshape(\n", + " env.obs_slots_lane_kept, env.road_features\n", + ")\n", + "idx += env.obs_slots_lane_kept * env.road_features\n", "assert np.allclose(lanes_manual, lanes), \"lanes mismatch\"\n", "\n", "# Boundaries\n", - "bounds_manual = o[idx : idx + MAX_BOUNDS * ROAD_F].reshape(MAX_BOUNDS, ROAD_F)\n", - "idx += MAX_BOUNDS * ROAD_F\n", + "bounds_manual = o[idx : idx + env.obs_slots_boundary_kept * env.road_features].reshape(\n", + " env.obs_slots_boundary_kept, env.road_features\n", + ")\n", + "idx += env.obs_slots_boundary_kept * env.road_features\n", "assert np.allclose(bounds_manual, boundaries), \"boundaries mismatch\"\n", "\n", "# Traffic\n", - "traffic_manual = o[idx : idx + MAX_TRAFFIC * TRAFFIC_CONTROL_F].reshape(MAX_TRAFFIC, TRAFFIC_CONTROL_F)\n", - "idx += MAX_TRAFFIC * TRAFFIC_CONTROL_F\n", + "traffic_manual = o[idx : idx + env.obs_slots_traffic_controls_n * env.traffic_control_features].reshape(\n", + " env.obs_slots_traffic_controls_n, env.traffic_control_features\n", + ")\n", + "idx += env.obs_slots_traffic_controls_n * env.traffic_control_features\n", "assert np.allclose(traffic_manual, traffic), \"traffic mismatch\"\n", "\n", "assert idx == obs.shape[1], f\"obs size mismatch: used {idx}, total {obs.shape[1]}\"\n", @@ -227,9 +182,9 @@ "metadata": {}, "outputs": [], "source": [ - "coefs = obs[0, EGO_DIM : EGO_DIM + NUM_COEFS]\n", + "coefs = obs[0, env.ego_features : env.ego_features + env.num_reward_coefs]\n", "fig, ax = plt.subplots(figsize=(12, 4))\n", - "bars = ax.bar(range(NUM_COEFS), coefs, tick_label=COEF_NAMES)\n", + "bars = ax.bar(range(env.num_reward_coefs), coefs, tick_label=COEF_NAMES)\n", "ax.set_ylabel(\"Normalized coef value\")\n", "ax.set_title(\"Reward conditioning coefficients (agent 0)\")\n", "plt.xticks(rotation=45, ha=\"right\")\n", @@ -239,7 +194,7 @@ "plt.show()\n", "\n", "# Compare across agents\n", - "all_coefs = obs[:, EGO_DIM : EGO_DIM + NUM_COEFS]\n", + "all_coefs = obs[:, env.ego_features : env.ego_features + env.num_reward_coefs]\n", "print(\"Coef stats across agents:\")\n", "for i, name in enumerate(COEF_NAMES):\n", " c = all_coefs[:, i]\n", @@ -262,13 +217,13 @@ "partner_labels = [\"rel_x\", \"rel_y\", \"rel_z\", \"length\", \"width\", \"heading_cos\", \"heading_sin\", \"speed\"]\n", "active_mask = ~np.all(partners == 0, axis=1)\n", "n_active = active_mask.sum()\n", - "print(f\"Active partners: {n_active}/{MAX_PARTNERS}\")\n", + "print(f\"Active partners: {n_active}/{env.obs_slots_partners_n}\")\n", "\n", "fig, axes = plt.subplots(1, 2, figsize=(14, 5))\n", "\n", "# Heatmap\n", "im = axes[0].imshow(partners, aspect=\"auto\", cmap=\"RdBu_r\", vmin=-1, vmax=1)\n", - "axes[0].set_xticks(range(PARTNER_F))\n", + "axes[0].set_xticks(range(env.partner_features))\n", "axes[0].set_xticklabels(partner_labels, rotation=45, ha=\"right\")\n", "axes[0].set_ylabel(\"Partner index\")\n", "axes[0].set_title(f\"Partner obs heatmap ({n_active} active)\")\n", @@ -308,7 +263,9 @@ "\n", "lane_active = ~np.all(lanes == 0, axis=1)\n", "bound_active = ~np.all(boundaries == 0, axis=1)\n", - "print(f\"Active lanes: {lane_active.sum()}/{MAX_LANES}, boundaries: {bound_active.sum()}/{MAX_BOUNDS}\")\n", + "print(\n", + " f\"Active lanes: {lane_active.sum()}/{env.obs_slots_lane_kept}, boundaries: {bound_active.sum()}/{env.obs_slots_boundary_kept}\"\n", + ")\n", "\n", "fig, ax = plt.subplots(figsize=(10, 10))\n", "\n", @@ -367,15 +324,15 @@ "metadata": {}, "outputs": [], "source": [ - "img = pufferlib.viz.plot_observation(\n", + "img = plot_observation(\n", " obs[:1],\n", - " target_type=TARGET_TYPE,\n", - " reward_conditioning=True,\n", + " target_type=env.target_type,\n", + " reward_conditioning=env.reward_conditioning,\n", " num_target_waypoints=env.num_target_waypoints,\n", - " max_partners=MAX_PARTNERS,\n", - " max_lane_segments=MAX_LANES,\n", - " max_boundary_segments=MAX_BOUNDS,\n", - " obs_slots_traffic_controls_n=MAX_TRAFFIC,\n", + " obs_slots_partners_n=env.obs_slots_partners_n,\n", + " obs_slots_lane_n=env.obs_slots_lane_kept,\n", + " obs_slots_boundary_n=env.obs_slots_boundary_kept,\n", + " obs_slots_traffic_controls_n=env.obs_slots_traffic_controls_n,\n", ")\n", "fig, ax = plt.subplots(figsize=(10, 10))\n", "ax.imshow(img)\n", @@ -405,7 +362,7 @@ "else:\n", " scenario = scenarios\n", "\n", - "img = pufferlib.viz.plot_simulator_state(scenario, timestep=0)\n", + "img = plot_simulator_state(scenario)\n", "fig, ax = plt.subplots(figsize=(12, 12))\n", "ax.imshow(img)\n", "ax.axis(\"off\")\n", @@ -428,23 +385,12 @@ "outputs": [], "source": [ "N_STEPS = 20\n", - "ego_labels = [\n", - " \"speed\",\n", - " \"width\",\n", - " \"length\",\n", - " \"steering\",\n", - " \"a_long\",\n", - " \"a_lat\",\n", - " \"lane_center_dist\",\n", - " \"lane_heading_cos\",\n", - " \"speed_limit\",\n", - "]\n", - "ego_history = np.zeros((N_STEPS, EGO_DIM))\n", + "ego_history = np.zeros((N_STEPS, env.ego_features))\n", "\n", "for t in range(N_STEPS):\n", " actions = zero_actions(env)\n", " obs_t, _, _, _, _ = env.step(actions)\n", - " ego_history[t] = obs_t[0, :EGO_DIM]\n", + " ego_history[t] = obs_t[0, : env.ego_features]\n", "\n", "fig, axes = plt.subplots(2, 2, figsize=(14, 8))\n", "# Speed\n", @@ -486,16 +432,16 @@ "speeds = obs[:, 0] # speed is at index 0\n", "\n", "# Target waypoints start after ego + reward coefs\n", - "target_start = EGO_DIM + NUM_COEFS\n", + "target_start = env.ego_features + env.num_reward_coefs\n", "# Each target waypoint has TARGET_F features; first two are rel_x, rel_y\n", "first_target_x = obs[:, target_start]\n", "first_target_y = obs[:, target_start + 1]\n", "target_dists = np.sqrt(first_target_x**2 + first_target_y**2)\n", "\n", "# Count active partners per agent\n", - "partner_start = EGO_DIM + NUM_COEFS + TARGET_DIM\n", - "partner_end = partner_start + MAX_PARTNERS * PARTNER_F\n", - "all_partners = obs[:, partner_start:partner_end].reshape(-1, MAX_PARTNERS, PARTNER_F)\n", + "partner_start = env.ego_features + env.num_reward_coefs + env.num_target_waypoints * env.target_features\n", + "partner_end = partner_start + env.obs_slots_partners_n * env.partner_features\n", + "all_partners = obs[:, partner_start:partner_end].reshape(-1, env.obs_slots_partners_n, env.partner_features)\n", "partner_counts = (~np.all(all_partners == 0, axis=2)).sum(axis=1)\n", "\n", "fig, axes = plt.subplots(1, 3, figsize=(15, 4))\n", @@ -507,7 +453,7 @@ "axes[1].set_title(\"Distance to first target waypoint\")\n", "axes[1].set_xlabel(\"distance\")\n", "\n", - "axes[2].hist(partner_counts, bins=range(MAX_PARTNERS + 2), edgecolor=\"black\", alpha=0.7, color=\"green\")\n", + "axes[2].hist(partner_counts, bins=range(env.obs_slots_partners_n + 2), edgecolor=\"black\", alpha=0.7, color=\"green\")\n", "axes[2].set_title(\"Active partners per agent\")\n", "axes[2].set_xlabel(\"count\")\n", "plt.tight_layout()\n", diff --git a/notebooks/02_rewards.ipynb b/notebooks/02_rewards.ipynb index 41a8bc6b06..63372f8f9d 100644 --- a/notebooks/02_rewards.ipynb +++ b/notebooks/02_rewards.ipynb @@ -16,65 +16,22 @@ "source": [ "import numpy as np\n", "import matplotlib.pyplot as plt\n", - "from pufferlib.ocean.drive import binding\n", - "import pufferlib.viz\n", - "from notebooks.notebook_utils import (\n", - " COEF_NAMES,\n", - " EGO_LABELS,\n", - " make_drive_env,\n", - " notebook_dims,\n", - " random_actions,\n", - " zero_actions,\n", - ")\n", - "\n", - "# --- Environment configuration ---\n", - "NUM_AGENTS = 64\n", - "SIMULATION_MODE = \"gigaflow\"\n", - "DYNAMICS_MODEL = \"jerk\"\n", - "ACTION_TYPE = \"discrete\"\n", - "DT = 0.1\n", - "SCENARIO_LENGTH = 512\n", - "RESAMPLE_FREQUENCY = 0\n", - "REWARD_CONDITIONING = True\n", - "REWARD_RANDOMIZATION = False\n", - "TARGET_TYPE = \"static\"\n", - "COLLISION_BEHAVIOR = 1\n", - "OFFROAD_BEHAVIOR = 1\n", - "SEED = 42\n", + "from notebooks.notebook_utils import COEF_NAMES, make_drive_env, random_actions, zero_actions\n", "\n", - "# --- Observation dimensions ---\n", - "MAX_PARTNERS = 16\n", - "MAX_LANES = 32\n", - "MAX_BOUNDS = 32\n", - "MAX_TRAFFIC = 10\n", + "env, obs, info = make_drive_env()\n", "\n", - "env, obs, info = make_drive_env(\n", - " num_agents=NUM_AGENTS,\n", - " min_agents_per_env=NUM_AGENTS,\n", - " max_agents_per_env=NUM_AGENTS,\n", - " simulation_mode=SIMULATION_MODE,\n", - " dynamics_model=DYNAMICS_MODEL,\n", - " action_type=ACTION_TYPE,\n", - " dt=DT,\n", - " scenario_length=SCENARIO_LENGTH,\n", - " resample_frequency=RESAMPLE_FREQUENCY,\n", - " reward_conditioning=REWARD_CONDITIONING,\n", - " reward_randomization=REWARD_RANDOMIZATION,\n", - " target_type=TARGET_TYPE,\n", - " collision_behavior=COLLISION_BEHAVIOR,\n", - " offroad_behavior=OFFROAD_BEHAVIOR,\n", - " obs_slots_lane_n=MAX_LANES,\n", - " obs_slots_boundary_n=MAX_BOUNDS,\n", - " obs_slots_partners_n=MAX_PARTNERS,\n", - " obs_slots_traffic_controls_n=MAX_TRAFFIC,\n", - " seed=SEED,\n", + "print(\n", + " f\"env ready: {env.num_agents} agents, obs={obs.shape}, act_shape={(env.num_agents, len(env.single_action_space.nvec))}\"\n", ")\n", - "globals().update(notebook_dims(env))\n", - "\n", - "print(f\"env ready: {N} agents, obs={obs.shape}, act_shape={ACT_SHAPE}\")\n", - "print(f\"EGO_DIM={EGO_DIM}, NUM_COEFS={NUM_COEFS}, MAX_PARTNERS={MAX_PARTNERS}, PARTNER_F={PARTNER_F}\")\n", - "print(f\"MAX_LANES={MAX_LANES}, MAX_BOUNDS={MAX_BOUNDS}, ROAD_F={ROAD_F}\")\n", - "print(f\"MAX_TRAFFIC={MAX_TRAFFIC}, TRAFFIC_F={TRAFFIC_CONTROL_F}\")" + "print(\n", + " f\"ego_features={env.ego_features}, num_reward_coefs={env.num_reward_coefs}, obs_slots_partners_n={env.obs_slots_partners_n}, partner_features={env.partner_features}\"\n", + ")\n", + "print(\n", + " f\"obs_slots_lane_kept={env.obs_slots_lane_kept}, obs_slots_boundary_kept={env.obs_slots_boundary_kept}, road_features={env.road_features}\"\n", + ")\n", + "print(\n", + " f\"obs_slots_traffic_controls_n={env.obs_slots_traffic_controls_n}, traffic_control_features={env.traffic_control_features}\"\n", + ")" ] }, { @@ -122,8 +79,8 @@ "outputs": [], "source": [ "N_STEPS = 100\n", - "rewards_history = np.zeros((N_STEPS, N))\n", - "terms_history = np.zeros((N_STEPS, N))\n", + "rewards_history = np.zeros((N_STEPS, env.num_agents))\n", + "terms_history = np.zeros((N_STEPS, env.num_agents))\n", "\n", "for t in range(N_STEPS):\n", " actions = random_actions(env)\n", @@ -145,7 +102,7 @@ "plt.colorbar(im, ax=axes[1])\n", "\n", "cum_returns = rewards_history.cumsum(axis=0)\n", - "for i in range(min(8, N)):\n", + "for i in range(min(8, env.num_agents)):\n", " axes[2].plot(cum_returns[:, i], alpha=0.6, label=f\"agent {i}\")\n", "axes[2].set_xlabel(\"Step\")\n", "axes[2].set_ylabel(\"Cumulative return\")\n", @@ -171,7 +128,7 @@ "metadata": {}, "outputs": [], "source": [ - "all_coefs = obs[:, EGO_DIM : EGO_DIM + NUM_COEFS]\n", + "all_coefs = obs[:, env.ego_features : env.ego_features + env.num_reward_coefs]\n", "print(f\"Reward coefs shape: {all_coefs.shape}\")\n", "print()\n", "print(f\"{'Coef':>15s} | {'mean':>8s} {'std':>8s} {'min':>8s} {'max':>8s}\")\n", @@ -201,7 +158,7 @@ "for t in range(N_STEPS):\n", " actions = random_actions(env)\n", " obs, rew, term, trunc, info = env.step(actions)\n", - " for i in range(N):\n", + " for i in range(env.num_agents):\n", " if term[i]:\n", " term_steps.append(t)\n", " term_rewards.append(rew[i])\n", @@ -252,9 +209,9 @@ " prev_obs = obs.copy()\n", " actions = random_actions(env)\n", " obs, rew, term, trunc, info = env.step(actions)\n", - " for i in range(N):\n", + " for i in range(env.num_agents):\n", " if rew[i] >= 0.5:\n", - " target_start = EGO_DIM + NUM_COEFS\n", + " target_start = env.ego_features + env.num_reward_coefs\n", " goal_dist = np.sqrt(prev_obs[i, target_start] ** 2 + prev_obs[i, target_start + 1] ** 2)\n", " goal_events.append((t, i, rew[i], goal_dist))\n", "\n", @@ -329,10 +286,10 @@ "STEPS_PER_ACTION = 20\n", "action_rewards = {}\n", "\n", - "for a in range(N_ACTIONS):\n", + "for a in range(env.single_action_space.nvec[0]):\n", " rews = []\n", " for _ in range(STEPS_PER_ACTION):\n", - " actions = np.full(ACT_SHAPE, a, dtype=np.int64)\n", + " actions = np.full((env.num_agents, len(env.single_action_space.nvec)), a, dtype=np.int64)\n", " obs, rew, term, trunc, info = env.step(actions)\n", " rews.append(rew.mean())\n", " action_rewards[a] = np.mean(rews)\n", diff --git a/notebooks/03_metrics.ipynb b/notebooks/03_metrics.ipynb index 30b893c781..d2e5084a66 100644 --- a/notebooks/03_metrics.ipynb +++ b/notebooks/03_metrics.ipynb @@ -17,64 +17,23 @@ "import numpy as np\n", "import matplotlib.pyplot as plt\n", "from pufferlib.ocean.drive import binding\n", - "import pufferlib.viz\n", - "from notebooks.notebook_utils import (\n", - " COEF_NAMES,\n", - " EGO_LABELS,\n", - " make_drive_env,\n", - " notebook_dims,\n", - " random_actions,\n", - " zero_actions,\n", - ")\n", + "from notebooks.notebook_utils import make_drive_env, random_actions\n", "\n", - "# --- Environment configuration ---\n", - "NUM_AGENTS = 64\n", - "SIMULATION_MODE = \"gigaflow\"\n", - "DYNAMICS_MODEL = \"jerk\"\n", - "ACTION_TYPE = \"discrete\"\n", - "DT = 0.1\n", - "SCENARIO_LENGTH = 512\n", - "RESAMPLE_FREQUENCY = 0\n", - "REWARD_CONDITIONING = True\n", - "REWARD_RANDOMIZATION = False\n", - "TARGET_TYPE = \"static\"\n", - "COLLISION_BEHAVIOR = 1\n", - "OFFROAD_BEHAVIOR = 1\n", - "SEED = 42\n", "\n", - "# --- Observation dimensions ---\n", - "MAX_PARTNERS = 16\n", - "MAX_LANES = 32\n", - "MAX_BOUNDS = 32\n", - "MAX_TRAFFIC = 10\n", + "env, obs, info = make_drive_env()\n", "\n", - "env, obs, info = make_drive_env(\n", - " num_agents=NUM_AGENTS,\n", - " min_agents_per_env=NUM_AGENTS,\n", - " max_agents_per_env=NUM_AGENTS,\n", - " simulation_mode=SIMULATION_MODE,\n", - " dynamics_model=DYNAMICS_MODEL,\n", - " action_type=ACTION_TYPE,\n", - " dt=DT,\n", - " scenario_length=SCENARIO_LENGTH,\n", - " resample_frequency=RESAMPLE_FREQUENCY,\n", - " reward_conditioning=REWARD_CONDITIONING,\n", - " reward_randomization=REWARD_RANDOMIZATION,\n", - " target_type=TARGET_TYPE,\n", - " collision_behavior=COLLISION_BEHAVIOR,\n", - " offroad_behavior=OFFROAD_BEHAVIOR,\n", - " obs_slots_lane_n=MAX_LANES,\n", - " obs_slots_boundary_n=MAX_BOUNDS,\n", - " obs_slots_partners_n=MAX_PARTNERS,\n", - " obs_slots_traffic_controls_n=MAX_TRAFFIC,\n", - " seed=SEED,\n", + "print(\n", + " f\"env ready: {env.num_agents} agents, obs={obs.shape}, act_shape={(env.num_agents, len(env.single_action_space.nvec))}\"\n", ")\n", - "globals().update(notebook_dims(env))\n", - "\n", - "print(f\"env ready: {N} agents, obs={obs.shape}, act_shape={ACT_SHAPE}\")\n", - "print(f\"EGO_DIM={EGO_DIM}, NUM_COEFS={NUM_COEFS}, MAX_PARTNERS={MAX_PARTNERS}, PARTNER_F={PARTNER_F}\")\n", - "print(f\"MAX_LANES={MAX_LANES}, MAX_BOUNDS={MAX_BOUNDS}, ROAD_F={ROAD_F}\")\n", - "print(f\"MAX_TRAFFIC={MAX_TRAFFIC}, TRAFFIC_F={TRAFFIC_CONTROL_F}\")" + "print(\n", + " f\"ego_features={env.ego_features}, num_reward_coefs={env.num_reward_coefs}, obs_slots_partners_n={env.obs_slots_partners_n}, partner_features={env.partner_features}\"\n", + ")\n", + "print(\n", + " f\"obs_slots_lane_kept={env.obs_slots_lane_kept}, obs_slots_boundary_kept={env.obs_slots_boundary_kept}, road_features={env.road_features}\"\n", + ")\n", + "print(\n", + " f\"obs_slots_traffic_controls_n={env.obs_slots_traffic_controls_n}, traffic_control_features={env.traffic_control_features}\"\n", + ")" ] }, { @@ -94,7 +53,7 @@ " actions = random_actions(env)\n", " obs, rew, term, trunc, info = env.step(actions)\n", "\n", - "log = binding.vec_log(env.c_envs, N)\n", + "log = binding.vec_log(env.c_envs, env.num_agents)\n", "print(f\"vec_log type: {type(log)}\")\n", "if log:\n", " print(f\"Keys: {sorted(log.keys())}\")\n", @@ -119,9 +78,9 @@ "source": [ "N_STEPS = 512\n", "all_logs = []\n", - "all_rewards = np.zeros((N_STEPS, N))\n", - "all_terms = np.zeros((N_STEPS, N))\n", - "all_truncs = np.zeros((N_STEPS, N))\n", + "all_rewards = np.zeros((N_STEPS, env.num_agents))\n", + "all_terms = np.zeros((N_STEPS, env.num_agents))\n", + "all_truncs = np.zeros((N_STEPS, env.num_agents))\n", "\n", "for t in range(N_STEPS):\n", " actions = random_actions(env)\n", @@ -210,7 +169,7 @@ "outputs": [], "source": [ "TRACK_STEPS = 100\n", - "TRACK_AGENTS = min(5, N)\n", + "TRACK_AGENTS = min(5, env.num_agents)\n", "xy_history = np.zeros((TRACK_STEPS, TRACK_AGENTS, 2))\n", "\n", "for t in range(TRACK_STEPS):\n", @@ -288,13 +247,13 @@ "outputs": [], "source": [ "episode_lengths = []\n", - "agent_step_count = np.zeros(N)\n", + "agent_step_count = np.zeros(env.num_agents)\n", "active_counts = []\n", "\n", "for t in range(N_STEPS):\n", - " active = (~np.all(all_rewards[: t + 1] == 0, axis=0) if t > 0 else np.ones(N, dtype=bool)).sum()\n", + " active = (~np.all(all_rewards[: t + 1] == 0, axis=0) if t > 0 else np.ones(env.num_agents, dtype=bool)).sum()\n", " active_counts.append(active)\n", - " for i in range(N):\n", + " for i in range(env.num_agents):\n", " agent_step_count[i] += 1\n", " if all_terms[t, i] or all_truncs[t, i]:\n", " episode_lengths.append(agent_step_count[i])\n", @@ -333,7 +292,7 @@ "if all_logs and \"score\" in all_logs[0]:\n", " scores = [log[\"score\"] for log in all_logs if \"score\" in log]\n", " log_steps = [log[\"_step\"] for log in all_logs if \"score\" in log]\n", - " cum_rew_at_log = [all_rewards[: t + 1].sum() / N for t in log_steps]\n", + " cum_rew_at_log = [all_rewards[: t + 1].sum() / env.num_agents for t in log_steps]\n", "\n", " fig, ax = plt.subplots(figsize=(8, 6))\n", " ax.scatter(cum_rew_at_log, scores, alpha=0.5)\n", diff --git a/notebooks/04_training.ipynb b/notebooks/04_training.ipynb index da78c7c702..7a79a8e0ae 100644 --- a/notebooks/04_training.ipynb +++ b/notebooks/04_training.ipynb @@ -18,55 +18,14 @@ "import matplotlib.pyplot as plt\n", "import torch\n", "import torch.nn.functional as F\n", - "from pufferlib.ocean.drive import binding\n", - "from notebooks.notebook_utils import make_drive_env, make_drive_policy, notebook_dims, random_actions, zero_actions\n", - "\n", - "# --- Environment configuration ---\n", - "NUM_AGENTS = 64\n", - "SIMULATION_MODE = \"gigaflow\"\n", - "DYNAMICS_MODEL = \"jerk\"\n", - "ACTION_TYPE = \"discrete\"\n", - "DT = 0.1\n", - "SCENARIO_LENGTH = 512\n", - "RESAMPLE_FREQUENCY = 0\n", - "REWARD_CONDITIONING = True\n", - "REWARD_RANDOMIZATION = False\n", - "TARGET_TYPE = \"static\"\n", - "COLLISION_BEHAVIOR = 1\n", - "OFFROAD_BEHAVIOR = 1\n", - "SEED = 42\n", - "MAX_PARTNERS = 16\n", - "MAX_LANES = 32\n", - "MAX_BOUNDS = 32\n", - "MAX_TRAFFIC = 10\n", - "\n", - "env, obs, info = make_drive_env(\n", - " num_agents=NUM_AGENTS,\n", - " min_agents_per_env=NUM_AGENTS,\n", - " max_agents_per_env=NUM_AGENTS,\n", - " simulation_mode=SIMULATION_MODE,\n", - " dynamics_model=DYNAMICS_MODEL,\n", - " action_type=ACTION_TYPE,\n", - " dt=DT,\n", - " scenario_length=SCENARIO_LENGTH,\n", - " resample_frequency=RESAMPLE_FREQUENCY,\n", - " reward_conditioning=REWARD_CONDITIONING,\n", - " reward_randomization=REWARD_RANDOMIZATION,\n", - " target_type=TARGET_TYPE,\n", - " collision_behavior=COLLISION_BEHAVIOR,\n", - " offroad_behavior=OFFROAD_BEHAVIOR,\n", - " obs_slots_lane_n=MAX_LANES,\n", - " obs_slots_boundary_n=MAX_BOUNDS,\n", - " obs_slots_partners_n=MAX_PARTNERS,\n", - " obs_slots_traffic_controls_n=MAX_TRAFFIC,\n", - " seed=SEED,\n", - ")\n", - "globals().update(notebook_dims(env))\n", + "from notebooks.notebook_utils import make_drive_env, make_drive_policy, zero_actions\n", + "\n", + "env, obs, info = make_drive_env()\n", "\n", "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n", "policy = make_drive_policy(env, device)\n", "print(f\"Policy on {device}, params: {sum(p.numel() for p in policy.parameters()):,}\")\n", - "print(f\"Action dim: {policy.atn_dim}, act_shape: {ACT_SHAPE}\")" + "print(f\"Action dim: {policy.atn_dim}, act_shape: {(env.num_agents, len(env.single_action_space.nvec))}\")" ] }, { @@ -82,7 +41,7 @@ "metadata": {}, "outputs": [], "source": [ - "# CHECKPOINT_PATH = '/home/o-vcharrau/Workspace/PufferDrive-Valeo/runs/big_test_7/models/model_puffer_drive_000520.pt'\n", + "# CHECKPOINT_PATH = ''\n", "# state_dict = torch.load(CHECKPOINT_PATH, map_location=device)\n", "# state_dict = {k.replace(\"module.\", \"\"): v for k, v in state_dict.items()}\n", "# print('Checkpoint loaded')" @@ -180,7 +139,7 @@ "source": [ "x = obs_tensor\n", "backbone = policy.actor_backbone\n", - "slide_idx = EGO_DIM\n", + "slide_idx = env.ego_features\n", "\n", "ego_obs = x[:, :slide_idx]\n", "print(\n", @@ -193,9 +152,9 @@ " slide_idx += cond_dim\n", " print(f\"cond_obs: shape={cond_obs.shape}, NaN={torch.isnan(cond_obs).sum().item()}\")\n", "\n", - "partner_dim = MAX_PARTNERS * PARTNER_F\n", - "lane_dim = MAX_LANES * ROAD_F\n", - "boundary_dim = MAX_BOUNDS * ROAD_F\n", + "partner_dim = env.obs_slots_partners_n * env.partner_features\n", + "lane_dim = env.obs_slots_lane_kept * env.road_features\n", + "boundary_dim = env.obs_slots_boundary_kept * env.road_features\n", "\n", "partner_obs = x[:, slide_idx : slide_idx + partner_dim]\n", "slide_idx += partner_dim\n", @@ -206,9 +165,13 @@ "\n", "with torch.no_grad():\n", " ego_enc = backbone.ego_encoder(ego_obs)\n", - " partner_enc, _ = backbone.partner_encoder(partner_obs.view(-1, MAX_PARTNERS, PARTNER_F)).max(dim=1)\n", - " lane_enc, _ = backbone.lane_encoder(lane_obs.view(-1, MAX_LANES, ROAD_F)).max(dim=1)\n", - " bound_enc, _ = backbone.boundary_encoder(boundary_obs.view(-1, MAX_BOUNDS, ROAD_F)).max(dim=1)\n", + " partner_enc, _ = backbone.partner_encoder(partner_obs.view(-1, env.obs_slots_partners_n, env.partner_features)).max(\n", + " dim=1\n", + " )\n", + " lane_enc, _ = backbone.lane_encoder(lane_obs.view(-1, env.obs_slots_lane_kept, env.road_features)).max(dim=1)\n", + " bound_enc, _ = backbone.boundary_encoder(boundary_obs.view(-1, env.obs_slots_boundary_kept, env.road_features)).max(\n", + " dim=1\n", + " )\n", "\n", "for name, enc in [(\"ego\", ego_enc), (\"partner\", partner_enc), (\"lane\", lane_enc), (\"boundary\", bound_enc)]:\n", " print(\n", @@ -241,10 +204,10 @@ "\n", "action_logits_list, value = policy(obs_tensor)\n", "\n", - "fake_actions = torch.randint(0, N_ACTIONS, (N,), device=device)\n", - "fake_advantages = torch.randn(N, device=device)\n", - "fake_returns = torch.randn(N, device=device)\n", - "fake_old_logprobs = torch.randn(N, device=device)\n", + "fake_actions = torch.randint(0, env.single_action_space.nvec[0], (env.num_agents,), device=device)\n", + "fake_advantages = torch.randn(env.num_agents, device=device)\n", + "fake_returns = torch.randn(env.num_agents, device=device)\n", + "fake_old_logprobs = torch.randn(env.num_agents, device=device)\n", "\n", "logits = action_logits_list[0]\n", "dist = torch.distributions.Categorical(logits=logits)\n", @@ -322,12 +285,12 @@ "HORIZON = 128\n", "obs_dim = obs.shape[1]\n", "\n", - "obs_buf = np.zeros((HORIZON, N, obs_dim), dtype=np.float32)\n", - "act_buf = np.zeros((HORIZON, N), dtype=np.int64)\n", - "rew_buf = np.zeros((HORIZON, N), dtype=np.float32)\n", - "val_buf = np.zeros((HORIZON, N), dtype=np.float32)\n", - "logp_buf = np.zeros((HORIZON, N), dtype=np.float32)\n", - "done_buf = np.zeros((HORIZON, N), dtype=np.float32)\n", + "obs_buf = np.zeros((HORIZON, env.num_agents, obs_dim), dtype=np.float32)\n", + "act_buf = np.zeros((HORIZON, env.num_agents), dtype=np.int64)\n", + "rew_buf = np.zeros((HORIZON, env.num_agents), dtype=np.float32)\n", + "val_buf = np.zeros((HORIZON, env.num_agents), dtype=np.float32)\n", + "logp_buf = np.zeros((HORIZON, env.num_agents), dtype=np.float32)\n", + "done_buf = np.zeros((HORIZON, env.num_agents), dtype=np.float32)\n", "\n", "policy.eval()\n", "for t in range(HORIZON):\n", @@ -344,7 +307,7 @@ " logp_buf[t] = logp.cpu().numpy()\n", "\n", " # Reshape (N,) -> (N, 1) for env.step with MultiDiscrete\n", - " env_actions = act.cpu().numpy().reshape(ACT_SHAPE)\n", + " env_actions = act.cpu().numpy().reshape(env.num_agents, len(env.single_action_space.nvec))\n", " obs, rew, term, trunc, info = env.step(env_actions)\n", " rew_buf[t] = rew\n", " done_buf[t] = term | trunc\n", @@ -372,7 +335,7 @@ "gamma, lam = 0.98, 0.95\n", "advantages = np.zeros_like(rew_buf)\n", "\n", - "last_gae = np.zeros(N)\n", + "last_gae = np.zeros(env.num_agents)\n", "for t in reversed(range(HORIZON - 1)):\n", " next_non_terminal = 1.0 - done_buf[t + 1]\n", " delta = rew_buf[t + 1] + gamma * val_buf[t + 1] * next_non_terminal - val_buf[t]\n", @@ -456,7 +419,7 @@ "\n", "print(f\"\\npg_loss: {pg_loss.item():.6f}\")\n", "print(f\"v_loss: {v_loss.item():.6f}\")\n", - "print(f\"entropy: {entropy_loss.item():.6f} (max={np.log(N_ACTIONS):.4f})\")\n", + "print(f\"entropy: {entropy_loss.item():.6f} (max={np.log(env.single_action_space.nvec[0]):.4f})\")\n", "print(f\"total: {(pg_loss + 0.5 * v_loss - 0.01 * entropy_loss).item():.6f}\")" ] }, diff --git a/notebooks/05_inference.ipynb b/notebooks/05_inference.ipynb index c92d5b66fe..c1f99ec009 100644 --- a/notebooks/05_inference.ipynb +++ b/notebooks/05_inference.ipynb @@ -22,19 +22,11 @@ "import torch.nn.functional as F\n", "from pufferlib.ocean.drive.drive import Drive\n", "from pufferlib.ocean.drive import binding\n", - "from pufferlib.ocean.torch import Drive as DrivePolicy, Recurrent\n", - "import pufferlib.pytorch\n", - "from notebooks.notebook_utils import (\n", - " COEF_NAMES,\n", - " EGO_LABELS,\n", - " MAP_DIR,\n", - " load_notebook_config,\n", - " make_rnn_state,\n", - " notebook_dims,\n", - " zero_actions,\n", - ")\n", + "from pufferlib.ocean.torch import Drive as DrivePolicy\n", + "from pufferlib.pytorch import sample_logits\n", + "from notebooks.notebook_utils import COEF_NAMES, EGO_LABELS, MAP_DIR, load_notebook_config, zero_actions\n", "\n", - "CHECKPOINT_PATH = \"/home/o-vcharrau/Workspace/PufferDrive-Valeo/runs/tomate/models/model_puffer_drive_013100.pt\"\n", + "CHECKPOINT_PATH = \"../weights/tomate/models/model_puffer_drive_013100.pt\"\n", "ENV_NAME = \"puffer_drive\"\n", "\n", "config = load_notebook_config(CHECKPOINT_PATH, ENV_NAME)\n", @@ -43,16 +35,17 @@ "config[\"env\"][\"eval_mode\"] = 1\n", "config[\"env\"][\"map_dir\"] = MAP_DIR\n", "\n", + "config[\"env\"][\"obs_slots_boundary_n\"] = 80\n", + "config[\"env\"][\"obs_slots_lane_n\"] = 80\n", + "config[\"env\"][\"obs_dropout_lane\"] = 0.0\n", + "config[\"env\"][\"obs_dropout_boundary\"] = 0.0\n", + "\n", "env = Drive(**config[\"env\"])\n", "obs, info = env.reset(seed=42)\n", "N = env.num_agents\n", - "globals().update(notebook_dims(env))\n", "\n", "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n", "policy = DrivePolicy(env, **config[\"policy\"]).to(device)\n", - "use_rnn = config[\"train\"].get(\"use_rnn\", False)\n", - "if use_rnn:\n", - " policy = Recurrent(env, policy, **config[\"rnn\"]).to(device)\n", "\n", "if CHECKPOINT_PATH:\n", " sd = torch.load(CHECKPOINT_PATH, map_location=device)\n", @@ -60,10 +53,8 @@ " policy.load_state_dict(sd)\n", " print(f\"Loaded checkpoint: {CHECKPOINT_PATH}\")\n", "\n", - "inner_policy = policy.policy if use_rnn else policy\n", - "is_continuous = inner_policy.is_continuous\n", + "is_continuous = policy.is_continuous\n", "ACT_SHAPE = (N, len(env.single_action_space.nvec)) if not is_continuous else (N, env.single_action_space.shape[0])\n", - "state = make_rnn_state(policy, N, device) if use_rnn else None\n", "\n", "print(f\"Policy on {device}, params: {sum(p.numel() for p in policy.parameters()):,}\")\n", "print(f\"Obs shape: {obs.shape}, Action space: {env.single_action_space}\")\n", @@ -93,14 +84,11 @@ "policy.eval()\n", "\n", "with torch.no_grad():\n", - " if use_rnn:\n", - " logits_list, value = policy.forward_eval(obs_tensor, state)\n", - " else:\n", - " logits_list, value = policy(obs_tensor)\n", + " logits_list, value = policy(obs_tensor)\n", "\n", "# Sample actions\n", - "action, logprob, ent = pufferlib.pytorch.sample_logits(logits_list)\n", - "action_det, _, _ = pufferlib.pytorch.sample_logits(logits_list, deterministic=True)\n", + "action, logprob, ent = sample_logits(logits_list)\n", + "action_det, _, _ = sample_logits(logits_list, deterministic=True)\n", "\n", "print(f\"Value: mean={value.mean():.4f}, std={value.std():.4f}, range=[{value.min():.4f}, {value.max():.4f}]\")\n", "print(f\"Entropy: mean={ent.mean():.4f}, std={ent.std():.4f}\")\n", @@ -159,7 +147,6 @@ "def run_rollout(env, policy, deterministic=False, horizon=HORIZON):\n", " obs, _ = env.reset(seed=42)\n", " N = env.num_agents\n", - " st = make_rnn_state(policy, N, device) if use_rnn else None\n", "\n", " buffers = {\n", " \"obs\": np.zeros((horizon, N, obs_dim), dtype=np.float32),\n", @@ -173,18 +160,13 @@ " \"positions_x\": np.zeros((horizon, N), dtype=np.float32),\n", " \"positions_y\": np.zeros((horizon, N), dtype=np.float32),\n", " }\n", - " if use_rnn:\n", - " buffers[\"lstm_h_norm\"] = np.zeros((horizon, N), dtype=np.float32)\n", "\n", " policy.eval()\n", " for t in range(horizon):\n", " obs_t = torch.FloatTensor(obs).to(device)\n", " with torch.no_grad():\n", - " if use_rnn:\n", - " logits_list, val = policy.forward_eval(obs_t, st)\n", - " else:\n", - " logits_list, val = policy(obs_t)\n", - " act, logp, entr = pufferlib.pytorch.sample_logits(logits_list, deterministic=deterministic)\n", + " logits_list, val = policy(obs_t)\n", + " act, logp, entr = sample_logits(logits_list, deterministic=deterministic)\n", "\n", " buffers[\"obs\"][t] = obs\n", " buffers[\"actions\"][t] = act.cpu().numpy().reshape(N) if act.dim() > 1 else act.cpu().numpy()\n", @@ -192,9 +174,6 @@ " buffers[\"logprobs\"][t] = logp.cpu().numpy()\n", " buffers[\"entropy\"][t] = entr.cpu().numpy()\n", "\n", - " if use_rnn:\n", - " buffers[\"lstm_h_norm\"][t] = st[\"lstm_h\"].norm(dim=-1).cpu().numpy()\n", - "\n", " # Get positions\n", " gstate = env.get_global_agent_state()\n", " buffers[\"positions_x\"][t] = gstate[\"x\"]\n", @@ -249,9 +228,9 @@ " target_type=tgt_type,\n", " reward_conditioning=rew_cond,\n", " num_target_waypoints=n_tgt_wp,\n", - " max_partners=env.obs_slots_partners_n,\n", - " max_lane_segments=env.obs_slots_lane_n,\n", - " max_boundary_segments=env.obs_slots_boundary_n,\n", + " obs_slots_partners_n=env.obs_slots_partners_n,\n", + " obs_slots_lane_n=env.obs_slots_lane_kept,\n", + " obs_slots_boundary_n=env.obs_slots_boundary_kept,\n", " obs_slots_traffic_controls_n=env.obs_slots_traffic_controls_n,\n", " obs_norm_goal_offset_m=env.obs_norm_goal_offset_m,\n", " obs_norm_xy_offset_m=env.obs_norm_xy_offset_m,\n", @@ -282,9 +261,9 @@ " target_type=tgt_type,\n", " reward_conditioning=rew_cond,\n", " num_target_waypoints=n_tgt_wp,\n", - " max_partners=env.obs_slots_partners_n,\n", - " max_lane_segments=env.obs_slots_lane_n,\n", - " max_boundary_segments=env.obs_slots_boundary_n,\n", + " obs_slots_partners_n=env.obs_slots_partners_n,\n", + " obs_slots_lane_n=env.obs_slots_lane_kept,\n", + " obs_slots_boundary_n=env.obs_slots_boundary_kept,\n", " obs_slots_traffic_controls_n=env.obs_slots_traffic_controls_n,\n", " )\n", " ego_features_over_time.append(ego)\n", @@ -342,9 +321,9 @@ " target_type=tgt_type,\n", " reward_conditioning=rew_cond,\n", " num_target_waypoints=n_tgt_wp,\n", - " max_partners=env.obs_slots_partners_n,\n", - " max_lane_segments=env.obs_slots_lane_n,\n", - " max_boundary_segments=env.obs_slots_boundary_n,\n", + " obs_slots_partners_n=env.obs_slots_partners_n,\n", + " obs_slots_lane_n=env.obs_slots_lane_kept,\n", + " obs_slots_boundary_n=env.obs_slots_boundary_kept,\n", " obs_slots_traffic_controls_n=env.obs_slots_traffic_controls_n,\n", ")\n", "\n", @@ -449,10 +428,10 @@ "cond_dim = binding.NUM_REWARD_COEFS if rew_cond else 0\n", "tgt_feat = binding.STATIC_TARGET_FEATURES if tgt_type == \"static\" else binding.DYNAMIC_TARGET_FEATURES\n", "tgt_dim = n_tgt_wp * tgt_feat\n", - "partner_dim = env.obs_slots_partners_n * PARTNER_F\n", - "lane_dim = env.obs_slots_lane_n * ROAD_F\n", - "boundary_dim = env.obs_slots_boundary_n * ROAD_F\n", - "traffic_dim = env.obs_slots_traffic_controls_n * TRAFFIC_CONTROL_F\n", + "partner_dim = env.obs_slots_partners_n * env.partner_features\n", + "lane_dim = env.obs_slots_lane_kept * env.road_features\n", + "boundary_dim = env.obs_slots_boundary_kept * env.road_features\n", + "traffic_dim = env.obs_slots_traffic_controls_n * env.traffic_control_features\n", "\n", "# Slice indices\n", "idx = 0\n", @@ -535,9 +514,9 @@ " target_type=tgt_type,\n", " reward_conditioning=rew_cond,\n", " num_target_waypoints=n_tgt_wp,\n", - " max_partners=env.obs_slots_partners_n,\n", - " max_lane_segments=env.obs_slots_lane_n,\n", - " max_boundary_segments=env.obs_slots_boundary_n,\n", + " obs_slots_partners_n=env.obs_slots_partners_n,\n", + " obs_slots_lane_n=env.obs_slots_lane_kept,\n", + " obs_slots_boundary_n=env.obs_slots_boundary_kept,\n", " obs_slots_traffic_controls_n=env.obs_slots_traffic_controls_n,\n", " )\n", " egos.append(ego)\n", @@ -611,9 +590,9 @@ " target_type=tgt_type,\n", " reward_conditioning=rew_cond,\n", " num_target_waypoints=n_tgt_wp,\n", - " max_partners=env.obs_slots_partners_n,\n", - " max_lane_segments=env.obs_slots_lane_n,\n", - " max_boundary_segments=env.obs_slots_boundary_n,\n", + " obs_slots_partners_n=env.obs_slots_partners_n,\n", + " obs_slots_lane_n=env.obs_slots_lane_kept,\n", + " obs_slots_boundary_n=env.obs_slots_boundary_kept,\n", " obs_slots_traffic_controls_n=env.obs_slots_traffic_controls_n,\n", " )\n", " dists = np.sqrt(part[:, 0] ** 2 + part[:, 1] ** 2)\n", @@ -641,13 +620,12 @@ "sample_obs = buf_stoch[\"obs\"][sample_t : sample_t + 1, TRACKED_AGENT : TRACKED_AGENT + 1][0]\n", "ego, target, partners, lanes, boundaries, traffic_controls = unpack_obs(\n", " sample_obs,\n", - " dynamics_model=dyn_model,\n", " target_type=tgt_type,\n", " reward_conditioning=rew_cond,\n", " num_target_waypoints=n_tgt_wp,\n", - " max_partners=env.obs_slots_partners_n,\n", - " max_lane_segments=env.obs_slots_lane_n,\n", - " max_boundary_segments=env.obs_slots_boundary_n,\n", + " obs_slots_partners_n=env.obs_slots_partners_n,\n", + " obs_slots_lane_n=env.obs_slots_lane_kept,\n", + " obs_slots_boundary_n=env.obs_slots_boundary_kept,\n", " obs_slots_traffic_controls_n=env.obs_slots_traffic_controls_n,\n", ")\n", "\n", @@ -870,8 +848,8 @@ "source": [ "# Partner per-feature distributions (pooled over all agents + timesteps, visible only)\n", "partner_labels = [\"rel_x\", \"rel_y\", \"rel_z\", \"width\", \"length\", \"heading_cos\", \"heading_sin\", \"speed\"]\n", - "max_partners = env.obs_slots_partners_n\n", - "pf = PARTNER_F\n", + "obs_slots_partners_n = env.obs_slots_partners_n\n", + "pf = env.partner_features\n", "\n", "# Compute slices\n", "_ego_d = binding.EGO_FEATURES\n", @@ -879,16 +857,18 @@ "_tgt_f = binding.STATIC_TARGET_FEATURES if tgt_type == \"static\" else binding.DYNAMIC_TARGET_FEATURES\n", "_tgt_d = n_tgt_wp * _tgt_f\n", "_p_start = _ego_d + _cond_d + _tgt_d\n", - "_p_end = _p_start + max_partners * pf\n", + "_p_end = _p_start + obs_slots_partners_n * pf\n", "\n", - "all_partners = buf_stoch[\"obs\"][:, :, _p_start:_p_end].reshape(-1, max_partners, pf) # (H*N, max_partners, 8)\n", + "all_partners = buf_stoch[\"obs\"][:, :, _p_start:_p_end].reshape(\n", + " -1, obs_slots_partners_n, pf\n", + ") # (H*N, obs_slots_partners_n, 8)\n", "# Mask: partner is visible if any feature != 0\n", "visible_mask = np.any(all_partners != 0, axis=2) # (H*N, 16)\n", "visible_partners = all_partners[visible_mask] # (K, 8) — all visible partner observations\n", "\n", "print(\n", - " f\"Total partner obs: {all_partners.shape[0] * max_partners}, visible: {len(visible_partners)} \"\n", - " f\"({100 * len(visible_partners) / (all_partners.shape[0] * max_partners):.1f}%)\"\n", + " f\"Total partner obs: {all_partners.shape[0] * obs_slots_partners_n}, visible: {len(visible_partners)} \"\n", + " f\"({100 * len(visible_partners) / (all_partners.shape[0] * obs_slots_partners_n):.1f}%)\"\n", ")\n", "\n", "fig, axes = plt.subplots(3, 3, figsize=(21, 10))\n", @@ -917,7 +897,7 @@ "# Partner count distribution across (timestep, agent)\n", "partner_counts = visible_mask.sum(axis=1) # (H*N,)\n", "fig, axes = plt.subplots(1, 2, figsize=(12, 4))\n", - "axes[0].hist(partner_counts, bins=range(max_partners + 2), edgecolor=\"black\", alpha=0.7, color=\"darkorange\")\n", + "axes[0].hist(partner_counts, bins=range(obs_slots_partners_n + 2), edgecolor=\"black\", alpha=0.7, color=\"darkorange\")\n", "axes[0].set_xlabel(\"Visible partners\")\n", "axes[0].set_ylabel(\"Count\")\n", "axes[0].set_title(\"Partner count distribution (per agent per step)\")\n", @@ -951,9 +931,9 @@ "source": [ "# Road per-feature distributions (lanes + boundaries)\n", "road_labels = [\"rel_x\", \"rel_y\", \"rel_z\", \"seg_length\", \"seg_width\", \"dir_cos\", \"dir_sin\"]\n", - "rf = ROAD_F\n", - "max_lanes = env.obs_slots_lane_n\n", - "max_bounds = env.obs_slots_boundary_n\n", + "rf = env.road_features\n", + "max_lanes = env.obs_slots_lane_kept\n", + "max_bounds = env.obs_slots_boundary_kept\n", "\n", "_l_start = _p_end\n", "_l_end = _l_start + max_lanes * rf\n", @@ -1077,15 +1057,15 @@ "# Sparsity heatmap: fraction of nonzero per layer, per agent, over time\n", "layer_names = [\"partners\", \"lanes\", \"boundaries\"]\n", "layer_slices = [\n", - " (_p_start, _p_end, env.obs_slots_partners_n, PARTNER_F),\n", - " (_l_start, _l_end, env.obs_slots_lane_n, ROAD_F),\n", - " (_b_start, _b_end, env.obs_slots_boundary_n, ROAD_F),\n", + " (_p_start, _p_end, env.obs_slots_partners_n, env.partner_features),\n", + " (_l_start, _l_end, env.obs_slots_lane_kept, env.road_features),\n", + " (_b_start, _b_end, env.obs_slots_boundary_kept, env.road_features),\n", "]\n", "\n", "fig, axes = plt.subplots(1, 3, figsize=(20, 5))\n", "for ax, name, (s, e, n_obj, n_feat) in zip(axes, layer_names, layer_slices):\n", " # (H, N) -> fraction of visible objects per (timestep, agent)\n", - " raw = buf_stoch[\"obs\"][:, :, s:e].reshape(HORIZON, N, n_obj, n_feat)\n", + " raw = buf_stoch[\"obs\"][:, :, s:e].reshape(HORIZON, env.num_agents, n_obj, n_feat)\n", " occupancy = np.any(raw != 0, axis=3).sum(axis=2) / n_obj # (H, N)\n", " im = ax.imshow(occupancy.T, aspect=\"auto\", cmap=\"YlOrRd\", interpolation=\"nearest\", vmin=0, vmax=1)\n", " ax.set_xlabel(\"Step\")\n", @@ -1102,7 +1082,7 @@ "\n", "# Mean across agents\n", "for name, (s, e, n_obj, n_feat) in zip(layer_names, layer_slices):\n", - " raw = buf_stoch[\"obs\"][:, :, s:e].reshape(HORIZON, N, n_obj, n_feat)\n", + " raw = buf_stoch[\"obs\"][:, :, s:e].reshape(HORIZON, env.num_agents, n_obj, n_feat)\n", " occ_mean = np.any(raw != 0, axis=3).sum(axis=2).mean(axis=1) # (H,)\n", " axes[0].plot(occ_mean, label=name, alpha=0.8)\n", "axes[0].set_xlabel(\"Step\")\n", @@ -1113,7 +1093,7 @@ "\n", "# Mean across timesteps (per agent)\n", "for name, (s, e, n_obj, n_feat) in zip(layer_names, layer_slices):\n", - " raw = buf_stoch[\"obs\"][:, :, s:e].reshape(HORIZON, N, n_obj, n_feat)\n", + " raw = buf_stoch[\"obs\"][:, :, s:e].reshape(HORIZON, env.num_agents, n_obj, n_feat)\n", " occ_per_agent = np.any(raw != 0, axis=3).sum(axis=2).mean(axis=0) # (N,)\n", " axes[1].bar(range(N), occ_per_agent, alpha=0.5, label=name)\n", "axes[1].set_xlabel(\"Agent\")\n", @@ -1167,13 +1147,7 @@ "for t in range(HORIZON):\n", " obs_t = torch.FloatTensor(buf_stoch[\"obs\"][t : t + 1, TRACKED_AGENT : TRACKED_AGENT + 1][0]).to(device)\n", " with torch.no_grad():\n", - " if use_rnn:\n", - " # Can't replay RNN states here easily, use feedforward approximation\n", - " inner = policy.policy if hasattr(policy, \"policy\") else policy\n", - " h = inner.encode_observations(obs_t)\n", - " logits_list, _ = inner.decode_actions(h)\n", - " else:\n", - " logits_list, _ = policy(obs_t)\n", + " logits_list, _ = policy(obs_t)\n", " logits = logits_list[0] if isinstance(logits_list, (list, tuple)) else logits_list\n", " action_probs_time[t] = F.softmax(logits, dim=-1).cpu().numpy().flatten()\n", "\n", @@ -1485,58 +1459,300 @@ }, { "cell_type": "markdown", - "id": "cell-18", + "id": "ea90af09", + "metadata": {}, + "source": [ + "## Encoder analysis — what the policy encodes\n", + "\n", + "Each obs layer has its own encoder projecting raw features → `input_size` embedding:\n", + "- **ego** and **conditioning** (reward coefs + target): single vector, no pooling.\n", + "- **partners / lanes / boundaries / traffic**: per-slot encoder, padded slots masked to `-inf`, then **max-pooled** across slots → one embedding. Fully-padded layers are zeroed.\n", + "\n", + "The max-pool means each embedding dim is \"won\" by exactly one slot (object). Below we inspect:\n", + "1. Encoder inventory (in/out dims, params).\n", + "2. **What survives the max-pool**: which slot wins per dim, per-dim winner entropy (slot-specialized vs. spread), and where the dominant objects sit in ego frame.\n", + "3. **Embedding space**: per-encoder contribution (L2 norm), active/dead dims, silence rate." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "7eb06a8d", "metadata": {}, + "outputs": [], "source": [ - "## LSTM hidden state analysis" + "# ── Setup: capture per-encoder embeddings + reconstruct the max-pool ──\n", + "bb = policy.actor_backbone\n", + "ego_dim = policy.ego_dim\n", + "PAD = -1.0 # PADDED_OBSERVATION_VALUE\n", + "\n", + "# Flat batch of observations from the stochastic rollout\n", + "obs_flat = buf_stoch[\"obs\"].reshape(-1, obs_dim)\n", + "rng = np.random.default_rng(0)\n", + "sel = rng.choice(obs_flat.shape[0], size=min(4096, obs_flat.shape[0]), replace=False)\n", + "obs_batch = torch.FloatTensor(obs_flat[sel]).to(device)\n", + "B = obs_batch.shape[0]\n", + "\n", + "# Encoder inventory: (name, module, raw_in_features, n_slots, is_set)\n", + "enc_inventory = [(\"ego\", bb.ego_encoder, ego_dim, 1, False)]\n", + "if bb.obs_slots_lane_kept > 0:\n", + " enc_inventory.append((\"lane\", bb.lane_encoder, bb.road_features_count, bb.obs_slots_lane_kept, True))\n", + "if bb.obs_slots_boundary_kept > 0:\n", + " enc_inventory.append((\"boundary\", bb.boundary_encoder, bb.road_features_count, bb.obs_slots_boundary_kept, True))\n", + "if bb.obs_slots_partners_n > 0:\n", + " enc_inventory.append((\"partner\", bb.partner_encoder, bb.partner_features_count, bb.obs_slots_partners_n, True))\n", + "if bb.obs_slots_traffic_controls_n > 0:\n", + " enc_inventory.append(\n", + " (\n", + " \"traffic\",\n", + " bb.traffic_control_encoder,\n", + " bb.traffic_control_features_after_onehot,\n", + " bb.obs_slots_traffic_controls_n,\n", + " True,\n", + " )\n", + " )\n", + "if bb.conditioning_dim > 0:\n", + " enc_inventory.append((\"conditioning\", bb.conditioning_encoder, bb.conditioning_dim, 1, False))\n", + "\n", + "enc_names = [n for n, *_ in enc_inventory]\n", + "set_encs = [n for n, _, _, _, is_set in enc_inventory if is_set]\n", + "\n", + "print(f\"{'encoder':>13s} | {'raw_in':>6s} | {'emb_out':>7s} | {'slots':>5s} | {'pooled':>6s} | {'params':>9s}\")\n", + "print(\"-\" * 66)\n", + "for name, mod, rin, nslots, is_set in enc_inventory:\n", + " nparam = sum(p.numel() for p in mod.parameters())\n", + " print(\n", + " f\"{name:>13s} | {rin:>6d} | {bb.input_size:>7d} | {nslots:>5d} | {('max' if is_set else '-'):>6s} | {nparam:>9,d}\"\n", + " )\n", + "print(\n", + " f\"\\nBackbone input = {len(enc_inventory)} x {bb.input_size} = {len(enc_inventory) * bb.input_size} -> backbone -> {bb.out_dim}\"\n", + ")\n", + "\n", + "# Capture pre-pool encoder outputs via forward hooks\n", + "captured = {}\n", + "\n", + "\n", + "def _hook(name):\n", + " def fn(m, i, o):\n", + " captured[name] = o.detach()\n", + "\n", + " return fn\n", + "\n", + "\n", + "handles = [mod.register_forward_hook(_hook(name)) for name, mod, *_ in enc_inventory]\n", + "policy.eval()\n", + "with torch.no_grad():\n", + " policy(obs_batch)\n", + "for h in handles:\n", + " h.remove()\n", + "\n", + "# Reconstruct slot slices (same order as DriveBackbone.forward) + pad masks\n", + "partner_dim = bb.obs_slots_partners_n * bb.partner_features_count\n", + "lane_dim = bb.obs_slots_lane_kept * bb.road_features_count\n", + "boundary_dim = bb.obs_slots_boundary_kept * bb.road_features_count\n", + "traffic_dim = bb.obs_slots_traffic_controls_n * bb.traffic_control_features_count\n", + "_s = ego_dim + bb.conditioning_dim\n", + "sl = {}\n", + "sl[\"partner\"] = (_s, _s + partner_dim, bb.obs_slots_partners_n, bb.partner_features_count)\n", + "_s += partner_dim\n", + "sl[\"lane\"] = (_s, _s + lane_dim, bb.obs_slots_lane_kept, bb.road_features_count)\n", + "_s += lane_dim\n", + "sl[\"boundary\"] = (_s, _s + boundary_dim, bb.obs_slots_boundary_kept, bb.road_features_count)\n", + "_s += boundary_dim\n", + "sl[\"traffic\"] = (_s, _s + traffic_dim, bb.obs_slots_traffic_controls_n, bb.traffic_control_features_count)\n", + "_s += traffic_dim\n", + "\n", + "raw, pad, pooled, winners, valid_sample = {}, {}, {}, {}, {}\n", + "for name in set_encs:\n", + " s, e, ns, nf = sl[name]\n", + " obj = obs_batch[:, s:e].view(B, ns, nf)\n", + " raw[name] = obj\n", + " if name == \"traffic\":\n", + " cont = obj[:, :, : bb.traffic_control_continuous_features]\n", + " typ = obj[:, :, bb.traffic_control_continuous_features]\n", + " st = obj[:, :, bb.traffic_control_continuous_features + 1]\n", + " pad[name] = (\n", + " (cont == PAD).all(dim=2)\n", + " & (typ == binding.TRAFFIC_CONTROL_TYPE_NONE)\n", + " & (st == binding.TRAFFIC_CONTROL_STATE_UNKNOWN)\n", + " )\n", + " else:\n", + " pad[name] = (obj == PAD).all(dim=2)\n", + " masked = captured[name].masked_fill(pad[name].unsqueeze(2), -torch.inf)\n", + " vm = (~pad[name]).any(dim=1)\n", + " valid_sample[name] = vm\n", + " winners[name] = masked.max(dim=1).indices # (B, input_size): winning slot per dim\n", + " pooled[name] = torch.where(vm.unsqueeze(1), masked.max(dim=1).values, torch.zeros_like(masked.max(dim=1).values))\n", + "\n", + "for name in (\"ego\", \"conditioning\"):\n", + " if name in enc_names:\n", + " pooled[name] = captured[name]\n", + "\n", + "print(\"\\nCaptured embeddings for:\", enc_names)" ] }, { "cell_type": "code", "execution_count": null, - "id": "cell-19", + "id": "01446b9c", "metadata": {}, "outputs": [], "source": [ - "if use_rnn and \"lstm_h_norm\" in buf_stoch:\n", - " h_norm = buf_stoch[\"lstm_h_norm\"]\n", - "\n", - " fig, axes = plt.subplots(1, 3, figsize=(18, 5))\n", - "\n", - " # Hidden state norm over time\n", - " axes[0].plot(h_norm.mean(axis=1), label=\"mean\", color=\"blue\")\n", - " axes[0].fill_between(\n", - " range(HORIZON),\n", - " h_norm.mean(axis=1) - h_norm.std(axis=1),\n", - " h_norm.mean(axis=1) + h_norm.std(axis=1),\n", - " alpha=0.2,\n", - " color=\"blue\",\n", + "# ── What survives the max-pool: winning slots, specialization, spatial ──\n", + "n = len(set_encs)\n", + "fig, axes = plt.subplots(n, 3, figsize=(18, 4.2 * n))\n", + "if n == 1:\n", + " axes = axes[None, :]\n", + "\n", + "print(f\"{'encoder':>9s} | {'valid%':>6s} | {'mean active slots/dim':>21s} | {'%slot-specialized dims':>22s}\")\n", + "print(\"-\" * 70)\n", + "for r, name in enumerate(set_encs):\n", + " s, e, ns, nf = sl[name]\n", + " vm = valid_sample[name]\n", + " w = winners[name][vm] # (Bv, D)\n", + " D = w.shape[1]\n", + "\n", + " # (1) which slot wins, pooled over all dims+samples\n", + " slot_counts = torch.bincount(w.reshape(-1), minlength=ns).float().cpu().numpy()\n", + " slot_counts = slot_counts / max(slot_counts.sum(), 1)\n", + " axes[r, 0].bar(range(ns), slot_counts, color=\"teal\", alpha=0.85, edgecolor=\"black\")\n", + " axes[r, 0].set_title(f\"{name}: max-pool winner by slot\")\n", + " axes[r, 0].set_xlabel(\"slot index (0 = first/closest)\")\n", + " axes[r, 0].set_ylabel(\"frac of dims won\")\n", + "\n", + " # (2) per-dim winner entropy: slot-specialized (0) vs spread across slots (1)\n", + " onehot = F.one_hot(w, num_classes=ns).float() # (Bv, D, ns)\n", + " p = onehot.mean(dim=0) # (D, ns) winner distribution per dim\n", + " ent = (-(p * (p + 1e-9).log()).sum(dim=1) / np.log(ns)).cpu().numpy()\n", + " axes[r, 1].hist(ent, bins=30, color=\"indianred\", alpha=0.85, edgecolor=\"black\")\n", + " axes[r, 1].set_title(f\"{name}: per-dim winner entropy\")\n", + " axes[r, 1].set_xlabel(\"0 = slot-specialized → 1 = spread\")\n", + " axes[r, 1].set_xlim(0, 1)\n", + "\n", + " # (3) ego-frame position of the dominant object (mode winning slot per sample)\n", + " dom = torch.mode(w, dim=1).values # (Bv,)\n", + " rel = raw[name][vm]\n", + " dom_xy = rel[torch.arange(rel.shape[0]), dom][:, :2].cpu().numpy()\n", + " axes[r, 2].scatter(dom_xy[:, 0], dom_xy[:, 1], s=3, alpha=0.15, color=\"navy\")\n", + " axes[r, 2].scatter(0, 0, marker=\"*\", s=200, color=\"red\", zorder=5, label=\"ego\")\n", + " axes[r, 2].set_title(f\"{name}: dominant object position (ego frame)\")\n", + " axes[r, 2].set_xlabel(\"rel_x\")\n", + " axes[r, 2].set_ylabel(\"rel_y\")\n", + " axes[r, 2].set_aspect(\"equal\")\n", + " axes[r, 2].legend(fontsize=8)\n", + "\n", + " active_per_dim = np.exp(ent * np.log(ns)).mean()\n", + " print(\n", + " f\"{name:>9s} | {100 * vm.float().mean().item():>5.1f}% | {active_per_dim:>21.2f} | {100 * (ent < 0.2).mean():>21.1f}%\"\n", " )\n", - " axes[0].set_xlabel(\"Step\")\n", - " axes[0].set_ylabel(\"||h||\")\n", - " axes[0].set_title(\"LSTM hidden state norm over time\")\n", - " axes[0].grid(True, alpha=0.3)\n", - "\n", - " # Histogram at different timesteps\n", - " for t, color in [(0, \"blue\"), (HORIZON // 4, \"green\"), (HORIZON // 2, \"orange\"), (HORIZON - 1, \"red\")]:\n", - " axes[1].hist(h_norm[t], bins=30, alpha=0.4, label=f\"t={t}\", color=color, edgecolor=\"black\")\n", - " axes[1].set_xlabel(\"||h||\")\n", - " axes[1].set_ylabel(\"Count\")\n", - " axes[1].set_title(\"Hidden norm distribution at different timesteps\")\n", - " axes[1].legend()\n", - "\n", - " # Correlation: hidden norm vs value\n", - " axes[2].scatter(h_norm.flatten(), buf_stoch[\"values\"].flatten(), alpha=0.1, s=3)\n", - " corr = np.corrcoef(h_norm.flatten(), buf_stoch[\"values\"].flatten())[0, 1]\n", - " axes[2].set_xlabel(\"||h||\")\n", - " axes[2].set_ylabel(\"Value\")\n", - " axes[2].set_title(f\"Hidden norm vs Value (corr={corr:.3f})\")\n", - " axes[2].grid(True, alpha=0.3)\n", "\n", + "plt.tight_layout()\n", + "plt.show()\n", + "\n", + "\n", + "# ── H1/H2/H3 check: boundary max-pool winner distance vs slot index ──\n", + "if \"boundary\" in set_encs:\n", + " vm = valid_sample[\"boundary\"]\n", + " w = winners[\"boundary\"][vm] # (Bv, D) winning slot per dim\n", + " rb = raw[\"boundary\"][vm] # (Bv, ns, nf) raw segments\n", + " nsb = rb.shape[1]\n", + " reldist = torch.hypot(rb[:, :, 0], rb[:, :, 1]) # (Bv, ns) normalized ego-frame dist\n", + " valid_seg = ~pad[\"boundary\"][vm] # (Bv, ns) slots holding a real segment\n", + "\n", + " win_reldist = torch.gather(reldist, 1, w) # (Bv, D) dist of each winning segment\n", + " slot_flat = w.reshape(-1)\n", + " wdist_flat = win_reldist.reshape(-1)\n", + " total_wins = slot_flat.numel()\n", + "\n", + " print(\"\\n=== Boundary winner distance vs slot index (H1/H2/H3) ===\")\n", + " print(f\"{'slot':>4s} | {'#wins':>8s} | {'win%':>6s} | {'rel_dist winners':>16s} | {'rel_dist occupied':>17s}\")\n", + " print(\"-\" * 64)\n", + " for s in range(nsb):\n", + " wm = slot_flat == s\n", + " nwin = int(wm.sum())\n", + " wmean = wdist_flat[wm].mean().item() if nwin > 0 else float(\"nan\")\n", + " occ = valid_seg[:, s]\n", + " omean = reldist[occ, s].mean().item() if int(occ.sum()) > 0 else float(\"nan\")\n", + " print(f\"{s:>4d} | {nwin:>8d} | {100 * nwin / total_wins:>5.1f}% | {wmean:>16.4f} | {omean:>17.4f}\")\n", + "\n", + " win_mean = wdist_flat.mean().item()\n", + " seg_mean = reldist[valid_seg].mean().item()\n", + " print(f\"\\nMean rel_dist of WINNING segments : {win_mean:.4f}\")\n", + " print(f\"Mean rel_dist of ALL valid segments: {seg_mean:.4f}\")\n", + " verdict = \"FARTHER than avg (H3 supported)\" if win_mean > seg_mean else \"nearer than avg\"\n", + " print(f\"-> winners are {verdict} by {win_mean - seg_mean:+.4f} (normalized units)\")\n", + "\n", + " fig, ax = plt.subplots(1, 2, figsize=(14, 4))\n", + " occ_means = [\n", + " reldist[valid_seg[:, s], s].mean().item() if int(valid_seg[:, s].sum()) > 0 else np.nan for s in range(nsb)\n", + " ]\n", + " win_means = [\n", + " wdist_flat[slot_flat == s].mean().item() if int((slot_flat == s).sum()) > 0 else np.nan for s in range(nsb)\n", + " ]\n", + " ax[0].plot(range(nsb), occ_means, \"o-\", label=\"occupied (any segment in slot)\")\n", + " ax[0].plot(range(nsb), win_means, \"s-\", label=\"winners only\")\n", + " ax[0].axhline(seg_mean, color=\"gray\", ls=\"--\", label=\"global valid mean\")\n", + " ax[0].set_xlabel(\"slot index\")\n", + " ax[0].set_ylabel(\"mean rel_dist (normalized)\")\n", + " ax[0].set_title(\"Boundary rel_dist vs slot index\\n(flat = not distance-sorted -> H1/H2)\")\n", + " ax[0].legend(fontsize=8)\n", + " ax[1].hist(reldist[valid_seg].cpu().numpy(), bins=50, alpha=0.6, density=True, label=\"all valid segs\", color=\"gray\")\n", + " ax[1].hist(wdist_flat.cpu().numpy(), bins=50, alpha=0.6, density=True, label=\"winners\", color=\"crimson\")\n", + " ax[1].axvline(seg_mean, color=\"gray\", ls=\"--\")\n", + " ax[1].axvline(win_mean, color=\"crimson\", ls=\"--\")\n", + " ax[1].set_xlabel(\"rel_dist (normalized)\")\n", + " ax[1].set_title(\"Winner vs all-segment distance (H3)\")\n", + " ax[1].legend(fontsize=8)\n", " plt.tight_layout()\n", - " plt.show()\n", - "else:\n", - " print(\"No LSTM — skipping hidden state analysis\")" + " plt.show()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "44d421ad", + "metadata": {}, + "outputs": [], + "source": [ + "# ── Embedding space: per-encoder contribution, active/dead dims, silence ──\n", + "fig, axes = plt.subplots(1, 3, figsize=(20, 5))\n", + "\n", + "# (1) Mean L2 norm of each pooled embedding = relative weight in the concat fed to backbone\n", + "norms = [pooled[n].norm(dim=1).mean().item() for n in enc_names]\n", + "axes[0].bar(enc_names, norms, color=\"slateblue\", edgecolor=\"black\")\n", + "axes[0].set_title(\"Mean L2 norm of pooled embedding\\n(relative contribution to backbone input)\")\n", + "axes[0].tick_params(axis=\"x\", rotation=45)\n", + "axes[0].grid(True, axis=\"y\", alpha=0.3)\n", + "\n", + "# (2) Mean |activation| per embedding dim, per encoder\n", + "M = np.stack([pooled[n].abs().mean(0).cpu().numpy() for n in enc_names])\n", + "im = axes[1].imshow(M, aspect=\"auto\", cmap=\"magma\")\n", + "axes[1].set_yticks(range(len(enc_names)))\n", + "axes[1].set_yticklabels(enc_names)\n", + "axes[1].set_xlabel(\"embedding dim\")\n", + "axes[1].set_title(\"Mean |activation| per embedding dim\")\n", + "plt.colorbar(im, ax=axes[1])\n", + "\n", + "# (3) Dead dims (std<1e-4) — capacity the encoder never uses\n", + "dead = [(pooled[n].std(0) < 1e-4).float().mean().item() for n in enc_names]\n", + "axes[2].bar(enc_names, dead, color=\"gray\", edgecolor=\"black\")\n", + "axes[2].set_title(\"Fraction of dead embedding dims (std < 1e-4)\")\n", + "axes[2].tick_params(axis=\"x\", rotation=45)\n", + "axes[2].set_ylim(0, 1)\n", + "axes[2].grid(True, axis=\"y\", alpha=0.3)\n", + "\n", + "plt.tight_layout()\n", + "plt.show()\n", + "\n", + "print(f\"{'encoder':>13s} | {'mean|act|':>9s} | {'emb L2':>7s} | {'dead dims':>9s} | {'silence (fully padded)':>22s}\")\n", + "print(\"-\" * 80)\n", + "for name in enc_names:\n", + " silence = (1 - valid_sample[name].float().mean().item()) if name in valid_sample else 0.0\n", + " deadf = (pooled[name].std(0) < 1e-4).float().mean().item()\n", + " print(\n", + " f\"{name:>13s} | {pooled[name].abs().mean().item():>9.4f} | {pooled[name].norm(dim=1).mean().item():>7.3f} | \"\n", + " f\"{100 * deadf:>7.1f}% | {100 * silence:>21.1f}%\"\n", + " )" ] } ], diff --git a/notebooks/06_architecture.ipynb b/notebooks/06_architecture.ipynb index 72c53c004a..8df1731aba 100644 --- a/notebooks/06_architecture.ipynb +++ b/notebooks/06_architecture.ipynb @@ -21,26 +21,7 @@ "from torchinfo import summary\n", "from pufferlib.ocean.drive import binding\n", "from pufferlib.ocean.torch import Drive as DrivePolicy\n", - "from notebooks.notebook_utils import make_drive_env, notebook_dims, zero_actions\n", - "\n", - "# --- Environment configuration ---\n", - "NUM_AGENTS = 64\n", - "SIMULATION_MODE = \"gigaflow\"\n", - "DYNAMICS_MODEL = \"jerk\"\n", - "ACTION_TYPE = \"discrete\"\n", - "DT = 0.1\n", - "SCENARIO_LENGTH = 512\n", - "RESAMPLE_FREQUENCY = 0\n", - "REWARD_CONDITIONING = True\n", - "REWARD_RANDOMIZATION = False\n", - "TARGET_TYPE = \"static\"\n", - "COLLISION_BEHAVIOR = 1\n", - "OFFROAD_BEHAVIOR = 1\n", - "SEED = 42\n", - "MAX_PARTNERS = 20\n", - "MAX_LANES = 100\n", - "MAX_BOUNDS = 50\n", - "MAX_TRAFFIC = 4\n", + "from notebooks.notebook_utils import make_drive_env, zero_actions\n", "\n", "# --- Policy architecture ---\n", "INPUT_SIZE = 64\n", @@ -54,28 +35,7 @@ "ENCODER_GIGAFLOW = True\n", "DROPOUT = 0.0\n", "\n", - "env, obs, info = make_drive_env(\n", - " num_agents=NUM_AGENTS,\n", - " min_agents_per_env=NUM_AGENTS,\n", - " max_agents_per_env=NUM_AGENTS,\n", - " simulation_mode=SIMULATION_MODE,\n", - " dynamics_model=DYNAMICS_MODEL,\n", - " action_type=ACTION_TYPE,\n", - " dt=DT,\n", - " scenario_length=SCENARIO_LENGTH,\n", - " resample_frequency=RESAMPLE_FREQUENCY,\n", - " reward_conditioning=REWARD_CONDITIONING,\n", - " reward_randomization=REWARD_RANDOMIZATION,\n", - " target_type=TARGET_TYPE,\n", - " collision_behavior=COLLISION_BEHAVIOR,\n", - " offroad_behavior=OFFROAD_BEHAVIOR,\n", - " obs_slots_lane_n=MAX_LANES,\n", - " obs_slots_boundary_n=MAX_BOUNDS,\n", - " obs_slots_partners_n=MAX_PARTNERS,\n", - " obs_slots_traffic_controls_n=MAX_TRAFFIC,\n", - " seed=SEED,\n", - ")\n", - "globals().update(notebook_dims(env))\n", + "env, obs, info = make_drive_env()\n", "\n", "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n", "policy = DrivePolicy(\n", @@ -138,15 +98,15 @@ "# Collect encoder info — encoder_gigaflow adds Tanh+Dropout between LN and second Linear\n", "# ego, partner, conditioning use encoder_gigaflow; lane, boundary, traffic_ctrl use dropout\n", "encoders = [\n", - " (\"ego\", EGO_DIM, 1, \"direct\", ENCODER_GIGAFLOW),\n", + " (\"ego\", env.ego_features, 1, \"direct\", ENCODER_GIGAFLOW),\n", " (\"conditioning\", cond_dim, 1, \"direct\", ENCODER_GIGAFLOW) if cond_dim > 0 else None,\n", - " (\"partner\", PARTNER_F, MAX_PARTNERS, \"max-pool\", ENCODER_GIGAFLOW),\n", - " (\"lane\", ROAD_F, MAX_LANES, \"max-pool\", ENCODER_GIGAFLOW),\n", - " (\"boundary\", ROAD_F, MAX_BOUNDS, \"max-pool\", ENCODER_GIGAFLOW),\n", + " (\"partner\", env.partner_features, env.obs_slots_partners_n, \"max-pool\", ENCODER_GIGAFLOW),\n", + " (\"lane\", env.road_features, env.obs_slots_lane_kept, \"max-pool\", ENCODER_GIGAFLOW),\n", + " (\"boundary\", env.road_features, env.obs_slots_boundary_kept, \"max-pool\", ENCODER_GIGAFLOW),\n", " (\n", " \"traffic_ctrl\",\n", - " TRAFFIC_CONTROL_F - 2 + binding.NUM_TRAFFIC_CONTROL_TYPES + binding.NUM_TRAFFIC_CONTROL_STATES,\n", - " MAX_TRAFFIC,\n", + " env.traffic_control_features - 2 + binding.NUM_TRAFFIC_CONTROL_TYPES + binding.NUM_TRAFFIC_CONTROL_STATES,\n", + " env.obs_slots_traffic_controls_n,\n", " \"max-pool (onehot)\",\n", " ENCODER_GIGAFLOW,\n", " ),\n", @@ -330,12 +290,12 @@ "x = obs_tensor\n", "backbone = policy.actor_backbone\n", "\n", - "slide_idx = EGO_DIM\n", + "slide_idx = env.ego_features\n", "cond_dim = backbone.conditioning_dim\n", - "partner_dim = MAX_PARTNERS * PARTNER_F\n", - "lane_dim = MAX_LANES * ROAD_F\n", - "boundary_dim = MAX_BOUNDS * ROAD_F\n", - "traffic_dim = MAX_TRAFFIC * TRAFFIC_CONTROL_F\n", + "partner_dim = env.obs_slots_partners_n * env.partner_features\n", + "lane_dim = env.obs_slots_lane_kept * env.road_features\n", + "boundary_dim = env.obs_slots_boundary_kept * env.road_features\n", + "traffic_dim = env.obs_slots_traffic_controls_n * env.traffic_control_features\n", "\n", "# Slicing\n", "ego_obs = x[:, :slide_idx]\n", @@ -377,22 +337,22 @@ " cond_enc = backbone.conditioning_encoder(cond_obs)\n", " print(f\" cond_encoder: {cond_obs.shape} -> {cond_enc.shape}\")\n", "\n", - " p_reshaped = partner_obs.view(-1, MAX_PARTNERS, PARTNER_F)\n", + " p_reshaped = partner_obs.view(-1, env.obs_slots_partners_n, env.partner_features)\n", " p_enc, _ = backbone.partner_encoder(p_reshaped).max(dim=1)\n", " print(f\" partner_encoder: {partner_obs.shape} -> view {p_reshaped.shape} -> encode -> max-pool -> {p_enc.shape}\")\n", "\n", - " l_reshaped = lane_obs.view(-1, MAX_LANES, ROAD_F)\n", + " l_reshaped = lane_obs.view(-1, env.obs_slots_lane_kept, env.road_features)\n", " l_enc, _ = backbone.lane_encoder(l_reshaped).max(dim=1)\n", " print(f\" lane_encoder: {lane_obs.shape} -> view {l_reshaped.shape} -> encode -> max-pool -> {l_enc.shape}\")\n", "\n", - " b_reshaped = boundary_obs.view(-1, MAX_BOUNDS, ROAD_F)\n", + " b_reshaped = boundary_obs.view(-1, env.obs_slots_boundary_kept, env.road_features)\n", " b_enc, _ = backbone.boundary_encoder(b_reshaped).max(dim=1)\n", " print(f\" bound_encoder: {boundary_obs.shape} -> view {b_reshaped.shape} -> encode -> max-pool -> {b_enc.shape}\")\n", "\n", - " t_reshaped = traffic_obs.view(-1, MAX_TRAFFIC, TRAFFIC_CONTROL_F)\n", - " t_cont = t_reshaped[:, :, : TRAFFIC_CONTROL_F - 2]\n", - " t_type = t_reshaped[:, :, TRAFFIC_CONTROL_F - 2]\n", - " t_state = t_reshaped[:, :, TRAFFIC_CONTROL_F - 1]\n", + " t_reshaped = traffic_obs.view(-1, env.obs_slots_traffic_controls_n, env.traffic_control_features)\n", + " t_cont = t_reshaped[:, :, : env.traffic_control_features - 2]\n", + " t_type = t_reshaped[:, :, env.traffic_control_features - 2]\n", + " t_state = t_reshaped[:, :, env.traffic_control_features - 1]\n", " t_type_onehot = F.one_hot(t_type.long(), num_classes=binding.NUM_TRAFFIC_CONTROL_TYPES).float()\n", " t_state_onehot = F.one_hot(t_state.long(), num_classes=binding.NUM_TRAFFIC_CONTROL_STATES).float()\n", " t_input = torch.cat([t_cont, t_type_onehot, t_state_onehot], dim=2)\n", @@ -470,35 +430,37 @@ "source": [ "policy.eval()\n", "with torch.no_grad():\n", - " hidden = policy.actor_backbone(obs_tensor, EGO_DIM)\n", + " hidden = policy.actor_backbone(obs_tensor, env.ego_features)\n", " action_logits, value = policy.decode_actions(hidden)\n", "\n", "# Collect per-encoder activations\n", "activations = {}\n", "with torch.no_grad():\n", - " slide = EGO_DIM\n", - " activations[\"ego\"] = backbone.ego_encoder(obs_tensor[:, :EGO_DIM])\n", + " slide = env.ego_features\n", + " activations[\"ego\"] = backbone.ego_encoder(obs_tensor[:, : env.ego_features])\n", "\n", " if cond_dim > 0:\n", " activations[\"conditioning\"] = backbone.conditioning_encoder(obs_tensor[:, slide : slide + cond_dim])\n", " slide += cond_dim\n", "\n", - " p_obs = obs_tensor[:, slide : slide + partner_dim].view(-1, MAX_PARTNERS, PARTNER_F)\n", + " p_obs = obs_tensor[:, slide : slide + partner_dim].view(-1, env.obs_slots_partners_n, env.partner_features)\n", " activations[\"partner\"], _ = backbone.partner_encoder(p_obs).max(dim=1)\n", " slide += partner_dim\n", "\n", - " l_obs = obs_tensor[:, slide : slide + lane_dim].view(-1, MAX_LANES, ROAD_F)\n", + " l_obs = obs_tensor[:, slide : slide + lane_dim].view(-1, env.obs_slots_lane_kept, env.road_features)\n", " activations[\"lane\"], _ = backbone.lane_encoder(l_obs).max(dim=1)\n", " slide += lane_dim\n", "\n", - " b_obs = obs_tensor[:, slide : slide + boundary_dim].view(-1, MAX_BOUNDS, ROAD_F)\n", + " b_obs = obs_tensor[:, slide : slide + boundary_dim].view(-1, env.obs_slots_boundary_kept, env.road_features)\n", " activations[\"boundary\"], _ = backbone.boundary_encoder(b_obs).max(dim=1)\n", " slide += boundary_dim\n", "\n", - " t_obs = obs_tensor[:, slide : slide + traffic_dim].view(-1, MAX_TRAFFIC, TRAFFIC_CONTROL_F)\n", - " t_cont = t_obs[:, :, : TRAFFIC_CONTROL_F - 2]\n", - " t_type = t_obs[:, :, TRAFFIC_CONTROL_F - 2]\n", - " t_state = t_obs[:, :, TRAFFIC_CONTROL_F - 1]\n", + " t_obs = obs_tensor[:, slide : slide + traffic_dim].view(\n", + " -1, env.obs_slots_traffic_controls_n, env.traffic_control_features\n", + " )\n", + " t_cont = t_obs[:, :, : env.traffic_control_features - 2]\n", + " t_type = t_obs[:, :, env.traffic_control_features - 2]\n", + " t_state = t_obs[:, :, env.traffic_control_features - 1]\n", " t_type_onehot = F.one_hot(t_type.long(), num_classes=binding.NUM_TRAFFIC_CONTROL_TYPES).float()\n", " t_state_onehot = F.one_hot(t_state.long(), num_classes=binding.NUM_TRAFFIC_CONTROL_STATES).float()\n", " t_input = torch.cat([t_cont, t_type_onehot, t_state_onehot], dim=2)\n", @@ -674,7 +636,7 @@ "\n", "axes[1].bar(names, times, color=bar_colors, edgecolor=\"black\")\n", "axes[1].set_ylabel(\"ms / forward\")\n", - "axes[1].set_title(f\"Forward Pass Latency ({NUM_AGENTS} agents)\")\n", + "axes[1].set_title(f\"Forward Pass Latency ({env.num_agents} agents)\")\n", "axes[1].tick_params(axis=\"x\", rotation=30)\n", "for i, v in enumerate(times):\n", " axes[1].text(i, v, f\"{v:.2f}\", ha=\"center\", va=\"bottom\", fontsize=7)\n", @@ -705,18 +667,18 @@ " all_obs.append(o)\n", "stacked = np.concatenate(all_obs, axis=0)\n", "\n", - "slide = EGO_DIM\n", - "segments = [(\"ego\", 0, EGO_DIM, 1, EGO_DIM)]\n", + "slide = env.ego_features\n", + "segments = [(\"ego\", 0, env.ego_features, 1, env.ego_features)]\n", "if cond_dim > 0:\n", " segments.append((\"conditioning\", slide, slide + cond_dim, 1, cond_dim))\n", " slide += cond_dim\n", - "segments.append((\"partners\", slide, slide + partner_dim, MAX_PARTNERS, PARTNER_F))\n", + "segments.append((\"partners\", slide, slide + partner_dim, env.obs_slots_partners_n, env.partner_features))\n", "slide += partner_dim\n", - "segments.append((\"lanes\", slide, slide + lane_dim, MAX_LANES, ROAD_F))\n", + "segments.append((\"lanes\", slide, slide + lane_dim, env.obs_slots_lane_kept, env.road_features))\n", "slide += lane_dim\n", - "segments.append((\"boundaries\", slide, slide + boundary_dim, MAX_BOUNDS, ROAD_F))\n", + "segments.append((\"boundaries\", slide, slide + boundary_dim, env.obs_slots_boundary_kept, env.road_features))\n", "slide += boundary_dim\n", - "segments.append((\"traffic\", slide, slide + traffic_dim, MAX_TRAFFIC, TRAFFIC_CONTROL_F))\n", + "segments.append((\"traffic\", slide, slide + traffic_dim, env.obs_slots_traffic_controls_n, env.traffic_control_features))\n", "\n", "print(f\"{'Segment':>15s} | {'Slots':>5s} | {'Features':>8s} | {'Fill %':>7s} | {'Mean':>8s} | {'Std':>8s}\")\n", "print(\"-\" * 65)\n", @@ -762,7 +724,7 @@ "source": [ "# Jacobian-based sensitivity: d(hidden) / d(obs) magnitude\n", "sample = obs_tensor[:1].clone().requires_grad_(True)\n", - "hidden = policy.actor_backbone(sample, EGO_DIM)\n", + "hidden = policy.actor_backbone(sample, env.ego_features)\n", "# Sum hidden to scalar for backward\n", "hidden.sum().backward()\n", "sensitivity = sample.grad.abs().squeeze().cpu().numpy()\n", @@ -775,9 +737,9 @@ "axes[0].set_title(\"Input Feature Sensitivity (|d hidden / d obs|)\")\n", "\n", "# Mark segments\n", - "seg_boundaries = [0, EGO_DIM]\n", + "seg_boundaries = [0, env.ego_features]\n", "seg_labels = [\"ego\"]\n", - "s = EGO_DIM\n", + "s = env.ego_features\n", "if cond_dim > 0:\n", " s += cond_dim\n", " seg_boundaries.append(s)\n", diff --git a/notebooks/notebook_utils.py b/notebooks/notebook_utils.py index 444745a2d4..02164349f7 100644 --- a/notebooks/notebook_utils.py +++ b/notebooks/notebook_utils.py @@ -5,8 +5,6 @@ import numpy as np import yaml -import torch - from pufferlib.ocean.drive.drive import Drive from pufferlib.ocean.drive import binding from pufferlib.ocean.torch import Drive as DrivePolicy @@ -66,10 +64,23 @@ "map_dir": MAP_DIR, "collision_behavior": 1, "offroad_behavior": 1, - "obs_slots_lane_n": 32, - "obs_slots_boundary_n": 32, + "obs_slots_lane_n": 80, + "obs_slots_boundary_n": 80, "obs_slots_partners_n": 16, - "obs_slots_traffic_controls_n": 10, + "obs_slots_traffic_controls_n": 4, + "obs_dropout_lane": 0.0, + "obs_dropout_boundary": 0.0, + "obs_norm_goal_offset_m": 120.0, + "obs_norm_xy_offset_m": 120.0, + "obs_norm_veh_length_m": 15.0, + "obs_norm_veh_width_m": 10.0, + "obs_norm_road_seg_length_m": 10.0, + "obs_norm_road_seg_width_m": 5.0, + "obs_range_road_front_m": 120.0, + "obs_range_road_behind_m": 20.0, + "obs_range_road_side_m": 30.0, + "obs_range_partner_m": 100.0, + "obs_range_traffic_control_m": 100.0, "seed": 42, } @@ -98,27 +109,6 @@ def make_drive_env(**overrides): return env, obs, info -def notebook_dims(env): - return { - "EGO_DIM": env.ego_features, - "NUM_COEFS": binding.NUM_REWARD_COEFS, - "PARTNER_F": env.partner_features, - "ROAD_F": env.road_features, - "TRAFFIC_CONTROL_F": env.traffic_control_features, - "NUM_TRAFFIC_CONTROL_TYPES": binding.NUM_TRAFFIC_CONTROL_TYPES, - "MAX_PARTNERS": env.obs_slots_partners_n, - "MAX_LANES": env.obs_slots_lane_kept, - "MAX_BOUNDS": env.obs_slots_boundary_kept, - "MAX_TRAFFIC": env.obs_slots_traffic_controls_n, - "MAX_TARGET": env.num_target_waypoints, - "TARGET_F": env.target_features, - "TARGET_DIM": env.target_dim, - "N_ACTIONS": int(env.single_action_space.nvec[0]) if hasattr(env.single_action_space, "nvec") else 1, - "N": env.num_agents, - "ACT_SHAPE": action_shape(env), - } - - def action_shape(env): if hasattr(env.single_action_space, "nvec"): return (env.num_agents, len(env.single_action_space.nvec)) @@ -154,12 +144,4 @@ def load_notebook_config(checkpoint_path=None, env_name="puffer_drive"): if section in ycfg and isinstance(ycfg[section], dict): config[section].update(ycfg[section]) - config["train"]["use_rnn"] = config.get("rnn_name") is not None return config - - -def make_rnn_state(policy, n, device): - return { - "lstm_h": torch.zeros(n, policy.hidden_size, device=device), - "lstm_c": torch.zeros(n, policy.hidden_size, device=device), - } diff --git a/pufferlib/viz.py b/pufferlib/viz.py index 5656d7d392..3609984101 100644 --- a/pufferlib/viz.py +++ b/pufferlib/viz.py @@ -517,9 +517,9 @@ def unpack_obs( target_type: str = "static", reward_conditioning: bool = False, num_target_waypoints: int = 5, - max_partners: int = 16, - max_lane_segments: int = 16, - max_boundary_segments: int = 16, + obs_slots_partners_n: int = 16, + obs_slots_lane_n: int = 16, + obs_slots_boundary_n: int = 16, obs_slots_traffic_controls_n: int = 16, obs_dropout_lane: float = 0.0, obs_dropout_boundary: float = 0.0, @@ -536,6 +536,9 @@ def unpack_obs( if obs_flat.ndim == 1: obs_flat = obs_flat[None, :] + if isinstance(target_type, int): + target_type = "static" if target_type == binding.TARGET_STATIC else "dynamic" + ego_dim = binding.EGO_FEATURES # Partner obs @@ -544,8 +547,8 @@ def unpack_obs( road_feature_size = binding.ROAD_FEATURES # Traffic control obs traffic_control_feature_size = binding.TRAFFIC_CONTROL_FEATURES - lane_segment_count = compute_effective_road_obs_count(max_lane_segments, obs_dropout_lane) - boundary_segment_count = compute_effective_road_obs_count(max_boundary_segments, obs_dropout_boundary) + lane_segment_count = compute_effective_road_obs_count(obs_slots_lane_n, obs_dropout_lane) + boundary_segment_count = compute_effective_road_obs_count(obs_slots_boundary_n, obs_dropout_boundary) # Target obs target_features = binding.STATIC_TARGET_FEATURES if target_type == "static" else binding.DYNAMIC_TARGET_FEATURES @@ -564,9 +567,9 @@ def unpack_obs( # Extract partners partners_start = target_end - partners_end = partners_start + max_partners * partner_feature_size + partners_end = partners_start + obs_slots_partners_n * partner_feature_size partners_obs = obs_flat[:, partners_start:partners_end] - partners_obs = partners_obs.reshape(-1, max_partners, partner_feature_size) + partners_obs = partners_obs.reshape(-1, obs_slots_partners_n, partner_feature_size) # Extract lane elements lane_start = partners_end @@ -606,9 +609,9 @@ def plot_observation( target_type="static", reward_conditioning=False, num_target_waypoints=10, - max_partners=16, - max_lane_segments=32, - max_boundary_segments=32, + obs_slots_partners_n=16, + obs_slots_lane_n=32, + obs_slots_boundary_n=32, obs_slots_traffic_controls_n=4, obs_dropout_lane=0.0, obs_dropout_boundary=0.0, @@ -626,6 +629,9 @@ def plot_observation( obs: flattened observation tensor target_type: 0 for goal only, 1 for waypoints only, 2 for both """ + if isinstance(target_type, int): + target_type = "static" if target_type == binding.TARGET_STATIC else "dynamic" + fig, ax = plt.subplots(figsize=(20, 20)) ego_state, target_obs, partners_obs, lane_obs, boundary_obs, traffic_controls_obs = unpack_obs( @@ -633,9 +639,9 @@ def plot_observation( target_type=target_type, reward_conditioning=reward_conditioning, num_target_waypoints=num_target_waypoints, - max_partners=max_partners, - max_lane_segments=max_lane_segments, - max_boundary_segments=max_boundary_segments, + obs_slots_partners_n=obs_slots_partners_n, + obs_slots_lane_n=obs_slots_lane_n, + obs_slots_boundary_n=obs_slots_boundary_n, obs_slots_traffic_controls_n=obs_slots_traffic_controls_n, obs_dropout_lane=obs_dropout_lane, obs_dropout_boundary=obs_dropout_boundary, @@ -947,7 +953,7 @@ def generate_interactive_replay(scenario, replay, filename="replay.html"): "dynamics_model": env_cfg.get("dynamics_model", "classic"), "num_target_waypoints": int(env_cfg["num_target_waypoints"]), "reward_conditioning": bool(env_cfg["reward_conditioning"]), - "max_partners": int(env_cfg["obs_slots_partners_n"]), + "obs_slots_partners_n": int(env_cfg["obs_slots_partners_n"]), "lane_count": int(lane_count), "boundary_count": int(boundary_count), "traffic_obs_count": int(env_cfg["obs_slots_traffic_controls_n"]), @@ -1197,14 +1203,14 @@ def generate_interactive_replay(scenario, replay, filename="replay.html"): let p = base, ego = obs.subarray(p, p+10); p += 10; if (H.reward_conditioning) p += 17; const targetStart = p; p += H.num_target_waypoints * H.target_features; - const partnersStart = p; p += H.max_partners * 8; + const partnersStart = p; p += H.obs_slots_partners_n * 8; const lanesStart = p; p += H.lane_count * 7; const boundsStart = p; p += H.boundary_count * 7; const trafficStart = p; const rot = (x,y) => [-y,x]; const zero = (off,n) => { for(let i=0;i { const out=[]; for(let i=0;i