diff --git a/envpool/sokoban/astar_log.cc b/envpool/sokoban/astar_log.cc index 976cc67a..75be9ce9 100644 --- a/envpool/sokoban/astar_log.cc +++ b/envpool/sokoban/astar_log.cc @@ -41,7 +41,7 @@ void RunAStar(const std::string& level_file_name, if (line.empty()) { continue; } - SokobanLevel level = *level_loader.GetLevel(gen); + SokobanLevel level = level_loader.GetLevel(gen).data; level_idx++; } } @@ -50,7 +50,7 @@ void RunAStar(const std::string& level_file_name, while (level_idx < total_levels_to_run) { std::AStarSearch astarsearch(fsa_limit); std::cout << "Running level " << level_idx << std::endl; - SokobanLevel level = *level_loader.GetLevel(gen); + SokobanLevel level = level_loader.GetLevel(gen).data; SokobanNode node_start(dim_room, level, false); SokobanNode node_end(dim_room, level, true); diff --git a/envpool/sokoban/astar_log_level.cc b/envpool/sokoban/astar_log_level.cc index 96c3802f..91b59620 100644 --- a/envpool/sokoban/astar_log_level.cc +++ b/envpool/sokoban/astar_log_level.cc @@ -41,7 +41,7 @@ void RunAStar(const std::string& level_file_name, } std::AStarSearch astarsearch(fsa_limit); std::cout << "Running level " << level_idx << std::endl; - SokobanLevel level = *level_loader.GetLevel(gen); + SokobanLevel level = level_loader.GetLevel(gen).data; SokobanNode node_start(dim_room, level, false); SokobanNode node_end(dim_room, level, true); diff --git a/envpool/sokoban/level_loader.cc b/envpool/sokoban/level_loader.cc index 17e2bbc0..e2803663 100644 --- a/envpool/sokoban/level_loader.cc +++ b/envpool/sokoban/level_loader.cc @@ -118,16 +118,18 @@ void LevelLoader::LoadFile(std::mt19937& gen) { if (cur_file_ == level_file_paths_.end()) { throw std::runtime_error("No more files to load."); } + cur_level_file_++; file_path = *cur_file_; cur_file_++; } else { - const size_t load_file_idx = SafeUniformInt( - static_cast(0), level_file_paths_.size() - 1, gen); - file_path = level_file_paths_.at(load_file_idx); + cur_level_file_ = SafeUniformInt(static_cast(0), + level_file_paths_.size() - 1, gen); + file_path = level_file_paths_.at(cur_level_file_); } std::ifstream file(file_path); levels_.clear(); + int cur_level_idx = 0; std::string line; while (std::getline(file, line)) { if (line.empty()) { @@ -135,7 +137,7 @@ void LevelLoader::LoadFile(std::mt19937& gen) { } if (line.at(0) == '#') { - SokobanLevel& cur_level = levels_.emplace_back(0); + SokobanLevel cur_level(0); cur_level.reserve(10 * 10); // In practice most levels are this size // Count contiguous '#' characters and use this as the box dimension @@ -163,6 +165,8 @@ void LevelLoader::LoadFile(std::mt19937& gen) { << "x" << dim_room << std::endl; throw std::runtime_error(msg.str()); } + levels_.emplace_back( + std::make_pair(cur_level_idx++, std::move(cur_level))); } } if (!load_sequentially_) { @@ -178,20 +182,21 @@ void LevelLoader::LoadFile(std::mt19937& gen) { std::cout << "***Loaded " << levels_.size() << " levels from " << file_path << std::endl; if (verbose >= 2) { - PrintLevel(std::cout, levels_.at(0)); + PrintLevel(std::cout, levels_.at(0).second); std::cout << std::endl; - PrintLevel(std::cout, levels_.at(1)); + PrintLevel(std::cout, levels_.at(1).second); std::cout << std::endl; } } } -std::vector::iterator LevelLoader::GetLevel(std::mt19937& gen) { +TaggedSokobanLevel LevelLoader::GetLevel(std::mt19937& gen) { if (n_levels_to_load_ > 0 && levels_loaded_ >= n_levels_to_load_) { // std::cerr << "Warning: All levels loaded. Looping around now." << // std::endl; levels_loaded_ = 0; cur_file_ = level_file_paths_.begin(); + cur_level_file_ = -1; LoadFile(gen); // re-start from the `env_id`th level, like we do in the constructor. cur_level_ = env_id_; @@ -206,7 +211,9 @@ std::vector::iterator LevelLoader::GetLevel(std::mt19937& gen) { auto out = levels_.begin() + cur_level_; cur_level_ += num_envs_; levels_loaded_++; - return out; + + TaggedSokobanLevel tagged_level{cur_level_file_, out->first, out->second}; + return tagged_level; } } // namespace sokoban diff --git a/envpool/sokoban/level_loader.h b/envpool/sokoban/level_loader.h index c24cd416..b8a558d1 100644 --- a/envpool/sokoban/level_loader.h +++ b/envpool/sokoban/level_loader.h @@ -19,6 +19,7 @@ #include #include +#include #include namespace sokoban { @@ -34,6 +35,11 @@ constexpr uint8_t kPlayer = 5; constexpr uint8_t kPlayerOnTarget = 6; constexpr uint8_t kMaxLevelObject = kPlayerOnTarget; +struct TaggedSokobanLevel { + int file_idx, level_idx; + SokobanLevel data; +}; + class LevelLoader { protected: bool load_sequentially_; @@ -41,8 +47,8 @@ class LevelLoader { int levels_loaded_{0}; int env_id_{0}; int num_envs_{1}; - std::vector levels_{0}; - int cur_level_; + std::vector> levels_{0}; + int cur_level_{-1}, cur_level_file_{-1}; std::vector level_file_paths_{0}; std::vector::iterator cur_file_; void LoadFile(std::mt19937& gen); @@ -50,7 +56,7 @@ class LevelLoader { public: int verbose; - std::vector::iterator GetLevel(std::mt19937& gen); + TaggedSokobanLevel GetLevel(std::mt19937& gen); explicit LevelLoader(const std::filesystem::path& base_path, bool load_sequentially, int n_levels_to_load, int env_id = 0, int num_envs = 1, int verbose = 0); diff --git a/envpool/sokoban/sokoban_envpool.cc b/envpool/sokoban/sokoban_envpool.cc index ac5d9a75..003f00a7 100644 --- a/envpool/sokoban/sokoban_envpool.cc +++ b/envpool/sokoban/sokoban_envpool.cc @@ -15,9 +15,11 @@ #include "envpool/sokoban/sokoban_envpool.h" #include +#include #include #include #include +#include #include #include "envpool/core/py_envpool.h" @@ -31,7 +33,11 @@ void SokobanEnv::ResetWithoutWrite() { current_max_episode_steps_ = SafeUniformInt(min_episode_steps, max_episode_steps, gen_); - world_ = *(level_loader_.GetLevel(gen_)); + TaggedSokobanLevel level = level_loader_.GetLevel(gen_); + world_ = level.data; + level_idx_ = level.level_idx; + level_file_idx_ = level.file_idx; + if (world_.size() != dim_room_ * dim_room_) { std::stringstream msg; msg << "Loaded level is not dim_room x dim_room. world_.size()=" @@ -204,6 +210,9 @@ void SokobanEnv::WriteState(float reward) { } } obs.Assign(out.data(), out.size()); + + state["info:level_file_idx"_] = level_file_idx_; + state["info:level_idx"_] = level_idx_; } } // namespace sokoban diff --git a/envpool/sokoban/sokoban_envpool.h b/envpool/sokoban/sokoban_envpool.h index d2cd597d..96f74e4a 100644 --- a/envpool/sokoban/sokoban_envpool.h +++ b/envpool/sokoban/sokoban_envpool.h @@ -48,7 +48,9 @@ class SokobanEnvFns { template static decltype(auto) StateSpec(const Config& conf) { int dim_room = conf["dim_room"_]; - return MakeDict("obs"_.Bind(Spec({3, dim_room, dim_room}))); + return MakeDict("obs"_.Bind(Spec({3, dim_room, dim_room})), + "info:level_file_idx"_.Bind(Spec({-1})), + "info:level_idx"_.Bind(Spec({-1}))); } template static decltype(auto) ActionSpec(const Config& conf) { @@ -106,6 +108,7 @@ class SokobanEnv : public Env { std::filesystem::path levels_dir_; LevelLoader level_loader_; + int level_file_idx_{-1}, level_idx_{-1}; SokobanLevel world_; int verbose_; diff --git a/envpool/sokoban/sokoban_py_envpool_test.py b/envpool/sokoban/sokoban_py_envpool_test.py index d2905c39..34e842ab 100644 --- a/envpool/sokoban/sokoban_py_envpool_test.py +++ b/envpool/sokoban/sokoban_py_envpool_test.py @@ -285,10 +285,12 @@ def test_load_sequentially_with_multiple_envs() -> None: levels_dir = "/app/envpool/sokoban/sample_levels" files = glob.glob(f"{levels_dir}/*.txt") levels_by_files = [] + levels_per_file = [] total_levels, num_envs = 8, 2 for file in sorted(files): levels = read_levels_file(file) levels_by_files.extend(levels) + levels_per_file.append(len(levels)) assert len(levels_by_files) == total_levels, "8 levels stored in files." env = envpool.make( @@ -307,8 +309,17 @@ def test_load_sequentially_with_multiple_envs() -> None: printed_obs = [] for _ in range(2): # check loader loops around and loads levels again + gt_file_idx, gt_level_idx = 0, 0 for _ in range(total_levels // num_envs): - obs, _ = env.reset() + obs, info = env.reset() + level_file_idxs, level_idxs = info["level_file_idx"], info["level_idx"] + for lfi, li in zip(level_file_idxs, level_idxs): + assert lfi == gt_file_idx, f"lfi: {lfi}, gt_file_idx: {gt_file_idx}" + assert li == gt_level_idx, f"li: {li}, gt_level_idx: {gt_level_idx}" + gt_level_idx += 1 + if gt_level_idx == levels_per_file[gt_file_idx]: + gt_file_idx += 1 + gt_level_idx = 0 assert obs.shape == ( num_envs, 3,