Skip to content

Commit

Permalink
Merge pull request #24 from airboxlab/pearl
Browse files Browse the repository at this point in the history
Training with Meta Pearl example
  • Loading branch information
antoine-galataud authored Jan 11, 2024
2 parents 27978ee + 1b15d44 commit 248111e
Show file tree
Hide file tree
Showing 8 changed files with 914 additions and 1,450 deletions.
4 changes: 2 additions & 2 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -33,10 +33,10 @@ repos:
args: ["-f", "requirements.txt", "-o", "requirements.txt"]

- repo: https://github.com/asottile/pyupgrade
rev: v3.7.0
rev: v3.15.0
hooks:
- id: pyupgrade
args: [--py37-plus]
args: [--py38-plus]
name: Upgrade code

# python formatting
Expand Down
6 changes: 6 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -98,15 +98,21 @@ Run the amphitheater example with default parameters using Ray RLlib PPO algorit
### Using Poetry

```shell
# Using Ray Rllib
poetry run rllib --env AmphitheaterEnv
# Using Meta Pearl
poetry run pearl --env AmphitheaterEnv
```

### Using Python

If you installed dependencies with pip, you can run the example with:

```shell
# Using Ray Rllib
python3 rleplus/train/rllib.py --env AmphitheaterEnv
# Using Meta Pearl
python3 rleplus/train/pearl.py --env AmphitheaterEnv
```

Example of episode reward stats obtained training with PPO, 1e5 timesteps, 2 workers, with default parameters + LSTM, short E+ run period (2 first weeks of January).
Expand Down
774 changes: 752 additions & 22 deletions poetry.lock

Large diffs are not rendered by default.

4 changes: 3 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "rl-energyplus"
version = "0.6.0"
version = "0.7.0"
description = "EnergyPlus Gym Environments for Reinforcement Learning"
authors = ["Antoine Galataud <[email protected]>"]
packages = [
Expand All @@ -19,10 +19,12 @@ numpy = "1.23.5"
protobuf = "3.20.3"
tensorboard = "^2.9.0"
torch = "^2.1.1"
pearl = { git = "https://github.com/facebookresearch/Pearl.git" }
scipy = "^1.10.0"

[tool.poetry.scripts]
rllib = "rleplus.train.rllib:main"
pearl = "rleplus.train.pearl:main"
tests = "tests:run"

[build-system]
Expand Down
1,422 changes: 0 additions & 1,422 deletions requirements.txt

This file was deleted.

7 changes: 5 additions & 2 deletions rleplus/env/energyplus.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,8 @@ def _report_progress(progress: int) -> None:

# run EnergyPlus in a non-blocking way
def _run_energyplus(rn, cmd_args, state, results):
print(f"running EnergyPlus with args: {cmd_args}")
if self.verbose:
print(f"running EnergyPlus with args: {cmd_args}")

# start simulation
results["exit_code"] = rn.run_energyplus(state, cmd_args)
Expand Down Expand Up @@ -268,6 +269,8 @@ class EnergyPlusEnv(gym.Env, metaclass=abc.ABCMeta):
"""

def __init__(self, env_config: Dict[str, Any]):
self.spec = gym.envs.registration.EnvSpec(f"{self.__class__.__name__}")

self.env_config = env_config
self.episode = -1
self.timestep = 0
Expand Down Expand Up @@ -379,7 +382,6 @@ def step(self, action):
else:
# post-process action
action_to_apply = self.post_process_action(action)

# Enqueue action (sent to EnergyPlus through dedicated callback)
# then wait to get next observation.
# Timeout is set to 2s to handle end of simulation cases, which happens async
Expand All @@ -405,6 +407,7 @@ def step(self, action):
# compute reward
reward = self.compute_reward(obs)

# print("obs", obs, "reward", reward, "done", done, "action", action)
obs_vec = np.array(list(obs.values()))
return obs_vec, reward, done, False, {}

Expand Down
11 changes: 10 additions & 1 deletion rleplus/examples/amphitheater/env.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,16 @@


class AmphitheaterEnv(EnergyPlusEnv):
"""University amphitheatre environment."""
"""University amphitheatre environment.
This environment is based on an actual university amphitheatre in Luxembourg. The building model
(calibrated against actual energy consumption) of this amphitheatre is available in the same folder.
The weather file is a typical meteorological year (TMY) weather file.
HVAC: an AHU with a heating hot water coil, and supply and exhaust air fans.
Target actuator: supply air temperature setpoint.
"""

base_path = Path(__file__).parent

Expand Down
136 changes: 136 additions & 0 deletions rleplus/train/pearl.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,136 @@
"""An example of how to use Pearl to train a Bootstrapped DQN agent on the Amphitheater
environment.
See https://github.com/facebookresearch/Pearl for more configuration options.
"""
import argparse
from tempfile import TemporaryDirectory

from pearl.action_representation_modules.identity_action_representation_module import (
IdentityActionRepresentationModule,
)
from pearl.history_summarization_modules.lstm_history_summarization_module import (
LSTMHistorySummarizationModule,
)
from pearl.neural_networks.common.value_networks import EnsembleQValueNetwork
from pearl.pearl_agent import PearlAgent
from pearl.policy_learners.sequential_decision_making.bootstrapped_dqn import (
BootstrappedDQN,
)
from pearl.replay_buffers.sequential_decision_making.bootstrap_replay_buffer import (
BootstrapReplayBuffer,
)
from pearl.utils.functional_utils.train_and_eval.online_learning import online_learning
from pearl.utils.instantiations.environments.gym_environment import GymEnvironment
from pearl.utils.instantiations.spaces.discrete_action import DiscreteActionSpace

from rleplus.examples.registry import env_creator


def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser()
parser.add_argument(
"--env",
help="The gym environment to use.",
required=False,
default="AmphitheaterEnv",
)
parser.add_argument(
"--csv", help="Generate eplusout.csv at end of simulation", required=False, default=False, action="store_true"
)
parser.add_argument(
"--verbose",
help="In verbose mode, EnergyPlus will print to stdout",
required=False,
default=False,
action="store_true",
)
parser.add_argument(
"--output",
help="EnergyPlus output directory. Default is a generated one in /tmp/",
required=False,
default=TemporaryDirectory().name,
)
parser.add_argument("--timesteps", "-t", help="Number of timesteps to train", required=False, default=1e6)

built_args = parser.parse_args()
print(f"Running with following CLI args: {built_args}")
return built_args


def main():
args = parse_args()

# build the environment: we need to wrap the original gym environment in a Pearl environment
env_cls = env_creator(args.env)
env = GymEnvironment(
env_or_env_name=env_cls(
env_config=dict(
csv=args.csv,
verbose=args.verbose,
output=args.output,
)
)
)
assert isinstance(env.action_space, DiscreteActionSpace)

# declare some variables about environment dimensions
num_actions = env.action_space.n
obs_dim = env.observation_space.shape[0]
act_dim = env.action_space.action_dim
# Policy learner state dim, as well as hidden dim for the LSTM history summarization module.
# Note that Pearl flow is: (LSTM) history summarization module -> Policy Learner, hence the LSTM output/hidden dim
# is the same as the policy learner's state dim
state_dim = 128

# Bootstrapped DQN, is an extension of DQN that uses the so-called "deep exploration" mechanism.
# The main idea is to keep an ensemble of k Q-value networks and on each episode, one of them is sampled and the
# greedy policy associated with that network is used for exploration.
# See: https://arxiv.org/abs/1602.04621
k = 10
policy_learner = BootstrappedDQN(
q_ensemble_network=EnsembleQValueNetwork(
state_dim=state_dim,
action_dim=act_dim,
ensemble_size=k,
output_dim=1,
hidden_dims=[64, 64],
prior_scale=0.3,
),
action_space=env.action_space,
training_rounds=50,
action_representation_module=IdentityActionRepresentationModule(
max_number_actions=num_actions,
representation_dim=act_dim,
),
)

# History summarization module: we use the LSTM history summarization module
history_summarization_module = LSTMHistorySummarizationModule(
observation_dim=obs_dim,
action_dim=act_dim,
hidden_dim=state_dim,
history_length=8,
)

# Pearl agent
agent = PearlAgent(
policy_learner=policy_learner,
history_summarization_module=history_summarization_module,
replay_buffer=BootstrapReplayBuffer(100_000, 1.0, k),
device_id=-1,
)

# run the online learning loop
online_learning(
agent=agent,
env=env,
number_of_steps=args.timesteps,
print_every_x_steps=100,
record_period=10000,
learn_after_episode=True,
)


if __name__ == "__main__":
main()

0 comments on commit 248111e

Please sign in to comment.