Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion xla/pjrt/pjrt_client.h
Original file line number Diff line number Diff line change
Expand Up @@ -1032,7 +1032,7 @@ class PjRtClient {
absl::Span<PjRtBuffer* const> buffers,
absl::Span<const PjRtGlobalDeviceId> dst_global_device_ids,
std::vector<CrossHostTransferKey> transfer_keys) {
return absl::InternalError(
return absl::UnimplementedError(
"Cross-host data transfers are not supported by this client.");
}

Expand Down
120 changes: 94 additions & 26 deletions xla/python/pjrt_ifrt/pjrt_client.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1313,13 +1313,23 @@ PjRtClient::CopyArraysForCrossHost(absl::Span<ArrayRef> arrays,
std::optional<MemoryKind> memory_kind) {
std::vector<ArrayRef> new_arrays;
new_arrays.reserve(arrays.size());
std::vector<PjRtBuffers> recv_buffers;
std::vector<std::vector<std::unique_ptr<PjRtBuffer>>> recv_buffers;
recv_buffers.reserve(dst_devices->AddressableDeviceList()->size());
auto on_send_done = [](absl::Status status) {
if (!status.ok()) {
LOG(ERROR) << "xla::PjRtClient::CrossHostSendBuffers failed: " << status;
}
};
auto on_recv_done = [](absl::Status status) {
if (!status.ok()) {
LOG(ERROR) << "Cross-host receive buffer failed: " << status;
}
};
int j = 0; // Counter for the addressable buffers.
for (int i = 0; i < dst_devices->size(); ++i) {
// TODO(emilyaf): Extend CreateNewTransferKey to take N and return N keys
// as a performance optimization.
std::vector<int64_t> transfer_keys;
std::vector<CrossHostTransferKey> transfer_keys;
transfer_keys.reserve(arrays.size());
for (int k = 0; k < arrays.size(); ++k) {
transfer_keys.push_back(CreateNewTransferKey());
Expand All @@ -1337,23 +1347,49 @@ PjRtClient::CopyArraysForCrossHost(absl::Span<ArrayRef> arrays,
dst_devices->devices()[i]->DebugString()));
}

PjRtArray::PjRtBuffers send_buffers;
// Create vector of dst devices; we send each array to
// dst_devices->devices()[i].
TF_ASSIGN_OR_RETURN(
xla::PjRtGlobalDeviceId dst_global_device_id,
GetPjRtGlobalDeviceId(dst_devices->devices()[i]->Id()));
std::vector<PjRtGlobalDeviceId> dst_global_device_ids(
arrays.size(), dst_global_device_id);

// Create send buffers.
std::vector<PjRtBuffer*> send_buffers;
send_buffers.reserve(arrays.size());
for (ArrayRef& array : arrays) {
if (auto* const pjrt_array = llvm::dyn_cast<PjRtArray>(array.get())) {
auto buffers = pjrt_array->pjrt_buffers();
send_buffers.push_back(buffers[j]);
send_buffers.push_back(buffers[j].get());
} else {
// TODO(emilyaf): Support string arrays.
return absl::InvalidArgumentError(
"Unsupported array type for cross-host "
"PjRtClient::CopyArraysForCrossHost");
}
}
TF_RETURN_IF_ERROR(
CrossHostSendBuffers(send_buffers, std::move(transfer_keys)));

// If the PJRT plugin implements the `CrossHostSendBuffers` API, use it.
// Otherwise, call this class's `CrossHostSendBuffers` method to use the
// plugin's `CopyToRemoteDevice` API, getting the buffer descriptors from
// the KV store.
absl::StatusOr<std::vector<Future<>>> send_futures =
pjrt_client_->CrossHostSendBuffers(
send_buffers, std::move(dst_global_device_ids), transfer_keys);
if (send_futures.ok()) {
for (Future<>& send_future : *send_futures) {
send_future.OnReady(on_send_done);
}
} else if (absl::IsUnimplemented(send_futures.status())) {
TF_RETURN_IF_ERROR(
CrossHostSendBuffers(send_buffers, std::move(transfer_keys)));
} else {
return send_futures.status();
}
++j;
} else if (dst_devices->devices()[i]->IsAddressable()) {
// Create vector of shapes to receive.
std::vector<xla::Shape> recv_shapes;
recv_shapes.reserve(arrays.size());
for (const ArrayRef& array : arrays) {
Expand All @@ -1371,14 +1407,42 @@ PjRtClient::CopyArraysForCrossHost(absl::Span<ArrayRef> arrays,
"PjRtClient::CopyArraysForCrossHost");
}
}

// Get the dst device we receive into.
TF_ASSIGN_OR_RETURN(
xla::PjRtGlobalDeviceId pjrt_global_device_id,
GetPjRtGlobalDeviceId(dst_devices->devices()[i]->Id()));
TF_ASSIGN_OR_RETURN(xla::PjRtDevice * pjrt_device,
pjrt_client_->LookupDevice(pjrt_global_device_id));
TF_ASSIGN_OR_RETURN(recv_buffers.emplace_back(),
CrossHostReceiveBuffers(recv_shapes, pjrt_device,
std::move(transfer_keys)));

// Create vector of src devices; we receive each array from
// src_devices->devices()[i].
TF_ASSIGN_OR_RETURN(
xla::PjRtGlobalDeviceId src_global_device_id,
GetPjRtGlobalDeviceId(src_devices->devices()[i]->Id()));
std::vector<PjRtGlobalDeviceId> src_global_device_ids(
arrays.size(), src_global_device_id);

// If the PJRT plugin implements the `CrossHostReceiveBuffers` API, use
// it. Otherwise, call this class's `CrossHostReceiveBuffers` method to
// use the plugin's `MakeCrossHostReceiveBuffers` API, transmitting the
// buffer descriptors via the KV store.
absl::StatusOr<std::vector<std::unique_ptr<PjRtBuffer>>>
received_buffers = pjrt_client_->CrossHostReceiveBuffers(
pjrt_device, recv_shapes, std::move(src_global_device_ids),
transfer_keys);
if (received_buffers.ok()) {
recv_buffers.push_back(std::move(*received_buffers));
} else if (absl::IsUnimplemented(received_buffers.status())) {
TF_ASSIGN_OR_RETURN(recv_buffers.emplace_back(),
CrossHostReceiveBuffers(recv_shapes, pjrt_device,
std::move(transfer_keys)));
} else {
return received_buffers.status();
}
for (auto& buffer : recv_buffers.back()) {
buffer->GetReadyFuture().OnReady(on_recv_done);
}
}
}

Expand Down Expand Up @@ -1428,7 +1492,9 @@ PjRtClient::CopyArraysForCrossHostFallback(
memory_kind);
}

int64_t PjRtClient::CreateNewTransferKey() { return next_transfer_key_++; }
CrossHostTransferKey PjRtClient::CreateNewTransferKey() {
return CrossHostTransferKey(next_transfer_key_++);
}

absl::Status PjRtClient::WatchGlobalProcessInfo(
tsl::CoordinationServiceAgent& agent) {
Expand Down Expand Up @@ -1507,7 +1573,8 @@ absl::Status PjRtClient::WatchGlobalProcessInfo(
}

absl::Status PjRtClient::CrossHostSendBuffers(
PjRtBuffers buffers, const std::vector<int64_t>& keys) {
std::vector<PjRtBuffer*> buffers,
const std::vector<CrossHostTransferKey>& keys) {
if (keys.size() != buffers.size()) {
return absl::InternalError(
"CrossHostSendBuffers: keys must be the same size as buffers.");
Expand All @@ -1518,7 +1585,7 @@ absl::Status PjRtClient::CrossHostSendBuffers(
auto [promise, descriptor_future] = tsl::Future<std::string>::MakePromise();
work_queue_->Schedule(
[this, k = keys[i], promise = std::move(promise).ToShared()]() mutable {
std::string key = absl::StrCat(kKeyPrefix, k);
std::string key = absl::StrCat(kKeyPrefix, k.value());
absl::StatusOr<std::string> descriptor =
kv_store_->Get(key, cross_host_transfer_timeout_);
if (!descriptor.ok()) {
Expand All @@ -1528,16 +1595,24 @@ absl::Status PjRtClient::CrossHostSendBuffers(
promise->Set(std::move(*descriptor));
});
auto on_done = [](absl::Status status, bool sends_were_enqueued) {
CHECK_OK(status);
if (!status.ok()) {
LOG(ERROR) << "`xla::PjRtBuffer::CopyToRemoteDevice` failed: "
<< status;
}
if (!sends_were_enqueued) {
LOG(ERROR) << "`xla::PjRtBuffer::CopyToRemoteDevice` did not enqueue "
"sends.";
}
};
buffers[i]->CopyToRemoteDevice(std::move(descriptor_future), on_done);
}
return absl::OkStatus();
}

absl::StatusOr<PjRtArray::PjRtBuffers> PjRtClient::CrossHostReceiveBuffers(
absl::Span<const xla::Shape> shapes, xla::PjRtDevice* device,
std::vector<int64_t> keys) {
absl::StatusOr<std::vector<std::unique_ptr<PjRtBuffer>>>
PjRtClient::CrossHostReceiveBuffers(absl::Span<const xla::Shape> shapes,
xla::PjRtDevice* device,
std::vector<CrossHostTransferKey> keys) {
auto notifier = [this, keys = std::move(keys)](
absl::StatusOr<xla::PjRtCrossHostRecvState> recv_state) {
if (!recv_state.ok()) {
Expand Down Expand Up @@ -1567,7 +1642,7 @@ absl::StatusOr<PjRtArray::PjRtBuffers> PjRtClient::CrossHostReceiveBuffers(
return;
}
for (int i = 0, n = keys.size(); i < n; ++i) {
std::string key = absl::StrCat(kKeyPrefix, keys[i]);
std::string key = absl::StrCat(kKeyPrefix, keys[i].value());
absl::Status kv_status = kv_store_->Set(
key, recv_state->descriptors[i].serialized_descriptors.front());
if (!kv_status.ok()) {
Expand All @@ -1581,15 +1656,8 @@ absl::StatusOr<PjRtArray::PjRtBuffers> PjRtClient::CrossHostReceiveBuffers(
}
}
};
TF_ASSIGN_OR_RETURN(auto recv_buffers,
pjrt_client_->MakeCrossHostReceiveBuffers(
shapes, device, std::move(notifier)));
PjRtArray::PjRtBuffers buffers;
buffers.reserve(recv_buffers.size());
for (auto& recv_buffer : recv_buffers) {
buffers.push_back(std::move(recv_buffer));
}
return buffers;
return pjrt_client_->MakeCrossHostReceiveBuffers(shapes, device,
std::move(notifier));
}

absl::StatusOr<std::vector<xla::ifrt::ArrayRef>> PjRtClient::RemapArrays(
Expand Down
20 changes: 12 additions & 8 deletions xla/python/pjrt_ifrt/pjrt_client.h
Original file line number Diff line number Diff line change
Expand Up @@ -388,15 +388,19 @@ class PjRtClient final
DeviceListRef dst_devices, std::optional<MemoryKind> memory_kind);

// Extracts receive descriptors from a key-value store and sends buffers to a
// remote device.
absl::Status CrossHostSendBuffers(PjRtBuffers buffers,
const std::vector<int64_t>& keys);
// remote device. This is used when the backend does not implement the
// CrossHostSendBuffers API.
absl::Status CrossHostSendBuffers(
std::vector<PjRtBuffer*> buffers,
const std::vector<CrossHostTransferKey>& keys);

// Populates a key-value store with receive descriptors and places buffers
// from a cross-host send onto device.
absl::StatusOr<PjRtBuffers> CrossHostReceiveBuffers(
absl::Span<const xla::Shape> shapes, xla::PjRtDevice* device,
std::vector<int64_t> keys);
// from a cross-host send onto device. This is used when the backend does not
// implement the CrossHostReceiveBuffers API.
absl::StatusOr<std::vector<std::unique_ptr<PjRtBuffer>>>
CrossHostReceiveBuffers(absl::Span<const xla::Shape> shapes,
xla::PjRtDevice* device,
std::vector<CrossHostTransferKey> keys);

// Copies arrays from source to destination devices when at least one of the
// (source, destination) pairs is cross-host using an experimental DCN
Expand All @@ -409,7 +413,7 @@ class PjRtClient final
// Creates a unique identifier for each cross-host transfer. Every process
// must call it, regardless of whether it participates in the cross-host
// transfer, so that the returned value must be the same in all processes.
int64_t CreateNewTransferKey();
CrossHostTransferKey CreateNewTransferKey();

// If true, the backend implements the cross-host transfer APIs.
bool pjrt_supports_cross_host_transfers_ = false;
Expand Down
Loading