From 91f805f544788b4e3e048420b33cac08045c0c6c Mon Sep 17 00:00:00 2001 From: Xiaomeng Yang Date: Wed, 27 Jul 2022 23:57:07 -0700 Subject: [PATCH 1/9] Implement gRPC based backend to resolve moolib performance issue --- .gitignore | 2 +- rlmeta/CMakeLists.txt | 26 ++++++- rlmeta/cc/pybind.cc | 12 ++++ rlmeta/rpc/CMakeLists.txt | 92 ++++++++++++++++++++++++ rlmeta/rpc/__init__.py | 4 ++ rlmeta/rpc/cc/blocking_counter.cc | 28 ++++++++ rlmeta/rpc/cc/blocking_counter.h | 38 ++++++++++ rlmeta/rpc/cc/computation_queue.cc | 64 +++++++++++++++++ rlmeta/rpc/cc/computation_queue.h | 69 ++++++++++++++++++ rlmeta/rpc/cc/queue_impl.h | 109 +++++++++++++++++++++++++++++ rlmeta/rpc/cc/server.cc | 65 +++++++++++++++++ rlmeta/rpc/cc/server.h | 74 ++++++++++++++++++++ rlmeta/rpc/cc/task.cc | 71 +++++++++++++++++++ rlmeta/rpc/cc/task.h | 104 +++++++++++++++++++++++++++ rlmeta/rpc/protos/rpc.proto | 27 +++++++ rlmeta/rpc/rpc_pb2.py | 57 +++++++++++++++ rlmeta/rpc/rpc_pb2_grpc.py | 66 +++++++++++++++++ rlmeta/utils/data_utils.py | 2 +- 18 files changed, 905 insertions(+), 5 deletions(-) create mode 100644 rlmeta/rpc/CMakeLists.txt create mode 100644 rlmeta/rpc/__init__.py create mode 100644 rlmeta/rpc/cc/blocking_counter.cc create mode 100644 rlmeta/rpc/cc/blocking_counter.h create mode 100644 rlmeta/rpc/cc/computation_queue.cc create mode 100644 rlmeta/rpc/cc/computation_queue.h create mode 100644 rlmeta/rpc/cc/queue_impl.h create mode 100644 rlmeta/rpc/cc/server.cc create mode 100644 rlmeta/rpc/cc/server.h create mode 100644 rlmeta/rpc/cc/task.cc create mode 100644 rlmeta/rpc/cc/task.h create mode 100644 rlmeta/rpc/protos/rpc.proto create mode 100644 rlmeta/rpc/rpc_pb2.py create mode 100644 rlmeta/rpc/rpc_pb2_grpc.py diff --git a/.gitignore b/.gitignore index d96fadd..02aadb5 100644 --- a/.gitignore +++ b/.gitignore @@ -4,5 +4,5 @@ **/outputs/* *.egg-info/* *.eggs/* -*.so +*.so* compile_commands.json diff --git a/rlmeta/CMakeLists.txt b/rlmeta/CMakeLists.txt index 0cf1b2f..64415e2 100644 --- a/rlmeta/CMakeLists.txt +++ b/rlmeta/CMakeLists.txt @@ -9,7 +9,7 @@ set( -march=native -Wfatal-errors -fvisibility=hidden" ) -# set(CMAKE_EXPORT_COMPILE_COMMANDS ON) +set(CMAKE_EXPORT_COMPILE_COMMANDS ON) add_compile_definitions(_GLIBCXX_USE_CXX11_ABI=0) # PyTorch dependency @@ -48,13 +48,33 @@ add_subdirectory( ${CMAKE_CURRENT_BINARY_DIR}/third_party/pybind11 ) +add_subdirectory( + ${CMAKE_CURRENT_SOURCE_DIR}/rpc + ${CMAKE_CURRENT_BINARY_DIR}/rpc +) + pybind11_add_module( _rlmeta_extension ${CMAKE_CURRENT_SOURCE_DIR}/cc/circular_buffer.cc ${CMAKE_CURRENT_SOURCE_DIR}/cc/nested_utils.cc ${CMAKE_CURRENT_SOURCE_DIR}/cc/pybind.cc ${CMAKE_CURRENT_SOURCE_DIR}/cc/timestamp_manager.cc + ${CMAKE_CURRENT_SOURCE_DIR}/rpc/cc/blocking_counter.cc + ${CMAKE_CURRENT_SOURCE_DIR}/rpc/cc/computation_queue.cc + ${CMAKE_CURRENT_SOURCE_DIR}/rpc/cc/server.cc + ${CMAKE_CURRENT_SOURCE_DIR}/rpc/cc/task.cc ) target_include_directories( - _rlmeta_extension PUBLIC ${CMAKE_CURRENT_SOURCE_DIR}/..) -target_link_libraries(_rlmeta_extension PUBLIC torch ${TORCH_PYTHON_LIBRARIES}) + _rlmeta_extension + PUBLIC + ${CMAKE_CURRENT_SOURCE_DIR}/.. + ${CMAKE_CURRENT_BINARY_DIR}/rpc +) +target_link_libraries( + _rlmeta_extension + PUBLIC + torch + ${TORCH_PYTHON_LIBRARIES} + grpc++ + rpc_grpc_proto +) diff --git a/rlmeta/cc/pybind.cc b/rlmeta/cc/pybind.cc index 0f69fca..598ff73 100644 --- a/rlmeta/cc/pybind.cc +++ b/rlmeta/cc/pybind.cc @@ -9,6 +9,9 @@ #include "rlmeta/cc/nested_utils.h" #include "rlmeta/cc/segment_tree.h" #include "rlmeta/cc/timestamp_manager.h" +#include "rlmeta/rpc/cc/computation_queue.h" +#include "rlmeta/rpc/cc/server.h" +#include "rlmeta/rpc/cc/task.h" namespace py = pybind11; @@ -23,6 +26,15 @@ PYBIND11_MODULE(_rlmeta_extension, m) { rlmeta::DefineCircularBuffer(m); rlmeta::DefineNestedUtils(m); rlmeta::DefineTimestampManager(m); + + rlmeta::rpc::DefineTaskBase(m); + rlmeta::rpc::DefineTask(m); + rlmeta::rpc::DefineBatchedTask(m); + + rlmeta::rpc::DefineComputationQueue(m); + rlmeta::rpc::DefineBatchedComputationQueue(m); + + rlmeta::rpc::DefineServer(m); } } // namespace diff --git a/rlmeta/rpc/CMakeLists.txt b/rlmeta/rpc/CMakeLists.txt new file mode 100644 index 0000000..9cd2ec3 --- /dev/null +++ b/rlmeta/rpc/CMakeLists.txt @@ -0,0 +1,92 @@ +cmake_minimum_required(VERSION 3.14 FATAL_ERROR) + +project(rlmeta) + +set(CMAKE_CXX_STANDARD 17) +set( + CMAKE_CXX_FLAGS + "${CMAKE_CXX_FLAGS} -O3 -Wall -Wextra -Wno-register -Wno-comment -fPIC \ + -march=native -Wfatal-errors -fvisibility=hidden" +) + +include(FetchContent) + +# gRPC dependency +FetchContent_Declare( + grpc + GIT_REPOSITORY https://github.com/grpc/grpc + GIT_TAG v1.47.0 +) + +set(FETCHCONTENT_QUIET OFF) +FetchContent_MakeAvailable(grpc) + +set(_PROTOBUF_LIBPROTOBUF libprotobuf) +set(_REFLECTION grpc++_reflection) +set(_PROTOBUF_PROTOC $) +set(_GRPC_GRPCPP grpc++) +if(CMAKE_CROSSCOMPILING) + find_program(_GRPC_CPP_PLUGIN_EXECUTABLE grpc_cpp_plugin) +else() + set(_GRPC_CPP_PLUGIN_EXECUTABLE $) +endif() + +# proto files +get_filename_component(rpc_proto "${CMAKE_CURRENT_SOURCE_DIR}/protos/rpc.proto" ABSOLUTE) +get_filename_component(rpc_proto_path "${rpc_proto}" PATH) + +# generated files +set(rpc_proto_srcs "${CMAKE_CURRENT_BINARY_DIR}/rpc.pb.cc") +set(rpc_proto_hdrs "${CMAKE_CURRENT_BINARY_DIR}/rpc.pb.h") +set(rpc_grpc_srcs "${CMAKE_CURRENT_BINARY_DIR}/rpc.grpc.pb.cc") +set(rpc_grpc_hdrs "${CMAKE_CURRENT_BINARY_DIR}/rpc.grpc.pb.h") +add_custom_command( + OUTPUT "${rpc_proto_srcs}" "${rpc_proto_hdrs}" "${rpc_grpc_srcs}" "${rpc_grpc_hdrs}" + COMMAND ${_PROTOBUF_PROTOC} + ARGS --grpc_out "${CMAKE_CURRENT_BINARY_DIR}" + --cpp_out "${CMAKE_CURRENT_BINARY_DIR}" + -I "${rpc_proto_path}" + --plugin=protoc-gen-grpc="${_GRPC_CPP_PLUGIN_EXECUTABLE}" + "${rpc_proto}" + DEPENDS "${rpc_proto}" +) + +# rpc_grpc_proto +add_library( + rpc_grpc_proto + ${rpc_grpc_srcs} + ${rpc_grpc_hdrs} + ${rpc_proto_srcs} + ${rpc_proto_hdrs} +) +target_include_directories( + rpc_grpc_proto + PUBLIC + ${CMAKE_CURRENT_BINARY_DIR} +) +target_link_libraries( + rpc_grpc_proto + ${_REFLECTION} + ${_GRPC_GRPCPP} + ${_PROTOBUF_LIBPROTOBUF} +) + + +# pybind11_add_module( +# _rlmeta_rpc +# ${CMAKE_CURRENT_SOURCE_DIR}/pybind.cc +# ${CMAKE_CURRENT_SOURCE_DIR}/server.cc +# ) +# target_include_directories( +# _rlmeta_rpc +# PUBLIC +# ${CMAKE_CURRENT_BINARY_DIR} +# ${CMAKE_CURRENT_SOURCE_DIR} +# ${CMAKE_CURRENT_SOURCE_DIR}/../.. +# ) +# target_link_libraries( +# _rlmeta_rpc +# PUBLIC +# grpc++ +# rpc_grpc_proto +# ) diff --git a/rlmeta/rpc/__init__.py b/rlmeta/rpc/__init__.py new file mode 100644 index 0000000..7bec24c --- /dev/null +++ b/rlmeta/rpc/__init__.py @@ -0,0 +1,4 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. diff --git a/rlmeta/rpc/cc/blocking_counter.cc b/rlmeta/rpc/cc/blocking_counter.cc new file mode 100644 index 0000000..ae517bd --- /dev/null +++ b/rlmeta/rpc/cc/blocking_counter.cc @@ -0,0 +1,28 @@ +#include "rlmeta/rpc/cc/blocking_counter.h" + +#include + +namespace rlmeta { +namespace rpc { + +bool BlockingCounter::DecrementCount() { + bool ret = false; + { + std::unique_lock lk(mu_); + assert(count_ >= 1); + --count_; + ret = (count_ == 0); + } + cv_.notify_all(); + return ret; +} + +void BlockingCounter::Wait() { + std::unique_lock lk(mu_); + assert(num_waiting_ == 0); + ++num_waiting_; + cv_.wait(lk, [this]() { return count_ == 0; }); +} + +} // namespace rpc +} // namespace rlmeta diff --git a/rlmeta/rpc/cc/blocking_counter.h b/rlmeta/rpc/cc/blocking_counter.h new file mode 100644 index 0000000..d1b4a6c --- /dev/null +++ b/rlmeta/rpc/cc/blocking_counter.h @@ -0,0 +1,38 @@ +// Copyright (c) Meta Platforms, Inc. and affiliates. +// +// This source code is licensed under the MIT license found in the +// LICENSE file in the root directory of this source tree. + +#pragma once + +#include +#include +#include +#include + +namespace rlmeta { +namespace rpc { + +// https://github.com/abseil/abseil-cpp/blob/ce42de10fbea616379826e91c7c23c16bffe6e61/absl/synchronization/blocking_counter.h + +class BlockingCounter { + public: + explicit BlockingCounter(int64_t initial_count) : count_(initial_count) {} + + BlockingCounter(const BlockingCounter&) = delete; + BlockingCounter& operator=(const BlockingCounter&) = delete; + + bool DecrementCount(); + + void Wait(); + + private: + int64_t count_; + int64_t num_waiting_ = 0; + + std::mutex mu_; + std::condition_variable cv_; +}; + +} // namespace rpc +} // namespace rlmeta diff --git a/rlmeta/rpc/cc/computation_queue.cc b/rlmeta/rpc/cc/computation_queue.cc new file mode 100644 index 0000000..8f2982f --- /dev/null +++ b/rlmeta/rpc/cc/computation_queue.cc @@ -0,0 +1,64 @@ +// Copyright (c) Meta Platforms, Inc. and affiliates. +// +// This source code is licensed under the MIT license found in the +// LICENSE file in the root directory of this source tree. + +#include "rlmeta/rpc/cc/computation_queue.h" + +namespace rlmeta { +namespace rpc { + +std::future BatchedComputationQueue::Put( + const std::string& args, const std::string& kwargs) { + std::scoped_lock lk(mu_); + if (cur_computation_ == nullptr) { + cur_computation_ = std::make_shared(batch_size_); + queue_impl_.Put(cur_computation_); + } + std::future ret = cur_computation_->Add(args, kwargs); + if (cur_computation_->Full()) { + cur_computation_.reset(); + } + return ret; +} + +std::shared_ptr BatchedComputationQueue::Get() { + std::shared_ptr ret = queue_impl_.Get().value_or(nullptr); + if (ret != nullptr) { + std::scoped_lock lk(mu_); + if (!dynamic_cast(ret.get())->Full()) { + cur_computation_.reset(); + } + } + return ret; +} + +std::shared_ptr BatchedComputationQueue::GetFullBatch() { + std::shared_ptr ret = queue_impl_.Get().value_or(nullptr); + if (ret != nullptr) { + dynamic_cast(ret.get())->Wait(); + } + return ret; +} + +void DefineComputationQueue(py::module& m) { + py::class_>( + m, "ComputationQueue") + .def(py::init<>()) + .def("get", &ComputationQueue::Get, + py::call_guard()) + .def("shutdown", &ComputationQueue::Shutdown, + py::call_guard()); +} + +void DefineBatchedComputationQueue(py::module& m) { + py::class_>( + m, "BatchedComputationQueue") + .def(py::init()) + .def("get_full_batch", &BatchedComputationQueue::GetFullBatch, + py::call_guard()); +} + +} // namespace rpc +} // namespace rlmeta diff --git a/rlmeta/rpc/cc/computation_queue.h b/rlmeta/rpc/cc/computation_queue.h new file mode 100644 index 0000000..dd961b4 --- /dev/null +++ b/rlmeta/rpc/cc/computation_queue.h @@ -0,0 +1,69 @@ +// Copyright (c) Meta Platforms, Inc. and affiliates. +// +// This source code is licensed under the MIT license found in the +// LICENSE file in the root directory of this source tree. + +#pragma once + +#include + +#include +#include +#include + +#include "rlmeta/rpc/cc/queue_impl.h" +#include "rlmeta/rpc/cc/task.h" + +namespace py = pybind11; + +namespace rlmeta { +namespace rpc { + +class ComputationQueue { + public: + ComputationQueue() = default; + explicit ComputationQueue(int64_t capacity) : queue_impl_(capacity) {} + + virtual std::future Put(const std::string& args, + const std::string& kwargs) { + std::shared_ptr task = std::make_shared(args, kwargs); + queue_impl_.Put(task); + return task->Future(); + } + + virtual std::shared_ptr Get() { + return queue_impl_.Get().value_or(nullptr); + } + + virtual void Shutdown() { queue_impl_.Shutdown(); } + + protected: + QueueImpl> queue_impl_; +}; + +class BatchedComputationQueue : public ComputationQueue { + public: + explicit BatchedComputationQueue(int64_t batch_size) + : batch_size_(batch_size) {} + BatchedComputationQueue(int64_t capacity, int64_t batch_size) + : ComputationQueue(capacity), batch_size_(batch_size) {} + + std::future Put(const std::string& args, + const std::string& kwargs) override; + + std::shared_ptr Get() override; + + std::shared_ptr GetFullBatch(); + + protected: + const int64_t batch_size_; + std::shared_ptr cur_computation_ = nullptr; + + std::mutex mu_; +}; + +void DefineComputationQueue(py::module& m); +void DefineBatchedComputationQueue(py::module& m); + +} // namespace rpc +} // namespace rlmeta diff --git a/rlmeta/rpc/cc/queue_impl.h b/rlmeta/rpc/cc/queue_impl.h new file mode 100644 index 0000000..a350c54 --- /dev/null +++ b/rlmeta/rpc/cc/queue_impl.h @@ -0,0 +1,109 @@ +// Copyright (c) Meta Platforms, Inc. and affiliates. +// +// This source code is licensed under the MIT license found in the +// LICENSE file in the root directory of this source tree. + +#pragma once + +#include +#include +#include +#include +#include + +namespace rlmeta { +namespace rpc { + +template +class QueueImpl { + public: + QueueImpl() = default; + explicit QueueImpl(int64_t capacity) : capacity_(capacity){}; + + int64_t Size() const { + std::scoped_lock lk(mu_); + return data_.size(); + } + + bool Empty() const { + std::scoped_lock lk(mu_); + return data_.empty(); + } + + bool Full() const { + std::scoped_lock lk(mu_); + return data_.size() == capacity_; + } + + bool Put(const T& o) { + { + std::unique_lock lk(mu_); + can_put_.wait(lk, [this]() { + return !is_alive_ || static_cast(data_.size()) < capacity_; + }); + if (!is_alive_) { + return false; + } + data_.push_back(o); + } + can_get_.notify_one(); + return true; + } + + bool Put(T&& o) { + { + std::unique_lock lk(mu_); + can_put_.wait(lk, [this]() { + return !is_alive_ || static_cast(data_.size()) < capacity_; + }); + if (!is_alive_) { + return false; + } + data_.push_back(std::move(o)); + } + can_get_.notify_one(); + return true; + } + + std::optional Get() { + std::optional ret = [this]() -> std::optional { + std::unique_lock lk(mu_); + can_get_.wait(lk, [this]() { return !is_alive_ || !data_.empty(); }); + if (!is_alive_) { + return std::nullopt; + } + T ret = std::move(data_.front()); + data_.pop_front(); + return std::make_optional(ret); + }(); + if (ret.has_value()) { + can_put_.notify_one(); + } + return ret; + } + + void Shutdown() { + { + std::scoped_lock lk(mu_); + if (!is_alive_) { + return; + } + is_alive_ = false; + data_.clear(); + } + can_put_.notify_all(); + can_get_.notify_all(); + } + + protected: + const int64_t capacity_ = 1024; + std::deque data_; + bool is_alive_ = true; + + std::mutex mu_; + std::condition_variable can_put_; + std::condition_variable can_get_; +}; + +} // namespace rpc +} // namespace rlmeta diff --git a/rlmeta/rpc/cc/server.cc b/rlmeta/rpc/cc/server.cc new file mode 100644 index 0000000..0b845c6 --- /dev/null +++ b/rlmeta/rpc/cc/server.cc @@ -0,0 +1,65 @@ +// Copyright (c) Meta Platforms, Inc. and affiliates. +// +// This source code is licensed under the MIT license found in the +// LICENSE file in the root directory of this source tree. + +#include "rlmeta/rpc/cc/server.h" + +#include +#include + +#include +#include + +namespace rlmeta { +namespace rpc { + +grpc::Status ServiceImpl::RemoteCall(grpc::ServerContext* /*context*/, + const RpcRequest* request, + RpcResponse* response) { + auto& func = functions_.at(request->function()); + response->set_return_value(func(request->args(), request->kwargs())); + return grpc::Status::OK; +} + +void Server::Start() { + grpc::EnableDefaultHealthCheckService(true); + grpc::reflection::InitProtoReflectionServerBuilderPlugin(); + grpc::ServerBuilder builder; + builder.AddListeningPort(addr_, grpc::InsecureServerCredentials()); + builder.RegisterService(&service_); + server_ = builder.BuildAndStart(); +} + +void Server::Stop() { + server_->Shutdown(); + server_->Wait(); +} + +std::shared_ptr Server::RegisterQueue( + const std::string& func_name, int64_t batch_size) { + std::shared_ptr ret = nullptr; + if (batch_size == 0) { + ret = std::make_shared(); + } else { + ret = std::make_shared(batch_size); + } + service_.Register(func_name, [que = ret](const std::string& args, + const std::string& kwargs) { + std::future ret = que->Put(args, kwargs); + return ret.get(); + }); + return ret; +} + +void DefineServer(py::module& m) { + py::class_>(m, "Server") + .def(py::init()) + .def("start", &Server::Start) + .def("stop", &Server::Stop, py::call_guard()) + .def("register_queue", &Server::RegisterQueue, py::arg("func_name"), + py::arg("batch_size") = 0); +} + +} // namespace rpc +} // namespace rlmeta diff --git a/rlmeta/rpc/cc/server.h b/rlmeta/rpc/cc/server.h new file mode 100644 index 0000000..cea3802 --- /dev/null +++ b/rlmeta/rpc/cc/server.h @@ -0,0 +1,74 @@ +// Copyright (c) Meta Platforms, Inc. and affiliates. +// +// This source code is licensed under the MIT license found in the +// LICENSE file in the root directory of this source tree. + +#pragma once + +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include + +#include "rlmeta/rpc/cc/computation_queue.h" +#include "rlmeta/rpc/cc/task.h" +#include "rpc.grpc.pb.h" +#include "rpc.pb.h" + +namespace py = pybind11; + +namespace rlmeta { +namespace rpc { + +using PyFunc = + std::function; +using PyFuncDict = std::unordered_map; + +class ServiceImpl final : public Rpc::Service { + public: + grpc::Status Register(const std::string& func_name, PyFunc&& func_impl) { + functions_.emplace(func_name, std::move(func_impl)); + return grpc::Status::OK; + } + + private: + grpc::Status RemoteCall(grpc::ServerContext* context, + const RpcRequest* request, + RpcResponse* response) override; + + PyFuncDict functions_; +}; + +class Server { + public: + explicit Server(const std::string& addr) : addr_(addr) {} + + ~Server() { Stop(); } + + void Start(); + void Stop(); + + std::shared_ptr RegisterQueue(const std::string& func_name, + int64_t batch_size = 0); + + protected: + void ServePyFuncQueue(const std::string& func_name); + + void ServeImpl(const std::string& func_name, TaskBase& task); + + const std::string addr_; + std::unique_ptr server_; + ServiceImpl service_; +}; + +void DefineServer(py::module& m); + +} // namespace rpc +} // namespace rlmeta diff --git a/rlmeta/rpc/cc/task.cc b/rlmeta/rpc/cc/task.cc new file mode 100644 index 0000000..3bc2562 --- /dev/null +++ b/rlmeta/rpc/cc/task.cc @@ -0,0 +1,71 @@ +// Copyright (c) Meta Platforms, Inc. and affiliates. +// +// This source code is licensed under the MIT license found in the +// LICENSE file in the root directory of this source tree. + +#include "rlmeta/rpc/cc/task.h" + +#include + +namespace rlmeta { +namespace rpc { + +py::object BatchedTask::Args() { + const int64_t batch_size = batch_.size(); + py::tuple ret(batch_size); + for (int64_t i = 0; i < batch_size; ++i) { + ret[i] = batch_[i].Args(); + } + return ret; +} + +py::object BatchedTask::Kwargs() { + const int64_t batch_size = batch_.size(); + py::tuple ret(batch_size); + for (int64_t i = 0; i < batch_size; ++i) { + ret[i] = batch_[i].Kwargs(); + } + return ret; +} + +void BatchedTask::SetReturnValue(py::object&& return_value) { + const int64_t batch_size = batch_.size(); + py::tuple rets = py::reinterpret_borrow(return_value); + assert(rets.size() == batch_size); + for (int64_t i = 0; i < batch_size; ++i) { + batch_[i].SetReturnValue(std::move(rets[i])); + } +} + +std::future BatchedTask::Add(const std::string& args, + const std::string& kwargs) { + const int64_t batch_size = batch_.size(); + assert(batch_size < capacity_); + auto& task = batch_.emplace_back(args, kwargs); + num_to_wait_.DecrementCount(); + return task.Future(); +} + +void DefineTaskBase(py::module& m) { + py::class_>(m, "TaskBase") + .def("args", &TaskBase::Args) + .def("kwargs", &TaskBase::Kwargs) + .def("set_return_value", &TaskBase::SetReturnValue); +} + +void DefineTask(py::module& m) { + py::class_>(m, "Task"); +} + +void DefineBatchedTask(py::module& m) { + py::class_>(m, + "BatchedTask") + .def("__len__", &BatchedTask::BatchSize) + .def_property_readonly("capacity", &BatchedTask::capacity) + .def_property_readonly("batch_size", &BatchedTask::BatchSize) + .def("empty", &BatchedTask::Empty) + .def("full", &BatchedTask::Full); +} + +} // namespace rpc +} // namespace rlmeta diff --git a/rlmeta/rpc/cc/task.h b/rlmeta/rpc/cc/task.h new file mode 100644 index 0000000..8fd2cb8 --- /dev/null +++ b/rlmeta/rpc/cc/task.h @@ -0,0 +1,104 @@ +// Copyright (c) Meta Platforms, Inc. and affiliates. +// +// This source code is licensed under the MIT license found in the +// LICENSE file in the root directory of this source tree. + +#pragma once + +#include + +#include +#include +#include +#include +#include + +#include "rlmeta/rpc/cc/blocking_counter.h" + +namespace py = pybind11; + +namespace rlmeta { +namespace rpc { + +class TaskBase { + public: + virtual py::object Args() = 0; + virtual py::object Kwargs() = 0; + virtual void SetReturnValue(py::object&& return_value) = 0; +}; + +// https://pybind11.readthedocs.io/en/stable/advanced/classes.html#overriding-virtuals +class PyTaskBase : public TaskBase { + public: + using TaskBase::TaskBase; + + py::object Args() override { + PYBIND11_OVERRIDE_PURE(py::object, TaskBase, Args); + } + + py::object Kwargs() override { + PYBIND11_OVERRIDE_PURE(py::object, TaskBase, Kwargs); + } + + void SetReturnValue(py::object&& return_value) override { + PYBIND11_OVERRIDE_PURE(void, TaskBase, SetReturnValue, return_value); + } +}; + +class Task : public TaskBase { + public: + Task(const std::string& args, const std::string& kwargs) + : args_(args), kwargs_(kwargs) {} + + py::object Args() override { return py::bytes(std::move(args_)); } + py::object Kwargs() override { return py::bytes(std::move(kwargs_)); } + + std::future Future() { return promise_.get_future(); } + + void SetReturnValue(py::object&& return_value) override { + promise_.set_value( + py::reinterpret_borrow(std::move(return_value))); + } + + protected: + std::string args_; + std::string kwargs_; + std::promise promise_; +}; + +class BatchedTask : public TaskBase { + public: + explicit BatchedTask(int64_t capacity) + : capacity_(capacity), num_to_wait_(capacity) { + batch_.reserve(capacity_); + } + + int64_t capacity() const { return capacity_; } + int64_t BatchSize() const { return batch_.size(); } + + bool Empty() const { return batch_.empty(); } + bool Full() const { return static_cast(batch_.size()) == capacity_; } + + py::object Args() override; + py::object Kwargs() override; + + void SetReturnValue(py::object&& return_value) override; + + std::future Add(const std::string& args, + const std::string& kwargs); + + void Wait() { num_to_wait_.Wait(); } + + protected: + const int64_t capacity_; + std::vector batch_; + + BlockingCounter num_to_wait_; +}; + +void DefineTaskBase(py::module& m); +void DefineTask(py::module& m); +void DefineBatchedTask(py::module& m); + +} // namespace rpc +} // namespace rlmeta diff --git a/rlmeta/rpc/protos/rpc.proto b/rlmeta/rpc/protos/rpc.proto new file mode 100644 index 0000000..b3824b7 --- /dev/null +++ b/rlmeta/rpc/protos/rpc.proto @@ -0,0 +1,27 @@ +// Copyright (c) Meta Platforms, Inc. and affiliates. +// +// This source code is licensed under the MIT license found in the +// LICENSE file in the root directory of this source tree. + +syntax = "proto3"; + +package rlmeta.rpc; + +service Rpc { + rpc RemoteCall(RpcRequest) returns (RpcResponse) {} +} + +message Error { + string message = 1; +} + +message RpcRequest { + string function = 1; + bytes args = 2; + bytes kwargs = 3; +} + +message RpcResponse { + bytes return_value = 1; + optional Error error = 2; +} diff --git a/rlmeta/rpc/rpc_pb2.py b/rlmeta/rpc/rpc_pb2.py new file mode 100644 index 0000000..0ca2add --- /dev/null +++ b/rlmeta/rpc/rpc_pb2.py @@ -0,0 +1,57 @@ +# -*- coding: utf-8 -*- +# Generated by the protocol buffer compiler. DO NOT EDIT! +# source: rpc.proto +"""Generated protocol buffer code.""" +from google.protobuf import descriptor as _descriptor +from google.protobuf import descriptor_pool as _descriptor_pool +from google.protobuf import message as _message +from google.protobuf import reflection as _reflection +from google.protobuf import symbol_database as _symbol_database +# @@protoc_insertion_point(imports) + +_sym_db = _symbol_database.Default() + + + + +DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\trpc.proto\x12\nrlmeta.rpc\"\x18\n\x05\x45rror\x12\x0f\n\x07message\x18\x01 \x01(\t\"<\n\nRpcRequest\x12\x10\n\x08\x66unction\x18\x01 \x01(\t\x12\x0c\n\x04\x61rgs\x18\x02 \x01(\x0c\x12\x0e\n\x06kwargs\x18\x03 \x01(\x0c\"T\n\x0bRpcResponse\x12\x14\n\x0creturn_value\x18\x01 \x01(\x0c\x12%\n\x05\x65rror\x18\x02 \x01(\x0b\x32\x11.rlmeta.rpc.ErrorH\x00\x88\x01\x01\x42\x08\n\x06_error2F\n\x03Rpc\x12?\n\nRemoteCall\x12\x16.rlmeta.rpc.RpcRequest\x1a\x17.rlmeta.rpc.RpcResponse\"\x00\x62\x06proto3') + + + +_ERROR = DESCRIPTOR.message_types_by_name['Error'] +_RPCREQUEST = DESCRIPTOR.message_types_by_name['RpcRequest'] +_RPCRESPONSE = DESCRIPTOR.message_types_by_name['RpcResponse'] +Error = _reflection.GeneratedProtocolMessageType('Error', (_message.Message,), { + 'DESCRIPTOR' : _ERROR, + '__module__' : 'rpc_pb2' + # @@protoc_insertion_point(class_scope:rlmeta.rpc.Error) + }) +_sym_db.RegisterMessage(Error) + +RpcRequest = _reflection.GeneratedProtocolMessageType('RpcRequest', (_message.Message,), { + 'DESCRIPTOR' : _RPCREQUEST, + '__module__' : 'rpc_pb2' + # @@protoc_insertion_point(class_scope:rlmeta.rpc.RpcRequest) + }) +_sym_db.RegisterMessage(RpcRequest) + +RpcResponse = _reflection.GeneratedProtocolMessageType('RpcResponse', (_message.Message,), { + 'DESCRIPTOR' : _RPCRESPONSE, + '__module__' : 'rpc_pb2' + # @@protoc_insertion_point(class_scope:rlmeta.rpc.RpcResponse) + }) +_sym_db.RegisterMessage(RpcResponse) + +_RPC = DESCRIPTOR.services_by_name['Rpc'] +if _descriptor._USE_C_DESCRIPTORS == False: + + DESCRIPTOR._options = None + _ERROR._serialized_start=25 + _ERROR._serialized_end=49 + _RPCREQUEST._serialized_start=51 + _RPCREQUEST._serialized_end=111 + _RPCRESPONSE._serialized_start=113 + _RPCRESPONSE._serialized_end=197 + _RPC._serialized_start=199 + _RPC._serialized_end=269 +# @@protoc_insertion_point(module_scope) diff --git a/rlmeta/rpc/rpc_pb2_grpc.py b/rlmeta/rpc/rpc_pb2_grpc.py new file mode 100644 index 0000000..4393220 --- /dev/null +++ b/rlmeta/rpc/rpc_pb2_grpc.py @@ -0,0 +1,66 @@ +# Generated by the gRPC Python protocol compiler plugin. DO NOT EDIT! +"""Client and server classes corresponding to protobuf-defined services.""" +import grpc + +import rlmeta.rpc.rpc_pb2 as rpc__pb2 + + +class RpcStub(object): + """Missing associated documentation comment in .proto file.""" + + def __init__(self, channel): + """Constructor. + + Args: + channel: A grpc.Channel. + """ + self.RemoteCall = channel.unary_unary( + '/rlmeta.rpc.Rpc/RemoteCall', + request_serializer=rpc__pb2.RpcRequest.SerializeToString, + response_deserializer=rpc__pb2.RpcResponse.FromString, + ) + + +class RpcServicer(object): + """Missing associated documentation comment in .proto file.""" + + def RemoteCall(self, request, context): + """Missing associated documentation comment in .proto file.""" + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') + + +def add_RpcServicer_to_server(servicer, server): + rpc_method_handlers = { + 'RemoteCall': grpc.unary_unary_rpc_method_handler( + servicer.RemoteCall, + request_deserializer=rpc__pb2.RpcRequest.FromString, + response_serializer=rpc__pb2.RpcResponse.SerializeToString, + ), + } + generic_handler = grpc.method_handlers_generic_handler( + 'rlmeta.rpc.Rpc', rpc_method_handlers) + server.add_generic_rpc_handlers((generic_handler,)) + + + # This class is part of an EXPERIMENTAL API. +class Rpc(object): + """Missing associated documentation comment in .proto file.""" + + @staticmethod + def RemoteCall(request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): + return grpc.experimental.unary_unary(request, target, '/rlmeta.rpc.Rpc/RemoteCall', + rpc__pb2.RpcRequest.SerializeToString, + rpc__pb2.RpcResponse.FromString, + options, channel_credentials, + insecure, call_credentials, compression, wait_for_ready, timeout, metadata) diff --git a/rlmeta/utils/data_utils.py b/rlmeta/utils/data_utils.py index fba239c..45406e8 100644 --- a/rlmeta/utils/data_utils.py +++ b/rlmeta/utils/data_utils.py @@ -94,7 +94,7 @@ def stack_fields(input: Sequence[NestedTensor]) -> NestedTensor: def unstack_fields(input: NestedTensor, batch_size: int) -> Tuple[NestedTensor, ...]: if batch_size == 1: - return nested_utils.map_nested(lambda x: x.squeeze(0), input) + return (nested_utils.map_nested(lambda x: x.squeeze(0), input),) else: return nested_utils.unbatch_nested(lambda x: torch.unbind(x), input, batch_size) From d6b8d20573e665fde9282059031b3fde1be82321 Mon Sep 17 00:00:00 2001 From: Xiaomeng Yang Date: Thu, 28 Jul 2022 13:49:36 -0700 Subject: [PATCH 2/9] Add python grpc server and client --- rlmeta/cc/nested_utils.cc | 5 +- rlmeta/cc/pybind.cc | 21 ++++---- rlmeta/rpc/__init__.py | 16 ++++++ rlmeta/rpc/cc/blocking_counter.cc | 5 ++ rlmeta/rpc/cc/blocking_counter.h | 17 +++++- rlmeta/rpc/cc/server.cc | 1 + rlmeta/rpc/cc/server.h | 3 +- rlmeta/rpc/client.py | 48 +++++++++++++++++ rlmeta/rpc/rpc_pb2.py | 6 +++ rlmeta/rpc/rpc_pb2_grpc.py | 5 ++ rlmeta/rpc/server.py | 90 +++++++++++++++++++++++++++++++ 11 files changed, 202 insertions(+), 15 deletions(-) create mode 100644 rlmeta/rpc/client.py create mode 100644 rlmeta/rpc/server.py diff --git a/rlmeta/cc/nested_utils.cc b/rlmeta/cc/nested_utils.cc index fc6600d..f1418f5 100644 --- a/rlmeta/cc/nested_utils.cc +++ b/rlmeta/cc/nested_utils.cc @@ -201,10 +201,7 @@ py::tuple UnbatchNested(std::function func, } // namespace nested_utils void DefineNestedUtils(py::module& m) { - py::module sub = - m.def_submodule("nested_utils", "A submodule of \"_rlmeta_extension\""); - - sub.def("flatten_nested", &nested_utils::FlattenNested) + m.def("flatten_nested", &nested_utils::FlattenNested) .def("map_nested", &nested_utils::MapNested) .def("collate_nested", py::overload_cast, diff --git a/rlmeta/cc/pybind.cc b/rlmeta/cc/pybind.cc index 598ff73..c318139 100644 --- a/rlmeta/cc/pybind.cc +++ b/rlmeta/cc/pybind.cc @@ -24,17 +24,20 @@ PYBIND11_MODULE(_rlmeta_extension, m) { rlmeta::DefineMinSegmentTree("Fp64", m); rlmeta::DefineCircularBuffer(m); - rlmeta::DefineNestedUtils(m); rlmeta::DefineTimestampManager(m); - rlmeta::rpc::DefineTaskBase(m); - rlmeta::rpc::DefineTask(m); - rlmeta::rpc::DefineBatchedTask(m); - - rlmeta::rpc::DefineComputationQueue(m); - rlmeta::rpc::DefineBatchedComputationQueue(m); - - rlmeta::rpc::DefineServer(m); + py::module nested_utils = m.def_submodule( + "nested_utils", "A submodule of \"_rlmeta_extension\" for nested_utils"); + rlmeta::DefineNestedUtils(nested_utils); + + py::module rpc = + m.def_submodule("rpc", "A submodule of \"_rlmeta_extension\" for RPC"); + rlmeta::rpc::DefineTaskBase(rpc); + rlmeta::rpc::DefineTask(rpc); + rlmeta::rpc::DefineBatchedTask(rpc); + rlmeta::rpc::DefineComputationQueue(rpc); + rlmeta::rpc::DefineBatchedComputationQueue(rpc); + rlmeta::rpc::DefineServer(rpc); } } // namespace diff --git a/rlmeta/rpc/__init__.py b/rlmeta/rpc/__init__.py index 7bec24c..14cadb1 100644 --- a/rlmeta/rpc/__init__.py +++ b/rlmeta/rpc/__init__.py @@ -2,3 +2,19 @@ # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. + +from _rlmeta_extension.rpc import ComputationQueue, BatchedComputationQueue +from _rlmeta_extension.rpc import TaskBase, Task, BatchedTask + +from rlmeta.rpc.client import Client +from rlmeta.rpc.server import Server + +__all__ = [ + "ComputationQueue", + "BatchedComputationQueue", + "TaskBase", + "Task", + "BatchedTask", + "Client", + "Server", +] diff --git a/rlmeta/rpc/cc/blocking_counter.cc b/rlmeta/rpc/cc/blocking_counter.cc index ae517bd..4ee5a75 100644 --- a/rlmeta/rpc/cc/blocking_counter.cc +++ b/rlmeta/rpc/cc/blocking_counter.cc @@ -1,3 +1,8 @@ +// Copyright (c) Meta Platforms, Inc. and affiliates. +// +// This source code is licensed under the MIT license found in the +// LICENSE file in the root directory of this source tree. + #include "rlmeta/rpc/cc/blocking_counter.h" #include diff --git a/rlmeta/rpc/cc/blocking_counter.h b/rlmeta/rpc/cc/blocking_counter.h index d1b4a6c..5bc70a9 100644 --- a/rlmeta/rpc/cc/blocking_counter.h +++ b/rlmeta/rpc/cc/blocking_counter.h @@ -5,7 +5,6 @@ #pragma once -#include #include #include #include @@ -13,7 +12,23 @@ namespace rlmeta { namespace rpc { +// BlockingCounter implementation is copied and modified from +// the BlockingCounter class in Abseil. // https://github.com/abseil/abseil-cpp/blob/ce42de10fbea616379826e91c7c23c16bffe6e61/absl/synchronization/blocking_counter.h +// +// Copyright 2017 The Abseil Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. class BlockingCounter { public: diff --git a/rlmeta/rpc/cc/server.cc b/rlmeta/rpc/cc/server.cc index 0b845c6..15cac5c 100644 --- a/rlmeta/rpc/cc/server.cc +++ b/rlmeta/rpc/cc/server.cc @@ -55,6 +55,7 @@ std::shared_ptr Server::RegisterQueue( void DefineServer(py::module& m) { py::class_>(m, "Server") .def(py::init()) + .def_property_readonly("addr", &Server::addr) .def("start", &Server::Start) .def("stop", &Server::Stop, py::call_guard()) .def("register_queue", &Server::RegisterQueue, py::arg("func_name"), diff --git a/rlmeta/rpc/cc/server.h b/rlmeta/rpc/cc/server.h index cea3802..840813c 100644 --- a/rlmeta/rpc/cc/server.h +++ b/rlmeta/rpc/cc/server.h @@ -49,9 +49,10 @@ class ServiceImpl final : public Rpc::Service { class Server { public: explicit Server(const std::string& addr) : addr_(addr) {} - ~Server() { Stop(); } + const std::string& addr() const { return addr_; } + void Start(); void Stop(); diff --git a/rlmeta/rpc/client.py b/rlmeta/rpc/client.py new file mode 100644 index 0000000..7761f24 --- /dev/null +++ b/rlmeta/rpc/client.py @@ -0,0 +1,48 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import pickle + +from typing import Any + +import grpc +import grpc.experimental + +import rlmeta.rpc.rpc_pb2 as rpc_pb2 +import rlmeta.rpc.rpc_pb2_grpc as rpc_pb2_grpc + + +class Client: + + def connect(self, addr: str) -> None: + # self._channel = grpc.insecure_channel(addr) + # self._rpc_stub = rpc_pb2_grpc.RpcStub(self._channel) + + self._addr = addr + self._channel_options = [ + (grpc.experimental.ChannelOptions.SingleThreadedUnaryStream, 1) + ] + + def rpc(self, function: str, *args, **kwargs) -> Any: + with grpc.insecure_channel(self._addr, + options=self._channel_options) as channel: + stub = rpc_pb2_grpc.RpcStub(channel) + # ret = self._rpc_stub.RemoteCall( + ret = stub.RemoteCall( + rpc_pb2.RpcRequest(function=function, + args=pickle.dumps(args), + kwargs=pickle.dumps(kwargs))) + return pickle.loads(ret.return_value) + + async def async_rpc(self, function: str, *args, **kwargs) -> Any: + async with grpc.aio.insecure_channel( + self._addr, options=self._channel_options) as channel: + stub = rpc_pb2_grpc.RpcStub(channel) + # ret = await self._rpc_stub.RemoteCall( + ret = await stub.RemoteCall( + rpc_pb2.RpcRequest(function=function, + args=pickle.dumps(args), + kwargs=pickle.dumps(kwargs))) + return pickle.loads(ret.return_value) diff --git a/rlmeta/rpc/rpc_pb2.py b/rlmeta/rpc/rpc_pb2.py index 0ca2add..3afafab 100644 --- a/rlmeta/rpc/rpc_pb2.py +++ b/rlmeta/rpc/rpc_pb2.py @@ -1,4 +1,10 @@ # -*- coding: utf-8 -*- + +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + # Generated by the protocol buffer compiler. DO NOT EDIT! # source: rpc.proto """Generated protocol buffer code.""" diff --git a/rlmeta/rpc/rpc_pb2_grpc.py b/rlmeta/rpc/rpc_pb2_grpc.py index 4393220..d46b6ad 100644 --- a/rlmeta/rpc/rpc_pb2_grpc.py +++ b/rlmeta/rpc/rpc_pb2_grpc.py @@ -1,3 +1,8 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + # Generated by the gRPC Python protocol compiler plugin. DO NOT EDIT! """Client and server classes corresponding to protobuf-defined services.""" import grpc diff --git a/rlmeta/rpc/server.py b/rlmeta/rpc/server.py new file mode 100644 index 0000000..004f3b9 --- /dev/null +++ b/rlmeta/rpc/server.py @@ -0,0 +1,90 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import pickle +import threading + +from typing import Any, Callable, Optional, NoReturn + +import rlmeta.utils.data_utils as data_utils + +import _rlmeta_extension.rpc as rpc + + +class Server(rpc.Server): + + def __init__(self, addr: str): + super().__init__(addr) + self._tasks = [] + self._lock = threading.Lock() + + @property + def addr(self) -> str: + return super().addr + + def register(self, + func_name: str, + func_impl: Callable[..., Any], + batch_size: Optional[int] = None) -> None: + if batch_size is None: + batch_size = 0 + q = self.register_queue(func_name, batch_size) + t = threading.Thread(target=self._process, args=(q, func_impl)) + self._tasks.append((q, t)) + + def start(self) -> None: + super().start() + for _, t in self._tasks: + t.start() + + print("rpc.Server started") + + def stop(self) -> None: + super().stop() + for q, t in self._tasks: + q.shutdown() + t.join() + + print("rpc.Server stopped") + + def _process(self, queue: rpc.ComputationQueue, + func: Callable[..., Any]) -> NoReturn: + try: + while True: + task = queue.get() + if task is None: + break + self._wrap_func(task, func) + except StopIteration: + return + + def _wrap_func(self, task: rpc.TaskBase, func: Callable[..., Any]) -> None: + batch_size = None + + # TODO: Find better serialization method to replace pickle here. + if isinstance(task, rpc.Task): + args = pickle.loads(task.args()) + kwargs = pickle.loads(task.kwargs()) + else: + batch_size = task.batch_size + args = tuple(pickle.loads(i) for i in task.args()) + args = data_utils.stack_fields(args) + kwargs = tuple(pickle.loads(i) for i in task.kwargs()) + kwargs = data_utils.stack_fields(kwargs) + + # Lock to protect any state inside func. + # TODO: Find a better way to do this (e.g. asyncio) + with self._lock: + ret = func(*args, **kwargs) + + if batch_size is None: + ret = pickle.dumps(ret) + else: + if ret is None: + ret = (pickle.dumps(None),) * batch_size + else: + ret = data_utils.unstack_fields(ret, batch_size) + ret = tuple(pickle.dumps(i) for i in ret) + task.set_return_value(ret) From 70a124f2ce275962a208021e5281c84673146f0c Mon Sep 17 00:00:00 2001 From: Xiaomeng Yang Date: Thu, 4 Aug 2022 14:45:50 -0700 Subject: [PATCH 3/9] Add C++ grpc server and client --- examples/tutorials/remote_example.py | 7 +- rlmeta/CMakeLists.txt | 4 + rlmeta/agents/agent.py | 6 +- rlmeta/agents/dqn/apex_dqn_agent.py | 4 +- rlmeta/agents/ppo/ppo_agent.py | 4 +- rlmeta/cc/nested_utils.cc | 22 ++-- rlmeta/cc/nested_utils.h | 24 ++++ rlmeta/cc/numpy_utils.h | 5 +- rlmeta/cc/pybind.cc | 11 +- rlmeta/cc/torch_utils.h | 34 +++++ rlmeta/core/controller.py | 4 +- rlmeta/core/model.py | 21 ++-- rlmeta/core/remote.py | 33 +++-- rlmeta/core/replay_buffer.py | 59 +++++---- rlmeta/core/server.py | 116 ++++++++++-------- rlmeta/rpc/CMakeLists.txt | 2 +- rlmeta/rpc/__init__.py | 3 +- rlmeta/rpc/cc/client.cc | 96 +++++++++++++++ rlmeta/rpc/cc/client.h | 84 +++++++++++++ rlmeta/rpc/cc/computation_queue.cc | 29 +++-- rlmeta/rpc/cc/computation_queue.h | 26 ++-- rlmeta/rpc/cc/rpc_future.cc | 32 +++++ rlmeta/rpc/cc/rpc_future.h | 37 ++++++ rlmeta/rpc/cc/rpc_utils.cc | 177 +++++++++++++++++++++++++++ rlmeta/rpc/cc/rpc_utils.h | 38 ++++++ rlmeta/rpc/cc/server.cc | 18 +-- rlmeta/rpc/cc/server.h | 6 +- rlmeta/rpc/cc/task.cc | 71 +++++------ rlmeta/rpc/cc/task.h | 72 +++++------ rlmeta/rpc/cc/tensor_wrapper.cc | 86 +++++++++++++ rlmeta/rpc/cc/tensor_wrapper.h | 45 +++++++ rlmeta/rpc/client.py | 48 +++----- rlmeta/rpc/protos/rpc.proto | 53 +++++++- rlmeta/rpc/rpc_pb2.py | 63 ---------- rlmeta/rpc/rpc_pb2_grpc.py | 71 ----------- rlmeta/rpc/server.py | 24 ++-- rlmeta/utils/stats_dict.py | 19 ++- third_party/pybind11 | 2 +- 38 files changed, 1035 insertions(+), 421 deletions(-) create mode 100644 rlmeta/rpc/cc/client.cc create mode 100644 rlmeta/rpc/cc/client.h create mode 100644 rlmeta/rpc/cc/rpc_future.cc create mode 100644 rlmeta/rpc/cc/rpc_future.h create mode 100644 rlmeta/rpc/cc/rpc_utils.cc create mode 100644 rlmeta/rpc/cc/rpc_utils.h create mode 100644 rlmeta/rpc/cc/tensor_wrapper.cc create mode 100644 rlmeta/rpc/cc/tensor_wrapper.h delete mode 100644 rlmeta/rpc/rpc_pb2.py delete mode 100644 rlmeta/rpc/rpc_pb2_grpc.py diff --git a/examples/tutorials/remote_example.py b/examples/tutorials/remote_example.py index 1b7bcb0..7bddad9 100644 --- a/examples/tutorials/remote_example.py +++ b/examples/tutorials/remote_example.py @@ -4,6 +4,7 @@ # LICENSE file in the root directory of this source tree. import asyncio +import time import torch import torch.multiprocessing as mp @@ -62,18 +63,18 @@ def main(): adder = Adder() adder_server = Server(name="adder_server", addr="127.0.0.1:4411") adder_server.add_service(adder) - adder_client = remote_utils.make_remote(adder, adder_server) adder_server.start() + time.sleep(2) adder_client.connect() a = 1 b = 2 c = adder_client.add(a, b) print(f"{a} + {b} = {c}") - print("") - asyncio.run(run_batch(adder_client, send_tensor=False)) + # print("") + # asyncio.run(run_batch(adder_client, send_tensor=False)) print("") asyncio.run(run_batch(adder_client, send_tensor=True)) diff --git a/rlmeta/CMakeLists.txt b/rlmeta/CMakeLists.txt index 64415e2..5d2b743 100644 --- a/rlmeta/CMakeLists.txt +++ b/rlmeta/CMakeLists.txt @@ -60,9 +60,13 @@ pybind11_add_module( ${CMAKE_CURRENT_SOURCE_DIR}/cc/pybind.cc ${CMAKE_CURRENT_SOURCE_DIR}/cc/timestamp_manager.cc ${CMAKE_CURRENT_SOURCE_DIR}/rpc/cc/blocking_counter.cc + ${CMAKE_CURRENT_SOURCE_DIR}/rpc/cc/client.cc ${CMAKE_CURRENT_SOURCE_DIR}/rpc/cc/computation_queue.cc + ${CMAKE_CURRENT_SOURCE_DIR}/rpc/cc/rpc_future.cc + ${CMAKE_CURRENT_SOURCE_DIR}/rpc/cc/rpc_utils.cc ${CMAKE_CURRENT_SOURCE_DIR}/rpc/cc/server.cc ${CMAKE_CURRENT_SOURCE_DIR}/rpc/cc/task.cc + ${CMAKE_CURRENT_SOURCE_DIR}/rpc/cc/tensor_wrapper.cc ) target_include_directories( _rlmeta_extension diff --git a/rlmeta/agents/agent.py b/rlmeta/agents/agent.py index b8ef727..55bdf46 100644 --- a/rlmeta/agents/agent.py +++ b/rlmeta/agents/agent.py @@ -101,7 +101,7 @@ def __call__(self, index: int) -> Agent: return self._cls(*args, **kwargs) def _make_arg(self, arg: Any, index: int) -> Any: - if isinstance(arg, remote.Remote): - arg = copy.deepcopy(arg) - arg.name = moolib_utils.expend_name_by_index(arg.name, index) + # if isinstance(arg, remote.Remote): + # arg = copy.deepcopy(arg) + # arg.name = moolib_utils.expend_name_by_index(arg.name, index) return arg diff --git a/rlmeta/agents/dqn/apex_dqn_agent.py b/rlmeta/agents/dqn/apex_dqn_agent.py index b2bb7f8..fca97af 100644 --- a/rlmeta/agents/dqn/apex_dqn_agent.py +++ b/rlmeta/agents/dqn/apex_dqn_agent.py @@ -163,7 +163,7 @@ def train(self, num_steps: int) -> Optional[StatsDict]: for m in self._additional_models_to_update: m.push() - episode_stats = self.controller.get_stats() + episode_stats = StatsDict.from_dict(self.controller.get_stats()) stats.update(episode_stats) return stats @@ -172,7 +172,7 @@ def eval(self, num_episodes: Optional[int] = None) -> Optional[StatsDict]: self.controller.set_phase(Phase.EVAL, limit=num_episodes, reset=True) while self.controller.get_count() < num_episodes: time.sleep(1) - stats = self.controller.get_stats() + stats = StatsDict.from_dict(self.controller.get_stats()) return stats def make_replay(self) -> Optional[List[NestedTensor]]: diff --git a/rlmeta/agents/ppo/ppo_agent.py b/rlmeta/agents/ppo/ppo_agent.py index d86716a..ce0e641 100644 --- a/rlmeta/agents/ppo/ppo_agent.py +++ b/rlmeta/agents/ppo/ppo_agent.py @@ -156,7 +156,7 @@ def train(self, num_steps: int) -> Optional[StatsDict]: if self.step_counter % self.push_every_n_steps == 0: self.model.push() - episode_stats = self.controller.get_stats() + episode_stats = StatsDict.from_dict(self.controller.get_stats()) stats.update(episode_stats) return stats @@ -165,7 +165,7 @@ def eval(self, num_episodes: Optional[int] = None) -> Optional[StatsDict]: self.controller.set_phase(Phase.EVAL, limit=num_episodes, reset=True) while self.controller.get_count() < num_episodes: time.sleep(1) - stats = self.controller.get_stats() + stats = StatsDict.from_dict(self.controller.get_stats()) return stats def device(self) -> torch.device: diff --git a/rlmeta/cc/nested_utils.cc b/rlmeta/cc/nested_utils.cc index f1418f5..1660a6f 100644 --- a/rlmeta/cc/nested_utils.cc +++ b/rlmeta/cc/nested_utils.cc @@ -6,7 +6,6 @@ #include "rlmeta/cc/nested_utils.h" #include -#include namespace rlmeta { @@ -34,8 +33,10 @@ void VisitNestedImpl(Function func, const py::object& obj) { if (py::isinstance(obj)) { const py::dict src = py::reinterpret_borrow(obj); - for (const auto [k, v] : src) { - VisitNestedImpl(func, py::reinterpret_borrow(v)); + const std::vector keys = SortedKeys(src); + for (const std::string& k : keys) { + VisitNestedImpl(func, + py::reinterpret_borrow(src[py::str(k)])); } return; } @@ -68,8 +69,11 @@ py::object MapNestedImpl(Function func, const py::object& obj) { if (py::isinstance(obj)) { const py::dict src = py::reinterpret_borrow(obj); py::dict dst; - for (const auto [k, v] : src) { - dst[k] = MapNestedImpl(func, py::reinterpret_borrow(v)); + const std::vector keys = SortedKeys(src); + for (const std::string& k : keys) { + const py::str key = py::str(k); + dst[key] = + MapNestedImpl(func, py::reinterpret_borrow(src[key])); } return std::move(dst); } @@ -150,12 +154,14 @@ py::tuple UnbatchNestedImpl(std::function func, for (int64_t i = 0; i < batch_size; ++i) { dst[i] = py::dict(); } - for (const auto [k, v] : src) { + const std::vector keys = SortedKeys(src); + for (const std::string& k : keys) { + const py::str key = py::str(k); py::tuple cur = UnbatchNestedImpl( - func, py::reinterpret_borrow(v), batch_size); + func, py::reinterpret_borrow(src[key]), batch_size); for (int64_t i = 0; i < batch_size; ++i) { py::dict y = py::reinterpret_borrow(dst[i]); - y[k] = cur[i]; + y[key] = cur[i]; } } return dst; diff --git a/rlmeta/cc/nested_utils.h b/rlmeta/cc/nested_utils.h index ac1c6e6..dd5c2d1 100644 --- a/rlmeta/cc/nested_utils.h +++ b/rlmeta/cc/nested_utils.h @@ -9,6 +9,8 @@ #include #include +#include +#include namespace py = pybind11; @@ -16,6 +18,28 @@ namespace rlmeta { namespace nested_utils { +template +inline std::vector SortedKeys(const Dict& dict) { + std::vector ret; + ret.reserve(dict.size()); + for (const auto [k, v] : dict) { + ret.push_back(k); + } + std::sort(ret.begin(), ret.end()); + return ret; +} + +template <> +inline std::vector SortedKeys(const py::dict& dict) { + std::vector ret; + ret.reserve(dict.size()); + for (const auto [k, v] : dict) { + ret.push_back(py::reinterpret_borrow(k)); + } + std::sort(ret.begin(), ret.end()); + return ret; +} + py::tuple FlattenNested(const py::object& obj); py::object MapNested(std::function func, diff --git a/rlmeta/cc/numpy_utils.h b/rlmeta/cc/numpy_utils.h index f2b6a85..38e4c0a 100644 --- a/rlmeta/cc/numpy_utils.h +++ b/rlmeta/cc/numpy_utils.h @@ -28,10 +28,7 @@ std::vector NumpyArrayShape(const py::array_t& arr) { template py::array_t NumpyEmptyLike(const py::array_t& src) { - py::array_t dst(src.size()); - const std::vector shape = NumpyArrayShape(src); - dst.resize(shape); - return dst; + return py::array_t(NumpyArrayShape(src)); } } // namespace utils diff --git a/rlmeta/cc/pybind.cc b/rlmeta/cc/pybind.cc index c318139..c36bf5e 100644 --- a/rlmeta/cc/pybind.cc +++ b/rlmeta/cc/pybind.cc @@ -9,7 +9,10 @@ #include "rlmeta/cc/nested_utils.h" #include "rlmeta/cc/segment_tree.h" #include "rlmeta/cc/timestamp_manager.h" +#include "rlmeta/rpc/cc/client.h" #include "rlmeta/rpc/cc/computation_queue.h" +#include "rlmeta/rpc/cc/rpc_future.h" +#include "rlmeta/rpc/cc/rpc_utils.h" #include "rlmeta/rpc/cc/server.h" #include "rlmeta/rpc/cc/task.h" @@ -32,12 +35,18 @@ PYBIND11_MODULE(_rlmeta_extension, m) { py::module rpc = m.def_submodule("rpc", "A submodule of \"_rlmeta_extension\" for RPC"); - rlmeta::rpc::DefineTaskBase(rpc); + // rlmeta::rpc::DefineTaskBase(rpc); rlmeta::rpc::DefineTask(rpc); rlmeta::rpc::DefineBatchedTask(rpc); rlmeta::rpc::DefineComputationQueue(rpc); rlmeta::rpc::DefineBatchedComputationQueue(rpc); rlmeta::rpc::DefineServer(rpc); + rlmeta::rpc::DefineClient(rpc); + rlmeta::rpc::DefineRpcFuture(rpc); + + py::module rpc_utils = rpc.def_submodule( + "rpc_utils", "A submodule of \"_rlmeta_extension.rpc\" for rpc_utils"); + rlmeta::rpc::DefineRpcUtils(rpc_utils); } } // namespace diff --git a/rlmeta/cc/torch_utils.h b/rlmeta/cc/torch_utils.h index ab74bda..7420813 100644 --- a/rlmeta/cc/torch_utils.h +++ b/rlmeta/cc/torch_utils.h @@ -5,6 +5,8 @@ #pragma once +#include +#include #include #include @@ -20,6 +22,26 @@ struct TorchDataType { static constexpr torch::ScalarType value = torch::kBool; }; +template <> +struct TorchDataType { + static constexpr torch::ScalarType value = torch::kUInt8; +}; + +template <> +struct TorchDataType { + static constexpr torch::ScalarType value = torch::kInt8; +}; + +template <> +struct TorchDataType { + static constexpr torch::ScalarType value = torch::kInt16; +}; + +template <> +struct TorchDataType { + static constexpr torch::ScalarType value = torch::kInt32; +}; + template <> struct TorchDataType { static constexpr torch::ScalarType value = torch::kInt64; @@ -35,5 +57,17 @@ struct TorchDataType { static constexpr torch::ScalarType value = torch::kDouble; }; +inline bool IsTorchTensor(const py::object& obj) { + return THPVariable_Check(obj.ptr()); +} + +inline torch::Tensor PyObjectToTorchTensor(const py::object& obj) { + return THPVariable_Unpack(obj.ptr()); +} + +inline py::object TorchTensorToPyObject(const torch::Tensor& tensor) { + return py::reinterpret_steal(THPVariable_Wrap(tensor)); +} + } // namespace utils } // namespace rlmeta diff --git a/rlmeta/core/controller.py b/rlmeta/core/controller.py index 0255379..91e451e 100644 --- a/rlmeta/core/controller.py +++ b/rlmeta/core/controller.py @@ -65,8 +65,8 @@ def get_count(self) -> int: return self.count @remote.remote_method(batch_size=None) - def get_stats(self) -> StatsDict: - return self.stats + def get_stats(self) -> Dict[str, Dict[str, float]]: + return self.stats.dict() @remote.remote_method(batch_size=None) def add_episode(self, phase: Phase, stats: Dict[str, float]) -> None: diff --git a/rlmeta/core/model.py b/rlmeta/core/model.py index 03fe9f6..40e8ee1 100644 --- a/rlmeta/core/model.py +++ b/rlmeta/core/model.py @@ -64,26 +64,31 @@ def __call__(self, *args, **kwargs) -> Any: return self.wrapped(*args, **kwargs) def pull(self) -> None: - state_dict = self.client.sync(self.server_name, - self.remote_method_name("pull")) + # state_dict = self.client.sync(self.server_name, + # self.remote_method_name("pull")) + state_dict = self.client.rpc(self.remote_method_name("pull")) self.wrapped.load_state_dict(state_dict) async def async_pull(self) -> None: - state_dict = await self.client.async_(self.server_name, - self.remote_method_name("pull")) + # state_dict = await self.client.async_(self.server_name, + # self.remote_method_name("pull")) + state_dict = await self.client.async_rpc(self.remote_method_name("pull") + ) self.wrapped.load_state_dict(state_dict) def push(self) -> None: state_dict = self.wrapped.state_dict() state_dict = nested_utils.map_nested(lambda x: x.cpu(), state_dict) - self.client.sync(self.server_name, self.remote_method_name("push"), - state_dict) + # self.client.sync(self.server_name, self.remote_method_name("push"), + # state_dict) + self.client.rpc(self.remote_method_name("push"), state_dict) async def async_push(self) -> None: state_dict = self.wrapped.state_dict() state_dict = nested_utils.map_nested(lambda x: x.cpu(), state_dict) - await self.client.async_(self.server_name, - self.remote_method_name("push"), state_dict) + # await self.client.async_(self.server_name, + # self.remote_method_name("push"), state_dict) + await self.client.async_rpc(self.remote_method_name("push"), state_dict) def _bind(self) -> None: pass diff --git a/rlmeta/core/remote.py b/rlmeta/core/remote.py index f70e14b..26c3957 100644 --- a/rlmeta/core/remote.py +++ b/rlmeta/core/remote.py @@ -10,7 +10,8 @@ from typing import Any, Callable, List, Optional -import moolib +# import moolib +import rlmeta.rpc as rpc from rlmeta.core.launchable import Launchable from rlmeta.utils.moolib_utils import generate_random_name @@ -97,8 +98,11 @@ def server_name(self) -> str: def server_addr(self) -> str: return self._server_addr + # @property + # def client(self) -> Optional[moolib.Client]: + # return self._client @property - def client(self) -> Optional[moolib.Client]: + def client(self) -> Optional[rpc.Client]: return self._client @property @@ -116,11 +120,15 @@ def remote_method_name(self, method: str) -> str: def connect(self) -> None: if self._connected: return - self._client = moolib.Rpc() - self._client.set_transports(["uv"]) - self._client.set_name(self._name) - self._client.set_timeout(self._timeout) + # self._client = moolib.Rpc() + # self._client.set_transports(["uv"]) + # self._client.set_name(self._name) + # self._client.set_timeout(self._timeout) + # self._client.connect(self._server_addr) + + self._client = rpc.Client() self._client.connect(self._server_addr) + self._bind() self._connected = True @@ -140,12 +148,17 @@ def _reset(self, def _bind(self) -> None: for method in self._remote_methods: + # self._client_methods[method] = functools.partial( + # self.client.sync, self.server_name, + # self.remote_method_name(method)) + # self._client_methods["async_" + method] = functools.partial( + # self.client.async_, self.server_name, + # self.remote_method_name(method)) + self._client_methods[method] = functools.partial( - self.client.sync, self.server_name, - self.remote_method_name(method)) + self.client.rpc, self.remote_method_name(method)) self._client_methods["async_" + method] = functools.partial( - self.client.async_, self.server_name, - self.remote_method_name(method)) + self.client.async_rpc, self.remote_method_name(method)) def remote_method(batch_size: Optional[int] = None) -> Callable[..., Any]: diff --git a/rlmeta/core/replay_buffer.py b/rlmeta/core/replay_buffer.py index a6ac98e..a0ecfe1 100644 --- a/rlmeta/core/replay_buffer.py +++ b/rlmeta/core/replay_buffer.py @@ -4,14 +4,17 @@ # LICENSE file in the root directory of this source tree. import collections -import time +import concurrent.futures import logging +import time from typing import Callable, Optional, Sequence, Tuple, Union -from rich.console import Console + import numpy as np import torch +from rich.console import Console + import rlmeta.core.remote as remote import rlmeta.utils.data_utils as data_utils import rlmeta.utils.nested_utils as nested_utils @@ -290,11 +293,12 @@ def __init__(self, prefetch: int = 0, timeout: float = 60) -> None: super().__init__(target, server_name, server_addr, name, timeout) - self._prefetch = prefetch - self._futures = collections.deque() self._server_name = server_name self._server_addr = server_addr + self._prefetch = prefetch + self._futures = collections.deque() + def __repr__(self): return (f"RemoteReplayBuffer(server_name={self._server_name}, " + f"server_addr={self._server_addr})") @@ -308,16 +312,22 @@ def sample( ) -> Union[NestedTensor, Tuple[NestedTensor, torch.Tensor, torch.Tensor, torch.Tensor]]: if len(self._futures) > 0: - ret = self._futures.popleft().result() + # ret = self._futures.popleft().result() + ret = self._futures.popleft().get() else: - ret = self.client.sync(self.server_name, - self.remote_method_name("sample"), - batch_size) - - while len(self._futures) < self.prefetch: - fut = self.client.async_(self.server_name, - self.remote_method_name("sample"), - batch_size) + # ret = self.client.sync(self.server_name, + # self.remote_method_name("sample"), + # batch_size) + ret = self.client.rpc(self.remote_method_name("sample"), batch_size) + + # while len(self._futures) < self.prefetch: + # fut = self.client.async_(self.server_name, + # self.remote_method_name("sample"), + # batch_size) + # self._futures.append(fut) + while len(self._futures) < self._prefetch: + fut = self.client.rpc_future(self.remote_method_name("sample"), + batch_size) self._futures.append(fut) return ret @@ -329,14 +339,21 @@ async def async_sample( if len(self._futures) > 0: ret = await self._futures.popleft() else: - ret = await self.client.async_(self.server_name, - self.remote_method_name("sample"), - batch_size) - - while len(self._futures) < self.prefetch: - fut = self.client.async_(self.server_name, - self.remote_method_name("sample"), - batch_size) + # ret = await self.client.async_(self.server_name, + # self.remote_method_name("sample"), + # batch_size) + ret = await self.client.async_rpc(self.remote_method_name("sample"), + batch_size) + + # while len(self._futures) < self.prefetch: + # fut = self.client.async_(self.server_name, + # self.remote_method_name("sample"), + # batch_size) + # self._futures.append(fut) + + while len(self._futures) < self._prefetch: + fut = self.client.async_rpc(self.remote_method_name("sample"), + batch_size) self._futures.append(fut) return ret diff --git a/rlmeta/core/server.py b/rlmeta/core/server.py index b2ff30e..70146b5 100644 --- a/rlmeta/core/server.py +++ b/rlmeta/core/server.py @@ -14,8 +14,9 @@ import torch.multiprocessing as mp from rich.console import Console -import moolib +# import moolib +import rlmeta.rpc as rpc import rlmeta.utils.asyncio_utils as asyncio_utils from rlmeta.core.launchable import Launchable @@ -85,64 +86,75 @@ def init_execution(self) -> None: if isinstance(service, Launchable): service.init_execution() - self._server = moolib.Rpc() - self._server.set_transports(["uv"]) - self._server.set_name(self._name) - self._server.set_timeout(self._timeout) - console.log(f"Server={self.name} listening to {self._addr}") - try: - self._server.listen(self._addr) - except: - console.log(f"ERROR on listen({self._addr}) from: server={self}") - raise - - def _start_services(self) -> NoReturn: - self._loop = asyncio.get_event_loop() - self._tasks = [] - console.log(f"Server={self.name} starting services: {self._services}") + # self._server = moolib.Rpc() + # self._server.set_transports(["uv"]) + # self._server.set_name(self._name) + # self._server.set_timeout(self._timeout) + # console.log(f"Server={self.name} listening to {self._addr}") + # try: + # self._server.listen(self._addr) + # except: + # console.log(f"ERROR on listen({self._addr}) from: server={self}") + # raise + + self._server = rpc.Server(self._addr) for service in self._services: for method in service.remote_methods: method_impl = getattr(service, method) batch_size = getattr(method_impl, "__batch_size__", None) - self._add_server_task(service.remote_method_name(method), + self._server.register(service.remote_method_name(method), method_impl, batch_size) - try: - if not self._loop.is_running(): - self._loop.run_forever() - except Exception as e: - logging.error(e) - raise - finally: - for task in self._tasks: - task.cancel() - self._loop.stop() - self._loop.close() + + def _start_services(self) -> NoReturn: + # self._loop = asyncio.get_event_loop() + # self._tasks = [] + # console.log(f"Server={self.name} starting services: {self._services}") + # for service in self._services: + # for method in service.remote_methods: + # method_impl = getattr(service, method) + # batch_size = getattr(method_impl, "__batch_size__", None) + # self._add_server_task(service.remote_method_name(method), + # method_impl, batch_size) + # try: + # if not self._loop.is_running(): + # self._loop.run_forever() + # except Exception as e: + # logging.error(e) + # raise + # finally: + # for task in self._tasks: + # task.cancel() + # self._loop.stop() + # self._loop.close() + + self._server.start() + console.log(f"Server={self.name} listening to {self._addr}") console.log(f"Server={self.name} services started") - def _add_server_task(self, func_name: str, func_impl: Callable[..., Any], - batch_size: Optional[int]) -> None: - if batch_size is None: - que = self._server.define_queue(func_name) - else: - que = self._server.define_queue(func_name, - batch_size=batch_size, - dynamic_batching=True) - task = asyncio_utils.create_task(self._loop, - self._async_process(que, func_impl)) - self._tasks.append(task) - - async def _async_process(self, que: moolib.Queue, - func: Callable[..., Any]) -> None: - try: - while True: - ret_cb, args, kwargs = await que - ret = func(*args, **kwargs) - ret_cb(ret) - except asyncio.CancelledError: - pass - except Exception as e: - logging.error(e) - raise e + # def _add_server_task(self, func_name: str, func_impl: Callable[..., Any], + # batch_size: Optional[int]) -> None: + # if batch_size is None: + # que = self._server.define_queue(func_name) + # else: + # que = self._server.define_queue(func_name, + # batch_size=batch_size, + # dynamic_batching=True) + # task = asyncio_utils.create_task(self._loop, + # self._async_process(que, func_impl)) + # self._tasks.append(task) + # + # async def _async_process(self, que: moolib.Queue, + # func: Callable[..., Any]) -> None: + # try: + # while True: + # ret_cb, args, kwargs = await que + # ret = func(*args, **kwargs) + # ret_cb(ret) + # except asyncio.CancelledError: + # pass + # except Exception as e: + # logging.error(e) + # raise e class ServerList: diff --git a/rlmeta/rpc/CMakeLists.txt b/rlmeta/rpc/CMakeLists.txt index 9cd2ec3..4f35c9d 100644 --- a/rlmeta/rpc/CMakeLists.txt +++ b/rlmeta/rpc/CMakeLists.txt @@ -15,7 +15,7 @@ include(FetchContent) FetchContent_Declare( grpc GIT_REPOSITORY https://github.com/grpc/grpc - GIT_TAG v1.47.0 + GIT_TAG v1.48.0 ) set(FETCHCONTENT_QUIET OFF) diff --git a/rlmeta/rpc/__init__.py b/rlmeta/rpc/__init__.py index 14cadb1..9f5b711 100644 --- a/rlmeta/rpc/__init__.py +++ b/rlmeta/rpc/__init__.py @@ -4,7 +4,7 @@ # LICENSE file in the root directory of this source tree. from _rlmeta_extension.rpc import ComputationQueue, BatchedComputationQueue -from _rlmeta_extension.rpc import TaskBase, Task, BatchedTask +from _rlmeta_extension.rpc import Task, BatchedTask from rlmeta.rpc.client import Client from rlmeta.rpc.server import Server @@ -12,7 +12,6 @@ __all__ = [ "ComputationQueue", "BatchedComputationQueue", - "TaskBase", "Task", "BatchedTask", "Client", diff --git a/rlmeta/rpc/cc/client.cc b/rlmeta/rpc/cc/client.cc new file mode 100644 index 0000000..9af5123 --- /dev/null +++ b/rlmeta/rpc/cc/client.cc @@ -0,0 +1,96 @@ +// Copyright (c) Meta Platforms, Inc. and affiliates. +// +// This source code is licensed under the MIT license found in the +// LICENSE file in the root directory of this source tree. + +#include "rlmeta/rpc/cc/client.h" + +#include +#include + +#include "rlmeta/rpc/cc/rpc_utils.h" + +namespace rlmeta { +namespace rpc { + +void Client::Connect(const std::string& addr, int64_t timeout) { + if (connected_) { + Disconnect(); + } + + grpc::ChannelArguments ch_args; + ch_args.SetMaxSendMessageSize(-1); // Unlimited + ch_args.SetMaxReceiveMessageSize(-1); // Unlimited + channel_ = grpc::CreateCustomChannel(addr, grpc::InsecureChannelCredentials(), + ch_args); + stub_ = Rpc::NewStub(channel_); + + const auto deadline = + std::chrono::system_clock::now() + std::chrono::seconds(timeout); + if (!channel_->WaitForConnected(deadline)) { + std::cerr << "[Client::connect] timeout" << std::endl; + } + thread_ = std::make_unique(&Client::AsyncCompleteRpc, this); + connected_ = true; +} + +void Client::Disconnect() { + if (connected_) { + connected_ = false; + cq_.Shutdown(); + thread_->join(); + thread_.reset(); + } +} + +py::object Client::Rpc(const std::string& func, const py::args& args, + const py::kwargs& kwargs) { + return RpcFuture(func, args, kwargs).Get(); +} + +rlmeta::rpc::RpcFuture Client::RpcFuture(const std::string& func, + const py::args& args, + const py::kwargs& kwargs) { + assert(connected_); + RpcRequest request; + request.set_function(func); + *request.mutable_args() = rpc_utils::PythonToNestedData(args); + *request.mutable_kwargs() = rpc_utils::PythonToNestedData(kwargs); + + py::gil_scoped_release release; + AsyncClientCall* call = new AsyncClientCall(); + rlmeta::rpc::RpcFuture fut = call->promise.get_future(); + call->response_reader = + stub_->PrepareAsyncRemoteCall(&call->context, request, &cq_); + call->response_reader->StartCall(); + call->response_reader->Finish(&call->response, &call->status, (void*)call); + return fut; +} + +void Client::AsyncCompleteRpc() { + void* got_tag; + bool ok = false; + while (cq_.Next(&got_tag, &ok)) { + std::unique_ptr call( + static_cast(got_tag)); + GPR_ASSERT(ok); + assert(call->status.ok()); + call->promise.set_value(std::move(*call->response.mutable_return_value())); + } +} + +void DefineClient(py::module& m) { + py::class_>(m, "Client") + .def(py::init<>()) + .def_property_readonly("addr", &Client::addr) + .def_property_readonly("connected", &Client::connected) + .def("connect", &Client::Connect, py::arg("addr"), + py::arg("timeout") = 60, py::call_guard()) + .def("disconnect", &Client::Disconnect, + py::call_guard()) + .def("rpc", &Client::Rpc) + .def("rpc_future", &Client::RpcFuture); +} + +} // namespace rpc +} // namespace rlmeta diff --git a/rlmeta/rpc/cc/client.h b/rlmeta/rpc/cc/client.h new file mode 100644 index 0000000..0703a67 --- /dev/null +++ b/rlmeta/rpc/cc/client.h @@ -0,0 +1,84 @@ +// Copyright (c) Meta Platforms, Inc. and affiliates. +// +// This source code is licensed under the MIT license found in the +// LICENSE file in the root directory of this source tree. + +#pragma once + +#include +#include + +#include +#include +#include +#include + +#include "rlmeta/rpc/cc/rpc_future.h" +#include "rpc.grpc.pb.h" +#include "rpc.pb.h" + +namespace py = pybind11; + +namespace rlmeta { +namespace rpc { + +// The Client class is adapted from gRPC C++ async client example. +// https://github.com/grpc/grpc/blob/2d4f3c56001cd1e1f85734b2f7c5ce5f2797c38a/examples/cpp/helloworld/greeter_async_client2.cc +// +// Copyright 2015 gRPC authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +class Client { + public: + Client() = default; + ~Client() { Disconnect(); } + + const std::string& addr() const { return addr_; } + bool connected() const { return connected_; } + + void Connect(const std::string& addr, int64_t timeout = 60); + void Disconnect(); + + py::object Rpc(const std::string& func, const py::args& args, + const py::kwargs& kwargs); + rlmeta::rpc::RpcFuture RpcFuture(const std::string& func, + const py::args& args, + const py::kwargs& kwargs); + + protected: + struct AsyncClientCall { + RpcResponse response; + grpc::ClientContext context; + grpc::Status status; + std::promise promise; + std::unique_ptr> + response_reader; + }; + + void AsyncCompleteRpc(); + + std::string addr_; + bool connected_ = false; + + std::shared_ptr channel_; + std::unique_ptr stub_; + + grpc::CompletionQueue cq_; + std::unique_ptr thread_; +}; + +void DefineClient(py::module& m); + +} // namespace rpc +} // namespace rlmeta diff --git a/rlmeta/rpc/cc/computation_queue.cc b/rlmeta/rpc/cc/computation_queue.cc index 8f2982f..aa24eaf 100644 --- a/rlmeta/rpc/cc/computation_queue.cc +++ b/rlmeta/rpc/cc/computation_queue.cc @@ -8,22 +8,37 @@ namespace rlmeta { namespace rpc { -std::future BatchedComputationQueue::Put( - const std::string& args, const std::string& kwargs) { +std::future BatchedComputationQueue::Put(const NestedData& args, + const NestedData& kwargs) { std::scoped_lock lk(mu_); if (cur_computation_ == nullptr) { cur_computation_ = std::make_shared(batch_size_); queue_impl_.Put(cur_computation_); } - std::future ret = cur_computation_->Add(args, kwargs); + std::future ret = cur_computation_->Add(args, kwargs); if (cur_computation_->Full()) { cur_computation_.reset(); } return ret; } -std::shared_ptr BatchedComputationQueue::Get() { - std::shared_ptr ret = queue_impl_.Get().value_or(nullptr); +std::future BatchedComputationQueue::Put(NestedData&& args, + NestedData&& kwargs) { + std::scoped_lock lk(mu_); + if (cur_computation_ == nullptr) { + cur_computation_ = std::make_shared(batch_size_); + queue_impl_.Put(cur_computation_); + } + std::future ret = + cur_computation_->Add(std::move(args), std::move(kwargs)); + if (cur_computation_->Full()) { + cur_computation_.reset(); + } + return ret; +} + +std::shared_ptr BatchedComputationQueue::Get() { + std::shared_ptr ret = queue_impl_.Get().value_or(nullptr); if (ret != nullptr) { std::scoped_lock lk(mu_); if (!dynamic_cast(ret.get())->Full()) { @@ -33,8 +48,8 @@ std::shared_ptr BatchedComputationQueue::Get() { return ret; } -std::shared_ptr BatchedComputationQueue::GetFullBatch() { - std::shared_ptr ret = queue_impl_.Get().value_or(nullptr); +std::shared_ptr BatchedComputationQueue::GetFullBatch() { + std::shared_ptr ret = queue_impl_.Get().value_or(nullptr); if (ret != nullptr) { dynamic_cast(ret.get())->Wait(); } diff --git a/rlmeta/rpc/cc/computation_queue.h b/rlmeta/rpc/cc/computation_queue.h index dd961b4..eee8826 100644 --- a/rlmeta/rpc/cc/computation_queue.h +++ b/rlmeta/rpc/cc/computation_queue.h @@ -13,6 +13,7 @@ #include "rlmeta/rpc/cc/queue_impl.h" #include "rlmeta/rpc/cc/task.h" +#include "rpc.pb.h" namespace py = pybind11; @@ -24,21 +25,28 @@ class ComputationQueue { ComputationQueue() = default; explicit ComputationQueue(int64_t capacity) : queue_impl_(capacity) {} - virtual std::future Put(const std::string& args, - const std::string& kwargs) { + virtual std::future Put(const NestedData& args, + const NestedData& kwargs) { std::shared_ptr task = std::make_shared(args, kwargs); queue_impl_.Put(task); return task->Future(); } - virtual std::shared_ptr Get() { + virtual std::future Put(NestedData&& args, NestedData&& kwargs) { + std::shared_ptr task = + std::make_shared(std::move(args), std::move(kwargs)); + queue_impl_.Put(task); + return task->Future(); + } + + virtual std::shared_ptr Get() { return queue_impl_.Get().value_or(nullptr); } virtual void Shutdown() { queue_impl_.Shutdown(); } protected: - QueueImpl> queue_impl_; + QueueImpl> queue_impl_; }; class BatchedComputationQueue : public ComputationQueue { @@ -48,12 +56,12 @@ class BatchedComputationQueue : public ComputationQueue { BatchedComputationQueue(int64_t capacity, int64_t batch_size) : ComputationQueue(capacity), batch_size_(batch_size) {} - std::future Put(const std::string& args, - const std::string& kwargs) override; - - std::shared_ptr Get() override; + std::future Put(const NestedData& args, + const NestedData& kwargs) override; + std::future Put(NestedData&& args, NestedData&& kwargs) override; - std::shared_ptr GetFullBatch(); + std::shared_ptr Get() override; + std::shared_ptr GetFullBatch(); protected: const int64_t batch_size_; diff --git a/rlmeta/rpc/cc/rpc_future.cc b/rlmeta/rpc/cc/rpc_future.cc new file mode 100644 index 0000000..52d0e41 --- /dev/null +++ b/rlmeta/rpc/cc/rpc_future.cc @@ -0,0 +1,32 @@ +// Copyright (c) Meta Platforms, Inc. and affiliates. +// +// This source code is licensed under the MIT license found in the +// LICENSE file in the root directory of this source tree. + +#include "rlmeta/rpc/cc/rpc_future.h" + +#include + +#include "rlmeta/rpc/cc/rpc_utils.h" + +namespace rlmeta { +namespace rpc { + +py::object RpcFuture::Get() { + if (!Valid()) { + return py::none(); + } + py::object ret = rpc_utils::NestedDataToPython(std::move(future_.get())); + valid_ = false; + return ret; +} + +void DefineRpcFuture(py::module& m) { + py::class_>(m, "RpcFuture") + .def("valid", &RpcFuture::Valid) + .def("get", &RpcFuture::Get) + .def("wait", &RpcFuture::Wait); +} + +} // namespace rpc +} // namespace rlmeta diff --git a/rlmeta/rpc/cc/rpc_future.h b/rlmeta/rpc/cc/rpc_future.h new file mode 100644 index 0000000..4aac466 --- /dev/null +++ b/rlmeta/rpc/cc/rpc_future.h @@ -0,0 +1,37 @@ +// Copyright (c) Meta Platforms, Inc. and affiliates. +// +// This source code is licensed under the MIT license found in the +// LICENSE file in the root directory of this source tree. + +#pragma once + +#include + +#include + +#include "rpc.pb.h" + +namespace py = pybind11; + +namespace rlmeta { +namespace rpc { + +class RpcFuture { + public: + RpcFuture(std::future&& fut) : future_(std::move(fut)) {} + + bool Valid() const { return future_.valid() && valid_; } + + py::object Get(); + + void Wait() { future_.wait(); } + + protected: + std::future future_; + bool valid_ = true; +}; + +void DefineRpcFuture(py::module& m); + +} // namespace rpc +} // namespace rlmeta diff --git a/rlmeta/rpc/cc/rpc_utils.cc b/rlmeta/rpc/cc/rpc_utils.cc new file mode 100644 index 0000000..5314cf3 --- /dev/null +++ b/rlmeta/rpc/cc/rpc_utils.cc @@ -0,0 +1,177 @@ +// Copyright (c) Meta Platforms, Inc. and affiliates. +// +// This source code is licensed under the MIT license found in the +// LICENSE file in the root directory of this source tree. + +#include "rlmeta/rpc/cc/rpc_utils.h" + +#include +#include +#include +#include + +#include +#include +#include +#include + +#include "rlmeta/cc/nested_utils.h" +#include "rlmeta/cc/torch_utils.h" +#include "rlmeta/rpc/cc/tensor_wrapper.h" + +namespace rlmeta { +namespace rpc { + +namespace rpc_utils { + +TensorProto PythonToTensorProto(const py::object& obj) { + if (py::isinstance(obj)) { + return TensorWrapper(obj).TensorProto(); + } + if (rlmeta::utils::IsTorchTensor(obj)) { + return TensorWrapper(obj).TensorProto(); + } + return TensorProto(); +} + +py::object TensorProtoToPython(TensorProto&& tensor) { + if (tensor.tensor_type() == TensorProto::NUMPY) { + return TensorWrapper(std::move(tensor)).Python(); + } + if (tensor.tensor_type() == TensorProto::TORCH) { + return TensorWrapper(std::move(tensor)).Python(); + } + return py::none(); +} + +SimpleData PythonToSimpleData(const py::object& obj) { + SimpleData ret; + if (obj.is_none()) { + return ret; + } + if (py::isinstance(obj)) { + ret.set_bool_val(obj.cast()); + } else if (py::isinstance(obj)) { + ret.set_int_val(obj.cast()); + } else if (py::isinstance(obj)) { + ret.set_float_val(obj.cast()); + } else if (py::isinstance(obj)) { + ret.set_str_val(obj.cast()); + } else if (py::isinstance(obj)) { + ret.set_bytes_val(obj.cast()); + } else if (py::isinstance(obj)) { + *ret.mutable_tensor_val() = PythonToTensorProto(obj); + } else if (rlmeta::utils::IsTorchTensor(obj)) { + *ret.mutable_tensor_val() = PythonToTensorProto(obj); + } + return ret; +} + +py::object SimpleDataToPython(SimpleData&& proto) { + if (proto.has_bool_val()) { + return py::cast(proto.bool_val()); + } + if (proto.has_int_val()) { + return py::cast(proto.int_val()); + } + if (proto.has_float_val()) { + return py::cast(proto.float_val()); + } + if (proto.has_str_val()) { + return py::cast(proto.str_val()); + } + if (proto.has_bytes_val()) { + return py::bytes(proto.bytes_val()); + } + if (proto.has_tensor_val()) { + return TensorProtoToPython(std::move(*(proto.mutable_tensor_val()))); + } + return py::none(); +} + +NestedData PythonToNestedData(const py::object& obj) { + NestedData ret; + if (obj.is_none()) { + return ret; + } + + if (py::isinstance(obj)) { + py::tuple src = py::reinterpret_borrow(obj); + auto* dst = ret.mutable_vec(); + for (const auto x : src) { + *dst->add_data() = + PythonToNestedData(py::reinterpret_borrow(x)); + } + } else if (py::isinstance(obj)) { + py::list src = py::reinterpret_borrow(obj); + auto* dst = ret.mutable_vec(); + for (const auto x : src) { + *dst->add_data() = + PythonToNestedData(py::reinterpret_borrow(x)); + } + } else if (py::isinstance(obj)) { + py::dict src = py::reinterpret_borrow(obj); + auto* dst = ret.mutable_map(); + const std::vector keys = nested_utils::SortedKeys(src); + for (const std::string& k : keys) { + dst->mutable_data()->insert( + {k, PythonToNestedData( + py::reinterpret_borrow(src[py::str(k)]))}); + } + } else { + *ret.mutable_val() = PythonToSimpleData(obj); + } + + return ret; +} + +py::object NestedDataToPython(NestedData&& proto) { + if (proto.has_val()) { + return SimpleDataToPython(std::move(*proto.mutable_val())); + } + if (proto.has_vec()) { + auto* src = proto.mutable_vec(); + const int64_t n = src->data_size(); + py::tuple ret(n); + for (int64_t i = 0; i < n; ++i) { + ret[i] = NestedDataToPython(std::move(*(src->mutable_data(i)))); + } + return ret; + } + if (proto.has_map()) { + auto* src = proto.mutable_map()->mutable_data(); + py::dict ret; + const std::vector keys = nested_utils::SortedKeys(*src); + for (const std::string& k : keys) { + ret[py::str(k)] = NestedDataToPython(std::move(src->at(k))); + } + return ret; + } + return py::none(); +} + +std::string Dumps(const py::object& obj) { + const NestedData src = PythonToNestedData(obj); + py::gil_scoped_release release; + return src.SerializeAsString(); +} + +py::object Loads(const std::string& src) { + NestedData dst; + { + py::gil_scoped_release release; + dst.ParseFromString(src); + } + return NestedDataToPython(std::move(dst)); +} + +} // namespace rpc_utils + +void DefineRpcUtils(py::module& m) { + m.def("dumps", [](const py::object& obj) { + return py::bytes(rpc_utils::Dumps(obj)); + }).def("loads", &rpc_utils::Loads); +} + +} // namespace rpc +} // namespace rlmeta diff --git a/rlmeta/rpc/cc/rpc_utils.h b/rlmeta/rpc/cc/rpc_utils.h new file mode 100644 index 0000000..524136e --- /dev/null +++ b/rlmeta/rpc/cc/rpc_utils.h @@ -0,0 +1,38 @@ +// Copyright (c) Meta Platforms, Inc. and affiliates. +// +// This source code is licensed under the MIT license found in the +// LICENSE file in the root directory of this source tree. + +#pragma once + +#include + +#include + +#include "rpc.pb.h" + +namespace py = pybind11; + +namespace rlmeta { +namespace rpc { + +namespace rpc_utils { + +TensorProto PythonToTensorProto(const py::object& obj); +py::object TensorProtoToPython(TensorProto&& tensor); + +SimpleData PythonToSimpleData(const py::object& obj); +py::object SimpleDataToPython(SimpleData&& proto); + +NestedData PythonToNestedData(const py::object& obj); +py::object NestedDataToPython(NestedData&& proto); + +std::string Dumps(const py::object& obj); +py::object Loads(const std::string& src); + +} // namespace rpc_utils + +void DefineRpcUtils(py::module& m); + +} // namespace rpc +} // namespace rlmeta diff --git a/rlmeta/rpc/cc/server.cc b/rlmeta/rpc/cc/server.cc index 15cac5c..f04d7a0 100644 --- a/rlmeta/rpc/cc/server.cc +++ b/rlmeta/rpc/cc/server.cc @@ -11,14 +11,16 @@ #include #include +#include "rpc.pb.h" + namespace rlmeta { namespace rpc { -grpc::Status ServiceImpl::RemoteCall(grpc::ServerContext* /*context*/, +grpc::Status ServiceImpl::RemoteCall(grpc::ServerContext* /* context */, const RpcRequest* request, RpcResponse* response) { auto& func = functions_.at(request->function()); - response->set_return_value(func(request->args(), request->kwargs())); + *response->mutable_return_value() = func(request->args(), request->kwargs()); return grpc::Status::OK; } @@ -26,6 +28,7 @@ void Server::Start() { grpc::EnableDefaultHealthCheckService(true); grpc::reflection::InitProtoReflectionServerBuilderPlugin(); grpc::ServerBuilder builder; + builder.SetMaxReceiveMessageSize(-1); // Unlimited. builder.AddListeningPort(addr_, grpc::InsecureServerCredentials()); builder.RegisterService(&service_); server_ = builder.BuildAndStart(); @@ -44,11 +47,10 @@ std::shared_ptr Server::RegisterQueue( } else { ret = std::make_shared(batch_size); } - service_.Register(func_name, [que = ret](const std::string& args, - const std::string& kwargs) { - std::future ret = que->Put(args, kwargs); - return ret.get(); - }); + service_.Register( + func_name, [que = ret](const NestedData& args, const NestedData& kwargs) { + return que->Put(args, kwargs).get(); + }); return ret; } @@ -56,7 +58,7 @@ void DefineServer(py::module& m) { py::class_>(m, "Server") .def(py::init()) .def_property_readonly("addr", &Server::addr) - .def("start", &Server::Start) + .def("start", &Server::Start, py::call_guard()) .def("stop", &Server::Stop, py::call_guard()) .def("register_queue", &Server::RegisterQueue, py::arg("func_name"), py::arg("batch_size") = 0); diff --git a/rlmeta/rpc/cc/server.h b/rlmeta/rpc/cc/server.h index 840813c..0e533fc 100644 --- a/rlmeta/rpc/cc/server.h +++ b/rlmeta/rpc/cc/server.h @@ -7,7 +7,6 @@ #include #include -#include #include #include @@ -27,8 +26,7 @@ namespace py = pybind11; namespace rlmeta { namespace rpc { -using PyFunc = - std::function; +using PyFunc = std::function; using PyFuncDict = std::unordered_map; class ServiceImpl final : public Rpc::Service { @@ -62,8 +60,6 @@ class Server { protected: void ServePyFuncQueue(const std::string& func_name); - void ServeImpl(const std::string& func_name, TaskBase& task); - const std::string addr_; std::unique_ptr server_; ServiceImpl service_; diff --git a/rlmeta/rpc/cc/task.cc b/rlmeta/rpc/cc/task.cc index 3bc2562..51dcb03 100644 --- a/rlmeta/rpc/cc/task.cc +++ b/rlmeta/rpc/cc/task.cc @@ -10,59 +10,50 @@ namespace rlmeta { namespace rpc { -py::object BatchedTask::Args() { - const int64_t batch_size = batch_.size(); - py::tuple ret(batch_size); - for (int64_t i = 0; i < batch_size; ++i) { - ret[i] = batch_[i].Args(); +void BatchedTask::SetReturnValue(const py::object& return_value) { + NestedData ret = rpc_utils::PythonToNestedData(return_value); + py::gil_scoped_release release; + assert(ret.has_vec()); + assert(ret.vec_size() == batch_size_); + for (int64_t i = 0; i < batch_size_; ++i) { + promises_[i].set_value(std::move(*ret.mutable_vec()->mutable_data(i))); } - return ret; } -py::object BatchedTask::Kwargs() { - const int64_t batch_size = batch_.size(); - py::tuple ret(batch_size); - for (int64_t i = 0; i < batch_size; ++i) { - ret[i] = batch_[i].Kwargs(); - } - return ret; -} - -void BatchedTask::SetReturnValue(py::object&& return_value) { - const int64_t batch_size = batch_.size(); - py::tuple rets = py::reinterpret_borrow(return_value); - assert(rets.size() == batch_size); - for (int64_t i = 0; i < batch_size; ++i) { - batch_[i].SetReturnValue(std::move(rets[i])); - } -} - -std::future BatchedTask::Add(const std::string& args, - const std::string& kwargs) { - const int64_t batch_size = batch_.size(); - assert(batch_size < capacity_); - auto& task = batch_.emplace_back(args, kwargs); +std::future BatchedTask::Add(const NestedData& args, + const NestedData& kwargs) { + assert(batch_size_ < capacity_); + std::promise& p = promises_.emplace_back(); + *args_.mutable_vec()->add_data() = args; + *kwargs_.mutable_vec()->add_data() = kwargs; + ++batch_size_; num_to_wait_.DecrementCount(); - return task.Future(); + return p.get_future(); } -void DefineTaskBase(py::module& m) { - py::class_>(m, "TaskBase") - .def("args", &TaskBase::Args) - .def("kwargs", &TaskBase::Kwargs) - .def("set_return_value", &TaskBase::SetReturnValue); +std::future BatchedTask::Add(NestedData&& args, + NestedData&& kwargs) { + assert(batch_size_ < capacity_); + std::promise& p = promises_.emplace_back(); + *args_.mutable_vec()->add_data() = std::move(args); + *kwargs_.mutable_vec()->add_data() = std::move(kwargs); + ++batch_size_; + num_to_wait_.DecrementCount(); + return p.get_future(); } void DefineTask(py::module& m) { - py::class_>(m, "Task"); + py::class_>(m, "Task") + .def("args", &Task::Args) + .def("kwargs", &Task::Kwargs) + .def("set_return_value", &Task::SetReturnValue); } void DefineBatchedTask(py::module& m) { - py::class_>(m, - "BatchedTask") - .def("__len__", &BatchedTask::BatchSize) + py::class_>(m, "BatchedTask") + .def("__len__", &BatchedTask::batch_size) .def_property_readonly("capacity", &BatchedTask::capacity) - .def_property_readonly("batch_size", &BatchedTask::BatchSize) + .def_property_readonly("batch_size", &BatchedTask::batch_size) .def("empty", &BatchedTask::Empty) .def("full", &BatchedTask::Full); } diff --git a/rlmeta/rpc/cc/task.h b/rlmeta/rpc/cc/task.h index 8fd2cb8..797f084 100644 --- a/rlmeta/rpc/cc/task.h +++ b/rlmeta/rpc/cc/task.h @@ -14,84 +14,68 @@ #include #include "rlmeta/rpc/cc/blocking_counter.h" +#include "rlmeta/rpc/cc/rpc_utils.h" +#include "rpc.pb.h" namespace py = pybind11; namespace rlmeta { namespace rpc { -class TaskBase { +class Task { public: - virtual py::object Args() = 0; - virtual py::object Kwargs() = 0; - virtual void SetReturnValue(py::object&& return_value) = 0; -}; + Task() = default; -// https://pybind11.readthedocs.io/en/stable/advanced/classes.html#overriding-virtuals -class PyTaskBase : public TaskBase { - public: - using TaskBase::TaskBase; + Task(const NestedData& args, const NestedData& kwargs) + : args_(args), kwargs_(kwargs) {} - py::object Args() override { - PYBIND11_OVERRIDE_PURE(py::object, TaskBase, Args); - } + Task(NestedData&& args, NestedData&& kwargs) + : args_(std::move(args)), kwargs_(std::move(kwargs)) {} - py::object Kwargs() override { - PYBIND11_OVERRIDE_PURE(py::object, TaskBase, Kwargs); + virtual py::object Args() { + return rpc_utils::NestedDataToPython(std::move(args_)); } - void SetReturnValue(py::object&& return_value) override { - PYBIND11_OVERRIDE_PURE(void, TaskBase, SetReturnValue, return_value); + virtual py::object Kwargs() { + return rpc_utils::NestedDataToPython(std::move(kwargs_)); } -}; -class Task : public TaskBase { - public: - Task(const std::string& args, const std::string& kwargs) - : args_(args), kwargs_(kwargs) {} - - py::object Args() override { return py::bytes(std::move(args_)); } - py::object Kwargs() override { return py::bytes(std::move(kwargs_)); } + std::future Future() { return promise_.get_future(); } - std::future Future() { return promise_.get_future(); } - - void SetReturnValue(py::object&& return_value) override { - promise_.set_value( - py::reinterpret_borrow(std::move(return_value))); + virtual void SetReturnValue(const py::object& return_value) { + promise_.set_value(rpc_utils::PythonToNestedData(return_value)); } protected: - std::string args_; - std::string kwargs_; - std::promise promise_; + NestedData args_; + NestedData kwargs_; + std::promise promise_; }; -class BatchedTask : public TaskBase { +class BatchedTask : public Task { public: explicit BatchedTask(int64_t capacity) : capacity_(capacity), num_to_wait_(capacity) { - batch_.reserve(capacity_); + promises_.reserve(capacity); } int64_t capacity() const { return capacity_; } - int64_t BatchSize() const { return batch_.size(); } - - bool Empty() const { return batch_.empty(); } - bool Full() const { return static_cast(batch_.size()) == capacity_; } + int64_t batch_size() const { return batch_size_; } - py::object Args() override; - py::object Kwargs() override; + bool Empty() const { return batch_size_ == 0; } + bool Full() const { return batch_size_ == capacity_; } - void SetReturnValue(py::object&& return_value) override; + void SetReturnValue(const py::object& return_value) override; - std::future Add(const std::string& args, - const std::string& kwargs); + std::future Add(const NestedData& args, const NestedData& kwargs); + std::future Add(NestedData&& args, NestedData&& kwargs); void Wait() { num_to_wait_.Wait(); } protected: const int64_t capacity_; - std::vector batch_; + int64_t batch_size_ = 0; + std::vector> promises_; BlockingCounter num_to_wait_; }; diff --git a/rlmeta/rpc/cc/tensor_wrapper.cc b/rlmeta/rpc/cc/tensor_wrapper.cc new file mode 100644 index 0000000..ddb2df6 --- /dev/null +++ b/rlmeta/rpc/cc/tensor_wrapper.cc @@ -0,0 +1,86 @@ +// Copyright (c) Meta Platforms, Inc. and affiliates. +// +// This source code is licensed under the MIT license found in the +// LICENSE file in the root directory of this source tree. + +#include "rlmeta/rpc/cc/tensor_wrapper.h" + +#include + +#include + +#include "rlmeta/cc/torch_utils.h" + +namespace rlmeta { +namespace rpc { + +template <> +void TensorWrapper::FromPython(const py::object& obj) { + tensor_ = obj.cast(); +} + +template <> +void TensorWrapper::FromTensorProto( + rlmeta::rpc::TensorProto&& proto) { + assert(tensor.tensor_type() == TensorProto::NUMPY); + const std::vector shape(proto.shape().cbegin(), + proto.shape().cend()); + const void* data = proto.data().data(); + std::unique_ptr data_ptr(proto.release_data()); + auto capsule = py::capsule(data_ptr.get(), [](void* p) { + std::unique_ptr(reinterpret_cast(p)); + }); + data_ptr.release(); + tensor_ = py::array(py::dtype(proto.dtype()), shape, data, capsule); +} + +template <> +py::object TensorWrapper::Python() { + return std::move(tensor_); +} + +template <> +rlmeta::rpc::TensorProto TensorWrapper::TensorProto() { + rlmeta::rpc::TensorProto ret; + ret.set_tensor_type(TensorProto::NUMPY); + ret.set_dtype(tensor_.dtype().num()); + ret.mutable_shape()->Assign(tensor_.shape(), + tensor_.shape() + tensor_.ndim()); + ret.set_data(tensor_.data(), tensor_.nbytes()); + return ret; +} + +template <> +void TensorWrapper::FromPython(const py::object& obj) { + tensor_ = rlmeta::utils::PyObjectToTorchTensor(obj).contiguous(); +} + +template <> +void TensorWrapper::FromTensorProto( + rlmeta::rpc::TensorProto&& proto) { + assert(proto.tensor_type() == Tensor::TORCH); + const std::vector shape(proto.shape().cbegin(), + proto.shape().cend()); + std::string* data = proto.release_data(); + tensor_ = at::from_blob( + data->data(), shape, /*deleter=*/[data](void* /*p*/) { delete data; }, + static_cast(proto.dtype())); +} + +template <> +py::object TensorWrapper::Python() { + return rlmeta::utils::TorchTensorToPyObject(std::move(tensor_)); +} + +template <> +rlmeta::rpc::TensorProto TensorWrapper::TensorProto() { + rlmeta::rpc::TensorProto ret; + ret.set_tensor_type(TensorProto::TORCH); + ret.set_dtype(static_cast(tensor_.scalar_type())); + ret.mutable_shape()->Assign(tensor_.sizes().cbegin(), tensor_.sizes().cend()); + ret.set_data(tensor_.data_ptr(), tensor_.nbytes()); + return ret; +} + +} // namespace rpc +} // namespace rlmeta diff --git a/rlmeta/rpc/cc/tensor_wrapper.h b/rlmeta/rpc/cc/tensor_wrapper.h new file mode 100644 index 0000000..bd72657 --- /dev/null +++ b/rlmeta/rpc/cc/tensor_wrapper.h @@ -0,0 +1,45 @@ +// Copyright (c) Meta Platforms, Inc. and affiliates. +// +// This source code is licensed under the MIT license found in the +// LICENSE file in the root directory of this source tree. + +#pragma once + +#include + +#include "rpc.pb.h" + +namespace py = pybind11; + +namespace rlmeta { +namespace rpc { + +class TensorWrapperBase { + public: + virtual void FromPython(const py::object& obj) = 0; + virtual void FromTensorProto(rlmeta::rpc::TensorProto&& proto) = 0; + + virtual py::object Python() = 0; + virtual rlmeta::rpc::TensorProto TensorProto() = 0; +}; + +template +class TensorWrapper : public TensorWrapperBase { + public: + TensorWrapper(const py::object& obj) { FromPython(obj); } + TensorWrapper(rlmeta::rpc::TensorProto&& proto) { + FromTensorProto(std::move(proto)); + } + + void FromPython(const py::object& obj) override; + void FromTensorProto(rlmeta::rpc::TensorProto&& proto) override; + + py::object Python() override; + rlmeta::rpc::TensorProto TensorProto() override; + + protected: + TensorType tensor_; +}; + +} // namespace rpc +} // namespace rlmeta diff --git a/rlmeta/rpc/client.py b/rlmeta/rpc/client.py index 7761f24..4e64ad7 100644 --- a/rlmeta/rpc/client.py +++ b/rlmeta/rpc/client.py @@ -3,6 +3,7 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. +import asyncio import pickle from typing import Any @@ -10,39 +11,22 @@ import grpc import grpc.experimental -import rlmeta.rpc.rpc_pb2 as rpc_pb2 -import rlmeta.rpc.rpc_pb2_grpc as rpc_pb2_grpc +import _rlmeta_extension.rpc as _rpc +import _rlmeta_extension.rpc.rpc_utils as rpc_utils -class Client: - - def connect(self, addr: str) -> None: - # self._channel = grpc.insecure_channel(addr) - # self._rpc_stub = rpc_pb2_grpc.RpcStub(self._channel) - - self._addr = addr - self._channel_options = [ - (grpc.experimental.ChannelOptions.SingleThreadedUnaryStream, 1) - ] +class Client(_rpc.Client): def rpc(self, function: str, *args, **kwargs) -> Any: - with grpc.insecure_channel(self._addr, - options=self._channel_options) as channel: - stub = rpc_pb2_grpc.RpcStub(channel) - # ret = self._rpc_stub.RemoteCall( - ret = stub.RemoteCall( - rpc_pb2.RpcRequest(function=function, - args=pickle.dumps(args), - kwargs=pickle.dumps(kwargs))) - return pickle.loads(ret.return_value) - - async def async_rpc(self, function: str, *args, **kwargs) -> Any: - async with grpc.aio.insecure_channel( - self._addr, options=self._channel_options) as channel: - stub = rpc_pb2_grpc.RpcStub(channel) - # ret = await self._rpc_stub.RemoteCall( - ret = await stub.RemoteCall( - rpc_pb2.RpcRequest(function=function, - args=pickle.dumps(args), - kwargs=pickle.dumps(kwargs))) - return pickle.loads(ret.return_value) + return super().rpc(function, *args, **kwargs) + + def rpc_future(self, function: str, *args, **kwargs) -> _rpc.RpcFuture: + return super().rpc_future(function, *args, **kwargs) + + def async_rpc(self, function: str, *args, **kwargs) -> Any: + loop = asyncio.get_running_loop() + ret = super().rpc_future(function, *args, **kwargs) + fut = asyncio.Future() + # loop.call_soon_threadsafe(lambda x: fut.set_result(x.get()), ret) + loop.call_soon(lambda x: fut.set_result(x.get()), ret) + return fut diff --git a/rlmeta/rpc/protos/rpc.proto b/rlmeta/rpc/protos/rpc.proto index b3824b7..03a80d8 100644 --- a/rlmeta/rpc/protos/rpc.proto +++ b/rlmeta/rpc/protos/rpc.proto @@ -3,7 +3,7 @@ // This source code is licensed under the MIT license found in the // LICENSE file in the root directory of this source tree. -syntax = "proto3"; +syntax = "proto2"; package rlmeta.rpc; @@ -12,16 +12,57 @@ service Rpc { } message Error { - string message = 1; + optional string message = 1; } message RpcRequest { - string function = 1; - bytes args = 2; - bytes kwargs = 3; + optional string function = 1; + optional NestedData args = 2; + optional NestedData kwargs = 3; } message RpcResponse { - bytes return_value = 1; + optional NestedData return_value = 1; optional Error error = 2; } + + +message TensorProto { + enum TensorType { + UNKNOWN = 0; + NUMPY = 1; + TORCH = 2; + } + + optional TensorType tensor_type = 1; + optional int32 dtype = 2; + repeated int64 shape = 3 [packed = true]; + optional bytes data = 4; +} + +message SimpleData { + oneof value { + bool bool_val = 1; + int64 int_val = 2; + double float_val = 3; + string str_val = 4; + bytes bytes_val = 5; + TensorProto tensor_val = 6; + } +} + +message DataVector { + repeated NestedData data = 1; +} + +message DataMap { + map data = 1; +} + +message NestedData { + oneof nested_data { + SimpleData val = 1; + DataVector vec = 2; + DataMap map = 3; + } +} diff --git a/rlmeta/rpc/rpc_pb2.py b/rlmeta/rpc/rpc_pb2.py deleted file mode 100644 index 3afafab..0000000 --- a/rlmeta/rpc/rpc_pb2.py +++ /dev/null @@ -1,63 +0,0 @@ -# -*- coding: utf-8 -*- - -# Copyright (c) Meta Platforms, Inc. and affiliates. -# -# This source code is licensed under the MIT license found in the -# LICENSE file in the root directory of this source tree. - -# Generated by the protocol buffer compiler. DO NOT EDIT! -# source: rpc.proto -"""Generated protocol buffer code.""" -from google.protobuf import descriptor as _descriptor -from google.protobuf import descriptor_pool as _descriptor_pool -from google.protobuf import message as _message -from google.protobuf import reflection as _reflection -from google.protobuf import symbol_database as _symbol_database -# @@protoc_insertion_point(imports) - -_sym_db = _symbol_database.Default() - - - - -DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\trpc.proto\x12\nrlmeta.rpc\"\x18\n\x05\x45rror\x12\x0f\n\x07message\x18\x01 \x01(\t\"<\n\nRpcRequest\x12\x10\n\x08\x66unction\x18\x01 \x01(\t\x12\x0c\n\x04\x61rgs\x18\x02 \x01(\x0c\x12\x0e\n\x06kwargs\x18\x03 \x01(\x0c\"T\n\x0bRpcResponse\x12\x14\n\x0creturn_value\x18\x01 \x01(\x0c\x12%\n\x05\x65rror\x18\x02 \x01(\x0b\x32\x11.rlmeta.rpc.ErrorH\x00\x88\x01\x01\x42\x08\n\x06_error2F\n\x03Rpc\x12?\n\nRemoteCall\x12\x16.rlmeta.rpc.RpcRequest\x1a\x17.rlmeta.rpc.RpcResponse\"\x00\x62\x06proto3') - - - -_ERROR = DESCRIPTOR.message_types_by_name['Error'] -_RPCREQUEST = DESCRIPTOR.message_types_by_name['RpcRequest'] -_RPCRESPONSE = DESCRIPTOR.message_types_by_name['RpcResponse'] -Error = _reflection.GeneratedProtocolMessageType('Error', (_message.Message,), { - 'DESCRIPTOR' : _ERROR, - '__module__' : 'rpc_pb2' - # @@protoc_insertion_point(class_scope:rlmeta.rpc.Error) - }) -_sym_db.RegisterMessage(Error) - -RpcRequest = _reflection.GeneratedProtocolMessageType('RpcRequest', (_message.Message,), { - 'DESCRIPTOR' : _RPCREQUEST, - '__module__' : 'rpc_pb2' - # @@protoc_insertion_point(class_scope:rlmeta.rpc.RpcRequest) - }) -_sym_db.RegisterMessage(RpcRequest) - -RpcResponse = _reflection.GeneratedProtocolMessageType('RpcResponse', (_message.Message,), { - 'DESCRIPTOR' : _RPCRESPONSE, - '__module__' : 'rpc_pb2' - # @@protoc_insertion_point(class_scope:rlmeta.rpc.RpcResponse) - }) -_sym_db.RegisterMessage(RpcResponse) - -_RPC = DESCRIPTOR.services_by_name['Rpc'] -if _descriptor._USE_C_DESCRIPTORS == False: - - DESCRIPTOR._options = None - _ERROR._serialized_start=25 - _ERROR._serialized_end=49 - _RPCREQUEST._serialized_start=51 - _RPCREQUEST._serialized_end=111 - _RPCRESPONSE._serialized_start=113 - _RPCRESPONSE._serialized_end=197 - _RPC._serialized_start=199 - _RPC._serialized_end=269 -# @@protoc_insertion_point(module_scope) diff --git a/rlmeta/rpc/rpc_pb2_grpc.py b/rlmeta/rpc/rpc_pb2_grpc.py deleted file mode 100644 index d46b6ad..0000000 --- a/rlmeta/rpc/rpc_pb2_grpc.py +++ /dev/null @@ -1,71 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# -# This source code is licensed under the MIT license found in the -# LICENSE file in the root directory of this source tree. - -# Generated by the gRPC Python protocol compiler plugin. DO NOT EDIT! -"""Client and server classes corresponding to protobuf-defined services.""" -import grpc - -import rlmeta.rpc.rpc_pb2 as rpc__pb2 - - -class RpcStub(object): - """Missing associated documentation comment in .proto file.""" - - def __init__(self, channel): - """Constructor. - - Args: - channel: A grpc.Channel. - """ - self.RemoteCall = channel.unary_unary( - '/rlmeta.rpc.Rpc/RemoteCall', - request_serializer=rpc__pb2.RpcRequest.SerializeToString, - response_deserializer=rpc__pb2.RpcResponse.FromString, - ) - - -class RpcServicer(object): - """Missing associated documentation comment in .proto file.""" - - def RemoteCall(self, request, context): - """Missing associated documentation comment in .proto file.""" - context.set_code(grpc.StatusCode.UNIMPLEMENTED) - context.set_details('Method not implemented!') - raise NotImplementedError('Method not implemented!') - - -def add_RpcServicer_to_server(servicer, server): - rpc_method_handlers = { - 'RemoteCall': grpc.unary_unary_rpc_method_handler( - servicer.RemoteCall, - request_deserializer=rpc__pb2.RpcRequest.FromString, - response_serializer=rpc__pb2.RpcResponse.SerializeToString, - ), - } - generic_handler = grpc.method_handlers_generic_handler( - 'rlmeta.rpc.Rpc', rpc_method_handlers) - server.add_generic_rpc_handlers((generic_handler,)) - - - # This class is part of an EXPERIMENTAL API. -class Rpc(object): - """Missing associated documentation comment in .proto file.""" - - @staticmethod - def RemoteCall(request, - target, - options=(), - channel_credentials=None, - call_credentials=None, - insecure=False, - compression=None, - wait_for_ready=None, - timeout=None, - metadata=None): - return grpc.experimental.unary_unary(request, target, '/rlmeta.rpc.Rpc/RemoteCall', - rpc__pb2.RpcRequest.SerializeToString, - rpc__pb2.RpcResponse.FromString, - options, channel_credentials, - insecure, call_credentials, compression, wait_for_ready, timeout, metadata) diff --git a/rlmeta/rpc/server.py b/rlmeta/rpc/server.py index 004f3b9..9b50f35 100644 --- a/rlmeta/rpc/server.py +++ b/rlmeta/rpc/server.py @@ -3,7 +3,7 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. -import pickle +import time import threading from typing import Any, Callable, Optional, NoReturn @@ -11,6 +11,7 @@ import rlmeta.utils.data_utils as data_utils import _rlmeta_extension.rpc as rpc +import _rlmeta_extension.rpc.rpc_utils as rpc_utils class Server(rpc.Server): @@ -60,18 +61,13 @@ def _process(self, queue: rpc.ComputationQueue, except StopIteration: return - def _wrap_func(self, task: rpc.TaskBase, func: Callable[..., Any]) -> None: + def _wrap_func(self, task: rpc.Task, func: Callable[..., Any]) -> None: batch_size = None - - # TODO: Find better serialization method to replace pickle here. - if isinstance(task, rpc.Task): - args = pickle.loads(task.args()) - kwargs = pickle.loads(task.kwargs()) - else: + args = task.args() + kwargs = task.kwargs() + if isinstance(task, rpc.BatchedTask): batch_size = task.batch_size - args = tuple(pickle.loads(i) for i in task.args()) args = data_utils.stack_fields(args) - kwargs = tuple(pickle.loads(i) for i in task.kwargs()) kwargs = data_utils.stack_fields(kwargs) # Lock to protect any state inside func. @@ -79,12 +75,10 @@ def _wrap_func(self, task: rpc.TaskBase, func: Callable[..., Any]) -> None: with self._lock: ret = func(*args, **kwargs) - if batch_size is None: - ret = pickle.dumps(ret) - else: + if batch_size is not None: if ret is None: - ret = (pickle.dumps(None),) * batch_size + ret = (None,) * batch_size else: ret = data_utils.unstack_fields(ret, batch_size) - ret = tuple(pickle.dumps(i) for i in ret) + task.set_return_value(ret) diff --git a/rlmeta/utils/stats_dict.py b/rlmeta/utils/stats_dict.py index 15d2832..7d91e24 100644 --- a/rlmeta/utils/stats_dict.py +++ b/rlmeta/utils/stats_dict.py @@ -23,6 +23,17 @@ def __init__(self, if val is not None: self.add(val) + @classmethod + def from_dict(cls, data: Dict[str, float]) -> StatsItem: + ret = cls(key=data.get("key", None)) + ret._m0 = data.get("count", 0) + ret._m1 = data.get("mean", 0.0) + std = data.get("std", 0.0) + ret._m2 = std * std * ret._m0 + ret._min_val = data.get("min", float("inf")) + ret._max_val = data.get("max", float("-inf")) + return ret + @property def key(self) -> str: return self._key @@ -84,6 +95,12 @@ def __init__(self) -> None: def __getitem__(self, key: str) -> StatsItem: return self._dict[key] + @classmethod + def from_dict(cls, data: Dict[str, Dict[str, float]]) -> StatsDict: + ret = cls() + ret._dict = {k: StatsItem.from_dict(v) for k, v in data.items()} + return ret + def reset(self): self._dict.clear() @@ -100,7 +117,7 @@ def extend(self, d: Dict[str, float]) -> None: def update(self, stats: StatsDict) -> None: self._dict.update(stats._dict) - def dict(self) -> Dict[str, float]: + def dict(self) -> Dict[str, Dict[str, float]]: return {k: v.dict() for k, v in self._dict.items()} def json(self, info: Optional[str] = None, **kwargs) -> str: diff --git a/third_party/pybind11 b/third_party/pybind11 index d4b9f34..1e3400b 160000 --- a/third_party/pybind11 +++ b/third_party/pybind11 @@ -1 +1 @@ -Subproject commit d4b9f3471f465f0cc6d05556a837c26589b08b29 +Subproject commit 1e3400b6742288429f2069aaf5febf92d0662dae From 3d06ba66b4191a41ec56b1d7f84d58ba6ff91913 Mon Sep 17 00:00:00 2001 From: Xiaomeng Yang Date: Tue, 9 Aug 2022 20:12:59 -0700 Subject: [PATCH 4/9] Add python asyncion client to improve exploration performance --- examples/atari/dqn/atari_apex_dqn.py | 15 ++- rlmeta/agents/dqn/apex_dqn_agent.py | 21 ++-- rlmeta/core/model.py | 5 +- rlmeta/core/remote.py | 11 +- rlmeta/core/replay_buffer.py | 31 ++++-- rlmeta/rpc/CMakeLists.txt | 2 +- rlmeta/rpc/cc/client.cc | 43 ++++---- rlmeta/rpc/cc/client.h | 22 +--- rlmeta/rpc/cc/server.cc | 21 +++- rlmeta/rpc/cc/server.h | 15 ++- rlmeta/rpc/client.py | 45 +++++++- rlmeta/rpc/protos/rpc.proto | 19 +++- rlmeta/rpc/rpc_pb2.py | 151 +++++++++++++++++++++++++++ rlmeta/rpc/rpc_pb2_grpc.py | 104 ++++++++++++++++++ rlmeta/rpc/server.py | 12 +-- 15 files changed, 430 insertions(+), 87 deletions(-) create mode 100644 rlmeta/rpc/rpc_pb2.py create mode 100644 rlmeta/rpc/rpc_pb2_grpc.py diff --git a/examples/atari/dqn/atari_apex_dqn.py b/examples/atari/dqn/atari_apex_dqn.py index 5caa9cc..d142ad3 100644 --- a/examples/atari/dqn/atari_apex_dqn.py +++ b/examples/atari/dqn/atari_apex_dqn.py @@ -6,6 +6,7 @@ import copy import logging import time +import os import hydra @@ -63,8 +64,10 @@ def main(cfg): timeout=120) t_rb = make_remote_replay_buffer(rb, r_server) - env_fac = gym_wrappers.AtariWrapperFactory( - cfg.env, max_episode_steps=cfg.max_episode_steps) + t_env_fac = gym_wrappers.AtariWrapperFactory( + cfg.env, max_episode_steps=cfg.max_episode_steps, clip_rewards=True) + e_env_fac = gym_wrappers.AtariWrapperFactory( + cfg.env, max_episode_steps=cfg.max_episode_steps, clip_rewards=False) agent = ApexDQNAgent(a_model, replay_buffer=a_rb, @@ -81,7 +84,7 @@ def main(cfg): replay_buffer=t_rb) e_agent_fac = ApexDQNAgentFactory(e_model, ConstantEpsFunc(cfg.eval_eps)) - t_loop = ParallelLoop(env_fac, + t_loop = ParallelLoop(t_env_fac, t_agent_fac, t_ctrl, running_phase=Phase.TRAIN, @@ -89,7 +92,7 @@ def main(cfg): num_rollouts=cfg.num_train_rollouts, num_workers=cfg.num_train_workers, seed=cfg.train_seed) - e_loop = ParallelLoop(env_fac, + e_loop = ParallelLoop(e_env_fac, e_agent_fac, e_ctrl, running_phase=Phase.EVAL, @@ -133,4 +136,8 @@ def main(cfg): if __name__ == "__main__": mp.set_start_method("spawn") + if os.environ.get('https_proxy'): + del os.environ['https_proxy'] + if os.environ.get('http_proxy'): + del os.environ['http_proxy'] main() diff --git a/rlmeta/agents/dqn/apex_dqn_agent.py b/rlmeta/agents/dqn/apex_dqn_agent.py index fca97af..9261280 100644 --- a/rlmeta/agents/dqn/apex_dqn_agent.py +++ b/rlmeta/agents/dqn/apex_dqn_agent.py @@ -7,6 +7,8 @@ from typing import Callable, Dict, List, Optional, Sequence +import numpy as np + import torch import torch.nn as nn @@ -138,10 +140,10 @@ def train(self, num_steps: int) -> Optional[StatsDict]: console.log(f"Training for num_steps = {num_steps}") for _ in track(range(num_steps), description="Training..."): t0 = time.perf_counter() - batch, weight, index, timestamp = self.replay_buffer.sample( + index, batch, weight, timestamp = self.replay_buffer.sample( self.batch_size) t1 = time.perf_counter() - step_stats = self.train_step(batch, weight, index, timestamp) + step_stats = self.train_step(index, batch, weight, timestamp) t2 = time.perf_counter() time_stats = { "sample_data_time/ms": (t1 - t0) * 1000.0, @@ -177,20 +179,21 @@ def eval(self, num_episodes: Optional[int] = None) -> Optional[StatsDict]: def make_replay(self) -> Optional[List[NestedTensor]]: trajectory_len = len(self.trajectory) - if trajectory_len <= self.multi_step: + if trajectory_len <= 2: return None replay = [] append = replay.append - for i in range(0, trajectory_len - self.multi_step): + # for i in range(0, trajectory_len - self.multi_step): + for i in range(0, trajectory_len - 1): cur = self.trajectory[i] - nxt = self.trajectory[i + self.multi_step] + nxt = self.trajectory[min(i + self.multi_step, trajectory_len - 1)] obs = cur["obs"] act = cur["action"] next_obs = nxt["obs"] done = nxt["done"] reward = 0.0 - for j in range(self.multi_step): + for j in range(min(self.multi_step, trajectory_len - 1 - i)): reward += (self.gamma**j) * self.trajectory[i + j]["reward"] append({ "obs": obs, @@ -202,9 +205,9 @@ def make_replay(self) -> Optional[List[NestedTensor]]: return replay - def train_step(self, batch: NestedTensor, weight: torch.Tensor, - index: torch.Tensor, - timestamp: torch.Tensor) -> Dict[str, float]: + def train_step(self, index: np.ndarray, batch: NestedTensor, + weight: torch.Tensor, + timestamp: np.ndarray) -> Dict[str, float]: device = next(self.model.parameters()).device batch = nested_utils.map_nested(lambda x: x.to(device), batch) self.optimizer.zero_grad() diff --git a/rlmeta/core/model.py b/rlmeta/core/model.py index 40e8ee1..80212d7 100644 --- a/rlmeta/core/model.py +++ b/rlmeta/core/model.py @@ -45,9 +45,10 @@ def __init__(self, server_name: str, server_addr: str, name: Optional[str] = None, - timeout: float = 60) -> None: + timeout: float = 60, + py_aio_client: bool = True) -> None: self._wrapped = model - self._reset(server_name, server_addr, name, timeout) + self._reset(server_name, server_addr, name, timeout, py_aio_client) # TODO: Find a better way to implement this def __getattribute__(self, attr: str) -> Any: diff --git a/rlmeta/core/remote.py b/rlmeta/core/remote.py index 26c3957..bbddb26 100644 --- a/rlmeta/core/remote.py +++ b/rlmeta/core/remote.py @@ -55,14 +55,15 @@ def __init__(self, server_name: str, server_addr: str, name: Optional[str] = None, - timeout: float = 60) -> None: + timeout: float = 60, + py_aio_client: bool = True) -> None: self._target_repr = repr(target) self._server_name = server_name self._server_addr = server_addr self._remote_methods = target.remote_methods self._identifier = target.identifier - self._reset(server_name, server_addr, name, timeout) + self._reset(server_name, server_addr, name, timeout, py_aio_client) self._client_methods = {} # TODO: Find a better way to implement this @@ -126,7 +127,7 @@ def connect(self) -> None: # self._client.set_timeout(self._timeout) # self._client.connect(self._server_addr) - self._client = rpc.Client() + self._client = rpc.Client(self._py_aio_client) self._client.connect(self._server_addr) self._bind() @@ -136,13 +137,15 @@ def _reset(self, server_name: str, server_addr: str, name: Optional[str] = None, - timeout: float = 60) -> None: + timeout: float = 60, + py_aio_client: bool = True) -> None: if name is None: name = generate_random_name() self._server_name = server_name self._server_addr = server_addr self._name = name self._timeout = timeout + self._py_aio_client = py_aio_client self._client = None self._connected = False diff --git a/rlmeta/core/replay_buffer.py b/rlmeta/core/replay_buffer.py index a0ecfe1..8461a4a 100644 --- a/rlmeta/core/replay_buffer.py +++ b/rlmeta/core/replay_buffer.py @@ -16,6 +16,7 @@ from rich.console import Console import rlmeta.core.remote as remote +import rlmeta.rpc as rpc import rlmeta.utils.data_utils as data_utils import rlmeta.utils.nested_utils as nested_utils @@ -172,9 +173,10 @@ def extend(self, def sample( self, batch_size: int ) -> Tuple[NestedTensor, torch.Tensor, torch.Tensor, torch.Tensor]: - data, weight, index, timestamp = self._sample(batch_size) - return data, weight, torch.from_numpy(index), torch.from_numpy( - timestamp) + # data, weight, index, timestamp = self._sample(batch_size) + # return data, weight, torch.from_numpy(index), torch.from_numpy( + # timestamp) + return self._sample(batch_size) @remote.remote_method(batch_size=None) def update_priority(self, @@ -257,7 +259,8 @@ def _append(self, self._init_priority(index) else: self._update_priority(index, priority) - self._update_timestamp(index) + # self._update_timestamp(index) + self._timestamps.update(index) return index def _extend(self, @@ -280,7 +283,7 @@ def _sample( index = self._sum_tree.scan_lower_bound(mass) data, weight = self.__getitem__(index) timestamp = self._timestamps[index] - return data, weight, index, timestamp + return index, data, weight, timestamp class RemoteReplayBuffer(remote.Remote): @@ -292,7 +295,13 @@ def __init__(self, name: Optional[str] = None, prefetch: int = 0, timeout: float = 60) -> None: - super().__init__(target, server_name, server_addr, name, timeout) + # Disable python asyncio client for large data transmission. + super().__init__(target, + server_name, + server_addr, + name, + timeout, + py_aio_client=False) self._server_name = server_name self._server_addr = server_addr @@ -307,6 +316,16 @@ def __repr__(self): def prefetch(self) -> Optional[int]: return self._prefetch + def connect(self) -> None: + if self._connected: + return + + self._client = rpc.Client() + self._client.connect(self._server_addr) + + self._bind() + self._connected = True + def sample( self, batch_size: int ) -> Union[NestedTensor, Tuple[NestedTensor, torch.Tensor, torch.Tensor, diff --git a/rlmeta/rpc/CMakeLists.txt b/rlmeta/rpc/CMakeLists.txt index 4f35c9d..9cd2ec3 100644 --- a/rlmeta/rpc/CMakeLists.txt +++ b/rlmeta/rpc/CMakeLists.txt @@ -15,7 +15,7 @@ include(FetchContent) FetchContent_Declare( grpc GIT_REPOSITORY https://github.com/grpc/grpc - GIT_TAG v1.48.0 + GIT_TAG v1.47.0 ) set(FETCHCONTENT_QUIET OFF) diff --git a/rlmeta/rpc/cc/client.cc b/rlmeta/rpc/cc/client.cc index 9af5123..b6a2106 100644 --- a/rlmeta/rpc/cc/client.cc +++ b/rlmeta/rpc/cc/client.cc @@ -30,22 +30,30 @@ void Client::Connect(const std::string& addr, int64_t timeout) { if (!channel_->WaitForConnected(deadline)) { std::cerr << "[Client::connect] timeout" << std::endl; } - thread_ = std::make_unique(&Client::AsyncCompleteRpc, this); connected_ = true; } void Client::Disconnect() { if (connected_) { connected_ = false; - cq_.Shutdown(); - thread_->join(); - thread_.reset(); } } py::object Client::Rpc(const std::string& func, const py::args& args, const py::kwargs& kwargs) { - return RpcFuture(func, args, kwargs).Get(); + assert(connected_); + RpcRequest request; + request.set_function(func); + *request.mutable_args() = rpc_utils::PythonToNestedData(args); + *request.mutable_kwargs() = rpc_utils::PythonToNestedData(kwargs); + + NestedData ret; + { + py::gil_scoped_release release; + ret = RpcImpl(std::move(request)); + } + + return rpc_utils::NestedDataToPython(std::move(ret)); } rlmeta::rpc::RpcFuture Client::RpcFuture(const std::string& func, @@ -58,25 +66,18 @@ rlmeta::rpc::RpcFuture Client::RpcFuture(const std::string& func, *request.mutable_kwargs() = rpc_utils::PythonToNestedData(kwargs); py::gil_scoped_release release; - AsyncClientCall* call = new AsyncClientCall(); - rlmeta::rpc::RpcFuture fut = call->promise.get_future(); - call->response_reader = - stub_->PrepareAsyncRemoteCall(&call->context, request, &cq_); - call->response_reader->StartCall(); - call->response_reader->Finish(&call->response, &call->status, (void*)call); + rlmeta::rpc::RpcFuture fut = + std::async(&Client::RpcImpl, this, std::move(request)); return fut; } -void Client::AsyncCompleteRpc() { - void* got_tag; - bool ok = false; - while (cq_.Next(&got_tag, &ok)) { - std::unique_ptr call( - static_cast(got_tag)); - GPR_ASSERT(ok); - assert(call->status.ok()); - call->promise.set_value(std::move(*call->response.mutable_return_value())); - } +NestedData Client::RpcImpl(RpcRequest&& request) { + RpcResponse response; + grpc::ClientContext context; + grpc::Status status = stub_->RemoteCall(&context, request, &response); + assert(status.ok()); + NestedData ret = std::move(*response.mutable_return_value()); + return ret; } void DefineClient(py::module& m) { diff --git a/rlmeta/rpc/cc/client.h b/rlmeta/rpc/cc/client.h index 0703a67..fa86b74 100644 --- a/rlmeta/rpc/cc/client.h +++ b/rlmeta/rpc/cc/client.h @@ -22,23 +22,6 @@ namespace py = pybind11; namespace rlmeta { namespace rpc { -// The Client class is adapted from gRPC C++ async client example. -// https://github.com/grpc/grpc/blob/2d4f3c56001cd1e1f85734b2f7c5ce5f2797c38a/examples/cpp/helloworld/greeter_async_client2.cc -// -// Copyright 2015 gRPC authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - class Client { public: Client() = default; @@ -66,16 +49,13 @@ class Client { response_reader; }; - void AsyncCompleteRpc(); + NestedData RpcImpl(RpcRequest&& request); std::string addr_; bool connected_ = false; std::shared_ptr channel_; std::unique_ptr stub_; - - grpc::CompletionQueue cq_; - std::unique_ptr thread_; }; void DefineClient(py::module& m); diff --git a/rlmeta/rpc/cc/server.cc b/rlmeta/rpc/cc/server.cc index f04d7a0..b93ebbe 100644 --- a/rlmeta/rpc/cc/server.cc +++ b/rlmeta/rpc/cc/server.cc @@ -19,11 +19,24 @@ namespace rpc { grpc::Status ServiceImpl::RemoteCall(grpc::ServerContext* /* context */, const RpcRequest* request, RpcResponse* response) { - auto& func = functions_.at(request->function()); + auto& func = functions_.at(request->function()).first; *response->mutable_return_value() = func(request->args(), request->kwargs()); return grpc::Status::OK; } +grpc::Status ServiceImpl::PyRemoteCall(grpc::ServerContext* /* context */, + const PyRpcRequest* request, + PyRpcResponse* response) { + NestedData args; + NestedData kwargs; + args.ParseFromString(request->args()); + kwargs.ParseFromString(request->kwargs()); + auto& func = functions_.at(request->function()).second; + NestedData ret = func(std::move(args), std::move(kwargs)); + ret.SerializeToString(response->mutable_return_value()); + return grpc::Status::OK; +} + void Server::Start() { grpc::EnableDefaultHealthCheckService(true); grpc::reflection::InitProtoReflectionServerBuilderPlugin(); @@ -48,8 +61,12 @@ std::shared_ptr Server::RegisterQueue( ret = std::make_shared(batch_size); } service_.Register( - func_name, [que = ret](const NestedData& args, const NestedData& kwargs) { + func_name, + [que = ret](const NestedData& args, const NestedData& kwargs) { return que->Put(args, kwargs).get(); + }, + [que = ret](NestedData&& args, NestedData&& kwargs) { + return que->Put(std::move(args), std::move(kwargs)).get(); }); return ret; } diff --git a/rlmeta/rpc/cc/server.h b/rlmeta/rpc/cc/server.h index 0e533fc..938497a 100644 --- a/rlmeta/rpc/cc/server.h +++ b/rlmeta/rpc/cc/server.h @@ -15,6 +15,7 @@ #include #include #include +#include #include "rlmeta/rpc/cc/computation_queue.h" #include "rlmeta/rpc/cc/task.h" @@ -27,12 +28,16 @@ namespace rlmeta { namespace rpc { using PyFunc = std::function; -using PyFuncDict = std::unordered_map; +using PyFuncRvalue = std::function; +using PyFuncDict = + std::unordered_map>; class ServiceImpl final : public Rpc::Service { public: - grpc::Status Register(const std::string& func_name, PyFunc&& func_impl) { - functions_.emplace(func_name, std::move(func_impl)); + grpc::Status Register(const std::string& func_name, PyFunc&& func_impl, + PyFuncRvalue&& func_rvalue_impl) { + functions_.emplace(func_name, std::make_pair(std::move(func_impl), + std::move(func_rvalue_impl))); return grpc::Status::OK; } @@ -41,6 +46,10 @@ class ServiceImpl final : public Rpc::Service { const RpcRequest* request, RpcResponse* response) override; + grpc::Status PyRemoteCall(grpc::ServerContext* context, + const PyRpcRequest* request, + PyRpcResponse* response) override; + PyFuncDict functions_; }; diff --git a/rlmeta/rpc/client.py b/rlmeta/rpc/client.py index 4e64ad7..f06c3a5 100644 --- a/rlmeta/rpc/client.py +++ b/rlmeta/rpc/client.py @@ -11,22 +11,61 @@ import grpc import grpc.experimental +import rlmeta.rpc.rpc_pb2 as rpc_pb2 +import rlmeta.rpc.rpc_pb2_grpc as rpc_pb2_grpc import _rlmeta_extension.rpc as _rpc -import _rlmeta_extension.rpc.rpc_utils as rpc_utils +import _rlmeta_extension.rpc.rpc_utils as _rpc_utils class Client(_rpc.Client): + def __init__(self, py_aio_client: bool = True) -> None: + super().__init__() + self._addr = None + self._timeout = None + self._options = None + self._connected = False + self._py_aio_client = py_aio_client + + def connect(self, addr: str, timeout: int = 60) -> None: + super().connect(addr, timeout) + self._addr = addr + self._timeout = timeout + self._options = [ + (grpc.experimental.ChannelOptions.SingleThreadedUnaryStream, 1), + ("grpc.enable_http_proxy", 0), + ("grpc.max_send_message_length", -1), + ("grpc.max_receive_message_length", -1), + ] + self._connected = True + def rpc(self, function: str, *args, **kwargs) -> Any: return super().rpc(function, *args, **kwargs) def rpc_future(self, function: str, *args, **kwargs) -> _rpc.RpcFuture: return super().rpc_future(function, *args, **kwargs) - def async_rpc(self, function: str, *args, **kwargs) -> Any: + async def async_rpc(self, function: str, *args, **kwargs) -> Any: + if self._py_aio_client: + return await self._async_rpc_py(function, *args, **kwargs) + else: + return await self._async_rpc_cc(function, *args, **kwargs) + + async def _async_rpc_cc(self, function: str, *args, **kwargs) -> Any: loop = asyncio.get_running_loop() ret = super().rpc_future(function, *args, **kwargs) fut = asyncio.Future() # loop.call_soon_threadsafe(lambda x: fut.set_result(x.get()), ret) loop.call_soon(lambda x: fut.set_result(x.get()), ret) - return fut + return await fut + + # TODO: Add efficient native asyncio support in C++ client. + async def _async_rpc_py(self, function: str, *args, **kwargs) -> Any: + async with grpc.aio.insecure_channel(target=self._addr, + options=self._options) as channel: + stub = rpc_pb2_grpc.RpcStub(channel) + response = await stub.PyRemoteCall( + rpc_pb2.PyRpcRequest(function=function, + args=_rpc_utils.dumps(args), + kwargs=_rpc_utils.dumps(kwargs))) + return _rpc_utils.loads(response.return_value) diff --git a/rlmeta/rpc/protos/rpc.proto b/rlmeta/rpc/protos/rpc.proto index 03a80d8..cb0e784 100644 --- a/rlmeta/rpc/protos/rpc.proto +++ b/rlmeta/rpc/protos/rpc.proto @@ -9,10 +9,7 @@ package rlmeta.rpc; service Rpc { rpc RemoteCall(RpcRequest) returns (RpcResponse) {} -} - -message Error { - optional string message = 1; + rpc PyRemoteCall(PyRpcRequest) returns (PyRpcResponse) {} } message RpcRequest { @@ -26,6 +23,20 @@ message RpcResponse { optional Error error = 2; } +message PyRpcRequest { + optional string function = 1; + optional bytes args = 2; + optional bytes kwargs = 3; +} + +message PyRpcResponse { + optional bytes return_value = 1; + optional string error = 2; +} + +message Error { + optional string msg = 1; +} message TensorProto { enum TensorType { diff --git a/rlmeta/rpc/rpc_pb2.py b/rlmeta/rpc/rpc_pb2.py new file mode 100644 index 0000000..9348f50 --- /dev/null +++ b/rlmeta/rpc/rpc_pb2.py @@ -0,0 +1,151 @@ +# -*- coding: utf-8 -*- + +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +# Generated by the protocol buffer compiler. DO NOT EDIT! +# source: rpc.proto +"""Generated protocol buffer code.""" +from google.protobuf import descriptor as _descriptor +from google.protobuf import descriptor_pool as _descriptor_pool +from google.protobuf import message as _message +from google.protobuf import reflection as _reflection +from google.protobuf import symbol_database as _symbol_database +# @@protoc_insertion_point(imports) + +_sym_db = _symbol_database.Default() + + + + +DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\trpc.proto\x12\nrlmeta.rpc\"l\n\nRpcRequest\x12\x10\n\x08\x66unction\x18\x01 \x01(\t\x12$\n\x04\x61rgs\x18\x02 \x01(\x0b\x32\x16.rlmeta.rpc.NestedData\x12&\n\x06kwargs\x18\x03 \x01(\x0b\x32\x16.rlmeta.rpc.NestedData\"]\n\x0bRpcResponse\x12,\n\x0creturn_value\x18\x01 \x01(\x0b\x32\x16.rlmeta.rpc.NestedData\x12 \n\x05\x65rror\x18\x02 \x01(\x0b\x32\x11.rlmeta.rpc.Error\">\n\x0cPyRpcRequest\x12\x10\n\x08\x66unction\x18\x01 \x01(\t\x12\x0c\n\x04\x61rgs\x18\x02 \x01(\x0c\x12\x0e\n\x06kwargs\x18\x03 \x01(\x0c\"4\n\rPyRpcResponse\x12\x14\n\x0creturn_value\x18\x01 \x01(\x0c\x12\r\n\x05\x65rror\x18\x02 \x01(\t\"\x14\n\x05\x45rror\x12\x0b\n\x03msg\x18\x01 \x01(\t\"\xa7\x01\n\x0bTensorProto\x12\x37\n\x0btensor_type\x18\x01 \x01(\x0e\x32\".rlmeta.rpc.TensorProto.TensorType\x12\r\n\x05\x64type\x18\x02 \x01(\x05\x12\x11\n\x05shape\x18\x03 \x03(\x03\x42\x02\x10\x01\x12\x0c\n\x04\x64\x61ta\x18\x04 \x01(\x0c\"/\n\nTensorType\x12\x0b\n\x07UNKNOWN\x10\x00\x12\t\n\x05NUMPY\x10\x01\x12\t\n\x05TORCH\x10\x02\"\xa8\x01\n\nSimpleData\x12\x12\n\x08\x62ool_val\x18\x01 \x01(\x08H\x00\x12\x11\n\x07int_val\x18\x02 \x01(\x03H\x00\x12\x13\n\tfloat_val\x18\x03 \x01(\x01H\x00\x12\x11\n\x07str_val\x18\x04 \x01(\tH\x00\x12\x13\n\tbytes_val\x18\x05 \x01(\x0cH\x00\x12-\n\ntensor_val\x18\x06 \x01(\x0b\x32\x17.rlmeta.rpc.TensorProtoH\x00\x42\x07\n\x05value\"2\n\nDataVector\x12$\n\x04\x64\x61ta\x18\x01 \x03(\x0b\x32\x16.rlmeta.rpc.NestedData\"{\n\x07\x44\x61taMap\x12+\n\x04\x64\x61ta\x18\x01 \x03(\x0b\x32\x1d.rlmeta.rpc.DataMap.DataEntry\x1a\x43\n\tDataEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12%\n\x05value\x18\x02 \x01(\x0b\x32\x16.rlmeta.rpc.NestedData:\x02\x38\x01\"\x8d\x01\n\nNestedData\x12%\n\x03val\x18\x01 \x01(\x0b\x32\x16.rlmeta.rpc.SimpleDataH\x00\x12%\n\x03vec\x18\x02 \x01(\x0b\x32\x16.rlmeta.rpc.DataVectorH\x00\x12\"\n\x03map\x18\x03 \x01(\x0b\x32\x13.rlmeta.rpc.DataMapH\x00\x42\r\n\x0bnested_data2\x8d\x01\n\x03Rpc\x12?\n\nRemoteCall\x12\x16.rlmeta.rpc.RpcRequest\x1a\x17.rlmeta.rpc.RpcResponse\"\x00\x12\x45\n\x0cPyRemoteCall\x12\x18.rlmeta.rpc.PyRpcRequest\x1a\x19.rlmeta.rpc.PyRpcResponse\"\x00') + + + +_RPCREQUEST = DESCRIPTOR.message_types_by_name['RpcRequest'] +_RPCRESPONSE = DESCRIPTOR.message_types_by_name['RpcResponse'] +_PYRPCREQUEST = DESCRIPTOR.message_types_by_name['PyRpcRequest'] +_PYRPCRESPONSE = DESCRIPTOR.message_types_by_name['PyRpcResponse'] +_ERROR = DESCRIPTOR.message_types_by_name['Error'] +_TENSORPROTO = DESCRIPTOR.message_types_by_name['TensorProto'] +_SIMPLEDATA = DESCRIPTOR.message_types_by_name['SimpleData'] +_DATAVECTOR = DESCRIPTOR.message_types_by_name['DataVector'] +_DATAMAP = DESCRIPTOR.message_types_by_name['DataMap'] +_DATAMAP_DATAENTRY = _DATAMAP.nested_types_by_name['DataEntry'] +_NESTEDDATA = DESCRIPTOR.message_types_by_name['NestedData'] +_TENSORPROTO_TENSORTYPE = _TENSORPROTO.enum_types_by_name['TensorType'] +RpcRequest = _reflection.GeneratedProtocolMessageType('RpcRequest', (_message.Message,), { + 'DESCRIPTOR' : _RPCREQUEST, + '__module__' : 'rpc_pb2' + # @@protoc_insertion_point(class_scope:rlmeta.rpc.RpcRequest) + }) +_sym_db.RegisterMessage(RpcRequest) + +RpcResponse = _reflection.GeneratedProtocolMessageType('RpcResponse', (_message.Message,), { + 'DESCRIPTOR' : _RPCRESPONSE, + '__module__' : 'rpc_pb2' + # @@protoc_insertion_point(class_scope:rlmeta.rpc.RpcResponse) + }) +_sym_db.RegisterMessage(RpcResponse) + +PyRpcRequest = _reflection.GeneratedProtocolMessageType('PyRpcRequest', (_message.Message,), { + 'DESCRIPTOR' : _PYRPCREQUEST, + '__module__' : 'rpc_pb2' + # @@protoc_insertion_point(class_scope:rlmeta.rpc.PyRpcRequest) + }) +_sym_db.RegisterMessage(PyRpcRequest) + +PyRpcResponse = _reflection.GeneratedProtocolMessageType('PyRpcResponse', (_message.Message,), { + 'DESCRIPTOR' : _PYRPCRESPONSE, + '__module__' : 'rpc_pb2' + # @@protoc_insertion_point(class_scope:rlmeta.rpc.PyRpcResponse) + }) +_sym_db.RegisterMessage(PyRpcResponse) + +Error = _reflection.GeneratedProtocolMessageType('Error', (_message.Message,), { + 'DESCRIPTOR' : _ERROR, + '__module__' : 'rpc_pb2' + # @@protoc_insertion_point(class_scope:rlmeta.rpc.Error) + }) +_sym_db.RegisterMessage(Error) + +TensorProto = _reflection.GeneratedProtocolMessageType('TensorProto', (_message.Message,), { + 'DESCRIPTOR' : _TENSORPROTO, + '__module__' : 'rpc_pb2' + # @@protoc_insertion_point(class_scope:rlmeta.rpc.TensorProto) + }) +_sym_db.RegisterMessage(TensorProto) + +SimpleData = _reflection.GeneratedProtocolMessageType('SimpleData', (_message.Message,), { + 'DESCRIPTOR' : _SIMPLEDATA, + '__module__' : 'rpc_pb2' + # @@protoc_insertion_point(class_scope:rlmeta.rpc.SimpleData) + }) +_sym_db.RegisterMessage(SimpleData) + +DataVector = _reflection.GeneratedProtocolMessageType('DataVector', (_message.Message,), { + 'DESCRIPTOR' : _DATAVECTOR, + '__module__' : 'rpc_pb2' + # @@protoc_insertion_point(class_scope:rlmeta.rpc.DataVector) + }) +_sym_db.RegisterMessage(DataVector) + +DataMap = _reflection.GeneratedProtocolMessageType('DataMap', (_message.Message,), { + + 'DataEntry' : _reflection.GeneratedProtocolMessageType('DataEntry', (_message.Message,), { + 'DESCRIPTOR' : _DATAMAP_DATAENTRY, + '__module__' : 'rpc_pb2' + # @@protoc_insertion_point(class_scope:rlmeta.rpc.DataMap.DataEntry) + }) + , + 'DESCRIPTOR' : _DATAMAP, + '__module__' : 'rpc_pb2' + # @@protoc_insertion_point(class_scope:rlmeta.rpc.DataMap) + }) +_sym_db.RegisterMessage(DataMap) +_sym_db.RegisterMessage(DataMap.DataEntry) + +NestedData = _reflection.GeneratedProtocolMessageType('NestedData', (_message.Message,), { + 'DESCRIPTOR' : _NESTEDDATA, + '__module__' : 'rpc_pb2' + # @@protoc_insertion_point(class_scope:rlmeta.rpc.NestedData) + }) +_sym_db.RegisterMessage(NestedData) + +_RPC = DESCRIPTOR.services_by_name['Rpc'] +if _descriptor._USE_C_DESCRIPTORS == False: + + DESCRIPTOR._options = None + _TENSORPROTO.fields_by_name['shape']._options = None + _TENSORPROTO.fields_by_name['shape']._serialized_options = b'\020\001' + _DATAMAP_DATAENTRY._options = None + _DATAMAP_DATAENTRY._serialized_options = b'8\001' + _RPCREQUEST._serialized_start=25 + _RPCREQUEST._serialized_end=133 + _RPCRESPONSE._serialized_start=135 + _RPCRESPONSE._serialized_end=228 + _PYRPCREQUEST._serialized_start=230 + _PYRPCREQUEST._serialized_end=292 + _PYRPCRESPONSE._serialized_start=294 + _PYRPCRESPONSE._serialized_end=346 + _ERROR._serialized_start=348 + _ERROR._serialized_end=368 + _TENSORPROTO._serialized_start=371 + _TENSORPROTO._serialized_end=538 + _TENSORPROTO_TENSORTYPE._serialized_start=491 + _TENSORPROTO_TENSORTYPE._serialized_end=538 + _SIMPLEDATA._serialized_start=541 + _SIMPLEDATA._serialized_end=709 + _DATAVECTOR._serialized_start=711 + _DATAVECTOR._serialized_end=761 + _DATAMAP._serialized_start=763 + _DATAMAP._serialized_end=886 + _DATAMAP_DATAENTRY._serialized_start=819 + _DATAMAP_DATAENTRY._serialized_end=886 + _NESTEDDATA._serialized_start=889 + _NESTEDDATA._serialized_end=1030 + _RPC._serialized_start=1033 + _RPC._serialized_end=1174 +# @@protoc_insertion_point(module_scope) diff --git a/rlmeta/rpc/rpc_pb2_grpc.py b/rlmeta/rpc/rpc_pb2_grpc.py new file mode 100644 index 0000000..934ac72 --- /dev/null +++ b/rlmeta/rpc/rpc_pb2_grpc.py @@ -0,0 +1,104 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +# Generated by the gRPC Python protocol compiler plugin. DO NOT EDIT! +"""Client and server classes corresponding to protobuf-defined services.""" +import grpc + +import rlmeta.rpc.rpc_pb2 as rpc__pb2 + + +class RpcStub(object): + """Missing associated documentation comment in .proto file.""" + + def __init__(self, channel): + """Constructor. + + Args: + channel: A grpc.Channel. + """ + self.RemoteCall = channel.unary_unary( + '/rlmeta.rpc.Rpc/RemoteCall', + request_serializer=rpc__pb2.RpcRequest.SerializeToString, + response_deserializer=rpc__pb2.RpcResponse.FromString, + ) + self.PyRemoteCall = channel.unary_unary( + '/rlmeta.rpc.Rpc/PyRemoteCall', + request_serializer=rpc__pb2.PyRpcRequest.SerializeToString, + response_deserializer=rpc__pb2.PyRpcResponse.FromString, + ) + + +class RpcServicer(object): + """Missing associated documentation comment in .proto file.""" + + def RemoteCall(self, request, context): + """Missing associated documentation comment in .proto file.""" + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') + + def PyRemoteCall(self, request, context): + """Missing associated documentation comment in .proto file.""" + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') + + +def add_RpcServicer_to_server(servicer, server): + rpc_method_handlers = { + 'RemoteCall': grpc.unary_unary_rpc_method_handler( + servicer.RemoteCall, + request_deserializer=rpc__pb2.RpcRequest.FromString, + response_serializer=rpc__pb2.RpcResponse.SerializeToString, + ), + 'PyRemoteCall': grpc.unary_unary_rpc_method_handler( + servicer.PyRemoteCall, + request_deserializer=rpc__pb2.PyRpcRequest.FromString, + response_serializer=rpc__pb2.PyRpcResponse.SerializeToString, + ), + } + generic_handler = grpc.method_handlers_generic_handler( + 'rlmeta.rpc.Rpc', rpc_method_handlers) + server.add_generic_rpc_handlers((generic_handler,)) + + + # This class is part of an EXPERIMENTAL API. +class Rpc(object): + """Missing associated documentation comment in .proto file.""" + + @staticmethod + def RemoteCall(request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): + return grpc.experimental.unary_unary(request, target, '/rlmeta.rpc.Rpc/RemoteCall', + rpc__pb2.RpcRequest.SerializeToString, + rpc__pb2.RpcResponse.FromString, + options, channel_credentials, + insecure, call_credentials, compression, wait_for_ready, timeout, metadata) + + @staticmethod + def PyRemoteCall(request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): + return grpc.experimental.unary_unary(request, target, '/rlmeta.rpc.Rpc/PyRemoteCall', + rpc__pb2.PyRpcRequest.SerializeToString, + rpc__pb2.PyRpcResponse.FromString, + options, channel_credentials, + insecure, call_credentials, compression, wait_for_ready, timeout, metadata) diff --git a/rlmeta/rpc/server.py b/rlmeta/rpc/server.py index 9b50f35..2d14962 100644 --- a/rlmeta/rpc/server.py +++ b/rlmeta/rpc/server.py @@ -9,12 +9,10 @@ from typing import Any, Callable, Optional, NoReturn import rlmeta.utils.data_utils as data_utils +import _rlmeta_extension.rpc as _rpc -import _rlmeta_extension.rpc as rpc -import _rlmeta_extension.rpc.rpc_utils as rpc_utils - -class Server(rpc.Server): +class Server(_rpc.Server): def __init__(self, addr: str): super().__init__(addr) @@ -50,7 +48,7 @@ def stop(self) -> None: print("rpc.Server stopped") - def _process(self, queue: rpc.ComputationQueue, + def _process(self, queue: _rpc.ComputationQueue, func: Callable[..., Any]) -> NoReturn: try: while True: @@ -61,11 +59,11 @@ def _process(self, queue: rpc.ComputationQueue, except StopIteration: return - def _wrap_func(self, task: rpc.Task, func: Callable[..., Any]) -> None: + def _wrap_func(self, task: _rpc.Task, func: Callable[..., Any]) -> None: batch_size = None args = task.args() kwargs = task.kwargs() - if isinstance(task, rpc.BatchedTask): + if isinstance(task, _rpc.BatchedTask): batch_size = task.batch_size args = data_utils.stack_fields(args) kwargs = data_utils.stack_fields(kwargs) From 52bb0a28914566cdfd3f2efa38fceae9af4d40c3 Mon Sep 17 00:00:00 2001 From: Xiaomeng Yang Date: Thu, 11 Aug 2022 15:30:47 -0700 Subject: [PATCH 5/9] Remove unused code --- examples/atari/dqn/atari_apex_dqn.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/examples/atari/dqn/atari_apex_dqn.py b/examples/atari/dqn/atari_apex_dqn.py index d142ad3..17e27c1 100644 --- a/examples/atari/dqn/atari_apex_dqn.py +++ b/examples/atari/dqn/atari_apex_dqn.py @@ -136,8 +136,4 @@ def main(cfg): if __name__ == "__main__": mp.set_start_method("spawn") - if os.environ.get('https_proxy'): - del os.environ['https_proxy'] - if os.environ.get('http_proxy'): - del os.environ['http_proxy'] main() From 29f9b4fe49fd970bf632f4faaa5542e24e4c21c2 Mon Sep 17 00:00:00 2001 From: Xiaomeng Yang Date: Mon, 15 Aug 2022 13:41:03 -0700 Subject: [PATCH 6/9] Remove moolib requirement --- requirements.txt | 3 +-- rlmeta/core/replay_buffer.py | 7 +------ 2 files changed, 2 insertions(+), 8 deletions(-) diff --git a/requirements.txt b/requirements.txt index 3e16b67..ba73fe6 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,9 +1,8 @@ gym hydra-core matplotlib -moolib@git+https://github.com/facebookresearch/moolib numpy opencv-python tabulate torch>=1.5.1 -rich \ No newline at end of file +rich diff --git a/rlmeta/core/replay_buffer.py b/rlmeta/core/replay_buffer.py index 8461a4a..e6202df 100644 --- a/rlmeta/core/replay_buffer.py +++ b/rlmeta/core/replay_buffer.py @@ -296,12 +296,7 @@ def __init__(self, prefetch: int = 0, timeout: float = 60) -> None: # Disable python asyncio client for large data transmission. - super().__init__(target, - server_name, - server_addr, - name, - timeout, - py_aio_client=False) + super().__init__(target, server_name, server_addr, name, timeout) self._server_name = server_name self._server_addr = server_addr From 88e8bb6b8bfc11020f54827cc208bd464f7c0367 Mon Sep 17 00:00:00 2001 From: Xiaomeng Yang Date: Mon, 15 Aug 2022 14:22:47 -0700 Subject: [PATCH 7/9] Add grpcio requirement --- requirements.txt | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index ba73fe6..2890d42 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,8 +1,9 @@ +grpcio==1.47.0 gym hydra-core matplotlib numpy opencv-python +rich tabulate torch>=1.5.1 -rich From c7267964f355b905a5b46de8853c09f688509036 Mon Sep 17 00:00:00 2001 From: Xiaomeng Yang Date: Mon, 15 Aug 2022 18:28:53 -0700 Subject: [PATCH 8/9] Disable python async client for replay_buffer --- rlmeta/core/replay_buffer.py | 18 ++++++++++++------ 1 file changed, 12 insertions(+), 6 deletions(-) diff --git a/rlmeta/core/replay_buffer.py b/rlmeta/core/replay_buffer.py index e6202df..a2d59ab 100644 --- a/rlmeta/core/replay_buffer.py +++ b/rlmeta/core/replay_buffer.py @@ -121,7 +121,7 @@ def __init__(self, self._priority_type = priority_type self._sum_tree = SumSegmentTree(capacity, dtype=priority_type) - self._min_tree = MinSegmentTree(capacity, dtype=priority_type) + # self._min_tree = MinSegmentTree(capacity, dtype=priority_type) self._max_priority = 1.0 self._timestamps = TimestampManager(capacity) @@ -201,7 +201,7 @@ def warm_up(self, learning_starts: Optional[int] = None) -> None: def _init_priority(self, index) -> None: priority = self._max_priority**self.alpha self._sum_tree[index] = priority - self._min_tree[index] = priority + # self._min_tree[index] = priority def _update_priority( self, @@ -230,17 +230,18 @@ def _update_priority( if isinstance(index, int): if mask is None or mask: self._sum_tree[index] = priority - self._min_tree[index] = priority + # self._min_tree[index] = priority else: self._sum_tree.update(index, priority, mask) - self._min_tree.update(index, priority, mask) + # self._min_tree.update(index, priority, mask) def _compute_weight( self, index: Union[int, Tensor]) -> Union[float, torch.Tensor]: p = self._sum_tree[index] if isinstance(p, np.ndarray): p = torch.from_numpy(p) - p_min = self._min_tree.query(0, self.capacity) + # p_min = self._min_tree.query(0, self.capacity) + p_min = p.min() # Importance sampling weight formula: # w_i = (p_i / sum(p) * N) ^ (-beta) @@ -296,7 +297,12 @@ def __init__(self, prefetch: int = 0, timeout: float = 60) -> None: # Disable python asyncio client for large data transmission. - super().__init__(target, server_name, server_addr, name, timeout) + super().__init__(target, + server_name, + server_addr, + name, + timeout, + py_aio_client=False) self._server_name = server_name self._server_addr = server_addr From 94b7c48231b587ade4c7bf143e057771ab2ba43a Mon Sep 17 00:00:00 2001 From: Xiaomeng Yang Date: Mon, 15 Aug 2022 23:23:36 -0700 Subject: [PATCH 9/9] Disable python async client for replay_buffer --- rlmeta/core/replay_buffer.py | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/rlmeta/core/replay_buffer.py b/rlmeta/core/replay_buffer.py index a2d59ab..a06a744 100644 --- a/rlmeta/core/replay_buffer.py +++ b/rlmeta/core/replay_buffer.py @@ -317,15 +317,15 @@ def __repr__(self): def prefetch(self) -> Optional[int]: return self._prefetch - def connect(self) -> None: - if self._connected: - return - - self._client = rpc.Client() - self._client.connect(self._server_addr) - - self._bind() - self._connected = True + # def connect(self) -> None: + # if self._connected: + # return + # + # self._client = rpc.Client() + # self._client.connect(self._server_addr) + # + # self._bind() + # self._connected = True def sample( self, batch_size: int