Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
6c38340
Add tl_states to extraction script.
daphne-cornelisse Apr 28, 2025
2eb2d18
WIP
daphne-cornelisse Apr 28, 2025
c9fe81d
Add back metadata
daphne-cornelisse May 9, 2025
0e990fa
Merge dev into branch
daphne-cornelisse May 9, 2025
c972d47
Add tl_states struct and access through simulator.
daphne-cornelisse May 9, 2025
cf644ee
Add time index
daphne-cornelisse May 9, 2025
b18e9e6
Cleanup
daphne-cornelisse May 9, 2025
ce0a9f0
Cleanup
daphne-cornelisse May 9, 2025
672dd15
SMall fix
daphne-cornelisse May 9, 2025
8629618
Add tl state data struct
daphne-cornelisse May 9, 2025
3a668df
Add minimum test script
daphne-cornelisse May 9, 2025
667508e
Add minimum test script
daphne-cornelisse May 9, 2025
e8927fd
[Work In progress] Fixed initialization and made TL a singleton tensor
aaravpandya May 25, 2025
3252482
Fix json init and export
aaravpandya May 26, 2025
0cfb694
mean centering
aaravpandya May 26, 2025
529939e
Merge branch 'dev' into feat/add_tl_states
eugenevinitsky May 26, 2025
20a33ca
Test and omit unnesecessary code in tl obs
daphne-cornelisse Jun 10, 2025
1978703
Changed data access format in python code
Jun 13, 2025
8c5dfb3
Added a more interpretable positional element arrays
Jun 13, 2025
6efa90f
Unpack tl_states correctly
daphne-cornelisse Jun 15, 2025
f1e86c1
Add tl state plotting function in visualizer
daphne-cornelisse Jun 15, 2025
1235797
mini bug
daphne-cornelisse Jun 15, 2025
5a45e2a
Fix traffic lights by exporting everything as float32
daphne-cornelisse Jun 16, 2025
a0de436
Improve colors
daphne-cornelisse Jun 16, 2025
52d7379
Remove test file
daphne-cornelisse Jun 16, 2025
16413cb
Merge branch 'dev' into feat/add_tl_states
daphne-cornelisse Jun 16, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view

Large diffs are not rendered by default.

Large diffs are not rendered by default.

Large diffs are not rendered by default.

17 changes: 7 additions & 10 deletions data_utils/process_waymo_files.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,7 @@ def _init_tl_object(mapstate: scenario_pb2.DynamicMapState) -> Dict[int, Any]:
"x": lane_state.stop_point.x,
"y": lane_state.stop_point.y,
"z": lane_state.stop_point.z,
"lane_id": lane_state.lane,
}
return returned_dict

Expand Down Expand Up @@ -327,20 +328,16 @@ def waymo_to_scenario(

# Construct the traffic light states
tl_dict = defaultdict(
lambda: {"state": [], "x": [], "y": [], "z": [], "time_index": []}
lambda: {"state": [], "x": [], "y": [], "z": [], "time_index": [], "lane_id": []}
)
all_keys = ["state", "x", "y", "z"]
i = 0
for dynamic_map_state in protobuf.dynamic_map_states:
for i, dynamic_map_state in enumerate(protobuf.dynamic_map_states):
traffic_light_dict = _init_tl_object(dynamic_map_state)
# there is a traffic light but we don't want traffic light scenes so just return
if len(traffic_light_dict) > 0:
return
for id, value in traffic_light_dict.items():
for lane_id, value in traffic_light_dict.items():
for key in all_keys:
tl_dict[id][key].append(value[key])
tl_dict[id]["time_index"].append(i)
i += 1
tl_dict[lane_id][key].append(value[key])
tl_dict[lane_id]["time_index"].append(i)
tl_dict[lane_id]["lane_id"].append(lane_id)

# Construct the map states
roads = []
Expand Down
141 changes: 141 additions & 0 deletions gpudrive/datatypes/observation.py
Original file line number Diff line number Diff line change
Expand Up @@ -366,3 +366,144 @@ def one_hot_encode_bev_map(self):
self.bev_segmentation_map.long(),
num_classes=constants.NUM_MADRONA_ENTITY_TYPES, # From size of Madrona EntityType
)


@dataclass
class TrafficLightObs:
"""
A dataclass that represents traffic light information in the scenario.

This data struct contains the time series of traffic light information.
It contains the state (unknown: 0, stop: 1, caution: 2, go: 3), position, and lane_ids.

Initialized from tl_states_tensor (from Manager.trafficLightTensor()).
For details, see `TrafficLightState` in src/types.hpp.
Shape: (num_worlds, max_traffic_lights, num_timesteps * features).

Attributes:
state: The state of each traffic light (0=unknown, 1=stop, 2=caution, 3=go)
pos_x: X-coordinate of the traffic light
pos_y: Y-coordinate of the traffic light
pos_z: Z-coordinate of the traffic light
time_index: Time index of the traffic light state
lane_id: Lane ID associated with the traffic light
valid_mask: Boolean mask indicating valid traffic lights
"""

state: torch.Tensor
pos_x: torch.Tensor
pos_y: torch.Tensor
pos_z: torch.Tensor
time_index: torch.Tensor
lane_id: torch.Tensor
valid_mask: torch.Tensor

def __init__(
self,
tl_states_tensor: torch.Tensor,

):
"""Initializes the traffic light observation from a tensor."""
traj_length = constants.LOG_TRAJECTORY_LENGTH

# Calculate indices based on C++ struct layout:
# laneId (1) + state[traj_length] + x[traj_length] + y[traj_length] + z[traj_length] + timeIndex[traj_length] + numStates (1)

lane_id_end_idx = 1
state_end_idx = lane_id_end_idx + traj_length
pos_x_end_idx = state_end_idx + traj_length
pos_y_end_idx = pos_x_end_idx + traj_length
pos_z_end_idx = pos_y_end_idx + traj_length
time_index_end_idx = pos_z_end_idx + traj_length

# Extract fields according to C++ struct layout
# See `TrafficLightState` in src/types.hpp for details
self.lane_id = tl_states_tensor[:, :, 0] # Single lane ID value
self.state = tl_states_tensor[
:, :, lane_id_end_idx:state_end_idx
] # state[traj_length]
self.pos_x = tl_states_tensor[
:, :, state_end_idx:pos_x_end_idx
].float() # x[traj_length]
self.pos_y = tl_states_tensor[
:, :, pos_x_end_idx:pos_y_end_idx
].float() # y[traj_length]
self.pos_z = tl_states_tensor[
:, :, pos_y_end_idx:pos_z_end_idx
].float() # z[traj_length]
self.time_index = tl_states_tensor[
:, :, pos_z_end_idx:time_index_end_idx
] # timeIndex[traj_length]
self.num_states = tl_states_tensor[
:, :, time_index_end_idx
] # Single numStates value

# Create a valid mask based on numStates
# Traffic lights are valid if they have numStates > 0
self.valid_mask = self.num_states > 0

@classmethod
def from_tensor(
cls,
tl_states_tensor: madrona_gpudrive.madrona.Tensor,
backend="torch",
device="cuda",
):
"""Creates a TrafficLightObs from a tensor.

Args:
tl_states_tensor: The traffic light state tensor from the simulation
backend: Which backend to use ("torch" or "jax")
device: The device to place tensors on

Returns:
A TrafficLightObs instance
"""
if backend == "torch":
tensor = tl_states_tensor.to_torch().clone().to(device)
obj = cls(tensor)
return obj
elif backend == "jax":
raise NotImplementedError("JAX backend not implemented yet.")

def normalize(self):
"""Normalizes the traffic light observation coordinates."""

# Normalize position coordinates
self.pos_x = normalize_min_max(
tensor=self.pos_x,
min_val=constants.MIN_REL_COORD,
max_val=constants.MAX_REL_COORD,
)
self.pos_y = normalize_min_max(
tensor=self.pos_y,
min_val=constants.MIN_REL_COORD,
max_val=constants.MAX_REL_COORD,
)
self.pos_z = normalize_min_max(
tensor=self.pos_z,
min_val=constants.MIN_Z_COORD,
max_val=constants.MAX_Z_COORD,
)
def one_hot_encode_states(self):
"""One-hot encodes the traffic light states.

Converts the state values to one-hot encoded vectors with 4 classes:
0: Unknown
1: Stop
2: Caution
3: Go
"""
# Make sure values are in range 0-3
state_clamped = torch.clamp(self.state, 0, 3)
# One-hot encode
self.state_onehot = torch.nn.functional.one_hot(
state_clamped, num_classes=4
) * self.valid_mask.unsqueeze(-1)

return self.state_onehot

@property
def shape(self) -> tuple[int, ...]:
"""Shape: (num_worlds, max_traffic_lights, num_timesteps)."""
return self.state.shape
8 changes: 4 additions & 4 deletions gpudrive/env/env_torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -2060,7 +2060,7 @@ def render(self, focus_env_idx=0, focus_agent_idx=[0, 1]):

if __name__ == "__main__":

FOCUS_AGENTS = [0, 1, 2, 3, 4]
FOCUS_AGENTS = [0] #[0, 1, 2, 3, 4]

env_config = EnvConfig(
guidance=True,
Expand All @@ -2079,8 +2079,8 @@ def render(self, focus_env_idx=0, focus_agent_idx=[0, 1]):

# Create data loader
train_loader = SceneDataLoader(
root="data/processed/wosac/validation_interactive/json",
batch_size=10,
root="data/processed/tl",
batch_size=1,
dataset_size=100,
sample_with_replacement=False,
shuffle=False,
Expand Down Expand Up @@ -2116,7 +2116,7 @@ def render(self, focus_env_idx=0, focus_agent_idx=[0, 1]):

obs = env.get_obs(control_mask)
reward = env.get_rewards()
if time_step % 10 == 0 or time_step > env.episode_len - 3:
if time_step % 20 == 0 or time_step > env.episode_len - 3:
sim_states, agent_obs = env.render(focus_agent_idx=FOCUS_AGENTS)
sim_frames.append(img_from_fig(sim_states[0]))
for i in FOCUS_AGENTS:
Expand Down
100 changes: 100 additions & 0 deletions gpudrive/visualize/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
GlobalEgoState,
PartnerObs,
LidarObs,
TrafficLightObs,
)
from gpudrive.datatypes.trajectory import LogTrajectory, VBDTrajectory
from gpudrive.datatypes.control import ResponseType
Expand Down Expand Up @@ -94,6 +95,11 @@ def initialize_static_scenario_data(
)

self.trajectory = reference_trajectory
self.tl_obs = TrafficLightObs.from_tensor(
tl_states_tensor=self.sim_object.tl_state_tensor(),
backend=self.backend,
device=self.device,
)

def plot_simulator_state(
self,
Expand Down Expand Up @@ -320,6 +326,14 @@ def plot_simulator_state(
plot_guidance_up_to_time=plot_guidance_up_to_time,
)

self._plot_traffic_lights(
ax=ax,
env_idx=env_idx,
tl_obs=self.tl_obs,
time_step=time_step if time_step is not None else 0,
marker_size_scale=marker_scale,
)

# Draw the agents
self._plot_filtered_agent_bounding_boxes(
ax=ax,
Expand Down Expand Up @@ -946,6 +960,92 @@ def _plot_vbd_trajectory(
zorder=0,
)

def _plot_traffic_lights(
self,
ax: matplotlib.axes.Axes,
env_idx: int,
tl_obs: "TrafficLightObs",
time_step: int = 0,
marker_size_scale: float = 1.0,
):
"""Plot traffic light states as colored dots.

Args:
ax: Matplotlib axis to plot on
env_idx: Environment index
tl_obs: Traffic light observation object
time_step: Current time step
marker_size_scale: Scale factor for marker size
"""

# Traffic light state colors
TL_STATE_COLORS = {
0: "#C5C5C5", # Unknown - gray
1: "r", # Stop - red
2: "tab:orange", # Caution - orange
3: "g", # Go - green
}

# Get valid traffic lights for this environment
valid_mask = tl_obs.valid_mask[env_idx, :]
if not valid_mask.any():
return

# Clamp time_step to available data
max_time_idx = tl_obs.state.shape[2] - 1
time_step = min(time_step, max_time_idx)

# Get traffic light data for valid lights at current time step
valid_indices = torch.where(valid_mask)[0]

for tl_idx in valid_indices:
# Get position (use first valid position if time series)
pos_x = tl_obs.pos_x[env_idx, tl_idx, time_step].item()
pos_y = tl_obs.pos_y[env_idx, tl_idx, time_step].item()

# Skip if position is invalid (0,0 or out of bounds)
if (
(pos_x == 0 and pos_y == 0)
or abs(pos_x) > 1000
or abs(pos_y) > 1000
):
continue

# Get current state
state = int(tl_obs.state[env_idx, tl_idx, time_step].item())
state = max(0, min(3, state)) # Clamp to valid range

color = TL_STATE_COLORS[state]

if self.render_3d:
# Plot as elevated marker in 3D
height = 0.2 # Height above ground for visibility
ax.scatter3D(
[pos_x],
[pos_y],
[height],
color=color,
s=60 * marker_size_scale,
marker="o",
edgecolors="black",
linewidth=0.5,
alpha=0.8,
zorder=10,
)
else:
# Plot as 2D marker
ax.scatter(
pos_x,
pos_y,
color=color,
s=30 * marker_size_scale,
marker="o",
edgecolors="black",
linewidth=0.5,
alpha=0.9,
zorder=10,
)

def _get_endpoints(self, x, y, length, yaw):
"""Compute the start and end points of a road segment."""
center = np.array([x, y])
Expand Down
1 change: 1 addition & 0 deletions src/bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,7 @@ namespace madrona_gpudrive
.def("set_maps", &Manager::setMaps)
.def("world_means_tensor", &Manager::worldMeansTensor)
.def("metadata_tensor", &Manager::metadataTensor)
.def("tl_state_tensor", &Manager::trafficLightTensor)
.def("vbd_trajectory_tensor", &Manager::vbdTrajectoryTensor)
.def("map_name_tensor", &Manager::mapNameTensor)
.def("deleteAgents", [](Manager &self, nb::dict py_agents_to_delete) {
Expand Down
1 change: 1 addition & 0 deletions src/consts.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ namespace consts {
inline constexpr madrona::CountT kMaxAgentCount = 32;
inline constexpr madrona::CountT kMaxRoadEntityCount = 10000;
inline constexpr madrona::CountT kMaxAgentMapObservationsCount = 186;
inline constexpr madrona::CountT kMaxTrafficLightCount = 16;

inline constexpr bool useEstimatedYaw = true;

Expand Down
3 changes: 3 additions & 0 deletions src/init.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -55,10 +55,13 @@ namespace madrona_gpudrive
{
MapObject objects[MAX_OBJECTS];
MapRoad roads[MAX_ROADS];
TrafficLightState trafficLightStates[consts::kMaxTrafficLightCount];

uint32_t numObjects;
uint32_t numRoads;
uint32_t numRoadSegments;
uint32_t numTrafficLights;
bool hasTrafficLights;
MapVector2 mean;

char mapName[32];
Expand Down
Loading
Loading