diff --git a/torchax/test/test_view.py b/torchax/test/test_view.py index 3f5caee5f1f..576435ea9b0 100644 --- a/torchax/test/test_view.py +++ b/torchax/test/test_view.py @@ -14,6 +14,36 @@ def setUp(self): torch.manual_seed(0) torchax.enable_globally() + def test_index_copy_(self): + x = torch.zeros((10, 10), device="jax") + x_view = x[0, :] + indices = torch.arange(5, device="jax") + new_value = torch.ones((5,), device="jax") + x_view.index_copy_(0, indices, new_value) + self.assertEqual(type(x), Tensor) + self.assertEqual(type(x_view), View) + self.assertEqual(x.shape, (10, 10)) + self.assertEqual(x.sum(), 5) + + def test_flatten(self): + x = torch.zeros((10, 10), device="jax") + x1 = x.flatten(0, 1) + y = torch.ones(100, device="jax") + x1.copy_(y) + self.assertEqual(type(x), Tensor) + self.assertEqual(type(x1), View) + self.assertEqual(x.shape, (10, 10)) + self.assertEqual(x.sum(), 100) + + def test_narrow(self): + x = torch.zeros((10, 10), device="jax") + x = x.narrow(0, 0, 5).narrow(0, 0, 5) + y = torch.ones((5, 10), device="jax") + x.copy_(y) + self.assertEqual(type(x), View) + self.assertEqual(x.shape, (5, 10)) + self.assertEqual(x.sum(), 50) + def test_copy_(self): x = torch.zeros((10, 10), device="jax") y = torch.ones((5, 5), device="jax") diff --git a/torchax/torchax/__init__.py b/torchax/torchax/__init__.py index d18c983e252..bc35e8d129a 100644 --- a/torchax/torchax/__init__.py +++ b/torchax/torchax/__init__.py @@ -127,3 +127,6 @@ def compile(fn, options: Optional[CompileOptions] = None): raise RuntimeError('dynamo mode is not supported yet') elif options.mode == 'export': raise RuntimeError('export mode is not supported yet') + +# Intercept torch._sync as no-op +torch._sync = lambda *args, **kwargs: None diff --git a/torchax/torchax/decompositions.py b/torchax/torchax/decompositions.py index f116d42f3d6..47bbb535e25 100644 --- a/torchax/torchax/decompositions.py +++ b/torchax/torchax/decompositions.py @@ -766,4 +766,5 @@ def get_summand(ix: torch.Tensor, iy: torch.Tensor, iz: torch.Tensor, MUTABLE_DECOMPOSITION = [ torch.ops.aten.bernoulli_.Tensor, torch.ops.aten.bernoulli_.float, + torch.ops.aten.index_copy_.default, ] diff --git a/torchax/torchax/interop.py b/torchax/torchax/interop.py index 460441b308c..7b675a239f3 100644 --- a/torchax/torchax/interop.py +++ b/torchax/torchax/interop.py @@ -12,6 +12,7 @@ from torchax import tensor from torchax import util import torchax +from torchax.view import View from torchax.types import JaxValue, TorchValue, JaxCallable, TorchCallable @@ -179,7 +180,7 @@ def _jax_view(t: TorchValue) -> JaxValue: # t is an object from torch land # view it as-if it's a jax land object if isinstance(t, torch.Tensor): - assert isinstance(t, tensor.Tensor), type(t) + assert isinstance(t, tensor.Tensor) or isinstance(t, View), type(t) return t.jax() if isinstance(t, type(torch.int32)): return tensor.t2j_dtype(t) diff --git a/torchax/torchax/ops/jaten.py b/torchax/torchax/ops/jaten.py index fc8dcc71e46..cb7aea7b419 100644 --- a/torchax/torchax/ops/jaten.py +++ b/torchax/torchax/ops/jaten.py @@ -15,7 +15,7 @@ from torchax.ops import op_base, mappings from torchax import interop from torchax.ops import jax_reimplement -from torchax.view import View +from torchax.view import View, NarrowInfo, ReshapeInfo from torchax.tensor import Tensor # Keys are OpOverload, value is a callable that takes # Tensor @@ -57,6 +57,7 @@ torch.ops.aten.scatter_add_: torch.ops.aten.scatter_add, torch.ops.aten.scatter_reduce_.two: torch.ops.aten.scatter_reduce, torch.ops.aten.scatter_: torch.ops.aten.scatter, + torch.ops.aten.index_put_: torch.ops.aten.index_put, } # Note: tuple comparisons work intuitively, e.g. `_jax_version >= (0, 4, 32)`. @@ -102,13 +103,14 @@ def inner(func): @op( - torch.ops.aten.view_copy, - torch.ops.aten.view, - torch.ops.aten._unsafe_view, - torch.ops.aten.reshape, + torch.ops.aten.view_copy, + torch.ops.aten.view, + torch.ops.aten._unsafe_view, + torch.ops.aten.reshape, + is_jax_function=False, ) def _aten_unsafe_view(x, shape): - return jnp.reshape(x, shape) + return View(x, ReshapeInfo(shape=shape), env=x._env) @op(torch.ops.aten.add.Tensor) @@ -131,6 +133,8 @@ def _aten_copy(x, y, memory_format=None): if isinstance(x, View): x.update(y) return x + if isinstance(y, View): + y = y.torch() if x.ndim == 1 and y.ndim == 0: # case of torch.empty((1,)).copy_(tensor(N)) @@ -402,8 +406,8 @@ def _aten_triu(m, k): return jnp.triu(m, k) -@op(torch.ops.aten.slice) -@op(torch.ops.aten.slice_copy) +@op(torch.ops.aten.slice, is_jax_function=False, is_view_op=True) +@op(torch.ops.aten.slice_copy, is_jax_function=False, is_view_op=True) def _aten_slice(self, dim=0, start=None, end=None, step=1): if dim < 0: dim += self.ndim @@ -416,7 +420,7 @@ def _aten_slice(self, dim=0, start=None, end=None, step=1): dims.append(sl) else: dims.append(slice(None, None, None)) - return self[tuple(dims)] + return View(self, NarrowInfo(slices=tuple(dims)), env = self._env) @op(torch.ops.aten.detach) @@ -779,7 +783,6 @@ def _aten_empty_strided(sizes, stride, dtype=None, **kwargs): return jnp.empty(sizes, dtype=dtype) -@op(torch.ops.aten.index_put_) @op(torch.ops.aten.index_put) def _aten_index_put(self, indexes, values, accumulate=False): indexes = [slice(None, None, None) if i is None else i for i in indexes] diff --git a/torchax/torchax/ops/jtorch.py b/torchax/torchax/ops/jtorch.py index f03e5cbf7a0..1443d064ab8 100644 --- a/torchax/torchax/ops/jtorch.py +++ b/torchax/torchax/ops/jtorch.py @@ -269,7 +269,7 @@ def getitem(self, indexes): elif isinstance(indexes, list): indexes = tuple(indexes) - def is_narrow_slicing(): + def is_view_slicing(): tensor_free = not pytree.tree_any( lambda x: isinstance(x, torch.Tensor) or isinstance(x, jax.Array), indexes) @@ -277,7 +277,7 @@ def is_narrow_slicing(): [False if isinstance(x, list) else True for x in indexes]) return tensor_free and list_free - if is_narrow_slicing(): + if is_view_slicing(): return View(self, view_info=NarrowInfo(indexes), env=self._env) indexes = self._env.t2j_iso(indexes) @@ -512,3 +512,52 @@ def functional_linear(self, weights, bias=None): if bias is not None: res += bias return res + + +try: + # TODO: Currently the following ops are wrapped in the try + # catch block because torch.ops.xla is not in the torch ops + # registry. Either we import torch_xla in the upper level, + # or modify the the register_function to support this. + @register_function(torch.ops.xla.dynamo_set_buffer_donor_) + def _dynamo_set_buffer_donor(self, donor): + pass + + @register_function(torch.ops.xla.ragged_paged_attention) + def _ragged_paged_attention( + q: jax.Array, # [max_num_batched_tokens, num_q_heads, head_dim] + kv_pages: jax.Array, # [total_num_pages, page_size, num_combined_kv_heads, head_dim] + kv_lens: jax.Array, # i32[max_num_seqs] + page_indices: jax.Array, # i32[max_num_seqs, pages_per_seq] + cu_q_lens: jax.Array, # i32[max_num_seqs + 1] + num_seqs: jax.Array, # i32[1] + use_kernel: bool = True, + sm_scale: float = 1.0, + sliding_window: int | None = None, + soft_cap: float | None = None, + mask_value: float | None = None, + num_kv_pages_per_block: int | None = None, + num_queries_per_block: int | None = None, + vmem_limit_bytes: int | None = None, + ): + + from torch_xla.experimental.pallas_kernels.ragged_paged_attention_v2 import ragged_paged_attention as ragged_paged_attention_kernel + return ragged_paged_attention_kernel( + q = q, + kv_pages = kv_pages, + kv_lens = kv_lens, + page_indices = page_indices, + cu_q_lens = cu_q_lens, + num_seqs = num_seqs, + sm_scale = sm_scale, + sliding_window = sliding_window, + soft_cap = soft_cap, + mask_value = mask_value, + num_kv_pages_per_block = num_kv_pages_per_block, + num_queries_per_block = num_queries_per_block, + vmem_limit_bytes = vmem_limit_bytes, + ) +except Exception as e: + pass + + diff --git a/torchax/torchax/tensor.py b/torchax/torchax/tensor.py index 66e2b55994b..ca386406d10 100644 --- a/torchax/torchax/tensor.py +++ b/torchax/torchax/tensor.py @@ -101,14 +101,6 @@ def shape(self): def ndim(self): return len(self._elem.shape) - def flatten(self, start_dim=0, end_dim=-1): - if end_dim == -1: - end_dim = self.ndim - new_shape = ( - self._elem.shape[:start_dim] + (-1,) + self._elem.shape[end_dim + 1:]) - new_elem = jnp.reshape(self._elem, new_shape) - return Tensor(new_elem, self._env) - # return torch.reshape(self, new_shape) def __setitem__(self, key, val): key, val = self._env.t2j_iso((key, val)) @@ -381,7 +373,7 @@ def load_ops(self): ) def _to_copy(self, the_tensor, new_dtype, new_device): - if isinstance(the_tensor, Tensor): + if isinstance(the_tensor, Tensor) or isinstance(the_tensor, View): arr = the_tensor.jax() if new_dtype is not None and new_dtype != arr.dtype: arr = arr.astype(mappings.t2j_dtype(new_dtype)) diff --git a/torchax/torchax/view.py b/torchax/torchax/view.py index e9272871a4e..104a025de50 100644 --- a/torchax/torchax/view.py +++ b/torchax/torchax/view.py @@ -4,6 +4,7 @@ from enum import Enum from typing import Union, List, Tuple, Optional, Any, cast from abc import ABC, abstractmethod +import torch.utils._pytree as pytree # Reference to original PyTorch native functions # https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/native/native_functions.yaml @@ -114,6 +115,35 @@ def update_tensor(self, new_value: jax.Array, def calculate_output_shape(self, source: jax.Array) -> List[int]: return source[self.slices].shape +class ReshapeInfo(ViewInfo): + """ + Represents a reshape operation on a tensor. + Handles operations like tensor.reshape(1, 2, 3) and tensor.reshape(-1, 1) + """ + + def __init__(self, shape: Tuple[int, ...]) -> None: + """ + Args: + shape: The shape to reshape the tensor to. + E.g. jax_array.reshape(shape) will return the transformed tensor. + """ + super().__init__(ViewInfoType.RESHAPE) + self.shape = shape + + def __eq__(self, other: object) -> bool: + if not isinstance(other, ReshapeInfo): + return False + return self.shape == other.shape + + def transform_tensor(self, jax_array: jax.Array) -> jax.Array: + return jax_array.reshape(self.shape) + + def update_tensor(self, new_value: jax.Array, jax_array: jax.Array) -> jax.Array: + return new_value.reshape(jax_array.shape) + + def calculate_output_shape(self, source: jax.Array) -> List[int]: + return source.reshape(self.shape).shape + class SelectInfo(ViewInfo): """ @@ -318,6 +348,8 @@ def update( # applying inverse transformations in reverse order for view_info, parent_array in zip( reversed(view_infos), reversed(intermediate_values)): + assert isinstance(new_values, jax.Array) + assert isinstance(parent_array, jax.Array) # Apply the inverse transformation to propagate changes back new_values = view_info.update_tensor(new_values, parent_array) @@ -353,6 +385,10 @@ def create_sub_view(self, view_info: ViewInfo) -> "View": def __str__(self) -> str: return f"View({self.torch()})" + @property + def _elem(self) -> jax.Array: + return self.jax() + def jax(self) -> jax.Array: """ Returns a copy of the source tensor after transformations. @@ -363,6 +399,8 @@ def jax(self) -> jax.Array: return result def __setitem__(self, indexes, val): + # Handle tensor indexing + indexes = pytree.tree_map(lambda x: x.jax() if isinstance(x, torch.Tensor) else x, indexes) view_infos = self.get_transformation_chain() + [NarrowInfo(indexes)] self.update(view_infos=view_infos, new_values=val) @@ -381,4 +419,20 @@ def jax_device(self): def ndim(self): return len(self.shape) + @property + def data(self): + return self + __repr__ = __str__ + + +# FAILED test/test_ops.py::TestOpInfoCPU::test_reference_eager_masked_std_cpu_float32 - NotImplementedError: Cannot copy out of meta tensor; no data! +# FAILED test/test_ops.py::TestOpInfoCPU::test_reference_eager_masked_var_cpu_float32 - NotImplementedError: Cannot copy out of meta tensor; no data! +# FAILED test/test_ops.py::TestOpInfoCPU::test_reference_eager_masked_std_cpu_int64 - NotImplementedError: Cannot copy out of meta tensor; no data! +# FAILED test/test_ops.py::TestOpInfoCPU::test_reference_eager_masked_var_cpu_int64 - NotImplementedError: Cannot copy out of meta tensor; no data! +# FAILED test/test_ops.py::TestOpInfoCPU::test_reference_eager_nn_functional_interpolate_bilinear_cpu_float32 - NotImplementedError: Cannot copy out of meta tensor; no data! +# FAILED test/test_ops.py::TestOpInfoCPU::test_reference_eager_nn_functional_interpolate_linear_cpu_float32 - NotImplementedError: Cannot copy out of meta tensor; no data! +# FAILED test/test_ops.py::TestOpInfoCPU::test_reference_eager_nn_functional_interpolate_trilinear_cpu_float32 - NotImplementedError: Cannot copy out of meta tensor; no data! +# FAILED test/test_ops.py::TestOpInfoCPU::test_reference_eager_nn_functional_upsample_bilinear_cpu_float32 - NotImplementedError: Cannot copy out of meta tensor; no data! +# FAILED test/test_ops.py::TestOpInfoCPU::test_reference_eager_take_cpu_float32 - AttributeError: 'View' object has no attribute '_elem' +# FAILED test/test_ops.py::TestOpInfoCPU::test_reference_eager_take_cpu_int64 - AttributeError: 'View' object has no attribute '_elem' \ No newline at end of file