Skip to content

Commit 844b5b7

Browse files
Propagate mutations of Tensor slices to the source Tensor. (#9025)
1 parent 8d42e9a commit 844b5b7

File tree

10 files changed

+1470
-541
lines changed

10 files changed

+1470
-541
lines changed

.github/workflows/torchax.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,5 +56,6 @@ jobs:
5656
pytest test/test_libraries.py
5757
pytest test/test_symbolic_shapes.py
5858
pytest test/test_exports.py
59+
pytest test/test_view.py
5960
pytest test/test_util.py
6061
XLA_FLAGS=--xla_force_host_platform_device_count=4 pytest -n 0 test_dist/

torchax/docs/ops_registry.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ class Operator:
3030
is_jax_function: bool
3131
is_user_defined: bool
3232
needs_env: bool
33+
is_view_op: bool
3334
```
3435

3536
The `torch_op` is the corresponding torch callable, and `func` the implementation. `is_jax_function` is True if `func` is implemented using Jax, False if `func` is implemented using other torch ops. We can use this information to decide how to call it.

torchax/test/test_view.py

Lines changed: 382 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,382 @@
1+
import torch
2+
import torchax
3+
import re
4+
import sys
5+
import unittest
6+
7+
from torchax.tensor import Tensor
8+
from torchax.view import View
9+
10+
class TrainTest(unittest.TestCase):
11+
12+
def setUp(self):
13+
torch.manual_seed(0)
14+
torchax.enable_globally()
15+
16+
def test_copy_(self):
17+
x = torch.zeros((10, 10), device="jax")
18+
y = torch.ones((5, 5), device="jax")
19+
x[0:5, :][:, 0:5].copy_(y[:, :])
20+
self.assertEqual(type(x), Tensor)
21+
self.assertEqual(x.shape, (10, 10))
22+
self.assertEqual(x[0:5, 0:5].sum(), 25)
23+
self.assertEqual(x.sum(), 25)
24+
25+
def test_transivity(self):
26+
x = torch.zeros((10, 10), device="jax")
27+
x_view = x[0:5, :][:, 0:5].add_(1)
28+
y_view = x_view[0:5, :][:, 0:5].add_(1)
29+
self.assertEqual(type(x), Tensor)
30+
self.assertEqual(type(x_view), View)
31+
self.assertEqual(type(y_view), View)
32+
self.assertEqual(x.shape, (10, 10))
33+
self.assertEqual(x[0:5, 0:5].sum(), 50)
34+
self.assertEqual(x.sum(), 50)
35+
36+
def test_outofplace_add(self):
37+
x = torch.zeros((10, 10), device="jax")
38+
x2 = x[0:5, :][:, 0:5].add(1)
39+
x3 = x2[0:5, :][:, 0:5].add_(x2)
40+
self.assertEqual(type(x), Tensor)
41+
self.assertEqual(type(x2), Tensor)
42+
self.assertEqual(type(x3), View)
43+
self.assertEqual(x.shape, (10, 10))
44+
self.assertEqual(x.sum(), 0)
45+
self.assertEqual(x2.sum(), 50)
46+
47+
def test_multiply_tensor_and_view(self):
48+
x = torch.ones((10, 10), device="jax")*2
49+
y = torch.ones((10, 10), device="jax")
50+
x1 = x[:, :]
51+
res = x1.mul(y)
52+
self.assertEqual(type(x), Tensor)
53+
self.assertEqual(type(y), Tensor)
54+
self.assertEqual(type(x1), View)
55+
self.assertEqual(type(res), Tensor)
56+
self.assertEqual(res.sum(), 200)
57+
58+
def test_multiply_views(self):
59+
x = torch.ones((10, 10), device="jax")*2
60+
y = torch.ones((10, 10), device="jax")
61+
x1 = x[0:1, :]
62+
y1 = y[0:1, :]
63+
res = x1.mul(y1)
64+
self.assertEqual(type(x), Tensor)
65+
self.assertEqual(type(y), Tensor)
66+
self.assertEqual(type(x1), View)
67+
self.assertEqual(type(y1), View)
68+
self.assertEqual(type(res), Tensor)
69+
self.assertEqual(res.sum(), 20)
70+
71+
def test_setitem(self):
72+
a = torch.zeros(10, device = "jax")
73+
a[0:5][0:3] = 1
74+
self.assertEqual(type(a), Tensor)
75+
self.assertEqual(a.shape, (10,))
76+
self.assertEqual(a.sum(), 3)
77+
78+
# Test all in-place operations
79+
def test_add_(self):
80+
x = torch.zeros((10, 10), device="jax")
81+
x[0:5, :][:, 0:5].add_(1)
82+
self.assertEqual(type(x), Tensor)
83+
self.assertEqual(x.shape, (10, 10))
84+
self.assertEqual(x.sum(), 25)
85+
86+
def test_sub_(self):
87+
x = torch.zeros((10, 10), device="jax")
88+
x[0:5, :][:, 0:5].sub_(1)
89+
self.assertEqual(type(x), Tensor)
90+
self.assertEqual(x.shape, (10, 10))
91+
self.assertEqual(x.sum(), -25)
92+
93+
def test_mul_(self):
94+
x = torch.ones((10, 10), device="jax")
95+
x[0:5, :][:, 0:5].mul_(2)
96+
self.assertEqual(type(x), Tensor)
97+
self.assertEqual(x.shape, (10, 10))
98+
self.assertEqual(x.sum(), 125)
99+
100+
def test_div_(self):
101+
x = torch.ones((10, 10), device="jax")
102+
x[0:10, :][:, 0:10].div_(2)
103+
self.assertEqual(type(x), Tensor)
104+
self.assertEqual(x.shape, (10, 10))
105+
self.assertEqual(x.sum(), 50)
106+
107+
def test_pow_(self):
108+
x = torch.full((10, 10), fill_value=2, device="jax")
109+
x[0:5, :][:, 0:5].pow_(2)
110+
self.assertEqual(type(x), Tensor)
111+
self.assertEqual(x.shape, (10, 10))
112+
self.assertEqual(x.sum(), 250)
113+
114+
def test_clamp_(self):
115+
x = torch.arange(100, device="jax", dtype=torch.float).reshape(10, 10)
116+
x[0:5, :][:, 0:5].clamp_(min=50, max=80)
117+
self.assertEqual(type(x), Tensor)
118+
self.assertEqual(x.shape, (10, 10))
119+
self.assertTrue((x[0:5, 0:5] >= 50).all())
120+
self.assertTrue((x[0:5, 0:5] <= 80).all())
121+
122+
def test_lt_(self):
123+
x = torch.ones((10, 10), device="jax")
124+
y = torch.zeros((10, 10), device="jax")
125+
x[0:5, :][:, 0:5].lt_(0.5)
126+
self.assertEqual(type(x), Tensor)
127+
self.assertEqual(x.shape, (10, 10))
128+
self.assertEqual(x[0:5, 0:5].sum(), 0) # All False (0) in the modified region
129+
self.assertEqual(x[5:, 5:].sum(), 25) # All True (1) in the unmodified region
130+
131+
def test_le_(self):
132+
x = torch.ones((10, 10), device="jax")
133+
x[0:5, :][:, 0:5].le_(1)
134+
self.assertEqual(type(x), Tensor)
135+
self.assertEqual(x.shape, (10, 10))
136+
self.assertEqual(x.sum(), 100) # All True (1)
137+
138+
def test_gt_(self):
139+
x = torch.ones((10, 10), device="jax")
140+
x[0:5, :][:, 0:5].gt_(1)
141+
self.assertEqual(type(x), Tensor)
142+
self.assertEqual(x.shape, (10, 10))
143+
self.assertEqual(x[0:5, 0:5].sum(), 0) # All False (0) in the modified region
144+
self.assertEqual(x.sum(), 75) # Only the unmodified region is True (1)
145+
146+
def test_ge_(self):
147+
x = torch.ones((10, 10), device="jax")
148+
x[0:5, :][:, 0:5].ge_(1)
149+
self.assertEqual(type(x), Tensor)
150+
self.assertEqual(x.shape, (10, 10))
151+
self.assertEqual(x.sum(), 100) # All True (1)
152+
153+
def test_eq_(self):
154+
x = torch.ones((10, 10), device="jax")
155+
x[0:5, :][:, 0:5].eq_(1)
156+
self.assertEqual(type(x), Tensor)
157+
self.assertEqual(x.shape, (10, 10))
158+
self.assertEqual(x.sum(), 100) # All True (1)
159+
160+
def test_ne_(self):
161+
x = torch.ones((10, 10), device="jax")
162+
x[0:5, :][:, 0:5].ne_(1)
163+
self.assertEqual(type(x), Tensor)
164+
self.assertEqual(x.shape, (10, 10))
165+
self.assertEqual(x[0:5, 0:5].sum(), 0) # All False (0) in the modified region
166+
self.assertEqual(x.sum(), 75) # Only the unmodified region is True (1)
167+
168+
def test_bernoulli_(self):
169+
# Set a fixed seed for deterministic behavior
170+
torch.manual_seed(42)
171+
x = torch.full((10, 10), fill_value=0.5, device="jax")
172+
y = x[0:5, :][:, 0:5]
173+
y.bernoulli_()
174+
self.assertEqual(type(x), Tensor)
175+
self.assertEqual(x.shape, (10, 10))
176+
# Values will be 0 or 1 in the modified region
177+
self.assertTrue(((x[0:5, 0:5] == 0) | (x[0:5, 0:5] == 1)).all())
178+
# Unmodified region remains 0.5
179+
self.assertTrue((x[5:, 5:] == 0.5).all())
180+
181+
def test_geometric_(self):
182+
torch.manual_seed(42)
183+
x = torch.full((10, 10), fill_value=0.5, device="jax")
184+
y = x[0:5, :][:, 0:5]
185+
y.geometric_(p=0.5)
186+
self.assertEqual(type(x), Tensor)
187+
self.assertEqual(x.shape, (10, 10))
188+
# Geometric distribution values are positive integers
189+
self.assertTrue((x[0:5, 0:5] >= 1).all())
190+
# Unmodified region remains 0.5
191+
self.assertTrue((x[5:, 5:] == 0.5).all())
192+
193+
def test_normal_(self):
194+
torch.manual_seed(42)
195+
x = torch.zeros((10, 10), device="jax")
196+
x[0:5, :][:, 0:5].normal_(mean=0, std=1)
197+
self.assertEqual(type(x), Tensor)
198+
self.assertEqual(x.shape, (10, 10))
199+
# Unmodified region remains 0
200+
self.assertEqual(x[5:, 5:].sum(), 0)
201+
202+
def test_uniform_(self):
203+
torch.manual_seed(42)
204+
x = torch.zeros((10, 10), device="jax")
205+
x[0:5, :][:, 0:5].uniform_(0, 1)
206+
self.assertEqual(type(x), Tensor)
207+
self.assertEqual(x.shape, (10, 10))
208+
# Values in modified region are between 0 and 1
209+
self.assertTrue((x[0:5, 0:5] >= 0).all())
210+
self.assertTrue((x[0:5, 0:5] <= 1).all())
211+
# Unmodified region remains 0
212+
self.assertEqual(x[5:, 5:].sum(), 0)
213+
214+
def test_relu_(self):
215+
x = torch.randn((10, 10), device="jax")
216+
x_copy = x.clone()
217+
x[0:5, :][:, 0:5].relu_()
218+
self.assertEqual(type(x), Tensor)
219+
self.assertEqual(x.shape, (10, 10))
220+
# Modified region has no negative values
221+
self.assertTrue((x[0:5, 0:5] >= 0).all())
222+
# Unmodified region remains the same
223+
self.assertTrue(torch.equal(x[5:, 5:], x_copy[5:, 5:]))
224+
225+
def test_squeeze_(self):
226+
x = torch.randn((10, 1, 10), device="jax")
227+
x_clone = x.clone()
228+
# Squeeze the middle dimension
229+
x.squeeze_(1)
230+
self.assertEqual(type(x), Tensor)
231+
self.assertEqual(x.shape, (10, 10))
232+
# Content should remain the same
233+
self.assertTrue(torch.allclose(x, x_clone.squeeze()))
234+
235+
def test_sqrt_(self):
236+
x = torch.randn(
237+
(10, 10), device="jax"
238+
).abs() # Use abs to ensure positive values
239+
x_copy = x.clone()
240+
x[0:5, :][:, 0:5].sqrt_()
241+
self.assertEqual(type(x), Tensor)
242+
self.assertEqual(x.shape, (10, 10))
243+
# Modified region is square root of original values
244+
self.assertTrue(torch.allclose(x[0:5, 0:5], torch.sqrt(x_copy[0:5, 0:5])))
245+
# Unmodified region remains the same
246+
self.assertTrue(torch.equal(x[5:, 5:], x_copy[5:, 5:]))
247+
248+
def test_clamp_min_(self):
249+
x = torch.randn((10, 10), device="jax")
250+
x_copy = x.clone()
251+
x[0:5, :][:, 0:5].clamp_min_(0)
252+
self.assertEqual(type(x), Tensor)
253+
self.assertEqual(x.shape, (10, 10))
254+
# Modified region has no values below 0
255+
self.assertTrue((x[0:5, 0:5] >= 0).all())
256+
# Unmodified region remains the same
257+
self.assertTrue(torch.equal(x[5:, 5:], x_copy[5:, 5:]))
258+
259+
def test_sigmoid_(self):
260+
x = torch.randn((10, 10), device="jax")
261+
x_copy = x.clone()
262+
x[0:5, :][:, 0:5].sigmoid_()
263+
self.assertEqual(type(x), Tensor)
264+
self.assertEqual(x.shape, (10, 10))
265+
# Modified region values are between 0 and 1
266+
self.assertTrue((x[0:5, 0:5] >= 0).all())
267+
self.assertTrue((x[0:5, 0:5] <= 1).all())
268+
# Unmodified region remains the same
269+
self.assertTrue(torch.equal(x[5:, 5:], x_copy[5:, 5:]))
270+
271+
def test_tanh_(self):
272+
x = torch.randn((10, 10), device="jax")
273+
x_copy = x.clone()
274+
x[0:5, :][:, 0:5].tanh_()
275+
self.assertEqual(type(x), Tensor)
276+
self.assertEqual(x.shape, (10, 10))
277+
# Modified region values are between -1 and 1
278+
self.assertTrue((x[0:5, 0:5] >= -1).all())
279+
self.assertTrue((x[0:5, 0:5] <= 1).all())
280+
# Unmodified region remains the same
281+
self.assertTrue(torch.equal(x[5:, 5:], x_copy[5:, 5:]))
282+
283+
def test_ceil_(self):
284+
x = torch.randn((10, 10), device="jax")
285+
x_copy = x.clone()
286+
x[0:5, :][:, 0:5].ceil_()
287+
self.assertEqual(type(x), Tensor)
288+
self.assertEqual(x.shape, (10, 10))
289+
# Check that ceil operation was applied correctly
290+
self.assertTrue(torch.allclose(x[0:5, 0:5], torch.ceil(x_copy[0:5, 0:5])))
291+
# Unmodified region remains the same
292+
self.assertTrue(torch.equal(x[5:, 5:], x_copy[5:, 5:]))
293+
294+
def test_logical_not_(self):
295+
x = torch.zeros((10, 10), device="jax")
296+
x[0:5, 0:5] = 1 # Set some values to 1
297+
x[0:5, :][:, 0:5].logical_not_()
298+
self.assertEqual(type(x), Tensor)
299+
self.assertEqual(x.shape, (10, 10))
300+
# Modified region has all values flipped
301+
self.assertEqual(x[0:5, 0:5].sum(), 0) # All now 0
302+
# Unmodified region remains 0
303+
self.assertEqual(x[5:, 5:].sum(), 0)
304+
305+
def test_unsqueeze_(self):
306+
x = torch.randn((10, 10), device="jax")
307+
x_copy = x.clone()
308+
# Add dimension at index 1
309+
x.unsqueeze_(1)
310+
self.assertEqual(type(x), Tensor)
311+
self.assertEqual(x.shape, (10, 1, 10))
312+
# Content should remain the same
313+
self.assertTrue(torch.equal(x.squeeze(1), x_copy))
314+
315+
def test_transpose_(self):
316+
x = torch.randn((10, 5), device="jax")
317+
x_copy = x.clone()
318+
# Transpose dimensions 0 and 1
319+
x.transpose_(0, 1)
320+
self.assertEqual(type(x), Tensor)
321+
self.assertEqual(x.shape, (5, 10))
322+
# Check transposition worked correctly
323+
self.assertTrue(torch.equal(x, x_copy.transpose(0, 1)))
324+
325+
def test_log_normal_(self):
326+
torch.manual_seed(42)
327+
x = torch.zeros((10, 10), device="jax")
328+
x[0:5, :][:, 0:5].log_normal_(mean=0, std=1)
329+
self.assertEqual(type(x), Tensor)
330+
self.assertEqual(x.shape, (10, 10))
331+
# Log-normal values are positive
332+
self.assertTrue((x[0:5, 0:5] > 0).all())
333+
# Unmodified region remains 0
334+
self.assertEqual(x[5:, 5:].sum(), 0)
335+
336+
def test_scatter_add_(self):
337+
# Initialize test tensors
338+
x = torch.zeros((5, 5), device="jax")
339+
indices = torch.tensor([[0, 1, 2], [0, 1, 2]], device="jax")
340+
values = torch.ones((2, 3), device="jax")
341+
342+
# Apply scatter_add_ operation
343+
x.scatter_add_(0, indices, values)
344+
345+
self.assertEqual(type(x), Tensor)
346+
self.assertEqual(x.shape, (5, 5))
347+
# Check specific values were added
348+
self.assertTrue(torch.all(x[0, 0] == 2.0))
349+
self.assertEqual(x.sum(), 6.0) # Only the 3 specified positions have values
350+
351+
def test_scatter_(self):
352+
# Initialize test tensors
353+
x = torch.zeros((5, 5), device="jax")
354+
indices = torch.tensor([[0, 1, 2], [0, 1, 2]], device="jax")
355+
values = torch.ones((2, 3), device="jax") * 2.0
356+
357+
# Apply scatter_ operation
358+
x.scatter_(0, indices, values)
359+
360+
self.assertEqual(type(x), Tensor)
361+
self.assertEqual(x.shape, (5, 5))
362+
# Check specific values were replaced
363+
self.assertEqual(x[0, 0], 2.0)
364+
self.assertEqual(x[1, 1], 2.0)
365+
self.assertEqual(x[2, 2], 2.0)
366+
self.assertEqual(x.sum(), 6.0) # Only the 3 specified positions have values
367+
368+
def test_scatter_reduce_(self):
369+
# Initialize test tensors
370+
x = torch.ones((5, 5), device="jax")
371+
indices = torch.tensor([[0, 1, 2], [0, 1, 2]], device="jax")
372+
values = torch.ones((2, 3), device="jax") * 2.0
373+
374+
# Apply scatter_reduce_ operation with "sum" reduction
375+
x.scatter_reduce_(0, indices, values, reduce="sum")
376+
377+
self.assertEqual(type(x), Tensor)
378+
self.assertEqual(x.shape, (5, 5))
379+
# Check specific values were reduced
380+
self.assertTrue(torch.all(x[0, 0] == 5.0))
381+
self.assertEqual(x.sum(), 37.0)
382+

0 commit comments

Comments
 (0)