Skip to content

Commit fa2036a

Browse files
emilyfertigGoogle-ML-Automation
authored andcommitted
[PjRt-IFRT] For cross-host transfers, call the new CrossHost{Send,Receive}Buffers API if the plugin implements it.
This enables efficient cross-host transfers on GPU. PiperOrigin-RevId: 840867337
1 parent 49ab0c1 commit fa2036a

File tree

3 files changed

+107
-35
lines changed

3 files changed

+107
-35
lines changed

xla/pjrt/pjrt_client.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1032,7 +1032,7 @@ class PjRtClient {
10321032
absl::Span<PjRtBuffer* const> buffers,
10331033
absl::Span<const PjRtGlobalDeviceId> dst_global_device_ids,
10341034
std::vector<CrossHostTransferKey> transfer_keys) {
1035-
return absl::InternalError(
1035+
return absl::UnimplementedError(
10361036
"Cross-host data transfers are not supported by this client.");
10371037
}
10381038

xla/python/pjrt_ifrt/pjrt_client.cc

Lines changed: 94 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -1313,13 +1313,23 @@ PjRtClient::CopyArraysForCrossHost(absl::Span<ArrayRef> arrays,
13131313
std::optional<MemoryKind> memory_kind) {
13141314
std::vector<ArrayRef> new_arrays;
13151315
new_arrays.reserve(arrays.size());
1316-
std::vector<PjRtBuffers> recv_buffers;
1316+
std::vector<std::vector<std::unique_ptr<PjRtBuffer>>> recv_buffers;
13171317
recv_buffers.reserve(dst_devices->AddressableDeviceList()->size());
1318+
auto on_send_done = [](absl::Status status) {
1319+
if (!status.ok()) {
1320+
LOG(ERROR) << "xla::PjRtClient::CrossHostSendBuffers failed: " << status;
1321+
}
1322+
};
1323+
auto on_recv_done = [](absl::Status status) {
1324+
if (!status.ok()) {
1325+
LOG(ERROR) << "Cross-host receive buffer failed: " << status;
1326+
}
1327+
};
13181328
int j = 0; // Counter for the addressable buffers.
13191329
for (int i = 0; i < dst_devices->size(); ++i) {
13201330
// TODO(emilyaf): Extend CreateNewTransferKey to take N and return N keys
13211331
// as a performance optimization.
1322-
std::vector<int64_t> transfer_keys;
1332+
std::vector<CrossHostTransferKey> transfer_keys;
13231333
transfer_keys.reserve(arrays.size());
13241334
for (int k = 0; k < arrays.size(); ++k) {
13251335
transfer_keys.push_back(CreateNewTransferKey());
@@ -1337,23 +1347,49 @@ PjRtClient::CopyArraysForCrossHost(absl::Span<ArrayRef> arrays,
13371347
dst_devices->devices()[i]->DebugString()));
13381348
}
13391349

1340-
PjRtArray::PjRtBuffers send_buffers;
1350+
// Create vector of dst devices; we send each array to
1351+
// dst_devices->devices()[i].
1352+
TF_ASSIGN_OR_RETURN(
1353+
xla::PjRtGlobalDeviceId dst_global_device_id,
1354+
GetPjRtGlobalDeviceId(dst_devices->devices()[i]->Id()));
1355+
std::vector<PjRtGlobalDeviceId> dst_global_device_ids(
1356+
arrays.size(), dst_global_device_id);
1357+
1358+
// Create send buffers.
1359+
std::vector<PjRtBuffer*> send_buffers;
13411360
send_buffers.reserve(arrays.size());
13421361
for (ArrayRef& array : arrays) {
13431362
if (auto* const pjrt_array = llvm::dyn_cast<PjRtArray>(array.get())) {
13441363
auto buffers = pjrt_array->pjrt_buffers();
1345-
send_buffers.push_back(buffers[j]);
1364+
send_buffers.push_back(buffers[j].get());
13461365
} else {
13471366
// TODO(emilyaf): Support string arrays.
13481367
return absl::InvalidArgumentError(
13491368
"Unsupported array type for cross-host "
13501369
"PjRtClient::CopyArraysForCrossHost");
13511370
}
13521371
}
1353-
TF_RETURN_IF_ERROR(
1354-
CrossHostSendBuffers(send_buffers, std::move(transfer_keys)));
1372+
1373+
// If the PJRT plugin implements the `CrossHostSendBuffers` API, use it.
1374+
// Otherwise, call this class's `CrossHostSendBuffers` method to use the
1375+
// plugin's `CopyToRemoteDevice` API, getting the buffer descriptors from
1376+
// the KV store.
1377+
absl::StatusOr<std::vector<Future<>>> send_futures =
1378+
pjrt_client_->CrossHostSendBuffers(
1379+
send_buffers, std::move(dst_global_device_ids), transfer_keys);
1380+
if (send_futures.ok()) {
1381+
for (Future<>& send_future : *send_futures) {
1382+
send_future.OnReady(on_send_done);
1383+
}
1384+
} else if (absl::IsUnimplemented(send_futures.status())) {
1385+
TF_RETURN_IF_ERROR(
1386+
CrossHostSendBuffers(send_buffers, std::move(transfer_keys)));
1387+
} else {
1388+
return send_futures.status();
1389+
}
13551390
++j;
13561391
} else if (dst_devices->devices()[i]->IsAddressable()) {
1392+
// Create vector of shapes to receive.
13571393
std::vector<xla::Shape> recv_shapes;
13581394
recv_shapes.reserve(arrays.size());
13591395
for (const ArrayRef& array : arrays) {
@@ -1371,14 +1407,42 @@ PjRtClient::CopyArraysForCrossHost(absl::Span<ArrayRef> arrays,
13711407
"PjRtClient::CopyArraysForCrossHost");
13721408
}
13731409
}
1410+
1411+
// Get the dst device we receive into.
13741412
TF_ASSIGN_OR_RETURN(
13751413
xla::PjRtGlobalDeviceId pjrt_global_device_id,
13761414
GetPjRtGlobalDeviceId(dst_devices->devices()[i]->Id()));
13771415
TF_ASSIGN_OR_RETURN(xla::PjRtDevice * pjrt_device,
13781416
pjrt_client_->LookupDevice(pjrt_global_device_id));
1379-
TF_ASSIGN_OR_RETURN(recv_buffers.emplace_back(),
1380-
CrossHostReceiveBuffers(recv_shapes, pjrt_device,
1381-
std::move(transfer_keys)));
1417+
1418+
// Create vector of src devices; we receive each array from
1419+
// src_devices->devices()[i].
1420+
TF_ASSIGN_OR_RETURN(
1421+
xla::PjRtGlobalDeviceId src_global_device_id,
1422+
GetPjRtGlobalDeviceId(src_devices->devices()[i]->Id()));
1423+
std::vector<PjRtGlobalDeviceId> src_global_device_ids(
1424+
arrays.size(), src_global_device_id);
1425+
1426+
// If the PJRT plugin implements the `CrossHostReceiveBuffers` API, use
1427+
// it. Otherwise, call this class's `CrossHostReceiveBuffers` method to
1428+
// use the plugin's `MakeCrossHostReceiveBuffers` API, transmitting the
1429+
// buffer descriptors via the KV store.
1430+
absl::StatusOr<std::vector<std::unique_ptr<PjRtBuffer>>>
1431+
received_buffers = pjrt_client_->CrossHostReceiveBuffers(
1432+
pjrt_device, recv_shapes, std::move(src_global_device_ids),
1433+
transfer_keys);
1434+
if (received_buffers.ok()) {
1435+
recv_buffers.push_back(std::move(*received_buffers));
1436+
} else if (absl::IsUnimplemented(received_buffers.status())) {
1437+
TF_ASSIGN_OR_RETURN(recv_buffers.emplace_back(),
1438+
CrossHostReceiveBuffers(recv_shapes, pjrt_device,
1439+
std::move(transfer_keys)));
1440+
} else {
1441+
return received_buffers.status();
1442+
}
1443+
for (auto& buffer : recv_buffers.back()) {
1444+
buffer->GetReadyFuture().OnReady(on_recv_done);
1445+
}
13821446
}
13831447
}
13841448

@@ -1428,7 +1492,9 @@ PjRtClient::CopyArraysForCrossHostFallback(
14281492
memory_kind);
14291493
}
14301494

1431-
int64_t PjRtClient::CreateNewTransferKey() { return next_transfer_key_++; }
1495+
CrossHostTransferKey PjRtClient::CreateNewTransferKey() {
1496+
return CrossHostTransferKey(next_transfer_key_++);
1497+
}
14321498

14331499
absl::Status PjRtClient::WatchGlobalProcessInfo(
14341500
tsl::CoordinationServiceAgent& agent) {
@@ -1507,7 +1573,8 @@ absl::Status PjRtClient::WatchGlobalProcessInfo(
15071573
}
15081574

15091575
absl::Status PjRtClient::CrossHostSendBuffers(
1510-
PjRtBuffers buffers, const std::vector<int64_t>& keys) {
1576+
std::vector<PjRtBuffer*> buffers,
1577+
const std::vector<CrossHostTransferKey>& keys) {
15111578
if (keys.size() != buffers.size()) {
15121579
return absl::InternalError(
15131580
"CrossHostSendBuffers: keys must be the same size as buffers.");
@@ -1518,7 +1585,7 @@ absl::Status PjRtClient::CrossHostSendBuffers(
15181585
auto [promise, descriptor_future] = tsl::Future<std::string>::MakePromise();
15191586
work_queue_->Schedule(
15201587
[this, k = keys[i], promise = std::move(promise).ToShared()]() mutable {
1521-
std::string key = absl::StrCat(kKeyPrefix, k);
1588+
std::string key = absl::StrCat(kKeyPrefix, k.value());
15221589
absl::StatusOr<std::string> descriptor =
15231590
kv_store_->Get(key, cross_host_transfer_timeout_);
15241591
if (!descriptor.ok()) {
@@ -1528,16 +1595,24 @@ absl::Status PjRtClient::CrossHostSendBuffers(
15281595
promise->Set(std::move(*descriptor));
15291596
});
15301597
auto on_done = [](absl::Status status, bool sends_were_enqueued) {
1531-
CHECK_OK(status);
1598+
if (!status.ok()) {
1599+
LOG(ERROR) << "`xla::PjRtBuffer::CopyToRemoteDevice` failed: "
1600+
<< status;
1601+
}
1602+
if (!sends_were_enqueued) {
1603+
LOG(ERROR) << "`xla::PjRtBuffer::CopyToRemoteDevice` did not enqueue "
1604+
"sends.";
1605+
}
15321606
};
15331607
buffers[i]->CopyToRemoteDevice(std::move(descriptor_future), on_done);
15341608
}
15351609
return absl::OkStatus();
15361610
}
15371611

1538-
absl::StatusOr<PjRtArray::PjRtBuffers> PjRtClient::CrossHostReceiveBuffers(
1539-
absl::Span<const xla::Shape> shapes, xla::PjRtDevice* device,
1540-
std::vector<int64_t> keys) {
1612+
absl::StatusOr<std::vector<std::unique_ptr<PjRtBuffer>>>
1613+
PjRtClient::CrossHostReceiveBuffers(absl::Span<const xla::Shape> shapes,
1614+
xla::PjRtDevice* device,
1615+
std::vector<CrossHostTransferKey> keys) {
15411616
auto notifier = [this, keys = std::move(keys)](
15421617
absl::StatusOr<xla::PjRtCrossHostRecvState> recv_state) {
15431618
if (!recv_state.ok()) {
@@ -1567,7 +1642,7 @@ absl::StatusOr<PjRtArray::PjRtBuffers> PjRtClient::CrossHostReceiveBuffers(
15671642
return;
15681643
}
15691644
for (int i = 0, n = keys.size(); i < n; ++i) {
1570-
std::string key = absl::StrCat(kKeyPrefix, keys[i]);
1645+
std::string key = absl::StrCat(kKeyPrefix, keys[i].value());
15711646
absl::Status kv_status = kv_store_->Set(
15721647
key, recv_state->descriptors[i].serialized_descriptors.front());
15731648
if (!kv_status.ok()) {
@@ -1581,15 +1656,8 @@ absl::StatusOr<PjRtArray::PjRtBuffers> PjRtClient::CrossHostReceiveBuffers(
15811656
}
15821657
}
15831658
};
1584-
TF_ASSIGN_OR_RETURN(auto recv_buffers,
1585-
pjrt_client_->MakeCrossHostReceiveBuffers(
1586-
shapes, device, std::move(notifier)));
1587-
PjRtArray::PjRtBuffers buffers;
1588-
buffers.reserve(recv_buffers.size());
1589-
for (auto& recv_buffer : recv_buffers) {
1590-
buffers.push_back(std::move(recv_buffer));
1591-
}
1592-
return buffers;
1659+
return pjrt_client_->MakeCrossHostReceiveBuffers(shapes, device,
1660+
std::move(notifier));
15931661
}
15941662

15951663
absl::StatusOr<std::vector<xla::ifrt::ArrayRef>> PjRtClient::RemapArrays(

xla/python/pjrt_ifrt/pjrt_client.h

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -388,15 +388,19 @@ class PjRtClient final
388388
DeviceListRef dst_devices, std::optional<MemoryKind> memory_kind);
389389

390390
// Extracts receive descriptors from a key-value store and sends buffers to a
391-
// remote device.
392-
absl::Status CrossHostSendBuffers(PjRtBuffers buffers,
393-
const std::vector<int64_t>& keys);
391+
// remote device. This is used when the backend does not implement the
392+
// CrossHostSendBuffers API.
393+
absl::Status CrossHostSendBuffers(
394+
std::vector<PjRtBuffer*> buffers,
395+
const std::vector<CrossHostTransferKey>& keys);
394396

395397
// Populates a key-value store with receive descriptors and places buffers
396-
// from a cross-host send onto device.
397-
absl::StatusOr<PjRtBuffers> CrossHostReceiveBuffers(
398-
absl::Span<const xla::Shape> shapes, xla::PjRtDevice* device,
399-
std::vector<int64_t> keys);
398+
// from a cross-host send onto device. This is used when the backend does not
399+
// implement the CrossHostReceiveBuffers API.
400+
absl::StatusOr<std::vector<std::unique_ptr<PjRtBuffer>>>
401+
CrossHostReceiveBuffers(absl::Span<const xla::Shape> shapes,
402+
xla::PjRtDevice* device,
403+
std::vector<CrossHostTransferKey> keys);
400404

401405
// Copies arrays from source to destination devices when at least one of the
402406
// (source, destination) pairs is cross-host using an experimental DCN
@@ -409,7 +413,7 @@ class PjRtClient final
409413
// Creates a unique identifier for each cross-host transfer. Every process
410414
// must call it, regardless of whether it participates in the cross-host
411415
// transfer, so that the returned value must be the same in all processes.
412-
int64_t CreateNewTransferKey();
416+
CrossHostTransferKey CreateNewTransferKey();
413417

414418
// If true, the backend implements the cross-host transfer APIs.
415419
bool pjrt_supports_cross_host_transfers_ = false;

0 commit comments

Comments
 (0)