From 34d16cf6fb97b21376e0503964f7f0a8082709d4 Mon Sep 17 00:00:00 2001 From: Aarav Pandya Date: Fri, 12 Jan 2024 20:39:04 -0500 Subject: [PATCH 1/2] Init --- examples/06_ppo_with_sb3_ma_menv_control.py | 193 + examples/hr_rl.ipynb | 6640 +++++++++++++++++++ examples/temp.ipynb | 384 ++ nocturne/envs/nocturne_gymnasium.py | 136 + 4 files changed, 7353 insertions(+) create mode 100644 examples/06_ppo_with_sb3_ma_menv_control.py create mode 100644 examples/hr_rl.ipynb create mode 100644 examples/temp.ipynb create mode 100644 nocturne/envs/nocturne_gymnasium.py diff --git a/examples/06_ppo_with_sb3_ma_menv_control.py b/examples/06_ppo_with_sb3_ma_menv_control.py new file mode 100644 index 00000000..2f82e87d --- /dev/null +++ b/examples/06_ppo_with_sb3_ma_menv_control.py @@ -0,0 +1,193 @@ +"""Train HR-PPO agent.""" +import logging +from contextlib import nullcontext +from datetime import datetime + +import numpy as np +import torch +from box import Box +from stable_baselines3.common.policies import ActorCriticPolicy + +import wandb + +from typing import Callable + +# Import networks +from networks.mlp_late_fusion import LateFusionMLP, LateFusionMLPPolicy +# Permutation equivariant network +from networks.perm_eq_late_fusion import LateFusionNet, LateFusionPolicy + +# Multi-agent as vectorized environment +from nocturne.envs.vec_env_ma import MultiAgentAsVecEnv +from utils.config import load_config +from utils.random_utils import init_seed +from utils.render import make_video + +# Custom callback +from utils.sb3.callbacks import CustomMultiAgentCallback + +# Custom PPO class that supports multi-agent control +from utils.sb3.reg_ppo import RegularizedPPO +from utils.string_utils import datetime_to_str + +logging.basicConfig(level=logging.INFO) + + +def linear_schedule(initial_value: float) -> Callable[[float], float]: + """ + Linear learning rate schedule. + + :param initial_value: Initial learning rate. + :return: schedule that computes + current learning rate depending on remaining progress + """ + def func(progress_remaining: float) -> float: + """ + Progress will decrease from 1 (beginning) to 0. + + :param progress_remaining: + :return: current learning rate + """ + return progress_remaining * initial_value + + return func + +def train(env_config, exp_config, video_config, model_config): # pylint: disable=redefined-outer-name + """Train RL agent using PPO.""" + # Ensure reproducability + init_seed(env_config, exp_config, exp_config.seed) + + # Make environment + from nocturne.envs.nocturne_gymnasium import NocturneGymnasium + from stable_baselines3.common.vec_env import SubprocVecEnv + from nocturne.envs.base_env import BaseEnv + def make_env(env_config): + return NocturneGymnasium(BaseEnv(config=env_config)) + + env = SubprocVecEnv([lambda: make_env(env_config) for _ in range(4)]) + + # Set up run + datetime_ = datetime_to_str(dt=datetime.now()) + run_id = f"{datetime_}" if exp_config.track_wandb else None + + # Add scene to config + # exp_config.scene = env.filename + exp_config.track_wandb = False + + with wandb.init( + project=exp_config.project, + name=run_id, + group=exp_config.group, + config={**exp_config, **env_config}, + id=run_id, + **exp_config.wandb, + ) if exp_config.track_wandb else nullcontext() as run: + # Set device + exp_config.ppo.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + + # logging.info(f"Created env. Max # agents = {env_config.max_num_vehicles}.") + # logging.info(f"Learning in {len(env.env.files)} scene(s): {env.env.files} | using {exp_config.ppo.device}") + # logging.info(f"--- obs_space: {env.observation_space.shape[0]} ---") + # logging.info(f"Action_space\n: {env.env.idx_to_actions}") + + # if exp_config.reg_weight > 0.0: + # logging.info(f"Regularization weight: {exp_config.reg_weight} with policy: {exp_config.human_policy_path}") + + # # Initialize custom callback + custom_callback = CustomMultiAgentCallback( + env_config=env_config, + exp_config=exp_config, + video_config=video_config, + wandb_run=run if run_id is not None else None, + ) + + # Make scene init video to check expert actions + # if exp_config.track_wandb: + # for model in exp_config.wandb_init_videos: + # make_video( + # env_config=env_config, + # exp_config=exp_config, + # video_config=video_config, + # filenames=[env.filename], + # model=model, + # n_steps=None, + # ) + exp_config.track_wandb = False + + human_policy = None + # Load human reference policy if regularization is used + if exp_config.reg_weight > 0.0: + saved_variables = torch.load(exp_config.human_policy_path, map_location=exp_config.ppo.device) + human_policy = ActorCriticPolicy(**saved_variables["data"]) + human_policy.load_state_dict(saved_variables["state_dict"]) + human_policy.to(exp_config.ppo.device) + + # Set up PPO + model = RegularizedPPO( + learning_rate=linear_schedule(1e-4), + reg_policy=human_policy, + reg_weight=exp_config.reg_weight, # Regularization weight; lambda + env=env, + n_steps=exp_config.ppo.n_steps, + policy=LateFusionPolicy, + ent_coef=exp_config.ppo.ent_coef, + vf_coef=exp_config.ppo.vf_coef, + seed=exp_config.seed, # Seed for the pseudo random generators + verbose=exp_config.verbose, + tensorboard_log=f"runs/{run_id}" if run_id is not None else None, + device=exp_config.ppo.device, + env_config=env_config, + mlp_class=LateFusionNet, + mlp_config=model_config, + ) + + # Log number of trainable parameters + policy_params = filter(lambda p: p.requires_grad, model.policy.parameters()) + params = sum(np.prod(p.size()) for p in policy_params) + exp_config.n_policy_params = params + logging.info(f"Policy | trainable params: {params:,} \n") + + # Architecture + logging.info(f"Policy | arch: \n {model.policy}") + + + return env, model + # Learn + # model.learn( + # **exp_config.learn, + # callback=custom_callback, + # ) + + +if __name__ == "__main__": + env_config = load_config("env_config") + exp_config = load_config("exp_config") + video_config = load_config("video_config") + + env_config.num_files = 10 + + # Define model architecture + model_config = None + # model_config = Box( + # { + # "arch_ego_state": [8], + # "arch_road_objects": [64], + # "arch_road_graph": [128, 64], + # "arch_shared_net": [128], + # "act_func": "tanh", + # "dropout": 0.0, + # "last_layer_dim_pi": 64, + # "last_layer_dim_vf": 64, + # } + # ) + + # Train + env, model = train( + env_config=env_config, + exp_config=exp_config, + video_config=video_config, + model_config=model_config, + ) + + + model.learn(1) \ No newline at end of file diff --git a/examples/hr_rl.ipynb b/examples/hr_rl.ipynb new file mode 100644 index 00000000..94ed4698 --- /dev/null +++ b/examples/hr_rl.ipynb @@ -0,0 +1,6640 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [], + "source": [ + "\"\"\"Train HR-PPO agent.\"\"\"\n", + "import logging\n", + "from contextlib import nullcontext\n", + "from datetime import datetime\n", + "\n", + "import numpy as np\n", + "import torch\n", + "from box import Box\n", + "from stable_baselines3.common.policies import ActorCriticPolicy\n", + "\n", + "import wandb\n", + "\n", + "from typing import Callable\n", + "\n", + "# Import networks\n", + "from networks.mlp_late_fusion import LateFusionMLP, LateFusionMLPPolicy\n", + "# Permutation equivariant network\n", + "from networks.perm_eq_late_fusion import LateFusionNet, LateFusionPolicy \n", + "\n", + "# Multi-agent as vectorized environment\n", + "from nocturne.envs.vec_env_ma import MultiAgentAsVecEnv\n", + "from utils.config import load_config_nb\n", + "from utils.random_utils import init_seed\n", + "from utils.render import make_video\n", + "\n", + "# Custom callback\n", + "from utils.sb3.callbacks import CustomMultiAgentCallback\n", + "\n", + "# Custom PPO class that supports multi-agent control\n", + "from utils.sb3.reg_ppo import RegularizedPPO\n", + "from utils.string_utils import datetime_to_str\n", + "\n", + "logging.basicConfig(level=logging.INFO)" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ + "\n", + "def linear_schedule(initial_value: float) -> Callable[[float], float]:\n", + " \"\"\"\n", + " Linear learning rate schedule.\n", + "\n", + " :param initial_value: Initial learning rate.\n", + " :return: schedule that computes\n", + " current learning rate depending on remaining progress\n", + " \"\"\"\n", + " def func(progress_remaining: float) -> float:\n", + " \"\"\"\n", + " Progress will decrease from 1 (beginning) to 0.\n", + "\n", + " :param progress_remaining:\n", + " :return: current learning rate\n", + " \"\"\"\n", + " return progress_remaining * initial_value\n", + "\n", + " return func\n", + "\n", + "def train(env_config, exp_config, video_config, model_config): # pylint: disable=redefined-outer-name\n", + " \"\"\"Train RL agent using PPO.\"\"\"\n", + " # Ensure reproducability\n", + " init_seed(env_config, exp_config, exp_config.seed)\n", + "\n", + " # Make environment\n", + " from nocturne.envs.nocturne_gymnasium import NocturneGymnasium\n", + " from stable_baselines3.common.vec_env import SubprocVecEnv\n", + " from nocturne.envs.base_env import BaseEnv\n", + " def make_env(env_config):\n", + " return NocturneGymnasium(BaseEnv(config=env_config)) \n", + "\n", + " env = SubprocVecEnv([lambda: make_env(env_config) for _ in range(4)])\n", + "\n", + " # Set up run\n", + " datetime_ = datetime_to_str(dt=datetime.now())\n", + " run_id = f\"{datetime_}\" if exp_config.track_wandb else None\n", + "\n", + " # Add scene to config\n", + " # exp_config.scene = env.filename\n", + " exp_config.track_wandb = False\n", + "\n", + " with wandb.init(\n", + " project=exp_config.project,\n", + " name=run_id,\n", + " group=exp_config.group,\n", + " config={**exp_config, **env_config},\n", + " id=run_id,\n", + " **exp_config.wandb,\n", + " ) if exp_config.track_wandb else nullcontext() as run:\n", + " # Set device\n", + " exp_config.ppo.device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n", + "\n", + " # logging.info(f\"Created env. Max # agents = {env_config.max_num_vehicles}.\")\n", + " # logging.info(f\"Learning in {len(env.env.files)} scene(s): {env.env.files} | using {exp_config.ppo.device}\")\n", + " # logging.info(f\"--- obs_space: {env.observation_space.shape[0]} ---\")\n", + " # logging.info(f\"Action_space\\n: {env.env.idx_to_actions}\")\n", + " \n", + " # if exp_config.reg_weight > 0.0:\n", + " # logging.info(f\"Regularization weight: {exp_config.reg_weight} with policy: {exp_config.human_policy_path}\")\n", + "\n", + " # # Initialize custom callback\n", + " custom_callback = CustomMultiAgentCallback(\n", + " env_config=env_config,\n", + " exp_config=exp_config,\n", + " video_config=video_config,\n", + " wandb_run=run if run_id is not None else None,\n", + " )\n", + "\n", + " # Make scene init video to check expert actions\n", + " # if exp_config.track_wandb:\n", + " # for model in exp_config.wandb_init_videos:\n", + " # make_video(\n", + " # env_config=env_config,\n", + " # exp_config=exp_config,\n", + " # video_config=video_config,\n", + " # filenames=[env.filename],\n", + " # model=model,\n", + " # n_steps=None,\n", + " # )\n", + " exp_config.track_wandb = False\n", + " \n", + " human_policy = None\n", + " # Load human reference policy if regularization is used\n", + " if exp_config.reg_weight > 0.0:\n", + " saved_variables = torch.load(exp_config.human_policy_path, map_location=exp_config.ppo.device)\n", + " human_policy = ActorCriticPolicy(**saved_variables[\"data\"])\n", + " human_policy.load_state_dict(saved_variables[\"state_dict\"])\n", + " human_policy.to(exp_config.ppo.device)\n", + "\n", + " # Set up PPO\n", + " model = RegularizedPPO(\n", + " learning_rate=linear_schedule(1e-4),\n", + " reg_policy=human_policy,\n", + " reg_weight=exp_config.reg_weight, # Regularization weight; lambda\n", + " env=env,\n", + " n_steps=exp_config.ppo.n_steps,\n", + " policy=LateFusionPolicy,\n", + " ent_coef=exp_config.ppo.ent_coef,\n", + " vf_coef=exp_config.ppo.vf_coef,\n", + " seed=exp_config.seed, # Seed for the pseudo random generators\n", + " verbose=exp_config.verbose,\n", + " tensorboard_log=f\"runs/{run_id}\" if run_id is not None else None,\n", + " device=exp_config.ppo.device,\n", + " env_config=env_config,\n", + " mlp_class=LateFusionNet,\n", + " mlp_config=model_config,\n", + " )\n", + "\n", + " # Log number of trainable parameters\n", + " policy_params = filter(lambda p: p.requires_grad, model.policy.parameters())\n", + " params = sum(np.prod(p.size()) for p in policy_params)\n", + " exp_config.n_policy_params = params\n", + " logging.info(f\"Policy | trainable params: {params:,} \\n\")\n", + "\n", + " # Architecture\n", + " logging.info(f\"Policy | arch: \\n {model.policy}\")\n", + "\n", + "\n", + " return env, model\n", + " # Learn\n", + " # model.learn(\n", + " # **exp_config.learn,\n", + " # callback=custom_callback,\n", + " # )" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "INFO:root:Policy | trainable params: 148,566 \n", + "\n", + "INFO:root:Policy | arch: \n", + " LateFusionPolicy(\n", + " (features_extractor): FlattenExtractor(\n", + " (flatten): Flatten(start_dim=1, end_dim=-1)\n", + " )\n", + " (pi_features_extractor): FlattenExtractor(\n", + " (flatten): Flatten(start_dim=1, end_dim=-1)\n", + " )\n", + " (vf_features_extractor): FlattenExtractor(\n", + " (flatten): Flatten(start_dim=1, end_dim=-1)\n", + " )\n", + " (mlp_extractor): LateFusionNet(\n", + " (act_func): Tanh()\n", + " (actor_ego_state_net): Sequential(\n", + " (0): Linear(in_features=10, out_features=10, bias=True)\n", + " (1): Dropout(p=0.0, inplace=False)\n", + " (2): LayerNorm((10,), eps=1e-05, elementwise_affine=True)\n", + " (3): Tanh()\n", + " )\n", + " (actor_ro_net): Sequential(\n", + " (0): Linear(in_features=13, out_features=64, bias=True)\n", + " (1): Dropout(p=0.0, inplace=False)\n", + " (2): LayerNorm((64,), eps=1e-05, elementwise_affine=True)\n", + " (3): Tanh()\n", + " (4): Linear(in_features=64, out_features=32, bias=True)\n", + " (5): Dropout(p=0.0, inplace=False)\n", + " (6): LayerNorm((32,), eps=1e-05, elementwise_affine=True)\n", + " (7): Tanh()\n", + " )\n", + " (actor_rg_net): Sequential(\n", + " (0): Linear(in_features=13, out_features=64, bias=True)\n", + " (1): Dropout(p=0.0, inplace=False)\n", + " (2): LayerNorm((64,), eps=1e-05, elementwise_affine=True)\n", + " (3): Tanh()\n", + " (4): Linear(in_features=64, out_features=32, bias=True)\n", + " (5): Dropout(p=0.0, inplace=False)\n", + " (6): LayerNorm((32,), eps=1e-05, elementwise_affine=True)\n", + " (7): Tanh()\n", + " )\n", + " (actor_ss_net): Sequential(\n", + " (0): Linear(in_features=3, out_features=3, bias=True)\n", + " (1): Dropout(p=0.0, inplace=False)\n", + " (2): LayerNorm((3,), eps=1e-05, elementwise_affine=True)\n", + " (3): Tanh()\n", + " )\n", + " (actor_out_net): Sequential(\n", + " (0): Linear(in_features=77, out_features=256, bias=True)\n", + " (1): LayerNorm((256,), eps=1e-05, elementwise_affine=True)\n", + " (2): Tanh()\n", + " (3): Dropout(p=0.0, inplace=False)\n", + " (4): Linear(in_features=256, out_features=128, bias=True)\n", + " (5): LayerNorm((128,), eps=1e-05, elementwise_affine=True)\n", + " (6): Tanh()\n", + " (7): Dropout(p=0.0, inplace=False)\n", + " (8): Linear(in_features=128, out_features=64, bias=True)\n", + " (9): LayerNorm((64,), eps=1e-05, elementwise_affine=True)\n", + " (10): Tanh()\n", + " (11): Dropout(p=0.0, inplace=False)\n", + " (12): Linear(in_features=64, out_features=64, bias=True)\n", + " (13): LayerNorm((64,), eps=1e-05, elementwise_affine=True)\n", + " )\n", + " (val_ego_state_net): Sequential(\n", + " (0): Linear(in_features=10, out_features=10, bias=True)\n", + " (1): Dropout(p=0.0, inplace=False)\n", + " (2): LayerNorm((10,), eps=1e-05, elementwise_affine=True)\n", + " (3): Tanh()\n", + " )\n", + " (val_ro_net): Sequential(\n", + " (0): Linear(in_features=13, out_features=64, bias=True)\n", + " (1): Dropout(p=0.0, inplace=False)\n", + " (2): LayerNorm((64,), eps=1e-05, elementwise_affine=True)\n", + " (3): Tanh()\n", + " (4): Linear(in_features=64, out_features=32, bias=True)\n", + " (5): Dropout(p=0.0, inplace=False)\n", + " (6): LayerNorm((32,), eps=1e-05, elementwise_affine=True)\n", + " (7): Tanh()\n", + " )\n", + " (val_rg_net): Sequential(\n", + " (0): Linear(in_features=13, out_features=64, bias=True)\n", + " (1): Dropout(p=0.0, inplace=False)\n", + " (2): LayerNorm((64,), eps=1e-05, elementwise_affine=True)\n", + " (3): Tanh()\n", + " (4): Linear(in_features=64, out_features=32, bias=True)\n", + " (5): Dropout(p=0.0, inplace=False)\n", + " (6): LayerNorm((32,), eps=1e-05, elementwise_affine=True)\n", + " (7): Tanh()\n", + " )\n", + " (val_ss_net): Sequential(\n", + " (0): Linear(in_features=3, out_features=3, bias=True)\n", + " (1): Dropout(p=0.0, inplace=False)\n", + " (2): LayerNorm((3,), eps=1e-05, elementwise_affine=True)\n", + " (3): Tanh()\n", + " )\n", + " (val_out_net): Sequential(\n", + " (0): Linear(in_features=77, out_features=256, bias=True)\n", + " (1): LayerNorm((256,), eps=1e-05, elementwise_affine=True)\n", + " (2): Tanh()\n", + " (3): Dropout(p=0.0, inplace=False)\n", + " (4): Linear(in_features=256, out_features=128, bias=True)\n", + " (5): LayerNorm((128,), eps=1e-05, elementwise_affine=True)\n", + " (6): Tanh()\n", + " (7): Dropout(p=0.0, inplace=False)\n", + " (8): Linear(in_features=128, out_features=64, bias=True)\n", + " (9): LayerNorm((64,), eps=1e-05, elementwise_affine=True)\n", + " (10): Tanh()\n", + " (11): Dropout(p=0.0, inplace=False)\n", + " (12): Linear(in_features=64, out_features=64, bias=True)\n", + " (13): LayerNorm((64,), eps=1e-05, elementwise_affine=True)\n", + " )\n", + " )\n", + " (action_net): Linear(in_features=64, out_features=45, bias=True)\n", + " (value_net): Linear(in_features=64, out_features=1, bias=True)\n", + ")\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "LateFusionNet(\n", + " (act_func): Tanh()\n", + " (actor_ego_state_net): Sequential(\n", + " (0): Linear(in_features=10, out_features=10, bias=True)\n", + " (1): Dropout(p=0.0, inplace=False)\n", + " (2): LayerNorm((10,), eps=1e-05, elementwise_affine=True)\n", + " (3): Tanh()\n", + " )\n", + " (actor_ro_net): Sequential(\n", + " (0): Linear(in_features=13, out_features=64, bias=True)\n", + " (1): Dropout(p=0.0, inplace=False)\n", + " (2): LayerNorm((64,), eps=1e-05, elementwise_affine=True)\n", + " (3): Tanh()\n", + " (4): Linear(in_features=64, out_features=32, bias=True)\n", + " (5): Dropout(p=0.0, inplace=False)\n", + " (6): LayerNorm((32,), eps=1e-05, elementwise_affine=True)\n", + " (7): Tanh()\n", + " )\n", + " (actor_rg_net): Sequential(\n", + " (0): Linear(in_features=13, out_features=64, bias=True)\n", + " (1): Dropout(p=0.0, inplace=False)\n", + " (2): LayerNorm((64,), eps=1e-05, elementwise_affine=True)\n", + " (3): Tanh()\n", + " (4): Linear(in_features=64, out_features=32, bias=True)\n", + " (5): Dropout(p=0.0, inplace=False)\n", + " (6): LayerNorm((32,), eps=1e-05, elementwise_affine=True)\n", + " (7): Tanh()\n", + " )\n", + " (actor_ss_net): Sequential(\n", + " (0): Linear(in_features=3, out_features=3, bias=True)\n", + " (1): Dropout(p=0.0, inplace=False)\n", + " (2): LayerNorm((3,), eps=1e-05, elementwise_affine=True)\n", + " (3): Tanh()\n", + " )\n", + " (actor_out_net): Sequential(\n", + " (0): Linear(in_features=77, out_features=256, bias=True)\n", + " (1): LayerNorm((256,), eps=1e-05, elementwise_affine=True)\n", + " (2): Tanh()\n", + " (3): Dropout(p=0.0, inplace=False)\n", + " (4): Linear(in_features=256, out_features=128, bias=True)\n", + " (5): LayerNorm((128,), eps=1e-05, elementwise_affine=True)\n", + " (6): Tanh()\n", + " (7): Dropout(p=0.0, inplace=False)\n", + " (8): Linear(in_features=128, out_features=64, bias=True)\n", + " (9): LayerNorm((64,), eps=1e-05, elementwise_affine=True)\n", + " (10): Tanh()\n", + " (11): Dropout(p=0.0, inplace=False)\n", + " (12): Linear(in_features=64, out_features=64, bias=True)\n", + " (13): LayerNorm((64,), eps=1e-05, elementwise_affine=True)\n", + " )\n", + " (val_ego_state_net): Sequential(\n", + " (0): Linear(in_features=10, out_features=10, bias=True)\n", + " (1): Dropout(p=0.0, inplace=False)\n", + " (2): LayerNorm((10,), eps=1e-05, elementwise_affine=True)\n", + " (3): Tanh()\n", + " )\n", + " (val_ro_net): Sequential(\n", + " (0): Linear(in_features=13, out_features=64, bias=True)\n", + " (1): Dropout(p=0.0, inplace=False)\n", + " (2): LayerNorm((64,), eps=1e-05, elementwise_affine=True)\n", + " (3): Tanh()\n", + " (4): Linear(in_features=64, out_features=32, bias=True)\n", + " (5): Dropout(p=0.0, inplace=False)\n", + " (6): LayerNorm((32,), eps=1e-05, elementwise_affine=True)\n", + " (7): Tanh()\n", + " )\n", + " (val_rg_net): Sequential(\n", + " (0): Linear(in_features=13, out_features=64, bias=True)\n", + " (1): Dropout(p=0.0, inplace=False)\n", + " (2): LayerNorm((64,), eps=1e-05, elementwise_affine=True)\n", + " (3): Tanh()\n", + " (4): Linear(in_features=64, out_features=32, bias=True)\n", + " (5): Dropout(p=0.0, inplace=False)\n", + " (6): LayerNorm((32,), eps=1e-05, elementwise_affine=True)\n", + " (7): Tanh()\n", + " )\n", + " (val_ss_net): Sequential(\n", + " (0): Linear(in_features=3, out_features=3, bias=True)\n", + " (1): Dropout(p=0.0, inplace=False)\n", + " (2): LayerNorm((3,), eps=1e-05, elementwise_affine=True)\n", + " (3): Tanh()\n", + " )\n", + " (val_out_net): Sequential(\n", + " (0): Linear(in_features=77, out_features=256, bias=True)\n", + " (1): LayerNorm((256,), eps=1e-05, elementwise_affine=True)\n", + " (2): Tanh()\n", + " (3): Dropout(p=0.0, inplace=False)\n", + " (4): Linear(in_features=256, out_features=128, bias=True)\n", + " (5): LayerNorm((128,), eps=1e-05, elementwise_affine=True)\n", + " (6): Tanh()\n", + " (7): Dropout(p=0.0, inplace=False)\n", + " (8): Linear(in_features=128, out_features=64, bias=True)\n", + " (9): LayerNorm((64,), eps=1e-05, elementwise_affine=True)\n", + " (10): Tanh()\n", + " (11): Dropout(p=0.0, inplace=False)\n", + " (12): Linear(in_features=64, out_features=64, bias=True)\n", + " (13): LayerNorm((64,), eps=1e-05, elementwise_affine=True)\n", + " )\n", + ")\n" + ] + } + ], + "source": [ + "env_config = load_config_nb(\"env_config\")\n", + "exp_config = load_config_nb(\"exp_config\")\n", + "video_config = load_config_nb(\"video_config\")\n", + "\n", + "env_config.num_files = 10\n", + "\n", + "# Define model architecture\n", + "model_config = None\n", + "# model_config = Box(\n", + "# {\n", + "# \"arch_ego_state\": [8],\n", + "# \"arch_road_objects\": [64],\n", + "# \"arch_road_graph\": [128, 64],\n", + "# \"arch_shared_net\": [128],\n", + "# \"act_func\": \"tanh\",\n", + "# \"dropout\": 0.0,\n", + "# \"last_layer_dim_pi\": 64,\n", + "# \"last_layer_dim_vf\": 64,\n", + "# }\n", + "# )\n", + "\n", + "# Train\n", + "env, model = train(\n", + " env_config=env_config,\n", + " exp_config=exp_config,\n", + " video_config=video_config,\n", + " model_config=model_config,\n", + ")\n" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "array([{37: array([0.28523394, 0.49977985, 0.10757028, ..., 0. , 0. ,\n", + " 0. ]), 0: array([0.2631481 , 0.50686294, 0.01567911, ..., 0.01 , 0.68215334,\n", + " 0.01584669]), 51: array([0.33037499, 0.583 , 0.02133806, ..., 0. , 0. ,\n", + " 0. ]), 31: array([0.29642355, 0.51919246, 0.05890008, ..., 0. , 0. ,\n", + " 0. ]), 34: array([0.29417717, 0.51871204, 0.11663519, ..., 0. , 0. ,\n", + " 0. ]), 2: array([0.29118758, 0.50871187, 0.02653475, ..., 0. , 0. ,\n", + " 0. ]), 1: array([0.28495339, 0.51357961, 0.02791023, ..., 0. , 0. ,\n", + " 0. ]), 32: array([0.29227659, 0.5093497 , 0.11704891, ..., 0. , 0. ,\n", + " 0. ]), 30: array([0.28525367, 0.50776368, 0.125852 , ..., 0. , 0. ,\n", + " 0. ]), 33: array([0.3665545 , 0.60222054, 0.10613482, ..., 0. , 0. ,\n", + " 0. ]), 41: array([0.27929321, 0.49081314, 0.11201638, ..., 0. , 0. ,\n", + " 0. ]), 4: array([0.29417551, 0.52149814, 0.11362467, ..., 0. , 0. ,\n", + " 0. ])} ,\n", + " {37: array([0.28523394, 0.49977985, 0.10757028, ..., 0. , 0. ,\n", + " 0. ]), 0: array([0.2631481 , 0.50686294, 0.01567911, ..., 0.01 , 0.68215334,\n", + " 0.01584669]), 51: array([0.33037499, 0.583 , 0.02133806, ..., 0. , 0. ,\n", + " 0. ]), 31: array([0.29642355, 0.51919246, 0.05890008, ..., 0. , 0. ,\n", + " 0. ]), 34: array([0.29417717, 0.51871204, 0.11663519, ..., 0. , 0. ,\n", + " 0. ]), 2: array([0.29118758, 0.50871187, 0.02653475, ..., 0. , 0. ,\n", + " 0. ]), 1: array([0.28495339, 0.51357961, 0.02791023, ..., 0. , 0. ,\n", + " 0. ]), 32: array([0.29227659, 0.5093497 , 0.11704891, ..., 0. , 0. ,\n", + " 0. ]), 30: array([0.28525367, 0.50776368, 0.125852 , ..., 0. , 0. ,\n", + " 0. ]), 33: array([0.3665545 , 0.60222054, 0.10613482, ..., 0. , 0. ,\n", + " 0. ]), 41: array([0.27929321, 0.49081314, 0.11201638, ..., 0. , 0. ,\n", + " 0. ]), 4: array([0.29417551, 0.52149814, 0.11362467, ..., 0. , 0. ,\n", + " 0. ])} ,\n", + " {37: array([0.28523394, 0.49977985, 0.10757028, ..., 0. , 0. ,\n", + " 0. ]), 0: array([0.2631481 , 0.50686294, 0.01567911, ..., 0.01 , 0.68215334,\n", + " 0.01584669]), 51: array([0.33037499, 0.583 , 0.02133806, ..., 0. , 0. ,\n", + " 0. ]), 31: array([0.29642355, 0.51919246, 0.05890008, ..., 0. , 0. ,\n", + " 0. ]), 34: array([0.29417717, 0.51871204, 0.11663519, ..., 0. , 0. ,\n", + " 0. ]), 2: array([0.29118758, 0.50871187, 0.02653475, ..., 0. , 0. ,\n", + " 0. ]), 1: array([0.28495339, 0.51357961, 0.02791023, ..., 0. , 0. ,\n", + " 0. ]), 32: array([0.29227659, 0.5093497 , 0.11704891, ..., 0. , 0. ,\n", + " 0. ]), 30: array([0.28525367, 0.50776368, 0.125852 , ..., 0. , 0. ,\n", + " 0. ]), 33: array([0.3665545 , 0.60222054, 0.10613482, ..., 0. , 0. ,\n", + " 0. ]), 41: array([0.27929321, 0.49081314, 0.11201638, ..., 0. , 0. ,\n", + " 0. ]), 4: array([0.29417551, 0.52149814, 0.11362467, ..., 0. , 0. ,\n", + " 0. ])} ,\n", + " {37: array([0.28523394, 0.49977985, 0.10757028, ..., 0. , 0. ,\n", + " 0. ]), 0: array([0.2631481 , 0.50686294, 0.01567911, ..., 0.01 , 0.68215334,\n", + " 0.01584669]), 51: array([0.33037499, 0.583 , 0.02133806, ..., 0. , 0. ,\n", + " 0. ]), 31: array([0.29642355, 0.51919246, 0.05890008, ..., 0. , 0. ,\n", + " 0. ]), 34: array([0.29417717, 0.51871204, 0.11663519, ..., 0. , 0. ,\n", + " 0. ]), 2: array([0.29118758, 0.50871187, 0.02653475, ..., 0. , 0. ,\n", + " 0. ]), 1: array([0.28495339, 0.51357961, 0.02791023, ..., 0. , 0. ,\n", + " 0. ]), 32: array([0.29227659, 0.5093497 , 0.11704891, ..., 0. , 0. ,\n", + " 0. ]), 30: array([0.28525367, 0.50776368, 0.125852 , ..., 0. , 0. ,\n", + " 0. ]), 33: array([0.3665545 , 0.60222054, 0.10613482, ..., 0. , 0. ,\n", + " 0. ]), 41: array([0.27929321, 0.49081314, 0.11201638, ..., 0. , 0. ,\n", + " 0. ]), 4: array([0.29417551, 0.52149814, 0.11362467, ..., 0. , 0. ,\n", + " 0. ])} ],\n", + " dtype=object)" + ] + }, + "execution_count": 5, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "env.reset()" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [ + { + "ename": "TypeError", + "evalue": "can't convert np.ndarray of type numpy.object_. The only supported types are: float64, float32, float16, complex64, complex128, int64, int32, int16, int8, uint8, and bool.", + "output_type": "error", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mTypeError\u001b[0m Traceback (most recent call last)", + "Cell \u001b[0;32mIn[6], line 1\u001b[0m\n\u001b[0;32m----> 1\u001b[0m \u001b[43mmodel\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mlearn\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m1\u001b[39;49m\u001b[43m)\u001b[49m\n", + "File \u001b[0;32m~/nocturne_lab/.venv/lib/python3.11/site-packages/stable_baselines3/ppo/ppo.py:315\u001b[0m, in \u001b[0;36mPPO.learn\u001b[0;34m(self, total_timesteps, callback, log_interval, tb_log_name, reset_num_timesteps, progress_bar)\u001b[0m\n\u001b[1;32m 306\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mlearn\u001b[39m(\n\u001b[1;32m 307\u001b[0m \u001b[38;5;28mself\u001b[39m: SelfPPO,\n\u001b[1;32m 308\u001b[0m total_timesteps: \u001b[38;5;28mint\u001b[39m,\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 313\u001b[0m progress_bar: \u001b[38;5;28mbool\u001b[39m \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mFalse\u001b[39;00m,\n\u001b[1;32m 314\u001b[0m ) \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m>\u001b[39m SelfPPO:\n\u001b[0;32m--> 315\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43msuper\u001b[39;49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mlearn\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 316\u001b[0m \u001b[43m \u001b[49m\u001b[43mtotal_timesteps\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mtotal_timesteps\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 317\u001b[0m \u001b[43m \u001b[49m\u001b[43mcallback\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mcallback\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 318\u001b[0m \u001b[43m \u001b[49m\u001b[43mlog_interval\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mlog_interval\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 319\u001b[0m \u001b[43m \u001b[49m\u001b[43mtb_log_name\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mtb_log_name\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 320\u001b[0m \u001b[43m \u001b[49m\u001b[43mreset_num_timesteps\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mreset_num_timesteps\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 321\u001b[0m \u001b[43m \u001b[49m\u001b[43mprogress_bar\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mprogress_bar\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 322\u001b[0m \u001b[43m \u001b[49m\u001b[43m)\u001b[49m\n", + "File \u001b[0;32m~/nocturne_lab/.venv/lib/python3.11/site-packages/stable_baselines3/common/on_policy_algorithm.py:277\u001b[0m, in \u001b[0;36mOnPolicyAlgorithm.learn\u001b[0;34m(self, total_timesteps, callback, log_interval, tb_log_name, reset_num_timesteps, progress_bar)\u001b[0m\n\u001b[1;32m 274\u001b[0m \u001b[38;5;28;01massert\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39menv \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m\n\u001b[1;32m 276\u001b[0m \u001b[38;5;28;01mwhile\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mnum_timesteps \u001b[38;5;241m<\u001b[39m total_timesteps:\n\u001b[0;32m--> 277\u001b[0m continue_training \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mcollect_rollouts\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43menv\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mcallback\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mrollout_buffer\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mn_rollout_steps\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mn_steps\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 279\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m continue_training:\n\u001b[1;32m 280\u001b[0m \u001b[38;5;28;01mbreak\u001b[39;00m\n", + "File \u001b[0;32m~/nocturne_lab/utils/sb3/ma_ppo.py:56\u001b[0m, in \u001b[0;36mMultiAgentPPO.collect_rollouts\u001b[0;34m(self, env, callback, rollout_buffer, n_rollout_steps)\u001b[0m\n\u001b[1;32m 52\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mpolicy\u001b[38;5;241m.\u001b[39mreset_noise(env\u001b[38;5;241m.\u001b[39mnum_envs)\n\u001b[1;32m 54\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m torch\u001b[38;5;241m.\u001b[39mno_grad():\n\u001b[1;32m 55\u001b[0m \u001b[38;5;66;03m# Convert to pytorch tensor or to TensorDict\u001b[39;00m\n\u001b[0;32m---> 56\u001b[0m obs_tensor \u001b[38;5;241m=\u001b[39m \u001b[43mobs_as_tensor\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_last_obs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mdevice\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 58\u001b[0m \u001b[38;5;66;03m# EDIT_1: Mask out invalid observations (NaN dimensions and/or dead agents)\u001b[39;00m\n\u001b[1;32m 59\u001b[0m \u001b[38;5;66;03m# Create dummy actions, values and log_probs (NaN)\u001b[39;00m\n\u001b[1;32m 60\u001b[0m actions \u001b[38;5;241m=\u001b[39m torch\u001b[38;5;241m.\u001b[39mfull(fill_value\u001b[38;5;241m=\u001b[39mnp\u001b[38;5;241m.\u001b[39mnan, size\u001b[38;5;241m=\u001b[39m(\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mn_envs,))\u001b[38;5;241m.\u001b[39mto(\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mdevice)\n", + "File \u001b[0;32m~/nocturne_lab/.venv/lib/python3.11/site-packages/stable_baselines3/common/utils.py:483\u001b[0m, in \u001b[0;36mobs_as_tensor\u001b[0;34m(obs, device)\u001b[0m\n\u001b[1;32m 475\u001b[0m \u001b[38;5;250m\u001b[39m\u001b[38;5;124;03m\"\"\"\u001b[39;00m\n\u001b[1;32m 476\u001b[0m \u001b[38;5;124;03mMoves the observation to the given device.\u001b[39;00m\n\u001b[1;32m 477\u001b[0m \n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 480\u001b[0m \u001b[38;5;124;03m:return: PyTorch tensor of the observation on a desired device.\u001b[39;00m\n\u001b[1;32m 481\u001b[0m \u001b[38;5;124;03m\"\"\"\u001b[39;00m\n\u001b[1;32m 482\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28misinstance\u001b[39m(obs, np\u001b[38;5;241m.\u001b[39mndarray):\n\u001b[0;32m--> 483\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mth\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mas_tensor\u001b[49m\u001b[43m(\u001b[49m\u001b[43mobs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mdevice\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mdevice\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 484\u001b[0m \u001b[38;5;28;01melif\u001b[39;00m \u001b[38;5;28misinstance\u001b[39m(obs, \u001b[38;5;28mdict\u001b[39m):\n\u001b[1;32m 485\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m {key: th\u001b[38;5;241m.\u001b[39mas_tensor(_obs, device\u001b[38;5;241m=\u001b[39mdevice) \u001b[38;5;28;01mfor\u001b[39;00m (key, _obs) \u001b[38;5;129;01min\u001b[39;00m obs\u001b[38;5;241m.\u001b[39mitems()}\n", + "\u001b[0;31mTypeError\u001b[0m: can't convert np.ndarray of type numpy.object_. The only supported types are: float64, float32, float16, complex64, complex128, int64, int32, int16, int8, uint8, and bool." + ] + } + ], + "source": [ + "model.learn(1)" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [], + "source": [ + "y = env.observation_space.sample()" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "Discrete(25)" + ] + }, + "execution_count": 14, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "env.action_space" + ] + }, + { + "cell_type": "code", + "execution_count": 44, + "metadata": {}, + "outputs": [], + "source": [ + "m = model.policy.eval()" + ] + }, + { + "cell_type": "code", + "execution_count": 50, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "(tensor([8], device='cuda:0'),\n", + " tensor([[-0.5645]], device='cuda:0', grad_fn=),\n", + " tensor([-3.6230], device='cuda:0', grad_fn=))" + ] + }, + "execution_count": 50, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "m(torch.Tensor(y).unsqueeze(0).cuda())" + ] + }, + { + "cell_type": "code", + "execution_count": 42, + "metadata": {}, + "outputs": [], + "source": [ + "from torchviz import make_dot" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "(tensor([[-0.6441, -1.3309, 0.5179, -0.4558, 0.6692, 0.3656, 1.1654, 1.2158,\n", + " 0.9431, -0.3549, -1.8127, -0.9146, 0.6600, 1.6707, -0.7357, 1.0185,\n", + " 1.4673, 0.1360, 1.5719, -0.4908, -0.5020, -0.1510, -2.7177, -0.9720,\n", + " -1.5061, 0.1968, 0.1585, 0.3488, -0.5726, 0.5933, -0.0064, -1.4018,\n", + " 0.3319, -2.0170, 0.8157, -1.6580, 0.5165, 0.6804, -1.1788, -0.5321,\n", + " -0.3591, 0.4718, 0.4771, 0.6652, 0.6524, 0.5661, -1.1508, 0.2073,\n", + " -0.4792, 1.4354, 1.1228, 1.2609, 0.6219, -1.6648, 0.5761, 0.1294,\n", + " 0.4942, -1.1621, 1.7667, -1.1287, 0.1603, 0.8658, 0.2625, -0.8793]],\n", + " device='cuda:0', grad_fn=),\n", + " tensor([[-0.3148, 1.7813, 3.3665, 0.3266, -0.4068, 0.3553, -1.4078, -1.1260,\n", + " 0.0797, 0.4915, 0.1167, -0.4714, -0.7219, 1.0971, -0.2553, -0.5615,\n", + " 0.3163, 1.6734, -0.1957, 0.0232, -0.9513, 0.4992, -0.1210, -0.9332,\n", + " 0.1590, -1.3791, 0.4740, -0.7048, -0.2949, -0.6597, 0.9464, 0.3830,\n", + " -2.3730, -0.4302, -0.7398, -0.3818, 0.6071, 1.6034, 1.2792, -1.6859,\n", + " 0.6433, 0.4296, -2.0910, 0.2231, 0.6841, 0.2559, -0.5256, -0.4875,\n", + " 2.0396, -0.4128, 0.8378, -0.6700, -0.4007, 0.2891, -1.4553, -1.0730,\n", + " -0.8699, 1.3477, 0.5756, 1.1768, -0.5317, 0.2053, 0.5588, -0.2122]],\n", + " device='cuda:0', grad_fn=))" + ] + }, + "execution_count": 7, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "model.policy.mlp_extractor.cuda()(torch.Tensor(y).unsqueeze(0).cuda())" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": {}, + "outputs": [], + "source": [ + "# make_dot(model.policy(torch.Tensor(y).unsqueeze(0).cuda()), params=dict(model.policy.named_parameters()))" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": {}, + "outputs": [], + "source": [ + "from torchview import draw_graph" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "tensor([[-0.2501, 0.3726, 0.1859, ..., -0.5912, 0.7823, -1.0088]])" + ] + }, + "execution_count": 10, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "torch.Tensor(y).unsqueeze(0).cpu()" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "metadata": {}, + "outputs": [ + { + "data": { + "image/svg+xml": [ + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "model\n", + "\n", + "\n", + "\n", + "0\n", + "\n", + "\n", + "input-tensor\n", + "depth:0\n", + "\n", + "(1, 6730)\n", + "\n", + "\n", + "\n", + "1\n", + "\n", + "\n", + "__getitem__\n", + "depth:1\n", + "\n", + "input:\n", + "\n", + "(1, 6730) \n", + "\n", + "output: \n", + "\n", + "(1, 10) \n", + "\n", + "\n", + "\n", + "0->1\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "2\n", + "\n", + "\n", + "__getitem__\n", + "depth:1\n", + "\n", + "input:\n", + "\n", + "(1, 6730) \n", + "\n", + "output: \n", + "\n", + "(1, 6720) \n", + "\n", + "\n", + "\n", + "0->2\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "59\n", + "\n", + "\n", + "__getitem__\n", + "depth:1\n", + "\n", + "input:\n", + "\n", + "(1, 6730) \n", + "\n", + "output: \n", + "\n", + "(1, 10) \n", + "\n", + "\n", + "\n", + "0->59\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "60\n", + "\n", + "\n", + "__getitem__\n", + "depth:1\n", + "\n", + "input:\n", + "\n", + "(1, 6730) \n", + "\n", + "output: \n", + "\n", + "(1, 6720) \n", + "\n", + "\n", + "\n", + "0->60\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "10\n", + "\n", + "\n", + "Linear\n", + "depth:2\n", + "\n", + "input:\n", + "\n", + "(1, 10) \n", + "\n", + "output: \n", + "\n", + "(1, 10) \n", + "\n", + "\n", + "\n", + "1->10\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "3\n", + "\n", + "\n", + "__getitem__\n", + "depth:1\n", + "\n", + "input:\n", + "\n", + "(1, 6720) \n", + "\n", + "output: \n", + "\n", + "(1, 208) \n", + "\n", + "\n", + "\n", + "2->3\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "5\n", + "\n", + "\n", + "__getitem__\n", + "depth:1\n", + "\n", + "input:\n", + "\n", + "(1, 6720) \n", + "\n", + "output: \n", + "\n", + "(1, 6500) \n", + "\n", + "\n", + "\n", + "2->5\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "7\n", + "\n", + "\n", + "__getitem__\n", + "depth:1\n", + "\n", + "input:\n", + "\n", + "(1, 6720) \n", + "\n", + "output: \n", + "\n", + "(1, 0) \n", + "\n", + "\n", + "\n", + "2->7\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "8\n", + "\n", + "\n", + "__getitem__\n", + "depth:1\n", + "\n", + "input:\n", + "\n", + "(1, 6720) \n", + "\n", + "output: \n", + "\n", + "(1, 12) \n", + "\n", + "\n", + "\n", + "2->8\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "4\n", + "\n", + "\n", + "reshape\n", + "depth:1\n", + "\n", + "input:\n", + "\n", + "(1, 208) \n", + "\n", + "output: \n", + "\n", + "(1, 16, 13) \n", + "\n", + "\n", + "\n", + "3->4\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "14\n", + "\n", + "\n", + "Linear\n", + "depth:2\n", + "\n", + "input:\n", + "\n", + "(1, 16, 13) \n", + "\n", + "output: \n", + "\n", + "(1, 16, 64) \n", + "\n", + "\n", + "\n", + "4->14\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "6\n", + "\n", + "\n", + "reshape\n", + "depth:1\n", + "\n", + "input:\n", + "\n", + "(1, 6500) \n", + "\n", + "output: \n", + "\n", + "(1, 500, 13) \n", + "\n", + "\n", + "\n", + "5->6\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "26\n", + "\n", + "\n", + "Linear\n", + "depth:2\n", + "\n", + "input:\n", + "\n", + "(1, 500, 13) \n", + "\n", + "output: \n", + "\n", + "(1, 500, 64) \n", + "\n", + "\n", + "\n", + "6->26\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "9\n", + "\n", + "\n", + "reshape\n", + "depth:1\n", + "\n", + "input:\n", + "\n", + "(1, 12) \n", + "\n", + "output: \n", + "\n", + "(1, 4, 3) \n", + "\n", + "\n", + "\n", + "8->9\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "22\n", + "\n", + "\n", + "Linear\n", + "depth:2\n", + "\n", + "input:\n", + "\n", + "(1, 4, 3) \n", + "\n", + "output: \n", + "\n", + "(1, 4, 3) \n", + "\n", + "\n", + "\n", + "9->22\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "11\n", + "\n", + "\n", + "Dropout\n", + "depth:2\n", + "\n", + "input:\n", + "\n", + "(1, 10) \n", + "\n", + "output: \n", + "\n", + "(1, 10) \n", + "\n", + "\n", + "\n", + "10->11\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "12\n", + "\n", + "\n", + "LayerNorm\n", + "depth:2\n", + "\n", + "input:\n", + "\n", + "(1, 10) \n", + "\n", + "output: \n", + "\n", + "(1, 10) \n", + "\n", + "\n", + "\n", + "11->12\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "13\n", + "\n", + "\n", + "Tanh\n", + "depth:2\n", + "\n", + "input:\n", + "\n", + "(1, 10) \n", + "\n", + "output: \n", + "\n", + "(1, 10) \n", + "\n", + "\n", + "\n", + "12->13\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "43\n", + "\n", + "\n", + "cat\n", + "depth:1\n", + "\n", + "input:\n", + "\n", + "(1, 10), 2 x (1, 32), (1, 3) \n", + "\n", + "output: \n", + "\n", + "(1, 77) \n", + "\n", + "\n", + "\n", + "13->43\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "15\n", + "\n", + "\n", + "Dropout\n", + "depth:2\n", + "\n", + "input:\n", + "\n", + "(1, 16, 64) \n", + "\n", + "output: \n", + "\n", + "(1, 16, 64) \n", + "\n", + "\n", + "\n", + "14->15\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "16\n", + "\n", + "\n", + "LayerNorm\n", + "depth:2\n", + "\n", + "input:\n", + "\n", + "(1, 16, 64) \n", + "\n", + "output: \n", + "\n", + "(1, 16, 64) \n", + "\n", + "\n", + "\n", + "15->16\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "17\n", + "\n", + "\n", + "Tanh\n", + "depth:2\n", + "\n", + "input:\n", + "\n", + "(1, 16, 64) \n", + "\n", + "output: \n", + "\n", + "(1, 16, 64) \n", + "\n", + "\n", + "\n", + "16->17\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "18\n", + "\n", + "\n", + "Linear\n", + "depth:2\n", + "\n", + "input:\n", + "\n", + "(1, 16, 64) \n", + "\n", + "output: \n", + "\n", + "(1, 16, 32) \n", + "\n", + "\n", + "\n", + "17->18\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "19\n", + "\n", + "\n", + "Dropout\n", + "depth:2\n", + "\n", + "input:\n", + "\n", + "(1, 16, 32) \n", + "\n", + "output: \n", + "\n", + "(1, 16, 32) \n", + "\n", + "\n", + "\n", + "18->19\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "20\n", + "\n", + "\n", + "LayerNorm\n", + "depth:2\n", + "\n", + "input:\n", + "\n", + "(1, 16, 32) \n", + "\n", + "output: \n", + "\n", + "(1, 16, 32) \n", + "\n", + "\n", + "\n", + "19->20\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "21\n", + "\n", + "\n", + "Tanh\n", + "depth:2\n", + "\n", + "input:\n", + "\n", + "(1, 16, 32) \n", + "\n", + "output: \n", + "\n", + "(1, 16, 32) \n", + "\n", + "\n", + "\n", + "20->21\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "34\n", + "\n", + "\n", + "permute\n", + "depth:1\n", + "\n", + "input:\n", + "\n", + "(1, 16, 32) \n", + "\n", + "output: \n", + "\n", + "(1, 32, 16) \n", + "\n", + "\n", + "\n", + "21->34\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "23\n", + "\n", + "\n", + "Dropout\n", + "depth:2\n", + "\n", + "input:\n", + "\n", + "(1, 4, 3) \n", + "\n", + "output: \n", + "\n", + "(1, 4, 3) \n", + "\n", + "\n", + "\n", + "22->23\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "24\n", + "\n", + "\n", + "LayerNorm\n", + "depth:2\n", + "\n", + "input:\n", + "\n", + "(1, 4, 3) \n", + "\n", + "output: \n", + "\n", + "(1, 4, 3) \n", + "\n", + "\n", + "\n", + "23->24\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "25\n", + "\n", + "\n", + "Tanh\n", + "depth:2\n", + "\n", + "input:\n", + "\n", + "(1, 4, 3) \n", + "\n", + "output: \n", + "\n", + "(1, 4, 3) \n", + "\n", + "\n", + "\n", + "24->25\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "37\n", + "\n", + "\n", + "permute\n", + "depth:1\n", + "\n", + "input:\n", + "\n", + "(1, 4, 3) \n", + "\n", + "output: \n", + "\n", + "(1, 3, 4) \n", + "\n", + "\n", + "\n", + "25->37\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "27\n", + "\n", + "\n", + "Dropout\n", + "depth:2\n", + "\n", + "input:\n", + "\n", + "(1, 500, 64) \n", + "\n", + "output: \n", + "\n", + "(1, 500, 64) \n", + "\n", + "\n", + "\n", + "26->27\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "28\n", + "\n", + "\n", + "LayerNorm\n", + "depth:2\n", + "\n", + "input:\n", + "\n", + "(1, 500, 64) \n", + "\n", + "output: \n", + "\n", + "(1, 500, 64) \n", + "\n", + "\n", + "\n", + "27->28\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "29\n", + "\n", + "\n", + "Tanh\n", + "depth:2\n", + "\n", + "input:\n", + "\n", + "(1, 500, 64) \n", + "\n", + "output: \n", + "\n", + "(1, 500, 64) \n", + "\n", + "\n", + "\n", + "28->29\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "30\n", + "\n", + "\n", + "Linear\n", + "depth:2\n", + "\n", + "input:\n", + "\n", + "(1, 500, 64) \n", + "\n", + "output: \n", + "\n", + "(1, 500, 32) \n", + "\n", + "\n", + "\n", + "29->30\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "31\n", + "\n", + "\n", + "Dropout\n", + "depth:2\n", + "\n", + "input:\n", + "\n", + "(1, 500, 32) \n", + "\n", + "output: \n", + "\n", + "(1, 500, 32) \n", + "\n", + "\n", + "\n", + "30->31\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "32\n", + "\n", + "\n", + "LayerNorm\n", + "depth:2\n", + "\n", + "input:\n", + "\n", + "(1, 500, 32) \n", + "\n", + "output: \n", + "\n", + "(1, 500, 32) \n", + "\n", + "\n", + "\n", + "31->32\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "33\n", + "\n", + "\n", + "Tanh\n", + "depth:2\n", + "\n", + "input:\n", + "\n", + "(1, 500, 32) \n", + "\n", + "output: \n", + "\n", + "(1, 500, 32) \n", + "\n", + "\n", + "\n", + "32->33\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "40\n", + "\n", + "\n", + "permute\n", + "depth:1\n", + "\n", + "input:\n", + "\n", + "(1, 500, 32) \n", + "\n", + "output: \n", + "\n", + "(1, 32, 500) \n", + "\n", + "\n", + "\n", + "33->40\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "35\n", + "\n", + "\n", + "max_pool1d\n", + "depth:1\n", + "\n", + "input:\n", + "\n", + "(1, 32, 16) \n", + "\n", + "output: \n", + "\n", + "(1, 32, 1) \n", + "\n", + "\n", + "\n", + "34->35\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "36\n", + "\n", + "\n", + "squeeze\n", + "depth:1\n", + "\n", + "input:\n", + "\n", + "(1, 32, 1) \n", + "\n", + "output: \n", + "\n", + "(1, 32) \n", + "\n", + "\n", + "\n", + "35->36\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "36->43\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "38\n", + "\n", + "\n", + "max_pool1d\n", + "depth:1\n", + "\n", + "input:\n", + "\n", + "(1, 3, 4) \n", + "\n", + "output: \n", + "\n", + "(1, 3, 1) \n", + "\n", + "\n", + "\n", + "37->38\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "39\n", + "\n", + "\n", + "squeeze\n", + "depth:1\n", + "\n", + "input:\n", + "\n", + "(1, 3, 1) \n", + "\n", + "output: \n", + "\n", + "(1, 3) \n", + "\n", + "\n", + "\n", + "38->39\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "39->43\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "41\n", + "\n", + "\n", + "max_pool1d\n", + "depth:1\n", + "\n", + "input:\n", + "\n", + "(1, 32, 500) \n", + "\n", + "output: \n", + "\n", + "(1, 32, 1) \n", + "\n", + "\n", + "\n", + "40->41\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "42\n", + "\n", + "\n", + "squeeze\n", + "depth:1\n", + "\n", + "input:\n", + "\n", + "(1, 32, 1) \n", + "\n", + "output: \n", + "\n", + "(1, 32) \n", + "\n", + "\n", + "\n", + "41->42\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "42->43\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "44\n", + "\n", + "\n", + "Linear\n", + "depth:2\n", + "\n", + "input:\n", + "\n", + "(1, 77) \n", + "\n", + "output: \n", + "\n", + "(1, 256) \n", + "\n", + "\n", + "\n", + "43->44\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "45\n", + "\n", + "\n", + "LayerNorm\n", + "depth:2\n", + "\n", + "input:\n", + "\n", + "(1, 256) \n", + "\n", + "output: \n", + "\n", + "(1, 256) \n", + "\n", + "\n", + "\n", + "44->45\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "46\n", + "\n", + "\n", + "Tanh\n", + "depth:2\n", + "\n", + "input:\n", + "\n", + "(1, 256) \n", + "\n", + "output: \n", + "\n", + "(1, 256) \n", + "\n", + "\n", + "\n", + "45->46\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "47\n", + "\n", + "\n", + "Dropout\n", + "depth:2\n", + "\n", + "input:\n", + "\n", + "(1, 256) \n", + "\n", + "output: \n", + "\n", + "(1, 256) \n", + "\n", + "\n", + "\n", + "46->47\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "48\n", + "\n", + "\n", + "Linear\n", + "depth:2\n", + "\n", + "input:\n", + "\n", + "(1, 256) \n", + "\n", + "output: \n", + "\n", + "(1, 128) \n", + "\n", + "\n", + "\n", + "47->48\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "49\n", + "\n", + "\n", + "LayerNorm\n", + "depth:2\n", + "\n", + "input:\n", + "\n", + "(1, 128) \n", + "\n", + "output: \n", + "\n", + "(1, 128) \n", + "\n", + "\n", + "\n", + "48->49\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "50\n", + "\n", + "\n", + "Tanh\n", + "depth:2\n", + "\n", + "input:\n", + "\n", + "(1, 128) \n", + "\n", + "output: \n", + "\n", + "(1, 128) \n", + "\n", + "\n", + "\n", + "49->50\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "51\n", + "\n", + "\n", + "Dropout\n", + "depth:2\n", + "\n", + "input:\n", + "\n", + "(1, 128) \n", + "\n", + "output: \n", + "\n", + "(1, 128) \n", + "\n", + "\n", + "\n", + "50->51\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "52\n", + "\n", + "\n", + "Linear\n", + "depth:2\n", + "\n", + "input:\n", + "\n", + "(1, 128) \n", + "\n", + "output: \n", + "\n", + "(1, 64) \n", + "\n", + "\n", + "\n", + "51->52\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "53\n", + "\n", + "\n", + "LayerNorm\n", + "depth:2\n", + "\n", + "input:\n", + "\n", + "(1, 64) \n", + "\n", + "output: \n", + "\n", + "(1, 64) \n", + "\n", + "\n", + "\n", + "52->53\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "54\n", + "\n", + "\n", + "Tanh\n", + "depth:2\n", + "\n", + "input:\n", + "\n", + "(1, 64) \n", + "\n", + "output: \n", + "\n", + "(1, 64) \n", + "\n", + "\n", + "\n", + "53->54\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "55\n", + "\n", + "\n", + "Dropout\n", + "depth:2\n", + "\n", + "input:\n", + "\n", + "(1, 64) \n", + "\n", + "output: \n", + "\n", + "(1, 64) \n", + "\n", + "\n", + "\n", + "54->55\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "56\n", + "\n", + "\n", + "Linear\n", + "depth:2\n", + "\n", + "input:\n", + "\n", + "(1, 64) \n", + "\n", + "output: \n", + "\n", + "(1, 64) \n", + "\n", + "\n", + "\n", + "55->56\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "57\n", + "\n", + "\n", + "LayerNorm\n", + "depth:2\n", + "\n", + "input:\n", + "\n", + "(1, 64) \n", + "\n", + "output: \n", + "\n", + "(1, 64) \n", + "\n", + "\n", + "\n", + "56->57\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "58\n", + "\n", + "\n", + "output-tensor\n", + "depth:0\n", + "\n", + "(1, 64)\n", + "\n", + "\n", + "\n", + "57->58\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "68\n", + "\n", + "\n", + "Linear\n", + "depth:2\n", + "\n", + "input:\n", + "\n", + "(1, 10) \n", + "\n", + "output: \n", + "\n", + "(1, 10) \n", + "\n", + "\n", + "\n", + "59->68\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "61\n", + "\n", + "\n", + "__getitem__\n", + "depth:1\n", + "\n", + "input:\n", + "\n", + "(1, 6720) \n", + "\n", + "output: \n", + "\n", + "(1, 208) \n", + "\n", + "\n", + "\n", + "60->61\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "63\n", + "\n", + "\n", + "__getitem__\n", + "depth:1\n", + "\n", + "input:\n", + "\n", + "(1, 6720) \n", + "\n", + "output: \n", + "\n", + "(1, 6500) \n", + "\n", + "\n", + "\n", + "60->63\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "65\n", + "\n", + "\n", + "__getitem__\n", + "depth:1\n", + "\n", + "input:\n", + "\n", + "(1, 6720) \n", + "\n", + "output: \n", + "\n", + "(1, 0) \n", + "\n", + "\n", + "\n", + "60->65\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "66\n", + "\n", + "\n", + "__getitem__\n", + "depth:1\n", + "\n", + "input:\n", + "\n", + "(1, 6720) \n", + "\n", + "output: \n", + "\n", + "(1, 12) \n", + "\n", + "\n", + "\n", + "60->66\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "62\n", + "\n", + "\n", + "reshape\n", + "depth:1\n", + "\n", + "input:\n", + "\n", + "(1, 208) \n", + "\n", + "output: \n", + "\n", + "(1, 16, 13) \n", + "\n", + "\n", + "\n", + "61->62\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "72\n", + "\n", + "\n", + "Linear\n", + "depth:2\n", + "\n", + "input:\n", + "\n", + "(1, 16, 13) \n", + "\n", + "output: \n", + "\n", + "(1, 16, 64) \n", + "\n", + "\n", + "\n", + "62->72\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "64\n", + "\n", + "\n", + "reshape\n", + "depth:1\n", + "\n", + "input:\n", + "\n", + "(1, 6500) \n", + "\n", + "output: \n", + "\n", + "(1, 500, 13) \n", + "\n", + "\n", + "\n", + "63->64\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "84\n", + "\n", + "\n", + "Linear\n", + "depth:2\n", + "\n", + "input:\n", + "\n", + "(1, 500, 13) \n", + "\n", + "output: \n", + "\n", + "(1, 500, 64) \n", + "\n", + "\n", + "\n", + "64->84\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "67\n", + "\n", + "\n", + "reshape\n", + "depth:1\n", + "\n", + "input:\n", + "\n", + "(1, 12) \n", + "\n", + "output: \n", + "\n", + "(1, 4, 3) \n", + "\n", + "\n", + "\n", + "66->67\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "80\n", + "\n", + "\n", + "Linear\n", + "depth:2\n", + "\n", + "input:\n", + "\n", + "(1, 4, 3) \n", + "\n", + "output: \n", + "\n", + "(1, 4, 3) \n", + "\n", + "\n", + "\n", + "67->80\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "69\n", + "\n", + "\n", + "Dropout\n", + "depth:2\n", + "\n", + "input:\n", + "\n", + "(1, 10) \n", + "\n", + "output: \n", + "\n", + "(1, 10) \n", + "\n", + "\n", + "\n", + "68->69\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "70\n", + "\n", + "\n", + "LayerNorm\n", + "depth:2\n", + "\n", + "input:\n", + "\n", + "(1, 10) \n", + "\n", + "output: \n", + "\n", + "(1, 10) \n", + "\n", + "\n", + "\n", + "69->70\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "71\n", + "\n", + "\n", + "Tanh\n", + "depth:2\n", + "\n", + "input:\n", + "\n", + "(1, 10) \n", + "\n", + "output: \n", + "\n", + "(1, 10) \n", + "\n", + "\n", + "\n", + "70->71\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "101\n", + "\n", + "\n", + "cat\n", + "depth:1\n", + "\n", + "input:\n", + "\n", + "(1, 10), 2 x (1, 32), (1, 3) \n", + "\n", + "output: \n", + "\n", + "(1, 77) \n", + "\n", + "\n", + "\n", + "71->101\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "73\n", + "\n", + "\n", + "Dropout\n", + "depth:2\n", + "\n", + "input:\n", + "\n", + "(1, 16, 64) \n", + "\n", + "output: \n", + "\n", + "(1, 16, 64) \n", + "\n", + "\n", + "\n", + "72->73\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "74\n", + "\n", + "\n", + "LayerNorm\n", + "depth:2\n", + "\n", + "input:\n", + "\n", + "(1, 16, 64) \n", + "\n", + "output: \n", + "\n", + "(1, 16, 64) \n", + "\n", + "\n", + "\n", + "73->74\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "75\n", + "\n", + "\n", + "Tanh\n", + "depth:2\n", + "\n", + "input:\n", + "\n", + "(1, 16, 64) \n", + "\n", + "output: \n", + "\n", + "(1, 16, 64) \n", + "\n", + "\n", + "\n", + "74->75\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "76\n", + "\n", + "\n", + "Linear\n", + "depth:2\n", + "\n", + "input:\n", + "\n", + "(1, 16, 64) \n", + "\n", + "output: \n", + "\n", + "(1, 16, 32) \n", + "\n", + "\n", + "\n", + "75->76\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "77\n", + "\n", + "\n", + "Dropout\n", + "depth:2\n", + "\n", + "input:\n", + "\n", + "(1, 16, 32) \n", + "\n", + "output: \n", + "\n", + "(1, 16, 32) \n", + "\n", + "\n", + "\n", + "76->77\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "78\n", + "\n", + "\n", + "LayerNorm\n", + "depth:2\n", + "\n", + "input:\n", + "\n", + "(1, 16, 32) \n", + "\n", + "output: \n", + "\n", + "(1, 16, 32) \n", + "\n", + "\n", + "\n", + "77->78\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "79\n", + "\n", + "\n", + "Tanh\n", + "depth:2\n", + "\n", + "input:\n", + "\n", + "(1, 16, 32) \n", + "\n", + "output: \n", + "\n", + "(1, 16, 32) \n", + "\n", + "\n", + "\n", + "78->79\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "92\n", + "\n", + "\n", + "permute\n", + "depth:1\n", + "\n", + "input:\n", + "\n", + "(1, 16, 32) \n", + "\n", + "output: \n", + "\n", + "(1, 32, 16) \n", + "\n", + "\n", + "\n", + "79->92\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "81\n", + "\n", + "\n", + "Dropout\n", + "depth:2\n", + "\n", + "input:\n", + "\n", + "(1, 4, 3) \n", + "\n", + "output: \n", + "\n", + "(1, 4, 3) \n", + "\n", + "\n", + "\n", + "80->81\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "82\n", + "\n", + "\n", + "LayerNorm\n", + "depth:2\n", + "\n", + "input:\n", + "\n", + "(1, 4, 3) \n", + "\n", + "output: \n", + "\n", + "(1, 4, 3) \n", + "\n", + "\n", + "\n", + "81->82\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "83\n", + "\n", + "\n", + "Tanh\n", + "depth:2\n", + "\n", + "input:\n", + "\n", + "(1, 4, 3) \n", + "\n", + "output: \n", + "\n", + "(1, 4, 3) \n", + "\n", + "\n", + "\n", + "82->83\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "95\n", + "\n", + "\n", + "permute\n", + "depth:1\n", + "\n", + "input:\n", + "\n", + "(1, 4, 3) \n", + "\n", + "output: \n", + "\n", + "(1, 3, 4) \n", + "\n", + "\n", + "\n", + "83->95\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "85\n", + "\n", + "\n", + "Dropout\n", + "depth:2\n", + "\n", + "input:\n", + "\n", + "(1, 500, 64) \n", + "\n", + "output: \n", + "\n", + "(1, 500, 64) \n", + "\n", + "\n", + "\n", + "84->85\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "86\n", + "\n", + "\n", + "LayerNorm\n", + "depth:2\n", + "\n", + "input:\n", + "\n", + "(1, 500, 64) \n", + "\n", + "output: \n", + "\n", + "(1, 500, 64) \n", + "\n", + "\n", + "\n", + "85->86\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "87\n", + "\n", + "\n", + "Tanh\n", + "depth:2\n", + "\n", + "input:\n", + "\n", + "(1, 500, 64) \n", + "\n", + "output: \n", + "\n", + "(1, 500, 64) \n", + "\n", + "\n", + "\n", + "86->87\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "88\n", + "\n", + "\n", + "Linear\n", + "depth:2\n", + "\n", + "input:\n", + "\n", + "(1, 500, 64) \n", + "\n", + "output: \n", + "\n", + "(1, 500, 32) \n", + "\n", + "\n", + "\n", + "87->88\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "89\n", + "\n", + "\n", + "Dropout\n", + "depth:2\n", + "\n", + "input:\n", + "\n", + "(1, 500, 32) \n", + "\n", + "output: \n", + "\n", + "(1, 500, 32) \n", + "\n", + "\n", + "\n", + "88->89\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "90\n", + "\n", + "\n", + "LayerNorm\n", + "depth:2\n", + "\n", + "input:\n", + "\n", + "(1, 500, 32) \n", + "\n", + "output: \n", + "\n", + "(1, 500, 32) \n", + "\n", + "\n", + "\n", + "89->90\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "91\n", + "\n", + "\n", + "Tanh\n", + "depth:2\n", + "\n", + "input:\n", + "\n", + "(1, 500, 32) \n", + "\n", + "output: \n", + "\n", + "(1, 500, 32) \n", + "\n", + "\n", + "\n", + "90->91\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "98\n", + "\n", + "\n", + "permute\n", + "depth:1\n", + "\n", + "input:\n", + "\n", + "(1, 500, 32) \n", + "\n", + "output: \n", + "\n", + "(1, 32, 500) \n", + "\n", + "\n", + "\n", + "91->98\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "93\n", + "\n", + "\n", + "max_pool1d\n", + "depth:1\n", + "\n", + "input:\n", + "\n", + "(1, 32, 16) \n", + "\n", + "output: \n", + "\n", + "(1, 32, 1) \n", + "\n", + "\n", + "\n", + "92->93\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "94\n", + "\n", + "\n", + "squeeze\n", + "depth:1\n", + "\n", + "input:\n", + "\n", + "(1, 32, 1) \n", + "\n", + "output: \n", + "\n", + "(1, 32) \n", + "\n", + "\n", + "\n", + "93->94\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "94->101\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "96\n", + "\n", + "\n", + "max_pool1d\n", + "depth:1\n", + "\n", + "input:\n", + "\n", + "(1, 3, 4) \n", + "\n", + "output: \n", + "\n", + "(1, 3, 1) \n", + "\n", + "\n", + "\n", + "95->96\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "97\n", + "\n", + "\n", + "squeeze\n", + "depth:1\n", + "\n", + "input:\n", + "\n", + "(1, 3, 1) \n", + "\n", + "output: \n", + "\n", + "(1, 3) \n", + "\n", + "\n", + "\n", + "96->97\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "97->101\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "99\n", + "\n", + "\n", + "max_pool1d\n", + "depth:1\n", + "\n", + "input:\n", + "\n", + "(1, 32, 500) \n", + "\n", + "output: \n", + "\n", + "(1, 32, 1) \n", + "\n", + "\n", + "\n", + "98->99\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "100\n", + "\n", + "\n", + "squeeze\n", + "depth:1\n", + "\n", + "input:\n", + "\n", + "(1, 32, 1) \n", + "\n", + "output: \n", + "\n", + "(1, 32) \n", + "\n", + "\n", + "\n", + "99->100\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "100->101\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "102\n", + "\n", + "\n", + "Linear\n", + "depth:2\n", + "\n", + "input:\n", + "\n", + "(1, 77) \n", + "\n", + "output: \n", + "\n", + "(1, 256) \n", + "\n", + "\n", + "\n", + "101->102\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "103\n", + "\n", + "\n", + "LayerNorm\n", + "depth:2\n", + "\n", + "input:\n", + "\n", + "(1, 256) \n", + "\n", + "output: \n", + "\n", + "(1, 256) \n", + "\n", + "\n", + "\n", + "102->103\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "104\n", + "\n", + "\n", + "Tanh\n", + "depth:2\n", + "\n", + "input:\n", + "\n", + "(1, 256) \n", + "\n", + "output: \n", + "\n", + "(1, 256) \n", + "\n", + "\n", + "\n", + "103->104\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "105\n", + "\n", + "\n", + "Dropout\n", + "depth:2\n", + "\n", + "input:\n", + "\n", + "(1, 256) \n", + "\n", + "output: \n", + "\n", + "(1, 256) \n", + "\n", + "\n", + "\n", + "104->105\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "106\n", + "\n", + "\n", + "Linear\n", + "depth:2\n", + "\n", + "input:\n", + "\n", + "(1, 256) \n", + "\n", + "output: \n", + "\n", + "(1, 128) \n", + "\n", + "\n", + "\n", + "105->106\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "107\n", + "\n", + "\n", + "LayerNorm\n", + "depth:2\n", + "\n", + "input:\n", + "\n", + "(1, 128) \n", + "\n", + "output: \n", + "\n", + "(1, 128) \n", + "\n", + "\n", + "\n", + "106->107\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "108\n", + "\n", + "\n", + "Tanh\n", + "depth:2\n", + "\n", + "input:\n", + "\n", + "(1, 128) \n", + "\n", + "output: \n", + "\n", + "(1, 128) \n", + "\n", + "\n", + "\n", + "107->108\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "109\n", + "\n", + "\n", + "Dropout\n", + "depth:2\n", + "\n", + "input:\n", + "\n", + "(1, 128) \n", + "\n", + "output: \n", + "\n", + "(1, 128) \n", + "\n", + "\n", + "\n", + "108->109\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "110\n", + "\n", + "\n", + "Linear\n", + "depth:2\n", + "\n", + "input:\n", + "\n", + "(1, 128) \n", + "\n", + "output: \n", + "\n", + "(1, 64) \n", + "\n", + "\n", + "\n", + "109->110\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "111\n", + "\n", + "\n", + "LayerNorm\n", + "depth:2\n", + "\n", + "input:\n", + "\n", + "(1, 64) \n", + "\n", + "output: \n", + "\n", + "(1, 64) \n", + "\n", + "\n", + "\n", + "110->111\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "112\n", + "\n", + "\n", + "Tanh\n", + "depth:2\n", + "\n", + "input:\n", + "\n", + "(1, 64) \n", + "\n", + "output: \n", + "\n", + "(1, 64) \n", + "\n", + "\n", + "\n", + "111->112\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "113\n", + "\n", + "\n", + "Dropout\n", + "depth:2\n", + "\n", + "input:\n", + "\n", + "(1, 64) \n", + "\n", + "output: \n", + "\n", + "(1, 64) \n", + "\n", + "\n", + "\n", + "112->113\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "114\n", + "\n", + "\n", + "Linear\n", + "depth:2\n", + "\n", + "input:\n", + "\n", + "(1, 64) \n", + "\n", + "output: \n", + "\n", + "(1, 64) \n", + "\n", + "\n", + "\n", + "113->114\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "115\n", + "\n", + "\n", + "LayerNorm\n", + "depth:2\n", + "\n", + "input:\n", + "\n", + "(1, 64) \n", + "\n", + "output: \n", + "\n", + "(1, 64) \n", + "\n", + "\n", + "\n", + "114->115\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "116\n", + "\n", + "\n", + "output-tensor\n", + "depth:0\n", + "\n", + "(1, 64)\n", + "\n", + "\n", + "\n", + "115->116\n", + "\n", + "\n", + "\n", + "\n", + "\n" + ], + "text/plain": [ + "" + ] + }, + "execution_count": 11, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "model_graph = draw_graph(model.policy.mlp_extractor, input_data=torch.Tensor(y).unsqueeze(0).cuda(), device='cuda')\n", + "model_graph.visual_graph" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "metadata": {}, + "outputs": [], + "source": [ + "model_graph = draw_graph(model.policy, input_data=torch.Tensor(y).unsqueeze(0).cuda(), device='cuda')" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "metadata": {}, + "outputs": [ + { + "data": { + "image/svg+xml": [ + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "model\n", + "\n", + "\n", + "\n", + "0\n", + "\n", + "\n", + "input-tensor\n", + "depth:0\n", + "\n", + "(1, 6730)\n", + "\n", + "\n", + "\n", + "1\n", + "\n", + "\n", + "float\n", + "depth:1\n", + "\n", + "input:\n", + "\n", + "(1, 6730) \n", + "\n", + "output: \n", + "\n", + "(1, 6730) \n", + "\n", + "\n", + "\n", + "0->1\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "2\n", + "\n", + "\n", + "Flatten\n", + "depth:2\n", + "\n", + "input:\n", + "\n", + "(1, 6730) \n", + "\n", + "output: \n", + "\n", + "(1, 6730) \n", + "\n", + "\n", + "\n", + "1->2\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "3\n", + "\n", + "\n", + "__getitem__\n", + "depth:2\n", + "\n", + "input:\n", + "\n", + "(1, 6730) \n", + "\n", + "output: \n", + "\n", + "(1, 10) \n", + "\n", + "\n", + "\n", + "2->3\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "4\n", + "\n", + "\n", + "__getitem__\n", + "depth:2\n", + "\n", + "input:\n", + "\n", + "(1, 6730) \n", + "\n", + "output: \n", + "\n", + "(1, 6720) \n", + "\n", + "\n", + "\n", + "2->4\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "60\n", + "\n", + "\n", + "__getitem__\n", + "depth:2\n", + "\n", + "input:\n", + "\n", + "(1, 6730) \n", + "\n", + "output: \n", + "\n", + "(1, 10) \n", + "\n", + "\n", + "\n", + "2->60\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "61\n", + "\n", + "\n", + "__getitem__\n", + "depth:2\n", + "\n", + "input:\n", + "\n", + "(1, 6730) \n", + "\n", + "output: \n", + "\n", + "(1, 6720) \n", + "\n", + "\n", + "\n", + "2->61\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "12\n", + "\n", + "\n", + "Linear\n", + "depth:3\n", + "\n", + "input:\n", + "\n", + "(1, 10) \n", + "\n", + "output: \n", + "\n", + "(1, 10) \n", + "\n", + "\n", + "\n", + "3->12\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "5\n", + "\n", + "\n", + "__getitem__\n", + "depth:2\n", + "\n", + "input:\n", + "\n", + "(1, 6720) \n", + "\n", + "output: \n", + "\n", + "(1, 208) \n", + "\n", + "\n", + "\n", + "4->5\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "7\n", + "\n", + "\n", + "__getitem__\n", + "depth:2\n", + "\n", + "input:\n", + "\n", + "(1, 6720) \n", + "\n", + "output: \n", + "\n", + "(1, 6500) \n", + "\n", + "\n", + "\n", + "4->7\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "9\n", + "\n", + "\n", + "__getitem__\n", + "depth:2\n", + "\n", + "input:\n", + "\n", + "(1, 6720) \n", + "\n", + "output: \n", + "\n", + "(1, 0) \n", + "\n", + "\n", + "\n", + "4->9\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "10\n", + "\n", + "\n", + "__getitem__\n", + "depth:2\n", + "\n", + "input:\n", + "\n", + "(1, 6720) \n", + "\n", + "output: \n", + "\n", + "(1, 12) \n", + "\n", + "\n", + "\n", + "4->10\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "6\n", + "\n", + "\n", + "reshape\n", + "depth:2\n", + "\n", + "input:\n", + "\n", + "(1, 208) \n", + "\n", + "output: \n", + "\n", + "(1, 16, 13) \n", + "\n", + "\n", + "\n", + "5->6\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "16\n", + "\n", + "\n", + "Linear\n", + "depth:3\n", + "\n", + "input:\n", + "\n", + "(1, 16, 13) \n", + "\n", + "output: \n", + "\n", + "(1, 16, 64) \n", + "\n", + "\n", + "\n", + "6->16\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "8\n", + "\n", + "\n", + "reshape\n", + "depth:2\n", + "\n", + "input:\n", + "\n", + "(1, 6500) \n", + "\n", + "output: \n", + "\n", + "(1, 500, 13) \n", + "\n", + "\n", + "\n", + "7->8\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "28\n", + "\n", + "\n", + "Linear\n", + "depth:3\n", + "\n", + "input:\n", + "\n", + "(1, 500, 13) \n", + "\n", + "output: \n", + "\n", + "(1, 500, 64) \n", + "\n", + "\n", + "\n", + "8->28\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "11\n", + "\n", + "\n", + "reshape\n", + "depth:2\n", + "\n", + "input:\n", + "\n", + "(1, 12) \n", + "\n", + "output: \n", + "\n", + "(1, 4, 3) \n", + "\n", + "\n", + "\n", + "10->11\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "24\n", + "\n", + "\n", + "Linear\n", + "depth:3\n", + "\n", + "input:\n", + "\n", + "(1, 4, 3) \n", + "\n", + "output: \n", + "\n", + "(1, 4, 3) \n", + "\n", + "\n", + "\n", + "11->24\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "13\n", + "\n", + "\n", + "Dropout\n", + "depth:3\n", + "\n", + "input:\n", + "\n", + "(1, 10) \n", + "\n", + "output: \n", + "\n", + "(1, 10) \n", + "\n", + "\n", + "\n", + "12->13\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "14\n", + "\n", + "\n", + "LayerNorm\n", + "depth:3\n", + "\n", + "input:\n", + "\n", + "(1, 10) \n", + "\n", + "output: \n", + "\n", + "(1, 10) \n", + "\n", + "\n", + "\n", + "13->14\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "15\n", + "\n", + "\n", + "Tanh\n", + "depth:3\n", + "\n", + "input:\n", + "\n", + "(1, 10) \n", + "\n", + "output: \n", + "\n", + "(1, 10) \n", + "\n", + "\n", + "\n", + "14->15\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "45\n", + "\n", + "\n", + "cat\n", + "depth:2\n", + "\n", + "input:\n", + "\n", + "(1, 10), 2 x (1, 32), (1, 3) \n", + "\n", + "output: \n", + "\n", + "(1, 77) \n", + "\n", + "\n", + "\n", + "15->45\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "17\n", + "\n", + "\n", + "Dropout\n", + "depth:3\n", + "\n", + "input:\n", + "\n", + "(1, 16, 64) \n", + "\n", + "output: \n", + "\n", + "(1, 16, 64) \n", + "\n", + "\n", + "\n", + "16->17\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "18\n", + "\n", + "\n", + "LayerNorm\n", + "depth:3\n", + "\n", + "input:\n", + "\n", + "(1, 16, 64) \n", + "\n", + "output: \n", + "\n", + "(1, 16, 64) \n", + "\n", + "\n", + "\n", + "17->18\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "19\n", + "\n", + "\n", + "Tanh\n", + "depth:3\n", + "\n", + "input:\n", + "\n", + "(1, 16, 64) \n", + "\n", + "output: \n", + "\n", + "(1, 16, 64) \n", + "\n", + "\n", + "\n", + "18->19\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "20\n", + "\n", + "\n", + "Linear\n", + "depth:3\n", + "\n", + "input:\n", + "\n", + "(1, 16, 64) \n", + "\n", + "output: \n", + "\n", + "(1, 16, 32) \n", + "\n", + "\n", + "\n", + "19->20\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "21\n", + "\n", + "\n", + "Dropout\n", + "depth:3\n", + "\n", + "input:\n", + "\n", + "(1, 16, 32) \n", + "\n", + "output: \n", + "\n", + "(1, 16, 32) \n", + "\n", + "\n", + "\n", + "20->21\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "22\n", + "\n", + "\n", + "LayerNorm\n", + "depth:3\n", + "\n", + "input:\n", + "\n", + "(1, 16, 32) \n", + "\n", + "output: \n", + "\n", + "(1, 16, 32) \n", + "\n", + "\n", + "\n", + "21->22\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "23\n", + "\n", + "\n", + "Tanh\n", + "depth:3\n", + "\n", + "input:\n", + "\n", + "(1, 16, 32) \n", + "\n", + "output: \n", + "\n", + "(1, 16, 32) \n", + "\n", + "\n", + "\n", + "22->23\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "36\n", + "\n", + "\n", + "permute\n", + "depth:2\n", + "\n", + "input:\n", + "\n", + "(1, 16, 32) \n", + "\n", + "output: \n", + "\n", + "(1, 32, 16) \n", + "\n", + "\n", + "\n", + "23->36\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "25\n", + "\n", + "\n", + "Dropout\n", + "depth:3\n", + "\n", + "input:\n", + "\n", + "(1, 4, 3) \n", + "\n", + "output: \n", + "\n", + "(1, 4, 3) \n", + "\n", + "\n", + "\n", + "24->25\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "26\n", + "\n", + "\n", + "LayerNorm\n", + "depth:3\n", + "\n", + "input:\n", + "\n", + "(1, 4, 3) \n", + "\n", + "output: \n", + "\n", + "(1, 4, 3) \n", + "\n", + "\n", + "\n", + "25->26\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "27\n", + "\n", + "\n", + "Tanh\n", + "depth:3\n", + "\n", + "input:\n", + "\n", + "(1, 4, 3) \n", + "\n", + "output: \n", + "\n", + "(1, 4, 3) \n", + "\n", + "\n", + "\n", + "26->27\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "39\n", + "\n", + "\n", + "permute\n", + "depth:2\n", + "\n", + "input:\n", + "\n", + "(1, 4, 3) \n", + "\n", + "output: \n", + "\n", + "(1, 3, 4) \n", + "\n", + "\n", + "\n", + "27->39\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "29\n", + "\n", + "\n", + "Dropout\n", + "depth:3\n", + "\n", + "input:\n", + "\n", + "(1, 500, 64) \n", + "\n", + "output: \n", + "\n", + "(1, 500, 64) \n", + "\n", + "\n", + "\n", + "28->29\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "30\n", + "\n", + "\n", + "LayerNorm\n", + "depth:3\n", + "\n", + "input:\n", + "\n", + "(1, 500, 64) \n", + "\n", + "output: \n", + "\n", + "(1, 500, 64) \n", + "\n", + "\n", + "\n", + "29->30\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "31\n", + "\n", + "\n", + "Tanh\n", + "depth:3\n", + "\n", + "input:\n", + "\n", + "(1, 500, 64) \n", + "\n", + "output: \n", + "\n", + "(1, 500, 64) \n", + "\n", + "\n", + "\n", + "30->31\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "32\n", + "\n", + "\n", + "Linear\n", + "depth:3\n", + "\n", + "input:\n", + "\n", + "(1, 500, 64) \n", + "\n", + "output: \n", + "\n", + "(1, 500, 32) \n", + "\n", + "\n", + "\n", + "31->32\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "33\n", + "\n", + "\n", + "Dropout\n", + "depth:3\n", + "\n", + "input:\n", + "\n", + "(1, 500, 32) \n", + "\n", + "output: \n", + "\n", + "(1, 500, 32) \n", + "\n", + "\n", + "\n", + "32->33\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "34\n", + "\n", + "\n", + "LayerNorm\n", + "depth:3\n", + "\n", + "input:\n", + "\n", + "(1, 500, 32) \n", + "\n", + "output: \n", + "\n", + "(1, 500, 32) \n", + "\n", + "\n", + "\n", + "33->34\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "35\n", + "\n", + "\n", + "Tanh\n", + "depth:3\n", + "\n", + "input:\n", + "\n", + "(1, 500, 32) \n", + "\n", + "output: \n", + "\n", + "(1, 500, 32) \n", + "\n", + "\n", + "\n", + "34->35\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "42\n", + "\n", + "\n", + "permute\n", + "depth:2\n", + "\n", + "input:\n", + "\n", + "(1, 500, 32) \n", + "\n", + "output: \n", + "\n", + "(1, 32, 500) \n", + "\n", + "\n", + "\n", + "35->42\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "37\n", + "\n", + "\n", + "max_pool1d\n", + "depth:2\n", + "\n", + "input:\n", + "\n", + "(1, 32, 16) \n", + "\n", + "output: \n", + "\n", + "(1, 32, 1) \n", + "\n", + "\n", + "\n", + "36->37\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "38\n", + "\n", + "\n", + "squeeze\n", + "depth:2\n", + "\n", + "input:\n", + "\n", + "(1, 32, 1) \n", + "\n", + "output: \n", + "\n", + "(1, 32) \n", + "\n", + "\n", + "\n", + "37->38\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "38->45\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "40\n", + "\n", + "\n", + "max_pool1d\n", + "depth:2\n", + "\n", + "input:\n", + "\n", + "(1, 3, 4) \n", + "\n", + "output: \n", + "\n", + "(1, 3, 1) \n", + "\n", + "\n", + "\n", + "39->40\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "41\n", + "\n", + "\n", + "squeeze\n", + "depth:2\n", + "\n", + "input:\n", + "\n", + "(1, 3, 1) \n", + "\n", + "output: \n", + "\n", + "(1, 3) \n", + "\n", + "\n", + "\n", + "40->41\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "41->45\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "43\n", + "\n", + "\n", + "max_pool1d\n", + "depth:2\n", + "\n", + "input:\n", + "\n", + "(1, 32, 500) \n", + "\n", + "output: \n", + "\n", + "(1, 32, 1) \n", + "\n", + "\n", + "\n", + "42->43\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "44\n", + "\n", + "\n", + "squeeze\n", + "depth:2\n", + "\n", + "input:\n", + "\n", + "(1, 32, 1) \n", + "\n", + "output: \n", + "\n", + "(1, 32) \n", + "\n", + "\n", + "\n", + "43->44\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "44->45\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "46\n", + "\n", + "\n", + "Linear\n", + "depth:3\n", + "\n", + "input:\n", + "\n", + "(1, 77) \n", + "\n", + "output: \n", + "\n", + "(1, 256) \n", + "\n", + "\n", + "\n", + "45->46\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "47\n", + "\n", + "\n", + "LayerNorm\n", + "depth:3\n", + "\n", + "input:\n", + "\n", + "(1, 256) \n", + "\n", + "output: \n", + "\n", + "(1, 256) \n", + "\n", + "\n", + "\n", + "46->47\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "48\n", + "\n", + "\n", + "Tanh\n", + "depth:3\n", + "\n", + "input:\n", + "\n", + "(1, 256) \n", + "\n", + "output: \n", + "\n", + "(1, 256) \n", + "\n", + "\n", + "\n", + "47->48\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "49\n", + "\n", + "\n", + "Dropout\n", + "depth:3\n", + "\n", + "input:\n", + "\n", + "(1, 256) \n", + "\n", + "output: \n", + "\n", + "(1, 256) \n", + "\n", + "\n", + "\n", + "48->49\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "50\n", + "\n", + "\n", + "Linear\n", + "depth:3\n", + "\n", + "input:\n", + "\n", + "(1, 256) \n", + "\n", + "output: \n", + "\n", + "(1, 128) \n", + "\n", + "\n", + "\n", + "49->50\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "51\n", + "\n", + "\n", + "LayerNorm\n", + "depth:3\n", + "\n", + "input:\n", + "\n", + "(1, 128) \n", + "\n", + "output: \n", + "\n", + "(1, 128) \n", + "\n", + "\n", + "\n", + "50->51\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "52\n", + "\n", + "\n", + "Tanh\n", + "depth:3\n", + "\n", + "input:\n", + "\n", + "(1, 128) \n", + "\n", + "output: \n", + "\n", + "(1, 128) \n", + "\n", + "\n", + "\n", + "51->52\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "53\n", + "\n", + "\n", + "Dropout\n", + "depth:3\n", + "\n", + "input:\n", + "\n", + "(1, 128) \n", + "\n", + "output: \n", + "\n", + "(1, 128) \n", + "\n", + "\n", + "\n", + "52->53\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "54\n", + "\n", + "\n", + "Linear\n", + "depth:3\n", + "\n", + "input:\n", + "\n", + "(1, 128) \n", + "\n", + "output: \n", + "\n", + "(1, 64) \n", + "\n", + "\n", + "\n", + "53->54\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "55\n", + "\n", + "\n", + "LayerNorm\n", + "depth:3\n", + "\n", + "input:\n", + "\n", + "(1, 64) \n", + "\n", + "output: \n", + "\n", + "(1, 64) \n", + "\n", + "\n", + "\n", + "54->55\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "56\n", + "\n", + "\n", + "Tanh\n", + "depth:3\n", + "\n", + "input:\n", + "\n", + "(1, 64) \n", + "\n", + "output: \n", + "\n", + "(1, 64) \n", + "\n", + "\n", + "\n", + "55->56\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "57\n", + "\n", + "\n", + "Dropout\n", + "depth:3\n", + "\n", + "input:\n", + "\n", + "(1, 64) \n", + "\n", + "output: \n", + "\n", + "(1, 64) \n", + "\n", + "\n", + "\n", + "56->57\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "58\n", + "\n", + "\n", + "Linear\n", + "depth:3\n", + "\n", + "input:\n", + "\n", + "(1, 64) \n", + "\n", + "output: \n", + "\n", + "(1, 64) \n", + "\n", + "\n", + "\n", + "57->58\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "59\n", + "\n", + "\n", + "LayerNorm\n", + "depth:3\n", + "\n", + "input:\n", + "\n", + "(1, 64) \n", + "\n", + "output: \n", + "\n", + "(1, 64) \n", + "\n", + "\n", + "\n", + "58->59\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "119\n", + "\n", + "\n", + "Linear\n", + "depth:1\n", + "\n", + "input:\n", + "\n", + "(1, 64) \n", + "\n", + "output: \n", + "\n", + "(1, 25) \n", + "\n", + "\n", + "\n", + "59->119\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "69\n", + "\n", + "\n", + "Linear\n", + "depth:3\n", + "\n", + "input:\n", + "\n", + "(1, 10) \n", + "\n", + "output: \n", + "\n", + "(1, 10) \n", + "\n", + "\n", + "\n", + "60->69\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "62\n", + "\n", + "\n", + "__getitem__\n", + "depth:2\n", + "\n", + "input:\n", + "\n", + "(1, 6720) \n", + "\n", + "output: \n", + "\n", + "(1, 208) \n", + "\n", + "\n", + "\n", + "61->62\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "64\n", + "\n", + "\n", + "__getitem__\n", + "depth:2\n", + "\n", + "input:\n", + "\n", + "(1, 6720) \n", + "\n", + "output: \n", + "\n", + "(1, 6500) \n", + "\n", + "\n", + "\n", + "61->64\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "66\n", + "\n", + "\n", + "__getitem__\n", + "depth:2\n", + "\n", + "input:\n", + "\n", + "(1, 6720) \n", + "\n", + "output: \n", + "\n", + "(1, 0) \n", + "\n", + "\n", + "\n", + "61->66\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "67\n", + "\n", + "\n", + "__getitem__\n", + "depth:2\n", + "\n", + "input:\n", + "\n", + "(1, 6720) \n", + "\n", + "output: \n", + "\n", + "(1, 12) \n", + "\n", + "\n", + "\n", + "61->67\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "63\n", + "\n", + "\n", + "reshape\n", + "depth:2\n", + "\n", + "input:\n", + "\n", + "(1, 208) \n", + "\n", + "output: \n", + "\n", + "(1, 16, 13) \n", + "\n", + "\n", + "\n", + "62->63\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "73\n", + "\n", + "\n", + "Linear\n", + "depth:3\n", + "\n", + "input:\n", + "\n", + "(1, 16, 13) \n", + "\n", + "output: \n", + "\n", + "(1, 16, 64) \n", + "\n", + "\n", + "\n", + "63->73\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "65\n", + "\n", + "\n", + "reshape\n", + "depth:2\n", + "\n", + "input:\n", + "\n", + "(1, 6500) \n", + "\n", + "output: \n", + "\n", + "(1, 500, 13) \n", + "\n", + "\n", + "\n", + "64->65\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "85\n", + "\n", + "\n", + "Linear\n", + "depth:3\n", + "\n", + "input:\n", + "\n", + "(1, 500, 13) \n", + "\n", + "output: \n", + "\n", + "(1, 500, 64) \n", + "\n", + "\n", + "\n", + "65->85\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "68\n", + "\n", + "\n", + "reshape\n", + "depth:2\n", + "\n", + "input:\n", + "\n", + "(1, 12) \n", + "\n", + "output: \n", + "\n", + "(1, 4, 3) \n", + "\n", + "\n", + "\n", + "67->68\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "81\n", + "\n", + "\n", + "Linear\n", + "depth:3\n", + "\n", + "input:\n", + "\n", + "(1, 4, 3) \n", + "\n", + "output: \n", + "\n", + "(1, 4, 3) \n", + "\n", + "\n", + "\n", + "68->81\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "70\n", + "\n", + "\n", + "Dropout\n", + "depth:3\n", + "\n", + "input:\n", + "\n", + "(1, 10) \n", + "\n", + "output: \n", + "\n", + "(1, 10) \n", + "\n", + "\n", + "\n", + "69->70\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "71\n", + "\n", + "\n", + "LayerNorm\n", + "depth:3\n", + "\n", + "input:\n", + "\n", + "(1, 10) \n", + "\n", + "output: \n", + "\n", + "(1, 10) \n", + "\n", + "\n", + "\n", + "70->71\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "72\n", + "\n", + "\n", + "Tanh\n", + "depth:3\n", + "\n", + "input:\n", + "\n", + "(1, 10) \n", + "\n", + "output: \n", + "\n", + "(1, 10) \n", + "\n", + "\n", + "\n", + "71->72\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "102\n", + "\n", + "\n", + "cat\n", + "depth:2\n", + "\n", + "input:\n", + "\n", + "(1, 10), 2 x (1, 32), (1, 3) \n", + "\n", + "output: \n", + "\n", + "(1, 77) \n", + "\n", + "\n", + "\n", + "72->102\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "74\n", + "\n", + "\n", + "Dropout\n", + "depth:3\n", + "\n", + "input:\n", + "\n", + "(1, 16, 64) \n", + "\n", + "output: \n", + "\n", + "(1, 16, 64) \n", + "\n", + "\n", + "\n", + "73->74\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "75\n", + "\n", + "\n", + "LayerNorm\n", + "depth:3\n", + "\n", + "input:\n", + "\n", + "(1, 16, 64) \n", + "\n", + "output: \n", + "\n", + "(1, 16, 64) \n", + "\n", + "\n", + "\n", + "74->75\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "76\n", + "\n", + "\n", + "Tanh\n", + "depth:3\n", + "\n", + "input:\n", + "\n", + "(1, 16, 64) \n", + "\n", + "output: \n", + "\n", + "(1, 16, 64) \n", + "\n", + "\n", + "\n", + "75->76\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "77\n", + "\n", + "\n", + "Linear\n", + "depth:3\n", + "\n", + "input:\n", + "\n", + "(1, 16, 64) \n", + "\n", + "output: \n", + "\n", + "(1, 16, 32) \n", + "\n", + "\n", + "\n", + "76->77\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "78\n", + "\n", + "\n", + "Dropout\n", + "depth:3\n", + "\n", + "input:\n", + "\n", + "(1, 16, 32) \n", + "\n", + "output: \n", + "\n", + "(1, 16, 32) \n", + "\n", + "\n", + "\n", + "77->78\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "79\n", + "\n", + "\n", + "LayerNorm\n", + "depth:3\n", + "\n", + "input:\n", + "\n", + "(1, 16, 32) \n", + "\n", + "output: \n", + "\n", + "(1, 16, 32) \n", + "\n", + "\n", + "\n", + "78->79\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "80\n", + "\n", + "\n", + "Tanh\n", + "depth:3\n", + "\n", + "input:\n", + "\n", + "(1, 16, 32) \n", + "\n", + "output: \n", + "\n", + "(1, 16, 32) \n", + "\n", + "\n", + "\n", + "79->80\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "93\n", + "\n", + "\n", + "permute\n", + "depth:2\n", + "\n", + "input:\n", + "\n", + "(1, 16, 32) \n", + "\n", + "output: \n", + "\n", + "(1, 32, 16) \n", + "\n", + "\n", + "\n", + "80->93\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "82\n", + "\n", + "\n", + "Dropout\n", + "depth:3\n", + "\n", + "input:\n", + "\n", + "(1, 4, 3) \n", + "\n", + "output: \n", + "\n", + "(1, 4, 3) \n", + "\n", + "\n", + "\n", + "81->82\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "83\n", + "\n", + "\n", + "LayerNorm\n", + "depth:3\n", + "\n", + "input:\n", + "\n", + "(1, 4, 3) \n", + "\n", + "output: \n", + "\n", + "(1, 4, 3) \n", + "\n", + "\n", + "\n", + "82->83\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "84\n", + "\n", + "\n", + "Tanh\n", + "depth:3\n", + "\n", + "input:\n", + "\n", + "(1, 4, 3) \n", + "\n", + "output: \n", + "\n", + "(1, 4, 3) \n", + "\n", + "\n", + "\n", + "83->84\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "96\n", + "\n", + "\n", + "permute\n", + "depth:2\n", + "\n", + "input:\n", + "\n", + "(1, 4, 3) \n", + "\n", + "output: \n", + "\n", + "(1, 3, 4) \n", + "\n", + "\n", + "\n", + "84->96\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "86\n", + "\n", + "\n", + "Dropout\n", + "depth:3\n", + "\n", + "input:\n", + "\n", + "(1, 500, 64) \n", + "\n", + "output: \n", + "\n", + "(1, 500, 64) \n", + "\n", + "\n", + "\n", + "85->86\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "87\n", + "\n", + "\n", + "LayerNorm\n", + "depth:3\n", + "\n", + "input:\n", + "\n", + "(1, 500, 64) \n", + "\n", + "output: \n", + "\n", + "(1, 500, 64) \n", + "\n", + "\n", + "\n", + "86->87\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "88\n", + "\n", + "\n", + "Tanh\n", + "depth:3\n", + "\n", + "input:\n", + "\n", + "(1, 500, 64) \n", + "\n", + "output: \n", + "\n", + "(1, 500, 64) \n", + "\n", + "\n", + "\n", + "87->88\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "89\n", + "\n", + "\n", + "Linear\n", + "depth:3\n", + "\n", + "input:\n", + "\n", + "(1, 500, 64) \n", + "\n", + "output: \n", + "\n", + "(1, 500, 32) \n", + "\n", + "\n", + "\n", + "88->89\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "90\n", + "\n", + "\n", + "Dropout\n", + "depth:3\n", + "\n", + "input:\n", + "\n", + "(1, 500, 32) \n", + "\n", + "output: \n", + "\n", + "(1, 500, 32) \n", + "\n", + "\n", + "\n", + "89->90\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "91\n", + "\n", + "\n", + "LayerNorm\n", + "depth:3\n", + "\n", + "input:\n", + "\n", + "(1, 500, 32) \n", + "\n", + "output: \n", + "\n", + "(1, 500, 32) \n", + "\n", + "\n", + "\n", + "90->91\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "92\n", + "\n", + "\n", + "Tanh\n", + "depth:3\n", + "\n", + "input:\n", + "\n", + "(1, 500, 32) \n", + "\n", + "output: \n", + "\n", + "(1, 500, 32) \n", + "\n", + "\n", + "\n", + "91->92\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "99\n", + "\n", + "\n", + "permute\n", + "depth:2\n", + "\n", + "input:\n", + "\n", + "(1, 500, 32) \n", + "\n", + "output: \n", + "\n", + "(1, 32, 500) \n", + "\n", + "\n", + "\n", + "92->99\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "94\n", + "\n", + "\n", + "max_pool1d\n", + "depth:2\n", + "\n", + "input:\n", + "\n", + "(1, 32, 16) \n", + "\n", + "output: \n", + "\n", + "(1, 32, 1) \n", + "\n", + "\n", + "\n", + "93->94\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "95\n", + "\n", + "\n", + "squeeze\n", + "depth:2\n", + "\n", + "input:\n", + "\n", + "(1, 32, 1) \n", + "\n", + "output: \n", + "\n", + "(1, 32) \n", + "\n", + "\n", + "\n", + "94->95\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "95->102\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "97\n", + "\n", + "\n", + "max_pool1d\n", + "depth:2\n", + "\n", + "input:\n", + "\n", + "(1, 3, 4) \n", + "\n", + "output: \n", + "\n", + "(1, 3, 1) \n", + "\n", + "\n", + "\n", + "96->97\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "98\n", + "\n", + "\n", + "squeeze\n", + "depth:2\n", + "\n", + "input:\n", + "\n", + "(1, 3, 1) \n", + "\n", + "output: \n", + "\n", + "(1, 3) \n", + "\n", + "\n", + "\n", + "97->98\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "98->102\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "100\n", + "\n", + "\n", + "max_pool1d\n", + "depth:2\n", + "\n", + "input:\n", + "\n", + "(1, 32, 500) \n", + "\n", + "output: \n", + "\n", + "(1, 32, 1) \n", + "\n", + "\n", + "\n", + "99->100\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "101\n", + "\n", + "\n", + "squeeze\n", + "depth:2\n", + "\n", + "input:\n", + "\n", + "(1, 32, 1) \n", + "\n", + "output: \n", + "\n", + "(1, 32) \n", + "\n", + "\n", + "\n", + "100->101\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "101->102\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "103\n", + "\n", + "\n", + "Linear\n", + "depth:3\n", + "\n", + "input:\n", + "\n", + "(1, 77) \n", + "\n", + "output: \n", + "\n", + "(1, 256) \n", + "\n", + "\n", + "\n", + "102->103\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "104\n", + "\n", + "\n", + "LayerNorm\n", + "depth:3\n", + "\n", + "input:\n", + "\n", + "(1, 256) \n", + "\n", + "output: \n", + "\n", + "(1, 256) \n", + "\n", + "\n", + "\n", + "103->104\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "105\n", + "\n", + "\n", + "Tanh\n", + "depth:3\n", + "\n", + "input:\n", + "\n", + "(1, 256) \n", + "\n", + "output: \n", + "\n", + "(1, 256) \n", + "\n", + "\n", + "\n", + "104->105\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "106\n", + "\n", + "\n", + "Dropout\n", + "depth:3\n", + "\n", + "input:\n", + "\n", + "(1, 256) \n", + "\n", + "output: \n", + "\n", + "(1, 256) \n", + "\n", + "\n", + "\n", + "105->106\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "107\n", + "\n", + "\n", + "Linear\n", + "depth:3\n", + "\n", + "input:\n", + "\n", + "(1, 256) \n", + "\n", + "output: \n", + "\n", + "(1, 128) \n", + "\n", + "\n", + "\n", + "106->107\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "108\n", + "\n", + "\n", + "LayerNorm\n", + "depth:3\n", + "\n", + "input:\n", + "\n", + "(1, 128) \n", + "\n", + "output: \n", + "\n", + "(1, 128) \n", + "\n", + "\n", + "\n", + "107->108\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "109\n", + "\n", + "\n", + "Tanh\n", + "depth:3\n", + "\n", + "input:\n", + "\n", + "(1, 128) \n", + "\n", + "output: \n", + "\n", + "(1, 128) \n", + "\n", + "\n", + "\n", + "108->109\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "110\n", + "\n", + "\n", + "Dropout\n", + "depth:3\n", + "\n", + "input:\n", + "\n", + "(1, 128) \n", + "\n", + "output: \n", + "\n", + "(1, 128) \n", + "\n", + "\n", + "\n", + "109->110\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "111\n", + "\n", + "\n", + "Linear\n", + "depth:3\n", + "\n", + "input:\n", + "\n", + "(1, 128) \n", + "\n", + "output: \n", + "\n", + "(1, 64) \n", + "\n", + "\n", + "\n", + "110->111\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "112\n", + "\n", + "\n", + "LayerNorm\n", + "depth:3\n", + "\n", + "input:\n", + "\n", + "(1, 64) \n", + "\n", + "output: \n", + "\n", + "(1, 64) \n", + "\n", + "\n", + "\n", + "111->112\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "113\n", + "\n", + "\n", + "Tanh\n", + "depth:3\n", + "\n", + "input:\n", + "\n", + "(1, 64) \n", + "\n", + "output: \n", + "\n", + "(1, 64) \n", + "\n", + "\n", + "\n", + "112->113\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "114\n", + "\n", + "\n", + "Dropout\n", + "depth:3\n", + "\n", + "input:\n", + "\n", + "(1, 64) \n", + "\n", + "output: \n", + "\n", + "(1, 64) \n", + "\n", + "\n", + "\n", + "113->114\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "115\n", + "\n", + "\n", + "Linear\n", + "depth:3\n", + "\n", + "input:\n", + "\n", + "(1, 64) \n", + "\n", + "output: \n", + "\n", + "(1, 64) \n", + "\n", + "\n", + "\n", + "114->115\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "116\n", + "\n", + "\n", + "LayerNorm\n", + "depth:3\n", + "\n", + "input:\n", + "\n", + "(1, 64) \n", + "\n", + "output: \n", + "\n", + "(1, 64) \n", + "\n", + "\n", + "\n", + "115->116\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "117\n", + "\n", + "\n", + "Linear\n", + "depth:1\n", + "\n", + "input:\n", + "\n", + "(1, 64) \n", + "\n", + "output: \n", + "\n", + "(1, 1) \n", + "\n", + "\n", + "\n", + "116->117\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "118\n", + "\n", + "\n", + "output-tensor\n", + "depth:0\n", + "\n", + "(1, 1)\n", + "\n", + "\n", + "\n", + "117->118\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "120\n", + "\n", + "\n", + "logsumexp\n", + "depth:1\n", + "\n", + "input:\n", + "\n", + "(1, 25) \n", + "\n", + "output: \n", + "\n", + "(1, 1) \n", + "\n", + "\n", + "\n", + "119->120\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "121\n", + "\n", + "\n", + "sub\n", + "depth:1\n", + "\n", + "input:\n", + "\n", + "(1, 25), (1, 1) \n", + "\n", + "output: \n", + "\n", + "(1, 25) \n", + "\n", + "\n", + "\n", + "119->121\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "120->121\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "122\n", + "\n", + "\n", + "eq\n", + "depth:1\n", + "\n", + "input:\n", + "\n", + "2 x (1, 25) \n", + "\n", + "output: \n", + "\n", + "(1, 25) \n", + "\n", + "\n", + "\n", + "121->122\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "126\n", + "\n", + "\n", + "softmax\n", + "depth:1\n", + "\n", + "input:\n", + "\n", + "(1, 25) \n", + "\n", + "output: \n", + "\n", + "(1, 25) \n", + "\n", + "\n", + "\n", + "121->126\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "140\n", + "\n", + "\n", + "broadcast_tensors\n", + "depth:1\n", + "\n", + "input:\n", + "\n", + "(1, 1), (1, 25) \n", + "\n", + "output: \n", + "\n", + "2 x (1, 25) \n", + "\n", + "\n", + "\n", + "121->140\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "123\n", + "\n", + "\n", + "reshape\n", + "depth:1\n", + "\n", + "input:\n", + "\n", + "(1, 25) \n", + "\n", + "output: \n", + "\n", + "(1, 25) \n", + "\n", + "\n", + "\n", + "122->123\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "124\n", + "\n", + "\n", + "all\n", + "depth:1\n", + "\n", + "input:\n", + "\n", + "(1, 25) \n", + "\n", + "output: \n", + "\n", + "(1,) \n", + "\n", + "\n", + "\n", + "123->124\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "125\n", + "\n", + "\n", + "all\n", + "depth:1\n", + "\n", + "input:\n", + "\n", + "(1,) \n", + "\n", + "output: \n", + "\n", + "() \n", + "\n", + "\n", + "\n", + "124->125\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "127\n", + "\n", + "\n", + "reshape\n", + "depth:1\n", + "\n", + "input:\n", + "\n", + "(1, 25) \n", + "\n", + "output: \n", + "\n", + "(1, 25) \n", + "\n", + "\n", + "\n", + "126->127\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "128\n", + "\n", + "\n", + "multinomial\n", + "depth:1\n", + "\n", + "input:\n", + "\n", + "(1, 25) \n", + "\n", + "output: \n", + "\n", + "(1, 1) \n", + "\n", + "\n", + "\n", + "127->128\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "129\n", + "\n", + "\n", + "__get__\n", + "depth:1\n", + "\n", + "input:\n", + "\n", + "(1, 1) \n", + "\n", + "output: \n", + "\n", + "(1, 1) \n", + "\n", + "\n", + "\n", + "130\n", + "\n", + "\n", + "reshape\n", + "depth:1\n", + "\n", + "input:\n", + "\n", + "(1, 1) \n", + "\n", + "output: \n", + "\n", + "(1,) \n", + "\n", + "\n", + "\n", + "129->130\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "131\n", + "\n", + "\n", + "remainder\n", + "depth:1\n", + "\n", + "input:\n", + "\n", + "(1,) \n", + "\n", + "output: \n", + "\n", + "(1,) \n", + "\n", + "\n", + "\n", + "130->131\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "133\n", + "\n", + "\n", + "ge\n", + "depth:1\n", + "\n", + "input:\n", + "\n", + "(1,) \n", + "\n", + "output: \n", + "\n", + "(1,) \n", + "\n", + "\n", + "\n", + "130->133\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "135\n", + "\n", + "\n", + "le\n", + "depth:1\n", + "\n", + "input:\n", + "\n", + "(1,) \n", + "\n", + "output: \n", + "\n", + "(1,) \n", + "\n", + "\n", + "\n", + "130->135\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "138\n", + "\n", + "\n", + "long\n", + "depth:1\n", + "\n", + "input:\n", + "\n", + "(1,) \n", + "\n", + "output: \n", + "\n", + "(1,) \n", + "\n", + "\n", + "\n", + "130->138\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "132\n", + "\n", + "\n", + "eq\n", + "depth:1\n", + "\n", + "input:\n", + "\n", + "(1,) \n", + "\n", + "output: \n", + "\n", + "(1,) \n", + "\n", + "\n", + "\n", + "131->132\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "134\n", + "\n", + "\n", + "__and__\n", + "depth:1\n", + "\n", + "input:\n", + "\n", + "2 x (1,) \n", + "\n", + "output: \n", + "\n", + "(1,) \n", + "\n", + "\n", + "\n", + "132->134\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "133->134\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "136\n", + "\n", + "\n", + "__and__\n", + "depth:1\n", + "\n", + "input:\n", + "\n", + "2 x (1,) \n", + "\n", + "output: \n", + "\n", + "(1,) \n", + "\n", + "\n", + "\n", + "134->136\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "135->136\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "137\n", + "\n", + "\n", + "all\n", + "depth:1\n", + "\n", + "input:\n", + "\n", + "(1,) \n", + "\n", + "output: \n", + "\n", + "() \n", + "\n", + "\n", + "\n", + "136->137\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "139\n", + "\n", + "\n", + "unsqueeze\n", + "depth:1\n", + "\n", + "input:\n", + "\n", + "(1,) \n", + "\n", + "output: \n", + "\n", + "(1, 1) \n", + "\n", + "\n", + "\n", + "138->139\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "145\n", + "\n", + "\n", + "reshape\n", + "depth:1\n", + "\n", + "input:\n", + "\n", + "(1,) \n", + "\n", + "output: \n", + "\n", + "(1,) \n", + "\n", + "\n", + "\n", + "138->145\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "139->140\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "141\n", + "\n", + "\n", + "__getitem__\n", + "depth:1\n", + "\n", + "input:\n", + "\n", + "(1, 25) \n", + "\n", + "output: \n", + "\n", + "(1, 1) \n", + "\n", + "\n", + "\n", + "140->141\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "142\n", + "\n", + "\n", + "gather\n", + "depth:1\n", + "\n", + "input:\n", + "\n", + "(1, 25), (1, 1) \n", + "\n", + "output: \n", + "\n", + "(1, 1) \n", + "\n", + "\n", + "\n", + "140->142\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "141->142\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "143\n", + "\n", + "\n", + "squeeze\n", + "depth:1\n", + "\n", + "input:\n", + "\n", + "(1, 1) \n", + "\n", + "output: \n", + "\n", + "(1,) \n", + "\n", + "\n", + "\n", + "142->143\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "144\n", + "\n", + "\n", + "output-tensor\n", + "depth:0\n", + "\n", + "(1,)\n", + "\n", + "\n", + "\n", + "143->144\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "146\n", + "\n", + "\n", + "output-tensor\n", + "depth:0\n", + "\n", + "(1,)\n", + "\n", + "\n", + "\n", + "145->146\n", + "\n", + "\n", + "\n", + "\n", + "\n" + ], + "text/plain": [ + "" + ] + }, + "execution_count": 13, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "model_graph.visual_graph" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": ".venv", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.5" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/examples/temp.ipynb b/examples/temp.ipynb new file mode 100644 index 00000000..44a1cf02 --- /dev/null +++ b/examples/temp.ipynb @@ -0,0 +1,384 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "array([15, 10])" + ] + }, + "execution_count": 1, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "from gymnasium.spaces import MultiDiscrete\n", + "import numpy as np\n", + "action_space = MultiDiscrete(np.array([20,25]), seed=42)\n", + "action_space.sample()" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "array([17, 17])" + ] + }, + "execution_count": 2, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "action_space.sample()" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ + "from nocturne.envs.nocturne_gymnasium import NocturneGymnasium\n", + "import yaml\n", + "from nocturne.envs.base_env import BaseEnv\n", + "from nocturne.envs.vec_env_ma import MultiAgentAsVecEnv\n", + "\n", + "# Load environment settings\n", + "with open(f\"../configs/env_config.yaml\", \"r\") as stream:\n", + " env_config = yaml.safe_load(stream)\n", + "\n", + "# Initialize environment\n", + "env = BaseEnv(config=env_config)" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [], + "source": [ + "gymnasiumEnv = NocturneGymnasium(env)" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "MultiDiscrete([20 25])" + ] + }, + "execution_count": 5, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "gymnasiumEnv.action_space" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "{6: array([0.30362597, 0.54050583, 0.16309013, ..., 0. , 0. ,\n", + " 0. ]),\n", + " 23: array([0.33037499, 0.583 , 0.15996636, ..., 0. , 0. ,\n", + " 0. ]),\n", + " 2: array([0.29274434, 0.50419724, 0.15200901, ..., 0. , 0. ,\n", + " 0. ]),\n", + " 8: array([0.30145869, 0.53301573, 0.16035738, ..., 0. , 0. ,\n", + " 0. ]),\n", + " 9: array([0.27850392, 0.50491506, 0.16618462, ..., 0. , 0. ,\n", + " 0. ]),\n", + " 1: array([0.28742164, 0.52419186, 0.17056067, ..., 0. , 0. ,\n", + " 0. ]),\n", + " 14: array([0.27822894, 0.51315653, 0.1849147 , ..., 0. , 0. ,\n", + " 0. ]),\n", + " 18: array([0.28400436, 0.51011646, 0.13745898, ..., 0. , 0. ,\n", + " 0. ]),\n", + " 4: array([0.31240156, 0.52112043, 0.16890202, ..., 0. , 0. ,\n", + " 0. ]),\n", + " 0: array([0.27941427, 0.51273805, 0.16114239, ..., 0. , 0. ,\n", + " 0. ])}" + ] + }, + "execution_count": 6, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "gymnasiumEnv.reset()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Reset\n", + "obs_dict = gymnasiumEnv.reset()\n", + "\n", + "# Get info\n", + "agent_ids = [agent_id for agent_id in obs_dict.keys()]\n", + "dead_agent_ids = []\n", + "num_agents = len(agent_ids)\n", + "rewards = {agent_id: 0 for agent_id in agent_ids}\n", + "\n", + "for step in range(1000):\n", + "\n", + " # Sample actions\n", + " action_dict = {\n", + " agent_id: env.action_space.sample() \n", + " for agent_id in agent_ids\n", + " if agent_id not in dead_agent_ids\n", + " }\n", + " # Step in env\n", + " obs_dict, rew_dict, done_dict, info_dict = gymnasiumEnv.step(action_dict)\n", + "\n", + " for agent_id in action_dict.keys():\n", + " rewards[agent_id] += rew_dict[agent_id]\n", + "\n", + " # Update dead agents\n", + " for agent_id, is_done in done_dict.items():\n", + " if is_done and agent_id not in dead_agent_ids:\n", + " dead_agent_ids.append(agent_id)\n", + "\n", + " # Reset if all agents are done\n", + " if done_dict[\"__all__\"]:\n", + " print(f'Done after {env.step_num} steps -- total return in episode: {rewards}')\n", + " obs_dict = gymnasiumEnv.reset()\n", + " agent_ids = [agent_id for agent_id in obs_dict.keys()]\n", + " dead_agent_ids = []\n", + " rewards = {agent_id: 0 for agent_id in agent_ids}\n", + "\n", + "# Close environment\n", + "env.close()" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [], + "source": [ + "from stable_baselines3.common.vec_env import SubprocVecEnv" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [], + "source": [ + "def make_env(env_config):\n", + " return NocturneGymnasium(BaseEnv(config=env_config)) " + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [], + "source": [ + "envs = SubprocVecEnv([lambda: make_env(env_config) for _ in range(4)])" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [], + "source": [ + "# Reset\n", + "obs_dicts = envs.reset()" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [], + "source": [ + "agent_ids_batch = []\n", + "dead_agent_ids_batch = []\n", + "num_agents_batch = []\n", + "rewards_batch = []\n", + "for obs_dict in obs_dicts:\n", + " agent_ids = [agent_id for agent_id in obs_dict.keys()]\n", + " dead_agent_ids = []\n", + " num_agents = len(agent_ids)\n", + " rewards = {agent_id: 0 for agent_id in agent_ids}\n", + " agent_ids_batch.append(agent_ids)\n", + " dead_agent_ids_batch.append(dead_agent_ids)\n", + " num_agents_batch.append(num_agents)\n", + " rewards_batch.append(rewards)\n" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "[{6: 6, 23: 18, 2: 24, 8: 1, 9: 2, 1: 18, 14: 1, 18: 13, 4: 11, 0: 18},\n", + " {6: 0, 23: 14, 2: 8, 8: 22, 9: 1, 1: 4, 14: 18, 18: 18, 4: 1, 0: 23},\n", + " {6: 9, 23: 22, 2: 1, 8: 2, 9: 0, 1: 20, 14: 5, 18: 14, 4: 7, 0: 13},\n", + " {6: 13, 23: 1, 2: 8, 8: 13, 9: 9, 1: 16, 14: 9, 18: 9, 4: 12, 0: 23}]" + ] + }, + "execution_count": 8, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "action_dicts = [\n", + " {\n", + " agent_id: env.action_space.sample() \n", + " for agent_id in agent_ids\n", + " if agent_id not in dead_agent_ids\n", + " }\n", + " for agent_ids, dead_agent_ids in zip(agent_ids_batch, dead_agent_ids_batch)\n", + " ]\n", + "action_dicts" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": {}, + "outputs": [ + { + "ename": "KeyError", + "evalue": "7", + "output_type": "error", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mKeyError\u001b[0m Traceback (most recent call last)", + "Cell \u001b[0;32mIn[9], line 17\u001b[0m\n\u001b[1;32m 15\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m rew_dict, rewards \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28mzip\u001b[39m(rew_dicts, rewards_batch):\n\u001b[1;32m 16\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m agent_id \u001b[38;5;129;01min\u001b[39;00m rew_dict\u001b[38;5;241m.\u001b[39mkeys():\n\u001b[0;32m---> 17\u001b[0m \u001b[43mrewards\u001b[49m\u001b[43m[\u001b[49m\u001b[43magent_id\u001b[49m\u001b[43m]\u001b[49m \u001b[38;5;241m+\u001b[39m\u001b[38;5;241m=\u001b[39m rew_dict[agent_id] \n\u001b[1;32m 19\u001b[0m \u001b[38;5;66;03m# Update dead agents\u001b[39;00m\n\u001b[1;32m 20\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m done_dict, dead_agent_ids \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28mzip\u001b[39m(done_dicts, dead_agent_ids_batch):\n", + "\u001b[0;31mKeyError\u001b[0m: 7" + ] + } + ], + "source": [ + "for step in range(1000):\n", + "\n", + " # Sample actions\n", + " action_dicts = [\n", + " {\n", + " agent_id: env.action_space.sample() \n", + " for agent_id in agent_ids\n", + " if agent_id not in dead_agent_ids\n", + " }\n", + " for agent_ids, dead_agent_ids in zip(agent_ids_batch, dead_agent_ids_batch)\n", + " ]\n", + " # Step in env\n", + " obs_dicts, rew_dicts, done_dicts, info_dicts = envs.step(action_dicts)\n", + "\n", + " for rew_dict, rewards in zip(rew_dicts, rewards_batch):\n", + " for agent_id in rew_dict.keys():\n", + " rewards[agent_id] += rew_dict[agent_id] \n", + " \n", + " # Update dead agents\n", + " for done_dict, dead_agent_ids in zip(done_dicts, dead_agent_ids_batch):\n", + " for agent_id, is_done in done_dict.items():\n", + " if is_done and agent_id not in dead_agent_ids:\n", + " dead_agent_ids.append(agent_id)\n", + "\n", + " # Reset if all agents are done\n", + " if all([done_dict[\"__all__\"] for done_dict in done_dicts]):\n", + " print(f'Done after {env.step_num} steps -- total return in episode: {rewards}')\n", + " obs_dicts = envs.reset()\n", + " agent_ids_batch = []\n", + " dead_agent_ids_batch = []\n", + " num_agents_batch = []\n", + " rewards_batch = []\n", + " for obs_dict in obs_dicts:\n", + " agent_ids = [agent_id for agent_id in obs_dict.keys()]\n", + " dead_agent_ids = []\n", + " num_agents = len(agent_ids)\n", + " rewards = {agent_id: 0 for agent_id in agent_ids}\n", + " agent_ids_batch.append(agent_ids)\n", + " dead_agent_ids_batch.append(dead_agent_ids)\n", + " num_agents_batch.append(num_agents)\n", + " rewards_batch.append(rewards)\n", + "\n", + " # # Sample actions\n", + " # action_dict = {\n", + " # agent_id: env.action_space.sample() \n", + " # for agent_id in agent_ids\n", + " # if agent_id not in dead_agent_ids\n", + " # }\n", + " # # Step in env\n", + " # obs_dict, rew_dict, done_dict, info_dict = gymnasiumEnv.step(action_dict)\n", + "\n", + " # for agent_id in action_dict.keys():\n", + " # rewards[agent_id] += rew_dict[agent_id]\n", + "\n", + " # # Update dead agents\n", + " # for agent_id, is_done in done_dict.items():\n", + " # if is_done and agent_id not in dead_agent_ids:\n", + " # dead_agent_ids.append(agent_id)\n", + "\n", + " # # Reset if all agents are done\n", + " # if done_dict[\"__all__\"]:\n", + " # print(f'Done after {env.step_num} steps -- total return in episode: {rewards}')\n", + " # obs_dict = gymnasiumEnv.reset()\n", + " # agent_ids = [agent_id for agent_id in obs_dict.keys()]\n", + " # dead_agent_ids = []\n", + " # rewards = {agent_id: 0 for agent_id in agent_ids}" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": ".venv", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.5" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/nocturne/envs/nocturne_gymnasium.py b/nocturne/envs/nocturne_gymnasium.py new file mode 100644 index 00000000..099031b2 --- /dev/null +++ b/nocturne/envs/nocturne_gymnasium.py @@ -0,0 +1,136 @@ +"""Gymnasium vectorizable environment wrapper for Nocturne.""" +import logging +import time +from copy import deepcopy +from typing import Any, Dict, List, TypeVar + +import gym +import gymnasium +import numpy as np + +from nocturne.envs.base_env import BaseEnv +from utils.config import load_config + +logging.basicConfig(level=logging.INFO) + +ObsType = TypeVar("ObsType") +ActType = TypeVar("ActType") + +class NocturneGymnasium(gymnasium.Env): + """Nocturne environment wrapper for compatible with SB3. + """ + + def __init__(self, config, num_agents, psr=False): + self.env = BaseEnv(config) + + # Make action and observation spaces compatible with SB3 (requires gymnasium) + self.action_space = gymnasium.spaces.MultiDiscrete([self.env.config.max_num_vehicles, self.env.action_space.n]) + self.observation_space = gym.spaces.Box(-np.inf, np.inf, self.env.observation_space.shape, np.float32) + self.num_agents = num_agents # The maximum number of agents allowed in the environmen + self.psr = psr # Whether to use PSR or not + + self.psr_dict = self.init_scene_dict() if psr else None # Initialize dict to keep track of the average reward obtained in each scene + self.n_episodes = 0 + self.episode_lengths = [] + self.rewards = [] # Log reward per step + self.dead_agent_ids = [] # Log dead agents per step + self.num_agents_collided = 0 # Keep track of how many agents collided + self.total_agents_in_rollout = 0 # Log total number of agents in rollout + self.num_agents_goal_achieved = 0 # Keep track of how many agents reached their goal + self.agents_in_scene = [] + self.filename = None # If provided, always use the same file + + def step(self, actions): + """Take a step in the environment, convert dicts to np arrays. + + Args + ---- + action (Dict): Dictionary with a single action for the controlled vehicle. + + Returns + ------- + observation, reward, terminated, truncated, info (np.ndarray, float, bool, bool, dict) + """ + next_obs_dict, rewards_dict, dones_dict, info_dict = self.env.step( + action_dict=actions + ) + + return ( + next_obs_dict, + rewards_dict, + dones_dict, + False, + info_dict, + ) + + def reset(self, seed=None): + """Reset environment and return initial observations.""" + obs_dict = self.env.reset() + + # Reset Nocturne env + obs_dict = self.env.reset(self.filename, self.psr_dict) + + # Reset storage + self.agent_ids = [] + self.rewards = [] + self.dead_agent_ids = [] + self.ep_collisions = 0 + self.ep_goal_achived = 0 + + obs_all = np.full(fill_value=np.nan, shape=(self.num_envs, self.env.observation_space.shape[0])) + for idx, agent_id in enumerate(obs_dict.keys()): + self.agent_ids.append(agent_id) + obs_all[idx, :] = obs_dict[agent_id] + + # Save obs in buffer + self._save_obs(obs_all) + + logging.debug(f"RESET - agent ids: {self.agent_ids}") + + # Make dict for storing the last info set for each agent + self.last_info_dicts = {agent_id: {} for agent_id in self.agent_ids} + + return self._obs_from_buf(), {} + + def _obs_from_buf(self) -> ObsType: + """Get observation from buffer.""" + return np.copy(self.buf_obs) + + @property + def action_space(self): + return self.env.action_space + + @action_space.setter + def action_space(self, action_space): + self.env.action_space = action_space + + @property + def observation_space(self): + return self.env.observation_space + + @observation_space.setter + def observation_space(self, observation_space): + self.env.observation_space = observation_space + + def render(self): + pass + + def close(self): + pass + + @property + def seed(self, seed=None): + return None + + @seed.setter + def seed(self, seed=None): + pass + + def __getattr__(self, name): + return getattr(self._env, name) + + def get_attr(self, attr_name: str): + return getattr(self._env, attr_name) + + def set_attr(self, attr_name: str): + setattr(self._env, attr_name) From 33b65ea795e779215194c7cf7666e5a2489b117a Mon Sep 17 00:00:00 2001 From: Aarav Pandya Date: Sun, 14 Jan 2024 19:54:00 -0500 Subject: [PATCH 2/2] Pufferlib vectorization support --- examples/07_nocturne_pufferlib.ipynb | 169 ++++++++++++ examples/temp.ipynb | 384 --------------------------- nocturne/envs/nocturne_gymnasium.py | 225 +++++++++++++--- 3 files changed, 352 insertions(+), 426 deletions(-) create mode 100644 examples/07_nocturne_pufferlib.ipynb delete mode 100644 examples/temp.ipynb diff --git a/examples/07_nocturne_pufferlib.ipynb b/examples/07_nocturne_pufferlib.ipynb new file mode 100644 index 00000000..ac23d938 --- /dev/null +++ b/examples/07_nocturne_pufferlib.ipynb @@ -0,0 +1,169 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [], + "source": [ + "from nocturne.envs.nocturne_gymnasium import NocturneGymnasium, CustomPostprocessor\n", + "import yaml\n", + "from nocturne.envs.base_env import BaseEnv\n", + "from nocturne.envs.vec_env_ma import MultiAgentAsVecEnv\n", + "\n", + "import pufferlib.vectorization\n", + "vec = pufferlib.vectorization.Multiprocessing\n", + "# vec = pufferlib.vectorization.Serial\n", + "\n", + "import pufferlib.emulation\n", + "import pufferlib.wrappers\n", + "\n", + "from time import perf_counter\n", + "\n", + "# Load environment settings\n", + "with open(f\"../configs/env_config.yaml\", \"r\") as stream:\n", + " env_config = yaml.safe_load(stream)" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ + "def make_env(env_config):\n", + " return NocturneGymnasium(config=env_config, num_agents=env_config[\"max_num_vehicles\"]) \n", + "\n", + "def nocturne_creator(env_config):\n", + " return pufferlib.emulation.GymnasiumPufferEnv(env_creator=make_env, env_args=(env_config,), postprocessor_cls=CustomPostprocessor)" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [], + "source": [ + "# envs = vec(nocturne_creator,env_args=[env_config,], num_envs=4, envs_per_worker=2, env_pool=True)" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [], + "source": [ + "# envs.async_reset()\n", + "# obs = envs.recv()[0]" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [], + "source": [ + "# actions = [envs.single_action_space.sample() for _ in range(4)]\n", + "# envs.step(actions)" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Average time for single env step: 0.002120851782616228\n", + "Average FPS for single env step: 471.50866844944056\n" + ] + } + ], + "source": [ + "NUM_STEPS = 1000\n", + "\n", + "env = make_env(env_config)\n", + "env.reset()\n", + "\n", + "total_time = 0\n", + "\n", + "for i in range(NUM_STEPS):\n", + " actions = env.action_space.sample()\n", + " start = perf_counter()\n", + " env.step(actions)\n", + " end = perf_counter()\n", + " total_time += end - start\n", + "\n", + "print(f\"Average time for single env step: {total_time/NUM_STEPS}\")\n", + "print(f\"Average FPS for single env step: {NUM_STEPS/(total_time)}\")" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Average time for 1000 env step: 0.0010244677745140506\n", + "Average FPS for 1000 env step: 976.1165991525143\n" + ] + } + ], + "source": [ + "NUM_STEPS = 1000\n", + "NUM_PROCESSES = 32\n", + "NUM_ENVS = 128\n", + "assert(NUM_ENVS % NUM_PROCESSES == 0)\n", + "NUM_ENVS_PER_WORKER = NUM_ENVS // NUM_PROCESSES\n", + "\n", + "envs = vec(nocturne_creator,env_args=[env_config,], num_envs=NUM_ENVS, envs_per_worker=1, env_pool=False)\n", + "envs.async_reset()\n", + "obs = envs.recv()[0]\n", + "\n", + "total_time = 0\n", + "for i in range(NUM_STEPS):\n", + " actions = [envs.single_action_space.sample() for _ in range(NUM_ENVS)]\n", + " start = perf_counter()\n", + " envs.step(actions)\n", + " end = perf_counter()\n", + " total_time += end - start\n", + "envs.close()\n", + "print(f\"Average time for {NUM_STEPS} env step: {total_time/(NUM_STEPS*NUM_ENVS)}\")\n", + "print(f\"Average FPS for {NUM_STEPS} env step: {(NUM_STEPS/(total_time))*NUM_ENVS}\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": ".venv", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.5" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/examples/temp.ipynb b/examples/temp.ipynb deleted file mode 100644 index 44a1cf02..00000000 --- a/examples/temp.ipynb +++ /dev/null @@ -1,384 +0,0 @@ -{ - "cells": [ - { - "cell_type": "code", - "execution_count": 1, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "array([15, 10])" - ] - }, - "execution_count": 1, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "from gymnasium.spaces import MultiDiscrete\n", - "import numpy as np\n", - "action_space = MultiDiscrete(np.array([20,25]), seed=42)\n", - "action_space.sample()" - ] - }, - { - "cell_type": "code", - "execution_count": 2, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "array([17, 17])" - ] - }, - "execution_count": 2, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "action_space.sample()" - ] - }, - { - "cell_type": "code", - "execution_count": 2, - "metadata": {}, - "outputs": [], - "source": [ - "from nocturne.envs.nocturne_gymnasium import NocturneGymnasium\n", - "import yaml\n", - "from nocturne.envs.base_env import BaseEnv\n", - "from nocturne.envs.vec_env_ma import MultiAgentAsVecEnv\n", - "\n", - "# Load environment settings\n", - "with open(f\"../configs/env_config.yaml\", \"r\") as stream:\n", - " env_config = yaml.safe_load(stream)\n", - "\n", - "# Initialize environment\n", - "env = BaseEnv(config=env_config)" - ] - }, - { - "cell_type": "code", - "execution_count": 4, - "metadata": {}, - "outputs": [], - "source": [ - "gymnasiumEnv = NocturneGymnasium(env)" - ] - }, - { - "cell_type": "code", - "execution_count": 5, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "MultiDiscrete([20 25])" - ] - }, - "execution_count": 5, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "gymnasiumEnv.action_space" - ] - }, - { - "cell_type": "code", - "execution_count": 6, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "{6: array([0.30362597, 0.54050583, 0.16309013, ..., 0. , 0. ,\n", - " 0. ]),\n", - " 23: array([0.33037499, 0.583 , 0.15996636, ..., 0. , 0. ,\n", - " 0. ]),\n", - " 2: array([0.29274434, 0.50419724, 0.15200901, ..., 0. , 0. ,\n", - " 0. ]),\n", - " 8: array([0.30145869, 0.53301573, 0.16035738, ..., 0. , 0. ,\n", - " 0. ]),\n", - " 9: array([0.27850392, 0.50491506, 0.16618462, ..., 0. , 0. ,\n", - " 0. ]),\n", - " 1: array([0.28742164, 0.52419186, 0.17056067, ..., 0. , 0. ,\n", - " 0. ]),\n", - " 14: array([0.27822894, 0.51315653, 0.1849147 , ..., 0. , 0. ,\n", - " 0. ]),\n", - " 18: array([0.28400436, 0.51011646, 0.13745898, ..., 0. , 0. ,\n", - " 0. ]),\n", - " 4: array([0.31240156, 0.52112043, 0.16890202, ..., 0. , 0. ,\n", - " 0. ]),\n", - " 0: array([0.27941427, 0.51273805, 0.16114239, ..., 0. , 0. ,\n", - " 0. ])}" - ] - }, - "execution_count": 6, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "gymnasiumEnv.reset()" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# Reset\n", - "obs_dict = gymnasiumEnv.reset()\n", - "\n", - "# Get info\n", - "agent_ids = [agent_id for agent_id in obs_dict.keys()]\n", - "dead_agent_ids = []\n", - "num_agents = len(agent_ids)\n", - "rewards = {agent_id: 0 for agent_id in agent_ids}\n", - "\n", - "for step in range(1000):\n", - "\n", - " # Sample actions\n", - " action_dict = {\n", - " agent_id: env.action_space.sample() \n", - " for agent_id in agent_ids\n", - " if agent_id not in dead_agent_ids\n", - " }\n", - " # Step in env\n", - " obs_dict, rew_dict, done_dict, info_dict = gymnasiumEnv.step(action_dict)\n", - "\n", - " for agent_id in action_dict.keys():\n", - " rewards[agent_id] += rew_dict[agent_id]\n", - "\n", - " # Update dead agents\n", - " for agent_id, is_done in done_dict.items():\n", - " if is_done and agent_id not in dead_agent_ids:\n", - " dead_agent_ids.append(agent_id)\n", - "\n", - " # Reset if all agents are done\n", - " if done_dict[\"__all__\"]:\n", - " print(f'Done after {env.step_num} steps -- total return in episode: {rewards}')\n", - " obs_dict = gymnasiumEnv.reset()\n", - " agent_ids = [agent_id for agent_id in obs_dict.keys()]\n", - " dead_agent_ids = []\n", - " rewards = {agent_id: 0 for agent_id in agent_ids}\n", - "\n", - "# Close environment\n", - "env.close()" - ] - }, - { - "cell_type": "code", - "execution_count": 3, - "metadata": {}, - "outputs": [], - "source": [ - "from stable_baselines3.common.vec_env import SubprocVecEnv" - ] - }, - { - "cell_type": "code", - "execution_count": 4, - "metadata": {}, - "outputs": [], - "source": [ - "def make_env(env_config):\n", - " return NocturneGymnasium(BaseEnv(config=env_config)) " - ] - }, - { - "cell_type": "code", - "execution_count": 5, - "metadata": {}, - "outputs": [], - "source": [ - "envs = SubprocVecEnv([lambda: make_env(env_config) for _ in range(4)])" - ] - }, - { - "cell_type": "code", - "execution_count": 6, - "metadata": {}, - "outputs": [], - "source": [ - "# Reset\n", - "obs_dicts = envs.reset()" - ] - }, - { - "cell_type": "code", - "execution_count": 7, - "metadata": {}, - "outputs": [], - "source": [ - "agent_ids_batch = []\n", - "dead_agent_ids_batch = []\n", - "num_agents_batch = []\n", - "rewards_batch = []\n", - "for obs_dict in obs_dicts:\n", - " agent_ids = [agent_id for agent_id in obs_dict.keys()]\n", - " dead_agent_ids = []\n", - " num_agents = len(agent_ids)\n", - " rewards = {agent_id: 0 for agent_id in agent_ids}\n", - " agent_ids_batch.append(agent_ids)\n", - " dead_agent_ids_batch.append(dead_agent_ids)\n", - " num_agents_batch.append(num_agents)\n", - " rewards_batch.append(rewards)\n" - ] - }, - { - "cell_type": "code", - "execution_count": 8, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "[{6: 6, 23: 18, 2: 24, 8: 1, 9: 2, 1: 18, 14: 1, 18: 13, 4: 11, 0: 18},\n", - " {6: 0, 23: 14, 2: 8, 8: 22, 9: 1, 1: 4, 14: 18, 18: 18, 4: 1, 0: 23},\n", - " {6: 9, 23: 22, 2: 1, 8: 2, 9: 0, 1: 20, 14: 5, 18: 14, 4: 7, 0: 13},\n", - " {6: 13, 23: 1, 2: 8, 8: 13, 9: 9, 1: 16, 14: 9, 18: 9, 4: 12, 0: 23}]" - ] - }, - "execution_count": 8, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "action_dicts = [\n", - " {\n", - " agent_id: env.action_space.sample() \n", - " for agent_id in agent_ids\n", - " if agent_id not in dead_agent_ids\n", - " }\n", - " for agent_ids, dead_agent_ids in zip(agent_ids_batch, dead_agent_ids_batch)\n", - " ]\n", - "action_dicts" - ] - }, - { - "cell_type": "code", - "execution_count": 9, - "metadata": {}, - "outputs": [ - { - "ename": "KeyError", - "evalue": "7", - "output_type": "error", - "traceback": [ - "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", - "\u001b[0;31mKeyError\u001b[0m Traceback (most recent call last)", - "Cell \u001b[0;32mIn[9], line 17\u001b[0m\n\u001b[1;32m 15\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m rew_dict, rewards \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28mzip\u001b[39m(rew_dicts, rewards_batch):\n\u001b[1;32m 16\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m agent_id \u001b[38;5;129;01min\u001b[39;00m rew_dict\u001b[38;5;241m.\u001b[39mkeys():\n\u001b[0;32m---> 17\u001b[0m \u001b[43mrewards\u001b[49m\u001b[43m[\u001b[49m\u001b[43magent_id\u001b[49m\u001b[43m]\u001b[49m \u001b[38;5;241m+\u001b[39m\u001b[38;5;241m=\u001b[39m rew_dict[agent_id] \n\u001b[1;32m 19\u001b[0m \u001b[38;5;66;03m# Update dead agents\u001b[39;00m\n\u001b[1;32m 20\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m done_dict, dead_agent_ids \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28mzip\u001b[39m(done_dicts, dead_agent_ids_batch):\n", - "\u001b[0;31mKeyError\u001b[0m: 7" - ] - } - ], - "source": [ - "for step in range(1000):\n", - "\n", - " # Sample actions\n", - " action_dicts = [\n", - " {\n", - " agent_id: env.action_space.sample() \n", - " for agent_id in agent_ids\n", - " if agent_id not in dead_agent_ids\n", - " }\n", - " for agent_ids, dead_agent_ids in zip(agent_ids_batch, dead_agent_ids_batch)\n", - " ]\n", - " # Step in env\n", - " obs_dicts, rew_dicts, done_dicts, info_dicts = envs.step(action_dicts)\n", - "\n", - " for rew_dict, rewards in zip(rew_dicts, rewards_batch):\n", - " for agent_id in rew_dict.keys():\n", - " rewards[agent_id] += rew_dict[agent_id] \n", - " \n", - " # Update dead agents\n", - " for done_dict, dead_agent_ids in zip(done_dicts, dead_agent_ids_batch):\n", - " for agent_id, is_done in done_dict.items():\n", - " if is_done and agent_id not in dead_agent_ids:\n", - " dead_agent_ids.append(agent_id)\n", - "\n", - " # Reset if all agents are done\n", - " if all([done_dict[\"__all__\"] for done_dict in done_dicts]):\n", - " print(f'Done after {env.step_num} steps -- total return in episode: {rewards}')\n", - " obs_dicts = envs.reset()\n", - " agent_ids_batch = []\n", - " dead_agent_ids_batch = []\n", - " num_agents_batch = []\n", - " rewards_batch = []\n", - " for obs_dict in obs_dicts:\n", - " agent_ids = [agent_id for agent_id in obs_dict.keys()]\n", - " dead_agent_ids = []\n", - " num_agents = len(agent_ids)\n", - " rewards = {agent_id: 0 for agent_id in agent_ids}\n", - " agent_ids_batch.append(agent_ids)\n", - " dead_agent_ids_batch.append(dead_agent_ids)\n", - " num_agents_batch.append(num_agents)\n", - " rewards_batch.append(rewards)\n", - "\n", - " # # Sample actions\n", - " # action_dict = {\n", - " # agent_id: env.action_space.sample() \n", - " # for agent_id in agent_ids\n", - " # if agent_id not in dead_agent_ids\n", - " # }\n", - " # # Step in env\n", - " # obs_dict, rew_dict, done_dict, info_dict = gymnasiumEnv.step(action_dict)\n", - "\n", - " # for agent_id in action_dict.keys():\n", - " # rewards[agent_id] += rew_dict[agent_id]\n", - "\n", - " # # Update dead agents\n", - " # for agent_id, is_done in done_dict.items():\n", - " # if is_done and agent_id not in dead_agent_ids:\n", - " # dead_agent_ids.append(agent_id)\n", - "\n", - " # # Reset if all agents are done\n", - " # if done_dict[\"__all__\"]:\n", - " # print(f'Done after {env.step_num} steps -- total return in episode: {rewards}')\n", - " # obs_dict = gymnasiumEnv.reset()\n", - " # agent_ids = [agent_id for agent_id in obs_dict.keys()]\n", - " # dead_agent_ids = []\n", - " # rewards = {agent_id: 0 for agent_id in agent_ids}" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [] - } - ], - "metadata": { - "kernelspec": { - "display_name": ".venv", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.11.5" - } - }, - "nbformat": 4, - "nbformat_minor": 2 -} diff --git a/nocturne/envs/nocturne_gymnasium.py b/nocturne/envs/nocturne_gymnasium.py index 099031b2..9dbe76d2 100644 --- a/nocturne/envs/nocturne_gymnasium.py +++ b/nocturne/envs/nocturne_gymnasium.py @@ -2,12 +2,14 @@ import logging import time from copy import deepcopy -from typing import Any, Dict, List, TypeVar +from typing import Any, Dict, List, TypeVar, SupportsFloat -import gym import gymnasium import numpy as np +from pufferlib.emulation import Postprocessor + + from nocturne.envs.base_env import BaseEnv from utils.config import load_config @@ -24,11 +26,13 @@ def __init__(self, config, num_agents, psr=False): self.env = BaseEnv(config) # Make action and observation spaces compatible with SB3 (requires gymnasium) - self.action_space = gymnasium.spaces.MultiDiscrete([self.env.config.max_num_vehicles, self.env.action_space.n]) - self.observation_space = gym.spaces.Box(-np.inf, np.inf, self.env.observation_space.shape, np.float32) + # self.action_space = gymnasium.spaces.MultiDiscrete([self.env.config.max_num_vehicles, self.env.action_space.n]) self.num_agents = num_agents # The maximum number of agents allowed in the environmen + self.action_space = gymnasium.spaces.MultiDiscrete([self.env.action_space.n] * self.num_agents) + self.observation_space = gymnasium.spaces.Box(-np.inf, np.inf, [self.num_agents, self.env.observation_space.shape[0]], np.float32) self.psr = psr # Whether to use PSR or not + self.buf_obs = None # type: ObsType self.psr_dict = self.init_scene_dict() if psr else None # Initialize dict to keep track of the average reward obtained in each scene self.n_episodes = 0 self.episode_lengths = [] @@ -40,33 +44,106 @@ def __init__(self, config, num_agents, psr=False): self.agents_in_scene = [] self.filename = None # If provided, always use the same file - def step(self, actions): - """Take a step in the environment, convert dicts to np arrays. + def step(self, actions) -> tuple[ObsType, SupportsFloat, bool, bool, dict[str, Any]]: + """Convert action vector to dict and call env.step().""" - Args - ---- - action (Dict): Dictionary with a single action for the controlled vehicle. + agent_actions = { + agent_id: actions[idx] for idx, agent_id in enumerate(self.agent_ids) if agent_id not in self.dead_agent_ids + } - Returns - ------- - observation, reward, terminated, truncated, info (np.ndarray, float, bool, bool, dict) - """ - next_obs_dict, rewards_dict, dones_dict, info_dict = self.env.step( - action_dict=actions - ) + # Take a step to obtain dicts + next_obses_dict, rew_dict, done_dict, info_dict = self.env.step(agent_actions) + + # Update dead agents based on most recent done_dict + for agent_id, is_done in done_dict.items(): + if is_done and agent_id not in self.dead_agent_ids: + self.dead_agent_ids.append(agent_id) + # Store agents' last info dict + self.last_info_dicts[agent_id] = info_dict[agent_id].copy() + + # Storage + obs = np.full(fill_value=np.nan, shape=self.observation_space.shape) + self.buf_dones = np.full(fill_value=np.nan, shape=(self.num_agents,)) + self.buf_rews = np.full_like(self.buf_dones, fill_value=np.nan) + self.buf_infos = [{} for _ in range(self.num_agents)] + + # Override NaN placeholder for each agent that is alive + for idx, key in enumerate(self.agent_ids): + if key in next_obses_dict: + self.buf_rews[idx] = rew_dict[key] + self.buf_dones[idx] = done_dict[key] * 1 + self.buf_infos[idx] = info_dict[key] + obs[idx, :] = next_obses_dict[key] + + # Save step reward obtained across all agents + self.rewards.append(sum(rew_dict.values())) + self.agents_in_scene.append(len(self.agent_ids)) + + # Store observation + self._save_obs(obs) + + # Reset episode if ALL agents are done + if done_dict["__all__"]: + for agent_id in self.agent_ids: + self.ep_collisions += self.last_info_dicts[agent_id]["collided"] * 1 + self.ep_goal_achived += self.last_info_dicts[agent_id]["goal_achieved"] * 1 + + # Store the fraction of agents that collided in episode + self.num_agents_collided += self.ep_collisions + self.num_agents_goal_achieved += self.ep_goal_achived + self.total_agents_in_rollout += len(self.agent_ids) + + # Save final observation where user can get it, then reset + for idx in range(len(self.agent_ids)): + self.buf_infos[idx]["terminal_observation"] = obs[idx] + + # Log episode stats + ep_len = self.step_num + self.n_episodes += 1 + self.episode_lengths.append(ep_len) + + # Store reward at scene level + if self.psr: + self.psr_dict[self.env.file]["count"] += 1 + self.psr_dict[self.env.file]["reward"] += (sum(rew_dict.values())) / len(self.agent_ids) + self.psr_dict[self.env.file]["goal_rate"] += self.ep_goal_achived / len(self.agent_ids) + + # Reset + obs = self.reset() return ( - next_obs_dict, - rewards_dict, - dones_dict, + self._obs_from_buf(), + np.copy(self.buf_rews), + self.buf_dones.all(), False, - info_dict, + {'infos': deepcopy(self.buf_infos)}, ) + # def step(self, actions) -> tuple[ObsType, SupportsFloat, bool, bool, dict[str, Any]]: + # """Take a step in the environment, convert dicts to np arrays. + + # Args + # ---- + # action (Dict): Dictionary with a single action for the controlled vehicle. + + # Returns + # ------- + # observation, reward, terminated, truncated, info (np.ndarray, float, bool, bool, dict) + # """ + # next_obs_dict, rewards_dict, dones_dict, info_dict = self.env.step( + # action_dict=actions + # ) + + # return ( + # next_obs_dict, + # rewards_dict, + # dones_dict, + # False, + # info_dict, + # ) + def reset(self, seed=None): """Reset environment and return initial observations.""" - obs_dict = self.env.reset() - # Reset Nocturne env obs_dict = self.env.reset(self.filename, self.psr_dict) @@ -77,7 +154,7 @@ def reset(self, seed=None): self.ep_collisions = 0 self.ep_goal_achived = 0 - obs_all = np.full(fill_value=np.nan, shape=(self.num_envs, self.env.observation_space.shape[0])) + obs_all = np.full(fill_value=-np.pi*1e7, shape=self.observation_space.shape, dtype=self.observation_space.dtype) for idx, agent_id in enumerate(obs_dict.keys()): self.agent_ids.append(agent_id) obs_all[idx, :] = obs_dict[agent_id] @@ -96,21 +173,14 @@ def _obs_from_buf(self) -> ObsType: """Get observation from buffer.""" return np.copy(self.buf_obs) - @property - def action_space(self): - return self.env.action_space - - @action_space.setter - def action_space(self, action_space): - self.env.action_space = action_space + def _save_obs(self, obs: ObsType) -> None: + """Save observations into buffer.""" + self.buf_obs = obs @property - def observation_space(self): - return self.env.observation_space - - @observation_space.setter - def observation_space(self, observation_space): - self.env.observation_space = observation_space + def step_num(self) -> List[int]: + """The episodic timestep.""" + return self.env.step_num def render(self): pass @@ -126,11 +196,82 @@ def seed(self, seed=None): def seed(self, seed=None): pass - def __getattr__(self, name): - return getattr(self._env, name) + def get_attr(self, attr_name, indices=None): + raise NotImplementedError() + + def set_attr(self, attr_name, value, indices=None) -> None: + raise NotImplementedError() + +class CustomPostprocessor(Postprocessor): + '''Basic postprocessor that injects returns and lengths information into infos and + provides an option to pad to a maximum episode length. Works for single-agent and + team-based multi-agent environments''' + def reset(self, obs): + self.epoch_return = 0 + self.epoch_length = 0 + self.done = False + + def reward_done_truncated_info(self, reward, done, truncated, info): + if isinstance(reward, (list, np.ndarray)): + reward = sum(reward) + + # Env is done + if self.done: + return reward, done, truncated, info + + self.epoch_length += 1 + self.epoch_return += reward + + if done.all() or truncated: + info['return'] = self.epoch_return + info['length'] = self.epoch_length + self.done = True + + return reward, done, truncated, info + +def make_env(env_config, num_agents): + return NocturneGymnasium(config=env_config, num_agents=num_agents) + + +def nocturne_creator(env_config, num_agents): + return pufferlib.emulation.GymnasiumPufferEnv(env_creator=make_env, env_args=(env_config,num_agents,), postprocessor_cls=CustomPostprocessor) + +if __name__ == "__main__": + MAX_AGENTS = 3 + NUM_STEPS = 400 + + # Load environment variables and config + env_config = load_config("env_config") + + # Set the number of max vehicles + env_config.max_num_vehicles = MAX_AGENTS + + # from stable_baselines3.common.vec_env import SubprocVecEnv + + # # Make environment + # envs = SubprocVecEnv([lambda: make_env(env_config, MAX_AGENTS) for _ in range(4)]) + env = make_env(env_config, MAX_AGENTS) + import pufferlib.emulation + env = pufferlib.emulation.GymnasiumPufferEnv(env, postprocessor_cls=CustomPostprocessor) + env.reset() + env.step(env.action_space.sample()) + import pufferlib.vectorization + vec = pufferlib.vectorization.Multiprocessing + envs = vec(nocturne_creator,env_args=[env_config, MAX_AGENTS], num_envs=4, envs_per_worker=2, env_pool=True) + envs.async_reset() + obs = envs.recv()[0] + actions = [envs.single_action_space.sample() for _ in range(4)] + envs.step(actions) + envs.step(actions) + + for global_step in range(NUM_STEPS): + # Take random action(s) -- you'd obtain this from a policy + actions = np.array([envs.action_space.sample() for _ in range(4)]) + + # Step + obs, rew, done, info = envs.step(actions) - def get_attr(self, attr_name: str): - return getattr(self._env, attr_name) + # Log + # logging.info(f"step_num: {env.step_num} (global = {global_step}) | done: {done} | rew: {rew}") - def set_attr(self, attr_name: str): - setattr(self._env, attr_name) + time.sleep(0.2)