diff --git a/README.md b/README.md index 414eb868..90601b43 100644 --- a/README.md +++ b/README.md @@ -382,10 +382,10 @@ Email: jittor@qq.com File an issue: https://github.com/Jittor/jittor/issues -QQ Group: 761222083 +QQ Group: 836860279 - + ## The Team diff --git a/python/jittor/__init__.py b/python/jittor/__init__.py index cfd39c42..b495d55c 100644 --- a/python/jittor/__init__.py +++ b/python/jittor/__init__.py @@ -9,7 +9,7 @@ # file 'LICENSE.txt', which is part of this source code package. # *************************************************************** -__version__ = '1.3.9.6' +__version__ = '1.3.9.10' from jittor_utils import lock with lock.lock_scope(): ori_int = int @@ -428,7 +428,9 @@ def random(shape, dtype="float32", type="uniform"): jt.Var([[0.96788853 0.28334728 0.30482838] [0.46107793 0.62798643 0.03457401]], dtype=float32) ''' - + for dim in shape: + if dim < 0: + raise RuntimeError(f"Trying to create tensor with negative dimension {dim}: {shape}") ret = ops.random(shape, "float32", type) ## TODO: move those code to core #if dtype in ["float16", "bfloat16"]: @@ -484,6 +486,9 @@ def ones(*shape, dtype="float32"): shape = shape[:-1] if isinstance(shape, tuple) and isinstance(shape[0], (Sequence, NanoVector)): shape = shape[0] + for dim in shape: + if dim < 0: + raise RuntimeError(f"Trying to create tensor with negative dimension {dim}: {shape}") return unary(1, dtype).broadcast(shape) def new_ones(x, size): @@ -515,6 +520,9 @@ def zeros(*shape, dtype="float32"): shape = shape[:-1] if isinstance(shape, tuple) and isinstance(shape[0], (Sequence, NanoVector)): shape = shape[0] + for dim in shape: + if dim < 0: + raise RuntimeError(f"Trying to create tensor with negative dimension {dim}: {shape}") return unary(0, dtype).broadcast(shape) def new_zeros(x, size): @@ -547,6 +555,9 @@ def full(shape,val,dtype="float32"): ''' if not isinstance(shape, (NanoVector, Sequence)): shape = (shape,) + for dim in shape: + if dim < 0: + raise RuntimeError(f"Trying to create tensor with negative dimension {dim}: {shape}") return unary(val, dtype).broadcast(shape) def new_full(x, size, val): @@ -687,6 +698,8 @@ def flatten(input, start_dim=0, end_dim=-1): start_dim = len(in_shape) + start_dim if start_dim < 0 else start_dim end_dim = len(in_shape) + end_dim if end_dim < 0 else end_dim assert end_dim >= start_dim, "end_dim should be larger than or equal to start_dim for flatten function" + if len(in_shape) <= end_dim: + raise IndexError(f"Dimension out of range (expected to be in range of [{-len(in_shape)}, {len(in_shape) - 1}], but got {end_dim})") out_shape = [] for i in range(0,start_dim,1): out_shape.append(in_shape[i]) dims = 1 @@ -917,6 +930,9 @@ def randn(*size, dtype="float32", requires_grad=True) -> Var: [-0.612632 -1.1471151 -1.1879086 ]], dtype=float32) ''' if isinstance(size, tuple) and isinstance(size[0], (tuple, list, NanoVector)): size = size[0] + for dim in size: + if dim < 0: + raise RuntimeError(f"Trying to create tensor with negative dimension {dim}: {size}") arr = jt.random(size, dtype, "normal") if not requires_grad: return arr.stop_grad() return arr @@ -1013,6 +1029,9 @@ def randint(low, high=None, shape=(1,), dtype="int32") -> Var: [1 1 1]], dtype=int32) ''' if high is None: low, high = 0, low + for dim in shape: + if dim < 0: + raise RuntimeError(f"Trying to create tensor with negative dimension {dim}: {shape}") v = (jt.random(shape) * (high - low) + low).clamp(low, high-0.5) v = jt.floor_int(v) return v.astype(dtype) @@ -2152,3 +2171,7 @@ def inplace_wrapper(new_k, prev_func): from . import math_util from .math_util import * from . import distributions + +if jt.compiler.has_acl: + from jittor.extern.acl.acl_compiler import change_function + change_function() diff --git a/python/jittor/compatibility/__init__.py b/python/jittor/compatibility/__init__.py index 6f88fad5..94d2e40b 100644 --- a/python/jittor/compatibility/__init__.py +++ b/python/jittor/compatibility/__init__.py @@ -1,430 +1,430 @@ -import os -os.environ["FIX_TORCH_ERROR"] = "0" - -import jittor as jt -from jittor import * -from typing import Tuple - -org_int = int = type(1) -org_float = float = type(1.0) -org_bool = bool = type(True) - -import jtorch.compiler - -import jtorch_core -from jtorch_core import * - -device.__reduce__ = lambda self: (device, (self.type,)) -device.__module__ = "jtorch" -jt.jittor_core.device = device - -def handle_dtype(args, kw, dtype): - def convert(x): - if isinstance(x, jt.Var): - return x.cast(dtype) - return x - if dtype is not None: - if args is not None: - if isinstance(args, (tuple,list)): - args = [ convert(a) for a in args ] - else: - args = convert(x) - if kw is not None: - kw = { k:convert(v) for k,v in kw.items() } - return args, kw - -def get_args_names(func): - import inspect - spec = inspect.getfullargspec(func) - return spec[0] + spec[4] - -def wrapper(func): - has_dtype = False - if hasattr(func, "__code__"): - has_dtype = "dtype" in get_args_names(func) - def inner(*args, **kw): - requires_grad = None - dtype = None - if "requires_grad" in kw: - requires_grad = kw["requires_grad"] - del kw["requires_grad"] - if not has_dtype and "dtype" in kw: - dtype = kw["dtype"] - del kw["dtype"] - if "device" in kw: - del kw["device"] - if 'pin_memory' in kw: - del kw['pin_memory'] - args, kw = handle_dtype(args, kw, dtype) - ret = func(*args, **kw) - if isinstance(ret, jt.Var): - if requires_grad is not None: - ret.requires_grad = requires_grad - if dtype is not None: - ret.astype(dtype) - return ret - return inner +# import os +# os.environ["FIX_TORCH_ERROR"] = "0" + +# import jittor as jt +# from jittor import * +# from typing import Tuple + +# org_int = int = type(1) +# org_float = float = type(1.0) +# org_bool = bool = type(True) + +# import jtorch.compiler + +# import jtorch_core +# from jtorch_core import * + +# device.__reduce__ = lambda self: (device, (self.type,)) +# device.__module__ = "jtorch" +# jt.jittor_core.device = device + +# def handle_dtype(args, kw, dtype): +# def convert(x): +# if isinstance(x, jt.Var): +# return x.cast(dtype) +# return x +# if dtype is not None: +# if args is not None: +# if isinstance(args, (tuple,list)): +# args = [ convert(a) for a in args ] +# else: +# args = convert(x) +# if kw is not None: +# kw = { k:convert(v) for k,v in kw.items() } +# return args, kw + +# def get_args_names(func): +# import inspect +# spec = inspect.getfullargspec(func) +# return spec[0] + spec[4] + +# def wrapper(func): +# has_dtype = False +# if hasattr(func, "__code__"): +# has_dtype = "dtype" in get_args_names(func) +# def inner(*args, **kw): +# requires_grad = None +# dtype = None +# if "requires_grad" in kw: +# requires_grad = kw["requires_grad"] +# del kw["requires_grad"] +# if not has_dtype and "dtype" in kw: +# dtype = kw["dtype"] +# del kw["dtype"] +# if "device" in kw: +# del kw["device"] +# if 'pin_memory' in kw: +# del kw['pin_memory'] +# args, kw = handle_dtype(args, kw, dtype) +# ret = func(*args, **kw) +# if isinstance(ret, jt.Var): +# if requires_grad is not None: +# ret.requires_grad = requires_grad +# if dtype is not None: +# ret.astype(dtype) +# return ret +# return inner -import inspect -_wrapper_keys = set(["shape", "start", "size"]) -_wrapper_keys.add("x") -for k,v in list(globals().items()): - if callable(v) and not isinstance(v, type): - try: - spec = inspect.getfullargspec(v) - args_name = spec[0] - if len(args_name) and args_name[0] in _wrapper_keys: - globals()[k] = wrapper(v) - elif spec.varargs in _wrapper_keys: - globals()[k] = wrapper(v) - except: - pass - -def empty(*size, dtype=jt.float32, device=None, requires_grad=False): - if len(size) == 1 and not isinstance(size[0], org_int): - size = size[0] - return jt.empty(size, dtype) - -Tensor = Var - -Tensor.backward = lambda x: jtorch_core.backward(x) -Tensor.grad = property(grad_get, grad_set, grad_del) -Tensor.retains_grad = property(retain_grad_get, retain_grad_set) -def retain_grad(x:Tensor, value:bool=True): - x.retains_grad = value - return value -Tensor.retain_grad = retain_grad - -Tensor.dim = lambda self: self.ndim -Tensor.ndimension = lambda self: self.ndim -Tensor.nelement = lambda self: self.numel() -Tensor.cuda = lambda self: self -def device_get(x:Tensor): - return device("cpu") if not jt.has_cuda or not jt.flags.use_cuda else device("cuda") -Tensor.device = property(device_get) - -def argmax(x: Var, dim=None, keepdim: bool = False): - return jt.argmax(x, dim, keepdim)[0] -Tensor.argmax = argmax - -def tensor_type(x: Var, dtype=None, **kwargs): - if dtype: - return x.astype(dtype) - else: - return x.dtype -Tensor.type = tensor_type - -def is_floating_point(x: Var): - return "float" in str(x.dtype) -Tensor.is_floating_point = is_floating_point - -from . import autograd -from .autograd import * - -def tensor(data, *, dtype=None, device=None, requires_grad=False, pin_memory=False): - if isinstance(data,list): - data_list = [] - check = True - for p in data: - if isinstance(p, Tensor) and p.numel()==1: - data_list.append(p.item()) - elif isinstance(p, (org_int,org_float)): - data_list.append(p) - else: - check = False - break - if check: - data = data_list - return wrapper(array)(data, dtype=dtype, device=device, requires_grad=requires_grad, pin_memory=pin_memory) - -# tensor = wrapper(array) -from_numpy = wrapper(array) -strided = None - -def mod_zero_grad(self): - for p in self.parameters(): - p.grad = None -Module.zero_grad = mod_zero_grad - -class ModuleMisc: - def parameters(self): - return iter(super().parameters()) - - def load_state_dict(self, state_dict, strict=False): - return super().load_state_dict(state_dict) - - def to(self, device=None,dtype=None): - ''' do nothing but return its self''' - return self - def register_parameter(self,name,data): - self.name = data - - def buffers(self): - for _, buf in self.named_buffers(): - yield buf +# import inspect +# _wrapper_keys = set(["shape", "start", "size"]) +# _wrapper_keys.add("x") +# for k,v in list(globals().items()): +# if callable(v) and not isinstance(v, type): +# try: +# spec = inspect.getfullargspec(v) +# args_name = spec[0] +# if len(args_name) and args_name[0] in _wrapper_keys: +# globals()[k] = wrapper(v) +# elif spec.varargs in _wrapper_keys: +# globals()[k] = wrapper(v) +# except: +# pass + +# def empty(*size, dtype=jt.float32, device=None, requires_grad=False): +# if len(size) == 1 and not isinstance(size[0], org_int): +# size = size[0] +# return jt.empty(size, dtype) + +# Tensor = Var + +# Tensor.backward = lambda x: jtorch_core.backward(x) +# Tensor.grad = property(grad_get, grad_set, grad_del) +# Tensor.retains_grad = property(retain_grad_get, retain_grad_set) +# def retain_grad(x:Tensor, value:bool=True): +# x.retains_grad = value +# return value +# Tensor.retain_grad = retain_grad + +# Tensor.dim = lambda self: self.ndim +# Tensor.ndimension = lambda self: self.ndim +# Tensor.nelement = lambda self: self.numel() +# Tensor.cuda = lambda self: self +# def device_get(x:Tensor): +# return device("cpu") if not jt.has_cuda or not jt.flags.use_cuda else device("cuda") +# Tensor.device = property(device_get) + +# def argmax(x: Var, dim=None, keepdim: bool = False): +# return jt.argmax(x, dim, keepdim)[0] +# Tensor.argmax = argmax + +# def tensor_type(x: Var, dtype=None, **kwargs): +# if dtype: +# return x.astype(dtype) +# else: +# return x.dtype +# Tensor.type = tensor_type + +# def is_floating_point(x: Var): +# return "float" in str(x.dtype) +# Tensor.is_floating_point = is_floating_point + +# from . import autograd +# from .autograd import * + +# def tensor(data, *, dtype=None, device=None, requires_grad=False, pin_memory=False): +# if isinstance(data,list): +# data_list = [] +# check = True +# for p in data: +# if isinstance(p, Tensor) and p.numel()==1: +# data_list.append(p.item()) +# elif isinstance(p, (org_int,org_float)): +# data_list.append(p) +# else: +# check = False +# break +# if check: +# data = data_list +# return wrapper(array)(data, dtype=dtype, device=device, requires_grad=requires_grad, pin_memory=pin_memory) + +# # tensor = wrapper(array) +# from_numpy = wrapper(array) +# strided = None + +# def mod_zero_grad(self): +# for p in self.parameters(): +# p.grad = None +# Module.zero_grad = mod_zero_grad + +# class ModuleMisc: +# def parameters(self): +# return iter(super().parameters()) + +# def load_state_dict(self, state_dict, strict=False): +# return super().load_state_dict(state_dict) + +# def to(self, device=None,dtype=None): +# ''' do nothing but return its self''' +# return self +# def register_parameter(self,name,data): +# self.name = data + +# def buffers(self): +# for _, buf in self.named_buffers(): +# yield buf -def make_module(cls): - class TMod(ModuleMisc, cls): - def __init__(self, *args, **kw): - dtype = None - if "dtype" in kw: - dtype = kw["dtype"] - del kw["dtype"] - self._dtype = dtype - with jt.flag_scope(th_mode=0): - if "device" in kw: - del kw["device"] - super().__init__(*args, **kw) - for k,v in self.__dict__.items(): - if not k.startswith("_") and isinstance(v, Var) \ - and v.requires_grad: - v.retain_grad() - if dtype is not None and isinstance(v, Var): - v.assign(v.cast(dtype)) - def __call__(self, *args, **kw): - args, kw = handle_dtype(args, kw, self._dtype) - # if forward is override by user, call forward - if self.__class__.forward is not TMod.forward: - return self.forward(*args, **kw) - return self.execute(*args, **kw) - def forward(self, *args, **kw): - args, kw = handle_dtype(args, kw, self._dtype) - return self.execute(*args, **kw) +# def make_module(cls): +# class TMod(ModuleMisc, cls): +# def __init__(self, *args, **kw): +# dtype = None +# if "dtype" in kw: +# dtype = kw["dtype"] +# del kw["dtype"] +# self._dtype = dtype +# with jt.flag_scope(th_mode=0): +# if "device" in kw: +# del kw["device"] +# super().__init__(*args, **kw) +# for k,v in self.__dict__.items(): +# if not k.startswith("_") and isinstance(v, Var) \ +# and v.requires_grad: +# v.retain_grad() +# if dtype is not None and isinstance(v, Var): +# v.assign(v.cast(dtype)) +# def __call__(self, *args, **kw): +# args, kw = handle_dtype(args, kw, self._dtype) +# # if forward is override by user, call forward +# if self.__class__.forward is not TMod.forward: +# return self.forward(*args, **kw) +# return self.execute(*args, **kw) +# def forward(self, *args, **kw): +# args, kw = handle_dtype(args, kw, self._dtype) +# return self.execute(*args, **kw) - @property - def training(self): - if not hasattr(self, "is_train"): - self.is_train = True - return self.is_train - @training.setter - def training(self, value): - self.is_train = value - - TMod.__name__ = cls.__name__ - return TMod - -import jtorch.cuda -import jtorch.nn -from jtorch.nn import Module, Parameter -import jtorch.optim - -from jtorch.utils.dtype import Dtype, get_string_dtype - -def frombuffer(buffer: bytearray, - *, - dtype: Dtype, - count: int = -1, - offset: int = 0, - requires_grad: bool = True) -> Tensor: - dtype = get_string_dtype(dtype) - tensor = jt.array(np.frombuffer(buffer, dtype, count=count, offset=offset)) - if requires_grad and tensor.dtype.is_float(): - tensor.requires_grad = True - return tensor - -def conflict_wrapper(origin_func, new_func): - def wrapper(*args, **kw): - if jt.flags.th_mode: - return new_func(*args, **kw) - else: - return origin_func(*args, **kw) - return wrapper - -def min(*args, **kw): - dim = None - if len(args) >= 2 and isinstance(args[1], org_int): - dim = args[1] - elif "dim" in kw and isinstance(kw["dim"], org_int): - dim = kw["dim"] - if dim is not None: - k, v = jt.argmin(*args, **kw) - return v, k - elif len(args) == 2 and isinstance(args[1], jt.Var): - return jt.minimum(args[0], args[1]) - else: - return jt.min(*args, **kw) -Tensor.min = conflict_wrapper(jt.min, min) - -def max(*args, **kw): - dim = None - if "dim" in kw: - x = kw["dim"] - if len(args) >= 2 and isinstance(args[1], org_int): - dim = args[1] - elif "dim" in kw and isinstance(kw["dim"], org_int): - dim = kw["dim"] - if dim is not None: - k, v = jt.argmax(*args, **kw) - return v, k - elif len(args) == 2 and isinstance(args[1], jt.Var): - return jt.maximum(args[0], args[1]) - else: - return jt.max(*args, **kw) -Tensor.max = conflict_wrapper(jt.max, max) - -def argsort(*args, **kw): - k, v = jt.argsort(*args, **kw) - return k -Tensor.argsort = conflict_wrapper(jt.argsort, argsort) - -LongTensor = jt.int64 -FloatTensor = jt.float -HalfTensor = jt.float16 -BoolTensor = jt.bool -IntTensor = jt.int32 - -class JDType: - def __init__(self, func, str): - self.func = func - self.str = str - self.__name__ = str.split(".")[-1] - def __call__(self, *args, **kw): - return self.func(*args, **kw) - def __str__(self): - return self.str - def is_floating_point(self): - return "float" in str(self.str) - -int8 = JDType(jt.int8, "torch.int8") -int16 = JDType(jt.int16, "torch.int16") -int = int32 = JDType(jt.int32, "torch.int32") -long = int64 = JDType(jt.int64, "torch.int64") - -half = float16 = JDType(jt.float16, "torch.float16") -float = float32 = JDType(jt.float32, "torch.float32") -double = float64 = JDType(jt.float64, "torch.float64") -bfloat16 = "bfloat16" # TODO -complex64 = "complex64" # TODO -complex128 = "complex128" # TODO -def get_JDtype(dtype): - if dtype=='float32' or dtype == jt.float32: - return float32 - elif dtype=='float64' or dtype == jt.float64: - return float64 - elif dtype=='float16' or dtype == jt.float16: - return float16 - elif dtype=='int32' or dtype == jt.int32: - return int32 - elif dtype=='int64' or dtype == jt.int64: - return int64 - elif dtype=='int16' or dtype == jt.int16: - return int16 - elif dtype=='int8' or dtype == jt.int8: - return int8 - else: - raise Exception("dtype {} not supported".format(dtype)) - -def load(path,**kwargs): - def _to_jittor(data): - if isinstance(data,dict): - return {k:_to_jittor(d) for k,d in data.items()} - if isinstance(data,list): - return [_to_jittor(d) for d in data] - if isinstance(data,np.ndarray): - return jt.array(data) - return data - data = jt.load(path) +# @property +# def training(self): +# if not hasattr(self, "is_train"): +# self.is_train = True +# return self.is_train +# @training.setter +# def training(self, value): +# self.is_train = value + +# TMod.__name__ = cls.__name__ +# return TMod + +# import jtorch.cuda +# import jtorch.nn +# from jtorch.nn import Module, Parameter +# import jtorch.optim + +# from jtorch.utils.dtype import Dtype, get_string_dtype + +# def frombuffer(buffer: bytearray, +# *, +# dtype: Dtype, +# count: int = -1, +# offset: int = 0, +# requires_grad: bool = True) -> Tensor: +# dtype = get_string_dtype(dtype) +# tensor = jt.array(np.frombuffer(buffer, dtype, count=count, offset=offset)) +# if requires_grad and tensor.dtype.is_float(): +# tensor.requires_grad = True +# return tensor + +# def conflict_wrapper(origin_func, new_func): +# def wrapper(*args, **kw): +# if jt.flags.th_mode: +# return new_func(*args, **kw) +# else: +# return origin_func(*args, **kw) +# return wrapper + +# def min(*args, **kw): +# dim = None +# if len(args) >= 2 and isinstance(args[1], org_int): +# dim = args[1] +# elif "dim" in kw and isinstance(kw["dim"], org_int): +# dim = kw["dim"] +# if dim is not None: +# k, v = jt.argmin(*args, **kw) +# return v, k +# elif len(args) == 2 and isinstance(args[1], jt.Var): +# return jt.minimum(args[0], args[1]) +# else: +# return jt.min(*args, **kw) +# Tensor.min = conflict_wrapper(jt.min, min) + +# def max(*args, **kw): +# dim = None +# if "dim" in kw: +# x = kw["dim"] +# if len(args) >= 2 and isinstance(args[1], org_int): +# dim = args[1] +# elif "dim" in kw and isinstance(kw["dim"], org_int): +# dim = kw["dim"] +# if dim is not None: +# k, v = jt.argmax(*args, **kw) +# return v, k +# elif len(args) == 2 and isinstance(args[1], jt.Var): +# return jt.maximum(args[0], args[1]) +# else: +# return jt.max(*args, **kw) +# Tensor.max = conflict_wrapper(jt.max, max) + +# def argsort(*args, **kw): +# k, v = jt.argsort(*args, **kw) +# return k +# Tensor.argsort = conflict_wrapper(jt.argsort, argsort) + +# LongTensor = jt.int64 +# FloatTensor = jt.float +# HalfTensor = jt.float16 +# BoolTensor = jt.bool +# IntTensor = jt.int32 + +# class JDType: +# def __init__(self, func, str): +# self.func = func +# self.str = str +# self.__name__ = str.split(".")[-1] +# def __call__(self, *args, **kw): +# return self.func(*args, **kw) +# def __str__(self): +# return self.str +# def is_floating_point(self): +# return "float" in str(self.str) + +# int8 = JDType(jt.int8, "torch.int8") +# int16 = JDType(jt.int16, "torch.int16") +# int = int32 = JDType(jt.int32, "torch.int32") +# long = int64 = JDType(jt.int64, "torch.int64") + +# half = float16 = JDType(jt.float16, "torch.float16") +# float = float32 = JDType(jt.float32, "torch.float32") +# double = float64 = JDType(jt.float64, "torch.float64") +# bfloat16 = "bfloat16" # TODO +# complex64 = "complex64" # TODO +# complex128 = "complex128" # TODO +# def get_JDtype(dtype): +# if dtype=='float32' or dtype == jt.float32: +# return float32 +# elif dtype=='float64' or dtype == jt.float64: +# return float64 +# elif dtype=='float16' or dtype == jt.float16: +# return float16 +# elif dtype=='int32' or dtype == jt.int32: +# return int32 +# elif dtype=='int64' or dtype == jt.int64: +# return int64 +# elif dtype=='int16' or dtype == jt.int16: +# return int16 +# elif dtype=='int8' or dtype == jt.int8: +# return int8 +# else: +# raise Exception("dtype {} not supported".format(dtype)) + +# def load(path,**kwargs): +# def _to_jittor(data): +# if isinstance(data,dict): +# return {k:_to_jittor(d) for k,d in data.items()} +# if isinstance(data,list): +# return [_to_jittor(d) for d in data] +# if isinstance(data,np.ndarray): +# return jt.array(data) +# return data +# data = jt.load(path) - return _to_jittor(data) +# return _to_jittor(data) -def is_tensor(x): - return isinstance(x, Tensor) +# def is_tensor(x): +# return isinstance(x, Tensor) -manual_seed = jt.set_global_seed -jt.flags.amp_level = 3 -Size = jt.NanoVector +# manual_seed = jt.set_global_seed +# jt.flags.amp_level = 3 +# Size = jt.NanoVector -class Generator: - def __init__(self,*args,**kw) -> None: - self.seed = None - def manual_seed(self,seed): - self.seed = seed +# class Generator: +# def __init__(self,*args,**kw) -> None: +# self.seed = None +# def manual_seed(self,seed): +# self.seed = seed -from . import fx +# from . import fx -_default_type = "float32" +# _default_type = "float32" -def get_default_dtype(): - return _default_type -def set_default_dtype(dtype): - global _default_type - _default_type = dtype +# def get_default_dtype(): +# return _default_type +# def set_default_dtype(dtype): +# global _default_type +# _default_type = dtype -dtype = JDType +# dtype = JDType -def div(x,y,rounding_mode="floor"): - assert rounding_mode == "floor" - z = (x / y) - if rounding_mode == "floor": - z = z.floor() - if x.dtype == "int32" and (isinstance(y,org_int) or y.dtype == "int32"): - z = z.int32() - return z +# def div(x,y,rounding_mode="floor"): +# assert rounding_mode == "floor" +# z = (x / y) +# if rounding_mode == "floor": +# z = z.floor() +# if x.dtype == "int32" and (isinstance(y,org_int) or y.dtype == "int32"): +# z = z.int32() +# return z -def randn(*args,**kw): - wrap_randn = wrapper(jt.randn) - generator = kw.get('generator',None) - kw.pop('generator',None) - if 'layout' in kw: - del kw['layout'] - if generator is not None and generator.seed is not None: - jt.set_global_seed(generator.seed) - return wrap_randn(*args,**kw) +# def randn(*args,**kw): +# wrap_randn = wrapper(jt.randn) +# generator = kw.get('generator',None) +# kw.pop('generator',None) +# if 'layout' in kw: +# del kw['layout'] +# if generator is not None and generator.seed is not None: +# jt.set_global_seed(generator.seed) +# return wrap_randn(*args,**kw) -def rand(*args,**kw): - print("rand") - wrap_rand = wrapper(jt.rand) - generator = kw.get('generator',None) - kw.pop('generator',None) - if 'layout' in kw: - del kw['layout'] - if generator is not None and generator.seed is not None: - jt.set_global_seed(generator.seed) - return wrap_rand(*args,**kw) +# def rand(*args,**kw): +# print("rand") +# wrap_rand = wrapper(jt.rand) +# generator = kw.get('generator',None) +# kw.pop('generator',None) +# if 'layout' in kw: +# del kw['layout'] +# if generator is not None and generator.seed is not None: +# jt.set_global_seed(generator.seed) +# return wrap_rand(*args,**kw) -def set_default_tensor_type(t: type or str): - if isinstance(t, str): - info = t.split(".") - if len(info) == 3 and info[1] == 'cuda': - jt.flags.use_cuda = 1 - #TODO: type +# def set_default_tensor_type(t: type or str): +# if isinstance(t, str): +# info = t.split(".") +# if len(info) == 3 and info[1] == 'cuda': +# jt.flags.use_cuda = 1 +# #TODO: type -def clamp(x, min=None, max=None): - return jt.clamp(x, min, max) +# def clamp(x, min=None, max=None): +# return jt.clamp(x, min, max) -def to(x,*args,**kw): - device = None - if len(args) == 1: - device = args[0] - if isinstance(device, jt.NanoString) or callable(device): - return jt.to(x,*args,**kw) - if 'cpu' in str(device): - args = [] - device = kw.get("device",None) - if 'cpu' in str(device): - kw.pop('device',None) - print("to cpu") - # print(kw) - return jt.to(x,*args,**kw) -Tensor.to = conflict_wrapper(jt.to, to) +# def to(x,*args,**kw): +# device = None +# if len(args) == 1: +# device = args[0] +# if isinstance(device, jt.NanoString) or callable(device): +# return jt.to(x,*args,**kw) +# if 'cpu' in str(device): +# args = [] +# device = kw.get("device",None) +# if 'cpu' in str(device): +# kw.pop('device',None) +# print("to cpu") +# # print(kw) +# return jt.to(x,*args,**kw) +# Tensor.to = conflict_wrapper(jt.to, to) -mm = wrapper(jt.matmul) +# mm = wrapper(jt.matmul) -def _data_get(x): - return x +# def _data_get(x): +# return x -def _data_set(x, value): - x.assign(value) +# def _data_set(x, value): +# x.assign(value) -Tensor.data = property(_data_get, _data_set) -Tensor.layout = None \ No newline at end of file +# Tensor.data = property(_data_get, _data_set) +# Tensor.layout = None \ No newline at end of file diff --git a/python/jittor/compatibility/utils/data.py b/python/jittor/compatibility/utils/data.py index 71946a23..5fcfcaa6 100644 --- a/python/jittor/compatibility/utils/data.py +++ b/python/jittor/compatibility/utils/data.py @@ -99,39 +99,39 @@ def inner_iter(self): current_batch = self.collate_batch(current_batch) yield self.to_jittor(current_batch) -def get_worker_info(): - # always return the fake worker info - return namedtuple('WorkerInfo', 'id num_workers')(0, 1) - -class RandomSampler(jt.dataset.RandomSampler): - def __init__(self, dataset, generator=None, **kwargs): - super().__init__(dataset, **kwargs) - - def __iter__(self): - if getattr(self.dataset, "support_random_access", True): - return super().__iter__() - else: - self.dataset.shuffle() - return iter(range(self.dataset.__real_len__() if hasattr(self.dataset,"__real_len__") else self.dataset.__len__())) - -class DistributedSampler(jt.dataset.Sampler): - def __init__(self, sampler: RandomSampler): - assert(isinstance(sampler, RandomSampler)) - self.sampler = sampler - - def set_epoch(self, epoch: int): - ### do nothing, let jittor's inner dataset handle - pass - - def __iter__(self): - return self.sampler.__iter__() +# def get_worker_info(): +# # always return the fake worker info +# return namedtuple('WorkerInfo', 'id num_workers')(0, 1) + +# class RandomSampler(jt.dataset.RandomSampler): +# def __init__(self, dataset, generator=None, **kwargs): +# super().__init__(dataset, **kwargs) + +# def __iter__(self): +# if getattr(self.dataset, "support_random_access", True): +# return super().__iter__() +# else: +# self.dataset.shuffle() +# return iter(range(self.dataset.__real_len__() if hasattr(self.dataset,"__real_len__") else self.dataset.__len__())) + +# class DistributedSampler(jt.dataset.Sampler): +# def __init__(self, sampler: RandomSampler): +# assert(isinstance(sampler, RandomSampler)) +# self.sampler = sampler + +# def set_epoch(self, epoch: int): +# ### do nothing, let jittor's inner dataset handle +# pass + +# def __iter__(self): +# return self.sampler.__iter__() - def __len__(self): - return self.sampler.__len__() +# def __len__(self): +# return self.sampler.__len__() -BatchSampler = jt.dataset.BatchSampler -Sampler = jt.dataset.Sampler -SequentialSampler = jt.dataset.SequentialSampler -SubsetRandomSampler = jt.dataset.SubsetRandomSampler +# BatchSampler = jt.dataset.BatchSampler +# Sampler = jt.dataset.Sampler +# SequentialSampler = jt.dataset.SequentialSampler +# SubsetRandomSampler = jt.dataset.SubsetRandomSampler -TensorDataset = Dataset +# TensorDataset = Dataset diff --git a/python/jittor/contrib.py b/python/jittor/contrib.py index 31dd72de..c4026eee 100644 --- a/python/jittor/contrib.py +++ b/python/jittor/contrib.py @@ -15,6 +15,8 @@ def argmax_pool(x, size, stride, padding=0): + if stride<=0: + raise RuntimeError(f"stride must be > 0, but got {stride}") return pool.pool(x, size, 'maximum', padding, stride) def concat(arr, dim): @@ -241,9 +243,17 @@ def concat(arr, dim=0): if len(arr) == 0: raise ValueError("need at least one array to concat") total_dim = 0 - if dim < 0: dim += len(arr[0].shape) + base_dim = len(arr[0].shape) + if dim < 0: dim += base_dim + if dim < 0 or dim >= base_dim: + raise IndexError(f"Dimension out of range (expected to be in range of [{-base_dim}, {base_dim-1}], but got {dim})") dtypes = [] for a in arr: + if len(a.shape) != base_dim: + raise RuntimeError(f"get different number of dimensions of {base_dim} and {len(a.shape)}") + for i in range(base_dim): + if i != dim and a.shape[i] != arr[0].shape[i]: + raise RuntimeError(f"Sizes of vars must match except in dimension {dim}. Expected size {arr[0].shape[i]} but got size {a.shape[i]} for dimension number {i} in the list.") total_dim += a.shape[dim] dtypes.append(str(a.dtype)) cdim = 0 diff --git a/python/jittor/dataset/mnist.py b/python/jittor/dataset/mnist.py index 7aa1d883..f0945f94 100644 --- a/python/jittor/dataset/mnist.py +++ b/python/jittor/dataset/mnist.py @@ -26,7 +26,7 @@ class MNIST(Dataset): [in] data_root(str): your data root. [in] train(bool): choose model train or val. - [in] download(bool): Download data automatically if download is Ture. + [in] download(bool): Download data automatically if download is True. [in] batch_size(int): Data batch size. [in] shuffle(bool): Shuffle data if true. [in] transform(jittor.transform): transform data. @@ -106,7 +106,7 @@ class EMNIST(Dataset): [in] data_root(str): your data root. [in] split(str): one of 'byclass', 'bymerge', 'balanced', 'letters', 'digits', 'mnist'. [in] train(bool): choose model train or val. - [in] download(bool): Download data automatically if download is Ture. + [in] download(bool): Download data automatically if download is True. [in] batch_size(int): Data batch size. [in] shuffle(bool): Shuffle data if true. [in] transform(jittor.transform): transform data. diff --git a/python/jittor/extern/acl/acl_compiler.py b/python/jittor/extern/acl/acl_compiler.py index fea9cfce..0ab4b6eb 100644 --- a/python/jittor/extern/acl/acl_compiler.py +++ b/python/jittor/extern/acl/acl_compiler.py @@ -10,6 +10,7 @@ import ctypes import glob import jittor.compiler as compiler +import jittor as jt has_acl = 0 cc_flags = "" @@ -34,16 +35,18 @@ # export DUMP_GRAPH_LEVEL=1 # build pytorch-npu -# bash ./ci/build.sh -# python3 -m pip install ./dist/torch_npu-1.11.0.post1-cp37-cp37m-linux_x86_64.whl --force-reinstall +# bash ./ci/build.sh +# python3 -m pip install ./dist/torch_npu-1.11.0.post1-cp37-cp37m-linux_x86_64.whl --force-reinstall # pytorch: conda activate cann && source /usr/local/Ascend/ascend-toolkit/set_env.sh && export TASK_QUEUE_ENABLE=0 && cd /home/cjld/new_jittor/jittor/my/mm_benchmark # python3 ./mm_bench_pt_npu.py + def install(): import jittor.compiler as compiler global has_acl, cc_flags acl_compiler_home = os.path.dirname(__file__) - cc_files = sorted(glob.glob(acl_compiler_home+"/**/*.cc", recursive=True)) + cc_files = sorted(glob.glob(acl_compiler_home + "/**/*.cc", + recursive=True)) cc_files2 = [] for name in cc_files: if "acl_op_exec" in name: @@ -52,20 +55,22 @@ def install(): cc_files2.append(name) cc_files = cc_files2 cc_flags += f" -DHAS_CUDA -DIS_ACL \ - -I/usr/local/Ascend/ascend-toolkit/latest/x86_64-linux/include/ \ - -L/usr/local/Ascend/ascend-toolkit/latest/x86_64-linux/lib64 \ + -I/usr/local/Ascend/ascend-toolkit/latest/include/ \ + -L/usr/local/Ascend/ascend-toolkit/latest/lib64/ \ -I{acl_compiler_home} -lascendcl -lacl_op_compiler " + ctypes.CDLL("libascendcl.so", dlopen_flags) ''' -ltikc_runtime - -I/usr/local/Ascend/driver/include \ - -L/usr/local/Ascend/compiler/lib64 \ - -L/usr/local/Ascend/runtime/lib64 \ + -I/usr/local/Ascend/driver/include/ \ + -L/usr/local/Ascend/ascend-toolkit/latest/compiler/lib64/ \ + -L/usr/local/Ascend/ascend-toolkit/latest/runtime/lib64/ \ ''' jittor_utils.LOG.i("ACL detected") global mod - mod = jittor_utils.compile_module(''' + mod = jittor_utils.compile_module( + ''' #include "common.h" namespace jittor { // @pyjt(process) @@ -98,9 +103,10 @@ def check(): if not has_acl: return False compiler.cc_flags += cc_flags compiler.nvcc_path = tikcc_path - compiler.nvcc_flags = compiler.cc_flags.replace("-std=c++14","") + compiler.nvcc_flags = compiler.cc_flags.replace("-std=c++14", "") return True + def post_process(): if has_acl: from jittor import pool @@ -108,5 +114,711 @@ def post_process(): import jittor as jt jt.flags.use_cuda_host_allocator = 1 jt.flags.use_parallel_op_compiler = 0 - jt.flags.amp_reg |= 32 + 4 # 32 keep float16, 4 keep reduce type - mod.init_acl_ops() \ No newline at end of file + jt.flags.amp_reg |= 32 + 4 # 32 keep float16, 4 keep reduce type + mod.init_acl_ops() + + +def acl_cmd(name: str, inputs: list, output_dtypes: list, output_shapes: list, + attr: dict): + + input_code = '' + for i in range(len(inputs)): + if name == 'MaxPoolWithArgmaxV1' or name == 'MaxPoolGradWithArgmaxV1': + input_code += f"op.add(in{i}, true, ACL_FORMAT_NCHW);\n" + else: + input_code += f"op.add(in{i}, true);\n" + + output_code = '' + for i in range(len(output_dtypes)): + if name == 'MaxPoolWithArgmaxV1'or name == 'MaxPoolGradWithArgmaxV1': + output_code += f"op.add(out{i}, false, ACL_FORMAT_NCHW);\n" + else: + output_code += f"op.add(out{i}, false);\n" + + # add attr to op + attr_code = '' + if name == 'MaxPoolWithArgmaxV1' or name == 'MaxPoolGradWithArgmaxV1': + for k, v in attr.items(): + if k == "ceil_mode": + v = 'false' if v == False else 'true' + attr_code += f"op.set_attr(\"{k}\", {v});\n" + else: + v = str(v).replace('[', '{').replace(']', '}') + attr_code += f"op.set_attr(\"{k}\", vector{v});\n" + else: + for k, v in attr.items(): + if isinstance(v, bool): + attr_code += f"op.set_attr(\"{k}\", {str(v).lower()});\n" + elif isinstance(v, str): + attr_code += f"op.set_attr(\"{k}\", \"{v}\");\n" + else: + attr_code += f"op.set_attr(\"{k}\", int({v}));\n" + + import jittor as jt + return jt.code( + output_shapes, + output_dtypes, + inputs, + cuda_header=""" + #include + #include + #include + #include + + namespace jittor { + + void printDeviceData(const vector& output_desc, const vector& output_data, const string& name = "", bool input=true) { + LOGir << "name: " << name; + if(input) + LOGir << "is input"; + else + LOGir << "is ouput"; + for (size_t i = 0; i < output_desc.size(); ++i) { + void* base_addr = aclGetDataBufferAddr(output_data[i]); + LOGir << "addr of data[" << i << "] :" << base_addr; + size_t num_dims = aclGetTensorDescNumDims(output_desc[i]); + size_t total_size = 1; + std::vector dims(num_dims); + + std::cout << "shape of data: "; + for (size_t j = 0; j < num_dims; ++j) { + aclGetTensorDescDimV2(output_desc[i], j, &dims[j]); + total_size *= dims[j]; + std::cout << dims[j] << ", "; + } + int evey_batch_size = total_size/dims[0]; + std::cout << std::endl; + + // for(int i= 0; i < dims[0]; i++) { + // evey_batch_size = 16; + // std::vector host_buffer(evey_batch_size); + // void* offset_addr = static_cast(base_addr) + i * evey_batch_size * sizeof(float); + // aclrtMemcpy(host_buffer.data(), evey_batch_size * sizeof(float), offset_addr, evey_batch_size * sizeof(float), ACL_MEMCPY_DEVICE_TO_HOST); + // std::cout << "batch[" << i << "]:"; + // for (size_t k = 0; k < evey_batch_size; ++k) { + // std::cout << host_buffer[k] << ", "; + // } + // std::cout << std::endl; + // } + } + } + + struct AclOpRunner { + string name; + vector input_desc; + vector output_desc; + vector input_data; + vector output_data; + aclopAttr *attr; + vector> input_host; + vector> input_host_32; + + AclOpRunner(const string& name) : name(name) { + attr = aclopCreateAttr(); + } + + ~AclOpRunner() { + for (auto i : input_desc) aclDestroyTensorDesc(i); + for (auto i : output_desc) aclDestroyTensorDesc(i); + for (auto i : input_data) aclDestroyDataBuffer(i); + for (auto i : output_data) aclDestroyDataBuffer(i); + aclopDestroyAttr(attr); + } + + aclDataType get_dtype(NanoString s) { + if (s == ns_float32) return ACL_FLOAT; + if (s == ns_float16) return ACL_FLOAT16; + if (s == ns_int64) return ACL_INT64; + if (s == ns_int32) return ACL_INT32; + if (s == ns_int8) return ACL_INT8; + if (s == ns_int16) return ACL_INT16; + if (s == ns_uint8) return ACL_UINT8; + if (s == ns_uint16) return ACL_UINT16; + if (s == ns_uint32) return ACL_UINT32; + if (s == ns_bool) return ACL_BOOL; + LOGf << "Not supported dtype: " << s; + return ACL_FLOAT; + } + + void add(Var* v, bool is_input, int format=ACL_FORMAT_ND) { + int64_t shape[v->shape.size()]; + for (int i=0; ishape.size(); i++) shape[i] = v->shape[i]; + + auto desc = aclCreateTensorDesc(get_dtype(v->dtype()), v->shape.size(), &shape[0], (aclFormat)format); + aclSetTensorFormat(desc, (aclFormat)format); + aclSetTensorShape(desc, v->shape.size(), &shape[0]); + LOGv << "aclCreateTensorDesc" << (int)get_dtype(v->dtype()) << v->shape.size() << &shape[0] << format; + auto data = aclCreateDataBuffer(v->mem_ptr, v->size); + LOGv << "aclCreateDataBuffer" << v->mem_ptr << v->size; + ASSERT(desc && data); + if (is_input) { + input_desc.push_back(desc); + input_data.push_back(data); + } else { + output_desc.push_back(desc); + output_data.push_back(data); + } + } + + void add_input_host(vector v, int dtype=ACL_UINT64) { + int64_t shape[1]; + shape[0] = v.size(); + auto desc = aclCreateTensorDesc((aclDataType)dtype, 1, &shape[0], ACL_FORMAT_ND); + aclSetTensorPlaceMent(desc, ACL_MEMTYPE_HOST_COMPILE_INDEPENDENT); + LOGv << "aclCreateTensorDesc" << dtype << 1 << &shape[0] << ACL_FORMAT_ND; + auto data = aclCreateDataBuffer(&v[0], v.size()*sizeof(uint64)); + ASSERT(desc && data); + LOGv << "aclCreateDataBuffer" << &v[0] << v.size()*sizeof(uint64); + input_desc.push_back(desc); + input_data.push_back(data); + input_host.emplace_back(move(v)); + LOGv << "move" << input_host.back().data(); + } + + void add_input_host_scalar(vector v, int dtype=ACL_UINT32) { + int64_t shape[1]; + shape[0] = v.size(); + auto x = (int*)&v[0]; + x[0] = (int32)v[0]; + auto desc = aclCreateTensorDesc((aclDataType)dtype, 0, &shape[0], ACL_FORMAT_ND); + aclSetTensorPlaceMent(desc, ACL_MEMTYPE_HOST_COMPILE_INDEPENDENT); + LOGv << "aclCreateTensorDesc" << dtype << 1 << &shape[0] << ACL_FORMAT_ND; + auto data = aclCreateDataBuffer(&v[0], v.size()*sizeof(uint32)); + ASSERT(desc && data); + LOGv << "aclCreateDataBuffer" << &v[0] << v.size()*sizeof(uint32); + input_desc.push_back(desc); + input_data.push_back(data); + input_host.emplace_back(move(v)); + } + + void add_input_host_nv(NanoVector nv, int dtype=ACL_UINT64) { + vector v(nv.size()); + for (int i=0; i v(nv.size()); + for (int i=0; i value) { + CHECK(aclopSetAttrListInt(attr, key.c_str(), value.size(), &value[0])==0); + } + void set_attr(const string& key, string value) { + CHECK(aclopSetAttrString(attr, key.c_str(), value.c_str())==0); + } + void set_attr(const char* key, const char* value) { + CHECK(aclopSetAttrString(attr, key, value)==0); + } + + void run() { + // printDeviceData(input_desc, input_data, name); + + LOGv << "run" << name << input_desc.size() << output_desc.size(); + if (!PyGILState_Check()) { + ASSERT(0==aclopCompileAndExecuteV2(name.c_str(), input_desc.size(), &input_desc[0], &input_data[0], output_desc.size(), &output_desc[0], &output_data[0], attr, ACL_ENGINE_SYS, ACL_COMPILE_SYS, NULL, aclstream)); + } else { + int ret; + Py_BEGIN_ALLOW_THREADS + ret = aclopCompileAndExecuteV2(name.c_str(), input_desc.size(), &input_desc[0], &input_data[0], output_desc.size(), &output_desc[0], &output_data[0], attr, ACL_ENGINE_SYS, ACL_COMPILE_SYS, NULL, aclstream); + Py_END_ALLOW_THREADS + if (ret != 0) + LOGf << "aclopCompileAndExecuteV2" << name << "failed return" << ret; + } + ASSERT(0==aclrtSynchronizeDevice()); + + // printDeviceData(output_desc, output_data, name, false); + } + }; + + } + """, + cuda_src=f""" + // aclop + AclOpRunner op("{name}"); + {input_code} + {output_code} + {attr_code} + op.run();""" + ) + + +def change_function(): + import jittor as jt + from jittor import Function + + class IndexACL(Function): + def __init__(self): + super(IndexACL, self).__init__() + + def execute(self, inshape: list, dim: int, dtype="int32"): + # zeros a tensor, shape is inshape, dtype is dtype + max_len = inshape[dim] + tmp = jt.zeros(max_len, dtype=dtype) + result = acl_cmd( + "Range", + [jt.Var(0), jt.Var(max_len), jt.Var(1)], + output_dtypes=[tmp.dtype], + output_shapes=[tmp.shape], + attr={})[0] + broadcast_dim = [] + for i in range(len(inshape)): + if i != dim: + broadcast_dim.append(i) + result = jt.broadcast(result, shape=inshape, dims=broadcast_dim) + return result + + def grad(self, grad_output): + return grad_output + + + class PoolACL(Function): + + def __init__(self, + kernel_size, + stride=None, + padding=0, + dilation=None, + return_indices=None, + ceil_mode=False, + op='maximum'): + super(PoolACL, self).__init__() + import jittor as jt + # set attr + self.kernel_size = kernel_size if isinstance( + kernel_size, tuple) else (kernel_size, kernel_size) + stride = stride if stride else kernel_size + self.stride = stride if isinstance(stride, tuple) else (stride, stride) + self.padding = padding if isinstance(padding, tuple) else (padding, + padding) + dilation = dilation if dilation else 1 + self.dilation = dilation if isinstance(dilation, tuple) else (dilation, + dilation) + attr = {} + attr['ksize'] = [1, self.kernel_size[0], self.kernel_size[1], 1] + attr['strides'] = [1, self.stride[0], self.stride[1], 1] + attr['pads'] = [1, self.padding[0], self.padding[1], 1] + attr['dilation'] = [1, self.dilation[0], self.dilation[1], 1] + attr['ceil_mode'] = ceil_mode + + self.attr = attr + self.return_indices = return_indices + self.uint16 = jt.Var(1).int32().dtype + self.op = op + + def execute(self, input): + + # create input + input_shape = input.shape + input_dtype = input.dtype + + self.input = input + + # create output + output_shape = [ + input_shape[0], input_shape[1], + (input_shape[2] + 2 * self.padding[0] - self.dilation[0] * + (self.kernel_size[0] - 1) - 1) // self.stride[0] + 1, + (input_shape[3] + 2 * self.padding[1] - self.dilation[1] * + (self.kernel_size[1] - 1) - 1) // self.stride[1] + 1 + ] + output_dtype = input_dtype + result = acl_cmd("MaxPoolWithArgmaxV1", [input], + output_dtypes=[output_dtype, self.uint16], + output_shapes=[output_shape, output_shape], + attr=self.attr) + self.index = result[1] + if self.return_indices: + return result[0], result[1] + else: + return result[0] + + def grad(self, grad_output): + + grad_input = acl_cmd("MaxPoolGradWithArgmaxV1", [self.input, grad_output, self.index], + output_dtypes=[grad_output.dtype], + output_shapes=[self.input.shape], + attr=self.attr)[0] + return grad_input + + class BmmACL(Function): + def __init__(self, adj_x1=False, adj_x2=False): + super(BmmACL, self).__init__() + self.adj_x1 = adj_x1 + self.adj_x2 = adj_x2 + + def execute(self, x1, x2): + self.input = [x1, x2] + result = acl_cmd("BatchMatMul", [x1, x2], + output_dtypes=[x1.dtype], + output_shapes=[x1.shape[:-1] + x2.shape[-1:]], + attr={})[0] + return result + + def grad(self, grad_output): + x1, x2 = self.input + grad_x1 = acl_cmd("BatchMatMul", [grad_output, x2.transpose(-2, -1)], + output_dtypes=[x1.dtype], + output_shapes=[grad_output.shape[:-1] + x1.shape[-1:]], + attr={})[0] + grad_x2 = acl_cmd("BatchMatMul", [x1.transpose(-2, -1) , grad_output], + output_dtypes=[x2.dtype], + output_shapes=[x2.shape[:-1] + grad_output.shape[-1:]], + attr={})[0] + return grad_x1, grad_x2 + + class MatmulACL(Function): + def __init__(self, adj_x1=False, adj_x2=False): + super(MatmulACL, self).__init__() + self.adj_x1 = adj_x1 + self.adj_x2 = adj_x2 + + def execute(self, x1, x2): + self.input = [x1, x2] + if len(x1.shape) > 2 or len(x2.shape) > 2: + result = acl_cmd("BatchMatMul", [x1, x2], + output_dtypes=[x1.dtype], + output_shapes=[x1.shape[:-1] + x2.shape[-1:]], + attr={})[0] + else: + result = acl_cmd("MatMul", [x1, x2], + output_dtypes=[x1.dtype], + output_shapes=[x1.shape[:-1] + x2.shape[-1:]], + attr={})[0] + return result + + def grad(self, grad_output): + x1, x2 = self.input + if len(x1.shape) > 2 or len(x2.shape) > 2: + grad_x1 = acl_cmd("BatchMatMul", [grad_output, x2.transpose(-2, -1)], + output_dtypes=[x1.dtype], + output_shapes=[grad_output.shape[:-1] + x1.shape[-1:]], + attr={})[0] + grad_x2 = acl_cmd("BatchMatMul", [x1.transpose(-2, -1) , grad_output], + output_dtypes=[x2.dtype], + output_shapes=[x2.shape[:-1] + grad_output.shape[-1:]], + attr={})[0] + else: + grad_x1 = acl_cmd("MatMul", [grad_output, x2.transpose(-2, -1)], + output_dtypes=[x1.dtype], + output_shapes=[grad_output.shape[:-1] + x1.shape[-1:]], + attr={})[0] + grad_x2 = acl_cmd("MatMul", [x1.transpose(-2, -1) , grad_output], + output_dtypes=[x2.dtype], + output_shapes=[x2.shape[:-1] + grad_output.shape[-1:]], + attr={})[0] + return grad_x1, grad_x2 + + class GetItem(Function): + def __init__(self): + super(GetItem, self).__init__() + + def stride(self, x, dim): + stride = 1 + for i in range(dim+1, len(x.shape)): + stride *= x.shape[i] + return stride + + def execute(self, x, slices): + if isinstance(slices, jt.Var) or isinstance(slices, tuple): + if isinstance(slices, jt.Var): + slices = (slices,) + if isinstance(slices[0], jt.Var): + slices_len = len(slices) + masks = jt.ones(slices_len, dtype=jt.int64) + + output = slices[0].shape + output += x.shape[slices_len:] + + input_ = [x, masks, jt.Var(list(output)).int64()] + for i in range(slices_len): + input_.append(slices[i].int32()) + + result = acl_cmd("Index", input_, + output_dtypes=[x.dtype], + output_shapes=[output], + attr={})[0] + return result + + # use AsStrided operator to implement the getitem function + # get the shape and stride of the input tensor + x_dim = len(x.shape) + + if not isinstance(slices, tuple): + slices = (slices,) + + if len(slices) < x_dim: + slices += (slice(None, None, None),) * (x_dim - len(slices)) + + self.inputs = [x, slices] + + sizes = [] + strides = [] + offset = 0 + + for dim, s in enumerate(slices): + if isinstance(s, int): + if s < 0: # Handle negative indices. + s += x.shape[dim] + offset += s * self.stride(x, dim) + elif isinstance(s, slice): + # Unpack the slice + start, stop, step = s.indices(x.size(dim)) + size = (stop - start - 1) // step + 1 + stride = self.stride(x, dim) * step + offset += start * self.stride(x, dim) + sizes.append(size) + strides.append(stride) + else: + raise ValueError("Invalid slice type") + + if not sizes: + sizes = [1] + strides = [0] + # AsStrided same with as_strided of pytorch + self.sizes = sizes + self.strides = strides + self.offset = offset + self.shape = x.shape + result = acl_cmd("AsStrided", [x, jt.Var(sizes), jt.Var(strides), jt.Var(offset)], + output_dtypes=[x.dtype], + output_shapes=[jt.empty(sizes).shape], + attr={})[0] + return result + + def grad(self, grad_output): + result = jt.zeros(self.shape, dtype=grad_output.dtype) + sizes = list(grad_output.shape) + strides = [self.stride(grad_output, dim) for dim in range(len(grad_output.shape))] + result = acl_cmd("ViewCopy", [result, jt.Var(self.sizes), jt.Var(self.strides), jt.Var(self.offset), + grad_output, jt.Var(sizes), jt.Var(strides), jt.Var(0)], + output_dtypes=[result.dtype], + output_shapes=[result.shape], + attr={})[0] + result.sync() + return result, None + + class ConcatACL(Function): + def __init__(self): + super(ConcatACL, self).__init__() + + def execute(self, input_tensors, dim=0): + self.input = input_tensors + for i in range(len(input_tensors)): + if input_tensors[i].dtype != input_tensors[0].dtype: + raise ValueError("All input tensors must have the same dtype") + if input_tensors[i].shape[:dim] != input_tensors[0].shape[:dim] or input_tensors[i].shape[dim+1:] != input_tensors[0].shape[dim+1:]: + raise ValueError("All input tensors must have the same shape") + result = acl_cmd("ConcatD", input_tensors, + output_dtypes=[input_tensors[0].dtype], + output_shapes=[jt.empty(self.calculate_output_shape(input_tensors, dim)).shape], + attr={ + "N": len(input_tensors), + "concat_dim": dim + })[0] + return result + + def grad(self, grad_output): + grad_inputs = self.split_grad(grad_output, self.input, self.axis) + return grad_inputs + + def calculate_output_shape(self, input_tensors, axis): + shape = list(input_tensors[0].shape) + for tensor in input_tensors[1:]: + shape[axis] += tensor.shape[axis] + return tuple(shape) + + def split_grad(self, grad_output, input_tensors, axis): + offset = 0 + grad_inputs = [] + for tensor in input_tensors: + grad_input = acl_cmd("Slice", + [grad_output, + [0]*axis + [offset] + [0]*(len(tensor.shape)-axis-1), + tensor.shape] + ) + grad_inputs.append(grad_input) + offset += tensor.shape[axis] + return grad_inputs + + + class SetItemACL(Function): + def __init__(self): + super(SetItemACL, self).__init__() + + def stride(self, x, dim): + # 计算给定维度的步长 + stride = 1 + for i in range(dim + 1, len(x.shape)): + stride *= x.shape[i] + return stride + + def execute(self, x, slices, value): + self.is_tensor = type(value) == jt.Var + if type(value) != jt.Var: + value = jt.array(value) + x_dim = len(x.shape) + + # 确保slices是一个元组 + if not isinstance(slices, tuple): + slices = (slices,) + + # 补齐slices使其长度等于x的维度 + if len(slices) < x_dim: + slices += (slice(None, None, None),) * (x_dim - len(slices)) + + self.inputs = [x, slices, value] + + target_sizes = [] + target_strides = [] + offset = 0 + + for dim, s in enumerate(slices): + if isinstance(s, int): + if s < 0: + s += x.shape[dim] + s = slice(s, s+1, None) + if isinstance(s, slice): + # 解包切片 + start, stop, step = s.indices(x.shape[dim]) + size = (stop - start - 1) // step + 1 + stride = self.stride(x, dim) * step + offset += start * self.stride(x, dim) + target_sizes.append(size) + target_strides.append(stride) + else: + print("slices: ", s, type(s)) + raise ValueError("Invalid slice type") + + # 计算value的size、stride和offset + value_sizes = list(value.shape) + value_strides = [self.stride(value, dim) for dim in range(len(value.shape))] + + self.target_sizes = target_sizes + self.target_strides = target_strides + self.offset = offset + self.value_sizes = value_sizes + self.value_strides = value_strides + + #import pdb; pdb.set_trace() + result = acl_cmd("ViewCopy", [x, jt.Var(target_sizes), jt.Var(target_strides), jt.Var(offset), + value, jt.Var(value_sizes), jt.Var(value_strides), jt.Var(0)], + output_dtypes=[x.dtype], + output_shapes=[x.shape], + attr={})[0] + result.sync() + return result + + def grad(self, grad_output): + result = acl_cmd("AsStrided", [grad_output, jt.Var(self.target_sizes), jt.Var(self.target_strides), jt.Var(self.offset)], + output_dtypes=[grad_output.dtype], + output_shapes=[jt.empty(self.target_sizes).shape], + attr={})[0] + # copy grad_output to new_grad_output + new_grad_output = acl_cmd("Copy", [grad_output], + output_dtypes=[grad_output.dtype], + output_shapes=[grad_output.shape], + attr={ "N": 1 })[0] + new_grad_output = acl_cmd("ViewCopy", [new_grad_output, jt.Var(self.target_sizes), jt.Var(self.target_strides), jt.Var(self.offset), + jt.zeros(self.value_sizes, dtype=grad_output.dtype), jt.Var(self.value_sizes), jt.Var(self.value_strides), jt.Var(0)], + output_dtypes=[grad_output.dtype], + output_shapes=[grad_output.shape], + attr={})[0] + new_grad_output.sync() + return new_grad_output, None, result if self.is_tensor else None + + class TriuACL(Function): + def __init__(self): + super(TriuACL, self).__init__() + + def execute(self, input, k): + self.input = input + result = acl_cmd("Triu", [input], + output_dtypes=[input.dtype], + output_shapes=[input.shape], + attr={'diagonal': k})[0] + return result + + def grad(self, grad_output): + return grad_output + + class TransposeACL(Function): + def __init__(self): + super(TransposeACL, self).__init__() + + def execute(self, input, perm): + self.input = input + + output_shape = input.shape[perm[0]:perm[0]+1] + + for i in range(1, len(perm)): + output_shape += input.shape[perm[i]:perm[i]+1] + result = acl_cmd("Transpose", [input, jt.Var(perm)], + output_dtypes=[input.dtype], + output_shapes=[output_shape], + attr={})[0] + return result + + def grad(self, grad_output): + return grad_output + + def warp(origin_func, new_func): + def warpper(*args, **kwargs): + if jt.flags.use_acl: + return new_func(*args, **kwargs) + if origin_func == jt.index: + if len(args) == 2 and args[1] == None: + args = tuple(list(args[0:1])) + return origin_func(*args, **kwargs) + return warpper + + jt.index = warp(jt.index, IndexACL()) + jt.Var.index = lambda x, dim=None: warp(jt.index, IndexACL())(x.shape, dim) + jt.nn.Pool = warp(jt.nn.Pool, PoolACL) + + jt.triu = warp(jt.triu, TriuACL()) + jt.triu_ = warp(jt.triu, TriuACL()) + jt.Var.triu = lambda x: warp(jt.Var.triu, TriuACL())(x) + jt.Var.triu_ = lambda x: warp(jt.Var.triu_, TriuACL())(x) + + jt.getitem = warp(jt.getitem, GetItem()) + jt.Var.getitem = lambda x, slices: warp(jt.getitem, GetItem())(x, slices) + jt.setitem = warp(jt.setitem, SetItemACL()) + jt.Var.setitem = lambda x, slices, value: warp(jt.setitem, SetItemACL())(x, slices, value) + + # jt.nn.bmm = warp(jt.nn.bmm, BmmACL()) + # jt.bmm = warp(jt.bmm, BmmACL()) + # jt.nn.matmul = warp(jt.matmul, MatmulACL()) + # jt.matmul = warp(jt.matmul, MatmulACL()) + # jt.transpose = warp(jt.transpose, TransposeACL()) + # jt.Var.transpose = lambda x, perm: warp(jt.transpose, TransposeACL())(x, perm) + # jt.concat = warp(jt.concat, ConcatACL()) diff --git a/python/jittor/extern/acl/acl_jittor.cc b/python/jittor/extern/acl/acl_jittor.cc index dc99887f..07639ab0 100644 --- a/python/jittor/extern/acl/acl_jittor.cc +++ b/python/jittor/extern/acl/acl_jittor.cc @@ -16,6 +16,7 @@ namespace jittor { uint64_t acl_jittor_tid; int acl_jittor_thread_running=0; aclrtContext acl_jittor_context; +aclrtStream aclstream; #define CHECK_ACL(x) ASSERTop(x,==,0) @@ -28,7 +29,7 @@ static void* acl_jittor_process_callback(void*) { // LOGir << "acl_jittor_process_callback"; auto ret = aclrtProcessReport(1000); if (ret) { - if (acl_jittor_thread_running && ret != ACL_ERROR_RT_REPORT_TIMEOUT) + if (acl_jittor_thread_running && ret != ACL_ERROR_RT_REPORT_TIMEOUT && ret != ACL_ERROR_RT_THREAD_SUBSCRIBE) LOGir << "aclrtProcessReport:" << ret << acl_error_to_string(ret); break; } @@ -59,10 +60,9 @@ acl_jittor_initer() { CHECK_ACL(aclrtSubscribeReport(acl_jittor_tid,0)); // simple callback test - // aclrtStream stream; - // CHECK_ACL(aclrtCreateStream(&stream)); - // CHECK_ACL(aclrtSubscribeReport(acl_jittor_tid,stream)); - // CHECK_ACL(aclrtLaunchCallback((aclrtCallback)&aaa, 0, ACL_CALLBACK_NO_BLOCK, stream)); + CHECK_ACL(aclrtCreateStream(&aclstream)); + // CHECK_ACL(aclrtSubscribeReport(acl_jittor_tid,aclstream)); + // CHECK_ACL(aclrtLaunchCallback((aclrtCallback)&aaa, 0, ACL_CALLBACK_NO_BLOCK, aclstream)); // CHECK_ACL(aclrtLaunchCallback((aclrtCallback)&aaa, 0, ACL_CALLBACK_NO_BLOCK, 0)); } @@ -87,7 +87,7 @@ string process_acl(const string& src, const string& name, const map fake_class = { "cudnnHandle_t", "cudnnConvolutionBwdFilterAlgo_t", @@ -117,6 +117,7 @@ string process_acl(const string& src, const string& name, const map& output_desc, const vector& output_data, const string& name = "", bool input=true) { + LOGir << "name: " << name; + if(input) + LOGir << "is input"; + else + LOGir << "is ouput"; + for (size_t i = 0; i < output_desc.size(); ++i) { + void* base_addr = aclGetDataBufferAddr(output_data[i]); + LOGir << "addr of data[" << i << "] :" << base_addr; + size_t num_dims = aclGetTensorDescNumDims(output_desc[i]); + size_t total_size = 1; + std::vector dims(num_dims); + + std::cout << "shape of data: "; + for (size_t j = 0; j < num_dims; ++j) { + aclGetTensorDescDimV2(output_desc[i], j, &dims[j]); + total_size *= dims[j]; + std::cout << dims[j] << ", "; + } + int evey_batch_size = total_size/dims[0]; + std::cout << std::endl; + + // for(int i= 0; i < dims[0]; i++) { + // evey_batch_size = 16; + // std::vector host_buffer(evey_batch_size); + // void* offset_addr = static_cast(base_addr) + i * evey_batch_size * sizeof(float); + // aclrtMemcpy(host_buffer.data(), evey_batch_size * sizeof(float), offset_addr, evey_batch_size * sizeof(float), ACL_MEMCPY_DEVICE_TO_HOST); + // std::cout << "batch[" << i << "]:"; + // for (size_t k = 0; k < evey_batch_size; ++k) { + // std::cout << host_buffer[k] << ", "; + // } + // std::cout << std::endl; + // if(i >= 3) + // break; + // } + } +} + struct AclOpRunner { string name; vector input_desc; @@ -40,6 +80,7 @@ struct AclOpRunner { vector output_data; aclopAttr *attr; vector> input_host; + vector> input_host_32; AclOpRunner(const string& name) : name(name) { attr = aclopCreateAttr(); @@ -56,9 +97,14 @@ struct AclOpRunner { aclDataType get_dtype(NanoString s) { if (s == ns_float32) return ACL_FLOAT; if (s == ns_float16) return ACL_FLOAT16; + if (s == ns_int64) return ACL_INT64; if (s == ns_int32) return ACL_INT32; if (s == ns_int8) return ACL_INT8; + if (s == ns_int16) return ACL_INT16; if (s == ns_uint8) return ACL_UINT8; + if (s == ns_uint16) return ACL_UINT16; + if (s == ns_uint32) return ACL_UINT32; + if (s == ns_bool) return ACL_BOOL; LOGf << "Not supported dtype: " << s; return ACL_FLOAT; } @@ -138,7 +184,7 @@ struct AclOpRunner { auto data = aclCreateDataBuffer(&v[0], v.size()*sizeof(int)); input_desc.push_back(desc); input_data.push_back(data); - // input_host.emplace_back(move(v)); + input_host_32.emplace_back(move(v)); } void set_attr(const string& key, bool value) { @@ -164,18 +210,22 @@ struct AclOpRunner { } void run() { + // printDeviceData(input_desc, input_data, name); + LOGv << "run" << name << input_desc.size() << output_desc.size(); if (!PyGILState_Check()) { - ASSERT(0==aclopCompileAndExecuteV2(name.c_str(), input_desc.size(), &input_desc[0], &input_data[0], output_desc.size(), &output_desc[0], &output_data[0], attr, ACL_ENGINE_SYS, ACL_COMPILE_SYS, NULL, NULL)); + ASSERT(0==aclopCompileAndExecuteV2(name.c_str(), input_desc.size(), &input_desc[0], &input_data[0], output_desc.size(), &output_desc[0], &output_data[0], attr, ACL_ENGINE_SYS, ACL_COMPILE_SYS, NULL, aclstream)); } else { int ret; Py_BEGIN_ALLOW_THREADS - ret = aclopCompileAndExecuteV2(name.c_str(), input_desc.size(), &input_desc[0], &input_data[0], output_desc.size(), &output_desc[0], &output_data[0], attr, ACL_ENGINE_SYS, ACL_COMPILE_SYS, NULL, NULL); + ret = aclopCompileAndExecuteV2(name.c_str(), input_desc.size(), &input_desc[0], &input_data[0], output_desc.size(), &output_desc[0], &output_data[0], attr, ACL_ENGINE_SYS, ACL_COMPILE_SYS, NULL, aclstream); Py_END_ALLOW_THREADS if (ret != 0) LOGf << "aclopCompileAndExecuteV2" << name << "failed return" << ret; } - // ASSERT(0==aclrtSynchronizeDevice()); + ASSERT(0==aclrtSynchronizeDevice()); + + // printDeviceData(output_desc, output_data, name, false); } }; @@ -318,6 +368,17 @@ void try_exec_and_fallback_cpu(Op* op) { auto iter = opname_map.find(bop->ns); ASSERT(iter != opname_map.end()) << "op " << bop->ns << " not found"; op.name = iter->second; + if (bop->x->dtype() == ns_bool and bop->y->dtype() == ns_bool) + { + // BitwiseOr, BitwiseAnd, BitwiseXor -> LogicalOr, LogicalAnd, LogicalXor + if (bop->ns == ns_bitwise_or) { + op.name = "LogicalOr"; + } else if (bop->ns == ns_bitwise_and) { + op.name = "LogicalAnd"; + } else if (bop->ns == ns_bitwise_xor) { + op.name = "LogicalXor"; + } + } op.run(); } else if (op->name() == string("ternary")) { @@ -345,7 +406,7 @@ void try_exec_and_fallback_cpu(Op* op) { else if (rop->ns == ns_minimum) op.name = "ReduceMin"; else if (rop->ns == ns_mean) - op.name = "Reduce"; + op.name = "ReduceMean"; else LOGf << "op " << rop->ns << " not supported"; op.add(rop->x, true); @@ -381,6 +442,19 @@ void try_exec_and_fallback_cpu(Op* op) { op.add_input_host_nv(zshape, ACL_INT64); op.add(bop->z, false); op.run(); + } + else + if (op->name() == string("fuse_transpose")) { + // replace fuse_transpose with transpose + auto top = (TransposeOp*)op; + AclOpRunner op("Transpose"); + op.add(top->x, true); + op.add(top->y, false); + vector axes; + for (int i=0; iaxes.size(); i++) + axes.push_back(top->axes[i]); + op.add_input_host(axes, ACL_INT64); + op.run(); } else { LOGf << "op " << op->name() << " not supported"; @@ -388,6 +462,7 @@ void try_exec_and_fallback_cpu(Op* op) { } } catch (std::exception& e) { fallback = 1; + LOGir << "fallback cpu" << e.what(); } for (auto v : new_alloced) { free_var_mem(v); @@ -401,7 +476,7 @@ extern int current_seed; extern int64 current_offset; static unordered_map> acl_ops = { -{"curand_random", [&](Op* op) { +{"curand_random", [¤t_seed, ¤t_offset](Op* op) { auto _op = (RandomOp*)op; AclOpRunner runner(_op->type == ns_uniform ? "StatelessRandomUniformV2" : "StatelessRandomNormalV2"); auto out = op->output(0); @@ -429,7 +504,21 @@ static unordered_map> acl_ops = { runner.set_attr("transpose_x2", _op->trans_b); runner.run(); }}, -{"cudnn_conv", [&](Op* op) { +{"cublas_batched_matmul", [&](Op* op) { + struct BatchedMatmulOp : Op { + Var* a, *b, *c; + bool adj_x1, adj_x2; + }; + auto _op = (BatchedMatmulOp*)op; + AclOpRunner runner("BatchMatMul"); + runner.add(_op->a, true); + runner.add(_op->b, true); + runner.add(_op->c, false); + runner.set_attr("adj_x1", _op->adj_x1); + runner.set_attr("adj_x2", _op->adj_x2); + runner.run(); +}}, +{"cudnn_conv", [](Op* op) { struct ConvOp : Op { Var* x, * w, * y; int strideh, stridew, paddingh, paddingw, dilationh, dilationw, groups; @@ -451,7 +540,7 @@ static unordered_map> acl_ops = { auto _op = (ConvOp*)op; _op->run_acl(); }}, -{"cudnn_conv_backward_x", [&](Op* op) { +{"cudnn_conv_backward_x", [](Op* op) { struct ConvBackwardXOp : Op { Var* w, * dy, * dx; int xh, xw, strideh, stridew, paddingh, paddingw, dilationh, dilationw, groups; @@ -480,7 +569,7 @@ static unordered_map> acl_ops = { auto _op = (ConvBackwardXOp*)op; _op->run_acl(); }}, -{"cudnn_conv_backward_w", [&](Op* op) { +{"cudnn_conv_backward_w", [](Op* op) { struct ConvBackwardWOp : Op { Var* x, * dy, * dw; int kh, kw, strideh, stridew, paddingh, paddingw, dilationh, dilationw, groups; @@ -488,7 +577,7 @@ static unordered_map> acl_ops = { void run_acl() { AclOpRunner runner("Conv2DBackpropFilter"); runner.add(x, true, ACL_FORMAT_NCHW); - runner.add_input_host_nv32(x->shape); + runner.add_input_host_nv32(dw->shape); runner.add(dy, true, ACL_FORMAT_NCHW); runner.add(dw, false, ACL_FORMAT_NCHW); runner.set_attr("strides", vector{1,1,strideh,stridew}); @@ -538,7 +627,7 @@ static jit_op_entry_t acl_do_compile(Op* op) { return oc.compile(op->get_jit_key(get_jk()), *src); } if (op->name() == string("fused")) { - FusedOp* fop = (FusedOp*)op; + FusedOp* fop = (FusedOp*)op; // if is a relayed op if (fop->context->vrm.relay_groups.size()) { LOGv << "relay fused op"; @@ -546,7 +635,17 @@ static jit_op_entry_t acl_do_compile(Op* op) { } else { return &try_exec_and_fallback_cpu; } - } else { + } else + if (op->name() == string("code")) { + CodeOp* cop = (CodeOp*)op; + if (cop->cuda_src.find("acl") != string::npos) { + LOGv << "compile acl op"; + return oc.compile(op->get_jit_key(get_jk()), *src); + } else { + return &exec_mapped_acl_ops; + } + } else + { LOGv << "compile finish" << op; return &exec_mapped_acl_ops; } diff --git a/python/jittor/extern/cuda/cublas/inc/cublas_wrapper.h b/python/jittor/extern/cuda/cublas/inc/cublas_wrapper.h index be01348f..7667e495 100644 --- a/python/jittor/extern/cuda/cublas/inc/cublas_wrapper.h +++ b/python/jittor/extern/cuda/cublas/inc/cublas_wrapper.h @@ -25,7 +25,9 @@ static inline cudaDataType get_dtype(NanoString dtype) { if (dtype == ns_float32) return CUDA_R_32F; if (dtype == ns_float64) return CUDA_R_64F; if (dtype == ns_float16) return CUDA_R_16F; + #ifndef IS_ROCM if (dtype == ns_bfloat16) return CUDA_R_16BF; + #endif LOGf << "not support type" << dtype; return CUDA_R_32F; } diff --git a/python/jittor/extern/cuda/cudnn/inc/cudnn_wrapper.h b/python/jittor/extern/cuda/cudnn/inc/cudnn_wrapper.h index 2109a03a..73fb3233 100644 --- a/python/jittor/extern/cuda/cudnn/inc/cudnn_wrapper.h +++ b/python/jittor/extern/cuda/cudnn/inc/cudnn_wrapper.h @@ -8,8 +8,9 @@ #include #include #include +#ifndef IS_ROCM #include - +#endif #include "utils/log.h" #include "helper_cuda.h" #include "fp16_emu.h" @@ -31,6 +32,8 @@ void set_max_workspace_ratio(float64 ratio); template __inline__ cudnnDataType_t getDataType(); template <> __inline__ cudnnDataType_t getDataType() { return CUDNN_DATA_HALF; } template <> __inline__ cudnnDataType_t getDataType() { return CUDNN_DATA_FLOAT; } +#ifndef IS_ROCM template <> __inline__ cudnnDataType_t getDataType<__nv_bfloat16>() { return CUDNN_DATA_BFLOAT16; } +#endif } // jittor diff --git a/python/jittor/linalg.py b/python/jittor/linalg.py index d601335c..2053d86c 100644 --- a/python/jittor/linalg.py +++ b/python/jittor/linalg.py @@ -10,6 +10,214 @@ # *************************************************************** import jittor as jt from functools import partial +from .nn import ComplexNumber + +def complex_inv(x:ComplexNumber): + r""" + calculate the inverse of x. + :param x (...,M,M): + :return:x^-1 (...,M,M). + + TODO: Faster Implementation; Check backward. + """ + assert isinstance(x, ComplexNumber), "complex_inv is implemented for nn.ComplexNumber" + assert x.real.dtype == jt.float32 and x.imag.dtype == jt.float32, "real and imag in ComplexNumber should be jt.float32" + assert x.shape[-2] == x.shape[-1], "only square matrix is supported for complex_inv" + + def forward_code(np, data): + def _stack_to_complex(x): + return x[..., 0] + 1j * x[..., 1] + def _complex_to_stack(x): + return np.stack([np.real(x), np.imag(x)], axis=-1) + + a = _stack_to_complex(data["inputs"][0]) + m_a = data["outputs"][0] + t_a = np.linalg.inv(a) + np.copyto(m_a, _complex_to_stack(t_a)) + + + def backward_code(np, data): + def T(x): + return np.conj(np.swapaxes(x, -1, -2)) + def _stack_to_complex(x): + return x[..., 0] + 1j * x[..., 1] + def _complex_to_stack(x): + return np.stack([np.real(x), np.imag(x)], axis=-1) + _dot = partial(np.einsum, '...ij,...jk->...ik') + dout = _stack_to_complex(data["dout"]) + out = data["outputs"][0] + mx = _stack_to_complex(data["f_outputs"][0]) + t = -_dot(_dot(T(mx), dout), T(mx)) + np.copyto(out, _complex_to_stack(t)) + + lmx = jt.numpy_code( + x.value.shape, + x.value.dtype, + [x.value], + forward_code, + [backward_code], + ) + + return ComplexNumber(lmx, is_concat_value=True) + +def complex_eig(x:ComplexNumber): + r""" + calculate the eigenvalues and eigenvectors of x. + :param x (...,M,M): + :return:w, v. + w (...,M) : the eigenvalues. + v (...,M,M) : normalized eigenvectors. + """ + assert isinstance(x, ComplexNumber), "complex_eig is implemented for nn.ComplexNumber" + assert x.real.dtype == jt.float32 and x.imag.dtype == jt.float32, "real and imag in ComplexNumber should be jt.float32" + assert x.shape[-2] == x.shape[-1], "only square matrix is supported for complex_eig" + def forward_code(np, data): + def _stack_to_complex(x): + return x[..., 0] + 1j * x[..., 1] + def _complex_to_stack(x): + return np.stack([np.real(x), np.imag(x)], axis=-1) + a = _stack_to_complex(data["inputs"][0]) + w, v = data["outputs"] + tw, tv = np.linalg.eig(a) + np.copyto(w, _complex_to_stack(tw)) + np.copyto(v, _complex_to_stack(tv)) + + def backward_code(np, data): + raise NotImplementedError + + sw = x.shape[:-2] + x.shape[-1:] + (2,) + sv = x.value.shape + w, v = jt.numpy_code( + [sw, sv], + [x.value.dtype, x.value.dtype], + [x.value], + forward_code, + [backward_code], + ) + return ComplexNumber(w, is_concat_value=True), ComplexNumber(v, is_concat_value=True) + +def complex_qr(x): + r""" + do the qr factorization of x in the below formula: + x = QR where Q is orthogonal matrix and R is upper-triangle matrix. + :param x (...,M,M): + :return:q,r as the result of qr factorization.They are both in the shape of (...,M,M). + """ + assert isinstance(x, ComplexNumber), "linalg_qr is implemented for nn.ComplexNumber" + assert x.real.dtype == jt.float32 and x.imag.dtype == jt.float32, "real and imag in ComplexNumber should be jt.float32" + assert x.shape[-2] == x.shape[-1], "only square matrix is supported for linalg_qr" + def forward_code(np, data): + def _stack_to_complex(x): + return x[..., 0] + 1j * x[..., 1] + def _complex_to_stack(x): + return np.stack([np.real(x), np.imag(x)], axis=-1) + a = _stack_to_complex(data["inputs"][0]) + qr = data["outputs"][0] + Q, R = np.linalg.qr(a) + QR = np.stack([Q, R], axis=0) + np.copyto(qr, _complex_to_stack(QR)) + + def backward_code(np, data): + # reference: https://github.com/tencent-quantum-lab/tensorcircuit/blob/master/tensorcircuit/backends/pytorch_ops.py + def H(x): + return np.conj(np.swapaxes(x, -1, -2)) + def _TriangularSolve(x, r): + return H(np.linalg.solve(r, H(x))) + def _stack_to_complex(x): + return x[..., 0] + 1j * x[..., 1] + def _complex_to_stack(x): + return np.stack([np.real(x), np.imag(x)], axis=-1) + _dot = partial(np.einsum, '...ij,...jk->...ik') + _diag = partial(np.einsum, '...ii->...i') + + dout = data["dout"] + out = data["outputs"][0] + qr = data["f_outputs"][0] + dout = _stack_to_complex(dout) + dq, dr = dout[0], dout[1] + qr = _stack_to_complex(qr) + q, r = qr[0], qr[1] + + + qdq = _dot(H(q), dq) + qdq_ = qdq - H(qdq) + rdr = _dot(r, H(dr)) + rdr_ = rdr - H(rdr) + tril = np.tril(qdq_ + rdr_) + + grad_a = _dot(q, dr + _TriangularSolve(tril, r)) + grad_b = _TriangularSolve(dq - _dot(q, qdq), r) + ret = grad_a + grad_b + + m = rdr - H(qdq) + eyem = np.zeros_like(m) + _diag(eyem)[:] = _diag(m) + correction = eyem - np.real(eyem) + ret = ret + _TriangularSolve(_dot(q, H(correction)), r) + + ret = _complex_to_stack(ret) + np.copyto(out,ret) + + qr = jt.numpy_code( + (2,) + x.value.shape, + x.value.dtype, + [x.value], + forward_code, + [backward_code], + ) + q, r = qr[0], qr[1] + return ComplexNumber(q, is_concat_value=True), ComplexNumber(r, is_concat_value=True) + +def complex_svd(x:ComplexNumber): + r''' + calculate the Singular Value Decomposition of x.It follows the below fomula: + x = usv* + only support full matrices == False ver now, which means: + x's shape (...,M,K) + u's shape (...,M,K) + s's shape (...,K) + v's shape (...,K,N) + where K is min(M,N). + :param x: + :return:u,s,v. + ''' + def forward_code(np, data): + def _stack_to_complex(x): + return x[..., 0] + 1j * x[..., 1] + def _complex_to_stack(x): + return np.stack([np.real(x), np.imag(x)], axis=-1) + a = _stack_to_complex(data["inputs"][0]) + u, s, v = data["outputs"] + #TODO:remove copyto + tu, ts, tv = np.linalg.svd(a, full_matrices=0) + np.copyto(u, _complex_to_stack(tu)) + np.copyto(s, _complex_to_stack(ts)) + np.copyto(v, _complex_to_stack(tv)) + + def backward_code(np, data): + raise NotImplementedError + + m, n = x.shape[-2:] + k = min(m, n) + s1 = list(x.shape) + s1[-1] = k + s2 = list(x.shape) + s2[-2] = k + s3 = list(x.shape)[:-2] + s3.append(k) + s1.append(2) + s2.append(2) + s3.append(2) + u, s, v = jt.numpy_code( + [s1, s3, s2], + [x.value.dtype, x.value.dtype, x.value.dtype], + [x.value], + forward_code, + [backward_code], + ) + return ComplexNumber(u, is_concat_value=True), \ + ComplexNumber(s, is_concat_value=True), \ + ComplexNumber(v, is_concat_value=True) #TODO:full_matrices=1 def svd(x): @@ -25,6 +233,8 @@ def svd(x): :param x: :return:u,s,v. ''' + if isinstance(x, ComplexNumber): + return complex_svd(x) def forward_code(np, data): a = data["inputs"][0] u, s, v = data["outputs"] @@ -92,6 +302,17 @@ def T(x): ) return u, s, v +def eig(x): + r""" + calculate the eigenvalues and eigenvectors of x. + :param x (...,M,M): + :return (ComplexNumber):w, v. + w (...,M) : the eigenvalues. + v (...,M,M) : normalized eigenvectors. + """ + if isinstance(x, ComplexNumber): + return complex_eig(x) + return complex_eig(ComplexNumber(x)) def eigh(x): r""" @@ -147,6 +368,8 @@ def inv(x): :param x (...,M,M): :return:x^-1 (...,M,M). """ + if isinstance(x, ComplexNumber): + return complex_inv(x) def forward_code(np, data): a = data["inputs"][0] m_a = data["outputs"][0] @@ -387,6 +610,8 @@ def qr(x): :param x (...,M,M): :return:q,r as the result of qr factorization.They are both in the shape of (...,M,M). """ + if isinstance(x, ComplexNumber): + return complex_qr(x) def forward_code(np, data): a = data["inputs"][0] q, r = data["outputs"] diff --git a/python/jittor/misc.py b/python/jittor/misc.py index fc528d62..5357baf2 100644 --- a/python/jittor/misc.py +++ b/python/jittor/misc.py @@ -294,9 +294,10 @@ def expand(x, *shape): if len(shape) == 1 and isinstance(shape[0], (tuple,list,jt.NanoVector)): shape = shape[0] shape = list(shape) - for i in range(len(shape)): - if shape[i] == -1: - shape[i] = x.shape[i] + offset = len(shape) - len(x.shape) + for i in range(len(x.shape)): + if shape[offset + i] == -1: + shape[offset + i] = x.shape[i] return x.broadcast(shape) jt.Var.expand = expand @@ -2010,6 +2011,7 @@ def contiguous(x): return x.clone() def cpu(x): return x.clone() jt.Var.cpu = cpu def to(x, *args, **kargs): + args += tuple(kargs.values()) if len(args) >= 1: s = args[0] if isinstance(s, jt.NanoString) or callable(s): @@ -2234,3 +2236,16 @@ def cuda(x): def expm1(x): return jt.exp(x) - 1 + + +def isin(elements, test_elements, assume_unique=False, invert=False): + + elements = elements.unsqueeze(-1) + test_elements = test_elements.unsqueeze(0) + comparison = elements == test_elements + result = comparison.any(dim=-1) + + if invert: + result = jt.logical_not(result) + + return result \ No newline at end of file diff --git a/python/jittor/nn.py b/python/jittor/nn.py index a8190c79..22b934c9 100644 --- a/python/jittor/nn.py +++ b/python/jittor/nn.py @@ -22,6 +22,7 @@ from jittor.optim import * from jittor.misc import _pair, _triple from jittor_utils import LOG +from functools import partial def matmul_transpose(a, b): @@ -363,10 +364,14 @@ def __init__(self, num_parameters=1, init_=0.25): self.num_parameters = num_parameters self.weight = init.constant((num_parameters,), "float32", init_) + def execute(self, x): if self.num_parameters != 1: - assert self.num_parameters == x.size(1), f"num_parameters does not match input channels in PReLU" - return jt.maximum(0, x) + self.weight.broadcast(x, [0,2,3]) * jt.minimum(0, x) + assert self.num_parameters == x.shape[1], f"num_parameters does not match input channels in PReLU" + # Adjust broadcasting logic to ensure it matches the input dimensions + shape = [x.shape[0], self.num_parameters] + [1] * (len(x.shape) - 2) + weight_broadcasted = self.weight.broadcast(shape) + return jt.maximum(0, x) + weight_broadcasted * jt.minimum(0, x) else: return jt.maximum(0, x) + self.weight * jt.minimum(0, x) @@ -925,6 +930,42 @@ class Conv(Module): >>> output = conv(input) ''' def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True): + if in_channels <= 0: + raise ValueError(f"in_channels must be greater than zero, got {in_channels}") + if out_channels <= 0: + raise ValueError(f"out_channels must be greater than zero, got {out_channels}") + if groups <= 0: + raise ValueError(f"groups must must be greater than zero, got {groups}") + assert in_channels % groups == 0, 'in_channels must be divisible by groups' + assert out_channels % groups == 0, 'out_channels must be divisible by groups' + if isinstance(kernel_size, tuple): + for size in kernel_size: + if size <= 0: + raise ValueError(f"kernel_size must be greater than zero, got {kernel_size}") + else: + if kernel_size <= 0: + raise ValueError(f"kernel_size must be greater than zero, got {kernel_size}") + if isinstance(stride, tuple): + for size in stride: + if size <= 0: + raise ValueError(f"stride must be greater than zero, got {stride}") + else: + if stride <= 0: + raise ValueError(f"stride must be greater than zero, got {stride}") + if isinstance(padding, tuple): + for size in padding: + if size < 0: + raise ValueError(f"padding must be nonnegative, got {padding}") + else: + if padding < 0: + raise ValueError(f"padding must be nonnegative, got {padding}") + if isinstance(dilation, tuple): + for size in dilation: + if size <= 0: + raise ValueError(f"dilation must be greater than zero, got {dilation}") + else: + if dilation <= 0: + raise ValueError(f"dilation must be greater than zero, got {dilation}") self.in_channels = in_channels self.out_channels = out_channels self.kernel_size = kernel_size if isinstance(kernel_size, tuple) else (kernel_size, kernel_size) @@ -935,8 +976,6 @@ def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, self.is_depthwise_conv = self.groups == self.out_channels and self.groups == self.in_channels if self.is_depthwise_conv and jt.flags.use_cuda and jt.compiler.is_cuda: self.depthwise_conv = DepthwiseConv(stride, padding, dilation) - assert in_channels % groups == 0, 'in_channels must be divisible by groups' - assert out_channels % groups == 0, 'out_channels must be divisible by groups' Kh, Kw = self.kernel_size # self.weight = init.relu_invariant_gauss([out_channels, in_channels//groups, Kh, Kw], dtype="float", mode="fan_out") @@ -1061,6 +1100,8 @@ def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, self.dilation = (dilation, 1) self.groups = groups self.bias = bias + if groups <= 0: + raise ValueError("groups must be a positive integer") assert in_channels % groups == 0, 'in_channels must be divisible by groups' assert out_channels % groups == 0, 'out_channels must be divisible by groups' # using list to escape module dfs @@ -1121,6 +1162,8 @@ def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, self.padding = padding if isinstance(padding, tuple) else (padding, padding, padding) self.dilation = dilation if isinstance(dilation, tuple) else (dilation, dilation, dilation) self.groups = groups + if groups <= 0: + raise ValueError("groups must be a positive integer") assert in_channels % groups == 0, 'in_channels must be divisible by groups' assert out_channels % groups == 0, 'out_channels must be divisible by groups' Kh, Kw, Kd = self.kernel_size @@ -1187,7 +1230,8 @@ def conv2d(x, weight, bias=None, stride=1, padding=0, dilation=1, groups=1): stride = _pair(stride) dilation = _pair(dilation) out_channels = weight.shape[0] - + if groups <= 0: + raise ValueError("groups must be a positive integer") if groups == 1: N,C,H,W = x.shape Kh, Kw = weight.shape[-2:] @@ -1276,7 +1320,8 @@ def conv3d(x, weight, bias=None, stride=1, padding=0, dilation=1, groups=1): stride = _triple(stride) dilation = _triple(dilation) out_channels = weight.shape[0] - + if groups <= 0: + raise ValueError("groups must be a positive integer") if jt.flags.use_cuda and jt.cudnn: y = jt.cudnn.ops.cudnn_conv3d(x, weight, *stride, *padding, *dilation, groups) elif groups == 1: @@ -1352,6 +1397,7 @@ def __init__(self, in_channels, out_channels, kernel_size, stride=1, \ self.real_padding = (self.dilation[0] * (self.kernel_size[0] - 1) - self.padding[0], self.dilation[1] * (self.kernel_size[1] - 1) - self.padding[1]) self.output_padding = output_padding if isinstance (output_padding, tuple) else (output_padding, output_padding) + assert self.padding[0] >= 0 or self.padding[1] >= 0,"padding must be non-negative" assert self.output_padding[0] < max(self.stride[0], self.dilation[0]) and \ self.output_padding[1] < max(self.stride[1], self.dilation[1]), \ "output padding must be smaller than max(stride, dilation)" @@ -1474,6 +1520,8 @@ def execute(self, x): return conv_transpose3d(x, self.weight, self.bias, self.stride, self.padding, self.output_padding, self.group, self.dilation) def conv_transpose(input, weight, bias=None, stride=1, padding=0, output_padding=0, groups=1, dilation=1): + if stride <= 0: + raise RuntimeError("non-positive stride is not supported") if groups == 1: x = input N,C,H,W = x.shape @@ -1564,6 +1612,8 @@ def conv_transpose3d(input, weight, bias=None, stride=1, padding=0, output_paddi x = input N,C,D,H,W = x.shape i,o,d,h,w = weight.shape + if stride <= 0: + raise RuntimeError("non-positive stride is not supported") assert C==i assert groups==1, "Group conv not supported yet." stride = stride if isinstance(stride, tuple) else (stride, stride, stride) @@ -1631,6 +1681,8 @@ def pad(x,padding, mode='constant', value=0): class ReflectionPad2d(Module): def __init__(self, padding): + if padding < 0: + raise RuntimeError(f"padding must be > 0, but got {padding}") self.padding = padding if isinstance(self.padding, int): self.pl = self.padding @@ -1641,6 +1693,8 @@ def __init__(self, padding): self.pl, self.pr, self.pt, self.pb = self.padding else: raise TypeError(f"ReflectionPad2d padding just support int or tuple, but found {type(padding)}") + if self.pl < 0 or self.pr < 0 or self.pt < 0 or self.pb < 0: + raise ValueError(f"padding must be non-negative") def execute(self, x): n,c,h,w = x.shape @@ -1669,6 +1723,8 @@ def __init__(self, padding): self.pl, self.pr, self.pt, self.pb = self.padding else: raise TypeError(f"ZeroPad2d padding just support int or tuple, but found {type(padding)}") + if self.pl < 0 or self.pr < 0 or self.pt < 0 or self.pb < 0: + raise ValueError(f"padding must be non-negative") def execute(self, x): n,c,h,w = x.shape @@ -1687,6 +1743,8 @@ def __init__(self, padding, value): else: raise TypeError(f"ConstantPad2d padding just support int or tuple, but found {type(padding)}") self.value = value + if self.pl < 0 or self.pr < 0 or self.pt < 0 or self.pb < 0: + raise ValueError(f"padding must be non-negative") def execute(self, x): assert len(x.shape) >= 2 @@ -1701,6 +1759,8 @@ def execute(self, x): class ReplicationPad2d(Module): def __init__(self, padding): + if padding < 0: + raise RuntimeError(f"padding must be > 0, but got {padding}") self.padding = padding if isinstance(self.padding, int): self.pl = self.padding @@ -1711,6 +1771,8 @@ def __init__(self, padding): self.pl, self.pr, self.pt, self.pb = self.padding else: raise TypeError(f"ReplicationPad2d padding just support int or tuple, but found {type(padding)}") + if self.pl < 0 or self.pr < 0 or self.pt < 0 or self.pb < 0: + raise ValueError(f"padding must be non-negative") def execute(self, x): n,c,h,w = x.shape @@ -1760,13 +1822,14 @@ def embedding(input, weight): class PixelShuffle(Module): def __init__(self, upscale_factor): + assert upscale_factor > 0,f"upscale_factor must be greater than zero,got {upscale_factor}" self.upscale_factor = upscale_factor def execute(self, x): n,c,h,w = x.shape r = self.upscale_factor assert c%(r*r)==0, f"input channel needs to be divided by upscale_factor's square in PixelShuffle" - if r<0: + if r<=0: raise RuntimeError(f"pixel_shuffle expects a positive upscale_factor, but got {r}") return x.reindex([n,int(c/r**2),h*r,w*r], [ "i0", @@ -1815,6 +1878,15 @@ def execute(self, x): class Resize(Module): def __init__(self, size, mode="nearest", align_corners=False): super().__init__() + if isinstance(size,int): + if size <= 0: + raise ValueError(f"sizes must be positive, got {size}") + elif isinstance(size,tuple) or isinstance(size,list): + for item in size: + if item <= 0: + raise ValueError(f"sizes must be positive, got {item}") + else: + raise ValueError(f"size must be int or tuple") self.size = size self.mode = mode self.align_corners = align_corners @@ -2160,15 +2232,21 @@ def grid_sample(input, grid, mode='bilinear', padding_mode='zeros', align_corner class Upsample(Module): def __init__(self, scale_factor=None, mode='nearest'): - self.scale_factor = scale_factor if isinstance(scale_factor, tuple) else (scale_factor, scale_factor) + if isinstance(scale_factor, tuple): + self.scale_factor = tuple(float(factor) for factor in scale_factor) + else: + self.scale_factor = float(scale_factor) if scale_factor else None self.mode = mode def execute(self, x): - return upsample(x, - size=( - int(x.shape[2]*self.scale_factor[0]), - int(x.shape[3]*self.scale_factor[1])), - mode=self.mode) + if self.scale_factor is None: + raise ValueError("scale_factor should be defined") + else: + return upsample(x, + size=( + int(x.shape[2]*self.scale_factor[0]), + int(x.shape[3]*self.scale_factor[1])), + mode=self.mode) class UpsamplingBilinear2d(Upsample): def __init__(self, scale_factor=None): @@ -2234,6 +2312,15 @@ def __len__(self): def named_children(self,): return list(self.layers.items()) + + def __setattr__(self, key, value) -> None: + if isinstance(key, str) and key.isdigit(): + if int(key) 0 and output_size[1] > 0, "output size must be positive." if not isinstance(kernel_size,tuple): kernel_size = (kernel_size,kernel_size) + assert kernel_size[0] > 0 and kernel_size[1] > 0, "kernel size must be positive" if not isinstance(dilation,tuple): dilation = (dilation,dilation) + assert dilation[0] > 0 and dilation[1] > 0, "dilation must be positive" if not isinstance(padding,tuple): padding = (padding,padding) + assert padding[0] >= 0 and padding[1] >= 0, "padding must be non-negative" if not isinstance(stride,tuple): stride = (stride,stride) + assert stride[0] > 0 and stride[1] > 0, "stride must be positive" n,cl,num = X.shape area = kernel_size[0] * kernel_size[1] block_nums = [] @@ -3049,6 +3141,12 @@ def permute(self, *axes): def transpose(self, *axes): return ComplexNumber(jt.transpose(self.real, *axes), jt.transpose(self.imag, *axes)) + def broadcast(self, shape, dims): + return ComplexNumber(self.real.broadcast(shape, dims), self.imag.broadcast(shape, dims)) + + def sum(self, dims, keepdims: bool=False): + return ComplexNumber(self.real.sum(dims, keepdims=keepdims), self.imag.sum(dims, keepdims=keepdims)) + def exp(self): er = jt.exp(self.real) return ComplexNumber(er * jt.cos(self.imag), er * jt.sin(self.imag)) @@ -3059,7 +3157,7 @@ def conj(self): def __add__(self, other): if isinstance(other, ComplexNumber): return ComplexNumber(self.real + other.real, self.imag + other.imag) - elif isinstance(other, (jt.Var, int, float)): + elif isinstance(other, (int, float)): return ComplexNumber(self.real + other, self.imag) else: raise NotImplementedError @@ -3067,7 +3165,7 @@ def __add__(self, other): def __radd__(self, other): if isinstance(other, ComplexNumber): return ComplexNumber(other.real + self.real, other.imag + self.imag) - elif isinstance(other, (jt.Var, int, float)): + elif isinstance(other, (int, float)): return ComplexNumber(other + self.real, self.imag) else: raise NotImplementedError @@ -3075,7 +3173,7 @@ def __radd__(self, other): def __sub__(self, other): if isinstance(other, ComplexNumber): return ComplexNumber(self.real - other.real, self.imag - other.imag) - elif isinstance(other, (jt.Var, int, float)): + elif isinstance(other, (int, float)): return ComplexNumber(self.real - other, self.imag) else: raise NotImplementedError @@ -3083,7 +3181,7 @@ def __sub__(self, other): def __rsub__(self, other): if isinstance(other, ComplexNumber): return ComplexNumber(other.real - self.real, other.imag - self.imag) - elif isinstance(other, (jt.Var, int, float)): + elif isinstance(other, (int, float)): return ComplexNumber(other - self.real, self.imag) else: raise NotImplementedError @@ -3094,8 +3192,6 @@ def __mul__(self, other): self.real * other.imag + self.imag * other.real) elif isinstance(other, (int, float)): return ComplexNumber(self.value * other, is_concat_value=True) - elif isinstance(other, jt.Var): - return ComplexNumber(self.real * other, self.imag * other) else: raise NotImplementedError @@ -3105,8 +3201,6 @@ def __rmul__(self, other): other.imag * self.real + other.real * self.imag) elif isinstance(other, (int, float)): return ComplexNumber(other * self.value, is_concat_value=True) - elif isinstance(other, jt.Var): - return ComplexNumber(other * self.real, other * self.imag) else: raise NotImplementedError @@ -3117,8 +3211,6 @@ def __truediv__(self, other): (self.imag * other.real - self.real * other.imag) / norm) elif isinstance(other, (int, float)): return ComplexNumber(self.value / other, is_concat_value=True) - elif isinstance(other, jt.Var): - return ComplexNumber(self.real / other, self.imag / other) else: raise NotImplementedError @@ -3127,7 +3219,7 @@ def __rtruediv__(self, other): if isinstance(other, ComplexNumber): return ComplexNumber((other.real * self.real + other.imag * self.imag) / norm, (other.imag * self.real - other.real * self.imag) / norm) - elif isinstance(other, (int, float, jt.Var)): + elif isinstance(other, (int, float)): return ComplexNumber(other * self.real / norm, - other * self.imag / norm) else: raise NotImplementedError @@ -3136,8 +3228,6 @@ def __matmul__(self, other): if isinstance(other, ComplexNumber): return ComplexNumber(self.real @ other.real - self.imag @ other.imag, self.real @ other.imag + self.imag @ other.real) - elif isinstance(other, jt.Var): - return ComplexNumber(self.real @ other, self.imag @ other) else: raise NotImplementedError @@ -3145,8 +3235,6 @@ def __imatmul__(self, other): if isinstance(other, ComplexNumber): return ComplexNumber(other.real @ self.real - other.imag @ self.imag, other.imag @ self.real + other.real @ self.imag) - elif isinstance(other, jt.Var): - return ComplexNumber(other @ self.real, other @ self.imag) else: raise NotImplementedError diff --git a/python/jittor/pool.py b/python/jittor/pool.py index f9ad9a68..7e9e808a 100644 --- a/python/jittor/pool.py +++ b/python/jittor/pool.py @@ -29,9 +29,20 @@ def __init__(self, kernel_size, stride=None, padding=0, dilation=None, return_in self.padding = padding if isinstance(padding, tuple) else (padding, padding) self.ceil_mode = ceil_mode self.count_include_pad = count_include_pad and padding != 0 + for item in self.kernel_size: + if item <= 0: + raise RuntimeError(f"kernel_size must be greater than zero, but got {item}") + for item in self.stride: + if item <= 0: + raise RuntimeError(f"stride must be greater than zero, but got {item}") + for item in self.padding: + if item < 0: + raise RuntimeError(f"padding must be non-negative, but got {item}") def execute(self, x): N,C,H,W = x.shape + if H <= self.kernel_size[0] or W <= self.kernel_size[1]: + raise RuntimeError(f"size of var should be larger than kernel_size") if self.ceil_mode == False: h = (H+self.padding[0]*2-self.kernel_size[0])//self.stride[0]+1 w = (W+self.padding[1]*2-self.kernel_size[1])//self.stride[1]+1 @@ -205,9 +216,17 @@ def __init__(self, kernel_size, stride=None, padding=0, dilation=None, return_in self.padding = _triple(padding) self.ceil_mode = ceil_mode self.count_include_pad = count_include_pad and padding != 0 + if self.kernel_size[0] <= 0 or self.kernel_size[1] <= 0 or self.kernel_size[2] <= 0: + raise RuntimeError(f"kernel_size must be greater than zero, but got {kernel_size}") + if self.stride[0] <= 0 or self.stride[1] <= 0 or self.stride[2] <= 0: + raise RuntimeError(f"stride must be greater than zero, but got {stride}") + if self.padding[0] < 0 or self.padding[1] < 0 or self.padding[2] < 0: + raise RuntimeError(f"padding must be non-negative, but got {padding}") def execute(self, x): N,C,D,H,W = x.shape + if D <= self.kernel_size[0] or H <= self.kernel_size[1] or W <= self.kernel_size[2]: + raise RuntimeError(f"size of var should be larger than kernel_size") if self.ceil_mode == False: d = (D+self.padding[0]*2-self.kernel_size[0])//self.stride[0]+1 h = (H+self.padding[1]*2-self.kernel_size[1])//self.stride[1]+1 @@ -506,7 +525,7 @@ def execute(self, x): f"i3*{self.sh}+i6", # Hid f"i4*{self.sw}+i7", # Wid ]) - return xx.reduce("maximun", [5,6,7]) + return xx.reduce("maximum", [5,6,7]) def pool(x, kernel_size, op, padding=0, stride=None): return Pool(kernel_size, stride, padding, op=op)(x) diff --git a/python/jittor/src/misc/cuda_atomic.h b/python/jittor/src/misc/cuda_atomic.h index 6348befb..15b39412 100644 --- a/python/jittor/src/misc/cuda_atomic.h +++ b/python/jittor/src/misc/cuda_atomic.h @@ -6,7 +6,9 @@ // *************************************************************** #pragma once #include +#ifndef IS_ROCM #include +#endif #include "common.h" namespace jittor { diff --git a/python/jittor/src/misc/nan_checker.cc b/python/jittor/src/misc/nan_checker.cc index bbec4ccb..00fb34aa 100644 --- a/python/jittor/src/misc/nan_checker.cc +++ b/python/jittor/src/misc/nan_checker.cc @@ -11,7 +11,9 @@ #include "misc/cuda_flags.h" #include #include +#ifndef IS_ROCM #include +#endif #include "helper_cuda.h" #endif #include "mem/allocator.h" @@ -22,7 +24,9 @@ namespace jittor { #ifdef IS_CUDA EXTERN_LIB vector check_nan_float16(__half* ptr, int64 num); +#ifndef IS_ROCM EXTERN_LIB vector check_nan_bfloat16(__nv_bfloat16* ptr, int64 num); +#endif EXTERN_LIB vector check_nan_float32(float32* ptr, int64 num); EXTERN_LIB vector check_nan_float64(float64* ptr, int64 num); #endif @@ -33,7 +37,9 @@ void dump_var(Var* v, string name) { name = ss.str(); LOGe << "dump" << v << "to" << name; char* buffer = new char[v->size]; - #ifdef HAS_CUDA + #ifdef IS_ROCM + hipMemcpy(buffer, v->mem_ptr, v->size, hipMemcpyDefault); + #elif IS_CUDA cudaMemcpy(buffer, v->mem_ptr, v->size, cudaMemcpyDefault); #else std::memcpy(buffer, v->mem_ptr, v->size); @@ -57,9 +63,11 @@ bool check_nan(Var* v, Op* op) { if (v->dtype() == ns_float16) { nan_index = check_nan_float16((__half*)v->mem_ptr, v->num); } + #ifndef IS_ROCM if (v->dtype() == ns_bfloat16) { nan_index = check_nan_bfloat16((__nv_bfloat16*)v->mem_ptr, v->num); } + #endif if (v->dtype() == ns_float32) { nan_index = check_nan_float32((float32*)v->mem_ptr, v->num); } else @@ -104,14 +112,16 @@ bool check_nan(Var* v, Op* op) { auto* ptr = input->ptr<__half>(); __half value; cudaMemcpy(&value, ptr+index, sizeof(__half), cudaMemcpyDeviceToHost); - LOGe << "input" << icnt << "dtype" << input->dtype() << "index" << index << "value" << (float)value; + // LOGe << "input" << icnt << "dtype" << input->dtype() << "index" << index << "value" << (float)value; } else + #ifndef IS_ROCM if (input->dtype() == ns_bfloat16) { auto* ptr = input->ptr<__nv_bfloat16>(); __nv_bfloat16 value; cudaMemcpy(&value, ptr+index, sizeof(__nv_bfloat16), cudaMemcpyDeviceToHost); LOGe << "input" << icnt << "dtype" << input->dtype() << "index" << index << "value" << (float)value; } else + #endif if (input->dtype() == ns_float32) { auto* ptr = input->ptr(); float32 value; diff --git a/python/jittor/src/misc/nan_checker.cu b/python/jittor/src/misc/nan_checker.cu index ec7998ba..b2e631f1 100644 --- a/python/jittor/src/misc/nan_checker.cu +++ b/python/jittor/src/misc/nan_checker.cu @@ -7,9 +7,13 @@ #include "misc/cuda_flags.h" #include #include -#include + #include "helper_cuda.h" #include +//TODO:FIX in ROCM +#ifndef IS_ROCM +#include +#endif namespace jittor { @@ -37,7 +41,7 @@ __global__ void _check_nan_float16(__half* __restrict__ ptr, int64 num, int* cnt print_nan(float(ptr[i]), i, cnt); } } - +#ifndef IS_ROCM __global__ void _check_nan_bfloat16(__nv_bfloat16* __restrict__ ptr, int64 num, int* cnt) { int64 i = threadIdx.x + blockIdx.x * (int64)blockDim.x; if (i check_nan_float16(__half* ptr, int64 num) { _check_nan_float16<<>>(ptr, num, check_nan_get_device_ptr()); return report_nan(); } - +#ifndef IS_ROCM vector check_nan_bfloat16(__nv_bfloat16* ptr, int64 num) { int block_num = std::max((int64)1, (num-1)/1024+1); int thread_num = std::min((int64)1024, num); _check_nan_bfloat16<<>>(ptr, num, check_nan_get_device_ptr()); return report_nan(); } - +#endif #endif } \ No newline at end of file diff --git a/python/jittor/src/ops/binary_op.cc b/python/jittor/src/ops/binary_op.cc index 01a8ef2d..848d40a4 100644 --- a/python/jittor/src/ops/binary_op.cc +++ b/python/jittor/src/ops/binary_op.cc @@ -440,6 +440,17 @@ BinaryOp::BinaryOp(Var* x, Var* y, NanoString op) : x(x), y(y) { return; } + #ifdef IS_ACL + if (x->dtype() != y->dtype()) { + auto dtype = binary_dtype_infer(ns_add, x->ns, y->ns, 0, 0); + auto xp = make_unary(x, dtype); + auto yp = make_unary(y, dtype); + auto zp = make_binary(xp, yp, op); + forward(zp); + return; + } + #endif + flags.set(NodeFlags::_cpu); flags.set(NodeFlags::_cuda); set_type(OpType::element); diff --git a/python/jittor/src/type/fp16_compute.h b/python/jittor/src/type/fp16_compute.h index 56f70789..f5c85d3b 100644 --- a/python/jittor/src/type/fp16_compute.h +++ b/python/jittor/src/type/fp16_compute.h @@ -11,12 +11,17 @@ #include #include +#ifndef IS_ROCM #include +#endif namespace jittor { typedef __half float16; +#ifndef IS_ROCM typedef __nv_bfloat16 bfloat16; +#endif + #if CUDA_ARCH >= 800 inline __device__ float16 max(float16 a, float16 b) { return __hmax(a, b); } @@ -32,7 +37,7 @@ inline __device__ float16 min(float16 a, float16 b) { return float(a)= 800 inline __device__ bfloat16 max(bfloat16 a, bfloat16 b) { return __hmax(a, b); } inline __device__ bfloat16 min(bfloat16 a, bfloat16 b) { return __hmin(a, b); } @@ -45,7 +50,7 @@ inline __device__ bfloat16 min(bfloat16 a, bfloat16 b) { return float(a) __device__ inline typename std::enable_if::type diff --git a/python/jittor/src/var_holder.cc b/python/jittor/src/var_holder.cc index 092eddd2..ed4d4c0a 100644 --- a/python/jittor/src/var_holder.cc +++ b/python/jittor/src/var_holder.cc @@ -215,11 +215,14 @@ inline static void cast_item_data(ItemData& data) { auto* fp16 = (float16*)&data; auto* fp32 = (float32*)&data; fp32[0] = float32(fp16[0]); - } else if (data.dtype == ns_bfloat16) { + } + #ifndef IS_ROCM + else if (data.dtype == ns_bfloat16) { auto* bf16 = (bfloat16*)&data; auto* fp32 = (float32*)&data; fp32[0] = float32(bf16[0]); } + #endif data.dtype = ns_float32; } diff --git a/python/jittor/test/test_complex.py b/python/jittor/test/test_complex.py new file mode 100644 index 00000000..19686009 --- /dev/null +++ b/python/jittor/test/test_complex.py @@ -0,0 +1,200 @@ +import jittor as jt +from jittor.nn import ComplexNumber +import unittest +import numpy as np + +__skip_torch_test = False +try: + import torch +except: + __skip_torch_test = True + +class TestResultAndGrad: + def check_results(self, rlist1, rlist2): + assert len(rlist1) == len(rlist2) + for r1, r2 in zip(rlist1, rlist2): + assert r1.shape == r2.shape + assert np.allclose(r1, r2, rtol=1e-3, atol=1e-3) + + def grad_jittor(self, inputs, losses): + grads = [] + for i in inputs: + for loss in losses: + if isinstance(i, ComplexNumber): + g = jt.grad(loss, i.value, retain_graph=True) + grads.append(g[..., 0].numpy() + 1j * g[..., 1].numpy()) + else: + g = jt.grad(loss, i, retain_graph=True) + grads.append(g.numpy()) + return grads + + def grad_torch(self, inputs, losses): + grads = [] + for i in inputs: + for loss in losses: + g = torch.autograd.grad(loss, i, retain_graph=True)[0] + grads.append(g.detach().cpu().numpy()) + return grads + + def run_jittor_op(self, op, input_list, weights=None): + def _np_to_jittor(x:np.ndarray): + if x.dtype == np.complex64 or x.dtype == np.complex128: + nx = np.stack([np.real(x), np.imag(x)], axis=-1) + return ComplexNumber(jt.array(nx, dtype=jt.float32), is_concat_value=True) + elif x.dtype == np.float32 or x.dtype == np.float64: + return jt.array(x, dtype=jt.float32) + else: + assert False + def _jittor_to_np(x): + if isinstance(x, jt.Var): + return x.numpy() + elif isinstance(x, ComplexNumber): + return x.real.numpy() + 1j * x.imag.numpy() + assert False + + ninput_list = [_np_to_jittor(x) for x in input_list] + output_list = op(*ninput_list) + if isinstance(output_list, (jt.Var, ComplexNumber)): + output_list = [output_list] + losses = [] + if weights is None: + weights = [] + for o in output_list: + no = o.value if isinstance(o, ComplexNumber) else o + w = np.random.randn(*no.shape) + weights.append(w) + losses.append(jt.sum(no * jt.array(w))) + else: + assert len(output_list) == len(weights) + for o, w in zip(output_list, weights): + no = o.value if isinstance(o, ComplexNumber) else o + assert w.shape == no.shape + losses.append(jt.sum(no * jt.array(w))) + output_list = [_jittor_to_np(x) for x in output_list] + return ninput_list, output_list, losses, weights + + def run_torch_op(self, op, input_list, weights=None): + def _np_to_torch(x:np.ndarray): + return torch.from_numpy(x).requires_grad_(True) + def _torch_to_np(x:torch.Tensor) -> np.ndarray: + return x.detach().cpu().numpy() + ninput_list = [_np_to_torch(x) for x in input_list] + output_list = op(*ninput_list) + if isinstance(output_list, torch.Tensor): + output_list = [output_list] + losses = [] + if weights is None: + weights = [] + for o in output_list: + no = torch.stack([torch.real(o), torch.imag(o)], dim=-1) if o.is_complex() else o + w = np.random.randn(*no.shape) + weights.append(w) + losses.append(torch.sum(no * torch.from_numpy(w))) + else: + assert len(output_list) == len(weights) + for o, w in zip(output_list, weights): + no = torch.stack([torch.real(o), torch.imag(o)], dim=-1) if o.is_complex() else o + assert w.shape == no.shape + losses.append(torch.sum(no * torch.from_numpy(w))) + output_list = [_torch_to_np(x) for x in output_list] + return ninput_list, output_list, losses, weights + + def check_op_with_torch(self, jittor_op, torch_op, input_list, check_grad=True): + weights = None + jittor_input, jittor_output, jittor_losses, weights = self.run_jittor_op(jittor_op, input_list, weights) + torch_input, torch_output, torch_losses, weights = self.run_torch_op(torch_op, input_list, weights) + self.check_results(jittor_output, torch_output) + + if check_grad: + jittor_grads = self.grad_jittor(jittor_input, jittor_losses) + torch_grads = self.grad_torch(torch_input, torch_losses) + self.check_results(jittor_grads, torch_grads) + + def check_op_with_numpy(self, jittor_op, numpy_op, input_list): + _, jittor_output, _, _ = self.run_jittor_op(jittor_op, input_list, None) + numpy_output = numpy_op(*input_list) + if isinstance(numpy_output, np.ndarray): + numpy_output = [numpy_output] + + self.check_results(jittor_output, numpy_output) + +@unittest.skipIf(__skip_torch_test, "No Torch found") +class TestComplexLinalg(unittest.TestCase, TestResultAndGrad): + def random_complex_matrix(self, shape): + r = np.random.randn(*shape) + i = np.random.randn(*shape) + return r + 1j * i + + def test_complex_matmul(self): + s1 = (50, 200) + s2 = (200, 50) + m1 = self.random_complex_matrix(s1) + m2 = self.random_complex_matrix(s2) + + inputs = [m1, m2] + self.check_op_with_torch(jt.matmul, torch.matmul, inputs) + + def test_complex_matmul_batch(self): + s1 = (10, 50, 30) + s2 = (10, 30, 40) + m1 = self.random_complex_matrix(s1) + m2 = self.random_complex_matrix(s2) + + inputs = [m1, m2] + self.check_op_with_torch(jt.matmul, torch.matmul, inputs) + + def test_complex_inv(self): + s1 = (200, 200) + m1 = self.random_complex_matrix(s1) + inputs = [m1] + self.check_op_with_torch(jt.linalg.inv, torch.linalg.inv, inputs) + + def test_complex_inv_batch(self): + s1 = (10, 50, 50) + m1 = self.random_complex_matrix(s1) + inputs = [m1] + self.check_op_with_torch(jt.linalg.inv, torch.linalg.inv, inputs) + + def test_complex_eig(self): + # Unstable + s1 = (20, 20) + m1 = self.random_complex_matrix(s1) + inputs = [m1] + self.check_op_with_numpy(jt.linalg.eig, np.linalg.eig, inputs) + + def test_complex_eig_batch(self): + # Unstable + s1 = (5, 10, 10) + m1 = self.random_complex_matrix(s1) + inputs = [m1] + self.check_op_with_numpy(jt.linalg.eig, np.linalg.eig, inputs) + + def test_complex_qr(self): + s1 = (50, 50) + m1 = self.random_complex_matrix(s1) + inputs = [m1] + self.check_op_with_torch(jt.linalg.qr, torch.linalg.qr, inputs) + + def test_complex_qr_batch(self): + s1 = (10, 20, 20) + m1 = self.random_complex_matrix(s1) + inputs = [m1] + self.check_op_with_torch(jt.linalg.qr, torch.linalg.qr, inputs) + + def test_complex_svd(self): + # Unstable + s1 = (50, 50) + m1 = self.random_complex_matrix(s1) + inputs = [m1] + self.check_op_with_numpy(jt.linalg.svd, np.linalg.svd, inputs) + + def test_complex_svd_batch(self): + # Unstable + s1 = (10, 20, 20) + m1 = self.random_complex_matrix(s1) + inputs = [m1] + self.check_op_with_numpy(jt.linalg.svd, np.linalg.svd, inputs) + + +if __name__ == "__main__": + unittest.main() diff --git a/setup.py b/setup.py index 741165d2..beec50f6 100644 --- a/setup.py +++ b/setup.py @@ -62,7 +62,7 @@ package_data={'': ['*', '*/*', '*/*/*','*/*/*/*','*/*/*/*/*','*/*/*/*/*/*']}, # include_package_data=True, install_requires=[ - "numpy", + "numpy<2.0", "tqdm", "pillow", "astunparse",