Update PyTorch and XLA pin. #9668
Open
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
This PR updates the following pins:
libtpu
: 0.0.21 to 0.0.24jaxlib
): 0.7.1 to 0.8.0Key Changes:
@python
was replaced by@rules_python
atBUILD
file (ref: jax-ml/jax#31709)TF_ATTRIBUTE_NORETURN
was removed in favor of abseil (ref: openxla/xla#31699)xla/pjrt/tfrt_cpu_pjrt_client.h
file byxla/pjrt/cpu/cpu_client.h
inpjrt_registry.cpp
(openxla/xla#30936)xla/tsl/platform/default/logging.*
totorch_xla/csrc/runtime/tsl_platform_logging.*
Update (Oct 3):
static_assert(false)
for GCC < 13 (ref)flax
pin, since it does not overwritejax
anymoreTPU*
prefix ofjax.experimental.pallas.tpu
components (ref: jax-ml/jax#29115)