diff --git a/.dockerignore b/.dockerignore index 3a92c72017..a896e4ff2d 100644 --- a/.dockerignore +++ b/.dockerignore @@ -7,3 +7,42 @@ build *.egg-info experiments wandb +.neptune +.pytest_cache +.ruff_cache + +benchmark*/ +!pufferlib/ocean/benchmark/ +!pufferlib/ocean/benchmark/** +runs*/ +weights/ +checkpoints/ +/data/ +!/tests/smoke_tests/data/ +!/tests/smoke_tests/data/** +/artifacts/ +external/ + +pufferlib/resources/drive/binaries/*/ +!pufferlib/resources/drive/binaries/carla/ +!pufferlib/resources/drive/binaries/carla/** +!pufferlib/resources/drive/binaries/dense/ +!pufferlib/resources/drive/binaries/dense/** +!pufferlib/resources/drive/binaries/lateral/ +!pufferlib/resources/drive/binaries/lateral/** +!pufferlib/resources/drive/binaries/longitudinal/ +!pufferlib/resources/drive/binaries/longitudinal/** +!pufferlib/resources/drive/binaries/nuplan/ +!pufferlib/resources/drive/binaries/nuplan/** +!pufferlib/resources/drive/binaries/obstacles/ +!pufferlib/resources/drive/binaries/obstacles/** +!pufferlib/resources/drive/binaries/vru/ +!pufferlib/resources/drive/binaries/vru/** + +pufferlib/resources/drive/output*.gif +pufferlib/resources/drive/pufferdrive_*.gif +*.mp4 +*.mov +*.webm +*.avi +*.mkv diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 4d239b9a4a..5927bf2f08 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -66,7 +66,7 @@ jobs: PIP_NO_CACHE_DIR: 1 run: | sudo apt-get update && sudo apt-get install -y build-essential cmake - python -m pip install -U pip pytest jupytext nbclient ipykernel ipywidgets + python -m pip install -U pip pytest jupytext nbclient ipykernel pip install -e . --no-cache-dir python setup.py build_ext --inplace --force diff --git a/.gitignore b/.gitignore index be0cbb22bd..baa6caa7bf 100644 --- a/.gitignore +++ b/.gitignore @@ -169,6 +169,7 @@ pufferlib/ocean/impulse_wars/benchmark/ data/ pufferlib/resources/drive/binaries/*/ !pufferlib/resources/drive/binaries/carla/ +!pufferlib/resources/drive/binaries/carla_lhs/ !pufferlib/resources/drive/binaries/carla/** # Re-ignore .DS_Store inside carla binaries pufferlib/resources/drive/binaries/carla/.DS_Store diff --git a/notebooks/01_observations.py b/notebooks/01_observations.py index 7a2f99d60b..6677e62fb3 100644 --- a/notebooks/01_observations.py +++ b/notebooks/01_observations.py @@ -138,6 +138,7 @@ idx += env.obs_slots_traffic_controls_n * env.traffic_control_features assert np.allclose(traffic_manual, traffic), "traffic mismatch" +idx += 4 # appended slot-count features at end of obs assert idx == obs.shape[1], f"obs size mismatch: used {idx}, total {obs.shape[1]}" print(f"All slices match. Total features used: {idx}") @@ -175,7 +176,7 @@ "width", "heading_cos", "heading_sin", - "speed", + "sim_speed_signed", "seconds_stopped", ] active_mask = ~np.all(partners == 0, axis=1) diff --git a/notebooks/04_training.py b/notebooks/04_training.py index 7a1493a24d..c1378bd2b1 100644 --- a/notebooks/04_training.py +++ b/notebooks/04_training.py @@ -109,7 +109,7 @@ f"ego_obs: shape={ego_obs.shape}, NaN={torch.isnan(ego_obs).sum().item()}, range=[{ego_obs.min():.3f}, {ego_obs.max():.3f}]" ) -cond_dim = backbone.conditioning_dim +cond_dim = backbone.target_dim if cond_dim > 0: cond_obs = x[:, slide_idx : slide_idx + cond_dim] slide_idx += cond_dim @@ -143,7 +143,7 @@ if cond_dim > 0: with torch.no_grad(): - cond_enc = backbone.conditioning_encoder(cond_obs) + cond_enc = backbone.target_encoder(cond_obs) print( f"{'cond':>10s}_enc: NaN={torch.isnan(cond_enc).sum().item()}, dead={((cond_enc.abs().sum(dim=0) == 0).sum().item())}, range=[{cond_enc.min():.3f}, {cond_enc.max():.3f}]" ) diff --git a/notebooks/05_inference.py b/notebooks/05_inference.py index f074c28c90..26d6bc2795 100644 --- a/notebooks/05_inference.py +++ b/notebooks/05_inference.py @@ -259,7 +259,7 @@ def run_rollout(env, policy, deterministic=False, horizon=HORIZON): # - **Ego**: speed, width, length, [jerk: steering, a_long, a_lat], lane_center_dist, lane_angle, speed_limit # - **Conditioning** (if enabled): 17 reward coefs (goal_radius, goal_speed, collision, offroad, comfort, lane_align, vel_align, lane_center, center_bias, velocity, reverse, stop_line, timestep, overspeed, throttle, steer, acc) + target waypoints # - **Target**: static=rel_x,rel_y,rel_z per waypoint; dynamic=rel_x,rel_y,rel_z,heading_cos,heading_sin per waypoint -# - **Partners** (MAX_PARTNERS x 9): rel_x, rel_y, rel_z, length, width, heading_cos, heading_sin, speed, seconds_stopped +# - **Partners** (MAX_PARTNERS x 9): rel_x, rel_y, rel_z, length, width, heading_cos, heading_sin, sim_speed_signed, seconds_stopped # - **Lanes** (MAX_LANES x 7): rel_x, rel_y, rel_z, seg_length, seg_width, dir_cos, dir_sin # - **Boundaries** (MAX_BOUNDS x 7): same as lanes # - **Traffic controls** (MAX_TRAFFIC x 7): rel_x1, rel_y1, rel_x2, rel_y2, rel_z, type, state @@ -351,11 +351,11 @@ def layer_stats(name, arr): "width", "heading_cos", "heading_sin", - "speed", + "sim_speed_signed", "seconds_stopped", ] for p in range(min(int(n_visible), 5)): - vals = ", ".join(f"{partner_labels[j]}={partners[p, j]:.3f}" for j in range(len(partner_labels))) + vals = ", ".join(f"{partner_labels[j]}={partners[p, j]:.3f}" for j in range(env.partner_features)) print(f" [{p}] {vals}") if n_visible > 5: print(f" ... ({n_visible - 5} more)") @@ -631,12 +631,14 @@ def unpack_all_timesteps(bufs, agent_idx): for i in range(partners.shape[0]): if np.allclose(partners[i], 0): continue - rx, ry, rz, w, l, hc, hs, vx, vy = partners[i] + rx, ry, rz, length, width, hc, hs, speed, _ = partners[i] heading = np.arctan2(hs, hc) - rect = Rectangle((-l / 2, -w / 2), l, w, facecolor="orange", edgecolor="black", alpha=0.6, zorder=9) + rect = Rectangle( + (-length / 2, -width / 2), length, width, facecolor="orange", edgecolor="black", alpha=0.6, zorder=9 + ) rect.set_transform(plt.matplotlib.transforms.Affine2D().rotate(heading).translate(rx, ry) + ax.transData) ax.add_patch(rect) - ax.annotate(f"{vx:.2f}, {vy:.2f}", (rx, ry), fontsize=7, ha="center", color="darkred", zorder=12) + ax.annotate(f"{speed:.2f}", (rx, ry), fontsize=7, ha="center", color="darkred", zorder=12) part_mask = np.any(partners != 0, axis=1) if part_mask.any(): ax.scatter( @@ -773,7 +775,7 @@ def unpack_all_timesteps(bufs, agent_idx): "width", "heading_cos", "heading_sin", - "speed", + "sim_speed_signed", "seconds_stopped", ] obs_slots_partners_n = env.obs_slots_partners_n @@ -799,7 +801,7 @@ def unpack_all_timesteps(bufs, agent_idx): f"({100 * len(visible_partners) / (all_partners.shape[0] * obs_slots_partners_n):.1f}%)" ) -fig, axes = plt.subplots(3, 3, figsize=(21, 10)) +fig, axes = plt.subplots(3, 4, figsize=(21, 11)) axes = axes.flatten() for i, label in enumerate(partner_labels): @@ -811,12 +813,16 @@ def unpack_all_timesteps(bufs, agent_idx): axes[i].tick_params(labelsize=7) # rel_x vs rel_y scatter in last panel -axes[8].scatter(visible_partners[:, 0], visible_partners[:, 1], s=1, alpha=0.15, color="darkorange") -axes[8].set_xlabel("rel_x") -axes[8].set_ylabel("rel_y") -axes[8].set_title("Partner positions (ego frame)") -axes[8].set_aspect("equal") -axes[8].grid(True, alpha=0.3) +pos_ax = axes[len(partner_labels)] +pos_ax.scatter(visible_partners[:, 0], visible_partners[:, 1], s=1, alpha=0.15, color="darkorange") +pos_ax.set_xlabel("rel_x") +pos_ax.set_ylabel("rel_y") +pos_ax.set_title("Partner positions (ego frame)") +pos_ax.set_aspect("equal") +pos_ax.grid(True, alpha=0.3) + +for ax in axes[len(partner_labels) + 1 :]: + ax.axis("off") fig.suptitle("Partner features: all visible, full rollout", fontsize=13) plt.tight_layout() @@ -1303,7 +1309,7 @@ def compute_gae(rewards, values, terminals, truncations, gamma, lam): # %% [markdown] # ## Encoder analysis — what the policy encodes # -# Each obs layer has its own encoder projecting raw features → `input_size` embedding: +# Each obs layer has its own encoder projecting raw features → embedding width: # - **ego** and **conditioning** (reward coefs + target): single vector, no pooling. # - **partners / lanes / boundaries / traffic**: per-slot encoder, padded slots masked to `-inf`, then **max-pooled** across slots → one embedding. Fully-padded layers are zeroed. # @@ -1343,8 +1349,8 @@ def compute_gae(rewards, values, terminals, truncations, gamma, lam): True, ) ) -if bb.conditioning_dim > 0: - enc_inventory.append(("conditioning", bb.conditioning_encoder, bb.conditioning_dim, 1, False)) +if bb.target_dim > 0: + enc_inventory.append(("conditioning", bb.target_encoder, bb.target_dim, 1, False)) enc_names = [n for n, *_ in enc_inventory] set_encs = [n for n, _, _, _, is_set in enc_inventory if is_set] @@ -1354,10 +1360,10 @@ def compute_gae(rewards, values, terminals, truncations, gamma, lam): for name, mod, rin, nslots, is_set in enc_inventory: nparam = sum(p.numel() for p in mod.parameters()) print( - f"{name:>13s} | {rin:>6d} | {bb.input_size:>7d} | {nslots:>5d} | {('max' if is_set else '-'):>6s} | {nparam:>9,d}" + f"{name:>13s} | {rin:>6d} | {mod[-1].out_features:>7d} | {nslots:>5d} | {('max' if is_set else '-'):>6s} | {nparam:>9,d}" ) print( - f"\nBackbone input = {len(enc_inventory)} x {bb.input_size} = {len(enc_inventory) * bb.input_size} -> backbone -> {bb.out_dim}" + f"\nBackbone input = {sum(mod[-1].out_features for _, mod, _, _, _ in enc_inventory)} -> backbone -> {bb.out_dim}" ) # Capture pre-pool encoder outputs via forward hooks @@ -1383,7 +1389,7 @@ def fn(m, i, o): lane_dim = bb.obs_slots_lane_kept * bb.road_features_count boundary_dim = bb.obs_slots_boundary_kept * bb.road_features_count traffic_dim = bb.obs_slots_traffic_controls_n * bb.traffic_control_features_count -_s = ego_dim + bb.conditioning_dim +_s = ego_dim + bb.target_dim sl = {} sl["partner"] = (_s, _s + partner_dim, bb.obs_slots_partners_n, bb.partner_features_count) _s += partner_dim @@ -1413,7 +1419,7 @@ def fn(m, i, o): masked = captured[name].masked_fill(pad[name].unsqueeze(2), -torch.inf) vm = (~pad[name]).any(dim=1) valid_sample[name] = vm - winners[name] = masked.max(dim=1).indices # (B, input_size): winning slot per dim + winners[name] = masked.max(dim=1).indices # (B, embedding dim): winning slot per dim pooled[name] = torch.where(vm.unsqueeze(1), masked.max(dim=1).values, torch.zeros_like(masked.max(dim=1).values)) for name in ("ego", "conditioning"): diff --git a/notebooks/06_architecture.py b/notebooks/06_architecture.py index f8308a10be..ddcd8eac41 100644 --- a/notebooks/06_architecture.py +++ b/notebooks/06_architecture.py @@ -34,35 +34,46 @@ ACTOR_NUM_LAYERS = 3 CRITIC_HIDDEN_SIZE = 64 CRITIC_NUM_LAYERS = 2 -SPLIT_NETWORK = False -ENCODER_GIGAFLOW = True -DROPOUT = 0.0 +SHARED_NETWORK = True +ENCODER_ACTIVATION = "tanh" +ENCODER_LAYER_NORM = True +BACKBONE_ACTIVATION = "gelu" +BACKBONE_LAYER_NORM = False +MASK_PADDED_FEATURES = False env, obs, info = make_drive_env() device = torch.device("cuda" if torch.cuda.is_available() else "cpu") policy = DrivePolicy( env, - input_size=INPUT_SIZE, + ego_input_size=INPUT_SIZE, + partner_input_size=INPUT_SIZE, + lane_input_size=INPUT_SIZE, + boundary_input_size=INPUT_SIZE, + traffic_control_input_size=INPUT_SIZE, + target_input_size=INPUT_SIZE, backbone_hidden_size=BACKBONE_HIDDEN_SIZE, backbone_num_layers=BACKBONE_NUM_LAYERS, actor_hidden_size=ACTOR_HIDDEN_SIZE, actor_num_layers=ACTOR_NUM_LAYERS, critic_hidden_size=CRITIC_HIDDEN_SIZE, critic_num_layers=CRITIC_NUM_LAYERS, - split_network=SPLIT_NETWORK, - encoder_gigaflow=ENCODER_GIGAFLOW, - dropout=DROPOUT, + encoder_activation=ENCODER_ACTIVATION, + encoder_layer_norm=ENCODER_LAYER_NORM, + backbone_activation=BACKBONE_ACTIVATION, + backbone_layer_norm=BACKBONE_LAYER_NORM, + shared_network=SHARED_NETWORK, + mask_padded_features=MASK_PADDED_FEATURES, ).to(device) print(f"Device: {device}") print(f"Obs dim: {obs.shape[1]}") print(f"Action dim: {policy.atn_dim}") -print(f"Split network: {SPLIT_NETWORK}") +print(f"Shared network: {SHARED_NETWORK}") print(f"Backbone: {BACKBONE_HIDDEN_SIZE} x {BACKBONE_NUM_LAYERS}L") print(f"Actor: {ACTOR_HIDDEN_SIZE} x {ACTOR_NUM_LAYERS}L") print(f"Critic: {CRITIC_HIDDEN_SIZE} x {CRITIC_NUM_LAYERS}L") -print(f"Encoder gigaflow: {ENCODER_GIGAFLOW}, Dropout: {DROPOUT}") +print(f"Encoder: {ENCODER_ACTIVATION}, LayerNorm: {ENCODER_LAYER_NORM}") # %% [markdown] # ## Model Summary (torchinfo) @@ -76,22 +87,21 @@ # %% backbone = policy.actor_backbone -cond_dim = backbone.conditioning_dim +cond_dim = backbone.target_dim -# Collect encoder info — encoder_gigaflow adds Tanh+Dropout between LN and second Linear -# ego, partner, conditioning use encoder_gigaflow; lane, boundary, traffic_ctrl use dropout +# Collect encoder info encoders = [ - ("ego", env.ego_features, 1, "direct", ENCODER_GIGAFLOW), - ("conditioning", cond_dim, 1, "direct", ENCODER_GIGAFLOW) if cond_dim > 0 else None, - ("partner", env.partner_features, env.obs_slots_partners_n, "max-pool", ENCODER_GIGAFLOW), - ("lane", env.road_features, env.obs_slots_lane_kept, "max-pool", ENCODER_GIGAFLOW), - ("boundary", env.road_features, env.obs_slots_boundary_kept, "max-pool", ENCODER_GIGAFLOW), + ("ego", env.ego_features, 1, "direct", INPUT_SIZE), + ("conditioning", cond_dim, 1, "direct", INPUT_SIZE) if cond_dim > 0 else None, + ("partner", env.partner_features, env.obs_slots_partners_n, "max-pool", INPUT_SIZE), + ("lane", env.road_features, env.obs_slots_lane_kept, "max-pool", INPUT_SIZE), + ("boundary", env.road_features, env.obs_slots_boundary_kept, "max-pool", INPUT_SIZE), ( "traffic_ctrl", env.traffic_control_features - 2 + binding.NUM_TRAFFIC_CONTROL_TYPES + binding.NUM_TRAFFIC_CONTROL_STATES, env.obs_slots_traffic_controls_n, "max-pool (onehot)", - ENCODER_GIGAFLOW, + INPUT_SIZE, ), ] encoders = [e for e in encoders if e is not None] @@ -106,36 +116,25 @@ colors = plt.cm.Set2(np.linspace(0, 1, n_enc)) # Draw encoders -for i, ((name, in_f, n_obj, agg, gigaflow), y, c) in enumerate(zip(encoders, y_positions, colors)): +for i, ((name, in_f, n_obj, agg, out_size), y, c) in enumerate(zip(encoders, y_positions, colors)): # Input box label = f"{name}\n{n_obj}x{in_f}" if n_obj > 1 else f"{name}\n{in_f}" ax.add_patch(plt.Rectangle((0.2, y - 0.3), 1.6, 0.6, facecolor=c, edgecolor="black", lw=1.2, alpha=0.8)) ax.text(1.0, y, label, ha="center", va="center", fontsize=8, fontweight="bold") - # Encoder box — show gigaflow arch vs standard + # Encoder box ax.add_patch(plt.Rectangle((2.5, y - 0.25), 2.0, 0.5, facecolor="lightyellow", edgecolor="black", lw=1)) - ax.text(3.5, y + 0.05, f"Linear({in_f},{INPUT_SIZE})", ha="center", va="center", fontsize=7) - if gigaflow: - ax.text( - 3.5, - y - 0.12, - f"LN+Tanh+Drop+Linear({INPUT_SIZE},{INPUT_SIZE})", - ha="center", - va="center", - fontsize=5.5, - color="darkgreen", - ) - else: - drop_str = f"+Drop({DROPOUT})" if DROPOUT > 0 and name not in ("ego", "partner", "conditioning") else "" - ax.text( - 3.5, - y - 0.12, - f"LN{drop_str}+Linear({INPUT_SIZE},{INPUT_SIZE})", - ha="center", - va="center", - fontsize=6, - color="gray", - ) + ax.text(3.5, y + 0.05, f"Linear({in_f},{out_size})", ha="center", va="center", fontsize=7) + ln_label = "LN+" if ENCODER_LAYER_NORM else "" + ax.text( + 3.5, + y - 0.12, + f"{ln_label}{ENCODER_ACTIVATION}+Linear({out_size},{out_size})", + ha="center", + va="center", + fontsize=6, + color="gray", + ) # Aggregation if n_obj > 1: @@ -177,14 +176,13 @@ ax.annotate("", xy=(9.0, 6.0), xytext=(8.8, 5.3), arrowprops=dict(arrowstyle="->", lw=1.2)) ax.annotate("", xy=(9.0, 4.0), xytext=(8.8, 4.7), arrowprops=dict(arrowstyle="->", lw=1.2)) -split_label = "SPLIT" if SPLIT_NETWORK else "SHARED" +split_label = "SHARED" if SHARED_NETWORK else "SPLIT" ax.text(8.9, 4.55, split_label, ha="center", va="center", fontsize=7, color="red", fontweight="bold") -gigaflow_label = "GIGAFLOW" if ENCODER_GIGAFLOW else "STANDARD" ax.text( 5.0, 0.3, - f"Encoder mode: {gigaflow_label} | Dropout: {DROPOUT}", + f"Encoder: {ENCODER_ACTIVATION} | LayerNorm: {ENCODER_LAYER_NORM}", ha="center", va="center", fontsize=8, @@ -193,7 +191,7 @@ ) ax.set_title( - f"DrivePolicy Architecture (input_size={INPUT_SIZE}, backbone={BACKBONE_HIDDEN_SIZE})", + f"DrivePolicy Architecture (encoder_size={INPUT_SIZE}, backbone={BACKBONE_HIDDEN_SIZE})", fontsize=12, fontweight="bold", ) @@ -218,8 +216,8 @@ def count_params(module): "partner_encoder": backbone.partner_encoder, "traffic_ctrl_encoder": backbone.traffic_control_encoder, } -if backbone.conditioning_dim > 0: - components["conditioning_encoder"] = backbone.conditioning_encoder +if backbone.target_dim > 0: + components["target_encoder"] = backbone.target_encoder components["backbone_mlp"] = backbone.backbone components["actor_head"] = policy.actor_head components["critic_head"] = policy.critic_head @@ -233,7 +231,7 @@ def count_params(module): print(f"{n:>25s} | {c:>10,d} | {c / total:>5.1%}") print("-" * 48) print(f"{'TOTAL':>25s} | {total:>10,d}") -if SPLIT_NETWORK: +if not SHARED_NETWORK: critic_bb = count_params(policy.critic_backbone) print(f"{'+ critic_backbone':>25s} | {critic_bb:>10,d}") print(f"{'GRAND TOTAL':>25s} | {total + critic_bb:>10,d}") @@ -256,7 +254,7 @@ def count_params(module): backbone = policy.actor_backbone slide_idx = env.ego_features -cond_dim = backbone.conditioning_dim +cond_dim = backbone.target_dim partner_dim = env.obs_slots_partners_n * env.partner_features lane_dim = env.obs_slots_lane_kept * env.road_features boundary_dim = env.obs_slots_boundary_kept * env.road_features @@ -299,7 +297,7 @@ def count_params(module): print(f" ego_encoder: {ego_obs.shape} -> {ego_enc.shape}") if cond_dim > 0: - cond_enc = backbone.conditioning_encoder(cond_obs) + cond_enc = backbone.target_encoder(cond_obs) print(f" cond_encoder: {cond_obs.shape} -> {cond_enc.shape}") p_reshaped = partner_obs.view(-1, env.obs_slots_partners_n, env.partner_features) @@ -385,7 +383,7 @@ def count_params(module): activations["ego"] = backbone.ego_encoder(obs_tensor[:, : env.ego_features]) if cond_dim > 0: - activations["conditioning"] = backbone.conditioning_encoder(obs_tensor[:, slide : slide + cond_dim]) + activations["conditioning"] = backbone.target_encoder(obs_tensor[:, slide : slide + cond_dim]) slide += cond_dim p_obs = obs_tensor[:, slide : slide + partner_dim].view(-1, env.obs_slots_partners_n, env.partner_features) @@ -470,12 +468,12 @@ def count_params(module): # %% configs = [ - {"name": "tiny", "input_size": 32, "backbone_hidden_size": 64}, - {"name": "small", "input_size": 64, "backbone_hidden_size": 128}, - {"name": "medium", "input_size": 128, "backbone_hidden_size": 256, "backbone_num_layers": 2}, + {"name": "tiny", "encoder_size": 32, "backbone_hidden_size": 64}, + {"name": "small", "encoder_size": 64, "backbone_hidden_size": 128}, + {"name": "medium", "encoder_size": 128, "backbone_hidden_size": 256, "backbone_num_layers": 2}, { "name": "large", - "input_size": 128, + "encoder_size": 128, "backbone_hidden_size": 512, "backbone_num_layers": 2, "actor_num_layers": 2, @@ -485,7 +483,7 @@ def count_params(module): }, { "name": "xlarge", - "input_size": 256, + "encoder_size": 256, "backbone_hidden_size": 1024, "backbone_num_layers": 3, "actor_num_layers": 2, @@ -493,32 +491,51 @@ def count_params(module): "critic_num_layers": 2, "critic_hidden_size": 512, }, - {"name": "small+giga", "input_size": 64, "backbone_hidden_size": 128, "encoder_gigaflow": True, "dropout": 0.1}, + {"name": "small+tanh", "encoder_size": 64, "backbone_hidden_size": 128, "encoder_activation": "tanh"}, { - "name": "medium+giga", - "input_size": 128, + "name": "medium+tanh", + "encoder_size": 128, "backbone_hidden_size": 256, "backbone_num_layers": 2, - "encoder_gigaflow": True, - "dropout": 0.1, + "encoder_activation": "tanh", }, ] POLICY_DEFAULTS = { + "ego_input_size": 64, + "partner_input_size": 64, + "lane_input_size": 64, + "boundary_input_size": 64, + "traffic_control_input_size": 64, + "target_input_size": 64, "backbone_num_layers": 1, "actor_hidden_size": 128, "actor_num_layers": 0, "critic_hidden_size": 128, "critic_num_layers": 0, - "encoder_gigaflow": False, - "dropout": 0.0, - "split_network": False, + "encoder_activation": "relu", + "encoder_layer_norm": True, + "backbone_activation": "gelu", + "backbone_layer_norm": False, + "shared_network": True, + "mask_padded_features": False, } results = [] for cfg in configs: - name = cfg.pop("name") - full_cfg = {**POLICY_DEFAULTS, **cfg} + name = cfg["name"] + encoder_size = cfg.get("encoder_size", POLICY_DEFAULTS["ego_input_size"]) + full_cfg = {**POLICY_DEFAULTS, **{k: v for k, v in cfg.items() if k not in ("name", "encoder_size")}} + full_cfg.update( + { + "ego_input_size": encoder_size, + "partner_input_size": encoder_size, + "lane_input_size": encoder_size, + "boundary_input_size": encoder_size, + "traffic_control_input_size": encoder_size, + "target_input_size": encoder_size, + } + ) p = DrivePolicy(env, **full_cfg).to(device) n_params = sum(pp.numel() for pp in p.parameters()) @@ -532,17 +549,16 @@ def count_params(module): torch.cuda.synchronize() ms_per_fwd = (time.time() - t0) / 100 * 1000 - results.append({"name": name, "params": n_params, "ms/fwd": ms_per_fwd, **cfg}) - cfg["name"] = name # restore + results.append({"name": name, "encoder_size": encoder_size, "params": n_params, "ms/fwd": ms_per_fwd, **full_cfg}) del p print( - f"{'Config':>12s} | {'input':>5s} | {'bb_h':>5s} | {'bb_L':>4s} | {'act_h':>5s} | {'act_L':>5s} | {'crt_h':>5s} | {'crt_L':>5s} | {'giga':>5s} | {'Params':>10s} | {'ms/fwd':>8s}" + f"{'Config':>12s} | {'enc':>5s} | {'bb_h':>5s} | {'bb_L':>4s} | {'act_h':>5s} | {'act_L':>5s} | {'crt_h':>5s} | {'crt_L':>5s} | {'enc_act':>7s} | {'Params':>10s} | {'ms/fwd':>8s}" ) print("-" * 105) for r in results: print( - f"{r['name']:>12s} | {r['input_size']:>5d} | {r.get('backbone_hidden_size', 1024):>5d} | {r.get('backbone_num_layers', 1):>4d} | {r.get('actor_hidden_size', 1024):>5d} | {r.get('actor_num_layers', 1):>5d} | {r.get('critic_hidden_size', 1024):>5d} | {r.get('critic_num_layers', 1):>5d} | {str(r.get('encoder_gigaflow', False)):>5s} | {r['params']:>10,d} | {r['ms/fwd']:>7.2f}ms" + f"{r['name']:>12s} | {r['encoder_size']:>5d} | {r['backbone_hidden_size']:>5d} | {r['backbone_num_layers']:>4d} | {r['actor_hidden_size']:>5d} | {r['actor_num_layers']:>5d} | {r['critic_hidden_size']:>5d} | {r['critic_num_layers']:>5d} | {r['encoder_activation']:>7s} | {r['params']:>10,d} | {r['ms/fwd']:>7.2f}ms" ) fig, axes = plt.subplots(1, 2, figsize=(14, 4)) @@ -550,11 +566,11 @@ def count_params(module): params = [r["params"] for r in results] times = [r["ms/fwd"] for r in results] -bar_colors = ["coral" if r.get("encoder_gigaflow") else "steelblue" for r in results] +bar_colors = ["coral" if r["encoder_activation"] == "tanh" else "steelblue" for r in results] axes[0].bar(names, params, color=bar_colors, edgecolor="black") axes[0].set_ylabel("Parameters") -axes[0].set_title("Parameter Count (orange=gigaflow)") +axes[0].set_title("Parameter Count (orange=tanh encoder)") axes[0].tick_params(axis="x", rotation=30) for i, v in enumerate(params): axes[0].text(i, v, f"{v:,}", ha="center", va="bottom", fontsize=7) diff --git a/notebooks/notebook_utils.py b/notebooks/notebook_utils.py index df3c8dc063..815cc58ee2 100644 --- a/notebooks/notebook_utils.py +++ b/notebooks/notebook_utils.py @@ -92,16 +92,24 @@ } DEFAULT_POLICY_KWARGS = { - "input_size": 64, + "ego_input_size": 64, + "partner_input_size": 64, + "lane_input_size": 64, + "boundary_input_size": 64, + "traffic_control_input_size": 64, + "target_input_size": 64, "backbone_hidden_size": 128, "backbone_num_layers": 1, "actor_hidden_size": 128, "actor_num_layers": 0, "critic_hidden_size": 128, "critic_num_layers": 0, - "encoder_gigaflow": True, - "dropout": 0.0, - "split_network": False, + "encoder_activation": "tanh", + "encoder_layer_norm": True, + "backbone_activation": "gelu", + "backbone_layer_norm": False, + "shared_network": True, + "mask_padded_features": False, } diff --git a/pufferlib/config/ocean/drive.ini b/pufferlib/config/ocean/drive.ini index 6b01fa5e97..b6b1efc9f8 100644 --- a/pufferlib/config/ocean/drive.ini +++ b/pufferlib/config/ocean/drive.ini @@ -11,13 +11,23 @@ num_workers = auto batch_size = auto [policy] -; Encoder layer -input_size = 64 -encoder_gigaflow = True -dropout = 0.0 +; Encoder layer (per-encoder embedding width) +ego_input_size = 64 +partner_input_size = 64 +lane_input_size = 64 +boundary_input_size = 64 +traffic_control_input_size = 64 +target_input_size = 64 +; Encoder activation - options: "relu", "tanh", "gelu" +encoder_activation = "relu" +encoder_layer_norm = True +mask_padded_features = False ; Shared backbone layer backbone_hidden_size = 512 backbone_num_layers = 4 +; Backbone activation - options: "relu", "tanh", "gelu" +backbone_activation = "gelu" +backbone_layer_norm = False ; Actor head layer actor_hidden_size = 512 actor_num_layers = 0 @@ -25,7 +35,7 @@ actor_num_layers = 0 critic_hidden_size = 512 critic_num_layers = 0 ; Dual or shared actor-critic backbone -split_network = False +shared_network = True [rnn] input_size = 512 @@ -111,7 +121,7 @@ reward_lane_align = 0.025 reward_vel_align = 1.0 reward_lane_center = 0.0038 reward_center_bias = 0.0 -reward_velocity = 0.1 +reward_velocity = 0.0025 reward_reverse = 0.005 reward_timestep = 0.000025 reward_overspeed = 0.05 @@ -183,7 +193,6 @@ adam_beta2 = 0.999 adam_eps = 1e-8 vtrace_c_clip = 1 vtrace_rho_clip = 1 -ppo_granularity = auto adv_sampling_prio_alpha = 0.8499999999999999 adv_sampling_prio_beta0 = 0.8499999999999999 adv_filter_ewma_beta = 0.25 diff --git a/pufferlib/ocean/drive/drive.h b/pufferlib/ocean/drive/drive.h index 486d6280ce..9ef76fab7a 100644 --- a/pufferlib/ocean/drive/drive.h +++ b/pufferlib/ocean/drive/drive.h @@ -128,9 +128,9 @@ #define ROAD_FEATURES 7 #define PARTNER_FEATURES 9 #define TRAFFIC_CONTROL_FEATURES 7 -#define PADDED_OBSERVATION_VALUE -0.001f #define STATIC_TARGET_FEATURES 3 #define DYNAMIC_TARGET_FEATURES 5 +#define OBS_SLOT_NUM_TYPES 4 // GIGAFLOW specific #define MAX_ROUTE_LENGTH 64 @@ -3897,29 +3897,8 @@ static int compute_observation_size(Drive *env) { : 0; return EGO_FEATURES + PARTNER_FEATURES * env->obs_slots_partners_n + ROAD_FEATURES * (env->obs_slots_lane_kept + env->obs_slots_boundary_kept) - + TRAFFIC_CONTROL_FEATURES * env->obs_slots_traffic_controls_n + env->reward_conditioning * NUM_REWARD_COEFS - + env->num_target_waypoints * target_features; -} - -// Fill `rows` x `features` observation slots with the padding sentinel. -static inline void fill_padded_observation_rows(float *obs, int rows, int features) { - for (int r = 0; r < rows; r++) { - for (int c = 0; c < features; c++) { - obs[r * features + c] = PADDED_OBSERVATION_VALUE; - } - } -} - -// Pad `rows` traffic-control slots with the sentinel; type/state columns set to NONE/UNKNOWN. -static inline void fill_padded_traffic_control_rows(float *obs, int rows) { - for (int r = 0; r < rows; r++) { - int base = r * TRAFFIC_CONTROL_FEATURES; - for (int c = 0; c < TRAFFIC_CONTROL_FEATURES - 2; c++) { - obs[base + c] = PADDED_OBSERVATION_VALUE; - } - obs[base + TRAFFIC_CONTROL_FEATURES - 2] = TRAFFIC_CONTROL_TYPE_NONE; - obs[base + TRAFFIC_CONTROL_FEATURES - 1] = TRAFFIC_CONTROL_STATE_UNKNOWN; - } + + TRAFFIC_CONTROL_FEATURES * env->obs_slots_traffic_controls_n + OBS_SLOT_NUM_TYPES + + env->reward_conditioning * NUM_REWARD_COEFS + env->num_target_waypoints * target_features; } void allocate(Drive *env) { @@ -4665,7 +4644,7 @@ static int write_reward_target_obs(Drive *env, Agent *ego, float *obs, int obs_i static int write_partner_obs(Drive *env, Agent *ego, int agent_idx, float *obs, int obs_idx, int *partner_count) { if (ego->is_blind_partner && random_uniform(0.0f, 1.0f) < env->partner_blindness_trigger_prob) { int partner_obs_stride = env->obs_slots_partners_n * PARTNER_FEATURES; - fill_padded_observation_rows(&obs[obs_idx], env->obs_slots_partners_n, PARTNER_FEATURES); + memset(&obs[obs_idx], 0, partner_obs_stride * sizeof(float)); *partner_count = 0; return obs_idx + partner_obs_stride; } @@ -4723,10 +4702,9 @@ static int write_partner_obs(Drive *env, Agent *ego, int agent_idx, float *obs, for (int j = 0; j < partners_to_write; j++) { Agent *other = &env->agents[nearby_agents[j].index]; - float rel_x, rel_y, rel_heading_x, rel_heading_y, rel_vx, rel_vy; + float rel_x, rel_y, rel_heading_x, rel_heading_y; project_vector_to_ego_frame(ego, nearby_agents[j].dx, nearby_agents[j].dy, &rel_x, &rel_y); project_vector_to_ego_frame(ego, other->cos_heading, other->sin_heading, &rel_heading_x, &rel_heading_y); - project_vector_to_ego_frame(ego, other->sim_vx, other->sim_vy, &rel_vx, &rel_vy); obs[obs_idx++] = rel_x / env->obs_norm_xy_offset_m; obs[obs_idx++] = rel_y / env->obs_norm_xy_offset_m; obs[obs_idx++] = nearby_agents[j].dz / Z_BUFFER; @@ -4741,7 +4719,6 @@ static int write_partner_obs(Drive *env, Agent *ego, int agent_idx, float *obs, } *partner_count = partners_written; - fill_padded_observation_rows(&obs[obs_idx], env->obs_slots_partners_n - partners_written, PARTNER_FEATURES); return obs_idx + (env->obs_slots_partners_n - partners_written) * PARTNER_FEATURES; } @@ -4840,28 +4817,28 @@ static int write_road_obs(Drive *env, Agent *ego, float *obs, int obs_idx, int * subsample_road_observation_rows(lanes_buffer, lanes_found, lanes_to_copy); subsample_road_observation_rows(boundaries_buffer, boundaries_found, boundaries_to_copy); memcpy(&obs[lane_obs_idx], lanes_buffer, lanes_to_copy * ROAD_FEATURES * sizeof(float)); - fill_padded_observation_rows( + memset( &obs[lane_obs_idx + lanes_to_copy * ROAD_FEATURES], - env->obs_slots_lane_kept - lanes_to_copy, - ROAD_FEATURES); + 0, + (env->obs_slots_lane_kept - lanes_to_copy) * ROAD_FEATURES * sizeof(float)); memcpy(&obs[boundary_obs_idx], boundaries_buffer, boundaries_to_copy * ROAD_FEATURES * sizeof(float)); - fill_padded_observation_rows( + memset( &obs[boundary_obs_idx + boundaries_to_copy * ROAD_FEATURES], - env->obs_slots_boundary_kept - boundaries_to_copy, - ROAD_FEATURES); + 0, + (env->obs_slots_boundary_kept - boundaries_to_copy) * ROAD_FEATURES * sizeof(float)); return obs_idx; } *lane_count = lanes_found; *boundary_count = boundaries_found; - fill_padded_observation_rows( + memset( &obs[lane_obs_idx + lanes_found * ROAD_FEATURES], - env->obs_slots_lane_kept - lanes_found, - ROAD_FEATURES); - fill_padded_observation_rows( + 0, + (env->obs_slots_lane_kept - lanes_found) * ROAD_FEATURES * sizeof(float)); + memset( &obs[boundary_obs_idx + boundaries_found * ROAD_FEATURES], - env->obs_slots_boundary_kept - boundaries_found, - ROAD_FEATURES); + 0, + (env->obs_slots_boundary_kept - boundaries_found) * ROAD_FEATURES * sizeof(float)); return obs_idx; } @@ -4934,7 +4911,6 @@ static int write_traffic_control_obs(Drive *env, Agent *ego, float *obs, int obs } *traffic_control_count = controls_written; - fill_padded_traffic_control_rows(&obs[obs_idx], env->obs_slots_traffic_controls_n - controls_written); return obs_idx + (env->obs_slots_traffic_controls_n - controls_written) * TRAFFIC_CONTROL_FEATURES; } @@ -4957,6 +4933,10 @@ static void compute_observations(Drive *env) { obs_idx = write_partner_obs(env, ego, i, obs, obs_idx, &partner_count); obs_idx = write_road_obs(env, ego, obs, obs_idx, &lane_count, &boundary_count); obs_idx = write_traffic_control_obs(env, ego, obs, obs_idx, &traffic_control_count); + obs[obs_idx++] = (float) lane_count; + obs[obs_idx++] = (float) boundary_count; + obs[obs_idx++] = (float) partner_count; + obs[obs_idx++] = (float) traffic_control_count; assert(obs_idx == obs_per_agent); } } diff --git a/pufferlib/ocean/drive/drive.py b/pufferlib/ocean/drive/drive.py index 61da6c8d10..f7a017ead2 100644 --- a/pufferlib/ocean/drive/drive.py +++ b/pufferlib/ocean/drive/drive.py @@ -85,7 +85,7 @@ def __init__( reward_conditioning=False, reward_randomization=False, compute_eval_metrics=True, - split_network=False, + shared_network=True, obs_slots_lane_n=32, obs_slots_boundary_n=32, obs_slots_partners_n=16, @@ -117,7 +117,7 @@ def __init__( self.reward_conditioning = reward_conditioning self.reward_randomization = reward_randomization self.compute_eval_metrics = compute_eval_metrics - self.split_network = split_network + self.shared_network = shared_network self.render_mode = render_mode self.num_maps = num_maps self.report_interval = report_interval @@ -215,6 +215,7 @@ def __init__( self.partner_features = binding.PARTNER_FEATURES self.road_features = binding.ROAD_FEATURES self.traffic_control_features = binding.TRAFFIC_CONTROL_FEATURES + self.obs_slot_num_types = binding.OBS_SLOT_NUM_TYPES self.num_reward_coefs = binding.NUM_REWARD_COEFS if reward_conditioning else 0 # Target features based on target_type @@ -222,16 +223,17 @@ def __init__( self.target_features = binding.STATIC_TARGET_FEATURES else: self.target_features = binding.DYNAMIC_TARGET_FEATURES - self.target_dim = self.num_target_waypoints * self.target_features + self.goal_dim = self.num_target_waypoints * self.target_features self.num_obs = ( self.ego_features + self.num_reward_coefs - + self.target_dim + + self.goal_dim + self.obs_slots_partners_n * self.partner_features + self.obs_slots_lane_kept * self.road_features + self.obs_slots_boundary_kept * self.road_features + self.obs_slots_traffic_controls_n * self.traffic_control_features + + self.obs_slot_num_types ) self.single_observation_space = gymnasium.spaces.Box(low=-1, high=1, shape=(self.num_obs,), dtype=np.float32) diff --git a/pufferlib/ocean/drive/render.h b/pufferlib/ocean/drive/render.h index 2cb1a54eca..440b8c3599 100644 --- a/pufferlib/ocean/drive/render.h +++ b/pufferlib/ocean/drive/render.h @@ -782,7 +782,7 @@ void draw_agent_obs(Drive *env, int agent_index, int mode, int obs_only, int las // Draw position of other agents float x = agent_obs[obs_idx] * env->obs_norm_xy_offset_m; float y = agent_obs[obs_idx + 1] * env->obs_norm_xy_offset_m; - float z = agent_obs[obs_idx + 2] * env->obs_norm_xy_offset_m; + float z = agent_obs[obs_idx + 2] * Z_BUFFER; if (lasers && mode == 0) { DrawLine3D((Vector3) {0, 0, 0}, (Vector3) {x, y, z + 1}, ORANGE); } @@ -921,7 +921,7 @@ void draw_agent_obs(Drive *env, int agent_index, int mode, int obs_only, int las // For road segments, draw line between start and end points float x_middle = agent_obs[entity_idx] * env->obs_norm_xy_offset_m; float y_middle = agent_obs[entity_idx + 1] * env->obs_norm_xy_offset_m; - float z_middle = agent_obs[entity_idx + 2] * env->obs_norm_xy_offset_m; + float z_middle = agent_obs[entity_idx + 2] * Z_BUFFER; float rel_angle_x = (agent_obs[entity_idx + 5]); float rel_angle_y = (agent_obs[entity_idx + 6]); float rel_angle = atan2f(rel_angle_y, rel_angle_x); @@ -975,7 +975,7 @@ void draw_agent_obs(Drive *env, int agent_index, int mode, int obs_only, int las } float x_middle = agent_obs[entity_idx] * env->obs_norm_xy_offset_m; float y_middle = agent_obs[entity_idx + 1] * env->obs_norm_xy_offset_m; - float z_middle = agent_obs[entity_idx + 2] * env->obs_norm_xy_offset_m; + float z_middle = agent_obs[entity_idx + 2] * Z_BUFFER; float rel_angle_x = agent_obs[entity_idx + 5]; float rel_angle_y = agent_obs[entity_idx + 6]; float rel_angle = atan2f(rel_angle_y, rel_angle_x); diff --git a/pufferlib/ocean/env_binding.h b/pufferlib/ocean/env_binding.h index 6b7d164ede..860d37c662 100644 --- a/pufferlib/ocean/env_binding.h +++ b/pufferlib/ocean/env_binding.h @@ -1380,6 +1380,7 @@ PyMODINIT_FUNC PyInit_binding(void) { PyModule_AddIntConstant(m, "ROAD_FEATURES", ROAD_FEATURES); PyModule_AddIntConstant(m, "PARTNER_FEATURES", PARTNER_FEATURES); PyModule_AddIntConstant(m, "TRAFFIC_CONTROL_FEATURES", TRAFFIC_CONTROL_FEATURES); + PyModule_AddIntConstant(m, "OBS_SLOT_NUM_TYPES", OBS_SLOT_NUM_TYPES); PyModule_AddIntConstant(m, "NUM_TRAFFIC_CONTROL_TYPES", NUM_TRAFFIC_CONTROL_TYPES); PyModule_AddIntConstant(m, "NUM_TRAFFIC_CONTROL_STATES", NUM_TRAFFIC_CONTROL_STATES); PyModule_AddIntConstant(m, "TRAFFIC_CONTROL_TYPE_NONE", TRAFFIC_CONTROL_TYPE_NONE); diff --git a/pufferlib/ocean/torch.py b/pufferlib/ocean/torch.py index c0e4f06cb7..3e33d32420 100644 --- a/pufferlib/ocean/torch.py +++ b/pufferlib/ocean/torch.py @@ -11,6 +11,8 @@ Recurrent = pufferlib.models.LSTMWrapper +ACTIVATIONS = {"relu": nn.ReLU, "tanh": nn.Tanh, "gelu": nn.GELU} + class DriveBackbone(nn.Module): """ @@ -19,34 +21,55 @@ class DriveBackbone(nn.Module): - Split Actor/Critic (configurable) """ - def _create_encoder(self, in_features, input_size, encoder_gigaflow, dropout=0.0): - if encoder_gigaflow: - return nn.Sequential( - pufferlib.pytorch.layer_init(nn.Linear(in_features, input_size)), - nn.LayerNorm(input_size), - nn.Tanh(), - nn.Dropout(dropout), - pufferlib.pytorch.layer_init(nn.Linear(input_size, input_size)), - ) - else: - return nn.Sequential( - pufferlib.pytorch.layer_init(nn.Linear(in_features, input_size)), - nn.LayerNorm(input_size), - pufferlib.pytorch.layer_init(nn.Linear(input_size, input_size)), - ) + def _create_encoder(self, in_features, out_size): + layers = [pufferlib.pytorch.layer_init(nn.Linear(in_features, out_size))] + if self.encoder_layer_norm: + layers.append(nn.LayerNorm(out_size)) + layers.append(self.encoder_act_cls()) + layers.append(pufferlib.pytorch.layer_init(nn.Linear(out_size, out_size))) + return nn.Sequential(*layers) + + def _encode_and_pool(self, objects, valid_counts, encoder, out_size): + if not self.mask_padded_features: + return encoder(objects).max(dim=1).values + + valid_mask = torch.arange(objects.shape[1], device=objects.device) < valid_counts.unsqueeze(1) + encoded_objects = objects.new_full( + (objects.shape[0], objects.shape[1], out_size), + torch.finfo(objects.dtype).min, + ) + encoded_objects[valid_mask] = encoder(objects[valid_mask]) + pooled = encoded_objects.amax(dim=1) + return torch.where(valid_counts.unsqueeze(1) == 0, encoded_objects.new_zeros(()), pooled) def __init__( self, env, - input_size, + ego_input_size, + partner_input_size, + lane_input_size, + boundary_input_size, + traffic_control_input_size, + target_input_size, backbone_hidden_size, backbone_num_layers, ego_dim, - encoder_gigaflow, - dropout, + encoder_activation, + encoder_layer_norm, + backbone_activation, + backbone_layer_norm, + mask_padded_features, ): super().__init__() - self.input_size = input_size + self.encoder_act_cls = ACTIVATIONS[encoder_activation] + self.encoder_layer_norm = encoder_layer_norm + self.ego_dim = ego_dim + self.ego_input_size = ego_input_size + self.partner_input_size = partner_input_size + self.lane_input_size = lane_input_size + self.boundary_input_size = boundary_input_size + self.traffic_control_input_size = traffic_control_input_size + self.target_input_size = target_input_size # Observation dimensions from environment config self.obs_slots_partners_n = env.obs_slots_partners_n @@ -64,55 +87,47 @@ def __init__( + binding.NUM_TRAFFIC_CONTROL_TYPES + binding.NUM_TRAFFIC_CONTROL_STATES ) + self.obs_slot_num_types = binding.OBS_SLOT_NUM_TYPES + self.mask_padded_features = mask_padded_features # Conditioning size (reward coefficients + target info) - self.conditioning_dim = env.num_reward_coefs + env.target_dim - - num_feature_sets = 1 + self.target_dim = env.num_reward_coefs + env.goal_dim # 1. observations Encoders - # Each encoder projects raw features into a common input_size embedding space - self.ego_encoder = self._create_encoder(ego_dim, input_size, encoder_gigaflow) + # Each encoder projects raw features into its own embedding space + self.ego_encoder = self._create_encoder(ego_dim, ego_input_size) + encoders_out = ego_input_size if self.obs_slots_lane_kept > 0: - self.lane_encoder = self._create_encoder( - self.road_features_count, - input_size, - encoder_gigaflow, - dropout=dropout, - ) - num_feature_sets += 1 + self.lane_encoder = self._create_encoder(self.road_features_count, lane_input_size) + encoders_out += lane_input_size if self.obs_slots_boundary_kept > 0: - self.boundary_encoder = self._create_encoder( - self.road_features_count, - input_size, - encoder_gigaflow, - dropout=dropout, - ) - num_feature_sets += 1 + self.boundary_encoder = self._create_encoder(self.road_features_count, boundary_input_size) + encoders_out += boundary_input_size if self.obs_slots_partners_n > 0: - self.partner_encoder = self._create_encoder(self.partner_features_count, input_size, encoder_gigaflow) - num_feature_sets += 1 + self.partner_encoder = self._create_encoder(self.partner_features_count, partner_input_size) + encoders_out += partner_input_size if self.obs_slots_traffic_controls_n > 0: self.traffic_control_encoder = self._create_encoder( - self.traffic_control_features_after_onehot, - input_size, - encoder_gigaflow, + self.traffic_control_features_after_onehot, traffic_control_input_size ) - num_feature_sets += 1 - if self.conditioning_dim > 0: - self.conditioning_encoder = self._create_encoder(self.conditioning_dim, input_size, encoder_gigaflow) - num_feature_sets += 1 + encoders_out += traffic_control_input_size + if self.target_dim > 0: + self.target_encoder = self._create_encoder(self.target_dim, target_input_size) + encoders_out += target_input_size # 2. Main Backbone MLP + backbone_act_cls = ACTIVATIONS[backbone_activation] backbone_layers = [] - bb_in = num_feature_sets * input_size + bb_in = encoders_out for _ in range(backbone_num_layers): - backbone_layers.append(nn.GELU()) + backbone_layers.append(backbone_act_cls()) backbone_layers.append(pufferlib.pytorch.layer_init(nn.Linear(bb_in, backbone_hidden_size))) + if backbone_layer_norm: + backbone_layers.append(nn.LayerNorm(backbone_hidden_size)) bb_in = backbone_hidden_size - # Add final GELU before heads - backbone_layers.append(nn.GELU()) + # Add final activation before heads + backbone_layers.append(backbone_act_cls()) self.backbone = nn.Sequential(*backbone_layers) - self.out_dim = backbone_hidden_size if backbone_num_layers > 0 else num_feature_sets * input_size + self.out_dim = backbone_hidden_size if backbone_num_layers > 0 else encoders_out def forward(self, observations, ego_dim): # Extract and slice observations from the flat buffer @@ -124,8 +139,8 @@ def forward(self, observations, ego_dim): slide_idx = ego_dim ego_observations = observations[:, :slide_idx] - conditioning_observations = observations[:, slide_idx : slide_idx + self.conditioning_dim] - slide_idx += self.conditioning_dim + target_observations = observations[:, slide_idx : slide_idx + self.target_dim] + slide_idx += self.target_dim partner_observations = observations[:, slide_idx : slide_idx + partner_dim] slide_idx += partner_dim @@ -137,6 +152,20 @@ def forward(self, observations, ego_dim): slide_idx += boundary_dim traffic_control_observations = observations[:, slide_idx : slide_idx + traffic_control_dim] + count_observations = observations[ + :, slide_idx + traffic_control_dim : slide_idx + traffic_control_dim + self.obs_slot_num_types + ] + lane_counts, boundary_counts, partner_counts, traffic_control_counts = [ + count_observations[:, i].long().clamp_(0, capacity) + for i, capacity in enumerate( + ( + self.obs_slots_lane_kept, + self.obs_slots_boundary_kept, + self.obs_slots_partners_n, + self.obs_slots_traffic_controls_n, + ) + ) + ] # Encode Ego State ego_features = self.ego_encoder(ego_observations) @@ -146,17 +175,27 @@ def forward(self, observations, ego_dim): # Encode Lanes and Boundaries separately if self.obs_slots_lane_kept > 0: lane_objects = lane_observations.view(-1, self.obs_slots_lane_kept, self.road_features_count) - lane_features = self.lane_encoder(lane_objects).max(dim=1).values + lane_features = self._encode_and_pool(lane_objects, lane_counts, self.lane_encoder, self.lane_input_size) feature_list.append(lane_features) if self.obs_slots_boundary_kept > 0: boundary_objects = boundary_observations.view(-1, self.obs_slots_boundary_kept, self.road_features_count) - boundary_features = self.boundary_encoder(boundary_objects).max(dim=1).values + boundary_features = self._encode_and_pool( + boundary_objects, + boundary_counts, + self.boundary_encoder, + self.boundary_input_size, + ) feature_list.append(boundary_features) # Encode Partners if self.obs_slots_partners_n > 0: partner_objects = partner_observations.view(-1, self.obs_slots_partners_n, self.partner_features_count) - partner_features = self.partner_encoder(partner_objects).max(dim=1).values + partner_features = self._encode_and_pool( + partner_objects, + partner_counts, + self.partner_encoder, + self.partner_input_size, + ) feature_list.append(partner_features) # Encode Traffic Controls @@ -179,13 +218,18 @@ def forward(self, observations, ego_dim): [traffic_control_continuous, traffic_control_type_onehot, traffic_control_state_onehot], dim=2, ) - traffic_control_features = self.traffic_control_encoder(traffic_control_objects).max(dim=1).values + traffic_control_features = self._encode_and_pool( + traffic_control_objects, + traffic_control_counts, + self.traffic_control_encoder, + self.traffic_control_input_size, + ) feature_list.append(traffic_control_features) # Add optional features if enabled - if self.conditioning_dim > 0: - conditioning_features = self.conditioning_encoder(conditioning_observations) - feature_list.append(conditioning_features) + if self.target_dim > 0: + target_features = self.target_encoder(target_observations) + feature_list.append(target_features) # Concatenate all features and pass through main backbone concat_features = torch.cat(feature_list, dim=1) @@ -197,7 +241,7 @@ def pool_slot_counts(self, observations, ego_dim): boundary_dim = self.obs_slots_boundary_kept * self.road_features_count traffic_control_dim = self.obs_slots_traffic_controls_n * self.traffic_control_features_count - slide_idx = ego_dim + self.conditioning_dim + slide_idx = ego_dim + self.target_dim partner_observations = observations[:, slide_idx : slide_idx + partner_dim] slide_idx += partner_dim lane_observations = observations[:, slide_idx : slide_idx + lane_dim] @@ -263,43 +307,59 @@ class Drive(nn.Module): def __init__( self, env, - input_size: int, + ego_input_size: int, + partner_input_size: int, + lane_input_size: int, + boundary_input_size: int, + traffic_control_input_size: int, + target_input_size: int, backbone_hidden_size: int, backbone_num_layers: int, actor_hidden_size: int, actor_num_layers: int, critic_hidden_size: int, critic_num_layers: int, - encoder_gigaflow: bool, - dropout: int, - split_network: bool, + encoder_activation: str, + encoder_layer_norm: bool, + backbone_activation: str, + backbone_layer_norm: bool, + shared_network: bool, + mask_padded_features: bool, ): super().__init__() # Configuration flags from policy kwargs - self.split_network = split_network + self.shared_network = shared_network self.ego_dim = env.ego_features # Prepare arguments for the Backbone backbone_args = { "env": env, - "input_size": input_size, + "ego_input_size": ego_input_size, + "partner_input_size": partner_input_size, + "lane_input_size": lane_input_size, + "boundary_input_size": boundary_input_size, + "traffic_control_input_size": traffic_control_input_size, + "target_input_size": target_input_size, "backbone_hidden_size": backbone_hidden_size, "backbone_num_layers": backbone_num_layers, "ego_dim": self.ego_dim, - "encoder_gigaflow": encoder_gigaflow, - "dropout": dropout, + "encoder_activation": encoder_activation, + "encoder_layer_norm": encoder_layer_norm, + "backbone_activation": backbone_activation, + "backbone_layer_norm": backbone_layer_norm, + "mask_padded_features": mask_padded_features, } # Instantiate backbones self.actor_backbone = DriveBackbone(**backbone_args) - # If split_network is True, create a separate backbone for the critic. - # Otherwise, share the same backbone for both. - if self.split_network: - self.critic_backbone = DriveBackbone(**backbone_args) - else: + # If using shared network, critic backbone is the same as actor backbone. + # Otherwise, create a separate critic backbone with the same architecture. + if self.shared_network: self.critic_backbone = self.actor_backbone + else: + self.critic_backbone = DriveBackbone(**backbone_args) # Setup action and value heads self.is_continuous = isinstance(env.single_action_space, pufferlib.spaces.Box) @@ -337,10 +397,10 @@ def forward(self, observations, state=None): actor_hidden = self.actor_backbone(observations, self.ego_dim) # Forward pass for critic (may use separate backbone) - if self.split_network: - critic_hidden = self.critic_backbone(observations, self.ego_dim) - else: + if self.shared_network: critic_hidden = actor_hidden + else: + critic_hidden = self.critic_backbone(observations, self.ego_dim) # Compute actions if self.is_continuous: @@ -367,7 +427,7 @@ def pool_slot_counts(self, observations, state=None): # Required for PufferLib recurrent wrappers def encode_observations(self, observations, state=None): - assert not self.split_network, "LSTM wrapper doesn't support split_network=True" + assert self.shared_network, "LSTM wrapper requires shared_network=True" return self.actor_backbone(observations, self.ego_dim) def decode_actions(self, hidden): diff --git a/setup.py b/setup.py index 7504107dd2..52b5283cdc 100644 --- a/setup.py +++ b/setup.py @@ -277,6 +277,7 @@ def run(self): "tensorboard", "jupytext", "torchinfo", + "ipywidgets", ] setup( diff --git a/tests/smoke_tests/data/drive_rollout_golden.json b/tests/smoke_tests/data/drive_rollout_golden.json index 51a15eaf99..0c5a06b948 100644 --- a/tests/smoke_tests/data/drive_rollout_golden.json +++ b/tests/smoke_tests/data/drive_rollout_golden.json @@ -6,7 +6,7 @@ "comfort_violation_count": 0.7257332022373493, "dnf_rate": 0.5625, "episode_length": 46.42307692307692, - "episode_return": -2.4921847994510946, + "episode_return": -2.4922777001674357, "lane_center_rate": 0.7063253728243021, "n": 16.0, "num_goals_reached": 0.0, @@ -23,7 +23,7 @@ "reward_components/red_light": -0.038461538461538464, "reward_components/reverse": -0.01809134094331127, "reward_components/timestep": -0.0001000300175152146, - "reward_components/velocity": 9.528447229128618e-05, + "reward_components/velocity": 2.3821118632510593e-06, "score": 0.0, "velocity_progress_sum": 0.00018323936428015048 }, diff --git a/tests/smoke_tests/data/drive_smoke_golden.json b/tests/smoke_tests/data/drive_smoke_golden.json index ca414e7c34..e8b4778943 100644 --- a/tests/smoke_tests/data/drive_smoke_golden.json +++ b/tests/smoke_tests/data/drive_smoke_golden.json @@ -1,48 +1,48 @@ { "env": { - "avg_distance_per_infraction": 11.850770235061646, - "avg_speed_per_agent": 1.3407511115074158, - "collision_rate": 0.0625, - "comfort_violation_count": 0.7286259929339091, - "dnf_rate": 0.5520833333333334, - "episode_length": 39.0, - "episode_return": -2.2371936639149985, - "lane_center_rate": 0.7488949696222941, + "avg_distance_per_infraction": 15.107745885848999, + "avg_speed_per_agent": 1.4028075337409973, + "collision_rate": 0.0, + "comfort_violation_count": 0.7450608313083649, + "dnf_rate": 0.5625, + "episode_length": 47.0, + "episode_return": -2.568606436252594, + "lane_center_rate": 0.7080668658018112, "n": 16.0, "num_goals_reached": 0.0, - "obs/max": 4.0, - "obs/mean": 0.143798382836394, - "obs/min": -1.2310400689020753, - "offroad_rate": 0.3541666666666667, + "obs/max": 80.0, + "obs/mean": 0.23133239336311817, + "obs/min": -1.2116065127775073, + "offroad_rate": 0.40625, "red_light_violation_rate": 0.03125, "reward_components/ade": 0.0, - "reward_components/collision": -0.10453646133343379, - "reward_components/comfort": -1.432812213897705, + "reward_components/collision": 0.0, + "reward_components/comfort": -1.7539056837558746, "reward_components/goal": 0.0, - "reward_components/lane_align": -0.12008568147818248, - "reward_components/lane_center": -0.0017940285421597462, - "reward_components/offroad": -0.53125, + "reward_components/lane_align": -0.1518757250159979, + "reward_components/lane_center": -0.0031759651901666075, + "reward_components/offroad": -0.609375, "reward_components/overspeed": 0.0, "reward_components/red_light": -0.03125, - "reward_components/reverse": -0.015380206052213907, - "reward_components/timestep": -8.554685700801201e-05, + "reward_components/reverse": -0.01892186817713082, + "reward_components/timestep": -0.0001030468392855255, "reward_components/velocity": 0.0, "score": 0.0, "velocity_progress_sum": 0.0 }, "losses": { - "approx_kl": 0.00031589248516995994, - "clipfrac": 0.0, - "ema_max": 1.1163338720798492, - "entropy": 2.483056511197771, - "explained_variance": 0.16990125179290771, - "filter_threshold": 0.011163338720798492, - "filtered_fraction": 0.03224013340744858, - "kept_fraction": 0.9677598665925514, - "masked_fraction": 0.12158203125, - "old_approx_kl": 0.0005578249381090115, - "policy_loss": -0.001066769240424037, - "value_loss": 0.06768968754581042 + "approx_kl": 0.008834710278149163, + "clipfrac": 0.12224070089203971, + "ema_max": 1.4194585680961609, + "entropy": 2.4655654770987376, + "explained_variance": 0.21746909618377686, + "filter_threshold": 0.01419458568096161, + "filtered_fraction": 0.02824551873981529, + "kept_fraction": 0.9717544812601847, + "masked_fraction": 0.10107421875, + "old_approx_kl": 0.00783040001988411, + "policy_loss": -0.00791241747460195, + "value_loss": 0.3491141710962568 }, "meta": { "bptt_horizon": 64, diff --git a/tests/smoke_tests/test_drive_train.py b/tests/smoke_tests/test_drive_train.py index 038e974211..86942c1020 100644 --- a/tests/smoke_tests/test_drive_train.py +++ b/tests/smoke_tests/test_drive_train.py @@ -148,13 +148,18 @@ def _build_config(): _set_existing( args["policy"], { - "input_size": 32, - "backbone_hidden_size": 32, - "actor_hidden_size": 32, - "critic_hidden_size": 32, + "ego_input_size": 32, + "partner_input_size": 128, + "lane_input_size": 64, + "boundary_input_size": 64, + "traffic_control_input_size": 16, + "target_input_size": 8, + "backbone_hidden_size": 256, + "actor_hidden_size": 128, + "critic_hidden_size": 64, }, ) - _set_existing(args["rnn"], {"input_size": 32, "hidden_size": 32}) + _set_existing(args["rnn"], {"input_size": 256, "hidden_size": 256}) args["wandb"] = False args["neptune"] = False diff --git a/tests/unit_tests/test_drive_backbone.py b/tests/unit_tests/test_drive_backbone.py new file mode 100644 index 0000000000..a22031e84d --- /dev/null +++ b/tests/unit_tests/test_drive_backbone.py @@ -0,0 +1,48 @@ +import torch + +from pufferlib.ocean.torch import DriveBackbone + + +def test_encode_and_pool_masks_padded_objects(): + backbone = object.__new__(DriveBackbone) + backbone.mask_padded_features = True + + objects = torch.tensor( + [ + [[1.0, 10.0], [2.0, 3.0], [100.0, 100.0], [200.0, 200.0]], + [[300.0, 300.0], [400.0, 400.0], [500.0, 500.0], [600.0, 600.0]], + [[4.0, 0.0], [-3.0, 7.0], [5.0, 1.0], [999.0, 999.0]], + ] + ) + valid_counts = torch.tensor([2, 0, 3]) + encoded_inputs = [] + + def encoder(x): + encoded_inputs.append(x) + return x + + pooled = backbone._encode_and_pool(objects, valid_counts, encoder, 2) + + assert len(encoded_inputs) == 1 + torch.testing.assert_close( + encoded_inputs[0], + torch.tensor( + [ + [1.0, 10.0], + [2.0, 3.0], + [4.0, 0.0], + [-3.0, 7.0], + [5.0, 1.0], + ] + ), + ) + torch.testing.assert_close( + pooled, + torch.tensor( + [ + [2.0, 10.0], + [0.0, 0.0], + [5.0, 7.0], + ] + ), + )