Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[RLlib; Offline RL] Add docstrings to 'MARWIL'. #47157

Merged
merged 26 commits into from
Sep 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
613acfa
Added docstring to 'MARWILOfflinePreLearner'.
simonsays1980 Aug 15, 2024
c1563e7
Merge branch 'master' into add-docstrings-to-marwil
simonsays1980 Aug 19, 2024
8b2d98d
Added test code to 'MARWILOfflinePreLearner'.
simonsays1980 Aug 19, 2024
d9b08a5
Small nit.
simonsays1980 Aug 19, 2024
80271a7
Merge branch 'master' into add-docstrings-to-marwil
simonsays1980 Aug 20, 2024
c35a0a5
Added docstirngs to 'MARWILLearner' and '_training_step_old_api_stack…
simonsays1980 Aug 20, 2024
abf80d2
Merge branch 'master' into add-docstrings-to-marwil
simonsays1980 Aug 23, 2024
731db3b
Fixed multiple bugs in MARWIL's test code.
simonsays1980 Aug 23, 2024
0f9fe3f
Merge branch 'master' into add-docstrings-to-marwil
simonsays1980 Aug 28, 2024
937d0ac
Some small nits here and there because examples were not running.
simonsays1980 Aug 28, 2024
f00eb1c
Merge branch 'master' into add-docstrings-to-marwil
simonsays1980 Aug 29, 2024
69fe6f3
Added data file to the doc tests.
simonsays1980 Aug 29, 2024
614e735
Added data file to the RLlib doctest in 'ray/doc/BUILD'.
simonsays1980 Aug 30, 2024
f61c0c2
Changed data file in 'doc/BUILD' to 'rllib/..'.
simonsays1980 Sep 2, 2024
6da5a82
Merge branch 'master' into add-docstrings-to-marwil
simonsays1980 Sep 2, 2024
2a7e239
Changed path for data in BUILD file for doctests b/c data could still…
simonsays1980 Sep 6, 2024
4b217f7
Merged master.
simonsays1980 Sep 11, 2024
f747100
[rllib] add data to doctest
can-anyscale Sep 12, 2024
dec99e6
rllib tests
can-anyscale Sep 12, 2024
1fe3541
Merge branch 'master' into add-docstrings-to-marwil
simonsays1980 Sep 17, 2024
a81d852
let go
can-anyscale Sep 12, 2024
f5bdbfb
Merge branch 'master' into add-docstrings-to-marwil
simonsays1980 Sep 19, 2024
09c851e
Merged Cuong's branch with correct paths.
simonsays1980 Sep 19, 2024
de42842
Merge branch 'master' into add-docstrings-to-marwil
simonsays1980 Sep 25, 2024
6953130
Merge branch 'master' into add-docstrings-to-marwil
simonsays1980 Sep 25, 2024
9db326e
Added stop criterium to MARWIL testcode to avoid timeout.
simonsays1980 Sep 26, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions doc/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -467,6 +467,7 @@ doctest(
"source/rllib/rllib-sample-collection.rst",
],
),
data = ["//rllib:cartpole-v1_large"],
tags = ["team:rllib"],
)

Expand Down
9 changes: 8 additions & 1 deletion rllib/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,12 @@
load("//bazel:python.bzl", "py_test_module_list")
load("//bazel:python.bzl", "doctest")

filegroup(
name = "cartpole-v1_large",
data = glob(["tests/data/cartpole/cartpole-v1_large/*.parquet"]),
visibility = ["//visibility:public"],
)

doctest(
files = glob(
["**/*.py"],
Expand Down Expand Up @@ -112,7 +118,8 @@ doctest(
]
),
tags = ["team:rllib"],
size = "enormous"
data = glob(["tests/data/cartpole/cartpole-v1_large/*.parquet"]),
size = "enormous",
)

# --------------------------------------------------------------------
Expand Down
129 changes: 93 additions & 36 deletions rllib/algorithms/marwil/marwil.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,41 +47,92 @@
class MARWILConfig(AlgorithmConfig):
"""Defines a configuration class from which a MARWIL Algorithm can be built.

.. testcode::

Example:
>>> from ray.rllib.algorithms.marwil import MARWILConfig
>>> # Run this from the ray directory root.
>>> config = MARWILConfig() # doctest: +SKIP
>>> config = config.training(beta=1.0, lr=0.00001, gamma=0.99) # doctest: +SKIP
>>> config = config.offline_data( # doctest: +SKIP
... input_=["./rllib/tests/data/cartpole/large.json"])
>>> print(config.to_dict()) # doctest: +SKIP
...
>>> # Build an Algorithm object from the config and run 1 training iteration.
>>> algo = config.build() # doctest: +SKIP
>>> algo.train() # doctest: +SKIP

Example:
>>> from ray.rllib.algorithms.marwil import MARWILConfig
>>> from ray import tune
>>> config = MARWILConfig()
>>> # Print out some default values.
>>> print(config.beta) # doctest: +SKIP
>>> # Update the config object.
>>> config.training(lr=tune.grid_search( # doctest: +SKIP
... [0.001, 0.0001]), beta=0.75)
>>> # Set the config object's data path.
>>> # Run this from the ray directory root.
>>> config.offline_data( # doctest: +SKIP
... input_=["./rllib/tests/data/cartpole/large.json"])
>>> # Set the config object's env, used for evaluation.
>>> config.environment(env="CartPole-v1") # doctest: +SKIP
>>> # Use to_dict() to get the old-style python config dict
>>> # when running with tune.
>>> tune.Tuner( # doctest: +SKIP
... "MARWIL",
... param_space=config.to_dict(),
... ).fit()
from pathlib import Path
from ray.rllib.algorithms.marwil import MARWILConfig

# Get the base path (to ray/rllib)
base_path = Path(__file__).parents[2]
# Get the path to the data in rllib folder.
data_path = base_path / "tests/data/cartpole/cartpole-v1_large"

config = MARWILConfig()
# Enable the new API stack.
config.api_stack(
enable_rl_module_and_learner=True,
enable_env_runner_and_connector_v2=True,
)
# Define the environment for which to learn a policy
# from offline data.
config.environment("CartPole-v1")
# Set the training parameters.
config.training(
beta=1.0,
lr=1e-5,
gamma=0.99,
# We must define a train batch size for each
# learner (here 1 local learner).
train_batch_size_per_learner=2000,
)
# Define the data source for offline data.
config.offline_data(
input_=[data_path.as_posix()],
# Run exactly one update per training iteration.
dataset_num_iters_per_learner=1,
)

# Build an `Algorithm` object from the config and run 1 training
# iteration.
algo = config.build()
algo.train()

.. testcode::

from pathlib import Path
from ray.rllib.algorithms.marwil import MARWILConfig
from ray import train, tune

# Get the base path (to ray/rllib)
base_path = Path(__file__).parents[2]
# Get the path to the data in rllib folder.
data_path = base_path / "tests/data/cartpole/cartpole-v1_large"

config = MARWILConfig()
# Enable the new API stack.
config.api_stack(
enable_rl_module_and_learner=True,
enable_env_runner_and_connector_v2=True,
)
# Print out some default values
print(f"beta: {config.beta}")
# Update the config object.
config.training(
lr=tune.grid_search([1e-3, 1e-4]),
beta=0.75,
# We must define a train batch size for each
# learner (here 1 local learner).
train_batch_size_per_learner=2000,
)
# Set the config's data path.
config.offline_data(
input_=[data_path.as_posix()],
# Set the number of updates to be run per learner
# per training step.
dataset_num_iters_per_learner=1,
)
# Set the config's environment for evalaution.
config.environment(env="CartPole-v1")
# Set up a tuner to run the experiment.
tuner = tune.Tuner(
"MARWIL",
param_space=config,
run_config=train.RunConfig(
stop={"training_iteration": 1},
),
)
# Run the experiment.
tuner.fit()
"""

def __init__(self, algo_class=None):
Expand Down Expand Up @@ -162,11 +213,12 @@ def training(
see bc.py algorithm in this same directory.
bc_logstd_coeff: A coefficient to encourage higher action distribution
entropy for exploration.
moving_average_sqd_adv_norm_update_rate: The rate for updating the
squared moving average advantage norm (c^2). A higher rate leads
to faster updates of this moving avergage.
moving_average_sqd_adv_norm_start: Starting value for the
squared moving average advantage norm (c^2).
vf_coeff: Balancing value estimation loss and policy optimization loss.
moving_average_sqd_adv_norm_update_rate: Update rate for the
squared moving average advantage norm (c^2).
grad_clip: If specified, clip the global norm of gradients by this amount.

Returns:
Expand Down Expand Up @@ -458,6 +510,11 @@ class (multi-/single-learner setup) and evaluation on
return self.metrics.reduce()

def _training_step_old_api_stack(self) -> ResultDict:
"""Implements training step for the old stack.

Note, there is no hybrid stack anymore. If you need to use `RLModule`s,
use the new api stack.
"""
# Collect SampleBatches from sample workers.
with self._timers[SAMPLE_TIMER]:
train_batch = synchronous_parallel_sample(worker_set=self.env_runner_group)
Expand Down
5 changes: 5 additions & 0 deletions rllib/algorithms/marwil/torch/marwil_torch_learner.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,11 @@


class MARWILTorchLearner(MARWILLearner, TorchLearner):
"""Implements torch-specific MARWIL loss on top of MARWILLearner.

This class implements the MARWIL loss under `self.compute_loss_for_module()`.
"""

def compute_loss_for_module(
self,
*,
Expand Down
1 change: 1 addition & 0 deletions rllib/tuned_examples/bc/pendulum_bc.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@
.offline_data(
input_=[data_path],
input_read_method_kwargs={"override_num_blocks": max(args.num_gpus, 1)},
dataset_num_iters_per_learner=1 if args.num_gpus == 0 else None,
)
.training(
# To increase learning speed with multiple learners,
Expand Down
Loading