Skip to content

Commit

Permalink
Add level infos on reset (#15)
Browse files Browse the repository at this point in the history
  • Loading branch information
taufeeque9 authored Jul 11, 2024
2 parents 93bdd8c + dbeb9a2 commit aa270fc
Show file tree
Hide file tree
Showing 7 changed files with 53 additions and 17 deletions.
4 changes: 2 additions & 2 deletions envpool/sokoban/astar_log.cc
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ void RunAStar(const std::string& level_file_name,
if (line.empty()) {
continue;
}
SokobanLevel level = *level_loader.GetLevel(gen);
SokobanLevel level = level_loader.GetLevel(gen).data;
level_idx++;
}
}
Expand All @@ -50,7 +50,7 @@ void RunAStar(const std::string& level_file_name,
while (level_idx < total_levels_to_run) {
std::AStarSearch<SokobanNode> astarsearch(fsa_limit);
std::cout << "Running level " << level_idx << std::endl;
SokobanLevel level = *level_loader.GetLevel(gen);
SokobanLevel level = level_loader.GetLevel(gen).data;

SokobanNode node_start(dim_room, level, false);
SokobanNode node_end(dim_room, level, true);
Expand Down
2 changes: 1 addition & 1 deletion envpool/sokoban/astar_log_level.cc
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ void RunAStar(const std::string& level_file_name,
}
std::AStarSearch<SokobanNode> astarsearch(fsa_limit);
std::cout << "Running level " << level_idx << std::endl;
SokobanLevel level = *level_loader.GetLevel(gen);
SokobanLevel level = level_loader.GetLevel(gen).data;

SokobanNode node_start(dim_room, level, false);
SokobanNode node_end(dim_room, level, true);
Expand Down
23 changes: 15 additions & 8 deletions envpool/sokoban/level_loader.cc
Original file line number Diff line number Diff line change
Expand Up @@ -118,24 +118,26 @@ void LevelLoader::LoadFile(std::mt19937& gen) {
if (cur_file_ == level_file_paths_.end()) {
throw std::runtime_error("No more files to load.");
}
cur_level_file_++;
file_path = *cur_file_;
cur_file_++;
} else {
const size_t load_file_idx = SafeUniformInt(
static_cast<size_t>(0), level_file_paths_.size() - 1, gen);
file_path = level_file_paths_.at(load_file_idx);
cur_level_file_ = SafeUniformInt(static_cast<size_t>(0),
level_file_paths_.size() - 1, gen);
file_path = level_file_paths_.at(cur_level_file_);
}
std::ifstream file(file_path);

levels_.clear();
int cur_level_idx = 0;
std::string line;
while (std::getline(file, line)) {
if (line.empty()) {
continue;
}

if (line.at(0) == '#') {
SokobanLevel& cur_level = levels_.emplace_back(0);
SokobanLevel cur_level(0);
cur_level.reserve(10 * 10); // In practice most levels are this size

// Count contiguous '#' characters and use this as the box dimension
Expand Down Expand Up @@ -163,6 +165,8 @@ void LevelLoader::LoadFile(std::mt19937& gen) {
<< "x" << dim_room << std::endl;
throw std::runtime_error(msg.str());
}
levels_.emplace_back(
std::make_pair(cur_level_idx++, std::move(cur_level)));
}
}
if (!load_sequentially_) {
Expand All @@ -178,20 +182,21 @@ void LevelLoader::LoadFile(std::mt19937& gen) {
std::cout << "***Loaded " << levels_.size() << " levels from " << file_path
<< std::endl;
if (verbose >= 2) {
PrintLevel(std::cout, levels_.at(0));
PrintLevel(std::cout, levels_.at(0).second);
std::cout << std::endl;
PrintLevel(std::cout, levels_.at(1));
PrintLevel(std::cout, levels_.at(1).second);
std::cout << std::endl;
}
}
}

std::vector<SokobanLevel>::iterator LevelLoader::GetLevel(std::mt19937& gen) {
TaggedSokobanLevel LevelLoader::GetLevel(std::mt19937& gen) {
if (n_levels_to_load_ > 0 && levels_loaded_ >= n_levels_to_load_) {
// std::cerr << "Warning: All levels loaded. Looping around now." <<
// std::endl;
levels_loaded_ = 0;
cur_file_ = level_file_paths_.begin();
cur_level_file_ = -1;
LoadFile(gen);
// re-start from the `env_id`th level, like we do in the constructor.
cur_level_ = env_id_;
Expand All @@ -206,7 +211,9 @@ std::vector<SokobanLevel>::iterator LevelLoader::GetLevel(std::mt19937& gen) {
auto out = levels_.begin() + cur_level_;
cur_level_ += num_envs_;
levels_loaded_++;
return out;

TaggedSokobanLevel tagged_level{cur_level_file_, out->first, out->second};
return tagged_level;
}

} // namespace sokoban
12 changes: 9 additions & 3 deletions envpool/sokoban/level_loader.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

#include <filesystem>
#include <random>
#include <utility>
#include <vector>

namespace sokoban {
Expand All @@ -34,23 +35,28 @@ constexpr uint8_t kPlayer = 5;
constexpr uint8_t kPlayerOnTarget = 6;
constexpr uint8_t kMaxLevelObject = kPlayerOnTarget;

struct TaggedSokobanLevel {
int file_idx, level_idx;
SokobanLevel data;
};

class LevelLoader {
protected:
bool load_sequentially_;
int n_levels_to_load_;
int levels_loaded_{0};
int env_id_{0};
int num_envs_{1};
std::vector<SokobanLevel> levels_{0};
int cur_level_;
std::vector<std::pair<int, SokobanLevel>> levels_{0};
int cur_level_{-1}, cur_level_file_{-1};
std::vector<std::filesystem::path> level_file_paths_{0};
std::vector<std::filesystem::path>::iterator cur_file_;
void LoadFile(std::mt19937& gen);

public:
int verbose;

std::vector<SokobanLevel>::iterator GetLevel(std::mt19937& gen);
TaggedSokobanLevel GetLevel(std::mt19937& gen);
explicit LevelLoader(const std::filesystem::path& base_path,
bool load_sequentially, int n_levels_to_load,
int env_id = 0, int num_envs = 1, int verbose = 0);
Expand Down
11 changes: 10 additions & 1 deletion envpool/sokoban/sokoban_envpool.cc
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,11 @@
#include "envpool/sokoban/sokoban_envpool.h"

#include <array>
#include <iostream>
#include <limits>
#include <sstream>
#include <stdexcept>
#include <utility>
#include <vector>

#include "envpool/core/py_envpool.h"
Expand All @@ -31,7 +33,11 @@ void SokobanEnv::ResetWithoutWrite() {
current_max_episode_steps_ =
SafeUniformInt(min_episode_steps, max_episode_steps, gen_);

world_ = *(level_loader_.GetLevel(gen_));
TaggedSokobanLevel level = level_loader_.GetLevel(gen_);
world_ = level.data;
level_idx_ = level.level_idx;
level_file_idx_ = level.file_idx;

if (world_.size() != dim_room_ * dim_room_) {
std::stringstream msg;
msg << "Loaded level is not dim_room x dim_room. world_.size()="
Expand Down Expand Up @@ -204,6 +210,9 @@ void SokobanEnv::WriteState(float reward) {
}
}
obs.Assign(out.data(), out.size());

state["info:level_file_idx"_] = level_file_idx_;
state["info:level_idx"_] = level_idx_;
}

} // namespace sokoban
Expand Down
5 changes: 4 additions & 1 deletion envpool/sokoban/sokoban_envpool.h
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,9 @@ class SokobanEnvFns {
template <typename Config>
static decltype(auto) StateSpec(const Config& conf) {
int dim_room = conf["dim_room"_];
return MakeDict("obs"_.Bind(Spec<uint8_t>({3, dim_room, dim_room})));
return MakeDict("obs"_.Bind(Spec<uint8_t>({3, dim_room, dim_room})),
"info:level_file_idx"_.Bind(Spec<int>({-1})),
"info:level_idx"_.Bind(Spec<int>({-1})));
}
template <typename Config>
static decltype(auto) ActionSpec(const Config& conf) {
Expand Down Expand Up @@ -106,6 +108,7 @@ class SokobanEnv : public Env<SokobanEnvSpec> {
std::filesystem::path levels_dir_;

LevelLoader level_loader_;
int level_file_idx_{-1}, level_idx_{-1};
SokobanLevel world_;
int verbose_;

Expand Down
13 changes: 12 additions & 1 deletion envpool/sokoban/sokoban_py_envpool_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -285,10 +285,12 @@ def test_load_sequentially_with_multiple_envs() -> None:
levels_dir = "/app/envpool/sokoban/sample_levels"
files = glob.glob(f"{levels_dir}/*.txt")
levels_by_files = []
levels_per_file = []
total_levels, num_envs = 8, 2
for file in sorted(files):
levels = read_levels_file(file)
levels_by_files.extend(levels)
levels_per_file.append(len(levels))
assert len(levels_by_files) == total_levels, "8 levels stored in files."

env = envpool.make(
Expand All @@ -307,8 +309,17 @@ def test_load_sequentially_with_multiple_envs() -> None:
printed_obs = []

for _ in range(2): # check loader loops around and loads levels again
gt_file_idx, gt_level_idx = 0, 0
for _ in range(total_levels // num_envs):
obs, _ = env.reset()
obs, info = env.reset()
level_file_idxs, level_idxs = info["level_file_idx"], info["level_idx"]
for lfi, li in zip(level_file_idxs, level_idxs):
assert lfi == gt_file_idx, f"lfi: {lfi}, gt_file_idx: {gt_file_idx}"
assert li == gt_level_idx, f"li: {li}, gt_level_idx: {gt_level_idx}"
gt_level_idx += 1
if gt_level_idx == levels_per_file[gt_file_idx]:
gt_file_idx += 1
gt_level_idx = 0
assert obs.shape == (
num_envs,
3,
Expand Down

0 comments on commit aa270fc

Please sign in to comment.