From 5cb534336c23002a2c741db7f0734c8baf7ca031 Mon Sep 17 00:00:00 2001 From: Wen Zhang Date: Tue, 17 Jun 2025 22:01:45 +0000 Subject: [PATCH 1/8] Add jax_device context manager to control the device target --- torchax/test/test_interop.py | 47 +++++++++++++++++++++++++++++- torchax/torchax/__init__.py | 17 +++++++++++ torchax/torchax/tensor.py | 55 +++++++++++++++++++++++++++++------- 3 files changed, 108 insertions(+), 11 deletions(-) diff --git a/torchax/test/test_interop.py b/torchax/test/test_interop.py index dde90f235e2b..81ae5a244fc7 100644 --- a/torchax/test/test_interop.py +++ b/torchax/test/test_interop.py @@ -2,8 +2,21 @@ import torch import unittest import torchax -from torchax import interop +from torchax import interop, jax_device import torchax +import jax + + +def is_tpu_available(): + """Checks if any TPU devices are available to JAX.""" + try: + # jax.devices('tpu') will return a list of TPU devices if available. + # If no TPUs are found or JAX is not configured for TPU, + # it will raise a RuntimeError. + tpu_devices = jax.devices('tpu') + return len(tpu_devices) > 0 + except RuntimeError: + return False class InteropTest(unittest.TestCase): @@ -116,6 +129,38 @@ def forward(self, x): # assert torch.testing.assert_allclose(actual, expected) + def test_to_jax_device(self): + a = torch.ones(3, 3) + + with jax_device("cpu"): + # move torch.tensor to torchax.tensor CPU + b = a.to("jax") + self.assertEqual(b.jax_device.platform, "cpu") + self.assertEqual(b.device.type, "jax") + + if is_tpu_available: + # move torch.tensor to torchax.tensor TPU + with jax_device("tpu"): + c = a.to("jax") + self.assertEqual(c.jax_device.platform, "tpu") + self.assertEqual(c.device.type, "jax") + + # move torchax.tensor on CPU to TPU + with jax_device("tpu"): + self.assertEqual(b.jax_device.platform, "cpu") + self.assertEqual(c.device.type, "jax") + c = b.to("jax") + self.assertEqual(c.jax_device.platform, "tpu") + self.assertEqual(c.device.type, "jax") + + # move torchax.tensor on TPU to CPU + with jax_device("cpu"): + self.assertEqual(c.jax_device.platform, "tpu") + self.assertEqual(c.device.type, "jax") + d = c.to("jax") + self.assertEqual(d.jax_device.platform, "cpu") + self.assertEqual(d.device.type, "jax") + if __name__ == '__main__': unittest.main() diff --git a/torchax/torchax/__init__.py b/torchax/torchax/__init__.py index 36a49f80572a..9df90e010476 100644 --- a/torchax/torchax/__init__.py +++ b/torchax/torchax/__init__.py @@ -7,6 +7,7 @@ from torch.utils import _pytree as pytree from torchax import tensor from torchax import distributed # noqa: F401 +from contextlib import contextmanager __version__ = "0.0.4" VERSION = __version__ @@ -128,3 +129,19 @@ def compile(fn, options: Optional[CompileOptions] = None): raise RuntimeError('dynamo mode is not supported yet') elif options.mode == 'export': raise RuntimeError('export mode is not supported yet') + + +@contextmanager +def jax_device(target_device: str, env: tensor.Environment | None = None): + """ + When moving to Jax, manage where it tensor's device is + """ + if env is None: + env = default_env() + + prev_target_device = env.target_device + try: + env.target_device = target_device + yield env + finally: + env.target_device = prev_target_device diff --git a/torchax/torchax/tensor.py b/torchax/torchax/tensor.py index 2ce9f1fa9a73..9f00d41eb84f 100644 --- a/torchax/torchax/tensor.py +++ b/torchax/torchax/tensor.py @@ -293,6 +293,8 @@ def _name_of_func(func): torch.as_tensor, } +# TODO(wen): use existing types, either from torch or jax +SUPPORTED_JAX_PLATFROM = ["cpu", "tpu"] class Environment(contextlib.ContextDecorator): """This class holds a set of configurations and "globals" needed @@ -322,10 +324,19 @@ def __init__(self, configuration=None): self._manually_entered = False self.enabled = False - self._jax_devices = set(["jax", "jax_cpu", "xla"]) + self._prng_key = mutable_array( jax.random.key(torch.initial_seed() % (1 << 63))) self.autocast_dtype = None + self._target_device = "cpu" + + @property + def target_device(self): + return self._target_device + + @target_device.setter + def target_device(self, device:str): + self._target_device = device.lower() def manual_seed(self, key): self._prng_key = mutable_array(jax.random.key(key)) @@ -348,9 +359,18 @@ def get_as_jax_device(self, device: Any): if self.config.treat_cuda_as_jax_device and device.startswith("cuda"): return jax.local_devices()[0] - if device.startswith("jax") or device.startswith("xla"): + if device.startswith("xla"): return jax.local_devices()[0] + # TODO (wen): jax is NOT a device type, + # once we can register more than one backend, revisit + if device.startswith("jax"): + match self.target_device: + case "cpu": + return jax.devices("cpu")[0] + case "tpu": + return jax.devices("tpu")[0] + return None # fallback to torch def load_ops(self): @@ -401,19 +421,34 @@ def _get_from_dict(op_dict, op): def _to_copy(self, the_tensor, new_dtype, new_device): if isinstance(the_tensor, View): the_tensor = the_tensor.torch() + if isinstance(the_tensor, Tensor): + arr = the_tensor.jax() + if new_dtype is not None and new_dtype != arr.dtype: arr = arr.astype(mappings.t2j_dtype(new_dtype)) + + # convert xla tensor to other device + # only supported is CPU if new_device is not None: - # convert xla tensor to other device - # only supported is CPU - if str(new_device).startswith("cpu"): - # converting to a non-jax device: let torch native handle it - torch_tensor = self.j2t_copy(arr) if isinstance(the_tensor, - Tensor) else arr - with mode_utils.no_dispatch(), torch._C.DisableTorchFunction(): - return torch_tensor.to(new_device) + match str(new_device).lower(): + case "cpu": + # converting to a non-jax device: let torch native handle it + torch_tensor = self.j2t_copy(arr) if isinstance(the_tensor, + Tensor) else arr + with mode_utils.no_dispatch(), torch._C.DisableTorchFunction(): + return torch_tensor.to(new_device) + case "jax": + # move torchax.tensor / jax tensor between devices + # I don't know ifgit this will work after the model is jitted + if self.target_device != the_tensor.jax_device.platform: + arr = jax.device_put(the_tensor.jax(), + jax.devices(self.target_device)[0]) + return Tensor(arr, self) + case _: + logging.error(f"torchax.Tenosr cannot handle device {new_device}") + else: if new_dtype is not None and new_dtype != the_tensor.dtype: with mode_utils.no_dispatch(), torch._C.DisableTorchFunction(): From ea7c1d14c4f15c5d7b490aa78194d92726dcf94e Mon Sep 17 00:00:00 2001 From: Wen Zhang Date: Tue, 17 Jun 2025 22:25:25 +0000 Subject: [PATCH 2/8] Add some comment --- torchax/torchax/__init__.py | 19 +++++++++++++++++-- 1 file changed, 17 insertions(+), 2 deletions(-) diff --git a/torchax/torchax/__init__.py b/torchax/torchax/__init__.py index 9df90e010476..617970a56c4f 100644 --- a/torchax/torchax/__init__.py +++ b/torchax/torchax/__init__.py @@ -134,8 +134,23 @@ def compile(fn, options: Optional[CompileOptions] = None): @contextmanager def jax_device(target_device: str, env: tensor.Environment | None = None): """ - When moving to Jax, manage where it tensor's device is - """ + to("jax") cannot differentiate the device/platform (cpu vs tpu). + Use this context manager to control jax array's storage device + + Examples: + + a = torch.ones(3, 3) + + with jax_device("cpu"): + b = a.to("jax") + + with jax_device("tpu"): + c = a.to("jax") + + with jax_device("tpu"): + c = b.to("jax") + + """ if env is None: env = default_env() From 39a99804e14e3031c530adebc70953e31371bee2 Mon Sep 17 00:00:00 2001 From: Wen Zhang Date: Tue, 17 Jun 2025 22:32:45 +0000 Subject: [PATCH 3/8] Add default to match --- torchax/torchax/tensor.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/torchax/torchax/tensor.py b/torchax/torchax/tensor.py index 9f00d41eb84f..26b625edeb59 100644 --- a/torchax/torchax/tensor.py +++ b/torchax/torchax/tensor.py @@ -370,6 +370,8 @@ def get_as_jax_device(self, device: Any): return jax.devices("cpu")[0] case "tpu": return jax.devices("tpu")[0] + case _: + raise AttributeError(f"Cannot handle env.target_device {self.target_device}") return None # fallback to torch From 279c06ad1cb392a249b169d97425ebbfc8453cd2 Mon Sep 17 00:00:00 2001 From: Wen Zhang Date: Tue, 17 Jun 2025 22:38:00 +0000 Subject: [PATCH 4/8] lint --- torchax/torchax/tensor.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/torchax/torchax/tensor.py b/torchax/torchax/tensor.py index 26b625edeb59..6f8f337e1117 100644 --- a/torchax/torchax/tensor.py +++ b/torchax/torchax/tensor.py @@ -296,6 +296,7 @@ def _name_of_func(func): # TODO(wen): use existing types, either from torch or jax SUPPORTED_JAX_PLATFROM = ["cpu", "tpu"] + class Environment(contextlib.ContextDecorator): """This class holds a set of configurations and "globals" needed @@ -335,7 +336,7 @@ def target_device(self): return self._target_device @target_device.setter - def target_device(self, device:str): + def target_device(self, device: str): self._target_device = device.lower() def manual_seed(self, key): @@ -362,7 +363,7 @@ def get_as_jax_device(self, device: Any): if device.startswith("xla"): return jax.local_devices()[0] - # TODO (wen): jax is NOT a device type, + # TODO (wen): jax is NOT a device type, # once we can register more than one backend, revisit if device.startswith("jax"): match self.target_device: @@ -371,7 +372,8 @@ def get_as_jax_device(self, device: Any): case "tpu": return jax.devices("tpu")[0] case _: - raise AttributeError(f"Cannot handle env.target_device {self.target_device}") + raise AttributeError( + f"Cannot handle env.target_device {self.target_device}") return None # fallback to torch @@ -446,7 +448,7 @@ def _to_copy(self, the_tensor, new_dtype, new_device): # I don't know ifgit this will work after the model is jitted if self.target_device != the_tensor.jax_device.platform: arr = jax.device_put(the_tensor.jax(), - jax.devices(self.target_device)[0]) + jax.devices(self.target_device)[0]) return Tensor(arr, self) case _: logging.error(f"torchax.Tenosr cannot handle device {new_device}") From 74d15ef8a415584e1e9063834694a0c180a5682d Mon Sep 17 00:00:00 2001 From: Wen Zhang Date: Wed, 18 Jun 2025 09:33:56 -0700 Subject: [PATCH 5/8] fix test --- torchax/test/test_interop.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchax/test/test_interop.py b/torchax/test/test_interop.py index 81ae5a244fc7..ce0edca2f728 100644 --- a/torchax/test/test_interop.py +++ b/torchax/test/test_interop.py @@ -138,7 +138,7 @@ def test_to_jax_device(self): self.assertEqual(b.jax_device.platform, "cpu") self.assertEqual(b.device.type, "jax") - if is_tpu_available: + if is_tpu_available(): # move torch.tensor to torchax.tensor TPU with jax_device("tpu"): c = a.to("jax") From 5ec61278c31075d443e6da4499476777c3bb051b Mon Sep 17 00:00:00 2001 From: Wen Zhang Date: Fri, 20 Jun 2025 10:10:20 -0700 Subject: [PATCH 6/8] Clean up comment --- torchax/torchax/tensor.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/torchax/torchax/tensor.py b/torchax/torchax/tensor.py index 6f8f337e1117..e3067dd3ec77 100644 --- a/torchax/torchax/tensor.py +++ b/torchax/torchax/tensor.py @@ -433,8 +433,7 @@ def _to_copy(self, the_tensor, new_dtype, new_device): if new_dtype is not None and new_dtype != arr.dtype: arr = arr.astype(mappings.t2j_dtype(new_dtype)) - # convert xla tensor to other device - # only supported is CPU + if new_device is not None: match str(new_device).lower(): case "cpu": From d3596f3cef7fd95e35fc47d298f9856bace43e16 Mon Sep 17 00:00:00 2001 From: Wen Zhang Date: Fri, 20 Jun 2025 11:06:44 -0700 Subject: [PATCH 7/8] Setup ruff and apply the standard linter rule to torchax/torchax --- torchax/pyproject.toml | 16 + torchax/torchax/__init__.py | 188 +- torchax/torchax/amp.py | 449 +- torchax/torchax/config.py | 32 +- torchax/torchax/decompositions.py | 392 +- torchax/torchax/device_module.py | 14 +- torchax/torchax/distributed.py | 383 +- torchax/torchax/export.py | 438 +- torchax/torchax/flax.py | 63 +- torchax/torchax/interop.py | 512 +- torchax/torchax/mesh_util.py | 376 +- torchax/torchax/ops/__init__.py | 16 +- torchax/torchax/ops/jaten.py | 7571 ++++++++++++----------- torchax/torchax/ops/jax_reimplement.py | 302 +- torchax/torchax/ops/jc10d.py | 52 +- torchax/torchax/ops/jimage.py | 167 +- torchax/torchax/ops/jlibrary.py | 110 +- torchax/torchax/ops/jtorch.py | 622 +- torchax/torchax/ops/jtorchvision_nms.py | 434 +- torchax/torchax/ops/mappings.py | 186 +- torchax/torchax/ops/op_base.py | 183 +- torchax/torchax/ops/ops_registry.py | 86 +- torchax/torchax/tensor.py | 1191 ++-- torchax/torchax/tf_integration.py | 178 +- torchax/torchax/train.py | 195 +- torchax/torchax/types.py | 8 +- torchax/torchax/util.py | 165 +- torchax/torchax/view.py | 502 +- 28 files changed, 7598 insertions(+), 7233 deletions(-) diff --git a/torchax/pyproject.toml b/torchax/pyproject.toml index 3882b77b78f3..689b7983e680 100644 --- a/torchax/pyproject.toml +++ b/torchax/pyproject.toml @@ -51,3 +51,19 @@ packages = ["torchax"] [tool.pytest.ini_options] addopts="-n auto" + +[tool.ruff] +# Equivalent to column_limit +line-length = 80 + +# Enable preview mode to use rules like E306 +preview = true + +[tool.ruff.lint] +select = [ + "E", "F", "W", # Your existing rule selections + "Q", +] +# Enforces a blank line before nested functions and classes +extend-select = ["E306"] + diff --git a/torchax/torchax/__init__.py b/torchax/torchax/__init__.py index 617970a56c4f..2745a2a47db6 100644 --- a/torchax/torchax/__init__.py +++ b/torchax/torchax/__init__.py @@ -13,150 +13,158 @@ VERSION = __version__ __all__ = [ - 'default_env', - 'extract_jax', - 'enable_globally', + "default_env", + "extract_jax", + "enable_globally", ] from jax._src import xla_bridge -os.environ.setdefault('ENABLE_RUNTIME_UPTIME_TELEMETRY', '1') +os.environ.setdefault("ENABLE_RUNTIME_UPTIME_TELEMETRY", "1") # torchax:oss-begin -if getattr(jax.config, 'jax_pjrt_client_create_options', None): - jax.config.update( - 'jax_pjrt_client_create_options', - f'ml_framework_name:PyTorch/XLA2;ml_framework_version:{"v0.0.1"}') +if getattr(jax.config, "jax_pjrt_client_create_options", None): + jax.config.update( + "jax_pjrt_client_create_options", + f"ml_framework_name:PyTorch/XLA2;ml_framework_version:{'v0.0.1'}", + ) # torchax:oss-end env = None def default_env(): - global env + global env - if env is None: - env = tensor.Environment() - return env + if env is None: + env = tensor.Environment() + return env def extract_jax(mod: torch.nn.Module, env=None): - """Returns a pytree of jax.ndarray and a jax callable.""" - if env is None: - env = default_env() - states = dict(mod.named_buffers()) - states.update(mod.named_parameters()) + """Returns a pytree of jax.ndarray and a jax callable.""" + if env is None: + env = default_env() + states = dict(mod.named_buffers()) + states.update(mod.named_parameters()) - states = env.t2j_copy(states) + states = env.t2j_copy(states) - #@jax.jit - def jax_func(states, inputs): - (states, inputs) = env.j2t_iso((states, inputs)) - with env: - res = torch.func.functional_call(mod, states, inputs, tie_weights=False) - return env.t2j_iso(res) + # @jax.jit + def jax_func(states, inputs): + (states, inputs) = env.j2t_iso((states, inputs)) + with env: + res = torch.func.functional_call( + mod, states, inputs, tie_weights=False + ) + return env.t2j_iso(res) - return states, jax_func + return states, jax_func def enable_globally(): - env = default_env().enable_torch_modes() - return env + env = default_env().enable_torch_modes() + return env def disable_globally(): - global env - default_env().disable_torch_modes() + global env + default_env().disable_torch_modes() @contextlib.contextmanager def disable_temporarily(): - prev = default_env().enabled - if prev: - disable_globally() - yield () - if prev: - enable_globally() + prev = default_env().enabled + if prev: + disable_globally() + yield () + if prev: + enable_globally() -torch.utils.rename_privateuse1_backend('jax') +torch.utils.rename_privateuse1_backend("jax") unsupported_dtype = [torch.quint8] torch.utils.generate_methods_for_privateuse1_backend( for_tensor=True, for_module=True, for_storage=True, - unsupported_dtype=unsupported_dtype) + unsupported_dtype=unsupported_dtype, +) import jax import torchax.device_module -torch._register_device_module('jax', torchax.device_module) +torch._register_device_module("jax", torchax.device_module) def enable_accuracy_mode(): - jax.config.update('jax_enable_x64', True) - jax.config.update('jax_default_matmul_precision', 'highest') - default_env().config.internal_respect_torch_return_dtypes = True + jax.config.update("jax_enable_x64", True) + jax.config.update("jax_default_matmul_precision", "highest") + default_env().config.internal_respect_torch_return_dtypes = True def enable_performance_mode(): - jax.config.update('jax_enable_x64', False) - jax.config.update('jax_default_matmul_precision', 'default') - default_env().config.internal_respect_torch_return_dtypes = False + jax.config.update("jax_enable_x64", False) + jax.config.update("jax_default_matmul_precision", "default") + default_env().config.internal_respect_torch_return_dtypes = False @dataclasses.dataclass class CompileOptions: - # only valid if compiling nn.Module - methods_to_compile: List[str] = dataclasses.field( - default_factory=lambda: ['forward']) - jax_jit_kwargs: Dict[str, Any] = dataclasses.field(default_factory=dict) - mode: str = 'jax' # or dynamo or export + # only valid if compiling nn.Module + methods_to_compile: List[str] = dataclasses.field( + default_factory=lambda: ["forward"] + ) + jax_jit_kwargs: Dict[str, Any] = dataclasses.field(default_factory=dict) + mode: str = "jax" # or dynamo or export def compile(fn, options: Optional[CompileOptions] = None): - options = options or CompileOptions() - if options.mode == 'jax': - from torchax import interop - if isinstance(fn, torch.nn.Module): - module = interop.JittableModule(fn, extra_jit_args=options.jax_jit_kwargs) - for n in options.methods_to_compile: - module.make_jitted(n) - return module - else: - return interop.jax_jit(fn) - elif options.mode == 'dynamo': - raise RuntimeError('dynamo mode is not supported yet') - elif options.mode == 'export': - raise RuntimeError('export mode is not supported yet') + options = options or CompileOptions() + if options.mode == "jax": + from torchax import interop + + if isinstance(fn, torch.nn.Module): + module = interop.JittableModule( + fn, extra_jit_args=options.jax_jit_kwargs + ) + for n in options.methods_to_compile: + module.make_jitted(n) + return module + else: + return interop.jax_jit(fn) + elif options.mode == "dynamo": + raise RuntimeError("dynamo mode is not supported yet") + elif options.mode == "export": + raise RuntimeError("export mode is not supported yet") @contextmanager def jax_device(target_device: str, env: tensor.Environment | None = None): - """ - to("jax") cannot differentiate the device/platform (cpu vs tpu). - Use this context manager to control jax array's storage device - - Examples: - - a = torch.ones(3, 3) - - with jax_device("cpu"): - b = a.to("jax") - - with jax_device("tpu"): - c = a.to("jax") - - with jax_device("tpu"): - c = b.to("jax") - - """ - if env is None: - env = default_env() - - prev_target_device = env.target_device - try: - env.target_device = target_device - yield env - finally: - env.target_device = prev_target_device + """ + to("jax") cannot differentiate the device/platform (cpu vs tpu). + Use this context manager to control jax array's storage device + + Examples: + + a = torch.ones(3, 3) + + with jax_device("cpu"): + b = a.to("jax") + + with jax_device("tpu"): + c = a.to("jax") + + with jax_device("tpu"): + c = b.to("jax") + + """ + if env is None: + env = default_env() + + prev_target_device = env.target_device + try: + env.target_device = target_device + yield env + finally: + env.target_device = prev_target_device diff --git a/torchax/torchax/amp.py b/torchax/torchax/amp.py index ef06e884a8a8..465c2fff2fb4 100644 --- a/torchax/torchax/amp.py +++ b/torchax/torchax/amp.py @@ -28,306 +28,179 @@ # promote, // Run in the widest dtype among several args. # }; class CastPolicy(enum.Enum): - LOWER_PRECISION_FP = 0 - FP32 = 1 - FP32_SET_OPT_DTYPE = 2 - FP32_APPEND_DTYPE = 3 - PROMOTE = 4 + LOWER_PRECISION_FP = 0 + FP32 = 1 + FP32_SET_OPT_DTYPE = 2 + FP32_APPEND_DTYPE = 3 + PROMOTE = 4 def execute_policy(policy, args, kwargs, target_lower_fp): + def is_float(a): + return isinstance(a, torch.Tensor) and a.is_floating_point() - def is_float(a): - return isinstance(a, torch.Tensor) and a.is_floating_point() - match policy: - case CastPolicy.LOWER_PRECISION_FP: - return pytree.tree_map_only(is_float, lambda a: a.to(target_lower_fp), - (args, kwargs)) - case CastPolicy.FP32: - return pytree.tree_map_only(is_float, lambda a: a.to(torch.float32), - (args, kwargs)) - case CastPolicy.PROMOTE: - dtypes = set(a.dtype for a in args) - widest = max((dtype.itemsize, dtype) for dtype in dtypes)[1] - return pytree.tree_map_only(is_float, lambda a: a.to(widest), - (args, kwargs)) - case _: - raise AssertionError(f'Policy {policy} not implemented yet.') + match policy: + case CastPolicy.LOWER_PRECISION_FP: + return pytree.tree_map_only( + is_float, lambda a: a.to(target_lower_fp), (args, kwargs) + ) + case CastPolicy.FP32: + return pytree.tree_map_only( + is_float, lambda a: a.to(torch.float32), (args, kwargs) + ) + case CastPolicy.PROMOTE: + dtypes = set(a.dtype for a in args) + widest = max((dtype.itemsize, dtype) for dtype in dtypes)[1] + return pytree.tree_map_only( + is_float, lambda a: a.to(widest), (args, kwargs) + ) + case _: + raise AssertionError(f"Policy {policy} not implemented yet.") @contextlib.contextmanager def autocast(device, dtype=torch.bfloat16, env=None): - del device - if env is None: - import torchax - env = torchax.default_env() - env.autocast_dtype, old = dtype, env.autocast_dtype - yield - env.autocast_dtype = old + del device + if env is None: + import torchax + + env = torchax.default_env() + env.autocast_dtype, old = dtype, env.autocast_dtype + yield + env.autocast_dtype = old # https://github.com/pytorch/pytorch/blob/05faba40287cf7d8734da96cb2e904f39710bf29/aten/src/ATen/autocast_mode.cpp#L327 autocast_policy = { - torch.ops.aten.conv1d.default: - CastPolicy.LOWER_PRECISION_FP, - torch.ops.aten.conv1d.padding: - CastPolicy.LOWER_PRECISION_FP, - torch.ops.aten.conv2d.default: - CastPolicy.LOWER_PRECISION_FP, - torch.ops.aten.conv2d.padding: - CastPolicy.LOWER_PRECISION_FP, - torch.ops.aten.conv3d.default: - CastPolicy.LOWER_PRECISION_FP, - torch.ops.aten.conv3d.padding: - CastPolicy.LOWER_PRECISION_FP, - torch.ops.aten.bmm.default: - CastPolicy.LOWER_PRECISION_FP, - torch.ops.aten.mm.default: - CastPolicy.LOWER_PRECISION_FP, - torch.ops.aten.linalg_vecdot.default: - CastPolicy.LOWER_PRECISION_FP, - torch.ops.aten.baddbmm.default: - CastPolicy.LOWER_PRECISION_FP, - torch.ops.aten.addmm.default: - CastPolicy.LOWER_PRECISION_FP, - torch.ops.aten._addmm_activation.default: - CastPolicy.LOWER_PRECISION_FP, - torch.ops.aten.addbmm.default: - CastPolicy.LOWER_PRECISION_FP, - torch.ops.aten.linear.default: - CastPolicy.LOWER_PRECISION_FP, - torch.ops.aten._convolution.deprecated: - CastPolicy.LOWER_PRECISION_FP, - torch.ops.aten.matmul.default: - CastPolicy.LOWER_PRECISION_FP, - torch.ops.aten.conv_tbc.default: - CastPolicy.LOWER_PRECISION_FP, - torch.ops.aten.mkldnn_rnn_layer.default: - CastPolicy.LOWER_PRECISION_FP, - torch.ops.aten.conv_transpose1d.default: - CastPolicy.LOWER_PRECISION_FP, - torch.ops.aten.conv_transpose2d.input: - CastPolicy.LOWER_PRECISION_FP, - torch.ops.aten.conv_transpose3d.input: - CastPolicy.LOWER_PRECISION_FP, - torch.ops.aten.prelu.default: - CastPolicy.LOWER_PRECISION_FP, - torch.ops.aten.scaled_dot_product_attention.default: - CastPolicy.LOWER_PRECISION_FP, - torch.ops.aten._native_multi_head_attention.default: - CastPolicy.LOWER_PRECISION_FP, - + torch.ops.aten.conv1d.default: CastPolicy.LOWER_PRECISION_FP, + torch.ops.aten.conv1d.padding: CastPolicy.LOWER_PRECISION_FP, + torch.ops.aten.conv2d.default: CastPolicy.LOWER_PRECISION_FP, + torch.ops.aten.conv2d.padding: CastPolicy.LOWER_PRECISION_FP, + torch.ops.aten.conv3d.default: CastPolicy.LOWER_PRECISION_FP, + torch.ops.aten.conv3d.padding: CastPolicy.LOWER_PRECISION_FP, + torch.ops.aten.bmm.default: CastPolicy.LOWER_PRECISION_FP, + torch.ops.aten.mm.default: CastPolicy.LOWER_PRECISION_FP, + torch.ops.aten.linalg_vecdot.default: CastPolicy.LOWER_PRECISION_FP, + torch.ops.aten.baddbmm.default: CastPolicy.LOWER_PRECISION_FP, + torch.ops.aten.addmm.default: CastPolicy.LOWER_PRECISION_FP, + torch.ops.aten._addmm_activation.default: CastPolicy.LOWER_PRECISION_FP, + torch.ops.aten.addbmm.default: CastPolicy.LOWER_PRECISION_FP, + torch.ops.aten.linear.default: CastPolicy.LOWER_PRECISION_FP, + torch.ops.aten._convolution.deprecated: CastPolicy.LOWER_PRECISION_FP, + torch.ops.aten.matmul.default: CastPolicy.LOWER_PRECISION_FP, + torch.ops.aten.conv_tbc.default: CastPolicy.LOWER_PRECISION_FP, + torch.ops.aten.mkldnn_rnn_layer.default: CastPolicy.LOWER_PRECISION_FP, + torch.ops.aten.conv_transpose1d.default: CastPolicy.LOWER_PRECISION_FP, + torch.ops.aten.conv_transpose2d.input: CastPolicy.LOWER_PRECISION_FP, + torch.ops.aten.conv_transpose3d.input: CastPolicy.LOWER_PRECISION_FP, + torch.ops.aten.prelu.default: CastPolicy.LOWER_PRECISION_FP, + torch.ops.aten.scaled_dot_product_attention.default: CastPolicy.LOWER_PRECISION_FP, + torch.ops.aten._native_multi_head_attention.default: CastPolicy.LOWER_PRECISION_FP, # fp32 cast policy - torch.ops.aten.avg_pool3d.default: - CastPolicy.FP32, - torch.ops.aten.binary_cross_entropy.default: - CastPolicy.FP32, - torch.ops.aten.grid_sampler.default: - CastPolicy.FP32, - torch.ops.aten.polar.default: - CastPolicy.FP32, - torch.ops.aten.prod.default: - CastPolicy.FP32, - torch.ops.aten.prod.dim_int: - CastPolicy.FP32, - torch.ops.aten.prod.dim_Dimname: - CastPolicy.FP32, - torch.ops.aten.quantile.default: - CastPolicy.FP32, - torch.ops.aten.quantile.scalar: - CastPolicy.FP32, - torch.ops.aten.nanquantile.default: - CastPolicy.FP32, - torch.ops.aten.nanquantile.scalar: - CastPolicy.FP32, - torch.ops.aten.stft.default: - CastPolicy.FP32, - torch.ops.aten.stft.center: - CastPolicy.FP32, - torch.ops.aten.cdist.default: - CastPolicy.FP32, - torch.ops.aten.grid_sampler_2d.default: - CastPolicy.FP32, - torch.ops.aten._grid_sampler_2d_cpu_fallback.default: - CastPolicy.FP32, - torch.ops.aten.grid_sampler_3d.default: - CastPolicy.FP32, - torch.ops.aten.trace.default: - CastPolicy.FP32, - torch.ops.aten.view_as_complex.default: - CastPolicy.FP32, - torch.ops.aten.cholesky.default: - CastPolicy.FP32, - torch.ops.aten.cholesky_inverse.default: - CastPolicy.FP32, - torch.ops.aten.cholesky_solve.default: - CastPolicy.FP32, - torch.ops.aten.inverse.default: - CastPolicy.FP32, - torch.ops.aten.lu_solve.default: - CastPolicy.FP32, - torch.ops.aten.orgqr.default: - CastPolicy.FP32, - torch.ops.aten.ormqr.default: - CastPolicy.FP32, - torch.ops.aten.pinverse.default: - CastPolicy.FP32, - torch.ops.aten.max_pool3d.default: - CastPolicy.FP32, - torch.ops.aten.max_unpool2d.default: - CastPolicy.FP32, - torch.ops.aten.max_unpool3d.default: - CastPolicy.FP32, - torch.ops.aten.adaptive_avg_pool3d.default: - CastPolicy.FP32, - torch.ops.aten.reflection_pad1d.default: - CastPolicy.FP32, - torch.ops.aten.reflection_pad2d.default: - CastPolicy.FP32, - torch.ops.aten.replication_pad1d.default: - CastPolicy.FP32, - torch.ops.aten.replication_pad2d.default: - CastPolicy.FP32, - torch.ops.aten.replication_pad3d.default: - CastPolicy.FP32, - torch.ops.aten.mse_loss.default: - CastPolicy.FP32, - torch.ops.aten.cosine_embedding_loss.default: - CastPolicy.FP32, - torch.ops.aten.nll_loss.default: - CastPolicy.FP32, - torch.ops.aten.nll_loss2d.default: - CastPolicy.FP32, - torch.ops.aten.hinge_embedding_loss.default: - CastPolicy.FP32, - torch.ops.aten.poisson_nll_loss.default: - CastPolicy.FP32, - torch.ops.aten.smooth_l1_loss.default: - CastPolicy.FP32, - torch.ops.aten.cross_entropy_loss.default: - CastPolicy.FP32, - torch.ops.aten.l1_loss.default: - CastPolicy.FP32, - torch.ops.aten.huber_loss.default: - CastPolicy.FP32, - torch.ops.aten.margin_ranking_loss.default: - CastPolicy.FP32, - torch.ops.aten.soft_margin_loss.default: - CastPolicy.FP32, - torch.ops.aten.triplet_margin_loss.default: - CastPolicy.FP32, - torch.ops.aten.multi_margin_loss.default: - CastPolicy.FP32, - torch.ops.aten.ctc_loss.IntList: - CastPolicy.FP32, - torch.ops.aten.ctc_loss.Tensor: - CastPolicy.FP32, - torch.ops.aten.kl_div.default: - CastPolicy.FP32, - torch.ops.aten.multilabel_margin_loss.default: - CastPolicy.FP32, - torch.ops.aten.binary_cross_entropy_with_logits.default: - CastPolicy.FP32, - torch.ops.aten.fft_fft.default: - CastPolicy.FP32, - torch.ops.aten.fft_ifft.default: - CastPolicy.FP32, - torch.ops.aten.fft_fft2.default: - CastPolicy.FP32, - torch.ops.aten.fft_ifft2.default: - CastPolicy.FP32, - torch.ops.aten.fft_fftn.default: - CastPolicy.FP32, - torch.ops.aten.fft_ifftn.default: - CastPolicy.FP32, - torch.ops.aten.fft_rfft.default: - CastPolicy.FP32, - torch.ops.aten.fft_irfft.default: - CastPolicy.FP32, - torch.ops.aten.fft_rfft2.default: - CastPolicy.FP32, - torch.ops.aten.fft_irfft2.default: - CastPolicy.FP32, - torch.ops.aten.fft_rfftn.default: - CastPolicy.FP32, - torch.ops.aten.fft_irfftn.default: - CastPolicy.FP32, - torch.ops.aten.fft_hfft.default: - CastPolicy.FP32, - torch.ops.aten.fft_ihfft.default: - CastPolicy.FP32, - torch.ops.aten.linalg_cond.default: - CastPolicy.FP32, - torch.ops.aten.linalg_cond.p_str: - CastPolicy.FP32, - torch.ops.aten.linalg_matrix_rank.default: - CastPolicy.FP32, - torch.ops.aten.linalg_matrix_rank.tol_tensor: - CastPolicy.FP32, - torch.ops.aten.linalg_matrix_rank.atol_rtol_tensor: - CastPolicy.FP32, - torch.ops.aten.linalg_matrix_rank.atol_rtol_float: - CastPolicy.FP32, - torch.ops.aten.linalg_solve.default: - CastPolicy.FP32, - torch.ops.aten.linalg_cholesky.default: - CastPolicy.FP32, - torch.ops.aten.linalg_svdvals.default: - CastPolicy.FP32, - torch.ops.aten.linalg_eigvals.default: - CastPolicy.FP32, - torch.ops.aten.linalg_eigvalsh.default: - CastPolicy.FP32, - torch.ops.aten.linalg_inv.default: - CastPolicy.FP32, - torch.ops.aten.linalg_householder_product.default: - CastPolicy.FP32, - torch.ops.aten.linalg_tensorinv.default: - CastPolicy.FP32, - torch.ops.aten.linalg_tensorsolve.default: - CastPolicy.FP32, - torch.ops.aten.fake_quantize_per_tensor_affine.default: - CastPolicy.FP32, - torch.ops.aten.geqrf.default: - CastPolicy.FP32, - torch.ops.aten._lu_with_info.default: - CastPolicy.FP32, - torch.ops.aten.qr.default: - CastPolicy.FP32, - torch.ops.aten.svd.default: - CastPolicy.FP32, - torch.ops.aten.triangular_solve.default: - CastPolicy.FP32, - torch.ops.aten.fractional_max_pool2d.default: - CastPolicy.FP32, - torch.ops.aten.fractional_max_pool3d.default: - CastPolicy.FP32, - torch.ops.aten.adaptive_max_pool3d.default: - CastPolicy.FP32, - torch.ops.aten.multilabel_margin_loss_forward.default: - CastPolicy.FP32, - torch.ops.aten.linalg_qr.default: - CastPolicy.FP32, - torch.ops.aten.linalg_cholesky_ex.default: - CastPolicy.FP32, - torch.ops.aten.linalg_svd.default: - CastPolicy.FP32, - torch.ops.aten.linalg_eig.default: - CastPolicy.FP32, - torch.ops.aten.linalg_eigh.default: - CastPolicy.FP32, - torch.ops.aten.linalg_lstsq.default: - CastPolicy.FP32, - torch.ops.aten.linalg_inv_ex.default: - CastPolicy.FP32, - + torch.ops.aten.avg_pool3d.default: CastPolicy.FP32, + torch.ops.aten.binary_cross_entropy.default: CastPolicy.FP32, + torch.ops.aten.grid_sampler.default: CastPolicy.FP32, + torch.ops.aten.polar.default: CastPolicy.FP32, + torch.ops.aten.prod.default: CastPolicy.FP32, + torch.ops.aten.prod.dim_int: CastPolicy.FP32, + torch.ops.aten.prod.dim_Dimname: CastPolicy.FP32, + torch.ops.aten.quantile.default: CastPolicy.FP32, + torch.ops.aten.quantile.scalar: CastPolicy.FP32, + torch.ops.aten.nanquantile.default: CastPolicy.FP32, + torch.ops.aten.nanquantile.scalar: CastPolicy.FP32, + torch.ops.aten.stft.default: CastPolicy.FP32, + torch.ops.aten.stft.center: CastPolicy.FP32, + torch.ops.aten.cdist.default: CastPolicy.FP32, + torch.ops.aten.grid_sampler_2d.default: CastPolicy.FP32, + torch.ops.aten._grid_sampler_2d_cpu_fallback.default: CastPolicy.FP32, + torch.ops.aten.grid_sampler_3d.default: CastPolicy.FP32, + torch.ops.aten.trace.default: CastPolicy.FP32, + torch.ops.aten.view_as_complex.default: CastPolicy.FP32, + torch.ops.aten.cholesky.default: CastPolicy.FP32, + torch.ops.aten.cholesky_inverse.default: CastPolicy.FP32, + torch.ops.aten.cholesky_solve.default: CastPolicy.FP32, + torch.ops.aten.inverse.default: CastPolicy.FP32, + torch.ops.aten.lu_solve.default: CastPolicy.FP32, + torch.ops.aten.orgqr.default: CastPolicy.FP32, + torch.ops.aten.ormqr.default: CastPolicy.FP32, + torch.ops.aten.pinverse.default: CastPolicy.FP32, + torch.ops.aten.max_pool3d.default: CastPolicy.FP32, + torch.ops.aten.max_unpool2d.default: CastPolicy.FP32, + torch.ops.aten.max_unpool3d.default: CastPolicy.FP32, + torch.ops.aten.adaptive_avg_pool3d.default: CastPolicy.FP32, + torch.ops.aten.reflection_pad1d.default: CastPolicy.FP32, + torch.ops.aten.reflection_pad2d.default: CastPolicy.FP32, + torch.ops.aten.replication_pad1d.default: CastPolicy.FP32, + torch.ops.aten.replication_pad2d.default: CastPolicy.FP32, + torch.ops.aten.replication_pad3d.default: CastPolicy.FP32, + torch.ops.aten.mse_loss.default: CastPolicy.FP32, + torch.ops.aten.cosine_embedding_loss.default: CastPolicy.FP32, + torch.ops.aten.nll_loss.default: CastPolicy.FP32, + torch.ops.aten.nll_loss2d.default: CastPolicy.FP32, + torch.ops.aten.hinge_embedding_loss.default: CastPolicy.FP32, + torch.ops.aten.poisson_nll_loss.default: CastPolicy.FP32, + torch.ops.aten.smooth_l1_loss.default: CastPolicy.FP32, + torch.ops.aten.cross_entropy_loss.default: CastPolicy.FP32, + torch.ops.aten.l1_loss.default: CastPolicy.FP32, + torch.ops.aten.huber_loss.default: CastPolicy.FP32, + torch.ops.aten.margin_ranking_loss.default: CastPolicy.FP32, + torch.ops.aten.soft_margin_loss.default: CastPolicy.FP32, + torch.ops.aten.triplet_margin_loss.default: CastPolicy.FP32, + torch.ops.aten.multi_margin_loss.default: CastPolicy.FP32, + torch.ops.aten.ctc_loss.IntList: CastPolicy.FP32, + torch.ops.aten.ctc_loss.Tensor: CastPolicy.FP32, + torch.ops.aten.kl_div.default: CastPolicy.FP32, + torch.ops.aten.multilabel_margin_loss.default: CastPolicy.FP32, + torch.ops.aten.binary_cross_entropy_with_logits.default: CastPolicy.FP32, + torch.ops.aten.fft_fft.default: CastPolicy.FP32, + torch.ops.aten.fft_ifft.default: CastPolicy.FP32, + torch.ops.aten.fft_fft2.default: CastPolicy.FP32, + torch.ops.aten.fft_ifft2.default: CastPolicy.FP32, + torch.ops.aten.fft_fftn.default: CastPolicy.FP32, + torch.ops.aten.fft_ifftn.default: CastPolicy.FP32, + torch.ops.aten.fft_rfft.default: CastPolicy.FP32, + torch.ops.aten.fft_irfft.default: CastPolicy.FP32, + torch.ops.aten.fft_rfft2.default: CastPolicy.FP32, + torch.ops.aten.fft_irfft2.default: CastPolicy.FP32, + torch.ops.aten.fft_rfftn.default: CastPolicy.FP32, + torch.ops.aten.fft_irfftn.default: CastPolicy.FP32, + torch.ops.aten.fft_hfft.default: CastPolicy.FP32, + torch.ops.aten.fft_ihfft.default: CastPolicy.FP32, + torch.ops.aten.linalg_cond.default: CastPolicy.FP32, + torch.ops.aten.linalg_cond.p_str: CastPolicy.FP32, + torch.ops.aten.linalg_matrix_rank.default: CastPolicy.FP32, + torch.ops.aten.linalg_matrix_rank.tol_tensor: CastPolicy.FP32, + torch.ops.aten.linalg_matrix_rank.atol_rtol_tensor: CastPolicy.FP32, + torch.ops.aten.linalg_matrix_rank.atol_rtol_float: CastPolicy.FP32, + torch.ops.aten.linalg_solve.default: CastPolicy.FP32, + torch.ops.aten.linalg_cholesky.default: CastPolicy.FP32, + torch.ops.aten.linalg_svdvals.default: CastPolicy.FP32, + torch.ops.aten.linalg_eigvals.default: CastPolicy.FP32, + torch.ops.aten.linalg_eigvalsh.default: CastPolicy.FP32, + torch.ops.aten.linalg_inv.default: CastPolicy.FP32, + torch.ops.aten.linalg_householder_product.default: CastPolicy.FP32, + torch.ops.aten.linalg_tensorinv.default: CastPolicy.FP32, + torch.ops.aten.linalg_tensorsolve.default: CastPolicy.FP32, + torch.ops.aten.fake_quantize_per_tensor_affine.default: CastPolicy.FP32, + torch.ops.aten.geqrf.default: CastPolicy.FP32, + torch.ops.aten._lu_with_info.default: CastPolicy.FP32, + torch.ops.aten.qr.default: CastPolicy.FP32, + torch.ops.aten.svd.default: CastPolicy.FP32, + torch.ops.aten.triangular_solve.default: CastPolicy.FP32, + torch.ops.aten.fractional_max_pool2d.default: CastPolicy.FP32, + torch.ops.aten.fractional_max_pool3d.default: CastPolicy.FP32, + torch.ops.aten.adaptive_max_pool3d.default: CastPolicy.FP32, + torch.ops.aten.multilabel_margin_loss_forward.default: CastPolicy.FP32, + torch.ops.aten.linalg_qr.default: CastPolicy.FP32, + torch.ops.aten.linalg_cholesky_ex.default: CastPolicy.FP32, + torch.ops.aten.linalg_svd.default: CastPolicy.FP32, + torch.ops.aten.linalg_eig.default: CastPolicy.FP32, + torch.ops.aten.linalg_eigh.default: CastPolicy.FP32, + torch.ops.aten.linalg_lstsq.default: CastPolicy.FP32, + torch.ops.aten.linalg_inv_ex.default: CastPolicy.FP32, # promote - torch.ops.aten.stack.default: - CastPolicy.PROMOTE, - torch.ops.aten.cat.default: - CastPolicy.PROMOTE, - torch.ops.aten.index_copy.default: - CastPolicy.PROMOTE, - torch.ops.aten.index_copy.dimname: - CastPolicy.PROMOTE, + torch.ops.aten.stack.default: CastPolicy.PROMOTE, + torch.ops.aten.cat.default: CastPolicy.PROMOTE, + torch.ops.aten.index_copy.default: CastPolicy.PROMOTE, + torch.ops.aten.index_copy.dimname: CastPolicy.PROMOTE, } diff --git a/torchax/torchax/config.py b/torchax/torchax/config.py index 9370625e85cb..336563add738 100644 --- a/torchax/torchax/config.py +++ b/torchax/torchax/config.py @@ -3,24 +3,24 @@ @dataclasses.dataclass class Configuration: - debug_print_each_op: bool = False - debug_accuracy_for_each_op: bool = False - debug_mixed_tensor: bool = False - debug_print_each_op_operands: bool = False + debug_print_each_op: bool = False + debug_accuracy_for_each_op: bool = False + debug_mixed_tensor: bool = False + debug_print_each_op_operands: bool = False - use_int32_for_index: bool = False + use_int32_for_index: bool = False - # If true, we will convert Views into torchax.Tensors eagerly - force_materialize_views: bool = False + # If true, we will convert Views into torchax.Tensors eagerly + force_materialize_views: bool = False - # Use DLPack for converting jax.Arrays <-> and torch.Tensor - use_dlpack_for_data_conversion: bool = False + # Use DLPack for converting jax.Arrays <-> and torch.Tensor + use_dlpack_for_data_conversion: bool = False - # Flash attention - use_tpu_flash_attention: bool = False - shmap_flash_attention: bool = False + # Flash attention + use_tpu_flash_attention: bool = False + shmap_flash_attention: bool = False - # device - treat_cuda_as_jax_device: bool = True - use_torch_native_for_cpu_tensor: bool = True - internal_respect_torch_return_dtypes: bool = False + # device + treat_cuda_as_jax_device: bool = True + use_torch_native_for_cpu_tensor: bool = True + internal_respect_torch_return_dtypes: bool = False diff --git a/torchax/torchax/decompositions.py b/torchax/torchax/decompositions.py index 81cbcd02e3ac..369000661bb5 100644 --- a/torchax/torchax/decompositions.py +++ b/torchax/torchax/decompositions.py @@ -28,24 +28,23 @@ def _try_register(op, impl): - try: - register_decomposition(op)(impl) - except: - pass + try: + register_decomposition(op)(impl) + except: + pass @out_wrapper() def _reflection_pad(a: Tensor, padding: Tuple[int, ...]) -> Tensor: + def idx(left, middle, right): + dim_idx = torch.arange(-left, middle + right, device=a.device) + return middle - 1 - (middle - 1 - dim_idx.abs()).abs() - def idx(left, middle, right): - dim_idx = torch.arange(-left, middle + right, device=a.device) - return middle - 1 - (middle - 1 - dim_idx.abs()).abs() - - return _reflection_or_replication_pad( - a, - padding, - idx, - ) + return _reflection_or_replication_pad( + a, + padding, + idx, + ) _try_register(aten.reflection_pad1d, _reflection_pad) @@ -55,20 +54,20 @@ def idx(left, middle, right): @out_wrapper() def _replication_pad(a: Tensor, padding: Tuple[int, ...]) -> Tensor: + def idx(left, middle, right): + dim_idx = torch.arange(-left, middle + right, device=a.device) + return torch.clamp(dim_idx, 0, middle - 1) - def idx(left, middle, right): - dim_idx = torch.arange(-left, middle + right, device=a.device) - return torch.clamp(dim_idx, 0, middle - 1) - - return _reflection_or_replication_pad( - a, - padding, - idx, - ) + return _reflection_or_replication_pad( + a, + padding, + idx, + ) decomp.global_decomposition_table["post_autograd"][ - aten.replication_pad2d.default] = _replication_pad + aten.replication_pad2d.default +] = _replication_pad def _reflection_or_replication_pad( @@ -76,27 +75,29 @@ def _reflection_or_replication_pad( padding: Tuple[int, ...], idx_fn: Callable[[int, int, int], Tensor], ) -> Tensor: - dim = len(padding) // 2 - torch._check( - a.dim() in (dim + 1, dim + 2), - lambda: f"reflection_pad{dim}d requires {dim + 1}D or {dim + 2}D input", - ) - inp_shape = a.shape[-dim:] - nc_dim = a.dim() - dim - - padding_left = [padding[2 * (dim - 1 - i)] for i in range(dim)] - padding_right = [padding[2 * (dim - 1 - i) + 1] for i in range(dim)] - - result = a - for i in range(dim): - idx: List[Any] = [None] * result.dim() - idx[i + nc_dim] = idx_fn(padding_left[i], inp_shape[i], padding_right[i]) - result = aten._unsafe_index(result, idx) - - # convert output to correct memory format, if necessary - memory_format = utils.suggest_memory_format(result) - result = result.contiguous(memory_format=memory_format) - return result + dim = len(padding) // 2 + torch._check( + a.dim() in (dim + 1, dim + 2), + lambda: f"reflection_pad{dim}d requires {dim + 1}D or {dim + 2}D input", + ) + inp_shape = a.shape[-dim:] + nc_dim = a.dim() - dim + + padding_left = [padding[2 * (dim - 1 - i)] for i in range(dim)] + padding_right = [padding[2 * (dim - 1 - i) + 1] for i in range(dim)] + + result = a + for i in range(dim): + idx: List[Any] = [None] * result.dim() + idx[i + nc_dim] = idx_fn( + padding_left[i], inp_shape[i], padding_right[i] + ) + result = aten._unsafe_index(result, idx) + + # convert output to correct memory format, if necessary + memory_format = utils.suggest_memory_format(result) + result = result.contiguous(memory_format=memory_format) + return result _try_register(aten.replication_pad1d, _replication_pad) @@ -104,24 +105,24 @@ def _reflection_or_replication_pad( def bernoulli(self, *, generator=None): - return (torch.rand_like(self, dtype=torch.float32) < self).to(self.dtype) + return (torch.rand_like(self, dtype=torch.float32) < self).to(self.dtype) _try_register(aten.bernoulli.default, bernoulli) def rand_like(self, **kwargs): - dtype = kwargs.get("dtype", self.dtype) - return torch.rand(self.shape, dtype=dtype) + dtype = kwargs.get("dtype", self.dtype) + return torch.rand(self.shape, dtype=dtype) def channel_shuffle(self, groups): - batchsize, channels, height, width = self.shape - channels_per_group = channels // groups - self = self.reshape(batchsize, groups, channels_per_group, height, width) - self = self.transpose(1, 2) - self = self.reshape(batchsize, channels, height, width) - return self + batchsize, channels, height, width = self.shape + channels_per_group = channels // groups + self = self.reshape(batchsize, groups, channels_per_group, height, width) + self = self.transpose(1, 2) + self = self.reshape(batchsize, channels, height, width) + return self _try_register(aten.channel_shuffle, channel_shuffle) @@ -131,7 +132,7 @@ def channel_shuffle(self, groups): def bernoulli_float(self, p=0.5): - return self.bernoulli_(p) + return self.bernoulli_(p) _try_register(aten.bernoulli_.float, bernoulli_float) @@ -139,7 +140,7 @@ def bernoulli_float(self, p=0.5): def _sum_tensors(ts) -> Tensor: - return functools.reduce(torch.add, ts) + return functools.reduce(torch.add, ts) @register_decomposition(aten.grid_sampler_3d) @@ -150,141 +151,152 @@ def _grid_sampler_3d( padding_mode: int = 0, align_corners: bool = False, ) -> Tensor: - """References: https://github.com/pytorch/pytorch/blob/06a7dc21c1005750598c37f3adbc031183c74de6/torch/_decomp/decompositions.py#L4075 - - The above implement the 2d case. - """ - _expand_grid = False - torch._check( - interpolation_mode in (0, 1), - lambda: f"Invalid interpolation mode {interpolation_mode}", - ) - torch._check( - padding_mode in (0, 1, 2), lambda: f"Invalid padding mode {padding_mode}") - - # a is 5D: [B, C, D, H, W] - - def unnormalize(coords: Tensor, size: int) -> Tensor: - # Rescale coordinates from [-1, 1] to: - # [0, size - 1] if align_corners is True - # [-.5, size -.5] if align_corners is False - mul = (size * 0.5 - 0.5) if align_corners else (size * 0.5) - ofs = size * 0.5 - 0.5 - return coords * mul + ofs - - # Reflects coordinates until they fall between low and high (inclusive). - # The bounds are passed as twice their value so that half-integer values - # can be represented as ints. - def reflect_coordinates(coords: Tensor, twice_low: int, - twice_high: int) -> Tensor: - if twice_low == twice_high: - return torch.zeros_like(coords) - coords_min = twice_low / 2 - coords_span = (twice_high - twice_low) / 2 - coords2 = (coords - coords_min).abs() - extra = torch.fmod(coords2, coords_span) - flips = (coords2 / coords_span).floor().to(dtype=torch.int8) - return torch.where(flips & 1 == 0, extra + coords_min, - coords_span + coords_min - extra) - - def compute_coordinates(coords: Tensor, size: int) -> Tensor: - if padding_mode == 0: # Zero - return coords - elif padding_mode == 1: # Borders - return torch.clamp(coords, 0, size - 1) - else: # padding_mode == 2, Reflection - if align_corners: - coords_reflected = reflect_coordinates(coords, 0, 2 * (size - 1)) - else: - coords_reflected = reflect_coordinates(coords, -1, 2 * size - 1) - return torch.clamp(coords_reflected, 0, size - 1) - - def compute_source_index(coords: Tensor, size: int) -> Tensor: - coords_un = unnormalize(coords, size) - return compute_coordinates(coords_un, size) - - N, C, iD, iH, iW = a.shape - _, oD, oH, oW, three = grid.shape - assert three == 3, "Last dim of grid must be 3. got {}".format(three) - - def in_bounds_cond(xs: Tensor, ys: Tensor, zs) -> Tensor: - xcheck = torch.logical_and(0 <= xs, xs < iW) - ycheck = torch.logical_and(0 <= ys, ys < iH) - zcheck = torch.logical_and(0 <= zs, zs < iD) - return torch.logical_and(xcheck, torch.logical_and(ycheck, zcheck)) - - N_idx = torch.arange(N, device=a.device).view(N, 1, 1, 1, 1) - C_idx = torch.arange(C, device=a.device).view(1, C, 1, 1, 1) - - def clip(xs: torch.Tensor, ys: torch.Tensor, zs, ws: torch.Tensor): - cond = in_bounds_cond(xs, ys, zs) - # To clip to inside valid coordinates, we map the coordinates - # to (x, y) = (0, 0) and also set the weight to 0 - # We also change the shape of the tensor to the appropriate one for - # broadcasting with N_idx, C_idx for the purposes of advanced indexing - c = C if _expand_grid else 1 - return tuple( - torch.where(cond, t, 0).view(N, c, oD, oH, oW) for t in ( - xs.to(dtype=torch.int64), - ys.to(dtype=torch.int64), - zs.to(dtype=torch.int64), - ws, - )) - - def get_summand(ix: torch.Tensor, iy: torch.Tensor, iz: torch.Tensor, - w) -> Tensor: - # Perform clipping, index into input tensor and multiply by weight - idx_x, idx_y, idx_z, w_ = clip(ix, iy, iz, w) - return a[N_idx, C_idx, idx_z, idx_y, idx_x] * w_ - - x = grid[..., 0] - y = grid[..., 1] - d = grid[..., 2] - - if interpolation_mode == 0: # Bilinear - ix = compute_source_index(x, iW) - iy = compute_source_index(y, iH) - id_ = compute_source_index(d, iD) - - ix_nwf, iy_nwf, id_nwf = ix.floor(), iy.floor(), id_.floor() - ix_nef, iy_nef, id_nef = ix_nwf + 1, iy_nwf, id_nwf - ix_swf, iy_swf, id_swf = ix_nwf, iy_nwf + 1, id_nwf - ix_sef, iy_sef, id_sef = ix_nef, iy_swf, id_nwf - ix_nwb, iy_nwb, id_nwb = ix_nwf, iy_nwf, id_nwf + 1 - ix_neb, iy_neb, id_neb = ix_nef, iy_nef, id_nwf + 1 - ix_swb, iy_swb, id_swb = ix_swf, iy_swf, id_nwf + 1 - ix_seb, iy_seb, id_seb = ix_sef, iy_sef, id_nwf + 1 - - w_nwf = (ix_seb - ix) * (iy_seb - iy) * (id_seb - id_) - w_nef = (ix - ix_swb) * (iy_swb - iy) * (id_swb - id_) - w_swf = (ix_neb - ix) * (iy - iy_neb) * (id_neb - id_) - w_sef = (ix - ix_nwb) * (iy - iy_nwb) * (id_nwb - id_) - w_nwb = (ix_sef - ix) * (iy_sef - iy) * (id_ - id_sef) - w_neb = (ix - ix_swf) * (iy_swf - iy) * (id_ - id_swf) - w_swb = (ix_nef - ix) * (iy - iy_nef) * (id_ - id_nef) - w_seb = (ix - ix_nwf) * (iy - iy_nwf) * (id_ - id_nwf) - - return _sum_tensors( - get_summand(ix, iy, id_, w) for (ix, iy, id_, w) in ( - (ix_nwf, iy_nwf, id_nwf, w_nwf), - (ix_nef, iy_nef, id_nef, w_nef), - (ix_swf, iy_swf, id_swf, w_swf), - (ix_sef, iy_sef, id_sef, w_sef), - (ix_nwb, iy_nwb, id_nwb, w_nwb), - (ix_neb, iy_neb, id_neb, w_neb), - (ix_swb, iy_swb, id_swb, w_swb), - (ix_seb, iy_seb, id_seb, w_seb), - )) - else: # interpolation_mode == 1: # Nearest - ix = compute_source_index(x, iW) - iy = compute_source_index(y, iH) - iz = compute_source_index(d, iD) - - ix_nearest = ix.round() - iy_nearest = iy.round() - iz_nearest = iz.round() - - return get_summand(ix_nearest, iy_nearest, iz_nearest, 1) + """References: https://github.com/pytorch/pytorch/blob/06a7dc21c1005750598c37f3adbc031183c74de6/torch/_decomp/decompositions.py#L4075 + + The above implement the 2d case. + """ + _expand_grid = False + torch._check( + interpolation_mode in (0, 1), + lambda: f"Invalid interpolation mode {interpolation_mode}", + ) + torch._check( + padding_mode in (0, 1, 2), + lambda: f"Invalid padding mode {padding_mode}", + ) + + # a is 5D: [B, C, D, H, W] + + def unnormalize(coords: Tensor, size: int) -> Tensor: + # Rescale coordinates from [-1, 1] to: + # [0, size - 1] if align_corners is True + # [-.5, size -.5] if align_corners is False + mul = (size * 0.5 - 0.5) if align_corners else (size * 0.5) + ofs = size * 0.5 - 0.5 + return coords * mul + ofs + + # Reflects coordinates until they fall between low and high (inclusive). + # The bounds are passed as twice their value so that half-integer values + # can be represented as ints. + def reflect_coordinates( + coords: Tensor, twice_low: int, twice_high: int + ) -> Tensor: + if twice_low == twice_high: + return torch.zeros_like(coords) + coords_min = twice_low / 2 + coords_span = (twice_high - twice_low) / 2 + coords2 = (coords - coords_min).abs() + extra = torch.fmod(coords2, coords_span) + flips = (coords2 / coords_span).floor().to(dtype=torch.int8) + return torch.where( + flips & 1 == 0, extra + coords_min, coords_span + coords_min - extra + ) + + def compute_coordinates(coords: Tensor, size: int) -> Tensor: + if padding_mode == 0: # Zero + return coords + elif padding_mode == 1: # Borders + return torch.clamp(coords, 0, size - 1) + else: # padding_mode == 2, Reflection + if align_corners: + coords_reflected = reflect_coordinates( + coords, 0, 2 * (size - 1) + ) + else: + coords_reflected = reflect_coordinates(coords, -1, 2 * size - 1) + return torch.clamp(coords_reflected, 0, size - 1) + + def compute_source_index(coords: Tensor, size: int) -> Tensor: + coords_un = unnormalize(coords, size) + return compute_coordinates(coords_un, size) + + N, C, iD, iH, iW = a.shape + _, oD, oH, oW, three = grid.shape + assert three == 3, "Last dim of grid must be 3. got {}".format(three) + + def in_bounds_cond(xs: Tensor, ys: Tensor, zs) -> Tensor: + xcheck = torch.logical_and(0 <= xs, xs < iW) + ycheck = torch.logical_and(0 <= ys, ys < iH) + zcheck = torch.logical_and(0 <= zs, zs < iD) + return torch.logical_and(xcheck, torch.logical_and(ycheck, zcheck)) + + N_idx = torch.arange(N, device=a.device).view(N, 1, 1, 1, 1) + C_idx = torch.arange(C, device=a.device).view(1, C, 1, 1, 1) + + def clip(xs: torch.Tensor, ys: torch.Tensor, zs, ws: torch.Tensor): + cond = in_bounds_cond(xs, ys, zs) + # To clip to inside valid coordinates, we map the coordinates + # to (x, y) = (0, 0) and also set the weight to 0 + # We also change the shape of the tensor to the appropriate one for + # broadcasting with N_idx, C_idx for the purposes of advanced indexing + c = C if _expand_grid else 1 + return tuple( + torch.where(cond, t, 0).view(N, c, oD, oH, oW) + for t in ( + xs.to(dtype=torch.int64), + ys.to(dtype=torch.int64), + zs.to(dtype=torch.int64), + ws, + ) + ) + + def get_summand( + ix: torch.Tensor, iy: torch.Tensor, iz: torch.Tensor, w + ) -> Tensor: + # Perform clipping, index into input tensor and multiply by weight + idx_x, idx_y, idx_z, w_ = clip(ix, iy, iz, w) + return a[N_idx, C_idx, idx_z, idx_y, idx_x] * w_ + + x = grid[..., 0] + y = grid[..., 1] + d = grid[..., 2] + + if interpolation_mode == 0: # Bilinear + ix = compute_source_index(x, iW) + iy = compute_source_index(y, iH) + id_ = compute_source_index(d, iD) + + ix_nwf, iy_nwf, id_nwf = ix.floor(), iy.floor(), id_.floor() + ix_nef, iy_nef, id_nef = ix_nwf + 1, iy_nwf, id_nwf + ix_swf, iy_swf, id_swf = ix_nwf, iy_nwf + 1, id_nwf + ix_sef, iy_sef, id_sef = ix_nef, iy_swf, id_nwf + ix_nwb, iy_nwb, id_nwb = ix_nwf, iy_nwf, id_nwf + 1 + ix_neb, iy_neb, id_neb = ix_nef, iy_nef, id_nwf + 1 + ix_swb, iy_swb, id_swb = ix_swf, iy_swf, id_nwf + 1 + ix_seb, iy_seb, id_seb = ix_sef, iy_sef, id_nwf + 1 + + w_nwf = (ix_seb - ix) * (iy_seb - iy) * (id_seb - id_) + w_nef = (ix - ix_swb) * (iy_swb - iy) * (id_swb - id_) + w_swf = (ix_neb - ix) * (iy - iy_neb) * (id_neb - id_) + w_sef = (ix - ix_nwb) * (iy - iy_nwb) * (id_nwb - id_) + w_nwb = (ix_sef - ix) * (iy_sef - iy) * (id_ - id_sef) + w_neb = (ix - ix_swf) * (iy_swf - iy) * (id_ - id_swf) + w_swb = (ix_nef - ix) * (iy - iy_nef) * (id_ - id_nef) + w_seb = (ix - ix_nwf) * (iy - iy_nwf) * (id_ - id_nwf) + + return _sum_tensors( + get_summand(ix, iy, id_, w) + for (ix, iy, id_, w) in ( + (ix_nwf, iy_nwf, id_nwf, w_nwf), + (ix_nef, iy_nef, id_nef, w_nef), + (ix_swf, iy_swf, id_swf, w_swf), + (ix_sef, iy_sef, id_sef, w_sef), + (ix_nwb, iy_nwb, id_nwb, w_nwb), + (ix_neb, iy_neb, id_neb, w_neb), + (ix_swb, iy_swb, id_swb, w_swb), + (ix_seb, iy_seb, id_seb, w_seb), + ) + ) + else: # interpolation_mode == 1: # Nearest + ix = compute_source_index(x, iW) + iy = compute_source_index(y, iH) + iz = compute_source_index(d, iD) + + ix_nearest = ix.round() + iy_nearest = iy.round() + iz_nearest = iz.round() + + return get_summand(ix_nearest, iy_nearest, iz_nearest, 1) DECOMPOSITIONS = decomp.get_decompositions([ diff --git a/torchax/torchax/device_module.py b/torchax/torchax/device_module.py index 20fceaf06b43..2df98eb41423 100644 --- a/torchax/torchax/device_module.py +++ b/torchax/torchax/device_module.py @@ -1,26 +1,26 @@ def _is_in_bad_fork(): - return False + return False def manual_seed_all(seed): - pass + pass def device_count(): - return 1 + return 1 def get_rng_state(): - return [] + return [] def set_rng_state(new_state, device): - pass + pass def is_available(): - return True + return True def current_device(): - return 0 + return 0 diff --git a/torchax/torchax/distributed.py b/torchax/torchax/distributed.py index eb12f4eb2d56..a7dd0fea04fe 100644 --- a/torchax/torchax/distributed.py +++ b/torchax/torchax/distributed.py @@ -30,212 +30,225 @@ class ProcessGroupJax(ProcessGroup): - """Distributed backend implemented with JAX.""" - - def __init__(self, prefix_store, rank, size, timeout): - super().__init__(rank, size) - self._group_name = None - - def getBackendName(self): - return "jax" - - # TODO(wcromar): why doesn't default group name setter work? - # https://github.com/pytorch/pytorch/blob/7b1988f9222f3dec5cc2012afce84218199748ae/torch/csrc/distributed/c10d/ProcessGroup.cpp#L148-L152 - def _set_group_name(self, name: str) -> None: - self._group_name = name - - @property - def group_name(self): - assert self._group_name - return self._group_name - - @staticmethod - def _work( - tensors: Union[torch.Tensor, List[torch.Tensor], - List[List[torch.Tensor]]], - ) -> dist.Work: - fut = torch.futures.Future() - fut.set_result(tensors) - return torch._C._distributed_c10d._create_work_from_future(fut) - - def _allgather_base( - self, - output: torch.Tensor, - input: torch.Tensor, - opts=..., - ) -> dist.Work: - assert isinstance(input, torchax.tensor.Tensor) - assert isinstance(output, torchax.tensor.Tensor) - torch.distributed._functional_collectives.all_gather_tensor_inplace( - output, input, group=self) - return self._work(output) - - def allreduce( - self, - tensors: List[torch.Tensor], - opts: dist.AllreduceOptions = ..., - ) -> dist.Work: - assert len(tensors) == 1 - assert isinstance(tensors[0], torchax.tensor.Tensor) - torch.distributed._functional_collectives.all_reduce_inplace( - tensors[0], - torch.distributed._functional_collectives.REDUCE_OP_TO_STR[ - opts.reduceOp.op], + """Distributed backend implemented with JAX.""" + + def __init__(self, prefix_store, rank, size, timeout): + super().__init__(rank, size) + self._group_name = None + + def getBackendName(self): + return "jax" + + # TODO(wcromar): why doesn't default group name setter work? + # https://github.com/pytorch/pytorch/blob/7b1988f9222f3dec5cc2012afce84218199748ae/torch/csrc/distributed/c10d/ProcessGroup.cpp#L148-L152 + def _set_group_name(self, name: str) -> None: + self._group_name = name + + @property + def group_name(self): + assert self._group_name + return self._group_name + + @staticmethod + def _work( + tensors: Union[ + torch.Tensor, List[torch.Tensor], List[List[torch.Tensor]] + ], + ) -> dist.Work: + fut = torch.futures.Future() + fut.set_result(tensors) + return torch._C._distributed_c10d._create_work_from_future(fut) + + def _allgather_base( self, - ) + output: torch.Tensor, + input: torch.Tensor, + opts=..., + ) -> dist.Work: + assert isinstance(input, torchax.tensor.Tensor) + assert isinstance(output, torchax.tensor.Tensor) + torch.distributed._functional_collectives.all_gather_tensor_inplace( + output, input, group=self + ) + return self._work(output) + + def allreduce( + self, + tensors: List[torch.Tensor], + opts: dist.AllreduceOptions = ..., + ) -> dist.Work: + assert len(tensors) == 1 + assert isinstance(tensors[0], torchax.tensor.Tensor) + torch.distributed._functional_collectives.all_reduce_inplace( + tensors[0], + torch.distributed._functional_collectives.REDUCE_OP_TO_STR[ + opts.reduceOp.op + ], + self, + ) - return self._work(tensors) + return self._work(tensors) - def broadcast( - self, - tensors: List[torch.Tensor], - opts: dist.BroadcastOptions = ..., - ) -> dist.Work: - assert len(tensors) == 1 - assert isinstance(tensors[0], torchax.tensor.Tensor) - tensors[0].copy_( - torch.distributed._functional_collectives.broadcast( - tensors[0], opts.rootRank, group=self)) + def broadcast( + self, + tensors: List[torch.Tensor], + opts: dist.BroadcastOptions = ..., + ) -> dist.Work: + assert len(tensors) == 1 + assert isinstance(tensors[0], torchax.tensor.Tensor) + tensors[0].copy_( + torch.distributed._functional_collectives.broadcast( + tensors[0], opts.rootRank, group=self + ) + ) - return self._work(tensors) + return self._work(tensors) dist.Backend.register_backend("jax", ProcessGroupJax, devices=["jax"]) -def jax_rendezvous_handler(url: str, - timeout: datetime.timedelta = ..., - **kwargs): - """Initialize distributed store with JAX process IDs. - - Requires `$MASTER_ADDR` and `$MASTER_PORT`. - """ - # TODO(wcromar): jax.distributed.initialize(...) for multiprocess on GPU - # TODO(wcromar): Can we use the XLA coordinator as a Store? This isn't part - # of their public Python API - master_ip = os.environ["MASTER_ADDR"] - master_port = int(os.environ["MASTER_PORT"]) - # TODO(wcromar): Use `torchrun`'s store if available - store = dist.TCPStore( - master_ip, - master_port, - jax.process_count(), - is_master=jax.process_index() == 0, - ) +def jax_rendezvous_handler( + url: str, timeout: datetime.timedelta = ..., **kwargs +): + """Initialize distributed store with JAX process IDs. + + Requires `$MASTER_ADDR` and `$MASTER_PORT`. + """ + # TODO(wcromar): jax.distributed.initialize(...) for multiprocess on GPU + # TODO(wcromar): Can we use the XLA coordinator as a Store? This isn't part + # of their public Python API + master_ip = os.environ["MASTER_ADDR"] + master_port = int(os.environ["MASTER_PORT"]) + # TODO(wcromar): Use `torchrun`'s store if available + store = dist.TCPStore( + master_ip, + master_port, + jax.process_count(), + is_master=jax.process_index() == 0, + ) - yield (store, jax.process_index(), jax.process_count()) + yield (store, jax.process_index(), jax.process_count()) dist.register_rendezvous_handler("jax", jax_rendezvous_handler) def spawn(f, args=(), env: Optional[torchax.tensor.Environment] = None): - """Wrap `f` in a JAX `pmap` with the axis name `torch_dist` defined. - `f` is expected to take the replica index as a positional argument, similar - to `torch.multiprocessing.spawn`. - Note: `spawn` does not actually create parallel processes. - """ - env = env or torchax.default_env() - - def jax_wrapper(index, jax_args): - index, args = env.j2t_iso([index, jax_args]) - torch_outputs = f(index, *args) - return env.t2j_iso(torch_outputs) - - jax_outputs = jax.pmap( - jax_wrapper, axis_name="torch_dist")(np.arange(jax.device_count()), - env.t2j_iso(args)) - return env.j2t_iso(jax_outputs) + """Wrap `f` in a JAX `pmap` with the axis name `torch_dist` defined. + `f` is expected to take the replica index as a positional argument, similar + to `torch.multiprocessing.spawn`. + Note: `spawn` does not actually create parallel processes. + """ + env = env or torchax.default_env() + + def jax_wrapper(index, jax_args): + index, args = env.j2t_iso([index, jax_args]) + torch_outputs = f(index, *args) + return env.t2j_iso(torch_outputs) + + jax_outputs = jax.pmap(jax_wrapper, axis_name="torch_dist")( + np.arange(jax.device_count()), env.t2j_iso(args) + ) + return env.j2t_iso(jax_outputs) class DistributedDataParallel(torch.nn.Module): - """Re-implementation of DistributedDataParallel using JAX SPMD. - - Splits inputs along batch dimension (assumed to be 0) across all devices in - JAX runtime, including remote devices. Each process should load a distinct - shard of the input data using e.g. DistributedSampler. Each process' shard - is then further split among the addressable devices (e.g. local TPU chips) - by `shard_input`. - - Note: since parameters are replicated across addressable devices, inputs - must also be SPMD sharded using `shard_input` or `replicate_input`. - - Example usage: - - ``` - jax_model = torchax.distributed.DistributedDataParallel(create_model()) - for data, dataloader: - jax_data = jax_model.shard_input(data) - jax_output = jax_model(jax_data) - ``` - """ - - def __init__( - self, - module: torch.nn.Module, - env: Optional[torchax.tensor.Environment] = None, - **kwargs, - ): - if kwargs: - logging.warning(f"Unsupported kwargs {kwargs}") - - super().__init__() - self._env = env or torchax.default_env() - self._mesh = Mesh( - mesh_utils.create_device_mesh((jax.device_count(),)), - axis_names=("batch",), - ) - replicated_state = torch_pytree.tree_map_only( - torch.Tensor, - lambda t: self._env.j2t_iso( - jax.device_put( - self._env.to_xla(t)._elem, NamedSharding(self._mesh, P()))), - module.state_dict(), - ) - # TODO: broadcast - module.load_state_dict(replicated_state, assign=True) - self._module = module - - def shard_input(self, inp): - per_process_batch_size = inp.shape[0] # assumes batch dim is 0 - per_replica_batch_size = per_process_batch_size // jax.local_device_count() - per_replica_batches = torch.chunk(inp, jax.local_device_count()) - global_batch_size = per_replica_batch_size * jax.device_count() - global_batch_shape = (global_batch_size,) + inp.shape[1:] - - sharding = NamedSharding(self._mesh, P("batch")) - return self._env.j2t_iso( - jax.make_array_from_single_device_arrays( - global_batch_shape, - NamedSharding(self._mesh, P("batch")), - arrays=[ - jax.device_put(self._env.to_xla(batch)._elem, device) for batch, - device in zip(per_replica_batches, sharding.addressable_devices) - ], - )) - - def replicate_input(self, inp): - return self._env.j2t_iso( - jax.device_put(inp._elem, NamedSharding(self._mesh, P()))) + """Re-implementation of DistributedDataParallel using JAX SPMD. - def jit_step(self, func): + Splits inputs along batch dimension (assumed to be 0) across all devices in + JAX runtime, including remote devices. Each process should load a distinct + shard of the input data using e.g. DistributedSampler. Each process' shard + is then further split among the addressable devices (e.g. local TPU chips) + by `shard_input`. - @functools.partial( - interop.jax_jit, kwargs_for_jax_jit={'donate_argnums': 0}) - def _jit_fn(states, args): - self.load_state_dict(states) - outputs = func(*args) - return self.state_dict(), outputs + Note: since parameters are replicated across addressable devices, inputs + must also be SPMD sharded using `shard_input` or `replicate_input`. - @functools.wraps(func) - def inner(*args): - jax_states = self.state_dict() - new_states, outputs = _jit_fn(jax_states, args) - self.load_state_dict(new_states) - return outputs + Example usage: - return inner + ``` + jax_model = torchax.distributed.DistributedDataParallel(create_model()) + for data, dataloader: + jax_data = jax_model.shard_input(data) + jax_output = jax_model(jax_data) + ``` + """ - def forward(self, *args): - with self._env: - return self._module(*args) + def __init__( + self, + module: torch.nn.Module, + env: Optional[torchax.tensor.Environment] = None, + **kwargs, + ): + if kwargs: + logging.warning(f"Unsupported kwargs {kwargs}") + + super().__init__() + self._env = env or torchax.default_env() + self._mesh = Mesh( + mesh_utils.create_device_mesh((jax.device_count(),)), + axis_names=("batch",), + ) + replicated_state = torch_pytree.tree_map_only( + torch.Tensor, + lambda t: self._env.j2t_iso( + jax.device_put( + self._env.to_xla(t)._elem, NamedSharding(self._mesh, P()) + ) + ), + module.state_dict(), + ) + # TODO: broadcast + module.load_state_dict(replicated_state, assign=True) + self._module = module + + def shard_input(self, inp): + per_process_batch_size = inp.shape[0] # assumes batch dim is 0 + per_replica_batch_size = ( + per_process_batch_size // jax.local_device_count() + ) + per_replica_batches = torch.chunk(inp, jax.local_device_count()) + global_batch_size = per_replica_batch_size * jax.device_count() + global_batch_shape = (global_batch_size,) + inp.shape[1:] + + sharding = NamedSharding(self._mesh, P("batch")) + return self._env.j2t_iso( + jax.make_array_from_single_device_arrays( + global_batch_shape, + NamedSharding(self._mesh, P("batch")), + arrays=[ + jax.device_put(self._env.to_xla(batch)._elem, device) + for batch, device in zip( + per_replica_batches, sharding.addressable_devices + ) + ], + ) + ) + + def replicate_input(self, inp): + return self._env.j2t_iso( + jax.device_put(inp._elem, NamedSharding(self._mesh, P())) + ) + + def jit_step(self, func): + @functools.partial( + interop.jax_jit, kwargs_for_jax_jit={"donate_argnums": 0} + ) + def _jit_fn(states, args): + self.load_state_dict(states) + outputs = func(*args) + return self.state_dict(), outputs + + @functools.wraps(func) + def inner(*args): + jax_states = self.state_dict() + new_states, outputs = _jit_fn(jax_states, args) + self.load_state_dict(new_states) + return outputs + + return inner + + def forward(self, *args): + with self._env: + return self._module(*args) diff --git a/torchax/torchax/export.py b/torchax/torchax/export.py index 987fb92ba6ee..d91636300b27 100644 --- a/torchax/torchax/export.py +++ b/torchax/torchax/export.py @@ -1,5 +1,6 @@ # pylint: disable """Utilities for exporting a torch program to jax/stablehlo.""" + import copy from typing import Any, Dict, Tuple import torch @@ -16,40 +17,41 @@ class JaxInterpreter(torch.fx.Interpreter): - """Experimental.""" - - def __init__(self, graph_module): - super().__init__(graph_module) - import torchax.ops.jaten - import torchax.ops.jtorch - - def call_function(self, target, args: Tuple, kwargs: Dict) -> Any: - if not isinstance(target, - (torch._ops.OpOverloadPacket, torch._ops.OpOverload)): - return super().call_function(target, args, kwargs) - - if DEBUG: - print('Running ', target.name(), '--------') - - op = ops_registry.all_aten_ops.get(target) - if op is None: - op = ops_registry.all_aten_ops.get(target.overloadpacket) - assert op is not None, target - assert op.is_jax_function, op - if op is None: - op = ops_registry.all_aten_ops.get(target.overloadpacket) - if op is None: - print(target.name(), target.tags) - raise RuntimeError('No lowering found for', target.name()) - return op.func(*args, **kwargs) - - def run_node(self, n) -> Any: - res = super().run_node(n) - if DEBUG: - if n.op == 'call_function': - if hasattr(res, 'shape'): - print('Meta:', n.meta.get('val').shape, 'REAL: ', res.shape) - return res + """Experimental.""" + + def __init__(self, graph_module): + super().__init__(graph_module) + import torchax.ops.jaten + import torchax.ops.jtorch + + def call_function(self, target, args: Tuple, kwargs: Dict) -> Any: + if not isinstance( + target, (torch._ops.OpOverloadPacket, torch._ops.OpOverload) + ): + return super().call_function(target, args, kwargs) + + if DEBUG: + print("Running ", target.name(), "--------") + + op = ops_registry.all_aten_ops.get(target) + if op is None: + op = ops_registry.all_aten_ops.get(target.overloadpacket) + assert op is not None, target + assert op.is_jax_function, op + if op is None: + op = ops_registry.all_aten_ops.get(target.overloadpacket) + if op is None: + print(target.name(), target.tags) + raise RuntimeError("No lowering found for", target.name()) + return op.func(*args, **kwargs) + + def run_node(self, n) -> Any: + res = super().run_node(n) + if DEBUG: + if n.op == "call_function": + if hasattr(res, "shape"): + print("Meta:", n.meta.get("val").shape, "REAL: ", res.shape) + return res from torch._decomp import get_decompositions @@ -59,187 +61,219 @@ def run_node(self, n) -> Any: def _extract_states_from_exported_program(exported_model): - # NOTE call convention: (parameters, buffers, user_inputs) - param_and_buffer_keys = exported_model.graph_signature.parameters + exported_model.graph_signature.buffers - state_dict = copy.copy(exported_model.state_dict) - if (constants := getattr(exported_model, 'constants', None)) is not None: - state_dict.update(constants) - param_buffer_values = list(state_dict[key] for key in param_and_buffer_keys) + # NOTE call convention: (parameters, buffers, user_inputs) + param_and_buffer_keys = ( + exported_model.graph_signature.parameters + + exported_model.graph_signature.buffers + ) + state_dict = copy.copy(exported_model.state_dict) + if (constants := getattr(exported_model, "constants", None)) is not None: + state_dict.update(constants) + param_buffer_values = list(state_dict[key] for key in param_and_buffer_keys) - if hasattr(exported_model.graph_signature, "lifted_tensor_constants"): - for name in exported_model.graph_signature.lifted_tensor_constants: - param_buffer_values.append(exported_model.tensor_constants[name]) + if hasattr(exported_model.graph_signature, "lifted_tensor_constants"): + for name in exported_model.graph_signature.lifted_tensor_constants: + param_buffer_values.append(exported_model.tensor_constants[name]) - return param_and_buffer_keys, param_buffer_values + return param_and_buffer_keys, param_buffer_values def exported_program_to_jax(exported_program, export_raw: bool = False): - """returns a pytree of jax arrays(state), and - - a callable(func) that is jax function. - - func(state, input) would be how you call it. - """ - if torch.__version__ >= '2.2': - # torch version 2.1 didn't expose this yet - exported_program = exported_program.run_decompositions() - exported_program = exported_program.run_decompositions( - decompositions.DECOMPOSITIONS) - if DEBUG: - print(exported_program.graph_module.code) - - names, states = _extract_states_from_exported_program(exported_program) - - def _extract_args(args, kwargs): - flat_args, received_spec = pytree.tree_flatten( - (args, kwargs)) # type: ignore[possibly-undefined] - return flat_args - - num_mutations = len(exported_program.graph_signature.buffers_to_mutate) - - def func(states, inputs): - args = _extract_args(inputs, {}) - res = JaxInterpreter(exported_program.graph_module).run( - *states, - *args, - enable_io_processing=False, - ) - res = res[num_mutations:] - return res + """returns a pytree of jax arrays(state), and - if export_raw: - return names, states, func - env = torchax.default_env() - states = env.t2j_copy(states) - return states, func + a callable(func) that is jax function. + func(state, input) would be how you call it. + """ + if torch.__version__ >= "2.2": + # torch version 2.1 didn't expose this yet + exported_program = exported_program.run_decompositions() + exported_program = exported_program.run_decompositions( + decompositions.DECOMPOSITIONS + ) + if DEBUG: + print(exported_program.graph_module.code) -def extract_avals(exported): - """Return JAX Abstract Value shapes for all input parameters of the exported - program. This supports dynamic batch dimensions, including with constraints. - """ + names, states = _extract_states_from_exported_program(exported_program) - def _to_aval(arg_meta, symbolic_shapes): - """Convet from torch type to jax abstract value for export tracing - """ + def _extract_args(args, kwargs): + flat_args, received_spec = pytree.tree_flatten((args, kwargs)) # type: ignore[possibly-undefined] + return flat_args + + num_mutations = len(exported_program.graph_signature.buffers_to_mutate) + + def func(states, inputs): + args = _extract_args(inputs, {}) + res = JaxInterpreter(exported_program.graph_module).run( + *states, + *args, + enable_io_processing=False, + ) + res = res[num_mutations:] + return res - def _get_dim(d): - if isinstance(d, torch.SymInt): - return symbolic_shapes[str(d)] - return d - - val = arg_meta['val'] - is_scalar = isinstance(val, float) or isinstance(val, int) or isinstance( - val, bool) - if is_scalar: - return jax.ShapeDtypeStruct([], type(arg_meta['val'])) - - tensor_meta = arg_meta['tensor_meta'] - shape = [_get_dim(d) for d in tensor_meta.shape] - return jax.ShapeDtypeStruct(shape, mappings.t2j_dtype(tensor_meta.dtype)) - - def _get_inputs(exported): - """Return placeholders with input metadata""" - placeholders = [p for p in exported.graph.nodes if p.op == "placeholder"] - input_placeholders = [ - p for p, s in zip(placeholders, exported.graph_signature.input_specs) - if s.kind == torch.export.graph_signature.InputKind.USER_INPUT - ] - return input_placeholders - - def _build_symbolic_shapes(range_constraints): - """Convert torch SymInt to JAX symbolic_shape and stores in a map using the - string name of the torch symbolic int. - - TODO: There is probably a better way of storing a key for a symbolic int. - This value needs to be looked up again in `_to_aval` to figure out which - JAX symbolic to map to for a given torch tensor. + if export_raw: + return names, states, func + env = torchax.default_env() + states = env.t2j_copy(states) + return states, func + + +def extract_avals(exported): + """Return JAX Abstract Value shapes for all input parameters of the exported + program. This supports dynamic batch dimensions, including with constraints. """ - if len(range_constraints) == 0: - return None - - def _build_symbolic_constraints(symbol_name, torch_constraint): - """Convert torch SymInt constraints to string for JAX symbolic_shape - Using sympy may be overkill here, currently PyTorch only uses ValueRanges - which allow specifying the min and the max of a value, for example: - torch.export.Dim("a", min=5, max=10) - ==> ("a >= 5", "a <= 10",) - """ - if not isinstance(torch_constraint, torch.utils._sympy.value_ranges. - ValueRanges) or torch_constraint.is_bool: - raise TypeError( - f"No symbolic constraint handler for: {torch_constraint}") - - constraints = [] - symbol = sympy.Symbol(symbol_name) - if torch_constraint.lower != 2: - constraints.append(symbol >= torch_constraint.lower) - from sympy.core.singleton import S - if not torch_constraint.upper.is_infinite and torch_constraint.upper is not S.IntInfinity: - constraints.append(symbol <= torch_constraint.upper) - - return tuple(sympy.pretty(c, use_unicode=False) for c in constraints) - - def _build_symbolic_shape(sym, constraint, free_symbols): - """Returns a JAX symbolic shape for a given symbol and constraint - - There are two possible sympy `sym` inputs: - 1. Symbol - (s0) These can have custom constraints. - 2. Expr - (s0*2) These apply the expr to s0's constraints, cannot override. - - Currently support is limited to operations with a symbol and and int, - in `torch/export/dynamic_shapes.py`: - "Only increasing linear operations with integer coefficients are supported." - """ - symbol_name = str(sym) - constraints = _build_symbolic_constraints(symbol_name, constraint) - if sym.is_symbol: - symbolic_shape = jax.export.symbolic_shape( - symbol_name, constraints=constraints) - else: - assert len(sym.free_symbols) > 0 - scope = free_symbols[str(list(sym.free_symbols)[0])].scope - symbolic_shape = jax.export.symbolic_shape(symbol_name, scope=scope) - assert len(symbolic_shape) == 1 - return symbolic_shape[0] - - # Populate symbol variables before expressions, exprs need to use the same - # Symbolic scope as the variable they operate on. Expressions can only be - # integer compuations on symbol variables, so each symbol variable is OK to - # have its own scope. - symbolic_shapes = {} - symbol_variables = [ - (s, v) for s, v in range_constraints.items() if s.is_symbol - ] - symbol_exprs = [ - (s, v) for s, v in range_constraints.items() if not s.is_symbol - ] - for sym, constraint in symbol_variables + symbol_exprs: - symbolic_shape = _build_symbolic_shape(sym, constraint, symbolic_shapes) - symbolic_shapes[str(sym)] = symbolic_shape - return symbolic_shapes - - symbolic_shapes = _build_symbolic_shapes(exported.range_constraints) - args = _get_inputs(exported) - - if DEBUG: - print('Inputs to aval:', args, '--------') - print('Symbolic shapes:', symbolic_shapes) - for arg in args: - print('Meta2Aval', arg.meta, '--> ', _to_aval(arg.meta, symbolic_shapes)) - - return [_to_aval(arg.meta, symbolic_shapes) for arg in args] + + def _to_aval(arg_meta, symbolic_shapes): + """Convet from torch type to jax abstract value for export tracing""" + + def _get_dim(d): + if isinstance(d, torch.SymInt): + return symbolic_shapes[str(d)] + return d + + val = arg_meta["val"] + is_scalar = ( + isinstance(val, float) + or isinstance(val, int) + or isinstance(val, bool) + ) + if is_scalar: + return jax.ShapeDtypeStruct([], type(arg_meta["val"])) + + tensor_meta = arg_meta["tensor_meta"] + shape = [_get_dim(d) for d in tensor_meta.shape] + return jax.ShapeDtypeStruct( + shape, mappings.t2j_dtype(tensor_meta.dtype) + ) + + def _get_inputs(exported): + """Return placeholders with input metadata""" + placeholders = [ + p for p in exported.graph.nodes if p.op == "placeholder" + ] + input_placeholders = [ + p + for p, s in zip(placeholders, exported.graph_signature.input_specs) + if s.kind == torch.export.graph_signature.InputKind.USER_INPUT + ] + return input_placeholders + + def _build_symbolic_shapes(range_constraints): + """Convert torch SymInt to JAX symbolic_shape and stores in a map using the + string name of the torch symbolic int. + + TODO: There is probably a better way of storing a key for a symbolic int. + This value needs to be looked up again in `_to_aval` to figure out which + JAX symbolic to map to for a given torch tensor. + """ + if len(range_constraints) == 0: + return None + + def _build_symbolic_constraints(symbol_name, torch_constraint): + """Convert torch SymInt constraints to string for JAX symbolic_shape + Using sympy may be overkill here, currently PyTorch only uses ValueRanges + which allow specifying the min and the max of a value, for example: + torch.export.Dim("a", min=5, max=10) + ==> ("a >= 5", "a <= 10",) + """ + if ( + not isinstance( + torch_constraint, + torch.utils._sympy.value_ranges.ValueRanges, + ) + or torch_constraint.is_bool + ): + raise TypeError( + f"No symbolic constraint handler for: {torch_constraint}" + ) + + constraints = [] + symbol = sympy.Symbol(symbol_name) + if torch_constraint.lower != 2: + constraints.append(symbol >= torch_constraint.lower) + from sympy.core.singleton import S + + if ( + not torch_constraint.upper.is_infinite + and torch_constraint.upper is not S.IntInfinity + ): + constraints.append(symbol <= torch_constraint.upper) + + return tuple( + sympy.pretty(c, use_unicode=False) for c in constraints + ) + + def _build_symbolic_shape(sym, constraint, free_symbols): + """Returns a JAX symbolic shape for a given symbol and constraint + + There are two possible sympy `sym` inputs: + 1. Symbol - (s0) These can have custom constraints. + 2. Expr - (s0*2) These apply the expr to s0's constraints, cannot override. + + Currently support is limited to operations with a symbol and and int, + in `torch/export/dynamic_shapes.py`: + "Only increasing linear operations with integer coefficients are supported." + """ + symbol_name = str(sym) + constraints = _build_symbolic_constraints(symbol_name, constraint) + if sym.is_symbol: + symbolic_shape = jax.export.symbolic_shape( + symbol_name, constraints=constraints + ) + else: + assert len(sym.free_symbols) > 0 + scope = free_symbols[str(list(sym.free_symbols)[0])].scope + symbolic_shape = jax.export.symbolic_shape( + symbol_name, scope=scope + ) + assert len(symbolic_shape) == 1 + return symbolic_shape[0] + + # Populate symbol variables before expressions, exprs need to use the same + # Symbolic scope as the variable they operate on. Expressions can only be + # integer compuations on symbol variables, so each symbol variable is OK to + # have its own scope. + symbolic_shapes = {} + symbol_variables = [ + (s, v) for s, v in range_constraints.items() if s.is_symbol + ] + symbol_exprs = [ + (s, v) for s, v in range_constraints.items() if not s.is_symbol + ] + for sym, constraint in symbol_variables + symbol_exprs: + symbolic_shape = _build_symbolic_shape( + sym, constraint, symbolic_shapes + ) + symbolic_shapes[str(sym)] = symbolic_shape + return symbolic_shapes + + symbolic_shapes = _build_symbolic_shapes(exported.range_constraints) + args = _get_inputs(exported) + + if DEBUG: + print("Inputs to aval:", args, "--------") + print("Symbolic shapes:", symbolic_shapes) + for arg in args: + print( + "Meta2Aval", + arg.meta, + "--> ", + _to_aval(arg.meta, symbolic_shapes), + ) + + return [_to_aval(arg.meta, symbolic_shapes) for arg in args] def exported_program_to_stablehlo(exported_program): - """Replacement for torch_xla.stablehlo.exported_program_to_stablehlo + """Replacement for torch_xla.stablehlo.exported_program_to_stablehlo - Convert a program exported via torch.export to StableHLO. + Convert a program exported via torch.export to StableHLO. - This supports dynamic dimension sizes and generates explicit checks for - dynamo guards in the IR using shape_assertion custom_call ops. - """ - weights, func = exported_program_to_jax(exported_program) - jax_avals = extract_avals(exported_program) - jax_export = jax.export.export(jax.jit(func))(weights, (jax_avals,)) - return weights, jax_export + This supports dynamic dimension sizes and generates explicit checks for + dynamo guards in the IR using shape_assertion custom_call ops. + """ + weights, func = exported_program_to_jax(exported_program) + jax_avals = extract_avals(exported_program) + jax_export = jax.export.export(jax.jit(func))(weights, (jax_avals,)) + return weights, jax_export diff --git a/torchax/torchax/flax.py b/torchax/torchax/flax.py index 28542d79c90e..1b10867890e6 100644 --- a/torchax/torchax/flax.py +++ b/torchax/torchax/flax.py @@ -6,34 +6,35 @@ class FlaxNNModule(torch.nn.Module): - - def __init__(self, env, flax_module, sample_args, sample_kwargs=None): - super().__init__() - prng = env.prng_key - sample_kwargs = sample_kwargs or {} - parameter_dict = tx.interop.call_jax(flax_module.init, prng, *sample_args, - **sample_kwargs) - - self._params = self._encode_nested_dict(parameter_dict) - - self._flax_module = flax_module - - def _encode_nested_dict(self, nested_dict): - child_module = torch.nn.Module() - for k, v in nested_dict.items(): - if isinstance(v, dict): - child_module.add_module(k, self._encode_nested_dict(v)) - else: - child_module.register_parameter(k, torch.nn.Parameter(v)) - return child_module - - def _decode_nested_dict(self, child_module): - result = dict(child_module.named_parameters(recurse=False)) - for k, v in child_module.named_children(): - result[k] = self._decode_nested_dict(v) - return result - - def forward(self, *args, **kwargs): - nested_dict_params = self._decode_nested_dict(self._params) - return tx.interop.call_jax(self._flax_module.apply, nested_dict_params, - *args, **kwargs) + def __init__(self, env, flax_module, sample_args, sample_kwargs=None): + super().__init__() + prng = env.prng_key + sample_kwargs = sample_kwargs or {} + parameter_dict = tx.interop.call_jax( + flax_module.init, prng, *sample_args, **sample_kwargs + ) + + self._params = self._encode_nested_dict(parameter_dict) + + self._flax_module = flax_module + + def _encode_nested_dict(self, nested_dict): + child_module = torch.nn.Module() + for k, v in nested_dict.items(): + if isinstance(v, dict): + child_module.add_module(k, self._encode_nested_dict(v)) + else: + child_module.register_parameter(k, torch.nn.Parameter(v)) + return child_module + + def _decode_nested_dict(self, child_module): + result = dict(child_module.named_parameters(recurse=False)) + for k, v in child_module.named_children(): + result[k] = self._decode_nested_dict(v) + return result + + def forward(self, *args, **kwargs): + nested_dict_params = self._decode_nested_dict(self._params) + return tx.interop.call_jax( + self._flax_module.apply, nested_dict_params, *args, **kwargs + ) diff --git a/torchax/torchax/interop.py b/torchax/torchax/interop.py index 419e3232773f..c7a1d50feea4 100644 --- a/torchax/torchax/interop.py +++ b/torchax/torchax/interop.py @@ -17,327 +17,339 @@ def extract_all_buffers(m: torch.nn.Module): - buffers = {} - params = {} - - def extract_one(module, prefix): - for k in dir(module): - try: - v = getattr(module, k) - except: - continue - qual_name = prefix + k - if isinstance(v, torch.nn.parameter.Parameter) and v.requires_grad: - params[qual_name] = v - elif isinstance(v, torch.Tensor): - buffers[qual_name] = v - for name, child in module.named_children(): - extract_one(child, prefix + name + '.') - - extract_one(m, '') - return params, buffers + buffers = {} + params = {} + + def extract_one(module, prefix): + for k in dir(module): + try: + v = getattr(module, k) + except: + continue + qual_name = prefix + k + if isinstance(v, torch.nn.parameter.Parameter) and v.requires_grad: + params[qual_name] = v + elif isinstance(v, torch.Tensor): + buffers[qual_name] = v + for name, child in module.named_children(): + extract_one(child, prefix + name + ".") + + extract_one(m, "") + return params, buffers def set_all_buffers(m, params, buffers): + def set_one(module, prefix): + for k in dir(module): + qual_name = prefix + k + if (potential_v := buffers.get(qual_name)) is not None: + setattr(module, k, potential_v) + elif (potential_v := params.get(qual_name)) is not None: + print(k, potential_v) + setattr(module, k, torch.nn.Parameter(potential_v)) + for name, child in module.named_children(): + set_one(child, prefix + name + ".") - def set_one(module, prefix): - for k in dir(module): - qual_name = prefix + k - if (potential_v := buffers.get(qual_name)) is not None: - setattr(module, k, potential_v) - elif (potential_v := params.get(qual_name)) is not None: - print(k, potential_v) - setattr(module, k, torch.nn.Parameter(potential_v)) - for name, child in module.named_children(): - set_one(child, prefix + name + '.') - - set_one(m, '') + set_one(m, "") class JittableModule(torch.nn.Module): - - def __init__(self, - m: torch.nn.Module, - extra_jit_args={}, - dedup_parameters=True): - super().__init__() - self.params, self.buffers = extract_all_buffers(m) - self._model = m - self._jitted = {} - - self._extra_jit_args = extra_jit_args - - self._extra_dumped_weights = {} - - if dedup_parameters: - temp = collections.defaultdict(list) - for k, v in self.params.items(): - temp[id(v)].append(k) - - for v in temp.values(): - if len(v) > 1: - # duplicated weights with different name - self._extra_dumped_weights[v[0]] = v[1:] - for extra_keys in v[1:]: - del self.params[extra_keys] - - @property - def __class__(self): - # Lie about the class type so that - # isinstance(jittable_module, self._model.__class__) works - return self._model.__class__ - - def __call__(self, *args, **kwargs): - return self.forward(*args, **kwargs) - - def functional_call(self, method_name, params, buffers, *args, **kwargs): - kwargs = kwargs or {} - params_copy = copy.copy(params) - params_copy.update(buffers) - # reinflate the state dict so there are not any missing keys - for k, v in self._extra_dumped_weights.items(): - for new_key in v: - params_copy[new_key] = params_copy[k] - with torch_stateless._reparametrize_module(self._model, params_copy): - res = getattr(self._model, method_name)(*args, **kwargs) - return res - - def forward(self, *args, **kwargs): - if 'forward' not in self._jitted: - jitted = jax_jit( - functools.partial(self.functional_call, 'forward'), - kwargs_for_jax_jit=self._extra_jit_args, - ) - - def jitted_forward(*args, **kwargs): - return jitted(self.params, self.buffers, *args, **kwargs) - - self._jitted['forward'] = jitted_forward - return self._jitted['forward'](*args, **kwargs) - - def __getattr__(self, key): - if key == '_model': - return super().__getattr__(key) - if key in self._jitted: - return self._jitted[key] - return getattr(self._model, key) - - def make_jitted(self, key): - jitted = jax_jit( - functools.partial(self.functional_call, key), - kwargs_for_jax_jit=self._extra_jit_args) - - def call(*args, **kwargs): - return jitted(self.params, self.buffers, *args, **kwargs) - - self._jitted[key] = call + def __init__( + self, m: torch.nn.Module, extra_jit_args={}, dedup_parameters=True + ): + super().__init__() + self.params, self.buffers = extract_all_buffers(m) + self._model = m + self._jitted = {} + + self._extra_jit_args = extra_jit_args + + self._extra_dumped_weights = {} + + if dedup_parameters: + temp = collections.defaultdict(list) + for k, v in self.params.items(): + temp[id(v)].append(k) + + for v in temp.values(): + if len(v) > 1: + # duplicated weights with different name + self._extra_dumped_weights[v[0]] = v[1:] + for extra_keys in v[1:]: + del self.params[extra_keys] + + @property + def __class__(self): + # Lie about the class type so that + # isinstance(jittable_module, self._model.__class__) works + return self._model.__class__ + + def __call__(self, *args, **kwargs): + return self.forward(*args, **kwargs) + + def functional_call(self, method_name, params, buffers, *args, **kwargs): + kwargs = kwargs or {} + params_copy = copy.copy(params) + params_copy.update(buffers) + # reinflate the state dict so there are not any missing keys + for k, v in self._extra_dumped_weights.items(): + for new_key in v: + params_copy[new_key] = params_copy[k] + with torch_stateless._reparametrize_module(self._model, params_copy): + res = getattr(self._model, method_name)(*args, **kwargs) + return res + + def forward(self, *args, **kwargs): + if "forward" not in self._jitted: + jitted = jax_jit( + functools.partial(self.functional_call, "forward"), + kwargs_for_jax_jit=self._extra_jit_args, + ) + + def jitted_forward(*args, **kwargs): + return jitted(self.params, self.buffers, *args, **kwargs) + + self._jitted["forward"] = jitted_forward + return self._jitted["forward"](*args, **kwargs) + + def __getattr__(self, key): + if key == "_model": + return super().__getattr__(key) + if key in self._jitted: + return self._jitted[key] + return getattr(self._model, key) + + def make_jitted(self, key): + jitted = jax_jit( + functools.partial(self.functional_call, key), + kwargs_for_jax_jit=self._extra_jit_args, + ) + + def call(*args, **kwargs): + return jitted(self.params, self.buffers, *args, **kwargs) + + self._jitted[key] = call class CompileMixin: + def functional_call(self, method, params, buffers, *args, **kwargs): + kwargs = kwargs or {} + params_copy = copy.copy(params) + params_copy.update(buffers) + with torch_stateless._reparametrize_module(self, params_copy): + res = method(*args, **kwargs) + return res - def functional_call(self, method, params, buffers, *args, **kwargs): - kwargs = kwargs or {} - params_copy = copy.copy(params) - params_copy.update(buffers) - with torch_stateless._reparametrize_module(self, params_copy): - res = method(*args, **kwargs) - return res - - def jit(self, method): - jitted = jax_jit(functools.partial(self.functional_call, method_name)) + def jit(self, method): + jitted = jax_jit(functools.partial(self.functional_call, method_name)) - def call(*args, **kwargs): - return jitted(self.named_paramters(), self.named_buffers(), *args, - **kwargs) + def call(*args, **kwargs): + return jitted( + self.named_paramters(), self.named_buffers(), *args, **kwargs + ) - return call + return call def compile_nn_module(m: torch.nn.Module, methods=None): - if methods is None: - methods = ['forward'] + if methods is None: + methods = ["forward"] - new_parent = type( - m.__class__.__name__ + '_with_CompileMixin', - (CompileMixin, m.__class__), - ) - m.__class__ = NewParent + new_parent = type( + m.__class__.__name__ + "_with_CompileMixin", + (CompileMixin, m.__class__), + ) + m.__class__ = NewParent def _torch_view(t: JaxValue) -> TorchValue: - # t is an object from jax land - # view it as-if it's a torch land object - if isinstance(t, jax.Array): - # TODO - return tensor.Tensor(t, torchax.default_env()) - if isinstance(t, type(jnp.int32)): - return tensor.t2j_type(t) - if callable(t): # t is a JaxCallable - return functools.partial(call_jax, t) - # regular types are not changed - return t + # t is an object from jax land + # view it as-if it's a torch land object + if isinstance(t, jax.Array): + # TODO + return tensor.Tensor(t, torchax.default_env()) + if isinstance(t, type(jnp.int32)): + return tensor.t2j_type(t) + if callable(t): # t is a JaxCallable + return functools.partial(call_jax, t) + # regular types are not changed + return t torch_view = functools.partial(pytree.tree_map, _torch_view) def _jax_view(t: TorchValue) -> JaxValue: - # t is an object from torch land - # view it as-if it's a jax land object - if isinstance(t, torch.Tensor): - assert isinstance(t, tensor.Tensor) or isinstance(t, tensor.View), type(t) - return t.jax() - if isinstance(t, type(torch.int32)): - return tensor.t2j_dtype(t) - - # torch.nn.Module needs special handling - if not isinstance(t, torch.nn.Module) and callable(t): # t is a TorchCallable - return functools.partial(call_torch, t) - # regular types are not changed - return t + # t is an object from torch land + # view it as-if it's a jax land object + if isinstance(t, torch.Tensor): + assert isinstance(t, tensor.Tensor) or isinstance(t, tensor.View), type( + t + ) + return t.jax() + if isinstance(t, type(torch.int32)): + return tensor.t2j_dtype(t) + + # torch.nn.Module needs special handling + if not isinstance(t, torch.nn.Module) and callable( + t + ): # t is a TorchCallable + return functools.partial(call_torch, t) + # regular types are not changed + return t jax_view = functools.partial(pytree.tree_map, _jax_view) -def call_jax(jax_func: JaxCallable, *args: TorchValue, - **kwargs: TorchValue) -> TorchValue: - args, kwargs = jax_view((args, kwargs)) - res: JaxValue = jax_func(*args, **kwargs) - return torch_view(res) +def call_jax( + jax_func: JaxCallable, *args: TorchValue, **kwargs: TorchValue +) -> TorchValue: + args, kwargs = jax_view((args, kwargs)) + res: JaxValue = jax_func(*args, **kwargs) + return torch_view(res) -def call_torch(torch_func: TorchCallable, *args: JaxValue, - **kwargs: JaxValue) -> JaxValue: - args, kwargs = torch_view((args, kwargs)) - with torchax.default_env(): - res: TorchValue = torch_func(*args, **kwargs) - return jax_view(res) +def call_torch( + torch_func: TorchCallable, *args: JaxValue, **kwargs: JaxValue +) -> JaxValue: + args, kwargs = torch_view((args, kwargs)) + with torchax.default_env(): + res: TorchValue = torch_func(*args, **kwargs) + return jax_view(res) def j2t_autograd(fn, call_jax=call_jax): - """Given a JAX function, returns a PyTorch autograd function implemented with `jax.vjp(fn)`. + """Given a JAX function, returns a PyTorch autograd function implemented with `jax.vjp(fn)`. It wraps `fn` with `jax.vjp` to compute both the output and residuals (intermediate activations). The wrapped function is then run via `call_jax` and integrated into the PyTorch autograd framework by saving the residuals into the context object. """ - @wraps(fn) - def inner(*args, **kwargs): - from jax.tree_util import tree_flatten, tree_unflatten - from jax.util import safe_zip - - class JaxFun(torch.autograd.Function): - - @staticmethod - def forward(ctx, tree_def, *flat_args_kwargs): - - tensors, other = util.partition(flat_args_kwargs, - lambda x: isinstance(x, torch.Tensor)) - # We want the arguments that don't require grads to be closured? - - y, fun_vjp = call_jax(_jax_forward, fn, other, tree_def, tensors) - - # Save necessary information for backward - # Flatten the vjp function. `vjp_spec` contains a jaxpr for the backward pass. - # `residuals` contains the tensors needed for the backward pass.` - residuals, vjp_spec = tree_flatten(fun_vjp) - ctx.vjp_spec = vjp_spec - ctx.save_for_backward(*residuals) + @wraps(fn) + def inner(*args, **kwargs): + from jax.tree_util import tree_flatten, tree_unflatten + from jax.util import safe_zip + + class JaxFun(torch.autograd.Function): + @staticmethod + def forward(ctx, tree_def, *flat_args_kwargs): + tensors, other = util.partition( + flat_args_kwargs, lambda x: isinstance(x, torch.Tensor) + ) + # We want the arguments that don't require grads to be closured? + + y, fun_vjp = call_jax( + _jax_forward, fn, other, tree_def, tensors + ) + + # Save necessary information for backward + # Flatten the vjp function. `vjp_spec` contains a jaxpr for the backward pass. + # `residuals` contains the tensors needed for the backward pass.` + residuals, vjp_spec = tree_flatten(fun_vjp) + ctx.vjp_spec = vjp_spec + ctx.save_for_backward(*residuals) + return y + + @staticmethod + def backward(ctx, *grad_out): + assert len(grad_out) > 0 + grad_out = grad_out if len(grad_out) > 1 else grad_out[0] + + input_grads_structured = call_jax( + _jax_backward, ctx.vjp_spec, ctx.saved_tensors, grad_out + ) + + # Construct the gradient tuple to be returned. + # It needs to match the inputs to forward: (tree_def, *flat_inputs) + # The first gradient (for tree_def) is None. + # The subsequent gradients correspond to flat_inputs. + # We need to put a None for inputs that did not require gradients. + final_grads = [None] + for needs_grad, grad in safe_zip( + ctx.needs_input_grad[1:], input_grads_structured + ): + final_grads.append(grad if needs_grad else None) + + return tuple(final_grads) + + sig = signature(fn) + bound = sig.bind(*args, **kwargs) + bound.apply_defaults() + flat_args_kwargs, tree_def = tree_flatten((bound.args, bound.kwargs)) + y = JaxFun.apply(tree_def, *flat_args_kwargs) return y - @staticmethod - def backward(ctx, *grad_out): - assert len(grad_out) > 0 - grad_out = grad_out if len(grad_out) > 1 else grad_out[0] - - input_grads_structured = call_jax(_jax_backward, ctx.vjp_spec, - ctx.saved_tensors, grad_out) - - # Construct the gradient tuple to be returned. - # It needs to match the inputs to forward: (tree_def, *flat_inputs) - # The first gradient (for tree_def) is None. - # The subsequent gradients correspond to flat_inputs. - # We need to put a None for inputs that did not require gradients. - final_grads = [None] - for needs_grad, grad in safe_zip(ctx.needs_input_grad[1:], - input_grads_structured): - final_grads.append(grad if needs_grad else None) - - return tuple(final_grads) - - sig = signature(fn) - bound = sig.bind(*args, **kwargs) - bound.apply_defaults() - flat_args_kwargs, tree_def = tree_flatten((bound.args, bound.kwargs)) - y = JaxFun.apply(tree_def, *flat_args_kwargs) - return y - - return inner + return inner # NOTE(qihqi): This function cannot be inlined from the callsite # Becuase if it does, then it won't hit the compilation cache for # call_jax. Call jax uses functions' id as key. def _jax_forward(fn, other, tree_def, tensors): - """JAX function to compute output and vjp function. + """JAX function to compute output and vjp function. - primals should be a tuple (args, kwargs). - """ - import jax - from jax.tree_util import tree_flatten, tree_unflatten + primals should be a tuple (args, kwargs). + """ + import jax + from jax.tree_util import tree_flatten, tree_unflatten - def fn_wrapper(*tensors): - # Reconstruct the original args and kwargs - flat_inputs = util.merge(tensors, other) - args, kwargs = tree_unflatten(tree_def, flat_inputs) - return fn(*args, **kwargs) + def fn_wrapper(*tensors): + # Reconstruct the original args and kwargs + flat_inputs = util.merge(tensors, other) + args, kwargs = tree_unflatten(tree_def, flat_inputs) + return fn(*args, **kwargs) - return jax.vjp(fn_wrapper, *tensors) + return jax.vjp(fn_wrapper, *tensors) def _jax_backward(vjp_spec, saved_tensors, grad_out): - """JAX function to compute input gradients. + """JAX function to compute input gradients. + + Unflattening `saved_tensors` with `vjp_spec` should restore the original vjp function. + """ + from jax.tree_util import tree_unflatten - Unflattening `saved_tensors` with `vjp_spec` should restore the original vjp function. - """ - from jax.tree_util import tree_unflatten - fun_vjp = tree_unflatten(vjp_spec, saved_tensors) - return fun_vjp(grad_out) + fun_vjp = tree_unflatten(vjp_spec, saved_tensors) + return fun_vjp(grad_out) fori_loop = torch_view(jax.lax.fori_loop) def wrap_jax_jit(torch_function, jax_jit_func=jax.jit, kwargs_for_jax=None): - kwargs_for_jax = kwargs_for_jax or {} - jax_func = jax_view(torch_function) - jitted = jax_jit_func(jax_func, **kwargs_for_jax) - return torch_view(jitted) + kwargs_for_jax = kwargs_for_jax or {} + jax_func = jax_view(torch_function) + jitted = jax_jit_func(jax_func, **kwargs_for_jax) + return torch_view(jitted) -def jax_jit(torch_function, - kwargs_for_jax_jit=None, - fix_for_buffer_donation=False): - return wrap_jax_jit( - torch_function, jax_jit_func=jax.jit, kwargs_for_jax=kwargs_for_jax_jit) +def jax_jit( + torch_function, kwargs_for_jax_jit=None, fix_for_buffer_donation=False +): + return wrap_jax_jit( + torch_function, jax_jit_func=jax.jit, kwargs_for_jax=kwargs_for_jax_jit + ) def jax_shard_map(torch_function, kwargs_for_jax_shard_map=None): - return wrap_jax_jit( - torch_function, - jax_jit_func=shard_map, - kwargs_for_jax=kwargs_for_jax_shard_map) + return wrap_jax_jit( + torch_function, + jax_jit_func=shard_map, + kwargs_for_jax=kwargs_for_jax_shard_map, + ) def jax_value_and_grad(torch_function, kwargs_for_value_and_grad=None): - return wrap_jax_jit( - torch_function, - jax_jit_func=jax.value_and_grad, - kwargs_for_jax=kwargs_for_value_and_grad) + return wrap_jax_jit( + torch_function, + jax_jit_func=jax.value_and_grad, + kwargs_for_jax=kwargs_for_value_and_grad, + ) def gradient_checkpoint(torch_function, kwargs=None): - return wrap_jax_jit( - torch_function, jax_jit_func=jax.checkpoint, kwargs_for_jax=kwargs) + return wrap_jax_jit( + torch_function, jax_jit_func=jax.checkpoint, kwargs_for_jax=kwargs + ) diff --git a/torchax/torchax/mesh_util.py b/torchax/torchax/mesh_util.py index 3f65b8440b59..281f5fd80da5 100644 --- a/torchax/torchax/mesh_util.py +++ b/torchax/torchax/mesh_util.py @@ -6,206 +6,208 @@ def _shard_first_multiple_of(axis_name, shape, multiple_of): - """Creates a PartitionSpec to shard the first dimension divisible by a number. - - Iterates through the dimensions specified by `shape`. Finds the first dimension - whose size is a multiple of `multiple_of` and returns a PartitionSpec that - shards that dimension along the given `axis_name`. All preceding dimensions - are not sharded (marked as None in the PartitionSpec). All subsequent dimensions - skipped, which would be implicitly treated as replicated. - - Args: - axis_name: The name of the mesh axis to shard along (e.g., "data", "mdl"). - shape: A tuple or list representing the shape of the tensor to be sharded. - multiple_of: The integer value that a dimension size must be divisible by - in order to be sharded. Typically the size of the mesh axis. - - Returns: - A jax.sharding.PartitionSpec object specifying how to shard the tensor. - For example, if shape=(10, 20, 30), axis_name='x', multiple_of=4, - it would return PartitionSpec(None, 'x', None). - If none divides then it should return a replicated PartitionSpec - """ - sharding = [] - found = False - for size in shape: - if not found and size % multiple_of == 0: - found = True - sharding.append(axis_name) - else: - sharding.append(None) - return PartitionSpec(*sharding) + """Creates a PartitionSpec to shard the first dimension divisible by a number. - -class SingleAxisSharder: - """A callable object that generates PartitionSpecs for single-axis sharding. - - This sharder strategy attempts to shard the *first* dimension of a tensor - that is divisible by the specified `axis_size` along the given `axis_name`. - It's useful for simple 1D mesh sharding scenarios like FSDP where parameters - are typically sharded along one dimension. - - Attributes: - axis_name: The name of the mesh axis to shard along. - axis_size: The size of the mesh axis (number of devices along that axis). - """ - - def __init__(self, axis_name, axis_size, replicate_unshardable=False): - """Initializes the SingleAxisSharder. - - Args: - axis_name: The name of the mesh axis (e.g., "fsdp", "data"). - axis_size: The number of devices along the specified mesh axis. - replicate_unshardable: indicate whether it should return replicated sharding - (P()) when none of the axis is divisible by the axis size. - """ - self.axis_name = axis_name - self.axis_size = axis_size - self.replicate_unshardable = replicate_unshardable - - def __call__(self, name, shapedtype): - """Generates a PartitionSpec for a given tensor name and shaped type. + Iterates through the dimensions specified by `shape`. Finds the first dimension + whose size is a multiple of `multiple_of` and returns a PartitionSpec that + shards that dimension along the given `axis_name`. All preceding dimensions + are not sharded (marked as None in the PartitionSpec). All subsequent dimensions + skipped, which would be implicitly treated as replicated. Args: - name: The name of the tensor (e.g., parameter name). This argument is - provided for compatibility with more complex sharders but is not used - by this simple sharder. - shapedtype: An object with a `.shape` attribute describing the tensor's shape, - and `.dtype` describing it's dtype. Example: jax.Array, jax.ShapeDtypeStruct - or a torch.Tensor) + axis_name: The name of the mesh axis to shard along (e.g., "data", "mdl"). + shape: A tuple or list representing the shape of the tensor to be sharded. + multiple_of: The integer value that a dimension size must be divisible by + in order to be sharded. Typically the size of the mesh axis. Returns: - A jax.sharding.PartitionSpec determined by finding the first dimension - in `shapedtype.shape` divisible by `self.axis_size` using the helper - `_shard_first_multiple_of`. + A jax.sharding.PartitionSpec object specifying how to shard the tensor. + For example, if shape=(10, 20, 30), axis_name='x', multiple_of=4, + it would return PartitionSpec(None, 'x', None). + If none divides then it should return a replicated PartitionSpec """ - del name - sharding = _shard_first_multiple_of(self.axis_name, shapedtype.shape, - self.axis_size) - if not self.replicate_unshardable and all(s is None for s in sharding): - raise AssertionError( - f"Unable to find a dim to shard because " - f"None of the dims ({shapedtype.shape}) in shape is multiple of {self.axis_size}" - ) - return sharding - - -class Mesh: - """A helper class that wraps `jax.sharding.Mesh` object. - - The goal of this class is to provide helper methods that facilitate the - sharding of PyTorch tensors or models given a JAX device mesh configuration. - It simplifies initializing models directly into a sharded state. - - Attributes: - jax_mesh: The underlying `jax.sharding.Mesh` object defining the device grid - and axis names. - _sharder: The default sharding strategy callable (like SingleAxisSharder) - used to determine the PartitionSpec for each parameter if not overridden - during method calls. Can be None if no default is appropriate or set. - """ + sharding = [] + found = False + for size in shape: + if not found and size % multiple_of == 0: + found = True + sharding.append(axis_name) + else: + sharding.append(None) + return PartitionSpec(*sharding) - @classmethod - def fsdp_mesh(cls, axis_name="fsdp"): - """Creates a Mesh instance suitable for 1D FSDP-style sharding. - This named constructor creates a 1D mesh encompassing all available XLA - devices. It assigns the specified `axis_name` to this single dimension. - It then creates a `Mesh` instance using this JAX mesh and a - `SingleAxisSharder` configured appropriately for this 1D mesh. +class SingleAxisSharder: + """A callable object that generates PartitionSpecs for single-axis sharding. - Args: - axis_name: The name to assign to the single mesh axis (default: "fsdp"). - This name will be used by the default `SingleAxisSharder`. + This sharder strategy attempts to shard the *first* dimension of a tensor + that is divisible by the specified `axis_size` along the given `axis_name`. + It's useful for simple 1D mesh sharding scenarios like FSDP where parameters + are typically sharded along one dimension. - Returns: - A Mesh instance configured with a 1D JAX mesh across all devices and a - corresponding SingleAxisSharder. + Attributes: + axis_name: The name of the mesh axis to shard along. + axis_size: The size of the mesh axis (number of devices along that axis). """ - ndevice = jax.device_count() - jax_mesh = jax.make_mesh((ndevice,), (axis_name,)) - # replicate_unshardable so scalars and small model attributes are replicated. - return cls(jax_mesh, SingleAxisSharder(axis_name, ndevice, True)) - def __init__(self, jax_mesh, sharder=None): - """Initializes the Mesh helper. + def __init__(self, axis_name, axis_size, replicate_unshardable=False): + """Initializes the SingleAxisSharder. + + Args: + axis_name: The name of the mesh axis (e.g., "fsdp", "data"). + axis_size: The number of devices along the specified mesh axis. + replicate_unshardable: indicate whether it should return replicated sharding + (P()) when none of the axis is divisible by the axis size. + """ + self.axis_name = axis_name + self.axis_size = axis_size + self.replicate_unshardable = replicate_unshardable + + def __call__(self, name, shapedtype): + """Generates a PartitionSpec for a given tensor name and shaped type. + + Args: + name: The name of the tensor (e.g., parameter name). This argument is + provided for compatibility with more complex sharders but is not used + by this simple sharder. + shapedtype: An object with a `.shape` attribute describing the tensor's shape, + and `.dtype` describing it's dtype. Example: jax.Array, jax.ShapeDtypeStruct + or a torch.Tensor) + + Returns: + A jax.sharding.PartitionSpec determined by finding the first dimension + in `shapedtype.shape` divisible by `self.axis_size` using the helper + `_shard_first_multiple_of`. + """ + del name + sharding = _shard_first_multiple_of( + self.axis_name, shapedtype.shape, self.axis_size + ) + if not self.replicate_unshardable and all(s is None for s in sharding): + raise AssertionError( + f"Unable to find a dim to shard because " + f"None of the dims ({shapedtype.shape}) in shape is multiple of {self.axis_size}" + ) + return sharding - Args: - jax_mesh: A pre-configured `jax.sharding.Mesh` object defining the - physical device grid and logical axis names. - sharder: An optional callable (e.g., an instance of SingleAxisSharder) - that takes (name, shapedtype) and returns a `jax.sharding.PartitionSpec`. - This serves as the default sharding strategy. - If None, and the provided `jax_mesh` has exactly one axis, a - `SingleAxisSharder` is created automatically for that single axis. - If None and the mesh has multiple axes, `_sharder` remains None, and - an `override_sharder` must be provided to methods like - `initialize_model_sharded`. - """ - self.jax_mesh = jax_mesh - if sharder is None: - assert len(self.jax_mesh.axis_names) == 1 - sharder = SingleAxisSharder(self.jax_mesh.axis_names[0], - len(self.mesh.device_ids)) - self._sharder = sharder - - def initialize_model_sharded(self, - model_class, - init_args, - init_kwargs=None, - override_sharder=None): - """Initializes a PyTorch model with its parameters sharded across the mesh. - - This method orchestrates the initialization of a `torch.nn.Module` such - that its parameters are created directly on the target devices according - to the sharding specifications derived from the mesh and the chosen sharder. - It leverages `torchax.interop.jax_jit` to achieve this. - - Args: - model_class: The PyTorch model class (a subclass of `torch.nn.Module`). - init_args: A tuple containing the positional arguments required by the - `model_class.__init__` method. - init_kwargs: An optional dictionary containing the keyword arguments for - the `model_class.__init__` method. Defaults to None (treated as {}). - override_sharder: An optional callable sharding strategy to use - specifically for this initialization. If provided, it takes precedence - over the mesh's default `_sharder`. It must accept `(name, shapedtype)` - and return a `PartitionSpec`. If None, the mesh's default `_sharder` - is used. - Returns: - An instance of `model_class` whose parameters have been initialized and - are represented by sharded tensors distributed across the devices in the - `jax_mesh`. - - Raises: - ValueError: If no sharder is available (i.e., `override_sharder` is None - and the mesh's default `_sharder` is also None). - AssertionError: Can be raised by the sharder (e.g., `SingleAxisSharder`) - if it fails to determine a valid sharding for any parameter. - TypeError: If `shapedtype` passed to the sharder doesn't have a `.shape`. - Other errors from JAX JIT compilation or PyTorch model initialization. +class Mesh: + """A helper class that wraps `jax.sharding.Mesh` object. + + The goal of this class is to provide helper methods that facilitate the + sharding of PyTorch tensors or models given a JAX device mesh configuration. + It simplifies initializing models directly into a sharded state. + + Attributes: + jax_mesh: The underlying `jax.sharding.Mesh` object defining the device grid + and axis names. + _sharder: The default sharding strategy callable (like SingleAxisSharder) + used to determine the PartitionSpec for each parameter if not overridden + during method calls. Can be None if no default is appropriate or set. """ - init_kwargs = init_kwargs or {} - with torch.device("meta"), torchax.disable_temporarily(): - model = model_class(*init_args, **init_kwargs) - - sharder = override_sharder or self._sharder - - states = model.state_dict() - output_shards = { - name: NamedSharding(self.jax_mesh, sharder(name, tensor)) - for name, tensor in states.items() - } - - def model_initializer(): - with torchax.default_env(): - model = model_class(*init_args, **init_kwargs) - return dict(model.state_dict()) - - jitted = interop.jax_jit( - model_initializer, kwargs_for_jax_jit={"out_shardings": output_shards}) - weights_dict = jitted() - model.load_state_dict(weights_dict, assign=True) - return model + @classmethod + def fsdp_mesh(cls, axis_name="fsdp"): + """Creates a Mesh instance suitable for 1D FSDP-style sharding. + + This named constructor creates a 1D mesh encompassing all available XLA + devices. It assigns the specified `axis_name` to this single dimension. + It then creates a `Mesh` instance using this JAX mesh and a + `SingleAxisSharder` configured appropriately for this 1D mesh. + + Args: + axis_name: The name to assign to the single mesh axis (default: "fsdp"). + This name will be used by the default `SingleAxisSharder`. + + Returns: + A Mesh instance configured with a 1D JAX mesh across all devices and a + corresponding SingleAxisSharder. + """ + ndevice = jax.device_count() + jax_mesh = jax.make_mesh((ndevice,), (axis_name,)) + # replicate_unshardable so scalars and small model attributes are replicated. + return cls(jax_mesh, SingleAxisSharder(axis_name, ndevice, True)) + + def __init__(self, jax_mesh, sharder=None): + """Initializes the Mesh helper. + + Args: + jax_mesh: A pre-configured `jax.sharding.Mesh` object defining the + physical device grid and logical axis names. + sharder: An optional callable (e.g., an instance of SingleAxisSharder) + that takes (name, shapedtype) and returns a `jax.sharding.PartitionSpec`. + This serves as the default sharding strategy. + If None, and the provided `jax_mesh` has exactly one axis, a + `SingleAxisSharder` is created automatically for that single axis. + If None and the mesh has multiple axes, `_sharder` remains None, and + an `override_sharder` must be provided to methods like + `initialize_model_sharded`. + """ + self.jax_mesh = jax_mesh + if sharder is None: + assert len(self.jax_mesh.axis_names) == 1 + sharder = SingleAxisSharder( + self.jax_mesh.axis_names[0], len(self.mesh.device_ids) + ) + self._sharder = sharder + + def initialize_model_sharded( + self, model_class, init_args, init_kwargs=None, override_sharder=None + ): + """Initializes a PyTorch model with its parameters sharded across the mesh. + + This method orchestrates the initialization of a `torch.nn.Module` such + that its parameters are created directly on the target devices according + to the sharding specifications derived from the mesh and the chosen sharder. + It leverages `torchax.interop.jax_jit` to achieve this. + + Args: + model_class: The PyTorch model class (a subclass of `torch.nn.Module`). + init_args: A tuple containing the positional arguments required by the + `model_class.__init__` method. + init_kwargs: An optional dictionary containing the keyword arguments for + the `model_class.__init__` method. Defaults to None (treated as {}). + override_sharder: An optional callable sharding strategy to use + specifically for this initialization. If provided, it takes precedence + over the mesh's default `_sharder`. It must accept `(name, shapedtype)` + and return a `PartitionSpec`. If None, the mesh's default `_sharder` + is used. + + Returns: + An instance of `model_class` whose parameters have been initialized and + are represented by sharded tensors distributed across the devices in the + `jax_mesh`. + + Raises: + ValueError: If no sharder is available (i.e., `override_sharder` is None + and the mesh's default `_sharder` is also None). + AssertionError: Can be raised by the sharder (e.g., `SingleAxisSharder`) + if it fails to determine a valid sharding for any parameter. + TypeError: If `shapedtype` passed to the sharder doesn't have a `.shape`. + Other errors from JAX JIT compilation or PyTorch model initialization. + """ + init_kwargs = init_kwargs or {} + with torch.device("meta"), torchax.disable_temporarily(): + model = model_class(*init_args, **init_kwargs) + + sharder = override_sharder or self._sharder + + states = model.state_dict() + output_shards = { + name: NamedSharding(self.jax_mesh, sharder(name, tensor)) + for name, tensor in states.items() + } + + def model_initializer(): + with torchax.default_env(): + model = model_class(*init_args, **init_kwargs) + return dict(model.state_dict()) + + jitted = interop.jax_jit( + model_initializer, + kwargs_for_jax_jit={"out_shardings": output_shards}, + ) + weights_dict = jitted() + + model.load_state_dict(weights_dict, assign=True) + return model diff --git a/torchax/torchax/ops/__init__.py b/torchax/torchax/ops/__init__.py index 71c1b137132f..d306871dd7ac 100644 --- a/torchax/torchax/ops/__init__.py +++ b/torchax/torchax/ops/__init__.py @@ -1,10 +1,10 @@ def all_aten_jax_ops(): - # to load the ops - import torchax.ops.jaten # type: ignore - import torchax.ops.ops_registry # type: ignore + # to load the ops + import torchax.ops.jaten # type: ignore + import torchax.ops.ops_registry # type: ignore - return { - key: val.func - for key, val in torchax.ops.ops_registry.all_aten_ops.items() - if val.is_jax_function - } + return { + key: val.func + for key, val in torchax.ops.ops_registry.all_aten_ops.items() + if val.is_jax_function + } diff --git a/torchax/torchax/ops/jaten.py b/torchax/torchax/ops/jaten.py index 8d2242fdb59b..1f7e4a06902b 100644 --- a/torchax/torchax/ops/jaten.py +++ b/torchax/torchax/ops/jaten.py @@ -16,33 +16,37 @@ from torchax import interop from torchax.ops import jax_reimplement from torchax.view import View + # Keys are OpOverload, value is a callable that takes # Tensor all_ops = {} def op(*aten, **kwargs): - - def inner(func): - for a in aten: - ops_registry.register_torch_dispatch_op(a, func, **kwargs) - continue - - if isinstance(a, torch._ops.OpOverloadPacket): - opname = a.default.name() if 'default' in a.overloads( - ) else a._qualified_op_name - elif isinstance(a, torch._ops.OpOverload): - opname = a.name() - else: - raise RuntimeError(f'oops {a}') - - torchfunc = functools.partial(interop.call_jax, func) - # HACK: to_copy is where we make the initial conversion from CPU tensor to JAX tensor - torch.library.impl(opname, 'privateuseone')( - torchfunc if a != torch.ops.aten._to_copy else func) - return func - - return inner + def inner(func): + for a in aten: + ops_registry.register_torch_dispatch_op(a, func, **kwargs) + continue + + if isinstance(a, torch._ops.OpOverloadPacket): + opname = ( + a.default.name() + if "default" in a.overloads() + else a._qualified_op_name + ) + elif isinstance(a, torch._ops.OpOverload): + opname = a.name() + else: + raise RuntimeError(f"oops {a}") + + torchfunc = functools.partial(interop.call_jax, func) + # HACK: to_copy is where we make the initial conversion from CPU tensor to JAX tensor + torch.library.impl(opname, "privateuseone")( + torchfunc if a != torch.ops.aten._to_copy else func + ) + return func + + return inner @op( @@ -52,542 +56,549 @@ def inner(func): torch.ops.aten.reshape, ) def _aten_unsafe_view(x, shape): - return jnp.reshape(x, shape) + return jnp.reshape(x, shape) @op(torch.ops.aten.add.Tensor) @op(torch.ops.aten.add.Scalar) def _aten_add(x, y, *, alpha=1): - """if isinstance(x, jnp.ndarray) and isinstance(y, jnp.ndarray): + """if isinstance(x, jnp.ndarray) and isinstance(y, jnp.ndarray): - assert x.dtype == y.dtype, (x.dtype, y.dtype) - """ - res = x + y * alpha - if isinstance(x, float) or isinstance(y, float): - new_dtype = mappings.t2j_dtype(torch.get_default_dtype()) - res = res.astype(new_dtype) - return res + assert x.dtype == y.dtype, (x.dtype, y.dtype) + """ + res = x + y * alpha + if isinstance(x, float) or isinstance(y, float): + new_dtype = mappings.t2j_dtype(torch.get_default_dtype()) + res = res.astype(new_dtype) + return res -@op(torch.ops.aten.copy_, - is_jax_function=False, - is_view_op=True, - needs_env=True) +@op( + torch.ops.aten.copy_, is_jax_function=False, is_view_op=True, needs_env=True +) def _aten_copy(x, y, memory_format=None, env=None): - - if y.device.type == 'cpu': - y = env.to_xla(y) - - if isinstance(x, View): - x.update(y) + if y.device.type == "cpu": + y = env.to_xla(y) + + if isinstance(x, View): + x.update(y) + return x + + if x.ndim == 1 and y.ndim == 0: + # case of torch.empty((1,)).copy_(tensor(N)) + # we need to return 0D tensor([N]) and not scalar tensor(N) + # ref: https://github.com/pytorch/xla/issues/7505#issuecomment-2395319131 + x._elem = jnp.array([y._elem.astype(x._elem.dtype)]) + else: + x._elem = y._elem.astype(x._elem.dtype) return x - if x.ndim == 1 and y.ndim == 0: - # case of torch.empty((1,)).copy_(tensor(N)) - # we need to return 0D tensor([N]) and not scalar tensor(N) - # ref: https://github.com/pytorch/xla/issues/7505#issuecomment-2395319131 - x._elem = jnp.array([y._elem.astype(x._elem.dtype)]) - else: - x._elem = y._elem.astype(x._elem.dtype) - return x - @op(torch.ops.aten.clone) def _aten_clone(x, memory_format=None): - return x + return x # aten.trunc @op(torch.ops.aten.trunc) def _aten_trunc(x): - res = jnp.trunc(x) - return res.astype(x) + res = jnp.trunc(x) + return res.astype(x) @op(torch.ops.aten.index_copy) def _aten_index_copy(x, dim, indexes, source): - if x.ndim == 0: - return source - if x.ndim == 1: - source = jnp.squeeze(source) - # return jax.lax.scatter(x, index, dim) - if dim < 0: - dim = dim + x.ndim - dims = [] - for i in range(len(x.shape)): - if i == dim: - dims.append(indexes) - else: - dims.append(slice(None, None, None)) - return x.at[tuple(dims)].set(source) + if x.ndim == 0: + return source + if x.ndim == 1: + source = jnp.squeeze(source) + # return jax.lax.scatter(x, index, dim) + if dim < 0: + dim = dim + x.ndim + dims = [] + for i in range(len(x.shape)): + if i == dim: + dims.append(indexes) + else: + dims.append(slice(None, None, None)) + return x.at[tuple(dims)].set(source) # aten.cauchy_ @op(torch.ops.aten.cauchy_) def _aten_cauchy_(x, median=0, sigma=1): - """ - Fills the input array with values drawn from a Cauchy distribution. + """ + Fills the input array with values drawn from a Cauchy distribution. - Args: - x: An array to be filled with Cauchy samples. - median: The median of the Cauchy distribution. - sigma: The scale parameter of the Cauchy distribution. + Args: + x: An array to be filled with Cauchy samples. + median: The median of the Cauchy distribution. + sigma: The scale parameter of the Cauchy distribution. - Returns: - The input array filled with Cauchy samples. - """ - key = jax.random.PRNGKey(0) # You should use a different key for each call - samples = jax.random.cauchy(key, x.shape) * sigma + median - return x.at[:].set(samples) + Returns: + The input array filled with Cauchy samples. + """ + key = jax.random.PRNGKey(0) # You should use a different key for each call + samples = jax.random.cauchy(key, x.shape) * sigma + median + return x.at[:].set(samples) @op(torch.ops.aten.atleast_2d) def _aten_atleast_2d(inputs): - return jnp.atleast_2d(inputs) + return jnp.atleast_2d(inputs) @op(torch.ops.aten.atleast_1d) def _aten_atleast_1d(inputs): - return jnp.atleast_1d(inputs) + return jnp.atleast_1d(inputs) # aten.complex @op(torch.ops.aten.complex) def _aten_complex(real, imag): - """ - Constructs a complex array from real and imaginary parts. + """ + Constructs a complex array from real and imaginary parts. - Args: - real: An array of real values. - imag: An array of imaginary values. + Args: + real: An array of real values. + imag: An array of imaginary values. - Returns: - A complex array with the specified real and imaginary parts. - """ - return jnp.array( - real, dtype=jnp.float32) + 1j * jnp.array( - imag, dtype=jnp.float32) + Returns: + A complex array with the specified real and imaginary parts. + """ + return jnp.array(real, dtype=jnp.float32) + 1j * jnp.array( + imag, dtype=jnp.float32 + ) # aten.exponential_ @op(torch.ops.aten.exponential_) def _aten_exponential_(x, lambd=1.0): - """ - Fills the input array with values drawn from an exponential distribution. + """ + Fills the input array with values drawn from an exponential distribution. - Args: - x: An array to be filled with exponential samples. - lambd: The rate parameter of the exponential distribution. + Args: + x: An array to be filled with exponential samples. + lambd: The rate parameter of the exponential distribution. - Returns: - The input array filled with exponential samples. - """ - key = jax.random.PRNGKey(0) # Use a different key for each call - samples = jax.random.exponential(key, x.shape) / lambd - return x.at[:].set(samples) + Returns: + The input array filled with exponential samples. + """ + key = jax.random.PRNGKey(0) # Use a different key for each call + samples = jax.random.exponential(key, x.shape) / lambd + return x.at[:].set(samples) # aten.linalg_householder_product @op(torch.ops.aten.linalg_householder_product) def _aten_linalg_householder_product(input, tau): - return jax.lax.linalg.householder_product(a=input, taus=tau) + return jax.lax.linalg.householder_product(a=input, taus=tau) @op(torch.ops.aten.select) def _aten_select(x, dim, indexes): - return jax.lax.index_in_dim(x, index=indexes, axis=dim, keepdims=False) + return jax.lax.index_in_dim(x, index=indexes, axis=dim, keepdims=False) @op(torch.ops.aten.index_select) @op(torch.ops.aten.select_copy) def _aten_index_select(x, dim, index): - if x.shape == (): - return x - return jnp.take(x, index, dim) + if x.shape == (): + return x + return jnp.take(x, index, dim) @op(torch.ops.aten.cholesky) def _aten_cholesky(input, upper=False): - return jax.scipy.linalg.cholesky(input, lower=(not upper)) + return jax.scipy.linalg.cholesky(input, lower=(not upper)) @op(torch.ops.aten.linalg_cholesky_ex) def _aten_linalg_cholesky_ex(input, upper=False, check_errors=False): - if check_errors: - raise NotImplementedError( - "check_errors=True is not supported in this JAX implementation. " - "Check for positive definiteness using jnp.linalg.eigvalsh before " - "calling this function.") - - L = jax.scipy.linalg.cholesky(input, lower=not upper) - if len(L.shape) > 2: - info = jnp.zeros(shape=L.shape[:-2], dtype=jnp.int32) - else: - info = jnp.array(0, dtype=jnp.int32) - return L, info + if check_errors: + raise NotImplementedError( + "check_errors=True is not supported in this JAX implementation. " + "Check for positive definiteness using jnp.linalg.eigvalsh before " + "calling this function." + ) + + L = jax.scipy.linalg.cholesky(input, lower=not upper) + if len(L.shape) > 2: + info = jnp.zeros(shape=L.shape[:-2], dtype=jnp.int32) + else: + info = jnp.array(0, dtype=jnp.int32) + return L, info @op(torch.ops.aten.cholesky_solve) def _aten_cholesky_solve(input, input2, upper=False): - # Ensure input2 is lower triangular for cho_solve - L = input2 if not upper else input2.T - # Use cho_solve to solve the linear system - solution = jax.scipy.linalg.cho_solve((L, True), input) - return solution + # Ensure input2 is lower triangular for cho_solve + L = input2 if not upper else input2.T + # Use cho_solve to solve the linear system + solution = jax.scipy.linalg.cho_solve((L, True), input) + return solution @op(torch.ops.aten.special_zeta) def _aten_special_zeta(x, q): - new_dtype = mappings.t2j_dtype(torch.get_default_dtype()) - res = jax.scipy.special.zeta(x, q) - if isinstance(x, int) or isinstance(q, int): - res = res.astype(new_dtype) - return res # jax.scipy.special.zeta(x, q) + new_dtype = mappings.t2j_dtype(torch.get_default_dtype()) + res = jax.scipy.special.zeta(x, q) + if isinstance(x, int) or isinstance(q, int): + res = res.astype(new_dtype) + return res # jax.scipy.special.zeta(x, q) # aten.igammac @op(torch.ops.aten.igammac) def _aten_igammac(input, other): - if isinstance(input, jnp.ndarray): - input = jnp.where(input < 0, jnp.nan, input) - if isinstance(other, jnp.ndarray): - other = jnp.where(other < 0, jnp.nan, other) - else: - if (input == 0 and other == 0) or (input < 0) or (other < 0): - other = jnp.nan - return jnp.array(jax.scipy.special.gammaincc(input, other)) + if isinstance(input, jnp.ndarray): + input = jnp.where(input < 0, jnp.nan, input) + if isinstance(other, jnp.ndarray): + other = jnp.where(other < 0, jnp.nan, other) + else: + if (input == 0 and other == 0) or (input < 0) or (other < 0): + other = jnp.nan + return jnp.array(jax.scipy.special.gammaincc(input, other)) @op(torch.ops.aten.mean) def _aten_mean(x, dim=None, keepdim=False): - if x.shape == () and dim is not None: - dim = None # disable dim for jax array without dim - return jnp.mean(x, dim, keepdims=keepdim) + if x.shape == () and dim is not None: + dim = None # disable dim for jax array without dim + return jnp.mean(x, dim, keepdims=keepdim) def _torch_binary_scalar_type(scalar, tensor): - if "float" in str(tensor.dtype) or "complex" in str(tensor.dtype): - return tensor.dtype + if "float" in str(tensor.dtype) or "complex" in str(tensor.dtype): + return tensor.dtype - if isinstance(scalar, int): - if "int" in str(tensor.dtype): - return tensor.dtype + if isinstance(scalar, int): + if "int" in str(tensor.dtype): + return tensor.dtype - return jnp.float32 + return jnp.float32 @op(torch.ops.aten.searchsorted.Tensor) def _aten_searchsorted(sorted_sequence, values): - new_dtype = mappings.t2j_dtype(torch.get_default_dtype()) - res = jnp.searchsorted(sorted_sequence, values) - if sorted_sequence.dtype == np.dtype( - np.int32) or sorted_sequence.dtype == np.dtype(np.int32): - # res = res.astype(new_dtype) - res = res.astype(np.dtype(np.int64)) - return res # jnp.searchsorted(sorted_sequence, values) + new_dtype = mappings.t2j_dtype(torch.get_default_dtype()) + res = jnp.searchsorted(sorted_sequence, values) + if sorted_sequence.dtype == np.dtype( + np.int32 + ) or sorted_sequence.dtype == np.dtype(np.int32): + # res = res.astype(new_dtype) + res = res.astype(np.dtype(np.int64)) + return res # jnp.searchsorted(sorted_sequence, values) @op(torch.ops.aten.sub.Tensor) @op(torch.ops.aten.sub.Scalar) def _aten_sub(x, y, alpha=1): - if isinstance(x, float): - dtype = _torch_binary_scalar_type(x, y) - x = jnp.array(x, dtype=dtype) - if isinstance(y, float): - dtype = _torch_binary_scalar_type(y, x) - y = jnp.array(y, dtype=dtype) - return x - y * alpha + if isinstance(x, float): + dtype = _torch_binary_scalar_type(x, y) + x = jnp.array(x, dtype=dtype) + if isinstance(y, float): + dtype = _torch_binary_scalar_type(y, x) + y = jnp.array(y, dtype=dtype) + return x - y * alpha @op(torch.ops.aten.numpy_T) def _aten_numpy_T(input): - """ - Jax implementation of torch.numpy_T. + """ + Jax implementation of torch.numpy_T. - Args: - input: JAX array. + Args: + input: JAX array. - Returns: - Transposed JAX array. - """ - return jnp.transpose(input) + Returns: + Transposed JAX array. + """ + return jnp.transpose(input) @op(torch.ops.aten.mm) def _aten_mm(x, y): - res = x @ y - return res + res = x @ y + return res @op(torch.ops.aten.mul.Tensor, torch.ops.aten.mul.Scalar) def _aten_mul(x, y): - new_dtype = mappings.t2j_dtype(torch.get_default_dtype()) - res = x * y - if isinstance(x, float) or isinstance(y, float): - res = res.astype(new_dtype) - else: - if (not isinstance(x, int)) and (not isinstance(y, int)): - if x.dtype == np.dtype(np.float64) or y.dtype == np.dtype(np.float64): + new_dtype = mappings.t2j_dtype(torch.get_default_dtype()) + res = x * y + if isinstance(x, float) or isinstance(y, float): res = res.astype(new_dtype) - return res + else: + if (not isinstance(x, int)) and (not isinstance(y, int)): + if x.dtype == np.dtype(np.float64) or y.dtype == np.dtype( + np.float64 + ): + res = res.astype(new_dtype) + return res @op(torch.ops.aten.silu) @op(torch.ops.aten.silu.default) def _aten_silu(x): - return jax.nn.silu(x) + return jax.nn.silu(x) @op(torch.ops.aten.t) def _aten_t(x): - return jnp.transpose(x) + return jnp.transpose(x) @op(torch.ops.aten.transpose) @op(torch.ops.aten.transpose_copy) def _aten_transpose(x, dim0, dim1): - if x.ndim == 0: - return x - dim0 = dim0 if dim0 >= 0 else dim0 + x.ndim - dim1 = dim1 if dim1 >= 0 else dim1 + x.ndim - return jnp.swapaxes(x, dim0, dim1) + if x.ndim == 0: + return x + dim0 = dim0 if dim0 >= 0 else dim0 + x.ndim + dim1 = dim1 if dim1 >= 0 else dim1 + x.ndim + return jnp.swapaxes(x, dim0, dim1) @op(torch.ops.aten.triu) def _aten_triu(m, k=0): - return jnp.triu(m, k) + return jnp.triu(m, k) @op(torch.ops.aten.slice) @op(torch.ops.aten.slice_copy) def _aten_slice(self, dim=0, start=None, end=None, step=1): - if dim < 0: - dim += self.ndim - if end == sys.maxsize: - end = self.shape[dim] - sl = slice(start, end, step) - dims = [] - for i in range(len(self.shape)): - if i == dim: - dims.append(sl) - else: - dims.append(slice(None, None, None)) - return self[tuple(dims)] + if dim < 0: + dim += self.ndim + if end == sys.maxsize: + end = self.shape[dim] + sl = slice(start, end, step) + dims = [] + for i in range(len(self.shape)): + if i == dim: + dims.append(sl) + else: + dims.append(slice(None, None, None)) + return self[tuple(dims)] @op(torch.ops.aten.detach) def _aten_detach(self): - return self + return self @op(torch.ops.aten.imag) def _aten_imag(x): - return jnp.imag(x) + return jnp.imag(x) @op(torch.ops.aten.isfinite) def _aten_isfinite(x): - return jnp.isfinite(x) + return jnp.isfinite(x) @op(torch.ops.aten.real) def _aten_real(x): - return jnp.real(x) + return jnp.real(x) @op(torch.Tensor.resize_) -def _aten_resize_(x, size, interpolation='linear'): - new_size = tuple(size) - return jax.numpy.resize(x, new_size) +def _aten_resize_(x, size, interpolation="linear"): + new_size = tuple(size) + return jax.numpy.resize(x, new_size) @op(torch.ops.aten.resize_as_) def _aten_resize_as_(x, y): - return jax.numpy.resize(x, y.shape) + return jax.numpy.resize(x, y.shape) @op(torch.ops.aten.repeat_interleave.Tensor) def repeat_interleave(repeats, dim=0): - return jnp.repeat(np.arange(repeats.shape[dim]), repeats) + return jnp.repeat(np.arange(repeats.shape[dim]), repeats) @op(torch.ops.aten.repeat_interleave.self_int) @op(torch.ops.aten.repeat_interleave.self_Tensor) def repeat_interleave(self, repeats, dim=0): - total_repeat_length = None - if isinstance(repeats, int): - total_repeat_length = self.shape[dim] * repeats - repeats = np.array([repeats] * self.shape[dim]) - return jnp.repeat(self, repeats, dim, total_repeat_length=total_repeat_length) + total_repeat_length = None + if isinstance(repeats, int): + total_repeat_length = self.shape[dim] * repeats + repeats = np.array([repeats] * self.shape[dim]) + return jnp.repeat( + self, repeats, dim, total_repeat_length=total_repeat_length + ) @op(torch.ops.aten.view_as_real) def _aten_view_as_real(x): - real = jnp.real(x) - im = jnp.imag(x) - res = jnp.stack([real, im], -1) - return res + real = jnp.real(x) + im = jnp.imag(x) + res = jnp.stack([real, im], -1) + return res @op(torch.ops.aten.stack) def _aten_stack(tensors, dim=0): - return jnp.stack(tensors, dim) + return jnp.stack(tensors, dim) @op(torch.ops.aten._softmax) @op(torch.ops.aten.softmax) @op(torch.ops.aten.softmax.int) def _aten_softmax(x, dim, halftofloat=False): - if x.shape == (): - return jax.nn.softmax(x.reshape([1]), axis=0).reshape([]) - return jax.nn.softmax(x, dim) + if x.shape == (): + return jax.nn.softmax(x.reshape([1]), axis=0).reshape([]) + return jax.nn.softmax(x, dim) def _is_int(x): - if isinstance(x, int): - return True - if isinstance(x, jax.Array) and (x.dtype.name.startswith('int') or - x.dtype.name.startswith('uint')): - return True - return False + if isinstance(x, int): + return True + if isinstance(x, jax.Array) and ( + x.dtype.name.startswith("int") or x.dtype.name.startswith("uint") + ): + return True + return False def highest_precision_int_dtype(tensor1, tensor2): - if isinstance(tensor1, int): - return tensor2.dtype - if isinstance(tensor2, int): - return tensor1.dtype - - dtype_hierarchy = { - 'uint8': 8, - 'int8': 8, - 'uint16': 16, - 'int16': 16, - 'uint32': 32, - 'int32': 32, - 'uint64': 64, - 'int64': 64, - } - return max( - tensor1.dtype, - tensor2.dtype, - key=lambda dtype: dtype_hierarchy[str(dtype)]) + if isinstance(tensor1, int): + return tensor2.dtype + if isinstance(tensor2, int): + return tensor1.dtype + + dtype_hierarchy = { + "uint8": 8, + "int8": 8, + "uint16": 16, + "int16": 16, + "uint32": 32, + "int32": 32, + "uint64": 64, + "int64": 64, + } + return max( + tensor1.dtype, + tensor2.dtype, + key=lambda dtype: dtype_hierarchy[str(dtype)], + ) @op(torch.ops.aten.pow) def _aten_pow(x, y): - y_orig = y - if isinstance(y, int): - y = float(y) - if _is_int(x) and _is_int(y_orig): - # Do the math in float then cast - res = jnp.power(jnp.astype(x, jnp.dtype('float')), y) - return res.astype(highest_precision_int_dtype(x, y_orig)) - res = jnp.power(x, y) - if isinstance(x, float): - return res.astype(_torch_binary_scalar_type(x, y_orig)) - if isinstance(y_orig, float): - return res.astype(_torch_binary_scalar_type(y_orig, x)) - return res + y_orig = y + if isinstance(y, int): + y = float(y) + if _is_int(x) and _is_int(y_orig): + # Do the math in float then cast + res = jnp.power(jnp.astype(x, jnp.dtype("float")), y) + return res.astype(highest_precision_int_dtype(x, y_orig)) + res = jnp.power(x, y) + if isinstance(x, float): + return res.astype(_torch_binary_scalar_type(x, y_orig)) + if isinstance(y_orig, float): + return res.astype(_torch_binary_scalar_type(y_orig, x)) + return res @op(torch.ops.aten.view_as_complex) def _aten_view_as_complex(input): - if input.dtype == jnp.bfloat16: - input = input.astype(jnp.float32) - x, y = input[..., 0], input[..., 1] - return jax.lax.complex(x, y) + if input.dtype == jnp.bfloat16: + input = input.astype(jnp.float32) + x, y = input[..., 0], input[..., 1] + return jax.lax.complex(x, y) @op(torch.ops.aten.div) def _aten_div(x, y, rounding_mode=""): - res_dtype = None - if _is_int(x) and _is_int(y): - res_dtype = jnp.dtype('float32') + res_dtype = None + if _is_int(x) and _is_int(y): + res_dtype = jnp.dtype("float32") - if (isinstance(x, float) or isinstance(y, float)): - res_dtype = new_dtype = mappings.t2j_dtype(torch.get_default_dtype()) + if isinstance(x, float) or isinstance(y, float): + res_dtype = new_dtype = mappings.t2j_dtype(torch.get_default_dtype()) - if rounding_mode == "floor": - res = jnp.floor_divide(x, y) - if _is_int(x) and _is_int(y): - res_dtype = jnp.dtype('int64') - else: - res = x / y - if rounding_mode == "trunc": - res = jnp.trunc(res) - if _is_int(x) and _is_int(y): - res_dtype = jnp.dtype('int64') - if res_dtype: - res = res.astype(res_dtype) - return res + if rounding_mode == "floor": + res = jnp.floor_divide(x, y) + if _is_int(x) and _is_int(y): + res_dtype = jnp.dtype("int64") + else: + res = x / y + if rounding_mode == "trunc": + res = jnp.trunc(res) + if _is_int(x) and _is_int(y): + res_dtype = jnp.dtype("int64") + if res_dtype: + res = res.astype(res_dtype) + return res @op(torch.ops.aten.true_divide) def _aten_true_divide(x, y): - return x / y + return x / y @op(torch.ops.aten.dist) def _aten_dist(input, other, p=2): - diff = jnp.abs(jnp.subtract(input, other)) - return _aten_linalg_vector_norm(diff, ord=p) + diff = jnp.abs(jnp.subtract(input, other)) + return _aten_linalg_vector_norm(diff, ord=p) @op(torch.ops.aten.bmm) def _aten_bmm(x, y): - res = x @ y - return res - # return jnp.einsum('bnm,bmk->bnk', x, y) + res = x @ y + return res + # return jnp.einsum('bnm,bmk->bnk', x, y) @op(torch.ops.aten.embedding) # embedding(Tensor weight, Tensor indices, SymInt padding_idx=-1, bool scale_grad_by_freq=False, bool sparse=False) -def _aten_embedding(a, - w, - padding_idx=-1, - scale_grad_by_freq=False, - sparse=False): - return jnp.take(a, w, axis=0) +def _aten_embedding( + a, w, padding_idx=-1, scale_grad_by_freq=False, sparse=False +): + return jnp.take(a, w, axis=0) @op(torch.ops.aten.embedding_renorm_) def _aten_embedding_renorm_(weight, indices, max_norm, norm_type): - # Adapted from https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/native/Embedding.cpp - unique_indices = jnp.unique(indices) + # Adapted from https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/native/Embedding.cpp + unique_indices = jnp.unique(indices) - norm = jnp.linalg.norm( - _aten_embedding(weight, unique_indices), - ord=norm_type, - axis=1, - ) + norm = jnp.linalg.norm( + _aten_embedding(weight, unique_indices), + ord=norm_type, + axis=1, + ) - indice_idx = jnp.where(norm > max_norm) + indice_idx = jnp.where(norm > max_norm) - scale = max_norm / (norm[indice_idx] + 1e-7) + scale = max_norm / (norm[indice_idx] + 1e-7) - indices_to_update = unique_indices[indice_idx] + indices_to_update = unique_indices[indice_idx] - weight = weight.at[indices_to_update].set(weight[indices_to_update] * - scale[:, None]) - return weight + weight = weight.at[indices_to_update].set( + weight[indices_to_update] * scale[:, None] + ) + return weight -#- func: _embedding_bag_forward_only( +# - func: _embedding_bag_forward_only( # Tensor weight, Tensor indices, Tensor offsets, bool scale_grad_by_freq=False, # int mode=0, bool sparse=False, Tensor? per_sample_weights=None, bool include_last_offset=False, int padding_idx=-1) -> (Tensor, Tensor, Tensor, Tensor) @op(torch.ops.aten._embedding_bag) @op(torch.ops.aten._embedding_bag_forward_only) -def _aten__embedding_bag(weight, - indices, - offsets=None, - scale_grad_by_freq=False, - mode=0, - sparse=False, - per_sample_weights=None, - include_last_offset=False, - padding_idx=-1): - """Jax implementation of the PyTorch _embedding_bag function. +def _aten__embedding_bag( + weight, + indices, + offsets=None, + scale_grad_by_freq=False, + mode=0, + sparse=False, + per_sample_weights=None, + include_last_offset=False, + padding_idx=-1, +): + """Jax implementation of the PyTorch _embedding_bag function. Args: weight: The learnable weights of the module of shape (num_embeddings, embedding_dim). @@ -603,207 +614,210 @@ def _aten__embedding_bag(weight, Returns: A tuple of (output, offset2bag, bag_size, max_indices). """ - embedded = _aten_embedding(weight, indices, padding_idx) - - if offsets is None: - # offsets is None only when indices.ndim > 1 - if mode == 0: # sum - output = jnp.sum(embedded, axis=1) - elif mode == 1: # mean - output = jnp.mean(embedded, axis=1) - elif mode == 2: # max - output = jnp.max(embedded, axis=1) - return output, None, None, None - - if isinstance(offsets, jax.Array): - offsets_np = np.array(offsets) - else: - offsets_np = offsets - offset2bag = np.zeros(indices.shape[0], dtype=np.int64) - bag_size = np.zeros(offsets_np.shape[0], dtype=np.int64) - max_indices = jnp.full_like(indices, -1) - - for bag in range(offsets_np.shape[0]): - start = int(offsets_np[bag]) - - end = int(indices.shape[0] if bag + - 1 == offsets_np.shape[0] else offsets_np[bag + 1]) - bag_size[bag] = end - start - offset2bag = offset2bag.at[start:end].set(bag) - - if end - start > 0: - if mode == 0: - output_bag = jnp.sum(embedded[start:end], axis=0) - elif mode == 1: - output_bag = jnp.mean(embedded[start:end], axis=0) - elif mode == 2: - output_bag = jnp.max(embedded[start:end], axis=0) - max_indices = max_indices.at[start:end].set( - jnp.argmax(embedded[start:end], axis=0)) - - # The original code returned offset2bag, bag_size, and max_indices as numpy arrays. - # Converting them to JAX arrays for consistency. - offset2bag = jnp.array(offset2bag) - bag_size = jnp.array(bag_size) - - return output_bag, offset2bag, bag_size, max_indices + embedded = _aten_embedding(weight, indices, padding_idx) + + if offsets is None: + # offsets is None only when indices.ndim > 1 + if mode == 0: # sum + output = jnp.sum(embedded, axis=1) + elif mode == 1: # mean + output = jnp.mean(embedded, axis=1) + elif mode == 2: # max + output = jnp.max(embedded, axis=1) + return output, None, None, None + + if isinstance(offsets, jax.Array): + offsets_np = np.array(offsets) + else: + offsets_np = offsets + offset2bag = np.zeros(indices.shape[0], dtype=np.int64) + bag_size = np.zeros(offsets_np.shape[0], dtype=np.int64) + max_indices = jnp.full_like(indices, -1) + + for bag in range(offsets_np.shape[0]): + start = int(offsets_np[bag]) + + end = int( + indices.shape[0] + if bag + 1 == offsets_np.shape[0] + else offsets_np[bag + 1] + ) + bag_size[bag] = end - start + offset2bag = offset2bag.at[start:end].set(bag) + + if end - start > 0: + if mode == 0: + output_bag = jnp.sum(embedded[start:end], axis=0) + elif mode == 1: + output_bag = jnp.mean(embedded[start:end], axis=0) + elif mode == 2: + output_bag = jnp.max(embedded[start:end], axis=0) + max_indices = max_indices.at[start:end].set( + jnp.argmax(embedded[start:end], axis=0) + ) + + # The original code returned offset2bag, bag_size, and max_indices as numpy arrays. + # Converting them to JAX arrays for consistency. + offset2bag = jnp.array(offset2bag) + bag_size = jnp.array(bag_size) + + return output_bag, offset2bag, bag_size, max_indices @op(torch.ops.aten.rsqrt) @op_base.promote_int_input def _aten_rsqrt(x): - return jax.lax.rsqrt(x) + return jax.lax.rsqrt(x) @op(torch.ops.aten.expand) @op(torch.ops.aten.expand_copy) def _aten_expand(x, dims): - - def fix_dims(d, xs): - if d == -1: - return xs - return d - - shape = list(x.shape) - if len(shape) < len(dims): - shape = [ - 1, - ] * (len(dims) - len(shape)) + shape - # make sure that dims and shape is the same by - # left pad with 1s. Otherwise the zip below will - # truncate - dims = [fix_dims(p, s) for p, s in zip(dims, shape)] - return jnp.broadcast_to(x, dims) + def fix_dims(d, xs): + if d == -1: + return xs + return d + + shape = list(x.shape) + if len(shape) < len(dims): + shape = [ + 1, + ] * (len(dims) - len(shape)) + shape + # make sure that dims and shape is the same by + # left pad with 1s. Otherwise the zip below will + # truncate + dims = [fix_dims(p, s) for p, s in zip(dims, shape)] + return jnp.broadcast_to(x, dims) @op(torch.ops.aten.dot) def _aten_dot(x, y): - return jnp.dot(x, y) + return jnp.dot(x, y) @op(torch.ops.aten._to_copy) def _aten__to_copy(self, **kwargs): - dtype = mappings.t2j_dtype(kwargs["dtype"]) - if dtype != self.dtype: - return self.astype(dtype) - return jnp.copy(self) + dtype = mappings.t2j_dtype(kwargs["dtype"]) + if dtype != self.dtype: + return self.astype(dtype) + return jnp.copy(self) @op(torch.ops.aten.empty) @op_base.convert_dtype(use_default_dtype=False) def _aten_empty(size: Sequence[int], *, dtype=None, **kwargs): - return jnp.empty(size, dtype=dtype) + return jnp.empty(size, dtype=dtype) @op(torch.ops.aten.empty_like) @op_base.convert_dtype(use_default_dtype=False) def _aten_empty_like(input, *, dtype=None, **kwargs): - return jnp.empty_like(input, dtype) + return jnp.empty_like(input, dtype) @op(torch.ops.aten.ones) @op_base.convert_dtype() def _ones(size: Sequence[int], dtype=None, **kwargs): - return jnp.ones(size, dtype) + return jnp.ones(size, dtype) @op(torch.ops.aten.zeros) @op_base.convert_dtype() def _zeros(size: Sequence[int], dtype=None, **kwargs): - return jnp.zeros(size, dtype) + return jnp.zeros(size, dtype) @op(torch.ops.aten.full) @op_base.convert_dtype() def _full(size: Sequence[int], fill_value, *, dtype=None, **kwargs): - # TODO: handle torch.Size - return jnp.full(size, fill_value, dtype=dtype) + # TODO: handle torch.Size + return jnp.full(size, fill_value, dtype=dtype) @op(torch.ops.aten.empty_permuted) @op_base.convert_dtype() def _aten_empty_permuted(sizes, physical_layout, dtype=None, **kwargs): - # Ignore the physical layout, - # since JAX and torch tensor doesn't share the same memory. - return jnp.empty(sizes, dtype=dtype) + # Ignore the physical layout, + # since JAX and torch tensor doesn't share the same memory. + return jnp.empty(sizes, dtype=dtype) @op(torch.ops.aten.empty_strided) @op_base.convert_dtype() def _aten_empty_strided(sizes, stride, dtype=None, **kwargs): - # Ignore stride, since JAX and torch tensor doesn't share the same memory. - return jnp.empty(sizes, dtype=dtype) + # Ignore stride, since JAX and torch tensor doesn't share the same memory. + return jnp.empty(sizes, dtype=dtype) @op(torch.ops.aten.index_put_) @op(torch.ops.aten.index_put) def _aten_index_put(self, indexes, values, accumulate=False): - indexes = [slice(None, None, None) if i is None else i for i in indexes] - indexes = tuple(indexes) - if accumulate: - return self.at[indexes].add(values) - else: - return self.at[indexes].set(values) + indexes = [slice(None, None, None) if i is None else i for i in indexes] + indexes = tuple(indexes) + if accumulate: + return self.at[indexes].add(values) + else: + return self.at[indexes].set(values) @op(torch.ops.aten.index) @op(torch.ops.aten._unsafe_index) @op(torch.ops.aten.index.Tensor) def _aten_index(self, indexes): - indexes = [slice(None, None, None) if i is None else i for i in indexes] - indexes = tuple(indexes) - return self[indexes] + indexes = [slice(None, None, None) if i is None else i for i in indexes] + indexes = tuple(indexes) + return self[indexes] @op(torch.ops.aten.split) @op(torch.ops.aten.split_copy) @op(torch.ops.aten.split_with_sizes) def split_with_sizes(x, sizes, dim=0): - """Splits an array `x` into sub-arrays based on static sizes `sizes`. + """Splits an array `x` into sub-arrays based on static sizes `sizes`. - Args: - x: The input array to split. - sizes: A 1D array of integer sizes for each sub-array. + Args: + x: The input array to split. + sizes: A 1D array of integer sizes for each sub-array. - Returns: - A list of sub-arrays. - """ - if isinstance(sizes, int): - # split equal size, round up - new_sizes = [sizes] * (-(-x.shape[dim] // sizes)) - sizes = new_sizes - rank = x.ndim - splits = np.cumsum(sizes) # Cumulative sum for split points + Returns: + A list of sub-arrays. + """ + if isinstance(sizes, int): + # split equal size, round up + new_sizes = [sizes] * (-(-x.shape[dim] // sizes)) + sizes = new_sizes + rank = x.ndim + splits = np.cumsum(sizes) # Cumulative sum for split points - def make_range(rank, dim, start, end): - res = [slice(None, None, None)] * rank - res[dim] = slice(start, end) - return tuple(res) + def make_range(rank, dim, start, end): + res = [slice(None, None, None)] * rank + res[dim] = slice(start, end) + return tuple(res) - return [ - x[make_range(rank, dim, start, end)] - for start, end in zip([0] + list(splits[:-1]), splits) - ] + return [ + x[make_range(rank, dim, start, end)] + for start, end in zip([0] + list(splits[:-1]), splits) + ] @op(torch.ops.aten.permute) @op(torch.ops.aten.permute_copy) def permute(t, dims): - # TODO: return a View instead - return jnp.transpose(t, dims) + # TODO: return a View instead + return jnp.transpose(t, dims) @op(torch.ops.aten.unsqueeze) @op(torch.ops.aten.unsqueeze_copy) def _aten_unsqueeze(self, dim): - if dim < 0: - dim += self.ndim + 1 - return jnp.expand_dims(self, dim) + if dim < 0: + dim += self.ndim + 1 + return jnp.expand_dims(self, dim) @op(torch.ops.aten.ne) def _aten_ne(x, y): - return jnp.not_equal(x, y) + return jnp.not_equal(x, y) # Create indices along a specific axis @@ -817,200 +831,195 @@ def _aten_ne(x, y): # _indices_along_axis(x, axis=1) # >> [[0, 1, 2, 3]] shape (1, 4) def _indices_along_axis(x, axis): - return jnp.expand_dims( - jnp.arange(x.shape[axis]), - axis=[d for d in range(len(x.shape)) if d != axis]) + return jnp.expand_dims( + jnp.arange(x.shape[axis]), + axis=[d for d in range(len(x.shape)) if d != axis], + ) def _broadcast_indices(indices, shape): - return jnp.broadcast_to(indices, shape) + return jnp.broadcast_to(indices, shape) @op(torch.ops.aten.cummax) def _aten_cummax(x, dim): - if not x.shape: - return x, jnp.zeros_like(x, dtype=jnp.int64) + if not x.shape: + return x, jnp.zeros_like(x, dtype=jnp.int64) - axis = dim + axis = dim - indice_along_axis = _indices_along_axis(x, axis) - indices = _broadcast_indices(indice_along_axis, x.shape) + indice_along_axis = _indices_along_axis(x, axis) + indices = _broadcast_indices(indice_along_axis, x.shape) - def cummax_reduce_func(carry, elem): - v1, v2 = carry['val'], elem['val'] - i1, i2 = carry['idx'], elem['idx'] + def cummax_reduce_func(carry, elem): + v1, v2 = carry["val"], elem["val"] + i1, i2 = carry["idx"], elem["idx"] - v = jnp.maximum(v1, v2) - i = jnp.where(v1 > v2, i1, i2) - return {'val': v, 'idx': i} + v = jnp.maximum(v1, v2) + i = jnp.where(v1 > v2, i1, i2) + return {"val": v, "idx": i} - res = jax.lax.associative_scan( - cummax_reduce_func, { - 'val': x, - 'idx': indices - }, axis=axis) - return res['val'], res['idx'] + res = jax.lax.associative_scan( + cummax_reduce_func, {"val": x, "idx": indices}, axis=axis + ) + return res["val"], res["idx"] @op(torch.ops.aten.cummin) def _aten_cummin(x, dim): - if not x.shape: - return x, jnp.zeros_like(x, dtype=jnp.int64) + if not x.shape: + return x, jnp.zeros_like(x, dtype=jnp.int64) - axis = dim + axis = dim - indice_along_axis = _indices_along_axis(x, axis) - indices = _broadcast_indices(indice_along_axis, x.shape) + indice_along_axis = _indices_along_axis(x, axis) + indices = _broadcast_indices(indice_along_axis, x.shape) - def cummin_reduce_func(carry, elem): - v1, v2 = carry['val'], elem['val'] - i1, i2 = carry['idx'], elem['idx'] + def cummin_reduce_func(carry, elem): + v1, v2 = carry["val"], elem["val"] + i1, i2 = carry["idx"], elem["idx"] - v = jnp.minimum(v1, v2) - i = jnp.where(v1 < v2, i1, i2) - return {'val': v, 'idx': i} + v = jnp.minimum(v1, v2) + i = jnp.where(v1 < v2, i1, i2) + return {"val": v, "idx": i} - res = jax.lax.associative_scan( - cummin_reduce_func, { - 'val': x, - 'idx': indices - }, axis=axis) - return res['val'], res['idx'] + res = jax.lax.associative_scan( + cummin_reduce_func, {"val": x, "idx": indices}, axis=axis + ) + return res["val"], res["idx"] @op(torch.ops.aten.cumsum) def _aten_cumsum(x, y, dtype=None): - if dtype: - dtype = mappings.t2j_dtype(dtype) - if not x.shape: - return x - res = jnp.cumsum(x, y, dtype) - return res + if dtype: + dtype = mappings.t2j_dtype(dtype) + if not x.shape: + return x + res = jnp.cumsum(x, y, dtype) + return res @op(torch.ops.aten.cumprod) def _aten_cumprod(input, dim, dtype=None, out=None): - if dtype: - dtype = mappings.t2j_dtype(dtype) - if len(input.shape) > 0: - res = jnp.cumprod(input, axis=dim, dtype=dtype) - elif dtype: - res = input.astype(dtype) - else: - res = input - return res + if dtype: + dtype = mappings.t2j_dtype(dtype) + if len(input.shape) > 0: + res = jnp.cumprod(input, axis=dim, dtype=dtype) + elif dtype: + res = input.astype(dtype) + else: + res = input + return res @op(torch.ops.aten.native_layer_norm) -def _aten_native_layer_norm(input, - normalized_shape, - weight=None, - bias=None, - eps=1e-5): - """Implements layer normalization in Jax as defined by `aten::native_layer_norm`. - - Args: - input: The input tensor. - normalized_shape: A list of integer dimensions to be normalized over. - weight: Optional weight tensor for the affine transformation. - bias: Optional bias tensor for the affine transformation. - eps: A small epsilon value for numerical stability. - - Returns: - output: The normalized tensor. - mean: The calculated mean tensor. - std: The calculated standard deviation tensor. - """ - if isinstance(normalized_shape, int): - normalized_shape = [normalized_shape] - axis = [len(input.shape) - i - 1 for i in range(len(normalized_shape))] - - # Calculate mean and standard deviation - mean = jnp.mean(input, axis=axis, keepdims=True) - var = jnp.var(input, axis=axis, keepdims=True) - rstd = jax.lax.rsqrt(var + eps) - - # Normalize the input - norm_x = (input - mean) * rstd - - # Apply affine transformation (if provided) - if weight is not None: - norm_x *= weight - if bias is not None: - norm_x += bias - return norm_x, mean, rstd +def _aten_native_layer_norm( + input, normalized_shape, weight=None, bias=None, eps=1e-5 +): + """Implements layer normalization in Jax as defined by `aten::native_layer_norm`. + + Args: + input: The input tensor. + normalized_shape: A list of integer dimensions to be normalized over. + weight: Optional weight tensor for the affine transformation. + bias: Optional bias tensor for the affine transformation. + eps: A small epsilon value for numerical stability. + + Returns: + output: The normalized tensor. + mean: The calculated mean tensor. + std: The calculated standard deviation tensor. + """ + if isinstance(normalized_shape, int): + normalized_shape = [normalized_shape] + axis = [len(input.shape) - i - 1 for i in range(len(normalized_shape))] + + # Calculate mean and standard deviation + mean = jnp.mean(input, axis=axis, keepdims=True) + var = jnp.var(input, axis=axis, keepdims=True) + rstd = jax.lax.rsqrt(var + eps) + + # Normalize the input + norm_x = (input - mean) * rstd + + # Apply affine transformation (if provided) + if weight is not None: + norm_x *= weight + if bias is not None: + norm_x += bias + return norm_x, mean, rstd @op(torch.ops.aten.matmul) def _aten_matmul(x, y): - return x @ y + return x @ y # - func: addmm(Tensor self, Tensor mat1, Tensor mat2, *, Scalar beta=1, Scalar alpha=1) -> Tensor @op(torch.ops.aten.addmm) @op(torch.ops.aten.addmv) def _aten_addmm(self, mat1, mat2, *, beta=1.0, alpha=1.0): - alpha = jnp.array(alpha).astype(mat1.dtype) - beta = jnp.array(beta).astype(mat1.dtype) - self *= beta - self += alpha * jnp.matmul(mat1, mat2) - return self + alpha = jnp.array(alpha).astype(mat1.dtype) + beta = jnp.array(beta).astype(mat1.dtype) + self *= beta + self += alpha * jnp.matmul(mat1, mat2) + return self @op(torch.ops.aten.sparse_sampled_addmm) def _aten_sparse_addmm(self, mat1, mat2, *, beta=1.0, alpha=1.0): - alpha = jnp.array(alpha).astype(mat1.dtype) - beta = jnp.array(beta).astype(mat1.dtype) - self *= beta - self += alpha * jnp.matmul(mat1, mat2) * (self != 0) - return self + alpha = jnp.array(alpha).astype(mat1.dtype) + beta = jnp.array(beta).astype(mat1.dtype) + self *= beta + self += alpha * jnp.matmul(mat1, mat2) * (self != 0) + return self @op(torch.ops.aten.addbmm.default) def _aten_addbmm(input, batch1, batch2, *, beta=1, alpha=1): - alpha = jnp.array(alpha).astype(batch1.dtype) - beta = jnp.array(beta).astype(batch1.dtype) - mm = jnp.einsum("bxy, byz -> xz", batch1, batch2) - return jax.lax.cond(beta == 0, lambda: alpha * mm, - lambda: beta * input + alpha * mm) + alpha = jnp.array(alpha).astype(batch1.dtype) + beta = jnp.array(beta).astype(batch1.dtype) + mm = jnp.einsum("bxy, byz -> xz", batch1, batch2) + return jax.lax.cond( + beta == 0, lambda: alpha * mm, lambda: beta * input + alpha * mm + ) @op(torch.ops.aten.gelu) def _aten_gelu(self, *, approximate="none"): - approx = approximate == "tanh" - return jax.nn.gelu(self, approx) + approx = approximate == "tanh" + return jax.nn.gelu(self, approx) @op(torch.ops.aten.squeeze) @op(torch.ops.aten.squeeze_copy) def _aten_squeeze_dim(self, dim=None): - if self.ndim == 0: - return self - if dim is not None: - if isinstance(dim, int): - if self.shape[dim] != 1: + if self.ndim == 0: return self - if dim < 0: - dim += self.ndim - else: - # NOTE: torch leaves the dims that is not 1 unchanged, - # but jax raises error. - dim = [ - i if i >= 0 else (i + self.ndim) for i in dim if self.shape[i] == 1 - ] - - return jnp.squeeze(self, dim) + if dim is not None: + if isinstance(dim, int): + if self.shape[dim] != 1: + return self + if dim < 0: + dim += self.ndim + else: + # NOTE: torch leaves the dims that is not 1 unchanged, + # but jax raises error. + dim = [ + i if i >= 0 else (i + self.ndim) + for i in dim + if self.shape[i] == 1 + ] + + return jnp.squeeze(self, dim) @op(torch.ops.aten.bucketize) -def _aten_bucketize(input, - boundaries, - *, - out_int32=False, - right=False, - out=None): - return_type = jnp.int32 if out_int32 else jnp.int64 - return jnp.digitize(input, boundaries, right=not right).astype(return_type) +def _aten_bucketize( + input, boundaries, *, out_int32=False, right=False, out=None +): + return_type = jnp.int32 if out_int32 else jnp.int64 + return jnp.digitize(input, boundaries, right=not right).astype(return_type) @op(torch.ops.aten.conv2d) @@ -1023,16 +1032,17 @@ def _aten_conv2d( dilation, groups, ): - return _aten_convolution( - input, - weight, - bias, - stride, - padding, - dilation, - transposed=False, - output_padding=1, - groups=groups) + return _aten_convolution( + input, + weight, + bias, + stride, + padding, + dilation, + transposed=False, + output_padding=1, + groups=groups, + ) @op(torch.ops.aten.convolution) @@ -1047,169 +1057,176 @@ def _aten_convolution( output_padding, groups, ): - num_shape_dim = weight.ndim - 1 - batch_dims = input.shape[:-num_shape_dim] + num_shape_dim = weight.ndim - 1 + batch_dims = input.shape[:-num_shape_dim] + + input = input.reshape((-1, *input.shape[-num_shape_dim:])) + + def make_padding(padding, num_spatial_dims): + # Expand single padding to pairs expected by jax + if len(padding) == 1 and len(padding) < num_spatial_dims: + padding *= num_spatial_dims + if transposed: + # See https://pytorch.org/docs/stable/generated/torch.nn.ConvTranspose1d.html + pad_out = [] + for i in range(num_spatial_dims): + front = dilation[i] * (weight.shape[i + 2] - 1) - padding[i] + back = front + output_padding[i] + pad_out.append((front, back)) + return pad_out + else: + return ((p, p) for p in padding) + + def create_default_conv_dimension_numbers(num_spatial_dims): + # Ref: https://github.com/openxla/xla/blob/main/xla/client/xla_builder.cc#L4211 + # (batch dimension, feature dimension, spatial dimensions...) + lhs_spec = [0, 1] + # (out feature dimension, in feature dimension, spatial dimensions...) + # swapped for transposed convolution + rhs_spec = [1, 0] if transposed else [0, 1] + # (batch dimension, feature dimension, spatial dimensions...) + out_spec = [0, 1] + for i in range(0, num_spatial_dims): + lhs_spec.append(i + 2) + rhs_spec.append(i + 2) + out_spec.append(i + 2) + return jax.lax.ConvDimensionNumbers( + *map(tuple, (lhs_spec, rhs_spec, out_spec)) + ) - input = input.reshape((-1, *input.shape[-num_shape_dim:])) - - def make_padding(padding, num_spatial_dims): - # Expand single padding to pairs expected by jax - if len(padding) == 1 and len(padding) < num_spatial_dims: - padding *= num_spatial_dims if transposed: - # See https://pytorch.org/docs/stable/generated/torch.nn.ConvTranspose1d.html - pad_out = [] - for i in range(num_spatial_dims): - front = dilation[i] * (weight.shape[i + 2] - 1) - padding[i] - back = front + output_padding[i] - pad_out.append((front, back)) - return pad_out + rhs = jnp.flip(weight, range(2, 1 + num_shape_dim)) + if groups != 1: + # reshape filters for tranposed depthwise convolution + assert rhs.shape[0] % groups == 0 + rhs_shape = [rhs.shape[0] // groups, rhs.shape[1] * groups] + rhs_shape.extend(rhs.shape[2:]) + rhs = jnp.reshape(rhs, rhs_shape) + res = jax.lax.conv_general_dilated( + input, + rhs, + (1,) * len(stride), + make_padding(padding, len(stride)), + lhs_dilation=stride, + rhs_dilation=dilation, + dimension_numbers=create_default_conv_dimension_numbers( + len(stride) + ), + feature_group_count=groups, + batch_group_count=1, + ) else: - return ((p, p) for p in padding) - - def create_default_conv_dimension_numbers(num_spatial_dims): - # Ref: https://github.com/openxla/xla/blob/main/xla/client/xla_builder.cc#L4211 - # (batch dimension, feature dimension, spatial dimensions...) - lhs_spec = [0, 1] - # (out feature dimension, in feature dimension, spatial dimensions...) - # swapped for transposed convolution - rhs_spec = [1, 0] if transposed else [0, 1] - # (batch dimension, feature dimension, spatial dimensions...) - out_spec = [0, 1] - for i in range(0, num_spatial_dims): - lhs_spec.append(i + 2) - rhs_spec.append(i + 2) - out_spec.append(i + 2) - return jax.lax.ConvDimensionNumbers( - *map(tuple, (lhs_spec, rhs_spec, out_spec))) - - if transposed: - rhs = jnp.flip(weight, range(2, 1 + num_shape_dim)) - if groups != 1: - # reshape filters for tranposed depthwise convolution - assert rhs.shape[0] % groups == 0 - rhs_shape = [rhs.shape[0] // groups, rhs.shape[1] * groups] - rhs_shape.extend(rhs.shape[2:]) - rhs = jnp.reshape(rhs, rhs_shape) - res = jax.lax.conv_general_dilated( - input, - rhs, - (1,) * len(stride), - make_padding(padding, len(stride)), - lhs_dilation=stride, - rhs_dilation=dilation, - dimension_numbers=create_default_conv_dimension_numbers(len(stride)), - feature_group_count=groups, - batch_group_count=1, - ) - else: - res = jax.lax.conv_general_dilated( - input, - weight, - stride, - make_padding(padding, len(stride)), - lhs_dilation=(1,) * len(stride), - rhs_dilation=dilation, - dimension_numbers=create_default_conv_dimension_numbers(len(stride)), - feature_group_count=groups, - batch_group_count=1, - ) - - if bias is not None: - # TODO(qihqi): bias always on channel? - if len(bias.shape) == 1: - shape = [1] * len(res.shape) - shape[1] = bias.shape[0] - bias = bias.reshape(tuple(shape)) - res = res + bias - - res = res.reshape((*batch_dims, *res.shape[-num_shape_dim:])) - return res + res = jax.lax.conv_general_dilated( + input, + weight, + stride, + make_padding(padding, len(stride)), + lhs_dilation=(1,) * len(stride), + rhs_dilation=dilation, + dimension_numbers=create_default_conv_dimension_numbers( + len(stride) + ), + feature_group_count=groups, + batch_group_count=1, + ) + + if bias is not None: + # TODO(qihqi): bias always on channel? + if len(bias.shape) == 1: + shape = [1] * len(res.shape) + shape[1] = bias.shape[0] + bias = bias.reshape(tuple(shape)) + res = res + bias + + res = res.reshape((*batch_dims, *res.shape[-num_shape_dim:])) + return res # _native_batch_norm_legit(Tensor input, Tensor? weight, Tensor? bias, Tensor(a!) running_mean, Tensor(b!) running_var, bool training, float momentum, float eps) @op(torch.ops.aten._native_batch_norm_legit.default) -def _aten__native_batch_norm_legit(input, weight, bias, running_mean, - running_var, training, momentum, eps): - """JAX implementation of batch normalization with optional parameters. - Refers to https://github.com/pytorch/pytorch/blob/cd3a71f754a2248bcfe500de7c9860bd7d2002bf/torch/_decomp/decompositions.py#L1713. - - Args: - input (DeviceArray): Input data (N, C, H, W). - running_mean ([DeviceArray]): Running mean of input (C,). - running_var ([DeviceArray]): Running variance of input (C,). - weight (Optional[DeviceArray]): Scaling factor (gamma) (C,). Can be None. - bias (Optional[DeviceArray]): Shift factor (beta) (C,). Can be None. - training (bool): If True, use batch statistics for normalization. - If False, use running statistics. - momentum (float): Momentum factor for updating running statistics. - eps (float): Small constant for numerical stability. - - Returns: - DeviceArray: Normalized output - DeviceArray: Batch mean (C,) or empty if training is False - DeviceArray: Reversed batch variance (C,) or empty if training is False - """ - reduction_dims = [0] + list(range(2, input.ndim)) - reshape_dims = [1, -1] + [1] * (input.ndim - 2) - if training: - # Calculate batch mean and variance - mean = jnp.mean(input, axis=reduction_dims, keepdims=True) - saved_mean = jnp.squeeze(mean, reduction_dims) - var = jnp.var(input, axis=reduction_dims) - rstd = jax.lax.rsqrt(var.reshape(reshape_dims) + eps) - # Update running statistics using momentum - running_mean = (1 - momentum) * running_mean + momentum * saved_mean - running_var = (1 - momentum) * running_var + momentum * var - saved_rstd = jnp.squeeze(rstd, reduction_dims) - else: - rstd = jax.lax.rsqrt(running_var.reshape(reshape_dims) + eps) - saved_mean = jnp.array( - [], dtype=input.dtype - ) # No need to calculate batch statistics in inference mode - saved_rstd = jnp.array([], dtype=input.dtype) - - # Normalize - if training: - # use batch statistics if training - x_hat = (input - mean) * rstd - else: - # Use running statistics in inference mode - x_hat = (input - running_mean.reshape(reshape_dims)) * rstd - - # Scale and shift - if weight is not None: - x_hat *= weight.reshape(reshape_dims) # Reshape weight for broadcasting - if bias is not None: - x_hat += bias.reshape(reshape_dims) # Reshape bias for broadcasting - - return x_hat, saved_mean, saved_rstd +def _aten__native_batch_norm_legit( + input, weight, bias, running_mean, running_var, training, momentum, eps +): + """JAX implementation of batch normalization with optional parameters. + Refers to https://github.com/pytorch/pytorch/blob/cd3a71f754a2248bcfe500de7c9860bd7d2002bf/torch/_decomp/decompositions.py#L1713. + + Args: + input (DeviceArray): Input data (N, C, H, W). + running_mean ([DeviceArray]): Running mean of input (C,). + running_var ([DeviceArray]): Running variance of input (C,). + weight (Optional[DeviceArray]): Scaling factor (gamma) (C,). Can be None. + bias (Optional[DeviceArray]): Shift factor (beta) (C,). Can be None. + training (bool): If True, use batch statistics for normalization. + If False, use running statistics. + momentum (float): Momentum factor for updating running statistics. + eps (float): Small constant for numerical stability. + + Returns: + DeviceArray: Normalized output + DeviceArray: Batch mean (C,) or empty if training is False + DeviceArray: Reversed batch variance (C,) or empty if training is False + """ + reduction_dims = [0] + list(range(2, input.ndim)) + reshape_dims = [1, -1] + [1] * (input.ndim - 2) + if training: + # Calculate batch mean and variance + mean = jnp.mean(input, axis=reduction_dims, keepdims=True) + saved_mean = jnp.squeeze(mean, reduction_dims) + var = jnp.var(input, axis=reduction_dims) + rstd = jax.lax.rsqrt(var.reshape(reshape_dims) + eps) + # Update running statistics using momentum + running_mean = (1 - momentum) * running_mean + momentum * saved_mean + running_var = (1 - momentum) * running_var + momentum * var + saved_rstd = jnp.squeeze(rstd, reduction_dims) + else: + rstd = jax.lax.rsqrt(running_var.reshape(reshape_dims) + eps) + saved_mean = jnp.array( + [], dtype=input.dtype + ) # No need to calculate batch statistics in inference mode + saved_rstd = jnp.array([], dtype=input.dtype) + + # Normalize + if training: + # use batch statistics if training + x_hat = (input - mean) * rstd + else: + # Use running statistics in inference mode + x_hat = (input - running_mean.reshape(reshape_dims)) * rstd + + # Scale and shift + if weight is not None: + x_hat *= weight.reshape(reshape_dims) # Reshape weight for broadcasting + if bias is not None: + x_hat += bias.reshape(reshape_dims) # Reshape bias for broadcasting + + return x_hat, saved_mean, saved_rstd @op(torch.ops.aten._native_batch_norm_legit_no_training) -def _aten__native_batch_norm_legit_no_training(input, weight, bias, - running_mean, running_var, - momentum, eps): - return _aten__native_batch_norm_legit(input, weight, bias, running_mean, - running_var, False, momentum, eps) +def _aten__native_batch_norm_legit_no_training( + input, weight, bias, running_mean, running_var, momentum, eps +): + return _aten__native_batch_norm_legit( + input, weight, bias, running_mean, running_var, False, momentum, eps + ) @op(torch.ops.aten.relu) def _aten_relu(self): - return jax.nn.relu(self) + return jax.nn.relu(self) @op(torch.ops.aten.cat) def _aten_cat(tensors, dims=0): - # handle empty tensors as a special case. - # torch.cat will ignore the empty tensor, while jnp.concatenate - # will error if the dims > 0. - filtered_tensors = [ - t for t in tensors if not (t.ndim == 1 and t.shape[0] == 0) - ] - if filtered_tensors: - return jnp.concatenate(filtered_tensors, dims) - return tensors[0] + # handle empty tensors as a special case. + # torch.cat will ignore the empty tensor, while jnp.concatenate + # will error if the dims > 0. + filtered_tensors = [ + t for t in tensors if not (t.ndim == 1 and t.shape[0] == 0) + ] + if filtered_tensors: + return jnp.concatenate(filtered_tensors, dims) + return tensors[0] def _ceil_mode_padding( @@ -1220,219 +1237,237 @@ def _ceil_mode_padding( dilation: list[int], ceil_mode: bool, ): - """Creates low and high padding specification for the given padding (which is symmetric) and ceil mode. - - Additional high padding could be required when ceil mode is set. - """ - ceil_mode_padding = [] - for i in range(len(padding)): - left_padding = padding[i] - right_padding = left_padding - - input_size = input_shape[2 + i] - output_size_rem = (input_size + 2 * left_padding - - (kernel_size[i] - 1) * dilation[i] - 1) % stride[i] - if ceil_mode and output_size_rem != 0: - extra_padding = stride[i] - output_size_rem - new_output_size = (input_size + left_padding + right_padding + - extra_padding - (kernel_size[i] - 1) * dilation[i] - - 1 + stride[i] - 1) // stride[i] + 1 - # Ensure that the last pooling starts inside the image. - size_to_compare = input_size + left_padding + """Creates low and high padding specification for the given padding (which is symmetric) and ceil mode. - if (new_output_size - 1) * stride[i] < size_to_compare: - right_padding += extra_padding - - ceil_mode_padding.append((left_padding, right_padding)) - return ceil_mode_padding + Additional high padding could be required when ceil mode is set. + """ + ceil_mode_padding = [] + for i in range(len(padding)): + left_padding = padding[i] + right_padding = left_padding + + input_size = input_shape[2 + i] + output_size_rem = ( + input_size + + 2 * left_padding + - (kernel_size[i] - 1) * dilation[i] + - 1 + ) % stride[i] + if ceil_mode and output_size_rem != 0: + extra_padding = stride[i] - output_size_rem + new_output_size = ( + input_size + + left_padding + + right_padding + + extra_padding + - (kernel_size[i] - 1) * dilation[i] + - 1 + + stride[i] + - 1 + ) // stride[i] + 1 + # Ensure that the last pooling starts inside the image. + size_to_compare = input_size + left_padding + + if (new_output_size - 1) * stride[i] < size_to_compare: + right_padding += extra_padding + + ceil_mode_padding.append((left_padding, right_padding)) + return ceil_mode_padding @op(torch.ops.aten.max_pool2d_with_indices) @op(torch.ops.aten.max_pool3d_with_indices) -def _aten_max_pool2d_with_indices(inputs, - kernel_size, - strides=None, - padding=0, - dilation=1, - ceil_mode=False): - num_batch_dims = len(inputs.shape) - len(kernel_size) - 1 - kernel_size = tuple(kernel_size) - # Default stride is kernel_size - strides = tuple(strides) if strides else kernel_size - if isinstance(padding, int): - padding = [padding for _ in range(len(kernel_size))] - if isinstance(dilation, int): - dilation = tuple(dilation for _ in range(len(kernel_size))) - elif isinstance(dilation, list): - dilation = tuple(dilation) - - input_shape = inputs.shape - if num_batch_dims == 0: - input_shape = [1, *input_shape] - padding = _ceil_mode_padding(padding, input_shape, kernel_size, strides, - dilation, ceil_mode) - - assert len(kernel_size) == len( - strides), f"len({kernel_size=}) must equal len({strides=})" - assert len(kernel_size) == len( - dilation), f"len({kernel_size=}) must equal len({dilation=})" - strides = (1,) * (1 + num_batch_dims) + strides - dims = (1,) * (1 + num_batch_dims) + kernel_size - dilation = (1,) * (1 + num_batch_dims) + dilation - - is_single_input = False - if num_batch_dims == 0: - # add singleton batch dimension because lax.reduce_window always - # needs a batch dimension. - inputs = inputs[None] - strides = (1,) + strides - dims = (1,) + dims - dilation = (1,) + dilation - is_single_input = True - - assert inputs.ndim == len(dims), f"len({inputs.shape}) != len({dims})" - if not isinstance(padding, str): - padding = tuple(map(tuple, padding)) - assert len(padding) == len(kernel_size), ( - f"padding {padding} must specify pads for same number of dims as " - f"kernel_size {kernel_size}") - assert all([len(x) == 2 for x in padding - ]), f"each entry in padding {padding} must be length 2" - padding = ((0, 0), (0, 0)) + padding - - indices = jnp.arange(np.prod(inputs.shape[-len(kernel_size):])) - indices = indices.reshape(inputs.shape[-len(kernel_size):]) - indices = jnp.broadcast_to(indices, inputs.shape) - - def reduce_fn(a, b): - ai, av = a - bi, bv = b - which = av >= bv # torch breaks ties in favor of later indices - return jnp.where(which, ai, bi), jnp.where(which, av, bv) - - init_val = -jnp.inf - if inputs.dtype in (jnp.int32, jnp.int64): - init_val = -(1 << 31) - init_val = jnp.array(init_val).astype(inputs.dtype) - - # Separate maxpool result and indices into two reduce_window ops. Since - # the indices tensor is usually unused in inference, separating the two - # can help DCE computations for argmax. - y = jax.lax.reduce_window( - inputs, - init_val, - jax.lax.max, - dims, - strides, - padding, - window_dilation=dilation) - indices, _ = jax.lax.reduce_window( - (indices, inputs), - (0, init_val), - reduce_fn, - dims, - strides, - padding, - window_dilation=dilation, - ) - if is_single_input: - indices = jnp.squeeze(indices, axis=0) - y = jnp.squeeze(y, axis=0) - - return y, indices +def _aten_max_pool2d_with_indices( + inputs, kernel_size, strides=None, padding=0, dilation=1, ceil_mode=False +): + num_batch_dims = len(inputs.shape) - len(kernel_size) - 1 + kernel_size = tuple(kernel_size) + # Default stride is kernel_size + strides = tuple(strides) if strides else kernel_size + if isinstance(padding, int): + padding = [padding for _ in range(len(kernel_size))] + if isinstance(dilation, int): + dilation = tuple(dilation for _ in range(len(kernel_size))) + elif isinstance(dilation, list): + dilation = tuple(dilation) + + input_shape = inputs.shape + if num_batch_dims == 0: + input_shape = [1, *input_shape] + padding = _ceil_mode_padding( + padding, input_shape, kernel_size, strides, dilation, ceil_mode + ) + + assert len(kernel_size) == len(strides), ( + f"len({kernel_size=}) must equal len({strides=})" + ) + assert len(kernel_size) == len(dilation), ( + f"len({kernel_size=}) must equal len({dilation=})" + ) + strides = (1,) * (1 + num_batch_dims) + strides + dims = (1,) * (1 + num_batch_dims) + kernel_size + dilation = (1,) * (1 + num_batch_dims) + dilation + + is_single_input = False + if num_batch_dims == 0: + # add singleton batch dimension because lax.reduce_window always + # needs a batch dimension. + inputs = inputs[None] + strides = (1,) + strides + dims = (1,) + dims + dilation = (1,) + dilation + is_single_input = True + + assert inputs.ndim == len(dims), f"len({inputs.shape}) != len({dims})" + if not isinstance(padding, str): + padding = tuple(map(tuple, padding)) + assert len(padding) == len(kernel_size), ( + f"padding {padding} must specify pads for same number of dims as " + f"kernel_size {kernel_size}" + ) + assert all([len(x) == 2 for x in padding]), ( + f"each entry in padding {padding} must be length 2" + ) + padding = ((0, 0), (0, 0)) + padding + + indices = jnp.arange(np.prod(inputs.shape[-len(kernel_size) :])) + indices = indices.reshape(inputs.shape[-len(kernel_size) :]) + indices = jnp.broadcast_to(indices, inputs.shape) + + def reduce_fn(a, b): + ai, av = a + bi, bv = b + which = av >= bv # torch breaks ties in favor of later indices + return jnp.where(which, ai, bi), jnp.where(which, av, bv) + + init_val = -jnp.inf + if inputs.dtype in (jnp.int32, jnp.int64): + init_val = -(1 << 31) + init_val = jnp.array(init_val).astype(inputs.dtype) + + # Separate maxpool result and indices into two reduce_window ops. Since + # the indices tensor is usually unused in inference, separating the two + # can help DCE computations for argmax. + y = jax.lax.reduce_window( + inputs, + init_val, + jax.lax.max, + dims, + strides, + padding, + window_dilation=dilation, + ) + indices, _ = jax.lax.reduce_window( + (indices, inputs), + (0, init_val), + reduce_fn, + dims, + strides, + padding, + window_dilation=dilation, + ) + if is_single_input: + indices = jnp.squeeze(indices, axis=0) + y = jnp.squeeze(y, axis=0) + + return y, indices # Aten ops registered under the `xla` library. try: - @op(torch.ops.xla.max_pool2d_forward) - def _xla_max_pool2d_forward(*args, **kwargs): - return _aten_max_pool2d_with_indices(*args, **kwargs)[0] - - @op(torch.ops.xla.aot_mark_sharding) - def _xla_aot_mark_sharding(t, mesh: str, partition_spec: str): - from jax.sharding import PartitionSpec as P, NamedSharding - import ast - import torch_xla.distributed.spmd as xs - pmesh = xs.Mesh.from_str(mesh) - assert pmesh is not None - partition_spec_eval = ast.literal_eval(partition_spec) - jmesh = pmesh.get_jax_mesh() - return jax.lax.with_sharding_constraint( - t, NamedSharding(jmesh, P(*partition_spec_eval))) - - @op(torch.ops.xla.einsum_linear_forward) - def _xla_einsum_linear_forward(input, weight, bias): - with jax.named_scope('einsum_linear_forward'): - product = jax.numpy.einsum('...n,mn->...m', input, weight) - if bias is not None: - return product + bias - return product + @op(torch.ops.xla.max_pool2d_forward) + def _xla_max_pool2d_forward(*args, **kwargs): + return _aten_max_pool2d_with_indices(*args, **kwargs)[0] + + @op(torch.ops.xla.aot_mark_sharding) + def _xla_aot_mark_sharding(t, mesh: str, partition_spec: str): + from jax.sharding import PartitionSpec as P, NamedSharding + import ast + import torch_xla.distributed.spmd as xs + + pmesh = xs.Mesh.from_str(mesh) + assert pmesh is not None + partition_spec_eval = ast.literal_eval(partition_spec) + jmesh = pmesh.get_jax_mesh() + return jax.lax.with_sharding_constraint( + t, NamedSharding(jmesh, P(*partition_spec_eval)) + ) + + @op(torch.ops.xla.einsum_linear_forward) + def _xla_einsum_linear_forward(input, weight, bias): + with jax.named_scope("einsum_linear_forward"): + product = jax.numpy.einsum("...n,mn->...m", input, weight) + if bias is not None: + return product + bias + return product except AttributeError: - pass + pass # TODO add more ops @op(torch.ops.aten.min) def _aten_min(x, dim=None, keepdim=False): - if dim is not None: - return _with_reduction_scalar(jnp.min, x, dim, - keepdim), _with_reduction_scalar( - jnp.argmin, x, dim, - keepdim).astype(jnp.int64) - else: - return _with_reduction_scalar(jnp.min, x, dim, keepdim) + if dim is not None: + return _with_reduction_scalar( + jnp.min, x, dim, keepdim + ), _with_reduction_scalar(jnp.argmin, x, dim, keepdim).astype(jnp.int64) + else: + return _with_reduction_scalar(jnp.min, x, dim, keepdim) @op(torch.ops.aten.mode) def _aten_mode(input, dim=-1, keepdim=False, *, out=None): - if input.ndim == 0: # single number - return input, jnp.array(0) - dim = (input.ndim + - dim) % input.ndim # jnp.scipy.stats.mode does not accept -1 as dim - # keepdims must be True for accurate broadcasting - mode, _ = jax.scipy.stats.mode(input, axis=dim, keepdims=True) - mode_broadcast = jnp.broadcast_to(mode, input.shape) - if not keepdim: - mode = mode.squeeze(axis=dim) - indices = jnp.argmax( - jnp.equal(mode_broadcast, input), axis=dim, keepdims=keepdim) - return mode, indices + if input.ndim == 0: # single number + return input, jnp.array(0) + dim = ( + input.ndim + dim + ) % input.ndim # jnp.scipy.stats.mode does not accept -1 as dim + # keepdims must be True for accurate broadcasting + mode, _ = jax.scipy.stats.mode(input, axis=dim, keepdims=True) + mode_broadcast = jnp.broadcast_to(mode, input.shape) + if not keepdim: + mode = mode.squeeze(axis=dim) + indices = jnp.argmax( + jnp.equal(mode_broadcast, input), axis=dim, keepdims=keepdim + ) + return mode, indices @op(torch.ops.aten.amin) def _aten_amin(x, dim=None, keepdim=False): - return _with_reduction_scalar(jnp.amin, x, dim, keepdim) + return _with_reduction_scalar(jnp.amin, x, dim, keepdim) @op(torch.ops.aten.argmin) def _aten_argmin(self, dim=None, keepdim=False): - return _with_reduction_scalar(jnp.argmin, self, dim, keepdim) + return _with_reduction_scalar(jnp.argmin, self, dim, keepdim) @op(torch.ops.aten.sin) @op_base.promote_int_input def _aten_sin(x): - return jnp.sin(x) + return jnp.sin(x) @op(torch.ops.aten.sym_size) def _aten_sym_size(x, dim): - return x.shape[dim] + return x.shape[dim] @op(torch.ops.aten.var.correction) @op(torch.ops.prims.var) def _aten_var(x, dim=None, *, correction=1, keepdim=False, out=None): - return jnp.var(x, axis=dim, ddof=correction, keepdims=keepdim) + return jnp.var(x, axis=dim, ddof=correction, keepdims=keepdim) @op(torch.ops.prims.broadcast_in_dim) def _prims_broadcast_in_dim(t, shape, broadcast_dimensions): - return jax.lax.broadcast_in_dim( - t, shape, broadcast_dimensions=broadcast_dimensions) + return jax.lax.broadcast_in_dim( + t, shape, broadcast_dimensions=broadcast_dimensions + ) # aten.native_group_norm -- should use decomp table @@ -1441,171 +1476,176 @@ def _prims_broadcast_in_dim(t, shape, broadcast_dimensions): @op(torch.ops.aten.native_group_norm) def _aten_native_group_norm(input, weight, bias, N, C, HxW, group, eps=1e-5): - """Group Normalization implementation in JAX. - - Args: - input: Input tensor. Expected shape (batch_size, channels, ... spatial dims - ...) - weight: Optional scaling (gamma) parameter. Shape (channels,) - bias: Optional shifting (beta) parameter. Shape (channels,) - N: Batch size. - C: Number of channels. - HxW: Product of spatial dimensions (number of elements per channel after - flattening). - group: Number of groups for Group Normalization. - eps: Small value added for numerical stability. - - Returns: - A tuple of (normalized_output, mean, rstd) - """ - - input_shape = input.shape - - if 0 in input_shape: - return input, input, input - - # Reshape for group-wise normalization - reshaped_input = jnp.reshape(input, (1, N * group, -1)) - - # **Core Group Normalization** - def group_norm_body(x): # Function to apply within each group - mean = jnp.mean(x, axis=-1, keepdims=True) - var = jnp.var(x, axis=-1, keepdims=True) - rstd = jax.lax.rsqrt(var + eps) # Reciprocal of std with epsilon - normalized = (x - mean) * rstd - return normalized, mean, rstd - - normalized, group_mean, group_rstd = jax.lax.map(group_norm_body, - reshaped_input) - - # Reshape back to original input shape - output = jnp.reshape(normalized, input_shape) - - # **Affine transformation** - affine_shape = [-1 if i == 1 else 1 for i in range(input.ndim) - ] # Shape for broadcasting - if weight is not None and bias is not None: - output = bias.reshape(affine_shape) + output * weight.reshape(affine_shape) - elif weight is not None: - output = output * weight.reshape(affine_shape) - elif bias is not None: - output = output + bias.reshape(affine_shape) - - # Reshape mean and rstd - mean = jnp.reshape(group_mean, (N, group)) - rstd = jnp.reshape(group_rstd, (N, group)) - - return output, mean, rstd + """Group Normalization implementation in JAX. + Args: + input: Input tensor. Expected shape (batch_size, channels, ... spatial dims + ...) + weight: Optional scaling (gamma) parameter. Shape (channels,) + bias: Optional shifting (beta) parameter. Shape (channels,) + N: Batch size. + C: Number of channels. + HxW: Product of spatial dimensions (number of elements per channel after + flattening). + group: Number of groups for Group Normalization. + eps: Small value added for numerical stability. -@op(torch.ops.aten.linalg_vector_norm) -def _aten_linalg_vector_norm(self, ord=2, dim=None, keepdim=False, dtype=None): - """Calculates the vector norm along specified dimensions. - - Args: - self: The input tensor. - ord: The order of the norm. Can be a float or 'inf', '-inf', 'fro'. - Default is 2 (Euclidean norm). - dim: Dimensions along which to calculate the norm. If None, the norm is - calculated over all dimensions. - keepdim: Whether to keep the reduced dimensions. - dtype: Optional data type for the output. - - Returns: - The tensor containing the calculated vector norms. - """ - - if ord not in {2, float("inf"), float("-inf"), "fro" - } and not isinstance(ord, (int, float)): - raise ValueError( - f"Unsupported ord value: {ord}. Supported values are 2, inf, -inf, and" - " 'fro'.") - - # Special cases (for efficiency and clarity) - if ord == 0: - if self.shape == (): - # float sets it to float64. set it back to input type - result = jnp.astype(jnp.array(float(self != 0)), self.dtype) - else: - result = _with_reduction_scalar(jnp.sum, jnp.where(self != 0, 1, 0), dim, - keepdim) + Returns: + A tuple of (normalized_output, mean, rstd) + """ + + input_shape = input.shape + + if 0 in input_shape: + return input, input, input + + # Reshape for group-wise normalization + reshaped_input = jnp.reshape(input, (1, N * group, -1)) + + # **Core Group Normalization** + def group_norm_body(x): # Function to apply within each group + mean = jnp.mean(x, axis=-1, keepdims=True) + var = jnp.var(x, axis=-1, keepdims=True) + rstd = jax.lax.rsqrt(var + eps) # Reciprocal of std with epsilon + normalized = (x - mean) * rstd + return normalized, mean, rstd + + normalized, group_mean, group_rstd = jax.lax.map( + group_norm_body, reshaped_input + ) + + # Reshape back to original input shape + output = jnp.reshape(normalized, input_shape) + + # **Affine transformation** + affine_shape = [ + -1 if i == 1 else 1 for i in range(input.ndim) + ] # Shape for broadcasting + if weight is not None and bias is not None: + output = bias.reshape(affine_shape) + output * weight.reshape( + affine_shape + ) + elif weight is not None: + output = output * weight.reshape(affine_shape) + elif bias is not None: + output = output + bias.reshape(affine_shape) - elif ord == 2: # Euclidean norm - result = jnp.sqrt( - _with_reduction_scalar(jnp.sum, - jnp.abs(self)**2, dim, keepdim)) + # Reshape mean and rstd + mean = jnp.reshape(group_mean, (N, group)) + rstd = jnp.reshape(group_rstd, (N, group)) - elif ord == float("inf"): - result = _with_reduction_scalar(jnp.max, jnp.abs(self), dim, keepdim) + return output, mean, rstd - elif ord == float("-inf"): - result = _with_reduction_scalar(jnp.min, jnp.abs(self), dim, keepdim) - elif ord == "fro": # Frobenius norm - result = jnp.sqrt( - _with_reduction_scalar(jnp.sum, - jnp.abs(self)**2, dim, keepdim)) +@op(torch.ops.aten.linalg_vector_norm) +def _aten_linalg_vector_norm(self, ord=2, dim=None, keepdim=False, dtype=None): + """Calculates the vector norm along specified dimensions. + + Args: + self: The input tensor. + ord: The order of the norm. Can be a float or 'inf', '-inf', 'fro'. + Default is 2 (Euclidean norm). + dim: Dimensions along which to calculate the norm. If None, the norm is + calculated over all dimensions. + keepdim: Whether to keep the reduced dimensions. + dtype: Optional data type for the output. - else: # General case (e.g., ord = 1, ord = 3) - result = _with_reduction_scalar(jnp.sum, - jnp.abs(self)**ord, dim, - keepdim)**(1.0 / ord) + Returns: + The tensor containing the calculated vector norms. + """ - # (Optional) dtype conversion - if dtype is not None: - result = jnp.astype(result, self.dtype) + if ord not in {2, float("inf"), float("-inf"), "fro"} and not isinstance( + ord, (int, float) + ): + raise ValueError( + f"Unsupported ord value: {ord}. Supported values are 2, inf, -inf, and" + " 'fro'." + ) + + # Special cases (for efficiency and clarity) + if ord == 0: + if self.shape == (): + # float sets it to float64. set it back to input type + result = jnp.astype(jnp.array(float(self != 0)), self.dtype) + else: + result = _with_reduction_scalar( + jnp.sum, jnp.where(self != 0, 1, 0), dim, keepdim + ) + + elif ord == 2: # Euclidean norm + result = jnp.sqrt( + _with_reduction_scalar(jnp.sum, jnp.abs(self) ** 2, dim, keepdim) + ) + + elif ord == float("inf"): + result = _with_reduction_scalar(jnp.max, jnp.abs(self), dim, keepdim) + + elif ord == float("-inf"): + result = _with_reduction_scalar(jnp.min, jnp.abs(self), dim, keepdim) + + elif ord == "fro": # Frobenius norm + result = jnp.sqrt( + _with_reduction_scalar(jnp.sum, jnp.abs(self) ** 2, dim, keepdim) + ) + + else: # General case (e.g., ord = 1, ord = 3) + result = _with_reduction_scalar( + jnp.sum, jnp.abs(self) ** ord, dim, keepdim + ) ** (1.0 / ord) + + # (Optional) dtype conversion + if dtype is not None: + result = jnp.astype(result, self.dtype) - new_dtype = mappings.t2j_dtype(torch.get_default_dtype()) - if result.dtype == jax.numpy.int64: - result = result.astype(new_dtype) - return result + new_dtype = mappings.t2j_dtype(torch.get_default_dtype()) + if result.dtype == jax.numpy.int64: + result = result.astype(new_dtype) + return result # aten.reflection_pad1d @op(torch.ops.aten.reflection_pad1d) def _aten_reflection_pad1d(input, padding): - rank = len(input.shape) - pad_size = [(0, 0)] * rank - pad_size[-1] = padding - return jnp.pad(input, pad_size, mode="reflect") + rank = len(input.shape) + pad_size = [(0, 0)] * rank + pad_size[-1] = padding + return jnp.pad(input, pad_size, mode="reflect") # aten.alias @op(torch.ops.aten.alias) def _aten_alias(self, *args): - return self + return self # aten.sinh @op(torch.ops.aten.sinh) @op_base.promote_int_input def _aten_sinh(self): - return jnp.sinh(self) + return jnp.sinh(self) # aten.native_layer_norm_backward @op(torch.ops.aten.native_layer_norm_backward) -def _aten_native_layer_norm_backward(grad_out, - input, - normalized_shape, - weight, - bias, - eps=1e-5): - """Implements the backward pass of layer normalization in Jax as defined by `aten::native_layer_norm_backward`. - - Args: - grad_out: The gradient of the output tensor. - input: The input tensor. - normalized_shape: A list of integer dimensions to be normalized over. - weight: Optional weight tensor for the affine transformation. - bias: Optional bias tensor for the affine transformation. - eps: A small epsilon value for numerical stability. - - Returns: - A tuple of (grad_input, grad_weight, grad_bias). - """ - return jax.lax.native_layer_norm_backward(grad_out, input, normalized_shape, - weight, bias, eps) +def _aten_native_layer_norm_backward( + grad_out, input, normalized_shape, weight, bias, eps=1e-5 +): + """Implements the backward pass of layer normalization in Jax as defined by `aten::native_layer_norm_backward`. + + Args: + grad_out: The gradient of the output tensor. + input: The input tensor. + normalized_shape: A list of integer dimensions to be normalized over. + weight: Optional weight tensor for the affine transformation. + bias: Optional bias tensor for the affine transformation. + eps: A small epsilon value for numerical stability. + + Returns: + A tuple of (grad_input, grad_weight, grad_bias). + """ + return jax.lax.native_layer_norm_backward( + grad_out, input, normalized_shape, weight, bias, eps + ) # aten.reflection_pad3d_backward @@ -1616,34 +1656,34 @@ def _aten_native_layer_norm_backward(grad_out, @op(torch.ops.aten.atanh) @op_base.promote_int_input def _aten_atanh(self): - res = jnp.arctanh(self) - return res + res = jnp.arctanh(self) + return res # aten.bincount @op(torch.ops.aten.bincount) def _aten_bincount(input, weights=None, minlength=0): - return jnp.bincount(input, weights, minlength) + return jnp.bincount(input, weights, minlength) # aten.bitwise_not @op(torch.ops.aten.bitwise_not) def _aten_bitwise_not(self): - return ~self + return ~self # aten.bitwise_left_shift @op(torch.ops.aten.__lshift__) @op(torch.ops.aten.bitwise_left_shift) def _aten_bitwise_left_shift(input, other): - return jnp.left_shift(input, other) + return jnp.left_shift(input, other) # aten.bitwise_right_shift @op(torch.ops.aten.__rshift__) @op(torch.ops.aten.bitwise_right_shift) def _aten_bitwise_right_shift(input, other): - return jnp.right_shift(input, other) + return jnp.right_shift(input, other) # aten.embedding_dense_backward @@ -1652,125 +1692,127 @@ def _aten_bitwise_right_shift(input, other): # aten.sum @op(torch.ops.aten.sum) def _aten_sum(self, dim=None, keepdim=False, dtype=None): - if not dim: - dim = None - return _with_reduction_scalar(jnp.sum, self, dim, keepdim) + if not dim: + dim = None + return _with_reduction_scalar(jnp.sum, self, dim, keepdim) # aten.sqrt @op(torch.ops.aten.sqrt) @op_base.promote_int_input def _aten_sqrt(self): - return jnp.sqrt(self) + return jnp.sqrt(self) @op(torch.ops.aten.tan) @op_base.promote_int_input def _aten_tanh(self): - res = jnp.tan(self) - return res + res = jnp.tan(self) + return res # aten.tanh @op(torch.ops.aten.tanh) @op_base.promote_int_input def _aten_tanh(self): - res = jnp.tanh(self) - return res + res = jnp.tanh(self) + return res # aten.ceil @op(torch.ops.aten.ceil) def _aten_ceil(self): - return jnp.ceil(self).astype(self) + return jnp.ceil(self).astype(self) # aten.asin @op(torch.ops.aten.asin) @op_base.promote_int_input def _aten_asin(self): - res = jnp.arcsin(self) - return res + res = jnp.arcsin(self) + return res # aten.minimum @op(torch.ops.aten.minimum) def _aten_minimum(self, other): - return jnp.minimum(self, other) + return jnp.minimum(self, other) # aten.max_pool2d_backward def _scatter_index(dim, index): - """Returns a tuple of indexes; - - The first is to select in input (to modify), - the second is to select from the values. - """ - index_shape = list(index.shape) - input_indexes = [] - source_indexes = [] - if dim < 0: - dim += len(index_shape) - for i in range(len(index_shape)): - source_indexes.append(slice(0, index_shape[i])) - if i == dim: - input_indexes.append(index) - else: - target_shape = [1] * len(index_shape) - target_shape[i] = index_shape[i] - input_indexes.append( - jnp.broadcast_to( - jnp.arange(index_shape[i]).reshape(target_shape), index_shape)) - return tuple(input_indexes), tuple(source_indexes) + """Returns a tuple of indexes; + + The first is to select in input (to modify), + the second is to select from the values. + """ + index_shape = list(index.shape) + input_indexes = [] + source_indexes = [] + if dim < 0: + dim += len(index_shape) + for i in range(len(index_shape)): + source_indexes.append(slice(0, index_shape[i])) + if i == dim: + input_indexes.append(index) + else: + target_shape = [1] * len(index_shape) + target_shape[i] = index_shape[i] + input_indexes.append( + jnp.broadcast_to( + jnp.arange(index_shape[i]).reshape(target_shape), + index_shape, + ) + ) + return tuple(input_indexes), tuple(source_indexes) # aten.scatter_add @op(torch.ops.aten.scatter_add) def _aten_scatter_add(input, dim, index, src): - """JAX implementation of scatter, mimicking torch.scatter behavior""" + """JAX implementation of scatter, mimicking torch.scatter behavior""" - input_indexes, source_indexes = _scatter_index(dim, index) - return input.at[input_indexes].add(src[source_indexes]) + input_indexes, source_indexes = _scatter_index(dim, index) + return input.at[input_indexes].add(src[source_indexes]) # aten.masked_scatter @op(torch.ops.aten.masked_scatter) def _aten_masked_scatter(self, mask, source): + broadcast_shape = jnp.broadcast_shapes(self.shape, mask.shape) - broadcast_shape = jnp.broadcast_shapes(self.shape, mask.shape) + if self.shape != broadcast_shape: + self = jnp.broadcast_to(self, broadcast_shape) + elif mask.shape != broadcast_shape: + mask = jnp.broadcast_to(mask, broadcast_shape) - if self.shape != broadcast_shape: - self = jnp.broadcast_to(self, broadcast_shape) - elif mask.shape != broadcast_shape: - mask = jnp.broadcast_to(mask, broadcast_shape) + self_flat = self.flatten() + mask_flat = mask.flatten() + source_flat = source.flatten() - self_flat = self.flatten() - mask_flat = mask.flatten() - source_flat = source.flatten() + true_indices = jnp.where(mask_flat)[0] + self_flat = self_flat.at[true_indices].set(source_flat[: len(true_indices)]) + final_arr = self_flat.reshape(self.shape) - true_indices = jnp.where(mask_flat)[0] - self_flat = self_flat.at[true_indices].set(source_flat[:len(true_indices)]) - final_arr = self_flat.reshape(self.shape) - - return final_arr + return final_arr @op(torch.ops.aten.masked_select) def _aten_masked_select(self, mask, *args, **kwargs): - broadcast_shape = jnp.broadcast_shapes(self.shape, mask.shape) + broadcast_shape = jnp.broadcast_shapes(self.shape, mask.shape) - if self.shape != broadcast_shape: - self = jnp.broadcast_to(self, broadcast_shape) - if mask.shape != broadcast_shape: - mask = jnp.broadcast_to(mask, broadcast_shape) + if self.shape != broadcast_shape: + self = jnp.broadcast_to(self, broadcast_shape) + if mask.shape != broadcast_shape: + mask = jnp.broadcast_to(mask, broadcast_shape) - self_flat = self.flatten() - mask_flat = mask.flatten() - true_indices = jnp.where(mask_flat)[0] + self_flat = self.flatten() + mask_flat = mask.flatten() + true_indices = jnp.where(mask_flat)[0] - return self_flat[true_indices] + return self_flat[true_indices] # aten.logical_not @@ -1779,90 +1821,86 @@ def _aten_masked_select(self, mask, *args, **kwargs): # aten.sign @op(torch.ops.aten.sign) def _aten_sign(x): - return jnp.sign(x) + return jnp.sign(x) # aten.signbit @op(torch.ops.aten.signbit) def _aten_signbit(x): - return jnp.signbit(x) + return jnp.signbit(x) # aten.sigmoid @op(torch.ops.aten.sigmoid) @op_base.promote_int_input def _aten_sigmoid(x): - return jax.nn.sigmoid(x) + return jax.nn.sigmoid(x) # implement aten.asinh in jax @op(torch.ops.aten.asinh) @op_base.promote_int_input def _aten_asinh(self): - res = jnp.arcsinh(self) - return res + res = jnp.arcsinh(self) + return res # aten.atan @op(torch.ops.aten.atan) @op_base.promote_int_input def _aten_atan(self): - res = jnp.arctan(self) - return res + res = jnp.arctan(self) + return res @op(torch.ops.aten.scatter_reduce) @op(torch.ops.aten.scatter) -def _aten_scatter_reduce(input, - dim, - index, - src, - reduce=None, - *, - include_self=True): - if not isinstance(src, jnp.ndarray): - src = jnp.array(src, dtype=input.dtype) - input_indexes, source_indexes = _scatter_index(dim, index) - # "Zero out" target elements when not included - if not include_self: - if reduce in ["sum", "mean"]: - base_input = jnp.zeros_like(src) - elif reduce == "prod": - base_input = jnp.ones_like(src) +def _aten_scatter_reduce( + input, dim, index, src, reduce=None, *, include_self=True +): + if not isinstance(src, jnp.ndarray): + src = jnp.array(src, dtype=input.dtype) + input_indexes, source_indexes = _scatter_index(dim, index) + # "Zero out" target elements when not included + if not include_self: + if reduce in ["sum", "mean"]: + base_input = jnp.zeros_like(src) + elif reduce == "prod": + base_input = jnp.ones_like(src) + elif reduce == "amax": + base_input = jnp.full_like(src, -jnp.inf) + else: # amin + base_input = jnp.full_like(src, jnp.inf) + input = input.at[input_indexes].set(base_input[source_indexes]) + + if reduce == "sum" or reduce == "add": + return input.at[input_indexes].add(src[source_indexes]) + elif reduce == "prod" or reduce == "multiply": + return input.at[input_indexes].multiply(src[source_indexes]) + elif reduce == "mean": + if include_self: + count = jnp.ones_like(input) + else: + count = jnp.zeros_like(input) + count = count.at[input_indexes].add(jnp.ones_like(src)[source_indexes]) + count = jnp.clip(count, min=1) + mean = input.at[input_indexes].add(src[source_indexes]) + if _is_int(input): + return mean // count + return mean / count elif reduce == "amax": - base_input = jnp.full_like(src, -jnp.inf) - else: # amin - base_input = jnp.full_like(src, jnp.inf) - input = input.at[input_indexes].set(base_input[source_indexes]) - - if reduce == "sum" or reduce == "add": - return input.at[input_indexes].add(src[source_indexes]) - elif reduce == "prod" or reduce == "multiply": - return input.at[input_indexes].multiply(src[source_indexes]) - elif reduce == "mean": - if include_self: - count = jnp.ones_like(input) + return input.at[input_indexes].max(src[source_indexes]) + elif reduce == "amin": + return input.at[input_indexes].min(src[source_indexes]) else: - count = jnp.zeros_like(input) - count = count.at[input_indexes].add(jnp.ones_like(src)[source_indexes]) - count = jnp.clip(count, min=1) - mean = input.at[input_indexes].add(src[source_indexes]) - if _is_int(input): - return mean // count - return mean / count - elif reduce == "amax": - return input.at[input_indexes].max(src[source_indexes]) - elif reduce == "amin": - return input.at[input_indexes].min(src[source_indexes]) - else: - return input.at[input_indexes].set(src[source_indexes]) + return input.at[input_indexes].set(src[source_indexes]) # aten.acos @op(torch.ops.aten.acos) @op_base.promote_int_input def _aten_acos(self): - return jnp.arccos(self) + return jnp.arccos(self) # aten.sym_storage_offset @@ -1873,73 +1911,77 @@ def _aten_acos(self): # aten.gt @op(torch.ops.aten.gt) def _aten_gt(self, other): - return self > other + return self > other # aten.sym_stride # aten.lt @op(torch.ops.aten.lt) def _aten_lt(self, other): - return self < other + return self < other def pool(inputs, init, reduce_fn, window_shape, strides, padding): - """Helper function to define pooling functions. - - Pooling functions are implemented using the ReduceWindow XLA op. - NOTE: Be aware that pooling is not generally differentiable. - That means providing a reduce_fn that is differentiable does not imply that - pool is differentiable. - - Args: - inputs: input data with dimensions (batch, window dims..., features). - init: the initial value for the reduction - reduce_fn: a reduce function of the form ``(T, T) -> T``. - window_shape: a shape tuple defining the window to reduce over. - strides: a sequence of ``n`` integers, representing the inter-window - strides (default: ``(1, ..., 1)``). - padding: either the string ``'SAME'``, the string ``'VALID'``, or a sequence - of ``n`` ``(low, high)`` integer pairs that give the padding to apply before - and after each spatial dimension. - Returns: - The output of the reduction for each window slice. - """ - num_batch_dims = inputs.ndim - (len(window_shape) + 1) - strides = strides or (1,) * len(window_shape) - assert len(window_shape) == len( - strides), f"len({window_shape}) must equal len({strides})" - strides = (1,) * (1 + num_batch_dims) + strides - dims = (1,) * (1 + num_batch_dims) + window_shape - - is_single_input = False - if num_batch_dims == 0: - # add singleton batch dimension because lax.reduce_window always - # needs a batch dimension. - inputs = inputs[None] - strides = (1,) + strides - dims = (1,) + dims - is_single_input = True - - assert inputs.ndim == len(dims), f"len({inputs.shape}) != len({dims})" - if not isinstance(padding, str): - padding = tuple(map(tuple, padding)) - assert len(padding) == len(window_shape), ( - f"padding {padding} must specify pads for same number of dims as " - f"window_shape {window_shape}") - assert all([len(x) == 2 for x in padding - ]), f"each entry in padding {padding} must be length 2" - padding = ((0, 0), (0, 0)) + padding - y = jax.lax.reduce_window(inputs, init, reduce_fn, dims, strides, padding) - if is_single_input: - y = jnp.squeeze(y, axis=0) - return y + """Helper function to define pooling functions. + + Pooling functions are implemented using the ReduceWindow XLA op. + NOTE: Be aware that pooling is not generally differentiable. + That means providing a reduce_fn that is differentiable does not imply that + pool is differentiable. + + Args: + inputs: input data with dimensions (batch, window dims..., features). + init: the initial value for the reduction + reduce_fn: a reduce function of the form ``(T, T) -> T``. + window_shape: a shape tuple defining the window to reduce over. + strides: a sequence of ``n`` integers, representing the inter-window + strides (default: ``(1, ..., 1)``). + padding: either the string ``'SAME'``, the string ``'VALID'``, or a sequence + of ``n`` ``(low, high)`` integer pairs that give the padding to apply before + and after each spatial dimension. + Returns: + The output of the reduction for each window slice. + """ + num_batch_dims = inputs.ndim - (len(window_shape) + 1) + strides = strides or (1,) * len(window_shape) + assert len(window_shape) == len(strides), ( + f"len({window_shape}) must equal len({strides})" + ) + strides = (1,) * (1 + num_batch_dims) + strides + dims = (1,) * (1 + num_batch_dims) + window_shape + + is_single_input = False + if num_batch_dims == 0: + # add singleton batch dimension because lax.reduce_window always + # needs a batch dimension. + inputs = inputs[None] + strides = (1,) + strides + dims = (1,) + dims + is_single_input = True + + assert inputs.ndim == len(dims), f"len({inputs.shape}) != len({dims})" + if not isinstance(padding, str): + padding = tuple(map(tuple, padding)) + assert len(padding) == len(window_shape), ( + f"padding {padding} must specify pads for same number of dims as " + f"window_shape {window_shape}" + ) + assert all([len(x) == 2 for x in padding]), ( + f"each entry in padding {padding} must be length 2" + ) + padding = ((0, 0), (0, 0)) + padding + y = jax.lax.reduce_window(inputs, init, reduce_fn, dims, strides, padding) + if is_single_input: + y = jnp.squeeze(y, axis=0) + return y @op(torch.ops.aten._adaptive_avg_pool2d) @op(torch.ops.aten._adaptive_avg_pool3d) -def adaptive_avg_pool2or3d(input: jnp.ndarray, - output_size: Tuple[int, int]) -> jnp.ndarray: - """ +def adaptive_avg_pool2or3d( + input: jnp.ndarray, output_size: Tuple[int, int] +) -> jnp.ndarray: + """ Applies a 2/3D adaptive average pooling over an input signal composed of several input planes. See :class:`~torch.nn.AdaptiveAvgPool2d` for details and output shape. @@ -1951,115 +1993,127 @@ def adaptive_avg_pool2or3d(input: jnp.ndarray, Context: https://github.com/pytorch/pytorch/blob/main/torch/_decomp/decompositions.py#L2401 """ - shape = input.shape - ndim = len(shape) - out_dim = len(output_size) - num_spatial_dim = ndim - out_dim - - # Preconditions - - assert ndim in ( - out_dim + 1, out_dim + 2 - ), f"adaptive_avg_pool{num_spatial_dim}d(): Expected {num_spatial_dim+1}D or {num_spatial_dim+2}D tensor, but got {ndim}" - for d in input.shape[-2:]: - assert d != 0, "adaptive_avg_pool{num_spactial_dim}d(): Expected input to have non-zero size for " \ - f"non-batch dimensions, but input has shape {tuple(shape)}." - - # Optimisation (we should also do this in the kernel implementation) - if all(s % o == 0 for o, s in zip(output_size, shape[-out_dim:])): - stride = tuple(i // o for i, o in zip(shape[-out_dim:], output_size)) - kernel = tuple(i - (o - 1) * s - for i, o, s in zip(shape[-out_dim:], output_size, stride)) - return _aten_avg_pool( - input, - kernel, - strides=stride, - ) + shape = input.shape + ndim = len(shape) + out_dim = len(output_size) + num_spatial_dim = ndim - out_dim - def start_index(a, b, c): - return (a * c) // b - - def end_index(a, b, c): - return ((a + 1) * c + b - 1) // b - - def compute_idx(in_size, out_size): - orange = jnp.arange(out_size, dtype=jnp.int64) - i0 = start_index(orange, out_size, in_size) - # Let length = end_index - start_index, i.e. the length of the pooling kernels - # length.max() can be computed analytically as follows: - maxlength = in_size // out_size + 1 - in_size_mod = in_size % out_size - # adaptive = True iff there are kernels with different lengths - adaptive = not (in_size_mod == 0 or out_size % in_size_mod == 0) - if adaptive: - maxlength += 1 - elif in_size_mod == 0: - maxlength -= 1 - - range_max = jnp.arange(maxlength, dtype=jnp.int64) - idx = i0[:, None] + range_max - if adaptive: - # Need to clamp to avoid accessing out-of-bounds memory - idx = jnp.minimum(idx, in_size - 1) - - # Compute the length - i1 = end_index(orange, out_size, in_size) - length = i1 - i0 - else: - length = maxlength - return idx, length, range_max, adaptive - - idx, length, range_max, adaptive = [[None] * out_dim for _ in range(4)] - # length is not None if it's constant, otherwise we'll need to compute it - for i, (s, o) in enumerate(zip(shape[-out_dim:], output_size)): - idx[i], length[i], range_max[i], adaptive[i] = compute_idx(s, o) - - def _unsqueeze_to_dim(x, dim): - ndim = len(x.shape) - return jax.lax.expand_dims(x, tuple(range(ndim, dim))) - - if out_dim == 2: - # NOTE: unsqueeze to insert extra 1 in ranks; so they - # would broadcast - vals = input[..., _unsqueeze_to_dim(idx[0], 4), idx[1]] - reduce_axis = (-3, -1) - else: - assert out_dim == 3 - vals = input[..., - _unsqueeze_to_dim(idx[0], 6), - _unsqueeze_to_dim(idx[1], 4), idx[2]] - reduce_axis = (-5, -3, -1) - - # Shortcut for the simpler case - if not any(adaptive): - return jnp.mean(vals, axis=reduce_axis) - - def maybe_mask(vals, length, range_max, adaptive, dim): - if isinstance(length, int): - return vals, length + # Preconditions + + assert ndim in (out_dim + 1, out_dim + 2), ( + f"adaptive_avg_pool{num_spatial_dim}d(): Expected {num_spatial_dim + 1}D or {num_spatial_dim + 2}D tensor, but got {ndim}" + ) + for d in input.shape[-2:]: + assert d != 0, ( + "adaptive_avg_pool{num_spactial_dim}d(): Expected input to have non-zero size for " + f"non-batch dimensions, but input has shape {tuple(shape)}." + ) + + # Optimisation (we should also do this in the kernel implementation) + if all(s % o == 0 for o, s in zip(output_size, shape[-out_dim:])): + stride = tuple(i // o for i, o in zip(shape[-out_dim:], output_size)) + kernel = tuple( + i - (o - 1) * s + for i, o, s in zip(shape[-out_dim:], output_size, stride) + ) + return _aten_avg_pool( + input, + kernel, + strides=stride, + ) + + def start_index(a, b, c): + return (a * c) // b + + def end_index(a, b, c): + return ((a + 1) * c + b - 1) // b + + def compute_idx(in_size, out_size): + orange = jnp.arange(out_size, dtype=jnp.int64) + i0 = start_index(orange, out_size, in_size) + # Let length = end_index - start_index, i.e. the length of the pooling kernels + # length.max() can be computed analytically as follows: + maxlength = in_size // out_size + 1 + in_size_mod = in_size % out_size + # adaptive = True iff there are kernels with different lengths + adaptive = not (in_size_mod == 0 or out_size % in_size_mod == 0) + if adaptive: + maxlength += 1 + elif in_size_mod == 0: + maxlength -= 1 + + range_max = jnp.arange(maxlength, dtype=jnp.int64) + idx = i0[:, None] + range_max + if adaptive: + # Need to clamp to avoid accessing out-of-bounds memory + idx = jnp.minimum(idx, in_size - 1) + + # Compute the length + i1 = end_index(orange, out_size, in_size) + length = i1 - i0 + else: + length = maxlength + return idx, length, range_max, adaptive + + idx, length, range_max, adaptive = [[None] * out_dim for _ in range(4)] + # length is not None if it's constant, otherwise we'll need to compute it + for i, (s, o) in enumerate(zip(shape[-out_dim:], output_size)): + idx[i], length[i], range_max[i], adaptive[i] = compute_idx(s, o) + + def _unsqueeze_to_dim(x, dim): + ndim = len(x.shape) + return jax.lax.expand_dims(x, tuple(range(ndim, dim))) + + if out_dim == 2: + # NOTE: unsqueeze to insert extra 1 in ranks; so they + # would broadcast + vals = input[..., _unsqueeze_to_dim(idx[0], 4), idx[1]] + reduce_axis = (-3, -1) else: - # zero-out the things we didn't really want to select - assert dim < 0 - # hack - mask = range_max >= length[:, None] - if dim == -2: - mask = _unsqueeze_to_dim(mask, 4) - elif dim == -3: - mask = _unsqueeze_to_dim(mask, 6) - vals = jnp.where(mask, 0.0, vals) - # Compute the length of each window - length = _unsqueeze_to_dim(length, -dim) - return vals, length - - for i in range(len(length)): - vals, length[i] = maybe_mask( - vals, length[i], range_max[i], adaptive=adaptive[i], dim=(i - out_dim)) - - # We unroll the sum as we assume that the kernels are going to be small - ret = jnp.sum(vals, axis=reduce_axis) - # NOTE: math.prod because we want to expand it to length[0] * length[1] * ... - # this is multiplication with broadcasting, not regular pointwise product - return ret / math.prod(length) + assert out_dim == 3 + vals = input[ + ..., + _unsqueeze_to_dim(idx[0], 6), + _unsqueeze_to_dim(idx[1], 4), + idx[2], + ] + reduce_axis = (-5, -3, -1) + + # Shortcut for the simpler case + if not any(adaptive): + return jnp.mean(vals, axis=reduce_axis) + + def maybe_mask(vals, length, range_max, adaptive, dim): + if isinstance(length, int): + return vals, length + else: + # zero-out the things we didn't really want to select + assert dim < 0 + # hack + mask = range_max >= length[:, None] + if dim == -2: + mask = _unsqueeze_to_dim(mask, 4) + elif dim == -3: + mask = _unsqueeze_to_dim(mask, 6) + vals = jnp.where(mask, 0.0, vals) + # Compute the length of each window + length = _unsqueeze_to_dim(length, -dim) + return vals, length + + for i in range(len(length)): + vals, length[i] = maybe_mask( + vals, + length[i], + range_max[i], + adaptive=adaptive[i], + dim=(i - out_dim), + ) + + # We unroll the sum as we assume that the kernels are going to be small + ret = jnp.sum(vals, axis=reduce_axis) + # NOTE: math.prod because we want to expand it to length[0] * length[1] * ... + # this is multiplication with broadcasting, not regular pointwise product + return ret / math.prod(length) @op(torch.ops.aten.avg_pool1d) @@ -2074,120 +2128,132 @@ def _aten_avg_pool( count_include_pad=True, divisor_override=None, ): - num_batch_dims = len(inputs.shape) - len(kernel_size) - 1 - kernel_size = tuple(kernel_size) - strides = tuple(strides) if strides else kernel_size - if isinstance(padding, list) and len(padding) == 1: - padding = padding[0] - if isinstance(padding, int): - padding = [padding for _ in range(len(kernel_size))] - - input_shape = inputs.shape - if num_batch_dims == 0: - input_shape = [1, *input_shape] - padding = _ceil_mode_padding(padding, input_shape, kernel_size, strides, - [1] * len(kernel_size), ceil_mode) - - y = pool(inputs, 0.0, jax.lax.add, kernel_size, strides, padding) - if divisor_override is not None: - y = y / jnp.array(divisor_override, y.dtype) - elif count_include_pad: - div_shape = list(y.shape) - div_by = jnp.ones(div_shape, y.dtype) * np.prod(kernel_size) - unequal_paddings = map(lambda pad: pad[0] != pad[1], padding) - unequal_padding_indices = np.where(list(unequal_paddings))[0] - if len(unequal_padding_indices) > 0: - # indices to update kernel size - offset = len(div_shape) - len(padding) - skip_indices = list(map(lambda x: x + offset, unequal_padding_indices)) - indices = _generate_indices(div_shape, skip_dim_indices=skip_indices) - # updated kernel size accounting for maximum padding - new_kernel_size = list(kernel_size) - for j in unequal_padding_indices: - new_kernel_size[j] = kernel_size[j] - padding[j][1] + padding[j][0] - - for idx in indices: - for j in unequal_padding_indices: - idx[j + offset] = -1 - div_by = div_by.at[tuple(idx)].set(np.prod(new_kernel_size)) - - y = y / div_by - else: - div_shape = list(inputs.shape) - div_shape[num_batch_dims] = 1 - div_shape = tuple(div_shape) - if len(div_shape) - 2 == len(kernel_size): - div_shape = (1,) + div_shape[1:] - y = y / pool( - jnp.ones(div_shape, y.dtype), - jnp.array(0.0, y.dtype), - jax.lax.add, + num_batch_dims = len(inputs.shape) - len(kernel_size) - 1 + kernel_size = tuple(kernel_size) + strides = tuple(strides) if strides else kernel_size + if isinstance(padding, list) and len(padding) == 1: + padding = padding[0] + if isinstance(padding, int): + padding = [padding for _ in range(len(kernel_size))] + + input_shape = inputs.shape + if num_batch_dims == 0: + input_shape = [1, *input_shape] + padding = _ceil_mode_padding( + padding, + input_shape, kernel_size, strides, - padding, + [1] * len(kernel_size), + ceil_mode, ) - return y.astype(inputs.dtype) + + y = pool(inputs, 0.0, jax.lax.add, kernel_size, strides, padding) + if divisor_override is not None: + y = y / jnp.array(divisor_override, y.dtype) + elif count_include_pad: + div_shape = list(y.shape) + div_by = jnp.ones(div_shape, y.dtype) * np.prod(kernel_size) + unequal_paddings = map(lambda pad: pad[0] != pad[1], padding) + unequal_padding_indices = np.where(list(unequal_paddings))[0] + if len(unequal_padding_indices) > 0: + # indices to update kernel size + offset = len(div_shape) - len(padding) + skip_indices = list( + map(lambda x: x + offset, unequal_padding_indices) + ) + indices = _generate_indices( + div_shape, skip_dim_indices=skip_indices + ) + # updated kernel size accounting for maximum padding + new_kernel_size = list(kernel_size) + for j in unequal_padding_indices: + new_kernel_size[j] = ( + kernel_size[j] - padding[j][1] + padding[j][0] + ) + + for idx in indices: + for j in unequal_padding_indices: + idx[j + offset] = -1 + div_by = div_by.at[tuple(idx)].set(np.prod(new_kernel_size)) + + y = y / div_by + else: + div_shape = list(inputs.shape) + div_shape[num_batch_dims] = 1 + div_shape = tuple(div_shape) + if len(div_shape) - 2 == len(kernel_size): + div_shape = (1,) + div_shape[1:] + y = y / pool( + jnp.ones(div_shape, y.dtype), + jnp.array(0.0, y.dtype), + jax.lax.add, + kernel_size, + strides, + padding, + ) + return y.astype(inputs.dtype) # helper function to generate all indices to iterate through ndarray def _generate_indices(dims, skip_dim_indices=[]): - res = [] - - def _helper(curr_dim_idx, sofar): - if curr_dim_idx in skip_dim_indices: - _helper(curr_dim_idx + 1, sofar[:]) - return - if curr_dim_idx >= len(dims): - res.append(sofar) - return - for i in range(dims[curr_dim_idx]): - sofar[curr_dim_idx] = i - _helper(curr_dim_idx + 1, sofar[:]) - - _helper(0, [0 for _ in dims]) - return res + res = [] + + def _helper(curr_dim_idx, sofar): + if curr_dim_idx in skip_dim_indices: + _helper(curr_dim_idx + 1, sofar[:]) + return + if curr_dim_idx >= len(dims): + res.append(sofar) + return + for i in range(dims[curr_dim_idx]): + sofar[curr_dim_idx] = i + _helper(curr_dim_idx + 1, sofar[:]) + + _helper(0, [0 for _ in dims]) + return res # aten.sym_numel # aten.reciprocal @op(torch.ops.aten.reciprocal) def _aten_reciprocal(a): - if _is_int(a): - return (1 / a).astype(jnp.dtype('float32')) - return 1 / a + if _is_int(a): + return (1 / a).astype(jnp.dtype("float32")) + return 1 / a # aten.select_scatter @op(torch.ops.aten.select_scatter) def _aten_select_scatter(input, src, dim, index): - input_indexes = [] - if dim < 0: - dim += len(input.shape) - for x in range(len(input.shape)): - if x == dim: - input_indexes.append(index) - else: - input_indexes.append(slice(None, None, None)) - return input.at[tuple(input_indexes)].set(src) + input_indexes = [] + if dim < 0: + dim += len(input.shape) + for x in range(len(input.shape)): + if x == dim: + input_indexes.append(index) + else: + input_indexes.append(slice(None, None, None)) + return input.at[tuple(input_indexes)].set(src) @op(torch.ops.aten.scatter.src) def _aten_scatter_src(input, dim, index, src, reduce=None): - input_index, source_indexes = _scatter_index(dim, index) - return input.at[input_index].set(src[source_indexes]) + input_index, source_indexes = _scatter_index(dim, index) + return input.at[input_index].set(src[source_indexes]) @op(torch.ops.aten.scatter.value) def _aten_scatter(input, dim, index, src, reduce=None): - input_index, source_indexes = _scatter_index(dim, index) - return input.at[input_index].set(src) + input_index, source_indexes = _scatter_index(dim, index) + return input.at[input_index].set(src) # aten.acosh @op(torch.ops.aten.acosh) @op_base.promote_int_input def _aten_acosh(self): - return jnp.arccosh(self) + return jnp.arccosh(self) # aten.avg_pool2d_backward @@ -2196,58 +2262,59 @@ def _aten_acosh(self): # aten.round @op(torch.ops.aten.round) def _aten_round(input, decimals=0): - return jnp.round(input, decimals) + return jnp.round(input, decimals) # aten.max @op(torch.ops.aten.max) def _aten_max(self, dim=None, keepdim=False): - if dim is not None: - return _with_reduction_scalar(jnp.max, self, dim, - keepdim), _with_reduction_scalar( - jnp.argmax, self, dim, - keepdim).astype(jnp.int64) - else: - return _with_reduction_scalar(jnp.max, self, dim, keepdim) + if dim is not None: + return _with_reduction_scalar( + jnp.max, self, dim, keepdim + ), _with_reduction_scalar(jnp.argmax, self, dim, keepdim).astype( + jnp.int64 + ) + else: + return _with_reduction_scalar(jnp.max, self, dim, keepdim) # aten.maximum @op(torch.ops.aten.maximum) def _aten_maximum(self, other): - return jnp.maximum(self, other) + return jnp.maximum(self, other) # aten.abs @op(torch.ops.aten.abs) def _aten_abs(self): - return jnp.abs(self) + return jnp.abs(self) # generate aten.amax only @op(torch.ops.aten.amax) def _aten_amax(self, dim=None, keepdim=False): - return _with_reduction_scalar(jnp.amax, self, dim, keepdim) + return _with_reduction_scalar(jnp.amax, self, dim, keepdim) def _with_reduction_scalar(jax_func, self, dim, keepdim): - expanded = False - if self.ndim == 0: - # for self of rank 0: - # torch.any(x, 0), torch.any(x, -1) works; - # torch.any(x, 1) throws out of bounds, so it's - # behavior is the same as a jnp array of rank 1 - expanded = True - self = jnp.expand_dims(self, 0) - res = jax_func(self, axis=dim, keepdims=keepdim) - if expanded: - res = res.squeeze() - return res + expanded = False + if self.ndim == 0: + # for self of rank 0: + # torch.any(x, 0), torch.any(x, -1) works; + # torch.any(x, 1) throws out of bounds, so it's + # behavior is the same as a jnp array of rank 1 + expanded = True + self = jnp.expand_dims(self, 0) + res = jax_func(self, axis=dim, keepdims=keepdim) + if expanded: + res = res.squeeze() + return res # aten.any @op(torch.ops.aten.any) def _aten_any(self, dim=None, keepdim=False): - return _with_reduction_scalar(jnp.any, self, dim, keepdim) + return _with_reduction_scalar(jnp.any, self, dim, keepdim) # aten.arange @@ -2266,739 +2333,759 @@ def _aten_arange( device=None, pin_memory=False, ): - return jnp.arange( - op_base.maybe_convert_constant_dtype(start, dtype), - op_base.maybe_convert_constant_dtype(end, dtype), - op_base.maybe_convert_constant_dtype(step, dtype), - dtype=dtype, - ) + return jnp.arange( + op_base.maybe_convert_constant_dtype(start, dtype), + op_base.maybe_convert_constant_dtype(end, dtype), + op_base.maybe_convert_constant_dtype(step, dtype), + dtype=dtype, + ) # aten.argmax @op(torch.ops.aten.argmax) def _aten_argmax(self, dim=None, keepdim=False): - return _with_reduction_scalar(jnp.argmax, self, dim, keepdim) + return _with_reduction_scalar(jnp.argmax, self, dim, keepdim) def _strided_index(sizes, strides, storage_offset=None): - ind = jnp.zeros(sizes, dtype=jnp.int32) + ind = jnp.zeros(sizes, dtype=jnp.int32) - for i, (size, stride) in enumerate(zip(sizes, strides)): - result_shape = (1,) * i + (size,) + (1,) * (len(sizes) - i - 1) - indexes = (jnp.arange(size) * stride).reshape(result_shape) - ind += indexes + for i, (size, stride) in enumerate(zip(sizes, strides)): + result_shape = (1,) * i + (size,) + (1,) * (len(sizes) - i - 1) + indexes = (jnp.arange(size) * stride).reshape(result_shape) + ind += indexes - if storage_offset is not None: - ind += storage_offset - return ind + if storage_offset is not None: + ind += storage_offset + return ind # aten.as_strided @op(torch.ops.aten.as_strided) @op(torch.ops.aten.as_strided_copy) def _aten_as_strided(x, sizes, strides, storage_offset=None): - ind = _strided_index(sizes, strides, storage_offset) - flattened = jnp.ravel(x) - return flattened[ind] + ind = _strided_index(sizes, strides, storage_offset) + flattened = jnp.ravel(x) + return flattened[ind] @op(torch.ops.aten.as_strided_scatter) def _aten_as_strided_scatter(x, src, sizes, strides, storage_offset): - ind = _strided_index(sizes, strides, storage_offset) - flattened = jnp.ravel(x) - modified = flattened.at[ind].set(src) - return modified.reshape(x.shape) + ind = _strided_index(sizes, strides, storage_offset) + flattened = jnp.ravel(x) + modified = flattened.at[ind].set(src) + return modified.reshape(x.shape) # aten.atan2 @op(torch.ops.aten.atan2) @op_base.promote_int_input def _aten_atan2(input, other): - return jnp.arctan2(input, other) + return jnp.arctan2(input, other) # aten.bitwise_and @op(torch.ops.aten.bitwise_and) @op(torch.ops.aten.__and__) def _aten_bitwise_and(self, other): - return self & other + return self & other # aten.bitwise_or @op(torch.ops.aten.bitwise_or) def _aten_bitwise_or(self, other): - return self | other + return self | other # aten.bitwise_xor @op(torch.ops.aten.bitwise_xor) def _aten_bitwise_xor(self, other): - return self ^ other + return self ^ other # aten.broadcast_tensors @op(torch.ops.aten.broadcast_tensors) def _aten_broadcast_tensors(*tensors): - - def _get_broadcast_shape(shapes): - """ - Determines the output shape by broadcasting all input shapes. - - Args: - shapes: A list of tuples representing the shapes of the input tensors. - - Returns: - A tuple representing the broadcasted output shape. - """ - - # Find the maximum number of dimensions among all input tensors - max_dims = max(len(shape) for shape in shapes) - # Pad shorter shapes with 1s on the left to match the maximum number of dimensions - padded_shapes = [(1,) * (max_dims - len(shape)) + shape for shape in shapes] - - # Initialize the output shape with 1s - output_shape = [1] * max_dims - # Iterate through each dimension and apply broadcasting rules - for dim in range(max_dims): - dim_sizes = [shape[dim] for shape in padded_shapes] - max_size = max(dim_sizes) - if all(size == 1 or size == max_size for size in dim_sizes): - output_shape[dim] = max_size - else: - raise ValueError("Incompatible shapes for broadcasting") - return tuple(output_shape) - - def _broadcast_dimensions(input_shape, output_shape): - """ - Determines the broadcast_dimensions argument for jax.lax.broadcast_in_dim. - - Args: - input_shape: The shape of the input tensor. - output_shape: The desired output shape after broadcasting. - - Returns: - A tuple specifying which dimensions of the input tensor should be broadcasted. - """ - - res = tuple( - i for i, (in_dim, out_dim) in enumerate(zip(input_shape, output_shape))) - return res - - # clean some function's previous wrap - if len(tensors) == 1 and len(tensors[0]) >= 1 and isinstance( - tensors[0][0], jax.Array): - tensors = tensors[0] - - # Get the shapes of all input tensors - shapes = [t.shape for t in tensors] - # Find the output shape by broadcasting all input shapes - output_shape = _get_broadcast_shape(shapes) - # Broadcast each tensor to the output shape - broadcasted_tensors = [ - jax.lax.broadcast_in_dim(t, output_shape, - _broadcast_dimensions(t.shape, output_shape)) - for t in tensors - ] - - return broadcasted_tensors + def _get_broadcast_shape(shapes): + """ + Determines the output shape by broadcasting all input shapes. + + Args: + shapes: A list of tuples representing the shapes of the input tensors. + + Returns: + A tuple representing the broadcasted output shape. + """ + + # Find the maximum number of dimensions among all input tensors + max_dims = max(len(shape) for shape in shapes) + # Pad shorter shapes with 1s on the left to match the maximum number of dimensions + padded_shapes = [ + (1,) * (max_dims - len(shape)) + shape for shape in shapes + ] + + # Initialize the output shape with 1s + output_shape = [1] * max_dims + # Iterate through each dimension and apply broadcasting rules + for dim in range(max_dims): + dim_sizes = [shape[dim] for shape in padded_shapes] + max_size = max(dim_sizes) + if all(size == 1 or size == max_size for size in dim_sizes): + output_shape[dim] = max_size + else: + raise ValueError("Incompatible shapes for broadcasting") + return tuple(output_shape) + + def _broadcast_dimensions(input_shape, output_shape): + """ + Determines the broadcast_dimensions argument for jax.lax.broadcast_in_dim. + + Args: + input_shape: The shape of the input tensor. + output_shape: The desired output shape after broadcasting. + + Returns: + A tuple specifying which dimensions of the input tensor should be broadcasted. + """ + + res = tuple( + i + for i, (in_dim, out_dim) in enumerate( + zip(input_shape, output_shape) + ) + ) + return res + + # clean some function's previous wrap + if ( + len(tensors) == 1 + and len(tensors[0]) >= 1 + and isinstance(tensors[0][0], jax.Array) + ): + tensors = tensors[0] + + # Get the shapes of all input tensors + shapes = [t.shape for t in tensors] + # Find the output shape by broadcasting all input shapes + output_shape = _get_broadcast_shape(shapes) + # Broadcast each tensor to the output shape + broadcasted_tensors = [ + jax.lax.broadcast_in_dim( + t, output_shape, _broadcast_dimensions(t.shape, output_shape) + ) + for t in tensors + ] + + return broadcasted_tensors # aten.broadcast_to @op(torch.ops.aten.broadcast_to) def _aten_broadcast_to(input, shape): - return jnp.broadcast_to(input, shape) + return jnp.broadcast_to(input, shape) # aten.clamp @op(torch.ops.aten.clamp.default) @op(torch.ops.aten.clamp.Tensor) def _aten_clamp(self, min=None, max=None): - return jnp.clip(self, min, max) + return jnp.clip(self, min, max) @op(torch.ops.aten.clamp_min) def _aten_clamp_min(input, min): - return jnp.clip(input, min=min) + return jnp.clip(input, min=min) # aten.constant_pad_nd @op(torch.ops.aten.constant_pad_nd) def _aten_constant_pad_nd(input, padding, value=0): - # NOTE: Torch padding is flat and reversed: (1, 1, 2, 2) - # means last dim get padded 1 in front and 1 in back; - # and second last dim get padded 2 in front and 2 in back. - # Jax padding tuple of 3-tuple: the same padding is - # [(0, 0, 0), ..., (2,2,0), (1,1,0)], where the last dimension - # is the amount of padding added between any two elements in each dimension - m = len(padding) - rev_padding = [(padding[i - 1], padding[i], 0) for i in range(m - 1, 0, -2)] - pad_dim = tuple(([(0, 0, 0)] * (len(input.shape) - m // 2)) + rev_padding) - value_casted = jax.numpy.array(value, dtype=input.dtype) - return jax.lax.pad(input, padding_value=value_casted, padding_config=pad_dim) + # NOTE: Torch padding is flat and reversed: (1, 1, 2, 2) + # means last dim get padded 1 in front and 1 in back; + # and second last dim get padded 2 in front and 2 in back. + # Jax padding tuple of 3-tuple: the same padding is + # [(0, 0, 0), ..., (2,2,0), (1,1,0)], where the last dimension + # is the amount of padding added between any two elements in each dimension + m = len(padding) + rev_padding = [(padding[i - 1], padding[i], 0) for i in range(m - 1, 0, -2)] + pad_dim = tuple(([(0, 0, 0)] * (len(input.shape) - m // 2)) + rev_padding) + value_casted = jax.numpy.array(value, dtype=input.dtype) + return jax.lax.pad( + input, padding_value=value_casted, padding_config=pad_dim + ) # aten.convolution_backward @op(torch.ops.aten.lift_fresh_copy) def _aten_lift_fresh_copy(x): - return jnp.copy(x) + return jnp.copy(x) @op(torch.ops.aten.copy) def _aten_copy(self, src): - return jnp.broadcast_to(src, self.shape).astype(self.dtype) + return jnp.broadcast_to(src, self.shape).astype(self.dtype) @op(torch.ops.aten._cdist_forward) def _aten_cdist_forward(x1, x2, p, compute_mode=""): - # x1 is B x P x M - # x2 is B x Q x M - # res is B x P x Q - x1 = jnp.expand_dims(x1, len(x1.shape) - 1) - x2 = jnp.expand_dims(x2, len(x2.shape) - 2) - return jnp.linalg.norm(x1 - x2, ord=p, axis=-1) + # x1 is B x P x M + # x2 is B x Q x M + # res is B x P x Q + x1 = jnp.expand_dims(x1, len(x1.shape) - 1) + x2 = jnp.expand_dims(x2, len(x2.shape) - 2) + return jnp.linalg.norm(x1 - x2, ord=p, axis=-1) @op(torch.ops.aten._pdist_forward) def _aten__pdist_forward(x, p=2): - pairwise_dists = _aten_cdist_forward(x, x, p) - condensed_dists = pairwise_dists[jnp.triu_indices( - pairwise_dists.shape[0], k=1)] - return condensed_dists + pairwise_dists = _aten_cdist_forward(x, x, p) + condensed_dists = pairwise_dists[ + jnp.triu_indices(pairwise_dists.shape[0], k=1) + ] + return condensed_dists @op(torch.ops.aten.cholesky_inverse) def _aten_cholesky_inverse(input, upper=False): - t = jnp.matrix_transpose(input) - if "complex" in str(input.dtype): - t = t.conjugate() - return jnp.linalg.inv(input @ t) + t = jnp.matrix_transpose(input) + if "complex" in str(input.dtype): + t = t.conjugate() + return jnp.linalg.inv(input @ t) # aten.cos @op(torch.ops.aten.cos) @op_base.promote_int_input def _aten_cos(input): - return jnp.cos(input) + return jnp.cos(input) # aten.cosh @op(torch.ops.aten.cosh) @op_base.promote_int_input def _aten_cosh(input): - return jnp.cosh(input) + return jnp.cosh(input) @op(torch.ops.aten.diag) def _aten_diag(input, diagonal=0): - return jnp.diag(input, diagonal) + return jnp.diag(input, diagonal) # aten.diagonal @op(torch.ops.aten.diagonal) def _aten_diagonal(input, offset=0, dim1=0, dim2=1): - return jnp.diagonal(input, offset, dim1, dim2) + return jnp.diagonal(input, offset, dim1, dim2) def diag_indices_with_offset(input_shape, offset, dim1=0, dim2=1): - input_len = len(input_shape) - if dim1 == dim2 or not (0 <= dim1 < input_len and 0 <= dim2 < input_len): - raise ValueError("dim1 and dim2 must be different and in range [0, " + - str(input_len - 1) + "]") - - size1, size2 = input_shape[dim1], input_shape[dim2] - if offset >= 0: - indices1 = jnp.arange(min(size1, size2 - offset)) - indices2 = jnp.arange(offset, offset + len(indices1)) - else: - indices2 = jnp.arange(min(size1 + offset, size2)) - indices1 = jnp.arange(-offset, -offset + len(indices2)) - return [indices1, indices2] + input_len = len(input_shape) + if dim1 == dim2 or not (0 <= dim1 < input_len and 0 <= dim2 < input_len): + raise ValueError( + "dim1 and dim2 must be different and in range [0, " + + str(input_len - 1) + + "]" + ) + + size1, size2 = input_shape[dim1], input_shape[dim2] + if offset >= 0: + indices1 = jnp.arange(min(size1, size2 - offset)) + indices2 = jnp.arange(offset, offset + len(indices1)) + else: + indices2 = jnp.arange(min(size1 + offset, size2)) + indices1 = jnp.arange(-offset, -offset + len(indices2)) + return [indices1, indices2] @op(torch.ops.aten.diagonal_scatter) def _aten_diagonal_scatter(input, src, offset=0, dim1=0, dim2=1): - indexes = diag_indices_with_offset(input.shape, offset, dim1, dim2) - - if input.ndim == 2: - return input.at[tuple(indexes)].set(src) - else: - # src has the same shape as the output of - # jnp.diagonal(input, offset, dim1, dim2). - # Last dimension always contains the diagonal elements, - # while the preceding dimensions represent the "slices" - # from which these diagonals are extracted. Thus, - # we alter input axes to match this assumption, write src - # and then move the axes back to the original state. - input = jnp.moveaxis(input, (dim1, dim2), (-2, -1)) - multi_indexes = [slice(None)] * (input.ndim - 2) + indexes - input = input.at[tuple(multi_indexes)].set(src) - return jnp.moveaxis(input, (-2, -1), (dim1, dim2)) + indexes = diag_indices_with_offset(input.shape, offset, dim1, dim2) + + if input.ndim == 2: + return input.at[tuple(indexes)].set(src) + else: + # src has the same shape as the output of + # jnp.diagonal(input, offset, dim1, dim2). + # Last dimension always contains the diagonal elements, + # while the preceding dimensions represent the "slices" + # from which these diagonals are extracted. Thus, + # we alter input axes to match this assumption, write src + # and then move the axes back to the original state. + input = jnp.moveaxis(input, (dim1, dim2), (-2, -1)) + multi_indexes = [slice(None)] * (input.ndim - 2) + indexes + input = input.at[tuple(multi_indexes)].set(src) + return jnp.moveaxis(input, (-2, -1), (dim1, dim2)) # aten.diagflat @op(torch.ops.aten.diagflat) def _aten_diagflat(input, offset=0): - return jnp.diagflat(jnp.array(input), offset) + return jnp.diagflat(jnp.array(input), offset) @op(torch.ops.aten.movedim) def _aten_movedim(input, source, destination): - return jnp.moveaxis(input, source, destination) + return jnp.moveaxis(input, source, destination) # aten.eq @op(torch.ops.aten.eq) def _aten_eq(input1, input2): - return input1 == input2 + return input1 == input2 # aten.equal @op(torch.ops.aten.equal) def _aten_equal(input, other): - res = jnp.array_equal(input, other) - return bool(res) + res = jnp.array_equal(input, other) + return bool(res) # aten.erf @op(torch.ops.aten.erf) @op_base.promote_int_input def _aten_erf(x): - return jax.lax.erf(x) + return jax.lax.erf(x) @op(torch.ops.aten.erfinv) @op_base.promote_int_input def _aten_erfinv(input): - return jax.lax.erf_inv(input) + return jax.lax.erf_inv(input) # aten.exp @op(torch.ops.aten.exp) def _aten_exp(input): - res = jnp.exp(input) - new_dtype = mappings.t2j_dtype(torch.get_default_dtype()) - if input.dtype == jax.numpy.int64: - res = res.astype(new_dtype) - return res + res = jnp.exp(input) + new_dtype = mappings.t2j_dtype(torch.get_default_dtype()) + if input.dtype == jax.numpy.int64: + res = res.astype(new_dtype) + return res # aten.expm1 @op(torch.ops.aten.expm1) def _aten_expm1(input): - res = jnp.expm1(input) - new_dtype = mappings.t2j_dtype(torch.get_default_dtype()) - if input.dtype == jax.numpy.int64: - res = res.astype(new_dtype) - return res + res = jnp.expm1(input) + new_dtype = mappings.t2j_dtype(torch.get_default_dtype()) + if input.dtype == jax.numpy.int64: + res = res.astype(new_dtype) + return res # aten.exp2 @op(torch.ops.aten.exp2) def _aten_exp2(input): - res = jnp.exp2(input) - new_dtype = mappings.t2j_dtype(torch.get_default_dtype()) - if input.dtype == jax.numpy.int64: - res = res.astype(new_dtype) - return res + res = jnp.exp2(input) + new_dtype = mappings.t2j_dtype(torch.get_default_dtype()) + if input.dtype == jax.numpy.int64: + res = res.astype(new_dtype) + return res # aten.fill @op(torch.ops.aten.fill) @op(torch.ops.aten.full_like) -def _aten_fill(x, - value, - dtype=None, - pin_memory=None, - memory_format=None, - device=None): - if dtype is None: - dtype = x.dtype - else: - dtype = mappings.t2j_dtype(dtype) - return jnp.full(x.shape, value, dtype) +def _aten_fill( + x, value, dtype=None, pin_memory=None, memory_format=None, device=None +): + if dtype is None: + dtype = x.dtype + else: + dtype = mappings.t2j_dtype(dtype) + return jnp.full(x.shape, value, dtype) # aten.flip @op(torch.ops.aten.flip) def _aten_flip(input, dims): - if dims is not None: - return jnp.flip(input, tuple(dims)) - else: - return jnp.flip(input) + if dims is not None: + return jnp.flip(input, tuple(dims)) + else: + return jnp.flip(input) # aten.floor @op(torch.ops.aten.floor) def _aten_floor(input): - return jnp.floor(input).astype(input.dtype) + return jnp.floor(input).astype(input.dtype) # aten.fmax @op(torch.ops.aten.fmax) def _aten_fmax(input, other): - return jnp.fmax(input, other) + return jnp.fmax(input, other) # aten.fmin @op(torch.ops.aten.fmin) def _aten_fmin(input, other): - return jnp.fmin(input, other) + return jnp.fmin(input, other) # aten.fmod @op(torch.ops.aten.fmod) def _aten_fmod(input, other): - return input - other * _aten_div(input, other, "trunc") + return input - other * _aten_div(input, other, "trunc") # aten.frexp @op(torch.ops.aten.frexp) def _aten_frexp(input): - return jnp.frexp(input) + return jnp.frexp(input) # aten.gather @op(torch.ops.aten.gather) def _aten_gather(input, dim, index): - if input.ndim == 0: - return jnp.broadcast_to(input, index.shape) - # short circuit for empty outputs - if not all(index.shape): - return jnp.zeros(index.shape, dtype=input.dtype) - if dim < 0: - dim += input.ndim - input_indexes, source_indexes = _scatter_index(dim, index) - return input[input_indexes] + if input.ndim == 0: + return jnp.broadcast_to(input, index.shape) + # short circuit for empty outputs + if not all(index.shape): + return jnp.zeros(index.shape, dtype=input.dtype) + if dim < 0: + dim += input.ndim + input_indexes, source_indexes = _scatter_index(dim, index) + return input[input_indexes] # aten.ge @op(torch.ops.aten.ge) def _aten_ge(self, other): - return self >= other + return self >= other @op(torch.ops.aten.glu) def _aten_glu(x, dim=-1): - return jax.nn.glu(x, dim) + return jax.nn.glu(x, dim) # aten.hardtanh @op(torch.ops.aten.hardtanh) def _aten_hardtanh(input, min_val=-1, max_val=1, inplace=False): - if input.dtype == np.int64 and isinstance(max_val, float) and isinstance( - min_val, float): - min_val = int(min_val) - max_val = int(max_val) - return jnp.clip(input, min_val, max_val) + if ( + input.dtype == np.int64 + and isinstance(max_val, float) + and isinstance(min_val, float) + ): + min_val = int(min_val) + max_val = int(max_val) + return jnp.clip(input, min_val, max_val) # aten.histc @op(torch.ops.aten.histc) def _aten_histc(input, bins=100, min=0, max=0): - # TODO(@manfei): this function might cause some uncertainty - if min == 0 and max == 0: - if isinstance(input, jnp.ndarray) and input.size == 0: - min = 0 - max = 0 - else: - min = jnp.min(input) - max = jnp.max(input) - range_value = (min, max) - hist, bin_edges = jnp.histogram( - input, bins=bins, range=range_value, weights=None, density=None) - return hist + # TODO(@manfei): this function might cause some uncertainty + if min == 0 and max == 0: + if isinstance(input, jnp.ndarray) and input.size == 0: + min = 0 + max = 0 + else: + min = jnp.min(input) + max = jnp.max(input) + range_value = (min, max) + hist, bin_edges = jnp.histogram( + input, bins=bins, range=range_value, weights=None, density=None + ) + return hist @op(torch.ops.aten.hypot) def _aten_hypot(input, other): - return jnp.hypot(input, other) + return jnp.hypot(input, other) @op(torch.ops.aten.digamma) def _aten_digamma(input, *, out=None): - res = jax.scipy.special.digamma(input).astype(jnp.float32) - # replace indices where input == 0 with -inf in res - return jnp.where(jnp.equal(input, jnp.zeros(input.shape)), -jnp.inf, res) + res = jax.scipy.special.digamma(input).astype(jnp.float32) + # replace indices where input == 0 with -inf in res + return jnp.where(jnp.equal(input, jnp.zeros(input.shape)), -jnp.inf, res) @op(torch.ops.aten.igamma) def _aten_igamma(input, other): - return jax.scipy.special.gammainc(input, other) + return jax.scipy.special.gammainc(input, other) @op(torch.ops.aten.lgamma) def _aten_lgamma(input, *, out=None): - return jax.scipy.special.gammaln(input).astype(jnp.float32) + return jax.scipy.special.gammaln(input).astype(jnp.float32) @op(torch.ops.aten.mvlgamma) def _aten_mvlgamma(input, p, *, out=None): - input = input.astype(mappings.t2j_dtype(torch.get_default_dtype())) - return jax.scipy.special.multigammaln(input, p) + input = input.astype(mappings.t2j_dtype(torch.get_default_dtype())) + return jax.scipy.special.multigammaln(input, p) @op(torch.ops.aten.linalg_eig) def _aten_linalg_eig(A): - return jnp.linalg.eig(A) + return jnp.linalg.eig(A) @op(torch.ops.aten._linalg_eigh) -def _aten_linalg_eigh(A, UPLO='L'): - return jnp.linalg.eigh(A, UPLO) +def _aten_linalg_eigh(A, UPLO="L"): + return jnp.linalg.eigh(A, UPLO) @op(torch.ops.aten.linalg_lstsq) -def _aten_linalg_lstsq(A, B, rcond=None, driver='gelsy'): - input_dtype = A.dtype +def _aten_linalg_lstsq(A, B, rcond=None, driver="gelsy"): + input_dtype = A.dtype - m = A.shape[-2] - n = A.shape[-1] + m = A.shape[-2] + n = A.shape[-1] - is_batched = A.ndim > 2 + is_batched = A.ndim > 2 - if is_batched: + if is_batched: + batch_shape = jnp.broadcast_shapes(A.shape[:-2], B.shape[:-2]) + batch_size = int(np.prod(batch_shape)) + A_reshaped = A.reshape((batch_size,) + A.shape[-2:]) + B_reshaped = B.reshape((batch_size,) + B.shape[-2:]) - batch_shape = jnp.broadcast_shapes(A.shape[:-2], B.shape[:-2]) - batch_size = int(np.prod(batch_shape)) - A_reshaped = A.reshape((batch_size,) + A.shape[-2:]) - B_reshaped = B.reshape((batch_size,) + B.shape[-2:]) + X, residuals, rank, singular_values = jax.vmap( + jnp.linalg.lstsq, in_axes=(0, 0) + )(A_reshaped, B_reshaped, rcond=rcond) - X, residuals, rank, singular_values = jax.vmap( - jnp.linalg.lstsq, in_axes=(0, - 0))(A_reshaped, B_reshaped, rcond=rcond) + X = X.reshape(batch_shape + X.shape[-2:]) - X = X.reshape(batch_shape + X.shape[-2:]) + if driver in ["gelsd", "gelsy", "gelss"]: + rank = rank.reshape(batch_shape) + else: + rank = jnp.array([], dtype=jnp.int64) - if driver in ['gelsd', 'gelsy', 'gelss']: - rank = rank.reshape(batch_shape) - else: - rank = jnp.array([], dtype=jnp.int64) + full_rank = jnp.all(rank == n) + if driver == "gelsy" or m <= n or (not full_rank): + residuals = jnp.array([], dtype=input_dtype) + else: + residuals = residuals.reshape(batch_shape + residuals.shape[-1:]) - full_rank = jnp.all(rank == n) - if driver == 'gelsy' or m <= n or (not full_rank): - residuals = jnp.array([], dtype=input_dtype) - else: - residuals = residuals.reshape(batch_shape + residuals.shape[-1:]) + if driver in ["gelsd", "gelss"]: + singular_values = singular_values.reshape( + batch_shape + singular_values.shape[-1:] + ) + else: + singular_values = jnp.array([], dtype=input_dtype) - if driver in ['gelsd', 'gelss']: - singular_values = singular_values.reshape(batch_shape + - singular_values.shape[-1:]) else: - singular_values = jnp.array([], dtype=input_dtype) - - else: + X, residuals, rank, singular_values = jnp.linalg.lstsq( + A, B, rcond=rcond + ) - X, residuals, rank, singular_values = jnp.linalg.lstsq(A, B, rcond=rcond) + if driver not in ["gelsd", "gelsy", "gelss"]: + rank = jnp.array([], dtype=jnp.int64) - if driver not in ['gelsd', 'gelsy', 'gelss']: - rank = jnp.array([], dtype=jnp.int64) + rank_value = None + if rank.size > 0: + rank_value = int(rank.item()) + rank = jnp.array(rank_value, dtype=jnp.int64) - rank_value = None - if rank.size > 0: - rank_value = int(rank.item()) - rank = jnp.array(rank_value, dtype=jnp.int64) + # When driver is ‘gels’, assume that A is full-rank. + full_rank = driver == "gels" or rank_value == n + if driver == "gelsy" or m <= n or (not full_rank): + residuals = jnp.array([], dtype=input_dtype) - # When driver is ‘gels’, assume that A is full-rank. - full_rank = driver == 'gels' or rank_value == n - if driver == 'gelsy' or m <= n or (not full_rank): - residuals = jnp.array([], dtype=input_dtype) + if driver not in ["gelsd", "gelss"]: + singular_values = jnp.array([], dtype=input_dtype) - if driver not in ['gelsd', 'gelss']: - singular_values = jnp.array([], dtype=input_dtype) - - return X, residuals, rank, singular_values + return X, residuals, rank, singular_values @op(torch.ops.aten.linalg_ldl_factor_ex) def _aten_linalg_ldl_factor_ex(A, hermitian=False, check_errors=False): - # TODO: Replace with native LDL when available: - # https://github.com/jax-ml/jax/issues/12779 - # TODO: Not tested for complex inputs. Does not support hermitian=True - pivots = jnp.broadcast_to( - jnp.arange(1, A.shape[-1] + 1, dtype=jnp.int32), A.shape[:-1]) - info = jnp.zeros(A.shape[:-2], jnp.int32) - C = jnp.linalg.cholesky(A) - if C.size == 0: - return C, pivots, info - - # Fill diagonals of stacked matrices - @functools.partial(jnp.vectorize, signature='(k,k),(k,k)->(k,k)') - def fill_diagonal_batch(x, y): - return jnp.fill_diagonal(x, jnp.diag(y), inplace=False) - - D = C * jnp.eye(C.shape[-1], dtype=A.dtype) - LD = C @ jnp.linalg.inv(D) - LD = fill_diagonal_batch(LD, D * D) - return LD, pivots, info + # TODO: Replace with native LDL when available: + # https://github.com/jax-ml/jax/issues/12779 + # TODO: Not tested for complex inputs. Does not support hermitian=True + pivots = jnp.broadcast_to( + jnp.arange(1, A.shape[-1] + 1, dtype=jnp.int32), A.shape[:-1] + ) + info = jnp.zeros(A.shape[:-2], jnp.int32) + C = jnp.linalg.cholesky(A) + if C.size == 0: + return C, pivots, info + + # Fill diagonals of stacked matrices + @functools.partial(jnp.vectorize, signature="(k,k),(k,k)->(k,k)") + def fill_diagonal_batch(x, y): + return jnp.fill_diagonal(x, jnp.diag(y), inplace=False) + + D = C * jnp.eye(C.shape[-1], dtype=A.dtype) + LD = C @ jnp.linalg.inv(D) + LD = fill_diagonal_batch(LD, D * D) + return LD, pivots, info @op(torch.ops.aten.linalg_lu) def _aten_linalg_lu(A, pivot=True, out=None): - dtype = A.dtype + dtype = A.dtype - *_, m, n = A.shape - k = jnp.minimum(m, n) + *_, m, n = A.shape + k = jnp.minimum(m, n) - lu, _, permutation = jax.lax.linalg.lu(A) + lu, _, permutation = jax.lax.linalg.lu(A) - L = jnp.tril(lu[..., :, :k], k=-1) - eye_L = jnp.eye(m, k, dtype=dtype) - L = L + eye_L + L = jnp.tril(lu[..., :, :k], k=-1) + eye_L = jnp.eye(m, k, dtype=dtype) + L = L + eye_L - U = jnp.triu(lu[..., :k, :]) + U = jnp.triu(lu[..., :k, :]) - def perm_to_P(perm): - m = perm.shape[-1] - P = jnp.eye(m, dtype=dtype)[perm].T - return P + def perm_to_P(perm): + m = perm.shape[-1] + P = jnp.eye(m, dtype=dtype)[perm].T + return P - if permutation.ndim > 1: - num_batch_dims = permutation.ndim - 1 - for _ in range(num_batch_dims): - perm_to_P = jax.vmap(perm_to_P, in_axes=0) + if permutation.ndim > 1: + num_batch_dims = permutation.ndim - 1 + for _ in range(num_batch_dims): + perm_to_P = jax.vmap(perm_to_P, in_axes=0) - P = perm_to_P(permutation) + P = perm_to_P(permutation) - return P, L, U + return P, L, U @op(torch.ops.aten.linalg_lu_factor_ex) def _aten_linalg_lu_factor_ex(A, pivot=True, check_errors=False): - lu, pivots, _ = jax.lax.linalg.lu(A) - # PT pivots vector is 1-indexed - pivots = pivots + 1 - info = jnp.zeros(A.shape[:-2], jnp.int32) - return lu, pivots, info + lu, pivots, _ = jax.lax.linalg.lu(A) + # PT pivots vector is 1-indexed + pivots = pivots + 1 + info = jnp.zeros(A.shape[:-2], jnp.int32) + return lu, pivots, info @op(torch.ops.aten.linalg_lu_solve) def _aten_linalg_lu_solve(LU, pivots, B, left=True, adjoint=False): - # JAX pivots are offset by 1 compared to torch - pivots = pivots - 1 - if not left: - # XA = B is same as A'X = B' - trans = 0 if adjoint else 2 - x = jax.scipy.linalg.lu_solve((LU, pivots), jnp.matrix_transpose(B), trans) - x = jnp.matrix_transpose(x) - else: - trans = 2 if adjoint else 0 - x = jax.scipy.linalg.lu_solve((LU, pivots), B, trans) - return x + # JAX pivots are offset by 1 compared to torch + pivots = pivots - 1 + if not left: + # XA = B is same as A'X = B' + trans = 0 if adjoint else 2 + x = jax.scipy.linalg.lu_solve( + (LU, pivots), jnp.matrix_transpose(B), trans + ) + x = jnp.matrix_transpose(x) + else: + trans = 2 if adjoint else 0 + x = jax.scipy.linalg.lu_solve((LU, pivots), B, trans) + return x @op(torch.ops.aten.gcd) def _aten_gcd(input, other): - return jnp.gcd(input, other) + return jnp.gcd(input, other) # aten.lcm @op(torch.ops.aten.lcm) def _aten_lcm(input, other): - return jnp.lcm(input, other) + return jnp.lcm(input, other) # aten.isinf @op(torch.ops.aten.isinf) def _aten_isinf(input): - return jnp.isinf(input) + return jnp.isinf(input) # aten.isnan @op(torch.ops.aten.isnan) def _aten_isnan(input): - return jnp.isnan(input) + return jnp.isnan(input) @op(torch.ops.aten.le) def _aten_le(self, other): - return self <= other + return self <= other # aten.leaky_relu @op(torch.ops.aten.leaky_relu) def _aten_leaky_relu(x, negative_slope=0.01): - return jax.nn.leaky_relu(x, negative_slope) + return jax.nn.leaky_relu(x, negative_slope) # aten.log @op(torch.ops.aten.log) @op_base.promote_int_input def _aten_log(x): - return jnp.log(x) + return jnp.log(x) # aten.log10 @op(torch.ops.aten.log10) @op_base.promote_int_input def _aten_log10(x): - return jnp.log10(x) + return jnp.log10(x) # aten.log1p @op(torch.ops.aten.log1p) @op_base.promote_int_input def _aten_log1p(x): - return jnp.log1p(x) + return jnp.log1p(x) # aten.log2 @op(torch.ops.aten.log2) @op_base.promote_int_input def _aten_log2(x): - return jnp.log2(x) + return jnp.log2(x) # aten.logical_and @op(torch.ops.aten.logical_and) def _aten_logical_and(self, other): - return jnp.logical_and(self, other) + return jnp.logical_and(self, other) # aten.logical_or @op(torch.ops.aten.logical_or) def _aten_logical_or(self, other): - return jnp.logical_or(self, other) + return jnp.logical_or(self, other) # aten.logical_not @op(torch.ops.aten.logical_not) def _aten_logical_not(self): - return jnp.logical_not(self) + return jnp.logical_not(self) # aten.log_softmax @op(torch.ops.aten._log_softmax) def _aten_log_softmax(self, axis=-1, half_to_float=False): - if self.shape == (): - return jnp.astype(0.0, self.dtype) - return jax.nn.log_softmax(self, axis) + if self.shape == (): + return jnp.astype(0.0, self.dtype) + return jax.nn.log_softmax(self, axis) # aten.logaddexp @op(torch.ops.aten.logaddexp) def _aten_logaddexp(self, other): - return jnp.logaddexp(self, other) + return jnp.logaddexp(self, other) # aten.logaddexp2 @op(torch.ops.aten.logaddexp2) def _aten_logaddexp2(self, other): - return jnp.logaddexp2(self, other) + return jnp.logaddexp2(self, other) # aten.logcumsumexp @op(torch.ops.aten.logcumsumexp) def _aten_logcumsumexp(self, dim=None): - if self.shape == (): - return self - return jax.lax.cumlogsumexp(self, axis=dim) + if self.shape == (): + return self + return jax.lax.cumlogsumexp(self, axis=dim) # aten.max_pool3d_backward # aten.logical_xor @op(torch.ops.aten.logical_xor) def _aten_logical_xor(self, other): - return jnp.logical_xor(self, other) + return jnp.logical_xor(self, other) # aten.max_pool2d_with_indices_backward @@ -3007,86 +3094,90 @@ def _aten_logical_xor(self, other): # aten.neg @op(torch.ops.aten.neg) def _aten_neg(x): - return -1 * x + return -1 * x @op(torch.ops.aten.nextafter) def _aten_nextafter(input, other, *, out=None): - return jnp.nextafter(input, other) + return jnp.nextafter(input, other) @op(torch.ops.aten.nonzero_static) def _aten_nonzero_static(input, size, fill_value=-1): - indices = jnp.argwhere(input) + indices = jnp.argwhere(input) - if size < indices.shape[0]: - indices = indices[:size] - elif size > indices.shape[0]: - padding = jnp.full((size - indices.shape[0], indices.shape[1]), - fill_value, - dtype=indices.dtype) - indices = jnp.concatenate((indices, padding)) + if size < indices.shape[0]: + indices = indices[:size] + elif size > indices.shape[0]: + padding = jnp.full( + (size - indices.shape[0], indices.shape[1]), + fill_value, + dtype=indices.dtype, + ) + indices = jnp.concatenate((indices, padding)) - return indices + return indices # aten.nonzero @op(torch.ops.aten.nonzero) def _aten_nonzero(x, as_tuple=False): - if jnp.ndim(x) == 0 and (as_tuple or x.item() == 0): - return torch.empty(0, 0, dtype=torch.int64) - if jnp.ndim( - x - ) == 0: # when x is scalar, return torch.tensor([], size=(1, 0), dtype=torch.int64) - res = torch.empty(1, 0, dtype=torch.int64) - return jnp.array(res.numpy()) - index_tuple = jnp.nonzero(x) - index_tuple = [jnp.expand_dims(p, -1) for p in index_tuple] - return jnp.concatenate(index_tuple, axis=-1) + if jnp.ndim(x) == 0 and (as_tuple or x.item() == 0): + return torch.empty(0, 0, dtype=torch.int64) + if ( + jnp.ndim(x) == 0 + ): # when x is scalar, return torch.tensor([], size=(1, 0), dtype=torch.int64) + res = torch.empty(1, 0, dtype=torch.int64) + return jnp.array(res.numpy()) + index_tuple = jnp.nonzero(x) + index_tuple = [jnp.expand_dims(p, -1) for p in index_tuple] + return jnp.concatenate(index_tuple, axis=-1) # aten.prod @op(torch.ops.aten.prod) def _aten_prod(input, dim=None, keepdim=False, *, dtype=None): - if dtype: - input = input.astype(mappings.t2j_dtype(dtype)) - return _with_reduction_scalar(jnp.prod, input, dim, keepdim) + if dtype: + input = input.astype(mappings.t2j_dtype(dtype)) + return _with_reduction_scalar(jnp.prod, input, dim, keepdim) @op(torch.ops.aten.put) def _aten_put(self, index, source, accumulate=False): - expanded = False - res = None + expanded = False + res = None - if self.ndim == 0: - expanded = True - self = jnp.expand_dims(self, 0) + if self.ndim == 0: + expanded = True + self = jnp.expand_dims(self, 0) - if accumulate: - tmp = jnp.zeros(self.shape) - tmp = jnp.put(tmp, index, source, inplace=False) - res = jnp.add(self, tmp).astype(self.dtype) - else: - res = jnp.put(self, index, source, inplace=False) + if accumulate: + tmp = jnp.zeros(self.shape) + tmp = jnp.put(tmp, index, source, inplace=False) + res = jnp.add(self, tmp).astype(self.dtype) + else: + res = jnp.put(self, index, source, inplace=False) - if expanded: - res = res.squeeze() + if expanded: + res = res.squeeze() - return res + return res # aten.randperm # randperm.generator(SymInt n, *, Generator? generator, ScalarType? dtype=long, Layout? layout=None, Device? device=None, bool? pin_memory=None) @op(torch.ops.aten.randperm, needs_env=True) -def _aten_randperm(n, - *, - generator=None, - dtype=None, - layout=None, - device=None, - pin_memory=None, - env=None): - """ +def _aten_randperm( + n, + *, + generator=None, + dtype=None, + layout=None, + device=None, + pin_memory=None, + env=None, +): + """ Generates a random permutation of integers from 0 to n-1. Args: @@ -3100,14 +3191,14 @@ def _aten_randperm(n, Returns: A DeviceArray containing a random permutation of integers from 0 to n-1. """ - if dtype: - dtype = mappings.t2j_dtype(dtype) - else: - dtype = jnp.int64.dtype - key = env.get_and_rotate_prng_key(generator) - indices = jnp.arange(n, dtype=dtype) - permutation = jax.random.permutation(key, indices) - return permutation + if dtype: + dtype = mappings.t2j_dtype(dtype) + else: + dtype = jnp.int64.dtype + key = env.get_and_rotate_prng_key(generator) + indices = jnp.arange(n, dtype=dtype) + permutation = jax.random.permutation(key, indices) + return permutation # aten.reflection_pad3d @@ -3116,13 +3207,13 @@ def _aten_randperm(n, # aten.remainder @op(torch.ops.aten.remainder) def _aten_remainder(inputs, other): - return inputs % other + return inputs % other # aten.repeat @op(torch.ops.aten.repeat) def _aten_repeat(x, reps): - return jnp.tile(x, reps) + return jnp.tile(x, reps) # aten.replication_pad2d @@ -3130,31 +3221,31 @@ def _aten_repeat(x, reps): # aten.roll @op(torch.ops.aten.roll) def _aten_roll(input, shifts, dims=None): - return jnp.roll(input, shifts, dims) + return jnp.roll(input, shifts, dims) # aten.slice_scatter @op(torch.ops.aten.slice_scatter) def _aten_slice_scatter(input, src, dim=0, start=None, end=None, step=1): - input_index = [] - for x in range(len(input.shape)): - if x == dim: - input_index.append(slice(start, end, step)) - else: - input_index.append(slice(None, None, None)) - return input.at[tuple(input_index)].set(src) + input_index = [] + for x in range(len(input.shape)): + if x == dim: + input_index.append(slice(start, end, step)) + else: + input_index.append(slice(None, None, None)) + return input.at[tuple(input_index)].set(src) # aten.sort # torch.sort(input, dim=-1, descending=False, stable=False, *, out=None) @op(torch.ops.aten.sort) def _aten_sort(a, dim=-1, descending=False, stable=False): - if a.shape == (): - return (a, jnp.astype(0, 'int64')) - return ( - jnp.sort(a, axis=dim, stable=stable, descending=descending), - jnp.argsort(a, axis=dim, stable=stable, descending=descending), - ) + if a.shape == (): + return (a, jnp.astype(0, "int64")) + return ( + jnp.sort(a, axis=dim, stable=stable, descending=descending), + jnp.argsort(a, axis=dim, stable=stable, descending=descending), + ) # aten.sym_size @@ -3163,96 +3254,101 @@ def _aten_sort(a, dim=-1, descending=False, stable=False): # aten.topk @op(torch.ops.aten.topk) def _aten_topk(input, k, dim=None, largest=True, sorted=True, *, out=None): - """JAX top-k implementation using jax.lax.top_k for improved efficiency. - - Args: - input: The input JAX array. - k: The number of top elements to return. - dim: The dimension along which to find the top-k. If None, operates on the - flattened array. - largest: If True, returns the largest k elements. Otherwise, smallest k. - sorted: If True, returns the elements in sorted order. - - Returns: - A tuple (values, indices) containing: - - values: The top k values. - - indices: The indices of the top k values in the original array. - """ - if dim is None: - # last dim is chosen - dim = input.ndim - 1 - - if dim < 0: - dim = dim + input.ndim - - if not largest: - input = -input # Find top-k of negated input if we want the smallest - - if input.ndim == 0: - return input, jnp.array(0, dtype=jnp.int64.dtype) - - transpose_shape = None - if dim != -1 and dim != len(input.shape) - 1: - transpose_shape = list(range(len(input.shape))) - transpose_shape[dim], transpose_shape[-1] = ( - transpose_shape[-1], - transpose_shape[dim], - ) - input = jnp.transpose(input, transpose_shape) + """JAX top-k implementation using jax.lax.top_k for improved efficiency. - values, indices = jax.lax.top_k(input, k) + Args: + input: The input JAX array. + k: The number of top elements to return. + dim: The dimension along which to find the top-k. If None, operates on the + flattened array. + largest: If True, returns the largest k elements. Otherwise, smallest k. + sorted: If True, returns the elements in sorted order. - if sorted: - values = jnp.sort(values, descending=True) - indices = jnp.take_along_axis( - indices, jnp.argsort(values, axis=-1, descending=True), axis=-1) + Returns: + A tuple (values, indices) containing: + - values: The top k values. + - indices: The indices of the top k values in the original array. + """ + if dim is None: + # last dim is chosen + dim = input.ndim - 1 - if not largest: - values = -values # Negate values back if we found smallest + if dim < 0: + dim = dim + input.ndim + + if not largest: + input = -input # Find top-k of negated input if we want the smallest + + if input.ndim == 0: + return input, jnp.array(0, dtype=jnp.int64.dtype) + + transpose_shape = None + if dim != -1 and dim != len(input.shape) - 1: + transpose_shape = list(range(len(input.shape))) + transpose_shape[dim], transpose_shape[-1] = ( + transpose_shape[-1], + transpose_shape[dim], + ) + input = jnp.transpose(input, transpose_shape) - if transpose_shape is not None: - values = jnp.transpose(values, transpose_shape) - indices = jnp.transpose(indices, transpose_shape) + values, indices = jax.lax.top_k(input, k) - return values, indices + if sorted: + values = jnp.sort(values, descending=True) + indices = jnp.take_along_axis( + indices, jnp.argsort(values, axis=-1, descending=True), axis=-1 + ) + + if not largest: + values = -values # Negate values back if we found smallest + + if transpose_shape is not None: + values = jnp.transpose(values, transpose_shape) + indices = jnp.transpose(indices, transpose_shape) + + return values, indices # aten.tril_indices -#tril_indices(int row, int col, int offset=0, *, ScalarType? dtype=long, Layout? layout=None, Device? device=None, bool? pin_memory=None) +# tril_indices(int row, int col, int offset=0, *, ScalarType? dtype=long, Layout? layout=None, Device? device=None, bool? pin_memory=None) @op(torch.ops.aten.tril_indices) -def _aten_tril_indices(row, - col, - offset=0, - *, - dtype=jnp.int64.dtype, - layout=None, - device=None, - pin_memory=None): - a, b = jnp.tril_indices(row, offset, col) - return jnp.stack((a, b)) +def _aten_tril_indices( + row, + col, + offset=0, + *, + dtype=jnp.int64.dtype, + layout=None, + device=None, + pin_memory=None, +): + a, b = jnp.tril_indices(row, offset, col) + return jnp.stack((a, b)) # aten.tril_indices -#tril_indices(int row, int col, int offset=0, *, ScalarType? dtype=long, Layout? layout=None, Device? device=None, bool? pin_memory=None) +# tril_indices(int row, int col, int offset=0, *, ScalarType? dtype=long, Layout? layout=None, Device? device=None, bool? pin_memory=None) @op(torch.ops.aten.triu_indices) -def _aten_triu_indices(row, - col, - offset=0, - *, - dtype=jnp.int64.dtype, - layout=None, - device=None, - pin_memory=None): - a, b = jnp.triu_indices(row, offset, col) - return jnp.stack((a, b)) +def _aten_triu_indices( + row, + col, + offset=0, + *, + dtype=jnp.int64.dtype, + layout=None, + device=None, + pin_memory=None, +): + a, b = jnp.triu_indices(row, offset, col) + return jnp.stack((a, b)) @op(torch.ops.aten.unbind_copy) def _aten_unbind(a, dim=0): - return [ - jax.lax.index_in_dim(a, i, dim, keepdims=False) - for i in range(a.shape[dim]) - ] + return [ + jax.lax.index_in_dim(a, i, dim, keepdims=False) + for i in range(a.shape[dim]) + ] # aten.unique_dim @@ -3260,35 +3356,36 @@ def _aten_unbind(a, dim=0): # NOTE: Like the CUDA and CPU implementations, this implementation always sorts # the tensor regardless of the `sorted` argument passed to `torch.unique`. @op(torch.ops.aten.unique_dim) -def _aten_unique_dim(input_tensor, - dim, - sort=True, - return_inverse=False, - return_counts=False): - result_tensor_or_tuple = jnp.unique( - input_tensor, - return_index=False, - return_inverse=return_inverse, - return_counts=return_counts, - axis=dim, - equal_nan=False) - result_list = ( - list(result_tensor_or_tuple) if isinstance(result_tensor_or_tuple, tuple) - else [result_tensor_or_tuple]) - - if not return_inverse: - result_list.insert(1, None) - elif _jax_version < (0, 4, 31) and dim is not None: - result_list[1] = result_list[1].flatten() - - if not return_counts: - result_list.insert(2, None) - - # [result, None, None] if return_inverse=False and return_counts=False - # [result, inverse, None] if return_inverse=True and return_counts=False - # [result, None, counts] if return_inverse=False and return_counts=True - # [result, inverse, counts] if return_inverse=True and return_counts=True - return result_list +def _aten_unique_dim( + input_tensor, dim, sort=True, return_inverse=False, return_counts=False +): + result_tensor_or_tuple = jnp.unique( + input_tensor, + return_index=False, + return_inverse=return_inverse, + return_counts=return_counts, + axis=dim, + equal_nan=False, + ) + result_list = ( + list(result_tensor_or_tuple) + if isinstance(result_tensor_or_tuple, tuple) + else [result_tensor_or_tuple] + ) + + if not return_inverse: + result_list.insert(1, None) + elif _jax_version < (0, 4, 31) and dim is not None: + result_list[1] = result_list[1].flatten() + + if not return_counts: + result_list.insert(2, None) + + # [result, None, None] if return_inverse=False and return_counts=False + # [result, inverse, None] if return_inverse=True and return_counts=False + # [result, None, counts] if return_inverse=False and return_counts=True + # [result, inverse, counts] if return_inverse=True and return_counts=True + return result_list # aten._unique @@ -3297,17 +3394,18 @@ def _aten_unique_dim(input_tensor, # the tensor regardless of the `sorted` argument passed to `torch.unique`. @op(torch.ops.aten._unique) def _aten_unique(input_tensor, sort=True, return_inverse=False): - result_tensor_or_tuple = jnp.unique( - input_tensor, - return_index=False, - return_inverse=return_inverse, - return_counts=False, - axis=None, - equal_nan=False) - if return_inverse: - return result_tensor_or_tuple - else: - return (result_tensor_or_tuple, None) + result_tensor_or_tuple = jnp.unique( + input_tensor, + return_index=False, + return_inverse=return_inverse, + return_counts=False, + axis=None, + equal_nan=False, + ) + if return_inverse: + return result_tensor_or_tuple + else: + return (result_tensor_or_tuple, None) # aten._unique2 @@ -3315,79 +3413,86 @@ def _aten_unique(input_tensor, sort=True, return_inverse=False): # NOTE: Like the CUDA and CPU implementations, this implementation always sorts # the tensor regardless of the `sorted` argument passed to `torch.unique`. @op(torch.ops.aten._unique2) -def _aten_unique2(input_tensor, - sort=True, - return_inverse=False, - return_counts=False): - return _aten_unique_dim( - input_tensor=input_tensor, - dim=None, - sort=sort, - return_inverse=return_inverse, - return_counts=return_counts) +def _aten_unique2( + input_tensor, sort=True, return_inverse=False, return_counts=False +): + return _aten_unique_dim( + input_tensor=input_tensor, + dim=None, + sort=sort, + return_inverse=return_inverse, + return_counts=return_counts, + ) # aten.unique_consecutive @op(torch.ops.aten.unique_consecutive) -def _aten_unique_consecutive(input_tensor, - return_inverse=False, - return_counts=None, - dim=None): - # Explanation of computations (shown in 1D for simplicity): - # - # Input [a b b c c c d d d d e e e e e] - # Slice dropping final element (input[:-1]) [a b b c c c d d d d e e e e] - # Slice dropping first element (input[1:]) [b b c c c d d d d e e e e e] - # Boolean != operation on shifted slices [1 0 1 0 0 1 0 0 0 1 0 0 0 0] - # Prepend 1 to represent the first element [1 1 0 1 0 0 1 0 0 0 1 0 0 0 0] - # Filter input by the resulting bool array [a b c d e ] - # Output [a b c d e] - - if dim is None: - inverse_shape = input_tensor.shape - input_tensor = input_tensor.flatten() - ndim = 1 - dim = 0 - else: - inverse_shape = input_tensor.shape[dim] - ndim = input_tensor.ndim - if dim < 0: - dim += ndim +def _aten_unique_consecutive( + input_tensor, return_inverse=False, return_counts=None, dim=None +): + # Explanation of computations (shown in 1D for simplicity): + # + # Input [a b b c c c d d d d e e e e e] + # Slice dropping final element (input[:-1]) [a b b c c c d d d d e e e e] + # Slice dropping first element (input[1:]) [b b c c c d d d d e e e e e] + # Boolean != operation on shifted slices [1 0 1 0 0 1 0 0 0 1 0 0 0 0] + # Prepend 1 to represent the first element [1 1 0 1 0 0 1 0 0 0 1 0 0 0 0] + # Filter input by the resulting bool array [a b c d e ] + # Output [a b c d e] + + if dim is None: + inverse_shape = input_tensor.shape + input_tensor = input_tensor.flatten() + ndim = 1 + dim = 0 + else: + inverse_shape = input_tensor.shape[dim] + ndim = input_tensor.ndim + if dim < 0: + dim += ndim - nd_slice_0 = tuple( - slice(None, -1) if d == dim else slice(None) for d in range(ndim)) - nd_slice_1 = tuple( - slice(1, None) if d == dim else slice(None) for d in range(ndim)) + nd_slice_0 = tuple( + slice(None, -1) if d == dim else slice(None) for d in range(ndim) + ) + nd_slice_1 = tuple( + slice(1, None) if d == dim else slice(None) for d in range(ndim) + ) - axes_to_reduce = tuple(d for d in range(ndim) if d != dim) + axes_to_reduce = tuple(d for d in range(ndim) if d != dim) - does_not_equal_prior = ( - jnp.any( - input_tensor[nd_slice_0] != input_tensor[nd_slice_1], - axis=axes_to_reduce, - keepdims=False)) + does_not_equal_prior = jnp.any( + input_tensor[nd_slice_0] != input_tensor[nd_slice_1], + axis=axes_to_reduce, + keepdims=False, + ) - if input_tensor.shape[dim] != 0: - # Prepend `True` to represent the first element of the input. - does_not_equal_prior = jnp.insert(does_not_equal_prior, 0, True) + if input_tensor.shape[dim] != 0: + # Prepend `True` to represent the first element of the input. + does_not_equal_prior = jnp.insert(does_not_equal_prior, 0, True) - include_indices = jnp.argwhere(does_not_equal_prior)[:, 0] + include_indices = jnp.argwhere(does_not_equal_prior)[:, 0] - output_tensor = input_tensor[tuple( - include_indices if d == dim else slice(None) for d in range(ndim))] + output_tensor = input_tensor[ + tuple(include_indices if d == dim else slice(None) for d in range(ndim)) + ] - if return_inverse or return_counts: - counts = ( - jnp.append(include_indices[1:], input_tensor.shape[dim]) - - include_indices[:]) + if return_inverse or return_counts: + counts = ( + jnp.append(include_indices[1:], input_tensor.shape[dim]) + - include_indices[:] + ) - inverse = ( - jnp.reshape(jnp.repeat(jnp.arange(len(counts)), counts), inverse_shape) - if return_inverse else None) + inverse = ( + jnp.reshape( + jnp.repeat(jnp.arange(len(counts)), counts), inverse_shape + ) + if return_inverse + else None + ) - return output_tensor, inverse, counts + return output_tensor, inverse, counts - return output_tensor, None, None + return output_tensor, None, None # NOTE: skip aten.upsample_nearest2d and aten.upsample_bilinear2d @@ -3402,38 +3507,39 @@ def _aten_unique_consecutive(input_tensor, @op(torch.ops.aten.where.ScalarOther) @op(torch.ops.aten.where.Scalar) def _aten_where(condition, x=None, y=None): - return jnp.where(condition, x, y) + return jnp.where(condition, x, y) # aten.to.dtype # Tensor(a) self, ScalarType dtype, bool non_blocking=False, bool copy=False, MemoryFormat? memory_format=None @op(torch.ops.aten.to.dtype) -def _aten_to_dtype(a, - dtype, - non_blocking=False, - copy=False, - memory_format=None): - if dtype: - jaxdtype = mappings.t2j_dtype(dtype) - return a.astype(jaxdtype) +def _aten_to_dtype( + a, dtype, non_blocking=False, copy=False, memory_format=None +): + if dtype: + jaxdtype = mappings.t2j_dtype(dtype) + return a.astype(jaxdtype) @op(torch.ops.aten.to.dtype_layout) -def _aten_to_dtype_layout(a, - *, - dtype=None, - layout=None, - device=None, - pin_memory=None, - non_blocking=False, - copy=False, - memory_format=None): - return _aten_to_dtype( - a, - dtype, - non_blocking=non_blocking, - copy=copy, - memory_format=memory_format) +def _aten_to_dtype_layout( + a, + *, + dtype=None, + layout=None, + device=None, + pin_memory=None, + non_blocking=False, + copy=False, + memory_format=None, +): + return _aten_to_dtype( + a, + dtype, + non_blocking=non_blocking, + copy=copy, + memory_format=memory_format, + ) # aten.to.device @@ -3442,87 +3548,97 @@ def _aten_to_dtype_layout(a, # Tensor self, int[1]? dim=None, *, Scalar? correction=None, bool keepdim=False @op(torch.ops.aten.var_mean.correction) def _aten_var_mean_correction(tensor, dim=None, correction=1, keepdim=False): - # The internal API technically has a default `correction` argument of `None`, - # but the public API has a default argument of 1. Therefore, we simply set our - # default argument to 1. However, since the argument is officially supposed to - # be nullable, we still need to check for `None` per the API contract. - if correction is None: - correction = 1 - mean = jnp.mean(tensor, axis=dim, keepdims=keepdim) - # TODO: Pass in the `mean=mean` argument once `jax.numpy.var` supports it. - var = jnp.var(tensor, axis=dim, ddof=correction, keepdims=keepdim) - return var, mean + # The internal API technically has a default `correction` argument of `None`, + # but the public API has a default argument of 1. Therefore, we simply set our + # default argument to 1. However, since the argument is officially supposed to + # be nullable, we still need to check for `None` per the API contract. + if correction is None: + correction = 1 + mean = jnp.mean(tensor, axis=dim, keepdims=keepdim) + # TODO: Pass in the `mean=mean` argument once `jax.numpy.var` supports it. + var = jnp.var(tensor, axis=dim, ddof=correction, keepdims=keepdim) + return var, mean @op(torch.ops.aten.scalar_tensor) @op_base.convert_dtype() -def _aten_scalar_tensor(s, - dtype=None, - layout=None, - device=None, - pin_memory=None): - return jnp.array(s, dtype=dtype) +def _aten_scalar_tensor( + s, dtype=None, layout=None, device=None, pin_memory=None +): + return jnp.array(s, dtype=dtype) @op(torch.ops.aten.to.device) def _aten_to_device(x, device, dtype): - return x + return x @op(torch.ops.aten.max_pool2d_with_indices_backward) -def max_pool2d_with_indices_backward_custom(grad_output, self, kernel_size, - stride, padding, dilation, - ceil_mode, indices): - """ - Approximates the gradient calculation of PyTorch's max_pool2d_with_indices_backward. - - Args: - grad_output: The gradient tensor from the preceding layer. - self: The input tensor on which the original max pooling was performed. - kernel_size: The size of the pooling window. - stride: The stride of the pooling window. - padding: The padding applied during max pooling. - dilation: The dilation factor for the pooling operation. - ceil_mode: Whether to use ceil or floor when calculating output shapes. - indices: The indices of the maximum values, as produced by max_pool2d_with_indices. - - Returns: - The calculated gradient with respect to the input (grad_input). - """ - - kH, kW = kernel_size - dH, dW = stride - padH, padW = padding - dilH, dilW = dilation - - # Calculate output shape (may need adjustment based on ceil_mode) - out_shape = jnp.array(self.shape) - grad_input = jnp.zeros_like(self) - - # Iterate over the flattened input and output tensors - for i, idx in enumerate(indices.flatten()): - # Calculate input coordinates corresponding to the maximum value - out_y, out_x = i // grad_output.shape[3], i % grad_output.shape[3] - in_y = out_y * dH - padH + out_y * (dilH - 1) - in_x = out_x * dW - padW + out_x * (dilW - 1) - - # Scatter the gradient to the appropriate input locations (handling potential overlaps) - for y in range(in_y, in_y + kH): - for x in range(in_x, in_x + kW): - if 0 <= y < grad_input.shape[2] and 0 <= x < grad_input.shape[3]: - grad_input = grad_input.at[y, x].add(grad_output.flatten()[i]) - - return grad_input +def max_pool2d_with_indices_backward_custom( + grad_output, + self, + kernel_size, + stride, + padding, + dilation, + ceil_mode, + indices, +): + """ + Approximates the gradient calculation of PyTorch's max_pool2d_with_indices_backward. + + Args: + grad_output: The gradient tensor from the preceding layer. + self: The input tensor on which the original max pooling was performed. + kernel_size: The size of the pooling window. + stride: The stride of the pooling window. + padding: The padding applied during max pooling. + dilation: The dilation factor for the pooling operation. + ceil_mode: Whether to use ceil or floor when calculating output shapes. + indices: The indices of the maximum values, as produced by max_pool2d_with_indices. + + Returns: + The calculated gradient with respect to the input (grad_input). + """ + + kH, kW = kernel_size + dH, dW = stride + padH, padW = padding + dilH, dilW = dilation + + # Calculate output shape (may need adjustment based on ceil_mode) + out_shape = jnp.array(self.shape) + grad_input = jnp.zeros_like(self) + + # Iterate over the flattened input and output tensors + for i, idx in enumerate(indices.flatten()): + # Calculate input coordinates corresponding to the maximum value + out_y, out_x = i // grad_output.shape[3], i % grad_output.shape[3] + in_y = out_y * dH - padH + out_y * (dilH - 1) + in_x = out_x * dW - padW + out_x * (dilW - 1) + + # Scatter the gradient to the appropriate input locations (handling potential overlaps) + for y in range(in_y, in_y + kH): + for x in range(in_x, in_x + kW): + if ( + 0 <= y < grad_input.shape[2] + and 0 <= x < grad_input.shape[3] + ): + grad_input = grad_input.at[y, x].add( + grad_output.flatten()[i] + ) + + return grad_input @op(torch.ops.aten._local_scalar_dense) def _aten_local_scalar_dense(x): - return x.item() + return x.item() @op(torch.ops.aten.tensor_split.sections) def _aten_tensor_split(ary, indices_or_sections, axis=0): - return jnp.array_split(ary, indices_or_sections, axis) + return jnp.array_split(ary, indices_or_sections, axis) @op(torch.ops.aten.randn, needs_env=True) @@ -3538,14 +3654,14 @@ def _randn( pin_memory=False, env=None, ): - shape = size - if len(shape) == 1 and isinstance(shape[0], (list, tuple)): - shape = shape[0] - key = env.get_and_rotate_prng_key(generator) - res = jax.random.normal(key, shape) - if dtype is not None: - res = res.astype(dtype) - return res + shape = size + if len(shape) == 1 and isinstance(shape[0], (list, tuple)): + shape = shape[0] + key = env.get_and_rotate_prng_key(generator) + res = jax.random.normal(key, shape) + if dtype is not None: + res = res.astype(dtype) + return res @op(torch.ops.aten.bernoulli.p, needs_env=True) @@ -3556,16 +3672,16 @@ def _aten_bernoulli( generator=None, env=None, ): - key = env.get_and_rotate_prng_key(generator) - res = jax.random.uniform(key, self.shape) < p - return res + key = env.get_and_rotate_prng_key(generator) + res = jax.random.uniform(key, self.shape) < p + return res @op(torch.ops.aten.geometric, needs_env=True) def geometric(self, p, *, generator=None, env=None): - key = env.get_and_rotate_prng_key(generator) - res = jax.random.geometric(key, p, self.shape) - return res + key = env.get_and_rotate_prng_key(generator) + res = jax.random.geometric(key, p, self.shape) + return res @op(torch.ops.aten.randn_like, needs_env=True) @@ -3580,8 +3696,8 @@ def _aten_randn_like( memory_format=torch.preserve_format, env=None, ): - key = env.get_and_rotate_prng_key() - return jax.random.normal(key, dtype=dtype or x.dtype, shape=x.shape) + key = env.get_and_rotate_prng_key() + return jax.random.normal(key, dtype=dtype or x.dtype, shape=x.shape) @op(torch.ops.aten.rand, needs_env=True) @@ -3597,76 +3713,88 @@ def _rand( pin_memory=False, env=None, ): - shape = size - if len(shape) == 1 and isinstance(shape[0], (list, tuple)): - shape = shape[0] - key = env.get_and_rotate_prng_key(generator) - res = jax.random.uniform(key, shape) - if dtype is not None: - res = res.astype(dtype) - return res + shape = size + if len(shape) == 1 and isinstance(shape[0], (list, tuple)): + shape = shape[0] + key = env.get_and_rotate_prng_key(generator) + res = jax.random.uniform(key, shape) + if dtype is not None: + res = res.astype(dtype) + return res @op(torch.ops.aten.outer) def _aten_outer(a, b): - return jnp.outer(a, b) + return jnp.outer(a, b) @op(torch.ops.aten.allclose) def _aten_allclose(input, other, rtol=1e-05, atol=1e-08, equal_nan=False): - return jnp.allclose(input, other, rtol, atol, equal_nan) + return jnp.allclose(input, other, rtol, atol, equal_nan) @op(torch.ops.aten.native_batch_norm) -def _aten_native_batch_norm(input, - weight, - bias, - running_mean, - running_var, - training=False, - momentum=0.1, - eps=1e-5): - - if running_mean is None: - running_mean = jnp.zeros( - input.shape[1], dtype=input.dtype) # Initialize running mean if None - if running_var is None: - running_var = jnp.ones( - input.shape[1], - dtype=input.dtype) # Initialize running variance if None - - if training: - return _aten__native_batch_norm_legit(input, weight, bias, running_mean, - running_var, training, momentum, eps) - else: - return _aten__native_batch_norm_legit_no_training(input, weight, bias, - running_mean, running_var, - momentum, eps) +def _aten_native_batch_norm( + input, + weight, + bias, + running_mean, + running_var, + training=False, + momentum=0.1, + eps=1e-5, +): + if running_mean is None: + running_mean = jnp.zeros( + input.shape[1], dtype=input.dtype + ) # Initialize running mean if None + if running_var is None: + running_var = jnp.ones( + input.shape[1], dtype=input.dtype + ) # Initialize running variance if None + + if training: + return _aten__native_batch_norm_legit( + input, + weight, + bias, + running_mean, + running_var, + training, + momentum, + eps, + ) + else: + return _aten__native_batch_norm_legit_no_training( + input, weight, bias, running_mean, running_var, momentum, eps + ) @op(torch.ops.aten.normal, needs_env=True) def _aten_normal(self, mean=0, std=1, generator=None, env=None): - shape = self.shape - res = _randn(*shape, generator=generator, env=env) - return res * std + mean + shape = self.shape + res = _randn(*shape, generator=generator, env=env) + return res * std + mean # TODO: not clear what this function should actually do # https://github.com/pytorch/pytorch/blob/d96c80649f301129219469d8b4353e52edab3b78/aten/src/ATen/native/native_functions.yaml#L7933-L7940 @op(torch.ops.aten.lift_fresh) def _aten_lift_fresh(self): - return self + return self @op(torch.ops.aten.uniform, needs_env=True) def _aten_uniform(self, from_=0, to=1, *, generator=None, env=None): - assert from_ <= to, f'Uniform from(passed in {from_}) must be less than to(passed in {to})' - shape = self.shape - res = _rand(*shape, generator=generator, env=env) - return res * (to - from_) + from_ + assert from_ <= to, ( + f"Uniform from(passed in {from_}) must be less than to(passed in {to})" + ) + shape = self.shape + res = _rand(*shape, generator=generator, env=env) + return res * (to - from_) + from_ -#func: randint.low_generator(SymInt low, SymInt high, SymInt[] size, *, Generator? generator, ScalarType? dtype=long, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor +# func: randint.low_generator(SymInt low, SymInt high, SymInt[] size, *, Generator? generator, ScalarType? dtype=long, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor @op(torch.ops.aten.randint, needs_env=True) @@ -3678,26 +3806,29 @@ def _aten_randint( env=None, **kwargs, ): - if len(args) == 3: - # low, high, size - low, high, size = args - elif len(args) == 2: - high, size = args - low = 0 - else: - raise AssertionError( - f'Expected at 2 or 3 args for Aten::randint, got {len(args)}') - - key = env.get_and_rotate_prng_key(generator) - res = jax.random.randint(key, size, low, high) - if dtype is not None: - res = res.astype(dtype) - return res - - -@op(torch.ops.aten.randint_like, + if len(args) == 3: + # low, high, size + low, high, size = args + elif len(args) == 2: + high, size = args + low = 0 + else: + raise AssertionError( + f"Expected at 2 or 3 args for Aten::randint, got {len(args)}" + ) + + key = env.get_and_rotate_prng_key(generator) + res = jax.random.randint(key, size, low, high) + if dtype is not None: + res = res.astype(dtype) + return res + + +@op( + torch.ops.aten.randint_like, torch.ops.aten.randint.generator, - needs_env=True) + needs_env=True, +) @op_base.convert_dtype(use_default_dtype=False) def _aten_randint_like( input, @@ -3707,1863 +3838,1962 @@ def _aten_randint_like( env=None, **kwargs, ): - if len(args) == 2: - low, high = args - elif len(args) == 1: - high = args[0] - low = 0 - else: - raise AssertionError( - f'Expected at 1 or 2 args for Aten::randint_like, got {len(args)}') - - shape = input.shape - dtype = dtype or input.dtype - key = env.get_and_rotate_prng_key(generator) - res = jax.random.randint(key, shape, low, high) - if dtype is not None: - res = res.astype(dtype) - return res + if len(args) == 2: + low, high = args + elif len(args) == 1: + high = args[0] + low = 0 + else: + raise AssertionError( + f"Expected at 1 or 2 args for Aten::randint_like, got {len(args)}" + ) + + shape = input.shape + dtype = dtype or input.dtype + key = env.get_and_rotate_prng_key(generator) + res = jax.random.randint(key, shape, low, high) + if dtype is not None: + res = res.astype(dtype) + return res @op(torch.ops.aten.dim, is_jax_function=False) def _aten_dim(self): - return len(self.shape) + return len(self.shape) @op(torch.ops.aten.copysign) def _aten_copysign(input, other, *, out=None): - result = jnp.copysign(input, other) - # torch.copysign(x, y) returns float32 for integer x and y, - # regardless of their exact integer dtype, whereas jax.copysign returns - # float64 when one or both of them is int64. - if jnp.issubdtype(input.dtype, jnp.integer) and jnp.issubdtype( - other.dtype, jnp.integer): - result = result.astype(jnp.float32) - return result + result = jnp.copysign(input, other) + # torch.copysign(x, y) returns float32 for integer x and y, + # regardless of their exact integer dtype, whereas jax.copysign returns + # float64 when one or both of them is int64. + if jnp.issubdtype(input.dtype, jnp.integer) and jnp.issubdtype( + other.dtype, jnp.integer + ): + result = result.astype(jnp.float32) + return result @op(torch.ops.aten.i0) @op_base.promote_int_input def _aten_i0(self): - return jax.scipy.special.i0(self) + return jax.scipy.special.i0(self) @op(torch.ops.aten.special_i0e) @op_base.promote_int_input def _aten_i0e(self): - return jax.scipy.special.i0e(self) + return jax.scipy.special.i0e(self) @op(torch.ops.aten.special_i1) @op_base.promote_int_input def _aten_special_i1(self): - return jax.scipy.special.i1(self) + return jax.scipy.special.i1(self) @op(torch.ops.aten.special_i1e) @op_base.promote_int_input def _aten_special_i1e(self): - return jax.scipy.special.i1e(self) + return jax.scipy.special.i1e(self) @op(torch.ops.aten.special_laguerre_polynomial_l) @op_base.promote_int_input def _aten_special_laguerre_polynomial_l(self, n): - # Adapted from https://github.com/pytorch/pytorch/blob/f8f41dcb24cb4f4e87a51bb04847942dd835e496/aten/src/ATen/native/Math.h#L3106-L3134 + # Adapted from https://github.com/pytorch/pytorch/blob/f8f41dcb24cb4f4e87a51bb04847942dd835e496/aten/src/ATen/native/Math.h#L3106-L3134 - @jnp.vectorize - def vectorized(x, n_i): + @jnp.vectorize + def vectorized(x, n_i): + def negative_n(x): + return jnp.zeros_like(x) - def negative_n(x): - return jnp.zeros_like(x) + def zero_n(x): + return jnp.ones_like(x) - def zero_n(x): - return jnp.ones_like(x) + def one_n(x): + return jnp.ones_like(x) - x - def one_n(x): - return jnp.ones_like(x) - x + def zero_abs(x): + return jnp.ones_like(x) - def zero_abs(x): - return jnp.ones_like(x) + def default(x): + def f(k, carry): + p, q = carry + return ( + q, + ((k * 2 + (jnp.ones_like(x) - x)) * q - k * p) / (k + 1), + ) - def default(x): - - def f(k, carry): - p, q = carry - return (q, ((k * 2 + (jnp.ones_like(x) - x)) * q - k * p) / (k + 1)) + _, q = jax.lax.fori_loop( + 1, n_i, f, init_val=(1.0, jnp.ones_like(x) - x) + ) + return q - _, q = jax.lax.fori_loop(1, n_i, f, init_val=(1.0, jnp.ones_like(x) - x)) - return q + return jnp.piecewise( + x, + [n_i == 1, n_i == 0, jnp.abs(n_i) == jnp.zeros_like(x), n_i < 0], + [one_n, zero_n, zero_abs, negative_n, default], + ) - return jnp.piecewise( - x, [n_i == 1, n_i == 0, - jnp.abs(n_i) == jnp.zeros_like(x), n_i < 0], - [one_n, zero_n, zero_abs, negative_n, default]) - - return vectorized(self, n.astype(jnp.int64)) + return vectorized(self, n.astype(jnp.int64)) @op(torch.ops.aten.special_modified_bessel_i0) @op_base.promote_int_input def _aten_special_modified_bessel_i0(self): - # Adapted from https://github.com/pytorch/pytorch/blob/f8f41dcb24cb4f4e87a51bb04847942dd835e496/aten/src/ATen/native/Math.h#L3182-L3268 - - def small(x): - A = jnp.array( - [ - -4.41534164647933937950e-18, - 3.33079451882223809783e-17, - -2.43127984654795469359e-16, - 1.71539128555513303061e-15, - -1.16853328779934516808e-14, - 7.67618549860493561688e-14, - -4.85644678311192946090e-13, - 2.95505266312963983461e-12, - -1.72682629144155570723e-11, - 9.67580903537323691224e-11, - -5.18979560163526290666e-10, - 2.65982372468238665035e-09, - -1.30002500998624804212e-08, - 6.04699502254191894932e-08, - -2.67079385394061173391e-07, - 1.11738753912010371815e-06, - -4.41673835845875056359e-06, - 1.64484480707288970893e-05, - -5.75419501008210370398e-05, - 1.88502885095841655729e-04, - -5.76375574538582365885e-04, - 1.63947561694133579842e-03, - -4.32430999505057594430e-03, - 1.05464603945949983183e-02, - -2.37374148058994688156e-02, - 4.93052842396707084878e-02, - -9.49010970480476444210e-02, - 1.71620901522208775349e-01, - -3.04682672343198398683e-01, - 6.76795274409476084995e-01, - ], - dtype=self.dtype, - ) - - def f(carry, val): - p, q, a = carry - p, q = q, a - return (p, q, ((x / 2.0) - 2.0) * q - p + val), None - - (p, _, a), _ = jax.lax.scan( - f, init=(jnp.zeros_like(x), jnp.zeros_like(x), 0), xs=A) - - return jnp.exp(x) * (0.5 * (a - p)) - - def default(x): - B = jnp.array( - [ - -7.23318048787475395456e-18, - -4.83050448594418207126e-18, - 4.46562142029675999901e-17, - 3.46122286769746109310e-17, - -2.82762398051658348494e-16, - -3.42548561967721913462e-16, - 1.77256013305652638360e-15, - 3.81168066935262242075e-15, - -9.55484669882830764870e-15, - -4.15056934728722208663e-14, - 1.54008621752140982691e-14, - 3.85277838274214270114e-13, - 7.18012445138366623367e-13, - -1.79417853150680611778e-12, - -1.32158118404477131188e-11, - -3.14991652796324136454e-11, - 1.18891471078464383424e-11, - 4.94060238822496958910e-10, - 3.39623202570838634515e-09, - 2.26666899049817806459e-08, - 2.04891858946906374183e-07, - 2.89137052083475648297e-06, - 6.88975834691682398426e-05, - 3.36911647825569408990e-03, - 8.04490411014108831608e-01, - ], - dtype=self.dtype, - ) - - def f(carry, val): - p, q, b = carry - p, q = q, b - return (p, q, (32.0 / x - 2.0) * q - p + val), None - - (p, _, b), _ = jax.lax.scan( - f, init=(jnp.zeros_like(x), jnp.zeros_like(x), 0), xs=B) + # Adapted from https://github.com/pytorch/pytorch/blob/f8f41dcb24cb4f4e87a51bb04847942dd835e496/aten/src/ATen/native/Math.h#L3182-L3268 + + def small(x): + A = jnp.array( + [ + -4.41534164647933937950e-18, + 3.33079451882223809783e-17, + -2.43127984654795469359e-16, + 1.71539128555513303061e-15, + -1.16853328779934516808e-14, + 7.67618549860493561688e-14, + -4.85644678311192946090e-13, + 2.95505266312963983461e-12, + -1.72682629144155570723e-11, + 9.67580903537323691224e-11, + -5.18979560163526290666e-10, + 2.65982372468238665035e-09, + -1.30002500998624804212e-08, + 6.04699502254191894932e-08, + -2.67079385394061173391e-07, + 1.11738753912010371815e-06, + -4.41673835845875056359e-06, + 1.64484480707288970893e-05, + -5.75419501008210370398e-05, + 1.88502885095841655729e-04, + -5.76375574538582365885e-04, + 1.63947561694133579842e-03, + -4.32430999505057594430e-03, + 1.05464603945949983183e-02, + -2.37374148058994688156e-02, + 4.93052842396707084878e-02, + -9.49010970480476444210e-02, + 1.71620901522208775349e-01, + -3.04682672343198398683e-01, + 6.76795274409476084995e-01, + ], + dtype=self.dtype, + ) + + def f(carry, val): + p, q, a = carry + p, q = q, a + return (p, q, ((x / 2.0) - 2.0) * q - p + val), None + + (p, _, a), _ = jax.lax.scan( + f, init=(jnp.zeros_like(x), jnp.zeros_like(x), 0), xs=A + ) + + return jnp.exp(x) * (0.5 * (a - p)) - return jnp.exp(x) * (0.5 * (b - p)) / jnp.sqrt(x) - - self = jnp.abs(self) - return jnp.piecewise(self, [self <= 8], [small, default]) + def default(x): + B = jnp.array( + [ + -7.23318048787475395456e-18, + -4.83050448594418207126e-18, + 4.46562142029675999901e-17, + 3.46122286769746109310e-17, + -2.82762398051658348494e-16, + -3.42548561967721913462e-16, + 1.77256013305652638360e-15, + 3.81168066935262242075e-15, + -9.55484669882830764870e-15, + -4.15056934728722208663e-14, + 1.54008621752140982691e-14, + 3.85277838274214270114e-13, + 7.18012445138366623367e-13, + -1.79417853150680611778e-12, + -1.32158118404477131188e-11, + -3.14991652796324136454e-11, + 1.18891471078464383424e-11, + 4.94060238822496958910e-10, + 3.39623202570838634515e-09, + 2.26666899049817806459e-08, + 2.04891858946906374183e-07, + 2.89137052083475648297e-06, + 6.88975834691682398426e-05, + 3.36911647825569408990e-03, + 8.04490411014108831608e-01, + ], + dtype=self.dtype, + ) + + def f(carry, val): + p, q, b = carry + p, q = q, b + return (p, q, (32.0 / x - 2.0) * q - p + val), None + + (p, _, b), _ = jax.lax.scan( + f, init=(jnp.zeros_like(x), jnp.zeros_like(x), 0), xs=B + ) + + return jnp.exp(x) * (0.5 * (b - p)) / jnp.sqrt(x) + + self = jnp.abs(self) + return jnp.piecewise(self, [self <= 8], [small, default]) @op(torch.ops.aten.special_modified_bessel_i1) @op_base.promote_int_input def _aten_special_modified_bessel_i1(self): - # Adapted from https://github.com/pytorch/pytorch/blob/f8f41dcb24cb4f4e87a51bb04847942dd835e496/aten/src/ATen/native/Math.h#L3271-L3364 - - def small(x): - A = jnp.array( - [ - 2.77791411276104639959e-18, - -2.11142121435816608115e-17, - 1.55363195773620046921e-16, - -1.10559694773538630805e-15, - 7.60068429473540693410e-15, - -5.04218550472791168711e-14, - 3.22379336594557470981e-13, - -1.98397439776494371520e-12, - 1.17361862988909016308e-11, - -6.66348972350202774223e-11, - 3.62559028155211703701e-10, - -1.88724975172282928790e-09, - 9.38153738649577178388e-09, - -4.44505912879632808065e-08, - 2.00329475355213526229e-07, - -8.56872026469545474066e-07, - 3.47025130813767847674e-06, - -1.32731636560394358279e-05, - 4.78156510755005422638e-05, - -1.61760815825896745588e-04, - 5.12285956168575772895e-04, - -1.51357245063125314899e-03, - 4.15642294431288815669e-03, - -1.05640848946261981558e-02, - 2.47264490306265168283e-02, - -5.29459812080949914269e-02, - 1.02643658689847095384e-01, - -1.76416518357834055153e-01, - 2.52587186443633654823e-01, - ], - dtype=self.dtype, - ) - - def f(carry, val): - p, q, a = carry - p, q = q, a - return (p, q, ((jnp.abs(x) / 2.0) - 2.0) * q - p + val), None - - (p, _, a), _ = jax.lax.scan( - f, init=(jnp.zeros_like(x), jnp.zeros_like(x), 0), xs=A) - - return jax.lax.cond( - x < 0, lambda: -(0.5 * (a - p) * jnp.abs(x) * jnp.exp(jnp.abs(x))), - lambda: 0.5 * (a - p) * jnp.abs(x) * jnp.exp(jnp.abs(x))) - - def default(x): - B = jnp.array( - [ - 7.51729631084210481353e-18, - 4.41434832307170791151e-18, - -4.65030536848935832153e-17, - -3.20952592199342395980e-17, - 2.96262899764595013876e-16, - 3.30820231092092828324e-16, - -1.88035477551078244854e-15, - -3.81440307243700780478e-15, - 1.04202769841288027642e-14, - 4.27244001671195135429e-14, - -2.10154184277266431302e-14, - -4.08355111109219731823e-13, - -7.19855177624590851209e-13, - 2.03562854414708950722e-12, - 1.41258074366137813316e-11, - 3.25260358301548823856e-11, - -1.89749581235054123450e-11, - -5.58974346219658380687e-10, - -3.83538038596423702205e-09, - -2.63146884688951950684e-08, - -2.51223623787020892529e-07, - -3.88256480887769039346e-06, - -1.10588938762623716291e-04, - -9.76109749136146840777e-03, - 7.78576235018280120474e-01, - ], - dtype=self.dtype, - ) + # Adapted from https://github.com/pytorch/pytorch/blob/f8f41dcb24cb4f4e87a51bb04847942dd835e496/aten/src/ATen/native/Math.h#L3271-L3364 + + def small(x): + A = jnp.array( + [ + 2.77791411276104639959e-18, + -2.11142121435816608115e-17, + 1.55363195773620046921e-16, + -1.10559694773538630805e-15, + 7.60068429473540693410e-15, + -5.04218550472791168711e-14, + 3.22379336594557470981e-13, + -1.98397439776494371520e-12, + 1.17361862988909016308e-11, + -6.66348972350202774223e-11, + 3.62559028155211703701e-10, + -1.88724975172282928790e-09, + 9.38153738649577178388e-09, + -4.44505912879632808065e-08, + 2.00329475355213526229e-07, + -8.56872026469545474066e-07, + 3.47025130813767847674e-06, + -1.32731636560394358279e-05, + 4.78156510755005422638e-05, + -1.61760815825896745588e-04, + 5.12285956168575772895e-04, + -1.51357245063125314899e-03, + 4.15642294431288815669e-03, + -1.05640848946261981558e-02, + 2.47264490306265168283e-02, + -5.29459812080949914269e-02, + 1.02643658689847095384e-01, + -1.76416518357834055153e-01, + 2.52587186443633654823e-01, + ], + dtype=self.dtype, + ) + + def f(carry, val): + p, q, a = carry + p, q = q, a + return (p, q, ((jnp.abs(x) / 2.0) - 2.0) * q - p + val), None + + (p, _, a), _ = jax.lax.scan( + f, init=(jnp.zeros_like(x), jnp.zeros_like(x), 0), xs=A + ) + + return jax.lax.cond( + x < 0, + lambda: -(0.5 * (a - p) * jnp.abs(x) * jnp.exp(jnp.abs(x))), + lambda: 0.5 * (a - p) * jnp.abs(x) * jnp.exp(jnp.abs(x)), + ) - def f(carry, val): - p, q, b = carry - p, q = q, b - return (p, q, (32.0 / jnp.abs(x) - 2.0) * q - p + val), None - - (p, _, b), _ = jax.lax.scan( - f, init=(jnp.zeros_like(x), jnp.zeros_like(x), 0), xs=B) - - return jax.lax.cond( - x < 0, lambda: -(jnp.exp(jnp.abs(x)) * - (0.5 * (b - p)) / jnp.sqrt(jnp.abs(x))), - lambda: jnp.exp(jnp.abs(x)) * (0.5 * (b - p)) / jnp.sqrt(jnp.abs(x))) - - return jnp.piecewise(self, [self <= 8], [small, default]) + def default(x): + B = jnp.array( + [ + 7.51729631084210481353e-18, + 4.41434832307170791151e-18, + -4.65030536848935832153e-17, + -3.20952592199342395980e-17, + 2.96262899764595013876e-16, + 3.30820231092092828324e-16, + -1.88035477551078244854e-15, + -3.81440307243700780478e-15, + 1.04202769841288027642e-14, + 4.27244001671195135429e-14, + -2.10154184277266431302e-14, + -4.08355111109219731823e-13, + -7.19855177624590851209e-13, + 2.03562854414708950722e-12, + 1.41258074366137813316e-11, + 3.25260358301548823856e-11, + -1.89749581235054123450e-11, + -5.58974346219658380687e-10, + -3.83538038596423702205e-09, + -2.63146884688951950684e-08, + -2.51223623787020892529e-07, + -3.88256480887769039346e-06, + -1.10588938762623716291e-04, + -9.76109749136146840777e-03, + 7.78576235018280120474e-01, + ], + dtype=self.dtype, + ) + + def f(carry, val): + p, q, b = carry + p, q = q, b + return (p, q, (32.0 / jnp.abs(x) - 2.0) * q - p + val), None + + (p, _, b), _ = jax.lax.scan( + f, init=(jnp.zeros_like(x), jnp.zeros_like(x), 0), xs=B + ) + + return jax.lax.cond( + x < 0, + lambda: -( + jnp.exp(jnp.abs(x)) * (0.5 * (b - p)) / jnp.sqrt(jnp.abs(x)) + ), + lambda: jnp.exp(jnp.abs(x)) + * (0.5 * (b - p)) + / jnp.sqrt(jnp.abs(x)), + ) + + return jnp.piecewise(self, [self <= 8], [small, default]) @op(torch.ops.aten.special_modified_bessel_k0) @op_base.promote_int_input def _aten_special_modified_bessel_k0(self): - # Adapted from https://github.com/pytorch/pytorch/blob/f8f41dcb24cb4f4e87a51bb04847942dd835e496/aten/src/ATen/native/Math.h#L3367-L3441 - - def zero(x): - return jnp.array(jnp.inf, x.dtype) - - def negative(x): - return jnp.array(jnp.nan, x.dtype) - - def small(x): - A = jnp.array( - [ - 1.37446543561352307156e-16, - 4.25981614279661018399e-14, - 1.03496952576338420167e-11, - 1.90451637722020886025e-09, - 2.53479107902614945675e-07, - 2.28621210311945178607e-05, - 1.26461541144692592338e-03, - 3.59799365153615016266e-02, - 3.44289899924628486886e-01, - -5.35327393233902768720e-01, - ], - dtype=self.dtype, - ) - - def f(carry, val): - p, q, a = carry - p, q = q, a - return (p, q, (x * x - 2.0) * q - p + val), None - - (p, _, a), _ = jax.lax.scan( - f, init=(jnp.zeros_like(x), jnp.zeros_like(x), 0), xs=A) - - return 0.5 * (a - p) - jnp.log( - 0.5 * x) * _aten_special_modified_bessel_i0(x) - - def default(x): - B = jnp.array( - [ - 5.30043377268626276149e-18, - -1.64758043015242134646e-17, - 5.21039150503902756861e-17, - -1.67823109680541210385e-16, - 5.51205597852431940784e-16, - -1.84859337734377901440e-15, - 6.34007647740507060557e-15, - -2.22751332699166985548e-14, - 8.03289077536357521100e-14, - -2.98009692317273043925e-13, - 1.14034058820847496303e-12, - -4.51459788337394416547e-12, - 1.85594911495471785253e-11, - -7.95748924447710747776e-11, - 3.57739728140030116597e-10, - -1.69753450938905987466e-09, - 8.57403401741422608519e-09, - -4.66048989768794782956e-08, - 2.76681363944501510342e-07, - -1.83175552271911948767e-06, - 1.39498137188764993662e-05, - -1.28495495816278026384e-04, - 1.56988388573005337491e-03, - -3.14481013119645005427e-02, - 2.44030308206595545468e+00, - ], - dtype=self.dtype, - ) - - def f(carry, val): - p, q, b = carry - p, q = q, b - return (p, q, (8.0 / x - 2.0) * q - p + val), None + # Adapted from https://github.com/pytorch/pytorch/blob/f8f41dcb24cb4f4e87a51bb04847942dd835e496/aten/src/ATen/native/Math.h#L3367-L3441 + + def zero(x): + return jnp.array(jnp.inf, x.dtype) + + def negative(x): + return jnp.array(jnp.nan, x.dtype) + + def small(x): + A = jnp.array( + [ + 1.37446543561352307156e-16, + 4.25981614279661018399e-14, + 1.03496952576338420167e-11, + 1.90451637722020886025e-09, + 2.53479107902614945675e-07, + 2.28621210311945178607e-05, + 1.26461541144692592338e-03, + 3.59799365153615016266e-02, + 3.44289899924628486886e-01, + -5.35327393233902768720e-01, + ], + dtype=self.dtype, + ) + + def f(carry, val): + p, q, a = carry + p, q = q, a + return (p, q, (x * x - 2.0) * q - p + val), None + + (p, _, a), _ = jax.lax.scan( + f, init=(jnp.zeros_like(x), jnp.zeros_like(x), 0), xs=A + ) + + return 0.5 * (a - p) - jnp.log( + 0.5 * x + ) * _aten_special_modified_bessel_i0(x) - (p, _, b), _ = jax.lax.scan( - f, init=(jnp.zeros_like(x), jnp.zeros_like(x), 0), xs=B) - - return jnp.exp(-x) * (0.5 * (b - p)) / jnp.sqrt(x) + def default(x): + B = jnp.array( + [ + 5.30043377268626276149e-18, + -1.64758043015242134646e-17, + 5.21039150503902756861e-17, + -1.67823109680541210385e-16, + 5.51205597852431940784e-16, + -1.84859337734377901440e-15, + 6.34007647740507060557e-15, + -2.22751332699166985548e-14, + 8.03289077536357521100e-14, + -2.98009692317273043925e-13, + 1.14034058820847496303e-12, + -4.51459788337394416547e-12, + 1.85594911495471785253e-11, + -7.95748924447710747776e-11, + 3.57739728140030116597e-10, + -1.69753450938905987466e-09, + 8.57403401741422608519e-09, + -4.66048989768794782956e-08, + 2.76681363944501510342e-07, + -1.83175552271911948767e-06, + 1.39498137188764993662e-05, + -1.28495495816278026384e-04, + 1.56988388573005337491e-03, + -3.14481013119645005427e-02, + 2.44030308206595545468e00, + ], + dtype=self.dtype, + ) + + def f(carry, val): + p, q, b = carry + p, q = q, b + return (p, q, (8.0 / x - 2.0) * q - p + val), None + + (p, _, b), _ = jax.lax.scan( + f, init=(jnp.zeros_like(x), jnp.zeros_like(x), 0), xs=B + ) + + return jnp.exp(-x) * (0.5 * (b - p)) / jnp.sqrt(x) - return jnp.piecewise(self, [self <= 2, self < 0, self == 0], - [small, negative, zero, default]) + return jnp.piecewise( + self, [self <= 2, self < 0, self == 0], [small, negative, zero, default] + ) @op(torch.ops.aten.special_modified_bessel_k1) @op_base.promote_int_input def _aten_special_modified_bessel_k1(self): - # Adapted from https://github.com/pytorch/pytorch/blob/f8f41dcb24cb4f4e87a51bb04847942dd835e496/aten/src/ATen/native/Math.h#L3444-L3519 - - def zero(x): - return jnp.array(jnp.inf, x.dtype) - - def negative(x): - return jnp.array(jnp.nan, x.dtype) - - def small(x): - A = jnp.array( - [ - -7.02386347938628759343e-18, - -2.42744985051936593393e-15, - -6.66690169419932900609e-13, - -1.41148839263352776110e-10, - -2.21338763073472585583e-08, - -2.43340614156596823496e-06, - -1.73028895751305206302e-04, - -6.97572385963986435018e-03, - -1.22611180822657148235e-01, - -3.53155960776544875667e-01, - 1.52530022733894777053e+00, - ], - dtype=self.dtype, - ) - - def f(carry, val): - p, q, a = carry - p, q = q, a - a = (x * x - 2.0) * q - p + val - return (p, q, a), None - - (p, _, a), _ = jax.lax.scan( - f, init=(jnp.zeros_like(x), jnp.zeros_like(x), 0), xs=A) - - return jnp.log( - 0.5 * x) * _aten_special_modified_bessel_i1(x) + 0.5 * (a - p) / x - - def default(x): - B = jnp.array( - [ - -5.75674448366501715755e-18, - 1.79405087314755922667e-17, - -5.68946255844285935196e-17, - 1.83809354436663880070e-16, - -6.05704724837331885336e-16, - 2.03870316562433424052e-15, - -7.01983709041831346144e-15, - 2.47715442448130437068e-14, - -8.97670518232499435011e-14, - +3.34841966607842919884e-13, - -1.28917396095102890680e-12, - 5.13963967348173025100e-12, - -2.12996783842756842877e-11, - 9.21831518760500529508e-11, - -4.19035475934189648750e-10, - 2.01504975519703286596e-09, - -1.03457624656780970260e-08, - 5.74108412545004946722e-08, - -3.50196060308781257119e-07, - 2.40648494783721712015e-06, - -1.93619797416608296024e-05, - 1.95215518471351631108e-04, - -2.85781685962277938680e-03, - 1.03923736576817238437e-01, - 2.72062619048444266945e+00, - ], - dtype=self.dtype, - ) - - def f(carry, val): - p, q, b = carry - p, q = q, b - b = (8.0 / x - 2.0) * q - p + val - return (p, q, b), None + # Adapted from https://github.com/pytorch/pytorch/blob/f8f41dcb24cb4f4e87a51bb04847942dd835e496/aten/src/ATen/native/Math.h#L3444-L3519 + + def zero(x): + return jnp.array(jnp.inf, x.dtype) + + def negative(x): + return jnp.array(jnp.nan, x.dtype) + + def small(x): + A = jnp.array( + [ + -7.02386347938628759343e-18, + -2.42744985051936593393e-15, + -6.66690169419932900609e-13, + -1.41148839263352776110e-10, + -2.21338763073472585583e-08, + -2.43340614156596823496e-06, + -1.73028895751305206302e-04, + -6.97572385963986435018e-03, + -1.22611180822657148235e-01, + -3.53155960776544875667e-01, + 1.52530022733894777053e00, + ], + dtype=self.dtype, + ) + + def f(carry, val): + p, q, a = carry + p, q = q, a + a = (x * x - 2.0) * q - p + val + return (p, q, a), None + + (p, _, a), _ = jax.lax.scan( + f, init=(jnp.zeros_like(x), jnp.zeros_like(x), 0), xs=A + ) + + return ( + jnp.log(0.5 * x) * _aten_special_modified_bessel_i1(x) + + 0.5 * (a - p) / x + ) - (p, _, b), _ = jax.lax.scan( - f, init=(jnp.zeros_like(x), jnp.zeros_like(x), 0), xs=B) - - return jnp.exp(-x) * (0.5 * (b - p)) / jnp.sqrt(x) + def default(x): + B = jnp.array( + [ + -5.75674448366501715755e-18, + 1.79405087314755922667e-17, + -5.68946255844285935196e-17, + 1.83809354436663880070e-16, + -6.05704724837331885336e-16, + 2.03870316562433424052e-15, + -7.01983709041831346144e-15, + 2.47715442448130437068e-14, + -8.97670518232499435011e-14, + +3.34841966607842919884e-13, + -1.28917396095102890680e-12, + 5.13963967348173025100e-12, + -2.12996783842756842877e-11, + 9.21831518760500529508e-11, + -4.19035475934189648750e-10, + 2.01504975519703286596e-09, + -1.03457624656780970260e-08, + 5.74108412545004946722e-08, + -3.50196060308781257119e-07, + 2.40648494783721712015e-06, + -1.93619797416608296024e-05, + 1.95215518471351631108e-04, + -2.85781685962277938680e-03, + 1.03923736576817238437e-01, + 2.72062619048444266945e00, + ], + dtype=self.dtype, + ) + + def f(carry, val): + p, q, b = carry + p, q = q, b + b = (8.0 / x - 2.0) * q - p + val + return (p, q, b), None + + (p, _, b), _ = jax.lax.scan( + f, init=(jnp.zeros_like(x), jnp.zeros_like(x), 0), xs=B + ) + + return jnp.exp(-x) * (0.5 * (b - p)) / jnp.sqrt(x) - return jnp.piecewise(self, [self <= 2, self < 0, self == 0], - [small, negative, zero, default]) + return jnp.piecewise( + self, [self <= 2, self < 0, self == 0], [small, negative, zero, default] + ) @op(torch.ops.aten.polygamma) def _aten_polygamma(x, n): - if n.dtype in [jnp.int8, jnp.int16, jnp.int32, jnp.int64]: - n = n.astype(mappings.t2j_dtype(torch.get_default_dtype())) - return jax.lax.polygamma(jnp.float32(x), n) + if n.dtype in [jnp.int8, jnp.int16, jnp.int32, jnp.int64]: + n = n.astype(mappings.t2j_dtype(torch.get_default_dtype())) + return jax.lax.polygamma(jnp.float32(x), n) @op(torch.ops.aten.special_ndtri) @op_base.promote_int_input def _aten_special_ndtri(self): - return jax.scipy.special.ndtri(self) + return jax.scipy.special.ndtri(self) @op(torch.ops.aten.special_bessel_j0) @op_base.promote_int_input def _aten_special_bessel_j0(self): - # Adapted from https://github.com/pytorch/pytorch/blob/f8f41dcb24cb4f4e87a51bb04847942dd835e496/aten/src/ATen/native/Math.h#L2379-L2489 - - def very_small(x): - return 1.0 - x * x / 4.0 - - def small(x): - RP = jnp.array( - [ - -4.79443220978201773821e09, - 1.95617491946556577543e12, - -2.49248344360967716204e14, - 9.70862251047306323952e15, - ], - dtype=self.dtype, - ) - RQ = jnp.array( - [ - 4.99563147152651017219e02, - 1.73785401676374683123e05, - 4.84409658339962045305e07, - 1.11855537045356834862e10, - 2.11277520115489217587e12, - 3.10518229857422583814e14, - 3.18121955943204943306e16, - 1.71086294081043136091e18, - ], - dtype=self.dtype, - ) + # Adapted from https://github.com/pytorch/pytorch/blob/f8f41dcb24cb4f4e87a51bb04847942dd835e496/aten/src/ATen/native/Math.h#L2379-L2489 + + def very_small(x): + return 1.0 - x * x / 4.0 + + def small(x): + RP = jnp.array( + [ + -4.79443220978201773821e09, + 1.95617491946556577543e12, + -2.49248344360967716204e14, + 9.70862251047306323952e15, + ], + dtype=self.dtype, + ) + RQ = jnp.array( + [ + 4.99563147152651017219e02, + 1.73785401676374683123e05, + 4.84409658339962045305e07, + 1.11855537045356834862e10, + 2.11277520115489217587e12, + 3.10518229857422583814e14, + 3.18121955943204943306e16, + 1.71086294081043136091e18, + ], + dtype=self.dtype, + ) + + rp = op_base.foreach_loop( + RP, lambda carry, rp_i: carry * (x * x) + rp_i + ) + rq = op_base.foreach_loop( + RQ, lambda carry, rq_i: carry * (x * x) + rq_i + ) + + return ( + (x * x - 5.78318596294678452118e00) + * (x * x - 3.04712623436620863991e01) + * rp + / rq + ) - rp = op_base.foreach_loop(RP, lambda carry, rp_i: carry * (x * x) + rp_i) - rq = op_base.foreach_loop(RQ, lambda carry, rq_i: carry * (x * x) + rq_i) - - return ((x * x - 5.78318596294678452118e00) * - (x * x - 3.04712623436620863991e01) * rp / rq) - - def default(x): - PP = jnp.array( - [ - 7.96936729297347051624e-04, - 8.28352392107440799803e-02, - 1.23953371646414299388e00, - 5.44725003058768775090e00, - 8.74716500199817011941e00, - 5.30324038235394892183e00, - 9.99999999999999997821e-01, - ], - dtype=self.dtype, - ) - PQ = jnp.array( - [ - 9.24408810558863637013e-04, - 8.56288474354474431428e-02, - 1.25352743901058953537e00, - 5.47097740330417105182e00, - 8.76190883237069594232e00, - 5.30605288235394617618e00, - 1.00000000000000000218e00, - ], - dtype=self.dtype, - ) - QP = jnp.array( - [ - -1.13663838898469149931e-02, - -1.28252718670509318512e00, - -1.95539544257735972385e01, - -9.32060152123768231369e01, - -1.77681167980488050595e02, - -1.47077505154951170175e02, - -5.14105326766599330220e01, - -6.05014350600728481186e00, - ], - dtype=self.dtype, - ) - QQ = jnp.array( - [ - 6.43178256118178023184e01, - 8.56430025976980587198e02, - 3.88240183605401609683e03, - 7.24046774195652478189e03, - 5.93072701187316984827e03, - 2.06209331660327847417e03, - 2.42005740240291393179e02, - ], - dtype=self.dtype, + def default(x): + PP = jnp.array( + [ + 7.96936729297347051624e-04, + 8.28352392107440799803e-02, + 1.23953371646414299388e00, + 5.44725003058768775090e00, + 8.74716500199817011941e00, + 5.30324038235394892183e00, + 9.99999999999999997821e-01, + ], + dtype=self.dtype, + ) + PQ = jnp.array( + [ + 9.24408810558863637013e-04, + 8.56288474354474431428e-02, + 1.25352743901058953537e00, + 5.47097740330417105182e00, + 8.76190883237069594232e00, + 5.30605288235394617618e00, + 1.00000000000000000218e00, + ], + dtype=self.dtype, + ) + QP = jnp.array( + [ + -1.13663838898469149931e-02, + -1.28252718670509318512e00, + -1.95539544257735972385e01, + -9.32060152123768231369e01, + -1.77681167980488050595e02, + -1.47077505154951170175e02, + -5.14105326766599330220e01, + -6.05014350600728481186e00, + ], + dtype=self.dtype, + ) + QQ = jnp.array( + [ + 6.43178256118178023184e01, + 8.56430025976980587198e02, + 3.88240183605401609683e03, + 7.24046774195652478189e03, + 5.93072701187316984827e03, + 2.06209331660327847417e03, + 2.42005740240291393179e02, + ], + dtype=self.dtype, + ) + + pp = op_base.foreach_loop( + PP, lambda carry, pp_i: carry * (25.0 / (x * x)) + pp_i + ) + pq = op_base.foreach_loop( + PQ, lambda carry, pq_i: carry * (25.0 / (x * x)) + pq_i + ) + qp = op_base.foreach_loop( + QP, lambda carry, qp_i: carry * (25.0 / (x * x)) + qp_i + ) + qq = op_base.foreach_loop( + QQ, lambda carry, qq_i: carry * (25.0 / (x * x)) + qq_i + ) + + return ( + ( + pp / pq * jnp.cos(x - 0.785398163397448309615660845819875721) + - 5.0 + / x + * (qp / qq) + * jnp.sin(x - 0.785398163397448309615660845819875721) + ) + * 0.797884560802865355879892119868763737 + / jnp.sqrt(x) + ) + + self = jnp.abs(self) + # Last True condition in `piecewise` takes priority, but last function is + # default. See https://github.com/numpy/numpy/issues/16475 + return jnp.piecewise( + self, [self <= 5.0, self < 0.00001], [small, very_small, default] ) - pp = op_base.foreach_loop( - PP, lambda carry, pp_i: carry * (25.0 / (x * x)) + pp_i) - pq = op_base.foreach_loop( - PQ, lambda carry, pq_i: carry * (25.0 / (x * x)) + pq_i) - qp = op_base.foreach_loop( - QP, lambda carry, qp_i: carry * (25.0 / (x * x)) + qp_i) - qq = op_base.foreach_loop( - QQ, lambda carry, qq_i: carry * (25.0 / (x * x)) + qq_i) - - return ((pp / pq * jnp.cos(x - 0.785398163397448309615660845819875721) - - 5.0 / x * - (qp / qq) * jnp.sin(x - 0.785398163397448309615660845819875721)) * - 0.797884560802865355879892119868763737 / jnp.sqrt(x)) - - self = jnp.abs(self) - # Last True condition in `piecewise` takes priority, but last function is - # default. See https://github.com/numpy/numpy/issues/16475 - return jnp.piecewise(self, [self <= 5.0, self < 0.00001], - [small, very_small, default]) - @op(torch.ops.aten.special_bessel_j1) @op_base.promote_int_input def _aten_special_bessel_j1(self): - # Adapted from https://github.com/pytorch/pytorch/blob/f8f41dcb24cb4f4e87a51bb04847942dd835e496/aten/src/ATen/native/Math.h#L2491-L2597 - - def small(x): - RP = jnp.array( - [ - -8.99971225705559398224e08, - 4.52228297998194034323e11, - -7.27494245221818276015e13, - 3.68295732863852883286e15, - ], - dtype=self.dtype, - ) - RQ = jnp.array( - [ - 6.20836478118054335476e02, - 2.56987256757748830383e05, - 8.35146791431949253037e07, - 2.21511595479792499675e10, - 4.74914122079991414898e12, - 7.84369607876235854894e14, - 8.95222336184627338078e16, - 5.32278620332680085395e18, - ], - dtype=self.dtype, - ) + # Adapted from https://github.com/pytorch/pytorch/blob/f8f41dcb24cb4f4e87a51bb04847942dd835e496/aten/src/ATen/native/Math.h#L2491-L2597 + + def small(x): + RP = jnp.array( + [ + -8.99971225705559398224e08, + 4.52228297998194034323e11, + -7.27494245221818276015e13, + 3.68295732863852883286e15, + ], + dtype=self.dtype, + ) + RQ = jnp.array( + [ + 6.20836478118054335476e02, + 2.56987256757748830383e05, + 8.35146791431949253037e07, + 2.21511595479792499675e10, + 4.74914122079991414898e12, + 7.84369607876235854894e14, + 8.95222336184627338078e16, + 5.32278620332680085395e18, + ], + dtype=self.dtype, + ) + + rp = op_base.foreach_loop( + RP, lambda carry, rp_i: carry * (x * x) + rp_i + ) + rq = op_base.foreach_loop( + RQ, lambda carry, rq_i: carry * (x * x) + rq_i + ) + + return ( + rp + / rq + * x + * (x * x - 1.46819706421238932572e01) + * (x * x - 4.92184563216946036703e01) + ) - rp = op_base.foreach_loop(RP, lambda carry, rp_i: carry * (x * x) + rp_i) - rq = op_base.foreach_loop(RQ, lambda carry, rq_i: carry * (x * x) + rq_i) - - return (rp / rq * x * (x * x - 1.46819706421238932572e01) * - (x * x - 4.92184563216946036703e01)) - - def default(x): - PP = jnp.array( - [ - 7.62125616208173112003e-04, - 7.31397056940917570436e-02, - 1.12719608129684925192e00, - 5.11207951146807644818e00, - 8.42404590141772420927e00, - 5.21451598682361504063e00, - 1.00000000000000000254e00, - ], - dtype=self.dtype, - ) - PQ = jnp.array( - [ - 5.71323128072548699714e-04, - 6.88455908754495404082e-02, - 1.10514232634061696926e00, - 5.07386386128601488557e00, - 8.39985554327604159757e00, - 5.20982848682361821619e00, - 9.99999999999999997461e-01, - ], - dtype=self.dtype, - ) - QP = jnp.array( - [ - 5.10862594750176621635e-02, - 4.98213872951233449420e00, - 7.58238284132545283818e01, - 3.66779609360150777800e02, - 7.10856304998926107277e02, - 5.97489612400613639965e02, - 2.11688757100572135698e02, - 2.52070205858023719784e01, - ], - dtype=self.dtype, - ) - QQ = jnp.array( - [ - 7.42373277035675149943e01, - 1.05644886038262816351e03, - 4.98641058337653607651e03, - 9.56231892404756170795e03, - 7.99704160447350683650e03, - 2.82619278517639096600e03, - 3.36093607810698293419e02, - ], - dtype=self.dtype, + def default(x): + PP = jnp.array( + [ + 7.62125616208173112003e-04, + 7.31397056940917570436e-02, + 1.12719608129684925192e00, + 5.11207951146807644818e00, + 8.42404590141772420927e00, + 5.21451598682361504063e00, + 1.00000000000000000254e00, + ], + dtype=self.dtype, + ) + PQ = jnp.array( + [ + 5.71323128072548699714e-04, + 6.88455908754495404082e-02, + 1.10514232634061696926e00, + 5.07386386128601488557e00, + 8.39985554327604159757e00, + 5.20982848682361821619e00, + 9.99999999999999997461e-01, + ], + dtype=self.dtype, + ) + QP = jnp.array( + [ + 5.10862594750176621635e-02, + 4.98213872951233449420e00, + 7.58238284132545283818e01, + 3.66779609360150777800e02, + 7.10856304998926107277e02, + 5.97489612400613639965e02, + 2.11688757100572135698e02, + 2.52070205858023719784e01, + ], + dtype=self.dtype, + ) + QQ = jnp.array( + [ + 7.42373277035675149943e01, + 1.05644886038262816351e03, + 4.98641058337653607651e03, + 9.56231892404756170795e03, + 7.99704160447350683650e03, + 2.82619278517639096600e03, + 3.36093607810698293419e02, + ], + dtype=self.dtype, + ) + + pp = op_base.foreach_loop( + PP, lambda carry, pp_i: carry * (25.0 / (x * x)) + pp_i + ) + pq = op_base.foreach_loop( + PQ, lambda carry, pq_i: carry * (25.0 / (x * x)) + pq_i + ) + qp = op_base.foreach_loop( + QP, lambda carry, qp_i: carry * (25.0 / (x * x)) + qp_i + ) + qq = op_base.foreach_loop( + QQ, lambda carry, qq_i: carry * (25.0 / (x * x)) + qq_i + ) + + return ( + ( + pp / pq * jnp.cos(x - 2.356194490192344928846982537459627163) + - 5.0 + / x + * (qp / qq) + * jnp.sin(x - 2.356194490192344928846982537459627163) + ) + * 0.797884560802865355879892119868763737 + / jnp.sqrt(x) + ) + + # If x < 0, bessel_j1(x) = -bessel_j1(-x) + sign = jnp.sign(self) + self = jnp.abs(self) + return sign * jnp.piecewise( + self, + [self <= 5.0], + [small, default], ) - pp = op_base.foreach_loop( - PP, lambda carry, pp_i: carry * (25.0 / (x * x)) + pp_i) - pq = op_base.foreach_loop( - PQ, lambda carry, pq_i: carry * (25.0 / (x * x)) + pq_i) - qp = op_base.foreach_loop( - QP, lambda carry, qp_i: carry * (25.0 / (x * x)) + qp_i) - qq = op_base.foreach_loop( - QQ, lambda carry, qq_i: carry * (25.0 / (x * x)) + qq_i) - - return ((pp / pq * jnp.cos(x - 2.356194490192344928846982537459627163) - - 5.0 / x * - (qp / qq) * jnp.sin(x - 2.356194490192344928846982537459627163)) * - 0.797884560802865355879892119868763737 / jnp.sqrt(x)) - - # If x < 0, bessel_j1(x) = -bessel_j1(-x) - sign = jnp.sign(self) - self = jnp.abs(self) - return sign * jnp.piecewise( - self, - [self <= 5.0], - [small, default], - ) - @op(torch.ops.aten.special_bessel_y0) @op_base.promote_int_input def _aten_special_bessel_y0(self): - # Adapted from https://github.com/pytorch/pytorch/blob/f8f41dcb24cb4f4e87a51bb04847942dd835e496/aten/src/ATen/native/Math.h#L2599-L2712 - - def zero(x): - return jnp.array(-jnp.inf, x.dtype) - - def negative(x): - return jnp.array(jnp.nan, x.dtype) - - def small(x): - YP = jnp.array( - [ - 1.55924367855235737965e04, - -1.46639295903971606143e07, - 5.43526477051876500413e09, - -9.82136065717911466409e11, - 8.75906394395366999549e13, - -3.46628303384729719441e15, - 4.42733268572569800351e16, - -1.84950800436986690637e16, - ], - dtype=self.dtype, - ) - YQ = jnp.array( - [ - 1.04128353664259848412e03, - 6.26107330137134956842e05, - 2.68919633393814121987e08, - 8.64002487103935000337e10, - 2.02979612750105546709e13, - 3.17157752842975028269e15, - 2.50596256172653059228e17, - ], - dtype=self.dtype, - ) + # Adapted from https://github.com/pytorch/pytorch/blob/f8f41dcb24cb4f4e87a51bb04847942dd835e496/aten/src/ATen/native/Math.h#L2599-L2712 + + def zero(x): + return jnp.array(-jnp.inf, x.dtype) + + def negative(x): + return jnp.array(jnp.nan, x.dtype) + + def small(x): + YP = jnp.array( + [ + 1.55924367855235737965e04, + -1.46639295903971606143e07, + 5.43526477051876500413e09, + -9.82136065717911466409e11, + 8.75906394395366999549e13, + -3.46628303384729719441e15, + 4.42733268572569800351e16, + -1.84950800436986690637e16, + ], + dtype=self.dtype, + ) + YQ = jnp.array( + [ + 1.04128353664259848412e03, + 6.26107330137134956842e05, + 2.68919633393814121987e08, + 8.64002487103935000337e10, + 2.02979612750105546709e13, + 3.17157752842975028269e15, + 2.50596256172653059228e17, + ], + dtype=self.dtype, + ) + + yp = op_base.foreach_loop( + YP, lambda carry, yp_i: carry * (x * x) + yp_i + ) + yq = op_base.foreach_loop( + YQ, lambda carry, yq_i: carry * (x * x) + yq_i + ) + + return yp / yq + ( + 0.636619772367581343075535053490057448 + * jnp.log(x) + * _aten_special_bessel_j0(x) + ) - yp = op_base.foreach_loop(YP, lambda carry, yp_i: carry * (x * x) + yp_i) - yq = op_base.foreach_loop(YQ, lambda carry, yq_i: carry * (x * x) + yq_i) - - return yp / yq + (0.636619772367581343075535053490057448 * jnp.log(x) * - _aten_special_bessel_j0(x)) - - def default(x): - PP = jnp.array( - [ - 7.96936729297347051624e-04, - 8.28352392107440799803e-02, - 1.23953371646414299388e00, - 5.44725003058768775090e00, - 8.74716500199817011941e00, - 5.30324038235394892183e00, - 9.99999999999999997821e-01, - ], - dtype=self.dtype, - ) - PQ = jnp.array( - [ - 9.24408810558863637013e-04, - 8.56288474354474431428e-02, - 1.25352743901058953537e00, - 5.47097740330417105182e00, - 8.76190883237069594232e00, - 5.30605288235394617618e00, - 1.00000000000000000218e00, - ], - dtype=self.dtype, - ) - QP = jnp.array( - [ - -1.13663838898469149931e-02, - -1.28252718670509318512e00, - -1.95539544257735972385e01, - -9.32060152123768231369e01, - -1.77681167980488050595e02, - -1.47077505154951170175e02, - -5.14105326766599330220e01, - -6.05014350600728481186e00, - ], - dtype=self.dtype, - ) - QQ = jnp.array( - [ - 6.43178256118178023184e01, - 8.56430025976980587198e02, - 3.88240183605401609683e03, - 7.24046774195652478189e03, - 5.93072701187316984827e03, - 2.06209331660327847417e03, - 2.42005740240291393179e02, - ], - dtype=self.dtype, - ) - - factor = 25.0 / (x * x) - pp = op_base.foreach_loop(PP, lambda carry, pp_i: carry * factor + pp_i) - pq = op_base.foreach_loop(PQ, lambda carry, pq_i: carry * factor + pq_i) - qp = op_base.foreach_loop(QP, lambda carry, qp_i: carry * factor + qp_i) - qq = op_base.foreach_loop(QQ, lambda carry, qq_i: carry * factor + qq_i) - - return ((pp / pq * jnp.sin(x - 0.785398163397448309615660845819875721) + - 5.0 / x * - (qp / qq) * jnp.cos(x - 0.785398163397448309615660845819875721)) * - 0.797884560802865355879892119868763737 / jnp.sqrt(x)) + def default(x): + PP = jnp.array( + [ + 7.96936729297347051624e-04, + 8.28352392107440799803e-02, + 1.23953371646414299388e00, + 5.44725003058768775090e00, + 8.74716500199817011941e00, + 5.30324038235394892183e00, + 9.99999999999999997821e-01, + ], + dtype=self.dtype, + ) + PQ = jnp.array( + [ + 9.24408810558863637013e-04, + 8.56288474354474431428e-02, + 1.25352743901058953537e00, + 5.47097740330417105182e00, + 8.76190883237069594232e00, + 5.30605288235394617618e00, + 1.00000000000000000218e00, + ], + dtype=self.dtype, + ) + QP = jnp.array( + [ + -1.13663838898469149931e-02, + -1.28252718670509318512e00, + -1.95539544257735972385e01, + -9.32060152123768231369e01, + -1.77681167980488050595e02, + -1.47077505154951170175e02, + -5.14105326766599330220e01, + -6.05014350600728481186e00, + ], + dtype=self.dtype, + ) + QQ = jnp.array( + [ + 6.43178256118178023184e01, + 8.56430025976980587198e02, + 3.88240183605401609683e03, + 7.24046774195652478189e03, + 5.93072701187316984827e03, + 2.06209331660327847417e03, + 2.42005740240291393179e02, + ], + dtype=self.dtype, + ) + + factor = 25.0 / (x * x) + pp = op_base.foreach_loop(PP, lambda carry, pp_i: carry * factor + pp_i) + pq = op_base.foreach_loop(PQ, lambda carry, pq_i: carry * factor + pq_i) + qp = op_base.foreach_loop(QP, lambda carry, qp_i: carry * factor + qp_i) + qq = op_base.foreach_loop(QQ, lambda carry, qq_i: carry * factor + qq_i) + + return ( + ( + pp / pq * jnp.sin(x - 0.785398163397448309615660845819875721) + + 5.0 + / x + * (qp / qq) + * jnp.cos(x - 0.785398163397448309615660845819875721) + ) + * 0.797884560802865355879892119868763737 + / jnp.sqrt(x) + ) - return jnp.piecewise( - self, - [self <= 5.0, self < 0., self == 0.], - [small, negative, zero, default], - ) + return jnp.piecewise( + self, + [self <= 5.0, self < 0.0, self == 0.0], + [small, negative, zero, default], + ) @op(torch.ops.aten.special_bessel_y1) @op_base.promote_int_input def _aten_special_bessel_y1(self): - # Adapted from https://github.com/pytorch/pytorch/blob/f8f41dcb24cb4f4e87a51bb04847942dd835e496/aten/src/ATen/native/Math.h#L2714-L2826 - - def zero(x): - return jnp.array(-jnp.inf, x.dtype) - - def negative(x): - return jnp.array(jnp.nan, x.dtype) - - def small(x): - YP = jnp.array( - [ - 1.26320474790178026440e09, - -6.47355876379160291031e11, - 1.14509511541823727583e14, - -8.12770255501325109621e15, - 2.02439475713594898196e17, - -7.78877196265950026825e17, - ], - dtype=self.dtype, - ) - YQ = jnp.array( - [ - 5.94301592346128195359e02, - 2.35564092943068577943e05, - 7.34811944459721705660e07, - 1.87601316108706159478e10, - 3.88231277496238566008e12, - 6.20557727146953693363e14, - 6.87141087355300489866e16, - 3.97270608116560655612e18, - ], - dtype=self.dtype, - ) + # Adapted from https://github.com/pytorch/pytorch/blob/f8f41dcb24cb4f4e87a51bb04847942dd835e496/aten/src/ATen/native/Math.h#L2714-L2826 + + def zero(x): + return jnp.array(-jnp.inf, x.dtype) + + def negative(x): + return jnp.array(jnp.nan, x.dtype) + + def small(x): + YP = jnp.array( + [ + 1.26320474790178026440e09, + -6.47355876379160291031e11, + 1.14509511541823727583e14, + -8.12770255501325109621e15, + 2.02439475713594898196e17, + -7.78877196265950026825e17, + ], + dtype=self.dtype, + ) + YQ = jnp.array( + [ + 5.94301592346128195359e02, + 2.35564092943068577943e05, + 7.34811944459721705660e07, + 1.87601316108706159478e10, + 3.88231277496238566008e12, + 6.20557727146953693363e14, + 6.87141087355300489866e16, + 3.97270608116560655612e18, + ], + dtype=self.dtype, + ) + + yp = op_base.foreach_loop( + YP, lambda carry, yp_i: carry * (x * x) + yp_i + ) + yq = op_base.foreach_loop( + YQ, lambda carry, yq_i: carry * (x * x) + yq_i + ) + + return x * (yp / yq) + ( + 0.636619772367581343075535053490057448 + * (_aten_special_bessel_j1(x) * jnp.log(x) - 1.0 / x) + ) - yp = op_base.foreach_loop(YP, lambda carry, yp_i: carry * (x * x) + yp_i) - yq = op_base.foreach_loop(YQ, lambda carry, yq_i: carry * (x * x) + yq_i) - - return (x * (yp / yq) + - (0.636619772367581343075535053490057448 * - (_aten_special_bessel_j1(x) * jnp.log(x) - 1.0 / x))) - - def default(x): - PP = jnp.array( - [ - 7.62125616208173112003e-04, - 7.31397056940917570436e-02, - 1.12719608129684925192e00, - 5.11207951146807644818e00, - 8.42404590141772420927e00, - 5.21451598682361504063e00, - 1.00000000000000000254e00, - ], - dtype=self.dtype, - ) - PQ = jnp.array( - [ - 5.71323128072548699714e-04, - 6.88455908754495404082e-02, - 1.10514232634061696926e00, - 5.07386386128601488557e00, - 8.39985554327604159757e00, - 5.20982848682361821619e00, - 9.99999999999999997461e-01, - ], - dtype=self.dtype, - ) - QP = jnp.array( - [ - 5.10862594750176621635e-02, - 4.98213872951233449420e00, - 7.58238284132545283818e01, - 3.66779609360150777800e02, - 7.10856304998926107277e02, - 5.97489612400613639965e02, - 2.11688757100572135698e02, - 2.52070205858023719784e01, - ], - dtype=self.dtype, - ) - QQ = jnp.array( - [ - 7.42373277035675149943e01, - 1.05644886038262816351e03, - 4.98641058337653607651e03, - 9.56231892404756170795e03, - 7.99704160447350683650e03, - 2.82619278517639096600e03, - 3.36093607810698293419e02, - ], - dtype=self.dtype, - ) - - factor = 25.0 / (x * x) - pp = op_base.foreach_loop(PP, lambda carry, pp_i: carry * factor + pp_i) - pq = op_base.foreach_loop(PQ, lambda carry, pq_i: carry * factor + pq_i) - qp = op_base.foreach_loop(QP, lambda carry, qp_i: carry * factor + qp_i) - qq = op_base.foreach_loop(QQ, lambda carry, qq_i: carry * factor + qq_i) - - return ((pp / pq * jnp.sin(x - 2.356194490192344928846982537459627163) + - 5.0 / x * - (qp / qq) * jnp.cos(x - 2.356194490192344928846982537459627163)) * - 0.797884560802865355879892119868763737 / jnp.sqrt(x)) + def default(x): + PP = jnp.array( + [ + 7.62125616208173112003e-04, + 7.31397056940917570436e-02, + 1.12719608129684925192e00, + 5.11207951146807644818e00, + 8.42404590141772420927e00, + 5.21451598682361504063e00, + 1.00000000000000000254e00, + ], + dtype=self.dtype, + ) + PQ = jnp.array( + [ + 5.71323128072548699714e-04, + 6.88455908754495404082e-02, + 1.10514232634061696926e00, + 5.07386386128601488557e00, + 8.39985554327604159757e00, + 5.20982848682361821619e00, + 9.99999999999999997461e-01, + ], + dtype=self.dtype, + ) + QP = jnp.array( + [ + 5.10862594750176621635e-02, + 4.98213872951233449420e00, + 7.58238284132545283818e01, + 3.66779609360150777800e02, + 7.10856304998926107277e02, + 5.97489612400613639965e02, + 2.11688757100572135698e02, + 2.52070205858023719784e01, + ], + dtype=self.dtype, + ) + QQ = jnp.array( + [ + 7.42373277035675149943e01, + 1.05644886038262816351e03, + 4.98641058337653607651e03, + 9.56231892404756170795e03, + 7.99704160447350683650e03, + 2.82619278517639096600e03, + 3.36093607810698293419e02, + ], + dtype=self.dtype, + ) + + factor = 25.0 / (x * x) + pp = op_base.foreach_loop(PP, lambda carry, pp_i: carry * factor + pp_i) + pq = op_base.foreach_loop(PQ, lambda carry, pq_i: carry * factor + pq_i) + qp = op_base.foreach_loop(QP, lambda carry, qp_i: carry * factor + qp_i) + qq = op_base.foreach_loop(QQ, lambda carry, qq_i: carry * factor + qq_i) + + return ( + ( + pp / pq * jnp.sin(x - 2.356194490192344928846982537459627163) + + 5.0 + / x + * (qp / qq) + * jnp.cos(x - 2.356194490192344928846982537459627163) + ) + * 0.797884560802865355879892119868763737 + / jnp.sqrt(x) + ) - return jnp.piecewise( - self, - [self <= 5.0, self < 0., self == 0.], - [small, negative, zero, default], - ) + return jnp.piecewise( + self, + [self <= 5.0, self < 0.0, self == 0.0], + [small, negative, zero, default], + ) @op(torch.ops.aten.special_chebyshev_polynomial_t) @op_base.promote_int_input def _aten_special_chebyshev_polynomial_t(self, n): - # Adapted from https://github.com/pytorch/pytorch/blob/f8f41dcb24cb4f4e87a51bb04847942dd835e496/aten/src/ATen/native/Math.h#L2828-L2865 + # Adapted from https://github.com/pytorch/pytorch/blob/f8f41dcb24cb4f4e87a51bb04847942dd835e496/aten/src/ATen/native/Math.h#L2828-L2865 - @jnp.vectorize - def vectorized(x, n_i): + @jnp.vectorize + def vectorized(x, n_i): + def negative_n(x): + return jnp.zeros_like(x) - def negative_n(x): - return jnp.zeros_like(x) + def one_x(x): + return jnp.where( + (x > 0) | (n_i % 2 == 0), jnp.ones_like(x), -jnp.ones_like(x) + ) - def one_x(x): - return jnp.where((x > 0) | (n_i % 2 == 0), jnp.ones_like(x), - -jnp.ones_like(x)) + def large_n_small_x(x): + return jnp.cos(n_i * jnp.acos(x)) - def large_n_small_x(x): - return jnp.cos(n_i * jnp.acos(x)) + def zero_n(x): + return jnp.ones_like(x) - def zero_n(x): - return jnp.ones_like(x) + def one_n(x): + return x - def one_n(x): - return x + def default(x): + def f(_, carry): + p, q = carry + return (q, 2 * x * q - p) - def default(x): - - def f(_, carry): - p, q = carry - return (q, 2 * x * q - p) + _, r = jax.lax.fori_loop(0, n_i - 1, f, init_val=(1.0, x)) + return r - _, r = jax.lax.fori_loop(0, n_i - 1, f, init_val=(1., x)) - return r + return jnp.piecewise( + x, + [ + n_i == 1, + n_i == 0, + (n_i == 6) & (jnp.abs(x) < 1), + jnp.abs(x) == 1.0, + n_i < 0, + ], + [one_n, zero_n, large_n_small_x, one_x, negative_n, default], + ) - return jnp.piecewise(x, [ - n_i == 1, n_i == 0, (n_i == 6) & (jnp.abs(x) < 1), - jnp.abs(x) == 1., n_i < 0 - ], [one_n, zero_n, large_n_small_x, one_x, negative_n, default]) - - # Explcicitly vectorize since we must vectorizes over both self and n - return vectorized(self, n.astype(jnp.int64)) + # Explcicitly vectorize since we must vectorizes over both self and n + return vectorized(self, n.astype(jnp.int64)) @op(torch.ops.aten.special_chebyshev_polynomial_u) @op_base.promote_int_input def _aten_special_chebyshev_polynomial_u(self, n): - # Adapted from https://github.com/pytorch/pytorch/blob/f8f41dcb24cb4f4e87a51bb04847942dd835e496/aten/src/ATen/native/Math.h#L2872-L2913 - - @jnp.vectorize - def vectorized(x, n_i): + # Adapted from https://github.com/pytorch/pytorch/blob/f8f41dcb24cb4f4e87a51bb04847942dd835e496/aten/src/ATen/native/Math.h#L2872-L2913 - def negative_n(x): - return jnp.zeros_like(x) + @jnp.vectorize + def vectorized(x, n_i): + def negative_n(x): + return jnp.zeros_like(x) - def one_x(x): - return jnp.where((x > 0) | (n_i % 2 == 0), n_i + 1, -(n_i + 1)) + def one_x(x): + return jnp.where((x > 0) | (n_i % 2 == 0), n_i + 1, -(n_i + 1)) - def large_n_small_x(x): - sin_acos_x = jnp.sin(jnp.acos(x)) - return jnp.where( - sin_acos_x != 0, - jnp.sin((n_i + 1) * jnp.acos(x)) / sin_acos_x, - (n_i + 1) * jnp.cos((n_i + 1) * jnp.acos(x)) / x, - ) + def large_n_small_x(x): + sin_acos_x = jnp.sin(jnp.acos(x)) + return jnp.where( + sin_acos_x != 0, + jnp.sin((n_i + 1) * jnp.acos(x)) / sin_acos_x, + (n_i + 1) * jnp.cos((n_i + 1) * jnp.acos(x)) / x, + ) - def zero_n(x): - return jnp.ones_like(x) + def zero_n(x): + return jnp.ones_like(x) - def one_n(x): - return 2 * x + def one_n(x): + return 2 * x - def default(x): - - def f(_, carry): - p, q = carry - return (q, 2 * x * q - p) + def default(x): + def f(_, carry): + p, q = carry + return (q, 2 * x * q - p) - _, r = jax.lax.fori_loop(0, n_i - 1, f, init_val=(1.0, 2 * x)) - return r + _, r = jax.lax.fori_loop(0, n_i - 1, f, init_val=(1.0, 2 * x)) + return r - return jnp.piecewise( - x, - [ - n_i == 1, - n_i == 0, - (n_i > 8) & (jnp.abs(x) < 1), - jnp.abs(x) == 1.0, - n_i < 0, - ], - [one_n, zero_n, large_n_small_x, one_x, negative_n, default], - ) + return jnp.piecewise( + x, + [ + n_i == 1, + n_i == 0, + (n_i > 8) & (jnp.abs(x) < 1), + jnp.abs(x) == 1.0, + n_i < 0, + ], + [one_n, zero_n, large_n_small_x, one_x, negative_n, default], + ) - return vectorized(self, n.astype(jnp.int64)) + return vectorized(self, n.astype(jnp.int64)) @op(torch.ops.aten.special_erfcx) @op_base.promote_int_input def _aten_special_erfcx(x): - return jnp.exp(x * x) * jax.lax.erfc(x) + return jnp.exp(x * x) * jax.lax.erfc(x) @op(torch.ops.aten.erfc) @op_base.promote_int_input def _aten_erfcx(x): - return jax.lax.erfc(x) + return jax.lax.erfc(x) @op(torch.ops.aten.special_hermite_polynomial_h) @op_base.promote_int_input def _aten_special_hermite_polynomial_h(self, n): - # Adapted from https://github.com/pytorch/pytorch/blob/f8f41dcb24cb4f4e87a51bb04847942dd835e496/aten/src/ATen/native/Math.h#L3036-L3061 + # Adapted from https://github.com/pytorch/pytorch/blob/f8f41dcb24cb4f4e87a51bb04847942dd835e496/aten/src/ATen/native/Math.h#L3036-L3061 - @jnp.vectorize - def vectorized(x, n_i): + @jnp.vectorize + def vectorized(x, n_i): + def negative_n(x): + return jnp.zeros_like(x) - def negative_n(x): - return jnp.zeros_like(x) + def zero_n(x): + return jnp.ones_like(x) - def zero_n(x): - return jnp.ones_like(x) + def one_n(x): + return 2 * x - def one_n(x): - return 2 * x + def default(x): + def f(k, carry): + p, q = carry + return (q, 2 * x * q - 2 * k * p) - def default(x): - - def f(k, carry): - p, q = carry - return (q, 2 * x * q - 2 * k * p) - - _, r = jax.lax.fori_loop(1, n_i, f, init_val=(1.0, 2 * x)) - return r + _, r = jax.lax.fori_loop(1, n_i, f, init_val=(1.0, 2 * x)) + return r - return jnp.piecewise(x, [n_i == 1, n_i == 0, n_i < 0], - [one_n, zero_n, negative_n, default]) + return jnp.piecewise( + x, + [n_i == 1, n_i == 0, n_i < 0], + [one_n, zero_n, negative_n, default], + ) - return vectorized(self, n.astype(jnp.int64)) + return vectorized(self, n.astype(jnp.int64)) @op(torch.ops.aten.special_hermite_polynomial_he) @op_base.promote_int_input def _aten_special_hermite_polynomial_he(self, n): - # Adapted from https://github.com/pytorch/pytorch/blob/f8f41dcb24cb4f4e87a51bb04847942dd835e496/aten/src/ATen/native/Math.h#L3073-L3098 + # Adapted from https://github.com/pytorch/pytorch/blob/f8f41dcb24cb4f4e87a51bb04847942dd835e496/aten/src/ATen/native/Math.h#L3073-L3098 - @jnp.vectorize - def vectorized(x, n_i): + @jnp.vectorize + def vectorized(x, n_i): + def negative_n(x): + return jnp.zeros_like(x) - def negative_n(x): - return jnp.zeros_like(x) + def zero_n(x): + return jnp.ones_like(x) - def zero_n(x): - return jnp.ones_like(x) + def one_n(x): + return x - def one_n(x): - return x - - def default(x): + def default(x): + def f(k, carry): + p, q = carry + return (q, x * q - k * p) - def f(k, carry): - p, q = carry - return (q, x * q - k * p) + _, r = jax.lax.fori_loop(1, n_i, f, init_val=(1.0, x)) + return r - _, r = jax.lax.fori_loop(1, n_i, f, init_val=(1.0, x)) - return r + return jnp.piecewise( + x, + [n_i == 1.0, n_i == 0.0, n_i < 0], + [one_n, zero_n, negative_n, default], + ) - return jnp.piecewise(x, [n_i == 1.0, n_i == 0.0, n_i < 0], - [one_n, zero_n, negative_n, default]) - - return vectorized(self, n.astype(jnp.int64)) + return vectorized(self, n.astype(jnp.int64)) @op(torch.ops.aten.multinomial, needs_env=True) -def _aten_multinomial(input, - num_samples, - replacement=False, - *, - generator=None, - out=None, - env=None): - assert num_samples <= input.shape[ - -1] or replacement, "cannot take a larger sample than population when replacement=False" - key = env.get_and_rotate_prng_key(generator) - if input.ndim == 1: - return jax.random.choice( - key, input.shape[-1], (num_samples,), replace=replacement, p=input) - else: - return jnp.array([ - jax.random.choice( - key, - input.shape[-1], (num_samples,), - replace=replacement, - p=input[i, :]) for i in range(input.shape[0]) - ]) +def _aten_multinomial( + input, num_samples, replacement=False, *, generator=None, out=None, env=None +): + assert num_samples <= input.shape[-1] or replacement, ( + "cannot take a larger sample than population when replacement=False" + ) + key = env.get_and_rotate_prng_key(generator) + if input.ndim == 1: + return jax.random.choice( + key, input.shape[-1], (num_samples,), replace=replacement, p=input + ) + else: + return jnp.array([ + jax.random.choice( + key, + input.shape[-1], + (num_samples,), + replace=replacement, + p=input[i, :], + ) + for i in range(input.shape[0]) + ]) @op(torch.ops.aten.narrow) @op(torch.ops.aten.narrow_copy) def _aten_narrow(input, dim, start, length): - return jax.lax.dynamic_slice_in_dim(input, start, length, axis=dim) + return jax.lax.dynamic_slice_in_dim(input, start, length, axis=dim) @op(torch.ops.aten.flatten) def _aten_flatten(x, start_dim=0, end_dim=-1): - """ - Flattens a JAX array (similar to torch.flatten). + """ + Flattens a JAX array (similar to torch.flatten). - Args: - x: The JAX array to be flattened. - start_dim: The first dimension to include in the flattening. - end_dim: The last dimension to include in the flattening. + Args: + x: The JAX array to be flattened. + start_dim: The first dimension to include in the flattening. + end_dim: The last dimension to include in the flattening. - Returns: - A flattened JAX array. - """ - shape = x.shape + Returns: + A flattened JAX array. + """ + shape = x.shape - if end_dim < 0: - end_dim += len(shape) # Handle negative indexing + if end_dim < 0: + end_dim += len(shape) # Handle negative indexing - new_shape = (*shape[:start_dim], -1, *shape[end_dim + 1:]) - return jnp.reshape(x, new_shape) + new_shape = (*shape[:start_dim], -1, *shape[end_dim + 1 :]) + return jnp.reshape(x, new_shape) @op(torch.ops.aten.new_empty) def _new_empty(self, size, **kwargs): - dtype = kwargs.get('dtype') - if dtype is not None: - dtype = mappings.t2j_dtype(dtype) - else: - dtype = self.dtype - return jnp.empty(size, dtype=dtype) + dtype = kwargs.get("dtype") + if dtype is not None: + dtype = mappings.t2j_dtype(dtype) + else: + dtype = self.dtype + return jnp.empty(size, dtype=dtype) @op(torch.ops.aten.new_empty_strided) def _new_empty_strided(self, size, stride, dtype=None, **kwargs): - # Ignore stride, since JAX and torch tensor doesn't share the same memory. - if not dtype: - return jnp.empty(size, dtype=self.dtype) - else: - jax_dtype = mappings.t2j_dtype(dtype) - return jnp.empty(size, dtype=jax_dtype) + # Ignore stride, since JAX and torch tensor doesn't share the same memory. + if not dtype: + return jnp.empty(size, dtype=self.dtype) + else: + jax_dtype = mappings.t2j_dtype(dtype) + return jnp.empty(size, dtype=jax_dtype) @op(torch.ops.aten._unsafe_index_put) def _aten_unsafe_index_put(self, indices, values, accumulate=False): - return _aten_index_put(self, indices, values, accumulate) + return _aten_index_put(self, indices, values, accumulate) -@op(torch.ops.aten.conj_physical, torch.ops.aten.conj, - torch.ops.aten._conj_physical, torch.ops.aten._conj) +@op( + torch.ops.aten.conj_physical, + torch.ops.aten.conj, + torch.ops.aten._conj_physical, + torch.ops.aten._conj, +) def _aten_conj_physical(self): - return jnp.conjugate(self) + return jnp.conjugate(self) @op(torch.ops.aten.log_sigmoid) def _aten_log_sigmoid(x): - return jax.nn.log_sigmoid(x) + return jax.nn.log_sigmoid(x) # torch.qr @op(torch.ops.aten.qr) def _aten_qr(input, *args, **kwargs): - jax_mode = "reduced" - # torch bool param 'simple=True' corresponds to jax 'reduced' mode, - # and simple=False corresponds to jax 'complete' mode. - if kwargs.get("simple") is False: - jax_mode = "complete" - return jax.numpy.linalg.qr(input, mode=jax_mode) + jax_mode = "reduced" + # torch bool param 'simple=True' corresponds to jax 'reduced' mode, + # and simple=False corresponds to jax 'complete' mode. + if kwargs.get("simple") is False: + jax_mode = "complete" + return jax.numpy.linalg.qr(input, mode=jax_mode) # torch.linalg.qr @op(torch.ops.aten.linalg_qr) def _aten_linalg_qr(input, *args, **kwargs): - mode = kwargs.get("mode", "reduced") - return jax.numpy.linalg.qr(input, mode=mode) + mode = kwargs.get("mode", "reduced") + return jax.numpy.linalg.qr(input, mode=mode) # torch.linalg.matrix_exp @op(torch.ops.aten.linalg_matrix_exp) def _aten_linalg_matrix_exp(input): - return jax.scipy.linalg.expm(input) + return jax.scipy.linalg.expm(input) # torch._linalg.slogdet @op(torch.ops.aten._linalg_slogdet) def _aten__linalg_slogdet(input): - res = jnp.linalg.slogdet(input) - return res.sign, res.logabsdet + res = jnp.linalg.slogdet(input) + return res.sign, res.logabsdet # torch.linalg.svd @op(torch.ops.aten._linalg_svd) def _aten__linalg_svd(a, full_matrices=False, **kwargs): - return jnp.linalg.svd(a, full_matrices=full_matrices, **kwargs) + return jnp.linalg.svd(a, full_matrices=full_matrices, **kwargs) # torch.linalg.pinv @op(torch.ops.aten.linalg_pinv.atol_rtol_tensor) def _aten_linalg_pinv_atol_rtol_tensor(a, rtol=None, **kwargs): - return jnp.linalg.pinv(a, rtol, hermitian=False) + return jnp.linalg.pinv(a, rtol, hermitian=False) # torch.linalg.solve @op(torch.ops.aten._linalg_solve_ex) def _aten__linalg_solve_ex(a, b): - batched = False - if b.ndim > 1 and b.shape[-1] == a.shape[-1]: - batched = True - b = b[..., None] - res = jnp.linalg.solve(a, b) - if batched: - res = res.squeeze(-1) - info_shape = a.shape[0] if len(a.shape) >= 3 else [] - info = jnp.zeros(info_shape, dtype=mappings.t2j_dtype(torch.int32)) - return res, info + batched = False + if b.ndim > 1 and b.shape[-1] == a.shape[-1]: + batched = True + b = b[..., None] + res = jnp.linalg.solve(a, b) + if batched: + res = res.squeeze(-1) + info_shape = a.shape[0] if len(a.shape) >= 3 else [] + info = jnp.zeros(info_shape, dtype=mappings.t2j_dtype(torch.int32)) + return res, info # torch.linalg.solve_triangular @op(torch.ops.aten.linalg_solve_triangular) -def _aten_linalg_solve_triangular(a, - b, - *, - upper=True, - left=True, - unitriangular=False): - if left is False: - a = jnp.matrix_transpose(a) - b = jnp.matrix_transpose(b) - upper = not upper - res = jax.scipy.linalg.solve_triangular( - a, b, lower=not upper, unit_diagonal=unitriangular) - if left is False: - res = jnp.matrix_transpose(res) - return res +def _aten_linalg_solve_triangular( + a, b, *, upper=True, left=True, unitriangular=False +): + if left is False: + a = jnp.matrix_transpose(a) + b = jnp.matrix_transpose(b) + upper = not upper + res = jax.scipy.linalg.solve_triangular( + a, b, lower=not upper, unit_diagonal=unitriangular + ) + if left is False: + res = jnp.matrix_transpose(res) + return res @op(torch.ops.aten.linalg_inv_ex) def _aten_linalg_inv_ex(a): - ainv = jnp.linalg.inv(a) - info = jnp.zeros(a.shape[:-2], jnp.int32) - return ainv, info + ainv = jnp.linalg.inv(a) + info = jnp.zeros(a.shape[:-2], jnp.int32) + return ainv, info @op(torch.ops.aten._linalg_check_errors) def _aten__linalg_check_errors(*args, **kwargs): - pass + pass @op(torch.ops.aten.median) def _aten_median(self, dim=None, keepdim=False): - output = _with_reduction_scalar( - functools.partial(jnp.quantile, q=0.5, method='lower'), - self, - dim=dim, - keepdim=keepdim).astype(self.dtype) - if dim is None: - return output - else: - index = _with_reduction_scalar(_get_median_index, self, dim, - keepdim).astype(jnp.int64) - return output, index + output = _with_reduction_scalar( + functools.partial(jnp.quantile, q=0.5, method="lower"), + self, + dim=dim, + keepdim=keepdim, + ).astype(self.dtype) + if dim is None: + return output + else: + index = _with_reduction_scalar( + _get_median_index, self, dim, keepdim + ).astype(jnp.int64) + return output, index @op(torch.ops.aten.nanmedian) def _aten_nanmedian(input, dim=None, keepdim=False, *, out=None): - output = _with_reduction_scalar( - functools.partial(jnp.nanquantile, q=0.5, method='lower'), - input, - dim=dim, - keepdim=keepdim).astype(input.dtype) - if dim is None: - return output - else: - index = _with_reduction_scalar(_get_median_index, input, dim, - keepdim).astype(jnp.int64) - return output, index + output = _with_reduction_scalar( + functools.partial(jnp.nanquantile, q=0.5, method="lower"), + input, + dim=dim, + keepdim=keepdim, + ).astype(input.dtype) + if dim is None: + return output + else: + index = _with_reduction_scalar( + _get_median_index, input, dim, keepdim + ).astype(jnp.int64) + return output, index def _get_median_index(x, axis=None, keepdims=False): - sorted_arg = jnp.argsort(x, axis=axis) - n = x.shape[axis] if axis is not None else x.size - if n % 2 == 1: - index = n // 2 - else: - index = (n // 2) - 1 - if axis is None: - median_index = sorted_arg[index] - else: - median_index = jnp.take(sorted_arg, index, axis=axis) - if keepdims and axis is not None: - median_index = jnp.expand_dims(median_index, axis) - return median_index + sorted_arg = jnp.argsort(x, axis=axis) + n = x.shape[axis] if axis is not None else x.size + if n % 2 == 1: + index = n // 2 + else: + index = (n // 2) - 1 + if axis is None: + median_index = sorted_arg[index] + else: + median_index = jnp.take(sorted_arg, index, axis=axis) + if keepdims and axis is not None: + median_index = jnp.expand_dims(median_index, axis) + return median_index @op(torch.ops.aten.triangular_solve) -def _aten_triangular_solve(b, - a, - upper=True, - transpose=False, - unittriangular=False): - return (jax.lax.linalg.triangular_solve( - a, - b, - left_side=True, - lower=not upper, - transpose_a=transpose, - unit_diagonal=unittriangular), a) +def _aten_triangular_solve( + b, a, upper=True, transpose=False, unittriangular=False +): + return ( + jax.lax.linalg.triangular_solve( + a, + b, + left_side=True, + lower=not upper, + transpose_a=transpose, + unit_diagonal=unittriangular, + ), + a, + ) # func: _fft_c2c(Tensor self, SymInt[] dim, int normalization, bool forward) -> Tensor @op(torch.ops.aten._fft_c2c) def _aten__fft_c2c(self, dim, normalization, forward): - if forward: - norm = [ - 'backward', - 'ortho', - 'forward', - ][normalization] - return jnp.fft.fftn(self, axes=dim, norm=norm) - else: - norm = [ - 'forward', - 'ortho', - 'backward', - ][normalization] - return jnp.fft.ifftn(self, axes=dim, norm=norm) + if forward: + norm = [ + "backward", + "ortho", + "forward", + ][normalization] + return jnp.fft.fftn(self, axes=dim, norm=norm) + else: + norm = [ + "forward", + "ortho", + "backward", + ][normalization] + return jnp.fft.ifftn(self, axes=dim, norm=norm) @op(torch.ops.aten._fft_r2c) def _aten__fft_r2c(self, dim, normalization, onesided): - norm = [ - 'backward', - 'ortho', - 'forward', - ][normalization] - if onesided: - return jnp.fft.rfftn(self, axes=dim, norm=norm) - else: - return jnp.fft.fftn(self, axes=dim, norm=norm) + norm = [ + "backward", + "ortho", + "forward", + ][normalization] + if onesided: + return jnp.fft.rfftn(self, axes=dim, norm=norm) + else: + return jnp.fft.fftn(self, axes=dim, norm=norm) @op(torch.ops.aten._fft_c2r) def _aten__fft_c2r(self, dim, normalization, last_dim_size): - norm = [ - 'forward', - 'ortho', - 'backward', - ][normalization] - if len(dim) == 1: - s = [last_dim_size] - else: - s = None - return jnp.fft.irfftn(self, norm=norm, axes=dim, s=s) + norm = [ + "forward", + "ortho", + "backward", + ][normalization] + if len(dim) == 1: + s = [last_dim_size] + else: + s = None + return jnp.fft.irfftn(self, norm=norm, axes=dim, s=s) @op(torch.ops.aten._trilinear) -def _aten_trilinear(i1, - i2, - i3, - expand1, - expand2, - expand3, - sumdim, - unroll_dim=1): - return _aten_sum( - jnp.expand_dims(i1, expand1) * jnp.expand_dims(i2, expand2) * - jnp.expand_dims(i3, expand3), sumdim) +def _aten_trilinear( + i1, i2, i3, expand1, expand2, expand3, sumdim, unroll_dim=1 +): + return _aten_sum( + jnp.expand_dims(i1, expand1) + * jnp.expand_dims(i2, expand2) + * jnp.expand_dims(i3, expand3), + sumdim, + ) @op(torch.ops.aten.max_unpool2d) @op(torch.ops.aten.max_unpool3d) def _aten_max_unpoolxd(input, indices, output_size, stride=None, padding=0): - if output_size is None: - raise ValueError( - "output_size value is not set correctly. It cannot be None or empty.") - - output_size = [input.shape[0], input.shape[1]] + output_size - output = jnp.zeros(output_size, dtype=input.dtype) - - for idx in np.ndindex(input.shape): - max_index = indices[idx] - spatial_dims = output_size[2:] # (D, H, W) - unpooled_spatial_idx = np.unravel_index(max_index, spatial_dims) - full_idx = idx[:2] + unpooled_spatial_idx - output = output.at[full_idx].set(input[idx]) - - return output - - -def _aten_upsample(input, - output_size, - align_corners, - antialias, - method, - scale_factors=None, - scales_h=None, - scales_w=None): - # input: is of type jaxlib.xla_extension.ArrayImpl - image = input - - # https://jax.readthedocs.io/en/latest/_autosummary/jax.image.resize.html - # Resize does not distinguish batch, channel size. - # We need to leave them as is - # https://pytorch.org/vision/stable/transforms.html#supported-input-types-and-conventions - # pytorch image shape is (C,H,W) or (N,C,H,W) - # N - batch size - # C - no of channels - # H,W - heigth, width - - shape = list(image.shape) - # overriding output_size - if scale_factors: - shape[-1] = int(math.floor(shape[-1] * scale_factors[-1])) - shape[-2] = int(math.floor(shape[-2] * scale_factors[-2])) - if scales_h: - shape[-2] = int(math.floor(shape[-2] * scales_h)) - if scales_w: - shape[-1] = int(math.floor(shape[-1] * scales_w)) - # output_size overrides scale_factors, scales_* - if output_size: - shape[-1] = output_size[-1] - shape[-2] = output_size[-2] - - # pytorch upsample_bilinear returns the input as is when the shape is the same as input - if shape == list(image.shape): - return image - - spatial_dims = (2, 3) - if len(shape) == 3: - spatial_dims = (1, 2) - - scale = list([shape[i] / image.shape[i] for i in spatial_dims]) - if scale_factors: - scale = scale_factors - if scales_h: - scale[0] = scales_h - if scales_w: - scale[1] = scales_w - scale = jnp.array(scale) - - # align_corners is not supported in resize() - # https://github.com/jax-ml/jax/issues/11206 - if align_corners: - scale = jnp.array([ - (shape[i] - 1.0) / (image.shape[i] - 1.0) for i in spatial_dims - ]) - - translation = jnp.array([0 for i in spatial_dims]) - - return jax_reimplement.scale_and_translate( - image, - shape, - method=method, - scale=scale, - spatial_dims=spatial_dims, - translation=translation, - antialias=antialias, - ) + if output_size is None: + raise ValueError( + "output_size value is not set correctly. It cannot be None or empty." + ) + + output_size = [input.shape[0], input.shape[1]] + output_size + output = jnp.zeros(output_size, dtype=input.dtype) + + for idx in np.ndindex(input.shape): + max_index = indices[idx] + spatial_dims = output_size[2:] # (D, H, W) + unpooled_spatial_idx = np.unravel_index(max_index, spatial_dims) + full_idx = idx[:2] + unpooled_spatial_idx + output = output.at[full_idx].set(input[idx]) + + return output + + +def _aten_upsample( + input, + output_size, + align_corners, + antialias, + method, + scale_factors=None, + scales_h=None, + scales_w=None, +): + # input: is of type jaxlib.xla_extension.ArrayImpl + image = input + + # https://jax.readthedocs.io/en/latest/_autosummary/jax.image.resize.html + # Resize does not distinguish batch, channel size. + # We need to leave them as is + # https://pytorch.org/vision/stable/transforms.html#supported-input-types-and-conventions + # pytorch image shape is (C,H,W) or (N,C,H,W) + # N - batch size + # C - no of channels + # H,W - heigth, width + + shape = list(image.shape) + # overriding output_size + if scale_factors: + shape[-1] = int(math.floor(shape[-1] * scale_factors[-1])) + shape[-2] = int(math.floor(shape[-2] * scale_factors[-2])) + if scales_h: + shape[-2] = int(math.floor(shape[-2] * scales_h)) + if scales_w: + shape[-1] = int(math.floor(shape[-1] * scales_w)) + # output_size overrides scale_factors, scales_* + if output_size: + shape[-1] = output_size[-1] + shape[-2] = output_size[-2] + + # pytorch upsample_bilinear returns the input as is when the shape is the same as input + if shape == list(image.shape): + return image + + spatial_dims = (2, 3) + if len(shape) == 3: + spatial_dims = (1, 2) + + scale = list([shape[i] / image.shape[i] for i in spatial_dims]) + if scale_factors: + scale = scale_factors + if scales_h: + scale[0] = scales_h + if scales_w: + scale[1] = scales_w + scale = jnp.array(scale) + + # align_corners is not supported in resize() + # https://github.com/jax-ml/jax/issues/11206 + if align_corners: + scale = jnp.array([ + (shape[i] - 1.0) / (image.shape[i] - 1.0) for i in spatial_dims + ]) + + translation = jnp.array([0 for i in spatial_dims]) + + return jax_reimplement.scale_and_translate( + image, + shape, + method=method, + scale=scale, + spatial_dims=spatial_dims, + translation=translation, + antialias=antialias, + ) @op(torch.ops.aten._upsample_bilinear2d_aa) -def _aten_upsample_billinear_aa(input, - output_size, - align_corners, - scale_factors=None, - scales_h=None, - scales_w=None): - return _aten_upsample( - input, - output_size, - align_corners, - True, # antialias - "bilinear", # method - scale_factors, - scales_h, - scales_w) +def _aten_upsample_billinear_aa( + input, + output_size, + align_corners, + scale_factors=None, + scales_h=None, + scales_w=None, +): + return _aten_upsample( + input, + output_size, + align_corners, + True, # antialias + "bilinear", # method + scale_factors, + scales_h, + scales_w, + ) @op(torch.ops.aten._upsample_bicubic2d_aa) -def _aten_upsample_bicubic2d_aa(input, - output_size, - align_corners, - scale_factors=None, - scales_h=None, - scales_w=None): - return _aten_upsample( - input, - output_size, - align_corners, - True, # antialias - "bicubic", # method - scale_factors, - scales_h, - scales_w) +def _aten_upsample_bicubic2d_aa( + input, + output_size, + align_corners, + scale_factors=None, + scales_h=None, + scales_w=None, +): + return _aten_upsample( + input, + output_size, + align_corners, + True, # antialias + "bicubic", # method + scale_factors, + scales_h, + scales_w, + ) @op(torch.ops.aten.polar) def _aten_polar(abs, angle, *, out=None): - return jax.lax.complex(abs * jnp.cos(angle), abs * jnp.sin(angle)) + return jax.lax.complex(abs * jnp.cos(angle), abs * jnp.sin(angle)) @op(torch.ops.aten.cdist) -def _aten_cdist(x1, - x2, - p=2.0, - compute_mode='use_mm_for_euclid_dist_if_necessary'): - x1 = x1.astype(jnp.float32) - x2 = x2.astype(jnp.float32) - - if p == 0.0: - # For p = 0, use Hamming-like distance multiplied by the number of elements - return _hamming_distance(x1, x2).astype(jnp.float32) - elif p == 2.0: - # Use optimized Euclidean distance calculation - if compute_mode == 'use_mm_for_euclid_dist_if_necessary' and ( - x1.shape[-2] > 25 or x2.shape[-2] > 25): - return _euclidean_mm(x1, x2) - elif compute_mode == 'use_mm_for_euclid_dist': - return _euclidean_mm(x1, x2) +def _aten_cdist( + x1, x2, p=2.0, compute_mode="use_mm_for_euclid_dist_if_necessary" +): + x1 = x1.astype(jnp.float32) + x2 = x2.astype(jnp.float32) + + if p == 0.0: + # For p = 0, use Hamming-like distance multiplied by the number of elements + return _hamming_distance(x1, x2).astype(jnp.float32) + elif p == 2.0: + # Use optimized Euclidean distance calculation + if compute_mode == "use_mm_for_euclid_dist_if_necessary" and ( + x1.shape[-2] > 25 or x2.shape[-2] > 25 + ): + return _euclidean_mm(x1, x2) + elif compute_mode == "use_mm_for_euclid_dist": + return _euclidean_mm(x1, x2) + else: + return _euclidean_direct(x1, x2) else: - return _euclidean_direct(x1, x2) - else: - # General p-norm distance calculation - diff = jnp.abs(jnp.expand_dims(x1, -2) - jnp.expand_dims(x2, -3)) - return jnp.sum(jnp.power(diff, p), axis=-1).astype(jnp.float32)**(1 / p) + # General p-norm distance calculation + diff = jnp.abs(jnp.expand_dims(x1, -2) - jnp.expand_dims(x2, -3)) + return jnp.sum(jnp.power(diff, p), axis=-1).astype(jnp.float32) ** ( + 1 / p + ) def _hamming_distance(x1, x2): - """ - Computes the Hamming-like distance for p=0. + """ + Computes the Hamming-like distance for p=0. - Args: - x1: JAX array of shape (..., P, M) - x2: JAX array of shape (..., R, M) + Args: + x1: JAX array of shape (..., P, M) + x2: JAX array of shape (..., R, M) - Returns: - JAX array of shape (..., P, R) representing pairwise Hamming distances. - """ - diff = jnp.not_equal(jnp.expand_dims(x1, -2), jnp.expand_dims(x2, -3)) + Returns: + JAX array of shape (..., P, R) representing pairwise Hamming distances. + """ + diff = jnp.not_equal(jnp.expand_dims(x1, -2), jnp.expand_dims(x2, -3)) - hamming_dist = jnp.sum(diff, axis=-1).astype(jnp.float32) + hamming_dist = jnp.sum(diff, axis=-1).astype(jnp.float32) - return hamming_dist + return hamming_dist def _euclidean_mm(x1, x2): - """ - Computes the Euclidean distance using matrix multiplication. + """ + Computes the Euclidean distance using matrix multiplication. - Args: - x1: JAX array of shape (..., P, M) - x2: JAX array of shape (..., R, M) + Args: + x1: JAX array of shape (..., P, M) + x2: JAX array of shape (..., R, M) - Returns: - JAX array of shape (..., P, R) representing pairwise Euclidean distances. - """ - x1_sq = jnp.sum(x1**2, axis=-1, keepdims=True).astype(jnp.float32) - x2_sq = jnp.sum(x2**2, axis=-1, keepdims=True).astype(jnp.float32) + Returns: + JAX array of shape (..., P, R) representing pairwise Euclidean distances. + """ + x1_sq = jnp.sum(x1**2, axis=-1, keepdims=True).astype(jnp.float32) + x2_sq = jnp.sum(x2**2, axis=-1, keepdims=True).astype(jnp.float32) - x2_sq = jnp.swapaxes(x2_sq, -2, -1) + x2_sq = jnp.swapaxes(x2_sq, -2, -1) - dot_product = jnp.matmul(x1, jnp.swapaxes(x2, -1, -2)) + dot_product = jnp.matmul(x1, jnp.swapaxes(x2, -1, -2)) - dist_sq = x1_sq + x2_sq - 2 * dot_product - dist_sq = jnp.maximum(dist_sq, 0.0) - dist = jnp.sqrt(dist_sq).astype(jnp.float32) + dist_sq = x1_sq + x2_sq - 2 * dot_product + dist_sq = jnp.maximum(dist_sq, 0.0) + dist = jnp.sqrt(dist_sq).astype(jnp.float32) - return dist + return dist def _euclidean_direct(x1, x2): - """ - Computes the Euclidean distance directly without matrix multiplication. + """ + Computes the Euclidean distance directly without matrix multiplication. - Args: - x1: JAX array of shape (..., P, M) - x2: JAX array of shape (..., R, M) + Args: + x1: JAX array of shape (..., P, M) + x2: JAX array of shape (..., R, M) - Returns: - JAX array of shape (..., P, R) representing pairwise Euclidean distances. - """ - diff = jnp.expand_dims(x1, -2) - jnp.expand_dims(x2, -3) + Returns: + JAX array of shape (..., P, R) representing pairwise Euclidean distances. + """ + diff = jnp.expand_dims(x1, -2) - jnp.expand_dims(x2, -3) - dist_sq = jnp.sum(diff**2, axis=-1).astype(jnp.float32) + dist_sq = jnp.sum(diff**2, axis=-1).astype(jnp.float32) - dist_sq = jnp.maximum(dist_sq, 0.0) + dist_sq = jnp.maximum(dist_sq, 0.0) - dist = jnp.sqrt(dist_sq).astype(jnp.float32) + dist = jnp.sqrt(dist_sq).astype(jnp.float32) - return dist + return dist @op(torch.ops.aten.lu_unpack) def _aten_lu_unpack(LU_data, LU_pivots, unpack_data=True, unpack_pivots=True): - # lu_unpack doesnt exist in jax. - # Get commonly used data shape variables - n = LU_data.shape[-2] - m = LU_data.shape[-1] - dim = min(n, m) - - ### Compute the Lower and Upper triangle - if unpack_data: - # Extract lower triangle - L = jnp.tril(LU_data, k=-1) - - #emulate pytorch behavior: Add ones to the diagonal of L - eye = jnp.eye(n, m, dtype=LU_data.dtype) - L = L + eye - - # emulate pytorch behavior: Reshape lower triangle to match pivot - start_indices = jnp.zeros(len(LU_data.shape), dtype=int) - limit_indices = list(LU_data.shape) - limit_indices[-1] = dim - L = jax.lax.slice(L, start_indices, limit_indices) - - # Extract upper triangle - U = jnp.triu(LU_data) - - # emulate pytorch behavior: Reshape upper triangle to match pivot - start_indices = jnp.zeros(len(LU_data.shape), dtype=int) - limit_indices = list(LU_data.shape) - limit_indices[-2] = dim - U = jax.lax.slice(U, start_indices, limit_indices) - else: - # emulate pytroch behavior: return empty tensors - L = torch.empty(torch.Size([0])) - U = torch.empty(torch.Size([0])) - - ### Compute the Permutation matrix - if unpack_pivots: - # We should return a permutation matrix (2D) for each pivot array (1D) - # The shape of the final Permutation matrix depends on the shape of the input - # data and the pivots - - # start with a 2D identity matrix and tile it to the other dims of input data - identity2d = jnp.identity(n, dtype=jnp.float32) - tile_shape = list(LU_data.shape) - tile_shape[-1] = 1 - tile_shape[-2] = 1 - P = jnp.tile(identity2d, tile_shape) - - # closure to be called for each input 2D matrix. - def _lu_unpack_2d(p, pivot): - _pivot = pivot - 1 # pivots are offset by 1 in jax - indices = jnp.array([*range(n)], dtype=jnp.int32) - - def update_indices(i, _indices): - tmp = _indices[i] - _indices = _indices.at[i].set(_indices[_pivot[i]]) - _indices = _indices.at[_pivot[i]].set(tmp) - return _indices - - indices = jax.lax.fori_loop(0, _pivot.size, update_indices, indices) - p = p[jnp.array(indices)] - p = jnp.transpose(p) - return p - - if len(LU_pivots.shape) == 1: - # if we are dealing with a simple 2D input and 1D pivot, call the closure directly - P = _lu_unpack_2d(P, LU_pivots) + # lu_unpack doesnt exist in jax. + # Get commonly used data shape variables + n = LU_data.shape[-2] + m = LU_data.shape[-1] + dim = min(n, m) + + ### Compute the Lower and Upper triangle + if unpack_data: + # Extract lower triangle + L = jnp.tril(LU_data, k=-1) + + # emulate pytorch behavior: Add ones to the diagonal of L + eye = jnp.eye(n, m, dtype=LU_data.dtype) + L = L + eye + + # emulate pytorch behavior: Reshape lower triangle to match pivot + start_indices = jnp.zeros(len(LU_data.shape), dtype=int) + limit_indices = list(LU_data.shape) + limit_indices[-1] = dim + L = jax.lax.slice(L, start_indices, limit_indices) + + # Extract upper triangle + U = jnp.triu(LU_data) + + # emulate pytorch behavior: Reshape upper triangle to match pivot + start_indices = jnp.zeros(len(LU_data.shape), dtype=int) + limit_indices = list(LU_data.shape) + limit_indices[-2] = dim + U = jax.lax.slice(U, start_indices, limit_indices) else: - # We are dealing with >=3D inputs. Flatten inputs to 3D and use vmap to call the - # closure for each 2D matrix. Finally unflatten the result to match the input data - # shape. - - # reshape permutation matrix to 3d - dim_size = jnp.prod(jnp.array(P.shape[:-2])) - newPshape = (dim_size, P.shape[-2], P.shape[-1]) - reshapedP = P.reshape(newPshape) - - # reshape pivots to 3d - dim_size = jnp.prod(jnp.array(LU_pivots.shape[:-1])) - newPivotshape = (dim_size, LU_pivots.shape[-1]) - reshapedPivot = LU_pivots.reshape(newPivotshape) - - # vmap the reshaped 3d tensors - v_lu_unpack_2d = jax.vmap(_lu_unpack_2d, in_axes=(0, 0)) - unpackedP = v_lu_unpack_2d(reshapedP, reshapedPivot) - - # reshape result back to P's shape - newRetshape = (*P.shape[:-2], unpackedP.shape[-2], unpackedP.shape[-1]) - P = unpackedP.reshape(newRetshape) - else: - # emulate pytroch behavior: return empty tensors - P = torch.empty(torch.Size([0])) + # emulate pytroch behavior: return empty tensors + L = torch.empty(torch.Size([0])) + U = torch.empty(torch.Size([0])) + + ### Compute the Permutation matrix + if unpack_pivots: + # We should return a permutation matrix (2D) for each pivot array (1D) + # The shape of the final Permutation matrix depends on the shape of the input + # data and the pivots + + # start with a 2D identity matrix and tile it to the other dims of input data + identity2d = jnp.identity(n, dtype=jnp.float32) + tile_shape = list(LU_data.shape) + tile_shape[-1] = 1 + tile_shape[-2] = 1 + P = jnp.tile(identity2d, tile_shape) + + # closure to be called for each input 2D matrix. + def _lu_unpack_2d(p, pivot): + _pivot = pivot - 1 # pivots are offset by 1 in jax + indices = jnp.array([*range(n)], dtype=jnp.int32) + + def update_indices(i, _indices): + tmp = _indices[i] + _indices = _indices.at[i].set(_indices[_pivot[i]]) + _indices = _indices.at[_pivot[i]].set(tmp) + return _indices + + indices = jax.lax.fori_loop(0, _pivot.size, update_indices, indices) + p = p[jnp.array(indices)] + p = jnp.transpose(p) + return p + + if len(LU_pivots.shape) == 1: + # if we are dealing with a simple 2D input and 1D pivot, call the closure directly + P = _lu_unpack_2d(P, LU_pivots) + else: + # We are dealing with >=3D inputs. Flatten inputs to 3D and use vmap to call the + # closure for each 2D matrix. Finally unflatten the result to match the input data + # shape. + + # reshape permutation matrix to 3d + dim_size = jnp.prod(jnp.array(P.shape[:-2])) + newPshape = (dim_size, P.shape[-2], P.shape[-1]) + reshapedP = P.reshape(newPshape) + + # reshape pivots to 3d + dim_size = jnp.prod(jnp.array(LU_pivots.shape[:-1])) + newPivotshape = (dim_size, LU_pivots.shape[-1]) + reshapedPivot = LU_pivots.reshape(newPivotshape) + + # vmap the reshaped 3d tensors + v_lu_unpack_2d = jax.vmap(_lu_unpack_2d, in_axes=(0, 0)) + unpackedP = v_lu_unpack_2d(reshapedP, reshapedPivot) + + # reshape result back to P's shape + newRetshape = ( + *P.shape[:-2], + unpackedP.shape[-2], + unpackedP.shape[-1], + ) + P = unpackedP.reshape(newRetshape) + else: + # emulate pytroch behavior: return empty tensors + P = torch.empty(torch.Size([0])) - return P, L, U + return P, L, U @op(torch.ops.aten.linear) def linear(input, weight, bias=None): - res = input @ jnp.transpose(weight) - if bias is not None: - res += bias - return res + res = input @ jnp.transpose(weight) + if bias is not None: + res += bias + return res @op(torch.ops.aten.kthvalue) def kthvalue(input, k, dim=None, keepdim=False, *, out=None): - if input.ndim == 0: - return input, jnp.array(0) - dimension = -1 - if dim is not None: - dimension = dim - while dimension < 0: - dimension = dimension + input.ndim - values = jax.lax.index_in_dim( - jnp.partition(input, k - 1, dimension), k - 1, dimension, keepdim) - indices = jax.lax.index_in_dim( - jnp.argpartition(input, k - 1, dimension).astype('int64'), k - 1, - dimension, keepdim) - return values, indices + if input.ndim == 0: + return input, jnp.array(0) + dimension = -1 + if dim is not None: + dimension = dim + while dimension < 0: + dimension = dimension + input.ndim + values = jax.lax.index_in_dim( + jnp.partition(input, k - 1, dimension), k - 1, dimension, keepdim + ) + indices = jax.lax.index_in_dim( + jnp.argpartition(input, k - 1, dimension).astype("int64"), + k - 1, + dimension, + keepdim, + ) + return values, indices @op(torch.ops.aten.take) def _aten_take(self, index): - return self.flatten()[index] + return self.flatten()[index] # func: pad(Tensor self, SymInt[] pad, str mode="constant", float? value=None) -> Tensor @op(torch.ops.aten.pad) -def _aten_pad(self, pad, mode='constant', value=None): - if not isinstance(pad, (tuple, list)) or len(pad) % 2 != 0: - raise ValueError("Padding must be a sequence of even length.") - - num_dims = self.ndim - if len(pad) > 2 * num_dims: - raise ValueError( - f"Padding sequence length ({len(pad)}) exceeds 2 * number of dimensions ({2 * num_dims})." - ) - - # JAX's pad function expects padding for each dimension as a tuple of (low, high) - # We need to reverse the pad sequence and group them for JAX. - # pad = [p_l0, p_r0, p_l1, p_r1, ...] - # becomes ((..., ..., (p_l1, p_r1), (p_l0, p_r0))) - jax_pad_width = [] - # Iterate in reverse pairs - for i in range(len(pad) // 2): - jax_pad_width.append((pad[(2 * i)], pad[(2 * i + 1)])) - - # Pad any leading dimensions with (0, 0) if the pad sequence is shorter - # than the number of dimensions. - for _ in range(num_dims - len(pad) // 2): - jax_pad_width.append((0, 0)) - - # Reverse the jax_pad_width list to match the dimension order - jax_pad_width.reverse() - - if mode == "constant": - if value is None: - value = 0.0 - return jnp.pad( - self, pad_width=jax_pad_width, mode="constant", constant_values=value) - elif mode == "reflect": - return jnp.pad(self, pad_width=jax_pad_width, mode="reflect") - elif mode == "edge": - return jnp.pad(self, pad_width=jax_pad_width, mode="edge") - else: - raise ValueError( - f"Unsupported padding mode: {mode}. Expected 'constant', 'reflect', or 'edge'." - ) +def _aten_pad(self, pad, mode="constant", value=None): + if not isinstance(pad, (tuple, list)) or len(pad) % 2 != 0: + raise ValueError("Padding must be a sequence of even length.") + + num_dims = self.ndim + if len(pad) > 2 * num_dims: + raise ValueError( + f"Padding sequence length ({len(pad)}) exceeds 2 * number of dimensions ({2 * num_dims})." + ) + + # JAX's pad function expects padding for each dimension as a tuple of (low, high) + # We need to reverse the pad sequence and group them for JAX. + # pad = [p_l0, p_r0, p_l1, p_r1, ...] + # becomes ((..., ..., (p_l1, p_r1), (p_l0, p_r0))) + jax_pad_width = [] + # Iterate in reverse pairs + for i in range(len(pad) // 2): + jax_pad_width.append((pad[(2 * i)], pad[(2 * i + 1)])) + + # Pad any leading dimensions with (0, 0) if the pad sequence is shorter + # than the number of dimensions. + for _ in range(num_dims - len(pad) // 2): + jax_pad_width.append((0, 0)) + + # Reverse the jax_pad_width list to match the dimension order + jax_pad_width.reverse() + + if mode == "constant": + if value is None: + value = 0.0 + return jnp.pad( + self, + pad_width=jax_pad_width, + mode="constant", + constant_values=value, + ) + elif mode == "reflect": + return jnp.pad(self, pad_width=jax_pad_width, mode="reflect") + elif mode == "edge": + return jnp.pad(self, pad_width=jax_pad_width, mode="edge") + else: + raise ValueError( + f"Unsupported padding mode: {mode}. Expected 'constant', 'reflect', or 'edge'." + ) mutation_ops_to_functional = { - torch.ops.aten.add_: - op_base.InplaceOp(torch.ops.aten.add), - torch.ops.aten.sub_: - op_base.InplaceOp(torch.ops.aten.sub), - torch.ops.aten.mul_: - op_base.InplaceOp(torch.ops.aten.mul), - torch.ops.aten.div_: - op_base.InplaceOp(torch.ops.aten.div), - torch.ops.aten.pow_: - op_base.InplaceOp(torch.ops.aten.pow), - torch.ops.aten.lt_: - op_base.InplaceOp(torch.ops.aten.lt), - torch.ops.aten.le_: - op_base.InplaceOp(torch.ops.aten.le), - torch.ops.aten.gt_: - op_base.InplaceOp(torch.ops.aten.gt), - torch.ops.aten.ge_: - op_base.InplaceOp(torch.ops.aten.ge), - torch.ops.aten.eq_: - op_base.InplaceOp(torch.ops.aten.eq), - torch.ops.aten.ne_: - op_base.InplaceOp(torch.ops.aten.ne), - torch.ops.aten.bernoulli_: - op_base.InplaceOp(torch.ops.aten.bernoulli.p), - torch.ops.aten.bernoulli_.float: - op_base.InplaceOp(_aten_bernoulli, is_jax_func=True), - torch.ops.aten.geometric_: - op_base.InplaceOp(torch.ops.aten.geometric), - torch.ops.aten.normal_: - op_base.InplaceOp(torch.ops.aten.normal), - torch.ops.aten.random_: - op_base.InplaceOp(torch.ops.aten.uniform), - torch.ops.aten.uniform_: - op_base.InplaceOp(torch.ops.aten.uniform), - torch.ops.aten.relu_: - op_base.InplaceOp(torch.ops.aten.relu), + torch.ops.aten.add_: op_base.InplaceOp(torch.ops.aten.add), + torch.ops.aten.sub_: op_base.InplaceOp(torch.ops.aten.sub), + torch.ops.aten.mul_: op_base.InplaceOp(torch.ops.aten.mul), + torch.ops.aten.div_: op_base.InplaceOp(torch.ops.aten.div), + torch.ops.aten.pow_: op_base.InplaceOp(torch.ops.aten.pow), + torch.ops.aten.lt_: op_base.InplaceOp(torch.ops.aten.lt), + torch.ops.aten.le_: op_base.InplaceOp(torch.ops.aten.le), + torch.ops.aten.gt_: op_base.InplaceOp(torch.ops.aten.gt), + torch.ops.aten.ge_: op_base.InplaceOp(torch.ops.aten.ge), + torch.ops.aten.eq_: op_base.InplaceOp(torch.ops.aten.eq), + torch.ops.aten.ne_: op_base.InplaceOp(torch.ops.aten.ne), + torch.ops.aten.bernoulli_: op_base.InplaceOp(torch.ops.aten.bernoulli.p), + torch.ops.aten.bernoulli_.float: op_base.InplaceOp( + _aten_bernoulli, is_jax_func=True + ), + torch.ops.aten.geometric_: op_base.InplaceOp(torch.ops.aten.geometric), + torch.ops.aten.normal_: op_base.InplaceOp(torch.ops.aten.normal), + torch.ops.aten.random_: op_base.InplaceOp(torch.ops.aten.uniform), + torch.ops.aten.uniform_: op_base.InplaceOp(torch.ops.aten.uniform), + torch.ops.aten.relu_: op_base.InplaceOp(torch.ops.aten.relu), # squeeze_ is expected to change tensor's shape. So replace with new value - torch.ops.aten.squeeze_: - op_base.InplaceOp(torch.ops.aten.squeeze, True), - torch.ops.aten.sqrt_: - op_base.InplaceOp(torch.ops.aten.sqrt), - torch.ops.aten.clamp_: - op_base.InplaceOp(torch.ops.aten.clamp), - torch.ops.aten.clamp_min_: - op_base.InplaceOp(torch.ops.aten.clamp_min), - torch.ops.aten.sigmoid_: - op_base.InplaceOp(torch.ops.aten.sigmoid), - torch.ops.aten.tanh_: - op_base.InplaceOp(torch.ops.aten.tanh), - torch.ops.aten.ceil_: - op_base.InplaceOp(torch.ops.aten.ceil), - torch.ops.aten.logical_not_: - op_base.InplaceOp(torch.ops.aten.logical_not), - torch.ops.aten.unsqueeze_: - op_base.InplaceOp(torch.ops.aten.unsqueeze), - torch.ops.aten.transpose_: - op_base.InplaceOp(torch.ops.aten.transpose), - torch.ops.aten.log_normal_: - op_base.InplaceOp(torch.ops.aten.log_normal), - torch.ops.aten.scatter_add_: - op_base.InplaceOp(torch.ops.aten.scatter_add), - torch.ops.aten.scatter_reduce_.two: - op_base.InplaceOp(torch.ops.aten.scatter_reduce), - torch.ops.aten.scatter_: - op_base.InplaceOp(torch.ops.aten.scatter), - torch.ops.aten.bitwise_or_: - op_base.InplaceOp(torch.ops.aten.bitwise_or), + torch.ops.aten.squeeze_: op_base.InplaceOp(torch.ops.aten.squeeze, True), + torch.ops.aten.sqrt_: op_base.InplaceOp(torch.ops.aten.sqrt), + torch.ops.aten.clamp_: op_base.InplaceOp(torch.ops.aten.clamp), + torch.ops.aten.clamp_min_: op_base.InplaceOp(torch.ops.aten.clamp_min), + torch.ops.aten.sigmoid_: op_base.InplaceOp(torch.ops.aten.sigmoid), + torch.ops.aten.tanh_: op_base.InplaceOp(torch.ops.aten.tanh), + torch.ops.aten.ceil_: op_base.InplaceOp(torch.ops.aten.ceil), + torch.ops.aten.logical_not_: op_base.InplaceOp(torch.ops.aten.logical_not), + torch.ops.aten.unsqueeze_: op_base.InplaceOp(torch.ops.aten.unsqueeze), + torch.ops.aten.transpose_: op_base.InplaceOp(torch.ops.aten.transpose), + torch.ops.aten.log_normal_: op_base.InplaceOp(torch.ops.aten.log_normal), + torch.ops.aten.scatter_add_: op_base.InplaceOp(torch.ops.aten.scatter_add), + torch.ops.aten.scatter_reduce_.two: op_base.InplaceOp( + torch.ops.aten.scatter_reduce + ), + torch.ops.aten.scatter_: op_base.InplaceOp(torch.ops.aten.scatter), + torch.ops.aten.bitwise_or_: op_base.InplaceOp(torch.ops.aten.bitwise_or), } # Note: tuple comparisons work intuitively, e.g. `_jax_version >= (0, 4, 32)`. @@ -5575,9 +5805,10 @@ def _aten_pad(self, pad, mode='constant', value=None): } for operator, mutation in mutation_ops_to_functional.items(): - ops_registry.register_torch_dispatch_op( - operator, - mutation, - is_jax_function=False, - is_view_op=True, - needs_env=(operator in mutation_needs_env)) + ops_registry.register_torch_dispatch_op( + operator, + mutation, + is_jax_function=False, + is_view_op=True, + needs_env=(operator in mutation_needs_env), + ) diff --git a/torchax/torchax/ops/jax_reimplement.py b/torchax/torchax/ops/jax_reimplement.py index d9acc3be51ab..236fb253de38 100644 --- a/torchax/torchax/ops/jax_reimplement.py +++ b/torchax/torchax/ops/jax_reimplement.py @@ -15,66 +15,93 @@ # JAX Link: https://github.com/jax-ml/jax/blob/18f48bd52abe907ff9818da52f3d195d32910c1b/jax/_src/image/scale.py#L52 -def compute_weight_mat(input_size: core.DimSize, output_size: core.DimSize, - scale, translation, kernel: Callable, antialias: bool): - dtype = jnp.result_type(scale, translation) - inv_scale = 1. / scale - # When downsampling the kernel should be scaled since we want to low pass - # filter and interpolate, but when upsampling it should not be since we only - # want to interpolate. - kernel_scale = jnp.maximum(inv_scale, 1.) if antialias else 1. - sample_f = ((jnp.arange(output_size, dtype=dtype) + 0.5) * inv_scale - - translation * inv_scale - 0.5) - x = ( - jnp.abs(sample_f[jnp.newaxis, :] - - jnp.arange(input_size, dtype=dtype)[:, jnp.newaxis]) / - kernel_scale) - weights = kernel(x) - - total_weight_sum = jnp.sum(weights, axis=0, keepdims=True) - weights = jnp.where( - jnp.abs(total_weight_sum) > 1000. * float(np.finfo(np.float32).eps), - jnp.divide(weights, jnp.where(total_weight_sum != 0, total_weight_sum, - 1)), 0) - # Zero out weights where the sample location is completely outside the input - # range. - # Note sample_f has already had the 0.5 removed, hence the weird range below. - - # (barney-s) -------------- returning weights without zeroing --------------------- - return weights - input_size_minus_0_5 = core.dimension_as_value(input_size) - 0.5 - return jnp.where( - jnp.logical_and(sample_f >= -0.5, sample_f - <= input_size_minus_0_5)[jnp.newaxis, :], weights, 0) - # (barney-s) -------------- END returning weights without zeroing --------------------- +def compute_weight_mat( + input_size: core.DimSize, + output_size: core.DimSize, + scale, + translation, + kernel: Callable, + antialias: bool, +): + dtype = jnp.result_type(scale, translation) + inv_scale = 1.0 / scale + # When downsampling the kernel should be scaled since we want to low pass + # filter and interpolate, but when upsampling it should not be since we only + # want to interpolate. + kernel_scale = jnp.maximum(inv_scale, 1.0) if antialias else 1.0 + sample_f = ( + (jnp.arange(output_size, dtype=dtype) + 0.5) * inv_scale + - translation * inv_scale + - 0.5 + ) + x = ( + jnp.abs( + sample_f[jnp.newaxis, :] + - jnp.arange(input_size, dtype=dtype)[:, jnp.newaxis] + ) + / kernel_scale + ) + weights = kernel(x) + + total_weight_sum = jnp.sum(weights, axis=0, keepdims=True) + weights = jnp.where( + jnp.abs(total_weight_sum) > 1000.0 * float(np.finfo(np.float32).eps), + jnp.divide( + weights, jnp.where(total_weight_sum != 0, total_weight_sum, 1) + ), + 0, + ) + # Zero out weights where the sample location is completely outside the input + # range. + # Note sample_f has already had the 0.5 removed, hence the weird range below. + + # (barney-s) -------------- returning weights without zeroing --------------------- + return weights + input_size_minus_0_5 = core.dimension_as_value(input_size) - 0.5 + return jnp.where( + jnp.logical_and(sample_f >= -0.5, sample_f <= input_size_minus_0_5)[ + jnp.newaxis, : + ], + weights, + 0, + ) + # (barney-s) -------------- END returning weights without zeroing --------------------- # JAX Link: https://github.com/jax-ml/jax/blob/18f48bd52abe907ff9818da52f3d195d32910c1b/jax/_src/image/scale.py#L86 -def _scale_and_translate(x, output_shape: core.Shape, - spatial_dims: Sequence[int], scale, translation, - kernel, antialias: bool, precision): - input_shape = x.shape - assert len(input_shape) == len(output_shape) - assert len(spatial_dims) == len(scale) - assert len(spatial_dims) == len(translation) - if len(spatial_dims) == 0: - return x - contractions = [] - in_indices = list(range(len(output_shape))) - out_indices = list(range(len(output_shape))) - for i, d in enumerate(spatial_dims): - d = canonicalize_axis(d, x.ndim) - m = input_shape[d] - n = output_shape[d] - w = compute_weight_mat(m, n, scale[i], translation[i], kernel, - antialias).astype(x.dtype) - contractions.append(w) - contractions.append([d, len(output_shape) + i]) - out_indices[d] = len(output_shape) + i - contractions.append(out_indices) - return jnp.einsum(x, in_indices, *contractions, precision=precision) +def _scale_and_translate( + x, + output_shape: core.Shape, + spatial_dims: Sequence[int], + scale, + translation, + kernel, + antialias: bool, + precision, +): + input_shape = x.shape + assert len(input_shape) == len(output_shape) + assert len(spatial_dims) == len(scale) + assert len(spatial_dims) == len(translation) + if len(spatial_dims) == 0: + return x + contractions = [] + in_indices = list(range(len(output_shape))) + out_indices = list(range(len(output_shape))) + for i, d in enumerate(spatial_dims): + d = canonicalize_axis(d, x.ndim) + m = input_shape[d] + n = output_shape[d] + w = compute_weight_mat( + m, n, scale[i], translation[i], kernel, antialias + ).astype(x.dtype) + contractions.append(w) + contractions.append([d, len(output_shape) + i]) + out_indices[d] = len(output_shape) + i + contractions.append(out_indices) + return jnp.einsum(x, in_indices, *contractions, precision=precision) # JAX Link: https://github.com/jax-ml/jax/blob/18f48bd52abe907ff9818da52f3d195d32910c1b/jax/_src/image/scale.py#L172 @@ -89,83 +116,96 @@ def scale_and_translate( scale, translation, # (barney-s) use string - method: str, #(barney-s) | ResizeMethod, + method: str, # (barney-s) | ResizeMethod, antialias: bool = True, - precision=lax.Precision.HIGHEST): - """Apply a scale and translation to an image. - - Generates a new image of shape 'shape' by resampling from the input image - using the sampling method corresponding to method. For 2D images, this - operation transforms a location in the input images, (x, y), to a location - in the output image according to:: - - (x * scale[1] + translation[1], y * scale[0] + translation[0]) - - (Note the *inverse* warp is used to generate the sample locations.) - Assumes half-centered pixels, i.e the pixel at integer location ``row, col`` - has coordinates ``y, x = row + 0.5, col + 0.5``, and similarly for other input - image dimensions. - - If an output location(pixel) maps to an input sample location that is outside - the input boundaries then the value for the output location will be set to - zero. - - The ``method`` argument expects one of the following resize methods: - - ``ResizeMethod.LINEAR``, ``"linear"``, ``"bilinear"``, ``"trilinear"``, - ``"triangle"`` `Linear interpolation`_. If ``antialias`` is ``True``, uses a - triangular filter when downsampling. - - ``ResizeMethod.CUBIC``, ``"cubic"``, ``"bicubic"``, ``"tricubic"`` - `Cubic interpolation`_, using the Keys cubic kernel. - - ``ResizeMethod.LANCZOS3``, ``"lanczos3"`` - `Lanczos resampling`_, using a kernel of radius 3. - - ``ResizeMethod.LANCZOS5``, ``"lanczos5"`` - `Lanczos resampling`_, using a kernel of radius 5. - - .. _Linear interpolation: https://en.wikipedia.org/wiki/Bilinear_interpolation - .. _Cubic interpolation: https://en.wikipedia.org/wiki/Bicubic_interpolation - .. _Lanczos resampling: https://en.wikipedia.org/wiki/Lanczos_resampling - - Args: - image: a JAX array. - shape: the output shape, as a sequence of integers with length equal to the - number of dimensions of `image`. - spatial_dims: A length K tuple specifying the spatial dimensions that the - passed scale and translation should be applied to. - scale: A [K] array with the same number of dimensions as image, containing - the scale to apply in each dimension. - translation: A [K] array with the same number of dimensions as image, - containing the translation to apply in each dimension. - method: the resizing method to use; either a ``ResizeMethod`` instance or a - string. Available methods are: LINEAR, LANCZOS3, LANCZOS5, CUBIC. - antialias: Should an antialiasing filter be used when downsampling? Defaults - to ``True``. Has no effect when upsampling. - - Returns: - The scale and translated image. - """ - shape = core.canonicalize_shape(shape) - if len(shape) != image.ndim: - msg = ('shape must have length equal to the number of dimensions of x; ' - f' {shape} vs {image.shape}') - raise ValueError(msg) - if isinstance(method, str): - method = ResizeMethod.from_string(method) - if method == ResizeMethod.NEAREST: - # Nearest neighbor is currently special-cased for straight resize, so skip - # for now. - raise ValueError('Nearest neighbor resampling is not currently supported ' - 'for scale_and_translate.') - assert isinstance(method, ResizeMethod) - - kernel = _kernels[method] - image, = promote_dtypes_inexact(image) - scale, translation = promote_dtypes_inexact(scale, translation) - return _scale_and_translate(image, shape, spatial_dims, scale, translation, - kernel, antialias, precision) + precision=lax.Precision.HIGHEST, +): + """Apply a scale and translation to an image. + + Generates a new image of shape 'shape' by resampling from the input image + using the sampling method corresponding to method. For 2D images, this + operation transforms a location in the input images, (x, y), to a location + in the output image according to:: + + (x * scale[1] + translation[1], y * scale[0] + translation[0]) + + (Note the *inverse* warp is used to generate the sample locations.) + Assumes half-centered pixels, i.e the pixel at integer location ``row, col`` + has coordinates ``y, x = row + 0.5, col + 0.5``, and similarly for other input + image dimensions. + + If an output location(pixel) maps to an input sample location that is outside + the input boundaries then the value for the output location will be set to + zero. + + The ``method`` argument expects one of the following resize methods: + + ``ResizeMethod.LINEAR``, ``"linear"``, ``"bilinear"``, ``"trilinear"``, + ``"triangle"`` `Linear interpolation`_. If ``antialias`` is ``True``, uses a + triangular filter when downsampling. + + ``ResizeMethod.CUBIC``, ``"cubic"``, ``"bicubic"``, ``"tricubic"`` + `Cubic interpolation`_, using the Keys cubic kernel. + + ``ResizeMethod.LANCZOS3``, ``"lanczos3"`` + `Lanczos resampling`_, using a kernel of radius 3. + + ``ResizeMethod.LANCZOS5``, ``"lanczos5"`` + `Lanczos resampling`_, using a kernel of radius 5. + + .. _Linear interpolation: https://en.wikipedia.org/wiki/Bilinear_interpolation + .. _Cubic interpolation: https://en.wikipedia.org/wiki/Bicubic_interpolation + .. _Lanczos resampling: https://en.wikipedia.org/wiki/Lanczos_resampling + + Args: + image: a JAX array. + shape: the output shape, as a sequence of integers with length equal to the + number of dimensions of `image`. + spatial_dims: A length K tuple specifying the spatial dimensions that the + passed scale and translation should be applied to. + scale: A [K] array with the same number of dimensions as image, containing + the scale to apply in each dimension. + translation: A [K] array with the same number of dimensions as image, + containing the translation to apply in each dimension. + method: the resizing method to use; either a ``ResizeMethod`` instance or a + string. Available methods are: LINEAR, LANCZOS3, LANCZOS5, CUBIC. + antialias: Should an antialiasing filter be used when downsampling? Defaults + to ``True``. Has no effect when upsampling. + + Returns: + The scale and translated image. + """ + shape = core.canonicalize_shape(shape) + if len(shape) != image.ndim: + msg = ( + "shape must have length equal to the number of dimensions of x; " + f" {shape} vs {image.shape}" + ) + raise ValueError(msg) + if isinstance(method, str): + method = ResizeMethod.from_string(method) + if method == ResizeMethod.NEAREST: + # Nearest neighbor is currently special-cased for straight resize, so skip + # for now. + raise ValueError( + "Nearest neighbor resampling is not currently supported " + "for scale_and_translate." + ) + assert isinstance(method, ResizeMethod) + + kernel = _kernels[method] + (image,) = promote_dtypes_inexact(image) + scale, translation = promote_dtypes_inexact(scale, translation) + return _scale_and_translate( + image, + shape, + spatial_dims, + scale, + translation, + kernel, + antialias, + precision, + ) # END ----------------- END JAX code copied for testing ----------------------------- diff --git a/torchax/torchax/ops/jc10d.py b/torchax/torchax/ops/jc10d.py index 79544943f918..0d730d39a9ad 100644 --- a/torchax/torchax/ops/jc10d.py +++ b/torchax/torchax/ops/jc10d.py @@ -6,47 +6,45 @@ def op(*aten, **kwargs): + def inner(func): + for a in aten: + ops_registry.register_torch_dispatch_op(a, func, **kwargs) + return func - def inner(func): - for a in aten: - ops_registry.register_torch_dispatch_op(a, func, **kwargs) - return func - - return inner + return inner @op(torch.ops._c10d_functional.all_gather_into_tensor) def _c10d_all_gather(input, group_size: int, group_name: str): - return jax.lax.all_gather(input, "torch_dist") + return jax.lax.all_gather(input, "torch_dist") @op(torch.ops._c10d_functional.all_reduce) def _c10d_all_reduce(self, reduceOp: str, group_name: str): - - if reduceOp == "sum": - res = jax.lax.psum(self, axis_name="torch_dist") - elif reduceOp == "avg": - res = jax.lax.pmean(self, axis_name="torch_dist") - elif reduceOp == "min": - res = jax.lax.pmin(self, axis_name="torch_dist") - elif reduceOp == "max": - res = jax.lax.pmax(self, axis_name="torch_dist") - else: - raise RuntimeError(f"Reduce op {reduceOp} not implemented") - return res + if reduceOp == "sum": + res = jax.lax.psum(self, axis_name="torch_dist") + elif reduceOp == "avg": + res = jax.lax.pmean(self, axis_name="torch_dist") + elif reduceOp == "min": + res = jax.lax.pmin(self, axis_name="torch_dist") + elif reduceOp == "max": + res = jax.lax.pmax(self, axis_name="torch_dist") + else: + raise RuntimeError(f"Reduce op {reduceOp} not implemented") + return res @op(torch.ops._c10d_functional.broadcast) def _c10d_broadcast(self, src: int, group_name: str): - masked = jnp.where( - jax.lax.axis_index("torch_dist") == src, - self, - jnp.zeros_like(self), - ) - return jax.lax.psum(masked, "torch_dist") + masked = jnp.where( + jax.lax.axis_index("torch_dist") == src, + self, + jnp.zeros_like(self), + ) + return jax.lax.psum(masked, "torch_dist") @op(torch.ops._c10d_functional.wait_tensor) def _c10d_wait_tensor(tensor): - # Async tensor is aleady `wait`ed by dispatcher - return tensor + # Async tensor is aleady `wait`ed by dispatcher + return tensor diff --git a/torchax/torchax/ops/jimage.py b/torchax/torchax/ops/jimage.py index 947be0a5e3f0..dbfc77c64cd2 100644 --- a/torchax/torchax/ops/jimage.py +++ b/torchax/torchax/ops/jimage.py @@ -3,111 +3,110 @@ def cubic_kernel(x, a=-0.75): - """Cubic kernel with a = -0.75 (PyTorch-like Keys kernel)""" - absx = jnp.abs(x) - x2 = absx * absx - x3 = x2 * absx - cond1 = (absx <= 1) - cond2 = (absx > 1) & (absx < 2) - f1 = (a + 2) * x3 - (a + 3) * x2 + 1 - f2 = a * x3 - 5 * a * x2 + 8 * a * absx - 4 * a - return jnp.where(cond1, f1, jnp.where(cond2, f2, 0.0)) - - -def compute_contribs(in_size, - out_size, - scale, - support=2.0, - align_corners=False, - dtype=None): - if align_corners: - if out_size == 1: - in_coords = jnp.zeros((1,), dtype=dtype) + """Cubic kernel with a = -0.75 (PyTorch-like Keys kernel)""" + absx = jnp.abs(x) + x2 = absx * absx + x3 = x2 * absx + cond1 = absx <= 1 + cond2 = (absx > 1) & (absx < 2) + f1 = (a + 2) * x3 - (a + 3) * x2 + 1 + f2 = a * x3 - 5 * a * x2 + 8 * a * absx - 4 * a + return jnp.where(cond1, f1, jnp.where(cond2, f2, 0.0)) + + +def compute_contribs( + in_size, out_size, scale, support=2.0, align_corners=False, dtype=None +): + if align_corners: + if out_size == 1: + in_coords = jnp.zeros((1,), dtype=dtype) + else: + in_coords = jnp.linspace(0, in_size - 1, out_size, dtype=dtype) else: - in_coords = jnp.linspace(0, in_size - 1, out_size, dtype=dtype) - else: - out_coords = jnp.arange(out_size, dtype=dtype) + 0.5 - in_coords = out_coords / scale - 0.5 + out_coords = jnp.arange(out_size, dtype=dtype) + 0.5 + in_coords = out_coords / scale - 0.5 - left_idx = jnp.floor(in_coords).astype(jnp.int32) - 1 - idxs = left_idx[:, None] + jnp.arange(4) + left_idx = jnp.floor(in_coords).astype(jnp.int32) - 1 + idxs = left_idx[:, None] + jnp.arange(4) - dx = in_coords[:, None] - idxs + dx = in_coords[:, None] - idxs - weights = cubic_kernel(dx) + weights = cubic_kernel(dx) - weights = weights / jnp.sum(weights, axis=1, keepdims=True) - return idxs, weights + weights = weights / jnp.sum(weights, axis=1, keepdims=True) + return idxs, weights def gather_weights(img, idxs, axis): - """Safely gather with boundary handling""" - idxs = jnp.clip(idxs, 0, img.shape[axis] - 1) - return jnp.take(img, idxs, axis=axis) + """Safely gather with boundary handling""" + idxs = jnp.clip(idxs, 0, img.shape[axis] - 1) + return jnp.take(img, idxs, axis=axis) def interpolate_along_axis_bchw(img, idxs, weights, axis): - """ + """ Interpolate along H (axis=2) or W (axis=3) for tensor (B, C, H, W). idxs: (out_size, 4) int32 indices weights: (out_size, 4) float32 weights """ - assert axis in (2, 3), "Axis must be 2 (H) or 3 (W)" - out_size = idxs.shape[0] - k = idxs.shape[1] # Typically 4 for cubic + assert axis in (2, 3), "Axis must be 2 (H) or 3 (W)" + out_size = idxs.shape[0] + k = idxs.shape[1] # Typically 4 for cubic - # Clip to input bounds - idxs = jnp.clip(idxs, 0, img.shape[axis] - 1) # (out_size, 4) + # Clip to input bounds + idxs = jnp.clip(idxs, 0, img.shape[axis] - 1) # (out_size, 4) - def gather_and_weight(i): - idx = idxs[i] # (4,) - w = weights[i] # (4,) + def gather_and_weight(i): + idx = idxs[i] # (4,) + w = weights[i] # (4,) - def gather_one(offset): - return jnp.take(img, idx[offset], axis=axis) # shape (B, C, H, W) + def gather_one(offset): + return jnp.take(img, idx[offset], axis=axis) # shape (B, C, H, W) - gathered = jnp.stack([gather_one(o) for o in range(k)], - axis=0) # (4, B, C, H, W) - weighted = jnp.tensordot(w, gathered, axes=(0, 0)) # (B, C, H, W) - return weighted + gathered = jnp.stack( + [gather_one(o) for o in range(k)], axis=0 + ) # (4, B, C, H, W) + weighted = jnp.tensordot(w, gathered, axes=(0, 0)) # (B, C, H, W) + return weighted - out = jax.vmap(gather_and_weight)( - jnp.arange(out_size)) # (out_size, B, C, H, W) + out = jax.vmap(gather_and_weight)( + jnp.arange(out_size) + ) # (out_size, B, C, H, W) - # Move the interpolated axis back into place - if axis == 2: # interpolated over H - return jnp.moveaxis(out, 0, 2) # (B, C, out_H, W) - else: # axis == 3, interpolated over W - return jnp.moveaxis(out, 0, 3) # (B, C, H, out_W) + # Move the interpolated axis back into place + if axis == 2: # interpolated over H + return jnp.moveaxis(out, 0, 2) # (B, C, out_H, W) + else: # axis == 3, interpolated over W + return jnp.moveaxis(out, 0, 3) # (B, C, H, out_W) def interpolate_bicubic_no_aa(img, out_h, out_w, align_corners=False): - h, w = img.shape[-2:] - if align_corners and out_h > 1: - scale_y = (h - 1) / (out_h - 1) - else: - scale_y = out_h / h - - if align_corners and out_w > 1: - scale_x = (w - 1) / (out_w - 1) - else: - scale_x = out_w / w - - idxs_y, weights_y = compute_contribs( - h, - out_h, - scale_y, - align_corners=align_corners, - dtype=img.dtype, - ) - tmp = interpolate_along_axis_bchw(img, idxs_y, weights_y, axis=2) - - idxs_x, weights_x = compute_contribs( - w, - out_w, - scale_x, - align_corners=align_corners, - dtype=img.dtype, - ) - out = interpolate_along_axis_bchw(tmp, idxs_x, weights_x, axis=3) - return out + h, w = img.shape[-2:] + if align_corners and out_h > 1: + scale_y = (h - 1) / (out_h - 1) + else: + scale_y = out_h / h + + if align_corners and out_w > 1: + scale_x = (w - 1) / (out_w - 1) + else: + scale_x = out_w / w + + idxs_y, weights_y = compute_contribs( + h, + out_h, + scale_y, + align_corners=align_corners, + dtype=img.dtype, + ) + tmp = interpolate_along_axis_bchw(img, idxs_y, weights_y, axis=2) + + idxs_x, weights_x = compute_contribs( + w, + out_w, + scale_x, + align_corners=align_corners, + dtype=img.dtype, + ) + out = interpolate_along_axis_bchw(tmp, idxs_x, weights_x, axis=3) + return out diff --git a/torchax/torchax/ops/jlibrary.py b/torchax/torchax/ops/jlibrary.py index 17cdb161c3c3..697cf6fe646d 100644 --- a/torchax/torchax/ops/jlibrary.py +++ b/torchax/torchax/ops/jlibrary.py @@ -11,70 +11,70 @@ def _jit_composite_impl(composite_name, jaxpr_impl, **jit_args): - """Wrap a jaxpr in a jitted function with the proper composite name - TODO: Wrap JIT in a `stablehlo.composite` op, instead of generating a call op. - """ + """Wrap a jaxpr in a jitted function with the proper composite name + TODO: Wrap JIT in a `stablehlo.composite` op, instead of generating a call op. + """ - def composite_impl(*args): - return jaxpr_impl(*args) + def composite_impl(*args): + return jaxpr_impl(*args) - composite_impl.__name__ = composite_name - composite_impl.__qualname__ = composite_name - return jax.jit(composite_impl, **jit_args) + composite_impl.__name__ = composite_name + composite_impl.__qualname__ = composite_name + return jax.jit(composite_impl, **jit_args) def register_jax_composite(composite_name, impl, *ops, **jit_args): - """Register a composite using a JAX implementation. - composite_name - The name of the library op to use in the exported composite - impl - A JAX lowering for the library operation - *ops - Variadic torch.ops to lower using `impl`. - **jit_args - Additional parameters to forward to JAX jit. + """Register a composite using a JAX implementation. + composite_name - The name of the library op to use in the exported composite + impl - A JAX lowering for the library operation + *ops - Variadic torch.ops to lower using `impl`. + **jit_args - Additional parameters to forward to JAX jit. - This is used to register custom lowerings with an explicit jaxpr - implementation, such as preserving a specific aten op using a jaten impl. + This is used to register custom lowerings with an explicit jaxpr + implementation, such as preserving a specific aten op using a jaten impl. - For custom torch op registration with a decomposition written in torch, - use `register_torch_composite`. + For custom torch op registration with a decomposition written in torch, + use `register_torch_composite`. - For jit params and troubleshooting see: - https://jax.readthedocs.io/en/latest/_autosummary/jax.jit.html - """ + For jit params and troubleshooting see: + https://jax.readthedocs.io/en/latest/_autosummary/jax.jit.html + """ - @jaten.op(*ops) - def _composite_impl(*args): - return _jit_composite_impl(composite_name, impl, **jit_args)(*args) + @jaten.op(*ops) + def _composite_impl(*args): + return _jit_composite_impl(composite_name, impl, **jit_args)(*args) def register_torch_composite(composite_name, impl, *ops, **jit_args): - """Register a torch decomposition as a composite. - This is useful for registerring custom torch op libraries as composite ops. - - The `impl` can be the `@impl` used to define the torch custom library op. - This must be a function or module impl that provides the decompositions, and - not an instance of the custom op. - - TODO: Better error handling, or can we make this an instance of the op as a param? - - For jit params and troubleshooting see: - https://jax.readthedocs.io/en/latest/_autosummary/jax.jit.html - """ - - @jaten.op(*ops) - def _composite_impl(*args): - - class ImplWrapper(torch.nn.Module): - - def __init__(self): - super().__init__() - - def forward(self, *args): - return impl(*args) - - # Note: avoid refactoring to share code with register_jaxpr_composite. - # The `extract_jax` call must live in the `@jaten.op` handler. If called - # outside of the handler, we would build the jaxpr representation of the - # module once during registration, potentially missing op registrations that - # come after. I.e. may miss nested abstractions if we build jaxpr AoT. - state, jfn = torchax.extract_jax(ImplWrapper()) - jaxpr_impl = lambda *args: jfn(state, tuple([*args])) - return _jit_composite_impl(composite_name, jaxpr_impl, **jit_args)(*args) + """Register a torch decomposition as a composite. + This is useful for registerring custom torch op libraries as composite ops. + + The `impl` can be the `@impl` used to define the torch custom library op. + This must be a function or module impl that provides the decompositions, and + not an instance of the custom op. + + TODO: Better error handling, or can we make this an instance of the op as a param? + + For jit params and troubleshooting see: + https://jax.readthedocs.io/en/latest/_autosummary/jax.jit.html + """ + + @jaten.op(*ops) + def _composite_impl(*args): + class ImplWrapper(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, *args): + return impl(*args) + + # Note: avoid refactoring to share code with register_jaxpr_composite. + # The `extract_jax` call must live in the `@jaten.op` handler. If called + # outside of the handler, we would build the jaxpr representation of the + # module once during registration, potentially missing op registrations that + # come after. I.e. may miss nested abstractions if we build jaxpr AoT. + state, jfn = torchax.extract_jax(ImplWrapper()) + jaxpr_impl = lambda *args: jfn(state, tuple([*args])) + return _jit_composite_impl(composite_name, jaxpr_impl, **jit_args)( + *args + ) diff --git a/torchax/torchax/ops/jtorch.py b/torchax/torchax/ops/jtorch.py index 935c214d78f5..b3d5340cc7b3 100644 --- a/torchax/torchax/ops/jtorch.py +++ b/torchax/torchax/ops/jtorch.py @@ -20,92 +20,94 @@ def register_function(torch_func, **kwargs): - return functools.partial(register_torch_function_op, torch_func, **kwargs) + return functools.partial(register_torch_function_op, torch_func, **kwargs) @register_function(torch.as_tensor, is_jax_function=False, needs_env=True) @op_base.convert_dtype( - use_default_dtype=False) # Attempt to infer type from elements + use_default_dtype=False +) # Attempt to infer type from elements def _as_tensor(data, dtype=None, device=None, env=None): - if isinstance(data, torch.Tensor): - return env._to_copy(data, dtype, device) - if isinstance(data, np.ndarray): - jax_res = jnp.asarray(data) - else: - jax_res = _tensor(data, dtype=dtype) - return torchax.tensor.Tensor(jax_res, env) + if isinstance(data, torch.Tensor): + return env._to_copy(data, dtype, device) + if isinstance(data, np.ndarray): + jax_res = jnp.asarray(data) + else: + jax_res = _tensor(data, dtype=dtype) + return torchax.tensor.Tensor(jax_res, env) @register_function(torch.tensor) @op_base.convert_dtype( - use_default_dtype=False) # Attempt to infer type from elements + use_default_dtype=False +) # Attempt to infer type from elements def _tensor(data, *, dtype=None, **kwargs): - python_types_to_torch_types = { - bool: jnp.bool, - int: jnp.int64, - float: jnp.float32, - complex: jnp.complex64, - } - if not dtype: - leaves = jax.tree_util.tree_leaves(data) - if len(leaves) > 0: - dtype = python_types_to_torch_types.get(type(leaves[0])) - - return jnp.array( - data, dtype=dtype or mappings.t2j_dtype(torch.get_default_dtype())) + python_types_to_torch_types = { + bool: jnp.bool, + int: jnp.int64, + float: jnp.float32, + complex: jnp.complex64, + } + if not dtype: + leaves = jax.tree_util.tree_leaves(data) + if len(leaves) > 0: + dtype = python_types_to_torch_types.get(type(leaves[0])) + + return jnp.array( + data, dtype=dtype or mappings.t2j_dtype(torch.get_default_dtype()) + ) @register_function(torch.allclose) def _aten_allclose(input, other, rtol=1e-05, atol=1e-08, equal_nan=False): - return jnp.allclose(input, other, rtol, atol, equal_nan) + return jnp.allclose(input, other, rtol, atol, equal_nan) @register_function(torch.angle) def _torch_angle(input): - if input.dtype.name == "int64": - input = input.astype(jnp.dtype("float32")) - return jnp.angle(input) + if input.dtype.name == "int64": + input = input.astype(jnp.dtype("float32")) + return jnp.angle(input) @register_function(torch.argsort) def _torch_argsort(input, dim=-1, descending=False, stable=False): - expanded = False - if input.ndim == 0: - # for self of rank 0: - # torch.any(x, 0), torch.any(x, -1) works; - # torch.any(x, 1) throws out of bounds, so it's - # behavior is the same as a jnp array of rank 1 - expanded = True - input = jnp.expand_dims(input, 0) - res = jnp.argsort(input, axis=dim, descending=descending, stable=stable) - if expanded: - res = res.squeeze() - return res + expanded = False + if input.ndim == 0: + # for self of rank 0: + # torch.any(x, 0), torch.any(x, -1) works; + # torch.any(x, 1) throws out of bounds, so it's + # behavior is the same as a jnp array of rank 1 + expanded = True + input = jnp.expand_dims(input, 0) + res = jnp.argsort(input, axis=dim, descending=descending, stable=stable) + if expanded: + res = res.squeeze() + return res @register_function(torch.diag) def _diag(input, diagonal=0): - return jnp.diag(input, k=diagonal) + return jnp.diag(input, k=diagonal) @register_function(torch.einsum) @register_function(torch.ops.aten.einsum) def _einsum(equation, *operands): - - def get_params(*a): - inner_list = a[0] - if not isinstance(inner_list, jax.Array): - if len(inner_list) == 1: - A = inner_list - return A - elif len(inner_list) == 2: - A, B = inner_list - return A, B - return operands - - assert isinstance(equation, str), "Only accept str equation" - filtered_operands = get_params(*operands) - return jnp.einsum(equation, *filtered_operands) + def get_params(*a): + inner_list = a[0] + if not isinstance(inner_list, jax.Array): + if len(inner_list) == 1: + A = inner_list + return A + elif len(inner_list) == 2: + A, B = inner_list + return A, B + return operands + + assert isinstance(equation, str), "Only accept str equation" + filtered_operands = get_params(*operands) + return jnp.einsum(equation, *filtered_operands) def _sdpa_reference( @@ -118,115 +120,121 @@ def _sdpa_reference( scale=None, enable_gqa=False, ) -> torch.Tensor: - L, S = query.size(-2), key.size(-2) - scale_factor = 1 / math.sqrt(query.size(-1)) if scale is None else scale - attn_bias = torch.zeros(L, S, dtype=query.dtype, device=query.device) - if is_causal: - assert attn_mask is None - temp_mask = torch.ones( - L, S, dtype=torch.bool, device=query.device).tril(diagonal=0) - attn_bias.masked_fill_(temp_mask.logical_not(), float("-inf")) - attn_bias.to(query.dtype) - if attn_mask is not None: - if attn_mask.dtype == torch.bool: - attn_bias.masked_fill_(attn_mask.logical_not(), float("-inf")) - else: - attn_bias += attn_mask - if enable_gqa: - key = key.repeat_interleave(query.size(-3) // key.size(-3), -3) - value = value.repeat_interleave(query.size(-3) // value.size(-3), -3) - - attn_weight = query @ key.transpose(-2, -1) * scale_factor - attn_weight += attn_bias - attn_weight = torch.softmax(attn_weight, dim=-1) - if dropout_p > 0: - attn_weight = torch.dropout(attn_weight, dropout_p, train=True) - return attn_weight @ value + L, S = query.size(-2), key.size(-2) + scale_factor = 1 / math.sqrt(query.size(-1)) if scale is None else scale + attn_bias = torch.zeros(L, S, dtype=query.dtype, device=query.device) + if is_causal: + assert attn_mask is None + temp_mask = torch.ones( + L, S, dtype=torch.bool, device=query.device + ).tril(diagonal=0) + attn_bias.masked_fill_(temp_mask.logical_not(), float("-inf")) + attn_bias.to(query.dtype) + if attn_mask is not None: + if attn_mask.dtype == torch.bool: + attn_bias.masked_fill_(attn_mask.logical_not(), float("-inf")) + else: + attn_bias += attn_mask + if enable_gqa: + key = key.repeat_interleave(query.size(-3) // key.size(-3), -3) + value = value.repeat_interleave(query.size(-3) // value.size(-3), -3) + + attn_weight = query @ key.transpose(-2, -1) * scale_factor + attn_weight += attn_bias + attn_weight = torch.softmax(attn_weight, dim=-1) + if dropout_p > 0: + attn_weight = torch.dropout(attn_weight, dropout_p, train=True) + return attn_weight @ value from jax.sharding import PartitionSpec def _tpu_flash_attention(query, key, value, env): - fsdp_partition = PartitionSpec("fsdp") - - def wrap_flash_attention(query, key, value): - block_sizes = flash_attention.BlockSizes( - block_b=min(2, query.shape[0]), - block_q=min(512, query.shape[2]), - block_k_major=min(512, key.shape[2]), - block_k=min(512, key.shape[2]), - block_q_major_dkv=min(512, query.shape[2]), - block_k_major_dkv=min(512, key.shape[2]), - block_k_dkv=min(512, key.shape[2]), - block_q_dkv=min(512, query.shape[2]), - block_k_major_dq=min(512, key.shape[2]), - block_k_dq=min(256, key.shape[2]), - block_q_dq=min(1024, query.shape[2]), - ) - return flash_attention.flash_attention( - query, key, value, causal=True, block_sizes=block_sizes) - - if env.config.shmap_flash_attention: - wrap_flash_attention = shard_map( - wrap_flash_attention, - mesh=env._mesh, - in_specs=(fsdp_partition, fsdp_partition, fsdp_partition), - out_specs=fsdp_partition, - check_rep=False, - ) - # return flash_attn_mapped(query, key, value) - return wrap_flash_attention(query, key, value) + fsdp_partition = PartitionSpec("fsdp") + + def wrap_flash_attention(query, key, value): + block_sizes = flash_attention.BlockSizes( + block_b=min(2, query.shape[0]), + block_q=min(512, query.shape[2]), + block_k_major=min(512, key.shape[2]), + block_k=min(512, key.shape[2]), + block_q_major_dkv=min(512, query.shape[2]), + block_k_major_dkv=min(512, key.shape[2]), + block_k_dkv=min(512, key.shape[2]), + block_q_dkv=min(512, query.shape[2]), + block_k_major_dq=min(512, key.shape[2]), + block_k_dq=min(256, key.shape[2]), + block_q_dq=min(1024, query.shape[2]), + ) + return flash_attention.flash_attention( + query, key, value, causal=True, block_sizes=block_sizes + ) + + if env.config.shmap_flash_attention: + wrap_flash_attention = shard_map( + wrap_flash_attention, + mesh=env._mesh, + in_specs=(fsdp_partition, fsdp_partition, fsdp_partition), + out_specs=fsdp_partition, + check_rep=False, + ) + # return flash_attn_mapped(query, key, value) + return wrap_flash_attention(query, key, value) @register_function(torch.nn.functional.pad) def pad(tensor, pad, mode="constant", value=None): - # For padding modes that have different names between Torch and NumPy, this - # dict provides a Torch-to-NumPy translation. Any string not in this dict will - # be passed through as-is. - MODE_NAME_TRANSLATION = { - "circular": "wrap", - "replicate": "edge", - } - - numpy_mode = MODE_NAME_TRANSLATION.get(mode, mode) - - num_prefix_dims = tensor.ndim - len(pad) // 2 - - numpy_pad_width = [(0, 0)] * num_prefix_dims - nd_slice = [slice(None)] * num_prefix_dims - - for i in range(len(pad) - 2, -1, -2): - pad_start, pad_end = pad[i:i + 2] - slice_start, slice_end = None, None - - if pad_start < 0: - slice_start = -pad_start - pad_start = 0 - - if pad_end < 0: - slice_end = pad_end - pad_end = 0 - - numpy_pad_width.append((pad_start, pad_end)) - nd_slice.append(slice(slice_start, slice_end)) - - nd_slice = tuple(nd_slice) - - # `jax.numpy.pad` complains if we provide an irrelevant `constant_values` arg, - # even if the value we pass in is `None`. (It treats `None` as `nan`.) - kwargs = dict() - if mode == "constant" and value is not None: - kwargs["constant_values"] = value - - # The "replicate" mode pads first and then slices, whereas the "circular" mode - # slices first and then pads. The latter approach deals with smaller tensors, - # so we default to that option in modes where the order of operations doesn't - # affect the result. - if mode == "replicate": - return jnp.pad(tensor, numpy_pad_width, mode=numpy_mode, **kwargs)[nd_slice] - else: - return jnp.pad(tensor[nd_slice], numpy_pad_width, mode=numpy_mode, **kwargs) + # For padding modes that have different names between Torch and NumPy, this + # dict provides a Torch-to-NumPy translation. Any string not in this dict will + # be passed through as-is. + MODE_NAME_TRANSLATION = { + "circular": "wrap", + "replicate": "edge", + } + + numpy_mode = MODE_NAME_TRANSLATION.get(mode, mode) + + num_prefix_dims = tensor.ndim - len(pad) // 2 + + numpy_pad_width = [(0, 0)] * num_prefix_dims + nd_slice = [slice(None)] * num_prefix_dims + + for i in range(len(pad) - 2, -1, -2): + pad_start, pad_end = pad[i : i + 2] + slice_start, slice_end = None, None + + if pad_start < 0: + slice_start = -pad_start + pad_start = 0 + + if pad_end < 0: + slice_end = pad_end + pad_end = 0 + + numpy_pad_width.append((pad_start, pad_end)) + nd_slice.append(slice(slice_start, slice_end)) + + nd_slice = tuple(nd_slice) + + # `jax.numpy.pad` complains if we provide an irrelevant `constant_values` arg, + # even if the value we pass in is `None`. (It treats `None` as `nan`.) + kwargs = dict() + if mode == "constant" and value is not None: + kwargs["constant_values"] = value + + # The "replicate" mode pads first and then slices, whereas the "circular" mode + # slices first and then pads. The latter approach deals with smaller tensors, + # so we default to that option in modes where the order of operations doesn't + # affect the result. + if mode == "replicate": + return jnp.pad(tensor, numpy_pad_width, mode=numpy_mode, **kwargs)[ + nd_slice + ] + else: + return jnp.pad( + tensor[nd_slice], numpy_pad_width, mode=numpy_mode, **kwargs + ) @register_function( @@ -237,7 +245,8 @@ def pad(tensor, pad, mode="constant", value=None): @register_function( torch.ops.aten.scaled_dot_product_attention, is_jax_function=False, - needs_env=True) + needs_env=True, +) def scaled_dot_product_attention( query, key, @@ -249,96 +258,98 @@ def scaled_dot_product_attention( enable_gqa=False, env=None, ) -> torch.Tensor: + if env.config.use_tpu_flash_attention: + jquery, jkey, jvalue = env.t2j_iso((query, key, value)) + res = _tpu_flash_attention(jquery, jkey, jvalue, env) + return env.j2t_iso(res) - if env.config.use_tpu_flash_attention: - jquery, jkey, jvalue = env.t2j_iso((query, key, value)) - res = _tpu_flash_attention(jquery, jkey, jvalue, env) - return env.j2t_iso(res) - - return _sdpa_reference(query, key, value, attn_mask, dropout_p, is_causal, - scale, enable_gqa) + return _sdpa_reference( + query, key, value, attn_mask, dropout_p, is_causal, scale, enable_gqa + ) @register_function( - torch.Tensor.__getitem__, is_jax_function=False, is_view_op=True) + torch.Tensor.__getitem__, is_jax_function=False, is_view_op=True +) def getitem(self, indexes): + if isinstance(indexes, list) and isinstance(indexes[0], int): + # list of int, i.e. x[[1, 2]] NOT x[1, 2] (the second would be tuple of int) + indexes = (indexes,) + elif isinstance(indexes, list): + indexes = tuple(indexes) - if isinstance(indexes, list) and isinstance(indexes[0], int): - # list of int, i.e. x[[1, 2]] NOT x[1, 2] (the second would be tuple of int) - indexes = (indexes,) - elif isinstance(indexes, list): - indexes = tuple(indexes) + def is_narrow_slicing(): + tensor_free = not pytree.tree_any( + lambda x: isinstance(x, torch.Tensor) or isinstance(x, jax.Array), + indexes, + ) + list_free = not isinstance(indexes, tuple) or all([ + False if isinstance(x, list) else True for x in indexes + ]) + return tensor_free and list_free - def is_narrow_slicing(): - tensor_free = not pytree.tree_any( - lambda x: isinstance(x, torch.Tensor) or isinstance(x, jax.Array), - indexes) - list_free = not isinstance(indexes, tuple) or all( - [False if isinstance(x, list) else True for x in indexes]) - return tensor_free and list_free + if is_narrow_slicing(): + return View(self, view_info=NarrowInfo(indexes), env=self._env) - if is_narrow_slicing(): - return View(self, view_info=NarrowInfo(indexes), env=self._env) - - indexes = self._env.t2j_iso(indexes) - return torchax.tensor.Tensor(self._elem[indexes], self._env) + indexes = self._env.t2j_iso(indexes) + return torchax.tensor.Tensor(self._elem[indexes], self._env) @register_function(torch.corrcoef) def _corrcoef(x): - if x.dtype.name == "int64": - return jnp.corrcoef(x).astype(jnp.float32) - return jnp.corrcoef(x) + if x.dtype.name == "int64": + return jnp.corrcoef(x).astype(jnp.float32) + return jnp.corrcoef(x) @register_function(torch.sparse.mm, is_jax_function=False) def _sparse_mm(mat1, mat2, reduce="sum"): - return torch.mm(mat1, mat2) + return torch.mm(mat1, mat2) @register_function(torch.isclose) def _aten_isclose(input, other, rtol=1e-05, atol=1e-08, equal_nan=False): - return jnp.isclose(input, other, rtol, atol, equal_nan) + return jnp.isclose(input, other, rtol, atol, equal_nan) @register_function(torch.linalg.det) def linalg_det(input): - return jnp.linalg.det(input) + return jnp.linalg.det(input) @register_function(torch.ones) def _ones(*size: int, dtype=None, **kwargs): - if len(size) == 1 and isinstance(size[0], collections.abc.Iterable): - size = size[0] - return jaten._ones(size, dtype=dtype) + if len(size) == 1 and isinstance(size[0], collections.abc.Iterable): + size = size[0] + return jaten._ones(size, dtype=dtype) @register_function(torch.zeros, is_jax_function=True) def _zeros(*size: int, dtype=None, **kwargs): - if len(size) == 1 and isinstance(size[0], collections.abc.Iterable): - size = size[0] - return jaten._zeros(size, dtype=dtype) + if len(size) == 1 and isinstance(size[0], collections.abc.Iterable): + size = size[0] + return jaten._zeros(size, dtype=dtype) @register_function(torch.eye) @op_base.convert_dtype() def _eye(n: int, m: Optional[int] = None, *, dtype=None, **kwargs): - return jnp.eye(n, m, dtype=dtype) + return jnp.eye(n, m, dtype=dtype) @register_function(torch.full) @op_base.convert_dtype() def _full(size: Sequence[int], fill_value, *, dtype=None, **kwargs): - # TODO: handle torch.Size - return jnp.full(size, fill_value, dtype=dtype) + # TODO: handle torch.Size + return jnp.full(size, fill_value, dtype=dtype) @register_function(torch.empty) @op_base.convert_dtype() def empty(*size: Sequence[int], dtype=None, **kwargs): - if len(size) == 1 and isinstance(size[0], collections.abc.Iterable): - size = size[0] - return jnp.empty(size, dtype=dtype) + if len(size) == 1 and isinstance(size[0], collections.abc.Iterable): + size = size[0] + return jnp.empty(size, dtype=dtype) @register_function(torch.arange, is_jax_function=False) @@ -353,12 +364,12 @@ def arange( requires_grad=False, pin_memory=None, ): - if end is None: - end = start - start = 0 - if step is None: - step = 1 - return torch.ops.aten.arange(start, end, step, dtype=dtype) + if end is None: + end = start + start = 0 + if step is None: + step = 1 + return torch.ops.aten.arange(start, end, step, dtype=dtype) @register_function(torch.empty_strided, is_jax_function=False) @@ -372,19 +383,19 @@ def empty_strided( requires_grad=False, pin_memory=False, ): - return empty(size, dtype=dtype) + return empty(size, dtype=dtype) @register_function(torch.unravel_index) def unravel_index(indices, shape): - return jnp.unravel_index(indices, shape) + return jnp.unravel_index(indices, shape) @register_function(torch.rand, is_jax_function=False) def rand(*size, **kwargs): - if len(size) == 1 and isinstance(size[0], collections.abc.Iterable): - size = size[0] - return torch.ops.aten.rand(size, **kwargs) + if len(size) == 1 and isinstance(size[0], collections.abc.Iterable): + size = size[0] + return torch.ops.aten.rand(size, **kwargs) @register_function(torch.randn, is_jax_function=False) @@ -398,120 +409,120 @@ def randn( requires_grad=False, pin_memory=False, ): - if len(size) == 1 and isinstance(size[0], collections.abc.Iterable): - size = size[0] - return torch.ops.aten.randn(size, generator=generator, dtype=dtype) + if len(size) == 1 and isinstance(size[0], collections.abc.Iterable): + size = size[0] + return torch.ops.aten.randn(size, generator=generator, dtype=dtype) @register_function(torch.randint, is_jax_function=False) def randint(*args, **kwargs): - return torch.ops.aten.randint(*args, **kwargs) + return torch.ops.aten.randint(*args, **kwargs) @register_function(torch.logdet) def logdet(input): - _, logabsdet = jaten._aten__linalg_slogdet(input) - return logabsdet + _, logabsdet = jaten._aten__linalg_slogdet(input) + return logabsdet @register_function(torch.linalg.slogdet) def linalg_slogdet(input): - sign, logabsdet = jaten._aten__linalg_slogdet(input) - return torch.return_types.slogdet((sign, logabsdet)) + sign, logabsdet = jaten._aten__linalg_slogdet(input) + return torch.return_types.slogdet((sign, logabsdet)) @register_function(torch.tensor_split) def tensor_split(input, indices_or_sections, dim=0): - return jnp.array_split(input, indices_or_sections, axis=dim) + return jnp.array_split(input, indices_or_sections, axis=dim) @register_function(torch.linalg.solve) def linalg_solve(a, b): - res, _ = jaten._aten__linalg_solve_ex(a, b) - return res + res, _ = jaten._aten__linalg_solve_ex(a, b) + return res @register_function(torch.linalg.solve_ex) def linalg_solve_ex(a, b): - res, info = jaten._aten__linalg_solve_ex(a, b) - return res, info + res, info = jaten._aten__linalg_solve_ex(a, b) + return res, info @register_function(torch.linalg.svd) def linalg_svd(a, full_matrices=True): - return jaten._aten__linalg_svd(a, full_matrices=full_matrices) + return jaten._aten__linalg_svd(a, full_matrices=full_matrices) @register_function(torch.linalg.matrix_power) def matrix_power(A, n, *, out=None): - return jnp.linalg.matrix_power(A, n) + return jnp.linalg.matrix_power(A, n) @register_function(torch.svd) def svd(a, some=True, compute_uv=True): - if not compute_uv: - S = jaten._aten__linalg_svd(a, full_matrices=False)[1] - U = jnp.zeros((a.shape[-2], a.shape[-2]), dtype=a.dtype) - V = jnp.zeros((a.shape[-1], a.shape[-1]), dtype=a.dtype) - return U, S, V - U, S, V = jaten._aten__linalg_svd(a, full_matrices=not some) - return U, S, jnp.matrix_transpose(V) + if not compute_uv: + S = jaten._aten__linalg_svd(a, full_matrices=False)[1] + U = jnp.zeros((a.shape[-2], a.shape[-2]), dtype=a.dtype) + V = jnp.zeros((a.shape[-1], a.shape[-1]), dtype=a.dtype) + return U, S, V + U, S, V = jaten._aten__linalg_svd(a, full_matrices=not some) + return U, S, jnp.matrix_transpose(V) @register_function(torch.cdist) def _cdist(x1, x2, p=2.0, compute_mode="use_mm_for_euclid_dist_if_necessary"): - return jaten._aten_cdist(x1, x2, p, compute_mode) + return jaten._aten_cdist(x1, x2, p, compute_mode) @register_function(torch.lu) def lu(A, **kwargs): - lu, pivots, _ = jax.lax.linalg.lu(A) - # JAX pivots are offset by 1 compared to torch - _pivots = pivots + 1 - info_shape = pivots.shape[:-1] - info = jnp.zeros(info_shape, dtype=mappings.t2j_dtype(torch.int32)) - if kwargs["get_infos"] == True: - return lu, _pivots, info - return lu, _pivots + lu, pivots, _ = jax.lax.linalg.lu(A) + # JAX pivots are offset by 1 compared to torch + _pivots = pivots + 1 + info_shape = pivots.shape[:-1] + info = jnp.zeros(info_shape, dtype=mappings.t2j_dtype(torch.int32)) + if kwargs["get_infos"] == True: + return lu, _pivots, info + return lu, _pivots @register_function(torch.lu_solve) def lu_solve(b, LU_data, LU_pivots, **kwargs): - # JAX pivots are offset by 1 compared to torch - _pivots = LU_pivots - 1 - x = jax.scipy.linalg.lu_solve((LU_data, _pivots), b) - return x + # JAX pivots are offset by 1 compared to torch + _pivots = LU_pivots - 1 + x = jax.scipy.linalg.lu_solve((LU_data, _pivots), b) + return x @register_function(torch.linalg.tensorsolve) def linalg_tensorsolve(A, b, dims=None): - # examples: - # A = torch.randn(2, 3, 6), b = torch.randn(3, 2) - # A = torch.randn(2, 3, 6), b = torch.randn(2, 3) -> torch.Size([3, 6]) - # A = torch.randn(9, 2, 6, 3) b = torch.randn(6, 3) -> torch.Size([6, 3]) - # A = torch.randn(9, 2, 3, 6) b = torch.randn(6, 3) -> torch.Size([3, 6]) - # A = torch.randn(18, 6, 3) b = torch.randn(18) -> torch.Size([6, 3]) - # A = torch.randn(3, 8, 4, 6) b = torch.randn(4, 6) -> torch.Size([4,6]) - # A = torch.randn(3, 8, 1, 2, 2, 6) b = torch.randn(3, 4, 2) -> torch.Size([2, 2, 6]) - - # torch allows b to be shaped differently. - # especially when axes are moved using dims. - # ValueError: After moving axes to end, leading shape of a must match shape of b. got a.shape=(3, 2, 6), b.shape=(2, 3) - # So we are handling the moveaxis and forcing b's shape to match what jax expects - if dims is not None: - A = jnp.moveaxis(A, dims, len(dims) * (A.ndim - 1,)) - dims = None - if A.shape[:b.ndim] != b.shape: - b = jnp.reshape(b, A.shape[:b.ndim]) - return jnp.linalg.tensorsolve(A, b, axes=dims) + # examples: + # A = torch.randn(2, 3, 6), b = torch.randn(3, 2) + # A = torch.randn(2, 3, 6), b = torch.randn(2, 3) -> torch.Size([3, 6]) + # A = torch.randn(9, 2, 6, 3) b = torch.randn(6, 3) -> torch.Size([6, 3]) + # A = torch.randn(9, 2, 3, 6) b = torch.randn(6, 3) -> torch.Size([3, 6]) + # A = torch.randn(18, 6, 3) b = torch.randn(18) -> torch.Size([6, 3]) + # A = torch.randn(3, 8, 4, 6) b = torch.randn(4, 6) -> torch.Size([4,6]) + # A = torch.randn(3, 8, 1, 2, 2, 6) b = torch.randn(3, 4, 2) -> torch.Size([2, 2, 6]) + + # torch allows b to be shaped differently. + # especially when axes are moved using dims. + # ValueError: After moving axes to end, leading shape of a must match shape of b. got a.shape=(3, 2, 6), b.shape=(2, 3) + # So we are handling the moveaxis and forcing b's shape to match what jax expects + if dims is not None: + A = jnp.moveaxis(A, dims, len(dims) * (A.ndim - 1,)) + dims = None + if A.shape[: b.ndim] != b.shape: + b = jnp.reshape(b, A.shape[: b.ndim]) + return jnp.linalg.tensorsolve(A, b, axes=dims) @register_function(torch.nn.functional.linear) def functional_linear(self, weights, bias=None): - res = jnp.einsum("...a,ba->...b", self, weights) - if bias is not None: - res += bias - return res + res = jnp.einsum("...a,ba->...b", self, weights) + if bias is not None: + res += bias + return res @register_function(torch.nn.functional.interpolate) @@ -524,45 +535,46 @@ def functional_interpolate( recompute_scale_factor: bool, antialias: bool, ): - supported_methods = ( - "nearest", - "linear", - "bilinear", - "trilinear", - "cubic", - "bicubic", - "tricubic", - "lanczos3", - "lanczos5", - ) - is_jax_supported = mode in supported_methods - if not is_jax_supported: - raise torchax.tensor.OperatorNotFound( - f"JAX does not support interpolation mode: {mode}. Supported modes are: {supported_methods}" - ) - # None check - antialias = antialias or False - align_corners = align_corners or False - - if mode in ('cubic', 'bicubic', - 'tricubic') and not antialias and size is not None: - return jimage.interpolate_bicubic_no_aa( - input, - size[0], - size[1], - align_corners, - ) - else: - # fallback - raise torchax.tensor.OperatorNotFound( - f"JAX does not support interpolation mode: {mode}. Supported modes are: {supported_methods}" + supported_methods = ( + "nearest", + "linear", + "bilinear", + "trilinear", + "cubic", + "bicubic", + "tricubic", + "lanczos3", + "lanczos5", ) + is_jax_supported = mode in supported_methods + if not is_jax_supported: + raise torchax.tensor.OperatorNotFound( + f"JAX does not support interpolation mode: {mode}. Supported modes are: {supported_methods}" + ) + # None check + antialias = antialias or False + align_corners = align_corners or False + + if ( + mode in ("cubic", "bicubic", "tricubic") + and not antialias + and size is not None + ): + return jimage.interpolate_bicubic_no_aa( + input, + size[0], + size[1], + align_corners, + ) + else: + # fallback + raise torchax.tensor.OperatorNotFound( + f"JAX does not support interpolation mode: {mode}. Supported modes are: {supported_methods}" + ) @register_function(torch.Tensor.repeat_interleave) -def torch_Tensor_repeat_interleave(self, - repeats, - dim=None, - *, - output_size=None): - return jnp.repeat(self, repeats, axis=dim, total_repeat_length=output_size) +def torch_Tensor_repeat_interleave( + self, repeats, dim=None, *, output_size=None +): + return jnp.repeat(self, repeats, axis=dim, total_repeat_length=output_size) diff --git a/torchax/torchax/ops/jtorchvision_nms.py b/torchax/torchax/ops/jtorchvision_nms.py index 57832b560b03..401b248af345 100644 --- a/torchax/torchax/ops/jtorchvision_nms.py +++ b/torchax/torchax/ops/jtorchvision_nms.py @@ -14,221 +14,263 @@ def _bbox_overlap(boxes, gt_boxes): - """Find Bounding box overlap. - - Args: - boxes: first set of bounding boxes - gt_boxes: second set of boxes to compute IOU - - Returns: - iou: Intersection over union matrix of all input bounding boxes - """ - bb_y_min, bb_x_min, bb_y_max, bb_x_max = jnp.split( - ary=boxes, indices_or_sections=4, axis=2) - gt_y_min, gt_x_min, gt_y_max, gt_x_max = jnp.split( - ary=gt_boxes, indices_or_sections=4, axis=2) - - # Calculates the intersection area. - i_xmin = jnp.maximum(bb_x_min, jnp.transpose(gt_x_min, [0, 2, 1])) - i_xmax = jnp.minimum(bb_x_max, jnp.transpose(gt_x_max, [0, 2, 1])) - i_ymin = jnp.maximum(bb_y_min, jnp.transpose(gt_y_min, [0, 2, 1])) - i_ymax = jnp.minimum(bb_y_max, jnp.transpose(gt_y_max, [0, 2, 1])) - i_area = jnp.maximum((i_xmax - i_xmin), 0) * jnp.maximum((i_ymax - i_ymin), 0) - - # Calculates the union area. - bb_area = (bb_y_max - bb_y_min) * (bb_x_max - bb_x_min) - gt_area = (gt_y_max - gt_y_min) * (gt_x_max - gt_x_min) - # Adds a small epsilon to avoid divide-by-zero. - u_area = bb_area + jnp.transpose(gt_area, [0, 2, 1]) - i_area + 1e-8 - - # Calculates IoU. - iou = i_area / u_area - - return iou + """Find Bounding box overlap. + + Args: + boxes: first set of bounding boxes + gt_boxes: second set of boxes to compute IOU + + Returns: + iou: Intersection over union matrix of all input bounding boxes + """ + bb_y_min, bb_x_min, bb_y_max, bb_x_max = jnp.split( + ary=boxes, indices_or_sections=4, axis=2 + ) + gt_y_min, gt_x_min, gt_y_max, gt_x_max = jnp.split( + ary=gt_boxes, indices_or_sections=4, axis=2 + ) + + # Calculates the intersection area. + i_xmin = jnp.maximum(bb_x_min, jnp.transpose(gt_x_min, [0, 2, 1])) + i_xmax = jnp.minimum(bb_x_max, jnp.transpose(gt_x_max, [0, 2, 1])) + i_ymin = jnp.maximum(bb_y_min, jnp.transpose(gt_y_min, [0, 2, 1])) + i_ymax = jnp.minimum(bb_y_max, jnp.transpose(gt_y_max, [0, 2, 1])) + i_area = jnp.maximum((i_xmax - i_xmin), 0) * jnp.maximum( + (i_ymax - i_ymin), 0 + ) + + # Calculates the union area. + bb_area = (bb_y_max - bb_y_min) * (bb_x_max - bb_x_min) + gt_area = (gt_y_max - gt_y_min) * (gt_x_max - gt_x_min) + # Adds a small epsilon to avoid divide-by-zero. + u_area = bb_area + jnp.transpose(gt_area, [0, 2, 1]) - i_area + 1e-8 + + # Calculates IoU. + iou = i_area / u_area + + return iou def _self_suppression(in_args): - iou, _, iou_sum = in_args - batch_size = iou.shape[0] - can_suppress_others = jnp.reshape( - jnp.max(iou, 1) <= 0.5, [batch_size, -1, 1]).astype(iou.dtype) - iou_suppressed = jnp.reshape( - (jnp.max(can_suppress_others * iou, 1) <= 0.5).astype( - iou.dtype), [batch_size, -1, 1]) * iou - iou_sum_new = jnp.sum(iou_suppressed, [1, 2]) - return iou_suppressed, jnp.any(iou_sum - iou_sum_new > 0.5), iou_sum_new + iou, _, iou_sum = in_args + batch_size = iou.shape[0] + can_suppress_others = jnp.reshape( + jnp.max(iou, 1) <= 0.5, [batch_size, -1, 1] + ).astype(iou.dtype) + iou_suppressed = ( + jnp.reshape( + (jnp.max(can_suppress_others * iou, 1) <= 0.5).astype(iou.dtype), + [batch_size, -1, 1], + ) + * iou + ) + iou_sum_new = jnp.sum(iou_suppressed, [1, 2]) + return iou_suppressed, jnp.any(iou_sum - iou_sum_new > 0.5), iou_sum_new def _cross_suppression(in_args): - boxes, box_slice, iou_threshold, inner_idx = in_args - batch_size = boxes.shape[0] - new_slice = lax.dynamic_slice(boxes, [0, inner_idx * _NMS_TILE_SIZE, 0], - [batch_size, _NMS_TILE_SIZE, 4]) - iou = _bbox_overlap(new_slice, box_slice) - ret_slice = jnp.expand_dims((jnp.all(iou < iou_threshold, [1])).astype( - box_slice.dtype), 2) * box_slice - return boxes, ret_slice, iou_threshold, inner_idx + 1 + boxes, box_slice, iou_threshold, inner_idx = in_args + batch_size = boxes.shape[0] + new_slice = lax.dynamic_slice( + boxes, + [0, inner_idx * _NMS_TILE_SIZE, 0], + [batch_size, _NMS_TILE_SIZE, 4], + ) + iou = _bbox_overlap(new_slice, box_slice) + ret_slice = ( + jnp.expand_dims( + (jnp.all(iou < iou_threshold, [1])).astype(box_slice.dtype), 2 + ) + * box_slice + ) + return boxes, ret_slice, iou_threshold, inner_idx + 1 def _suppression_loop_body(in_args): - """Process boxes in the range [idx*_NMS_TILE_SIZE, (idx+1)*_NMS_TILE_SIZE). - - Args: - in_args: A tuple of arguments: boxes, iou_threshold, output_size, idx - - Returns: - boxes: updated boxes. - iou_threshold: pass down iou_threshold to the next iteration. - output_size: the updated output_size. - idx: the updated induction variable. - """ - boxes, iou_threshold, output_size, idx = in_args - num_tiles = boxes.shape[1] // _NMS_TILE_SIZE - batch_size = boxes.shape[0] - - # Iterates over tiles that can possibly suppress the current tile. - box_slice = lax.dynamic_slice(boxes, [0, idx * _NMS_TILE_SIZE, 0], - [batch_size, _NMS_TILE_SIZE, 4]) - - def _loop_cond(in_args): - _, _, _, inner_idx = in_args - return inner_idx < idx - - _, box_slice, _, _ = lax.while_loop(_loop_cond, _cross_suppression, - (boxes, box_slice, iou_threshold, 0)) - - # Iterates over the current tile to compute self-suppression. - iou = _bbox_overlap(box_slice, box_slice) - mask = jnp.expand_dims( - jnp.reshape(jnp.arange(_NMS_TILE_SIZE), [1, -1]) - > jnp.reshape(jnp.arange(_NMS_TILE_SIZE), [-1, 1]), 0) - iou *= (jnp.logical_and(mask, iou >= iou_threshold)).astype(iou.dtype) - - def _loop_cond2(in_args): - _, loop_condition, _ = in_args - return loop_condition - - suppressed_iou, _, _ = lax.while_loop(_loop_cond2, _self_suppression, - (iou, True, jnp.sum(iou, [1, 2]))) - suppressed_box = jnp.sum(suppressed_iou, 1) > 0 - box_slice *= jnp.expand_dims(1.0 - suppressed_box.astype(box_slice.dtype), 2) - - # Uses box_slice to update the input boxes. - mask = jnp.reshape((jnp.equal(jnp.arange(num_tiles), - idx)).astype(boxes.dtype), [1, -1, 1, 1]) - boxes = jnp.tile(jnp.expand_dims( - box_slice, 1), [1, num_tiles, 1, 1]) * mask + jnp.reshape( - boxes, [batch_size, num_tiles, _NMS_TILE_SIZE, 4]) * (1 - mask) - boxes = jnp.reshape(boxes, [batch_size, -1, 4]) - - # Updates output_size. - output_size += jnp.sum(jnp.any(box_slice > 0, [2]).astype(jnp.int32), [1]) - return boxes, iou_threshold, output_size, idx + 1 + """Process boxes in the range [idx*_NMS_TILE_SIZE, (idx+1)*_NMS_TILE_SIZE). + + Args: + in_args: A tuple of arguments: boxes, iou_threshold, output_size, idx + + Returns: + boxes: updated boxes. + iou_threshold: pass down iou_threshold to the next iteration. + output_size: the updated output_size. + idx: the updated induction variable. + """ + boxes, iou_threshold, output_size, idx = in_args + num_tiles = boxes.shape[1] // _NMS_TILE_SIZE + batch_size = boxes.shape[0] + + # Iterates over tiles that can possibly suppress the current tile. + box_slice = lax.dynamic_slice( + boxes, [0, idx * _NMS_TILE_SIZE, 0], [batch_size, _NMS_TILE_SIZE, 4] + ) + + def _loop_cond(in_args): + _, _, _, inner_idx = in_args + return inner_idx < idx + + _, box_slice, _, _ = lax.while_loop( + _loop_cond, _cross_suppression, (boxes, box_slice, iou_threshold, 0) + ) + + # Iterates over the current tile to compute self-suppression. + iou = _bbox_overlap(box_slice, box_slice) + mask = jnp.expand_dims( + jnp.reshape(jnp.arange(_NMS_TILE_SIZE), [1, -1]) + > jnp.reshape(jnp.arange(_NMS_TILE_SIZE), [-1, 1]), + 0, + ) + iou *= (jnp.logical_and(mask, iou >= iou_threshold)).astype(iou.dtype) + + def _loop_cond2(in_args): + _, loop_condition, _ = in_args + return loop_condition + + suppressed_iou, _, _ = lax.while_loop( + _loop_cond2, _self_suppression, (iou, True, jnp.sum(iou, [1, 2])) + ) + suppressed_box = jnp.sum(suppressed_iou, 1) > 0 + box_slice *= jnp.expand_dims( + 1.0 - suppressed_box.astype(box_slice.dtype), 2 + ) + + # Uses box_slice to update the input boxes. + mask = jnp.reshape( + (jnp.equal(jnp.arange(num_tiles), idx)).astype(boxes.dtype), + [1, -1, 1, 1], + ) + boxes = jnp.tile( + jnp.expand_dims(box_slice, 1), [1, num_tiles, 1, 1] + ) * mask + jnp.reshape( + boxes, [batch_size, num_tiles, _NMS_TILE_SIZE, 4] + ) * (1 - mask) + boxes = jnp.reshape(boxes, [batch_size, -1, 4]) + + # Updates output_size. + output_size += jnp.sum(jnp.any(box_slice > 0, [2]).astype(jnp.int32), [1]) + return boxes, iou_threshold, output_size, idx + 1 def non_max_suppression_padded(scores, boxes, max_output_size, iou_threshold): - """A wrapper that handles non-maximum suppression. - - Assumption: - * The boxes are sorted by scores unless the box is a dot (all coordinates - are zero). - * Boxes with higher scores can be used to suppress boxes with lower scores. - - The overal design of the algorithm is to handle boxes tile-by-tile: - - boxes = boxes.pad_to_multiply_of(tile_size) - num_tiles = len(boxes) // tile_size - output_boxes = [] - for i in range(num_tiles): - box_tile = boxes[i*tile_size : (i+1)*tile_size] - for j in range(i - 1): - suppressing_tile = boxes[j*tile_size : (j+1)*tile_size] - iou = _bbox_overlap(box_tile, suppressing_tile) - # if the box is suppressed in iou, clear it to a dot - box_tile *= _update_boxes(iou) - # Iteratively handle the diagnal tile. - iou = _box_overlap(box_tile, box_tile) - iou_changed = True - while iou_changed: - # boxes that are not suppressed by anything else - suppressing_boxes = _get_suppressing_boxes(iou) - # boxes that are suppressed by suppressing_boxes - suppressed_boxes = _get_suppressed_boxes(iou, suppressing_boxes) - # clear iou to 0 for boxes that are suppressed, as they cannot be used - # to suppress other boxes any more - new_iou = _clear_iou(iou, suppressed_boxes) - iou_changed = (new_iou != iou) - iou = new_iou - # remaining boxes that can still suppress others, are selected boxes. - output_boxes.append(_get_suppressing_boxes(iou)) - if len(output_boxes) >= max_output_size: - break - - Args: - scores: a tensor with a shape of [batch_size, anchors]. - boxes: a tensor with a shape of [batch_size, anchors, 4]. - max_output_size: a scalar integer `Tensor` representing the maximum number - of boxes to be selected by non max suppression. - iou_threshold: a float representing the threshold for deciding whether boxes - overlap too much with respect to IOU. - Returns: - nms_scores: a tensor with a shape of [batch_size, anchors]. It has same - dtype as input scores. - nms_proposals: a tensor with a shape of [batch_size, anchors, 4]. It has - same dtype as input boxes. - """ - batch_size = boxes.shape[0] - num_boxes = boxes.shape[1] - pad = int(jnp.ceil( - float(num_boxes) / _NMS_TILE_SIZE)) * _NMS_TILE_SIZE - num_boxes - boxes = jnp.pad(boxes.astype(jnp.float32), [[0, 0], [0, pad], [0, 0]]) - scores = jnp.pad(scores.astype(jnp.float32), [[0, 0], [0, pad]]) - num_boxes += pad - - def _loop_cond(in_args): - unused_boxes, unused_threshold, output_size, idx = in_args - return jnp.logical_and( - jnp.min(output_size) < max_output_size, idx - < num_boxes // _NMS_TILE_SIZE) - - selected_boxes, _, output_size, _ = lax.while_loop( - _loop_cond, _suppression_loop_body, - (boxes, iou_threshold, jnp.zeros([batch_size], jnp.int32), 0)) - idx = num_boxes - lax.top_k( - jnp.any(selected_boxes > 0, [2]).astype(jnp.int32) * - jnp.expand_dims(jnp.arange(num_boxes, 0, -1), 0), - max_output_size)[0].astype(jnp.int32) - idx = jnp.minimum(idx, num_boxes - 1) - idx = jnp.reshape( - idx + jnp.reshape(jnp.arange(batch_size) * num_boxes, [-1, 1]), [-1]) - - return idx - boxes = jnp.reshape((jnp.reshape(boxes, [-1, 4]))[idx], - [batch_size, max_output_size, 4]) - boxes = boxes * (jnp.reshape(jnp.arange(max_output_size), [1, -1, 1]) - < jnp.reshape(output_size, [-1, 1, 1])).astype(boxes.dtype) - scores = jnp.reshape( - jnp.reshape(scores, [-1, 1])[idx], [batch_size, max_output_size]) - scores = scores * (jnp.reshape(jnp.arange(max_output_size), [1, -1]) - < jnp.reshape(output_size, [-1, 1])).astype(scores.dtype) - return scores, boxes + """A wrapper that handles non-maximum suppression. + + Assumption: + * The boxes are sorted by scores unless the box is a dot (all coordinates + are zero). + * Boxes with higher scores can be used to suppress boxes with lower scores. + + The overal design of the algorithm is to handle boxes tile-by-tile: + + boxes = boxes.pad_to_multiply_of(tile_size) + num_tiles = len(boxes) // tile_size + output_boxes = [] + for i in range(num_tiles): + box_tile = boxes[i*tile_size : (i+1)*tile_size] + for j in range(i - 1): + suppressing_tile = boxes[j*tile_size : (j+1)*tile_size] + iou = _bbox_overlap(box_tile, suppressing_tile) + # if the box is suppressed in iou, clear it to a dot + box_tile *= _update_boxes(iou) + # Iteratively handle the diagnal tile. + iou = _box_overlap(box_tile, box_tile) + iou_changed = True + while iou_changed: + # boxes that are not suppressed by anything else + suppressing_boxes = _get_suppressing_boxes(iou) + # boxes that are suppressed by suppressing_boxes + suppressed_boxes = _get_suppressed_boxes(iou, suppressing_boxes) + # clear iou to 0 for boxes that are suppressed, as they cannot be used + # to suppress other boxes any more + new_iou = _clear_iou(iou, suppressed_boxes) + iou_changed = (new_iou != iou) + iou = new_iou + # remaining boxes that can still suppress others, are selected boxes. + output_boxes.append(_get_suppressing_boxes(iou)) + if len(output_boxes) >= max_output_size: + break + + Args: + scores: a tensor with a shape of [batch_size, anchors]. + boxes: a tensor with a shape of [batch_size, anchors, 4]. + max_output_size: a scalar integer `Tensor` representing the maximum number + of boxes to be selected by non max suppression. + iou_threshold: a float representing the threshold for deciding whether boxes + overlap too much with respect to IOU. + Returns: + nms_scores: a tensor with a shape of [batch_size, anchors]. It has same + dtype as input scores. + nms_proposals: a tensor with a shape of [batch_size, anchors, 4]. It has + same dtype as input boxes. + """ + batch_size = boxes.shape[0] + num_boxes = boxes.shape[1] + pad = ( + int(jnp.ceil(float(num_boxes) / _NMS_TILE_SIZE)) * _NMS_TILE_SIZE + - num_boxes + ) + boxes = jnp.pad(boxes.astype(jnp.float32), [[0, 0], [0, pad], [0, 0]]) + scores = jnp.pad(scores.astype(jnp.float32), [[0, 0], [0, pad]]) + num_boxes += pad + + def _loop_cond(in_args): + unused_boxes, unused_threshold, output_size, idx = in_args + return jnp.logical_and( + jnp.min(output_size) < max_output_size, + idx < num_boxes // _NMS_TILE_SIZE, + ) + + selected_boxes, _, output_size, _ = lax.while_loop( + _loop_cond, + _suppression_loop_body, + (boxes, iou_threshold, jnp.zeros([batch_size], jnp.int32), 0), + ) + idx = num_boxes - lax.top_k( + jnp.any(selected_boxes > 0, [2]).astype(jnp.int32) + * jnp.expand_dims(jnp.arange(num_boxes, 0, -1), 0), + max_output_size, + )[0].astype(jnp.int32) + idx = jnp.minimum(idx, num_boxes - 1) + idx = jnp.reshape( + idx + jnp.reshape(jnp.arange(batch_size) * num_boxes, [-1, 1]), [-1] + ) + + return idx + boxes = jnp.reshape( + (jnp.reshape(boxes, [-1, 4]))[idx], [batch_size, max_output_size, 4] + ) + boxes = boxes * ( + jnp.reshape(jnp.arange(max_output_size), [1, -1, 1]) + < jnp.reshape(output_size, [-1, 1, 1]) + ).astype(boxes.dtype) + scores = jnp.reshape( + jnp.reshape(scores, [-1, 1])[idx], [batch_size, max_output_size] + ) + scores = scores * ( + jnp.reshape(jnp.arange(max_output_size), [1, -1]) + < jnp.reshape(output_size, [-1, 1]) + ).astype(scores.dtype) + return scores, boxes # registry: def nms(boxes, scores, iou_threshold): - max_output_size = boxes.shape[0] - boxes = boxes.reshape((1, *boxes.shape)) - scores = scores.reshape((1, *scores.shape)) - res = non_max_suppression_padded(scores, boxes, max_output_size, - iou_threshold) - return res + max_output_size = boxes.shape[0] + boxes = boxes.reshape((1, *boxes.shape)) + scores = scores.reshape((1, *scores.shape)) + res = non_max_suppression_padded( + scores, boxes, max_output_size, iou_threshold + ) + return res try: - import torch - import torchvision - ops_registry.register_torch_dispatch_op(torch.ops.torchvision.nms, nms) + import torch + import torchvision + + ops_registry.register_torch_dispatch_op(torch.ops.torchvision.nms, nms) except Exception: - pass + pass diff --git a/torchax/torchax/ops/mappings.py b/torchax/torchax/ops/mappings.py index 409a6d8350be..d363e19f1a09 100644 --- a/torchax/torchax/ops/mappings.py +++ b/torchax/torchax/ops/mappings.py @@ -8,132 +8,114 @@ def t2j(t, use_dlpack=True): - is_bool = False - if t.dtype == torch.bool: - is_bool = True - t = t.to(torch.int8) - - t = t.to_dense() - - if not t.is_contiguous(): - t = t.contiguous() - - res = None - if use_dlpack: - try: - res = jaxdl.from_dlpack(t) - except Exception: - pass - - if res is None: - # https://github.com/google/jax/issues/7657 - # https://github.com/google/jax/issues/17784 - if t.dtype == torch.bfloat16: - nparray = (t.cpu().detach().to(torch.float32).numpy() - ) # numpy don't support bfloat16 - else: - nparray = t.cpu().detach().numpy() - res = jnp.asarray(nparray) - if t.dtype == torch.bfloat16: - res = res.astype(jnp.bfloat16) - - if is_bool: - res = res.astype(jnp.bool_) - return res + is_bool = False + if t.dtype == torch.bool: + is_bool = True + t = t.to(torch.int8) + t = t.to_dense() + + if not t.is_contiguous(): + t = t.contiguous() -def j2t(x, use_dlpack=True): - with mode_utils.no_dispatch(), torch._C.DisableTorchFunction(): res = None if use_dlpack: - try: - dl = jaxdl.to_dlpack(x) - res = torchdl.from_dlpack(dl) - except Exception: - res = None + try: + res = jaxdl.from_dlpack(t) + except Exception: + pass - orig_dtype = None if res is None: - orig_dtype = None - if x.dtype == jnp.bfloat16.dtype: - orig_dtype = x.dtype - x = x.astype(jnp.float32.dtype) - res = torch.from_numpy(numpy.asarray(x)) + # https://github.com/google/jax/issues/7657 + # https://github.com/google/jax/issues/17784 + if t.dtype == torch.bfloat16: + nparray = ( + t.cpu().detach().to(torch.float32).numpy() + ) # numpy don't support bfloat16 + else: + nparray = t.cpu().detach().numpy() + res = jnp.asarray(nparray) + if t.dtype == torch.bfloat16: + res = res.astype(jnp.bfloat16) + + if is_bool: + res = res.astype(jnp.bool_) + return res - if x.dtype == jnp.bool_: - res = res.to(torch.bool) - if orig_dtype is not None: - res = res.to(j2t_dtype(orig_dtype)) - return res +def j2t(x, use_dlpack=True): + with mode_utils.no_dispatch(), torch._C.DisableTorchFunction(): + res = None + if use_dlpack: + try: + dl = jaxdl.to_dlpack(x) + res = torchdl.from_dlpack(dl) + except Exception: + res = None + + orig_dtype = None + if res is None: + orig_dtype = None + if x.dtype == jnp.bfloat16.dtype: + orig_dtype = x.dtype + x = x.astype(jnp.float32.dtype) + res = torch.from_numpy(numpy.asarray(x)) + + if x.dtype == jnp.bool_: + res = res.to(torch.bool) + + if orig_dtype is not None: + res = res.to(j2t_dtype(orig_dtype)) + return res TORCH_DTYPE_TO_JAX = { # NO_MAPPING : jnp.float0.dtype (signless scalar int), - torch.bool: - jnp.bool_.dtype, + torch.bool: jnp.bool_.dtype, # NO_MAPPING : jnp.int4.dtype, - torch.int8: - jnp.int8.dtype, - torch.int16: - jnp.int16.dtype, - torch.int32: - jnp.int32.dtype, - torch.int64: - jnp.int64.dtype, - torch.long: - jnp.int64.dtype, + torch.int8: jnp.int8.dtype, + torch.int16: jnp.int16.dtype, + torch.int32: jnp.int32.dtype, + torch.int64: jnp.int64.dtype, + torch.long: jnp.int64.dtype, # NO_MAPPING : jnp.uint4 - torch.uint8: - jnp.uint8.dtype, - torch.uint16: - jnp.uint16.dtype, - torch.uint32: - jnp.uint32.dtype, - torch.uint64: - jnp.uint64.dtype, + torch.uint8: jnp.uint8.dtype, + torch.uint16: jnp.uint16.dtype, + torch.uint32: jnp.uint32.dtype, + torch.uint64: jnp.uint64.dtype, # NO_MAPPING : jnp.float8_e4m3b11fnuz.dtype, - torch.float8_e4m3fn: - jnp.float8_e4m3fn.dtype, + torch.float8_e4m3fn: jnp.float8_e4m3fn.dtype, # NO_MAPPING : jnp.float8_e4m3fnuz.dtype, - torch.float8_e5m2: - jnp.float8_e5m2.dtype, + torch.float8_e5m2: jnp.float8_e5m2.dtype, # NO_MAPPING : jnp.float8_e5m2fnuz.dtype, - torch.bfloat16: - jnp.bfloat16.dtype, - torch.half: - jnp.float16.dtype, - torch.float16: - jnp.float16.dtype, - torch.float32: - jnp.float32.dtype, - torch.float64: - jnp.float64.dtype, - torch.double: - jnp.double.dtype, - torch.complex64: - jnp.complex64.dtype, - torch.complex128: - jnp.complex128.dtype, - None: - None, + torch.bfloat16: jnp.bfloat16.dtype, + torch.half: jnp.float16.dtype, + torch.float16: jnp.float16.dtype, + torch.float32: jnp.float32.dtype, + torch.float64: jnp.float64.dtype, + torch.double: jnp.double.dtype, + torch.complex64: jnp.complex64.dtype, + torch.complex128: jnp.complex128.dtype, + None: None, } JAX_DTYPE_TO_TORCH = {value: key for key, value in TORCH_DTYPE_TO_JAX.items()} # Add imprecise mappings for some JAX dtypes which don't have torch analogues -JAX_DTYPE_TO_TORCH[jnp.dtype('int4')] = torch.int8 -JAX_DTYPE_TO_TORCH[jnp.dtype('uint4')] = torch.uint8 +JAX_DTYPE_TO_TORCH[jnp.dtype("int4")] = torch.int8 +JAX_DTYPE_TO_TORCH[jnp.dtype("uint4")] = torch.uint8 def t2j_dtype(dtype): - if dtype not in TORCH_DTYPE_TO_JAX: - raise RuntimeError( - f'Attempting to convert unknown type: {dtype} to jax type,') - return TORCH_DTYPE_TO_JAX[dtype] + if dtype not in TORCH_DTYPE_TO_JAX: + raise RuntimeError( + f"Attempting to convert unknown type: {dtype} to jax type," + ) + return TORCH_DTYPE_TO_JAX[dtype] def j2t_dtype(dtype): - if dtype not in JAX_DTYPE_TO_TORCH: - raise RuntimeError( - f'Attempting to convert unknown type: {dtype} to torch type,') - return JAX_DTYPE_TO_TORCH[dtype] + if dtype not in JAX_DTYPE_TO_TORCH: + raise RuntimeError( + f"Attempting to convert unknown type: {dtype} to torch type," + ) + return JAX_DTYPE_TO_TORCH[dtype] diff --git a/torchax/torchax/ops/op_base.py b/torchax/torchax/ops/op_base.py index d69e85ae50a6..06cdd4b4806a 100644 --- a/torchax/torchax/ops/op_base.py +++ b/torchax/torchax/ops/op_base.py @@ -12,120 +12,125 @@ class InplaceOp: - - def __init__(self, - functional_op, - replace=False, - position_to_mutate=0, - is_jax_func=False): - self.functional = functional_op - self.replace = replace - self.position_to_mutate = position_to_mutate - self.is_jax_func = is_jax_func - - def __call__(self, *args, **kwargs): - to_mutate = args[self.position_to_mutate] - view_value = to_mutate - if isinstance(to_mutate, View): - view_value = to_mutate.torch() - # Convert the target View to a Tensor, and - # leave the rest args as is. If other args are - # also View, they will be converted to tensors - # in the self.functional dispatch. - env = view_value._env - if self.is_jax_func: - view_value, args, kwargs = env.t2j_iso((view_value, args, kwargs)) - new_value_jax = self.functional(view_value, *args[1:], **kwargs) - new_value = env.j2t_iso(new_value_jax) - else: - new_value = self.functional(view_value, *args[1:], **kwargs) - - if isinstance(to_mutate, View): - to_mutate.update(new_value) - else: - if self.replace: - to_mutate._elem = new_value._elem - else: - to_mutate.copy_(new_value) - return to_mutate + def __init__( + self, + functional_op, + replace=False, + position_to_mutate=0, + is_jax_func=False, + ): + self.functional = functional_op + self.replace = replace + self.position_to_mutate = position_to_mutate + self.is_jax_func = is_jax_func + + def __call__(self, *args, **kwargs): + to_mutate = args[self.position_to_mutate] + view_value = to_mutate + if isinstance(to_mutate, View): + view_value = to_mutate.torch() + # Convert the target View to a Tensor, and + # leave the rest args as is. If other args are + # also View, they will be converted to tensors + # in the self.functional dispatch. + env = view_value._env + if self.is_jax_func: + view_value, args, kwargs = env.t2j_iso((view_value, args, kwargs)) + new_value_jax = self.functional(view_value, *args[1:], **kwargs) + new_value = env.j2t_iso(new_value_jax) + else: + new_value = self.functional(view_value, *args[1:], **kwargs) + + if isinstance(to_mutate, View): + to_mutate.update(new_value) + else: + if self.replace: + to_mutate._elem = new_value._elem + else: + to_mutate.copy_(new_value) + return to_mutate class OutVariant: - - def __call__(self, *args, **kwargs): - to_mutate = kwargs['out'] - del kwargs['out'] - to_mutate._elem = self.functional(*args, **kwargs)._elem - return to_mutate + def __call__(self, *args, **kwargs): + to_mutate = kwargs["out"] + del kwargs["out"] + to_mutate._elem = self.functional(*args, **kwargs)._elem + return to_mutate -P = ParamSpec('P') +P = ParamSpec("P") def convert_dtype(use_default_dtype: bool = True): - """Converts `dtype` kwarg of function from torch to JAX. + """Converts `dtype` kwarg of function from torch to JAX. - Args: - use_default_dtype: Whether to use torch default dtype if none is provided. + Args: + use_default_dtype: Whether to use torch default dtype if none is provided. - Returns: - A decorator that wraps a JAX implementation of a torch function. - """ + Returns: + A decorator that wraps a JAX implementation of a torch function. + """ - def decorator(func: types.TorchCallable): + def decorator(func: types.TorchCallable): + @functools.wraps(func) + def wrapper( + *args: P.args, + dtype: Optional[torch.dtype] = None, + **kwargs: P.kwargs, + ): + if not dtype and use_default_dtype: + dtype = torch.get_default_dtype() + if isinstance(dtype, torch.dtype): + jax_dtype = mappings.t2j_dtype(dtype) + else: + jax_dtype = dtype - @functools.wraps(func) - def wrapper(*args: P.args, - dtype: Optional[torch.dtype] = None, - **kwargs: P.kwargs): - if not dtype and use_default_dtype: - dtype = torch.get_default_dtype() - if isinstance(dtype, torch.dtype): - jax_dtype = mappings.t2j_dtype(dtype) - else: - jax_dtype = dtype + return func(*args, dtype=jax_dtype, **kwargs) - return func(*args, dtype=jax_dtype, **kwargs) + return wrapper - return wrapper + return decorator - return decorator +def maybe_convert_constant_dtype( + val: Optional[types.JaxValue], dtype: Optional[jnp.dtype] +): + """Optionally converts scalar constant's dtype using `numpy` -def maybe_convert_constant_dtype(val: Optional[types.JaxValue], - dtype: Optional[jnp.dtype]): - """Optionally converts scalar constant's dtype using `numpy` + Use in cases where you require a constant and can't handle a traced array. + """ + if val and dtype: + if isinstance(val, jax.Array): + return maybe_convert_constant_dtype(val.item(), dtype) - Use in cases where you require a constant and can't handle a traced array. - """ - if val and dtype: - if isinstance(val, jax.Array): - return maybe_convert_constant_dtype(val.item(), dtype) + return np.array(val, dtype) - return np.array(val, dtype) - - return val + return val def promote_int_input(f: Callable[Concatenate[jax.Array, P], types.JaxValue]): - """If the first argument is an int array, promote it to float32.""" + """If the first argument is an int array, promote it to float32.""" - @functools.wraps(f) - def wrapper(x: jax.Array, *args: P.args, **kwargs: P.kwargs): - if x.dtype in [jnp.int8, jnp.int16, jnp.int32, jnp.int64]: - x = x.astype(mappings.t2j_dtype(torch.get_default_dtype())) + @functools.wraps(f) + def wrapper(x: jax.Array, *args: P.args, **kwargs: P.kwargs): + if x.dtype in [jnp.int8, jnp.int16, jnp.int32, jnp.int64]: + x = x.astype(mappings.t2j_dtype(torch.get_default_dtype())) - return f(x, *args, **kwargs) + return f(x, *args, **kwargs) - return wrapper + return wrapper -def foreach_loop(seq: jax.Array, - fn: Callable[[jax.Array, jax.Array], jax.Array], - init_val=0.0): - """Run `fn` for each element of 1D array `seq`. +def foreach_loop( + seq: jax.Array, + fn: Callable[[jax.Array, jax.Array], jax.Array], + init_val=0.0, +): + """Run `fn` for each element of 1D array `seq`. - Similar to `functools.reduce`, but implemented with `jax.lax.fori_loop`.""" - assert len(seq.shape) == 1 - return jax.lax.fori_loop(0, len(seq), lambda i, carry: fn(carry, seq[i]), - init_val) + Similar to `functools.reduce`, but implemented with `jax.lax.fori_loop`.""" + assert len(seq.shape) == 1 + return jax.lax.fori_loop( + 0, len(seq), lambda i, carry: fn(carry, seq[i]), init_val + ) diff --git a/torchax/torchax/ops/ops_registry.py b/torchax/torchax/ops/ops_registry.py index aa0d61cbb491..4d8cb770a72c 100644 --- a/torchax/torchax/ops/ops_registry.py +++ b/torchax/torchax/ops/ops_registry.py @@ -7,49 +7,55 @@ @dataclasses.dataclass class Operator: - torch_op: TorchCallable - func: Union[TorchCallable, JaxCallable] - is_jax_function: bool - is_user_defined: bool - needs_env: bool - is_view_op: bool + torch_op: TorchCallable + func: Union[TorchCallable, JaxCallable] + is_jax_function: bool + is_user_defined: bool + needs_env: bool + is_view_op: bool all_aten_ops: Dict[TorchCallable, Operator] = {} all_torch_functions: Dict[TorchCallable, Operator] = {} -def register_torch_dispatch_op(aten_op, - impl_callable, - is_jax_function=True, - is_user_defined=False, - needs_env=False, - is_view_op=False): - op = Operator( - aten_op, - impl_callable, - is_jax_function=is_jax_function, - is_user_defined=is_user_defined, - needs_env=needs_env, - is_view_op=is_view_op) - if aten_op in all_aten_ops: - logging.warning(f'Duplicate op registration for {aten_op}') - all_aten_ops[aten_op] = op - return impl_callable - - -def register_torch_function_op(torch_func, - impl_callable, - is_jax_function=True, - is_user_defined=False, - needs_env=False, - is_view_op=False): - op = Operator( - torch_func, - impl_callable, - is_jax_function=is_jax_function, - is_user_defined=is_user_defined, - needs_env=needs_env, - is_view_op=is_view_op) - all_torch_functions[torch_func] = op - return impl_callable +def register_torch_dispatch_op( + aten_op, + impl_callable, + is_jax_function=True, + is_user_defined=False, + needs_env=False, + is_view_op=False, +): + op = Operator( + aten_op, + impl_callable, + is_jax_function=is_jax_function, + is_user_defined=is_user_defined, + needs_env=needs_env, + is_view_op=is_view_op, + ) + if aten_op in all_aten_ops: + logging.warning(f"Duplicate op registration for {aten_op}") + all_aten_ops[aten_op] = op + return impl_callable + + +def register_torch_function_op( + torch_func, + impl_callable, + is_jax_function=True, + is_user_defined=False, + needs_env=False, + is_view_op=False, +): + op = Operator( + torch_func, + impl_callable, + is_jax_function=is_jax_function, + is_user_defined=is_user_defined, + needs_env=needs_env, + is_view_op=is_view_op, + ) + all_torch_functions[torch_func] = op + return impl_callable diff --git a/torchax/torchax/tensor.py b/torchax/torchax/tensor.py index e3067dd3ec77..dfbc851a55ce 100644 --- a/torchax/torchax/tensor.py +++ b/torchax/torchax/tensor.py @@ -22,259 +22,276 @@ class OperatorNotFound(Exception): - pass + pass def wrap(jaxarray): - return torch_pytree.tree_map_only(jnp.ndarray, Tensor, jaxarray) + return torch_pytree.tree_map_only(jnp.ndarray, Tensor, jaxarray) def unwrap(torchtensors): - return torch_pytree.tree_map_only(Tensor, lambda x: x._elem, torchtensors) + return torch_pytree.tree_map_only(Tensor, lambda x: x._elem, torchtensors) @contextlib.contextmanager def log_nested(env, message): - if env.config.debug_print_each_op: - print((" " * log_nested.level) + message, file=sys.stderr) - log_nested.level += 1 - yield - log_nested.level -= 1 + if env.config.debug_print_each_op: + print((" " * log_nested.level) + message, file=sys.stderr) + log_nested.level += 1 + yield + log_nested.level -= 1 log_nested.level = 0 class Tensor(torch.Tensor): + @staticmethod + def __new__(cls, elem, env): + dtype = mappings.j2t_dtype(elem.dtype) + shape = list(elem.shape) + for i, s in enumerate(shape): + if not isinstance(s, int): + shape[i] = 1 + if dtype is None: + dtype = torch.float32 + return torch.Tensor._make_wrapper_subclass( + cls, + shape, + dtype=dtype, + device="meta", + requires_grad=False, + ) - @staticmethod - def __new__(cls, elem, env): - dtype = mappings.j2t_dtype(elem.dtype) - shape = list(elem.shape) - for i, s in enumerate(shape): - if not isinstance(s, int): - shape[i] = 1 - if dtype is None: - dtype = torch.float32 - return torch.Tensor._make_wrapper_subclass( - cls, - shape, - dtype=dtype, - device="meta", - requires_grad=False, - ) - - def __init__(self, elem: jax.Array, env: "Environment"): - super().__init__() - self._elem = elem - self._env = env - - def __str__(self): - return "Tensor({} {})".format(str(type(self._elem)), str(self._elem)) - - __repr__ = __str__ - - def __jax_array__(self): - return self._elem - - @property - def shape(self): - return torch.Size(self._elem.shape) + def __init__(self, elem: jax.Array, env: "Environment"): + super().__init__() + self._elem = elem + self._env = env - @property - def ndim(self): - return len(self._elem.shape) + def __str__(self): + return "Tensor({} {})".format(str(type(self._elem)), str(self._elem)) - def flatten(self, start_dim=0, end_dim=-1): - if end_dim == -1: - end_dim = self.ndim - new_shape = ( - self._elem.shape[:start_dim] + (-1,) + self._elem.shape[end_dim + 1:]) - new_elem = jnp.reshape(self._elem, new_shape) - return Tensor(new_elem, self._env) - # return torch.reshape(self, new_shape) + __repr__ = __str__ - def __setitem__(self, key, val): - key, val = self._env.t2j_iso((key, val)) - self._elem = self._elem.at[key].set(val) + def __jax_array__(self): + return self._elem - def type_as(self, other): - self._elem = self._elem.astype(other._elem.dtype) - return self + @property + def shape(self): + return torch.Size(self._elem.shape) - __torch_function__ = torch._C._disabled_torch_function_impl + @property + def ndim(self): + return len(self._elem.shape) - @classmethod - def __torch_dispatch__(cls, func, types, args=(), kwargs=None): - # TODO(hanq): figure out why is dispatch mode not sufficient - if func == torch.ops._c10d_functional.wait_tensor.default: - return args[0]._env.dispatch(func, types, args, kwargs) - raise AssertionError( - 'torchax Tensors can only do math within the torchax environment.' - 'Please wrap your code with `with torchax.default_env()` or ' - 'call torchax.enable_globally() before.') + def flatten(self, start_dim=0, end_dim=-1): + if end_dim == -1: + end_dim = self.ndim + new_shape = ( + self._elem.shape[:start_dim] + + (-1,) + + self._elem.shape[end_dim + 1 :] + ) + new_elem = jnp.reshape(self._elem, new_shape) + return Tensor(new_elem, self._env) + # return torch.reshape(self, new_shape) + + def __setitem__(self, key, val): + key, val = self._env.t2j_iso((key, val)) + self._elem = self._elem.at[key].set(val) + + def type_as(self, other): + self._elem = self._elem.astype(other._elem.dtype) + return self + + __torch_function__ = torch._C._disabled_torch_function_impl + + @classmethod + def __torch_dispatch__(cls, func, types, args=(), kwargs=None): + # TODO(hanq): figure out why is dispatch mode not sufficient + if func == torch.ops._c10d_functional.wait_tensor.default: + return args[0]._env.dispatch(func, types, args, kwargs) + raise AssertionError( + "torchax Tensors can only do math within the torchax environment." + "Please wrap your code with `with torchax.default_env()` or " + "call torchax.enable_globally() before." + ) - def detach(self): - return Tensor(jax.lax.stop_gradient(self.jax()), self._env) + def detach(self): + return Tensor(jax.lax.stop_gradient(self.jax()), self._env) - def numpy(self) -> numpy.ndarray: - import numpy as np + def numpy(self) -> numpy.ndarray: + import numpy as np - return np.array(self._elem) + return np.array(self._elem) - def jax(self) -> jax.Array: - return self._elem + def jax(self) -> jax.Array: + return self._elem - def torch(self) -> torch.Tensor: - return self._env.j2t_copy(self.jax()) + def torch(self) -> torch.Tensor: + return self._env.j2t_copy(self.jax()) - @property - def dtype(self): - return mappings.j2t_dtype(self._elem.dtype) + @property + def dtype(self): + return mappings.j2t_dtype(self._elem.dtype) - def dim(self): - return self.ndim + def dim(self): + return self.ndim - @property - def device(self): - return torch.device("jax:0") + @property + def device(self): + return torch.device("jax:0") - @property - def jax_device(self): - return self._elem.device + @property + def jax_device(self): + return self._elem.device - @property - def data(self): - logger.warn("In-place to .data modifications still results a copy on TPU") - return self + @property + def data(self): + logger.warn( + "In-place to .data modifications still results a copy on TPU" + ) + return self - @data.setter - def data(self, other): - if isinstance(other, Tensor): - self._elem = other._elem + @data.setter + def data(self, other): + if isinstance(other, Tensor): + self._elem = other._elem - def apply_jax(self, jax_function, *args, **kwargs): - # Call a jax function on _elem - res = jax_function(self._elem, *args, **kwargs) - return self._env.j2t_iso(res) + def apply_jax(self, jax_function, *args, **kwargs): + # Call a jax function on _elem + res = jax_function(self._elem, *args, **kwargs) + return self._env.j2t_iso(res) - def apply_jax_(self, jax_function, *args, **kwargs): - self._elem = jax_function(self._elem, *args, **kwargs) - return self + def apply_jax_(self, jax_function, *args, **kwargs): + self._elem = jax_function(self._elem, *args, **kwargs) + return self - def tolist(self): - return self._elem.tolist() + def tolist(self): + return self._elem.tolist() - def shard_(self, sharding): - self.apply_jax_(jax.lax.with_sharding_constraint, sharding) + def shard_(self, sharding): + self.apply_jax_(jax.lax.with_sharding_constraint, sharding) def debug_accuracy(func, args, kwargs, current_output): - args_torch, kwargs_torch, out_torch = torch_pytree.tree_map_only( - torch.Tensor, lambda x: x.torch(), (args, kwargs, current_output)) + args_torch, kwargs_torch, out_torch = torch_pytree.tree_map_only( + torch.Tensor, lambda x: x.torch(), (args, kwargs, current_output) + ) - with mode_utils.no_dispatch(), torch._C.DisableTorchFunction(): - if "device" in kwargs_torch: - kwargs_torch["device"] = "cpu" # do the torch native for comparison - expected_out = func(*args_torch, **kwargs_torch) + with mode_utils.no_dispatch(), torch._C.DisableTorchFunction(): + if "device" in kwargs_torch: + kwargs_torch["device"] = "cpu" # do the torch native for comparison + expected_out = func(*args_torch, **kwargs_torch) - flattened_current_out, _ = torch_pytree.tree_flatten(out_torch) - flattened_expected_out, _ = torch_pytree.tree_flatten(expected_out) + flattened_current_out, _ = torch_pytree.tree_flatten(out_torch) + flattened_expected_out, _ = torch_pytree.tree_flatten(expected_out) - for ex, real in zip(flattened_expected_out, flattened_current_out): - if isinstance(ex, torch.Tensor) and ex.dtype != real.dtype: - ex = ex.to(real.dtype) - try: - if isinstance(ex, torch.Tensor) and not torch.allclose( - ex, real, atol=1e-3, equal_nan=True): - import pdb + for ex, real in zip(flattened_expected_out, flattened_current_out): + if isinstance(ex, torch.Tensor) and ex.dtype != real.dtype: + ex = ex.to(real.dtype) + try: + if isinstance(ex, torch.Tensor) and not torch.allclose( + ex, real, atol=1e-3, equal_nan=True + ): + import pdb - pdb.set_trace() - except: - import pdb + pdb.set_trace() + except: + import pdb - pdb.set_trace() + pdb.set_trace() - return True + return True def _make_debug_msg(is_dispatch, log_args, func, args, kwargs): + def _display(a): + if isinstance(a, torch.Tensor): + return f"Tensor of {type(a)}: {a.dtype}{a.shape}" + elif isinstance(a, jax.Array): + return f"Jax Array of {type(a)}: {a.dtype}{a.shape}" + else: + return str(a) - def _display(a): - if isinstance(a, torch.Tensor): - return f"Tensor of {type(a)}: {a.dtype}{a.shape}" - elif isinstance(a, jax.Array): - return f"Jax Array of {type(a)}: {a.dtype}{a.shape}" - else: - return str(a) - - kwargs = kwargs or {} - title = "DISPATCH" if is_dispatch else "FUNCTION" - args_msg = "args: " + ",".join(_display(a) for a in args) if log_args else "" - kwargs_msg = ("kwargs: " + - ",".join(f"{key}: {_display(a)}" for key, a in kwargs.items()) - if log_args else "") - return f"{title}: {_name_of_func(func)} {args_msg} ~ {kwargs_msg}" + kwargs = kwargs or {} + title = "DISPATCH" if is_dispatch else "FUNCTION" + args_msg = ( + "args: " + ",".join(_display(a) for a in args) if log_args else "" + ) + kwargs_msg = ( + "kwargs: " + + ",".join(f"{key}: {_display(a)}" for key, a in kwargs.items()) + if log_args + else "" + ) + return f"{title}: {_name_of_func(func)} {args_msg} ~ {kwargs_msg}" class XLAFunctionMode(torch.overrides.TorchFunctionMode): - """Context manager that dispatches torch function calls to JAX.""" - - def __init__(self, env): - self.env = env - - def __torch_function__(self, - func, - types, - args=(), - kwargs=None) -> torch.Tensor: - message = f"FUNCTION: {_name_of_func(func)}" - if self.env.config.debug_print_each_op_operands: - message = message + "f" - message = _make_debug_msg(False, - self.env.config.debug_print_each_op_operands, - func, args, kwargs) - with log_nested(self.env, message): - try: - return self.env.dispatch(func, types, args, kwargs) - except OperatorNotFound: - pass - if _name_of_func(func) in ( - "rot90"): # skip rot90 with k%4==0 due to no change - if len(args) >= 2 and type(args[1]) == int: - if (args[1]) % 4 == 0: - return args[0] - return func(*args, **(kwargs or {})) + """Context manager that dispatches torch function calls to JAX.""" + + def __init__(self, env): + self.env = env + + def __torch_function__( + self, func, types, args=(), kwargs=None + ) -> torch.Tensor: + message = f"FUNCTION: {_name_of_func(func)}" + if self.env.config.debug_print_each_op_operands: + message = message + "f" + message = _make_debug_msg( + False, + self.env.config.debug_print_each_op_operands, + func, + args, + kwargs, + ) + with log_nested(self.env, message): + try: + return self.env.dispatch(func, types, args, kwargs) + except OperatorNotFound: + pass + if _name_of_func(func) in ( + "rot90" + ): # skip rot90 with k%4==0 due to no change + if len(args) >= 2 and type(args[1]) == int: + if (args[1]) % 4 == 0: + return args[0] + return func(*args, **(kwargs or {})) class XLADispatchMode(torch_dispatch.TorchDispatchMode): - - def __init__(self, env): - self.env = env - - def __torch_dispatch__(self, func, types, args=(), kwargs=None): - message = _make_debug_msg(True, - self.env.config.debug_print_each_op_operands, - func, args, kwargs) - with log_nested(self.env, message): - if isinstance(func, torch._ops.OpOverloadPacket): - with self: - return func(*args, **kwargs) - # Only functions under these namespaces will be intercepted - if func.namespace not in ( - "aten", - "_c10d_functional", - "torchvision", - "xla", - ): - return func(*args, **kwargs) - return self.env.dispatch(func, types, args, kwargs) + def __init__(self, env): + self.env = env + + def __torch_dispatch__(self, func, types, args=(), kwargs=None): + message = _make_debug_msg( + True, + self.env.config.debug_print_each_op_operands, + func, + args, + kwargs, + ) + with log_nested(self.env, message): + if isinstance(func, torch._ops.OpOverloadPacket): + with self: + return func(*args, **kwargs) + # Only functions under these namespaces will be intercepted + if func.namespace not in ( + "aten", + "_c10d_functional", + "torchvision", + "xla", + ): + return func(*args, **kwargs) + return self.env.dispatch(func, types, args, kwargs) def _name_of_func(func): - if hasattr(func, "name"): - return func.name() - return func.__name__ + if hasattr(func, "name"): + return func.name() + return func.__name__ # Constructors that don't take other tensor as input @@ -298,402 +315,432 @@ def _name_of_func(func): class Environment(contextlib.ContextDecorator): - """This class holds a set of configurations and "globals" needed - - for executing torch program using jax. - Things included so far: - - op registry - PRNGKey - Configs - - Also helper functions to manipulate those. - """ - - def __init__(self, configuration=None): - self._function_mode = XLAFunctionMode(self) - self._dispatch_mode = XLADispatchMode(self) - - # name is torch callable - self._ops = {} - self._decomps = {} - - self.load_ops() - - self._mesh = None - self.config = configuration or config.Configuration() - - self._manually_entered = False - self.enabled = False - - self._prng_key = mutable_array( - jax.random.key(torch.initial_seed() % (1 << 63))) - self.autocast_dtype = None - self._target_device = "cpu" - - @property - def target_device(self): - return self._target_device - - @target_device.setter - def target_device(self, device: str): - self._target_device = device.lower() - - def manual_seed(self, key): - self._prng_key = mutable_array(jax.random.key(key)) - - @property - def prng_key(self): - return self._prng_key[...] - - def get_as_jax_device(self, device: Any): - if device is None: - device = torch.get_default_device() + """This class holds a set of configurations and "globals" needed - if isinstance(device, torch.device): - device = str(device) + for executing torch program using jax. + Things included so far: - if not self.config.use_torch_native_for_cpu_tensor and device.startswith( - "cpu"): - return jax.devices("cpu")[0] + op registry + PRNGKey + Configs - if self.config.treat_cuda_as_jax_device and device.startswith("cuda"): - return jax.local_devices()[0] - - if device.startswith("xla"): - return jax.local_devices()[0] + Also helper functions to manipulate those. + """ - # TODO (wen): jax is NOT a device type, - # once we can register more than one backend, revisit - if device.startswith("jax"): - match self.target_device: - case "cpu": - return jax.devices("cpu")[0] - case "tpu": - return jax.devices("tpu")[0] - case _: - raise AttributeError( - f"Cannot handle env.target_device {self.target_device}") + def __init__(self, configuration=None): + self._function_mode = XLAFunctionMode(self) + self._dispatch_mode = XLADispatchMode(self) - return None # fallback to torch + # name is torch callable + self._ops = {} + self._decomps = {} - def load_ops(self): - from torchax.ops import jaten, jtorch, jc10d, jtorchvision_nms + self.load_ops() - for k, v in itertools.chain(ops_registry.all_aten_ops.items(), - ops_registry.all_torch_functions.items()): - if v.is_jax_function: - self._ops[k] = v - else: - self._decomps[k] = v + self._mesh = None + self.config = configuration or config.Configuration() - from torchax.decompositions import DECOMPOSITIONS, MUTABLE_DECOMPOSITION + self._manually_entered = False + self.enabled = False - for k, v in DECOMPOSITIONS.items(): - if k not in self._decomps: - self._decomps[k] = ops_registry.Operator( - k, - v, - is_jax_function=False, - is_user_defined=False, - needs_env=False, - is_view_op=k in MUTABLE_DECOMPOSITION, + self._prng_key = mutable_array( + jax.random.key(torch.initial_seed() % (1 << 63)) ) + self.autocast_dtype = None + self._target_device = "cpu" + + @property + def target_device(self): + return self._target_device + + @target_device.setter + def target_device(self, device: str): + self._target_device = device.lower() + + def manual_seed(self, key): + self._prng_key = mutable_array(jax.random.key(key)) + + @property + def prng_key(self): + return self._prng_key[...] + + def get_as_jax_device(self, device: Any): + if device is None: + device = torch.get_default_device() + + if isinstance(device, torch.device): + device = str(device) + + if ( + not self.config.use_torch_native_for_cpu_tensor + and device.startswith("cpu") + ): + return jax.devices("cpu")[0] + + if self.config.treat_cuda_as_jax_device and device.startswith("cuda"): + return jax.local_devices()[0] + + if device.startswith("xla"): + return jax.local_devices()[0] + + # TODO (wen): jax is NOT a device type, + # once we can register more than one backend, revisit + if device.startswith("jax"): + match self.target_device: + case "cpu": + return jax.devices("cpu")[0] + case "tpu": + return jax.devices("tpu")[0] + case _: + raise AttributeError( + f"Cannot handle env.target_device {self.target_device}" + ) + + return None # fallback to torch + + def load_ops(self): + from torchax.ops import jaten, jtorch, jc10d, jtorchvision_nms + + for k, v in itertools.chain( + ops_registry.all_aten_ops.items(), + ops_registry.all_torch_functions.items(), + ): + if v.is_jax_function: + self._ops[k] = v + else: + self._decomps[k] = v + + from torchax.decompositions import DECOMPOSITIONS, MUTABLE_DECOMPOSITION + + for k, v in DECOMPOSITIONS.items(): + if k not in self._decomps: + self._decomps[k] = ops_registry.Operator( + k, + v, + is_jax_function=False, + is_user_defined=False, + needs_env=False, + is_view_op=k in MUTABLE_DECOMPOSITION, + ) + + def _get_op_or_decomp(self, func): + def _get_from_dict(op_dict, op): + op = op_dict.get(func) + if op is None and isinstance(func, torch._ops.OpOverloadPacket): + op = op_dict.get(func.default) + if op is None and isinstance(func, torch._ops.OpOverload): + op = op_dict.get(func.overloadpacket) + return op + + op = _get_from_dict(self._ops, func) + + if op is None: + # fallback to decompose + op = _get_from_dict(self._decomps, func) + + if op is None: + raise OperatorNotFound( + f"Operator with name {_name_of_func(func)} has no lowering" + ) + + return op + + def _to_copy(self, the_tensor, new_dtype, new_device): + if isinstance(the_tensor, View): + the_tensor = the_tensor.torch() + + if isinstance(the_tensor, Tensor): + arr = the_tensor.jax() + + if new_dtype is not None and new_dtype != arr.dtype: + arr = arr.astype(mappings.t2j_dtype(new_dtype)) + + if new_device is not None: + match str(new_device).lower(): + case "cpu": + # converting to a non-jax device: let torch native handle it + torch_tensor = ( + self.j2t_copy(arr) + if isinstance(the_tensor, Tensor) + else arr + ) + with ( + mode_utils.no_dispatch(), + torch._C.DisableTorchFunction(), + ): + return torch_tensor.to(new_device) + case "jax": + # move torchax.tensor / jax tensor between devices + # I don't know ifgit this will work after the model is jitted + if self.target_device != the_tensor.jax_device.platform: + arr = jax.device_put( + the_tensor.jax(), + jax.devices(self.target_device)[0], + ) + return Tensor(arr, self) + case _: + logging.error( + f"torchax.Tenosr cannot handle device {new_device}" + ) - def _get_op_or_decomp(self, func): - - def _get_from_dict(op_dict, op): - op = op_dict.get(func) - if op is None and isinstance(func, torch._ops.OpOverloadPacket): - op = op_dict.get(func.default) - if op is None and isinstance(func, torch._ops.OpOverload): - op = op_dict.get(func.overloadpacket) - return op - - op = _get_from_dict(self._ops, func) - - if op is None: - # fallback to decompose - op = _get_from_dict(self._decomps, func) - - if op is None: - raise OperatorNotFound( - f"Operator with name {_name_of_func(func)} has no lowering") - - return op - - def _to_copy(self, the_tensor, new_dtype, new_device): - if isinstance(the_tensor, View): - the_tensor = the_tensor.torch() - - if isinstance(the_tensor, Tensor): + else: + if new_dtype is not None and new_dtype != the_tensor.dtype: + with mode_utils.no_dispatch(), torch._C.DisableTorchFunction(): + the_tensor = the_tensor.to(new_dtype) - arr = the_tensor.jax() + if new_device is None: ## device is None means don't change device + return the_tensor - if new_dtype is not None and new_dtype != arr.dtype: - arr = arr.astype(mappings.t2j_dtype(new_dtype)) + jax_device = self.get_as_jax_device(new_device) + if jax_device: + arr = self.t2j_copy(the_tensor) + arr = jax.device_put(arr, jax_device) + else: + with mode_utils.no_dispatch(), torch._C.DisableTorchFunction(): + return the_tensor.to(new_device) + return Tensor(arr, self) - if new_device is not None: - match str(new_device).lower(): - case "cpu": - # converting to a non-jax device: let torch native handle it - torch_tensor = self.j2t_copy(arr) if isinstance(the_tensor, - Tensor) else arr - with mode_utils.no_dispatch(), torch._C.DisableTorchFunction(): - return torch_tensor.to(new_device) - case "jax": - # move torchax.tensor / jax tensor between devices - # I don't know ifgit this will work after the model is jitted - if self.target_device != the_tensor.jax_device.platform: - arr = jax.device_put(the_tensor.jax(), - jax.devices(self.target_device)[0]) - return Tensor(arr, self) - case _: - logging.error(f"torchax.Tenosr cannot handle device {new_device}") - - else: - if new_dtype is not None and new_dtype != the_tensor.dtype: - with mode_utils.no_dispatch(), torch._C.DisableTorchFunction(): - the_tensor = the_tensor.to(new_dtype) - - if new_device is None: ## device is None means don't change device - return the_tensor - - jax_device = self.get_as_jax_device(new_device) - if jax_device: - arr = self.t2j_copy(the_tensor) - arr = jax.device_put(arr, jax_device) - else: - with mode_utils.no_dispatch(), torch._C.DisableTorchFunction(): - return the_tensor.to(new_device) - - return Tensor(arr, self) - - def get_and_rotate_prng_key(self, - generator: Optional[torch.Generator] = None): - if generator is not None: - with mode_utils.no_dispatch(), torch._C.DisableTorchFunction(): - self._prng_key[...] = jax.random.key(generator.initial_seed() % (2**63)) - old_key = self._prng_key[...] - new_prng_key, next_key = jax.random.split(old_key) - self._prng_key[...] = new_prng_key - return next_key - - def _handle_tensor_constructor(self, func, args, kwargs): - device = kwargs.get("device") - jax_device = self.get_as_jax_device(device) - # TODO(qihqi) figure out better ways for device propagation - if not self._manually_entered and jax_device is None: - # let torch handle it - with mode_utils.no_dispatch(), torch._C.DisableTorchFunction(): - return func(*args, **kwargs) - with jax.default_device(jax_device): - requires_grad = kwargs.get("requires_grad", False) - op = self._get_op_or_decomp(func) - res = op.func(*args, **kwargs) - if isinstance(res, jax.Array): - res = Tensor(res, self) - if requires_grad: - res.requires_grad = True - return res - - def _torch_Tensor_to(self, args, kwargs): - the_tensor = args[0] - args = args[1:] - if len(args) >= 1 and isinstance(args[0], torch.Tensor): - dtype = args[0].dtype - device = args[0].device - return self._to_copy(the_tensor, dtype, device) - device = kwargs.get("device") - dtype = kwargs.get("dtype") - # args like pin_memory etc that we will ignore - args = list(filter(lambda x: not isinstance(x, bool), args)) - if len(args) >= 2: - device, dtype, *_ = args - elif len(args) == 1 and isinstance(args[0], torch.dtype): - dtype = args[0] - elif len(args) == 1: - device = args[0] - return self._to_copy(the_tensor, dtype, device) - - def dispatch(self, func, types, args, kwargs): - kwargs = kwargs or {} - if func in TENSOR_CONSTRUCTORS: - return self._handle_tensor_constructor(func, args, kwargs) - if func in ( - torch.Tensor.to, - torch.ops.aten.lift_fresh.default, - torch.ops.aten._to_copy, - torch.ops.aten._to_copy.default, + def get_and_rotate_prng_key( + self, generator: Optional[torch.Generator] = None ): - return self._torch_Tensor_to(args, kwargs) - - # If the func doesn't act on Tensor, and is not a tensor constructor, - # We should skip and let torch handle it. - - tensor_args = [ - t for t in torch_pytree.tree_flatten(args)[0] - if isinstance(t, torch.Tensor) - ] - - def is_not_torchax_tensor(x): - return not isinstance(x, Tensor) and not isinstance(x, View) - - if tensor_args and all(is_not_torchax_tensor(t) for t in tensor_args): - res = func(*args, **kwargs) - return res - - with jax.named_scope(_name_of_func(func)): - op = self._get_op_or_decomp(func) - - old_args, old_kwargs = args, kwargs - with self._dispatch_mode: - args, kwargs = torch_pytree.tree_map_only( - torch.distributed._functional_collectives.AsyncCollectiveTensor, - torch.distributed._functional_collectives.wait_tensor, - (args, kwargs), + if generator is not None: + with mode_utils.no_dispatch(), torch._C.DisableTorchFunction(): + self._prng_key[...] = jax.random.key( + generator.initial_seed() % (2**63) + ) + old_key = self._prng_key[...] + new_prng_key, next_key = jax.random.split(old_key) + self._prng_key[...] = new_prng_key + return next_key + + def _handle_tensor_constructor(self, func, args, kwargs): + device = kwargs.get("device") + jax_device = self.get_as_jax_device(device) + # TODO(qihqi) figure out better ways for device propagation + if not self._manually_entered and jax_device is None: + # let torch handle it + with mode_utils.no_dispatch(), torch._C.DisableTorchFunction(): + return func(*args, **kwargs) + with jax.default_device(jax_device): + requires_grad = kwargs.get("requires_grad", False) + op = self._get_op_or_decomp(func) + res = op.func(*args, **kwargs) + if isinstance(res, jax.Array): + res = Tensor(res, self) + if requires_grad: + res.requires_grad = True + return res + + def _torch_Tensor_to(self, args, kwargs): + the_tensor = args[0] + args = args[1:] + if len(args) >= 1 and isinstance(args[0], torch.Tensor): + dtype = args[0].dtype + device = args[0].device + return self._to_copy(the_tensor, dtype, device) + device = kwargs.get("device") + dtype = kwargs.get("dtype") + # args like pin_memory etc that we will ignore + args = list(filter(lambda x: not isinstance(x, bool), args)) + if len(args) >= 2: + device, dtype, *_ = args + elif len(args) == 1 and isinstance(args[0], torch.dtype): + dtype = args[0] + elif len(args) == 1: + device = args[0] + return self._to_copy(the_tensor, dtype, device) + + def dispatch(self, func, types, args, kwargs): + kwargs = kwargs or {} + if func in TENSOR_CONSTRUCTORS: + return self._handle_tensor_constructor(func, args, kwargs) + if func in ( + torch.Tensor.to, + torch.ops.aten.lift_fresh.default, + torch.ops.aten._to_copy, + torch.ops.aten._to_copy.default, + ): + return self._torch_Tensor_to(args, kwargs) + + # If the func doesn't act on Tensor, and is not a tensor constructor, + # We should skip and let torch handle it. + + tensor_args = [ + t + for t in torch_pytree.tree_flatten(args)[0] + if isinstance(t, torch.Tensor) + ] + + def is_not_torchax_tensor(x): + return not isinstance(x, Tensor) and not isinstance(x, View) + + if tensor_args and all(is_not_torchax_tensor(t) for t in tensor_args): + res = func(*args, **kwargs) + return res + + with jax.named_scope(_name_of_func(func)): + op = self._get_op_or_decomp(func) + + old_args, old_kwargs = args, kwargs + with self._dispatch_mode: + args, kwargs = torch_pytree.tree_map_only( + torch.distributed._functional_collectives.AsyncCollectiveTensor, + torch.distributed._functional_collectives.wait_tensor, + (args, kwargs), + ) + + try: + if not op.is_view_op: + args, kwargs = self.v2t_iso((args, kwargs)) + + with self: + if self.autocast_dtype is not None: + autocast_policy = amp.autocast_policy.get(func) + if autocast_policy is not None: + args, kwargs = amp.execute_policy( + autocast_policy, + args, + kwargs, + self.autocast_dtype, + ) + + if op.is_jax_function: + args, kwargs = self.t2j_iso((args, kwargs)) + except AssertionError: + if self.config.debug_mixed_tensor: + breakpoint() + else: + raise + + if op.needs_env: + kwargs["env"] = self + + if op.is_jax_function: + res = op.func(*args, **kwargs) + else: + # enable dispatch mode because this op could be a composite autograd op + # meaning, it will decompose in C++ + with self._dispatch_mode: + res = op.func(*args, **kwargs) + + if op.is_jax_function: + res = self.j2t_iso(res) + + if self.config.force_materialize_views and isinstance(res, View): + res = res.torch() + + if self.config.debug_accuracy_for_each_op: + debug_accuracy(func, old_args, old_kwargs, res) + return res + + def enable_torch_modes(self): + self._dispatch_mode.__enter__() + self._function_mode.__enter__() + self.enabled = True + + def disable_torch_modes(self, *exc): + if not exc: + exc = (None, None, None) + self._function_mode.__exit__(*exc) + self._dispatch_mode.__exit__(*exc) + self.enabled = False + + def __enter__(self): + self.enable_torch_modes() + self._manually_entered = True + return self + + def __exit__(self, *exc): + self._manually_entered = False + self.disable_torch_modes(*exc) + + def _move_one_value(self, val): + if isinstance(val, torch.nn.Module): + with self: + return val.to("jax") + if isinstance(val, Tensor): + return val + if isinstance(val, torch.Tensor): + return Tensor(self.t2j_copy(val), self) + return val + + def to_xla(self, torchvalues): + # tensors are torch.Tensors (not XLATensor) + res = torch_pytree.tree_map(self._move_one_value, torchvalues) + return res + + def t2j_iso(self, torchtensors): + """Convert torchax Tensor to jax array. + + This function will not copy, will just unwrap the inner jax array out. + Note: iso is short for "isomorphic" + """ + + def to_jax(x): + if isinstance( + x, + torch.distributed._functional_collectives.AsyncCollectiveTensor, + ): + x = x.wait() + assert isinstance(x, Tensor) or isinstance(x, View), ( + f"Expect a Tensor or a View but got {type(x)}; usually this means there is a mixed math between XLATensor and torch.Tensor" + ) + return x.jax() + + res = torch_pytree.tree_map_only(torch.Tensor, to_jax, torchtensors) + return res + + def v2t_iso(self, views): + def to_tensor(x): + if isinstance(x, View): + return x.torch() + return x + + res = torch_pytree.tree_map_only(View, to_tensor, views) + return res + + def j2t_iso(self, jaxarray): + """Convert jax array to torchax Tensor. + + This function will not copy, will just wrap the jax array with a torchax Tensor + Note: iso is short for "isomorphic" + """ + return torch_pytree.tree_map_only( + jax.Array, lambda x: Tensor(x, self), jaxarray ) - try: - if not op.is_view_op: - args, kwargs = self.v2t_iso((args, kwargs)) - - with self: - if self.autocast_dtype is not None: - autocast_policy = amp.autocast_policy.get(func) - if autocast_policy is not None: - args, kwargs = amp.execute_policy(autocast_policy, args, kwargs, - self.autocast_dtype) - - if op.is_jax_function: - args, kwargs = self.t2j_iso((args, kwargs)) - except AssertionError: - if self.config.debug_mixed_tensor: - breakpoint() - else: - raise - - if op.needs_env: - kwargs["env"] = self - - if op.is_jax_function: - res = op.func(*args, **kwargs) - else: - # enable dispatch mode because this op could be a composite autograd op - # meaning, it will decompose in C++ - with self._dispatch_mode: - res = op.func(*args, **kwargs) - - if op.is_jax_function: - res = self.j2t_iso(res) - - if self.config.force_materialize_views and isinstance(res, View): - res = res.torch() - - if self.config.debug_accuracy_for_each_op: - debug_accuracy(func, old_args, old_kwargs, res) - return res - - def enable_torch_modes(self): - self._dispatch_mode.__enter__() - self._function_mode.__enter__() - self.enabled = True - - def disable_torch_modes(self, *exc): - if not exc: - exc = (None, None, None) - self._function_mode.__exit__(*exc) - self._dispatch_mode.__exit__(*exc) - self.enabled = False - - def __enter__(self): - self.enable_torch_modes() - self._manually_entered = True - return self - - def __exit__(self, *exc): - self._manually_entered = False - self.disable_torch_modes(*exc) - - def _move_one_value(self, val): - if isinstance(val, torch.nn.Module): - with self: - return val.to("jax") - if isinstance(val, Tensor): - return val - if isinstance(val, torch.Tensor): - return Tensor(self.t2j_copy(val), self) - return val - - def to_xla(self, torchvalues): - # tensors are torch.Tensors (not XLATensor) - res = torch_pytree.tree_map(self._move_one_value, torchvalues) - return res - - def t2j_iso(self, torchtensors): - """Convert torchax Tensor to jax array. - - This function will not copy, will just unwrap the inner jax array out. - Note: iso is short for "isomorphic" - """ + def j2t_copy(self, args): + """Convert torch.Tensor in cpu to a jax array + + This might involves copying the data (depending if dlpack is enabled) + """ + return torch_pytree.tree_map_only( + jax.Array, + lambda x: mappings.j2t( + x, self.config.use_dlpack_for_data_conversion + ), + args, + ) - def to_jax(x): - if isinstance( - x, torch.distributed._functional_collectives.AsyncCollectiveTensor): - x = x.wait() - assert isinstance(x, Tensor) or isinstance(x, View), ( - f"Expect a Tensor or a View but got {type(x)}; usually this means there is a mixed math between XLATensor and torch.Tensor" - ) - return x.jax() - - res = torch_pytree.tree_map_only(torch.Tensor, to_jax, torchtensors) - return res - - def v2t_iso(self, views): - - def to_tensor(x): - if isinstance(x, View): - return x.torch() - return x - - res = torch_pytree.tree_map_only(View, to_tensor, views) - return res - - def j2t_iso(self, jaxarray): - """Convert jax array to torchax Tensor. - - This function will not copy, will just wrap the jax array with a torchax Tensor - Note: iso is short for "isomorphic" - """ - return torch_pytree.tree_map_only(jax.Array, lambda x: Tensor(x, self), - jaxarray) + def t2j_copy(self, args): + """Convert jax array to torch.Tensor in cpu. + + This might involves copying the data (depending if dlpack is enabled) + """ + return torch_pytree.tree_map_only( + torch.Tensor, + lambda x: mappings.t2j( + x, self.config.use_dlpack_for_data_conversion + ), + args, + ) - def j2t_copy(self, args): - """Convert torch.Tensor in cpu to a jax array - - This might involves copying the data (depending if dlpack is enabled) - """ - return torch_pytree.tree_map_only( - jax.Array, - lambda x: mappings.j2t(x, self.config.use_dlpack_for_data_conversion), - args) - - def t2j_copy(self, args): - """Convert jax array to torch.Tensor in cpu. - - This might involves copying the data (depending if dlpack is enabled) - """ - return torch_pytree.tree_map_only( - torch.Tensor, - lambda x: mappings.t2j(x, self.config.use_dlpack_for_data_conversion), - args) - - def override_op_definition(self, op_to_override, op_impl): - self._ops[op_to_override] = ops_registry.Operator( - op_to_override, - op_impl, - is_jax_function=False, - is_user_defined=True, - needs_env=False, - ) + def override_op_definition(self, op_to_override, op_impl): + self._ops[op_to_override] = ops_registry.Operator( + op_to_override, + op_impl, + is_jax_function=False, + is_user_defined=True, + needs_env=False, + ) diff --git a/torchax/torchax/tf_integration.py b/torchax/torchax/tf_integration.py index c9842089bfcf..54049604d729 100644 --- a/torchax/torchax/tf_integration.py +++ b/torchax/torchax/tf_integration.py @@ -9,30 +9,31 @@ def exported_program_to_tf_function(ep, enable_xla=True): - weights, jax_program = export.exported_program_to_jax(ep) - wrapped = lambda *args: jax_program(weights, (args,)) - avals = export.extract_avals(ep) - input_signature = [ - tf.TensorSpec(shape=t.shape, dtype=t.dtype, name=f"args_{i}") - for i, t in enumerate(avals) - ] - tf_f = tf.function( - jax2tf.convert( - wrapped, - with_gradient=False, - enable_xla=enable_xla, - ), - autograph=False, - input_signature=input_signature, - ) - return tf_f - - -def exported_program_to_tf_module(ep: torch.export.ExportedProgram, - enable_xla=True) -> tf.Module: - tfm = tf.Module() - tfm.f = exported_program_to_tf_function(ep, enable_xla) - return tfm + weights, jax_program = export.exported_program_to_jax(ep) + wrapped = lambda *args: jax_program(weights, (args,)) + avals = export.extract_avals(ep) + input_signature = [ + tf.TensorSpec(shape=t.shape, dtype=t.dtype, name=f"args_{i}") + for i, t in enumerate(avals) + ] + tf_f = tf.function( + jax2tf.convert( + wrapped, + with_gradient=False, + enable_xla=enable_xla, + ), + autograph=False, + input_signature=input_signature, + ) + return tf_f + + +def exported_program_to_tf_module( + ep: torch.export.ExportedProgram, enable_xla=True +) -> tf.Module: + tfm = tf.Module() + tfm.f = exported_program_to_tf_function(ep, enable_xla) + return tfm def save_exported_program_as_tf_saved_model( @@ -42,36 +43,38 @@ def save_exported_program_as_tf_saved_model( function_alias: str = "", enable_xla=True, ): - """This function will export and save a pytorch ExportedProgram to tf.saved_model format. - - The resulting tf.saved_model can be used inference using tf.serving model - server - or further convert to tflite flatbuffer for on-device serving. - - Args: - torch_model: torch.nn.Module - model to export and save - args: Tuple[Any] - a set of args to trace the model with, i.e. - torch_model(*args) must run - saved_model_dir: os.PathLike - location to an empty directory to store the - saved_model - serving_key: str - serving key tag, this is used by tf.serving to know - which function to run. - function_alias: str - passed through saved_model.save, used to tag a - function for inference converter or other tools. - """ - tfm = exported_program_to_tf_module(ep, enable_xla=enable_xla) - signatures = { - serving_key: tfm.f.get_concrete_function(*tfm.f.input_signature) - } - save_options = tf.saved_model.SaveOptions(function_aliases={ - function_alias: tfm.f, - }) - tf.saved_model.save( - tfm, - saved_model_dir, - signatures=signatures, - options=save_options, - ) + """This function will export and save a pytorch ExportedProgram to tf.saved_model format. + + The resulting tf.saved_model can be used inference using tf.serving model + server + or further convert to tflite flatbuffer for on-device serving. + + Args: + torch_model: torch.nn.Module - model to export and save + args: Tuple[Any] - a set of args to trace the model with, i.e. + torch_model(*args) must run + saved_model_dir: os.PathLike - location to an empty directory to store the + saved_model + serving_key: str - serving key tag, this is used by tf.serving to know + which function to run. + function_alias: str - passed through saved_model.save, used to tag a + function for inference converter or other tools. + """ + tfm = exported_program_to_tf_module(ep, enable_xla=enable_xla) + signatures = { + serving_key: tfm.f.get_concrete_function(*tfm.f.input_signature) + } + save_options = tf.saved_model.SaveOptions( + function_aliases={ + function_alias: tfm.f, + } + ) + tf.saved_model.save( + tfm, + saved_model_dir, + signatures=signatures, + options=save_options, + ) def save_torch_module_as_tf_saved_model( @@ -82,38 +85,41 @@ def save_torch_module_as_tf_saved_model( function_alias: str = "", enable_xla=True, ): - """This function will export and save a pytorch nn.Module to tf.saved_model format. - - The resulting tf.saved_model can be used inference using tf.serving model - server - or further convert to tflite flatbuffer for on-device serving. - - Args: - torch_model: torch.nn.Module - model to export and save - args: Tuple[Any] - a set of args to trace the model with, i.e. - torch_model(*args) must run - saved_model_dir: os.PathLike - location to an empty directory to store the - saved_model - serving_key: str - serving key tag, this is used by tf.serving to know - which function to run. - function_alias: str - passed through saved_model.save, used to tag a - function for inference converter or other tools. - """ - ep = torch.export.export(torch_model, args) - save_exported_program_as_tf_saved_model(ep, saved_model_dir, serving_key, - function_alias, enable_xla) + """This function will export and save a pytorch nn.Module to tf.saved_model format. + + The resulting tf.saved_model can be used inference using tf.serving model + server + or further convert to tflite flatbuffer for on-device serving. + + Args: + torch_model: torch.nn.Module - model to export and save + args: Tuple[Any] - a set of args to trace the model with, i.e. + torch_model(*args) must run + saved_model_dir: os.PathLike - location to an empty directory to store the + saved_model + serving_key: str - serving key tag, this is used by tf.serving to know + which function to run. + function_alias: str - passed through saved_model.save, used to tag a + function for inference converter or other tools. + """ + ep = torch.export.export(torch_model, args) + save_exported_program_as_tf_saved_model( + ep, saved_model_dir, serving_key, function_alias, enable_xla + ) def exported_program_to_tflite_flatbuffer(ep: torch.export.ExportedProgram): - tfm = exported_program_to_tf_module(ep) - tf_concrete_func = tfm.f.get_concrete_function(*tfm.f.input_signature) - converter = tf.lite.TFLiteConverter.from_concrete_functions( - [tf_concrete_func], tfm) - tflite_model = converter.convert() - return tflite_model - - -def torch_module_to_tflite_flatbuffer(torch_model: torch.nn.Module, - args: Tuple[Any]): - ep = torch.export.export(torch_model, args) - return exported_program_to_tflite_flatbuffer(ep) + tfm = exported_program_to_tf_module(ep) + tf_concrete_func = tfm.f.get_concrete_function(*tfm.f.input_signature) + converter = tf.lite.TFLiteConverter.from_concrete_functions( + [tf_concrete_func], tfm + ) + tflite_model = converter.convert() + return tflite_model + + +def torch_module_to_tflite_flatbuffer( + torch_model: torch.nn.Module, args: Tuple[Any] +): + ep = torch.export.export(torch_model, args) + return exported_program_to_tflite_flatbuffer(ep) diff --git a/torchax/torchax/train.py b/torchax/torchax/train.py index fb4e16fc48ee..78639090321f 100644 --- a/torchax/torchax/train.py +++ b/torchax/torchax/train.py @@ -12,106 +12,107 @@ def make_train_step(model_fn, loss_fn, optax_optimizer, remat_policy=None): - """Make a function that do one train step given model and loss. - - model_fn: a function representing the model's forward: - i.e. has signature Callable[weights, buffers, args] -> result. Where, - weights is a pytree of trainable parameters - buffers is a pytree of non-trainable parameters / constants - args is the input data loaded from the data set - result is the return value of the model - loss_fn: a function to compute loss. - i.e. it has signature of Callable[result, label] -> loss - where, result is what model_fn returned - loss is loaded from the dataloader. - optax_optimizer: the optimizer from optax library. for example, optax.adam - remat_policy: One of jax.ad_checkpoint.checkpoint_policies, specifies how - to do gradient checkpointing. If None, then it means checkpoint everything. - """ - env = torchax.default_env() - - def loss(weights, buffers, args, label): # inputs are XLATensor - with env, jax.named_scope('compute_loss'): - res = model_fn(weights, buffers, args) - l = loss_fn(res, label) - return l - - loss = interop.gradient_checkpoint(loss, kwargs={'policy': remat_policy}) - grad_fn = interop.jax_value_and_grad(loss) - - def step(weights, buffers, opt_state, args, label): #inputs are array - with jax.named_scope('compute_gradient'): - loss, gradient = grad_fn(weights, buffers, args, label) - - with jax.named_scope("optimizer_updates"): - updates, opt_state = interop.call_jax(optax_optimizer.update, gradient, - opt_state, weights) - weights = interop.call_jax(optax.apply_updates, weights, updates) - return loss, weights, opt_state - - # TODO: apply jax.jit so the user don't have to. - return step + """Make a function that do one train step given model and loss. + + model_fn: a function representing the model's forward: + i.e. has signature Callable[weights, buffers, args] -> result. Where, + weights is a pytree of trainable parameters + buffers is a pytree of non-trainable parameters / constants + args is the input data loaded from the data set + result is the return value of the model + loss_fn: a function to compute loss. + i.e. it has signature of Callable[result, label] -> loss + where, result is what model_fn returned + loss is loaded from the dataloader. + optax_optimizer: the optimizer from optax library. for example, optax.adam + remat_policy: One of jax.ad_checkpoint.checkpoint_policies, specifies how + to do gradient checkpointing. If None, then it means checkpoint everything. + """ + env = torchax.default_env() + + def loss(weights, buffers, args, label): # inputs are XLATensor + with env, jax.named_scope("compute_loss"): + res = model_fn(weights, buffers, args) + l = loss_fn(res, label) + return l + + loss = interop.gradient_checkpoint(loss, kwargs={"policy": remat_policy}) + grad_fn = interop.jax_value_and_grad(loss) + + def step(weights, buffers, opt_state, args, label): # inputs are array + with jax.named_scope("compute_gradient"): + loss, gradient = grad_fn(weights, buffers, args, label) + + with jax.named_scope("optimizer_updates"): + updates, opt_state = interop.call_jax( + optax_optimizer.update, gradient, opt_state, weights + ) + weights = interop.call_jax(optax.apply_updates, weights, updates) + return loss, weights, opt_state + + # TODO: apply jax.jit so the user don't have to. + return step class Container: - pass + pass class ScannedModule(torch.nn.Module): - - def __init__(self, module_list, checkpoint_policy=None): - super().__init__() - - self.c = None - assert module_list - self.c = Container() - self.c.one_mod = module_list[0] - self.checkpoint_policy = checkpoint_policy - - weights = self._stack_layer_weights(module_list) - self.layer_weights_keys = list(self.c.one_mod.state_dict().keys()) - self.params = torch.nn.ParameterDict({ - self._param_name_new(k): v for k, v in weights.items() - }) - - def _stack_layer_weights(self, module_list): - # Create weights such that, for every [n, m] weights - # becomes [k, n, m] where k is number of layer - # i.e. stacking layer weights together - temp = collections.defaultdict(list) - for m in module_list: - for k, v in m.state_dict().items(): - temp[k].append(v) - res = {k: torch.stack(v) for k, v in temp.items()} - return res - - def _param_name_new(self, old): - return '___'.join(old.split('.')) - - def _param_name_old(self, new): - return '.'.join(new.split('___')) - - def forward(self, *args, **kwargs): - assert not kwargs - weights = { - k: self.params[self._param_name_new(k)] for k in self.layer_weights_keys - } - scan = interop.torch_view(jax.lax.scan) - - def eval_one_layer(args, weight): - # unpack args - h, *rest = args - newh = torch.func.functional_call(self.c.one_mod, weight, args) - # next layer's input; and residual to be added to list - return (newh, *rest), None - - _eval_one_layer = interop.gradient_checkpoint( - eval_one_layer, - kwargs={'policy': self.checkpoint_policy}, - ) - h, _ = scan( - _eval_one_layer, - args, - weights, - ) - return h[0] + def __init__(self, module_list, checkpoint_policy=None): + super().__init__() + + self.c = None + assert module_list + self.c = Container() + self.c.one_mod = module_list[0] + self.checkpoint_policy = checkpoint_policy + + weights = self._stack_layer_weights(module_list) + self.layer_weights_keys = list(self.c.one_mod.state_dict().keys()) + self.params = torch.nn.ParameterDict({ + self._param_name_new(k): v for k, v in weights.items() + }) + + def _stack_layer_weights(self, module_list): + # Create weights such that, for every [n, m] weights + # becomes [k, n, m] where k is number of layer + # i.e. stacking layer weights together + temp = collections.defaultdict(list) + for m in module_list: + for k, v in m.state_dict().items(): + temp[k].append(v) + res = {k: torch.stack(v) for k, v in temp.items()} + return res + + def _param_name_new(self, old): + return "___".join(old.split(".")) + + def _param_name_old(self, new): + return ".".join(new.split("___")) + + def forward(self, *args, **kwargs): + assert not kwargs + weights = { + k: self.params[self._param_name_new(k)] + for k in self.layer_weights_keys + } + scan = interop.torch_view(jax.lax.scan) + + def eval_one_layer(args, weight): + # unpack args + h, *rest = args + newh = torch.func.functional_call(self.c.one_mod, weight, args) + # next layer's input; and residual to be added to list + return (newh, *rest), None + + _eval_one_layer = interop.gradient_checkpoint( + eval_one_layer, + kwargs={"policy": self.checkpoint_policy}, + ) + h, _ = scan( + _eval_one_layer, + args, + weights, + ) + return h[0] diff --git a/torchax/torchax/types.py b/torchax/torchax/types.py index 72a2f678c961..d61e1444eb45 100644 --- a/torchax/torchax/types.py +++ b/torchax/torchax/types.py @@ -4,9 +4,9 @@ import jax.numpy as jnp import sys -P = ParamSpec('P') +P = ParamSpec("P") -TorchValue: TypeAlias = Union[torch.Tensor, torch.dtype, 'TorchCallable', Any] +TorchValue: TypeAlias = Union[torch.Tensor, torch.dtype, "TorchCallable", Any] TorchCallable: TypeAlias = Callable[P, TorchValue] -JaxValue: TypeAlias = Union[jax.Array, jnp.dtype, 'JaxCallable', Any] -JaxCallable: TypeAlias = Callable[P, JaxValue] \ No newline at end of file +JaxValue: TypeAlias = Union[jax.Array, jnp.dtype, "JaxCallable", Any] +JaxCallable: TypeAlias = Callable[P, JaxValue] diff --git a/torchax/torchax/util.py b/torchax/torchax/util.py index e34f77119d6f..7f6f8cd638dc 100644 --- a/torchax/torchax/util.py +++ b/torchax/torchax/util.py @@ -1,88 +1,89 @@ from typing import Any, Callable -def partition(original: list[Any], - func: Callable[[Any], bool]) -> tuple[list[Any], list[Any]]: - """Partitions elements into two parallel lists based on a predicate function. - - Iterates through the 'original' list, applying 'func' to each element 'a'. - - If `func(a)` returns True, 'a' is appended to the first list ('truthy') - and `None` is appended to the second list ('falsy'). - - If `func(a)` returns False, `None` is appended to the first list ('truthy') - and 'a' is appended to the second list ('falsy'). - - The result is two lists of the same length as the 'original' list, acting - as parallel representations of the partitioned elements, using `None` as - placeholders. - - This is useful when we want to mark a group of elements as static (via passing - static_argnums) or donated (via donate_argnums) when combining with jax.jit - and friends. - - Args: - original: The list of elements to partition. - func: A callable (function or lambda) that accepts an element from - 'original' and returns a boolean value (True or False). - - Returns: - A tuple containing two lists (`truthy`, `falsy`), both of the same - length as `original`: - - The first list contains elements `x` where `func(x)` was True, and - `None` otherwise. - - The second list contains elements `x` where `func(x)` was False, and - `None` otherwise. - - Example: - >>> def is_even(n): return n % 2 == 0 - >>> nums = [1, 2, 3, 4, 5, 6] - >>> truthy_list, falsy_list = partition(nums, is_even) - >>> truthy_list - [None, 2, None, 4, None, 6] - >>> falsy_list - [1, None, 3, None, 5, None] - """ - truthy = [] - falsy = [] - for a in original: - t, f = (a, None) if func(a) else (None, a) - truthy.append(t) - falsy.append(f) - return truthy, falsy +def partition( + original: list[Any], func: Callable[[Any], bool] +) -> tuple[list[Any], list[Any]]: + """Partitions elements into two parallel lists based on a predicate function. + + Iterates through the 'original' list, applying 'func' to each element 'a'. + - If `func(a)` returns True, 'a' is appended to the first list ('truthy') + and `None` is appended to the second list ('falsy'). + - If `func(a)` returns False, `None` is appended to the first list ('truthy') + and 'a' is appended to the second list ('falsy'). + + The result is two lists of the same length as the 'original' list, acting + as parallel representations of the partitioned elements, using `None` as + placeholders. + + This is useful when we want to mark a group of elements as static (via passing + static_argnums) or donated (via donate_argnums) when combining with jax.jit + and friends. + + Args: + original: The list of elements to partition. + func: A callable (function or lambda) that accepts an element from + 'original' and returns a boolean value (True or False). + + Returns: + A tuple containing two lists (`truthy`, `falsy`), both of the same + length as `original`: + - The first list contains elements `x` where `func(x)` was True, and + `None` otherwise. + - The second list contains elements `x` where `func(x)` was False, and + `None` otherwise. + + Example: + >>> def is_even(n): return n % 2 == 0 + >>> nums = [1, 2, 3, 4, 5, 6] + >>> truthy_list, falsy_list = partition(nums, is_even) + >>> truthy_list + [None, 2, None, 4, None, 6] + >>> falsy_list + [1, None, 3, None, 5, None] + """ + truthy = [] + falsy = [] + for a in original: + t, f = (a, None) if func(a) else (None, a) + truthy.append(t) + falsy.append(f) + return truthy, falsy def merge(list1: list[Any], list2: list[Any]) -> list[Any]: - """Merges two lists element-wise, prioritizing non-None elements from list1. - - Creates a new list where each element is taken from the corresponding position - in 'list1', unless that element is None, in which case the element from the - corresponding position in 'list2' is used. Assumes both lists have the - same length. - - Invariant: merge(*partion(input_list, predicate)) == input_list for any predicate - - Args: - list1: The primary list. Its elements are preferred unless they are None. - list2: The secondary list. Its elements are used as fallbacks when the - corresponding element in list1 is None. - - Returns: - A new list representing the merged result. - - Raises: - AssertionError: If 'list1' and 'list2' do not have the same length. - - Example: - >>> l1 = [1, None, 3, None] - >>> l2 = [None, 2, None, 4] - >>> merge(l1, l2) - [1, 2, 3, 4] - >>> l3 = [None, 'b', None] - >>> l4 = ['a', None, 'c'] - >>> merge(l3, l4) - ['a', 'b', 'c'] - """ - assert len(list1) == len(list2) - res = [] - for a, b in zip(list1, list2): - res.append(b if a is None else a) - return res + """Merges two lists element-wise, prioritizing non-None elements from list1. + + Creates a new list where each element is taken from the corresponding position + in 'list1', unless that element is None, in which case the element from the + corresponding position in 'list2' is used. Assumes both lists have the + same length. + + Invariant: merge(*partion(input_list, predicate)) == input_list for any predicate + + Args: + list1: The primary list. Its elements are preferred unless they are None. + list2: The secondary list. Its elements are used as fallbacks when the + corresponding element in list1 is None. + + Returns: + A new list representing the merged result. + + Raises: + AssertionError: If 'list1' and 'list2' do not have the same length. + + Example: + >>> l1 = [1, None, 3, None] + >>> l2 = [None, 2, None, 4] + >>> merge(l1, l2) + [1, 2, 3, 4] + >>> l3 = [None, 'b', None] + >>> l4 = ['a', None, 'c'] + >>> merge(l3, l4) + ['a', 'b', 'c'] + """ + assert len(list1) == len(list2) + res = [] + for a, b in zip(list1, list2): + res.append(b if a is None else a) + return res diff --git a/torchax/torchax/view.py b/torchax/torchax/view.py index 040fa24ef9e8..c2a5851be6dc 100644 --- a/torchax/torchax/view.py +++ b/torchax/torchax/view.py @@ -10,39 +10,40 @@ class ViewInfoType(Enum): - INVALID = 0 - NARROW = 1 - NO_OP = 2 - PERMUTE = 3 - RESHAPE = 4 - RESIZE = 5 - SELECT = 6 - AS_STRIDED = 7 - DIAGONAL = 8 + INVALID = 0 + NARROW = 1 + NO_OP = 2 + PERMUTE = 3 + RESHAPE = 4 + RESIZE = 5 + SELECT = 6 + AS_STRIDED = 7 + DIAGONAL = 8 class ViewInfo(ABC): - """ + """ Abstract base class for all view operations. Defines the interface for applying and updating view transformations. """ - def __init__( - self, - view_info_type: ViewInfoType = ViewInfoType.INVALID, - ): - """ + def __init__( + self, + view_info_type: ViewInfoType = ViewInfoType.INVALID, + ): + """ Initialize a ViewInfo object. Args: view_info_type: The type of view operation """ - self.view_info_type = view_info_type + self.view_info_type = view_info_type - @abstractmethod - def update_tensor(self, new_value: jax.Array, - jax_array: jax.Array) -> jax.Array: - """ + @abstractmethod + def update_tensor( + self, new_value: jax.Array, jax_array: jax.Array + ) -> jax.Array: + """ Apply this view transformation to a JAX array and update its value. Args: @@ -52,11 +53,11 @@ def update_tensor(self, new_value: jax.Array, Returns: Updated array """ - pass + pass - @abstractmethod - def transform_tensor(self, jax_array: jax.Array) -> jax.Array: - """ + @abstractmethod + def transform_tensor(self, jax_array: jax.Array) -> jax.Array: + """ Apply this view transformation to a JAX array. Args: @@ -65,11 +66,11 @@ def transform_tensor(self, jax_array: jax.Array) -> jax.Array: Returns: Transformed array """ - pass + pass - @abstractmethod - def calculate_output_shape(self, source: jax.Array) -> List[int]: - """ + @abstractmethod + def calculate_output_shape(self, source: jax.Array) -> List[int]: + """ Calculate the resulting shape after applying this view. Args: @@ -78,300 +79,323 @@ def calculate_output_shape(self, source: jax.Array) -> List[int]: Returns: Resulting shape after transformation """ - pass + pass class NarrowInfo(ViewInfo): - """ + """ Represents a slicing operation on a tensor. Handles operations like tensor[1:3, :, 2:5:2]. """ - def __init__(self, slices: Union[slice, Tuple[slice]]) -> None: - """ + def __init__(self, slices: Union[slice, Tuple[slice]]) -> None: + """ Args: slices: The slice(s) to apply to the tensor. E.g. jax_array.at[slices] will return the transformed tensor. """ - super().__init__(ViewInfoType.NARROW) - self.slices = slices + super().__init__(ViewInfoType.NARROW) + self.slices = slices - def __eq__(self, other: object) -> bool: - if not isinstance(other, NarrowInfo): - return False - return self.slices == other.slices + def __eq__(self, other: object) -> bool: + if not isinstance(other, NarrowInfo): + return False + return self.slices == other.slices - def transform_tensor(self, jax_array: jax.Array) -> jax.Array: - try: - return jax_array[self.slices] - except IndexError as e: - raise IndexError("Invalid slice operation") from e + def transform_tensor(self, jax_array: jax.Array) -> jax.Array: + try: + return jax_array[self.slices] + except IndexError as e: + raise IndexError("Invalid slice operation") from e - def update_tensor(self, new_value: jax.Array, - jax_array: jax.Array) -> jax.Array: - return jax_array.at[self.slices].set(new_value) + def update_tensor( + self, new_value: jax.Array, jax_array: jax.Array + ) -> jax.Array: + return jax_array.at[self.slices].set(new_value) - def calculate_output_shape(self, source: jax.Array) -> List[int]: - return source[self.slices].shape + def calculate_output_shape(self, source: jax.Array) -> List[int]: + return source[self.slices].shape class SelectInfo(ViewInfo): - """ + """ Represents a selection operation on a tensor. Typically used for indexing operations that select specific elements. """ - def __init__(self, - dim: int = 0, - start: int = 0, - end: int = 0, - stride: int = 0) -> None: - super().__init__(ViewInfoType.SELECT) - self.dim: int = dim - self.start: int = start - self.end: int = end - self.stride: int = stride - - def __eq__(self, other: object) -> bool: - if not isinstance(other, SelectInfo): - return False - return (self.dim == other.dim and self.start == other.start and - self.end == other.end and self.stride == other.stride) - - def transform_tensor(self, jax_array: jax.Array) -> jax.Array: - raise NotImplementedError("SelectInfo.apply not implemented") - - def update_tensor(self, new_value: jax.Array, - jax_array: jax.Array) -> jax.Array: - raise NotImplementedError("SelectInfo.update not implemented") - - def calculate_output_shape(self, source: jax.Array) -> List[int]: - raise NotImplementedError( - "SelectInfo.calculate_output_shape not implemented") + def __init__( + self, dim: int = 0, start: int = 0, end: int = 0, stride: int = 0 + ) -> None: + super().__init__(ViewInfoType.SELECT) + self.dim: int = dim + self.start: int = start + self.end: int = end + self.stride: int = stride + + def __eq__(self, other: object) -> bool: + if not isinstance(other, SelectInfo): + return False + return ( + self.dim == other.dim + and self.start == other.start + and self.end == other.end + and self.stride == other.stride + ) + + def transform_tensor(self, jax_array: jax.Array) -> jax.Array: + raise NotImplementedError("SelectInfo.apply not implemented") + + def update_tensor( + self, new_value: jax.Array, jax_array: jax.Array + ) -> jax.Array: + raise NotImplementedError("SelectInfo.update not implemented") + + def calculate_output_shape(self, source: jax.Array) -> List[int]: + raise NotImplementedError( + "SelectInfo.calculate_output_shape not implemented" + ) class AsStridedInfo(ViewInfo): - """ + """ Information for as_strided operations. """ - def __init__(self, stride: List[int], offset: int = 0) -> None: - super().__init__(ViewInfoType.AS_STRIDED) - self.stride: List[int] = stride - self.offset: int = offset + def __init__(self, stride: List[int], offset: int = 0) -> None: + super().__init__(ViewInfoType.AS_STRIDED) + self.stride: List[int] = stride + self.offset: int = offset - def __eq__(self, other: object) -> bool: - if not isinstance(other, AsStridedInfo): - return False - return self.offset == other.offset and self.stride == other.stride + def __eq__(self, other: object) -> bool: + if not isinstance(other, AsStridedInfo): + return False + return self.offset == other.offset and self.stride == other.stride - def transform_tensor(self, jax_array: jax.Array) -> jax.Array: - raise NotImplementedError("AsStridedInfo.apply not implemented") + def transform_tensor(self, jax_array: jax.Array) -> jax.Array: + raise NotImplementedError("AsStridedInfo.apply not implemented") - def update_tensor(self, new_value: jax.Array, - jax_array: jax.Array) -> jax.Array: - raise NotImplementedError("AsStridedInfo.update not implemented") + def update_tensor( + self, new_value: jax.Array, jax_array: jax.Array + ) -> jax.Array: + raise NotImplementedError("AsStridedInfo.update not implemented") - def calculate_output_shape(self, source: jax.Array) -> List[int]: - raise NotImplementedError( - "AsStridedInfo.calculate_output_shape not implemented") + def calculate_output_shape(self, source: jax.Array) -> List[int]: + raise NotImplementedError( + "AsStridedInfo.calculate_output_shape not implemented" + ) class DiagonalInfo(ViewInfo): - """ + """ Information for diagonal operations. Extracts diagonal elements from a tensor. """ - def __init__(self, offset: int = 0, dim1: int = 0, dim2: int = 1) -> None: - """ + def __init__(self, offset: int = 0, dim1: int = 0, dim2: int = 1) -> None: + """ Args: offset: Offset from the main diagonal dim1: First dimension for diagonal extraction dim2: Second dimension for diagonal extraction """ - super().__init__(ViewInfoType.DIAGONAL) - self.offset: int = offset - self.dim1: int = dim1 - self.dim2: int = dim2 - - def __eq__(self, other: object) -> bool: - if not isinstance(other, DiagonalInfo): - return False - return (self.offset == other.offset and self.dim1 == other.dim1 and - self.dim2 == other.dim2) - - def transform_tensor(self, jax_array: jax.Array) -> jax.Array: - raise NotImplementedError("DiagonalInfo.apply not implemented") - - def update_tensor(self, new_value: jax.Array, - jax_array: jax.Array) -> jax.Array: - raise NotImplementedError("DiagonalInfo.update not implemented") - - def calculate_output_shape(self, source: jax.Array) -> List[int]: - raise NotImplementedError( - "DiagonalInfo.calculate_output_shape not implemented") + super().__init__(ViewInfoType.DIAGONAL) + self.offset: int = offset + self.dim1: int = dim1 + self.dim2: int = dim2 + + def __eq__(self, other: object) -> bool: + if not isinstance(other, DiagonalInfo): + return False + return ( + self.offset == other.offset + and self.dim1 == other.dim1 + and self.dim2 == other.dim2 + ) + + def transform_tensor(self, jax_array: jax.Array) -> jax.Array: + raise NotImplementedError("DiagonalInfo.apply not implemented") + + def update_tensor( + self, new_value: jax.Array, jax_array: jax.Array + ) -> jax.Array: + raise NotImplementedError("DiagonalInfo.update not implemented") + + def calculate_output_shape(self, source: jax.Array) -> List[int]: + raise NotImplementedError( + "DiagonalInfo.calculate_output_shape not implemented" + ) class View(torch.Tensor): - """ + """ A View is a reference to another Tensor or another View, with a transformation applied to it. """ - @staticmethod - def __new__(cls, parent: Union["torchax.Tensor", "View"], view_info: ViewInfo, - env: Any) -> "View": - """ + @staticmethod + def __new__( + cls, + parent: Union["torchax.Tensor", "View"], + view_info: ViewInfo, + env: Any, + ) -> "View": + """ Args: parent: Parent tensor or view view_info: Information about the view transformation env: Environment for tensor operations """ - shape = view_info.calculate_output_shape(parent.jax()) - return torch.Tensor._make_wrapper_subclass( - cls, - shape, - device="meta", - dtype=parent.dtype, - requires_grad=False, - ) - - def __init__(self, parent: Union["torchax.Tensor", "View"], - view_info: ViewInfo, env: Any) -> None: - super().__init__() - self.parent = parent - self.view_info = view_info - self._env = env - - def get_transformation_chain(self) -> List[ViewInfo]: - """ + shape = view_info.calculate_output_shape(parent.jax()) + return torch.Tensor._make_wrapper_subclass( + cls, + shape, + device="meta", + dtype=parent.dtype, + requires_grad=False, + ) + + def __init__( + self, + parent: Union["torchax.Tensor", "View"], + view_info: ViewInfo, + env: Any, + ) -> None: + super().__init__() + self.parent = parent + self.view_info = view_info + self._env = env + + def get_transformation_chain(self) -> List[ViewInfo]: + """ Get all view transformations from the source tensor to this view. """ - if isinstance(self.parent, View): - transformations = self.parent.get_transformation_chain() - transformations.append(self.view_info) - return transformations - else: - return [self.view_info] + if isinstance(self.parent, View): + transformations = self.parent.get_transformation_chain() + transformations.append(self.view_info) + return transformations + else: + return [self.view_info] - __torch_function__ = torch._C._disabled_torch_function_impl + __torch_function__ = torch._C._disabled_torch_function_impl - def source_jax(self) -> jax.Array: - """ + def source_jax(self) -> jax.Array: + """ Returns the source tensor. """ - if isinstance(self.parent, View): - return self.parent.source_jax() - else: - return self.parent.jax() + if isinstance(self.parent, View): + return self.parent.source_jax() + else: + return self.parent.jax() - def replace_source_jax(self, new_value: jax.Array) -> None: - """ + def replace_source_jax(self, new_value: jax.Array) -> None: + """ Update the source tensor with new values. """ - if isinstance(self.parent, View): - self.parent.replace_source_jax(new_value) - else: - assert new_value.shape == self.parent._elem.shape - self.parent._elem = new_value + if isinstance(self.parent, View): + self.parent.replace_source_jax(new_value) + else: + assert new_value.shape == self.parent._elem.shape + self.parent._elem = new_value - def torch(self) -> "torchax.Tensor": - """ + def torch(self) -> "torchax.Tensor": + """ Returns a Torchax tensor representing this view after all transformations """ - from torchax.tensor import Tensor + from torchax.tensor import Tensor - return Tensor(self.jax(), self._env) + return Tensor(self.jax(), self._env) - def update( - self, - new_values: Union[jax.Array, "View", "torchax.Tensor"], - view_infos: Optional[List[ViewInfo]] = None, - ) -> None: - """ + def update( + self, + new_values: Union[jax.Array, "View", "torchax.Tensor"], + view_infos: Optional[List[ViewInfo]] = None, + ) -> None: + """ Update this view with new values, propagating changes back to source. If view_infos is None, it will use the transformation chain from the source tensor. """ - if view_infos is None: - view_infos = self.get_transformation_chain() - - # Get the source JAX array - source_array = self.source_jax() - - # Get the new value - from torchax.tensor import Tensor - - if isinstance(new_values, View) or isinstance(new_values, Tensor): - new_values = new_values.jax() - - # Apply all view transformations to the source array - # And store intermediate values - intermediate_values = [source_array] - for view_info in view_infos[:-1]: - intermediate_values.append( - view_info.transform_tensor(intermediate_values[-1])) - - # TODO: Investigate efficiency of this algorithm - # Update the source array with the new value by - # applying inverse transformations in reverse order - for view_info, parent_array in zip( - reversed(view_infos), reversed(intermediate_values)): - # Apply the inverse transformation to propagate changes back - new_values = view_info.update_tensor(new_values, parent_array) - - # Update the source tensor with the new values - self.replace_source_jax(new_values) - - @classmethod - def __torch_dispatch__( - cls, - func: Any, - types: Tuple[Any, ...], - args: Tuple[Any, ...] = (), - kwargs: Optional[dict] = None, - ) -> Any: - raise AssertionError( - 'torchax Tensors can only do math within the torchax environment.' - 'Please wrap your code with `with torchax.default_env()` or ' - 'call torchax.enable_globally() before.') - - def create_sub_view(self, view_info: ViewInfo) -> "View": - """ + if view_infos is None: + view_infos = self.get_transformation_chain() + + # Get the source JAX array + source_array = self.source_jax() + + # Get the new value + from torchax.tensor import Tensor + + if isinstance(new_values, View) or isinstance(new_values, Tensor): + new_values = new_values.jax() + + # Apply all view transformations to the source array + # And store intermediate values + intermediate_values = [source_array] + for view_info in view_infos[:-1]: + intermediate_values.append( + view_info.transform_tensor(intermediate_values[-1]) + ) + + # TODO: Investigate efficiency of this algorithm + # Update the source array with the new value by + # applying inverse transformations in reverse order + for view_info, parent_array in zip( + reversed(view_infos), reversed(intermediate_values) + ): + # Apply the inverse transformation to propagate changes back + new_values = view_info.update_tensor(new_values, parent_array) + + # Update the source tensor with the new values + self.replace_source_jax(new_values) + + @classmethod + def __torch_dispatch__( + cls, + func: Any, + types: Tuple[Any, ...], + args: Tuple[Any, ...] = (), + kwargs: Optional[dict] = None, + ) -> Any: + raise AssertionError( + "torchax Tensors can only do math within the torchax environment." + "Please wrap your code with `with torchax.default_env()` or " + "call torchax.enable_globally() before." + ) + + def create_sub_view(self, view_info: ViewInfo) -> "View": + """ Create a new view that is a child of this view. """ - return View(self, view_info, self._env) + return View(self, view_info, self._env) - def __str__(self) -> str: - return f"View({self.torch()})" + def __str__(self) -> str: + return f"View({self.torch()})" - def jax(self) -> jax.Array: - """ + def jax(self) -> jax.Array: + """ Returns a copy of the source tensor after transformations. """ - result = self.source_jax() - for view_info in self.get_transformation_chain(): - result = view_info.transform_tensor(result) - return result + result = self.source_jax() + for view_info in self.get_transformation_chain(): + result = view_info.transform_tensor(result) + return result - def __setitem__(self, indexes, val): - view_infos = self.get_transformation_chain() + [NarrowInfo(indexes)] - self.update(view_infos=view_infos, new_values=val) + def __setitem__(self, indexes, val): + view_infos = self.get_transformation_chain() + [NarrowInfo(indexes)] + self.update(view_infos=view_infos, new_values=val) - def dim(self): - return self.ndim + def dim(self): + return self.ndim - @property - def device(self): - return torch.device("jax:0") + @property + def device(self): + return torch.device("jax:0") - @property - def jax_device(self): - return self.jax().device + @property + def jax_device(self): + return self.jax().device - @property - def ndim(self): - return len(self.shape) + @property + def ndim(self): + return len(self.shape) - __repr__ = __str__ + __repr__ = __str__ From 29323a2d0e7e78f15a112578a88fd13ff4f286b7 Mon Sep 17 00:00:00 2001 From: Wen Zhang Date: Fri, 20 Jun 2025 18:54:52 +0000 Subject: [PATCH 8/8] use 2 space for tab instead --- torchax/pyproject.toml | 2 + torchax/torchax/__init__.py | 174 +- torchax/torchax/amp.py | 324 +- torchax/torchax/config.py | 32 +- torchax/torchax/decompositions.py | 1356 ++-- torchax/torchax/device_module.py | 14 +- torchax/torchax/distributed.py | 398 +- torchax/torchax/export.py | 438 +- torchax/torchax/flax.py | 64 +- torchax/torchax/interop.py | 510 +- torchax/torchax/mesh_util.py | 380 +- torchax/torchax/ops/__init__.py | 16 +- torchax/torchax/ops/jaten.py | 7740 +++++++++++------------ torchax/torchax/ops/jax_reimplement.py | 338 +- torchax/torchax/ops/jc10d.py | 50 +- torchax/torchax/ops/jimage.py | 178 +- torchax/torchax/ops/jlibrary.py | 108 +- torchax/torchax/ops/jtorch.py | 728 ++- torchax/torchax/ops/jtorchvision_nms.py | 450 +- torchax/torchax/ops/mappings.py | 176 +- torchax/torchax/ops/op_base.py | 178 +- torchax/torchax/ops/ops_registry.py | 76 +- torchax/torchax/tensor.py | 1234 ++-- torchax/torchax/tf_integration.py | 188 +- torchax/torchax/train.py | 195 +- torchax/torchax/util.py | 162 +- torchax/torchax/view.py | 688 +- 27 files changed, 8035 insertions(+), 8162 deletions(-) diff --git a/torchax/pyproject.toml b/torchax/pyproject.toml index 689b7983e680..85d752964ab8 100644 --- a/torchax/pyproject.toml +++ b/torchax/pyproject.toml @@ -59,6 +59,8 @@ line-length = 80 # Enable preview mode to use rules like E306 preview = true +indent-width = 2 + [tool.ruff.lint] select = [ "E", "F", "W", # Your existing rule selections diff --git a/torchax/torchax/__init__.py b/torchax/torchax/__init__.py index 2745a2a47db6..f16a61333189 100644 --- a/torchax/torchax/__init__.py +++ b/torchax/torchax/__init__.py @@ -13,9 +13,9 @@ VERSION = __version__ __all__ = [ - "default_env", - "extract_jax", - "enable_globally", + "default_env", + "extract_jax", + "enable_globally", ] from jax._src import xla_bridge @@ -24,71 +24,69 @@ # torchax:oss-begin if getattr(jax.config, "jax_pjrt_client_create_options", None): - jax.config.update( - "jax_pjrt_client_create_options", - f"ml_framework_name:PyTorch/XLA2;ml_framework_version:{'v0.0.1'}", - ) + jax.config.update( + "jax_pjrt_client_create_options", + f"ml_framework_name:PyTorch/XLA2;ml_framework_version:{'v0.0.1'}", + ) # torchax:oss-end env = None def default_env(): - global env + global env - if env is None: - env = tensor.Environment() - return env + if env is None: + env = tensor.Environment() + return env def extract_jax(mod: torch.nn.Module, env=None): - """Returns a pytree of jax.ndarray and a jax callable.""" - if env is None: - env = default_env() - states = dict(mod.named_buffers()) - states.update(mod.named_parameters()) + """Returns a pytree of jax.ndarray and a jax callable.""" + if env is None: + env = default_env() + states = dict(mod.named_buffers()) + states.update(mod.named_parameters()) - states = env.t2j_copy(states) + states = env.t2j_copy(states) - # @jax.jit - def jax_func(states, inputs): - (states, inputs) = env.j2t_iso((states, inputs)) - with env: - res = torch.func.functional_call( - mod, states, inputs, tie_weights=False - ) - return env.t2j_iso(res) + # @jax.jit + def jax_func(states, inputs): + (states, inputs) = env.j2t_iso((states, inputs)) + with env: + res = torch.func.functional_call(mod, states, inputs, tie_weights=False) + return env.t2j_iso(res) - return states, jax_func + return states, jax_func def enable_globally(): - env = default_env().enable_torch_modes() - return env + env = default_env().enable_torch_modes() + return env def disable_globally(): - global env - default_env().disable_torch_modes() + global env + default_env().disable_torch_modes() @contextlib.contextmanager def disable_temporarily(): - prev = default_env().enabled - if prev: - disable_globally() - yield () - if prev: - enable_globally() + prev = default_env().enabled + if prev: + disable_globally() + yield () + if prev: + enable_globally() torch.utils.rename_privateuse1_backend("jax") unsupported_dtype = [torch.quint8] torch.utils.generate_methods_for_privateuse1_backend( - for_tensor=True, - for_module=True, - for_storage=True, - unsupported_dtype=unsupported_dtype, + for_tensor=True, + for_module=True, + for_storage=True, + unsupported_dtype=unsupported_dtype, ) import jax @@ -98,73 +96,71 @@ def disable_temporarily(): def enable_accuracy_mode(): - jax.config.update("jax_enable_x64", True) - jax.config.update("jax_default_matmul_precision", "highest") - default_env().config.internal_respect_torch_return_dtypes = True + jax.config.update("jax_enable_x64", True) + jax.config.update("jax_default_matmul_precision", "highest") + default_env().config.internal_respect_torch_return_dtypes = True def enable_performance_mode(): - jax.config.update("jax_enable_x64", False) - jax.config.update("jax_default_matmul_precision", "default") - default_env().config.internal_respect_torch_return_dtypes = False + jax.config.update("jax_enable_x64", False) + jax.config.update("jax_default_matmul_precision", "default") + default_env().config.internal_respect_torch_return_dtypes = False @dataclasses.dataclass class CompileOptions: - # only valid if compiling nn.Module - methods_to_compile: List[str] = dataclasses.field( - default_factory=lambda: ["forward"] - ) - jax_jit_kwargs: Dict[str, Any] = dataclasses.field(default_factory=dict) - mode: str = "jax" # or dynamo or export + # only valid if compiling nn.Module + methods_to_compile: List[str] = dataclasses.field( + default_factory=lambda: ["forward"] + ) + jax_jit_kwargs: Dict[str, Any] = dataclasses.field(default_factory=dict) + mode: str = "jax" # or dynamo or export def compile(fn, options: Optional[CompileOptions] = None): - options = options or CompileOptions() - if options.mode == "jax": - from torchax import interop - - if isinstance(fn, torch.nn.Module): - module = interop.JittableModule( - fn, extra_jit_args=options.jax_jit_kwargs - ) - for n in options.methods_to_compile: - module.make_jitted(n) - return module - else: - return interop.jax_jit(fn) - elif options.mode == "dynamo": - raise RuntimeError("dynamo mode is not supported yet") - elif options.mode == "export": - raise RuntimeError("export mode is not supported yet") + options = options or CompileOptions() + if options.mode == "jax": + from torchax import interop + + if isinstance(fn, torch.nn.Module): + module = interop.JittableModule(fn, extra_jit_args=options.jax_jit_kwargs) + for n in options.methods_to_compile: + module.make_jitted(n) + return module + else: + return interop.jax_jit(fn) + elif options.mode == "dynamo": + raise RuntimeError("dynamo mode is not supported yet") + elif options.mode == "export": + raise RuntimeError("export mode is not supported yet") @contextmanager def jax_device(target_device: str, env: tensor.Environment | None = None): - """ - to("jax") cannot differentiate the device/platform (cpu vs tpu). - Use this context manager to control jax array's storage device + """ + to("jax") cannot differentiate the device/platform (cpu vs tpu). + Use this context manager to control jax array's storage device - Examples: + Examples: - a = torch.ones(3, 3) + a = torch.ones(3, 3) - with jax_device("cpu"): - b = a.to("jax") + with jax_device("cpu"): + b = a.to("jax") - with jax_device("tpu"): - c = a.to("jax") + with jax_device("tpu"): + c = a.to("jax") - with jax_device("tpu"): - c = b.to("jax") + with jax_device("tpu"): + c = b.to("jax") - """ - if env is None: - env = default_env() + """ + if env is None: + env = default_env() - prev_target_device = env.target_device - try: - env.target_device = target_device - yield env - finally: - env.target_device = prev_target_device + prev_target_device = env.target_device + try: + env.target_device = target_device + yield env + finally: + env.target_device = prev_target_device diff --git a/torchax/torchax/amp.py b/torchax/torchax/amp.py index 465c2fff2fb4..6e38585acbd2 100644 --- a/torchax/torchax/amp.py +++ b/torchax/torchax/amp.py @@ -28,179 +28,179 @@ # promote, // Run in the widest dtype among several args. # }; class CastPolicy(enum.Enum): - LOWER_PRECISION_FP = 0 - FP32 = 1 - FP32_SET_OPT_DTYPE = 2 - FP32_APPEND_DTYPE = 3 - PROMOTE = 4 + LOWER_PRECISION_FP = 0 + FP32 = 1 + FP32_SET_OPT_DTYPE = 2 + FP32_APPEND_DTYPE = 3 + PROMOTE = 4 def execute_policy(policy, args, kwargs, target_lower_fp): - def is_float(a): - return isinstance(a, torch.Tensor) and a.is_floating_point() + def is_float(a): + return isinstance(a, torch.Tensor) and a.is_floating_point() - match policy: - case CastPolicy.LOWER_PRECISION_FP: - return pytree.tree_map_only( - is_float, lambda a: a.to(target_lower_fp), (args, kwargs) - ) - case CastPolicy.FP32: - return pytree.tree_map_only( - is_float, lambda a: a.to(torch.float32), (args, kwargs) - ) - case CastPolicy.PROMOTE: - dtypes = set(a.dtype for a in args) - widest = max((dtype.itemsize, dtype) for dtype in dtypes)[1] - return pytree.tree_map_only( - is_float, lambda a: a.to(widest), (args, kwargs) - ) - case _: - raise AssertionError(f"Policy {policy} not implemented yet.") + match policy: + case CastPolicy.LOWER_PRECISION_FP: + return pytree.tree_map_only( + is_float, lambda a: a.to(target_lower_fp), (args, kwargs) + ) + case CastPolicy.FP32: + return pytree.tree_map_only( + is_float, lambda a: a.to(torch.float32), (args, kwargs) + ) + case CastPolicy.PROMOTE: + dtypes = set(a.dtype for a in args) + widest = max((dtype.itemsize, dtype) for dtype in dtypes)[1] + return pytree.tree_map_only( + is_float, lambda a: a.to(widest), (args, kwargs) + ) + case _: + raise AssertionError(f"Policy {policy} not implemented yet.") @contextlib.contextmanager def autocast(device, dtype=torch.bfloat16, env=None): - del device - if env is None: - import torchax + del device + if env is None: + import torchax - env = torchax.default_env() - env.autocast_dtype, old = dtype, env.autocast_dtype - yield - env.autocast_dtype = old + env = torchax.default_env() + env.autocast_dtype, old = dtype, env.autocast_dtype + yield + env.autocast_dtype = old # https://github.com/pytorch/pytorch/blob/05faba40287cf7d8734da96cb2e904f39710bf29/aten/src/ATen/autocast_mode.cpp#L327 autocast_policy = { - torch.ops.aten.conv1d.default: CastPolicy.LOWER_PRECISION_FP, - torch.ops.aten.conv1d.padding: CastPolicy.LOWER_PRECISION_FP, - torch.ops.aten.conv2d.default: CastPolicy.LOWER_PRECISION_FP, - torch.ops.aten.conv2d.padding: CastPolicy.LOWER_PRECISION_FP, - torch.ops.aten.conv3d.default: CastPolicy.LOWER_PRECISION_FP, - torch.ops.aten.conv3d.padding: CastPolicy.LOWER_PRECISION_FP, - torch.ops.aten.bmm.default: CastPolicy.LOWER_PRECISION_FP, - torch.ops.aten.mm.default: CastPolicy.LOWER_PRECISION_FP, - torch.ops.aten.linalg_vecdot.default: CastPolicy.LOWER_PRECISION_FP, - torch.ops.aten.baddbmm.default: CastPolicy.LOWER_PRECISION_FP, - torch.ops.aten.addmm.default: CastPolicy.LOWER_PRECISION_FP, - torch.ops.aten._addmm_activation.default: CastPolicy.LOWER_PRECISION_FP, - torch.ops.aten.addbmm.default: CastPolicy.LOWER_PRECISION_FP, - torch.ops.aten.linear.default: CastPolicy.LOWER_PRECISION_FP, - torch.ops.aten._convolution.deprecated: CastPolicy.LOWER_PRECISION_FP, - torch.ops.aten.matmul.default: CastPolicy.LOWER_PRECISION_FP, - torch.ops.aten.conv_tbc.default: CastPolicy.LOWER_PRECISION_FP, - torch.ops.aten.mkldnn_rnn_layer.default: CastPolicy.LOWER_PRECISION_FP, - torch.ops.aten.conv_transpose1d.default: CastPolicy.LOWER_PRECISION_FP, - torch.ops.aten.conv_transpose2d.input: CastPolicy.LOWER_PRECISION_FP, - torch.ops.aten.conv_transpose3d.input: CastPolicy.LOWER_PRECISION_FP, - torch.ops.aten.prelu.default: CastPolicy.LOWER_PRECISION_FP, - torch.ops.aten.scaled_dot_product_attention.default: CastPolicy.LOWER_PRECISION_FP, - torch.ops.aten._native_multi_head_attention.default: CastPolicy.LOWER_PRECISION_FP, - # fp32 cast policy - torch.ops.aten.avg_pool3d.default: CastPolicy.FP32, - torch.ops.aten.binary_cross_entropy.default: CastPolicy.FP32, - torch.ops.aten.grid_sampler.default: CastPolicy.FP32, - torch.ops.aten.polar.default: CastPolicy.FP32, - torch.ops.aten.prod.default: CastPolicy.FP32, - torch.ops.aten.prod.dim_int: CastPolicy.FP32, - torch.ops.aten.prod.dim_Dimname: CastPolicy.FP32, - torch.ops.aten.quantile.default: CastPolicy.FP32, - torch.ops.aten.quantile.scalar: CastPolicy.FP32, - torch.ops.aten.nanquantile.default: CastPolicy.FP32, - torch.ops.aten.nanquantile.scalar: CastPolicy.FP32, - torch.ops.aten.stft.default: CastPolicy.FP32, - torch.ops.aten.stft.center: CastPolicy.FP32, - torch.ops.aten.cdist.default: CastPolicy.FP32, - torch.ops.aten.grid_sampler_2d.default: CastPolicy.FP32, - torch.ops.aten._grid_sampler_2d_cpu_fallback.default: CastPolicy.FP32, - torch.ops.aten.grid_sampler_3d.default: CastPolicy.FP32, - torch.ops.aten.trace.default: CastPolicy.FP32, - torch.ops.aten.view_as_complex.default: CastPolicy.FP32, - torch.ops.aten.cholesky.default: CastPolicy.FP32, - torch.ops.aten.cholesky_inverse.default: CastPolicy.FP32, - torch.ops.aten.cholesky_solve.default: CastPolicy.FP32, - torch.ops.aten.inverse.default: CastPolicy.FP32, - torch.ops.aten.lu_solve.default: CastPolicy.FP32, - torch.ops.aten.orgqr.default: CastPolicy.FP32, - torch.ops.aten.ormqr.default: CastPolicy.FP32, - torch.ops.aten.pinverse.default: CastPolicy.FP32, - torch.ops.aten.max_pool3d.default: CastPolicy.FP32, - torch.ops.aten.max_unpool2d.default: CastPolicy.FP32, - torch.ops.aten.max_unpool3d.default: CastPolicy.FP32, - torch.ops.aten.adaptive_avg_pool3d.default: CastPolicy.FP32, - torch.ops.aten.reflection_pad1d.default: CastPolicy.FP32, - torch.ops.aten.reflection_pad2d.default: CastPolicy.FP32, - torch.ops.aten.replication_pad1d.default: CastPolicy.FP32, - torch.ops.aten.replication_pad2d.default: CastPolicy.FP32, - torch.ops.aten.replication_pad3d.default: CastPolicy.FP32, - torch.ops.aten.mse_loss.default: CastPolicy.FP32, - torch.ops.aten.cosine_embedding_loss.default: CastPolicy.FP32, - torch.ops.aten.nll_loss.default: CastPolicy.FP32, - torch.ops.aten.nll_loss2d.default: CastPolicy.FP32, - torch.ops.aten.hinge_embedding_loss.default: CastPolicy.FP32, - torch.ops.aten.poisson_nll_loss.default: CastPolicy.FP32, - torch.ops.aten.smooth_l1_loss.default: CastPolicy.FP32, - torch.ops.aten.cross_entropy_loss.default: CastPolicy.FP32, - torch.ops.aten.l1_loss.default: CastPolicy.FP32, - torch.ops.aten.huber_loss.default: CastPolicy.FP32, - torch.ops.aten.margin_ranking_loss.default: CastPolicy.FP32, - torch.ops.aten.soft_margin_loss.default: CastPolicy.FP32, - torch.ops.aten.triplet_margin_loss.default: CastPolicy.FP32, - torch.ops.aten.multi_margin_loss.default: CastPolicy.FP32, - torch.ops.aten.ctc_loss.IntList: CastPolicy.FP32, - torch.ops.aten.ctc_loss.Tensor: CastPolicy.FP32, - torch.ops.aten.kl_div.default: CastPolicy.FP32, - torch.ops.aten.multilabel_margin_loss.default: CastPolicy.FP32, - torch.ops.aten.binary_cross_entropy_with_logits.default: CastPolicy.FP32, - torch.ops.aten.fft_fft.default: CastPolicy.FP32, - torch.ops.aten.fft_ifft.default: CastPolicy.FP32, - torch.ops.aten.fft_fft2.default: CastPolicy.FP32, - torch.ops.aten.fft_ifft2.default: CastPolicy.FP32, - torch.ops.aten.fft_fftn.default: CastPolicy.FP32, - torch.ops.aten.fft_ifftn.default: CastPolicy.FP32, - torch.ops.aten.fft_rfft.default: CastPolicy.FP32, - torch.ops.aten.fft_irfft.default: CastPolicy.FP32, - torch.ops.aten.fft_rfft2.default: CastPolicy.FP32, - torch.ops.aten.fft_irfft2.default: CastPolicy.FP32, - torch.ops.aten.fft_rfftn.default: CastPolicy.FP32, - torch.ops.aten.fft_irfftn.default: CastPolicy.FP32, - torch.ops.aten.fft_hfft.default: CastPolicy.FP32, - torch.ops.aten.fft_ihfft.default: CastPolicy.FP32, - torch.ops.aten.linalg_cond.default: CastPolicy.FP32, - torch.ops.aten.linalg_cond.p_str: CastPolicy.FP32, - torch.ops.aten.linalg_matrix_rank.default: CastPolicy.FP32, - torch.ops.aten.linalg_matrix_rank.tol_tensor: CastPolicy.FP32, - torch.ops.aten.linalg_matrix_rank.atol_rtol_tensor: CastPolicy.FP32, - torch.ops.aten.linalg_matrix_rank.atol_rtol_float: CastPolicy.FP32, - torch.ops.aten.linalg_solve.default: CastPolicy.FP32, - torch.ops.aten.linalg_cholesky.default: CastPolicy.FP32, - torch.ops.aten.linalg_svdvals.default: CastPolicy.FP32, - torch.ops.aten.linalg_eigvals.default: CastPolicy.FP32, - torch.ops.aten.linalg_eigvalsh.default: CastPolicy.FP32, - torch.ops.aten.linalg_inv.default: CastPolicy.FP32, - torch.ops.aten.linalg_householder_product.default: CastPolicy.FP32, - torch.ops.aten.linalg_tensorinv.default: CastPolicy.FP32, - torch.ops.aten.linalg_tensorsolve.default: CastPolicy.FP32, - torch.ops.aten.fake_quantize_per_tensor_affine.default: CastPolicy.FP32, - torch.ops.aten.geqrf.default: CastPolicy.FP32, - torch.ops.aten._lu_with_info.default: CastPolicy.FP32, - torch.ops.aten.qr.default: CastPolicy.FP32, - torch.ops.aten.svd.default: CastPolicy.FP32, - torch.ops.aten.triangular_solve.default: CastPolicy.FP32, - torch.ops.aten.fractional_max_pool2d.default: CastPolicy.FP32, - torch.ops.aten.fractional_max_pool3d.default: CastPolicy.FP32, - torch.ops.aten.adaptive_max_pool3d.default: CastPolicy.FP32, - torch.ops.aten.multilabel_margin_loss_forward.default: CastPolicy.FP32, - torch.ops.aten.linalg_qr.default: CastPolicy.FP32, - torch.ops.aten.linalg_cholesky_ex.default: CastPolicy.FP32, - torch.ops.aten.linalg_svd.default: CastPolicy.FP32, - torch.ops.aten.linalg_eig.default: CastPolicy.FP32, - torch.ops.aten.linalg_eigh.default: CastPolicy.FP32, - torch.ops.aten.linalg_lstsq.default: CastPolicy.FP32, - torch.ops.aten.linalg_inv_ex.default: CastPolicy.FP32, - # promote - torch.ops.aten.stack.default: CastPolicy.PROMOTE, - torch.ops.aten.cat.default: CastPolicy.PROMOTE, - torch.ops.aten.index_copy.default: CastPolicy.PROMOTE, - torch.ops.aten.index_copy.dimname: CastPolicy.PROMOTE, + torch.ops.aten.conv1d.default: CastPolicy.LOWER_PRECISION_FP, + torch.ops.aten.conv1d.padding: CastPolicy.LOWER_PRECISION_FP, + torch.ops.aten.conv2d.default: CastPolicy.LOWER_PRECISION_FP, + torch.ops.aten.conv2d.padding: CastPolicy.LOWER_PRECISION_FP, + torch.ops.aten.conv3d.default: CastPolicy.LOWER_PRECISION_FP, + torch.ops.aten.conv3d.padding: CastPolicy.LOWER_PRECISION_FP, + torch.ops.aten.bmm.default: CastPolicy.LOWER_PRECISION_FP, + torch.ops.aten.mm.default: CastPolicy.LOWER_PRECISION_FP, + torch.ops.aten.linalg_vecdot.default: CastPolicy.LOWER_PRECISION_FP, + torch.ops.aten.baddbmm.default: CastPolicy.LOWER_PRECISION_FP, + torch.ops.aten.addmm.default: CastPolicy.LOWER_PRECISION_FP, + torch.ops.aten._addmm_activation.default: CastPolicy.LOWER_PRECISION_FP, + torch.ops.aten.addbmm.default: CastPolicy.LOWER_PRECISION_FP, + torch.ops.aten.linear.default: CastPolicy.LOWER_PRECISION_FP, + torch.ops.aten._convolution.deprecated: CastPolicy.LOWER_PRECISION_FP, + torch.ops.aten.matmul.default: CastPolicy.LOWER_PRECISION_FP, + torch.ops.aten.conv_tbc.default: CastPolicy.LOWER_PRECISION_FP, + torch.ops.aten.mkldnn_rnn_layer.default: CastPolicy.LOWER_PRECISION_FP, + torch.ops.aten.conv_transpose1d.default: CastPolicy.LOWER_PRECISION_FP, + torch.ops.aten.conv_transpose2d.input: CastPolicy.LOWER_PRECISION_FP, + torch.ops.aten.conv_transpose3d.input: CastPolicy.LOWER_PRECISION_FP, + torch.ops.aten.prelu.default: CastPolicy.LOWER_PRECISION_FP, + torch.ops.aten.scaled_dot_product_attention.default: CastPolicy.LOWER_PRECISION_FP, + torch.ops.aten._native_multi_head_attention.default: CastPolicy.LOWER_PRECISION_FP, + # fp32 cast policy + torch.ops.aten.avg_pool3d.default: CastPolicy.FP32, + torch.ops.aten.binary_cross_entropy.default: CastPolicy.FP32, + torch.ops.aten.grid_sampler.default: CastPolicy.FP32, + torch.ops.aten.polar.default: CastPolicy.FP32, + torch.ops.aten.prod.default: CastPolicy.FP32, + torch.ops.aten.prod.dim_int: CastPolicy.FP32, + torch.ops.aten.prod.dim_Dimname: CastPolicy.FP32, + torch.ops.aten.quantile.default: CastPolicy.FP32, + torch.ops.aten.quantile.scalar: CastPolicy.FP32, + torch.ops.aten.nanquantile.default: CastPolicy.FP32, + torch.ops.aten.nanquantile.scalar: CastPolicy.FP32, + torch.ops.aten.stft.default: CastPolicy.FP32, + torch.ops.aten.stft.center: CastPolicy.FP32, + torch.ops.aten.cdist.default: CastPolicy.FP32, + torch.ops.aten.grid_sampler_2d.default: CastPolicy.FP32, + torch.ops.aten._grid_sampler_2d_cpu_fallback.default: CastPolicy.FP32, + torch.ops.aten.grid_sampler_3d.default: CastPolicy.FP32, + torch.ops.aten.trace.default: CastPolicy.FP32, + torch.ops.aten.view_as_complex.default: CastPolicy.FP32, + torch.ops.aten.cholesky.default: CastPolicy.FP32, + torch.ops.aten.cholesky_inverse.default: CastPolicy.FP32, + torch.ops.aten.cholesky_solve.default: CastPolicy.FP32, + torch.ops.aten.inverse.default: CastPolicy.FP32, + torch.ops.aten.lu_solve.default: CastPolicy.FP32, + torch.ops.aten.orgqr.default: CastPolicy.FP32, + torch.ops.aten.ormqr.default: CastPolicy.FP32, + torch.ops.aten.pinverse.default: CastPolicy.FP32, + torch.ops.aten.max_pool3d.default: CastPolicy.FP32, + torch.ops.aten.max_unpool2d.default: CastPolicy.FP32, + torch.ops.aten.max_unpool3d.default: CastPolicy.FP32, + torch.ops.aten.adaptive_avg_pool3d.default: CastPolicy.FP32, + torch.ops.aten.reflection_pad1d.default: CastPolicy.FP32, + torch.ops.aten.reflection_pad2d.default: CastPolicy.FP32, + torch.ops.aten.replication_pad1d.default: CastPolicy.FP32, + torch.ops.aten.replication_pad2d.default: CastPolicy.FP32, + torch.ops.aten.replication_pad3d.default: CastPolicy.FP32, + torch.ops.aten.mse_loss.default: CastPolicy.FP32, + torch.ops.aten.cosine_embedding_loss.default: CastPolicy.FP32, + torch.ops.aten.nll_loss.default: CastPolicy.FP32, + torch.ops.aten.nll_loss2d.default: CastPolicy.FP32, + torch.ops.aten.hinge_embedding_loss.default: CastPolicy.FP32, + torch.ops.aten.poisson_nll_loss.default: CastPolicy.FP32, + torch.ops.aten.smooth_l1_loss.default: CastPolicy.FP32, + torch.ops.aten.cross_entropy_loss.default: CastPolicy.FP32, + torch.ops.aten.l1_loss.default: CastPolicy.FP32, + torch.ops.aten.huber_loss.default: CastPolicy.FP32, + torch.ops.aten.margin_ranking_loss.default: CastPolicy.FP32, + torch.ops.aten.soft_margin_loss.default: CastPolicy.FP32, + torch.ops.aten.triplet_margin_loss.default: CastPolicy.FP32, + torch.ops.aten.multi_margin_loss.default: CastPolicy.FP32, + torch.ops.aten.ctc_loss.IntList: CastPolicy.FP32, + torch.ops.aten.ctc_loss.Tensor: CastPolicy.FP32, + torch.ops.aten.kl_div.default: CastPolicy.FP32, + torch.ops.aten.multilabel_margin_loss.default: CastPolicy.FP32, + torch.ops.aten.binary_cross_entropy_with_logits.default: CastPolicy.FP32, + torch.ops.aten.fft_fft.default: CastPolicy.FP32, + torch.ops.aten.fft_ifft.default: CastPolicy.FP32, + torch.ops.aten.fft_fft2.default: CastPolicy.FP32, + torch.ops.aten.fft_ifft2.default: CastPolicy.FP32, + torch.ops.aten.fft_fftn.default: CastPolicy.FP32, + torch.ops.aten.fft_ifftn.default: CastPolicy.FP32, + torch.ops.aten.fft_rfft.default: CastPolicy.FP32, + torch.ops.aten.fft_irfft.default: CastPolicy.FP32, + torch.ops.aten.fft_rfft2.default: CastPolicy.FP32, + torch.ops.aten.fft_irfft2.default: CastPolicy.FP32, + torch.ops.aten.fft_rfftn.default: CastPolicy.FP32, + torch.ops.aten.fft_irfftn.default: CastPolicy.FP32, + torch.ops.aten.fft_hfft.default: CastPolicy.FP32, + torch.ops.aten.fft_ihfft.default: CastPolicy.FP32, + torch.ops.aten.linalg_cond.default: CastPolicy.FP32, + torch.ops.aten.linalg_cond.p_str: CastPolicy.FP32, + torch.ops.aten.linalg_matrix_rank.default: CastPolicy.FP32, + torch.ops.aten.linalg_matrix_rank.tol_tensor: CastPolicy.FP32, + torch.ops.aten.linalg_matrix_rank.atol_rtol_tensor: CastPolicy.FP32, + torch.ops.aten.linalg_matrix_rank.atol_rtol_float: CastPolicy.FP32, + torch.ops.aten.linalg_solve.default: CastPolicy.FP32, + torch.ops.aten.linalg_cholesky.default: CastPolicy.FP32, + torch.ops.aten.linalg_svdvals.default: CastPolicy.FP32, + torch.ops.aten.linalg_eigvals.default: CastPolicy.FP32, + torch.ops.aten.linalg_eigvalsh.default: CastPolicy.FP32, + torch.ops.aten.linalg_inv.default: CastPolicy.FP32, + torch.ops.aten.linalg_householder_product.default: CastPolicy.FP32, + torch.ops.aten.linalg_tensorinv.default: CastPolicy.FP32, + torch.ops.aten.linalg_tensorsolve.default: CastPolicy.FP32, + torch.ops.aten.fake_quantize_per_tensor_affine.default: CastPolicy.FP32, + torch.ops.aten.geqrf.default: CastPolicy.FP32, + torch.ops.aten._lu_with_info.default: CastPolicy.FP32, + torch.ops.aten.qr.default: CastPolicy.FP32, + torch.ops.aten.svd.default: CastPolicy.FP32, + torch.ops.aten.triangular_solve.default: CastPolicy.FP32, + torch.ops.aten.fractional_max_pool2d.default: CastPolicy.FP32, + torch.ops.aten.fractional_max_pool3d.default: CastPolicy.FP32, + torch.ops.aten.adaptive_max_pool3d.default: CastPolicy.FP32, + torch.ops.aten.multilabel_margin_loss_forward.default: CastPolicy.FP32, + torch.ops.aten.linalg_qr.default: CastPolicy.FP32, + torch.ops.aten.linalg_cholesky_ex.default: CastPolicy.FP32, + torch.ops.aten.linalg_svd.default: CastPolicy.FP32, + torch.ops.aten.linalg_eig.default: CastPolicy.FP32, + torch.ops.aten.linalg_eigh.default: CastPolicy.FP32, + torch.ops.aten.linalg_lstsq.default: CastPolicy.FP32, + torch.ops.aten.linalg_inv_ex.default: CastPolicy.FP32, + # promote + torch.ops.aten.stack.default: CastPolicy.PROMOTE, + torch.ops.aten.cat.default: CastPolicy.PROMOTE, + torch.ops.aten.index_copy.default: CastPolicy.PROMOTE, + torch.ops.aten.index_copy.dimname: CastPolicy.PROMOTE, } diff --git a/torchax/torchax/config.py b/torchax/torchax/config.py index 336563add738..9370625e85cb 100644 --- a/torchax/torchax/config.py +++ b/torchax/torchax/config.py @@ -3,24 +3,24 @@ @dataclasses.dataclass class Configuration: - debug_print_each_op: bool = False - debug_accuracy_for_each_op: bool = False - debug_mixed_tensor: bool = False - debug_print_each_op_operands: bool = False + debug_print_each_op: bool = False + debug_accuracy_for_each_op: bool = False + debug_mixed_tensor: bool = False + debug_print_each_op_operands: bool = False - use_int32_for_index: bool = False + use_int32_for_index: bool = False - # If true, we will convert Views into torchax.Tensors eagerly - force_materialize_views: bool = False + # If true, we will convert Views into torchax.Tensors eagerly + force_materialize_views: bool = False - # Use DLPack for converting jax.Arrays <-> and torch.Tensor - use_dlpack_for_data_conversion: bool = False + # Use DLPack for converting jax.Arrays <-> and torch.Tensor + use_dlpack_for_data_conversion: bool = False - # Flash attention - use_tpu_flash_attention: bool = False - shmap_flash_attention: bool = False + # Flash attention + use_tpu_flash_attention: bool = False + shmap_flash_attention: bool = False - # device - treat_cuda_as_jax_device: bool = True - use_torch_native_for_cpu_tensor: bool = True - internal_respect_torch_return_dtypes: bool = False + # device + treat_cuda_as_jax_device: bool = True + use_torch_native_for_cpu_tensor: bool = True + internal_respect_torch_return_dtypes: bool = False diff --git a/torchax/torchax/decompositions.py b/torchax/torchax/decompositions.py index 369000661bb5..f6d0a2891b7e 100644 --- a/torchax/torchax/decompositions.py +++ b/torchax/torchax/decompositions.py @@ -28,23 +28,23 @@ def _try_register(op, impl): - try: - register_decomposition(op)(impl) - except: - pass + try: + register_decomposition(op)(impl) + except: + pass @out_wrapper() def _reflection_pad(a: Tensor, padding: Tuple[int, ...]) -> Tensor: - def idx(left, middle, right): - dim_idx = torch.arange(-left, middle + right, device=a.device) - return middle - 1 - (middle - 1 - dim_idx.abs()).abs() - - return _reflection_or_replication_pad( - a, - padding, - idx, - ) + def idx(left, middle, right): + dim_idx = torch.arange(-left, middle + right, device=a.device) + return middle - 1 - (middle - 1 - dim_idx.abs()).abs() + + return _reflection_or_replication_pad( + a, + padding, + idx, + ) _try_register(aten.reflection_pad1d, _reflection_pad) @@ -54,50 +54,48 @@ def idx(left, middle, right): @out_wrapper() def _replication_pad(a: Tensor, padding: Tuple[int, ...]) -> Tensor: - def idx(left, middle, right): - dim_idx = torch.arange(-left, middle + right, device=a.device) - return torch.clamp(dim_idx, 0, middle - 1) - - return _reflection_or_replication_pad( - a, - padding, - idx, - ) + def idx(left, middle, right): + dim_idx = torch.arange(-left, middle + right, device=a.device) + return torch.clamp(dim_idx, 0, middle - 1) + + return _reflection_or_replication_pad( + a, + padding, + idx, + ) decomp.global_decomposition_table["post_autograd"][ - aten.replication_pad2d.default + aten.replication_pad2d.default ] = _replication_pad def _reflection_or_replication_pad( - a: Tensor, - padding: Tuple[int, ...], - idx_fn: Callable[[int, int, int], Tensor], + a: Tensor, + padding: Tuple[int, ...], + idx_fn: Callable[[int, int, int], Tensor], ) -> Tensor: - dim = len(padding) // 2 - torch._check( - a.dim() in (dim + 1, dim + 2), - lambda: f"reflection_pad{dim}d requires {dim + 1}D or {dim + 2}D input", - ) - inp_shape = a.shape[-dim:] - nc_dim = a.dim() - dim + dim = len(padding) // 2 + torch._check( + a.dim() in (dim + 1, dim + 2), + lambda: f"reflection_pad{dim}d requires {dim + 1}D or {dim + 2}D input", + ) + inp_shape = a.shape[-dim:] + nc_dim = a.dim() - dim - padding_left = [padding[2 * (dim - 1 - i)] for i in range(dim)] - padding_right = [padding[2 * (dim - 1 - i) + 1] for i in range(dim)] + padding_left = [padding[2 * (dim - 1 - i)] for i in range(dim)] + padding_right = [padding[2 * (dim - 1 - i) + 1] for i in range(dim)] - result = a - for i in range(dim): - idx: List[Any] = [None] * result.dim() - idx[i + nc_dim] = idx_fn( - padding_left[i], inp_shape[i], padding_right[i] - ) - result = aten._unsafe_index(result, idx) + result = a + for i in range(dim): + idx: List[Any] = [None] * result.dim() + idx[i + nc_dim] = idx_fn(padding_left[i], inp_shape[i], padding_right[i]) + result = aten._unsafe_index(result, idx) - # convert output to correct memory format, if necessary - memory_format = utils.suggest_memory_format(result) - result = result.contiguous(memory_format=memory_format) - return result + # convert output to correct memory format, if necessary + memory_format = utils.suggest_memory_format(result) + result = result.contiguous(memory_format=memory_format) + return result _try_register(aten.replication_pad1d, _replication_pad) @@ -105,24 +103,24 @@ def _reflection_or_replication_pad( def bernoulli(self, *, generator=None): - return (torch.rand_like(self, dtype=torch.float32) < self).to(self.dtype) + return (torch.rand_like(self, dtype=torch.float32) < self).to(self.dtype) _try_register(aten.bernoulli.default, bernoulli) def rand_like(self, **kwargs): - dtype = kwargs.get("dtype", self.dtype) - return torch.rand(self.shape, dtype=dtype) + dtype = kwargs.get("dtype", self.dtype) + return torch.rand(self.shape, dtype=dtype) def channel_shuffle(self, groups): - batchsize, channels, height, width = self.shape - channels_per_group = channels // groups - self = self.reshape(batchsize, groups, channels_per_group, height, width) - self = self.transpose(1, 2) - self = self.reshape(batchsize, channels, height, width) - return self + batchsize, channels, height, width = self.shape + channels_per_group = channels // groups + self = self.reshape(batchsize, groups, channels_per_group, height, width) + self = self.transpose(1, 2) + self = self.reshape(batchsize, channels, height, width) + return self _try_register(aten.channel_shuffle, channel_shuffle) @@ -132,7 +130,7 @@ def channel_shuffle(self, groups): def bernoulli_float(self, p=0.5): - return self.bernoulli_(p) + return self.bernoulli_(p) _try_register(aten.bernoulli_.float, bernoulli_float) @@ -140,642 +138,640 @@ def bernoulli_float(self, p=0.5): def _sum_tensors(ts) -> Tensor: - return functools.reduce(torch.add, ts) + return functools.reduce(torch.add, ts) @register_decomposition(aten.grid_sampler_3d) def _grid_sampler_3d( - a: torch.Tensor, - grid: torch.Tensor, - interpolation_mode: int = 0, - padding_mode: int = 0, - align_corners: bool = False, + a: torch.Tensor, + grid: torch.Tensor, + interpolation_mode: int = 0, + padding_mode: int = 0, + align_corners: bool = False, ) -> Tensor: - """References: https://github.com/pytorch/pytorch/blob/06a7dc21c1005750598c37f3adbc031183c74de6/torch/_decomp/decompositions.py#L4075 - - The above implement the 2d case. - """ - _expand_grid = False - torch._check( - interpolation_mode in (0, 1), - lambda: f"Invalid interpolation mode {interpolation_mode}", + """References: https://github.com/pytorch/pytorch/blob/06a7dc21c1005750598c37f3adbc031183c74de6/torch/_decomp/decompositions.py#L4075 + + The above implement the 2d case. + """ + _expand_grid = False + torch._check( + interpolation_mode in (0, 1), + lambda: f"Invalid interpolation mode {interpolation_mode}", + ) + torch._check( + padding_mode in (0, 1, 2), + lambda: f"Invalid padding mode {padding_mode}", + ) + + # a is 5D: [B, C, D, H, W] + + def unnormalize(coords: Tensor, size: int) -> Tensor: + # Rescale coordinates from [-1, 1] to: + # [0, size - 1] if align_corners is True + # [-.5, size -.5] if align_corners is False + mul = (size * 0.5 - 0.5) if align_corners else (size * 0.5) + ofs = size * 0.5 - 0.5 + return coords * mul + ofs + + # Reflects coordinates until they fall between low and high (inclusive). + # The bounds are passed as twice their value so that half-integer values + # can be represented as ints. + def reflect_coordinates( + coords: Tensor, twice_low: int, twice_high: int + ) -> Tensor: + if twice_low == twice_high: + return torch.zeros_like(coords) + coords_min = twice_low / 2 + coords_span = (twice_high - twice_low) / 2 + coords2 = (coords - coords_min).abs() + extra = torch.fmod(coords2, coords_span) + flips = (coords2 / coords_span).floor().to(dtype=torch.int8) + return torch.where( + flips & 1 == 0, extra + coords_min, coords_span + coords_min - extra ) - torch._check( - padding_mode in (0, 1, 2), - lambda: f"Invalid padding mode {padding_mode}", + + def compute_coordinates(coords: Tensor, size: int) -> Tensor: + if padding_mode == 0: # Zero + return coords + elif padding_mode == 1: # Borders + return torch.clamp(coords, 0, size - 1) + else: # padding_mode == 2, Reflection + if align_corners: + coords_reflected = reflect_coordinates(coords, 0, 2 * (size - 1)) + else: + coords_reflected = reflect_coordinates(coords, -1, 2 * size - 1) + return torch.clamp(coords_reflected, 0, size - 1) + + def compute_source_index(coords: Tensor, size: int) -> Tensor: + coords_un = unnormalize(coords, size) + return compute_coordinates(coords_un, size) + + N, C, iD, iH, iW = a.shape + _, oD, oH, oW, three = grid.shape + assert three == 3, "Last dim of grid must be 3. got {}".format(three) + + def in_bounds_cond(xs: Tensor, ys: Tensor, zs) -> Tensor: + xcheck = torch.logical_and(0 <= xs, xs < iW) + ycheck = torch.logical_and(0 <= ys, ys < iH) + zcheck = torch.logical_and(0 <= zs, zs < iD) + return torch.logical_and(xcheck, torch.logical_and(ycheck, zcheck)) + + N_idx = torch.arange(N, device=a.device).view(N, 1, 1, 1, 1) + C_idx = torch.arange(C, device=a.device).view(1, C, 1, 1, 1) + + def clip(xs: torch.Tensor, ys: torch.Tensor, zs, ws: torch.Tensor): + cond = in_bounds_cond(xs, ys, zs) + # To clip to inside valid coordinates, we map the coordinates + # to (x, y) = (0, 0) and also set the weight to 0 + # We also change the shape of the tensor to the appropriate one for + # broadcasting with N_idx, C_idx for the purposes of advanced indexing + c = C if _expand_grid else 1 + return tuple( + torch.where(cond, t, 0).view(N, c, oD, oH, oW) + for t in ( + xs.to(dtype=torch.int64), + ys.to(dtype=torch.int64), + zs.to(dtype=torch.int64), + ws, + ) + ) + + def get_summand( + ix: torch.Tensor, iy: torch.Tensor, iz: torch.Tensor, w + ) -> Tensor: + # Perform clipping, index into input tensor and multiply by weight + idx_x, idx_y, idx_z, w_ = clip(ix, iy, iz, w) + return a[N_idx, C_idx, idx_z, idx_y, idx_x] * w_ + + x = grid[..., 0] + y = grid[..., 1] + d = grid[..., 2] + + if interpolation_mode == 0: # Bilinear + ix = compute_source_index(x, iW) + iy = compute_source_index(y, iH) + id_ = compute_source_index(d, iD) + + ix_nwf, iy_nwf, id_nwf = ix.floor(), iy.floor(), id_.floor() + ix_nef, iy_nef, id_nef = ix_nwf + 1, iy_nwf, id_nwf + ix_swf, iy_swf, id_swf = ix_nwf, iy_nwf + 1, id_nwf + ix_sef, iy_sef, id_sef = ix_nef, iy_swf, id_nwf + ix_nwb, iy_nwb, id_nwb = ix_nwf, iy_nwf, id_nwf + 1 + ix_neb, iy_neb, id_neb = ix_nef, iy_nef, id_nwf + 1 + ix_swb, iy_swb, id_swb = ix_swf, iy_swf, id_nwf + 1 + ix_seb, iy_seb, id_seb = ix_sef, iy_sef, id_nwf + 1 + + w_nwf = (ix_seb - ix) * (iy_seb - iy) * (id_seb - id_) + w_nef = (ix - ix_swb) * (iy_swb - iy) * (id_swb - id_) + w_swf = (ix_neb - ix) * (iy - iy_neb) * (id_neb - id_) + w_sef = (ix - ix_nwb) * (iy - iy_nwb) * (id_nwb - id_) + w_nwb = (ix_sef - ix) * (iy_sef - iy) * (id_ - id_sef) + w_neb = (ix - ix_swf) * (iy_swf - iy) * (id_ - id_swf) + w_swb = (ix_nef - ix) * (iy - iy_nef) * (id_ - id_nef) + w_seb = (ix - ix_nwf) * (iy - iy_nwf) * (id_ - id_nwf) + + return _sum_tensors( + get_summand(ix, iy, id_, w) + for (ix, iy, id_, w) in ( + (ix_nwf, iy_nwf, id_nwf, w_nwf), + (ix_nef, iy_nef, id_nef, w_nef), + (ix_swf, iy_swf, id_swf, w_swf), + (ix_sef, iy_sef, id_sef, w_sef), + (ix_nwb, iy_nwb, id_nwb, w_nwb), + (ix_neb, iy_neb, id_neb, w_neb), + (ix_swb, iy_swb, id_swb, w_swb), + (ix_seb, iy_seb, id_seb, w_seb), + ) ) + else: # interpolation_mode == 1: # Nearest + ix = compute_source_index(x, iW) + iy = compute_source_index(y, iH) + iz = compute_source_index(d, iD) + + ix_nearest = ix.round() + iy_nearest = iy.round() + iz_nearest = iz.round() - # a is 5D: [B, C, D, H, W] - - def unnormalize(coords: Tensor, size: int) -> Tensor: - # Rescale coordinates from [-1, 1] to: - # [0, size - 1] if align_corners is True - # [-.5, size -.5] if align_corners is False - mul = (size * 0.5 - 0.5) if align_corners else (size * 0.5) - ofs = size * 0.5 - 0.5 - return coords * mul + ofs - - # Reflects coordinates until they fall between low and high (inclusive). - # The bounds are passed as twice their value so that half-integer values - # can be represented as ints. - def reflect_coordinates( - coords: Tensor, twice_low: int, twice_high: int - ) -> Tensor: - if twice_low == twice_high: - return torch.zeros_like(coords) - coords_min = twice_low / 2 - coords_span = (twice_high - twice_low) / 2 - coords2 = (coords - coords_min).abs() - extra = torch.fmod(coords2, coords_span) - flips = (coords2 / coords_span).floor().to(dtype=torch.int8) - return torch.where( - flips & 1 == 0, extra + coords_min, coords_span + coords_min - extra - ) - - def compute_coordinates(coords: Tensor, size: int) -> Tensor: - if padding_mode == 0: # Zero - return coords - elif padding_mode == 1: # Borders - return torch.clamp(coords, 0, size - 1) - else: # padding_mode == 2, Reflection - if align_corners: - coords_reflected = reflect_coordinates( - coords, 0, 2 * (size - 1) - ) - else: - coords_reflected = reflect_coordinates(coords, -1, 2 * size - 1) - return torch.clamp(coords_reflected, 0, size - 1) - - def compute_source_index(coords: Tensor, size: int) -> Tensor: - coords_un = unnormalize(coords, size) - return compute_coordinates(coords_un, size) - - N, C, iD, iH, iW = a.shape - _, oD, oH, oW, three = grid.shape - assert three == 3, "Last dim of grid must be 3. got {}".format(three) - - def in_bounds_cond(xs: Tensor, ys: Tensor, zs) -> Tensor: - xcheck = torch.logical_and(0 <= xs, xs < iW) - ycheck = torch.logical_and(0 <= ys, ys < iH) - zcheck = torch.logical_and(0 <= zs, zs < iD) - return torch.logical_and(xcheck, torch.logical_and(ycheck, zcheck)) - - N_idx = torch.arange(N, device=a.device).view(N, 1, 1, 1, 1) - C_idx = torch.arange(C, device=a.device).view(1, C, 1, 1, 1) - - def clip(xs: torch.Tensor, ys: torch.Tensor, zs, ws: torch.Tensor): - cond = in_bounds_cond(xs, ys, zs) - # To clip to inside valid coordinates, we map the coordinates - # to (x, y) = (0, 0) and also set the weight to 0 - # We also change the shape of the tensor to the appropriate one for - # broadcasting with N_idx, C_idx for the purposes of advanced indexing - c = C if _expand_grid else 1 - return tuple( - torch.where(cond, t, 0).view(N, c, oD, oH, oW) - for t in ( - xs.to(dtype=torch.int64), - ys.to(dtype=torch.int64), - zs.to(dtype=torch.int64), - ws, - ) - ) - - def get_summand( - ix: torch.Tensor, iy: torch.Tensor, iz: torch.Tensor, w - ) -> Tensor: - # Perform clipping, index into input tensor and multiply by weight - idx_x, idx_y, idx_z, w_ = clip(ix, iy, iz, w) - return a[N_idx, C_idx, idx_z, idx_y, idx_x] * w_ - - x = grid[..., 0] - y = grid[..., 1] - d = grid[..., 2] - - if interpolation_mode == 0: # Bilinear - ix = compute_source_index(x, iW) - iy = compute_source_index(y, iH) - id_ = compute_source_index(d, iD) - - ix_nwf, iy_nwf, id_nwf = ix.floor(), iy.floor(), id_.floor() - ix_nef, iy_nef, id_nef = ix_nwf + 1, iy_nwf, id_nwf - ix_swf, iy_swf, id_swf = ix_nwf, iy_nwf + 1, id_nwf - ix_sef, iy_sef, id_sef = ix_nef, iy_swf, id_nwf - ix_nwb, iy_nwb, id_nwb = ix_nwf, iy_nwf, id_nwf + 1 - ix_neb, iy_neb, id_neb = ix_nef, iy_nef, id_nwf + 1 - ix_swb, iy_swb, id_swb = ix_swf, iy_swf, id_nwf + 1 - ix_seb, iy_seb, id_seb = ix_sef, iy_sef, id_nwf + 1 - - w_nwf = (ix_seb - ix) * (iy_seb - iy) * (id_seb - id_) - w_nef = (ix - ix_swb) * (iy_swb - iy) * (id_swb - id_) - w_swf = (ix_neb - ix) * (iy - iy_neb) * (id_neb - id_) - w_sef = (ix - ix_nwb) * (iy - iy_nwb) * (id_nwb - id_) - w_nwb = (ix_sef - ix) * (iy_sef - iy) * (id_ - id_sef) - w_neb = (ix - ix_swf) * (iy_swf - iy) * (id_ - id_swf) - w_swb = (ix_nef - ix) * (iy - iy_nef) * (id_ - id_nef) - w_seb = (ix - ix_nwf) * (iy - iy_nwf) * (id_ - id_nwf) - - return _sum_tensors( - get_summand(ix, iy, id_, w) - for (ix, iy, id_, w) in ( - (ix_nwf, iy_nwf, id_nwf, w_nwf), - (ix_nef, iy_nef, id_nef, w_nef), - (ix_swf, iy_swf, id_swf, w_swf), - (ix_sef, iy_sef, id_sef, w_sef), - (ix_nwb, iy_nwb, id_nwb, w_nwb), - (ix_neb, iy_neb, id_neb, w_neb), - (ix_swb, iy_swb, id_swb, w_swb), - (ix_seb, iy_seb, id_seb, w_seb), - ) - ) - else: # interpolation_mode == 1: # Nearest - ix = compute_source_index(x, iW) - iy = compute_source_index(y, iH) - iz = compute_source_index(d, iD) - - ix_nearest = ix.round() - iy_nearest = iy.round() - iz_nearest = iz.round() - - return get_summand(ix_nearest, iy_nearest, iz_nearest, 1) + return get_summand(ix_nearest, iy_nearest, iz_nearest, 1) DECOMPOSITIONS = decomp.get_decompositions([ - torch.ops.aten.upsample_bicubic2d, - torch.ops.aten.upsample_nearest1d, - torch.ops.aten.upsample_nearest2d, - torch.ops.aten.upsample_nearest3d, - torch.ops.aten._upsample_nearest_exact1d, - torch.ops.aten._upsample_nearest_exact2d, - torch.ops.aten._upsample_nearest_exact3d, - torch.ops.aten._native_batch_norm_legit.no_stats, - torch.ops.aten._native_batch_norm_legit_functional.default, - torch.ops.aten._adaptive_avg_pool2d, - torch.ops.aten._adaptive_avg_pool3d, - torch.ops.aten.grid_sampler_2d, - torch.ops.aten.grid_sampler_3d, - torch.ops.aten.native_dropout, - torch.ops.aten.reflection_pad1d, - torch.ops.aten.reflection_pad2d, - torch.ops.aten.reflection_pad3d, - torch.ops.aten.replication_pad1d, - torch.ops.aten.replication_pad2d, - torch.ops.aten.replication_pad3d, - torch.ops.aten.bernoulli, - torch.ops.aten.rand_like, - torch.ops.aten._batch_norm_with_update, - torch.ops.aten.channel_shuffle, - torch.ops.aten.nll_loss2d_forward, - torch.ops.aten.nll_loss2d_backward, - torch.ops.aten.bernoulli_.Tensor, - torch.ops.aten.bernoulli_.float, - torch.ops.aten.log_normal, - torch.ops.aten.addcdiv.default, - torch.ops.aten.addcdiv.out, - torch.ops.aten.addcdiv_.default, - torch.ops.aten.addcmul.default, - torch.ops.aten.addcmul.out, - torch.ops.aten.addcmul_.default, - torch.ops.aten.addr.default, - torch.ops.aten.addr.out, - torch.ops.aten.affine_grid_generator.default, - torch.ops.aten.affine_grid_generator.out, - torch.ops.aten.alias_copy.default, - torch.ops.aten.alias_copy.out, - torch.ops.aten.all.default, - torch.ops.aten.all.dim, - torch.ops.aten.all.dims, - torch.ops.aten.all.out, - torch.ops.aten.all.dims_out, - torch.ops.aten.all.all_out, - torch.ops.aten.all.dimname, - torch.ops.aten.all.dimname_out, - torch.ops.aten.aminmax.default, - torch.ops.aten.aminmax.out, - torch.ops.aten.arange.default, - torch.ops.aten.arange.start, - torch.ops.aten.baddbmm.default, - torch.ops.aten.baddbmm.out, - torch.ops.aten.binary_cross_entropy.default, - torch.ops.aten.binary_cross_entropy.out, - torch.ops.aten.binary_cross_entropy_backward.default, - torch.ops.aten.binary_cross_entropy_backward.grad_input, - torch.ops.aten.binary_cross_entropy_with_logits.default, - torch.ops.aten.binary_cross_entropy_with_logits.out, - torch.ops.aten.block_diag.default, - torch.ops.aten.block_diag.out, - torch.ops.aten.celu.default, - torch.ops.aten.celu.out, - torch.ops.aten.celu_.default, - torch.ops.aten.channel_shuffle.default, - torch.ops.aten.channel_shuffle.out, - torch.ops.aten.clamp_max.default, - torch.ops.aten.clamp_max.Tensor, - torch.ops.aten.clamp_max.out, - torch.ops.aten.clamp_max.Tensor_out, - torch.ops.aten.clamp_min.default, - torch.ops.aten.clamp_min.Tensor, - torch.ops.aten.clamp_min.out, - torch.ops.aten.clamp_min.Tensor_out, - torch.ops.aten.col2im.default, - torch.ops.aten.col2im.out, - torch.ops.aten.count_nonzero.dim_IntList, - torch.ops.aten.count_nonzero.dim_IntList_out, - torch.ops.aten.count_nonzero.default, - torch.ops.aten.count_nonzero.out, - torch.ops.aten.linalg_cross.default, - torch.ops.aten.linalg_cross.out, - torch.ops.aten.cudnn_batch_norm.default, - torch.ops.aten.cudnn_batch_norm.out, - torch.ops.aten.cudnn_batch_norm_backward.default, - torch.ops.aten.cudnn_batch_norm_backward.out, - torch.ops.aten.miopen_batch_norm_backward.default, - torch.ops.aten.miopen_batch_norm_backward.out, - torch.ops.aten.deg2rad.default, - torch.ops.aten.deg2rad.out, - torch.ops.aten.deg2rad_.default, - torch.ops.aten.detach.default, - torch.ops.aten.diag_embed.default, - torch.ops.aten.diag_embed.out, - torch.ops.aten.diagonal_backward.default, - torch.ops.aten.diagonal_backward.out, - torch.ops.aten.dot.default, - torch.ops.aten.dot.out, - torch.ops.aten.vdot.default, - torch.ops.aten.vdot.out, - torch.ops.aten.elu.default, - torch.ops.aten.elu.out, - torch.ops.aten.elu_.default, - torch.ops.aten.elu_backward.default, - torch.ops.aten.elu_backward.grad_input, - torch.ops.aten.embedding_dense_backward.default, - torch.ops.aten.embedding_dense_backward.out, - torch.ops.aten.empty_like.default, - torch.ops.aten.empty_like.out, - torch.ops.aten._euclidean_dist.default, - torch.ops.aten.expand_copy.default, - torch.ops.aten.expand_copy.out, - torch.ops.aten.eye.default, - torch.ops.aten.eye.m, - torch.ops.aten.eye.out, - torch.ops.aten.eye.m_out, - torch.ops.aten.fill.Scalar, - torch.ops.aten.fill.Tensor, - torch.ops.aten.fill_.Scalar, - torch.ops.aten.fill_.Tensor, - torch.ops.aten.floor_divide.default, - torch.ops.aten.floor_divide.Scalar, - torch.ops.aten.floor_divide.out, - torch.ops.aten.floor_divide.Scalar_out, - torch.ops.aten.frac.default, - torch.ops.aten.frac.out, - torch.ops.aten.frac_.default, - torch.ops.aten.gelu_.default, - torch.ops.aten.gelu_backward.default, - torch.ops.aten.gelu_backward.grad_input, - torch.ops.aten.glu.default, - torch.ops.aten.glu.out, - torch.ops.aten.glu_backward.default, - torch.ops.aten.glu_backward.grad_input, - torch.ops.aten.hardshrink.default, - torch.ops.aten.hardshrink.out, - torch.ops.aten.hardsigmoid.default, - torch.ops.aten.hardsigmoid.out, - torch.ops.aten.hardsigmoid_.default, - torch.ops.aten.hardsigmoid_backward.default, - torch.ops.aten.hardsigmoid_backward.grad_input, - torch.ops.aten.hardswish.default, - torch.ops.aten.hardswish.out, - torch.ops.aten.hardswish_.default, - torch.ops.aten.hardswish_backward.default, - torch.ops.aten.hardswish_backward.out, - torch.ops.aten.hardtanh_.default, - torch.ops.aten.hardtanh_backward.default, - torch.ops.aten.hardtanh_backward.grad_input, - torch.ops.aten.heaviside.default, - torch.ops.aten.heaviside.out, - torch.ops.aten.heaviside_.default, - torch.ops.aten.huber_loss.default, - torch.ops.aten.huber_loss.out, - torch.ops.aten.huber_loss_backward.default, - torch.ops.aten.huber_loss_backward.out, - torch.ops.aten.im2col.default, - torch.ops.aten.im2col.out, - torch.ops.aten.index_add.default, - torch.ops.aten.index_add.out, - torch.ops.aten.index_add.dimname, - torch.ops.aten.index_add_.default, - torch.ops.aten.index_copy.default, - torch.ops.aten.index_copy.dimname, - torch.ops.aten.index_copy.out, - torch.ops.aten.index_copy_.default, - torch.ops.aten.index_copy_.dimname, - torch.ops.aten.index_fill.int_Tensor, - torch.ops.aten.index_fill.int_Scalar, - torch.ops.aten.index_fill.Dimname_Scalar, - torch.ops.aten.index_fill.Dimname_Tensor, - torch.ops.aten.index_fill.int_Scalar_out, - torch.ops.aten.index_fill.int_Tensor_out, - torch.ops.aten.index_fill_.int_Tensor, - torch.ops.aten.index_fill_.int_Scalar, - torch.ops.aten.index_fill_.Dimname_Scalar, - torch.ops.aten.index_fill_.Dimname_Tensor, - torch.ops.aten.isin.Tensor_Tensor, - torch.ops.aten.isin.Tensor_Tensor_out, - torch.ops.aten.isin.Tensor_Scalar, - torch.ops.aten.isin.Tensor_Scalar_out, - torch.ops.aten.isin.Scalar_Tensor, - torch.ops.aten.isin.Scalar_Tensor_out, - torch.ops.aten.isneginf.default, - torch.ops.aten.isneginf.out, - torch.ops.aten.isposinf.default, - torch.ops.aten.isposinf.out, - torch.ops.aten.leaky_relu_.default, - torch.ops.aten.leaky_relu_backward.default, - torch.ops.aten.leaky_relu_backward.grad_input, - torch.ops.aten.lerp.Scalar, - torch.ops.aten.lerp.Tensor, - torch.ops.aten.lerp.Scalar_out, - torch.ops.aten.lerp.Tensor_out, - torch.ops.aten.lerp_.Scalar, - torch.ops.aten.lerp_.Tensor, - torch.ops.aten.linspace.Tensor_Tensor, - torch.ops.aten.linspace.Tensor_Scalar, - torch.ops.aten.linspace.Scalar_Tensor, - torch.ops.aten.linspace.default, - torch.ops.aten.linspace.out, - torch.ops.aten.linspace.Tensor_Tensor_out, - torch.ops.aten.linspace.Tensor_Scalar_out, - torch.ops.aten.linspace.Scalar_Tensor_out, - torch.ops.aten.logaddexp.default, - torch.ops.aten.logaddexp.out, - torch.ops.aten.logaddexp2.default, - torch.ops.aten.logaddexp2.out, - torch.ops.aten.logit.default, - torch.ops.aten.logit.out, - torch.ops.aten.logit_.default, - torch.ops.aten.logit_backward.default, - torch.ops.aten.log_sigmoid_backward.default, - torch.ops.aten.log_sigmoid_backward.grad_input, - torch.ops.aten.log_sigmoid_forward.default, - torch.ops.aten.log_sigmoid_forward.output, - torch.ops.aten._log_softmax_backward_data.default, - torch.ops.aten._log_softmax_backward_data.out, - torch.ops.aten.logspace.Tensor_Tensor, - torch.ops.aten.logspace.Tensor_Scalar, - torch.ops.aten.logspace.Scalar_Tensor, - torch.ops.aten.logspace.default, - torch.ops.aten.logspace.out, - torch.ops.aten.logspace.Tensor_Tensor_out, - torch.ops.aten.logspace.Tensor_Scalar_out, - torch.ops.aten.logspace.Scalar_Tensor_out, - torch.ops.aten.logsumexp.default, - torch.ops.aten.masked_fill.Scalar, - torch.ops.aten.masked_fill.Tensor, - torch.ops.aten.masked_fill.Scalar_out, - torch.ops.aten.masked_fill.Tensor_out, - torch.ops.aten.masked_fill_.Scalar, - torch.ops.aten.masked_fill_.Tensor, - torch.ops.aten.mish.default, - torch.ops.aten.mish.out, - torch.ops.aten.mish_.default, - torch.ops.aten.mse_loss.default, - torch.ops.aten.mse_loss.out, - torch.ops.aten.mse_loss_backward.default, - torch.ops.aten.mse_loss_backward.grad_input, - torch.ops.aten.multi_margin_loss.default, - torch.ops.aten.multi_margin_loss.out, - torch.ops.aten.multilabel_margin_loss_forward.default, - torch.ops.aten.multilabel_margin_loss_forward.output, - torch.ops.aten.mv.default, - torch.ops.aten.mv.out, - torch.ops.aten.mvlgamma.default, - torch.ops.aten.mvlgamma.out, - torch.ops.aten.mvlgamma_.default, - torch.ops.aten.nansum.default, - torch.ops.aten.nansum.out, - torch.ops.aten.nan_to_num.default, - torch.ops.aten.nan_to_num.out, - torch.ops.aten.nan_to_num_.default, - torch.ops.aten.native_batch_norm_backward.default, - torch.ops.aten.native_batch_norm_backward.out, - torch.ops.aten.native_dropout_backward.default, - torch.ops.aten.native_dropout_backward.out, - torch.ops.aten.native_group_norm_backward.default, - torch.ops.aten.native_group_norm_backward.out, - torch.ops.aten.native_layer_norm_backward.default, - torch.ops.aten.native_layer_norm_backward.out, - torch.ops.aten.new_empty.default, - torch.ops.aten.new_empty.out, - torch.ops.aten.new_full.default, - torch.ops.aten.new_full.out, - torch.ops.aten.new_ones.default, - torch.ops.aten.new_ones.out, - torch.ops.aten.new_zeros.default, - torch.ops.aten.new_zeros.out, - torch.ops.aten.nll_loss2d_forward.default, - torch.ops.aten.nll_loss2d_forward.output, - torch.ops.aten.nll_loss2d_backward.default, - torch.ops.aten.nll_loss2d_backward.grad_input, - torch.ops.aten.nll_loss_backward.default, - torch.ops.aten.nll_loss_backward.grad_input, - torch.ops.aten.nll_loss_forward.default, - torch.ops.aten.nll_loss_forward.output, - torch.ops.aten.norm.Scalar, - torch.ops.aten.norm.ScalarOpt_dim, - torch.ops.aten.norm.names_ScalarOpt_dim, - torch.ops.aten.norm.ScalarOpt_dim_dtype, - torch.ops.aten.norm.dtype_out, - torch.ops.aten.norm.out, - torch.ops.aten.norm.ScalarOpt_dtype, - torch.ops.aten.norm.ScalarOpt_dtype_out, - torch.ops.aten.norm.Scalar_out, - torch.ops.aten.norm.names_ScalarOpt_dim_dtype, - torch.ops.aten.norm.names_dtype_out, - torch.ops.aten.norm.names_out, - torch.ops.aten.ones.default, - torch.ops.aten.ones_like.default, - torch.ops.aten.ones_like.out, - torch.ops.aten.pixel_shuffle.default, - torch.ops.aten.pixel_shuffle.out, - torch.ops.aten.pixel_unshuffle.default, - torch.ops.aten.pixel_unshuffle.out, - torch.ops.aten._prelu_kernel.default, - torch.ops.aten._prelu_kernel_backward.default, - torch.ops.aten._reshape_alias.default, - torch.ops.aten.rad2deg.default, - torch.ops.aten.rad2deg.out, - torch.ops.aten.rad2deg_.default, - torch.ops.aten.reflection_pad1d.default, - torch.ops.aten.reflection_pad1d.out, - torch.ops.aten.reflection_pad1d_backward.default, - torch.ops.aten.reflection_pad1d_backward.grad_input, - torch.ops.aten.reflection_pad2d.default, - torch.ops.aten.reflection_pad2d.out, - torch.ops.aten.reflection_pad2d_backward.default, - torch.ops.aten.reflection_pad2d_backward.grad_input, - torch.ops.aten.reflection_pad3d.default, - torch.ops.aten.reflection_pad3d.out, - torch.ops.aten.reflection_pad3d_backward.default, - torch.ops.aten.reflection_pad3d_backward.grad_input, - torch.ops.aten.replication_pad1d.default, - torch.ops.aten.replication_pad1d.out, - torch.ops.aten.replication_pad2d.default, - torch.ops.aten.replication_pad2d.out, - torch.ops.aten.replication_pad3d.default, - torch.ops.aten.replication_pad3d.out, - torch.ops.aten.renorm.default, - torch.ops.aten.renorm.out, - torch.ops.aten.renorm_.default, - torch.ops.aten.resize_as.default, - torch.ops.aten.resize_as.out, - torch.ops.aten.roll.default, - torch.ops.aten.roll.out, - torch.ops.aten.rot90.default, - torch.ops.aten.rot90.out, - torch.ops.aten.rrelu_with_noise.default, - torch.ops.aten.rrelu_with_noise.out, - torch.ops.aten.rrelu_with_noise_.default, - torch.ops.aten.rsub.Tensor, - torch.ops.aten.rsub.Scalar, - torch.ops.aten.rsub.Tensor_out, - torch.ops.aten.rsub.Scalar_out, - torch.ops.aten._safe_softmax.default, - torch.ops.aten._scaled_dot_product_flash_attention_for_cpu.default, - torch.ops.aten.select_backward.default, - torch.ops.aten.select_backward.out, - torch.ops.aten.select_scatter.default, - torch.ops.aten.select_scatter.out, - torch.ops.aten.sgn.default, - torch.ops.aten.sgn.out, - torch.ops.aten.sgn_.default, - torch.ops.aten.sigmoid_backward.default, - torch.ops.aten.sigmoid_backward.grad_input, - torch.ops.aten.silu.default, - torch.ops.aten.silu.out, - torch.ops.aten.silu_.default, - torch.ops.aten.silu_backward.default, - torch.ops.aten.silu_backward.grad_input, - torch.ops.aten.sinc.default, - torch.ops.aten.sinc.out, - torch.ops.aten.sinc_.default, - torch.ops.aten.slice_backward.default, - torch.ops.aten.slice_backward.out, - torch.ops.aten.smooth_l1_loss.default, - torch.ops.aten.smooth_l1_loss.out, - torch.ops.aten.smooth_l1_loss_backward.default, - torch.ops.aten.smooth_l1_loss_backward.grad_input, - torch.ops.aten.soft_margin_loss.default, - torch.ops.aten.soft_margin_loss.out, - torch.ops.aten.soft_margin_loss_backward.default, - torch.ops.aten.soft_margin_loss_backward.grad_input, - torch.ops.aten._softmax_backward_data.default, - torch.ops.aten._softmax_backward_data.out, - torch.ops.aten.softplus.default, - torch.ops.aten.softplus.out, - torch.ops.aten.softplus_backward.default, - torch.ops.aten.softplus_backward.grad_input, - torch.ops.aten.softshrink.default, - torch.ops.aten.softshrink.out, - torch.ops.aten.special_entr.default, - torch.ops.aten.special_entr.out, - torch.ops.aten.special_log_ndtr.default, - torch.ops.aten.special_log_ndtr.out, - torch.ops.aten.special_xlog1py.default, - torch.ops.aten.special_xlog1py.other_scalar, - torch.ops.aten.special_xlog1py.self_scalar, - torch.ops.aten.special_xlog1py.out, - torch.ops.aten.special_xlog1py.self_scalar_out, - torch.ops.aten.special_xlog1py.other_scalar_out, - torch.ops.aten.split.Tensor, - torch.ops.aten.split_with_sizes_copy.default, - torch.ops.aten.split_with_sizes_copy.out, - torch.ops.aten.squeeze.default, - torch.ops.aten.squeeze.dim, - torch.ops.aten.std.default, - torch.ops.aten.std.dim, - torch.ops.aten.std.correction, - torch.ops.aten.std.names_dim, - torch.ops.aten.std.names_out, - torch.ops.aten.std.out, - torch.ops.aten.std.correction_out, - torch.ops.aten.std.correction_names, - torch.ops.aten.std.correction_names_out, - torch.ops.aten.std_mean.default, - torch.ops.aten.std_mean.dim, - torch.ops.aten.std_mean.correction, - torch.ops.aten.std_mean.names_dim, - torch.ops.aten.std_mean.correction_names, - torch.ops.aten.std_mean.correction_out, - torch.ops.aten.stack.default, - torch.ops.aten.stack.out, - torch.ops.aten.sum.default, - torch.ops.aten.sum.out, - torch.ops.aten.t.default, - torch.ops.aten.t_copy.out, - torch.ops.aten.t_copy.default, - torch.ops.aten.take.default, - torch.ops.aten.take.out, - torch.ops.aten.tanh_backward.default, - torch.ops.aten.tanh_backward.grad_input, - torch.ops.aten.threshold.default, - torch.ops.aten.threshold.out, - torch.ops.aten.threshold_.default, - torch.ops.aten.threshold_backward.default, - torch.ops.aten.threshold_backward.grad_input, - torch.ops.aten.trace.default, - torch.ops.aten.trace.out, - torch.ops.aten.transpose.int, - torch.ops.aten.tril.default, - torch.ops.aten.tril.out, - torch.ops.aten.tril_.default, - torch.ops.aten.triu.default, - torch.ops.aten.triu.out, - torch.ops.aten.triu_.default, - torch.ops.aten.unbind.int, - torch.ops.aten.unbind.Dimname, - torch.ops.aten.unfold_backward.default, - torch.ops.aten.unfold_backward.out, - torch.ops.aten.unfold_copy.default, - torch.ops.aten.unfold_copy.out, - torch.ops.aten._unsafe_index.Tensor, - torch.ops.aten._unsafe_index_put.default, - torch.ops.aten._unsafe_masked_index.default, - torch.ops.aten._unsafe_masked_index_put_accumulate.default, - torch.ops.aten.unsafe_split.Tensor, - torch.ops.aten.unsafe_split_with_sizes.default, - torch.ops.aten.unsqueeze_copy.out, - torch.ops.aten.unsqueeze_copy.default, - torch.ops.aten._unsafe_view.default, - torch.ops.aten._unsafe_view.out, - torch.ops.aten.upsample_linear1d.default, - torch.ops.aten.upsample_linear1d.out, - torch.ops.aten.upsample_bilinear2d.vec, - torch.ops.aten.upsample_bilinear2d.default, - torch.ops.aten.upsample_bilinear2d.out, - torch.ops.aten.upsample_trilinear3d.vec, - torch.ops.aten.upsample_trilinear3d.default, - torch.ops.aten.upsample_trilinear3d.out, - torch.ops.aten.xlogy.Tensor, - torch.ops.aten.xlogy.Scalar_Other, - torch.ops.aten.xlogy.Scalar_Self, - torch.ops.aten.xlogy.OutTensor, - torch.ops.aten.xlogy.OutScalar_Self, - torch.ops.aten.xlogy.OutScalar_Other, - torch.ops.aten.xlogy_.Tensor, - torch.ops.aten.xlogy_.Scalar_Other, - torch.ops.aten.zero.default, - torch.ops.aten.zero.out, - torch.ops.aten.zero_.default, - torch.ops.aten.zeros.default, - torch.ops.aten.zeros_like.default, - torch.ops.aten.zeros_like.out, - torch.ops.aten._chunk_cat.default, - torch.ops.aten._chunk_cat.out, - torch.ops.aten._weight_norm_interface.default, - torch.ops.aten._weight_norm_interface.out, + torch.ops.aten.upsample_bicubic2d, + torch.ops.aten.upsample_nearest1d, + torch.ops.aten.upsample_nearest2d, + torch.ops.aten.upsample_nearest3d, + torch.ops.aten._upsample_nearest_exact1d, + torch.ops.aten._upsample_nearest_exact2d, + torch.ops.aten._upsample_nearest_exact3d, + torch.ops.aten._native_batch_norm_legit.no_stats, + torch.ops.aten._native_batch_norm_legit_functional.default, + torch.ops.aten._adaptive_avg_pool2d, + torch.ops.aten._adaptive_avg_pool3d, + torch.ops.aten.grid_sampler_2d, + torch.ops.aten.grid_sampler_3d, + torch.ops.aten.native_dropout, + torch.ops.aten.reflection_pad1d, + torch.ops.aten.reflection_pad2d, + torch.ops.aten.reflection_pad3d, + torch.ops.aten.replication_pad1d, + torch.ops.aten.replication_pad2d, + torch.ops.aten.replication_pad3d, + torch.ops.aten.bernoulli, + torch.ops.aten.rand_like, + torch.ops.aten._batch_norm_with_update, + torch.ops.aten.channel_shuffle, + torch.ops.aten.nll_loss2d_forward, + torch.ops.aten.nll_loss2d_backward, + torch.ops.aten.bernoulli_.Tensor, + torch.ops.aten.bernoulli_.float, + torch.ops.aten.log_normal, + torch.ops.aten.addcdiv.default, + torch.ops.aten.addcdiv.out, + torch.ops.aten.addcdiv_.default, + torch.ops.aten.addcmul.default, + torch.ops.aten.addcmul.out, + torch.ops.aten.addcmul_.default, + torch.ops.aten.addr.default, + torch.ops.aten.addr.out, + torch.ops.aten.affine_grid_generator.default, + torch.ops.aten.affine_grid_generator.out, + torch.ops.aten.alias_copy.default, + torch.ops.aten.alias_copy.out, + torch.ops.aten.all.default, + torch.ops.aten.all.dim, + torch.ops.aten.all.dims, + torch.ops.aten.all.out, + torch.ops.aten.all.dims_out, + torch.ops.aten.all.all_out, + torch.ops.aten.all.dimname, + torch.ops.aten.all.dimname_out, + torch.ops.aten.aminmax.default, + torch.ops.aten.aminmax.out, + torch.ops.aten.arange.default, + torch.ops.aten.arange.start, + torch.ops.aten.baddbmm.default, + torch.ops.aten.baddbmm.out, + torch.ops.aten.binary_cross_entropy.default, + torch.ops.aten.binary_cross_entropy.out, + torch.ops.aten.binary_cross_entropy_backward.default, + torch.ops.aten.binary_cross_entropy_backward.grad_input, + torch.ops.aten.binary_cross_entropy_with_logits.default, + torch.ops.aten.binary_cross_entropy_with_logits.out, + torch.ops.aten.block_diag.default, + torch.ops.aten.block_diag.out, + torch.ops.aten.celu.default, + torch.ops.aten.celu.out, + torch.ops.aten.celu_.default, + torch.ops.aten.channel_shuffle.default, + torch.ops.aten.channel_shuffle.out, + torch.ops.aten.clamp_max.default, + torch.ops.aten.clamp_max.Tensor, + torch.ops.aten.clamp_max.out, + torch.ops.aten.clamp_max.Tensor_out, + torch.ops.aten.clamp_min.default, + torch.ops.aten.clamp_min.Tensor, + torch.ops.aten.clamp_min.out, + torch.ops.aten.clamp_min.Tensor_out, + torch.ops.aten.col2im.default, + torch.ops.aten.col2im.out, + torch.ops.aten.count_nonzero.dim_IntList, + torch.ops.aten.count_nonzero.dim_IntList_out, + torch.ops.aten.count_nonzero.default, + torch.ops.aten.count_nonzero.out, + torch.ops.aten.linalg_cross.default, + torch.ops.aten.linalg_cross.out, + torch.ops.aten.cudnn_batch_norm.default, + torch.ops.aten.cudnn_batch_norm.out, + torch.ops.aten.cudnn_batch_norm_backward.default, + torch.ops.aten.cudnn_batch_norm_backward.out, + torch.ops.aten.miopen_batch_norm_backward.default, + torch.ops.aten.miopen_batch_norm_backward.out, + torch.ops.aten.deg2rad.default, + torch.ops.aten.deg2rad.out, + torch.ops.aten.deg2rad_.default, + torch.ops.aten.detach.default, + torch.ops.aten.diag_embed.default, + torch.ops.aten.diag_embed.out, + torch.ops.aten.diagonal_backward.default, + torch.ops.aten.diagonal_backward.out, + torch.ops.aten.dot.default, + torch.ops.aten.dot.out, + torch.ops.aten.vdot.default, + torch.ops.aten.vdot.out, + torch.ops.aten.elu.default, + torch.ops.aten.elu.out, + torch.ops.aten.elu_.default, + torch.ops.aten.elu_backward.default, + torch.ops.aten.elu_backward.grad_input, + torch.ops.aten.embedding_dense_backward.default, + torch.ops.aten.embedding_dense_backward.out, + torch.ops.aten.empty_like.default, + torch.ops.aten.empty_like.out, + torch.ops.aten._euclidean_dist.default, + torch.ops.aten.expand_copy.default, + torch.ops.aten.expand_copy.out, + torch.ops.aten.eye.default, + torch.ops.aten.eye.m, + torch.ops.aten.eye.out, + torch.ops.aten.eye.m_out, + torch.ops.aten.fill.Scalar, + torch.ops.aten.fill.Tensor, + torch.ops.aten.fill_.Scalar, + torch.ops.aten.fill_.Tensor, + torch.ops.aten.floor_divide.default, + torch.ops.aten.floor_divide.Scalar, + torch.ops.aten.floor_divide.out, + torch.ops.aten.floor_divide.Scalar_out, + torch.ops.aten.frac.default, + torch.ops.aten.frac.out, + torch.ops.aten.frac_.default, + torch.ops.aten.gelu_.default, + torch.ops.aten.gelu_backward.default, + torch.ops.aten.gelu_backward.grad_input, + torch.ops.aten.glu.default, + torch.ops.aten.glu.out, + torch.ops.aten.glu_backward.default, + torch.ops.aten.glu_backward.grad_input, + torch.ops.aten.hardshrink.default, + torch.ops.aten.hardshrink.out, + torch.ops.aten.hardsigmoid.default, + torch.ops.aten.hardsigmoid.out, + torch.ops.aten.hardsigmoid_.default, + torch.ops.aten.hardsigmoid_backward.default, + torch.ops.aten.hardsigmoid_backward.grad_input, + torch.ops.aten.hardswish.default, + torch.ops.aten.hardswish.out, + torch.ops.aten.hardswish_.default, + torch.ops.aten.hardswish_backward.default, + torch.ops.aten.hardswish_backward.out, + torch.ops.aten.hardtanh_.default, + torch.ops.aten.hardtanh_backward.default, + torch.ops.aten.hardtanh_backward.grad_input, + torch.ops.aten.heaviside.default, + torch.ops.aten.heaviside.out, + torch.ops.aten.heaviside_.default, + torch.ops.aten.huber_loss.default, + torch.ops.aten.huber_loss.out, + torch.ops.aten.huber_loss_backward.default, + torch.ops.aten.huber_loss_backward.out, + torch.ops.aten.im2col.default, + torch.ops.aten.im2col.out, + torch.ops.aten.index_add.default, + torch.ops.aten.index_add.out, + torch.ops.aten.index_add.dimname, + torch.ops.aten.index_add_.default, + torch.ops.aten.index_copy.default, + torch.ops.aten.index_copy.dimname, + torch.ops.aten.index_copy.out, + torch.ops.aten.index_copy_.default, + torch.ops.aten.index_copy_.dimname, + torch.ops.aten.index_fill.int_Tensor, + torch.ops.aten.index_fill.int_Scalar, + torch.ops.aten.index_fill.Dimname_Scalar, + torch.ops.aten.index_fill.Dimname_Tensor, + torch.ops.aten.index_fill.int_Scalar_out, + torch.ops.aten.index_fill.int_Tensor_out, + torch.ops.aten.index_fill_.int_Tensor, + torch.ops.aten.index_fill_.int_Scalar, + torch.ops.aten.index_fill_.Dimname_Scalar, + torch.ops.aten.index_fill_.Dimname_Tensor, + torch.ops.aten.isin.Tensor_Tensor, + torch.ops.aten.isin.Tensor_Tensor_out, + torch.ops.aten.isin.Tensor_Scalar, + torch.ops.aten.isin.Tensor_Scalar_out, + torch.ops.aten.isin.Scalar_Tensor, + torch.ops.aten.isin.Scalar_Tensor_out, + torch.ops.aten.isneginf.default, + torch.ops.aten.isneginf.out, + torch.ops.aten.isposinf.default, + torch.ops.aten.isposinf.out, + torch.ops.aten.leaky_relu_.default, + torch.ops.aten.leaky_relu_backward.default, + torch.ops.aten.leaky_relu_backward.grad_input, + torch.ops.aten.lerp.Scalar, + torch.ops.aten.lerp.Tensor, + torch.ops.aten.lerp.Scalar_out, + torch.ops.aten.lerp.Tensor_out, + torch.ops.aten.lerp_.Scalar, + torch.ops.aten.lerp_.Tensor, + torch.ops.aten.linspace.Tensor_Tensor, + torch.ops.aten.linspace.Tensor_Scalar, + torch.ops.aten.linspace.Scalar_Tensor, + torch.ops.aten.linspace.default, + torch.ops.aten.linspace.out, + torch.ops.aten.linspace.Tensor_Tensor_out, + torch.ops.aten.linspace.Tensor_Scalar_out, + torch.ops.aten.linspace.Scalar_Tensor_out, + torch.ops.aten.logaddexp.default, + torch.ops.aten.logaddexp.out, + torch.ops.aten.logaddexp2.default, + torch.ops.aten.logaddexp2.out, + torch.ops.aten.logit.default, + torch.ops.aten.logit.out, + torch.ops.aten.logit_.default, + torch.ops.aten.logit_backward.default, + torch.ops.aten.log_sigmoid_backward.default, + torch.ops.aten.log_sigmoid_backward.grad_input, + torch.ops.aten.log_sigmoid_forward.default, + torch.ops.aten.log_sigmoid_forward.output, + torch.ops.aten._log_softmax_backward_data.default, + torch.ops.aten._log_softmax_backward_data.out, + torch.ops.aten.logspace.Tensor_Tensor, + torch.ops.aten.logspace.Tensor_Scalar, + torch.ops.aten.logspace.Scalar_Tensor, + torch.ops.aten.logspace.default, + torch.ops.aten.logspace.out, + torch.ops.aten.logspace.Tensor_Tensor_out, + torch.ops.aten.logspace.Tensor_Scalar_out, + torch.ops.aten.logspace.Scalar_Tensor_out, + torch.ops.aten.logsumexp.default, + torch.ops.aten.masked_fill.Scalar, + torch.ops.aten.masked_fill.Tensor, + torch.ops.aten.masked_fill.Scalar_out, + torch.ops.aten.masked_fill.Tensor_out, + torch.ops.aten.masked_fill_.Scalar, + torch.ops.aten.masked_fill_.Tensor, + torch.ops.aten.mish.default, + torch.ops.aten.mish.out, + torch.ops.aten.mish_.default, + torch.ops.aten.mse_loss.default, + torch.ops.aten.mse_loss.out, + torch.ops.aten.mse_loss_backward.default, + torch.ops.aten.mse_loss_backward.grad_input, + torch.ops.aten.multi_margin_loss.default, + torch.ops.aten.multi_margin_loss.out, + torch.ops.aten.multilabel_margin_loss_forward.default, + torch.ops.aten.multilabel_margin_loss_forward.output, + torch.ops.aten.mv.default, + torch.ops.aten.mv.out, + torch.ops.aten.mvlgamma.default, + torch.ops.aten.mvlgamma.out, + torch.ops.aten.mvlgamma_.default, + torch.ops.aten.nansum.default, + torch.ops.aten.nansum.out, + torch.ops.aten.nan_to_num.default, + torch.ops.aten.nan_to_num.out, + torch.ops.aten.nan_to_num_.default, + torch.ops.aten.native_batch_norm_backward.default, + torch.ops.aten.native_batch_norm_backward.out, + torch.ops.aten.native_dropout_backward.default, + torch.ops.aten.native_dropout_backward.out, + torch.ops.aten.native_group_norm_backward.default, + torch.ops.aten.native_group_norm_backward.out, + torch.ops.aten.native_layer_norm_backward.default, + torch.ops.aten.native_layer_norm_backward.out, + torch.ops.aten.new_empty.default, + torch.ops.aten.new_empty.out, + torch.ops.aten.new_full.default, + torch.ops.aten.new_full.out, + torch.ops.aten.new_ones.default, + torch.ops.aten.new_ones.out, + torch.ops.aten.new_zeros.default, + torch.ops.aten.new_zeros.out, + torch.ops.aten.nll_loss2d_forward.default, + torch.ops.aten.nll_loss2d_forward.output, + torch.ops.aten.nll_loss2d_backward.default, + torch.ops.aten.nll_loss2d_backward.grad_input, + torch.ops.aten.nll_loss_backward.default, + torch.ops.aten.nll_loss_backward.grad_input, + torch.ops.aten.nll_loss_forward.default, + torch.ops.aten.nll_loss_forward.output, + torch.ops.aten.norm.Scalar, + torch.ops.aten.norm.ScalarOpt_dim, + torch.ops.aten.norm.names_ScalarOpt_dim, + torch.ops.aten.norm.ScalarOpt_dim_dtype, + torch.ops.aten.norm.dtype_out, + torch.ops.aten.norm.out, + torch.ops.aten.norm.ScalarOpt_dtype, + torch.ops.aten.norm.ScalarOpt_dtype_out, + torch.ops.aten.norm.Scalar_out, + torch.ops.aten.norm.names_ScalarOpt_dim_dtype, + torch.ops.aten.norm.names_dtype_out, + torch.ops.aten.norm.names_out, + torch.ops.aten.ones.default, + torch.ops.aten.ones_like.default, + torch.ops.aten.ones_like.out, + torch.ops.aten.pixel_shuffle.default, + torch.ops.aten.pixel_shuffle.out, + torch.ops.aten.pixel_unshuffle.default, + torch.ops.aten.pixel_unshuffle.out, + torch.ops.aten._prelu_kernel.default, + torch.ops.aten._prelu_kernel_backward.default, + torch.ops.aten._reshape_alias.default, + torch.ops.aten.rad2deg.default, + torch.ops.aten.rad2deg.out, + torch.ops.aten.rad2deg_.default, + torch.ops.aten.reflection_pad1d.default, + torch.ops.aten.reflection_pad1d.out, + torch.ops.aten.reflection_pad1d_backward.default, + torch.ops.aten.reflection_pad1d_backward.grad_input, + torch.ops.aten.reflection_pad2d.default, + torch.ops.aten.reflection_pad2d.out, + torch.ops.aten.reflection_pad2d_backward.default, + torch.ops.aten.reflection_pad2d_backward.grad_input, + torch.ops.aten.reflection_pad3d.default, + torch.ops.aten.reflection_pad3d.out, + torch.ops.aten.reflection_pad3d_backward.default, + torch.ops.aten.reflection_pad3d_backward.grad_input, + torch.ops.aten.replication_pad1d.default, + torch.ops.aten.replication_pad1d.out, + torch.ops.aten.replication_pad2d.default, + torch.ops.aten.replication_pad2d.out, + torch.ops.aten.replication_pad3d.default, + torch.ops.aten.replication_pad3d.out, + torch.ops.aten.renorm.default, + torch.ops.aten.renorm.out, + torch.ops.aten.renorm_.default, + torch.ops.aten.resize_as.default, + torch.ops.aten.resize_as.out, + torch.ops.aten.roll.default, + torch.ops.aten.roll.out, + torch.ops.aten.rot90.default, + torch.ops.aten.rot90.out, + torch.ops.aten.rrelu_with_noise.default, + torch.ops.aten.rrelu_with_noise.out, + torch.ops.aten.rrelu_with_noise_.default, + torch.ops.aten.rsub.Tensor, + torch.ops.aten.rsub.Scalar, + torch.ops.aten.rsub.Tensor_out, + torch.ops.aten.rsub.Scalar_out, + torch.ops.aten._safe_softmax.default, + torch.ops.aten._scaled_dot_product_flash_attention_for_cpu.default, + torch.ops.aten.select_backward.default, + torch.ops.aten.select_backward.out, + torch.ops.aten.select_scatter.default, + torch.ops.aten.select_scatter.out, + torch.ops.aten.sgn.default, + torch.ops.aten.sgn.out, + torch.ops.aten.sgn_.default, + torch.ops.aten.sigmoid_backward.default, + torch.ops.aten.sigmoid_backward.grad_input, + torch.ops.aten.silu.default, + torch.ops.aten.silu.out, + torch.ops.aten.silu_.default, + torch.ops.aten.silu_backward.default, + torch.ops.aten.silu_backward.grad_input, + torch.ops.aten.sinc.default, + torch.ops.aten.sinc.out, + torch.ops.aten.sinc_.default, + torch.ops.aten.slice_backward.default, + torch.ops.aten.slice_backward.out, + torch.ops.aten.smooth_l1_loss.default, + torch.ops.aten.smooth_l1_loss.out, + torch.ops.aten.smooth_l1_loss_backward.default, + torch.ops.aten.smooth_l1_loss_backward.grad_input, + torch.ops.aten.soft_margin_loss.default, + torch.ops.aten.soft_margin_loss.out, + torch.ops.aten.soft_margin_loss_backward.default, + torch.ops.aten.soft_margin_loss_backward.grad_input, + torch.ops.aten._softmax_backward_data.default, + torch.ops.aten._softmax_backward_data.out, + torch.ops.aten.softplus.default, + torch.ops.aten.softplus.out, + torch.ops.aten.softplus_backward.default, + torch.ops.aten.softplus_backward.grad_input, + torch.ops.aten.softshrink.default, + torch.ops.aten.softshrink.out, + torch.ops.aten.special_entr.default, + torch.ops.aten.special_entr.out, + torch.ops.aten.special_log_ndtr.default, + torch.ops.aten.special_log_ndtr.out, + torch.ops.aten.special_xlog1py.default, + torch.ops.aten.special_xlog1py.other_scalar, + torch.ops.aten.special_xlog1py.self_scalar, + torch.ops.aten.special_xlog1py.out, + torch.ops.aten.special_xlog1py.self_scalar_out, + torch.ops.aten.special_xlog1py.other_scalar_out, + torch.ops.aten.split.Tensor, + torch.ops.aten.split_with_sizes_copy.default, + torch.ops.aten.split_with_sizes_copy.out, + torch.ops.aten.squeeze.default, + torch.ops.aten.squeeze.dim, + torch.ops.aten.std.default, + torch.ops.aten.std.dim, + torch.ops.aten.std.correction, + torch.ops.aten.std.names_dim, + torch.ops.aten.std.names_out, + torch.ops.aten.std.out, + torch.ops.aten.std.correction_out, + torch.ops.aten.std.correction_names, + torch.ops.aten.std.correction_names_out, + torch.ops.aten.std_mean.default, + torch.ops.aten.std_mean.dim, + torch.ops.aten.std_mean.correction, + torch.ops.aten.std_mean.names_dim, + torch.ops.aten.std_mean.correction_names, + torch.ops.aten.std_mean.correction_out, + torch.ops.aten.stack.default, + torch.ops.aten.stack.out, + torch.ops.aten.sum.default, + torch.ops.aten.sum.out, + torch.ops.aten.t.default, + torch.ops.aten.t_copy.out, + torch.ops.aten.t_copy.default, + torch.ops.aten.take.default, + torch.ops.aten.take.out, + torch.ops.aten.tanh_backward.default, + torch.ops.aten.tanh_backward.grad_input, + torch.ops.aten.threshold.default, + torch.ops.aten.threshold.out, + torch.ops.aten.threshold_.default, + torch.ops.aten.threshold_backward.default, + torch.ops.aten.threshold_backward.grad_input, + torch.ops.aten.trace.default, + torch.ops.aten.trace.out, + torch.ops.aten.transpose.int, + torch.ops.aten.tril.default, + torch.ops.aten.tril.out, + torch.ops.aten.tril_.default, + torch.ops.aten.triu.default, + torch.ops.aten.triu.out, + torch.ops.aten.triu_.default, + torch.ops.aten.unbind.int, + torch.ops.aten.unbind.Dimname, + torch.ops.aten.unfold_backward.default, + torch.ops.aten.unfold_backward.out, + torch.ops.aten.unfold_copy.default, + torch.ops.aten.unfold_copy.out, + torch.ops.aten._unsafe_index.Tensor, + torch.ops.aten._unsafe_index_put.default, + torch.ops.aten._unsafe_masked_index.default, + torch.ops.aten._unsafe_masked_index_put_accumulate.default, + torch.ops.aten.unsafe_split.Tensor, + torch.ops.aten.unsafe_split_with_sizes.default, + torch.ops.aten.unsqueeze_copy.out, + torch.ops.aten.unsqueeze_copy.default, + torch.ops.aten._unsafe_view.default, + torch.ops.aten._unsafe_view.out, + torch.ops.aten.upsample_linear1d.default, + torch.ops.aten.upsample_linear1d.out, + torch.ops.aten.upsample_bilinear2d.vec, + torch.ops.aten.upsample_bilinear2d.default, + torch.ops.aten.upsample_bilinear2d.out, + torch.ops.aten.upsample_trilinear3d.vec, + torch.ops.aten.upsample_trilinear3d.default, + torch.ops.aten.upsample_trilinear3d.out, + torch.ops.aten.xlogy.Tensor, + torch.ops.aten.xlogy.Scalar_Other, + torch.ops.aten.xlogy.Scalar_Self, + torch.ops.aten.xlogy.OutTensor, + torch.ops.aten.xlogy.OutScalar_Self, + torch.ops.aten.xlogy.OutScalar_Other, + torch.ops.aten.xlogy_.Tensor, + torch.ops.aten.xlogy_.Scalar_Other, + torch.ops.aten.zero.default, + torch.ops.aten.zero.out, + torch.ops.aten.zero_.default, + torch.ops.aten.zeros.default, + torch.ops.aten.zeros_like.default, + torch.ops.aten.zeros_like.out, + torch.ops.aten._chunk_cat.default, + torch.ops.aten._chunk_cat.out, + torch.ops.aten._weight_norm_interface.default, + torch.ops.aten._weight_norm_interface.out, ]) MUTABLE_DECOMPOSITION = [ - torch.ops.aten.bernoulli_.Tensor, - torch.ops.aten.bernoulli_.float, + torch.ops.aten.bernoulli_.Tensor, + torch.ops.aten.bernoulli_.float, ] diff --git a/torchax/torchax/device_module.py b/torchax/torchax/device_module.py index 2df98eb41423..20fceaf06b43 100644 --- a/torchax/torchax/device_module.py +++ b/torchax/torchax/device_module.py @@ -1,26 +1,26 @@ def _is_in_bad_fork(): - return False + return False def manual_seed_all(seed): - pass + pass def device_count(): - return 1 + return 1 def get_rng_state(): - return [] + return [] def set_rng_state(new_state, device): - pass + pass def is_available(): - return True + return True def current_device(): - return 0 + return 0 diff --git a/torchax/torchax/distributed.py b/torchax/torchax/distributed.py index a7dd0fea04fe..b73bf202c633 100644 --- a/torchax/torchax/distributed.py +++ b/torchax/torchax/distributed.py @@ -30,225 +30,221 @@ class ProcessGroupJax(ProcessGroup): - """Distributed backend implemented with JAX.""" - - def __init__(self, prefix_store, rank, size, timeout): - super().__init__(rank, size) - self._group_name = None - - def getBackendName(self): - return "jax" - - # TODO(wcromar): why doesn't default group name setter work? - # https://github.com/pytorch/pytorch/blob/7b1988f9222f3dec5cc2012afce84218199748ae/torch/csrc/distributed/c10d/ProcessGroup.cpp#L148-L152 - def _set_group_name(self, name: str) -> None: - self._group_name = name - - @property - def group_name(self): - assert self._group_name - return self._group_name - - @staticmethod - def _work( - tensors: Union[ - torch.Tensor, List[torch.Tensor], List[List[torch.Tensor]] - ], - ) -> dist.Work: - fut = torch.futures.Future() - fut.set_result(tensors) - return torch._C._distributed_c10d._create_work_from_future(fut) - - def _allgather_base( - self, - output: torch.Tensor, - input: torch.Tensor, - opts=..., - ) -> dist.Work: - assert isinstance(input, torchax.tensor.Tensor) - assert isinstance(output, torchax.tensor.Tensor) - torch.distributed._functional_collectives.all_gather_tensor_inplace( - output, input, group=self - ) - return self._work(output) - - def allreduce( - self, - tensors: List[torch.Tensor], - opts: dist.AllreduceOptions = ..., - ) -> dist.Work: - assert len(tensors) == 1 - assert isinstance(tensors[0], torchax.tensor.Tensor) - torch.distributed._functional_collectives.all_reduce_inplace( - tensors[0], - torch.distributed._functional_collectives.REDUCE_OP_TO_STR[ - opts.reduceOp.op - ], - self, - ) + """Distributed backend implemented with JAX.""" + + def __init__(self, prefix_store, rank, size, timeout): + super().__init__(rank, size) + self._group_name = None + + def getBackendName(self): + return "jax" + + # TODO(wcromar): why doesn't default group name setter work? + # https://github.com/pytorch/pytorch/blob/7b1988f9222f3dec5cc2012afce84218199748ae/torch/csrc/distributed/c10d/ProcessGroup.cpp#L148-L152 + def _set_group_name(self, name: str) -> None: + self._group_name = name + + @property + def group_name(self): + assert self._group_name + return self._group_name + + @staticmethod + def _work( + tensors: Union[torch.Tensor, List[torch.Tensor], List[List[torch.Tensor]]], + ) -> dist.Work: + fut = torch.futures.Future() + fut.set_result(tensors) + return torch._C._distributed_c10d._create_work_from_future(fut) + + def _allgather_base( + self, + output: torch.Tensor, + input: torch.Tensor, + opts=..., + ) -> dist.Work: + assert isinstance(input, torchax.tensor.Tensor) + assert isinstance(output, torchax.tensor.Tensor) + torch.distributed._functional_collectives.all_gather_tensor_inplace( + output, input, group=self + ) + return self._work(output) + + def allreduce( + self, + tensors: List[torch.Tensor], + opts: dist.AllreduceOptions = ..., + ) -> dist.Work: + assert len(tensors) == 1 + assert isinstance(tensors[0], torchax.tensor.Tensor) + torch.distributed._functional_collectives.all_reduce_inplace( + tensors[0], + torch.distributed._functional_collectives.REDUCE_OP_TO_STR[ + opts.reduceOp.op + ], + self, + ) - return self._work(tensors) - - def broadcast( - self, - tensors: List[torch.Tensor], - opts: dist.BroadcastOptions = ..., - ) -> dist.Work: - assert len(tensors) == 1 - assert isinstance(tensors[0], torchax.tensor.Tensor) - tensors[0].copy_( - torch.distributed._functional_collectives.broadcast( - tensors[0], opts.rootRank, group=self - ) - ) + return self._work(tensors) + + def broadcast( + self, + tensors: List[torch.Tensor], + opts: dist.BroadcastOptions = ..., + ) -> dist.Work: + assert len(tensors) == 1 + assert isinstance(tensors[0], torchax.tensor.Tensor) + tensors[0].copy_( + torch.distributed._functional_collectives.broadcast( + tensors[0], opts.rootRank, group=self + ) + ) - return self._work(tensors) + return self._work(tensors) dist.Backend.register_backend("jax", ProcessGroupJax, devices=["jax"]) def jax_rendezvous_handler( - url: str, timeout: datetime.timedelta = ..., **kwargs + url: str, timeout: datetime.timedelta = ..., **kwargs ): - """Initialize distributed store with JAX process IDs. - - Requires `$MASTER_ADDR` and `$MASTER_PORT`. - """ - # TODO(wcromar): jax.distributed.initialize(...) for multiprocess on GPU - # TODO(wcromar): Can we use the XLA coordinator as a Store? This isn't part - # of their public Python API - master_ip = os.environ["MASTER_ADDR"] - master_port = int(os.environ["MASTER_PORT"]) - # TODO(wcromar): Use `torchrun`'s store if available - store = dist.TCPStore( - master_ip, - master_port, - jax.process_count(), - is_master=jax.process_index() == 0, - ) - - yield (store, jax.process_index(), jax.process_count()) + """Initialize distributed store with JAX process IDs. + + Requires `$MASTER_ADDR` and `$MASTER_PORT`. + """ + # TODO(wcromar): jax.distributed.initialize(...) for multiprocess on GPU + # TODO(wcromar): Can we use the XLA coordinator as a Store? This isn't part + # of their public Python API + master_ip = os.environ["MASTER_ADDR"] + master_port = int(os.environ["MASTER_PORT"]) + # TODO(wcromar): Use `torchrun`'s store if available + store = dist.TCPStore( + master_ip, + master_port, + jax.process_count(), + is_master=jax.process_index() == 0, + ) + + yield (store, jax.process_index(), jax.process_count()) dist.register_rendezvous_handler("jax", jax_rendezvous_handler) def spawn(f, args=(), env: Optional[torchax.tensor.Environment] = None): - """Wrap `f` in a JAX `pmap` with the axis name `torch_dist` defined. - `f` is expected to take the replica index as a positional argument, similar - to `torch.multiprocessing.spawn`. - Note: `spawn` does not actually create parallel processes. - """ - env = env or torchax.default_env() - - def jax_wrapper(index, jax_args): - index, args = env.j2t_iso([index, jax_args]) - torch_outputs = f(index, *args) - return env.t2j_iso(torch_outputs) - - jax_outputs = jax.pmap(jax_wrapper, axis_name="torch_dist")( - np.arange(jax.device_count()), env.t2j_iso(args) - ) - return env.j2t_iso(jax_outputs) + """Wrap `f` in a JAX `pmap` with the axis name `torch_dist` defined. + `f` is expected to take the replica index as a positional argument, similar + to `torch.multiprocessing.spawn`. + Note: `spawn` does not actually create parallel processes. + """ + env = env or torchax.default_env() + + def jax_wrapper(index, jax_args): + index, args = env.j2t_iso([index, jax_args]) + torch_outputs = f(index, *args) + return env.t2j_iso(torch_outputs) + + jax_outputs = jax.pmap(jax_wrapper, axis_name="torch_dist")( + np.arange(jax.device_count()), env.t2j_iso(args) + ) + return env.j2t_iso(jax_outputs) class DistributedDataParallel(torch.nn.Module): - """Re-implementation of DistributedDataParallel using JAX SPMD. - - Splits inputs along batch dimension (assumed to be 0) across all devices in - JAX runtime, including remote devices. Each process should load a distinct - shard of the input data using e.g. DistributedSampler. Each process' shard - is then further split among the addressable devices (e.g. local TPU chips) - by `shard_input`. - - Note: since parameters are replicated across addressable devices, inputs - must also be SPMD sharded using `shard_input` or `replicate_input`. - - Example usage: - - ``` - jax_model = torchax.distributed.DistributedDataParallel(create_model()) - for data, dataloader: - jax_data = jax_model.shard_input(data) - jax_output = jax_model(jax_data) - ``` - """ - - def __init__( - self, - module: torch.nn.Module, - env: Optional[torchax.tensor.Environment] = None, - **kwargs, - ): - if kwargs: - logging.warning(f"Unsupported kwargs {kwargs}") - - super().__init__() - self._env = env or torchax.default_env() - self._mesh = Mesh( - mesh_utils.create_device_mesh((jax.device_count(),)), - axis_names=("batch",), - ) - replicated_state = torch_pytree.tree_map_only( - torch.Tensor, - lambda t: self._env.j2t_iso( - jax.device_put( - self._env.to_xla(t)._elem, NamedSharding(self._mesh, P()) - ) - ), - module.state_dict(), - ) - # TODO: broadcast - module.load_state_dict(replicated_state, assign=True) - self._module = module - - def shard_input(self, inp): - per_process_batch_size = inp.shape[0] # assumes batch dim is 0 - per_replica_batch_size = ( - per_process_batch_size // jax.local_device_count() - ) - per_replica_batches = torch.chunk(inp, jax.local_device_count()) - global_batch_size = per_replica_batch_size * jax.device_count() - global_batch_shape = (global_batch_size,) + inp.shape[1:] - - sharding = NamedSharding(self._mesh, P("batch")) - return self._env.j2t_iso( - jax.make_array_from_single_device_arrays( - global_batch_shape, - NamedSharding(self._mesh, P("batch")), - arrays=[ - jax.device_put(self._env.to_xla(batch)._elem, device) - for batch, device in zip( - per_replica_batches, sharding.addressable_devices - ) - ], - ) + """Re-implementation of DistributedDataParallel using JAX SPMD. + + Splits inputs along batch dimension (assumed to be 0) across all devices in + JAX runtime, including remote devices. Each process should load a distinct + shard of the input data using e.g. DistributedSampler. Each process' shard + is then further split among the addressable devices (e.g. local TPU chips) + by `shard_input`. + + Note: since parameters are replicated across addressable devices, inputs + must also be SPMD sharded using `shard_input` or `replicate_input`. + + Example usage: + + ``` + jax_model = torchax.distributed.DistributedDataParallel(create_model()) + for data, dataloader: + jax_data = jax_model.shard_input(data) + jax_output = jax_model(jax_data) + ``` + """ + + def __init__( + self, + module: torch.nn.Module, + env: Optional[torchax.tensor.Environment] = None, + **kwargs, + ): + if kwargs: + logging.warning(f"Unsupported kwargs {kwargs}") + + super().__init__() + self._env = env or torchax.default_env() + self._mesh = Mesh( + mesh_utils.create_device_mesh((jax.device_count(),)), + axis_names=("batch",), + ) + replicated_state = torch_pytree.tree_map_only( + torch.Tensor, + lambda t: self._env.j2t_iso( + jax.device_put( + self._env.to_xla(t)._elem, NamedSharding(self._mesh, P()) ) + ), + module.state_dict(), + ) + # TODO: broadcast + module.load_state_dict(replicated_state, assign=True) + self._module = module + + def shard_input(self, inp): + per_process_batch_size = inp.shape[0] # assumes batch dim is 0 + per_replica_batch_size = per_process_batch_size // jax.local_device_count() + per_replica_batches = torch.chunk(inp, jax.local_device_count()) + global_batch_size = per_replica_batch_size * jax.device_count() + global_batch_shape = (global_batch_size,) + inp.shape[1:] + + sharding = NamedSharding(self._mesh, P("batch")) + return self._env.j2t_iso( + jax.make_array_from_single_device_arrays( + global_batch_shape, + NamedSharding(self._mesh, P("batch")), + arrays=[ + jax.device_put(self._env.to_xla(batch)._elem, device) + for batch, device in zip( + per_replica_batches, sharding.addressable_devices + ) + ], + ) + ) - def replicate_input(self, inp): - return self._env.j2t_iso( - jax.device_put(inp._elem, NamedSharding(self._mesh, P())) - ) + def replicate_input(self, inp): + return self._env.j2t_iso( + jax.device_put(inp._elem, NamedSharding(self._mesh, P())) + ) - def jit_step(self, func): - @functools.partial( - interop.jax_jit, kwargs_for_jax_jit={"donate_argnums": 0} - ) - def _jit_fn(states, args): - self.load_state_dict(states) - outputs = func(*args) - return self.state_dict(), outputs - - @functools.wraps(func) - def inner(*args): - jax_states = self.state_dict() - new_states, outputs = _jit_fn(jax_states, args) - self.load_state_dict(new_states) - return outputs - - return inner - - def forward(self, *args): - with self._env: - return self._module(*args) + def jit_step(self, func): + @functools.partial( + interop.jax_jit, kwargs_for_jax_jit={"donate_argnums": 0} + ) + def _jit_fn(states, args): + self.load_state_dict(states) + outputs = func(*args) + return self.state_dict(), outputs + + @functools.wraps(func) + def inner(*args): + jax_states = self.state_dict() + new_states, outputs = _jit_fn(jax_states, args) + self.load_state_dict(new_states) + return outputs + + return inner + + def forward(self, *args): + with self._env: + return self._module(*args) diff --git a/torchax/torchax/export.py b/torchax/torchax/export.py index d91636300b27..636de0db820b 100644 --- a/torchax/torchax/export.py +++ b/torchax/torchax/export.py @@ -17,41 +17,41 @@ class JaxInterpreter(torch.fx.Interpreter): - """Experimental.""" - - def __init__(self, graph_module): - super().__init__(graph_module) - import torchax.ops.jaten - import torchax.ops.jtorch - - def call_function(self, target, args: Tuple, kwargs: Dict) -> Any: - if not isinstance( - target, (torch._ops.OpOverloadPacket, torch._ops.OpOverload) - ): - return super().call_function(target, args, kwargs) - - if DEBUG: - print("Running ", target.name(), "--------") - - op = ops_registry.all_aten_ops.get(target) - if op is None: - op = ops_registry.all_aten_ops.get(target.overloadpacket) - assert op is not None, target - assert op.is_jax_function, op - if op is None: - op = ops_registry.all_aten_ops.get(target.overloadpacket) - if op is None: - print(target.name(), target.tags) - raise RuntimeError("No lowering found for", target.name()) - return op.func(*args, **kwargs) - - def run_node(self, n) -> Any: - res = super().run_node(n) - if DEBUG: - if n.op == "call_function": - if hasattr(res, "shape"): - print("Meta:", n.meta.get("val").shape, "REAL: ", res.shape) - return res + """Experimental.""" + + def __init__(self, graph_module): + super().__init__(graph_module) + import torchax.ops.jaten + import torchax.ops.jtorch + + def call_function(self, target, args: Tuple, kwargs: Dict) -> Any: + if not isinstance( + target, (torch._ops.OpOverloadPacket, torch._ops.OpOverload) + ): + return super().call_function(target, args, kwargs) + + if DEBUG: + print("Running ", target.name(), "--------") + + op = ops_registry.all_aten_ops.get(target) + if op is None: + op = ops_registry.all_aten_ops.get(target.overloadpacket) + assert op is not None, target + assert op.is_jax_function, op + if op is None: + op = ops_registry.all_aten_ops.get(target.overloadpacket) + if op is None: + print(target.name(), target.tags) + raise RuntimeError("No lowering found for", target.name()) + return op.func(*args, **kwargs) + + def run_node(self, n) -> Any: + res = super().run_node(n) + if DEBUG: + if n.op == "call_function": + if hasattr(res, "shape"): + print("Meta:", n.meta.get("val").shape, "REAL: ", res.shape) + return res from torch._decomp import get_decompositions @@ -61,219 +61,207 @@ def run_node(self, n) -> Any: def _extract_states_from_exported_program(exported_model): - # NOTE call convention: (parameters, buffers, user_inputs) - param_and_buffer_keys = ( - exported_model.graph_signature.parameters - + exported_model.graph_signature.buffers - ) - state_dict = copy.copy(exported_model.state_dict) - if (constants := getattr(exported_model, "constants", None)) is not None: - state_dict.update(constants) - param_buffer_values = list(state_dict[key] for key in param_and_buffer_keys) + # NOTE call convention: (parameters, buffers, user_inputs) + param_and_buffer_keys = ( + exported_model.graph_signature.parameters + + exported_model.graph_signature.buffers + ) + state_dict = copy.copy(exported_model.state_dict) + if (constants := getattr(exported_model, "constants", None)) is not None: + state_dict.update(constants) + param_buffer_values = list(state_dict[key] for key in param_and_buffer_keys) - if hasattr(exported_model.graph_signature, "lifted_tensor_constants"): - for name in exported_model.graph_signature.lifted_tensor_constants: - param_buffer_values.append(exported_model.tensor_constants[name]) + if hasattr(exported_model.graph_signature, "lifted_tensor_constants"): + for name in exported_model.graph_signature.lifted_tensor_constants: + param_buffer_values.append(exported_model.tensor_constants[name]) - return param_and_buffer_keys, param_buffer_values + return param_and_buffer_keys, param_buffer_values def exported_program_to_jax(exported_program, export_raw: bool = False): - """returns a pytree of jax arrays(state), and + """returns a pytree of jax arrays(state), and - a callable(func) that is jax function. + a callable(func) that is jax function. - func(state, input) would be how you call it. - """ - if torch.__version__ >= "2.2": - # torch version 2.1 didn't expose this yet - exported_program = exported_program.run_decompositions() - exported_program = exported_program.run_decompositions( - decompositions.DECOMPOSITIONS - ) - if DEBUG: - print(exported_program.graph_module.code) + func(state, input) would be how you call it. + """ + if torch.__version__ >= "2.2": + # torch version 2.1 didn't expose this yet + exported_program = exported_program.run_decompositions() + exported_program = exported_program.run_decompositions( + decompositions.DECOMPOSITIONS + ) + if DEBUG: + print(exported_program.graph_module.code) - names, states = _extract_states_from_exported_program(exported_program) + names, states = _extract_states_from_exported_program(exported_program) - def _extract_args(args, kwargs): - flat_args, received_spec = pytree.tree_flatten((args, kwargs)) # type: ignore[possibly-undefined] - return flat_args + def _extract_args(args, kwargs): + flat_args, received_spec = pytree.tree_flatten((args, kwargs)) # type: ignore[possibly-undefined] + return flat_args - num_mutations = len(exported_program.graph_signature.buffers_to_mutate) + num_mutations = len(exported_program.graph_signature.buffers_to_mutate) - def func(states, inputs): - args = _extract_args(inputs, {}) - res = JaxInterpreter(exported_program.graph_module).run( - *states, - *args, - enable_io_processing=False, - ) - res = res[num_mutations:] - return res + def func(states, inputs): + args = _extract_args(inputs, {}) + res = JaxInterpreter(exported_program.graph_module).run( + *states, + *args, + enable_io_processing=False, + ) + res = res[num_mutations:] + return res - if export_raw: - return names, states, func - env = torchax.default_env() - states = env.t2j_copy(states) - return states, func + if export_raw: + return names, states, func + env = torchax.default_env() + states = env.t2j_copy(states) + return states, func def extract_avals(exported): - """Return JAX Abstract Value shapes for all input parameters of the exported - program. This supports dynamic batch dimensions, including with constraints. - """ + """Return JAX Abstract Value shapes for all input parameters of the exported + program. This supports dynamic batch dimensions, including with constraints. + """ - def _to_aval(arg_meta, symbolic_shapes): - """Convet from torch type to jax abstract value for export tracing""" + def _to_aval(arg_meta, symbolic_shapes): + """Convet from torch type to jax abstract value for export tracing""" - def _get_dim(d): - if isinstance(d, torch.SymInt): - return symbolic_shapes[str(d)] - return d + def _get_dim(d): + if isinstance(d, torch.SymInt): + return symbolic_shapes[str(d)] + return d - val = arg_meta["val"] - is_scalar = ( - isinstance(val, float) - or isinstance(val, int) - or isinstance(val, bool) + val = arg_meta["val"] + is_scalar = ( + isinstance(val, float) or isinstance(val, int) or isinstance(val, bool) + ) + if is_scalar: + return jax.ShapeDtypeStruct([], type(arg_meta["val"])) + + tensor_meta = arg_meta["tensor_meta"] + shape = [_get_dim(d) for d in tensor_meta.shape] + return jax.ShapeDtypeStruct(shape, mappings.t2j_dtype(tensor_meta.dtype)) + + def _get_inputs(exported): + """Return placeholders with input metadata""" + placeholders = [p for p in exported.graph.nodes if p.op == "placeholder"] + input_placeholders = [ + p + for p, s in zip(placeholders, exported.graph_signature.input_specs) + if s.kind == torch.export.graph_signature.InputKind.USER_INPUT + ] + return input_placeholders + + def _build_symbolic_shapes(range_constraints): + """Convert torch SymInt to JAX symbolic_shape and stores in a map using the + string name of the torch symbolic int. + + TODO: There is probably a better way of storing a key for a symbolic int. + This value needs to be looked up again in `_to_aval` to figure out which + JAX symbolic to map to for a given torch tensor. + """ + if len(range_constraints) == 0: + return None + + def _build_symbolic_constraints(symbol_name, torch_constraint): + """Convert torch SymInt constraints to string for JAX symbolic_shape + Using sympy may be overkill here, currently PyTorch only uses ValueRanges + which allow specifying the min and the max of a value, for example: + torch.export.Dim("a", min=5, max=10) + ==> ("a >= 5", "a <= 10",) + """ + if ( + not isinstance( + torch_constraint, + torch.utils._sympy.value_ranges.ValueRanges, ) - if is_scalar: - return jax.ShapeDtypeStruct([], type(arg_meta["val"])) - - tensor_meta = arg_meta["tensor_meta"] - shape = [_get_dim(d) for d in tensor_meta.shape] - return jax.ShapeDtypeStruct( - shape, mappings.t2j_dtype(tensor_meta.dtype) + or torch_constraint.is_bool + ): + raise TypeError( + f"No symbolic constraint handler for: {torch_constraint}" ) - def _get_inputs(exported): - """Return placeholders with input metadata""" - placeholders = [ - p for p in exported.graph.nodes if p.op == "placeholder" - ] - input_placeholders = [ - p - for p, s in zip(placeholders, exported.graph_signature.input_specs) - if s.kind == torch.export.graph_signature.InputKind.USER_INPUT - ] - return input_placeholders - - def _build_symbolic_shapes(range_constraints): - """Convert torch SymInt to JAX symbolic_shape and stores in a map using the - string name of the torch symbolic int. - - TODO: There is probably a better way of storing a key for a symbolic int. - This value needs to be looked up again in `_to_aval` to figure out which - JAX symbolic to map to for a given torch tensor. - """ - if len(range_constraints) == 0: - return None - - def _build_symbolic_constraints(symbol_name, torch_constraint): - """Convert torch SymInt constraints to string for JAX symbolic_shape - Using sympy may be overkill here, currently PyTorch only uses ValueRanges - which allow specifying the min and the max of a value, for example: - torch.export.Dim("a", min=5, max=10) - ==> ("a >= 5", "a <= 10",) - """ - if ( - not isinstance( - torch_constraint, - torch.utils._sympy.value_ranges.ValueRanges, - ) - or torch_constraint.is_bool - ): - raise TypeError( - f"No symbolic constraint handler for: {torch_constraint}" - ) - - constraints = [] - symbol = sympy.Symbol(symbol_name) - if torch_constraint.lower != 2: - constraints.append(symbol >= torch_constraint.lower) - from sympy.core.singleton import S - - if ( - not torch_constraint.upper.is_infinite - and torch_constraint.upper is not S.IntInfinity - ): - constraints.append(symbol <= torch_constraint.upper) - - return tuple( - sympy.pretty(c, use_unicode=False) for c in constraints - ) - - def _build_symbolic_shape(sym, constraint, free_symbols): - """Returns a JAX symbolic shape for a given symbol and constraint - - There are two possible sympy `sym` inputs: - 1. Symbol - (s0) These can have custom constraints. - 2. Expr - (s0*2) These apply the expr to s0's constraints, cannot override. - - Currently support is limited to operations with a symbol and and int, - in `torch/export/dynamic_shapes.py`: - "Only increasing linear operations with integer coefficients are supported." - """ - symbol_name = str(sym) - constraints = _build_symbolic_constraints(symbol_name, constraint) - if sym.is_symbol: - symbolic_shape = jax.export.symbolic_shape( - symbol_name, constraints=constraints - ) - else: - assert len(sym.free_symbols) > 0 - scope = free_symbols[str(list(sym.free_symbols)[0])].scope - symbolic_shape = jax.export.symbolic_shape( - symbol_name, scope=scope - ) - assert len(symbolic_shape) == 1 - return symbolic_shape[0] - - # Populate symbol variables before expressions, exprs need to use the same - # Symbolic scope as the variable they operate on. Expressions can only be - # integer compuations on symbol variables, so each symbol variable is OK to - # have its own scope. - symbolic_shapes = {} - symbol_variables = [ - (s, v) for s, v in range_constraints.items() if s.is_symbol - ] - symbol_exprs = [ - (s, v) for s, v in range_constraints.items() if not s.is_symbol - ] - for sym, constraint in symbol_variables + symbol_exprs: - symbolic_shape = _build_symbolic_shape( - sym, constraint, symbolic_shapes - ) - symbolic_shapes[str(sym)] = symbolic_shape - return symbolic_shapes - - symbolic_shapes = _build_symbolic_shapes(exported.range_constraints) - args = _get_inputs(exported) - - if DEBUG: - print("Inputs to aval:", args, "--------") - print("Symbolic shapes:", symbolic_shapes) - for arg in args: - print( - "Meta2Aval", - arg.meta, - "--> ", - _to_aval(arg.meta, symbolic_shapes), - ) - - return [_to_aval(arg.meta, symbolic_shapes) for arg in args] + constraints = [] + symbol = sympy.Symbol(symbol_name) + if torch_constraint.lower != 2: + constraints.append(symbol >= torch_constraint.lower) + from sympy.core.singleton import S + + if ( + not torch_constraint.upper.is_infinite + and torch_constraint.upper is not S.IntInfinity + ): + constraints.append(symbol <= torch_constraint.upper) + + return tuple(sympy.pretty(c, use_unicode=False) for c in constraints) + + def _build_symbolic_shape(sym, constraint, free_symbols): + """Returns a JAX symbolic shape for a given symbol and constraint + + There are two possible sympy `sym` inputs: + 1. Symbol - (s0) These can have custom constraints. + 2. Expr - (s0*2) These apply the expr to s0's constraints, cannot override. + + Currently support is limited to operations with a symbol and and int, + in `torch/export/dynamic_shapes.py`: + "Only increasing linear operations with integer coefficients are supported." + """ + symbol_name = str(sym) + constraints = _build_symbolic_constraints(symbol_name, constraint) + if sym.is_symbol: + symbolic_shape = jax.export.symbolic_shape( + symbol_name, constraints=constraints + ) + else: + assert len(sym.free_symbols) > 0 + scope = free_symbols[str(list(sym.free_symbols)[0])].scope + symbolic_shape = jax.export.symbolic_shape(symbol_name, scope=scope) + assert len(symbolic_shape) == 1 + return symbolic_shape[0] + + # Populate symbol variables before expressions, exprs need to use the same + # Symbolic scope as the variable they operate on. Expressions can only be + # integer compuations on symbol variables, so each symbol variable is OK to + # have its own scope. + symbolic_shapes = {} + symbol_variables = [ + (s, v) for s, v in range_constraints.items() if s.is_symbol + ] + symbol_exprs = [ + (s, v) for s, v in range_constraints.items() if not s.is_symbol + ] + for sym, constraint in symbol_variables + symbol_exprs: + symbolic_shape = _build_symbolic_shape(sym, constraint, symbolic_shapes) + symbolic_shapes[str(sym)] = symbolic_shape + return symbolic_shapes + + symbolic_shapes = _build_symbolic_shapes(exported.range_constraints) + args = _get_inputs(exported) + + if DEBUG: + print("Inputs to aval:", args, "--------") + print("Symbolic shapes:", symbolic_shapes) + for arg in args: + print( + "Meta2Aval", + arg.meta, + "--> ", + _to_aval(arg.meta, symbolic_shapes), + ) + + return [_to_aval(arg.meta, symbolic_shapes) for arg in args] def exported_program_to_stablehlo(exported_program): - """Replacement for torch_xla.stablehlo.exported_program_to_stablehlo + """Replacement for torch_xla.stablehlo.exported_program_to_stablehlo - Convert a program exported via torch.export to StableHLO. + Convert a program exported via torch.export to StableHLO. - This supports dynamic dimension sizes and generates explicit checks for - dynamo guards in the IR using shape_assertion custom_call ops. - """ - weights, func = exported_program_to_jax(exported_program) - jax_avals = extract_avals(exported_program) - jax_export = jax.export.export(jax.jit(func))(weights, (jax_avals,)) - return weights, jax_export + This supports dynamic dimension sizes and generates explicit checks for + dynamo guards in the IR using shape_assertion custom_call ops. + """ + weights, func = exported_program_to_jax(exported_program) + jax_avals = extract_avals(exported_program) + jax_export = jax.export.export(jax.jit(func))(weights, (jax_avals,)) + return weights, jax_export diff --git a/torchax/torchax/flax.py b/torchax/torchax/flax.py index 1b10867890e6..5503dd9b71f6 100644 --- a/torchax/torchax/flax.py +++ b/torchax/torchax/flax.py @@ -6,35 +6,35 @@ class FlaxNNModule(torch.nn.Module): - def __init__(self, env, flax_module, sample_args, sample_kwargs=None): - super().__init__() - prng = env.prng_key - sample_kwargs = sample_kwargs or {} - parameter_dict = tx.interop.call_jax( - flax_module.init, prng, *sample_args, **sample_kwargs - ) - - self._params = self._encode_nested_dict(parameter_dict) - - self._flax_module = flax_module - - def _encode_nested_dict(self, nested_dict): - child_module = torch.nn.Module() - for k, v in nested_dict.items(): - if isinstance(v, dict): - child_module.add_module(k, self._encode_nested_dict(v)) - else: - child_module.register_parameter(k, torch.nn.Parameter(v)) - return child_module - - def _decode_nested_dict(self, child_module): - result = dict(child_module.named_parameters(recurse=False)) - for k, v in child_module.named_children(): - result[k] = self._decode_nested_dict(v) - return result - - def forward(self, *args, **kwargs): - nested_dict_params = self._decode_nested_dict(self._params) - return tx.interop.call_jax( - self._flax_module.apply, nested_dict_params, *args, **kwargs - ) + def __init__(self, env, flax_module, sample_args, sample_kwargs=None): + super().__init__() + prng = env.prng_key + sample_kwargs = sample_kwargs or {} + parameter_dict = tx.interop.call_jax( + flax_module.init, prng, *sample_args, **sample_kwargs + ) + + self._params = self._encode_nested_dict(parameter_dict) + + self._flax_module = flax_module + + def _encode_nested_dict(self, nested_dict): + child_module = torch.nn.Module() + for k, v in nested_dict.items(): + if isinstance(v, dict): + child_module.add_module(k, self._encode_nested_dict(v)) + else: + child_module.register_parameter(k, torch.nn.Parameter(v)) + return child_module + + def _decode_nested_dict(self, child_module): + result = dict(child_module.named_parameters(recurse=False)) + for k, v in child_module.named_children(): + result[k] = self._decode_nested_dict(v) + return result + + def forward(self, *args, **kwargs): + nested_dict_params = self._decode_nested_dict(self._params) + return tx.interop.call_jax( + self._flax_module.apply, nested_dict_params, *args, **kwargs + ) diff --git a/torchax/torchax/interop.py b/torchax/torchax/interop.py index c7a1d50feea4..b0b3533fe226 100644 --- a/torchax/torchax/interop.py +++ b/torchax/torchax/interop.py @@ -17,339 +17,333 @@ def extract_all_buffers(m: torch.nn.Module): - buffers = {} - params = {} - - def extract_one(module, prefix): - for k in dir(module): - try: - v = getattr(module, k) - except: - continue - qual_name = prefix + k - if isinstance(v, torch.nn.parameter.Parameter) and v.requires_grad: - params[qual_name] = v - elif isinstance(v, torch.Tensor): - buffers[qual_name] = v - for name, child in module.named_children(): - extract_one(child, prefix + name + ".") - - extract_one(m, "") - return params, buffers + buffers = {} + params = {} + + def extract_one(module, prefix): + for k in dir(module): + try: + v = getattr(module, k) + except: + continue + qual_name = prefix + k + if isinstance(v, torch.nn.parameter.Parameter) and v.requires_grad: + params[qual_name] = v + elif isinstance(v, torch.Tensor): + buffers[qual_name] = v + for name, child in module.named_children(): + extract_one(child, prefix + name + ".") + + extract_one(m, "") + return params, buffers def set_all_buffers(m, params, buffers): - def set_one(module, prefix): - for k in dir(module): - qual_name = prefix + k - if (potential_v := buffers.get(qual_name)) is not None: - setattr(module, k, potential_v) - elif (potential_v := params.get(qual_name)) is not None: - print(k, potential_v) - setattr(module, k, torch.nn.Parameter(potential_v)) - for name, child in module.named_children(): - set_one(child, prefix + name + ".") + def set_one(module, prefix): + for k in dir(module): + qual_name = prefix + k + if (potential_v := buffers.get(qual_name)) is not None: + setattr(module, k, potential_v) + elif (potential_v := params.get(qual_name)) is not None: + print(k, potential_v) + setattr(module, k, torch.nn.Parameter(potential_v)) + for name, child in module.named_children(): + set_one(child, prefix + name + ".") - set_one(m, "") + set_one(m, "") class JittableModule(torch.nn.Module): - def __init__( - self, m: torch.nn.Module, extra_jit_args={}, dedup_parameters=True - ): - super().__init__() - self.params, self.buffers = extract_all_buffers(m) - self._model = m - self._jitted = {} - - self._extra_jit_args = extra_jit_args - - self._extra_dumped_weights = {} - - if dedup_parameters: - temp = collections.defaultdict(list) - for k, v in self.params.items(): - temp[id(v)].append(k) - - for v in temp.values(): - if len(v) > 1: - # duplicated weights with different name - self._extra_dumped_weights[v[0]] = v[1:] - for extra_keys in v[1:]: - del self.params[extra_keys] - - @property - def __class__(self): - # Lie about the class type so that - # isinstance(jittable_module, self._model.__class__) works - return self._model.__class__ - - def __call__(self, *args, **kwargs): - return self.forward(*args, **kwargs) - - def functional_call(self, method_name, params, buffers, *args, **kwargs): - kwargs = kwargs or {} - params_copy = copy.copy(params) - params_copy.update(buffers) - # reinflate the state dict so there are not any missing keys - for k, v in self._extra_dumped_weights.items(): - for new_key in v: - params_copy[new_key] = params_copy[k] - with torch_stateless._reparametrize_module(self._model, params_copy): - res = getattr(self._model, method_name)(*args, **kwargs) - return res - - def forward(self, *args, **kwargs): - if "forward" not in self._jitted: - jitted = jax_jit( - functools.partial(self.functional_call, "forward"), - kwargs_for_jax_jit=self._extra_jit_args, - ) - - def jitted_forward(*args, **kwargs): - return jitted(self.params, self.buffers, *args, **kwargs) - - self._jitted["forward"] = jitted_forward - return self._jitted["forward"](*args, **kwargs) - - def __getattr__(self, key): - if key == "_model": - return super().__getattr__(key) - if key in self._jitted: - return self._jitted[key] - return getattr(self._model, key) - - def make_jitted(self, key): - jitted = jax_jit( - functools.partial(self.functional_call, key), - kwargs_for_jax_jit=self._extra_jit_args, - ) + def __init__( + self, m: torch.nn.Module, extra_jit_args={}, dedup_parameters=True + ): + super().__init__() + self.params, self.buffers = extract_all_buffers(m) + self._model = m + self._jitted = {} + + self._extra_jit_args = extra_jit_args + + self._extra_dumped_weights = {} + + if dedup_parameters: + temp = collections.defaultdict(list) + for k, v in self.params.items(): + temp[id(v)].append(k) + + for v in temp.values(): + if len(v) > 1: + # duplicated weights with different name + self._extra_dumped_weights[v[0]] = v[1:] + for extra_keys in v[1:]: + del self.params[extra_keys] + + @property + def __class__(self): + # Lie about the class type so that + # isinstance(jittable_module, self._model.__class__) works + return self._model.__class__ + + def __call__(self, *args, **kwargs): + return self.forward(*args, **kwargs) + + def functional_call(self, method_name, params, buffers, *args, **kwargs): + kwargs = kwargs or {} + params_copy = copy.copy(params) + params_copy.update(buffers) + # reinflate the state dict so there are not any missing keys + for k, v in self._extra_dumped_weights.items(): + for new_key in v: + params_copy[new_key] = params_copy[k] + with torch_stateless._reparametrize_module(self._model, params_copy): + res = getattr(self._model, method_name)(*args, **kwargs) + return res + + def forward(self, *args, **kwargs): + if "forward" not in self._jitted: + jitted = jax_jit( + functools.partial(self.functional_call, "forward"), + kwargs_for_jax_jit=self._extra_jit_args, + ) + + def jitted_forward(*args, **kwargs): + return jitted(self.params, self.buffers, *args, **kwargs) + + self._jitted["forward"] = jitted_forward + return self._jitted["forward"](*args, **kwargs) + + def __getattr__(self, key): + if key == "_model": + return super().__getattr__(key) + if key in self._jitted: + return self._jitted[key] + return getattr(self._model, key) + + def make_jitted(self, key): + jitted = jax_jit( + functools.partial(self.functional_call, key), + kwargs_for_jax_jit=self._extra_jit_args, + ) - def call(*args, **kwargs): - return jitted(self.params, self.buffers, *args, **kwargs) + def call(*args, **kwargs): + return jitted(self.params, self.buffers, *args, **kwargs) - self._jitted[key] = call + self._jitted[key] = call class CompileMixin: - def functional_call(self, method, params, buffers, *args, **kwargs): - kwargs = kwargs or {} - params_copy = copy.copy(params) - params_copy.update(buffers) - with torch_stateless._reparametrize_module(self, params_copy): - res = method(*args, **kwargs) - return res + def functional_call(self, method, params, buffers, *args, **kwargs): + kwargs = kwargs or {} + params_copy = copy.copy(params) + params_copy.update(buffers) + with torch_stateless._reparametrize_module(self, params_copy): + res = method(*args, **kwargs) + return res - def jit(self, method): - jitted = jax_jit(functools.partial(self.functional_call, method_name)) + def jit(self, method): + jitted = jax_jit(functools.partial(self.functional_call, method_name)) - def call(*args, **kwargs): - return jitted( - self.named_paramters(), self.named_buffers(), *args, **kwargs - ) + def call(*args, **kwargs): + return jitted( + self.named_paramters(), self.named_buffers(), *args, **kwargs + ) - return call + return call def compile_nn_module(m: torch.nn.Module, methods=None): - if methods is None: - methods = ["forward"] + if methods is None: + methods = ["forward"] - new_parent = type( - m.__class__.__name__ + "_with_CompileMixin", - (CompileMixin, m.__class__), - ) - m.__class__ = NewParent + new_parent = type( + m.__class__.__name__ + "_with_CompileMixin", + (CompileMixin, m.__class__), + ) + m.__class__ = NewParent def _torch_view(t: JaxValue) -> TorchValue: - # t is an object from jax land - # view it as-if it's a torch land object - if isinstance(t, jax.Array): - # TODO - return tensor.Tensor(t, torchax.default_env()) - if isinstance(t, type(jnp.int32)): - return tensor.t2j_type(t) - if callable(t): # t is a JaxCallable - return functools.partial(call_jax, t) - # regular types are not changed - return t + # t is an object from jax land + # view it as-if it's a torch land object + if isinstance(t, jax.Array): + # TODO + return tensor.Tensor(t, torchax.default_env()) + if isinstance(t, type(jnp.int32)): + return tensor.t2j_type(t) + if callable(t): # t is a JaxCallable + return functools.partial(call_jax, t) + # regular types are not changed + return t torch_view = functools.partial(pytree.tree_map, _torch_view) def _jax_view(t: TorchValue) -> JaxValue: - # t is an object from torch land - # view it as-if it's a jax land object - if isinstance(t, torch.Tensor): - assert isinstance(t, tensor.Tensor) or isinstance(t, tensor.View), type( - t - ) - return t.jax() - if isinstance(t, type(torch.int32)): - return tensor.t2j_dtype(t) + # t is an object from torch land + # view it as-if it's a jax land object + if isinstance(t, torch.Tensor): + assert isinstance(t, tensor.Tensor) or isinstance(t, tensor.View), type(t) + return t.jax() + if isinstance(t, type(torch.int32)): + return tensor.t2j_dtype(t) - # torch.nn.Module needs special handling - if not isinstance(t, torch.nn.Module) and callable( - t - ): # t is a TorchCallable - return functools.partial(call_torch, t) - # regular types are not changed - return t + # torch.nn.Module needs special handling + if not isinstance(t, torch.nn.Module) and callable(t): # t is a TorchCallable + return functools.partial(call_torch, t) + # regular types are not changed + return t jax_view = functools.partial(pytree.tree_map, _jax_view) def call_jax( - jax_func: JaxCallable, *args: TorchValue, **kwargs: TorchValue + jax_func: JaxCallable, *args: TorchValue, **kwargs: TorchValue ) -> TorchValue: - args, kwargs = jax_view((args, kwargs)) - res: JaxValue = jax_func(*args, **kwargs) - return torch_view(res) + args, kwargs = jax_view((args, kwargs)) + res: JaxValue = jax_func(*args, **kwargs) + return torch_view(res) def call_torch( - torch_func: TorchCallable, *args: JaxValue, **kwargs: JaxValue + torch_func: TorchCallable, *args: JaxValue, **kwargs: JaxValue ) -> JaxValue: - args, kwargs = torch_view((args, kwargs)) - with torchax.default_env(): - res: TorchValue = torch_func(*args, **kwargs) - return jax_view(res) + args, kwargs = torch_view((args, kwargs)) + with torchax.default_env(): + res: TorchValue = torch_func(*args, **kwargs) + return jax_view(res) def j2t_autograd(fn, call_jax=call_jax): - """Given a JAX function, returns a PyTorch autograd function implemented with `jax.vjp(fn)`. - - It wraps `fn` with `jax.vjp` to compute both the output and residuals (intermediate - activations). The wrapped function is then run via `call_jax` and integrated into - the PyTorch autograd framework by saving the residuals into the context object. - """ - - @wraps(fn) - def inner(*args, **kwargs): - from jax.tree_util import tree_flatten, tree_unflatten - from jax.util import safe_zip - - class JaxFun(torch.autograd.Function): - @staticmethod - def forward(ctx, tree_def, *flat_args_kwargs): - tensors, other = util.partition( - flat_args_kwargs, lambda x: isinstance(x, torch.Tensor) - ) - # We want the arguments that don't require grads to be closured? - - y, fun_vjp = call_jax( - _jax_forward, fn, other, tree_def, tensors - ) - - # Save necessary information for backward - # Flatten the vjp function. `vjp_spec` contains a jaxpr for the backward pass. - # `residuals` contains the tensors needed for the backward pass.` - residuals, vjp_spec = tree_flatten(fun_vjp) - ctx.vjp_spec = vjp_spec - ctx.save_for_backward(*residuals) - return y - - @staticmethod - def backward(ctx, *grad_out): - assert len(grad_out) > 0 - grad_out = grad_out if len(grad_out) > 1 else grad_out[0] - - input_grads_structured = call_jax( - _jax_backward, ctx.vjp_spec, ctx.saved_tensors, grad_out - ) - - # Construct the gradient tuple to be returned. - # It needs to match the inputs to forward: (tree_def, *flat_inputs) - # The first gradient (for tree_def) is None. - # The subsequent gradients correspond to flat_inputs. - # We need to put a None for inputs that did not require gradients. - final_grads = [None] - for needs_grad, grad in safe_zip( - ctx.needs_input_grad[1:], input_grads_structured - ): - final_grads.append(grad if needs_grad else None) - - return tuple(final_grads) - - sig = signature(fn) - bound = sig.bind(*args, **kwargs) - bound.apply_defaults() - flat_args_kwargs, tree_def = tree_flatten((bound.args, bound.kwargs)) - y = JaxFun.apply(tree_def, *flat_args_kwargs) + """Given a JAX function, returns a PyTorch autograd function implemented with `jax.vjp(fn)`. + + It wraps `fn` with `jax.vjp` to compute both the output and residuals (intermediate + activations). The wrapped function is then run via `call_jax` and integrated into + the PyTorch autograd framework by saving the residuals into the context object. + """ + + @wraps(fn) + def inner(*args, **kwargs): + from jax.tree_util import tree_flatten, tree_unflatten + from jax.util import safe_zip + + class JaxFun(torch.autograd.Function): + @staticmethod + def forward(ctx, tree_def, *flat_args_kwargs): + tensors, other = util.partition( + flat_args_kwargs, lambda x: isinstance(x, torch.Tensor) + ) + # We want the arguments that don't require grads to be closured? + + y, fun_vjp = call_jax(_jax_forward, fn, other, tree_def, tensors) + + # Save necessary information for backward + # Flatten the vjp function. `vjp_spec` contains a jaxpr for the backward pass. + # `residuals` contains the tensors needed for the backward pass.` + residuals, vjp_spec = tree_flatten(fun_vjp) + ctx.vjp_spec = vjp_spec + ctx.save_for_backward(*residuals) return y - return inner + @staticmethod + def backward(ctx, *grad_out): + assert len(grad_out) > 0 + grad_out = grad_out if len(grad_out) > 1 else grad_out[0] + + input_grads_structured = call_jax( + _jax_backward, ctx.vjp_spec, ctx.saved_tensors, grad_out + ) + + # Construct the gradient tuple to be returned. + # It needs to match the inputs to forward: (tree_def, *flat_inputs) + # The first gradient (for tree_def) is None. + # The subsequent gradients correspond to flat_inputs. + # We need to put a None for inputs that did not require gradients. + final_grads = [None] + for needs_grad, grad in safe_zip( + ctx.needs_input_grad[1:], input_grads_structured + ): + final_grads.append(grad if needs_grad else None) + + return tuple(final_grads) + + sig = signature(fn) + bound = sig.bind(*args, **kwargs) + bound.apply_defaults() + flat_args_kwargs, tree_def = tree_flatten((bound.args, bound.kwargs)) + y = JaxFun.apply(tree_def, *flat_args_kwargs) + return y + + return inner # NOTE(qihqi): This function cannot be inlined from the callsite # Becuase if it does, then it won't hit the compilation cache for # call_jax. Call jax uses functions' id as key. def _jax_forward(fn, other, tree_def, tensors): - """JAX function to compute output and vjp function. + """JAX function to compute output and vjp function. - primals should be a tuple (args, kwargs). - """ - import jax - from jax.tree_util import tree_flatten, tree_unflatten + primals should be a tuple (args, kwargs). + """ + import jax + from jax.tree_util import tree_flatten, tree_unflatten - def fn_wrapper(*tensors): - # Reconstruct the original args and kwargs - flat_inputs = util.merge(tensors, other) - args, kwargs = tree_unflatten(tree_def, flat_inputs) - return fn(*args, **kwargs) + def fn_wrapper(*tensors): + # Reconstruct the original args and kwargs + flat_inputs = util.merge(tensors, other) + args, kwargs = tree_unflatten(tree_def, flat_inputs) + return fn(*args, **kwargs) - return jax.vjp(fn_wrapper, *tensors) + return jax.vjp(fn_wrapper, *tensors) def _jax_backward(vjp_spec, saved_tensors, grad_out): - """JAX function to compute input gradients. + """JAX function to compute input gradients. - Unflattening `saved_tensors` with `vjp_spec` should restore the original vjp function. - """ - from jax.tree_util import tree_unflatten + Unflattening `saved_tensors` with `vjp_spec` should restore the original vjp function. + """ + from jax.tree_util import tree_unflatten - fun_vjp = tree_unflatten(vjp_spec, saved_tensors) - return fun_vjp(grad_out) + fun_vjp = tree_unflatten(vjp_spec, saved_tensors) + return fun_vjp(grad_out) fori_loop = torch_view(jax.lax.fori_loop) def wrap_jax_jit(torch_function, jax_jit_func=jax.jit, kwargs_for_jax=None): - kwargs_for_jax = kwargs_for_jax or {} - jax_func = jax_view(torch_function) - jitted = jax_jit_func(jax_func, **kwargs_for_jax) - return torch_view(jitted) + kwargs_for_jax = kwargs_for_jax or {} + jax_func = jax_view(torch_function) + jitted = jax_jit_func(jax_func, **kwargs_for_jax) + return torch_view(jitted) def jax_jit( - torch_function, kwargs_for_jax_jit=None, fix_for_buffer_donation=False + torch_function, kwargs_for_jax_jit=None, fix_for_buffer_donation=False ): - return wrap_jax_jit( - torch_function, jax_jit_func=jax.jit, kwargs_for_jax=kwargs_for_jax_jit - ) + return wrap_jax_jit( + torch_function, jax_jit_func=jax.jit, kwargs_for_jax=kwargs_for_jax_jit + ) def jax_shard_map(torch_function, kwargs_for_jax_shard_map=None): - return wrap_jax_jit( - torch_function, - jax_jit_func=shard_map, - kwargs_for_jax=kwargs_for_jax_shard_map, - ) + return wrap_jax_jit( + torch_function, + jax_jit_func=shard_map, + kwargs_for_jax=kwargs_for_jax_shard_map, + ) def jax_value_and_grad(torch_function, kwargs_for_value_and_grad=None): - return wrap_jax_jit( - torch_function, - jax_jit_func=jax.value_and_grad, - kwargs_for_jax=kwargs_for_value_and_grad, - ) + return wrap_jax_jit( + torch_function, + jax_jit_func=jax.value_and_grad, + kwargs_for_jax=kwargs_for_value_and_grad, + ) def gradient_checkpoint(torch_function, kwargs=None): - return wrap_jax_jit( - torch_function, jax_jit_func=jax.checkpoint, kwargs_for_jax=kwargs - ) + return wrap_jax_jit( + torch_function, jax_jit_func=jax.checkpoint, kwargs_for_jax=kwargs + ) diff --git a/torchax/torchax/mesh_util.py b/torchax/torchax/mesh_util.py index 281f5fd80da5..7b74ea091428 100644 --- a/torchax/torchax/mesh_util.py +++ b/torchax/torchax/mesh_util.py @@ -6,208 +6,208 @@ def _shard_first_multiple_of(axis_name, shape, multiple_of): - """Creates a PartitionSpec to shard the first dimension divisible by a number. + """Creates a PartitionSpec to shard the first dimension divisible by a number. + + Iterates through the dimensions specified by `shape`. Finds the first dimension + whose size is a multiple of `multiple_of` and returns a PartitionSpec that + shards that dimension along the given `axis_name`. All preceding dimensions + are not sharded (marked as None in the PartitionSpec). All subsequent dimensions + skipped, which would be implicitly treated as replicated. + + Args: + axis_name: The name of the mesh axis to shard along (e.g., "data", "mdl"). + shape: A tuple or list representing the shape of the tensor to be sharded. + multiple_of: The integer value that a dimension size must be divisible by + in order to be sharded. Typically the size of the mesh axis. + + Returns: + A jax.sharding.PartitionSpec object specifying how to shard the tensor. + For example, if shape=(10, 20, 30), axis_name='x', multiple_of=4, + it would return PartitionSpec(None, 'x', None). + If none divides then it should return a replicated PartitionSpec + """ + sharding = [] + found = False + for size in shape: + if not found and size % multiple_of == 0: + found = True + sharding.append(axis_name) + else: + sharding.append(None) + return PartitionSpec(*sharding) - Iterates through the dimensions specified by `shape`. Finds the first dimension - whose size is a multiple of `multiple_of` and returns a PartitionSpec that - shards that dimension along the given `axis_name`. All preceding dimensions - are not sharded (marked as None in the PartitionSpec). All subsequent dimensions - skipped, which would be implicitly treated as replicated. + +class SingleAxisSharder: + """A callable object that generates PartitionSpecs for single-axis sharding. + + This sharder strategy attempts to shard the *first* dimension of a tensor + that is divisible by the specified `axis_size` along the given `axis_name`. + It's useful for simple 1D mesh sharding scenarios like FSDP where parameters + are typically sharded along one dimension. + + Attributes: + axis_name: The name of the mesh axis to shard along. + axis_size: The size of the mesh axis (number of devices along that axis). + """ + + def __init__(self, axis_name, axis_size, replicate_unshardable=False): + """Initializes the SingleAxisSharder. + + Args: + axis_name: The name of the mesh axis (e.g., "fsdp", "data"). + axis_size: The number of devices along the specified mesh axis. + replicate_unshardable: indicate whether it should return replicated sharding + (P()) when none of the axis is divisible by the axis size. + """ + self.axis_name = axis_name + self.axis_size = axis_size + self.replicate_unshardable = replicate_unshardable + + def __call__(self, name, shapedtype): + """Generates a PartitionSpec for a given tensor name and shaped type. Args: - axis_name: The name of the mesh axis to shard along (e.g., "data", "mdl"). - shape: A tuple or list representing the shape of the tensor to be sharded. - multiple_of: The integer value that a dimension size must be divisible by - in order to be sharded. Typically the size of the mesh axis. + name: The name of the tensor (e.g., parameter name). This argument is + provided for compatibility with more complex sharders but is not used + by this simple sharder. + shapedtype: An object with a `.shape` attribute describing the tensor's shape, + and `.dtype` describing it's dtype. Example: jax.Array, jax.ShapeDtypeStruct + or a torch.Tensor) Returns: - A jax.sharding.PartitionSpec object specifying how to shard the tensor. - For example, if shape=(10, 20, 30), axis_name='x', multiple_of=4, - it would return PartitionSpec(None, 'x', None). - If none divides then it should return a replicated PartitionSpec + A jax.sharding.PartitionSpec determined by finding the first dimension + in `shapedtype.shape` divisible by `self.axis_size` using the helper + `_shard_first_multiple_of`. """ - sharding = [] - found = False - for size in shape: - if not found and size % multiple_of == 0: - found = True - sharding.append(axis_name) - else: - sharding.append(None) - return PartitionSpec(*sharding) + del name + sharding = _shard_first_multiple_of( + self.axis_name, shapedtype.shape, self.axis_size + ) + if not self.replicate_unshardable and all(s is None for s in sharding): + raise AssertionError( + f"Unable to find a dim to shard because " + f"None of the dims ({shapedtype.shape}) in shape is multiple of {self.axis_size}" + ) + return sharding -class SingleAxisSharder: - """A callable object that generates PartitionSpecs for single-axis sharding. +class Mesh: + """A helper class that wraps `jax.sharding.Mesh` object. - This sharder strategy attempts to shard the *first* dimension of a tensor - that is divisible by the specified `axis_size` along the given `axis_name`. - It's useful for simple 1D mesh sharding scenarios like FSDP where parameters - are typically sharded along one dimension. + The goal of this class is to provide helper methods that facilitate the + sharding of PyTorch tensors or models given a JAX device mesh configuration. + It simplifies initializing models directly into a sharded state. - Attributes: - axis_name: The name of the mesh axis to shard along. - axis_size: The size of the mesh axis (number of devices along that axis). - """ + Attributes: + jax_mesh: The underlying `jax.sharding.Mesh` object defining the device grid + and axis names. + _sharder: The default sharding strategy callable (like SingleAxisSharder) + used to determine the PartitionSpec for each parameter if not overridden + during method calls. Can be None if no default is appropriate or set. + """ - def __init__(self, axis_name, axis_size, replicate_unshardable=False): - """Initializes the SingleAxisSharder. - - Args: - axis_name: The name of the mesh axis (e.g., "fsdp", "data"). - axis_size: The number of devices along the specified mesh axis. - replicate_unshardable: indicate whether it should return replicated sharding - (P()) when none of the axis is divisible by the axis size. - """ - self.axis_name = axis_name - self.axis_size = axis_size - self.replicate_unshardable = replicate_unshardable - - def __call__(self, name, shapedtype): - """Generates a PartitionSpec for a given tensor name and shaped type. - - Args: - name: The name of the tensor (e.g., parameter name). This argument is - provided for compatibility with more complex sharders but is not used - by this simple sharder. - shapedtype: An object with a `.shape` attribute describing the tensor's shape, - and `.dtype` describing it's dtype. Example: jax.Array, jax.ShapeDtypeStruct - or a torch.Tensor) - - Returns: - A jax.sharding.PartitionSpec determined by finding the first dimension - in `shapedtype.shape` divisible by `self.axis_size` using the helper - `_shard_first_multiple_of`. - """ - del name - sharding = _shard_first_multiple_of( - self.axis_name, shapedtype.shape, self.axis_size - ) - if not self.replicate_unshardable and all(s is None for s in sharding): - raise AssertionError( - f"Unable to find a dim to shard because " - f"None of the dims ({shapedtype.shape}) in shape is multiple of {self.axis_size}" - ) - return sharding + @classmethod + def fsdp_mesh(cls, axis_name="fsdp"): + """Creates a Mesh instance suitable for 1D FSDP-style sharding. + This named constructor creates a 1D mesh encompassing all available XLA + devices. It assigns the specified `axis_name` to this single dimension. + It then creates a `Mesh` instance using this JAX mesh and a + `SingleAxisSharder` configured appropriately for this 1D mesh. -class Mesh: - """A helper class that wraps `jax.sharding.Mesh` object. - - The goal of this class is to provide helper methods that facilitate the - sharding of PyTorch tensors or models given a JAX device mesh configuration. - It simplifies initializing models directly into a sharded state. - - Attributes: - jax_mesh: The underlying `jax.sharding.Mesh` object defining the device grid - and axis names. - _sharder: The default sharding strategy callable (like SingleAxisSharder) - used to determine the PartitionSpec for each parameter if not overridden - during method calls. Can be None if no default is appropriate or set. + Args: + axis_name: The name to assign to the single mesh axis (default: "fsdp"). + This name will be used by the default `SingleAxisSharder`. + + Returns: + A Mesh instance configured with a 1D JAX mesh across all devices and a + corresponding SingleAxisSharder. + """ + ndevice = jax.device_count() + jax_mesh = jax.make_mesh((ndevice,), (axis_name,)) + # replicate_unshardable so scalars and small model attributes are replicated. + return cls(jax_mesh, SingleAxisSharder(axis_name, ndevice, True)) + + def __init__(self, jax_mesh, sharder=None): + """Initializes the Mesh helper. + + Args: + jax_mesh: A pre-configured `jax.sharding.Mesh` object defining the + physical device grid and logical axis names. + sharder: An optional callable (e.g., an instance of SingleAxisSharder) + that takes (name, shapedtype) and returns a `jax.sharding.PartitionSpec`. + This serves as the default sharding strategy. + If None, and the provided `jax_mesh` has exactly one axis, a + `SingleAxisSharder` is created automatically for that single axis. + If None and the mesh has multiple axes, `_sharder` remains None, and + an `override_sharder` must be provided to methods like + `initialize_model_sharded`. """ + self.jax_mesh = jax_mesh + if sharder is None: + assert len(self.jax_mesh.axis_names) == 1 + sharder = SingleAxisSharder( + self.jax_mesh.axis_names[0], len(self.mesh.device_ids) + ) + self._sharder = sharder + + def initialize_model_sharded( + self, model_class, init_args, init_kwargs=None, override_sharder=None + ): + """Initializes a PyTorch model with its parameters sharded across the mesh. + + This method orchestrates the initialization of a `torch.nn.Module` such + that its parameters are created directly on the target devices according + to the sharding specifications derived from the mesh and the chosen sharder. + It leverages `torchax.interop.jax_jit` to achieve this. + + Args: + model_class: The PyTorch model class (a subclass of `torch.nn.Module`). + init_args: A tuple containing the positional arguments required by the + `model_class.__init__` method. + init_kwargs: An optional dictionary containing the keyword arguments for + the `model_class.__init__` method. Defaults to None (treated as {}). + override_sharder: An optional callable sharding strategy to use + specifically for this initialization. If provided, it takes precedence + over the mesh's default `_sharder`. It must accept `(name, shapedtype)` + and return a `PartitionSpec`. If None, the mesh's default `_sharder` + is used. - @classmethod - def fsdp_mesh(cls, axis_name="fsdp"): - """Creates a Mesh instance suitable for 1D FSDP-style sharding. - - This named constructor creates a 1D mesh encompassing all available XLA - devices. It assigns the specified `axis_name` to this single dimension. - It then creates a `Mesh` instance using this JAX mesh and a - `SingleAxisSharder` configured appropriately for this 1D mesh. - - Args: - axis_name: The name to assign to the single mesh axis (default: "fsdp"). - This name will be used by the default `SingleAxisSharder`. - - Returns: - A Mesh instance configured with a 1D JAX mesh across all devices and a - corresponding SingleAxisSharder. - """ - ndevice = jax.device_count() - jax_mesh = jax.make_mesh((ndevice,), (axis_name,)) - # replicate_unshardable so scalars and small model attributes are replicated. - return cls(jax_mesh, SingleAxisSharder(axis_name, ndevice, True)) - - def __init__(self, jax_mesh, sharder=None): - """Initializes the Mesh helper. - - Args: - jax_mesh: A pre-configured `jax.sharding.Mesh` object defining the - physical device grid and logical axis names. - sharder: An optional callable (e.g., an instance of SingleAxisSharder) - that takes (name, shapedtype) and returns a `jax.sharding.PartitionSpec`. - This serves as the default sharding strategy. - If None, and the provided `jax_mesh` has exactly one axis, a - `SingleAxisSharder` is created automatically for that single axis. - If None and the mesh has multiple axes, `_sharder` remains None, and - an `override_sharder` must be provided to methods like - `initialize_model_sharded`. - """ - self.jax_mesh = jax_mesh - if sharder is None: - assert len(self.jax_mesh.axis_names) == 1 - sharder = SingleAxisSharder( - self.jax_mesh.axis_names[0], len(self.mesh.device_ids) - ) - self._sharder = sharder - - def initialize_model_sharded( - self, model_class, init_args, init_kwargs=None, override_sharder=None - ): - """Initializes a PyTorch model with its parameters sharded across the mesh. - - This method orchestrates the initialization of a `torch.nn.Module` such - that its parameters are created directly on the target devices according - to the sharding specifications derived from the mesh and the chosen sharder. - It leverages `torchax.interop.jax_jit` to achieve this. - - Args: - model_class: The PyTorch model class (a subclass of `torch.nn.Module`). - init_args: A tuple containing the positional arguments required by the - `model_class.__init__` method. - init_kwargs: An optional dictionary containing the keyword arguments for - the `model_class.__init__` method. Defaults to None (treated as {}). - override_sharder: An optional callable sharding strategy to use - specifically for this initialization. If provided, it takes precedence - over the mesh's default `_sharder`. It must accept `(name, shapedtype)` - and return a `PartitionSpec`. If None, the mesh's default `_sharder` - is used. - - Returns: - An instance of `model_class` whose parameters have been initialized and - are represented by sharded tensors distributed across the devices in the - `jax_mesh`. - - Raises: - ValueError: If no sharder is available (i.e., `override_sharder` is None - and the mesh's default `_sharder` is also None). - AssertionError: Can be raised by the sharder (e.g., `SingleAxisSharder`) - if it fails to determine a valid sharding for any parameter. - TypeError: If `shapedtype` passed to the sharder doesn't have a `.shape`. - Other errors from JAX JIT compilation or PyTorch model initialization. - """ - init_kwargs = init_kwargs or {} - with torch.device("meta"), torchax.disable_temporarily(): - model = model_class(*init_args, **init_kwargs) - - sharder = override_sharder or self._sharder - - states = model.state_dict() - output_shards = { - name: NamedSharding(self.jax_mesh, sharder(name, tensor)) - for name, tensor in states.items() - } - - def model_initializer(): - with torchax.default_env(): - model = model_class(*init_args, **init_kwargs) - return dict(model.state_dict()) - - jitted = interop.jax_jit( - model_initializer, - kwargs_for_jax_jit={"out_shardings": output_shards}, - ) - weights_dict = jitted() - - model.load_state_dict(weights_dict, assign=True) - return model + Returns: + An instance of `model_class` whose parameters have been initialized and + are represented by sharded tensors distributed across the devices in the + `jax_mesh`. + + Raises: + ValueError: If no sharder is available (i.e., `override_sharder` is None + and the mesh's default `_sharder` is also None). + AssertionError: Can be raised by the sharder (e.g., `SingleAxisSharder`) + if it fails to determine a valid sharding for any parameter. + TypeError: If `shapedtype` passed to the sharder doesn't have a `.shape`. + Other errors from JAX JIT compilation or PyTorch model initialization. + """ + init_kwargs = init_kwargs or {} + with torch.device("meta"), torchax.disable_temporarily(): + model = model_class(*init_args, **init_kwargs) + + sharder = override_sharder or self._sharder + + states = model.state_dict() + output_shards = { + name: NamedSharding(self.jax_mesh, sharder(name, tensor)) + for name, tensor in states.items() + } + + def model_initializer(): + with torchax.default_env(): + model = model_class(*init_args, **init_kwargs) + return dict(model.state_dict()) + + jitted = interop.jax_jit( + model_initializer, + kwargs_for_jax_jit={"out_shardings": output_shards}, + ) + weights_dict = jitted() + + model.load_state_dict(weights_dict, assign=True) + return model diff --git a/torchax/torchax/ops/__init__.py b/torchax/torchax/ops/__init__.py index d306871dd7ac..a68521616574 100644 --- a/torchax/torchax/ops/__init__.py +++ b/torchax/torchax/ops/__init__.py @@ -1,10 +1,10 @@ def all_aten_jax_ops(): - # to load the ops - import torchax.ops.jaten # type: ignore - import torchax.ops.ops_registry # type: ignore + # to load the ops + import torchax.ops.jaten # type: ignore + import torchax.ops.ops_registry # type: ignore - return { - key: val.func - for key, val in torchax.ops.ops_registry.all_aten_ops.items() - if val.is_jax_function - } + return { + key: val.func + for key, val in torchax.ops.ops_registry.all_aten_ops.items() + if val.is_jax_function + } diff --git a/torchax/torchax/ops/jaten.py b/torchax/torchax/ops/jaten.py index 1f7e4a06902b..2d9a93a794fc 100644 --- a/torchax/torchax/ops/jaten.py +++ b/torchax/torchax/ops/jaten.py @@ -23,563 +23,559 @@ def op(*aten, **kwargs): - def inner(func): - for a in aten: - ops_registry.register_torch_dispatch_op(a, func, **kwargs) - continue - - if isinstance(a, torch._ops.OpOverloadPacket): - opname = ( - a.default.name() - if "default" in a.overloads() - else a._qualified_op_name - ) - elif isinstance(a, torch._ops.OpOverload): - opname = a.name() - else: - raise RuntimeError(f"oops {a}") - - torchfunc = functools.partial(interop.call_jax, func) - # HACK: to_copy is where we make the initial conversion from CPU tensor to JAX tensor - torch.library.impl(opname, "privateuseone")( - torchfunc if a != torch.ops.aten._to_copy else func - ) - return func - - return inner + def inner(func): + for a in aten: + ops_registry.register_torch_dispatch_op(a, func, **kwargs) + continue + + if isinstance(a, torch._ops.OpOverloadPacket): + opname = ( + a.default.name() + if "default" in a.overloads() + else a._qualified_op_name + ) + elif isinstance(a, torch._ops.OpOverload): + opname = a.name() + else: + raise RuntimeError(f"oops {a}") + + torchfunc = functools.partial(interop.call_jax, func) + # HACK: to_copy is where we make the initial conversion from CPU tensor to JAX tensor + torch.library.impl(opname, "privateuseone")( + torchfunc if a != torch.ops.aten._to_copy else func + ) + return func + + return inner @op( - torch.ops.aten.view_copy, - torch.ops.aten.view, - torch.ops.aten._unsafe_view, - torch.ops.aten.reshape, + torch.ops.aten.view_copy, + torch.ops.aten.view, + torch.ops.aten._unsafe_view, + torch.ops.aten.reshape, ) def _aten_unsafe_view(x, shape): - return jnp.reshape(x, shape) + return jnp.reshape(x, shape) @op(torch.ops.aten.add.Tensor) @op(torch.ops.aten.add.Scalar) def _aten_add(x, y, *, alpha=1): - """if isinstance(x, jnp.ndarray) and isinstance(y, jnp.ndarray): + """if isinstance(x, jnp.ndarray) and isinstance(y, jnp.ndarray): - assert x.dtype == y.dtype, (x.dtype, y.dtype) - """ - res = x + y * alpha - if isinstance(x, float) or isinstance(y, float): - new_dtype = mappings.t2j_dtype(torch.get_default_dtype()) - res = res.astype(new_dtype) - return res + assert x.dtype == y.dtype, (x.dtype, y.dtype) + """ + res = x + y * alpha + if isinstance(x, float) or isinstance(y, float): + new_dtype = mappings.t2j_dtype(torch.get_default_dtype()) + res = res.astype(new_dtype) + return res @op( - torch.ops.aten.copy_, is_jax_function=False, is_view_op=True, needs_env=True + torch.ops.aten.copy_, is_jax_function=False, is_view_op=True, needs_env=True ) def _aten_copy(x, y, memory_format=None, env=None): - if y.device.type == "cpu": - y = env.to_xla(y) - - if isinstance(x, View): - x.update(y) - return x - - if x.ndim == 1 and y.ndim == 0: - # case of torch.empty((1,)).copy_(tensor(N)) - # we need to return 0D tensor([N]) and not scalar tensor(N) - # ref: https://github.com/pytorch/xla/issues/7505#issuecomment-2395319131 - x._elem = jnp.array([y._elem.astype(x._elem.dtype)]) - else: - x._elem = y._elem.astype(x._elem.dtype) + if y.device.type == "cpu": + y = env.to_xla(y) + + if isinstance(x, View): + x.update(y) return x + if x.ndim == 1 and y.ndim == 0: + # case of torch.empty((1,)).copy_(tensor(N)) + # we need to return 0D tensor([N]) and not scalar tensor(N) + # ref: https://github.com/pytorch/xla/issues/7505#issuecomment-2395319131 + x._elem = jnp.array([y._elem.astype(x._elem.dtype)]) + else: + x._elem = y._elem.astype(x._elem.dtype) + return x + @op(torch.ops.aten.clone) def _aten_clone(x, memory_format=None): - return x + return x # aten.trunc @op(torch.ops.aten.trunc) def _aten_trunc(x): - res = jnp.trunc(x) - return res.astype(x) + res = jnp.trunc(x) + return res.astype(x) @op(torch.ops.aten.index_copy) def _aten_index_copy(x, dim, indexes, source): - if x.ndim == 0: - return source - if x.ndim == 1: - source = jnp.squeeze(source) - # return jax.lax.scatter(x, index, dim) - if dim < 0: - dim = dim + x.ndim - dims = [] - for i in range(len(x.shape)): - if i == dim: - dims.append(indexes) - else: - dims.append(slice(None, None, None)) - return x.at[tuple(dims)].set(source) + if x.ndim == 0: + return source + if x.ndim == 1: + source = jnp.squeeze(source) + # return jax.lax.scatter(x, index, dim) + if dim < 0: + dim = dim + x.ndim + dims = [] + for i in range(len(x.shape)): + if i == dim: + dims.append(indexes) + else: + dims.append(slice(None, None, None)) + return x.at[tuple(dims)].set(source) # aten.cauchy_ @op(torch.ops.aten.cauchy_) def _aten_cauchy_(x, median=0, sigma=1): - """ - Fills the input array with values drawn from a Cauchy distribution. + """ + Fills the input array with values drawn from a Cauchy distribution. - Args: - x: An array to be filled with Cauchy samples. - median: The median of the Cauchy distribution. - sigma: The scale parameter of the Cauchy distribution. + Args: + x: An array to be filled with Cauchy samples. + median: The median of the Cauchy distribution. + sigma: The scale parameter of the Cauchy distribution. - Returns: - The input array filled with Cauchy samples. - """ - key = jax.random.PRNGKey(0) # You should use a different key for each call - samples = jax.random.cauchy(key, x.shape) * sigma + median - return x.at[:].set(samples) + Returns: + The input array filled with Cauchy samples. + """ + key = jax.random.PRNGKey(0) # You should use a different key for each call + samples = jax.random.cauchy(key, x.shape) * sigma + median + return x.at[:].set(samples) @op(torch.ops.aten.atleast_2d) def _aten_atleast_2d(inputs): - return jnp.atleast_2d(inputs) + return jnp.atleast_2d(inputs) @op(torch.ops.aten.atleast_1d) def _aten_atleast_1d(inputs): - return jnp.atleast_1d(inputs) + return jnp.atleast_1d(inputs) # aten.complex @op(torch.ops.aten.complex) def _aten_complex(real, imag): - """ - Constructs a complex array from real and imaginary parts. + """ + Constructs a complex array from real and imaginary parts. - Args: - real: An array of real values. - imag: An array of imaginary values. + Args: + real: An array of real values. + imag: An array of imaginary values. - Returns: - A complex array with the specified real and imaginary parts. - """ - return jnp.array(real, dtype=jnp.float32) + 1j * jnp.array( - imag, dtype=jnp.float32 - ) + Returns: + A complex array with the specified real and imaginary parts. + """ + return jnp.array(real, dtype=jnp.float32) + 1j * jnp.array( + imag, dtype=jnp.float32 + ) # aten.exponential_ @op(torch.ops.aten.exponential_) def _aten_exponential_(x, lambd=1.0): - """ - Fills the input array with values drawn from an exponential distribution. + """ + Fills the input array with values drawn from an exponential distribution. - Args: - x: An array to be filled with exponential samples. - lambd: The rate parameter of the exponential distribution. + Args: + x: An array to be filled with exponential samples. + lambd: The rate parameter of the exponential distribution. - Returns: - The input array filled with exponential samples. - """ - key = jax.random.PRNGKey(0) # Use a different key for each call - samples = jax.random.exponential(key, x.shape) / lambd - return x.at[:].set(samples) + Returns: + The input array filled with exponential samples. + """ + key = jax.random.PRNGKey(0) # Use a different key for each call + samples = jax.random.exponential(key, x.shape) / lambd + return x.at[:].set(samples) # aten.linalg_householder_product @op(torch.ops.aten.linalg_householder_product) def _aten_linalg_householder_product(input, tau): - return jax.lax.linalg.householder_product(a=input, taus=tau) + return jax.lax.linalg.householder_product(a=input, taus=tau) @op(torch.ops.aten.select) def _aten_select(x, dim, indexes): - return jax.lax.index_in_dim(x, index=indexes, axis=dim, keepdims=False) + return jax.lax.index_in_dim(x, index=indexes, axis=dim, keepdims=False) @op(torch.ops.aten.index_select) @op(torch.ops.aten.select_copy) def _aten_index_select(x, dim, index): - if x.shape == (): - return x - return jnp.take(x, index, dim) + if x.shape == (): + return x + return jnp.take(x, index, dim) @op(torch.ops.aten.cholesky) def _aten_cholesky(input, upper=False): - return jax.scipy.linalg.cholesky(input, lower=(not upper)) + return jax.scipy.linalg.cholesky(input, lower=(not upper)) @op(torch.ops.aten.linalg_cholesky_ex) def _aten_linalg_cholesky_ex(input, upper=False, check_errors=False): - if check_errors: - raise NotImplementedError( - "check_errors=True is not supported in this JAX implementation. " - "Check for positive definiteness using jnp.linalg.eigvalsh before " - "calling this function." - ) + if check_errors: + raise NotImplementedError( + "check_errors=True is not supported in this JAX implementation. " + "Check for positive definiteness using jnp.linalg.eigvalsh before " + "calling this function." + ) - L = jax.scipy.linalg.cholesky(input, lower=not upper) - if len(L.shape) > 2: - info = jnp.zeros(shape=L.shape[:-2], dtype=jnp.int32) - else: - info = jnp.array(0, dtype=jnp.int32) - return L, info + L = jax.scipy.linalg.cholesky(input, lower=not upper) + if len(L.shape) > 2: + info = jnp.zeros(shape=L.shape[:-2], dtype=jnp.int32) + else: + info = jnp.array(0, dtype=jnp.int32) + return L, info @op(torch.ops.aten.cholesky_solve) def _aten_cholesky_solve(input, input2, upper=False): - # Ensure input2 is lower triangular for cho_solve - L = input2 if not upper else input2.T - # Use cho_solve to solve the linear system - solution = jax.scipy.linalg.cho_solve((L, True), input) - return solution + # Ensure input2 is lower triangular for cho_solve + L = input2 if not upper else input2.T + # Use cho_solve to solve the linear system + solution = jax.scipy.linalg.cho_solve((L, True), input) + return solution @op(torch.ops.aten.special_zeta) def _aten_special_zeta(x, q): - new_dtype = mappings.t2j_dtype(torch.get_default_dtype()) - res = jax.scipy.special.zeta(x, q) - if isinstance(x, int) or isinstance(q, int): - res = res.astype(new_dtype) - return res # jax.scipy.special.zeta(x, q) + new_dtype = mappings.t2j_dtype(torch.get_default_dtype()) + res = jax.scipy.special.zeta(x, q) + if isinstance(x, int) or isinstance(q, int): + res = res.astype(new_dtype) + return res # jax.scipy.special.zeta(x, q) # aten.igammac @op(torch.ops.aten.igammac) def _aten_igammac(input, other): - if isinstance(input, jnp.ndarray): - input = jnp.where(input < 0, jnp.nan, input) - if isinstance(other, jnp.ndarray): - other = jnp.where(other < 0, jnp.nan, other) - else: - if (input == 0 and other == 0) or (input < 0) or (other < 0): - other = jnp.nan - return jnp.array(jax.scipy.special.gammaincc(input, other)) + if isinstance(input, jnp.ndarray): + input = jnp.where(input < 0, jnp.nan, input) + if isinstance(other, jnp.ndarray): + other = jnp.where(other < 0, jnp.nan, other) + else: + if (input == 0 and other == 0) or (input < 0) or (other < 0): + other = jnp.nan + return jnp.array(jax.scipy.special.gammaincc(input, other)) @op(torch.ops.aten.mean) def _aten_mean(x, dim=None, keepdim=False): - if x.shape == () and dim is not None: - dim = None # disable dim for jax array without dim - return jnp.mean(x, dim, keepdims=keepdim) + if x.shape == () and dim is not None: + dim = None # disable dim for jax array without dim + return jnp.mean(x, dim, keepdims=keepdim) def _torch_binary_scalar_type(scalar, tensor): - if "float" in str(tensor.dtype) or "complex" in str(tensor.dtype): - return tensor.dtype + if "float" in str(tensor.dtype) or "complex" in str(tensor.dtype): + return tensor.dtype - if isinstance(scalar, int): - if "int" in str(tensor.dtype): - return tensor.dtype + if isinstance(scalar, int): + if "int" in str(tensor.dtype): + return tensor.dtype - return jnp.float32 + return jnp.float32 @op(torch.ops.aten.searchsorted.Tensor) def _aten_searchsorted(sorted_sequence, values): - new_dtype = mappings.t2j_dtype(torch.get_default_dtype()) - res = jnp.searchsorted(sorted_sequence, values) - if sorted_sequence.dtype == np.dtype( - np.int32 - ) or sorted_sequence.dtype == np.dtype(np.int32): - # res = res.astype(new_dtype) - res = res.astype(np.dtype(np.int64)) - return res # jnp.searchsorted(sorted_sequence, values) + new_dtype = mappings.t2j_dtype(torch.get_default_dtype()) + res = jnp.searchsorted(sorted_sequence, values) + if sorted_sequence.dtype == np.dtype( + np.int32 + ) or sorted_sequence.dtype == np.dtype(np.int32): + # res = res.astype(new_dtype) + res = res.astype(np.dtype(np.int64)) + return res # jnp.searchsorted(sorted_sequence, values) @op(torch.ops.aten.sub.Tensor) @op(torch.ops.aten.sub.Scalar) def _aten_sub(x, y, alpha=1): - if isinstance(x, float): - dtype = _torch_binary_scalar_type(x, y) - x = jnp.array(x, dtype=dtype) - if isinstance(y, float): - dtype = _torch_binary_scalar_type(y, x) - y = jnp.array(y, dtype=dtype) - return x - y * alpha + if isinstance(x, float): + dtype = _torch_binary_scalar_type(x, y) + x = jnp.array(x, dtype=dtype) + if isinstance(y, float): + dtype = _torch_binary_scalar_type(y, x) + y = jnp.array(y, dtype=dtype) + return x - y * alpha @op(torch.ops.aten.numpy_T) def _aten_numpy_T(input): - """ - Jax implementation of torch.numpy_T. + """ + Jax implementation of torch.numpy_T. - Args: - input: JAX array. + Args: + input: JAX array. - Returns: - Transposed JAX array. - """ - return jnp.transpose(input) + Returns: + Transposed JAX array. + """ + return jnp.transpose(input) @op(torch.ops.aten.mm) def _aten_mm(x, y): - res = x @ y - return res + res = x @ y + return res @op(torch.ops.aten.mul.Tensor, torch.ops.aten.mul.Scalar) def _aten_mul(x, y): - new_dtype = mappings.t2j_dtype(torch.get_default_dtype()) - res = x * y - if isinstance(x, float) or isinstance(y, float): + new_dtype = mappings.t2j_dtype(torch.get_default_dtype()) + res = x * y + if isinstance(x, float) or isinstance(y, float): + res = res.astype(new_dtype) + else: + if (not isinstance(x, int)) and (not isinstance(y, int)): + if x.dtype == np.dtype(np.float64) or y.dtype == np.dtype(np.float64): res = res.astype(new_dtype) - else: - if (not isinstance(x, int)) and (not isinstance(y, int)): - if x.dtype == np.dtype(np.float64) or y.dtype == np.dtype( - np.float64 - ): - res = res.astype(new_dtype) - return res + return res @op(torch.ops.aten.silu) @op(torch.ops.aten.silu.default) def _aten_silu(x): - return jax.nn.silu(x) + return jax.nn.silu(x) @op(torch.ops.aten.t) def _aten_t(x): - return jnp.transpose(x) + return jnp.transpose(x) @op(torch.ops.aten.transpose) @op(torch.ops.aten.transpose_copy) def _aten_transpose(x, dim0, dim1): - if x.ndim == 0: - return x - dim0 = dim0 if dim0 >= 0 else dim0 + x.ndim - dim1 = dim1 if dim1 >= 0 else dim1 + x.ndim - return jnp.swapaxes(x, dim0, dim1) + if x.ndim == 0: + return x + dim0 = dim0 if dim0 >= 0 else dim0 + x.ndim + dim1 = dim1 if dim1 >= 0 else dim1 + x.ndim + return jnp.swapaxes(x, dim0, dim1) @op(torch.ops.aten.triu) def _aten_triu(m, k=0): - return jnp.triu(m, k) + return jnp.triu(m, k) @op(torch.ops.aten.slice) @op(torch.ops.aten.slice_copy) def _aten_slice(self, dim=0, start=None, end=None, step=1): - if dim < 0: - dim += self.ndim - if end == sys.maxsize: - end = self.shape[dim] - sl = slice(start, end, step) - dims = [] - for i in range(len(self.shape)): - if i == dim: - dims.append(sl) - else: - dims.append(slice(None, None, None)) - return self[tuple(dims)] + if dim < 0: + dim += self.ndim + if end == sys.maxsize: + end = self.shape[dim] + sl = slice(start, end, step) + dims = [] + for i in range(len(self.shape)): + if i == dim: + dims.append(sl) + else: + dims.append(slice(None, None, None)) + return self[tuple(dims)] @op(torch.ops.aten.detach) def _aten_detach(self): - return self + return self @op(torch.ops.aten.imag) def _aten_imag(x): - return jnp.imag(x) + return jnp.imag(x) @op(torch.ops.aten.isfinite) def _aten_isfinite(x): - return jnp.isfinite(x) + return jnp.isfinite(x) @op(torch.ops.aten.real) def _aten_real(x): - return jnp.real(x) + return jnp.real(x) @op(torch.Tensor.resize_) def _aten_resize_(x, size, interpolation="linear"): - new_size = tuple(size) - return jax.numpy.resize(x, new_size) + new_size = tuple(size) + return jax.numpy.resize(x, new_size) @op(torch.ops.aten.resize_as_) def _aten_resize_as_(x, y): - return jax.numpy.resize(x, y.shape) + return jax.numpy.resize(x, y.shape) @op(torch.ops.aten.repeat_interleave.Tensor) def repeat_interleave(repeats, dim=0): - return jnp.repeat(np.arange(repeats.shape[dim]), repeats) + return jnp.repeat(np.arange(repeats.shape[dim]), repeats) @op(torch.ops.aten.repeat_interleave.self_int) @op(torch.ops.aten.repeat_interleave.self_Tensor) def repeat_interleave(self, repeats, dim=0): - total_repeat_length = None - if isinstance(repeats, int): - total_repeat_length = self.shape[dim] * repeats - repeats = np.array([repeats] * self.shape[dim]) - return jnp.repeat( - self, repeats, dim, total_repeat_length=total_repeat_length - ) + total_repeat_length = None + if isinstance(repeats, int): + total_repeat_length = self.shape[dim] * repeats + repeats = np.array([repeats] * self.shape[dim]) + return jnp.repeat(self, repeats, dim, total_repeat_length=total_repeat_length) @op(torch.ops.aten.view_as_real) def _aten_view_as_real(x): - real = jnp.real(x) - im = jnp.imag(x) - res = jnp.stack([real, im], -1) - return res + real = jnp.real(x) + im = jnp.imag(x) + res = jnp.stack([real, im], -1) + return res @op(torch.ops.aten.stack) def _aten_stack(tensors, dim=0): - return jnp.stack(tensors, dim) + return jnp.stack(tensors, dim) @op(torch.ops.aten._softmax) @op(torch.ops.aten.softmax) @op(torch.ops.aten.softmax.int) def _aten_softmax(x, dim, halftofloat=False): - if x.shape == (): - return jax.nn.softmax(x.reshape([1]), axis=0).reshape([]) - return jax.nn.softmax(x, dim) + if x.shape == (): + return jax.nn.softmax(x.reshape([1]), axis=0).reshape([]) + return jax.nn.softmax(x, dim) def _is_int(x): - if isinstance(x, int): - return True - if isinstance(x, jax.Array) and ( - x.dtype.name.startswith("int") or x.dtype.name.startswith("uint") - ): - return True - return False + if isinstance(x, int): + return True + if isinstance(x, jax.Array) and ( + x.dtype.name.startswith("int") or x.dtype.name.startswith("uint") + ): + return True + return False def highest_precision_int_dtype(tensor1, tensor2): - if isinstance(tensor1, int): - return tensor2.dtype - if isinstance(tensor2, int): - return tensor1.dtype - - dtype_hierarchy = { - "uint8": 8, - "int8": 8, - "uint16": 16, - "int16": 16, - "uint32": 32, - "int32": 32, - "uint64": 64, - "int64": 64, - } - return max( - tensor1.dtype, - tensor2.dtype, - key=lambda dtype: dtype_hierarchy[str(dtype)], - ) + if isinstance(tensor1, int): + return tensor2.dtype + if isinstance(tensor2, int): + return tensor1.dtype + + dtype_hierarchy = { + "uint8": 8, + "int8": 8, + "uint16": 16, + "int16": 16, + "uint32": 32, + "int32": 32, + "uint64": 64, + "int64": 64, + } + return max( + tensor1.dtype, + tensor2.dtype, + key=lambda dtype: dtype_hierarchy[str(dtype)], + ) @op(torch.ops.aten.pow) def _aten_pow(x, y): - y_orig = y - if isinstance(y, int): - y = float(y) - if _is_int(x) and _is_int(y_orig): - # Do the math in float then cast - res = jnp.power(jnp.astype(x, jnp.dtype("float")), y) - return res.astype(highest_precision_int_dtype(x, y_orig)) - res = jnp.power(x, y) - if isinstance(x, float): - return res.astype(_torch_binary_scalar_type(x, y_orig)) - if isinstance(y_orig, float): - return res.astype(_torch_binary_scalar_type(y_orig, x)) - return res + y_orig = y + if isinstance(y, int): + y = float(y) + if _is_int(x) and _is_int(y_orig): + # Do the math in float then cast + res = jnp.power(jnp.astype(x, jnp.dtype("float")), y) + return res.astype(highest_precision_int_dtype(x, y_orig)) + res = jnp.power(x, y) + if isinstance(x, float): + return res.astype(_torch_binary_scalar_type(x, y_orig)) + if isinstance(y_orig, float): + return res.astype(_torch_binary_scalar_type(y_orig, x)) + return res @op(torch.ops.aten.view_as_complex) def _aten_view_as_complex(input): - if input.dtype == jnp.bfloat16: - input = input.astype(jnp.float32) - x, y = input[..., 0], input[..., 1] - return jax.lax.complex(x, y) + if input.dtype == jnp.bfloat16: + input = input.astype(jnp.float32) + x, y = input[..., 0], input[..., 1] + return jax.lax.complex(x, y) @op(torch.ops.aten.div) def _aten_div(x, y, rounding_mode=""): - res_dtype = None - if _is_int(x) and _is_int(y): - res_dtype = jnp.dtype("float32") + res_dtype = None + if _is_int(x) and _is_int(y): + res_dtype = jnp.dtype("float32") - if isinstance(x, float) or isinstance(y, float): - res_dtype = new_dtype = mappings.t2j_dtype(torch.get_default_dtype()) + if isinstance(x, float) or isinstance(y, float): + res_dtype = new_dtype = mappings.t2j_dtype(torch.get_default_dtype()) - if rounding_mode == "floor": - res = jnp.floor_divide(x, y) - if _is_int(x) and _is_int(y): - res_dtype = jnp.dtype("int64") - else: - res = x / y - if rounding_mode == "trunc": - res = jnp.trunc(res) - if _is_int(x) and _is_int(y): - res_dtype = jnp.dtype("int64") - if res_dtype: - res = res.astype(res_dtype) - return res + if rounding_mode == "floor": + res = jnp.floor_divide(x, y) + if _is_int(x) and _is_int(y): + res_dtype = jnp.dtype("int64") + else: + res = x / y + if rounding_mode == "trunc": + res = jnp.trunc(res) + if _is_int(x) and _is_int(y): + res_dtype = jnp.dtype("int64") + if res_dtype: + res = res.astype(res_dtype) + return res @op(torch.ops.aten.true_divide) def _aten_true_divide(x, y): - return x / y + return x / y @op(torch.ops.aten.dist) def _aten_dist(input, other, p=2): - diff = jnp.abs(jnp.subtract(input, other)) - return _aten_linalg_vector_norm(diff, ord=p) + diff = jnp.abs(jnp.subtract(input, other)) + return _aten_linalg_vector_norm(diff, ord=p) @op(torch.ops.aten.bmm) def _aten_bmm(x, y): - res = x @ y - return res - # return jnp.einsum('bnm,bmk->bnk', x, y) + res = x @ y + return res + # return jnp.einsum('bnm,bmk->bnk', x, y) @op(torch.ops.aten.embedding) # embedding(Tensor weight, Tensor indices, SymInt padding_idx=-1, bool scale_grad_by_freq=False, bool sparse=False) def _aten_embedding( - a, w, padding_idx=-1, scale_grad_by_freq=False, sparse=False + a, w, padding_idx=-1, scale_grad_by_freq=False, sparse=False ): - return jnp.take(a, w, axis=0) + return jnp.take(a, w, axis=0) @op(torch.ops.aten.embedding_renorm_) def _aten_embedding_renorm_(weight, indices, max_norm, norm_type): - # Adapted from https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/native/Embedding.cpp - unique_indices = jnp.unique(indices) + # Adapted from https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/native/Embedding.cpp + unique_indices = jnp.unique(indices) - norm = jnp.linalg.norm( - _aten_embedding(weight, unique_indices), - ord=norm_type, - axis=1, - ) + norm = jnp.linalg.norm( + _aten_embedding(weight, unique_indices), + ord=norm_type, + axis=1, + ) - indice_idx = jnp.where(norm > max_norm) + indice_idx = jnp.where(norm > max_norm) - scale = max_norm / (norm[indice_idx] + 1e-7) + scale = max_norm / (norm[indice_idx] + 1e-7) - indices_to_update = unique_indices[indice_idx] + indices_to_update = unique_indices[indice_idx] - weight = weight.at[indices_to_update].set( - weight[indices_to_update] * scale[:, None] - ) - return weight + weight = weight.at[indices_to_update].set( + weight[indices_to_update] * scale[:, None] + ) + return weight # - func: _embedding_bag_forward_only( @@ -588,236 +584,236 @@ def _aten_embedding_renorm_(weight, indices, max_norm, norm_type): @op(torch.ops.aten._embedding_bag) @op(torch.ops.aten._embedding_bag_forward_only) def _aten__embedding_bag( - weight, - indices, - offsets=None, - scale_grad_by_freq=False, - mode=0, - sparse=False, - per_sample_weights=None, - include_last_offset=False, - padding_idx=-1, + weight, + indices, + offsets=None, + scale_grad_by_freq=False, + mode=0, + sparse=False, + per_sample_weights=None, + include_last_offset=False, + padding_idx=-1, ): - """Jax implementation of the PyTorch _embedding_bag function. - - Args: - weight: The learnable weights of the module of shape (num_embeddings, embedding_dim). - indices: A LongTensor containing the indices to extract. - offsets: A LongTensor containing the starting offset of each bag. - scale_grad_by_freq: Whether to scale gradients by the inverse of frequency of the words in the mini-batch. - mode: 0 = "sum", 1 = "mean" or 2 = "max" - sparse: Whether the gradients with respect to weight should be a sparse tensor. - per_sample_weights: If given, each embedding vector is weighted by per_sample_weights - include_last_offset: Whether to include the last offset as a valid bag. - padding_idx: If specified, the entries at padding_idx do not contribute to the gradient. - - Returns: - A tuple of (output, offset2bag, bag_size, max_indices). - """ - embedded = _aten_embedding(weight, indices, padding_idx) - - if offsets is None: - # offsets is None only when indices.ndim > 1 - if mode == 0: # sum - output = jnp.sum(embedded, axis=1) - elif mode == 1: # mean - output = jnp.mean(embedded, axis=1) - elif mode == 2: # max - output = jnp.max(embedded, axis=1) - return output, None, None, None - - if isinstance(offsets, jax.Array): - offsets_np = np.array(offsets) - else: - offsets_np = offsets - offset2bag = np.zeros(indices.shape[0], dtype=np.int64) - bag_size = np.zeros(offsets_np.shape[0], dtype=np.int64) - max_indices = jnp.full_like(indices, -1) - - for bag in range(offsets_np.shape[0]): - start = int(offsets_np[bag]) - - end = int( - indices.shape[0] - if bag + 1 == offsets_np.shape[0] - else offsets_np[bag + 1] + """Jax implementation of the PyTorch _embedding_bag function. + + Args: + weight: The learnable weights of the module of shape (num_embeddings, embedding_dim). + indices: A LongTensor containing the indices to extract. + offsets: A LongTensor containing the starting offset of each bag. + scale_grad_by_freq: Whether to scale gradients by the inverse of frequency of the words in the mini-batch. + mode: 0 = "sum", 1 = "mean" or 2 = "max" + sparse: Whether the gradients with respect to weight should be a sparse tensor. + per_sample_weights: If given, each embedding vector is weighted by per_sample_weights + include_last_offset: Whether to include the last offset as a valid bag. + padding_idx: If specified, the entries at padding_idx do not contribute to the gradient. + + Returns: + A tuple of (output, offset2bag, bag_size, max_indices). + """ + embedded = _aten_embedding(weight, indices, padding_idx) + + if offsets is None: + # offsets is None only when indices.ndim > 1 + if mode == 0: # sum + output = jnp.sum(embedded, axis=1) + elif mode == 1: # mean + output = jnp.mean(embedded, axis=1) + elif mode == 2: # max + output = jnp.max(embedded, axis=1) + return output, None, None, None + + if isinstance(offsets, jax.Array): + offsets_np = np.array(offsets) + else: + offsets_np = offsets + offset2bag = np.zeros(indices.shape[0], dtype=np.int64) + bag_size = np.zeros(offsets_np.shape[0], dtype=np.int64) + max_indices = jnp.full_like(indices, -1) + + for bag in range(offsets_np.shape[0]): + start = int(offsets_np[bag]) + + end = int( + indices.shape[0] + if bag + 1 == offsets_np.shape[0] + else offsets_np[bag + 1] + ) + bag_size[bag] = end - start + offset2bag = offset2bag.at[start:end].set(bag) + + if end - start > 0: + if mode == 0: + output_bag = jnp.sum(embedded[start:end], axis=0) + elif mode == 1: + output_bag = jnp.mean(embedded[start:end], axis=0) + elif mode == 2: + output_bag = jnp.max(embedded[start:end], axis=0) + max_indices = max_indices.at[start:end].set( + jnp.argmax(embedded[start:end], axis=0) ) - bag_size[bag] = end - start - offset2bag = offset2bag.at[start:end].set(bag) - if end - start > 0: - if mode == 0: - output_bag = jnp.sum(embedded[start:end], axis=0) - elif mode == 1: - output_bag = jnp.mean(embedded[start:end], axis=0) - elif mode == 2: - output_bag = jnp.max(embedded[start:end], axis=0) - max_indices = max_indices.at[start:end].set( - jnp.argmax(embedded[start:end], axis=0) - ) + # The original code returned offset2bag, bag_size, and max_indices as numpy arrays. + # Converting them to JAX arrays for consistency. + offset2bag = jnp.array(offset2bag) + bag_size = jnp.array(bag_size) - # The original code returned offset2bag, bag_size, and max_indices as numpy arrays. - # Converting them to JAX arrays for consistency. - offset2bag = jnp.array(offset2bag) - bag_size = jnp.array(bag_size) - - return output_bag, offset2bag, bag_size, max_indices + return output_bag, offset2bag, bag_size, max_indices @op(torch.ops.aten.rsqrt) @op_base.promote_int_input def _aten_rsqrt(x): - return jax.lax.rsqrt(x) + return jax.lax.rsqrt(x) @op(torch.ops.aten.expand) @op(torch.ops.aten.expand_copy) def _aten_expand(x, dims): - def fix_dims(d, xs): - if d == -1: - return xs - return d - - shape = list(x.shape) - if len(shape) < len(dims): - shape = [ - 1, - ] * (len(dims) - len(shape)) + shape - # make sure that dims and shape is the same by - # left pad with 1s. Otherwise the zip below will - # truncate - dims = [fix_dims(p, s) for p, s in zip(dims, shape)] - return jnp.broadcast_to(x, dims) + def fix_dims(d, xs): + if d == -1: + return xs + return d + + shape = list(x.shape) + if len(shape) < len(dims): + shape = [ + 1, + ] * (len(dims) - len(shape)) + shape + # make sure that dims and shape is the same by + # left pad with 1s. Otherwise the zip below will + # truncate + dims = [fix_dims(p, s) for p, s in zip(dims, shape)] + return jnp.broadcast_to(x, dims) @op(torch.ops.aten.dot) def _aten_dot(x, y): - return jnp.dot(x, y) + return jnp.dot(x, y) @op(torch.ops.aten._to_copy) def _aten__to_copy(self, **kwargs): - dtype = mappings.t2j_dtype(kwargs["dtype"]) - if dtype != self.dtype: - return self.astype(dtype) - return jnp.copy(self) + dtype = mappings.t2j_dtype(kwargs["dtype"]) + if dtype != self.dtype: + return self.astype(dtype) + return jnp.copy(self) @op(torch.ops.aten.empty) @op_base.convert_dtype(use_default_dtype=False) def _aten_empty(size: Sequence[int], *, dtype=None, **kwargs): - return jnp.empty(size, dtype=dtype) + return jnp.empty(size, dtype=dtype) @op(torch.ops.aten.empty_like) @op_base.convert_dtype(use_default_dtype=False) def _aten_empty_like(input, *, dtype=None, **kwargs): - return jnp.empty_like(input, dtype) + return jnp.empty_like(input, dtype) @op(torch.ops.aten.ones) @op_base.convert_dtype() def _ones(size: Sequence[int], dtype=None, **kwargs): - return jnp.ones(size, dtype) + return jnp.ones(size, dtype) @op(torch.ops.aten.zeros) @op_base.convert_dtype() def _zeros(size: Sequence[int], dtype=None, **kwargs): - return jnp.zeros(size, dtype) + return jnp.zeros(size, dtype) @op(torch.ops.aten.full) @op_base.convert_dtype() def _full(size: Sequence[int], fill_value, *, dtype=None, **kwargs): - # TODO: handle torch.Size - return jnp.full(size, fill_value, dtype=dtype) + # TODO: handle torch.Size + return jnp.full(size, fill_value, dtype=dtype) @op(torch.ops.aten.empty_permuted) @op_base.convert_dtype() def _aten_empty_permuted(sizes, physical_layout, dtype=None, **kwargs): - # Ignore the physical layout, - # since JAX and torch tensor doesn't share the same memory. - return jnp.empty(sizes, dtype=dtype) + # Ignore the physical layout, + # since JAX and torch tensor doesn't share the same memory. + return jnp.empty(sizes, dtype=dtype) @op(torch.ops.aten.empty_strided) @op_base.convert_dtype() def _aten_empty_strided(sizes, stride, dtype=None, **kwargs): - # Ignore stride, since JAX and torch tensor doesn't share the same memory. - return jnp.empty(sizes, dtype=dtype) + # Ignore stride, since JAX and torch tensor doesn't share the same memory. + return jnp.empty(sizes, dtype=dtype) @op(torch.ops.aten.index_put_) @op(torch.ops.aten.index_put) def _aten_index_put(self, indexes, values, accumulate=False): - indexes = [slice(None, None, None) if i is None else i for i in indexes] - indexes = tuple(indexes) - if accumulate: - return self.at[indexes].add(values) - else: - return self.at[indexes].set(values) + indexes = [slice(None, None, None) if i is None else i for i in indexes] + indexes = tuple(indexes) + if accumulate: + return self.at[indexes].add(values) + else: + return self.at[indexes].set(values) @op(torch.ops.aten.index) @op(torch.ops.aten._unsafe_index) @op(torch.ops.aten.index.Tensor) def _aten_index(self, indexes): - indexes = [slice(None, None, None) if i is None else i for i in indexes] - indexes = tuple(indexes) - return self[indexes] + indexes = [slice(None, None, None) if i is None else i for i in indexes] + indexes = tuple(indexes) + return self[indexes] @op(torch.ops.aten.split) @op(torch.ops.aten.split_copy) @op(torch.ops.aten.split_with_sizes) def split_with_sizes(x, sizes, dim=0): - """Splits an array `x` into sub-arrays based on static sizes `sizes`. + """Splits an array `x` into sub-arrays based on static sizes `sizes`. - Args: - x: The input array to split. - sizes: A 1D array of integer sizes for each sub-array. + Args: + x: The input array to split. + sizes: A 1D array of integer sizes for each sub-array. - Returns: - A list of sub-arrays. - """ - if isinstance(sizes, int): - # split equal size, round up - new_sizes = [sizes] * (-(-x.shape[dim] // sizes)) - sizes = new_sizes - rank = x.ndim - splits = np.cumsum(sizes) # Cumulative sum for split points - - def make_range(rank, dim, start, end): - res = [slice(None, None, None)] * rank - res[dim] = slice(start, end) - return tuple(res) - - return [ - x[make_range(rank, dim, start, end)] - for start, end in zip([0] + list(splits[:-1]), splits) - ] + Returns: + A list of sub-arrays. + """ + if isinstance(sizes, int): + # split equal size, round up + new_sizes = [sizes] * (-(-x.shape[dim] // sizes)) + sizes = new_sizes + rank = x.ndim + splits = np.cumsum(sizes) # Cumulative sum for split points + + def make_range(rank, dim, start, end): + res = [slice(None, None, None)] * rank + res[dim] = slice(start, end) + return tuple(res) + + return [ + x[make_range(rank, dim, start, end)] + for start, end in zip([0] + list(splits[:-1]), splits) + ] @op(torch.ops.aten.permute) @op(torch.ops.aten.permute_copy) def permute(t, dims): - # TODO: return a View instead - return jnp.transpose(t, dims) + # TODO: return a View instead + return jnp.transpose(t, dims) @op(torch.ops.aten.unsqueeze) @op(torch.ops.aten.unsqueeze_copy) def _aten_unsqueeze(self, dim): - if dim < 0: - dim += self.ndim + 1 - return jnp.expand_dims(self, dim) + if dim < 0: + dim += self.ndim + 1 + return jnp.expand_dims(self, dim) @op(torch.ops.aten.ne) def _aten_ne(x, y): - return jnp.not_equal(x, y) + return jnp.not_equal(x, y) # Create indices along a specific axis @@ -831,643 +827,634 @@ def _aten_ne(x, y): # _indices_along_axis(x, axis=1) # >> [[0, 1, 2, 3]] shape (1, 4) def _indices_along_axis(x, axis): - return jnp.expand_dims( - jnp.arange(x.shape[axis]), - axis=[d for d in range(len(x.shape)) if d != axis], - ) + return jnp.expand_dims( + jnp.arange(x.shape[axis]), + axis=[d for d in range(len(x.shape)) if d != axis], + ) def _broadcast_indices(indices, shape): - return jnp.broadcast_to(indices, shape) + return jnp.broadcast_to(indices, shape) @op(torch.ops.aten.cummax) def _aten_cummax(x, dim): - if not x.shape: - return x, jnp.zeros_like(x, dtype=jnp.int64) + if not x.shape: + return x, jnp.zeros_like(x, dtype=jnp.int64) - axis = dim + axis = dim - indice_along_axis = _indices_along_axis(x, axis) - indices = _broadcast_indices(indice_along_axis, x.shape) + indice_along_axis = _indices_along_axis(x, axis) + indices = _broadcast_indices(indice_along_axis, x.shape) - def cummax_reduce_func(carry, elem): - v1, v2 = carry["val"], elem["val"] - i1, i2 = carry["idx"], elem["idx"] + def cummax_reduce_func(carry, elem): + v1, v2 = carry["val"], elem["val"] + i1, i2 = carry["idx"], elem["idx"] - v = jnp.maximum(v1, v2) - i = jnp.where(v1 > v2, i1, i2) - return {"val": v, "idx": i} + v = jnp.maximum(v1, v2) + i = jnp.where(v1 > v2, i1, i2) + return {"val": v, "idx": i} - res = jax.lax.associative_scan( - cummax_reduce_func, {"val": x, "idx": indices}, axis=axis - ) - return res["val"], res["idx"] + res = jax.lax.associative_scan( + cummax_reduce_func, {"val": x, "idx": indices}, axis=axis + ) + return res["val"], res["idx"] @op(torch.ops.aten.cummin) def _aten_cummin(x, dim): - if not x.shape: - return x, jnp.zeros_like(x, dtype=jnp.int64) + if not x.shape: + return x, jnp.zeros_like(x, dtype=jnp.int64) - axis = dim + axis = dim - indice_along_axis = _indices_along_axis(x, axis) - indices = _broadcast_indices(indice_along_axis, x.shape) + indice_along_axis = _indices_along_axis(x, axis) + indices = _broadcast_indices(indice_along_axis, x.shape) - def cummin_reduce_func(carry, elem): - v1, v2 = carry["val"], elem["val"] - i1, i2 = carry["idx"], elem["idx"] + def cummin_reduce_func(carry, elem): + v1, v2 = carry["val"], elem["val"] + i1, i2 = carry["idx"], elem["idx"] - v = jnp.minimum(v1, v2) - i = jnp.where(v1 < v2, i1, i2) - return {"val": v, "idx": i} + v = jnp.minimum(v1, v2) + i = jnp.where(v1 < v2, i1, i2) + return {"val": v, "idx": i} - res = jax.lax.associative_scan( - cummin_reduce_func, {"val": x, "idx": indices}, axis=axis - ) - return res["val"], res["idx"] + res = jax.lax.associative_scan( + cummin_reduce_func, {"val": x, "idx": indices}, axis=axis + ) + return res["val"], res["idx"] @op(torch.ops.aten.cumsum) def _aten_cumsum(x, y, dtype=None): - if dtype: - dtype = mappings.t2j_dtype(dtype) - if not x.shape: - return x - res = jnp.cumsum(x, y, dtype) - return res + if dtype: + dtype = mappings.t2j_dtype(dtype) + if not x.shape: + return x + res = jnp.cumsum(x, y, dtype) + return res @op(torch.ops.aten.cumprod) def _aten_cumprod(input, dim, dtype=None, out=None): - if dtype: - dtype = mappings.t2j_dtype(dtype) - if len(input.shape) > 0: - res = jnp.cumprod(input, axis=dim, dtype=dtype) - elif dtype: - res = input.astype(dtype) - else: - res = input - return res + if dtype: + dtype = mappings.t2j_dtype(dtype) + if len(input.shape) > 0: + res = jnp.cumprod(input, axis=dim, dtype=dtype) + elif dtype: + res = input.astype(dtype) + else: + res = input + return res @op(torch.ops.aten.native_layer_norm) def _aten_native_layer_norm( - input, normalized_shape, weight=None, bias=None, eps=1e-5 + input, normalized_shape, weight=None, bias=None, eps=1e-5 ): - """Implements layer normalization in Jax as defined by `aten::native_layer_norm`. - - Args: - input: The input tensor. - normalized_shape: A list of integer dimensions to be normalized over. - weight: Optional weight tensor for the affine transformation. - bias: Optional bias tensor for the affine transformation. - eps: A small epsilon value for numerical stability. - - Returns: - output: The normalized tensor. - mean: The calculated mean tensor. - std: The calculated standard deviation tensor. - """ - if isinstance(normalized_shape, int): - normalized_shape = [normalized_shape] - axis = [len(input.shape) - i - 1 for i in range(len(normalized_shape))] - - # Calculate mean and standard deviation - mean = jnp.mean(input, axis=axis, keepdims=True) - var = jnp.var(input, axis=axis, keepdims=True) - rstd = jax.lax.rsqrt(var + eps) - - # Normalize the input - norm_x = (input - mean) * rstd - - # Apply affine transformation (if provided) - if weight is not None: - norm_x *= weight - if bias is not None: - norm_x += bias - return norm_x, mean, rstd + """Implements layer normalization in Jax as defined by `aten::native_layer_norm`. + + Args: + input: The input tensor. + normalized_shape: A list of integer dimensions to be normalized over. + weight: Optional weight tensor for the affine transformation. + bias: Optional bias tensor for the affine transformation. + eps: A small epsilon value for numerical stability. + + Returns: + output: The normalized tensor. + mean: The calculated mean tensor. + std: The calculated standard deviation tensor. + """ + if isinstance(normalized_shape, int): + normalized_shape = [normalized_shape] + axis = [len(input.shape) - i - 1 for i in range(len(normalized_shape))] + + # Calculate mean and standard deviation + mean = jnp.mean(input, axis=axis, keepdims=True) + var = jnp.var(input, axis=axis, keepdims=True) + rstd = jax.lax.rsqrt(var + eps) + + # Normalize the input + norm_x = (input - mean) * rstd + + # Apply affine transformation (if provided) + if weight is not None: + norm_x *= weight + if bias is not None: + norm_x += bias + return norm_x, mean, rstd @op(torch.ops.aten.matmul) def _aten_matmul(x, y): - return x @ y + return x @ y # - func: addmm(Tensor self, Tensor mat1, Tensor mat2, *, Scalar beta=1, Scalar alpha=1) -> Tensor @op(torch.ops.aten.addmm) @op(torch.ops.aten.addmv) def _aten_addmm(self, mat1, mat2, *, beta=1.0, alpha=1.0): - alpha = jnp.array(alpha).astype(mat1.dtype) - beta = jnp.array(beta).astype(mat1.dtype) - self *= beta - self += alpha * jnp.matmul(mat1, mat2) - return self + alpha = jnp.array(alpha).astype(mat1.dtype) + beta = jnp.array(beta).astype(mat1.dtype) + self *= beta + self += alpha * jnp.matmul(mat1, mat2) + return self @op(torch.ops.aten.sparse_sampled_addmm) def _aten_sparse_addmm(self, mat1, mat2, *, beta=1.0, alpha=1.0): - alpha = jnp.array(alpha).astype(mat1.dtype) - beta = jnp.array(beta).astype(mat1.dtype) - self *= beta - self += alpha * jnp.matmul(mat1, mat2) * (self != 0) - return self + alpha = jnp.array(alpha).astype(mat1.dtype) + beta = jnp.array(beta).astype(mat1.dtype) + self *= beta + self += alpha * jnp.matmul(mat1, mat2) * (self != 0) + return self @op(torch.ops.aten.addbmm.default) def _aten_addbmm(input, batch1, batch2, *, beta=1, alpha=1): - alpha = jnp.array(alpha).astype(batch1.dtype) - beta = jnp.array(beta).astype(batch1.dtype) - mm = jnp.einsum("bxy, byz -> xz", batch1, batch2) - return jax.lax.cond( - beta == 0, lambda: alpha * mm, lambda: beta * input + alpha * mm - ) + alpha = jnp.array(alpha).astype(batch1.dtype) + beta = jnp.array(beta).astype(batch1.dtype) + mm = jnp.einsum("bxy, byz -> xz", batch1, batch2) + return jax.lax.cond( + beta == 0, lambda: alpha * mm, lambda: beta * input + alpha * mm + ) @op(torch.ops.aten.gelu) def _aten_gelu(self, *, approximate="none"): - approx = approximate == "tanh" - return jax.nn.gelu(self, approx) + approx = approximate == "tanh" + return jax.nn.gelu(self, approx) @op(torch.ops.aten.squeeze) @op(torch.ops.aten.squeeze_copy) def _aten_squeeze_dim(self, dim=None): - if self.ndim == 0: + if self.ndim == 0: + return self + if dim is not None: + if isinstance(dim, int): + if self.shape[dim] != 1: return self - if dim is not None: - if isinstance(dim, int): - if self.shape[dim] != 1: - return self - if dim < 0: - dim += self.ndim - else: - # NOTE: torch leaves the dims that is not 1 unchanged, - # but jax raises error. - dim = [ - i if i >= 0 else (i + self.ndim) - for i in dim - if self.shape[i] == 1 - ] - - return jnp.squeeze(self, dim) + if dim < 0: + dim += self.ndim + else: + # NOTE: torch leaves the dims that is not 1 unchanged, + # but jax raises error. + dim = [ + i if i >= 0 else (i + self.ndim) for i in dim if self.shape[i] == 1 + ] + + return jnp.squeeze(self, dim) @op(torch.ops.aten.bucketize) def _aten_bucketize( - input, boundaries, *, out_int32=False, right=False, out=None + input, boundaries, *, out_int32=False, right=False, out=None ): - return_type = jnp.int32 if out_int32 else jnp.int64 - return jnp.digitize(input, boundaries, right=not right).astype(return_type) + return_type = jnp.int32 if out_int32 else jnp.int64 + return jnp.digitize(input, boundaries, right=not right).astype(return_type) @op(torch.ops.aten.conv2d) def _aten_conv2d( + input, + weight, + bias, + stride, + padding, + dilation, + groups, +): + return _aten_convolution( input, weight, bias, stride, padding, dilation, - groups, -): - return _aten_convolution( - input, - weight, - bias, - stride, - padding, - dilation, - transposed=False, - output_padding=1, - groups=groups, - ) + transposed=False, + output_padding=1, + groups=groups, + ) @op(torch.ops.aten.convolution) def _aten_convolution( - input, - weight, - bias, - stride, - padding, - dilation, - transposed, - output_padding, - groups, + input, + weight, + bias, + stride, + padding, + dilation, + transposed, + output_padding, + groups, ): - num_shape_dim = weight.ndim - 1 - batch_dims = input.shape[:-num_shape_dim] - - input = input.reshape((-1, *input.shape[-num_shape_dim:])) - - def make_padding(padding, num_spatial_dims): - # Expand single padding to pairs expected by jax - if len(padding) == 1 and len(padding) < num_spatial_dims: - padding *= num_spatial_dims - if transposed: - # See https://pytorch.org/docs/stable/generated/torch.nn.ConvTranspose1d.html - pad_out = [] - for i in range(num_spatial_dims): - front = dilation[i] * (weight.shape[i + 2] - 1) - padding[i] - back = front + output_padding[i] - pad_out.append((front, back)) - return pad_out - else: - return ((p, p) for p in padding) - - def create_default_conv_dimension_numbers(num_spatial_dims): - # Ref: https://github.com/openxla/xla/blob/main/xla/client/xla_builder.cc#L4211 - # (batch dimension, feature dimension, spatial dimensions...) - lhs_spec = [0, 1] - # (out feature dimension, in feature dimension, spatial dimensions...) - # swapped for transposed convolution - rhs_spec = [1, 0] if transposed else [0, 1] - # (batch dimension, feature dimension, spatial dimensions...) - out_spec = [0, 1] - for i in range(0, num_spatial_dims): - lhs_spec.append(i + 2) - rhs_spec.append(i + 2) - out_spec.append(i + 2) - return jax.lax.ConvDimensionNumbers( - *map(tuple, (lhs_spec, rhs_spec, out_spec)) - ) + num_shape_dim = weight.ndim - 1 + batch_dims = input.shape[:-num_shape_dim] + input = input.reshape((-1, *input.shape[-num_shape_dim:])) + + def make_padding(padding, num_spatial_dims): + # Expand single padding to pairs expected by jax + if len(padding) == 1 and len(padding) < num_spatial_dims: + padding *= num_spatial_dims if transposed: - rhs = jnp.flip(weight, range(2, 1 + num_shape_dim)) - if groups != 1: - # reshape filters for tranposed depthwise convolution - assert rhs.shape[0] % groups == 0 - rhs_shape = [rhs.shape[0] // groups, rhs.shape[1] * groups] - rhs_shape.extend(rhs.shape[2:]) - rhs = jnp.reshape(rhs, rhs_shape) - res = jax.lax.conv_general_dilated( - input, - rhs, - (1,) * len(stride), - make_padding(padding, len(stride)), - lhs_dilation=stride, - rhs_dilation=dilation, - dimension_numbers=create_default_conv_dimension_numbers( - len(stride) - ), - feature_group_count=groups, - batch_group_count=1, - ) + # See https://pytorch.org/docs/stable/generated/torch.nn.ConvTranspose1d.html + pad_out = [] + for i in range(num_spatial_dims): + front = dilation[i] * (weight.shape[i + 2] - 1) - padding[i] + back = front + output_padding[i] + pad_out.append((front, back)) + return pad_out else: - res = jax.lax.conv_general_dilated( - input, - weight, - stride, - make_padding(padding, len(stride)), - lhs_dilation=(1,) * len(stride), - rhs_dilation=dilation, - dimension_numbers=create_default_conv_dimension_numbers( - len(stride) - ), - feature_group_count=groups, - batch_group_count=1, - ) + return ((p, p) for p in padding) + + def create_default_conv_dimension_numbers(num_spatial_dims): + # Ref: https://github.com/openxla/xla/blob/main/xla/client/xla_builder.cc#L4211 + # (batch dimension, feature dimension, spatial dimensions...) + lhs_spec = [0, 1] + # (out feature dimension, in feature dimension, spatial dimensions...) + # swapped for transposed convolution + rhs_spec = [1, 0] if transposed else [0, 1] + # (batch dimension, feature dimension, spatial dimensions...) + out_spec = [0, 1] + for i in range(0, num_spatial_dims): + lhs_spec.append(i + 2) + rhs_spec.append(i + 2) + out_spec.append(i + 2) + return jax.lax.ConvDimensionNumbers( + *map(tuple, (lhs_spec, rhs_spec, out_spec)) + ) + + if transposed: + rhs = jnp.flip(weight, range(2, 1 + num_shape_dim)) + if groups != 1: + # reshape filters for tranposed depthwise convolution + assert rhs.shape[0] % groups == 0 + rhs_shape = [rhs.shape[0] // groups, rhs.shape[1] * groups] + rhs_shape.extend(rhs.shape[2:]) + rhs = jnp.reshape(rhs, rhs_shape) + res = jax.lax.conv_general_dilated( + input, + rhs, + (1,) * len(stride), + make_padding(padding, len(stride)), + lhs_dilation=stride, + rhs_dilation=dilation, + dimension_numbers=create_default_conv_dimension_numbers(len(stride)), + feature_group_count=groups, + batch_group_count=1, + ) + else: + res = jax.lax.conv_general_dilated( + input, + weight, + stride, + make_padding(padding, len(stride)), + lhs_dilation=(1,) * len(stride), + rhs_dilation=dilation, + dimension_numbers=create_default_conv_dimension_numbers(len(stride)), + feature_group_count=groups, + batch_group_count=1, + ) - if bias is not None: - # TODO(qihqi): bias always on channel? - if len(bias.shape) == 1: - shape = [1] * len(res.shape) - shape[1] = bias.shape[0] - bias = bias.reshape(tuple(shape)) - res = res + bias + if bias is not None: + # TODO(qihqi): bias always on channel? + if len(bias.shape) == 1: + shape = [1] * len(res.shape) + shape[1] = bias.shape[0] + bias = bias.reshape(tuple(shape)) + res = res + bias - res = res.reshape((*batch_dims, *res.shape[-num_shape_dim:])) - return res + res = res.reshape((*batch_dims, *res.shape[-num_shape_dim:])) + return res # _native_batch_norm_legit(Tensor input, Tensor? weight, Tensor? bias, Tensor(a!) running_mean, Tensor(b!) running_var, bool training, float momentum, float eps) @op(torch.ops.aten._native_batch_norm_legit.default) def _aten__native_batch_norm_legit( - input, weight, bias, running_mean, running_var, training, momentum, eps + input, weight, bias, running_mean, running_var, training, momentum, eps ): - """JAX implementation of batch normalization with optional parameters. - Refers to https://github.com/pytorch/pytorch/blob/cd3a71f754a2248bcfe500de7c9860bd7d2002bf/torch/_decomp/decompositions.py#L1713. - - Args: - input (DeviceArray): Input data (N, C, H, W). - running_mean ([DeviceArray]): Running mean of input (C,). - running_var ([DeviceArray]): Running variance of input (C,). - weight (Optional[DeviceArray]): Scaling factor (gamma) (C,). Can be None. - bias (Optional[DeviceArray]): Shift factor (beta) (C,). Can be None. - training (bool): If True, use batch statistics for normalization. - If False, use running statistics. - momentum (float): Momentum factor for updating running statistics. - eps (float): Small constant for numerical stability. - - Returns: - DeviceArray: Normalized output - DeviceArray: Batch mean (C,) or empty if training is False - DeviceArray: Reversed batch variance (C,) or empty if training is False - """ - reduction_dims = [0] + list(range(2, input.ndim)) - reshape_dims = [1, -1] + [1] * (input.ndim - 2) - if training: - # Calculate batch mean and variance - mean = jnp.mean(input, axis=reduction_dims, keepdims=True) - saved_mean = jnp.squeeze(mean, reduction_dims) - var = jnp.var(input, axis=reduction_dims) - rstd = jax.lax.rsqrt(var.reshape(reshape_dims) + eps) - # Update running statistics using momentum - running_mean = (1 - momentum) * running_mean + momentum * saved_mean - running_var = (1 - momentum) * running_var + momentum * var - saved_rstd = jnp.squeeze(rstd, reduction_dims) - else: - rstd = jax.lax.rsqrt(running_var.reshape(reshape_dims) + eps) - saved_mean = jnp.array( - [], dtype=input.dtype - ) # No need to calculate batch statistics in inference mode - saved_rstd = jnp.array([], dtype=input.dtype) - - # Normalize - if training: - # use batch statistics if training - x_hat = (input - mean) * rstd - else: - # Use running statistics in inference mode - x_hat = (input - running_mean.reshape(reshape_dims)) * rstd - - # Scale and shift - if weight is not None: - x_hat *= weight.reshape(reshape_dims) # Reshape weight for broadcasting - if bias is not None: - x_hat += bias.reshape(reshape_dims) # Reshape bias for broadcasting - - return x_hat, saved_mean, saved_rstd + """JAX implementation of batch normalization with optional parameters. + Refers to https://github.com/pytorch/pytorch/blob/cd3a71f754a2248bcfe500de7c9860bd7d2002bf/torch/_decomp/decompositions.py#L1713. + + Args: + input (DeviceArray): Input data (N, C, H, W). + running_mean ([DeviceArray]): Running mean of input (C,). + running_var ([DeviceArray]): Running variance of input (C,). + weight (Optional[DeviceArray]): Scaling factor (gamma) (C,). Can be None. + bias (Optional[DeviceArray]): Shift factor (beta) (C,). Can be None. + training (bool): If True, use batch statistics for normalization. + If False, use running statistics. + momentum (float): Momentum factor for updating running statistics. + eps (float): Small constant for numerical stability. + + Returns: + DeviceArray: Normalized output + DeviceArray: Batch mean (C,) or empty if training is False + DeviceArray: Reversed batch variance (C,) or empty if training is False + """ + reduction_dims = [0] + list(range(2, input.ndim)) + reshape_dims = [1, -1] + [1] * (input.ndim - 2) + if training: + # Calculate batch mean and variance + mean = jnp.mean(input, axis=reduction_dims, keepdims=True) + saved_mean = jnp.squeeze(mean, reduction_dims) + var = jnp.var(input, axis=reduction_dims) + rstd = jax.lax.rsqrt(var.reshape(reshape_dims) + eps) + # Update running statistics using momentum + running_mean = (1 - momentum) * running_mean + momentum * saved_mean + running_var = (1 - momentum) * running_var + momentum * var + saved_rstd = jnp.squeeze(rstd, reduction_dims) + else: + rstd = jax.lax.rsqrt(running_var.reshape(reshape_dims) + eps) + saved_mean = jnp.array( + [], dtype=input.dtype + ) # No need to calculate batch statistics in inference mode + saved_rstd = jnp.array([], dtype=input.dtype) + + # Normalize + if training: + # use batch statistics if training + x_hat = (input - mean) * rstd + else: + # Use running statistics in inference mode + x_hat = (input - running_mean.reshape(reshape_dims)) * rstd + + # Scale and shift + if weight is not None: + x_hat *= weight.reshape(reshape_dims) # Reshape weight for broadcasting + if bias is not None: + x_hat += bias.reshape(reshape_dims) # Reshape bias for broadcasting + + return x_hat, saved_mean, saved_rstd @op(torch.ops.aten._native_batch_norm_legit_no_training) def _aten__native_batch_norm_legit_no_training( - input, weight, bias, running_mean, running_var, momentum, eps + input, weight, bias, running_mean, running_var, momentum, eps ): - return _aten__native_batch_norm_legit( - input, weight, bias, running_mean, running_var, False, momentum, eps - ) + return _aten__native_batch_norm_legit( + input, weight, bias, running_mean, running_var, False, momentum, eps + ) @op(torch.ops.aten.relu) def _aten_relu(self): - return jax.nn.relu(self) + return jax.nn.relu(self) @op(torch.ops.aten.cat) def _aten_cat(tensors, dims=0): - # handle empty tensors as a special case. - # torch.cat will ignore the empty tensor, while jnp.concatenate - # will error if the dims > 0. - filtered_tensors = [ - t for t in tensors if not (t.ndim == 1 and t.shape[0] == 0) - ] - if filtered_tensors: - return jnp.concatenate(filtered_tensors, dims) - return tensors[0] + # handle empty tensors as a special case. + # torch.cat will ignore the empty tensor, while jnp.concatenate + # will error if the dims > 0. + filtered_tensors = [ + t for t in tensors if not (t.ndim == 1 and t.shape[0] == 0) + ] + if filtered_tensors: + return jnp.concatenate(filtered_tensors, dims) + return tensors[0] def _ceil_mode_padding( - padding: list[int], - input_shape: list[int], - kernel_size: list[int], - stride: list[int], - dilation: list[int], - ceil_mode: bool, + padding: list[int], + input_shape: list[int], + kernel_size: list[int], + stride: list[int], + dilation: list[int], + ceil_mode: bool, ): - """Creates low and high padding specification for the given padding (which is symmetric) and ceil mode. - - Additional high padding could be required when ceil mode is set. - """ - ceil_mode_padding = [] - for i in range(len(padding)): - left_padding = padding[i] - right_padding = left_padding - - input_size = input_shape[2 + i] - output_size_rem = ( - input_size - + 2 * left_padding - - (kernel_size[i] - 1) * dilation[i] - - 1 - ) % stride[i] - if ceil_mode and output_size_rem != 0: - extra_padding = stride[i] - output_size_rem - new_output_size = ( - input_size - + left_padding - + right_padding - + extra_padding - - (kernel_size[i] - 1) * dilation[i] - - 1 - + stride[i] - - 1 - ) // stride[i] + 1 - # Ensure that the last pooling starts inside the image. - size_to_compare = input_size + left_padding - - if (new_output_size - 1) * stride[i] < size_to_compare: - right_padding += extra_padding - - ceil_mode_padding.append((left_padding, right_padding)) - return ceil_mode_padding + """Creates low and high padding specification for the given padding (which is symmetric) and ceil mode. + + Additional high padding could be required when ceil mode is set. + """ + ceil_mode_padding = [] + for i in range(len(padding)): + left_padding = padding[i] + right_padding = left_padding + + input_size = input_shape[2 + i] + output_size_rem = ( + input_size + 2 * left_padding - (kernel_size[i] - 1) * dilation[i] - 1 + ) % stride[i] + if ceil_mode and output_size_rem != 0: + extra_padding = stride[i] - output_size_rem + new_output_size = ( + input_size + + left_padding + + right_padding + + extra_padding + - (kernel_size[i] - 1) * dilation[i] + - 1 + + stride[i] + - 1 + ) // stride[i] + 1 + # Ensure that the last pooling starts inside the image. + size_to_compare = input_size + left_padding + + if (new_output_size - 1) * stride[i] < size_to_compare: + right_padding += extra_padding + + ceil_mode_padding.append((left_padding, right_padding)) + return ceil_mode_padding @op(torch.ops.aten.max_pool2d_with_indices) @op(torch.ops.aten.max_pool3d_with_indices) def _aten_max_pool2d_with_indices( - inputs, kernel_size, strides=None, padding=0, dilation=1, ceil_mode=False + inputs, kernel_size, strides=None, padding=0, dilation=1, ceil_mode=False ): - num_batch_dims = len(inputs.shape) - len(kernel_size) - 1 - kernel_size = tuple(kernel_size) - # Default stride is kernel_size - strides = tuple(strides) if strides else kernel_size - if isinstance(padding, int): - padding = [padding for _ in range(len(kernel_size))] - if isinstance(dilation, int): - dilation = tuple(dilation for _ in range(len(kernel_size))) - elif isinstance(dilation, list): - dilation = tuple(dilation) - - input_shape = inputs.shape - if num_batch_dims == 0: - input_shape = [1, *input_shape] - padding = _ceil_mode_padding( - padding, input_shape, kernel_size, strides, dilation, ceil_mode - ) - - assert len(kernel_size) == len(strides), ( - f"len({kernel_size=}) must equal len({strides=})" - ) - assert len(kernel_size) == len(dilation), ( - f"len({kernel_size=}) must equal len({dilation=})" - ) - strides = (1,) * (1 + num_batch_dims) + strides - dims = (1,) * (1 + num_batch_dims) + kernel_size - dilation = (1,) * (1 + num_batch_dims) + dilation - - is_single_input = False - if num_batch_dims == 0: - # add singleton batch dimension because lax.reduce_window always - # needs a batch dimension. - inputs = inputs[None] - strides = (1,) + strides - dims = (1,) + dims - dilation = (1,) + dilation - is_single_input = True - - assert inputs.ndim == len(dims), f"len({inputs.shape}) != len({dims})" - if not isinstance(padding, str): - padding = tuple(map(tuple, padding)) - assert len(padding) == len(kernel_size), ( - f"padding {padding} must specify pads for same number of dims as " - f"kernel_size {kernel_size}" - ) - assert all([len(x) == 2 for x in padding]), ( - f"each entry in padding {padding} must be length 2" - ) - padding = ((0, 0), (0, 0)) + padding - - indices = jnp.arange(np.prod(inputs.shape[-len(kernel_size) :])) - indices = indices.reshape(inputs.shape[-len(kernel_size) :]) - indices = jnp.broadcast_to(indices, inputs.shape) - - def reduce_fn(a, b): - ai, av = a - bi, bv = b - which = av >= bv # torch breaks ties in favor of later indices - return jnp.where(which, ai, bi), jnp.where(which, av, bv) - - init_val = -jnp.inf - if inputs.dtype in (jnp.int32, jnp.int64): - init_val = -(1 << 31) - init_val = jnp.array(init_val).astype(inputs.dtype) - - # Separate maxpool result and indices into two reduce_window ops. Since - # the indices tensor is usually unused in inference, separating the two - # can help DCE computations for argmax. - y = jax.lax.reduce_window( - inputs, - init_val, - jax.lax.max, - dims, - strides, - padding, - window_dilation=dilation, + num_batch_dims = len(inputs.shape) - len(kernel_size) - 1 + kernel_size = tuple(kernel_size) + # Default stride is kernel_size + strides = tuple(strides) if strides else kernel_size + if isinstance(padding, int): + padding = [padding for _ in range(len(kernel_size))] + if isinstance(dilation, int): + dilation = tuple(dilation for _ in range(len(kernel_size))) + elif isinstance(dilation, list): + dilation = tuple(dilation) + + input_shape = inputs.shape + if num_batch_dims == 0: + input_shape = [1, *input_shape] + padding = _ceil_mode_padding( + padding, input_shape, kernel_size, strides, dilation, ceil_mode + ) + + assert len(kernel_size) == len(strides), ( + f"len({kernel_size=}) must equal len({strides=})" + ) + assert len(kernel_size) == len(dilation), ( + f"len({kernel_size=}) must equal len({dilation=})" + ) + strides = (1,) * (1 + num_batch_dims) + strides + dims = (1,) * (1 + num_batch_dims) + kernel_size + dilation = (1,) * (1 + num_batch_dims) + dilation + + is_single_input = False + if num_batch_dims == 0: + # add singleton batch dimension because lax.reduce_window always + # needs a batch dimension. + inputs = inputs[None] + strides = (1,) + strides + dims = (1,) + dims + dilation = (1,) + dilation + is_single_input = True + + assert inputs.ndim == len(dims), f"len({inputs.shape}) != len({dims})" + if not isinstance(padding, str): + padding = tuple(map(tuple, padding)) + assert len(padding) == len(kernel_size), ( + f"padding {padding} must specify pads for same number of dims as " + f"kernel_size {kernel_size}" ) - indices, _ = jax.lax.reduce_window( - (indices, inputs), - (0, init_val), - reduce_fn, - dims, - strides, - padding, - window_dilation=dilation, + assert all([len(x) == 2 for x in padding]), ( + f"each entry in padding {padding} must be length 2" ) - if is_single_input: - indices = jnp.squeeze(indices, axis=0) - y = jnp.squeeze(y, axis=0) + padding = ((0, 0), (0, 0)) + padding + + indices = jnp.arange(np.prod(inputs.shape[-len(kernel_size) :])) + indices = indices.reshape(inputs.shape[-len(kernel_size) :]) + indices = jnp.broadcast_to(indices, inputs.shape) + + def reduce_fn(a, b): + ai, av = a + bi, bv = b + which = av >= bv # torch breaks ties in favor of later indices + return jnp.where(which, ai, bi), jnp.where(which, av, bv) + + init_val = -jnp.inf + if inputs.dtype in (jnp.int32, jnp.int64): + init_val = -(1 << 31) + init_val = jnp.array(init_val).astype(inputs.dtype) + + # Separate maxpool result and indices into two reduce_window ops. Since + # the indices tensor is usually unused in inference, separating the two + # can help DCE computations for argmax. + y = jax.lax.reduce_window( + inputs, + init_val, + jax.lax.max, + dims, + strides, + padding, + window_dilation=dilation, + ) + indices, _ = jax.lax.reduce_window( + (indices, inputs), + (0, init_val), + reduce_fn, + dims, + strides, + padding, + window_dilation=dilation, + ) + if is_single_input: + indices = jnp.squeeze(indices, axis=0) + y = jnp.squeeze(y, axis=0) - return y, indices + return y, indices # Aten ops registered under the `xla` library. try: - @op(torch.ops.xla.max_pool2d_forward) - def _xla_max_pool2d_forward(*args, **kwargs): - return _aten_max_pool2d_with_indices(*args, **kwargs)[0] - - @op(torch.ops.xla.aot_mark_sharding) - def _xla_aot_mark_sharding(t, mesh: str, partition_spec: str): - from jax.sharding import PartitionSpec as P, NamedSharding - import ast - import torch_xla.distributed.spmd as xs - - pmesh = xs.Mesh.from_str(mesh) - assert pmesh is not None - partition_spec_eval = ast.literal_eval(partition_spec) - jmesh = pmesh.get_jax_mesh() - return jax.lax.with_sharding_constraint( - t, NamedSharding(jmesh, P(*partition_spec_eval)) - ) + @op(torch.ops.xla.max_pool2d_forward) + def _xla_max_pool2d_forward(*args, **kwargs): + return _aten_max_pool2d_with_indices(*args, **kwargs)[0] + + @op(torch.ops.xla.aot_mark_sharding) + def _xla_aot_mark_sharding(t, mesh: str, partition_spec: str): + from jax.sharding import PartitionSpec as P, NamedSharding + import ast + import torch_xla.distributed.spmd as xs + + pmesh = xs.Mesh.from_str(mesh) + assert pmesh is not None + partition_spec_eval = ast.literal_eval(partition_spec) + jmesh = pmesh.get_jax_mesh() + return jax.lax.with_sharding_constraint( + t, NamedSharding(jmesh, P(*partition_spec_eval)) + ) - @op(torch.ops.xla.einsum_linear_forward) - def _xla_einsum_linear_forward(input, weight, bias): - with jax.named_scope("einsum_linear_forward"): - product = jax.numpy.einsum("...n,mn->...m", input, weight) - if bias is not None: - return product + bias - return product + @op(torch.ops.xla.einsum_linear_forward) + def _xla_einsum_linear_forward(input, weight, bias): + with jax.named_scope("einsum_linear_forward"): + product = jax.numpy.einsum("...n,mn->...m", input, weight) + if bias is not None: + return product + bias + return product except AttributeError: - pass + pass # TODO add more ops @op(torch.ops.aten.min) def _aten_min(x, dim=None, keepdim=False): - if dim is not None: - return _with_reduction_scalar( - jnp.min, x, dim, keepdim - ), _with_reduction_scalar(jnp.argmin, x, dim, keepdim).astype(jnp.int64) - else: - return _with_reduction_scalar(jnp.min, x, dim, keepdim) + if dim is not None: + return _with_reduction_scalar( + jnp.min, x, dim, keepdim + ), _with_reduction_scalar(jnp.argmin, x, dim, keepdim).astype(jnp.int64) + else: + return _with_reduction_scalar(jnp.min, x, dim, keepdim) @op(torch.ops.aten.mode) def _aten_mode(input, dim=-1, keepdim=False, *, out=None): - if input.ndim == 0: # single number - return input, jnp.array(0) - dim = ( - input.ndim + dim - ) % input.ndim # jnp.scipy.stats.mode does not accept -1 as dim - # keepdims must be True for accurate broadcasting - mode, _ = jax.scipy.stats.mode(input, axis=dim, keepdims=True) - mode_broadcast = jnp.broadcast_to(mode, input.shape) - if not keepdim: - mode = mode.squeeze(axis=dim) - indices = jnp.argmax( - jnp.equal(mode_broadcast, input), axis=dim, keepdims=keepdim - ) - return mode, indices + if input.ndim == 0: # single number + return input, jnp.array(0) + dim = ( + input.ndim + dim + ) % input.ndim # jnp.scipy.stats.mode does not accept -1 as dim + # keepdims must be True for accurate broadcasting + mode, _ = jax.scipy.stats.mode(input, axis=dim, keepdims=True) + mode_broadcast = jnp.broadcast_to(mode, input.shape) + if not keepdim: + mode = mode.squeeze(axis=dim) + indices = jnp.argmax( + jnp.equal(mode_broadcast, input), axis=dim, keepdims=keepdim + ) + return mode, indices @op(torch.ops.aten.amin) def _aten_amin(x, dim=None, keepdim=False): - return _with_reduction_scalar(jnp.amin, x, dim, keepdim) + return _with_reduction_scalar(jnp.amin, x, dim, keepdim) @op(torch.ops.aten.argmin) def _aten_argmin(self, dim=None, keepdim=False): - return _with_reduction_scalar(jnp.argmin, self, dim, keepdim) + return _with_reduction_scalar(jnp.argmin, self, dim, keepdim) @op(torch.ops.aten.sin) @op_base.promote_int_input def _aten_sin(x): - return jnp.sin(x) + return jnp.sin(x) @op(torch.ops.aten.sym_size) def _aten_sym_size(x, dim): - return x.shape[dim] + return x.shape[dim] @op(torch.ops.aten.var.correction) @op(torch.ops.prims.var) def _aten_var(x, dim=None, *, correction=1, keepdim=False, out=None): - return jnp.var(x, axis=dim, ddof=correction, keepdims=keepdim) + return jnp.var(x, axis=dim, ddof=correction, keepdims=keepdim) @op(torch.ops.prims.broadcast_in_dim) def _prims_broadcast_in_dim(t, shape, broadcast_dimensions): - return jax.lax.broadcast_in_dim( - t, shape, broadcast_dimensions=broadcast_dimensions - ) + return jax.lax.broadcast_in_dim( + t, shape, broadcast_dimensions=broadcast_dimensions + ) # aten.native_group_norm -- should use decomp table @@ -1476,176 +1463,174 @@ def _prims_broadcast_in_dim(t, shape, broadcast_dimensions): @op(torch.ops.aten.native_group_norm) def _aten_native_group_norm(input, weight, bias, N, C, HxW, group, eps=1e-5): - """Group Normalization implementation in JAX. - - Args: - input: Input tensor. Expected shape (batch_size, channels, ... spatial dims - ...) - weight: Optional scaling (gamma) parameter. Shape (channels,) - bias: Optional shifting (beta) parameter. Shape (channels,) - N: Batch size. - C: Number of channels. - HxW: Product of spatial dimensions (number of elements per channel after - flattening). - group: Number of groups for Group Normalization. - eps: Small value added for numerical stability. - - Returns: - A tuple of (normalized_output, mean, rstd) - """ - - input_shape = input.shape - - if 0 in input_shape: - return input, input, input - - # Reshape for group-wise normalization - reshaped_input = jnp.reshape(input, (1, N * group, -1)) - - # **Core Group Normalization** - def group_norm_body(x): # Function to apply within each group - mean = jnp.mean(x, axis=-1, keepdims=True) - var = jnp.var(x, axis=-1, keepdims=True) - rstd = jax.lax.rsqrt(var + eps) # Reciprocal of std with epsilon - normalized = (x - mean) * rstd - return normalized, mean, rstd - - normalized, group_mean, group_rstd = jax.lax.map( - group_norm_body, reshaped_input - ) - - # Reshape back to original input shape - output = jnp.reshape(normalized, input_shape) - - # **Affine transformation** - affine_shape = [ - -1 if i == 1 else 1 for i in range(input.ndim) - ] # Shape for broadcasting - if weight is not None and bias is not None: - output = bias.reshape(affine_shape) + output * weight.reshape( - affine_shape - ) - elif weight is not None: - output = output * weight.reshape(affine_shape) - elif bias is not None: - output = output + bias.reshape(affine_shape) - - # Reshape mean and rstd - mean = jnp.reshape(group_mean, (N, group)) - rstd = jnp.reshape(group_rstd, (N, group)) - - return output, mean, rstd + """Group Normalization implementation in JAX. + + Args: + input: Input tensor. Expected shape (batch_size, channels, ... spatial dims + ...) + weight: Optional scaling (gamma) parameter. Shape (channels,) + bias: Optional shifting (beta) parameter. Shape (channels,) + N: Batch size. + C: Number of channels. + HxW: Product of spatial dimensions (number of elements per channel after + flattening). + group: Number of groups for Group Normalization. + eps: Small value added for numerical stability. + + Returns: + A tuple of (normalized_output, mean, rstd) + """ + + input_shape = input.shape + + if 0 in input_shape: + return input, input, input + + # Reshape for group-wise normalization + reshaped_input = jnp.reshape(input, (1, N * group, -1)) + + # **Core Group Normalization** + def group_norm_body(x): # Function to apply within each group + mean = jnp.mean(x, axis=-1, keepdims=True) + var = jnp.var(x, axis=-1, keepdims=True) + rstd = jax.lax.rsqrt(var + eps) # Reciprocal of std with epsilon + normalized = (x - mean) * rstd + return normalized, mean, rstd + + normalized, group_mean, group_rstd = jax.lax.map( + group_norm_body, reshaped_input + ) + + # Reshape back to original input shape + output = jnp.reshape(normalized, input_shape) + + # **Affine transformation** + affine_shape = [ + -1 if i == 1 else 1 for i in range(input.ndim) + ] # Shape for broadcasting + if weight is not None and bias is not None: + output = bias.reshape(affine_shape) + output * weight.reshape(affine_shape) + elif weight is not None: + output = output * weight.reshape(affine_shape) + elif bias is not None: + output = output + bias.reshape(affine_shape) + + # Reshape mean and rstd + mean = jnp.reshape(group_mean, (N, group)) + rstd = jnp.reshape(group_rstd, (N, group)) + + return output, mean, rstd @op(torch.ops.aten.linalg_vector_norm) def _aten_linalg_vector_norm(self, ord=2, dim=None, keepdim=False, dtype=None): - """Calculates the vector norm along specified dimensions. - - Args: - self: The input tensor. - ord: The order of the norm. Can be a float or 'inf', '-inf', 'fro'. - Default is 2 (Euclidean norm). - dim: Dimensions along which to calculate the norm. If None, the norm is - calculated over all dimensions. - keepdim: Whether to keep the reduced dimensions. - dtype: Optional data type for the output. - - Returns: - The tensor containing the calculated vector norms. - """ + """Calculates the vector norm along specified dimensions. + + Args: + self: The input tensor. + ord: The order of the norm. Can be a float or 'inf', '-inf', 'fro'. + Default is 2 (Euclidean norm). + dim: Dimensions along which to calculate the norm. If None, the norm is + calculated over all dimensions. + keepdim: Whether to keep the reduced dimensions. + dtype: Optional data type for the output. + + Returns: + The tensor containing the calculated vector norms. + """ + + if ord not in {2, float("inf"), float("-inf"), "fro"} and not isinstance( + ord, (int, float) + ): + raise ValueError( + f"Unsupported ord value: {ord}. Supported values are 2, inf, -inf, and" + " 'fro'." + ) - if ord not in {2, float("inf"), float("-inf"), "fro"} and not isinstance( - ord, (int, float) - ): - raise ValueError( - f"Unsupported ord value: {ord}. Supported values are 2, inf, -inf, and" - " 'fro'." - ) + # Special cases (for efficiency and clarity) + if ord == 0: + if self.shape == (): + # float sets it to float64. set it back to input type + result = jnp.astype(jnp.array(float(self != 0)), self.dtype) + else: + result = _with_reduction_scalar( + jnp.sum, jnp.where(self != 0, 1, 0), dim, keepdim + ) - # Special cases (for efficiency and clarity) - if ord == 0: - if self.shape == (): - # float sets it to float64. set it back to input type - result = jnp.astype(jnp.array(float(self != 0)), self.dtype) - else: - result = _with_reduction_scalar( - jnp.sum, jnp.where(self != 0, 1, 0), dim, keepdim - ) - - elif ord == 2: # Euclidean norm - result = jnp.sqrt( - _with_reduction_scalar(jnp.sum, jnp.abs(self) ** 2, dim, keepdim) - ) + elif ord == 2: # Euclidean norm + result = jnp.sqrt( + _with_reduction_scalar(jnp.sum, jnp.abs(self) ** 2, dim, keepdim) + ) - elif ord == float("inf"): - result = _with_reduction_scalar(jnp.max, jnp.abs(self), dim, keepdim) + elif ord == float("inf"): + result = _with_reduction_scalar(jnp.max, jnp.abs(self), dim, keepdim) - elif ord == float("-inf"): - result = _with_reduction_scalar(jnp.min, jnp.abs(self), dim, keepdim) + elif ord == float("-inf"): + result = _with_reduction_scalar(jnp.min, jnp.abs(self), dim, keepdim) - elif ord == "fro": # Frobenius norm - result = jnp.sqrt( - _with_reduction_scalar(jnp.sum, jnp.abs(self) ** 2, dim, keepdim) - ) + elif ord == "fro": # Frobenius norm + result = jnp.sqrt( + _with_reduction_scalar(jnp.sum, jnp.abs(self) ** 2, dim, keepdim) + ) - else: # General case (e.g., ord = 1, ord = 3) - result = _with_reduction_scalar( - jnp.sum, jnp.abs(self) ** ord, dim, keepdim - ) ** (1.0 / ord) + else: # General case (e.g., ord = 1, ord = 3) + result = _with_reduction_scalar( + jnp.sum, jnp.abs(self) ** ord, dim, keepdim + ) ** (1.0 / ord) - # (Optional) dtype conversion - if dtype is not None: - result = jnp.astype(result, self.dtype) + # (Optional) dtype conversion + if dtype is not None: + result = jnp.astype(result, self.dtype) - new_dtype = mappings.t2j_dtype(torch.get_default_dtype()) - if result.dtype == jax.numpy.int64: - result = result.astype(new_dtype) - return result + new_dtype = mappings.t2j_dtype(torch.get_default_dtype()) + if result.dtype == jax.numpy.int64: + result = result.astype(new_dtype) + return result # aten.reflection_pad1d @op(torch.ops.aten.reflection_pad1d) def _aten_reflection_pad1d(input, padding): - rank = len(input.shape) - pad_size = [(0, 0)] * rank - pad_size[-1] = padding - return jnp.pad(input, pad_size, mode="reflect") + rank = len(input.shape) + pad_size = [(0, 0)] * rank + pad_size[-1] = padding + return jnp.pad(input, pad_size, mode="reflect") # aten.alias @op(torch.ops.aten.alias) def _aten_alias(self, *args): - return self + return self # aten.sinh @op(torch.ops.aten.sinh) @op_base.promote_int_input def _aten_sinh(self): - return jnp.sinh(self) + return jnp.sinh(self) # aten.native_layer_norm_backward @op(torch.ops.aten.native_layer_norm_backward) def _aten_native_layer_norm_backward( - grad_out, input, normalized_shape, weight, bias, eps=1e-5 + grad_out, input, normalized_shape, weight, bias, eps=1e-5 ): - """Implements the backward pass of layer normalization in Jax as defined by `aten::native_layer_norm_backward`. + """Implements the backward pass of layer normalization in Jax as defined by `aten::native_layer_norm_backward`. - Args: - grad_out: The gradient of the output tensor. - input: The input tensor. - normalized_shape: A list of integer dimensions to be normalized over. - weight: Optional weight tensor for the affine transformation. - bias: Optional bias tensor for the affine transformation. - eps: A small epsilon value for numerical stability. + Args: + grad_out: The gradient of the output tensor. + input: The input tensor. + normalized_shape: A list of integer dimensions to be normalized over. + weight: Optional weight tensor for the affine transformation. + bias: Optional bias tensor for the affine transformation. + eps: A small epsilon value for numerical stability. - Returns: - A tuple of (grad_input, grad_weight, grad_bias). - """ - return jax.lax.native_layer_norm_backward( - grad_out, input, normalized_shape, weight, bias, eps - ) + Returns: + A tuple of (grad_input, grad_weight, grad_bias). + """ + return jax.lax.native_layer_norm_backward( + grad_out, input, normalized_shape, weight, bias, eps + ) # aten.reflection_pad3d_backward @@ -1656,34 +1641,34 @@ def _aten_native_layer_norm_backward( @op(torch.ops.aten.atanh) @op_base.promote_int_input def _aten_atanh(self): - res = jnp.arctanh(self) - return res + res = jnp.arctanh(self) + return res # aten.bincount @op(torch.ops.aten.bincount) def _aten_bincount(input, weights=None, minlength=0): - return jnp.bincount(input, weights, minlength) + return jnp.bincount(input, weights, minlength) # aten.bitwise_not @op(torch.ops.aten.bitwise_not) def _aten_bitwise_not(self): - return ~self + return ~self # aten.bitwise_left_shift @op(torch.ops.aten.__lshift__) @op(torch.ops.aten.bitwise_left_shift) def _aten_bitwise_left_shift(input, other): - return jnp.left_shift(input, other) + return jnp.left_shift(input, other) # aten.bitwise_right_shift @op(torch.ops.aten.__rshift__) @op(torch.ops.aten.bitwise_right_shift) def _aten_bitwise_right_shift(input, other): - return jnp.right_shift(input, other) + return jnp.right_shift(input, other) # aten.embedding_dense_backward @@ -1692,127 +1677,127 @@ def _aten_bitwise_right_shift(input, other): # aten.sum @op(torch.ops.aten.sum) def _aten_sum(self, dim=None, keepdim=False, dtype=None): - if not dim: - dim = None - return _with_reduction_scalar(jnp.sum, self, dim, keepdim) + if not dim: + dim = None + return _with_reduction_scalar(jnp.sum, self, dim, keepdim) # aten.sqrt @op(torch.ops.aten.sqrt) @op_base.promote_int_input def _aten_sqrt(self): - return jnp.sqrt(self) + return jnp.sqrt(self) @op(torch.ops.aten.tan) @op_base.promote_int_input def _aten_tanh(self): - res = jnp.tan(self) - return res + res = jnp.tan(self) + return res # aten.tanh @op(torch.ops.aten.tanh) @op_base.promote_int_input def _aten_tanh(self): - res = jnp.tanh(self) - return res + res = jnp.tanh(self) + return res # aten.ceil @op(torch.ops.aten.ceil) def _aten_ceil(self): - return jnp.ceil(self).astype(self) + return jnp.ceil(self).astype(self) # aten.asin @op(torch.ops.aten.asin) @op_base.promote_int_input def _aten_asin(self): - res = jnp.arcsin(self) - return res + res = jnp.arcsin(self) + return res # aten.minimum @op(torch.ops.aten.minimum) def _aten_minimum(self, other): - return jnp.minimum(self, other) + return jnp.minimum(self, other) # aten.max_pool2d_backward def _scatter_index(dim, index): - """Returns a tuple of indexes; - - The first is to select in input (to modify), - the second is to select from the values. - """ - index_shape = list(index.shape) - input_indexes = [] - source_indexes = [] - if dim < 0: - dim += len(index_shape) - for i in range(len(index_shape)): - source_indexes.append(slice(0, index_shape[i])) - if i == dim: - input_indexes.append(index) - else: - target_shape = [1] * len(index_shape) - target_shape[i] = index_shape[i] - input_indexes.append( - jnp.broadcast_to( - jnp.arange(index_shape[i]).reshape(target_shape), - index_shape, - ) - ) - return tuple(input_indexes), tuple(source_indexes) + """Returns a tuple of indexes; + + The first is to select in input (to modify), + the second is to select from the values. + """ + index_shape = list(index.shape) + input_indexes = [] + source_indexes = [] + if dim < 0: + dim += len(index_shape) + for i in range(len(index_shape)): + source_indexes.append(slice(0, index_shape[i])) + if i == dim: + input_indexes.append(index) + else: + target_shape = [1] * len(index_shape) + target_shape[i] = index_shape[i] + input_indexes.append( + jnp.broadcast_to( + jnp.arange(index_shape[i]).reshape(target_shape), + index_shape, + ) + ) + return tuple(input_indexes), tuple(source_indexes) # aten.scatter_add @op(torch.ops.aten.scatter_add) def _aten_scatter_add(input, dim, index, src): - """JAX implementation of scatter, mimicking torch.scatter behavior""" + """JAX implementation of scatter, mimicking torch.scatter behavior""" - input_indexes, source_indexes = _scatter_index(dim, index) - return input.at[input_indexes].add(src[source_indexes]) + input_indexes, source_indexes = _scatter_index(dim, index) + return input.at[input_indexes].add(src[source_indexes]) # aten.masked_scatter @op(torch.ops.aten.masked_scatter) def _aten_masked_scatter(self, mask, source): - broadcast_shape = jnp.broadcast_shapes(self.shape, mask.shape) + broadcast_shape = jnp.broadcast_shapes(self.shape, mask.shape) - if self.shape != broadcast_shape: - self = jnp.broadcast_to(self, broadcast_shape) - elif mask.shape != broadcast_shape: - mask = jnp.broadcast_to(mask, broadcast_shape) + if self.shape != broadcast_shape: + self = jnp.broadcast_to(self, broadcast_shape) + elif mask.shape != broadcast_shape: + mask = jnp.broadcast_to(mask, broadcast_shape) - self_flat = self.flatten() - mask_flat = mask.flatten() - source_flat = source.flatten() + self_flat = self.flatten() + mask_flat = mask.flatten() + source_flat = source.flatten() - true_indices = jnp.where(mask_flat)[0] - self_flat = self_flat.at[true_indices].set(source_flat[: len(true_indices)]) - final_arr = self_flat.reshape(self.shape) + true_indices = jnp.where(mask_flat)[0] + self_flat = self_flat.at[true_indices].set(source_flat[: len(true_indices)]) + final_arr = self_flat.reshape(self.shape) - return final_arr + return final_arr @op(torch.ops.aten.masked_select) def _aten_masked_select(self, mask, *args, **kwargs): - broadcast_shape = jnp.broadcast_shapes(self.shape, mask.shape) + broadcast_shape = jnp.broadcast_shapes(self.shape, mask.shape) - if self.shape != broadcast_shape: - self = jnp.broadcast_to(self, broadcast_shape) - if mask.shape != broadcast_shape: - mask = jnp.broadcast_to(mask, broadcast_shape) + if self.shape != broadcast_shape: + self = jnp.broadcast_to(self, broadcast_shape) + if mask.shape != broadcast_shape: + mask = jnp.broadcast_to(mask, broadcast_shape) - self_flat = self.flatten() - mask_flat = mask.flatten() - true_indices = jnp.where(mask_flat)[0] + self_flat = self.flatten() + mask_flat = mask.flatten() + true_indices = jnp.where(mask_flat)[0] - return self_flat[true_indices] + return self_flat[true_indices] # aten.logical_not @@ -1821,86 +1806,86 @@ def _aten_masked_select(self, mask, *args, **kwargs): # aten.sign @op(torch.ops.aten.sign) def _aten_sign(x): - return jnp.sign(x) + return jnp.sign(x) # aten.signbit @op(torch.ops.aten.signbit) def _aten_signbit(x): - return jnp.signbit(x) + return jnp.signbit(x) # aten.sigmoid @op(torch.ops.aten.sigmoid) @op_base.promote_int_input def _aten_sigmoid(x): - return jax.nn.sigmoid(x) + return jax.nn.sigmoid(x) # implement aten.asinh in jax @op(torch.ops.aten.asinh) @op_base.promote_int_input def _aten_asinh(self): - res = jnp.arcsinh(self) - return res + res = jnp.arcsinh(self) + return res # aten.atan @op(torch.ops.aten.atan) @op_base.promote_int_input def _aten_atan(self): - res = jnp.arctan(self) - return res + res = jnp.arctan(self) + return res @op(torch.ops.aten.scatter_reduce) @op(torch.ops.aten.scatter) def _aten_scatter_reduce( - input, dim, index, src, reduce=None, *, include_self=True + input, dim, index, src, reduce=None, *, include_self=True ): - if not isinstance(src, jnp.ndarray): - src = jnp.array(src, dtype=input.dtype) - input_indexes, source_indexes = _scatter_index(dim, index) - # "Zero out" target elements when not included - if not include_self: - if reduce in ["sum", "mean"]: - base_input = jnp.zeros_like(src) - elif reduce == "prod": - base_input = jnp.ones_like(src) - elif reduce == "amax": - base_input = jnp.full_like(src, -jnp.inf) - else: # amin - base_input = jnp.full_like(src, jnp.inf) - input = input.at[input_indexes].set(base_input[source_indexes]) - - if reduce == "sum" or reduce == "add": - return input.at[input_indexes].add(src[source_indexes]) - elif reduce == "prod" or reduce == "multiply": - return input.at[input_indexes].multiply(src[source_indexes]) - elif reduce == "mean": - if include_self: - count = jnp.ones_like(input) - else: - count = jnp.zeros_like(input) - count = count.at[input_indexes].add(jnp.ones_like(src)[source_indexes]) - count = jnp.clip(count, min=1) - mean = input.at[input_indexes].add(src[source_indexes]) - if _is_int(input): - return mean // count - return mean / count + if not isinstance(src, jnp.ndarray): + src = jnp.array(src, dtype=input.dtype) + input_indexes, source_indexes = _scatter_index(dim, index) + # "Zero out" target elements when not included + if not include_self: + if reduce in ["sum", "mean"]: + base_input = jnp.zeros_like(src) + elif reduce == "prod": + base_input = jnp.ones_like(src) elif reduce == "amax": - return input.at[input_indexes].max(src[source_indexes]) - elif reduce == "amin": - return input.at[input_indexes].min(src[source_indexes]) + base_input = jnp.full_like(src, -jnp.inf) + else: # amin + base_input = jnp.full_like(src, jnp.inf) + input = input.at[input_indexes].set(base_input[source_indexes]) + + if reduce == "sum" or reduce == "add": + return input.at[input_indexes].add(src[source_indexes]) + elif reduce == "prod" or reduce == "multiply": + return input.at[input_indexes].multiply(src[source_indexes]) + elif reduce == "mean": + if include_self: + count = jnp.ones_like(input) else: - return input.at[input_indexes].set(src[source_indexes]) + count = jnp.zeros_like(input) + count = count.at[input_indexes].add(jnp.ones_like(src)[source_indexes]) + count = jnp.clip(count, min=1) + mean = input.at[input_indexes].add(src[source_indexes]) + if _is_int(input): + return mean // count + return mean / count + elif reduce == "amax": + return input.at[input_indexes].max(src[source_indexes]) + elif reduce == "amin": + return input.at[input_indexes].min(src[source_indexes]) + else: + return input.at[input_indexes].set(src[source_indexes]) # aten.acos @op(torch.ops.aten.acos) @op_base.promote_int_input def _aten_acos(self): - return jnp.arccos(self) + return jnp.arccos(self) # aten.sym_storage_offset @@ -1911,349 +1896,342 @@ def _aten_acos(self): # aten.gt @op(torch.ops.aten.gt) def _aten_gt(self, other): - return self > other + return self > other # aten.sym_stride # aten.lt @op(torch.ops.aten.lt) def _aten_lt(self, other): - return self < other + return self < other def pool(inputs, init, reduce_fn, window_shape, strides, padding): - """Helper function to define pooling functions. - - Pooling functions are implemented using the ReduceWindow XLA op. - NOTE: Be aware that pooling is not generally differentiable. - That means providing a reduce_fn that is differentiable does not imply that - pool is differentiable. - - Args: - inputs: input data with dimensions (batch, window dims..., features). - init: the initial value for the reduction - reduce_fn: a reduce function of the form ``(T, T) -> T``. - window_shape: a shape tuple defining the window to reduce over. - strides: a sequence of ``n`` integers, representing the inter-window - strides (default: ``(1, ..., 1)``). - padding: either the string ``'SAME'``, the string ``'VALID'``, or a sequence - of ``n`` ``(low, high)`` integer pairs that give the padding to apply before - and after each spatial dimension. - Returns: - The output of the reduction for each window slice. - """ - num_batch_dims = inputs.ndim - (len(window_shape) + 1) - strides = strides or (1,) * len(window_shape) - assert len(window_shape) == len(strides), ( - f"len({window_shape}) must equal len({strides})" + """Helper function to define pooling functions. + + Pooling functions are implemented using the ReduceWindow XLA op. + NOTE: Be aware that pooling is not generally differentiable. + That means providing a reduce_fn that is differentiable does not imply that + pool is differentiable. + + Args: + inputs: input data with dimensions (batch, window dims..., features). + init: the initial value for the reduction + reduce_fn: a reduce function of the form ``(T, T) -> T``. + window_shape: a shape tuple defining the window to reduce over. + strides: a sequence of ``n`` integers, representing the inter-window + strides (default: ``(1, ..., 1)``). + padding: either the string ``'SAME'``, the string ``'VALID'``, or a sequence + of ``n`` ``(low, high)`` integer pairs that give the padding to apply before + and after each spatial dimension. + Returns: + The output of the reduction for each window slice. + """ + num_batch_dims = inputs.ndim - (len(window_shape) + 1) + strides = strides or (1,) * len(window_shape) + assert len(window_shape) == len(strides), ( + f"len({window_shape}) must equal len({strides})" + ) + strides = (1,) * (1 + num_batch_dims) + strides + dims = (1,) * (1 + num_batch_dims) + window_shape + + is_single_input = False + if num_batch_dims == 0: + # add singleton batch dimension because lax.reduce_window always + # needs a batch dimension. + inputs = inputs[None] + strides = (1,) + strides + dims = (1,) + dims + is_single_input = True + + assert inputs.ndim == len(dims), f"len({inputs.shape}) != len({dims})" + if not isinstance(padding, str): + padding = tuple(map(tuple, padding)) + assert len(padding) == len(window_shape), ( + f"padding {padding} must specify pads for same number of dims as " + f"window_shape {window_shape}" ) - strides = (1,) * (1 + num_batch_dims) + strides - dims = (1,) * (1 + num_batch_dims) + window_shape - - is_single_input = False - if num_batch_dims == 0: - # add singleton batch dimension because lax.reduce_window always - # needs a batch dimension. - inputs = inputs[None] - strides = (1,) + strides - dims = (1,) + dims - is_single_input = True - - assert inputs.ndim == len(dims), f"len({inputs.shape}) != len({dims})" - if not isinstance(padding, str): - padding = tuple(map(tuple, padding)) - assert len(padding) == len(window_shape), ( - f"padding {padding} must specify pads for same number of dims as " - f"window_shape {window_shape}" - ) - assert all([len(x) == 2 for x in padding]), ( - f"each entry in padding {padding} must be length 2" - ) - padding = ((0, 0), (0, 0)) + padding - y = jax.lax.reduce_window(inputs, init, reduce_fn, dims, strides, padding) - if is_single_input: - y = jnp.squeeze(y, axis=0) - return y + assert all([len(x) == 2 for x in padding]), ( + f"each entry in padding {padding} must be length 2" + ) + padding = ((0, 0), (0, 0)) + padding + y = jax.lax.reduce_window(inputs, init, reduce_fn, dims, strides, padding) + if is_single_input: + y = jnp.squeeze(y, axis=0) + return y @op(torch.ops.aten._adaptive_avg_pool2d) @op(torch.ops.aten._adaptive_avg_pool3d) def adaptive_avg_pool2or3d( - input: jnp.ndarray, output_size: Tuple[int, int] + input: jnp.ndarray, output_size: Tuple[int, int] ) -> jnp.ndarray: - """ - Applies a 2/3D adaptive average pooling over an input signal composed of several input planes. - - See :class:`~torch.nn.AdaptiveAvgPool2d` for details and output shape. - - Args: - input: input tensor - output_size: the target output size (single integer or double-integer tuple) - - Context: - https://github.com/pytorch/pytorch/blob/main/torch/_decomp/decompositions.py#L2401 - """ - shape = input.shape - ndim = len(shape) - out_dim = len(output_size) - num_spatial_dim = ndim - out_dim - - # Preconditions + """ + Applies a 2/3D adaptive average pooling over an input signal composed of several input planes. + + See :class:`~torch.nn.AdaptiveAvgPool2d` for details and output shape. + + Args: + input: input tensor + output_size: the target output size (single integer or double-integer tuple) + + Context: + https://github.com/pytorch/pytorch/blob/main/torch/_decomp/decompositions.py#L2401 + """ + shape = input.shape + ndim = len(shape) + out_dim = len(output_size) + num_spatial_dim = ndim - out_dim + + # Preconditions + + assert ndim in (out_dim + 1, out_dim + 2), ( + f"adaptive_avg_pool{num_spatial_dim}d(): Expected {num_spatial_dim + 1}D or {num_spatial_dim + 2}D tensor, but got {ndim}" + ) + for d in input.shape[-2:]: + assert d != 0, ( + "adaptive_avg_pool{num_spactial_dim}d(): Expected input to have non-zero size for " + f"non-batch dimensions, but input has shape {tuple(shape)}." + ) - assert ndim in (out_dim + 1, out_dim + 2), ( - f"adaptive_avg_pool{num_spatial_dim}d(): Expected {num_spatial_dim + 1}D or {num_spatial_dim + 2}D tensor, but got {ndim}" + # Optimisation (we should also do this in the kernel implementation) + if all(s % o == 0 for o, s in zip(output_size, shape[-out_dim:])): + stride = tuple(i // o for i, o in zip(shape[-out_dim:], output_size)) + kernel = tuple( + i - (o - 1) * s for i, o, s in zip(shape[-out_dim:], output_size, stride) + ) + return _aten_avg_pool( + input, + kernel, + strides=stride, ) - for d in input.shape[-2:]: - assert d != 0, ( - "adaptive_avg_pool{num_spactial_dim}d(): Expected input to have non-zero size for " - f"non-batch dimensions, but input has shape {tuple(shape)}." - ) - # Optimisation (we should also do this in the kernel implementation) - if all(s % o == 0 for o, s in zip(output_size, shape[-out_dim:])): - stride = tuple(i // o for i, o in zip(shape[-out_dim:], output_size)) - kernel = tuple( - i - (o - 1) * s - for i, o, s in zip(shape[-out_dim:], output_size, stride) - ) - return _aten_avg_pool( - input, - kernel, - strides=stride, - ) + def start_index(a, b, c): + return (a * c) // b + + def end_index(a, b, c): + return ((a + 1) * c + b - 1) // b + + def compute_idx(in_size, out_size): + orange = jnp.arange(out_size, dtype=jnp.int64) + i0 = start_index(orange, out_size, in_size) + # Let length = end_index - start_index, i.e. the length of the pooling kernels + # length.max() can be computed analytically as follows: + maxlength = in_size // out_size + 1 + in_size_mod = in_size % out_size + # adaptive = True iff there are kernels with different lengths + adaptive = not (in_size_mod == 0 or out_size % in_size_mod == 0) + if adaptive: + maxlength += 1 + elif in_size_mod == 0: + maxlength -= 1 + + range_max = jnp.arange(maxlength, dtype=jnp.int64) + idx = i0[:, None] + range_max + if adaptive: + # Need to clamp to avoid accessing out-of-bounds memory + idx = jnp.minimum(idx, in_size - 1) + + # Compute the length + i1 = end_index(orange, out_size, in_size) + length = i1 - i0 + else: + length = maxlength + return idx, length, range_max, adaptive + + idx, length, range_max, adaptive = [[None] * out_dim for _ in range(4)] + # length is not None if it's constant, otherwise we'll need to compute it + for i, (s, o) in enumerate(zip(shape[-out_dim:], output_size)): + idx[i], length[i], range_max[i], adaptive[i] = compute_idx(s, o) + + def _unsqueeze_to_dim(x, dim): + ndim = len(x.shape) + return jax.lax.expand_dims(x, tuple(range(ndim, dim))) + + if out_dim == 2: + # NOTE: unsqueeze to insert extra 1 in ranks; so they + # would broadcast + vals = input[..., _unsqueeze_to_dim(idx[0], 4), idx[1]] + reduce_axis = (-3, -1) + else: + assert out_dim == 3 + vals = input[ + ..., + _unsqueeze_to_dim(idx[0], 6), + _unsqueeze_to_dim(idx[1], 4), + idx[2], + ] + reduce_axis = (-5, -3, -1) - def start_index(a, b, c): - return (a * c) // b - - def end_index(a, b, c): - return ((a + 1) * c + b - 1) // b - - def compute_idx(in_size, out_size): - orange = jnp.arange(out_size, dtype=jnp.int64) - i0 = start_index(orange, out_size, in_size) - # Let length = end_index - start_index, i.e. the length of the pooling kernels - # length.max() can be computed analytically as follows: - maxlength = in_size // out_size + 1 - in_size_mod = in_size % out_size - # adaptive = True iff there are kernels with different lengths - adaptive = not (in_size_mod == 0 or out_size % in_size_mod == 0) - if adaptive: - maxlength += 1 - elif in_size_mod == 0: - maxlength -= 1 - - range_max = jnp.arange(maxlength, dtype=jnp.int64) - idx = i0[:, None] + range_max - if adaptive: - # Need to clamp to avoid accessing out-of-bounds memory - idx = jnp.minimum(idx, in_size - 1) - - # Compute the length - i1 = end_index(orange, out_size, in_size) - length = i1 - i0 - else: - length = maxlength - return idx, length, range_max, adaptive - - idx, length, range_max, adaptive = [[None] * out_dim for _ in range(4)] - # length is not None if it's constant, otherwise we'll need to compute it - for i, (s, o) in enumerate(zip(shape[-out_dim:], output_size)): - idx[i], length[i], range_max[i], adaptive[i] = compute_idx(s, o) - - def _unsqueeze_to_dim(x, dim): - ndim = len(x.shape) - return jax.lax.expand_dims(x, tuple(range(ndim, dim))) - - if out_dim == 2: - # NOTE: unsqueeze to insert extra 1 in ranks; so they - # would broadcast - vals = input[..., _unsqueeze_to_dim(idx[0], 4), idx[1]] - reduce_axis = (-3, -1) + # Shortcut for the simpler case + if not any(adaptive): + return jnp.mean(vals, axis=reduce_axis) + + def maybe_mask(vals, length, range_max, adaptive, dim): + if isinstance(length, int): + return vals, length else: - assert out_dim == 3 - vals = input[ - ..., - _unsqueeze_to_dim(idx[0], 6), - _unsqueeze_to_dim(idx[1], 4), - idx[2], - ] - reduce_axis = (-5, -3, -1) - - # Shortcut for the simpler case - if not any(adaptive): - return jnp.mean(vals, axis=reduce_axis) - - def maybe_mask(vals, length, range_max, adaptive, dim): - if isinstance(length, int): - return vals, length - else: - # zero-out the things we didn't really want to select - assert dim < 0 - # hack - mask = range_max >= length[:, None] - if dim == -2: - mask = _unsqueeze_to_dim(mask, 4) - elif dim == -3: - mask = _unsqueeze_to_dim(mask, 6) - vals = jnp.where(mask, 0.0, vals) - # Compute the length of each window - length = _unsqueeze_to_dim(length, -dim) - return vals, length - - for i in range(len(length)): - vals, length[i] = maybe_mask( - vals, - length[i], - range_max[i], - adaptive=adaptive[i], - dim=(i - out_dim), - ) + # zero-out the things we didn't really want to select + assert dim < 0 + # hack + mask = range_max >= length[:, None] + if dim == -2: + mask = _unsqueeze_to_dim(mask, 4) + elif dim == -3: + mask = _unsqueeze_to_dim(mask, 6) + vals = jnp.where(mask, 0.0, vals) + # Compute the length of each window + length = _unsqueeze_to_dim(length, -dim) + return vals, length + + for i in range(len(length)): + vals, length[i] = maybe_mask( + vals, + length[i], + range_max[i], + adaptive=adaptive[i], + dim=(i - out_dim), + ) - # We unroll the sum as we assume that the kernels are going to be small - ret = jnp.sum(vals, axis=reduce_axis) - # NOTE: math.prod because we want to expand it to length[0] * length[1] * ... - # this is multiplication with broadcasting, not regular pointwise product - return ret / math.prod(length) + # We unroll the sum as we assume that the kernels are going to be small + ret = jnp.sum(vals, axis=reduce_axis) + # NOTE: math.prod because we want to expand it to length[0] * length[1] * ... + # this is multiplication with broadcasting, not regular pointwise product + return ret / math.prod(length) @op(torch.ops.aten.avg_pool1d) @op(torch.ops.aten.avg_pool2d) @op(torch.ops.aten.avg_pool3d) def _aten_avg_pool( - inputs, - kernel_size, - strides=None, - padding=0, - ceil_mode=False, - count_include_pad=True, - divisor_override=None, + inputs, + kernel_size, + strides=None, + padding=0, + ceil_mode=False, + count_include_pad=True, + divisor_override=None, ): - num_batch_dims = len(inputs.shape) - len(kernel_size) - 1 - kernel_size = tuple(kernel_size) - strides = tuple(strides) if strides else kernel_size - if isinstance(padding, list) and len(padding) == 1: - padding = padding[0] - if isinstance(padding, int): - padding = [padding for _ in range(len(kernel_size))] - - input_shape = inputs.shape - if num_batch_dims == 0: - input_shape = [1, *input_shape] - padding = _ceil_mode_padding( - padding, - input_shape, - kernel_size, - strides, - [1] * len(kernel_size), - ceil_mode, + num_batch_dims = len(inputs.shape) - len(kernel_size) - 1 + kernel_size = tuple(kernel_size) + strides = tuple(strides) if strides else kernel_size + if isinstance(padding, list) and len(padding) == 1: + padding = padding[0] + if isinstance(padding, int): + padding = [padding for _ in range(len(kernel_size))] + + input_shape = inputs.shape + if num_batch_dims == 0: + input_shape = [1, *input_shape] + padding = _ceil_mode_padding( + padding, + input_shape, + kernel_size, + strides, + [1] * len(kernel_size), + ceil_mode, + ) + + y = pool(inputs, 0.0, jax.lax.add, kernel_size, strides, padding) + if divisor_override is not None: + y = y / jnp.array(divisor_override, y.dtype) + elif count_include_pad: + div_shape = list(y.shape) + div_by = jnp.ones(div_shape, y.dtype) * np.prod(kernel_size) + unequal_paddings = map(lambda pad: pad[0] != pad[1], padding) + unequal_padding_indices = np.where(list(unequal_paddings))[0] + if len(unequal_padding_indices) > 0: + # indices to update kernel size + offset = len(div_shape) - len(padding) + skip_indices = list(map(lambda x: x + offset, unequal_padding_indices)) + indices = _generate_indices(div_shape, skip_dim_indices=skip_indices) + # updated kernel size accounting for maximum padding + new_kernel_size = list(kernel_size) + for j in unequal_padding_indices: + new_kernel_size[j] = kernel_size[j] - padding[j][1] + padding[j][0] + + for idx in indices: + for j in unequal_padding_indices: + idx[j + offset] = -1 + div_by = div_by.at[tuple(idx)].set(np.prod(new_kernel_size)) + + y = y / div_by + else: + div_shape = list(inputs.shape) + div_shape[num_batch_dims] = 1 + div_shape = tuple(div_shape) + if len(div_shape) - 2 == len(kernel_size): + div_shape = (1,) + div_shape[1:] + y = y / pool( + jnp.ones(div_shape, y.dtype), + jnp.array(0.0, y.dtype), + jax.lax.add, + kernel_size, + strides, + padding, ) - - y = pool(inputs, 0.0, jax.lax.add, kernel_size, strides, padding) - if divisor_override is not None: - y = y / jnp.array(divisor_override, y.dtype) - elif count_include_pad: - div_shape = list(y.shape) - div_by = jnp.ones(div_shape, y.dtype) * np.prod(kernel_size) - unequal_paddings = map(lambda pad: pad[0] != pad[1], padding) - unequal_padding_indices = np.where(list(unequal_paddings))[0] - if len(unequal_padding_indices) > 0: - # indices to update kernel size - offset = len(div_shape) - len(padding) - skip_indices = list( - map(lambda x: x + offset, unequal_padding_indices) - ) - indices = _generate_indices( - div_shape, skip_dim_indices=skip_indices - ) - # updated kernel size accounting for maximum padding - new_kernel_size = list(kernel_size) - for j in unequal_padding_indices: - new_kernel_size[j] = ( - kernel_size[j] - padding[j][1] + padding[j][0] - ) - - for idx in indices: - for j in unequal_padding_indices: - idx[j + offset] = -1 - div_by = div_by.at[tuple(idx)].set(np.prod(new_kernel_size)) - - y = y / div_by - else: - div_shape = list(inputs.shape) - div_shape[num_batch_dims] = 1 - div_shape = tuple(div_shape) - if len(div_shape) - 2 == len(kernel_size): - div_shape = (1,) + div_shape[1:] - y = y / pool( - jnp.ones(div_shape, y.dtype), - jnp.array(0.0, y.dtype), - jax.lax.add, - kernel_size, - strides, - padding, - ) - return y.astype(inputs.dtype) + return y.astype(inputs.dtype) # helper function to generate all indices to iterate through ndarray def _generate_indices(dims, skip_dim_indices=[]): - res = [] - - def _helper(curr_dim_idx, sofar): - if curr_dim_idx in skip_dim_indices: - _helper(curr_dim_idx + 1, sofar[:]) - return - if curr_dim_idx >= len(dims): - res.append(sofar) - return - for i in range(dims[curr_dim_idx]): - sofar[curr_dim_idx] = i - _helper(curr_dim_idx + 1, sofar[:]) - - _helper(0, [0 for _ in dims]) - return res + res = [] + + def _helper(curr_dim_idx, sofar): + if curr_dim_idx in skip_dim_indices: + _helper(curr_dim_idx + 1, sofar[:]) + return + if curr_dim_idx >= len(dims): + res.append(sofar) + return + for i in range(dims[curr_dim_idx]): + sofar[curr_dim_idx] = i + _helper(curr_dim_idx + 1, sofar[:]) + + _helper(0, [0 for _ in dims]) + return res # aten.sym_numel # aten.reciprocal @op(torch.ops.aten.reciprocal) def _aten_reciprocal(a): - if _is_int(a): - return (1 / a).astype(jnp.dtype("float32")) - return 1 / a + if _is_int(a): + return (1 / a).astype(jnp.dtype("float32")) + return 1 / a # aten.select_scatter @op(torch.ops.aten.select_scatter) def _aten_select_scatter(input, src, dim, index): - input_indexes = [] - if dim < 0: - dim += len(input.shape) - for x in range(len(input.shape)): - if x == dim: - input_indexes.append(index) - else: - input_indexes.append(slice(None, None, None)) - return input.at[tuple(input_indexes)].set(src) + input_indexes = [] + if dim < 0: + dim += len(input.shape) + for x in range(len(input.shape)): + if x == dim: + input_indexes.append(index) + else: + input_indexes.append(slice(None, None, None)) + return input.at[tuple(input_indexes)].set(src) @op(torch.ops.aten.scatter.src) def _aten_scatter_src(input, dim, index, src, reduce=None): - input_index, source_indexes = _scatter_index(dim, index) - return input.at[input_index].set(src[source_indexes]) + input_index, source_indexes = _scatter_index(dim, index) + return input.at[input_index].set(src[source_indexes]) @op(torch.ops.aten.scatter.value) def _aten_scatter(input, dim, index, src, reduce=None): - input_index, source_indexes = _scatter_index(dim, index) - return input.at[input_index].set(src) + input_index, source_indexes = _scatter_index(dim, index) + return input.at[input_index].set(src) # aten.acosh @op(torch.ops.aten.acosh) @op_base.promote_int_input def _aten_acosh(self): - return jnp.arccosh(self) + return jnp.arccosh(self) # aten.avg_pool2d_backward @@ -2262,59 +2240,57 @@ def _aten_acosh(self): # aten.round @op(torch.ops.aten.round) def _aten_round(input, decimals=0): - return jnp.round(input, decimals) + return jnp.round(input, decimals) # aten.max @op(torch.ops.aten.max) def _aten_max(self, dim=None, keepdim=False): - if dim is not None: - return _with_reduction_scalar( - jnp.max, self, dim, keepdim - ), _with_reduction_scalar(jnp.argmax, self, dim, keepdim).astype( - jnp.int64 - ) - else: - return _with_reduction_scalar(jnp.max, self, dim, keepdim) + if dim is not None: + return _with_reduction_scalar( + jnp.max, self, dim, keepdim + ), _with_reduction_scalar(jnp.argmax, self, dim, keepdim).astype(jnp.int64) + else: + return _with_reduction_scalar(jnp.max, self, dim, keepdim) # aten.maximum @op(torch.ops.aten.maximum) def _aten_maximum(self, other): - return jnp.maximum(self, other) + return jnp.maximum(self, other) # aten.abs @op(torch.ops.aten.abs) def _aten_abs(self): - return jnp.abs(self) + return jnp.abs(self) # generate aten.amax only @op(torch.ops.aten.amax) def _aten_amax(self, dim=None, keepdim=False): - return _with_reduction_scalar(jnp.amax, self, dim, keepdim) + return _with_reduction_scalar(jnp.amax, self, dim, keepdim) def _with_reduction_scalar(jax_func, self, dim, keepdim): - expanded = False - if self.ndim == 0: - # for self of rank 0: - # torch.any(x, 0), torch.any(x, -1) works; - # torch.any(x, 1) throws out of bounds, so it's - # behavior is the same as a jnp array of rank 1 - expanded = True - self = jnp.expand_dims(self, 0) - res = jax_func(self, axis=dim, keepdims=keepdim) - if expanded: - res = res.squeeze() - return res + expanded = False + if self.ndim == 0: + # for self of rank 0: + # torch.any(x, 0), torch.any(x, -1) works; + # torch.any(x, 1) throws out of bounds, so it's + # behavior is the same as a jnp array of rank 1 + expanded = True + self = jnp.expand_dims(self, 0) + res = jax_func(self, axis=dim, keepdims=keepdim) + if expanded: + res = res.squeeze() + return res # aten.any @op(torch.ops.aten.any) def _aten_any(self, dim=None, keepdim=False): - return _with_reduction_scalar(jnp.any, self, dim, keepdim) + return _with_reduction_scalar(jnp.any, self, dim, keepdim) # aten.arange @@ -2323,769 +2299,758 @@ def _aten_any(self, dim=None, keepdim=False): @op(torch.ops.aten.arange.default) @op_base.convert_dtype(use_default_dtype=False) def _aten_arange( - start, - end=None, - step=None, - *, - dtype=None, - layout=None, - requires_grad=False, - device=None, - pin_memory=False, + start, + end=None, + step=None, + *, + dtype=None, + layout=None, + requires_grad=False, + device=None, + pin_memory=False, ): - return jnp.arange( - op_base.maybe_convert_constant_dtype(start, dtype), - op_base.maybe_convert_constant_dtype(end, dtype), - op_base.maybe_convert_constant_dtype(step, dtype), - dtype=dtype, - ) + return jnp.arange( + op_base.maybe_convert_constant_dtype(start, dtype), + op_base.maybe_convert_constant_dtype(end, dtype), + op_base.maybe_convert_constant_dtype(step, dtype), + dtype=dtype, + ) # aten.argmax @op(torch.ops.aten.argmax) def _aten_argmax(self, dim=None, keepdim=False): - return _with_reduction_scalar(jnp.argmax, self, dim, keepdim) + return _with_reduction_scalar(jnp.argmax, self, dim, keepdim) def _strided_index(sizes, strides, storage_offset=None): - ind = jnp.zeros(sizes, dtype=jnp.int32) + ind = jnp.zeros(sizes, dtype=jnp.int32) - for i, (size, stride) in enumerate(zip(sizes, strides)): - result_shape = (1,) * i + (size,) + (1,) * (len(sizes) - i - 1) - indexes = (jnp.arange(size) * stride).reshape(result_shape) - ind += indexes + for i, (size, stride) in enumerate(zip(sizes, strides)): + result_shape = (1,) * i + (size,) + (1,) * (len(sizes) - i - 1) + indexes = (jnp.arange(size) * stride).reshape(result_shape) + ind += indexes - if storage_offset is not None: - ind += storage_offset - return ind + if storage_offset is not None: + ind += storage_offset + return ind # aten.as_strided @op(torch.ops.aten.as_strided) @op(torch.ops.aten.as_strided_copy) def _aten_as_strided(x, sizes, strides, storage_offset=None): - ind = _strided_index(sizes, strides, storage_offset) - flattened = jnp.ravel(x) - return flattened[ind] + ind = _strided_index(sizes, strides, storage_offset) + flattened = jnp.ravel(x) + return flattened[ind] @op(torch.ops.aten.as_strided_scatter) def _aten_as_strided_scatter(x, src, sizes, strides, storage_offset): - ind = _strided_index(sizes, strides, storage_offset) - flattened = jnp.ravel(x) - modified = flattened.at[ind].set(src) - return modified.reshape(x.shape) + ind = _strided_index(sizes, strides, storage_offset) + flattened = jnp.ravel(x) + modified = flattened.at[ind].set(src) + return modified.reshape(x.shape) # aten.atan2 @op(torch.ops.aten.atan2) @op_base.promote_int_input def _aten_atan2(input, other): - return jnp.arctan2(input, other) + return jnp.arctan2(input, other) # aten.bitwise_and @op(torch.ops.aten.bitwise_and) @op(torch.ops.aten.__and__) def _aten_bitwise_and(self, other): - return self & other + return self & other # aten.bitwise_or @op(torch.ops.aten.bitwise_or) def _aten_bitwise_or(self, other): - return self | other + return self | other # aten.bitwise_xor @op(torch.ops.aten.bitwise_xor) def _aten_bitwise_xor(self, other): - return self ^ other + return self ^ other # aten.broadcast_tensors @op(torch.ops.aten.broadcast_tensors) def _aten_broadcast_tensors(*tensors): - def _get_broadcast_shape(shapes): - """ - Determines the output shape by broadcasting all input shapes. - - Args: - shapes: A list of tuples representing the shapes of the input tensors. - - Returns: - A tuple representing the broadcasted output shape. - """ - - # Find the maximum number of dimensions among all input tensors - max_dims = max(len(shape) for shape in shapes) - # Pad shorter shapes with 1s on the left to match the maximum number of dimensions - padded_shapes = [ - (1,) * (max_dims - len(shape)) + shape for shape in shapes - ] - - # Initialize the output shape with 1s - output_shape = [1] * max_dims - # Iterate through each dimension and apply broadcasting rules - for dim in range(max_dims): - dim_sizes = [shape[dim] for shape in padded_shapes] - max_size = max(dim_sizes) - if all(size == 1 or size == max_size for size in dim_sizes): - output_shape[dim] = max_size - else: - raise ValueError("Incompatible shapes for broadcasting") - return tuple(output_shape) - - def _broadcast_dimensions(input_shape, output_shape): - """ - Determines the broadcast_dimensions argument for jax.lax.broadcast_in_dim. - - Args: - input_shape: The shape of the input tensor. - output_shape: The desired output shape after broadcasting. - - Returns: - A tuple specifying which dimensions of the input tensor should be broadcasted. - """ - - res = tuple( - i - for i, (in_dim, out_dim) in enumerate( - zip(input_shape, output_shape) - ) - ) - return res + def _get_broadcast_shape(shapes): + """ + Determines the output shape by broadcasting all input shapes. - # clean some function's previous wrap - if ( - len(tensors) == 1 - and len(tensors[0]) >= 1 - and isinstance(tensors[0][0], jax.Array) - ): - tensors = tensors[0] - - # Get the shapes of all input tensors - shapes = [t.shape for t in tensors] - # Find the output shape by broadcasting all input shapes - output_shape = _get_broadcast_shape(shapes) - # Broadcast each tensor to the output shape - broadcasted_tensors = [ - jax.lax.broadcast_in_dim( - t, output_shape, _broadcast_dimensions(t.shape, output_shape) - ) - for t in tensors - ] + Args: + shapes: A list of tuples representing the shapes of the input tensors. - return broadcasted_tensors + Returns: + A tuple representing the broadcasted output shape. + """ + # Find the maximum number of dimensions among all input tensors + max_dims = max(len(shape) for shape in shapes) + # Pad shorter shapes with 1s on the left to match the maximum number of dimensions + padded_shapes = [(1,) * (max_dims - len(shape)) + shape for shape in shapes] + + # Initialize the output shape with 1s + output_shape = [1] * max_dims + # Iterate through each dimension and apply broadcasting rules + for dim in range(max_dims): + dim_sizes = [shape[dim] for shape in padded_shapes] + max_size = max(dim_sizes) + if all(size == 1 or size == max_size for size in dim_sizes): + output_shape[dim] = max_size + else: + raise ValueError("Incompatible shapes for broadcasting") + return tuple(output_shape) + + def _broadcast_dimensions(input_shape, output_shape): + """ + Determines the broadcast_dimensions argument for jax.lax.broadcast_in_dim. -# aten.broadcast_to -@op(torch.ops.aten.broadcast_to) -def _aten_broadcast_to(input, shape): - return jnp.broadcast_to(input, shape) + Args: + input_shape: The shape of the input tensor. + output_shape: The desired output shape after broadcasting. + Returns: + A tuple specifying which dimensions of the input tensor should be broadcasted. + """ -# aten.clamp -@op(torch.ops.aten.clamp.default) -@op(torch.ops.aten.clamp.Tensor) + res = tuple( + i for i, (in_dim, out_dim) in enumerate(zip(input_shape, output_shape)) + ) + return res + + # clean some function's previous wrap + if ( + len(tensors) == 1 + and len(tensors[0]) >= 1 + and isinstance(tensors[0][0], jax.Array) + ): + tensors = tensors[0] + + # Get the shapes of all input tensors + shapes = [t.shape for t in tensors] + # Find the output shape by broadcasting all input shapes + output_shape = _get_broadcast_shape(shapes) + # Broadcast each tensor to the output shape + broadcasted_tensors = [ + jax.lax.broadcast_in_dim( + t, output_shape, _broadcast_dimensions(t.shape, output_shape) + ) + for t in tensors + ] + + return broadcasted_tensors + + +# aten.broadcast_to +@op(torch.ops.aten.broadcast_to) +def _aten_broadcast_to(input, shape): + return jnp.broadcast_to(input, shape) + + +# aten.clamp +@op(torch.ops.aten.clamp.default) +@op(torch.ops.aten.clamp.Tensor) def _aten_clamp(self, min=None, max=None): - return jnp.clip(self, min, max) + return jnp.clip(self, min, max) @op(torch.ops.aten.clamp_min) def _aten_clamp_min(input, min): - return jnp.clip(input, min=min) + return jnp.clip(input, min=min) # aten.constant_pad_nd @op(torch.ops.aten.constant_pad_nd) def _aten_constant_pad_nd(input, padding, value=0): - # NOTE: Torch padding is flat and reversed: (1, 1, 2, 2) - # means last dim get padded 1 in front and 1 in back; - # and second last dim get padded 2 in front and 2 in back. - # Jax padding tuple of 3-tuple: the same padding is - # [(0, 0, 0), ..., (2,2,0), (1,1,0)], where the last dimension - # is the amount of padding added between any two elements in each dimension - m = len(padding) - rev_padding = [(padding[i - 1], padding[i], 0) for i in range(m - 1, 0, -2)] - pad_dim = tuple(([(0, 0, 0)] * (len(input.shape) - m // 2)) + rev_padding) - value_casted = jax.numpy.array(value, dtype=input.dtype) - return jax.lax.pad( - input, padding_value=value_casted, padding_config=pad_dim - ) + # NOTE: Torch padding is flat and reversed: (1, 1, 2, 2) + # means last dim get padded 1 in front and 1 in back; + # and second last dim get padded 2 in front and 2 in back. + # Jax padding tuple of 3-tuple: the same padding is + # [(0, 0, 0), ..., (2,2,0), (1,1,0)], where the last dimension + # is the amount of padding added between any two elements in each dimension + m = len(padding) + rev_padding = [(padding[i - 1], padding[i], 0) for i in range(m - 1, 0, -2)] + pad_dim = tuple(([(0, 0, 0)] * (len(input.shape) - m // 2)) + rev_padding) + value_casted = jax.numpy.array(value, dtype=input.dtype) + return jax.lax.pad(input, padding_value=value_casted, padding_config=pad_dim) # aten.convolution_backward @op(torch.ops.aten.lift_fresh_copy) def _aten_lift_fresh_copy(x): - return jnp.copy(x) + return jnp.copy(x) @op(torch.ops.aten.copy) def _aten_copy(self, src): - return jnp.broadcast_to(src, self.shape).astype(self.dtype) + return jnp.broadcast_to(src, self.shape).astype(self.dtype) @op(torch.ops.aten._cdist_forward) def _aten_cdist_forward(x1, x2, p, compute_mode=""): - # x1 is B x P x M - # x2 is B x Q x M - # res is B x P x Q - x1 = jnp.expand_dims(x1, len(x1.shape) - 1) - x2 = jnp.expand_dims(x2, len(x2.shape) - 2) - return jnp.linalg.norm(x1 - x2, ord=p, axis=-1) + # x1 is B x P x M + # x2 is B x Q x M + # res is B x P x Q + x1 = jnp.expand_dims(x1, len(x1.shape) - 1) + x2 = jnp.expand_dims(x2, len(x2.shape) - 2) + return jnp.linalg.norm(x1 - x2, ord=p, axis=-1) @op(torch.ops.aten._pdist_forward) def _aten__pdist_forward(x, p=2): - pairwise_dists = _aten_cdist_forward(x, x, p) - condensed_dists = pairwise_dists[ - jnp.triu_indices(pairwise_dists.shape[0], k=1) - ] - return condensed_dists + pairwise_dists = _aten_cdist_forward(x, x, p) + condensed_dists = pairwise_dists[ + jnp.triu_indices(pairwise_dists.shape[0], k=1) + ] + return condensed_dists @op(torch.ops.aten.cholesky_inverse) def _aten_cholesky_inverse(input, upper=False): - t = jnp.matrix_transpose(input) - if "complex" in str(input.dtype): - t = t.conjugate() - return jnp.linalg.inv(input @ t) + t = jnp.matrix_transpose(input) + if "complex" in str(input.dtype): + t = t.conjugate() + return jnp.linalg.inv(input @ t) # aten.cos @op(torch.ops.aten.cos) @op_base.promote_int_input def _aten_cos(input): - return jnp.cos(input) + return jnp.cos(input) # aten.cosh @op(torch.ops.aten.cosh) @op_base.promote_int_input def _aten_cosh(input): - return jnp.cosh(input) + return jnp.cosh(input) @op(torch.ops.aten.diag) def _aten_diag(input, diagonal=0): - return jnp.diag(input, diagonal) + return jnp.diag(input, diagonal) # aten.diagonal @op(torch.ops.aten.diagonal) def _aten_diagonal(input, offset=0, dim1=0, dim2=1): - return jnp.diagonal(input, offset, dim1, dim2) + return jnp.diagonal(input, offset, dim1, dim2) def diag_indices_with_offset(input_shape, offset, dim1=0, dim2=1): - input_len = len(input_shape) - if dim1 == dim2 or not (0 <= dim1 < input_len and 0 <= dim2 < input_len): - raise ValueError( - "dim1 and dim2 must be different and in range [0, " - + str(input_len - 1) - + "]" - ) + input_len = len(input_shape) + if dim1 == dim2 or not (0 <= dim1 < input_len and 0 <= dim2 < input_len): + raise ValueError( + "dim1 and dim2 must be different and in range [0, " + + str(input_len - 1) + + "]" + ) - size1, size2 = input_shape[dim1], input_shape[dim2] - if offset >= 0: - indices1 = jnp.arange(min(size1, size2 - offset)) - indices2 = jnp.arange(offset, offset + len(indices1)) - else: - indices2 = jnp.arange(min(size1 + offset, size2)) - indices1 = jnp.arange(-offset, -offset + len(indices2)) - return [indices1, indices2] + size1, size2 = input_shape[dim1], input_shape[dim2] + if offset >= 0: + indices1 = jnp.arange(min(size1, size2 - offset)) + indices2 = jnp.arange(offset, offset + len(indices1)) + else: + indices2 = jnp.arange(min(size1 + offset, size2)) + indices1 = jnp.arange(-offset, -offset + len(indices2)) + return [indices1, indices2] @op(torch.ops.aten.diagonal_scatter) def _aten_diagonal_scatter(input, src, offset=0, dim1=0, dim2=1): - indexes = diag_indices_with_offset(input.shape, offset, dim1, dim2) - - if input.ndim == 2: - return input.at[tuple(indexes)].set(src) - else: - # src has the same shape as the output of - # jnp.diagonal(input, offset, dim1, dim2). - # Last dimension always contains the diagonal elements, - # while the preceding dimensions represent the "slices" - # from which these diagonals are extracted. Thus, - # we alter input axes to match this assumption, write src - # and then move the axes back to the original state. - input = jnp.moveaxis(input, (dim1, dim2), (-2, -1)) - multi_indexes = [slice(None)] * (input.ndim - 2) + indexes - input = input.at[tuple(multi_indexes)].set(src) - return jnp.moveaxis(input, (-2, -1), (dim1, dim2)) + indexes = diag_indices_with_offset(input.shape, offset, dim1, dim2) + + if input.ndim == 2: + return input.at[tuple(indexes)].set(src) + else: + # src has the same shape as the output of + # jnp.diagonal(input, offset, dim1, dim2). + # Last dimension always contains the diagonal elements, + # while the preceding dimensions represent the "slices" + # from which these diagonals are extracted. Thus, + # we alter input axes to match this assumption, write src + # and then move the axes back to the original state. + input = jnp.moveaxis(input, (dim1, dim2), (-2, -1)) + multi_indexes = [slice(None)] * (input.ndim - 2) + indexes + input = input.at[tuple(multi_indexes)].set(src) + return jnp.moveaxis(input, (-2, -1), (dim1, dim2)) # aten.diagflat @op(torch.ops.aten.diagflat) def _aten_diagflat(input, offset=0): - return jnp.diagflat(jnp.array(input), offset) + return jnp.diagflat(jnp.array(input), offset) @op(torch.ops.aten.movedim) def _aten_movedim(input, source, destination): - return jnp.moveaxis(input, source, destination) + return jnp.moveaxis(input, source, destination) # aten.eq @op(torch.ops.aten.eq) def _aten_eq(input1, input2): - return input1 == input2 + return input1 == input2 # aten.equal @op(torch.ops.aten.equal) def _aten_equal(input, other): - res = jnp.array_equal(input, other) - return bool(res) + res = jnp.array_equal(input, other) + return bool(res) # aten.erf @op(torch.ops.aten.erf) @op_base.promote_int_input def _aten_erf(x): - return jax.lax.erf(x) + return jax.lax.erf(x) @op(torch.ops.aten.erfinv) @op_base.promote_int_input def _aten_erfinv(input): - return jax.lax.erf_inv(input) + return jax.lax.erf_inv(input) # aten.exp @op(torch.ops.aten.exp) def _aten_exp(input): - res = jnp.exp(input) - new_dtype = mappings.t2j_dtype(torch.get_default_dtype()) - if input.dtype == jax.numpy.int64: - res = res.astype(new_dtype) - return res + res = jnp.exp(input) + new_dtype = mappings.t2j_dtype(torch.get_default_dtype()) + if input.dtype == jax.numpy.int64: + res = res.astype(new_dtype) + return res # aten.expm1 @op(torch.ops.aten.expm1) def _aten_expm1(input): - res = jnp.expm1(input) - new_dtype = mappings.t2j_dtype(torch.get_default_dtype()) - if input.dtype == jax.numpy.int64: - res = res.astype(new_dtype) - return res + res = jnp.expm1(input) + new_dtype = mappings.t2j_dtype(torch.get_default_dtype()) + if input.dtype == jax.numpy.int64: + res = res.astype(new_dtype) + return res # aten.exp2 @op(torch.ops.aten.exp2) def _aten_exp2(input): - res = jnp.exp2(input) - new_dtype = mappings.t2j_dtype(torch.get_default_dtype()) - if input.dtype == jax.numpy.int64: - res = res.astype(new_dtype) - return res + res = jnp.exp2(input) + new_dtype = mappings.t2j_dtype(torch.get_default_dtype()) + if input.dtype == jax.numpy.int64: + res = res.astype(new_dtype) + return res # aten.fill @op(torch.ops.aten.fill) @op(torch.ops.aten.full_like) def _aten_fill( - x, value, dtype=None, pin_memory=None, memory_format=None, device=None + x, value, dtype=None, pin_memory=None, memory_format=None, device=None ): - if dtype is None: - dtype = x.dtype - else: - dtype = mappings.t2j_dtype(dtype) - return jnp.full(x.shape, value, dtype) + if dtype is None: + dtype = x.dtype + else: + dtype = mappings.t2j_dtype(dtype) + return jnp.full(x.shape, value, dtype) # aten.flip @op(torch.ops.aten.flip) def _aten_flip(input, dims): - if dims is not None: - return jnp.flip(input, tuple(dims)) - else: - return jnp.flip(input) + if dims is not None: + return jnp.flip(input, tuple(dims)) + else: + return jnp.flip(input) # aten.floor @op(torch.ops.aten.floor) def _aten_floor(input): - return jnp.floor(input).astype(input.dtype) + return jnp.floor(input).astype(input.dtype) # aten.fmax @op(torch.ops.aten.fmax) def _aten_fmax(input, other): - return jnp.fmax(input, other) + return jnp.fmax(input, other) # aten.fmin @op(torch.ops.aten.fmin) def _aten_fmin(input, other): - return jnp.fmin(input, other) + return jnp.fmin(input, other) # aten.fmod @op(torch.ops.aten.fmod) def _aten_fmod(input, other): - return input - other * _aten_div(input, other, "trunc") + return input - other * _aten_div(input, other, "trunc") # aten.frexp @op(torch.ops.aten.frexp) def _aten_frexp(input): - return jnp.frexp(input) + return jnp.frexp(input) # aten.gather @op(torch.ops.aten.gather) def _aten_gather(input, dim, index): - if input.ndim == 0: - return jnp.broadcast_to(input, index.shape) - # short circuit for empty outputs - if not all(index.shape): - return jnp.zeros(index.shape, dtype=input.dtype) - if dim < 0: - dim += input.ndim - input_indexes, source_indexes = _scatter_index(dim, index) - return input[input_indexes] + if input.ndim == 0: + return jnp.broadcast_to(input, index.shape) + # short circuit for empty outputs + if not all(index.shape): + return jnp.zeros(index.shape, dtype=input.dtype) + if dim < 0: + dim += input.ndim + input_indexes, source_indexes = _scatter_index(dim, index) + return input[input_indexes] # aten.ge @op(torch.ops.aten.ge) def _aten_ge(self, other): - return self >= other + return self >= other @op(torch.ops.aten.glu) def _aten_glu(x, dim=-1): - return jax.nn.glu(x, dim) + return jax.nn.glu(x, dim) # aten.hardtanh @op(torch.ops.aten.hardtanh) def _aten_hardtanh(input, min_val=-1, max_val=1, inplace=False): - if ( - input.dtype == np.int64 - and isinstance(max_val, float) - and isinstance(min_val, float) - ): - min_val = int(min_val) - max_val = int(max_val) - return jnp.clip(input, min_val, max_val) + if ( + input.dtype == np.int64 + and isinstance(max_val, float) + and isinstance(min_val, float) + ): + min_val = int(min_val) + max_val = int(max_val) + return jnp.clip(input, min_val, max_val) # aten.histc @op(torch.ops.aten.histc) def _aten_histc(input, bins=100, min=0, max=0): - # TODO(@manfei): this function might cause some uncertainty - if min == 0 and max == 0: - if isinstance(input, jnp.ndarray) and input.size == 0: - min = 0 - max = 0 - else: - min = jnp.min(input) - max = jnp.max(input) - range_value = (min, max) - hist, bin_edges = jnp.histogram( - input, bins=bins, range=range_value, weights=None, density=None - ) - return hist + # TODO(@manfei): this function might cause some uncertainty + if min == 0 and max == 0: + if isinstance(input, jnp.ndarray) and input.size == 0: + min = 0 + max = 0 + else: + min = jnp.min(input) + max = jnp.max(input) + range_value = (min, max) + hist, bin_edges = jnp.histogram( + input, bins=bins, range=range_value, weights=None, density=None + ) + return hist @op(torch.ops.aten.hypot) def _aten_hypot(input, other): - return jnp.hypot(input, other) + return jnp.hypot(input, other) @op(torch.ops.aten.digamma) def _aten_digamma(input, *, out=None): - res = jax.scipy.special.digamma(input).astype(jnp.float32) - # replace indices where input == 0 with -inf in res - return jnp.where(jnp.equal(input, jnp.zeros(input.shape)), -jnp.inf, res) + res = jax.scipy.special.digamma(input).astype(jnp.float32) + # replace indices where input == 0 with -inf in res + return jnp.where(jnp.equal(input, jnp.zeros(input.shape)), -jnp.inf, res) @op(torch.ops.aten.igamma) def _aten_igamma(input, other): - return jax.scipy.special.gammainc(input, other) + return jax.scipy.special.gammainc(input, other) @op(torch.ops.aten.lgamma) def _aten_lgamma(input, *, out=None): - return jax.scipy.special.gammaln(input).astype(jnp.float32) + return jax.scipy.special.gammaln(input).astype(jnp.float32) @op(torch.ops.aten.mvlgamma) def _aten_mvlgamma(input, p, *, out=None): - input = input.astype(mappings.t2j_dtype(torch.get_default_dtype())) - return jax.scipy.special.multigammaln(input, p) + input = input.astype(mappings.t2j_dtype(torch.get_default_dtype())) + return jax.scipy.special.multigammaln(input, p) @op(torch.ops.aten.linalg_eig) def _aten_linalg_eig(A): - return jnp.linalg.eig(A) + return jnp.linalg.eig(A) @op(torch.ops.aten._linalg_eigh) def _aten_linalg_eigh(A, UPLO="L"): - return jnp.linalg.eigh(A, UPLO) + return jnp.linalg.eigh(A, UPLO) @op(torch.ops.aten.linalg_lstsq) def _aten_linalg_lstsq(A, B, rcond=None, driver="gelsy"): - input_dtype = A.dtype - - m = A.shape[-2] - n = A.shape[-1] + input_dtype = A.dtype - is_batched = A.ndim > 2 + m = A.shape[-2] + n = A.shape[-1] - if is_batched: - batch_shape = jnp.broadcast_shapes(A.shape[:-2], B.shape[:-2]) - batch_size = int(np.prod(batch_shape)) - A_reshaped = A.reshape((batch_size,) + A.shape[-2:]) - B_reshaped = B.reshape((batch_size,) + B.shape[-2:]) + is_batched = A.ndim > 2 - X, residuals, rank, singular_values = jax.vmap( - jnp.linalg.lstsq, in_axes=(0, 0) - )(A_reshaped, B_reshaped, rcond=rcond) + if is_batched: + batch_shape = jnp.broadcast_shapes(A.shape[:-2], B.shape[:-2]) + batch_size = int(np.prod(batch_shape)) + A_reshaped = A.reshape((batch_size,) + A.shape[-2:]) + B_reshaped = B.reshape((batch_size,) + B.shape[-2:]) - X = X.reshape(batch_shape + X.shape[-2:]) + X, residuals, rank, singular_values = jax.vmap( + jnp.linalg.lstsq, in_axes=(0, 0) + )(A_reshaped, B_reshaped, rcond=rcond) - if driver in ["gelsd", "gelsy", "gelss"]: - rank = rank.reshape(batch_shape) - else: - rank = jnp.array([], dtype=jnp.int64) + X = X.reshape(batch_shape + X.shape[-2:]) - full_rank = jnp.all(rank == n) - if driver == "gelsy" or m <= n or (not full_rank): - residuals = jnp.array([], dtype=input_dtype) - else: - residuals = residuals.reshape(batch_shape + residuals.shape[-1:]) + if driver in ["gelsd", "gelsy", "gelss"]: + rank = rank.reshape(batch_shape) + else: + rank = jnp.array([], dtype=jnp.int64) - if driver in ["gelsd", "gelss"]: - singular_values = singular_values.reshape( - batch_shape + singular_values.shape[-1:] - ) - else: - singular_values = jnp.array([], dtype=input_dtype) + full_rank = jnp.all(rank == n) + if driver == "gelsy" or m <= n or (not full_rank): + residuals = jnp.array([], dtype=input_dtype) + else: + residuals = residuals.reshape(batch_shape + residuals.shape[-1:]) + if driver in ["gelsd", "gelss"]: + singular_values = singular_values.reshape( + batch_shape + singular_values.shape[-1:] + ) else: - X, residuals, rank, singular_values = jnp.linalg.lstsq( - A, B, rcond=rcond - ) + singular_values = jnp.array([], dtype=input_dtype) - if driver not in ["gelsd", "gelsy", "gelss"]: - rank = jnp.array([], dtype=jnp.int64) + else: + X, residuals, rank, singular_values = jnp.linalg.lstsq(A, B, rcond=rcond) - rank_value = None - if rank.size > 0: - rank_value = int(rank.item()) - rank = jnp.array(rank_value, dtype=jnp.int64) + if driver not in ["gelsd", "gelsy", "gelss"]: + rank = jnp.array([], dtype=jnp.int64) - # When driver is ‘gels’, assume that A is full-rank. - full_rank = driver == "gels" or rank_value == n - if driver == "gelsy" or m <= n or (not full_rank): - residuals = jnp.array([], dtype=input_dtype) + rank_value = None + if rank.size > 0: + rank_value = int(rank.item()) + rank = jnp.array(rank_value, dtype=jnp.int64) - if driver not in ["gelsd", "gelss"]: - singular_values = jnp.array([], dtype=input_dtype) + # When driver is ‘gels’, assume that A is full-rank. + full_rank = driver == "gels" or rank_value == n + if driver == "gelsy" or m <= n or (not full_rank): + residuals = jnp.array([], dtype=input_dtype) - return X, residuals, rank, singular_values + if driver not in ["gelsd", "gelss"]: + singular_values = jnp.array([], dtype=input_dtype) + + return X, residuals, rank, singular_values @op(torch.ops.aten.linalg_ldl_factor_ex) def _aten_linalg_ldl_factor_ex(A, hermitian=False, check_errors=False): - # TODO: Replace with native LDL when available: - # https://github.com/jax-ml/jax/issues/12779 - # TODO: Not tested for complex inputs. Does not support hermitian=True - pivots = jnp.broadcast_to( - jnp.arange(1, A.shape[-1] + 1, dtype=jnp.int32), A.shape[:-1] - ) - info = jnp.zeros(A.shape[:-2], jnp.int32) - C = jnp.linalg.cholesky(A) - if C.size == 0: - return C, pivots, info - - # Fill diagonals of stacked matrices - @functools.partial(jnp.vectorize, signature="(k,k),(k,k)->(k,k)") - def fill_diagonal_batch(x, y): - return jnp.fill_diagonal(x, jnp.diag(y), inplace=False) - - D = C * jnp.eye(C.shape[-1], dtype=A.dtype) - LD = C @ jnp.linalg.inv(D) - LD = fill_diagonal_batch(LD, D * D) - return LD, pivots, info + # TODO: Replace with native LDL when available: + # https://github.com/jax-ml/jax/issues/12779 + # TODO: Not tested for complex inputs. Does not support hermitian=True + pivots = jnp.broadcast_to( + jnp.arange(1, A.shape[-1] + 1, dtype=jnp.int32), A.shape[:-1] + ) + info = jnp.zeros(A.shape[:-2], jnp.int32) + C = jnp.linalg.cholesky(A) + if C.size == 0: + return C, pivots, info + + # Fill diagonals of stacked matrices + @functools.partial(jnp.vectorize, signature="(k,k),(k,k)->(k,k)") + def fill_diagonal_batch(x, y): + return jnp.fill_diagonal(x, jnp.diag(y), inplace=False) + + D = C * jnp.eye(C.shape[-1], dtype=A.dtype) + LD = C @ jnp.linalg.inv(D) + LD = fill_diagonal_batch(LD, D * D) + return LD, pivots, info @op(torch.ops.aten.linalg_lu) def _aten_linalg_lu(A, pivot=True, out=None): - dtype = A.dtype + dtype = A.dtype - *_, m, n = A.shape - k = jnp.minimum(m, n) + *_, m, n = A.shape + k = jnp.minimum(m, n) - lu, _, permutation = jax.lax.linalg.lu(A) + lu, _, permutation = jax.lax.linalg.lu(A) - L = jnp.tril(lu[..., :, :k], k=-1) - eye_L = jnp.eye(m, k, dtype=dtype) - L = L + eye_L + L = jnp.tril(lu[..., :, :k], k=-1) + eye_L = jnp.eye(m, k, dtype=dtype) + L = L + eye_L - U = jnp.triu(lu[..., :k, :]) + U = jnp.triu(lu[..., :k, :]) - def perm_to_P(perm): - m = perm.shape[-1] - P = jnp.eye(m, dtype=dtype)[perm].T - return P + def perm_to_P(perm): + m = perm.shape[-1] + P = jnp.eye(m, dtype=dtype)[perm].T + return P - if permutation.ndim > 1: - num_batch_dims = permutation.ndim - 1 - for _ in range(num_batch_dims): - perm_to_P = jax.vmap(perm_to_P, in_axes=0) + if permutation.ndim > 1: + num_batch_dims = permutation.ndim - 1 + for _ in range(num_batch_dims): + perm_to_P = jax.vmap(perm_to_P, in_axes=0) - P = perm_to_P(permutation) + P = perm_to_P(permutation) - return P, L, U + return P, L, U @op(torch.ops.aten.linalg_lu_factor_ex) def _aten_linalg_lu_factor_ex(A, pivot=True, check_errors=False): - lu, pivots, _ = jax.lax.linalg.lu(A) - # PT pivots vector is 1-indexed - pivots = pivots + 1 - info = jnp.zeros(A.shape[:-2], jnp.int32) - return lu, pivots, info + lu, pivots, _ = jax.lax.linalg.lu(A) + # PT pivots vector is 1-indexed + pivots = pivots + 1 + info = jnp.zeros(A.shape[:-2], jnp.int32) + return lu, pivots, info @op(torch.ops.aten.linalg_lu_solve) def _aten_linalg_lu_solve(LU, pivots, B, left=True, adjoint=False): - # JAX pivots are offset by 1 compared to torch - pivots = pivots - 1 - if not left: - # XA = B is same as A'X = B' - trans = 0 if adjoint else 2 - x = jax.scipy.linalg.lu_solve( - (LU, pivots), jnp.matrix_transpose(B), trans - ) - x = jnp.matrix_transpose(x) - else: - trans = 2 if adjoint else 0 - x = jax.scipy.linalg.lu_solve((LU, pivots), B, trans) - return x + # JAX pivots are offset by 1 compared to torch + pivots = pivots - 1 + if not left: + # XA = B is same as A'X = B' + trans = 0 if adjoint else 2 + x = jax.scipy.linalg.lu_solve((LU, pivots), jnp.matrix_transpose(B), trans) + x = jnp.matrix_transpose(x) + else: + trans = 2 if adjoint else 0 + x = jax.scipy.linalg.lu_solve((LU, pivots), B, trans) + return x @op(torch.ops.aten.gcd) def _aten_gcd(input, other): - return jnp.gcd(input, other) + return jnp.gcd(input, other) # aten.lcm @op(torch.ops.aten.lcm) def _aten_lcm(input, other): - return jnp.lcm(input, other) + return jnp.lcm(input, other) # aten.isinf @op(torch.ops.aten.isinf) def _aten_isinf(input): - return jnp.isinf(input) + return jnp.isinf(input) # aten.isnan @op(torch.ops.aten.isnan) def _aten_isnan(input): - return jnp.isnan(input) + return jnp.isnan(input) @op(torch.ops.aten.le) def _aten_le(self, other): - return self <= other + return self <= other # aten.leaky_relu @op(torch.ops.aten.leaky_relu) def _aten_leaky_relu(x, negative_slope=0.01): - return jax.nn.leaky_relu(x, negative_slope) + return jax.nn.leaky_relu(x, negative_slope) # aten.log @op(torch.ops.aten.log) @op_base.promote_int_input def _aten_log(x): - return jnp.log(x) + return jnp.log(x) # aten.log10 @op(torch.ops.aten.log10) @op_base.promote_int_input def _aten_log10(x): - return jnp.log10(x) + return jnp.log10(x) # aten.log1p @op(torch.ops.aten.log1p) @op_base.promote_int_input def _aten_log1p(x): - return jnp.log1p(x) + return jnp.log1p(x) # aten.log2 @op(torch.ops.aten.log2) @op_base.promote_int_input def _aten_log2(x): - return jnp.log2(x) + return jnp.log2(x) # aten.logical_and @op(torch.ops.aten.logical_and) def _aten_logical_and(self, other): - return jnp.logical_and(self, other) + return jnp.logical_and(self, other) # aten.logical_or @op(torch.ops.aten.logical_or) def _aten_logical_or(self, other): - return jnp.logical_or(self, other) + return jnp.logical_or(self, other) # aten.logical_not @op(torch.ops.aten.logical_not) def _aten_logical_not(self): - return jnp.logical_not(self) + return jnp.logical_not(self) # aten.log_softmax @op(torch.ops.aten._log_softmax) def _aten_log_softmax(self, axis=-1, half_to_float=False): - if self.shape == (): - return jnp.astype(0.0, self.dtype) - return jax.nn.log_softmax(self, axis) + if self.shape == (): + return jnp.astype(0.0, self.dtype) + return jax.nn.log_softmax(self, axis) # aten.logaddexp @op(torch.ops.aten.logaddexp) def _aten_logaddexp(self, other): - return jnp.logaddexp(self, other) + return jnp.logaddexp(self, other) # aten.logaddexp2 @op(torch.ops.aten.logaddexp2) def _aten_logaddexp2(self, other): - return jnp.logaddexp2(self, other) + return jnp.logaddexp2(self, other) # aten.logcumsumexp @op(torch.ops.aten.logcumsumexp) def _aten_logcumsumexp(self, dim=None): - if self.shape == (): - return self - return jax.lax.cumlogsumexp(self, axis=dim) + if self.shape == (): + return self + return jax.lax.cumlogsumexp(self, axis=dim) # aten.max_pool3d_backward # aten.logical_xor @op(torch.ops.aten.logical_xor) def _aten_logical_xor(self, other): - return jnp.logical_xor(self, other) + return jnp.logical_xor(self, other) # aten.max_pool2d_with_indices_backward @@ -3094,111 +3059,111 @@ def _aten_logical_xor(self, other): # aten.neg @op(torch.ops.aten.neg) def _aten_neg(x): - return -1 * x + return -1 * x @op(torch.ops.aten.nextafter) def _aten_nextafter(input, other, *, out=None): - return jnp.nextafter(input, other) + return jnp.nextafter(input, other) @op(torch.ops.aten.nonzero_static) def _aten_nonzero_static(input, size, fill_value=-1): - indices = jnp.argwhere(input) - - if size < indices.shape[0]: - indices = indices[:size] - elif size > indices.shape[0]: - padding = jnp.full( - (size - indices.shape[0], indices.shape[1]), - fill_value, - dtype=indices.dtype, - ) - indices = jnp.concatenate((indices, padding)) + indices = jnp.argwhere(input) + + if size < indices.shape[0]: + indices = indices[:size] + elif size > indices.shape[0]: + padding = jnp.full( + (size - indices.shape[0], indices.shape[1]), + fill_value, + dtype=indices.dtype, + ) + indices = jnp.concatenate((indices, padding)) - return indices + return indices # aten.nonzero @op(torch.ops.aten.nonzero) def _aten_nonzero(x, as_tuple=False): - if jnp.ndim(x) == 0 and (as_tuple or x.item() == 0): - return torch.empty(0, 0, dtype=torch.int64) - if ( - jnp.ndim(x) == 0 - ): # when x is scalar, return torch.tensor([], size=(1, 0), dtype=torch.int64) - res = torch.empty(1, 0, dtype=torch.int64) - return jnp.array(res.numpy()) - index_tuple = jnp.nonzero(x) - index_tuple = [jnp.expand_dims(p, -1) for p in index_tuple] - return jnp.concatenate(index_tuple, axis=-1) + if jnp.ndim(x) == 0 and (as_tuple or x.item() == 0): + return torch.empty(0, 0, dtype=torch.int64) + if ( + jnp.ndim(x) == 0 + ): # when x is scalar, return torch.tensor([], size=(1, 0), dtype=torch.int64) + res = torch.empty(1, 0, dtype=torch.int64) + return jnp.array(res.numpy()) + index_tuple = jnp.nonzero(x) + index_tuple = [jnp.expand_dims(p, -1) for p in index_tuple] + return jnp.concatenate(index_tuple, axis=-1) # aten.prod @op(torch.ops.aten.prod) def _aten_prod(input, dim=None, keepdim=False, *, dtype=None): - if dtype: - input = input.astype(mappings.t2j_dtype(dtype)) - return _with_reduction_scalar(jnp.prod, input, dim, keepdim) + if dtype: + input = input.astype(mappings.t2j_dtype(dtype)) + return _with_reduction_scalar(jnp.prod, input, dim, keepdim) @op(torch.ops.aten.put) def _aten_put(self, index, source, accumulate=False): - expanded = False - res = None + expanded = False + res = None - if self.ndim == 0: - expanded = True - self = jnp.expand_dims(self, 0) + if self.ndim == 0: + expanded = True + self = jnp.expand_dims(self, 0) - if accumulate: - tmp = jnp.zeros(self.shape) - tmp = jnp.put(tmp, index, source, inplace=False) - res = jnp.add(self, tmp).astype(self.dtype) - else: - res = jnp.put(self, index, source, inplace=False) + if accumulate: + tmp = jnp.zeros(self.shape) + tmp = jnp.put(tmp, index, source, inplace=False) + res = jnp.add(self, tmp).astype(self.dtype) + else: + res = jnp.put(self, index, source, inplace=False) - if expanded: - res = res.squeeze() + if expanded: + res = res.squeeze() - return res + return res # aten.randperm # randperm.generator(SymInt n, *, Generator? generator, ScalarType? dtype=long, Layout? layout=None, Device? device=None, bool? pin_memory=None) @op(torch.ops.aten.randperm, needs_env=True) def _aten_randperm( - n, - *, - generator=None, - dtype=None, - layout=None, - device=None, - pin_memory=None, - env=None, + n, + *, + generator=None, + dtype=None, + layout=None, + device=None, + pin_memory=None, + env=None, ): - """ - Generates a random permutation of integers from 0 to n-1. - - Args: - n: The upper bound (exclusive) of the permutation range. - generator: A PRNGKey used as the random key. If None, a new key is created. - dtype: The desired data type of the output array. Default is jnp.int64. - layout: The desired layout of the output array (e.g., 'row-major', 'column-major'). - device: The desired device on which to place the output array (e.g., jax.devices()[0]). - pin_memory: Whether to pin the output array's memory to the host. - - Returns: - A DeviceArray containing a random permutation of integers from 0 to n-1. - """ - if dtype: - dtype = mappings.t2j_dtype(dtype) - else: - dtype = jnp.int64.dtype - key = env.get_and_rotate_prng_key(generator) - indices = jnp.arange(n, dtype=dtype) - permutation = jax.random.permutation(key, indices) - return permutation + """ + Generates a random permutation of integers from 0 to n-1. + + Args: + n: The upper bound (exclusive) of the permutation range. + generator: A PRNGKey used as the random key. If None, a new key is created. + dtype: The desired data type of the output array. Default is jnp.int64. + layout: The desired layout of the output array (e.g., 'row-major', 'column-major'). + device: The desired device on which to place the output array (e.g., jax.devices()[0]). + pin_memory: Whether to pin the output array's memory to the host. + + Returns: + A DeviceArray containing a random permutation of integers from 0 to n-1. + """ + if dtype: + dtype = mappings.t2j_dtype(dtype) + else: + dtype = jnp.int64.dtype + key = env.get_and_rotate_prng_key(generator) + indices = jnp.arange(n, dtype=dtype) + permutation = jax.random.permutation(key, indices) + return permutation # aten.reflection_pad3d @@ -3207,13 +3172,13 @@ def _aten_randperm( # aten.remainder @op(torch.ops.aten.remainder) def _aten_remainder(inputs, other): - return inputs % other + return inputs % other # aten.repeat @op(torch.ops.aten.repeat) def _aten_repeat(x, reps): - return jnp.tile(x, reps) + return jnp.tile(x, reps) # aten.replication_pad2d @@ -3221,31 +3186,31 @@ def _aten_repeat(x, reps): # aten.roll @op(torch.ops.aten.roll) def _aten_roll(input, shifts, dims=None): - return jnp.roll(input, shifts, dims) + return jnp.roll(input, shifts, dims) # aten.slice_scatter @op(torch.ops.aten.slice_scatter) def _aten_slice_scatter(input, src, dim=0, start=None, end=None, step=1): - input_index = [] - for x in range(len(input.shape)): - if x == dim: - input_index.append(slice(start, end, step)) - else: - input_index.append(slice(None, None, None)) - return input.at[tuple(input_index)].set(src) + input_index = [] + for x in range(len(input.shape)): + if x == dim: + input_index.append(slice(start, end, step)) + else: + input_index.append(slice(None, None, None)) + return input.at[tuple(input_index)].set(src) # aten.sort # torch.sort(input, dim=-1, descending=False, stable=False, *, out=None) @op(torch.ops.aten.sort) def _aten_sort(a, dim=-1, descending=False, stable=False): - if a.shape == (): - return (a, jnp.astype(0, "int64")) - return ( - jnp.sort(a, axis=dim, stable=stable, descending=descending), - jnp.argsort(a, axis=dim, stable=stable, descending=descending), - ) + if a.shape == (): + return (a, jnp.astype(0, "int64")) + return ( + jnp.sort(a, axis=dim, stable=stable, descending=descending), + jnp.argsort(a, axis=dim, stable=stable, descending=descending), + ) # aten.sym_size @@ -3254,101 +3219,100 @@ def _aten_sort(a, dim=-1, descending=False, stable=False): # aten.topk @op(torch.ops.aten.topk) def _aten_topk(input, k, dim=None, largest=True, sorted=True, *, out=None): - """JAX top-k implementation using jax.lax.top_k for improved efficiency. - - Args: - input: The input JAX array. - k: The number of top elements to return. - dim: The dimension along which to find the top-k. If None, operates on the - flattened array. - largest: If True, returns the largest k elements. Otherwise, smallest k. - sorted: If True, returns the elements in sorted order. - - Returns: - A tuple (values, indices) containing: - - values: The top k values. - - indices: The indices of the top k values in the original array. - """ - if dim is None: - # last dim is chosen - dim = input.ndim - 1 - - if dim < 0: - dim = dim + input.ndim - - if not largest: - input = -input # Find top-k of negated input if we want the smallest - - if input.ndim == 0: - return input, jnp.array(0, dtype=jnp.int64.dtype) - - transpose_shape = None - if dim != -1 and dim != len(input.shape) - 1: - transpose_shape = list(range(len(input.shape))) - transpose_shape[dim], transpose_shape[-1] = ( - transpose_shape[-1], - transpose_shape[dim], - ) - input = jnp.transpose(input, transpose_shape) + """JAX top-k implementation using jax.lax.top_k for improved efficiency. + + Args: + input: The input JAX array. + k: The number of top elements to return. + dim: The dimension along which to find the top-k. If None, operates on the + flattened array. + largest: If True, returns the largest k elements. Otherwise, smallest k. + sorted: If True, returns the elements in sorted order. + + Returns: + A tuple (values, indices) containing: + - values: The top k values. + - indices: The indices of the top k values in the original array. + """ + if dim is None: + # last dim is chosen + dim = input.ndim - 1 + + if dim < 0: + dim = dim + input.ndim + + if not largest: + input = -input # Find top-k of negated input if we want the smallest + + if input.ndim == 0: + return input, jnp.array(0, dtype=jnp.int64.dtype) + + transpose_shape = None + if dim != -1 and dim != len(input.shape) - 1: + transpose_shape = list(range(len(input.shape))) + transpose_shape[dim], transpose_shape[-1] = ( + transpose_shape[-1], + transpose_shape[dim], + ) + input = jnp.transpose(input, transpose_shape) - values, indices = jax.lax.top_k(input, k) + values, indices = jax.lax.top_k(input, k) - if sorted: - values = jnp.sort(values, descending=True) - indices = jnp.take_along_axis( - indices, jnp.argsort(values, axis=-1, descending=True), axis=-1 - ) + if sorted: + values = jnp.sort(values, descending=True) + indices = jnp.take_along_axis( + indices, jnp.argsort(values, axis=-1, descending=True), axis=-1 + ) - if not largest: - values = -values # Negate values back if we found smallest + if not largest: + values = -values # Negate values back if we found smallest - if transpose_shape is not None: - values = jnp.transpose(values, transpose_shape) - indices = jnp.transpose(indices, transpose_shape) + if transpose_shape is not None: + values = jnp.transpose(values, transpose_shape) + indices = jnp.transpose(indices, transpose_shape) - return values, indices + return values, indices # aten.tril_indices # tril_indices(int row, int col, int offset=0, *, ScalarType? dtype=long, Layout? layout=None, Device? device=None, bool? pin_memory=None) @op(torch.ops.aten.tril_indices) def _aten_tril_indices( - row, - col, - offset=0, - *, - dtype=jnp.int64.dtype, - layout=None, - device=None, - pin_memory=None, + row, + col, + offset=0, + *, + dtype=jnp.int64.dtype, + layout=None, + device=None, + pin_memory=None, ): - a, b = jnp.tril_indices(row, offset, col) - return jnp.stack((a, b)) + a, b = jnp.tril_indices(row, offset, col) + return jnp.stack((a, b)) # aten.tril_indices # tril_indices(int row, int col, int offset=0, *, ScalarType? dtype=long, Layout? layout=None, Device? device=None, bool? pin_memory=None) @op(torch.ops.aten.triu_indices) def _aten_triu_indices( - row, - col, - offset=0, - *, - dtype=jnp.int64.dtype, - layout=None, - device=None, - pin_memory=None, + row, + col, + offset=0, + *, + dtype=jnp.int64.dtype, + layout=None, + device=None, + pin_memory=None, ): - a, b = jnp.triu_indices(row, offset, col) - return jnp.stack((a, b)) + a, b = jnp.triu_indices(row, offset, col) + return jnp.stack((a, b)) @op(torch.ops.aten.unbind_copy) def _aten_unbind(a, dim=0): - return [ - jax.lax.index_in_dim(a, i, dim, keepdims=False) - for i in range(a.shape[dim]) - ] + return [ + jax.lax.index_in_dim(a, i, dim, keepdims=False) for i in range(a.shape[dim]) + ] # aten.unique_dim @@ -3357,35 +3321,35 @@ def _aten_unbind(a, dim=0): # the tensor regardless of the `sorted` argument passed to `torch.unique`. @op(torch.ops.aten.unique_dim) def _aten_unique_dim( - input_tensor, dim, sort=True, return_inverse=False, return_counts=False + input_tensor, dim, sort=True, return_inverse=False, return_counts=False ): - result_tensor_or_tuple = jnp.unique( - input_tensor, - return_index=False, - return_inverse=return_inverse, - return_counts=return_counts, - axis=dim, - equal_nan=False, - ) - result_list = ( - list(result_tensor_or_tuple) - if isinstance(result_tensor_or_tuple, tuple) - else [result_tensor_or_tuple] - ) - - if not return_inverse: - result_list.insert(1, None) - elif _jax_version < (0, 4, 31) and dim is not None: - result_list[1] = result_list[1].flatten() - - if not return_counts: - result_list.insert(2, None) - - # [result, None, None] if return_inverse=False and return_counts=False - # [result, inverse, None] if return_inverse=True and return_counts=False - # [result, None, counts] if return_inverse=False and return_counts=True - # [result, inverse, counts] if return_inverse=True and return_counts=True - return result_list + result_tensor_or_tuple = jnp.unique( + input_tensor, + return_index=False, + return_inverse=return_inverse, + return_counts=return_counts, + axis=dim, + equal_nan=False, + ) + result_list = ( + list(result_tensor_or_tuple) + if isinstance(result_tensor_or_tuple, tuple) + else [result_tensor_or_tuple] + ) + + if not return_inverse: + result_list.insert(1, None) + elif _jax_version < (0, 4, 31) and dim is not None: + result_list[1] = result_list[1].flatten() + + if not return_counts: + result_list.insert(2, None) + + # [result, None, None] if return_inverse=False and return_counts=False + # [result, inverse, None] if return_inverse=True and return_counts=False + # [result, None, counts] if return_inverse=False and return_counts=True + # [result, inverse, counts] if return_inverse=True and return_counts=True + return result_list # aten._unique @@ -3394,18 +3358,18 @@ def _aten_unique_dim( # the tensor regardless of the `sorted` argument passed to `torch.unique`. @op(torch.ops.aten._unique) def _aten_unique(input_tensor, sort=True, return_inverse=False): - result_tensor_or_tuple = jnp.unique( - input_tensor, - return_index=False, - return_inverse=return_inverse, - return_counts=False, - axis=None, - equal_nan=False, - ) - if return_inverse: - return result_tensor_or_tuple - else: - return (result_tensor_or_tuple, None) + result_tensor_or_tuple = jnp.unique( + input_tensor, + return_index=False, + return_inverse=return_inverse, + return_counts=False, + axis=None, + equal_nan=False, + ) + if return_inverse: + return result_tensor_or_tuple + else: + return (result_tensor_or_tuple, None) # aten._unique2 @@ -3414,85 +3378,83 @@ def _aten_unique(input_tensor, sort=True, return_inverse=False): # the tensor regardless of the `sorted` argument passed to `torch.unique`. @op(torch.ops.aten._unique2) def _aten_unique2( - input_tensor, sort=True, return_inverse=False, return_counts=False + input_tensor, sort=True, return_inverse=False, return_counts=False ): - return _aten_unique_dim( - input_tensor=input_tensor, - dim=None, - sort=sort, - return_inverse=return_inverse, - return_counts=return_counts, - ) + return _aten_unique_dim( + input_tensor=input_tensor, + dim=None, + sort=sort, + return_inverse=return_inverse, + return_counts=return_counts, + ) # aten.unique_consecutive @op(torch.ops.aten.unique_consecutive) def _aten_unique_consecutive( - input_tensor, return_inverse=False, return_counts=None, dim=None + input_tensor, return_inverse=False, return_counts=None, dim=None ): - # Explanation of computations (shown in 1D for simplicity): - # - # Input [a b b c c c d d d d e e e e e] - # Slice dropping final element (input[:-1]) [a b b c c c d d d d e e e e] - # Slice dropping first element (input[1:]) [b b c c c d d d d e e e e e] - # Boolean != operation on shifted slices [1 0 1 0 0 1 0 0 0 1 0 0 0 0] - # Prepend 1 to represent the first element [1 1 0 1 0 0 1 0 0 0 1 0 0 0 0] - # Filter input by the resulting bool array [a b c d e ] - # Output [a b c d e] - - if dim is None: - inverse_shape = input_tensor.shape - input_tensor = input_tensor.flatten() - ndim = 1 - dim = 0 - else: - inverse_shape = input_tensor.shape[dim] - ndim = input_tensor.ndim - if dim < 0: - dim += ndim + # Explanation of computations (shown in 1D for simplicity): + # + # Input [a b b c c c d d d d e e e e e] + # Slice dropping final element (input[:-1]) [a b b c c c d d d d e e e e] + # Slice dropping first element (input[1:]) [b b c c c d d d d e e e e e] + # Boolean != operation on shifted slices [1 0 1 0 0 1 0 0 0 1 0 0 0 0] + # Prepend 1 to represent the first element [1 1 0 1 0 0 1 0 0 0 1 0 0 0 0] + # Filter input by the resulting bool array [a b c d e ] + # Output [a b c d e] + + if dim is None: + inverse_shape = input_tensor.shape + input_tensor = input_tensor.flatten() + ndim = 1 + dim = 0 + else: + inverse_shape = input_tensor.shape[dim] + ndim = input_tensor.ndim + if dim < 0: + dim += ndim - nd_slice_0 = tuple( - slice(None, -1) if d == dim else slice(None) for d in range(ndim) - ) - nd_slice_1 = tuple( - slice(1, None) if d == dim else slice(None) for d in range(ndim) - ) + nd_slice_0 = tuple( + slice(None, -1) if d == dim else slice(None) for d in range(ndim) + ) + nd_slice_1 = tuple( + slice(1, None) if d == dim else slice(None) for d in range(ndim) + ) - axes_to_reduce = tuple(d for d in range(ndim) if d != dim) + axes_to_reduce = tuple(d for d in range(ndim) if d != dim) - does_not_equal_prior = jnp.any( - input_tensor[nd_slice_0] != input_tensor[nd_slice_1], - axis=axes_to_reduce, - keepdims=False, - ) + does_not_equal_prior = jnp.any( + input_tensor[nd_slice_0] != input_tensor[nd_slice_1], + axis=axes_to_reduce, + keepdims=False, + ) - if input_tensor.shape[dim] != 0: - # Prepend `True` to represent the first element of the input. - does_not_equal_prior = jnp.insert(does_not_equal_prior, 0, True) + if input_tensor.shape[dim] != 0: + # Prepend `True` to represent the first element of the input. + does_not_equal_prior = jnp.insert(does_not_equal_prior, 0, True) - include_indices = jnp.argwhere(does_not_equal_prior)[:, 0] + include_indices = jnp.argwhere(does_not_equal_prior)[:, 0] - output_tensor = input_tensor[ - tuple(include_indices if d == dim else slice(None) for d in range(ndim)) - ] + output_tensor = input_tensor[ + tuple(include_indices if d == dim else slice(None) for d in range(ndim)) + ] - if return_inverse or return_counts: - counts = ( - jnp.append(include_indices[1:], input_tensor.shape[dim]) - - include_indices[:] - ) + if return_inverse or return_counts: + counts = ( + jnp.append(include_indices[1:], input_tensor.shape[dim]) + - include_indices[:] + ) - inverse = ( - jnp.reshape( - jnp.repeat(jnp.arange(len(counts)), counts), inverse_shape - ) - if return_inverse - else None - ) + inverse = ( + jnp.reshape(jnp.repeat(jnp.arange(len(counts)), counts), inverse_shape) + if return_inverse + else None + ) - return output_tensor, inverse, counts + return output_tensor, inverse, counts - return output_tensor, None, None + return output_tensor, None, None # NOTE: skip aten.upsample_nearest2d and aten.upsample_bilinear2d @@ -3507,39 +3469,39 @@ def _aten_unique_consecutive( @op(torch.ops.aten.where.ScalarOther) @op(torch.ops.aten.where.Scalar) def _aten_where(condition, x=None, y=None): - return jnp.where(condition, x, y) + return jnp.where(condition, x, y) # aten.to.dtype # Tensor(a) self, ScalarType dtype, bool non_blocking=False, bool copy=False, MemoryFormat? memory_format=None @op(torch.ops.aten.to.dtype) def _aten_to_dtype( - a, dtype, non_blocking=False, copy=False, memory_format=None + a, dtype, non_blocking=False, copy=False, memory_format=None ): - if dtype: - jaxdtype = mappings.t2j_dtype(dtype) - return a.astype(jaxdtype) + if dtype: + jaxdtype = mappings.t2j_dtype(dtype) + return a.astype(jaxdtype) @op(torch.ops.aten.to.dtype_layout) def _aten_to_dtype_layout( - a, - *, - dtype=None, - layout=None, - device=None, - pin_memory=None, - non_blocking=False, - copy=False, - memory_format=None, + a, + *, + dtype=None, + layout=None, + device=None, + pin_memory=None, + non_blocking=False, + copy=False, + memory_format=None, ): - return _aten_to_dtype( - a, - dtype, - non_blocking=non_blocking, - copy=copy, - memory_format=memory_format, - ) + return _aten_to_dtype( + a, + dtype, + non_blocking=non_blocking, + copy=copy, + memory_format=memory_format, + ) # aten.to.device @@ -3548,250 +3510,245 @@ def _aten_to_dtype_layout( # Tensor self, int[1]? dim=None, *, Scalar? correction=None, bool keepdim=False @op(torch.ops.aten.var_mean.correction) def _aten_var_mean_correction(tensor, dim=None, correction=1, keepdim=False): - # The internal API technically has a default `correction` argument of `None`, - # but the public API has a default argument of 1. Therefore, we simply set our - # default argument to 1. However, since the argument is officially supposed to - # be nullable, we still need to check for `None` per the API contract. - if correction is None: - correction = 1 - mean = jnp.mean(tensor, axis=dim, keepdims=keepdim) - # TODO: Pass in the `mean=mean` argument once `jax.numpy.var` supports it. - var = jnp.var(tensor, axis=dim, ddof=correction, keepdims=keepdim) - return var, mean + # The internal API technically has a default `correction` argument of `None`, + # but the public API has a default argument of 1. Therefore, we simply set our + # default argument to 1. However, since the argument is officially supposed to + # be nullable, we still need to check for `None` per the API contract. + if correction is None: + correction = 1 + mean = jnp.mean(tensor, axis=dim, keepdims=keepdim) + # TODO: Pass in the `mean=mean` argument once `jax.numpy.var` supports it. + var = jnp.var(tensor, axis=dim, ddof=correction, keepdims=keepdim) + return var, mean @op(torch.ops.aten.scalar_tensor) @op_base.convert_dtype() def _aten_scalar_tensor( - s, dtype=None, layout=None, device=None, pin_memory=None + s, dtype=None, layout=None, device=None, pin_memory=None ): - return jnp.array(s, dtype=dtype) + return jnp.array(s, dtype=dtype) @op(torch.ops.aten.to.device) def _aten_to_device(x, device, dtype): - return x + return x @op(torch.ops.aten.max_pool2d_with_indices_backward) def max_pool2d_with_indices_backward_custom( - grad_output, - self, - kernel_size, - stride, - padding, - dilation, - ceil_mode, - indices, + grad_output, + self, + kernel_size, + stride, + padding, + dilation, + ceil_mode, + indices, ): - """ - Approximates the gradient calculation of PyTorch's max_pool2d_with_indices_backward. - - Args: - grad_output: The gradient tensor from the preceding layer. - self: The input tensor on which the original max pooling was performed. - kernel_size: The size of the pooling window. - stride: The stride of the pooling window. - padding: The padding applied during max pooling. - dilation: The dilation factor for the pooling operation. - ceil_mode: Whether to use ceil or floor when calculating output shapes. - indices: The indices of the maximum values, as produced by max_pool2d_with_indices. - - Returns: - The calculated gradient with respect to the input (grad_input). - """ - - kH, kW = kernel_size - dH, dW = stride - padH, padW = padding - dilH, dilW = dilation - - # Calculate output shape (may need adjustment based on ceil_mode) - out_shape = jnp.array(self.shape) - grad_input = jnp.zeros_like(self) - - # Iterate over the flattened input and output tensors - for i, idx in enumerate(indices.flatten()): - # Calculate input coordinates corresponding to the maximum value - out_y, out_x = i // grad_output.shape[3], i % grad_output.shape[3] - in_y = out_y * dH - padH + out_y * (dilH - 1) - in_x = out_x * dW - padW + out_x * (dilW - 1) - - # Scatter the gradient to the appropriate input locations (handling potential overlaps) - for y in range(in_y, in_y + kH): - for x in range(in_x, in_x + kW): - if ( - 0 <= y < grad_input.shape[2] - and 0 <= x < grad_input.shape[3] - ): - grad_input = grad_input.at[y, x].add( - grad_output.flatten()[i] - ) - - return grad_input + """ + Approximates the gradient calculation of PyTorch's max_pool2d_with_indices_backward. + + Args: + grad_output: The gradient tensor from the preceding layer. + self: The input tensor on which the original max pooling was performed. + kernel_size: The size of the pooling window. + stride: The stride of the pooling window. + padding: The padding applied during max pooling. + dilation: The dilation factor for the pooling operation. + ceil_mode: Whether to use ceil or floor when calculating output shapes. + indices: The indices of the maximum values, as produced by max_pool2d_with_indices. + + Returns: + The calculated gradient with respect to the input (grad_input). + """ + + kH, kW = kernel_size + dH, dW = stride + padH, padW = padding + dilH, dilW = dilation + + # Calculate output shape (may need adjustment based on ceil_mode) + out_shape = jnp.array(self.shape) + grad_input = jnp.zeros_like(self) + + # Iterate over the flattened input and output tensors + for i, idx in enumerate(indices.flatten()): + # Calculate input coordinates corresponding to the maximum value + out_y, out_x = i // grad_output.shape[3], i % grad_output.shape[3] + in_y = out_y * dH - padH + out_y * (dilH - 1) + in_x = out_x * dW - padW + out_x * (dilW - 1) + + # Scatter the gradient to the appropriate input locations (handling potential overlaps) + for y in range(in_y, in_y + kH): + for x in range(in_x, in_x + kW): + if 0 <= y < grad_input.shape[2] and 0 <= x < grad_input.shape[3]: + grad_input = grad_input.at[y, x].add(grad_output.flatten()[i]) + + return grad_input @op(torch.ops.aten._local_scalar_dense) def _aten_local_scalar_dense(x): - return x.item() + return x.item() @op(torch.ops.aten.tensor_split.sections) def _aten_tensor_split(ary, indices_or_sections, axis=0): - return jnp.array_split(ary, indices_or_sections, axis) + return jnp.array_split(ary, indices_or_sections, axis) @op(torch.ops.aten.randn, needs_env=True) @op_base.convert_dtype() def _randn( - *size, - generator=None, - out=None, - dtype=None, - layout=torch.strided, - device=None, - requires_grad=False, - pin_memory=False, - env=None, + *size, + generator=None, + out=None, + dtype=None, + layout=torch.strided, + device=None, + requires_grad=False, + pin_memory=False, + env=None, ): - shape = size - if len(shape) == 1 and isinstance(shape[0], (list, tuple)): - shape = shape[0] - key = env.get_and_rotate_prng_key(generator) - res = jax.random.normal(key, shape) - if dtype is not None: - res = res.astype(dtype) - return res + shape = size + if len(shape) == 1 and isinstance(shape[0], (list, tuple)): + shape = shape[0] + key = env.get_and_rotate_prng_key(generator) + res = jax.random.normal(key, shape) + if dtype is not None: + res = res.astype(dtype) + return res @op(torch.ops.aten.bernoulli.p, needs_env=True) def _aten_bernoulli( - self, - p=0.5, - *, - generator=None, - env=None, + self, + p=0.5, + *, + generator=None, + env=None, ): - key = env.get_and_rotate_prng_key(generator) - res = jax.random.uniform(key, self.shape) < p - return res + key = env.get_and_rotate_prng_key(generator) + res = jax.random.uniform(key, self.shape) < p + return res @op(torch.ops.aten.geometric, needs_env=True) def geometric(self, p, *, generator=None, env=None): - key = env.get_and_rotate_prng_key(generator) - res = jax.random.geometric(key, p, self.shape) - return res + key = env.get_and_rotate_prng_key(generator) + res = jax.random.geometric(key, p, self.shape) + return res @op(torch.ops.aten.randn_like, needs_env=True) @op_base.convert_dtype() def _aten_randn_like( - x, - *, - dtype=None, - layout=None, - device=None, - pin_memory=False, - memory_format=torch.preserve_format, - env=None, + x, + *, + dtype=None, + layout=None, + device=None, + pin_memory=False, + memory_format=torch.preserve_format, + env=None, ): - key = env.get_and_rotate_prng_key() - return jax.random.normal(key, dtype=dtype or x.dtype, shape=x.shape) + key = env.get_and_rotate_prng_key() + return jax.random.normal(key, dtype=dtype or x.dtype, shape=x.shape) @op(torch.ops.aten.rand, needs_env=True) @op_base.convert_dtype() def _rand( - *size, - generator=None, - out=None, - dtype=None, - layout=torch.strided, - device=None, - requires_grad=False, - pin_memory=False, - env=None, + *size, + generator=None, + out=None, + dtype=None, + layout=torch.strided, + device=None, + requires_grad=False, + pin_memory=False, + env=None, ): - shape = size - if len(shape) == 1 and isinstance(shape[0], (list, tuple)): - shape = shape[0] - key = env.get_and_rotate_prng_key(generator) - res = jax.random.uniform(key, shape) - if dtype is not None: - res = res.astype(dtype) - return res + shape = size + if len(shape) == 1 and isinstance(shape[0], (list, tuple)): + shape = shape[0] + key = env.get_and_rotate_prng_key(generator) + res = jax.random.uniform(key, shape) + if dtype is not None: + res = res.astype(dtype) + return res @op(torch.ops.aten.outer) def _aten_outer(a, b): - return jnp.outer(a, b) + return jnp.outer(a, b) @op(torch.ops.aten.allclose) def _aten_allclose(input, other, rtol=1e-05, atol=1e-08, equal_nan=False): - return jnp.allclose(input, other, rtol, atol, equal_nan) + return jnp.allclose(input, other, rtol, atol, equal_nan) @op(torch.ops.aten.native_batch_norm) def _aten_native_batch_norm( - input, - weight, - bias, - running_mean, - running_var, - training=False, - momentum=0.1, - eps=1e-5, + input, + weight, + bias, + running_mean, + running_var, + training=False, + momentum=0.1, + eps=1e-5, ): - if running_mean is None: - running_mean = jnp.zeros( - input.shape[1], dtype=input.dtype - ) # Initialize running mean if None - if running_var is None: - running_var = jnp.ones( - input.shape[1], dtype=input.dtype - ) # Initialize running variance if None - - if training: - return _aten__native_batch_norm_legit( - input, - weight, - bias, - running_mean, - running_var, - training, - momentum, - eps, - ) - else: - return _aten__native_batch_norm_legit_no_training( - input, weight, bias, running_mean, running_var, momentum, eps - ) + if running_mean is None: + running_mean = jnp.zeros( + input.shape[1], dtype=input.dtype + ) # Initialize running mean if None + if running_var is None: + running_var = jnp.ones( + input.shape[1], dtype=input.dtype + ) # Initialize running variance if None + + if training: + return _aten__native_batch_norm_legit( + input, + weight, + bias, + running_mean, + running_var, + training, + momentum, + eps, + ) + else: + return _aten__native_batch_norm_legit_no_training( + input, weight, bias, running_mean, running_var, momentum, eps + ) @op(torch.ops.aten.normal, needs_env=True) def _aten_normal(self, mean=0, std=1, generator=None, env=None): - shape = self.shape - res = _randn(*shape, generator=generator, env=env) - return res * std + mean + shape = self.shape + res = _randn(*shape, generator=generator, env=env) + return res * std + mean # TODO: not clear what this function should actually do # https://github.com/pytorch/pytorch/blob/d96c80649f301129219469d8b4353e52edab3b78/aten/src/ATen/native/native_functions.yaml#L7933-L7940 @op(torch.ops.aten.lift_fresh) def _aten_lift_fresh(self): - return self + return self @op(torch.ops.aten.uniform, needs_env=True) def _aten_uniform(self, from_=0, to=1, *, generator=None, env=None): - assert from_ <= to, ( - f"Uniform from(passed in {from_}) must be less than to(passed in {to})" - ) - shape = self.shape - res = _rand(*shape, generator=generator, env=env) - return res * (to - from_) + from_ + assert from_ <= to, ( + f"Uniform from(passed in {from_}) must be less than to(passed in {to})" + ) + shape = self.shape + res = _rand(*shape, generator=generator, env=env) + return res * (to - from_) + from_ # func: randint.low_generator(SymInt low, SymInt high, SymInt[] size, *, Generator? generator, ScalarType? dtype=long, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor @@ -3800,2015 +3757,1990 @@ def _aten_uniform(self, from_=0, to=1, *, generator=None, env=None): @op(torch.ops.aten.randint, needs_env=True) @op_base.convert_dtype(use_default_dtype=False) def _aten_randint( - *args, - generator=None, - dtype=None, - env=None, - **kwargs, + *args, + generator=None, + dtype=None, + env=None, + **kwargs, ): - if len(args) == 3: - # low, high, size - low, high, size = args - elif len(args) == 2: - high, size = args - low = 0 - else: - raise AssertionError( - f"Expected at 2 or 3 args for Aten::randint, got {len(args)}" - ) + if len(args) == 3: + # low, high, size + low, high, size = args + elif len(args) == 2: + high, size = args + low = 0 + else: + raise AssertionError( + f"Expected at 2 or 3 args for Aten::randint, got {len(args)}" + ) - key = env.get_and_rotate_prng_key(generator) - res = jax.random.randint(key, size, low, high) - if dtype is not None: - res = res.astype(dtype) - return res + key = env.get_and_rotate_prng_key(generator) + res = jax.random.randint(key, size, low, high) + if dtype is not None: + res = res.astype(dtype) + return res @op( - torch.ops.aten.randint_like, - torch.ops.aten.randint.generator, - needs_env=True, + torch.ops.aten.randint_like, + torch.ops.aten.randint.generator, + needs_env=True, ) @op_base.convert_dtype(use_default_dtype=False) def _aten_randint_like( - input, - *args, - generator=None, - dtype=None, - env=None, - **kwargs, + input, + *args, + generator=None, + dtype=None, + env=None, + **kwargs, ): - if len(args) == 2: - low, high = args - elif len(args) == 1: - high = args[0] - low = 0 - else: - raise AssertionError( - f"Expected at 1 or 2 args for Aten::randint_like, got {len(args)}" - ) + if len(args) == 2: + low, high = args + elif len(args) == 1: + high = args[0] + low = 0 + else: + raise AssertionError( + f"Expected at 1 or 2 args for Aten::randint_like, got {len(args)}" + ) - shape = input.shape - dtype = dtype or input.dtype - key = env.get_and_rotate_prng_key(generator) - res = jax.random.randint(key, shape, low, high) - if dtype is not None: - res = res.astype(dtype) - return res + shape = input.shape + dtype = dtype or input.dtype + key = env.get_and_rotate_prng_key(generator) + res = jax.random.randint(key, shape, low, high) + if dtype is not None: + res = res.astype(dtype) + return res @op(torch.ops.aten.dim, is_jax_function=False) def _aten_dim(self): - return len(self.shape) + return len(self.shape) @op(torch.ops.aten.copysign) def _aten_copysign(input, other, *, out=None): - result = jnp.copysign(input, other) - # torch.copysign(x, y) returns float32 for integer x and y, - # regardless of their exact integer dtype, whereas jax.copysign returns - # float64 when one or both of them is int64. - if jnp.issubdtype(input.dtype, jnp.integer) and jnp.issubdtype( - other.dtype, jnp.integer - ): - result = result.astype(jnp.float32) - return result + result = jnp.copysign(input, other) + # torch.copysign(x, y) returns float32 for integer x and y, + # regardless of their exact integer dtype, whereas jax.copysign returns + # float64 when one or both of them is int64. + if jnp.issubdtype(input.dtype, jnp.integer) and jnp.issubdtype( + other.dtype, jnp.integer + ): + result = result.astype(jnp.float32) + return result @op(torch.ops.aten.i0) @op_base.promote_int_input def _aten_i0(self): - return jax.scipy.special.i0(self) + return jax.scipy.special.i0(self) @op(torch.ops.aten.special_i0e) @op_base.promote_int_input def _aten_i0e(self): - return jax.scipy.special.i0e(self) + return jax.scipy.special.i0e(self) @op(torch.ops.aten.special_i1) @op_base.promote_int_input def _aten_special_i1(self): - return jax.scipy.special.i1(self) + return jax.scipy.special.i1(self) @op(torch.ops.aten.special_i1e) @op_base.promote_int_input def _aten_special_i1e(self): - return jax.scipy.special.i1e(self) + return jax.scipy.special.i1e(self) @op(torch.ops.aten.special_laguerre_polynomial_l) @op_base.promote_int_input def _aten_special_laguerre_polynomial_l(self, n): - # Adapted from https://github.com/pytorch/pytorch/blob/f8f41dcb24cb4f4e87a51bb04847942dd835e496/aten/src/ATen/native/Math.h#L3106-L3134 - - @jnp.vectorize - def vectorized(x, n_i): - def negative_n(x): - return jnp.zeros_like(x) - - def zero_n(x): - return jnp.ones_like(x) - - def one_n(x): - return jnp.ones_like(x) - x - - def zero_abs(x): - return jnp.ones_like(x) - - def default(x): - def f(k, carry): - p, q = carry - return ( - q, - ((k * 2 + (jnp.ones_like(x) - x)) * q - k * p) / (k + 1), - ) - - _, q = jax.lax.fori_loop( - 1, n_i, f, init_val=(1.0, jnp.ones_like(x) - x) - ) - return q - - return jnp.piecewise( - x, - [n_i == 1, n_i == 0, jnp.abs(n_i) == jnp.zeros_like(x), n_i < 0], - [one_n, zero_n, zero_abs, negative_n, default], + # Adapted from https://github.com/pytorch/pytorch/blob/f8f41dcb24cb4f4e87a51bb04847942dd835e496/aten/src/ATen/native/Math.h#L3106-L3134 + + @jnp.vectorize + def vectorized(x, n_i): + def negative_n(x): + return jnp.zeros_like(x) + + def zero_n(x): + return jnp.ones_like(x) + + def one_n(x): + return jnp.ones_like(x) - x + + def zero_abs(x): + return jnp.ones_like(x) + + def default(x): + def f(k, carry): + p, q = carry + return ( + q, + ((k * 2 + (jnp.ones_like(x) - x)) * q - k * p) / (k + 1), ) - return vectorized(self, n.astype(jnp.int64)) + _, q = jax.lax.fori_loop(1, n_i, f, init_val=(1.0, jnp.ones_like(x) - x)) + return q + + return jnp.piecewise( + x, + [n_i == 1, n_i == 0, jnp.abs(n_i) == jnp.zeros_like(x), n_i < 0], + [one_n, zero_n, zero_abs, negative_n, default], + ) + + return vectorized(self, n.astype(jnp.int64)) @op(torch.ops.aten.special_modified_bessel_i0) @op_base.promote_int_input def _aten_special_modified_bessel_i0(self): - # Adapted from https://github.com/pytorch/pytorch/blob/f8f41dcb24cb4f4e87a51bb04847942dd835e496/aten/src/ATen/native/Math.h#L3182-L3268 - - def small(x): - A = jnp.array( - [ - -4.41534164647933937950e-18, - 3.33079451882223809783e-17, - -2.43127984654795469359e-16, - 1.71539128555513303061e-15, - -1.16853328779934516808e-14, - 7.67618549860493561688e-14, - -4.85644678311192946090e-13, - 2.95505266312963983461e-12, - -1.72682629144155570723e-11, - 9.67580903537323691224e-11, - -5.18979560163526290666e-10, - 2.65982372468238665035e-09, - -1.30002500998624804212e-08, - 6.04699502254191894932e-08, - -2.67079385394061173391e-07, - 1.11738753912010371815e-06, - -4.41673835845875056359e-06, - 1.64484480707288970893e-05, - -5.75419501008210370398e-05, - 1.88502885095841655729e-04, - -5.76375574538582365885e-04, - 1.63947561694133579842e-03, - -4.32430999505057594430e-03, - 1.05464603945949983183e-02, - -2.37374148058994688156e-02, - 4.93052842396707084878e-02, - -9.49010970480476444210e-02, - 1.71620901522208775349e-01, - -3.04682672343198398683e-01, - 6.76795274409476084995e-01, - ], - dtype=self.dtype, - ) - - def f(carry, val): - p, q, a = carry - p, q = q, a - return (p, q, ((x / 2.0) - 2.0) * q - p + val), None + # Adapted from https://github.com/pytorch/pytorch/blob/f8f41dcb24cb4f4e87a51bb04847942dd835e496/aten/src/ATen/native/Math.h#L3182-L3268 + + def small(x): + A = jnp.array( + [ + -4.41534164647933937950e-18, + 3.33079451882223809783e-17, + -2.43127984654795469359e-16, + 1.71539128555513303061e-15, + -1.16853328779934516808e-14, + 7.67618549860493561688e-14, + -4.85644678311192946090e-13, + 2.95505266312963983461e-12, + -1.72682629144155570723e-11, + 9.67580903537323691224e-11, + -5.18979560163526290666e-10, + 2.65982372468238665035e-09, + -1.30002500998624804212e-08, + 6.04699502254191894932e-08, + -2.67079385394061173391e-07, + 1.11738753912010371815e-06, + -4.41673835845875056359e-06, + 1.64484480707288970893e-05, + -5.75419501008210370398e-05, + 1.88502885095841655729e-04, + -5.76375574538582365885e-04, + 1.63947561694133579842e-03, + -4.32430999505057594430e-03, + 1.05464603945949983183e-02, + -2.37374148058994688156e-02, + 4.93052842396707084878e-02, + -9.49010970480476444210e-02, + 1.71620901522208775349e-01, + -3.04682672343198398683e-01, + 6.76795274409476084995e-01, + ], + dtype=self.dtype, + ) - (p, _, a), _ = jax.lax.scan( - f, init=(jnp.zeros_like(x), jnp.zeros_like(x), 0), xs=A - ) + def f(carry, val): + p, q, a = carry + p, q = q, a + return (p, q, ((x / 2.0) - 2.0) * q - p + val), None - return jnp.exp(x) * (0.5 * (a - p)) + (p, _, a), _ = jax.lax.scan( + f, init=(jnp.zeros_like(x), jnp.zeros_like(x), 0), xs=A + ) - def default(x): - B = jnp.array( - [ - -7.23318048787475395456e-18, - -4.83050448594418207126e-18, - 4.46562142029675999901e-17, - 3.46122286769746109310e-17, - -2.82762398051658348494e-16, - -3.42548561967721913462e-16, - 1.77256013305652638360e-15, - 3.81168066935262242075e-15, - -9.55484669882830764870e-15, - -4.15056934728722208663e-14, - 1.54008621752140982691e-14, - 3.85277838274214270114e-13, - 7.18012445138366623367e-13, - -1.79417853150680611778e-12, - -1.32158118404477131188e-11, - -3.14991652796324136454e-11, - 1.18891471078464383424e-11, - 4.94060238822496958910e-10, - 3.39623202570838634515e-09, - 2.26666899049817806459e-08, - 2.04891858946906374183e-07, - 2.89137052083475648297e-06, - 6.88975834691682398426e-05, - 3.36911647825569408990e-03, - 8.04490411014108831608e-01, - ], - dtype=self.dtype, - ) + return jnp.exp(x) * (0.5 * (a - p)) + + def default(x): + B = jnp.array( + [ + -7.23318048787475395456e-18, + -4.83050448594418207126e-18, + 4.46562142029675999901e-17, + 3.46122286769746109310e-17, + -2.82762398051658348494e-16, + -3.42548561967721913462e-16, + 1.77256013305652638360e-15, + 3.81168066935262242075e-15, + -9.55484669882830764870e-15, + -4.15056934728722208663e-14, + 1.54008621752140982691e-14, + 3.85277838274214270114e-13, + 7.18012445138366623367e-13, + -1.79417853150680611778e-12, + -1.32158118404477131188e-11, + -3.14991652796324136454e-11, + 1.18891471078464383424e-11, + 4.94060238822496958910e-10, + 3.39623202570838634515e-09, + 2.26666899049817806459e-08, + 2.04891858946906374183e-07, + 2.89137052083475648297e-06, + 6.88975834691682398426e-05, + 3.36911647825569408990e-03, + 8.04490411014108831608e-01, + ], + dtype=self.dtype, + ) - def f(carry, val): - p, q, b = carry - p, q = q, b - return (p, q, (32.0 / x - 2.0) * q - p + val), None + def f(carry, val): + p, q, b = carry + p, q = q, b + return (p, q, (32.0 / x - 2.0) * q - p + val), None - (p, _, b), _ = jax.lax.scan( - f, init=(jnp.zeros_like(x), jnp.zeros_like(x), 0), xs=B - ) + (p, _, b), _ = jax.lax.scan( + f, init=(jnp.zeros_like(x), jnp.zeros_like(x), 0), xs=B + ) - return jnp.exp(x) * (0.5 * (b - p)) / jnp.sqrt(x) + return jnp.exp(x) * (0.5 * (b - p)) / jnp.sqrt(x) - self = jnp.abs(self) - return jnp.piecewise(self, [self <= 8], [small, default]) + self = jnp.abs(self) + return jnp.piecewise(self, [self <= 8], [small, default]) @op(torch.ops.aten.special_modified_bessel_i1) @op_base.promote_int_input def _aten_special_modified_bessel_i1(self): - # Adapted from https://github.com/pytorch/pytorch/blob/f8f41dcb24cb4f4e87a51bb04847942dd835e496/aten/src/ATen/native/Math.h#L3271-L3364 - - def small(x): - A = jnp.array( - [ - 2.77791411276104639959e-18, - -2.11142121435816608115e-17, - 1.55363195773620046921e-16, - -1.10559694773538630805e-15, - 7.60068429473540693410e-15, - -5.04218550472791168711e-14, - 3.22379336594557470981e-13, - -1.98397439776494371520e-12, - 1.17361862988909016308e-11, - -6.66348972350202774223e-11, - 3.62559028155211703701e-10, - -1.88724975172282928790e-09, - 9.38153738649577178388e-09, - -4.44505912879632808065e-08, - 2.00329475355213526229e-07, - -8.56872026469545474066e-07, - 3.47025130813767847674e-06, - -1.32731636560394358279e-05, - 4.78156510755005422638e-05, - -1.61760815825896745588e-04, - 5.12285956168575772895e-04, - -1.51357245063125314899e-03, - 4.15642294431288815669e-03, - -1.05640848946261981558e-02, - 2.47264490306265168283e-02, - -5.29459812080949914269e-02, - 1.02643658689847095384e-01, - -1.76416518357834055153e-01, - 2.52587186443633654823e-01, - ], - dtype=self.dtype, - ) + # Adapted from https://github.com/pytorch/pytorch/blob/f8f41dcb24cb4f4e87a51bb04847942dd835e496/aten/src/ATen/native/Math.h#L3271-L3364 + + def small(x): + A = jnp.array( + [ + 2.77791411276104639959e-18, + -2.11142121435816608115e-17, + 1.55363195773620046921e-16, + -1.10559694773538630805e-15, + 7.60068429473540693410e-15, + -5.04218550472791168711e-14, + 3.22379336594557470981e-13, + -1.98397439776494371520e-12, + 1.17361862988909016308e-11, + -6.66348972350202774223e-11, + 3.62559028155211703701e-10, + -1.88724975172282928790e-09, + 9.38153738649577178388e-09, + -4.44505912879632808065e-08, + 2.00329475355213526229e-07, + -8.56872026469545474066e-07, + 3.47025130813767847674e-06, + -1.32731636560394358279e-05, + 4.78156510755005422638e-05, + -1.61760815825896745588e-04, + 5.12285956168575772895e-04, + -1.51357245063125314899e-03, + 4.15642294431288815669e-03, + -1.05640848946261981558e-02, + 2.47264490306265168283e-02, + -5.29459812080949914269e-02, + 1.02643658689847095384e-01, + -1.76416518357834055153e-01, + 2.52587186443633654823e-01, + ], + dtype=self.dtype, + ) - def f(carry, val): - p, q, a = carry - p, q = q, a - return (p, q, ((jnp.abs(x) / 2.0) - 2.0) * q - p + val), None + def f(carry, val): + p, q, a = carry + p, q = q, a + return (p, q, ((jnp.abs(x) / 2.0) - 2.0) * q - p + val), None - (p, _, a), _ = jax.lax.scan( - f, init=(jnp.zeros_like(x), jnp.zeros_like(x), 0), xs=A - ) + (p, _, a), _ = jax.lax.scan( + f, init=(jnp.zeros_like(x), jnp.zeros_like(x), 0), xs=A + ) - return jax.lax.cond( - x < 0, - lambda: -(0.5 * (a - p) * jnp.abs(x) * jnp.exp(jnp.abs(x))), - lambda: 0.5 * (a - p) * jnp.abs(x) * jnp.exp(jnp.abs(x)), - ) + return jax.lax.cond( + x < 0, + lambda: -(0.5 * (a - p) * jnp.abs(x) * jnp.exp(jnp.abs(x))), + lambda: 0.5 * (a - p) * jnp.abs(x) * jnp.exp(jnp.abs(x)), + ) - def default(x): - B = jnp.array( - [ - 7.51729631084210481353e-18, - 4.41434832307170791151e-18, - -4.65030536848935832153e-17, - -3.20952592199342395980e-17, - 2.96262899764595013876e-16, - 3.30820231092092828324e-16, - -1.88035477551078244854e-15, - -3.81440307243700780478e-15, - 1.04202769841288027642e-14, - 4.27244001671195135429e-14, - -2.10154184277266431302e-14, - -4.08355111109219731823e-13, - -7.19855177624590851209e-13, - 2.03562854414708950722e-12, - 1.41258074366137813316e-11, - 3.25260358301548823856e-11, - -1.89749581235054123450e-11, - -5.58974346219658380687e-10, - -3.83538038596423702205e-09, - -2.63146884688951950684e-08, - -2.51223623787020892529e-07, - -3.88256480887769039346e-06, - -1.10588938762623716291e-04, - -9.76109749136146840777e-03, - 7.78576235018280120474e-01, - ], - dtype=self.dtype, - ) + def default(x): + B = jnp.array( + [ + 7.51729631084210481353e-18, + 4.41434832307170791151e-18, + -4.65030536848935832153e-17, + -3.20952592199342395980e-17, + 2.96262899764595013876e-16, + 3.30820231092092828324e-16, + -1.88035477551078244854e-15, + -3.81440307243700780478e-15, + 1.04202769841288027642e-14, + 4.27244001671195135429e-14, + -2.10154184277266431302e-14, + -4.08355111109219731823e-13, + -7.19855177624590851209e-13, + 2.03562854414708950722e-12, + 1.41258074366137813316e-11, + 3.25260358301548823856e-11, + -1.89749581235054123450e-11, + -5.58974346219658380687e-10, + -3.83538038596423702205e-09, + -2.63146884688951950684e-08, + -2.51223623787020892529e-07, + -3.88256480887769039346e-06, + -1.10588938762623716291e-04, + -9.76109749136146840777e-03, + 7.78576235018280120474e-01, + ], + dtype=self.dtype, + ) - def f(carry, val): - p, q, b = carry - p, q = q, b - return (p, q, (32.0 / jnp.abs(x) - 2.0) * q - p + val), None + def f(carry, val): + p, q, b = carry + p, q = q, b + return (p, q, (32.0 / jnp.abs(x) - 2.0) * q - p + val), None - (p, _, b), _ = jax.lax.scan( - f, init=(jnp.zeros_like(x), jnp.zeros_like(x), 0), xs=B - ) + (p, _, b), _ = jax.lax.scan( + f, init=(jnp.zeros_like(x), jnp.zeros_like(x), 0), xs=B + ) - return jax.lax.cond( - x < 0, - lambda: -( - jnp.exp(jnp.abs(x)) * (0.5 * (b - p)) / jnp.sqrt(jnp.abs(x)) - ), - lambda: jnp.exp(jnp.abs(x)) - * (0.5 * (b - p)) - / jnp.sqrt(jnp.abs(x)), - ) + return jax.lax.cond( + x < 0, + lambda: -(jnp.exp(jnp.abs(x)) * (0.5 * (b - p)) / jnp.sqrt(jnp.abs(x))), + lambda: jnp.exp(jnp.abs(x)) * (0.5 * (b - p)) / jnp.sqrt(jnp.abs(x)), + ) - return jnp.piecewise(self, [self <= 8], [small, default]) + return jnp.piecewise(self, [self <= 8], [small, default]) @op(torch.ops.aten.special_modified_bessel_k0) @op_base.promote_int_input def _aten_special_modified_bessel_k0(self): - # Adapted from https://github.com/pytorch/pytorch/blob/f8f41dcb24cb4f4e87a51bb04847942dd835e496/aten/src/ATen/native/Math.h#L3367-L3441 - - def zero(x): - return jnp.array(jnp.inf, x.dtype) - - def negative(x): - return jnp.array(jnp.nan, x.dtype) - - def small(x): - A = jnp.array( - [ - 1.37446543561352307156e-16, - 4.25981614279661018399e-14, - 1.03496952576338420167e-11, - 1.90451637722020886025e-09, - 2.53479107902614945675e-07, - 2.28621210311945178607e-05, - 1.26461541144692592338e-03, - 3.59799365153615016266e-02, - 3.44289899924628486886e-01, - -5.35327393233902768720e-01, - ], - dtype=self.dtype, - ) + # Adapted from https://github.com/pytorch/pytorch/blob/f8f41dcb24cb4f4e87a51bb04847942dd835e496/aten/src/ATen/native/Math.h#L3367-L3441 + + def zero(x): + return jnp.array(jnp.inf, x.dtype) + + def negative(x): + return jnp.array(jnp.nan, x.dtype) + + def small(x): + A = jnp.array( + [ + 1.37446543561352307156e-16, + 4.25981614279661018399e-14, + 1.03496952576338420167e-11, + 1.90451637722020886025e-09, + 2.53479107902614945675e-07, + 2.28621210311945178607e-05, + 1.26461541144692592338e-03, + 3.59799365153615016266e-02, + 3.44289899924628486886e-01, + -5.35327393233902768720e-01, + ], + dtype=self.dtype, + ) - def f(carry, val): - p, q, a = carry - p, q = q, a - return (p, q, (x * x - 2.0) * q - p + val), None + def f(carry, val): + p, q, a = carry + p, q = q, a + return (p, q, (x * x - 2.0) * q - p + val), None - (p, _, a), _ = jax.lax.scan( - f, init=(jnp.zeros_like(x), jnp.zeros_like(x), 0), xs=A - ) + (p, _, a), _ = jax.lax.scan( + f, init=(jnp.zeros_like(x), jnp.zeros_like(x), 0), xs=A + ) - return 0.5 * (a - p) - jnp.log( - 0.5 * x - ) * _aten_special_modified_bessel_i0(x) + return 0.5 * (a - p) - jnp.log(0.5 * x) * _aten_special_modified_bessel_i0( + x + ) - def default(x): - B = jnp.array( - [ - 5.30043377268626276149e-18, - -1.64758043015242134646e-17, - 5.21039150503902756861e-17, - -1.67823109680541210385e-16, - 5.51205597852431940784e-16, - -1.84859337734377901440e-15, - 6.34007647740507060557e-15, - -2.22751332699166985548e-14, - 8.03289077536357521100e-14, - -2.98009692317273043925e-13, - 1.14034058820847496303e-12, - -4.51459788337394416547e-12, - 1.85594911495471785253e-11, - -7.95748924447710747776e-11, - 3.57739728140030116597e-10, - -1.69753450938905987466e-09, - 8.57403401741422608519e-09, - -4.66048989768794782956e-08, - 2.76681363944501510342e-07, - -1.83175552271911948767e-06, - 1.39498137188764993662e-05, - -1.28495495816278026384e-04, - 1.56988388573005337491e-03, - -3.14481013119645005427e-02, - 2.44030308206595545468e00, - ], - dtype=self.dtype, - ) + def default(x): + B = jnp.array( + [ + 5.30043377268626276149e-18, + -1.64758043015242134646e-17, + 5.21039150503902756861e-17, + -1.67823109680541210385e-16, + 5.51205597852431940784e-16, + -1.84859337734377901440e-15, + 6.34007647740507060557e-15, + -2.22751332699166985548e-14, + 8.03289077536357521100e-14, + -2.98009692317273043925e-13, + 1.14034058820847496303e-12, + -4.51459788337394416547e-12, + 1.85594911495471785253e-11, + -7.95748924447710747776e-11, + 3.57739728140030116597e-10, + -1.69753450938905987466e-09, + 8.57403401741422608519e-09, + -4.66048989768794782956e-08, + 2.76681363944501510342e-07, + -1.83175552271911948767e-06, + 1.39498137188764993662e-05, + -1.28495495816278026384e-04, + 1.56988388573005337491e-03, + -3.14481013119645005427e-02, + 2.44030308206595545468e00, + ], + dtype=self.dtype, + ) - def f(carry, val): - p, q, b = carry - p, q = q, b - return (p, q, (8.0 / x - 2.0) * q - p + val), None + def f(carry, val): + p, q, b = carry + p, q = q, b + return (p, q, (8.0 / x - 2.0) * q - p + val), None - (p, _, b), _ = jax.lax.scan( - f, init=(jnp.zeros_like(x), jnp.zeros_like(x), 0), xs=B - ) + (p, _, b), _ = jax.lax.scan( + f, init=(jnp.zeros_like(x), jnp.zeros_like(x), 0), xs=B + ) - return jnp.exp(-x) * (0.5 * (b - p)) / jnp.sqrt(x) + return jnp.exp(-x) * (0.5 * (b - p)) / jnp.sqrt(x) - return jnp.piecewise( - self, [self <= 2, self < 0, self == 0], [small, negative, zero, default] - ) + return jnp.piecewise( + self, [self <= 2, self < 0, self == 0], [small, negative, zero, default] + ) @op(torch.ops.aten.special_modified_bessel_k1) @op_base.promote_int_input def _aten_special_modified_bessel_k1(self): - # Adapted from https://github.com/pytorch/pytorch/blob/f8f41dcb24cb4f4e87a51bb04847942dd835e496/aten/src/ATen/native/Math.h#L3444-L3519 - - def zero(x): - return jnp.array(jnp.inf, x.dtype) - - def negative(x): - return jnp.array(jnp.nan, x.dtype) - - def small(x): - A = jnp.array( - [ - -7.02386347938628759343e-18, - -2.42744985051936593393e-15, - -6.66690169419932900609e-13, - -1.41148839263352776110e-10, - -2.21338763073472585583e-08, - -2.43340614156596823496e-06, - -1.73028895751305206302e-04, - -6.97572385963986435018e-03, - -1.22611180822657148235e-01, - -3.53155960776544875667e-01, - 1.52530022733894777053e00, - ], - dtype=self.dtype, - ) + # Adapted from https://github.com/pytorch/pytorch/blob/f8f41dcb24cb4f4e87a51bb04847942dd835e496/aten/src/ATen/native/Math.h#L3444-L3519 + + def zero(x): + return jnp.array(jnp.inf, x.dtype) + + def negative(x): + return jnp.array(jnp.nan, x.dtype) + + def small(x): + A = jnp.array( + [ + -7.02386347938628759343e-18, + -2.42744985051936593393e-15, + -6.66690169419932900609e-13, + -1.41148839263352776110e-10, + -2.21338763073472585583e-08, + -2.43340614156596823496e-06, + -1.73028895751305206302e-04, + -6.97572385963986435018e-03, + -1.22611180822657148235e-01, + -3.53155960776544875667e-01, + 1.52530022733894777053e00, + ], + dtype=self.dtype, + ) - def f(carry, val): - p, q, a = carry - p, q = q, a - a = (x * x - 2.0) * q - p + val - return (p, q, a), None + def f(carry, val): + p, q, a = carry + p, q = q, a + a = (x * x - 2.0) * q - p + val + return (p, q, a), None - (p, _, a), _ = jax.lax.scan( - f, init=(jnp.zeros_like(x), jnp.zeros_like(x), 0), xs=A - ) + (p, _, a), _ = jax.lax.scan( + f, init=(jnp.zeros_like(x), jnp.zeros_like(x), 0), xs=A + ) - return ( - jnp.log(0.5 * x) * _aten_special_modified_bessel_i1(x) - + 0.5 * (a - p) / x - ) + return ( + jnp.log(0.5 * x) * _aten_special_modified_bessel_i1(x) + 0.5 * (a - p) / x + ) - def default(x): - B = jnp.array( - [ - -5.75674448366501715755e-18, - 1.79405087314755922667e-17, - -5.68946255844285935196e-17, - 1.83809354436663880070e-16, - -6.05704724837331885336e-16, - 2.03870316562433424052e-15, - -7.01983709041831346144e-15, - 2.47715442448130437068e-14, - -8.97670518232499435011e-14, - +3.34841966607842919884e-13, - -1.28917396095102890680e-12, - 5.13963967348173025100e-12, - -2.12996783842756842877e-11, - 9.21831518760500529508e-11, - -4.19035475934189648750e-10, - 2.01504975519703286596e-09, - -1.03457624656780970260e-08, - 5.74108412545004946722e-08, - -3.50196060308781257119e-07, - 2.40648494783721712015e-06, - -1.93619797416608296024e-05, - 1.95215518471351631108e-04, - -2.85781685962277938680e-03, - 1.03923736576817238437e-01, - 2.72062619048444266945e00, - ], - dtype=self.dtype, - ) + def default(x): + B = jnp.array( + [ + -5.75674448366501715755e-18, + 1.79405087314755922667e-17, + -5.68946255844285935196e-17, + 1.83809354436663880070e-16, + -6.05704724837331885336e-16, + 2.03870316562433424052e-15, + -7.01983709041831346144e-15, + 2.47715442448130437068e-14, + -8.97670518232499435011e-14, + +3.34841966607842919884e-13, + -1.28917396095102890680e-12, + 5.13963967348173025100e-12, + -2.12996783842756842877e-11, + 9.21831518760500529508e-11, + -4.19035475934189648750e-10, + 2.01504975519703286596e-09, + -1.03457624656780970260e-08, + 5.74108412545004946722e-08, + -3.50196060308781257119e-07, + 2.40648494783721712015e-06, + -1.93619797416608296024e-05, + 1.95215518471351631108e-04, + -2.85781685962277938680e-03, + 1.03923736576817238437e-01, + 2.72062619048444266945e00, + ], + dtype=self.dtype, + ) - def f(carry, val): - p, q, b = carry - p, q = q, b - b = (8.0 / x - 2.0) * q - p + val - return (p, q, b), None + def f(carry, val): + p, q, b = carry + p, q = q, b + b = (8.0 / x - 2.0) * q - p + val + return (p, q, b), None - (p, _, b), _ = jax.lax.scan( - f, init=(jnp.zeros_like(x), jnp.zeros_like(x), 0), xs=B - ) + (p, _, b), _ = jax.lax.scan( + f, init=(jnp.zeros_like(x), jnp.zeros_like(x), 0), xs=B + ) - return jnp.exp(-x) * (0.5 * (b - p)) / jnp.sqrt(x) + return jnp.exp(-x) * (0.5 * (b - p)) / jnp.sqrt(x) - return jnp.piecewise( - self, [self <= 2, self < 0, self == 0], [small, negative, zero, default] - ) + return jnp.piecewise( + self, [self <= 2, self < 0, self == 0], [small, negative, zero, default] + ) @op(torch.ops.aten.polygamma) def _aten_polygamma(x, n): - if n.dtype in [jnp.int8, jnp.int16, jnp.int32, jnp.int64]: - n = n.astype(mappings.t2j_dtype(torch.get_default_dtype())) - return jax.lax.polygamma(jnp.float32(x), n) + if n.dtype in [jnp.int8, jnp.int16, jnp.int32, jnp.int64]: + n = n.astype(mappings.t2j_dtype(torch.get_default_dtype())) + return jax.lax.polygamma(jnp.float32(x), n) @op(torch.ops.aten.special_ndtri) @op_base.promote_int_input def _aten_special_ndtri(self): - return jax.scipy.special.ndtri(self) + return jax.scipy.special.ndtri(self) @op(torch.ops.aten.special_bessel_j0) @op_base.promote_int_input def _aten_special_bessel_j0(self): - # Adapted from https://github.com/pytorch/pytorch/blob/f8f41dcb24cb4f4e87a51bb04847942dd835e496/aten/src/ATen/native/Math.h#L2379-L2489 - - def very_small(x): - return 1.0 - x * x / 4.0 - - def small(x): - RP = jnp.array( - [ - -4.79443220978201773821e09, - 1.95617491946556577543e12, - -2.49248344360967716204e14, - 9.70862251047306323952e15, - ], - dtype=self.dtype, - ) - RQ = jnp.array( - [ - 4.99563147152651017219e02, - 1.73785401676374683123e05, - 4.84409658339962045305e07, - 1.11855537045356834862e10, - 2.11277520115489217587e12, - 3.10518229857422583814e14, - 3.18121955943204943306e16, - 1.71086294081043136091e18, - ], - dtype=self.dtype, - ) + # Adapted from https://github.com/pytorch/pytorch/blob/f8f41dcb24cb4f4e87a51bb04847942dd835e496/aten/src/ATen/native/Math.h#L2379-L2489 + + def very_small(x): + return 1.0 - x * x / 4.0 + + def small(x): + RP = jnp.array( + [ + -4.79443220978201773821e09, + 1.95617491946556577543e12, + -2.49248344360967716204e14, + 9.70862251047306323952e15, + ], + dtype=self.dtype, + ) + RQ = jnp.array( + [ + 4.99563147152651017219e02, + 1.73785401676374683123e05, + 4.84409658339962045305e07, + 1.11855537045356834862e10, + 2.11277520115489217587e12, + 3.10518229857422583814e14, + 3.18121955943204943306e16, + 1.71086294081043136091e18, + ], + dtype=self.dtype, + ) - rp = op_base.foreach_loop( - RP, lambda carry, rp_i: carry * (x * x) + rp_i - ) - rq = op_base.foreach_loop( - RQ, lambda carry, rq_i: carry * (x * x) + rq_i - ) + rp = op_base.foreach_loop(RP, lambda carry, rp_i: carry * (x * x) + rp_i) + rq = op_base.foreach_loop(RQ, lambda carry, rq_i: carry * (x * x) + rq_i) - return ( - (x * x - 5.78318596294678452118e00) - * (x * x - 3.04712623436620863991e01) - * rp - / rq - ) - - def default(x): - PP = jnp.array( - [ - 7.96936729297347051624e-04, - 8.28352392107440799803e-02, - 1.23953371646414299388e00, - 5.44725003058768775090e00, - 8.74716500199817011941e00, - 5.30324038235394892183e00, - 9.99999999999999997821e-01, - ], - dtype=self.dtype, - ) - PQ = jnp.array( - [ - 9.24408810558863637013e-04, - 8.56288474354474431428e-02, - 1.25352743901058953537e00, - 5.47097740330417105182e00, - 8.76190883237069594232e00, - 5.30605288235394617618e00, - 1.00000000000000000218e00, - ], - dtype=self.dtype, - ) - QP = jnp.array( - [ - -1.13663838898469149931e-02, - -1.28252718670509318512e00, - -1.95539544257735972385e01, - -9.32060152123768231369e01, - -1.77681167980488050595e02, - -1.47077505154951170175e02, - -5.14105326766599330220e01, - -6.05014350600728481186e00, - ], - dtype=self.dtype, - ) - QQ = jnp.array( - [ - 6.43178256118178023184e01, - 8.56430025976980587198e02, - 3.88240183605401609683e03, - 7.24046774195652478189e03, - 5.93072701187316984827e03, - 2.06209331660327847417e03, - 2.42005740240291393179e02, - ], - dtype=self.dtype, - ) + return ( + (x * x - 5.78318596294678452118e00) + * (x * x - 3.04712623436620863991e01) + * rp + / rq + ) - pp = op_base.foreach_loop( - PP, lambda carry, pp_i: carry * (25.0 / (x * x)) + pp_i - ) - pq = op_base.foreach_loop( - PQ, lambda carry, pq_i: carry * (25.0 / (x * x)) + pq_i - ) - qp = op_base.foreach_loop( - QP, lambda carry, qp_i: carry * (25.0 / (x * x)) + qp_i - ) - qq = op_base.foreach_loop( - QQ, lambda carry, qq_i: carry * (25.0 / (x * x)) + qq_i - ) + def default(x): + PP = jnp.array( + [ + 7.96936729297347051624e-04, + 8.28352392107440799803e-02, + 1.23953371646414299388e00, + 5.44725003058768775090e00, + 8.74716500199817011941e00, + 5.30324038235394892183e00, + 9.99999999999999997821e-01, + ], + dtype=self.dtype, + ) + PQ = jnp.array( + [ + 9.24408810558863637013e-04, + 8.56288474354474431428e-02, + 1.25352743901058953537e00, + 5.47097740330417105182e00, + 8.76190883237069594232e00, + 5.30605288235394617618e00, + 1.00000000000000000218e00, + ], + dtype=self.dtype, + ) + QP = jnp.array( + [ + -1.13663838898469149931e-02, + -1.28252718670509318512e00, + -1.95539544257735972385e01, + -9.32060152123768231369e01, + -1.77681167980488050595e02, + -1.47077505154951170175e02, + -5.14105326766599330220e01, + -6.05014350600728481186e00, + ], + dtype=self.dtype, + ) + QQ = jnp.array( + [ + 6.43178256118178023184e01, + 8.56430025976980587198e02, + 3.88240183605401609683e03, + 7.24046774195652478189e03, + 5.93072701187316984827e03, + 2.06209331660327847417e03, + 2.42005740240291393179e02, + ], + dtype=self.dtype, + ) - return ( - ( - pp / pq * jnp.cos(x - 0.785398163397448309615660845819875721) - - 5.0 - / x - * (qp / qq) - * jnp.sin(x - 0.785398163397448309615660845819875721) - ) - * 0.797884560802865355879892119868763737 - / jnp.sqrt(x) - ) + pp = op_base.foreach_loop( + PP, lambda carry, pp_i: carry * (25.0 / (x * x)) + pp_i + ) + pq = op_base.foreach_loop( + PQ, lambda carry, pq_i: carry * (25.0 / (x * x)) + pq_i + ) + qp = op_base.foreach_loop( + QP, lambda carry, qp_i: carry * (25.0 / (x * x)) + qp_i + ) + qq = op_base.foreach_loop( + QQ, lambda carry, qq_i: carry * (25.0 / (x * x)) + qq_i + ) - self = jnp.abs(self) - # Last True condition in `piecewise` takes priority, but last function is - # default. See https://github.com/numpy/numpy/issues/16475 - return jnp.piecewise( - self, [self <= 5.0, self < 0.00001], [small, very_small, default] + return ( + ( + pp / pq * jnp.cos(x - 0.785398163397448309615660845819875721) + - 5.0 + / x + * (qp / qq) + * jnp.sin(x - 0.785398163397448309615660845819875721) + ) + * 0.797884560802865355879892119868763737 + / jnp.sqrt(x) ) + self = jnp.abs(self) + # Last True condition in `piecewise` takes priority, but last function is + # default. See https://github.com/numpy/numpy/issues/16475 + return jnp.piecewise( + self, [self <= 5.0, self < 0.00001], [small, very_small, default] + ) + @op(torch.ops.aten.special_bessel_j1) @op_base.promote_int_input def _aten_special_bessel_j1(self): - # Adapted from https://github.com/pytorch/pytorch/blob/f8f41dcb24cb4f4e87a51bb04847942dd835e496/aten/src/ATen/native/Math.h#L2491-L2597 - - def small(x): - RP = jnp.array( - [ - -8.99971225705559398224e08, - 4.52228297998194034323e11, - -7.27494245221818276015e13, - 3.68295732863852883286e15, - ], - dtype=self.dtype, - ) - RQ = jnp.array( - [ - 6.20836478118054335476e02, - 2.56987256757748830383e05, - 8.35146791431949253037e07, - 2.21511595479792499675e10, - 4.74914122079991414898e12, - 7.84369607876235854894e14, - 8.95222336184627338078e16, - 5.32278620332680085395e18, - ], - dtype=self.dtype, - ) - - rp = op_base.foreach_loop( - RP, lambda carry, rp_i: carry * (x * x) + rp_i - ) - rq = op_base.foreach_loop( - RQ, lambda carry, rq_i: carry * (x * x) + rq_i - ) + # Adapted from https://github.com/pytorch/pytorch/blob/f8f41dcb24cb4f4e87a51bb04847942dd835e496/aten/src/ATen/native/Math.h#L2491-L2597 + + def small(x): + RP = jnp.array( + [ + -8.99971225705559398224e08, + 4.52228297998194034323e11, + -7.27494245221818276015e13, + 3.68295732863852883286e15, + ], + dtype=self.dtype, + ) + RQ = jnp.array( + [ + 6.20836478118054335476e02, + 2.56987256757748830383e05, + 8.35146791431949253037e07, + 2.21511595479792499675e10, + 4.74914122079991414898e12, + 7.84369607876235854894e14, + 8.95222336184627338078e16, + 5.32278620332680085395e18, + ], + dtype=self.dtype, + ) - return ( - rp - / rq - * x - * (x * x - 1.46819706421238932572e01) - * (x * x - 4.92184563216946036703e01) - ) + rp = op_base.foreach_loop(RP, lambda carry, rp_i: carry * (x * x) + rp_i) + rq = op_base.foreach_loop(RQ, lambda carry, rq_i: carry * (x * x) + rq_i) - def default(x): - PP = jnp.array( - [ - 7.62125616208173112003e-04, - 7.31397056940917570436e-02, - 1.12719608129684925192e00, - 5.11207951146807644818e00, - 8.42404590141772420927e00, - 5.21451598682361504063e00, - 1.00000000000000000254e00, - ], - dtype=self.dtype, - ) - PQ = jnp.array( - [ - 5.71323128072548699714e-04, - 6.88455908754495404082e-02, - 1.10514232634061696926e00, - 5.07386386128601488557e00, - 8.39985554327604159757e00, - 5.20982848682361821619e00, - 9.99999999999999997461e-01, - ], - dtype=self.dtype, - ) - QP = jnp.array( - [ - 5.10862594750176621635e-02, - 4.98213872951233449420e00, - 7.58238284132545283818e01, - 3.66779609360150777800e02, - 7.10856304998926107277e02, - 5.97489612400613639965e02, - 2.11688757100572135698e02, - 2.52070205858023719784e01, - ], - dtype=self.dtype, - ) - QQ = jnp.array( - [ - 7.42373277035675149943e01, - 1.05644886038262816351e03, - 4.98641058337653607651e03, - 9.56231892404756170795e03, - 7.99704160447350683650e03, - 2.82619278517639096600e03, - 3.36093607810698293419e02, - ], - dtype=self.dtype, - ) + return ( + rp + / rq + * x + * (x * x - 1.46819706421238932572e01) + * (x * x - 4.92184563216946036703e01) + ) - pp = op_base.foreach_loop( - PP, lambda carry, pp_i: carry * (25.0 / (x * x)) + pp_i - ) - pq = op_base.foreach_loop( - PQ, lambda carry, pq_i: carry * (25.0 / (x * x)) + pq_i - ) - qp = op_base.foreach_loop( - QP, lambda carry, qp_i: carry * (25.0 / (x * x)) + qp_i - ) - qq = op_base.foreach_loop( - QQ, lambda carry, qq_i: carry * (25.0 / (x * x)) + qq_i - ) + def default(x): + PP = jnp.array( + [ + 7.62125616208173112003e-04, + 7.31397056940917570436e-02, + 1.12719608129684925192e00, + 5.11207951146807644818e00, + 8.42404590141772420927e00, + 5.21451598682361504063e00, + 1.00000000000000000254e00, + ], + dtype=self.dtype, + ) + PQ = jnp.array( + [ + 5.71323128072548699714e-04, + 6.88455908754495404082e-02, + 1.10514232634061696926e00, + 5.07386386128601488557e00, + 8.39985554327604159757e00, + 5.20982848682361821619e00, + 9.99999999999999997461e-01, + ], + dtype=self.dtype, + ) + QP = jnp.array( + [ + 5.10862594750176621635e-02, + 4.98213872951233449420e00, + 7.58238284132545283818e01, + 3.66779609360150777800e02, + 7.10856304998926107277e02, + 5.97489612400613639965e02, + 2.11688757100572135698e02, + 2.52070205858023719784e01, + ], + dtype=self.dtype, + ) + QQ = jnp.array( + [ + 7.42373277035675149943e01, + 1.05644886038262816351e03, + 4.98641058337653607651e03, + 9.56231892404756170795e03, + 7.99704160447350683650e03, + 2.82619278517639096600e03, + 3.36093607810698293419e02, + ], + dtype=self.dtype, + ) - return ( - ( - pp / pq * jnp.cos(x - 2.356194490192344928846982537459627163) - - 5.0 - / x - * (qp / qq) - * jnp.sin(x - 2.356194490192344928846982537459627163) - ) - * 0.797884560802865355879892119868763737 - / jnp.sqrt(x) - ) + pp = op_base.foreach_loop( + PP, lambda carry, pp_i: carry * (25.0 / (x * x)) + pp_i + ) + pq = op_base.foreach_loop( + PQ, lambda carry, pq_i: carry * (25.0 / (x * x)) + pq_i + ) + qp = op_base.foreach_loop( + QP, lambda carry, qp_i: carry * (25.0 / (x * x)) + qp_i + ) + qq = op_base.foreach_loop( + QQ, lambda carry, qq_i: carry * (25.0 / (x * x)) + qq_i + ) - # If x < 0, bessel_j1(x) = -bessel_j1(-x) - sign = jnp.sign(self) - self = jnp.abs(self) - return sign * jnp.piecewise( - self, - [self <= 5.0], - [small, default], + return ( + ( + pp / pq * jnp.cos(x - 2.356194490192344928846982537459627163) + - 5.0 + / x + * (qp / qq) + * jnp.sin(x - 2.356194490192344928846982537459627163) + ) + * 0.797884560802865355879892119868763737 + / jnp.sqrt(x) ) + # If x < 0, bessel_j1(x) = -bessel_j1(-x) + sign = jnp.sign(self) + self = jnp.abs(self) + return sign * jnp.piecewise( + self, + [self <= 5.0], + [small, default], + ) + @op(torch.ops.aten.special_bessel_y0) @op_base.promote_int_input def _aten_special_bessel_y0(self): - # Adapted from https://github.com/pytorch/pytorch/blob/f8f41dcb24cb4f4e87a51bb04847942dd835e496/aten/src/ATen/native/Math.h#L2599-L2712 - - def zero(x): - return jnp.array(-jnp.inf, x.dtype) - - def negative(x): - return jnp.array(jnp.nan, x.dtype) - - def small(x): - YP = jnp.array( - [ - 1.55924367855235737965e04, - -1.46639295903971606143e07, - 5.43526477051876500413e09, - -9.82136065717911466409e11, - 8.75906394395366999549e13, - -3.46628303384729719441e15, - 4.42733268572569800351e16, - -1.84950800436986690637e16, - ], - dtype=self.dtype, - ) - YQ = jnp.array( - [ - 1.04128353664259848412e03, - 6.26107330137134956842e05, - 2.68919633393814121987e08, - 8.64002487103935000337e10, - 2.02979612750105546709e13, - 3.17157752842975028269e15, - 2.50596256172653059228e17, - ], - dtype=self.dtype, - ) - - yp = op_base.foreach_loop( - YP, lambda carry, yp_i: carry * (x * x) + yp_i - ) - yq = op_base.foreach_loop( - YQ, lambda carry, yq_i: carry * (x * x) + yq_i - ) + # Adapted from https://github.com/pytorch/pytorch/blob/f8f41dcb24cb4f4e87a51bb04847942dd835e496/aten/src/ATen/native/Math.h#L2599-L2712 + + def zero(x): + return jnp.array(-jnp.inf, x.dtype) + + def negative(x): + return jnp.array(jnp.nan, x.dtype) + + def small(x): + YP = jnp.array( + [ + 1.55924367855235737965e04, + -1.46639295903971606143e07, + 5.43526477051876500413e09, + -9.82136065717911466409e11, + 8.75906394395366999549e13, + -3.46628303384729719441e15, + 4.42733268572569800351e16, + -1.84950800436986690637e16, + ], + dtype=self.dtype, + ) + YQ = jnp.array( + [ + 1.04128353664259848412e03, + 6.26107330137134956842e05, + 2.68919633393814121987e08, + 8.64002487103935000337e10, + 2.02979612750105546709e13, + 3.17157752842975028269e15, + 2.50596256172653059228e17, + ], + dtype=self.dtype, + ) - return yp / yq + ( - 0.636619772367581343075535053490057448 - * jnp.log(x) - * _aten_special_bessel_j0(x) - ) + yp = op_base.foreach_loop(YP, lambda carry, yp_i: carry * (x * x) + yp_i) + yq = op_base.foreach_loop(YQ, lambda carry, yq_i: carry * (x * x) + yq_i) - def default(x): - PP = jnp.array( - [ - 7.96936729297347051624e-04, - 8.28352392107440799803e-02, - 1.23953371646414299388e00, - 5.44725003058768775090e00, - 8.74716500199817011941e00, - 5.30324038235394892183e00, - 9.99999999999999997821e-01, - ], - dtype=self.dtype, - ) - PQ = jnp.array( - [ - 9.24408810558863637013e-04, - 8.56288474354474431428e-02, - 1.25352743901058953537e00, - 5.47097740330417105182e00, - 8.76190883237069594232e00, - 5.30605288235394617618e00, - 1.00000000000000000218e00, - ], - dtype=self.dtype, - ) - QP = jnp.array( - [ - -1.13663838898469149931e-02, - -1.28252718670509318512e00, - -1.95539544257735972385e01, - -9.32060152123768231369e01, - -1.77681167980488050595e02, - -1.47077505154951170175e02, - -5.14105326766599330220e01, - -6.05014350600728481186e00, - ], - dtype=self.dtype, - ) - QQ = jnp.array( - [ - 6.43178256118178023184e01, - 8.56430025976980587198e02, - 3.88240183605401609683e03, - 7.24046774195652478189e03, - 5.93072701187316984827e03, - 2.06209331660327847417e03, - 2.42005740240291393179e02, - ], - dtype=self.dtype, - ) + return yp / yq + ( + 0.636619772367581343075535053490057448 + * jnp.log(x) + * _aten_special_bessel_j0(x) + ) - factor = 25.0 / (x * x) - pp = op_base.foreach_loop(PP, lambda carry, pp_i: carry * factor + pp_i) - pq = op_base.foreach_loop(PQ, lambda carry, pq_i: carry * factor + pq_i) - qp = op_base.foreach_loop(QP, lambda carry, qp_i: carry * factor + qp_i) - qq = op_base.foreach_loop(QQ, lambda carry, qq_i: carry * factor + qq_i) + def default(x): + PP = jnp.array( + [ + 7.96936729297347051624e-04, + 8.28352392107440799803e-02, + 1.23953371646414299388e00, + 5.44725003058768775090e00, + 8.74716500199817011941e00, + 5.30324038235394892183e00, + 9.99999999999999997821e-01, + ], + dtype=self.dtype, + ) + PQ = jnp.array( + [ + 9.24408810558863637013e-04, + 8.56288474354474431428e-02, + 1.25352743901058953537e00, + 5.47097740330417105182e00, + 8.76190883237069594232e00, + 5.30605288235394617618e00, + 1.00000000000000000218e00, + ], + dtype=self.dtype, + ) + QP = jnp.array( + [ + -1.13663838898469149931e-02, + -1.28252718670509318512e00, + -1.95539544257735972385e01, + -9.32060152123768231369e01, + -1.77681167980488050595e02, + -1.47077505154951170175e02, + -5.14105326766599330220e01, + -6.05014350600728481186e00, + ], + dtype=self.dtype, + ) + QQ = jnp.array( + [ + 6.43178256118178023184e01, + 8.56430025976980587198e02, + 3.88240183605401609683e03, + 7.24046774195652478189e03, + 5.93072701187316984827e03, + 2.06209331660327847417e03, + 2.42005740240291393179e02, + ], + dtype=self.dtype, + ) - return ( - ( - pp / pq * jnp.sin(x - 0.785398163397448309615660845819875721) - + 5.0 - / x - * (qp / qq) - * jnp.cos(x - 0.785398163397448309615660845819875721) - ) - * 0.797884560802865355879892119868763737 - / jnp.sqrt(x) - ) + factor = 25.0 / (x * x) + pp = op_base.foreach_loop(PP, lambda carry, pp_i: carry * factor + pp_i) + pq = op_base.foreach_loop(PQ, lambda carry, pq_i: carry * factor + pq_i) + qp = op_base.foreach_loop(QP, lambda carry, qp_i: carry * factor + qp_i) + qq = op_base.foreach_loop(QQ, lambda carry, qq_i: carry * factor + qq_i) - return jnp.piecewise( - self, - [self <= 5.0, self < 0.0, self == 0.0], - [small, negative, zero, default], + return ( + ( + pp / pq * jnp.sin(x - 0.785398163397448309615660845819875721) + + 5.0 + / x + * (qp / qq) + * jnp.cos(x - 0.785398163397448309615660845819875721) + ) + * 0.797884560802865355879892119868763737 + / jnp.sqrt(x) ) + return jnp.piecewise( + self, + [self <= 5.0, self < 0.0, self == 0.0], + [small, negative, zero, default], + ) + @op(torch.ops.aten.special_bessel_y1) @op_base.promote_int_input def _aten_special_bessel_y1(self): - # Adapted from https://github.com/pytorch/pytorch/blob/f8f41dcb24cb4f4e87a51bb04847942dd835e496/aten/src/ATen/native/Math.h#L2714-L2826 - - def zero(x): - return jnp.array(-jnp.inf, x.dtype) - - def negative(x): - return jnp.array(jnp.nan, x.dtype) - - def small(x): - YP = jnp.array( - [ - 1.26320474790178026440e09, - -6.47355876379160291031e11, - 1.14509511541823727583e14, - -8.12770255501325109621e15, - 2.02439475713594898196e17, - -7.78877196265950026825e17, - ], - dtype=self.dtype, - ) - YQ = jnp.array( - [ - 5.94301592346128195359e02, - 2.35564092943068577943e05, - 7.34811944459721705660e07, - 1.87601316108706159478e10, - 3.88231277496238566008e12, - 6.20557727146953693363e14, - 6.87141087355300489866e16, - 3.97270608116560655612e18, - ], - dtype=self.dtype, - ) - - yp = op_base.foreach_loop( - YP, lambda carry, yp_i: carry * (x * x) + yp_i - ) - yq = op_base.foreach_loop( - YQ, lambda carry, yq_i: carry * (x * x) + yq_i - ) + # Adapted from https://github.com/pytorch/pytorch/blob/f8f41dcb24cb4f4e87a51bb04847942dd835e496/aten/src/ATen/native/Math.h#L2714-L2826 + + def zero(x): + return jnp.array(-jnp.inf, x.dtype) + + def negative(x): + return jnp.array(jnp.nan, x.dtype) + + def small(x): + YP = jnp.array( + [ + 1.26320474790178026440e09, + -6.47355876379160291031e11, + 1.14509511541823727583e14, + -8.12770255501325109621e15, + 2.02439475713594898196e17, + -7.78877196265950026825e17, + ], + dtype=self.dtype, + ) + YQ = jnp.array( + [ + 5.94301592346128195359e02, + 2.35564092943068577943e05, + 7.34811944459721705660e07, + 1.87601316108706159478e10, + 3.88231277496238566008e12, + 6.20557727146953693363e14, + 6.87141087355300489866e16, + 3.97270608116560655612e18, + ], + dtype=self.dtype, + ) - return x * (yp / yq) + ( - 0.636619772367581343075535053490057448 - * (_aten_special_bessel_j1(x) * jnp.log(x) - 1.0 / x) - ) + yp = op_base.foreach_loop(YP, lambda carry, yp_i: carry * (x * x) + yp_i) + yq = op_base.foreach_loop(YQ, lambda carry, yq_i: carry * (x * x) + yq_i) - def default(x): - PP = jnp.array( - [ - 7.62125616208173112003e-04, - 7.31397056940917570436e-02, - 1.12719608129684925192e00, - 5.11207951146807644818e00, - 8.42404590141772420927e00, - 5.21451598682361504063e00, - 1.00000000000000000254e00, - ], - dtype=self.dtype, - ) - PQ = jnp.array( - [ - 5.71323128072548699714e-04, - 6.88455908754495404082e-02, - 1.10514232634061696926e00, - 5.07386386128601488557e00, - 8.39985554327604159757e00, - 5.20982848682361821619e00, - 9.99999999999999997461e-01, - ], - dtype=self.dtype, - ) - QP = jnp.array( - [ - 5.10862594750176621635e-02, - 4.98213872951233449420e00, - 7.58238284132545283818e01, - 3.66779609360150777800e02, - 7.10856304998926107277e02, - 5.97489612400613639965e02, - 2.11688757100572135698e02, - 2.52070205858023719784e01, - ], - dtype=self.dtype, - ) - QQ = jnp.array( - [ - 7.42373277035675149943e01, - 1.05644886038262816351e03, - 4.98641058337653607651e03, - 9.56231892404756170795e03, - 7.99704160447350683650e03, - 2.82619278517639096600e03, - 3.36093607810698293419e02, - ], - dtype=self.dtype, - ) + return x * (yp / yq) + ( + 0.636619772367581343075535053490057448 + * (_aten_special_bessel_j1(x) * jnp.log(x) - 1.0 / x) + ) - factor = 25.0 / (x * x) - pp = op_base.foreach_loop(PP, lambda carry, pp_i: carry * factor + pp_i) - pq = op_base.foreach_loop(PQ, lambda carry, pq_i: carry * factor + pq_i) - qp = op_base.foreach_loop(QP, lambda carry, qp_i: carry * factor + qp_i) - qq = op_base.foreach_loop(QQ, lambda carry, qq_i: carry * factor + qq_i) + def default(x): + PP = jnp.array( + [ + 7.62125616208173112003e-04, + 7.31397056940917570436e-02, + 1.12719608129684925192e00, + 5.11207951146807644818e00, + 8.42404590141772420927e00, + 5.21451598682361504063e00, + 1.00000000000000000254e00, + ], + dtype=self.dtype, + ) + PQ = jnp.array( + [ + 5.71323128072548699714e-04, + 6.88455908754495404082e-02, + 1.10514232634061696926e00, + 5.07386386128601488557e00, + 8.39985554327604159757e00, + 5.20982848682361821619e00, + 9.99999999999999997461e-01, + ], + dtype=self.dtype, + ) + QP = jnp.array( + [ + 5.10862594750176621635e-02, + 4.98213872951233449420e00, + 7.58238284132545283818e01, + 3.66779609360150777800e02, + 7.10856304998926107277e02, + 5.97489612400613639965e02, + 2.11688757100572135698e02, + 2.52070205858023719784e01, + ], + dtype=self.dtype, + ) + QQ = jnp.array( + [ + 7.42373277035675149943e01, + 1.05644886038262816351e03, + 4.98641058337653607651e03, + 9.56231892404756170795e03, + 7.99704160447350683650e03, + 2.82619278517639096600e03, + 3.36093607810698293419e02, + ], + dtype=self.dtype, + ) - return ( - ( - pp / pq * jnp.sin(x - 2.356194490192344928846982537459627163) - + 5.0 - / x - * (qp / qq) - * jnp.cos(x - 2.356194490192344928846982537459627163) - ) - * 0.797884560802865355879892119868763737 - / jnp.sqrt(x) - ) + factor = 25.0 / (x * x) + pp = op_base.foreach_loop(PP, lambda carry, pp_i: carry * factor + pp_i) + pq = op_base.foreach_loop(PQ, lambda carry, pq_i: carry * factor + pq_i) + qp = op_base.foreach_loop(QP, lambda carry, qp_i: carry * factor + qp_i) + qq = op_base.foreach_loop(QQ, lambda carry, qq_i: carry * factor + qq_i) - return jnp.piecewise( - self, - [self <= 5.0, self < 0.0, self == 0.0], - [small, negative, zero, default], + return ( + ( + pp / pq * jnp.sin(x - 2.356194490192344928846982537459627163) + + 5.0 + / x + * (qp / qq) + * jnp.cos(x - 2.356194490192344928846982537459627163) + ) + * 0.797884560802865355879892119868763737 + / jnp.sqrt(x) ) + return jnp.piecewise( + self, + [self <= 5.0, self < 0.0, self == 0.0], + [small, negative, zero, default], + ) + @op(torch.ops.aten.special_chebyshev_polynomial_t) @op_base.promote_int_input def _aten_special_chebyshev_polynomial_t(self, n): - # Adapted from https://github.com/pytorch/pytorch/blob/f8f41dcb24cb4f4e87a51bb04847942dd835e496/aten/src/ATen/native/Math.h#L2828-L2865 - - @jnp.vectorize - def vectorized(x, n_i): - def negative_n(x): - return jnp.zeros_like(x) - - def one_x(x): - return jnp.where( - (x > 0) | (n_i % 2 == 0), jnp.ones_like(x), -jnp.ones_like(x) - ) - - def large_n_small_x(x): - return jnp.cos(n_i * jnp.acos(x)) - - def zero_n(x): - return jnp.ones_like(x) - - def one_n(x): - return x - - def default(x): - def f(_, carry): - p, q = carry - return (q, 2 * x * q - p) - - _, r = jax.lax.fori_loop(0, n_i - 1, f, init_val=(1.0, x)) - return r - - return jnp.piecewise( - x, - [ - n_i == 1, - n_i == 0, - (n_i == 6) & (jnp.abs(x) < 1), - jnp.abs(x) == 1.0, - n_i < 0, - ], - [one_n, zero_n, large_n_small_x, one_x, negative_n, default], - ) + # Adapted from https://github.com/pytorch/pytorch/blob/f8f41dcb24cb4f4e87a51bb04847942dd835e496/aten/src/ATen/native/Math.h#L2828-L2865 - # Explcicitly vectorize since we must vectorizes over both self and n - return vectorized(self, n.astype(jnp.int64)) + @jnp.vectorize + def vectorized(x, n_i): + def negative_n(x): + return jnp.zeros_like(x) + + def one_x(x): + return jnp.where( + (x > 0) | (n_i % 2 == 0), jnp.ones_like(x), -jnp.ones_like(x) + ) + + def large_n_small_x(x): + return jnp.cos(n_i * jnp.acos(x)) + + def zero_n(x): + return jnp.ones_like(x) + + def one_n(x): + return x + + def default(x): + def f(_, carry): + p, q = carry + return (q, 2 * x * q - p) + + _, r = jax.lax.fori_loop(0, n_i - 1, f, init_val=(1.0, x)) + return r + + return jnp.piecewise( + x, + [ + n_i == 1, + n_i == 0, + (n_i == 6) & (jnp.abs(x) < 1), + jnp.abs(x) == 1.0, + n_i < 0, + ], + [one_n, zero_n, large_n_small_x, one_x, negative_n, default], + ) + + # Explcicitly vectorize since we must vectorizes over both self and n + return vectorized(self, n.astype(jnp.int64)) @op(torch.ops.aten.special_chebyshev_polynomial_u) @op_base.promote_int_input def _aten_special_chebyshev_polynomial_u(self, n): - # Adapted from https://github.com/pytorch/pytorch/blob/f8f41dcb24cb4f4e87a51bb04847942dd835e496/aten/src/ATen/native/Math.h#L2872-L2913 - - @jnp.vectorize - def vectorized(x, n_i): - def negative_n(x): - return jnp.zeros_like(x) - - def one_x(x): - return jnp.where((x > 0) | (n_i % 2 == 0), n_i + 1, -(n_i + 1)) - - def large_n_small_x(x): - sin_acos_x = jnp.sin(jnp.acos(x)) - return jnp.where( - sin_acos_x != 0, - jnp.sin((n_i + 1) * jnp.acos(x)) / sin_acos_x, - (n_i + 1) * jnp.cos((n_i + 1) * jnp.acos(x)) / x, - ) - - def zero_n(x): - return jnp.ones_like(x) - - def one_n(x): - return 2 * x - - def default(x): - def f(_, carry): - p, q = carry - return (q, 2 * x * q - p) - - _, r = jax.lax.fori_loop(0, n_i - 1, f, init_val=(1.0, 2 * x)) - return r - - return jnp.piecewise( - x, - [ - n_i == 1, - n_i == 0, - (n_i > 8) & (jnp.abs(x) < 1), - jnp.abs(x) == 1.0, - n_i < 0, - ], - [one_n, zero_n, large_n_small_x, one_x, negative_n, default], - ) + # Adapted from https://github.com/pytorch/pytorch/blob/f8f41dcb24cb4f4e87a51bb04847942dd835e496/aten/src/ATen/native/Math.h#L2872-L2913 + + @jnp.vectorize + def vectorized(x, n_i): + def negative_n(x): + return jnp.zeros_like(x) + + def one_x(x): + return jnp.where((x > 0) | (n_i % 2 == 0), n_i + 1, -(n_i + 1)) - return vectorized(self, n.astype(jnp.int64)) + def large_n_small_x(x): + sin_acos_x = jnp.sin(jnp.acos(x)) + return jnp.where( + sin_acos_x != 0, + jnp.sin((n_i + 1) * jnp.acos(x)) / sin_acos_x, + (n_i + 1) * jnp.cos((n_i + 1) * jnp.acos(x)) / x, + ) + + def zero_n(x): + return jnp.ones_like(x) + + def one_n(x): + return 2 * x + + def default(x): + def f(_, carry): + p, q = carry + return (q, 2 * x * q - p) + + _, r = jax.lax.fori_loop(0, n_i - 1, f, init_val=(1.0, 2 * x)) + return r + + return jnp.piecewise( + x, + [ + n_i == 1, + n_i == 0, + (n_i > 8) & (jnp.abs(x) < 1), + jnp.abs(x) == 1.0, + n_i < 0, + ], + [one_n, zero_n, large_n_small_x, one_x, negative_n, default], + ) + + return vectorized(self, n.astype(jnp.int64)) @op(torch.ops.aten.special_erfcx) @op_base.promote_int_input def _aten_special_erfcx(x): - return jnp.exp(x * x) * jax.lax.erfc(x) + return jnp.exp(x * x) * jax.lax.erfc(x) @op(torch.ops.aten.erfc) @op_base.promote_int_input def _aten_erfcx(x): - return jax.lax.erfc(x) + return jax.lax.erfc(x) @op(torch.ops.aten.special_hermite_polynomial_h) @op_base.promote_int_input def _aten_special_hermite_polynomial_h(self, n): - # Adapted from https://github.com/pytorch/pytorch/blob/f8f41dcb24cb4f4e87a51bb04847942dd835e496/aten/src/ATen/native/Math.h#L3036-L3061 + # Adapted from https://github.com/pytorch/pytorch/blob/f8f41dcb24cb4f4e87a51bb04847942dd835e496/aten/src/ATen/native/Math.h#L3036-L3061 - @jnp.vectorize - def vectorized(x, n_i): - def negative_n(x): - return jnp.zeros_like(x) + @jnp.vectorize + def vectorized(x, n_i): + def negative_n(x): + return jnp.zeros_like(x) - def zero_n(x): - return jnp.ones_like(x) + def zero_n(x): + return jnp.ones_like(x) - def one_n(x): - return 2 * x + def one_n(x): + return 2 * x - def default(x): - def f(k, carry): - p, q = carry - return (q, 2 * x * q - 2 * k * p) + def default(x): + def f(k, carry): + p, q = carry + return (q, 2 * x * q - 2 * k * p) - _, r = jax.lax.fori_loop(1, n_i, f, init_val=(1.0, 2 * x)) - return r + _, r = jax.lax.fori_loop(1, n_i, f, init_val=(1.0, 2 * x)) + return r - return jnp.piecewise( - x, - [n_i == 1, n_i == 0, n_i < 0], - [one_n, zero_n, negative_n, default], - ) + return jnp.piecewise( + x, + [n_i == 1, n_i == 0, n_i < 0], + [one_n, zero_n, negative_n, default], + ) - return vectorized(self, n.astype(jnp.int64)) + return vectorized(self, n.astype(jnp.int64)) @op(torch.ops.aten.special_hermite_polynomial_he) @op_base.promote_int_input def _aten_special_hermite_polynomial_he(self, n): - # Adapted from https://github.com/pytorch/pytorch/blob/f8f41dcb24cb4f4e87a51bb04847942dd835e496/aten/src/ATen/native/Math.h#L3073-L3098 + # Adapted from https://github.com/pytorch/pytorch/blob/f8f41dcb24cb4f4e87a51bb04847942dd835e496/aten/src/ATen/native/Math.h#L3073-L3098 - @jnp.vectorize - def vectorized(x, n_i): - def negative_n(x): - return jnp.zeros_like(x) + @jnp.vectorize + def vectorized(x, n_i): + def negative_n(x): + return jnp.zeros_like(x) - def zero_n(x): - return jnp.ones_like(x) + def zero_n(x): + return jnp.ones_like(x) - def one_n(x): - return x + def one_n(x): + return x - def default(x): - def f(k, carry): - p, q = carry - return (q, x * q - k * p) + def default(x): + def f(k, carry): + p, q = carry + return (q, x * q - k * p) - _, r = jax.lax.fori_loop(1, n_i, f, init_val=(1.0, x)) - return r + _, r = jax.lax.fori_loop(1, n_i, f, init_val=(1.0, x)) + return r - return jnp.piecewise( - x, - [n_i == 1.0, n_i == 0.0, n_i < 0], - [one_n, zero_n, negative_n, default], - ) + return jnp.piecewise( + x, + [n_i == 1.0, n_i == 0.0, n_i < 0], + [one_n, zero_n, negative_n, default], + ) - return vectorized(self, n.astype(jnp.int64)) + return vectorized(self, n.astype(jnp.int64)) @op(torch.ops.aten.multinomial, needs_env=True) def _aten_multinomial( - input, num_samples, replacement=False, *, generator=None, out=None, env=None + input, num_samples, replacement=False, *, generator=None, out=None, env=None ): - assert num_samples <= input.shape[-1] or replacement, ( - "cannot take a larger sample than population when replacement=False" + assert num_samples <= input.shape[-1] or replacement, ( + "cannot take a larger sample than population when replacement=False" + ) + key = env.get_and_rotate_prng_key(generator) + if input.ndim == 1: + return jax.random.choice( + key, input.shape[-1], (num_samples,), replace=replacement, p=input ) - key = env.get_and_rotate_prng_key(generator) - if input.ndim == 1: - return jax.random.choice( - key, input.shape[-1], (num_samples,), replace=replacement, p=input - ) - else: - return jnp.array([ - jax.random.choice( - key, - input.shape[-1], - (num_samples,), - replace=replacement, - p=input[i, :], - ) - for i in range(input.shape[0]) - ]) + else: + return jnp.array([ + jax.random.choice( + key, + input.shape[-1], + (num_samples,), + replace=replacement, + p=input[i, :], + ) + for i in range(input.shape[0]) + ]) @op(torch.ops.aten.narrow) @op(torch.ops.aten.narrow_copy) def _aten_narrow(input, dim, start, length): - return jax.lax.dynamic_slice_in_dim(input, start, length, axis=dim) + return jax.lax.dynamic_slice_in_dim(input, start, length, axis=dim) @op(torch.ops.aten.flatten) def _aten_flatten(x, start_dim=0, end_dim=-1): - """ - Flattens a JAX array (similar to torch.flatten). + """ + Flattens a JAX array (similar to torch.flatten). - Args: - x: The JAX array to be flattened. - start_dim: The first dimension to include in the flattening. - end_dim: The last dimension to include in the flattening. + Args: + x: The JAX array to be flattened. + start_dim: The first dimension to include in the flattening. + end_dim: The last dimension to include in the flattening. - Returns: - A flattened JAX array. - """ - shape = x.shape + Returns: + A flattened JAX array. + """ + shape = x.shape - if end_dim < 0: - end_dim += len(shape) # Handle negative indexing + if end_dim < 0: + end_dim += len(shape) # Handle negative indexing - new_shape = (*shape[:start_dim], -1, *shape[end_dim + 1 :]) - return jnp.reshape(x, new_shape) + new_shape = (*shape[:start_dim], -1, *shape[end_dim + 1 :]) + return jnp.reshape(x, new_shape) @op(torch.ops.aten.new_empty) def _new_empty(self, size, **kwargs): - dtype = kwargs.get("dtype") - if dtype is not None: - dtype = mappings.t2j_dtype(dtype) - else: - dtype = self.dtype - return jnp.empty(size, dtype=dtype) + dtype = kwargs.get("dtype") + if dtype is not None: + dtype = mappings.t2j_dtype(dtype) + else: + dtype = self.dtype + return jnp.empty(size, dtype=dtype) @op(torch.ops.aten.new_empty_strided) def _new_empty_strided(self, size, stride, dtype=None, **kwargs): - # Ignore stride, since JAX and torch tensor doesn't share the same memory. - if not dtype: - return jnp.empty(size, dtype=self.dtype) - else: - jax_dtype = mappings.t2j_dtype(dtype) - return jnp.empty(size, dtype=jax_dtype) + # Ignore stride, since JAX and torch tensor doesn't share the same memory. + if not dtype: + return jnp.empty(size, dtype=self.dtype) + else: + jax_dtype = mappings.t2j_dtype(dtype) + return jnp.empty(size, dtype=jax_dtype) @op(torch.ops.aten._unsafe_index_put) def _aten_unsafe_index_put(self, indices, values, accumulate=False): - return _aten_index_put(self, indices, values, accumulate) + return _aten_index_put(self, indices, values, accumulate) @op( - torch.ops.aten.conj_physical, - torch.ops.aten.conj, - torch.ops.aten._conj_physical, - torch.ops.aten._conj, + torch.ops.aten.conj_physical, + torch.ops.aten.conj, + torch.ops.aten._conj_physical, + torch.ops.aten._conj, ) def _aten_conj_physical(self): - return jnp.conjugate(self) + return jnp.conjugate(self) @op(torch.ops.aten.log_sigmoid) def _aten_log_sigmoid(x): - return jax.nn.log_sigmoid(x) + return jax.nn.log_sigmoid(x) # torch.qr @op(torch.ops.aten.qr) def _aten_qr(input, *args, **kwargs): - jax_mode = "reduced" - # torch bool param 'simple=True' corresponds to jax 'reduced' mode, - # and simple=False corresponds to jax 'complete' mode. - if kwargs.get("simple") is False: - jax_mode = "complete" - return jax.numpy.linalg.qr(input, mode=jax_mode) + jax_mode = "reduced" + # torch bool param 'simple=True' corresponds to jax 'reduced' mode, + # and simple=False corresponds to jax 'complete' mode. + if kwargs.get("simple") is False: + jax_mode = "complete" + return jax.numpy.linalg.qr(input, mode=jax_mode) # torch.linalg.qr @op(torch.ops.aten.linalg_qr) def _aten_linalg_qr(input, *args, **kwargs): - mode = kwargs.get("mode", "reduced") - return jax.numpy.linalg.qr(input, mode=mode) + mode = kwargs.get("mode", "reduced") + return jax.numpy.linalg.qr(input, mode=mode) # torch.linalg.matrix_exp @op(torch.ops.aten.linalg_matrix_exp) def _aten_linalg_matrix_exp(input): - return jax.scipy.linalg.expm(input) + return jax.scipy.linalg.expm(input) # torch._linalg.slogdet @op(torch.ops.aten._linalg_slogdet) def _aten__linalg_slogdet(input): - res = jnp.linalg.slogdet(input) - return res.sign, res.logabsdet + res = jnp.linalg.slogdet(input) + return res.sign, res.logabsdet # torch.linalg.svd @op(torch.ops.aten._linalg_svd) def _aten__linalg_svd(a, full_matrices=False, **kwargs): - return jnp.linalg.svd(a, full_matrices=full_matrices, **kwargs) + return jnp.linalg.svd(a, full_matrices=full_matrices, **kwargs) # torch.linalg.pinv @op(torch.ops.aten.linalg_pinv.atol_rtol_tensor) def _aten_linalg_pinv_atol_rtol_tensor(a, rtol=None, **kwargs): - return jnp.linalg.pinv(a, rtol, hermitian=False) + return jnp.linalg.pinv(a, rtol, hermitian=False) # torch.linalg.solve @op(torch.ops.aten._linalg_solve_ex) def _aten__linalg_solve_ex(a, b): - batched = False - if b.ndim > 1 and b.shape[-1] == a.shape[-1]: - batched = True - b = b[..., None] - res = jnp.linalg.solve(a, b) - if batched: - res = res.squeeze(-1) - info_shape = a.shape[0] if len(a.shape) >= 3 else [] - info = jnp.zeros(info_shape, dtype=mappings.t2j_dtype(torch.int32)) - return res, info + batched = False + if b.ndim > 1 and b.shape[-1] == a.shape[-1]: + batched = True + b = b[..., None] + res = jnp.linalg.solve(a, b) + if batched: + res = res.squeeze(-1) + info_shape = a.shape[0] if len(a.shape) >= 3 else [] + info = jnp.zeros(info_shape, dtype=mappings.t2j_dtype(torch.int32)) + return res, info # torch.linalg.solve_triangular @op(torch.ops.aten.linalg_solve_triangular) def _aten_linalg_solve_triangular( - a, b, *, upper=True, left=True, unitriangular=False + a, b, *, upper=True, left=True, unitriangular=False ): - if left is False: - a = jnp.matrix_transpose(a) - b = jnp.matrix_transpose(b) - upper = not upper - res = jax.scipy.linalg.solve_triangular( - a, b, lower=not upper, unit_diagonal=unitriangular - ) - if left is False: - res = jnp.matrix_transpose(res) - return res + if left is False: + a = jnp.matrix_transpose(a) + b = jnp.matrix_transpose(b) + upper = not upper + res = jax.scipy.linalg.solve_triangular( + a, b, lower=not upper, unit_diagonal=unitriangular + ) + if left is False: + res = jnp.matrix_transpose(res) + return res @op(torch.ops.aten.linalg_inv_ex) def _aten_linalg_inv_ex(a): - ainv = jnp.linalg.inv(a) - info = jnp.zeros(a.shape[:-2], jnp.int32) - return ainv, info + ainv = jnp.linalg.inv(a) + info = jnp.zeros(a.shape[:-2], jnp.int32) + return ainv, info @op(torch.ops.aten._linalg_check_errors) def _aten__linalg_check_errors(*args, **kwargs): - pass + pass @op(torch.ops.aten.median) def _aten_median(self, dim=None, keepdim=False): - output = _with_reduction_scalar( - functools.partial(jnp.quantile, q=0.5, method="lower"), - self, - dim=dim, - keepdim=keepdim, - ).astype(self.dtype) - if dim is None: - return output - else: - index = _with_reduction_scalar( - _get_median_index, self, dim, keepdim - ).astype(jnp.int64) - return output, index + output = _with_reduction_scalar( + functools.partial(jnp.quantile, q=0.5, method="lower"), + self, + dim=dim, + keepdim=keepdim, + ).astype(self.dtype) + if dim is None: + return output + else: + index = _with_reduction_scalar( + _get_median_index, self, dim, keepdim + ).astype(jnp.int64) + return output, index @op(torch.ops.aten.nanmedian) def _aten_nanmedian(input, dim=None, keepdim=False, *, out=None): - output = _with_reduction_scalar( - functools.partial(jnp.nanquantile, q=0.5, method="lower"), - input, - dim=dim, - keepdim=keepdim, - ).astype(input.dtype) - if dim is None: - return output - else: - index = _with_reduction_scalar( - _get_median_index, input, dim, keepdim - ).astype(jnp.int64) - return output, index + output = _with_reduction_scalar( + functools.partial(jnp.nanquantile, q=0.5, method="lower"), + input, + dim=dim, + keepdim=keepdim, + ).astype(input.dtype) + if dim is None: + return output + else: + index = _with_reduction_scalar( + _get_median_index, input, dim, keepdim + ).astype(jnp.int64) + return output, index def _get_median_index(x, axis=None, keepdims=False): - sorted_arg = jnp.argsort(x, axis=axis) - n = x.shape[axis] if axis is not None else x.size - if n % 2 == 1: - index = n // 2 - else: - index = (n // 2) - 1 - if axis is None: - median_index = sorted_arg[index] - else: - median_index = jnp.take(sorted_arg, index, axis=axis) - if keepdims and axis is not None: - median_index = jnp.expand_dims(median_index, axis) - return median_index + sorted_arg = jnp.argsort(x, axis=axis) + n = x.shape[axis] if axis is not None else x.size + if n % 2 == 1: + index = n // 2 + else: + index = (n // 2) - 1 + if axis is None: + median_index = sorted_arg[index] + else: + median_index = jnp.take(sorted_arg, index, axis=axis) + if keepdims and axis is not None: + median_index = jnp.expand_dims(median_index, axis) + return median_index @op(torch.ops.aten.triangular_solve) def _aten_triangular_solve( - b, a, upper=True, transpose=False, unittriangular=False + b, a, upper=True, transpose=False, unittriangular=False ): - return ( - jax.lax.linalg.triangular_solve( - a, - b, - left_side=True, - lower=not upper, - transpose_a=transpose, - unit_diagonal=unittriangular, - ), - a, - ) + return ( + jax.lax.linalg.triangular_solve( + a, + b, + left_side=True, + lower=not upper, + transpose_a=transpose, + unit_diagonal=unittriangular, + ), + a, + ) # func: _fft_c2c(Tensor self, SymInt[] dim, int normalization, bool forward) -> Tensor @op(torch.ops.aten._fft_c2c) def _aten__fft_c2c(self, dim, normalization, forward): - if forward: - norm = [ - "backward", - "ortho", - "forward", - ][normalization] - return jnp.fft.fftn(self, axes=dim, norm=norm) - else: - norm = [ - "forward", - "ortho", - "backward", - ][normalization] - return jnp.fft.ifftn(self, axes=dim, norm=norm) + if forward: + norm = [ + "backward", + "ortho", + "forward", + ][normalization] + return jnp.fft.fftn(self, axes=dim, norm=norm) + else: + norm = [ + "forward", + "ortho", + "backward", + ][normalization] + return jnp.fft.ifftn(self, axes=dim, norm=norm) @op(torch.ops.aten._fft_r2c) def _aten__fft_r2c(self, dim, normalization, onesided): - norm = [ - "backward", - "ortho", - "forward", - ][normalization] - if onesided: - return jnp.fft.rfftn(self, axes=dim, norm=norm) - else: - return jnp.fft.fftn(self, axes=dim, norm=norm) + norm = [ + "backward", + "ortho", + "forward", + ][normalization] + if onesided: + return jnp.fft.rfftn(self, axes=dim, norm=norm) + else: + return jnp.fft.fftn(self, axes=dim, norm=norm) @op(torch.ops.aten._fft_c2r) def _aten__fft_c2r(self, dim, normalization, last_dim_size): - norm = [ - "forward", - "ortho", - "backward", - ][normalization] - if len(dim) == 1: - s = [last_dim_size] - else: - s = None - return jnp.fft.irfftn(self, norm=norm, axes=dim, s=s) + norm = [ + "forward", + "ortho", + "backward", + ][normalization] + if len(dim) == 1: + s = [last_dim_size] + else: + s = None + return jnp.fft.irfftn(self, norm=norm, axes=dim, s=s) @op(torch.ops.aten._trilinear) def _aten_trilinear( - i1, i2, i3, expand1, expand2, expand3, sumdim, unroll_dim=1 + i1, i2, i3, expand1, expand2, expand3, sumdim, unroll_dim=1 ): - return _aten_sum( - jnp.expand_dims(i1, expand1) - * jnp.expand_dims(i2, expand2) - * jnp.expand_dims(i3, expand3), - sumdim, - ) + return _aten_sum( + jnp.expand_dims(i1, expand1) + * jnp.expand_dims(i2, expand2) + * jnp.expand_dims(i3, expand3), + sumdim, + ) @op(torch.ops.aten.max_unpool2d) @op(torch.ops.aten.max_unpool3d) def _aten_max_unpoolxd(input, indices, output_size, stride=None, padding=0): - if output_size is None: - raise ValueError( - "output_size value is not set correctly. It cannot be None or empty." - ) + if output_size is None: + raise ValueError( + "output_size value is not set correctly. It cannot be None or empty." + ) - output_size = [input.shape[0], input.shape[1]] + output_size - output = jnp.zeros(output_size, dtype=input.dtype) + output_size = [input.shape[0], input.shape[1]] + output_size + output = jnp.zeros(output_size, dtype=input.dtype) - for idx in np.ndindex(input.shape): - max_index = indices[idx] - spatial_dims = output_size[2:] # (D, H, W) - unpooled_spatial_idx = np.unravel_index(max_index, spatial_dims) - full_idx = idx[:2] + unpooled_spatial_idx - output = output.at[full_idx].set(input[idx]) + for idx in np.ndindex(input.shape): + max_index = indices[idx] + spatial_dims = output_size[2:] # (D, H, W) + unpooled_spatial_idx = np.unravel_index(max_index, spatial_dims) + full_idx = idx[:2] + unpooled_spatial_idx + output = output.at[full_idx].set(input[idx]) - return output + return output def _aten_upsample( - input, - output_size, - align_corners, - antialias, - method, - scale_factors=None, - scales_h=None, - scales_w=None, + input, + output_size, + align_corners, + antialias, + method, + scale_factors=None, + scales_h=None, + scales_w=None, ): - # input: is of type jaxlib.xla_extension.ArrayImpl - image = input - - # https://jax.readthedocs.io/en/latest/_autosummary/jax.image.resize.html - # Resize does not distinguish batch, channel size. - # We need to leave them as is - # https://pytorch.org/vision/stable/transforms.html#supported-input-types-and-conventions - # pytorch image shape is (C,H,W) or (N,C,H,W) - # N - batch size - # C - no of channels - # H,W - heigth, width - - shape = list(image.shape) - # overriding output_size - if scale_factors: - shape[-1] = int(math.floor(shape[-1] * scale_factors[-1])) - shape[-2] = int(math.floor(shape[-2] * scale_factors[-2])) - if scales_h: - shape[-2] = int(math.floor(shape[-2] * scales_h)) - if scales_w: - shape[-1] = int(math.floor(shape[-1] * scales_w)) - # output_size overrides scale_factors, scales_* - if output_size: - shape[-1] = output_size[-1] - shape[-2] = output_size[-2] - - # pytorch upsample_bilinear returns the input as is when the shape is the same as input - if shape == list(image.shape): - return image - - spatial_dims = (2, 3) - if len(shape) == 3: - spatial_dims = (1, 2) - - scale = list([shape[i] / image.shape[i] for i in spatial_dims]) - if scale_factors: - scale = scale_factors - if scales_h: - scale[0] = scales_h - if scales_w: - scale[1] = scales_w - scale = jnp.array(scale) - - # align_corners is not supported in resize() - # https://github.com/jax-ml/jax/issues/11206 - if align_corners: - scale = jnp.array([ - (shape[i] - 1.0) / (image.shape[i] - 1.0) for i in spatial_dims - ]) - - translation = jnp.array([0 for i in spatial_dims]) - - return jax_reimplement.scale_and_translate( - image, - shape, - method=method, - scale=scale, - spatial_dims=spatial_dims, - translation=translation, - antialias=antialias, - ) + # input: is of type jaxlib.xla_extension.ArrayImpl + image = input + + # https://jax.readthedocs.io/en/latest/_autosummary/jax.image.resize.html + # Resize does not distinguish batch, channel size. + # We need to leave them as is + # https://pytorch.org/vision/stable/transforms.html#supported-input-types-and-conventions + # pytorch image shape is (C,H,W) or (N,C,H,W) + # N - batch size + # C - no of channels + # H,W - heigth, width + + shape = list(image.shape) + # overriding output_size + if scale_factors: + shape[-1] = int(math.floor(shape[-1] * scale_factors[-1])) + shape[-2] = int(math.floor(shape[-2] * scale_factors[-2])) + if scales_h: + shape[-2] = int(math.floor(shape[-2] * scales_h)) + if scales_w: + shape[-1] = int(math.floor(shape[-1] * scales_w)) + # output_size overrides scale_factors, scales_* + if output_size: + shape[-1] = output_size[-1] + shape[-2] = output_size[-2] + + # pytorch upsample_bilinear returns the input as is when the shape is the same as input + if shape == list(image.shape): + return image + + spatial_dims = (2, 3) + if len(shape) == 3: + spatial_dims = (1, 2) + + scale = list([shape[i] / image.shape[i] for i in spatial_dims]) + if scale_factors: + scale = scale_factors + if scales_h: + scale[0] = scales_h + if scales_w: + scale[1] = scales_w + scale = jnp.array(scale) + + # align_corners is not supported in resize() + # https://github.com/jax-ml/jax/issues/11206 + if align_corners: + scale = jnp.array([ + (shape[i] - 1.0) / (image.shape[i] - 1.0) for i in spatial_dims + ]) + + translation = jnp.array([0 for i in spatial_dims]) + + return jax_reimplement.scale_and_translate( + image, + shape, + method=method, + scale=scale, + spatial_dims=spatial_dims, + translation=translation, + antialias=antialias, + ) @op(torch.ops.aten._upsample_bilinear2d_aa) def _aten_upsample_billinear_aa( + input, + output_size, + align_corners, + scale_factors=None, + scales_h=None, + scales_w=None, +): + return _aten_upsample( input, output_size, align_corners, - scale_factors=None, - scales_h=None, - scales_w=None, -): - return _aten_upsample( - input, - output_size, - align_corners, - True, # antialias - "bilinear", # method - scale_factors, - scales_h, - scales_w, - ) + True, # antialias + "bilinear", # method + scale_factors, + scales_h, + scales_w, + ) @op(torch.ops.aten._upsample_bicubic2d_aa) def _aten_upsample_bicubic2d_aa( + input, + output_size, + align_corners, + scale_factors=None, + scales_h=None, + scales_w=None, +): + return _aten_upsample( input, output_size, align_corners, - scale_factors=None, - scales_h=None, - scales_w=None, -): - return _aten_upsample( - input, - output_size, - align_corners, - True, # antialias - "bicubic", # method - scale_factors, - scales_h, - scales_w, - ) + True, # antialias + "bicubic", # method + scale_factors, + scales_h, + scales_w, + ) @op(torch.ops.aten.polar) def _aten_polar(abs, angle, *, out=None): - return jax.lax.complex(abs * jnp.cos(angle), abs * jnp.sin(angle)) + return jax.lax.complex(abs * jnp.cos(angle), abs * jnp.sin(angle)) @op(torch.ops.aten.cdist) def _aten_cdist( - x1, x2, p=2.0, compute_mode="use_mm_for_euclid_dist_if_necessary" + x1, x2, p=2.0, compute_mode="use_mm_for_euclid_dist_if_necessary" ): - x1 = x1.astype(jnp.float32) - x2 = x2.astype(jnp.float32) - - if p == 0.0: - # For p = 0, use Hamming-like distance multiplied by the number of elements - return _hamming_distance(x1, x2).astype(jnp.float32) - elif p == 2.0: - # Use optimized Euclidean distance calculation - if compute_mode == "use_mm_for_euclid_dist_if_necessary" and ( - x1.shape[-2] > 25 or x2.shape[-2] > 25 - ): - return _euclidean_mm(x1, x2) - elif compute_mode == "use_mm_for_euclid_dist": - return _euclidean_mm(x1, x2) - else: - return _euclidean_direct(x1, x2) + x1 = x1.astype(jnp.float32) + x2 = x2.astype(jnp.float32) + + if p == 0.0: + # For p = 0, use Hamming-like distance multiplied by the number of elements + return _hamming_distance(x1, x2).astype(jnp.float32) + elif p == 2.0: + # Use optimized Euclidean distance calculation + if compute_mode == "use_mm_for_euclid_dist_if_necessary" and ( + x1.shape[-2] > 25 or x2.shape[-2] > 25 + ): + return _euclidean_mm(x1, x2) + elif compute_mode == "use_mm_for_euclid_dist": + return _euclidean_mm(x1, x2) else: - # General p-norm distance calculation - diff = jnp.abs(jnp.expand_dims(x1, -2) - jnp.expand_dims(x2, -3)) - return jnp.sum(jnp.power(diff, p), axis=-1).astype(jnp.float32) ** ( - 1 / p - ) + return _euclidean_direct(x1, x2) + else: + # General p-norm distance calculation + diff = jnp.abs(jnp.expand_dims(x1, -2) - jnp.expand_dims(x2, -3)) + return jnp.sum(jnp.power(diff, p), axis=-1).astype(jnp.float32) ** (1 / p) def _hamming_distance(x1, x2): - """ - Computes the Hamming-like distance for p=0. + """ + Computes the Hamming-like distance for p=0. - Args: - x1: JAX array of shape (..., P, M) - x2: JAX array of shape (..., R, M) + Args: + x1: JAX array of shape (..., P, M) + x2: JAX array of shape (..., R, M) - Returns: - JAX array of shape (..., P, R) representing pairwise Hamming distances. - """ - diff = jnp.not_equal(jnp.expand_dims(x1, -2), jnp.expand_dims(x2, -3)) + Returns: + JAX array of shape (..., P, R) representing pairwise Hamming distances. + """ + diff = jnp.not_equal(jnp.expand_dims(x1, -2), jnp.expand_dims(x2, -3)) - hamming_dist = jnp.sum(diff, axis=-1).astype(jnp.float32) + hamming_dist = jnp.sum(diff, axis=-1).astype(jnp.float32) - return hamming_dist + return hamming_dist def _euclidean_mm(x1, x2): - """ - Computes the Euclidean distance using matrix multiplication. + """ + Computes the Euclidean distance using matrix multiplication. - Args: - x1: JAX array of shape (..., P, M) - x2: JAX array of shape (..., R, M) + Args: + x1: JAX array of shape (..., P, M) + x2: JAX array of shape (..., R, M) - Returns: - JAX array of shape (..., P, R) representing pairwise Euclidean distances. - """ - x1_sq = jnp.sum(x1**2, axis=-1, keepdims=True).astype(jnp.float32) - x2_sq = jnp.sum(x2**2, axis=-1, keepdims=True).astype(jnp.float32) + Returns: + JAX array of shape (..., P, R) representing pairwise Euclidean distances. + """ + x1_sq = jnp.sum(x1**2, axis=-1, keepdims=True).astype(jnp.float32) + x2_sq = jnp.sum(x2**2, axis=-1, keepdims=True).astype(jnp.float32) - x2_sq = jnp.swapaxes(x2_sq, -2, -1) + x2_sq = jnp.swapaxes(x2_sq, -2, -1) - dot_product = jnp.matmul(x1, jnp.swapaxes(x2, -1, -2)) + dot_product = jnp.matmul(x1, jnp.swapaxes(x2, -1, -2)) - dist_sq = x1_sq + x2_sq - 2 * dot_product - dist_sq = jnp.maximum(dist_sq, 0.0) - dist = jnp.sqrt(dist_sq).astype(jnp.float32) + dist_sq = x1_sq + x2_sq - 2 * dot_product + dist_sq = jnp.maximum(dist_sq, 0.0) + dist = jnp.sqrt(dist_sq).astype(jnp.float32) - return dist + return dist def _euclidean_direct(x1, x2): - """ - Computes the Euclidean distance directly without matrix multiplication. + """ + Computes the Euclidean distance directly without matrix multiplication. - Args: - x1: JAX array of shape (..., P, M) - x2: JAX array of shape (..., R, M) + Args: + x1: JAX array of shape (..., P, M) + x2: JAX array of shape (..., R, M) - Returns: - JAX array of shape (..., P, R) representing pairwise Euclidean distances. - """ - diff = jnp.expand_dims(x1, -2) - jnp.expand_dims(x2, -3) + Returns: + JAX array of shape (..., P, R) representing pairwise Euclidean distances. + """ + diff = jnp.expand_dims(x1, -2) - jnp.expand_dims(x2, -3) - dist_sq = jnp.sum(diff**2, axis=-1).astype(jnp.float32) + dist_sq = jnp.sum(diff**2, axis=-1).astype(jnp.float32) - dist_sq = jnp.maximum(dist_sq, 0.0) + dist_sq = jnp.maximum(dist_sq, 0.0) - dist = jnp.sqrt(dist_sq).astype(jnp.float32) + dist = jnp.sqrt(dist_sq).astype(jnp.float32) - return dist + return dist @op(torch.ops.aten.lu_unpack) def _aten_lu_unpack(LU_data, LU_pivots, unpack_data=True, unpack_pivots=True): - # lu_unpack doesnt exist in jax. - # Get commonly used data shape variables - n = LU_data.shape[-2] - m = LU_data.shape[-1] - dim = min(n, m) - - ### Compute the Lower and Upper triangle - if unpack_data: - # Extract lower triangle - L = jnp.tril(LU_data, k=-1) - - # emulate pytorch behavior: Add ones to the diagonal of L - eye = jnp.eye(n, m, dtype=LU_data.dtype) - L = L + eye - - # emulate pytorch behavior: Reshape lower triangle to match pivot - start_indices = jnp.zeros(len(LU_data.shape), dtype=int) - limit_indices = list(LU_data.shape) - limit_indices[-1] = dim - L = jax.lax.slice(L, start_indices, limit_indices) - - # Extract upper triangle - U = jnp.triu(LU_data) - - # emulate pytorch behavior: Reshape upper triangle to match pivot - start_indices = jnp.zeros(len(LU_data.shape), dtype=int) - limit_indices = list(LU_data.shape) - limit_indices[-2] = dim - U = jax.lax.slice(U, start_indices, limit_indices) - else: - # emulate pytroch behavior: return empty tensors - L = torch.empty(torch.Size([0])) - U = torch.empty(torch.Size([0])) - - ### Compute the Permutation matrix - if unpack_pivots: - # We should return a permutation matrix (2D) for each pivot array (1D) - # The shape of the final Permutation matrix depends on the shape of the input - # data and the pivots - - # start with a 2D identity matrix and tile it to the other dims of input data - identity2d = jnp.identity(n, dtype=jnp.float32) - tile_shape = list(LU_data.shape) - tile_shape[-1] = 1 - tile_shape[-2] = 1 - P = jnp.tile(identity2d, tile_shape) - - # closure to be called for each input 2D matrix. - def _lu_unpack_2d(p, pivot): - _pivot = pivot - 1 # pivots are offset by 1 in jax - indices = jnp.array([*range(n)], dtype=jnp.int32) - - def update_indices(i, _indices): - tmp = _indices[i] - _indices = _indices.at[i].set(_indices[_pivot[i]]) - _indices = _indices.at[_pivot[i]].set(tmp) - return _indices - - indices = jax.lax.fori_loop(0, _pivot.size, update_indices, indices) - p = p[jnp.array(indices)] - p = jnp.transpose(p) - return p - - if len(LU_pivots.shape) == 1: - # if we are dealing with a simple 2D input and 1D pivot, call the closure directly - P = _lu_unpack_2d(P, LU_pivots) - else: - # We are dealing with >=3D inputs. Flatten inputs to 3D and use vmap to call the - # closure for each 2D matrix. Finally unflatten the result to match the input data - # shape. - - # reshape permutation matrix to 3d - dim_size = jnp.prod(jnp.array(P.shape[:-2])) - newPshape = (dim_size, P.shape[-2], P.shape[-1]) - reshapedP = P.reshape(newPshape) - - # reshape pivots to 3d - dim_size = jnp.prod(jnp.array(LU_pivots.shape[:-1])) - newPivotshape = (dim_size, LU_pivots.shape[-1]) - reshapedPivot = LU_pivots.reshape(newPivotshape) - - # vmap the reshaped 3d tensors - v_lu_unpack_2d = jax.vmap(_lu_unpack_2d, in_axes=(0, 0)) - unpackedP = v_lu_unpack_2d(reshapedP, reshapedPivot) - - # reshape result back to P's shape - newRetshape = ( - *P.shape[:-2], - unpackedP.shape[-2], - unpackedP.shape[-1], - ) - P = unpackedP.reshape(newRetshape) + # lu_unpack doesnt exist in jax. + # Get commonly used data shape variables + n = LU_data.shape[-2] + m = LU_data.shape[-1] + dim = min(n, m) + + ### Compute the Lower and Upper triangle + if unpack_data: + # Extract lower triangle + L = jnp.tril(LU_data, k=-1) + + # emulate pytorch behavior: Add ones to the diagonal of L + eye = jnp.eye(n, m, dtype=LU_data.dtype) + L = L + eye + + # emulate pytorch behavior: Reshape lower triangle to match pivot + start_indices = jnp.zeros(len(LU_data.shape), dtype=int) + limit_indices = list(LU_data.shape) + limit_indices[-1] = dim + L = jax.lax.slice(L, start_indices, limit_indices) + + # Extract upper triangle + U = jnp.triu(LU_data) + + # emulate pytorch behavior: Reshape upper triangle to match pivot + start_indices = jnp.zeros(len(LU_data.shape), dtype=int) + limit_indices = list(LU_data.shape) + limit_indices[-2] = dim + U = jax.lax.slice(U, start_indices, limit_indices) + else: + # emulate pytroch behavior: return empty tensors + L = torch.empty(torch.Size([0])) + U = torch.empty(torch.Size([0])) + + ### Compute the Permutation matrix + if unpack_pivots: + # We should return a permutation matrix (2D) for each pivot array (1D) + # The shape of the final Permutation matrix depends on the shape of the input + # data and the pivots + + # start with a 2D identity matrix and tile it to the other dims of input data + identity2d = jnp.identity(n, dtype=jnp.float32) + tile_shape = list(LU_data.shape) + tile_shape[-1] = 1 + tile_shape[-2] = 1 + P = jnp.tile(identity2d, tile_shape) + + # closure to be called for each input 2D matrix. + def _lu_unpack_2d(p, pivot): + _pivot = pivot - 1 # pivots are offset by 1 in jax + indices = jnp.array([*range(n)], dtype=jnp.int32) + + def update_indices(i, _indices): + tmp = _indices[i] + _indices = _indices.at[i].set(_indices[_pivot[i]]) + _indices = _indices.at[_pivot[i]].set(tmp) + return _indices + + indices = jax.lax.fori_loop(0, _pivot.size, update_indices, indices) + p = p[jnp.array(indices)] + p = jnp.transpose(p) + return p + + if len(LU_pivots.shape) == 1: + # if we are dealing with a simple 2D input and 1D pivot, call the closure directly + P = _lu_unpack_2d(P, LU_pivots) else: - # emulate pytroch behavior: return empty tensors - P = torch.empty(torch.Size([0])) - - return P, L, U + # We are dealing with >=3D inputs. Flatten inputs to 3D and use vmap to call the + # closure for each 2D matrix. Finally unflatten the result to match the input data + # shape. + + # reshape permutation matrix to 3d + dim_size = jnp.prod(jnp.array(P.shape[:-2])) + newPshape = (dim_size, P.shape[-2], P.shape[-1]) + reshapedP = P.reshape(newPshape) + + # reshape pivots to 3d + dim_size = jnp.prod(jnp.array(LU_pivots.shape[:-1])) + newPivotshape = (dim_size, LU_pivots.shape[-1]) + reshapedPivot = LU_pivots.reshape(newPivotshape) + + # vmap the reshaped 3d tensors + v_lu_unpack_2d = jax.vmap(_lu_unpack_2d, in_axes=(0, 0)) + unpackedP = v_lu_unpack_2d(reshapedP, reshapedPivot) + + # reshape result back to P's shape + newRetshape = ( + *P.shape[:-2], + unpackedP.shape[-2], + unpackedP.shape[-1], + ) + P = unpackedP.reshape(newRetshape) + else: + # emulate pytroch behavior: return empty tensors + P = torch.empty(torch.Size([0])) + + return P, L, U @op(torch.ops.aten.linear) def linear(input, weight, bias=None): - res = input @ jnp.transpose(weight) - if bias is not None: - res += bias - return res + res = input @ jnp.transpose(weight) + if bias is not None: + res += bias + return res @op(torch.ops.aten.kthvalue) def kthvalue(input, k, dim=None, keepdim=False, *, out=None): - if input.ndim == 0: - return input, jnp.array(0) - dimension = -1 - if dim is not None: - dimension = dim - while dimension < 0: - dimension = dimension + input.ndim - values = jax.lax.index_in_dim( - jnp.partition(input, k - 1, dimension), k - 1, dimension, keepdim - ) - indices = jax.lax.index_in_dim( - jnp.argpartition(input, k - 1, dimension).astype("int64"), - k - 1, - dimension, - keepdim, - ) - return values, indices + if input.ndim == 0: + return input, jnp.array(0) + dimension = -1 + if dim is not None: + dimension = dim + while dimension < 0: + dimension = dimension + input.ndim + values = jax.lax.index_in_dim( + jnp.partition(input, k - 1, dimension), k - 1, dimension, keepdim + ) + indices = jax.lax.index_in_dim( + jnp.argpartition(input, k - 1, dimension).astype("int64"), + k - 1, + dimension, + keepdim, + ) + return values, indices @op(torch.ops.aten.take) def _aten_take(self, index): - return self.flatten()[index] + return self.flatten()[index] # func: pad(Tensor self, SymInt[] pad, str mode="constant", float? value=None) -> Tensor @op(torch.ops.aten.pad) def _aten_pad(self, pad, mode="constant", value=None): - if not isinstance(pad, (tuple, list)) or len(pad) % 2 != 0: - raise ValueError("Padding must be a sequence of even length.") + if not isinstance(pad, (tuple, list)) or len(pad) % 2 != 0: + raise ValueError("Padding must be a sequence of even length.") - num_dims = self.ndim - if len(pad) > 2 * num_dims: - raise ValueError( - f"Padding sequence length ({len(pad)}) exceeds 2 * number of dimensions ({2 * num_dims})." - ) + num_dims = self.ndim + if len(pad) > 2 * num_dims: + raise ValueError( + f"Padding sequence length ({len(pad)}) exceeds 2 * number of dimensions ({2 * num_dims})." + ) - # JAX's pad function expects padding for each dimension as a tuple of (low, high) - # We need to reverse the pad sequence and group them for JAX. - # pad = [p_l0, p_r0, p_l1, p_r1, ...] - # becomes ((..., ..., (p_l1, p_r1), (p_l0, p_r0))) - jax_pad_width = [] - # Iterate in reverse pairs - for i in range(len(pad) // 2): - jax_pad_width.append((pad[(2 * i)], pad[(2 * i + 1)])) - - # Pad any leading dimensions with (0, 0) if the pad sequence is shorter - # than the number of dimensions. - for _ in range(num_dims - len(pad) // 2): - jax_pad_width.append((0, 0)) - - # Reverse the jax_pad_width list to match the dimension order - jax_pad_width.reverse() - - if mode == "constant": - if value is None: - value = 0.0 - return jnp.pad( - self, - pad_width=jax_pad_width, - mode="constant", - constant_values=value, - ) - elif mode == "reflect": - return jnp.pad(self, pad_width=jax_pad_width, mode="reflect") - elif mode == "edge": - return jnp.pad(self, pad_width=jax_pad_width, mode="edge") - else: - raise ValueError( - f"Unsupported padding mode: {mode}. Expected 'constant', 'reflect', or 'edge'." - ) + # JAX's pad function expects padding for each dimension as a tuple of (low, high) + # We need to reverse the pad sequence and group them for JAX. + # pad = [p_l0, p_r0, p_l1, p_r1, ...] + # becomes ((..., ..., (p_l1, p_r1), (p_l0, p_r0))) + jax_pad_width = [] + # Iterate in reverse pairs + for i in range(len(pad) // 2): + jax_pad_width.append((pad[(2 * i)], pad[(2 * i + 1)])) + + # Pad any leading dimensions with (0, 0) if the pad sequence is shorter + # than the number of dimensions. + for _ in range(num_dims - len(pad) // 2): + jax_pad_width.append((0, 0)) + + # Reverse the jax_pad_width list to match the dimension order + jax_pad_width.reverse() + + if mode == "constant": + if value is None: + value = 0.0 + return jnp.pad( + self, + pad_width=jax_pad_width, + mode="constant", + constant_values=value, + ) + elif mode == "reflect": + return jnp.pad(self, pad_width=jax_pad_width, mode="reflect") + elif mode == "edge": + return jnp.pad(self, pad_width=jax_pad_width, mode="edge") + else: + raise ValueError( + f"Unsupported padding mode: {mode}. Expected 'constant', 'reflect', or 'edge'." + ) mutation_ops_to_functional = { - torch.ops.aten.add_: op_base.InplaceOp(torch.ops.aten.add), - torch.ops.aten.sub_: op_base.InplaceOp(torch.ops.aten.sub), - torch.ops.aten.mul_: op_base.InplaceOp(torch.ops.aten.mul), - torch.ops.aten.div_: op_base.InplaceOp(torch.ops.aten.div), - torch.ops.aten.pow_: op_base.InplaceOp(torch.ops.aten.pow), - torch.ops.aten.lt_: op_base.InplaceOp(torch.ops.aten.lt), - torch.ops.aten.le_: op_base.InplaceOp(torch.ops.aten.le), - torch.ops.aten.gt_: op_base.InplaceOp(torch.ops.aten.gt), - torch.ops.aten.ge_: op_base.InplaceOp(torch.ops.aten.ge), - torch.ops.aten.eq_: op_base.InplaceOp(torch.ops.aten.eq), - torch.ops.aten.ne_: op_base.InplaceOp(torch.ops.aten.ne), - torch.ops.aten.bernoulli_: op_base.InplaceOp(torch.ops.aten.bernoulli.p), - torch.ops.aten.bernoulli_.float: op_base.InplaceOp( - _aten_bernoulli, is_jax_func=True - ), - torch.ops.aten.geometric_: op_base.InplaceOp(torch.ops.aten.geometric), - torch.ops.aten.normal_: op_base.InplaceOp(torch.ops.aten.normal), - torch.ops.aten.random_: op_base.InplaceOp(torch.ops.aten.uniform), - torch.ops.aten.uniform_: op_base.InplaceOp(torch.ops.aten.uniform), - torch.ops.aten.relu_: op_base.InplaceOp(torch.ops.aten.relu), - # squeeze_ is expected to change tensor's shape. So replace with new value - torch.ops.aten.squeeze_: op_base.InplaceOp(torch.ops.aten.squeeze, True), - torch.ops.aten.sqrt_: op_base.InplaceOp(torch.ops.aten.sqrt), - torch.ops.aten.clamp_: op_base.InplaceOp(torch.ops.aten.clamp), - torch.ops.aten.clamp_min_: op_base.InplaceOp(torch.ops.aten.clamp_min), - torch.ops.aten.sigmoid_: op_base.InplaceOp(torch.ops.aten.sigmoid), - torch.ops.aten.tanh_: op_base.InplaceOp(torch.ops.aten.tanh), - torch.ops.aten.ceil_: op_base.InplaceOp(torch.ops.aten.ceil), - torch.ops.aten.logical_not_: op_base.InplaceOp(torch.ops.aten.logical_not), - torch.ops.aten.unsqueeze_: op_base.InplaceOp(torch.ops.aten.unsqueeze), - torch.ops.aten.transpose_: op_base.InplaceOp(torch.ops.aten.transpose), - torch.ops.aten.log_normal_: op_base.InplaceOp(torch.ops.aten.log_normal), - torch.ops.aten.scatter_add_: op_base.InplaceOp(torch.ops.aten.scatter_add), - torch.ops.aten.scatter_reduce_.two: op_base.InplaceOp( - torch.ops.aten.scatter_reduce - ), - torch.ops.aten.scatter_: op_base.InplaceOp(torch.ops.aten.scatter), - torch.ops.aten.bitwise_or_: op_base.InplaceOp(torch.ops.aten.bitwise_or), + torch.ops.aten.add_: op_base.InplaceOp(torch.ops.aten.add), + torch.ops.aten.sub_: op_base.InplaceOp(torch.ops.aten.sub), + torch.ops.aten.mul_: op_base.InplaceOp(torch.ops.aten.mul), + torch.ops.aten.div_: op_base.InplaceOp(torch.ops.aten.div), + torch.ops.aten.pow_: op_base.InplaceOp(torch.ops.aten.pow), + torch.ops.aten.lt_: op_base.InplaceOp(torch.ops.aten.lt), + torch.ops.aten.le_: op_base.InplaceOp(torch.ops.aten.le), + torch.ops.aten.gt_: op_base.InplaceOp(torch.ops.aten.gt), + torch.ops.aten.ge_: op_base.InplaceOp(torch.ops.aten.ge), + torch.ops.aten.eq_: op_base.InplaceOp(torch.ops.aten.eq), + torch.ops.aten.ne_: op_base.InplaceOp(torch.ops.aten.ne), + torch.ops.aten.bernoulli_: op_base.InplaceOp(torch.ops.aten.bernoulli.p), + torch.ops.aten.bernoulli_.float: op_base.InplaceOp( + _aten_bernoulli, is_jax_func=True + ), + torch.ops.aten.geometric_: op_base.InplaceOp(torch.ops.aten.geometric), + torch.ops.aten.normal_: op_base.InplaceOp(torch.ops.aten.normal), + torch.ops.aten.random_: op_base.InplaceOp(torch.ops.aten.uniform), + torch.ops.aten.uniform_: op_base.InplaceOp(torch.ops.aten.uniform), + torch.ops.aten.relu_: op_base.InplaceOp(torch.ops.aten.relu), + # squeeze_ is expected to change tensor's shape. So replace with new value + torch.ops.aten.squeeze_: op_base.InplaceOp(torch.ops.aten.squeeze, True), + torch.ops.aten.sqrt_: op_base.InplaceOp(torch.ops.aten.sqrt), + torch.ops.aten.clamp_: op_base.InplaceOp(torch.ops.aten.clamp), + torch.ops.aten.clamp_min_: op_base.InplaceOp(torch.ops.aten.clamp_min), + torch.ops.aten.sigmoid_: op_base.InplaceOp(torch.ops.aten.sigmoid), + torch.ops.aten.tanh_: op_base.InplaceOp(torch.ops.aten.tanh), + torch.ops.aten.ceil_: op_base.InplaceOp(torch.ops.aten.ceil), + torch.ops.aten.logical_not_: op_base.InplaceOp(torch.ops.aten.logical_not), + torch.ops.aten.unsqueeze_: op_base.InplaceOp(torch.ops.aten.unsqueeze), + torch.ops.aten.transpose_: op_base.InplaceOp(torch.ops.aten.transpose), + torch.ops.aten.log_normal_: op_base.InplaceOp(torch.ops.aten.log_normal), + torch.ops.aten.scatter_add_: op_base.InplaceOp(torch.ops.aten.scatter_add), + torch.ops.aten.scatter_reduce_.two: op_base.InplaceOp( + torch.ops.aten.scatter_reduce + ), + torch.ops.aten.scatter_: op_base.InplaceOp(torch.ops.aten.scatter), + torch.ops.aten.bitwise_or_: op_base.InplaceOp(torch.ops.aten.bitwise_or), } # Note: tuple comparisons work intuitively, e.g. `_jax_version >= (0, 4, 32)`. _jax_version = tuple(int(v) for v in jax.version._version.split(".")) mutation_needs_env = { - torch.ops.aten.bernoulli_, - torch.ops.aten.bernoulli_.float, + torch.ops.aten.bernoulli_, + torch.ops.aten.bernoulli_.float, } for operator, mutation in mutation_ops_to_functional.items(): - ops_registry.register_torch_dispatch_op( - operator, - mutation, - is_jax_function=False, - is_view_op=True, - needs_env=(operator in mutation_needs_env), - ) + ops_registry.register_torch_dispatch_op( + operator, + mutation, + is_jax_function=False, + is_view_op=True, + needs_env=(operator in mutation_needs_env), + ) diff --git a/torchax/torchax/ops/jax_reimplement.py b/torchax/torchax/ops/jax_reimplement.py index 236fb253de38..f98dc240437d 100644 --- a/torchax/torchax/ops/jax_reimplement.py +++ b/torchax/torchax/ops/jax_reimplement.py @@ -16,92 +16,90 @@ def compute_weight_mat( - input_size: core.DimSize, - output_size: core.DimSize, - scale, - translation, - kernel: Callable, - antialias: bool, + input_size: core.DimSize, + output_size: core.DimSize, + scale, + translation, + kernel: Callable, + antialias: bool, ): - dtype = jnp.result_type(scale, translation) - inv_scale = 1.0 / scale - # When downsampling the kernel should be scaled since we want to low pass - # filter and interpolate, but when upsampling it should not be since we only - # want to interpolate. - kernel_scale = jnp.maximum(inv_scale, 1.0) if antialias else 1.0 - sample_f = ( - (jnp.arange(output_size, dtype=dtype) + 0.5) * inv_scale - - translation * inv_scale - - 0.5 - ) - x = ( - jnp.abs( - sample_f[jnp.newaxis, :] - - jnp.arange(input_size, dtype=dtype)[:, jnp.newaxis] - ) - / kernel_scale - ) - weights = kernel(x) - - total_weight_sum = jnp.sum(weights, axis=0, keepdims=True) - weights = jnp.where( - jnp.abs(total_weight_sum) > 1000.0 * float(np.finfo(np.float32).eps), - jnp.divide( - weights, jnp.where(total_weight_sum != 0, total_weight_sum, 1) - ), - 0, - ) - # Zero out weights where the sample location is completely outside the input - # range. - # Note sample_f has already had the 0.5 removed, hence the weird range below. - - # (barney-s) -------------- returning weights without zeroing --------------------- - return weights - input_size_minus_0_5 = core.dimension_as_value(input_size) - 0.5 - return jnp.where( - jnp.logical_and(sample_f >= -0.5, sample_f <= input_size_minus_0_5)[ - jnp.newaxis, : - ], - weights, - 0, + dtype = jnp.result_type(scale, translation) + inv_scale = 1.0 / scale + # When downsampling the kernel should be scaled since we want to low pass + # filter and interpolate, but when upsampling it should not be since we only + # want to interpolate. + kernel_scale = jnp.maximum(inv_scale, 1.0) if antialias else 1.0 + sample_f = ( + (jnp.arange(output_size, dtype=dtype) + 0.5) * inv_scale + - translation * inv_scale + - 0.5 + ) + x = ( + jnp.abs( + sample_f[jnp.newaxis, :] + - jnp.arange(input_size, dtype=dtype)[:, jnp.newaxis] ) - # (barney-s) -------------- END returning weights without zeroing --------------------- + / kernel_scale + ) + weights = kernel(x) + + total_weight_sum = jnp.sum(weights, axis=0, keepdims=True) + weights = jnp.where( + jnp.abs(total_weight_sum) > 1000.0 * float(np.finfo(np.float32).eps), + jnp.divide(weights, jnp.where(total_weight_sum != 0, total_weight_sum, 1)), + 0, + ) + # Zero out weights where the sample location is completely outside the input + # range. + # Note sample_f has already had the 0.5 removed, hence the weird range below. + + # (barney-s) -------------- returning weights without zeroing --------------------- + return weights + input_size_minus_0_5 = core.dimension_as_value(input_size) - 0.5 + return jnp.where( + jnp.logical_and(sample_f >= -0.5, sample_f <= input_size_minus_0_5)[ + jnp.newaxis, : + ], + weights, + 0, + ) + # (barney-s) -------------- END returning weights without zeroing --------------------- # JAX Link: https://github.com/jax-ml/jax/blob/18f48bd52abe907ff9818da52f3d195d32910c1b/jax/_src/image/scale.py#L86 def _scale_and_translate( - x, - output_shape: core.Shape, - spatial_dims: Sequence[int], - scale, - translation, - kernel, - antialias: bool, - precision, + x, + output_shape: core.Shape, + spatial_dims: Sequence[int], + scale, + translation, + kernel, + antialias: bool, + precision, ): - input_shape = x.shape - assert len(input_shape) == len(output_shape) - assert len(spatial_dims) == len(scale) - assert len(spatial_dims) == len(translation) - if len(spatial_dims) == 0: - return x - contractions = [] - in_indices = list(range(len(output_shape))) - out_indices = list(range(len(output_shape))) - for i, d in enumerate(spatial_dims): - d = canonicalize_axis(d, x.ndim) - m = input_shape[d] - n = output_shape[d] - w = compute_weight_mat( - m, n, scale[i], translation[i], kernel, antialias - ).astype(x.dtype) - contractions.append(w) - contractions.append([d, len(output_shape) + i]) - out_indices[d] = len(output_shape) + i - contractions.append(out_indices) - return jnp.einsum(x, in_indices, *contractions, precision=precision) + input_shape = x.shape + assert len(input_shape) == len(output_shape) + assert len(spatial_dims) == len(scale) + assert len(spatial_dims) == len(translation) + if len(spatial_dims) == 0: + return x + contractions = [] + in_indices = list(range(len(output_shape))) + out_indices = list(range(len(output_shape))) + for i, d in enumerate(spatial_dims): + d = canonicalize_axis(d, x.ndim) + m = input_shape[d] + n = output_shape[d] + w = compute_weight_mat( + m, n, scale[i], translation[i], kernel, antialias + ).astype(x.dtype) + contractions.append(w) + contractions.append([d, len(output_shape) + i]) + out_indices[d] = len(output_shape) + i + contractions.append(out_indices) + return jnp.einsum(x, in_indices, *contractions, precision=precision) # JAX Link: https://github.com/jax-ml/jax/blob/18f48bd52abe907ff9818da52f3d195d32910c1b/jax/_src/image/scale.py#L172 @@ -110,102 +108,102 @@ def _scale_and_translate( # scale and translation here are scalar elements of an np.array, what is the # correct type annotation? def scale_and_translate( + image, + shape: core.Shape, + spatial_dims: Sequence[int], + scale, + translation, + # (barney-s) use string + method: str, # (barney-s) | ResizeMethod, + antialias: bool = True, + precision=lax.Precision.HIGHEST, +): + """Apply a scale and translation to an image. + + Generates a new image of shape 'shape' by resampling from the input image + using the sampling method corresponding to method. For 2D images, this + operation transforms a location in the input images, (x, y), to a location + in the output image according to:: + + (x * scale[1] + translation[1], y * scale[0] + translation[0]) + + (Note the *inverse* warp is used to generate the sample locations.) + Assumes half-centered pixels, i.e the pixel at integer location ``row, col`` + has coordinates ``y, x = row + 0.5, col + 0.5``, and similarly for other input + image dimensions. + + If an output location(pixel) maps to an input sample location that is outside + the input boundaries then the value for the output location will be set to + zero. + + The ``method`` argument expects one of the following resize methods: + + ``ResizeMethod.LINEAR``, ``"linear"``, ``"bilinear"``, ``"trilinear"``, + ``"triangle"`` `Linear interpolation`_. If ``antialias`` is ``True``, uses a + triangular filter when downsampling. + + ``ResizeMethod.CUBIC``, ``"cubic"``, ``"bicubic"``, ``"tricubic"`` + `Cubic interpolation`_, using the Keys cubic kernel. + + ``ResizeMethod.LANCZOS3``, ``"lanczos3"`` + `Lanczos resampling`_, using a kernel of radius 3. + + ``ResizeMethod.LANCZOS5``, ``"lanczos5"`` + `Lanczos resampling`_, using a kernel of radius 5. + + .. _Linear interpolation: https://en.wikipedia.org/wiki/Bilinear_interpolation + .. _Cubic interpolation: https://en.wikipedia.org/wiki/Bicubic_interpolation + .. _Lanczos resampling: https://en.wikipedia.org/wiki/Lanczos_resampling + + Args: + image: a JAX array. + shape: the output shape, as a sequence of integers with length equal to the + number of dimensions of `image`. + spatial_dims: A length K tuple specifying the spatial dimensions that the + passed scale and translation should be applied to. + scale: A [K] array with the same number of dimensions as image, containing + the scale to apply in each dimension. + translation: A [K] array with the same number of dimensions as image, + containing the translation to apply in each dimension. + method: the resizing method to use; either a ``ResizeMethod`` instance or a + string. Available methods are: LINEAR, LANCZOS3, LANCZOS5, CUBIC. + antialias: Should an antialiasing filter be used when downsampling? Defaults + to ``True``. Has no effect when upsampling. + + Returns: + The scale and translated image. + """ + shape = core.canonicalize_shape(shape) + if len(shape) != image.ndim: + msg = ( + "shape must have length equal to the number of dimensions of x; " + f" {shape} vs {image.shape}" + ) + raise ValueError(msg) + if isinstance(method, str): + method = ResizeMethod.from_string(method) + if method == ResizeMethod.NEAREST: + # Nearest neighbor is currently special-cased for straight resize, so skip + # for now. + raise ValueError( + "Nearest neighbor resampling is not currently supported " + "for scale_and_translate." + ) + assert isinstance(method, ResizeMethod) + + kernel = _kernels[method] + (image,) = promote_dtypes_inexact(image) + scale, translation = promote_dtypes_inexact(scale, translation) + return _scale_and_translate( image, - shape: core.Shape, - spatial_dims: Sequence[int], + shape, + spatial_dims, scale, translation, - # (barney-s) use string - method: str, # (barney-s) | ResizeMethod, - antialias: bool = True, - precision=lax.Precision.HIGHEST, -): - """Apply a scale and translation to an image. - - Generates a new image of shape 'shape' by resampling from the input image - using the sampling method corresponding to method. For 2D images, this - operation transforms a location in the input images, (x, y), to a location - in the output image according to:: - - (x * scale[1] + translation[1], y * scale[0] + translation[0]) - - (Note the *inverse* warp is used to generate the sample locations.) - Assumes half-centered pixels, i.e the pixel at integer location ``row, col`` - has coordinates ``y, x = row + 0.5, col + 0.5``, and similarly for other input - image dimensions. - - If an output location(pixel) maps to an input sample location that is outside - the input boundaries then the value for the output location will be set to - zero. - - The ``method`` argument expects one of the following resize methods: - - ``ResizeMethod.LINEAR``, ``"linear"``, ``"bilinear"``, ``"trilinear"``, - ``"triangle"`` `Linear interpolation`_. If ``antialias`` is ``True``, uses a - triangular filter when downsampling. - - ``ResizeMethod.CUBIC``, ``"cubic"``, ``"bicubic"``, ``"tricubic"`` - `Cubic interpolation`_, using the Keys cubic kernel. - - ``ResizeMethod.LANCZOS3``, ``"lanczos3"`` - `Lanczos resampling`_, using a kernel of radius 3. - - ``ResizeMethod.LANCZOS5``, ``"lanczos5"`` - `Lanczos resampling`_, using a kernel of radius 5. - - .. _Linear interpolation: https://en.wikipedia.org/wiki/Bilinear_interpolation - .. _Cubic interpolation: https://en.wikipedia.org/wiki/Bicubic_interpolation - .. _Lanczos resampling: https://en.wikipedia.org/wiki/Lanczos_resampling - - Args: - image: a JAX array. - shape: the output shape, as a sequence of integers with length equal to the - number of dimensions of `image`. - spatial_dims: A length K tuple specifying the spatial dimensions that the - passed scale and translation should be applied to. - scale: A [K] array with the same number of dimensions as image, containing - the scale to apply in each dimension. - translation: A [K] array with the same number of dimensions as image, - containing the translation to apply in each dimension. - method: the resizing method to use; either a ``ResizeMethod`` instance or a - string. Available methods are: LINEAR, LANCZOS3, LANCZOS5, CUBIC. - antialias: Should an antialiasing filter be used when downsampling? Defaults - to ``True``. Has no effect when upsampling. - - Returns: - The scale and translated image. - """ - shape = core.canonicalize_shape(shape) - if len(shape) != image.ndim: - msg = ( - "shape must have length equal to the number of dimensions of x; " - f" {shape} vs {image.shape}" - ) - raise ValueError(msg) - if isinstance(method, str): - method = ResizeMethod.from_string(method) - if method == ResizeMethod.NEAREST: - # Nearest neighbor is currently special-cased for straight resize, so skip - # for now. - raise ValueError( - "Nearest neighbor resampling is not currently supported " - "for scale_and_translate." - ) - assert isinstance(method, ResizeMethod) - - kernel = _kernels[method] - (image,) = promote_dtypes_inexact(image) - scale, translation = promote_dtypes_inexact(scale, translation) - return _scale_and_translate( - image, - shape, - spatial_dims, - scale, - translation, - kernel, - antialias, - precision, - ) + kernel, + antialias, + precision, + ) # END ----------------- END JAX code copied for testing ----------------------------- diff --git a/torchax/torchax/ops/jc10d.py b/torchax/torchax/ops/jc10d.py index 0d730d39a9ad..9cfbc8ba97e4 100644 --- a/torchax/torchax/ops/jc10d.py +++ b/torchax/torchax/ops/jc10d.py @@ -6,45 +6,45 @@ def op(*aten, **kwargs): - def inner(func): - for a in aten: - ops_registry.register_torch_dispatch_op(a, func, **kwargs) - return func + def inner(func): + for a in aten: + ops_registry.register_torch_dispatch_op(a, func, **kwargs) + return func - return inner + return inner @op(torch.ops._c10d_functional.all_gather_into_tensor) def _c10d_all_gather(input, group_size: int, group_name: str): - return jax.lax.all_gather(input, "torch_dist") + return jax.lax.all_gather(input, "torch_dist") @op(torch.ops._c10d_functional.all_reduce) def _c10d_all_reduce(self, reduceOp: str, group_name: str): - if reduceOp == "sum": - res = jax.lax.psum(self, axis_name="torch_dist") - elif reduceOp == "avg": - res = jax.lax.pmean(self, axis_name="torch_dist") - elif reduceOp == "min": - res = jax.lax.pmin(self, axis_name="torch_dist") - elif reduceOp == "max": - res = jax.lax.pmax(self, axis_name="torch_dist") - else: - raise RuntimeError(f"Reduce op {reduceOp} not implemented") - return res + if reduceOp == "sum": + res = jax.lax.psum(self, axis_name="torch_dist") + elif reduceOp == "avg": + res = jax.lax.pmean(self, axis_name="torch_dist") + elif reduceOp == "min": + res = jax.lax.pmin(self, axis_name="torch_dist") + elif reduceOp == "max": + res = jax.lax.pmax(self, axis_name="torch_dist") + else: + raise RuntimeError(f"Reduce op {reduceOp} not implemented") + return res @op(torch.ops._c10d_functional.broadcast) def _c10d_broadcast(self, src: int, group_name: str): - masked = jnp.where( - jax.lax.axis_index("torch_dist") == src, - self, - jnp.zeros_like(self), - ) - return jax.lax.psum(masked, "torch_dist") + masked = jnp.where( + jax.lax.axis_index("torch_dist") == src, + self, + jnp.zeros_like(self), + ) + return jax.lax.psum(masked, "torch_dist") @op(torch.ops._c10d_functional.wait_tensor) def _c10d_wait_tensor(tensor): - # Async tensor is aleady `wait`ed by dispatcher - return tensor + # Async tensor is aleady `wait`ed by dispatcher + return tensor diff --git a/torchax/torchax/ops/jimage.py b/torchax/torchax/ops/jimage.py index dbfc77c64cd2..8ca0653ce460 100644 --- a/torchax/torchax/ops/jimage.py +++ b/torchax/torchax/ops/jimage.py @@ -3,110 +3,110 @@ def cubic_kernel(x, a=-0.75): - """Cubic kernel with a = -0.75 (PyTorch-like Keys kernel)""" - absx = jnp.abs(x) - x2 = absx * absx - x3 = x2 * absx - cond1 = absx <= 1 - cond2 = (absx > 1) & (absx < 2) - f1 = (a + 2) * x3 - (a + 3) * x2 + 1 - f2 = a * x3 - 5 * a * x2 + 8 * a * absx - 4 * a - return jnp.where(cond1, f1, jnp.where(cond2, f2, 0.0)) + """Cubic kernel with a = -0.75 (PyTorch-like Keys kernel)""" + absx = jnp.abs(x) + x2 = absx * absx + x3 = x2 * absx + cond1 = absx <= 1 + cond2 = (absx > 1) & (absx < 2) + f1 = (a + 2) * x3 - (a + 3) * x2 + 1 + f2 = a * x3 - 5 * a * x2 + 8 * a * absx - 4 * a + return jnp.where(cond1, f1, jnp.where(cond2, f2, 0.0)) def compute_contribs( - in_size, out_size, scale, support=2.0, align_corners=False, dtype=None + in_size, out_size, scale, support=2.0, align_corners=False, dtype=None ): - if align_corners: - if out_size == 1: - in_coords = jnp.zeros((1,), dtype=dtype) - else: - in_coords = jnp.linspace(0, in_size - 1, out_size, dtype=dtype) + if align_corners: + if out_size == 1: + in_coords = jnp.zeros((1,), dtype=dtype) else: - out_coords = jnp.arange(out_size, dtype=dtype) + 0.5 - in_coords = out_coords / scale - 0.5 + in_coords = jnp.linspace(0, in_size - 1, out_size, dtype=dtype) + else: + out_coords = jnp.arange(out_size, dtype=dtype) + 0.5 + in_coords = out_coords / scale - 0.5 - left_idx = jnp.floor(in_coords).astype(jnp.int32) - 1 - idxs = left_idx[:, None] + jnp.arange(4) + left_idx = jnp.floor(in_coords).astype(jnp.int32) - 1 + idxs = left_idx[:, None] + jnp.arange(4) - dx = in_coords[:, None] - idxs + dx = in_coords[:, None] - idxs - weights = cubic_kernel(dx) + weights = cubic_kernel(dx) - weights = weights / jnp.sum(weights, axis=1, keepdims=True) - return idxs, weights + weights = weights / jnp.sum(weights, axis=1, keepdims=True) + return idxs, weights def gather_weights(img, idxs, axis): - """Safely gather with boundary handling""" - idxs = jnp.clip(idxs, 0, img.shape[axis] - 1) - return jnp.take(img, idxs, axis=axis) + """Safely gather with boundary handling""" + idxs = jnp.clip(idxs, 0, img.shape[axis] - 1) + return jnp.take(img, idxs, axis=axis) def interpolate_along_axis_bchw(img, idxs, weights, axis): - """ - Interpolate along H (axis=2) or W (axis=3) for tensor (B, C, H, W). - idxs: (out_size, 4) int32 indices - weights: (out_size, 4) float32 weights - """ - assert axis in (2, 3), "Axis must be 2 (H) or 3 (W)" - out_size = idxs.shape[0] - k = idxs.shape[1] # Typically 4 for cubic - - # Clip to input bounds - idxs = jnp.clip(idxs, 0, img.shape[axis] - 1) # (out_size, 4) - - def gather_and_weight(i): - idx = idxs[i] # (4,) - w = weights[i] # (4,) - - def gather_one(offset): - return jnp.take(img, idx[offset], axis=axis) # shape (B, C, H, W) - - gathered = jnp.stack( - [gather_one(o) for o in range(k)], axis=0 - ) # (4, B, C, H, W) - weighted = jnp.tensordot(w, gathered, axes=(0, 0)) # (B, C, H, W) - return weighted - - out = jax.vmap(gather_and_weight)( - jnp.arange(out_size) - ) # (out_size, B, C, H, W) - - # Move the interpolated axis back into place - if axis == 2: # interpolated over H - return jnp.moveaxis(out, 0, 2) # (B, C, out_H, W) - else: # axis == 3, interpolated over W - return jnp.moveaxis(out, 0, 3) # (B, C, H, out_W) + """ + Interpolate along H (axis=2) or W (axis=3) for tensor (B, C, H, W). + idxs: (out_size, 4) int32 indices + weights: (out_size, 4) float32 weights + """ + assert axis in (2, 3), "Axis must be 2 (H) or 3 (W)" + out_size = idxs.shape[0] + k = idxs.shape[1] # Typically 4 for cubic + + # Clip to input bounds + idxs = jnp.clip(idxs, 0, img.shape[axis] - 1) # (out_size, 4) + + def gather_and_weight(i): + idx = idxs[i] # (4,) + w = weights[i] # (4,) + + def gather_one(offset): + return jnp.take(img, idx[offset], axis=axis) # shape (B, C, H, W) + + gathered = jnp.stack( + [gather_one(o) for o in range(k)], axis=0 + ) # (4, B, C, H, W) + weighted = jnp.tensordot(w, gathered, axes=(0, 0)) # (B, C, H, W) + return weighted + + out = jax.vmap(gather_and_weight)( + jnp.arange(out_size) + ) # (out_size, B, C, H, W) + + # Move the interpolated axis back into place + if axis == 2: # interpolated over H + return jnp.moveaxis(out, 0, 2) # (B, C, out_H, W) + else: # axis == 3, interpolated over W + return jnp.moveaxis(out, 0, 3) # (B, C, H, out_W) def interpolate_bicubic_no_aa(img, out_h, out_w, align_corners=False): - h, w = img.shape[-2:] - if align_corners and out_h > 1: - scale_y = (h - 1) / (out_h - 1) - else: - scale_y = out_h / h - - if align_corners and out_w > 1: - scale_x = (w - 1) / (out_w - 1) - else: - scale_x = out_w / w - - idxs_y, weights_y = compute_contribs( - h, - out_h, - scale_y, - align_corners=align_corners, - dtype=img.dtype, - ) - tmp = interpolate_along_axis_bchw(img, idxs_y, weights_y, axis=2) - - idxs_x, weights_x = compute_contribs( - w, - out_w, - scale_x, - align_corners=align_corners, - dtype=img.dtype, - ) - out = interpolate_along_axis_bchw(tmp, idxs_x, weights_x, axis=3) - return out + h, w = img.shape[-2:] + if align_corners and out_h > 1: + scale_y = (h - 1) / (out_h - 1) + else: + scale_y = out_h / h + + if align_corners and out_w > 1: + scale_x = (w - 1) / (out_w - 1) + else: + scale_x = out_w / w + + idxs_y, weights_y = compute_contribs( + h, + out_h, + scale_y, + align_corners=align_corners, + dtype=img.dtype, + ) + tmp = interpolate_along_axis_bchw(img, idxs_y, weights_y, axis=2) + + idxs_x, weights_x = compute_contribs( + w, + out_w, + scale_x, + align_corners=align_corners, + dtype=img.dtype, + ) + out = interpolate_along_axis_bchw(tmp, idxs_x, weights_x, axis=3) + return out diff --git a/torchax/torchax/ops/jlibrary.py b/torchax/torchax/ops/jlibrary.py index 697cf6fe646d..90156b9c8432 100644 --- a/torchax/torchax/ops/jlibrary.py +++ b/torchax/torchax/ops/jlibrary.py @@ -11,70 +11,68 @@ def _jit_composite_impl(composite_name, jaxpr_impl, **jit_args): - """Wrap a jaxpr in a jitted function with the proper composite name - TODO: Wrap JIT in a `stablehlo.composite` op, instead of generating a call op. - """ + """Wrap a jaxpr in a jitted function with the proper composite name + TODO: Wrap JIT in a `stablehlo.composite` op, instead of generating a call op. + """ - def composite_impl(*args): - return jaxpr_impl(*args) + def composite_impl(*args): + return jaxpr_impl(*args) - composite_impl.__name__ = composite_name - composite_impl.__qualname__ = composite_name - return jax.jit(composite_impl, **jit_args) + composite_impl.__name__ = composite_name + composite_impl.__qualname__ = composite_name + return jax.jit(composite_impl, **jit_args) def register_jax_composite(composite_name, impl, *ops, **jit_args): - """Register a composite using a JAX implementation. - composite_name - The name of the library op to use in the exported composite - impl - A JAX lowering for the library operation - *ops - Variadic torch.ops to lower using `impl`. - **jit_args - Additional parameters to forward to JAX jit. + """Register a composite using a JAX implementation. + composite_name - The name of the library op to use in the exported composite + impl - A JAX lowering for the library operation + *ops - Variadic torch.ops to lower using `impl`. + **jit_args - Additional parameters to forward to JAX jit. - This is used to register custom lowerings with an explicit jaxpr - implementation, such as preserving a specific aten op using a jaten impl. + This is used to register custom lowerings with an explicit jaxpr + implementation, such as preserving a specific aten op using a jaten impl. - For custom torch op registration with a decomposition written in torch, - use `register_torch_composite`. + For custom torch op registration with a decomposition written in torch, + use `register_torch_composite`. - For jit params and troubleshooting see: - https://jax.readthedocs.io/en/latest/_autosummary/jax.jit.html - """ + For jit params and troubleshooting see: + https://jax.readthedocs.io/en/latest/_autosummary/jax.jit.html + """ - @jaten.op(*ops) - def _composite_impl(*args): - return _jit_composite_impl(composite_name, impl, **jit_args)(*args) + @jaten.op(*ops) + def _composite_impl(*args): + return _jit_composite_impl(composite_name, impl, **jit_args)(*args) def register_torch_composite(composite_name, impl, *ops, **jit_args): - """Register a torch decomposition as a composite. - This is useful for registerring custom torch op libraries as composite ops. - - The `impl` can be the `@impl` used to define the torch custom library op. - This must be a function or module impl that provides the decompositions, and - not an instance of the custom op. - - TODO: Better error handling, or can we make this an instance of the op as a param? - - For jit params and troubleshooting see: - https://jax.readthedocs.io/en/latest/_autosummary/jax.jit.html - """ - - @jaten.op(*ops) - def _composite_impl(*args): - class ImplWrapper(torch.nn.Module): - def __init__(self): - super().__init__() - - def forward(self, *args): - return impl(*args) - - # Note: avoid refactoring to share code with register_jaxpr_composite. - # The `extract_jax` call must live in the `@jaten.op` handler. If called - # outside of the handler, we would build the jaxpr representation of the - # module once during registration, potentially missing op registrations that - # come after. I.e. may miss nested abstractions if we build jaxpr AoT. - state, jfn = torchax.extract_jax(ImplWrapper()) - jaxpr_impl = lambda *args: jfn(state, tuple([*args])) - return _jit_composite_impl(composite_name, jaxpr_impl, **jit_args)( - *args - ) + """Register a torch decomposition as a composite. + This is useful for registerring custom torch op libraries as composite ops. + + The `impl` can be the `@impl` used to define the torch custom library op. + This must be a function or module impl that provides the decompositions, and + not an instance of the custom op. + + TODO: Better error handling, or can we make this an instance of the op as a param? + + For jit params and troubleshooting see: + https://jax.readthedocs.io/en/latest/_autosummary/jax.jit.html + """ + + @jaten.op(*ops) + def _composite_impl(*args): + class ImplWrapper(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, *args): + return impl(*args) + + # Note: avoid refactoring to share code with register_jaxpr_composite. + # The `extract_jax` call must live in the `@jaten.op` handler. If called + # outside of the handler, we would build the jaxpr representation of the + # module once during registration, potentially missing op registrations that + # come after. I.e. may miss nested abstractions if we build jaxpr AoT. + state, jfn = torchax.extract_jax(ImplWrapper()) + jaxpr_impl = lambda *args: jfn(state, tuple([*args])) + return _jit_composite_impl(composite_name, jaxpr_impl, **jit_args)(*args) diff --git a/torchax/torchax/ops/jtorch.py b/torchax/torchax/ops/jtorch.py index b3d5340cc7b3..a2bb6b1a6bdb 100644 --- a/torchax/torchax/ops/jtorch.py +++ b/torchax/torchax/ops/jtorch.py @@ -20,561 +20,557 @@ def register_function(torch_func, **kwargs): - return functools.partial(register_torch_function_op, torch_func, **kwargs) + return functools.partial(register_torch_function_op, torch_func, **kwargs) @register_function(torch.as_tensor, is_jax_function=False, needs_env=True) @op_base.convert_dtype( - use_default_dtype=False + use_default_dtype=False ) # Attempt to infer type from elements def _as_tensor(data, dtype=None, device=None, env=None): - if isinstance(data, torch.Tensor): - return env._to_copy(data, dtype, device) - if isinstance(data, np.ndarray): - jax_res = jnp.asarray(data) - else: - jax_res = _tensor(data, dtype=dtype) - return torchax.tensor.Tensor(jax_res, env) + if isinstance(data, torch.Tensor): + return env._to_copy(data, dtype, device) + if isinstance(data, np.ndarray): + jax_res = jnp.asarray(data) + else: + jax_res = _tensor(data, dtype=dtype) + return torchax.tensor.Tensor(jax_res, env) @register_function(torch.tensor) @op_base.convert_dtype( - use_default_dtype=False + use_default_dtype=False ) # Attempt to infer type from elements def _tensor(data, *, dtype=None, **kwargs): - python_types_to_torch_types = { - bool: jnp.bool, - int: jnp.int64, - float: jnp.float32, - complex: jnp.complex64, - } - if not dtype: - leaves = jax.tree_util.tree_leaves(data) - if len(leaves) > 0: - dtype = python_types_to_torch_types.get(type(leaves[0])) - - return jnp.array( - data, dtype=dtype or mappings.t2j_dtype(torch.get_default_dtype()) - ) + python_types_to_torch_types = { + bool: jnp.bool, + int: jnp.int64, + float: jnp.float32, + complex: jnp.complex64, + } + if not dtype: + leaves = jax.tree_util.tree_leaves(data) + if len(leaves) > 0: + dtype = python_types_to_torch_types.get(type(leaves[0])) + + return jnp.array( + data, dtype=dtype or mappings.t2j_dtype(torch.get_default_dtype()) + ) @register_function(torch.allclose) def _aten_allclose(input, other, rtol=1e-05, atol=1e-08, equal_nan=False): - return jnp.allclose(input, other, rtol, atol, equal_nan) + return jnp.allclose(input, other, rtol, atol, equal_nan) @register_function(torch.angle) def _torch_angle(input): - if input.dtype.name == "int64": - input = input.astype(jnp.dtype("float32")) - return jnp.angle(input) + if input.dtype.name == "int64": + input = input.astype(jnp.dtype("float32")) + return jnp.angle(input) @register_function(torch.argsort) def _torch_argsort(input, dim=-1, descending=False, stable=False): - expanded = False - if input.ndim == 0: - # for self of rank 0: - # torch.any(x, 0), torch.any(x, -1) works; - # torch.any(x, 1) throws out of bounds, so it's - # behavior is the same as a jnp array of rank 1 - expanded = True - input = jnp.expand_dims(input, 0) - res = jnp.argsort(input, axis=dim, descending=descending, stable=stable) - if expanded: - res = res.squeeze() - return res + expanded = False + if input.ndim == 0: + # for self of rank 0: + # torch.any(x, 0), torch.any(x, -1) works; + # torch.any(x, 1) throws out of bounds, so it's + # behavior is the same as a jnp array of rank 1 + expanded = True + input = jnp.expand_dims(input, 0) + res = jnp.argsort(input, axis=dim, descending=descending, stable=stable) + if expanded: + res = res.squeeze() + return res @register_function(torch.diag) def _diag(input, diagonal=0): - return jnp.diag(input, k=diagonal) + return jnp.diag(input, k=diagonal) @register_function(torch.einsum) @register_function(torch.ops.aten.einsum) def _einsum(equation, *operands): - def get_params(*a): - inner_list = a[0] - if not isinstance(inner_list, jax.Array): - if len(inner_list) == 1: - A = inner_list - return A - elif len(inner_list) == 2: - A, B = inner_list - return A, B - return operands - - assert isinstance(equation, str), "Only accept str equation" - filtered_operands = get_params(*operands) - return jnp.einsum(equation, *filtered_operands) + def get_params(*a): + inner_list = a[0] + if not isinstance(inner_list, jax.Array): + if len(inner_list) == 1: + A = inner_list + return A + elif len(inner_list) == 2: + A, B = inner_list + return A, B + return operands + + assert isinstance(equation, str), "Only accept str equation" + filtered_operands = get_params(*operands) + return jnp.einsum(equation, *filtered_operands) def _sdpa_reference( - query, - key, - value, - attn_mask=None, - dropout_p=0.0, - is_causal=False, - scale=None, - enable_gqa=False, + query, + key, + value, + attn_mask=None, + dropout_p=0.0, + is_causal=False, + scale=None, + enable_gqa=False, ) -> torch.Tensor: - L, S = query.size(-2), key.size(-2) - scale_factor = 1 / math.sqrt(query.size(-1)) if scale is None else scale - attn_bias = torch.zeros(L, S, dtype=query.dtype, device=query.device) - if is_causal: - assert attn_mask is None - temp_mask = torch.ones( - L, S, dtype=torch.bool, device=query.device - ).tril(diagonal=0) - attn_bias.masked_fill_(temp_mask.logical_not(), float("-inf")) - attn_bias.to(query.dtype) - if attn_mask is not None: - if attn_mask.dtype == torch.bool: - attn_bias.masked_fill_(attn_mask.logical_not(), float("-inf")) - else: - attn_bias += attn_mask - if enable_gqa: - key = key.repeat_interleave(query.size(-3) // key.size(-3), -3) - value = value.repeat_interleave(query.size(-3) // value.size(-3), -3) - - attn_weight = query @ key.transpose(-2, -1) * scale_factor - attn_weight += attn_bias - attn_weight = torch.softmax(attn_weight, dim=-1) - if dropout_p > 0: - attn_weight = torch.dropout(attn_weight, dropout_p, train=True) - return attn_weight @ value + L, S = query.size(-2), key.size(-2) + scale_factor = 1 / math.sqrt(query.size(-1)) if scale is None else scale + attn_bias = torch.zeros(L, S, dtype=query.dtype, device=query.device) + if is_causal: + assert attn_mask is None + temp_mask = torch.ones(L, S, dtype=torch.bool, device=query.device).tril( + diagonal=0 + ) + attn_bias.masked_fill_(temp_mask.logical_not(), float("-inf")) + attn_bias.to(query.dtype) + if attn_mask is not None: + if attn_mask.dtype == torch.bool: + attn_bias.masked_fill_(attn_mask.logical_not(), float("-inf")) + else: + attn_bias += attn_mask + if enable_gqa: + key = key.repeat_interleave(query.size(-3) // key.size(-3), -3) + value = value.repeat_interleave(query.size(-3) // value.size(-3), -3) + + attn_weight = query @ key.transpose(-2, -1) * scale_factor + attn_weight += attn_bias + attn_weight = torch.softmax(attn_weight, dim=-1) + if dropout_p > 0: + attn_weight = torch.dropout(attn_weight, dropout_p, train=True) + return attn_weight @ value from jax.sharding import PartitionSpec def _tpu_flash_attention(query, key, value, env): - fsdp_partition = PartitionSpec("fsdp") - - def wrap_flash_attention(query, key, value): - block_sizes = flash_attention.BlockSizes( - block_b=min(2, query.shape[0]), - block_q=min(512, query.shape[2]), - block_k_major=min(512, key.shape[2]), - block_k=min(512, key.shape[2]), - block_q_major_dkv=min(512, query.shape[2]), - block_k_major_dkv=min(512, key.shape[2]), - block_k_dkv=min(512, key.shape[2]), - block_q_dkv=min(512, query.shape[2]), - block_k_major_dq=min(512, key.shape[2]), - block_k_dq=min(256, key.shape[2]), - block_q_dq=min(1024, query.shape[2]), - ) - return flash_attention.flash_attention( - query, key, value, causal=True, block_sizes=block_sizes - ) - - if env.config.shmap_flash_attention: - wrap_flash_attention = shard_map( - wrap_flash_attention, - mesh=env._mesh, - in_specs=(fsdp_partition, fsdp_partition, fsdp_partition), - out_specs=fsdp_partition, - check_rep=False, - ) - # return flash_attn_mapped(query, key, value) - return wrap_flash_attention(query, key, value) + fsdp_partition = PartitionSpec("fsdp") + + def wrap_flash_attention(query, key, value): + block_sizes = flash_attention.BlockSizes( + block_b=min(2, query.shape[0]), + block_q=min(512, query.shape[2]), + block_k_major=min(512, key.shape[2]), + block_k=min(512, key.shape[2]), + block_q_major_dkv=min(512, query.shape[2]), + block_k_major_dkv=min(512, key.shape[2]), + block_k_dkv=min(512, key.shape[2]), + block_q_dkv=min(512, query.shape[2]), + block_k_major_dq=min(512, key.shape[2]), + block_k_dq=min(256, key.shape[2]), + block_q_dq=min(1024, query.shape[2]), + ) + return flash_attention.flash_attention( + query, key, value, causal=True, block_sizes=block_sizes + ) + + if env.config.shmap_flash_attention: + wrap_flash_attention = shard_map( + wrap_flash_attention, + mesh=env._mesh, + in_specs=(fsdp_partition, fsdp_partition, fsdp_partition), + out_specs=fsdp_partition, + check_rep=False, + ) + # return flash_attn_mapped(query, key, value) + return wrap_flash_attention(query, key, value) @register_function(torch.nn.functional.pad) def pad(tensor, pad, mode="constant", value=None): - # For padding modes that have different names between Torch and NumPy, this - # dict provides a Torch-to-NumPy translation. Any string not in this dict will - # be passed through as-is. - MODE_NAME_TRANSLATION = { - "circular": "wrap", - "replicate": "edge", - } - - numpy_mode = MODE_NAME_TRANSLATION.get(mode, mode) - - num_prefix_dims = tensor.ndim - len(pad) // 2 - - numpy_pad_width = [(0, 0)] * num_prefix_dims - nd_slice = [slice(None)] * num_prefix_dims - - for i in range(len(pad) - 2, -1, -2): - pad_start, pad_end = pad[i : i + 2] - slice_start, slice_end = None, None - - if pad_start < 0: - slice_start = -pad_start - pad_start = 0 - - if pad_end < 0: - slice_end = pad_end - pad_end = 0 - - numpy_pad_width.append((pad_start, pad_end)) - nd_slice.append(slice(slice_start, slice_end)) - - nd_slice = tuple(nd_slice) - - # `jax.numpy.pad` complains if we provide an irrelevant `constant_values` arg, - # even if the value we pass in is `None`. (It treats `None` as `nan`.) - kwargs = dict() - if mode == "constant" and value is not None: - kwargs["constant_values"] = value - - # The "replicate" mode pads first and then slices, whereas the "circular" mode - # slices first and then pads. The latter approach deals with smaller tensors, - # so we default to that option in modes where the order of operations doesn't - # affect the result. - if mode == "replicate": - return jnp.pad(tensor, numpy_pad_width, mode=numpy_mode, **kwargs)[ - nd_slice - ] - else: - return jnp.pad( - tensor[nd_slice], numpy_pad_width, mode=numpy_mode, **kwargs - ) + # For padding modes that have different names between Torch and NumPy, this + # dict provides a Torch-to-NumPy translation. Any string not in this dict will + # be passed through as-is. + MODE_NAME_TRANSLATION = { + "circular": "wrap", + "replicate": "edge", + } + + numpy_mode = MODE_NAME_TRANSLATION.get(mode, mode) + + num_prefix_dims = tensor.ndim - len(pad) // 2 + + numpy_pad_width = [(0, 0)] * num_prefix_dims + nd_slice = [slice(None)] * num_prefix_dims + + for i in range(len(pad) - 2, -1, -2): + pad_start, pad_end = pad[i : i + 2] + slice_start, slice_end = None, None + + if pad_start < 0: + slice_start = -pad_start + pad_start = 0 + + if pad_end < 0: + slice_end = pad_end + pad_end = 0 + + numpy_pad_width.append((pad_start, pad_end)) + nd_slice.append(slice(slice_start, slice_end)) + + nd_slice = tuple(nd_slice) + + # `jax.numpy.pad` complains if we provide an irrelevant `constant_values` arg, + # even if the value we pass in is `None`. (It treats `None` as `nan`.) + kwargs = dict() + if mode == "constant" and value is not None: + kwargs["constant_values"] = value + + # The "replicate" mode pads first and then slices, whereas the "circular" mode + # slices first and then pads. The latter approach deals with smaller tensors, + # so we default to that option in modes where the order of operations doesn't + # affect the result. + if mode == "replicate": + return jnp.pad(tensor, numpy_pad_width, mode=numpy_mode, **kwargs)[nd_slice] + else: + return jnp.pad(tensor[nd_slice], numpy_pad_width, mode=numpy_mode, **kwargs) @register_function( - torch.nn.functional.scaled_dot_product_attention, - is_jax_function=False, - needs_env=True, + torch.nn.functional.scaled_dot_product_attention, + is_jax_function=False, + needs_env=True, ) @register_function( - torch.ops.aten.scaled_dot_product_attention, - is_jax_function=False, - needs_env=True, + torch.ops.aten.scaled_dot_product_attention, + is_jax_function=False, + needs_env=True, ) def scaled_dot_product_attention( - query, - key, - value, - attn_mask=None, - dropout_p=0.0, - is_causal=False, - scale=None, - enable_gqa=False, - env=None, + query, + key, + value, + attn_mask=None, + dropout_p=0.0, + is_causal=False, + scale=None, + enable_gqa=False, + env=None, ) -> torch.Tensor: - if env.config.use_tpu_flash_attention: - jquery, jkey, jvalue = env.t2j_iso((query, key, value)) - res = _tpu_flash_attention(jquery, jkey, jvalue, env) - return env.j2t_iso(res) + if env.config.use_tpu_flash_attention: + jquery, jkey, jvalue = env.t2j_iso((query, key, value)) + res = _tpu_flash_attention(jquery, jkey, jvalue, env) + return env.j2t_iso(res) - return _sdpa_reference( - query, key, value, attn_mask, dropout_p, is_causal, scale, enable_gqa - ) + return _sdpa_reference( + query, key, value, attn_mask, dropout_p, is_causal, scale, enable_gqa + ) @register_function( - torch.Tensor.__getitem__, is_jax_function=False, is_view_op=True + torch.Tensor.__getitem__, is_jax_function=False, is_view_op=True ) def getitem(self, indexes): - if isinstance(indexes, list) and isinstance(indexes[0], int): - # list of int, i.e. x[[1, 2]] NOT x[1, 2] (the second would be tuple of int) - indexes = (indexes,) - elif isinstance(indexes, list): - indexes = tuple(indexes) - - def is_narrow_slicing(): - tensor_free = not pytree.tree_any( - lambda x: isinstance(x, torch.Tensor) or isinstance(x, jax.Array), - indexes, - ) - list_free = not isinstance(indexes, tuple) or all([ - False if isinstance(x, list) else True for x in indexes - ]) - return tensor_free and list_free + if isinstance(indexes, list) and isinstance(indexes[0], int): + # list of int, i.e. x[[1, 2]] NOT x[1, 2] (the second would be tuple of int) + indexes = (indexes,) + elif isinstance(indexes, list): + indexes = tuple(indexes) + + def is_narrow_slicing(): + tensor_free = not pytree.tree_any( + lambda x: isinstance(x, torch.Tensor) or isinstance(x, jax.Array), + indexes, + ) + list_free = not isinstance(indexes, tuple) or all([ + False if isinstance(x, list) else True for x in indexes + ]) + return tensor_free and list_free - if is_narrow_slicing(): - return View(self, view_info=NarrowInfo(indexes), env=self._env) + if is_narrow_slicing(): + return View(self, view_info=NarrowInfo(indexes), env=self._env) - indexes = self._env.t2j_iso(indexes) - return torchax.tensor.Tensor(self._elem[indexes], self._env) + indexes = self._env.t2j_iso(indexes) + return torchax.tensor.Tensor(self._elem[indexes], self._env) @register_function(torch.corrcoef) def _corrcoef(x): - if x.dtype.name == "int64": - return jnp.corrcoef(x).astype(jnp.float32) - return jnp.corrcoef(x) + if x.dtype.name == "int64": + return jnp.corrcoef(x).astype(jnp.float32) + return jnp.corrcoef(x) @register_function(torch.sparse.mm, is_jax_function=False) def _sparse_mm(mat1, mat2, reduce="sum"): - return torch.mm(mat1, mat2) + return torch.mm(mat1, mat2) @register_function(torch.isclose) def _aten_isclose(input, other, rtol=1e-05, atol=1e-08, equal_nan=False): - return jnp.isclose(input, other, rtol, atol, equal_nan) + return jnp.isclose(input, other, rtol, atol, equal_nan) @register_function(torch.linalg.det) def linalg_det(input): - return jnp.linalg.det(input) + return jnp.linalg.det(input) @register_function(torch.ones) def _ones(*size: int, dtype=None, **kwargs): - if len(size) == 1 and isinstance(size[0], collections.abc.Iterable): - size = size[0] - return jaten._ones(size, dtype=dtype) + if len(size) == 1 and isinstance(size[0], collections.abc.Iterable): + size = size[0] + return jaten._ones(size, dtype=dtype) @register_function(torch.zeros, is_jax_function=True) def _zeros(*size: int, dtype=None, **kwargs): - if len(size) == 1 and isinstance(size[0], collections.abc.Iterable): - size = size[0] - return jaten._zeros(size, dtype=dtype) + if len(size) == 1 and isinstance(size[0], collections.abc.Iterable): + size = size[0] + return jaten._zeros(size, dtype=dtype) @register_function(torch.eye) @op_base.convert_dtype() def _eye(n: int, m: Optional[int] = None, *, dtype=None, **kwargs): - return jnp.eye(n, m, dtype=dtype) + return jnp.eye(n, m, dtype=dtype) @register_function(torch.full) @op_base.convert_dtype() def _full(size: Sequence[int], fill_value, *, dtype=None, **kwargs): - # TODO: handle torch.Size - return jnp.full(size, fill_value, dtype=dtype) + # TODO: handle torch.Size + return jnp.full(size, fill_value, dtype=dtype) @register_function(torch.empty) @op_base.convert_dtype() def empty(*size: Sequence[int], dtype=None, **kwargs): - if len(size) == 1 and isinstance(size[0], collections.abc.Iterable): - size = size[0] - return jnp.empty(size, dtype=dtype) + if len(size) == 1 and isinstance(size[0], collections.abc.Iterable): + size = size[0] + return jnp.empty(size, dtype=dtype) @register_function(torch.arange, is_jax_function=False) def arange( - start, - end=None, - step=None, - out=None, - dtype=None, - layout=torch.strided, - device=None, - requires_grad=False, - pin_memory=None, + start, + end=None, + step=None, + out=None, + dtype=None, + layout=torch.strided, + device=None, + requires_grad=False, + pin_memory=None, ): - if end is None: - end = start - start = 0 - if step is None: - step = 1 - return torch.ops.aten.arange(start, end, step, dtype=dtype) + if end is None: + end = start + start = 0 + if step is None: + step = 1 + return torch.ops.aten.arange(start, end, step, dtype=dtype) @register_function(torch.empty_strided, is_jax_function=False) def empty_strided( - size, - stride, - *, - dtype=None, - layout=None, - device=None, - requires_grad=False, - pin_memory=False, + size, + stride, + *, + dtype=None, + layout=None, + device=None, + requires_grad=False, + pin_memory=False, ): - return empty(size, dtype=dtype) + return empty(size, dtype=dtype) @register_function(torch.unravel_index) def unravel_index(indices, shape): - return jnp.unravel_index(indices, shape) + return jnp.unravel_index(indices, shape) @register_function(torch.rand, is_jax_function=False) def rand(*size, **kwargs): - if len(size) == 1 and isinstance(size[0], collections.abc.Iterable): - size = size[0] - return torch.ops.aten.rand(size, **kwargs) + if len(size) == 1 and isinstance(size[0], collections.abc.Iterable): + size = size[0] + return torch.ops.aten.rand(size, **kwargs) @register_function(torch.randn, is_jax_function=False) def randn( - *size, - generator=None, - out=None, - dtype=None, - layout=torch.strided, - device=None, - requires_grad=False, - pin_memory=False, + *size, + generator=None, + out=None, + dtype=None, + layout=torch.strided, + device=None, + requires_grad=False, + pin_memory=False, ): - if len(size) == 1 and isinstance(size[0], collections.abc.Iterable): - size = size[0] - return torch.ops.aten.randn(size, generator=generator, dtype=dtype) + if len(size) == 1 and isinstance(size[0], collections.abc.Iterable): + size = size[0] + return torch.ops.aten.randn(size, generator=generator, dtype=dtype) @register_function(torch.randint, is_jax_function=False) def randint(*args, **kwargs): - return torch.ops.aten.randint(*args, **kwargs) + return torch.ops.aten.randint(*args, **kwargs) @register_function(torch.logdet) def logdet(input): - _, logabsdet = jaten._aten__linalg_slogdet(input) - return logabsdet + _, logabsdet = jaten._aten__linalg_slogdet(input) + return logabsdet @register_function(torch.linalg.slogdet) def linalg_slogdet(input): - sign, logabsdet = jaten._aten__linalg_slogdet(input) - return torch.return_types.slogdet((sign, logabsdet)) + sign, logabsdet = jaten._aten__linalg_slogdet(input) + return torch.return_types.slogdet((sign, logabsdet)) @register_function(torch.tensor_split) def tensor_split(input, indices_or_sections, dim=0): - return jnp.array_split(input, indices_or_sections, axis=dim) + return jnp.array_split(input, indices_or_sections, axis=dim) @register_function(torch.linalg.solve) def linalg_solve(a, b): - res, _ = jaten._aten__linalg_solve_ex(a, b) - return res + res, _ = jaten._aten__linalg_solve_ex(a, b) + return res @register_function(torch.linalg.solve_ex) def linalg_solve_ex(a, b): - res, info = jaten._aten__linalg_solve_ex(a, b) - return res, info + res, info = jaten._aten__linalg_solve_ex(a, b) + return res, info @register_function(torch.linalg.svd) def linalg_svd(a, full_matrices=True): - return jaten._aten__linalg_svd(a, full_matrices=full_matrices) + return jaten._aten__linalg_svd(a, full_matrices=full_matrices) @register_function(torch.linalg.matrix_power) def matrix_power(A, n, *, out=None): - return jnp.linalg.matrix_power(A, n) + return jnp.linalg.matrix_power(A, n) @register_function(torch.svd) def svd(a, some=True, compute_uv=True): - if not compute_uv: - S = jaten._aten__linalg_svd(a, full_matrices=False)[1] - U = jnp.zeros((a.shape[-2], a.shape[-2]), dtype=a.dtype) - V = jnp.zeros((a.shape[-1], a.shape[-1]), dtype=a.dtype) - return U, S, V - U, S, V = jaten._aten__linalg_svd(a, full_matrices=not some) - return U, S, jnp.matrix_transpose(V) + if not compute_uv: + S = jaten._aten__linalg_svd(a, full_matrices=False)[1] + U = jnp.zeros((a.shape[-2], a.shape[-2]), dtype=a.dtype) + V = jnp.zeros((a.shape[-1], a.shape[-1]), dtype=a.dtype) + return U, S, V + U, S, V = jaten._aten__linalg_svd(a, full_matrices=not some) + return U, S, jnp.matrix_transpose(V) @register_function(torch.cdist) def _cdist(x1, x2, p=2.0, compute_mode="use_mm_for_euclid_dist_if_necessary"): - return jaten._aten_cdist(x1, x2, p, compute_mode) + return jaten._aten_cdist(x1, x2, p, compute_mode) @register_function(torch.lu) def lu(A, **kwargs): - lu, pivots, _ = jax.lax.linalg.lu(A) - # JAX pivots are offset by 1 compared to torch - _pivots = pivots + 1 - info_shape = pivots.shape[:-1] - info = jnp.zeros(info_shape, dtype=mappings.t2j_dtype(torch.int32)) - if kwargs["get_infos"] == True: - return lu, _pivots, info - return lu, _pivots + lu, pivots, _ = jax.lax.linalg.lu(A) + # JAX pivots are offset by 1 compared to torch + _pivots = pivots + 1 + info_shape = pivots.shape[:-1] + info = jnp.zeros(info_shape, dtype=mappings.t2j_dtype(torch.int32)) + if kwargs["get_infos"] == True: + return lu, _pivots, info + return lu, _pivots @register_function(torch.lu_solve) def lu_solve(b, LU_data, LU_pivots, **kwargs): - # JAX pivots are offset by 1 compared to torch - _pivots = LU_pivots - 1 - x = jax.scipy.linalg.lu_solve((LU_data, _pivots), b) - return x + # JAX pivots are offset by 1 compared to torch + _pivots = LU_pivots - 1 + x = jax.scipy.linalg.lu_solve((LU_data, _pivots), b) + return x @register_function(torch.linalg.tensorsolve) def linalg_tensorsolve(A, b, dims=None): - # examples: - # A = torch.randn(2, 3, 6), b = torch.randn(3, 2) - # A = torch.randn(2, 3, 6), b = torch.randn(2, 3) -> torch.Size([3, 6]) - # A = torch.randn(9, 2, 6, 3) b = torch.randn(6, 3) -> torch.Size([6, 3]) - # A = torch.randn(9, 2, 3, 6) b = torch.randn(6, 3) -> torch.Size([3, 6]) - # A = torch.randn(18, 6, 3) b = torch.randn(18) -> torch.Size([6, 3]) - # A = torch.randn(3, 8, 4, 6) b = torch.randn(4, 6) -> torch.Size([4,6]) - # A = torch.randn(3, 8, 1, 2, 2, 6) b = torch.randn(3, 4, 2) -> torch.Size([2, 2, 6]) - - # torch allows b to be shaped differently. - # especially when axes are moved using dims. - # ValueError: After moving axes to end, leading shape of a must match shape of b. got a.shape=(3, 2, 6), b.shape=(2, 3) - # So we are handling the moveaxis and forcing b's shape to match what jax expects - if dims is not None: - A = jnp.moveaxis(A, dims, len(dims) * (A.ndim - 1,)) - dims = None - if A.shape[: b.ndim] != b.shape: - b = jnp.reshape(b, A.shape[: b.ndim]) - return jnp.linalg.tensorsolve(A, b, axes=dims) + # examples: + # A = torch.randn(2, 3, 6), b = torch.randn(3, 2) + # A = torch.randn(2, 3, 6), b = torch.randn(2, 3) -> torch.Size([3, 6]) + # A = torch.randn(9, 2, 6, 3) b = torch.randn(6, 3) -> torch.Size([6, 3]) + # A = torch.randn(9, 2, 3, 6) b = torch.randn(6, 3) -> torch.Size([3, 6]) + # A = torch.randn(18, 6, 3) b = torch.randn(18) -> torch.Size([6, 3]) + # A = torch.randn(3, 8, 4, 6) b = torch.randn(4, 6) -> torch.Size([4,6]) + # A = torch.randn(3, 8, 1, 2, 2, 6) b = torch.randn(3, 4, 2) -> torch.Size([2, 2, 6]) + + # torch allows b to be shaped differently. + # especially when axes are moved using dims. + # ValueError: After moving axes to end, leading shape of a must match shape of b. got a.shape=(3, 2, 6), b.shape=(2, 3) + # So we are handling the moveaxis and forcing b's shape to match what jax expects + if dims is not None: + A = jnp.moveaxis(A, dims, len(dims) * (A.ndim - 1,)) + dims = None + if A.shape[: b.ndim] != b.shape: + b = jnp.reshape(b, A.shape[: b.ndim]) + return jnp.linalg.tensorsolve(A, b, axes=dims) @register_function(torch.nn.functional.linear) def functional_linear(self, weights, bias=None): - res = jnp.einsum("...a,ba->...b", self, weights) - if bias is not None: - res += bias - return res + res = jnp.einsum("...a,ba->...b", self, weights) + if bias is not None: + res += bias + return res @register_function(torch.nn.functional.interpolate) def functional_interpolate( - input, - size: Tuple[int, int], - scale_factor: Optional[float], - mode: str, - align_corners: bool, - recompute_scale_factor: bool, - antialias: bool, + input, + size: Tuple[int, int], + scale_factor: Optional[float], + mode: str, + align_corners: bool, + recompute_scale_factor: bool, + antialias: bool, ): - supported_methods = ( - "nearest", - "linear", - "bilinear", - "trilinear", - "cubic", - "bicubic", - "tricubic", - "lanczos3", - "lanczos5", + supported_methods = ( + "nearest", + "linear", + "bilinear", + "trilinear", + "cubic", + "bicubic", + "tricubic", + "lanczos3", + "lanczos5", + ) + is_jax_supported = mode in supported_methods + if not is_jax_supported: + raise torchax.tensor.OperatorNotFound( + f"JAX does not support interpolation mode: {mode}. Supported modes are: {supported_methods}" + ) + # None check + antialias = antialias or False + align_corners = align_corners or False + + if ( + mode in ("cubic", "bicubic", "tricubic") + and not antialias + and size is not None + ): + return jimage.interpolate_bicubic_no_aa( + input, + size[0], + size[1], + align_corners, + ) + else: + # fallback + raise torchax.tensor.OperatorNotFound( + f"JAX does not support interpolation mode: {mode}. Supported modes are: {supported_methods}" ) - is_jax_supported = mode in supported_methods - if not is_jax_supported: - raise torchax.tensor.OperatorNotFound( - f"JAX does not support interpolation mode: {mode}. Supported modes are: {supported_methods}" - ) - # None check - antialias = antialias or False - align_corners = align_corners or False - - if ( - mode in ("cubic", "bicubic", "tricubic") - and not antialias - and size is not None - ): - return jimage.interpolate_bicubic_no_aa( - input, - size[0], - size[1], - align_corners, - ) - else: - # fallback - raise torchax.tensor.OperatorNotFound( - f"JAX does not support interpolation mode: {mode}. Supported modes are: {supported_methods}" - ) @register_function(torch.Tensor.repeat_interleave) def torch_Tensor_repeat_interleave( - self, repeats, dim=None, *, output_size=None + self, repeats, dim=None, *, output_size=None ): - return jnp.repeat(self, repeats, axis=dim, total_repeat_length=output_size) + return jnp.repeat(self, repeats, axis=dim, total_repeat_length=output_size) diff --git a/torchax/torchax/ops/jtorchvision_nms.py b/torchax/torchax/ops/jtorchvision_nms.py index 401b248af345..6a639ddd6383 100644 --- a/torchax/torchax/ops/jtorchvision_nms.py +++ b/torchax/torchax/ops/jtorchvision_nms.py @@ -14,263 +14,259 @@ def _bbox_overlap(boxes, gt_boxes): - """Find Bounding box overlap. + """Find Bounding box overlap. - Args: - boxes: first set of bounding boxes - gt_boxes: second set of boxes to compute IOU + Args: + boxes: first set of bounding boxes + gt_boxes: second set of boxes to compute IOU - Returns: - iou: Intersection over union matrix of all input bounding boxes - """ - bb_y_min, bb_x_min, bb_y_max, bb_x_max = jnp.split( - ary=boxes, indices_or_sections=4, axis=2 - ) - gt_y_min, gt_x_min, gt_y_max, gt_x_max = jnp.split( - ary=gt_boxes, indices_or_sections=4, axis=2 - ) + Returns: + iou: Intersection over union matrix of all input bounding boxes + """ + bb_y_min, bb_x_min, bb_y_max, bb_x_max = jnp.split( + ary=boxes, indices_or_sections=4, axis=2 + ) + gt_y_min, gt_x_min, gt_y_max, gt_x_max = jnp.split( + ary=gt_boxes, indices_or_sections=4, axis=2 + ) - # Calculates the intersection area. - i_xmin = jnp.maximum(bb_x_min, jnp.transpose(gt_x_min, [0, 2, 1])) - i_xmax = jnp.minimum(bb_x_max, jnp.transpose(gt_x_max, [0, 2, 1])) - i_ymin = jnp.maximum(bb_y_min, jnp.transpose(gt_y_min, [0, 2, 1])) - i_ymax = jnp.minimum(bb_y_max, jnp.transpose(gt_y_max, [0, 2, 1])) - i_area = jnp.maximum((i_xmax - i_xmin), 0) * jnp.maximum( - (i_ymax - i_ymin), 0 - ) + # Calculates the intersection area. + i_xmin = jnp.maximum(bb_x_min, jnp.transpose(gt_x_min, [0, 2, 1])) + i_xmax = jnp.minimum(bb_x_max, jnp.transpose(gt_x_max, [0, 2, 1])) + i_ymin = jnp.maximum(bb_y_min, jnp.transpose(gt_y_min, [0, 2, 1])) + i_ymax = jnp.minimum(bb_y_max, jnp.transpose(gt_y_max, [0, 2, 1])) + i_area = jnp.maximum((i_xmax - i_xmin), 0) * jnp.maximum((i_ymax - i_ymin), 0) - # Calculates the union area. - bb_area = (bb_y_max - bb_y_min) * (bb_x_max - bb_x_min) - gt_area = (gt_y_max - gt_y_min) * (gt_x_max - gt_x_min) - # Adds a small epsilon to avoid divide-by-zero. - u_area = bb_area + jnp.transpose(gt_area, [0, 2, 1]) - i_area + 1e-8 + # Calculates the union area. + bb_area = (bb_y_max - bb_y_min) * (bb_x_max - bb_x_min) + gt_area = (gt_y_max - gt_y_min) * (gt_x_max - gt_x_min) + # Adds a small epsilon to avoid divide-by-zero. + u_area = bb_area + jnp.transpose(gt_area, [0, 2, 1]) - i_area + 1e-8 - # Calculates IoU. - iou = i_area / u_area + # Calculates IoU. + iou = i_area / u_area - return iou + return iou def _self_suppression(in_args): - iou, _, iou_sum = in_args - batch_size = iou.shape[0] - can_suppress_others = jnp.reshape( - jnp.max(iou, 1) <= 0.5, [batch_size, -1, 1] - ).astype(iou.dtype) - iou_suppressed = ( - jnp.reshape( - (jnp.max(can_suppress_others * iou, 1) <= 0.5).astype(iou.dtype), - [batch_size, -1, 1], - ) - * iou + iou, _, iou_sum = in_args + batch_size = iou.shape[0] + can_suppress_others = jnp.reshape( + jnp.max(iou, 1) <= 0.5, [batch_size, -1, 1] + ).astype(iou.dtype) + iou_suppressed = ( + jnp.reshape( + (jnp.max(can_suppress_others * iou, 1) <= 0.5).astype(iou.dtype), + [batch_size, -1, 1], ) - iou_sum_new = jnp.sum(iou_suppressed, [1, 2]) - return iou_suppressed, jnp.any(iou_sum - iou_sum_new > 0.5), iou_sum_new + * iou + ) + iou_sum_new = jnp.sum(iou_suppressed, [1, 2]) + return iou_suppressed, jnp.any(iou_sum - iou_sum_new > 0.5), iou_sum_new def _cross_suppression(in_args): - boxes, box_slice, iou_threshold, inner_idx = in_args - batch_size = boxes.shape[0] - new_slice = lax.dynamic_slice( - boxes, - [0, inner_idx * _NMS_TILE_SIZE, 0], - [batch_size, _NMS_TILE_SIZE, 4], + boxes, box_slice, iou_threshold, inner_idx = in_args + batch_size = boxes.shape[0] + new_slice = lax.dynamic_slice( + boxes, + [0, inner_idx * _NMS_TILE_SIZE, 0], + [batch_size, _NMS_TILE_SIZE, 4], + ) + iou = _bbox_overlap(new_slice, box_slice) + ret_slice = ( + jnp.expand_dims( + (jnp.all(iou < iou_threshold, [1])).astype(box_slice.dtype), 2 ) - iou = _bbox_overlap(new_slice, box_slice) - ret_slice = ( - jnp.expand_dims( - (jnp.all(iou < iou_threshold, [1])).astype(box_slice.dtype), 2 - ) - * box_slice - ) - return boxes, ret_slice, iou_threshold, inner_idx + 1 + * box_slice + ) + return boxes, ret_slice, iou_threshold, inner_idx + 1 def _suppression_loop_body(in_args): - """Process boxes in the range [idx*_NMS_TILE_SIZE, (idx+1)*_NMS_TILE_SIZE). - - Args: - in_args: A tuple of arguments: boxes, iou_threshold, output_size, idx - - Returns: - boxes: updated boxes. - iou_threshold: pass down iou_threshold to the next iteration. - output_size: the updated output_size. - idx: the updated induction variable. - """ - boxes, iou_threshold, output_size, idx = in_args - num_tiles = boxes.shape[1] // _NMS_TILE_SIZE - batch_size = boxes.shape[0] - - # Iterates over tiles that can possibly suppress the current tile. - box_slice = lax.dynamic_slice( - boxes, [0, idx * _NMS_TILE_SIZE, 0], [batch_size, _NMS_TILE_SIZE, 4] - ) - - def _loop_cond(in_args): - _, _, _, inner_idx = in_args - return inner_idx < idx - - _, box_slice, _, _ = lax.while_loop( - _loop_cond, _cross_suppression, (boxes, box_slice, iou_threshold, 0) - ) - - # Iterates over the current tile to compute self-suppression. - iou = _bbox_overlap(box_slice, box_slice) - mask = jnp.expand_dims( - jnp.reshape(jnp.arange(_NMS_TILE_SIZE), [1, -1]) - > jnp.reshape(jnp.arange(_NMS_TILE_SIZE), [-1, 1]), - 0, - ) - iou *= (jnp.logical_and(mask, iou >= iou_threshold)).astype(iou.dtype) - - def _loop_cond2(in_args): - _, loop_condition, _ = in_args - return loop_condition - - suppressed_iou, _, _ = lax.while_loop( - _loop_cond2, _self_suppression, (iou, True, jnp.sum(iou, [1, 2])) - ) - suppressed_box = jnp.sum(suppressed_iou, 1) > 0 - box_slice *= jnp.expand_dims( - 1.0 - suppressed_box.astype(box_slice.dtype), 2 - ) - - # Uses box_slice to update the input boxes. - mask = jnp.reshape( - (jnp.equal(jnp.arange(num_tiles), idx)).astype(boxes.dtype), - [1, -1, 1, 1], - ) - boxes = jnp.tile( - jnp.expand_dims(box_slice, 1), [1, num_tiles, 1, 1] - ) * mask + jnp.reshape( - boxes, [batch_size, num_tiles, _NMS_TILE_SIZE, 4] - ) * (1 - mask) - boxes = jnp.reshape(boxes, [batch_size, -1, 4]) - - # Updates output_size. - output_size += jnp.sum(jnp.any(box_slice > 0, [2]).astype(jnp.int32), [1]) - return boxes, iou_threshold, output_size, idx + 1 + """Process boxes in the range [idx*_NMS_TILE_SIZE, (idx+1)*_NMS_TILE_SIZE). + + Args: + in_args: A tuple of arguments: boxes, iou_threshold, output_size, idx + + Returns: + boxes: updated boxes. + iou_threshold: pass down iou_threshold to the next iteration. + output_size: the updated output_size. + idx: the updated induction variable. + """ + boxes, iou_threshold, output_size, idx = in_args + num_tiles = boxes.shape[1] // _NMS_TILE_SIZE + batch_size = boxes.shape[0] + + # Iterates over tiles that can possibly suppress the current tile. + box_slice = lax.dynamic_slice( + boxes, [0, idx * _NMS_TILE_SIZE, 0], [batch_size, _NMS_TILE_SIZE, 4] + ) + + def _loop_cond(in_args): + _, _, _, inner_idx = in_args + return inner_idx < idx + + _, box_slice, _, _ = lax.while_loop( + _loop_cond, _cross_suppression, (boxes, box_slice, iou_threshold, 0) + ) + + # Iterates over the current tile to compute self-suppression. + iou = _bbox_overlap(box_slice, box_slice) + mask = jnp.expand_dims( + jnp.reshape(jnp.arange(_NMS_TILE_SIZE), [1, -1]) + > jnp.reshape(jnp.arange(_NMS_TILE_SIZE), [-1, 1]), + 0, + ) + iou *= (jnp.logical_and(mask, iou >= iou_threshold)).astype(iou.dtype) + + def _loop_cond2(in_args): + _, loop_condition, _ = in_args + return loop_condition + + suppressed_iou, _, _ = lax.while_loop( + _loop_cond2, _self_suppression, (iou, True, jnp.sum(iou, [1, 2])) + ) + suppressed_box = jnp.sum(suppressed_iou, 1) > 0 + box_slice *= jnp.expand_dims(1.0 - suppressed_box.astype(box_slice.dtype), 2) + + # Uses box_slice to update the input boxes. + mask = jnp.reshape( + (jnp.equal(jnp.arange(num_tiles), idx)).astype(boxes.dtype), + [1, -1, 1, 1], + ) + boxes = jnp.tile( + jnp.expand_dims(box_slice, 1), [1, num_tiles, 1, 1] + ) * mask + jnp.reshape(boxes, [batch_size, num_tiles, _NMS_TILE_SIZE, 4]) * ( + 1 - mask + ) + boxes = jnp.reshape(boxes, [batch_size, -1, 4]) + + # Updates output_size. + output_size += jnp.sum(jnp.any(box_slice > 0, [2]).astype(jnp.int32), [1]) + return boxes, iou_threshold, output_size, idx + 1 def non_max_suppression_padded(scores, boxes, max_output_size, iou_threshold): - """A wrapper that handles non-maximum suppression. - - Assumption: - * The boxes are sorted by scores unless the box is a dot (all coordinates - are zero). - * Boxes with higher scores can be used to suppress boxes with lower scores. - - The overal design of the algorithm is to handle boxes tile-by-tile: - - boxes = boxes.pad_to_multiply_of(tile_size) - num_tiles = len(boxes) // tile_size - output_boxes = [] - for i in range(num_tiles): - box_tile = boxes[i*tile_size : (i+1)*tile_size] - for j in range(i - 1): - suppressing_tile = boxes[j*tile_size : (j+1)*tile_size] - iou = _bbox_overlap(box_tile, suppressing_tile) - # if the box is suppressed in iou, clear it to a dot - box_tile *= _update_boxes(iou) - # Iteratively handle the diagnal tile. - iou = _box_overlap(box_tile, box_tile) - iou_changed = True - while iou_changed: - # boxes that are not suppressed by anything else - suppressing_boxes = _get_suppressing_boxes(iou) - # boxes that are suppressed by suppressing_boxes - suppressed_boxes = _get_suppressed_boxes(iou, suppressing_boxes) - # clear iou to 0 for boxes that are suppressed, as they cannot be used - # to suppress other boxes any more - new_iou = _clear_iou(iou, suppressed_boxes) - iou_changed = (new_iou != iou) - iou = new_iou - # remaining boxes that can still suppress others, are selected boxes. - output_boxes.append(_get_suppressing_boxes(iou)) - if len(output_boxes) >= max_output_size: - break - - Args: - scores: a tensor with a shape of [batch_size, anchors]. - boxes: a tensor with a shape of [batch_size, anchors, 4]. - max_output_size: a scalar integer `Tensor` representing the maximum number - of boxes to be selected by non max suppression. - iou_threshold: a float representing the threshold for deciding whether boxes - overlap too much with respect to IOU. - Returns: - nms_scores: a tensor with a shape of [batch_size, anchors]. It has same - dtype as input scores. - nms_proposals: a tensor with a shape of [batch_size, anchors, 4]. It has - same dtype as input boxes. - """ - batch_size = boxes.shape[0] - num_boxes = boxes.shape[1] - pad = ( - int(jnp.ceil(float(num_boxes) / _NMS_TILE_SIZE)) * _NMS_TILE_SIZE - - num_boxes - ) - boxes = jnp.pad(boxes.astype(jnp.float32), [[0, 0], [0, pad], [0, 0]]) - scores = jnp.pad(scores.astype(jnp.float32), [[0, 0], [0, pad]]) - num_boxes += pad - - def _loop_cond(in_args): - unused_boxes, unused_threshold, output_size, idx = in_args - return jnp.logical_and( - jnp.min(output_size) < max_output_size, - idx < num_boxes // _NMS_TILE_SIZE, - ) - - selected_boxes, _, output_size, _ = lax.while_loop( - _loop_cond, - _suppression_loop_body, - (boxes, iou_threshold, jnp.zeros([batch_size], jnp.int32), 0), - ) - idx = num_boxes - lax.top_k( - jnp.any(selected_boxes > 0, [2]).astype(jnp.int32) - * jnp.expand_dims(jnp.arange(num_boxes, 0, -1), 0), - max_output_size, - )[0].astype(jnp.int32) - idx = jnp.minimum(idx, num_boxes - 1) - idx = jnp.reshape( - idx + jnp.reshape(jnp.arange(batch_size) * num_boxes, [-1, 1]), [-1] + """A wrapper that handles non-maximum suppression. + + Assumption: + * The boxes are sorted by scores unless the box is a dot (all coordinates + are zero). + * Boxes with higher scores can be used to suppress boxes with lower scores. + + The overal design of the algorithm is to handle boxes tile-by-tile: + + boxes = boxes.pad_to_multiply_of(tile_size) + num_tiles = len(boxes) // tile_size + output_boxes = [] + for i in range(num_tiles): + box_tile = boxes[i*tile_size : (i+1)*tile_size] + for j in range(i - 1): + suppressing_tile = boxes[j*tile_size : (j+1)*tile_size] + iou = _bbox_overlap(box_tile, suppressing_tile) + # if the box is suppressed in iou, clear it to a dot + box_tile *= _update_boxes(iou) + # Iteratively handle the diagnal tile. + iou = _box_overlap(box_tile, box_tile) + iou_changed = True + while iou_changed: + # boxes that are not suppressed by anything else + suppressing_boxes = _get_suppressing_boxes(iou) + # boxes that are suppressed by suppressing_boxes + suppressed_boxes = _get_suppressed_boxes(iou, suppressing_boxes) + # clear iou to 0 for boxes that are suppressed, as they cannot be used + # to suppress other boxes any more + new_iou = _clear_iou(iou, suppressed_boxes) + iou_changed = (new_iou != iou) + iou = new_iou + # remaining boxes that can still suppress others, are selected boxes. + output_boxes.append(_get_suppressing_boxes(iou)) + if len(output_boxes) >= max_output_size: + break + + Args: + scores: a tensor with a shape of [batch_size, anchors]. + boxes: a tensor with a shape of [batch_size, anchors, 4]. + max_output_size: a scalar integer `Tensor` representing the maximum number + of boxes to be selected by non max suppression. + iou_threshold: a float representing the threshold for deciding whether boxes + overlap too much with respect to IOU. + Returns: + nms_scores: a tensor with a shape of [batch_size, anchors]. It has same + dtype as input scores. + nms_proposals: a tensor with a shape of [batch_size, anchors, 4]. It has + same dtype as input boxes. + """ + batch_size = boxes.shape[0] + num_boxes = boxes.shape[1] + pad = ( + int(jnp.ceil(float(num_boxes) / _NMS_TILE_SIZE)) * _NMS_TILE_SIZE + - num_boxes + ) + boxes = jnp.pad(boxes.astype(jnp.float32), [[0, 0], [0, pad], [0, 0]]) + scores = jnp.pad(scores.astype(jnp.float32), [[0, 0], [0, pad]]) + num_boxes += pad + + def _loop_cond(in_args): + unused_boxes, unused_threshold, output_size, idx = in_args + return jnp.logical_and( + jnp.min(output_size) < max_output_size, + idx < num_boxes // _NMS_TILE_SIZE, ) - return idx - boxes = jnp.reshape( - (jnp.reshape(boxes, [-1, 4]))[idx], [batch_size, max_output_size, 4] - ) - boxes = boxes * ( - jnp.reshape(jnp.arange(max_output_size), [1, -1, 1]) - < jnp.reshape(output_size, [-1, 1, 1]) - ).astype(boxes.dtype) - scores = jnp.reshape( - jnp.reshape(scores, [-1, 1])[idx], [batch_size, max_output_size] - ) - scores = scores * ( - jnp.reshape(jnp.arange(max_output_size), [1, -1]) - < jnp.reshape(output_size, [-1, 1]) - ).astype(scores.dtype) - return scores, boxes + selected_boxes, _, output_size, _ = lax.while_loop( + _loop_cond, + _suppression_loop_body, + (boxes, iou_threshold, jnp.zeros([batch_size], jnp.int32), 0), + ) + idx = num_boxes - lax.top_k( + jnp.any(selected_boxes > 0, [2]).astype(jnp.int32) + * jnp.expand_dims(jnp.arange(num_boxes, 0, -1), 0), + max_output_size, + )[0].astype(jnp.int32) + idx = jnp.minimum(idx, num_boxes - 1) + idx = jnp.reshape( + idx + jnp.reshape(jnp.arange(batch_size) * num_boxes, [-1, 1]), [-1] + ) + + return idx + boxes = jnp.reshape( + (jnp.reshape(boxes, [-1, 4]))[idx], [batch_size, max_output_size, 4] + ) + boxes = boxes * ( + jnp.reshape(jnp.arange(max_output_size), [1, -1, 1]) + < jnp.reshape(output_size, [-1, 1, 1]) + ).astype(boxes.dtype) + scores = jnp.reshape( + jnp.reshape(scores, [-1, 1])[idx], [batch_size, max_output_size] + ) + scores = scores * ( + jnp.reshape(jnp.arange(max_output_size), [1, -1]) + < jnp.reshape(output_size, [-1, 1]) + ).astype(scores.dtype) + return scores, boxes # registry: def nms(boxes, scores, iou_threshold): - max_output_size = boxes.shape[0] - boxes = boxes.reshape((1, *boxes.shape)) - scores = scores.reshape((1, *scores.shape)) - res = non_max_suppression_padded( - scores, boxes, max_output_size, iou_threshold - ) - return res + max_output_size = boxes.shape[0] + boxes = boxes.reshape((1, *boxes.shape)) + scores = scores.reshape((1, *scores.shape)) + res = non_max_suppression_padded( + scores, boxes, max_output_size, iou_threshold + ) + return res try: - import torch - import torchvision + import torch + import torchvision - ops_registry.register_torch_dispatch_op(torch.ops.torchvision.nms, nms) + ops_registry.register_torch_dispatch_op(torch.ops.torchvision.nms, nms) except Exception: - pass + pass diff --git a/torchax/torchax/ops/mappings.py b/torchax/torchax/ops/mappings.py index d363e19f1a09..c5dc29cc8543 100644 --- a/torchax/torchax/ops/mappings.py +++ b/torchax/torchax/ops/mappings.py @@ -8,95 +8,95 @@ def t2j(t, use_dlpack=True): - is_bool = False - if t.dtype == torch.bool: - is_bool = True - t = t.to(torch.int8) + is_bool = False + if t.dtype == torch.bool: + is_bool = True + t = t.to(torch.int8) + + t = t.to_dense() + + if not t.is_contiguous(): + t = t.contiguous() + + res = None + if use_dlpack: + try: + res = jaxdl.from_dlpack(t) + except Exception: + pass + + if res is None: + # https://github.com/google/jax/issues/7657 + # https://github.com/google/jax/issues/17784 + if t.dtype == torch.bfloat16: + nparray = ( + t.cpu().detach().to(torch.float32).numpy() + ) # numpy don't support bfloat16 + else: + nparray = t.cpu().detach().numpy() + res = jnp.asarray(nparray) + if t.dtype == torch.bfloat16: + res = res.astype(jnp.bfloat16) + + if is_bool: + res = res.astype(jnp.bool_) + return res - t = t.to_dense() - - if not t.is_contiguous(): - t = t.contiguous() +def j2t(x, use_dlpack=True): + with mode_utils.no_dispatch(), torch._C.DisableTorchFunction(): res = None if use_dlpack: - try: - res = jaxdl.from_dlpack(t) - except Exception: - pass - - if res is None: - # https://github.com/google/jax/issues/7657 - # https://github.com/google/jax/issues/17784 - if t.dtype == torch.bfloat16: - nparray = ( - t.cpu().detach().to(torch.float32).numpy() - ) # numpy don't support bfloat16 - else: - nparray = t.cpu().detach().numpy() - res = jnp.asarray(nparray) - if t.dtype == torch.bfloat16: - res = res.astype(jnp.bfloat16) - - if is_bool: - res = res.astype(jnp.bool_) - return res - - -def j2t(x, use_dlpack=True): - with mode_utils.no_dispatch(), torch._C.DisableTorchFunction(): + try: + dl = jaxdl.to_dlpack(x) + res = torchdl.from_dlpack(dl) + except Exception: res = None - if use_dlpack: - try: - dl = jaxdl.to_dlpack(x) - res = torchdl.from_dlpack(dl) - except Exception: - res = None - orig_dtype = None - if res is None: - orig_dtype = None - if x.dtype == jnp.bfloat16.dtype: - orig_dtype = x.dtype - x = x.astype(jnp.float32.dtype) - res = torch.from_numpy(numpy.asarray(x)) + orig_dtype = None + if res is None: + orig_dtype = None + if x.dtype == jnp.bfloat16.dtype: + orig_dtype = x.dtype + x = x.astype(jnp.float32.dtype) + res = torch.from_numpy(numpy.asarray(x)) - if x.dtype == jnp.bool_: - res = res.to(torch.bool) + if x.dtype == jnp.bool_: + res = res.to(torch.bool) - if orig_dtype is not None: - res = res.to(j2t_dtype(orig_dtype)) - return res + if orig_dtype is not None: + res = res.to(j2t_dtype(orig_dtype)) + return res TORCH_DTYPE_TO_JAX = { - # NO_MAPPING : jnp.float0.dtype (signless scalar int), - torch.bool: jnp.bool_.dtype, - # NO_MAPPING : jnp.int4.dtype, - torch.int8: jnp.int8.dtype, - torch.int16: jnp.int16.dtype, - torch.int32: jnp.int32.dtype, - torch.int64: jnp.int64.dtype, - torch.long: jnp.int64.dtype, - # NO_MAPPING : jnp.uint4 - torch.uint8: jnp.uint8.dtype, - torch.uint16: jnp.uint16.dtype, - torch.uint32: jnp.uint32.dtype, - torch.uint64: jnp.uint64.dtype, - # NO_MAPPING : jnp.float8_e4m3b11fnuz.dtype, - torch.float8_e4m3fn: jnp.float8_e4m3fn.dtype, - # NO_MAPPING : jnp.float8_e4m3fnuz.dtype, - torch.float8_e5m2: jnp.float8_e5m2.dtype, - # NO_MAPPING : jnp.float8_e5m2fnuz.dtype, - torch.bfloat16: jnp.bfloat16.dtype, - torch.half: jnp.float16.dtype, - torch.float16: jnp.float16.dtype, - torch.float32: jnp.float32.dtype, - torch.float64: jnp.float64.dtype, - torch.double: jnp.double.dtype, - torch.complex64: jnp.complex64.dtype, - torch.complex128: jnp.complex128.dtype, - None: None, + # NO_MAPPING : jnp.float0.dtype (signless scalar int), + torch.bool: jnp.bool_.dtype, + # NO_MAPPING : jnp.int4.dtype, + torch.int8: jnp.int8.dtype, + torch.int16: jnp.int16.dtype, + torch.int32: jnp.int32.dtype, + torch.int64: jnp.int64.dtype, + torch.long: jnp.int64.dtype, + # NO_MAPPING : jnp.uint4 + torch.uint8: jnp.uint8.dtype, + torch.uint16: jnp.uint16.dtype, + torch.uint32: jnp.uint32.dtype, + torch.uint64: jnp.uint64.dtype, + # NO_MAPPING : jnp.float8_e4m3b11fnuz.dtype, + torch.float8_e4m3fn: jnp.float8_e4m3fn.dtype, + # NO_MAPPING : jnp.float8_e4m3fnuz.dtype, + torch.float8_e5m2: jnp.float8_e5m2.dtype, + # NO_MAPPING : jnp.float8_e5m2fnuz.dtype, + torch.bfloat16: jnp.bfloat16.dtype, + torch.half: jnp.float16.dtype, + torch.float16: jnp.float16.dtype, + torch.float32: jnp.float32.dtype, + torch.float64: jnp.float64.dtype, + torch.double: jnp.double.dtype, + torch.complex64: jnp.complex64.dtype, + torch.complex128: jnp.complex128.dtype, + None: None, } JAX_DTYPE_TO_TORCH = {value: key for key, value in TORCH_DTYPE_TO_JAX.items()} @@ -106,16 +106,16 @@ def j2t(x, use_dlpack=True): def t2j_dtype(dtype): - if dtype not in TORCH_DTYPE_TO_JAX: - raise RuntimeError( - f"Attempting to convert unknown type: {dtype} to jax type," - ) - return TORCH_DTYPE_TO_JAX[dtype] + if dtype not in TORCH_DTYPE_TO_JAX: + raise RuntimeError( + f"Attempting to convert unknown type: {dtype} to jax type," + ) + return TORCH_DTYPE_TO_JAX[dtype] def j2t_dtype(dtype): - if dtype not in JAX_DTYPE_TO_TORCH: - raise RuntimeError( - f"Attempting to convert unknown type: {dtype} to torch type," - ) - return JAX_DTYPE_TO_TORCH[dtype] + if dtype not in JAX_DTYPE_TO_TORCH: + raise RuntimeError( + f"Attempting to convert unknown type: {dtype} to torch type," + ) + return JAX_DTYPE_TO_TORCH[dtype] diff --git a/torchax/torchax/ops/op_base.py b/torchax/torchax/ops/op_base.py index 06cdd4b4806a..78b63771d379 100644 --- a/torchax/torchax/ops/op_base.py +++ b/torchax/torchax/ops/op_base.py @@ -12,125 +12,125 @@ class InplaceOp: - def __init__( - self, - functional_op, - replace=False, - position_to_mutate=0, - is_jax_func=False, - ): - self.functional = functional_op - self.replace = replace - self.position_to_mutate = position_to_mutate - self.is_jax_func = is_jax_func - - def __call__(self, *args, **kwargs): - to_mutate = args[self.position_to_mutate] - view_value = to_mutate - if isinstance(to_mutate, View): - view_value = to_mutate.torch() - # Convert the target View to a Tensor, and - # leave the rest args as is. If other args are - # also View, they will be converted to tensors - # in the self.functional dispatch. - env = view_value._env - if self.is_jax_func: - view_value, args, kwargs = env.t2j_iso((view_value, args, kwargs)) - new_value_jax = self.functional(view_value, *args[1:], **kwargs) - new_value = env.j2t_iso(new_value_jax) - else: - new_value = self.functional(view_value, *args[1:], **kwargs) - - if isinstance(to_mutate, View): - to_mutate.update(new_value) - else: - if self.replace: - to_mutate._elem = new_value._elem - else: - to_mutate.copy_(new_value) - return to_mutate + def __init__( + self, + functional_op, + replace=False, + position_to_mutate=0, + is_jax_func=False, + ): + self.functional = functional_op + self.replace = replace + self.position_to_mutate = position_to_mutate + self.is_jax_func = is_jax_func + + def __call__(self, *args, **kwargs): + to_mutate = args[self.position_to_mutate] + view_value = to_mutate + if isinstance(to_mutate, View): + view_value = to_mutate.torch() + # Convert the target View to a Tensor, and + # leave the rest args as is. If other args are + # also View, they will be converted to tensors + # in the self.functional dispatch. + env = view_value._env + if self.is_jax_func: + view_value, args, kwargs = env.t2j_iso((view_value, args, kwargs)) + new_value_jax = self.functional(view_value, *args[1:], **kwargs) + new_value = env.j2t_iso(new_value_jax) + else: + new_value = self.functional(view_value, *args[1:], **kwargs) + + if isinstance(to_mutate, View): + to_mutate.update(new_value) + else: + if self.replace: + to_mutate._elem = new_value._elem + else: + to_mutate.copy_(new_value) + return to_mutate class OutVariant: - def __call__(self, *args, **kwargs): - to_mutate = kwargs["out"] - del kwargs["out"] - to_mutate._elem = self.functional(*args, **kwargs)._elem - return to_mutate + def __call__(self, *args, **kwargs): + to_mutate = kwargs["out"] + del kwargs["out"] + to_mutate._elem = self.functional(*args, **kwargs)._elem + return to_mutate P = ParamSpec("P") def convert_dtype(use_default_dtype: bool = True): - """Converts `dtype` kwarg of function from torch to JAX. + """Converts `dtype` kwarg of function from torch to JAX. - Args: - use_default_dtype: Whether to use torch default dtype if none is provided. + Args: + use_default_dtype: Whether to use torch default dtype if none is provided. - Returns: - A decorator that wraps a JAX implementation of a torch function. - """ + Returns: + A decorator that wraps a JAX implementation of a torch function. + """ - def decorator(func: types.TorchCallable): - @functools.wraps(func) - def wrapper( - *args: P.args, - dtype: Optional[torch.dtype] = None, - **kwargs: P.kwargs, - ): - if not dtype and use_default_dtype: - dtype = torch.get_default_dtype() - if isinstance(dtype, torch.dtype): - jax_dtype = mappings.t2j_dtype(dtype) - else: - jax_dtype = dtype + def decorator(func: types.TorchCallable): + @functools.wraps(func) + def wrapper( + *args: P.args, + dtype: Optional[torch.dtype] = None, + **kwargs: P.kwargs, + ): + if not dtype and use_default_dtype: + dtype = torch.get_default_dtype() + if isinstance(dtype, torch.dtype): + jax_dtype = mappings.t2j_dtype(dtype) + else: + jax_dtype = dtype - return func(*args, dtype=jax_dtype, **kwargs) + return func(*args, dtype=jax_dtype, **kwargs) - return wrapper + return wrapper - return decorator + return decorator def maybe_convert_constant_dtype( - val: Optional[types.JaxValue], dtype: Optional[jnp.dtype] + val: Optional[types.JaxValue], dtype: Optional[jnp.dtype] ): - """Optionally converts scalar constant's dtype using `numpy` + """Optionally converts scalar constant's dtype using `numpy` - Use in cases where you require a constant and can't handle a traced array. - """ - if val and dtype: - if isinstance(val, jax.Array): - return maybe_convert_constant_dtype(val.item(), dtype) + Use in cases where you require a constant and can't handle a traced array. + """ + if val and dtype: + if isinstance(val, jax.Array): + return maybe_convert_constant_dtype(val.item(), dtype) - return np.array(val, dtype) + return np.array(val, dtype) - return val + return val def promote_int_input(f: Callable[Concatenate[jax.Array, P], types.JaxValue]): - """If the first argument is an int array, promote it to float32.""" + """If the first argument is an int array, promote it to float32.""" - @functools.wraps(f) - def wrapper(x: jax.Array, *args: P.args, **kwargs: P.kwargs): - if x.dtype in [jnp.int8, jnp.int16, jnp.int32, jnp.int64]: - x = x.astype(mappings.t2j_dtype(torch.get_default_dtype())) + @functools.wraps(f) + def wrapper(x: jax.Array, *args: P.args, **kwargs: P.kwargs): + if x.dtype in [jnp.int8, jnp.int16, jnp.int32, jnp.int64]: + x = x.astype(mappings.t2j_dtype(torch.get_default_dtype())) - return f(x, *args, **kwargs) + return f(x, *args, **kwargs) - return wrapper + return wrapper def foreach_loop( - seq: jax.Array, - fn: Callable[[jax.Array, jax.Array], jax.Array], - init_val=0.0, + seq: jax.Array, + fn: Callable[[jax.Array, jax.Array], jax.Array], + init_val=0.0, ): - """Run `fn` for each element of 1D array `seq`. + """Run `fn` for each element of 1D array `seq`. - Similar to `functools.reduce`, but implemented with `jax.lax.fori_loop`.""" - assert len(seq.shape) == 1 - return jax.lax.fori_loop( - 0, len(seq), lambda i, carry: fn(carry, seq[i]), init_val - ) + Similar to `functools.reduce`, but implemented with `jax.lax.fori_loop`.""" + assert len(seq.shape) == 1 + return jax.lax.fori_loop( + 0, len(seq), lambda i, carry: fn(carry, seq[i]), init_val + ) diff --git a/torchax/torchax/ops/ops_registry.py b/torchax/torchax/ops/ops_registry.py index 4d8cb770a72c..7ba29cdcbd3b 100644 --- a/torchax/torchax/ops/ops_registry.py +++ b/torchax/torchax/ops/ops_registry.py @@ -7,12 +7,12 @@ @dataclasses.dataclass class Operator: - torch_op: TorchCallable - func: Union[TorchCallable, JaxCallable] - is_jax_function: bool - is_user_defined: bool - needs_env: bool - is_view_op: bool + torch_op: TorchCallable + func: Union[TorchCallable, JaxCallable] + is_jax_function: bool + is_user_defined: bool + needs_env: bool + is_view_op: bool all_aten_ops: Dict[TorchCallable, Operator] = {} @@ -20,42 +20,42 @@ class Operator: def register_torch_dispatch_op( + aten_op, + impl_callable, + is_jax_function=True, + is_user_defined=False, + needs_env=False, + is_view_op=False, +): + op = Operator( aten_op, impl_callable, - is_jax_function=True, - is_user_defined=False, - needs_env=False, - is_view_op=False, -): - op = Operator( - aten_op, - impl_callable, - is_jax_function=is_jax_function, - is_user_defined=is_user_defined, - needs_env=needs_env, - is_view_op=is_view_op, - ) - if aten_op in all_aten_ops: - logging.warning(f"Duplicate op registration for {aten_op}") - all_aten_ops[aten_op] = op - return impl_callable + is_jax_function=is_jax_function, + is_user_defined=is_user_defined, + needs_env=needs_env, + is_view_op=is_view_op, + ) + if aten_op in all_aten_ops: + logging.warning(f"Duplicate op registration for {aten_op}") + all_aten_ops[aten_op] = op + return impl_callable def register_torch_function_op( + torch_func, + impl_callable, + is_jax_function=True, + is_user_defined=False, + needs_env=False, + is_view_op=False, +): + op = Operator( torch_func, impl_callable, - is_jax_function=True, - is_user_defined=False, - needs_env=False, - is_view_op=False, -): - op = Operator( - torch_func, - impl_callable, - is_jax_function=is_jax_function, - is_user_defined=is_user_defined, - needs_env=needs_env, - is_view_op=is_view_op, - ) - all_torch_functions[torch_func] = op - return impl_callable + is_jax_function=is_jax_function, + is_user_defined=is_user_defined, + needs_env=needs_env, + is_view_op=is_view_op, + ) + all_torch_functions[torch_func] = op + return impl_callable diff --git a/torchax/torchax/tensor.py b/torchax/torchax/tensor.py index dfbc851a55ce..f2b87a87d269 100644 --- a/torchax/torchax/tensor.py +++ b/torchax/torchax/tensor.py @@ -22,292 +22,285 @@ class OperatorNotFound(Exception): - pass + pass def wrap(jaxarray): - return torch_pytree.tree_map_only(jnp.ndarray, Tensor, jaxarray) + return torch_pytree.tree_map_only(jnp.ndarray, Tensor, jaxarray) def unwrap(torchtensors): - return torch_pytree.tree_map_only(Tensor, lambda x: x._elem, torchtensors) + return torch_pytree.tree_map_only(Tensor, lambda x: x._elem, torchtensors) @contextlib.contextmanager def log_nested(env, message): - if env.config.debug_print_each_op: - print((" " * log_nested.level) + message, file=sys.stderr) - log_nested.level += 1 - yield - log_nested.level -= 1 + if env.config.debug_print_each_op: + print((" " * log_nested.level) + message, file=sys.stderr) + log_nested.level += 1 + yield + log_nested.level -= 1 log_nested.level = 0 class Tensor(torch.Tensor): - @staticmethod - def __new__(cls, elem, env): - dtype = mappings.j2t_dtype(elem.dtype) - shape = list(elem.shape) - for i, s in enumerate(shape): - if not isinstance(s, int): - shape[i] = 1 - if dtype is None: - dtype = torch.float32 - return torch.Tensor._make_wrapper_subclass( - cls, - shape, - dtype=dtype, - device="meta", - requires_grad=False, - ) + @staticmethod + def __new__(cls, elem, env): + dtype = mappings.j2t_dtype(elem.dtype) + shape = list(elem.shape) + for i, s in enumerate(shape): + if not isinstance(s, int): + shape[i] = 1 + if dtype is None: + dtype = torch.float32 + return torch.Tensor._make_wrapper_subclass( + cls, + shape, + dtype=dtype, + device="meta", + requires_grad=False, + ) - def __init__(self, elem: jax.Array, env: "Environment"): - super().__init__() - self._elem = elem - self._env = env + def __init__(self, elem: jax.Array, env: "Environment"): + super().__init__() + self._elem = elem + self._env = env - def __str__(self): - return "Tensor({} {})".format(str(type(self._elem)), str(self._elem)) + def __str__(self): + return "Tensor({} {})".format(str(type(self._elem)), str(self._elem)) - __repr__ = __str__ + __repr__ = __str__ - def __jax_array__(self): - return self._elem + def __jax_array__(self): + return self._elem - @property - def shape(self): - return torch.Size(self._elem.shape) + @property + def shape(self): + return torch.Size(self._elem.shape) - @property - def ndim(self): - return len(self._elem.shape) + @property + def ndim(self): + return len(self._elem.shape) - def flatten(self, start_dim=0, end_dim=-1): - if end_dim == -1: - end_dim = self.ndim - new_shape = ( - self._elem.shape[:start_dim] - + (-1,) - + self._elem.shape[end_dim + 1 :] - ) - new_elem = jnp.reshape(self._elem, new_shape) - return Tensor(new_elem, self._env) - # return torch.reshape(self, new_shape) - - def __setitem__(self, key, val): - key, val = self._env.t2j_iso((key, val)) - self._elem = self._elem.at[key].set(val) - - def type_as(self, other): - self._elem = self._elem.astype(other._elem.dtype) - return self - - __torch_function__ = torch._C._disabled_torch_function_impl - - @classmethod - def __torch_dispatch__(cls, func, types, args=(), kwargs=None): - # TODO(hanq): figure out why is dispatch mode not sufficient - if func == torch.ops._c10d_functional.wait_tensor.default: - return args[0]._env.dispatch(func, types, args, kwargs) - raise AssertionError( - "torchax Tensors can only do math within the torchax environment." - "Please wrap your code with `with torchax.default_env()` or " - "call torchax.enable_globally() before." - ) + def flatten(self, start_dim=0, end_dim=-1): + if end_dim == -1: + end_dim = self.ndim + new_shape = ( + self._elem.shape[:start_dim] + (-1,) + self._elem.shape[end_dim + 1 :] + ) + new_elem = jnp.reshape(self._elem, new_shape) + return Tensor(new_elem, self._env) + # return torch.reshape(self, new_shape) + + def __setitem__(self, key, val): + key, val = self._env.t2j_iso((key, val)) + self._elem = self._elem.at[key].set(val) + + def type_as(self, other): + self._elem = self._elem.astype(other._elem.dtype) + return self + + __torch_function__ = torch._C._disabled_torch_function_impl + + @classmethod + def __torch_dispatch__(cls, func, types, args=(), kwargs=None): + # TODO(hanq): figure out why is dispatch mode not sufficient + if func == torch.ops._c10d_functional.wait_tensor.default: + return args[0]._env.dispatch(func, types, args, kwargs) + raise AssertionError( + "torchax Tensors can only do math within the torchax environment." + "Please wrap your code with `with torchax.default_env()` or " + "call torchax.enable_globally() before." + ) - def detach(self): - return Tensor(jax.lax.stop_gradient(self.jax()), self._env) + def detach(self): + return Tensor(jax.lax.stop_gradient(self.jax()), self._env) - def numpy(self) -> numpy.ndarray: - import numpy as np + def numpy(self) -> numpy.ndarray: + import numpy as np - return np.array(self._elem) + return np.array(self._elem) - def jax(self) -> jax.Array: - return self._elem + def jax(self) -> jax.Array: + return self._elem - def torch(self) -> torch.Tensor: - return self._env.j2t_copy(self.jax()) + def torch(self) -> torch.Tensor: + return self._env.j2t_copy(self.jax()) - @property - def dtype(self): - return mappings.j2t_dtype(self._elem.dtype) + @property + def dtype(self): + return mappings.j2t_dtype(self._elem.dtype) - def dim(self): - return self.ndim + def dim(self): + return self.ndim - @property - def device(self): - return torch.device("jax:0") + @property + def device(self): + return torch.device("jax:0") - @property - def jax_device(self): - return self._elem.device + @property + def jax_device(self): + return self._elem.device - @property - def data(self): - logger.warn( - "In-place to .data modifications still results a copy on TPU" - ) - return self + @property + def data(self): + logger.warn("In-place to .data modifications still results a copy on TPU") + return self - @data.setter - def data(self, other): - if isinstance(other, Tensor): - self._elem = other._elem + @data.setter + def data(self, other): + if isinstance(other, Tensor): + self._elem = other._elem - def apply_jax(self, jax_function, *args, **kwargs): - # Call a jax function on _elem - res = jax_function(self._elem, *args, **kwargs) - return self._env.j2t_iso(res) + def apply_jax(self, jax_function, *args, **kwargs): + # Call a jax function on _elem + res = jax_function(self._elem, *args, **kwargs) + return self._env.j2t_iso(res) - def apply_jax_(self, jax_function, *args, **kwargs): - self._elem = jax_function(self._elem, *args, **kwargs) - return self + def apply_jax_(self, jax_function, *args, **kwargs): + self._elem = jax_function(self._elem, *args, **kwargs) + return self - def tolist(self): - return self._elem.tolist() + def tolist(self): + return self._elem.tolist() - def shard_(self, sharding): - self.apply_jax_(jax.lax.with_sharding_constraint, sharding) + def shard_(self, sharding): + self.apply_jax_(jax.lax.with_sharding_constraint, sharding) def debug_accuracy(func, args, kwargs, current_output): - args_torch, kwargs_torch, out_torch = torch_pytree.tree_map_only( - torch.Tensor, lambda x: x.torch(), (args, kwargs, current_output) - ) + args_torch, kwargs_torch, out_torch = torch_pytree.tree_map_only( + torch.Tensor, lambda x: x.torch(), (args, kwargs, current_output) + ) - with mode_utils.no_dispatch(), torch._C.DisableTorchFunction(): - if "device" in kwargs_torch: - kwargs_torch["device"] = "cpu" # do the torch native for comparison - expected_out = func(*args_torch, **kwargs_torch) + with mode_utils.no_dispatch(), torch._C.DisableTorchFunction(): + if "device" in kwargs_torch: + kwargs_torch["device"] = "cpu" # do the torch native for comparison + expected_out = func(*args_torch, **kwargs_torch) - flattened_current_out, _ = torch_pytree.tree_flatten(out_torch) - flattened_expected_out, _ = torch_pytree.tree_flatten(expected_out) + flattened_current_out, _ = torch_pytree.tree_flatten(out_torch) + flattened_expected_out, _ = torch_pytree.tree_flatten(expected_out) - for ex, real in zip(flattened_expected_out, flattened_current_out): - if isinstance(ex, torch.Tensor) and ex.dtype != real.dtype: - ex = ex.to(real.dtype) - try: - if isinstance(ex, torch.Tensor) and not torch.allclose( - ex, real, atol=1e-3, equal_nan=True - ): - import pdb + for ex, real in zip(flattened_expected_out, flattened_current_out): + if isinstance(ex, torch.Tensor) and ex.dtype != real.dtype: + ex = ex.to(real.dtype) + try: + if isinstance(ex, torch.Tensor) and not torch.allclose( + ex, real, atol=1e-3, equal_nan=True + ): + import pdb - pdb.set_trace() - except: - import pdb + pdb.set_trace() + except: + import pdb - pdb.set_trace() + pdb.set_trace() - return True + return True def _make_debug_msg(is_dispatch, log_args, func, args, kwargs): - def _display(a): - if isinstance(a, torch.Tensor): - return f"Tensor of {type(a)}: {a.dtype}{a.shape}" - elif isinstance(a, jax.Array): - return f"Jax Array of {type(a)}: {a.dtype}{a.shape}" - else: - return str(a) - - kwargs = kwargs or {} - title = "DISPATCH" if is_dispatch else "FUNCTION" - args_msg = ( - "args: " + ",".join(_display(a) for a in args) if log_args else "" - ) - kwargs_msg = ( - "kwargs: " - + ",".join(f"{key}: {_display(a)}" for key, a in kwargs.items()) - if log_args - else "" - ) - return f"{title}: {_name_of_func(func)} {args_msg} ~ {kwargs_msg}" + def _display(a): + if isinstance(a, torch.Tensor): + return f"Tensor of {type(a)}: {a.dtype}{a.shape}" + elif isinstance(a, jax.Array): + return f"Jax Array of {type(a)}: {a.dtype}{a.shape}" + else: + return str(a) + + kwargs = kwargs or {} + title = "DISPATCH" if is_dispatch else "FUNCTION" + args_msg = "args: " + ",".join(_display(a) for a in args) if log_args else "" + kwargs_msg = ( + "kwargs: " + ",".join(f"{key}: {_display(a)}" for key, a in kwargs.items()) + if log_args + else "" + ) + return f"{title}: {_name_of_func(func)} {args_msg} ~ {kwargs_msg}" class XLAFunctionMode(torch.overrides.TorchFunctionMode): - """Context manager that dispatches torch function calls to JAX.""" - - def __init__(self, env): - self.env = env - - def __torch_function__( - self, func, types, args=(), kwargs=None - ) -> torch.Tensor: - message = f"FUNCTION: {_name_of_func(func)}" - if self.env.config.debug_print_each_op_operands: - message = message + "f" - message = _make_debug_msg( - False, - self.env.config.debug_print_each_op_operands, - func, - args, - kwargs, - ) - with log_nested(self.env, message): - try: - return self.env.dispatch(func, types, args, kwargs) - except OperatorNotFound: - pass - if _name_of_func(func) in ( - "rot90" - ): # skip rot90 with k%4==0 due to no change - if len(args) >= 2 and type(args[1]) == int: - if (args[1]) % 4 == 0: - return args[0] - return func(*args, **(kwargs or {})) + """Context manager that dispatches torch function calls to JAX.""" + + def __init__(self, env): + self.env = env + + def __torch_function__( + self, func, types, args=(), kwargs=None + ) -> torch.Tensor: + message = f"FUNCTION: {_name_of_func(func)}" + if self.env.config.debug_print_each_op_operands: + message = message + "f" + message = _make_debug_msg( + False, + self.env.config.debug_print_each_op_operands, + func, + args, + kwargs, + ) + with log_nested(self.env, message): + try: + return self.env.dispatch(func, types, args, kwargs) + except OperatorNotFound: + pass + if _name_of_func(func) in ( + "rot90" + ): # skip rot90 with k%4==0 due to no change + if len(args) >= 2 and type(args[1]) == int: + if (args[1]) % 4 == 0: + return args[0] + return func(*args, **(kwargs or {})) class XLADispatchMode(torch_dispatch.TorchDispatchMode): - def __init__(self, env): - self.env = env - - def __torch_dispatch__(self, func, types, args=(), kwargs=None): - message = _make_debug_msg( - True, - self.env.config.debug_print_each_op_operands, - func, - args, - kwargs, - ) - with log_nested(self.env, message): - if isinstance(func, torch._ops.OpOverloadPacket): - with self: - return func(*args, **kwargs) - # Only functions under these namespaces will be intercepted - if func.namespace not in ( - "aten", - "_c10d_functional", - "torchvision", - "xla", - ): - return func(*args, **kwargs) - return self.env.dispatch(func, types, args, kwargs) + def __init__(self, env): + self.env = env + + def __torch_dispatch__(self, func, types, args=(), kwargs=None): + message = _make_debug_msg( + True, + self.env.config.debug_print_each_op_operands, + func, + args, + kwargs, + ) + with log_nested(self.env, message): + if isinstance(func, torch._ops.OpOverloadPacket): + with self: + return func(*args, **kwargs) + # Only functions under these namespaces will be intercepted + if func.namespace not in ( + "aten", + "_c10d_functional", + "torchvision", + "xla", + ): + return func(*args, **kwargs) + return self.env.dispatch(func, types, args, kwargs) def _name_of_func(func): - if hasattr(func, "name"): - return func.name() - return func.__name__ + if hasattr(func, "name"): + return func.name() + return func.__name__ # Constructors that don't take other tensor as input TENSOR_CONSTRUCTORS = { - torch.ones, - torch.zeros, - torch.empty, - torch.empty_strided, - torch.tensor, - torch.arange, - torch.eye, - torch.randn, - torch.rand, - torch.randint, - torch.full, - torch.as_tensor, + torch.ones, + torch.zeros, + torch.empty, + torch.empty_strided, + torch.tensor, + torch.arange, + torch.eye, + torch.randn, + torch.rand, + torch.randint, + torch.full, + torch.as_tensor, } # TODO(wen): use existing types, either from torch or jax @@ -315,432 +308,421 @@ def _name_of_func(func): class Environment(contextlib.ContextDecorator): - """This class holds a set of configurations and "globals" needed + """This class holds a set of configurations and "globals" needed - for executing torch program using jax. - Things included so far: + for executing torch program using jax. + Things included so far: - op registry - PRNGKey - Configs + op registry + PRNGKey + Configs - Also helper functions to manipulate those. - """ + Also helper functions to manipulate those. + """ - def __init__(self, configuration=None): - self._function_mode = XLAFunctionMode(self) - self._dispatch_mode = XLADispatchMode(self) + def __init__(self, configuration=None): + self._function_mode = XLAFunctionMode(self) + self._dispatch_mode = XLADispatchMode(self) - # name is torch callable - self._ops = {} - self._decomps = {} + # name is torch callable + self._ops = {} + self._decomps = {} - self.load_ops() + self.load_ops() - self._mesh = None - self.config = configuration or config.Configuration() + self._mesh = None + self.config = configuration or config.Configuration() - self._manually_entered = False - self.enabled = False + self._manually_entered = False + self.enabled = False - self._prng_key = mutable_array( - jax.random.key(torch.initial_seed() % (1 << 63)) - ) - self.autocast_dtype = None - self._target_device = "cpu" - - @property - def target_device(self): - return self._target_device - - @target_device.setter - def target_device(self, device: str): - self._target_device = device.lower() - - def manual_seed(self, key): - self._prng_key = mutable_array(jax.random.key(key)) - - @property - def prng_key(self): - return self._prng_key[...] - - def get_as_jax_device(self, device: Any): - if device is None: - device = torch.get_default_device() - - if isinstance(device, torch.device): - device = str(device) - - if ( - not self.config.use_torch_native_for_cpu_tensor - and device.startswith("cpu") - ): - return jax.devices("cpu")[0] - - if self.config.treat_cuda_as_jax_device and device.startswith("cuda"): - return jax.local_devices()[0] - - if device.startswith("xla"): - return jax.local_devices()[0] - - # TODO (wen): jax is NOT a device type, - # once we can register more than one backend, revisit - if device.startswith("jax"): - match self.target_device: - case "cpu": - return jax.devices("cpu")[0] - case "tpu": - return jax.devices("tpu")[0] - case _: - raise AttributeError( - f"Cannot handle env.target_device {self.target_device}" - ) - - return None # fallback to torch - - def load_ops(self): - from torchax.ops import jaten, jtorch, jc10d, jtorchvision_nms - - for k, v in itertools.chain( - ops_registry.all_aten_ops.items(), - ops_registry.all_torch_functions.items(), - ): - if v.is_jax_function: - self._ops[k] = v - else: - self._decomps[k] = v - - from torchax.decompositions import DECOMPOSITIONS, MUTABLE_DECOMPOSITION - - for k, v in DECOMPOSITIONS.items(): - if k not in self._decomps: - self._decomps[k] = ops_registry.Operator( - k, - v, - is_jax_function=False, - is_user_defined=False, - needs_env=False, - is_view_op=k in MUTABLE_DECOMPOSITION, - ) - - def _get_op_or_decomp(self, func): - def _get_from_dict(op_dict, op): - op = op_dict.get(func) - if op is None and isinstance(func, torch._ops.OpOverloadPacket): - op = op_dict.get(func.default) - if op is None and isinstance(func, torch._ops.OpOverload): - op = op_dict.get(func.overloadpacket) - return op - - op = _get_from_dict(self._ops, func) - - if op is None: - # fallback to decompose - op = _get_from_dict(self._decomps, func) - - if op is None: - raise OperatorNotFound( - f"Operator with name {_name_of_func(func)} has no lowering" - ) + self._prng_key = mutable_array( + jax.random.key(torch.initial_seed() % (1 << 63)) + ) + self.autocast_dtype = None + self._target_device = "cpu" - return op - - def _to_copy(self, the_tensor, new_dtype, new_device): - if isinstance(the_tensor, View): - the_tensor = the_tensor.torch() - - if isinstance(the_tensor, Tensor): - arr = the_tensor.jax() - - if new_dtype is not None and new_dtype != arr.dtype: - arr = arr.astype(mappings.t2j_dtype(new_dtype)) - - if new_device is not None: - match str(new_device).lower(): - case "cpu": - # converting to a non-jax device: let torch native handle it - torch_tensor = ( - self.j2t_copy(arr) - if isinstance(the_tensor, Tensor) - else arr - ) - with ( - mode_utils.no_dispatch(), - torch._C.DisableTorchFunction(), - ): - return torch_tensor.to(new_device) - case "jax": - # move torchax.tensor / jax tensor between devices - # I don't know ifgit this will work after the model is jitted - if self.target_device != the_tensor.jax_device.platform: - arr = jax.device_put( - the_tensor.jax(), - jax.devices(self.target_device)[0], - ) - return Tensor(arr, self) - case _: - logging.error( - f"torchax.Tenosr cannot handle device {new_device}" - ) + @property + def target_device(self): + return self._target_device - else: - if new_dtype is not None and new_dtype != the_tensor.dtype: - with mode_utils.no_dispatch(), torch._C.DisableTorchFunction(): - the_tensor = the_tensor.to(new_dtype) + @target_device.setter + def target_device(self, device: str): + self._target_device = device.lower() - if new_device is None: ## device is None means don't change device - return the_tensor + def manual_seed(self, key): + self._prng_key = mutable_array(jax.random.key(key)) - jax_device = self.get_as_jax_device(new_device) - if jax_device: - arr = self.t2j_copy(the_tensor) - arr = jax.device_put(arr, jax_device) - else: - with mode_utils.no_dispatch(), torch._C.DisableTorchFunction(): - return the_tensor.to(new_device) + @property + def prng_key(self): + return self._prng_key[...] - return Tensor(arr, self) + def get_as_jax_device(self, device: Any): + if device is None: + device = torch.get_default_device() - def get_and_rotate_prng_key( - self, generator: Optional[torch.Generator] = None + if isinstance(device, torch.device): + device = str(device) + + if not self.config.use_torch_native_for_cpu_tensor and device.startswith( + "cpu" ): - if generator is not None: - with mode_utils.no_dispatch(), torch._C.DisableTorchFunction(): - self._prng_key[...] = jax.random.key( - generator.initial_seed() % (2**63) - ) - old_key = self._prng_key[...] - new_prng_key, next_key = jax.random.split(old_key) - self._prng_key[...] = new_prng_key - return next_key - - def _handle_tensor_constructor(self, func, args, kwargs): - device = kwargs.get("device") - jax_device = self.get_as_jax_device(device) - # TODO(qihqi) figure out better ways for device propagation - if not self._manually_entered and jax_device is None: - # let torch handle it - with mode_utils.no_dispatch(), torch._C.DisableTorchFunction(): - return func(*args, **kwargs) - with jax.default_device(jax_device): - requires_grad = kwargs.get("requires_grad", False) - op = self._get_op_or_decomp(func) - res = op.func(*args, **kwargs) - if isinstance(res, jax.Array): - res = Tensor(res, self) - if requires_grad: - res.requires_grad = True - return res - - def _torch_Tensor_to(self, args, kwargs): - the_tensor = args[0] - args = args[1:] - if len(args) >= 1 and isinstance(args[0], torch.Tensor): - dtype = args[0].dtype - device = args[0].device - return self._to_copy(the_tensor, dtype, device) - device = kwargs.get("device") - dtype = kwargs.get("dtype") - # args like pin_memory etc that we will ignore - args = list(filter(lambda x: not isinstance(x, bool), args)) - if len(args) >= 2: - device, dtype, *_ = args - elif len(args) == 1 and isinstance(args[0], torch.dtype): - dtype = args[0] - elif len(args) == 1: - device = args[0] - return self._to_copy(the_tensor, dtype, device) - - def dispatch(self, func, types, args, kwargs): - kwargs = kwargs or {} - if func in TENSOR_CONSTRUCTORS: - return self._handle_tensor_constructor(func, args, kwargs) - if func in ( - torch.Tensor.to, - torch.ops.aten.lift_fresh.default, - torch.ops.aten._to_copy, - torch.ops.aten._to_copy.default, - ): - return self._torch_Tensor_to(args, kwargs) - - # If the func doesn't act on Tensor, and is not a tensor constructor, - # We should skip and let torch handle it. - - tensor_args = [ - t - for t in torch_pytree.tree_flatten(args)[0] - if isinstance(t, torch.Tensor) - ] - - def is_not_torchax_tensor(x): - return not isinstance(x, Tensor) and not isinstance(x, View) - - if tensor_args and all(is_not_torchax_tensor(t) for t in tensor_args): - res = func(*args, **kwargs) - return res - - with jax.named_scope(_name_of_func(func)): - op = self._get_op_or_decomp(func) - - old_args, old_kwargs = args, kwargs - with self._dispatch_mode: - args, kwargs = torch_pytree.tree_map_only( - torch.distributed._functional_collectives.AsyncCollectiveTensor, - torch.distributed._functional_collectives.wait_tensor, - (args, kwargs), - ) - - try: - if not op.is_view_op: - args, kwargs = self.v2t_iso((args, kwargs)) - - with self: - if self.autocast_dtype is not None: - autocast_policy = amp.autocast_policy.get(func) - if autocast_policy is not None: - args, kwargs = amp.execute_policy( - autocast_policy, - args, - kwargs, - self.autocast_dtype, - ) - - if op.is_jax_function: - args, kwargs = self.t2j_iso((args, kwargs)) - except AssertionError: - if self.config.debug_mixed_tensor: - breakpoint() - else: - raise - - if op.needs_env: - kwargs["env"] = self - - if op.is_jax_function: - res = op.func(*args, **kwargs) - else: - # enable dispatch mode because this op could be a composite autograd op - # meaning, it will decompose in C++ - with self._dispatch_mode: - res = op.func(*args, **kwargs) - - if op.is_jax_function: - res = self.j2t_iso(res) - - if self.config.force_materialize_views and isinstance(res, View): - res = res.torch() - - if self.config.debug_accuracy_for_each_op: - debug_accuracy(func, old_args, old_kwargs, res) - return res - - def enable_torch_modes(self): - self._dispatch_mode.__enter__() - self._function_mode.__enter__() - self.enabled = True - - def disable_torch_modes(self, *exc): - if not exc: - exc = (None, None, None) - self._function_mode.__exit__(*exc) - self._dispatch_mode.__exit__(*exc) - self.enabled = False - - def __enter__(self): - self.enable_torch_modes() - self._manually_entered = True - return self - - def __exit__(self, *exc): - self._manually_entered = False - self.disable_torch_modes(*exc) - - def _move_one_value(self, val): - if isinstance(val, torch.nn.Module): - with self: - return val.to("jax") - if isinstance(val, Tensor): - return val - if isinstance(val, torch.Tensor): - return Tensor(self.t2j_copy(val), self) - return val - - def to_xla(self, torchvalues): - # tensors are torch.Tensors (not XLATensor) - res = torch_pytree.tree_map(self._move_one_value, torchvalues) - return res - - def t2j_iso(self, torchtensors): - """Convert torchax Tensor to jax array. - - This function will not copy, will just unwrap the inner jax array out. - Note: iso is short for "isomorphic" - """ - - def to_jax(x): - if isinstance( - x, - torch.distributed._functional_collectives.AsyncCollectiveTensor, - ): - x = x.wait() - assert isinstance(x, Tensor) or isinstance(x, View), ( - f"Expect a Tensor or a View but got {type(x)}; usually this means there is a mixed math between XLATensor and torch.Tensor" + return jax.devices("cpu")[0] + + if self.config.treat_cuda_as_jax_device and device.startswith("cuda"): + return jax.local_devices()[0] + + if device.startswith("xla"): + return jax.local_devices()[0] + + # TODO (wen): jax is NOT a device type, + # once we can register more than one backend, revisit + if device.startswith("jax"): + match self.target_device: + case "cpu": + return jax.devices("cpu")[0] + case "tpu": + return jax.devices("tpu")[0] + case _: + raise AttributeError( + f"Cannot handle env.target_device {self.target_device}" + ) + + return None # fallback to torch + + def load_ops(self): + from torchax.ops import jaten, jtorch, jc10d, jtorchvision_nms + + for k, v in itertools.chain( + ops_registry.all_aten_ops.items(), + ops_registry.all_torch_functions.items(), + ): + if v.is_jax_function: + self._ops[k] = v + else: + self._decomps[k] = v + + from torchax.decompositions import DECOMPOSITIONS, MUTABLE_DECOMPOSITION + + for k, v in DECOMPOSITIONS.items(): + if k not in self._decomps: + self._decomps[k] = ops_registry.Operator( + k, + v, + is_jax_function=False, + is_user_defined=False, + needs_env=False, + is_view_op=k in MUTABLE_DECOMPOSITION, + ) + + def _get_op_or_decomp(self, func): + def _get_from_dict(op_dict, op): + op = op_dict.get(func) + if op is None and isinstance(func, torch._ops.OpOverloadPacket): + op = op_dict.get(func.default) + if op is None and isinstance(func, torch._ops.OpOverload): + op = op_dict.get(func.overloadpacket) + return op + + op = _get_from_dict(self._ops, func) + + if op is None: + # fallback to decompose + op = _get_from_dict(self._decomps, func) + + if op is None: + raise OperatorNotFound( + f"Operator with name {_name_of_func(func)} has no lowering" + ) + + return op + + def _to_copy(self, the_tensor, new_dtype, new_device): + if isinstance(the_tensor, View): + the_tensor = the_tensor.torch() + + if isinstance(the_tensor, Tensor): + arr = the_tensor.jax() + + if new_dtype is not None and new_dtype != arr.dtype: + arr = arr.astype(mappings.t2j_dtype(new_dtype)) + + if new_device is not None: + match str(new_device).lower(): + case "cpu": + # converting to a non-jax device: let torch native handle it + torch_tensor = ( + self.j2t_copy(arr) if isinstance(the_tensor, Tensor) else arr ) - return x.jax() + with ( + mode_utils.no_dispatch(), + torch._C.DisableTorchFunction(), + ): + return torch_tensor.to(new_device) + case "jax": + # move torchax.tensor / jax tensor between devices + # I don't know ifgit this will work after the model is jitted + if self.target_device != the_tensor.jax_device.platform: + arr = jax.device_put( + the_tensor.jax(), + jax.devices(self.target_device)[0], + ) + return Tensor(arr, self) + case _: + logging.error(f"torchax.Tenosr cannot handle device {new_device}") + + else: + if new_dtype is not None and new_dtype != the_tensor.dtype: + with mode_utils.no_dispatch(), torch._C.DisableTorchFunction(): + the_tensor = the_tensor.to(new_dtype) + + if new_device is None: ## device is None means don't change device + return the_tensor + + jax_device = self.get_as_jax_device(new_device) + if jax_device: + arr = self.t2j_copy(the_tensor) + arr = jax.device_put(arr, jax_device) + else: + with mode_utils.no_dispatch(), torch._C.DisableTorchFunction(): + return the_tensor.to(new_device) + + return Tensor(arr, self) + + def get_and_rotate_prng_key( + self, generator: Optional[torch.Generator] = None + ): + if generator is not None: + with mode_utils.no_dispatch(), torch._C.DisableTorchFunction(): + self._prng_key[...] = jax.random.key(generator.initial_seed() % (2**63)) + old_key = self._prng_key[...] + new_prng_key, next_key = jax.random.split(old_key) + self._prng_key[...] = new_prng_key + return next_key + + def _handle_tensor_constructor(self, func, args, kwargs): + device = kwargs.get("device") + jax_device = self.get_as_jax_device(device) + # TODO(qihqi) figure out better ways for device propagation + if not self._manually_entered and jax_device is None: + # let torch handle it + with mode_utils.no_dispatch(), torch._C.DisableTorchFunction(): + return func(*args, **kwargs) + with jax.default_device(jax_device): + requires_grad = kwargs.get("requires_grad", False) + op = self._get_op_or_decomp(func) + res = op.func(*args, **kwargs) + if isinstance(res, jax.Array): + res = Tensor(res, self) + if requires_grad: + res.requires_grad = True + return res + + def _torch_Tensor_to(self, args, kwargs): + the_tensor = args[0] + args = args[1:] + if len(args) >= 1 and isinstance(args[0], torch.Tensor): + dtype = args[0].dtype + device = args[0].device + return self._to_copy(the_tensor, dtype, device) + device = kwargs.get("device") + dtype = kwargs.get("dtype") + # args like pin_memory etc that we will ignore + args = list(filter(lambda x: not isinstance(x, bool), args)) + if len(args) >= 2: + device, dtype, *_ = args + elif len(args) == 1 and isinstance(args[0], torch.dtype): + dtype = args[0] + elif len(args) == 1: + device = args[0] + return self._to_copy(the_tensor, dtype, device) + + def dispatch(self, func, types, args, kwargs): + kwargs = kwargs or {} + if func in TENSOR_CONSTRUCTORS: + return self._handle_tensor_constructor(func, args, kwargs) + if func in ( + torch.Tensor.to, + torch.ops.aten.lift_fresh.default, + torch.ops.aten._to_copy, + torch.ops.aten._to_copy.default, + ): + return self._torch_Tensor_to(args, kwargs) - res = torch_pytree.tree_map_only(torch.Tensor, to_jax, torchtensors) - return res + # If the func doesn't act on Tensor, and is not a tensor constructor, + # We should skip and let torch handle it. - def v2t_iso(self, views): - def to_tensor(x): - if isinstance(x, View): - return x.torch() - return x + tensor_args = [ + t + for t in torch_pytree.tree_flatten(args)[0] + if isinstance(t, torch.Tensor) + ] - res = torch_pytree.tree_map_only(View, to_tensor, views) - return res + def is_not_torchax_tensor(x): + return not isinstance(x, Tensor) and not isinstance(x, View) - def j2t_iso(self, jaxarray): - """Convert jax array to torchax Tensor. + if tensor_args and all(is_not_torchax_tensor(t) for t in tensor_args): + res = func(*args, **kwargs) + return res - This function will not copy, will just wrap the jax array with a torchax Tensor - Note: iso is short for "isomorphic" - """ - return torch_pytree.tree_map_only( - jax.Array, lambda x: Tensor(x, self), jaxarray - ) + with jax.named_scope(_name_of_func(func)): + op = self._get_op_or_decomp(func) - def j2t_copy(self, args): - """Convert torch.Tensor in cpu to a jax array - - This might involves copying the data (depending if dlpack is enabled) - """ - return torch_pytree.tree_map_only( - jax.Array, - lambda x: mappings.j2t( - x, self.config.use_dlpack_for_data_conversion - ), - args, + old_args, old_kwargs = args, kwargs + with self._dispatch_mode: + args, kwargs = torch_pytree.tree_map_only( + torch.distributed._functional_collectives.AsyncCollectiveTensor, + torch.distributed._functional_collectives.wait_tensor, + (args, kwargs), ) - def t2j_copy(self, args): - """Convert jax array to torch.Tensor in cpu. - - This might involves copying the data (depending if dlpack is enabled) - """ - return torch_pytree.tree_map_only( - torch.Tensor, - lambda x: mappings.t2j( - x, self.config.use_dlpack_for_data_conversion - ), - args, - ) + try: + if not op.is_view_op: + args, kwargs = self.v2t_iso((args, kwargs)) + + with self: + if self.autocast_dtype is not None: + autocast_policy = amp.autocast_policy.get(func) + if autocast_policy is not None: + args, kwargs = amp.execute_policy( + autocast_policy, + args, + kwargs, + self.autocast_dtype, + ) + + if op.is_jax_function: + args, kwargs = self.t2j_iso((args, kwargs)) + except AssertionError: + if self.config.debug_mixed_tensor: + breakpoint() + else: + raise + + if op.needs_env: + kwargs["env"] = self + + if op.is_jax_function: + res = op.func(*args, **kwargs) + else: + # enable dispatch mode because this op could be a composite autograd op + # meaning, it will decompose in C++ + with self._dispatch_mode: + res = op.func(*args, **kwargs) + + if op.is_jax_function: + res = self.j2t_iso(res) + + if self.config.force_materialize_views and isinstance(res, View): + res = res.torch() + + if self.config.debug_accuracy_for_each_op: + debug_accuracy(func, old_args, old_kwargs, res) + return res + + def enable_torch_modes(self): + self._dispatch_mode.__enter__() + self._function_mode.__enter__() + self.enabled = True + + def disable_torch_modes(self, *exc): + if not exc: + exc = (None, None, None) + self._function_mode.__exit__(*exc) + self._dispatch_mode.__exit__(*exc) + self.enabled = False + + def __enter__(self): + self.enable_torch_modes() + self._manually_entered = True + return self + + def __exit__(self, *exc): + self._manually_entered = False + self.disable_torch_modes(*exc) + + def _move_one_value(self, val): + if isinstance(val, torch.nn.Module): + with self: + return val.to("jax") + if isinstance(val, Tensor): + return val + if isinstance(val, torch.Tensor): + return Tensor(self.t2j_copy(val), self) + return val + + def to_xla(self, torchvalues): + # tensors are torch.Tensors (not XLATensor) + res = torch_pytree.tree_map(self._move_one_value, torchvalues) + return res + + def t2j_iso(self, torchtensors): + """Convert torchax Tensor to jax array. + + This function will not copy, will just unwrap the inner jax array out. + Note: iso is short for "isomorphic" + """ - def override_op_definition(self, op_to_override, op_impl): - self._ops[op_to_override] = ops_registry.Operator( - op_to_override, - op_impl, - is_jax_function=False, - is_user_defined=True, - needs_env=False, - ) + def to_jax(x): + if isinstance( + x, + torch.distributed._functional_collectives.AsyncCollectiveTensor, + ): + x = x.wait() + assert isinstance(x, Tensor) or isinstance(x, View), ( + f"Expect a Tensor or a View but got {type(x)}; usually this means there is a mixed math between XLATensor and torch.Tensor" + ) + return x.jax() + + res = torch_pytree.tree_map_only(torch.Tensor, to_jax, torchtensors) + return res + + def v2t_iso(self, views): + def to_tensor(x): + if isinstance(x, View): + return x.torch() + return x + + res = torch_pytree.tree_map_only(View, to_tensor, views) + return res + + def j2t_iso(self, jaxarray): + """Convert jax array to torchax Tensor. + + This function will not copy, will just wrap the jax array with a torchax Tensor + Note: iso is short for "isomorphic" + """ + return torch_pytree.tree_map_only( + jax.Array, lambda x: Tensor(x, self), jaxarray + ) + + def j2t_copy(self, args): + """Convert torch.Tensor in cpu to a jax array + + This might involves copying the data (depending if dlpack is enabled) + """ + return torch_pytree.tree_map_only( + jax.Array, + lambda x: mappings.j2t(x, self.config.use_dlpack_for_data_conversion), + args, + ) + + def t2j_copy(self, args): + """Convert jax array to torch.Tensor in cpu. + + This might involves copying the data (depending if dlpack is enabled) + """ + return torch_pytree.tree_map_only( + torch.Tensor, + lambda x: mappings.t2j(x, self.config.use_dlpack_for_data_conversion), + args, + ) + + def override_op_definition(self, op_to_override, op_impl): + self._ops[op_to_override] = ops_registry.Operator( + op_to_override, + op_impl, + is_jax_function=False, + is_user_defined=True, + needs_env=False, + ) diff --git a/torchax/torchax/tf_integration.py b/torchax/torchax/tf_integration.py index 54049604d729..2542df318589 100644 --- a/torchax/torchax/tf_integration.py +++ b/torchax/torchax/tf_integration.py @@ -9,117 +9,117 @@ def exported_program_to_tf_function(ep, enable_xla=True): - weights, jax_program = export.exported_program_to_jax(ep) - wrapped = lambda *args: jax_program(weights, (args,)) - avals = export.extract_avals(ep) - input_signature = [ - tf.TensorSpec(shape=t.shape, dtype=t.dtype, name=f"args_{i}") - for i, t in enumerate(avals) - ] - tf_f = tf.function( - jax2tf.convert( - wrapped, - with_gradient=False, - enable_xla=enable_xla, - ), - autograph=False, - input_signature=input_signature, - ) - return tf_f + weights, jax_program = export.exported_program_to_jax(ep) + wrapped = lambda *args: jax_program(weights, (args,)) + avals = export.extract_avals(ep) + input_signature = [ + tf.TensorSpec(shape=t.shape, dtype=t.dtype, name=f"args_{i}") + for i, t in enumerate(avals) + ] + tf_f = tf.function( + jax2tf.convert( + wrapped, + with_gradient=False, + enable_xla=enable_xla, + ), + autograph=False, + input_signature=input_signature, + ) + return tf_f def exported_program_to_tf_module( - ep: torch.export.ExportedProgram, enable_xla=True + ep: torch.export.ExportedProgram, enable_xla=True ) -> tf.Module: - tfm = tf.Module() - tfm.f = exported_program_to_tf_function(ep, enable_xla) - return tfm + tfm = tf.Module() + tfm.f = exported_program_to_tf_function(ep, enable_xla) + return tfm def save_exported_program_as_tf_saved_model( - ep: torch.export.ExportedProgram, - saved_model_dir: os.PathLike, - serving_key: str = tf.saved_model.DEFAULT_SERVING_SIGNATURE_DEF_KEY, - function_alias: str = "", - enable_xla=True, + ep: torch.export.ExportedProgram, + saved_model_dir: os.PathLike, + serving_key: str = tf.saved_model.DEFAULT_SERVING_SIGNATURE_DEF_KEY, + function_alias: str = "", + enable_xla=True, ): - """This function will export and save a pytorch ExportedProgram to tf.saved_model format. - - The resulting tf.saved_model can be used inference using tf.serving model - server - or further convert to tflite flatbuffer for on-device serving. - - Args: - torch_model: torch.nn.Module - model to export and save - args: Tuple[Any] - a set of args to trace the model with, i.e. - torch_model(*args) must run - saved_model_dir: os.PathLike - location to an empty directory to store the - saved_model - serving_key: str - serving key tag, this is used by tf.serving to know - which function to run. - function_alias: str - passed through saved_model.save, used to tag a - function for inference converter or other tools. - """ - tfm = exported_program_to_tf_module(ep, enable_xla=enable_xla) - signatures = { - serving_key: tfm.f.get_concrete_function(*tfm.f.input_signature) + """This function will export and save a pytorch ExportedProgram to tf.saved_model format. + + The resulting tf.saved_model can be used inference using tf.serving model + server + or further convert to tflite flatbuffer for on-device serving. + + Args: + torch_model: torch.nn.Module - model to export and save + args: Tuple[Any] - a set of args to trace the model with, i.e. + torch_model(*args) must run + saved_model_dir: os.PathLike - location to an empty directory to store the + saved_model + serving_key: str - serving key tag, this is used by tf.serving to know + which function to run. + function_alias: str - passed through saved_model.save, used to tag a + function for inference converter or other tools. + """ + tfm = exported_program_to_tf_module(ep, enable_xla=enable_xla) + signatures = { + serving_key: tfm.f.get_concrete_function(*tfm.f.input_signature) + } + save_options = tf.saved_model.SaveOptions( + function_aliases={ + function_alias: tfm.f, } - save_options = tf.saved_model.SaveOptions( - function_aliases={ - function_alias: tfm.f, - } - ) - tf.saved_model.save( - tfm, - saved_model_dir, - signatures=signatures, - options=save_options, - ) + ) + tf.saved_model.save( + tfm, + saved_model_dir, + signatures=signatures, + options=save_options, + ) def save_torch_module_as_tf_saved_model( - torch_model: torch.nn.Module, - args: Tuple[Any], - saved_model_dir: os.PathLike, - serving_key: str = tf.saved_model.DEFAULT_SERVING_SIGNATURE_DEF_KEY, - function_alias: str = "", - enable_xla=True, + torch_model: torch.nn.Module, + args: Tuple[Any], + saved_model_dir: os.PathLike, + serving_key: str = tf.saved_model.DEFAULT_SERVING_SIGNATURE_DEF_KEY, + function_alias: str = "", + enable_xla=True, ): - """This function will export and save a pytorch nn.Module to tf.saved_model format. - - The resulting tf.saved_model can be used inference using tf.serving model - server - or further convert to tflite flatbuffer for on-device serving. - - Args: - torch_model: torch.nn.Module - model to export and save - args: Tuple[Any] - a set of args to trace the model with, i.e. - torch_model(*args) must run - saved_model_dir: os.PathLike - location to an empty directory to store the - saved_model - serving_key: str - serving key tag, this is used by tf.serving to know - which function to run. - function_alias: str - passed through saved_model.save, used to tag a - function for inference converter or other tools. - """ - ep = torch.export.export(torch_model, args) - save_exported_program_as_tf_saved_model( - ep, saved_model_dir, serving_key, function_alias, enable_xla - ) + """This function will export and save a pytorch nn.Module to tf.saved_model format. + + The resulting tf.saved_model can be used inference using tf.serving model + server + or further convert to tflite flatbuffer for on-device serving. + + Args: + torch_model: torch.nn.Module - model to export and save + args: Tuple[Any] - a set of args to trace the model with, i.e. + torch_model(*args) must run + saved_model_dir: os.PathLike - location to an empty directory to store the + saved_model + serving_key: str - serving key tag, this is used by tf.serving to know + which function to run. + function_alias: str - passed through saved_model.save, used to tag a + function for inference converter or other tools. + """ + ep = torch.export.export(torch_model, args) + save_exported_program_as_tf_saved_model( + ep, saved_model_dir, serving_key, function_alias, enable_xla + ) def exported_program_to_tflite_flatbuffer(ep: torch.export.ExportedProgram): - tfm = exported_program_to_tf_module(ep) - tf_concrete_func = tfm.f.get_concrete_function(*tfm.f.input_signature) - converter = tf.lite.TFLiteConverter.from_concrete_functions( - [tf_concrete_func], tfm - ) - tflite_model = converter.convert() - return tflite_model + tfm = exported_program_to_tf_module(ep) + tf_concrete_func = tfm.f.get_concrete_function(*tfm.f.input_signature) + converter = tf.lite.TFLiteConverter.from_concrete_functions( + [tf_concrete_func], tfm + ) + tflite_model = converter.convert() + return tflite_model def torch_module_to_tflite_flatbuffer( - torch_model: torch.nn.Module, args: Tuple[Any] + torch_model: torch.nn.Module, args: Tuple[Any] ): - ep = torch.export.export(torch_model, args) - return exported_program_to_tflite_flatbuffer(ep) + ep = torch.export.export(torch_model, args) + return exported_program_to_tflite_flatbuffer(ep) diff --git a/torchax/torchax/train.py b/torchax/torchax/train.py index 78639090321f..c1be6ab9901e 100644 --- a/torchax/torchax/train.py +++ b/torchax/torchax/train.py @@ -12,107 +12,106 @@ def make_train_step(model_fn, loss_fn, optax_optimizer, remat_policy=None): - """Make a function that do one train step given model and loss. - - model_fn: a function representing the model's forward: - i.e. has signature Callable[weights, buffers, args] -> result. Where, - weights is a pytree of trainable parameters - buffers is a pytree of non-trainable parameters / constants - args is the input data loaded from the data set - result is the return value of the model - loss_fn: a function to compute loss. - i.e. it has signature of Callable[result, label] -> loss - where, result is what model_fn returned - loss is loaded from the dataloader. - optax_optimizer: the optimizer from optax library. for example, optax.adam - remat_policy: One of jax.ad_checkpoint.checkpoint_policies, specifies how - to do gradient checkpointing. If None, then it means checkpoint everything. - """ - env = torchax.default_env() - - def loss(weights, buffers, args, label): # inputs are XLATensor - with env, jax.named_scope("compute_loss"): - res = model_fn(weights, buffers, args) - l = loss_fn(res, label) - return l - - loss = interop.gradient_checkpoint(loss, kwargs={"policy": remat_policy}) - grad_fn = interop.jax_value_and_grad(loss) - - def step(weights, buffers, opt_state, args, label): # inputs are array - with jax.named_scope("compute_gradient"): - loss, gradient = grad_fn(weights, buffers, args, label) - - with jax.named_scope("optimizer_updates"): - updates, opt_state = interop.call_jax( - optax_optimizer.update, gradient, opt_state, weights - ) - weights = interop.call_jax(optax.apply_updates, weights, updates) - return loss, weights, opt_state - - # TODO: apply jax.jit so the user don't have to. - return step + """Make a function that do one train step given model and loss. + + model_fn: a function representing the model's forward: + i.e. has signature Callable[weights, buffers, args] -> result. Where, + weights is a pytree of trainable parameters + buffers is a pytree of non-trainable parameters / constants + args is the input data loaded from the data set + result is the return value of the model + loss_fn: a function to compute loss. + i.e. it has signature of Callable[result, label] -> loss + where, result is what model_fn returned + loss is loaded from the dataloader. + optax_optimizer: the optimizer from optax library. for example, optax.adam + remat_policy: One of jax.ad_checkpoint.checkpoint_policies, specifies how + to do gradient checkpointing. If None, then it means checkpoint everything. + """ + env = torchax.default_env() + + def loss(weights, buffers, args, label): # inputs are XLATensor + with env, jax.named_scope("compute_loss"): + res = model_fn(weights, buffers, args) + l = loss_fn(res, label) + return l + + loss = interop.gradient_checkpoint(loss, kwargs={"policy": remat_policy}) + grad_fn = interop.jax_value_and_grad(loss) + + def step(weights, buffers, opt_state, args, label): # inputs are array + with jax.named_scope("compute_gradient"): + loss, gradient = grad_fn(weights, buffers, args, label) + + with jax.named_scope("optimizer_updates"): + updates, opt_state = interop.call_jax( + optax_optimizer.update, gradient, opt_state, weights + ) + weights = interop.call_jax(optax.apply_updates, weights, updates) + return loss, weights, opt_state + + # TODO: apply jax.jit so the user don't have to. + return step class Container: - pass + pass class ScannedModule(torch.nn.Module): - def __init__(self, module_list, checkpoint_policy=None): - super().__init__() - - self.c = None - assert module_list - self.c = Container() - self.c.one_mod = module_list[0] - self.checkpoint_policy = checkpoint_policy - - weights = self._stack_layer_weights(module_list) - self.layer_weights_keys = list(self.c.one_mod.state_dict().keys()) - self.params = torch.nn.ParameterDict({ - self._param_name_new(k): v for k, v in weights.items() - }) - - def _stack_layer_weights(self, module_list): - # Create weights such that, for every [n, m] weights - # becomes [k, n, m] where k is number of layer - # i.e. stacking layer weights together - temp = collections.defaultdict(list) - for m in module_list: - for k, v in m.state_dict().items(): - temp[k].append(v) - res = {k: torch.stack(v) for k, v in temp.items()} - return res - - def _param_name_new(self, old): - return "___".join(old.split(".")) - - def _param_name_old(self, new): - return ".".join(new.split("___")) - - def forward(self, *args, **kwargs): - assert not kwargs - weights = { - k: self.params[self._param_name_new(k)] - for k in self.layer_weights_keys - } - scan = interop.torch_view(jax.lax.scan) - - def eval_one_layer(args, weight): - # unpack args - h, *rest = args - newh = torch.func.functional_call(self.c.one_mod, weight, args) - # next layer's input; and residual to be added to list - return (newh, *rest), None - - _eval_one_layer = interop.gradient_checkpoint( - eval_one_layer, - kwargs={"policy": self.checkpoint_policy}, - ) - h, _ = scan( - _eval_one_layer, - args, - weights, - ) - return h[0] + def __init__(self, module_list, checkpoint_policy=None): + super().__init__() + + self.c = None + assert module_list + self.c = Container() + self.c.one_mod = module_list[0] + self.checkpoint_policy = checkpoint_policy + + weights = self._stack_layer_weights(module_list) + self.layer_weights_keys = list(self.c.one_mod.state_dict().keys()) + self.params = torch.nn.ParameterDict({ + self._param_name_new(k): v for k, v in weights.items() + }) + + def _stack_layer_weights(self, module_list): + # Create weights such that, for every [n, m] weights + # becomes [k, n, m] where k is number of layer + # i.e. stacking layer weights together + temp = collections.defaultdict(list) + for m in module_list: + for k, v in m.state_dict().items(): + temp[k].append(v) + res = {k: torch.stack(v) for k, v in temp.items()} + return res + + def _param_name_new(self, old): + return "___".join(old.split(".")) + + def _param_name_old(self, new): + return ".".join(new.split("___")) + + def forward(self, *args, **kwargs): + assert not kwargs + weights = { + k: self.params[self._param_name_new(k)] for k in self.layer_weights_keys + } + scan = interop.torch_view(jax.lax.scan) + + def eval_one_layer(args, weight): + # unpack args + h, *rest = args + newh = torch.func.functional_call(self.c.one_mod, weight, args) + # next layer's input; and residual to be added to list + return (newh, *rest), None + + _eval_one_layer = interop.gradient_checkpoint( + eval_one_layer, + kwargs={"policy": self.checkpoint_policy}, + ) + h, _ = scan( + _eval_one_layer, + args, + weights, + ) + return h[0] diff --git a/torchax/torchax/util.py b/torchax/torchax/util.py index 7f6f8cd638dc..4b4c4297dcd4 100644 --- a/torchax/torchax/util.py +++ b/torchax/torchax/util.py @@ -2,88 +2,88 @@ def partition( - original: list[Any], func: Callable[[Any], bool] + original: list[Any], func: Callable[[Any], bool] ) -> tuple[list[Any], list[Any]]: - """Partitions elements into two parallel lists based on a predicate function. - - Iterates through the 'original' list, applying 'func' to each element 'a'. - - If `func(a)` returns True, 'a' is appended to the first list ('truthy') - and `None` is appended to the second list ('falsy'). - - If `func(a)` returns False, `None` is appended to the first list ('truthy') - and 'a' is appended to the second list ('falsy'). - - The result is two lists of the same length as the 'original' list, acting - as parallel representations of the partitioned elements, using `None` as - placeholders. - - This is useful when we want to mark a group of elements as static (via passing - static_argnums) or donated (via donate_argnums) when combining with jax.jit - and friends. - - Args: - original: The list of elements to partition. - func: A callable (function or lambda) that accepts an element from - 'original' and returns a boolean value (True or False). - - Returns: - A tuple containing two lists (`truthy`, `falsy`), both of the same - length as `original`: - - The first list contains elements `x` where `func(x)` was True, and - `None` otherwise. - - The second list contains elements `x` where `func(x)` was False, and - `None` otherwise. - - Example: - >>> def is_even(n): return n % 2 == 0 - >>> nums = [1, 2, 3, 4, 5, 6] - >>> truthy_list, falsy_list = partition(nums, is_even) - >>> truthy_list - [None, 2, None, 4, None, 6] - >>> falsy_list - [1, None, 3, None, 5, None] - """ - truthy = [] - falsy = [] - for a in original: - t, f = (a, None) if func(a) else (None, a) - truthy.append(t) - falsy.append(f) - return truthy, falsy + """Partitions elements into two parallel lists based on a predicate function. + + Iterates through the 'original' list, applying 'func' to each element 'a'. + - If `func(a)` returns True, 'a' is appended to the first list ('truthy') + and `None` is appended to the second list ('falsy'). + - If `func(a)` returns False, `None` is appended to the first list ('truthy') + and 'a' is appended to the second list ('falsy'). + + The result is two lists of the same length as the 'original' list, acting + as parallel representations of the partitioned elements, using `None` as + placeholders. + + This is useful when we want to mark a group of elements as static (via passing + static_argnums) or donated (via donate_argnums) when combining with jax.jit + and friends. + + Args: + original: The list of elements to partition. + func: A callable (function or lambda) that accepts an element from + 'original' and returns a boolean value (True or False). + + Returns: + A tuple containing two lists (`truthy`, `falsy`), both of the same + length as `original`: + - The first list contains elements `x` where `func(x)` was True, and + `None` otherwise. + - The second list contains elements `x` where `func(x)` was False, and + `None` otherwise. + + Example: + >>> def is_even(n): return n % 2 == 0 + >>> nums = [1, 2, 3, 4, 5, 6] + >>> truthy_list, falsy_list = partition(nums, is_even) + >>> truthy_list + [None, 2, None, 4, None, 6] + >>> falsy_list + [1, None, 3, None, 5, None] + """ + truthy = [] + falsy = [] + for a in original: + t, f = (a, None) if func(a) else (None, a) + truthy.append(t) + falsy.append(f) + return truthy, falsy def merge(list1: list[Any], list2: list[Any]) -> list[Any]: - """Merges two lists element-wise, prioritizing non-None elements from list1. - - Creates a new list where each element is taken from the corresponding position - in 'list1', unless that element is None, in which case the element from the - corresponding position in 'list2' is used. Assumes both lists have the - same length. - - Invariant: merge(*partion(input_list, predicate)) == input_list for any predicate - - Args: - list1: The primary list. Its elements are preferred unless they are None. - list2: The secondary list. Its elements are used as fallbacks when the - corresponding element in list1 is None. - - Returns: - A new list representing the merged result. - - Raises: - AssertionError: If 'list1' and 'list2' do not have the same length. - - Example: - >>> l1 = [1, None, 3, None] - >>> l2 = [None, 2, None, 4] - >>> merge(l1, l2) - [1, 2, 3, 4] - >>> l3 = [None, 'b', None] - >>> l4 = ['a', None, 'c'] - >>> merge(l3, l4) - ['a', 'b', 'c'] - """ - assert len(list1) == len(list2) - res = [] - for a, b in zip(list1, list2): - res.append(b if a is None else a) - return res + """Merges two lists element-wise, prioritizing non-None elements from list1. + + Creates a new list where each element is taken from the corresponding position + in 'list1', unless that element is None, in which case the element from the + corresponding position in 'list2' is used. Assumes both lists have the + same length. + + Invariant: merge(*partion(input_list, predicate)) == input_list for any predicate + + Args: + list1: The primary list. Its elements are preferred unless they are None. + list2: The secondary list. Its elements are used as fallbacks when the + corresponding element in list1 is None. + + Returns: + A new list representing the merged result. + + Raises: + AssertionError: If 'list1' and 'list2' do not have the same length. + + Example: + >>> l1 = [1, None, 3, None] + >>> l2 = [None, 2, None, 4] + >>> merge(l1, l2) + [1, 2, 3, 4] + >>> l3 = [None, 'b', None] + >>> l4 = ['a', None, 'c'] + >>> merge(l3, l4) + ['a', 'b', 'c'] + """ + assert len(list1) == len(list2) + res = [] + for a, b in zip(list1, list2): + res.append(b if a is None else a) + return res diff --git a/torchax/torchax/view.py b/torchax/torchax/view.py index c2a5851be6dc..0012273aa213 100644 --- a/torchax/torchax/view.py +++ b/torchax/torchax/view.py @@ -10,392 +10,392 @@ class ViewInfoType(Enum): - INVALID = 0 - NARROW = 1 - NO_OP = 2 - PERMUTE = 3 - RESHAPE = 4 - RESIZE = 5 - SELECT = 6 - AS_STRIDED = 7 - DIAGONAL = 8 + INVALID = 0 + NARROW = 1 + NO_OP = 2 + PERMUTE = 3 + RESHAPE = 4 + RESIZE = 5 + SELECT = 6 + AS_STRIDED = 7 + DIAGONAL = 8 class ViewInfo(ABC): + """ + Abstract base class for all view operations. + Defines the interface for applying and updating view transformations. + """ + + def __init__( + self, + view_info_type: ViewInfoType = ViewInfoType.INVALID, + ): """ - Abstract base class for all view operations. - Defines the interface for applying and updating view transformations. - """ - - def __init__( - self, - view_info_type: ViewInfoType = ViewInfoType.INVALID, - ): - """ - Initialize a ViewInfo object. + Initialize a ViewInfo object. - Args: - view_info_type: The type of view operation - """ - self.view_info_type = view_info_type + Args: + view_info_type: The type of view operation + """ + self.view_info_type = view_info_type - @abstractmethod - def update_tensor( - self, new_value: jax.Array, jax_array: jax.Array - ) -> jax.Array: - """ - Apply this view transformation to a JAX array and update its value. + @abstractmethod + def update_tensor( + self, new_value: jax.Array, jax_array: jax.Array + ) -> jax.Array: + """ + Apply this view transformation to a JAX array and update its value. - Args: - new_value: The new values to set in the view - jax_array: The parent array to update + Args: + new_value: The new values to set in the view + jax_array: The parent array to update - Returns: - Updated array - """ - pass + Returns: + Updated array + """ + pass - @abstractmethod - def transform_tensor(self, jax_array: jax.Array) -> jax.Array: - """ - Apply this view transformation to a JAX array. + @abstractmethod + def transform_tensor(self, jax_array: jax.Array) -> jax.Array: + """ + Apply this view transformation to a JAX array. - Args: - jax_array: The array to transform + Args: + jax_array: The array to transform - Returns: - Transformed array - """ - pass + Returns: + Transformed array + """ + pass - @abstractmethod - def calculate_output_shape(self, source: jax.Array) -> List[int]: - """ - Calculate the resulting shape after applying this view. + @abstractmethod + def calculate_output_shape(self, source: jax.Array) -> List[int]: + """ + Calculate the resulting shape after applying this view. - Args: - source: Original jax array before transformation + Args: + source: Original jax array before transformation - Returns: - Resulting shape after transformation - """ - pass + Returns: + Resulting shape after transformation + """ + pass class NarrowInfo(ViewInfo): + """ + Represents a slicing operation on a tensor. + Handles operations like tensor[1:3, :, 2:5:2]. + """ + + def __init__(self, slices: Union[slice, Tuple[slice]]) -> None: """ - Represents a slicing operation on a tensor. - Handles operations like tensor[1:3, :, 2:5:2]. + Args: + slices: The slice(s) to apply to the tensor. + E.g. jax_array.at[slices] will return the transformed tensor. """ + super().__init__(ViewInfoType.NARROW) + self.slices = slices - def __init__(self, slices: Union[slice, Tuple[slice]]) -> None: - """ - Args: - slices: The slice(s) to apply to the tensor. - E.g. jax_array.at[slices] will return the transformed tensor. - """ - super().__init__(ViewInfoType.NARROW) - self.slices = slices - - def __eq__(self, other: object) -> bool: - if not isinstance(other, NarrowInfo): - return False - return self.slices == other.slices + def __eq__(self, other: object) -> bool: + if not isinstance(other, NarrowInfo): + return False + return self.slices == other.slices - def transform_tensor(self, jax_array: jax.Array) -> jax.Array: - try: - return jax_array[self.slices] - except IndexError as e: - raise IndexError("Invalid slice operation") from e + def transform_tensor(self, jax_array: jax.Array) -> jax.Array: + try: + return jax_array[self.slices] + except IndexError as e: + raise IndexError("Invalid slice operation") from e - def update_tensor( - self, new_value: jax.Array, jax_array: jax.Array - ) -> jax.Array: - return jax_array.at[self.slices].set(new_value) + def update_tensor( + self, new_value: jax.Array, jax_array: jax.Array + ) -> jax.Array: + return jax_array.at[self.slices].set(new_value) - def calculate_output_shape(self, source: jax.Array) -> List[int]: - return source[self.slices].shape + def calculate_output_shape(self, source: jax.Array) -> List[int]: + return source[self.slices].shape class SelectInfo(ViewInfo): - """ - Represents a selection operation on a tensor. - Typically used for indexing operations that select specific elements. - """ - - def __init__( - self, dim: int = 0, start: int = 0, end: int = 0, stride: int = 0 - ) -> None: - super().__init__(ViewInfoType.SELECT) - self.dim: int = dim - self.start: int = start - self.end: int = end - self.stride: int = stride - - def __eq__(self, other: object) -> bool: - if not isinstance(other, SelectInfo): - return False - return ( - self.dim == other.dim - and self.start == other.start - and self.end == other.end - and self.stride == other.stride - ) - - def transform_tensor(self, jax_array: jax.Array) -> jax.Array: - raise NotImplementedError("SelectInfo.apply not implemented") - - def update_tensor( - self, new_value: jax.Array, jax_array: jax.Array - ) -> jax.Array: - raise NotImplementedError("SelectInfo.update not implemented") - - def calculate_output_shape(self, source: jax.Array) -> List[int]: - raise NotImplementedError( - "SelectInfo.calculate_output_shape not implemented" - ) + """ + Represents a selection operation on a tensor. + Typically used for indexing operations that select specific elements. + """ + + def __init__( + self, dim: int = 0, start: int = 0, end: int = 0, stride: int = 0 + ) -> None: + super().__init__(ViewInfoType.SELECT) + self.dim: int = dim + self.start: int = start + self.end: int = end + self.stride: int = stride + + def __eq__(self, other: object) -> bool: + if not isinstance(other, SelectInfo): + return False + return ( + self.dim == other.dim + and self.start == other.start + and self.end == other.end + and self.stride == other.stride + ) + + def transform_tensor(self, jax_array: jax.Array) -> jax.Array: + raise NotImplementedError("SelectInfo.apply not implemented") + + def update_tensor( + self, new_value: jax.Array, jax_array: jax.Array + ) -> jax.Array: + raise NotImplementedError("SelectInfo.update not implemented") + + def calculate_output_shape(self, source: jax.Array) -> List[int]: + raise NotImplementedError( + "SelectInfo.calculate_output_shape not implemented" + ) class AsStridedInfo(ViewInfo): - """ - Information for as_strided operations. - """ + """ + Information for as_strided operations. + """ - def __init__(self, stride: List[int], offset: int = 0) -> None: - super().__init__(ViewInfoType.AS_STRIDED) - self.stride: List[int] = stride - self.offset: int = offset + def __init__(self, stride: List[int], offset: int = 0) -> None: + super().__init__(ViewInfoType.AS_STRIDED) + self.stride: List[int] = stride + self.offset: int = offset - def __eq__(self, other: object) -> bool: - if not isinstance(other, AsStridedInfo): - return False - return self.offset == other.offset and self.stride == other.stride + def __eq__(self, other: object) -> bool: + if not isinstance(other, AsStridedInfo): + return False + return self.offset == other.offset and self.stride == other.stride - def transform_tensor(self, jax_array: jax.Array) -> jax.Array: - raise NotImplementedError("AsStridedInfo.apply not implemented") + def transform_tensor(self, jax_array: jax.Array) -> jax.Array: + raise NotImplementedError("AsStridedInfo.apply not implemented") - def update_tensor( - self, new_value: jax.Array, jax_array: jax.Array - ) -> jax.Array: - raise NotImplementedError("AsStridedInfo.update not implemented") + def update_tensor( + self, new_value: jax.Array, jax_array: jax.Array + ) -> jax.Array: + raise NotImplementedError("AsStridedInfo.update not implemented") - def calculate_output_shape(self, source: jax.Array) -> List[int]: - raise NotImplementedError( - "AsStridedInfo.calculate_output_shape not implemented" - ) + def calculate_output_shape(self, source: jax.Array) -> List[int]: + raise NotImplementedError( + "AsStridedInfo.calculate_output_shape not implemented" + ) class DiagonalInfo(ViewInfo): + """ + Information for diagonal operations. + Extracts diagonal elements from a tensor. + """ + + def __init__(self, offset: int = 0, dim1: int = 0, dim2: int = 1) -> None: """ - Information for diagonal operations. - Extracts diagonal elements from a tensor. + Args: + offset: Offset from the main diagonal + dim1: First dimension for diagonal extraction + dim2: Second dimension for diagonal extraction """ - - def __init__(self, offset: int = 0, dim1: int = 0, dim2: int = 1) -> None: - """ - Args: - offset: Offset from the main diagonal - dim1: First dimension for diagonal extraction - dim2: Second dimension for diagonal extraction - """ - super().__init__(ViewInfoType.DIAGONAL) - self.offset: int = offset - self.dim1: int = dim1 - self.dim2: int = dim2 - - def __eq__(self, other: object) -> bool: - if not isinstance(other, DiagonalInfo): - return False - return ( - self.offset == other.offset - and self.dim1 == other.dim1 - and self.dim2 == other.dim2 - ) - - def transform_tensor(self, jax_array: jax.Array) -> jax.Array: - raise NotImplementedError("DiagonalInfo.apply not implemented") - - def update_tensor( - self, new_value: jax.Array, jax_array: jax.Array - ) -> jax.Array: - raise NotImplementedError("DiagonalInfo.update not implemented") - - def calculate_output_shape(self, source: jax.Array) -> List[int]: - raise NotImplementedError( - "DiagonalInfo.calculate_output_shape not implemented" - ) + super().__init__(ViewInfoType.DIAGONAL) + self.offset: int = offset + self.dim1: int = dim1 + self.dim2: int = dim2 + + def __eq__(self, other: object) -> bool: + if not isinstance(other, DiagonalInfo): + return False + return ( + self.offset == other.offset + and self.dim1 == other.dim1 + and self.dim2 == other.dim2 + ) + + def transform_tensor(self, jax_array: jax.Array) -> jax.Array: + raise NotImplementedError("DiagonalInfo.apply not implemented") + + def update_tensor( + self, new_value: jax.Array, jax_array: jax.Array + ) -> jax.Array: + raise NotImplementedError("DiagonalInfo.update not implemented") + + def calculate_output_shape(self, source: jax.Array) -> List[int]: + raise NotImplementedError( + "DiagonalInfo.calculate_output_shape not implemented" + ) class View(torch.Tensor): + """ + A View is a reference to another Tensor or another View, + with a transformation applied to it. + """ + + @staticmethod + def __new__( + cls, + parent: Union["torchax.Tensor", "View"], + view_info: ViewInfo, + env: Any, + ) -> "View": + """ + Args: + parent: Parent tensor or view + view_info: Information about the view transformation + env: Environment for tensor operations + """ + shape = view_info.calculate_output_shape(parent.jax()) + return torch.Tensor._make_wrapper_subclass( + cls, + shape, + device="meta", + dtype=parent.dtype, + requires_grad=False, + ) + + def __init__( + self, + parent: Union["torchax.Tensor", "View"], + view_info: ViewInfo, + env: Any, + ) -> None: + super().__init__() + self.parent = parent + self.view_info = view_info + self._env = env + + def get_transformation_chain(self) -> List[ViewInfo]: """ - A View is a reference to another Tensor or another View, - with a transformation applied to it. + Get all view transformations from the source tensor to this view. + """ + if isinstance(self.parent, View): + transformations = self.parent.get_transformation_chain() + transformations.append(self.view_info) + return transformations + else: + return [self.view_info] + + __torch_function__ = torch._C._disabled_torch_function_impl + + def source_jax(self) -> jax.Array: """ + Returns the source tensor. + """ + if isinstance(self.parent, View): + return self.parent.source_jax() + else: + return self.parent.jax() + + def replace_source_jax(self, new_value: jax.Array) -> None: + """ + Update the source tensor with new values. + """ + if isinstance(self.parent, View): + self.parent.replace_source_jax(new_value) + else: + assert new_value.shape == self.parent._elem.shape + self.parent._elem = new_value + + def torch(self) -> "torchax.Tensor": + """ + Returns a Torchax tensor representing this view after all transformations + """ + from torchax.tensor import Tensor + + return Tensor(self.jax(), self._env) + + def update( + self, + new_values: Union[jax.Array, "View", "torchax.Tensor"], + view_infos: Optional[List[ViewInfo]] = None, + ) -> None: + """ + Update this view with new values, propagating changes back to source. + If view_infos is None, it will use the transformation chain + from the source tensor. + """ + if view_infos is None: + view_infos = self.get_transformation_chain() + + # Get the source JAX array + source_array = self.source_jax() + + # Get the new value + from torchax.tensor import Tensor + + if isinstance(new_values, View) or isinstance(new_values, Tensor): + new_values = new_values.jax() + + # Apply all view transformations to the source array + # And store intermediate values + intermediate_values = [source_array] + for view_info in view_infos[:-1]: + intermediate_values.append( + view_info.transform_tensor(intermediate_values[-1]) + ) + + # TODO: Investigate efficiency of this algorithm + # Update the source array with the new value by + # applying inverse transformations in reverse order + for view_info, parent_array in zip( + reversed(view_infos), reversed(intermediate_values) + ): + # Apply the inverse transformation to propagate changes back + new_values = view_info.update_tensor(new_values, parent_array) + + # Update the source tensor with the new values + self.replace_source_jax(new_values) + + @classmethod + def __torch_dispatch__( + cls, + func: Any, + types: Tuple[Any, ...], + args: Tuple[Any, ...] = (), + kwargs: Optional[dict] = None, + ) -> Any: + raise AssertionError( + "torchax Tensors can only do math within the torchax environment." + "Please wrap your code with `with torchax.default_env()` or " + "call torchax.enable_globally() before." + ) + + def create_sub_view(self, view_info: ViewInfo) -> "View": + """ + Create a new view that is a child of this view. + """ + return View(self, view_info, self._env) + + def __str__(self) -> str: + return f"View({self.torch()})" + + def jax(self) -> jax.Array: + """ + Returns a copy of the source tensor after transformations. + """ + result = self.source_jax() + for view_info in self.get_transformation_chain(): + result = view_info.transform_tensor(result) + return result + + def __setitem__(self, indexes, val): + view_infos = self.get_transformation_chain() + [NarrowInfo(indexes)] + self.update(view_infos=view_infos, new_values=val) + + def dim(self): + return self.ndim + + @property + def device(self): + return torch.device("jax:0") + + @property + def jax_device(self): + return self.jax().device + + @property + def ndim(self): + return len(self.shape) - @staticmethod - def __new__( - cls, - parent: Union["torchax.Tensor", "View"], - view_info: ViewInfo, - env: Any, - ) -> "View": - """ - Args: - parent: Parent tensor or view - view_info: Information about the view transformation - env: Environment for tensor operations - """ - shape = view_info.calculate_output_shape(parent.jax()) - return torch.Tensor._make_wrapper_subclass( - cls, - shape, - device="meta", - dtype=parent.dtype, - requires_grad=False, - ) - - def __init__( - self, - parent: Union["torchax.Tensor", "View"], - view_info: ViewInfo, - env: Any, - ) -> None: - super().__init__() - self.parent = parent - self.view_info = view_info - self._env = env - - def get_transformation_chain(self) -> List[ViewInfo]: - """ - Get all view transformations from the source tensor to this view. - """ - if isinstance(self.parent, View): - transformations = self.parent.get_transformation_chain() - transformations.append(self.view_info) - return transformations - else: - return [self.view_info] - - __torch_function__ = torch._C._disabled_torch_function_impl - - def source_jax(self) -> jax.Array: - """ - Returns the source tensor. - """ - if isinstance(self.parent, View): - return self.parent.source_jax() - else: - return self.parent.jax() - - def replace_source_jax(self, new_value: jax.Array) -> None: - """ - Update the source tensor with new values. - """ - if isinstance(self.parent, View): - self.parent.replace_source_jax(new_value) - else: - assert new_value.shape == self.parent._elem.shape - self.parent._elem = new_value - - def torch(self) -> "torchax.Tensor": - """ - Returns a Torchax tensor representing this view after all transformations - """ - from torchax.tensor import Tensor - - return Tensor(self.jax(), self._env) - - def update( - self, - new_values: Union[jax.Array, "View", "torchax.Tensor"], - view_infos: Optional[List[ViewInfo]] = None, - ) -> None: - """ - Update this view with new values, propagating changes back to source. - If view_infos is None, it will use the transformation chain - from the source tensor. - """ - if view_infos is None: - view_infos = self.get_transformation_chain() - - # Get the source JAX array - source_array = self.source_jax() - - # Get the new value - from torchax.tensor import Tensor - - if isinstance(new_values, View) or isinstance(new_values, Tensor): - new_values = new_values.jax() - - # Apply all view transformations to the source array - # And store intermediate values - intermediate_values = [source_array] - for view_info in view_infos[:-1]: - intermediate_values.append( - view_info.transform_tensor(intermediate_values[-1]) - ) - - # TODO: Investigate efficiency of this algorithm - # Update the source array with the new value by - # applying inverse transformations in reverse order - for view_info, parent_array in zip( - reversed(view_infos), reversed(intermediate_values) - ): - # Apply the inverse transformation to propagate changes back - new_values = view_info.update_tensor(new_values, parent_array) - - # Update the source tensor with the new values - self.replace_source_jax(new_values) - - @classmethod - def __torch_dispatch__( - cls, - func: Any, - types: Tuple[Any, ...], - args: Tuple[Any, ...] = (), - kwargs: Optional[dict] = None, - ) -> Any: - raise AssertionError( - "torchax Tensors can only do math within the torchax environment." - "Please wrap your code with `with torchax.default_env()` or " - "call torchax.enable_globally() before." - ) - - def create_sub_view(self, view_info: ViewInfo) -> "View": - """ - Create a new view that is a child of this view. - """ - return View(self, view_info, self._env) - - def __str__(self) -> str: - return f"View({self.torch()})" - - def jax(self) -> jax.Array: - """ - Returns a copy of the source tensor after transformations. - """ - result = self.source_jax() - for view_info in self.get_transformation_chain(): - result = view_info.transform_tensor(result) - return result - - def __setitem__(self, indexes, val): - view_infos = self.get_transformation_chain() + [NarrowInfo(indexes)] - self.update(view_infos=view_infos, new_values=val) - - def dim(self): - return self.ndim - - @property - def device(self): - return torch.device("jax:0") - - @property - def jax_device(self): - return self.jax().device - - @property - def ndim(self): - return len(self.shape) - - __repr__ = __str__ + __repr__ = __str__