Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 5 additions & 13 deletions .github/workflows/_tpu_ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -51,10 +51,14 @@ jobs:
pip install fsspec
pip install rich
# Test dependencies
pip install --upgrade protobuf
pip install flax
# PyTorch/XLA Optional Dependencies
# =================================
#
# Install `JAX` and `libtpu` dependencies for pallas and TPU tests.
# Install `jax` and `libtpu` dependencies for pallas and TPU tests.
#
# Note that we might need to install pre-release versions of both, in
# external artifact repositories.
Expand All @@ -70,18 +74,6 @@ jobs:
pip install "$WHL[pallas]" --pre --index-url $INDEX --find-links $LINKS
pip install "$WHL[tpu]" --pre --index-url $INDEX --find-links $LINKS
pip install --upgrade protobuf
# Flax Pin
# ========
#
# Be careful when bumping the `flax` version, since it can cause tests that
# depend on `jax` to start breaking.
#
# Newer `flax` versions might pull newer `jax` versions, which might be incompatible
# with the current version of PyTorch/XLA.
pip install flax==0.11.2
- name: Run Tests (${{ matrix.test_script }})
if: inputs.has_code_changes == 'true'
env:
Expand Down
4 changes: 2 additions & 2 deletions .torch_commit
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
# 2025-09-17
928ac57c2ab03f9f79376f9995553eea2e6f4ca8
# 2025-09-29
21fec65781bebe867faf209f89bb687ffd236ca4
2 changes: 1 addition & 1 deletion BUILD
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
load("@python//:defs.bzl", "compile_pip_requirements")
load("@python_version_repo//:py_version.bzl", "REQUIREMENTS")
load("@rules_python//python:pip.bzl", "compile_pip_requirements")

compile_pip_requirements(
name = "requirements",
Expand Down
3 changes: 2 additions & 1 deletion WORKSPACE
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ new_local_repository(

# To build PyTorch/XLA with a new revison of OpenXLA, update the xla_hash to
# the openxla git commit hash and note the date of the commit.
xla_hash = '92f7b5952dd585c5be17c9a5caad27407005b513' # Committed on 2025-08-15.
xla_hash = '9a9aa0e11e4fcda8d6a9c3267dca6776ddbdb0ca' # Committed on 2025-10-01.

http_archive(
name = "xla",
Expand All @@ -63,6 +63,7 @@ http_archive(
patch_tool = "patch",
patches = [
"//openxla_patches:no_fortify.diff",
"//openxla_patches:if_constexpr_static_assert.diff",
],
strip_prefix = "xla-" + xla_hash,
urls = [
Expand Down
40 changes: 40 additions & 0 deletions openxla_patches/if_constexpr_static_assert.diff
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
diff --git a/xla/python/ifrt/attribute_map.h b/xla/python/ifrt/attribute_map.h
index a8c9f11c8d..e5bb70bcf8 100644
--- a/xla/python/ifrt/attribute_map.h
+++ b/xla/python/ifrt/attribute_map.h
@@ -106,7 +106,9 @@ class AttributeMap {
} else if constexpr (std::is_same_v<T, float>) {
return Get<T, FloatValue>(key);
} else {
- static_assert(false, "Unsupported type for AttributeMap::Get");
+ // Same as: static_assert(false).
+ // Make it compileable by GCC version < 13.
+ static_assert(!sizeof(T), "Unsupported type for AttributeMap::Get");
}
}

diff --git a/xla/stream_executor/plugin_registry.cc b/xla/stream_executor/plugin_registry.cc
index f16a4f6707..8bfd51b238 100644
--- a/xla/stream_executor/plugin_registry.cc
+++ b/xla/stream_executor/plugin_registry.cc
@@ -41,7 +41,9 @@ PluginKind GetPluginKind() {
} else if constexpr (std::is_same_v<FactoryT, PluginRegistry::FftFactory>) {
return PluginKind::kFft;
} else {
- static_assert(false, "Unsupported factory type");
+ // Same as: static_assert(false).
+ // Make it compileable by GCC version < 13.
+ static_assert(!sizeof(FactoryT), "Unsupported factory type");
}
}
template <typename FactoryT>
@@ -53,7 +55,9 @@ absl::string_view GetPluginName() {
} else if constexpr (std::is_same_v<FactoryT, PluginRegistry::FftFactory>) {
return "FFT";
} else {
- static_assert(false, "Unsupported factory type");
+ // Same as: static_assert(false).
+ // Make it compileable by GCC version < 13.
+ static_assert(!sizeof(FactoryT), "Unsupported factory type");
}
}
10 changes: 5 additions & 5 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,12 +112,12 @@

USE_NIGHTLY = True # Whether to use nightly or stable libtpu and JAX.

_libtpu_version = '0.0.21'
_libtpu_date = '20250813'
_libtpu_version = '0.0.24'
_libtpu_date = '20250929'

_jax_version = '0.7.1'
_jaxlib_version = '0.7.1'
_jax_date = '20250813' # Date for jax and jaxlib.
_jax_version = '0.8.0'
_jaxlib_version = '0.8.0'
_jax_date = '20251001' # Date for jax and jaxlib.

if USE_NIGHTLY:
_libtpu_version += f".dev{_libtpu_date}+nightly"
Expand Down
2 changes: 1 addition & 1 deletion test/spmd/test_fsdp_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ def test_fsdp_v2_basic(self):
# Make sure optimization barrier is applied.
hlo = torch_xla._XLAC._get_xla_tensors_hlo([model.fc2.weight.grad])
self.assertIn(
'opt-barrier.38 = (f32[1,64]{0,1}, f32[1]{0}, f32[16,64]{1,0}) opt-barrier((f32[1,64]{0,1}, f32[1]{0}, f32[16,64]{1,0}) %tuple.37',
'opt-barrier.1 = (f32[1,64]{0,1}, f32[1]{0}, f32[16,64]{1,0}) opt-barrier((f32[1,64]{0,1}, f32[1]{0}, f32[16,64]{1,0}) %tuple.2',
hlo)

# Make sure the model can execute without error.
Expand Down
27 changes: 14 additions & 13 deletions test/spmd/test_xla_sharding.py
Original file line number Diff line number Diff line change
Expand Up @@ -613,7 +613,7 @@ def test_inplace_add_with_sharding(self):
self.assertEqual(sharding_spec, torch_xla._XLAC._get_xla_sharding_spec(xt))
hlo = torch_xla._XLAC._get_xla_tensors_hlo([xt])
self.assertIn(
'%custom-call.7 = f32[2,2]{1,0} custom-call(f32[2,2]{1,0} %add.6), custom_call_target="Sharding", sharding=',
'%custom-call.1 = f32[2,2]{1,0} custom-call(f32[2,2]{1,0} %add.1), custom_call_target="Sharding", sharding=',
hlo)

# avoid calling xr.addressable_device_count here otherwise it will init the test
Expand Down Expand Up @@ -713,7 +713,8 @@ def test_xla_sharded_hlo_dump(self):
partition_spec)
xst2 = xst1 + 5
hlo = torch_xla._XLAC._get_xla_tensors_hlo([xst2.global_tensor])
self.assertIn('%p1.3 = f32[1,8]{1,0} parameter(1), sharding', hlo)
print(hlo)
self.assertIn('%p1.1 = f32[1,8]{1,0} parameter(1), sharding', hlo)
if torch_xla._XLAC._xla_get_auto_sharding():
# scalar 5 should be implicitly replicated, so the pre-optimization HLO
# shouldn't mark it with sharding.
Expand Down Expand Up @@ -828,13 +829,13 @@ def test_mark_sharding_ir(self):
(0, 1))
hlo = torch_xla._XLAC._get_xla_tensors_hlo([actual.global_tensor])
self.assertIn(
'%custom-call.7 = f32[1,128]{1,0} custom-call(f32[1,128]{1,0} %add.6), custom_call_target="Sharding", sharding=',
'%custom-call.1 = f32[1,128]{1,0} custom-call(f32[1,128]{1,0} %add.1), custom_call_target="Sharding", sharding=',
hlo)

actual += 0
hlo = torch_xla._XLAC._get_xla_tensors_hlo([actual.global_tensor])
self.assertIn(
'%add.12 = f32[1,128]{1,0} add(f32[1,128]{1,0} %custom-call.9, f32[1,128]{1,0} %broadcast.11)',
'%add.3 = f32[1,128]{1,0} add(f32[1,128]{1,0} %custom-call.1, f32[1,128]{1,0} %broadcast.3)',
hlo)

self.assertTrue(torch.allclose(expected, actual.cpu()))
Expand Down Expand Up @@ -1141,7 +1142,7 @@ def test_backward_optimization_barrier(self):

hlo = torch_xla._XLAC._get_xla_tensors_hlo([model.fc2.weight.grad])
self.assertIn(
'%opt-barrier.37 = (f32[1,64]{0,1}, f32[1]{0}, f32[2,64]{1,0}) opt-barrier((f32[1,64]{0,1}, f32[1]{0}, f32[2,64]{1,0}) %tuple.36)',
'%opt-barrier.1 = (f32[1,64]{0,1}, f32[1]{0}, f32[2,64]{1,0}) opt-barrier((f32[1,64]{0,1}, f32[1]{0}, f32[2,64]{1,0}) %tuple.2)',
hlo)

def test_mark_shard_scalar(self):
Expand Down Expand Up @@ -1198,7 +1199,7 @@ def test_spmd_full_to_shard_shape(self):

hlo = torch_xla._XLAC._get_xla_tensors_hlo([xx])
self.assertEqual(xx.shape, (8, 8 // self.n_devices))
self.assertIn(f'%custom-call.2 = f32[8,{8//self.n_devices}]{{1,0}}', hlo)
self.assertIn(f'%custom-call.1 = f32[8,{8//self.n_devices}]{{1,0}}', hlo)
self.assertIn(
f'custom_call_target="SPMDFullToShardShape", sharding={{manual}}', hlo)
self.assertEqual(torch_xla._XLAC._get_xla_sharding_spec(xx), "{manual}")
Expand All @@ -1215,7 +1216,7 @@ def test_spmd_full_to_shard_shape(self):

hlo = torch_xla._XLAC._get_xla_tensors_hlo([xx])
self.assertEqual(xx.shape, (8, 4))
self.assertIn(f'%custom-call.2 = f32[8,4]{{1,0}}', hlo)
self.assertIn(f'%custom-call.1 = f32[8,4]{{1,0}}', hlo)
self.assertIn(
f'custom_call_target="SPMDFullToShardShape", sharding={{manual}}', hlo)
self.assertEqual(torch_xla._XLAC._get_xla_sharding_spec(xx), "{manual}")
Expand Down Expand Up @@ -1246,7 +1247,7 @@ def test_spmd_shard_to_full_shape(self):

hlo = torch_xla._XLAC._get_xla_tensors_hlo([xx])
self.assertEqual(xx.shape, x.shape)
self.assertIn('%custom-call.9 = f32[8,8]{1,0}', hlo)
self.assertIn('%custom-call.5 = f32[8,8]{1,0}', hlo)
self.assertIn(
'custom_call_target="SPMDShardToFullShape", sharding={replicated}', hlo)
self.assertEqual(torch_xla._XLAC._get_xla_sharding_spec(xx), "{replicated}")
Expand Down Expand Up @@ -1297,7 +1298,7 @@ def test_spmd_reduce_scatter(self):

hlo = torch_xla._XLAC._get_xla_tensors_hlo([x])
self.assertIn(
f"reduce-scatter(f32[8,8]{{1,0}} %custom-call.2), channel_id=1, replica_groups={{{{{','.join([str(x) for x in self.device_ids])}}}}}, use_global_device_ids=true, dimensions={{0}}, to_apply=%AddComputation.3",
f"reduce-scatter(f32[8,8]{{1,0}} %custom-call.3), channel_id=1, replica_groups={{{{{','.join([str(x) for x in self.device_ids])}}}}}, use_global_device_ids=true, dimensions={{0}}, to_apply=%AddComputation.1",
hlo)

expected_x = torch.ones(8 // self.n_devices, 8) * self.n_devices
Expand All @@ -1318,7 +1319,7 @@ def test_spmd_reduce_scatter_canonical_index(self):

hlo = torch_xla._XLAC._get_xla_tensors_hlo([x])
self.assertIn(
f"reduce-scatter(f32[8,8]{{1,0}} %custom-call.2), channel_id=1, replica_groups={{{{{','.join([str(x) for x in self.device_ids])}}}}}, use_global_device_ids=true, dimensions={{1}}, to_apply=%AddComputation.3",
f"reduce-scatter(f32[8,8]{{1,0}} %custom-call.3), channel_id=1, replica_groups={{{{{','.join([str(x) for x in self.device_ids])}}}}}, use_global_device_ids=true, dimensions={{1}}, to_apply=%AddComputation.1",
hlo)

expected_x = torch.ones(8, 8 // self.n_devices) * self.n_devices
Expand All @@ -1338,7 +1339,7 @@ def test_spmd_all_reduce(self):

hlo = torch_xla._XLAC._get_xla_tensors_hlo([x])
self.assertIn(
f"all-reduce(f32[8,8]{{1,0}} %custom-call.2), channel_id=1, replica_groups={{{{{','.join([str(x) for x in self.device_ids])}}}}}, use_global_device_ids=true, to_apply=%AddComputation.3",
f"all-reduce(f32[8,8]{{1,0}} %custom-call.3), channel_id=1, replica_groups={{{{{','.join([str(x) for x in self.device_ids])}}}}}, use_global_device_ids=true, to_apply=%AddComputation.1",
hlo)

expected_x = torch.ones(8, 8) * self.n_devices
Expand All @@ -1359,7 +1360,7 @@ def test_spmd_all_reduce_scale(self):

hlo = torch_xla._XLAC._get_xla_tensors_hlo([x])
self.assertIn(
f"all-reduce(f32[8,8]{{1,0}} %custom-call.2), channel_id=1, replica_groups={{{{{','.join([str(x) for x in self.device_ids])}}}}}, use_global_device_ids=true, to_apply=%AddComputation.3",
f"all-reduce(f32[8,8]{{1,0}} %custom-call.3), channel_id=1, replica_groups={{{{{','.join([str(x) for x in self.device_ids])}}}}}, use_global_device_ids=true, to_apply=%AddComputation.1",
hlo)

expected_x = torch.ones(8, 8) * int(self.n_devices * scale)
Expand Down Expand Up @@ -1713,7 +1714,7 @@ def test_annotate_custom_sharding(self):
f'%p0.1 = f32[2,4,64,64]{{3,2,1,0}} parameter(0), sharding={original_sharding_spec}',
hlo)
self.assertIn(
f'%custom-call.2 = f32[2,4,64,64]{{3,2,1,0}} custom-call(f32[2,4,64,64]{{3,2,1,0}} %p0.1), custom_call_target="Sharding", sharding={custom_sharding_spec}',
f'%custom-call.1 = f32[2,4,64,64]{{3,2,1,0}} custom-call(f32[2,4,64,64]{{3,2,1,0}} %p0.1), custom_call_target="Sharding", sharding={custom_sharding_spec}',
hlo)
xm.mark_step()
# Ensure that the resulting sharding spec is preserved
Expand Down
4 changes: 2 additions & 2 deletions torch_xla/csrc/lowering_context.h
Original file line number Diff line number Diff line change
Expand Up @@ -124,8 +124,8 @@ class LoweringContext : public torch::lazy::LoweringContext {
};

// Reports an XLA builder error for the given node.
TF_ATTRIBUTE_NORETURN void ReportBuilderError(const torch::lazy::Node& node,
absl::string_view error_msg);
ABSL_ATTRIBUTE_NORETURN void ReportBuilderError(const torch::lazy::Node& node,
absl::string_view error_msg);

xla::XlaBuilder builder_;
std::unordered_map<torch::lazy::BackendData::Handle, Parameter>
Expand Down
22 changes: 19 additions & 3 deletions torch_xla/csrc/runtime/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -381,18 +381,34 @@ cc_test(
],
)

cc_library(
name = "tsl_platform_logging",
srcs = ["tsl_platform_logging.cpp"],
hdrs = ["tsl_platform_logging.h"],
deps = [
"@xla//xla/tsl/platform:env_time",
"@xla//xla/tsl/platform:logging",
"@xla//xla/tsl/platform:macros",
"@xla//xla/tsl/platform:types",
"@com_google_absl//absl/base:core_headers",
"@com_google_absl//absl/base:log_severity",
"@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/strings:str_format",
"@com_google_absl//absl/strings:string_view",
],
)

cc_library(
name = "tf_logging",
srcs = ["tf_logging.cpp"],
hdrs = ["tf_logging.h"],
deps = [
":tsl_platform_logging",
"//torch_xla/csrc:status",
"@torch//:headers",
"@torch//:runtime_headers",
"@tsl//tsl/platform:stacktrace",
"@tsl//tsl/platform:statusor",
"@xla//xla/service:platform_util",
"@com_google_absl//absl/base:log_severity",
"@com_google_absl//absl/log:absl_log",
],
)

Expand Down
2 changes: 1 addition & 1 deletion torch_xla/csrc/runtime/pjrt_registry.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,12 @@
#include "torch_xla/csrc/runtime/xla_coordinator.h"
#include "torch_xla/csrc/status.h"
#include "xla/pjrt/c/pjrt_c_api.h"
#include "xla/pjrt/cpu/cpu_client.h"
#include "xla/pjrt/distributed/client.h"
#include "xla/pjrt/distributed/distributed.h"
#include "xla/pjrt/distributed/in_memory_key_value_store.h"
#include "xla/pjrt/pjrt_api.h"
#include "xla/pjrt/pjrt_c_api_client.h"
#include "xla/pjrt/tfrt_cpu_pjrt_client.h"

namespace torch_xla {
namespace runtime {
Expand Down
Loading