From fa2036af8241f792f0918b1a18bab4fe2d71aa4c Mon Sep 17 00:00:00 2001 From: Emily Fertig Date: Fri, 5 Dec 2025 14:07:27 -0800 Subject: [PATCH] [PjRt-IFRT] For cross-host transfers, call the new CrossHost{Send,Receive}Buffers API if the plugin implements it. This enables efficient cross-host transfers on GPU. PiperOrigin-RevId: 840867337 --- xla/pjrt/pjrt_client.h | 2 +- xla/python/pjrt_ifrt/pjrt_client.cc | 120 ++++++++++++++++++++++------ xla/python/pjrt_ifrt/pjrt_client.h | 20 +++-- 3 files changed, 107 insertions(+), 35 deletions(-) diff --git a/xla/pjrt/pjrt_client.h b/xla/pjrt/pjrt_client.h index f35a13e9aafe4..827415e8e0515 100644 --- a/xla/pjrt/pjrt_client.h +++ b/xla/pjrt/pjrt_client.h @@ -1032,7 +1032,7 @@ class PjRtClient { absl::Span buffers, absl::Span dst_global_device_ids, std::vector transfer_keys) { - return absl::InternalError( + return absl::UnimplementedError( "Cross-host data transfers are not supported by this client."); } diff --git a/xla/python/pjrt_ifrt/pjrt_client.cc b/xla/python/pjrt_ifrt/pjrt_client.cc index d94d21da799b8..c2d4458ba0663 100644 --- a/xla/python/pjrt_ifrt/pjrt_client.cc +++ b/xla/python/pjrt_ifrt/pjrt_client.cc @@ -1313,13 +1313,23 @@ PjRtClient::CopyArraysForCrossHost(absl::Span arrays, std::optional memory_kind) { std::vector new_arrays; new_arrays.reserve(arrays.size()); - std::vector recv_buffers; + std::vector>> recv_buffers; recv_buffers.reserve(dst_devices->AddressableDeviceList()->size()); + auto on_send_done = [](absl::Status status) { + if (!status.ok()) { + LOG(ERROR) << "xla::PjRtClient::CrossHostSendBuffers failed: " << status; + } + }; + auto on_recv_done = [](absl::Status status) { + if (!status.ok()) { + LOG(ERROR) << "Cross-host receive buffer failed: " << status; + } + }; int j = 0; // Counter for the addressable buffers. for (int i = 0; i < dst_devices->size(); ++i) { // TODO(emilyaf): Extend CreateNewTransferKey to take N and return N keys // as a performance optimization. - std::vector transfer_keys; + std::vector transfer_keys; transfer_keys.reserve(arrays.size()); for (int k = 0; k < arrays.size(); ++k) { transfer_keys.push_back(CreateNewTransferKey()); @@ -1337,12 +1347,21 @@ PjRtClient::CopyArraysForCrossHost(absl::Span arrays, dst_devices->devices()[i]->DebugString())); } - PjRtArray::PjRtBuffers send_buffers; + // Create vector of dst devices; we send each array to + // dst_devices->devices()[i]. + TF_ASSIGN_OR_RETURN( + xla::PjRtGlobalDeviceId dst_global_device_id, + GetPjRtGlobalDeviceId(dst_devices->devices()[i]->Id())); + std::vector dst_global_device_ids( + arrays.size(), dst_global_device_id); + + // Create send buffers. + std::vector send_buffers; send_buffers.reserve(arrays.size()); for (ArrayRef& array : arrays) { if (auto* const pjrt_array = llvm::dyn_cast(array.get())) { auto buffers = pjrt_array->pjrt_buffers(); - send_buffers.push_back(buffers[j]); + send_buffers.push_back(buffers[j].get()); } else { // TODO(emilyaf): Support string arrays. return absl::InvalidArgumentError( @@ -1350,10 +1369,27 @@ PjRtClient::CopyArraysForCrossHost(absl::Span arrays, "PjRtClient::CopyArraysForCrossHost"); } } - TF_RETURN_IF_ERROR( - CrossHostSendBuffers(send_buffers, std::move(transfer_keys))); + + // If the PJRT plugin implements the `CrossHostSendBuffers` API, use it. + // Otherwise, call this class's `CrossHostSendBuffers` method to use the + // plugin's `CopyToRemoteDevice` API, getting the buffer descriptors from + // the KV store. + absl::StatusOr>> send_futures = + pjrt_client_->CrossHostSendBuffers( + send_buffers, std::move(dst_global_device_ids), transfer_keys); + if (send_futures.ok()) { + for (Future<>& send_future : *send_futures) { + send_future.OnReady(on_send_done); + } + } else if (absl::IsUnimplemented(send_futures.status())) { + TF_RETURN_IF_ERROR( + CrossHostSendBuffers(send_buffers, std::move(transfer_keys))); + } else { + return send_futures.status(); + } ++j; } else if (dst_devices->devices()[i]->IsAddressable()) { + // Create vector of shapes to receive. std::vector recv_shapes; recv_shapes.reserve(arrays.size()); for (const ArrayRef& array : arrays) { @@ -1371,14 +1407,42 @@ PjRtClient::CopyArraysForCrossHost(absl::Span arrays, "PjRtClient::CopyArraysForCrossHost"); } } + + // Get the dst device we receive into. TF_ASSIGN_OR_RETURN( xla::PjRtGlobalDeviceId pjrt_global_device_id, GetPjRtGlobalDeviceId(dst_devices->devices()[i]->Id())); TF_ASSIGN_OR_RETURN(xla::PjRtDevice * pjrt_device, pjrt_client_->LookupDevice(pjrt_global_device_id)); - TF_ASSIGN_OR_RETURN(recv_buffers.emplace_back(), - CrossHostReceiveBuffers(recv_shapes, pjrt_device, - std::move(transfer_keys))); + + // Create vector of src devices; we receive each array from + // src_devices->devices()[i]. + TF_ASSIGN_OR_RETURN( + xla::PjRtGlobalDeviceId src_global_device_id, + GetPjRtGlobalDeviceId(src_devices->devices()[i]->Id())); + std::vector src_global_device_ids( + arrays.size(), src_global_device_id); + + // If the PJRT plugin implements the `CrossHostReceiveBuffers` API, use + // it. Otherwise, call this class's `CrossHostReceiveBuffers` method to + // use the plugin's `MakeCrossHostReceiveBuffers` API, transmitting the + // buffer descriptors via the KV store. + absl::StatusOr>> + received_buffers = pjrt_client_->CrossHostReceiveBuffers( + pjrt_device, recv_shapes, std::move(src_global_device_ids), + transfer_keys); + if (received_buffers.ok()) { + recv_buffers.push_back(std::move(*received_buffers)); + } else if (absl::IsUnimplemented(received_buffers.status())) { + TF_ASSIGN_OR_RETURN(recv_buffers.emplace_back(), + CrossHostReceiveBuffers(recv_shapes, pjrt_device, + std::move(transfer_keys))); + } else { + return received_buffers.status(); + } + for (auto& buffer : recv_buffers.back()) { + buffer->GetReadyFuture().OnReady(on_recv_done); + } } } @@ -1428,7 +1492,9 @@ PjRtClient::CopyArraysForCrossHostFallback( memory_kind); } -int64_t PjRtClient::CreateNewTransferKey() { return next_transfer_key_++; } +CrossHostTransferKey PjRtClient::CreateNewTransferKey() { + return CrossHostTransferKey(next_transfer_key_++); +} absl::Status PjRtClient::WatchGlobalProcessInfo( tsl::CoordinationServiceAgent& agent) { @@ -1507,7 +1573,8 @@ absl::Status PjRtClient::WatchGlobalProcessInfo( } absl::Status PjRtClient::CrossHostSendBuffers( - PjRtBuffers buffers, const std::vector& keys) { + std::vector buffers, + const std::vector& keys) { if (keys.size() != buffers.size()) { return absl::InternalError( "CrossHostSendBuffers: keys must be the same size as buffers."); @@ -1518,7 +1585,7 @@ absl::Status PjRtClient::CrossHostSendBuffers( auto [promise, descriptor_future] = tsl::Future::MakePromise(); work_queue_->Schedule( [this, k = keys[i], promise = std::move(promise).ToShared()]() mutable { - std::string key = absl::StrCat(kKeyPrefix, k); + std::string key = absl::StrCat(kKeyPrefix, k.value()); absl::StatusOr descriptor = kv_store_->Get(key, cross_host_transfer_timeout_); if (!descriptor.ok()) { @@ -1528,16 +1595,24 @@ absl::Status PjRtClient::CrossHostSendBuffers( promise->Set(std::move(*descriptor)); }); auto on_done = [](absl::Status status, bool sends_were_enqueued) { - CHECK_OK(status); + if (!status.ok()) { + LOG(ERROR) << "`xla::PjRtBuffer::CopyToRemoteDevice` failed: " + << status; + } + if (!sends_were_enqueued) { + LOG(ERROR) << "`xla::PjRtBuffer::CopyToRemoteDevice` did not enqueue " + "sends."; + } }; buffers[i]->CopyToRemoteDevice(std::move(descriptor_future), on_done); } return absl::OkStatus(); } -absl::StatusOr PjRtClient::CrossHostReceiveBuffers( - absl::Span shapes, xla::PjRtDevice* device, - std::vector keys) { +absl::StatusOr>> +PjRtClient::CrossHostReceiveBuffers(absl::Span shapes, + xla::PjRtDevice* device, + std::vector keys) { auto notifier = [this, keys = std::move(keys)]( absl::StatusOr recv_state) { if (!recv_state.ok()) { @@ -1567,7 +1642,7 @@ absl::StatusOr PjRtClient::CrossHostReceiveBuffers( return; } for (int i = 0, n = keys.size(); i < n; ++i) { - std::string key = absl::StrCat(kKeyPrefix, keys[i]); + std::string key = absl::StrCat(kKeyPrefix, keys[i].value()); absl::Status kv_status = kv_store_->Set( key, recv_state->descriptors[i].serialized_descriptors.front()); if (!kv_status.ok()) { @@ -1581,15 +1656,8 @@ absl::StatusOr PjRtClient::CrossHostReceiveBuffers( } } }; - TF_ASSIGN_OR_RETURN(auto recv_buffers, - pjrt_client_->MakeCrossHostReceiveBuffers( - shapes, device, std::move(notifier))); - PjRtArray::PjRtBuffers buffers; - buffers.reserve(recv_buffers.size()); - for (auto& recv_buffer : recv_buffers) { - buffers.push_back(std::move(recv_buffer)); - } - return buffers; + return pjrt_client_->MakeCrossHostReceiveBuffers(shapes, device, + std::move(notifier)); } absl::StatusOr> PjRtClient::RemapArrays( diff --git a/xla/python/pjrt_ifrt/pjrt_client.h b/xla/python/pjrt_ifrt/pjrt_client.h index d6a653e053609..040775d7a98dd 100644 --- a/xla/python/pjrt_ifrt/pjrt_client.h +++ b/xla/python/pjrt_ifrt/pjrt_client.h @@ -388,15 +388,19 @@ class PjRtClient final DeviceListRef dst_devices, std::optional memory_kind); // Extracts receive descriptors from a key-value store and sends buffers to a - // remote device. - absl::Status CrossHostSendBuffers(PjRtBuffers buffers, - const std::vector& keys); + // remote device. This is used when the backend does not implement the + // CrossHostSendBuffers API. + absl::Status CrossHostSendBuffers( + std::vector buffers, + const std::vector& keys); // Populates a key-value store with receive descriptors and places buffers - // from a cross-host send onto device. - absl::StatusOr CrossHostReceiveBuffers( - absl::Span shapes, xla::PjRtDevice* device, - std::vector keys); + // from a cross-host send onto device. This is used when the backend does not + // implement the CrossHostReceiveBuffers API. + absl::StatusOr>> + CrossHostReceiveBuffers(absl::Span shapes, + xla::PjRtDevice* device, + std::vector keys); // Copies arrays from source to destination devices when at least one of the // (source, destination) pairs is cross-host using an experimental DCN @@ -409,7 +413,7 @@ class PjRtClient final // Creates a unique identifier for each cross-host transfer. Every process // must call it, regardless of whether it participates in the cross-host // transfer, so that the returned value must be the same in all processes. - int64_t CreateNewTransferKey(); + CrossHostTransferKey CreateNewTransferKey(); // If true, the backend implements the cross-host transfer APIs. bool pjrt_supports_cross_host_transfers_ = false;