diff --git a/torchax/test/test_view.py b/torchax/test/test_view.py index 79fdd5afb8d4..9b776aec5fc0 100644 --- a/torchax/test/test_view.py +++ b/torchax/test/test_view.py @@ -12,7 +12,38 @@ class TrainTest(unittest.TestCase): def setUp(self): torch.manual_seed(0) torchax.enable_globally() + + def test_index_copy_(self): + x = torch.zeros((10, 10), device="jax") + x_view = x[0, :] + indices = torch.arange(5, device="jax") + new_value = torch.ones((5,), device="jax") + x_view.index_copy_(0, indices, new_value) + self.assertEqual(type(x), Tensor) + self.assertEqual(type(x_view), View) + self.assertEqual(x.shape, (10, 10)) + self.assertEqual(x.sum(), 5) + def test_flatten(self): + x = torch.zeros((10, 10), device="jax") + x1 = x.flatten(0, 1) + y = torch.ones(100, device="jax") + x1.copy_(y) + self.assertEqual(type(x), Tensor) + self.assertEqual(type(x1), View) + self.assertEqual(x.shape, (10, 10)) + self.assertEqual(x.sum(), 100) + + + def test_narrow(self): + x = torch.zeros((10, 10), device="jax") + x = x.narrow(0, 0, 5).narrow(0, 0, 5) + y = torch.ones((5, 10), device="jax") + x.copy_(y) + self.assertEqual(type(x), View) + self.assertEqual(x.shape, (5, 10)) + self.assertEqual(x.sum(), 50) + def test_copy_(self): x = torch.zeros((10, 10), device="jax") y = torch.ones((5, 5), device="jax") @@ -379,4 +410,3 @@ def test_scatter_reduce_(self): # Check specific values were reduced self.assertTrue(torch.all(x[0, 0] == 5.0)) self.assertEqual(x.sum(), 37.0) - diff --git a/torchax/torchax/__init__.py b/torchax/torchax/__init__.py index c5806db34a42..df2e94c28329 100644 --- a/torchax/torchax/__init__.py +++ b/torchax/torchax/__init__.py @@ -120,3 +120,6 @@ def compile(fn, options: Optional[CompileOptions] = None): raise RuntimeError('dynamo mode is not supported yet') elif options.mode == 'export': raise RuntimeError('export mode is not supported yet') + +# Intercept torch._sync as no-op +torch._sync = lambda *args, **kwargs: None diff --git a/torchax/torchax/decompositions.py b/torchax/torchax/decompositions.py index 257eca59ce36..03213fe4307b 100644 --- a/torchax/torchax/decompositions.py +++ b/torchax/torchax/decompositions.py @@ -776,4 +776,5 @@ def get_summand( MUTABLE_DECOMPOSITION = [ torch.ops.aten.bernoulli_.Tensor, torch.ops.aten.bernoulli_.float, + torch.ops.aten.index_copy_.default, ] diff --git a/torchax/torchax/interop.py b/torchax/torchax/interop.py index 8fa2e63b83dc..e1ab4088035f 100644 --- a/torchax/torchax/interop.py +++ b/torchax/torchax/interop.py @@ -12,6 +12,7 @@ from torchax import tensor from torchax import util import torchax +from torchax.view import View from torchax.types import JaxValue, TorchValue, JaxCallable, TorchCallable @@ -172,7 +173,7 @@ def _jax_view(t: TorchValue) -> JaxValue: # t is an object from torch land # view it as-if it's a jax land object if isinstance(t, torch.Tensor): - assert isinstance(t, tensor.Tensor), type(t) + assert isinstance(t, tensor.Tensor) or isinstance(t, View), type(t) return t.jax() if isinstance(t, type(torch.int32)): return tensor.t2j_dtype(t) diff --git a/torchax/torchax/ops/jaten.py b/torchax/torchax/ops/jaten.py index c47068165949..7c49e6e30f01 100644 --- a/torchax/torchax/ops/jaten.py +++ b/torchax/torchax/ops/jaten.py @@ -15,7 +15,7 @@ from torchax.ops import op_base, mappings from torchax import interop from torchax.ops import jax_reimplement -from torchax.view import View +from torchax.view import View, NarrowInfo, ReshapeInfo from torchax.tensor import Tensor # Keys are OpOverload, value is a callable that takes # Tensor @@ -57,6 +57,7 @@ torch.ops.aten.scatter_add_: torch.ops.aten.scatter_add, torch.ops.aten.scatter_reduce_.two: torch.ops.aten.scatter_reduce, torch.ops.aten.scatter_: torch.ops.aten.scatter, + torch.ops.aten.index_put_: torch.ops.aten.index_put, } # Note: tuple comparisons work intuitively, e.g. `_jax_version >= (0, 4, 32)`. @@ -103,9 +104,10 @@ def inner(func): torch.ops.aten.view, torch.ops.aten._unsafe_view, torch.ops.aten.reshape, + is_jax_function=False, ) def _aten_unsafe_view(x, shape): - return jnp.reshape(x, shape) + return View(x, ReshapeInfo(shape=shape), env=x._env) @op(torch.ops.aten.add.Tensor) @@ -128,6 +130,8 @@ def _aten_copy(x, y, memory_format=None): if isinstance(x, View): x.update(y) return x + if isinstance(y, View): + y = y.torch() if x.ndim == 1 and y.ndim == 0: # case of torch.empty((1,)).copy_(tensor(N)) @@ -397,8 +401,8 @@ def _aten_triu(m, k): return jnp.triu(m, k) -@op(torch.ops.aten.slice) -@op(torch.ops.aten.slice_copy) +@op(torch.ops.aten.slice, is_jax_function=False, is_view_op=True) +@op(torch.ops.aten.slice_copy, is_jax_function=False, is_view_op=True) def _aten_slice(self, dim=0, start=None, end=None, step=1): if dim < 0: dim += self.ndim @@ -411,7 +415,7 @@ def _aten_slice(self, dim=0, start=None, end=None, step=1): dims.append(sl) else: dims.append(slice(None, None, None)) - return self[tuple(dims)] + return View(self, NarrowInfo(slices=tuple(dims)), env = self._env) @op(torch.ops.aten.detach) @@ -752,7 +756,6 @@ def _aten_empty_strided(sizes, stride, dtype=None, **kwargs): return jnp.empty(sizes, dtype=dtype) -@op(torch.ops.aten.index_put_) @op(torch.ops.aten.index_put) def _aten_index_put(self, indexes, values, accumulate=False): indexes = [slice(None, None, None) if i is None else i for i in indexes] diff --git a/torchax/torchax/ops/jtorch.py b/torchax/torchax/ops/jtorch.py index 0578194c7bf1..74120c792728 100644 --- a/torchax/torchax/ops/jtorch.py +++ b/torchax/torchax/ops/jtorch.py @@ -512,3 +512,51 @@ def functional_linear(self, weights, bias=None): if bias is not None: res += bias return res + +try: + # TODO: Currently the following ops are wrapped in the try + # catch block because torch.ops.xla is not in the torch ops + # registry. Either we import torch_xla in the upper level, + # or modify the the register_function to support this. + @register_function(torch.ops.xla.dynamo_set_buffer_donor_) + def _dynamo_set_buffer_donor(self, donor): + pass + + @register_function(torch.ops.xla.ragged_paged_attention) + def _ragged_paged_attention( + q: jax.Array, # [max_num_batched_tokens, num_q_heads, head_dim] + kv_pages: jax.Array, # [total_num_pages, page_size, num_combined_kv_heads, head_dim] + kv_lens: jax.Array, # i32[max_num_seqs] + page_indices: jax.Array, # i32[max_num_seqs, pages_per_seq] + cu_q_lens: jax.Array, # i32[max_num_seqs + 1] + num_seqs: jax.Array, # i32[1] + use_kernel: bool = True, + sm_scale: float = 1.0, + sliding_window: int | None = None, + soft_cap: float | None = None, + mask_value: float | None = None, + num_kv_pages_per_block: int | None = None, + num_queries_per_block: int | None = None, + vmem_limit_bytes: int | None = None, + ): + + from torch_xla.experimental.pallas_kernels.ragged_paged_attention_v2 import ragged_paged_attention as ragged_paged_attention_kernel + return ragged_paged_attention_kernel( + q = q, + kv_pages = kv_pages, + kv_lens = kv_lens, + page_indices = page_indices, + cu_q_lens = cu_q_lens, + num_seqs = num_seqs, + sm_scale = sm_scale, + sliding_window = sliding_window, + soft_cap = soft_cap, + mask_value = mask_value, + num_kv_pages_per_block = num_kv_pages_per_block, + num_queries_per_block = num_queries_per_block, + vmem_limit_bytes = vmem_limit_bytes, + ) +except Exception as e: + pass + + diff --git a/torchax/torchax/tensor.py b/torchax/torchax/tensor.py index 90afa30dbd9c..8cf77f3d0071 100644 --- a/torchax/torchax/tensor.py +++ b/torchax/torchax/tensor.py @@ -100,15 +100,6 @@ def shape(self): def ndim(self): return len(self._elem.shape) - def flatten(self, start_dim=0, end_dim=-1): - if end_dim == -1: - end_dim = self.ndim - new_shape = ( - self._elem.shape[:start_dim] + (-1,) + self._elem.shape[end_dim + 1 :] - ) - new_elem = jnp.reshape(self._elem, new_shape) - return Tensor(new_elem, self._env) - # return torch.reshape(self, new_shape) def __setitem__(self, key, val): key, val = self._env.t2j_iso((key, val)) diff --git a/torchax/torchax/view.py b/torchax/torchax/view.py index 8c8bdf538104..30a4f0338899 100644 --- a/torchax/torchax/view.py +++ b/torchax/torchax/view.py @@ -4,6 +4,7 @@ from enum import Enum from typing import Union, List, Tuple, Optional, Any, cast from abc import ABC, abstractmethod +import torch.utils._pytree as pytree # Reference to original PyTorch native functions # https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/native/native_functions.yaml @@ -112,6 +113,35 @@ def update_tensor(self, new_value: jax.Array, jax_array: jax.Array) -> jax.Array def calculate_output_shape(self, source: jax.Array) -> List[int]: return source[self.slices].shape +class ReshapeInfo(ViewInfo): + """ + Represents a reshape operation on a tensor. + Handles operations like tensor.reshape(1, 2, 3) and tensor.reshape(-1, 1) + """ + + def __init__(self, shape: Tuple[int, ...]) -> None: + """ + Args: + shape: The shape to reshape the tensor to. + E.g. jax_array.reshape(shape) will return the transformed tensor. + """ + super().__init__(ViewInfoType.RESHAPE) + self.shape = shape + + def __eq__(self, other: object) -> bool: + if not isinstance(other, ReshapeInfo): + return False + return self.shape == other.shape + + def transform_tensor(self, jax_array: jax.Array) -> jax.Array: + return jax_array.reshape(self.shape) + + def update_tensor(self, new_value: jax.Array, jax_array: jax.Array) -> jax.Array: + return new_value.reshape(jax_array.shape) + + def calculate_output_shape(self, source: jax.Array) -> List[int]: + return source.reshape(self.shape).shape + class SelectInfo(ViewInfo): """ @@ -321,6 +351,8 @@ def update( for view_info, parent_array in zip( reversed(view_infos), reversed(intermediate_values) ): + assert isinstance(new_values, jax.Array) + assert isinstance(parent_array, jax.Array) # Apply the inverse transformation to propagate changes back new_values = view_info.update_tensor(new_values, parent_array) @@ -366,6 +398,8 @@ def jax(self) -> jax.Array: return result def __setitem__(self, indexes, val): + # Handle tensor indexing + indexes = pytree.tree_map(lambda x: x.jax() if isinstance(x, torch.Tensor) else x, indexes) view_infos = self.get_transformation_chain() + [NarrowInfo(indexes)] self.update(view_infos=view_infos, new_values=val) @@ -383,5 +417,9 @@ def jax_device(self): @property def ndim(self): return len(self.shape) - + + @property + def data(self): + return self + __repr__ = __str__