Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 0 additions & 2 deletions .github/workflows/lint_python.yml
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,9 @@ on:
- main
paths:
- "nle_language_wrapper/**.py"
- "setup.py"
pull_request:
paths:
- "nle_language_wrapper/**.py"
- "setup.py"

jobs:
check_python:
Expand Down
25 changes: 12 additions & 13 deletions .github/workflows/run_tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -25,26 +25,25 @@ jobs:
runs-on: ubuntu-latest

steps:
- name: Setup Python 3.9 env
uses: actions/setup-python@v1
with:
python-version: "3.9"
- name: Clone repository
uses: actions/checkout@v2
- name: Checkout the repository
uses: actions/checkout@v4

- name: Install the latest version of uv.
uses: astral-sh/setup-uv@v5
with:
submodules: recursive
- name: Upgrade pip if necessary
run: "pip install -q --upgrade pip"
enable-cache: true
cache-dependency-glob: "uv.lock"

- name: Install build dependencies
run: |
sudo apt-get update
sudo apt-get install -y build-essential autoconf libtool pkg-config
sudo apt-get install -y python3-dev python3-pip python3-numpy
sudo apt-get install -y flex bison libbz2-dev cmake
- name: Install package
run: "pip install -e .['dev']"
- name: Run tests
run: "make test"
- name: Install package with uv
run: "uv sync"
- name: Run tests with uv
run: "uv run pytest"
- name: Upload code coverage
run: |
curl -Os https://uploader.codecov.io/latest/linux/codecov
Expand Down
6 changes: 0 additions & 6 deletions .gitmodules
Original file line number Diff line number Diff line change
@@ -1,6 +0,0 @@
[submodule "pybind11"]
path = pybind11
url = https://github.com/pybind/pybind11
[submodule "nle"]
path = nle
url = https://github.com/facebookresearch/nle.git
52 changes: 33 additions & 19 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,11 +1,26 @@
cmake_minimum_required(VERSION 3.15)
project(nethack_language_wrapper)
cmake_policy(SET CMP0148 OLD)

message(STATUS "Building nle_language_wrapper backend")

set(CMAKE_POSITION_INDEPENDENT_CODE ON)

add_compile_definitions(
include(FetchContent)
FetchContent_Declare(
nle
GIT_REPOSITORY https://github.com/jbcoe/nle-nethack.git
GIT_TAG 88a8da68dbd95e55ae6dbdd2fd2f1810bbad8207)
FetchContent_MakeAvailable(nle)
find_package(pybind11 CONFIG REQUIRED)

pybind11_add_module(
nle_language_obsv
src/main.cpp )

target_compile_definitions(
nle_language_obsv
PRIVATE
GCC_WARN
NOCLIPPING
NOMAIL
Expand All @@ -14,23 +29,22 @@ add_compile_definitions(
DLB
NOCWD_ASSUMPTIONS)

set(CMAKE_CXX_FLAGS "-O3 -Wall -Wextra")
target_compile_options(
nle_language_obsv
PRIVATE -O3 -Wall -Wextra)

set(NLE_BASE ${PYTHON_SRC_PARENT}/nle/nle)
set(NLE_INC ${PYTHON_SRC_PARENT}/nle/)

file(GLOB_RECURSE PM_H_PATH "./nle/build" pm.h)
get_filename_component(NLE_INC_GEN ${PM_H_PATH} DIRECTORY)

add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/pybind11)
pybind11_add_module(
set_target_properties(
nle_language_obsv
src/main.cpp )

set_target_properties(nle_language_obsv PROPERTIES CXX_STANDARD 17)
target_link_directories(nle_language_obsv PUBLIC ${PYTHON_SRC_PARENT}/nle_language_wrapper)
target_include_directories(nle_language_obsv PUBLIC ${NLE_INC})
target_include_directories(nle_language_obsv PUBLIC ${NLE_INC_GEN})
target_link_libraries(nle_language_obsv PUBLIC nethack)
# Add relative rpath to enable finding the nethack shared library.
set_target_properties(nle_language_obsv PROPERTIES LINK_FLAGS "-Wl,-rpath,'$ORIGIN'")
PROPERTIES
CXX_STANDARD 17
POSITION_INDEPENDENT_CODE ON
# Add relative rpath to enable finding the nethack shared library in site_packages/nle.
LINK_FLAGS "-Wl,-rpath,'$ORIGIN/../nle'")

target_link_libraries(
nle_language_obsv
PUBLIC
nethack)

install(TARGETS nle_language_obsv
LIBRARY DESTINATION nle_language_wrapper)
10 changes: 5 additions & 5 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -5,16 +5,16 @@
.PHONY: test

format-python:
isort setup.py nle_language_wrapper
black setup.py nle_language_wrapper --config pyproject.toml
isort nle_language_wrapper
black nle_language_wrapper --config pyproject.toml

format-cpp:
clang-format -style=Google -i src/main.cpp

format-python-check:
isort -c --diff setup.py nle_language_wrapper
black --check --diff setup.py nle_language_wrapper
pylint setup.py \
isort -c --diff nle_language_wrapper
black --check --diff nle_language_wrapper
pylint \
nle_language_wrapper/agents/ \
nle_language_wrapper/wrappers/ \
nle_language_wrapper/scripts/ \
Expand Down
1 change: 0 additions & 1 deletion nle
Submodule nle deleted from 2fa1be
4 changes: 2 additions & 2 deletions nle_language_wrapper/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ def fake_nle_env(mocker):
info = None
nle_env.reset = mocker.MagicMock(return_value=obsv)
nle_env.step = mocker.MagicMock(return_value=(obsv, reward, done, info))
nle_env.actions = [nethack_actions.CompassDirection.N]
nle_env._actions = [nethack_actions.CompassDirection.N]
nle_env.observation_space = spaces.Dict(
{
"glyphs": spaces.Space(),
Expand Down Expand Up @@ -138,7 +138,7 @@ def fake_nethack_multiple_monsters_env(mocker):
info = None
nle_env.reset = mocker.MagicMock(return_value=obsv)
nle_env.step = mocker.MagicMock(return_value=(obsv, reward, done, info))
nle_env.actions = [nethack_actions.CompassDirection.N]
nle_env._actions = [nethack_actions.CompassDirection.N]
nle_env.observation_space = spaces.Dict(
{
"glyphs": spaces.Space(),
Expand Down
6 changes: 3 additions & 3 deletions nle_language_wrapper/tests/test_nle_language_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,15 +68,15 @@ def test_action_actions_maps_reflect_valid_actions(fake_nle_env):


def test_step_valid_action_not_supported(real_nethack_env):
real_nethack_env.actions = [
real_nethack_env._actions = [
action
for action in list(real_nethack_env.actions)
for action in list(real_nethack_env._actions)
if action != nethack_actions.Command.TRAVEL
]

dut = NLELanguageWrapper(real_nethack_env)
dut.reset()
dut.env.actions = list(dut.env.actions)
dut.env._actions = list(dut.env._actions)
with pytest.raises(ValueError):
dut.step("travel")

Expand Down
11 changes: 7 additions & 4 deletions nle_language_wrapper/wrappers/nle_language_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,7 +232,9 @@ def __init__(self, env, use_language_action=True):
use_language_action(bool): Use language action or discrete integer actions
"""
super().__init__(env)
assert isinstance(env, NLE), "Only NLE environments are supported"
assert isinstance(
env, NLE
), f"Only NLE environments are supported {env} {type(env)}"
missing_obsv_keys = self.REQUIRED_NLE_OBSV_KEYS.difference(
env.observation_space.spaces.keys()
)
Expand All @@ -245,16 +247,17 @@ def __init__(self, env, use_language_action=True):

# Build map for action string to NLE Action Enum
self.action_str_enum_map = {}

for nle_action_enum, action_strs in self.all_nle_action_map.items():
if nle_action_enum in self.env.actions:
if nle_action_enum in self.env._actions:
for action_str in action_strs:
self.action_str_enum_map[action_str] = nle_action_enum

# Build map for NLE Action Enum to NLE action index
self.action_enum_index_map = {}
for nle_action_enum, _ in self.all_nle_action_map.items():
if nle_action_enum in self.env.actions:
self.action_enum_index_map[nle_action_enum] = self.env.actions.index(
if nle_action_enum in self.env._actions:
self.action_enum_index_map[nle_action_enum] = self.env._actions.index(
nle_action_enum
)

Expand Down
1 change: 0 additions & 1 deletion pybind11
Submodule pybind11 deleted from 479e9a
75 changes: 69 additions & 6 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,3 +1,57 @@
[project]
name = "nle-language-wrapper"
version = "1.0.0a"
description = "Language Wrapper for the NetHack Learning Environment (NLE)"
authors = [{ name = "Nikolaj Goodger", email = "[email protected]" }]
readme = "README.md"
keywords = ["nle", "language", "wrapper"]
requires-python = ">=3.8"
classifiers=[
"License :: OSI Approved :: MIT License",
"Development Status :: 2 - Pre-Alpha",
"Operating System :: POSIX :: Linux",
"Operating System :: MacOS",
"Programming Language :: Python :: 3",
"Programming Language :: Python :: 3.8",
"Programming Language :: Python :: 3.9",
"Programming Language :: Python :: 3.10",
"Programming Language :: C++",
"Topic :: Scientific/Engineering :: Artificial Intelligence",
"Topic :: Games/Entertainment",
]
dependencies = [
"nle @ git+https://github.com/jbcoe/nle-nethack.git@88a8da68dbd95e55ae6dbdd2fd2f1810bbad8207",
"minihack>=0.1.4",
"gym>=0.15,<=0.23",
]

[project.license]
file = "LICENSE"

[project.urls]
Homepage = "https://github.com/ngoodger/nle-language-wrapper"

[dependency-groups]
"dev"= [
"black>=22.6.0",
"flake8>=4.0.1",
"pylint>=2.15.8",
"pytest>=7.1.2",
"pytest-cov>=3.0.0",
"pytest-mock>=3.7.0",
"pygame>=2.1.2",
"isort>=5.10.1",
"numpy>=1.21.0",
]

# [project.optional-dependencies]
# "agent" = [
# "sample_factory>=1.121.4",
# "transformers>=4.17.0",
# "torch@https://download.pytorch.org/whl/cu111/",
# "torch-1.9.1%2Bcu111-cp39-cp39-linux_x86_64.whl",
# ]

[tool.black]
line-length = 88
target-version = ['py37']
Expand Down Expand Up @@ -29,10 +83,19 @@ disable = [
"wrong-import-order"
]


[build-system]
requires = [
"setuptools>=42",
"wheel"
]
build-backend = "setuptools.build_meta"
requires = ["scikit-build-core>=0.10", "pybind11>=2.2", "wheel", "setuptools-scm"]
build-backend = "scikit_build_core.build"

[tool.scikit-build]
wheel.packages = ["nle_language_wrapper"]
cmake.build-type = "Release"
cmake.args = ["-DCMAKE_POLICY_VERSION_MINIMUM=3.5"]
[tool.scikit-build.editable]
mode = "redirect"

[tool.scikit-build.metadata.version]
provider = "scikit_build_core.metadata.setuptools_scm"

[tool.pytest.ini_options]
testpaths = ["nle_language_wrapper/tests"]
Loading