From 045780ecda53efae2f77b03b0b2f974bc0181e35 Mon Sep 17 00:00:00 2001 From: daphnedemekas Date: Tue, 22 Apr 2025 12:10:07 -0700 Subject: [PATCH] reformat multienv probs --- mettagrid/mettagrid_env.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/mettagrid/mettagrid_env.py b/mettagrid/mettagrid_env.py index 56f03293..acd20e67 100644 --- a/mettagrid/mettagrid_env.py +++ b/mettagrid/mettagrid_env.py @@ -1,5 +1,5 @@ import copy -from typing import Any, Dict, List, Optional +from typing import Any, Dict, Optional import gymnasium as gym import hydra @@ -180,21 +180,20 @@ class MettaGridEnvSet(MettaGridEnv): def __init__( self, env_cfg: DictConfig, - probabilities: List[float] | None, render_mode: str, buf=None, **kwargs, ): - self._env_cfgs = env_cfg.envs + self._envs = list(env_cfg.envs.keys()) + self._probabilities = list(env_cfg.envs.values()) self._num_agents_global = env_cfg.num_agents - self._probabilities = probabilities self._env_cfg = self._get_new_env_cfg() super().__init__(env_cfg, render_mode, buf, **kwargs) self._cfg_template = None # we don't use this with multiple envs, so we clear it to emphasize that fact def _get_new_env_cfg(self): - selected_env = np.random.choice(self._env_cfgs, p=self._probabilities) + selected_env = np.random.choice(self._envs, p=self._probabilities) env_cfg = config_from_path(selected_env) if self._num_agents_global != env_cfg.game.num_agents: raise ValueError(