Skip to content

[Benchmark] Data benchmarks #799

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 12 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
129 changes: 129 additions & 0 deletions benchmarks/data_collection/atari.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,129 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

"""
Atari game data collection benchmark
====================================

Runs an Atari game with a random policy using a multiprocess async data collector.

Image size: torch.Size([210, 160, 3])

Performance results with default configuration:
+-------------------------------+--------------------------------------------------+
| Machine specs | 3x A100 GPUs, |
| | Intel(R) Xeon(R) Platinum 8275CL CPU @ 3.00GHz |
| | |
+===============================+==================================================+
| Batched transforms | 1775.2762 fps |
+-------------------------------+--------------------------------------------------+
| Single env transform | 2593.7481 fps |
+-------------------------------+--------------------------------------------------+

"""
import argparse
import time

import torch.cuda
import tqdm

from torchrl.collectors.collectors import MultiaSyncDataCollector, RandomPolicy
from torchrl.envs import (
Compose,
EnvCreator,
GrayScale,
ParallelEnv,
Resize,
ToTensorImage,
TransformedEnv,
)
from torchrl.envs.libs.gym import GymEnv

total_frames = 100000

parser = argparse.ArgumentParser()

parser.add_argument(
"--batched",
action="store_true",
help="if True, the transforms will be applied on batches of images.",
)
parser.add_argument(
"--n_envs",
type=int,
default=16,
help="Number of environments to be run in parallel in each collector.",
)
parser.add_argument(
"--n_workers_collector",
type=int,
default=3,
help="Number sub-collectors in the data collector.",
)
parser.add_argument(
"--n_frames",
type=int,
default=64,
help="Number of frames in each batch of data collected.",
)

if __name__ == "__main__":

def make_env():
return GymEnv("ALE/Pong-v5")

# print the raw env output
print(make_env().fake_tensordict())

def make_transformed_env(env):
return TransformedEnv(
env,
Compose(
ToTensorImage(),
GrayScale(),
Resize(84, 84),
),
)

args = parser.parse_args()
if args.batched:
parallel_env = make_transformed_env(
ParallelEnv(args.n_envs, EnvCreator(lambda: make_env()))
)
else:
parallel_env = ParallelEnv(
args.n_envs, EnvCreator(lambda: make_transformed_env(make_env()))
)
devices = list(range(torch.cuda.device_count()))[: args.n_workers_collector]
if len(devices) == 1:
devices = devices[0]
elif len(devices) < args.n_workers_collector:
raise RuntimeError(
"This benchmark requires at least as many GPUs as the number of collector workers."
)
collector = MultiaSyncDataCollector(
[
parallel_env,
]
* args.n_workers_collector,
RandomPolicy(parallel_env.action_spec),
total_frames=total_frames,
frames_per_batch=args.n_frames,
devices=devices,
passing_devices=devices,
split_trajs=False,
)
frames = 0
pbar = tqdm.tqdm(total=total_frames)
for i, data in enumerate(collector):
pbar.update(data.numel())
if i == 10:
t = time.time()
if i >= 10:
frames += data.numel()
t = time.time() - t
del collector
print(f"\n\nframes per sec: {frames/t: 4.4f} (frames={frames}, t={t})\n\n")
exit()
62 changes: 62 additions & 0 deletions benchmarks/data_collection/atari_sb.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

"""
Atari game data collection benchmark with stable-baselines3
===========================================================

Runs an Atari game with a random policy using a multiprocess async data collector.

Image size: torch.Size([210, 160, 3])

Performance results with default configuration:
+-------------------------------+--------------------------------------------------+
| Machine specs | 3x A100 GPUs, |
| | Intel(R) Xeon(R) Platinum 8275CL CPU @ 3.00GHz |
| | |
+===============================+==================================================+
| | 1176.7944 fps |
+-------------------------------+--------------------------------------------------+

"""

import time

import tqdm
from stable_baselines3 import A2C
from stable_baselines3.common.env_util import make_atari_env
from stable_baselines3.common.vec_env import VecFrameStack

# There already exists an environment generator
# that will make and wrap atari environments correctly.
# Here we are also multi-worker training (n_envs=4 => 4 environments)
n_envs = 32
env = make_atari_env("PongNoFrameskip-v4", n_envs=n_envs, seed=0)
# Frame-stacking with 4 frames
env = VecFrameStack(env, n_stack=4)

model = A2C("CnnPolicy", env, verbose=1)

frames = 0
total_frames = 100_000
pbar = tqdm.tqdm(total=total_frames)
obs = env.reset()
action = None

i = 0
while True:
if i == 10:
t0 = time.time()
elif i >= 10:
frames += n_envs
pbar.update(n_envs)
if action is None:
action, _states = model.predict(obs)
obs, rewards, dones, info = env.step(action)
if frames > total_frames:
break
i += 1
t = frames / (time.time() - t0)
print(f"fps: {t}")
129 changes: 129 additions & 0 deletions benchmarks/data_collection/dmc.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,129 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

"""
DeepMind control suite data collection benchmark
================================================

Runs a "cheetah"-"run" dm-control task with a random policy using a multiprocess async data collector.

Image size: torch.Size([240, 320, 3])

Performance results with default configuration:
+-------------------------------+--------------------------------------------------+
| Machine specs | 3x A100 GPUs, |
| | Intel(R) Xeon(R) Platinum 8275CL CPU @ 3.00GHz |
| | |
+===============================+==================================================+
| Batched transforms | 1885.2913 fps |
+-------------------------------+--------------------------------------------------+
| Single env transform | 1903.3575 fps |
+-------------------------------+--------------------------------------------------+

"""
import argparse
import time

import torch.cuda
import tqdm

from torchrl.collectors.collectors import MultiaSyncDataCollector, RandomPolicy
from torchrl.envs import (
Compose,
EnvCreator,
GrayScale,
ParallelEnv,
Resize,
ToTensorImage,
TransformedEnv,
)
from torchrl.envs.libs.dm_control import DMControlEnv

total_frames = 100000

parser = argparse.ArgumentParser()

parser.add_argument(
"--batched",
action="store_true",
help="if True, the transforms will be applied on batches of images.",
)
parser.add_argument(
"--n_envs",
type=int,
default=8,
help="Number of environments to be run in parallel in each collector.",
)
parser.add_argument(
"--n_workers_collector",
type=int,
default=4,
help="Number sub-collectors in the data collector.",
)
parser.add_argument(
"--n_frames",
type=int,
default=64,
help="Number of frames in each batch of data collected.",
)

if __name__ == "__main__":

def make_env():
return DMControlEnv("cheetah", "run", from_pixels=True)

# print the raw env output
print(make_env().fake_tensordict())

def make_transformed_env(env):
return TransformedEnv(
env,
Compose(
ToTensorImage(),
GrayScale(),
Resize(84, 84),
),
)

args = parser.parse_args()
if args.batched:
parallel_env = make_transformed_env(
ParallelEnv(args.n_envs, EnvCreator(make_env))
)
else:
parallel_env = ParallelEnv(
args.n_envs, EnvCreator(lambda: make_transformed_env(make_env()))
)
devices = list(range(torch.cuda.device_count()))[: args.n_workers_collector]
if len(devices) == 1:
devices = devices[0]
elif len(devices) < args.n_workers_collector:
raise RuntimeError(
"This benchmark requires at least as many GPUs as the number of collector workers."
)
collector = MultiaSyncDataCollector(
[
parallel_env,
]
* args.n_workers_collector,
RandomPolicy(parallel_env.action_spec),
total_frames=total_frames,
frames_per_batch=args.n_frames,
devices=devices,
passing_devices=devices,
split_trajs=False,
)
frames = 0
pbar = tqdm.tqdm(total=total_frames)
for i, data in enumerate(collector):
pbar.update(data.numel())
if i == 10:
t = time.time()
if i >= 10:
frames += data.numel()
t = time.time() - t
del collector
print(f"\n\nframes per sec: {frames/t: 4.4f} (frames={frames}, t={t})\n\n")
exit()
Loading