Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
168 changes: 57 additions & 111 deletions notebooks/01_observations.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -16,65 +16,10 @@
"source": [
"import numpy as np\n",
"import matplotlib.pyplot as plt\n",
"from pufferlib.ocean.drive import binding\n",
"import pufferlib.viz\n",
"from notebooks.notebook_utils import (\n",
" COEF_NAMES,\n",
" EGO_LABELS,\n",
" make_drive_env,\n",
" notebook_dims,\n",
" random_actions,\n",
" zero_actions,\n",
")\n",
"\n",
"# --- Environment configuration ---\n",
"NUM_AGENTS = 64\n",
"SIMULATION_MODE = \"gigaflow\"\n",
"DYNAMICS_MODEL = \"jerk\"\n",
"ACTION_TYPE = \"discrete\"\n",
"DT = 0.1\n",
"SCENARIO_LENGTH = 512\n",
"RESAMPLE_FREQUENCY = 0\n",
"REWARD_CONDITIONING = True\n",
"REWARD_RANDOMIZATION = False\n",
"TARGET_TYPE = \"static\"\n",
"COLLISION_BEHAVIOR = 1\n",
"OFFROAD_BEHAVIOR = 1\n",
"SEED = 42\n",
"\n",
"# --- Observation dimensions ---\n",
"MAX_PARTNERS = 16\n",
"MAX_LANES = 32\n",
"MAX_BOUNDS = 32\n",
"MAX_TRAFFIC = 4\n",
"\n",
"env, obs, info = make_drive_env(\n",
" num_agents=NUM_AGENTS,\n",
" min_agents_per_env=NUM_AGENTS,\n",
" max_agents_per_env=NUM_AGENTS,\n",
" simulation_mode=SIMULATION_MODE,\n",
" dynamics_model=DYNAMICS_MODEL,\n",
" action_type=ACTION_TYPE,\n",
" dt=DT,\n",
" scenario_length=SCENARIO_LENGTH,\n",
" resample_frequency=RESAMPLE_FREQUENCY,\n",
" reward_conditioning=REWARD_CONDITIONING,\n",
" reward_randomization=REWARD_RANDOMIZATION,\n",
" target_type=TARGET_TYPE,\n",
" collision_behavior=COLLISION_BEHAVIOR,\n",
" offroad_behavior=OFFROAD_BEHAVIOR,\n",
" obs_slots_lane_n=MAX_LANES,\n",
" obs_slots_boundary_n=MAX_BOUNDS,\n",
" obs_slots_partners_n=MAX_PARTNERS,\n",
" obs_slots_traffic_controls_n=MAX_TRAFFIC,\n",
" seed=SEED,\n",
")\n",
"globals().update(notebook_dims(env))\n",
"from pufferlib.viz import plot_observation, plot_simulator_state, unpack_obs\n",
"from notebooks.notebook_utils import COEF_NAMES, make_drive_env, zero_actions\n",
"\n",
"print(f\"env ready: {N} agents, obs={obs.shape}, act_shape={ACT_SHAPE}\")\n",
"print(f\"EGO_DIM={EGO_DIM}, NUM_COEFS={NUM_COEFS}, MAX_PARTNERS={MAX_PARTNERS}, PARTNER_F={PARTNER_F}\")\n",
"print(f\"MAX_LANES={MAX_LANES}, MAX_BOUNDS={MAX_BOUNDS}, ROAD_F={ROAD_F}\")\n",
"print(f\"MAX_TRAFFIC={MAX_TRAFFIC}, TRAFFIC_F={TRAFFIC_CONTROL_F}\")"
"env, obs, info = make_drive_env()"
]
},
{
Expand Down Expand Up @@ -127,15 +72,15 @@
"metadata": {},
"outputs": [],
"source": [
"ego, target, partners, lanes, boundaries, traffic = pufferlib.viz.unpack_obs(\n",
"ego, target, partners, lanes, boundaries, traffic = unpack_obs(\n",
" obs[:1],\n",
" target_type=TARGET_TYPE,\n",
" reward_conditioning=REWARD_CONDITIONING,\n",
" target_type=env.target_type,\n",
" reward_conditioning=env.reward_conditioning,\n",
" num_target_waypoints=env.num_target_waypoints,\n",
" max_partners=MAX_PARTNERS,\n",
" max_lane_segments=MAX_LANES,\n",
" max_boundary_segments=MAX_BOUNDS,\n",
" obs_slots_traffic_controls_n=MAX_TRAFFIC,\n",
" obs_slots_partners_n=env.obs_slots_partners_n,\n",
" obs_slots_lane_n=env.obs_slots_lane_kept,\n",
" obs_slots_boundary_n=env.obs_slots_boundary_kept,\n",
" obs_slots_traffic_controls_n=env.obs_slots_traffic_controls_n,\n",
")\n",
"print(f\"ego: {ego.shape} = {ego}\")\n",
"print(f\"target: {target.shape}\")\n",
Expand Down Expand Up @@ -177,37 +122,47 @@
"idx = 0\n",
"\n",
"# Ego\n",
"ego_manual = o[idx : idx + EGO_DIM]\n",
"idx += EGO_DIM\n",
"ego_manual = o[idx : idx + env.ego_features]\n",
"idx += env.ego_features\n",
"assert np.allclose(ego_manual, ego), f\"ego mismatch: {ego_manual} vs {ego}\"\n",
"\n",
"# Reward conditioning coefs\n",
"coefs_manual = o[idx : idx + NUM_COEFS]\n",
"idx += NUM_COEFS\n",
"coefs_manual = o[idx : idx + env.num_reward_coefs]\n",
"idx += env.num_reward_coefs\n",
"\n",
"# Target\n",
"target_manual = o[idx : idx + MAX_TARGET * TARGET_F].reshape(MAX_TARGET, TARGET_F)\n",
"idx += MAX_TARGET * TARGET_F\n",
"target_manual = o[idx : idx + env.num_target_waypoints * env.target_features].reshape(\n",
" env.num_target_waypoints, env.target_features\n",
")\n",
"idx += env.num_target_waypoints * env.target_features\n",
"assert np.allclose(target_manual, target), \"target mismatch\"\n",
"\n",
"# Partners\n",
"partners_manual = o[idx : idx + MAX_PARTNERS * PARTNER_F].reshape(MAX_PARTNERS, PARTNER_F)\n",
"idx += MAX_PARTNERS * PARTNER_F\n",
"partners_manual = o[idx : idx + env.obs_slots_partners_n * env.partner_features].reshape(\n",
" env.obs_slots_partners_n, env.partner_features\n",
")\n",
"idx += env.obs_slots_partners_n * env.partner_features\n",
"assert np.allclose(partners_manual, partners), \"partners mismatch\"\n",
"\n",
"# Lanes\n",
"lanes_manual = o[idx : idx + MAX_LANES * ROAD_F].reshape(MAX_LANES, ROAD_F)\n",
"idx += MAX_LANES * ROAD_F\n",
"lanes_manual = o[idx : idx + env.obs_slots_lane_kept * env.road_features].reshape(\n",
" env.obs_slots_lane_kept, env.road_features\n",
")\n",
"idx += env.obs_slots_lane_kept * env.road_features\n",
"assert np.allclose(lanes_manual, lanes), \"lanes mismatch\"\n",
"\n",
"# Boundaries\n",
"bounds_manual = o[idx : idx + MAX_BOUNDS * ROAD_F].reshape(MAX_BOUNDS, ROAD_F)\n",
"idx += MAX_BOUNDS * ROAD_F\n",
"bounds_manual = o[idx : idx + env.obs_slots_boundary_kept * env.road_features].reshape(\n",
" env.obs_slots_boundary_kept, env.road_features\n",
")\n",
"idx += env.obs_slots_boundary_kept * env.road_features\n",
"assert np.allclose(bounds_manual, boundaries), \"boundaries mismatch\"\n",
"\n",
"# Traffic\n",
"traffic_manual = o[idx : idx + MAX_TRAFFIC * TRAFFIC_CONTROL_F].reshape(MAX_TRAFFIC, TRAFFIC_CONTROL_F)\n",
"idx += MAX_TRAFFIC * TRAFFIC_CONTROL_F\n",
"traffic_manual = o[idx : idx + env.obs_slots_traffic_controls_n * env.traffic_control_features].reshape(\n",
" env.obs_slots_traffic_controls_n, env.traffic_control_features\n",
")\n",
"idx += env.obs_slots_traffic_controls_n * env.traffic_control_features\n",
"assert np.allclose(traffic_manual, traffic), \"traffic mismatch\"\n",
"\n",
"assert idx == obs.shape[1], f\"obs size mismatch: used {idx}, total {obs.shape[1]}\"\n",
Expand All @@ -227,9 +182,9 @@
"metadata": {},
"outputs": [],
"source": [
"coefs = obs[0, EGO_DIM : EGO_DIM + NUM_COEFS]\n",
"coefs = obs[0, env.ego_features : env.ego_features + env.num_reward_coefs]\n",
"fig, ax = plt.subplots(figsize=(12, 4))\n",
"bars = ax.bar(range(NUM_COEFS), coefs, tick_label=COEF_NAMES)\n",
"bars = ax.bar(range(env.num_reward_coefs), coefs, tick_label=COEF_NAMES)\n",
"ax.set_ylabel(\"Normalized coef value\")\n",
"ax.set_title(\"Reward conditioning coefficients (agent 0)\")\n",
"plt.xticks(rotation=45, ha=\"right\")\n",
Expand All @@ -239,7 +194,7 @@
"plt.show()\n",
"\n",
"# Compare across agents\n",
"all_coefs = obs[:, EGO_DIM : EGO_DIM + NUM_COEFS]\n",
"all_coefs = obs[:, env.ego_features : env.ego_features + env.num_reward_coefs]\n",
"print(\"Coef stats across agents:\")\n",
"for i, name in enumerate(COEF_NAMES):\n",
" c = all_coefs[:, i]\n",
Expand All @@ -262,13 +217,13 @@
"partner_labels = [\"rel_x\", \"rel_y\", \"rel_z\", \"length\", \"width\", \"heading_cos\", \"heading_sin\", \"speed\"]\n",
"active_mask = ~np.all(partners == 0, axis=1)\n",
"n_active = active_mask.sum()\n",
"print(f\"Active partners: {n_active}/{MAX_PARTNERS}\")\n",
"print(f\"Active partners: {n_active}/{env.obs_slots_partners_n}\")\n",
"\n",
"fig, axes = plt.subplots(1, 2, figsize=(14, 5))\n",
"\n",
"# Heatmap\n",
"im = axes[0].imshow(partners, aspect=\"auto\", cmap=\"RdBu_r\", vmin=-1, vmax=1)\n",
"axes[0].set_xticks(range(PARTNER_F))\n",
"axes[0].set_xticks(range(env.partner_features))\n",
"axes[0].set_xticklabels(partner_labels, rotation=45, ha=\"right\")\n",
"axes[0].set_ylabel(\"Partner index\")\n",
"axes[0].set_title(f\"Partner obs heatmap ({n_active} active)\")\n",
Expand Down Expand Up @@ -308,7 +263,9 @@
"\n",
"lane_active = ~np.all(lanes == 0, axis=1)\n",
"bound_active = ~np.all(boundaries == 0, axis=1)\n",
"print(f\"Active lanes: {lane_active.sum()}/{MAX_LANES}, boundaries: {bound_active.sum()}/{MAX_BOUNDS}\")\n",
"print(\n",
" f\"Active lanes: {lane_active.sum()}/{env.obs_slots_lane_kept}, boundaries: {bound_active.sum()}/{env.obs_slots_boundary_kept}\"\n",
")\n",
"\n",
"fig, ax = plt.subplots(figsize=(10, 10))\n",
"\n",
Expand Down Expand Up @@ -367,15 +324,15 @@
"metadata": {},
"outputs": [],
"source": [
"img = pufferlib.viz.plot_observation(\n",
"img = plot_observation(\n",
" obs[:1],\n",
" target_type=TARGET_TYPE,\n",
" reward_conditioning=True,\n",
" target_type=env.target_type,\n",
" reward_conditioning=env.reward_conditioning,\n",
" num_target_waypoints=env.num_target_waypoints,\n",
" max_partners=MAX_PARTNERS,\n",
" max_lane_segments=MAX_LANES,\n",
" max_boundary_segments=MAX_BOUNDS,\n",
" obs_slots_traffic_controls_n=MAX_TRAFFIC,\n",
" obs_slots_partners_n=env.obs_slots_partners_n,\n",
" obs_slots_lane_n=env.obs_slots_lane_kept,\n",
" obs_slots_boundary_n=env.obs_slots_boundary_kept,\n",
" obs_slots_traffic_controls_n=env.obs_slots_traffic_controls_n,\n",
")\n",
"fig, ax = plt.subplots(figsize=(10, 10))\n",
"ax.imshow(img)\n",
Expand Down Expand Up @@ -405,7 +362,7 @@
"else:\n",
" scenario = scenarios\n",
"\n",
"img = pufferlib.viz.plot_simulator_state(scenario, timestep=0)\n",
"img = plot_simulator_state(scenario)\n",
"fig, ax = plt.subplots(figsize=(12, 12))\n",
"ax.imshow(img)\n",
"ax.axis(\"off\")\n",
Expand All @@ -428,23 +385,12 @@
"outputs": [],
"source": [
"N_STEPS = 20\n",
"ego_labels = [\n",
" \"speed\",\n",
" \"width\",\n",
" \"length\",\n",
" \"steering\",\n",
" \"a_long\",\n",
" \"a_lat\",\n",
" \"lane_center_dist\",\n",
" \"lane_heading_cos\",\n",
" \"speed_limit\",\n",
"]\n",
"ego_history = np.zeros((N_STEPS, EGO_DIM))\n",
"ego_history = np.zeros((N_STEPS, env.ego_features))\n",
"\n",
"for t in range(N_STEPS):\n",
" actions = zero_actions(env)\n",
" obs_t, _, _, _, _ = env.step(actions)\n",
" ego_history[t] = obs_t[0, :EGO_DIM]\n",
" ego_history[t] = obs_t[0, : env.ego_features]\n",
"\n",
"fig, axes = plt.subplots(2, 2, figsize=(14, 8))\n",
"# Speed\n",
Expand Down Expand Up @@ -486,16 +432,16 @@
"speeds = obs[:, 0] # speed is at index 0\n",
"\n",
"# Target waypoints start after ego + reward coefs\n",
"target_start = EGO_DIM + NUM_COEFS\n",
"target_start = env.ego_features + env.num_reward_coefs\n",
"# Each target waypoint has TARGET_F features; first two are rel_x, rel_y\n",
"first_target_x = obs[:, target_start]\n",
"first_target_y = obs[:, target_start + 1]\n",
"target_dists = np.sqrt(first_target_x**2 + first_target_y**2)\n",
"\n",
"# Count active partners per agent\n",
"partner_start = EGO_DIM + NUM_COEFS + TARGET_DIM\n",
"partner_end = partner_start + MAX_PARTNERS * PARTNER_F\n",
"all_partners = obs[:, partner_start:partner_end].reshape(-1, MAX_PARTNERS, PARTNER_F)\n",
"partner_start = env.ego_features + env.num_reward_coefs + env.num_target_waypoints * env.target_features\n",
"partner_end = partner_start + env.obs_slots_partners_n * env.partner_features\n",
"all_partners = obs[:, partner_start:partner_end].reshape(-1, env.obs_slots_partners_n, env.partner_features)\n",
"partner_counts = (~np.all(all_partners == 0, axis=2)).sum(axis=1)\n",
"\n",
"fig, axes = plt.subplots(1, 3, figsize=(15, 4))\n",
Expand All @@ -507,7 +453,7 @@
"axes[1].set_title(\"Distance to first target waypoint\")\n",
"axes[1].set_xlabel(\"distance\")\n",
"\n",
"axes[2].hist(partner_counts, bins=range(MAX_PARTNERS + 2), edgecolor=\"black\", alpha=0.7, color=\"green\")\n",
"axes[2].hist(partner_counts, bins=range(env.obs_slots_partners_n + 2), edgecolor=\"black\", alpha=0.7, color=\"green\")\n",
"axes[2].set_title(\"Active partners per agent\")\n",
"axes[2].set_xlabel(\"count\")\n",
"plt.tight_layout()\n",
Expand Down
Loading
Loading