diff --git a/baselines/ppo/config/ppo_base_puffer.yaml b/baselines/ppo/config/ppo_base_puffer.yaml index 84cff17f2..b7d633955 100644 --- a/baselines/ppo/config/ppo_base_puffer.yaml +++ b/baselines/ppo/config/ppo_base_puffer.yaml @@ -15,7 +15,7 @@ environment: # Overrides default environment configs (see pygpudrive/env/config. road_map_obs: true partner_obs: true norm_obs: true - add_reference_path: false + add_reference_pos_xy: false remove_non_vehicles: false # If false, all agents are included (vehicles, pedestrians, cyclists) lidar_obs: false # NOTE: Setting this to true currently turns of the other observation types reward_type: "weighted_combination" # Options: "weighted_combination", "reward_conditioned" diff --git a/baselines/ppo/config/ppo_waypoint.yaml b/baselines/ppo/config/ppo_guided_autonomy.yaml similarity index 68% rename from baselines/ppo/config/ppo_waypoint.yaml rename to baselines/ppo/config/ppo_guided_autonomy.yaml index 24b95d655..e3395d25d 100644 --- a/baselines/ppo/config/ppo_waypoint.yaml +++ b/baselines/ppo/config/ppo_guided_autonomy.yaml @@ -2,59 +2,64 @@ mode: "train" use_rnn: false eval_model_path: null baseline: false -data_dir: data/processed/wosac/validation_json_1 +data_dir: data/processed/wosac/validation_json_100 continue_training: false model_cpt: null environment: # Overrides default environment configs (see pygpudrive/env/config.py) name: "gpudrive" num_worlds: 100 # Number of parallel environments - k_unique_scenes: 1 # Number of unique scenes to sample from + k_unique_scenes: 100 # Number of unique scenes to sample from max_controlled_agents: 64 # Maximum number of agents controlled by the model. Make sure this aligns with the variable kMaxAgentCount in src/consts.hpp ego_state: true road_map_obs: true partner_obs: true norm_obs: true + add_previous_action: true + + # Guidance through expert suggestions + guidance: true # If true, the agent will be guided by expert suggestions + guidance_mode: "log_replay" # Options: "log_replay", "vbd_amortized", "vbd_online" + add_reference_pos_xy: true # If true, a reference path is added to the ego observation + add_reference_speed: true # If true, the reference speeds are added to the ego observation + add_reference_heading: true # If true, the reference heading are added to the ego observation + prob_reference_dropout: 0.0 # Value between 0 and 1, probability of a reference point to be zeroed out + + # Reward function + reward_type: "guided_autonomy" + collision_weight: -0.2 + off_road_weight: -0.2 + guidance_pos_xy_weight: 0.01 + guidance_speed_weight: 0.01 + guidance_heading_weight: 0.01 + smoothness_weight: 0.001 + + init_mode: womd_tracks_to_predict + dynamics_model: "classic" remove_non_vehicles: false collision_behavior: "ignore" goal_behavior: "ignore" - reward_type: "follow_waypoints" - waypoint_distance_scale: 0.01 - speed_distance_scale: 0.01 - jerk_smoothness_scale: 0.001 - - init_mode: all_non_trivial #womd_tracks_to_predict - dynamics_model: "classic" polyline_reduction_threshold: 0.1 # Rate at which to sample points from the polyline (0 is use all closest points, 1 maximum sparsity), needs to be balanced with kMaxAgentMapObservationsCount sampling_seed: 42 # If given, the set of scenes to sample from will be deterministic, if None, the set of scenes will be random obs_radius: 50.0 # Visibility radius of the agents action_space_steer_disc: 15 action_space_accel_disc: 11 + max_steer_angle: -1.57 # pi/2 = 1.57, pi/3 = 1.05 + max_accel_value: 4.0 init_steps: 0 # Warmup steps goal_achieved_weight: 0.0 collision_weight: -0.2 off_road_weight: -0.2 - # Versatile Behavior Diffusion (VBD) - use_vbd: false - init_steps: 0 - vbd_trajectory_weight: 0.1 # Importance of distance to the vbd trajectories in the reward function - vbd_in_obs: false - - # Planning guidance - add_reference_path: true # If true, a reference path is added to the ego observation - add_reference_speed: true # If true, the reference speed (scalar) is added to the ego observation - prob_reference_dropout: 0.0 # Value between 0 and 1, probability of a reference point to be zeroed out - wandb: entity: "" project: "humanlike" - group: "debug" + group: "wosac_scale_100_base" mode: "online" # Options: online, offline, disabled tags: ["ppo", "ff"] train: - exp_id: waypoint_rs # Set dynamically in the script if needed + exp_id: guidance_log_replay # Set dynamically in the script if needed seed: 42 cpu_offload: false device: "cuda" # Dynamically set to cuda if available, else cpu @@ -66,7 +71,7 @@ train: resample_scenes: false resample_dataset_size: 500 # Number of unique scenes to sample from resample_interval: 2_000_000 - sample_with_replacement: true + sample_with_replacement: false shuffle_dataset: true file_prefix: "" @@ -102,18 +107,18 @@ train: num_parameters: 0 # Total trainable parameters, to be filled at runtime # # # Checkpointing # # # - checkpoint_interval: 250 # Save policy every k iterations + checkpoint_interval: 50 # Save policy every k iterations checkpoint_path: "./runs" # # # Rendering # # # - render: true # Determines whether to render the environment (note: will slow down training) + render: false # Determines whether to render the environment (note: will slow down training) render_3d: false # Render simulator state in 3d or 2d - render_interval: 200 # Render every k iterations + render_interval: 300 # Render every k iterations render_k_scenarios: 1 # Number of scenarios to render render_format: "mp4" # Options: gif, mp4 render_fps: 20 # Frames per second zoom_radius: 100 - plot_waypoints: true + plot_guidance_pos_xy: true vec: backend: "native" # Only native is currently supported diff --git a/baselines/ppo/config/ppo_population.yaml b/baselines/ppo/config/ppo_population.yaml index 4ff05427c..5d02aa13c 100644 --- a/baselines/ppo/config/ppo_population.yaml +++ b/baselines/ppo/config/ppo_population.yaml @@ -17,7 +17,7 @@ environment: # Overrides default environment configs (see pygpudrive/env/config. norm_obs: true remove_non_vehicles: false # If false, all agents are included (vehicles, pedestrians, cyclists) lidar_obs: false # NOTE: Setting this to true currently turns of the other observation types - reward_type: "reward_conditioned" # Options: "weighted_combination", "reward_conditioned", "follow_waypoints" + reward_type: "reward_conditioned" # Options: "weighted_combination", "reward_conditioned", "guided_autonomy" collision_weight: -0.75 off_road_weight: -0.75 goal_achieved_weight: 1.0 @@ -110,7 +110,7 @@ train: render_format: "mp4" # Options: gif, mp4 render_fps: 20 # Frames per second zoom_radius: 100 - plot_waypoints: true + plot_guidance_pos_xy: true vec: backend: "native" # Only native is currently supported diff --git a/baselines/ppo/ppo_waypoint.py b/baselines/ppo/ppo_guided_autonomy.py similarity index 94% rename from baselines/ppo/ppo_waypoint.py rename to baselines/ppo/ppo_guided_autonomy.py index 621055635..8f6c2d8c9 100644 --- a/baselines/ppo/ppo_waypoint.py +++ b/baselines/ppo/ppo_guided_autonomy.py @@ -106,7 +106,7 @@ def init_wandb(args, name, id=None, resume=True): def run( config_path: Annotated[ str, typer.Argument(help="The path to the default configuration file") - ] = "baselines/ppo/config/ppo_waypoint.yaml", + ] = "baselines/ppo/config/ppo_guided_autonomy.yaml", *, # fmt: off # Environment options @@ -115,10 +115,10 @@ def run( k_unique_scenes: Annotated[Optional[int], typer.Option(help="The number of unique scenes to sample")] = None, collision_weight: Annotated[Optional[float], typer.Option(help="The weight for collision penalty")] = None, off_road_weight: Annotated[Optional[float], typer.Option(help="The weight for off-road penalty")] = None, - goal_achieved_weight: Annotated[Optional[float], typer.Option(help="The weight for goal-achieved reward")] = None, - waypoint_distance_scale: Annotated[Optional[float], typer.Option(help="Scale for realism rewards")] = None, - speed_distance_scale: Annotated[Optional[float], typer.Option(help="Scale for realism rewards")] = None, - jerk_smoothness_scale: Annotated[Optional[float], typer.Option(help="Scale for realism rewards")] = None, + guidance_pos_xy_weight: Annotated[Optional[float], typer.Option(help="Scale for realism rewards")] = None, + guidance_speed_weight: Annotated[Optional[float], typer.Option(help="Scale for realism rewards")] = None, + guidance_heading_weight: Annotated[Optional[float], typer.Option(help="Scale for realism rewards")] = None, + smoothness_weight: Annotated[Optional[float], typer.Option(help="Scale for realism rewards")] = None, dist_to_goal_threshold: Annotated[Optional[float], typer.Option(help="The distance threshold for goal-achieved")] = None, randomize_rewards: Annotated[Optional[int], typer.Option(help="If reward_type == reward_conditioned, choose the condition_mode; 0 or 1")] = 0, sampling_seed: Annotated[Optional[int], typer.Option(help="The seed for sampling scenes")] = None, @@ -130,7 +130,7 @@ def run( vbd_trajectory_weight: Annotated[Optional[float], typer.Option(help="Weight for VBD trajectory deviation penalty")] = 0.1, vbd_in_obs: Annotated[Optional[bool], typer.Option(help="Include VBD predictions in the observation")] = False, init_steps: Annotated[Optional[int], typer.Option(help="Environment warmup steps")] = 0, - + # Train options seed: Annotated[Optional[int], typer.Option(help="The seed for training")] = None, learning_rate: Annotated[Optional[float], typer.Option(help="The learning rate for training")] = None, @@ -174,10 +174,10 @@ def run( "k_unique_scenes": k_unique_scenes, "collision_weight": collision_weight, "off_road_weight": off_road_weight, - "goal_achieved_weight": goal_achieved_weight, - "waypoint_distance_scale": waypoint_distance_scale, - "jerk_smoothness_scale": jerk_smoothness_scale, - "speed_distance_scale": speed_distance_scale, + "guidance_pos_xy_weight": guidance_pos_xy_weight, + "smoothness_weight": smoothness_weight, + "guidance_speed_weight": guidance_speed_weight, + "guidance_heading_weight": guidance_heading_weight, "dist_to_goal_threshold": dist_to_goal_threshold, "sampling_seed": sampling_seed, "obs_radius": obs_radius, diff --git a/checkpoints/model_guidance_log_replay__S_1__04_26_09_02_20_677_000833.pt b/checkpoints/model_guidance_log_replay__S_1__04_26_09_02_20_677_000833.pt new file mode 100644 index 000000000..6854e7c96 Binary files /dev/null and b/checkpoints/model_guidance_log_replay__S_1__04_26_09_02_20_677_000833.pt differ diff --git a/checkpoints/model_guidance_log_replay__S_3__04_27_13_13_33_780_013762.pt b/checkpoints/model_guidance_log_replay__S_3__04_27_13_13_33_780_013762.pt new file mode 100644 index 000000000..5431fd1c0 Binary files /dev/null and b/checkpoints/model_guidance_log_replay__S_3__04_27_13_13_33_780_013762.pt differ diff --git a/checkpoints/model_waypoint_rs__S_1__04_23_19_37_26_618_003500.pt b/checkpoints/model_waypoint_rs__S_1__04_23_19_37_26_618_003500.pt deleted file mode 100644 index 138500377..000000000 Binary files a/checkpoints/model_waypoint_rs__S_1__04_23_19_37_26_618_003500.pt and /dev/null differ diff --git a/examples/eval/README.md b/examples/eval/README.md index 78c321e8e..407f53789 100644 --- a/examples/eval/README.md +++ b/examples/eval/README.md @@ -2,9 +2,13 @@ ## Requirements -Prerequisite +Prerequisite to run the eval +``` +pip install --no-deps waymo-open-dataset-tf-2-12-0==1.6.6 +``` + +Requirement to process the data ``` -pip install --no-deps waymo-open-dataset-tf-2-12-0==1.6.4 pip install --no-deps git+https://github.com/waymo-research/waymax.git@main#egg=waymo-waymax ``` @@ -24,5 +28,5 @@ python examples/eval/extract_dataset.py --data_dir data/raw --save_dir data/proc ## Evaluation Run eval with ``` -python wosac_eval.py +python run_wosac_eval.py ``` diff --git a/examples/eval/run_wosac_eval.py b/examples/eval/run_wosac_eval.py index c08fa9e36..a939cc8aa 100644 --- a/examples/eval/run_wosac_eval.py +++ b/examples/eval/run_wosac_eval.py @@ -3,20 +3,28 @@ import os import sys import mediapy +import logging +import numpy as np +from time import perf_counter from tqdm import tqdm +from pathlib import Path from gpudrive.env.config import EnvConfig from gpudrive.env.env_torch import GPUDriveTorchEnv from gpudrive.env.dataset import SceneDataLoader from gpudrive.datatypes.observation import GlobalEgoState from gpudrive.utils.checkpoint import load_agent +from gpudrive.visualize.utils import img_from_fig # WOSAC sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) from waymo_open_dataset.protos import sim_agents_submission_pb2 -# from eval.wosac_eval import WOSACMetrics from eval.wosac_eval_origin import WOSACMetrics +# Configure logging +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger("WOSAC evaluation") + def get_state(env): """Obtain raw agent states.""" @@ -38,18 +46,34 @@ def get_state(env): def rollout( - env, - sim_agent, - init_steps, - num_envs, - max_agents, - device, + env: GPUDriveTorchEnv, + sim_agent: torch.nn.Module, + init_steps: int, + num_envs: int, + max_agents: int, + device: str, + render_simulator_states: bool = False, + render_agent_pov: bool = False, + render_every_n_steps: int = 5, + save_videos: bool = True, + video_dir: str = "videos", + video_format: str = "gif", ): """Rollout agent in the environment and return the scenario rollouts.""" + # Storage + env_ids = list(range(num_envs)) + simulator_state_frames = {env_id: [] for env_id in range(num_envs)} + agent_observation_frames = {env_id: [] for env_id in range(num_envs)} + + start_env_rollout = perf_counter() + + control_mask = env.cont_agent_mask - next_obs = env.reset() + next_obs = env.reset(control_mask) - scenario_ids = list(env.get_scenario_ids().values()) + # Get scenario ids + scenario_ids_dict = env.get_scenario_ids() + scenario_ids = list(scenario_ids_dict.values()) pos_x_list = [] pos_y_list = [] @@ -57,14 +81,12 @@ def rollout( heading_list = [] done_list = [env.get_dones()] - control_mask = env.cont_agent_mask - pos_x, pos_y, pos_z, heading, _ = get_state(env) for time_step in range(env.episode_len - init_steps): # Predict actions - action, _, _, _ = sim_agent(next_obs[control_mask]) + action, _, _, _ = sim_agent(next_obs) action_template = torch.zeros( (num_envs, max_agents), dtype=torch.int64, device=device @@ -74,7 +96,30 @@ def rollout( # Step env.step_dynamics(action_template) - next_obs = env.get_obs() + # Render + if render_simulator_states and time_step % render_every_n_steps == 0: + sim_states = env.vis.plot_simulator_state( + env_indices=env_ids, + zoom_radius=100, + time_steps=[time_step] * len(env_ids), + plot_guidance_pos_xy=True, + ) + for idx in range(num_envs): + simulator_state_frames[idx].append( + img_from_fig(sim_states[idx]) + ) + + if render_agent_pov and time_step % render_every_n_steps == 0: + agent_obs = env.vis.plot_agent_observation( + env_idx=0, + agent_idx=0, + figsize=(10, 10), + trajectory=env.reference_path[0, :, :].to("cpu"), + ) + agent_observation_frames[idx].append(img_from_fig(agent_obs)) + + # Get next observation + next_obs = env.get_obs(control_mask) done = env.get_dones() pos_x, pos_y, pos_z, heading, id = get_state(env) @@ -86,6 +131,29 @@ def rollout( done_list.append(done) _ = done_list.pop() + if save_videos: + for idx in range(num_envs): + scenario_id = scenario_ids_dict[idx] + if ( + render_simulator_states + and len(simulator_state_frames[idx]) > 0 + ): + mediapy.write_video( + f"{video_dir}/sim_state_env_{idx}_{scenario_id}.{video_format}", + np.array(simulator_state_frames[idx]), + fps=8, + codec=video_format, + ) + + if render_agent_pov and len(agent_observation_frames[0]) > 0: + scenario_id = scenario_ids_dict[0] + mediapy.write_video( + f"{video_dir}/agent_0_{scenario_id}.{video_format}", + np.array(agent_observation_frames[0]), + fps=8, + codec=video_format, + ) + # Generate Scenario pos_x_stack = torch.stack(pos_x_list, dim=-1).cpu().numpy() pos_y_stack = torch.stack(pos_y_list, dim=-1).cpu().numpy() @@ -95,6 +163,12 @@ def rollout( id = id.cpu().numpy() control_mask = control_mask.cpu().numpy() + logging.info( + f"Policy rollout took: {perf_counter() - start_env_rollout:.2f} s ({len(env.data_batch)} scenarios)." + ) + + start_ground_truth_ext = perf_counter() + scenario_rollouts = [] scenario_rollout_masks = [] for i, scenario_id in enumerate(scenario_ids): @@ -129,6 +203,10 @@ def rollout( ) ) + logging.info( + f"Ground truth extraction took: {perf_counter() - start_ground_truth_ext:.2f} s ({len(env.data_batch)} scenarios)." + ) + return scenario_ids, scenario_rollouts, scenario_rollout_masks @@ -136,15 +214,16 @@ def rollout( # Settings MAX_AGENTS = 64 - NUM_ENVS = 1 - DEVICE = "cpu" - NUM_BATCHES = 1 - NUM_ROLLOUTS = 10 + NUM_ENVS = 3 + DEVICE = "cuda" # where to run the env rollouts + NUM_ROLLOUTS_PER_BATCH = 1 + NUM_DATA_BATCHES = 2 INIT_STEPS = 10 DATASET_SIZE = 100 + RENDER = True - DATA_JSON = "data/processed/wosac/validation_json_1" - DATA_TFRECORD = "data/processed/wosac/validation_tfrecord_1" + DATA_JSON = "data/processed/wosac/validation_json_3" + DATA_TFRECORD = "data/processed/wosac/validation_tfrecord_3" # Create data loader val_loader = SceneDataLoader( @@ -152,58 +231,52 @@ def rollout( batch_size=NUM_ENVS, dataset_size=DATASET_SIZE, sample_with_replacement=True, + shuffle=True, file_prefix="", ) # Load agent agent = load_agent( - path_to_cpt="checkpoints/model_waypoint_rs__S_1__04_23_19_37_26_618_003500.pt", + # path_to_cpt="checkpoints/model_guidance_log_replay__S_1__04_26_09_02_20_677_000833.pt", + path_to_cpt="checkpoints/model_guidance_log_replay__S_3__04_27_13_13_33_780_013762.pt", + ).to(DEVICE) + + # Override default environment settings to match those the agent was trained with + default_config = EnvConfig() + config_dict = { + field.name: getattr(agent.config, field.name) + for field in dataclasses.fields(EnvConfig) + if hasattr(agent.config, field.name) + and getattr(agent.config, field.name) + != getattr(default_config, field.name) + } + + # Add fixed overrides specific to WOSAC evaluation + fixed_overrides = { + "init_steps": INIT_STEPS, + } + + logging.info( + f"initializing env with init_mode = {config_dict['init_mode']}" ) - # Obtain config directly from the agent checkpoint - config = agent.config - - # Configs env_config = dataclasses.replace( - EnvConfig(), - ego_state=config.ego_state, - road_map_obs=config.road_map_obs, - partner_obs=config.partner_obs, - reward_type=config.reward_type, - norm_obs=config.norm_obs, - dynamics_model=config.dynamics_model, - collision_behavior=config.collision_behavior, - polyline_reduction_threshold=config.polyline_reduction_threshold, - obs_radius=config.obs_radius, - steer_actions=torch.round( - torch.linspace( - -torch.pi / 3, torch.pi / 3, config.action_space_steer_disc - ), - decimals=3, - ), - accel_actions=torch.round( - torch.linspace(-4.0, 4.0, config.action_space_accel_disc), - decimals=3, - ), - remove_non_vehicles=config.remove_non_vehicles, - init_mode="womd_tracks_to_predict", - init_steps=INIT_STEPS, - goal_behavior="stop", - add_reference_path=config.add_reference_path, - add_reference_speed=config.add_reference_speed, + default_config, **config_dict, **fixed_overrides ) + # Make environment env = GPUDriveTorchEnv( config=env_config, data_loader=val_loader, max_cont_agents=MAX_AGENTS, device=DEVICE, ) + wosac_metrics = WOSACMetrics() - for _ in tqdm(range(NUM_BATCHES)): - for _ in range(NUM_ROLLOUTS): - # try: + for _ in tqdm(range(NUM_DATA_BATCHES)): + for _ in range(NUM_ROLLOUTS_PER_BATCH): + scenario_ids, scenario_rollouts, scenario_rollout_masks = rollout( env=env, sim_agent=agent, @@ -211,11 +284,10 @@ def rollout( num_envs=NUM_ENVS, max_agents=MAX_AGENTS, device=DEVICE, + render_simulator_states=RENDER, + render_agent_pov=RENDER, + save_videos=RENDER, ) - # except Exception as e: - # print(f"Error during rollout: {e}") - # continue - tf_record_paths = [ os.path.join(DATA_TFRECORD, f"{scenario_id}.tfrecords") for scenario_id in scenario_ids @@ -225,11 +297,11 @@ def rollout( scenario_rollouts, # scenario_rollout_masks=scenario_rollout_masks ) - try: - env.swap_data_batch() - except Exception as e: - break + # Swap batch of scenarios + env.swap_data_batch() + + # Aggregate results results = wosac_metrics.compute() for key, value in results.items(): diff --git a/examples/eval/wosac_eval_origin.py b/examples/eval/wosac_eval_origin.py index 943b4e8da..9bc0f57f8 100644 --- a/examples/eval/wosac_eval_origin.py +++ b/examples/eval/wosac_eval_origin.py @@ -1,15 +1,16 @@ import itertools import multiprocessing as mp import os +import datetime +import pandas as pd +import numpy as np from pathlib import Path -from typing import Dict, List +from typing import Dict, List, Optional import tensorflow as tf import waymo_open_dataset.wdl_limited.sim_agents_metrics.metrics as wosac_metrics from waymo_open_dataset.utils.sim_agents import submission_specs -# import eval.wosac_metrics.metrics as wosac_metrics - from google.protobuf import text_format from torch import Tensor, tensor from torchmetrics import Metric @@ -25,11 +26,20 @@ class WOSACMetrics(Metric): validation metrics based on ground truth trajectory, using waymo_open_dataset api """ - def __init__(self, challenge_type = None, prefix: str = "", ego_only: bool = False) -> None: + def __init__( + self, + challenge_type=None, + prefix: str = "", + ego_only: bool = False, + baselines_df: Optional[pd.DataFrame] = None, + save_table_with_baselines: bool = True, + ) -> None: super().__init__() self.is_mp_init = False self.prefix = prefix self.ego_only = ego_only + self.save_table_with_baselines = save_table_with_baselines + if challenge_type is None: self.challenge_type = submission_specs.ChallengeType.SIM_AGENTS else: @@ -37,6 +47,14 @@ def __init__(self, challenge_type = None, prefix: str = "", ego_only: bool = Fal self.wosac_config = load_metrics_config(self.challenge_type) + # Initialize baseline df if not provided + if baselines_df is None: + self.baselines_df = self._create_baselines_df() + else: + self.baselines_df = baselines_df + + self.timestamp = datetime.datetime.now().strftime("%Y-%m-%d %H") + self.field_names = [ "metametric", "average_displacement_error", @@ -56,15 +74,111 @@ def __init__(self, challenge_type = None, prefix: str = "", ego_only: bool = Fal ] for k in self.field_names: self.add_state(k, default=tensor(0.0), dist_reduce_fx="sum") - self.add_state("scenario_counter", default=tensor(0.0), dist_reduce_fx="sum") + self.add_state( + "scenario_counter", default=tensor(0.0), dist_reduce_fx="sum" + ) tf.config.set_visible_devices([], "GPU") + def _create_baselines_df(self) -> pd.DataFrame: + """Create a DataFrame with baseline metrics based on the provided table.""" + columns = [ + "AGENT POLICY", + "REPLAN RATE (Hz)", + "LINEAR SPEED (↑)", + "LINEAR ACCEL. (↑)", + "ANG. SPEED (↑)", + "ANG. ACCEL. (↑)", + "DIST. TO OBJ. (↑)", + "COLLISION (↑)", + "TTC (↑)", + "DIST. TO ROAD EDGE (↑)", + "OFFROAD (↑)", + "COMPOSITE METRIC (↑)", + "ADE (↓)", + "MINADE (↓)", + "COLLISION RATE (↓)", + "OFFROAD RATE (↓)", + ] + + data = [ + # Logged Oracle values + [ + "Logged oracle", + "-", + 0.476, + 0.478, + 0.578, + 0.694, + 0.476, + 1.000, + 0.883, + 0.715, + 1.000, + 0.819, + 0.000, + 0.000, + 0.028, + 0.111, + ], + # Versatile Behavior Diffusion + [ + "VBD", + 0.125, + 0.359, + 0.366, + 0.420, + 0.522, + 0.368, + 0.934, + 0.815, + 0.651, + 0.879, + 0.720, + 2.257, + 1.474, + 0.036, + 0.152, + ], + ] + + # Other rows from the table can be added as needed + data.extend( + [ + [ + "Random agent", + 10, + 0.002, + 0.116, + 0.014, + 0.034, + 0.000, + 0.000, + 0.735, + 0.148, + 0.191, + 0.144, + 50.739, + 50.706, + 1.000, + 0.613, + ], + ] + ) + + return pd.DataFrame(data, columns=columns) + @staticmethod def _compute_scenario_metrics( - config, scenario_file, scenario_rollout, ego_only, scenario_rollouts_mask=None + config, + scenario_file, + scenario_rollout, + ego_only, + scenario_rollouts_mask=None, ) -> sim_agents_metrics_pb2.SimAgentMetrics: scenario = scenario_pb2.Scenario() - for data in tf.data.TFRecordDataset([scenario_file], compression_type=""): + for data in tf.data.TFRecordDataset( + [scenario_file], compression_type="" + ): scenario.ParseFromString(bytes(data.numpy())) break if ego_only: @@ -74,32 +188,43 @@ def _compute_scenario_metrics( scenario.tracks[i].states[t].valid = False while len(scenario.tracks_to_predict) > 1: scenario.tracks_to_predict.pop() - scenario.tracks_to_predict[0].track_index = scenario.sdc_track_index + scenario.tracks_to_predict[ + 0 + ].track_index = scenario.sdc_track_index return wosac_metrics.compute_scenario_metrics_for_bundle( - config, scenario, scenario_rollout,# scenario_rollouts_mask=scenario_rollouts_mask + config, + scenario, + scenario_rollout, # scenario_rollouts_mask=scenario_rollouts_mask ) def update( self, scenario_files: List[str], scenario_rollouts: List[sim_agents_submission_pb2.ScenarioRollouts], - scenario_rollout_masks = None, + scenario_rollout_masks=None, ) -> None: if scenario_rollout_masks is None: scenario_rollout_masks = [None] * len(scenario_rollouts) pool_scenario_metrics = [] - for _scenario, _scenario_rollout, _scenario_mask in zip(scenario_files, scenario_rollouts, scenario_rollout_masks): + for _scenario, _scenario_rollout, _scenario_mask in zip( + scenario_files, scenario_rollouts, scenario_rollout_masks + ): try: pool_scenario_metrics.append( self._compute_scenario_metrics( - self.wosac_config, _scenario, _scenario_rollout, self.ego_only, + self.wosac_config, + _scenario, + _scenario_rollout, + self.ego_only, # scenario_rollouts_mask=_scenario_mask ) ) except Exception as e: - print(f"Error processing scenario {_scenario_rollout.scenario_id}") + print( + f"Error processing scenario {_scenario_rollout.scenario_id}" + ) print(e) for scenario_metrics in pool_scenario_metrics: @@ -108,11 +233,15 @@ def update( self.average_displacement_error += ( scenario_metrics.average_displacement_error ) - self.linear_speed_likelihood += scenario_metrics.linear_speed_likelihood + self.linear_speed_likelihood += ( + scenario_metrics.linear_speed_likelihood + ) self.linear_acceleration_likelihood += ( scenario_metrics.linear_acceleration_likelihood ) - self.angular_speed_likelihood += scenario_metrics.angular_speed_likelihood + self.angular_speed_likelihood += ( + scenario_metrics.angular_speed_likelihood + ) self.angular_acceleration_likelihood += ( scenario_metrics.angular_acceleration_likelihood ) @@ -137,8 +266,12 @@ def update( self.traffic_light_violation_likelihood += ( scenario_metrics.traffic_light_violation_likelihood ) - self.simulated_collision_rate += scenario_metrics.simulated_collision_rate - self.simulated_offroad_rate += scenario_metrics.simulated_offroad_rate + self.simulated_collision_rate += ( + scenario_metrics.simulated_collision_rate + ) + self.simulated_offroad_rate += ( + scenario_metrics.simulated_offroad_rate + ) def compute(self) -> Dict[str, Tensor]: metrics_dict = {} @@ -163,22 +296,76 @@ def compute(self) -> Dict[str, Tensor]: for k in self.field_names: out_dict[f"{self.prefix}/wosac_likelihood/{k}"] = metrics_dict[k] + if self.save_table_with_baselines: + self._add_current_method_to_baselines(metrics_dict, final_metrics) + + table_name = f"wosac_table_{self.timestamp}.csv" + self.baselines_df.to_csv(table_name, index=False) + print(f"Saved df to {table_name}") + return out_dict + def _add_current_method_to_baselines(self, metrics_dict, final_metrics): + """Add the current method's metrics to the baselines dataframe.""" + new_row = { + "AGENT POLICY": "Guided self-play (ours)", + "REPLAN RATE (Hz)": 10, + "LINEAR SPEED (↑)": metrics_dict["linear_speed_likelihood"].item(), + "LINEAR ACCEL. (↑)": metrics_dict[ + "linear_acceleration_likelihood" + ].item(), + "ANG. SPEED (↑)": metrics_dict["angular_speed_likelihood"].item(), + "ANG. ACCEL. (↑)": metrics_dict[ + "angular_acceleration_likelihood" + ].item(), + "DIST. TO OBJ. (↑)": metrics_dict[ + "distance_to_nearest_object_likelihood" + ].item(), + "COLLISION (↑)": metrics_dict[ + "collision_indication_likelihood" + ].item(), + "TTC (↑)": metrics_dict["time_to_collision_likelihood"].item(), + "DIST. TO ROAD EDGE (↑)": metrics_dict[ + "distance_to_road_edge_likelihood" + ].item(), + "OFFROAD (↑)": metrics_dict[ + "offroad_indication_likelihood" + ].item(), + "COMPOSITE METRIC (↑)": final_metrics.realism_meta_metric, + "ADE (↓)": metrics_dict["average_displacement_error"].item(), + "MINADE (↓)": metrics_dict[ + "min_average_displacement_error" + ].item(), + "COLLISION RATE (↓)": metrics_dict[ + "simulated_collision_rate" + ].item(), + "OFFROAD RATE (↓)": metrics_dict["simulated_offroad_rate"].item(), + } + + self.baselines_df = pd.concat( + [self.baselines_df, pd.DataFrame([new_row])], ignore_index=True + ) + + # Optional: Sort the dataframe by a performance metric, such as COMPOSITE METRIC + self.baselines_df = self.baselines_df.sort_values( + by="COMPOSITE METRIC (↑)", ascending=False + ) + def load_metrics_config( challenge_type: submission_specs.ChallengeType, ) -> sim_agents_metrics_pb2.SimAgentMetricsConfig: """Loads the `SimAgentMetricsConfig` used for the challenge.""" import waymo_open_dataset + pyglib_resource = waymo_open_dataset.__path__[0] if challenge_type == submission_specs.ChallengeType.SIM_AGENTS: - config_path = f'{pyglib_resource}/wdl_limited/sim_agents_metrics/challenge_2025_sim_agents_config.textproto' # pylint: disable=line-too-long + config_path = f"{pyglib_resource}/wdl_limited/sim_agents_metrics/challenge_2025_sim_agents_config.textproto" # pylint: disable=line-too-long elif challenge_type == submission_specs.ChallengeType.SCENARIO_GEN: - config_path = f'{pyglib_resource}/wdl_limited/sim_agents_metrics/challenge_2025_scenario_gen_config.textproto' # pylint: disable=line-too-long + config_path = f"{pyglib_resource}/wdl_limited/sim_agents_metrics/challenge_2025_scenario_gen_config.textproto" # pylint: disable=line-too-long else: - raise ValueError(f'Unsupported {challenge_type=}') - with open(config_path, 'r') as f: + raise ValueError(f"Unsupported {challenge_type=}") + with open(config_path, "r") as f: config = sim_agents_metrics_pb2.SimAgentMetricsConfig() text_format.Parse(f.read(), config) - return config \ No newline at end of file + return config diff --git a/examples/experimental/notebooks/debug.ipynb b/examples/experimental/notebooks/debug.ipynb index 2be729d0a..f4fd484c2 100644 --- a/examples/experimental/notebooks/debug.ipynb +++ b/examples/experimental/notebooks/debug.ipynb @@ -31,13 +31,13 @@ }, { "cell_type": "code", - "execution_count": 2, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ "env_config = EnvConfig(\n", " dynamics_model=\"classic\",\n", - " reward_type=\"follow_waypoints\",\n", + " reward_type=\"guided_autonomy\",\n", " add_reference_path=True,\n", ")\n", "render_config = RenderConfig()\n", diff --git a/gpudrive/datatypes/trajectory.py b/gpudrive/datatypes/trajectory.py index 92541571b..99fbf72b8 100644 --- a/gpudrive/datatypes/trajectory.py +++ b/gpudrive/datatypes/trajectory.py @@ -112,7 +112,12 @@ def comp_reference_speed(self): """Returns the average speed of the trajectory.""" return torch.sqrt( self.vel_xy[:, :, :, 0] ** 2 + self.vel_xy[:, :, :, 1] ** 2 - ) + ).unsqueeze(-1) + + @property + def length(self): + """Returns the length of the trajectory.""" + return self.pos_xy.shape[2] @dataclass @@ -128,12 +133,16 @@ class VBDTrajectory: def __init__(self, vbd_traj_tensor: torch.Tensor): """Initializes the VBD trajectory with a tensor.""" - self.pos_x = vbd_traj_tensor[:, :, :, 0] - self.pos_y = vbd_traj_tensor[:, :, :, 1] - self.pos_xy = torch.stack([self.pos_x, self.pos_y], dim=3) - self.yaw = vbd_traj_tensor[:, :, :, 2] - self.vel_x = vbd_traj_tensor[:, :, :, 3] - self.vel_y = vbd_traj_tensor[:, :, :, 4] + self.pos_x = vbd_traj_tensor[:, :, :, 0].unsqueeze(-1) + self.pos_y = vbd_traj_tensor[:, :, :, 1].unsqueeze(-1) + self.pos_xy = vbd_traj_tensor[:, :, :, :2] + self.yaw = vbd_traj_tensor[:, :, :, 2].unsqueeze(-1) + self.vel_x = vbd_traj_tensor[:, :, :, 3].unsqueeze(-1) + self.vel_y = vbd_traj_tensor[:, :, :, 4].unsqueeze(-1) + self.vel_xy = vbd_traj_tensor[:, :, :, 3:5] + self.ref_speed = self.comp_reference_speed() + # Assumption: All timesteps are valid + self.valids = torch.ones_like(self.pos_x, dtype=torch.int32) @classmethod def from_tensor( @@ -148,6 +157,15 @@ def from_tensor( elif backend == "jax": raise NotImplementedError("JAX backend not implemented yet.") + def comp_reference_speed(self): + """Returns the average speed of the trajectory.""" + return torch.sqrt(self.vel_x**2 + self.vel_y**2) + + @property + def length(self): + """Returns the length of the trajectory.""" + return self.pos_xy.shape[2] + # def restore_mean(self, mean_x, mean_y): # """Reapplies the mean to revert back to the original coordinates.""" # # Reshape for broadcasting diff --git a/gpudrive/env/base_env.py b/gpudrive/env/base_env.py index 7c5f6e273..7c5afc16b 100755 --- a/gpudrive/env/base_env.py +++ b/gpudrive/env/base_env.py @@ -62,7 +62,7 @@ def _set_reward_params(self): if ( self.config.reward_type == "sparse_on_goal_achieved" or self.config.reward_type == "weighted_combination" - or self.config.reward_type == "follow_waypoints" + or self.config.reward_type == "guided_autonomy" or self.config.reward_type == "reward_conditioned" ): reward_params.rewardType = ( diff --git a/gpudrive/env/config.py b/gpudrive/env/config.py index 186aca52c..b1fb1ddf0 100755 --- a/gpudrive/env/config.py +++ b/gpudrive/env/config.py @@ -30,6 +30,16 @@ class EnvConfig: partner_obs: bool = True # Include partner vehicle info in observations bev_obs: bool = False # Include rasterized Bird's Eye View observations centered on ego vehicle norm_obs: bool = True # Normalize observations + add_previous_action: bool = False # Previous action time agent has taken + + # Guidance settings; these are used to direct the agent's behavior and + # will be included in the observations if set to True + guidance: bool = True + guidance_mode: str = "log_replay" # Options: "log_replay", "vbd_amortized", "vbd_online", "goals_only" + # Ways to guide the agent + add_reference_pos_xy: bool = True # (x, y) position time series + add_reference_speed: bool = False # speed time series + add_reference_heading: bool = False # heading time series # Maximum number of controlled agents in the scene max_controlled_agents: int = madrona_gpudrive.kMaxAgentCount @@ -51,14 +61,18 @@ class EnvConfig: dynamics_model: str = ( "classic" # Options: "classic", "bicycle", "delta_local", or "state" ) - + # Action space settings (if discretized) # Classic or Invertible Bicycle dynamics model + action_space_steer_disc: int = 13 + action_space_accel_disc: int = 7 + max_steer_angle: float = 1.57 # in radians: pi/2 = 1.57, pi/3 = 1.05 + max_accel_value: float = 4.0 steer_actions: torch.Tensor = torch.round( - torch.linspace(-torch.pi / 3, torch.pi / 3, 13), decimals=3 + torch.linspace(-max_steer_angle, max_steer_angle, action_space_steer_disc), decimals=3 ) accel_actions: torch.Tensor = torch.round( - torch.linspace(-4.0, 4.0, 7), decimals=3 + torch.linspace(-max_accel_value, max_accel_value, action_space_accel_disc), decimals=3 ) head_tilt_actions: torch.Tensor = torch.Tensor([0]) @@ -94,11 +108,6 @@ class EnvConfig: # Goal behavior settings goal_behavior: str = "ignore" # Options: "stop", "ignore", "remove" - # Reference points settings - add_reference_speed: bool = ( - False # Include reference speed in observations - ) - add_reference_path: bool = False prob_reference_dropout: float = ( 0.0 # Probability of dropping reference points ) @@ -106,15 +115,20 @@ class EnvConfig: # Reward settings reward_type: str = "sparse_on_goal_achieved" - # Alternatively, "weighted_combination", "follow_waypoints", "distance_to_vdb_trajs", "reward_conditioned" + # Alternatively, "weighted_combination", "guided_autonomy", "reward_conditioned" - # If reward_type is "follow_waypoints", the following parameters are used - waypoint_sample_interval: int = 1 # Interval for sampling waypoints - waypoint_distance_scale: float = ( - 0.01 # Importance of distance to waypoints + # If reward_type is "guided_autonomy", the following parameters are used + guidance_sample_interval: int = 1 + guidance_pos_xy_weight: float = ( + 0.01 # Importance of matching suggested positions + ) + guidance_speed_weight: float = ( + 0.01 # Importance of matching suggested speeds + ) + guidance_heading_weight: float = ( + 0.0 # Importance of matching suggested headings ) - speed_distance_scale: float = 0.0 - jerk_smoothness_scale: float = 0.0 + smoothness_weight: float = 0.0 # If reward_type is "reward_conditioned", the following parameters are used # Weights for the reward components diff --git a/gpudrive/env/constants.py b/gpudrive/env/constants.py index f457cf223..3b71fd766 100644 --- a/gpudrive/env/constants.py +++ b/gpudrive/env/constants.py @@ -1,4 +1,5 @@ import numpy as np +import madrona_gpudrive """Predefined constants for the environment.""" @@ -22,12 +23,14 @@ MAX_ROAD_SCALE = 100 # Feature shape constants -EGO_FEAT_DIM = 6 # Ego state base fields +EGO_FEAT_DIM = 4 # Ego state base fields PARTNER_FEAT_DIM = 6 ROAD_GRAPH_FEAT_DIM = 13 +ROAD_GRAPH_TOP_K = madrona_gpudrive.kMaxAgentMapObservationsCount + # Dataset constants -LOG_TRAJECTORY_LEN = 91 +LOG_TRAJECTORY_LENGTH = 91 # BEV observation constants BEV_RASTERIZATION_RESOLUTION = 200 diff --git a/gpudrive/env/env_puffer.py b/gpudrive/env/env_puffer.py index 630b26b19..8bd9ec3d3 100644 --- a/gpudrive/env/env_puffer.py +++ b/gpudrive/env/env_puffer.py @@ -168,19 +168,25 @@ def __init__( dynamics_model="classic", action_space_steer_disc=13, action_space_accel_disc=7, + max_steer_angle=1.57, + max_accel_value=4.0, ego_state=True, road_map_obs=True, partner_obs=True, norm_obs=True, lidar_obs=False, bev_obs=False, - add_reference_path=False, - add_reference_speed=False, + add_previous_action=False, + guidance=True, + add_reference_pos_xy=True, + add_reference_speed=True, + add_reference_heading=False, prob_reference_dropout=0.0, reward_type="weighted_combination", - waypoint_distance_scale=0.05, - speed_distance_scale=0.0, - jerk_smoothness_scale=0.0, + guidance_pos_xy_weight=0.01, + guidance_speed_weight=0.0, + guidance_heading_weight=0.0, + smoothness_weight=0.0, condition_mode="random", collision_behavior="ignore", goal_behavior="remove", @@ -204,7 +210,7 @@ def __init__( render_format="mp4", render_fps=15, zoom_radius=50, - plot_waypoints=False, + plot_guidance_pos_xy=False, buf=None, **kwargs, ): @@ -231,6 +237,12 @@ def __init__( self.init_mode = init_mode self.reward_type = reward_type + # Expert guidance + self.guidance = guidance + self.add_reference_pos_xy = add_reference_pos_xy + self.add_reference_speed = add_reference_speed + self.add_reference_heading = add_reference_heading + self.render = render self.render_interval = render_interval self.render_k_scenarios = render_k_scenarios @@ -238,7 +250,7 @@ def __init__( self.render_format = render_format self.render_fps = render_fps self.zoom_radius = zoom_radius - self.plot_waypoints = plot_waypoints + self.plot_guidance_pos_xy = plot_guidance_pos_xy self.track_realism_metrics = track_realism_metrics self.track_n_worlds = track_n_worlds @@ -259,14 +271,18 @@ def __init__( road_map_obs=road_map_obs, partner_obs=partner_obs, reward_type=reward_type, - waypoint_distance_scale=waypoint_distance_scale, - speed_distance_scale=speed_distance_scale, - jerk_smoothness_scale=jerk_smoothness_scale, + guidance_pos_xy_weight=guidance_pos_xy_weight, + guidance_speed_weight=guidance_speed_weight, + guidance_heading_weight=guidance_heading_weight, + smoothness_weight=smoothness_weight, condition_mode=condition_mode, norm_obs=norm_obs, bev_obs=bev_obs, - add_reference_path=add_reference_path, + add_previous_action=add_previous_action, + guidance=guidance, + add_reference_pos_xy=add_reference_pos_xy, add_reference_speed=add_reference_speed, + add_reference_heading=add_reference_heading, prob_reference_dropout=prob_reference_dropout, dynamics_model=dynamics_model, collision_behavior=collision_behavior, @@ -278,15 +294,10 @@ def __init__( lidar_obs=lidar_obs, disable_classic_obs=True if lidar_obs else False, obs_radius=obs_radius, - steer_actions=torch.round( - torch.linspace( - -torch.pi / 3, torch.pi / 3, action_space_steer_disc - ), - decimals=3, - ), - accel_actions=torch.round( - torch.linspace(-4.0, 4.0, action_space_accel_disc), decimals=3 - ), + max_steer_angle=max_steer_angle, + max_accel_value=max_accel_value, + action_space_steer_disc=action_space_steer_disc, + action_space_accel_disc=action_space_accel_disc, use_vbd=use_vbd, vbd_trajectory_weight=vbd_trajectory_weight, ) @@ -433,10 +444,10 @@ def step(self, action): reward_controlled = reward[self.controlled_agent_mask] # Store human-like and internal rewards separately - if self.reward_type == "follow_waypoints": + if self.reward_type == "guided_autonomy": self.human_like_rewards[ self.live_agent_mask - ] += self.env.distance_penalty[self.live_agent_mask] + ] += self.env.guidance_error[self.live_agent_mask] self.internal_rewards[ self.live_agent_mask ] += self.env.base_rewards[self.live_agent_mask] @@ -648,7 +659,7 @@ def render_env(self): env_indices=envs_to_render, time_steps=time_steps, zoom_radius=self.zoom_radius, - plot_waypoints=self.plot_waypoints, + plot_guidance_pos_xy=self.plot_guidance_pos_xy, ) agent_obs = self.env.vis.plot_agent_observation( @@ -775,7 +786,7 @@ def compute_realism_metrics(self, done_worlds): ) # [batch, time, 1] valid_mask = ( - self.env.log_trajectory.valids[done_worlds] + self.env.reference_trajectory.valids[done_worlds] .detach() .cpu() .numpy()[control_mask] @@ -785,14 +796,14 @@ def compute_realism_metrics(self, done_worlds): # Take human logs (ground-truth) # Shape: [worlds, max_cont_agents, time, 2] -> [batch, time, 2] ref_pos_xy_np = ( - self.env.log_trajectory.pos_xy[done_worlds] + self.env.reference_trajectory.pos_xy[done_worlds] .detach() .cpu() .numpy()[control_mask] ) # Shape: [worlds, max_cont_agents, time, 1] -> [batch, time, 1] ref_headings_np = ( - self.env.log_trajectory.yaw[done_worlds] + self.env.reference_trajectory.yaw[done_worlds] .detach() .cpu() .numpy()[control_mask] diff --git a/gpudrive/env/env_torch.py b/gpudrive/env/env_torch.py index 5db6d0d68..60cec0fef 100755 --- a/gpudrive/env/env_torch.py +++ b/gpudrive/env/env_torch.py @@ -12,8 +12,8 @@ from gpudrive.env.base_env import GPUDriveGymEnv from gpudrive.datatypes.trajectory import ( LogTrajectory, - to_local_frame, VBDTrajectory, + to_local_frame, ) from gpudrive.datatypes.roadgraph import ( LocalRoadGraphPoints, @@ -79,6 +79,8 @@ def __init__( # Initialize simulator self.sim = self._initialize_simulator(params, self.data_batch) + self.init_steps = self.config.init_steps + # Controlled agents setup self.cont_agent_mask = self.get_controlled_agents_mask() self.max_agent_count = self.cont_agent_mask.shape[1] @@ -86,24 +88,14 @@ def __init__( self.cont_agent_mask.sum().item() ) - self.log_trajectory = LogTrajectory.from_tensor( - self.sim.expert_trajectory_tensor(), - self.num_worlds, - self.max_agent_count, - backend=self.backend, - device=self.device, - ) + self.setup_guidance() + self.episode_len = self.config.episode_len - self.reference_path_length = self.log_trajectory.pos_xy.shape[2] self.step_in_world = ( self.episode_len - self.sim.steps_remaining_tensor().to_torch() ) - # Now initialize reward weights tensor if using reward_conditioned reward type - if ( - hasattr(self.config, "reward_type") - and self.config.reward_type == "reward_conditioned" - ): + if self.config.reward_type == "reward_conditioned": # Use default condition_mode from config or fall back to "random" condition_mode = getattr(self.config, "condition_mode", "random") self.agent_type = getattr(self.config, "agent_type", None) @@ -115,9 +107,6 @@ def __init__( (self.num_worlds, self.max_cont_agents, 3), device=self.device ) - # Initialize VBD model if used - self._initialize_vbd() - # Setup action and observation spaces self.observation_space = Box( low=-1.0, @@ -140,13 +129,40 @@ def __init__( self.vis = MatplotlibVisualizer( sim_object=self.sim, controlled_agent_mask=self.cont_agent_mask, + reference_trajectory=self.reference_trajectory, goal_radius=self.config.dist_to_goal_threshold, - backend=self.backend, num_worlds=self.num_worlds, render_config=self.render_config, env_config=self.config, ) + def setup_guidance(self): + """Configure the reference trajectory based on the guidance mode.""" + self.guidance_mode = self.config.guidance_mode + + if self.guidance_mode == "vbd_amortized": + trajectory_tensor = self.sim.vbd_trajectory_tensor() + self.reference_trajectory = VBDTrajectory.from_tensor( + trajectory_tensor, self.backend, self.device + ) + elif self.guidance_mode == "vbd_online": + # TODO: Add support for 'vbd_online' mode + raise ValueError( + f"Unsupported guidance mode: {self.guidance_mode}." + ) + else: # Default option is "log_replay" + trajectory_tensor = self.sim.expert_trajectory_tensor() + self.reference_trajectory = LogTrajectory.from_tensor( + trajectory_tensor, + self.num_worlds, + self.max_agent_count, + self.backend, + self.device, + ) + + # Length of the guidance trajectory + self.reference_traj_len = self.reference_trajectory.length + def _initialize_vbd(self): """ Initialize the Versatile Behavior Diffusion (VBD) model and related @@ -161,6 +177,8 @@ def _initialize_vbd(self): if self.use_vbd: self._load_vbd_trajectories() else: + self.init_steps = self.config.init_steps + self.vbd_trajectories = None def _generate_sample_batch(self, init_steps=10): @@ -425,15 +443,9 @@ def reset( return self.get_obs(mask) - def get_dones(self, world_time_steps=None): + def get_dones(self): """ Returns tensor indicating which agents have terminated. - - Args: - world_time_steps: Optional tensor [num_worlds] with current timestep per world. - - Returns: - torch.Tensor: Boolean tensor [num_worlds, num_agents] where True indicates done. """ terminal = ( self.sim.done_tensor() @@ -442,28 +454,12 @@ def get_dones(self, world_time_steps=None): .squeeze(dim=2) .to(torch.float) ) - - if ( - world_time_steps is not None - and self.config.reward_type == "follow_waypoints" - and self.config.waypoint_distance_scale > 0.0 - ): - # Find last valid timestep for each agent, this is the ground-truth episode length - agent_episode_length = 90 - torch.argmax( - self.log_trajectory.valids.squeeze(-1).flip(2), dim=2 - ) - - expanded_time_steps = world_time_steps.unsqueeze(1).expand_as( - agent_episode_length - ) - return terminal.bool() & ( - expanded_time_steps >= agent_episode_length - ) - - else: - return terminal.bool() + return terminal.bool() def get_infos(self): + """ + Returns the info tensor for the current step. + """ return Info.from_tensor( self.sim.info_tensor(), backend=self.backend, @@ -473,7 +469,7 @@ def get_infos(self): def get_rewards( self, collision_weight=-0.5, - goal_achieved_weight=0.0, + goal_achieved_weight=1.0, off_road_weight=-0.5, ): """Obtain the rewards for the current step.""" @@ -527,95 +523,97 @@ def get_rewards( return weighted_rewards - elif self.config.reward_type == "distance_to_vdb_trajs": - # Reward based on distance to VBD predicted trajectories - # (i.e. the deviation from the predicted trajectory) - weighted_rewards = ( - collision_weight * collided - + goal_achieved_weight * goal_achieved - + off_road_weight * off_road - ) + elif self.config.reward_type == "guided_autonomy": - agent_states = GlobalEgoState.from_tensor( - self.sim.absolute_self_observation_tensor(), - self.backend, - self.device, + self.base_rewards = ( + collision_weight * collided + off_road_weight * off_road ) - agent_pos = torch.stack( - [agent_states.pos_x, agent_states.pos_y], dim=-1 - ) + step_in_world = self.step_in_world[:, 0, :].squeeze(-1) - # Extract VBD positions at current time steps for each world - vbd_pos = [] - for i in range(self.num_worlds): - current_time = ( - self.world_time_steps[i].item() - self.init_steps - ) - # Make sure we don't exceed trajectory length - current_time = min( - current_time, self.vbd_trajectories.shape[2] - 1 - ) - vbd_pos.append(self.vbd_trajectories[i, :, current_time, :2]) - vbd_pos_tensor = torch.stack(vbd_pos) + # Assumption: All worlds are at the same time step + # Check if we still have referene trajectory points, if not + # we set the guidance errors to zero. + if step_in_world[0] < self.reference_traj_len: + batch_indices = torch.arange(step_in_world.shape[0]) - # Compute euclidean distance between agent and logs - dist_to_vbd = torch.norm(vbd_pos_tensor - agent_pos, dim=-1) + # Guidance + suggested_pos_xy = self.reference_trajectory.pos_xy[ + batch_indices, :, step_in_world, : + ] - # Add reward based on inverse distance to logs - weighted_rewards += self.vbd_trajectory_weight * torch.exp( - -dist_to_vbd - ) + suggested_speed = self.reference_trajectory.ref_speed[ + batch_indices, :, step_in_world + ].squeeze(-1) - return weighted_rewards + suggested_heading = self.reference_trajectory.yaw[ + batch_indices, :, step_in_world + ].squeeze(-1) - elif self.config.reward_type == "follow_waypoints": - # Reward based on minimizing distance to time-aligned waypoints plus penalty for collision/off-road - self.base_rewards = ( - goal_achieved_weight * goal_achieved - + collision_weight * collided - + off_road_weight * off_road - ) + is_valid = ( + self.reference_trajectory.valids[ + batch_indices, :, step_in_world + ] + .squeeze(-1) + .bool() + ) - # Extract waypoints (ground truth) at time t - step_in_world = self.step_in_world[:, 0, :].squeeze(-1) - batch_indices = torch.arange(step_in_world.shape[0]) - gt_agent_pos = self.log_trajectory.pos_xy[ - batch_indices, :, step_in_world, : - ] + # Get actual agent positions + agent_states = GlobalEgoState.from_tensor( + self.sim.absolute_self_observation_tensor(), + self.backend, + self.device, + ) - gt_agent_speed = self.log_trajectory.ref_speed[ - batch_indices, :, step_in_world - ] - valid_mask = ( - self.log_trajectory.valids[batch_indices, :, step_in_world] - .squeeze(-1) - .bool() - ) + actual_agent_pos_xy = agent_states.pos_xy - # Get actual agent positions - agent_state = GlobalEgoState.from_tensor( - self.sim.absolute_self_observation_tensor(), - self.backend, - self.device, - ) + actual_agent_speed = ( + self.sim.self_observation_tensor().to_torch()[:, :, 0] + ) - actual_agent_speed = self.sim.self_observation_tensor().to_torch()[ - :, :, 0 - ] + actual_agent_heading = agent_states.rotation_angle - actual_agent_pos = torch.stack( - [agent_state.pos_x, agent_state.pos_y], dim=-1 - ) + # Compute distances + guidance_pos_error = torch.norm( + suggested_pos_xy - actual_agent_pos_xy, dim=-1 + ) + guidance_speed_error = ( + suggested_speed - actual_agent_speed + ) ** 2 + guidance_heading_error = ( + suggested_heading - actual_agent_heading + ) ** 2 + + self.guidance_error = ( + -self.config.guidance_pos_xy_weight + * torch.log(guidance_pos_error + 1.0) + - self.config.guidance_speed_weight + * torch.log(guidance_speed_error + 1.0) + - self.config.guidance_heading_weight + * torch.log(guidance_heading_error + 1.0) + ) - speed_error = (gt_agent_speed - actual_agent_speed) ** 2 + # Zero-out guidance errors for invalid time steps, that is, + # those that were not observed at the current time step + self.guidance_error[~is_valid] = 0.0 + + # Reduce guidance density + if self.config.guidance_sample_interval > 1: + waypoint_mask = ( + ( + step_in_world + % self.config.guidance_sample_interval + == 0 + ) + .float() + .unsqueeze(1) + ) + self.guidance_error = self.guidance_error * waypoint_mask - # Compute euclidean distance between agent and waypoints - dist_to_waypoints = torch.norm( - gt_agent_pos - actual_agent_pos, dim=-1 - ) + else: + self.guidance_error = torch.zeros_like(self.base_rewards) - # Penalty for jerky movements + # Encourage smooth driving if hasattr(self, "action_diff"): acceleration_jerk = ( self.action_diff[:, :, 0] ** 2 @@ -625,37 +623,16 @@ def get_rewards( ) # Second action component is steering self.smoothness_penalty = -( - self.config.jerk_smoothness_scale * acceleration_jerk - + self.config.jerk_smoothness_scale * steering_jerk + self.config.smoothness_weight * acceleration_jerk + + self.config.smoothness_weight * steering_jerk ) else: self.smoothness_penalty = torch.zeros_like(self.base_rewards) - self.distance_penalty = ( - -self.config.waypoint_distance_scale - * torch.log(dist_to_waypoints + 1.0) - - self.config.speed_distance_scale - * torch.log(speed_error + 1.0) - ) - - # Zero-out distance penalty for invalid time steps, that is, - # The reference positions have not been observed at every time step - # if not observed, we set the distance penalty to 0 - self.distance_penalty[~valid_mask] = 0.0 + self.guidance_error += self.smoothness_penalty - self.distance_penalty += self.smoothness_penalty - - # Apply waypoint mask only if sampling interval is greater than 1 - if self.config.waypoint_sample_interval > 1: - waypoint_mask = ( - (step_in_world % self.config.waypoint_sample_interval == 0) - .float() - .unsqueeze(1) - ) - self.distance_penalty = self.distance_penalty * waypoint_mask - - # Combine base rewards with distance penalty - rewards = self.base_rewards + self.distance_penalty + # Combine base rewards with guidance error + rewards = self.base_rewards + self.guidance_error return rewards @@ -707,7 +684,11 @@ def _apply_actions(self, actions): self.action_value_tensor.clone() ) - if self.config.dynamics_model == "state" and self.previous_action_value_tensor.shape != self.action_value_tensor.shape: + if ( + self.config.dynamics_model == "state" + and self.previous_action_value_tensor.shape + != self.action_value_tensor.shape + ): self.previous_action_value_tensor = ( self.action_value_tensor.clone() ) @@ -831,58 +812,57 @@ def _set_continuous_action_space(self) -> None: ) return action_space - def _get_ego_state(self, mask=None) -> torch.Tensor: - """Get the ego state.""" + def _get_guidance(self, mask=None) -> torch.Tensor: + """Receive (expert) suggestions from pre-trained model or logs.""" - if not self.config.ego_state: - return torch.Tensor().to(self.device) - - ego_state = LocalEgoState.from_tensor( - self_obs_tensor=self.sim.self_observation_tensor(), - backend=self.backend, - device=self.device, - mask=mask, - ) + if not self.config.guidance: + return torch.zeros(0, device=self.device) - if self.config.norm_obs: - ego_state.normalize() - - base_fields = [ - ego_state.speed.unsqueeze(-1), - ego_state.vehicle_length.unsqueeze(-1), - ego_state.vehicle_width.unsqueeze(-1), - ego_state.is_collided.unsqueeze(-1), - ] + guidance = [] if mask is None: - base_fields.append( - self.previous_action_value_tensor[:, :, :2] - / constants.MAX_ACTION_VALUE, # Previous accel, steering + valid_timesteps_mask = self.reference_trajectory.valids.bool() + + # Provide agent with index to pay attention to through one-hot encoding + next_step_in_world = torch.clamp( + self.step_in_world[:, 0, :].squeeze(-1) + 1, + min=0, + max=self.reference_traj_len - 1, ) + time_one_hot = torch.zeros( + ( + self.num_worlds, + self.max_agent_count, + self.reference_traj_len, + 1, + ), + device=self.device, + ) + time_one_hot[:, :, next_step_in_world, :] = 1.0 if self.config.add_reference_speed: - - avg_ref_speed = ( - self.log_trajectory.ref_speed.clone().mean(axis=-1) + reference_speed = ( + self.reference_trajectory.ref_speed.clone() / constants.MAX_SPEED ) + reference_speed[~valid_timesteps_mask] = constants.INVALID_ID + guidance.append(reference_speed) - base_fields.append(avg_ref_speed.unsqueeze(-1)) - - if self.config.add_reference_path: - - state = ( - self.sim.absolute_self_observation_tensor() - .to_torch() - .clone().to(self.device) + states = None + if ( + self.config.add_reference_pos_xy + or self.config.add_reference_heading + ): + states = GlobalEgoState.from_tensor( + self.sim.absolute_self_observation_tensor(), + self.backend, + self.device, ) - global_ego_pos_xy = state[:, :, :2] - global_ego_yaw = state[:, :, 7] - glob_reference_xy = self.log_trajectory.pos_xy - agent_indices = torch.arange(self.max_cont_agents) + + if self.config.add_reference_pos_xy: + glob_reference_xy = self.reference_trajectory.pos_xy local_reference_xy = torch.empty_like(glob_reference_xy) - valid_timesteps_mask = self.log_trajectory.valids.bool() # Transform reference path to be relative to current # agent positions and heading @@ -894,8 +874,10 @@ def _get_ego_state(self, mask=None) -> torch.Tensor: global_pos_xy=glob_reference_xy[ world_idx, agent_idx, :, : ], - ego_pos=global_ego_pos_xy[world_idx, agent_idx], - ego_yaw=global_ego_yaw[world_idx, agent_idx], + ego_pos=states.pos_xy[world_idx, agent_idx], + ego_yaw=states.rotation_angle[ + world_idx, agent_idx + ], device=self.device, ) @@ -909,23 +891,6 @@ def _get_ego_state(self, mask=None) -> torch.Tensor: ~valid_timesteps_mask.expand_as(local_reference_xy) ] = constants.INVALID_ID - # Provide agent with index to pay attention to through one-hot encoding - next_step_in_world = torch.clamp( - self.step_in_world[:, 0, :].squeeze(-1) + 1, - min=0, - max=self.episode_len, - ) - time_one_hot = torch.zeros( - ( - self.num_worlds, - self.max_agent_count, - self.reference_path_length, - 1, - ), - device=self.device, - ) - time_one_hot[:, :, next_step_in_world, :] = 1.0 - # Make unnormalized reference path available for plotting self.reference_path = torch.cat( (local_ref_xy_orig, time_one_hot), dim=-1 @@ -934,100 +899,91 @@ def _get_ego_state(self, mask=None) -> torch.Tensor: reference_path = torch.cat( (local_reference_xy, time_one_hot), dim=-1 ) + guidance.append(reference_path) - # Flatten the dimensions for stacking - base_fields.append(reference_path.flatten(start_dim=2)) - - # batch_size = local_reference_xy.shape[0] - # num_points = local_reference_xy.shape[1] - # time_steps = local_reference_xy.shape[2] - - # Create dropout mask for the time dimension - # Shape: [batch_size, num_points, time_steps, 1] - # point_dropout_mask = torch.bernoulli( - # torch.ones( - # batch_size, - # num_points, - # time_steps, - # 1, - # device=local_reference_xy.device, - # ) - # * (1 - self.config.prob_reference_dropout) - # ).bool() - - # Apply dropout mask - # self.local_reference_xy = ( - # local_reference_xy * point_dropout_mask - # ) + if self.config.add_reference_heading: + reference_headings = self.reference_trajectory.yaw.clone() - if self.config.reward_type == "reward_conditioned": + # Transform headings to local coordinate frame + for world_idx in range(self.num_worlds): + for agent_idx in range(self.max_cont_agents): + # Subtract current agent heading to get relative heading + reference_headings[ + world_idx, agent_idx + ] -= states.rotation_angle[world_idx, agent_idx] - # Create expanded weights for all environments - # Expand from [max_agents, 3] to [num_worlds, max_agents] - collision_weights = self.reward_weights_tensor[:, 0].expand( - self.num_worlds, -1 - ) - goal_weights = self.reward_weights_tensor[:, 1].expand( - self.num_worlds, -1 - ) - off_road_weights = self.reward_weights_tensor[:, 2].expand( - self.num_worlds, -1 + # Normalize + reference_headings = ( + reference_headings / constants.MAX_ORIENTATION_RAD ) - full_fields = base_fields + [ - collision_weights, - goal_weights, - off_road_weights, - ] - return torch.stack(full_fields).permute(1, 2, 0) - else: - return torch.cat(base_fields, dim=-1) - else: + # Set invalid timesteps to -1 + reference_headings[ + ~valid_timesteps_mask + ] = constants.INVALID_ID + guidance.append(reference_headings) - base_fields.append( - self.previous_action_value_tensor[mask][:, :2] - / constants.MAX_ACTION_VALUE, # Previous accel, steering + return torch.cat(guidance, dim=-1).flatten(start_dim=2) + + else: + batch_size = mask.sum() + batch_indices = torch.arange(batch_size) + + # Provide agent with index to pay attention to through one-hot encoding + next_step_in_world = torch.clamp( + self.step_in_world[mask] + 1, + min=0, + max=self.reference_traj_len - 1, + ) + time_one_hot = torch.zeros( + (batch_size, self.reference_traj_len, 1), + device=self.device, ) + time_one_hot[batch_indices, next_step_in_world] = 1.0 + + valid_timesteps_mask = self.reference_trajectory.valids.bool()[ + mask + ] + + states = None + if ( + self.config.add_reference_pos_xy + or self.config.add_reference_heading + ): + states = GlobalEgoState.from_tensor( + self.sim.absolute_self_observation_tensor(), + self.backend, + self.device, + ) if self.config.add_reference_speed: - avg_ref_speed = ( - self.log_trajectory.ref_speed[mask].clone().mean(axis=-1) + reference_speed = ( + self.reference_trajectory.ref_speed[mask].clone() / constants.MAX_SPEED ) - base_fields.append(avg_ref_speed.unsqueeze(-1)) - - if self.config.add_reference_path: - - # State information - state = ( - self.sim.absolute_self_observation_tensor() - .to_torch() - .clone()[mask] - ).to(self.device) - global_ego_pos_xy = state[:, :2] # Shape: [batch, 2] - global_ego_yaw = state[:, 7] # Shape: [batch] - global_reference_xy = self.log_trajectory.pos_xy.clone()[mask] - valid_timesteps_mask = self.log_trajectory.valids.bool()[mask] - batch_size = global_reference_xy.shape[0] - batch_indices = torch.arange(batch_size) + reference_speed[~valid_timesteps_mask] = constants.INVALID_ID + guidance.append(reference_speed) - # Translate all points to a local coordinate frame - translated = global_reference_xy - global_ego_pos_xy.unsqueeze( - 1 - ) + if self.config.add_reference_pos_xy: + global_reference_xy = self.reference_trajectory.pos_xy.clone()[ + mask + ] - # Create rotation matrices for all agents at once - cos_yaw = torch.cos(global_ego_yaw) - sin_yaw = torch.sin(global_ego_yaw) + # Translate all points to a local coordinate frame + translated = global_reference_xy - states.pos_xy[ + mask + ].unsqueeze(1) # Create batch of rotation matrices: [batch, 2, 2] + cos_yaw = torch.cos(states.rotation_angle[mask]) + sin_yaw = torch.sin(states.rotation_angle[mask]) rotation_matrices = torch.stack( [ torch.stack([cos_yaw, sin_yaw], dim=1), torch.stack([-sin_yaw, cos_yaw], dim=1), ], dim=1, - ) # Shape: [batch, 2, 2] + ) # Apply rotation to all points local_reference_xy = torch.bmm( @@ -1044,16 +1000,6 @@ def _get_ego_state(self, mask=None) -> torch.Tensor: ~valid_timesteps_mask.expand_as(local_reference_xy) ] = constants.INVALID_ID - # Provide agent with index to pay attention to through one-hot encoding - next_step_in_world = torch.clamp( - self.step_in_world[mask] + 1, min=0, max=self.episode_len - ) - time_one_hot = torch.zeros( - (batch_size, self.reference_path_length, 1), - device=self.device, - ) - time_one_hot[batch_indices, next_step_in_world] = 1.0 - # Stack reference_path = torch.cat( (local_reference_xy, time_one_hot), dim=2 @@ -1062,15 +1008,80 @@ def _get_ego_state(self, mask=None) -> torch.Tensor: self.reference_path = torch.cat( (local_reference_xy_orig, time_one_hot), dim=2 ) + guidance.append(reference_path) - # Stack - base_fields.append(reference_path.flatten(start_dim=1)) + if self.config.add_reference_heading: + reference_headings = self.reference_trajectory.yaw[ + mask + ].clone() + + # Translate headings to local coordinate frame by subtracting current global agent headings + reference_headings = ( + reference_headings + - states.rotation_angle[mask].view(-1, 1, 1) + ) + + # Normalize by 2pi to ensure values are in [-1, 1] + reference_headings = ( + reference_headings / constants.MAX_ORIENTATION_RAD + ) + + # Set invalid timesteps to -1 + reference_headings[ + ~valid_timesteps_mask + ] = constants.INVALID_ID + guidance.append(reference_headings) + + return torch.cat(guidance, dim=-1).flatten(start_dim=1) + + def _get_ego_state(self, mask=None) -> torch.Tensor: + """Get the ego state.""" + + if not self.config.ego_state: + return torch.Tensor().to(self.device) + + ego_state = LocalEgoState.from_tensor( + self_obs_tensor=self.sim.self_observation_tensor(), + backend=self.backend, + device=self.device, + mask=mask, + ) + + if self.config.norm_obs: + ego_state.normalize() + + base_fields = [ + ego_state.speed.unsqueeze(-1), + ego_state.vehicle_length.unsqueeze(-1), + ego_state.vehicle_width.unsqueeze(-1), + ego_state.is_collided.unsqueeze(-1), + ] + + if mask is None: + + if self.config.add_previous_action: + normalized_prev_actions = ( + self.previous_action_value_tensor[:, :, :2] + / constants.MAX_ACTION_VALUE + ) + base_fields.append(normalized_prev_actions) if self.config.reward_type == "reward_conditioned": - # For masked agents, we need to extract agent indices from the mask - world_indices, agent_indices = torch.where(mask) - # Get the reward weights for these specific agents + full_fields = base_fields + [ + self.reward_weights_tensor.expand(self.num_worlds, -1) + ] + return torch.stack(full_fields).permute(1, 2, 0) + else: + return torch.cat(base_fields, dim=-1) + else: + + base_fields.append( + self.previous_action_value_tensor[mask][:, :2] + / constants.MAX_ACTION_VALUE, # Previous accel, steering + ) + if self.config.reward_type == "reward_conditioned": + _, agent_indices = torch.where(mask) weights_for_masked_agents = self.reward_weights_tensor.to( self.device )[agent_indices] @@ -1437,12 +1448,12 @@ def get_obs(self, mask=None): Returns: torch.Tensor: (num_worlds, max_agent_count, num_features) """ - # Base observations ego_states = self._get_ego_state(mask) partner_observations = self._get_partner_obs(mask) road_map_observations = self._get_road_map_obs(mask) + guidance = self._get_guidance(mask) - if self.use_vbd and self.config.vbd_in_obs: + if self.config.use_vbd and self.config.vbd_in_obs: # Add ego-centric VBD trajectories vbd_observations = self._get_vbd_obs(mask) @@ -1461,6 +1472,7 @@ def get_obs(self, mask=None): ego_states, partner_observations, road_map_observations, + guidance, ), dim=-1, ) @@ -1546,7 +1558,10 @@ def remove_agents_by_id( ) # Reset static scenario data for the visualizer - self.vis.initialize_static_scenario_data(self.cont_agent_mask) + self.vis.initialize_static_scenario_data( + controlled_agent_mask=self.cont_agent_mask, + reference_trajectory=self.reference_trajectory, + ) def swap_data_batch(self, data_batch=None): """ @@ -1576,36 +1591,15 @@ def swap_data_batch(self, data_batch=None): self.cont_agent_mask.sum().item() ) - # Load VBD trajectories for the new batch if VBD is enabled - if self.use_vbd: - self._load_vbd_trajectories() + # Receive guidance trajectories from the new batch of scenarios + self.setup_guidance() # Reset static scenario data for the visualizer - self.vis.initialize_static_scenario_data(self.cont_agent_mask) - - # Obtain new log trajectory - self.log_trajectory = LogTrajectory.from_tensor( - self.sim.expert_trajectory_tensor(), - self.num_worlds, - self.max_agent_count, - backend=self.backend, - device=self.device - ) - - def _load_vbd_trajectories(self): - """Load VBD trajectories directly from the simulator.""" - if not self.use_vbd: - return - - # Get VBD trajectories from the simulator - vbd_traj = VBDTrajectory.from_tensor( - self.sim.vbd_trajectory_tensor(), - backend=self.backend, - device=self.device, + self.vis.initialize_static_scenario_data( + controlled_agent_mask=self.cont_agent_mask, + reference_trajectory=self.reference_trajectory, ) - self.vbd_trajectories = vbd_traj - def get_expert_actions(self): """Get expert actions for the full trajectories across worlds. @@ -1621,7 +1615,7 @@ def get_expert_actions(self): self.num_worlds, self.max_agent_count, backend=self.backend, - device=self.device + device=self.device, ) if self.config.dynamics_model == "delta_local": @@ -1704,11 +1698,12 @@ def get_scenario_ids(self): if __name__ == "__main__": env_config = EnvConfig( - dynamics_model="delta_local", - reward_type="follow_waypoints", - add_reference_path=True, - init_mode="womd_tracks_to_predict", - init_steps=10, + guidance=True, + guidance_mode="log_replay", # "vbd_amortized" + add_reference_pos_xy=True, + add_reference_speed=True, + add_reference_heading=True, + reward_type="guided_autonomy", ) render_config = RenderConfig() @@ -1719,7 +1714,7 @@ def get_scenario_ids(self): dataset_size=1, sample_with_replacement=False, shuffle=False, - file_prefix="" + file_prefix="", ) # Make env @@ -1731,6 +1726,7 @@ def get_scenario_ids(self): ) control_mask = env.cont_agent_mask + print(f"Number of controlled agents: {control_mask.sum()}") # Rollout @@ -1742,7 +1738,7 @@ def get_scenario_ids(self): expert_actions, _, _, _ = env.get_expert_actions() env_idx = 0 - + highlight_agent = torch.where(control_mask[env_idx, :])[0][0].item() agent_positions = [] @@ -1750,12 +1746,8 @@ def get_scenario_ids(self): env.sim.absolute_self_observation_tensor(), device=env.device, ) - means_xy = ( - env.sim.world_means_tensor().to_torch()[:, :2].to(env.device) - ) - init_state.restore_mean( - mean_x=means_xy[:, 0], mean_y=means_xy[:, 1] - ) + means_xy = env.sim.world_means_tensor().to_torch()[:, :2].to(env.device) + init_state.restore_mean(mean_x=means_xy[:, 0], mean_y=means_xy[:, 1]) agent_positions.append(init_state.pos_xy[env_idx, highlight_agent]) print(f"Highlighted agent: {highlight_agent}") @@ -1783,7 +1775,7 @@ def get_scenario_ids(self): zoom_radius=70, time_steps=[t], center_agent_indices=[highlight_agent], - plot_waypoints=True, + plot_guidance_pos_xy=True, ) agent_obs = env.vis.plot_agent_observation( @@ -1798,10 +1790,6 @@ def get_scenario_ids(self): sim_frames.append(img_from_fig(sim_states[0])) agent_obs_frames.append(img_from_fig(agent_obs)) - world_time_steps = ( - torch.Tensor([t]).repeat((1, env.num_worlds)).long().to(env.device) - ) - obs = env.get_obs(control_mask) reward = env.get_rewards() diff --git a/gpudrive/integrations/puffer/ppo.py b/gpudrive/integrations/puffer/ppo.py index 3bf2ab1f9..70cad1461 100644 --- a/gpudrive/integrations/puffer/ppo.py +++ b/gpudrive/integrations/puffer/ppo.py @@ -672,6 +672,12 @@ def flatten_batch(self, advantages_np): self.b_values = self.b_values[b_flat] self.b_returns = self.b_advantages + self.b_values + if self.b_obs.max() > 1.0 or self.b_obs.min() < -1.0: + print( + f"Warning: The batch of observations contains features outside the range [-1, 1]." + f"Please check your observation normalization; min {self.b_obs.min()}, max {self.b_obs.max()}" + ) + class Utilization(Thread): def __init__(self, delay=1, maxlen=20): diff --git a/gpudrive/networks/agents.py b/gpudrive/networks/agents.py index 23f5ba88d..565374a35 100644 --- a/gpudrive/networks/agents.py +++ b/gpudrive/networks/agents.py @@ -40,26 +40,38 @@ def __init__( self.action_dim = action_dim self.top_k = top_k - # Indices for unpacking the observation modalities - self.ego_state_idx = ( - 9 - if self.config["reward_type"] == "reward_conditioned" - else constants.EGO_FEAT_DIM - ) - if self.config[ - "add_reference_path" - ]: # Every agent receives a reference path - # NOTE: Hardcoded to 91 for now - self.ego_state_idx += 91 * 3 + # Indices for unpacking the different observation modalities + self.ego_state_idx = constants.EGO_FEAT_DIM + if self.config["reward_type"] == "reward_conditioned": + self.ego_state_idx += 3 + if self.config["add_previous_action"]: + self.ego_state_idx += 2 - if self.config["add_reference_speed"]: - self.ego_state_idx += 1 self.max_controlled_agents = madrona_gpudrive.kMaxAgentCount self.max_observable_agents = self.max_controlled_agents - 1 self.partner_obs_idx = self.ego_state_idx + ( constants.PARTNER_FEAT_DIM * self.max_observable_agents ) + self.road_map_idx = self.partner_obs_idx + ( + constants.ROAD_GRAPH_TOP_K * constants.ROAD_GRAPH_FEAT_DIM + ) + + if self.config["guidance"]: + self.guidance_feature_dim = 0 + # One-hot encoding signalling the next time step + self.guidance_feature_dim += constants.LOG_TRAJECTORY_LENGTH + if self.config["add_reference_pos_xy"]: + self.guidance_feature_dim += ( + constants.LOG_TRAJECTORY_LENGTH * 2 + ) + if self.config["add_reference_speed"]: + # TODO: Change to reference path-length (90) if using the vbd_amortized trajs + self.guidance_feature_dim += constants.LOG_TRAJECTORY_LENGTH + if self.config["add_reference_heading"]: + self.guidance_feature_dim += constants.LOG_TRAJECTORY_LENGTH + self.guidance_idx = self.road_map_idx + self.guidance_feature_dim + # Shared embedding networks for both actor and critic self.ego_embed = nn.Sequential( layer_init(nn.Linear(self.ego_state_idx, embed_dim)), @@ -85,9 +97,17 @@ def __init__( layer_init(nn.Linear(embed_dim, embed_dim)), ) + self.guidance_embed = nn.Sequential( + layer_init(nn.Linear(self.guidance_feature_dim, embed_dim)), + nn.LayerNorm(embed_dim), + self.act_func, + nn.Dropout(self.dropout), + layer_init(nn.Linear(embed_dim, embed_dim)), + ) + # Critic network self.critic = nn.Sequential( - layer_init(nn.Linear((2 * top_k + 1) * embed_dim, 32)), + layer_init(nn.Linear((2 * top_k + 2) * embed_dim, 32)), nn.LayerNorm(32), self.act_func, layer_init(nn.Linear(32, 1), std=1.0), @@ -95,7 +115,7 @@ def __init__( # Actor network self.actor = nn.Sequential( - layer_init(nn.Linear((2 * top_k + 1) * embed_dim, 64)), + layer_init(nn.Linear((2 * top_k + 2) * embed_dim, 64)), nn.LayerNorm(64), self.act_func, layer_init(nn.Linear(64, action_dim), std=0.01), @@ -109,12 +129,13 @@ def forward(self, x, action=None): If None, a new actions are sampled. """ # Unpack into modalities - ego_state, partner_obs, road_graph = self.unpack_obs(x) + ego_state, partner_obs, road_graph, guidance = self.unpack_obs(x) # Use shared embedding networks for both actor and critic ego_embed = self.ego_embed(ego_state) partner_embed = self.partner_embed(partner_obs) road_embed = self.road_map_embed(road_graph) + guidance_embed = self.guidance_embed(guidance) # Take top k features from partner and road embeddings partner_max_pool = torch.topk(partner_embed, k=self.top_k, dim=1)[ @@ -127,7 +148,10 @@ def forward(self, x, action=None): ) # Concatenate the embeddings - z = torch.cat([ego_embed, partner_max_pool, road_max_pool], dim=-1) + z = torch.cat( + [ego_embed, partner_max_pool, road_max_pool, guidance_embed], + dim=-1, + ) # Pass to the actor and critic networks logits = self.actor(z) @@ -151,7 +175,8 @@ def unpack_obs(self, obs_flat): ego_state = obs_flat[:, : self.ego_state_idx] partner_obs = obs_flat[:, self.ego_state_idx : self.partner_obs_idx] - roadgraph_obs = obs_flat[:, self.partner_obs_idx :] + roadgraph_obs = obs_flat[:, self.partner_obs_idx : self.road_map_idx] + guidance = obs_flat[:, self.road_map_idx :] road_objects = partner_obs.view( -1, self.max_observable_agents, constants.PARTNER_FEAT_DIM @@ -160,7 +185,7 @@ def unpack_obs(self, obs_flat): -1, TOP_K_ROAD_POINTS, constants.ROAD_GRAPH_FEAT_DIM ) - return ego_state, road_objects, road_graph + return ego_state, road_objects, road_graph, guidance class SeparateActorCriticAgent(nn.Module): @@ -191,7 +216,7 @@ def __init__( else constants.EGO_FEAT_DIM ) if self.config[ - "add_reference_path" + "add_reference_pos_xy" ]: # Every agent receives a reference path self.ego_state_idx += 91 * 2 diff --git a/gpudrive/networks/late_fusion.py b/gpudrive/networks/late_fusion.py index ccda595d8..98b18fc57 100644 --- a/gpudrive/networks/late_fusion.py +++ b/gpudrive/networks/late_fusion.py @@ -107,7 +107,7 @@ def __init__( # Agents know their "type", consisting of three weights # that determine the reward (collision, goal, off-road) self.ego_state_idx += 3 - if "add_reference_path" in self.config: + if "add_reference_pos_xy" in self.config: self.ego_state_idx += 2 * 91 self.vbd_in_obs = self.config.vbd_in_obs diff --git a/gpudrive/utils/generate_sbatch.py b/gpudrive/utils/generate_sbatch.py index d4c61e975..d81078a58 100644 --- a/gpudrive/utils/generate_sbatch.py +++ b/gpudrive/utils/generate_sbatch.py @@ -243,30 +243,33 @@ def save_script(filename, file_path, fields, params, param_order=None): if __name__ == "__main__": - group = "minimal_tiny" + group = "wosac_scale_100_v3" fields = { - "time_h": 6, # Max time per job (job will finish if run is done before) + "time_h": 47, # Max time per job (job will finish if run is done before) "num_gpus": 1, # GPUs per job "max_sim_jobs": 30, # Max jobs at the same time "memory": 70, "job_name": group, - "run_file": "baselines/ppo/ppo_waypoint.py", + "run_file": "baselines/ppo/ppo_guided_autonomy.py", } - + hyperparams = { - "group": [group], # Group name - "num_worlds": [500], - "resample_scenes": [0], - "k_unique_scenes": [4], - #"resample_interval": [5_000_000], - #"resample_dataset_size": [10_000], - #"total_timesteps": [3_000_000_000], + "group": [group], # Group name + "num_worlds": [700], + "resample_scenes": [0], + "k_unique_scenes": [100], + # "resample_interval": [5_000_000], + # "resample_dataset_size": [10_000], + # "total_timesteps": [3_000_000_000], "batch_size": [262_144], "minibatch_size": [16_384], - "waypoint_distance_scale": [0.0, 0.05, 0.1], - "ent_coef": [0.001], + # "guidance_pos_xy_weight": [0.01], + # "guidance_speed_weight": [0.01], + # "guidance_heading_weight": [0.01], + "ent_coef": [0.005, 0.001, 0.003, 0.01], "vf_coef": [0.5], + "update_epochs": [4], "render": [0], } diff --git a/gpudrive/visualize/core.py b/gpudrive/visualize/core.py index 02545c92f..c51a8677d 100644 --- a/gpudrive/visualize/core.py +++ b/gpudrive/visualize/core.py @@ -41,8 +41,8 @@ def __init__( self, sim_object, controlled_agent_mask, + reference_trajectory, goal_radius, - backend: str, num_worlds: int, render_config: Dict[str, Any], env_config: Dict[str, Any], @@ -56,12 +56,17 @@ def __init__( self.figsize = (15, 15) self.env_config = env_config self.render_3d = render_config.render_3d - self.vehicle_height = ( - render_config.vehicle_height - ) # Default vehicle height - self.initialize_static_scenario_data(controlled_agent_mask) + self.vehicle_height = render_config.vehicle_height + self.initialize_static_scenario_data( + controlled_agent_mask=controlled_agent_mask, + reference_trajectory=reference_trajectory, + ) - def initialize_static_scenario_data(self, controlled_agent_mask): + def initialize_static_scenario_data( + self, + controlled_agent_mask, + reference_trajectory, + ): """ Initialize key information for visualization based on the current batch of scenarios. @@ -76,28 +81,16 @@ def initialize_static_scenario_data(self, controlled_agent_mask): backend=self.backend, device=self.device, ) - self.controlled_agent_mask = controlled_agent_mask + self.controlled_agent_mask = controlled_agent_mask.clone().to( + self.device + ) if isinstance(controlled_agent_mask, ArrayImpl): self.controlled_agent_mask = torch.from_numpy( np.array(controlled_agent_mask) ) - self.controlled_agent_mask = self.controlled_agent_mask.to(self.device) - - self.log_trajectory = LogTrajectory.from_tensor( - self.sim_object.expert_trajectory_tensor(), - self.num_worlds, - self.controlled_agent_mask.shape[1], - backend=self.backend, - device=self.device, - ) - - self.vbd_trajectory = VBDTrajectory.from_tensor( - self.sim_object.vbd_trajectory_tensor(), - backend=self.backend, - device=self.device, - ) + self.trajectory = reference_trajectory def plot_simulator_state( self, @@ -105,8 +98,7 @@ def plot_simulator_state( time_steps: Optional[List[int]] = None, center_agent_indices: Optional[List[int]] = None, zoom_radius: int = 100, - plot_waypoints: bool = False, - plot_vbd_trajectory: bool = False, + plot_guidance_pos_xy: bool = False, agent_positions: Optional[torch.Tensor] = None, backward_goals: bool = False, policy_masks: Optional[Dict[int, Dict[str, torch.Tensor]]] = None, @@ -126,7 +118,7 @@ def plot_simulator_state( center_agent_indices: Optional list of center agent indices for zooming. figsize: Tuple for figure size of each subplot. zoom_radius: Radius for zooming in around the center agent. - plot_waypoints: If True, plots the waypoints from the human replays. + plot_guidance_pos_xy: If True, plots the waypoints from the human replays. agent_positions: Optional tensor to plot rolled out agent positions. backward_goals: If True, plots backward goals for controlled agents. policy_mask: dict @@ -257,7 +249,6 @@ def plot_simulator_state( line_width_scale = max(self.figsize) / 15 if policy_masks: - world_based_policy_mask = {} for policy_name, (fn, mask) in policy_masks.items(): @@ -265,7 +256,6 @@ def plot_simulator_state( if world not in world_based_policy_mask: world_based_policy_mask[world] = {} world_based_policy_mask[world][policy_name] = mask[world] - else: world_based_policy_mask = None @@ -273,19 +263,27 @@ def plot_simulator_state( for idx, (env_idx, time_step, center_agent_idx) in enumerate( zip(env_indices, time_steps, center_agent_indices) ): + # Create a completely new figure and axis for each environment to prevent carryover + plt.close( + "all" + ) # Close all existing figures first to prevent memory leaks - # Initialize figure and axes from cached road graph - fig, ax = plt.subplots( - figsize=self.figsize, - subplot_kw={"projection": "3d"} if self.render_3d else {}, - ) + # Initialize a new figure for each environment + fig = plt.figure(figsize=self.figsize) + + # Create a new axis with proper projection if self.render_3d: + ax = fig.add_subplot(111, projection="3d") ax.view_init(elev=30, azim=45) # Set default 3D view angle + else: + ax = fig.add_subplot(111) + + # Set up the figure and axis fig.subplots_adjust(left=0, right=1, bottom=0, top=1) - ax.clear() # Clear any existing content ax.set_aspect("equal", adjustable="box") - figs.append(fig) # Add the new figure - plt.close(fig) # Close the figure to prevent carryover + + # Add to figures list - use a copy to ensure it's detached from the current matplotlib state + figs.append(fig) # Get control mask and omit out-of-bound agents (dead agents) controlled = self.controlled_agent_mask[env_idx, :] @@ -308,21 +306,12 @@ def plot_simulator_state( marker_size_scale=marker_scale, ) - if plot_waypoints: - self._plot_waypoints( + if plot_guidance_pos_xy: + self._plot_reference_xy( ax=ax, control_mask=controlled_live, env_idx=env_idx, - log_trajectory=self.log_trajectory, - line_width_scale=line_width_scale, - ) - - if plot_vbd_trajectory: - self._plot_vbd_trajectory( - ax=ax, - control_mask=controlled_live, - env_idx=env_idx, - vbd_trajectory=self.vbd_trajectory, + trajectory=self.trajectory, line_width_scale=line_width_scale, ) @@ -389,9 +378,12 @@ def plot_simulator_state( ax.set_axis_off() - for fig in figs: + # Apply tight layout to current figure fig.tight_layout(pad=2, rect=[0.00, 0.00, 0.9, 1]) + # Close the figure to prevent memory leaks and cleanup + plt.close(fig) + return figs def plot_agent_trajectories( @@ -626,173 +618,125 @@ def plot_agent_trajectories( return None - def _plot_waypoints( + def _plot_reference_xy( self, ax: matplotlib.axes.Axes, env_idx: int, control_mask: torch.Tensor, - log_trajectory: LogTrajectory, + trajectory, line_width_scale: int = 1.0, ): """Plot the log replay trajectory for controlled agents in either 2D or 3D.""" if self.render_3d: - # Get trajectory points - trajectory_points = log_trajectory.pos_xy[ - env_idx, control_mask, :, : - ].numpy() - - # Set a fixed height for trajectory visualization - trajectory_height = 0.05 # Small height above ground - - # Plot trajectories for each controlled agent - for agent_trajectory in trajectory_points: - # Filter out invalid points (zeros or out of bounds) - valid_mask = ( - (agent_trajectory[:, 0] != 0) - & (agent_trajectory[:, 1] != 0) - & (np.abs(agent_trajectory[:, 0]) < OUT_OF_BOUNDS) - & (np.abs(agent_trajectory[:, 1]) < OUT_OF_BOUNDS) + # Get trajectory points - make a clean copy to avoid reference issues + try: + trajectory_points = ( + trajectory.pos_xy[env_idx, control_mask, :, :] + .clone() + .numpy() ) - valid_points = agent_trajectory[valid_mask] - if len(valid_points) > 1: - # Create segments for the trajectory - segments = [] - for i in range(len(valid_points) - 1): - segment = np.array( - [ - [ - valid_points[i, 0], - valid_points[i, 1], - trajectory_height, - ], + # Set a fixed height for trajectory visualization + trajectory_height = 0.05 # Small height above ground + + # Plot trajectories for each controlled agent + for agent_trajectory in trajectory_points: + # Filter out invalid points (zeros or out of bounds) + valid_mask = ( + (agent_trajectory[:, 0] != 0) + & (agent_trajectory[:, 1] != 0) + & (np.abs(agent_trajectory[:, 0]) < OUT_OF_BOUNDS) + & (np.abs(agent_trajectory[:, 1]) < OUT_OF_BOUNDS) + ) + valid_points = agent_trajectory[valid_mask] + + if len(valid_points) > 1: + # Create segments for the trajectory + segments = [] + for i in range(len(valid_points) - 1): + segment = np.array( [ - valid_points[i + 1, 0], - valid_points[i + 1, 1], - trajectory_height, - ], - ] - ) - segments.append(segment) + [ + valid_points[i, 0], + valid_points[i, 1], + trajectory_height, + ], + [ + valid_points[i + 1, 0], + valid_points[i + 1, 1], + trajectory_height, + ], + ] + ) + segments.append(segment) - # Create line collection with fade effect - colors = np.zeros((len(segments), 4)) - colors[:, 1] = 0.9 # Green component - colors[:, 3] = np.linspace( - 0.2, 0.6, len(segments) - ) # Alpha gradient + # Create line collection with fade effect + colors = np.zeros((len(segments), 4)) + colors[:, 1] = 0.9 # Green component + colors[:, 3] = np.linspace( + 0.2, 0.6, len(segments) + ) # Alpha gradient - lc = Line3DCollection( - segments, colors=colors, linewidth=2 * line_width_scale - ) - ax.add_collection3d(lc) + # Create a fresh line collection for each plot + lc = Line3DCollection( + segments, + colors=colors, + linewidth=2 * line_width_scale, + ) + ax.add_collection3d(lc) - # Add points at trajectory positions - ax.scatter3D( - valid_points[:, 0], - valid_points[:, 1], - np.full_like(valid_points[:, 0], trajectory_height), - color="lightgreen", - s=10, - alpha=0.5, - zorder=0, - ) + # Add points at trajectory positions - creating a new scatter object each time + ax.scatter3D( + valid_points[:, 0], + valid_points[:, 1], + np.full_like( + valid_points[:, 0], trajectory_height + ), + color="lightgreen", + s=10, + alpha=0.5, + zorder=0, + ) + except Exception as e: + print(f"Error plotting 3D reference trajectory: {e}") else: - # Original 2D plotting - ax.scatter( - log_trajectory.pos_xy[env_idx, control_mask, :, 0] - .cpu() - .numpy(), - log_trajectory.pos_xy[env_idx, control_mask, :, 1] - .cpu() - .numpy(), - color="lightgreen", - linewidth=0.35 * line_width_scale, - alpha=0.35, - zorder=0, - ) - - def _plot_vbd_trajectory( - self, - ax: matplotlib.axes.Axes, - env_idx: int, - control_mask: torch.Tensor, - vbd_trajectory: VBDTrajectory, - line_width_scale: int = 1.0, - ): - """Plot the VBD trajectory for controlled agents in either 2D or 3D.""" - if self.render_3d: - # Get trajectory points - trajectory_points = vbd_trajectory.pos_xy[ - env_idx, control_mask, :, : - ].numpy() - - # Set a fixed height for trajectory visualization - trajectory_height = 0.05 + try: + # Create a new scatter plot for this specific environment and control mask + pos_x = ( + trajectory.pos_xy.clone()[env_idx, control_mask, :, 0] + .cpu() + .numpy() + ) + pos_y = ( + trajectory.pos_xy.clone()[env_idx, control_mask, :, 1] + .cpu() + .numpy() + ) - # Plot trajectories for each controlled agent - for agent_trajectory in trajectory_points: # Filter out invalid points (zeros or out of bounds) valid_mask = ( - (agent_trajectory[:, 0] != 0) - & (agent_trajectory[:, 1] != 0) - & (np.abs(agent_trajectory[:, 0]) < OUT_OF_BOUNDS) - & (np.abs(agent_trajectory[:, 1]) < OUT_OF_BOUNDS) + (pos_x != 0) + & (pos_y != 0) + & (np.abs(pos_x) < OUT_OF_BOUNDS) + & (np.abs(pos_y) < OUT_OF_BOUNDS) ) - valid_points = agent_trajectory[valid_mask] - if len(valid_points) > 1: - # Create segments for the trajectory - segments = [] - for i in range(len(valid_points) - 1): - segment = np.array( - [ - [ - valid_points[i, 0], - valid_points[i, 1], - trajectory_height, - ], - [ - valid_points[i + 1, 0], - valid_points[i + 1, 1], - trajectory_height, - ], - ] - ) - segments.append(segment) - - # Create line collection with fade effect - colors = np.zeros((len(segments), 4)) - colors[:, 1] = 0.9 # Green component - colors[:, 3] = np.linspace( - 0.2, 0.6, len(segments) - ) # Alpha gradient - - lc = Line3DCollection( - segments, colors=colors, linewidth=2 * line_width_scale - ) - ax.add_collection3d(lc) + # Apply mask if any valid points exist + if np.any(valid_mask): + pos_x = pos_x[valid_mask] + pos_y = pos_y[valid_mask] - # Add points at trajectory positions - ax.scatter3D( - valid_points[:, 0], - valid_points[:, 1], - np.full_like(valid_points[:, 0], trajectory_height), + # Create a fresh scatter plot + ax.scatter( + pos_x, + pos_y, color="lightgreen", - s=10, - alpha=0.5, + linewidth=0.3 * line_width_scale, + alpha=0.25, zorder=0, ) - else: - # Original 2D plotting - ax.scatter( - vbd_trajectory.pos_xy[env_idx, control_mask, :, 0].numpy(), - vbd_trajectory.pos_xy[env_idx, control_mask, :, 1].numpy(), - color="lightgreen", - linewidth=0.35 * line_width_scale, - alpha=0.35, - zorder=0, - ) + except Exception as e: + print(f"Error plotting 2D reference trajectory: {e}") def _plot_vbd_trajectory( self,