|
| 1 | +import torch |
| 2 | +import torchax |
| 3 | +import re |
| 4 | +import sys |
| 5 | +import unittest |
| 6 | + |
| 7 | +from torchax.tensor import Tensor |
| 8 | +from torchax.view import View |
| 9 | + |
| 10 | +class TrainTest(unittest.TestCase): |
| 11 | + |
| 12 | + def setUp(self): |
| 13 | + torch.manual_seed(0) |
| 14 | + torchax.enable_globally() |
| 15 | + |
| 16 | + def test_copy_(self): |
| 17 | + x = torch.zeros((10, 10), device="jax") |
| 18 | + y = torch.ones((5, 5), device="jax") |
| 19 | + x[0:5, :][:, 0:5].copy_(y[:, :]) |
| 20 | + self.assertEqual(type(x), Tensor) |
| 21 | + self.assertEqual(x.shape, (10, 10)) |
| 22 | + self.assertEqual(x[0:5, 0:5].sum(), 25) |
| 23 | + self.assertEqual(x.sum(), 25) |
| 24 | + |
| 25 | + def test_transivity(self): |
| 26 | + x = torch.zeros((10, 10), device="jax") |
| 27 | + x_view = x[0:5, :][:, 0:5].add_(1) |
| 28 | + y_view = x_view[0:5, :][:, 0:5].add_(1) |
| 29 | + self.assertEqual(type(x), Tensor) |
| 30 | + self.assertEqual(type(x_view), View) |
| 31 | + self.assertEqual(type(y_view), View) |
| 32 | + self.assertEqual(x.shape, (10, 10)) |
| 33 | + self.assertEqual(x[0:5, 0:5].sum(), 50) |
| 34 | + self.assertEqual(x.sum(), 50) |
| 35 | + |
| 36 | + def test_outofplace_add(self): |
| 37 | + x = torch.zeros((10, 10), device="jax") |
| 38 | + x2 = x[0:5, :][:, 0:5].add(1) |
| 39 | + x3 = x2[0:5, :][:, 0:5].add_(x2) |
| 40 | + self.assertEqual(type(x), Tensor) |
| 41 | + self.assertEqual(type(x2), Tensor) |
| 42 | + self.assertEqual(type(x3), View) |
| 43 | + self.assertEqual(x.shape, (10, 10)) |
| 44 | + self.assertEqual(x.sum(), 0) |
| 45 | + self.assertEqual(x2.sum(), 50) |
| 46 | + |
| 47 | + def test_multiply_tensor_and_view(self): |
| 48 | + x = torch.ones((10, 10), device="jax")*2 |
| 49 | + y = torch.ones((10, 10), device="jax") |
| 50 | + x1 = x[:, :] |
| 51 | + res = x1.mul(y) |
| 52 | + self.assertEqual(type(x), Tensor) |
| 53 | + self.assertEqual(type(y), Tensor) |
| 54 | + self.assertEqual(type(x1), View) |
| 55 | + self.assertEqual(type(res), Tensor) |
| 56 | + self.assertEqual(res.sum(), 200) |
| 57 | + |
| 58 | + def test_multiply_views(self): |
| 59 | + x = torch.ones((10, 10), device="jax")*2 |
| 60 | + y = torch.ones((10, 10), device="jax") |
| 61 | + x1 = x[0:1, :] |
| 62 | + y1 = y[0:1, :] |
| 63 | + res = x1.mul(y1) |
| 64 | + self.assertEqual(type(x), Tensor) |
| 65 | + self.assertEqual(type(y), Tensor) |
| 66 | + self.assertEqual(type(x1), View) |
| 67 | + self.assertEqual(type(y1), View) |
| 68 | + self.assertEqual(type(res), Tensor) |
| 69 | + self.assertEqual(res.sum(), 20) |
| 70 | + |
| 71 | + def test_setitem(self): |
| 72 | + a = torch.zeros(10, device = "jax") |
| 73 | + a[0:5][0:3] = 1 |
| 74 | + self.assertEqual(type(a), Tensor) |
| 75 | + self.assertEqual(a.shape, (10,)) |
| 76 | + self.assertEqual(a.sum(), 3) |
| 77 | + |
| 78 | + # Test all in-place operations |
| 79 | + def test_add_(self): |
| 80 | + x = torch.zeros((10, 10), device="jax") |
| 81 | + x[0:5, :][:, 0:5].add_(1) |
| 82 | + self.assertEqual(type(x), Tensor) |
| 83 | + self.assertEqual(x.shape, (10, 10)) |
| 84 | + self.assertEqual(x.sum(), 25) |
| 85 | + |
| 86 | + def test_sub_(self): |
| 87 | + x = torch.zeros((10, 10), device="jax") |
| 88 | + x[0:5, :][:, 0:5].sub_(1) |
| 89 | + self.assertEqual(type(x), Tensor) |
| 90 | + self.assertEqual(x.shape, (10, 10)) |
| 91 | + self.assertEqual(x.sum(), -25) |
| 92 | + |
| 93 | + def test_mul_(self): |
| 94 | + x = torch.ones((10, 10), device="jax") |
| 95 | + x[0:5, :][:, 0:5].mul_(2) |
| 96 | + self.assertEqual(type(x), Tensor) |
| 97 | + self.assertEqual(x.shape, (10, 10)) |
| 98 | + self.assertEqual(x.sum(), 125) |
| 99 | + |
| 100 | + def test_div_(self): |
| 101 | + x = torch.ones((10, 10), device="jax") |
| 102 | + x[0:10, :][:, 0:10].div_(2) |
| 103 | + self.assertEqual(type(x), Tensor) |
| 104 | + self.assertEqual(x.shape, (10, 10)) |
| 105 | + self.assertEqual(x.sum(), 50) |
| 106 | + |
| 107 | + def test_pow_(self): |
| 108 | + x = torch.full((10, 10), fill_value=2, device="jax") |
| 109 | + x[0:5, :][:, 0:5].pow_(2) |
| 110 | + self.assertEqual(type(x), Tensor) |
| 111 | + self.assertEqual(x.shape, (10, 10)) |
| 112 | + self.assertEqual(x.sum(), 250) |
| 113 | + |
| 114 | + def test_clamp_(self): |
| 115 | + x = torch.arange(100, device="jax", dtype=torch.float).reshape(10, 10) |
| 116 | + x[0:5, :][:, 0:5].clamp_(min=50, max=80) |
| 117 | + self.assertEqual(type(x), Tensor) |
| 118 | + self.assertEqual(x.shape, (10, 10)) |
| 119 | + self.assertTrue((x[0:5, 0:5] >= 50).all()) |
| 120 | + self.assertTrue((x[0:5, 0:5] <= 80).all()) |
| 121 | + |
| 122 | + def test_lt_(self): |
| 123 | + x = torch.ones((10, 10), device="jax") |
| 124 | + y = torch.zeros((10, 10), device="jax") |
| 125 | + x[0:5, :][:, 0:5].lt_(0.5) |
| 126 | + self.assertEqual(type(x), Tensor) |
| 127 | + self.assertEqual(x.shape, (10, 10)) |
| 128 | + self.assertEqual(x[0:5, 0:5].sum(), 0) # All False (0) in the modified region |
| 129 | + self.assertEqual(x[5:, 5:].sum(), 25) # All True (1) in the unmodified region |
| 130 | + |
| 131 | + def test_le_(self): |
| 132 | + x = torch.ones((10, 10), device="jax") |
| 133 | + x[0:5, :][:, 0:5].le_(1) |
| 134 | + self.assertEqual(type(x), Tensor) |
| 135 | + self.assertEqual(x.shape, (10, 10)) |
| 136 | + self.assertEqual(x.sum(), 100) # All True (1) |
| 137 | + |
| 138 | + def test_gt_(self): |
| 139 | + x = torch.ones((10, 10), device="jax") |
| 140 | + x[0:5, :][:, 0:5].gt_(1) |
| 141 | + self.assertEqual(type(x), Tensor) |
| 142 | + self.assertEqual(x.shape, (10, 10)) |
| 143 | + self.assertEqual(x[0:5, 0:5].sum(), 0) # All False (0) in the modified region |
| 144 | + self.assertEqual(x.sum(), 75) # Only the unmodified region is True (1) |
| 145 | + |
| 146 | + def test_ge_(self): |
| 147 | + x = torch.ones((10, 10), device="jax") |
| 148 | + x[0:5, :][:, 0:5].ge_(1) |
| 149 | + self.assertEqual(type(x), Tensor) |
| 150 | + self.assertEqual(x.shape, (10, 10)) |
| 151 | + self.assertEqual(x.sum(), 100) # All True (1) |
| 152 | + |
| 153 | + def test_eq_(self): |
| 154 | + x = torch.ones((10, 10), device="jax") |
| 155 | + x[0:5, :][:, 0:5].eq_(1) |
| 156 | + self.assertEqual(type(x), Tensor) |
| 157 | + self.assertEqual(x.shape, (10, 10)) |
| 158 | + self.assertEqual(x.sum(), 100) # All True (1) |
| 159 | + |
| 160 | + def test_ne_(self): |
| 161 | + x = torch.ones((10, 10), device="jax") |
| 162 | + x[0:5, :][:, 0:5].ne_(1) |
| 163 | + self.assertEqual(type(x), Tensor) |
| 164 | + self.assertEqual(x.shape, (10, 10)) |
| 165 | + self.assertEqual(x[0:5, 0:5].sum(), 0) # All False (0) in the modified region |
| 166 | + self.assertEqual(x.sum(), 75) # Only the unmodified region is True (1) |
| 167 | + |
| 168 | + def test_bernoulli_(self): |
| 169 | + # Set a fixed seed for deterministic behavior |
| 170 | + torch.manual_seed(42) |
| 171 | + x = torch.full((10, 10), fill_value=0.5, device="jax") |
| 172 | + y = x[0:5, :][:, 0:5] |
| 173 | + y.bernoulli_() |
| 174 | + self.assertEqual(type(x), Tensor) |
| 175 | + self.assertEqual(x.shape, (10, 10)) |
| 176 | + # Values will be 0 or 1 in the modified region |
| 177 | + self.assertTrue(((x[0:5, 0:5] == 0) | (x[0:5, 0:5] == 1)).all()) |
| 178 | + # Unmodified region remains 0.5 |
| 179 | + self.assertTrue((x[5:, 5:] == 0.5).all()) |
| 180 | + |
| 181 | + def test_geometric_(self): |
| 182 | + torch.manual_seed(42) |
| 183 | + x = torch.full((10, 10), fill_value=0.5, device="jax") |
| 184 | + y = x[0:5, :][:, 0:5] |
| 185 | + y.geometric_(p=0.5) |
| 186 | + self.assertEqual(type(x), Tensor) |
| 187 | + self.assertEqual(x.shape, (10, 10)) |
| 188 | + # Geometric distribution values are positive integers |
| 189 | + self.assertTrue((x[0:5, 0:5] >= 1).all()) |
| 190 | + # Unmodified region remains 0.5 |
| 191 | + self.assertTrue((x[5:, 5:] == 0.5).all()) |
| 192 | + |
| 193 | + def test_normal_(self): |
| 194 | + torch.manual_seed(42) |
| 195 | + x = torch.zeros((10, 10), device="jax") |
| 196 | + x[0:5, :][:, 0:5].normal_(mean=0, std=1) |
| 197 | + self.assertEqual(type(x), Tensor) |
| 198 | + self.assertEqual(x.shape, (10, 10)) |
| 199 | + # Unmodified region remains 0 |
| 200 | + self.assertEqual(x[5:, 5:].sum(), 0) |
| 201 | + |
| 202 | + def test_uniform_(self): |
| 203 | + torch.manual_seed(42) |
| 204 | + x = torch.zeros((10, 10), device="jax") |
| 205 | + x[0:5, :][:, 0:5].uniform_(0, 1) |
| 206 | + self.assertEqual(type(x), Tensor) |
| 207 | + self.assertEqual(x.shape, (10, 10)) |
| 208 | + # Values in modified region are between 0 and 1 |
| 209 | + self.assertTrue((x[0:5, 0:5] >= 0).all()) |
| 210 | + self.assertTrue((x[0:5, 0:5] <= 1).all()) |
| 211 | + # Unmodified region remains 0 |
| 212 | + self.assertEqual(x[5:, 5:].sum(), 0) |
| 213 | + |
| 214 | + def test_relu_(self): |
| 215 | + x = torch.randn((10, 10), device="jax") |
| 216 | + x_copy = x.clone() |
| 217 | + x[0:5, :][:, 0:5].relu_() |
| 218 | + self.assertEqual(type(x), Tensor) |
| 219 | + self.assertEqual(x.shape, (10, 10)) |
| 220 | + # Modified region has no negative values |
| 221 | + self.assertTrue((x[0:5, 0:5] >= 0).all()) |
| 222 | + # Unmodified region remains the same |
| 223 | + self.assertTrue(torch.equal(x[5:, 5:], x_copy[5:, 5:])) |
| 224 | + |
| 225 | + def test_squeeze_(self): |
| 226 | + x = torch.randn((10, 1, 10), device="jax") |
| 227 | + x_clone = x.clone() |
| 228 | + # Squeeze the middle dimension |
| 229 | + x.squeeze_(1) |
| 230 | + self.assertEqual(type(x), Tensor) |
| 231 | + self.assertEqual(x.shape, (10, 10)) |
| 232 | + # Content should remain the same |
| 233 | + self.assertTrue(torch.allclose(x, x_clone.squeeze())) |
| 234 | + |
| 235 | + def test_sqrt_(self): |
| 236 | + x = torch.randn( |
| 237 | + (10, 10), device="jax" |
| 238 | + ).abs() # Use abs to ensure positive values |
| 239 | + x_copy = x.clone() |
| 240 | + x[0:5, :][:, 0:5].sqrt_() |
| 241 | + self.assertEqual(type(x), Tensor) |
| 242 | + self.assertEqual(x.shape, (10, 10)) |
| 243 | + # Modified region is square root of original values |
| 244 | + self.assertTrue(torch.allclose(x[0:5, 0:5], torch.sqrt(x_copy[0:5, 0:5]))) |
| 245 | + # Unmodified region remains the same |
| 246 | + self.assertTrue(torch.equal(x[5:, 5:], x_copy[5:, 5:])) |
| 247 | + |
| 248 | + def test_clamp_min_(self): |
| 249 | + x = torch.randn((10, 10), device="jax") |
| 250 | + x_copy = x.clone() |
| 251 | + x[0:5, :][:, 0:5].clamp_min_(0) |
| 252 | + self.assertEqual(type(x), Tensor) |
| 253 | + self.assertEqual(x.shape, (10, 10)) |
| 254 | + # Modified region has no values below 0 |
| 255 | + self.assertTrue((x[0:5, 0:5] >= 0).all()) |
| 256 | + # Unmodified region remains the same |
| 257 | + self.assertTrue(torch.equal(x[5:, 5:], x_copy[5:, 5:])) |
| 258 | + |
| 259 | + def test_sigmoid_(self): |
| 260 | + x = torch.randn((10, 10), device="jax") |
| 261 | + x_copy = x.clone() |
| 262 | + x[0:5, :][:, 0:5].sigmoid_() |
| 263 | + self.assertEqual(type(x), Tensor) |
| 264 | + self.assertEqual(x.shape, (10, 10)) |
| 265 | + # Modified region values are between 0 and 1 |
| 266 | + self.assertTrue((x[0:5, 0:5] >= 0).all()) |
| 267 | + self.assertTrue((x[0:5, 0:5] <= 1).all()) |
| 268 | + # Unmodified region remains the same |
| 269 | + self.assertTrue(torch.equal(x[5:, 5:], x_copy[5:, 5:])) |
| 270 | + |
| 271 | + def test_tanh_(self): |
| 272 | + x = torch.randn((10, 10), device="jax") |
| 273 | + x_copy = x.clone() |
| 274 | + x[0:5, :][:, 0:5].tanh_() |
| 275 | + self.assertEqual(type(x), Tensor) |
| 276 | + self.assertEqual(x.shape, (10, 10)) |
| 277 | + # Modified region values are between -1 and 1 |
| 278 | + self.assertTrue((x[0:5, 0:5] >= -1).all()) |
| 279 | + self.assertTrue((x[0:5, 0:5] <= 1).all()) |
| 280 | + # Unmodified region remains the same |
| 281 | + self.assertTrue(torch.equal(x[5:, 5:], x_copy[5:, 5:])) |
| 282 | + |
| 283 | + def test_ceil_(self): |
| 284 | + x = torch.randn((10, 10), device="jax") |
| 285 | + x_copy = x.clone() |
| 286 | + x[0:5, :][:, 0:5].ceil_() |
| 287 | + self.assertEqual(type(x), Tensor) |
| 288 | + self.assertEqual(x.shape, (10, 10)) |
| 289 | + # Check that ceil operation was applied correctly |
| 290 | + self.assertTrue(torch.allclose(x[0:5, 0:5], torch.ceil(x_copy[0:5, 0:5]))) |
| 291 | + # Unmodified region remains the same |
| 292 | + self.assertTrue(torch.equal(x[5:, 5:], x_copy[5:, 5:])) |
| 293 | + |
| 294 | + def test_logical_not_(self): |
| 295 | + x = torch.zeros((10, 10), device="jax") |
| 296 | + x[0:5, 0:5] = 1 # Set some values to 1 |
| 297 | + x[0:5, :][:, 0:5].logical_not_() |
| 298 | + self.assertEqual(type(x), Tensor) |
| 299 | + self.assertEqual(x.shape, (10, 10)) |
| 300 | + # Modified region has all values flipped |
| 301 | + self.assertEqual(x[0:5, 0:5].sum(), 0) # All now 0 |
| 302 | + # Unmodified region remains 0 |
| 303 | + self.assertEqual(x[5:, 5:].sum(), 0) |
| 304 | + |
| 305 | + def test_unsqueeze_(self): |
| 306 | + x = torch.randn((10, 10), device="jax") |
| 307 | + x_copy = x.clone() |
| 308 | + # Add dimension at index 1 |
| 309 | + x.unsqueeze_(1) |
| 310 | + self.assertEqual(type(x), Tensor) |
| 311 | + self.assertEqual(x.shape, (10, 1, 10)) |
| 312 | + # Content should remain the same |
| 313 | + self.assertTrue(torch.equal(x.squeeze(1), x_copy)) |
| 314 | + |
| 315 | + def test_transpose_(self): |
| 316 | + x = torch.randn((10, 5), device="jax") |
| 317 | + x_copy = x.clone() |
| 318 | + # Transpose dimensions 0 and 1 |
| 319 | + x.transpose_(0, 1) |
| 320 | + self.assertEqual(type(x), Tensor) |
| 321 | + self.assertEqual(x.shape, (5, 10)) |
| 322 | + # Check transposition worked correctly |
| 323 | + self.assertTrue(torch.equal(x, x_copy.transpose(0, 1))) |
| 324 | + |
| 325 | + def test_log_normal_(self): |
| 326 | + torch.manual_seed(42) |
| 327 | + x = torch.zeros((10, 10), device="jax") |
| 328 | + x[0:5, :][:, 0:5].log_normal_(mean=0, std=1) |
| 329 | + self.assertEqual(type(x), Tensor) |
| 330 | + self.assertEqual(x.shape, (10, 10)) |
| 331 | + # Log-normal values are positive |
| 332 | + self.assertTrue((x[0:5, 0:5] > 0).all()) |
| 333 | + # Unmodified region remains 0 |
| 334 | + self.assertEqual(x[5:, 5:].sum(), 0) |
| 335 | + |
| 336 | + def test_scatter_add_(self): |
| 337 | + # Initialize test tensors |
| 338 | + x = torch.zeros((5, 5), device="jax") |
| 339 | + indices = torch.tensor([[0, 1, 2], [0, 1, 2]], device="jax") |
| 340 | + values = torch.ones((2, 3), device="jax") |
| 341 | + |
| 342 | + # Apply scatter_add_ operation |
| 343 | + x.scatter_add_(0, indices, values) |
| 344 | + |
| 345 | + self.assertEqual(type(x), Tensor) |
| 346 | + self.assertEqual(x.shape, (5, 5)) |
| 347 | + # Check specific values were added |
| 348 | + self.assertTrue(torch.all(x[0, 0] == 2.0)) |
| 349 | + self.assertEqual(x.sum(), 6.0) # Only the 3 specified positions have values |
| 350 | + |
| 351 | + def test_scatter_(self): |
| 352 | + # Initialize test tensors |
| 353 | + x = torch.zeros((5, 5), device="jax") |
| 354 | + indices = torch.tensor([[0, 1, 2], [0, 1, 2]], device="jax") |
| 355 | + values = torch.ones((2, 3), device="jax") * 2.0 |
| 356 | + |
| 357 | + # Apply scatter_ operation |
| 358 | + x.scatter_(0, indices, values) |
| 359 | + |
| 360 | + self.assertEqual(type(x), Tensor) |
| 361 | + self.assertEqual(x.shape, (5, 5)) |
| 362 | + # Check specific values were replaced |
| 363 | + self.assertEqual(x[0, 0], 2.0) |
| 364 | + self.assertEqual(x[1, 1], 2.0) |
| 365 | + self.assertEqual(x[2, 2], 2.0) |
| 366 | + self.assertEqual(x.sum(), 6.0) # Only the 3 specified positions have values |
| 367 | + |
| 368 | + def test_scatter_reduce_(self): |
| 369 | + # Initialize test tensors |
| 370 | + x = torch.ones((5, 5), device="jax") |
| 371 | + indices = torch.tensor([[0, 1, 2], [0, 1, 2]], device="jax") |
| 372 | + values = torch.ones((2, 3), device="jax") * 2.0 |
| 373 | + |
| 374 | + # Apply scatter_reduce_ operation with "sum" reduction |
| 375 | + x.scatter_reduce_(0, indices, values, reduce="sum") |
| 376 | + |
| 377 | + self.assertEqual(type(x), Tensor) |
| 378 | + self.assertEqual(x.shape, (5, 5)) |
| 379 | + # Check specific values were reduced |
| 380 | + self.assertTrue(torch.all(x[0, 0] == 5.0)) |
| 381 | + self.assertEqual(x.sum(), 37.0) |
| 382 | + |
0 commit comments