diff --git a/baselines/ppo/config/ppo_guided_autonomy.yaml b/baselines/ppo/config/ppo_guided_autonomy.yaml
index 985302f7c..a9302f829 100644
--- a/baselines/ppo/config/ppo_guided_autonomy.yaml
+++ b/baselines/ppo/config/ppo_guided_autonomy.yaml
@@ -28,8 +28,8 @@ environment: # Overrides default environment configs (see pygpudrive/env/config.
# Reward function
reward_type: "guided_autonomy"
- collision_weight: -0.01
- off_road_weight: -0.01
+ collision_weight: -0.1
+ off_road_weight: -0.1
guidance_speed_weight: 0.005
guidance_heading_weight: 0.005
smoothness_weight: 0.0
@@ -52,7 +52,7 @@ environment: # Overrides default environment configs (see pygpudrive/env/config.
wandb:
entity: ""
project: "humanlike"
- group: "debug_mini"
+ group: "debug_mini"
mode: "online" # Options: online, offline, disabled
tags: ["ppo", "ff"]
@@ -69,7 +69,7 @@ train:
resample_scenes: false
resample_dataset_size: 10_000 # Number of unique scenes to sample from
resample_interval: 5_000_000
- sample_with_replacement: false
+ sample_with_replacement: true
shuffle_dataset: true
file_prefix: ""
diff --git a/examples/eval/notebooks/02_guidance_data_analyis.ipynb b/examples/eval/notebooks/02_guidance_data_analyis.ipynb
new file mode 100644
index 000000000..6018db22e
--- /dev/null
+++ b/examples/eval/notebooks/02_guidance_data_analyis.ipynb
@@ -0,0 +1,14960 @@
+{
+ "cells": [
+ {
+ "cell_type": "code",
+ "execution_count": 22,
+ "id": "ae958675",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "import importlib\n",
+ "import gpudrive\n",
+ "importlib.reload(gpudrive)\n",
+ "\n",
+ "import numpy as np\n",
+ "import os\n",
+ "import torch\n",
+ "from pathlib import Path\n",
+ "\n",
+ "# Set working directory to the base directory 'gpudrive_madrona'\n",
+ "working_dir = Path.cwd()\n",
+ "while working_dir.name != 'gpudrive':\n",
+ " working_dir = working_dir.parent\n",
+ " if working_dir == Path.home():\n",
+ " raise FileNotFoundError(\"Base directory 'gpudrive_madrona' not found\")\n",
+ "os.chdir(working_dir)\n",
+ "import torch\n",
+ "from PIL import Image\n",
+ "import seaborn as sns\n",
+ "import warnings\n",
+ "import matplotlib as mpl\n",
+ "import matplotlib.pyplot as plt\n",
+ "import datetime\n",
+ "\n",
+ "from gpudrive.env.env_torch import GPUDriveTorchEnv\n",
+ "from gpudrive.env.config import EnvConfig, RenderConfig\n",
+ "from gpudrive.env.dataset import SceneDataLoader\n",
+ "from gpudrive.visualize.utils import img_from_fig\n",
+ "\n",
+ "sns.set(\"notebook\", font_scale=1.05, rc={\"figure.figsize\": (10, 5)})\n",
+ "sns.set_style(\"ticks\", rc={\"figure.facecolor\": \"none\", \"axes.facecolor\": \"none\"})\n",
+ "%config InlineBackend.figure_format = 'svg'\n",
+ "warnings.filterwarnings(\"ignore\")\n",
+ "plt.set_loglevel(\"WARNING\")\n",
+ "mpl.rcParams[\"lines.markersize\"] = 8\n",
+ "\n",
+ "plt.set_loglevel(\"WARNING\")\n",
+ "%matplotlib inline"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 23,
+ "id": "937797aa",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "MS = 9\n",
+ " \n",
+ "def plot_agent_guidance_info(guidance_obs, center_agent_idx):\n",
+ " \n",
+ " agent_guidance = guidance_obs[center_agent_idx, :, :]\n",
+ " mask = agent_guidance[:, 0] != -1.0\n",
+ " \n",
+ " time_steps = np.arange(agent_guidance.shape[0])\n",
+ " \n",
+ " fig, axes = plt.subplots(1, 3, figsize=(9.5, 3))\n",
+ " \n",
+ " axes[0].scatter(agent_guidance[:, 0][mask], agent_guidance[:, 1][mask], marker='o', s=MS, color='g') \n",
+ " axes[0].set_xlabel(r'Suggested $x$')\n",
+ " axes[0].set_ylabel(r'Suggested $y$')\n",
+ " #axes[0].legend()\n",
+ " axes[0].grid(True, alpha=0.3)\n",
+ " \n",
+ " axes[1].scatter(time_steps[mask], agent_guidance[:, 2][mask], marker='o', color='b', s=MS)\n",
+ " axes[1].set_xlabel('Time')\n",
+ " axes[1].set_ylabel('Suggested speed')\n",
+ " axes[1].grid(True, alpha=0.3)\n",
+ " \n",
+ " axes[2].scatter(time_steps[mask], agent_guidance[:, 3][mask], marker='o', color='#bc6c25', s=MS)\n",
+ " axes[2].set_xlabel('Time')\n",
+ " axes[2].set_ylabel('Suggested heading (rad)')\n",
+ " axes[2].grid(True, alpha=0.3)\n",
+ " \n",
+ " #fig.suptitle(f\"Agent's {center_agent_idx} normalized guidance information\")\n",
+ " \n",
+ " plt.tight_layout()\n",
+ " plt.subplots_adjust(top=0.85, bottom=0.15)\n",
+ " \n",
+ " return fig"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 24,
+ "id": "db3838a5",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "DATASET = \"data/processed/wosac/validation_json_100\"\n",
+ "FIGURES_DIR = \"examples/eval/figures\"\n",
+ "DATA_FOLDER = \"examples/eval/figures_data/\""
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "2aaa50e1",
+ "metadata": {},
+ "source": [
+ "### Checking guidance data"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 25,
+ "id": "dce673ef",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# Load npy array from the file\n",
+ "FOLDER = \"examples/eval/figures_data/guidance/\"\n",
+ "\n",
+ "ref_log_replay = np.load(f\"{FOLDER}reference_log_replay.npy\")\n",
+ "ref_vbd_amortized = np.load(f\"{FOLDER}reference_vbd_amortized.npy\")\n",
+ "ref_vbd_online = np.load(f\"{FOLDER}reference_vbd_online.npy\")"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 26,
+ "id": "c216b641",
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "((204, 91, 6), (204, 91, 6), (210, 91, 6))"
+ ]
+ },
+ "execution_count": 26,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "# Should all be 91 \n",
+ "\n",
+ "# Controlled agent elements for 10 scenarios\n",
+ "ref_log_replay.shape, ref_vbd_amortized.shape, ref_vbd_online.shape"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 27,
+ "id": "a706660b",
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "image/svg+xml": [
+ "\n",
+ "\n",
+ "\n"
+ ],
+ "text/plain": [
+ ""
+ ]
+ },
+ "metadata": {
+ "needs_background": "light"
+ },
+ "output_type": "display_data"
+ }
+ ],
+ "source": [
+ "fig, axs = plt.subplots(1, 3, figsize=(13, 3),)# sharex=True, sharey=True)\n",
+ "\n",
+ "axs[0].hist(ref_log_replay[:, :, :2].flatten(), color='b', alpha=0.5, bins=20, label=\"log_replay\")\n",
+ "axs[1].hist(ref_vbd_amortized[:, :, :2].flatten(), color='r', alpha=1.0, bins=20, label=\"vbd_amortized\")\n",
+ "axs[2].hist(ref_vbd_online[:, :, :2].flatten(), color='g', alpha=1.0, bins=20, label=\"vbd_online\")\n",
+ "\n",
+ "axs[0].set_xlabel(\"Demeaned position in global coord. frame\")\n",
+ "axs[0].set_ylabel(\"Count\")\n",
+ "\n",
+ "fig.legend()\n",
+ "sns.despine()"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 28,
+ "id": "a08ab3a2",
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "image/svg+xml": [
+ "\n",
+ "\n",
+ "\n"
+ ],
+ "text/plain": [
+ ""
+ ]
+ },
+ "metadata": {
+ "needs_background": "light"
+ },
+ "output_type": "display_data"
+ }
+ ],
+ "source": [
+ "fig, axs = plt.subplots(1, 3, figsize=(13, 3) )# sharex=True, sharey=True)\n",
+ "\n",
+ "axs[0].hist(ref_log_replay[:, :, 2:4].flatten(), color='b', alpha=0.5, bins=20, label=\"log_replay\")\n",
+ "axs[1].hist(ref_vbd_amortized[:, :, 2:4].flatten(), color='r', alpha=1.0, bins=20, label=\"vbd_amortized\")\n",
+ "axs[2].hist(ref_vbd_online[:, :, 2:4].flatten(), color='g', alpha=1.0, bins=20, label=\"vbd_online\")\n",
+ "\n",
+ "axs[0].set_xlabel(\"Demeaned velocity [m/s]\")\n",
+ "axs[0].set_ylabel(\"Count\")\n",
+ "\n",
+ "fig.legend()\n",
+ "sns.despine()"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 29,
+ "id": "920285ab",
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "image/svg+xml": [
+ "\n",
+ "\n",
+ "\n"
+ ],
+ "text/plain": [
+ ""
+ ]
+ },
+ "metadata": {
+ "needs_background": "light"
+ },
+ "output_type": "display_data"
+ }
+ ],
+ "source": [
+ "fig, axs = plt.subplots(1, 3, figsize=(13, 3),)# sharex=True, sharey=True)\n",
+ "\n",
+ "axs[0].hist(ref_log_replay[:, :, 4].flatten(), color='b', alpha=0.5, bins=20, label=\"log_replay\")\n",
+ "axs[1].hist(ref_vbd_amortized[:, :, 4].flatten(), color='r', alpha=1.0, bins=20, label=\"vbd_amortized\")\n",
+ "axs[2].hist(ref_vbd_online[:, :, 4].flatten(), color='g', alpha=1.0, bins=20, label=\"vbd_online\")\n",
+ "\n",
+ "axs[0].set_xlabel(\"Headings in global coord. frame [rad]\")\n",
+ "axs[0].set_ylabel(\"Count\")\n",
+ "\n",
+ "# add legend\n",
+ "fig.legend()\n",
+ "sns.despine()"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 30,
+ "id": "47af0e1f",
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "image/svg+xml": [
+ "\n",
+ "\n",
+ "\n"
+ ],
+ "text/plain": [
+ ""
+ ]
+ },
+ "metadata": {
+ "needs_background": "light"
+ },
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "image/svg+xml": [
+ "\n",
+ "\n",
+ "\n"
+ ],
+ "text/plain": [
+ ""
+ ]
+ },
+ "metadata": {
+ "needs_background": "light"
+ },
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "image/svg+xml": [
+ "\n",
+ "\n",
+ "\n"
+ ],
+ "text/plain": [
+ ""
+ ]
+ },
+ "metadata": {
+ "needs_background": "light"
+ },
+ "output_type": "display_data"
+ }
+ ],
+ "source": [
+ "for agent_idx in range(3):\n",
+ " \n",
+ " valid = ref_log_replay[agent_idx, :, -1].astype( bool)\n",
+ "\n",
+ " fig, axs = plt.subplots(1, 3, figsize=(12, 4))\n",
+ " axs[0].set_title(f\"Global position (x, y); agent {agent_idx}\")\n",
+ " axs[0].plot(ref_log_replay[agent_idx, :, 0][valid], ref_log_replay[agent_idx, :, 1][valid], color='b')\n",
+ " #axs[0].plot(ref_vbd_online[agent_idx, :, 0], ref_vbd_online[agent_idx, :, 1], color='g', alpha=0.5)\n",
+ " axs[0].plot(ref_vbd_amortized[agent_idx, :, 0], ref_vbd_amortized[agent_idx, :, 1], color='r')\n",
+ " axs[0].set_xlabel(\"x\")\n",
+ " axs[0].set_ylabel(\"y\")\n",
+ "\n",
+ " axs[1].set_title(f\"Global velocity (x, y); agent {agent_idx}\")\n",
+ " axs[1].plot(ref_log_replay[agent_idx, :, 2][valid], ref_log_replay[agent_idx, :, 3][valid], color='b')\n",
+ " axs[1].plot(ref_vbd_online[agent_idx, :, 2], ref_vbd_online[agent_idx, :, 3], color='g')\n",
+ " axs[1].plot(ref_vbd_amortized[agent_idx, :, 2], ref_vbd_amortized[agent_idx, :, 3], color='r')\n",
+ " axs[2].set_xlabel(r\"$v_x$\")\n",
+ " axs[2].set_ylabel(r\"$v_y$\")\n",
+ "\n",
+ " axs[2].set_title(f\"Global heading; agent {agent_idx}\")\n",
+ " axs[2].plot(ref_log_replay[agent_idx, :, 4][valid], color='b', label=\"log_replay\")\n",
+ " axs[2].plot(ref_vbd_online[agent_idx, :, 4], color='g', label=\"vbd_online\")\n",
+ " axs[2].plot(ref_vbd_amortized[agent_idx, :, 4], color='r', label=\"vbd_amortized\")\n",
+ " axs[2].set_ylabel(r\"$\\theta$\")\n",
+ "\n",
+ " plt.tight_layout()\n",
+ "\n",
+ " fig.legend(loc=\"center right\", facecolor=\"white\", bbox_to_anchor=(1.12, 0.5))\n",
+ " plt.show()"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "abac4ce3",
+ "metadata": {},
+ "source": [
+ "### Dataset stats"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 31,
+ "id": "59043fd6",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "valid_guidance_points = ref_log_replay[:, :, 5].sum(axis=-1)\n",
+ "valid_guidance_points.shape\n",
+ "\n",
+ "# Some trajectories, while valid, are just parked cars\n",
+ "# Detect these by checking for zero velocity trajectories\n",
+ "valid_and_non_zero = ((ref_log_replay[:, :, 2] != 0) & (ref_log_replay[:, :, 5] == 1)).sum(axis=1)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 32,
+ "id": "4708aa7f",
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "image/svg+xml": [
+ "\n",
+ "\n",
+ "\n"
+ ],
+ "text/plain": [
+ ""
+ ]
+ },
+ "metadata": {
+ "needs_background": "light"
+ },
+ "output_type": "display_data"
+ }
+ ],
+ "source": [
+ "fig, axs = plt.subplots(1, 2, figsize=(9, 3.5), sharey=True)\n",
+ "\n",
+ "fig.suptitle(f\"Distribution of guidance points per episode; A = {ref_log_replay.shape[0]} (50 scenarios)\")\n",
+ "\n",
+ "sns.histplot(valid_guidance_points, bins=10, stat='percent', ax=axs[0])\n",
+ "axs[0].grid(True, alpha=0.2)\n",
+ "axs[0].set_xlabel(\"Guidance points per episode.\")\n",
+ "mean_value = valid_guidance_points.mean()\n",
+ "axs[0].axvline(x=mean_value, color='purple', linestyle='--')\n",
+ "axs[0].text(mean_value-3, axs[0].get_ylim()[1]*0.95, r'$\\mu$', color='purple', \n",
+ " fontsize=14, ha='center', va='center')\n",
+ "\n",
+ "sns.histplot(valid_and_non_zero, bins=10, stat='percent', ax=axs[1])\n",
+ "axs[1].grid(True, alpha=0.2)\n",
+ "axs[1].set_xlabel(r\"$\\bf{Nonzero}$ guidance points per episode.\")\n",
+ "mean_value_nonzero = valid_and_non_zero.mean()\n",
+ "axs[1].axvline(x=mean_value_nonzero, color='purple', linestyle='--')\n",
+ "axs[1].text(mean_value_nonzero-3, axs[1].get_ylim()[1]*0.95, r'$\\mu$', color='purple', fontsize=14, ha='center', va='center')\n",
+ "\n",
+ "plt.tight_layout()\n",
+ "sns.despine()\n",
+ "plt.savefig(f\"{FIGURES_DIR}/guidance_points_distribution.pdf\", dpi=300, bbox_inches=\"tight\")"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "2cd9d59d",
+ "metadata": {},
+ "outputs": [],
+ "source": []
+ },
+ {
+ "cell_type": "markdown",
+ "id": "09f01b81",
+ "metadata": {},
+ "source": []
+ },
+ {
+ "cell_type": "markdown",
+ "id": "c0ea8a40",
+ "metadata": {},
+ "source": [
+ "### Make guidance figs"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "ee3b198a",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# GUIDANCE_MODE = \"log_replay\"\n",
+ "\n",
+ "# env_config = EnvConfig(\n",
+ "# reward_type=\"guided_autonomy\",\n",
+ "# guidance=True,\n",
+ "# guidance_mode=GUIDANCE_MODE,\n",
+ "# add_reference_heading=True,\n",
+ "# add_reference_speed=True,\n",
+ "# add_reference_pos_xy=True,\n",
+ "# init_mode=\"wosac_train\",\n",
+ "# smoothen_trajectory=False,\n",
+ "# dynamics_model=\"delta_local\",\n",
+ "# guidance_dropout_prob=0.7,\n",
+ "# init_steps=0,\n",
+ "# )\n",
+ "# render_config = RenderConfig()\n",
+ "\n",
+ "# train_loader = SceneDataLoader(\n",
+ "# root=DATASET,\n",
+ "# batch_size=1,\n",
+ "# dataset_size=100,\n",
+ "# sample_with_replacement=False,\n",
+ "# shuffle=False,\n",
+ "# file_prefix=\"\",\n",
+ " \n",
+ "# )\n",
+ "\n",
+ "# env = GPUDriveTorchEnv(\n",
+ "# config=env_config,\n",
+ "# data_loader=train_loader,\n",
+ "# max_cont_agents=32, \n",
+ "# device=\"cpu\",\n",
+ "# )\n",
+ "\n",
+ "# print(env.data_batch)\n",
+ "\n",
+ "# obs = env.reset(env.cont_agent_mask)\n",
+ "# expert_actions, _, _, _ = env.get_expert_actions()\n",
+ "\n",
+ "# # for time_step in range(env.init_steps, 20):\n",
+ "# # print(f\"Step: {env.step_in_world[0, 0, 0].item()}\")\n",
+ "\n",
+ "# # # Step the environment\n",
+ "# # expert_actions, _, _, _ = env.get_expert_actions()\n",
+ "# # env.step_dynamics(expert_actions[:, :, time_step, :])\n",
+ "\n",
+ "# # obs = env.get_obs(env.cont_agent_mask)\n",
+ "# # rew = env.get_rewards()\n",
+ " \n",
+ "\n",
+ "# # # Save for analysis\n",
+ "# # reference_traj = torch.cat([\n",
+ "# # env.reference_trajectory.pos_xy,\n",
+ "# # env.reference_trajectory.vel_xy,\n",
+ "# # env.reference_trajectory.yaw,\n",
+ "# # env.reference_trajectory.valids,\n",
+ "# # ], dim=-1)\n",
+ "# # reference_traj_np = reference_traj[env.cont_agent_mask].numpy()\n",
+ "\n",
+ "# # np.save(f\"reference_{GUIDANCE_MODE}.npy\", reference_traj_np)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "43a62f4b",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "CENTER_AGENT_IDX = 0\n",
+ "\n",
+ "fig = env.vis.plot_agent_observation(\n",
+ " env_idx=0,\n",
+ " agent_idx=CENTER_AGENT_IDX,\n",
+ " figsize=(10, 10),\n",
+ " trajectory=env.reference_path[CENTER_AGENT_IDX, :, :],\n",
+ " # step_reward=env.guidance_reward[\n",
+ " # 0, 1\n",
+ " # ].item(),\n",
+ " # route_progress=env.route_progress[1],\n",
+ ")\n",
+ "\n",
+ "plt.savefig(\n",
+ " os.path.join(FIGURES_DIR, f\"agent_observation_{CENTER_AGENT_IDX}_{GUIDANCE_MODE}.pdf\"),\n",
+ " dpi=300,\n",
+ " bbox_inches=\"tight\",\n",
+ ")"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "b4020148",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "fig = plot_agent_guidance_info(env.guidance_obs, center_agent_idx=1);\n",
+ "\n",
+ "plt.savefig(\n",
+ " os.path.join(FIGURES_DIR, f\"agent_guidance_info_{CENTER_AGENT_IDX}_{GUIDANCE_MODE}.pdf\"),\n",
+ " dpi=300,\n",
+ " bbox_inches=\"tight\",\n",
+ ")"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "82ac8a30",
+ "metadata": {},
+ "outputs": [],
+ "source": []
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": []
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": "gpudrive",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.11.9"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 5
+}
diff --git a/examples/eval/notebooks/02_guidance_data_analysis.ipynb b/examples/eval/notebooks/02_guidance_data_analysis.ipynb
deleted file mode 100644
index 60eae77a4..000000000
--- a/examples/eval/notebooks/02_guidance_data_analysis.ipynb
+++ /dev/null
@@ -1,63838 +0,0 @@
-{
- "cells": [
- {
- "cell_type": "code",
-<<<<<<< HEAD
- "execution_count": 1,
-=======
- "execution_count": 52,
->>>>>>> dev
- "metadata": {},
- "outputs": [],
- "source": [
- "import importlib\n",
- "import gpudrive\n",
- "importlib.reload(gpudrive)\n",
- "\n",
- "import numpy as np\n",
- "import os\n",
- "import torch\n",
- "from pathlib import Path\n",
- "\n",
- "# Set working directory to the base directory 'gpudrive_madrona'\n",
- "working_dir = Path.cwd()\n",
- "while working_dir.name != 'gpudrive':\n",
- " working_dir = working_dir.parent\n",
- " if working_dir == Path.home():\n",
- " raise FileNotFoundError(\"Base directory 'gpudrive_madrona' not found\")\n",
- "os.chdir(working_dir)\n",
- "import torch\n",
- "from PIL import Image\n",
- "import seaborn as sns\n",
- "import warnings\n",
- "import matplotlib as mpl\n",
- "import matplotlib.pyplot as plt\n",
- "import datetime\n",
- "\n",
- "from gpudrive.env.env_torch import GPUDriveTorchEnv\n",
- "from gpudrive.env.config import EnvConfig, RenderConfig\n",
- "from gpudrive.env.dataset import SceneDataLoader\n",
- "from gpudrive.visualize.utils import img_from_fig\n",
- "\n",
- "sns.set(\"notebook\", font_scale=1.05, rc={\"figure.figsize\": (10, 5)})\n",
- "sns.set_style(\"ticks\", rc={\"figure.facecolor\": \"none\", \"axes.facecolor\": \"none\"})\n",
- "%config InlineBackend.figure_format = 'svg'\n",
- "warnings.filterwarnings(\"ignore\")\n",
- "plt.set_loglevel(\"WARNING\")\n",
- "mpl.rcParams[\"lines.markersize\"] = 8\n",
- "\n",
- "plt.set_loglevel(\"WARNING\")\n",
- "%matplotlib inline"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 2,
- "metadata": {},
- "outputs": [],
- "source": [
- "MS = 9\n",
- " \n",
- "def plot_agent_guidance_info(guidance_obs, center_agent_idx):\n",
- " \n",
- " agent_guidance = guidance_obs[center_agent_idx, :, :]\n",
- " mask = agent_guidance[:, 0] != -1.0\n",
- " \n",
- " time_steps = np.arange(agent_guidance.shape[0])\n",
- " \n",
- " fig, axes = plt.subplots(1, 3, figsize=(9.5, 3))\n",
- " \n",
- " axes[0].scatter(agent_guidance[:, 0][mask], agent_guidance[:, 1][mask], marker='o', s=MS, color='g') \n",
- " axes[0].set_xlabel(r'Suggested $x$')\n",
- " axes[0].set_ylabel(r'Suggested $y$')\n",
- " #axes[0].legend()\n",
- " axes[0].grid(True, alpha=0.3)\n",
- " \n",
- " axes[1].scatter(time_steps[mask], agent_guidance[:, 2][mask], marker='o', color='b', s=MS)\n",
- " axes[1].set_xlabel('Time')\n",
- " axes[1].set_ylabel('Suggested speed')\n",
- " axes[1].grid(True, alpha=0.3)\n",
- " \n",
- " axes[2].scatter(time_steps[mask], agent_guidance[:, 3][mask], marker='o', color='#bc6c25', s=MS)\n",
- " axes[2].set_xlabel('Time')\n",
- " axes[2].set_ylabel('Suggested heading (rad)')\n",
- " axes[2].grid(True, alpha=0.3)\n",
- " \n",
- " #fig.suptitle(f\"Agent's {center_agent_idx} normalized guidance information\")\n",
- " \n",
- " plt.tight_layout()\n",
- " plt.subplots_adjust(top=0.85, bottom=0.15)\n",
- " \n",
- " return fig"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "### Settings"
- ]
- },
- {
- "cell_type": "code",
-<<<<<<< HEAD
- "execution_count": 3,
-=======
- "execution_count": 53,
->>>>>>> dev
- "metadata": {},
- "outputs": [],
- "source": [
- "DATASET = \"data/processed/wosac/validation_json_100\"\n",
- "FIGURES_DIR = \"examples/eval/figures\"\n",
- "DATA_FOLDER = \"examples/eval/figures_data/\""
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "### Make environment"
- ]
- },
- {
- "cell_type": "code",
-<<<<<<< HEAD
- "execution_count": 4,
-=======
- "execution_count": 39,
->>>>>>> dev
- "metadata": {},
- "outputs": [
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "['data/processed/wosac/validation_json_100/12858b7f79387840.json']\n"
- ]
- }
- ],
- "source": [
- "GUIDANCE_MODE = \"log_replay\"\n",
- "\n",
- "env_config = EnvConfig(\n",
- " reward_type=\"guided_autonomy\",\n",
- " guidance=True,\n",
- " guidance_mode=GUIDANCE_MODE,\n",
- " add_reference_heading=True,\n",
- " add_reference_speed=True,\n",
- " add_reference_pos_xy=True,\n",
- " init_mode=\"wosac_train\",\n",
- " smoothen_trajectory=False,\n",
- " dynamics_model=\"delta_local\",\n",
- " guidance_dropout_prob=0.7,\n",
- " init_steps=0,\n",
- ")\n",
- "render_config = RenderConfig()\n",
- "\n",
- "train_loader = SceneDataLoader(\n",
- " root=DATASET,\n",
- " batch_size=1,\n",
- " dataset_size=100,\n",
- " sample_with_replacement=False,\n",
- " shuffle=False,\n",
- " file_prefix=\"\",\n",
- " \n",
- ")\n",
- "\n",
- "env = GPUDriveTorchEnv(\n",
- " config=env_config,\n",
- " data_loader=train_loader,\n",
- " max_cont_agents=32, \n",
- " device=\"cpu\",\n",
- ")\n",
- "\n",
- "print(env.data_batch)\n",
- "\n",
- "obs = env.reset(env.cont_agent_mask)\n",
- "expert_actions, _, _, _ = env.get_expert_actions()\n",
- "\n",
- "# for time_step in range(env.init_steps, 20):\n",
- "# print(f\"Step: {env.step_in_world[0, 0, 0].item()}\")\n",
- "\n",
- "# # Step the environment\n",
- "# expert_actions, _, _, _ = env.get_expert_actions()\n",
- "# env.step_dynamics(expert_actions[:, :, time_step, :])\n",
- "\n",
- "# obs = env.get_obs(env.cont_agent_mask)\n",
- "# rew = env.get_rewards()\n",
- " \n",
- "\n",
- "# # Save for analysis\n",
- "# reference_traj = torch.cat([\n",
- "# env.reference_trajectory.pos_xy,\n",
- "# env.reference_trajectory.vel_xy,\n",
- "# env.reference_trajectory.yaw,\n",
- "# env.reference_trajectory.valids,\n",
- "# ], dim=-1)\n",
- "# reference_traj_np = reference_traj[env.cont_agent_mask].numpy()\n",
- "\n",
- "# np.save(f\"reference_{GUIDANCE_MODE}.npy\", reference_traj_np)"
- ]
- },
- {
- "cell_type": "code",
-<<<<<<< HEAD
- "execution_count": 5,
- "metadata": {},
- "outputs": [
- {
- "data": {
- "text/plain": [
- "torch.Size([10, 2974])"
- ]
- },
- "execution_count": 5,
- "metadata": {},
- "output_type": "execute_result"
- }
- ],
- "source": [
- "obs.shape"
-=======
- "execution_count": 41,
- "metadata": {},
- "outputs": [],
- "source": [
- "#env.reference_trajectory.pos_xy[0, 0, :, 0]"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "### Smoothening trajectory data"
->>>>>>> dev
- ]
- },
- {
- "cell_type": "code",
-<<<<<<< HEAD
- "execution_count": 6,
-=======
- "execution_count": 42,
->>>>>>> dev
- "metadata": {},
- "outputs": [],
- "source": [
- "sim_states = env.vis.plot_simulator_state(\n",
- " env_indices=0,\n",
- " zoom_radius=60,\n",
- " plot_guidance_pos_xy=True,\n",
- " plot_guidance_up_to_time=False,\n",
- " center_agent_indices=[0],\n",
- ")\n",
- "fig = sim_states[0]\n",
- "\n",
- "fig.savefig(\n",
- " os.path.join(FIGURES_DIR, f\"wosac_adapted_sim_state.pdf\"),\n",
- " dpi=300,\n",
- " bbox_inches=\"tight\",\n",
- ")"
- ]
- },
- {
- "cell_type": "code",
-<<<<<<< HEAD
- "execution_count": 7,
- "metadata": {},
- "outputs": [
- {
- "data": {
- "image/svg+xml": [
- "\n",
- "\n",
- "\n"
- ],
- "text/plain": [
- ""
- ]
- },
- "metadata": {},
- "output_type": "display_data"
- }
- ],
- "source": [
- "CENTER_AGENT_IDX = 0\n",
- "\n",
- "fig = env.vis.plot_agent_observation(\n",
- " env_idx=0,\n",
- " agent_idx=CENTER_AGENT_IDX,\n",
- " figsize=(10, 10),\n",
- " trajectory=env.reference_path[CENTER_AGENT_IDX, :, :],\n",
- " # step_reward=env.guidance_reward[\n",
- " # 0, 1\n",
- " # ].item(),\n",
- " # route_progress=env.route_progress[1],\n",
- ")\n",
- "\n",
- "plt.savefig(\n",
- " os.path.join(FIGURES_DIR, f\"agent_observation_{CENTER_AGENT_IDX}_{GUIDANCE_MODE}.pdf\"),\n",
- " dpi=300,\n",
- " bbox_inches=\"tight\",\n",
- ")"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": []
- },
- {
- "cell_type": "code",
- "execution_count": 8,
- "metadata": {},
- "outputs": [
- {
- "data": {
- "image/svg+xml": [
- "\n",
- "\n",
- "\n"
- ],
- "text/plain": [
- ""
- ]
- },
- "metadata": {
- "needs_background": "light"
- },
- "output_type": "display_data"
- }
- ],
- "source": [
- "fig = plot_agent_guidance_info(env.guidance_obs, center_agent_idx=1);\n",
- "\n",
- "plt.savefig(\n",
- " os.path.join(FIGURES_DIR, f\"agent_guidance_info_{CENTER_AGENT_IDX}_{GUIDANCE_MODE}.pdf\"),\n",
- " dpi=300,\n",
- " bbox_inches=\"tight\",\n",
- ")"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": []
- },
- {
- "cell_type": "code",
- "execution_count": 33,
- "metadata": {},
- "outputs": [
- {
- "data": {
- "text/plain": [
- "tensor([-32.8585, -32.8577, -32.8582, -32.8602, -32.8622, -32.8634, -32.8660,\n",
- " -32.8688, -32.8728, -32.8773, -32.8801, -32.8816, -32.8817, -32.8812,\n",
- " -32.8783, -32.8721, -32.8564, -32.8297, -32.7953, -32.7494, -32.6895,\n",
- " -32.6174, -32.5349, -32.4381, -32.3291, -32.2056, -32.0679, -31.9155,\n",
- " -31.7463, -31.5636, -31.3660, -31.1512, -30.9199, -30.6727, -30.4093,\n",
- " -30.1285, -29.8302, -29.5161, -29.1851, -28.8341, -28.4636, -28.0751,\n",
- " -27.6704, -27.2509, -26.8136, -26.3599, -25.8922, -25.4067, -24.9075,\n",
- " -24.3964, -23.8717, -23.3363, -22.7909, -22.2343, -21.6653, -21.0818,\n",
- " -20.4884, -19.8846, -19.2682, -18.6404, -18.0094, -17.3685, -16.7175,\n",
- " -16.0576, -15.3887, -14.7155, -14.0410, -13.3705, -12.6958, -12.0195,\n",
- " -11.3450, -10.6720, -10.0023, -9.3430, -8.6844, -8.0244, -7.3569,\n",
- " -6.6813, -6.0006, -5.3140, -4.6201, -3.9159, -3.1992, -2.4728,\n",
- " -1.7352, -0.9873, -0.2267, 0.5480, 1.3348, 2.1340, 2.9464])"
- ]
- },
- "execution_count": 33,
- "metadata": {},
- "output_type": "execute_result"
- }
- ],
- "source": [
- "env.reference_trajectory.pos_xy[0, 0, :, 0]"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "### Smoothening trajectory data"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 34,
- "metadata": {},
- "outputs": [],
- "source": [
- "ref_log_replay_raw = np.load(f\"{DATA_FOLDER}reference_log_replay_raw.npy\")\n",
- "ref_log_replay_smooth = np.load(f\"{DATA_FOLDER}reference_log_replay_smooth.npy\")"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 35,
-=======
- "execution_count": 43,
->>>>>>> dev
- "metadata": {},
- "outputs": [],
- "source": [
- "def plot_trajectories(raw, smooth, agent_idx):\n",
- " \n",
- " valid = raw[agent_idx, :, 5].astype(bool)\n",
- " \n",
- " fig, axs = plt.subplots(1, 3, figsize=(10, 3))\n",
- " axs[0].set_title(f\"Global position (x, y); agent {agent_idx}\")\n",
- " axs[0].plot(raw[agent_idx, :, 0][valid], raw[agent_idx, :, 1][valid], color='b')\n",
- " axs[0].plot(smooth[agent_idx, :, 0][valid], smooth[agent_idx, :, 1][valid], color='orange')\n",
- " axs[0].set_xlabel(\"x\")\n",
- " axs[0].set_ylabel(\"y\")\n",
- "\n",
- " axs[1].set_title(f\"Global velocity (x, y); agent {agent_idx}\")\n",
- " axs[1].plot(raw[agent_idx, :, 2][valid], raw[agent_idx, :, 3][valid], color='b')\n",
- " axs[1].plot(smooth[agent_idx, :, 2][valid], smooth[agent_idx, :, 3][valid], color='orange')\n",
- " axs[1].set_xlabel(r\"$v_x$\")\n",
- " axs[1].set_ylabel(r\"$v_y$\")\n",
- "\n",
- " axs[2].set_title(f\"Global heading; agent {agent_idx}\")\n",
- " axs[2].plot(list(range(valid.sum())), raw[agent_idx, :, 4][valid], color='b', label=\"raw\")\n",
- " axs[2].plot(list(range(valid.sum())), smooth[agent_idx, :, 4][valid], color='orange', label=\"smoothed\")\n",
- " axs[2].set_ylabel(r\"$\\theta$\")\n",
- "\n",
- " plt.tight_layout()\n",
- "\n",
- " fig.legend(loc=\"center right\", facecolor=\"white\", bbox_to_anchor=(1.15, 0.5))\n",
- " plt.show()"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 44,
- "metadata": {},
- "outputs": [
- {
- "data": {
- "image/svg+xml": [
- "\n",
- "\n",
- "\n"
- ],
- "text/plain": [
- ""
- ]
- },
- "metadata": {
- "needs_background": "light"
- },
- "output_type": "display_data"
- },
- {
- "data": {
- "image/svg+xml": [
- "\n",
- "\n",
- "\n"
- ],
- "text/plain": [
- ""
- ]
- },
- "metadata": {
- "needs_background": "light"
- },
- "output_type": "display_data"
- },
- {
- "data": {
- "image/svg+xml": [
- "\n",
- "\n",
- "\n"
- ],
- "text/plain": [
- ""
- ]
- },
- "metadata": {
- "needs_background": "light"
- },
- "output_type": "display_data"
- },
- {
- "data": {
- "image/svg+xml": [
- "\n",
- "\n",
- "\n"
- ],
- "text/plain": [
- ""
- ]
- },
- "metadata": {
- "needs_background": "light"
- },
- "output_type": "display_data"
- },
- {
- "data": {
- "image/svg+xml": [
- "\n",
- "\n",
- "\n"
- ],
- "text/plain": [
- ""
- ]
- },
- "metadata": {
- "needs_background": "light"
- },
- "output_type": "display_data"
- }
- ],
- "source": [
- "for i in range(5):\n",
- " plot_trajectories(raw=ref_log_replay_raw, smooth=ref_log_replay_smooth, agent_idx=i)"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "## Aligning the scale of guidance data"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "### Q1. Are the distributions from different modes (`vbd_online`, `log_replay` and `vbd_amortized`) on the same scale?\n",
- "\n",
- "- Unnormalized values\n",
- "- Macroscopic (full dist)"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 54,
- "metadata": {},
- "outputs": [],
- "source": [
- "# Load npy array from the file\n",
- "FOLDER = \"examples/eval/figures_data/\"\n",
- "\n",
- "ref_log_replay = np.load(f\"{DATA_FOLDER}reference_log_replay.npy\")\n",
- "ref_vbd_amortized = np.load(f\"{DATA_FOLDER}reference_vbd_amortized.npy\")\n",
- "ref_vbd_online = np.load(f\"{DATA_FOLDER}reference_vbd_online.npy\")"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 55,
- "metadata": {},
- "outputs": [
- {
- "data": {
- "text/plain": [
- "((1180, 91, 6), (1180, 90, 6), (38, 80, 6))"
- ]
- },
- "execution_count": 55,
- "metadata": {},
- "output_type": "execute_result"
- }
- ],
- "source": [
- "# Controlled agent elements for 10 scenarios\n",
- "ref_log_replay.shape, ref_vbd_amortized.shape, ref_vbd_online.shape"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "### Q1 | **Global positions** $(x,y)$"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 56,
- "metadata": {},
- "outputs": [
- {
- "data": {
- "image/svg+xml": [
- "\n",
- "\n",
- "\n"
- ],
- "text/plain": [
- ""
- ]
- },
- "metadata": {
- "needs_background": "light"
- },
- "output_type": "display_data"
- }
- ],
- "source": [
- "fig, axs = plt.subplots(1, 3, figsize=(13, 3),)# sharex=True, sharey=True)\n",
- "\n",
- "axs[0].hist(ref_log_replay[:, :, :2].flatten(), color='b', alpha=0.5, bins=20, label=\"log_replay\")\n",
- "axs[1].hist(ref_vbd_amortized[:, :, :2].flatten(), color='r', alpha=1.0, bins=20, label=\"vbd_amortized\")\n",
- "axs[2].hist(ref_vbd_online[:, :, :2].flatten(), color='g', alpha=1.0, bins=20, label=\"vbd_online\")\n",
- "\n",
- "axs[0].set_xlabel(\"Demeaned position in global coord. frame\")\n",
- "axs[0].set_ylabel(\"Count\")\n",
- "\n",
- "fig.legend()\n",
- "sns.despine()"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "### Q1 | **Velocities** $(x,y)$"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 49,
- "metadata": {},
- "outputs": [
- {
- "data": {
- "image/svg+xml": [
- "\n",
- "\n",
- "\n"
- ],
- "text/plain": [
- ""
- ]
- },
- "metadata": {
- "needs_background": "light"
- },
- "output_type": "display_data"
- }
- ],
- "source": [
- "fig, axs = plt.subplots(1, 3, figsize=(13, 3) )# sharex=True, sharey=True)\n",
- "\n",
- "axs[0].hist(ref_log_replay[:, :, 2:4].flatten(), color='b', alpha=0.5, bins=20, label=\"log_replay\")\n",
- "axs[1].hist(ref_vbd_amortized[:, :, 2:4].flatten(), color='r', alpha=1.0, bins=20, label=\"vbd_amortized\")\n",
- "axs[2].hist(ref_vbd_online[:, :, 2:4].flatten(), color='g', alpha=1.0, bins=20, label=\"vbd_online\")\n",
- "\n",
- "axs[0].set_xlabel(\"Demeaned velocity [m/s]\")\n",
- "axs[0].set_ylabel(\"Count\")\n",
- "\n",
- "fig.legend()\n",
- "sns.despine()"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "### Q1 | **Headings**"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 50,
- "metadata": {},
- "outputs": [
- {
- "data": {
- "image/svg+xml": [
- "\n",
- "\n",
- "\n"
- ],
- "text/plain": [
- ""
- ]
- },
- "metadata": {
- "needs_background": "light"
- },
- "output_type": "display_data"
- }
- ],
- "source": [
- "fig, axs = plt.subplots(1, 3, figsize=(13, 3),)# sharex=True, sharey=True)\n",
- "\n",
- "axs[0].hist(ref_log_replay[:, :, 4].flatten(), color='b', alpha=0.5, bins=20, label=\"log_replay\")\n",
- "axs[1].hist(ref_vbd_amortized[:, :, 4].flatten(), color='r', alpha=1.0, bins=20, label=\"vbd_amortized\")\n",
- "axs[2].hist(ref_vbd_online[:, :, 4].flatten(), color='g', alpha=1.0, bins=20, label=\"vbd_online\")\n",
- "\n",
- "axs[0].set_xlabel(\"Headings in global coord. frame [rad]\")\n",
- "axs[0].set_ylabel(\"Count\")\n",
- "\n",
- "# add legend\n",
- "fig.legend()\n",
- "sns.despine()"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "### Q2. How do the trajectories differ qualitatively?\n",
- "\n",
- "- Plot several guidance trajectories for a particular agent to ensure that they are on the same scale"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 51,
- "metadata": {},
- "outputs": [
- {
- "data": {
- "image/svg+xml": [
- "\n",
- "\n",
- "\n"
- ],
- "text/plain": [
- ""
- ]
- },
- "metadata": {
- "needs_background": "light"
- },
- "output_type": "display_data"
- },
- {
- "data": {
- "image/svg+xml": [
- "\n",
- "\n",
- "\n"
- ],
- "text/plain": [
- ""
- ]
- },
- "metadata": {
- "needs_background": "light"
- },
- "output_type": "display_data"
- },
- {
- "data": {
- "image/svg+xml": [
- "\n",
- "\n",
- "\n"
- ],
- "text/plain": [
- ""
- ]
- },
- "metadata": {
- "needs_background": "light"
- },
- "output_type": "display_data"
- },
- {
- "data": {
- "image/svg+xml": [
- "\n",
- "\n",
- "\n"
- ],
- "text/plain": [
- ""
- ]
- },
- "metadata": {
- "needs_background": "light"
- },
- "output_type": "display_data"
- },
- {
- "data": {
- "image/svg+xml": [
- "\n",
- "\n",
- "\n"
- ],
- "text/plain": [
- ""
- ]
- },
- "metadata": {
- "needs_background": "light"
- },
- "output_type": "display_data"
- },
- {
- "data": {
- "image/svg+xml": [
- "\n",
- "\n",
- "\n"
- ],
- "text/plain": [
- ""
- ]
- },
- "metadata": {
- "needs_background": "light"
- },
- "output_type": "display_data"
- },
- {
- "data": {
- "image/svg+xml": [
- "\n",
- "\n",
- "\n"
- ],
- "text/plain": [
- ""
- ]
- },
- "metadata": {
- "needs_background": "light"
- },
- "output_type": "display_data"
- },
- {
- "data": {
- "image/svg+xml": [
- "\n",
- "\n",
- "\n"
- ],
- "text/plain": [
- ""
- ]
- },
- "metadata": {
- "needs_background": "light"
- },
- "output_type": "display_data"
- },
- {
- "data": {
- "image/svg+xml": [
- "\n",
- "\n",
- "\n"
- ],
- "text/plain": [
- ""
- ]
- },
- "metadata": {
- "needs_background": "light"
- },
- "output_type": "display_data"
- },
- {
- "data": {
- "image/svg+xml": [
- "\n",
- "\n",
- "\n"
- ],
- "text/plain": [
- ""
- ]
- },
- "metadata": {
- "needs_background": "light"
- },
- "output_type": "display_data"
- }
- ],
- "source": [
- "for agent_idx in range(10):\n",
- " \n",
- " valid = ref_log_replay[agent_idx, :, -1].astype( bool)\n",
- "\n",
- " fig, axs = plt.subplots(1, 3, figsize=(12, 4))\n",
- " axs[0].set_title(f\"Global position (x, y); agent {agent_idx}\")\n",
- " axs[0].plot(ref_log_replay[agent_idx, :, 0][valid], ref_log_replay[agent_idx, :, 1][valid], color='b')\n",
- " #axs[0].plot(ref_vbd_online[agent_idx, :, 0], ref_vbd_online[agent_idx, :, 1], color='g', alpha=0.5)\n",
- " axs[0].plot(ref_vbd_amortized[agent_idx, :, 0], ref_vbd_amortized[agent_idx, :, 1], color='r')\n",
- " axs[0].set_xlabel(\"x\")\n",
- " axs[0].set_ylabel(\"y\")\n",
- "\n",
- " axs[1].set_title(f\"Global velocity (x, y); agent {agent_idx}\")\n",
- " axs[1].plot(ref_log_replay[agent_idx, :, 2][valid], ref_log_replay[agent_idx, :, 3][valid], color='b')\n",
- " axs[1].plot(ref_vbd_online[agent_idx, :, 2], ref_vbd_online[agent_idx, :, 3], color='g')\n",
- " axs[1].plot(ref_vbd_amortized[agent_idx, :, 2], ref_vbd_amortized[agent_idx, :, 3], color='r')\n",
- " axs[2].set_xlabel(r\"$v_x$\")\n",
- " axs[2].set_ylabel(r\"$v_y$\")\n",
- "\n",
- " axs[2].set_title(f\"Global heading; agent {agent_idx}\")\n",
- " axs[2].plot(ref_log_replay[agent_idx, :, 4][valid], color='b', label=\"log_replay\")\n",
- " axs[2].plot(ref_vbd_online[agent_idx, :, 4], color='g', label=\"vbd_online\")\n",
- " axs[2].plot(ref_vbd_amortized[agent_idx, :, 4], color='r', label=\"vbd_amortized\")\n",
- " axs[2].set_ylabel(r\"$\\theta$\")\n",
- "\n",
- " plt.tight_layout()\n",
- "\n",
- " fig.legend(loc=\"center right\", facecolor=\"white\", bbox_to_anchor=(1.12, 0.5))\n",
- " plt.show()"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "### Show corresponding full agent views"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 27,
- "metadata": {},
- "outputs": [
- {
- "data": {
- "image/svg+xml": [
- "\n",
- "\n",
- "\n"
- ],
- "text/plain": [
- ""
- ]
- },
- "metadata": {},
- "output_type": "display_data"
- }
- ],
- "source": [
- "center_agent_idx=3\n",
- "\n",
- "traj_masked = env.reference_path[center_agent_idx, :, :]\n",
- "\n",
- "agent_obs_masked = env.vis.plot_agent_observation(\n",
- " env_idx=0,\n",
- " agent_idx=center_agent_idx,\n",
- " figsize=(8, 8),\n",
- " trajectory=traj_masked,\n",
- ")"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 28,
- "metadata": {},
- "outputs": [
- {
- "data": {
- "image/svg+xml": [
- "\n",
- "\n",
- "\n"
- ],
- "text/plain": [
- ""
- ]
- },
- "metadata": {},
- "output_type": "display_data"
- }
- ],
- "source": [
- "center_agent_idx=1\n",
- "\n",
- "traj_masked = env.reference_path[center_agent_idx, :, :]\n",
- "\n",
- "agent_obs_masked = env.vis.plot_agent_observation(\n",
- " env_idx=0,\n",
- " agent_idx=center_agent_idx,\n",
- " figsize=(8, 8),\n",
- " trajectory=traj_masked,\n",
- ")"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 29,
- "metadata": {},
- "outputs": [
- {
- "data": {
- "image/svg+xml": [
- "\n",
- "\n",
- "\n"
- ],
- "text/plain": [
- ""
- ]
- },
- "metadata": {},
- "output_type": "display_data"
- }
- ],
- "source": [
- "center_agent_idx=0\n",
- "\n",
- "traj_masked = env.reference_path[center_agent_idx, :, :]\n",
- "\n",
- "agent_obs_masked = env.vis.plot_agent_observation(\n",
- " env_idx=0,\n",
- " agent_idx=center_agent_idx,\n",
- " figsize=(8, 8),\n",
- " trajectory=traj_masked,\n",
- ")"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "### Simulator state for that scene"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 30,
- "metadata": {},
- "outputs": [
- {
- "data": {
- "image/svg+xml": [
- "\n",
- "\n",
- "\n"
- ],
- "text/plain": [
- ""
- ]
- },
- "execution_count": 30,
- "metadata": {},
- "output_type": "execute_result"
- }
- ],
- "source": [
- "sim_states = env.vis.plot_simulator_state(\n",
- " env_indices=0,\n",
- " zoom_radius=70,\n",
- " plot_guidance_pos_xy=True,\n",
- " center_agent_indices=[center_agent_idx],\n",
- ")\n",
- "sim_states[0]"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "### Useful dataset statistics\n",
- "\n",
- "- What is the distribution of the guidance trajectories in the dataset?\n",
- " - What is the average length $T$?, what is the variance?\n",
- "- How often is the guidance trajectory to \"do nothing\"?"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 26,
- "metadata": {},
- "outputs": [
- {
- "data": {
- "text/plain": [
- "(1180, 91, 6)"
- ]
- },
- "execution_count": 26,
- "metadata": {},
- "output_type": "execute_result"
- }
- ],
- "source": [
- "ref_log_replay.shape # [agents, time, features]"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 33,
- "metadata": {},
- "outputs": [],
- "source": [
- "valid_guidance_points = ref_log_replay[:, :, 5].sum(axis=-1)\n",
- "valid_guidance_points.shape\n",
- "\n",
- "# Some trajectories, while valid, are just parked cars\n",
- "# Detect these by checking for zero velocity trajectories\n",
- "valid_and_non_zero = ((ref_log_replay[:, :, 2] != 0) & (ref_log_replay[:, :, 5] == 1)).sum(axis=1)"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 34,
- "metadata": {},
- "outputs": [
- {
- "data": {
- "image/svg+xml": [
- "\n",
- "\n",
- "\n"
- ],
- "text/plain": [
- ""
- ]
- },
- "metadata": {
- "needs_background": "light"
- },
- "output_type": "display_data"
- }
- ],
- "source": [
- "fig, axs = plt.subplots(1, 2, figsize=(9, 3.5), sharey=True)\n",
- "\n",
- "fig.suptitle(f\"Distribution of guidance points per episode; A = {ref_log_replay.shape[0]} (50 scenarios)\")\n",
- "\n",
- "sns.histplot(valid_guidance_points, bins=10, stat='percent', ax=axs[0])\n",
- "axs[0].grid(True, alpha=0.2)\n",
- "axs[0].set_xlabel(\"Guidance points per episode.\")\n",
- "mean_value = valid_guidance_points.mean()\n",
- "axs[0].axvline(x=mean_value, color='purple', linestyle='--')\n",
- "axs[0].text(mean_value-3, axs[0].get_ylim()[1]*0.95, r'$\\mu$', color='purple', \n",
- " fontsize=14, ha='center', va='center')\n",
- "\n",
- "sns.histplot(valid_and_non_zero, bins=10, stat='percent', ax=axs[1])\n",
- "axs[1].grid(True, alpha=0.2)\n",
- "axs[1].set_xlabel(r\"$\\bf{Nonzero}$ guidance points per episode.\")\n",
- "mean_value_nonzero = valid_and_non_zero.mean()\n",
- "axs[1].axvline(x=mean_value_nonzero, color='purple', linestyle='--')\n",
- "axs[1].text(mean_value_nonzero-3, axs[1].get_ylim()[1]*0.95, r'$\\mu$', color='purple', fontsize=14, ha='center', va='center')\n",
- "\n",
- "plt.tight_layout()\n",
- "sns.despine()\n",
- "plt.savefig(f\"{FIGURES_DIR}/guidance_points_distribution.pdf\", dpi=300, bbox_inches=\"tight\")"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 20,
- "metadata": {},
- "outputs": [
- {
- "data": {
- "text/plain": [
- "(69.444916, 30.798435, 91.0, 1.0)"
- ]
- },
- "execution_count": 20,
- "metadata": {},
- "output_type": "execute_result"
- }
- ],
- "source": [
- "valid_guidance_points.mean(), valid_guidance_points.std(), valid_guidance_points.max(), valid_guidance_points.min()"
- ]
- }
- ],
- "metadata": {
- "kernelspec": {
- "display_name": "gpudrive",
- "language": "python",
- "name": "python3"
- },
- "language_info": {
- "codemirror_mode": {
- "name": "ipython",
- "version": 3
- },
- "file_extension": ".py",
- "mimetype": "text/x-python",
- "name": "python",
- "nbconvert_exporter": "python",
- "pygments_lexer": "ipython3",
- "version": "3.11.9"
- }
- },
- "nbformat": 4,
- "nbformat_minor": 2
-}
diff --git a/examples/eval/notebooks/obtain_guidance_data.py b/examples/eval/notebooks/obtain_guidance_data.py
index f81b3c286..d70863863 100644
--- a/examples/eval/notebooks/obtain_guidance_data.py
+++ b/examples/eval/notebooks/obtain_guidance_data.py
@@ -11,9 +11,9 @@
if __name__ == "__main__":
- GUIDANCE_MODE = "log_replay"
+ GUIDANCE_MODE = "vbd_online"
DATASET = "data/processed/wosac/validation_json_100" # Ensure VBD trajectory structures are in here
- SAVE_PATH = "examples/eval/figures_data/"
+ SAVE_PATH = "examples/eval/figures_data/guidance/"
env_config = EnvConfig(
dynamics_model="classic",
@@ -23,14 +23,14 @@
add_reference_heading=True,
add_reference_speed=True,
add_reference_pos_xy=True,
- init_mode="wosac_train",
+ init_mode="wosac_eval",
smoothen_trajectory=False,
)
render_config = RenderConfig()
train_loader = SceneDataLoader(
root=DATASET,
- batch_size=2,
+ batch_size=10,
dataset_size=100,
sample_with_replacement=False,
shuffle=False,
@@ -40,7 +40,7 @@
env = GPUDriveTorchEnv(
config=env_config,
data_loader=train_loader,
- max_cont_agents=64,
+ max_cont_agents=32,
device="cuda",
)
diff --git a/examples/eval/run_wosac_eval.py b/examples/eval/run_wosac_eval.py
index ad589cabf..3ab2928a5 100644
--- a/examples/eval/run_wosac_eval.py
+++ b/examples/eval/run_wosac_eval.py
@@ -92,7 +92,7 @@ def rollout(
info.collided_with_vehicle == 1
) & (info.type == int(madrona_gpudrive.EntityType.Vehicle))
control_mask = control_mask_all & ~zero_action_mask
-
+
next_obs = env.reset(mask=control_mask)
# Get scenario ids
@@ -143,7 +143,7 @@ def rollout(
agent_observation_frames[idx].append(img_from_fig(agent_obs))
# Get next observation
- next_obs = env.get_obs(control_mask)
+ next_obs = env.get_obs(control_mask)
# NOTE(dc): Make sure to decouple the obs from the reward function
reward = env.get_rewards()
done = env.get_dones()
@@ -235,6 +235,7 @@ def rollout(
return scenario_ids, scenario_rollouts, scenario_rollout_masks
+
def load_config(config_path):
"""Load the configuration file."""
with open(config_path, "r") as f:
@@ -252,9 +253,11 @@ def load_config(config_path):
NUM_DATA_BATCHES = 1
INIT_STEPS = 10
DATASET_SIZE = 100
- RENDER = False
+ RENDER = True
LOG_DIR = "examples/eval/figures_data/wosac/"
- GUIDANCE_MODE = "log_replay" #"vbd_amortized"
+ GUIDANCE_MODE = (
+ "log_replay" # Options: "vbd_amortized", "vbd_online", "log_replay"
+ )
DATA_JSON = "data/processed/wosac/validation_json_100"
DATA_TFRECORD = "data/processed/wosac/validation_tfrecord_100"
@@ -269,16 +272,16 @@ def load_config(config_path):
batch_size=NUM_ENVS,
dataset_size=DATASET_SIZE,
sample_with_replacement=True,
- shuffle=False,
+ shuffle=True,
file_prefix="",
)
# Load agent
agent = load_agent(path_to_cpt=CPT_PATH).to(DEVICE)
-
+
# config = load_config("baselines/ppo/config/ppo_guided_autonomy.yaml")
# config = config.environment
-
+
config = agent.config
# Override default environment settings to match those the agent was trained with
@@ -304,25 +307,29 @@ def load_config(config_path):
goal_behavior=config.goal_behavior,
polyline_reduction_threshold=config.polyline_reduction_threshold,
remove_non_vehicles=config.remove_non_vehicles,
- lidar_obs=False,
+ lidar_obs=False,
obs_radius=config.obs_radius,
max_steer_angle=config.max_steer_angle,
max_accel_value=config.max_accel_value,
action_space_steer_disc=config.action_space_steer_disc,
action_space_accel_disc=config.action_space_accel_disc,
# Override action space
- steer_actions = torch.round(
+ steer_actions=torch.round(
torch.linspace(
- -config.max_steer_angle, config.max_steer_angle, config.action_space_steer_disc
+ -config.max_steer_angle,
+ config.max_steer_angle,
+ config.action_space_steer_disc,
),
decimals=3,
+ ),
+ accel_actions=torch.round(
+ torch.linspace(
+ -config.max_accel_value,
+ config.max_accel_value,
+ config.action_space_accel_disc,
),
- accel_actions = torch.round(
- torch.linspace(
- -config.max_accel_value, config.max_accel_value, config.action_space_accel_disc
- ),
- decimals=3,
- ),
+ decimals=3,
+ ),
init_mode="wosac_eval",
init_steps=INIT_STEPS,
guidance_mode=GUIDANCE_MODE,
diff --git a/examples/eval/wosac_eval_origin.py b/examples/eval/wosac_eval_origin.py
index fa4364b69..fcf4d50c8 100644
--- a/examples/eval/wosac_eval_origin.py
+++ b/examples/eval/wosac_eval_origin.py
@@ -55,7 +55,7 @@ def __init__(
else:
self.baselines_df = baselines_df
- self.timestamp = datetime.datetime.now().strftime("%Y-%m-%d-%H-%M-%S")
+ self.timestamp = datetime.datetime.now().strftime("%Y-%m-%d-%H-%M")
self.field_names = [
"metametric",
diff --git a/gpudrive/datatypes/trajectory.py b/gpudrive/datatypes/trajectory.py
index 30881eed5..7d50b8309 100644
--- a/gpudrive/datatypes/trajectory.py
+++ b/gpudrive/datatypes/trajectory.py
@@ -228,7 +228,6 @@ def __init__(self, vbd_traj_tensor: torch.Tensor):
torch.int32
)
-
@classmethod
def from_tensor(
cls,
diff --git a/gpudrive/env/env_torch.py b/gpudrive/env/env_torch.py
index dd91d6a09..1b3463b15 100755
--- a/gpudrive/env/env_torch.py
+++ b/gpudrive/env/env_torch.py
@@ -190,7 +190,7 @@ def setup_guidance(self):
self.max_agent_count,
madrona_gpudrive.kTrajectoryLength,
6,
- )
+ ).to(self.device)
reference_trajectory[
:, :, : self.init_steps + 1, :2
] = log_trajectory.pos_xy[:, :, : self.init_steps + 1]
diff --git a/src/dynamics.hpp b/src/dynamics.hpp
index 7fba26845..ed2fead12 100755
--- a/src/dynamics.hpp
+++ b/src/dynamics.hpp
@@ -25,9 +25,9 @@ namespace madrona_gpudrive
float speed = velocity.linear.length();
float yaw = utils::quatToYaw(rotation);
- float x_dot = speed * cosf(utils::AngleAdd(yaw, velocity.angular.z));
- float y_dot = speed * sinf(utils::AngleAdd(yaw, velocity.angular.z));
- float theta_dot = speed * tanf(velocity.angular.z) / size.length;
+ float x_dot = speed * cosf(yaw);
+ float y_dot = speed * sinf(yaw);
+ float theta_dot = speed * tanf(velocity.angular.z) / (0.8 * size.length);
float delta_dot = action.classic.steering;
// Update the yaw
float new_yaw = utils::AngleAdd(yaw, theta_dot * dt);