diff --git a/.gitignore b/.gitignore index d96fadd..02aadb5 100644 --- a/.gitignore +++ b/.gitignore @@ -4,5 +4,5 @@ **/outputs/* *.egg-info/* *.eggs/* -*.so +*.so* compile_commands.json diff --git a/examples/atari/dqn/atari_apex_dqn.py b/examples/atari/dqn/atari_apex_dqn.py index 346b088..17e27c1 100644 --- a/examples/atari/dqn/atari_apex_dqn.py +++ b/examples/atari/dqn/atari_apex_dqn.py @@ -6,6 +6,7 @@ import copy import logging import time +import os import hydra diff --git a/examples/tutorials/remote_example.py b/examples/tutorials/remote_example.py index 1b7bcb0..7bddad9 100644 --- a/examples/tutorials/remote_example.py +++ b/examples/tutorials/remote_example.py @@ -4,6 +4,7 @@ # LICENSE file in the root directory of this source tree. import asyncio +import time import torch import torch.multiprocessing as mp @@ -62,18 +63,18 @@ def main(): adder = Adder() adder_server = Server(name="adder_server", addr="127.0.0.1:4411") adder_server.add_service(adder) - adder_client = remote_utils.make_remote(adder, adder_server) adder_server.start() + time.sleep(2) adder_client.connect() a = 1 b = 2 c = adder_client.add(a, b) print(f"{a} + {b} = {c}") - print("") - asyncio.run(run_batch(adder_client, send_tensor=False)) + # print("") + # asyncio.run(run_batch(adder_client, send_tensor=False)) print("") asyncio.run(run_batch(adder_client, send_tensor=True)) diff --git a/requirements.txt b/requirements.txt index 3e16b67..2890d42 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,9 +1,9 @@ +grpcio==1.47.0 gym hydra-core matplotlib -moolib@git+https://github.com/facebookresearch/moolib numpy opencv-python +rich tabulate torch>=1.5.1 -rich \ No newline at end of file diff --git a/rlmeta/CMakeLists.txt b/rlmeta/CMakeLists.txt index 0cf1b2f..5d2b743 100644 --- a/rlmeta/CMakeLists.txt +++ b/rlmeta/CMakeLists.txt @@ -9,7 +9,7 @@ set( -march=native -Wfatal-errors -fvisibility=hidden" ) -# set(CMAKE_EXPORT_COMPILE_COMMANDS ON) +set(CMAKE_EXPORT_COMPILE_COMMANDS ON) add_compile_definitions(_GLIBCXX_USE_CXX11_ABI=0) # PyTorch dependency @@ -48,13 +48,37 @@ add_subdirectory( ${CMAKE_CURRENT_BINARY_DIR}/third_party/pybind11 ) +add_subdirectory( + ${CMAKE_CURRENT_SOURCE_DIR}/rpc + ${CMAKE_CURRENT_BINARY_DIR}/rpc +) + pybind11_add_module( _rlmeta_extension ${CMAKE_CURRENT_SOURCE_DIR}/cc/circular_buffer.cc ${CMAKE_CURRENT_SOURCE_DIR}/cc/nested_utils.cc ${CMAKE_CURRENT_SOURCE_DIR}/cc/pybind.cc ${CMAKE_CURRENT_SOURCE_DIR}/cc/timestamp_manager.cc + ${CMAKE_CURRENT_SOURCE_DIR}/rpc/cc/blocking_counter.cc + ${CMAKE_CURRENT_SOURCE_DIR}/rpc/cc/client.cc + ${CMAKE_CURRENT_SOURCE_DIR}/rpc/cc/computation_queue.cc + ${CMAKE_CURRENT_SOURCE_DIR}/rpc/cc/rpc_future.cc + ${CMAKE_CURRENT_SOURCE_DIR}/rpc/cc/rpc_utils.cc + ${CMAKE_CURRENT_SOURCE_DIR}/rpc/cc/server.cc + ${CMAKE_CURRENT_SOURCE_DIR}/rpc/cc/task.cc + ${CMAKE_CURRENT_SOURCE_DIR}/rpc/cc/tensor_wrapper.cc ) target_include_directories( - _rlmeta_extension PUBLIC ${CMAKE_CURRENT_SOURCE_DIR}/..) -target_link_libraries(_rlmeta_extension PUBLIC torch ${TORCH_PYTHON_LIBRARIES}) + _rlmeta_extension + PUBLIC + ${CMAKE_CURRENT_SOURCE_DIR}/.. + ${CMAKE_CURRENT_BINARY_DIR}/rpc +) +target_link_libraries( + _rlmeta_extension + PUBLIC + torch + ${TORCH_PYTHON_LIBRARIES} + grpc++ + rpc_grpc_proto +) diff --git a/rlmeta/agents/agent.py b/rlmeta/agents/agent.py index b8ef727..55bdf46 100644 --- a/rlmeta/agents/agent.py +++ b/rlmeta/agents/agent.py @@ -101,7 +101,7 @@ def __call__(self, index: int) -> Agent: return self._cls(*args, **kwargs) def _make_arg(self, arg: Any, index: int) -> Any: - if isinstance(arg, remote.Remote): - arg = copy.deepcopy(arg) - arg.name = moolib_utils.expend_name_by_index(arg.name, index) + # if isinstance(arg, remote.Remote): + # arg = copy.deepcopy(arg) + # arg.name = moolib_utils.expend_name_by_index(arg.name, index) return arg diff --git a/rlmeta/agents/dqn/apex_dqn_agent.py b/rlmeta/agents/dqn/apex_dqn_agent.py index 1f65ea7..eaa6208 100644 --- a/rlmeta/agents/dqn/apex_dqn_agent.py +++ b/rlmeta/agents/dqn/apex_dqn_agent.py @@ -7,6 +7,8 @@ from typing import Callable, Dict, List, Optional, Sequence +import numpy as np + import torch import torch.nn as nn @@ -138,10 +140,10 @@ def train(self, num_steps: int) -> Optional[StatsDict]: console.log(f"Training for num_steps = {num_steps}") for _ in track(range(num_steps), description="Training..."): t0 = time.perf_counter() - batch, weight, index, timestamp = self.replay_buffer.sample( + index, batch, weight, timestamp = self.replay_buffer.sample( self.batch_size) t1 = time.perf_counter() - step_stats = self.train_step(batch, weight, index, timestamp) + step_stats = self.train_step(index, batch, weight, timestamp) t2 = time.perf_counter() time_stats = { "sample_data_time/ms": (t1 - t0) * 1000.0, @@ -163,7 +165,7 @@ def train(self, num_steps: int) -> Optional[StatsDict]: for m in self._additional_models_to_update: m.push() - episode_stats = self.controller.get_stats() + episode_stats = StatsDict.from_dict(self.controller.get_stats()) stats.update(episode_stats) return stats @@ -172,7 +174,7 @@ def eval(self, num_episodes: Optional[int] = None) -> Optional[StatsDict]: self.controller.set_phase(Phase.EVAL, limit=num_episodes, reset=True) while self.controller.get_count() < num_episodes: time.sleep(1) - stats = self.controller.get_stats() + stats = StatsDict.from_dict(self.controller.get_stats()) return stats def make_replay(self) -> Optional[List[NestedTensor]]: @@ -202,9 +204,9 @@ def make_replay(self) -> Optional[List[NestedTensor]]: return replay - def train_step(self, batch: NestedTensor, weight: torch.Tensor, - index: torch.Tensor, - timestamp: torch.Tensor) -> Dict[str, float]: + def train_step(self, index: np.ndarray, batch: NestedTensor, + weight: torch.Tensor, + timestamp: np.ndarray) -> Dict[str, float]: device = next(self.model.parameters()).device batch = nested_utils.map_nested(lambda x: x.to(device), batch) self.optimizer.zero_grad() diff --git a/rlmeta/agents/ppo/ppo_agent.py b/rlmeta/agents/ppo/ppo_agent.py index d86716a..ce0e641 100644 --- a/rlmeta/agents/ppo/ppo_agent.py +++ b/rlmeta/agents/ppo/ppo_agent.py @@ -156,7 +156,7 @@ def train(self, num_steps: int) -> Optional[StatsDict]: if self.step_counter % self.push_every_n_steps == 0: self.model.push() - episode_stats = self.controller.get_stats() + episode_stats = StatsDict.from_dict(self.controller.get_stats()) stats.update(episode_stats) return stats @@ -165,7 +165,7 @@ def eval(self, num_episodes: Optional[int] = None) -> Optional[StatsDict]: self.controller.set_phase(Phase.EVAL, limit=num_episodes, reset=True) while self.controller.get_count() < num_episodes: time.sleep(1) - stats = self.controller.get_stats() + stats = StatsDict.from_dict(self.controller.get_stats()) return stats def device(self) -> torch.device: diff --git a/rlmeta/cc/nested_utils.cc b/rlmeta/cc/nested_utils.cc index fc6600d..1660a6f 100644 --- a/rlmeta/cc/nested_utils.cc +++ b/rlmeta/cc/nested_utils.cc @@ -6,7 +6,6 @@ #include "rlmeta/cc/nested_utils.h" #include -#include namespace rlmeta { @@ -34,8 +33,10 @@ void VisitNestedImpl(Function func, const py::object& obj) { if (py::isinstance(obj)) { const py::dict src = py::reinterpret_borrow(obj); - for (const auto [k, v] : src) { - VisitNestedImpl(func, py::reinterpret_borrow(v)); + const std::vector keys = SortedKeys(src); + for (const std::string& k : keys) { + VisitNestedImpl(func, + py::reinterpret_borrow(src[py::str(k)])); } return; } @@ -68,8 +69,11 @@ py::object MapNestedImpl(Function func, const py::object& obj) { if (py::isinstance(obj)) { const py::dict src = py::reinterpret_borrow(obj); py::dict dst; - for (const auto [k, v] : src) { - dst[k] = MapNestedImpl(func, py::reinterpret_borrow(v)); + const std::vector keys = SortedKeys(src); + for (const std::string& k : keys) { + const py::str key = py::str(k); + dst[key] = + MapNestedImpl(func, py::reinterpret_borrow(src[key])); } return std::move(dst); } @@ -150,12 +154,14 @@ py::tuple UnbatchNestedImpl(std::function func, for (int64_t i = 0; i < batch_size; ++i) { dst[i] = py::dict(); } - for (const auto [k, v] : src) { + const std::vector keys = SortedKeys(src); + for (const std::string& k : keys) { + const py::str key = py::str(k); py::tuple cur = UnbatchNestedImpl( - func, py::reinterpret_borrow(v), batch_size); + func, py::reinterpret_borrow(src[key]), batch_size); for (int64_t i = 0; i < batch_size; ++i) { py::dict y = py::reinterpret_borrow(dst[i]); - y[k] = cur[i]; + y[key] = cur[i]; } } return dst; @@ -201,10 +207,7 @@ py::tuple UnbatchNested(std::function func, } // namespace nested_utils void DefineNestedUtils(py::module& m) { - py::module sub = - m.def_submodule("nested_utils", "A submodule of \"_rlmeta_extension\""); - - sub.def("flatten_nested", &nested_utils::FlattenNested) + m.def("flatten_nested", &nested_utils::FlattenNested) .def("map_nested", &nested_utils::MapNested) .def("collate_nested", py::overload_cast, diff --git a/rlmeta/cc/nested_utils.h b/rlmeta/cc/nested_utils.h index ac1c6e6..dd5c2d1 100644 --- a/rlmeta/cc/nested_utils.h +++ b/rlmeta/cc/nested_utils.h @@ -9,6 +9,8 @@ #include #include +#include +#include namespace py = pybind11; @@ -16,6 +18,28 @@ namespace rlmeta { namespace nested_utils { +template +inline std::vector SortedKeys(const Dict& dict) { + std::vector ret; + ret.reserve(dict.size()); + for (const auto [k, v] : dict) { + ret.push_back(k); + } + std::sort(ret.begin(), ret.end()); + return ret; +} + +template <> +inline std::vector SortedKeys(const py::dict& dict) { + std::vector ret; + ret.reserve(dict.size()); + for (const auto [k, v] : dict) { + ret.push_back(py::reinterpret_borrow(k)); + } + std::sort(ret.begin(), ret.end()); + return ret; +} + py::tuple FlattenNested(const py::object& obj); py::object MapNested(std::function func, diff --git a/rlmeta/cc/numpy_utils.h b/rlmeta/cc/numpy_utils.h index f2b6a85..38e4c0a 100644 --- a/rlmeta/cc/numpy_utils.h +++ b/rlmeta/cc/numpy_utils.h @@ -28,10 +28,7 @@ std::vector NumpyArrayShape(const py::array_t& arr) { template py::array_t NumpyEmptyLike(const py::array_t& src) { - py::array_t dst(src.size()); - const std::vector shape = NumpyArrayShape(src); - dst.resize(shape); - return dst; + return py::array_t(NumpyArrayShape(src)); } } // namespace utils diff --git a/rlmeta/cc/pybind.cc b/rlmeta/cc/pybind.cc index 0f69fca..c36bf5e 100644 --- a/rlmeta/cc/pybind.cc +++ b/rlmeta/cc/pybind.cc @@ -9,6 +9,12 @@ #include "rlmeta/cc/nested_utils.h" #include "rlmeta/cc/segment_tree.h" #include "rlmeta/cc/timestamp_manager.h" +#include "rlmeta/rpc/cc/client.h" +#include "rlmeta/rpc/cc/computation_queue.h" +#include "rlmeta/rpc/cc/rpc_future.h" +#include "rlmeta/rpc/cc/rpc_utils.h" +#include "rlmeta/rpc/cc/server.h" +#include "rlmeta/rpc/cc/task.h" namespace py = pybind11; @@ -21,8 +27,26 @@ PYBIND11_MODULE(_rlmeta_extension, m) { rlmeta::DefineMinSegmentTree("Fp64", m); rlmeta::DefineCircularBuffer(m); - rlmeta::DefineNestedUtils(m); rlmeta::DefineTimestampManager(m); + + py::module nested_utils = m.def_submodule( + "nested_utils", "A submodule of \"_rlmeta_extension\" for nested_utils"); + rlmeta::DefineNestedUtils(nested_utils); + + py::module rpc = + m.def_submodule("rpc", "A submodule of \"_rlmeta_extension\" for RPC"); + // rlmeta::rpc::DefineTaskBase(rpc); + rlmeta::rpc::DefineTask(rpc); + rlmeta::rpc::DefineBatchedTask(rpc); + rlmeta::rpc::DefineComputationQueue(rpc); + rlmeta::rpc::DefineBatchedComputationQueue(rpc); + rlmeta::rpc::DefineServer(rpc); + rlmeta::rpc::DefineClient(rpc); + rlmeta::rpc::DefineRpcFuture(rpc); + + py::module rpc_utils = rpc.def_submodule( + "rpc_utils", "A submodule of \"_rlmeta_extension.rpc\" for rpc_utils"); + rlmeta::rpc::DefineRpcUtils(rpc_utils); } } // namespace diff --git a/rlmeta/cc/torch_utils.h b/rlmeta/cc/torch_utils.h index ab74bda..7420813 100644 --- a/rlmeta/cc/torch_utils.h +++ b/rlmeta/cc/torch_utils.h @@ -5,6 +5,8 @@ #pragma once +#include +#include #include #include @@ -20,6 +22,26 @@ struct TorchDataType { static constexpr torch::ScalarType value = torch::kBool; }; +template <> +struct TorchDataType { + static constexpr torch::ScalarType value = torch::kUInt8; +}; + +template <> +struct TorchDataType { + static constexpr torch::ScalarType value = torch::kInt8; +}; + +template <> +struct TorchDataType { + static constexpr torch::ScalarType value = torch::kInt16; +}; + +template <> +struct TorchDataType { + static constexpr torch::ScalarType value = torch::kInt32; +}; + template <> struct TorchDataType { static constexpr torch::ScalarType value = torch::kInt64; @@ -35,5 +57,17 @@ struct TorchDataType { static constexpr torch::ScalarType value = torch::kDouble; }; +inline bool IsTorchTensor(const py::object& obj) { + return THPVariable_Check(obj.ptr()); +} + +inline torch::Tensor PyObjectToTorchTensor(const py::object& obj) { + return THPVariable_Unpack(obj.ptr()); +} + +inline py::object TorchTensorToPyObject(const torch::Tensor& tensor) { + return py::reinterpret_steal(THPVariable_Wrap(tensor)); +} + } // namespace utils } // namespace rlmeta diff --git a/rlmeta/core/controller.py b/rlmeta/core/controller.py index 0255379..91e451e 100644 --- a/rlmeta/core/controller.py +++ b/rlmeta/core/controller.py @@ -65,8 +65,8 @@ def get_count(self) -> int: return self.count @remote.remote_method(batch_size=None) - def get_stats(self) -> StatsDict: - return self.stats + def get_stats(self) -> Dict[str, Dict[str, float]]: + return self.stats.dict() @remote.remote_method(batch_size=None) def add_episode(self, phase: Phase, stats: Dict[str, float]) -> None: diff --git a/rlmeta/core/model.py b/rlmeta/core/model.py index 03fe9f6..80212d7 100644 --- a/rlmeta/core/model.py +++ b/rlmeta/core/model.py @@ -45,9 +45,10 @@ def __init__(self, server_name: str, server_addr: str, name: Optional[str] = None, - timeout: float = 60) -> None: + timeout: float = 60, + py_aio_client: bool = True) -> None: self._wrapped = model - self._reset(server_name, server_addr, name, timeout) + self._reset(server_name, server_addr, name, timeout, py_aio_client) # TODO: Find a better way to implement this def __getattribute__(self, attr: str) -> Any: @@ -64,26 +65,31 @@ def __call__(self, *args, **kwargs) -> Any: return self.wrapped(*args, **kwargs) def pull(self) -> None: - state_dict = self.client.sync(self.server_name, - self.remote_method_name("pull")) + # state_dict = self.client.sync(self.server_name, + # self.remote_method_name("pull")) + state_dict = self.client.rpc(self.remote_method_name("pull")) self.wrapped.load_state_dict(state_dict) async def async_pull(self) -> None: - state_dict = await self.client.async_(self.server_name, - self.remote_method_name("pull")) + # state_dict = await self.client.async_(self.server_name, + # self.remote_method_name("pull")) + state_dict = await self.client.async_rpc(self.remote_method_name("pull") + ) self.wrapped.load_state_dict(state_dict) def push(self) -> None: state_dict = self.wrapped.state_dict() state_dict = nested_utils.map_nested(lambda x: x.cpu(), state_dict) - self.client.sync(self.server_name, self.remote_method_name("push"), - state_dict) + # self.client.sync(self.server_name, self.remote_method_name("push"), + # state_dict) + self.client.rpc(self.remote_method_name("push"), state_dict) async def async_push(self) -> None: state_dict = self.wrapped.state_dict() state_dict = nested_utils.map_nested(lambda x: x.cpu(), state_dict) - await self.client.async_(self.server_name, - self.remote_method_name("push"), state_dict) + # await self.client.async_(self.server_name, + # self.remote_method_name("push"), state_dict) + await self.client.async_rpc(self.remote_method_name("push"), state_dict) def _bind(self) -> None: pass diff --git a/rlmeta/core/remote.py b/rlmeta/core/remote.py index f70e14b..bbddb26 100644 --- a/rlmeta/core/remote.py +++ b/rlmeta/core/remote.py @@ -10,7 +10,8 @@ from typing import Any, Callable, List, Optional -import moolib +# import moolib +import rlmeta.rpc as rpc from rlmeta.core.launchable import Launchable from rlmeta.utils.moolib_utils import generate_random_name @@ -54,14 +55,15 @@ def __init__(self, server_name: str, server_addr: str, name: Optional[str] = None, - timeout: float = 60) -> None: + timeout: float = 60, + py_aio_client: bool = True) -> None: self._target_repr = repr(target) self._server_name = server_name self._server_addr = server_addr self._remote_methods = target.remote_methods self._identifier = target.identifier - self._reset(server_name, server_addr, name, timeout) + self._reset(server_name, server_addr, name, timeout, py_aio_client) self._client_methods = {} # TODO: Find a better way to implement this @@ -97,8 +99,11 @@ def server_name(self) -> str: def server_addr(self) -> str: return self._server_addr + # @property + # def client(self) -> Optional[moolib.Client]: + # return self._client @property - def client(self) -> Optional[moolib.Client]: + def client(self) -> Optional[rpc.Client]: return self._client @property @@ -116,11 +121,15 @@ def remote_method_name(self, method: str) -> str: def connect(self) -> None: if self._connected: return - self._client = moolib.Rpc() - self._client.set_transports(["uv"]) - self._client.set_name(self._name) - self._client.set_timeout(self._timeout) + # self._client = moolib.Rpc() + # self._client.set_transports(["uv"]) + # self._client.set_name(self._name) + # self._client.set_timeout(self._timeout) + # self._client.connect(self._server_addr) + + self._client = rpc.Client(self._py_aio_client) self._client.connect(self._server_addr) + self._bind() self._connected = True @@ -128,24 +137,31 @@ def _reset(self, server_name: str, server_addr: str, name: Optional[str] = None, - timeout: float = 60) -> None: + timeout: float = 60, + py_aio_client: bool = True) -> None: if name is None: name = generate_random_name() self._server_name = server_name self._server_addr = server_addr self._name = name self._timeout = timeout + self._py_aio_client = py_aio_client self._client = None self._connected = False def _bind(self) -> None: for method in self._remote_methods: + # self._client_methods[method] = functools.partial( + # self.client.sync, self.server_name, + # self.remote_method_name(method)) + # self._client_methods["async_" + method] = functools.partial( + # self.client.async_, self.server_name, + # self.remote_method_name(method)) + self._client_methods[method] = functools.partial( - self.client.sync, self.server_name, - self.remote_method_name(method)) + self.client.rpc, self.remote_method_name(method)) self._client_methods["async_" + method] = functools.partial( - self.client.async_, self.server_name, - self.remote_method_name(method)) + self.client.async_rpc, self.remote_method_name(method)) def remote_method(batch_size: Optional[int] = None) -> Callable[..., Any]: diff --git a/rlmeta/core/replay_buffer.py b/rlmeta/core/replay_buffer.py index a6ac98e..a06a744 100644 --- a/rlmeta/core/replay_buffer.py +++ b/rlmeta/core/replay_buffer.py @@ -4,15 +4,19 @@ # LICENSE file in the root directory of this source tree. import collections -import time +import concurrent.futures import logging +import time from typing import Callable, Optional, Sequence, Tuple, Union -from rich.console import Console + import numpy as np import torch +from rich.console import Console + import rlmeta.core.remote as remote +import rlmeta.rpc as rpc import rlmeta.utils.data_utils as data_utils import rlmeta.utils.nested_utils as nested_utils @@ -117,7 +121,7 @@ def __init__(self, self._priority_type = priority_type self._sum_tree = SumSegmentTree(capacity, dtype=priority_type) - self._min_tree = MinSegmentTree(capacity, dtype=priority_type) + # self._min_tree = MinSegmentTree(capacity, dtype=priority_type) self._max_priority = 1.0 self._timestamps = TimestampManager(capacity) @@ -169,9 +173,10 @@ def extend(self, def sample( self, batch_size: int ) -> Tuple[NestedTensor, torch.Tensor, torch.Tensor, torch.Tensor]: - data, weight, index, timestamp = self._sample(batch_size) - return data, weight, torch.from_numpy(index), torch.from_numpy( - timestamp) + # data, weight, index, timestamp = self._sample(batch_size) + # return data, weight, torch.from_numpy(index), torch.from_numpy( + # timestamp) + return self._sample(batch_size) @remote.remote_method(batch_size=None) def update_priority(self, @@ -196,7 +201,7 @@ def warm_up(self, learning_starts: Optional[int] = None) -> None: def _init_priority(self, index) -> None: priority = self._max_priority**self.alpha self._sum_tree[index] = priority - self._min_tree[index] = priority + # self._min_tree[index] = priority def _update_priority( self, @@ -225,17 +230,18 @@ def _update_priority( if isinstance(index, int): if mask is None or mask: self._sum_tree[index] = priority - self._min_tree[index] = priority + # self._min_tree[index] = priority else: self._sum_tree.update(index, priority, mask) - self._min_tree.update(index, priority, mask) + # self._min_tree.update(index, priority, mask) def _compute_weight( self, index: Union[int, Tensor]) -> Union[float, torch.Tensor]: p = self._sum_tree[index] if isinstance(p, np.ndarray): p = torch.from_numpy(p) - p_min = self._min_tree.query(0, self.capacity) + # p_min = self._min_tree.query(0, self.capacity) + p_min = p.min() # Importance sampling weight formula: # w_i = (p_i / sum(p) * N) ^ (-beta) @@ -254,7 +260,8 @@ def _append(self, self._init_priority(index) else: self._update_priority(index, priority) - self._update_timestamp(index) + # self._update_timestamp(index) + self._timestamps.update(index) return index def _extend(self, @@ -277,7 +284,7 @@ def _sample( index = self._sum_tree.scan_lower_bound(mass) data, weight = self.__getitem__(index) timestamp = self._timestamps[index] - return data, weight, index, timestamp + return index, data, weight, timestamp class RemoteReplayBuffer(remote.Remote): @@ -289,12 +296,19 @@ def __init__(self, name: Optional[str] = None, prefetch: int = 0, timeout: float = 60) -> None: - super().__init__(target, server_name, server_addr, name, timeout) - self._prefetch = prefetch - self._futures = collections.deque() + # Disable python asyncio client for large data transmission. + super().__init__(target, + server_name, + server_addr, + name, + timeout, + py_aio_client=False) self._server_name = server_name self._server_addr = server_addr + self._prefetch = prefetch + self._futures = collections.deque() + def __repr__(self): return (f"RemoteReplayBuffer(server_name={self._server_name}, " + f"server_addr={self._server_addr})") @@ -303,21 +317,37 @@ def __repr__(self): def prefetch(self) -> Optional[int]: return self._prefetch + # def connect(self) -> None: + # if self._connected: + # return + # + # self._client = rpc.Client() + # self._client.connect(self._server_addr) + # + # self._bind() + # self._connected = True + def sample( self, batch_size: int ) -> Union[NestedTensor, Tuple[NestedTensor, torch.Tensor, torch.Tensor, torch.Tensor]]: if len(self._futures) > 0: - ret = self._futures.popleft().result() + # ret = self._futures.popleft().result() + ret = self._futures.popleft().get() else: - ret = self.client.sync(self.server_name, - self.remote_method_name("sample"), - batch_size) - - while len(self._futures) < self.prefetch: - fut = self.client.async_(self.server_name, - self.remote_method_name("sample"), - batch_size) + # ret = self.client.sync(self.server_name, + # self.remote_method_name("sample"), + # batch_size) + ret = self.client.rpc(self.remote_method_name("sample"), batch_size) + + # while len(self._futures) < self.prefetch: + # fut = self.client.async_(self.server_name, + # self.remote_method_name("sample"), + # batch_size) + # self._futures.append(fut) + while len(self._futures) < self._prefetch: + fut = self.client.rpc_future(self.remote_method_name("sample"), + batch_size) self._futures.append(fut) return ret @@ -329,14 +359,21 @@ async def async_sample( if len(self._futures) > 0: ret = await self._futures.popleft() else: - ret = await self.client.async_(self.server_name, - self.remote_method_name("sample"), - batch_size) - - while len(self._futures) < self.prefetch: - fut = self.client.async_(self.server_name, - self.remote_method_name("sample"), - batch_size) + # ret = await self.client.async_(self.server_name, + # self.remote_method_name("sample"), + # batch_size) + ret = await self.client.async_rpc(self.remote_method_name("sample"), + batch_size) + + # while len(self._futures) < self.prefetch: + # fut = self.client.async_(self.server_name, + # self.remote_method_name("sample"), + # batch_size) + # self._futures.append(fut) + + while len(self._futures) < self._prefetch: + fut = self.client.async_rpc(self.remote_method_name("sample"), + batch_size) self._futures.append(fut) return ret diff --git a/rlmeta/core/server.py b/rlmeta/core/server.py index b2ff30e..70146b5 100644 --- a/rlmeta/core/server.py +++ b/rlmeta/core/server.py @@ -14,8 +14,9 @@ import torch.multiprocessing as mp from rich.console import Console -import moolib +# import moolib +import rlmeta.rpc as rpc import rlmeta.utils.asyncio_utils as asyncio_utils from rlmeta.core.launchable import Launchable @@ -85,64 +86,75 @@ def init_execution(self) -> None: if isinstance(service, Launchable): service.init_execution() - self._server = moolib.Rpc() - self._server.set_transports(["uv"]) - self._server.set_name(self._name) - self._server.set_timeout(self._timeout) - console.log(f"Server={self.name} listening to {self._addr}") - try: - self._server.listen(self._addr) - except: - console.log(f"ERROR on listen({self._addr}) from: server={self}") - raise - - def _start_services(self) -> NoReturn: - self._loop = asyncio.get_event_loop() - self._tasks = [] - console.log(f"Server={self.name} starting services: {self._services}") + # self._server = moolib.Rpc() + # self._server.set_transports(["uv"]) + # self._server.set_name(self._name) + # self._server.set_timeout(self._timeout) + # console.log(f"Server={self.name} listening to {self._addr}") + # try: + # self._server.listen(self._addr) + # except: + # console.log(f"ERROR on listen({self._addr}) from: server={self}") + # raise + + self._server = rpc.Server(self._addr) for service in self._services: for method in service.remote_methods: method_impl = getattr(service, method) batch_size = getattr(method_impl, "__batch_size__", None) - self._add_server_task(service.remote_method_name(method), + self._server.register(service.remote_method_name(method), method_impl, batch_size) - try: - if not self._loop.is_running(): - self._loop.run_forever() - except Exception as e: - logging.error(e) - raise - finally: - for task in self._tasks: - task.cancel() - self._loop.stop() - self._loop.close() + + def _start_services(self) -> NoReturn: + # self._loop = asyncio.get_event_loop() + # self._tasks = [] + # console.log(f"Server={self.name} starting services: {self._services}") + # for service in self._services: + # for method in service.remote_methods: + # method_impl = getattr(service, method) + # batch_size = getattr(method_impl, "__batch_size__", None) + # self._add_server_task(service.remote_method_name(method), + # method_impl, batch_size) + # try: + # if not self._loop.is_running(): + # self._loop.run_forever() + # except Exception as e: + # logging.error(e) + # raise + # finally: + # for task in self._tasks: + # task.cancel() + # self._loop.stop() + # self._loop.close() + + self._server.start() + console.log(f"Server={self.name} listening to {self._addr}") console.log(f"Server={self.name} services started") - def _add_server_task(self, func_name: str, func_impl: Callable[..., Any], - batch_size: Optional[int]) -> None: - if batch_size is None: - que = self._server.define_queue(func_name) - else: - que = self._server.define_queue(func_name, - batch_size=batch_size, - dynamic_batching=True) - task = asyncio_utils.create_task(self._loop, - self._async_process(que, func_impl)) - self._tasks.append(task) - - async def _async_process(self, que: moolib.Queue, - func: Callable[..., Any]) -> None: - try: - while True: - ret_cb, args, kwargs = await que - ret = func(*args, **kwargs) - ret_cb(ret) - except asyncio.CancelledError: - pass - except Exception as e: - logging.error(e) - raise e + # def _add_server_task(self, func_name: str, func_impl: Callable[..., Any], + # batch_size: Optional[int]) -> None: + # if batch_size is None: + # que = self._server.define_queue(func_name) + # else: + # que = self._server.define_queue(func_name, + # batch_size=batch_size, + # dynamic_batching=True) + # task = asyncio_utils.create_task(self._loop, + # self._async_process(que, func_impl)) + # self._tasks.append(task) + # + # async def _async_process(self, que: moolib.Queue, + # func: Callable[..., Any]) -> None: + # try: + # while True: + # ret_cb, args, kwargs = await que + # ret = func(*args, **kwargs) + # ret_cb(ret) + # except asyncio.CancelledError: + # pass + # except Exception as e: + # logging.error(e) + # raise e class ServerList: diff --git a/rlmeta/rpc/CMakeLists.txt b/rlmeta/rpc/CMakeLists.txt new file mode 100644 index 0000000..9cd2ec3 --- /dev/null +++ b/rlmeta/rpc/CMakeLists.txt @@ -0,0 +1,92 @@ +cmake_minimum_required(VERSION 3.14 FATAL_ERROR) + +project(rlmeta) + +set(CMAKE_CXX_STANDARD 17) +set( + CMAKE_CXX_FLAGS + "${CMAKE_CXX_FLAGS} -O3 -Wall -Wextra -Wno-register -Wno-comment -fPIC \ + -march=native -Wfatal-errors -fvisibility=hidden" +) + +include(FetchContent) + +# gRPC dependency +FetchContent_Declare( + grpc + GIT_REPOSITORY https://github.com/grpc/grpc + GIT_TAG v1.47.0 +) + +set(FETCHCONTENT_QUIET OFF) +FetchContent_MakeAvailable(grpc) + +set(_PROTOBUF_LIBPROTOBUF libprotobuf) +set(_REFLECTION grpc++_reflection) +set(_PROTOBUF_PROTOC $) +set(_GRPC_GRPCPP grpc++) +if(CMAKE_CROSSCOMPILING) + find_program(_GRPC_CPP_PLUGIN_EXECUTABLE grpc_cpp_plugin) +else() + set(_GRPC_CPP_PLUGIN_EXECUTABLE $) +endif() + +# proto files +get_filename_component(rpc_proto "${CMAKE_CURRENT_SOURCE_DIR}/protos/rpc.proto" ABSOLUTE) +get_filename_component(rpc_proto_path "${rpc_proto}" PATH) + +# generated files +set(rpc_proto_srcs "${CMAKE_CURRENT_BINARY_DIR}/rpc.pb.cc") +set(rpc_proto_hdrs "${CMAKE_CURRENT_BINARY_DIR}/rpc.pb.h") +set(rpc_grpc_srcs "${CMAKE_CURRENT_BINARY_DIR}/rpc.grpc.pb.cc") +set(rpc_grpc_hdrs "${CMAKE_CURRENT_BINARY_DIR}/rpc.grpc.pb.h") +add_custom_command( + OUTPUT "${rpc_proto_srcs}" "${rpc_proto_hdrs}" "${rpc_grpc_srcs}" "${rpc_grpc_hdrs}" + COMMAND ${_PROTOBUF_PROTOC} + ARGS --grpc_out "${CMAKE_CURRENT_BINARY_DIR}" + --cpp_out "${CMAKE_CURRENT_BINARY_DIR}" + -I "${rpc_proto_path}" + --plugin=protoc-gen-grpc="${_GRPC_CPP_PLUGIN_EXECUTABLE}" + "${rpc_proto}" + DEPENDS "${rpc_proto}" +) + +# rpc_grpc_proto +add_library( + rpc_grpc_proto + ${rpc_grpc_srcs} + ${rpc_grpc_hdrs} + ${rpc_proto_srcs} + ${rpc_proto_hdrs} +) +target_include_directories( + rpc_grpc_proto + PUBLIC + ${CMAKE_CURRENT_BINARY_DIR} +) +target_link_libraries( + rpc_grpc_proto + ${_REFLECTION} + ${_GRPC_GRPCPP} + ${_PROTOBUF_LIBPROTOBUF} +) + + +# pybind11_add_module( +# _rlmeta_rpc +# ${CMAKE_CURRENT_SOURCE_DIR}/pybind.cc +# ${CMAKE_CURRENT_SOURCE_DIR}/server.cc +# ) +# target_include_directories( +# _rlmeta_rpc +# PUBLIC +# ${CMAKE_CURRENT_BINARY_DIR} +# ${CMAKE_CURRENT_SOURCE_DIR} +# ${CMAKE_CURRENT_SOURCE_DIR}/../.. +# ) +# target_link_libraries( +# _rlmeta_rpc +# PUBLIC +# grpc++ +# rpc_grpc_proto +# ) diff --git a/rlmeta/rpc/__init__.py b/rlmeta/rpc/__init__.py new file mode 100644 index 0000000..9f5b711 --- /dev/null +++ b/rlmeta/rpc/__init__.py @@ -0,0 +1,19 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from _rlmeta_extension.rpc import ComputationQueue, BatchedComputationQueue +from _rlmeta_extension.rpc import Task, BatchedTask + +from rlmeta.rpc.client import Client +from rlmeta.rpc.server import Server + +__all__ = [ + "ComputationQueue", + "BatchedComputationQueue", + "Task", + "BatchedTask", + "Client", + "Server", +] diff --git a/rlmeta/rpc/cc/blocking_counter.cc b/rlmeta/rpc/cc/blocking_counter.cc new file mode 100644 index 0000000..4ee5a75 --- /dev/null +++ b/rlmeta/rpc/cc/blocking_counter.cc @@ -0,0 +1,33 @@ +// Copyright (c) Meta Platforms, Inc. and affiliates. +// +// This source code is licensed under the MIT license found in the +// LICENSE file in the root directory of this source tree. + +#include "rlmeta/rpc/cc/blocking_counter.h" + +#include + +namespace rlmeta { +namespace rpc { + +bool BlockingCounter::DecrementCount() { + bool ret = false; + { + std::unique_lock lk(mu_); + assert(count_ >= 1); + --count_; + ret = (count_ == 0); + } + cv_.notify_all(); + return ret; +} + +void BlockingCounter::Wait() { + std::unique_lock lk(mu_); + assert(num_waiting_ == 0); + ++num_waiting_; + cv_.wait(lk, [this]() { return count_ == 0; }); +} + +} // namespace rpc +} // namespace rlmeta diff --git a/rlmeta/rpc/cc/blocking_counter.h b/rlmeta/rpc/cc/blocking_counter.h new file mode 100644 index 0000000..5bc70a9 --- /dev/null +++ b/rlmeta/rpc/cc/blocking_counter.h @@ -0,0 +1,53 @@ +// Copyright (c) Meta Platforms, Inc. and affiliates. +// +// This source code is licensed under the MIT license found in the +// LICENSE file in the root directory of this source tree. + +#pragma once + +#include +#include +#include + +namespace rlmeta { +namespace rpc { + +// BlockingCounter implementation is copied and modified from +// the BlockingCounter class in Abseil. +// https://github.com/abseil/abseil-cpp/blob/ce42de10fbea616379826e91c7c23c16bffe6e61/absl/synchronization/blocking_counter.h +// +// Copyright 2017 The Abseil Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +class BlockingCounter { + public: + explicit BlockingCounter(int64_t initial_count) : count_(initial_count) {} + + BlockingCounter(const BlockingCounter&) = delete; + BlockingCounter& operator=(const BlockingCounter&) = delete; + + bool DecrementCount(); + + void Wait(); + + private: + int64_t count_; + int64_t num_waiting_ = 0; + + std::mutex mu_; + std::condition_variable cv_; +}; + +} // namespace rpc +} // namespace rlmeta diff --git a/rlmeta/rpc/cc/client.cc b/rlmeta/rpc/cc/client.cc new file mode 100644 index 0000000..b6a2106 --- /dev/null +++ b/rlmeta/rpc/cc/client.cc @@ -0,0 +1,97 @@ +// Copyright (c) Meta Platforms, Inc. and affiliates. +// +// This source code is licensed under the MIT license found in the +// LICENSE file in the root directory of this source tree. + +#include "rlmeta/rpc/cc/client.h" + +#include +#include + +#include "rlmeta/rpc/cc/rpc_utils.h" + +namespace rlmeta { +namespace rpc { + +void Client::Connect(const std::string& addr, int64_t timeout) { + if (connected_) { + Disconnect(); + } + + grpc::ChannelArguments ch_args; + ch_args.SetMaxSendMessageSize(-1); // Unlimited + ch_args.SetMaxReceiveMessageSize(-1); // Unlimited + channel_ = grpc::CreateCustomChannel(addr, grpc::InsecureChannelCredentials(), + ch_args); + stub_ = Rpc::NewStub(channel_); + + const auto deadline = + std::chrono::system_clock::now() + std::chrono::seconds(timeout); + if (!channel_->WaitForConnected(deadline)) { + std::cerr << "[Client::connect] timeout" << std::endl; + } + connected_ = true; +} + +void Client::Disconnect() { + if (connected_) { + connected_ = false; + } +} + +py::object Client::Rpc(const std::string& func, const py::args& args, + const py::kwargs& kwargs) { + assert(connected_); + RpcRequest request; + request.set_function(func); + *request.mutable_args() = rpc_utils::PythonToNestedData(args); + *request.mutable_kwargs() = rpc_utils::PythonToNestedData(kwargs); + + NestedData ret; + { + py::gil_scoped_release release; + ret = RpcImpl(std::move(request)); + } + + return rpc_utils::NestedDataToPython(std::move(ret)); +} + +rlmeta::rpc::RpcFuture Client::RpcFuture(const std::string& func, + const py::args& args, + const py::kwargs& kwargs) { + assert(connected_); + RpcRequest request; + request.set_function(func); + *request.mutable_args() = rpc_utils::PythonToNestedData(args); + *request.mutable_kwargs() = rpc_utils::PythonToNestedData(kwargs); + + py::gil_scoped_release release; + rlmeta::rpc::RpcFuture fut = + std::async(&Client::RpcImpl, this, std::move(request)); + return fut; +} + +NestedData Client::RpcImpl(RpcRequest&& request) { + RpcResponse response; + grpc::ClientContext context; + grpc::Status status = stub_->RemoteCall(&context, request, &response); + assert(status.ok()); + NestedData ret = std::move(*response.mutable_return_value()); + return ret; +} + +void DefineClient(py::module& m) { + py::class_>(m, "Client") + .def(py::init<>()) + .def_property_readonly("addr", &Client::addr) + .def_property_readonly("connected", &Client::connected) + .def("connect", &Client::Connect, py::arg("addr"), + py::arg("timeout") = 60, py::call_guard()) + .def("disconnect", &Client::Disconnect, + py::call_guard()) + .def("rpc", &Client::Rpc) + .def("rpc_future", &Client::RpcFuture); +} + +} // namespace rpc +} // namespace rlmeta diff --git a/rlmeta/rpc/cc/client.h b/rlmeta/rpc/cc/client.h new file mode 100644 index 0000000..fa86b74 --- /dev/null +++ b/rlmeta/rpc/cc/client.h @@ -0,0 +1,64 @@ +// Copyright (c) Meta Platforms, Inc. and affiliates. +// +// This source code is licensed under the MIT license found in the +// LICENSE file in the root directory of this source tree. + +#pragma once + +#include +#include + +#include +#include +#include +#include + +#include "rlmeta/rpc/cc/rpc_future.h" +#include "rpc.grpc.pb.h" +#include "rpc.pb.h" + +namespace py = pybind11; + +namespace rlmeta { +namespace rpc { + +class Client { + public: + Client() = default; + ~Client() { Disconnect(); } + + const std::string& addr() const { return addr_; } + bool connected() const { return connected_; } + + void Connect(const std::string& addr, int64_t timeout = 60); + void Disconnect(); + + py::object Rpc(const std::string& func, const py::args& args, + const py::kwargs& kwargs); + rlmeta::rpc::RpcFuture RpcFuture(const std::string& func, + const py::args& args, + const py::kwargs& kwargs); + + protected: + struct AsyncClientCall { + RpcResponse response; + grpc::ClientContext context; + grpc::Status status; + std::promise promise; + std::unique_ptr> + response_reader; + }; + + NestedData RpcImpl(RpcRequest&& request); + + std::string addr_; + bool connected_ = false; + + std::shared_ptr channel_; + std::unique_ptr stub_; +}; + +void DefineClient(py::module& m); + +} // namespace rpc +} // namespace rlmeta diff --git a/rlmeta/rpc/cc/computation_queue.cc b/rlmeta/rpc/cc/computation_queue.cc new file mode 100644 index 0000000..aa24eaf --- /dev/null +++ b/rlmeta/rpc/cc/computation_queue.cc @@ -0,0 +1,79 @@ +// Copyright (c) Meta Platforms, Inc. and affiliates. +// +// This source code is licensed under the MIT license found in the +// LICENSE file in the root directory of this source tree. + +#include "rlmeta/rpc/cc/computation_queue.h" + +namespace rlmeta { +namespace rpc { + +std::future BatchedComputationQueue::Put(const NestedData& args, + const NestedData& kwargs) { + std::scoped_lock lk(mu_); + if (cur_computation_ == nullptr) { + cur_computation_ = std::make_shared(batch_size_); + queue_impl_.Put(cur_computation_); + } + std::future ret = cur_computation_->Add(args, kwargs); + if (cur_computation_->Full()) { + cur_computation_.reset(); + } + return ret; +} + +std::future BatchedComputationQueue::Put(NestedData&& args, + NestedData&& kwargs) { + std::scoped_lock lk(mu_); + if (cur_computation_ == nullptr) { + cur_computation_ = std::make_shared(batch_size_); + queue_impl_.Put(cur_computation_); + } + std::future ret = + cur_computation_->Add(std::move(args), std::move(kwargs)); + if (cur_computation_->Full()) { + cur_computation_.reset(); + } + return ret; +} + +std::shared_ptr BatchedComputationQueue::Get() { + std::shared_ptr ret = queue_impl_.Get().value_or(nullptr); + if (ret != nullptr) { + std::scoped_lock lk(mu_); + if (!dynamic_cast(ret.get())->Full()) { + cur_computation_.reset(); + } + } + return ret; +} + +std::shared_ptr BatchedComputationQueue::GetFullBatch() { + std::shared_ptr ret = queue_impl_.Get().value_or(nullptr); + if (ret != nullptr) { + dynamic_cast(ret.get())->Wait(); + } + return ret; +} + +void DefineComputationQueue(py::module& m) { + py::class_>( + m, "ComputationQueue") + .def(py::init<>()) + .def("get", &ComputationQueue::Get, + py::call_guard()) + .def("shutdown", &ComputationQueue::Shutdown, + py::call_guard()); +} + +void DefineBatchedComputationQueue(py::module& m) { + py::class_>( + m, "BatchedComputationQueue") + .def(py::init()) + .def("get_full_batch", &BatchedComputationQueue::GetFullBatch, + py::call_guard()); +} + +} // namespace rpc +} // namespace rlmeta diff --git a/rlmeta/rpc/cc/computation_queue.h b/rlmeta/rpc/cc/computation_queue.h new file mode 100644 index 0000000..eee8826 --- /dev/null +++ b/rlmeta/rpc/cc/computation_queue.h @@ -0,0 +1,77 @@ +// Copyright (c) Meta Platforms, Inc. and affiliates. +// +// This source code is licensed under the MIT license found in the +// LICENSE file in the root directory of this source tree. + +#pragma once + +#include + +#include +#include +#include + +#include "rlmeta/rpc/cc/queue_impl.h" +#include "rlmeta/rpc/cc/task.h" +#include "rpc.pb.h" + +namespace py = pybind11; + +namespace rlmeta { +namespace rpc { + +class ComputationQueue { + public: + ComputationQueue() = default; + explicit ComputationQueue(int64_t capacity) : queue_impl_(capacity) {} + + virtual std::future Put(const NestedData& args, + const NestedData& kwargs) { + std::shared_ptr task = std::make_shared(args, kwargs); + queue_impl_.Put(task); + return task->Future(); + } + + virtual std::future Put(NestedData&& args, NestedData&& kwargs) { + std::shared_ptr task = + std::make_shared(std::move(args), std::move(kwargs)); + queue_impl_.Put(task); + return task->Future(); + } + + virtual std::shared_ptr Get() { + return queue_impl_.Get().value_or(nullptr); + } + + virtual void Shutdown() { queue_impl_.Shutdown(); } + + protected: + QueueImpl> queue_impl_; +}; + +class BatchedComputationQueue : public ComputationQueue { + public: + explicit BatchedComputationQueue(int64_t batch_size) + : batch_size_(batch_size) {} + BatchedComputationQueue(int64_t capacity, int64_t batch_size) + : ComputationQueue(capacity), batch_size_(batch_size) {} + + std::future Put(const NestedData& args, + const NestedData& kwargs) override; + std::future Put(NestedData&& args, NestedData&& kwargs) override; + + std::shared_ptr Get() override; + std::shared_ptr GetFullBatch(); + + protected: + const int64_t batch_size_; + std::shared_ptr cur_computation_ = nullptr; + + std::mutex mu_; +}; + +void DefineComputationQueue(py::module& m); +void DefineBatchedComputationQueue(py::module& m); + +} // namespace rpc +} // namespace rlmeta diff --git a/rlmeta/rpc/cc/queue_impl.h b/rlmeta/rpc/cc/queue_impl.h new file mode 100644 index 0000000..a350c54 --- /dev/null +++ b/rlmeta/rpc/cc/queue_impl.h @@ -0,0 +1,109 @@ +// Copyright (c) Meta Platforms, Inc. and affiliates. +// +// This source code is licensed under the MIT license found in the +// LICENSE file in the root directory of this source tree. + +#pragma once + +#include +#include +#include +#include +#include + +namespace rlmeta { +namespace rpc { + +template +class QueueImpl { + public: + QueueImpl() = default; + explicit QueueImpl(int64_t capacity) : capacity_(capacity){}; + + int64_t Size() const { + std::scoped_lock lk(mu_); + return data_.size(); + } + + bool Empty() const { + std::scoped_lock lk(mu_); + return data_.empty(); + } + + bool Full() const { + std::scoped_lock lk(mu_); + return data_.size() == capacity_; + } + + bool Put(const T& o) { + { + std::unique_lock lk(mu_); + can_put_.wait(lk, [this]() { + return !is_alive_ || static_cast(data_.size()) < capacity_; + }); + if (!is_alive_) { + return false; + } + data_.push_back(o); + } + can_get_.notify_one(); + return true; + } + + bool Put(T&& o) { + { + std::unique_lock lk(mu_); + can_put_.wait(lk, [this]() { + return !is_alive_ || static_cast(data_.size()) < capacity_; + }); + if (!is_alive_) { + return false; + } + data_.push_back(std::move(o)); + } + can_get_.notify_one(); + return true; + } + + std::optional Get() { + std::optional ret = [this]() -> std::optional { + std::unique_lock lk(mu_); + can_get_.wait(lk, [this]() { return !is_alive_ || !data_.empty(); }); + if (!is_alive_) { + return std::nullopt; + } + T ret = std::move(data_.front()); + data_.pop_front(); + return std::make_optional(ret); + }(); + if (ret.has_value()) { + can_put_.notify_one(); + } + return ret; + } + + void Shutdown() { + { + std::scoped_lock lk(mu_); + if (!is_alive_) { + return; + } + is_alive_ = false; + data_.clear(); + } + can_put_.notify_all(); + can_get_.notify_all(); + } + + protected: + const int64_t capacity_ = 1024; + std::deque data_; + bool is_alive_ = true; + + std::mutex mu_; + std::condition_variable can_put_; + std::condition_variable can_get_; +}; + +} // namespace rpc +} // namespace rlmeta diff --git a/rlmeta/rpc/cc/rpc_future.cc b/rlmeta/rpc/cc/rpc_future.cc new file mode 100644 index 0000000..52d0e41 --- /dev/null +++ b/rlmeta/rpc/cc/rpc_future.cc @@ -0,0 +1,32 @@ +// Copyright (c) Meta Platforms, Inc. and affiliates. +// +// This source code is licensed under the MIT license found in the +// LICENSE file in the root directory of this source tree. + +#include "rlmeta/rpc/cc/rpc_future.h" + +#include + +#include "rlmeta/rpc/cc/rpc_utils.h" + +namespace rlmeta { +namespace rpc { + +py::object RpcFuture::Get() { + if (!Valid()) { + return py::none(); + } + py::object ret = rpc_utils::NestedDataToPython(std::move(future_.get())); + valid_ = false; + return ret; +} + +void DefineRpcFuture(py::module& m) { + py::class_>(m, "RpcFuture") + .def("valid", &RpcFuture::Valid) + .def("get", &RpcFuture::Get) + .def("wait", &RpcFuture::Wait); +} + +} // namespace rpc +} // namespace rlmeta diff --git a/rlmeta/rpc/cc/rpc_future.h b/rlmeta/rpc/cc/rpc_future.h new file mode 100644 index 0000000..4aac466 --- /dev/null +++ b/rlmeta/rpc/cc/rpc_future.h @@ -0,0 +1,37 @@ +// Copyright (c) Meta Platforms, Inc. and affiliates. +// +// This source code is licensed under the MIT license found in the +// LICENSE file in the root directory of this source tree. + +#pragma once + +#include + +#include + +#include "rpc.pb.h" + +namespace py = pybind11; + +namespace rlmeta { +namespace rpc { + +class RpcFuture { + public: + RpcFuture(std::future&& fut) : future_(std::move(fut)) {} + + bool Valid() const { return future_.valid() && valid_; } + + py::object Get(); + + void Wait() { future_.wait(); } + + protected: + std::future future_; + bool valid_ = true; +}; + +void DefineRpcFuture(py::module& m); + +} // namespace rpc +} // namespace rlmeta diff --git a/rlmeta/rpc/cc/rpc_utils.cc b/rlmeta/rpc/cc/rpc_utils.cc new file mode 100644 index 0000000..5314cf3 --- /dev/null +++ b/rlmeta/rpc/cc/rpc_utils.cc @@ -0,0 +1,177 @@ +// Copyright (c) Meta Platforms, Inc. and affiliates. +// +// This source code is licensed under the MIT license found in the +// LICENSE file in the root directory of this source tree. + +#include "rlmeta/rpc/cc/rpc_utils.h" + +#include +#include +#include +#include + +#include +#include +#include +#include + +#include "rlmeta/cc/nested_utils.h" +#include "rlmeta/cc/torch_utils.h" +#include "rlmeta/rpc/cc/tensor_wrapper.h" + +namespace rlmeta { +namespace rpc { + +namespace rpc_utils { + +TensorProto PythonToTensorProto(const py::object& obj) { + if (py::isinstance(obj)) { + return TensorWrapper(obj).TensorProto(); + } + if (rlmeta::utils::IsTorchTensor(obj)) { + return TensorWrapper(obj).TensorProto(); + } + return TensorProto(); +} + +py::object TensorProtoToPython(TensorProto&& tensor) { + if (tensor.tensor_type() == TensorProto::NUMPY) { + return TensorWrapper(std::move(tensor)).Python(); + } + if (tensor.tensor_type() == TensorProto::TORCH) { + return TensorWrapper(std::move(tensor)).Python(); + } + return py::none(); +} + +SimpleData PythonToSimpleData(const py::object& obj) { + SimpleData ret; + if (obj.is_none()) { + return ret; + } + if (py::isinstance(obj)) { + ret.set_bool_val(obj.cast()); + } else if (py::isinstance(obj)) { + ret.set_int_val(obj.cast()); + } else if (py::isinstance(obj)) { + ret.set_float_val(obj.cast()); + } else if (py::isinstance(obj)) { + ret.set_str_val(obj.cast()); + } else if (py::isinstance(obj)) { + ret.set_bytes_val(obj.cast()); + } else if (py::isinstance(obj)) { + *ret.mutable_tensor_val() = PythonToTensorProto(obj); + } else if (rlmeta::utils::IsTorchTensor(obj)) { + *ret.mutable_tensor_val() = PythonToTensorProto(obj); + } + return ret; +} + +py::object SimpleDataToPython(SimpleData&& proto) { + if (proto.has_bool_val()) { + return py::cast(proto.bool_val()); + } + if (proto.has_int_val()) { + return py::cast(proto.int_val()); + } + if (proto.has_float_val()) { + return py::cast(proto.float_val()); + } + if (proto.has_str_val()) { + return py::cast(proto.str_val()); + } + if (proto.has_bytes_val()) { + return py::bytes(proto.bytes_val()); + } + if (proto.has_tensor_val()) { + return TensorProtoToPython(std::move(*(proto.mutable_tensor_val()))); + } + return py::none(); +} + +NestedData PythonToNestedData(const py::object& obj) { + NestedData ret; + if (obj.is_none()) { + return ret; + } + + if (py::isinstance(obj)) { + py::tuple src = py::reinterpret_borrow(obj); + auto* dst = ret.mutable_vec(); + for (const auto x : src) { + *dst->add_data() = + PythonToNestedData(py::reinterpret_borrow(x)); + } + } else if (py::isinstance(obj)) { + py::list src = py::reinterpret_borrow(obj); + auto* dst = ret.mutable_vec(); + for (const auto x : src) { + *dst->add_data() = + PythonToNestedData(py::reinterpret_borrow(x)); + } + } else if (py::isinstance(obj)) { + py::dict src = py::reinterpret_borrow(obj); + auto* dst = ret.mutable_map(); + const std::vector keys = nested_utils::SortedKeys(src); + for (const std::string& k : keys) { + dst->mutable_data()->insert( + {k, PythonToNestedData( + py::reinterpret_borrow(src[py::str(k)]))}); + } + } else { + *ret.mutable_val() = PythonToSimpleData(obj); + } + + return ret; +} + +py::object NestedDataToPython(NestedData&& proto) { + if (proto.has_val()) { + return SimpleDataToPython(std::move(*proto.mutable_val())); + } + if (proto.has_vec()) { + auto* src = proto.mutable_vec(); + const int64_t n = src->data_size(); + py::tuple ret(n); + for (int64_t i = 0; i < n; ++i) { + ret[i] = NestedDataToPython(std::move(*(src->mutable_data(i)))); + } + return ret; + } + if (proto.has_map()) { + auto* src = proto.mutable_map()->mutable_data(); + py::dict ret; + const std::vector keys = nested_utils::SortedKeys(*src); + for (const std::string& k : keys) { + ret[py::str(k)] = NestedDataToPython(std::move(src->at(k))); + } + return ret; + } + return py::none(); +} + +std::string Dumps(const py::object& obj) { + const NestedData src = PythonToNestedData(obj); + py::gil_scoped_release release; + return src.SerializeAsString(); +} + +py::object Loads(const std::string& src) { + NestedData dst; + { + py::gil_scoped_release release; + dst.ParseFromString(src); + } + return NestedDataToPython(std::move(dst)); +} + +} // namespace rpc_utils + +void DefineRpcUtils(py::module& m) { + m.def("dumps", [](const py::object& obj) { + return py::bytes(rpc_utils::Dumps(obj)); + }).def("loads", &rpc_utils::Loads); +} + +} // namespace rpc +} // namespace rlmeta diff --git a/rlmeta/rpc/cc/rpc_utils.h b/rlmeta/rpc/cc/rpc_utils.h new file mode 100644 index 0000000..524136e --- /dev/null +++ b/rlmeta/rpc/cc/rpc_utils.h @@ -0,0 +1,38 @@ +// Copyright (c) Meta Platforms, Inc. and affiliates. +// +// This source code is licensed under the MIT license found in the +// LICENSE file in the root directory of this source tree. + +#pragma once + +#include + +#include + +#include "rpc.pb.h" + +namespace py = pybind11; + +namespace rlmeta { +namespace rpc { + +namespace rpc_utils { + +TensorProto PythonToTensorProto(const py::object& obj); +py::object TensorProtoToPython(TensorProto&& tensor); + +SimpleData PythonToSimpleData(const py::object& obj); +py::object SimpleDataToPython(SimpleData&& proto); + +NestedData PythonToNestedData(const py::object& obj); +py::object NestedDataToPython(NestedData&& proto); + +std::string Dumps(const py::object& obj); +py::object Loads(const std::string& src); + +} // namespace rpc_utils + +void DefineRpcUtils(py::module& m); + +} // namespace rpc +} // namespace rlmeta diff --git a/rlmeta/rpc/cc/server.cc b/rlmeta/rpc/cc/server.cc new file mode 100644 index 0000000..b93ebbe --- /dev/null +++ b/rlmeta/rpc/cc/server.cc @@ -0,0 +1,85 @@ +// Copyright (c) Meta Platforms, Inc. and affiliates. +// +// This source code is licensed under the MIT license found in the +// LICENSE file in the root directory of this source tree. + +#include "rlmeta/rpc/cc/server.h" + +#include +#include + +#include +#include + +#include "rpc.pb.h" + +namespace rlmeta { +namespace rpc { + +grpc::Status ServiceImpl::RemoteCall(grpc::ServerContext* /* context */, + const RpcRequest* request, + RpcResponse* response) { + auto& func = functions_.at(request->function()).first; + *response->mutable_return_value() = func(request->args(), request->kwargs()); + return grpc::Status::OK; +} + +grpc::Status ServiceImpl::PyRemoteCall(grpc::ServerContext* /* context */, + const PyRpcRequest* request, + PyRpcResponse* response) { + NestedData args; + NestedData kwargs; + args.ParseFromString(request->args()); + kwargs.ParseFromString(request->kwargs()); + auto& func = functions_.at(request->function()).second; + NestedData ret = func(std::move(args), std::move(kwargs)); + ret.SerializeToString(response->mutable_return_value()); + return grpc::Status::OK; +} + +void Server::Start() { + grpc::EnableDefaultHealthCheckService(true); + grpc::reflection::InitProtoReflectionServerBuilderPlugin(); + grpc::ServerBuilder builder; + builder.SetMaxReceiveMessageSize(-1); // Unlimited. + builder.AddListeningPort(addr_, grpc::InsecureServerCredentials()); + builder.RegisterService(&service_); + server_ = builder.BuildAndStart(); +} + +void Server::Stop() { + server_->Shutdown(); + server_->Wait(); +} + +std::shared_ptr Server::RegisterQueue( + const std::string& func_name, int64_t batch_size) { + std::shared_ptr ret = nullptr; + if (batch_size == 0) { + ret = std::make_shared(); + } else { + ret = std::make_shared(batch_size); + } + service_.Register( + func_name, + [que = ret](const NestedData& args, const NestedData& kwargs) { + return que->Put(args, kwargs).get(); + }, + [que = ret](NestedData&& args, NestedData&& kwargs) { + return que->Put(std::move(args), std::move(kwargs)).get(); + }); + return ret; +} + +void DefineServer(py::module& m) { + py::class_>(m, "Server") + .def(py::init()) + .def_property_readonly("addr", &Server::addr) + .def("start", &Server::Start, py::call_guard()) + .def("stop", &Server::Stop, py::call_guard()) + .def("register_queue", &Server::RegisterQueue, py::arg("func_name"), + py::arg("batch_size") = 0); +} + +} // namespace rpc +} // namespace rlmeta diff --git a/rlmeta/rpc/cc/server.h b/rlmeta/rpc/cc/server.h new file mode 100644 index 0000000..938497a --- /dev/null +++ b/rlmeta/rpc/cc/server.h @@ -0,0 +1,80 @@ +// Copyright (c) Meta Platforms, Inc. and affiliates. +// +// This source code is licensed under the MIT license found in the +// LICENSE file in the root directory of this source tree. + +#pragma once + +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include + +#include "rlmeta/rpc/cc/computation_queue.h" +#include "rlmeta/rpc/cc/task.h" +#include "rpc.grpc.pb.h" +#include "rpc.pb.h" + +namespace py = pybind11; + +namespace rlmeta { +namespace rpc { + +using PyFunc = std::function; +using PyFuncRvalue = std::function; +using PyFuncDict = + std::unordered_map>; + +class ServiceImpl final : public Rpc::Service { + public: + grpc::Status Register(const std::string& func_name, PyFunc&& func_impl, + PyFuncRvalue&& func_rvalue_impl) { + functions_.emplace(func_name, std::make_pair(std::move(func_impl), + std::move(func_rvalue_impl))); + return grpc::Status::OK; + } + + private: + grpc::Status RemoteCall(grpc::ServerContext* context, + const RpcRequest* request, + RpcResponse* response) override; + + grpc::Status PyRemoteCall(grpc::ServerContext* context, + const PyRpcRequest* request, + PyRpcResponse* response) override; + + PyFuncDict functions_; +}; + +class Server { + public: + explicit Server(const std::string& addr) : addr_(addr) {} + ~Server() { Stop(); } + + const std::string& addr() const { return addr_; } + + void Start(); + void Stop(); + + std::shared_ptr RegisterQueue(const std::string& func_name, + int64_t batch_size = 0); + + protected: + void ServePyFuncQueue(const std::string& func_name); + + const std::string addr_; + std::unique_ptr server_; + ServiceImpl service_; +}; + +void DefineServer(py::module& m); + +} // namespace rpc +} // namespace rlmeta diff --git a/rlmeta/rpc/cc/task.cc b/rlmeta/rpc/cc/task.cc new file mode 100644 index 0000000..51dcb03 --- /dev/null +++ b/rlmeta/rpc/cc/task.cc @@ -0,0 +1,62 @@ +// Copyright (c) Meta Platforms, Inc. and affiliates. +// +// This source code is licensed under the MIT license found in the +// LICENSE file in the root directory of this source tree. + +#include "rlmeta/rpc/cc/task.h" + +#include + +namespace rlmeta { +namespace rpc { + +void BatchedTask::SetReturnValue(const py::object& return_value) { + NestedData ret = rpc_utils::PythonToNestedData(return_value); + py::gil_scoped_release release; + assert(ret.has_vec()); + assert(ret.vec_size() == batch_size_); + for (int64_t i = 0; i < batch_size_; ++i) { + promises_[i].set_value(std::move(*ret.mutable_vec()->mutable_data(i))); + } +} + +std::future BatchedTask::Add(const NestedData& args, + const NestedData& kwargs) { + assert(batch_size_ < capacity_); + std::promise& p = promises_.emplace_back(); + *args_.mutable_vec()->add_data() = args; + *kwargs_.mutable_vec()->add_data() = kwargs; + ++batch_size_; + num_to_wait_.DecrementCount(); + return p.get_future(); +} + +std::future BatchedTask::Add(NestedData&& args, + NestedData&& kwargs) { + assert(batch_size_ < capacity_); + std::promise& p = promises_.emplace_back(); + *args_.mutable_vec()->add_data() = std::move(args); + *kwargs_.mutable_vec()->add_data() = std::move(kwargs); + ++batch_size_; + num_to_wait_.DecrementCount(); + return p.get_future(); +} + +void DefineTask(py::module& m) { + py::class_>(m, "Task") + .def("args", &Task::Args) + .def("kwargs", &Task::Kwargs) + .def("set_return_value", &Task::SetReturnValue); +} + +void DefineBatchedTask(py::module& m) { + py::class_>(m, "BatchedTask") + .def("__len__", &BatchedTask::batch_size) + .def_property_readonly("capacity", &BatchedTask::capacity) + .def_property_readonly("batch_size", &BatchedTask::batch_size) + .def("empty", &BatchedTask::Empty) + .def("full", &BatchedTask::Full); +} + +} // namespace rpc +} // namespace rlmeta diff --git a/rlmeta/rpc/cc/task.h b/rlmeta/rpc/cc/task.h new file mode 100644 index 0000000..797f084 --- /dev/null +++ b/rlmeta/rpc/cc/task.h @@ -0,0 +1,88 @@ +// Copyright (c) Meta Platforms, Inc. and affiliates. +// +// This source code is licensed under the MIT license found in the +// LICENSE file in the root directory of this source tree. + +#pragma once + +#include + +#include +#include +#include +#include +#include + +#include "rlmeta/rpc/cc/blocking_counter.h" +#include "rlmeta/rpc/cc/rpc_utils.h" +#include "rpc.pb.h" + +namespace py = pybind11; + +namespace rlmeta { +namespace rpc { + +class Task { + public: + Task() = default; + + Task(const NestedData& args, const NestedData& kwargs) + : args_(args), kwargs_(kwargs) {} + + Task(NestedData&& args, NestedData&& kwargs) + : args_(std::move(args)), kwargs_(std::move(kwargs)) {} + + virtual py::object Args() { + return rpc_utils::NestedDataToPython(std::move(args_)); + } + + virtual py::object Kwargs() { + return rpc_utils::NestedDataToPython(std::move(kwargs_)); + } + + std::future Future() { return promise_.get_future(); } + + virtual void SetReturnValue(const py::object& return_value) { + promise_.set_value(rpc_utils::PythonToNestedData(return_value)); + } + + protected: + NestedData args_; + NestedData kwargs_; + std::promise promise_; +}; + +class BatchedTask : public Task { + public: + explicit BatchedTask(int64_t capacity) + : capacity_(capacity), num_to_wait_(capacity) { + promises_.reserve(capacity); + } + + int64_t capacity() const { return capacity_; } + int64_t batch_size() const { return batch_size_; } + + bool Empty() const { return batch_size_ == 0; } + bool Full() const { return batch_size_ == capacity_; } + + void SetReturnValue(const py::object& return_value) override; + + std::future Add(const NestedData& args, const NestedData& kwargs); + std::future Add(NestedData&& args, NestedData&& kwargs); + + void Wait() { num_to_wait_.Wait(); } + + protected: + const int64_t capacity_; + int64_t batch_size_ = 0; + std::vector> promises_; + + BlockingCounter num_to_wait_; +}; + +void DefineTaskBase(py::module& m); +void DefineTask(py::module& m); +void DefineBatchedTask(py::module& m); + +} // namespace rpc +} // namespace rlmeta diff --git a/rlmeta/rpc/cc/tensor_wrapper.cc b/rlmeta/rpc/cc/tensor_wrapper.cc new file mode 100644 index 0000000..ddb2df6 --- /dev/null +++ b/rlmeta/rpc/cc/tensor_wrapper.cc @@ -0,0 +1,86 @@ +// Copyright (c) Meta Platforms, Inc. and affiliates. +// +// This source code is licensed under the MIT license found in the +// LICENSE file in the root directory of this source tree. + +#include "rlmeta/rpc/cc/tensor_wrapper.h" + +#include + +#include + +#include "rlmeta/cc/torch_utils.h" + +namespace rlmeta { +namespace rpc { + +template <> +void TensorWrapper::FromPython(const py::object& obj) { + tensor_ = obj.cast(); +} + +template <> +void TensorWrapper::FromTensorProto( + rlmeta::rpc::TensorProto&& proto) { + assert(tensor.tensor_type() == TensorProto::NUMPY); + const std::vector shape(proto.shape().cbegin(), + proto.shape().cend()); + const void* data = proto.data().data(); + std::unique_ptr data_ptr(proto.release_data()); + auto capsule = py::capsule(data_ptr.get(), [](void* p) { + std::unique_ptr(reinterpret_cast(p)); + }); + data_ptr.release(); + tensor_ = py::array(py::dtype(proto.dtype()), shape, data, capsule); +} + +template <> +py::object TensorWrapper::Python() { + return std::move(tensor_); +} + +template <> +rlmeta::rpc::TensorProto TensorWrapper::TensorProto() { + rlmeta::rpc::TensorProto ret; + ret.set_tensor_type(TensorProto::NUMPY); + ret.set_dtype(tensor_.dtype().num()); + ret.mutable_shape()->Assign(tensor_.shape(), + tensor_.shape() + tensor_.ndim()); + ret.set_data(tensor_.data(), tensor_.nbytes()); + return ret; +} + +template <> +void TensorWrapper::FromPython(const py::object& obj) { + tensor_ = rlmeta::utils::PyObjectToTorchTensor(obj).contiguous(); +} + +template <> +void TensorWrapper::FromTensorProto( + rlmeta::rpc::TensorProto&& proto) { + assert(proto.tensor_type() == Tensor::TORCH); + const std::vector shape(proto.shape().cbegin(), + proto.shape().cend()); + std::string* data = proto.release_data(); + tensor_ = at::from_blob( + data->data(), shape, /*deleter=*/[data](void* /*p*/) { delete data; }, + static_cast(proto.dtype())); +} + +template <> +py::object TensorWrapper::Python() { + return rlmeta::utils::TorchTensorToPyObject(std::move(tensor_)); +} + +template <> +rlmeta::rpc::TensorProto TensorWrapper::TensorProto() { + rlmeta::rpc::TensorProto ret; + ret.set_tensor_type(TensorProto::TORCH); + ret.set_dtype(static_cast(tensor_.scalar_type())); + ret.mutable_shape()->Assign(tensor_.sizes().cbegin(), tensor_.sizes().cend()); + ret.set_data(tensor_.data_ptr(), tensor_.nbytes()); + return ret; +} + +} // namespace rpc +} // namespace rlmeta diff --git a/rlmeta/rpc/cc/tensor_wrapper.h b/rlmeta/rpc/cc/tensor_wrapper.h new file mode 100644 index 0000000..bd72657 --- /dev/null +++ b/rlmeta/rpc/cc/tensor_wrapper.h @@ -0,0 +1,45 @@ +// Copyright (c) Meta Platforms, Inc. and affiliates. +// +// This source code is licensed under the MIT license found in the +// LICENSE file in the root directory of this source tree. + +#pragma once + +#include + +#include "rpc.pb.h" + +namespace py = pybind11; + +namespace rlmeta { +namespace rpc { + +class TensorWrapperBase { + public: + virtual void FromPython(const py::object& obj) = 0; + virtual void FromTensorProto(rlmeta::rpc::TensorProto&& proto) = 0; + + virtual py::object Python() = 0; + virtual rlmeta::rpc::TensorProto TensorProto() = 0; +}; + +template +class TensorWrapper : public TensorWrapperBase { + public: + TensorWrapper(const py::object& obj) { FromPython(obj); } + TensorWrapper(rlmeta::rpc::TensorProto&& proto) { + FromTensorProto(std::move(proto)); + } + + void FromPython(const py::object& obj) override; + void FromTensorProto(rlmeta::rpc::TensorProto&& proto) override; + + py::object Python() override; + rlmeta::rpc::TensorProto TensorProto() override; + + protected: + TensorType tensor_; +}; + +} // namespace rpc +} // namespace rlmeta diff --git a/rlmeta/rpc/client.py b/rlmeta/rpc/client.py new file mode 100644 index 0000000..f06c3a5 --- /dev/null +++ b/rlmeta/rpc/client.py @@ -0,0 +1,71 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import asyncio +import pickle + +from typing import Any + +import grpc +import grpc.experimental + +import rlmeta.rpc.rpc_pb2 as rpc_pb2 +import rlmeta.rpc.rpc_pb2_grpc as rpc_pb2_grpc +import _rlmeta_extension.rpc as _rpc +import _rlmeta_extension.rpc.rpc_utils as _rpc_utils + + +class Client(_rpc.Client): + + def __init__(self, py_aio_client: bool = True) -> None: + super().__init__() + self._addr = None + self._timeout = None + self._options = None + self._connected = False + self._py_aio_client = py_aio_client + + def connect(self, addr: str, timeout: int = 60) -> None: + super().connect(addr, timeout) + self._addr = addr + self._timeout = timeout + self._options = [ + (grpc.experimental.ChannelOptions.SingleThreadedUnaryStream, 1), + ("grpc.enable_http_proxy", 0), + ("grpc.max_send_message_length", -1), + ("grpc.max_receive_message_length", -1), + ] + self._connected = True + + def rpc(self, function: str, *args, **kwargs) -> Any: + return super().rpc(function, *args, **kwargs) + + def rpc_future(self, function: str, *args, **kwargs) -> _rpc.RpcFuture: + return super().rpc_future(function, *args, **kwargs) + + async def async_rpc(self, function: str, *args, **kwargs) -> Any: + if self._py_aio_client: + return await self._async_rpc_py(function, *args, **kwargs) + else: + return await self._async_rpc_cc(function, *args, **kwargs) + + async def _async_rpc_cc(self, function: str, *args, **kwargs) -> Any: + loop = asyncio.get_running_loop() + ret = super().rpc_future(function, *args, **kwargs) + fut = asyncio.Future() + # loop.call_soon_threadsafe(lambda x: fut.set_result(x.get()), ret) + loop.call_soon(lambda x: fut.set_result(x.get()), ret) + return await fut + + # TODO: Add efficient native asyncio support in C++ client. + async def _async_rpc_py(self, function: str, *args, **kwargs) -> Any: + async with grpc.aio.insecure_channel(target=self._addr, + options=self._options) as channel: + stub = rpc_pb2_grpc.RpcStub(channel) + response = await stub.PyRemoteCall( + rpc_pb2.PyRpcRequest(function=function, + args=_rpc_utils.dumps(args), + kwargs=_rpc_utils.dumps(kwargs))) + return _rpc_utils.loads(response.return_value) diff --git a/rlmeta/rpc/protos/rpc.proto b/rlmeta/rpc/protos/rpc.proto new file mode 100644 index 0000000..cb0e784 --- /dev/null +++ b/rlmeta/rpc/protos/rpc.proto @@ -0,0 +1,79 @@ +// Copyright (c) Meta Platforms, Inc. and affiliates. +// +// This source code is licensed under the MIT license found in the +// LICENSE file in the root directory of this source tree. + +syntax = "proto2"; + +package rlmeta.rpc; + +service Rpc { + rpc RemoteCall(RpcRequest) returns (RpcResponse) {} + rpc PyRemoteCall(PyRpcRequest) returns (PyRpcResponse) {} +} + +message RpcRequest { + optional string function = 1; + optional NestedData args = 2; + optional NestedData kwargs = 3; +} + +message RpcResponse { + optional NestedData return_value = 1; + optional Error error = 2; +} + +message PyRpcRequest { + optional string function = 1; + optional bytes args = 2; + optional bytes kwargs = 3; +} + +message PyRpcResponse { + optional bytes return_value = 1; + optional string error = 2; +} + +message Error { + optional string msg = 1; +} + +message TensorProto { + enum TensorType { + UNKNOWN = 0; + NUMPY = 1; + TORCH = 2; + } + + optional TensorType tensor_type = 1; + optional int32 dtype = 2; + repeated int64 shape = 3 [packed = true]; + optional bytes data = 4; +} + +message SimpleData { + oneof value { + bool bool_val = 1; + int64 int_val = 2; + double float_val = 3; + string str_val = 4; + bytes bytes_val = 5; + TensorProto tensor_val = 6; + } +} + +message DataVector { + repeated NestedData data = 1; +} + +message DataMap { + map data = 1; +} + +message NestedData { + oneof nested_data { + SimpleData val = 1; + DataVector vec = 2; + DataMap map = 3; + } +} diff --git a/rlmeta/rpc/rpc_pb2.py b/rlmeta/rpc/rpc_pb2.py new file mode 100644 index 0000000..9348f50 --- /dev/null +++ b/rlmeta/rpc/rpc_pb2.py @@ -0,0 +1,151 @@ +# -*- coding: utf-8 -*- + +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +# Generated by the protocol buffer compiler. DO NOT EDIT! +# source: rpc.proto +"""Generated protocol buffer code.""" +from google.protobuf import descriptor as _descriptor +from google.protobuf import descriptor_pool as _descriptor_pool +from google.protobuf import message as _message +from google.protobuf import reflection as _reflection +from google.protobuf import symbol_database as _symbol_database +# @@protoc_insertion_point(imports) + +_sym_db = _symbol_database.Default() + + + + +DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\trpc.proto\x12\nrlmeta.rpc\"l\n\nRpcRequest\x12\x10\n\x08\x66unction\x18\x01 \x01(\t\x12$\n\x04\x61rgs\x18\x02 \x01(\x0b\x32\x16.rlmeta.rpc.NestedData\x12&\n\x06kwargs\x18\x03 \x01(\x0b\x32\x16.rlmeta.rpc.NestedData\"]\n\x0bRpcResponse\x12,\n\x0creturn_value\x18\x01 \x01(\x0b\x32\x16.rlmeta.rpc.NestedData\x12 \n\x05\x65rror\x18\x02 \x01(\x0b\x32\x11.rlmeta.rpc.Error\">\n\x0cPyRpcRequest\x12\x10\n\x08\x66unction\x18\x01 \x01(\t\x12\x0c\n\x04\x61rgs\x18\x02 \x01(\x0c\x12\x0e\n\x06kwargs\x18\x03 \x01(\x0c\"4\n\rPyRpcResponse\x12\x14\n\x0creturn_value\x18\x01 \x01(\x0c\x12\r\n\x05\x65rror\x18\x02 \x01(\t\"\x14\n\x05\x45rror\x12\x0b\n\x03msg\x18\x01 \x01(\t\"\xa7\x01\n\x0bTensorProto\x12\x37\n\x0btensor_type\x18\x01 \x01(\x0e\x32\".rlmeta.rpc.TensorProto.TensorType\x12\r\n\x05\x64type\x18\x02 \x01(\x05\x12\x11\n\x05shape\x18\x03 \x03(\x03\x42\x02\x10\x01\x12\x0c\n\x04\x64\x61ta\x18\x04 \x01(\x0c\"/\n\nTensorType\x12\x0b\n\x07UNKNOWN\x10\x00\x12\t\n\x05NUMPY\x10\x01\x12\t\n\x05TORCH\x10\x02\"\xa8\x01\n\nSimpleData\x12\x12\n\x08\x62ool_val\x18\x01 \x01(\x08H\x00\x12\x11\n\x07int_val\x18\x02 \x01(\x03H\x00\x12\x13\n\tfloat_val\x18\x03 \x01(\x01H\x00\x12\x11\n\x07str_val\x18\x04 \x01(\tH\x00\x12\x13\n\tbytes_val\x18\x05 \x01(\x0cH\x00\x12-\n\ntensor_val\x18\x06 \x01(\x0b\x32\x17.rlmeta.rpc.TensorProtoH\x00\x42\x07\n\x05value\"2\n\nDataVector\x12$\n\x04\x64\x61ta\x18\x01 \x03(\x0b\x32\x16.rlmeta.rpc.NestedData\"{\n\x07\x44\x61taMap\x12+\n\x04\x64\x61ta\x18\x01 \x03(\x0b\x32\x1d.rlmeta.rpc.DataMap.DataEntry\x1a\x43\n\tDataEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12%\n\x05value\x18\x02 \x01(\x0b\x32\x16.rlmeta.rpc.NestedData:\x02\x38\x01\"\x8d\x01\n\nNestedData\x12%\n\x03val\x18\x01 \x01(\x0b\x32\x16.rlmeta.rpc.SimpleDataH\x00\x12%\n\x03vec\x18\x02 \x01(\x0b\x32\x16.rlmeta.rpc.DataVectorH\x00\x12\"\n\x03map\x18\x03 \x01(\x0b\x32\x13.rlmeta.rpc.DataMapH\x00\x42\r\n\x0bnested_data2\x8d\x01\n\x03Rpc\x12?\n\nRemoteCall\x12\x16.rlmeta.rpc.RpcRequest\x1a\x17.rlmeta.rpc.RpcResponse\"\x00\x12\x45\n\x0cPyRemoteCall\x12\x18.rlmeta.rpc.PyRpcRequest\x1a\x19.rlmeta.rpc.PyRpcResponse\"\x00') + + + +_RPCREQUEST = DESCRIPTOR.message_types_by_name['RpcRequest'] +_RPCRESPONSE = DESCRIPTOR.message_types_by_name['RpcResponse'] +_PYRPCREQUEST = DESCRIPTOR.message_types_by_name['PyRpcRequest'] +_PYRPCRESPONSE = DESCRIPTOR.message_types_by_name['PyRpcResponse'] +_ERROR = DESCRIPTOR.message_types_by_name['Error'] +_TENSORPROTO = DESCRIPTOR.message_types_by_name['TensorProto'] +_SIMPLEDATA = DESCRIPTOR.message_types_by_name['SimpleData'] +_DATAVECTOR = DESCRIPTOR.message_types_by_name['DataVector'] +_DATAMAP = DESCRIPTOR.message_types_by_name['DataMap'] +_DATAMAP_DATAENTRY = _DATAMAP.nested_types_by_name['DataEntry'] +_NESTEDDATA = DESCRIPTOR.message_types_by_name['NestedData'] +_TENSORPROTO_TENSORTYPE = _TENSORPROTO.enum_types_by_name['TensorType'] +RpcRequest = _reflection.GeneratedProtocolMessageType('RpcRequest', (_message.Message,), { + 'DESCRIPTOR' : _RPCREQUEST, + '__module__' : 'rpc_pb2' + # @@protoc_insertion_point(class_scope:rlmeta.rpc.RpcRequest) + }) +_sym_db.RegisterMessage(RpcRequest) + +RpcResponse = _reflection.GeneratedProtocolMessageType('RpcResponse', (_message.Message,), { + 'DESCRIPTOR' : _RPCRESPONSE, + '__module__' : 'rpc_pb2' + # @@protoc_insertion_point(class_scope:rlmeta.rpc.RpcResponse) + }) +_sym_db.RegisterMessage(RpcResponse) + +PyRpcRequest = _reflection.GeneratedProtocolMessageType('PyRpcRequest', (_message.Message,), { + 'DESCRIPTOR' : _PYRPCREQUEST, + '__module__' : 'rpc_pb2' + # @@protoc_insertion_point(class_scope:rlmeta.rpc.PyRpcRequest) + }) +_sym_db.RegisterMessage(PyRpcRequest) + +PyRpcResponse = _reflection.GeneratedProtocolMessageType('PyRpcResponse', (_message.Message,), { + 'DESCRIPTOR' : _PYRPCRESPONSE, + '__module__' : 'rpc_pb2' + # @@protoc_insertion_point(class_scope:rlmeta.rpc.PyRpcResponse) + }) +_sym_db.RegisterMessage(PyRpcResponse) + +Error = _reflection.GeneratedProtocolMessageType('Error', (_message.Message,), { + 'DESCRIPTOR' : _ERROR, + '__module__' : 'rpc_pb2' + # @@protoc_insertion_point(class_scope:rlmeta.rpc.Error) + }) +_sym_db.RegisterMessage(Error) + +TensorProto = _reflection.GeneratedProtocolMessageType('TensorProto', (_message.Message,), { + 'DESCRIPTOR' : _TENSORPROTO, + '__module__' : 'rpc_pb2' + # @@protoc_insertion_point(class_scope:rlmeta.rpc.TensorProto) + }) +_sym_db.RegisterMessage(TensorProto) + +SimpleData = _reflection.GeneratedProtocolMessageType('SimpleData', (_message.Message,), { + 'DESCRIPTOR' : _SIMPLEDATA, + '__module__' : 'rpc_pb2' + # @@protoc_insertion_point(class_scope:rlmeta.rpc.SimpleData) + }) +_sym_db.RegisterMessage(SimpleData) + +DataVector = _reflection.GeneratedProtocolMessageType('DataVector', (_message.Message,), { + 'DESCRIPTOR' : _DATAVECTOR, + '__module__' : 'rpc_pb2' + # @@protoc_insertion_point(class_scope:rlmeta.rpc.DataVector) + }) +_sym_db.RegisterMessage(DataVector) + +DataMap = _reflection.GeneratedProtocolMessageType('DataMap', (_message.Message,), { + + 'DataEntry' : _reflection.GeneratedProtocolMessageType('DataEntry', (_message.Message,), { + 'DESCRIPTOR' : _DATAMAP_DATAENTRY, + '__module__' : 'rpc_pb2' + # @@protoc_insertion_point(class_scope:rlmeta.rpc.DataMap.DataEntry) + }) + , + 'DESCRIPTOR' : _DATAMAP, + '__module__' : 'rpc_pb2' + # @@protoc_insertion_point(class_scope:rlmeta.rpc.DataMap) + }) +_sym_db.RegisterMessage(DataMap) +_sym_db.RegisterMessage(DataMap.DataEntry) + +NestedData = _reflection.GeneratedProtocolMessageType('NestedData', (_message.Message,), { + 'DESCRIPTOR' : _NESTEDDATA, + '__module__' : 'rpc_pb2' + # @@protoc_insertion_point(class_scope:rlmeta.rpc.NestedData) + }) +_sym_db.RegisterMessage(NestedData) + +_RPC = DESCRIPTOR.services_by_name['Rpc'] +if _descriptor._USE_C_DESCRIPTORS == False: + + DESCRIPTOR._options = None + _TENSORPROTO.fields_by_name['shape']._options = None + _TENSORPROTO.fields_by_name['shape']._serialized_options = b'\020\001' + _DATAMAP_DATAENTRY._options = None + _DATAMAP_DATAENTRY._serialized_options = b'8\001' + _RPCREQUEST._serialized_start=25 + _RPCREQUEST._serialized_end=133 + _RPCRESPONSE._serialized_start=135 + _RPCRESPONSE._serialized_end=228 + _PYRPCREQUEST._serialized_start=230 + _PYRPCREQUEST._serialized_end=292 + _PYRPCRESPONSE._serialized_start=294 + _PYRPCRESPONSE._serialized_end=346 + _ERROR._serialized_start=348 + _ERROR._serialized_end=368 + _TENSORPROTO._serialized_start=371 + _TENSORPROTO._serialized_end=538 + _TENSORPROTO_TENSORTYPE._serialized_start=491 + _TENSORPROTO_TENSORTYPE._serialized_end=538 + _SIMPLEDATA._serialized_start=541 + _SIMPLEDATA._serialized_end=709 + _DATAVECTOR._serialized_start=711 + _DATAVECTOR._serialized_end=761 + _DATAMAP._serialized_start=763 + _DATAMAP._serialized_end=886 + _DATAMAP_DATAENTRY._serialized_start=819 + _DATAMAP_DATAENTRY._serialized_end=886 + _NESTEDDATA._serialized_start=889 + _NESTEDDATA._serialized_end=1030 + _RPC._serialized_start=1033 + _RPC._serialized_end=1174 +# @@protoc_insertion_point(module_scope) diff --git a/rlmeta/rpc/rpc_pb2_grpc.py b/rlmeta/rpc/rpc_pb2_grpc.py new file mode 100644 index 0000000..934ac72 --- /dev/null +++ b/rlmeta/rpc/rpc_pb2_grpc.py @@ -0,0 +1,104 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +# Generated by the gRPC Python protocol compiler plugin. DO NOT EDIT! +"""Client and server classes corresponding to protobuf-defined services.""" +import grpc + +import rlmeta.rpc.rpc_pb2 as rpc__pb2 + + +class RpcStub(object): + """Missing associated documentation comment in .proto file.""" + + def __init__(self, channel): + """Constructor. + + Args: + channel: A grpc.Channel. + """ + self.RemoteCall = channel.unary_unary( + '/rlmeta.rpc.Rpc/RemoteCall', + request_serializer=rpc__pb2.RpcRequest.SerializeToString, + response_deserializer=rpc__pb2.RpcResponse.FromString, + ) + self.PyRemoteCall = channel.unary_unary( + '/rlmeta.rpc.Rpc/PyRemoteCall', + request_serializer=rpc__pb2.PyRpcRequest.SerializeToString, + response_deserializer=rpc__pb2.PyRpcResponse.FromString, + ) + + +class RpcServicer(object): + """Missing associated documentation comment in .proto file.""" + + def RemoteCall(self, request, context): + """Missing associated documentation comment in .proto file.""" + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') + + def PyRemoteCall(self, request, context): + """Missing associated documentation comment in .proto file.""" + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') + + +def add_RpcServicer_to_server(servicer, server): + rpc_method_handlers = { + 'RemoteCall': grpc.unary_unary_rpc_method_handler( + servicer.RemoteCall, + request_deserializer=rpc__pb2.RpcRequest.FromString, + response_serializer=rpc__pb2.RpcResponse.SerializeToString, + ), + 'PyRemoteCall': grpc.unary_unary_rpc_method_handler( + servicer.PyRemoteCall, + request_deserializer=rpc__pb2.PyRpcRequest.FromString, + response_serializer=rpc__pb2.PyRpcResponse.SerializeToString, + ), + } + generic_handler = grpc.method_handlers_generic_handler( + 'rlmeta.rpc.Rpc', rpc_method_handlers) + server.add_generic_rpc_handlers((generic_handler,)) + + + # This class is part of an EXPERIMENTAL API. +class Rpc(object): + """Missing associated documentation comment in .proto file.""" + + @staticmethod + def RemoteCall(request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): + return grpc.experimental.unary_unary(request, target, '/rlmeta.rpc.Rpc/RemoteCall', + rpc__pb2.RpcRequest.SerializeToString, + rpc__pb2.RpcResponse.FromString, + options, channel_credentials, + insecure, call_credentials, compression, wait_for_ready, timeout, metadata) + + @staticmethod + def PyRemoteCall(request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): + return grpc.experimental.unary_unary(request, target, '/rlmeta.rpc.Rpc/PyRemoteCall', + rpc__pb2.PyRpcRequest.SerializeToString, + rpc__pb2.PyRpcResponse.FromString, + options, channel_credentials, + insecure, call_credentials, compression, wait_for_ready, timeout, metadata) diff --git a/rlmeta/rpc/server.py b/rlmeta/rpc/server.py new file mode 100644 index 0000000..2d14962 --- /dev/null +++ b/rlmeta/rpc/server.py @@ -0,0 +1,82 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import time +import threading + +from typing import Any, Callable, Optional, NoReturn + +import rlmeta.utils.data_utils as data_utils +import _rlmeta_extension.rpc as _rpc + + +class Server(_rpc.Server): + + def __init__(self, addr: str): + super().__init__(addr) + self._tasks = [] + self._lock = threading.Lock() + + @property + def addr(self) -> str: + return super().addr + + def register(self, + func_name: str, + func_impl: Callable[..., Any], + batch_size: Optional[int] = None) -> None: + if batch_size is None: + batch_size = 0 + q = self.register_queue(func_name, batch_size) + t = threading.Thread(target=self._process, args=(q, func_impl)) + self._tasks.append((q, t)) + + def start(self) -> None: + super().start() + for _, t in self._tasks: + t.start() + + print("rpc.Server started") + + def stop(self) -> None: + super().stop() + for q, t in self._tasks: + q.shutdown() + t.join() + + print("rpc.Server stopped") + + def _process(self, queue: _rpc.ComputationQueue, + func: Callable[..., Any]) -> NoReturn: + try: + while True: + task = queue.get() + if task is None: + break + self._wrap_func(task, func) + except StopIteration: + return + + def _wrap_func(self, task: _rpc.Task, func: Callable[..., Any]) -> None: + batch_size = None + args = task.args() + kwargs = task.kwargs() + if isinstance(task, _rpc.BatchedTask): + batch_size = task.batch_size + args = data_utils.stack_fields(args) + kwargs = data_utils.stack_fields(kwargs) + + # Lock to protect any state inside func. + # TODO: Find a better way to do this (e.g. asyncio) + with self._lock: + ret = func(*args, **kwargs) + + if batch_size is not None: + if ret is None: + ret = (None,) * batch_size + else: + ret = data_utils.unstack_fields(ret, batch_size) + + task.set_return_value(ret) diff --git a/rlmeta/utils/stats_dict.py b/rlmeta/utils/stats_dict.py index 15d2832..7d91e24 100644 --- a/rlmeta/utils/stats_dict.py +++ b/rlmeta/utils/stats_dict.py @@ -23,6 +23,17 @@ def __init__(self, if val is not None: self.add(val) + @classmethod + def from_dict(cls, data: Dict[str, float]) -> StatsItem: + ret = cls(key=data.get("key", None)) + ret._m0 = data.get("count", 0) + ret._m1 = data.get("mean", 0.0) + std = data.get("std", 0.0) + ret._m2 = std * std * ret._m0 + ret._min_val = data.get("min", float("inf")) + ret._max_val = data.get("max", float("-inf")) + return ret + @property def key(self) -> str: return self._key @@ -84,6 +95,12 @@ def __init__(self) -> None: def __getitem__(self, key: str) -> StatsItem: return self._dict[key] + @classmethod + def from_dict(cls, data: Dict[str, Dict[str, float]]) -> StatsDict: + ret = cls() + ret._dict = {k: StatsItem.from_dict(v) for k, v in data.items()} + return ret + def reset(self): self._dict.clear() @@ -100,7 +117,7 @@ def extend(self, d: Dict[str, float]) -> None: def update(self, stats: StatsDict) -> None: self._dict.update(stats._dict) - def dict(self) -> Dict[str, float]: + def dict(self) -> Dict[str, Dict[str, float]]: return {k: v.dict() for k, v in self._dict.items()} def json(self, info: Optional[str] = None, **kwargs) -> str: diff --git a/third_party/pybind11 b/third_party/pybind11 index d4b9f34..1e3400b 160000 --- a/third_party/pybind11 +++ b/third_party/pybind11 @@ -1 +1 @@ -Subproject commit d4b9f3471f465f0cc6d05556a837c26589b08b29 +Subproject commit 1e3400b6742288429f2069aaf5febf92d0662dae