@@ -1313,13 +1313,23 @@ PjRtClient::CopyArraysForCrossHost(absl::Span<ArrayRef> arrays,
13131313 std::optional<MemoryKind> memory_kind) {
13141314 std::vector<ArrayRef> new_arrays;
13151315 new_arrays.reserve (arrays.size ());
1316- std::vector<PjRtBuffers > recv_buffers;
1316+ std::vector<std::vector<std::unique_ptr<PjRtBuffer>> > recv_buffers;
13171317 recv_buffers.reserve (dst_devices->AddressableDeviceList ()->size ());
1318+ auto on_send_done = [](absl::Status status) {
1319+ if (!status.ok ()) {
1320+ LOG (ERROR) << " xla::PjRtClient::CrossHostSendBuffers failed: " << status;
1321+ }
1322+ };
1323+ auto on_recv_done = [](absl::Status status) {
1324+ if (!status.ok ()) {
1325+ LOG (ERROR) << " Cross-host receive buffer failed: " << status;
1326+ }
1327+ };
13181328 int j = 0 ; // Counter for the addressable buffers.
13191329 for (int i = 0 ; i < dst_devices->size (); ++i) {
13201330 // TODO(emilyaf): Extend CreateNewTransferKey to take N and return N keys
13211331 // as a performance optimization.
1322- std::vector<int64_t > transfer_keys;
1332+ std::vector<CrossHostTransferKey > transfer_keys;
13231333 transfer_keys.reserve (arrays.size ());
13241334 for (int k = 0 ; k < arrays.size (); ++k) {
13251335 transfer_keys.push_back (CreateNewTransferKey ());
@@ -1337,23 +1347,49 @@ PjRtClient::CopyArraysForCrossHost(absl::Span<ArrayRef> arrays,
13371347 dst_devices->devices ()[i]->DebugString ()));
13381348 }
13391349
1340- PjRtArray::PjRtBuffers send_buffers;
1350+ // Create vector of dst devices; we send each array to
1351+ // dst_devices->devices()[i].
1352+ TF_ASSIGN_OR_RETURN (
1353+ xla::PjRtGlobalDeviceId dst_global_device_id,
1354+ GetPjRtGlobalDeviceId (dst_devices->devices ()[i]->Id ()));
1355+ std::vector<PjRtGlobalDeviceId> dst_global_device_ids (
1356+ arrays.size (), dst_global_device_id);
1357+
1358+ // Create send buffers.
1359+ std::vector<PjRtBuffer*> send_buffers;
13411360 send_buffers.reserve (arrays.size ());
13421361 for (ArrayRef& array : arrays) {
13431362 if (auto * const pjrt_array = llvm::dyn_cast<PjRtArray>(array.get ())) {
13441363 auto buffers = pjrt_array->pjrt_buffers ();
1345- send_buffers.push_back (buffers[j]);
1364+ send_buffers.push_back (buffers[j]. get () );
13461365 } else {
13471366 // TODO(emilyaf): Support string arrays.
13481367 return absl::InvalidArgumentError (
13491368 " Unsupported array type for cross-host "
13501369 " PjRtClient::CopyArraysForCrossHost" );
13511370 }
13521371 }
1353- TF_RETURN_IF_ERROR (
1354- CrossHostSendBuffers (send_buffers, std::move (transfer_keys)));
1372+
1373+ // If the PJRT plugin implements the `CrossHostSendBuffers` API, use it.
1374+ // Otherwise, call this class's `CrossHostSendBuffers` method to use the
1375+ // plugin's `CopyToRemoteDevice` API, getting the buffer descriptors from
1376+ // the KV store.
1377+ absl::StatusOr<std::vector<Future<>>> send_futures =
1378+ pjrt_client_->CrossHostSendBuffers (
1379+ send_buffers, std::move (dst_global_device_ids), transfer_keys);
1380+ if (send_futures.ok ()) {
1381+ for (Future<>& send_future : *send_futures) {
1382+ send_future.OnReady (on_send_done);
1383+ }
1384+ } else if (absl::IsUnimplemented (send_futures.status ())) {
1385+ TF_RETURN_IF_ERROR (
1386+ CrossHostSendBuffers (send_buffers, std::move (transfer_keys)));
1387+ } else {
1388+ return send_futures.status ();
1389+ }
13551390 ++j;
13561391 } else if (dst_devices->devices ()[i]->IsAddressable ()) {
1392+ // Create vector of shapes to receive.
13571393 std::vector<xla::Shape> recv_shapes;
13581394 recv_shapes.reserve (arrays.size ());
13591395 for (const ArrayRef& array : arrays) {
@@ -1371,14 +1407,42 @@ PjRtClient::CopyArraysForCrossHost(absl::Span<ArrayRef> arrays,
13711407 " PjRtClient::CopyArraysForCrossHost" );
13721408 }
13731409 }
1410+
1411+ // Get the dst device we receive into.
13741412 TF_ASSIGN_OR_RETURN (
13751413 xla::PjRtGlobalDeviceId pjrt_global_device_id,
13761414 GetPjRtGlobalDeviceId (dst_devices->devices ()[i]->Id ()));
13771415 TF_ASSIGN_OR_RETURN (xla::PjRtDevice * pjrt_device,
13781416 pjrt_client_->LookupDevice (pjrt_global_device_id));
1379- TF_ASSIGN_OR_RETURN (recv_buffers.emplace_back (),
1380- CrossHostReceiveBuffers (recv_shapes, pjrt_device,
1381- std::move (transfer_keys)));
1417+
1418+ // Create vector of src devices; we receive each array from
1419+ // src_devices->devices()[i].
1420+ TF_ASSIGN_OR_RETURN (
1421+ xla::PjRtGlobalDeviceId src_global_device_id,
1422+ GetPjRtGlobalDeviceId (src_devices->devices ()[i]->Id ()));
1423+ std::vector<PjRtGlobalDeviceId> src_global_device_ids (
1424+ arrays.size (), src_global_device_id);
1425+
1426+ // If the PJRT plugin implements the `CrossHostReceiveBuffers` API, use
1427+ // it. Otherwise, call this class's `CrossHostReceiveBuffers` method to
1428+ // use the plugin's `MakeCrossHostReceiveBuffers` API, transmitting the
1429+ // buffer descriptors via the KV store.
1430+ absl::StatusOr<std::vector<std::unique_ptr<PjRtBuffer>>>
1431+ received_buffers = pjrt_client_->CrossHostReceiveBuffers (
1432+ pjrt_device, recv_shapes, std::move (src_global_device_ids),
1433+ transfer_keys);
1434+ if (received_buffers.ok ()) {
1435+ recv_buffers.push_back (std::move (*received_buffers));
1436+ } else if (absl::IsUnimplemented (received_buffers.status ())) {
1437+ TF_ASSIGN_OR_RETURN (recv_buffers.emplace_back (),
1438+ CrossHostReceiveBuffers (recv_shapes, pjrt_device,
1439+ std::move (transfer_keys)));
1440+ } else {
1441+ return received_buffers.status ();
1442+ }
1443+ for (auto & buffer : recv_buffers.back ()) {
1444+ buffer->GetReadyFuture ().OnReady (on_recv_done);
1445+ }
13821446 }
13831447 }
13841448
@@ -1428,7 +1492,9 @@ PjRtClient::CopyArraysForCrossHostFallback(
14281492 memory_kind);
14291493}
14301494
1431- int64_t PjRtClient::CreateNewTransferKey () { return next_transfer_key_++; }
1495+ CrossHostTransferKey PjRtClient::CreateNewTransferKey () {
1496+ return CrossHostTransferKey (next_transfer_key_++);
1497+ }
14321498
14331499absl::Status PjRtClient::WatchGlobalProcessInfo (
14341500 tsl::CoordinationServiceAgent& agent) {
@@ -1507,7 +1573,8 @@ absl::Status PjRtClient::WatchGlobalProcessInfo(
15071573}
15081574
15091575absl::Status PjRtClient::CrossHostSendBuffers (
1510- PjRtBuffers buffers, const std::vector<int64_t >& keys) {
1576+ std::vector<PjRtBuffer*> buffers,
1577+ const std::vector<CrossHostTransferKey>& keys) {
15111578 if (keys.size () != buffers.size ()) {
15121579 return absl::InternalError (
15131580 " CrossHostSendBuffers: keys must be the same size as buffers." );
@@ -1518,7 +1585,7 @@ absl::Status PjRtClient::CrossHostSendBuffers(
15181585 auto [promise, descriptor_future] = tsl::Future<std::string>::MakePromise ();
15191586 work_queue_->Schedule (
15201587 [this , k = keys[i], promise = std::move (promise).ToShared ()]() mutable {
1521- std::string key = absl::StrCat (kKeyPrefix , k);
1588+ std::string key = absl::StrCat (kKeyPrefix , k. value () );
15221589 absl::StatusOr<std::string> descriptor =
15231590 kv_store_->Get (key, cross_host_transfer_timeout_);
15241591 if (!descriptor.ok ()) {
@@ -1528,16 +1595,24 @@ absl::Status PjRtClient::CrossHostSendBuffers(
15281595 promise->Set (std::move (*descriptor));
15291596 });
15301597 auto on_done = [](absl::Status status, bool sends_were_enqueued) {
1531- CHECK_OK (status);
1598+ if (!status.ok ()) {
1599+ LOG (ERROR) << " `xla::PjRtBuffer::CopyToRemoteDevice` failed: "
1600+ << status;
1601+ }
1602+ if (!sends_were_enqueued) {
1603+ LOG (ERROR) << " `xla::PjRtBuffer::CopyToRemoteDevice` did not enqueue "
1604+ " sends." ;
1605+ }
15321606 };
15331607 buffers[i]->CopyToRemoteDevice (std::move (descriptor_future), on_done);
15341608 }
15351609 return absl::OkStatus ();
15361610}
15371611
1538- absl::StatusOr<PjRtArray::PjRtBuffers> PjRtClient::CrossHostReceiveBuffers (
1539- absl::Span<const xla::Shape> shapes, xla::PjRtDevice* device,
1540- std::vector<int64_t > keys) {
1612+ absl::StatusOr<std::vector<std::unique_ptr<PjRtBuffer>>>
1613+ PjRtClient::CrossHostReceiveBuffers (absl::Span<const xla::Shape> shapes,
1614+ xla::PjRtDevice* device,
1615+ std::vector<CrossHostTransferKey> keys) {
15411616 auto notifier = [this , keys = std::move (keys)](
15421617 absl::StatusOr<xla::PjRtCrossHostRecvState> recv_state) {
15431618 if (!recv_state.ok ()) {
@@ -1567,7 +1642,7 @@ absl::StatusOr<PjRtArray::PjRtBuffers> PjRtClient::CrossHostReceiveBuffers(
15671642 return ;
15681643 }
15691644 for (int i = 0 , n = keys.size (); i < n; ++i) {
1570- std::string key = absl::StrCat (kKeyPrefix , keys[i]);
1645+ std::string key = absl::StrCat (kKeyPrefix , keys[i]. value () );
15711646 absl::Status kv_status = kv_store_->Set (
15721647 key, recv_state->descriptors [i].serialized_descriptors .front ());
15731648 if (!kv_status.ok ()) {
@@ -1581,15 +1656,8 @@ absl::StatusOr<PjRtArray::PjRtBuffers> PjRtClient::CrossHostReceiveBuffers(
15811656 }
15821657 }
15831658 };
1584- TF_ASSIGN_OR_RETURN (auto recv_buffers,
1585- pjrt_client_->MakeCrossHostReceiveBuffers (
1586- shapes, device, std::move (notifier)));
1587- PjRtArray::PjRtBuffers buffers;
1588- buffers.reserve (recv_buffers.size ());
1589- for (auto & recv_buffer : recv_buffers) {
1590- buffers.push_back (std::move (recv_buffer));
1591- }
1592- return buffers;
1659+ return pjrt_client_->MakeCrossHostReceiveBuffers (shapes, device,
1660+ std::move (notifier));
15931661}
15941662
15951663absl::StatusOr<std::vector<xla::ifrt::ArrayRef>> PjRtClient::RemapArrays (
0 commit comments