@@ -550,11 +550,14 @@ torch::lazy::BackendDataPtr TensorToXlaData(
550550 const at::Tensor& tensor, const xla::Shape& shape,
551551 const torch::lazy::BackendDevice& device) {
552552 TORCH_LAZY_TIMED (" TensorToData" );
553+
554+ XLA_ASSIGN_OR_THROW (runtime::ComputationClient * absl_nonnull const client,
555+ runtime::GetComputationClient ());
556+
553557 if (static_cast <XlaDeviceType>(device.type ()) == XlaDeviceType::SPMD) {
554558 // The tensor is bypassing the virtual device, so it should be replicated
555559 // to all devices.
556- std::vector<std::string> local_devices =
557- runtime::GetComputationClientOrDie ()->GetLocalDevices ();
560+ std::vector<std::string> local_devices = client->GetLocalDevices ();
558561 auto replicated_data =
559562 std::vector<at::Tensor>(local_devices.size (), tensor);
560563 return ShardingUtil::CreateShardedData (replicated_data, local_devices,
@@ -565,8 +568,7 @@ torch::lazy::BackendDataPtr TensorToXlaData(
565568 source_tensors.push_back (
566569 std::make_shared<runtime::AtenSource>(tensor, shape, device.toString ()));
567570
568- auto handles =
569- runtime::GetComputationClientOrDie ()->TransferToDevice (source_tensors);
571+ auto handles = client->TransferToDevice (source_tensors);
570572 XLA_CHECK_EQ (handles.size (), 1 );
571573 return handles.front ();
572574}
@@ -806,15 +808,17 @@ std::vector<torch::lazy::BackendDataPtr> CreateTensorsData(
806808 return {};
807809 }
808810
811+ XLA_ASSIGN_OR_THROW (runtime::ComputationClient * absl_nonnull const client,
812+ runtime::GetComputationClient ());
813+
809814 // CreateTensorsData should be implicitly replicated to all devices.
810815 if (IsVirtualDevice (devices[0 ])) {
811816 XLA_CHECK (
812817 std::all_of (devices.begin (), devices.end (),
813818 [&](const std::string& s) { return s == devices[0 ]; }))
814819 << " can't mix virtual device and real device." ;
815820
816- std::vector<std::string> local_devices =
817- runtime::GetComputationClientOrDie ()->GetLocalDevices ();
821+ std::vector<std::string> local_devices = client->GetLocalDevices ();
818822 std::vector<runtime::ComputationClient::DataPtr> handles;
819823 for (size_t i = 0 ; i < tensors.size (); ++i) {
820824 auto device = ParseDeviceString (devices[i]);
@@ -834,8 +838,7 @@ std::vector<torch::lazy::BackendDataPtr> CreateTensorsData(
834838 source_tensors.push_back (std::make_shared<runtime::AtenSource>(
835839 tensors[i], std::move (shape), devices[i]));
836840 }
837- return WrapXlaData (
838- runtime::GetComputationClientOrDie ()->TransferToDevice (source_tensors));
841+ return WrapXlaData (client->TransferToDevice (source_tensors));
839842}
840843
841844std::vector<torch::lazy::BackendDataPtr> CreateTensorsData (
@@ -846,6 +849,9 @@ std::vector<torch::lazy::BackendDataPtr> CreateTensorsData(
846849 XLA_CHECK_EQ (tensors.size (), shardings.size ());
847850 XLA_CHECK_EQ (tensors.size (), devices.size ());
848851
852+ XLA_ASSIGN_OR_THROW (runtime::ComputationClient * absl_nonnull const client,
853+ runtime::GetComputationClient ());
854+
849855 std::vector<runtime::ComputationClient::DataPtr> handles;
850856 for (size_t i = 0 ; i < tensors.size (); ++i) {
851857 torch::lazy::BackendDevice device = ParseDeviceString (devices[i]);
@@ -858,8 +864,7 @@ std::vector<torch::lazy::BackendDataPtr> CreateTensorsData(
858864 // GetLocalDevices returns the list of local devices specified by their
859865 // global ordinals (e.g. ["TPU:4", "TPU:5", "TPU:6", "TPU:7"]).
860866
861- std::vector<std::string> local_devices =
862- runtime::GetComputationClientOrDie ()->GetLocalDevices ();
867+ std::vector<std::string> local_devices = client->GetLocalDevices ();
863868 // Shards the input tensors with padding, to split evenly.
864869 // The execution requires consistent shard sizes, and the zero-padded
865870 // values should be ignored.
@@ -871,8 +876,7 @@ std::vector<torch::lazy::BackendDataPtr> CreateTensorsData(
871876 } else {
872877 source_tensors.push_back (std::make_shared<runtime::AtenSource>(
873878 tensors[i], std::move (shape), devices[i]));
874- new_handles = runtime::GetComputationClientOrDie ()->TransferToDevice (
875- source_tensors);
879+ new_handles = client->TransferToDevice (source_tensors);
876880 }
877881 handles.insert (handles.end (), new_handles.begin (), new_handles.end ());
878882 }
@@ -910,7 +914,7 @@ absl::StatusOr<std::vector<xla::Literal>> ReleaseGilAndTransferData(
910914 save = PyEval_SaveThread ();
911915 }
912916
913- XLA_ASSIGN_OR_RETURN (runtime::ComputationClient * client,
917+ XLA_ASSIGN_OR_RETURN (runtime::ComputationClient * absl_nonnull const client,
914918 runtime::GetComputationClient ());
915919 XLA_ASSIGN_OR_RETURN (std::vector<xla::Literal> literals,
916920 client->TransferFromDevice (UnwrapXlaData (xla_data)));
0 commit comments