Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ jobs:
PIP_NO_CACHE_DIR: 1
run: |
sudo apt-get update && sudo apt-get install -y build-essential cmake
python -m pip install -U pip pytest jupytext nbclient ipykernel ipywidgets
python -m pip install -U pip pytest jupytext nbclient ipykernel
pip install -e . --no-cache-dir
python setup.py build_ext --inplace --force

Expand Down
4 changes: 3 additions & 1 deletion notebooks/01_observations.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,7 @@
idx += env.obs_slots_traffic_controls_n * env.traffic_control_features
assert np.allclose(traffic_manual, traffic), "traffic mismatch"

idx += 4 # appended slot-count features at end of obs
assert idx == obs.shape[1], f"obs size mismatch: used {idx}, total {obs.shape[1]}"
print(f"All slices match. Total features used: {idx}")

Expand Down Expand Up @@ -175,7 +176,8 @@
"width",
"heading_cos",
"heading_sin",
"speed",
"rel_vx",
"rel_vy",
"seconds_stopped",
]
active_mask = ~np.all(partners == 0, axis=1)
Expand Down
4 changes: 2 additions & 2 deletions notebooks/04_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@
f"ego_obs: shape={ego_obs.shape}, NaN={torch.isnan(ego_obs).sum().item()}, range=[{ego_obs.min():.3f}, {ego_obs.max():.3f}]"
)

cond_dim = backbone.conditioning_dim
cond_dim = backbone.target_dim
if cond_dim > 0:
cond_obs = x[:, slide_idx : slide_idx + cond_dim]
slide_idx += cond_dim
Expand Down Expand Up @@ -143,7 +143,7 @@

if cond_dim > 0:
with torch.no_grad():
cond_enc = backbone.conditioning_encoder(cond_obs)
cond_enc = backbone.target_encoder(cond_obs)
print(
f"{'cond':>10s}_enc: NaN={torch.isnan(cond_enc).sum().item()}, dead={((cond_enc.abs().sum(dim=0) == 0).sum().item())}, range=[{cond_enc.min():.3f}, {cond_enc.max():.3f}]"
)
Expand Down
48 changes: 27 additions & 21 deletions notebooks/05_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -259,7 +259,7 @@ def run_rollout(env, policy, deterministic=False, horizon=HORIZON):
# - **Ego**: speed, width, length, [jerk: steering, a_long, a_lat], lane_center_dist, lane_angle, speed_limit
# - **Conditioning** (if enabled): 17 reward coefs (goal_radius, goal_speed, collision, offroad, comfort, lane_align, vel_align, lane_center, center_bias, velocity, reverse, stop_line, timestep, overspeed, throttle, steer, acc) + target waypoints
# - **Target**: static=rel_x,rel_y,rel_z per waypoint; dynamic=rel_x,rel_y,rel_z,heading_cos,heading_sin per waypoint
# - **Partners** (MAX_PARTNERS x 9): rel_x, rel_y, rel_z, length, width, heading_cos, heading_sin, speed, seconds_stopped
# - **Partners** (MAX_PARTNERS x 9): rel_x, rel_y, rel_z, length, width, heading_cos, heading_sin, rel_vx, rel_vy, seconds_stopped
# - **Lanes** (MAX_LANES x 7): rel_x, rel_y, rel_z, seg_length, seg_width, dir_cos, dir_sin
# - **Boundaries** (MAX_BOUNDS x 7): same as lanes
# - **Traffic controls** (MAX_TRAFFIC x 7): rel_x1, rel_y1, rel_x2, rel_y2, rel_z, type, state
Expand Down Expand Up @@ -351,11 +351,12 @@ def layer_stats(name, arr):
"width",
"heading_cos",
"heading_sin",
"speed",
"rel_vx",
"rel_vy",
"seconds_stopped",
]
for p in range(min(int(n_visible), 5)):
vals = ", ".join(f"{partner_labels[j]}={partners[p, j]:.3f}" for j in range(len(partner_labels)))
vals = ", ".join(f"{partner_labels[j]}={partners[p, j]:.3f}" for j in range(env.partner_features))
print(f" [{p}] {vals}")
if n_visible > 5:
print(f" ... ({n_visible - 5} more)")
Expand Down Expand Up @@ -631,7 +632,7 @@ def unpack_all_timesteps(bufs, agent_idx):
for i in range(partners.shape[0]):
if np.allclose(partners[i], 0):
continue
rx, ry, rz, w, l, hc, hs, vx, vy = partners[i]
rx, ry, rz, w, l, hc, hs, vx, vy, _ = partners[i]
heading = np.arctan2(hs, hc)
rect = Rectangle((-l / 2, -w / 2), l, w, facecolor="orange", edgecolor="black", alpha=0.6, zorder=9)
rect.set_transform(plt.matplotlib.transforms.Affine2D().rotate(heading).translate(rx, ry) + ax.transData)
Expand Down Expand Up @@ -773,7 +774,8 @@ def unpack_all_timesteps(bufs, agent_idx):
"width",
"heading_cos",
"heading_sin",
"speed",
"rel_vx",
"rel_vy",
"seconds_stopped",
]
obs_slots_partners_n = env.obs_slots_partners_n
Expand All @@ -789,17 +791,17 @@ def unpack_all_timesteps(bufs, agent_idx):

all_partners = buf_stoch["obs"][:, :, _p_start:_p_end].reshape(
-1, obs_slots_partners_n, pf
) # (H*N, obs_slots_partners_n, pf)
) # (H*N, obs_slots_partners_n, 10)
# Mask: partner is visible if any feature != 0
visible_mask = np.any(all_partners != 0, axis=2) # (H*N, 16)
visible_partners = all_partners[visible_mask] # (K, pf) — all visible partner observations
visible_partners = all_partners[visible_mask] # (K, 10) — all visible partner observations

print(
f"Total partner obs: {all_partners.shape[0] * obs_slots_partners_n}, visible: {len(visible_partners)} "
f"({100 * len(visible_partners) / (all_partners.shape[0] * obs_slots_partners_n):.1f}%)"
)

fig, axes = plt.subplots(3, 3, figsize=(21, 10))
fig, axes = plt.subplots(3, 4, figsize=(21, 11))
axes = axes.flatten()

for i, label in enumerate(partner_labels):
Expand All @@ -811,12 +813,16 @@ def unpack_all_timesteps(bufs, agent_idx):
axes[i].tick_params(labelsize=7)

# rel_x vs rel_y scatter in last panel
axes[8].scatter(visible_partners[:, 0], visible_partners[:, 1], s=1, alpha=0.15, color="darkorange")
axes[8].set_xlabel("rel_x")
axes[8].set_ylabel("rel_y")
axes[8].set_title("Partner positions (ego frame)")
axes[8].set_aspect("equal")
axes[8].grid(True, alpha=0.3)
pos_ax = axes[len(partner_labels)]
pos_ax.scatter(visible_partners[:, 0], visible_partners[:, 1], s=1, alpha=0.15, color="darkorange")
pos_ax.set_xlabel("rel_x")
pos_ax.set_ylabel("rel_y")
pos_ax.set_title("Partner positions (ego frame)")
pos_ax.set_aspect("equal")
pos_ax.grid(True, alpha=0.3)

for ax in axes[len(partner_labels) + 1 :]:
ax.axis("off")

fig.suptitle("Partner features: all visible, full rollout", fontsize=13)
plt.tight_layout()
Expand Down Expand Up @@ -1303,7 +1309,7 @@ def compute_gae(rewards, values, terminals, truncations, gamma, lam):
# %% [markdown]
# ## Encoder analysis — what the policy encodes
#
# Each obs layer has its own encoder projecting raw features → `input_size` embedding:
# Each obs layer has its own encoder projecting raw features → embedding width:
# - **ego** and **conditioning** (reward coefs + target): single vector, no pooling.
# - **partners / lanes / boundaries / traffic**: per-slot encoder, padded slots masked to `-inf`, then **max-pooled** across slots → one embedding. Fully-padded layers are zeroed.
#
Expand Down Expand Up @@ -1343,8 +1349,8 @@ def compute_gae(rewards, values, terminals, truncations, gamma, lam):
True,
)
)
if bb.conditioning_dim > 0:
enc_inventory.append(("conditioning", bb.conditioning_encoder, bb.conditioning_dim, 1, False))
if bb.target_dim > 0:
enc_inventory.append(("conditioning", bb.target_encoder, bb.target_dim, 1, False))

enc_names = [n for n, *_ in enc_inventory]
set_encs = [n for n, _, _, _, is_set in enc_inventory if is_set]
Expand All @@ -1354,10 +1360,10 @@ def compute_gae(rewards, values, terminals, truncations, gamma, lam):
for name, mod, rin, nslots, is_set in enc_inventory:
nparam = sum(p.numel() for p in mod.parameters())
print(
f"{name:>13s} | {rin:>6d} | {bb.input_size:>7d} | {nslots:>5d} | {('max' if is_set else '-'):>6s} | {nparam:>9,d}"
f"{name:>13s} | {rin:>6d} | {mod[-1].out_features:>7d} | {nslots:>5d} | {('max' if is_set else '-'):>6s} | {nparam:>9,d}"
)
print(
f"\nBackbone input = {len(enc_inventory)} x {bb.input_size} = {len(enc_inventory) * bb.input_size} -> backbone -> {bb.out_dim}"
f"\nBackbone input = {sum(mod[-1].out_features for _, mod, _, _, _ in enc_inventory)} -> backbone -> {bb.out_dim}"
)

# Capture pre-pool encoder outputs via forward hooks
Expand All @@ -1383,7 +1389,7 @@ def fn(m, i, o):
lane_dim = bb.obs_slots_lane_kept * bb.road_features_count
boundary_dim = bb.obs_slots_boundary_kept * bb.road_features_count
traffic_dim = bb.obs_slots_traffic_controls_n * bb.traffic_control_features_count
_s = ego_dim + bb.conditioning_dim
_s = ego_dim + bb.target_dim
sl = {}
sl["partner"] = (_s, _s + partner_dim, bb.obs_slots_partners_n, bb.partner_features_count)
_s += partner_dim
Expand Down Expand Up @@ -1413,7 +1419,7 @@ def fn(m, i, o):
masked = captured[name].masked_fill(pad[name].unsqueeze(2), -torch.inf)
vm = (~pad[name]).any(dim=1)
valid_sample[name] = vm
winners[name] = masked.max(dim=1).indices # (B, input_size): winning slot per dim
winners[name] = masked.max(dim=1).indices # (B, embedding dim): winning slot per dim
pooled[name] = torch.where(vm.unsqueeze(1), masked.max(dim=1).values, torch.zeros_like(masked.max(dim=1).values))

for name in ("ego", "conditioning"):
Expand Down
Loading
Loading