diff --git a/examples/eval/run_wosac_eval.py b/examples/eval/run_wosac_eval.py index 0a7e4f654..c08fa9e36 100644 --- a/examples/eval/run_wosac_eval.py +++ b/examples/eval/run_wosac_eval.py @@ -14,7 +14,8 @@ # WOSAC sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) from waymo_open_dataset.protos import sim_agents_submission_pb2 -from eval.wosac_eval import WOSACMetrics +# from eval.wosac_eval import WOSACMetrics +from eval.wosac_eval_origin import WOSACMetrics def get_state(env): diff --git a/examples/eval/wosac_eval_origin.py b/examples/eval/wosac_eval_origin.py index 55a581816..943b4e8da 100644 --- a/examples/eval/wosac_eval_origin.py +++ b/examples/eval/wosac_eval_origin.py @@ -6,6 +6,8 @@ import tensorflow as tf import waymo_open_dataset.wdl_limited.sim_agents_metrics.metrics as wosac_metrics +from waymo_open_dataset.utils.sim_agents import submission_specs + # import eval.wosac_metrics.metrics as wosac_metrics from google.protobuf import text_format @@ -23,12 +25,17 @@ class WOSACMetrics(Metric): validation metrics based on ground truth trajectory, using waymo_open_dataset api """ - def __init__(self, prefix: str = "", ego_only: bool = False) -> None: + def __init__(self, challenge_type = None, prefix: str = "", ego_only: bool = False) -> None: super().__init__() self.is_mp_init = False self.prefix = prefix self.ego_only = ego_only - self.wosac_config = load_metrics_config() + if challenge_type is None: + self.challenge_type = submission_specs.ChallengeType.SIM_AGENTS + else: + self.challenge_type = challenge_type + + self.wosac_config = load_metrics_config(self.challenge_type) self.field_names = [ "metametric", @@ -43,6 +50,7 @@ def __init__(self, prefix: str = "", ego_only: bool = False) -> None: "distance_to_road_edge_likelihood", "offroad_indication_likelihood", "min_average_displacement_error", + "traffic_light_violation_likelihood", "simulated_collision_rate", "simulated_offroad_rate", ] @@ -126,6 +134,9 @@ def update( self.min_average_displacement_error += ( scenario_metrics.min_average_displacement_error ) + self.traffic_light_violation_likelihood += ( + scenario_metrics.traffic_light_violation_likelihood + ) self.simulated_collision_rate += scenario_metrics.simulated_collision_rate self.simulated_offroad_rate += scenario_metrics.simulated_offroad_rate @@ -155,14 +166,18 @@ def compute(self) -> Dict[str, Tensor]: return out_dict -def load_metrics_config() -> sim_agents_metrics_pb2.SimAgentMetricsConfig: +def load_metrics_config( + challenge_type: submission_specs.ChallengeType, +) -> sim_agents_metrics_pb2.SimAgentMetricsConfig: """Loads the `SimAgentMetricsConfig` used for the challenge.""" - # pylint: disable=line-too-long - # pyformat: disable import waymo_open_dataset - config_path = '{pyglib_resource}/wdl_limited/sim_agents_metrics/challenge_2024_config.textproto'.format( - pyglib_resource=waymo_open_dataset.__path__[0] - ) + pyglib_resource = waymo_open_dataset.__path__[0] + if challenge_type == submission_specs.ChallengeType.SIM_AGENTS: + config_path = f'{pyglib_resource}/wdl_limited/sim_agents_metrics/challenge_2025_sim_agents_config.textproto' # pylint: disable=line-too-long + elif challenge_type == submission_specs.ChallengeType.SCENARIO_GEN: + config_path = f'{pyglib_resource}/wdl_limited/sim_agents_metrics/challenge_2025_scenario_gen_config.textproto' # pylint: disable=line-too-long + else: + raise ValueError(f'Unsupported {challenge_type=}') with open(config_path, 'r') as f: config = sim_agents_metrics_pb2.SimAgentMetricsConfig() text_format.Parse(f.read(), config)