| 
8 | 8 | from unittest.mock import patch  | 
9 | 9 | 
 
  | 
10 | 10 | import torch  | 
 | 11 | +import torch.nn.functional as F  | 
11 | 12 | 
 
  | 
12 | 13 | from torchao.testing.utils import skip_if_no_cuda  | 
13 | 14 | from torchao.utils import TorchAOBaseTensor, torch_version_at_least  | 
@@ -344,6 +345,53 @@ def __init__(  | 
344 | 345 |         )  | 
345 | 346 |         self._test_default_impls_helper(lp_tensor, lp_tensor_for_copy)  | 
346 | 347 | 
 
  | 
 | 348 | +    def test_implements_and_torch_function_together(self):  | 
 | 349 | +        """Ensure a function decorated with both @_implements and @_implements_torch_function works."""  | 
 | 350 | +        counter = {"calls": 0}  | 
 | 351 | + | 
 | 352 | +        class MyTensor(TorchAOBaseTensor):  | 
 | 353 | +            tensor_data_names = ["qdata"]  | 
 | 354 | +            tensor_attribute_names = ["attr", "device"]  | 
 | 355 | + | 
 | 356 | +            def __new__(cls, qdata: torch.Tensor, attr: str = "attr", device=None):  | 
 | 357 | +                kwargs = {}  | 
 | 358 | +                if device is None:  | 
 | 359 | +                    device = qdata.device  | 
 | 360 | +                kwargs["device"] = device  | 
 | 361 | +                kwargs["dtype"] = qdata.dtype  | 
 | 362 | +                r = torch.Tensor._make_wrapper_subclass(cls, qdata.shape, **kwargs)  | 
 | 363 | +                r.qdata = qdata  | 
 | 364 | +                r.attr = attr  | 
 | 365 | +                return r  | 
 | 366 | + | 
 | 367 | +            def __init__(self, qdata: torch.Tensor, attr: str = "attr", device=None):  | 
 | 368 | +                pass  | 
 | 369 | + | 
 | 370 | +        implements = MyTensor.implements  | 
 | 371 | +        implements_torch_function = MyTensor.implements_torch_function  | 
 | 372 | + | 
 | 373 | +        @implements([torch.ops.aten.t.default])  | 
 | 374 | +        @implements_torch_function([F.linear])  | 
 | 375 | +        def fake_linear(func, types, args, kwargs):  | 
 | 376 | +            counter["calls"] += 1  | 
 | 377 | + | 
 | 378 | +        l = torch.nn.Linear(2, 3)  | 
 | 379 | +        l.weight = torch.nn.Parameter(MyTensor(l.weight.detach(), "attr", None))  | 
 | 380 | +        x = torch.randn(4, 2)  | 
 | 381 | + | 
 | 382 | +        # Torch function path  | 
 | 383 | +        F.linear(x, l.weight, l.bias)  | 
 | 384 | +        self.assertEqual(  | 
 | 385 | +            counter["calls"], 1, "Expected fake_linear to be called via F.linear"  | 
 | 386 | +        )  | 
 | 387 | + | 
 | 388 | +        # ATen path  | 
 | 389 | +        mt = MyTensor(torch.randn(3, 4))  | 
 | 390 | +        torch.ops.aten.t.default(mt)  | 
 | 391 | +        self.assertEqual(  | 
 | 392 | +            counter["calls"], 2, "Expected fake_linear to be called via aten.t.default"  | 
 | 393 | +        )  | 
 | 394 | + | 
347 | 395 | 
 
  | 
348 | 396 | if __name__ == "__main__":  | 
349 | 397 |     unittest.main()  | 
0 commit comments