diff --git a/.github/workflows/run_build_docs.yml b/.github/workflows/run_build_docs.yml new file mode 100644 index 00000000..9c66225d --- /dev/null +++ b/.github/workflows/run_build_docs.yml @@ -0,0 +1,56 @@ +# Workflow for building Sphinx docs and deploying to GH Pages +name: Build Sphinx docs and deploy to GH Pages + +on: + pull_request: + branches: [main] + workflow_dispatch: + inputs: + branch: + description: Branch to build docs for + required: true + default: main + +# Sets permissions of the GITHUB_TOKEN to allow deployment to GitHub Pages +permissions: + contents: read + pages: write + id-token: write + +# Allow only one concurrent deployment, cancelling in-progress runs. +concurrency: + group: pages + cancel-in-progress: true + +jobs: + build: + runs-on: ubuntu-latest + steps: + - run: echo "The job was automatically triggered by a ${{ github.event_name }} event." + - run: echo "This job is now running on a ${{ runner.os }} server." + - run: echo "Running on branch ${{ github.ref }} of repository ${{ github.repository }}." + - name: Check out repository code. + uses: actions/checkout@v3 + - name: Python environment setup + uses: actions/setup-python@v4 + with: + python-version: '3.10' + - name: Install dependencies. + run: | + sudo apt-get update + sudo apt-get install libsfml-dev + git submodule sync + git submodule update --init --recursive + python -m pip install -U pip poetry + poetry install --with=docs + - name: Build Sphinx docs + run: poetry run sphinx-build -b html ${{ github.workspace }}/docs/source ${{ github.workspace }}/docs/build/html + - name: Setup Pages + uses: actions/configure-pages@v3 + - name: Upload artifact + uses: actions/upload-pages-artifact@v1 + with: + path: ./docs/build/html # Upload HTML docs only + - name: Deploy to GitHub Pages + id: deployment + uses: actions/deploy-pages@v2 diff --git a/.gitignore b/.gitignore index 1529576e..a8e131b6 100644 --- a/.gitignore +++ b/.gitignore @@ -51,5 +51,8 @@ configs.json # poetry poetry.lock +# Static doc files +!docs/source/_static/* + # wandb wandb diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index d1117cb4..018753c8 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -45,6 +45,22 @@ repos: args: [--autofix, --no-sort] - id: pretty-format-yaml args: [--autofix] +- repo: https://github.com/mwouts/jupytext + rev: v1.15.2 + hooks: + - id: jupytext + args: [--from, ipynb, --to, md:myst, --sync] +- repo: https://github.com/nbQA-dev/nbQA + rev: 1.7.0 + hooks: + - id: nbqa-pyupgrade + args: [--py310-plus] + - id: nbqa-black + args: [--line-length=120] + - id: nbqa-isort + args: [--profile=black] + - id: nbqa-flake8 + args: [--max-line-length=120, --extend-ignore=E203] - repo: local hooks: - id: pylint @@ -52,6 +68,7 @@ repos: entry: poetry run pylint language: system types: [python] + exclude: ^docs/.* - id: poetry-export-requirements name: poetry-export-requirements entry: poetry export --without-hashes --with=main,research -f requirements.txt -o requirements.txt @@ -64,3 +81,9 @@ repos: language: system types: [python] pass_filenames: false + - id: poetry-export-requirements-docs + name: poetry-export-requirements-docs + entry: poetry export --without-hashes --only docs -f requirements.txt -o requirements.docs.txt + language: system + types: [python] + pass_filenames: false diff --git a/docs/Makefile b/docs/Makefile new file mode 100644 index 00000000..d0c3cbf1 --- /dev/null +++ b/docs/Makefile @@ -0,0 +1,20 @@ +# Minimal makefile for Sphinx documentation +# + +# You can set these variables from the command line, and also +# from the environment for the first two. +SPHINXOPTS ?= +SPHINXBUILD ?= sphinx-build +SOURCEDIR = source +BUILDDIR = build + +# Put it first so that "make" without argument is like "make help". +help: + @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) + +.PHONY: help Makefile + +# Catch-all target: route all unknown targets to Sphinx using the new +# "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). +%: Makefile + @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) diff --git a/docs/make.bat b/docs/make.bat new file mode 100644 index 00000000..dc1312ab --- /dev/null +++ b/docs/make.bat @@ -0,0 +1,35 @@ +@ECHO OFF + +pushd %~dp0 + +REM Command file for Sphinx documentation + +if "%SPHINXBUILD%" == "" ( + set SPHINXBUILD=sphinx-build +) +set SOURCEDIR=source +set BUILDDIR=build + +%SPHINXBUILD% >NUL 2>NUL +if errorlevel 9009 ( + echo. + echo.The 'sphinx-build' command was not found. Make sure you have Sphinx + echo.installed, then set the SPHINXBUILD environment variable to point + echo.to the full path of the 'sphinx-build' executable. Alternatively you + echo.may add the Sphinx directory to PATH. + echo. + echo.If you don't have Sphinx installed, grab it from + echo.https://www.sphinx-doc.org/ + exit /b 1 +) + +if "%1" == "" goto help + +%SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% +goto end + +:help +%SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% + +:end +popd diff --git a/docs/source/01_data_structure.ipynb b/docs/source/01_data_structure.ipynb new file mode 100644 index 00000000..62a7ccca --- /dev/null +++ b/docs/source/01_data_structure.ipynb @@ -0,0 +1,4230 @@ +{ + "cells": [ + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Data format of a traffic scene\n", + "\n", + "This notebook dives into the data format used to create simulations in Nocturne.\n", + "\n", + "_Last update: 10/2023_" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [], + "source": [ + "import json\n", + "import os\n", + "\n", + "import matplotlib.pyplot as plt\n", + "import pandas as pd\n", + "import seaborn as sns\n", + "\n", + "os.chdir(\"..\")\n", + "\n", + "cmap = [\"r\", \"g\", \"b\", \"y\", \"c\"]\n", + "%config InlineBackend.figure_format = 'svg'\n", + "sns.set(\"notebook\", font_scale=1.1, rc={\"figure.figsize\": (8, 3)})\n", + "sns.set_style(\"ticks\", rc={\"figure.facecolor\": \"none\", \"axes.facecolor\": \"none\"})" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Traffic scenes are constructed by utilizing the [Waymo Open Motion dataset](https://waymo.com/open/). Though every scene is unique, they all have the same basic data structure. \n", + "\n", + "To load a traffic scene:" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "dict_keys(['name', 'objects', 'roads', 'tl_states'])" + ] + }, + "execution_count": 11, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# Take an example scene\n", + "data_path = \"./data/example_scenario.json\"\n", + "\n", + "with open(data_path) as file:\n", + " traffic_scene = json.load(file)\n", + "\n", + "traffic_scene.keys()" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Global Overview \n", + "A traffic scene consists of:\n", + "- `name`: the name of the traffic scenario.\n", + "- `objects`: the road objects or moving vehicles in the scene.\n", + "- `roads`: the road points in the scene, these are all the stationary objects.\n", + "- `tl_states`: the states of the traffic lights, which are filtered out for now. " + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "{}" + ] + }, + "execution_count": 12, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "traffic_scene[\"tl_states\"]" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "'tfrecord-00358-of-01000_65.json'" + ] + }, + "execution_count": 13, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "traffic_scene[\"name\"]" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "metadata": {}, + "outputs": [ + { + "data": { + "image/svg+xml": [ + "\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " 2023-10-03T10:23:25.972593\n", + " image/svg+xml\n", + " \n", + " \n", + " Matplotlib v3.8.0, https://matplotlib.org/\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "\n" + ], + "text/plain": [ + "
" + ] + }, + "metadata": { + "needs_background": "light" + }, + "output_type": "display_data" + } + ], + "source": [ + "pd.Series([traffic_scene[\"objects\"][idx][\"type\"] for idx in range(len(traffic_scene[\"objects\"]))]).value_counts().plot(\n", + " kind=\"bar\", rot=45, color=cmap\n", + ")\n", + "plt.title(f'Distribution of road objects in traffic scene. Total # objects: {len(traffic_scene[\"objects\"])}')\n", + "plt.show()" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "This traffic scenario only contains vehicles and pedestrians, some scenes have cyclists as well." + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "metadata": {}, + "outputs": [ + { + "data": { + "image/svg+xml": [ + "\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " 2023-10-03T10:23:26.839616\n", + " image/svg+xml\n", + " \n", + " \n", + " Matplotlib v3.8.0, https://matplotlib.org/\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "\n" + ], + "text/plain": [ + "
" + ] + }, + "metadata": { + "needs_background": "light" + }, + "output_type": "display_data" + } + ], + "source": [ + "pd.Series([traffic_scene[\"roads\"][idx][\"type\"] for idx in range(len(traffic_scene[\"roads\"]))]).value_counts().plot(\n", + " kind=\"bar\", rot=45, color=cmap\n", + ")\n", + "plt.title(f'Distribution of road points in traffic scene. Total # points: {len(traffic_scene[\"roads\"])}')\n", + "plt.show()" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### In-Depth: Road Objects\n", + "\n", + "This is a list of different road objects in the traffic scene. For each road object, we have information about its position, velocity, size, in which direction it's heading, whether it's a valid object, the type, and the final position of the vehicle." + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "dict_keys(['position', 'width', 'length', 'heading', 'velocity', 'valid', 'goalPosition', 'type'])" + ] + }, + "execution_count": 16, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# Take the first object\n", + "idx = 0\n", + "\n", + "# For each object, we have this information:\n", + "traffic_scene[\"objects\"][idx].keys()" + ] + }, + { + "cell_type": "code", + "execution_count": 17, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[\n", + " {\n", + " \"x\": 9037.7138671875,\n", + " \"y\": -2720.373779296875\n", + " },\n", + " {\n", + " \"x\": 9037.7607421875,\n", + " \"y\": -2720.306640625\n", + " },\n", + " {\n", + " \"x\": 9037.822265625,\n", + " \"y\": -2720.217529296875\n", + " },\n", + " {\n", + " \"x\": 9037.8916015625,\n", + " \"y\": -2720.146240234375\n", + " },\n", + " {\n", + " \"x\": 9037.9482421875,\n", + " \"y\": -2720.070068359375\n", + " },\n", + " {\n", + " \"x\": 9038.01953125,\n", + " \"y\": -2719.994384765625\n", + " },\n", + " {\n", + " \"x\": 9038.1005859375,\n", + " \"y\": -2719.903076171875\n", + " },\n", + " {\n", + " \"x\": 9038.1953125,\n", + " \"y\": -2719.830810546875\n", + " },\n", + " {\n", + " \"x\": 9038.279296875,\n", + " \"y\": -2719.74462890625\n", + " },\n", + " {\n", + " \"x\": 9038.3564453125,\n", + " \"y\": -2719.674560546875\n", + " }\n", + "]\n" + ] + } + ], + "source": [ + "# Position contains the (x, y) coordinates for the vehicle at every time step\n", + "print(json.dumps(traffic_scene[\"objects\"][idx][\"position\"][:10], indent=4))" + ] + }, + { + "cell_type": "code", + "execution_count": 18, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "(0.6877052187919617, 0.6777269244194031)" + ] + }, + "execution_count": 18, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# Width and length together make the size of the object, and is used to see if there is a collision\n", + "traffic_scene[\"objects\"][idx][\"width\"], traffic_scene[\"objects\"][idx][\"length\"]" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "An object's heading refers to the direction it is pointing or moving in. The default coordinate system in Nocturne is right-handed, where the positive x and y axes point to the right and downwards, respectively. In a right-handed coordinate system, 0 degrees is located on the x-axis and the angle increases counter-clockwise.\n", + "\n", + "Because the scene is created from the viewpoint of an ego driver, there may be instances where the heading of certain vehicles is not available. These cases are represented by the value `-10_000`, to indicate that these steps should be filtered out or are invalid." + ] + }, + { + "cell_type": "code", + "execution_count": 19, + "metadata": {}, + "outputs": [ + { + "data": { + "image/svg+xml": [ + "\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " 2023-10-03T10:23:28.800884\n", + " image/svg+xml\n", + " \n", + " \n", + " Matplotlib v3.8.0, https://matplotlib.org/\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "\n" + ], + "text/plain": [ + "
" + ] + }, + "metadata": { + "needs_background": "light" + }, + "output_type": "display_data" + } + ], + "source": [ + "# Heading is the direction in which the vehicle is pointing\n", + "plt.plot(traffic_scene[\"objects\"][idx][\"heading\"])\n", + "plt.xlabel(\"Time step\")\n", + "plt.ylabel(\"Heading\")\n", + "plt.show()" + ] + }, + { + "cell_type": "code", + "execution_count": 20, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[\n", + " {\n", + " \"x\": 0.634765625,\n", + " \"y\": 0.72265625\n", + " },\n", + " {\n", + " \"x\": 0.46875,\n", + " \"y\": 0.67138671875\n", + " },\n", + " {\n", + " \"x\": 0.615234375,\n", + " \"y\": 0.89111328125\n", + " },\n", + " {\n", + " \"x\": 0.693359375,\n", + " \"y\": 0.712890625\n", + " },\n", + " {\n", + " \"x\": 0.56640625,\n", + " \"y\": 0.76171875\n", + " },\n", + " {\n", + " \"x\": 0.712890625,\n", + " \"y\": 0.7568359375\n", + " },\n", + " {\n", + " \"x\": 0.810546875,\n", + " \"y\": 0.9130859375\n", + " },\n", + " {\n", + " \"x\": 0.947265625,\n", + " \"y\": 0.72265625\n", + " },\n", + " {\n", + " \"x\": 0.83984375,\n", + " \"y\": 0.86181640625\n", + " },\n", + " {\n", + " \"x\": 0.771484375,\n", + " \"y\": 0.70068359375\n", + " }\n", + "]\n" + ] + } + ], + "source": [ + "# Velocity shows the velocity in the x- and y- directions\n", + "print(json.dumps(traffic_scene[\"objects\"][idx][\"velocity\"][:10], indent=4))" + ] + }, + { + "cell_type": "code", + "execution_count": 21, + "metadata": {}, + "outputs": [ + { + "data": { + "image/svg+xml": [ + "\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " 2023-10-03T10:23:29.389521\n", + " image/svg+xml\n", + " \n", + " \n", + " Matplotlib v3.8.0, https://matplotlib.org/\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "\n" + ], + "text/plain": [ + "
" + ] + }, + "metadata": { + "needs_background": "light" + }, + "output_type": "display_data" + } + ], + "source": [ + "# Valid indicates if the state of the vehicle was observed for each timepoint\n", + "plt.xlabel(\"Time step\")\n", + "plt.ylabel(\"IS VALID\")\n", + "plt.plot(traffic_scene[\"objects\"][idx][\"valid\"], \"_\", lw=5)\n", + "plt.show()" + ] + }, + { + "cell_type": "code", + "execution_count": 22, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "{'x': 9041.1259765625, 'y': -2716.647216796875}" + ] + }, + "execution_count": 22, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# Each object has a goalPosition, an (x, y) position within the scene\n", + "traffic_scene[\"objects\"][idx][\"goalPosition\"]" + ] + }, + { + "cell_type": "code", + "execution_count": 23, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "'pedestrian'" + ] + }, + "execution_count": 23, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# Finally, we have the type of the vehicle\n", + "traffic_scene[\"objects\"][idx][\"type\"]" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### In-Depth: Road Points\n", + "\n", + "Road points are static objects in the scene." + ] + }, + { + "cell_type": "code", + "execution_count": 24, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "dict_keys(['geometry', 'type'])" + ] + }, + "execution_count": 24, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "traffic_scene[\"roads\"][idx].keys()" + ] + }, + { + "cell_type": "code", + "execution_count": 25, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "'road_edge'" + ] + }, + "execution_count": 25, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# This point represents the edge of a road\n", + "traffic_scene[\"roads\"][idx][\"type\"]" + ] + }, + { + "cell_type": "code", + "execution_count": 26, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[\n", + " {\n", + " \"x\": 8922.911733810946,\n", + " \"y\": -2849.426741530589\n", + " },\n", + " {\n", + " \"x\": 8923.216436260553,\n", + " \"y\": -2849.038518766975\n", + " },\n", + " {\n", + " \"x\": 8923.50673911804,\n", + " \"y\": -2848.63941352788\n", + " },\n", + " {\n", + " \"x\": 8923.782254084921,\n", + " \"y\": -2848.2299596442986\n", + " },\n", + " {\n", + " \"x\": 8924.042612639492,\n", + " \"y\": -2847.8107047886665\n", + " },\n", + " {\n", + " \"x\": 8924.287466537296,\n", + " \"y\": -2847.382209743547\n", + " },\n", + " {\n", + " \"x\": 8924.516488266596,\n", + " \"y\": -2846.945047650609\n", + " },\n", + " {\n", + " \"x\": 8924.729371495881,\n", + " \"y\": -2846.49980324385\n", + " },\n", + " {\n", + " \"x\": 8924.91688626026,\n", + " \"y\": -2846.067714357487\n", + " },\n", + " {\n", + " \"x\": 8925.087545312272,\n", + " \"y\": -2845.6286986979553\n", + " }\n", + "]\n" + ] + } + ], + "source": [ + "# Geometry contains the (x, y) position(s) for a road point\n", + "# Note that this will be a list for road lanes and edges but a single (x, y) tuple for stop signs and alike\n", + "print(json.dumps(traffic_scene[\"roads\"][idx][\"geometry\"][:10], indent=4));" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "nocturne-research", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.12" + }, + "orig_nbformat": 4 + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/docs/source/02_nocturne_concepts.ipynb b/docs/source/02_nocturne_concepts.ipynb new file mode 100644 index 00000000..a31d262c --- /dev/null +++ b/docs/source/02_nocturne_concepts.ipynb @@ -0,0 +1,786 @@ +{ + "cells": [ + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Nocturne concepts\n", + "\n", + "This page introduces the most basic elements of nocturne. You can find further information about these [in Section 3 of the Nocturne paper](https://arxiv.org/abs/2206.09889).\n", + "\n", + "_Last update: 10/2023_" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [], + "source": [ + "import os\n", + "\n", + "import numpy as np\n", + "\n", + "os.chdir(\"..\")\n", + "\n", + "data_path = \"./data/example_scenario.json\"" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Summary\n", + "\n", + "- Nocturne simulations are **discretized traffic scenarios**. A scenario is a constructed snapshot of traffic situation at a particular timepoint.\n", + "- The state of the vehicle of focus is referred to as the **ego state**. Each vehicle has their **own partial view of the traffic scene**; and a visible state is constructed by parameterizing the view distance, head angle and cone radius of the driver. The action for each vehicle is a `(1, 3)` tuple with the acceleration, steering and head angle of the vehicle. \n", + "- The **step method advances the simulation** with a desired step size. By default, the dynamics of vehicles are driven by a kinematic bicycle model. If a vehicle is set to expert-controlled mode, its position, heading, and speed will be updated according to a trajectory recorded from a human driver." + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Simulation\n", + "\n", + "In Nocturne, a simulation discretizes an existing traffic scenario. At the moment, Nocturne supports traffic scenarios from the Waymo Open Dataset, but can be further extended to work with other driving datasets. \n", + "\n", + "
\n", + "
\n", + "\n", + "
An example of a set of traffic scenario's in Nocturne. Upon initialization, a start time is chosen. After each iteration we take a step in the simulation, which gets us to the next scenario. This is done until we reach the end of the simulation.
\n", + "
\n", + "\n", + "We show an example of this using `example_scenario.json`, where our traffic data is extracted from the Waymo open motion dataset:" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [], + "source": [ + "from nocturne import Simulation\n", + "\n", + "scenario_config = {\n", + " \"start_time\": 0, # When to start the simulation\n", + " \"allow_non_vehicles\": True, # Whether to include cyclists and pedestrians\n", + " \"max_visible_road_points\": 10, # Maximum number of road points for a vehicle\n", + " \"max_visible_objects\": 10, # Maximum number of road objects for a vehicle\n", + " \"max_visible_traffic_lights\": 10, # Maximum number of traffic lights in constructed view\n", + " \"max_visible_stop_signs\": 10, # Maximum number of stop signs in constructed view\n", + "}\n", + "\n", + "# Create simulation\n", + "sim = Simulation(data_path, scenario_config)" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Scenario\n", + "\n", + "A simulation consists of a set of scenarios. A scenario is a snapshot of the traffic scene at a particular timepoint. \n", + "\n", + "Here is how to create a scenario object:" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [], + "source": [ + "# Get traffic scenario at timepoint\n", + "scenario = sim.getScenario()" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "The `scenario` objects holds information we are interested in. Here are a couple of examples:" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "33" + ] + }, + "execution_count": 6, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# The number of road objects in the scene\n", + "len(scenario.getObjects())" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Total # moving objects: 15\n", + "\n", + "Object IDs of moving vehicles: \n", + " [0, 1, 2, 3, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32] \n" + ] + } + ], + "source": [ + "# The road objects that moved at a particular timepoint\n", + "objects_that_moved = scenario.getObjectsThatMoved()\n", + "\n", + "print(f\"Total # moving objects: {len(objects_that_moved)}\\n\")\n", + "print(f\"Object IDs of moving vehicles: \\n {[obj.getID() for obj in objects_that_moved]} \")" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "128" + ] + }, + "execution_count": 8, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# Number of road lines\n", + "len(scenario.road_lines())" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "[,\n", + " ,\n", + " ,\n", + " ,\n", + " ]" + ] + }, + "execution_count": 9, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "scenario.getVehicles()[:5]" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "[]" + ] + }, + "execution_count": 10, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# No cyclists in this scene\n", + "scenario.getCyclists()" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Found 2 moving vehicles in scene: [3, 32]\n" + ] + } + ], + "source": [ + "# Select all moving vehicles that move\n", + "moving_vehicles = [obj for obj in scenario.getVehicles() if obj in objects_that_moved]\n", + "\n", + "print(f\"Found {len(moving_vehicles)} moving vehicles in scene: {[vehicle.getID() for vehicle in moving_vehicles]}\")" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "#### Ego state\n", + "\n", + "The **ego state** is an array with features that describe the current vehicle. This array holds the following information: \n", + "- 0: length of ego vehicle\n", + "- 1: width of ego vehicle\n", + "- 2: speed of ego vehicle\n", + "- 3: distance to the goal position of ego vehicle\n", + "- 4: angle to the goal (target azimuth) \n", + "- 5: desired heading at goal position\n", + "- 6: desired speed at goal position\n", + "- 7: current acceleration\n", + "- 8: current steering position\n", + "- 9: current head angle" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Selected vehicle # 3\n" + ] + }, + { + "data": { + "text/plain": [ + "array([ 4.4936213 , 1.9770377 , 0.07662283, 4.24219 , -0.05617166,\n", + " -0.05909407, 1.6792779 , 0. , 0. , 0. ],\n", + " dtype=float32)" + ] + }, + "execution_count": 12, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# Select an arbitrary vehicle\n", + "ego_vehicle = moving_vehicles[0]\n", + "\n", + "print(f\"Selected vehicle # {ego_vehicle.getID()}\")\n", + "\n", + "# Get the state for ego vehicle\n", + "scenario.ego_state(ego_vehicle)" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "#### Visible state\n", + "\n", + "We use the ego vehicle state, together with a view distance (how far the vehicle can see) and a view angle to construct the **visible state**. The figure below shows this procedure for a simplified traffic scene. \n", + "\n", + "Calling `scenario.visible_state()` returns a dictionary with four matrices:\n", + "- `stop_signs`: The visible stop signs \n", + "- `traffic_lights`: The states for the traffic lights from the perspective of the ego driver(red, yellow, green).\n", + "- `road_points`: The observable road points (static elements in the scene).\n", + "- `objects`: The observable road objects (vehicles, pedestrians and cyclists).\n", + "\n", + "
\n", + "
\n", + "\n", + "
To investigate coordination under partial observability, agents in Nocturne can only see an obstructed view of their environment. In this simplified traffic scene, we construct the state for the red ego driver. Note that Nocturne assumes that stop signs can be viewed, even if they are behind another driver.
\n", + "
\n", + "\n", + "\\begin{align*}\n", + "\\end{align*}\n", + "\n", + "
\n", + "
\n", + "\n", + "
The same scene, this time showing the view of the yellow car.
\n", + "
" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "The shape of the visible state is a function of the maximum number of visible objects defined at initialization (traffic lights, stop signs, road objects, and road points) and whether we add padding. If `padding = True`, an array is of size `(max visible objects, # features)` is always constructed, even if there are no visible objects. Otherwise, if `padding = False` new entries are only created when objects are visible. \n", + "\n", + "For example, say a vehicle does not observe any stop signs at a given timepoint. If we set `padding=False`, and run `visible_state['stop_signs']`, we'll get back an empty array with the shape `(0, 3)`, where 3 is the number of features per stop sign. However, if the vehicle observes two stop signs using the same setting, then `visible_state['stop_signs']` will return an array with the shape `(2, 3)`.\n", + "\n", + "On the other hand, if we set `padding=True`, the resulting array will always have a shape of `(max visible stop signs, 3)`, irrespective of how many stop signs the vehicle actually observes." + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "dict_keys(['stop_signs', 'traffic_lights', 'road_points', 'objects'])" + ] + }, + "execution_count": 13, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# Define viewing distance, radius and head angle\n", + "view_distance = 80\n", + "view_angle = np.radians(120)\n", + "head_angle = 0\n", + "padding = True\n", + "\n", + "# Construct the visible state for ego vehicle\n", + "visible_state = scenario.visible_state(\n", + " ego_vehicle,\n", + " view_dist=view_distance,\n", + " view_angle=view_angle,\n", + " head_angle=head_angle,\n", + " padding=padding,\n", + ")\n", + "\n", + "visible_state.keys()" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "array([[0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],\n", + " [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],\n", + " [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]], dtype=float32)" + ] + }, + "execution_count": 14, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# There are no visible stop signs at this point\n", + "visible_state[\"stop_signs\"].T" + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "array([[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],\n", + " [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],\n", + " [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],\n", + " [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],\n", + " [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],\n", + " [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],\n", + " [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],\n", + " [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],\n", + " [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],\n", + " [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]], dtype=float32)" + ] + }, + "execution_count": 15, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# Traffic light states are filtered out in this version of Nocturne\n", + "visible_state[\"traffic_lights\"]" + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "(10, 13)" + ] + }, + "execution_count": 16, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# Max visible road points x 13 features\n", + "visible_state[\"road_points\"].shape" + ] + }, + { + "cell_type": "code", + "execution_count": 17, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "(10, 13)" + ] + }, + "execution_count": 17, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# Number of visible road objects x 13 features\n", + "visible_state[\"objects\"].shape" + ] + }, + { + "cell_type": "code", + "execution_count": 18, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Dimension flattened visible state: 410\n" + ] + } + ], + "source": [ + "visible_state_dim = sum([val.flatten().shape[0] for key, val in visible_state.items()])\n", + "\n", + "print(f\"Dimension flattened visible state: {visible_state_dim}\")" + ] + }, + { + "cell_type": "code", + "execution_count": 19, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "(410,)" + ] + }, + "execution_count": 19, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# We can also flatten the visible state\n", + "# flattened has padding: if we miss an object --> zeros\n", + "visible_state_flat = scenario.flattened_visible_state(\n", + " ego_vehicle,\n", + " view_dist=view_distance,\n", + " view_angle=view_angle,\n", + " head_angle=head_angle,\n", + ")\n", + "\n", + "visible_state_flat.shape" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Note that `.flattened_visible_state()` adds padding by default." + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Step \n", + "\n", + "`step(dt)` is a method call on an instance of the Simulation class, where `dt` is a scalar that represents the length of each simulation timestep in seconds. It advances the simulation by one timestep, which can result in changes to the state of the simulation (for example, new positions of objects, updated velocities, etc.) based on the physical laws and rules defined in the simulation.\n", + "\n", + "In the Waymo dataset, the length of the expert data is 9 seconds, a step size of 0.1 is used to discretize each traffic scene. The first second is used as a warm-start, leaving the remaining 8 seconds (80 steps) for the simulation (Details in Section 3.3)." + ] + }, + { + "cell_type": "code", + "execution_count": 20, + "metadata": {}, + "outputs": [], + "source": [ + "dt = 0.1\n", + "\n", + "# Step the simulation\n", + "sim.step(dt)" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Vehicle control\n", + "\n", + "By default, vehicles in Nocturne are driven by a **kinematic bicycle model**. This means that calling the `step(dt)` method evolves the dynamics of a vehicle according to the following set of equations (Appendix D in the paper):\n", + "\n", + "\\begin{align*}\n", + " \\textbf{position: } x_{t+1} &= x_t + \\dot{x} \\, \\Delta t \\\\\n", + " y_{t+1} &= y_t + \\dot{y} \\, \\Delta t \\\\\n", + " \\textbf{heading: } \\theta_{t+1} &= \\theta_t + \\dot{\\theta} \\, \\Delta t \\\\ \n", + " \\textbf{speed: } v_{t+1} &= \\text{clip}(v_t + \\dot{v} \\, \\Delta t, -v_{\\text{max}}, v_{\\text{max}}) \\\\\n", + "\\end{align*}\n", + "\n", + "with\n", + "\n", + "\\begin{align*}\n", + " \\dot{v} &= a \\\\ \n", + " \\bar{v} &= \\text{clip}(v_t, + 0.5 \\, \\dot{v} \\, \\Delta \\, t ,\\, - v_{\\text{max}}, v_{\\text{max}}) \\\\\n", + " \\beta &= \\tan^{-1} \\left( \\frac{l_r \\tan (\\delta)}{L} \\right) \\\\\n", + " &= \\tan^{-1} (0.5 \\tan(\\delta)) \\\\\n", + " \\dot{x} &= \\bar{v} \\cos (\\theta + \\beta) \\\\\n", + " \\dot{y} &= \\bar{v} \\sin (\\theta + \\beta) \\\\\n", + " \\dot{\\theta} &= \\frac{\\bar{v} \\cos (\\beta)\\tan(\\delta)}{L}\n", + "\\end{align*}\n", + "\n", + "where $(x_t, y_t)$ is the position of a vehicle at time $t$, $\\theta_t$ is the vehicles heading angle, $a$ is the acceleration and $\\delta$ is the steering angle. Finally, $L$ is the length of the car and $l_r = 0.5L$ is the distance to the rear axle of the car.\n", + "\n", + "If we set a vehicle to be **expert-controlled** instead, it will follow the same path as the respective human driver. This means that when we call the `step(dt)` function, the vehicle's position, heading, and speed will be updated to match the next point in the recorded human trajectory." + ] + }, + { + "cell_type": "code", + "execution_count": 21, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "False" + ] + }, + "execution_count": 21, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# By default, all vehicles are not expert controlled\n", + "ego_vehicle.expert_control" + ] + }, + { + "cell_type": "code", + "execution_count": 22, + "metadata": {}, + "outputs": [], + "source": [ + "# Set a vehicle to be expert controlled:\n", + "ego_vehicle.expert_control = True" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "---\n", + "\n", + "> **Pseudocode**: How `step(dt)` advances the simulation for every vehicle. Full code is implemented in [scenario.cc](https://github.com/facebookresearch/nocturne/blob/ae0a4e361457caf6b7e397675cc86f46161405ed/nocturne/cpp/src/scenario.cc#L264)\n", + "\n", + "---\n", + "\n", + "```Python\n", + "for vehicle in vehicles:\n", + "\n", + " if object is not expert controlled:\n", + " step vehicle dynamics following the kinematic bicycle model\n", + " \n", + " if vehicle is expert controlled:\n", + " get current time & vehicle idx\n", + " vehicle position = expert trajectories[vehicle_idx, time]\n", + " vehicle heading = expert headings[vehicle_idx, time]\n", + " vehicle speed = expert speeds[vehicle_idx, time]\n", + "```" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Action space\n", + "\n", + "The action set for a vehicle consists of three components: acceleration, steering and the head angle. Actions are discretized based on a provided upper and lower bound.\n", + "\n", + "The experiments in the paper use:\n", + "- 6 discrete actions for **acceleration** uniformly split between $[-3, 2] \\, \\frac{m}{s^2}$\n", + "- 21 discrete actions for **steering** between $[-0.7, 0.7]$ radians \n", + "- 5 discrete actions for **head tilt** between $[-1.6, 1.6]$ radians\n", + "\n", + "This is how you can access an expert action for a vehicle in Nocturne:" + ] + }, + { + "cell_type": "code", + "execution_count": 23, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "{acceleration: -0.224648, steering: -0.360994, head_angle: 0.000000}" + ] + }, + "execution_count": 23, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# Choose an arbitrary timepoint\n", + "time = 5\n", + "\n", + "# Show expert action at timepoint\n", + "scenario.expert_action(ego_vehicle, time)" + ] + }, + { + "cell_type": "code", + "execution_count": 24, + "metadata": {}, + "outputs": [], + "source": [ + "expert_action = scenario.expert_action(ego_vehicle, time)\n", + "\n", + "expert_action = expert_action.numpy()" + ] + }, + { + "cell_type": "code", + "execution_count": 27, + "metadata": {}, + "outputs": [], + "source": [ + "acceleration = expert_action[0]\n", + "steering = expert_action[1]" + ] + }, + { + "cell_type": "code", + "execution_count": 28, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "nocturne_cpp.Action" + ] + }, + "execution_count": 28, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "type(scenario.expert_action(ego_vehicle, time))" + ] + }, + { + "cell_type": "code", + "execution_count": 29, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "(-0.005859, 0.004639)" + ] + }, + "execution_count": 29, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# How did the vehicle's position change after taking this action?\n", + "scenario.expert_pos_shift(ego_vehicle, time)" + ] + }, + { + "cell_type": "code", + "execution_count": 30, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "-0.0007097125053405762" + ] + }, + "execution_count": 30, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# How did the head angle change?\n", + "scenario.expert_heading_shift(ego_vehicle, time)" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "nocturne-research", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.12" + }, + "orig_nbformat": 4 + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/docs/source/03_basic_rl_usage.ipynb b/docs/source/03_basic_rl_usage.ipynb new file mode 100644 index 00000000..76ed79b9 --- /dev/null +++ b/docs/source/03_basic_rl_usage.ipynb @@ -0,0 +1,216 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Basic RL usage" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Initializing environments\n", + "\n", + "\n", + "#### **Environment settings**\n", + "\n", + "- Initializing an environment is done with the `BaseEnv` class. The `BaseEnv` class leverages the `nocturne` simulator to create a basic RL interface, based on the provided traffic scenario(s). \n", + "\n", + "---\n", + "> 📝 The `env_config.yaml` file defines our environment settings, such as the action space, observation space and traffic scenarios to use.\n", + "---\n", + "\n", + "Check out `configs/env_config` for all the details!" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [ + { + "ename": "Exception", + "evalue": "", + "output_type": "error", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mException\u001b[0m Traceback (most recent call last)", + "\u001b[1;32m/Users/dejongmathijs/Library/CloudStorage/OneDrive-TheBostonConsultingGroup,Inc/Documents/Personal/nocturne_lab/examples/03_basic_rl_usage.ipynb Cell 3\u001b[0m line \u001b[0;36m2\n\u001b[1;32m 1\u001b[0m \u001b[39mimport\u001b[39;00m \u001b[39myaml\u001b[39;00m\n\u001b[0;32m----> 2\u001b[0m \u001b[39mfrom\u001b[39;00m \u001b[39mnocturne\u001b[39;00m\u001b[39m.\u001b[39;00m\u001b[39menvs\u001b[39;00m\u001b[39m.\u001b[39;00m\u001b[39mbase_env\u001b[39;00m \u001b[39mimport\u001b[39;00m BaseEnv\n\u001b[1;32m 4\u001b[0m \u001b[39m# Load environment settings\u001b[39;00m\n\u001b[1;32m 5\u001b[0m \u001b[39mwith\u001b[39;00m \u001b[39mopen\u001b[39m(\u001b[39mf\u001b[39m\u001b[39m\"\u001b[39m\u001b[39m../configs/env_config.yaml\u001b[39m\u001b[39m\"\u001b[39m, \u001b[39m\"\u001b[39m\u001b[39mr\u001b[39m\u001b[39m\"\u001b[39m) \u001b[39mas\u001b[39;00m stream:\n", + "File \u001b[0;32m~/.pyenv-i386/versions/nocturne_lab/lib/python3.10/site-packages/nocturne/envs/__init__.py:2\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[39m\"\"\"Import file for tests.\"\"\"\u001b[39;00m\n\u001b[0;32m----> 2\u001b[0m \u001b[39mfrom\u001b[39;00m \u001b[39mnocturne\u001b[39;00m\u001b[39m.\u001b[39;00m\u001b[39menvs\u001b[39;00m\u001b[39m.\u001b[39;00m\u001b[39mbase_env\u001b[39;00m \u001b[39mimport\u001b[39;00m BaseEnv\n\u001b[1;32m 4\u001b[0m __all__ \u001b[39m=\u001b[39m [\n\u001b[1;32m 5\u001b[0m \u001b[39m\"\u001b[39m\u001b[39mBaseEnv\u001b[39m\u001b[39m\"\u001b[39m,\n\u001b[1;32m 6\u001b[0m ]\n", + "File \u001b[0;32m~/.pyenv-i386/versions/nocturne_lab/lib/python3.10/site-packages/nocturne/envs/base_env.py:24\u001b[0m\n\u001b[1;32m 20\u001b[0m \u001b[39mfrom\u001b[39;00m \u001b[39mgym\u001b[39;00m\u001b[39m.\u001b[39;00m\u001b[39mspaces\u001b[39;00m \u001b[39mimport\u001b[39;00m Box, Discrete\n\u001b[1;32m 22\u001b[0m \u001b[39mfrom\u001b[39;00m \u001b[39mnocturne\u001b[39;00m \u001b[39mimport\u001b[39;00m Action, Simulation, Vector2D, Vehicle\n\u001b[0;32m---> 24\u001b[0m \u001b[39mraise\u001b[39;00m \u001b[39mException\u001b[39;00m()\n\u001b[1;32m 26\u001b[0m _MAX_NUM_TRIES_TO_FIND_VALID_VEHICLE \u001b[39m=\u001b[39m \u001b[39m1_000\u001b[39m\n\u001b[1;32m 28\u001b[0m logging\u001b[39m.\u001b[39mgetLogger(\u001b[39m__name__\u001b[39m)\n", + "\u001b[0;31mException\u001b[0m: " + ] + } + ], + "source": [ + "import os\n", + "\n", + "import yaml\n", + "\n", + "from nocturne.envs.base_env import BaseEnv\n", + "\n", + "os.chdir(\"..\")\n", + "\n", + "# Load environment settings\n", + "with open(\"./configs/env_config.yaml\") as stream:\n", + " env_config = yaml.safe_load(stream)\n", + "\n", + "# Initialize environment\n", + "env = BaseEnv(config=env_config)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "print(f\"controlling agents # {[agent.id for agent in env.controlled_vehicles]}\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "#### **Data**\n", + "\n", + "- Within `env_config.yaml`, we specify the path to the folder containing the traffic scenarios to use as follows:\n", + "\n", + "```yaml\n", + "# Path to folder with traffic scene(s) from which to create an environment\n", + "data_path: ../data\n", + "```\n", + "\n", + "- [Here](https://github.com/facebookresearch/nocturne/tree/main#downloading-the-dataset) are the instructions to access the complete dataset of traffic scenes. \n", + "\n", + "- The data folder also has a file named `valid_files.json`. This file lists the names of all the valid traffic scenarios along with the ids of the vehicles that are not valid. These vehicles are excluded from our experiment.\n", + "\n", + "For simplicity, we currently added a single traffic scenario that includes two vehicles in our data folder. Both vehicles can be used, so our `valid_files.json` looks like this:\n", + "\n", + "```yaml\n", + "{\n", + " \"example_scenario.json\": []\n", + "}\n", + "```" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Interacting with the environment\n", + "\n", + "The classic agent-environment loop of reinforcement learning is implemented as follows:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Reset\n", + "obs_dict = env.reset()\n", + "\n", + "# Get info\n", + "agent_ids = [agent_id for agent_id in obs_dict.keys()]\n", + "dead_agent_ids = []\n", + "num_agents = len(agent_ids)\n", + "rewards = {agent_id: 0 for agent_id in agent_ids}\n", + "\n", + "for step in range(1000):\n", + " # Sample actions\n", + " action_dict = {agent_id: env.action_space.sample() for agent_id in agent_ids if agent_id not in dead_agent_ids}\n", + "\n", + " # Step in env\n", + " obs_dict, rew_dict, done_dict, info_dict = env.step(action_dict)\n", + "\n", + " for agent_id in action_dict.keys():\n", + " rewards[agent_id] += rew_dict[agent_id]\n", + "\n", + " # Update dead agents\n", + " for agent_id, is_done in done_dict.items():\n", + " if is_done and agent_id not in dead_agent_ids:\n", + " dead_agent_ids.append(agent_id)\n", + "\n", + " # Reset if all agents are done\n", + " if done_dict[\"__all__\"]:\n", + " print(f\"Done after {env.step_num} steps -- total return in episode: {rewards}\")\n", + " obs_dict = env.reset()\n", + " dead_agent_ids = []\n", + " rewards = {agent_id: 0 for agent_id in agent_ids}\n", + "\n", + "# Close environment\n", + "env.close()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Accessing information about the environment" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# The observation space\n", + "env.observation_space" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# The size of the joint action space\n", + "env.action_space" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Which agents are controlled?\n", + "env.controlled_vehicles" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### \n" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.12" + }, + "orig_nbformat": 4 + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/docs/source/04_ppo_with_sb3.ipynb b/docs/source/04_ppo_with_sb3.ipynb new file mode 100644 index 00000000..3969a88e --- /dev/null +++ b/docs/source/04_ppo_with_sb3.ipynb @@ -0,0 +1,164 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## PPO with single-agent control\n", + "\n", + "In this notebook, we show how to use Proximal Policy Optimization (PPO) with Nocturne and [Stable Baselines 3 (SB3)](https://stable-baselines3.readthedocs.io/en/master/index.html). SB3 is a library that has implementations of various well-known RL algorithms." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Wrappers\n", + "\n", + "The Nocturne `BaseEnv` returns output as dictionaries, but the SB3 `PPO` class expects numpy arrays. To make our environment compatible with SB3, we create a wrapper class. Wrappers modify an environment without altering code directly, which reduces boilerplate and increasing modularity." + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": {}, + "outputs": [], + "source": [ + "import os\n", + "\n", + "import yaml\n", + "\n", + "# Import base environment and wrapper\n", + "from nocturne.envs.base_env import BaseEnv\n", + "from nocturne.wrappers.sb3_wrappers import NocturneToSB3\n", + "\n", + "os.chdir(\"..\")" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": {}, + "outputs": [], + "source": [ + "# Load environment settings\n", + "with open(\"./configs/env_config.yaml\") as stream:\n", + " env_config = yaml.safe_load(stream)\n", + "\n", + "# Make sure to only control a single agent at a time. This is achieved by setting max_num_vehicles = 1\n", + "env_config[\"max_num_vehicles\"] = 1" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": {}, + "outputs": [], + "source": [ + "# Initialize env and wrap it with SB3 wrapper\n", + "env = NocturneToSB3(BaseEnv(env_config))" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### PPO\n", + "\n", + "Now all we have to do is initialize the SB3 `PPO` class and we're ready to learn! We use Weights & Biases (`wandb`) to take care of the logging. If you prefer not to use `wandb`, set `LOGGING = False` and `verbose=1`. \n", + "\n", + "\n", + "---\n", + "\n", + "> 🔦 More info on PPO and settings can be found in the [SB3 docs](https://stable-baselines3.readthedocs.io/en/master/modules/ppo.html).\n", + "\n", + "---" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "metadata": {}, + "outputs": [], + "source": [ + "import wandb\n", + "from stable_baselines3 import PPO" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "metadata": {}, + "outputs": [], + "source": [ + "LOGGING = True" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "if LOGGING:\n", + " wandb.login()\n", + " run = wandb.init(\n", + " project=\"single_agent_control_sb3_ppo\",\n", + " sync_tensorboard=True,\n", + " )\n", + " run_id = run.id\n", + "else:\n", + " run_id = None\n", + "\n", + "# Init PPO algorithm\n", + "model = PPO(\n", + " policy=\"MlpPolicy\", # Policy type\n", + " n_steps=4096, # Number of steps per rollout\n", + " batch_size=128, # Minibatch size\n", + " env=env, # Our wrapped environment\n", + " seed=42, # Always seed for reproducibility\n", + " verbose=0,\n", + " tensorboard_log=f\"runs/{run_id}\" if run_id is not None else None, # Sync with wandb\n", + ")\n", + "\n", + "# Learn\n", + "model.learn(total_timesteps=200_000)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### 🤔 How good is your policy?\n", + "\n", + "Hooray! You have just trained your first PPO agent in Nocturne! 🏁 \n", + "\n", + "Now take a look at information you've logged over training; did we learn? (if you want to compare, [this is how my run looks like](https://api.wandb.ai/links/daphnecor/iarufxw9))\n", + "\n", + "One important metric for assess the effectiveness of your policy is the average cumulative reward per episode. In our case, the **maximum** achievable return per episode is approximately between 9 and 10 (it varies per traffic scene and per agent). With the configurations above, your policy should approach this value in 150,000 steps. Here, steps (the `global_step`) represents the total number of **frames** our policy network has seen, you can think of it as the accumulated experience." + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "nocturne_lab", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.12" + }, + "orig_nbformat": 4 + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/docs/source/_static/logo.png b/docs/source/_static/logo.png new file mode 100644 index 00000000..c7c26e1d Binary files /dev/null and b/docs/source/_static/logo.png differ diff --git a/docs/source/api_python.md b/docs/source/api_python.md new file mode 100644 index 00000000..0fd775fc --- /dev/null +++ b/docs/source/api_python.md @@ -0,0 +1,9 @@ +--- +title: Nocturne | API Reference +--- +# API Reference + +```{eval-rst} +.. autosummary:: nocturne + :toctree: generated +``` diff --git a/docs/source/changelog.md b/docs/source/changelog.md new file mode 100644 index 00000000..9695d069 --- /dev/null +++ b/docs/source/changelog.md @@ -0,0 +1,6 @@ +--- +title: Changelog +--- + +```{include} ../../CHANGELOG.md +``` diff --git a/docs/source/coc.md b/docs/source/coc.md new file mode 100644 index 00000000..39a1be4f --- /dev/null +++ b/docs/source/coc.md @@ -0,0 +1,6 @@ +--- +title: Code of Conduct +--- + +```{include} ../../CODE_OF_CONDUCT.md +``` diff --git a/docs/source/conf.py b/docs/source/conf.py new file mode 100644 index 00000000..6cb631d4 --- /dev/null +++ b/docs/source/conf.py @@ -0,0 +1,139 @@ +""" +Configuration file for the Sphinx documentation builder. + +For the full list of built-in configuration values, see the documentation: +https://www.sphinx-doc.org/en/master/usage/configuration.html +""" + +# -- Path setup -------------------------------------------------------------- + +import pathlib +import sys + +sys.path.insert(0, pathlib.Path(__file__).parents[2].resolve().as_posix()) + + +# -- Project information ----------------------------------------------------- +# https://www.sphinx-doc.org/en/master/usage/configuration.html#project-information + +project = "Nocturne" +copyright = "2023, The Nocturne Authors" +author = "The Nocturne Authors" + +# The short X.Y version +version = "0.0.1" +# The full version, including alpha/beta/rc tags +release = "0.0.1" + +# -- General configuration --------------------------------------------------- +# https://www.sphinx-doc.org/en/master/usage/configuration.html#general-configuration + +extensions = [ + "sphinx.ext.autodoc", + "sphinx.ext.autosummary", + "sphinx.ext.autosectionlabel", + "sphinx.ext.autosummary", + "sphinx.ext.duration", + "sphinx.ext.githubpages", + "sphinx.ext.intersphinx", + "sphinx.ext.napoleon", + "myst_nb", # See: https://myst-nb.readthedocs.io/en/latest/ + "sphinxcontrib.bibtex", # See: https://sphinxcontrib-bibtex.readthedocs.io/en/latest/quickstart.html + "sphinx_autodoc_typehints", # See: https://github.com/tox-dev/sphinx-autodoc-typehints +] + +templates_path = ["_templates"] +exclude_patterns = ["_build", "Thumbs.db", ".DS_Store"] + +intersphinx_mapping = { + "python": ("https://docs.python.org/3.8", None), + "numpy": ("https://numpy.org/doc/stable/", None), + "scipy": ("https://docs.scipy.org/doc/scipy-1.8.1/", None), +} + +bibtex_bibfiles = ["references.bib"] + +autosummary_generate = True + +# -- Options for HTML output ------------------------------------------------- +# https://www.sphinx-doc.org/en/master/usage/configuration.html#options-for-html-output + +html_theme = "sphinx_book_theme" +html_theme_options = { + "path_to_docs": "docs", + "use_download_button": True, + "use_edit_page_button": True, + "use_fullscreen_button": True, + "use_issues_button": True, + "use_source_button": True, + "use_repository_button": True, + "use_sidenotes": True, + "repository_url": "https://github.com/emerge-lab/nocturne", + "repository_branch": "main", + "launch_buttons": {"colab_url": "https://colab.research.google.com"}, + "home_page_in_toc": True, + "show_navbar_depth": 1, + "show_toc_level": 2, + "icon_links": [ + { + "name": "Nocturne GitHub", + "url": "https://github.com/emerge-lab/nocturne", + "icon": "fa-brands fa-github", + }, + ], +} +html_static_path = ["_static"] +html_logo = "_static/logo.png" +# html_favicon = "_static/logo-square.svg" +html_title = "Nocturne" +html_copy_source = True + +html_sidebars = {"**/**": ["sbt-sidebar-nav.html"]} + +# -- Options for MySt-NB output ------------------------------------------------- +# https://myst-nb.readthedocs.io/en/latest/index.html +source_suffix = { + ".rst": "restructuredtext", + ".ipynb": "myst-nb", + ".myst": "myst-nb", +} +myst_heading_anchors = 3 # auto-generate 3 levels of heading anchors +myst_enable_extensions = [ + "amsmath", + "colon_fence", + "deflist", + "dollarmath", + "html_image", +] +myst_url_schemes = ("http", "https", "mailto") +nb_execution_mode = "force" +nb_execution_allow_errors = False +nb_merge_streams = True + +nb_execution_excludepatterns = [ + # Slow notebook + # 'notebooks/Neural_Network_and_Data_Loading.*', +] + +# -- Options for autodoc ---------------------------------------------------- +# https://www.sphinx-doc.org/en/master/usage/extensions/autodoc.html#configuration + +# Automatically extract typehints when specified and place them in +# descriptions of the relevant function/method. +autodoc_typehints = "description" + +# Don't show class signature with the class' name. +autodoc_class_signature = "separated" + +# -- Extension configuration ------------------------------------------------- + +# Tell sphinx-autodoc-typehints to generate stub parameter annotations including +# types, even if the parameters aren't explicitly documented. +always_document_param_types = True + + +# Tell sphinx autodoc how to render type aliases. +autodoc_type_aliases = { + "ArrayLike": "ArrayLike", + "DTypeLike": "DTypeLike", +} diff --git a/docs/source/contributing.md b/docs/source/contributing.md new file mode 100644 index 00000000..60539c0e --- /dev/null +++ b/docs/source/contributing.md @@ -0,0 +1,6 @@ +--- +title: Contributing +--- + +```{include} ../../CONTRIBUTING.md +``` diff --git a/docs/source/index.md b/docs/source/index.md new file mode 100644 index 00000000..8db35682 --- /dev/null +++ b/docs/source/index.md @@ -0,0 +1,39 @@ +--- +title: Nocturne +--- + +```{toctree} +:maxdepth: 1 +:hidden: +:caption: Quickstart +self +``` + +```{toctree} +:maxdepth: 1 +:hidden: +:caption: Tutorial +01_data_structure +02_nocturne_concepts +03_basic_rl_usage +04_ppo_with_sb3 +``` + +```{toctree} +:maxdepth: 1 +:hidden: +:caption: Reference +api_python +``` + +```{toctree} +:maxdepth: 1 +:hidden: +:caption: Developer +contributing +changelog +coc +``` + +```{include} ../../README.md +``` diff --git a/docs/source/references.bib b/docs/source/references.bib new file mode 100644 index 00000000..e69de29b diff --git a/pyproject.toml b/pyproject.toml index e62c4129..21870abc 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -55,6 +55,9 @@ pandas = "^2.1.1" wandb = "^0.15.12" tensorboard = "^2.14.1" +[tool.poetry.group.dev] +optional = true + [tool.poetry.group.dev.dependencies] pre-commit = "^3.4.0" flake8 = "^6.1.0" @@ -63,6 +66,16 @@ isort = "^5.12.0" pylint = "^3.0.0" tomli = "^2.0.1" +[tool.poetry.group.docs] +optional = true + +[tool.poetry.group.docs.dependencies] +sphinx = "^5.3.0" +sphinx-book-theme = "^1.0.1" +myst-nb = "^0.17.2" +sphinxcontrib-bibtex = "^2.6.1" +sphinx-autodoc-typehints = "^1.23.0" + [tool.poetry.build] script = "build.py" generate-setup-file = true @@ -79,6 +92,7 @@ convention = "google" [tool.pylint] max-line-length = 120 +exclude = "^docs/*" disable = "W1514" [tool.isort] diff --git a/requirements.docs.txt b/requirements.docs.txt new file mode 100644 index 00000000..475180f2 --- /dev/null +++ b/requirements.docs.txt @@ -0,0 +1,87 @@ +accessible-pygments==0.0.4 ; python_version >= "3.10" and python_version < "3.13" +alabaster==0.7.13 ; python_version >= "3.10" and python_version < "3.13" +appnope==0.1.3 ; python_version >= "3.10" and python_version < "3.13" and (platform_system == "Darwin" or sys_platform == "darwin") +asttokens==2.4.0 ; python_version >= "3.10" and python_version < "3.13" +attrs==23.1.0 ; python_version >= "3.10" and python_version < "3.13" +babel==2.13.0 ; python_version >= "3.10" and python_version < "3.13" +backcall==0.2.0 ; python_version >= "3.10" and python_version < "3.13" +beautifulsoup4==4.12.2 ; python_version >= "3.10" and python_version < "3.13" +certifi==2023.7.22 ; python_version >= "3.10" and python_version < "3.13" +cffi==1.16.0 ; python_version >= "3.10" and python_version < "3.13" and implementation_name == "pypy" +charset-normalizer==3.3.0 ; python_version >= "3.10" and python_version < "3.13" +click==8.1.7 ; python_version >= "3.10" and python_version < "3.13" +colorama==0.4.6 ; python_version >= "3.10" and python_version < "3.13" and (sys_platform == "win32" or platform_system == "Windows") +comm==0.1.4 ; python_version >= "3.10" and python_version < "3.13" +debugpy==1.8.0 ; python_version >= "3.10" and python_version < "3.13" +decorator==5.1.1 ; python_version >= "3.10" and python_version < "3.13" +docutils==0.17.1 ; python_version >= "3.10" and python_version < "3.13" +exceptiongroup==1.1.3 ; python_version >= "3.10" and python_version < "3.11" +executing==2.0.0 ; python_version >= "3.10" and python_version < "3.13" +fastjsonschema==2.18.1 ; python_version >= "3.10" and python_version < "3.13" +greenlet==3.0.0 ; python_version >= "3.10" and python_version < "3.13" and (platform_machine == "aarch64" or platform_machine == "ppc64le" or platform_machine == "x86_64" or platform_machine == "amd64" or platform_machine == "AMD64" or platform_machine == "win32" or platform_machine == "WIN32") +idna==3.4 ; python_version >= "3.10" and python_version < "3.13" +imagesize==1.4.1 ; python_version >= "3.10" and python_version < "3.13" +importlib-metadata==6.8.0 ; python_version >= "3.10" and python_version < "3.13" +ipykernel==6.25.2 ; python_version >= "3.10" and python_version < "3.13" +ipython==8.16.1 ; python_version >= "3.10" and python_version < "3.13" +jedi==0.19.1 ; python_version >= "3.10" and python_version < "3.13" +jinja2==3.1.2 ; python_version >= "3.10" and python_version < "3.13" +jsonschema-specifications==2023.7.1 ; python_version >= "3.10" and python_version < "3.13" +jsonschema==4.19.1 ; python_version >= "3.10" and python_version < "3.13" +jupyter-cache==0.6.1 ; python_version >= "3.10" and python_version < "3.13" +jupyter-client==8.3.1 ; python_version >= "3.10" and python_version < "3.13" +jupyter-core==5.3.2 ; python_version >= "3.10" and python_version < "3.13" +latexcodec==2.0.1 ; python_version >= "3.10" and python_version < "3.13" +markdown-it-py==2.2.0 ; python_version >= "3.10" and python_version < "3.13" +markupsafe==2.1.3 ; python_version >= "3.10" and python_version < "3.13" +matplotlib-inline==0.1.6 ; python_version >= "3.10" and python_version < "3.13" +mdit-py-plugins==0.3.5 ; python_version >= "3.10" and python_version < "3.13" +mdurl==0.1.2 ; python_version >= "3.10" and python_version < "3.13" +myst-nb==0.17.2 ; python_version >= "3.10" and python_version < "3.13" +myst-parser==0.18.1 ; python_version >= "3.10" and python_version < "3.13" +nbclient==0.7.4 ; python_version >= "3.10" and python_version < "3.13" +nbformat==5.9.2 ; python_version >= "3.10" and python_version < "3.13" +nest-asyncio==1.5.8 ; python_version >= "3.10" and python_version < "3.13" +packaging==23.2 ; python_version >= "3.10" and python_version < "3.13" +parso==0.8.3 ; python_version >= "3.10" and python_version < "3.13" +pexpect==4.8.0 ; python_version >= "3.10" and python_version < "3.13" and sys_platform != "win32" +pickleshare==0.7.5 ; python_version >= "3.10" and python_version < "3.13" +platformdirs==3.11.0 ; python_version >= "3.10" and python_version < "3.13" +prompt-toolkit==3.0.39 ; python_version >= "3.10" and python_version < "3.13" +psutil==5.9.5 ; python_version >= "3.10" and python_version < "3.13" +ptyprocess==0.7.0 ; python_version >= "3.10" and python_version < "3.13" and sys_platform != "win32" +pure-eval==0.2.2 ; python_version >= "3.10" and python_version < "3.13" +pybtex-docutils==1.0.3 ; python_version >= "3.10" and python_version < "3.13" +pybtex==0.24.0 ; python_version >= "3.10" and python_version < "3.13" +pycparser==2.21 ; python_version >= "3.10" and python_version < "3.13" and implementation_name == "pypy" +pydata-sphinx-theme==0.14.1 ; python_version >= "3.10" and python_version < "3.13" +pygments==2.16.1 ; python_version >= "3.10" and python_version < "3.13" +python-dateutil==2.8.2 ; python_version >= "3.10" and python_version < "3.13" +pywin32==306 ; sys_platform == "win32" and platform_python_implementation != "PyPy" and python_version >= "3.10" and python_version < "3.13" +pyyaml==6.0.1 ; python_version >= "3.10" and python_version < "3.13" +pyzmq==25.1.1 ; python_version >= "3.10" and python_version < "3.13" +referencing==0.30.2 ; python_version >= "3.10" and python_version < "3.13" +requests==2.31.0 ; python_version >= "3.10" and python_version < "3.13" +rpds-py==0.10.3 ; python_version >= "3.10" and python_version < "3.13" +six==1.16.0 ; python_version >= "3.10" and python_version < "3.13" +snowballstemmer==2.2.0 ; python_version >= "3.10" and python_version < "3.13" +soupsieve==2.5 ; python_version >= "3.10" and python_version < "3.13" +sphinx-autodoc-typehints==1.23.0 ; python_version >= "3.10" and python_version < "3.13" +sphinx-book-theme==1.0.1 ; python_version >= "3.10" and python_version < "3.13" +sphinx==5.3.0 ; python_version >= "3.10" and python_version < "3.13" +sphinxcontrib-applehelp==1.0.7 ; python_version >= "3.10" and python_version < "3.13" +sphinxcontrib-bibtex==2.6.1 ; python_version >= "3.10" and python_version < "3.13" +sphinxcontrib-devhelp==1.0.5 ; python_version >= "3.10" and python_version < "3.13" +sphinxcontrib-htmlhelp==2.0.4 ; python_version >= "3.10" and python_version < "3.13" +sphinxcontrib-jsmath==1.0.1 ; python_version >= "3.10" and python_version < "3.13" +sphinxcontrib-qthelp==1.0.6 ; python_version >= "3.10" and python_version < "3.13" +sphinxcontrib-serializinghtml==1.1.9 ; python_version >= "3.10" and python_version < "3.13" +sqlalchemy==2.0.21 ; python_version >= "3.10" and python_version < "3.13" +stack-data==0.6.3 ; python_version >= "3.10" and python_version < "3.13" +tabulate==0.9.0 ; python_version >= "3.10" and python_version < "3.13" +tornado==6.3.3 ; python_version >= "3.10" and python_version < "3.13" +traitlets==5.11.1 ; python_version >= "3.10" and python_version < "3.13" +typing-extensions==4.8.0 ; python_version >= "3.10" and python_version < "3.13" +urllib3==2.0.6 ; python_version >= "3.10" and python_version < "3.13" +wcwidth==0.2.8 ; python_version >= "3.10" and python_version < "3.13" +zipp==3.17.0 ; python_version >= "3.10" and python_version < "3.13"