Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
61 changes: 61 additions & 0 deletions .github/workflows/update-smoke-golden.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
name: Update Smoke Golden

# Regenerate tests/smoke_tests/data/drive_smoke_golden.json inside the pinned
# QEMU/Haswell image (the only place it is bit-reproducible) and commit it back.
# git commit --allow-empty -m "regen smoke golden [update-golden]" && git push

on:
workflow_dispatch:
push:
branches:
- 'emerge/**'
- 'ev/**'

permissions:
contents: write

jobs:
update-golden:
name: Regenerate smoke golden
if: github.event_name == 'workflow_dispatch' || contains(github.event.head_commit.message, '[update-golden]')
runs-on: ubuntu-latest
steps:
- name: Checkout code
uses: actions/checkout@v4
with:
ref: ${{ github.ref_name }}

- name: Build smoke image
run: docker build -f tests/smoke_tests/Dockerfile -t pufferdrive-smoke .

- name: Regenerate golden
run: |
docker run --rm -e SMOKE_UPDATE_GOLDEN=1 \
-v "$PWD/tests/smoke_tests/data:/app/tests/smoke_tests/data" \
pufferdrive-smoke tests/smoke_tests/test_drive_train.py
timeout-minutes: 30

- name: Upload golden artifact
uses: actions/upload-artifact@v4
with:
name: drive_smoke_golden
path: tests/smoke_tests/data/drive_smoke_golden.json

- name: Commit regenerated golden
run: |
# The container writes the golden as root via the bind mount; reclaim
# it so git can stage/commit cleanly.
sudo chown "$(id -u):$(id -g)" tests/smoke_tests/data/drive_smoke_golden.json
# json.dump writes no trailing newline; add one so the committed file
# satisfies pre-commit's end-of-file-fixer.
[ -n "$(tail -c1 tests/smoke_tests/data/drive_smoke_golden.json)" ] \
&& printf '\n' >> tests/smoke_tests/data/drive_smoke_golden.json
if git diff --quiet -- tests/smoke_tests/data/drive_smoke_golden.json; then
echo "Golden unchanged; nothing to commit."
exit 0
fi
git config user.name "github-actions[bot]"
git config user.email "41898282+github-actions[bot]@users.noreply.github.com"
git add -f tests/smoke_tests/data/drive_smoke_golden.json # data/ is gitignored; file is tracked
git commit -m "Regenerate smoke golden for partner seconds_stopped feature"
git push origin "HEAD:${{ github.ref_name }}"
12 changes: 11 additions & 1 deletion notebooks/01_observations.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -214,7 +214,17 @@
"metadata": {},
"outputs": [],
"source": [
"partner_labels = [\"rel_x\", \"rel_y\", \"rel_z\", \"length\", \"width\", \"heading_cos\", \"heading_sin\", \"speed\"]\n",
"partner_labels = [\n",
" \"rel_x\",\n",
" \"rel_y\",\n",
" \"rel_z\",\n",
" \"length\",\n",
" \"width\",\n",
" \"heading_cos\",\n",
" \"heading_sin\",\n",
" \"speed\",\n",
" \"seconds_stopped\",\n",
"]\n",
"active_mask = ~np.all(partners == 0, axis=1)\n",
"n_active = active_mask.sum()\n",
"print(f\"Active partners: {n_active}/{env.obs_slots_partners_n}\")\n",
Expand Down
32 changes: 26 additions & 6 deletions notebooks/05_inference.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -299,7 +299,7 @@
"- **Ego**: speed, width, length, [jerk: steering, a_long, a_lat], lane_center_dist, lane_angle, speed_limit\n",
"- **Conditioning** (if enabled): 17 reward coefs (goal_radius, goal_speed, collision, offroad, comfort, lane_align, vel_align, lane_center, center_bias, velocity, reverse, stop_line, timestep, overspeed, throttle, steer, acc) + target waypoints\n",
"- **Target**: static=rel_x,rel_y,rel_z per waypoint; dynamic=rel_x,rel_y,rel_z,heading_cos,heading_sin per waypoint\n",
"- **Partners** (MAX_PARTNERS x 8): rel_x, rel_y, rel_z, length, width, heading_cos, heading_sin, speed\n",
"- **Partners** (MAX_PARTNERS x 9): rel_x, rel_y, rel_z, length, width, heading_cos, heading_sin, speed, seconds_stopped\n",
"- **Lanes** (MAX_LANES x 7): rel_x, rel_y, rel_z, seg_length, seg_width, dir_cos, dir_sin\n",
"- **Boundaries** (MAX_BOUNDS x 7): same as lanes\n",
"- **Traffic controls** (MAX_TRAFFIC x 7): rel_x1, rel_y1, rel_x2, rel_y2, rel_z, type, state"
Expand Down Expand Up @@ -390,9 +390,19 @@
"# --- Partner summary ---\n",
"n_visible = np.sum(np.any(partners != 0, axis=1))\n",
"print(f\"\\n--- Partners: {n_visible}/{partners.shape[0]} visible ---\")\n",
"partner_labels = [\"rel_x\", \"rel_y\", \"rel_z\", \"width\", \"length\", \"heading_cos\", \"heading_sin\", \"speed\"]\n",
"partner_labels = [\n",
" \"rel_x\",\n",
" \"rel_y\",\n",
" \"rel_z\",\n",
" \"length\",\n",
" \"width\",\n",
" \"heading_cos\",\n",
" \"heading_sin\",\n",
" \"speed\",\n",
" \"seconds_stopped\",\n",
"]\n",
"for p in range(min(int(n_visible), 5)):\n",
" vals = \", \".join(f\"{partner_labels[j]}={partners[p, j]:.3f}\" for j in range(8))\n",
" vals = \", \".join(f\"{partner_labels[j]}={partners[p, j]:.3f}\" for j in range(len(partner_labels)))\n",
" print(f\" [{p}] {vals}\")\n",
"if n_visible > 5:\n",
" print(f\" ... ({n_visible - 5} more)\")\n",
Expand Down Expand Up @@ -847,7 +857,17 @@
"outputs": [],
"source": [
"# Partner per-feature distributions (pooled over all agents + timesteps, visible only)\n",
"partner_labels = [\"rel_x\", \"rel_y\", \"rel_z\", \"width\", \"length\", \"heading_cos\", \"heading_sin\", \"speed\"]\n",
"partner_labels = [\n",
" \"rel_x\",\n",
" \"rel_y\",\n",
" \"rel_z\",\n",
" \"length\",\n",
" \"width\",\n",
" \"heading_cos\",\n",
" \"heading_sin\",\n",
" \"speed\",\n",
" \"seconds_stopped\",\n",
"]\n",
"obs_slots_partners_n = env.obs_slots_partners_n\n",
"pf = env.partner_features\n",
"\n",
Expand All @@ -861,10 +881,10 @@
"\n",
"all_partners = buf_stoch[\"obs\"][:, :, _p_start:_p_end].reshape(\n",
" -1, obs_slots_partners_n, pf\n",
") # (H*N, obs_slots_partners_n, 8)\n",
") # (H*N, obs_slots_partners_n, pf)\n",
"# Mask: partner is visible if any feature != 0\n",
"visible_mask = np.any(all_partners != 0, axis=2) # (H*N, 16)\n",
"visible_partners = all_partners[visible_mask] # (K, 8) — all visible partner observations\n",
"visible_partners = all_partners[visible_mask] # (K, pf) — all visible partner observations\n",
"\n",
"print(\n",
" f\"Total partner obs: {all_partners.shape[0] * obs_slots_partners_n}, visible: {len(visible_partners)} \"\n",
Expand Down
23 changes: 17 additions & 6 deletions pufferlib/ocean/drive/drive.h
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@
// Observation feature counts
#define EGO_FEATURES 10
#define ROAD_FEATURES 7
#define PARTNER_FEATURES 8
#define PARTNER_FEATURES 9
#define TRAFFIC_CONTROL_FEATURES 7
#define PADDED_OBSERVATION_VALUE -0.001f
#define STATIC_TARGET_FEATURES 3
Expand Down Expand Up @@ -4441,11 +4441,6 @@ static void compute_rewards(Drive *env, int i) {
env->rewards[i] += speed_reward;
env->logs[i].avg_speed_per_agent += agent->sim_speed;
agent->distance_since_spawn += agent->sim_speed * env->dt;
if (agent->sim_speed < AGENT_STOPPED_SPEED_THRESHOLD) {
agent->seconds_stopped += env->dt;
} else {
agent->seconds_stopped = 0.0f;
}
env->logs[i].episode_return += speed_reward;
env->logs[i].reward_overspeed += speed_reward;

Expand Down Expand Up @@ -4663,6 +4658,8 @@ static int write_partner_obs(Drive *env, Agent *ego, int agent_idx, float *obs,
obs[obs_idx++] = rel_heading_x;
obs[obs_idx++] = rel_heading_y;
obs[obs_idx++] = other->sim_speed_signed / MAX_SPEED;
// TODO(hack): partner seconds_stopped is a temporary feature; remove later.
obs[obs_idx++] = fminf(1.0f, other->seconds_stopped / MAX_STOPPED_SECONDS);
partners_written++;
}

Expand Down Expand Up @@ -5291,6 +5288,20 @@ void c_step(Drive *env) {
// move_expert(env, env->actions, agent_idx);
}

// Update stopped-duration for every agent (active + replayed/static), not
// just policy-controlled ones, so the partner seconds_stopped observation is
// populated even in control_sdc_only mode where only the ego is active.
for (int j = 0; j < env->num_agents; j++) {
int agent_idx = (j < env->active_agent_count) ? env->active_agent_indices[j]
: env->static_agent_indices[j - env->active_agent_count];
Agent *agent = &env->agents[agent_idx];
if (agent->sim_speed < AGENT_STOPPED_SPEED_THRESHOLD) {
agent->seconds_stopped += env->dt;
} else {
agent->seconds_stopped = 0.0f;
}
}

// -> 2. Compute metrics and rewards
for (int i = 0; i < env->active_agent_count; i++) {
int agent_idx = env->active_agent_indices[i];
Expand Down
5 changes: 3 additions & 2 deletions pufferlib/viz.py
Original file line number Diff line number Diff line change
Expand Up @@ -954,6 +954,7 @@ def generate_interactive_replay(scenario, replay, filename="replay.html"):
"num_target_waypoints": int(env_cfg["num_target_waypoints"]),
"reward_conditioning": bool(env_cfg["reward_conditioning"]),
"obs_slots_partners_n": int(env_cfg["obs_slots_partners_n"]),
"partner_features": int(binding.PARTNER_FEATURES),
"lane_count": int(lane_count),
"boundary_count": int(boundary_count),
"traffic_obs_count": int(env_cfg["obs_slots_traffic_controls_n"]),
Expand Down Expand Up @@ -1203,14 +1204,14 @@ def generate_interactive_replay(scenario, replay, filename="replay.html"):
let p = base, ego = obs.subarray(p, p+10); p += 10;
if (H.reward_conditioning) p += 17;
const targetStart = p; p += H.num_target_waypoints * H.target_features;
const partnersStart = p; p += H.obs_slots_partners_n * 8;
const partnersStart = p; p += H.obs_slots_partners_n * H.partner_features;
const lanesStart = p; p += H.lane_count * 7;
const boundsStart = p; p += H.boundary_count * 7;
const trafficStart = p;
const rot = (x,y) => [-y,x];
const zero = (off,n) => { for(let i=0;i<n;i++) if(obs[off+i] !== 0) return false; return true; };
const roads = (start,count,poolName) => { const out=[]; for(let i=0;i<count;i++){ const o=start+i*7; if(zero(o,7)) continue; let xy=rot(obs[o],obs[o+1]), cs=rot(obs[o+5],obs[o+6]); out.push([xy[0],xy[1],obs[o+3]*H.scales.road_length_to_position,obs[o+4]*H.scales.road_width_to_position,cs[0],cs[1],poolAt(poolName,frame,slot,i)]); } return out; };
const partners = []; for(let i=0;i<H.obs_slots_partners_n;i++){ const o=partnersStart+i*8; if(zero(o,8)) continue; let xy=rot(obs[o],obs[o+1]), h=Math.atan2(obs[o+6],obs[o+5]); h = ((h + Math.PI/2 + Math.PI) % (2*Math.PI)) - Math.PI; partners.push({x:xy[0],y:xy[1],l:obs[o+3]*H.scales.veh_len_to_position,w:obs[o+4]*H.scales.veh_width_to_position,h:h,s:obs[o+7],pool:poolAt("pool_partner",frame,slot,i)}); }
const partners = []; for(let i=0;i<H.obs_slots_partners_n;i++){ const o=partnersStart+i*H.partner_features; if(zero(o,H.partner_features)) continue; let xy=rot(obs[o],obs[o+1]), h=Math.atan2(obs[o+6],obs[o+5]); h = ((h + Math.PI/2 + Math.PI) % (2*Math.PI)) - Math.PI; partners.push({x:xy[0],y:xy[1],l:obs[o+3]*H.scales.veh_len_to_position,w:obs[o+4]*H.scales.veh_width_to_position,h:h,s:obs[o+7],pool:poolAt("pool_partner",frame,slot,i)}); }
const gps = []; for(let i=0;i<H.num_target_waypoints;i++){ const o=targetStart+i*H.target_features; if(zero(o,H.target_features)) continue; let scale=H.target_type === "static" ? H.scales.goal_to_position : 1, xy=rot(obs[o]*scale, obs[o+1]*scale); gps.push(xy); }
const controls = []; for(let i=0;i<H.traffic_obs_count;i++){ const o=trafficStart+i*7; if(zero(o,7)) continue; let a=rot(obs[o],obs[o+1]), b=rot(obs[o+2],obs[o+3]); controls.push({type:obs[o+5], state:obs[o+6], x1:a[0], y1:a[1], x2:b[0], y2:b[1], pool:poolAt("pool_traffic",frame,slot,i)}); }
return {ego:{s:ego[0],w:ego[1]*H.scales.veh_width_to_position,l:ego[2]*H.scales.veh_len_to_position,st:ego[3],al:ego[4],alat:ego[5]}, partners, lanes:roads(lanesStart,H.lane_count,"pool_lane"), bounds:roads(boundsStart,H.boundary_count,"pool_boundary"), gps, traffic_controls:controls};
Expand Down
57 changes: 36 additions & 21 deletions tests/smoke_tests/data/drive_smoke_golden.json
Original file line number Diff line number Diff line change
@@ -1,33 +1,48 @@
{
"env": {
"avg_distance_per_infraction": 13.20275001525879,
"avg_speed_per_agent": 1.3609784364700317,
"collision_rate": 0.0125,
"comfort_violation_count": 0.7306405305862427,
"dnf_rate": 0.55,
"episode_length": 43.6,
"episode_return": -2.2548463344573975,
"lane_center_rate": 0.6961542010307312,
"avg_distance_per_infraction": 11.850770235061646,
"avg_speed_per_agent": 1.3407511115074158,
"collision_rate": 0.0625,
"comfort_violation_count": 0.7286259929339091,
"dnf_rate": 0.5520833333333334,
"episode_length": 39.0,
"episode_return": -2.2371936639149985,
"lane_center_rate": 0.7488949696222941,
"n": 16.0,
"num_goals_reached": 0.0,
"offroad_rate": 0.3625,
"red_light_violation_rate": 0.075,
"obs/max": 4.0,
"obs/mean": 0.143798382836394,
"obs/min": -1.2310400689020753,
"offroad_rate": 0.3541666666666667,
"red_light_violation_rate": 0.03125,
"reward_components/ade": 0.0,
"reward_components/collision": -0.10453646133343379,
"reward_components/comfort": -1.432812213897705,
"reward_components/goal": 0.0,
"reward_components/lane_align": -0.12008568147818248,
"reward_components/lane_center": -0.0017940285421597462,
"reward_components/offroad": -0.53125,
"reward_components/overspeed": 0.0,
"reward_components/red_light": -0.03125,
"reward_components/reverse": -0.015380206052213907,
"reward_components/timestep": -8.554685700801201e-05,
"reward_components/velocity": 0.0,
"score": 0.0,
"velocity_progress_sum": 0.0
},
"losses": {
"approx_kl": 0.0004938042755903942,
"approx_kl": 0.00031589248516995994,
"clipfrac": 0.0,
"ema_max": 1.191539317369461,
"entropy": 2.4832939420427596,
"explained_variance": 0.18538302183151245,
"filter_threshold": 0.01191539317369461,
"filtered_fraction": 0.023398328690807824,
"kept_fraction": 0.9766016713091922,
"masked_fraction": 0.12353515625,
"old_approx_kl": 0.000515544121818883,
"policy_loss": -0.0005634417258469122,
"value_loss": 0.08380150688546044
"ema_max": 1.1163338720798492,
"entropy": 2.483056511197771,
"explained_variance": 0.16990125179290771,
"filter_threshold": 0.011163338720798492,
"filtered_fraction": 0.03224013340744858,
"kept_fraction": 0.9677598665925514,
"masked_fraction": 0.12158203125,
"old_approx_kl": 0.0005578249381090115,
"policy_loss": -0.001066769240424037,
"value_loss": 0.06768968754581042
},
"meta": {
"bptt_horizon": 64,
Expand Down
Loading