Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion examples/eval/run_wosac_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,8 @@
# WOSAC
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from waymo_open_dataset.protos import sim_agents_submission_pb2
from eval.wosac_eval import WOSACMetrics
# from eval.wosac_eval import WOSACMetrics
from eval.wosac_eval_origin import WOSACMetrics


def get_state(env):
Expand Down
31 changes: 23 additions & 8 deletions examples/eval/wosac_eval_origin.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@

import tensorflow as tf
import waymo_open_dataset.wdl_limited.sim_agents_metrics.metrics as wosac_metrics
from waymo_open_dataset.utils.sim_agents import submission_specs

# import eval.wosac_metrics.metrics as wosac_metrics

from google.protobuf import text_format
Expand All @@ -23,12 +25,17 @@ class WOSACMetrics(Metric):
validation metrics based on ground truth trajectory, using waymo_open_dataset api
"""

def __init__(self, prefix: str = "", ego_only: bool = False) -> None:
def __init__(self, challenge_type = None, prefix: str = "", ego_only: bool = False) -> None:
super().__init__()
self.is_mp_init = False
self.prefix = prefix
self.ego_only = ego_only
self.wosac_config = load_metrics_config()
if challenge_type is None:
self.challenge_type = submission_specs.ChallengeType.SIM_AGENTS
else:
self.challenge_type = challenge_type

self.wosac_config = load_metrics_config(self.challenge_type)

self.field_names = [
"metametric",
Expand All @@ -43,6 +50,7 @@ def __init__(self, prefix: str = "", ego_only: bool = False) -> None:
"distance_to_road_edge_likelihood",
"offroad_indication_likelihood",
"min_average_displacement_error",
"traffic_light_violation_likelihood",
"simulated_collision_rate",
"simulated_offroad_rate",
]
Expand Down Expand Up @@ -126,6 +134,9 @@ def update(
self.min_average_displacement_error += (
scenario_metrics.min_average_displacement_error
)
self.traffic_light_violation_likelihood += (
scenario_metrics.traffic_light_violation_likelihood
)
self.simulated_collision_rate += scenario_metrics.simulated_collision_rate
self.simulated_offroad_rate += scenario_metrics.simulated_offroad_rate

Expand Down Expand Up @@ -155,14 +166,18 @@ def compute(self) -> Dict[str, Tensor]:
return out_dict


def load_metrics_config() -> sim_agents_metrics_pb2.SimAgentMetricsConfig:
def load_metrics_config(
challenge_type: submission_specs.ChallengeType,
) -> sim_agents_metrics_pb2.SimAgentMetricsConfig:
"""Loads the `SimAgentMetricsConfig` used for the challenge."""
# pylint: disable=line-too-long
# pyformat: disable
import waymo_open_dataset
config_path = '{pyglib_resource}/wdl_limited/sim_agents_metrics/challenge_2024_config.textproto'.format(
pyglib_resource=waymo_open_dataset.__path__[0]
)
pyglib_resource = waymo_open_dataset.__path__[0]
if challenge_type == submission_specs.ChallengeType.SIM_AGENTS:
config_path = f'{pyglib_resource}/wdl_limited/sim_agents_metrics/challenge_2025_sim_agents_config.textproto' # pylint: disable=line-too-long
elif challenge_type == submission_specs.ChallengeType.SCENARIO_GEN:
config_path = f'{pyglib_resource}/wdl_limited/sim_agents_metrics/challenge_2025_scenario_gen_config.textproto' # pylint: disable=line-too-long
else:
raise ValueError(f'Unsupported {challenge_type=}')
with open(config_path, 'r') as f:
config = sim_agents_metrics_pb2.SimAgentMetricsConfig()
text_format.Parse(f.read(), config)
Expand Down