Skip to content

Add features needed for vllm #9092

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 14 commits into
base: master
Choose a base branch
from
32 changes: 31 additions & 1 deletion torchax/test/test_view.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,38 @@ class TrainTest(unittest.TestCase):
def setUp(self):
torch.manual_seed(0)
torchax.enable_globally()

def test_index_copy_(self):
x = torch.zeros((10, 10), device="jax")
x_view = x[0, :]
indices = torch.arange(5, device="jax")
new_value = torch.ones((5,), device="jax")
x_view.index_copy_(0, indices, new_value)
self.assertEqual(type(x), Tensor)
self.assertEqual(type(x_view), View)
self.assertEqual(x.shape, (10, 10))
self.assertEqual(x.sum(), 5)

def test_flatten(self):
x = torch.zeros((10, 10), device="jax")
x1 = x.flatten(0, 1)
y = torch.ones(100, device="jax")
x1.copy_(y)
self.assertEqual(type(x), Tensor)
self.assertEqual(type(x1), View)
self.assertEqual(x.shape, (10, 10))
self.assertEqual(x.sum(), 100)


def test_narrow(self):
x = torch.zeros((10, 10), device="jax")
x = x.narrow(0, 0, 5).narrow(0, 0, 5)
y = torch.ones((5, 10), device="jax")
x.copy_(y)
self.assertEqual(type(x), View)
self.assertEqual(x.shape, (5, 10))
self.assertEqual(x.sum(), 50)

def test_copy_(self):
x = torch.zeros((10, 10), device="jax")
y = torch.ones((5, 5), device="jax")
Expand Down Expand Up @@ -379,4 +410,3 @@ def test_scatter_reduce_(self):
# Check specific values were reduced
self.assertTrue(torch.all(x[0, 0] == 5.0))
self.assertEqual(x.sum(), 37.0)

3 changes: 3 additions & 0 deletions torchax/torchax/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,3 +120,6 @@ def compile(fn, options: Optional[CompileOptions] = None):
raise RuntimeError('dynamo mode is not supported yet')
elif options.mode == 'export':
raise RuntimeError('export mode is not supported yet')

# Intercept torch._sync as no-op
torch._sync = lambda *args, **kwargs: None
1 change: 1 addition & 0 deletions torchax/torchax/decompositions.py
Original file line number Diff line number Diff line change
Expand Up @@ -776,4 +776,5 @@ def get_summand(
MUTABLE_DECOMPOSITION = [
torch.ops.aten.bernoulli_.Tensor,
torch.ops.aten.bernoulli_.float,
torch.ops.aten.index_copy_.default,
]
3 changes: 2 additions & 1 deletion torchax/torchax/interop.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from torchax import tensor
from torchax import util
import torchax
from torchax.view import View

from torchax.types import JaxValue, TorchValue, JaxCallable, TorchCallable

Expand Down Expand Up @@ -172,7 +173,7 @@ def _jax_view(t: TorchValue) -> JaxValue:
# t is an object from torch land
# view it as-if it's a jax land object
if isinstance(t, torch.Tensor):
assert isinstance(t, tensor.Tensor), type(t)
assert isinstance(t, tensor.Tensor) or isinstance(t, View), type(t)
return t.jax()
if isinstance(t, type(torch.int32)):
return tensor.t2j_dtype(t)
Expand Down
15 changes: 9 additions & 6 deletions torchax/torchax/ops/jaten.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from torchax.ops import op_base, mappings
from torchax import interop
from torchax.ops import jax_reimplement
from torchax.view import View
from torchax.view import View, NarrowInfo, ReshapeInfo
from torchax.tensor import Tensor
# Keys are OpOverload, value is a callable that takes
# Tensor
Expand Down Expand Up @@ -57,6 +57,7 @@
torch.ops.aten.scatter_add_: torch.ops.aten.scatter_add,
torch.ops.aten.scatter_reduce_.two: torch.ops.aten.scatter_reduce,
torch.ops.aten.scatter_: torch.ops.aten.scatter,
torch.ops.aten.index_put_: torch.ops.aten.index_put,
}

# Note: tuple comparisons work intuitively, e.g. `_jax_version >= (0, 4, 32)`.
Expand Down Expand Up @@ -103,9 +104,10 @@ def inner(func):
torch.ops.aten.view,
torch.ops.aten._unsafe_view,
torch.ops.aten.reshape,
is_jax_function=False,
)
def _aten_unsafe_view(x, shape):
return jnp.reshape(x, shape)
return View(x, ReshapeInfo(shape=shape), env=x._env)


@op(torch.ops.aten.add.Tensor)
Expand All @@ -128,6 +130,8 @@ def _aten_copy(x, y, memory_format=None):
if isinstance(x, View):
x.update(y)
return x
if isinstance(y, View):
y = y.torch()

if x.ndim == 1 and y.ndim == 0:
# case of torch.empty((1,)).copy_(tensor(N))
Expand Down Expand Up @@ -397,8 +401,8 @@ def _aten_triu(m, k):
return jnp.triu(m, k)


@op(torch.ops.aten.slice)
@op(torch.ops.aten.slice_copy)
@op(torch.ops.aten.slice, is_jax_function=False, is_view_op=True)
@op(torch.ops.aten.slice_copy, is_jax_function=False, is_view_op=True)
def _aten_slice(self, dim=0, start=None, end=None, step=1):
if dim < 0:
dim += self.ndim
Expand All @@ -411,7 +415,7 @@ def _aten_slice(self, dim=0, start=None, end=None, step=1):
dims.append(sl)
else:
dims.append(slice(None, None, None))
return self[tuple(dims)]
return View(self, NarrowInfo(slices=tuple(dims)), env = self._env)


@op(torch.ops.aten.detach)
Expand Down Expand Up @@ -752,7 +756,6 @@ def _aten_empty_strided(sizes, stride, dtype=None, **kwargs):
return jnp.empty(sizes, dtype=dtype)


@op(torch.ops.aten.index_put_)
@op(torch.ops.aten.index_put)
def _aten_index_put(self, indexes, values, accumulate=False):
indexes = [slice(None, None, None) if i is None else i for i in indexes]
Expand Down
48 changes: 48 additions & 0 deletions torchax/torchax/ops/jtorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -512,3 +512,51 @@ def functional_linear(self, weights, bias=None):
if bias is not None:
res += bias
return res

try:
# TODO: Currently the following ops are wrapped in the try
# catch block because torch.ops.xla is not in the torch ops
# registry. Either we import torch_xla in the upper level,
# or modify the the register_function to support this.
@register_function(torch.ops.xla.dynamo_set_buffer_donor_)
def _dynamo_set_buffer_donor(self, donor):
pass

@register_function(torch.ops.xla.ragged_paged_attention)
def _ragged_paged_attention(
q: jax.Array, # [max_num_batched_tokens, num_q_heads, head_dim]
kv_pages: jax.Array, # [total_num_pages, page_size, num_combined_kv_heads, head_dim]
kv_lens: jax.Array, # i32[max_num_seqs]
page_indices: jax.Array, # i32[max_num_seqs, pages_per_seq]
cu_q_lens: jax.Array, # i32[max_num_seqs + 1]
num_seqs: jax.Array, # i32[1]
use_kernel: bool = True,
sm_scale: float = 1.0,
sliding_window: int | None = None,
soft_cap: float | None = None,
mask_value: float | None = None,
num_kv_pages_per_block: int | None = None,
num_queries_per_block: int | None = None,
vmem_limit_bytes: int | None = None,
):

from torch_xla.experimental.pallas_kernels.ragged_paged_attention_v2 import ragged_paged_attention as ragged_paged_attention_kernel
return ragged_paged_attention_kernel(
q = q,
kv_pages = kv_pages,
kv_lens = kv_lens,
page_indices = page_indices,
cu_q_lens = cu_q_lens,
num_seqs = num_seqs,
sm_scale = sm_scale,
sliding_window = sliding_window,
soft_cap = soft_cap,
mask_value = mask_value,
num_kv_pages_per_block = num_kv_pages_per_block,
num_queries_per_block = num_queries_per_block,
vmem_limit_bytes = vmem_limit_bytes,
)
except Exception as e:
pass


9 changes: 0 additions & 9 deletions torchax/torchax/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,15 +100,6 @@ def shape(self):
def ndim(self):
return len(self._elem.shape)

def flatten(self, start_dim=0, end_dim=-1):
if end_dim == -1:
end_dim = self.ndim
new_shape = (
self._elem.shape[:start_dim] + (-1,) + self._elem.shape[end_dim + 1 :]
)
new_elem = jnp.reshape(self._elem, new_shape)
return Tensor(new_elem, self._env)
# return torch.reshape(self, new_shape)

def __setitem__(self, key, val):
key, val = self._env.t2j_iso((key, val))
Expand Down
40 changes: 39 additions & 1 deletion torchax/torchax/view.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from enum import Enum
from typing import Union, List, Tuple, Optional, Any, cast
from abc import ABC, abstractmethod
import torch.utils._pytree as pytree

# Reference to original PyTorch native functions
# https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/native/native_functions.yaml
Expand Down Expand Up @@ -112,6 +113,35 @@ def update_tensor(self, new_value: jax.Array, jax_array: jax.Array) -> jax.Array
def calculate_output_shape(self, source: jax.Array) -> List[int]:
return source[self.slices].shape

class ReshapeInfo(ViewInfo):
"""
Represents a reshape operation on a tensor.
Handles operations like tensor.reshape(1, 2, 3) and tensor.reshape(-1, 1)
"""

def __init__(self, shape: Tuple[int, ...]) -> None:
"""
Args:
shape: The shape to reshape the tensor to.
E.g. jax_array.reshape(shape) will return the transformed tensor.
"""
super().__init__(ViewInfoType.RESHAPE)
self.shape = shape

def __eq__(self, other: object) -> bool:
if not isinstance(other, ReshapeInfo):
return False
return self.shape == other.shape

def transform_tensor(self, jax_array: jax.Array) -> jax.Array:
return jax_array.reshape(self.shape)

def update_tensor(self, new_value: jax.Array, jax_array: jax.Array) -> jax.Array:
return new_value.reshape(jax_array.shape)

def calculate_output_shape(self, source: jax.Array) -> List[int]:
return source.reshape(self.shape).shape


class SelectInfo(ViewInfo):
"""
Expand Down Expand Up @@ -321,6 +351,8 @@ def update(
for view_info, parent_array in zip(
reversed(view_infos), reversed(intermediate_values)
):
assert isinstance(new_values, jax.Array)
assert isinstance(parent_array, jax.Array)
# Apply the inverse transformation to propagate changes back
new_values = view_info.update_tensor(new_values, parent_array)

Expand Down Expand Up @@ -366,6 +398,8 @@ def jax(self) -> jax.Array:
return result

def __setitem__(self, indexes, val):
# Handle tensor indexing
indexes = pytree.tree_map(lambda x: x.jax() if isinstance(x, torch.Tensor) else x, indexes)
view_infos = self.get_transformation_chain() + [NarrowInfo(indexes)]
self.update(view_infos=view_infos, new_values=val)

Expand All @@ -383,5 +417,9 @@ def jax_device(self):
@property
def ndim(self):
return len(self.shape)


@property
def data(self):
return self

__repr__ = __str__