From 96889051f888abd142298f451f1283d0d8eb3902 Mon Sep 17 00:00:00 2001 From: Hanyu Zhao Date: Mon, 28 Feb 2022 19:55:15 +0800 Subject: [PATCH 1/2] Add prototypes for data workers. --- .../data_worker_controller.h | 88 +++++++++ .../data_worker_rendezvous_mgr.h | 160 +++++++++++++++++ .../data_worker_rendezvous_mgr_interface.h | 49 +++++ .../core/framework/data_worker_rendezvous.h | 167 ++++++++++++++++++ .../core/graph/data_worker_graph_partition.h | 57 ++++++ 5 files changed, 521 insertions(+) create mode 100644 tensorflow/core/distributed_runtime/data_worker_controller.h create mode 100644 tensorflow/core/distributed_runtime/data_worker_rendezvous_mgr.h create mode 100644 tensorflow/core/distributed_runtime/data_worker_rendezvous_mgr_interface.h create mode 100644 tensorflow/core/framework/data_worker_rendezvous.h create mode 100644 tensorflow/core/graph/data_worker_graph_partition.h diff --git a/tensorflow/core/distributed_runtime/data_worker_controller.h b/tensorflow/core/distributed_runtime/data_worker_controller.h new file mode 100644 index 00000000000..c27c51d5755 --- /dev/null +++ b/tensorflow/core/distributed_runtime/data_worker_controller.h @@ -0,0 +1,88 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_DATA_WORKER_CONTROLLER_H_ +#define TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_DATA_WORKER_CONTROLLER_H_ + +#include +#include +#include +#include + +#include "tensorflow/core/platform/mutex.h" +#include "tensorflow/core/graph/graph.h" +#include "tensorflow/core/graph/data_worker_graph_partition.h" + +namespace tensorflow { + +// Maintains and dispatches subgraphs to data workers. +class DataWorkerController { +private: + // A data-processing graph partitioned from each *training* + // worker task to be dispatched to data workers. + struct TaskDataWorkerGraph { + string task_name; + std::vector> + registered_data_workers; + std::shared_ptr g; + // Names of the DataWorkerSend ops that should be run + // on the data worker clients. + std::vector node_names; + // Names of the tensors to be sent from data workers. + std::vector tensor_names; + int num_registered() const { return registered_data_workers.size(); } + void RegisterDataWorker(const string& name, const string& host_port) { + registered_data_workers.emplace_back(name, host_port); + } + + TaskDataWorkerGraph(const string& task_name, + std::shared_ptr g, + const std::vector& node_names, + const std::vector& tensor_names) + : task_name(task_name), g(g), node_names(node_names), tensor_names(tensor_names) {} + ~TaskDataWorkerGraph() {} + }; + + mutex mu_; + std::vector graphs_ GUARDED_BY(mu_); + bool use_default_split_points_ = true; + bool extend_default_split_ = false; + bool fuse_recv_ = false; + // Used for sequencing DataWorkerSend/Recv nodes. + int64 next_node_id_ GUARDED_BY(mu_) = 0; + + // Returns the graph that has been allocated to the least number of data workers. + TaskDataWorkerGraph& GetGraphForNewDataWorker(); + // Resets the device names to the target data worker. + void ResetDeviceNamesForGraph(GraphDef* const g, const string& dw_name); + void ResetDeviceNameForNode(NodeDef* node, const string& dw_name); + +public: + DataWorkerController() {} + DataWorkerController(bool use_default_split_points, bool extend_default_split, bool fuse_recv); + ~DataWorkerController() {} + Status Partition(Graph* g, PartitionForDataWorkerOptions& popts); + Status RegisterDataWorker(GraphDef* dst_graph, + const string& name, + const string& host_port, + string& training_worker_name, + std::vector& node_names); + const std::vector* GetTensorNames(const string& task_name); +}; + +} // namespace tensorflow + + +#endif diff --git a/tensorflow/core/distributed_runtime/data_worker_rendezvous_mgr.h b/tensorflow/core/distributed_runtime/data_worker_rendezvous_mgr.h new file mode 100644 index 00000000000..097ef7f6dff --- /dev/null +++ b/tensorflow/core/distributed_runtime/data_worker_rendezvous_mgr.h @@ -0,0 +1,160 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_DATA_WORKER_RENDEZVOUS_MGR_H_ +#define TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_DATA_WORKER_RENDEZVOUS_MGR_H_ + +#include +#include +#include + +#include "tensorflow/core/distributed_runtime/data_worker_rendezvous_mgr_interface.h" +#include "tensorflow/core/framework/data_worker_rendezvous.h" +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/lib/gtl/flatmap.h" +#include "tensorflow/core/lib/gtl/flatset.h" +#include "tensorflow/core/lib/hash/hash.h" +#include "tensorflow/core/lib/core/threadpool.h" +#include "tensorflow/core/platform/mutex.h" +#include "tensorflow/core/distributed_runtime/worker_session.h" + +namespace tensorflow { +class GenericDataWorkerRendezvous; +class DataWorkerRecvTensorThread; + +class DataWorkerRendezvousMgr: public DataWorkerRendezvousMgrInterface{ + public: + struct DataWorkerRendezvousMgrOptions{ + int queue_size = 100; + int num_recv_threads = 1; + int num_send_threads = 4; + string protocol = "grpc"; + bool fuse_recv = false; + }; + + explicit DataWorkerRendezvousMgr(const DataWorkerRendezvousMgrOptions& options); + ~DataWorkerRendezvousMgr(); + + void RecvLocalAsync(const DataWorkerRendezvous::ParsedKey& key, + DataWorkerRendezvous::DoneCallback done) override; + + void FuseRecvLocalAsync(const std::vector& keys, + DataWorkerRendezvous::FuseDoneCallback done) override; + + void RegisterDataWorker(const string& task_name, const string& host_port) override; + + void SetTensorNames(const std::vector& tensor_names) override; + + DataWorkerRendezvous* Find(); + + private: + mutex mu_; + const int queue_size_; + const int num_recv_threads_; + const int num_send_threads_; + const string protocol_; + const bool fuse_recv_; + GenericDataWorkerRendezvous* rdwr_ GUARDED_BY(mu_); + GenericDataWorkerRendezvous* FindOrCreate(); +}; + +// GenericDataWorkerRendezvous supports both grpc and grpc++ as the underlying +// communication protocol. It also supports transferring data from the local +// data worker directly. +class GenericDataWorkerRendezvous: public DataWorkerRendezvous { + public: + GenericDataWorkerRendezvous(const int& queue_size, + const int& num_recv_threads, + const int& num_send_threads, + const string& protocol, + const bool& fuse_recv); + ~GenericDataWorkerRendezvous(); + + Status Initialize(WorkerSession* session) override; + void StartAbort(const Status& status) override; + void SetTensorNames(const std::vector& tensor_names) override; + Status SetRecvAttrs(const DataWorkerRendezvous::ParsedKey& key, + const AllocatorAttributes& alloc_attrs, + const string& device) override; + void DataWorkerSendAsync(const DataWorkerRendezvous::ParsedKey& key, + const Tensor& val, + const DataWorkerRendezvous::Args& send_args, + DataWorkerRendezvous::DoneCallback done) override; + Status LocalDataWorkerSend(const DataWorkerRendezvous::ParsedKey& key, + const string& tensor_name, + const Tensor& val, + const DataWorkerRendezvous::Args& send_args) override; + void RecvLocalAsync(const DataWorkerRendezvous::ParsedKey& key, DataWorkerRendezvous::DoneCallback done) override; + void FuseRecvLocalAsync(const std::vector& keys, + DataWorkerRendezvous::FuseDoneCallback done) override; + void DataWorkerRecvAsync(const DataWorkerRendezvous::ParsedKey& key, + const DataWorkerRendezvous::Args& recv_args, + DataWorkerRendezvous::DoneCallback done) override; + void DataWorkerFuseRecvAsync(const DataWorkerRendezvous::Args& recv_args, + DataWorkerRendezvous::FuseDoneCallback done) override; + void RegisterDataWorker(const string& task_name, const string& host_port); + + private: + void RecvAsync(const DataWorkerRendezvous::ParsedKey& key, + const DataWorkerRendezvous::Args& recv_args, + DataWorkerRendezvous::DoneCallback done); + void EnqueueRecvItems(std::vector& items); + void EnqueueFuseRecvItem(FuseItem* item); + void SameWorkerRecvDone(const DataWorkerRendezvous::ParsedKey& parsed, + const DataWorkerRendezvous::Args& send_args, + const DataWorkerRendezvous::Args& recv_args, + const Tensor& in, Tensor* out, StatusCallback done); + + static uint64 KeyHash(const StringPiece& k) { + return Hash64(k.data(), k.size()); + } + + const string protocol_; + const bool fuse_recv_; + mutex attrs_mu_; + std::unordered_map> recv_nodes_attrs_ GUARDED_BY(attrs_mu_); + const int num_recv_threads_; + std::vector> recv_threads_; + std::unique_ptr send_threads_; + + typedef std::deque ItemQueue; + typedef std::deque FuseItemQueue; + typedef gtl::FlatMap Table; + + std::mutex mu_; + std::mutex local_tmp_mu_; + std::condition_variable cv_; + Status status_ GUARDED_BY(mu_); + WorkerSession* session_ GUARDED_BY(mu_); + + // Table is used for both data workers and training workers for storing the items to enable async execution: + // Data workers put the produced tensors in the Table and wait for the training workers + // to fetch them. Training workers put the fetched tensors in their local Table. + Table table_ GUARDED_BY(mu_); + FuseItemQueue fuse_queue_ GUARDED_BY(mu_); + std::vector local_tmp_ GUARDED_BY(local_tmp_mu_); + std::vector tensor_names_; + const int queue_size_; + + friend class DataWorkerRecvTensorThread; + friend class GrpcDataWorkerRecvTensorThread; + friend class StarDataWorkerRecvTensorThread; + TF_DISALLOW_COPY_AND_ASSIGN(GenericDataWorkerRendezvous); +}; + + +} // end namespace tensorflow + +#endif // TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_DATA_WORKER_RENDEZVOUS_MGR_H_ \ No newline at end of file diff --git a/tensorflow/core/distributed_runtime/data_worker_rendezvous_mgr_interface.h b/tensorflow/core/distributed_runtime/data_worker_rendezvous_mgr_interface.h new file mode 100644 index 00000000000..7f351ae81d2 --- /dev/null +++ b/tensorflow/core/distributed_runtime/data_worker_rendezvous_mgr_interface.h @@ -0,0 +1,49 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_DATA_WORKER_RENDEZVOUS_MGR_INTERFACE_H_ +#define TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_DATA_WORKER_RENDEZVOUS_MGR_INTERFACE_H_ + +#include +#include + +#include "tensorflow/core/distributed_runtime/worker_env.h" +#include "tensorflow/core/framework/data_worker_rendezvous.h" +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/platform/types.h" + +namespace tensorflow { + +class DataWorkerRendezvousMgrInterface { + public: + DataWorkerRendezvousMgrInterface() {} + virtual ~DataWorkerRendezvousMgrInterface() {} + + virtual DataWorkerRendezvous* Find() = 0; + + virtual void RecvLocalAsync(const DataWorkerRendezvous::ParsedKey& key, + DataWorkerRendezvous::DoneCallback done) = 0; + + virtual void FuseRecvLocalAsync(const std::vector& keys, + DataWorkerRendezvous::FuseDoneCallback done) = 0; + + virtual void RegisterDataWorker(const string& task_name, const string& host_port) = 0; + + virtual void SetTensorNames(const std::vector& tensor_names) = 0; +}; + +} // end namespace tensorflow + +#endif // TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_DATA_WORKER_RENDEZVOUS_MGR_INTERFACE_H_ diff --git a/tensorflow/core/framework/data_worker_rendezvous.h b/tensorflow/core/framework/data_worker_rendezvous.h new file mode 100644 index 00000000000..06ab2815423 --- /dev/null +++ b/tensorflow/core/framework/data_worker_rendezvous.h @@ -0,0 +1,167 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_FRAMEWORK_DATA_WORKER_RENDEZVOUS_H_ +#define TENSORFLOW_FRAMEWORK_DATA_WORKER_RENDEZVOUS_H_ + +#include +#include +#include +#include + +#include "tensorflow/core/framework/cancellation.h" +#include "tensorflow/core/framework/control_flow.h" +#include "tensorflow/core/framework/device_base.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/lib/core/refcount.h" +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/util/device_name_utils.h" + +namespace tensorflow { +struct WorkerSession; + +class DataWorkerRendezvous : public core::RefCounted { + public: + explicit DataWorkerRendezvous() {} + struct Args { + AllocatorAttributes alloc_attrs; + DeviceContext* device_context = nullptr; + CancellationManager* cancellation_manager = nullptr; // not owned. + bool local = false; + string device_type; + string device; + }; + + struct DataWorkerInfo { + string task_name; + string host_port; + string tensor_name; + DataWorkerInfo() {} + DataWorkerInfo(string task_name, string host_port, string tensor_name) + : task_name(task_name), host_port(host_port), tensor_name(tensor_name) {} + }; + + virtual ~DataWorkerRendezvous() {} + + // Constructs a data worker rendezvous key for the tensor of "name". + static string CreateKey(const string& name); + + // Parses the key constructed by CreateKey and parse src/dst device + // names into structures respectively. + // (TODO) Reserved for future design of multi-edge data worker send/recv op + struct ParsedKey { + StringPiece tensor_name; + ParsedKey() {} + ParsedKey(const ParsedKey& b) { *this = b; } + + ParsedKey& operator=(const ParsedKey& b); + StringPiece FullKey() const { return buf_; } + + private: + friend class DataWorkerRendezvous; + friend class BaseDataWorkerSendOp; + friend class DataWorkerSendOp; + friend class LocalDataWorkerSendOp; + friend class DataWorkerRecvOp; + friend class DataWorkerFuseRecvOp; + string buf_; + }; + static Status ParseKey(StringPiece key, ParsedKey* out); + + // Synchronous wrapper for RecvAsync. + Status Recv(const ParsedKey& key, const Args& args, Tensor* val, + bool* is_dead, int64 timeout_ms); + Status Recv(const ParsedKey& key, const Args& args, Tensor* val, + bool* is_dead); + + virtual Status Initialize(WorkerSession* session) = 0; + virtual void StartAbort(const Status& status) = 0; + virtual void SetTensorNames(const std::vector& tensor_names) = 0; + // DataWorkerRendezvous receives tensors asynchronously and cannot + // rely on the RecvOp to set the attrs. We therefore expose this + // interface to set the attrs in advance. + virtual Status SetRecvAttrs(const ParsedKey& key, const AllocatorAttributes& alloc_attrs, const string& device) = 0; + + typedef std::function + DoneCallback; + typedef std::function&, + const Args&, + const std::vector&)> + // is_dead omitted + FuseDoneCallback; + struct Item { + DoneCallback waiter = nullptr; + DataWorkerInfo data_worker_info; + Tensor value; + bool is_dead = false; + ParsedKey key; + Args send_args; + Args recv_args; + CancellationToken cancellation_token; + // Returns true iff this item represents a value being sent. + bool IsSendValue() const { return this->waiter == nullptr; } + + ~Item() { + if (send_args.device_context) { + send_args.device_context->Unref(); + } + if (recv_args.device_context) { + recv_args.device_context->Unref(); + } + } + }; + struct FuseItem { + FuseDoneCallback waiter = nullptr; + DataWorkerInfo data_worker_info; + std::vector values; + std::vector keys; + std::vector send_args; + Args recv_args; + CancellationToken cancellation_token; + // Returns true iff this item represents a value being sent. + bool IsSendValue() const { return this->waiter == nullptr; } + + ~FuseItem() { + for (Args args : send_args) { + if (args.device_context) { + args.device_context->Unref(); + } + } + if (recv_args.device_context) { + recv_args.device_context->Unref(); + } + } + }; + + // We make the send operation an async method because it may block due to the queue size limit. + virtual void DataWorkerSendAsync(const ParsedKey& key, + const Tensor& val, + const Args& send_args, + DoneCallback done) = 0; + virtual Status LocalDataWorkerSend(const ParsedKey& key, + const string& tensor_name, + const Tensor& val, + const Args& send_args) = 0; + virtual void RecvLocalAsync(const ParsedKey& key, DoneCallback done) = 0; + virtual void FuseRecvLocalAsync(const std::vector& keys, FuseDoneCallback done) = 0; + virtual void DataWorkerRecvAsync(const ParsedKey& key, const Args& args, DoneCallback done) = 0; + virtual void DataWorkerFuseRecvAsync(const Args& recv_args, + FuseDoneCallback done) = 0; +}; + +} // end namespace tensorflow + +#endif // TENSORFLOW_FRAMEWORK_DATA_WORKER_RENDEZVOUS_H_ \ No newline at end of file diff --git a/tensorflow/core/graph/data_worker_graph_partition.h b/tensorflow/core/graph/data_worker_graph_partition.h new file mode 100644 index 00000000000..84fbeafebf8 --- /dev/null +++ b/tensorflow/core/graph/data_worker_graph_partition.h @@ -0,0 +1,57 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_GRAPH_DATA_WORKER_GRAPH_PARTITION_H_ +#define TENSORFLOW_CORE_GRAPH_DATA_WORKER_GRAPH_PARTITION_H_ + +#include +#include +#include +#include +#include + +#include "tensorflow/core/framework/function.h" +#include "tensorflow/core/framework/graph.pb.h" +#include "tensorflow/core/graph/graph.h" + + +namespace tensorflow { + +struct PartitionForDataWorkerOptions { + // A function that returns a unique graph node name with the given + // prefix. + typedef std::function NewNameFunc; + NewNameFunc new_name = nullptr; + + // If specified, flib_def defines a function library that should be + // partitioned and replicated into each resulting partition graphs. + const FunctionLibraryDefinition* flib_def = nullptr; + + std::unordered_set targets; + bool use_default_split_points = true; + bool extend_default_split = false; + bool fuse_recv = false; +}; + +Status PartitionForDataWorker( + const PartitionForDataWorkerOptions& opts, + Graph* g, + std::unordered_map>& data_worker_graphs, + std::unordered_map>& node_names, + std::unordered_map>& tensor_names); + +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_GRAPH_DATA_WORKER_GRAPH_PARTITION_H_ \ No newline at end of file From 6ec534b8d2b00b9e69eae27080c379cdd35d0cd7 Mon Sep 17 00:00:00 2001 From: Hanyu Zhao Date: Mon, 28 Feb 2022 20:03:12 +0800 Subject: [PATCH 2/2] add DataWorkerSendRecv ops --- .../core/kernels/data_worker_sendrecv_ops.h | 92 ++++++++++++ .../core/ops/data_worker_sendrecv_ops.cc | 136 ++++++++++++++++++ 2 files changed, 228 insertions(+) create mode 100644 tensorflow/core/kernels/data_worker_sendrecv_ops.h create mode 100644 tensorflow/core/ops/data_worker_sendrecv_ops.cc diff --git a/tensorflow/core/kernels/data_worker_sendrecv_ops.h b/tensorflow/core/kernels/data_worker_sendrecv_ops.h new file mode 100644 index 00000000000..6ed3623c0be --- /dev/null +++ b/tensorflow/core/kernels/data_worker_sendrecv_ops.h @@ -0,0 +1,92 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_KERNELS_DATA_WORKER_SENDRECV_OPS_H_ +#define TENSORFLOW_CORE_KERNELS_DATA_WORKER_SENDRECV_OPS_H_ + +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/data_worker_rendezvous.h" +#include "tensorflow/core/platform/macros.h" + +namespace tensorflow { + +class BaseDataWorkerSendOp : public AsyncOpKernel { + public: + explicit BaseDataWorkerSendOp(OpKernelConstruction* ctx); + void ComputeAsync(OpKernelContext* ctx, DoneCallback done) override; + + protected: + virtual void Send(OpKernelContext* ctx, + const DataWorkerRendezvous::Args& args, + DoneCallback done) = 0; + string tensor_name_; + string send_device_; + string send_device_type_; + DataWorkerRendezvous::ParsedKey parsed_key_; + + private: + TF_DISALLOW_COPY_AND_ASSIGN(BaseDataWorkerSendOp); +}; + +class DataWorkerSendOp : public BaseDataWorkerSendOp { + public: + explicit DataWorkerSendOp(OpKernelConstruction* ctx) + : BaseDataWorkerSendOp(ctx) {} + + void Send(OpKernelContext* ctx, + const DataWorkerRendezvous::Args& args, + DoneCallback done) override; +}; + +class LocalDataWorkerSendOp : public BaseDataWorkerSendOp { + public: + explicit LocalDataWorkerSendOp(OpKernelConstruction* ctx) + : BaseDataWorkerSendOp(ctx) {} + + void Send(OpKernelContext* ctx, + const DataWorkerRendezvous::Args& args, + DoneCallback done) override; +}; + +class DataWorkerRecvOp : public AsyncOpKernel { + public: + explicit DataWorkerRecvOp(OpKernelConstruction* ctx); + void ComputeAsync(OpKernelContext* ctx, DoneCallback done) override; + + private: + string recv_device_; + string recv_device_type_; + DataWorkerRendezvous::ParsedKey parsed_key_; + bool attrs_set_ = false; + + TF_DISALLOW_COPY_AND_ASSIGN(DataWorkerRecvOp); +}; + +class DataWorkerFuseRecvOp : public AsyncOpKernel { + public: + explicit DataWorkerFuseRecvOp(OpKernelConstruction* ctx); + void ComputeAsync(OpKernelContext* ctx, DoneCallback done) override; + + private: + string recv_device_; + string recv_device_type_; + std::vector parsed_keys_; + bool attrs_set_ = false; + + TF_DISALLOW_COPY_AND_ASSIGN(DataWorkerFuseRecvOp); +}; +} // end namespace tensorflow + +#endif // TENSORFLOW_CORE_KERNELS_DATA_WORKER_SENDRECV_OPS_H_ diff --git a/tensorflow/core/ops/data_worker_sendrecv_ops.cc b/tensorflow/core/ops/data_worker_sendrecv_ops.cc new file mode 100644 index 00000000000..a4fe4a31e01 --- /dev/null +++ b/tensorflow/core/ops/data_worker_sendrecv_ops.cc @@ -0,0 +1,136 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/framework/common_shape_fns.h" +#include "tensorflow/core/framework/op.h" + +namespace tensorflow { + +REGISTER_OP("_DataWorkerSend") + .Input("tensor: T") + .Attr("T: type") + .Attr("tensor_name: string") + .Attr("send_device: string") + .SetIsStateful() + .SetShapeFn(shape_inference::UnknownShape) + .Doc(R"doc( +Sends the named tensor from data worker to training worker. + +tensor: The tensor to send. +tensor_name: The name of the tensor to send. +)doc"); + +REGISTER_OP("_LocalDataWorkerSend") + .Input("tensor: T") + .Attr("T: type") + .Attr("tensor_name: string") + .Attr("send_device: string") + .SetIsStateful() + .SetShapeFn(shape_inference::UnknownShape) + .Doc(R"doc( +Sends the named tensor from the local data worker +(the data worker subgraph that still resides +in the training worker) to training worker. + +tensor: The tensor to send. +tensor_name: The name of the tensor to send. +)doc"); + +REGISTER_OP("_DataWorkerRecv") + .Output("tensor: tensor_type") + .Attr("tensor_type: type") + .Attr("tensor_name: string") + .Attr("recv_device: string") + .SetIsStateful() + .SetShapeFn(shape_inference::UnknownShape) + .Doc(R"doc( +Receives the named tensor from data worker. + +tensor: The tensor to receive. +tensor_name: The name of the tensor to receive. +recv_device: The name of the device receiving the tensor. +)doc"); + +REGISTER_OP("_HostDataWorkerSend") + .Input("tensor: T") + .Attr("T: type") + .Attr("tensor_name: string") + .Attr("send_device: string") + .SetIsStateful() + .SetShapeFn(shape_inference::UnknownShape) + .Doc(R"doc( +Sends the named tensor from data worker to training worker. + +_HostDataWorkerSend requires its input on host memory whereas _DataWorkerSend requires its +input on device memory. + +tensor: The tensor to send. +tensor_name: The name of the tensor to send. +)doc"); + +REGISTER_OP("_HostLocalDataWorkerSend") + .Input("tensor: T") + .Attr("T: type") + .Attr("tensor_name: string") + .Attr("send_device: string") + .SetIsStateful() + .SetShapeFn(shape_inference::UnknownShape) + .Doc(R"doc( +Sends the named tensor from the local data worker +(the data worker subgraph that still resides +in the training worker) to training worker. + +_HostLocalDataWorkerSend requires its input on host memory whereas _LocalDataWorkerSend requires its +input on device memory. + +tensor: The tensor to send. +tensor_name: The name of the tensor to send. +)doc"); + +REGISTER_OP("_HostDataWorkerRecv") + .Output("tensor: tensor_type") + .Attr("tensor_type: type") + .Attr("tensor_name: string") + .Attr("recv_device: string") + .SetIsStateful() + .SetShapeFn(shape_inference::UnknownShape) + .Doc(R"doc( +Receives the named tensor from data worker. + +_HostDataWorkerRecv produces its output on host memory whereas _DataWorkerRecv produces its +output on device memory. + +tensor: The tensor to receive. +tensor_name: The name of the tensor to receive. +recv_device: The name of the device receiving the tensor. +)doc"); + +REGISTER_OP("_DataWorkerFuseRecv") + .Output("tensors: tensor_types") + .Attr("tensor_types: list(type)") + .Attr("tensor_names: list(string)") + .Attr("recv_device: string") + .SetIsStateful() + .SetShapeFn(shape_inference::UnknownShape); + +REGISTER_OP("_HostDataWorkerFuseRecv") + .Output("tensors: tensor_types") + .Attr("tensor_types: list(type)") + .Attr("tensor_names: list(string)") + .Attr("recv_device: string") + .SetIsStateful() + .SetShapeFn(shape_inference::UnknownShape); + +} // end namespace tensorflow