Skip to content

Commit

Permalink
Add part (non-quantized K/V pages) of paged_attention_kernel tests ba…
Browse files Browse the repository at this point in the history
…ck for TPU v6.

The paged_attention_kernel tests for TPU v6 was disabled in the past but I discovered that all the failing tests have `are_kv_quantized=True`. So we can still test the non-quantized part on TPU v6.

PiperOrigin-RevId: 716375593
  • Loading branch information
Google-ML-Automation committed Jan 17, 2025
1 parent a4a657b commit 59ae7df
Showing 1 changed file with 2 additions and 4 deletions.
6 changes: 2 additions & 4 deletions tests/pallas/tpu_paged_attention_kernel_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,10 +108,6 @@ def _megacore_enabled():

@jtu.with_config(jax_numpy_dtype_promotion="standard")
class PagedAttentionKernelTest(jtu.JaxTestCase):
def setUp(self):
super().setUp()
if jtu.is_device_tpu_at_least(6):
self.skipTest('Not implemented for TPU v6')

@parameterized.product(
dtype=(jnp.float32, jnp.bfloat16),
Expand Down Expand Up @@ -144,6 +140,8 @@ def test_paged_attention(
# weight and scale tensors for quantized tensors. When enabled on TPUv4,
# the tests sometimes failed with resource exhausted error.
self.skipTest("Quantization is not supported on TPU v4")
if jtu.is_device_tpu_at_least(6) and are_kv_quantized:
self.skipTest("Quantization is not supported on TPU v6")
if megacore_mode and not _megacore_enabled():
self.skipTest("Megacore is only available on TPU v4 or TPU v5p")
if num_kv_heads % 2 != 0 and megacore_mode == "kv_head":
Expand Down

0 comments on commit 59ae7df

Please sign in to comment.