Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -254,6 +254,10 @@ examples/experimental/logs/*
# Sbatch scripts
*.sh

# Dataset
wosac/
other/

# Videos
videos/
output_videos_larger_dataset/
Expand Down
14 changes: 7 additions & 7 deletions baselines/ppo/config/ppo_waypoint.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,14 @@ mode: "train"
use_rnn: false
eval_model_path: null
baseline: false
data_dir: data/processed/wosac/validation_json_100
data_dir: data/processed/wosac/validation_json_1
continue_training: false
model_cpt: null

environment: # Overrides default environment configs (see pygpudrive/env/config.py)
name: "gpudrive"
num_worlds: 100 # Number of parallel environments
k_unique_scenes: 100 # Number of unique scenes to sample from
k_unique_scenes: 1 # Number of unique scenes to sample from
max_controlled_agents: 64 # Maximum number of agents controlled by the model. Make sure this aligns with the variable kMaxAgentCount in src/consts.hpp
ego_state: true
road_map_obs: true
Expand Down Expand Up @@ -84,7 +84,7 @@ train:
clip_coef: 0.2
clip_vloss: false
vf_clip_coef: 0.2
ent_coef: 0.001
ent_coef: 0.005
vf_coef: 0.5
max_grad_norm: 0.5
target_kl: null
Expand All @@ -102,14 +102,14 @@ train:
num_parameters: 0 # Total trainable parameters, to be filled at runtime

# # # Checkpointing # # #
checkpoint_interval: 500 # Save policy every k iterations
checkpoint_interval: 250 # Save policy every k iterations
checkpoint_path: "./runs"

# # # Rendering # # #
render: false # Determines whether to render the environment (note: will slow down training)
render: true # Determines whether to render the environment (note: will slow down training)
render_3d: false # Render simulator state in 3d or 2d
render_interval: 150 # Render every k iterations
render_k_scenarios: 2 # Number of scenarios to render
render_interval: 200 # Render every k iterations
render_k_scenarios: 1 # Number of scenarios to render
render_format: "mp4" # Options: gif, mp4
render_fps: 20 # Frames per second
zoom_radius: 100
Expand Down
Binary file not shown.
2 changes: 1 addition & 1 deletion data_utils/process_waymo_files.py
Original file line number Diff line number Diff line change
Expand Up @@ -685,7 +685,7 @@ def process_data(args):
)
parser.add_argument(
"--id_as_filename",
default=False,
default=True,
action="store_true",
help="Use the unique scenario id as the filename",
)
Expand Down
28 changes: 28 additions & 0 deletions examples/eval/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
## Waymo Open Sim Agent Challenge (WOSAC) evaluation


## Requirements
Prerequisite
```
pip install --no-deps waymo-open-dataset-tf-2-12-0==1.6.4
pip install --no-deps git+https://github.com/waymo-research/waymax.git@main#egg=waymo-waymax
```

## Dataset
Extract TF example from raw waymo TF example dataset using
```
python examples/eval/extract_dataset.py --data_dir XXXX --save_dir xxxx --dataset [train/val/val_interactive]
```

e.g.

```
python examples/eval/extract_dataset.py --data_dir data/raw --save_dir data/processed/wosac --dataset all
```


## Evaluation
Run eval with
```
python wosac_eval.py
```
File renamed without changes.
137 changes: 137 additions & 0 deletions examples/eval/extract_dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,137 @@
import os
# Force JAX to use CPU only
os.environ["JAX_PLATFORMS"] = "cpu"

import tensorflow as tf
import glob
import argparse
# import pickle
from tqdm import tqdm
from waymo_open_dataset.protos import scenario_pb2


# from waymax import dataloader
# from waymax.config import DataFormat
import functools

MAX_NUM_OBJECTS = 64
MAX_POLYLINES = 256
MAX_TRAFFIC_LIGHTS = 16
CURRENT_INDEX = 10
NUM_POINTS_POLYLINE = 30

def tf_preprocess(serialized: bytes) -> dict[str, tf.Tensor]:
"""
Preprocesses the serialized data.

Args:
serialized (bytes): The serialized data.

Returns:
dict[str, tf.Tensor]: The preprocessed data.
"""
womd_features = dataloader.womd_utils.get_features_description(
include_sdc_paths=False,
max_num_rg_points=30000,
num_paths=None,
num_points_per_path=None,
)
womd_features['scenario/id'] = tf.io.FixedLenFeature([1], tf.string)

deserialized = tf.io.parse_example(serialized, womd_features)
parsed_id = deserialized.pop('scenario/id')
deserialized['scenario/id'] = tf.io.decode_raw(parsed_id, tf.uint8)
return dataloader.preprocess_womd_example(
deserialized,
aggregate_timesteps=True,
max_num_objects=None,
)

def tf_postprocess(example: dict[str, tf.Tensor]):
"""
Postprocesses the example.

Args:
example (dict[str, tf.Tensor]): The example to be postprocessed.

Returns:
tuple: A tuple containing the scenario ID and the postprocessed scenario.
"""
scenario = dataloader.simulator_state_from_womd_dict(example)
scenario_id = example['scenario/id']
return scenario_id, scenario

def data_process(
data_dir: str,
save_dir: str,
):
"""
Process the Waymax dataset and save the processed data.

Args:
data_dir (str): Directory path of the Waymax dataset.
save_dir (str): Directory path to save the processed data.
"""
dataset = tf.data.TFRecordDataset(
data_dir, compression_type=""
)

os.makedirs(save_dir, exist_ok=True)

for tf_data in dataset:
tf_data = tf_data.numpy()
scenario = scenario_pb2.Scenario()
scenario.ParseFromString(bytes(tf_data))
scenario_id = scenario.scenario_id

scenario_filename = os.path.join(save_dir, scenario_id+'.tfrecords')

# check if file exists
if os.path.exists(scenario_filename):
continue

# Remove the .as_posix() method call since scenario_filename is a string
with tf.io.TFRecordWriter(scenario_filename) as file_writer:
file_writer.write(tf_data)


if __name__ == '__main__':
# add arguments
parser = argparse.ArgumentParser()
parser.add_argument('--data_dir', type=str, default='/data/Dataset/Waymo/V1_2_tf')
parser.add_argument('--save_dir', type=str, default='/data/Dataset/Waymo/VBD')
parser.add_argument('--dataset', type=str, default='all')
parser.add_argument('--num_workers', type=int, default=1) # Change default to 1
args = parser.parse_args()

os.makedirs(args.save_dir, exist_ok=True)

print(f'Processing data from {args.data_dir} and Saving to {args.save_dir}')

def process(dataset):
"""
Process a specific dataset and save the processed data.

Args:
dataset (str): Name of the dataset to process.
"""
data_path = os.path.join(args.data_dir, dataset+'/*')
data_files = glob.glob(data_path)
save_dir = os.path.join(args.save_dir, dataset+'_extracted')
n_files = len(data_files)
print(f'Processing {n_files} files in {data_path} dataset')
os.makedirs(save_dir, exist_ok=True)
print(f'Saving to {save_dir}')

# Process files one by one instead of using multiprocessing
for data_file in tqdm(data_files):
data_process(data_file, save_dir)

if args.dataset == 'all' or args.dataset == 'train':
process('training')

if args.dataset == 'all' or args.dataset == 'val':
process('validation')

if args.dataset == 'all' or args.dataset == 'val_interactive':
process('validation_interactive')
Loading