Skip to content
This repository was archived by the owner on Oct 31, 2023. It is now read-only.
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -4,5 +4,5 @@
**/outputs/*
*.egg-info/*
*.eggs/*
*.so
*.so*
compile_commands.json
1 change: 1 addition & 0 deletions examples/atari/dqn/atari_apex_dqn.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import copy
import logging
import time
import os

import hydra

Expand Down
7 changes: 4 additions & 3 deletions examples/tutorials/remote_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
# LICENSE file in the root directory of this source tree.

import asyncio
import time

import torch
import torch.multiprocessing as mp
Expand Down Expand Up @@ -62,18 +63,18 @@ def main():
adder = Adder()
adder_server = Server(name="adder_server", addr="127.0.0.1:4411")
adder_server.add_service(adder)

adder_client = remote_utils.make_remote(adder, adder_server)

adder_server.start()
time.sleep(2)
adder_client.connect()

a = 1
b = 2
c = adder_client.add(a, b)
print(f"{a} + {b} = {c}")
print("")
asyncio.run(run_batch(adder_client, send_tensor=False))
# print("")
# asyncio.run(run_batch(adder_client, send_tensor=False))
print("")
asyncio.run(run_batch(adder_client, send_tensor=True))

Expand Down
4 changes: 2 additions & 2 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
grpcio==1.47.0
gym
hydra-core
matplotlib
moolib@git+https://github.com/facebookresearch/moolib
numpy
opencv-python
rich
tabulate
torch>=1.5.1
rich
30 changes: 27 additions & 3 deletions rlmeta/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ set(
-march=native -Wfatal-errors -fvisibility=hidden"
)

# set(CMAKE_EXPORT_COMPILE_COMMANDS ON)
set(CMAKE_EXPORT_COMPILE_COMMANDS ON)
add_compile_definitions(_GLIBCXX_USE_CXX11_ABI=0)

# PyTorch dependency
Expand Down Expand Up @@ -48,13 +48,37 @@ add_subdirectory(
${CMAKE_CURRENT_BINARY_DIR}/third_party/pybind11
)

add_subdirectory(
${CMAKE_CURRENT_SOURCE_DIR}/rpc
${CMAKE_CURRENT_BINARY_DIR}/rpc
)

pybind11_add_module(
_rlmeta_extension
${CMAKE_CURRENT_SOURCE_DIR}/cc/circular_buffer.cc
${CMAKE_CURRENT_SOURCE_DIR}/cc/nested_utils.cc
${CMAKE_CURRENT_SOURCE_DIR}/cc/pybind.cc
${CMAKE_CURRENT_SOURCE_DIR}/cc/timestamp_manager.cc
${CMAKE_CURRENT_SOURCE_DIR}/rpc/cc/blocking_counter.cc
${CMAKE_CURRENT_SOURCE_DIR}/rpc/cc/client.cc
${CMAKE_CURRENT_SOURCE_DIR}/rpc/cc/computation_queue.cc
${CMAKE_CURRENT_SOURCE_DIR}/rpc/cc/rpc_future.cc
${CMAKE_CURRENT_SOURCE_DIR}/rpc/cc/rpc_utils.cc
${CMAKE_CURRENT_SOURCE_DIR}/rpc/cc/server.cc
${CMAKE_CURRENT_SOURCE_DIR}/rpc/cc/task.cc
${CMAKE_CURRENT_SOURCE_DIR}/rpc/cc/tensor_wrapper.cc
)
target_include_directories(
_rlmeta_extension PUBLIC ${CMAKE_CURRENT_SOURCE_DIR}/..)
target_link_libraries(_rlmeta_extension PUBLIC torch ${TORCH_PYTHON_LIBRARIES})
_rlmeta_extension
PUBLIC
${CMAKE_CURRENT_SOURCE_DIR}/..
${CMAKE_CURRENT_BINARY_DIR}/rpc
)
target_link_libraries(
_rlmeta_extension
PUBLIC
torch
${TORCH_PYTHON_LIBRARIES}
grpc++
rpc_grpc_proto
)
6 changes: 3 additions & 3 deletions rlmeta/agents/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ def __call__(self, index: int) -> Agent:
return self._cls(*args, **kwargs)

def _make_arg(self, arg: Any, index: int) -> Any:
if isinstance(arg, remote.Remote):
arg = copy.deepcopy(arg)
arg.name = moolib_utils.expend_name_by_index(arg.name, index)
# if isinstance(arg, remote.Remote):
# arg = copy.deepcopy(arg)
# arg.name = moolib_utils.expend_name_by_index(arg.name, index)
return arg
16 changes: 9 additions & 7 deletions rlmeta/agents/dqn/apex_dqn_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@

from typing import Callable, Dict, List, Optional, Sequence

import numpy as np

import torch
import torch.nn as nn

Expand Down Expand Up @@ -138,10 +140,10 @@ def train(self, num_steps: int) -> Optional[StatsDict]:
console.log(f"Training for num_steps = {num_steps}")
for _ in track(range(num_steps), description="Training..."):
t0 = time.perf_counter()
batch, weight, index, timestamp = self.replay_buffer.sample(
index, batch, weight, timestamp = self.replay_buffer.sample(
self.batch_size)
t1 = time.perf_counter()
step_stats = self.train_step(batch, weight, index, timestamp)
step_stats = self.train_step(index, batch, weight, timestamp)
t2 = time.perf_counter()
time_stats = {
"sample_data_time/ms": (t1 - t0) * 1000.0,
Expand All @@ -163,7 +165,7 @@ def train(self, num_steps: int) -> Optional[StatsDict]:
for m in self._additional_models_to_update:
m.push()

episode_stats = self.controller.get_stats()
episode_stats = StatsDict.from_dict(self.controller.get_stats())
stats.update(episode_stats)

return stats
Expand All @@ -172,7 +174,7 @@ def eval(self, num_episodes: Optional[int] = None) -> Optional[StatsDict]:
self.controller.set_phase(Phase.EVAL, limit=num_episodes, reset=True)
while self.controller.get_count() < num_episodes:
time.sleep(1)
stats = self.controller.get_stats()
stats = StatsDict.from_dict(self.controller.get_stats())
return stats

def make_replay(self) -> Optional[List[NestedTensor]]:
Expand Down Expand Up @@ -202,9 +204,9 @@ def make_replay(self) -> Optional[List[NestedTensor]]:

return replay

def train_step(self, batch: NestedTensor, weight: torch.Tensor,
index: torch.Tensor,
timestamp: torch.Tensor) -> Dict[str, float]:
def train_step(self, index: np.ndarray, batch: NestedTensor,
weight: torch.Tensor,
timestamp: np.ndarray) -> Dict[str, float]:
device = next(self.model.parameters()).device
batch = nested_utils.map_nested(lambda x: x.to(device), batch)
self.optimizer.zero_grad()
Expand Down
4 changes: 2 additions & 2 deletions rlmeta/agents/ppo/ppo_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,7 @@ def train(self, num_steps: int) -> Optional[StatsDict]:
if self.step_counter % self.push_every_n_steps == 0:
self.model.push()

episode_stats = self.controller.get_stats()
episode_stats = StatsDict.from_dict(self.controller.get_stats())
stats.update(episode_stats)

return stats
Expand All @@ -165,7 +165,7 @@ def eval(self, num_episodes: Optional[int] = None) -> Optional[StatsDict]:
self.controller.set_phase(Phase.EVAL, limit=num_episodes, reset=True)
while self.controller.get_count() < num_episodes:
time.sleep(1)
stats = self.controller.get_stats()
stats = StatsDict.from_dict(self.controller.get_stats())
return stats

def device(self) -> torch.device:
Expand Down
27 changes: 15 additions & 12 deletions rlmeta/cc/nested_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
#include "rlmeta/cc/nested_utils.h"

#include <utility>
#include <vector>

namespace rlmeta {

Expand Down Expand Up @@ -34,8 +33,10 @@ void VisitNestedImpl(Function func, const py::object& obj) {

if (py::isinstance<py::dict>(obj)) {
const py::dict src = py::reinterpret_borrow<py::dict>(obj);
for (const auto [k, v] : src) {
VisitNestedImpl(func, py::reinterpret_borrow<py::object>(v));
const std::vector<std::string> keys = SortedKeys(src);
for (const std::string& k : keys) {
VisitNestedImpl(func,
py::reinterpret_borrow<py::object>(src[py::str(k)]));
}
return;
}
Expand Down Expand Up @@ -68,8 +69,11 @@ py::object MapNestedImpl(Function func, const py::object& obj) {
if (py::isinstance<py::dict>(obj)) {
const py::dict src = py::reinterpret_borrow<py::dict>(obj);
py::dict dst;
for (const auto [k, v] : src) {
dst[k] = MapNestedImpl(func, py::reinterpret_borrow<py::object>(v));
const std::vector<std::string> keys = SortedKeys(src);
for (const std::string& k : keys) {
const py::str key = py::str(k);
dst[key] =
MapNestedImpl(func, py::reinterpret_borrow<py::object>(src[key]));
}
return std::move(dst);
}
Expand Down Expand Up @@ -150,12 +154,14 @@ py::tuple UnbatchNestedImpl(std::function<py::tuple(const py::object&)> func,
for (int64_t i = 0; i < batch_size; ++i) {
dst[i] = py::dict();
}
for (const auto [k, v] : src) {
const std::vector<std::string> keys = SortedKeys(src);
for (const std::string& k : keys) {
const py::str key = py::str(k);
py::tuple cur = UnbatchNestedImpl(
func, py::reinterpret_borrow<py::object>(v), batch_size);
func, py::reinterpret_borrow<py::object>(src[key]), batch_size);
for (int64_t i = 0; i < batch_size; ++i) {
py::dict y = py::reinterpret_borrow<py::dict>(dst[i]);
y[k] = cur[i];
y[key] = cur[i];
}
}
return dst;
Expand Down Expand Up @@ -201,10 +207,7 @@ py::tuple UnbatchNested(std::function<py::tuple(const py::object&)> func,
} // namespace nested_utils

void DefineNestedUtils(py::module& m) {
py::module sub =
m.def_submodule("nested_utils", "A submodule of \"_rlmeta_extension\"");

sub.def("flatten_nested", &nested_utils::FlattenNested)
m.def("flatten_nested", &nested_utils::FlattenNested)
.def("map_nested", &nested_utils::MapNested)
.def("collate_nested",
py::overload_cast<std::function<py::object(const py::tuple&)>,
Expand Down
24 changes: 24 additions & 0 deletions rlmeta/cc/nested_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,37 @@
#include <pybind11/pybind11.h>

#include <functional>
#include <string>
#include <vector>

namespace py = pybind11;

namespace rlmeta {

namespace nested_utils {

template <class Dict>
inline std::vector<std::string> SortedKeys(const Dict& dict) {
std::vector<std::string> ret;
ret.reserve(dict.size());
for (const auto [k, v] : dict) {
ret.push_back(k);
}
std::sort(ret.begin(), ret.end());
return ret;
}

template <>
inline std::vector<std::string> SortedKeys<py::dict>(const py::dict& dict) {
std::vector<std::string> ret;
ret.reserve(dict.size());
for (const auto [k, v] : dict) {
ret.push_back(py::reinterpret_borrow<py::str>(k));
}
std::sort(ret.begin(), ret.end());
return ret;
}

py::tuple FlattenNested(const py::object& obj);

py::object MapNested(std::function<py::object(const py::object&)> func,
Expand Down
5 changes: 1 addition & 4 deletions rlmeta/cc/numpy_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,10 +28,7 @@ std::vector<int64_t> NumpyArrayShape(const py::array_t<T>& arr) {

template <typename T_SRC, typename T_DST = T_SRC>
py::array_t<T_DST> NumpyEmptyLike(const py::array_t<T_SRC>& src) {
py::array_t<T_DST> dst(src.size());
const std::vector<int64_t> shape = NumpyArrayShape(src);
dst.resize(shape);
return dst;
return py::array_t<T_DST>(NumpyArrayShape(src));
}

} // namespace utils
Expand Down
26 changes: 25 additions & 1 deletion rlmeta/cc/pybind.cc
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,12 @@
#include "rlmeta/cc/nested_utils.h"
#include "rlmeta/cc/segment_tree.h"
#include "rlmeta/cc/timestamp_manager.h"
#include "rlmeta/rpc/cc/client.h"
#include "rlmeta/rpc/cc/computation_queue.h"
#include "rlmeta/rpc/cc/rpc_future.h"
#include "rlmeta/rpc/cc/rpc_utils.h"
#include "rlmeta/rpc/cc/server.h"
#include "rlmeta/rpc/cc/task.h"

namespace py = pybind11;

Expand All @@ -21,8 +27,26 @@ PYBIND11_MODULE(_rlmeta_extension, m) {
rlmeta::DefineMinSegmentTree<double>("Fp64", m);

rlmeta::DefineCircularBuffer(m);
rlmeta::DefineNestedUtils(m);
rlmeta::DefineTimestampManager(m);

py::module nested_utils = m.def_submodule(
"nested_utils", "A submodule of \"_rlmeta_extension\" for nested_utils");
rlmeta::DefineNestedUtils(nested_utils);

py::module rpc =
m.def_submodule("rpc", "A submodule of \"_rlmeta_extension\" for RPC");
// rlmeta::rpc::DefineTaskBase(rpc);
rlmeta::rpc::DefineTask(rpc);
rlmeta::rpc::DefineBatchedTask(rpc);
rlmeta::rpc::DefineComputationQueue(rpc);
rlmeta::rpc::DefineBatchedComputationQueue(rpc);
rlmeta::rpc::DefineServer(rpc);
rlmeta::rpc::DefineClient(rpc);
rlmeta::rpc::DefineRpcFuture(rpc);

py::module rpc_utils = rpc.def_submodule(
"rpc_utils", "A submodule of \"_rlmeta_extension.rpc\" for rpc_utils");
rlmeta::rpc::DefineRpcUtils(rpc_utils);
}

} // namespace
34 changes: 34 additions & 0 deletions rlmeta/cc/torch_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@

#pragma once

#include <torch/extension.h>
#include <torch/python.h>
#include <torch/torch.h>

#include <cstdint>
Expand All @@ -20,6 +22,26 @@ struct TorchDataType<bool> {
static constexpr torch::ScalarType value = torch::kBool;
};

template <>
struct TorchDataType<uint8_t> {
static constexpr torch::ScalarType value = torch::kUInt8;
};

template <>
struct TorchDataType<int8_t> {
static constexpr torch::ScalarType value = torch::kInt8;
};

template <>
struct TorchDataType<int16_t> {
static constexpr torch::ScalarType value = torch::kInt16;
};

template <>
struct TorchDataType<int32_t> {
static constexpr torch::ScalarType value = torch::kInt32;
};

template <>
struct TorchDataType<int64_t> {
static constexpr torch::ScalarType value = torch::kInt64;
Expand All @@ -35,5 +57,17 @@ struct TorchDataType<double> {
static constexpr torch::ScalarType value = torch::kDouble;
};

inline bool IsTorchTensor(const py::object& obj) {
return THPVariable_Check(obj.ptr());
}

inline torch::Tensor PyObjectToTorchTensor(const py::object& obj) {
return THPVariable_Unpack(obj.ptr());
}

inline py::object TorchTensorToPyObject(const torch::Tensor& tensor) {
return py::reinterpret_steal<py::object>(THPVariable_Wrap(tensor));
}

} // namespace utils
} // namespace rlmeta
Loading