From c1a2687c98913de9864942804f44eb386fee00ed Mon Sep 17 00:00:00 2001 From: kevin Date: Tue, 29 Apr 2025 11:54:18 -0400 Subject: [PATCH 1/5] detect parked script --- data_utils/detect_parked.py | 338 ++++++++++++++++++++++++++++++++++++ 1 file changed, 338 insertions(+) create mode 100644 data_utils/detect_parked.py diff --git a/data_utils/detect_parked.py b/data_utils/detect_parked.py new file mode 100644 index 000000000..20ecb7aac --- /dev/null +++ b/data_utils/detect_parked.py @@ -0,0 +1,338 @@ +import os +import json +import logging +import psutil +import argparse +import math +from pathlib import Path +import numpy as np +from multiprocessing import Pool, cpu_count +from tqdm import tqdm +import trimesh + +logging.basicConfig(level=logging.INFO) + +def _filter_small_segments(segments, min_length=1e-6): + """Filter out segments that are too short.""" + valid_segments = [] + for segment in segments: + start, end = segment + length = np.linalg.norm(np.array(end) - np.array(start)) + if length >= min_length: + valid_segments.append(segment) + return valid_segments + + +def _generate_mesh(segments, height=2.0, width=0.2): + """Generate a mesh from line segments.""" + segments = np.array(segments, dtype=np.float64) + starts, ends = segments[:, 0, :], segments[:, 1, :] + directions = ends - starts + lengths = np.linalg.norm(directions, axis=1, keepdims=True) + unit_directions = directions / lengths + + # Create the base box mesh with the height along the z-axis + base_box = trimesh.creation.box(extents=[1.0, width, height]) + base_box.apply_translation([0.5, 0, 0]) # Align box's origin to its start + z_axis = np.array([0, 0, 1]) + angles = np.arctan2( + unit_directions[:, 1], unit_directions[:, 0] + ) # Rotation in the XY plane + + rectangles = [] + lengths = lengths.flatten() + + for i, (start, length, angle) in enumerate(zip(starts, lengths, angles)): + # Copy the base box and scale to match segment length + scaled_box = base_box.copy() + scaled_box.apply_scale([length, 1.0, 1.0]) + + # Apply rotation around the z-axis + rotation_matrix = trimesh.transformations.rotation_matrix( + angle, z_axis + ) + scaled_box.apply_transform(rotation_matrix) + + # Translate the box to the segment's starting point + scaled_box.apply_translation(start) + + rectangles.append(scaled_box) + + # Concatenate all boxes into a single mesh + mesh = trimesh.util.concatenate(rectangles) + return mesh + + +def _create_agent_box_mesh(position, heading, length, width, height): + """Create a box mesh for an agent at a given position and orientation. + + Args: + position (list): [x, y, z] position + heading (float): yaw angle in radians + length (float): length of the box + width (float): width of the box + height (float): height of the box + + Returns: + trimesh.Trimesh: Box mesh positioned and oriented correctly + """ + # Create box centered at origin + box = trimesh.creation.box(extents=[length, width, height]) + + # Rotate box to align with heading + z_axis = np.array([0, 0, 1]) + rotation_matrix = trimesh.transformations.rotation_matrix(heading, z_axis) + box.apply_transform(rotation_matrix) + + # Move box to position + box.apply_translation(position) + + return box + + +def calculate_trajectory_length(positions, valid_mask, start_idx, end_idx=90): + """ + Calculate the sum of valid trajectory segments from start_idx to end_idx. + + Args: + positions: List of position dictionaries with x, y, z coordinates + valid_mask: List of boolean values indicating valid timesteps + start_idx: Starting index for calculation + end_idx: Ending index for calculation + + Returns: + float: Sum of valid trajectory segments length + """ + # Limit end_idx to the maximum available index + end_idx = min(end_idx, len(positions) - 1) + + # Initialize total length + total_length = 0.0 + + # Iterate through positions from start_idx to end_idx + for i in range(start_idx, end_idx): + # Check if both current and next positions are valid + if valid_mask[i] and valid_mask[i + 1]: + # Calculate distance between consecutive valid positions + current_pos = np.array([positions[i]['x'], positions[i]['y'], positions[i]['z']]) + next_pos = np.array([positions[i + 1]['x'], positions[i + 1]['y'], positions[i + 1]['z']]) + + segment_length = np.linalg.norm(next_pos - current_pos) + total_length += segment_length + + return total_length + + +def process_scene(args): + """Process a single scene file to detect parked vehicles.""" + filepath, init_steps, threshold = args + try: + with open(filepath, 'r') as f: + scene = json.load(f) + + # Initialize counters for this scene + valid_vehicles = 0 + colliding_vehicles = 0 + parked_vehicles = 0 + moving_vehicles = 0 + + # Extract road data for collision checking + roads = scene['roads'] + edge_segments = [] + + # Collect road edge segments for collision checking + for road in roads: + if road["type"] == "road_edge": + edge_vertices = [[r["x"], r["y"], r["z"]] for r in road["geometry"]] + edge_segments.extend([ + [edge_vertices[i], edge_vertices[i + 1]] + for i in range(len(edge_vertices) - 1) + ]) + + # Generate road edge mesh + edge_segments = _filter_small_segments(edge_segments) + edge_mesh = _generate_mesh(edge_segments) + + # Create collision managers + road_collision_manager = trimesh.collision.CollisionManager() + road_collision_manager.add_object("road_edges", edge_mesh) + vehicle_collision_manager = trimesh.collision.CollisionManager() + + # First, collect all valid vehicles at init_steps + valid_vehicle_objects = [] + for obj in scene['objects']: + # Check if object is a vehicle and valid at init_steps + if (obj['type'] == 'vehicle' and + init_steps < len(obj['valid']) and + obj['valid'][init_steps]): + + valid_vehicles += 1 + valid_vehicle_objects.append(obj) + + # Create vehicle box at init_steps + position = [ + obj['position'][init_steps]['x'], + obj['position'][init_steps]['y'], + obj['position'][init_steps]['z'] + ] + heading = obj['heading'][init_steps] + vehicle_box = _create_agent_box_mesh( + position, + heading, + obj['length'], + obj['width'], + obj['height'] + ) + vehicle_collision_manager.add_object(str(obj['id']), vehicle_box) + + # Check for collisions with other vehicles + _, vehicle_collision_pairs = vehicle_collision_manager.in_collision_internal(return_names=True) + + # Check for collisions with road edges + _, road_collision_pairs = vehicle_collision_manager.in_collision_other( + road_collision_manager, return_names=True + ) + + # Combine all colliding vehicle IDs + colliding_vehicle_ids = set() + + # Add vehicles that collide with each other + for v1, v2 in vehicle_collision_pairs: + colliding_vehicle_ids.add(v1) + colliding_vehicle_ids.add(v2) + + # Add vehicles that collide with road edges + for v_id, _ in road_collision_pairs: + colliding_vehicle_ids.add(v_id) + + colliding_vehicles = len(colliding_vehicle_ids) + + # Process each colliding vehicle to determine if parked or moving + for obj in valid_vehicle_objects: + if str(obj['id']) in colliding_vehicle_ids: + # Calculate trajectory length from init_steps to 90 or last timestep + trajectory_length = calculate_trajectory_length( + obj['position'], + obj['valid'], + init_steps, + min(90, len(obj['valid']) - 1) + ) + + # Classify as parked or moving based on threshold + if trajectory_length < threshold: + parked_vehicles += 1 + else: + moving_vehicles += 1 + + return filepath, (valid_vehicles, colliding_vehicles, parked_vehicles, moving_vehicles) + + except Exception as e: + logging.error(f"Error processing {filepath}: {e}") + return filepath, None + + +def process_directory(args): + """Process all JSON files in directory.""" + input_dir = Path(args.input_dir) + num_workers = args.num_workers + init_steps = args.init_steps + threshold = args.threshold + + # Get all JSON files + json_files = list(input_dir.glob("*.json")) + if not json_files: + logging.error(f"No JSON files found in {input_dir}") + return + + logging.info(f"Found {len(json_files)} JSON files to process") + + # Calculate batch size based on available memory + mem_info = psutil.virtual_memory() + available_memory = mem_info.available / (1024**3) # Convert to GB + usable_memory = int(available_memory * 0.9) # Use 90% of available memory + batch_size = min(1000 * usable_memory, len(json_files)) + + # Initialize counters using numpy int64 to handle large numbers + total_processed = np.int64(0) + total_valid_vehicles = np.int64(0) + total_colliding_vehicles = np.int64(0) + total_parked_vehicles = np.int64(0) + total_moving_vehicles = np.int64(0) + + # Process files in batches + for i in range(0, len(json_files), int(batch_size)): + batch = json_files[i:i + int(batch_size)] + + # Process batch in parallel + with Pool(num_workers) as pool: + results = list(tqdm( + pool.imap(process_scene, [(str(f), init_steps, threshold) for f in batch]), + total=len(batch), + desc=f"Processing batch {i//int(batch_size) + 1}" + )) + + # Count results + for filepath, counts in results: + if counts is not None: + valid_vehicles, colliding_vehicles, parked_vehicles, moving_vehicles = counts + total_processed += 1 + total_valid_vehicles += valid_vehicles + total_colliding_vehicles += colliding_vehicles + total_parked_vehicles += parked_vehicles + total_moving_vehicles += moving_vehicles + + # Calculate percentages using float64 for precision + parked_percentage = (float(total_parked_vehicles) / float(total_colliding_vehicles) * 100) if total_colliding_vehicles > 0 else 0.0 + moving_percentage = (float(total_moving_vehicles) / float(total_colliding_vehicles) * 100) if total_colliding_vehicles > 0 else 0.0 + + logging.info(f"Processing complete!") + logging.info(f"Total files processed: {total_processed:,d}") + logging.info(f"Total vehicles valid at t={init_steps}: {total_valid_vehicles:,d}") + logging.info(f"Total vehicles in collision at t={init_steps}: {total_colliding_vehicles:,d}") + logging.info(f"Total parked vehicles: {total_parked_vehicles:,d} ({parked_percentage:.2f}%)") + logging.info(f"Total moving vehicles: {total_moving_vehicles:,d} ({moving_percentage:.2f}%)") + + # Save results to a JSON file for future reference + results = { + "total_files_processed": int(total_processed), + "total_valid_vehicles": int(total_valid_vehicles), + "total_colliding_vehicles": int(total_colliding_vehicles), + "total_parked_vehicles": int(total_parked_vehicles), + "total_moving_vehicles": int(total_moving_vehicles), + "parked_percentage": float(parked_percentage), + "moving_percentage": float(moving_percentage) + } + + with open('parked_vehicle_results.json', 'w') as f: + json.dump(results, f, indent=4) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser( + description="Analyze vehicle parking behavior in JSON files" + ) + parser.add_argument( + "input_dir", + help="Directory containing JSON files to process" + ) + parser.add_argument( + "--num_workers", + type=int, + default=cpu_count(), + help="Number of worker processes (default: number of CPU cores)" + ) + parser.add_argument( + "--init_steps", + type=int, + default=10, + help="Initial timestep to check for vehicle validity and collisions (default: 10)" + ) + parser.add_argument( + "--threshold", + type=float, + default=0.4, + help="Threshold for trajectory length to classify a vehicle as moving (default: 0.2)" + ) + + args = parser.parse_args() + process_directory(args) \ No newline at end of file From 78f73c346bcee4ef5d5e64926a3040cdfc763964 Mon Sep 17 00:00:00 2001 From: kevin Date: Tue, 29 Apr 2025 12:31:36 -0400 Subject: [PATCH 2/5] default arg change --- data_utils/detect_parked.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/data_utils/detect_parked.py b/data_utils/detect_parked.py index 20ecb7aac..7a3750e7e 100644 --- a/data_utils/detect_parked.py +++ b/data_utils/detect_parked.py @@ -330,7 +330,7 @@ def process_directory(args): parser.add_argument( "--threshold", type=float, - default=0.4, + default=1.0, help="Threshold for trajectory length to classify a vehicle as moving (default: 0.2)" ) From e5a9c8bba89942e93f644c3dba1fdf7014c5a42e Mon Sep 17 00:00:00 2001 From: kevin Date: Tue, 29 Apr 2025 13:28:49 -0400 Subject: [PATCH 3/5] parked vehicle mask --- examples/eval/run_wosac_eval.py | 13 +++++++++++-- gpudrive/datatypes/info.py | 2 ++ 2 files changed, 13 insertions(+), 2 deletions(-) diff --git a/examples/eval/run_wosac_eval.py b/examples/eval/run_wosac_eval.py index a939cc8aa..5add2ecde 100644 --- a/examples/eval/run_wosac_eval.py +++ b/examples/eval/run_wosac_eval.py @@ -13,8 +13,10 @@ from gpudrive.env.env_torch import GPUDriveTorchEnv from gpudrive.env.dataset import SceneDataLoader from gpudrive.datatypes.observation import GlobalEgoState +from gpudrive.datatypes.info import Info from gpudrive.utils.checkpoint import load_agent from gpudrive.visualize.utils import img_from_fig +import madrona_gpudrive # WOSAC sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) @@ -67,10 +69,17 @@ def rollout( start_env_rollout = perf_counter() - control_mask = env.cont_agent_mask - next_obs = env.reset(control_mask) + info = Info.from_tensor( + env.sim.info_tensor(), + backend=env.backend, + device=env.device, + ) + # Zero out actions for parked vehicles + zero_action_mask = (info.off_road == 1) & (info.collided_with_vehicle == 1) & (info.type == int(madrona_gpudrive.EntityType.Vehicle)) + control_mask = env.cont_agent_mask & ~zero_action_mask + # Get scenario ids scenario_ids_dict = env.get_scenario_ids() scenario_ids = list(scenario_ids_dict.values()) diff --git a/gpudrive/datatypes/info.py b/gpudrive/datatypes/info.py index 3556294df..c0cca1a84 100644 --- a/gpudrive/datatypes/info.py +++ b/gpudrive/datatypes/info.py @@ -12,7 +12,9 @@ def __init__(self, info_tensor: torch.Tensor): """Initializes the ego state with an observation tensor.""" self.off_road = info_tensor[:, :, 0] self.collided = info_tensor[:, :, 1:3].sum(axis=2) + self.collided_with_vehicle = info_tensor[:, :, 1] self.goal_achieved = info_tensor[:, :, 3] + self.type = info_tensor[:, :, 4] @classmethod def from_tensor( From 57c86e1ab2377d9e008611fa5a0475b6122ce58d Mon Sep 17 00:00:00 2001 From: kevin Date: Tue, 29 Apr 2025 14:57:13 -0400 Subject: [PATCH 4/5] train and eval init modes --- baselines/ppo/config/ppo_guided_autonomy.yaml | 2 +- examples/eval/run_wosac_eval.py | 2 +- gpudrive/env/base_env.py | 6 +++++- gpudrive/env/config.py | 2 +- gpudrive/integrations/vbd/data/amortize.py | 2 +- src/init.hpp | 3 ++- src/level_gen.cpp | 8 ++++++++ 7 files changed, 19 insertions(+), 6 deletions(-) diff --git a/baselines/ppo/config/ppo_guided_autonomy.yaml b/baselines/ppo/config/ppo_guided_autonomy.yaml index e3395d25d..ec8da96eb 100644 --- a/baselines/ppo/config/ppo_guided_autonomy.yaml +++ b/baselines/ppo/config/ppo_guided_autonomy.yaml @@ -34,7 +34,7 @@ environment: # Overrides default environment configs (see pygpudrive/env/config. guidance_heading_weight: 0.01 smoothness_weight: 0.001 - init_mode: womd_tracks_to_predict + init_mode: wosac_train dynamics_model: "classic" remove_non_vehicles: false collision_behavior: "ignore" diff --git a/examples/eval/run_wosac_eval.py b/examples/eval/run_wosac_eval.py index 5add2ecde..4fabc00dc 100644 --- a/examples/eval/run_wosac_eval.py +++ b/examples/eval/run_wosac_eval.py @@ -77,7 +77,7 @@ def rollout( device=env.device, ) # Zero out actions for parked vehicles - zero_action_mask = (info.off_road == 1) & (info.collided_with_vehicle == 1) & (info.type == int(madrona_gpudrive.EntityType.Vehicle)) + zero_action_mask = (info.off_road == 1) | (info.collided_with_vehicle == 1) & (info.type == int(madrona_gpudrive.EntityType.Vehicle)) control_mask = env.cont_agent_mask & ~zero_action_mask # Get scenario ids diff --git a/gpudrive/env/base_env.py b/gpudrive/env/base_env.py index 7c5afc16b..3006c3c0f 100755 --- a/gpudrive/env/base_env.py +++ b/gpudrive/env/base_env.py @@ -117,11 +117,15 @@ def _setup_environment_parameters(self): ) params.rewardParams = self._set_reward_params() params.maxNumControlledAgents = self.max_cont_agents - if self.config.init_mode == "womd_tracks_to_predict": + if self.config.init_mode == "wosac_eval": # Bypasses all gpudrive initialization rules and directly reads from the tracks_to_predict # flag in the WOMD dataset metadata params.readFromTracksToPredict = True params.isStaticAgentControlled = True + params.controlExperts = True + elif self.config.init_mode == "wosac_train": + params.readFromTracksToPredict = True + params.isStaticAgentControlled = True elif self.config.init_mode == "all_objects": params.isStaticAgentControlled = True params.initOnlyValidAgentsAtFirstStep = False diff --git a/gpudrive/env/config.py b/gpudrive/env/config.py index b1fb1ddf0..9e3af7609 100755 --- a/gpudrive/env/config.py +++ b/gpudrive/env/config.py @@ -173,7 +173,7 @@ class EnvConfig: agent_size_scale: float = madrona_gpudrive.vehicleScale # Initialization mode - init_mode: str = "all_non_trivial" # Options: all_non_trivial, all_objects, all_valid, womd_tracks_to_predict + init_mode: str = "all_non_trivial" # Options: all_non_trivial, all_objects, all_valid, wosac_eval, wosac_train # VBD model settings use_vbd: bool = False diff --git a/gpudrive/integrations/vbd/data/amortize.py b/gpudrive/integrations/vbd/data/amortize.py index 194926203..35fa375d6 100644 --- a/gpudrive/integrations/vbd/data/amortize.py +++ b/gpudrive/integrations/vbd/data/amortize.py @@ -64,7 +64,7 @@ def main(): init_steps=INIT_STEPS, # Warmup period dynamics_model="state", # Use state-based dynamics model dist_to_goal_threshold=1e-5, # Trick to make sure the agents don't disappear when they reach the goal - init_mode = 'womd_tracks_to_predict', + init_mode = 'wosac_eval', max_controlled_agents=MAX_CONTROLLED_AGENTS, goal_behavior='ignore' ) diff --git a/src/init.hpp b/src/init.hpp index 0c7e25d3c..8d4ff60dc 100755 --- a/src/init.hpp +++ b/src/init.hpp @@ -131,8 +131,9 @@ namespace madrona_gpudrive bool enableLidar = false; bool disableClassicalObs = false; DynamicsModel dynamicsModel = DynamicsModel::Classic; - bool readFromTracksToPredict = false; // Default: false - for womd_tracks_to_predict initialization mode + bool readFromTracksToPredict = false; // Default: false - for wosac initialization mode uint32_t initSteps = 0; + bool controlExperts = false; // Default: false - for wosac initialization mode }; struct WorldInit diff --git a/src/level_gen.cpp b/src/level_gen.cpp index 76da34ed7..c43ad403f 100755 --- a/src/level_gen.cpp +++ b/src/level_gen.cpp @@ -133,12 +133,20 @@ static inline bool isAgentControllable(Engine &ctx, Entity agent, bool markAsExp if (ctx.get(agent_iface).valids[initSteps] == 0) { return false; } + else if (!ctx.data().params.controlExperts && markAsExpert) { + return false; + } else { return ctx.data().numControlledAgents < ctx.data().params.maxNumControlledAgents; } } // Original logic for other initialization modes + if (ctx.data().params.controlExperts) { + return ctx.data().numControlledAgents < ctx.data().params.maxNumControlledAgents && + ctx.get(agent_iface).valids[0] && + ctx.get(agent) == ResponseType::Dynamic; + } return ctx.data().numControlledAgents < ctx.data().params.maxNumControlledAgents && ctx.get(agent_iface).valids[0] && ctx.get(agent) == ResponseType::Dynamic && From 1e1127f564eb282b74233e75df53f90371bb3c32 Mon Sep 17 00:00:00 2001 From: Daphne Cornelisse Date: Tue, 29 Apr 2025 16:13:35 -0400 Subject: [PATCH 5/5] Minor fixes --- examples/eval/run_wosac_eval.py | 18 ++++++++++++------ gpudrive/env/base_env.py | 1 + gpudrive/env/env_torch.py | 5 +++-- src/bindings.cpp | 3 ++- 4 files changed, 18 insertions(+), 9 deletions(-) diff --git a/examples/eval/run_wosac_eval.py b/examples/eval/run_wosac_eval.py index 4fabc00dc..6e5178737 100644 --- a/examples/eval/run_wosac_eval.py +++ b/examples/eval/run_wosac_eval.py @@ -69,16 +69,21 @@ def rollout( start_env_rollout = perf_counter() - next_obs = env.reset(control_mask) + _ = env.reset() + # Zero out actions for parked vehicles info = Info.from_tensor( env.sim.info_tensor(), backend=env.backend, device=env.device, ) - # Zero out actions for parked vehicles - zero_action_mask = (info.off_road == 1) | (info.collided_with_vehicle == 1) & (info.type == int(madrona_gpudrive.EntityType.Vehicle)) - control_mask = env.cont_agent_mask & ~zero_action_mask + control_mask_all = env.cont_agent_mask.clone() + zero_action_mask = (info.off_road == 1) | ( + info.collided_with_vehicle == 1 + ) & (info.type == int(madrona_gpudrive.EntityType.Vehicle)) + control_mask = control_mask_all & ~zero_action_mask + + next_obs = env.reset(control_mask) # Get scenario ids scenario_ids_dict = env.get_scenario_ids() @@ -226,7 +231,7 @@ def rollout( NUM_ENVS = 3 DEVICE = "cuda" # where to run the env rollouts NUM_ROLLOUTS_PER_BATCH = 1 - NUM_DATA_BATCHES = 2 + NUM_DATA_BATCHES = 1 INIT_STEPS = 10 DATASET_SIZE = 100 RENDER = True @@ -247,7 +252,8 @@ def rollout( # Load agent agent = load_agent( # path_to_cpt="checkpoints/model_guidance_log_replay__S_1__04_26_09_02_20_677_000833.pt", - path_to_cpt="checkpoints/model_guidance_log_replay__S_3__04_27_13_13_33_780_013762.pt", + # path_to_cpt="checkpoints/model_guidance_log_replay__S_3__04_27_13_13_33_780_013762.pt", # Trained with collision penalty + path_to_cpt="checkpoints/model_guidance_log_replay__S_3__04_28_15_56_44_152_014083.pt", # Trained without collision penalty ).to(DEVICE) # Override default environment settings to match those the agent was trained with diff --git a/gpudrive/env/base_env.py b/gpudrive/env/base_env.py index 3006c3c0f..cb1258238 100755 --- a/gpudrive/env/base_env.py +++ b/gpudrive/env/base_env.py @@ -126,6 +126,7 @@ def _setup_environment_parameters(self): elif self.config.init_mode == "wosac_train": params.readFromTracksToPredict = True params.isStaticAgentControlled = True + params.controlExperts = False elif self.config.init_mode == "all_objects": params.isStaticAgentControlled = True params.initOnlyValidAgentsAtFirstStep = False diff --git a/gpudrive/env/env_torch.py b/gpudrive/env/env_torch.py index 60cec0fef..cdafc62ca 100755 --- a/gpudrive/env/env_torch.py +++ b/gpudrive/env/env_torch.py @@ -1704,12 +1704,13 @@ def get_scenario_ids(self): add_reference_speed=True, add_reference_heading=True, reward_type="guided_autonomy", + init_mode="wosac_train", ) render_config = RenderConfig() # Create data loader train_loader = SceneDataLoader( - root="data/processed/examples", + root="data/processed/wosac/debug", batch_size=1, dataset_size=1, sample_with_replacement=False, @@ -1753,7 +1754,7 @@ def get_scenario_ids(self): print(f"Highlighted agent: {highlight_agent}") print(f"Position: {agent_positions[-1]}") - for t in range(env.init_steps, env.episode_len): + for t in range(env.init_steps, env.init_steps + 10): print(f"Step: {t+1}") # Step the environment diff --git a/src/bindings.cpp b/src/bindings.cpp index e44ff4619..e3edfe6cd 100755 --- a/src/bindings.cpp +++ b/src/bindings.cpp @@ -61,7 +61,8 @@ namespace madrona_gpudrive .def_rw("disableClassicalObs", &Parameters::disableClassicalObs) .def_rw("isStaticAgentControlled", &Parameters::isStaticAgentControlled) .def_rw("readFromTracksToPredict", &Parameters::readFromTracksToPredict) - .def_rw("initSteps", &Parameters::initSteps); + .def_rw("initSteps", &Parameters::initSteps) + .def_rw("controlExperts", &Parameters::controlExperts); // Define CollisionBehaviour enum nb::enum_(m, "CollisionBehaviour")