diff --git a/README.md b/README.md
index 414eb868..90601b43 100644
--- a/README.md
+++ b/README.md
@@ -382,10 +382,10 @@ Email: jittor@qq.com
File an issue: https://github.com/Jittor/jittor/issues
-QQ Group: 761222083
+QQ Group: 836860279
-
+
## The Team
diff --git a/python/jittor/__init__.py b/python/jittor/__init__.py
index cfd39c42..b495d55c 100644
--- a/python/jittor/__init__.py
+++ b/python/jittor/__init__.py
@@ -9,7 +9,7 @@
# file 'LICENSE.txt', which is part of this source code package.
# ***************************************************************
-__version__ = '1.3.9.6'
+__version__ = '1.3.9.10'
from jittor_utils import lock
with lock.lock_scope():
ori_int = int
@@ -428,7 +428,9 @@ def random(shape, dtype="float32", type="uniform"):
jt.Var([[0.96788853 0.28334728 0.30482838]
[0.46107793 0.62798643 0.03457401]], dtype=float32)
'''
-
+ for dim in shape:
+ if dim < 0:
+ raise RuntimeError(f"Trying to create tensor with negative dimension {dim}: {shape}")
ret = ops.random(shape, "float32", type)
## TODO: move those code to core
#if dtype in ["float16", "bfloat16"]:
@@ -484,6 +486,9 @@ def ones(*shape, dtype="float32"):
shape = shape[:-1]
if isinstance(shape, tuple) and isinstance(shape[0], (Sequence, NanoVector)):
shape = shape[0]
+ for dim in shape:
+ if dim < 0:
+ raise RuntimeError(f"Trying to create tensor with negative dimension {dim}: {shape}")
return unary(1, dtype).broadcast(shape)
def new_ones(x, size):
@@ -515,6 +520,9 @@ def zeros(*shape, dtype="float32"):
shape = shape[:-1]
if isinstance(shape, tuple) and isinstance(shape[0], (Sequence, NanoVector)):
shape = shape[0]
+ for dim in shape:
+ if dim < 0:
+ raise RuntimeError(f"Trying to create tensor with negative dimension {dim}: {shape}")
return unary(0, dtype).broadcast(shape)
def new_zeros(x, size):
@@ -547,6 +555,9 @@ def full(shape,val,dtype="float32"):
'''
if not isinstance(shape, (NanoVector, Sequence)):
shape = (shape,)
+ for dim in shape:
+ if dim < 0:
+ raise RuntimeError(f"Trying to create tensor with negative dimension {dim}: {shape}")
return unary(val, dtype).broadcast(shape)
def new_full(x, size, val):
@@ -687,6 +698,8 @@ def flatten(input, start_dim=0, end_dim=-1):
start_dim = len(in_shape) + start_dim if start_dim < 0 else start_dim
end_dim = len(in_shape) + end_dim if end_dim < 0 else end_dim
assert end_dim >= start_dim, "end_dim should be larger than or equal to start_dim for flatten function"
+ if len(in_shape) <= end_dim:
+ raise IndexError(f"Dimension out of range (expected to be in range of [{-len(in_shape)}, {len(in_shape) - 1}], but got {end_dim})")
out_shape = []
for i in range(0,start_dim,1): out_shape.append(in_shape[i])
dims = 1
@@ -917,6 +930,9 @@ def randn(*size, dtype="float32", requires_grad=True) -> Var:
[-0.612632 -1.1471151 -1.1879086 ]], dtype=float32)
'''
if isinstance(size, tuple) and isinstance(size[0], (tuple, list, NanoVector)): size = size[0]
+ for dim in size:
+ if dim < 0:
+ raise RuntimeError(f"Trying to create tensor with negative dimension {dim}: {size}")
arr = jt.random(size, dtype, "normal")
if not requires_grad: return arr.stop_grad()
return arr
@@ -1013,6 +1029,9 @@ def randint(low, high=None, shape=(1,), dtype="int32") -> Var:
[1 1 1]], dtype=int32)
'''
if high is None: low, high = 0, low
+ for dim in shape:
+ if dim < 0:
+ raise RuntimeError(f"Trying to create tensor with negative dimension {dim}: {shape}")
v = (jt.random(shape) * (high - low) + low).clamp(low, high-0.5)
v = jt.floor_int(v)
return v.astype(dtype)
@@ -2152,3 +2171,7 @@ def inplace_wrapper(new_k, prev_func):
from . import math_util
from .math_util import *
from . import distributions
+
+if jt.compiler.has_acl:
+ from jittor.extern.acl.acl_compiler import change_function
+ change_function()
diff --git a/python/jittor/compatibility/__init__.py b/python/jittor/compatibility/__init__.py
index 6f88fad5..94d2e40b 100644
--- a/python/jittor/compatibility/__init__.py
+++ b/python/jittor/compatibility/__init__.py
@@ -1,430 +1,430 @@
-import os
-os.environ["FIX_TORCH_ERROR"] = "0"
-
-import jittor as jt
-from jittor import *
-from typing import Tuple
-
-org_int = int = type(1)
-org_float = float = type(1.0)
-org_bool = bool = type(True)
-
-import jtorch.compiler
-
-import jtorch_core
-from jtorch_core import *
-
-device.__reduce__ = lambda self: (device, (self.type,))
-device.__module__ = "jtorch"
-jt.jittor_core.device = device
-
-def handle_dtype(args, kw, dtype):
- def convert(x):
- if isinstance(x, jt.Var):
- return x.cast(dtype)
- return x
- if dtype is not None:
- if args is not None:
- if isinstance(args, (tuple,list)):
- args = [ convert(a) for a in args ]
- else:
- args = convert(x)
- if kw is not None:
- kw = { k:convert(v) for k,v in kw.items() }
- return args, kw
-
-def get_args_names(func):
- import inspect
- spec = inspect.getfullargspec(func)
- return spec[0] + spec[4]
-
-def wrapper(func):
- has_dtype = False
- if hasattr(func, "__code__"):
- has_dtype = "dtype" in get_args_names(func)
- def inner(*args, **kw):
- requires_grad = None
- dtype = None
- if "requires_grad" in kw:
- requires_grad = kw["requires_grad"]
- del kw["requires_grad"]
- if not has_dtype and "dtype" in kw:
- dtype = kw["dtype"]
- del kw["dtype"]
- if "device" in kw:
- del kw["device"]
- if 'pin_memory' in kw:
- del kw['pin_memory']
- args, kw = handle_dtype(args, kw, dtype)
- ret = func(*args, **kw)
- if isinstance(ret, jt.Var):
- if requires_grad is not None:
- ret.requires_grad = requires_grad
- if dtype is not None:
- ret.astype(dtype)
- return ret
- return inner
+# import os
+# os.environ["FIX_TORCH_ERROR"] = "0"
+
+# import jittor as jt
+# from jittor import *
+# from typing import Tuple
+
+# org_int = int = type(1)
+# org_float = float = type(1.0)
+# org_bool = bool = type(True)
+
+# import jtorch.compiler
+
+# import jtorch_core
+# from jtorch_core import *
+
+# device.__reduce__ = lambda self: (device, (self.type,))
+# device.__module__ = "jtorch"
+# jt.jittor_core.device = device
+
+# def handle_dtype(args, kw, dtype):
+# def convert(x):
+# if isinstance(x, jt.Var):
+# return x.cast(dtype)
+# return x
+# if dtype is not None:
+# if args is not None:
+# if isinstance(args, (tuple,list)):
+# args = [ convert(a) for a in args ]
+# else:
+# args = convert(x)
+# if kw is not None:
+# kw = { k:convert(v) for k,v in kw.items() }
+# return args, kw
+
+# def get_args_names(func):
+# import inspect
+# spec = inspect.getfullargspec(func)
+# return spec[0] + spec[4]
+
+# def wrapper(func):
+# has_dtype = False
+# if hasattr(func, "__code__"):
+# has_dtype = "dtype" in get_args_names(func)
+# def inner(*args, **kw):
+# requires_grad = None
+# dtype = None
+# if "requires_grad" in kw:
+# requires_grad = kw["requires_grad"]
+# del kw["requires_grad"]
+# if not has_dtype and "dtype" in kw:
+# dtype = kw["dtype"]
+# del kw["dtype"]
+# if "device" in kw:
+# del kw["device"]
+# if 'pin_memory' in kw:
+# del kw['pin_memory']
+# args, kw = handle_dtype(args, kw, dtype)
+# ret = func(*args, **kw)
+# if isinstance(ret, jt.Var):
+# if requires_grad is not None:
+# ret.requires_grad = requires_grad
+# if dtype is not None:
+# ret.astype(dtype)
+# return ret
+# return inner
-import inspect
-_wrapper_keys = set(["shape", "start", "size"])
-_wrapper_keys.add("x")
-for k,v in list(globals().items()):
- if callable(v) and not isinstance(v, type):
- try:
- spec = inspect.getfullargspec(v)
- args_name = spec[0]
- if len(args_name) and args_name[0] in _wrapper_keys:
- globals()[k] = wrapper(v)
- elif spec.varargs in _wrapper_keys:
- globals()[k] = wrapper(v)
- except:
- pass
-
-def empty(*size, dtype=jt.float32, device=None, requires_grad=False):
- if len(size) == 1 and not isinstance(size[0], org_int):
- size = size[0]
- return jt.empty(size, dtype)
-
-Tensor = Var
-
-Tensor.backward = lambda x: jtorch_core.backward(x)
-Tensor.grad = property(grad_get, grad_set, grad_del)
-Tensor.retains_grad = property(retain_grad_get, retain_grad_set)
-def retain_grad(x:Tensor, value:bool=True):
- x.retains_grad = value
- return value
-Tensor.retain_grad = retain_grad
-
-Tensor.dim = lambda self: self.ndim
-Tensor.ndimension = lambda self: self.ndim
-Tensor.nelement = lambda self: self.numel()
-Tensor.cuda = lambda self: self
-def device_get(x:Tensor):
- return device("cpu") if not jt.has_cuda or not jt.flags.use_cuda else device("cuda")
-Tensor.device = property(device_get)
-
-def argmax(x: Var, dim=None, keepdim: bool = False):
- return jt.argmax(x, dim, keepdim)[0]
-Tensor.argmax = argmax
-
-def tensor_type(x: Var, dtype=None, **kwargs):
- if dtype:
- return x.astype(dtype)
- else:
- return x.dtype
-Tensor.type = tensor_type
-
-def is_floating_point(x: Var):
- return "float" in str(x.dtype)
-Tensor.is_floating_point = is_floating_point
-
-from . import autograd
-from .autograd import *
-
-def tensor(data, *, dtype=None, device=None, requires_grad=False, pin_memory=False):
- if isinstance(data,list):
- data_list = []
- check = True
- for p in data:
- if isinstance(p, Tensor) and p.numel()==1:
- data_list.append(p.item())
- elif isinstance(p, (org_int,org_float)):
- data_list.append(p)
- else:
- check = False
- break
- if check:
- data = data_list
- return wrapper(array)(data, dtype=dtype, device=device, requires_grad=requires_grad, pin_memory=pin_memory)
-
-# tensor = wrapper(array)
-from_numpy = wrapper(array)
-strided = None
-
-def mod_zero_grad(self):
- for p in self.parameters():
- p.grad = None
-Module.zero_grad = mod_zero_grad
-
-class ModuleMisc:
- def parameters(self):
- return iter(super().parameters())
-
- def load_state_dict(self, state_dict, strict=False):
- return super().load_state_dict(state_dict)
-
- def to(self, device=None,dtype=None):
- ''' do nothing but return its self'''
- return self
- def register_parameter(self,name,data):
- self.name = data
-
- def buffers(self):
- for _, buf in self.named_buffers():
- yield buf
+# import inspect
+# _wrapper_keys = set(["shape", "start", "size"])
+# _wrapper_keys.add("x")
+# for k,v in list(globals().items()):
+# if callable(v) and not isinstance(v, type):
+# try:
+# spec = inspect.getfullargspec(v)
+# args_name = spec[0]
+# if len(args_name) and args_name[0] in _wrapper_keys:
+# globals()[k] = wrapper(v)
+# elif spec.varargs in _wrapper_keys:
+# globals()[k] = wrapper(v)
+# except:
+# pass
+
+# def empty(*size, dtype=jt.float32, device=None, requires_grad=False):
+# if len(size) == 1 and not isinstance(size[0], org_int):
+# size = size[0]
+# return jt.empty(size, dtype)
+
+# Tensor = Var
+
+# Tensor.backward = lambda x: jtorch_core.backward(x)
+# Tensor.grad = property(grad_get, grad_set, grad_del)
+# Tensor.retains_grad = property(retain_grad_get, retain_grad_set)
+# def retain_grad(x:Tensor, value:bool=True):
+# x.retains_grad = value
+# return value
+# Tensor.retain_grad = retain_grad
+
+# Tensor.dim = lambda self: self.ndim
+# Tensor.ndimension = lambda self: self.ndim
+# Tensor.nelement = lambda self: self.numel()
+# Tensor.cuda = lambda self: self
+# def device_get(x:Tensor):
+# return device("cpu") if not jt.has_cuda or not jt.flags.use_cuda else device("cuda")
+# Tensor.device = property(device_get)
+
+# def argmax(x: Var, dim=None, keepdim: bool = False):
+# return jt.argmax(x, dim, keepdim)[0]
+# Tensor.argmax = argmax
+
+# def tensor_type(x: Var, dtype=None, **kwargs):
+# if dtype:
+# return x.astype(dtype)
+# else:
+# return x.dtype
+# Tensor.type = tensor_type
+
+# def is_floating_point(x: Var):
+# return "float" in str(x.dtype)
+# Tensor.is_floating_point = is_floating_point
+
+# from . import autograd
+# from .autograd import *
+
+# def tensor(data, *, dtype=None, device=None, requires_grad=False, pin_memory=False):
+# if isinstance(data,list):
+# data_list = []
+# check = True
+# for p in data:
+# if isinstance(p, Tensor) and p.numel()==1:
+# data_list.append(p.item())
+# elif isinstance(p, (org_int,org_float)):
+# data_list.append(p)
+# else:
+# check = False
+# break
+# if check:
+# data = data_list
+# return wrapper(array)(data, dtype=dtype, device=device, requires_grad=requires_grad, pin_memory=pin_memory)
+
+# # tensor = wrapper(array)
+# from_numpy = wrapper(array)
+# strided = None
+
+# def mod_zero_grad(self):
+# for p in self.parameters():
+# p.grad = None
+# Module.zero_grad = mod_zero_grad
+
+# class ModuleMisc:
+# def parameters(self):
+# return iter(super().parameters())
+
+# def load_state_dict(self, state_dict, strict=False):
+# return super().load_state_dict(state_dict)
+
+# def to(self, device=None,dtype=None):
+# ''' do nothing but return its self'''
+# return self
+# def register_parameter(self,name,data):
+# self.name = data
+
+# def buffers(self):
+# for _, buf in self.named_buffers():
+# yield buf
-def make_module(cls):
- class TMod(ModuleMisc, cls):
- def __init__(self, *args, **kw):
- dtype = None
- if "dtype" in kw:
- dtype = kw["dtype"]
- del kw["dtype"]
- self._dtype = dtype
- with jt.flag_scope(th_mode=0):
- if "device" in kw:
- del kw["device"]
- super().__init__(*args, **kw)
- for k,v in self.__dict__.items():
- if not k.startswith("_") and isinstance(v, Var) \
- and v.requires_grad:
- v.retain_grad()
- if dtype is not None and isinstance(v, Var):
- v.assign(v.cast(dtype))
- def __call__(self, *args, **kw):
- args, kw = handle_dtype(args, kw, self._dtype)
- # if forward is override by user, call forward
- if self.__class__.forward is not TMod.forward:
- return self.forward(*args, **kw)
- return self.execute(*args, **kw)
- def forward(self, *args, **kw):
- args, kw = handle_dtype(args, kw, self._dtype)
- return self.execute(*args, **kw)
+# def make_module(cls):
+# class TMod(ModuleMisc, cls):
+# def __init__(self, *args, **kw):
+# dtype = None
+# if "dtype" in kw:
+# dtype = kw["dtype"]
+# del kw["dtype"]
+# self._dtype = dtype
+# with jt.flag_scope(th_mode=0):
+# if "device" in kw:
+# del kw["device"]
+# super().__init__(*args, **kw)
+# for k,v in self.__dict__.items():
+# if not k.startswith("_") and isinstance(v, Var) \
+# and v.requires_grad:
+# v.retain_grad()
+# if dtype is not None and isinstance(v, Var):
+# v.assign(v.cast(dtype))
+# def __call__(self, *args, **kw):
+# args, kw = handle_dtype(args, kw, self._dtype)
+# # if forward is override by user, call forward
+# if self.__class__.forward is not TMod.forward:
+# return self.forward(*args, **kw)
+# return self.execute(*args, **kw)
+# def forward(self, *args, **kw):
+# args, kw = handle_dtype(args, kw, self._dtype)
+# return self.execute(*args, **kw)
- @property
- def training(self):
- if not hasattr(self, "is_train"):
- self.is_train = True
- return self.is_train
- @training.setter
- def training(self, value):
- self.is_train = value
-
- TMod.__name__ = cls.__name__
- return TMod
-
-import jtorch.cuda
-import jtorch.nn
-from jtorch.nn import Module, Parameter
-import jtorch.optim
-
-from jtorch.utils.dtype import Dtype, get_string_dtype
-
-def frombuffer(buffer: bytearray,
- *,
- dtype: Dtype,
- count: int = -1,
- offset: int = 0,
- requires_grad: bool = True) -> Tensor:
- dtype = get_string_dtype(dtype)
- tensor = jt.array(np.frombuffer(buffer, dtype, count=count, offset=offset))
- if requires_grad and tensor.dtype.is_float():
- tensor.requires_grad = True
- return tensor
-
-def conflict_wrapper(origin_func, new_func):
- def wrapper(*args, **kw):
- if jt.flags.th_mode:
- return new_func(*args, **kw)
- else:
- return origin_func(*args, **kw)
- return wrapper
-
-def min(*args, **kw):
- dim = None
- if len(args) >= 2 and isinstance(args[1], org_int):
- dim = args[1]
- elif "dim" in kw and isinstance(kw["dim"], org_int):
- dim = kw["dim"]
- if dim is not None:
- k, v = jt.argmin(*args, **kw)
- return v, k
- elif len(args) == 2 and isinstance(args[1], jt.Var):
- return jt.minimum(args[0], args[1])
- else:
- return jt.min(*args, **kw)
-Tensor.min = conflict_wrapper(jt.min, min)
-
-def max(*args, **kw):
- dim = None
- if "dim" in kw:
- x = kw["dim"]
- if len(args) >= 2 and isinstance(args[1], org_int):
- dim = args[1]
- elif "dim" in kw and isinstance(kw["dim"], org_int):
- dim = kw["dim"]
- if dim is not None:
- k, v = jt.argmax(*args, **kw)
- return v, k
- elif len(args) == 2 and isinstance(args[1], jt.Var):
- return jt.maximum(args[0], args[1])
- else:
- return jt.max(*args, **kw)
-Tensor.max = conflict_wrapper(jt.max, max)
-
-def argsort(*args, **kw):
- k, v = jt.argsort(*args, **kw)
- return k
-Tensor.argsort = conflict_wrapper(jt.argsort, argsort)
-
-LongTensor = jt.int64
-FloatTensor = jt.float
-HalfTensor = jt.float16
-BoolTensor = jt.bool
-IntTensor = jt.int32
-
-class JDType:
- def __init__(self, func, str):
- self.func = func
- self.str = str
- self.__name__ = str.split(".")[-1]
- def __call__(self, *args, **kw):
- return self.func(*args, **kw)
- def __str__(self):
- return self.str
- def is_floating_point(self):
- return "float" in str(self.str)
-
-int8 = JDType(jt.int8, "torch.int8")
-int16 = JDType(jt.int16, "torch.int16")
-int = int32 = JDType(jt.int32, "torch.int32")
-long = int64 = JDType(jt.int64, "torch.int64")
-
-half = float16 = JDType(jt.float16, "torch.float16")
-float = float32 = JDType(jt.float32, "torch.float32")
-double = float64 = JDType(jt.float64, "torch.float64")
-bfloat16 = "bfloat16" # TODO
-complex64 = "complex64" # TODO
-complex128 = "complex128" # TODO
-def get_JDtype(dtype):
- if dtype=='float32' or dtype == jt.float32:
- return float32
- elif dtype=='float64' or dtype == jt.float64:
- return float64
- elif dtype=='float16' or dtype == jt.float16:
- return float16
- elif dtype=='int32' or dtype == jt.int32:
- return int32
- elif dtype=='int64' or dtype == jt.int64:
- return int64
- elif dtype=='int16' or dtype == jt.int16:
- return int16
- elif dtype=='int8' or dtype == jt.int8:
- return int8
- else:
- raise Exception("dtype {} not supported".format(dtype))
-
-def load(path,**kwargs):
- def _to_jittor(data):
- if isinstance(data,dict):
- return {k:_to_jittor(d) for k,d in data.items()}
- if isinstance(data,list):
- return [_to_jittor(d) for d in data]
- if isinstance(data,np.ndarray):
- return jt.array(data)
- return data
- data = jt.load(path)
+# @property
+# def training(self):
+# if not hasattr(self, "is_train"):
+# self.is_train = True
+# return self.is_train
+# @training.setter
+# def training(self, value):
+# self.is_train = value
+
+# TMod.__name__ = cls.__name__
+# return TMod
+
+# import jtorch.cuda
+# import jtorch.nn
+# from jtorch.nn import Module, Parameter
+# import jtorch.optim
+
+# from jtorch.utils.dtype import Dtype, get_string_dtype
+
+# def frombuffer(buffer: bytearray,
+# *,
+# dtype: Dtype,
+# count: int = -1,
+# offset: int = 0,
+# requires_grad: bool = True) -> Tensor:
+# dtype = get_string_dtype(dtype)
+# tensor = jt.array(np.frombuffer(buffer, dtype, count=count, offset=offset))
+# if requires_grad and tensor.dtype.is_float():
+# tensor.requires_grad = True
+# return tensor
+
+# def conflict_wrapper(origin_func, new_func):
+# def wrapper(*args, **kw):
+# if jt.flags.th_mode:
+# return new_func(*args, **kw)
+# else:
+# return origin_func(*args, **kw)
+# return wrapper
+
+# def min(*args, **kw):
+# dim = None
+# if len(args) >= 2 and isinstance(args[1], org_int):
+# dim = args[1]
+# elif "dim" in kw and isinstance(kw["dim"], org_int):
+# dim = kw["dim"]
+# if dim is not None:
+# k, v = jt.argmin(*args, **kw)
+# return v, k
+# elif len(args) == 2 and isinstance(args[1], jt.Var):
+# return jt.minimum(args[0], args[1])
+# else:
+# return jt.min(*args, **kw)
+# Tensor.min = conflict_wrapper(jt.min, min)
+
+# def max(*args, **kw):
+# dim = None
+# if "dim" in kw:
+# x = kw["dim"]
+# if len(args) >= 2 and isinstance(args[1], org_int):
+# dim = args[1]
+# elif "dim" in kw and isinstance(kw["dim"], org_int):
+# dim = kw["dim"]
+# if dim is not None:
+# k, v = jt.argmax(*args, **kw)
+# return v, k
+# elif len(args) == 2 and isinstance(args[1], jt.Var):
+# return jt.maximum(args[0], args[1])
+# else:
+# return jt.max(*args, **kw)
+# Tensor.max = conflict_wrapper(jt.max, max)
+
+# def argsort(*args, **kw):
+# k, v = jt.argsort(*args, **kw)
+# return k
+# Tensor.argsort = conflict_wrapper(jt.argsort, argsort)
+
+# LongTensor = jt.int64
+# FloatTensor = jt.float
+# HalfTensor = jt.float16
+# BoolTensor = jt.bool
+# IntTensor = jt.int32
+
+# class JDType:
+# def __init__(self, func, str):
+# self.func = func
+# self.str = str
+# self.__name__ = str.split(".")[-1]
+# def __call__(self, *args, **kw):
+# return self.func(*args, **kw)
+# def __str__(self):
+# return self.str
+# def is_floating_point(self):
+# return "float" in str(self.str)
+
+# int8 = JDType(jt.int8, "torch.int8")
+# int16 = JDType(jt.int16, "torch.int16")
+# int = int32 = JDType(jt.int32, "torch.int32")
+# long = int64 = JDType(jt.int64, "torch.int64")
+
+# half = float16 = JDType(jt.float16, "torch.float16")
+# float = float32 = JDType(jt.float32, "torch.float32")
+# double = float64 = JDType(jt.float64, "torch.float64")
+# bfloat16 = "bfloat16" # TODO
+# complex64 = "complex64" # TODO
+# complex128 = "complex128" # TODO
+# def get_JDtype(dtype):
+# if dtype=='float32' or dtype == jt.float32:
+# return float32
+# elif dtype=='float64' or dtype == jt.float64:
+# return float64
+# elif dtype=='float16' or dtype == jt.float16:
+# return float16
+# elif dtype=='int32' or dtype == jt.int32:
+# return int32
+# elif dtype=='int64' or dtype == jt.int64:
+# return int64
+# elif dtype=='int16' or dtype == jt.int16:
+# return int16
+# elif dtype=='int8' or dtype == jt.int8:
+# return int8
+# else:
+# raise Exception("dtype {} not supported".format(dtype))
+
+# def load(path,**kwargs):
+# def _to_jittor(data):
+# if isinstance(data,dict):
+# return {k:_to_jittor(d) for k,d in data.items()}
+# if isinstance(data,list):
+# return [_to_jittor(d) for d in data]
+# if isinstance(data,np.ndarray):
+# return jt.array(data)
+# return data
+# data = jt.load(path)
- return _to_jittor(data)
+# return _to_jittor(data)
-def is_tensor(x):
- return isinstance(x, Tensor)
+# def is_tensor(x):
+# return isinstance(x, Tensor)
-manual_seed = jt.set_global_seed
-jt.flags.amp_level = 3
-Size = jt.NanoVector
+# manual_seed = jt.set_global_seed
+# jt.flags.amp_level = 3
+# Size = jt.NanoVector
-class Generator:
- def __init__(self,*args,**kw) -> None:
- self.seed = None
- def manual_seed(self,seed):
- self.seed = seed
+# class Generator:
+# def __init__(self,*args,**kw) -> None:
+# self.seed = None
+# def manual_seed(self,seed):
+# self.seed = seed
-from . import fx
+# from . import fx
-_default_type = "float32"
+# _default_type = "float32"
-def get_default_dtype():
- return _default_type
-def set_default_dtype(dtype):
- global _default_type
- _default_type = dtype
+# def get_default_dtype():
+# return _default_type
+# def set_default_dtype(dtype):
+# global _default_type
+# _default_type = dtype
-dtype = JDType
+# dtype = JDType
-def div(x,y,rounding_mode="floor"):
- assert rounding_mode == "floor"
- z = (x / y)
- if rounding_mode == "floor":
- z = z.floor()
- if x.dtype == "int32" and (isinstance(y,org_int) or y.dtype == "int32"):
- z = z.int32()
- return z
+# def div(x,y,rounding_mode="floor"):
+# assert rounding_mode == "floor"
+# z = (x / y)
+# if rounding_mode == "floor":
+# z = z.floor()
+# if x.dtype == "int32" and (isinstance(y,org_int) or y.dtype == "int32"):
+# z = z.int32()
+# return z
-def randn(*args,**kw):
- wrap_randn = wrapper(jt.randn)
- generator = kw.get('generator',None)
- kw.pop('generator',None)
- if 'layout' in kw:
- del kw['layout']
- if generator is not None and generator.seed is not None:
- jt.set_global_seed(generator.seed)
- return wrap_randn(*args,**kw)
+# def randn(*args,**kw):
+# wrap_randn = wrapper(jt.randn)
+# generator = kw.get('generator',None)
+# kw.pop('generator',None)
+# if 'layout' in kw:
+# del kw['layout']
+# if generator is not None and generator.seed is not None:
+# jt.set_global_seed(generator.seed)
+# return wrap_randn(*args,**kw)
-def rand(*args,**kw):
- print("rand")
- wrap_rand = wrapper(jt.rand)
- generator = kw.get('generator',None)
- kw.pop('generator',None)
- if 'layout' in kw:
- del kw['layout']
- if generator is not None and generator.seed is not None:
- jt.set_global_seed(generator.seed)
- return wrap_rand(*args,**kw)
+# def rand(*args,**kw):
+# print("rand")
+# wrap_rand = wrapper(jt.rand)
+# generator = kw.get('generator',None)
+# kw.pop('generator',None)
+# if 'layout' in kw:
+# del kw['layout']
+# if generator is not None and generator.seed is not None:
+# jt.set_global_seed(generator.seed)
+# return wrap_rand(*args,**kw)
-def set_default_tensor_type(t: type or str):
- if isinstance(t, str):
- info = t.split(".")
- if len(info) == 3 and info[1] == 'cuda':
- jt.flags.use_cuda = 1
- #TODO: type
+# def set_default_tensor_type(t: type or str):
+# if isinstance(t, str):
+# info = t.split(".")
+# if len(info) == 3 and info[1] == 'cuda':
+# jt.flags.use_cuda = 1
+# #TODO: type
-def clamp(x, min=None, max=None):
- return jt.clamp(x, min, max)
+# def clamp(x, min=None, max=None):
+# return jt.clamp(x, min, max)
-def to(x,*args,**kw):
- device = None
- if len(args) == 1:
- device = args[0]
- if isinstance(device, jt.NanoString) or callable(device):
- return jt.to(x,*args,**kw)
- if 'cpu' in str(device):
- args = []
- device = kw.get("device",None)
- if 'cpu' in str(device):
- kw.pop('device',None)
- print("to cpu")
- # print(kw)
- return jt.to(x,*args,**kw)
-Tensor.to = conflict_wrapper(jt.to, to)
+# def to(x,*args,**kw):
+# device = None
+# if len(args) == 1:
+# device = args[0]
+# if isinstance(device, jt.NanoString) or callable(device):
+# return jt.to(x,*args,**kw)
+# if 'cpu' in str(device):
+# args = []
+# device = kw.get("device",None)
+# if 'cpu' in str(device):
+# kw.pop('device',None)
+# print("to cpu")
+# # print(kw)
+# return jt.to(x,*args,**kw)
+# Tensor.to = conflict_wrapper(jt.to, to)
-mm = wrapper(jt.matmul)
+# mm = wrapper(jt.matmul)
-def _data_get(x):
- return x
+# def _data_get(x):
+# return x
-def _data_set(x, value):
- x.assign(value)
+# def _data_set(x, value):
+# x.assign(value)
-Tensor.data = property(_data_get, _data_set)
-Tensor.layout = None
\ No newline at end of file
+# Tensor.data = property(_data_get, _data_set)
+# Tensor.layout = None
\ No newline at end of file
diff --git a/python/jittor/compatibility/utils/data.py b/python/jittor/compatibility/utils/data.py
index 71946a23..5fcfcaa6 100644
--- a/python/jittor/compatibility/utils/data.py
+++ b/python/jittor/compatibility/utils/data.py
@@ -99,39 +99,39 @@ def inner_iter(self):
current_batch = self.collate_batch(current_batch)
yield self.to_jittor(current_batch)
-def get_worker_info():
- # always return the fake worker info
- return namedtuple('WorkerInfo', 'id num_workers')(0, 1)
-
-class RandomSampler(jt.dataset.RandomSampler):
- def __init__(self, dataset, generator=None, **kwargs):
- super().__init__(dataset, **kwargs)
-
- def __iter__(self):
- if getattr(self.dataset, "support_random_access", True):
- return super().__iter__()
- else:
- self.dataset.shuffle()
- return iter(range(self.dataset.__real_len__() if hasattr(self.dataset,"__real_len__") else self.dataset.__len__()))
-
-class DistributedSampler(jt.dataset.Sampler):
- def __init__(self, sampler: RandomSampler):
- assert(isinstance(sampler, RandomSampler))
- self.sampler = sampler
-
- def set_epoch(self, epoch: int):
- ### do nothing, let jittor's inner dataset handle
- pass
-
- def __iter__(self):
- return self.sampler.__iter__()
+# def get_worker_info():
+# # always return the fake worker info
+# return namedtuple('WorkerInfo', 'id num_workers')(0, 1)
+
+# class RandomSampler(jt.dataset.RandomSampler):
+# def __init__(self, dataset, generator=None, **kwargs):
+# super().__init__(dataset, **kwargs)
+
+# def __iter__(self):
+# if getattr(self.dataset, "support_random_access", True):
+# return super().__iter__()
+# else:
+# self.dataset.shuffle()
+# return iter(range(self.dataset.__real_len__() if hasattr(self.dataset,"__real_len__") else self.dataset.__len__()))
+
+# class DistributedSampler(jt.dataset.Sampler):
+# def __init__(self, sampler: RandomSampler):
+# assert(isinstance(sampler, RandomSampler))
+# self.sampler = sampler
+
+# def set_epoch(self, epoch: int):
+# ### do nothing, let jittor's inner dataset handle
+# pass
+
+# def __iter__(self):
+# return self.sampler.__iter__()
- def __len__(self):
- return self.sampler.__len__()
+# def __len__(self):
+# return self.sampler.__len__()
-BatchSampler = jt.dataset.BatchSampler
-Sampler = jt.dataset.Sampler
-SequentialSampler = jt.dataset.SequentialSampler
-SubsetRandomSampler = jt.dataset.SubsetRandomSampler
+# BatchSampler = jt.dataset.BatchSampler
+# Sampler = jt.dataset.Sampler
+# SequentialSampler = jt.dataset.SequentialSampler
+# SubsetRandomSampler = jt.dataset.SubsetRandomSampler
-TensorDataset = Dataset
+# TensorDataset = Dataset
diff --git a/python/jittor/contrib.py b/python/jittor/contrib.py
index 31dd72de..c4026eee 100644
--- a/python/jittor/contrib.py
+++ b/python/jittor/contrib.py
@@ -15,6 +15,8 @@
def argmax_pool(x, size, stride, padding=0):
+ if stride<=0:
+ raise RuntimeError(f"stride must be > 0, but got {stride}")
return pool.pool(x, size, 'maximum', padding, stride)
def concat(arr, dim):
@@ -241,9 +243,17 @@ def concat(arr, dim=0):
if len(arr) == 0:
raise ValueError("need at least one array to concat")
total_dim = 0
- if dim < 0: dim += len(arr[0].shape)
+ base_dim = len(arr[0].shape)
+ if dim < 0: dim += base_dim
+ if dim < 0 or dim >= base_dim:
+ raise IndexError(f"Dimension out of range (expected to be in range of [{-base_dim}, {base_dim-1}], but got {dim})")
dtypes = []
for a in arr:
+ if len(a.shape) != base_dim:
+ raise RuntimeError(f"get different number of dimensions of {base_dim} and {len(a.shape)}")
+ for i in range(base_dim):
+ if i != dim and a.shape[i] != arr[0].shape[i]:
+ raise RuntimeError(f"Sizes of vars must match except in dimension {dim}. Expected size {arr[0].shape[i]} but got size {a.shape[i]} for dimension number {i} in the list.")
total_dim += a.shape[dim]
dtypes.append(str(a.dtype))
cdim = 0
diff --git a/python/jittor/dataset/mnist.py b/python/jittor/dataset/mnist.py
index 7aa1d883..f0945f94 100644
--- a/python/jittor/dataset/mnist.py
+++ b/python/jittor/dataset/mnist.py
@@ -26,7 +26,7 @@ class MNIST(Dataset):
[in] data_root(str): your data root.
[in] train(bool): choose model train or val.
- [in] download(bool): Download data automatically if download is Ture.
+ [in] download(bool): Download data automatically if download is True.
[in] batch_size(int): Data batch size.
[in] shuffle(bool): Shuffle data if true.
[in] transform(jittor.transform): transform data.
@@ -106,7 +106,7 @@ class EMNIST(Dataset):
[in] data_root(str): your data root.
[in] split(str): one of 'byclass', 'bymerge', 'balanced', 'letters', 'digits', 'mnist'.
[in] train(bool): choose model train or val.
- [in] download(bool): Download data automatically if download is Ture.
+ [in] download(bool): Download data automatically if download is True.
[in] batch_size(int): Data batch size.
[in] shuffle(bool): Shuffle data if true.
[in] transform(jittor.transform): transform data.
diff --git a/python/jittor/extern/acl/acl_compiler.py b/python/jittor/extern/acl/acl_compiler.py
index fea9cfce..0ab4b6eb 100644
--- a/python/jittor/extern/acl/acl_compiler.py
+++ b/python/jittor/extern/acl/acl_compiler.py
@@ -10,6 +10,7 @@
import ctypes
import glob
import jittor.compiler as compiler
+import jittor as jt
has_acl = 0
cc_flags = ""
@@ -34,16 +35,18 @@
# export DUMP_GRAPH_LEVEL=1
# build pytorch-npu
-# bash ./ci/build.sh
-# python3 -m pip install ./dist/torch_npu-1.11.0.post1-cp37-cp37m-linux_x86_64.whl --force-reinstall
+# bash ./ci/build.sh
+# python3 -m pip install ./dist/torch_npu-1.11.0.post1-cp37-cp37m-linux_x86_64.whl --force-reinstall
# pytorch: conda activate cann && source /usr/local/Ascend/ascend-toolkit/set_env.sh && export TASK_QUEUE_ENABLE=0 && cd /home/cjld/new_jittor/jittor/my/mm_benchmark
# python3 ./mm_bench_pt_npu.py
+
def install():
import jittor.compiler as compiler
global has_acl, cc_flags
acl_compiler_home = os.path.dirname(__file__)
- cc_files = sorted(glob.glob(acl_compiler_home+"/**/*.cc", recursive=True))
+ cc_files = sorted(glob.glob(acl_compiler_home + "/**/*.cc",
+ recursive=True))
cc_files2 = []
for name in cc_files:
if "acl_op_exec" in name:
@@ -52,20 +55,22 @@ def install():
cc_files2.append(name)
cc_files = cc_files2
cc_flags += f" -DHAS_CUDA -DIS_ACL \
- -I/usr/local/Ascend/ascend-toolkit/latest/x86_64-linux/include/ \
- -L/usr/local/Ascend/ascend-toolkit/latest/x86_64-linux/lib64 \
+ -I/usr/local/Ascend/ascend-toolkit/latest/include/ \
+ -L/usr/local/Ascend/ascend-toolkit/latest/lib64/ \
-I{acl_compiler_home} -lascendcl -lacl_op_compiler "
+
ctypes.CDLL("libascendcl.so", dlopen_flags)
'''
-ltikc_runtime
- -I/usr/local/Ascend/driver/include \
- -L/usr/local/Ascend/compiler/lib64 \
- -L/usr/local/Ascend/runtime/lib64 \
+ -I/usr/local/Ascend/driver/include/ \
+ -L/usr/local/Ascend/ascend-toolkit/latest/compiler/lib64/ \
+ -L/usr/local/Ascend/ascend-toolkit/latest/runtime/lib64/ \
'''
jittor_utils.LOG.i("ACL detected")
global mod
- mod = jittor_utils.compile_module('''
+ mod = jittor_utils.compile_module(
+ '''
#include "common.h"
namespace jittor {
// @pyjt(process)
@@ -98,9 +103,10 @@ def check():
if not has_acl: return False
compiler.cc_flags += cc_flags
compiler.nvcc_path = tikcc_path
- compiler.nvcc_flags = compiler.cc_flags.replace("-std=c++14","")
+ compiler.nvcc_flags = compiler.cc_flags.replace("-std=c++14", "")
return True
+
def post_process():
if has_acl:
from jittor import pool
@@ -108,5 +114,711 @@ def post_process():
import jittor as jt
jt.flags.use_cuda_host_allocator = 1
jt.flags.use_parallel_op_compiler = 0
- jt.flags.amp_reg |= 32 + 4 # 32 keep float16, 4 keep reduce type
- mod.init_acl_ops()
\ No newline at end of file
+ jt.flags.amp_reg |= 32 + 4 # 32 keep float16, 4 keep reduce type
+ mod.init_acl_ops()
+
+
+def acl_cmd(name: str, inputs: list, output_dtypes: list, output_shapes: list,
+ attr: dict):
+
+ input_code = ''
+ for i in range(len(inputs)):
+ if name == 'MaxPoolWithArgmaxV1' or name == 'MaxPoolGradWithArgmaxV1':
+ input_code += f"op.add(in{i}, true, ACL_FORMAT_NCHW);\n"
+ else:
+ input_code += f"op.add(in{i}, true);\n"
+
+ output_code = ''
+ for i in range(len(output_dtypes)):
+ if name == 'MaxPoolWithArgmaxV1'or name == 'MaxPoolGradWithArgmaxV1':
+ output_code += f"op.add(out{i}, false, ACL_FORMAT_NCHW);\n"
+ else:
+ output_code += f"op.add(out{i}, false);\n"
+
+ # add attr to op
+ attr_code = ''
+ if name == 'MaxPoolWithArgmaxV1' or name == 'MaxPoolGradWithArgmaxV1':
+ for k, v in attr.items():
+ if k == "ceil_mode":
+ v = 'false' if v == False else 'true'
+ attr_code += f"op.set_attr(\"{k}\", {v});\n"
+ else:
+ v = str(v).replace('[', '{').replace(']', '}')
+ attr_code += f"op.set_attr(\"{k}\", vector{v});\n"
+ else:
+ for k, v in attr.items():
+ if isinstance(v, bool):
+ attr_code += f"op.set_attr(\"{k}\", {str(v).lower()});\n"
+ elif isinstance(v, str):
+ attr_code += f"op.set_attr(\"{k}\", \"{v}\");\n"
+ else:
+ attr_code += f"op.set_attr(\"{k}\", int({v}));\n"
+
+ import jittor as jt
+ return jt.code(
+ output_shapes,
+ output_dtypes,
+ inputs,
+ cuda_header="""
+ #include
+ #include
+ #include
+ #include
+
+ namespace jittor {
+
+ void printDeviceData(const vector& output_desc, const vector& output_data, const string& name = "", bool input=true) {
+ LOGir << "name: " << name;
+ if(input)
+ LOGir << "is input";
+ else
+ LOGir << "is ouput";
+ for (size_t i = 0; i < output_desc.size(); ++i) {
+ void* base_addr = aclGetDataBufferAddr(output_data[i]);
+ LOGir << "addr of data[" << i << "] :" << base_addr;
+ size_t num_dims = aclGetTensorDescNumDims(output_desc[i]);
+ size_t total_size = 1;
+ std::vector dims(num_dims);
+
+ std::cout << "shape of data: ";
+ for (size_t j = 0; j < num_dims; ++j) {
+ aclGetTensorDescDimV2(output_desc[i], j, &dims[j]);
+ total_size *= dims[j];
+ std::cout << dims[j] << ", ";
+ }
+ int evey_batch_size = total_size/dims[0];
+ std::cout << std::endl;
+
+ // for(int i= 0; i < dims[0]; i++) {
+ // evey_batch_size = 16;
+ // std::vector host_buffer(evey_batch_size);
+ // void* offset_addr = static_cast(base_addr) + i * evey_batch_size * sizeof(float);
+ // aclrtMemcpy(host_buffer.data(), evey_batch_size * sizeof(float), offset_addr, evey_batch_size * sizeof(float), ACL_MEMCPY_DEVICE_TO_HOST);
+ // std::cout << "batch[" << i << "]:";
+ // for (size_t k = 0; k < evey_batch_size; ++k) {
+ // std::cout << host_buffer[k] << ", ";
+ // }
+ // std::cout << std::endl;
+ // }
+ }
+ }
+
+ struct AclOpRunner {
+ string name;
+ vector input_desc;
+ vector output_desc;
+ vector input_data;
+ vector output_data;
+ aclopAttr *attr;
+ vector> input_host;
+ vector> input_host_32;
+
+ AclOpRunner(const string& name) : name(name) {
+ attr = aclopCreateAttr();
+ }
+
+ ~AclOpRunner() {
+ for (auto i : input_desc) aclDestroyTensorDesc(i);
+ for (auto i : output_desc) aclDestroyTensorDesc(i);
+ for (auto i : input_data) aclDestroyDataBuffer(i);
+ for (auto i : output_data) aclDestroyDataBuffer(i);
+ aclopDestroyAttr(attr);
+ }
+
+ aclDataType get_dtype(NanoString s) {
+ if (s == ns_float32) return ACL_FLOAT;
+ if (s == ns_float16) return ACL_FLOAT16;
+ if (s == ns_int64) return ACL_INT64;
+ if (s == ns_int32) return ACL_INT32;
+ if (s == ns_int8) return ACL_INT8;
+ if (s == ns_int16) return ACL_INT16;
+ if (s == ns_uint8) return ACL_UINT8;
+ if (s == ns_uint16) return ACL_UINT16;
+ if (s == ns_uint32) return ACL_UINT32;
+ if (s == ns_bool) return ACL_BOOL;
+ LOGf << "Not supported dtype: " << s;
+ return ACL_FLOAT;
+ }
+
+ void add(Var* v, bool is_input, int format=ACL_FORMAT_ND) {
+ int64_t shape[v->shape.size()];
+ for (int i=0; ishape.size(); i++) shape[i] = v->shape[i];
+
+ auto desc = aclCreateTensorDesc(get_dtype(v->dtype()), v->shape.size(), &shape[0], (aclFormat)format);
+ aclSetTensorFormat(desc, (aclFormat)format);
+ aclSetTensorShape(desc, v->shape.size(), &shape[0]);
+ LOGv << "aclCreateTensorDesc" << (int)get_dtype(v->dtype()) << v->shape.size() << &shape[0] << format;
+ auto data = aclCreateDataBuffer(v->mem_ptr, v->size);
+ LOGv << "aclCreateDataBuffer" << v->mem_ptr << v->size;
+ ASSERT(desc && data);
+ if (is_input) {
+ input_desc.push_back(desc);
+ input_data.push_back(data);
+ } else {
+ output_desc.push_back(desc);
+ output_data.push_back(data);
+ }
+ }
+
+ void add_input_host(vector v, int dtype=ACL_UINT64) {
+ int64_t shape[1];
+ shape[0] = v.size();
+ auto desc = aclCreateTensorDesc((aclDataType)dtype, 1, &shape[0], ACL_FORMAT_ND);
+ aclSetTensorPlaceMent(desc, ACL_MEMTYPE_HOST_COMPILE_INDEPENDENT);
+ LOGv << "aclCreateTensorDesc" << dtype << 1 << &shape[0] << ACL_FORMAT_ND;
+ auto data = aclCreateDataBuffer(&v[0], v.size()*sizeof(uint64));
+ ASSERT(desc && data);
+ LOGv << "aclCreateDataBuffer" << &v[0] << v.size()*sizeof(uint64);
+ input_desc.push_back(desc);
+ input_data.push_back(data);
+ input_host.emplace_back(move(v));
+ LOGv << "move" << input_host.back().data();
+ }
+
+ void add_input_host_scalar(vector v, int dtype=ACL_UINT32) {
+ int64_t shape[1];
+ shape[0] = v.size();
+ auto x = (int*)&v[0];
+ x[0] = (int32)v[0];
+ auto desc = aclCreateTensorDesc((aclDataType)dtype, 0, &shape[0], ACL_FORMAT_ND);
+ aclSetTensorPlaceMent(desc, ACL_MEMTYPE_HOST_COMPILE_INDEPENDENT);
+ LOGv << "aclCreateTensorDesc" << dtype << 1 << &shape[0] << ACL_FORMAT_ND;
+ auto data = aclCreateDataBuffer(&v[0], v.size()*sizeof(uint32));
+ ASSERT(desc && data);
+ LOGv << "aclCreateDataBuffer" << &v[0] << v.size()*sizeof(uint32);
+ input_desc.push_back(desc);
+ input_data.push_back(data);
+ input_host.emplace_back(move(v));
+ }
+
+ void add_input_host_nv(NanoVector nv, int dtype=ACL_UINT64) {
+ vector v(nv.size());
+ for (int i=0; i v(nv.size());
+ for (int i=0; i value) {
+ CHECK(aclopSetAttrListInt(attr, key.c_str(), value.size(), &value[0])==0);
+ }
+ void set_attr(const string& key, string value) {
+ CHECK(aclopSetAttrString(attr, key.c_str(), value.c_str())==0);
+ }
+ void set_attr(const char* key, const char* value) {
+ CHECK(aclopSetAttrString(attr, key, value)==0);
+ }
+
+ void run() {
+ // printDeviceData(input_desc, input_data, name);
+
+ LOGv << "run" << name << input_desc.size() << output_desc.size();
+ if (!PyGILState_Check()) {
+ ASSERT(0==aclopCompileAndExecuteV2(name.c_str(), input_desc.size(), &input_desc[0], &input_data[0], output_desc.size(), &output_desc[0], &output_data[0], attr, ACL_ENGINE_SYS, ACL_COMPILE_SYS, NULL, aclstream));
+ } else {
+ int ret;
+ Py_BEGIN_ALLOW_THREADS
+ ret = aclopCompileAndExecuteV2(name.c_str(), input_desc.size(), &input_desc[0], &input_data[0], output_desc.size(), &output_desc[0], &output_data[0], attr, ACL_ENGINE_SYS, ACL_COMPILE_SYS, NULL, aclstream);
+ Py_END_ALLOW_THREADS
+ if (ret != 0)
+ LOGf << "aclopCompileAndExecuteV2" << name << "failed return" << ret;
+ }
+ ASSERT(0==aclrtSynchronizeDevice());
+
+ // printDeviceData(output_desc, output_data, name, false);
+ }
+ };
+
+ }
+ """,
+ cuda_src=f"""
+ // aclop
+ AclOpRunner op("{name}");
+ {input_code}
+ {output_code}
+ {attr_code}
+ op.run();"""
+ )
+
+
+def change_function():
+ import jittor as jt
+ from jittor import Function
+
+ class IndexACL(Function):
+ def __init__(self):
+ super(IndexACL, self).__init__()
+
+ def execute(self, inshape: list, dim: int, dtype="int32"):
+ # zeros a tensor, shape is inshape, dtype is dtype
+ max_len = inshape[dim]
+ tmp = jt.zeros(max_len, dtype=dtype)
+ result = acl_cmd(
+ "Range",
+ [jt.Var(0), jt.Var(max_len), jt.Var(1)],
+ output_dtypes=[tmp.dtype],
+ output_shapes=[tmp.shape],
+ attr={})[0]
+ broadcast_dim = []
+ for i in range(len(inshape)):
+ if i != dim:
+ broadcast_dim.append(i)
+ result = jt.broadcast(result, shape=inshape, dims=broadcast_dim)
+ return result
+
+ def grad(self, grad_output):
+ return grad_output
+
+
+ class PoolACL(Function):
+
+ def __init__(self,
+ kernel_size,
+ stride=None,
+ padding=0,
+ dilation=None,
+ return_indices=None,
+ ceil_mode=False,
+ op='maximum'):
+ super(PoolACL, self).__init__()
+ import jittor as jt
+ # set attr
+ self.kernel_size = kernel_size if isinstance(
+ kernel_size, tuple) else (kernel_size, kernel_size)
+ stride = stride if stride else kernel_size
+ self.stride = stride if isinstance(stride, tuple) else (stride, stride)
+ self.padding = padding if isinstance(padding, tuple) else (padding,
+ padding)
+ dilation = dilation if dilation else 1
+ self.dilation = dilation if isinstance(dilation, tuple) else (dilation,
+ dilation)
+ attr = {}
+ attr['ksize'] = [1, self.kernel_size[0], self.kernel_size[1], 1]
+ attr['strides'] = [1, self.stride[0], self.stride[1], 1]
+ attr['pads'] = [1, self.padding[0], self.padding[1], 1]
+ attr['dilation'] = [1, self.dilation[0], self.dilation[1], 1]
+ attr['ceil_mode'] = ceil_mode
+
+ self.attr = attr
+ self.return_indices = return_indices
+ self.uint16 = jt.Var(1).int32().dtype
+ self.op = op
+
+ def execute(self, input):
+
+ # create input
+ input_shape = input.shape
+ input_dtype = input.dtype
+
+ self.input = input
+
+ # create output
+ output_shape = [
+ input_shape[0], input_shape[1],
+ (input_shape[2] + 2 * self.padding[0] - self.dilation[0] *
+ (self.kernel_size[0] - 1) - 1) // self.stride[0] + 1,
+ (input_shape[3] + 2 * self.padding[1] - self.dilation[1] *
+ (self.kernel_size[1] - 1) - 1) // self.stride[1] + 1
+ ]
+ output_dtype = input_dtype
+ result = acl_cmd("MaxPoolWithArgmaxV1", [input],
+ output_dtypes=[output_dtype, self.uint16],
+ output_shapes=[output_shape, output_shape],
+ attr=self.attr)
+ self.index = result[1]
+ if self.return_indices:
+ return result[0], result[1]
+ else:
+ return result[0]
+
+ def grad(self, grad_output):
+
+ grad_input = acl_cmd("MaxPoolGradWithArgmaxV1", [self.input, grad_output, self.index],
+ output_dtypes=[grad_output.dtype],
+ output_shapes=[self.input.shape],
+ attr=self.attr)[0]
+ return grad_input
+
+ class BmmACL(Function):
+ def __init__(self, adj_x1=False, adj_x2=False):
+ super(BmmACL, self).__init__()
+ self.adj_x1 = adj_x1
+ self.adj_x2 = adj_x2
+
+ def execute(self, x1, x2):
+ self.input = [x1, x2]
+ result = acl_cmd("BatchMatMul", [x1, x2],
+ output_dtypes=[x1.dtype],
+ output_shapes=[x1.shape[:-1] + x2.shape[-1:]],
+ attr={})[0]
+ return result
+
+ def grad(self, grad_output):
+ x1, x2 = self.input
+ grad_x1 = acl_cmd("BatchMatMul", [grad_output, x2.transpose(-2, -1)],
+ output_dtypes=[x1.dtype],
+ output_shapes=[grad_output.shape[:-1] + x1.shape[-1:]],
+ attr={})[0]
+ grad_x2 = acl_cmd("BatchMatMul", [x1.transpose(-2, -1) , grad_output],
+ output_dtypes=[x2.dtype],
+ output_shapes=[x2.shape[:-1] + grad_output.shape[-1:]],
+ attr={})[0]
+ return grad_x1, grad_x2
+
+ class MatmulACL(Function):
+ def __init__(self, adj_x1=False, adj_x2=False):
+ super(MatmulACL, self).__init__()
+ self.adj_x1 = adj_x1
+ self.adj_x2 = adj_x2
+
+ def execute(self, x1, x2):
+ self.input = [x1, x2]
+ if len(x1.shape) > 2 or len(x2.shape) > 2:
+ result = acl_cmd("BatchMatMul", [x1, x2],
+ output_dtypes=[x1.dtype],
+ output_shapes=[x1.shape[:-1] + x2.shape[-1:]],
+ attr={})[0]
+ else:
+ result = acl_cmd("MatMul", [x1, x2],
+ output_dtypes=[x1.dtype],
+ output_shapes=[x1.shape[:-1] + x2.shape[-1:]],
+ attr={})[0]
+ return result
+
+ def grad(self, grad_output):
+ x1, x2 = self.input
+ if len(x1.shape) > 2 or len(x2.shape) > 2:
+ grad_x1 = acl_cmd("BatchMatMul", [grad_output, x2.transpose(-2, -1)],
+ output_dtypes=[x1.dtype],
+ output_shapes=[grad_output.shape[:-1] + x1.shape[-1:]],
+ attr={})[0]
+ grad_x2 = acl_cmd("BatchMatMul", [x1.transpose(-2, -1) , grad_output],
+ output_dtypes=[x2.dtype],
+ output_shapes=[x2.shape[:-1] + grad_output.shape[-1:]],
+ attr={})[0]
+ else:
+ grad_x1 = acl_cmd("MatMul", [grad_output, x2.transpose(-2, -1)],
+ output_dtypes=[x1.dtype],
+ output_shapes=[grad_output.shape[:-1] + x1.shape[-1:]],
+ attr={})[0]
+ grad_x2 = acl_cmd("MatMul", [x1.transpose(-2, -1) , grad_output],
+ output_dtypes=[x2.dtype],
+ output_shapes=[x2.shape[:-1] + grad_output.shape[-1:]],
+ attr={})[0]
+ return grad_x1, grad_x2
+
+ class GetItem(Function):
+ def __init__(self):
+ super(GetItem, self).__init__()
+
+ def stride(self, x, dim):
+ stride = 1
+ for i in range(dim+1, len(x.shape)):
+ stride *= x.shape[i]
+ return stride
+
+ def execute(self, x, slices):
+ if isinstance(slices, jt.Var) or isinstance(slices, tuple):
+ if isinstance(slices, jt.Var):
+ slices = (slices,)
+ if isinstance(slices[0], jt.Var):
+ slices_len = len(slices)
+ masks = jt.ones(slices_len, dtype=jt.int64)
+
+ output = slices[0].shape
+ output += x.shape[slices_len:]
+
+ input_ = [x, masks, jt.Var(list(output)).int64()]
+ for i in range(slices_len):
+ input_.append(slices[i].int32())
+
+ result = acl_cmd("Index", input_,
+ output_dtypes=[x.dtype],
+ output_shapes=[output],
+ attr={})[0]
+ return result
+
+ # use AsStrided operator to implement the getitem function
+ # get the shape and stride of the input tensor
+ x_dim = len(x.shape)
+
+ if not isinstance(slices, tuple):
+ slices = (slices,)
+
+ if len(slices) < x_dim:
+ slices += (slice(None, None, None),) * (x_dim - len(slices))
+
+ self.inputs = [x, slices]
+
+ sizes = []
+ strides = []
+ offset = 0
+
+ for dim, s in enumerate(slices):
+ if isinstance(s, int):
+ if s < 0: # Handle negative indices.
+ s += x.shape[dim]
+ offset += s * self.stride(x, dim)
+ elif isinstance(s, slice):
+ # Unpack the slice
+ start, stop, step = s.indices(x.size(dim))
+ size = (stop - start - 1) // step + 1
+ stride = self.stride(x, dim) * step
+ offset += start * self.stride(x, dim)
+ sizes.append(size)
+ strides.append(stride)
+ else:
+ raise ValueError("Invalid slice type")
+
+ if not sizes:
+ sizes = [1]
+ strides = [0]
+ # AsStrided same with as_strided of pytorch
+ self.sizes = sizes
+ self.strides = strides
+ self.offset = offset
+ self.shape = x.shape
+ result = acl_cmd("AsStrided", [x, jt.Var(sizes), jt.Var(strides), jt.Var(offset)],
+ output_dtypes=[x.dtype],
+ output_shapes=[jt.empty(sizes).shape],
+ attr={})[0]
+ return result
+
+ def grad(self, grad_output):
+ result = jt.zeros(self.shape, dtype=grad_output.dtype)
+ sizes = list(grad_output.shape)
+ strides = [self.stride(grad_output, dim) for dim in range(len(grad_output.shape))]
+ result = acl_cmd("ViewCopy", [result, jt.Var(self.sizes), jt.Var(self.strides), jt.Var(self.offset),
+ grad_output, jt.Var(sizes), jt.Var(strides), jt.Var(0)],
+ output_dtypes=[result.dtype],
+ output_shapes=[result.shape],
+ attr={})[0]
+ result.sync()
+ return result, None
+
+ class ConcatACL(Function):
+ def __init__(self):
+ super(ConcatACL, self).__init__()
+
+ def execute(self, input_tensors, dim=0):
+ self.input = input_tensors
+ for i in range(len(input_tensors)):
+ if input_tensors[i].dtype != input_tensors[0].dtype:
+ raise ValueError("All input tensors must have the same dtype")
+ if input_tensors[i].shape[:dim] != input_tensors[0].shape[:dim] or input_tensors[i].shape[dim+1:] != input_tensors[0].shape[dim+1:]:
+ raise ValueError("All input tensors must have the same shape")
+ result = acl_cmd("ConcatD", input_tensors,
+ output_dtypes=[input_tensors[0].dtype],
+ output_shapes=[jt.empty(self.calculate_output_shape(input_tensors, dim)).shape],
+ attr={
+ "N": len(input_tensors),
+ "concat_dim": dim
+ })[0]
+ return result
+
+ def grad(self, grad_output):
+ grad_inputs = self.split_grad(grad_output, self.input, self.axis)
+ return grad_inputs
+
+ def calculate_output_shape(self, input_tensors, axis):
+ shape = list(input_tensors[0].shape)
+ for tensor in input_tensors[1:]:
+ shape[axis] += tensor.shape[axis]
+ return tuple(shape)
+
+ def split_grad(self, grad_output, input_tensors, axis):
+ offset = 0
+ grad_inputs = []
+ for tensor in input_tensors:
+ grad_input = acl_cmd("Slice",
+ [grad_output,
+ [0]*axis + [offset] + [0]*(len(tensor.shape)-axis-1),
+ tensor.shape]
+ )
+ grad_inputs.append(grad_input)
+ offset += tensor.shape[axis]
+ return grad_inputs
+
+
+ class SetItemACL(Function):
+ def __init__(self):
+ super(SetItemACL, self).__init__()
+
+ def stride(self, x, dim):
+ # 计算给定维度的步长
+ stride = 1
+ for i in range(dim + 1, len(x.shape)):
+ stride *= x.shape[i]
+ return stride
+
+ def execute(self, x, slices, value):
+ self.is_tensor = type(value) == jt.Var
+ if type(value) != jt.Var:
+ value = jt.array(value)
+ x_dim = len(x.shape)
+
+ # 确保slices是一个元组
+ if not isinstance(slices, tuple):
+ slices = (slices,)
+
+ # 补齐slices使其长度等于x的维度
+ if len(slices) < x_dim:
+ slices += (slice(None, None, None),) * (x_dim - len(slices))
+
+ self.inputs = [x, slices, value]
+
+ target_sizes = []
+ target_strides = []
+ offset = 0
+
+ for dim, s in enumerate(slices):
+ if isinstance(s, int):
+ if s < 0:
+ s += x.shape[dim]
+ s = slice(s, s+1, None)
+ if isinstance(s, slice):
+ # 解包切片
+ start, stop, step = s.indices(x.shape[dim])
+ size = (stop - start - 1) // step + 1
+ stride = self.stride(x, dim) * step
+ offset += start * self.stride(x, dim)
+ target_sizes.append(size)
+ target_strides.append(stride)
+ else:
+ print("slices: ", s, type(s))
+ raise ValueError("Invalid slice type")
+
+ # 计算value的size、stride和offset
+ value_sizes = list(value.shape)
+ value_strides = [self.stride(value, dim) for dim in range(len(value.shape))]
+
+ self.target_sizes = target_sizes
+ self.target_strides = target_strides
+ self.offset = offset
+ self.value_sizes = value_sizes
+ self.value_strides = value_strides
+
+ #import pdb; pdb.set_trace()
+ result = acl_cmd("ViewCopy", [x, jt.Var(target_sizes), jt.Var(target_strides), jt.Var(offset),
+ value, jt.Var(value_sizes), jt.Var(value_strides), jt.Var(0)],
+ output_dtypes=[x.dtype],
+ output_shapes=[x.shape],
+ attr={})[0]
+ result.sync()
+ return result
+
+ def grad(self, grad_output):
+ result = acl_cmd("AsStrided", [grad_output, jt.Var(self.target_sizes), jt.Var(self.target_strides), jt.Var(self.offset)],
+ output_dtypes=[grad_output.dtype],
+ output_shapes=[jt.empty(self.target_sizes).shape],
+ attr={})[0]
+ # copy grad_output to new_grad_output
+ new_grad_output = acl_cmd("Copy", [grad_output],
+ output_dtypes=[grad_output.dtype],
+ output_shapes=[grad_output.shape],
+ attr={ "N": 1 })[0]
+ new_grad_output = acl_cmd("ViewCopy", [new_grad_output, jt.Var(self.target_sizes), jt.Var(self.target_strides), jt.Var(self.offset),
+ jt.zeros(self.value_sizes, dtype=grad_output.dtype), jt.Var(self.value_sizes), jt.Var(self.value_strides), jt.Var(0)],
+ output_dtypes=[grad_output.dtype],
+ output_shapes=[grad_output.shape],
+ attr={})[0]
+ new_grad_output.sync()
+ return new_grad_output, None, result if self.is_tensor else None
+
+ class TriuACL(Function):
+ def __init__(self):
+ super(TriuACL, self).__init__()
+
+ def execute(self, input, k):
+ self.input = input
+ result = acl_cmd("Triu", [input],
+ output_dtypes=[input.dtype],
+ output_shapes=[input.shape],
+ attr={'diagonal': k})[0]
+ return result
+
+ def grad(self, grad_output):
+ return grad_output
+
+ class TransposeACL(Function):
+ def __init__(self):
+ super(TransposeACL, self).__init__()
+
+ def execute(self, input, perm):
+ self.input = input
+
+ output_shape = input.shape[perm[0]:perm[0]+1]
+
+ for i in range(1, len(perm)):
+ output_shape += input.shape[perm[i]:perm[i]+1]
+ result = acl_cmd("Transpose", [input, jt.Var(perm)],
+ output_dtypes=[input.dtype],
+ output_shapes=[output_shape],
+ attr={})[0]
+ return result
+
+ def grad(self, grad_output):
+ return grad_output
+
+ def warp(origin_func, new_func):
+ def warpper(*args, **kwargs):
+ if jt.flags.use_acl:
+ return new_func(*args, **kwargs)
+ if origin_func == jt.index:
+ if len(args) == 2 and args[1] == None:
+ args = tuple(list(args[0:1]))
+ return origin_func(*args, **kwargs)
+ return warpper
+
+ jt.index = warp(jt.index, IndexACL())
+ jt.Var.index = lambda x, dim=None: warp(jt.index, IndexACL())(x.shape, dim)
+ jt.nn.Pool = warp(jt.nn.Pool, PoolACL)
+
+ jt.triu = warp(jt.triu, TriuACL())
+ jt.triu_ = warp(jt.triu, TriuACL())
+ jt.Var.triu = lambda x: warp(jt.Var.triu, TriuACL())(x)
+ jt.Var.triu_ = lambda x: warp(jt.Var.triu_, TriuACL())(x)
+
+ jt.getitem = warp(jt.getitem, GetItem())
+ jt.Var.getitem = lambda x, slices: warp(jt.getitem, GetItem())(x, slices)
+ jt.setitem = warp(jt.setitem, SetItemACL())
+ jt.Var.setitem = lambda x, slices, value: warp(jt.setitem, SetItemACL())(x, slices, value)
+
+ # jt.nn.bmm = warp(jt.nn.bmm, BmmACL())
+ # jt.bmm = warp(jt.bmm, BmmACL())
+ # jt.nn.matmul = warp(jt.matmul, MatmulACL())
+ # jt.matmul = warp(jt.matmul, MatmulACL())
+ # jt.transpose = warp(jt.transpose, TransposeACL())
+ # jt.Var.transpose = lambda x, perm: warp(jt.transpose, TransposeACL())(x, perm)
+ # jt.concat = warp(jt.concat, ConcatACL())
diff --git a/python/jittor/extern/acl/acl_jittor.cc b/python/jittor/extern/acl/acl_jittor.cc
index dc99887f..07639ab0 100644
--- a/python/jittor/extern/acl/acl_jittor.cc
+++ b/python/jittor/extern/acl/acl_jittor.cc
@@ -16,6 +16,7 @@ namespace jittor {
uint64_t acl_jittor_tid;
int acl_jittor_thread_running=0;
aclrtContext acl_jittor_context;
+aclrtStream aclstream;
#define CHECK_ACL(x) ASSERTop(x,==,0)
@@ -28,7 +29,7 @@ static void* acl_jittor_process_callback(void*) {
// LOGir << "acl_jittor_process_callback";
auto ret = aclrtProcessReport(1000);
if (ret) {
- if (acl_jittor_thread_running && ret != ACL_ERROR_RT_REPORT_TIMEOUT)
+ if (acl_jittor_thread_running && ret != ACL_ERROR_RT_REPORT_TIMEOUT && ret != ACL_ERROR_RT_THREAD_SUBSCRIBE)
LOGir << "aclrtProcessReport:" << ret << acl_error_to_string(ret);
break;
}
@@ -59,10 +60,9 @@ acl_jittor_initer() {
CHECK_ACL(aclrtSubscribeReport(acl_jittor_tid,0));
// simple callback test
- // aclrtStream stream;
- // CHECK_ACL(aclrtCreateStream(&stream));
- // CHECK_ACL(aclrtSubscribeReport(acl_jittor_tid,stream));
- // CHECK_ACL(aclrtLaunchCallback((aclrtCallback)&aaa, 0, ACL_CALLBACK_NO_BLOCK, stream));
+ CHECK_ACL(aclrtCreateStream(&aclstream));
+ // CHECK_ACL(aclrtSubscribeReport(acl_jittor_tid,aclstream));
+ // CHECK_ACL(aclrtLaunchCallback((aclrtCallback)&aaa, 0, ACL_CALLBACK_NO_BLOCK, aclstream));
// CHECK_ACL(aclrtLaunchCallback((aclrtCallback)&aaa, 0, ACL_CALLBACK_NO_BLOCK, 0));
}
@@ -87,7 +87,7 @@ string process_acl(const string& src, const string& name, const map fake_class = {
"cudnnHandle_t", "cudnnConvolutionBwdFilterAlgo_t",
@@ -117,6 +117,7 @@ string process_acl(const string& src, const string& name, const map& output_desc, const vector& output_data, const string& name = "", bool input=true) {
+ LOGir << "name: " << name;
+ if(input)
+ LOGir << "is input";
+ else
+ LOGir << "is ouput";
+ for (size_t i = 0; i < output_desc.size(); ++i) {
+ void* base_addr = aclGetDataBufferAddr(output_data[i]);
+ LOGir << "addr of data[" << i << "] :" << base_addr;
+ size_t num_dims = aclGetTensorDescNumDims(output_desc[i]);
+ size_t total_size = 1;
+ std::vector dims(num_dims);
+
+ std::cout << "shape of data: ";
+ for (size_t j = 0; j < num_dims; ++j) {
+ aclGetTensorDescDimV2(output_desc[i], j, &dims[j]);
+ total_size *= dims[j];
+ std::cout << dims[j] << ", ";
+ }
+ int evey_batch_size = total_size/dims[0];
+ std::cout << std::endl;
+
+ // for(int i= 0; i < dims[0]; i++) {
+ // evey_batch_size = 16;
+ // std::vector host_buffer(evey_batch_size);
+ // void* offset_addr = static_cast(base_addr) + i * evey_batch_size * sizeof(float);
+ // aclrtMemcpy(host_buffer.data(), evey_batch_size * sizeof(float), offset_addr, evey_batch_size * sizeof(float), ACL_MEMCPY_DEVICE_TO_HOST);
+ // std::cout << "batch[" << i << "]:";
+ // for (size_t k = 0; k < evey_batch_size; ++k) {
+ // std::cout << host_buffer[k] << ", ";
+ // }
+ // std::cout << std::endl;
+ // if(i >= 3)
+ // break;
+ // }
+ }
+}
+
struct AclOpRunner {
string name;
vector input_desc;
@@ -40,6 +80,7 @@ struct AclOpRunner {
vector output_data;
aclopAttr *attr;
vector> input_host;
+ vector> input_host_32;
AclOpRunner(const string& name) : name(name) {
attr = aclopCreateAttr();
@@ -56,9 +97,14 @@ struct AclOpRunner {
aclDataType get_dtype(NanoString s) {
if (s == ns_float32) return ACL_FLOAT;
if (s == ns_float16) return ACL_FLOAT16;
+ if (s == ns_int64) return ACL_INT64;
if (s == ns_int32) return ACL_INT32;
if (s == ns_int8) return ACL_INT8;
+ if (s == ns_int16) return ACL_INT16;
if (s == ns_uint8) return ACL_UINT8;
+ if (s == ns_uint16) return ACL_UINT16;
+ if (s == ns_uint32) return ACL_UINT32;
+ if (s == ns_bool) return ACL_BOOL;
LOGf << "Not supported dtype: " << s;
return ACL_FLOAT;
}
@@ -138,7 +184,7 @@ struct AclOpRunner {
auto data = aclCreateDataBuffer(&v[0], v.size()*sizeof(int));
input_desc.push_back(desc);
input_data.push_back(data);
- // input_host.emplace_back(move(v));
+ input_host_32.emplace_back(move(v));
}
void set_attr(const string& key, bool value) {
@@ -164,18 +210,22 @@ struct AclOpRunner {
}
void run() {
+ // printDeviceData(input_desc, input_data, name);
+
LOGv << "run" << name << input_desc.size() << output_desc.size();
if (!PyGILState_Check()) {
- ASSERT(0==aclopCompileAndExecuteV2(name.c_str(), input_desc.size(), &input_desc[0], &input_data[0], output_desc.size(), &output_desc[0], &output_data[0], attr, ACL_ENGINE_SYS, ACL_COMPILE_SYS, NULL, NULL));
+ ASSERT(0==aclopCompileAndExecuteV2(name.c_str(), input_desc.size(), &input_desc[0], &input_data[0], output_desc.size(), &output_desc[0], &output_data[0], attr, ACL_ENGINE_SYS, ACL_COMPILE_SYS, NULL, aclstream));
} else {
int ret;
Py_BEGIN_ALLOW_THREADS
- ret = aclopCompileAndExecuteV2(name.c_str(), input_desc.size(), &input_desc[0], &input_data[0], output_desc.size(), &output_desc[0], &output_data[0], attr, ACL_ENGINE_SYS, ACL_COMPILE_SYS, NULL, NULL);
+ ret = aclopCompileAndExecuteV2(name.c_str(), input_desc.size(), &input_desc[0], &input_data[0], output_desc.size(), &output_desc[0], &output_data[0], attr, ACL_ENGINE_SYS, ACL_COMPILE_SYS, NULL, aclstream);
Py_END_ALLOW_THREADS
if (ret != 0)
LOGf << "aclopCompileAndExecuteV2" << name << "failed return" << ret;
}
- // ASSERT(0==aclrtSynchronizeDevice());
+ ASSERT(0==aclrtSynchronizeDevice());
+
+ // printDeviceData(output_desc, output_data, name, false);
}
};
@@ -318,6 +368,17 @@ void try_exec_and_fallback_cpu(Op* op) {
auto iter = opname_map.find(bop->ns);
ASSERT(iter != opname_map.end()) << "op " << bop->ns << " not found";
op.name = iter->second;
+ if (bop->x->dtype() == ns_bool and bop->y->dtype() == ns_bool)
+ {
+ // BitwiseOr, BitwiseAnd, BitwiseXor -> LogicalOr, LogicalAnd, LogicalXor
+ if (bop->ns == ns_bitwise_or) {
+ op.name = "LogicalOr";
+ } else if (bop->ns == ns_bitwise_and) {
+ op.name = "LogicalAnd";
+ } else if (bop->ns == ns_bitwise_xor) {
+ op.name = "LogicalXor";
+ }
+ }
op.run();
} else
if (op->name() == string("ternary")) {
@@ -345,7 +406,7 @@ void try_exec_and_fallback_cpu(Op* op) {
else if (rop->ns == ns_minimum)
op.name = "ReduceMin";
else if (rop->ns == ns_mean)
- op.name = "Reduce";
+ op.name = "ReduceMean";
else
LOGf << "op " << rop->ns << " not supported";
op.add(rop->x, true);
@@ -381,6 +442,19 @@ void try_exec_and_fallback_cpu(Op* op) {
op.add_input_host_nv(zshape, ACL_INT64);
op.add(bop->z, false);
op.run();
+ }
+ else
+ if (op->name() == string("fuse_transpose")) {
+ // replace fuse_transpose with transpose
+ auto top = (TransposeOp*)op;
+ AclOpRunner op("Transpose");
+ op.add(top->x, true);
+ op.add(top->y, false);
+ vector axes;
+ for (int i=0; iaxes.size(); i++)
+ axes.push_back(top->axes[i]);
+ op.add_input_host(axes, ACL_INT64);
+ op.run();
} else
{
LOGf << "op " << op->name() << " not supported";
@@ -388,6 +462,7 @@ void try_exec_and_fallback_cpu(Op* op) {
}
} catch (std::exception& e) {
fallback = 1;
+ LOGir << "fallback cpu" << e.what();
}
for (auto v : new_alloced) {
free_var_mem(v);
@@ -401,7 +476,7 @@ extern int current_seed;
extern int64 current_offset;
static unordered_map> acl_ops = {
-{"curand_random", [&](Op* op) {
+{"curand_random", [¤t_seed, ¤t_offset](Op* op) {
auto _op = (RandomOp*)op;
AclOpRunner runner(_op->type == ns_uniform ? "StatelessRandomUniformV2" : "StatelessRandomNormalV2");
auto out = op->output(0);
@@ -429,7 +504,21 @@ static unordered_map> acl_ops = {
runner.set_attr("transpose_x2", _op->trans_b);
runner.run();
}},
-{"cudnn_conv", [&](Op* op) {
+{"cublas_batched_matmul", [&](Op* op) {
+ struct BatchedMatmulOp : Op {
+ Var* a, *b, *c;
+ bool adj_x1, adj_x2;
+ };
+ auto _op = (BatchedMatmulOp*)op;
+ AclOpRunner runner("BatchMatMul");
+ runner.add(_op->a, true);
+ runner.add(_op->b, true);
+ runner.add(_op->c, false);
+ runner.set_attr("adj_x1", _op->adj_x1);
+ runner.set_attr("adj_x2", _op->adj_x2);
+ runner.run();
+}},
+{"cudnn_conv", [](Op* op) {
struct ConvOp : Op {
Var* x, * w, * y;
int strideh, stridew, paddingh, paddingw, dilationh, dilationw, groups;
@@ -451,7 +540,7 @@ static unordered_map> acl_ops = {
auto _op = (ConvOp*)op;
_op->run_acl();
}},
-{"cudnn_conv_backward_x", [&](Op* op) {
+{"cudnn_conv_backward_x", [](Op* op) {
struct ConvBackwardXOp : Op {
Var* w, * dy, * dx;
int xh, xw, strideh, stridew, paddingh, paddingw, dilationh, dilationw, groups;
@@ -480,7 +569,7 @@ static unordered_map> acl_ops = {
auto _op = (ConvBackwardXOp*)op;
_op->run_acl();
}},
-{"cudnn_conv_backward_w", [&](Op* op) {
+{"cudnn_conv_backward_w", [](Op* op) {
struct ConvBackwardWOp : Op {
Var* x, * dy, * dw;
int kh, kw, strideh, stridew, paddingh, paddingw, dilationh, dilationw, groups;
@@ -488,7 +577,7 @@ static unordered_map> acl_ops = {
void run_acl() {
AclOpRunner runner("Conv2DBackpropFilter");
runner.add(x, true, ACL_FORMAT_NCHW);
- runner.add_input_host_nv32(x->shape);
+ runner.add_input_host_nv32(dw->shape);
runner.add(dy, true, ACL_FORMAT_NCHW);
runner.add(dw, false, ACL_FORMAT_NCHW);
runner.set_attr("strides", vector{1,1,strideh,stridew});
@@ -538,7 +627,7 @@ static jit_op_entry_t acl_do_compile(Op* op) {
return oc.compile(op->get_jit_key(get_jk()), *src);
}
if (op->name() == string("fused")) {
- FusedOp* fop = (FusedOp*)op;
+ FusedOp* fop = (FusedOp*)op;
// if is a relayed op
if (fop->context->vrm.relay_groups.size()) {
LOGv << "relay fused op";
@@ -546,7 +635,17 @@ static jit_op_entry_t acl_do_compile(Op* op) {
} else {
return &try_exec_and_fallback_cpu;
}
- } else {
+ } else
+ if (op->name() == string("code")) {
+ CodeOp* cop = (CodeOp*)op;
+ if (cop->cuda_src.find("acl") != string::npos) {
+ LOGv << "compile acl op";
+ return oc.compile(op->get_jit_key(get_jk()), *src);
+ } else {
+ return &exec_mapped_acl_ops;
+ }
+ } else
+ {
LOGv << "compile finish" << op;
return &exec_mapped_acl_ops;
}
diff --git a/python/jittor/extern/cuda/cublas/inc/cublas_wrapper.h b/python/jittor/extern/cuda/cublas/inc/cublas_wrapper.h
index be01348f..7667e495 100644
--- a/python/jittor/extern/cuda/cublas/inc/cublas_wrapper.h
+++ b/python/jittor/extern/cuda/cublas/inc/cublas_wrapper.h
@@ -25,7 +25,9 @@ static inline cudaDataType get_dtype(NanoString dtype) {
if (dtype == ns_float32) return CUDA_R_32F;
if (dtype == ns_float64) return CUDA_R_64F;
if (dtype == ns_float16) return CUDA_R_16F;
+ #ifndef IS_ROCM
if (dtype == ns_bfloat16) return CUDA_R_16BF;
+ #endif
LOGf << "not support type" << dtype;
return CUDA_R_32F;
}
diff --git a/python/jittor/extern/cuda/cudnn/inc/cudnn_wrapper.h b/python/jittor/extern/cuda/cudnn/inc/cudnn_wrapper.h
index 2109a03a..73fb3233 100644
--- a/python/jittor/extern/cuda/cudnn/inc/cudnn_wrapper.h
+++ b/python/jittor/extern/cuda/cudnn/inc/cudnn_wrapper.h
@@ -8,8 +8,9 @@
#include
#include
#include
+#ifndef IS_ROCM
#include
-
+#endif
#include "utils/log.h"
#include "helper_cuda.h"
#include "fp16_emu.h"
@@ -31,6 +32,8 @@ void set_max_workspace_ratio(float64 ratio);
template __inline__ cudnnDataType_t getDataType();
template <> __inline__ cudnnDataType_t getDataType() { return CUDNN_DATA_HALF; }
template <> __inline__ cudnnDataType_t getDataType() { return CUDNN_DATA_FLOAT; }
+#ifndef IS_ROCM
template <> __inline__ cudnnDataType_t getDataType<__nv_bfloat16>() { return CUDNN_DATA_BFLOAT16; }
+#endif
} // jittor
diff --git a/python/jittor/linalg.py b/python/jittor/linalg.py
index d601335c..2053d86c 100644
--- a/python/jittor/linalg.py
+++ b/python/jittor/linalg.py
@@ -10,6 +10,214 @@
# ***************************************************************
import jittor as jt
from functools import partial
+from .nn import ComplexNumber
+
+def complex_inv(x:ComplexNumber):
+ r"""
+ calculate the inverse of x.
+ :param x (...,M,M):
+ :return:x^-1 (...,M,M).
+
+ TODO: Faster Implementation; Check backward.
+ """
+ assert isinstance(x, ComplexNumber), "complex_inv is implemented for nn.ComplexNumber"
+ assert x.real.dtype == jt.float32 and x.imag.dtype == jt.float32, "real and imag in ComplexNumber should be jt.float32"
+ assert x.shape[-2] == x.shape[-1], "only square matrix is supported for complex_inv"
+
+ def forward_code(np, data):
+ def _stack_to_complex(x):
+ return x[..., 0] + 1j * x[..., 1]
+ def _complex_to_stack(x):
+ return np.stack([np.real(x), np.imag(x)], axis=-1)
+
+ a = _stack_to_complex(data["inputs"][0])
+ m_a = data["outputs"][0]
+ t_a = np.linalg.inv(a)
+ np.copyto(m_a, _complex_to_stack(t_a))
+
+
+ def backward_code(np, data):
+ def T(x):
+ return np.conj(np.swapaxes(x, -1, -2))
+ def _stack_to_complex(x):
+ return x[..., 0] + 1j * x[..., 1]
+ def _complex_to_stack(x):
+ return np.stack([np.real(x), np.imag(x)], axis=-1)
+ _dot = partial(np.einsum, '...ij,...jk->...ik')
+ dout = _stack_to_complex(data["dout"])
+ out = data["outputs"][0]
+ mx = _stack_to_complex(data["f_outputs"][0])
+ t = -_dot(_dot(T(mx), dout), T(mx))
+ np.copyto(out, _complex_to_stack(t))
+
+ lmx = jt.numpy_code(
+ x.value.shape,
+ x.value.dtype,
+ [x.value],
+ forward_code,
+ [backward_code],
+ )
+
+ return ComplexNumber(lmx, is_concat_value=True)
+
+def complex_eig(x:ComplexNumber):
+ r"""
+ calculate the eigenvalues and eigenvectors of x.
+ :param x (...,M,M):
+ :return:w, v.
+ w (...,M) : the eigenvalues.
+ v (...,M,M) : normalized eigenvectors.
+ """
+ assert isinstance(x, ComplexNumber), "complex_eig is implemented for nn.ComplexNumber"
+ assert x.real.dtype == jt.float32 and x.imag.dtype == jt.float32, "real and imag in ComplexNumber should be jt.float32"
+ assert x.shape[-2] == x.shape[-1], "only square matrix is supported for complex_eig"
+ def forward_code(np, data):
+ def _stack_to_complex(x):
+ return x[..., 0] + 1j * x[..., 1]
+ def _complex_to_stack(x):
+ return np.stack([np.real(x), np.imag(x)], axis=-1)
+ a = _stack_to_complex(data["inputs"][0])
+ w, v = data["outputs"]
+ tw, tv = np.linalg.eig(a)
+ np.copyto(w, _complex_to_stack(tw))
+ np.copyto(v, _complex_to_stack(tv))
+
+ def backward_code(np, data):
+ raise NotImplementedError
+
+ sw = x.shape[:-2] + x.shape[-1:] + (2,)
+ sv = x.value.shape
+ w, v = jt.numpy_code(
+ [sw, sv],
+ [x.value.dtype, x.value.dtype],
+ [x.value],
+ forward_code,
+ [backward_code],
+ )
+ return ComplexNumber(w, is_concat_value=True), ComplexNumber(v, is_concat_value=True)
+
+def complex_qr(x):
+ r"""
+ do the qr factorization of x in the below formula:
+ x = QR where Q is orthogonal matrix and R is upper-triangle matrix.
+ :param x (...,M,M):
+ :return:q,r as the result of qr factorization.They are both in the shape of (...,M,M).
+ """
+ assert isinstance(x, ComplexNumber), "linalg_qr is implemented for nn.ComplexNumber"
+ assert x.real.dtype == jt.float32 and x.imag.dtype == jt.float32, "real and imag in ComplexNumber should be jt.float32"
+ assert x.shape[-2] == x.shape[-1], "only square matrix is supported for linalg_qr"
+ def forward_code(np, data):
+ def _stack_to_complex(x):
+ return x[..., 0] + 1j * x[..., 1]
+ def _complex_to_stack(x):
+ return np.stack([np.real(x), np.imag(x)], axis=-1)
+ a = _stack_to_complex(data["inputs"][0])
+ qr = data["outputs"][0]
+ Q, R = np.linalg.qr(a)
+ QR = np.stack([Q, R], axis=0)
+ np.copyto(qr, _complex_to_stack(QR))
+
+ def backward_code(np, data):
+ # reference: https://github.com/tencent-quantum-lab/tensorcircuit/blob/master/tensorcircuit/backends/pytorch_ops.py
+ def H(x):
+ return np.conj(np.swapaxes(x, -1, -2))
+ def _TriangularSolve(x, r):
+ return H(np.linalg.solve(r, H(x)))
+ def _stack_to_complex(x):
+ return x[..., 0] + 1j * x[..., 1]
+ def _complex_to_stack(x):
+ return np.stack([np.real(x), np.imag(x)], axis=-1)
+ _dot = partial(np.einsum, '...ij,...jk->...ik')
+ _diag = partial(np.einsum, '...ii->...i')
+
+ dout = data["dout"]
+ out = data["outputs"][0]
+ qr = data["f_outputs"][0]
+ dout = _stack_to_complex(dout)
+ dq, dr = dout[0], dout[1]
+ qr = _stack_to_complex(qr)
+ q, r = qr[0], qr[1]
+
+
+ qdq = _dot(H(q), dq)
+ qdq_ = qdq - H(qdq)
+ rdr = _dot(r, H(dr))
+ rdr_ = rdr - H(rdr)
+ tril = np.tril(qdq_ + rdr_)
+
+ grad_a = _dot(q, dr + _TriangularSolve(tril, r))
+ grad_b = _TriangularSolve(dq - _dot(q, qdq), r)
+ ret = grad_a + grad_b
+
+ m = rdr - H(qdq)
+ eyem = np.zeros_like(m)
+ _diag(eyem)[:] = _diag(m)
+ correction = eyem - np.real(eyem)
+ ret = ret + _TriangularSolve(_dot(q, H(correction)), r)
+
+ ret = _complex_to_stack(ret)
+ np.copyto(out,ret)
+
+ qr = jt.numpy_code(
+ (2,) + x.value.shape,
+ x.value.dtype,
+ [x.value],
+ forward_code,
+ [backward_code],
+ )
+ q, r = qr[0], qr[1]
+ return ComplexNumber(q, is_concat_value=True), ComplexNumber(r, is_concat_value=True)
+
+def complex_svd(x:ComplexNumber):
+ r'''
+ calculate the Singular Value Decomposition of x.It follows the below fomula:
+ x = usv*
+ only support full matrices == False ver now, which means:
+ x's shape (...,M,K)
+ u's shape (...,M,K)
+ s's shape (...,K)
+ v's shape (...,K,N)
+ where K is min(M,N).
+ :param x:
+ :return:u,s,v.
+ '''
+ def forward_code(np, data):
+ def _stack_to_complex(x):
+ return x[..., 0] + 1j * x[..., 1]
+ def _complex_to_stack(x):
+ return np.stack([np.real(x), np.imag(x)], axis=-1)
+ a = _stack_to_complex(data["inputs"][0])
+ u, s, v = data["outputs"]
+ #TODO:remove copyto
+ tu, ts, tv = np.linalg.svd(a, full_matrices=0)
+ np.copyto(u, _complex_to_stack(tu))
+ np.copyto(s, _complex_to_stack(ts))
+ np.copyto(v, _complex_to_stack(tv))
+
+ def backward_code(np, data):
+ raise NotImplementedError
+
+ m, n = x.shape[-2:]
+ k = min(m, n)
+ s1 = list(x.shape)
+ s1[-1] = k
+ s2 = list(x.shape)
+ s2[-2] = k
+ s3 = list(x.shape)[:-2]
+ s3.append(k)
+ s1.append(2)
+ s2.append(2)
+ s3.append(2)
+ u, s, v = jt.numpy_code(
+ [s1, s3, s2],
+ [x.value.dtype, x.value.dtype, x.value.dtype],
+ [x.value],
+ forward_code,
+ [backward_code],
+ )
+ return ComplexNumber(u, is_concat_value=True), \
+ ComplexNumber(s, is_concat_value=True), \
+ ComplexNumber(v, is_concat_value=True)
#TODO:full_matrices=1
def svd(x):
@@ -25,6 +233,8 @@ def svd(x):
:param x:
:return:u,s,v.
'''
+ if isinstance(x, ComplexNumber):
+ return complex_svd(x)
def forward_code(np, data):
a = data["inputs"][0]
u, s, v = data["outputs"]
@@ -92,6 +302,17 @@ def T(x):
)
return u, s, v
+def eig(x):
+ r"""
+ calculate the eigenvalues and eigenvectors of x.
+ :param x (...,M,M):
+ :return (ComplexNumber):w, v.
+ w (...,M) : the eigenvalues.
+ v (...,M,M) : normalized eigenvectors.
+ """
+ if isinstance(x, ComplexNumber):
+ return complex_eig(x)
+ return complex_eig(ComplexNumber(x))
def eigh(x):
r"""
@@ -147,6 +368,8 @@ def inv(x):
:param x (...,M,M):
:return:x^-1 (...,M,M).
"""
+ if isinstance(x, ComplexNumber):
+ return complex_inv(x)
def forward_code(np, data):
a = data["inputs"][0]
m_a = data["outputs"][0]
@@ -387,6 +610,8 @@ def qr(x):
:param x (...,M,M):
:return:q,r as the result of qr factorization.They are both in the shape of (...,M,M).
"""
+ if isinstance(x, ComplexNumber):
+ return complex_qr(x)
def forward_code(np, data):
a = data["inputs"][0]
q, r = data["outputs"]
diff --git a/python/jittor/misc.py b/python/jittor/misc.py
index fc528d62..5357baf2 100644
--- a/python/jittor/misc.py
+++ b/python/jittor/misc.py
@@ -294,9 +294,10 @@ def expand(x, *shape):
if len(shape) == 1 and isinstance(shape[0], (tuple,list,jt.NanoVector)):
shape = shape[0]
shape = list(shape)
- for i in range(len(shape)):
- if shape[i] == -1:
- shape[i] = x.shape[i]
+ offset = len(shape) - len(x.shape)
+ for i in range(len(x.shape)):
+ if shape[offset + i] == -1:
+ shape[offset + i] = x.shape[i]
return x.broadcast(shape)
jt.Var.expand = expand
@@ -2010,6 +2011,7 @@ def contiguous(x): return x.clone()
def cpu(x): return x.clone()
jt.Var.cpu = cpu
def to(x, *args, **kargs):
+ args += tuple(kargs.values())
if len(args) >= 1:
s = args[0]
if isinstance(s, jt.NanoString) or callable(s):
@@ -2234,3 +2236,16 @@ def cuda(x):
def expm1(x):
return jt.exp(x) - 1
+
+
+def isin(elements, test_elements, assume_unique=False, invert=False):
+
+ elements = elements.unsqueeze(-1)
+ test_elements = test_elements.unsqueeze(0)
+ comparison = elements == test_elements
+ result = comparison.any(dim=-1)
+
+ if invert:
+ result = jt.logical_not(result)
+
+ return result
\ No newline at end of file
diff --git a/python/jittor/nn.py b/python/jittor/nn.py
index a8190c79..22b934c9 100644
--- a/python/jittor/nn.py
+++ b/python/jittor/nn.py
@@ -22,6 +22,7 @@
from jittor.optim import *
from jittor.misc import _pair, _triple
from jittor_utils import LOG
+from functools import partial
def matmul_transpose(a, b):
@@ -363,10 +364,14 @@ def __init__(self, num_parameters=1, init_=0.25):
self.num_parameters = num_parameters
self.weight = init.constant((num_parameters,), "float32", init_)
+
def execute(self, x):
if self.num_parameters != 1:
- assert self.num_parameters == x.size(1), f"num_parameters does not match input channels in PReLU"
- return jt.maximum(0, x) + self.weight.broadcast(x, [0,2,3]) * jt.minimum(0, x)
+ assert self.num_parameters == x.shape[1], f"num_parameters does not match input channels in PReLU"
+ # Adjust broadcasting logic to ensure it matches the input dimensions
+ shape = [x.shape[0], self.num_parameters] + [1] * (len(x.shape) - 2)
+ weight_broadcasted = self.weight.broadcast(shape)
+ return jt.maximum(0, x) + weight_broadcasted * jt.minimum(0, x)
else:
return jt.maximum(0, x) + self.weight * jt.minimum(0, x)
@@ -925,6 +930,42 @@ class Conv(Module):
>>> output = conv(input)
'''
def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True):
+ if in_channels <= 0:
+ raise ValueError(f"in_channels must be greater than zero, got {in_channels}")
+ if out_channels <= 0:
+ raise ValueError(f"out_channels must be greater than zero, got {out_channels}")
+ if groups <= 0:
+ raise ValueError(f"groups must must be greater than zero, got {groups}")
+ assert in_channels % groups == 0, 'in_channels must be divisible by groups'
+ assert out_channels % groups == 0, 'out_channels must be divisible by groups'
+ if isinstance(kernel_size, tuple):
+ for size in kernel_size:
+ if size <= 0:
+ raise ValueError(f"kernel_size must be greater than zero, got {kernel_size}")
+ else:
+ if kernel_size <= 0:
+ raise ValueError(f"kernel_size must be greater than zero, got {kernel_size}")
+ if isinstance(stride, tuple):
+ for size in stride:
+ if size <= 0:
+ raise ValueError(f"stride must be greater than zero, got {stride}")
+ else:
+ if stride <= 0:
+ raise ValueError(f"stride must be greater than zero, got {stride}")
+ if isinstance(padding, tuple):
+ for size in padding:
+ if size < 0:
+ raise ValueError(f"padding must be nonnegative, got {padding}")
+ else:
+ if padding < 0:
+ raise ValueError(f"padding must be nonnegative, got {padding}")
+ if isinstance(dilation, tuple):
+ for size in dilation:
+ if size <= 0:
+ raise ValueError(f"dilation must be greater than zero, got {dilation}")
+ else:
+ if dilation <= 0:
+ raise ValueError(f"dilation must be greater than zero, got {dilation}")
self.in_channels = in_channels
self.out_channels = out_channels
self.kernel_size = kernel_size if isinstance(kernel_size, tuple) else (kernel_size, kernel_size)
@@ -935,8 +976,6 @@ def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0,
self.is_depthwise_conv = self.groups == self.out_channels and self.groups == self.in_channels
if self.is_depthwise_conv and jt.flags.use_cuda and jt.compiler.is_cuda:
self.depthwise_conv = DepthwiseConv(stride, padding, dilation)
- assert in_channels % groups == 0, 'in_channels must be divisible by groups'
- assert out_channels % groups == 0, 'out_channels must be divisible by groups'
Kh, Kw = self.kernel_size
# self.weight = init.relu_invariant_gauss([out_channels, in_channels//groups, Kh, Kw], dtype="float", mode="fan_out")
@@ -1061,6 +1100,8 @@ def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0,
self.dilation = (dilation, 1)
self.groups = groups
self.bias = bias
+ if groups <= 0:
+ raise ValueError("groups must be a positive integer")
assert in_channels % groups == 0, 'in_channels must be divisible by groups'
assert out_channels % groups == 0, 'out_channels must be divisible by groups'
# using list to escape module dfs
@@ -1121,6 +1162,8 @@ def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0,
self.padding = padding if isinstance(padding, tuple) else (padding, padding, padding)
self.dilation = dilation if isinstance(dilation, tuple) else (dilation, dilation, dilation)
self.groups = groups
+ if groups <= 0:
+ raise ValueError("groups must be a positive integer")
assert in_channels % groups == 0, 'in_channels must be divisible by groups'
assert out_channels % groups == 0, 'out_channels must be divisible by groups'
Kh, Kw, Kd = self.kernel_size
@@ -1187,7 +1230,8 @@ def conv2d(x, weight, bias=None, stride=1, padding=0, dilation=1, groups=1):
stride = _pair(stride)
dilation = _pair(dilation)
out_channels = weight.shape[0]
-
+ if groups <= 0:
+ raise ValueError("groups must be a positive integer")
if groups == 1:
N,C,H,W = x.shape
Kh, Kw = weight.shape[-2:]
@@ -1276,7 +1320,8 @@ def conv3d(x, weight, bias=None, stride=1, padding=0, dilation=1, groups=1):
stride = _triple(stride)
dilation = _triple(dilation)
out_channels = weight.shape[0]
-
+ if groups <= 0:
+ raise ValueError("groups must be a positive integer")
if jt.flags.use_cuda and jt.cudnn:
y = jt.cudnn.ops.cudnn_conv3d(x, weight, *stride, *padding, *dilation, groups)
elif groups == 1:
@@ -1352,6 +1397,7 @@ def __init__(self, in_channels, out_channels, kernel_size, stride=1, \
self.real_padding = (self.dilation[0] * (self.kernel_size[0] - 1) - self.padding[0],
self.dilation[1] * (self.kernel_size[1] - 1) - self.padding[1])
self.output_padding = output_padding if isinstance (output_padding, tuple) else (output_padding, output_padding)
+ assert self.padding[0] >= 0 or self.padding[1] >= 0,"padding must be non-negative"
assert self.output_padding[0] < max(self.stride[0], self.dilation[0]) and \
self.output_padding[1] < max(self.stride[1], self.dilation[1]), \
"output padding must be smaller than max(stride, dilation)"
@@ -1474,6 +1520,8 @@ def execute(self, x):
return conv_transpose3d(x, self.weight, self.bias, self.stride, self.padding, self.output_padding, self.group, self.dilation)
def conv_transpose(input, weight, bias=None, stride=1, padding=0, output_padding=0, groups=1, dilation=1):
+ if stride <= 0:
+ raise RuntimeError("non-positive stride is not supported")
if groups == 1:
x = input
N,C,H,W = x.shape
@@ -1564,6 +1612,8 @@ def conv_transpose3d(input, weight, bias=None, stride=1, padding=0, output_paddi
x = input
N,C,D,H,W = x.shape
i,o,d,h,w = weight.shape
+ if stride <= 0:
+ raise RuntimeError("non-positive stride is not supported")
assert C==i
assert groups==1, "Group conv not supported yet."
stride = stride if isinstance(stride, tuple) else (stride, stride, stride)
@@ -1631,6 +1681,8 @@ def pad(x,padding, mode='constant', value=0):
class ReflectionPad2d(Module):
def __init__(self, padding):
+ if padding < 0:
+ raise RuntimeError(f"padding must be > 0, but got {padding}")
self.padding = padding
if isinstance(self.padding, int):
self.pl = self.padding
@@ -1641,6 +1693,8 @@ def __init__(self, padding):
self.pl, self.pr, self.pt, self.pb = self.padding
else:
raise TypeError(f"ReflectionPad2d padding just support int or tuple, but found {type(padding)}")
+ if self.pl < 0 or self.pr < 0 or self.pt < 0 or self.pb < 0:
+ raise ValueError(f"padding must be non-negative")
def execute(self, x):
n,c,h,w = x.shape
@@ -1669,6 +1723,8 @@ def __init__(self, padding):
self.pl, self.pr, self.pt, self.pb = self.padding
else:
raise TypeError(f"ZeroPad2d padding just support int or tuple, but found {type(padding)}")
+ if self.pl < 0 or self.pr < 0 or self.pt < 0 or self.pb < 0:
+ raise ValueError(f"padding must be non-negative")
def execute(self, x):
n,c,h,w = x.shape
@@ -1687,6 +1743,8 @@ def __init__(self, padding, value):
else:
raise TypeError(f"ConstantPad2d padding just support int or tuple, but found {type(padding)}")
self.value = value
+ if self.pl < 0 or self.pr < 0 or self.pt < 0 or self.pb < 0:
+ raise ValueError(f"padding must be non-negative")
def execute(self, x):
assert len(x.shape) >= 2
@@ -1701,6 +1759,8 @@ def execute(self, x):
class ReplicationPad2d(Module):
def __init__(self, padding):
+ if padding < 0:
+ raise RuntimeError(f"padding must be > 0, but got {padding}")
self.padding = padding
if isinstance(self.padding, int):
self.pl = self.padding
@@ -1711,6 +1771,8 @@ def __init__(self, padding):
self.pl, self.pr, self.pt, self.pb = self.padding
else:
raise TypeError(f"ReplicationPad2d padding just support int or tuple, but found {type(padding)}")
+ if self.pl < 0 or self.pr < 0 or self.pt < 0 or self.pb < 0:
+ raise ValueError(f"padding must be non-negative")
def execute(self, x):
n,c,h,w = x.shape
@@ -1760,13 +1822,14 @@ def embedding(input, weight):
class PixelShuffle(Module):
def __init__(self, upscale_factor):
+ assert upscale_factor > 0,f"upscale_factor must be greater than zero,got {upscale_factor}"
self.upscale_factor = upscale_factor
def execute(self, x):
n,c,h,w = x.shape
r = self.upscale_factor
assert c%(r*r)==0, f"input channel needs to be divided by upscale_factor's square in PixelShuffle"
- if r<0:
+ if r<=0:
raise RuntimeError(f"pixel_shuffle expects a positive upscale_factor, but got {r}")
return x.reindex([n,int(c/r**2),h*r,w*r], [
"i0",
@@ -1815,6 +1878,15 @@ def execute(self, x):
class Resize(Module):
def __init__(self, size, mode="nearest", align_corners=False):
super().__init__()
+ if isinstance(size,int):
+ if size <= 0:
+ raise ValueError(f"sizes must be positive, got {size}")
+ elif isinstance(size,tuple) or isinstance(size,list):
+ for item in size:
+ if item <= 0:
+ raise ValueError(f"sizes must be positive, got {item}")
+ else:
+ raise ValueError(f"size must be int or tuple")
self.size = size
self.mode = mode
self.align_corners = align_corners
@@ -2160,15 +2232,21 @@ def grid_sample(input, grid, mode='bilinear', padding_mode='zeros', align_corner
class Upsample(Module):
def __init__(self, scale_factor=None, mode='nearest'):
- self.scale_factor = scale_factor if isinstance(scale_factor, tuple) else (scale_factor, scale_factor)
+ if isinstance(scale_factor, tuple):
+ self.scale_factor = tuple(float(factor) for factor in scale_factor)
+ else:
+ self.scale_factor = float(scale_factor) if scale_factor else None
self.mode = mode
def execute(self, x):
- return upsample(x,
- size=(
- int(x.shape[2]*self.scale_factor[0]),
- int(x.shape[3]*self.scale_factor[1])),
- mode=self.mode)
+ if self.scale_factor is None:
+ raise ValueError("scale_factor should be defined")
+ else:
+ return upsample(x,
+ size=(
+ int(x.shape[2]*self.scale_factor[0]),
+ int(x.shape[3]*self.scale_factor[1])),
+ mode=self.mode)
class UpsamplingBilinear2d(Upsample):
def __init__(self, scale_factor=None):
@@ -2234,6 +2312,15 @@ def __len__(self):
def named_children(self,):
return list(self.layers.items())
+
+ def __setattr__(self, key, value) -> None:
+ if isinstance(key, str) and key.isdigit():
+ if int(key) 0 and output_size[1] > 0, "output size must be positive."
if not isinstance(kernel_size,tuple):
kernel_size = (kernel_size,kernel_size)
+ assert kernel_size[0] > 0 and kernel_size[1] > 0, "kernel size must be positive"
if not isinstance(dilation,tuple):
dilation = (dilation,dilation)
+ assert dilation[0] > 0 and dilation[1] > 0, "dilation must be positive"
if not isinstance(padding,tuple):
padding = (padding,padding)
+ assert padding[0] >= 0 and padding[1] >= 0, "padding must be non-negative"
if not isinstance(stride,tuple):
stride = (stride,stride)
+ assert stride[0] > 0 and stride[1] > 0, "stride must be positive"
n,cl,num = X.shape
area = kernel_size[0] * kernel_size[1]
block_nums = []
@@ -3049,6 +3141,12 @@ def permute(self, *axes):
def transpose(self, *axes):
return ComplexNumber(jt.transpose(self.real, *axes), jt.transpose(self.imag, *axes))
+ def broadcast(self, shape, dims):
+ return ComplexNumber(self.real.broadcast(shape, dims), self.imag.broadcast(shape, dims))
+
+ def sum(self, dims, keepdims: bool=False):
+ return ComplexNumber(self.real.sum(dims, keepdims=keepdims), self.imag.sum(dims, keepdims=keepdims))
+
def exp(self):
er = jt.exp(self.real)
return ComplexNumber(er * jt.cos(self.imag), er * jt.sin(self.imag))
@@ -3059,7 +3157,7 @@ def conj(self):
def __add__(self, other):
if isinstance(other, ComplexNumber):
return ComplexNumber(self.real + other.real, self.imag + other.imag)
- elif isinstance(other, (jt.Var, int, float)):
+ elif isinstance(other, (int, float)):
return ComplexNumber(self.real + other, self.imag)
else:
raise NotImplementedError
@@ -3067,7 +3165,7 @@ def __add__(self, other):
def __radd__(self, other):
if isinstance(other, ComplexNumber):
return ComplexNumber(other.real + self.real, other.imag + self.imag)
- elif isinstance(other, (jt.Var, int, float)):
+ elif isinstance(other, (int, float)):
return ComplexNumber(other + self.real, self.imag)
else:
raise NotImplementedError
@@ -3075,7 +3173,7 @@ def __radd__(self, other):
def __sub__(self, other):
if isinstance(other, ComplexNumber):
return ComplexNumber(self.real - other.real, self.imag - other.imag)
- elif isinstance(other, (jt.Var, int, float)):
+ elif isinstance(other, (int, float)):
return ComplexNumber(self.real - other, self.imag)
else:
raise NotImplementedError
@@ -3083,7 +3181,7 @@ def __sub__(self, other):
def __rsub__(self, other):
if isinstance(other, ComplexNumber):
return ComplexNumber(other.real - self.real, other.imag - self.imag)
- elif isinstance(other, (jt.Var, int, float)):
+ elif isinstance(other, (int, float)):
return ComplexNumber(other - self.real, self.imag)
else:
raise NotImplementedError
@@ -3094,8 +3192,6 @@ def __mul__(self, other):
self.real * other.imag + self.imag * other.real)
elif isinstance(other, (int, float)):
return ComplexNumber(self.value * other, is_concat_value=True)
- elif isinstance(other, jt.Var):
- return ComplexNumber(self.real * other, self.imag * other)
else:
raise NotImplementedError
@@ -3105,8 +3201,6 @@ def __rmul__(self, other):
other.imag * self.real + other.real * self.imag)
elif isinstance(other, (int, float)):
return ComplexNumber(other * self.value, is_concat_value=True)
- elif isinstance(other, jt.Var):
- return ComplexNumber(other * self.real, other * self.imag)
else:
raise NotImplementedError
@@ -3117,8 +3211,6 @@ def __truediv__(self, other):
(self.imag * other.real - self.real * other.imag) / norm)
elif isinstance(other, (int, float)):
return ComplexNumber(self.value / other, is_concat_value=True)
- elif isinstance(other, jt.Var):
- return ComplexNumber(self.real / other, self.imag / other)
else:
raise NotImplementedError
@@ -3127,7 +3219,7 @@ def __rtruediv__(self, other):
if isinstance(other, ComplexNumber):
return ComplexNumber((other.real * self.real + other.imag * self.imag) / norm,
(other.imag * self.real - other.real * self.imag) / norm)
- elif isinstance(other, (int, float, jt.Var)):
+ elif isinstance(other, (int, float)):
return ComplexNumber(other * self.real / norm, - other * self.imag / norm)
else:
raise NotImplementedError
@@ -3136,8 +3228,6 @@ def __matmul__(self, other):
if isinstance(other, ComplexNumber):
return ComplexNumber(self.real @ other.real - self.imag @ other.imag,
self.real @ other.imag + self.imag @ other.real)
- elif isinstance(other, jt.Var):
- return ComplexNumber(self.real @ other, self.imag @ other)
else:
raise NotImplementedError
@@ -3145,8 +3235,6 @@ def __imatmul__(self, other):
if isinstance(other, ComplexNumber):
return ComplexNumber(other.real @ self.real - other.imag @ self.imag,
other.imag @ self.real + other.real @ self.imag)
- elif isinstance(other, jt.Var):
- return ComplexNumber(other @ self.real, other @ self.imag)
else:
raise NotImplementedError
diff --git a/python/jittor/pool.py b/python/jittor/pool.py
index f9ad9a68..7e9e808a 100644
--- a/python/jittor/pool.py
+++ b/python/jittor/pool.py
@@ -29,9 +29,20 @@ def __init__(self, kernel_size, stride=None, padding=0, dilation=None, return_in
self.padding = padding if isinstance(padding, tuple) else (padding, padding)
self.ceil_mode = ceil_mode
self.count_include_pad = count_include_pad and padding != 0
+ for item in self.kernel_size:
+ if item <= 0:
+ raise RuntimeError(f"kernel_size must be greater than zero, but got {item}")
+ for item in self.stride:
+ if item <= 0:
+ raise RuntimeError(f"stride must be greater than zero, but got {item}")
+ for item in self.padding:
+ if item < 0:
+ raise RuntimeError(f"padding must be non-negative, but got {item}")
def execute(self, x):
N,C,H,W = x.shape
+ if H <= self.kernel_size[0] or W <= self.kernel_size[1]:
+ raise RuntimeError(f"size of var should be larger than kernel_size")
if self.ceil_mode == False:
h = (H+self.padding[0]*2-self.kernel_size[0])//self.stride[0]+1
w = (W+self.padding[1]*2-self.kernel_size[1])//self.stride[1]+1
@@ -205,9 +216,17 @@ def __init__(self, kernel_size, stride=None, padding=0, dilation=None, return_in
self.padding = _triple(padding)
self.ceil_mode = ceil_mode
self.count_include_pad = count_include_pad and padding != 0
+ if self.kernel_size[0] <= 0 or self.kernel_size[1] <= 0 or self.kernel_size[2] <= 0:
+ raise RuntimeError(f"kernel_size must be greater than zero, but got {kernel_size}")
+ if self.stride[0] <= 0 or self.stride[1] <= 0 or self.stride[2] <= 0:
+ raise RuntimeError(f"stride must be greater than zero, but got {stride}")
+ if self.padding[0] < 0 or self.padding[1] < 0 or self.padding[2] < 0:
+ raise RuntimeError(f"padding must be non-negative, but got {padding}")
def execute(self, x):
N,C,D,H,W = x.shape
+ if D <= self.kernel_size[0] or H <= self.kernel_size[1] or W <= self.kernel_size[2]:
+ raise RuntimeError(f"size of var should be larger than kernel_size")
if self.ceil_mode == False:
d = (D+self.padding[0]*2-self.kernel_size[0])//self.stride[0]+1
h = (H+self.padding[1]*2-self.kernel_size[1])//self.stride[1]+1
@@ -506,7 +525,7 @@ def execute(self, x):
f"i3*{self.sh}+i6", # Hid
f"i4*{self.sw}+i7", # Wid
])
- return xx.reduce("maximun", [5,6,7])
+ return xx.reduce("maximum", [5,6,7])
def pool(x, kernel_size, op, padding=0, stride=None):
return Pool(kernel_size, stride, padding, op=op)(x)
diff --git a/python/jittor/src/misc/cuda_atomic.h b/python/jittor/src/misc/cuda_atomic.h
index 6348befb..15b39412 100644
--- a/python/jittor/src/misc/cuda_atomic.h
+++ b/python/jittor/src/misc/cuda_atomic.h
@@ -6,7 +6,9 @@
// ***************************************************************
#pragma once
#include
+#ifndef IS_ROCM
#include
+#endif
#include "common.h"
namespace jittor {
diff --git a/python/jittor/src/misc/nan_checker.cc b/python/jittor/src/misc/nan_checker.cc
index bbec4ccb..00fb34aa 100644
--- a/python/jittor/src/misc/nan_checker.cc
+++ b/python/jittor/src/misc/nan_checker.cc
@@ -11,7 +11,9 @@
#include "misc/cuda_flags.h"
#include
#include
+#ifndef IS_ROCM
#include
+#endif
#include "helper_cuda.h"
#endif
#include "mem/allocator.h"
@@ -22,7 +24,9 @@ namespace jittor {
#ifdef IS_CUDA
EXTERN_LIB vector check_nan_float16(__half* ptr, int64 num);
+#ifndef IS_ROCM
EXTERN_LIB vector check_nan_bfloat16(__nv_bfloat16* ptr, int64 num);
+#endif
EXTERN_LIB vector check_nan_float32(float32* ptr, int64 num);
EXTERN_LIB vector check_nan_float64(float64* ptr, int64 num);
#endif
@@ -33,7 +37,9 @@ void dump_var(Var* v, string name) {
name = ss.str();
LOGe << "dump" << v << "to" << name;
char* buffer = new char[v->size];
- #ifdef HAS_CUDA
+ #ifdef IS_ROCM
+ hipMemcpy(buffer, v->mem_ptr, v->size, hipMemcpyDefault);
+ #elif IS_CUDA
cudaMemcpy(buffer, v->mem_ptr, v->size, cudaMemcpyDefault);
#else
std::memcpy(buffer, v->mem_ptr, v->size);
@@ -57,9 +63,11 @@ bool check_nan(Var* v, Op* op) {
if (v->dtype() == ns_float16) {
nan_index = check_nan_float16((__half*)v->mem_ptr, v->num);
}
+ #ifndef IS_ROCM
if (v->dtype() == ns_bfloat16) {
nan_index = check_nan_bfloat16((__nv_bfloat16*)v->mem_ptr, v->num);
}
+ #endif
if (v->dtype() == ns_float32) {
nan_index = check_nan_float32((float32*)v->mem_ptr, v->num);
} else
@@ -104,14 +112,16 @@ bool check_nan(Var* v, Op* op) {
auto* ptr = input->ptr<__half>();
__half value;
cudaMemcpy(&value, ptr+index, sizeof(__half), cudaMemcpyDeviceToHost);
- LOGe << "input" << icnt << "dtype" << input->dtype() << "index" << index << "value" << (float)value;
+ // LOGe << "input" << icnt << "dtype" << input->dtype() << "index" << index << "value" << (float)value;
} else
+ #ifndef IS_ROCM
if (input->dtype() == ns_bfloat16) {
auto* ptr = input->ptr<__nv_bfloat16>();
__nv_bfloat16 value;
cudaMemcpy(&value, ptr+index, sizeof(__nv_bfloat16), cudaMemcpyDeviceToHost);
LOGe << "input" << icnt << "dtype" << input->dtype() << "index" << index << "value" << (float)value;
} else
+ #endif
if (input->dtype() == ns_float32) {
auto* ptr = input->ptr();
float32 value;
diff --git a/python/jittor/src/misc/nan_checker.cu b/python/jittor/src/misc/nan_checker.cu
index ec7998ba..b2e631f1 100644
--- a/python/jittor/src/misc/nan_checker.cu
+++ b/python/jittor/src/misc/nan_checker.cu
@@ -7,9 +7,13 @@
#include "misc/cuda_flags.h"
#include
#include
-#include
+
#include "helper_cuda.h"
#include
+//TODO:FIX in ROCM
+#ifndef IS_ROCM
+#include
+#endif
namespace jittor {
@@ -37,7 +41,7 @@ __global__ void _check_nan_float16(__half* __restrict__ ptr, int64 num, int* cnt
print_nan(float(ptr[i]), i, cnt);
}
}
-
+#ifndef IS_ROCM
__global__ void _check_nan_bfloat16(__nv_bfloat16* __restrict__ ptr, int64 num, int* cnt) {
int64 i = threadIdx.x + blockIdx.x * (int64)blockDim.x;
if (i check_nan_float16(__half* ptr, int64 num) {
_check_nan_float16<<>>(ptr, num, check_nan_get_device_ptr());
return report_nan();
}
-
+#ifndef IS_ROCM
vector check_nan_bfloat16(__nv_bfloat16* ptr, int64 num) {
int block_num = std::max((int64)1, (num-1)/1024+1);
int thread_num = std::min((int64)1024, num);
_check_nan_bfloat16<<>>(ptr, num, check_nan_get_device_ptr());
return report_nan();
}
-
+#endif
#endif
}
\ No newline at end of file
diff --git a/python/jittor/src/ops/binary_op.cc b/python/jittor/src/ops/binary_op.cc
index 01a8ef2d..848d40a4 100644
--- a/python/jittor/src/ops/binary_op.cc
+++ b/python/jittor/src/ops/binary_op.cc
@@ -440,6 +440,17 @@ BinaryOp::BinaryOp(Var* x, Var* y, NanoString op) : x(x), y(y) {
return;
}
+ #ifdef IS_ACL
+ if (x->dtype() != y->dtype()) {
+ auto dtype = binary_dtype_infer(ns_add, x->ns, y->ns, 0, 0);
+ auto xp = make_unary(x, dtype);
+ auto yp = make_unary(y, dtype);
+ auto zp = make_binary(xp, yp, op);
+ forward(zp);
+ return;
+ }
+ #endif
+
flags.set(NodeFlags::_cpu);
flags.set(NodeFlags::_cuda);
set_type(OpType::element);
diff --git a/python/jittor/src/type/fp16_compute.h b/python/jittor/src/type/fp16_compute.h
index 56f70789..f5c85d3b 100644
--- a/python/jittor/src/type/fp16_compute.h
+++ b/python/jittor/src/type/fp16_compute.h
@@ -11,12 +11,17 @@
#include
#include
+#ifndef IS_ROCM
#include
+#endif
namespace jittor {
typedef __half float16;
+#ifndef IS_ROCM
typedef __nv_bfloat16 bfloat16;
+#endif
+
#if CUDA_ARCH >= 800
inline __device__ float16 max(float16 a, float16 b) { return __hmax(a, b); }
@@ -32,7 +37,7 @@ inline __device__ float16 min(float16 a, float16 b) { return float(a)= 800
inline __device__ bfloat16 max(bfloat16 a, bfloat16 b) { return __hmax(a, b); }
inline __device__ bfloat16 min(bfloat16 a, bfloat16 b) { return __hmin(a, b); }
@@ -45,7 +50,7 @@ inline __device__ bfloat16 min(bfloat16 a, bfloat16 b) { return float(a)
__device__ inline
typename std::enable_if::type
diff --git a/python/jittor/src/var_holder.cc b/python/jittor/src/var_holder.cc
index 092eddd2..ed4d4c0a 100644
--- a/python/jittor/src/var_holder.cc
+++ b/python/jittor/src/var_holder.cc
@@ -215,11 +215,14 @@ inline static void cast_item_data(ItemData& data) {
auto* fp16 = (float16*)&data;
auto* fp32 = (float32*)&data;
fp32[0] = float32(fp16[0]);
- } else if (data.dtype == ns_bfloat16) {
+ }
+ #ifndef IS_ROCM
+ else if (data.dtype == ns_bfloat16) {
auto* bf16 = (bfloat16*)&data;
auto* fp32 = (float32*)&data;
fp32[0] = float32(bf16[0]);
}
+ #endif
data.dtype = ns_float32;
}
diff --git a/python/jittor/test/test_complex.py b/python/jittor/test/test_complex.py
new file mode 100644
index 00000000..19686009
--- /dev/null
+++ b/python/jittor/test/test_complex.py
@@ -0,0 +1,200 @@
+import jittor as jt
+from jittor.nn import ComplexNumber
+import unittest
+import numpy as np
+
+__skip_torch_test = False
+try:
+ import torch
+except:
+ __skip_torch_test = True
+
+class TestResultAndGrad:
+ def check_results(self, rlist1, rlist2):
+ assert len(rlist1) == len(rlist2)
+ for r1, r2 in zip(rlist1, rlist2):
+ assert r1.shape == r2.shape
+ assert np.allclose(r1, r2, rtol=1e-3, atol=1e-3)
+
+ def grad_jittor(self, inputs, losses):
+ grads = []
+ for i in inputs:
+ for loss in losses:
+ if isinstance(i, ComplexNumber):
+ g = jt.grad(loss, i.value, retain_graph=True)
+ grads.append(g[..., 0].numpy() + 1j * g[..., 1].numpy())
+ else:
+ g = jt.grad(loss, i, retain_graph=True)
+ grads.append(g.numpy())
+ return grads
+
+ def grad_torch(self, inputs, losses):
+ grads = []
+ for i in inputs:
+ for loss in losses:
+ g = torch.autograd.grad(loss, i, retain_graph=True)[0]
+ grads.append(g.detach().cpu().numpy())
+ return grads
+
+ def run_jittor_op(self, op, input_list, weights=None):
+ def _np_to_jittor(x:np.ndarray):
+ if x.dtype == np.complex64 or x.dtype == np.complex128:
+ nx = np.stack([np.real(x), np.imag(x)], axis=-1)
+ return ComplexNumber(jt.array(nx, dtype=jt.float32), is_concat_value=True)
+ elif x.dtype == np.float32 or x.dtype == np.float64:
+ return jt.array(x, dtype=jt.float32)
+ else:
+ assert False
+ def _jittor_to_np(x):
+ if isinstance(x, jt.Var):
+ return x.numpy()
+ elif isinstance(x, ComplexNumber):
+ return x.real.numpy() + 1j * x.imag.numpy()
+ assert False
+
+ ninput_list = [_np_to_jittor(x) for x in input_list]
+ output_list = op(*ninput_list)
+ if isinstance(output_list, (jt.Var, ComplexNumber)):
+ output_list = [output_list]
+ losses = []
+ if weights is None:
+ weights = []
+ for o in output_list:
+ no = o.value if isinstance(o, ComplexNumber) else o
+ w = np.random.randn(*no.shape)
+ weights.append(w)
+ losses.append(jt.sum(no * jt.array(w)))
+ else:
+ assert len(output_list) == len(weights)
+ for o, w in zip(output_list, weights):
+ no = o.value if isinstance(o, ComplexNumber) else o
+ assert w.shape == no.shape
+ losses.append(jt.sum(no * jt.array(w)))
+ output_list = [_jittor_to_np(x) for x in output_list]
+ return ninput_list, output_list, losses, weights
+
+ def run_torch_op(self, op, input_list, weights=None):
+ def _np_to_torch(x:np.ndarray):
+ return torch.from_numpy(x).requires_grad_(True)
+ def _torch_to_np(x:torch.Tensor) -> np.ndarray:
+ return x.detach().cpu().numpy()
+ ninput_list = [_np_to_torch(x) for x in input_list]
+ output_list = op(*ninput_list)
+ if isinstance(output_list, torch.Tensor):
+ output_list = [output_list]
+ losses = []
+ if weights is None:
+ weights = []
+ for o in output_list:
+ no = torch.stack([torch.real(o), torch.imag(o)], dim=-1) if o.is_complex() else o
+ w = np.random.randn(*no.shape)
+ weights.append(w)
+ losses.append(torch.sum(no * torch.from_numpy(w)))
+ else:
+ assert len(output_list) == len(weights)
+ for o, w in zip(output_list, weights):
+ no = torch.stack([torch.real(o), torch.imag(o)], dim=-1) if o.is_complex() else o
+ assert w.shape == no.shape
+ losses.append(torch.sum(no * torch.from_numpy(w)))
+ output_list = [_torch_to_np(x) for x in output_list]
+ return ninput_list, output_list, losses, weights
+
+ def check_op_with_torch(self, jittor_op, torch_op, input_list, check_grad=True):
+ weights = None
+ jittor_input, jittor_output, jittor_losses, weights = self.run_jittor_op(jittor_op, input_list, weights)
+ torch_input, torch_output, torch_losses, weights = self.run_torch_op(torch_op, input_list, weights)
+ self.check_results(jittor_output, torch_output)
+
+ if check_grad:
+ jittor_grads = self.grad_jittor(jittor_input, jittor_losses)
+ torch_grads = self.grad_torch(torch_input, torch_losses)
+ self.check_results(jittor_grads, torch_grads)
+
+ def check_op_with_numpy(self, jittor_op, numpy_op, input_list):
+ _, jittor_output, _, _ = self.run_jittor_op(jittor_op, input_list, None)
+ numpy_output = numpy_op(*input_list)
+ if isinstance(numpy_output, np.ndarray):
+ numpy_output = [numpy_output]
+
+ self.check_results(jittor_output, numpy_output)
+
+@unittest.skipIf(__skip_torch_test, "No Torch found")
+class TestComplexLinalg(unittest.TestCase, TestResultAndGrad):
+ def random_complex_matrix(self, shape):
+ r = np.random.randn(*shape)
+ i = np.random.randn(*shape)
+ return r + 1j * i
+
+ def test_complex_matmul(self):
+ s1 = (50, 200)
+ s2 = (200, 50)
+ m1 = self.random_complex_matrix(s1)
+ m2 = self.random_complex_matrix(s2)
+
+ inputs = [m1, m2]
+ self.check_op_with_torch(jt.matmul, torch.matmul, inputs)
+
+ def test_complex_matmul_batch(self):
+ s1 = (10, 50, 30)
+ s2 = (10, 30, 40)
+ m1 = self.random_complex_matrix(s1)
+ m2 = self.random_complex_matrix(s2)
+
+ inputs = [m1, m2]
+ self.check_op_with_torch(jt.matmul, torch.matmul, inputs)
+
+ def test_complex_inv(self):
+ s1 = (200, 200)
+ m1 = self.random_complex_matrix(s1)
+ inputs = [m1]
+ self.check_op_with_torch(jt.linalg.inv, torch.linalg.inv, inputs)
+
+ def test_complex_inv_batch(self):
+ s1 = (10, 50, 50)
+ m1 = self.random_complex_matrix(s1)
+ inputs = [m1]
+ self.check_op_with_torch(jt.linalg.inv, torch.linalg.inv, inputs)
+
+ def test_complex_eig(self):
+ # Unstable
+ s1 = (20, 20)
+ m1 = self.random_complex_matrix(s1)
+ inputs = [m1]
+ self.check_op_with_numpy(jt.linalg.eig, np.linalg.eig, inputs)
+
+ def test_complex_eig_batch(self):
+ # Unstable
+ s1 = (5, 10, 10)
+ m1 = self.random_complex_matrix(s1)
+ inputs = [m1]
+ self.check_op_with_numpy(jt.linalg.eig, np.linalg.eig, inputs)
+
+ def test_complex_qr(self):
+ s1 = (50, 50)
+ m1 = self.random_complex_matrix(s1)
+ inputs = [m1]
+ self.check_op_with_torch(jt.linalg.qr, torch.linalg.qr, inputs)
+
+ def test_complex_qr_batch(self):
+ s1 = (10, 20, 20)
+ m1 = self.random_complex_matrix(s1)
+ inputs = [m1]
+ self.check_op_with_torch(jt.linalg.qr, torch.linalg.qr, inputs)
+
+ def test_complex_svd(self):
+ # Unstable
+ s1 = (50, 50)
+ m1 = self.random_complex_matrix(s1)
+ inputs = [m1]
+ self.check_op_with_numpy(jt.linalg.svd, np.linalg.svd, inputs)
+
+ def test_complex_svd_batch(self):
+ # Unstable
+ s1 = (10, 20, 20)
+ m1 = self.random_complex_matrix(s1)
+ inputs = [m1]
+ self.check_op_with_numpy(jt.linalg.svd, np.linalg.svd, inputs)
+
+
+if __name__ == "__main__":
+ unittest.main()
diff --git a/setup.py b/setup.py
index 741165d2..beec50f6 100644
--- a/setup.py
+++ b/setup.py
@@ -62,7 +62,7 @@
package_data={'': ['*', '*/*', '*/*/*','*/*/*/*','*/*/*/*/*','*/*/*/*/*/*']},
# include_package_data=True,
install_requires=[
- "numpy",
+ "numpy<2.0",
"tqdm",
"pillow",
"astunparse",