diff --git a/.github/workflows/_tpu_ci.yml b/.github/workflows/_tpu_ci.yml index 2f48391c96e..ffa7de73407 100644 --- a/.github/workflows/_tpu_ci.yml +++ b/.github/workflows/_tpu_ci.yml @@ -51,10 +51,14 @@ jobs: pip install fsspec pip install rich + # Test dependencies + pip install --upgrade protobuf + pip install flax + # PyTorch/XLA Optional Dependencies # ================================= # - # Install `JAX` and `libtpu` dependencies for pallas and TPU tests. + # Install `jax` and `libtpu` dependencies for pallas and TPU tests. # # Note that we might need to install pre-release versions of both, in # external artifact repositories. @@ -70,18 +74,6 @@ jobs: pip install "$WHL[pallas]" --pre --index-url $INDEX --find-links $LINKS pip install "$WHL[tpu]" --pre --index-url $INDEX --find-links $LINKS - pip install --upgrade protobuf - - # Flax Pin - # ======== - # - # Be careful when bumping the `flax` version, since it can cause tests that - # depend on `jax` to start breaking. - # - # Newer `flax` versions might pull newer `jax` versions, which might be incompatible - # with the current version of PyTorch/XLA. - pip install flax==0.11.2 - - name: Run Tests (${{ matrix.test_script }}) if: inputs.has_code_changes == 'true' env: diff --git a/.torch_commit b/.torch_commit index 715a8bee47e..0536d88def3 100644 --- a/.torch_commit +++ b/.torch_commit @@ -1,2 +1,2 @@ -# 2025-09-17 -928ac57c2ab03f9f79376f9995553eea2e6f4ca8 \ No newline at end of file +# 2025-09-29 +21fec65781bebe867faf209f89bb687ffd236ca4 \ No newline at end of file diff --git a/BUILD b/BUILD index 128f83dcd56..d445d37ed6a 100644 --- a/BUILD +++ b/BUILD @@ -1,5 +1,5 @@ -load("@python//:defs.bzl", "compile_pip_requirements") load("@python_version_repo//:py_version.bzl", "REQUIREMENTS") +load("@rules_python//python:pip.bzl", "compile_pip_requirements") compile_pip_requirements( name = "requirements", diff --git a/WORKSPACE b/WORKSPACE index 78e928d2a0f..2edf2be524a 100644 --- a/WORKSPACE +++ b/WORKSPACE @@ -52,7 +52,7 @@ new_local_repository( # To build PyTorch/XLA with a new revison of OpenXLA, update the xla_hash to # the openxla git commit hash and note the date of the commit. -xla_hash = '92f7b5952dd585c5be17c9a5caad27407005b513' # Committed on 2025-08-15. +xla_hash = '9a9aa0e11e4fcda8d6a9c3267dca6776ddbdb0ca' # Committed on 2025-10-01. http_archive( name = "xla", @@ -63,6 +63,7 @@ http_archive( patch_tool = "patch", patches = [ "//openxla_patches:no_fortify.diff", + "//openxla_patches:if_constexpr_static_assert.diff", ], strip_prefix = "xla-" + xla_hash, urls = [ diff --git a/openxla_patches/if_constexpr_static_assert.diff b/openxla_patches/if_constexpr_static_assert.diff new file mode 100644 index 00000000000..b258d6dfb67 --- /dev/null +++ b/openxla_patches/if_constexpr_static_assert.diff @@ -0,0 +1,40 @@ +diff --git a/xla/python/ifrt/attribute_map.h b/xla/python/ifrt/attribute_map.h +index a8c9f11c8d..e5bb70bcf8 100644 +--- a/xla/python/ifrt/attribute_map.h ++++ b/xla/python/ifrt/attribute_map.h +@@ -106,7 +106,9 @@ class AttributeMap { + } else if constexpr (std::is_same_v) { + return Get(key); + } else { +- static_assert(false, "Unsupported type for AttributeMap::Get"); ++ // Same as: static_assert(false). ++ // Make it compileable by GCC version < 13. ++ static_assert(!sizeof(T), "Unsupported type for AttributeMap::Get"); + } + } + +diff --git a/xla/stream_executor/plugin_registry.cc b/xla/stream_executor/plugin_registry.cc +index f16a4f6707..8bfd51b238 100644 +--- a/xla/stream_executor/plugin_registry.cc ++++ b/xla/stream_executor/plugin_registry.cc +@@ -41,7 +41,9 @@ PluginKind GetPluginKind() { + } else if constexpr (std::is_same_v) { + return PluginKind::kFft; + } else { +- static_assert(false, "Unsupported factory type"); ++ // Same as: static_assert(false). ++ // Make it compileable by GCC version < 13. ++ static_assert(!sizeof(FactoryT), "Unsupported factory type"); + } + } + template +@@ -53,7 +55,9 @@ absl::string_view GetPluginName() { + } else if constexpr (std::is_same_v) { + return "FFT"; + } else { +- static_assert(false, "Unsupported factory type"); ++ // Same as: static_assert(false). ++ // Make it compileable by GCC version < 13. ++ static_assert(!sizeof(FactoryT), "Unsupported factory type"); + } + } diff --git a/setup.py b/setup.py index 33642f9a3f6..d646be3a4f3 100644 --- a/setup.py +++ b/setup.py @@ -112,12 +112,12 @@ USE_NIGHTLY = True # Whether to use nightly or stable libtpu and JAX. -_libtpu_version = '0.0.21' -_libtpu_date = '20250813' +_libtpu_version = '0.0.24' +_libtpu_date = '20250929' -_jax_version = '0.7.1' -_jaxlib_version = '0.7.1' -_jax_date = '20250813' # Date for jax and jaxlib. +_jax_version = '0.8.0' +_jaxlib_version = '0.8.0' +_jax_date = '20251001' # Date for jax and jaxlib. if USE_NIGHTLY: _libtpu_version += f".dev{_libtpu_date}+nightly" diff --git a/test/spmd/test_fsdp_v2.py b/test/spmd/test_fsdp_v2.py index 778fb729460..7c69e374fe8 100644 --- a/test/spmd/test_fsdp_v2.py +++ b/test/spmd/test_fsdp_v2.py @@ -55,7 +55,7 @@ def test_fsdp_v2_basic(self): # Make sure optimization barrier is applied. hlo = torch_xla._XLAC._get_xla_tensors_hlo([model.fc2.weight.grad]) self.assertIn( - 'opt-barrier.38 = (f32[1,64]{0,1}, f32[1]{0}, f32[16,64]{1,0}) opt-barrier((f32[1,64]{0,1}, f32[1]{0}, f32[16,64]{1,0}) %tuple.37', + 'opt-barrier.1 = (f32[1,64]{0,1}, f32[1]{0}, f32[16,64]{1,0}) opt-barrier((f32[1,64]{0,1}, f32[1]{0}, f32[16,64]{1,0}) %tuple.2', hlo) # Make sure the model can execute without error. diff --git a/test/spmd/test_xla_sharding.py b/test/spmd/test_xla_sharding.py index 48b760f6e3f..46e9785fee7 100644 --- a/test/spmd/test_xla_sharding.py +++ b/test/spmd/test_xla_sharding.py @@ -613,7 +613,7 @@ def test_inplace_add_with_sharding(self): self.assertEqual(sharding_spec, torch_xla._XLAC._get_xla_sharding_spec(xt)) hlo = torch_xla._XLAC._get_xla_tensors_hlo([xt]) self.assertIn( - '%custom-call.7 = f32[2,2]{1,0} custom-call(f32[2,2]{1,0} %add.6), custom_call_target="Sharding", sharding=', + '%custom-call.1 = f32[2,2]{1,0} custom-call(f32[2,2]{1,0} %add.1), custom_call_target="Sharding", sharding=', hlo) # avoid calling xr.addressable_device_count here otherwise it will init the test @@ -713,7 +713,8 @@ def test_xla_sharded_hlo_dump(self): partition_spec) xst2 = xst1 + 5 hlo = torch_xla._XLAC._get_xla_tensors_hlo([xst2.global_tensor]) - self.assertIn('%p1.3 = f32[1,8]{1,0} parameter(1), sharding', hlo) + print(hlo) + self.assertIn('%p1.1 = f32[1,8]{1,0} parameter(1), sharding', hlo) if torch_xla._XLAC._xla_get_auto_sharding(): # scalar 5 should be implicitly replicated, so the pre-optimization HLO # shouldn't mark it with sharding. @@ -828,13 +829,13 @@ def test_mark_sharding_ir(self): (0, 1)) hlo = torch_xla._XLAC._get_xla_tensors_hlo([actual.global_tensor]) self.assertIn( - '%custom-call.7 = f32[1,128]{1,0} custom-call(f32[1,128]{1,0} %add.6), custom_call_target="Sharding", sharding=', + '%custom-call.1 = f32[1,128]{1,0} custom-call(f32[1,128]{1,0} %add.1), custom_call_target="Sharding", sharding=', hlo) actual += 0 hlo = torch_xla._XLAC._get_xla_tensors_hlo([actual.global_tensor]) self.assertIn( - '%add.12 = f32[1,128]{1,0} add(f32[1,128]{1,0} %custom-call.9, f32[1,128]{1,0} %broadcast.11)', + '%add.3 = f32[1,128]{1,0} add(f32[1,128]{1,0} %custom-call.1, f32[1,128]{1,0} %broadcast.3)', hlo) self.assertTrue(torch.allclose(expected, actual.cpu())) @@ -1141,7 +1142,7 @@ def test_backward_optimization_barrier(self): hlo = torch_xla._XLAC._get_xla_tensors_hlo([model.fc2.weight.grad]) self.assertIn( - '%opt-barrier.37 = (f32[1,64]{0,1}, f32[1]{0}, f32[2,64]{1,0}) opt-barrier((f32[1,64]{0,1}, f32[1]{0}, f32[2,64]{1,0}) %tuple.36)', + '%opt-barrier.1 = (f32[1,64]{0,1}, f32[1]{0}, f32[2,64]{1,0}) opt-barrier((f32[1,64]{0,1}, f32[1]{0}, f32[2,64]{1,0}) %tuple.2)', hlo) def test_mark_shard_scalar(self): @@ -1198,7 +1199,7 @@ def test_spmd_full_to_shard_shape(self): hlo = torch_xla._XLAC._get_xla_tensors_hlo([xx]) self.assertEqual(xx.shape, (8, 8 // self.n_devices)) - self.assertIn(f'%custom-call.2 = f32[8,{8//self.n_devices}]{{1,0}}', hlo) + self.assertIn(f'%custom-call.1 = f32[8,{8//self.n_devices}]{{1,0}}', hlo) self.assertIn( f'custom_call_target="SPMDFullToShardShape", sharding={{manual}}', hlo) self.assertEqual(torch_xla._XLAC._get_xla_sharding_spec(xx), "{manual}") @@ -1215,7 +1216,7 @@ def test_spmd_full_to_shard_shape(self): hlo = torch_xla._XLAC._get_xla_tensors_hlo([xx]) self.assertEqual(xx.shape, (8, 4)) - self.assertIn(f'%custom-call.2 = f32[8,4]{{1,0}}', hlo) + self.assertIn(f'%custom-call.1 = f32[8,4]{{1,0}}', hlo) self.assertIn( f'custom_call_target="SPMDFullToShardShape", sharding={{manual}}', hlo) self.assertEqual(torch_xla._XLAC._get_xla_sharding_spec(xx), "{manual}") @@ -1246,7 +1247,7 @@ def test_spmd_shard_to_full_shape(self): hlo = torch_xla._XLAC._get_xla_tensors_hlo([xx]) self.assertEqual(xx.shape, x.shape) - self.assertIn('%custom-call.9 = f32[8,8]{1,0}', hlo) + self.assertIn('%custom-call.5 = f32[8,8]{1,0}', hlo) self.assertIn( 'custom_call_target="SPMDShardToFullShape", sharding={replicated}', hlo) self.assertEqual(torch_xla._XLAC._get_xla_sharding_spec(xx), "{replicated}") @@ -1297,7 +1298,7 @@ def test_spmd_reduce_scatter(self): hlo = torch_xla._XLAC._get_xla_tensors_hlo([x]) self.assertIn( - f"reduce-scatter(f32[8,8]{{1,0}} %custom-call.2), channel_id=1, replica_groups={{{{{','.join([str(x) for x in self.device_ids])}}}}}, use_global_device_ids=true, dimensions={{0}}, to_apply=%AddComputation.3", + f"reduce-scatter(f32[8,8]{{1,0}} %custom-call.3), channel_id=1, replica_groups={{{{{','.join([str(x) for x in self.device_ids])}}}}}, use_global_device_ids=true, dimensions={{0}}, to_apply=%AddComputation.1", hlo) expected_x = torch.ones(8 // self.n_devices, 8) * self.n_devices @@ -1318,7 +1319,7 @@ def test_spmd_reduce_scatter_canonical_index(self): hlo = torch_xla._XLAC._get_xla_tensors_hlo([x]) self.assertIn( - f"reduce-scatter(f32[8,8]{{1,0}} %custom-call.2), channel_id=1, replica_groups={{{{{','.join([str(x) for x in self.device_ids])}}}}}, use_global_device_ids=true, dimensions={{1}}, to_apply=%AddComputation.3", + f"reduce-scatter(f32[8,8]{{1,0}} %custom-call.3), channel_id=1, replica_groups={{{{{','.join([str(x) for x in self.device_ids])}}}}}, use_global_device_ids=true, dimensions={{1}}, to_apply=%AddComputation.1", hlo) expected_x = torch.ones(8, 8 // self.n_devices) * self.n_devices @@ -1338,7 +1339,7 @@ def test_spmd_all_reduce(self): hlo = torch_xla._XLAC._get_xla_tensors_hlo([x]) self.assertIn( - f"all-reduce(f32[8,8]{{1,0}} %custom-call.2), channel_id=1, replica_groups={{{{{','.join([str(x) for x in self.device_ids])}}}}}, use_global_device_ids=true, to_apply=%AddComputation.3", + f"all-reduce(f32[8,8]{{1,0}} %custom-call.3), channel_id=1, replica_groups={{{{{','.join([str(x) for x in self.device_ids])}}}}}, use_global_device_ids=true, to_apply=%AddComputation.1", hlo) expected_x = torch.ones(8, 8) * self.n_devices @@ -1359,7 +1360,7 @@ def test_spmd_all_reduce_scale(self): hlo = torch_xla._XLAC._get_xla_tensors_hlo([x]) self.assertIn( - f"all-reduce(f32[8,8]{{1,0}} %custom-call.2), channel_id=1, replica_groups={{{{{','.join([str(x) for x in self.device_ids])}}}}}, use_global_device_ids=true, to_apply=%AddComputation.3", + f"all-reduce(f32[8,8]{{1,0}} %custom-call.3), channel_id=1, replica_groups={{{{{','.join([str(x) for x in self.device_ids])}}}}}, use_global_device_ids=true, to_apply=%AddComputation.1", hlo) expected_x = torch.ones(8, 8) * int(self.n_devices * scale) @@ -1713,7 +1714,7 @@ def test_annotate_custom_sharding(self): f'%p0.1 = f32[2,4,64,64]{{3,2,1,0}} parameter(0), sharding={original_sharding_spec}', hlo) self.assertIn( - f'%custom-call.2 = f32[2,4,64,64]{{3,2,1,0}} custom-call(f32[2,4,64,64]{{3,2,1,0}} %p0.1), custom_call_target="Sharding", sharding={custom_sharding_spec}', + f'%custom-call.1 = f32[2,4,64,64]{{3,2,1,0}} custom-call(f32[2,4,64,64]{{3,2,1,0}} %p0.1), custom_call_target="Sharding", sharding={custom_sharding_spec}', hlo) xm.mark_step() # Ensure that the resulting sharding spec is preserved diff --git a/torch_xla/csrc/lowering_context.h b/torch_xla/csrc/lowering_context.h index f0545534155..9bfbd70b982 100644 --- a/torch_xla/csrc/lowering_context.h +++ b/torch_xla/csrc/lowering_context.h @@ -124,8 +124,8 @@ class LoweringContext : public torch::lazy::LoweringContext { }; // Reports an XLA builder error for the given node. - TF_ATTRIBUTE_NORETURN void ReportBuilderError(const torch::lazy::Node& node, - absl::string_view error_msg); + ABSL_ATTRIBUTE_NORETURN void ReportBuilderError(const torch::lazy::Node& node, + absl::string_view error_msg); xla::XlaBuilder builder_; std::unordered_map diff --git a/torch_xla/csrc/runtime/BUILD b/torch_xla/csrc/runtime/BUILD index 4f0f3bf384e..655b91d3fc1 100644 --- a/torch_xla/csrc/runtime/BUILD +++ b/torch_xla/csrc/runtime/BUILD @@ -381,18 +381,34 @@ cc_test( ], ) +cc_library( + name = "tsl_platform_logging", + srcs = ["tsl_platform_logging.cpp"], + hdrs = ["tsl_platform_logging.h"], + deps = [ + "@xla//xla/tsl/platform:env_time", + "@xla//xla/tsl/platform:logging", + "@xla//xla/tsl/platform:macros", + "@xla//xla/tsl/platform:types", + "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/base:log_severity", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/strings:str_format", + "@com_google_absl//absl/strings:string_view", + ], +) + cc_library( name = "tf_logging", srcs = ["tf_logging.cpp"], hdrs = ["tf_logging.h"], deps = [ + ":tsl_platform_logging", "//torch_xla/csrc:status", "@torch//:headers", "@torch//:runtime_headers", - "@tsl//tsl/platform:stacktrace", - "@tsl//tsl/platform:statusor", - "@xla//xla/service:platform_util", "@com_google_absl//absl/base:log_severity", + "@com_google_absl//absl/log:absl_log", ], ) diff --git a/torch_xla/csrc/runtime/pjrt_registry.cpp b/torch_xla/csrc/runtime/pjrt_registry.cpp index 162e6dca9d2..73c6c624ad3 100644 --- a/torch_xla/csrc/runtime/pjrt_registry.cpp +++ b/torch_xla/csrc/runtime/pjrt_registry.cpp @@ -12,12 +12,12 @@ #include "torch_xla/csrc/runtime/xla_coordinator.h" #include "torch_xla/csrc/status.h" #include "xla/pjrt/c/pjrt_c_api.h" +#include "xla/pjrt/cpu/cpu_client.h" #include "xla/pjrt/distributed/client.h" #include "xla/pjrt/distributed/distributed.h" #include "xla/pjrt/distributed/in_memory_key_value_store.h" #include "xla/pjrt/pjrt_api.h" #include "xla/pjrt/pjrt_c_api_client.h" -#include "xla/pjrt/tfrt_cpu_pjrt_client.h" namespace torch_xla { namespace runtime { diff --git a/torch_xla/csrc/runtime/tf_logging.h b/torch_xla/csrc/runtime/tf_logging.h index 58998fdfbcf..5cecbea07b0 100644 --- a/torch_xla/csrc/runtime/tf_logging.h +++ b/torch_xla/csrc/runtime/tf_logging.h @@ -3,32 +3,29 @@ #include -#include "absl/base/log_severity.h" -#include "tsl/platform/logging.h" +#include "absl/log/absl_log.h" +#include "torch_xla/csrc/runtime/tsl_platform_logging.h" namespace torch_xla { namespace runtime { namespace internal { -// It happens that Caffe defined the same exact Google macros, hiding the TF -// ones, and making log messages disappear. -// Unfortunately to get those back, we have to poke through the TF -// implementaiton of them. -#define TF_LOG(severity) _TF_LOG_##severity - -#define TF_VLOG_IS_ON(lvl) \ - (([](int level, const char* fname) { \ - static const bool vmodule_activated = \ - ::tsl::internal::LogMessage::VmoduleActivated(fname, level); \ - return vmodule_activated; \ - })(lvl, __FILE__)) - -#define TF_VLOG(level) \ - TF_PREDICT_TRUE(!TF_VLOG_IS_ON(level)) \ - ? (void)0 \ - : ::tsl::internal::Voidifier() & \ - ::tsl::internal::LogMessage(__FILE__, __LINE__, \ - absl::LogSeverity::kInfo) +// TODO: replace all TF_*LOG macro calls with ABSL_*LOG() +// +// Why are we using Abseil logging, now? +// ===================================== +// Ref: https://github.com/openxla/xla/pull/29477 +// +// OpenXLA removed their internal logging in favor of Abseil. +// +// Why do we have the `TF_` prefix? +// ================================ +// Ref: https://github.com/pytorch/xla/pull/34 +// +// So as not to clash with C10 definition. +// Maybe this is not a problem anymore, though. +#define TF_LOG(severity) ABSL_LOG(severity) +#define TF_VLOG(level) ABSL_VLOG(level) struct ErrorSink : public std::basic_ostringstream {}; @@ -39,7 +36,7 @@ class ErrorGenerator { // Use a dummy & operator as it has lower precedence WRT the streaming // operator, and hence allows collecting user error messages before we finally // throw. - TF_ATTRIBUTE_NORETURN void operator&( + ABSL_ATTRIBUTE_NORETURN void operator&( const std::basic_ostream& oss) const; private: @@ -55,10 +52,12 @@ class ErrorGenerator { while (TF_PREDICT_FALSE(!(condition))) \ TF_ERROR_STREAM() << "Check failed: " #condition ": " -#define TF_CHECK_OP_LOG(name, op, val1, val2) \ - while (::tsl::internal::CheckOpString _result{::tsl::internal::name##Impl( \ - ::tsl::internal::GetReferenceableValue(val1), \ - ::tsl::internal::GetReferenceableValue(val2), #val1 " " #op " " #val2)}) \ +#define TF_CHECK_OP_LOG(name, op, val1, val2) \ + while (::tsl::torch_xla::internal::CheckOpString _result{ \ + ::tsl::torch_xla::internal::name##Impl( \ + ::tsl::torch_xla::internal::GetReferenceableValue(val1), \ + ::tsl::torch_xla::internal::GetReferenceableValue(val2), \ + #val1 " " #op " " #val2)}) \ TF_ERROR_STREAM() << *(_result.str_) #define TF_CHECK_OP(name, op, val1, val2) TF_CHECK_OP_LOG(name, op, val1, val2) @@ -75,10 +74,6 @@ class ErrorGenerator { #define TF_CHECK_GE(val1, val2) TF_CHECK_LE(val2, val1) #define TF_CHECK_GT(val1, val2) TF_CHECK_LT(val2, val1) -#undef TF_CHECK_OK -#define TF_CHECK_OK(val) TF_DO_CHECK_OK(val, FATAL) -#define TF_CHECK_NOTNULL(val) TF_CHECK(val != nullptr) - } // namespace internal } // namespace runtime } // namespace torch_xla diff --git a/torch_xla/csrc/runtime/tsl_platform_logging.cpp b/torch_xla/csrc/runtime/tsl_platform_logging.cpp new file mode 100644 index 00000000000..4d023138b6d --- /dev/null +++ b/torch_xla/csrc/runtime/tsl_platform_logging.cpp @@ -0,0 +1,579 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +/* + * This file was copied from the OpenXLA repository + * (https://github.com/openxla/xla), before it was deleted. + * + * Commit: 20358a12f26199d016e6e690fe31a4a0a141226e + * Date: 2025-08-26 + */ + +#include "torch_xla/csrc/runtime/tsl_platform_logging.h" + +#include +#include +#include +#include +#include +#include +#include +#include + +// TODO(b/142492876): Avoid depending on absl internal. +#include "absl/base/internal/cycleclock.h" +#include "absl/base/internal/sysinfo.h" +#include "absl/base/log_severity.h" +#include "absl/base/optimization.h" +#include "absl/container/flat_hash_map.h" +#include "absl/strings/numbers.h" +#include "absl/strings/str_format.h" +#include "absl/strings/string_view.h" +#include "absl/synchronization/mutex.h" +#include "xla/tsl/platform/env_time.h" +#include "xla/tsl/platform/types.h" + +#if defined(PLATFORM_POSIX_ANDROID) +#include + +#include +#include +#endif + +#include +#include +#include + +#include +#include + +namespace tsl { + +namespace torch_xla { + +namespace internal { +namespace { + +// This is an internal singleton class that manages the log sinks. It allows +// adding and removing the log sinks, as well as handling sending log messages +// to all the registered log sinks. +class TFLogSinks { + public: + // Gets the TFLogSinks instance. This is the entry point for using this class. + static TFLogSinks& Instance(); + + // Adds a log sink. The sink argument must not be a nullptr. TFLogSinks + // takes ownership of the pointer, the user must not free the pointer. + // The pointer will remain valid until the application terminates or + // until TFLogSinks::Remove is called for the same pointer value. + void Add(TFLogSink* sink); + + // Removes a log sink. This will also erase the sink object. The pointer + // to the sink becomes invalid after this call. + void Remove(TFLogSink* sink); + + // Gets the currently registered log sinks. + std::vector GetSinks() const; + + // Sends a log message to all registered log sinks. + // + // If there are no log sinks are registered: + // + // NO_DEFAULT_LOGGER is defined: + // Up to 128 messages will be queued until a log sink is added. + // The queue will then be logged to the first added log sink. + // + // NO_DEFAULT_LOGGER is not defined: + // The messages will be logged using the default logger. The default logger + // will log to stdout on all platforms except for Android. On Androit the + // default Android logger will be used. + void Send(const TFLogEntry& entry); + + private: + TFLogSinks(); + void SendToSink(TFLogSink& sink, const TFLogEntry& entry); + + std::queue log_entry_queue_; + static const size_t kMaxLogEntryQueueSize = 128; + + mutable absl::Mutex mutex_; + std::vector sinks_; +}; + +TFLogSinks::TFLogSinks() { +#ifndef NO_DEFAULT_LOGGER + static TFDefaultLogSink* const default_sink = new TFDefaultLogSink(); + sinks_.push_back(default_sink); +#endif +} + +TFLogSinks& TFLogSinks::Instance() { + static TFLogSinks* const instance = new TFLogSinks(); + return *instance; +} + +void TFLogSinks::Add(TFLogSink* sink) { + assert(sink != nullptr && "The sink must not be a nullptr"); + + absl::MutexLock lock(&mutex_); + sinks_.push_back(sink); + + // If this is the only sink log all the queued up messages to this sink + if (sinks_.size() == 1) { + while (!log_entry_queue_.empty()) { + for (const auto& sink : sinks_) { + SendToSink(*sink, log_entry_queue_.front()); + } + log_entry_queue_.pop(); + } + } +} + +void TFLogSinks::Remove(TFLogSink* sink) { + assert(sink != nullptr && "The sink must not be a nullptr"); + + absl::MutexLock lock(&mutex_); + auto it = std::find(sinks_.begin(), sinks_.end(), sink); + if (it != sinks_.end()) sinks_.erase(it); +} + +std::vector TFLogSinks::GetSinks() const { + absl::MutexLock lock(&mutex_); + return sinks_; +} + +void TFLogSinks::Send(const TFLogEntry& entry) { + absl::MutexLock lock(&mutex_); + + // If we don't have any sinks registered, queue them up + if (sinks_.empty()) { + // If we've exceeded the maximum queue size, drop the oldest entries + while (log_entry_queue_.size() >= kMaxLogEntryQueueSize) { + log_entry_queue_.pop(); + } + log_entry_queue_.push(entry); + return; + } + + // If we have items in the queue, push them out first + while (!log_entry_queue_.empty()) { + for (const auto& sink : sinks_) { + SendToSink(*sink, log_entry_queue_.front()); + } + log_entry_queue_.pop(); + } + + // ... and now we can log the current log entry + for (const auto& sink : sinks_) { + SendToSink(*sink, entry); + } +} + +void TFLogSinks::SendToSink(TFLogSink& sink, const TFLogEntry& entry) { + sink.Send(entry); + sink.WaitTillSent(); +} + +// A class for managing the text file to which VLOG output is written. +// If the environment variable TF_CPP_VLOG_FILENAME is set, all VLOG +// calls are redirected from stderr to a file with corresponding name. +class VlogFileMgr { + public: + // Determines if the env variable is set and if necessary + // opens the file for write access. + VlogFileMgr(); + // Closes the file. + ~VlogFileMgr(); + // Returns either a pointer to the file or stderr. + FILE* FilePtr() const; + + private: + FILE* vlog_file_ptr; + char* vlog_file_name; +}; + +VlogFileMgr::VlogFileMgr() { + vlog_file_name = getenv("TF_CPP_VLOG_FILENAME"); + vlog_file_ptr = + vlog_file_name == nullptr ? nullptr : fopen(vlog_file_name, "w"); + + if (vlog_file_ptr == nullptr) { + vlog_file_ptr = stderr; + } +} + +VlogFileMgr::~VlogFileMgr() { + if (vlog_file_ptr != stderr) { + fclose(vlog_file_ptr); + } +} + +FILE* VlogFileMgr::FilePtr() const { return vlog_file_ptr; } + +int ParseInteger(absl::string_view str) { + int level; + if (!absl::SimpleAtoi(str, &level)) { + return 0; + } + return level; +} + +// Parse log level (int64) from environment variable (char*) +int64_t LogLevelStrToInt(const char* tf_env_var_val) { + if (tf_env_var_val == nullptr) { + return 0; + } + return ParseInteger(tf_env_var_val); +} + +using VmoduleMap = absl::flat_hash_map; + +// Returns a mapping from module name to VLOG level, derived from the +// TF_CPP_VMODULE environment variable; ownership is transferred to the caller. +VmoduleMap* VmodulesMapFromEnv() { + // The value of the env var is supposed to be of the form: + // "foo=1,bar=2,baz=3" + const char* env = getenv("TF_CPP_VMODULE"); + if (env == nullptr) { + // If there is no TF_CPP_VMODULE configuration (most common case), return + // nullptr so that the ShouldVlogModule() API can fast bail out of it. + return nullptr; + } + // The memory returned by getenv() can be invalidated by following getenv() or + // setenv() calls. And since we keep references to it in the VmoduleMap in + // form of StringData objects, make a copy of it. + const char* env_data = strdup(env); + absl::string_view env_view(env_data); + VmoduleMap* result = new VmoduleMap(); + while (!env_view.empty()) { + size_t eq_pos = env_view.find('='); + if (eq_pos == absl::string_view::npos) { + break; + } + absl::string_view module_name = env_view.substr(0, eq_pos); + env_view.remove_prefix(eq_pos + 1); + + // Comma either points at the next comma delimiter, or at a null terminator. + // We check that the integer we parse ends at this delimiter. + size_t level_end_pos = env_view.find(','); + absl::string_view level_str = env_view.substr(0, level_end_pos); + (*result)[module_name] = ParseInteger(level_str); + if (level_end_pos != absl::string_view::npos) { + env_view.remove_prefix(level_end_pos + 1); + } + } + return result; +} + +bool EmitThreadIdFromEnv() { + const char* tf_env_var_val = getenv("TF_CPP_LOG_THREAD_ID"); + return tf_env_var_val == nullptr ? false : ParseInteger(tf_env_var_val) != 0; +} + +} // namespace + +absl::LogSeverityAtLeast MinLogLevelFromEnv() { + // We don't want to print logs during fuzzing as that would slow fuzzing down + // by almost 2x. So, if we are in fuzzing mode (not just running a test), we + // return a value so that nothing is actually printed. Since LOG uses >= + // (see ~LogMessage in this file) to see if log messages need to be printed, + // the value we're interested on to disable printing is the maximum severity. + // See also http://llvm.org/docs/LibFuzzer.html#fuzzer-friendly-build-mode +#ifdef FUZZING_BUILD_MODE_UNSAFE_FOR_PRODUCTION + return absl::LogSeverityAtLeast::kInfinity; +#else + const char* tf_env_var_val = getenv("TF_CPP_MIN_LOG_LEVEL"); + return static_cast( + LogLevelStrToInt(tf_env_var_val)); +#endif +} + +int MaxVLogLevelFromEnv() { + // We don't want to print logs during fuzzing as that would slow fuzzing down + // by almost 2x. So, if we are in fuzzing mode (not just running a test), we + // return a value so that nothing is actually printed. Since VLOG uses <= + // (see VLOG_IS_ON in logging.h) to see if log messages need to be printed, + // the value we're interested on to disable printing is 0. + // See also http://llvm.org/docs/LibFuzzer.html#fuzzer-friendly-build-mode +#ifdef FUZZING_BUILD_MODE_UNSAFE_FOR_PRODUCTION + return 0; +#else + const char* tf_env_var_val = getenv("TF_CPP_MAX_VLOG_LEVEL"); + return LogLevelStrToInt(tf_env_var_val); +#endif +} + +LogMessage::LogMessage(const char* fname, int line, absl::LogSeverity severity) + : fname_(fname), line_(line), severity_(severity) {} + +LogMessage& LogMessage::AtLocation(absl::string_view fname, int line) { + fname_ = fname; + line_ = line; + return *this; +} + +LogMessage::~LogMessage() { + // Read the min log level once during the first call to logging. + static absl::LogSeverityAtLeast min_log_level = MinLogLevelFromEnv(); + if (severity_ >= min_log_level) { + GenerateLogMessage(); + } +} + +void LogMessage::GenerateLogMessage() { + TFLogSinks::Instance().Send(TFLogEntry(severity_, fname_, line_, str())); +} + +int LogMessage::MaxVLogLevel() { + static int max_vlog_level = MaxVLogLevelFromEnv(); + return max_vlog_level; +} + +bool LogMessage::VmoduleActivated(const char* fname, int level) { + if (level <= MaxVLogLevel()) { + return true; + } + static VmoduleMap* vmodules = VmodulesMapFromEnv(); + if (ABSL_PREDICT_TRUE(vmodules == nullptr)) { + return false; + } + absl::string_view module(fname); + if (size_t last_slash = module.rfind('/'); + last_slash != absl::string_view::npos) { + module.remove_prefix(last_slash + 1); + } + if (size_t dot_after = module.find('.'); + dot_after != absl::string_view::npos) { + module.remove_suffix(module.size() - dot_after); + } + auto it = vmodules->find(module); + return it != vmodules->end() && it->second >= level; +} + +LogMessageFatal::LogMessageFatal(const char* file, int line) + : LogMessage(file, line, absl::LogSeverity::kFatal) {} +LogMessageFatal::~LogMessageFatal() { + // abort() ensures we don't return (we promised we would not via + // ATTRIBUTE_NORETURN). + GenerateLogMessage(); + abort(); +} + +template <> +void MakeCheckOpValueString(std::ostream* os, const char& v) { + if (v >= 32 && v <= 126) { + (*os) << "'" << v << "'"; + } else { + (*os) << "char value " << static_cast(v); + } +} + +template <> +void MakeCheckOpValueString(std::ostream* os, const signed char& v) { + if (v >= 32 && v <= 126) { + (*os) << "'" << v << "'"; + } else { + (*os) << "signed char value " << static_cast(v); + } +} + +template <> +void MakeCheckOpValueString(std::ostream* os, const unsigned char& v) { + if (v >= 32 && v <= 126) { + (*os) << "'" << v << "'"; + } else { + (*os) << "unsigned char value " << static_cast(v); + } +} + +template <> +void MakeCheckOpValueString(std::ostream* os, const std::nullptr_t& v) { + (*os) << "nullptr"; +} + +CheckOpMessageBuilder::CheckOpMessageBuilder(const char* exprtext) + : stream_(new std::ostringstream) { + *stream_ << "Check failed: " << exprtext << " ("; +} + +CheckOpMessageBuilder::~CheckOpMessageBuilder() { delete stream_; } + +std::ostream* CheckOpMessageBuilder::ForVar2() { + *stream_ << " vs. "; + return stream_; +} + +string* CheckOpMessageBuilder::NewString() { + *stream_ << ")"; + return new string(stream_->str()); +} + +namespace { +// The following code behaves like AtomicStatsCounter::LossyAdd() for +// speed since it is fine to lose occasional updates. +// Returns old value of *counter. +uint32 LossyIncrement(std::atomic* counter) { + const uint32 value = counter->load(std::memory_order_relaxed); + counter->store(value + 1, std::memory_order_relaxed); + return value; +} +} // namespace + +bool LogEveryNState::ShouldLog(int n) { + return n != 0 && (LossyIncrement(&counter_) % n) == 0; +} + +bool LogFirstNState::ShouldLog(int n) { + const int counter_value = + static_cast(counter_.load(std::memory_order_relaxed)); + if (counter_value < n) { + counter_.store(counter_value + 1, std::memory_order_relaxed); + return true; + } + return false; +} + +bool LogEveryPow2State::ShouldLog(int ignored) { + const uint32 new_value = LossyIncrement(&counter_) + 1; + return (new_value & (new_value - 1)) == 0; +} + +bool LogEveryNSecState::ShouldLog(double seconds) { + LossyIncrement(&counter_); + const int64_t now_cycles = absl::base_internal::CycleClock::Now(); + int64_t next_cycles = next_log_time_cycles_.load(std::memory_order_relaxed); + do { + if (now_cycles <= next_cycles) return false; + } while (!next_log_time_cycles_.compare_exchange_weak( + next_cycles, + now_cycles + seconds * absl::base_internal::CycleClock::Frequency(), + std::memory_order_relaxed, std::memory_order_relaxed)); + return true; +} + +} // namespace internal + +void TFAddLogSink(TFLogSink* sink) { + internal::TFLogSinks::Instance().Add(sink); +} + +void TFRemoveLogSink(TFLogSink* sink) { + internal::TFLogSinks::Instance().Remove(sink); +} + +std::vector TFGetLogSinks() { + return internal::TFLogSinks::Instance().GetSinks(); +} + +void TFDefaultLogSink::Send(const TFLogEntry& entry) { +#ifdef PLATFORM_POSIX_ANDROID + int android_log_level; + switch (entry.log_severity()) { + case absl::LogSeverity::kInfo: + android_log_level = ANDROID_LOG_INFO; + break; + case absl::LogSeverity::kWarning: + android_log_level = ANDROID_LOG_WARN; + break; + case absl::LogSeverity::kError: + android_log_level = ANDROID_LOG_ERROR; + break; + case absl::LogSeverity::kFatal: + android_log_level = ANDROID_LOG_FATAL; + break; + default: + if (entry.log_severity() < absl::LogSeverity::kInfo) { + android_log_level = ANDROID_LOG_VERBOSE; + } else { + android_log_level = ANDROID_LOG_ERROR; + } + break; + } + + std::stringstream ss; + const auto& fname = entry.FName(); + auto pos = fname.find("/"); + ss << (pos != std::string::npos ? fname.substr(pos + 1) : fname) << ":" + << entry.Line() << " " << entry.ToString(); + __android_log_write(android_log_level, "native", ss.str().c_str()); + + // Also log to stderr (for standalone Android apps). + // Don't use 'std::cerr' since it crashes on Android. + fprintf(stderr, "native : %s\n", ss.str().c_str()); + + // Android logging at level FATAL does not terminate execution, so abort() + // is still required to stop the program. + if (entry.log_severity() == absl::LogSeverity::kFatal) { + abort(); + } +#else // PLATFORM_POSIX_ANDROID + static const internal::VlogFileMgr vlog_file; + static bool log_thread_id = internal::EmitThreadIdFromEnv(); + uint64_t now_micros = EnvTime::NowMicros(); + time_t now_seconds = static_cast(now_micros / 1000000); + int32_t micros_remainder = static_cast(now_micros % 1000000); + const size_t time_buffer_size = 30; + char time_buffer[time_buffer_size]; + struct tm* tp; +#if defined(__linux__) || defined(__APPLE__) + struct tm now_tm; + tp = localtime_r(&now_seconds, &now_tm); +#else + tp = localtime(&now_seconds); // NOLINT(runtime/threadsafe_fn) +#endif + strftime(time_buffer, time_buffer_size, "%Y-%m-%d %H:%M:%S", tp); + uint64_t tid = absl::base_internal::GetTID(); + constexpr size_t kTidBufferSize = + (1 + std::numeric_limits::digits10 + 1); + char tid_buffer[kTidBufferSize] = ""; + if (log_thread_id) { + absl::SNPrintF(tid_buffer, sizeof(tid_buffer), " %7u", tid); + } + + char sev; + switch (entry.log_severity()) { + case absl::LogSeverity::kInfo: + sev = 'I'; + break; + + case absl::LogSeverity::kWarning: + sev = 'W'; + break; + + case absl::LogSeverity::kError: + sev = 'E'; + break; + + case absl::LogSeverity::kFatal: + sev = 'F'; + break; + + default: + assert(false && "Unknown logging severity"); + sev = '?'; + break; + } + + absl::FPrintF(vlog_file.FilePtr(), "%s.%06d: %c%s %s:%d] %s\n", time_buffer, + micros_remainder, sev, tid_buffer, entry.FName().c_str(), + entry.Line(), entry.ToString().c_str()); + fflush(vlog_file.FilePtr()); // Ensure logs are written immediately. +#endif // PLATFORM_POSIX_ANDROID +} + +void UpdateLogVerbosityIfDefined(const char* env_var) {} + +} // namespace torch_xla +} // namespace tsl diff --git a/torch_xla/csrc/runtime/tsl_platform_logging.h b/torch_xla/csrc/runtime/tsl_platform_logging.h new file mode 100644 index 00000000000..befff2988ad --- /dev/null +++ b/torch_xla/csrc/runtime/tsl_platform_logging.h @@ -0,0 +1,684 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +/* + * This file was copied from the OpenXLA repository + * (https://github.com/openxla/xla), before it was deleted. + * + * Commit: 20358a12f26199d016e6e690fe31a4a0a141226e + * Date: 2025-08-26 + */ + +#if defined(_WIN32) +// prevent compile error because MSVC doesn't realize in debug build that +// LOG(FATAL) finally invokes abort() +#pragma warning(disable : 4716) +#endif // _WIN32 + +#ifndef XLA_TORCH_XLA_CSRC_RUNTIME_TSL_PLATFORM_LOGGING_H_ +#define XLA_TORCH_XLA_CSRC_RUNTIME_TSL_PLATFORM_LOGGING_H_ + +// IWYU pragma: private, include "xla/tsl/platform/logging.h" +// IWYU pragma: friend third_party/tensorflow/compiler/xla/tsl/platform/logging.h + +#include +#include +#include +#include +#include +#include + +#include "absl/base/log_severity.h" +#include "absl/strings/string_view.h" +#include "xla/tsl/platform/logging.h" +#include "xla/tsl/platform/macros.h" +#include "xla/tsl/platform/types.h" + +/* + * Do not undef. + * + * // TODO(mrry): Prevent this Windows.h #define from leaking out of our + * headers. #undef ERROR + * + * // Undef everything in case we're being mixed with some other Google library + * // which already defined them itself. Presumably all Google libraries will + * // support the same syntax for these so it should not be a big deal if they + * // end up using our definitions instead. + * #undef LOG + * #undef LOG_EVERY_N + * #undef LOG_FIRST_N + * #undef LOG_EVERY_POW_2 + * #undef LOG_EVERY_N_SEC + * #undef VLOG + * + * #undef CHECK + * #undef CHECK_EQ + * #undef CHECK_NE + * #undef CHECK_LT + * #undef CHECK_LE + * #undef CHECK_GT + * #undef CHECK_GE + * + * #undef DCHECK + * #undef DCHECK_EQ + * #undef DCHECK_NE + * #undef DCHECK_LT + * #undef DCHECK_LE + * #undef DCHECK_GT + * #undef DCHECK_GE + * + * #undef QCHECK + * #undef QCHECK_EQ + * #undef QCHECK_NE + * #undef QCHECK_LT + * #undef QCHECK_LE + * #undef QCHECK_GT + * #undef QCHECK_GE + * + * #undef PCHECK + * + */ + +namespace tsl { + +namespace torch_xla { + +namespace internal { + +class LogMessage : public std::basic_ostringstream { + public: + LogMessage(const char* fname, int line, absl::LogSeverity severity); + ~LogMessage() override; + + // Change the location of the log message. + LogMessage& AtLocation(absl::string_view fname, int line); + + // Returns the maximum log level for VLOG statements. + // E.g., if MaxVLogLevel() is 2, then VLOG(2) statements will produce output, + // but VLOG(3) will not. Defaults to 0. + static int MaxVLogLevel(); + + // Returns whether VLOG level lvl is activated for the file fname. + // + // E.g. if the environment variable TF_CPP_VMODULE contains foo=3 and fname is + // foo.cc and lvl is <= 3, this will return true. It will also return true if + // the level is lower or equal to TF_CPP_MAX_VLOG_LEVEL (default zero). + // + // It is expected that the result of this query will be cached in the VLOG-ing + // call site to avoid repeated lookups. This routine performs a hash-map + // access against the VLOG-ing specification provided by the env var. + static bool VmoduleActivated(const char* fname, int level); + + protected: + void GenerateLogMessage(); + + private: + absl::string_view fname_; + int line_; + absl::LogSeverity severity_; +}; + +// Uses the lower operator & precedence to voidify a LogMessage reference, so +// that the ternary VLOG() implementation is balanced, type wise. +struct Voidifier { + template + void operator&(const T&) const {} +}; + +// LogMessageFatal ensures the process will exit in failure after +// logging this message. +class LogMessageFatal : public LogMessage { + public: + LogMessageFatal(const char* file, int line) ABSL_ATTRIBUTE_COLD; + ABSL_ATTRIBUTE_NORETURN ~LogMessageFatal() override; +}; + +// LogMessageNull supports the DVLOG macro by simply dropping any log messages. +class LogMessageNull : public std::basic_ostringstream { + public: + LogMessageNull() = default; + ~LogMessageNull() override {} +}; + +#define _TF_LOG_INFO \ + ::tsl::torch_xla::internal::LogMessage(__FILE__, __LINE__, \ + absl::LogSeverity::kInfo) +#define _TF_LOG_WARNING \ + ::tsl::torch_xla::internal::LogMessage(__FILE__, __LINE__, \ + absl::LogSeverity::kWarning) +#define _TF_LOG_ERROR \ + ::tsl::torch_xla::internal::LogMessage(__FILE__, __LINE__, \ + absl::LogSeverity::kError) +#define _TF_LOG_FATAL \ + ::tsl::torch_xla::internal::LogMessageFatal(__FILE__, __LINE__) + +#define _TF_LOG_QFATAL _TF_LOG_FATAL + +#ifdef NDEBUG +#define _TF_LOG_DFATAL _TF_LOG_ERROR +#else +#define _TF_LOG_DFATAL _TF_LOG_FATAL +#endif + +/* + * Avoiding duplicated macro. + * + * #define LOG(severity) _TF_LOG_##severity + * + * + * + * #ifdef IS_MOBILE_PLATFORM + * + * // Turn VLOG off when under mobile devices for considerations of binary size. + * #define VLOG_IS_ON(lvl) ((lvl) <= 0) + * + * #else + * + * // Otherwise, set TF_CPP_MAX_VLOG_LEVEL environment to update minimum log + * level + * // of VLOG, or TF_CPP_VMODULE to set the minimum log level for individual + * // translation units. + * #define VLOG_IS_ON(lvl) \ + * (([](int level, const char* fname) { \ + * static const bool vmodule_activated = \ + * ::tsl::torch_xla::internal::LogMessage::VmoduleActivated(fname, + * level); \ + * return vmodule_activated; \ + * })(lvl, __FILE__)) + * + * #endif + * + * #define VLOG(level) \ + * TF_PREDICT_TRUE(!VLOG_IS_ON(level)) \ + * ? (void)0 \ + * : ::tsl::torch_xla::internal::Voidifier() & \ + * ::tsl::torch_xla::internal::LogMessage(__FILE__, __LINE__, \ + * absl::LogSeverity::kInfo) + * + * // `DVLOG` behaves like `VLOG` in debug mode (i.e. `#ifndef NDEBUG`). + * // Otherwise, it compiles away and does nothing. + * #ifndef NDEBUG + * #define DVLOG VLOG + * #else + * #define DVLOG(verbose_level) \ + * while (false && (verbose_level) > 0) + * ::tsl::torch_xla::internal::LogMessageNull() #endif + * + */ + +class LogEveryNState { + public: + bool ShouldLog(int n); + uint32_t counter() { return counter_.load(std::memory_order_relaxed); } + + private: + std::atomic counter_{0}; +}; + +class LogFirstNState { + public: + bool ShouldLog(int n); + uint32 counter() { return counter_.load(std::memory_order_relaxed); } + + private: + std::atomic counter_{0}; +}; + +class LogEveryPow2State { + public: + bool ShouldLog(int ignored); + uint32 counter() { return counter_.load(std::memory_order_relaxed); } + + private: + std::atomic counter_{0}; +}; + +class LogEveryNSecState { + public: + bool ShouldLog(double seconds); + uint32 counter() { return counter_.load(std::memory_order_relaxed); } + + private: + std::atomic counter_{0}; + // Cycle count according to CycleClock that we should next log at. + std::atomic next_log_time_cycles_{0}; +}; + +/* + * Avoiding duplicated macro. + * // This macro has a lot going on! + * // + * // * A local static (`logging_internal_stateful_condition_state`) is + * // declared in a scope such that each `LOG_EVERY_N` (etc.) line has its own + * // state. + * // * `COUNTER`, the third variable, is used to support `<< COUNTER`. It is + * not + * // mangled, so shadowing can be a problem, albeit more of a + * // shoot-yourself-in-the-foot one. Don't name your variables `COUNTER`. + * // * A single for loop can declare state and also test + * // `condition && state.ShouldLog()`, but there's no way to constrain it to + * run + * // only once (or not at all) without declaring another variable. The outer + * // for-loop declares this variable (`do_log`). + * // * Using for loops instead of if statements means there's no risk of an + * // ambiguous dangling else statement. + * #define LOGGING_INTERNAL_STATEFUL_CONDITION(kind, condition, arg) \ + * for (bool logging_internal_stateful_condition_do_log(condition); \ + * logging_internal_stateful_condition_do_log; \ + * logging_internal_stateful_condition_do_log = false) \ + * for (static ::tsl::torch_xla::internal::Log##kind##State \ + * logging_internal_stateful_condition_state; \ + * logging_internal_stateful_condition_do_log && \ + * logging_internal_stateful_condition_state.ShouldLog(arg); \ + * logging_internal_stateful_condition_do_log = false) \ + * for (const uint32_t COUNTER ABSL_ATTRIBUTE_UNUSED = \ + * logging_internal_stateful_condition_state.counter(); \ + * logging_internal_stateful_condition_do_log; \ + * logging_internal_stateful_condition_do_log = false) + * + * // An instance of `LOG_EVERY_N` increments a hidden zero-initialized counter + * // every time execution passes through it and logs the specified message when + * // the counter's value is a multiple of `n`, doing nothing otherwise. Each + * // instance has its own counter. The counter's value can be logged by + * streaming + * // the symbol `COUNTER`. `LOG_EVERY_N` is thread-safe. + * // Example: + * // + * // for (const auto& user : all_users) { + * // LOG_EVERY_N(INFO, 1000) << "Processing user #" << COUNTER; + * // ProcessUser(user); + * // } + * #define LOG_EVERY_N(severity, n) \ + * LOGGING_INTERNAL_STATEFUL_CONDITION(EveryN, true, n) \ + * LOG(severity) + * // `LOG_FIRST_N` behaves like `LOG_EVERY_N` except that the specified message + * is + * // logged when the counter's value is less than `n`. `LOG_FIRST_N` is + * // thread-safe. + * #define LOG_FIRST_N(severity, n) \ + * LOGGING_INTERNAL_STATEFUL_CONDITION(FirstN, true, n) \ + * LOG(severity) + * // `LOG_EVERY_POW_2` behaves like `LOG_EVERY_N` except that the specified + * // message is logged when the counter's value is a power of 2. + * // `LOG_EVERY_POW_2` is thread-safe. + * #define LOG_EVERY_POW_2(severity) \ + * LOGGING_INTERNAL_STATEFUL_CONDITION(EveryPow2, true, 0) \ + * LOG(severity) + * // An instance of `LOG_EVERY_N_SEC` uses a hidden state variable to log the + * // specified message at most once every `n_seconds`. A hidden counter of + * // executions (whether a message is logged or not) is also maintained and can + * be + * // logged by streaming the symbol `COUNTER`. `LOG_EVERY_N_SEC` is + * thread-safe. + * // Example: + * // + * // LOG_EVERY_N_SEC(INFO, 2.5) << "Got " << COUNTER << " cookies so far"; + * #define LOG_EVERY_N_SEC(severity, n_seconds) \ + * LOGGING_INTERNAL_STATEFUL_CONDITION(EveryNSec, true, n_seconds) \ + * LOG(severity) + * + * // CHECK dies with a fatal error if condition is not true. It is *not* + * // controlled by NDEBUG, so the check will be executed regardless of + * // compilation mode. Therefore, it is safe to do things like: + * // CHECK(fp->Write(x) == 4) + * #define CHECK(condition) \ + * if (TF_PREDICT_FALSE(!(condition))) \ + * LOG(FATAL) << "Check failed: " #condition " " + * + */ + +// Function is overloaded for integral types to allow static const +// integrals declared in classes and not defined to be used as arguments to +// CHECK* macros. It's not encouraged though. +template +inline const T& GetReferenceableValue(const T& t) { + return t; +} +inline char GetReferenceableValue(char t) { return t; } +inline unsigned char GetReferenceableValue(unsigned char t) { return t; } +inline signed char GetReferenceableValue(signed char t) { return t; } +inline int16 GetReferenceableValue(int16_t t) { return t; } +inline uint16 GetReferenceableValue(uint16 t) { return t; } +inline int GetReferenceableValue(int t) { return t; } +inline unsigned int GetReferenceableValue(unsigned int t) { return t; } +inline int64_t GetReferenceableValue(int64_t t) { return t; } +inline uint64 GetReferenceableValue(uint64 t) { return t; } + +// This formats a value for a failing CHECK_XX statement. Ordinarily, +// it uses the definition for operator<<, with a few special cases below. +template +inline void MakeCheckOpValueString(std::ostream* os, const T& v) { + (*os) << v; +} + +// Overrides for char types provide readable values for unprintable +// characters. +template <> +void MakeCheckOpValueString(std::ostream* os, const char& v); +template <> +void MakeCheckOpValueString(std::ostream* os, const signed char& v); +template <> +void MakeCheckOpValueString(std::ostream* os, const unsigned char& v); + +#if LANG_CXX11 +// We need an explicit specialization for std::nullptr_t. +template <> +void MakeCheckOpValueString(std::ostream* os, const std::nullptr_t& v); +#endif + +// A container for a string pointer which can be evaluated to a bool - +// true iff the pointer is non-NULL. +struct CheckOpString { + explicit CheckOpString(string* str) : str_(str) {} + // No destructor: if str_ is non-NULL, we're about to LOG(FATAL), + // so there's no point in cleaning up str_. + explicit operator bool() const { return TF_PREDICT_FALSE(str_ != nullptr); } + string* str_; +}; + +// Build the error message string. Specify no inlining for code size. +template +string* MakeCheckOpString(const T1& v1, const T2& v2, + const char* exprtext) ABSL_ATTRIBUTE_NOINLINE; + +// A helper class for formatting "expr (V1 vs. V2)" in a CHECK_XX +// statement. See MakeCheckOpString for sample usage. Other +// approaches were considered: use of a template method (e.g., +// base::BuildCheckOpString(exprtext, base::Print, &v1, +// base::Print, &v2), however this approach has complications +// related to volatile arguments and function-pointer arguments). +class CheckOpMessageBuilder { + public: + // Inserts "exprtext" and " (" to the stream. + explicit CheckOpMessageBuilder(const char* exprtext); + // Deletes "stream_". + ~CheckOpMessageBuilder(); + // For inserting the first variable. + std::ostream* ForVar1() { return stream_; } + // For inserting the second variable (adds an intermediate " vs. "). + std::ostream* ForVar2(); + // Get the result (inserts the closing ")"). + string* NewString(); + + private: + std::ostringstream* stream_; +}; + +template +string* MakeCheckOpString(const T1& v1, const T2& v2, const char* exprtext) { + CheckOpMessageBuilder comb(exprtext); + MakeCheckOpValueString(comb.ForVar1(), v1); + MakeCheckOpValueString(comb.ForVar2(), v2); + return comb.NewString(); +} + +// Helper functions for CHECK_OP macro. +// We use the full name Check_EQ, Check_NE, etc. in case the file including +// absl/log/log.h provides its own #defines for the simpler names EQ, NE, etc. +// This happens if, for example, those are used as token names in a +// yacc grammar. +// The (int, int) overload works around the issue that the compiler +// will not instantiate the template version of the function on values of +// unnamed enum type - see comment below. +#define TF_DEFINE_CHECK_OP_IMPL(name, op) \ + template \ + inline string* name##Impl(const T1& v1, const T2& v2, \ + const char* exprtext) { \ + if (TF_PREDICT_TRUE(v1 op v2)) \ + return NULL; \ + else \ + return ::tsl::torch_xla::internal::MakeCheckOpString(v1, v2, exprtext); \ + } \ + inline string* name##Impl(int v1, int v2, const char* exprtext) { \ + return name##Impl(v1, v2, exprtext); \ + } + +// The (size_t, int) and (int, size_t) specialization are to handle unsigned +// comparison errors while still being thorough with the comparison. + +TF_DEFINE_CHECK_OP_IMPL(Check_EQ, ==) +// Compilation error with CHECK_EQ(NULL, x)? +// Use CHECK(x == NULL) instead. + +inline string* Check_EQImpl(int v1, size_t v2, const char* exprtext) { + if (TF_PREDICT_FALSE(v1 < 0)) + ::tsl::torch_xla::internal::MakeCheckOpString(v1, v2, exprtext); + + return Check_EQImpl(size_t(v1), v2, exprtext); +} + +inline string* Check_EQImpl(size_t v1, int v2, const char* exprtext) { + return Check_EQImpl(v2, v1, exprtext); +} + +TF_DEFINE_CHECK_OP_IMPL(Check_NE, !=) + +inline string* Check_NEImpl(int v1, size_t v2, const char* exprtext) { + if (v1 < 0) return NULL; + + return Check_NEImpl(size_t(v1), v2, exprtext); +} + +inline string* Check_NEImpl(size_t v1, int v2, const char* exprtext) { + return Check_NEImpl(v2, v1, exprtext); +} + +TF_DEFINE_CHECK_OP_IMPL(Check_LE, <=) + +inline string* Check_LEImpl(int v1, size_t v2, const char* exprtext) { + if (v1 <= 0) return NULL; + + return Check_LEImpl(size_t(v1), v2, exprtext); +} + +inline string* Check_LEImpl(size_t v1, int v2, const char* exprtext) { + if (TF_PREDICT_FALSE(v2 < 0)) + return ::tsl::torch_xla::internal::MakeCheckOpString(v1, v2, exprtext); + return Check_LEImpl(v1, size_t(v2), exprtext); +} + +TF_DEFINE_CHECK_OP_IMPL(Check_LT, <) + +inline string* Check_LTImpl(int v1, size_t v2, const char* exprtext) { + if (v1 < 0) return NULL; + + return Check_LTImpl(size_t(v1), v2, exprtext); +} + +inline string* Check_LTImpl(size_t v1, int v2, const char* exprtext) { + if (v2 < 0) + return ::tsl::torch_xla::internal::MakeCheckOpString(v1, v2, exprtext); + return Check_LTImpl(v1, size_t(v2), exprtext); +} + +// Implement GE,GT in terms of LE,LT +template +inline string* Check_GEImpl(const T1& v1, const T2& v2, const char* exprtext) { + return Check_LEImpl(v2, v1, exprtext); +} + +template +inline string* Check_GTImpl(const T1& v1, const T2& v2, const char* exprtext) { + return Check_LTImpl(v2, v1, exprtext); +} + +#undef TF_DEFINE_CHECK_OP_IMPL + +/* + * Avoiding duplicated macro. + * + * // In optimized mode, use CheckOpString to hint to compiler that + * // the while condition is unlikely. + * #define CHECK_OP_LOG(name, op, val1, val2) \ + * while (::tsl::torch_xla::internal::CheckOpString + * _result{::tsl::torch_xla::internal::name##Impl( \ + * ::tsl::torch_xla::internal::GetReferenceableValue(val1), \ + * ::tsl::torch_xla::internal::GetReferenceableValue(val2), #val1 " " #op + * " " #val2)}) \ + * ::tsl::torch_xla::internal::LogMessageFatal(__FILE__, __LINE__) << + * *(_result.str_) + * + * #define CHECK_OP(name, op, val1, val2) CHECK_OP_LOG(name, op, val1, val2) + * + * // CHECK_EQ/NE/... + * #define CHECK_EQ(val1, val2) CHECK_OP(Check_EQ, ==, val1, val2) + * #define CHECK_NE(val1, val2) CHECK_OP(Check_NE, !=, val1, val2) + * #define CHECK_LE(val1, val2) CHECK_OP(Check_LE, <=, val1, val2) + * #define CHECK_LT(val1, val2) CHECK_OP(Check_LT, <, val1, val2) + * #define CHECK_GE(val1, val2) CHECK_OP(Check_GE, >=, val1, val2) + * #define CHECK_GT(val1, val2) CHECK_OP(Check_GT, >, val1, val2) + * #define CHECK_NOTNULL(val) \ + * ::tsl::torch_xla::internal::CheckNotNull(__FILE__, __LINE__, \ + * "'" #val "' Must be non NULL", (val)) + * + * #ifndef NDEBUG + * // DCHECK_EQ/NE/... + * #define DCHECK(condition) CHECK(condition) + * #define DCHECK_EQ(val1, val2) CHECK_EQ(val1, val2) + * #define DCHECK_NE(val1, val2) CHECK_NE(val1, val2) + * #define DCHECK_LE(val1, val2) CHECK_LE(val1, val2) + * #define DCHECK_LT(val1, val2) CHECK_LT(val1, val2) + * #define DCHECK_GE(val1, val2) CHECK_GE(val1, val2) + * #define DCHECK_GT(val1, val2) CHECK_GT(val1, val2) + * + * #else + * + * #define DCHECK(condition) \ + * while (false && (condition)) LOG(FATAL) + * + * // NDEBUG is defined, so DCHECK_EQ(x, y) and so on do nothing. + * // However, we still want the compiler to parse x and y, because + * // we don't want to lose potentially useful errors and warnings. + * // _DCHECK_NOP is a helper, and should not be used outside of this file. + * #define _TF_DCHECK_NOP(x, y) \ + * while (false && ((void)(x), (void)(y), 0)) LOG(FATAL) + * + * #define DCHECK_EQ(x, y) _TF_DCHECK_NOP(x, y) + * #define DCHECK_NE(x, y) _TF_DCHECK_NOP(x, y) + * #define DCHECK_LE(x, y) _TF_DCHECK_NOP(x, y) + * #define DCHECK_LT(x, y) _TF_DCHECK_NOP(x, y) + * #define DCHECK_GE(x, y) _TF_DCHECK_NOP(x, y) + * #define DCHECK_GT(x, y) _TF_DCHECK_NOP(x, y) + * + * #endif + * + * // These are for when you don't want a CHECK failure to print a verbose + * // stack trace. The implementation of CHECK* in this file already doesn't. + * #define QCHECK(condition) CHECK(condition) + * #define QCHECK_EQ(x, y) CHECK_EQ(x, y) + * #define QCHECK_NE(x, y) CHECK_NE(x, y) + * #define QCHECK_LE(x, y) CHECK_LE(x, y) + * #define QCHECK_LT(x, y) CHECK_LT(x, y) + * #define QCHECK_GE(x, y) CHECK_GE(x, y) + * #define QCHECK_GT(x, y) CHECK_GT(x, y) + * + */ + +template +T&& CheckNotNull(const char* file, int line, const char* exprtext, T&& t) { + if (t == nullptr) { + LogMessageFatal(file, line) << string(exprtext); + } + return std::forward(t); +} + +absl::LogSeverityAtLeast MinLogLevelFromEnv(); + +int MaxVLogLevelFromEnv(); + +} // namespace internal + +// LogSink support adapted from absl/log/log.h +// +// `LogSink` is an interface which can be extended to intercept and process +// all log messages. LogSink implementations must be thread-safe. A single +// instance will be called from whichever thread is performing a logging +// operation. +class TFLogEntry { + public: + explicit TFLogEntry(absl::LogSeverity severity, absl::string_view message) + : severity_(severity), message_(message) {} + + explicit TFLogEntry(absl::LogSeverity severity, absl::string_view fname, + int line, absl::string_view message) + : severity_(severity), fname_(fname), line_(line), message_(message) {} + + absl::LogSeverity log_severity() const { return severity_; } + std::string FName() const { return fname_; } + int Line() const { return line_; } + std::string ToString() const { return message_; } + absl::string_view text_message() const { return message_; } + + // Returning similar result as `text_message` as there is no prefix in this + // implementation. + absl::string_view text_message_with_prefix() const { return message_; } + + private: + const absl::LogSeverity severity_; + const std::string fname_; + int line_ = -1; + const std::string message_; +}; + +class TFLogSink { + public: + virtual ~TFLogSink() = default; + + // `Send` is called synchronously during the log statement. The logging + // module guarantees not to call `Send` concurrently on the same log sink. + // Implementations should be careful not to call`LOG` or `CHECK` or take + // any locks that might be held by the `LOG` caller, to avoid deadlock. + // + // `e` is guaranteed to remain valid until the subsequent call to + // `WaitTillSent` completes, so implementations may store a pointer to or + // copy of `e` (e.g. in a thread local variable) for use in `WaitTillSent`. + virtual void Send(const TFLogEntry& entry) = 0; + + // `WaitTillSent` blocks the calling thread (the thread that generated a log + // message) until the sink has finished processing the log message. + // `WaitTillSent` is called once per log message, following the call to + // `Send`. This may be useful when log messages are buffered or processed + // asynchronously by an expensive log sink. + // The default implementation returns immediately. Like `Send`, + // implementations should be careful not to call `LOG` or `CHECK or take any + // locks that might be held by the `LOG` caller, to avoid deadlock. + virtual void WaitTillSent() {} +}; + +// This is the default log sink. This log sink is used if there are no other +// log sinks registered. To disable the default log sink, set the +// "no_default_logger" Bazel config setting to true or define a +// NO_DEFAULT_LOGGER preprocessor symbol. This log sink will always log to +// stderr. +class TFDefaultLogSink : public TFLogSink { + public: + void Send(const TFLogEntry& entry) override; +}; + +// Add or remove a `LogSink` as a consumer of logging data. Thread-safe. +void TFAddLogSink(TFLogSink* sink); +void TFRemoveLogSink(TFLogSink* sink); + +// Get all the log sinks. Thread-safe. +std::vector TFGetLogSinks(); + +// Change verbose level of pre-defined files if envorionment +// variable `env_var` is defined. This is currently a no op. +void UpdateLogVerbosityIfDefined(const char* env_var); + +} // namespace torch_xla +} // namespace tsl + +#endif // XLA_TORCH_XLA_CSRC_RUNTIME_TSL_PLATFORM_LOGGING_H_ diff --git a/torch_xla/experimental/pallas_kernels/multi_queries_paged_attention_kernel.py b/torch_xla/experimental/pallas_kernels/multi_queries_paged_attention_kernel.py index dc03d7bca85..8835252216b 100644 --- a/torch_xla/experimental/pallas_kernels/multi_queries_paged_attention_kernel.py +++ b/torch_xla/experimental/pallas_kernels/multi_queries_paged_attention_kernel.py @@ -560,9 +560,9 @@ def qo_index_map(batch_index, kv_head_index, q_seq_blk_idx, *_): in_specs = [ q_block_spec, # Below 4 correspond to the 4 input: k_pages, k_scales_pages, q_pages, q_scales_pages. - pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.ANY), + pl.BlockSpec(memory_space=pltpu.MemorySpace.ANY), None, - pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.ANY), + pl.BlockSpec(memory_space=pltpu.MemorySpace.ANY), None, ] @@ -639,7 +639,7 @@ def lm_index_map(batch_index, kv_head_index, q_seq_blk_idx, *_): grid=grid, scratch_shapes=scratch_shapes, ), - compiler_params=pltpu.TPUCompilerParams( + compiler_params=pltpu.CompilerParams( # due to compute_block_indices, we loop batch, kv_head, q_blk, kv_blk, the order matters. dimension_semantics=( "arbitrary", diff --git a/torch_xla/experimental/pallas_kernels/quantized_matmul_kernel.py b/torch_xla/experimental/pallas_kernels/quantized_matmul_kernel.py index b4bd0c081f0..1664f33a4eb 100644 --- a/torch_xla/experimental/pallas_kernels/quantized_matmul_kernel.py +++ b/torch_xla/experimental/pallas_kernels/quantized_matmul_kernel.py @@ -295,7 +295,7 @@ def quantized_matmul_int8( grid=(n_bs, n_out, n_in), ), out_shape=jax.ShapeDtypeStruct((padded_bs, padded_out_features), x.dtype), - compiler_params=pltpu.TPUCompilerParams( + compiler_params=pltpu.CompilerParams( dimension_semantics=("parallel", "arbitrary", "arbitrary"), vmem_limit_bytes=vmem_limit_bytes, ), diff --git a/torch_xla/experimental/pallas_kernels/ragged_paged_attention_kernel.py b/torch_xla/experimental/pallas_kernels/ragged_paged_attention_kernel.py index 5a1be44515b..d1b3e39e066 100644 --- a/torch_xla/experimental/pallas_kernels/ragged_paged_attention_kernel.py +++ b/torch_xla/experimental/pallas_kernels/ragged_paged_attention_kernel.py @@ -854,7 +854,7 @@ def cur_page_indices_index_map(kv_head_idx, logical_q_blk_idx, kv_blk_idx, cur_page_indices_spec = pl.BlockSpec( (None, None, num_kv_pages_per_block), cur_page_indices_index_map, - memory_space=pltpu.TPUMemorySpace.SMEM, + memory_space=pltpu.MemorySpace.SMEM, ) page_size = k_pages.shape[2] @@ -874,14 +874,14 @@ def next_kv_blk_page_indices_index_map(kv_head_idx, logical_q_blk_idx, next_page_indices_spec = pl.BlockSpec( (None, None, num_kv_pages_per_block), next_kv_blk_page_indices_index_map, - memory_space=pltpu.TPUMemorySpace.SMEM, + memory_space=pltpu.MemorySpace.SMEM, ) in_specs = [ q_block_spec, # Below 4 correspond to the 4 input: k_pages, k_scales_pages, q_pages, q_scales_pages. - pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.ANY), + pl.BlockSpec(memory_space=pltpu.MemorySpace.ANY), None, - pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.ANY), + pl.BlockSpec(memory_space=pltpu.MemorySpace.ANY), None, cur_page_indices_spec, next_page_indices_spec, @@ -955,7 +955,7 @@ def next_kv_blk_page_indices_index_map(kv_head_idx, logical_q_blk_idx, grid=grid, scratch_shapes=scratch_shapes, ), - compiler_params=pltpu.TPUCompilerParams( + compiler_params=pltpu.CompilerParams( # due to compute_block_indices, we loop kv_head, q_blk, kv_blk, the order matters. dimension_semantics=( "arbitrary", diff --git a/torch_xla/experimental/pallas_kernels/ragged_paged_attention_v2.py b/torch_xla/experimental/pallas_kernels/ragged_paged_attention_v2.py index 8e124f2c41c..2d32ad0fe2e 100644 --- a/torch_xla/experimental/pallas_kernels/ragged_paged_attention_v2.py +++ b/torch_xla/experimental/pallas_kernels/ragged_paged_attention_v2.py @@ -2339,7 +2339,7 @@ def q_index_map(heads_blk_idx, q_blk_idx, *_): grid=grid, scratch_shapes=scratch_shapes, ), - compiler_params=pltpu.TPUCompilerParams( + compiler_params=pltpu.CompilerParams( dimension_semantics=( "arbitrary", "arbitrary",