Skip to content

Commit 40eea6d

Browse files
pritamdamaniafacebook-github-bot
authored andcommitted
Support device map for distributed autograd while using TensorPipe. (pytorch#44859)
Summary: Pull Request resolved: pytorch#44859 TensorPipe's `set_device_map` option was applied during the forward pass. However, if we ran the backward pass for the graph we would not automatically pick up the reverse device mapping. As a result, users had to specify both forward and backward device mapping which is very tedious to do. In this PR, I've added this functionality such that TensorPipe automatically picks up the reverse device mapping during the backward pass. This is done by storing the appropriate device mapping in the "recv" autograd function for distributed autograd. #Closes: pytorch#44170 ghstack-source-id: 119950842 Test Plan: 1) waitforbuildbot 2) Unit test added. Reviewed By: mrshenli Differential Revision: D23751975 fbshipit-source-id: 2717d0ef5bde3db029a6172d98aad95734d52140
1 parent 6d09809 commit 40eea6d

20 files changed

+224
-63
lines changed

test/cpp/rpc/e2e_test_base.h

+8
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,14 @@ class TestE2EBase : public ::testing::Test {
4040
RpcAgent::setCurrentRpcAgent(rpcAgent);
4141
std::shared_ptr<TypeResolver> typeResolver =
4242
std::make_shared<TypeResolver>([&](const c10::QualifiedName& qn) {
43+
// For Dict that is used for device map.
44+
auto pos = qn.name().find("Dict");
45+
if (pos != std::string::npos) {
46+
return c10::StrongTypePtr(
47+
nullptr,
48+
c10::DictType::create(
49+
c10::IntType::create(), c10::IntType::create()));
50+
}
4351
return c10::StrongTypePtr(
4452
nullptr, c10::TensorType::create(at::Tensor()));
4553
});

torch/csrc/distributed/autograd/functions/recvrpc_backward.cpp

+7-3
Original file line numberDiff line numberDiff line change
@@ -13,10 +13,12 @@ using torch::autograd::variable_list;
1313
RecvRpcBackward::RecvRpcBackward(
1414
const AutogradMetadata& autogradMetadata,
1515
ContextPtr autogradContext,
16-
rpc::worker_id_t fromWorkerId)
16+
rpc::worker_id_t fromWorkerId,
17+
std::unordered_map<c10::DeviceIndex, c10::DeviceIndex> deviceMap)
1718
: autogradMetadata_(autogradMetadata),
1819
autogradContext_(std::move(autogradContext)),
19-
fromWorkerId_(fromWorkerId) {}
20+
fromWorkerId_(fromWorkerId),
21+
deviceMap_(std::move(deviceMap)) {}
2022

2123
variable_list RecvRpcBackward::apply(variable_list&& grads) {
2224
std::vector<Variable> outputGrads;
@@ -49,7 +51,9 @@ variable_list RecvRpcBackward::apply(variable_list&& grads) {
4951
auto rpcAgent = rpc::RpcAgent::getCurrentRpcAgent();
5052
auto jitFuture = rpcAgent->send(
5153
rpcAgent->getWorkerInfo(fromWorkerId_),
52-
std::move(gradCall).toMessage());
54+
std::move(gradCall).toMessage(),
55+
rpc::kUnsetRpcTimeout,
56+
deviceMap_);
5357

5458
// Record the future in the context.
5559
sharedContext->addOutstandingRpc(jitFuture);

torch/csrc/distributed/autograd/functions/recvrpc_backward.h

+5-1
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,8 @@ class TORCH_API RecvRpcBackward : public torch::autograd::Node {
2222
explicit RecvRpcBackward(
2323
const AutogradMetadata& autogradMetadata,
2424
std::shared_ptr<DistAutogradContext> autogradContext,
25-
rpc::worker_id_t fromWorkerId);
25+
rpc::worker_id_t fromWorkerId,
26+
std::unordered_map<c10::DeviceIndex, c10::DeviceIndex> deviceMap);
2627

2728
torch::autograd::variable_list apply(
2829
torch::autograd::variable_list&& grads) override;
@@ -38,6 +39,9 @@ class TORCH_API RecvRpcBackward : public torch::autograd::Node {
3839
// The worker id from which the RPC was received. During the backward pass,
3940
// we need to propagate the gradients to this workerId.
4041
rpc::worker_id_t fromWorkerId_;
42+
43+
// Device mapping for tensors sent over RPC.
44+
const std::unordered_map<c10::DeviceIndex, c10::DeviceIndex> deviceMap_;
4145
};
4246

4347
} // namespace autograd

torch/csrc/distributed/autograd/rpc_messages/rpc_with_autograd.cpp

+31-7
Original file line numberDiff line numberDiff line change
@@ -18,11 +18,13 @@ RpcWithAutograd::RpcWithAutograd(
1818
worker_id_t fromWorkerId,
1919
MessageType messageType,
2020
const AutogradMetadata& autogradMetadata,
21-
rpc::Message&& wrappedMessage)
21+
rpc::Message&& wrappedMessage,
22+
std::unordered_map<c10::DeviceIndex, c10::DeviceIndex> deviceMap)
2223
: fromWorkerId_(fromWorkerId),
2324
messageType_(messageType),
2425
autogradMetadata_(autogradMetadata),
25-
wrappedMessage_(std::move(wrappedMessage)) {
26+
wrappedMessage_(std::move(wrappedMessage)),
27+
deviceMap_(std::move(deviceMap)) {
2628
TORCH_INTERNAL_ASSERT(
2729
messageType_ == MessageType::FORWARD_AUTOGRAD_REQ ||
2830
messageType_ == MessageType::FORWARD_AUTOGRAD_RESP);
@@ -36,13 +38,15 @@ RpcWithAutograd::RpcWithAutograd(
3638
const AutogradMetadata& autogradMetadata,
3739
std::unique_ptr<RpcCommandBase> wrappedRpc,
3840
MessageType wrappedMessageType,
39-
std::vector<torch::Tensor> tensors)
41+
std::vector<torch::Tensor> tensors,
42+
std::unordered_map<c10::DeviceIndex, c10::DeviceIndex> deviceMap)
4043
: fromWorkerId_(fromWorkerId),
4144
messageType_(messageType),
4245
autogradMetadata_(autogradMetadata),
4346
wrappedRpc_(std::move(wrappedRpc)),
4447
wrappedMessageType_(wrappedMessageType),
45-
tensors_(std::move(tensors)) {
48+
tensors_(std::move(tensors)),
49+
deviceMap_(std::move(deviceMap)) {
4650
TORCH_INTERNAL_ASSERT(wrappedRpc_ != nullptr, "wrappedRpc cannot be null!");
4751
TORCH_INTERNAL_ASSERT(
4852
messageType_ == MessageType::FORWARD_AUTOGRAD_REQ ||
@@ -56,10 +60,17 @@ Message RpcWithAutograd::toMessageImpl() && {
5660
auto payload = std::move(wrappedMessage_).movePayload();
5761
TORCH_INTERNAL_ASSERT(!payload.empty());
5862

63+
// Convert deviceMap to c10::Dict for serialization.
64+
c10::Dict<int64_t, int64_t> deviceMap;
65+
for (const auto& mapEntry : deviceMap_) {
66+
deviceMap.insert(mapEntry.first, mapEntry.second);
67+
}
68+
5969
std::vector<at::IValue> ivalues{wrappedMessageType,
6070
autogradMetadata_.autogradContextId,
6171
autogradMetadata_.autogradMessageId,
62-
fromWorkerId_};
72+
fromWorkerId_,
73+
deviceMap};
6374

6475
// Now pickle using JIT pickler.
6576
std::vector<torch::Tensor> tensorTable;
@@ -92,12 +103,19 @@ std::unique_ptr<RpcWithAutograd> RpcWithAutograd::fromMessage(
92103
auto tupleElements = rpc::readWrappedPayload(payload, message);
93104

94105
// Gather all the fields.
95-
TORCH_INTERNAL_ASSERT(tupleElements.size() == 4);
106+
TORCH_INTERNAL_ASSERT(tupleElements.size() == 5);
96107
MessageType wrappedMessageType =
97108
static_cast<MessageType>(tupleElements[0].toInt());
98109
AutogradMetadata autogradMetadata(
99110
tupleElements[1].toInt(), tupleElements[2].toInt());
100111
worker_id_t workerId = tupleElements[3].toInt();
112+
auto c10DeviceMap = tupleElements[4].to<c10::Dict<int64_t, int64_t>>();
113+
114+
// Convert to regular map.
115+
std::unordered_map<c10::DeviceIndex, c10::DeviceIndex> deviceMap;
116+
for (const auto& mapEntry : c10DeviceMap) {
117+
deviceMap.insert({mapEntry.key(), mapEntry.value()});
118+
}
101119

102120
// Create new message type and build wrapped RPC.
103121
Message wrappedMessage(
@@ -116,7 +134,8 @@ std::unique_ptr<RpcWithAutograd> RpcWithAutograd::fromMessage(
116134
autogradMetadata,
117135
std::move(wrappedRpc),
118136
wrappedMessageType,
119-
wrappedMessage.tensors());
137+
wrappedMessage.tensors(),
138+
deviceMap);
120139
}
121140

122141
std::vector<torch::Tensor>& RpcWithAutograd::tensors() {
@@ -150,6 +169,11 @@ rpc::worker_id_t RpcWithAutograd::fromWorkerId() const {
150169
return fromWorkerId_;
151170
}
152171

172+
const std::unordered_map<c10::DeviceIndex, c10::DeviceIndex>& RpcWithAutograd::
173+
deviceMap() {
174+
return deviceMap_;
175+
}
176+
153177
} // namespace autograd
154178
} // namespace distributed
155179
} // namespace torch

torch/csrc/distributed/autograd/rpc_messages/rpc_with_autograd.h

+10-2
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,8 @@ class TORCH_API RpcWithAutograd final : public rpc::RpcCommandBase {
1818
rpc::worker_id_t fromWorkerId,
1919
rpc::MessageType messageType,
2020
const AutogradMetadata& autogradMetadata,
21-
rpc::Message&& wrappedMessage);
21+
rpc::Message&& wrappedMessage,
22+
std::unordered_map<c10::DeviceIndex, c10::DeviceIndex> deviceMap = {});
2223

2324
// Used when receiving an RPC over the wire.
2425
RpcWithAutograd(
@@ -27,7 +28,8 @@ class TORCH_API RpcWithAutograd final : public rpc::RpcCommandBase {
2728
const AutogradMetadata& autogradMetadata,
2829
std::unique_ptr<rpc::RpcCommandBase> wrappedRpc,
2930
rpc::MessageType wrappedMessageType,
30-
std::vector<torch::Tensor> tensors);
31+
std::vector<torch::Tensor> tensors,
32+
std::unordered_map<c10::DeviceIndex, c10::DeviceIndex> deviceMap = {});
3133

3234
rpc::Message toMessageImpl() && override;
3335

@@ -52,6 +54,9 @@ class TORCH_API RpcWithAutograd final : public rpc::RpcCommandBase {
5254
// Retrieve the worker id from which the RPC originated.
5355
rpc::worker_id_t fromWorkerId() const;
5456

57+
// Retrieve the device map.
58+
const std::unordered_map<c10::DeviceIndex, c10::DeviceIndex>& deviceMap();
59+
5560
private:
5661
// WorkerId from which this RPC originated. This is necessary for knowing
5762
// which worker we need to contact during the backward pass.
@@ -83,6 +88,9 @@ class TORCH_API RpcWithAutograd final : public rpc::RpcCommandBase {
8388

8489
// Tensors part of the wrappedRpc that need to be considered for autograd.
8590
std::vector<torch::Tensor> tensors_;
91+
92+
// Device mapping for tensors that are sent across an RPC to another node.
93+
std::unordered_map<c10::DeviceIndex, c10::DeviceIndex> deviceMap_;
8694
};
8795

8896
} // namespace autograd

torch/csrc/distributed/autograd/utils.cpp

+9-5
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,8 @@ void addSendRpcBackward(
5252
ContextPtr addRecvRpcBackward(
5353
const AutogradMetadata& autogradMetadata,
5454
std::vector<torch::Tensor>& tensors,
55-
rpc::worker_id_t fromWorkerId) {
55+
rpc::worker_id_t fromWorkerId,
56+
const std::unordered_map<c10::DeviceIndex, c10::DeviceIndex>& deviceMap) {
5657
// Initialize autograd context if necessary.
5758
auto& autogradContainer = DistAutogradContainer::getInstance();
5859
auto autogradContext =
@@ -61,7 +62,7 @@ ContextPtr addRecvRpcBackward(
6162
if (!tensors.empty() && torch::autograd::compute_requires_grad(tensors)) {
6263
// Attach the tensors as inputs to the autograd function.
6364
auto grad_fn = std::make_shared<RecvRpcBackward>(
64-
autogradMetadata, autogradContext, fromWorkerId);
65+
autogradMetadata, autogradContext, fromWorkerId, deviceMap);
6566
for (auto& tensor : tensors) {
6667
if (tensor.requires_grad()) {
6768
torch::autograd::set_history(tensor, grad_fn);
@@ -102,7 +103,8 @@ Message getMessageWithAutograd(
102103
const rpc::worker_id_t dstId,
103104
torch::distributed::rpc::Message&& wrappedRpcMsg,
104105
MessageType msgType,
105-
bool forceGradRecording) {
106+
bool forceGradRecording,
107+
const std::unordered_map<c10::DeviceIndex, c10::DeviceIndex>& deviceMap) {
106108
auto& autogradContainer = DistAutogradContainer::getInstance();
107109

108110
// If there is no valid context and no tensor requires grads, send original
@@ -125,7 +127,8 @@ Message getMessageWithAutograd(
125127
RpcAgent::getCurrentRpcAgent()->getWorkerInfo().id_,
126128
msgType,
127129
autogradMetadata,
128-
std::move(wrappedRpcMsg));
130+
std::move(wrappedRpcMsg),
131+
deviceMap);
129132

130133
if (tensorsRequireGrad) {
131134
// Record autograd information for 'send'.
@@ -149,7 +152,8 @@ std::shared_ptr<JitFuture> sendMessageWithAutograd(
149152
dst.id_,
150153
std::move(wrappedRpcMsg),
151154
MessageType::FORWARD_AUTOGRAD_REQ,
152-
forceGradRecording);
155+
forceGradRecording,
156+
agent.getDeviceMap(dst));
153157

154158
std::shared_ptr<JitFuture> fut;
155159
// If profiler is enabled, wrap this message with profiling metadata that will

torch/csrc/distributed/autograd/utils.h

+5-2
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,8 @@ TORCH_API void addSendRpcBackward(
3030
TORCH_API ContextPtr addRecvRpcBackward(
3131
const AutogradMetadata& autogradMetadata,
3232
std::vector<torch::Tensor>& tensors,
33-
rpc::worker_id_t fromWorkerId);
33+
rpc::worker_id_t fromWorkerId,
34+
const std::unordered_map<c10::DeviceIndex, c10::DeviceIndex>& deviceMap);
3435

3536
// This method is a wrapper utility used internally to wrap autograd info
3637
// and attach autograd function for each type of rpc call if it has valid
@@ -42,7 +43,9 @@ TORCH_API rpc::Message getMessageWithAutograd(
4243
const rpc::worker_id_t dstId,
4344
rpc::Message&& wrappedRpcMsg,
4445
rpc::MessageType msgType,
45-
bool forceGradRecording = false);
46+
bool forceGradRecording = false,
47+
const std::unordered_map<c10::DeviceIndex, c10::DeviceIndex>& deviceMap =
48+
{});
4649

4750
// Send message after autograd checking
4851
TORCH_API std::shared_ptr<c10::ivalue::Future>

torch/csrc/distributed/rpc/process_group_agent.cpp

+2-1
Original file line numberDiff line numberDiff line change
@@ -290,7 +290,8 @@ void ProcessGroupAgent::shutdownImpl() {
290290
std::shared_ptr<JitFuture> ProcessGroupAgent::send(
291291
const WorkerInfo& to,
292292
Message&& message,
293-
const float rpcTimeoutSeconds) {
293+
const float rpcTimeoutSeconds,
294+
const std::unordered_map<c10::DeviceIndex, c10::DeviceIndex>& deviceMap) {
294295
// Throw if we previously encountered an exception in ::listenLoop.
295296
{
296297
std::unique_lock<std::mutex> guard(listenLoopExceptionMutex_);

torch/csrc/distributed/rpc/process_group_agent.h

+3-1
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,9 @@ class TORCH_API ProcessGroupAgent : public RpcAgent {
9191
std::shared_ptr<JitFuture> send(
9292
const WorkerInfo& to,
9393
Message&& message,
94-
const float rpcTimeoutSeconds = kUnsetRpcTimeout) override;
94+
const float rpcTimeoutSeconds = kUnsetRpcTimeout,
95+
const std::unordered_map<c10::DeviceIndex, c10::DeviceIndex>& deviceMap =
96+
{}) override;
9597

9698
// put SendWork into a queue and notify the worker thread
9799
virtual void enqueueSend(SendWork work);

torch/csrc/distributed/rpc/request_callback_no_python.cpp

+10-1
Original file line numberDiff line numberDiff line change
@@ -345,11 +345,20 @@ void RequestCallbackNoPython::processForwardAutogradReq(
345345
const std::shared_ptr<JitFuture>& responseFuture) const {
346346
auto& rpcWithAutograd = static_cast<RpcWithAutograd&>(rpc);
347347

348+
// Need to reverse the device map for the backward pass of distributed
349+
// autograd.
350+
std::unordered_map<c10::DeviceIndex, c10::DeviceIndex> reverseDeviceMap;
351+
for (const auto& mapEntry : rpcWithAutograd.deviceMap()) {
352+
reverseDeviceMap.insert({mapEntry.second, mapEntry.first});
353+
}
354+
355+
348356
// Attach 'recv' autograd function.
349357
auto autogradContext = addRecvRpcBackward(
350358
rpcWithAutograd.autogradMetadata(),
351359
rpcWithAutograd.tensors(),
352-
rpcWithAutograd.fromWorkerId());
360+
rpcWithAutograd.fromWorkerId(),
361+
reverseDeviceMap);
353362
// For this recv thread on server side, before processRpc(),
354363
// set current_context_id_ to be context_id passed from client.
355364
// In this way, if there is nested rpc call in python rpc call, original

torch/csrc/distributed/rpc/rpc_agent.cpp

+6
Original file line numberDiff line numberDiff line change
@@ -286,6 +286,12 @@ bool RpcAgent::isGILProfilingEnabled() {
286286
return profilingEnabled_.load();
287287
}
288288

289+
std::unordered_map<c10::DeviceIndex, c10::DeviceIndex> RpcAgent::getDeviceMap(
290+
const WorkerInfo& dest) {
291+
// Default implementation has no device map.
292+
return {};
293+
}
294+
289295
std::unordered_map<std::string, std::string> RpcAgent::getDebugInfo() {
290296
/* This would later include more info other than metrics for eg: may include
291297
stack traces for the threads owned by the agent */

torch/csrc/distributed/rpc/rpc_agent.h

+7-1
Original file line numberDiff line numberDiff line change
@@ -160,7 +160,9 @@ class TORCH_API RpcAgent {
160160
virtual std::shared_ptr<JitFuture> send(
161161
const WorkerInfo& to,
162162
Message&& message,
163-
const float rpcTimeoutSeconds = kUnsetRpcTimeout) = 0;
163+
const float rpcTimeoutSeconds = kUnsetRpcTimeout,
164+
const std::unordered_map<c10::DeviceIndex, c10::DeviceIndex>& deviceMap =
165+
{}) = 0;
164166

165167
// Retries sending the message up to maxRetries times until an ACK is
166168
// receieved. The duration between consecutive sends is increased over
@@ -259,6 +261,10 @@ class TORCH_API RpcAgent {
259261
// Get the type resolver
260262
std::shared_ptr<TypeResolver> getTypeResolver();
261263

264+
// Retrieves the device map for the provided destination worker.
265+
virtual std::unordered_map<c10::DeviceIndex, c10::DeviceIndex> getDeviceMap(
266+
const WorkerInfo& dest);
267+
262268
protected:
263269
const WorkerInfo workerInfo_;
264270
const std::unique_ptr<RequestCallback> cb_;

0 commit comments

Comments
 (0)