From edf6a94661121def2554a150099bdc1c107a276f Mon Sep 17 00:00:00 2001 From: co63oc Date: Mon, 22 May 2023 13:18:08 +0800 Subject: [PATCH 01/73] Update mnist.py --- python/jittor/dataset/mnist.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/python/jittor/dataset/mnist.py b/python/jittor/dataset/mnist.py index 7aa1d883..f0945f94 100644 --- a/python/jittor/dataset/mnist.py +++ b/python/jittor/dataset/mnist.py @@ -26,7 +26,7 @@ class MNIST(Dataset): [in] data_root(str): your data root. [in] train(bool): choose model train or val. - [in] download(bool): Download data automatically if download is Ture. + [in] download(bool): Download data automatically if download is True. [in] batch_size(int): Data batch size. [in] shuffle(bool): Shuffle data if true. [in] transform(jittor.transform): transform data. @@ -106,7 +106,7 @@ class EMNIST(Dataset): [in] data_root(str): your data root. [in] split(str): one of 'byclass', 'bymerge', 'balanced', 'letters', 'digits', 'mnist'. [in] train(bool): choose model train or val. - [in] download(bool): Download data automatically if download is Ture. + [in] download(bool): Download data automatically if download is True. [in] batch_size(int): Data batch size. [in] shuffle(bool): Shuffle data if true. [in] transform(jittor.transform): transform data. From 04a02e0be1536cdefff2b4e3c4210611ebfc784d Mon Sep 17 00:00:00 2001 From: DongYang Li <62846124+LDYang694@users.noreply.github.com> Date: Thu, 16 May 2024 14:35:21 +0800 Subject: [PATCH 02/73] polish PixelShuffle in nn.py --- python/jittor/nn.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/jittor/nn.py b/python/jittor/nn.py index 75b8f42b..4a30eefd 100644 --- a/python/jittor/nn.py +++ b/python/jittor/nn.py @@ -1767,7 +1767,7 @@ def execute(self, x): n,c,h,w = x.shape r = self.upscale_factor assert c%(r*r)==0, f"input channel needs to be divided by upscale_factor's square in PixelShuffle" - if r<0: + if r<=0: raise RuntimeError(f"pixel_shuffle expects a positive upscale_factor, but got {r}") return x.reindex([n,int(c/r**2),h*r,w*r], [ "i0", From e001b4c053bd6830e12cb687c01e92ba496620ad Mon Sep 17 00:00:00 2001 From: lidongyang Date: Mon, 20 May 2024 21:34:26 +0800 Subject: [PATCH 03/73] polish rocm support --- python/jittor/compatibility/__init__.py | 798 +++++++++--------- python/jittor/compatibility/utils/data.py | 66 +- .../extern/cuda/cublas/inc/cublas_wrapper.h | 2 + .../extern/cuda/cudnn/inc/cudnn_wrapper.h | 5 +- python/jittor/src/misc/cuda_atomic.h | 2 + python/jittor/src/misc/nan_checker.cc | 14 +- python/jittor/src/misc/nan_checker.cu | 13 +- python/jittor/src/type/fp16_compute.h | 9 +- python/jittor/src/var_holder.cc | 5 +- 9 files changed, 472 insertions(+), 442 deletions(-) diff --git a/python/jittor/compatibility/__init__.py b/python/jittor/compatibility/__init__.py index 6f88fad5..94d2e40b 100644 --- a/python/jittor/compatibility/__init__.py +++ b/python/jittor/compatibility/__init__.py @@ -1,430 +1,430 @@ -import os -os.environ["FIX_TORCH_ERROR"] = "0" - -import jittor as jt -from jittor import * -from typing import Tuple - -org_int = int = type(1) -org_float = float = type(1.0) -org_bool = bool = type(True) - -import jtorch.compiler - -import jtorch_core -from jtorch_core import * - -device.__reduce__ = lambda self: (device, (self.type,)) -device.__module__ = "jtorch" -jt.jittor_core.device = device - -def handle_dtype(args, kw, dtype): - def convert(x): - if isinstance(x, jt.Var): - return x.cast(dtype) - return x - if dtype is not None: - if args is not None: - if isinstance(args, (tuple,list)): - args = [ convert(a) for a in args ] - else: - args = convert(x) - if kw is not None: - kw = { k:convert(v) for k,v in kw.items() } - return args, kw - -def get_args_names(func): - import inspect - spec = inspect.getfullargspec(func) - return spec[0] + spec[4] - -def wrapper(func): - has_dtype = False - if hasattr(func, "__code__"): - has_dtype = "dtype" in get_args_names(func) - def inner(*args, **kw): - requires_grad = None - dtype = None - if "requires_grad" in kw: - requires_grad = kw["requires_grad"] - del kw["requires_grad"] - if not has_dtype and "dtype" in kw: - dtype = kw["dtype"] - del kw["dtype"] - if "device" in kw: - del kw["device"] - if 'pin_memory' in kw: - del kw['pin_memory'] - args, kw = handle_dtype(args, kw, dtype) - ret = func(*args, **kw) - if isinstance(ret, jt.Var): - if requires_grad is not None: - ret.requires_grad = requires_grad - if dtype is not None: - ret.astype(dtype) - return ret - return inner +# import os +# os.environ["FIX_TORCH_ERROR"] = "0" + +# import jittor as jt +# from jittor import * +# from typing import Tuple + +# org_int = int = type(1) +# org_float = float = type(1.0) +# org_bool = bool = type(True) + +# import jtorch.compiler + +# import jtorch_core +# from jtorch_core import * + +# device.__reduce__ = lambda self: (device, (self.type,)) +# device.__module__ = "jtorch" +# jt.jittor_core.device = device + +# def handle_dtype(args, kw, dtype): +# def convert(x): +# if isinstance(x, jt.Var): +# return x.cast(dtype) +# return x +# if dtype is not None: +# if args is not None: +# if isinstance(args, (tuple,list)): +# args = [ convert(a) for a in args ] +# else: +# args = convert(x) +# if kw is not None: +# kw = { k:convert(v) for k,v in kw.items() } +# return args, kw + +# def get_args_names(func): +# import inspect +# spec = inspect.getfullargspec(func) +# return spec[0] + spec[4] + +# def wrapper(func): +# has_dtype = False +# if hasattr(func, "__code__"): +# has_dtype = "dtype" in get_args_names(func) +# def inner(*args, **kw): +# requires_grad = None +# dtype = None +# if "requires_grad" in kw: +# requires_grad = kw["requires_grad"] +# del kw["requires_grad"] +# if not has_dtype and "dtype" in kw: +# dtype = kw["dtype"] +# del kw["dtype"] +# if "device" in kw: +# del kw["device"] +# if 'pin_memory' in kw: +# del kw['pin_memory'] +# args, kw = handle_dtype(args, kw, dtype) +# ret = func(*args, **kw) +# if isinstance(ret, jt.Var): +# if requires_grad is not None: +# ret.requires_grad = requires_grad +# if dtype is not None: +# ret.astype(dtype) +# return ret +# return inner -import inspect -_wrapper_keys = set(["shape", "start", "size"]) -_wrapper_keys.add("x") -for k,v in list(globals().items()): - if callable(v) and not isinstance(v, type): - try: - spec = inspect.getfullargspec(v) - args_name = spec[0] - if len(args_name) and args_name[0] in _wrapper_keys: - globals()[k] = wrapper(v) - elif spec.varargs in _wrapper_keys: - globals()[k] = wrapper(v) - except: - pass - -def empty(*size, dtype=jt.float32, device=None, requires_grad=False): - if len(size) == 1 and not isinstance(size[0], org_int): - size = size[0] - return jt.empty(size, dtype) - -Tensor = Var - -Tensor.backward = lambda x: jtorch_core.backward(x) -Tensor.grad = property(grad_get, grad_set, grad_del) -Tensor.retains_grad = property(retain_grad_get, retain_grad_set) -def retain_grad(x:Tensor, value:bool=True): - x.retains_grad = value - return value -Tensor.retain_grad = retain_grad - -Tensor.dim = lambda self: self.ndim -Tensor.ndimension = lambda self: self.ndim -Tensor.nelement = lambda self: self.numel() -Tensor.cuda = lambda self: self -def device_get(x:Tensor): - return device("cpu") if not jt.has_cuda or not jt.flags.use_cuda else device("cuda") -Tensor.device = property(device_get) - -def argmax(x: Var, dim=None, keepdim: bool = False): - return jt.argmax(x, dim, keepdim)[0] -Tensor.argmax = argmax - -def tensor_type(x: Var, dtype=None, **kwargs): - if dtype: - return x.astype(dtype) - else: - return x.dtype -Tensor.type = tensor_type - -def is_floating_point(x: Var): - return "float" in str(x.dtype) -Tensor.is_floating_point = is_floating_point - -from . import autograd -from .autograd import * - -def tensor(data, *, dtype=None, device=None, requires_grad=False, pin_memory=False): - if isinstance(data,list): - data_list = [] - check = True - for p in data: - if isinstance(p, Tensor) and p.numel()==1: - data_list.append(p.item()) - elif isinstance(p, (org_int,org_float)): - data_list.append(p) - else: - check = False - break - if check: - data = data_list - return wrapper(array)(data, dtype=dtype, device=device, requires_grad=requires_grad, pin_memory=pin_memory) - -# tensor = wrapper(array) -from_numpy = wrapper(array) -strided = None - -def mod_zero_grad(self): - for p in self.parameters(): - p.grad = None -Module.zero_grad = mod_zero_grad - -class ModuleMisc: - def parameters(self): - return iter(super().parameters()) - - def load_state_dict(self, state_dict, strict=False): - return super().load_state_dict(state_dict) - - def to(self, device=None,dtype=None): - ''' do nothing but return its self''' - return self - def register_parameter(self,name,data): - self.name = data - - def buffers(self): - for _, buf in self.named_buffers(): - yield buf +# import inspect +# _wrapper_keys = set(["shape", "start", "size"]) +# _wrapper_keys.add("x") +# for k,v in list(globals().items()): +# if callable(v) and not isinstance(v, type): +# try: +# spec = inspect.getfullargspec(v) +# args_name = spec[0] +# if len(args_name) and args_name[0] in _wrapper_keys: +# globals()[k] = wrapper(v) +# elif spec.varargs in _wrapper_keys: +# globals()[k] = wrapper(v) +# except: +# pass + +# def empty(*size, dtype=jt.float32, device=None, requires_grad=False): +# if len(size) == 1 and not isinstance(size[0], org_int): +# size = size[0] +# return jt.empty(size, dtype) + +# Tensor = Var + +# Tensor.backward = lambda x: jtorch_core.backward(x) +# Tensor.grad = property(grad_get, grad_set, grad_del) +# Tensor.retains_grad = property(retain_grad_get, retain_grad_set) +# def retain_grad(x:Tensor, value:bool=True): +# x.retains_grad = value +# return value +# Tensor.retain_grad = retain_grad + +# Tensor.dim = lambda self: self.ndim +# Tensor.ndimension = lambda self: self.ndim +# Tensor.nelement = lambda self: self.numel() +# Tensor.cuda = lambda self: self +# def device_get(x:Tensor): +# return device("cpu") if not jt.has_cuda or not jt.flags.use_cuda else device("cuda") +# Tensor.device = property(device_get) + +# def argmax(x: Var, dim=None, keepdim: bool = False): +# return jt.argmax(x, dim, keepdim)[0] +# Tensor.argmax = argmax + +# def tensor_type(x: Var, dtype=None, **kwargs): +# if dtype: +# return x.astype(dtype) +# else: +# return x.dtype +# Tensor.type = tensor_type + +# def is_floating_point(x: Var): +# return "float" in str(x.dtype) +# Tensor.is_floating_point = is_floating_point + +# from . import autograd +# from .autograd import * + +# def tensor(data, *, dtype=None, device=None, requires_grad=False, pin_memory=False): +# if isinstance(data,list): +# data_list = [] +# check = True +# for p in data: +# if isinstance(p, Tensor) and p.numel()==1: +# data_list.append(p.item()) +# elif isinstance(p, (org_int,org_float)): +# data_list.append(p) +# else: +# check = False +# break +# if check: +# data = data_list +# return wrapper(array)(data, dtype=dtype, device=device, requires_grad=requires_grad, pin_memory=pin_memory) + +# # tensor = wrapper(array) +# from_numpy = wrapper(array) +# strided = None + +# def mod_zero_grad(self): +# for p in self.parameters(): +# p.grad = None +# Module.zero_grad = mod_zero_grad + +# class ModuleMisc: +# def parameters(self): +# return iter(super().parameters()) + +# def load_state_dict(self, state_dict, strict=False): +# return super().load_state_dict(state_dict) + +# def to(self, device=None,dtype=None): +# ''' do nothing but return its self''' +# return self +# def register_parameter(self,name,data): +# self.name = data + +# def buffers(self): +# for _, buf in self.named_buffers(): +# yield buf -def make_module(cls): - class TMod(ModuleMisc, cls): - def __init__(self, *args, **kw): - dtype = None - if "dtype" in kw: - dtype = kw["dtype"] - del kw["dtype"] - self._dtype = dtype - with jt.flag_scope(th_mode=0): - if "device" in kw: - del kw["device"] - super().__init__(*args, **kw) - for k,v in self.__dict__.items(): - if not k.startswith("_") and isinstance(v, Var) \ - and v.requires_grad: - v.retain_grad() - if dtype is not None and isinstance(v, Var): - v.assign(v.cast(dtype)) - def __call__(self, *args, **kw): - args, kw = handle_dtype(args, kw, self._dtype) - # if forward is override by user, call forward - if self.__class__.forward is not TMod.forward: - return self.forward(*args, **kw) - return self.execute(*args, **kw) - def forward(self, *args, **kw): - args, kw = handle_dtype(args, kw, self._dtype) - return self.execute(*args, **kw) +# def make_module(cls): +# class TMod(ModuleMisc, cls): +# def __init__(self, *args, **kw): +# dtype = None +# if "dtype" in kw: +# dtype = kw["dtype"] +# del kw["dtype"] +# self._dtype = dtype +# with jt.flag_scope(th_mode=0): +# if "device" in kw: +# del kw["device"] +# super().__init__(*args, **kw) +# for k,v in self.__dict__.items(): +# if not k.startswith("_") and isinstance(v, Var) \ +# and v.requires_grad: +# v.retain_grad() +# if dtype is not None and isinstance(v, Var): +# v.assign(v.cast(dtype)) +# def __call__(self, *args, **kw): +# args, kw = handle_dtype(args, kw, self._dtype) +# # if forward is override by user, call forward +# if self.__class__.forward is not TMod.forward: +# return self.forward(*args, **kw) +# return self.execute(*args, **kw) +# def forward(self, *args, **kw): +# args, kw = handle_dtype(args, kw, self._dtype) +# return self.execute(*args, **kw) - @property - def training(self): - if not hasattr(self, "is_train"): - self.is_train = True - return self.is_train - @training.setter - def training(self, value): - self.is_train = value - - TMod.__name__ = cls.__name__ - return TMod - -import jtorch.cuda -import jtorch.nn -from jtorch.nn import Module, Parameter -import jtorch.optim - -from jtorch.utils.dtype import Dtype, get_string_dtype - -def frombuffer(buffer: bytearray, - *, - dtype: Dtype, - count: int = -1, - offset: int = 0, - requires_grad: bool = True) -> Tensor: - dtype = get_string_dtype(dtype) - tensor = jt.array(np.frombuffer(buffer, dtype, count=count, offset=offset)) - if requires_grad and tensor.dtype.is_float(): - tensor.requires_grad = True - return tensor - -def conflict_wrapper(origin_func, new_func): - def wrapper(*args, **kw): - if jt.flags.th_mode: - return new_func(*args, **kw) - else: - return origin_func(*args, **kw) - return wrapper - -def min(*args, **kw): - dim = None - if len(args) >= 2 and isinstance(args[1], org_int): - dim = args[1] - elif "dim" in kw and isinstance(kw["dim"], org_int): - dim = kw["dim"] - if dim is not None: - k, v = jt.argmin(*args, **kw) - return v, k - elif len(args) == 2 and isinstance(args[1], jt.Var): - return jt.minimum(args[0], args[1]) - else: - return jt.min(*args, **kw) -Tensor.min = conflict_wrapper(jt.min, min) - -def max(*args, **kw): - dim = None - if "dim" in kw: - x = kw["dim"] - if len(args) >= 2 and isinstance(args[1], org_int): - dim = args[1] - elif "dim" in kw and isinstance(kw["dim"], org_int): - dim = kw["dim"] - if dim is not None: - k, v = jt.argmax(*args, **kw) - return v, k - elif len(args) == 2 and isinstance(args[1], jt.Var): - return jt.maximum(args[0], args[1]) - else: - return jt.max(*args, **kw) -Tensor.max = conflict_wrapper(jt.max, max) - -def argsort(*args, **kw): - k, v = jt.argsort(*args, **kw) - return k -Tensor.argsort = conflict_wrapper(jt.argsort, argsort) - -LongTensor = jt.int64 -FloatTensor = jt.float -HalfTensor = jt.float16 -BoolTensor = jt.bool -IntTensor = jt.int32 - -class JDType: - def __init__(self, func, str): - self.func = func - self.str = str - self.__name__ = str.split(".")[-1] - def __call__(self, *args, **kw): - return self.func(*args, **kw) - def __str__(self): - return self.str - def is_floating_point(self): - return "float" in str(self.str) - -int8 = JDType(jt.int8, "torch.int8") -int16 = JDType(jt.int16, "torch.int16") -int = int32 = JDType(jt.int32, "torch.int32") -long = int64 = JDType(jt.int64, "torch.int64") - -half = float16 = JDType(jt.float16, "torch.float16") -float = float32 = JDType(jt.float32, "torch.float32") -double = float64 = JDType(jt.float64, "torch.float64") -bfloat16 = "bfloat16" # TODO -complex64 = "complex64" # TODO -complex128 = "complex128" # TODO -def get_JDtype(dtype): - if dtype=='float32' or dtype == jt.float32: - return float32 - elif dtype=='float64' or dtype == jt.float64: - return float64 - elif dtype=='float16' or dtype == jt.float16: - return float16 - elif dtype=='int32' or dtype == jt.int32: - return int32 - elif dtype=='int64' or dtype == jt.int64: - return int64 - elif dtype=='int16' or dtype == jt.int16: - return int16 - elif dtype=='int8' or dtype == jt.int8: - return int8 - else: - raise Exception("dtype {} not supported".format(dtype)) - -def load(path,**kwargs): - def _to_jittor(data): - if isinstance(data,dict): - return {k:_to_jittor(d) for k,d in data.items()} - if isinstance(data,list): - return [_to_jittor(d) for d in data] - if isinstance(data,np.ndarray): - return jt.array(data) - return data - data = jt.load(path) +# @property +# def training(self): +# if not hasattr(self, "is_train"): +# self.is_train = True +# return self.is_train +# @training.setter +# def training(self, value): +# self.is_train = value + +# TMod.__name__ = cls.__name__ +# return TMod + +# import jtorch.cuda +# import jtorch.nn +# from jtorch.nn import Module, Parameter +# import jtorch.optim + +# from jtorch.utils.dtype import Dtype, get_string_dtype + +# def frombuffer(buffer: bytearray, +# *, +# dtype: Dtype, +# count: int = -1, +# offset: int = 0, +# requires_grad: bool = True) -> Tensor: +# dtype = get_string_dtype(dtype) +# tensor = jt.array(np.frombuffer(buffer, dtype, count=count, offset=offset)) +# if requires_grad and tensor.dtype.is_float(): +# tensor.requires_grad = True +# return tensor + +# def conflict_wrapper(origin_func, new_func): +# def wrapper(*args, **kw): +# if jt.flags.th_mode: +# return new_func(*args, **kw) +# else: +# return origin_func(*args, **kw) +# return wrapper + +# def min(*args, **kw): +# dim = None +# if len(args) >= 2 and isinstance(args[1], org_int): +# dim = args[1] +# elif "dim" in kw and isinstance(kw["dim"], org_int): +# dim = kw["dim"] +# if dim is not None: +# k, v = jt.argmin(*args, **kw) +# return v, k +# elif len(args) == 2 and isinstance(args[1], jt.Var): +# return jt.minimum(args[0], args[1]) +# else: +# return jt.min(*args, **kw) +# Tensor.min = conflict_wrapper(jt.min, min) + +# def max(*args, **kw): +# dim = None +# if "dim" in kw: +# x = kw["dim"] +# if len(args) >= 2 and isinstance(args[1], org_int): +# dim = args[1] +# elif "dim" in kw and isinstance(kw["dim"], org_int): +# dim = kw["dim"] +# if dim is not None: +# k, v = jt.argmax(*args, **kw) +# return v, k +# elif len(args) == 2 and isinstance(args[1], jt.Var): +# return jt.maximum(args[0], args[1]) +# else: +# return jt.max(*args, **kw) +# Tensor.max = conflict_wrapper(jt.max, max) + +# def argsort(*args, **kw): +# k, v = jt.argsort(*args, **kw) +# return k +# Tensor.argsort = conflict_wrapper(jt.argsort, argsort) + +# LongTensor = jt.int64 +# FloatTensor = jt.float +# HalfTensor = jt.float16 +# BoolTensor = jt.bool +# IntTensor = jt.int32 + +# class JDType: +# def __init__(self, func, str): +# self.func = func +# self.str = str +# self.__name__ = str.split(".")[-1] +# def __call__(self, *args, **kw): +# return self.func(*args, **kw) +# def __str__(self): +# return self.str +# def is_floating_point(self): +# return "float" in str(self.str) + +# int8 = JDType(jt.int8, "torch.int8") +# int16 = JDType(jt.int16, "torch.int16") +# int = int32 = JDType(jt.int32, "torch.int32") +# long = int64 = JDType(jt.int64, "torch.int64") + +# half = float16 = JDType(jt.float16, "torch.float16") +# float = float32 = JDType(jt.float32, "torch.float32") +# double = float64 = JDType(jt.float64, "torch.float64") +# bfloat16 = "bfloat16" # TODO +# complex64 = "complex64" # TODO +# complex128 = "complex128" # TODO +# def get_JDtype(dtype): +# if dtype=='float32' or dtype == jt.float32: +# return float32 +# elif dtype=='float64' or dtype == jt.float64: +# return float64 +# elif dtype=='float16' or dtype == jt.float16: +# return float16 +# elif dtype=='int32' or dtype == jt.int32: +# return int32 +# elif dtype=='int64' or dtype == jt.int64: +# return int64 +# elif dtype=='int16' or dtype == jt.int16: +# return int16 +# elif dtype=='int8' or dtype == jt.int8: +# return int8 +# else: +# raise Exception("dtype {} not supported".format(dtype)) + +# def load(path,**kwargs): +# def _to_jittor(data): +# if isinstance(data,dict): +# return {k:_to_jittor(d) for k,d in data.items()} +# if isinstance(data,list): +# return [_to_jittor(d) for d in data] +# if isinstance(data,np.ndarray): +# return jt.array(data) +# return data +# data = jt.load(path) - return _to_jittor(data) +# return _to_jittor(data) -def is_tensor(x): - return isinstance(x, Tensor) +# def is_tensor(x): +# return isinstance(x, Tensor) -manual_seed = jt.set_global_seed -jt.flags.amp_level = 3 -Size = jt.NanoVector +# manual_seed = jt.set_global_seed +# jt.flags.amp_level = 3 +# Size = jt.NanoVector -class Generator: - def __init__(self,*args,**kw) -> None: - self.seed = None - def manual_seed(self,seed): - self.seed = seed +# class Generator: +# def __init__(self,*args,**kw) -> None: +# self.seed = None +# def manual_seed(self,seed): +# self.seed = seed -from . import fx +# from . import fx -_default_type = "float32" +# _default_type = "float32" -def get_default_dtype(): - return _default_type -def set_default_dtype(dtype): - global _default_type - _default_type = dtype +# def get_default_dtype(): +# return _default_type +# def set_default_dtype(dtype): +# global _default_type +# _default_type = dtype -dtype = JDType +# dtype = JDType -def div(x,y,rounding_mode="floor"): - assert rounding_mode == "floor" - z = (x / y) - if rounding_mode == "floor": - z = z.floor() - if x.dtype == "int32" and (isinstance(y,org_int) or y.dtype == "int32"): - z = z.int32() - return z +# def div(x,y,rounding_mode="floor"): +# assert rounding_mode == "floor" +# z = (x / y) +# if rounding_mode == "floor": +# z = z.floor() +# if x.dtype == "int32" and (isinstance(y,org_int) or y.dtype == "int32"): +# z = z.int32() +# return z -def randn(*args,**kw): - wrap_randn = wrapper(jt.randn) - generator = kw.get('generator',None) - kw.pop('generator',None) - if 'layout' in kw: - del kw['layout'] - if generator is not None and generator.seed is not None: - jt.set_global_seed(generator.seed) - return wrap_randn(*args,**kw) +# def randn(*args,**kw): +# wrap_randn = wrapper(jt.randn) +# generator = kw.get('generator',None) +# kw.pop('generator',None) +# if 'layout' in kw: +# del kw['layout'] +# if generator is not None and generator.seed is not None: +# jt.set_global_seed(generator.seed) +# return wrap_randn(*args,**kw) -def rand(*args,**kw): - print("rand") - wrap_rand = wrapper(jt.rand) - generator = kw.get('generator',None) - kw.pop('generator',None) - if 'layout' in kw: - del kw['layout'] - if generator is not None and generator.seed is not None: - jt.set_global_seed(generator.seed) - return wrap_rand(*args,**kw) +# def rand(*args,**kw): +# print("rand") +# wrap_rand = wrapper(jt.rand) +# generator = kw.get('generator',None) +# kw.pop('generator',None) +# if 'layout' in kw: +# del kw['layout'] +# if generator is not None and generator.seed is not None: +# jt.set_global_seed(generator.seed) +# return wrap_rand(*args,**kw) -def set_default_tensor_type(t: type or str): - if isinstance(t, str): - info = t.split(".") - if len(info) == 3 and info[1] == 'cuda': - jt.flags.use_cuda = 1 - #TODO: type +# def set_default_tensor_type(t: type or str): +# if isinstance(t, str): +# info = t.split(".") +# if len(info) == 3 and info[1] == 'cuda': +# jt.flags.use_cuda = 1 +# #TODO: type -def clamp(x, min=None, max=None): - return jt.clamp(x, min, max) +# def clamp(x, min=None, max=None): +# return jt.clamp(x, min, max) -def to(x,*args,**kw): - device = None - if len(args) == 1: - device = args[0] - if isinstance(device, jt.NanoString) or callable(device): - return jt.to(x,*args,**kw) - if 'cpu' in str(device): - args = [] - device = kw.get("device",None) - if 'cpu' in str(device): - kw.pop('device',None) - print("to cpu") - # print(kw) - return jt.to(x,*args,**kw) -Tensor.to = conflict_wrapper(jt.to, to) +# def to(x,*args,**kw): +# device = None +# if len(args) == 1: +# device = args[0] +# if isinstance(device, jt.NanoString) or callable(device): +# return jt.to(x,*args,**kw) +# if 'cpu' in str(device): +# args = [] +# device = kw.get("device",None) +# if 'cpu' in str(device): +# kw.pop('device',None) +# print("to cpu") +# # print(kw) +# return jt.to(x,*args,**kw) +# Tensor.to = conflict_wrapper(jt.to, to) -mm = wrapper(jt.matmul) +# mm = wrapper(jt.matmul) -def _data_get(x): - return x +# def _data_get(x): +# return x -def _data_set(x, value): - x.assign(value) +# def _data_set(x, value): +# x.assign(value) -Tensor.data = property(_data_get, _data_set) -Tensor.layout = None \ No newline at end of file +# Tensor.data = property(_data_get, _data_set) +# Tensor.layout = None \ No newline at end of file diff --git a/python/jittor/compatibility/utils/data.py b/python/jittor/compatibility/utils/data.py index 71946a23..5fcfcaa6 100644 --- a/python/jittor/compatibility/utils/data.py +++ b/python/jittor/compatibility/utils/data.py @@ -99,39 +99,39 @@ def inner_iter(self): current_batch = self.collate_batch(current_batch) yield self.to_jittor(current_batch) -def get_worker_info(): - # always return the fake worker info - return namedtuple('WorkerInfo', 'id num_workers')(0, 1) - -class RandomSampler(jt.dataset.RandomSampler): - def __init__(self, dataset, generator=None, **kwargs): - super().__init__(dataset, **kwargs) - - def __iter__(self): - if getattr(self.dataset, "support_random_access", True): - return super().__iter__() - else: - self.dataset.shuffle() - return iter(range(self.dataset.__real_len__() if hasattr(self.dataset,"__real_len__") else self.dataset.__len__())) - -class DistributedSampler(jt.dataset.Sampler): - def __init__(self, sampler: RandomSampler): - assert(isinstance(sampler, RandomSampler)) - self.sampler = sampler - - def set_epoch(self, epoch: int): - ### do nothing, let jittor's inner dataset handle - pass - - def __iter__(self): - return self.sampler.__iter__() +# def get_worker_info(): +# # always return the fake worker info +# return namedtuple('WorkerInfo', 'id num_workers')(0, 1) + +# class RandomSampler(jt.dataset.RandomSampler): +# def __init__(self, dataset, generator=None, **kwargs): +# super().__init__(dataset, **kwargs) + +# def __iter__(self): +# if getattr(self.dataset, "support_random_access", True): +# return super().__iter__() +# else: +# self.dataset.shuffle() +# return iter(range(self.dataset.__real_len__() if hasattr(self.dataset,"__real_len__") else self.dataset.__len__())) + +# class DistributedSampler(jt.dataset.Sampler): +# def __init__(self, sampler: RandomSampler): +# assert(isinstance(sampler, RandomSampler)) +# self.sampler = sampler + +# def set_epoch(self, epoch: int): +# ### do nothing, let jittor's inner dataset handle +# pass + +# def __iter__(self): +# return self.sampler.__iter__() - def __len__(self): - return self.sampler.__len__() +# def __len__(self): +# return self.sampler.__len__() -BatchSampler = jt.dataset.BatchSampler -Sampler = jt.dataset.Sampler -SequentialSampler = jt.dataset.SequentialSampler -SubsetRandomSampler = jt.dataset.SubsetRandomSampler +# BatchSampler = jt.dataset.BatchSampler +# Sampler = jt.dataset.Sampler +# SequentialSampler = jt.dataset.SequentialSampler +# SubsetRandomSampler = jt.dataset.SubsetRandomSampler -TensorDataset = Dataset +# TensorDataset = Dataset diff --git a/python/jittor/extern/cuda/cublas/inc/cublas_wrapper.h b/python/jittor/extern/cuda/cublas/inc/cublas_wrapper.h index be01348f..7667e495 100644 --- a/python/jittor/extern/cuda/cublas/inc/cublas_wrapper.h +++ b/python/jittor/extern/cuda/cublas/inc/cublas_wrapper.h @@ -25,7 +25,9 @@ static inline cudaDataType get_dtype(NanoString dtype) { if (dtype == ns_float32) return CUDA_R_32F; if (dtype == ns_float64) return CUDA_R_64F; if (dtype == ns_float16) return CUDA_R_16F; + #ifndef IS_ROCM if (dtype == ns_bfloat16) return CUDA_R_16BF; + #endif LOGf << "not support type" << dtype; return CUDA_R_32F; } diff --git a/python/jittor/extern/cuda/cudnn/inc/cudnn_wrapper.h b/python/jittor/extern/cuda/cudnn/inc/cudnn_wrapper.h index 2109a03a..73fb3233 100644 --- a/python/jittor/extern/cuda/cudnn/inc/cudnn_wrapper.h +++ b/python/jittor/extern/cuda/cudnn/inc/cudnn_wrapper.h @@ -8,8 +8,9 @@ #include #include #include +#ifndef IS_ROCM #include - +#endif #include "utils/log.h" #include "helper_cuda.h" #include "fp16_emu.h" @@ -31,6 +32,8 @@ void set_max_workspace_ratio(float64 ratio); template __inline__ cudnnDataType_t getDataType(); template <> __inline__ cudnnDataType_t getDataType() { return CUDNN_DATA_HALF; } template <> __inline__ cudnnDataType_t getDataType() { return CUDNN_DATA_FLOAT; } +#ifndef IS_ROCM template <> __inline__ cudnnDataType_t getDataType<__nv_bfloat16>() { return CUDNN_DATA_BFLOAT16; } +#endif } // jittor diff --git a/python/jittor/src/misc/cuda_atomic.h b/python/jittor/src/misc/cuda_atomic.h index 6348befb..15b39412 100644 --- a/python/jittor/src/misc/cuda_atomic.h +++ b/python/jittor/src/misc/cuda_atomic.h @@ -6,7 +6,9 @@ // *************************************************************** #pragma once #include +#ifndef IS_ROCM #include +#endif #include "common.h" namespace jittor { diff --git a/python/jittor/src/misc/nan_checker.cc b/python/jittor/src/misc/nan_checker.cc index bbec4ccb..d9a6f486 100644 --- a/python/jittor/src/misc/nan_checker.cc +++ b/python/jittor/src/misc/nan_checker.cc @@ -11,7 +11,9 @@ #include "misc/cuda_flags.h" #include #include +#ifndef IS_ROCM #include +#endif #include "helper_cuda.h" #endif #include "mem/allocator.h" @@ -22,7 +24,9 @@ namespace jittor { #ifdef IS_CUDA EXTERN_LIB vector check_nan_float16(__half* ptr, int64 num); +#ifndef IS_ROCM EXTERN_LIB vector check_nan_bfloat16(__nv_bfloat16* ptr, int64 num); +#endif EXTERN_LIB vector check_nan_float32(float32* ptr, int64 num); EXTERN_LIB vector check_nan_float64(float64* ptr, int64 num); #endif @@ -33,7 +37,9 @@ void dump_var(Var* v, string name) { name = ss.str(); LOGe << "dump" << v << "to" << name; char* buffer = new char[v->size]; - #ifdef HAS_CUDA + #ifdef IS_ROCM + hipMemcpy(buffer, v->mem_ptr, v->size, hipMemcpyDefault); + #elif HAS_CUDA cudaMemcpy(buffer, v->mem_ptr, v->size, cudaMemcpyDefault); #else std::memcpy(buffer, v->mem_ptr, v->size); @@ -57,9 +63,11 @@ bool check_nan(Var* v, Op* op) { if (v->dtype() == ns_float16) { nan_index = check_nan_float16((__half*)v->mem_ptr, v->num); } + #ifndef IS_ROCM if (v->dtype() == ns_bfloat16) { nan_index = check_nan_bfloat16((__nv_bfloat16*)v->mem_ptr, v->num); } + #endif if (v->dtype() == ns_float32) { nan_index = check_nan_float32((float32*)v->mem_ptr, v->num); } else @@ -104,14 +112,16 @@ bool check_nan(Var* v, Op* op) { auto* ptr = input->ptr<__half>(); __half value; cudaMemcpy(&value, ptr+index, sizeof(__half), cudaMemcpyDeviceToHost); - LOGe << "input" << icnt << "dtype" << input->dtype() << "index" << index << "value" << (float)value; + // LOGe << "input" << icnt << "dtype" << input->dtype() << "index" << index << "value" << (float)value; } else + #ifndef IS_ROCM if (input->dtype() == ns_bfloat16) { auto* ptr = input->ptr<__nv_bfloat16>(); __nv_bfloat16 value; cudaMemcpy(&value, ptr+index, sizeof(__nv_bfloat16), cudaMemcpyDeviceToHost); LOGe << "input" << icnt << "dtype" << input->dtype() << "index" << index << "value" << (float)value; } else + #endif if (input->dtype() == ns_float32) { auto* ptr = input->ptr(); float32 value; diff --git a/python/jittor/src/misc/nan_checker.cu b/python/jittor/src/misc/nan_checker.cu index ec7998ba..b2e631f1 100644 --- a/python/jittor/src/misc/nan_checker.cu +++ b/python/jittor/src/misc/nan_checker.cu @@ -7,9 +7,13 @@ #include "misc/cuda_flags.h" #include #include -#include + #include "helper_cuda.h" #include +//TODO:FIX in ROCM +#ifndef IS_ROCM +#include +#endif namespace jittor { @@ -37,7 +41,7 @@ __global__ void _check_nan_float16(__half* __restrict__ ptr, int64 num, int* cnt print_nan(float(ptr[i]), i, cnt); } } - +#ifndef IS_ROCM __global__ void _check_nan_bfloat16(__nv_bfloat16* __restrict__ ptr, int64 num, int* cnt) { int64 i = threadIdx.x + blockIdx.x * (int64)blockDim.x; if (i check_nan_float16(__half* ptr, int64 num) { _check_nan_float16<<>>(ptr, num, check_nan_get_device_ptr()); return report_nan(); } - +#ifndef IS_ROCM vector check_nan_bfloat16(__nv_bfloat16* ptr, int64 num) { int block_num = std::max((int64)1, (num-1)/1024+1); int thread_num = std::min((int64)1024, num); _check_nan_bfloat16<<>>(ptr, num, check_nan_get_device_ptr()); return report_nan(); } - +#endif #endif } \ No newline at end of file diff --git a/python/jittor/src/type/fp16_compute.h b/python/jittor/src/type/fp16_compute.h index 56f70789..f5c85d3b 100644 --- a/python/jittor/src/type/fp16_compute.h +++ b/python/jittor/src/type/fp16_compute.h @@ -11,12 +11,17 @@ #include #include +#ifndef IS_ROCM #include +#endif namespace jittor { typedef __half float16; +#ifndef IS_ROCM typedef __nv_bfloat16 bfloat16; +#endif + #if CUDA_ARCH >= 800 inline __device__ float16 max(float16 a, float16 b) { return __hmax(a, b); } @@ -32,7 +37,7 @@ inline __device__ float16 min(float16 a, float16 b) { return float(a)= 800 inline __device__ bfloat16 max(bfloat16 a, bfloat16 b) { return __hmax(a, b); } inline __device__ bfloat16 min(bfloat16 a, bfloat16 b) { return __hmin(a, b); } @@ -45,7 +50,7 @@ inline __device__ bfloat16 min(bfloat16 a, bfloat16 b) { return float(a) __device__ inline typename std::enable_if::type diff --git a/python/jittor/src/var_holder.cc b/python/jittor/src/var_holder.cc index 092eddd2..ed4d4c0a 100644 --- a/python/jittor/src/var_holder.cc +++ b/python/jittor/src/var_holder.cc @@ -215,11 +215,14 @@ inline static void cast_item_data(ItemData& data) { auto* fp16 = (float16*)&data; auto* fp32 = (float32*)&data; fp32[0] = float32(fp16[0]); - } else if (data.dtype == ns_bfloat16) { + } + #ifndef IS_ROCM + else if (data.dtype == ns_bfloat16) { auto* bf16 = (bfloat16*)&data; auto* fp32 = (float32*)&data; fp32[0] = float32(bf16[0]); } + #endif data.dtype = ns_float32; } From 05f4cf3f11bb7b3c23082dab4dfcfbf7ffad4fbe Mon Sep 17 00:00:00 2001 From: DongYang Li <62846124+LDYang694@users.noreply.github.com> Date: Mon, 20 May 2024 21:43:48 +0800 Subject: [PATCH 04/73] Update version to 1.3.9.8 --- python/jittor/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/jittor/__init__.py b/python/jittor/__init__.py index fc1811e8..4f4fe5b3 100644 --- a/python/jittor/__init__.py +++ b/python/jittor/__init__.py @@ -9,7 +9,7 @@ # file 'LICENSE.txt', which is part of this source code package. # *************************************************************** -__version__ = '1.3.9.6' +__version__ = '1.3.9.8' from jittor_utils import lock with lock.lock_scope(): ori_int = int From 7714ce31d1e69084211f374845a075a0e9140a6f Mon Sep 17 00:00:00 2001 From: zhc7 Date: Wed, 22 May 2024 11:08:28 +0800 Subject: [PATCH 05/73] fix: a minimal quick fix for issue #544 --- python/jittor/misc.py | 1 + 1 file changed, 1 insertion(+) diff --git a/python/jittor/misc.py b/python/jittor/misc.py index fc528d62..d8e3d551 100644 --- a/python/jittor/misc.py +++ b/python/jittor/misc.py @@ -2010,6 +2010,7 @@ def contiguous(x): return x.clone() def cpu(x): return x.clone() jt.Var.cpu = cpu def to(x, *args, **kargs): + args += tuple(kargs.values()) if len(args) >= 1: s = args[0] if isinstance(s, jt.NanoString) or callable(s): From 5df1673608dc7c708e5739260d80d5a4d97c92de Mon Sep 17 00:00:00 2001 From: Hanyuxuan Date: Wed, 29 May 2024 11:05:53 +0800 Subject: [PATCH 06/73] fix: jt.Var.expand with valid index -1 --- python/jittor/misc.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/python/jittor/misc.py b/python/jittor/misc.py index d8e3d551..5d05f13e 100644 --- a/python/jittor/misc.py +++ b/python/jittor/misc.py @@ -294,9 +294,10 @@ def expand(x, *shape): if len(shape) == 1 and isinstance(shape[0], (tuple,list,jt.NanoVector)): shape = shape[0] shape = list(shape) - for i in range(len(shape)): - if shape[i] == -1: - shape[i] = x.shape[i] + offset = len(shape) - len(x.shape) + for i in range(len(x.shape)): + if shape[offset + i] == -1: + shape[offset + i] = x.shape[i] return x.broadcast(shape) jt.Var.expand = expand From 14de5fa8bdfe36f4d40c409a5d9b9ba921d0da49 Mon Sep 17 00:00:00 2001 From: Hanyuxuan Date: Thu, 30 May 2024 14:00:27 +0800 Subject: [PATCH 07/73] a IndexError fix of issue #448 --- python/jittor/__init__.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/python/jittor/__init__.py b/python/jittor/__init__.py index f5deae17..6dc49360 100644 --- a/python/jittor/__init__.py +++ b/python/jittor/__init__.py @@ -698,6 +698,8 @@ def flatten(input, start_dim=0, end_dim=-1): start_dim = len(in_shape) + start_dim if start_dim < 0 else start_dim end_dim = len(in_shape) + end_dim if end_dim < 0 else end_dim assert end_dim >= start_dim, "end_dim should be larger than or equal to start_dim for flatten function" + if len(in_shape) <= end_dim: + raise IndexError("Dimension out of range (expected to be in range of [%d, %d], but got %d)" % (-len(in_shape),len(in_shape) - 1,end_dim)) out_shape = [] for i in range(0,start_dim,1): out_shape.append(in_shape[i]) dims = 1 From 862bce9d5bb3eb88a221060e4b886e39ec74269c Mon Sep 17 00:00:00 2001 From: Hanyuxuan Date: Thu, 30 May 2024 14:29:34 +0800 Subject: [PATCH 08/73] a ValueError fix of issue #450 --- python/jittor/nn.py | 18 ++++++++++++------ 1 file changed, 12 insertions(+), 6 deletions(-) diff --git a/python/jittor/nn.py b/python/jittor/nn.py index 30f96dcd..98cd48d3 100644 --- a/python/jittor/nn.py +++ b/python/jittor/nn.py @@ -2175,15 +2175,21 @@ def grid_sample(input, grid, mode='bilinear', padding_mode='zeros', align_corner class Upsample(Module): def __init__(self, scale_factor=None, mode='nearest'): - self.scale_factor = scale_factor if isinstance(scale_factor, tuple) else (scale_factor, scale_factor) + if isinstance(scale_factor, tuple): + self.scale_factor = tuple(float(factor) for factor in scale_factor) + else: + self.scale_factor = float(scale_factor) if scale_factor else None self.mode = mode def execute(self, x): - return upsample(x, - size=( - int(x.shape[2]*self.scale_factor[0]), - int(x.shape[3]*self.scale_factor[1])), - mode=self.mode) + if self.scale_factor is None: + raise ValueError("scale_factor should be defined") + else: + return upsample(x, + size=( + int(x.shape[2]*self.scale_factor[0]), + int(x.shape[3]*self.scale_factor[1])), + mode=self.mode) class UpsamplingBilinear2d(Upsample): def __init__(self, scale_factor=None): From cd8b19ada5881df3db0dccfdc843bffa152b041c Mon Sep 17 00:00:00 2001 From: Hanyuxuan Date: Thu, 30 May 2024 15:19:19 +0800 Subject: [PATCH 09/73] fix illegal parameters of Pool and Pool3d of issue #451,#453,#456,#457 --- python/jittor/__init__.py | 2 +- python/jittor/pool.py | 26 +++++++++++++------------- 2 files changed, 14 insertions(+), 14 deletions(-) diff --git a/python/jittor/__init__.py b/python/jittor/__init__.py index 6dc49360..007d702a 100644 --- a/python/jittor/__init__.py +++ b/python/jittor/__init__.py @@ -699,7 +699,7 @@ def flatten(input, start_dim=0, end_dim=-1): end_dim = len(in_shape) + end_dim if end_dim < 0 else end_dim assert end_dim >= start_dim, "end_dim should be larger than or equal to start_dim for flatten function" if len(in_shape) <= end_dim: - raise IndexError("Dimension out of range (expected to be in range of [%d, %d], but got %d)" % (-len(in_shape),len(in_shape) - 1,end_dim)) + raise IndexError(f"Dimension out of range (expected to be in range of [{-len(in_shape)}, {len(in_shape) - 1}], but got {end_dim})") out_shape = [] for i in range(0,start_dim,1): out_shape.append(in_shape[i]) dims = 1 diff --git a/python/jittor/pool.py b/python/jittor/pool.py index 79c693b0..aa89e897 100644 --- a/python/jittor/pool.py +++ b/python/jittor/pool.py @@ -21,6 +21,12 @@ class Pool(Module): def __init__(self, kernel_size, stride=None, padding=0, dilation=None, return_indices=None, ceil_mode=False, count_include_pad=True, op="maximum"): assert dilation == None assert return_indices == None or op == "maximum" + if self.kernel_size[0] <= 0 or self.kernel_size[1] <= 0: + raise RuntimeError(f"kernel_size must be greater than zero, but got {kernel_size}") + if self.stride[0] <= 0 or self.stride[1] <= 0: + raise RuntimeError(f"stride must be greater than zero, but got {stride}") + if self.padding[0] < 0 or self.padding[1] < 0: + raise RuntimeError(f"padding must be non-negative, but got {padding}") self.return_indices = return_indices self.kernel_size = kernel_size if isinstance(kernel_size, tuple) else (kernel_size, kernel_size) self.op = op @@ -29,12 +35,6 @@ def __init__(self, kernel_size, stride=None, padding=0, dilation=None, return_in self.padding = padding if isinstance(padding, tuple) else (padding, padding) self.ceil_mode = ceil_mode self.count_include_pad = count_include_pad and padding != 0 - if self.kernel_size[0] <= 0 or self.kernel_size[1] <= 0: - raise RuntimeError(f"kernel_size must be greater than zero, but got {kernel_size}") - if self.stride[0] <= 0 or self.stride[1] <= 0: - raise RuntimeError(f"stride must be greater than zero, but got {stride}") - if self.padding[0] < 0 or self.padding[1] < 0: - raise RuntimeError(f"padding must be non-negative, but got {padding}") def execute(self, x): N,C,H,W = x.shape @@ -203,6 +203,12 @@ class Pool3d(Module): def __init__(self, kernel_size, stride=None, padding=0, dilation=None, return_indices=None, ceil_mode=False, count_include_pad=True, op="maximum"): assert dilation == None assert return_indices == None or op == "maximum" + if self.kernel_size[0] <= 0 or self.kernel_size[1] <= 0 or self.kernel_size[2] <= 0: + raise RuntimeError(f"kernel_size must be greater than zero, but got {kernel_size}") + if self.stride[0] <= 0 or self.stride[1] <= 0 or self.stride[2] <= 0: + raise RuntimeError(f"stride must be greater than zero, but got {stride}") + if self.padding[0] < 0 or self.padding[1] < 0 or self.padding[2] < 0: + raise RuntimeError(f"padding must be non-negative, but got {padding}") self.return_indices = return_indices self.kernel_size = _triple(kernel_size) self.op = op @@ -211,12 +217,6 @@ def __init__(self, kernel_size, stride=None, padding=0, dilation=None, return_in self.padding = _triple(padding) self.ceil_mode = ceil_mode self.count_include_pad = count_include_pad and padding != 0 - if self.kernel_size[0] <= 0 or self.kernel_size[1] <= 0 or self.kernel_size[2] <= 0: - raise RuntimeError(f"kernel_size must be greater than zero, but got {kernel_size}") - if self.stride[0] <= 0 or self.stride[1] <= 0 or self.stride[2] <= 0: - raise RuntimeError(f"stride must be greater than zero, but got {stride}") - if self.padding[0] < 0 or self.padding[1] < 0 or self.padding[2] < 0: - raise RuntimeError(f"padding must be non-negative, but got {padding}") def execute(self, x): N,C,D,H,W = x.shape @@ -518,7 +518,7 @@ def execute(self, x): f"i3*{self.sh}+i6", # Hid f"i4*{self.sw}+i7", # Wid ]) - return xx.reduce("maximun", [5,6,7]) + return xx.reduce("maximum", [5,6,7]) def pool(x, kernel_size, op, padding=0, stride=None): return Pool(kernel_size, stride, padding, op=op)(x) From 793d63894c84dfe4734a97a863c819d79e23b2cf Mon Sep 17 00:00:00 2001 From: Hanyuxuan Date: Thu, 30 May 2024 16:10:06 +0800 Subject: [PATCH 10/73] fix illegal parameters of Conv2d issue #471,#472,#473,#474,#475,#476,#477 --- python/jittor/nn.py | 38 ++++++++++++++++++++++++++++++++++++-- 1 file changed, 36 insertions(+), 2 deletions(-) diff --git a/python/jittor/nn.py b/python/jittor/nn.py index 98cd48d3..dd828594 100644 --- a/python/jittor/nn.py +++ b/python/jittor/nn.py @@ -926,6 +926,42 @@ class Conv(Module): >>> output = conv(input) ''' def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True): + if in_channels <= 0: + raise ValueError(f"in_channels must be greater than zero, got {in_channels}") + if out_channels <= 0: + raise ValueError(f"out_channels must be greater than zero, got {out_channels}") + if groups <= 0: + raise ValueError(f"groups must must be greater than zero, got {groups}") + assert in_channels % groups == 0, 'in_channels must be divisible by groups' + assert out_channels % groups == 0, 'out_channels must be divisible by groups' + if isinstance(kernel_size, tuple): + for size in kernel_size: + if size <= 0: + raise ValueError(f"kernel_size must be greater than zero, got {kernel_size}") + else: + if kernel_size <= 0: + raise ValueError(f"kernel_size must be greater than zero, got {kernel_size}") + if isinstance(stride, tuple): + for size in stride: + if size <= 0: + raise ValueError(f"stride must be greater than zero, got {stride}") + else: + if stride <= 0: + raise ValueError(f"stride must be greater than zero, got {stride}") + if isinstance(padding, tuple): + for size in padding: + if size < 0: + raise ValueError(f"padding must be nonnegative, got {padding}") + else: + if padding < 0: + raise ValueError(f"padding must be nonnegative, got {padding}") + if isinstance(dilation, tuple): + for size in dilation: + if size <= 0: + raise ValueError(f"dilation must be greater than zero, got {dilation}") + else: + if dilation <= 0: + raise ValueError(f"dilation must be greater than zero, got {dilation}") self.in_channels = in_channels self.out_channels = out_channels self.kernel_size = kernel_size if isinstance(kernel_size, tuple) else (kernel_size, kernel_size) @@ -936,8 +972,6 @@ def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, self.is_depthwise_conv = self.groups == self.out_channels and self.groups == self.in_channels if self.is_depthwise_conv and jt.flags.use_cuda and jt.compiler.is_cuda: self.depthwise_conv = DepthwiseConv(stride, padding, dilation) - assert in_channels % groups == 0, 'in_channels must be divisible by groups' - assert out_channels % groups == 0, 'out_channels must be divisible by groups' Kh, Kw = self.kernel_size # self.weight = init.relu_invariant_gauss([out_channels, in_channels//groups, Kh, Kw], dtype="float", mode="fan_out") From 9e60eb6d1c53a63edc4014db433cdcfb594c6062 Mon Sep 17 00:00:00 2001 From: Hanyuxuan Date: Thu, 30 May 2024 19:03:17 +0800 Subject: [PATCH 11/73] fix illegal parameters of PixelShuffle of issue #458,fix validity of concat of issue #459 --- python/jittor/contrib.py | 10 +++++++++- python/jittor/nn.py | 1 + 2 files changed, 10 insertions(+), 1 deletion(-) diff --git a/python/jittor/contrib.py b/python/jittor/contrib.py index 7235f7d4..c4026eee 100644 --- a/python/jittor/contrib.py +++ b/python/jittor/contrib.py @@ -243,9 +243,17 @@ def concat(arr, dim=0): if len(arr) == 0: raise ValueError("need at least one array to concat") total_dim = 0 - if dim < 0: dim += len(arr[0].shape) + base_dim = len(arr[0].shape) + if dim < 0: dim += base_dim + if dim < 0 or dim >= base_dim: + raise IndexError(f"Dimension out of range (expected to be in range of [{-base_dim}, {base_dim-1}], but got {dim})") dtypes = [] for a in arr: + if len(a.shape) != base_dim: + raise RuntimeError(f"get different number of dimensions of {base_dim} and {len(a.shape)}") + for i in range(base_dim): + if i != dim and a.shape[i] != arr[0].shape[i]: + raise RuntimeError(f"Sizes of vars must match except in dimension {dim}. Expected size {arr[0].shape[i]} but got size {a.shape[i]} for dimension number {i} in the list.") total_dim += a.shape[dim] dtypes.append(str(a.dtype)) cdim = 0 diff --git a/python/jittor/nn.py b/python/jittor/nn.py index dd828594..9f7e8aeb 100644 --- a/python/jittor/nn.py +++ b/python/jittor/nn.py @@ -1809,6 +1809,7 @@ def embedding(input, weight): class PixelShuffle(Module): def __init__(self, upscale_factor): + assert upscale_factor > 0,f"upscale_factor must be greater than zero,got {upscale_factor}" self.upscale_factor = upscale_factor def execute(self, x): From 9a23f5ce25705979301c56265c05862c6ae4c531 Mon Sep 17 00:00:00 2001 From: Hanyuxuan Date: Fri, 31 May 2024 14:33:07 +0800 Subject: [PATCH 12/73] check x.shape and kernel_size of Pool and Pool3d,issue #461,#463 --- python/jittor/nn.py | 9 +++++++++ python/jittor/pool.py | 31 +++++++++++++++++++------------ 2 files changed, 28 insertions(+), 12 deletions(-) diff --git a/python/jittor/nn.py b/python/jittor/nn.py index 9f7e8aeb..57e2bdc0 100644 --- a/python/jittor/nn.py +++ b/python/jittor/nn.py @@ -1865,6 +1865,15 @@ def execute(self, x): class Resize(Module): def __init__(self, size, mode="nearest", align_corners=False): super().__init__() + if isinstance(size,int): + if size <= 0: + raise ValueError(f"sizes must be positive, got {size}") + elif isinstance(size,tuple) or isinstance(size,list): + for item in size: + if item <= 0: + raise ValueError(f"sizes must be positive, got {item}") + else: + raise ValueError(f"size must be int or tuple") self.size = size self.mode = mode self.align_corners = align_corners diff --git a/python/jittor/pool.py b/python/jittor/pool.py index aa89e897..7e9e808a 100644 --- a/python/jittor/pool.py +++ b/python/jittor/pool.py @@ -21,12 +21,6 @@ class Pool(Module): def __init__(self, kernel_size, stride=None, padding=0, dilation=None, return_indices=None, ceil_mode=False, count_include_pad=True, op="maximum"): assert dilation == None assert return_indices == None or op == "maximum" - if self.kernel_size[0] <= 0 or self.kernel_size[1] <= 0: - raise RuntimeError(f"kernel_size must be greater than zero, but got {kernel_size}") - if self.stride[0] <= 0 or self.stride[1] <= 0: - raise RuntimeError(f"stride must be greater than zero, but got {stride}") - if self.padding[0] < 0 or self.padding[1] < 0: - raise RuntimeError(f"padding must be non-negative, but got {padding}") self.return_indices = return_indices self.kernel_size = kernel_size if isinstance(kernel_size, tuple) else (kernel_size, kernel_size) self.op = op @@ -35,9 +29,20 @@ def __init__(self, kernel_size, stride=None, padding=0, dilation=None, return_in self.padding = padding if isinstance(padding, tuple) else (padding, padding) self.ceil_mode = ceil_mode self.count_include_pad = count_include_pad and padding != 0 + for item in self.kernel_size: + if item <= 0: + raise RuntimeError(f"kernel_size must be greater than zero, but got {item}") + for item in self.stride: + if item <= 0: + raise RuntimeError(f"stride must be greater than zero, but got {item}") + for item in self.padding: + if item < 0: + raise RuntimeError(f"padding must be non-negative, but got {item}") def execute(self, x): N,C,H,W = x.shape + if H <= self.kernel_size[0] or W <= self.kernel_size[1]: + raise RuntimeError(f"size of var should be larger than kernel_size") if self.ceil_mode == False: h = (H+self.padding[0]*2-self.kernel_size[0])//self.stride[0]+1 w = (W+self.padding[1]*2-self.kernel_size[1])//self.stride[1]+1 @@ -203,12 +208,6 @@ class Pool3d(Module): def __init__(self, kernel_size, stride=None, padding=0, dilation=None, return_indices=None, ceil_mode=False, count_include_pad=True, op="maximum"): assert dilation == None assert return_indices == None or op == "maximum" - if self.kernel_size[0] <= 0 or self.kernel_size[1] <= 0 or self.kernel_size[2] <= 0: - raise RuntimeError(f"kernel_size must be greater than zero, but got {kernel_size}") - if self.stride[0] <= 0 or self.stride[1] <= 0 or self.stride[2] <= 0: - raise RuntimeError(f"stride must be greater than zero, but got {stride}") - if self.padding[0] < 0 or self.padding[1] < 0 or self.padding[2] < 0: - raise RuntimeError(f"padding must be non-negative, but got {padding}") self.return_indices = return_indices self.kernel_size = _triple(kernel_size) self.op = op @@ -217,9 +216,17 @@ def __init__(self, kernel_size, stride=None, padding=0, dilation=None, return_in self.padding = _triple(padding) self.ceil_mode = ceil_mode self.count_include_pad = count_include_pad and padding != 0 + if self.kernel_size[0] <= 0 or self.kernel_size[1] <= 0 or self.kernel_size[2] <= 0: + raise RuntimeError(f"kernel_size must be greater than zero, but got {kernel_size}") + if self.stride[0] <= 0 or self.stride[1] <= 0 or self.stride[2] <= 0: + raise RuntimeError(f"stride must be greater than zero, but got {stride}") + if self.padding[0] < 0 or self.padding[1] < 0 or self.padding[2] < 0: + raise RuntimeError(f"padding must be non-negative, but got {padding}") def execute(self, x): N,C,D,H,W = x.shape + if D <= self.kernel_size[0] or H <= self.kernel_size[1] or W <= self.kernel_size[2]: + raise RuntimeError(f"size of var should be larger than kernel_size") if self.ceil_mode == False: d = (D+self.padding[0]*2-self.kernel_size[0])//self.stride[0]+1 h = (H+self.padding[1]*2-self.kernel_size[1])//self.stride[1]+1 From 9d7e6342b5f63988ce3daf773d3c11d780e3e013 Mon Sep 17 00:00:00 2001 From: Hanyuxuan Date: Fri, 31 May 2024 14:50:47 +0800 Subject: [PATCH 13/73] fix Pad2d with illegal padding,issue #464,#465,#466,#467 --- python/jittor/nn.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/python/jittor/nn.py b/python/jittor/nn.py index 57e2bdc0..5d408a53 100644 --- a/python/jittor/nn.py +++ b/python/jittor/nn.py @@ -1688,6 +1688,8 @@ def __init__(self, padding): self.pl, self.pr, self.pt, self.pb = self.padding else: raise TypeError(f"ReflectionPad2d padding just support int or tuple, but found {type(padding)}") + if self.pl < 0 or self.pr < 0 or self.pt < 0 or self.pb < 0: + raise ValueError(f"padding must be non-negative") def execute(self, x): n,c,h,w = x.shape @@ -1716,6 +1718,8 @@ def __init__(self, padding): self.pl, self.pr, self.pt, self.pb = self.padding else: raise TypeError(f"ZeroPad2d padding just support int or tuple, but found {type(padding)}") + if self.pl < 0 or self.pr < 0 or self.pt < 0 or self.pb < 0: + raise ValueError(f"padding must be non-negative") def execute(self, x): n,c,h,w = x.shape @@ -1734,6 +1738,8 @@ def __init__(self, padding, value): else: raise TypeError(f"ConstantPad2d padding just support int or tuple, but found {type(padding)}") self.value = value + if self.pl < 0 or self.pr < 0 or self.pt < 0 or self.pb < 0: + raise ValueError(f"padding must be non-negative") def execute(self, x): assert len(x.shape) >= 2 @@ -1760,6 +1766,8 @@ def __init__(self, padding): self.pl, self.pr, self.pt, self.pb = self.padding else: raise TypeError(f"ReplicationPad2d padding just support int or tuple, but found {type(padding)}") + if self.pl < 0 or self.pr < 0 or self.pt < 0 or self.pb < 0: + raise ValueError(f"padding must be non-negative") def execute(self, x): n,c,h,w = x.shape From c79142d8e9910ced72294a242a9bd8065bd7eeb3 Mon Sep 17 00:00:00 2001 From: Hanyuxuan Date: Fri, 31 May 2024 15:23:42 +0800 Subject: [PATCH 14/73] fix illegal parameters of ConvTranspose and Pool,issue #478,#480,#481,#482,#483 --- python/jittor/nn.py | 1 + 1 file changed, 1 insertion(+) diff --git a/python/jittor/nn.py b/python/jittor/nn.py index 5d408a53..8f8012a8 100644 --- a/python/jittor/nn.py +++ b/python/jittor/nn.py @@ -1393,6 +1393,7 @@ def __init__(self, in_channels, out_channels, kernel_size, stride=1, \ self.real_padding = (self.dilation[0] * (self.kernel_size[0] - 1) - self.padding[0], self.dilation[1] * (self.kernel_size[1] - 1) - self.padding[1]) self.output_padding = output_padding if isinstance (output_padding, tuple) else (output_padding, output_padding) + assert self.padding[0] >= 0 or self.padding[1] >= 0,"padding must be non-negative" assert self.output_padding[0] < max(self.stride[0], self.dilation[0]) and \ self.output_padding[1] < max(self.stride[1], self.dilation[1]), \ "output padding must be smaller than max(stride, dilation)" From 69674751ca43a787f3a42da7d9e79c07fd003c79 Mon Sep 17 00:00:00 2001 From: DongYang Li <62846124+LDYang694@users.noreply.github.com> Date: Tue, 4 Jun 2024 17:55:59 +0800 Subject: [PATCH 15/73] Update README.md --- README.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index 414eb868..90601b43 100644 --- a/README.md +++ b/README.md @@ -382,10 +382,10 @@ Email: jittor@qq.com File an issue: https://github.com/Jittor/jittor/issues -QQ Group: 761222083 +QQ Group: 836860279 - + ## The Team From 8d26bb8006353afb8d077d48df667822d116fd2f Mon Sep 17 00:00:00 2001 From: lidongyang Date: Wed, 5 Jun 2024 22:31:20 +0800 Subject: [PATCH 16/73] polish nn.Sequential attribute --- python/jittor/nn.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/python/jittor/nn.py b/python/jittor/nn.py index 8f8012a8..b423b45c 100644 --- a/python/jittor/nn.py +++ b/python/jittor/nn.py @@ -2308,6 +2308,15 @@ def __len__(self): def named_children(self,): return list(self.layers.items()) + + def __setattr__(self, key, value) -> None: + if isinstance(key, str) and key.isdigit(): + if int(key) Date: Thu, 6 Jun 2024 18:02:29 +0800 Subject: [PATCH 17/73] check target shape and output shape in jt.nn.binary_cross_entropy_with_logits --- python/jittor/nn.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/python/jittor/nn.py b/python/jittor/nn.py index b423b45c..844fdfb6 100644 --- a/python/jittor/nn.py +++ b/python/jittor/nn.py @@ -494,6 +494,9 @@ def execute(self, output, target): return l1_loss(output, target) def binary_cross_entropy_with_logits(output, target, weight=None, pos_weight=None, size_average=True): + if not (target.shape == output.shape): + raise ValueError(f"Target size ({target.shape}) must be the same as output size ({output.shape})") + max_val = jt.clamp(-output,min_v=0) if pos_weight is not None: log_weight = (pos_weight-1)*target + 1 From 3310abbe2cd3433f351527e8407c87c9f332c95f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E8=8C=83=E5=AD=99=E5=A5=87?= <2020012497@secoder.net> Date: Thu, 6 Jun 2024 20:04:35 +0800 Subject: [PATCH 18/73] check input1 and input2 shape in jt.nn.Bilinear() --- python/jittor/nn.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/python/jittor/nn.py b/python/jittor/nn.py index b423b45c..f5be5576 100644 --- a/python/jittor/nn.py +++ b/python/jittor/nn.py @@ -3006,6 +3006,10 @@ def call_rnn_cell(self, input, hidden, suffix): return h, h def bilinear(in1, in2, weight, bias): + if weight.shape[1] != in1.shape[1]: + raise RuntimeError(f"bilinear(): input1 size deos not match weight size: got {in1.shape[1]} but expected {weight.shape[1]}") + if weight.shape[2] != in2.shape[1]: + raise RuntimeError(f"bilinear(): input2 size deos not match weight size: got {in2.shape[1]} but expected {weight.shape[2]}") w = weight.transpose((1,0,2)) w = w.reshape((w.shape[0], -1)) x = jt.matmul(in1, w) From e7e3ea392ecb198449f771f17d365f737aa5c157 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E8=8C=83=E5=AD=99=E5=A5=87?= <2020012497@secoder.net> Date: Thu, 6 Jun 2024 20:18:31 +0800 Subject: [PATCH 19/73] check input shape in jt.nn.Conv1d --- python/jittor/nn.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/python/jittor/nn.py b/python/jittor/nn.py index b423b45c..88eb77af 100644 --- a/python/jittor/nn.py +++ b/python/jittor/nn.py @@ -1106,6 +1106,8 @@ def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, self.bias = self._conv[0].bias def execute(self, x): + if x.dim() != 3: + raise ValueError("Input shape must be `(N, C, L)`!") N,C,D = x.shape assert C==self.in_channels self._conv[0].weight = self.weight.unsqueeze(-1) From 8cc95a6ce1ce8c6fe2c632ca34ce691726d37584 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E8=8C=83=E5=AD=99=E5=A5=87?= <2020012497@secoder.net> Date: Thu, 6 Jun 2024 20:25:29 +0800 Subject: [PATCH 20/73] check input shape in jt.nn.Conv1d_sp --- python/jittor/nn.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/python/jittor/nn.py b/python/jittor/nn.py index b423b45c..1de651f3 100644 --- a/python/jittor/nn.py +++ b/python/jittor/nn.py @@ -1187,6 +1187,8 @@ def __init__(self, inchannels, outchannels, kernel_size=1, bias=True): assert kernel_size == 1 def execute(self, x): + if x.dim() != 3: + raise ValueError("Input shape must be `(N, C, L)`!") x = x.transpose(0, 2, 1) x = super().execute(x) x = x.transpose(0, 2, 1) From d7bfb0591d2042d69d8273f2e90eca900a24487e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E8=8C=83=E5=AD=99=E5=A5=87?= <2020012497@secoder.net> Date: Thu, 6 Jun 2024 20:35:05 +0800 Subject: [PATCH 21/73] jt.nn.Conv1d_sp in_channels and out_channels must be positive --- python/jittor/nn.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/python/jittor/nn.py b/python/jittor/nn.py index 1de651f3..d99cedbf 100644 --- a/python/jittor/nn.py +++ b/python/jittor/nn.py @@ -1183,6 +1183,8 @@ def execute(self, x): class Conv1d_sp(Linear): def __init__(self, inchannels, outchannels, kernel_size=1, bias=True): + assert inchannels > 0, 'in_channels must be positive' + assert outchannels > 0, 'out_channels must be positive' super().__init__(inchannels, outchannels, bias=bias) assert kernel_size == 1 From 4c04060a1f425e6a2adda3642a92a4771544640d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E8=8C=83=E5=AD=99=E5=A5=87?= <2020012497@secoder.net> Date: Thu, 6 Jun 2024 20:39:10 +0800 Subject: [PATCH 22/73] jt.nn.Conv1d in_channels and out_channels must be positive --- python/jittor/nn.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/python/jittor/nn.py b/python/jittor/nn.py index 88eb77af..d4e2b2d1 100644 --- a/python/jittor/nn.py +++ b/python/jittor/nn.py @@ -1088,6 +1088,8 @@ class Conv1d(Module): >>> output = conv(input) ''' def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True): + assert in_channels > 0, 'in_channels must be positive' + assert out_channels > 0, 'out_channels must be positive' self.in_channels = in_channels self.out_channels = out_channels self.kernel_size = (kernel_size, 1) From ecc73d051ef5eee0b225716eb8d2d3263f9f18e3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E8=8C=83=E5=AD=99=E5=A5=87?= <2020012497@secoder.net> Date: Thu, 6 Jun 2024 20:55:46 +0800 Subject: [PATCH 23/73] check input shape in jt.nn.ConvTranspose --- python/jittor/nn.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/python/jittor/nn.py b/python/jittor/nn.py index b423b45c..18a43817 100644 --- a/python/jittor/nn.py +++ b/python/jittor/nn.py @@ -1411,6 +1411,8 @@ def __init__(self, in_channels, out_channels, kernel_size, stride=1, \ self.bias = None def execute(self, x): + if x.dim() != 4: + raise RuntimeError(f'Expected 3D (unbatched) or 4D (batched) input to conv_transpose2d, but got input of size: {x.shape}') if self.groups == 1: N,C,H,W = x.shape i,o,h,w = self.weight.shape From 9c26755e2e47681614d99d332d8bdf64a9780cbe Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E8=8C=83=E5=AD=99=E5=A5=87?= <2020012497@secoder.net> Date: Mon, 10 Jun 2024 16:08:48 +0800 Subject: [PATCH 24/73] modify stride positive check in jt.nn.transpose3d --- python/jittor/nn.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/python/jittor/nn.py b/python/jittor/nn.py index b423b45c..c3cdd974 100644 --- a/python/jittor/nn.py +++ b/python/jittor/nn.py @@ -1608,11 +1608,11 @@ def conv_transpose3d(input, weight, bias=None, stride=1, padding=0, output_paddi x = input N,C,D,H,W = x.shape i,o,d,h,w = weight.shape - if stride <= 0: - raise RuntimeError("non-positive stride is not supported") assert C==i assert groups==1, "Group conv not supported yet." stride = stride if isinstance(stride, tuple) else (stride, stride, stride) + if stride[0] <= 0 or stride[1] <= 0 or stride[2] <= 0: + raise RuntimeError("non-positive stride is not supported") dilation = dilation if isinstance(dilation, tuple) else (dilation, dilation, dilation) # added padding = padding if isinstance(padding, tuple) else (padding, padding, padding) From 5d35972ece3c3e061e74a0ac7556fcaee4074f8a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E8=8C=83=E5=AD=99=E5=A5=87?= <2020012497@secoder.net> Date: Mon, 10 Jun 2024 16:29:58 +0800 Subject: [PATCH 25/73] add input shape check in jt.nn.transpose3d --- python/jittor/nn.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/python/jittor/nn.py b/python/jittor/nn.py index c3cdd974..d992b360 100644 --- a/python/jittor/nn.py +++ b/python/jittor/nn.py @@ -1606,6 +1606,8 @@ def conv_transpose(input, weight, bias=None, stride=1, padding=0, output_padding def conv_transpose3d(input, weight, bias=None, stride=1, padding=0, output_padding=0, groups=1, dilation=1): x = input + if x.dim() != 5: + raise RuntimeError(f'Expected 5D input to conv_transpose3d, but got input of size: {x.shape}') N,C,D,H,W = x.shape i,o,d,h,w = weight.shape assert C==i From 00c6cb104708b2c95f1b441fb4868baac92864cf Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E8=8C=83=E5=AD=99=E5=A5=87?= <2020012497@secoder.net> Date: Mon, 10 Jun 2024 16:39:05 +0800 Subject: [PATCH 26/73] modify stride positive check in jt.nn.conv_transpose; add input shape check in jt.nn.conv_transpose --- python/jittor/nn.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/python/jittor/nn.py b/python/jittor/nn.py index d992b360..f4dfcb1e 100644 --- a/python/jittor/nn.py +++ b/python/jittor/nn.py @@ -1516,14 +1516,16 @@ def execute(self, x): return conv_transpose3d(x, self.weight, self.bias, self.stride, self.padding, self.output_padding, self.group, self.dilation) def conv_transpose(input, weight, bias=None, stride=1, padding=0, output_padding=0, groups=1, dilation=1): - if stride <= 0: - raise RuntimeError("non-positive stride is not supported") if groups == 1: x = input + if x.dim() != 4: + raise RuntimeError(f'Expected 4D input to conv_transpose, but got input of size: {x.shape}') N,C,H,W = x.shape i,o,h,w = weight.shape assert C==i stride = stride if isinstance(stride, tuple) else (stride, stride) + if stride[0] <= 0 or stride[1] <= 0: + raise RuntimeError("non-positive stride is not supported") dilation = dilation if isinstance(dilation, tuple) else (dilation, dilation) # added padding = padding if isinstance(padding, tuple) else (padding, padding) @@ -1555,6 +1557,8 @@ def conv_transpose(input, weight, bias=None, stride=1, padding=0, output_padding assert not bias, "Bias should be none or jittor var" return y else: + if input.dim() != 4: + raise RuntimeError(f'Expected 4D input to conv_transpose, but got input of size: {input.shape}') N,C,H,W = input.shape i,o,h,w = weight.shape G = groups @@ -1563,6 +1567,8 @@ def conv_transpose(input, weight, bias=None, stride=1, padding=0, output_padding assert C % G == 0 assert C==i, (C, i) stride = stride if isinstance(stride, tuple) else (stride, stride) + if stride[0] <= 0 or stride[1] <= 0: + raise RuntimeError("non-positive stride is not supported") dilation = dilation if isinstance(dilation, tuple) else (dilation, dilation) # added padding = padding if isinstance(padding, tuple) else (padding, padding) From dc824d1bd7693f1ec42d9c68d903f08072627d8f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E8=8C=83=E5=AD=99=E5=A5=87?= <2020012497@secoder.net> Date: Mon, 10 Jun 2024 16:43:22 +0800 Subject: [PATCH 27/73] remove 3D(unbatch) description --- python/jittor/nn.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/jittor/nn.py b/python/jittor/nn.py index 18a43817..0382797c 100644 --- a/python/jittor/nn.py +++ b/python/jittor/nn.py @@ -1412,7 +1412,7 @@ def __init__(self, in_channels, out_channels, kernel_size, stride=1, \ def execute(self, x): if x.dim() != 4: - raise RuntimeError(f'Expected 3D (unbatched) or 4D (batched) input to conv_transpose2d, but got input of size: {x.shape}') + raise RuntimeError(f'Expected 4D (batched) input to conv_transpose2d, but got input of size: {x.shape}') if self.groups == 1: N,C,H,W = x.shape i,o,h,w = self.weight.shape From 5cd0051bb158a0ddaaa17279e0370501ff0bebc7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E8=8C=83=E5=AD=99=E5=A5=87?= <2020012497@secoder.net> Date: Mon, 10 Jun 2024 16:45:53 +0800 Subject: [PATCH 28/73] add stride check in jt.nn.ConvTranspose --- python/jittor/nn.py | 1 + 1 file changed, 1 insertion(+) diff --git a/python/jittor/nn.py b/python/jittor/nn.py index 0382797c..3db11962 100644 --- a/python/jittor/nn.py +++ b/python/jittor/nn.py @@ -1393,6 +1393,7 @@ def __init__(self, in_channels, out_channels, kernel_size, stride=1, \ self.real_padding = (self.dilation[0] * (self.kernel_size[0] - 1) - self.padding[0], self.dilation[1] * (self.kernel_size[1] - 1) - self.padding[1]) self.output_padding = output_padding if isinstance (output_padding, tuple) else (output_padding, output_padding) + assert self.stride[0] > 0 or self.stride[1] > 0,"stride must be non-negative" assert self.padding[0] >= 0 or self.padding[1] >= 0,"padding must be non-negative" assert self.output_padding[0] < max(self.stride[0], self.dilation[0]) and \ self.output_padding[1] < max(self.stride[1], self.dilation[1]), \ From f1925689e7322f1827791627603e6bfbf5103427 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E8=8C=83=E5=AD=99=E5=A5=87?= <2020012497@secoder.net> Date: Mon, 10 Jun 2024 16:48:27 +0800 Subject: [PATCH 29/73] modify error information --- python/jittor/nn.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/jittor/nn.py b/python/jittor/nn.py index 3db11962..c886e855 100644 --- a/python/jittor/nn.py +++ b/python/jittor/nn.py @@ -1393,7 +1393,7 @@ def __init__(self, in_channels, out_channels, kernel_size, stride=1, \ self.real_padding = (self.dilation[0] * (self.kernel_size[0] - 1) - self.padding[0], self.dilation[1] * (self.kernel_size[1] - 1) - self.padding[1]) self.output_padding = output_padding if isinstance (output_padding, tuple) else (output_padding, output_padding) - assert self.stride[0] > 0 or self.stride[1] > 0,"stride must be non-negative" + assert self.stride[0] > 0 or self.stride[1] > 0,"stride must be positive" assert self.padding[0] >= 0 or self.padding[1] >= 0,"padding must be non-negative" assert self.output_padding[0] < max(self.stride[0], self.dilation[0]) and \ self.output_padding[1] < max(self.stride[1], self.dilation[1]), \ From 0f5c7f86a0b85464ce9e0d28e54c9cdb28b2461a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E8=8C=83=E5=AD=99=E5=A5=87?= <2020012497@secoder.net> Date: Mon, 10 Jun 2024 17:02:05 +0800 Subject: [PATCH 30/73] check input shape in nn.Dropout2d --- python/jittor/nn.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/python/jittor/nn.py b/python/jittor/nn.py index b423b45c..15ee9a9b 100644 --- a/python/jittor/nn.py +++ b/python/jittor/nn.py @@ -589,6 +589,8 @@ def __init__(self, p=0.5, is_train=False): #TODO: test model.train() to change self.is_train def execute(self, input): output = input + if (input.dim() != 4) and (input.dim() != 3): + raise RuntimeError(f'Expected 3D (unbatched) or 4D (batched) input to Dropout2d, but got input of size: {input.shape}') shape = input.shape[:-2] if self.p > 0 and self.is_train: if self.p == 1: From bd105581f2317560b420a3c50e457eb692168793 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E8=8C=83=E5=AD=99=E5=A5=87?= <2020012497@secoder.net> Date: Mon, 10 Jun 2024 19:05:49 +0800 Subject: [PATCH 31/73] check input shape in jt.nn.ZeroPad2d --- python/jittor/nn.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/python/jittor/nn.py b/python/jittor/nn.py index b423b45c..ada2e878 100644 --- a/python/jittor/nn.py +++ b/python/jittor/nn.py @@ -1723,6 +1723,8 @@ def __init__(self, padding): raise ValueError(f"padding must be non-negative") def execute(self, x): + if x.dim() != 4: + raise RuntimeError("Input shape must be `(N, C, H, W)`!") n,c,h,w = x.shape return x.reindex([n,c,h+self.pt+self.pb,w+self.pl+self.pr], ["i0","i1",f"i2-{self.pt}",f"i3-{self.pl}"]) From 81eccbfca502beba8886dd96c5df8e5ed2c2af8d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E8=8C=83=E5=AD=99=E5=A5=87?= <2020012497@secoder.net> Date: Mon, 10 Jun 2024 19:08:53 +0800 Subject: [PATCH 32/73] check input shape in jt.nn.ReplicationPad2d --- python/jittor/nn.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/python/jittor/nn.py b/python/jittor/nn.py index b423b45c..87db5784 100644 --- a/python/jittor/nn.py +++ b/python/jittor/nn.py @@ -1771,6 +1771,8 @@ def __init__(self, padding): raise ValueError(f"padding must be non-negative") def execute(self, x): + if x.dim() != 4: + raise RuntimeError("Input shape must be `(N, C, H, W)`!") n,c,h,w = x.shape oh=h+self.pt+self.pb ow=w+self.pl+self.pr From 0dc433d4bd3be774b46591724371c5deaae5693b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E8=8C=83=E5=AD=99=E5=A5=87?= <2020012497@secoder.net> Date: Mon, 10 Jun 2024 19:25:53 +0800 Subject: [PATCH 33/73] check input shape and scale factor's positiveness in jt.nn.Upsample --- python/jittor/nn.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/python/jittor/nn.py b/python/jittor/nn.py index b423b45c..7837ed50 100644 --- a/python/jittor/nn.py +++ b/python/jittor/nn.py @@ -1937,8 +1937,12 @@ def _interpolate(img, x, y, ids, mode): # TODO: tf_mode to another function def resize(img, size, mode="nearest", align_corners=False, tf_mode=False): + if img.dim() != 3: + raise ValueError("Input shape must be `(N, C, H, W)`!") n, c, h, w = img.shape H, W = size + if h <= 0 or w <= 0 or H <= 0 or W <= 0: + raise RuntimeError(f"Input and output sizes should be greater than 0, but got input (H: {h}, W: {w}) output (H: {H}, W: {W})") nid, cid, hid, wid = jt.index((n, c, H, W)) if align_corners: x = hid * ((h - 1) / max(1, H - 1)) From ca63d37d8bf51ecfc43cb60088672bad61ea16d8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E8=8C=83=E5=AD=99=E5=A5=87?= <2020012497@secoder.net> Date: Mon, 10 Jun 2024 19:26:40 +0800 Subject: [PATCH 34/73] resume --- python/jittor/nn.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/python/jittor/nn.py b/python/jittor/nn.py index 7837ed50..b423b45c 100644 --- a/python/jittor/nn.py +++ b/python/jittor/nn.py @@ -1937,12 +1937,8 @@ def _interpolate(img, x, y, ids, mode): # TODO: tf_mode to another function def resize(img, size, mode="nearest", align_corners=False, tf_mode=False): - if img.dim() != 3: - raise ValueError("Input shape must be `(N, C, H, W)`!") n, c, h, w = img.shape H, W = size - if h <= 0 or w <= 0 or H <= 0 or W <= 0: - raise RuntimeError(f"Input and output sizes should be greater than 0, but got input (H: {h}, W: {w}) output (H: {H}, W: {W})") nid, cid, hid, wid = jt.index((n, c, H, W)) if align_corners: x = hid * ((h - 1) / max(1, H - 1)) From d42bbda19501743789d1bb39d822b022a0157a3e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E8=8C=83=E5=AD=99=E5=A5=87?= <2020012497@secoder.net> Date: Mon, 10 Jun 2024 19:27:29 +0800 Subject: [PATCH 35/73] check input shape and scale factor's positiveness in jt.nn.Upsample --- python/jittor/nn.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/python/jittor/nn.py b/python/jittor/nn.py index b423b45c..7837ed50 100644 --- a/python/jittor/nn.py +++ b/python/jittor/nn.py @@ -1937,8 +1937,12 @@ def _interpolate(img, x, y, ids, mode): # TODO: tf_mode to another function def resize(img, size, mode="nearest", align_corners=False, tf_mode=False): + if img.dim() != 3: + raise ValueError("Input shape must be `(N, C, H, W)`!") n, c, h, w = img.shape H, W = size + if h <= 0 or w <= 0 or H <= 0 or W <= 0: + raise RuntimeError(f"Input and output sizes should be greater than 0, but got input (H: {h}, W: {w}) output (H: {H}, W: {W})") nid, cid, hid, wid = jt.index((n, c, H, W)) if align_corners: x = hid * ((h - 1) / max(1, H - 1)) From 0ea0fd9351c6dc92668bea729922dc9ef2e72c86 Mon Sep 17 00:00:00 2001 From: DongYang Li <62846124+LDYang694@users.noreply.github.com> Date: Tue, 25 Jun 2024 16:47:09 +0800 Subject: [PATCH 36/73] Update setup.py fix numpy version --- setup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.py b/setup.py index 741165d2..beec50f6 100644 --- a/setup.py +++ b/setup.py @@ -62,7 +62,7 @@ package_data={'': ['*', '*/*', '*/*/*','*/*/*/*','*/*/*/*/*','*/*/*/*/*/*']}, # include_package_data=True, install_requires=[ - "numpy", + "numpy<2.0", "tqdm", "pillow", "astunparse", From 7416cfb7e461d93ccc7862d6edaf24b798b8399a Mon Sep 17 00:00:00 2001 From: DongYang Li <62846124+LDYang694@users.noreply.github.com> Date: Tue, 25 Jun 2024 16:49:44 +0800 Subject: [PATCH 37/73] update version --- python/jittor/__init__.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/python/jittor/__init__.py b/python/jittor/__init__.py index 007d702a..205888d9 100644 --- a/python/jittor/__init__.py +++ b/python/jittor/__init__.py @@ -9,7 +9,7 @@ # file 'LICENSE.txt', which is part of this source code package. # *************************************************************** -__version__ = '1.3.9.8' +__version__ = '1.3.9.9' from jittor_utils import lock with lock.lock_scope(): ori_int = int @@ -2174,4 +2174,4 @@ def inplace_wrapper(new_k, prev_func): if jt.compiler.has_acl: from jittor.extern.acl.acl_compiler import change_function - change_function() \ No newline at end of file + change_function() From 21e7409aff05c85e3d08e19d26cfdf3b489c04b8 Mon Sep 17 00:00:00 2001 From: fansunqi <392443298@qq.com> Date: Mon, 1 Jul 2024 12:23:03 +0800 Subject: [PATCH 38/73] check parameters' positive in jt.nn.fold --- python/jittor/nn.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/python/jittor/nn.py b/python/jittor/nn.py index b423b45c..3208a179 100644 --- a/python/jittor/nn.py +++ b/python/jittor/nn.py @@ -2434,14 +2434,19 @@ def unfold(X, kernel_size, dilation=1, padding=0, stride=1): def fold(X,output_size,kernel_size,dilation=1,padding=0,stride=1): assert X.ndim==3 + assert output_size[0] > 0 and output_size[1] > 0, "output size must be positive." if not isinstance(kernel_size,tuple): kernel_size = (kernel_size,kernel_size) + assert kernel_size[0] > 0 and kernel_size[1] > 0, "kernel size must be positive" if not isinstance(dilation,tuple): dilation = (dilation,dilation) + assert dilation[0] > 0 and dilation[1] > 0, "dilation must be positive" if not isinstance(padding,tuple): padding = (padding,padding) + assert padding[0] >= 0 and padding[1] >= 0, "padding must be non-negative" if not isinstance(stride,tuple): stride = (stride,stride) + assert stride[0] > 0 and stride[1] > 0, "stride must be positive" n,cl,num = X.shape area = kernel_size[0] * kernel_size[1] block_nums = [] From 810a0699bf7d3d5a58b7a9639401c7825cd60d5f Mon Sep 17 00:00:00 2001 From: fansunqi <392443298@qq.com> Date: Mon, 1 Jul 2024 15:41:24 +0800 Subject: [PATCH 39/73] check parameter's positive in jt.nn.unfold --- python/jittor/nn.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/python/jittor/nn.py b/python/jittor/nn.py index b423b45c..51068c93 100644 --- a/python/jittor/nn.py +++ b/python/jittor/nn.py @@ -2410,12 +2410,16 @@ def unfold(X, kernel_size, dilation=1, padding=0, stride=1): assert X.ndim == 4 if not isinstance(kernel_size, tuple): kernel_size = (kernel_size, kernel_size) + assert kernel_size[0] > 0 and kernel_size[1] > 0, "kernel size must be positive" if not isinstance(dilation, tuple): dilation = (dilation, dilation) + assert dilation[0] > 0 and dilation[1] > 0, "dilation must be positive" if not isinstance(padding, tuple): padding = (padding, padding) + assert padding[0] >= 0 and padding[1] >= 0, "padding must be non-negative" if not isinstance(stride, tuple): stride = (stride, stride) + assert stride[0] > 0 and stride[1] > 0, "stride must be positive" n, c, h, w = X.shape shape = X.shape area = kernel_size[0] * kernel_size[1] From b2f7f26bea05be67ffb0a21b5a9eb1a7c5f528fc Mon Sep 17 00:00:00 2001 From: DongYang Li <62846124+LDYang694@users.noreply.github.com> Date: Tue, 2 Jul 2024 20:02:02 +0800 Subject: [PATCH 40/73] update version --- python/jittor/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/jittor/__init__.py b/python/jittor/__init__.py index 205888d9..b495d55c 100644 --- a/python/jittor/__init__.py +++ b/python/jittor/__init__.py @@ -9,7 +9,7 @@ # file 'LICENSE.txt', which is part of this source code package. # *************************************************************** -__version__ = '1.3.9.9' +__version__ = '1.3.9.10' from jittor_utils import lock with lock.lock_scope(): ori_int = int From 7852283458b86192cb390725b8289d6cf12e6983 Mon Sep 17 00:00:00 2001 From: lidongyang Date: Fri, 5 Jul 2024 18:12:43 +0800 Subject: [PATCH 41/73] add isin --- python/jittor/misc.py | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/python/jittor/misc.py b/python/jittor/misc.py index 5d05f13e..5357baf2 100644 --- a/python/jittor/misc.py +++ b/python/jittor/misc.py @@ -2236,3 +2236,16 @@ def cuda(x): def expm1(x): return jt.exp(x) - 1 + + +def isin(elements, test_elements, assume_unique=False, invert=False): + + elements = elements.unsqueeze(-1) + test_elements = test_elements.unsqueeze(0) + comparison = elements == test_elements + result = comparison.any(dim=-1) + + if invert: + result = jt.logical_not(result) + + return result \ No newline at end of file From 46b290aefc572f28055756831885575e404831ac Mon Sep 17 00:00:00 2001 From: DongYang Li <62846124+LDYang694@users.noreply.github.com> Date: Mon, 8 Jul 2024 17:13:32 +0800 Subject: [PATCH 42/73] Update nn.py --- python/jittor/nn.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/python/jittor/nn.py b/python/jittor/nn.py index c886e855..5a4a9b68 100644 --- a/python/jittor/nn.py +++ b/python/jittor/nn.py @@ -1393,8 +1393,8 @@ def __init__(self, in_channels, out_channels, kernel_size, stride=1, \ self.real_padding = (self.dilation[0] * (self.kernel_size[0] - 1) - self.padding[0], self.dilation[1] * (self.kernel_size[1] - 1) - self.padding[1]) self.output_padding = output_padding if isinstance (output_padding, tuple) else (output_padding, output_padding) - assert self.stride[0] > 0 or self.stride[1] > 0,"stride must be positive" - assert self.padding[0] >= 0 or self.padding[1] >= 0,"padding must be non-negative" + assert self.stride[0] > 0 and self.stride[1] > 0,"stride must be positive" + assert self.padding[0] >= 0 and self.padding[1] >= 0,"padding must be non-negative" assert self.output_padding[0] < max(self.stride[0], self.dilation[0]) and \ self.output_padding[1] < max(self.stride[1], self.dilation[1]), \ "output padding must be smaller than max(stride, dilation)" From d4886b043527971ab02c914d38c67975527a81a0 Mon Sep 17 00:00:00 2001 From: lidongyang Date: Tue, 9 Jul 2024 14:28:17 +0800 Subject: [PATCH 43/73] polish nn.Sequential __getattr__ --- python/jittor/nn.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/python/jittor/nn.py b/python/jittor/nn.py index 3208a179..c8f36105 100644 --- a/python/jittor/nn.py +++ b/python/jittor/nn.py @@ -2320,8 +2320,8 @@ def __setattr__(self, key, value) -> None: def __getattr__(self, key): - if key in self.layers: - return self.layers[key] + if 'layers' in self.__dict__ and key in self.__dict__['layers']: + return self.__dict__['layers'][key] return super().__getattr__(key) From 02f6f6d1ca804eec9a522b2ae8948cd8219241af Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=BC=A0=E4=BB=AA?= Date: Tue, 9 Jul 2024 19:50:35 +0800 Subject: [PATCH 44/73] update acl --- python/jittor/extern/acl/acl_compiler.py | 373 ++++++++++++++++++++--- python/jittor/extern/acl/acl_jittor.cc | 2 +- python/jittor/extern/acl/acl_op_exec.cc | 15 + 3 files changed, 347 insertions(+), 43 deletions(-) diff --git a/python/jittor/extern/acl/acl_compiler.py b/python/jittor/extern/acl/acl_compiler.py index 0ab4b6eb..9478ce7a 100644 --- a/python/jittor/extern/acl/acl_compiler.py +++ b/python/jittor/extern/acl/acl_compiler.py @@ -12,6 +12,7 @@ import jittor.compiler as compiler import jittor as jt + has_acl = 0 cc_flags = "" tikcc_path = env_or_try_find('tikcc_path', 'ccec') @@ -120,40 +121,52 @@ def post_process(): def acl_cmd(name: str, inputs: list, output_dtypes: list, output_shapes: list, attr: dict): - + nchw_op = ['MaxPoolWithArgmaxV1','MaxPoolGradWithArgmaxV1', 'AvgPoolV2'] + attr_op = ['MaxPoolWithArgmaxV1','MaxPoolGradWithArgmaxV1', 'AvgPoolV2', 'AdaptiveAvgPool2d', 'AdaptiveAvgPool2dGrad'] + input_code = '' for i in range(len(inputs)): - if name == 'MaxPoolWithArgmaxV1' or name == 'MaxPoolGradWithArgmaxV1': + if name in nchw_op: input_code += f"op.add(in{i}, true, ACL_FORMAT_NCHW);\n" else: input_code += f"op.add(in{i}, true);\n" output_code = '' for i in range(len(output_dtypes)): - if name == 'MaxPoolWithArgmaxV1'or name == 'MaxPoolGradWithArgmaxV1': + if name in nchw_op: output_code += f"op.add(out{i}, false, ACL_FORMAT_NCHW);\n" else: output_code += f"op.add(out{i}, false);\n" # add attr to op attr_code = '' - if name == 'MaxPoolWithArgmaxV1' or name == 'MaxPoolGradWithArgmaxV1': + if name in attr_op: for k, v in attr.items(): - if k == "ceil_mode": - v = 'false' if v == False else 'true' - attr_code += f"op.set_attr(\"{k}\", {v});\n" + if isinstance(v, bool): + if v == True: + attr_code += f"op.set_attr(\"{k}\", 1, 1);\n" + else: + attr_code += f"op.set_attr(\"{k}\", 1, 0);\n" + elif isinstance(v, str): + attr_code += f"op.set_attr(\"{k}\", \"{v}\");\n" + elif k == 'divisor_override_value': + attr_code += f"op.set_attr(\"{k}\", int64_t({v}), 0);\n" else: v = str(v).replace('[', '{').replace(']', '}') attr_code += f"op.set_attr(\"{k}\", vector{v});\n" else: for k, v in attr.items(): if isinstance(v, bool): - attr_code += f"op.set_attr(\"{k}\", {str(v).lower()});\n" + if v == True: + attr_code += f"op.set_attr(\"{k}\", 1, 1);\n" + else: + attr_code += f"op.set_attr(\"{k}\", 1, 0);\n" elif isinstance(v, str): attr_code += f"op.set_attr(\"{k}\", \"{v}\");\n" else: - attr_code += f"op.set_attr(\"{k}\", int({v}));\n" - + attr_code += f"op.set_attr(\"{k}\", int({v}));\n" + + # print(attr_code) import jittor as jt return jt.code( output_shapes, @@ -317,29 +330,45 @@ def acl_cmd(name: str, inputs: list, output_dtypes: list, output_shapes: list, input_data.push_back(data); input_host_32.emplace_back(move(v)); } - + void set_attr(const string& key, bool value) { + // LOGir << "string bool" << "set_attr" << key << value; CHECK(aclopSetAttrBool(attr, key.c_str(), value)==0); } + void set_attr(const string& key, int value, int is_bool) { + // LOGir << "string bool" << "set_attr" << key << value << is_bool; + CHECK(aclopSetAttrBool(attr, key.c_str(), value==is_bool)==0); + } void set_attr(const string& key, float value) { + // LOGir << "string float" <<"set_attr" << key << value; CHECK(aclopSetAttrFloat(attr, key.c_str(), value)==0); } void set_attr(const string& key, int64_t value) { + // LOGir << "string int64" << "set_attr" << key << value; + CHECK(aclopSetAttrInt(attr, key.c_str(), value)==0); + } + void set_attr(const string& key, int64_t value, int placeholder) { + // LOGir << "string int64" << "set_attr" << key << value; CHECK(aclopSetAttrInt(attr, key.c_str(), value)==0); } void set_attr(const string& key, int32 value) { + // LOGir << "string int32" << "set_attr" << key << value; CHECK(aclopSetAttrInt(attr, key.c_str(), value)==0); } void set_attr(const string& key, vector value) { + // LOGir << "string vector" << "set_attr" << key << value; CHECK(aclopSetAttrListInt(attr, key.c_str(), value.size(), &value[0])==0); } void set_attr(const string& key, string value) { + // LOGir << "string string" << "set_attr" << key << value; CHECK(aclopSetAttrString(attr, key.c_str(), value.c_str())==0); } void set_attr(const char* key, const char* value) { + // LOGir << "char" << "set_attr" << key << value; CHECK(aclopSetAttrString(attr, key, value)==0); } + void run() { // printDeviceData(input_desc, input_data, name); @@ -380,29 +409,63 @@ class IndexACL(Function): def __init__(self): super(IndexACL, self).__init__() - def execute(self, inshape: list, dim: int, dtype="int32"): + def execute(self, inshape: list, dim, dtype="int32"): # zeros a tensor, shape is inshape, dtype is dtype - max_len = inshape[dim] - tmp = jt.zeros(max_len, dtype=dtype) - result = acl_cmd( - "Range", - [jt.Var(0), jt.Var(max_len), jt.Var(1)], - output_dtypes=[tmp.dtype], - output_shapes=[tmp.shape], - attr={})[0] - broadcast_dim = [] - for i in range(len(inshape)): - if i != dim: - broadcast_dim.append(i) - result = jt.broadcast(result, shape=inshape, dims=broadcast_dim) - return result + if dim == None: + dim = [i for i in range(len(inshape))] + elif type(dim) == int: + dim = [dim] + results = [] + for d in dim: + max_len = inshape[d] + tmp = jt.zeros(max_len, dtype=dtype) + result = acl_cmd( + "Range", + [jt.Var(0), jt.Var(max_len), jt.Var(1)], + output_dtypes=[tmp.dtype], + output_shapes=[tmp.shape], + attr={})[0] + broadcast_dim = [] + for i in range(len(inshape)): + if i != d: + broadcast_dim.append(i) + result = jt.broadcast(result, shape=inshape, dims=broadcast_dim) + results.append(result) + if len(results) != 1: + return tuple(results) + else: + return results[0] def grad(self, grad_output): return grad_output - class PoolACL(Function): + def get_paddings(self): + pad_top = self.padding[0] + pad_left = self.padding[1] + H = self.input.shape[-2] + W = self.input.shape[-1] + + totalH = H + 2 * self.padding[0] - self.kernel_size[0] + totalW = W + 2 * self.padding[1] - self.kernel_size[1] + + kH = (totalH + self.stride[0] - 1) // self.stride[0] + 1 if self.attr['ceil_mode'] else totalH // self.stride[0] + 1 + kW = (totalW + self.stride[1] - 1) // self.stride[1] + 1 if self.attr['ceil_mode'] else totalW // self.stride[1] + 1 + + if self.attr['ceil_mode']: + if (kH - 1) * self.stride[0] >= H + self.padding[0]: + kH -= 1 + need_pad_h = (kH - 1) * self.stride[0] + self.kernel_size[0] - H + pad_top = need_pad_h - self.padding[0] + if (kW - 1) * self.stride[1] >= W + self.padding[1]: + kW -= 1 + need_pad_w = (kW - 1) * self.stride[1] + self.kernel_size[1] - W + pad_left = need_pad_w - self.padding[1] + + pads = [self.padding[0], pad_top, self.padding[1], pad_left] + return pads + def __init__(self, kernel_size, stride=None, @@ -410,9 +473,9 @@ def __init__(self, dilation=None, return_indices=None, ceil_mode=False, + count_include_pad=True, op='maximum'): super(PoolACL, self).__init__() - import jittor as jt # set attr self.kernel_size = kernel_size if isinstance( kernel_size, tuple) else (kernel_size, kernel_size) @@ -424,16 +487,28 @@ def __init__(self, self.dilation = dilation if isinstance(dilation, tuple) else (dilation, dilation) attr = {} - attr['ksize'] = [1, self.kernel_size[0], self.kernel_size[1], 1] - attr['strides'] = [1, self.stride[0], self.stride[1], 1] - attr['pads'] = [1, self.padding[0], self.padding[1], 1] - attr['dilation'] = [1, self.dilation[0], self.dilation[1], 1] - attr['ceil_mode'] = ceil_mode - - self.attr = attr + self.return_indices = return_indices self.uint16 = jt.Var(1).int32().dtype self.op = op + + if op == 'mean': + attr['exclusive'] = not count_include_pad + attr['global_pooling'] = False + attr['divisor_override_value'] = 0 + attr['ksize'] = [1, 1, self.kernel_size[0], self.kernel_size[1]] + attr['strides'] = [1, 1, self.stride[0], self.stride[1]] + attr['ceil_mode'] = ceil_mode + attr['padding_mode'] = 'CALCULATED' + attr['data_format'] = 'NCHW' + elif op == 'maximum': + attr['ksize'] = [1, self.kernel_size[0], self.kernel_size[1], 1] + attr['strides'] = [1, self.stride[0], self.stride[1], 1] + attr['pads'] = [1, self.padding[0], self.padding[1], 1] + attr['dilation'] = [1, self.dilation[0], self.dilation[1], 1] + # attr['ceil_mode'] = ceil_mode + + self.attr = attr def execute(self, input): @@ -442,7 +517,6 @@ def execute(self, input): input_dtype = input.dtype self.input = input - # create output output_shape = [ input_shape[0], input_shape[1], @@ -452,22 +526,42 @@ def execute(self, input): (self.kernel_size[1] - 1) - 1) // self.stride[1] + 1 ] output_dtype = input_dtype - result = acl_cmd("MaxPoolWithArgmaxV1", [input], + + if self.op == 'mean': + self.attr['pads'] = self.get_paddings() + result = acl_cmd("AvgPoolV2", [input], + output_dtypes=[output_dtype], + output_shapes=[output_shape], + attr=self.attr) + elif self.op == 'maximum': + result = acl_cmd("MaxPoolWithArgmaxV1", [input], output_dtypes=[output_dtype, self.uint16], output_shapes=[output_shape, output_shape], attr=self.attr) - self.index = result[1] + else: + raise ValueError('no this type pool') + + if self.op == 'maximum': + self.index = result[1] + if self.return_indices: return result[0], result[1] else: return result[0] def grad(self, grad_output): - - grad_input = acl_cmd("MaxPoolGradWithArgmaxV1", [self.input, grad_output, self.index], + if self.op == 'maximum': + grad_input = acl_cmd("MaxPoolGradWithArgmaxV1", [self.input, grad_output, self.index], output_dtypes=[grad_output.dtype], output_shapes=[self.input.shape], attr=self.attr)[0] + elif self.op == 'mean': + grad_input = acl_cmd("AvgPoolV2", [self.input, grad_output, self.index], + output_dtypes=[grad_output.dtype], + output_shapes=[self.input.shape], + attr=self.attr)[0] + else: + grad_input = None return grad_input class BmmACL(Function): @@ -791,19 +885,211 @@ def execute(self, input, perm): def grad(self, grad_output): return grad_output + class AdaptiveMaxPool2dACL(Function): + def __init__(self, + output_size, + return_indices=False, + ): + super(AdaptiveMaxPool2dACL, self).__init__() + self.output_size = (output_size, output_size) if isinstance( + output_size, int) else output_size + + self.return_indices = return_indices + self.uint16 = jt.Var(1).int32().dtype + + attr = {} + attr['ceil_mode'] = False + attr['dilations'] = [1,1,1,1] + self.attr = attr + + + def execute(self, input): + input_shape = input.shape + input_dtype = input.dtype + + output_shape = [ + input_shape[0], input_shape[1], + self.output_size[0], self.output_size[1] + ] + output_dtype = input_dtype + self.input = input + + stride_h = input_shape[2] // output_shape[2]; + stride_w = input_shape[3] // output_shape[3]; + kernel_size_h = input_shape[2] - (output_shape[2] - 1) * stride_h; + kernel_size_w = input_shape[3] - (output_shape[3] - 1) * stride_w; + + stride = [0, 0] + kernel_size = [0, 0] + padding = [0, 0] + + stride[0] = stride_h; + stride[1] = stride_w; + kernel_size[0] = kernel_size_h; + kernel_size[1] = kernel_size_w; + padding[0] = padding[1] = 0; + kernel_sizes = [1, kernel_size[0], kernel_size[1], 1]; + strides_size = [1, stride[0], stride[1], 1]; + paddings = [1, padding[0], padding[1], 1]; + + self.attr['ksize'] = kernel_sizes + self.attr['strides'] = strides_size + self.attr['pads'] = paddings + + result = acl_cmd("MaxPoolWithArgmaxV1", [input], + output_dtypes=[output_dtype, self.uint16], + output_shapes=[output_shape, output_shape], + attr=self.attr) + + self.index = result[1] + + if self.return_indices: + return result[0], result[1] + else: + return result[0] + + def grad(self, grad_output): + grad_input = acl_cmd("MaxPoolGradWithArgmaxV1", [self.input, grad_output, self.index], + output_dtypes=[grad_output.dtype], + output_shapes=[self.input.shape], + attr=self.attr)[0] + return grad_input + + class AdaptiveAvgPool2dACL(Function): + def __init__(self, + output_size + ): + super(AdaptiveAvgPool2dACL, self).__init__() + self.output_size = (output_size, output_size) if isinstance( + output_size, int) else output_size + + attr = {} + if isinstance(output_size, tuple): + output_size = [output_size[0], output_size[1]] + attr['output_size'] = output_size + self.attr = attr + + def execute(self, input): + input_shape = input.shape + input_dtype = input.dtype + + self.original_shape = input_shape + + output_shape = [ + input_shape[0], input_shape[1], + self.attr['output_size'][0], self.attr['output_size'][1] + ] + output_dtype = input_dtype + self.input = input + + result = acl_cmd("AdaptiveAvgPool2d", [input], + output_dtypes=[output_dtype], + output_shapes=[output_shape], + attr=self.attr) + + + return result[0] + + def grad(self, grad_output): + attr = {} + attr['orig_input_shape'] = list(self.original_shape) + grad_input = acl_cmd("AdaptiveAvgPool2dGrad", [grad_output], + output_dtypes=[grad_output.dtype], + output_shapes=[self.original_shape], + attr=attr)[0] + return grad_input + + class CumsumACL(Function): + def __init__(self): + super(CumsumACL, self).__init__() + + def execute(self, input, dim=-1): + self.input = input + self.dim = dim + result = acl_cmd("Cumsum", [input, jt.Var(dim)], + output_dtypes=[input.dtype], + output_shapes=[input.shape], + attr={})[0] + return result + + def grad(self, grad_output): + # TODO flip算子未适配 + flipped_grad_output = jt.flip(grad_output, dims=[self.dim]) + cumulative_grad = acl_cmd("Cumsum", [flipped_grad_output, jt.Var(self.dim)], + output_dtypes=[grad_output.dtype], + output_shapes=[grad_output.shape], + attr={})[0] + return jt.flip(cumulative_grad, dims=[self.dim]) + + class GatherACL(Function): + def __init__(self): + super(GatherACL, self).__init__() + + def execute(self, input, dim, index): + self.input = input + self.dim = dim + self.index = index + result = acl_cmd("GatherElements", [input, index], + output_dtypes=[index.dtype], + output_shapes=[index.shape], + attr={'dim':dim})[0] + return result + + def grad(self, grad_output): + # TODO + grad_input = acl_cmd("ScatterElements", [jt.zeros(self.input.shape, dtype=grad_output.dtype), self.index, grad_output], + output_dtypes=[grad_output.dtype], + output_shapes=[self.input.shape], + attr={'axis':self.dim})[0] + return grad_input, None, None + + class WhereACL(Function): + def __init__(self): + super(WhereACL, self).__init__() + + def execute(self, condition, x, y): + self.condition = condition + + if x.dtype != y.dtype: + if x.dtype == jt.float32: + y = y.float32() + elif y.dtype == jt.float32: + x = x.float32() + else: + x = x.to(y.dtype) + + self.x = x + self.y = y + + result = acl_cmd("Select", [condition, x, y], + output_dtypes=[x.dtype], + output_shapes=[x.shape], + attr={})[0] + return result + + def grad(self, grad_output): + # TODO + return grad_output, None, None + def warp(origin_func, new_func): def warpper(*args, **kwargs): - if jt.flags.use_acl: - return new_func(*args, **kwargs) if origin_func == jt.index: if len(args) == 2 and args[1] == None: args = tuple(list(args[0:1])) + if jt.flags.use_acl: + if isinstance(new_func, IndexACL): + if len(args) == 1: + args = (args[0], None) + return new_func(*args, **kwargs) return origin_func(*args, **kwargs) return warpper + jt.index = warp(jt.index, IndexACL()) jt.Var.index = lambda x, dim=None: warp(jt.index, IndexACL())(x.shape, dim) jt.nn.Pool = warp(jt.nn.Pool, PoolACL) + jt.nn.AdaptiveMaxPool2d = warp(jt.nn.AdaptiveMaxPool2d, AdaptiveMaxPool2dACL) + jt.nn.AdaptiveAvgPool2d = warp(jt.nn.AdaptiveAvgPool2d, AdaptiveAvgPool2dACL) jt.triu = warp(jt.triu, TriuACL()) jt.triu_ = warp(jt.triu, TriuACL()) @@ -815,6 +1101,9 @@ def warpper(*args, **kwargs): jt.setitem = warp(jt.setitem, SetItemACL()) jt.Var.setitem = lambda x, slices, value: warp(jt.setitem, SetItemACL())(x, slices, value) + jt.cumsum = warp(jt.cumsum, CumsumACL()) + jt.gather = warp(jt.gather, GatherACL()) + jt.where = warp(jt.where, WhereACL()) # jt.nn.bmm = warp(jt.nn.bmm, BmmACL()) # jt.bmm = warp(jt.bmm, BmmACL()) # jt.nn.matmul = warp(jt.matmul, MatmulACL()) diff --git a/python/jittor/extern/acl/acl_jittor.cc b/python/jittor/extern/acl/acl_jittor.cc index 07639ab0..be0c17f4 100644 --- a/python/jittor/extern/acl/acl_jittor.cc +++ b/python/jittor/extern/acl/acl_jittor.cc @@ -232,7 +232,7 @@ void acl_jittor_op_compiler(string& filename, string& src, bool is_acl, string& src = new_src; new_src = token_replace_all(new_src, "atomicAdd(&$1,$2);", "$1=$1+$2;"); - new_src = token_replace_all(new_src, "bool", "int8"); + // new_src = token_replace_all(new_src, "bool", "int8"); new_src = token_replace_all(new_src, "::numeric_min()", "-1e30"); new_src = token_replace_all(new_src, "::numeric_max()", "1e30"); // TODO: support max diff --git a/python/jittor/extern/acl/acl_op_exec.cc b/python/jittor/extern/acl/acl_op_exec.cc index 275f0010..07b35145 100644 --- a/python/jittor/extern/acl/acl_op_exec.cc +++ b/python/jittor/extern/acl/acl_op_exec.cc @@ -188,24 +188,39 @@ struct AclOpRunner { } void set_attr(const string& key, bool value) { + // LOGir << "string bool" << "set_attr" << key << value; CHECK(aclopSetAttrBool(attr, key.c_str(), value)==0); } + void set_attr(const string& key, int value, int is_bool) { + // LOGir << "string bool" << "set_attr" << key << value << is_bool; + CHECK(aclopSetAttrBool(attr, key.c_str(), value==is_bool)==0); + } void set_attr(const string& key, float value) { + // LOGir << "string float" <<"set_attr" << key << value; CHECK(aclopSetAttrFloat(attr, key.c_str(), value)==0); } void set_attr(const string& key, int64_t value) { + // LOGir << "string int64" << "set_attr" << key << value; + CHECK(aclopSetAttrInt(attr, key.c_str(), value)==0); + } + void set_attr(const string& key, int64_t value, int placeholder) { + // LOGir << "string int64" << "set_attr" << key << value; CHECK(aclopSetAttrInt(attr, key.c_str(), value)==0); } void set_attr(const string& key, int32 value) { + // LOGir << "string int32" << "set_attr" << key << value; CHECK(aclopSetAttrInt(attr, key.c_str(), value)==0); } void set_attr(const string& key, vector value) { + // LOGir << "string vector" << "set_attr" << key << value; CHECK(aclopSetAttrListInt(attr, key.c_str(), value.size(), &value[0])==0); } void set_attr(const string& key, string value) { + // LOGir << "string string" << "set_attr" << key << value; CHECK(aclopSetAttrString(attr, key.c_str(), value.c_str())==0); } void set_attr(const char* key, const char* value) { + // LOGir << "char" << "set_attr" << key << value; CHECK(aclopSetAttrString(attr, key, value)==0); } From 6f782bf174e6fb7c1c3bd45622fb63441bb0d5c1 Mon Sep 17 00:00:00 2001 From: Yi Zhang Date: Tue, 9 Jul 2024 21:48:35 +0800 Subject: [PATCH 45/73] Update acl_compiler.py --- python/jittor/extern/acl/acl_compiler.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/python/jittor/extern/acl/acl_compiler.py b/python/jittor/extern/acl/acl_compiler.py index 9478ce7a..68197e58 100644 --- a/python/jittor/extern/acl/acl_compiler.py +++ b/python/jittor/extern/acl/acl_compiler.py @@ -1080,6 +1080,9 @@ def warpper(*args, **kwargs): if isinstance(new_func, IndexACL): if len(args) == 1: args = (args[0], None) + if isinstance(new_func, CumsumACL): + args = (args[0], kwargs.get('dim', -1)) + kwargs = {} return new_func(*args, **kwargs) return origin_func(*args, **kwargs) return warpper From 28e1920e97915938c86d9657e2f83fcf4d0ae0e3 Mon Sep 17 00:00:00 2001 From: hanyx Date: Tue, 9 Jul 2024 22:27:28 +0800 Subject: [PATCH 46/73] ComplexNumber:polar,view_as_complex,view_as_real --- python/jittor/nn.py | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/python/jittor/nn.py b/python/jittor/nn.py index 30f96dcd..f2181a8f 100644 --- a/python/jittor/nn.py +++ b/python/jittor/nn.py @@ -3171,6 +3171,17 @@ def ifft2(self): return ComplexNumber(_fft2(self.value, inverse=True), is_concat_value=True) +def polar(abs:jt.Var, angle: jt.Var) -> ComplexNumber: + assert abs.shape == angle.shape + return ComplexNumber(abs * angle.cos(),abs * angle.sin()) + +def view_as_complex(x: jt.Var) -> ComplexNumber: + assert x.shape[-1] == 2 + return ComplexNumber(x[...,0],x[...,1]) + +def view_as_real(x: ComplexNumber) -> jt.Var: + return jt.stack([x.value[...,0],x.value[...,1]],dim=-1) + def one_hot(x: jt.Var, num_classes: int=-1) -> jt.Var: ''' Returns the one_hot encoding of inputs. From 6eef0f8c6e47644a5de70b293af710accbda5b81 Mon Sep 17 00:00:00 2001 From: Jiapeng Zhang <46623500+zjp-shadow@users.noreply.github.com> Date: Wed, 10 Jul 2024 19:58:07 +0800 Subject: [PATCH 47/73] fix load bugs fix load bugs of state --- python/jittor/__init__.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/python/jittor/__init__.py b/python/jittor/__init__.py index b495d55c..dc17a228 100644 --- a/python/jittor/__init__.py +++ b/python/jittor/__init__.py @@ -1614,6 +1614,8 @@ def load_parameters(self, params): else: if hasattr(v, k): v = getattr(v, k) + if v is None: + continue assert isinstance(v, (Module, Var)), \ f"expect a jittor Module or Var, but got <{v.__class__.__name__}>, key: {key}" else: From d1d39d27ef56283dba168338311da9b8d088e390 Mon Sep 17 00:00:00 2001 From: lidongyang Date: Fri, 12 Jul 2024 15:07:35 +0800 Subject: [PATCH 48/73] add no gpu device error --- python/jittor/compiler.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/python/jittor/compiler.py b/python/jittor/compiler.py index c6be01bb..a5dcd136 100644 --- a/python/jittor/compiler.py +++ b/python/jittor/compiler.py @@ -1002,6 +1002,8 @@ def check_debug_flags(): r, s = sp.getstatusoutput(f"log_v=0 {sys.executable} -m jittor_utils.query_cuda_cc") if r==0: s = sorted(list(set(s.strip().split()))) + if len(s)==0: + LOG.e("No GPU Device Found!") cu += "_sm_" + "_".join(s) if "cuda_arch" not in os.environ: os.environ["cuda_arch"] = " ".join(cu) From 7347b7079c440c09f9c4efdc71619347148ecbf3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=82=93=E4=B8=80=E8=BD=A9?= <2021013404@secoder.net> Date: Thu, 18 Jul 2024 20:04:40 +0800 Subject: [PATCH 49/73] FEAT! where,scatter,cumsum,gather,flip --- python/jittor/extern/acl/acl_compiler.py | 106 +++++++++++++++++++---- 1 file changed, 89 insertions(+), 17 deletions(-) diff --git a/python/jittor/extern/acl/acl_compiler.py b/python/jittor/extern/acl/acl_compiler.py index 68197e58..80c3c8d1 100644 --- a/python/jittor/extern/acl/acl_compiler.py +++ b/python/jittor/extern/acl/acl_compiler.py @@ -11,7 +11,7 @@ import glob import jittor.compiler as compiler import jittor as jt - +import pdb has_acl = 0 cc_flags = "" @@ -122,7 +122,7 @@ def post_process(): def acl_cmd(name: str, inputs: list, output_dtypes: list, output_shapes: list, attr: dict): nchw_op = ['MaxPoolWithArgmaxV1','MaxPoolGradWithArgmaxV1', 'AvgPoolV2'] - attr_op = ['MaxPoolWithArgmaxV1','MaxPoolGradWithArgmaxV1', 'AvgPoolV2', 'AdaptiveAvgPool2d', 'AdaptiveAvgPool2dGrad'] + attr_op = ['MaxPoolWithArgmaxV1','MaxPoolGradWithArgmaxV1', 'AvgPoolV2', 'AdaptiveAvgPool2d', 'AdaptiveAvgPool2dGrad', 'ReverseV2'] input_code = '' for i in range(len(inputs)): @@ -166,7 +166,8 @@ def acl_cmd(name: str, inputs: list, output_dtypes: list, output_shapes: list, else: attr_code += f"op.set_attr(\"{k}\", int({v}));\n" - # print(attr_code) + #print("input_code",input_code) + #print("attr_code",attr_code) import jittor as jt return jt.code( output_shapes, @@ -776,7 +777,7 @@ def stride(self, x, dim): stride *= x.shape[i] return stride - def execute(self, x, slices, value): + def execute(self, x, slices, value, reduce = 'void'): self.is_tensor = type(value) == jt.Var if type(value) != jt.Var: value = jt.array(value) @@ -1013,14 +1014,19 @@ def execute(self, input, dim=-1): return result def grad(self, grad_output): - # TODO flip算子未适配 - flipped_grad_output = jt.flip(grad_output, dims=[self.dim]) + flipped_grad_output = acl_cmd("ReverseV2", [grad_output, jt.Var([self.dim])], + output_dtypes=[grad_output.dtype], + output_shapes=[grad_output.shape], + attr={})[0] cumulative_grad = acl_cmd("Cumsum", [flipped_grad_output, jt.Var(self.dim)], output_dtypes=[grad_output.dtype], output_shapes=[grad_output.shape], attr={})[0] - return jt.flip(cumulative_grad, dims=[self.dim]) - + grad_input = acl_cmd("ReverseV2", [cumulative_grad, jt.Var([self.dim])], + output_dtypes=[grad_output.dtype], + output_shapes=[grad_output.shape], + attr={})[0] + return grad_input class GatherACL(Function): def __init__(self): super(GatherACL, self).__init__() @@ -1029,20 +1035,43 @@ def execute(self, input, dim, index): self.input = input self.dim = dim self.index = index + result = acl_cmd("GatherElements", [input, index], - output_dtypes=[index.dtype], + output_dtypes=[input.dtype], output_shapes=[index.shape], attr={'dim':dim})[0] return result def grad(self, grad_output): - # TODO - grad_input = acl_cmd("ScatterElements", [jt.zeros(self.input.shape, dtype=grad_output.dtype), self.index, grad_output], + tmp = jt.zeros(self.index.shape,dtype=grad_output.dtype) + grad_input = acl_cmd("ScatterElements", [tmp, self.index, grad_output], output_dtypes=[grad_output.dtype], - output_shapes=[self.input.shape], - attr={'axis':self.dim})[0] - return grad_input, None, None + output_shapes=[self.index.shape], + attr={'axis':self.dim, 'reduction':"add"})[0] + return grad_input + + class ScatterACL(Function): + def __init__(self): + super(ScatterACL, self).__init__() + def execute(self, input, dim, index, src, reduce='void'): + self.input = input + self.dim = dim + self.index = index + self.reduce = reduce + result = acl_cmd("ScatterElements", [input, self.index, src], + output_dtypes=[input.dtype], + output_shapes=[index.shape], + attr={'axis':self.dim, 'reduction':reduce})[0] + return result + + def grad(self, grad_output): + grad_input = acl_cmd("GatherElements", [grad_output, self.index], + output_dtypes=[grad_output.dtype], + output_shapes=[self.index.shape], + attr={'dim':self.dim})[0] + return grad_output, None, None, grad_input + class WhereACL(Function): def __init__(self): super(WhereACL, self).__init__() @@ -1068,9 +1097,44 @@ def execute(self, condition, x, y): return result def grad(self, grad_output): - # TODO - return grad_output, None, None + tmp = jt.zeros(grad_output.shape,dtype=grad_output.dtype) + grad_x = acl_cmd("Select", [self.condition, grad_output, tmp], + output_dtypes=[self.x.dtype], + output_shapes=[self.x.shape], + attr={})[0] + + grad_y = acl_cmd("Select", [self.condition, tmp, grad_output], + output_dtypes=[self.y.dtype], + output_shapes=[self.y.shape], + attr={})[0] + return grad_output, grad_x, grad_y + class FlipACL(Function): + def __init__(self): + super(FlipACL, self).__init__() + + + def execute(self, input, dim): + self.input = input + #if isinstance(dim_vector, tuple): + dim_vector = jt.Var(list(dim)) + #print(dim_vector.dtype) + self.dim_vector = dim_vector + #print(input, dim_vector) + result = acl_cmd("ReverseV2", [input, dim_vector], + output_dtypes=[input.dtype], + output_shapes=[input.shape], + attr={})[0] + return result + + def grad(self, grad_output): + #print(grad_output) + grad_input = acl_cmd("ReverseV2", [grad_output, self.dim_vector], + output_dtypes=[grad_output.dtype], + output_shapes=[grad_output.shape], + attr={})[0] + return grad_input + def warp(origin_func, new_func): def warpper(*args, **kwargs): if origin_func == jt.index: @@ -1083,6 +1147,9 @@ def warpper(*args, **kwargs): if isinstance(new_func, CumsumACL): args = (args[0], kwargs.get('dim', -1)) kwargs = {} + if isinstance(new_func, ScatterACL): + args = (args[0], args[1], args[2], args[3], kwargs.get('reduce', 'void')) + kwargs = {} return new_func(*args, **kwargs) return origin_func(*args, **kwargs) return warpper @@ -1102,11 +1169,16 @@ def warpper(*args, **kwargs): jt.getitem = warp(jt.getitem, GetItem()) jt.Var.getitem = lambda x, slices: warp(jt.getitem, GetItem())(x, slices) jt.setitem = warp(jt.setitem, SetItemACL()) - jt.Var.setitem = lambda x, slices, value: warp(jt.setitem, SetItemACL())(x, slices, value) + jt.Var.setitem = lambda x, slices, value, reduce='void': warp(jt.setitem, SetItemACL())(x, slices, value, reduce) + jt.misc.flip = warp(jt.misc.flip, FlipACL()) + jt.Var.flip = lambda x, dim_vector: warp(jt.misc.flip, FlipACL())(x, dim_vector) jt.cumsum = warp(jt.cumsum, CumsumACL()) jt.gather = warp(jt.gather, GatherACL()) + jt.scatter = warp(jt.scatter, ScatterACL()) jt.where = warp(jt.where, WhereACL()) + + # jt.nn.bmm = warp(jt.nn.bmm, BmmACL()) # jt.bmm = warp(jt.bmm, BmmACL()) # jt.nn.matmul = warp(jt.matmul, MatmulACL()) From 357a0de9b41b196da51ffd5566921b5a675ade5b Mon Sep 17 00:00:00 2001 From: CHEN Xinsheng Date: Fri, 19 Jul 2024 16:42:55 +0800 Subject: [PATCH 50/73] fix dtype mismatch in `nn.cross_entropy_loss` --- python/jittor/nn.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/jittor/nn.py b/python/jittor/nn.py index 5cc8a95d..623daeeb 100644 --- a/python/jittor/nn.py +++ b/python/jittor/nn.py @@ -386,7 +386,7 @@ def cross_entropy_loss(output, target, weight=None, ignore_index=None,reduction= if ignore_index is not None: target_weight = jt.ternary( target==ignore_index, - jt.array(0).broadcast(target_weight), + jt.array(0).broadcast(target_weight).type_as(target_weight), target_weight ) From bbc448a31b919f674b1660cf6f2704904580ab5d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=82=93=E4=B8=80=E8=BD=A9?= <2021013404@secoder.net> Date: Fri, 19 Jul 2024 17:01:17 +0800 Subject: [PATCH 51/73] FEAT! add aclop unittest --- python/jittor/test/test_aclop.py | 297 +++++++++++++++++++++++++++++++ 1 file changed, 297 insertions(+) create mode 100644 python/jittor/test/test_aclop.py diff --git a/python/jittor/test/test_aclop.py b/python/jittor/test/test_aclop.py new file mode 100644 index 00000000..2cb52bce --- /dev/null +++ b/python/jittor/test/test_aclop.py @@ -0,0 +1,297 @@ +import unittest +import jittor as jt +from .test_core import expect_error +import numpy as np +from jittor import init, Module +import numpy as np +@unittest.skipIf(not jt.compiler.has_acl, "No ACL found") +class TestACL(unittest.TestCase): + @jt.flag_scope(use_acl=1) + def test_getitem(self): + a = jt.ones(100, 2) + b = a[0:2, 0:2] + np.testing.assert_allclose(b.numpy(), [[1,1],[1,1]]) + print("test getitem success") + + @jt.flag_scope(use_acl=1) + def test_setitem(self): + a = jt.ones(2, 2) + b = jt.Var(0) + a[0:1, 0:1] = b + np.testing.assert_allclose(a.numpy(), [[0,1],[1,1]]) + print("test setitem success") + + @jt.flag_scope(use_acl=1) + def test_getitem_grad(self): + a = jt.ones(2, 2) + b = a[0:1, 0:1] + optimizer = jt.optim.SGD([a], 0.1) + loss = b.sum() + optimizer.zero_grad() + optimizer.backward(loss) + optimizer.step() + res=a.opt_grad(optimizer) + np.testing.assert_allclose(res.numpy(), [[1,0],[0,0]]) + print("test getitem grad success") + + @jt.flag_scope(use_acl=1) + def test_setitem_grad(self): + a = jt.ones(3, 3) + b = jt.ones(2, 2) + a[0:2, 0:2] = b * 2 + optimizer = jt.optim.SGD([a, b], 0.1) + loss = a.sum() + optimizer.zero_grad() + optimizer.backward(loss) + optimizer.step() + res_a = a.opt_grad(optimizer) + res_b = b.opt_grad(optimizer) + np.testing.assert_allclose(res_a.numpy(), [[1,1,1],[1,1,1],[1,1,1]]) + np.testing.assert_allclose(res_b.numpy(), [[2,2],[2,2]]) + print("test setitem grad success") + + @jt.flag_scope(use_acl=1) + def test_concat(self): + a = jt.ones(2, 2) + b = jt.ones(2, 2) + c = jt.concat([a, b], 0) + np.testing.assert_allclose(c.numpy(), [[1,1],[1,1],[1,1],[1,1]]) + print("test concat success") + + @jt.flag_scope(use_acl=1) + def test_maxpool_grad(self): + a = jt.ones(1, 1, 4, 4) + max_pool = jt.nn.Pool(2, op='maximum') + optimizer = jt.optim.SGD([a], 0.1) + b = max_pool(a) + loss = b.sum() + optimizer.zero_grad() + optimizer.backward(loss) + optimizer.step() + res=a.opt_grad(optimizer) + np.testing.assert_allclose(res.numpy(), [[[[1,0,1,0],[0,0,0,0],[1,0,1,0],[0,0,0,0]]]]) + print("test maxpool grad success") + + @jt.flag_scope(use_acl=1) + def test_triu(self): + a = jt.ones(3, 3) + b = jt.triu_(a, 0) + c = jt.triu_(a, 1) + np.testing.assert_allclose(b.numpy(), [[1,1,1],[0,1,1],[0,0,1]]) + np.testing.assert_allclose(c.numpy(), [[0,1,1],[0,0,1],[0,0,0]]) + print("test triu success") + + @jt.flag_scope(use_acl=1) + def test_bmm(self): + a = jt.ones(3, 2, 2).float32() + b = jt.bmm(a, a) + np.testing.assert_allclose(b.numpy(), [[[2,2],[2,2]],[[2,2],[2,2]],[[2,2],[2,2]]]) + print("test bmm success") + + @jt.flag_scope(use_acl=1) + def test_matmul(self): + a = jt.ones(1, 4, 4) + b = jt.ones(4, 2) + c = jt.matmul(a, b) + np.testing.assert_allclose(c.numpy(), [[[4,4],[4,4],[4,4],[4,4]]]) + print("test matmul success") + + @jt.flag_scope(use_acl=1) + def test_maxpool(self): + a = jt.ones(1, 1, 4, 4) + max_pool = jt.nn.Pool(2, op='maximum') + np.testing.assert_allclose(max_pool(a).numpy(), [[[[1,1],[1,1]]]]) + print("test maxpool success") + + @jt.flag_scope(use_acl=1) + def test_transpose(self): + a = jt.ones(1, 2, 2) + b = a.transpose(0, 2) + np.testing.assert_allclose(b.numpy(), [[[1],[1]],[[1],[1]]]) + print("test transpose success") + + @jt.flag_scope(use_acl=1) + def test_matmul_grad(self): + a = jt.ones(1, 2, 2) + b = jt.ones(2, 2) + optimizer = jt.optim.SGD([a, b], 0.1) + loss = jt.matmul(a, b).sum() + optimizer.zero_grad() + optimizer.backward(loss) + optimizer.step() + res_a = a.opt_grad(optimizer) + res_b = b.opt_grad(optimizer) + np.testing.assert_allclose(res_a.numpy(), [[[2,2],[2,2]]]) + np.testing.assert_allclose(res_b.numpy(), [[2,2],[2,2]]) + print("test matmul grad success") + + @jt.flag_scope(use_acl=1) + def test_bmm_grad(self): + a = jt.ones(3, 2, 2).float32() + optimizer = jt.optim.SGD([a], 0.1) + c = jt.bmm(a, a) + loss = c.sum() + + optimizer.zero_grad() + optimizer.backward(loss) + optimizer.step() + + res = a.opt_grad(optimizer) + np.testing.assert_allclose(res.numpy(), [[[4,4],[4,4]],[[4,4],[4,4]],[[4,4],[4,4]]]) + print("test bmm grad success") + + @jt.flag_scope(use_acl=1) + def test_avgpool(self): + a = jt.ones(1, 1, 4, 4) + avg_pool = jt.nn.Pool(2, op='mean') + b = avg_pool(a) + np.testing.assert_allclose(b.numpy(), [[[[1,1],[1,1]]]]) + print("test avgpool success") + + @jt.flag_scope(use_acl=1) + def test_adaptive_maxpool2d(self): + a = jt.float32([[[[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12], + [13, 14, 15, 16]]]]) + pool = jt.nn.AdaptiveMaxPool2d((2, 2)) + b = pool(a) + np.testing.assert_allclose(b.numpy(), [[[[6, 8], [14, 16]]]]) + print("test adaptive_maxpool2d success") + + @jt.flag_scope(use_acl=1) + def test_adaptive_maxpool2d_grad(self): + a = jt.float32([[[[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12], + [13, 14, 15, 16]]]]) + max_pool = jt.nn.AdaptiveMaxPool2d((2, 2)) + optimizer = jt.optim.SGD([a], 0.1) + b = max_pool(a) + loss = b.sum() + optimizer.zero_grad() + optimizer.backward(loss) + optimizer.step() + res = a.opt_grad(optimizer) + np.testing.assert_allclose(res.numpy(), [[[[0, 0, 0, 0], [0, 1, 0, 1], [0, 0, 0, 0], [0, 1, 0, 1]]]]) + print("test adaptive_maxpool2d grad success") + + @jt.flag_scope(use_acl=1) + def test_adaptive_avgpool2d(self): + a = jt.float32([[[[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12], + [13, 14, 15, 16]]]]) + pool = jt.nn.AdaptiveAvgPool2d((2, 2)) + b = pool(a) + np.testing.assert_allclose(b.numpy(), [[[[3.5, 5.5], [11.5, 13.5]]]]) + print("test adaptive_avgpool2d success") + + @jt.flag_scope(use_acl=1) + def test_adaptive_avgpool2d_grad(self): + a = jt.float32([[[[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12], + [13, 14, 15, 16]]]]) + avg_pool = jt.nn.AdaptiveAvgPool2d((2, 2)) + optimizer = jt.optim.SGD([a], 0.1) + b = avg_pool(a) + loss = b.sum() + optimizer.zero_grad() + optimizer.backward(loss) + optimizer.step() + res = a.opt_grad(optimizer) + np.testing.assert_allclose(res.numpy(), [[[[0.25, 0.25, 0.25, 0.25], [0.25, 0.25, 0.25, 0.25], [0.25, 0.25, 0.25, 0.25], [0.25, 0.25, 0.25, 0.25]]]]) + print("test adaptive_avgpool2d grad success") + @jt.flag_scope(use_acl=1) + def test_index(self): + a = jt.ones(2, 3) + [s1,s2] = jt.index(a.shape) + np.testing.assert_allclose(s1.numpy(), [[0,0,0],[1,1,1]]) + np.testing.assert_allclose(s2.numpy(), [[0,1,2],[0,1,2]]) + print("test index success") + + @jt.flag_scope(use_acl=1) + def test_gather(self): + a = jt.array([[1, 2], [3, 4]]) + b = jt.gather(a, 1, jt.array([[0, 0], [1, 0]])) + np.testing.assert_allclose(b.numpy(), [[1, 1], [4, 3]]) + print("test gather success") + + @jt.flag_scope(use_acl=1) + def test_gather_grad(self): + a = jt.float32([[1, 2], [3, 4]]) + optimizer = jt.optim.SGD([a], 0.1) + b = jt.gather(a, 1, jt.array([[0, 0], [1, 0]])) + loss = b.sum() + optimizer.zero_grad() + optimizer.backward(loss) + optimizer.step() + res = a.opt_grad(optimizer) + np.testing.assert_allclose(res.numpy(), [[2, 0], [1, 1]]) + print("test gather grad success") + + @jt.flag_scope(use_acl=1) + def test_scatter(self): + a = jt.array([[1, 2], [3, 4]]) + b = jt.array([[0, 0], [0, 0]]) + b = jt.scatter(b, 1, jt.array([[0, 0], [1, 0]]), a, reduce="add") + np.testing.assert_allclose(b.numpy(), [[3, 0], [4, 3]]) + print("test scatter success") + + @jt.flag_scope(use_acl=1) + def test_scatter_grad(self): + a = jt.float32([[1, 2], [3, 4]]) + b = jt.float32([[0, 0], [0, 0]]) + optimizer = jt.optim.SGD([a,b], 0.1) + c = jt.scatter(b, 1, jt.array([[0, 0], [1, 0]]), a, reduce="add") + loss = c.max() + optimizer.zero_grad() + optimizer.backward(loss) + optimizer.step() + res_a = a.opt_grad(optimizer) + res_b = b.opt_grad(optimizer) + np.testing.assert_allclose(res_a.numpy(), [[0, 0], [0, 1]]) + np.testing.assert_allclose(res_b.numpy(), [[0, 0], [1, 0]]) + print("test scatter grad success") + + @jt.flag_scope(use_acl=1) + def test_where(self): + a = jt.array([[1, 2], [3, 4]]) + b = jt.ones(2, 2) + c = jt.where(a > 2, a, b) + np.testing.assert_allclose(c.numpy(), [[1,1],[3,4]]) + print("test where success") + + @jt.flag_scope(use_acl=1) + def test_where_grad(self): + a = jt.float32([[1, 2], [3, 4]]) + b = jt.array([[2., 2.], [2., 2.]]) + c = jt.where(a > 2, a, b) + optimizer = jt.optim.SGD([a,b], 0.1) + loss = c.sum() + optimizer.zero_grad() + optimizer.backward(loss) + optimizer.step() + res_a = a.opt_grad(optimizer) + res_b = b.opt_grad(optimizer) + np.testing.assert_allclose(res_a.numpy(), [[0, 0], [1, 1]]) + np.testing.assert_allclose(res_b.numpy(), [[1, 1], [0, 0]]) + print("test where grad success") + + @jt.flag_scope(use_acl=1) + def test_flip(self): + a = jt.array([[1., 2.], [3., 4.]]) + b = a.flip((0,1)) + np.testing.assert_allclose(b.numpy(), [[4,3],[2,1]]) + print("test flip success") + + @jt.flag_scope(use_acl=1) + def test_flip_grad(self): + a = jt.float32([[1, 2], [3, 4]]) + optimizer = jt.optim.SGD([a], 0.1) + b = a.flip((0,1)) + loss = b.max() + optimizer.zero_grad() + optimizer.backward(loss) + optimizer.step() + res = a.opt_grad(optimizer) + np.testing.assert_allclose(res.numpy(), [[0, 0], [0, 1]]) + print("test flip grad success") + + + +if __name__ == "__main__": + unittest.main() \ No newline at end of file From 0de90f2574632752b3af1c663d98d0d48431044a Mon Sep 17 00:00:00 2001 From: Yi Zhang Date: Fri, 19 Jul 2024 17:04:42 +0800 Subject: [PATCH 52/73] Format test_aclop.py --- python/jittor/test/test_aclop.py | 157 +++++++++++++++++-------------- 1 file changed, 87 insertions(+), 70 deletions(-) diff --git a/python/jittor/test/test_aclop.py b/python/jittor/test/test_aclop.py index 2cb52bce..196d2576 100644 --- a/python/jittor/test/test_aclop.py +++ b/python/jittor/test/test_aclop.py @@ -4,13 +4,16 @@ import numpy as np from jittor import init, Module import numpy as np + + @unittest.skipIf(not jt.compiler.has_acl, "No ACL found") class TestACL(unittest.TestCase): + @jt.flag_scope(use_acl=1) def test_getitem(self): a = jt.ones(100, 2) b = a[0:2, 0:2] - np.testing.assert_allclose(b.numpy(), [[1,1],[1,1]]) + np.testing.assert_allclose(b.numpy(), [[1, 1], [1, 1]]) print("test getitem success") @jt.flag_scope(use_acl=1) @@ -18,11 +21,11 @@ def test_setitem(self): a = jt.ones(2, 2) b = jt.Var(0) a[0:1, 0:1] = b - np.testing.assert_allclose(a.numpy(), [[0,1],[1,1]]) + np.testing.assert_allclose(a.numpy(), [[0, 1], [1, 1]]) print("test setitem success") @jt.flag_scope(use_acl=1) - def test_getitem_grad(self): + def test_getitem_grad(self): a = jt.ones(2, 2) b = a[0:1, 0:1] optimizer = jt.optim.SGD([a], 0.1) @@ -30,10 +33,10 @@ def test_getitem_grad(self): optimizer.zero_grad() optimizer.backward(loss) optimizer.step() - res=a.opt_grad(optimizer) - np.testing.assert_allclose(res.numpy(), [[1,0],[0,0]]) + res = a.opt_grad(optimizer) + np.testing.assert_allclose(res.numpy(), [[1, 0], [0, 0]]) print("test getitem grad success") - + @jt.flag_scope(use_acl=1) def test_setitem_grad(self): a = jt.ones(3, 3) @@ -46,20 +49,21 @@ def test_setitem_grad(self): optimizer.step() res_a = a.opt_grad(optimizer) res_b = b.opt_grad(optimizer) - np.testing.assert_allclose(res_a.numpy(), [[1,1,1],[1,1,1],[1,1,1]]) - np.testing.assert_allclose(res_b.numpy(), [[2,2],[2,2]]) + np.testing.assert_allclose(res_a.numpy(), + [[1, 1, 1], [1, 1, 1], [1, 1, 1]]) + np.testing.assert_allclose(res_b.numpy(), [[2, 2], [2, 2]]) print("test setitem grad success") - + @jt.flag_scope(use_acl=1) - def test_concat(self): + def test_concat(self): a = jt.ones(2, 2) b = jt.ones(2, 2) c = jt.concat([a, b], 0) - np.testing.assert_allclose(c.numpy(), [[1,1],[1,1],[1,1],[1,1]]) + np.testing.assert_allclose(c.numpy(), [[1, 1], [1, 1], [1, 1], [1, 1]]) print("test concat success") - + @jt.flag_scope(use_acl=1) - def test_maxpool_grad(self): + def test_maxpool_grad(self): a = jt.ones(1, 1, 4, 4) max_pool = jt.nn.Pool(2, op='maximum') optimizer = jt.optim.SGD([a], 0.1) @@ -68,46 +72,52 @@ def test_maxpool_grad(self): optimizer.zero_grad() optimizer.backward(loss) optimizer.step() - res=a.opt_grad(optimizer) - np.testing.assert_allclose(res.numpy(), [[[[1,0,1,0],[0,0,0,0],[1,0,1,0],[0,0,0,0]]]]) + res = a.opt_grad(optimizer) + np.testing.assert_allclose( + res.numpy(), + [[[[1, 0, 1, 0], [0, 0, 0, 0], [1, 0, 1, 0], [0, 0, 0, 0]]]]) print("test maxpool grad success") - + @jt.flag_scope(use_acl=1) def test_triu(self): a = jt.ones(3, 3) b = jt.triu_(a, 0) c = jt.triu_(a, 1) - np.testing.assert_allclose(b.numpy(), [[1,1,1],[0,1,1],[0,0,1]]) - np.testing.assert_allclose(c.numpy(), [[0,1,1],[0,0,1],[0,0,0]]) + np.testing.assert_allclose(b.numpy(), + [[1, 1, 1], [0, 1, 1], [0, 0, 1]]) + np.testing.assert_allclose(c.numpy(), + [[0, 1, 1], [0, 0, 1], [0, 0, 0]]) print("test triu success") - + @jt.flag_scope(use_acl=1) def test_bmm(self): a = jt.ones(3, 2, 2).float32() b = jt.bmm(a, a) - np.testing.assert_allclose(b.numpy(), [[[2,2],[2,2]],[[2,2],[2,2]],[[2,2],[2,2]]]) + np.testing.assert_allclose( + b.numpy(), [[[2, 2], [2, 2]], [[2, 2], [2, 2]], [[2, 2], [2, 2]]]) print("test bmm success") - + @jt.flag_scope(use_acl=1) def test_matmul(self): a = jt.ones(1, 4, 4) b = jt.ones(4, 2) c = jt.matmul(a, b) - np.testing.assert_allclose(c.numpy(), [[[4,4],[4,4],[4,4],[4,4]]]) + np.testing.assert_allclose(c.numpy(), + [[[4, 4], [4, 4], [4, 4], [4, 4]]]) print("test matmul success") - + @jt.flag_scope(use_acl=1) def test_maxpool(self): a = jt.ones(1, 1, 4, 4) - max_pool = jt.nn.Pool(2, op='maximum') - np.testing.assert_allclose(max_pool(a).numpy(), [[[[1,1],[1,1]]]]) + max_pool = jt.nn.Pool(2, op='maximum') + np.testing.assert_allclose(max_pool(a).numpy(), [[[[1, 1], [1, 1]]]]) print("test maxpool success") - + @jt.flag_scope(use_acl=1) def test_transpose(self): a = jt.ones(1, 2, 2) b = a.transpose(0, 2) - np.testing.assert_allclose(b.numpy(), [[[1],[1]],[[1],[1]]]) + np.testing.assert_allclose(b.numpy(), [[[1], [1]], [[1], [1]]]) print("test transpose success") @jt.flag_scope(use_acl=1) @@ -121,12 +131,12 @@ def test_matmul_grad(self): optimizer.step() res_a = a.opt_grad(optimizer) res_b = b.opt_grad(optimizer) - np.testing.assert_allclose(res_a.numpy(), [[[2,2],[2,2]]]) - np.testing.assert_allclose(res_b.numpy(), [[2,2],[2,2]]) + np.testing.assert_allclose(res_a.numpy(), [[[2, 2], [2, 2]]]) + np.testing.assert_allclose(res_b.numpy(), [[2, 2], [2, 2]]) print("test matmul grad success") - + @jt.flag_scope(use_acl=1) - def test_bmm_grad(self): + def test_bmm_grad(self): a = jt.ones(3, 2, 2).float32() optimizer = jt.optim.SGD([a], 0.1) c = jt.bmm(a, a) @@ -135,19 +145,21 @@ def test_bmm_grad(self): optimizer.zero_grad() optimizer.backward(loss) optimizer.step() - + res = a.opt_grad(optimizer) - np.testing.assert_allclose(res.numpy(), [[[4,4],[4,4]],[[4,4],[4,4]],[[4,4],[4,4]]]) + np.testing.assert_allclose( + res.numpy(), + [[[4, 4], [4, 4]], [[4, 4], [4, 4]], [[4, 4], [4, 4]]]) print("test bmm grad success") - + @jt.flag_scope(use_acl=1) def test_avgpool(self): a = jt.ones(1, 1, 4, 4) avg_pool = jt.nn.Pool(2, op='mean') b = avg_pool(a) - np.testing.assert_allclose(b.numpy(), [[[[1,1],[1,1]]]]) + np.testing.assert_allclose(b.numpy(), [[[[1, 1], [1, 1]]]]) print("test avgpool success") - + @jt.flag_scope(use_acl=1) def test_adaptive_maxpool2d(self): a = jt.float32([[[[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12], @@ -156,9 +168,9 @@ def test_adaptive_maxpool2d(self): b = pool(a) np.testing.assert_allclose(b.numpy(), [[[[6, 8], [14, 16]]]]) print("test adaptive_maxpool2d success") - + @jt.flag_scope(use_acl=1) - def test_adaptive_maxpool2d_grad(self): + def test_adaptive_maxpool2d_grad(self): a = jt.float32([[[[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12], [13, 14, 15, 16]]]]) max_pool = jt.nn.AdaptiveMaxPool2d((2, 2)) @@ -168,21 +180,23 @@ def test_adaptive_maxpool2d_grad(self): optimizer.zero_grad() optimizer.backward(loss) optimizer.step() - res = a.opt_grad(optimizer) - np.testing.assert_allclose(res.numpy(), [[[[0, 0, 0, 0], [0, 1, 0, 1], [0, 0, 0, 0], [0, 1, 0, 1]]]]) + res = a.opt_grad(optimizer) + np.testing.assert_allclose( + res.numpy(), + [[[[0, 0, 0, 0], [0, 1, 0, 1], [0, 0, 0, 0], [0, 1, 0, 1]]]]) print("test adaptive_maxpool2d grad success") - + @jt.flag_scope(use_acl=1) def test_adaptive_avgpool2d(self): a = jt.float32([[[[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12], [13, 14, 15, 16]]]]) pool = jt.nn.AdaptiveAvgPool2d((2, 2)) b = pool(a) - np.testing.assert_allclose(b.numpy(), [[[[3.5, 5.5], [11.5, 13.5]]]]) + np.testing.assert_allclose(b.numpy(), [[[[3.5, 5.5], [11.5, 13.5]]]]) print("test adaptive_avgpool2d success") - + @jt.flag_scope(use_acl=1) - def test_adaptive_avgpool2d_grad(self): + def test_adaptive_avgpool2d_grad(self): a = jt.float32([[[[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12], [13, 14, 15, 16]]]]) avg_pool = jt.nn.AdaptiveAvgPool2d((2, 2)) @@ -192,15 +206,19 @@ def test_adaptive_avgpool2d_grad(self): optimizer.zero_grad() optimizer.backward(loss) optimizer.step() - res = a.opt_grad(optimizer) - np.testing.assert_allclose(res.numpy(), [[[[0.25, 0.25, 0.25, 0.25], [0.25, 0.25, 0.25, 0.25], [0.25, 0.25, 0.25, 0.25], [0.25, 0.25, 0.25, 0.25]]]]) + res = a.opt_grad(optimizer) + np.testing.assert_allclose( + res.numpy(), + [[[[0.25, 0.25, 0.25, 0.25], [0.25, 0.25, 0.25, 0.25], + [0.25, 0.25, 0.25, 0.25], [0.25, 0.25, 0.25, 0.25]]]]) print("test adaptive_avgpool2d grad success") + @jt.flag_scope(use_acl=1) def test_index(self): a = jt.ones(2, 3) - [s1,s2] = jt.index(a.shape) - np.testing.assert_allclose(s1.numpy(), [[0,0,0],[1,1,1]]) - np.testing.assert_allclose(s2.numpy(), [[0,1,2],[0,1,2]]) + [s1, s2] = jt.index(a.shape) + np.testing.assert_allclose(s1.numpy(), [[0, 0, 0], [1, 1, 1]]) + np.testing.assert_allclose(s2.numpy(), [[0, 1, 2], [0, 1, 2]]) print("test index success") @jt.flag_scope(use_acl=1) @@ -209,7 +227,7 @@ def test_gather(self): b = jt.gather(a, 1, jt.array([[0, 0], [1, 0]])) np.testing.assert_allclose(b.numpy(), [[1, 1], [4, 3]]) print("test gather success") - + @jt.flag_scope(use_acl=1) def test_gather_grad(self): a = jt.float32([[1, 2], [3, 4]]) @@ -222,20 +240,20 @@ def test_gather_grad(self): res = a.opt_grad(optimizer) np.testing.assert_allclose(res.numpy(), [[2, 0], [1, 1]]) print("test gather grad success") - + @jt.flag_scope(use_acl=1) def test_scatter(self): a = jt.array([[1, 2], [3, 4]]) - b = jt.array([[0, 0], [0, 0]]) + b = jt.array([[0, 0], [0, 0]]) b = jt.scatter(b, 1, jt.array([[0, 0], [1, 0]]), a, reduce="add") np.testing.assert_allclose(b.numpy(), [[3, 0], [4, 3]]) print("test scatter success") - + @jt.flag_scope(use_acl=1) def test_scatter_grad(self): a = jt.float32([[1, 2], [3, 4]]) - b = jt.float32([[0, 0], [0, 0]]) - optimizer = jt.optim.SGD([a,b], 0.1) + b = jt.float32([[0, 0], [0, 0]]) + optimizer = jt.optim.SGD([a, b], 0.1) c = jt.scatter(b, 1, jt.array([[0, 0], [1, 0]]), a, reduce="add") loss = c.max() optimizer.zero_grad() @@ -243,8 +261,8 @@ def test_scatter_grad(self): optimizer.step() res_a = a.opt_grad(optimizer) res_b = b.opt_grad(optimizer) - np.testing.assert_allclose(res_a.numpy(), [[0, 0], [0, 1]]) - np.testing.assert_allclose(res_b.numpy(), [[0, 0], [1, 0]]) + np.testing.assert_allclose(res_a.numpy(), [[0, 0], [0, 1]]) + np.testing.assert_allclose(res_b.numpy(), [[0, 0], [1, 0]]) print("test scatter grad success") @jt.flag_scope(use_acl=1) @@ -252,15 +270,15 @@ def test_where(self): a = jt.array([[1, 2], [3, 4]]) b = jt.ones(2, 2) c = jt.where(a > 2, a, b) - np.testing.assert_allclose(c.numpy(), [[1,1],[3,4]]) + np.testing.assert_allclose(c.numpy(), [[1, 1], [3, 4]]) print("test where success") - + @jt.flag_scope(use_acl=1) def test_where_grad(self): a = jt.float32([[1, 2], [3, 4]]) b = jt.array([[2., 2.], [2., 2.]]) c = jt.where(a > 2, a, b) - optimizer = jt.optim.SGD([a,b], 0.1) + optimizer = jt.optim.SGD([a, b], 0.1) loss = c.sum() optimizer.zero_grad() optimizer.backward(loss) @@ -270,19 +288,19 @@ def test_where_grad(self): np.testing.assert_allclose(res_a.numpy(), [[0, 0], [1, 1]]) np.testing.assert_allclose(res_b.numpy(), [[1, 1], [0, 0]]) print("test where grad success") - + @jt.flag_scope(use_acl=1) def test_flip(self): - a = jt.array([[1., 2.], [3., 4.]]) - b = a.flip((0,1)) - np.testing.assert_allclose(b.numpy(), [[4,3],[2,1]]) + a = jt.array([[1., 2.], [3., 4.]]) + b = a.flip((0, 1)) + np.testing.assert_allclose(b.numpy(), [[4, 3], [2, 1]]) print("test flip success") - + @jt.flag_scope(use_acl=1) def test_flip_grad(self): a = jt.float32([[1, 2], [3, 4]]) optimizer = jt.optim.SGD([a], 0.1) - b = a.flip((0,1)) + b = a.flip((0, 1)) loss = b.max() optimizer.zero_grad() optimizer.backward(loss) @@ -290,8 +308,7 @@ def test_flip_grad(self): res = a.opt_grad(optimizer) np.testing.assert_allclose(res.numpy(), [[0, 0], [0, 1]]) print("test flip grad success") - - - + + if __name__ == "__main__": - unittest.main() \ No newline at end of file + unittest.main() From 80e3d292197fddaecf56b567cfb9fe4b4492bed5 Mon Sep 17 00:00:00 2001 From: Yi Zhang Date: Wed, 24 Jul 2024 14:53:04 +0800 Subject: [PATCH 53/73] Update acl_compiler.py --- python/jittor/extern/acl/acl_compiler.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/python/jittor/extern/acl/acl_compiler.py b/python/jittor/extern/acl/acl_compiler.py index 80c3c8d1..b81eef95 100644 --- a/python/jittor/extern/acl/acl_compiler.py +++ b/python/jittor/extern/acl/acl_compiler.py @@ -55,17 +55,18 @@ def install(): else: cc_files2.append(name) cc_files = cc_files2 + ascend_toolkit_home = os.getenv('ASCEND_TOOLKIT_HOME') cc_flags += f" -DHAS_CUDA -DIS_ACL \ - -I/usr/local/Ascend/ascend-toolkit/latest/include/ \ - -L/usr/local/Ascend/ascend-toolkit/latest/lib64/ \ + -I{ascend_toolkit_home}/include/ \ + -L{ascend_toolkit_home}/lib64/ \ -I{acl_compiler_home} -lascendcl -lacl_op_compiler " ctypes.CDLL("libascendcl.so", dlopen_flags) - ''' + f''' -ltikc_runtime -I/usr/local/Ascend/driver/include/ \ - -L/usr/local/Ascend/ascend-toolkit/latest/compiler/lib64/ \ - -L/usr/local/Ascend/ascend-toolkit/latest/runtime/lib64/ \ + -L{ascend_toolkit_home}/compiler/lib64/ \ + -L{ascend_toolkit_home}/runtime/lib64/ \ ''' jittor_utils.LOG.i("ACL detected") From ff0c295c99510cb695f34aa1d2fc4e678b4d7bb4 Mon Sep 17 00:00:00 2001 From: Yi Zhang Date: Wed, 24 Jul 2024 15:01:30 +0800 Subject: [PATCH 54/73] Update compile_extern.py --- python/jittor/compile_extern.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/python/jittor/compile_extern.py b/python/jittor/compile_extern.py index 4df000b4..4328f134 100644 --- a/python/jittor/compile_extern.py +++ b/python/jittor/compile_extern.py @@ -626,6 +626,7 @@ def setup_mpi(): global mpi_ops, mpi, use_mpi global mpicc_path, has_mpi use_mpi = os.environ.get("use_mpi", "1")=="1" + if not use_mpi: return mpi_ops = None mpi = None has_mpi = False @@ -711,4 +712,4 @@ def inner(self, *args, **kw): # install backend extern library for mod in jit_utils.backends: if mod.install_extern(): - break \ No newline at end of file + break From c0e9e0f301c73aee095ba1a13024703f57c80bd1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=82=93=E4=B8=80=E8=BD=A9?= <2021013404@secoder.net> Date: Wed, 24 Jul 2024 15:13:57 +0800 Subject: [PATCH 55/73] FEAT! add floor_int --- python/jittor/extern/acl/acl_compiler.py | 18 ++++++++++++++++++ 1 file changed, 18 insertions(+) diff --git a/python/jittor/extern/acl/acl_compiler.py b/python/jittor/extern/acl/acl_compiler.py index b81eef95..92cc1697 100644 --- a/python/jittor/extern/acl/acl_compiler.py +++ b/python/jittor/extern/acl/acl_compiler.py @@ -1135,7 +1135,23 @@ def grad(self, grad_output): output_shapes=[grad_output.shape], attr={})[0] return grad_input + + class FloorIntACL(Function): + def __init__(self): + super(FloorIntACL, self).__init__() + def execute(self, input): + self.input = input + self.shape = input.shape + result = acl_cmd("Floor", [input], + output_dtypes=[jt.int], + output_shapes=[input.shape], + attr={})[0] + return result + + def grad(self, grad_output): + return jt.zeros(self.shape, dtype=grad_output.dtype) + def warp(origin_func, new_func): def warpper(*args, **kwargs): if origin_func == jt.index: @@ -1178,6 +1194,8 @@ def warpper(*args, **kwargs): jt.gather = warp(jt.gather, GatherACL()) jt.scatter = warp(jt.scatter, ScatterACL()) jt.where = warp(jt.where, WhereACL()) + jt.floor_int = warp(jt.floor_int, FloorIntACL()) + jt.Var.floor_int = lambda x: warp(jt.floor_int, FloorIntACL())(x) # jt.nn.bmm = warp(jt.nn.bmm, BmmACL()) From 3973e075852283288301119687404b5c2cf14811 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=BC=A0=E4=BB=AA?= Date: Wed, 24 Jul 2024 15:25:10 +0800 Subject: [PATCH 56/73] feat: enable ACL optimization in split function --- python/jittor/extern/acl/acl_compiler.py | 753 +++++++++++++---------- python/jittor/misc.py | 2 +- 2 files changed, 427 insertions(+), 328 deletions(-) diff --git a/python/jittor/extern/acl/acl_compiler.py b/python/jittor/extern/acl/acl_compiler.py index 92cc1697..4c26bc98 100644 --- a/python/jittor/extern/acl/acl_compiler.py +++ b/python/jittor/extern/acl/acl_compiler.py @@ -1,6 +1,6 @@ # *************************************************************** -# Copyright (c) 2023 Jittor. All Rights Reserved. -# Maintainers: Dun Liang . +# Copyright (c) 2023 Jittor. All Rights Reserved. +# Maintainers: Dun Liang . # This file is subject to the terms and conditions defined in # file 'LICENSE.txt', which is part of this source code package. # *************************************************************** @@ -11,7 +11,6 @@ import glob import jittor.compiler as compiler import jittor as jt -import pdb has_acl = 0 cc_flags = "" @@ -122,21 +121,24 @@ def post_process(): def acl_cmd(name: str, inputs: list, output_dtypes: list, output_shapes: list, attr: dict): - nchw_op = ['MaxPoolWithArgmaxV1','MaxPoolGradWithArgmaxV1', 'AvgPoolV2'] - attr_op = ['MaxPoolWithArgmaxV1','MaxPoolGradWithArgmaxV1', 'AvgPoolV2', 'AdaptiveAvgPool2d', 'AdaptiveAvgPool2dGrad', 'ReverseV2'] - + nchw_op = ['MaxPoolWithArgmaxV1', 'MaxPoolGradWithArgmaxV1', 'AvgPoolV2'] + attr_op = [ + 'MaxPoolWithArgmaxV1', 'MaxPoolGradWithArgmaxV1', 'AvgPoolV2', + 'AdaptiveAvgPool2d', 'AdaptiveAvgPool2dGrad', 'ReverseV2' + ] + input_code = '' for i in range(len(inputs)): if name in nchw_op: input_code += f"op.add(in{i}, true, ACL_FORMAT_NCHW);\n" - else: + else: input_code += f"op.add(in{i}, true);\n" output_code = '' for i in range(len(output_dtypes)): if name in nchw_op: output_code += f"op.add(out{i}, false, ACL_FORMAT_NCHW);\n" - else: + else: output_code += f"op.add(out{i}, false);\n" # add attr to op @@ -151,7 +153,7 @@ def acl_cmd(name: str, inputs: list, output_dtypes: list, output_shapes: list, elif isinstance(v, str): attr_code += f"op.set_attr(\"{k}\", \"{v}\");\n" elif k == 'divisor_override_value': - attr_code += f"op.set_attr(\"{k}\", int64_t({v}), 0);\n" + attr_code += f"op.set_attr(\"{k}\", int64_t({v}), 0);\n" else: v = str(v).replace('[', '{').replace(']', '}') attr_code += f"op.set_attr(\"{k}\", vector{v});\n" @@ -165,16 +167,15 @@ def acl_cmd(name: str, inputs: list, output_dtypes: list, output_shapes: list, elif isinstance(v, str): attr_code += f"op.set_attr(\"{k}\", \"{v}\");\n" else: - attr_code += f"op.set_attr(\"{k}\", int({v}));\n" - + attr_code += f"op.set_attr(\"{k}\", int({v}));\n" + #print("input_code",input_code) #print("attr_code",attr_code) import jittor as jt - return jt.code( - output_shapes, - output_dtypes, - inputs, - cuda_header=""" + return jt.code(output_shapes, + output_dtypes, + inputs, + cuda_header=""" #include #include #include @@ -393,21 +394,21 @@ def acl_cmd(name: str, inputs: list, output_dtypes: list, output_shapes: list, } """, - cuda_src=f""" + cuda_src=f""" // aclop AclOpRunner op("{name}"); {input_code} {output_code} {attr_code} - op.run();""" - ) + op.run();""") def change_function(): import jittor as jt from jittor import Function - + class IndexACL(Function): + def __init__(self): super(IndexACL, self).__init__() @@ -422,8 +423,8 @@ def execute(self, inshape: list, dim, dtype="int32"): max_len = inshape[d] tmp = jt.zeros(max_len, dtype=dtype) result = acl_cmd( - "Range", - [jt.Var(0), jt.Var(max_len), jt.Var(1)], + "Range", [jt.Var(0), jt.Var(max_len), + jt.Var(1)], output_dtypes=[tmp.dtype], output_shapes=[tmp.shape], attr={})[0] @@ -431,7 +432,9 @@ def execute(self, inshape: list, dim, dtype="int32"): for i in range(len(inshape)): if i != d: broadcast_dim.append(i) - result = jt.broadcast(result, shape=inshape, dims=broadcast_dim) + result = jt.broadcast(result, + shape=inshape, + dims=broadcast_dim) results.append(result) if len(results) != 1: return tuple(results) @@ -452,231 +455,255 @@ def get_paddings(self): totalH = H + 2 * self.padding[0] - self.kernel_size[0] totalW = W + 2 * self.padding[1] - self.kernel_size[1] - kH = (totalH + self.stride[0] - 1) // self.stride[0] + 1 if self.attr['ceil_mode'] else totalH // self.stride[0] + 1 - kW = (totalW + self.stride[1] - 1) // self.stride[1] + 1 if self.attr['ceil_mode'] else totalW // self.stride[1] + 1 + kH = (totalH + self.stride[0] - + 1) // self.stride[0] + 1 if self.attr[ + 'ceil_mode'] else totalH // self.stride[0] + 1 + kW = (totalW + self.stride[1] - + 1) // self.stride[1] + 1 if self.attr[ + 'ceil_mode'] else totalW // self.stride[1] + 1 if self.attr['ceil_mode']: if (kH - 1) * self.stride[0] >= H + self.padding[0]: kH -= 1 - need_pad_h = (kH - 1) * self.stride[0] + self.kernel_size[0] - H + need_pad_h = (kH - + 1) * self.stride[0] + self.kernel_size[0] - H pad_top = need_pad_h - self.padding[0] if (kW - 1) * self.stride[1] >= W + self.padding[1]: kW -= 1 - need_pad_w = (kW - 1) * self.stride[1] + self.kernel_size[1] - W + need_pad_w = (kW - + 1) * self.stride[1] + self.kernel_size[1] - W pad_left = need_pad_w - self.padding[1] pads = [self.padding[0], pad_top, self.padding[1], pad_left] return pads - + def __init__(self, - kernel_size, - stride=None, - padding=0, - dilation=None, - return_indices=None, - ceil_mode=False, - count_include_pad=True, - op='maximum'): + kernel_size, + stride=None, + padding=0, + dilation=None, + return_indices=None, + ceil_mode=False, + count_include_pad=True, + op='maximum'): super(PoolACL, self).__init__() # set attr self.kernel_size = kernel_size if isinstance( kernel_size, tuple) else (kernel_size, kernel_size) stride = stride if stride else kernel_size - self.stride = stride if isinstance(stride, tuple) else (stride, stride) + self.stride = stride if isinstance(stride, tuple) else (stride, + stride) self.padding = padding if isinstance(padding, tuple) else (padding, - padding) + padding) dilation = dilation if dilation else 1 - self.dilation = dilation if isinstance(dilation, tuple) else (dilation, - dilation) + self.dilation = dilation if isinstance( + dilation, tuple) else (dilation, dilation) attr = {} - + self.return_indices = return_indices self.uint16 = jt.Var(1).int32().dtype self.op = op - + if op == 'mean': attr['exclusive'] = not count_include_pad attr['global_pooling'] = False attr['divisor_override_value'] = 0 - attr['ksize'] = [1, 1, self.kernel_size[0], self.kernel_size[1]] + attr['ksize'] = [ + 1, 1, self.kernel_size[0], self.kernel_size[1] + ] attr['strides'] = [1, 1, self.stride[0], self.stride[1]] attr['ceil_mode'] = ceil_mode attr['padding_mode'] = 'CALCULATED' attr['data_format'] = 'NCHW' elif op == 'maximum': - attr['ksize'] = [1, self.kernel_size[0], self.kernel_size[1], 1] + attr['ksize'] = [ + 1, self.kernel_size[0], self.kernel_size[1], 1 + ] attr['strides'] = [1, self.stride[0], self.stride[1], 1] attr['pads'] = [1, self.padding[0], self.padding[1], 1] attr['dilation'] = [1, self.dilation[0], self.dilation[1], 1] # attr['ceil_mode'] = ceil_mode - + self.attr = attr def execute(self, input): - + # create input input_shape = input.shape input_dtype = input.dtype - + self.input = input # create output output_shape = [ input_shape[0], input_shape[1], (input_shape[2] + 2 * self.padding[0] - self.dilation[0] * - (self.kernel_size[0] - 1) - 1) // self.stride[0] + 1, + (self.kernel_size[0] - 1) - 1) // self.stride[0] + 1, (input_shape[3] + 2 * self.padding[1] - self.dilation[1] * - (self.kernel_size[1] - 1) - 1) // self.stride[1] + 1 + (self.kernel_size[1] - 1) - 1) // self.stride[1] + 1 ] output_dtype = input_dtype if self.op == 'mean': self.attr['pads'] = self.get_paddings() result = acl_cmd("AvgPoolV2", [input], - output_dtypes=[output_dtype], - output_shapes=[output_shape], - attr=self.attr) - elif self.op == 'maximum': + output_dtypes=[output_dtype], + output_shapes=[output_shape], + attr=self.attr) + elif self.op == 'maximum': result = acl_cmd("MaxPoolWithArgmaxV1", [input], - output_dtypes=[output_dtype, self.uint16], - output_shapes=[output_shape, output_shape], - attr=self.attr) + output_dtypes=[output_dtype, self.uint16], + output_shapes=[output_shape, output_shape], + attr=self.attr) else: raise ValueError('no this type pool') - - if self.op == 'maximum': + + if self.op == 'maximum': self.index = result[1] - + if self.return_indices: return result[0], result[1] else: return result[0] - + def grad(self, grad_output): if self.op == 'maximum': - grad_input = acl_cmd("MaxPoolGradWithArgmaxV1", [self.input, grad_output, self.index], - output_dtypes=[grad_output.dtype], - output_shapes=[self.input.shape], - attr=self.attr)[0] + grad_input = acl_cmd("MaxPoolGradWithArgmaxV1", + [self.input, grad_output, self.index], + output_dtypes=[grad_output.dtype], + output_shapes=[self.input.shape], + attr=self.attr)[0] elif self.op == 'mean': - grad_input = acl_cmd("AvgPoolV2", [self.input, grad_output, self.index], - output_dtypes=[grad_output.dtype], - output_shapes=[self.input.shape], - attr=self.attr)[0] + grad_input = acl_cmd("AvgPoolV2", + [self.input, grad_output, self.index], + output_dtypes=[grad_output.dtype], + output_shapes=[self.input.shape], + attr=self.attr)[0] else: grad_input = None return grad_input - + class BmmACL(Function): + def __init__(self, adj_x1=False, adj_x2=False): super(BmmACL, self).__init__() self.adj_x1 = adj_x1 self.adj_x2 = adj_x2 - + def execute(self, x1, x2): self.input = [x1, x2] result = acl_cmd("BatchMatMul", [x1, x2], - output_dtypes=[x1.dtype], - output_shapes=[x1.shape[:-1] + x2.shape[-1:]], - attr={})[0] + output_dtypes=[x1.dtype], + output_shapes=[x1.shape[:-1] + x2.shape[-1:]], + attr={})[0] return result def grad(self, grad_output): x1, x2 = self.input - grad_x1 = acl_cmd("BatchMatMul", [grad_output, x2.transpose(-2, -1)], - output_dtypes=[x1.dtype], - output_shapes=[grad_output.shape[:-1] + x1.shape[-1:]], - attr={})[0] - grad_x2 = acl_cmd("BatchMatMul", [x1.transpose(-2, -1) , grad_output], - output_dtypes=[x2.dtype], - output_shapes=[x2.shape[:-1] + grad_output.shape[-1:]], - attr={})[0] + grad_x1 = acl_cmd( + "BatchMatMul", [grad_output, x2.transpose(-2, -1)], + output_dtypes=[x1.dtype], + output_shapes=[grad_output.shape[:-1] + x1.shape[-1:]], + attr={})[0] + grad_x2 = acl_cmd( + "BatchMatMul", [x1.transpose(-2, -1), grad_output], + output_dtypes=[x2.dtype], + output_shapes=[x2.shape[:-1] + grad_output.shape[-1:]], + attr={})[0] return grad_x1, grad_x2 - + class MatmulACL(Function): + def __init__(self, adj_x1=False, adj_x2=False): super(MatmulACL, self).__init__() self.adj_x1 = adj_x1 self.adj_x2 = adj_x2 - + def execute(self, x1, x2): self.input = [x1, x2] if len(x1.shape) > 2 or len(x2.shape) > 2: result = acl_cmd("BatchMatMul", [x1, x2], - output_dtypes=[x1.dtype], - output_shapes=[x1.shape[:-1] + x2.shape[-1:]], - attr={})[0] + output_dtypes=[x1.dtype], + output_shapes=[x1.shape[:-1] + x2.shape[-1:]], + attr={})[0] else: result = acl_cmd("MatMul", [x1, x2], - output_dtypes=[x1.dtype], - output_shapes=[x1.shape[:-1] + x2.shape[-1:]], - attr={})[0] + output_dtypes=[x1.dtype], + output_shapes=[x1.shape[:-1] + x2.shape[-1:]], + attr={})[0] return result def grad(self, grad_output): x1, x2 = self.input if len(x1.shape) > 2 or len(x2.shape) > 2: - grad_x1 = acl_cmd("BatchMatMul", [grad_output, x2.transpose(-2, -1)], - output_dtypes=[x1.dtype], - output_shapes=[grad_output.shape[:-1] + x1.shape[-1:]], - attr={})[0] - grad_x2 = acl_cmd("BatchMatMul", [x1.transpose(-2, -1) , grad_output], - output_dtypes=[x2.dtype], - output_shapes=[x2.shape[:-1] + grad_output.shape[-1:]], - attr={})[0] + grad_x1 = acl_cmd( + "BatchMatMul", + [grad_output, x2.transpose(-2, -1)], + output_dtypes=[x1.dtype], + output_shapes=[grad_output.shape[:-1] + x1.shape[-1:]], + attr={})[0] + grad_x2 = acl_cmd( + "BatchMatMul", [x1.transpose(-2, -1), grad_output], + output_dtypes=[x2.dtype], + output_shapes=[x2.shape[:-1] + grad_output.shape[-1:]], + attr={})[0] else: - grad_x1 = acl_cmd("MatMul", [grad_output, x2.transpose(-2, -1)], - output_dtypes=[x1.dtype], - output_shapes=[grad_output.shape[:-1] + x1.shape[-1:]], - attr={})[0] - grad_x2 = acl_cmd("MatMul", [x1.transpose(-2, -1) , grad_output], - output_dtypes=[x2.dtype], - output_shapes=[x2.shape[:-1] + grad_output.shape[-1:]], - attr={})[0] + grad_x1 = acl_cmd( + "MatMul", [grad_output, x2.transpose(-2, -1)], + output_dtypes=[x1.dtype], + output_shapes=[grad_output.shape[:-1] + x1.shape[-1:]], + attr={})[0] + grad_x2 = acl_cmd( + "MatMul", [x1.transpose(-2, -1), grad_output], + output_dtypes=[x2.dtype], + output_shapes=[x2.shape[:-1] + grad_output.shape[-1:]], + attr={})[0] return grad_x1, grad_x2 - + class GetItem(Function): + def __init__(self): super(GetItem, self).__init__() - + def stride(self, x, dim): stride = 1 - for i in range(dim+1, len(x.shape)): + for i in range(dim + 1, len(x.shape)): stride *= x.shape[i] return stride - + def execute(self, x, slices): if isinstance(slices, jt.Var) or isinstance(slices, tuple): if isinstance(slices, jt.Var): - slices = (slices,) + slices = (slices, ) if isinstance(slices[0], jt.Var): slices_len = len(slices) masks = jt.ones(slices_len, dtype=jt.int64) - + output = slices[0].shape output += x.shape[slices_len:] - + input_ = [x, masks, jt.Var(list(output)).int64()] for i in range(slices_len): input_.append(slices[i].int32()) - result = acl_cmd("Index", input_, - output_dtypes=[x.dtype], - output_shapes=[output], - attr={})[0] + result = acl_cmd("Index", + input_, + output_dtypes=[x.dtype], + output_shapes=[output], + attr={})[0] return result - + # use AsStrided operator to implement the getitem function # get the shape and stride of the input tensor x_dim = len(x.shape) - + if not isinstance(slices, tuple): - slices = (slices,) + slices = (slices, ) if len(slices) < x_dim: - slices += (slice(None, None, None),) * (x_dim - len(slices)) - + slices += (slice(None, None, None), ) * (x_dim - len(slices)) + self.inputs = [x, slices] - + sizes = [] strides = [] offset = 0 @@ -696,7 +723,7 @@ def execute(self, x, slices): strides.append(stride) else: raise ValueError("Invalid slice type") - + if not sizes: sizes = [1] strides = [0] @@ -705,25 +732,40 @@ def execute(self, x, slices): self.strides = strides self.offset = offset self.shape = x.shape - result = acl_cmd("AsStrided", [x, jt.Var(sizes), jt.Var(strides), jt.Var(offset)], - output_dtypes=[x.dtype], - output_shapes=[jt.empty(sizes).shape], - attr={})[0] + result = acl_cmd( + "AsStrided", + [x, jt.Var(sizes), + jt.Var(strides), + jt.Var(offset)], + output_dtypes=[x.dtype], + output_shapes=[jt.empty(sizes).shape], + attr={})[0] return result - + def grad(self, grad_output): result = jt.zeros(self.shape, dtype=grad_output.dtype) sizes = list(grad_output.shape) - strides = [self.stride(grad_output, dim) for dim in range(len(grad_output.shape))] - result = acl_cmd("ViewCopy", [result, jt.Var(self.sizes), jt.Var(self.strides), jt.Var(self.offset), - grad_output, jt.Var(sizes), jt.Var(strides), jt.Var(0)], - output_dtypes=[result.dtype], - output_shapes=[result.shape], - attr={})[0] + strides = [ + self.stride(grad_output, dim) + for dim in range(len(grad_output.shape)) + ] + result = acl_cmd("ViewCopy", [ + result, + jt.Var(self.sizes), + jt.Var(self.strides), + jt.Var(self.offset), grad_output, + jt.Var(sizes), + jt.Var(strides), + jt.Var(0) + ], + output_dtypes=[result.dtype], + output_shapes=[result.shape], + attr={})[0] result.sync() return result, None class ConcatACL(Function): + def __init__(self): super(ConcatACL, self).__init__() @@ -731,16 +773,25 @@ def execute(self, input_tensors, dim=0): self.input = input_tensors for i in range(len(input_tensors)): if input_tensors[i].dtype != input_tensors[0].dtype: - raise ValueError("All input tensors must have the same dtype") - if input_tensors[i].shape[:dim] != input_tensors[0].shape[:dim] or input_tensors[i].shape[dim+1:] != input_tensors[0].shape[dim+1:]: - raise ValueError("All input tensors must have the same shape") - result = acl_cmd("ConcatD", input_tensors, - output_dtypes=[input_tensors[0].dtype], - output_shapes=[jt.empty(self.calculate_output_shape(input_tensors, dim)).shape], - attr={ - "N": len(input_tensors), - "concat_dim": dim - })[0] + raise ValueError( + "All input tensors must have the same dtype") + if input_tensors[i].shape[:dim] != input_tensors[ + 0].shape[:dim] or input_tensors[i].shape[ + dim + 1:] != input_tensors[0].shape[dim + 1:]: + raise ValueError( + "All input tensors must have the same shape") + result = acl_cmd( + "ConcatD", + input_tensors, + output_dtypes=[input_tensors[0].dtype], + output_shapes=[ + jt.empty(self.calculate_output_shape(input_tensors, + dim)).shape + ], + attr={ + "N": len(input_tensors), + "concat_dim": dim + })[0] return result def grad(self, grad_output): @@ -757,17 +808,16 @@ def split_grad(self, grad_output, input_tensors, axis): offset = 0 grad_inputs = [] for tensor in input_tensors: - grad_input = acl_cmd("Slice", - [grad_output, - [0]*axis + [offset] + [0]*(len(tensor.shape)-axis-1), - tensor.shape] - ) + grad_input = acl_cmd("Slice", [ + grad_output, [0] * axis + [offset] + [0] * + (len(tensor.shape) - axis - 1), tensor.shape + ]) grad_inputs.append(grad_input) offset += tensor.shape[axis] return grad_inputs - - + class SetItemACL(Function): + def __init__(self): super(SetItemACL, self).__init__() @@ -778,7 +828,7 @@ def stride(self, x, dim): stride *= x.shape[i] return stride - def execute(self, x, slices, value, reduce = 'void'): + def execute(self, x, slices, value, reduce='void'): self.is_tensor = type(value) == jt.Var if type(value) != jt.Var: value = jt.array(value) @@ -786,11 +836,11 @@ def execute(self, x, slices, value, reduce = 'void'): # 确保slices是一个元组 if not isinstance(slices, tuple): - slices = (slices,) + slices = (slices, ) # 补齐slices使其长度等于x的维度 if len(slices) < x_dim: - slices += (slice(None, None, None),) * (x_dim - len(slices)) + slices += (slice(None, None, None), ) * (x_dim - len(slices)) self.inputs = [x, slices, value] @@ -802,7 +852,7 @@ def execute(self, x, slices, value, reduce = 'void'): if isinstance(s, int): if s < 0: s += x.shape[dim] - s = slice(s, s+1, None) + s = slice(s, s + 1, None) if isinstance(s, slice): # 解包切片 start, stop, step = s.indices(x.shape[dim]) @@ -817,269 +867,310 @@ def execute(self, x, slices, value, reduce = 'void'): # 计算value的size、stride和offset value_sizes = list(value.shape) - value_strides = [self.stride(value, dim) for dim in range(len(value.shape))] + value_strides = [ + self.stride(value, dim) for dim in range(len(value.shape)) + ] self.target_sizes = target_sizes self.target_strides = target_strides self.offset = offset self.value_sizes = value_sizes self.value_strides = value_strides - + #import pdb; pdb.set_trace() - result = acl_cmd("ViewCopy", [x, jt.Var(target_sizes), jt.Var(target_strides), jt.Var(offset), - value, jt.Var(value_sizes), jt.Var(value_strides), jt.Var(0)], + result = acl_cmd("ViewCopy", [ + x, + jt.Var(target_sizes), + jt.Var(target_strides), + jt.Var(offset), value, + jt.Var(value_sizes), + jt.Var(value_strides), + jt.Var(0) + ], output_dtypes=[x.dtype], output_shapes=[x.shape], attr={})[0] result.sync() return result - + def grad(self, grad_output): - result = acl_cmd("AsStrided", [grad_output, jt.Var(self.target_sizes), jt.Var(self.target_strides), jt.Var(self.offset)], - output_dtypes=[grad_output.dtype], - output_shapes=[jt.empty(self.target_sizes).shape], - attr={})[0] + result = acl_cmd("AsStrided", [ + grad_output, + jt.Var(self.target_sizes), + jt.Var(self.target_strides), + jt.Var(self.offset) + ], + output_dtypes=[grad_output.dtype], + output_shapes=[jt.empty(self.target_sizes).shape], + attr={})[0] # copy grad_output to new_grad_output new_grad_output = acl_cmd("Copy", [grad_output], - output_dtypes=[grad_output.dtype], - output_shapes=[grad_output.shape], - attr={ "N": 1 })[0] - new_grad_output = acl_cmd("ViewCopy", [new_grad_output, jt.Var(self.target_sizes), jt.Var(self.target_strides), jt.Var(self.offset), - jt.zeros(self.value_sizes, dtype=grad_output.dtype), jt.Var(self.value_sizes), jt.Var(self.value_strides), jt.Var(0)], - output_dtypes=[grad_output.dtype], - output_shapes=[grad_output.shape], - attr={})[0] + output_dtypes=[grad_output.dtype], + output_shapes=[grad_output.shape], + attr={"N": 1})[0] + new_grad_output = acl_cmd("ViewCopy", [ + new_grad_output, + jt.Var(self.target_sizes), + jt.Var(self.target_strides), + jt.Var(self.offset), + jt.zeros(self.value_sizes, dtype=grad_output.dtype), + jt.Var(self.value_sizes), + jt.Var(self.value_strides), + jt.Var(0) + ], + output_dtypes=[grad_output.dtype], + output_shapes=[grad_output.shape], + attr={})[0] new_grad_output.sync() return new_grad_output, None, result if self.is_tensor else None class TriuACL(Function): + def __init__(self): super(TriuACL, self).__init__() def execute(self, input, k): self.input = input result = acl_cmd("Triu", [input], - output_dtypes=[input.dtype], - output_shapes=[input.shape], - attr={'diagonal': k})[0] + output_dtypes=[input.dtype], + output_shapes=[input.shape], + attr={'diagonal': k})[0] return result def grad(self, grad_output): return grad_output class TransposeACL(Function): + def __init__(self): super(TransposeACL, self).__init__() def execute(self, input, perm): self.input = input - - output_shape = input.shape[perm[0]:perm[0]+1] + + output_shape = input.shape[perm[0]:perm[0] + 1] for i in range(1, len(perm)): - output_shape += input.shape[perm[i]:perm[i]+1] + output_shape += input.shape[perm[i]:perm[i] + 1] result = acl_cmd("Transpose", [input, jt.Var(perm)], - output_dtypes=[input.dtype], - output_shapes=[output_shape], - attr={})[0] + output_dtypes=[input.dtype], + output_shapes=[output_shape], + attr={})[0] return result def grad(self, grad_output): return grad_output - + class AdaptiveMaxPool2dACL(Function): - def __init__(self, - output_size, - return_indices=False, - ): + + def __init__( + self, + output_size, + return_indices=False, + ): super(AdaptiveMaxPool2dACL, self).__init__() self.output_size = (output_size, output_size) if isinstance( output_size, int) else output_size - + self.return_indices = return_indices self.uint16 = jt.Var(1).int32().dtype - + attr = {} attr['ceil_mode'] = False - attr['dilations'] = [1,1,1,1] - self.attr = attr - - + attr['dilations'] = [1, 1, 1, 1] + self.attr = attr + def execute(self, input): input_shape = input.shape input_dtype = input.dtype - + output_shape = [ - input_shape[0], input_shape[1], - self.output_size[0], self.output_size[1] + input_shape[0], input_shape[1], self.output_size[0], + self.output_size[1] ] output_dtype = input_dtype self.input = input - - stride_h = input_shape[2] // output_shape[2]; - stride_w = input_shape[3] // output_shape[3]; - kernel_size_h = input_shape[2] - (output_shape[2] - 1) * stride_h; - kernel_size_w = input_shape[3] - (output_shape[3] - 1) * stride_w; - + + stride_h = input_shape[2] // output_shape[2] + stride_w = input_shape[3] // output_shape[3] + kernel_size_h = input_shape[2] - (output_shape[2] - 1) * stride_h + kernel_size_w = input_shape[3] - (output_shape[3] - 1) * stride_w + stride = [0, 0] kernel_size = [0, 0] padding = [0, 0] - - stride[0] = stride_h; - stride[1] = stride_w; - kernel_size[0] = kernel_size_h; - kernel_size[1] = kernel_size_w; - padding[0] = padding[1] = 0; - kernel_sizes = [1, kernel_size[0], kernel_size[1], 1]; - strides_size = [1, stride[0], stride[1], 1]; - paddings = [1, padding[0], padding[1], 1]; - + + stride[0] = stride_h + stride[1] = stride_w + kernel_size[0] = kernel_size_h + kernel_size[1] = kernel_size_w + padding[0] = padding[1] = 0 + kernel_sizes = [1, kernel_size[0], kernel_size[1], 1] + strides_size = [1, stride[0], stride[1], 1] + paddings = [1, padding[0], padding[1], 1] + self.attr['ksize'] = kernel_sizes self.attr['strides'] = strides_size self.attr['pads'] = paddings - + result = acl_cmd("MaxPoolWithArgmaxV1", [input], - output_dtypes=[output_dtype, self.uint16], - output_shapes=[output_shape, output_shape], - attr=self.attr) - + output_dtypes=[output_dtype, self.uint16], + output_shapes=[output_shape, output_shape], + attr=self.attr) + self.index = result[1] - + if self.return_indices: return result[0], result[1] else: return result[0] - + def grad(self, grad_output): - grad_input = acl_cmd("MaxPoolGradWithArgmaxV1", [self.input, grad_output, self.index], - output_dtypes=[grad_output.dtype], - output_shapes=[self.input.shape], - attr=self.attr)[0] + grad_input = acl_cmd("MaxPoolGradWithArgmaxV1", + [self.input, grad_output, self.index], + output_dtypes=[grad_output.dtype], + output_shapes=[self.input.shape], + attr=self.attr)[0] return grad_input - + class AdaptiveAvgPool2dACL(Function): - def __init__(self, - output_size - ): + + def __init__(self, output_size): super(AdaptiveAvgPool2dACL, self).__init__() self.output_size = (output_size, output_size) if isinstance( output_size, int) else output_size - + attr = {} if isinstance(output_size, tuple): output_size = [output_size[0], output_size[1]] attr['output_size'] = output_size self.attr = attr - + def execute(self, input): input_shape = input.shape input_dtype = input.dtype - + self.original_shape = input_shape - + output_shape = [ - input_shape[0], input_shape[1], - self.attr['output_size'][0], self.attr['output_size'][1] + input_shape[0], input_shape[1], self.attr['output_size'][0], + self.attr['output_size'][1] ] output_dtype = input_dtype self.input = input - + result = acl_cmd("AdaptiveAvgPool2d", [input], - output_dtypes=[output_dtype], - output_shapes=[output_shape], - attr=self.attr) - - + output_dtypes=[output_dtype], + output_shapes=[output_shape], + attr=self.attr) + return result[0] - + def grad(self, grad_output): attr = {} attr['orig_input_shape'] = list(self.original_shape) grad_input = acl_cmd("AdaptiveAvgPool2dGrad", [grad_output], - output_dtypes=[grad_output.dtype], - output_shapes=[self.original_shape], - attr=attr)[0] + output_dtypes=[grad_output.dtype], + output_shapes=[self.original_shape], + attr=attr)[0] return grad_input - + class CumsumACL(Function): + def __init__(self): super(CumsumACL, self).__init__() - + def execute(self, input, dim=-1): self.input = input self.dim = dim result = acl_cmd("Cumsum", [input, jt.Var(dim)], - output_dtypes=[input.dtype], - output_shapes=[input.shape], - attr={})[0] + output_dtypes=[input.dtype], + output_shapes=[input.shape], + attr={})[0] return result - + def grad(self, grad_output): - flipped_grad_output = acl_cmd("ReverseV2", [grad_output, jt.Var([self.dim])], - output_dtypes=[grad_output.dtype], - output_shapes=[grad_output.shape], - attr={})[0] - cumulative_grad = acl_cmd("Cumsum", [flipped_grad_output, jt.Var(self.dim)], - output_dtypes=[grad_output.dtype], - output_shapes=[grad_output.shape], - attr={})[0] - grad_input = acl_cmd("ReverseV2", [cumulative_grad, jt.Var([self.dim])], - output_dtypes=[grad_output.dtype], - output_shapes=[grad_output.shape], - attr={})[0] + flipped_grad_output = acl_cmd( + "ReverseV2", [grad_output, jt.Var([self.dim])], + output_dtypes=[grad_output.dtype], + output_shapes=[grad_output.shape], + attr={})[0] + cumulative_grad = acl_cmd( + "Cumsum", + [flipped_grad_output, jt.Var(self.dim)], + output_dtypes=[grad_output.dtype], + output_shapes=[grad_output.shape], + attr={})[0] + grad_input = acl_cmd( + "ReverseV2", + [cumulative_grad, jt.Var([self.dim])], + output_dtypes=[grad_output.dtype], + output_shapes=[grad_output.shape], + attr={})[0] return grad_input + class GatherACL(Function): + def __init__(self): super(GatherACL, self).__init__() - + def execute(self, input, dim, index): self.input = input self.dim = dim self.index = index - + result = acl_cmd("GatherElements", [input, index], - output_dtypes=[input.dtype], - output_shapes=[index.shape], - attr={'dim':dim})[0] + output_dtypes=[input.dtype], + output_shapes=[index.shape], + attr={'dim': dim})[0] return result - + def grad(self, grad_output): - tmp = jt.zeros(self.index.shape,dtype=grad_output.dtype) - grad_input = acl_cmd("ScatterElements", [tmp, self.index, grad_output], - output_dtypes=[grad_output.dtype], - output_shapes=[self.index.shape], - attr={'axis':self.dim, 'reduction':"add"})[0] + tmp = jt.zeros(self.index.shape, dtype=grad_output.dtype) + grad_input = acl_cmd("ScatterElements", + [tmp, self.index, grad_output], + output_dtypes=[grad_output.dtype], + output_shapes=[self.index.shape], + attr={ + 'axis': self.dim, + 'reduction': "add" + })[0] return grad_input - + class ScatterACL(Function): + def __init__(self): super(ScatterACL, self).__init__() - + def execute(self, input, dim, index, src, reduce='void'): self.input = input self.dim = dim self.index = index self.reduce = reduce result = acl_cmd("ScatterElements", [input, self.index, src], - output_dtypes=[input.dtype], - output_shapes=[index.shape], - attr={'axis':self.dim, 'reduction':reduce})[0] + output_dtypes=[input.dtype], + output_shapes=[index.shape], + attr={ + 'axis': self.dim, + 'reduction': reduce + })[0] return result def grad(self, grad_output): grad_input = acl_cmd("GatherElements", [grad_output, self.index], - output_dtypes=[grad_output.dtype], - output_shapes=[self.index.shape], - attr={'dim':self.dim})[0] + output_dtypes=[grad_output.dtype], + output_shapes=[self.index.shape], + attr={'dim': self.dim})[0] return grad_output, None, None, grad_input - + class WhereACL(Function): + def __init__(self): super(WhereACL, self).__init__() - + def execute(self, condition, x, y): self.condition = condition - + if x.dtype != y.dtype: if x.dtype == jt.float32: y = y.float32() @@ -1087,72 +1178,74 @@ def execute(self, condition, x, y): x = x.float32() else: x = x.to(y.dtype) - + self.x = x self.y = y - + result = acl_cmd("Select", [condition, x, y], - output_dtypes=[x.dtype], - output_shapes=[x.shape], - attr={})[0] + output_dtypes=[x.dtype], + output_shapes=[x.shape], + attr={})[0] return result - + def grad(self, grad_output): - tmp = jt.zeros(grad_output.shape,dtype=grad_output.dtype) + tmp = jt.zeros(grad_output.shape, dtype=grad_output.dtype) grad_x = acl_cmd("Select", [self.condition, grad_output, tmp], - output_dtypes=[self.x.dtype], - output_shapes=[self.x.shape], - attr={})[0] - + output_dtypes=[self.x.dtype], + output_shapes=[self.x.shape], + attr={})[0] + grad_y = acl_cmd("Select", [self.condition, tmp, grad_output], - output_dtypes=[self.y.dtype], - output_shapes=[self.y.shape], - attr={})[0] + output_dtypes=[self.y.dtype], + output_shapes=[self.y.shape], + attr={})[0] return grad_output, grad_x, grad_y - + class FlipACL(Function): + def __init__(self): super(FlipACL, self).__init__() - - + def execute(self, input, dim): self.input = input - #if isinstance(dim_vector, tuple): + #if isinstance(dim_vector, tuple): dim_vector = jt.Var(list(dim)) #print(dim_vector.dtype) self.dim_vector = dim_vector #print(input, dim_vector) result = acl_cmd("ReverseV2", [input, dim_vector], - output_dtypes=[input.dtype], - output_shapes=[input.shape], - attr={})[0] + output_dtypes=[input.dtype], + output_shapes=[input.shape], + attr={})[0] return result - + def grad(self, grad_output): #print(grad_output) grad_input = acl_cmd("ReverseV2", [grad_output, self.dim_vector], - output_dtypes=[grad_output.dtype], - output_shapes=[grad_output.shape], - attr={})[0] + output_dtypes=[grad_output.dtype], + output_shapes=[grad_output.shape], + attr={})[0] return grad_input - + class FloorIntACL(Function): + def __init__(self): super(FloorIntACL, self).__init__() - + def execute(self, input): self.input = input self.shape = input.shape result = acl_cmd("Floor", [input], - output_dtypes=[jt.int], - output_shapes=[input.shape], - attr={})[0] + output_dtypes=[jt.int], + output_shapes=[input.shape], + attr={})[0] return result - + def grad(self, grad_output): return jt.zeros(self.shape, dtype=grad_output.dtype) - + def warp(origin_func, new_func): + def warpper(*args, **kwargs): if origin_func == jt.index: if len(args) == 2 and args[1] == None: @@ -1165,39 +1258,45 @@ def warpper(*args, **kwargs): args = (args[0], kwargs.get('dim', -1)) kwargs = {} if isinstance(new_func, ScatterACL): - args = (args[0], args[1], args[2], args[3], kwargs.get('reduce', 'void')) + args = (args[0], args[1], args[2], args[3], + kwargs.get('reduce', 'void')) kwargs = {} return new_func(*args, **kwargs) return origin_func(*args, **kwargs) + return warpper - - + jt.index = warp(jt.index, IndexACL()) jt.Var.index = lambda x, dim=None: warp(jt.index, IndexACL())(x.shape, dim) jt.nn.Pool = warp(jt.nn.Pool, PoolACL) - jt.nn.AdaptiveMaxPool2d = warp(jt.nn.AdaptiveMaxPool2d, AdaptiveMaxPool2dACL) - jt.nn.AdaptiveAvgPool2d = warp(jt.nn.AdaptiveAvgPool2d, AdaptiveAvgPool2dACL) - + jt.nn.AdaptiveMaxPool2d = warp(jt.nn.AdaptiveMaxPool2d, + AdaptiveMaxPool2dACL) + jt.nn.AdaptiveAvgPool2d = warp(jt.nn.AdaptiveAvgPool2d, + AdaptiveAvgPool2dACL) + jt.triu = warp(jt.triu, TriuACL()) jt.triu_ = warp(jt.triu, TriuACL()) jt.Var.triu = lambda x: warp(jt.Var.triu, TriuACL())(x) jt.Var.triu_ = lambda x: warp(jt.Var.triu_, TriuACL())(x) - + jt.getitem = warp(jt.getitem, GetItem()) - jt.Var.getitem = lambda x, slices: warp(jt.getitem, GetItem())(x, slices) + jt.Var.getitem = lambda x, slices, return_x=None: warp( + jt.getitem, GetItem())(x, slices, return_x) + jt.setitem = warp(jt.setitem, SetItemACL()) - jt.Var.setitem = lambda x, slices, value, reduce='void': warp(jt.setitem, SetItemACL())(x, slices, value, reduce) - + jt.Var.setitem = lambda x, slices, value, reduce='void': warp( + jt.setitem, SetItemACL())(x, slices, value, reduce) + jt.misc.flip = warp(jt.misc.flip, FlipACL()) - jt.Var.flip = lambda x, dim_vector: warp(jt.misc.flip, FlipACL())(x, dim_vector) + jt.Var.flip = lambda x, dim_vector: warp(jt.misc.flip, FlipACL())( + x, dim_vector) jt.cumsum = warp(jt.cumsum, CumsumACL()) jt.gather = warp(jt.gather, GatherACL()) jt.scatter = warp(jt.scatter, ScatterACL()) jt.where = warp(jt.where, WhereACL()) jt.floor_int = warp(jt.floor_int, FloorIntACL()) jt.Var.floor_int = lambda x: warp(jt.floor_int, FloorIntACL())(x) - - + # jt.nn.bmm = warp(jt.nn.bmm, BmmACL()) # jt.bmm = warp(jt.bmm, BmmACL()) # jt.nn.matmul = warp(jt.matmul, MatmulACL()) diff --git a/python/jittor/misc.py b/python/jittor/misc.py index 5357baf2..2760f9d4 100644 --- a/python/jittor/misc.py +++ b/python/jittor/misc.py @@ -882,7 +882,7 @@ def split(d, split_size, dim=0): ans = [] last = 0 s_last = len(split_size)-1 - gopt_disable = jt.flags.gopt_disable + gopt_disable = jt.flags.gopt_disable or jt.flags.use_acl for j, i in enumerate(split_size): if i==0: shape = list(d.shape) From 7759a5440382e786432f76860d41c6e4fabe6e75 Mon Sep 17 00:00:00 2001 From: Yi Zhang Date: Wed, 24 Jul 2024 15:43:19 +0800 Subject: [PATCH 57/73] Update compile_extern.py --- python/jittor/compile_extern.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/jittor/compile_extern.py b/python/jittor/compile_extern.py index 4328f134..f6c93b54 100644 --- a/python/jittor/compile_extern.py +++ b/python/jittor/compile_extern.py @@ -626,10 +626,10 @@ def setup_mpi(): global mpi_ops, mpi, use_mpi global mpicc_path, has_mpi use_mpi = os.environ.get("use_mpi", "1")=="1" - if not use_mpi: return mpi_ops = None mpi = None has_mpi = False + if not use_mpi: return mpicc_path = env_or_try_find('mpicc_path', 'mpicc') if mpicc_path == "": # LOG.i("mpicc not found, distribution disabled.") From bab510cd92a103c844a3996b42c8227e60985e59 Mon Sep 17 00:00:00 2001 From: Yi Zhang Date: Wed, 24 Jul 2024 16:20:02 +0800 Subject: [PATCH 58/73] Update acl_compiler.py --- python/jittor/extern/acl/acl_compiler.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/jittor/extern/acl/acl_compiler.py b/python/jittor/extern/acl/acl_compiler.py index 4c26bc98..d6a09c5f 100644 --- a/python/jittor/extern/acl/acl_compiler.py +++ b/python/jittor/extern/acl/acl_compiler.py @@ -1281,7 +1281,7 @@ def warpper(*args, **kwargs): jt.getitem = warp(jt.getitem, GetItem()) jt.Var.getitem = lambda x, slices, return_x=None: warp( - jt.getitem, GetItem())(x, slices, return_x) + jt.getitem, GetItem())(x, slices) jt.setitem = warp(jt.setitem, SetItemACL()) jt.Var.setitem = lambda x, slices, value, reduce='void': warp( From fa33e0a163a2b714466f1ddce36be50a79274499 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=BC=A0=E4=BB=AA?= Date: Thu, 25 Jul 2024 15:54:57 +0800 Subject: [PATCH 59/73] Fixed the BUG of ACL op memory --- python/jittor/extern/acl/acl_compiler.py | 104 +++++++++++++----- .../src/mem/allocator/sfrl_allocator.cc | 5 + 2 files changed, 81 insertions(+), 28 deletions(-) diff --git a/python/jittor/extern/acl/acl_compiler.py b/python/jittor/extern/acl/acl_compiler.py index d6a09c5f..22e366eb 100644 --- a/python/jittor/extern/acl/acl_compiler.py +++ b/python/jittor/extern/acl/acl_compiler.py @@ -371,7 +371,6 @@ def acl_cmd(name: str, inputs: list, output_dtypes: list, output_shapes: list, CHECK(aclopSetAttrString(attr, key, value)==0); } - void run() { // printDeviceData(input_desc, input_data, name); @@ -414,6 +413,7 @@ def __init__(self): def execute(self, inshape: list, dim, dtype="int32"): # zeros a tensor, shape is inshape, dtype is dtype + dim_input = dim if dim == None: dim = [i for i in range(len(inshape))] elif type(dim) == int: @@ -436,7 +436,7 @@ def execute(self, inshape: list, dim, dtype="int32"): shape=inshape, dims=broadcast_dim) results.append(result) - if len(results) != 1: + if len(results) != 1 or dim_input == None: return tuple(results) else: return results[0] @@ -663,6 +663,7 @@ class GetItem(Function): def __init__(self): super(GetItem, self).__init__() + self.type_ = 'index' def stride(self, x, dim): stride = 1 @@ -670,32 +671,34 @@ def stride(self, x, dim): stride *= x.shape[i] return stride - def execute(self, x, slices): + def execute(self, x, slices, return_x=None): if isinstance(slices, jt.Var) or isinstance(slices, tuple): if isinstance(slices, jt.Var): slices = (slices, ) if isinstance(slices[0], jt.Var): slices_len = len(slices) masks = jt.ones(slices_len, dtype=jt.int64) - output = slices[0].shape output += x.shape[slices_len:] - input_ = [x, masks, jt.Var(list(output)).int64()] for i in range(slices_len): input_.append(slices[i].int32()) - result = acl_cmd("Index", input_, output_dtypes=[x.dtype], output_shapes=[output], attr={})[0] + self.shape = x.shape + self.sizes = list(output) + self.type_ = 'index' + self.slices = slices + # self.strides return result # use AsStrided operator to implement the getitem function # get the shape and stride of the input tensor x_dim = len(x.shape) - + # int type if not isinstance(slices, tuple): slices = (slices, ) @@ -732,6 +735,7 @@ def execute(self, x, slices): self.strides = strides self.offset = offset self.shape = x.shape + self.type_ = 'as_strided' result = acl_cmd( "AsStrided", [x, jt.Var(sizes), @@ -743,24 +747,62 @@ def execute(self, x, slices): return result def grad(self, grad_output): - result = jt.zeros(self.shape, dtype=grad_output.dtype) - sizes = list(grad_output.shape) - strides = [ - self.stride(grad_output, dim) - for dim in range(len(grad_output.shape)) - ] - result = acl_cmd("ViewCopy", [ - result, - jt.Var(self.sizes), - jt.Var(self.strides), - jt.Var(self.offset), grad_output, - jt.Var(sizes), - jt.Var(strides), - jt.Var(0) - ], - output_dtypes=[result.dtype], - output_shapes=[result.shape], - attr={})[0] + if self.type_ == 'as_strided': + result = jt.zeros(self.shape, dtype=grad_output.dtype) + sizes = list(grad_output.shape) + strides = [ + self.stride(grad_output, dim) + for dim in range(len(grad_output.shape)) + ] + result = acl_cmd("ViewCopy", [ + result, + jt.Var(self.sizes), + jt.Var(self.strides), + jt.Var(self.offset), grad_output, + jt.Var(sizes), + jt.Var(strides), + jt.Var(0) + ], + output_dtypes=[result.dtype], + output_shapes=[result.shape], + attr={})[0] + elif self.type_ == 'index': + #TODO: use IndexPutV2 to implement the grad function + assert len(self.slices) == 1 + index = self.slices[0] + input = jt.zeros(self.shape, dtype=grad_output.dtype) + input_flatten = input.reshape(input.shape[0], -1) + index_flatten = index.reshape(-1).unsqueeze(-1).repeat( + 1, input_flatten.shape[1]) + grad_output_flatten = grad_output.reshape(index.numel(), -1) + result = acl_cmd( + "ScatterElements", + [input_flatten, index_flatten, grad_output_flatten], + output_dtypes=[input.dtype], + output_shapes=[input.shape], + attr={ + 'axis': 0, + 'reduction': 'add' + })[0] + result = result.reshape(self.shape) + # result = jt.zeros(self.shape, dtype=grad_output.dtype) + # # masks = jt.ones(len(self.slices), dtype=jt.int64) + # masks = jt.array([1,1], dtype=jt.int64) + # expand_masks = jt.array([1,1], dtype=jt.int64) + # inputs_ = [result,grad_output,masks,expand_masks] + # slices_len = len(self.slices) + # for i in range(slices_len): + # inputs_.append(self.slices[i].int64()) + # # breakpoint() + # jt.sync_all(True) + # print(inputs_) + # result_ = acl_cmd("IndexPutV2", inputs_, + # output_dtypes=[result.dtype], + # output_shapes=[result.shape], + # attr={"accumulate":True})[0] + # result = result_ + else: + raise ValueError("Invalid slice type") result.sync() return result, None @@ -1130,7 +1172,7 @@ def grad(self, grad_output): grad_input = acl_cmd("ScatterElements", [tmp, self.index, grad_output], output_dtypes=[grad_output.dtype], - output_shapes=[self.index.shape], + output_shapes=[tmp.shape], attr={ 'axis': self.dim, 'reduction': "add" @@ -1149,7 +1191,7 @@ def execute(self, input, dim, index, src, reduce='void'): self.reduce = reduce result = acl_cmd("ScatterElements", [input, self.index, src], output_dtypes=[input.dtype], - output_shapes=[index.shape], + output_shapes=[input.shape], attr={ 'axis': self.dim, 'reduction': reduce @@ -1257,10 +1299,12 @@ def warpper(*args, **kwargs): if isinstance(new_func, CumsumACL): args = (args[0], kwargs.get('dim', -1)) kwargs = {} - if isinstance(new_func, ScatterACL): + if isinstance(new_func, + ScatterACL) and kwargs.get('reduce') is not None: args = (args[0], args[1], args[2], args[3], kwargs.get('reduce', 'void')) kwargs = {} + return new_func(*args, **kwargs) return origin_func(*args, **kwargs) @@ -1292,7 +1336,11 @@ def warpper(*args, **kwargs): x, dim_vector) jt.cumsum = warp(jt.cumsum, CumsumACL()) jt.gather = warp(jt.gather, GatherACL()) + jt.Var.gather = lambda x, dim, index: warp(jt.gather, GatherACL())(x, dim, + index) jt.scatter = warp(jt.scatter, ScatterACL()) + jt.Var.scatter = lambda x, dim, index, src, reduce="void": warp( + jt.scatter, ScatterACL())(x, dim, index, src, reduce) jt.where = warp(jt.where, WhereACL()) jt.floor_int = warp(jt.floor_int, FloorIntACL()) jt.Var.floor_int = lambda x: warp(jt.floor_int, FloorIntACL())(x) diff --git a/python/jittor/src/mem/allocator/sfrl_allocator.cc b/python/jittor/src/mem/allocator/sfrl_allocator.cc index ce5460fd..add3aef9 100644 --- a/python/jittor/src/mem/allocator/sfrl_allocator.cc +++ b/python/jittor/src/mem/allocator/sfrl_allocator.cc @@ -242,7 +242,12 @@ std::mutex sfrl_allocator_mutex; void* SFRLAllocator::alloc(size_t size, size_t& allocation) { std::unique_lock lock(sfrl_allocator_mutex); + #ifdef IS_ACL + // output of acl op need additional 32 bytes + size = align_size(size+32); + #else size = align_size(size); + #endif CachingBlockPool* blocks = get_blocks(size); //search cached block CachingBlock* block = blocks->pop_block(size); From 21ac78e4ee3e3f8668c6917cb244b51f55431341 Mon Sep 17 00:00:00 2001 From: Yuxuan Han Date: Fri, 26 Jul 2024 21:16:09 +0800 Subject: [PATCH 60/73] complement of test_aclop --- python/jittor/test/test_aclop.py | 79 ++++++++++++++++++++++++-------- 1 file changed, 60 insertions(+), 19 deletions(-) diff --git a/python/jittor/test/test_aclop.py b/python/jittor/test/test_aclop.py index 196d2576..b79ae5a9 100644 --- a/python/jittor/test/test_aclop.py +++ b/python/jittor/test/test_aclop.py @@ -16,6 +16,13 @@ def test_getitem(self): np.testing.assert_allclose(b.numpy(), [[1, 1], [1, 1]]) print("test getitem success") + @jt.flag_scope(use_acl=1) + def test_getitem_neg(self): + a = jt.ones(2, 3, 2) + b = a[0:1,0:-2] + np.testing.assert_allclose(b.numpy(), [[[1,1]]]) + print("test getitem neg success") + @jt.flag_scope(use_acl=1) def test_setitem(self): a = jt.ones(2, 2) @@ -24,6 +31,14 @@ def test_setitem(self): np.testing.assert_allclose(a.numpy(), [[0, 1], [1, 1]]) print("test setitem success") + @jt.flag_scope(use_acl=1) + def test_setitem_neg(self): + a = jt.ones(2, 3, 2) + b = jt.Var(0) + a[0:1, 0:-2] = b + np.testing.assert_allclose(a.numpy(), [[[0,0],[1,1],[1,1]],[[1,1],[1,1],[1,1]]]) + print("test setitem neg success") + @jt.flag_scope(use_acl=1) def test_getitem_grad(self): a = jt.ones(2, 2) @@ -62,9 +77,25 @@ def test_concat(self): np.testing.assert_allclose(c.numpy(), [[1, 1], [1, 1], [1, 1], [1, 1]]) print("test concat success") + @jt.flag_scope(use_acl=1) + def test_concat_neg(self): + a = jt.ones(2, 2) + b = jt.ones(2, 2) + c = jt.concat([a, b], -1) + np.testing.assert_allclose(c.numpy(), [[1,1,1,1],[1,1,1,1]]) + print("test concat neg success") + + @jt.flag_scope(use_acl=1) + def test_concat_zero_dim(self): + a = jt.ones([]) + b = jt.zeros([]) + c = jt.concat([a, b], 0) + np.testing.assert_allclose(c.numpy(), [1,0]) + print("test concat zero dim success") + @jt.flag_scope(use_acl=1) def test_maxpool_grad(self): - a = jt.ones(1, 1, 4, 4) + a = jt.float32([[[[1,2,3,4],[2,3,4,1],[3,4,1,2],[4,1,2,3]]]]) max_pool = jt.nn.Pool(2, op='maximum') optimizer = jt.optim.SGD([a], 0.1) b = max_pool(a) @@ -75,7 +106,7 @@ def test_maxpool_grad(self): res = a.opt_grad(optimizer) np.testing.assert_allclose( res.numpy(), - [[[[1, 0, 1, 0], [0, 0, 0, 0], [1, 0, 1, 0], [0, 0, 0, 0]]]]) + [[[[0, 0, 0, 1], [0, 1, 0, 0], [0, 1, 0, 0], [0, 0, 0, 1]]]]) print("test maxpool grad success") @jt.flag_scope(use_acl=1) @@ -83,47 +114,57 @@ def test_triu(self): a = jt.ones(3, 3) b = jt.triu_(a, 0) c = jt.triu_(a, 1) + d = jt.triu_(a, -1) np.testing.assert_allclose(b.numpy(), [[1, 1, 1], [0, 1, 1], [0, 0, 1]]) np.testing.assert_allclose(c.numpy(), [[0, 1, 1], [0, 0, 1], [0, 0, 0]]) + np.testing.assert_allclose(d.numpy(), + [[1, 1, 1], [1, 1, 1], [0, 1, 1]]) print("test triu success") @jt.flag_scope(use_acl=1) def test_bmm(self): - a = jt.ones(3, 2, 2).float32() + a = jt.float32([[[1,2],[3,4]],[[2,1],[4,3]],[[1,2],[4,3]]]) b = jt.bmm(a, a) np.testing.assert_allclose( - b.numpy(), [[[2, 2], [2, 2]], [[2, 2], [2, 2]], [[2, 2], [2, 2]]]) + b.numpy(), [[[7, 10], [15, 22]], [[8, 5], [20, 13]], [[9, 8], [16, 17]]]) print("test bmm success") @jt.flag_scope(use_acl=1) def test_matmul(self): - a = jt.ones(1, 4, 4) - b = jt.ones(4, 2) + a = jt.float32([[[1,2,3,4],[5,6,7,8],[9,10,11,12],[13,14,15,16]]]) + b = jt.float32([[1,1],[1,1],[1,1],[1,1]]) c = jt.matmul(a, b) np.testing.assert_allclose(c.numpy(), - [[[4, 4], [4, 4], [4, 4], [4, 4]]]) + [[[10, 10], [26, 26], [42, 42], [58, 58]]]) print("test matmul success") @jt.flag_scope(use_acl=1) def test_maxpool(self): - a = jt.ones(1, 1, 4, 4) + a = jt.float32([[[[1,2,3,4],[2,3,4,1],[3,4,1,2],[4,1,2,3]]]]) max_pool = jt.nn.Pool(2, op='maximum') - np.testing.assert_allclose(max_pool(a).numpy(), [[[[1, 1], [1, 1]]]]) + np.testing.assert_allclose(max_pool(a).numpy(), [[[[3, 4], [4, 3]]]]) print("test maxpool success") @jt.flag_scope(use_acl=1) def test_transpose(self): - a = jt.ones(1, 2, 2) + a = jt.float32([[[1,2],[3,4]]]) b = a.transpose(0, 2) - np.testing.assert_allclose(b.numpy(), [[[1], [1]], [[1], [1]]]) + np.testing.assert_allclose(b.numpy(), [[[1], [3]], [[2], [4]]]) print("test transpose success") + @jt.flag_scope(use_acl=1) + def test_transpose_neg(self): + a = jt.float32([[[1,2],[3,4]]]) + b = a.transpose(1, -1) + np.testing.assert_allclose(b.numpy(), [[[1,3], [2,4]]]) + print("test transpose neg success") + @jt.flag_scope(use_acl=1) def test_matmul_grad(self): - a = jt.ones(1, 2, 2) - b = jt.ones(2, 2) + a = jt.float32([[[1,2,3,4],[5,6,7,8],[9,10,11,12],[13,14,15,16]]]) + b = jt.float32([[1,1],[1,1],[1,1],[1,1]]) optimizer = jt.optim.SGD([a, b], 0.1) loss = jt.matmul(a, b).sum() optimizer.zero_grad() @@ -131,13 +172,13 @@ def test_matmul_grad(self): optimizer.step() res_a = a.opt_grad(optimizer) res_b = b.opt_grad(optimizer) - np.testing.assert_allclose(res_a.numpy(), [[[2, 2], [2, 2]]]) - np.testing.assert_allclose(res_b.numpy(), [[2, 2], [2, 2]]) + np.testing.assert_allclose(res_a.numpy(), [[[2, 2, 2, 2], [2, 2, 2, 2], [2, 2, 2, 2], [2, 2, 2, 2]]]) + np.testing.assert_allclose(res_b.numpy(), [[28, 28], [32, 32], [36, 36], [40, 40]]) print("test matmul grad success") @jt.flag_scope(use_acl=1) def test_bmm_grad(self): - a = jt.ones(3, 2, 2).float32() + a = jt.float32([[[1,2],[3,4]],[[2,1],[4,3]],[[1,2],[4,3]]]) optimizer = jt.optim.SGD([a], 0.1) c = jt.bmm(a, a) loss = c.sum() @@ -149,15 +190,15 @@ def test_bmm_grad(self): res = a.opt_grad(optimizer) np.testing.assert_allclose( res.numpy(), - [[[4, 4], [4, 4]], [[4, 4], [4, 4]], [[4, 4], [4, 4]]]) + [[[7, 11], [9, 13]], [[9, 13], [7, 11]], [[8, 12], [8, 12]]]) print("test bmm grad success") @jt.flag_scope(use_acl=1) def test_avgpool(self): - a = jt.ones(1, 1, 4, 4) + a = jt.float32([[[[1,2,3,4],[2,3,4,1],[3,4,1,2],[4,1,2,3]]]]) avg_pool = jt.nn.Pool(2, op='mean') b = avg_pool(a) - np.testing.assert_allclose(b.numpy(), [[[[1, 1], [1, 1]]]]) + np.testing.assert_allclose(b.numpy(), [[[[2, 3], [3, 2]]]]) print("test avgpool success") @jt.flag_scope(use_acl=1) From b010cc650c7a7704a38faac5f4c5c1855d3afe58 Mon Sep 17 00:00:00 2001 From: Yuxuan Han Date: Thu, 1 Aug 2024 16:00:00 +0800 Subject: [PATCH 61/73] complement of test_aclop --- python/jittor/test/test_aclop.py | 76 ++++++++++++++++++++++++++++---- 1 file changed, 67 insertions(+), 9 deletions(-) diff --git a/python/jittor/test/test_aclop.py b/python/jittor/test/test_aclop.py index b79ae5a9..cb2235fc 100644 --- a/python/jittor/test/test_aclop.py +++ b/python/jittor/test/test_aclop.py @@ -205,13 +205,16 @@ def test_avgpool(self): def test_adaptive_maxpool2d(self): a = jt.float32([[[[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12], [13, 14, 15, 16]]]]) - pool = jt.nn.AdaptiveMaxPool2d((2, 2)) - b = pool(a) + pool_1 = jt.nn.AdaptiveMaxPool2d((2, 2)) + pool_2 = jt.nn.AdaptiveMaxPool2d((3, 4)) + b = pool_1(a) + c = pool_2(a) np.testing.assert_allclose(b.numpy(), [[[[6, 8], [14, 16]]]]) + np.testing.assert_allclose(c.numpy(), [[[[5,6,7,8],[9,10,11,12],[13,14,15,16]]]]) print("test adaptive_maxpool2d success") @jt.flag_scope(use_acl=1) - def test_adaptive_maxpool2d_grad(self): + def test_adaptive_maxpool2d_grad_1(self): a = jt.float32([[[[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12], [13, 14, 15, 16]]]]) max_pool = jt.nn.AdaptiveMaxPool2d((2, 2)) @@ -225,15 +228,35 @@ def test_adaptive_maxpool2d_grad(self): np.testing.assert_allclose( res.numpy(), [[[[0, 0, 0, 0], [0, 1, 0, 1], [0, 0, 0, 0], [0, 1, 0, 1]]]]) - print("test adaptive_maxpool2d grad success") + print("test adaptive_maxpool2d_1 grad success") + + @jt.flag_scope(use_acl=1) + def test_adaptive_maxpool2d_grad_2(self): + a = jt.float32([[[[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12], + [13, 14, 15, 16]]]]) + max_pool = jt.nn.AdaptiveMaxPool2d((1, 3)) + optimizer = jt.optim.SGD([a], 0.1) + b = max_pool(a) + loss = b.sum() + optimizer.zero_grad() + optimizer.backward(loss) + optimizer.step() + res = a.opt_grad(optimizer) + np.testing.assert_allclose( + res.numpy(), + [[[[0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 1, 1, 1]]]]) + print("test adaptive_maxpool2d_2 grad success") @jt.flag_scope(use_acl=1) def test_adaptive_avgpool2d(self): a = jt.float32([[[[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12], [13, 14, 15, 16]]]]) - pool = jt.nn.AdaptiveAvgPool2d((2, 2)) - b = pool(a) + pool_1 = jt.nn.AdaptiveAvgPool2d((2, 2)) + pool_2 = jt.nn.AdaptiveAvgPool2d((1, 3)) + b = pool_1(a) + c = pool_2(a) np.testing.assert_allclose(b.numpy(), [[[[3.5, 5.5], [11.5, 13.5]]]]) + np.testing.assert_allclose(c.numpy(), [[[[7.5, 8.5, 9.5]]]]) print("test adaptive_avgpool2d success") @jt.flag_scope(use_acl=1) @@ -253,10 +276,28 @@ def test_adaptive_avgpool2d_grad(self): [[[[0.25, 0.25, 0.25, 0.25], [0.25, 0.25, 0.25, 0.25], [0.25, 0.25, 0.25, 0.25], [0.25, 0.25, 0.25, 0.25]]]]) print("test adaptive_avgpool2d grad success") + + @jt.flag_scope(use_acl=1) + def test_adaptive_avgpool2d_grad_2(self): + a = jt.float32([[[[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12], + [13, 14, 15, 16]]]]) + avg_pool = jt.nn.AdaptiveAvgPool2d((1, 3)) + optimizer = jt.optim.SGD([a], 0.1) + b = avg_pool(a) + loss = b.sum() + optimizer.zero_grad() + optimizer.backward(loss) + optimizer.step() + res = a.opt_grad(optimizer) + np.testing.assert_allclose( + res.numpy(), + [[[[0.125, 0.25, 0.25, 0.125], [0.125, 0.25, 0.25, 0.125], + [0.125, 0.25, 0.25, 0.125], [0.125, 0.25, 0.25, 0.125]]]]) + print("test adaptive_avgpool2d_2 grad success") @jt.flag_scope(use_acl=1) def test_index(self): - a = jt.ones(2, 3) + a = jt.rand(2, 3) [s1, s2] = jt.index(a.shape) np.testing.assert_allclose(s1.numpy(), [[0, 0, 0], [1, 1, 1]]) np.testing.assert_allclose(s2.numpy(), [[0, 1, 2], [0, 1, 2]]) @@ -267,20 +308,37 @@ def test_gather(self): a = jt.array([[1, 2], [3, 4]]) b = jt.gather(a, 1, jt.array([[0, 0], [1, 0]])) np.testing.assert_allclose(b.numpy(), [[1, 1], [4, 3]]) + b = jt.gather(a, 0, jt.array([[0, 0], [1, 0]])) + np.testing.assert_allclose(b.numpy(), [[1, 2], [3, 2]]) + b = jt.gather(a, -1, jt.array([[0, 0], [1, 0]])) + np.testing.assert_allclose(b.numpy(), [[1, 1], [4, 3]]) print("test gather success") @jt.flag_scope(use_acl=1) def test_gather_grad(self): a = jt.float32([[1, 2], [3, 4]]) optimizer = jt.optim.SGD([a], 0.1) - b = jt.gather(a, 1, jt.array([[0, 0], [1, 0]])) + b = jt.gather(a, 0, jt.array([[0, 0], [1, 0]])) loss = b.sum() optimizer.zero_grad() optimizer.backward(loss) optimizer.step() res = a.opt_grad(optimizer) - np.testing.assert_allclose(res.numpy(), [[2, 0], [1, 1]]) + np.testing.assert_allclose(res.numpy(), [[1, 2], [1, 0]]) print("test gather grad success") + + @jt.flag_scope(use_acl=1) + def test_gather_grad_neg(self): + a = jt.float32([[4, 3], [2, 1]]) + optimizer = jt.optim.SGD([a], 0.1) + b = jt.gather(a, -1, jt.array([[0, 0], [1, 0]])) + loss = b.sum() + optimizer.zero_grad() + optimizer.backward(loss) + optimizer.step() + res = a.opt_grad(optimizer) + np.testing.assert_allclose(res.numpy(), [[2, 0], [1, 1]]) + print("test gather grad neg success") @jt.flag_scope(use_acl=1) def test_scatter(self): From 7cf4732904831fc1d72de57fb15cf6b2c6ad3a09 Mon Sep 17 00:00:00 2001 From: Yuxuan Han Date: Mon, 12 Aug 2024 19:28:01 +0800 Subject: [PATCH 62/73] complement of test_aclop --- python/jittor/test/test_aclop.py | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/python/jittor/test/test_aclop.py b/python/jittor/test/test_aclop.py index cb2235fc..647a589f 100644 --- a/python/jittor/test/test_aclop.py +++ b/python/jittor/test/test_aclop.py @@ -341,12 +341,20 @@ def test_gather_grad_neg(self): print("test gather grad neg success") @jt.flag_scope(use_acl=1) - def test_scatter(self): + def test_scatter_add(self): a = jt.array([[1, 2], [3, 4]]) b = jt.array([[0, 0], [0, 0]]) b = jt.scatter(b, 1, jt.array([[0, 0], [1, 0]]), a, reduce="add") np.testing.assert_allclose(b.numpy(), [[3, 0], [4, 3]]) - print("test scatter success") + print("test scatter add success") + + @jt.flag_scope(use_acl=1) + def test_scatter_multi(self): + a = jt.array([[1, 2], [3, 4]]) + b = jt.array([[5, 6], [7, 8]]) + b = jt.scatter(b, 0, jt.array([[0, 0], [1, 0]]), a, reduce="multiply") + np.testing.assert_allclose(b.numpy(), [[5, 48], [21, 8]]) + print("test scatter multiply success") @jt.flag_scope(use_acl=1) def test_scatter_grad(self): From 4159ace94137aeeca14c13d7dc13eb1596a518b7 Mon Sep 17 00:00:00 2001 From: Yuxuan Han Date: Mon, 12 Aug 2024 19:50:29 +0800 Subject: [PATCH 63/73] complement of test_aclop: error of scatter()-multiple and where() --- python/jittor/test/test_aclop.py | 48 ++++++++++++++++++++++++++++++-- 1 file changed, 45 insertions(+), 3 deletions(-) diff --git a/python/jittor/test/test_aclop.py b/python/jittor/test/test_aclop.py index 647a589f..41a6d748 100644 --- a/python/jittor/test/test_aclop.py +++ b/python/jittor/test/test_aclop.py @@ -357,7 +357,7 @@ def test_scatter_multi(self): print("test scatter multiply success") @jt.flag_scope(use_acl=1) - def test_scatter_grad(self): + def test_scatter_add_grad(self): a = jt.float32([[1, 2], [3, 4]]) b = jt.float32([[0, 0], [0, 0]]) optimizer = jt.optim.SGD([a, b], 0.1) @@ -370,7 +370,23 @@ def test_scatter_grad(self): res_b = b.opt_grad(optimizer) np.testing.assert_allclose(res_a.numpy(), [[0, 0], [0, 1]]) np.testing.assert_allclose(res_b.numpy(), [[0, 0], [1, 0]]) - print("test scatter grad success") + print("test scatter add grad success") + + @jt.flag_scope(use_acl=1) + def test_scatter_mult_grad(self): + a = jt.float32([[1, 2], [3, 4]]) + b = jt.float32([[5, 6], [7, 8]]) + optimizer = jt.optim.SGD([a, b], 0.1) + c = jt.scatter(b, 1, jt.array([[0, 0], [1, 0]]), a, reduce="multiply") + loss = c.max() + optimizer.zero_grad() + optimizer.backward(loss) + optimizer.step() + res_a = a.opt_grad(optimizer) + res_b = b.opt_grad(optimizer) + np.testing.assert_allclose(res_a.numpy(), [[0, 6], [0, 6]]) + np.testing.assert_allclose(res_b.numpy(), [[0, 8], [0, 0]]) + print("test scatter mult grad success") @jt.flag_scope(use_acl=1) def test_where(self): @@ -380,8 +396,34 @@ def test_where(self): np.testing.assert_allclose(c.numpy(), [[1, 1], [3, 4]]) print("test where success") + @jt.flag_scope(use_acl=1) + def test_where_2(self): + a = jt.array([[1, 2], [3, 4]]) + b = jt.array([[5, 6], [7, 8]]) + cond = jt.array([[1, 0], [0, 1]]) + c = jt.where(cond, a, b) + np.testing.assert_allclose(c.numpy(), [[1, 6], [7, 4]]) + print("test where_2 success") + @jt.flag_scope(use_acl=1) def test_where_grad(self): + a = jt.array([[1, 2], [3, 4]]) + b = jt.array([[5, 6], [7, 8]]) + cond = jt.array([[1, 0], [0, 1]]) + c = jt.where(cond, a, b) + optimizer = jt.optim.SGD([a, b], 0.1) + loss = c.sum() + optimizer.zero_grad() + optimizer.backward(loss) + optimizer.step() + res_a = a.opt_grad(optimizer) + res_b = b.opt_grad(optimizer) + np.testing.assert_allclose(res_a.numpy(), [[0, 0], [1, 1]]) + np.testing.assert_allclose(res_b.numpy(), [[1, 1], [0, 0]]) + print("test where grad success") + + @jt.flag_scope(use_acl=1) + def test_where_grad_2(self): a = jt.float32([[1, 2], [3, 4]]) b = jt.array([[2., 2.], [2., 2.]]) c = jt.where(a > 2, a, b) @@ -394,7 +436,7 @@ def test_where_grad(self): res_b = b.opt_grad(optimizer) np.testing.assert_allclose(res_a.numpy(), [[0, 0], [1, 1]]) np.testing.assert_allclose(res_b.numpy(), [[1, 1], [0, 0]]) - print("test where grad success") + print("test where grad 2 success") @jt.flag_scope(use_acl=1) def test_flip(self): From 466722519d1f447aa0b1839e542cd6e358ba5c8b Mon Sep 17 00:00:00 2001 From: 514flowey <1114811901@qq.com> Date: Tue, 20 Aug 2024 15:08:19 +0800 Subject: [PATCH 64/73] add several ffunctions --- python/jittor/__init__.py | 1 + python/jittor/__init__.pyi | 2 +- python/jittor/gradfunctional/__init__.py | 2 + python/jittor/gradfunctional/functional.py | 420 +++++++++++++++++++++ python/jittor/nn.py | 131 +++++++ python/jittor/test/test_complex.py | 317 +++++++++++++++- 6 files changed, 856 insertions(+), 17 deletions(-) create mode 100644 python/jittor/gradfunctional/__init__.py create mode 100644 python/jittor/gradfunctional/functional.py diff --git a/python/jittor/__init__.py b/python/jittor/__init__.py index dc17a228..18838634 100644 --- a/python/jittor/__init__.py +++ b/python/jittor/__init__.py @@ -2140,6 +2140,7 @@ def is_var(v): from . import optim from . import dataset from . import init +from . import gradfunctional dtype = NanoString diff --git a/python/jittor/__init__.pyi b/python/jittor/__init__.pyi index b849af4c..6cfce692 100644 --- a/python/jittor/__init__.pyi +++ b/python/jittor/__init__.pyi @@ -1,7 +1,7 @@ from jittor_core import * from jittor_core.ops import * from .misc import * -from . import attention as attention, contrib as contrib, dataset as dataset, init as init, linalg as linalg, lr_scheduler as lr_scheduler, numpy2cupy as numpy2cupy, optim as optim, sparse as sparse +from . import attention as attention, contrib as contrib, dataset as dataset, init as init, linalg as linalg, lr_scheduler as lr_scheduler, numpy2cupy as numpy2cupy, optim as optim, sparse as sparse, gradfunctional as gradfunctional from .compile_extern import cublas as cublas, cudnn as cudnn, cufft as cufft, curand as curand, mkl_ops as mkl_ops, mpi_ops as mpi_ops, world_size as world_size from .compiler import compile_custom_op as compile_custom_op, compile_custom_ops as compile_custom_ops from .contrib import concat as concat diff --git a/python/jittor/gradfunctional/__init__.py b/python/jittor/gradfunctional/__init__.py new file mode 100644 index 00000000..259897e1 --- /dev/null +++ b/python/jittor/gradfunctional/__init__.py @@ -0,0 +1,2 @@ +from .functional import jvp, vjp + diff --git a/python/jittor/gradfunctional/functional.py b/python/jittor/gradfunctional/functional.py new file mode 100644 index 00000000..df183cc1 --- /dev/null +++ b/python/jittor/gradfunctional/functional.py @@ -0,0 +1,420 @@ +# reference: https://github.com/pytorch/pytorch/blob/8ea5b572a63b1acc538a9fc8d3862c73739116e8/torch/autograd/functional.py +import jittor as jt + +__all__ = ["vjp", "jvp", "jacobian", "hessian", "hvp", "vhp"] + +# Utility functions +def _as_tuple_nocheck(x): + if isinstance(x, tuple): + return x + elif isinstance(x, list): + return tuple(x) + else: + return (x,) + +def _as_tuple(inp, arg_name=None, fn_name=None): + # Ensures that inp is a tuple of Tensors + # Returns whether or not the original inp was a tuple and the tupled version of the input + if arg_name is None and fn_name is None: + return _as_tuple_nocheck(inp) + + is_inp_tuple = True + if not isinstance(inp, tuple): + inp = (inp,) + is_inp_tuple = False + + for i, el in enumerate(inp): + if not isinstance(el, (jt.Var, jt.nn.ComplexNumber)): + if is_inp_tuple: + raise TypeError( + f"The {arg_name} given to {fn_name} must be either a Tensor or a tuple of Tensors but the" + f" value at index {i} has type {type(el)}." + ) + else: + raise TypeError( + f"The {arg_name} given to {fn_name} must be either a Tensor or a tuple of Tensors but the" + f" given {arg_name} has type {type(el)}." + ) + + return is_inp_tuple, inp + + +def _tuple_postprocess(res, to_unpack): + # Unpacks a potentially nested tuple of Tensors + # to_unpack should be a single boolean or a tuple of two booleans. + # It is used to: + # - invert _as_tuple when res should match the inp given to _as_tuple + # - optionally remove nesting of two tuples created by multiple calls to _as_tuple + if isinstance(to_unpack, tuple): + assert len(to_unpack) == 2 + if not to_unpack[1]: + res = tuple(el[0] for el in res) + if not to_unpack[0]: + res = res[0] + else: + if not to_unpack: + res = res[0] + return res + + +def _grad_preprocess(inputs, create_graph, need_graph): + # Preprocess the inputs to make sure they require gradient + # inputs is a tuple of Tensors to preprocess + # create_graph specifies if the user wants gradients to flow back to the Tensors in inputs + # need_graph specifies if we internally want gradients to flow back to the Tensors in res + # Note that we *always* create a new Tensor object to be able to see the difference between + # inputs given as arguments and the same Tensors automatically captured by the user function. + # Check this issue for more details on how that can happen: https://github.com/pytorch/pytorch/issues/32576 + res = [] + for inp in inputs: + if create_graph and inp.requires_grad: + # Create at least a new Tensor object in a differentiable way + # Use .reshae() to get a shallow copy + res.append(inp.reshape(inp.shape)) + else: + if need_graph: + ninp = inp.detach().start_grad() + else: + ninp = inp.detach().stop_grad() + res.append(ninp) + return tuple(res) + + +def _grad_postprocess(inputs, create_graph): + # Postprocess the generated Tensors to avoid returning Tensors with history when the user did not + # request it. + if isinstance(inputs[0], (jt.Var, jt.nn.ComplexNumber)): + if not create_graph: + return tuple(inp.detach() for inp in inputs) + else: + return inputs + else: + return tuple(_grad_postprocess(inp, create_graph) for inp in inputs) + + +def _validate_v(v, other, is_other_tuple): + # This assumes that other is the correct shape, and v should match + # Both are assumed to be tuples of Tensors + if len(other) != len(v): + if is_other_tuple: + raise RuntimeError( + f"v is a tuple of invalid length: should be {len(other)} but got {len(v)}." + ) + else: + raise RuntimeError("The given v should contain a single Tensor.") + + for idx, (el_v, el_other) in enumerate(zip(v, other)): + if el_v.shape != el_other.shape: + prepend = "" + if is_other_tuple: + prepend = f"Entry {idx} in " + raise RuntimeError( + f"{prepend}v has invalid size: should be {el_other.shape} but got {el_v.shape}." + ) + + +def _check_requires_grad(inputs, input_type, strict): + # Used to make all the necessary checks to raise nice errors in strict mode. + if not strict: + return + + if input_type not in ["outputs", "grad_inputs", "jacobian", "hessian"]: + raise RuntimeError("Invalid input_type to _check_requires_grad") + for i, inp in enumerate(inputs): + if inp is None: + # This can only be reached for grad_inputs. + raise RuntimeError( + f"The output of the user-provided function is independent of input {i}." + " This is not allowed in strict mode." + ) + if not inp.requires_grad: + if input_type == "hessian": + raise RuntimeError( + f"The hessian of the user-provided function with respect to input {i}" + " is independent of the input. This is not allowed in strict mode." + " You should ensure that your function is thrice differentiable and that" + " the hessian depends on the inputs." + ) + elif input_type == "jacobian": + raise RuntimeError( + "While computing the hessian, found that the jacobian of the user-provided" + f" function with respect to input {i} is independent of the input. This is not" + " allowed in strict mode. You should ensure that your function is twice" + " differentiable and that the jacobian depends on the inputs (this would be" + " violated by a linear function for example)." + ) + elif input_type == "grad_inputs": + raise RuntimeError( + f"The gradient with respect to input {i} is independent of the inputs of the" + " user-provided function. This is not allowed in strict mode." + ) + else: + raise RuntimeError( + f"Output {i} of the user-provided function does not require gradients." + " The outputs must be computed in a differentiable manner from the input" + " when running in strict mode." + ) + + +def _autograd_grad( + outputs, + inputs, + grad_outputs=None, + create_graph=True, +): + # Version of grad that accepts `None` in outputs and do not compute gradients for them. + # This has the extra constraint that inputs has to be a tuple + assert isinstance(outputs, tuple) + if grad_outputs is None: + grad_outputs = (None,) * len(outputs) + assert isinstance(grad_outputs, tuple) + assert len(outputs) == len(grad_outputs) + + new_outputs = () + new_grad_outputs = () + for out, grad_out in zip(outputs, grad_outputs): + if out is not None and out.requires_grad: + new_outputs += (out,) + new_grad_outputs += (grad_out,) + + if len(new_outputs) == 0: + # No differentiable output, we don't need to call the autograd engine + return (None,) * len(inputs) + else: + acc_loss = None + for new_output, grad_output in zip(new_outputs, grad_outputs): + if isinstance(new_output, jt.nn.ComplexNumber): + if grad_output is not None: + loss = (new_output.value * grad_output.value).sum() + else: + loss = new_output.value.sum() + else: + if grad_output is not None: + new_output = new_output * grad_output + loss = new_output.sum() + if acc_loss is None: + acc_loss = loss + else: + acc_loss += loss + + complex_inds = [] + var_inputs = [] + for idx, inp in enumerate(inputs): + if isinstance(inp, jt.nn.ComplexNumber): + var_inputs.append(inp.value) + complex_inds.append(idx) + else: + var_inputs.append(inp) + + grads = jt.grad(acc_loss, var_inputs, retain_graph=create_graph) + for complex_ind in complex_inds: + grads[complex_ind] = jt.nn.ComplexNumber(grads[complex_ind], is_concat_value=True) + return tuple(grads) + + +def _fill_in_zeros(grads, refs, strict, create_graph, stage): + # Used to detect None in the grads and depending on the flags, either replace them + # with Tensors full of 0s of the appropriate size based on the refs or raise an error. + # strict and create graph allow us to detect when it is appropriate to raise an error + # stage gives us information of which backward call we consider to give good error message + if stage not in ["back", "back_trick", "double_back", "double_back_trick"]: + raise RuntimeError(f"Invalid stage argument '{stage}' to _fill_in_zeros") + + res = () + for i, grads_i in enumerate(grads): + if grads_i is None: + if strict: + if stage == "back": + raise RuntimeError( + "The output of the user-provided function is independent of " + f"input {i}. This is not allowed in strict mode." + ) + elif stage == "back_trick": + raise RuntimeError( + f"The gradient with respect to the input is independent of entry {i}" + " in the grad_outputs when using the double backward trick to compute" + " forward mode gradients. This is not allowed in strict mode." + ) + elif stage == "double_back": + raise RuntimeError( + "The jacobian of the user-provided function is independent of " + f"input {i}. This is not allowed in strict mode." + ) + else: + raise RuntimeError( + "The hessian of the user-provided function is independent of " + f"entry {i} in the grad_jacobian. This is not allowed in strict " + "mode as it prevents from using the double backward trick to " + "replace forward mode AD." + ) + + refs_i = refs[i] + if isinstance(refs_i, jt.nn.ComplexNumber): + grads_i = jt.nn.ComplexNumber(jt.zeros_like(refs_i.value), is_concat_value=True) + else: + grads_i = jt.zeros_like(refs_i) + else: + if strict and create_graph and not grads_i.requires_grad: + if "double" not in stage: + raise RuntimeError( + "The jacobian of the user-provided function is independent of " + f"input {i}. This is not allowed in strict mode when create_graph=True." + ) + else: + raise RuntimeError( + "The hessian of the user-provided function is independent of " + f"input {i}. This is not allowed in strict mode when create_graph=True." + ) + + res += (grads_i,) + + return res + + +# Public API + +def vjp(func, inputs, v=None, create_graph=False, strict=False): + r"""Compute the dot product between a vector ``v`` and the Jacobian of the given function at the point given by the inputs. + + Args: + func (function): a Python function that takes Tensor inputs and returns + a tuple of Tensors or a Tensor. + inputs (tuple of Tensors or Tensor): inputs to the function ``func``. + v (tuple of Tensors or Tensor): The vector for which the vector + Jacobian product is computed. Must be the same size as the output + of ``func``. This argument is optional when the output of ``func`` + contains a single element and (if it is not provided) will be set + as a Tensor containing a single ``1``. + create_graph (bool, optional): If ``True``, both the output and result + will be computed in a differentiable way. Note that when ``strict`` + is ``False``, the result can not require gradients or be + disconnected from the inputs. Defaults to ``False``. + strict (bool, optional): If ``True``, an error will be raised when we + detect that there exists an input such that all the outputs are + independent of it. If ``False``, we return a Tensor of zeros as the + vjp for said inputs, which is the expected mathematical value. + Defaults to ``False``. + + Returns: + output (tuple): tuple with: + func_output (tuple of Tensors or Tensor): output of ``func(inputs)`` + + vjp (tuple of Tensors or Tensor): result of the dot product with + the same shape as the inputs. + """ + with jt.enable_grad(): + is_inputs_tuple, inputs = _as_tuple(inputs, "inputs", "vjp") + inputs = _grad_preprocess(inputs, create_graph=create_graph, need_graph=True) + + outputs = func(*inputs) + is_outputs_tuple, outputs = _as_tuple( + outputs, "outputs of the user-provided function", "vjp" + ) + _check_requires_grad(outputs, "outputs", strict=strict) + + if v is not None: + _, v = _as_tuple(v, "v", "vjp") + v = _grad_preprocess(v, create_graph=create_graph, need_graph=False) + _validate_v(v, outputs, is_outputs_tuple) + else: + if len(outputs) != 1 or outputs[0].nelement() != 1: + raise RuntimeError( + "The vector v can only be None if the " + "user-provided function returns " + "a single Tensor with a single element." + ) + + with jt.enable_grad(): + grad_res = _autograd_grad(outputs, inputs, v, create_graph=create_graph) + vjp = _fill_in_zeros(grad_res, inputs, strict, create_graph, "back") + + # Cleanup objects and return them to the user + outputs = _grad_postprocess(outputs, create_graph) + vjp = _grad_postprocess(vjp, create_graph) + + return _tuple_postprocess(outputs, is_outputs_tuple), _tuple_postprocess( + vjp, is_inputs_tuple + ) + + +def jvp(func, inputs, v=None, create_graph=False, strict=False): + r"""Compute the dot product between the Jacobian of the given function at the point given by the inputs and a vector ``v``. + + Args: + func (function): a Python function that takes Tensor inputs and returns + a tuple of Tensors or a Tensor. + inputs (tuple of Tensors or Tensor): inputs to the function ``func``. + v (tuple of Tensors or Tensor): The vector for which the Jacobian + vector product is computed. Must be the same size as the input of + ``func``. This argument is optional when the input to ``func`` + contains a single element and (if it is not provided) will be set + as a Tensor containing a single ``1``. + create_graph (bool, optional): If ``True``, both the output and result + will be computed in a differentiable way. Note that when ``strict`` + is ``False``, the result can not require gradients or be + disconnected from the inputs. Defaults to ``False``. + strict (bool, optional): If ``True``, an error will be raised when we + detect that there exists an input such that all the outputs are + independent of it. If ``False``, we return a Tensor of zeros as the + jvp for said inputs, which is the expected mathematical value. + Defaults to ``False``. + + Returns: + output (tuple): tuple with: + func_output (tuple of Tensors or Tensor): output of ``func(inputs)`` + + jvp (tuple of Tensors or Tensor): result of the dot product with + the same shape as the output. + + """ + with jt.enable_grad(): + is_inputs_tuple, inputs = _as_tuple(inputs, "inputs", "jvp") + inputs = _grad_preprocess(inputs, create_graph=create_graph, need_graph=True) + + if v is not None: + _, v = _as_tuple(v, "v", "jvp") + v = _grad_preprocess(v, create_graph=create_graph, need_graph=False) + _validate_v(v, inputs, is_inputs_tuple) + else: + if len(inputs) != 1 or inputs[0].nelement() != 1: + raise RuntimeError( + "The vector v can only be None if the input to " + "the user-provided function is a single Tensor " + "with a single element." + ) + + outputs = func(*inputs) + is_outputs_tuple, outputs = _as_tuple( + outputs, "outputs of the user-provided function", "jvp" + ) + _check_requires_grad(outputs, "outputs", strict=strict) + # The backward is linear so the value of grad_outputs is not important as + # it won't appear in the double backward graph. We only need to ensure that + # it does not contain inf or nan. + grad_outputs = tuple( + jt.nn.ComplexNumber(jt.zeros_like(out.value), is_concat_value=True) if isinstance(out, jt.nn.ComplexNumber) else jt.zeros_like(out) + for out in outputs + ) + + grad_inputs = _autograd_grad(outputs, inputs, grad_outputs=grad_outputs, create_graph=True) + _check_requires_grad(grad_inputs, "grad_inputs", strict=strict) + + if create_graph: + with jt.enable_grad(): + grad_res = _autograd_grad( + grad_inputs, grad_outputs, v, create_graph=create_graph + ) + jvp = _fill_in_zeros(grad_res, outputs, strict, create_graph, "back_trick") + else: + grad_res = _autograd_grad( + grad_inputs, grad_outputs, v, create_graph=create_graph + ) + jvp = _fill_in_zeros(grad_res, outputs, strict, create_graph, "back_trick") + + # Cleanup objects and return them to the user + outputs = _grad_postprocess(outputs, create_graph) + jvp = _grad_postprocess(jvp, create_graph) + + return _tuple_postprocess(outputs, is_outputs_tuple), _tuple_postprocess( + jvp, is_outputs_tuple + ) diff --git a/python/jittor/nn.py b/python/jittor/nn.py index cafaf352..aad40f5c 100644 --- a/python/jittor/nn.py +++ b/python/jittor/nn.py @@ -3130,6 +3130,10 @@ def __init__(self, real: jt.Var, imag: jt.Var=None, is_concat_value=False): assert real.dtype == imag.dtype self.value = jt.stack([real, imag], dim=-1) + @property + def requires_grad(self): + return self.value.requires_grad + @property def real(self): return self.value[..., 0] @@ -3142,6 +3146,10 @@ def imag(self): def shape(self): return self.value.shape[:-1] + @property + def dtype(self): + return "complex64" + def norm(self): return jt.sqrt(jt.sqr(self.real) + jt.sqr(self.imag)) @@ -3287,6 +3295,129 @@ def view_as_complex(x: jt.Var) -> ComplexNumber: def view_as_real(x: ComplexNumber) -> jt.Var: return jt.stack([x.value[...,0],x.value[...,1]],dim=-1) +# reference: https://github.com/pytorch/pytorch/blob/8ea5b572a63b1acc538a9fc8d3862c73739116e8/torch/functional.py#L1258 +def tensordot(a, b, dims=2): + r"""Returns a contraction of a and b over multiple dimensions. + + :attr:`tensordot` implements a generalized matrix product. + + Args: + a (Tensor): Left tensor to contract + b (Tensor): Right tensor to contract + dims (int or Tuple[List[int], List[int]] or List[List[int]] containing two lists or Tensor): number of dimensions to + contract or explicit lists of dimensions for :attr:`a` and + :attr:`b` respectively + + When called with a non-negative integer argument :attr:`dims` = :math:`d`, and + the number of dimensions of :attr:`a` and :attr:`b` is :math:`m` and :math:`n`, + respectively, :func:`tensordot` computes + + .. math:: + r_{i_0,...,i_{m-d}, i_d,...,i_n} + = \sum_{k_0,...,k_{d-1}} a_{i_0,...,i_{m-d},k_0,...,k_{d-1}} \times b_{k_0,...,k_{d-1}, i_d,...,i_n}. + + When called with :attr:`dims` of the list form, the given dimensions will be contracted + in place of the last :math:`d` of :attr:`a` and the first :math:`d` of :math:`b`. The sizes + in these dimensions must match. + + """ + if not isinstance(dims, (tuple, list, int)): + raise RuntimeError( + "tensordot expects dims to be int or " + + "Tuple[List[int], List[int]] or " + + "List[List[int]] containing two lists, but got " + + f"dims={dims}" + ) + + dims_a, dims_b = [], [] + + if isinstance(dims, (tuple, list)): + dims_a, dims_b = dims + + if isinstance(dims, (int)): + if dims < 0: + raise RuntimeError(f"tensordot expects dims >= 0, but got dims={dims}") + if dims > min(len(a.shape), len(b.shape)): + raise RuntimeError( + f"tensordot expects dims < ndim_a or ndim_b, but got dims={dims}" + ) + dims_a = list(range(len(a.shape)-dims, len(a.shape))) + dims_b = list(range(dims)) + + # reference: https://github.com/pytorch/pytorch/blob/8ea5b572a63b1acc538a9fc8d3862c73739116e8/aten/src/ATen/native/Linear.cpp#L769 + def __tensordot_native(input1:jt.Var, input2:jt.Var, dims1, dims2): + if not isinstance(dims1, (list, tuple)): + raise RuntimeError("tensordot expects dims1 to be List[Int], but got dims={}".format(dims1)) + if not isinstance(dims2, (list, tuple)): + raise RuntimeError("tensordot expects dims2 to be List[Int], but got dims={}".format(dims2)) + dims1 = list(dims1) + dims2 = list(dims2) + if len(dims1) != len(dims2): + raise RuntimeError("both dimension lists should have the same length") + if input1.dtype != input2.dtype: + raise RuntimeError("both inputs should have the same dtype") + t1 = input1 + t2 = input2 + csize = 1 + input1_bitmap = np.zeros(len(input1.shape), dtype='bool') + input2_bitmap = np.zeros(len(input2.shape), dtype='bool') + for i in range(len(dims1)): + s1 = input1.shape[dims1[i]] + s2 = input2.shape[dims2[i]] + input1_bitmap[dims1] = True + input2_bitmap[dims2] = True + if s2 == 1: #broadcasted dimensions can be summed right away + t1 = t1.sum(dims1[i], keepdims=True) + elif s1 == 1: + t2 = t2.sum(dims2[i], keepdims=True) + else: + if s1 != s2: + raise RuntimeError("contracted dimensions need to match, but first has size {}, in dim {}, and second has size {}".format(s1, i, s2)) + csize *= s1 + + p1, p2 = [], [] # p1, p2: input permutations + rsizes = [] + size1, size2 = 1, 1 # number of non-contracted elements + for i in range(len(input1.shape)): + if not input1_bitmap[i]: + p1.append(i) + size1 *= t1.shape[i] + rsizes.append(t1.shape[i]) + p1 += dims1 + p2 += dims2 + for i in range(len(input2.shape)): + if not input2_bitmap[i]: + p2.append(i) + size2 *= t2.shape[i] + rsizes.append(t2.shape[i]) + + # permute and reshape for matrix multiplication + t1 = t1.permute(p1).reshape((size1, csize)) + t2 = t2.permute(p2).reshape((csize, size2)) + # multiply and reshape to target size + return jt.matmul(t1, t2).reshape(rsizes) + + return __tensordot_native(a, b, dims_a, dims_b) + +# reference: https://github.com/pytorch/pytorch/blob/5ed3b70d09a4ab2a5be4becfda9dd0d3e3227c39/aten/src/ATen/native/LinearAlgebra.cpp#L3375 +def kron(a:jt.Var, b:jt.Var): + a_dim, b_dim = len(a.shape), len(b.shape) + max_dim = max(a_dim, b_dim) + pad_a, pad_b = max_dim-a_dim, max_dim-b_dim + a_reshape, b_reshape = [], [] + result_reshape = [] + for i in range(max_dim): + a_2i_shape = a.shape[i - pad_a] if i >= pad_a else 1 + b_2i1_shape = b.shape[i - pad_b] if i >= pad_b else 1 + a_reshape.append(a_2i_shape) + a_reshape.append(1) + b_reshape.append(1) + b_reshape.append(b_2i1_shape) + result_reshape.append(a_2i_shape * b_2i1_shape) + a = a.reshape(a_reshape) + b = b.reshape(b_reshape) + return (a * b).reshape(result_reshape) + def one_hot(x: jt.Var, num_classes: int=-1) -> jt.Var: ''' Returns the one_hot encoding of inputs. diff --git a/python/jittor/test/test_complex.py b/python/jittor/test/test_complex.py index 19686009..4c1a6c53 100644 --- a/python/jittor/test/test_complex.py +++ b/python/jittor/test/test_complex.py @@ -2,6 +2,7 @@ from jittor.nn import ComplexNumber import unittest import numpy as np +from functools import partial __skip_torch_test = False try: @@ -10,6 +11,15 @@ __skip_torch_test = True class TestResultAndGrad: + def flatten_list(self, list_like): + results = [] + if isinstance(list_like, (list, tuple)): + for x in list_like: + results.extend(self.flatten_list(x)) + return results + else: + return [list_like] + def check_results(self, rlist1, rlist2): assert len(rlist1) == len(rlist2) for r1, r2 in zip(rlist1, rlist2): @@ -36,13 +46,21 @@ def grad_torch(self, inputs, losses): grads.append(g.detach().cpu().numpy()) return grads - def run_jittor_op(self, op, input_list, weights=None): - def _np_to_jittor(x:np.ndarray): - if x.dtype == np.complex64 or x.dtype == np.complex128: - nx = np.stack([np.real(x), np.imag(x)], axis=-1) - return ComplexNumber(jt.array(nx, dtype=jt.float32), is_concat_value=True) - elif x.dtype == np.float32 or x.dtype == np.float64: - return jt.array(x, dtype=jt.float32) + def run_jittor_op(self, op, input_list, weights=None, key_names=None, **kwargs): + def _np_to_jittor(x): + if isinstance(x, np.ndarray): + if x.dtype == np.complex64 or x.dtype == np.complex128: + nx = np.stack([np.real(x), np.imag(x)], axis=-1) + return ComplexNumber(jt.array(nx, dtype=jt.float32), is_concat_value=True) + elif x.dtype == np.float32 or x.dtype == np.float64: + return jt.array(x, dtype=jt.float32) + else: + assert False + elif isinstance(x, (list, tuple)): + nx = [_np_to_jittor(vx) for vx in x] + if isinstance(x, tuple): + return tuple(nx) + return nx else: assert False def _jittor_to_np(x): @@ -51,11 +69,19 @@ def _jittor_to_np(x): elif isinstance(x, ComplexNumber): return x.real.numpy() + 1j * x.imag.numpy() assert False - ninput_list = [_np_to_jittor(x) for x in input_list] - output_list = op(*ninput_list) + + if key_names != None: + assert len(ninput_list) == len(key_names) + nkwargs = kwargs.copy() + for k, v in zip(key_names, ninput_list): + nkwargs[k] = v + output_list = op(**nkwargs) + else: + output_list = op(*ninput_list, **kwargs) if isinstance(output_list, (jt.Var, ComplexNumber)): output_list = [output_list] + output_list = self.flatten_list(output_list) losses = [] if weights is None: weights = [] @@ -73,15 +99,31 @@ def _jittor_to_np(x): output_list = [_jittor_to_np(x) for x in output_list] return ninput_list, output_list, losses, weights - def run_torch_op(self, op, input_list, weights=None): - def _np_to_torch(x:np.ndarray): - return torch.from_numpy(x).requires_grad_(True) + def run_torch_op(self, op, input_list, weights=None, key_names=None, **kwargs): + def _np_to_torch(x): + if isinstance(x, np.ndarray): + return torch.from_numpy(x).requires_grad_(True) + elif isinstance(x, (list, tuple)): + nx = [_np_to_torch(vx) for vx in x] + if isinstance(x, tuple): + return tuple(nx) + return nx + else: + assert False def _torch_to_np(x:torch.Tensor) -> np.ndarray: return x.detach().cpu().numpy() ninput_list = [_np_to_torch(x) for x in input_list] - output_list = op(*ninput_list) + if key_names != None: + assert len(ninput_list) == len(key_names) + nkwargs = kwargs.copy() + for k, v in zip(key_names, ninput_list): + nkwargs[k] = v + output_list = op(**nkwargs) + else: + output_list = op(*ninput_list, **kwargs) if isinstance(output_list, torch.Tensor): output_list = [output_list] + output_list = self.flatten_list(output_list) losses = [] if weights is None: weights = [] @@ -99,10 +141,10 @@ def _torch_to_np(x:torch.Tensor) -> np.ndarray: output_list = [_torch_to_np(x) for x in output_list] return ninput_list, output_list, losses, weights - def check_op_with_torch(self, jittor_op, torch_op, input_list, check_grad=True): + def check_op_with_torch(self, jittor_op, torch_op, input_list, check_grad=True, jittor_knames=None, torch_knames=None, **kwargs): weights = None - jittor_input, jittor_output, jittor_losses, weights = self.run_jittor_op(jittor_op, input_list, weights) - torch_input, torch_output, torch_losses, weights = self.run_torch_op(torch_op, input_list, weights) + jittor_input, jittor_output, jittor_losses, weights = self.run_jittor_op(jittor_op, input_list, weights, key_names=jittor_knames, **kwargs) + torch_input, torch_output, torch_losses, weights = self.run_torch_op(torch_op, input_list, weights, key_names=torch_knames, **kwargs) self.check_results(jittor_output, torch_output) if check_grad: @@ -195,6 +237,249 @@ def test_complex_svd_batch(self): inputs = [m1] self.check_op_with_numpy(jt.linalg.svd, np.linalg.svd, inputs) +class TestTensordot(unittest.TestCase, TestResultAndGrad): + def random_complex_matrix(self, shape): + r = np.random.randn(*shape) + i = np.random.randn(*shape) + return r + 1j * i + + def random_real_matrix(self, shape): + return np.random.randn(*shape) + + def test_complex_tensordot_numberdim(self): + s1 = (3, 4, 5) + s2 = (4, 5, 6) + dims = 2 + m1 = self.random_complex_matrix(s1) + m2 = self.random_complex_matrix(s2) + inputs = [m1, m2] + self.check_op_with_torch(jt.nn.tensordot, torch.tensordot, inputs, dims = dims) + + def test_complex_tensordot_tupledim(self): + s1 = (3, 5, 4, 6) + s2 = (6, 4, 5, 3) + dims = ([2, 1, 3], [1, 2, 0]) + m1 = self.random_complex_matrix(s1) + m2 = self.random_complex_matrix(s2) + inputs = [m1, m2] + self.check_op_with_torch(jt.nn.tensordot, torch.tensordot, inputs, dims = dims) + + def test_real_tensordot_numberdim(self): + s1 = (3, 4, 5) + s2 = (4, 5, 6) + dims = 2 + m1 = self.random_real_matrix(s1) + m2 = self.random_real_matrix(s2) + inputs = [m1, m2] + self.check_op_with_torch(jt.nn.tensordot, torch.tensordot, inputs, dims = dims) + + def test_real_tensordot_tupledim(self): + s1 = (3, 5, 4, 6) + s2 = (6, 4, 5, 3) + dims = ([2, 1, 3], [1, 2, 0]) + m1 = self.random_real_matrix(s1) + m2 = self.random_real_matrix(s2) + inputs = [m1, m2] + self.check_op_with_torch(jt.nn.tensordot, torch.tensordot, inputs, dims = dims) + +class TestKron(unittest.TestCase, TestResultAndGrad): + def random_complex_matrix(self, shape): + r = np.random.randn(*shape) + i = np.random.randn(*shape) + return r + 1j * i + + def random_real_matrix(self, shape): + return np.random.randn(*shape) + + def test_complex_firstlarge(self): + s1 = (2, 3, 4) + s2 = (5, 2) + m1 = self.random_complex_matrix(s1) + m2 = self.random_complex_matrix(s2) + inputs = [m1, m2] + self.check_op_with_torch(jt.nn.kron, torch.kron, inputs) + + def test_complex_second_large(self): + s1 = (2, 3) + s2 = (5, 2, 4) + m1 = self.random_complex_matrix(s1) + m2 = self.random_complex_matrix(s2) + inputs = [m1, m2] + self.check_op_with_torch(jt.nn.kron, torch.kron, inputs) + + def test_real_firstlarge(self): + s1 = (2, 3, 4) + s2 = (5, 2) + m1 = self.random_real_matrix(s1) + m2 = self.random_real_matrix(s2) + inputs = [m1, m2] + self.check_op_with_torch(jt.nn.kron, torch.kron, inputs) + + def test_real_second_large(self): + s1 = (2, 3) + s2 = (5, 2, 4) + m1 = self.random_real_matrix(s1) + m2 = self.random_real_matrix(s2) + inputs = [m1, m2] + self.check_op_with_torch(jt.nn.kron, torch.kron, inputs) + +@unittest.skipIf(__skip_torch_test, "No Torch found") +class TestGradFunctional(unittest.TestCase, TestResultAndGrad): + def random_complex_matrix(self, shape): + r = np.random.randn(*shape) + i = np.random.randn(*shape) + return r + 1j * i + + def random_real_matrix(self, shape): + return np.random.randn(*shape) * 0.0 + 1.0 + + def test_real_jvp_exp(self): + def exp_reducer(x): + return x.exp().sum(dim=1) + s1 = (5, 6) + m1 = self.random_real_matrix(s1) + m2 = self.random_real_matrix(s1) + inputs = [m1, m2] + self.check_op_with_torch( + partial(jt.gradfunctional.jvp, func=exp_reducer, create_graph=True), + partial(torch.autograd.functional.jvp, func=exp_reducer, create_graph=True), + inputs, + jittor_knames = ['inputs', 'v'], + torch_knames = ['inputs', 'v'], + check_grad=False) + + def test_complex_jvp_exp(self): + def exp_reducer(x): + return x.exp().sum(1) + s1 = (5, 6) + m1 = self.random_complex_matrix(s1) + m2 = self.random_complex_matrix(s1) + inputs = [m1, m2] + self.check_op_with_torch( + partial(jt.gradfunctional.jvp, func=exp_reducer, create_graph=True), + partial(torch.autograd.functional.jvp, func=exp_reducer, create_graph=True), + inputs, + jittor_knames = ['inputs', 'v'], + torch_knames = ['inputs', 'v'], + check_grad=False, + ) + + def test_real_jvp_add(self): + w1, w2 = np.random.rand(), np.random.rand() + def adder(x, y): + return w1 * x + w2 * y + s1 = (5, 6) + m1 = self.random_real_matrix(s1) + m2 = self.random_real_matrix(s1) + m3 = self.random_real_matrix(s1) + m4 = self.random_real_matrix(s1) + inputs = [(m1, m2), (m3, m4)] + self.check_op_with_torch( + partial(jt.gradfunctional.jvp, func=adder, create_graph=True), + partial(torch.autograd.functional.jvp, func=adder, create_graph=True), + inputs, + jittor_knames = ['inputs', 'v'], + torch_knames = ['inputs', 'v'], + check_grad=False, + ) + + def test_complex_jvp_add(self): + w1r, w1i = np.random.rand(), np.random.rand() + w2r, w2i = np.random.rand(), np.random.rand() + def adder_pt(x, y): + return (w1r + 1j * w1i) * x + (w2r + 1j * w2i) * y + def adder_jt(x, y): + w1 = ComplexNumber(real=jt.array(w1r).reshape(1,1), imag = jt.array(w1i).reshape(1,1)) + w2 = ComplexNumber(real=jt.array(w2r).reshape(1,1), imag = jt.array(w2i).reshape(1,1)) + return w1 * x + w2 * y + s1 = (5, 6) + m1 = self.random_complex_matrix(s1) + m2 = self.random_complex_matrix(s1) + m3 = self.random_complex_matrix(s1) + m4 = self.random_complex_matrix(s1) + inputs = [(m1, m2), (m3, m4)] + self.check_op_with_torch( + partial(jt.gradfunctional.jvp, func=adder_jt, create_graph=True), + partial(torch.autograd.functional.jvp, func=adder_pt, create_graph=True), + inputs, + jittor_knames = ['inputs', 'v'], + torch_knames = ['inputs', 'v'], + check_grad=False, + ) + + def test_real_vjp_exp(self): + def exp_reducer(x): + return x.exp().sum(dim=1) + s1 = (5, 6) + s2 = (5,) + m1 = self.random_real_matrix(s1) + m2 = self.random_real_matrix(s2) + inputs = [m1, m2] + self.check_op_with_torch( + partial(jt.gradfunctional.vjp, func=exp_reducer), + partial(torch.autograd.functional.vjp, func=exp_reducer), + inputs, + jittor_knames = ['inputs', 'v'], + torch_knames = ['inputs', 'v'], + check_grad=False) + + def test_complex_vjp_exp(self): + def exp_reducer(x): + return x.exp().sum(1) + s1 = (5, 6) + s2 = (5,) + m1 = self.random_real_matrix(s1) + m2 = self.random_real_matrix(s2) + inputs = [m1, m2] + self.check_op_with_torch( + partial(jt.gradfunctional.vjp, func=exp_reducer, create_graph=True), + partial(torch.autograd.functional.vjp, func=exp_reducer, create_graph=True), + inputs, + jittor_knames = ['inputs', 'v'], + torch_knames = ['inputs', 'v'], + check_grad=False, + ) + + def test_real_vjp_add(self): + w1, w2 = np.random.rand(), np.random.rand() + def adder(x, y): + return w1 * x + w2 * y + s1 = (5, 6) + m1 = self.random_real_matrix(s1) + m2 = self.random_real_matrix(s1) + m3 = self.random_real_matrix(s1) + inputs = [(m1, m2), m3] + self.check_op_with_torch( + partial(jt.gradfunctional.vjp, func=adder, create_graph=True), + partial(torch.autograd.functional.vjp, func=adder, create_graph=True), + inputs, + jittor_knames = ['inputs', 'v'], + torch_knames = ['inputs', 'v'], + check_grad=False, + ) + + def test_complex_vjp_add(self): + w1r, w1i = np.random.rand(), np.random.rand() + w2r, w2i = np.random.rand(), np.random.rand() + def adder_pt(x, y): + return (w1r + 1j * w1i) * x + (w2r + 1j * w2i) * y + def adder_jt(x, y): + w1 = ComplexNumber(real=jt.array(w1r).reshape(1,1), imag = jt.array(w1i).reshape(1,1)) + w2 = ComplexNumber(real=jt.array(w2r).reshape(1,1), imag = jt.array(w2i).reshape(1,1)) + return w1 * x + w2 * y + s1 = (5, 6) + m1 = self.random_complex_matrix(s1) + m2 = self.random_complex_matrix(s1) + m3 = self.random_complex_matrix(s1) + inputs = [(m1, m2), (m3)] + self.check_op_with_torch( + partial(jt.gradfunctional.vjp, func=adder_jt, create_graph=True), + partial(torch.autograd.functional.vjp, func=adder_pt, create_graph=True), + inputs, + jittor_knames = ['inputs', 'v'], + torch_knames = ['inputs', 'v'], + check_grad=False, + ) if __name__ == "__main__": unittest.main() From a55e2960534630b76792c6d7930528c521d3275b Mon Sep 17 00:00:00 2001 From: 514flowey <1114811901@qq.com> Date: Thu, 22 Aug 2024 12:53:03 +0800 Subject: [PATCH 65/73] fix unique bug --- python/jittor/__init__.py | 24 ++++++++++------ python/jittor/misc.py | 51 ++++++++++++++++++++++++++++++++++ python/jittor/nn.py | 14 ++++++++-- python/jittor/test/test_var.py | 27 ++++++++++++++++++ 4 files changed, 105 insertions(+), 11 deletions(-) diff --git a/python/jittor/__init__.py b/python/jittor/__init__.py index 18838634..3d9e7aa4 100644 --- a/python/jittor/__init__.py +++ b/python/jittor/__init__.py @@ -652,14 +652,22 @@ def var(x, dim=None, dims=None, unbiased=False, keepdims=False): return sqr Var.var = var -def std(x): - matsize=1 - for i in x.shape: - matsize *= i - out=(x-x.mean()).sqr().sum() - out=out/(matsize-1) - out=out.maximum(1e-6).sqrt() - return out +def std(x, dim=None, keepdim=False): + if dim is None: + matsize=1 + for i in x.shape: + matsize *= i + out=(x-x.mean()).sqr().sum() + out=out/(matsize-1) + out=out.maximum(1e-6).sqrt() + return out + else: + dimsize=x.size(dim) + mean=jt.mean(x, dim, keepdim=True) + out=(x - mean).sqr().sum(dim=dim, keepdim=keepdim) + out=out/(dimsize-1) + out=out.maximum(1e-6).sqrt() + return out Var.std = std def norm(x, p=2, dim=-1, keepdims=False, eps=1e-30, keepdim=False): diff --git a/python/jittor/misc.py b/python/jittor/misc.py index 2760f9d4..8561393e 100644 --- a/python/jittor/misc.py +++ b/python/jittor/misc.py @@ -623,6 +623,10 @@ def unique( #include #include + #include + #include + #include + #include #include ''', @@ -705,6 +709,11 @@ def unique( #include #include #include + + #include + #include + #include + #include #include @@ -923,6 +932,48 @@ def diag(x,diagonal=0): output_shape = (x.shape[0]-d,) return x.reindex(output_shape,[f'i0+{d}' if diagonal<=0 else 'i0',f'i0+{d}' if diagonal>=0 else 'i0']) +# reference: https://github.com/pytorch/pytorch/blob/25d5a815f74db80ef19a3f714709b55b05675245/torch/_refs/__init__.py +def diagonal(x, offset=0, dim1=0, dim2=1): + def __normalize_dim(d, rank): + if d < 0: + d += rank + if d < 0 or d >= rank: + msg = f"Dimension out of range (expected to be in range of [{-rank}, {rank - 1}], but got {d})" + raise IndexError(msg) + return d + assert x.ndim >= 2, f"diagonal dimensions requires ndim larger than 2, but got {x.ndim}" + dim1 = __normalize_dim(dim1, x.ndim) + dim2 = __normalize_dim(dim2, x.ndim) + assert dim1 != dim2, f"diagonal dimensions cannot be identical {dim1}, {dim2}" + + if offset >= 0: + diag_size = max(min(x.shape[dim1], x.shape[dim2] - offset), 0) + else: + diag_size = max(min(x.shape[dim1] + offset, x.shape[dim2]), 0) + + sizes = [] + indices = [] + lsizes = 0 + dim_diag = x.ndim - 2 + abs_offset = offset if offset >= 0 else -offset + for i, s in enumerate(x.shape): + if i == dim1: + if offset >= 0: + indices.append(f"i{dim_diag}") + else: + indices.append(f"i{dim_diag}+{abs_offset}") + elif i == dim2: + if offset >= 0: + indices.append(f"i{dim_diag}+{abs_offset}") + else: + indices.append(f"i{dim_diag}") + else: + indices.append(f"i{lsizes}") + sizes.append(s) + lsizes += 1 + out_shape = tuple(sizes + [diag_size]) + return x.reindex(out_shape, indices) + jt.Var.diag = diag diff --git a/python/jittor/nn.py b/python/jittor/nn.py index aad40f5c..2080a1c8 100644 --- a/python/jittor/nn.py +++ b/python/jittor/nn.py @@ -2255,22 +2255,30 @@ def grid_sample(input, grid, mode='bilinear', padding_mode='zeros', align_corner class Upsample(Module): - def __init__(self, scale_factor=None, mode='nearest'): + def __init__(self, scale_factor=None, mode='nearest', align_corners=False): if isinstance(scale_factor, tuple): self.scale_factor = tuple(float(factor) for factor in scale_factor) else: self.scale_factor = float(scale_factor) if scale_factor else None self.mode = mode - + self.align_corners = align_corners + def execute(self, x): if self.scale_factor is None: raise ValueError("scale_factor should be defined") + elif isinstance(self.scale_factor, float): + return upsample(x, + size=(int(x.shape[2]*self.scale_factor), + int(x.shape[3]*self.scale_factor)), + mode=self.mode, + align_corners=self.align_corners) else: return upsample(x, size=( int(x.shape[2]*self.scale_factor[0]), int(x.shape[3]*self.scale_factor[1])), - mode=self.mode) + mode=self.mode, + align_corners=self.align_cornerss) class UpsamplingBilinear2d(Upsample): def __init__(self, scale_factor=None): diff --git a/python/jittor/test/test_var.py b/python/jittor/test/test_var.py index 116df6c7..c18b041b 100644 --- a/python/jittor/test/test_var.py +++ b/python/jittor/test/test_var.py @@ -46,6 +46,33 @@ def test_norm(self): np.testing.assert_allclose(jt_x.norm(2,1).numpy(), tc_x.norm(2,1).numpy(), atol=1e-6) np.testing.assert_allclose(jt_x.norm(2,0).numpy(), tc_x.norm(2,0).numpy(), atol=1e-6) + def test_std_with_dim(self): + x=np.random.randn(100, 1000).astype(np.float32) + jt_x = jt.array(x) + tc_x = torch.from_numpy(x) + np.testing.assert_allclose(jt_x.std(dim=-1).numpy(), tc_x.std(dim=-1).numpy(), 1e-4) + np.testing.assert_allclose(jt_x.std(dim=0, keepdim=True).numpy(), tc_x.std(dim=0, keepdim=True).numpy(), 1e-4) + + def test_diagonal(self): + x = np.reshape(np.arange(5*6*7*8), (5,6,7,8)) + jt_x = jt.array(x) + tc_x = torch.from_numpy(x) + def __assert_equal(a:np.ndarray, b:np.ndarray, rtol=1e-6, atol=1e-6): + assert a.shape == b.shape, f"{a.shape}!={b.shape}" + np.testing.assert_allclose(a, b, rtol=rtol, atol=atol) + __assert_equal(jt.misc.diagonal(jt_x, 0, dim1=1, dim2=2).numpy(), tc_x.diagonal(offset=0, dim1=1, dim2=2).numpy()) + __assert_equal(jt.misc.diagonal(jt_x, -1, dim1=1, dim2=2).numpy(), tc_x.diagonal(offset=-1, dim1=1, dim2=2).numpy()) + __assert_equal(jt.misc.diagonal(jt_x, -2, dim1=1, dim2=2).numpy(), tc_x.diagonal(offset=-2, dim1=1, dim2=2).numpy()) + __assert_equal(jt.misc.diagonal(jt_x, -6, dim1=1, dim2=2).numpy(), tc_x.diagonal(offset=-6, dim1=1, dim2=2).numpy()) + __assert_equal(jt.misc.diagonal(jt_x, 1, dim1=1, dim2=2).numpy(), tc_x.diagonal(offset=1, dim1=1, dim2=2).numpy()) + __assert_equal(jt.misc.diagonal(jt_x, 2, dim1=1, dim2=2).numpy(), tc_x.diagonal(offset=2, dim1=1, dim2=2).numpy()) + __assert_equal(jt.misc.diagonal(jt_x, 7, dim1=1, dim2=2).numpy(), tc_x.diagonal(offset=7, dim1=1, dim2=2).numpy()) + __assert_equal(jt.misc.diagonal(jt_x, 1, dim1=-1, dim2=2).numpy(), tc_x.diagonal(offset=1, dim1=-1, dim2=2).numpy()) + __assert_equal(jt.misc.diagonal(jt_x, 1, dim1=-1, dim2=-2).numpy(), tc_x.diagonal(offset=1, dim1=-1, dim2=-2).numpy()) + __assert_equal(jt.misc.diagonal(jt_x, 1, dim1=-2, dim2=-1).numpy(), tc_x.diagonal(offset=1, dim1=-2, dim2=-1).numpy()) + __assert_equal(jt.misc.diagonal(jt_x, 1, dim1=0, dim2=-2).numpy(), tc_x.diagonal(offset=1, dim1=0, dim2=-2).numpy()) + __assert_equal(jt.misc.diagonal(jt_x, 1, dim1=2, dim2=1).numpy(), tc_x.diagonal(offset=1, dim1=2, dim2=1).numpy()) + if __name__ == "__main__": unittest.main() From a725f200423a5c0f72e3f099dfb8c4abe779675d Mon Sep 17 00:00:00 2001 From: liylo <2813164552@qq.com> Date: Wed, 28 Aug 2024 20:50:37 +0800 Subject: [PATCH 66/73] fix load --- python/jittor/__init__.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/python/jittor/__init__.py b/python/jittor/__init__.py index 3d9e7aa4..7df64b54 100644 --- a/python/jittor/__init__.py +++ b/python/jittor/__init__.py @@ -1619,6 +1619,12 @@ def load_parameters(self, params): else: end=1 break + elif isinstance(v, nn.ParameterList): + if k in v.keys(): + v = v[k] + else: + end = 1 + break else: if hasattr(v, k): v = getattr(v, k) From 3b63820624a82feb5de3898f57e1c02f73f6e393 Mon Sep 17 00:00:00 2001 From: liylo <2813164552@qq.com> Date: Wed, 28 Aug 2024 21:13:00 +0800 Subject: [PATCH 67/73] simple implementation for block diag --- python/jittor/__init__.py | 6 ----- python/jittor/contrib.py | 47 +++++++++++++++++++++++++++++++++++++++ 2 files changed, 47 insertions(+), 6 deletions(-) diff --git a/python/jittor/__init__.py b/python/jittor/__init__.py index 7df64b54..3d9e7aa4 100644 --- a/python/jittor/__init__.py +++ b/python/jittor/__init__.py @@ -1619,12 +1619,6 @@ def load_parameters(self, params): else: end=1 break - elif isinstance(v, nn.ParameterList): - if k in v.keys(): - v = v[k] - else: - end = 1 - break else: if hasattr(v, k): v = getattr(v, k) diff --git a/python/jittor/contrib.py b/python/jittor/contrib.py index c4026eee..040f25fe 100644 --- a/python/jittor/contrib.py +++ b/python/jittor/contrib.py @@ -57,6 +57,53 @@ def concat(arr, dim): cdim += a.shape[dim] return s +def block_diag(*tensors): + """Create a block diagonal matrix from provided tensors. + + Args: + *tensors: One or more tensors with 0, 1, or 2 dimensions. + + Returns: + Tensor: A 2 dimensional tensor with all the input tensors arranged in + order such that their upper left and lower right corners are + diagonally adjacent. All other elements are set to 0. + """ + rows = 0 + cols = 0 + for tensor in tensors: + shape = tensor.shape + if len(shape) == 0: # 0-d tensor + rows += 1 + cols += 1 + elif len(shape) == 1: # 1-d tensor + rows += 1 + cols += shape[0] + elif len(shape) == 2: # 2-d tensor + rows += shape[0] + cols += shape[1] + + result = jt.zeros((rows, cols)) + result.requires_grad = True + + current_row = 0 + current_col = 0 + for tensor in tensors: + shape = tensor.shape + if len(shape) == 0: # 0-d tensor + result[current_row, current_col] = tensor + current_row += 1 + current_col += 1 + elif len(shape) == 1: # 1-d tensor + result[current_row, current_col:current_col + shape[0]] = tensor + current_row += 1 + current_col += shape[0] + elif len(shape) == 2: # 2-d tensor + result[current_row:current_row + shape[0], current_col:current_col + shape[1]] = tensor + current_row += shape[0] + current_col += shape[1] + + return result + def check(bc): bc = np.array(bc) if ((bc != 1) * (bc != bc.max(0))).sum() > 0: From 362d09c08b9ce354f5f43b13ba488bd019227a43 Mon Sep 17 00:00:00 2001 From: liylo <2813164552@qq.com> Date: Wed, 28 Aug 2024 21:18:56 +0800 Subject: [PATCH 68/73] simple implementation for block diag with proper grad --- python/jittor/contrib.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/python/jittor/contrib.py b/python/jittor/contrib.py index 040f25fe..2495629c 100644 --- a/python/jittor/contrib.py +++ b/python/jittor/contrib.py @@ -68,6 +68,8 @@ def block_diag(*tensors): order such that their upper left and lower right corners are diagonally adjacent. All other elements are set to 0. """ + requires_grad = tensors[0].requires_grad + rows = 0 cols = 0 for tensor in tensors: @@ -83,7 +85,7 @@ def block_diag(*tensors): cols += shape[1] result = jt.zeros((rows, cols)) - result.requires_grad = True + result.requires_grad = requires_grad current_row = 0 current_col = 0 From f56ddbd5a51358ab0fb782e20047468c87781bed Mon Sep 17 00:00:00 2001 From: liylo <2813164552@qq.com> Date: Wed, 28 Aug 2024 21:27:02 +0800 Subject: [PATCH 69/73] init --- python/jittor/contrib.py | 49 ---------------------------------------- 1 file changed, 49 deletions(-) diff --git a/python/jittor/contrib.py b/python/jittor/contrib.py index 2495629c..c4026eee 100644 --- a/python/jittor/contrib.py +++ b/python/jittor/contrib.py @@ -57,55 +57,6 @@ def concat(arr, dim): cdim += a.shape[dim] return s -def block_diag(*tensors): - """Create a block diagonal matrix from provided tensors. - - Args: - *tensors: One or more tensors with 0, 1, or 2 dimensions. - - Returns: - Tensor: A 2 dimensional tensor with all the input tensors arranged in - order such that their upper left and lower right corners are - diagonally adjacent. All other elements are set to 0. - """ - requires_grad = tensors[0].requires_grad - - rows = 0 - cols = 0 - for tensor in tensors: - shape = tensor.shape - if len(shape) == 0: # 0-d tensor - rows += 1 - cols += 1 - elif len(shape) == 1: # 1-d tensor - rows += 1 - cols += shape[0] - elif len(shape) == 2: # 2-d tensor - rows += shape[0] - cols += shape[1] - - result = jt.zeros((rows, cols)) - result.requires_grad = requires_grad - - current_row = 0 - current_col = 0 - for tensor in tensors: - shape = tensor.shape - if len(shape) == 0: # 0-d tensor - result[current_row, current_col] = tensor - current_row += 1 - current_col += 1 - elif len(shape) == 1: # 1-d tensor - result[current_row, current_col:current_col + shape[0]] = tensor - current_row += 1 - current_col += shape[0] - elif len(shape) == 2: # 2-d tensor - result[current_row:current_row + shape[0], current_col:current_col + shape[1]] = tensor - current_row += shape[0] - current_col += shape[1] - - return result - def check(bc): bc = np.array(bc) if ((bc != 1) * (bc != bc.max(0))).sum() > 0: From 29fa67bb920d0a6094ba53a32b87e8774f925528 Mon Sep 17 00:00:00 2001 From: liylo <2813164552@qq.com> Date: Wed, 28 Aug 2024 21:35:12 +0800 Subject: [PATCH 70/73] forward hooks now could modifiy inputs and outputs --- python/jittor/__init__.py | 18 ++++++++++++++---- 1 file changed, 14 insertions(+), 4 deletions(-) diff --git a/python/jittor/__init__.py b/python/jittor/__init__.py index 3d9e7aa4..c2df0aa7 100644 --- a/python/jittor/__init__.py +++ b/python/jittor/__init__.py @@ -1464,9 +1464,17 @@ def requires_grad_(self, requires_grad=True): def __hooked_call__(self, *args, **kw): if hasattr(self, "__fhook2__"): if len(kw): - self.__fhook2__(self, args, kw) + args_kw_result = self.__fhook2__(self, args, kw) else: - self.__fhook2__(self, args) + args_kw_result = self.__fhook2__(self, args) + if args_kw_result is not None: + if isinstance(args_kw_result, tuple) and len(args_kw_result) == 2: + args, kw = args_kw_result + else: + raise RuntimeError( + "forward pre-hook must return None or a tuple " + f"of (new_args, new_kwargs), but got {args_kw_result}." + ) if hasattr(self, "__bihook__"): if len(kw): LOG.w("backward hook not support kw") @@ -1485,9 +1493,11 @@ def __hooked_call__(self, *args, **kw): ret = grad_hooker(ret, self.__bohook__) if hasattr(self, "__fhook__"): if len(kw): - self.__fhook__(self, args, ret, kw) + res = self.__fhook__(self, args, ret, kw) else: - self.__fhook__(self, args, ret) + res = self.__fhook__(self, args, ret) + if res is not None: + ret = res return ret def _place_hooker(self): From 58cd6b538f20f4fed9efcd00a3c2307eb3d31547 Mon Sep 17 00:00:00 2001 From: lidongyang Date: Wed, 4 Sep 2024 16:11:51 +0800 Subject: [PATCH 71/73] remove compatibility --- python/jittor/compatibility/__init__.py | 430 ---- python/jittor/compatibility/autograd.py | 134 -- python/jittor/compatibility/compiler.py | 39 - python/jittor/compatibility/cuda.py | 64 - python/jittor/compatibility/distributed.py | 53 - python/jittor/compatibility/distributions.py | 15 - python/jittor/compatibility/fft/__init__.py | 5 - python/jittor/compatibility/fx.py | 2 - python/jittor/compatibility/gradscaler.py | 519 ----- python/jittor/compatibility/gradscaler_old.py | 556 ----- python/jittor/compatibility/misc.py | 12 - python/jittor/compatibility/nn/__init__.py | 281 --- python/jittor/compatibility/nn/init.py | 16 - .../jittor/compatibility/nn/utils/__init__.py | 1 - python/jittor/compatibility/nn/utils/rnn.py | 20 - python/jittor/compatibility/optim.py | 1854 ----------------- .../jittor/compatibility/src/jtorch_core.cc | 102 - python/jittor/compatibility/src/jtorch_core.h | 40 - .../compatibility/test/test_conflict_func.py | 25 - .../compatibility/test/test_function.py | 58 - python/jittor/compatibility/test/test_misc.py | 24 - .../compatibility/test/test_tutorial.py | 56 - .../compatibility/tutorial/auto_grad1.py | 44 - .../compatibility/tutorial/auto_grad2.py | 60 - .../compatibility/tutorial/auto_grad3.py | 85 - .../compatibility/tutorial/auto_grad4.py | 71 - .../tutorial/auto_grad5_optim.py | 53 - .../tutorial/auto_grad6_module.py | 59 - .../tutorial/auto_grad7_dynet.py | 69 - .../compatibility/tutorial/quickstart.py | 106 - python/jittor/compatibility/utils/__init__.py | 5 - python/jittor/compatibility/utils/_pytree.py | 3 - .../jittor/compatibility/utils/checkpoint.py | 8 - python/jittor/compatibility/utils/data.py | 137 -- python/jittor/compatibility/utils/dtype.py | 9 - python/jittor/compatibility/utils/hooks.py | 0 .../jittor/compatibility/utils/pip_publish.py | 34 - .../vision/_internally_replaced_utils.py | 46 - .../compatibility/vision/datasets/__init__.py | 9 - .../compatibility/vision/datasets/mnist.py | 558 ----- .../compatibility/vision/datasets/utils.py | 522 ----- .../compatibility/vision/datasets/vision.py | 104 - .../jittor/compatibility/vision/transforms.py | 1 - python/jittor/compatibility/vision/utils.py | 582 ------ 44 files changed, 6871 deletions(-) delete mode 100644 python/jittor/compatibility/__init__.py delete mode 100644 python/jittor/compatibility/autograd.py delete mode 100644 python/jittor/compatibility/compiler.py delete mode 100644 python/jittor/compatibility/cuda.py delete mode 100644 python/jittor/compatibility/distributed.py delete mode 100644 python/jittor/compatibility/distributions.py delete mode 100644 python/jittor/compatibility/fft/__init__.py delete mode 100644 python/jittor/compatibility/fx.py delete mode 100644 python/jittor/compatibility/gradscaler.py delete mode 100644 python/jittor/compatibility/gradscaler_old.py delete mode 100644 python/jittor/compatibility/misc.py delete mode 100644 python/jittor/compatibility/nn/__init__.py delete mode 100644 python/jittor/compatibility/nn/init.py delete mode 100644 python/jittor/compatibility/nn/utils/__init__.py delete mode 100644 python/jittor/compatibility/nn/utils/rnn.py delete mode 100644 python/jittor/compatibility/optim.py delete mode 100644 python/jittor/compatibility/src/jtorch_core.cc delete mode 100644 python/jittor/compatibility/src/jtorch_core.h delete mode 100644 python/jittor/compatibility/test/test_conflict_func.py delete mode 100644 python/jittor/compatibility/test/test_function.py delete mode 100644 python/jittor/compatibility/test/test_misc.py delete mode 100644 python/jittor/compatibility/test/test_tutorial.py delete mode 100644 python/jittor/compatibility/tutorial/auto_grad1.py delete mode 100644 python/jittor/compatibility/tutorial/auto_grad2.py delete mode 100644 python/jittor/compatibility/tutorial/auto_grad3.py delete mode 100644 python/jittor/compatibility/tutorial/auto_grad4.py delete mode 100644 python/jittor/compatibility/tutorial/auto_grad5_optim.py delete mode 100644 python/jittor/compatibility/tutorial/auto_grad6_module.py delete mode 100644 python/jittor/compatibility/tutorial/auto_grad7_dynet.py delete mode 100644 python/jittor/compatibility/tutorial/quickstart.py delete mode 100644 python/jittor/compatibility/utils/__init__.py delete mode 100644 python/jittor/compatibility/utils/_pytree.py delete mode 100644 python/jittor/compatibility/utils/checkpoint.py delete mode 100644 python/jittor/compatibility/utils/data.py delete mode 100644 python/jittor/compatibility/utils/dtype.py delete mode 100644 python/jittor/compatibility/utils/hooks.py delete mode 100644 python/jittor/compatibility/utils/pip_publish.py delete mode 100644 python/jittor/compatibility/vision/_internally_replaced_utils.py delete mode 100644 python/jittor/compatibility/vision/datasets/__init__.py delete mode 100644 python/jittor/compatibility/vision/datasets/mnist.py delete mode 100644 python/jittor/compatibility/vision/datasets/utils.py delete mode 100644 python/jittor/compatibility/vision/datasets/vision.py delete mode 100644 python/jittor/compatibility/vision/transforms.py delete mode 100644 python/jittor/compatibility/vision/utils.py diff --git a/python/jittor/compatibility/__init__.py b/python/jittor/compatibility/__init__.py deleted file mode 100644 index 94d2e40b..00000000 --- a/python/jittor/compatibility/__init__.py +++ /dev/null @@ -1,430 +0,0 @@ -# import os -# os.environ["FIX_TORCH_ERROR"] = "0" - -# import jittor as jt -# from jittor import * -# from typing import Tuple - -# org_int = int = type(1) -# org_float = float = type(1.0) -# org_bool = bool = type(True) - -# import jtorch.compiler - -# import jtorch_core -# from jtorch_core import * - -# device.__reduce__ = lambda self: (device, (self.type,)) -# device.__module__ = "jtorch" -# jt.jittor_core.device = device - -# def handle_dtype(args, kw, dtype): -# def convert(x): -# if isinstance(x, jt.Var): -# return x.cast(dtype) -# return x -# if dtype is not None: -# if args is not None: -# if isinstance(args, (tuple,list)): -# args = [ convert(a) for a in args ] -# else: -# args = convert(x) -# if kw is not None: -# kw = { k:convert(v) for k,v in kw.items() } -# return args, kw - -# def get_args_names(func): -# import inspect -# spec = inspect.getfullargspec(func) -# return spec[0] + spec[4] - -# def wrapper(func): -# has_dtype = False -# if hasattr(func, "__code__"): -# has_dtype = "dtype" in get_args_names(func) -# def inner(*args, **kw): -# requires_grad = None -# dtype = None -# if "requires_grad" in kw: -# requires_grad = kw["requires_grad"] -# del kw["requires_grad"] -# if not has_dtype and "dtype" in kw: -# dtype = kw["dtype"] -# del kw["dtype"] -# if "device" in kw: -# del kw["device"] -# if 'pin_memory' in kw: -# del kw['pin_memory'] -# args, kw = handle_dtype(args, kw, dtype) -# ret = func(*args, **kw) -# if isinstance(ret, jt.Var): -# if requires_grad is not None: -# ret.requires_grad = requires_grad -# if dtype is not None: -# ret.astype(dtype) -# return ret -# return inner - - -# import inspect -# _wrapper_keys = set(["shape", "start", "size"]) -# _wrapper_keys.add("x") -# for k,v in list(globals().items()): -# if callable(v) and not isinstance(v, type): -# try: -# spec = inspect.getfullargspec(v) -# args_name = spec[0] -# if len(args_name) and args_name[0] in _wrapper_keys: -# globals()[k] = wrapper(v) -# elif spec.varargs in _wrapper_keys: -# globals()[k] = wrapper(v) -# except: -# pass - -# def empty(*size, dtype=jt.float32, device=None, requires_grad=False): -# if len(size) == 1 and not isinstance(size[0], org_int): -# size = size[0] -# return jt.empty(size, dtype) - -# Tensor = Var - -# Tensor.backward = lambda x: jtorch_core.backward(x) -# Tensor.grad = property(grad_get, grad_set, grad_del) -# Tensor.retains_grad = property(retain_grad_get, retain_grad_set) -# def retain_grad(x:Tensor, value:bool=True): -# x.retains_grad = value -# return value -# Tensor.retain_grad = retain_grad - -# Tensor.dim = lambda self: self.ndim -# Tensor.ndimension = lambda self: self.ndim -# Tensor.nelement = lambda self: self.numel() -# Tensor.cuda = lambda self: self -# def device_get(x:Tensor): -# return device("cpu") if not jt.has_cuda or not jt.flags.use_cuda else device("cuda") -# Tensor.device = property(device_get) - -# def argmax(x: Var, dim=None, keepdim: bool = False): -# return jt.argmax(x, dim, keepdim)[0] -# Tensor.argmax = argmax - -# def tensor_type(x: Var, dtype=None, **kwargs): -# if dtype: -# return x.astype(dtype) -# else: -# return x.dtype -# Tensor.type = tensor_type - -# def is_floating_point(x: Var): -# return "float" in str(x.dtype) -# Tensor.is_floating_point = is_floating_point - -# from . import autograd -# from .autograd import * - -# def tensor(data, *, dtype=None, device=None, requires_grad=False, pin_memory=False): -# if isinstance(data,list): -# data_list = [] -# check = True -# for p in data: -# if isinstance(p, Tensor) and p.numel()==1: -# data_list.append(p.item()) -# elif isinstance(p, (org_int,org_float)): -# data_list.append(p) -# else: -# check = False -# break -# if check: -# data = data_list -# return wrapper(array)(data, dtype=dtype, device=device, requires_grad=requires_grad, pin_memory=pin_memory) - -# # tensor = wrapper(array) -# from_numpy = wrapper(array) -# strided = None - -# def mod_zero_grad(self): -# for p in self.parameters(): -# p.grad = None -# Module.zero_grad = mod_zero_grad - -# class ModuleMisc: -# def parameters(self): -# return iter(super().parameters()) - -# def load_state_dict(self, state_dict, strict=False): -# return super().load_state_dict(state_dict) - -# def to(self, device=None,dtype=None): -# ''' do nothing but return its self''' -# return self -# def register_parameter(self,name,data): -# self.name = data - -# def buffers(self): -# for _, buf in self.named_buffers(): -# yield buf - - -# def make_module(cls): -# class TMod(ModuleMisc, cls): -# def __init__(self, *args, **kw): -# dtype = None -# if "dtype" in kw: -# dtype = kw["dtype"] -# del kw["dtype"] -# self._dtype = dtype -# with jt.flag_scope(th_mode=0): -# if "device" in kw: -# del kw["device"] -# super().__init__(*args, **kw) -# for k,v in self.__dict__.items(): -# if not k.startswith("_") and isinstance(v, Var) \ -# and v.requires_grad: -# v.retain_grad() -# if dtype is not None and isinstance(v, Var): -# v.assign(v.cast(dtype)) -# def __call__(self, *args, **kw): -# args, kw = handle_dtype(args, kw, self._dtype) -# # if forward is override by user, call forward -# if self.__class__.forward is not TMod.forward: -# return self.forward(*args, **kw) -# return self.execute(*args, **kw) -# def forward(self, *args, **kw): -# args, kw = handle_dtype(args, kw, self._dtype) -# return self.execute(*args, **kw) - -# @property -# def training(self): -# if not hasattr(self, "is_train"): -# self.is_train = True -# return self.is_train -# @training.setter -# def training(self, value): -# self.is_train = value - -# TMod.__name__ = cls.__name__ -# return TMod - -# import jtorch.cuda -# import jtorch.nn -# from jtorch.nn import Module, Parameter -# import jtorch.optim - -# from jtorch.utils.dtype import Dtype, get_string_dtype - -# def frombuffer(buffer: bytearray, -# *, -# dtype: Dtype, -# count: int = -1, -# offset: int = 0, -# requires_grad: bool = True) -> Tensor: -# dtype = get_string_dtype(dtype) -# tensor = jt.array(np.frombuffer(buffer, dtype, count=count, offset=offset)) -# if requires_grad and tensor.dtype.is_float(): -# tensor.requires_grad = True -# return tensor - -# def conflict_wrapper(origin_func, new_func): -# def wrapper(*args, **kw): -# if jt.flags.th_mode: -# return new_func(*args, **kw) -# else: -# return origin_func(*args, **kw) -# return wrapper - -# def min(*args, **kw): -# dim = None -# if len(args) >= 2 and isinstance(args[1], org_int): -# dim = args[1] -# elif "dim" in kw and isinstance(kw["dim"], org_int): -# dim = kw["dim"] -# if dim is not None: -# k, v = jt.argmin(*args, **kw) -# return v, k -# elif len(args) == 2 and isinstance(args[1], jt.Var): -# return jt.minimum(args[0], args[1]) -# else: -# return jt.min(*args, **kw) -# Tensor.min = conflict_wrapper(jt.min, min) - -# def max(*args, **kw): -# dim = None -# if "dim" in kw: -# x = kw["dim"] -# if len(args) >= 2 and isinstance(args[1], org_int): -# dim = args[1] -# elif "dim" in kw and isinstance(kw["dim"], org_int): -# dim = kw["dim"] -# if dim is not None: -# k, v = jt.argmax(*args, **kw) -# return v, k -# elif len(args) == 2 and isinstance(args[1], jt.Var): -# return jt.maximum(args[0], args[1]) -# else: -# return jt.max(*args, **kw) -# Tensor.max = conflict_wrapper(jt.max, max) - -# def argsort(*args, **kw): -# k, v = jt.argsort(*args, **kw) -# return k -# Tensor.argsort = conflict_wrapper(jt.argsort, argsort) - -# LongTensor = jt.int64 -# FloatTensor = jt.float -# HalfTensor = jt.float16 -# BoolTensor = jt.bool -# IntTensor = jt.int32 - -# class JDType: -# def __init__(self, func, str): -# self.func = func -# self.str = str -# self.__name__ = str.split(".")[-1] -# def __call__(self, *args, **kw): -# return self.func(*args, **kw) -# def __str__(self): -# return self.str -# def is_floating_point(self): -# return "float" in str(self.str) - -# int8 = JDType(jt.int8, "torch.int8") -# int16 = JDType(jt.int16, "torch.int16") -# int = int32 = JDType(jt.int32, "torch.int32") -# long = int64 = JDType(jt.int64, "torch.int64") - -# half = float16 = JDType(jt.float16, "torch.float16") -# float = float32 = JDType(jt.float32, "torch.float32") -# double = float64 = JDType(jt.float64, "torch.float64") -# bfloat16 = "bfloat16" # TODO -# complex64 = "complex64" # TODO -# complex128 = "complex128" # TODO -# def get_JDtype(dtype): -# if dtype=='float32' or dtype == jt.float32: -# return float32 -# elif dtype=='float64' or dtype == jt.float64: -# return float64 -# elif dtype=='float16' or dtype == jt.float16: -# return float16 -# elif dtype=='int32' or dtype == jt.int32: -# return int32 -# elif dtype=='int64' or dtype == jt.int64: -# return int64 -# elif dtype=='int16' or dtype == jt.int16: -# return int16 -# elif dtype=='int8' or dtype == jt.int8: -# return int8 -# else: -# raise Exception("dtype {} not supported".format(dtype)) - -# def load(path,**kwargs): -# def _to_jittor(data): -# if isinstance(data,dict): -# return {k:_to_jittor(d) for k,d in data.items()} -# if isinstance(data,list): -# return [_to_jittor(d) for d in data] -# if isinstance(data,np.ndarray): -# return jt.array(data) -# return data -# data = jt.load(path) - -# return _to_jittor(data) - -# def is_tensor(x): -# return isinstance(x, Tensor) - -# manual_seed = jt.set_global_seed -# jt.flags.amp_level = 3 -# Size = jt.NanoVector - -# class Generator: -# def __init__(self,*args,**kw) -> None: -# self.seed = None -# def manual_seed(self,seed): -# self.seed = seed - - - -# from . import fx - - -# _default_type = "float32" - -# def get_default_dtype(): -# return _default_type -# def set_default_dtype(dtype): -# global _default_type -# _default_type = dtype - -# dtype = JDType - -# def div(x,y,rounding_mode="floor"): -# assert rounding_mode == "floor" -# z = (x / y) -# if rounding_mode == "floor": -# z = z.floor() -# if x.dtype == "int32" and (isinstance(y,org_int) or y.dtype == "int32"): -# z = z.int32() -# return z - - -# def randn(*args,**kw): -# wrap_randn = wrapper(jt.randn) -# generator = kw.get('generator',None) -# kw.pop('generator',None) -# if 'layout' in kw: -# del kw['layout'] -# if generator is not None and generator.seed is not None: -# jt.set_global_seed(generator.seed) -# return wrap_randn(*args,**kw) - -# def rand(*args,**kw): -# print("rand") -# wrap_rand = wrapper(jt.rand) -# generator = kw.get('generator',None) -# kw.pop('generator',None) -# if 'layout' in kw: -# del kw['layout'] -# if generator is not None and generator.seed is not None: -# jt.set_global_seed(generator.seed) -# return wrap_rand(*args,**kw) - - - -# def set_default_tensor_type(t: type or str): -# if isinstance(t, str): -# info = t.split(".") -# if len(info) == 3 and info[1] == 'cuda': -# jt.flags.use_cuda = 1 -# #TODO: type - - -# def clamp(x, min=None, max=None): -# return jt.clamp(x, min, max) - - -# def to(x,*args,**kw): -# device = None -# if len(args) == 1: -# device = args[0] -# if isinstance(device, jt.NanoString) or callable(device): -# return jt.to(x,*args,**kw) -# if 'cpu' in str(device): -# args = [] -# device = kw.get("device",None) -# if 'cpu' in str(device): -# kw.pop('device',None) -# print("to cpu") -# # print(kw) -# return jt.to(x,*args,**kw) -# Tensor.to = conflict_wrapper(jt.to, to) - -# mm = wrapper(jt.matmul) - -# def _data_get(x): -# return x - -# def _data_set(x, value): -# x.assign(value) - -# Tensor.data = property(_data_get, _data_set) -# Tensor.layout = None \ No newline at end of file diff --git a/python/jittor/compatibility/autograd.py b/python/jittor/compatibility/autograd.py deleted file mode 100644 index 5ed88dde..00000000 --- a/python/jittor/compatibility/autograd.py +++ /dev/null @@ -1,134 +0,0 @@ -import jittor as jt -from jittor import Var -from collections.abc import Sequence, Mapping - -Variable = Var - -class FunctionContext: - def save_for_backward(self, *args): - self.saved_tensors = args - -class Function: - ''' Function Module for customized backward operations - -Example 1 (Function can have multiple input and multiple output, and user -can store value for backward computation):: - - import jtorch - from jtorch import Function - - class MyFunc(Function): - @staticmethod - def forward(self, x, y): - self.x = x - self.y = y - return x*y, x/y - - @staticmethod - def backward(self, grad0, grad1): - return grad0 * self.y, grad1 * self.x - - a = jtorch.array(3.0) - a.requires_grad = True - b = jtorch.array(4.0) - b.requires_grad = True - func = MyFunc.apply - c,d = func(a, b) - (c+d*3).backward() - assert a.grad.data == 4 - assert b.grad.data == 9 - -Example 2(Function can return None for no gradiant, and gradiant -can also be None):: - - import jtorch - from jtorch import Function - - class MyFunc(Function): - @staticmethod - def forward(self, x, y): - self.x = x - self.y = y - return x*y, x/y - - @staticmethod - def backward(self, grad0, grad1): - assert grad1 is None - return grad0 * self.y, None - a = jt.array(3.0) - a.requires_grad = True - b = jt.array(4.0) - b.requires_grad = True - func = MyFunc.apply - c,d = func(a, b) - d.stop_grad() - da, db = jt.grad(c+d*3, [a, b]) - assert da.data == 4 - assert db.data == 0 - - ''' - def __call__(self, *args): - backup = args - args = list(args) - taped_inputs = [] - taped_outputs = [] - input_mask = [-1] * len(args) - for i,v in enumerate(args): - if isinstance(v, Var): - if v.is_stop_grad(): - # -2 in input_mask represents it is stop_grad - input_mask[i] = -2 - continue - v = v.tape() - input_mask[i] = len(taped_inputs) - args[i] = v - taped_inputs.append(v) - ctx = FunctionContext() - ori_res = self.forward(ctx, *args) - # ori_res = self.execute(*args) - if not isinstance(ori_res, Sequence): - res = [ori_res] - else: - res = list(ori_res) - output_mask = [-1] * len(res) - for i,v in enumerate(res): - if isinstance(v, Var): - v = v.tape() - output_mask[i] = len(taped_outputs) - res[i] = v - taped_outputs.append(v) - ctx.input_mask = input_mask - ctx.output_mask = output_mask - # tape output and input together so - # backward treat them as one operator - jt.tape_together(taped_inputs, taped_outputs, - lambda *args: self._grad(ctx, self, *args)) - if isinstance(ori_res, Sequence): - return res - else: - return res[0] - - @staticmethod - def _grad(ctx, func, *args): - new_args = ( (args[i] if i>=0 else None) for i in ctx.output_mask ) - ret = func.backward(ctx, *new_args) - if not isinstance(ret, Sequence): - ret = (ret,) - new_ret = [] - for i, r in enumerate(ret): - j = ctx.input_mask[i] - if j<0: - # -2 in input_mask represents it is stop_grad - assert r is None or j==-2, f"{type(self)}'s {i}-th returned grad should be None, "\ - "because the input value is not jittor variable." - else: - new_ret.append(r) - return new_ret - - def dfs(self, parents, k, callback, callback_leave=None): - pass - - @classmethod - def apply(cls, *args, **kw): - func = cls() - return func(*args, **kw) diff --git a/python/jittor/compatibility/compiler.py b/python/jittor/compatibility/compiler.py deleted file mode 100644 index 77bab138..00000000 --- a/python/jittor/compatibility/compiler.py +++ /dev/null @@ -1,39 +0,0 @@ -import jittor as jt -import jittor_utils -import glob -import os -from jittor import pyjt_compiler -import sys -from jittor_utils import lock - - -jtorch_path = os.path.dirname(__file__) -cache_path = os.path.join(jt.compiler.cache_path, "jtorch") -# os.makedirs(cache_path, exist_ok=True) -os.makedirs(os.path.join(cache_path, "gen"), exist_ok=True) - -with lock.lock_scope(): - pyjt_gen_src = pyjt_compiler.compile(cache_path, jtorch_path) - -ext_args = 'c[cu]' if jt.has_cuda else 'cc' -files = glob.glob(jtorch_path+"/src/**/*."+ext_args, recursive=True) -files += pyjt_gen_src -cc_flags = " -I\""+os.path.join(jtorch_path, "src")+"\" " -if os.environ.get("use_data_o", "1") == "1": - files += glob.glob(jtorch_path+"/src/**/*.o", recursive=True) - files = [f for f in files if "__data__" not in f] - - -with lock.lock_scope(): - jt.compiler.compile( - jt.compiler.cc_path, - jt.compiler.cc_flags+jt.compiler.opt_flags+ cc_flags, - files, - "jtorch_core"+jt.compiler.extension_suffix, - obj_dirname="jtorch_objs") - - -with jittor_utils.import_scope(jt.compiler.import_flags): - import jtorch_core as core - -jt.flags.th_mode = 1 diff --git a/python/jittor/compatibility/cuda.py b/python/jittor/compatibility/cuda.py deleted file mode 100644 index 75665c7c..00000000 --- a/python/jittor/compatibility/cuda.py +++ /dev/null @@ -1,64 +0,0 @@ -import jittor as jt -import jtorch - -def is_available(): - return jt.has_cuda - -def device_count(): - return int(jt.has_cuda) - -def set_device(device=None): - pass - -def get_rng_state(device=None): - pass - -def current_device(): - return jtorch.device("cuda") - -def mem_get_info(i): - return ("75GB",) - - -class Generator: - def __init__(self): - pass - - def set_state(self, state): - self.state = state - -default_generators = [Generator()] -_lazy_call = lambda func: func() -device = None - -LongTensor = jt.int64 -FloatTensor = jt.float -HalfTensor = jt.float16 -BoolTensor = jt.bool - -manual_seed = jt.set_global_seed -manual_seed_all = jt.set_global_seed - -def synchronize(): - jt.sync_all(True) - -class Event: - pass - -class Stream: - pass - -from typing import Any - -from .gradscaler import GradScaler - -class autocast: - def __init__(self,**kwargs): - pass - - def __enter__(self,): - pass - - def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any): - pass - diff --git a/python/jittor/compatibility/distributed.py b/python/jittor/compatibility/distributed.py deleted file mode 100644 index e39f559a..00000000 --- a/python/jittor/compatibility/distributed.py +++ /dev/null @@ -1,53 +0,0 @@ -import datetime -from enum import Enum -import jittor as jt - - -class DistributedDataParallel: - def __new__(cls, model): - return model - -def is_initialized(): - return True - -def get_rank(group=None): - return 0 - -def get_world_size(group=None): - return 1 - -def get_backend(group=None): - return "nccl" - -def new_group(ranks=None, timeout=datetime.timedelta(seconds=1800), backend=None, pg_options=None): - return 1 - -def barrier(): - pass - -def is_available(): - return True - -def is_built(): - return True - -class ReduceOp: - SUM = 0 - -class GroupMember: - WORLD = 0 - -class ProcessGroup: - pass - -class Join: - pass - -dist_backend = Enum("dist_backend", ("GLOO", "MPI", "NCCL")) -_backend = dist_backend.NCCL - -def is_mpi_available(): - return jt.in_mpi - -def DistributedDataParallel(model, *args, **kw): - return model diff --git a/python/jittor/compatibility/distributions.py b/python/jittor/compatibility/distributions.py deleted file mode 100644 index a98dfe29..00000000 --- a/python/jittor/compatibility/distributions.py +++ /dev/null @@ -1,15 +0,0 @@ -import jittor as jt - -class RelaxedBernoulli: - def __init__(self, temperature, probs=None, logits=None): - self.temperature = temperature - self.probs = probs - self.logits = logits - - def rsample(self): - noise = jt.rand_like(self.logits) - eps = 1e-20 - noise = jt.clamp(noise, eps, 1.0 - eps) - logit_noise = jt.log(noise) - jt.log(1 - noise) - sample = (self.logits + logit_noise) / self.temperature - return jt.sigmoid(sample) diff --git a/python/jittor/compatibility/fft/__init__.py b/python/jittor/compatibility/fft/__init__.py deleted file mode 100644 index 7a89fc9c..00000000 --- a/python/jittor/compatibility/fft/__init__.py +++ /dev/null @@ -1,5 +0,0 @@ -#TODO: Implement FFT and IFFT -fftn = None -fftshift = None -ifftn = None -ifftshift = None \ No newline at end of file diff --git a/python/jittor/compatibility/fx.py b/python/jittor/compatibility/fx.py deleted file mode 100644 index 0f0eb4f8..00000000 --- a/python/jittor/compatibility/fx.py +++ /dev/null @@ -1,2 +0,0 @@ -class Proxy: - pass \ No newline at end of file diff --git a/python/jittor/compatibility/gradscaler.py b/python/jittor/compatibility/gradscaler.py deleted file mode 100644 index 087d6bb2..00000000 --- a/python/jittor/compatibility/gradscaler.py +++ /dev/null @@ -1,519 +0,0 @@ -from collections import defaultdict, abc -from enum import Enum -from typing import Any, Dict, List, Optional, Tuple, cast -import inspect -import warnings - -import jittor as jt -# import torch - -def _refresh_per_optimizer_state(): - return {} - - -class GradScaler: - _scale: Optional[jt.Var] - _grows_tracker: Optional[jt.Var] - _per_optimizer_states: Dict[int, Dict[str, Any]] - """ - An instance ``scaler`` of :class:`GradScaler` helps perform the steps of gradient scaling - conveniently. - - * ``scaler.scale(loss)`` multiplies a given loss by ``scaler``'s current scale factor. - * ``scaler.step(optimizer)`` safely unscales gradients and calls ``optimizer.step()``. - * ``scaler.update()`` updates ``scaler``'s scale factor. - - Example:: - - # Creates a GradScaler once at the beginning of training. - scaler = GradScaler() - - for epoch in epochs: - for input, target in data: - optimizer.zero_grad() - output = model(input) - loss = loss_fn(output, target) - - # Scales loss. Calls backward() on scaled loss to create scaled gradients. - scaler.scale(loss).backward() - - # scaler.step() first unscales gradients of the optimizer's params. - # If gradients don't contain infs/NaNs, optimizer.step() is then called, - # otherwise, optimizer.step() is skipped. - scaler.step(optimizer) - - # Updates the scale for next iteration. - scaler.update() - - See the :ref:`Automatic Mixed Precision examples` for usage - (along with autocasting) in more complex cases like gradient clipping, gradient accumulation, gradient penalty, - and multiple losses/optimizers. - - ``scaler`` dynamically estimates the scale factor each iteration. To minimize gradient underflow, - a large scale factor should be used. However, ``float16`` values can "overflow" (become inf or NaN) if - the scale factor is too large. Therefore, the optimal scale factor is the largest factor that can be used - without incurring inf or NaN gradient values. - ``scaler`` approximates the optimal scale factor over time by checking the gradients for infs and NaNs during every - ``scaler.step(optimizer)`` (or optional separate ``scaler.unscale_(optimizer)``, see :meth:`unscale_`). - - * If infs/NaNs are found, ``scaler.step(optimizer)`` skips the underlying ``optimizer.step()`` (so the params - themselves remain uncorrupted) and ``update()`` multiplies the scale by ``backoff_factor``. - - * If no infs/NaNs are found, ``scaler.step(optimizer)`` runs the underlying ``optimizer.step()`` as usual. - If ``growth_interval`` unskipped iterations occur consecutively, ``update()`` multiplies the scale by - ``growth_factor``. - - The scale factor often causes infs/NaNs to appear in gradients for the first few iterations as its - value calibrates. ``scaler.step`` will skip the underlying ``optimizer.step()`` for these - iterations. After that, step skipping should occur rarely (once every few hundred or thousand iterations). - - Args: - init_scale (float, optional, default=2.**16): Initial scale factor. - growth_factor (float, optional, default=2.0): Factor by which the scale is multiplied during - :meth:`update` if no inf/NaN gradients occur for ``growth_interval`` consecutive iterations. - backoff_factor (float, optional, default=0.5): Factor by which the scale is multiplied during - :meth:`update` if inf/NaN gradients occur in an iteration. - growth_interval (int, optional, default=2000): Number of consecutive iterations without inf/NaN gradients - that must occur for the scale to be multiplied by ``growth_factor``. - enabled (bool, optional): If ``False``, disables gradient scaling. :meth:`step` simply - invokes the underlying ``optimizer.step()``, and other methods become no-ops. - Default: ``True`` - """ - def __init__(self, - init_scale=2.**16, - growth_factor=2.0, - backoff_factor=0.5, - growth_interval=2000, - enabled=True): - self._enabled = enabled - - if self._enabled: - assert growth_factor > 1.0, "The growth factor must be > 1.0." - assert backoff_factor < 1.0, "The backoff factor must be < 1.0." - - self._init_scale = init_scale - # self._scale will be lazily initialized during the first call to scale() - self._scale = None - self._growth_factor = growth_factor - self._backoff_factor = backoff_factor - self._growth_interval = growth_interval - self._init_growth_tracker = 0 - # self._growth_tracker will be lazily initialized during the first call to scale() - self._growth_tracker = None - self._per_optimizer_states = defaultdict(_refresh_per_optimizer_state) - - def _check_scale_growth_tracker(self, funcname) -> Tuple[jt.Var, jt.Var]: - fix = "This may indicate your script did not use scaler.scale(loss or outputs) earlier in the iteration." - assert self._scale is not None, "Attempted {} but _scale is None. ".format(funcname) + fix - assert self._growth_tracker is not None, "Attempted {} but _growth_tracker is None. ".format(funcname) + fix - return (self._scale, self._growth_tracker) - - def _lazy_init_scale_growth_tracker(self): - assert self._growth_tracker is None, "_growth_tracker initialized before _scale" - self._scale = self._init_scale - self._growth_tracker = self._init_growth_tracker - - def scale(self, outputs): - """ - Multiplies ('scales') a tensor or list of tensors by the scale factor. - - Returns scaled outputs. If this instance of :class:`GradScaler` is not enabled, outputs are returned - unmodified. - - Args: - outputs (Tensor or iterable of Tensors): Outputs to scale. - """ - if not self._enabled: - return outputs - - - # Short-circuit for the common case. - if isinstance(outputs, jt.Var): - assert jt.flags.use_cuda == 1 - if self._scale is None: - self._lazy_init_scale_growth_tracker() - assert self._scale is not None - return outputs * self._scale - - def apply_scale(val): - if isinstance(val, jt.Var): - assert jt.flags.use_cuda == 1 - if self._scale is None: - self._lazy_init_scale_growth_tracker() - assert self._scale is not None - return val * self._scale - elif isinstance(val, abc.Iterable): - iterable = map(apply_scale, val) - if isinstance(val, (list, tuple)): - return type(val)(iterable) - else: - return iterable - else: - raise ValueError("outputs must be a Tensor or an iterable of Tensors") - - return apply_scale(outputs) - - def _unscale_grads_(self, optimizer, inv_scale, found_inf, allow_fp16): - with jt.no_grad(): - optimizer.pre_step() - for group in optimizer.param_groups: - for to_unscale in group["grads"]: - if to_unscale is None or isinstance(to_unscale,(int,float)): - continue - if (not allow_fp16) and str(to_unscale.dtype) == "float16": - raise ValueError("Attempting to unscale FP16 gradients.") - - if not (to_unscale.isinf().any()): - if inv_scale != 1.0: - to_unscale.update(to_unscale*inv_scale) - else: - found_inf = 1.0 - - return found_inf - - def unscale_(self, optimizer): - """ - Divides ("unscales") the optimizer's gradient tensors by the scale factor. - - :meth:`unscale_` is optional, serving cases where you need to - :ref:`modify or inspect gradients` - between the backward pass(es) and :meth:`step`. - If :meth:`unscale_` is not called explicitly, gradients will be unscaled automatically during :meth:`step`. - - Simple example, using :meth:`unscale_` to enable clipping of unscaled gradients:: - - ... - scaler.scale(loss).backward() - scaler.unscale_(optimizer) - torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm) - scaler.step(optimizer) - scaler.update() - - Args: - optimizer (torch.optim.Optimizer): Optimizer that owns the gradients to be unscaled. - - .. note:: - :meth:`unscale_` does not incur a CPU-GPU sync. - - .. warning:: - :meth:`unscale_` should only be called once per optimizer per :meth:`step` call, - and only after all gradients for that optimizer's assigned parameters have been accumulated. - Calling :meth:`unscale_` twice for a given optimizer between each :meth:`step` triggers a RuntimeError. - - .. warning:: - :meth:`unscale_` may unscale sparse gradients out of place, replacing the ``.grad`` attribute. - """ - if not self._enabled: - return - - self._check_scale_growth_tracker("unscale_") - - optimizer_state = self._per_optimizer_states[id(optimizer)] - - if hasattr(optimizer,"get_find_inf"): - return - # FP32 division can be imprecise for certain compile options, so we carry out the reciprocal in FP64. - assert self._scale is not None - inv_scale = 1.0 / self._scale - found_inf = 0.0 - optimizer_state["found_inf_per_device"] = self._unscale_grads_(optimizer, inv_scale, found_inf, False) - - - def step(self, optimizer, *args, **kwargs): - """ - :meth:`step` carries out the following two operations: - - 1. Internally invokes ``unscale_(optimizer)`` (unless :meth:`unscale_` was explicitly called for ``optimizer`` - earlier in the iteration). As part of the :meth:`unscale_`, gradients are checked for infs/NaNs. - 2. If no inf/NaN gradients are found, invokes ``optimizer.step()`` using the unscaled - gradients. Otherwise, ``optimizer.step()`` is skipped to avoid corrupting the params. - - ``*args`` and ``**kwargs`` are forwarded to ``optimizer.step()``. - - Returns the return value of ``optimizer.step(*args, **kwargs)``. - - Args: - optimizer (torch.optim.Optimizer): Optimizer that applies the gradients. - args: Any arguments. - kwargs: Any keyword arguments. - - .. warning:: - Closure use is not currently supported. - """ - if (not self._enabled): - return optimizer.step(*args, **kwargs) - - if "closure" in kwargs: - raise RuntimeError("Closure use is not currently supported if GradScaler is enabled.") - - self._check_scale_growth_tracker("step") - - optimizer_state = self._per_optimizer_states[id(optimizer)] - retval = None - - if (hasattr(optimizer, "_step_supports_amp_scaling") and optimizer._step_supports_amp_scaling): - # This optimizer has customized scale-handling logic, so we can call optimizer.step() directly. - # The contract with custom optimizers is that their step() should accept an additional, - # optional grad_scaler kwarg. We append self to the kwargs so the custom optimizer has full information: - # it can query its own state, invoke unscale_ on itself, etc - # The contract above is being deprecated to avoid introducing `grad_scaler: GradScaler` argument - # to `Optimizer.step`. The new behavior is going to add two Tensor attributes of `grad_scale` - # and `found_inf` to the passed optimizer so that the optimizer can utilize those - # to skip the parameter updates or unscale gradients before updating parameters in - # the fused kernel, e.g. `FusedAdamMathFunctor`. - # In this behavior, `GradScaler._check_inf_per_device` is called if `OptState.READY`, - # while the method is expected to be called by users side, i.e. their optimizers. - kwargs_ = kwargs - has_grad_scaler_kwarg = "grad_scaler" in inspect.signature(optimizer.step).parameters - if has_grad_scaler_kwarg: - warnings.warn( - "GradScaler is going to stop passing itself as a keyword argument to the passed " - "optimizer. In the near future GradScaler registers `grad_scale: Tensor` and " - "`found_inf: Tensor` to the passed optimizer and let the optimizer use them directly.", - FutureWarning) - kwargs_.update({"grad_scaler": self}) - else: - if optimizer_state["stage"] is OptState.READY: - self._check_inf_per_device(optimizer) - scaler = self._get_scale_async() - found_inf = cast( - jt.Var, - sum([ - t for t in optimizer_state["found_inf_per_device"].values() - ]) - ) - optimizer.grad_scale = None if optimizer_state["stage"] == OptState.UNSCALED else scaler - optimizer.found_inf = found_inf - retval = optimizer.step(*args, **kwargs_) - optimizer_state["stage"] = OptState.STEPPED - if not has_grad_scaler_kwarg: - del optimizer.grad_scale - del optimizer.found_inf - return retval - - if hasattr(optimizer,"get_find_inf"): - optimizer.set_grad_scale(self._scale) - optimizer.step() - optimizer_state["found_inf_per_device"] = optimizer.get_find_inf() - return - - retval = None - if not optimizer_state["found_inf_per_device"]: - retval = optimizer.step(*args, **kwargs) - else: - optimizer.post_step() - - return retval - - - def update(self, new_scale=None): - """ - Updates the scale factor. - - If any optimizer steps were skipped the scale is multiplied by ``backoff_factor`` - to reduce it. If ``growth_interval`` unskipped iterations occurred consecutively, - the scale is multiplied by ``growth_factor`` to increase it. - - Passing ``new_scale`` sets the new scale value manually. (``new_scale`` is not - used directly, it's used to fill GradScaler's internal scale tensor. So if - ``new_scale`` was a tensor, later in-place changes to that tensor will not further - affect the scale GradScaler uses internally.) - - Args: - new_scale (float or :class:`torch.cuda.FloatTensor`, optional, default=None): New scale factor. - - .. warning:: - :meth:`update` should only be called at the end of the iteration, after ``scaler.step(optimizer)`` has - been invoked for all optimizers used this iteration. - """ - if not self._enabled: - return - - _scale, _growth_tracker = self._check_scale_growth_tracker("update") - - if new_scale is not None: - # Accept a new user-defined scale. - if isinstance(new_scale, float): - self._scale.fill_(new_scale) # type: ignore[union-attr] - else: - reason = "new_scale should be a float or a 1-element torch.cuda.FloatTensor with requires_grad=False." - assert isinstance(new_scale, torch.cuda.FloatTensor), reason # type: ignore[attr-defined] - assert new_scale.numel() == 1, reason - assert new_scale.requires_grad is False, reason - self._scale.copy_(new_scale) # type: ignore[union-attr] - else: - # Consume shared inf/nan data collected from optimizers to update the scale. - # If all found_inf tensors are on the same device as self._scale, this operation is asynchronous. - found_infs = [state["found_inf_per_device"] - for state in self._per_optimizer_states.values() - ] - - assert len(found_infs) > 0, "No inf checks were recorded prior to update." - - found_inf_combined = found_infs[0] - if len(found_infs) > 1: - for i in range(1, len(found_infs)): - found_inf_combined += found_infs[i] - - - current_scale = _scale - if found_inf_combined: - current_scale *=self._backoff_factor - _growth_tracker = 0 - else: - successful = _growth_tracker+1 - if successful == self._growth_interval: - new_scale = current_scale*self._growth_factor - if new_scale < 1e9: - current_scale = new_scale - _growth_tracker = 0 - else: - _growth_tracker = successful - - self._scale, self._growth_tracker = current_scale,_growth_tracker - - # To prepare for next iteration, clear the data collected from optimizers this iteration. - self._per_optimizer_states = defaultdict(_refresh_per_optimizer_state) - - def _get_scale_async(self): - return self._scale - - def get_scale(self): - """ - Returns a Python float containing the current scale, or 1.0 if scaling is disabled. - - .. warning:: - :meth:`get_scale` incurs a CPU-GPU sync. - """ - if self._enabled: - return self._init_scale if self._scale is None else self._get_scale_async() - else: - return 1.0 - - def get_growth_factor(self): - r""" - Returns a Python float containing the scale growth factor. - """ - return self._growth_factor - - def set_growth_factor(self, new_factor): - r""" - Args: - new_scale (float): Value to use as the new scale growth factor. - """ - self._growth_factor = new_factor - - def get_backoff_factor(self): - r""" - Returns a Python float containing the scale backoff factor. - """ - return self._backoff_factor - - def set_backoff_factor(self, new_factor): - r""" - Args: - new_scale (float): Value to use as the new scale backoff factor. - """ - self._backoff_factor = new_factor - - def get_growth_interval(self): - r""" - Returns a Python int containing the growth interval. - """ - return self._growth_interval - - def set_growth_interval(self, new_interval): - r""" - Args: - new_interval (int): Value to use as the new growth interval. - """ - self._growth_interval = new_interval - - def _get_growth_tracker(self): - if self._enabled: - return self._init_growth_tracker if self._growth_tracker is None else self._growth_tracker.item() - else: - return 0 - - def is_enabled(self): - r""" - Returns a bool indicating whether this instance is enabled. - """ - return self._enabled - - def state_dict(self): - r""" - Returns the state of the scaler as a :class:`dict`. It contains five entries: - - * ``"scale"`` - a Python float containing the current scale - * ``"growth_factor"`` - a Python float containing the current growth factor - * ``"backoff_factor"`` - a Python float containing the current backoff factor - * ``"growth_interval"`` - a Python int containing the current growth interval - * ``"_growth_tracker"`` - a Python int containing the number of recent consecutive unskipped steps. - - If this instance is not enabled, returns an empty dict. - - .. note:: - If you wish to checkpoint the scaler's state after a particular iteration, :meth:`state_dict` - should be called after :meth:`update`. - """ - return {"scale": self.get_scale(), - "growth_factor": self._growth_factor, - "backoff_factor": self._backoff_factor, - "growth_interval": self._growth_interval, - "_growth_tracker": self._get_growth_tracker()} if self._enabled else {} - - def load_state_dict(self, state_dict): - r""" - Loads the scaler state. If this instance is disabled, :meth:`load_state_dict` is a no-op. - - Args: - state_dict(dict): scaler state. Should be an object returned from a call to :meth:`state_dict`. - """ - if not self._enabled: - return - - if len(state_dict) == 0: - raise RuntimeError("The source state dict is empty, possibly because it was saved " - "from a disabled instance of GradScaler.") - - self._init_scale = state_dict["scale"] - if self._scale is not None: - self._scale.fill_(state_dict["scale"]) - self._growth_factor = state_dict["growth_factor"] - self._backoff_factor = state_dict["backoff_factor"] - self._growth_interval = state_dict["growth_interval"] - self._init_growth_tracker = state_dict["_growth_tracker"] - if self._growth_tracker is not None: - self._growth_tracker.fill_(state_dict["_growth_tracker"]) - - def __getstate__(self): - state = self.__dict__.copy() - if self._enabled: - assert len(self._per_optimizer_states) == 0, "A GradScaler instance may only be pickled at the beginning "\ - "of an iteration, or at the end after scaler.update()." - # Pickling _scale and _growth_tracker Tensors directly triggers - # "warnings.warn("pickle support for Storage will be removed in 1.5..." - # so instead, we set the unpickled instance up to reinitialize them lazily. - state['_init_scale'] = self.get_scale() - state['_init_growth_tracker'] = self._get_growth_tracker() - state['_scale'] = None - state['_growth_tracker'] = None - return state - - def __setstate__(self, state): - self.__dict__.update(state) - - def _check_inf_per_device(self, optimizer): - _scale, _ = self._check_scale_growth_tracker("_check_inf_per_device") - - dummy_inv_scale = 1.0 - found_inf = 0.0 - - self._per_optimizer_states[id(optimizer)]["found_inf_per_device"] = \ - self._unscale_grads_(optimizer, dummy_inv_scale, found_inf, True) - - return self._per_optimizer_states[id(optimizer)]["found_inf_per_device"] - - def _found_inf_per_device(self, optimizer): - return self._per_optimizer_states[id(optimizer)]["found_inf_per_device"] diff --git a/python/jittor/compatibility/gradscaler_old.py b/python/jittor/compatibility/gradscaler_old.py deleted file mode 100644 index 389be2cf..00000000 --- a/python/jittor/compatibility/gradscaler_old.py +++ /dev/null @@ -1,556 +0,0 @@ -from collections import defaultdict, abc -from enum import Enum -from typing import Any, Dict, List, Optional, Tuple, cast -import inspect -import warnings - -import jittor as jt -# import torch - - -__all__ = ["OptState", "GradScaler"] - - -# Defines default_factory for GradScaler's _per_optimizer_states defaultdict, -# as well as associated "enum" values. Prefers defining these at top level because -# - Lambdas can't be pickled, so we don't want to supply a lambda as the factory. -# - Defining READY, UNSCALED, STEPPED and _refresh_per_optimizer_state within GradScaler -# causes a circular reference, which we'd rather avoid. -class OptState(Enum): - READY = 0 - UNSCALED = 1 - STEPPED = 2 - - -def _refresh_per_optimizer_state(): - return {"stage": OptState.READY, "found_inf_per_device": {}} - - -class GradScaler: - _scale: Optional[jt.Var] - _grows_tracker: Optional[jt.Var] - _per_optimizer_states: Dict[int, Dict[str, Any]] - """ - An instance ``scaler`` of :class:`GradScaler` helps perform the steps of gradient scaling - conveniently. - - * ``scaler.scale(loss)`` multiplies a given loss by ``scaler``'s current scale factor. - * ``scaler.step(optimizer)`` safely unscales gradients and calls ``optimizer.step()``. - * ``scaler.update()`` updates ``scaler``'s scale factor. - - Example:: - - # Creates a GradScaler once at the beginning of training. - scaler = GradScaler() - - for epoch in epochs: - for input, target in data: - optimizer.zero_grad() - output = model(input) - loss = loss_fn(output, target) - - # Scales loss. Calls backward() on scaled loss to create scaled gradients. - scaler.scale(loss).backward() - - # scaler.step() first unscales gradients of the optimizer's params. - # If gradients don't contain infs/NaNs, optimizer.step() is then called, - # otherwise, optimizer.step() is skipped. - scaler.step(optimizer) - - # Updates the scale for next iteration. - scaler.update() - - See the :ref:`Automatic Mixed Precision examples` for usage - (along with autocasting) in more complex cases like gradient clipping, gradient accumulation, gradient penalty, - and multiple losses/optimizers. - - ``scaler`` dynamically estimates the scale factor each iteration. To minimize gradient underflow, - a large scale factor should be used. However, ``float16`` values can "overflow" (become inf or NaN) if - the scale factor is too large. Therefore, the optimal scale factor is the largest factor that can be used - without incurring inf or NaN gradient values. - ``scaler`` approximates the optimal scale factor over time by checking the gradients for infs and NaNs during every - ``scaler.step(optimizer)`` (or optional separate ``scaler.unscale_(optimizer)``, see :meth:`unscale_`). - - * If infs/NaNs are found, ``scaler.step(optimizer)`` skips the underlying ``optimizer.step()`` (so the params - themselves remain uncorrupted) and ``update()`` multiplies the scale by ``backoff_factor``. - - * If no infs/NaNs are found, ``scaler.step(optimizer)`` runs the underlying ``optimizer.step()`` as usual. - If ``growth_interval`` unskipped iterations occur consecutively, ``update()`` multiplies the scale by - ``growth_factor``. - - The scale factor often causes infs/NaNs to appear in gradients for the first few iterations as its - value calibrates. ``scaler.step`` will skip the underlying ``optimizer.step()`` for these - iterations. After that, step skipping should occur rarely (once every few hundred or thousand iterations). - - Args: - init_scale (float, optional, default=2.**16): Initial scale factor. - growth_factor (float, optional, default=2.0): Factor by which the scale is multiplied during - :meth:`update` if no inf/NaN gradients occur for ``growth_interval`` consecutive iterations. - backoff_factor (float, optional, default=0.5): Factor by which the scale is multiplied during - :meth:`update` if inf/NaN gradients occur in an iteration. - growth_interval (int, optional, default=2000): Number of consecutive iterations without inf/NaN gradients - that must occur for the scale to be multiplied by ``growth_factor``. - enabled (bool, optional): If ``False``, disables gradient scaling. :meth:`step` simply - invokes the underlying ``optimizer.step()``, and other methods become no-ops. - Default: ``True`` - """ - def __init__(self, - init_scale=2.**16, - growth_factor=2.0, - backoff_factor=0.5, - growth_interval=2000, - enabled=True): - self._enabled = enabled - - if self._enabled: - assert growth_factor > 1.0, "The growth factor must be > 1.0." - assert backoff_factor < 1.0, "The backoff factor must be < 1.0." - - self._init_scale = init_scale - # self._scale will be lazily initialized during the first call to scale() - self._scale = None - self._growth_factor = growth_factor - self._backoff_factor = backoff_factor - self._growth_interval = growth_interval - self._init_growth_tracker = 0 - # self._growth_tracker will be lazily initialized during the first call to scale() - self._growth_tracker = None - self._per_optimizer_states = defaultdict(_refresh_per_optimizer_state) - - def _check_scale_growth_tracker(self, funcname) -> Tuple[jt.Var, jt.Var]: - fix = "This may indicate your script did not use scaler.scale(loss or outputs) earlier in the iteration." - assert self._scale is not None, "Attempted {} but _scale is None. ".format(funcname) + fix - assert self._growth_tracker is not None, "Attempted {} but _growth_tracker is None. ".format(funcname) + fix - return (self._scale, self._growth_tracker) - - def _lazy_init_scale_growth_tracker(self): - assert self._growth_tracker is None, "_growth_tracker initialized before _scale" - self._scale = self._init_scale - self._growth_tracker = self._init_growth_tracker - - def scale(self, outputs): - """ - Multiplies ('scales') a tensor or list of tensors by the scale factor. - - Returns scaled outputs. If this instance of :class:`GradScaler` is not enabled, outputs are returned - unmodified. - - Args: - outputs (Tensor or iterable of Tensors): Outputs to scale. - """ - print("scale") - if not self._enabled: - return outputs - - - # Short-circuit for the common case. - if isinstance(outputs, jt.Var): - assert jt.flags.use_cuda == 1 - if self._scale is None: - self._lazy_init_scale_growth_tracker() - assert self._scale is not None - return outputs * self._scale - - def apply_scale(val): - if isinstance(val, jt.Var): - assert jt.flags.use_cuda == 1 - if self._scale is None: - self._lazy_init_scale_growth_tracker() - assert self._scale is not None - return val * self._scale - elif isinstance(val, abc.Iterable): - iterable = map(apply_scale, val) - if isinstance(val, (list, tuple)): - return type(val)(iterable) - else: - return iterable - else: - raise ValueError("outputs must be a Tensor or an iterable of Tensors") - - return apply_scale(outputs) - - def _unscale_grads_(self, optimizer, inv_scale, found_inf, allow_fp16): - - # To set up _amp_foreach_non_finite_check_and_unscale_, split grads by device and dtype. - # There could be hundreds of grads, so we'd like to iterate through them just once. - # However, we don't know their devices or dtypes in advance. - - # https://stackoverflow.com/questions/5029934/defaultdict-of-defaultdict - # Google says mypy struggles with defaultdicts type annotations. - with jt.no_grad(): - optimizer.pre_step() - for group in optimizer.param_groups: - for to_unscale in group["grads"]: - if to_unscale is None or isinstance(to_unscale,(int,float)): - continue - if (not allow_fp16) and str(to_unscale.dtype) == "float16": - raise ValueError("Attempting to unscale FP16 gradients.") - - if not (to_unscale.isinf().any()): - if inv_scale != 1.0: - to_unscale.update(to_unscale*inv_scale) - else: - found_inf = 1.0 - - return found_inf - - def unscale_(self, optimizer): - """ - Divides ("unscales") the optimizer's gradient tensors by the scale factor. - - :meth:`unscale_` is optional, serving cases where you need to - :ref:`modify or inspect gradients` - between the backward pass(es) and :meth:`step`. - If :meth:`unscale_` is not called explicitly, gradients will be unscaled automatically during :meth:`step`. - - Simple example, using :meth:`unscale_` to enable clipping of unscaled gradients:: - - ... - scaler.scale(loss).backward() - scaler.unscale_(optimizer) - torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm) - scaler.step(optimizer) - scaler.update() - - Args: - optimizer (torch.optim.Optimizer): Optimizer that owns the gradients to be unscaled. - - .. note:: - :meth:`unscale_` does not incur a CPU-GPU sync. - - .. warning:: - :meth:`unscale_` should only be called once per optimizer per :meth:`step` call, - and only after all gradients for that optimizer's assigned parameters have been accumulated. - Calling :meth:`unscale_` twice for a given optimizer between each :meth:`step` triggers a RuntimeError. - - .. warning:: - :meth:`unscale_` may unscale sparse gradients out of place, replacing the ``.grad`` attribute. - """ - if not self._enabled: - return - - self._check_scale_growth_tracker("unscale_") - - optimizer_state = self._per_optimizer_states[id(optimizer)] - - if optimizer_state["stage"] is OptState.UNSCALED: - raise RuntimeError("unscale_() has already been called on this optimizer since the last update().") - elif optimizer_state["stage"] is OptState.STEPPED: - raise RuntimeError("unscale_() is being called after step().") - - - # FP32 division can be imprecise for certain compile options, so we carry out the reciprocal in FP64. - assert self._scale is not None - inv_scale = 1.0 / self._scale - found_inf = 0.0 - optimizer_state["found_inf_per_device"] = self._unscale_grads_(optimizer, inv_scale, found_inf, False) - optimizer_state["stage"] = OptState.UNSCALED - - def _maybe_opt_step(self, optimizer, optimizer_state, *args, **kwargs): - retval = None - if not optimizer_state["found_inf_per_device"]: - retval = optimizer.step(*args, **kwargs) - else: - optimizer.post_step() - - return retval - - def step(self, optimizer, *args, **kwargs): - """ - :meth:`step` carries out the following two operations: - - 1. Internally invokes ``unscale_(optimizer)`` (unless :meth:`unscale_` was explicitly called for ``optimizer`` - earlier in the iteration). As part of the :meth:`unscale_`, gradients are checked for infs/NaNs. - 2. If no inf/NaN gradients are found, invokes ``optimizer.step()`` using the unscaled - gradients. Otherwise, ``optimizer.step()`` is skipped to avoid corrupting the params. - - ``*args`` and ``**kwargs`` are forwarded to ``optimizer.step()``. - - Returns the return value of ``optimizer.step(*args, **kwargs)``. - - Args: - optimizer (torch.optim.Optimizer): Optimizer that applies the gradients. - args: Any arguments. - kwargs: Any keyword arguments. - - .. warning:: - Closure use is not currently supported. - """ - if (not self._enabled): - return optimizer.step(*args, **kwargs) - - if "closure" in kwargs: - raise RuntimeError("Closure use is not currently supported if GradScaler is enabled.") - - self._check_scale_growth_tracker("step") - - optimizer_state = self._per_optimizer_states[id(optimizer)] - - if optimizer_state["stage"] is OptState.STEPPED: - raise RuntimeError("step() has already been called since the last update().") - - retval = None - - if (hasattr(optimizer, "_step_supports_amp_scaling") and optimizer._step_supports_amp_scaling): - # This optimizer has customized scale-handling logic, so we can call optimizer.step() directly. - # The contract with custom optimizers is that their step() should accept an additional, - # optional grad_scaler kwarg. We append self to the kwargs so the custom optimizer has full information: - # it can query its own state, invoke unscale_ on itself, etc - # The contract above is being deprecated to avoid introducing `grad_scaler: GradScaler` argument - # to `Optimizer.step`. The new behavior is going to add two Tensor attributes of `grad_scale` - # and `found_inf` to the passed optimizer so that the optimizer can utilize those - # to skip the parameter updates or unscale gradients before updating parameters in - # the fused kernel, e.g. `FusedAdamMathFunctor`. - # In this behavior, `GradScaler._check_inf_per_device` is called if `OptState.READY`, - # while the method is expected to be called by users side, i.e. their optimizers. - kwargs_ = kwargs - has_grad_scaler_kwarg = "grad_scaler" in inspect.signature(optimizer.step).parameters - if has_grad_scaler_kwarg: - warnings.warn( - "GradScaler is going to stop passing itself as a keyword argument to the passed " - "optimizer. In the near future GradScaler registers `grad_scale: Tensor` and " - "`found_inf: Tensor` to the passed optimizer and let the optimizer use them directly.", - FutureWarning) - kwargs_.update({"grad_scaler": self}) - else: - if optimizer_state["stage"] is OptState.READY: - self._check_inf_per_device(optimizer) - scaler = self._get_scale_async() - found_inf = cast( - jt.Var, - sum([ - t for t in optimizer_state["found_inf_per_device"].values() - ]) - ) - optimizer.grad_scale = None if optimizer_state["stage"] == OptState.UNSCALED else scaler - optimizer.found_inf = found_inf - retval = optimizer.step(*args, **kwargs_) - optimizer_state["stage"] = OptState.STEPPED - if not has_grad_scaler_kwarg: - del optimizer.grad_scale - del optimizer.found_inf - return retval - - - if optimizer_state["stage"] is OptState.READY: - self.unscale_(optimizer) - - assert "found_inf_per_device" in optimizer_state, "No inf checks were recorded for this optimizer." - - retval = self._maybe_opt_step(optimizer, optimizer_state, *args, **kwargs) - - optimizer_state["stage"] = OptState.STEPPED - - return retval - - def update(self, new_scale=None): - """ - Updates the scale factor. - - If any optimizer steps were skipped the scale is multiplied by ``backoff_factor`` - to reduce it. If ``growth_interval`` unskipped iterations occurred consecutively, - the scale is multiplied by ``growth_factor`` to increase it. - - Passing ``new_scale`` sets the new scale value manually. (``new_scale`` is not - used directly, it's used to fill GradScaler's internal scale tensor. So if - ``new_scale`` was a tensor, later in-place changes to that tensor will not further - affect the scale GradScaler uses internally.) - - Args: - new_scale (float or :class:`torch.cuda.FloatTensor`, optional, default=None): New scale factor. - - .. warning:: - :meth:`update` should only be called at the end of the iteration, after ``scaler.step(optimizer)`` has - been invoked for all optimizers used this iteration. - """ - if not self._enabled: - return - - _scale, _growth_tracker = self._check_scale_growth_tracker("update") - - if new_scale is not None: - # Accept a new user-defined scale. - if isinstance(new_scale, float): - self._scale.fill_(new_scale) # type: ignore[union-attr] - else: - reason = "new_scale should be a float or a 1-element torch.cuda.FloatTensor with requires_grad=False." - assert isinstance(new_scale, torch.cuda.FloatTensor), reason # type: ignore[attr-defined] - assert new_scale.numel() == 1, reason - assert new_scale.requires_grad is False, reason - self._scale.copy_(new_scale) # type: ignore[union-attr] - else: - # Consume shared inf/nan data collected from optimizers to update the scale. - # If all found_inf tensors are on the same device as self._scale, this operation is asynchronous. - found_infs = [state["found_inf_per_device"] - for state in self._per_optimizer_states.values() - ] - - assert len(found_infs) > 0, "No inf checks were recorded prior to update." - - found_inf_combined = found_infs[0] - if len(found_infs) > 1: - for i in range(1, len(found_infs)): - found_inf_combined += found_infs[i] - - - current_scale = _scale - if found_inf_combined: - current_scale *=self._backoff_factor - _growth_tracker = 0 - else: - successful = _growth_tracker+1 - if successful == self._growth_interval: - new_scale = current_scale*self._growth_factor - if new_scale < 1e9: - current_scale = new_scale - _growth_tracker = 0 - else: - _growth_tracker = successful - - self._scale, self._growth_tracker = current_scale,_growth_tracker - - # To prepare for next iteration, clear the data collected from optimizers this iteration. - self._per_optimizer_states = defaultdict(_refresh_per_optimizer_state) - - def _get_scale_async(self): - return self._scale - - def get_scale(self): - """ - Returns a Python float containing the current scale, or 1.0 if scaling is disabled. - - .. warning:: - :meth:`get_scale` incurs a CPU-GPU sync. - """ - if self._enabled: - return self._init_scale if self._scale is None else self._get_scale_async() - else: - return 1.0 - - def get_growth_factor(self): - r""" - Returns a Python float containing the scale growth factor. - """ - return self._growth_factor - - def set_growth_factor(self, new_factor): - r""" - Args: - new_scale (float): Value to use as the new scale growth factor. - """ - self._growth_factor = new_factor - - def get_backoff_factor(self): - r""" - Returns a Python float containing the scale backoff factor. - """ - return self._backoff_factor - - def set_backoff_factor(self, new_factor): - r""" - Args: - new_scale (float): Value to use as the new scale backoff factor. - """ - self._backoff_factor = new_factor - - def get_growth_interval(self): - r""" - Returns a Python int containing the growth interval. - """ - return self._growth_interval - - def set_growth_interval(self, new_interval): - r""" - Args: - new_interval (int): Value to use as the new growth interval. - """ - self._growth_interval = new_interval - - def _get_growth_tracker(self): - if self._enabled: - return self._init_growth_tracker if self._growth_tracker is None else self._growth_tracker.item() - else: - return 0 - - def is_enabled(self): - r""" - Returns a bool indicating whether this instance is enabled. - """ - return self._enabled - - def state_dict(self): - r""" - Returns the state of the scaler as a :class:`dict`. It contains five entries: - - * ``"scale"`` - a Python float containing the current scale - * ``"growth_factor"`` - a Python float containing the current growth factor - * ``"backoff_factor"`` - a Python float containing the current backoff factor - * ``"growth_interval"`` - a Python int containing the current growth interval - * ``"_growth_tracker"`` - a Python int containing the number of recent consecutive unskipped steps. - - If this instance is not enabled, returns an empty dict. - - .. note:: - If you wish to checkpoint the scaler's state after a particular iteration, :meth:`state_dict` - should be called after :meth:`update`. - """ - return {"scale": self.get_scale(), - "growth_factor": self._growth_factor, - "backoff_factor": self._backoff_factor, - "growth_interval": self._growth_interval, - "_growth_tracker": self._get_growth_tracker()} if self._enabled else {} - - def load_state_dict(self, state_dict): - r""" - Loads the scaler state. If this instance is disabled, :meth:`load_state_dict` is a no-op. - - Args: - state_dict(dict): scaler state. Should be an object returned from a call to :meth:`state_dict`. - """ - if not self._enabled: - return - - if len(state_dict) == 0: - raise RuntimeError("The source state dict is empty, possibly because it was saved " - "from a disabled instance of GradScaler.") - - self._init_scale = state_dict["scale"] - if self._scale is not None: - self._scale.fill_(state_dict["scale"]) - self._growth_factor = state_dict["growth_factor"] - self._backoff_factor = state_dict["backoff_factor"] - self._growth_interval = state_dict["growth_interval"] - self._init_growth_tracker = state_dict["_growth_tracker"] - if self._growth_tracker is not None: - self._growth_tracker.fill_(state_dict["_growth_tracker"]) - - def __getstate__(self): - state = self.__dict__.copy() - if self._enabled: - assert len(self._per_optimizer_states) == 0, "A GradScaler instance may only be pickled at the beginning "\ - "of an iteration, or at the end after scaler.update()." - # Pickling _scale and _growth_tracker Tensors directly triggers - # "warnings.warn("pickle support for Storage will be removed in 1.5..." - # so instead, we set the unpickled instance up to reinitialize them lazily. - state['_init_scale'] = self.get_scale() - state['_init_growth_tracker'] = self._get_growth_tracker() - state['_scale'] = None - state['_growth_tracker'] = None - return state - - def __setstate__(self, state): - self.__dict__.update(state) - - def _check_inf_per_device(self, optimizer): - _scale, _ = self._check_scale_growth_tracker("_check_inf_per_device") - - dummy_inv_scale = 1.0 - found_inf = 0.0 - - self._per_optimizer_states[id(optimizer)]["found_inf_per_device"] = \ - self._unscale_grads_(optimizer, dummy_inv_scale, found_inf, True) - - return self._per_optimizer_states[id(optimizer)]["found_inf_per_device"] - - def _found_inf_per_device(self, optimizer): - return self._per_optimizer_states[id(optimizer)]["found_inf_per_device"] diff --git a/python/jittor/compatibility/misc.py b/python/jittor/compatibility/misc.py deleted file mode 100644 index 8e9ed20d..00000000 --- a/python/jittor/compatibility/misc.py +++ /dev/null @@ -1,12 +0,0 @@ -import math - -def _jit_set_profiling_mode(x): pass -def _jit_set_profiling_executor(x): pass -def _jit_override_can_fuse_on_cpu(x): pass -def _jit_override_can_fuse_on_gpu(x): pass - -def script(func): - return func - -inf = math.inf -nan = math.nan \ No newline at end of file diff --git a/python/jittor/compatibility/nn/__init__.py b/python/jittor/compatibility/nn/__init__.py deleted file mode 100644 index ae0ff3ae..00000000 --- a/python/jittor/compatibility/nn/__init__.py +++ /dev/null @@ -1,281 +0,0 @@ -import jtorch -from typing import List, Optional, Tuple, Iterable, Iterator, Mapping, Any, overload, TypeVar, Dict -from typing_extensions import Self -import jittor as jt -from jtorch import make_module, Tensor, ModuleMisc, wrapper -#from . import init -from jittor import Function -import operator -import warnings - -for k,v in jt.nn.__dict__.items(): - if callable(v): - globals()[k] = wrapper(v) - -for k,v in jt.nn.__dict__.items(): - if isinstance(v, type) and issubclass(v, jt.Module): - globals()[k] = make_module(v) - -from collections import OrderedDict -from collections import abc as container_abcs - -class Module(ModuleMisc, jt.Module): - - def __call__(self, *args, **kw): - return self.execute(*args, **kw) - - def execute(self, *args, **kw): - return self.forward(*args, **kw) - - def get_submodule(self, target: str): - if target == "": - return self - - atoms: List[str] = target.split(".") - mod: jt.nn.Module = self - - for item in atoms: - if not hasattr(mod, item): - raise AttributeError(mod._get_name() + " has no " - "attribute `" + item + "`") - - mod = getattr(mod, item) - - if not isinstance(mod, jt.nn.Module): - raise AttributeError("`" + item + "` is not " - "an nn.Module") - return mod - - - -def Parameter(x:Tensor, requires_grad:bool=True) -> Tensor: - x = x.clone() - x.requires_grad = requires_grad - x.retains_grad = requires_grad - return x - -def embedding(input, weight, padding_idx=None, max_norm=None, norm_type=2.0, scale_grad_by_freq=False, sparse=False): - return jt.nn.embedding(input, weight) - -def dropout(x, p=0.5, training=False): - return jt.nn.dropout(x, p, training) - - -class Flatten(Module): - ''' Flattens the contiguous range of dimensions in a Var. - :param start_dim: the first dimension to be flattened. Defaults: 1. - :type start_dim: int - :param end_dim: the last dimension to be flattened. Defaults: -1. - :type end_dim: int - ''' - def __init__(self, start_dim=1, end_dim=-1): - self.start_dim = start_dim - self.end_dim = end_dim - - def forward(self, x) -> jt.Var: - return x.flatten(self.start_dim, self.end_dim) - -class _IncompatibleKeys: - def __init__(self, missing_keys, unexpected_keys): - self.missing_keys = missing_keys - self.unexpected_keys = unexpected_keys - -_BatchNorm = None - -#from . import utils -normalize = wrapper(jt.normalize) - -T = TypeVar('T', bound=Module) - -class ModuleDict(Module): - _modules: Dict[str, Module] # type: ignore[assignment] - - def __init__(self, modules: Optional[Mapping[str, Module]] = None) -> None: - super().__init__() - if modules is not None: - self.update(modules) - - def __getitem__(self, key: str) -> Module: - return self._modules[key] - - def __setitem__(self, key: str, module: Module) -> None: - self.add_module(key, module) - - def __delitem__(self, key: str) -> None: - del self._modules[key] - - def __len__(self) -> int: - return len(self._modules) - - def __iter__(self) -> Iterator[str]: - return iter(self._modules) - - def __contains__(self, key: str) -> bool: - return key in self._modules - - def clear(self) -> None: - """Remove all items from the ModuleDict.""" - self._modules.clear() - - def pop(self, key: str) -> Module: - r"""Remove key from the ModuleDict and return its module. - - Args: - key (str): key to pop from the ModuleDict - """ - v = self[key] - del self[key] - return v - - def keys(self) -> Iterable[str]: - r"""Return an iterable of the ModuleDict keys.""" - return self._modules.keys() - - def items(self) -> Iterable[Tuple[str, Module]]: - r"""Return an iterable of the ModuleDict key/value pairs.""" - return self._modules.items() - - def values(self) -> Iterable[Module]: - r"""Return an iterable of the ModuleDict values.""" - return self._modules.values() - - def update(self, modules: Mapping[str, Module]) -> None: - r"""Update the :class:`~torch.nn.ModuleDict` with key-value pairs from a mapping, overwriting existing keys. - - .. note:: - If :attr:`modules` is an ``OrderedDict``, a :class:`~torch.nn.ModuleDict`, or - an iterable of key-value pairs, the order of new elements in it is preserved. - - Args: - modules (iterable): a mapping (dictionary) from string to :class:`~torch.nn.Module`, - or an iterable of key-value pairs of type (string, :class:`~torch.nn.Module`) - """ - if not isinstance(modules, container_abcs.Iterable): - raise TypeError("ModuleDict.update should be called with an " - "iterable of key/value pairs, but got " + - type(modules).__name__) - - if isinstance(modules, (OrderedDict, ModuleDict, container_abcs.Mapping)): - for key, module in modules.items(): - self[key] = module - else: - # modules here can be a list with two items - for j, m in enumerate(modules): - if not isinstance(m, container_abcs.Iterable): - raise TypeError("ModuleDict update sequence element " - "#" + str(j) + " should be Iterable; is" + - type(m).__name__) - if not len(m) == 2: - raise ValueError("ModuleDict update sequence element " - "#" + str(j) + " has length " + str(len(m)) + - "; 2 is required") - # modules can be Mapping (what it's typed at), or a list: [(name1, module1), (name2, module2)] - # that's too cumbersome to type correctly with overloads, so we add an ignore here - self[m[0]] = m[1] # type: ignore[assignment] - - # remove forward alltogether to fallback on Module's _forward_unimplemented - - -class ParameterList(Module): - - def __init__(self, values: Optional[Iterable[Any]] = None) -> None: - super().__init__() - self._size = 0 - if values is not None: - self += values - - def _get_abs_string_index(self, idx): - """Get the absolute index for the list of modules.""" - idx = operator.index(idx) - if not (-len(self) <= idx < len(self)): - raise IndexError(f'index {idx} is out of range') - if idx < 0: - idx += len(self) - return str(idx) - - @overload - def __getitem__(self, idx: int) -> Any: - ... - - @overload - def __getitem__(self: T, idx: slice) -> T: - ... - - def __getitem__(self, idx): - if isinstance(idx, slice): - start, stop, step = idx.indices(len(self)) - out = self.__class__() - for i in range(start, stop, step): - out.append(self[i]) - return out - else: - idx = self._get_abs_string_index(idx) - return getattr(self, str(idx)) - - def __setitem__(self, idx: int, param: Any) -> None: - # Note that all other function that add an entry to the list part of - # the ParameterList end up here. So this is the only place where we need - # to wrap things into Parameter if needed. - # Objects added via setattr() are not in the list part and thus won't - # call into this function. - idx = self._get_abs_string_index(idx) - if isinstance(param, jt.Var) and not isinstance(param, Parameter): - param = Parameter(param) - return setattr(self, str(idx), param) - - def __len__(self) -> int: - return self._size - - def __iter__(self) -> Iterator[Any]: - return iter(self[i] for i in range(len(self))) - - def __iadd__(self, parameters: Iterable[Any]) -> Self: - return self.extend(parameters) - - def __dir__(self): - keys = super().__dir__() - keys = [key for key in keys if not key.isdigit()] - return keys - - def append(self, value: Any) -> 'ParameterList': - """Append a given value at the end of the list. - - Args: - value (Any): value to append - """ - new_idx = len(self) - self._size += 1 - self[new_idx] = value - return self - - def extend(self, values: Iterable[Any]) -> Self: - """Append values from a Python iterable to the end of the list. - - Args: - values (iterable): iterable of values to append - """ - # Tensor is an iterable but we never want to unpack it here - if not isinstance(values, container_abcs.Iterable) or isinstance(values, jt.Var): - raise TypeError("ParameterList.extend should be called with an " - "iterable, but got " + type(values).__name__) - for value in values: - self.append(value) - return self - - def extra_repr(self) -> str: - child_lines = [] - for k, p in enumerate(self): - if isinstance(p, jt.Var): - size_str = 'x'.join(str(size) for size in p.size()) - parastr = '{} containing: [{} of size {}{}]'.format( - "Parameter" if isinstance(p, Parameter) else "Tensor", - p.dtype, size_str, "cuda" if jt.flags.use_cuda else "cpu") - child_lines.append(' (' + str(k) + '): ' + parastr) - else: - child_lines.append(' (' + str(k) + '): Object of type: ' + type(p).__name__) - - tmpstr = '\n'.join(child_lines) - return tmpstr - - def __call__(self, *args, **kwargs): - raise RuntimeError('ParameterList should not be called.') \ No newline at end of file diff --git a/python/jittor/compatibility/nn/init.py b/python/jittor/compatibility/nn/init.py deleted file mode 100644 index 3b9f0907..00000000 --- a/python/jittor/compatibility/nn/init.py +++ /dev/null @@ -1,16 +0,0 @@ -import jittor as jt - -for k,v in jt.nn.init.__dict__.items(): - if callable(v): - globals()[k] = v - - -normal = gauss -normal_ = gauss_ -xavier_normal = xavier_gauss -xavier_normal_ = xavier_gauss_ -zeros_ = zero_ - - -jt.Var.normal_ = normal_ - diff --git a/python/jittor/compatibility/nn/utils/__init__.py b/python/jittor/compatibility/nn/utils/__init__.py deleted file mode 100644 index 83409f5f..00000000 --- a/python/jittor/compatibility/nn/utils/__init__.py +++ /dev/null @@ -1 +0,0 @@ -from . import rnn \ No newline at end of file diff --git a/python/jittor/compatibility/nn/utils/rnn.py b/python/jittor/compatibility/nn/utils/rnn.py deleted file mode 100644 index b32da8c3..00000000 --- a/python/jittor/compatibility/nn/utils/rnn.py +++ /dev/null @@ -1,20 +0,0 @@ -import jittor as jt - -PackedSequence = None - -def pad_sequence(sequences,batch_first=False,padding_value=0.0): - max_f = max([len(s) for s in sequences]) - # max_f = 512 - b = len(sequences) - if batch_first: - ret = sequences[0].new_full([b,max_f,]+list(sequences[0].shape[1:]),padding_value) - for i,s in enumerate(sequences): - ret[i,:len(s)] = s - else: - ret = sequences[0].new_full([max_f,b,]+list(sequences[0].shape[1:]),padding_value) - for i,s in enumerate(sequences): - ret[:len(s),i] = s - # print(ret.shape) - # ret = ret[:,:406] - return ret - \ No newline at end of file diff --git a/python/jittor/compatibility/optim.py b/python/jittor/compatibility/optim.py deleted file mode 100644 index 2410917f..00000000 --- a/python/jittor/compatibility/optim.py +++ /dev/null @@ -1,1854 +0,0 @@ -import jittor as jt -import math -from jittor.optim import * -from functools import partial - -class Optimizer(jt.optim.Optimizer): - def pre_step(self, loss=None, retain_graph=False): - jt.flags.node_order = 1 - params_has_grad = [] - for pg in self.param_groups: - pg["grads"] = [ jt.zeros_like(p) if p.grad is None else p.grad#.float32() - for p in pg["params"] ] - for p in pg["params"]: - if p.requires_grad: - params_has_grad.append(p) - jt.sync(params_has_grad) - self.n_step += 1 - - def zero_grad(self): - for pg in self.param_groups: - pg["grads"] = [ None for p in pg["params"] ] - for p in pg["params"]: p.grad = None - - def post_step(self): - jt.flags.node_order = 0 - - def clip_grad_norm(self, max_norm:float, norm_type:int=2): - r"""Clips gradient norm of this optimizer. - The norm is computed over all gradients together. - - Args: - max_norm (float or int): max norm of the gradients - norm_type (int): 1-norm or 2-norm - - Example:: - - a = jt.ones(2) - opt = jt.optim.SGD([a], 0.1) - - loss = a*a - opt.zero_grad() - opt.backward(loss) - - print(opt.param_groups[0]['grads'][0].norm()) # output: 2.83 - opt.clip_grad_norm(0.01, 2) - print(opt.param_groups[0]['grads'][0].norm()) # output: 0.01 - - opt.step() - - """ - self.pre_step(None) - grads = [] - for pg in self.param_groups: - for p, g in zip(pg["params"], pg["grads"]): - if p.is_stop_grad(): continue - grads.append(g.flatten()) - if len(grads) == 0: return - total_norm = jt.norm(jt.concat(grads), norm_type) - clip_coef = jt.minimum(max_norm / (total_norm + 1e-6), 1.0) - for pg in self.param_groups: - for p, g in zip(pg["params"], pg["grads"]): - if p.is_stop_grad(): continue - g.update(g*clip_coef) - - -class AdamW(Optimizer): - def __init__(self, params, lr, eps=1e-8, betas=(0.9, 0.999), weight_decay=0,use_fp32=True): - print("lr:", lr) - super().__init__(params, lr) - self.eps = eps - self.betas = betas - self.weight_decay = weight_decay - - self.use_fp32 = use_fp32 - # assert weight_decay==0, "weight_decay is not supported yet" - - # initialize required arguments for each param_groups - for pg in self.param_groups: - values = pg["values"] = [] - m = pg["m"] = [] - mp = pg['masterparams'] = [] - for p in pg["params"]: - values.append(jt.zeros(p.shape, "float32" if self.use_fp32 else p.dtype).stop_grad()) - m.append(jt.zeros(p.shape, "float32" if self.use_fp32 else p.dtype).stop_grad()) - if self.use_fp32: - mp.append(p.detach().clone().stop_grad()) - - def add_param_group(self, group): - values = group["values"] = [] - m = group["m"] = [] - mp = group['masterparams'] = [] - for p in group["params"]: - values.append(jt.zeros(p.shape, "float32" if self.use_fp32 else p.dtype).stop_grad()) - m.append(jt.zeros(p.shape, "float32" if self.use_fp32 else p.dtype).stop_grad()) - if self.use_fp32: - mp.append(p.detach().clone().stop_grad()) - self.param_groups.append(group) - - def step(self, loss=None, retain_graph=False): - self.pre_step(loss, retain_graph) - if loss is None: - self.n_step += 1 - n = float(self.n_step) - for pg in self.param_groups: - # get arguments from each param_groups - lr = pg.get("lr", self.lr) - eps = pg.get("eps", self.eps) - weight_decay = pg.get("weight_decay", self.weight_decay) - b0, b1 = pg.get("betas", self.betas) - - for p, g, v, m,mp in zip(pg["params"], pg["grads"], pg["values"], pg["m"],pg['masterparams']): - if p.is_stop_grad(): continue - #if g.abs().sum().item() < 1e-8: continue - #import pdb; pdb.set_trace() - c_p = (mp * (1 - lr * weight_decay)) - mp.update(c_p) - if self.use_fp32: - g = g.float32() - bias_correction1 = 1 - b0 ** n - bias_correction2 = 1 - b1 ** n - m.update(b0 * m + (1-b0) * g) #exp_avg - v.update(b1 * v + (1-b1) * g * g) #exp_avg_sq - denom = jt.sqrt(v) / jt.sqrt(bias_correction2) + eps - step_size = lr / bias_correction1 - new_p = (mp - step_size * m / denom) - mp.update(new_p) - p.update(mp.cast(p.dtype)) - self.post_step() - -for k,v in jt.optim.__dict__.items(): - if k == "AdamW":continue - if isinstance(v, type) and issubclass(v, jt.optim.Optimizer) and \ - not v is jt.optim.Optimizer: - class OptimWrap(v, Optimizer): - pass - globals()[k] = OptimWrap - - -class Adagrad(Optimizer): - pass - - - -import types -import math -from functools import wraps -import warnings -import weakref -from collections import Counter -from bisect import bisect_right - - -class LRScheduler: - - def __init__(self, optimizer, last_epoch=-1, verbose=False): - - # Attach optimizer - if not isinstance(optimizer, Optimizer): - raise TypeError('{} is not an Optimizer'.format( - type(optimizer).__name__)) - self.optimizer = optimizer - - # Initialize epoch and base learning rates - if last_epoch == -1: - for group in optimizer.param_groups: - group.setdefault('initial_lr', group.get("lr",optimizer.lr)) - else: - for i, group in enumerate(optimizer.param_groups): - if 'initial_lr' not in group: - raise KeyError("param 'initial_lr' is not specified " - "in param_groups[{}] when resuming an optimizer".format(i)) - self.base_lrs = [group['initial_lr'] for group in optimizer.param_groups] - self.last_epoch = last_epoch - - # Following https://github.com/pytorch/pytorch/issues/20124 - # We would like to ensure that `lr_scheduler.step()` is called after - # `optimizer.step()` - def with_counter(method): - if getattr(method, '_with_counter', False): - # `optimizer.step()` has already been replaced, return. - return method - - # Keep a weak reference to the optimizer instance to prevent - # cyclic references. - instance_ref = weakref.ref(method.__self__) - # Get the unbound method for the same purpose. - func = method.__func__ - cls = instance_ref().__class__ - del method - - @wraps(func) - def wrapper(*args, **kwargs): - instance = instance_ref() - instance._step_count += 1 - wrapped = func.__get__(instance, cls) - return wrapped(*args, **kwargs) - - # Note that the returned function here is no longer a bound method, - # so attributes like `__func__` and `__self__` no longer exist. - wrapper._with_counter = True - return wrapper - - self.optimizer.step = with_counter(self.optimizer.step) - self.verbose = verbose - - self._initial_step() - - def _initial_step(self): - """Initialize step counts and performs a step""" - self.optimizer._step_count = 0 - self._step_count = 0 - self.step() - - def state_dict(self): - """Returns the state of the scheduler as a :class:`dict`. - - It contains an entry for every variable in self.__dict__ which - is not the optimizer. - """ - return {key: value for key, value in self.__dict__.items() if key != 'optimizer'} - - def load_state_dict(self, state_dict): - """Loads the schedulers state. - - Args: - state_dict (dict): scheduler state. Should be an object returned - from a call to :meth:`state_dict`. - """ - self.__dict__.update(state_dict) - - def get_last_lr(self): - """ Return last computed learning rate by current scheduler. - """ - return self._last_lr - - def get_lr(self): - # Compute learning rate using chainable form of the scheduler - raise NotImplementedError - - def print_lr(self, is_verbose, group, lr, epoch=None): - """Display the current learning rate. - """ - if is_verbose: - if epoch is None: - print('Adjusting learning rate' - ' of group {} to {:.4e}.'.format(group, lr)) - else: - epoch_str = ("%.2f" if isinstance(epoch, float) else - "%.5d") % epoch - print('Epoch {}: adjusting learning rate' - ' of group {} to {:.4e}.'.format(epoch_str, group, lr)) - - - def step(self, epoch=None): - # Raise a warning if old pattern is detected - # https://github.com/pytorch/pytorch/issues/20124 - if self._step_count == 1: - if not hasattr(self.optimizer.step, "_with_counter"): - warnings.warn("Seems like `optimizer.step()` has been overridden after learning rate scheduler " - "initialization. Please, make sure to call `optimizer.step()` before " - "`lr_scheduler.step()`. See more details at " - "https://pytorch.org/docs/stable/optim.html#how-to-adjust-learning-rate", UserWarning) - - # Just check if there were two first lr_scheduler.step() calls before optimizer.step() - elif self.optimizer._step_count < 1: - warnings.warn("Detected call of `lr_scheduler.step()` before `optimizer.step()`. " - "In PyTorch 1.1.0 and later, you should call them in the opposite order: " - "`optimizer.step()` before `lr_scheduler.step()`. Failure to do this " - "will result in PyTorch skipping the first value of the learning rate schedule. " - "See more details at " - "https://pytorch.org/docs/stable/optim.html#how-to-adjust-learning-rate", UserWarning) - self._step_count += 1 - - with _enable_get_lr_call(self): - if epoch is None: - self.last_epoch += 1 - values = self.get_lr() - else: - warnings.warn(EPOCH_DEPRECATION_WARNING, UserWarning) - self.last_epoch = epoch - if hasattr(self, "_get_closed_form_lr"): - values = self._get_closed_form_lr() - else: - values = self.get_lr() - - for i, data in enumerate(zip(self.optimizer.param_groups, values)): - param_group, lr = data - param_group['lr'] = lr - self.print_lr(self.verbose, i, lr, epoch) - - self._last_lr = [group['lr'] for group in self.optimizer.param_groups] - -# Including _LRScheduler for backwards compatibility -# Subclass instead of assign because we want __name__ of _LRScheduler to be _LRScheduler (assigning would make it LRScheduler). -class _LRScheduler(LRScheduler): - pass - - -class _enable_get_lr_call: - - def __init__(self, o): - self.o = o - - def __enter__(self): - self.o._get_lr_called_within_step = True - return self - - def __exit__(self, type, value, traceback): - self.o._get_lr_called_within_step = False - - -class LambdaLR(LRScheduler): - """Sets the learning rate of each parameter group to the initial lr - times a given function. When last_epoch=-1, sets initial lr as lr. - - Args: - optimizer (Optimizer): Wrapped optimizer. - lr_lambda (function or list): A function which computes a multiplicative - factor given an integer parameter epoch, or a list of such - functions, one for each group in optimizer.param_groups. - last_epoch (int): The index of last epoch. Default: -1. - verbose (bool): If ``True``, prints a message to stdout for - each update. Default: ``False``. - - Example: - >>> # xdoctest: +SKIP - >>> # Assuming optimizer has two groups. - >>> lambda1 = lambda epoch: epoch // 30 - >>> lambda2 = lambda epoch: 0.95 ** epoch - >>> scheduler = LambdaLR(optimizer, lr_lambda=[lambda1, lambda2]) - >>> for epoch in range(100): - >>> train(...) - >>> validate(...) - >>> scheduler.step() - """ - - def __init__(self, optimizer, lr_lambda, last_epoch=-1, verbose=False): - self.optimizer = optimizer - - if not isinstance(lr_lambda, list) and not isinstance(lr_lambda, tuple): - self.lr_lambdas = [lr_lambda] * len(optimizer.param_groups) - else: - if len(lr_lambda) != len(optimizer.param_groups): - raise ValueError("Expected {} lr_lambdas, but got {}".format( - len(optimizer.param_groups), len(lr_lambda))) - self.lr_lambdas = list(lr_lambda) - super().__init__(optimizer, last_epoch, verbose) - - def state_dict(self): - """Returns the state of the scheduler as a :class:`dict`. - - It contains an entry for every variable in self.__dict__ which - is not the optimizer. - The learning rate lambda functions will only be saved if they are callable objects - and not if they are functions or lambdas. - - When saving or loading the scheduler, please make sure to also save or load the state of the optimizer. - """ - - state_dict = {key: value for key, value in self.__dict__.items() if key not in ('optimizer', 'lr_lambdas')} - state_dict['lr_lambdas'] = [None] * len(self.lr_lambdas) - - for idx, fn in enumerate(self.lr_lambdas): - if not isinstance(fn, types.FunctionType): - state_dict['lr_lambdas'][idx] = fn.__dict__.copy() - - return state_dict - - def load_state_dict(self, state_dict): - """Loads the schedulers state. - - When saving or loading the scheduler, please make sure to also save or load the state of the optimizer. - - Args: - state_dict (dict): scheduler state. Should be an object returned - from a call to :meth:`state_dict`. - """ - - lr_lambdas = state_dict.pop('lr_lambdas') - self.__dict__.update(state_dict) - # Restore state_dict keys in order to prevent side effects - # https://github.com/pytorch/pytorch/issues/32756 - state_dict['lr_lambdas'] = lr_lambdas - - for idx, fn in enumerate(lr_lambdas): - if fn is not None: - self.lr_lambdas[idx].__dict__.update(fn) - - def get_lr(self): - if not self._get_lr_called_within_step: - warnings.warn("To get the last learning rate computed by the scheduler, " - "please use `get_last_lr()`.") - return [base_lr * lmbda(self.last_epoch) - for lmbda, base_lr in zip(self.lr_lambdas, self.base_lrs)] - - -class MultiplicativeLR(LRScheduler): - """Multiply the learning rate of each parameter group by the factor given - in the specified function. When last_epoch=-1, sets initial lr as lr. - - Args: - optimizer (Optimizer): Wrapped optimizer. - lr_lambda (function or list): A function which computes a multiplicative - factor given an integer parameter epoch, or a list of such - functions, one for each group in optimizer.param_groups. - last_epoch (int): The index of last epoch. Default: -1. - verbose (bool): If ``True``, prints a message to stdout for - each update. Default: ``False``. - - Example: - >>> # xdoctest: +SKIP - >>> lmbda = lambda epoch: 0.95 - >>> scheduler = MultiplicativeLR(optimizer, lr_lambda=lmbda) - >>> for epoch in range(100): - >>> train(...) - >>> validate(...) - >>> scheduler.step() - """ - - def __init__(self, optimizer, lr_lambda, last_epoch=-1, verbose=False): - self.optimizer = optimizer - - if not isinstance(lr_lambda, list) and not isinstance(lr_lambda, tuple): - self.lr_lambdas = [lr_lambda] * len(optimizer.param_groups) - else: - if len(lr_lambda) != len(optimizer.param_groups): - raise ValueError("Expected {} lr_lambdas, but got {}".format( - len(optimizer.param_groups), len(lr_lambda))) - self.lr_lambdas = list(lr_lambda) - super().__init__(optimizer, last_epoch, verbose) - - def state_dict(self): - """Returns the state of the scheduler as a :class:`dict`. - - It contains an entry for every variable in self.__dict__ which - is not the optimizer. - The learning rate lambda functions will only be saved if they are callable objects - and not if they are functions or lambdas. - """ - state_dict = {key: value for key, value in self.__dict__.items() if key not in ('optimizer', 'lr_lambdas')} - state_dict['lr_lambdas'] = [None] * len(self.lr_lambdas) - - for idx, fn in enumerate(self.lr_lambdas): - if not isinstance(fn, types.FunctionType): - state_dict['lr_lambdas'][idx] = fn.__dict__.copy() - - return state_dict - - def load_state_dict(self, state_dict): - """Loads the schedulers state. - - Args: - state_dict (dict): scheduler state. Should be an object returned - from a call to :meth:`state_dict`. - """ - lr_lambdas = state_dict.pop('lr_lambdas') - self.__dict__.update(state_dict) - # Restore state_dict keys in order to prevent side effects - # https://github.com/pytorch/pytorch/issues/32756 - state_dict['lr_lambdas'] = lr_lambdas - - for idx, fn in enumerate(lr_lambdas): - if fn is not None: - self.lr_lambdas[idx].__dict__.update(fn) - - def get_lr(self): - if not self._get_lr_called_within_step: - warnings.warn("To get the last learning rate computed by the scheduler, " - "please use `get_last_lr()`.", UserWarning) - - if self.last_epoch > 0: - return [group['lr'] * lmbda(self.last_epoch) - for lmbda, group in zip(self.lr_lambdas, self.optimizer.param_groups)] - else: - return [group['lr'] for group in self.optimizer.param_groups] - - -class StepLR(LRScheduler): - """Decays the learning rate of each parameter group by gamma every - step_size epochs. Notice that such decay can happen simultaneously with - other changes to the learning rate from outside this scheduler. When - last_epoch=-1, sets initial lr as lr. - - Args: - optimizer (Optimizer): Wrapped optimizer. - step_size (int): Period of learning rate decay. - gamma (float): Multiplicative factor of learning rate decay. - Default: 0.1. - last_epoch (int): The index of last epoch. Default: -1. - verbose (bool): If ``True``, prints a message to stdout for - each update. Default: ``False``. - - Example: - >>> # xdoctest: +SKIP - >>> # Assuming optimizer uses lr = 0.05 for all groups - >>> # lr = 0.05 if epoch < 30 - >>> # lr = 0.005 if 30 <= epoch < 60 - >>> # lr = 0.0005 if 60 <= epoch < 90 - >>> # ... - >>> scheduler = StepLR(optimizer, step_size=30, gamma=0.1) - >>> for epoch in range(100): - >>> train(...) - >>> validate(...) - >>> scheduler.step() - """ - - def __init__(self, optimizer, step_size, gamma=0.1, last_epoch=-1, verbose=False): - self.step_size = step_size - self.gamma = gamma - super().__init__(optimizer, last_epoch, verbose) - - def get_lr(self): - if not self._get_lr_called_within_step: - warnings.warn("To get the last learning rate computed by the scheduler, " - "please use `get_last_lr()`.", UserWarning) - - if (self.last_epoch == 0) or (self.last_epoch % self.step_size != 0): - return [group['lr'] for group in self.optimizer.param_groups] - return [group['lr'] * self.gamma - for group in self.optimizer.param_groups] - - def _get_closed_form_lr(self): - return [base_lr * self.gamma ** (self.last_epoch // self.step_size) - for base_lr in self.base_lrs] - - -class MultiStepLR(LRScheduler): - """Decays the learning rate of each parameter group by gamma once the - number of epoch reaches one of the milestones. Notice that such decay can - happen simultaneously with other changes to the learning rate from outside - this scheduler. When last_epoch=-1, sets initial lr as lr. - - Args: - optimizer (Optimizer): Wrapped optimizer. - milestones (list): List of epoch indices. Must be increasing. - gamma (float): Multiplicative factor of learning rate decay. - Default: 0.1. - last_epoch (int): The index of last epoch. Default: -1. - verbose (bool): If ``True``, prints a message to stdout for - each update. Default: ``False``. - - Example: - >>> # xdoctest: +SKIP - >>> # Assuming optimizer uses lr = 0.05 for all groups - >>> # lr = 0.05 if epoch < 30 - >>> # lr = 0.005 if 30 <= epoch < 80 - >>> # lr = 0.0005 if epoch >= 80 - >>> scheduler = MultiStepLR(optimizer, milestones=[30,80], gamma=0.1) - >>> for epoch in range(100): - >>> train(...) - >>> validate(...) - >>> scheduler.step() - """ - - def __init__(self, optimizer, milestones, gamma=0.1, last_epoch=-1, verbose=False): - self.milestones = Counter(milestones) - self.gamma = gamma - super().__init__(optimizer, last_epoch, verbose) - - def get_lr(self): - if not self._get_lr_called_within_step: - warnings.warn("To get the last learning rate computed by the scheduler, " - "please use `get_last_lr()`.", UserWarning) - - if self.last_epoch not in self.milestones: - return [group['lr'] for group in self.optimizer.param_groups] - return [group['lr'] * self.gamma ** self.milestones[self.last_epoch] - for group in self.optimizer.param_groups] - - def _get_closed_form_lr(self): - milestones = sorted(self.milestones.elements()) - return [base_lr * self.gamma ** bisect_right(milestones, self.last_epoch) - for base_lr in self.base_lrs] - - -class ConstantLR(LRScheduler): - """Decays the learning rate of each parameter group by a small constant factor until the - number of epoch reaches a pre-defined milestone: total_iters. Notice that such decay can - happen simultaneously with other changes to the learning rate from outside this scheduler. - When last_epoch=-1, sets initial lr as lr. - - Args: - optimizer (Optimizer): Wrapped optimizer. - factor (float): The number we multiply learning rate until the milestone. Default: 1./3. - total_iters (int): The number of steps that the scheduler decays the learning rate. - Default: 5. - last_epoch (int): The index of the last epoch. Default: -1. - verbose (bool): If ``True``, prints a message to stdout for - each update. Default: ``False``. - - Example: - >>> # xdoctest: +SKIP - >>> # Assuming optimizer uses lr = 0.05 for all groups - >>> # lr = 0.025 if epoch == 0 - >>> # lr = 0.025 if epoch == 1 - >>> # lr = 0.025 if epoch == 2 - >>> # lr = 0.025 if epoch == 3 - >>> # lr = 0.05 if epoch >= 4 - >>> scheduler = ConstantLR(self.opt, factor=0.5, total_iters=4) - >>> for epoch in range(100): - >>> train(...) - >>> validate(...) - >>> scheduler.step() - """ - - def __init__(self, optimizer, factor=1.0 / 3, total_iters=5, last_epoch=-1, verbose=False): - if factor > 1.0 or factor < 0: - raise ValueError('Constant multiplicative factor expected to be between 0 and 1.') - - self.factor = factor - self.total_iters = total_iters - super().__init__(optimizer, last_epoch, verbose) - - def get_lr(self): - if not self._get_lr_called_within_step: - warnings.warn("To get the last learning rate computed by the scheduler, " - "please use `get_last_lr()`.", UserWarning) - - if self.last_epoch == 0: - return [group['lr'] * self.factor for group in self.optimizer.param_groups] - - if (self.last_epoch > self.total_iters or - (self.last_epoch != self.total_iters)): - return [group['lr'] for group in self.optimizer.param_groups] - - if (self.last_epoch == self.total_iters): - return [group['lr'] * (1.0 / self.factor) for group in self.optimizer.param_groups] - - def _get_closed_form_lr(self): - return [base_lr * (self.factor + (self.last_epoch >= self.total_iters) * (1 - self.factor)) - for base_lr in self.base_lrs] - - -class LinearLR(LRScheduler): - """Decays the learning rate of each parameter group by linearly changing small - multiplicative factor until the number of epoch reaches a pre-defined milestone: total_iters. - Notice that such decay can happen simultaneously with other changes to the learning rate - from outside this scheduler. When last_epoch=-1, sets initial lr as lr. - - Args: - optimizer (Optimizer): Wrapped optimizer. - start_factor (float): The number we multiply learning rate in the first epoch. - The multiplication factor changes towards end_factor in the following epochs. - Default: 1./3. - end_factor (float): The number we multiply learning rate at the end of linear changing - process. Default: 1.0. - total_iters (int): The number of iterations that multiplicative factor reaches to 1. - Default: 5. - last_epoch (int): The index of the last epoch. Default: -1. - verbose (bool): If ``True``, prints a message to stdout for - each update. Default: ``False``. - - Example: - >>> # xdoctest: +SKIP - >>> # Assuming optimizer uses lr = 0.05 for all groups - >>> # lr = 0.025 if epoch == 0 - >>> # lr = 0.03125 if epoch == 1 - >>> # lr = 0.0375 if epoch == 2 - >>> # lr = 0.04375 if epoch == 3 - >>> # lr = 0.05 if epoch >= 4 - >>> scheduler = LinearLR(self.opt, start_factor=0.5, total_iters=4) - >>> for epoch in range(100): - >>> train(...) - >>> validate(...) - >>> scheduler.step() - """ - - def __init__(self, optimizer, start_factor=1.0 / 3, end_factor=1.0, total_iters=5, last_epoch=-1, - verbose=False): - if start_factor > 1.0 or start_factor <= 0: - raise ValueError('Starting multiplicative factor expected to be greater than 0 and less or equal to 1.') - - if end_factor > 1.0 or end_factor < 0: - raise ValueError('Ending multiplicative factor expected to be between 0 and 1.') - - self.start_factor = start_factor - self.end_factor = end_factor - self.total_iters = total_iters - super().__init__(optimizer, last_epoch, verbose) - - def get_lr(self): - if not self._get_lr_called_within_step: - warnings.warn("To get the last learning rate computed by the scheduler, " - "please use `get_last_lr()`.", UserWarning) - - if self.last_epoch == 0: - return [group['lr'] * self.start_factor for group in self.optimizer.param_groups] - - if self.last_epoch > self.total_iters: - return [group['lr'] for group in self.optimizer.param_groups] - - return [group['lr'] * (1. + (self.end_factor - self.start_factor) / - (self.total_iters * self.start_factor + (self.last_epoch - 1) * (self.end_factor - self.start_factor))) - for group in self.optimizer.param_groups] - - def _get_closed_form_lr(self): - return [base_lr * (self.start_factor + - (self.end_factor - self.start_factor) * min(self.total_iters, self.last_epoch) / self.total_iters) - for base_lr in self.base_lrs] - - -class ExponentialLR(LRScheduler): - """Decays the learning rate of each parameter group by gamma every epoch. - When last_epoch=-1, sets initial lr as lr. - - Args: - optimizer (Optimizer): Wrapped optimizer. - gamma (float): Multiplicative factor of learning rate decay. - last_epoch (int): The index of last epoch. Default: -1. - verbose (bool): If ``True``, prints a message to stdout for - each update. Default: ``False``. - """ - - def __init__(self, optimizer, gamma, last_epoch=-1, verbose=False): - self.gamma = gamma - super().__init__(optimizer, last_epoch, verbose) - - def get_lr(self): - if not self._get_lr_called_within_step: - warnings.warn("To get the last learning rate computed by the scheduler, " - "please use `get_last_lr()`.", UserWarning) - - if self.last_epoch == 0: - return [group['lr'] for group in self.optimizer.param_groups] - return [group['lr'] * self.gamma - for group in self.optimizer.param_groups] - - def _get_closed_form_lr(self): - return [base_lr * self.gamma ** self.last_epoch - for base_lr in self.base_lrs] - - -class SequentialLR(LRScheduler): - """Receives the list of schedulers that is expected to be called sequentially during - optimization process and milestone points that provides exact intervals to reflect - which scheduler is supposed to be called at a given epoch. - - Args: - optimizer (Optimizer): Wrapped optimizer. - schedulers (list): List of chained schedulers. - milestones (list): List of integers that reflects milestone points. - last_epoch (int): The index of last epoch. Default: -1. - verbose (bool): Does nothing. - - Example: - >>> # xdoctest: +SKIP - >>> # Assuming optimizer uses lr = 1. for all groups - >>> # lr = 0.1 if epoch == 0 - >>> # lr = 0.1 if epoch == 1 - >>> # lr = 0.9 if epoch == 2 - >>> # lr = 0.81 if epoch == 3 - >>> # lr = 0.729 if epoch == 4 - >>> scheduler1 = ConstantLR(self.opt, factor=0.1, total_iters=2) - >>> scheduler2 = ExponentialLR(self.opt, gamma=0.9) - >>> scheduler = SequentialLR(self.opt, schedulers=[scheduler1, scheduler2], milestones=[2]) - >>> for epoch in range(100): - >>> train(...) - >>> validate(...) - >>> scheduler.step() - """ - - def __init__(self, optimizer, schedulers, milestones, last_epoch=-1, verbose=False): - for scheduler_idx in range(len(schedulers)): - if schedulers[scheduler_idx].optimizer != optimizer: - raise ValueError( - "Sequential Schedulers expects all schedulers to belong to the same optimizer, but " - f"got schedulers at index {scheduler_idx} to be different than the optimizer passed in." - ) - - if (schedulers[scheduler_idx].optimizer != schedulers[0].optimizer): - raise ValueError( - "Sequential Schedulers expects all schedulers to belong to the same optimizer, but " - f"got schedulers at index {0} and {scheduler_idx} to be different." - ) - if (len(milestones) != len(schedulers) - 1): - raise ValueError( - "Sequential Schedulers expects number of schedulers provided to be one more " - "than the number of milestone points, but got number of schedulers {} and the " - "number of milestones to be equal to {}".format(len(schedulers), len(milestones)) - ) - self._schedulers = schedulers - self._milestones = milestones - self.last_epoch = last_epoch + 1 - self.optimizer = optimizer - - # Reset learning rates back to initial values - for group in self.optimizer.param_groups: - group["lr"] = group["initial_lr"] - - # "Undo" the step performed by other schedulers - for scheduler in self._schedulers: - scheduler.last_epoch -= 1 - - # Perform the initial step for only the first scheduler - self._schedulers[0]._initial_step() - - self._last_lr = schedulers[0].get_last_lr() - - def step(self): - self.last_epoch += 1 - idx = bisect_right(self._milestones, self.last_epoch) - scheduler = self._schedulers[idx] - if idx > 0 and self._milestones[idx - 1] == self.last_epoch: - scheduler.step(0) - else: - scheduler.step() - - self._last_lr = scheduler.get_last_lr() - - def state_dict(self): - """Returns the state of the scheduler as a :class:`dict`. - - It contains an entry for every variable in self.__dict__ which - is not the optimizer. - The wrapped scheduler states will also be saved. - """ - state_dict = {key: value for key, value in self.__dict__.items() if key not in ('optimizer', '_schedulers')} - state_dict['_schedulers'] = [None] * len(self._schedulers) - - for idx, s in enumerate(self._schedulers): - state_dict['_schedulers'][idx] = s.state_dict() - - return state_dict - - def load_state_dict(self, state_dict): - """Loads the schedulers state. - - Args: - state_dict (dict): scheduler state. Should be an object returned - from a call to :meth:`state_dict`. - """ - _schedulers = state_dict.pop('_schedulers') - self.__dict__.update(state_dict) - # Restore state_dict keys in order to prevent side effects - # https://github.com/pytorch/pytorch/issues/32756 - state_dict['_schedulers'] = _schedulers - - for idx, s in enumerate(_schedulers): - self._schedulers[idx].load_state_dict(s) - - -class PolynomialLR(LRScheduler): - """Decays the learning rate of each parameter group using a polynomial function - in the given total_iters. When last_epoch=-1, sets initial lr as lr. - - Args: - optimizer (Optimizer): Wrapped optimizer. - total_iters (int): The number of steps that the scheduler decays the learning rate. Default: 5. - power (int): The power of the polynomial. Default: 1.0. - verbose (bool): If ``True``, prints a message to stdout for - each update. Default: ``False``. - - Example: - >>> # xdoctest: +SKIP("undefined vars") - >>> # Assuming optimizer uses lr = 0.001 for all groups - >>> # lr = 0.001 if epoch == 0 - >>> # lr = 0.00075 if epoch == 1 - >>> # lr = 0.00050 if epoch == 2 - >>> # lr = 0.00025 if epoch == 3 - >>> # lr = 0.0 if epoch >= 4 - >>> scheduler = PolynomialLR(self.opt, total_iters=4, power=1.0) - >>> for epoch in range(100): - >>> train(...) - >>> validate(...) - >>> scheduler.step() - """ - def __init__(self, optimizer, total_iters=5, power=1.0, last_epoch=-1, verbose=False): - self.total_iters = total_iters - self.power = power - super().__init__(optimizer, last_epoch, verbose) - - def get_lr(self): - if not self._get_lr_called_within_step: - warnings.warn("To get the last learning rate computed by the scheduler, " - "please use `get_last_lr()`.", UserWarning) - - if self.last_epoch == 0 or self.last_epoch > self.total_iters: - return [group["lr"] for group in self.optimizer.param_groups] - - decay_factor = ((1.0 - self.last_epoch / self.total_iters) / (1.0 - (self.last_epoch - 1) / self.total_iters)) ** self.power - return [group["lr"] * decay_factor for group in self.optimizer.param_groups] - - def _get_closed_form_lr(self): - return [ - ( - base_lr * (1.0 - min(self.total_iters, self.last_epoch) / self.total_iters) ** self.power - ) - for base_lr in self.base_lrs - ] - - -class CosineAnnealingLR(LRScheduler): - r"""Set the learning rate of each parameter group using a cosine annealing - schedule, where :math:`\eta_{max}` is set to the initial lr and - :math:`T_{cur}` is the number of epochs since the last restart in SGDR: - - .. math:: - \begin{aligned} - \eta_t & = \eta_{min} + \frac{1}{2}(\eta_{max} - \eta_{min})\left(1 - + \cos\left(\frac{T_{cur}}{T_{max}}\pi\right)\right), - & T_{cur} \neq (2k+1)T_{max}; \\ - \eta_{t+1} & = \eta_{t} + \frac{1}{2}(\eta_{max} - \eta_{min}) - \left(1 - \cos\left(\frac{1}{T_{max}}\pi\right)\right), - & T_{cur} = (2k+1)T_{max}. - \end{aligned} - - When last_epoch=-1, sets initial lr as lr. Notice that because the schedule - is defined recursively, the learning rate can be simultaneously modified - outside this scheduler by other operators. If the learning rate is set - solely by this scheduler, the learning rate at each step becomes: - - .. math:: - \eta_t = \eta_{min} + \frac{1}{2}(\eta_{max} - \eta_{min})\left(1 + - \cos\left(\frac{T_{cur}}{T_{max}}\pi\right)\right) - - It has been proposed in - `SGDR: Stochastic Gradient Descent with Warm Restarts`_. Note that this only - implements the cosine annealing part of SGDR, and not the restarts. - - Args: - optimizer (Optimizer): Wrapped optimizer. - T_max (int): Maximum number of iterations. - eta_min (float): Minimum learning rate. Default: 0. - last_epoch (int): The index of last epoch. Default: -1. - verbose (bool): If ``True``, prints a message to stdout for - each update. Default: ``False``. - - .. _SGDR\: Stochastic Gradient Descent with Warm Restarts: - https://arxiv.org/abs/1608.03983 - """ - - def __init__(self, optimizer, T_max, eta_min=0, last_epoch=-1, verbose=False): - self.T_max = T_max - self.eta_min = eta_min - super().__init__(optimizer, last_epoch, verbose) - - def get_lr(self): - if not self._get_lr_called_within_step: - warnings.warn("To get the last learning rate computed by the scheduler, " - "please use `get_last_lr()`.", UserWarning) - - if self.last_epoch == 0: - return [group['lr'] for group in self.optimizer.param_groups] - elif self._step_count == 1 and self.last_epoch > 0: - return [self.eta_min + (base_lr - self.eta_min) * - (1 + math.cos((self.last_epoch) * math.pi / self.T_max)) / 2 - for base_lr, group in - zip(self.base_lrs, self.optimizer.param_groups)] - elif (self.last_epoch - 1 - self.T_max) % (2 * self.T_max) == 0: - return [group['lr'] + (base_lr - self.eta_min) * - (1 - math.cos(math.pi / self.T_max)) / 2 - for base_lr, group in - zip(self.base_lrs, self.optimizer.param_groups)] - return [(1 + math.cos(math.pi * self.last_epoch / self.T_max)) / - (1 + math.cos(math.pi * (self.last_epoch - 1) / self.T_max)) * - (group['lr'] - self.eta_min) + self.eta_min - for group in self.optimizer.param_groups] - - def _get_closed_form_lr(self): - return [self.eta_min + (base_lr - self.eta_min) * - (1 + math.cos(math.pi * self.last_epoch / self.T_max)) / 2 - for base_lr in self.base_lrs] - - -class ChainedScheduler(LRScheduler): - """Chains list of learning rate schedulers. It takes a list of chainable learning - rate schedulers and performs consecutive step() functions belonging to them by just - one call. - - Args: - schedulers (list): List of chained schedulers. - - Example: - >>> # xdoctest: +SKIP - >>> # Assuming optimizer uses lr = 1. for all groups - >>> # lr = 0.09 if epoch == 0 - >>> # lr = 0.081 if epoch == 1 - >>> # lr = 0.729 if epoch == 2 - >>> # lr = 0.6561 if epoch == 3 - >>> # lr = 0.59049 if epoch >= 4 - >>> scheduler1 = ConstantLR(self.opt, factor=0.1, total_iters=2) - >>> scheduler2 = ExponentialLR(self.opt, gamma=0.9) - >>> scheduler = ChainedScheduler([scheduler1, scheduler2]) - >>> for epoch in range(100): - >>> train(...) - >>> validate(...) - >>> scheduler.step() - """ - - def __init__(self, schedulers): - for scheduler_idx in range(1, len(schedulers)): - if (schedulers[scheduler_idx].optimizer != schedulers[0].optimizer): - raise ValueError( - "ChainedScheduler expects all schedulers to belong to the same optimizer, but " - "got schedulers at index {} and {} to be different".format(0, scheduler_idx) - ) - self._schedulers = list(schedulers) - self.optimizer = schedulers[0].optimizer - self._last_lr = [group['lr'] for group in self._schedulers[-1].optimizer.param_groups] - - def step(self): - for scheduler in self._schedulers: - scheduler.step() - self._last_lr = [group['lr'] for group in self._schedulers[-1].optimizer.param_groups] - - def state_dict(self): - """Returns the state of the scheduler as a :class:`dict`. - - It contains an entry for every variable in self.__dict__ which - is not the optimizer. - The wrapped scheduler states will also be saved. - """ - state_dict = {key: value for key, value in self.__dict__.items() if key not in ('optimizer', '_schedulers')} - state_dict['_schedulers'] = [None] * len(self._schedulers) - - for idx, s in enumerate(self._schedulers): - state_dict['_schedulers'][idx] = s.state_dict() - - return state_dict - - def load_state_dict(self, state_dict): - """Loads the schedulers state. - - Args: - state_dict (dict): scheduler state. Should be an object returned - from a call to :meth:`state_dict`. - """ - _schedulers = state_dict.pop('_schedulers') - self.__dict__.update(state_dict) - # Restore state_dict keys in order to prevent side effects - # https://github.com/pytorch/pytorch/issues/32756 - state_dict['_schedulers'] = _schedulers - - for idx, s in enumerate(_schedulers): - self._schedulers[idx].load_state_dict(s) - - -class ReduceLROnPlateau: - """Reduce learning rate when a metric has stopped improving. - Models often benefit from reducing the learning rate by a factor - of 2-10 once learning stagnates. This scheduler reads a metrics - quantity and if no improvement is seen for a 'patience' number - of epochs, the learning rate is reduced. - - Args: - optimizer (Optimizer): Wrapped optimizer. - mode (str): One of `min`, `max`. In `min` mode, lr will - be reduced when the quantity monitored has stopped - decreasing; in `max` mode it will be reduced when the - quantity monitored has stopped increasing. Default: 'min'. - factor (float): Factor by which the learning rate will be - reduced. new_lr = lr * factor. Default: 0.1. - patience (int): Number of epochs with no improvement after - which learning rate will be reduced. For example, if - `patience = 2`, then we will ignore the first 2 epochs - with no improvement, and will only decrease the LR after the - 3rd epoch if the loss still hasn't improved then. - Default: 10. - threshold (float): Threshold for measuring the new optimum, - to only focus on significant changes. Default: 1e-4. - threshold_mode (str): One of `rel`, `abs`. In `rel` mode, - dynamic_threshold = best * ( 1 + threshold ) in 'max' - mode or best * ( 1 - threshold ) in `min` mode. - In `abs` mode, dynamic_threshold = best + threshold in - `max` mode or best - threshold in `min` mode. Default: 'rel'. - cooldown (int): Number of epochs to wait before resuming - normal operation after lr has been reduced. Default: 0. - min_lr (float or list): A scalar or a list of scalars. A - lower bound on the learning rate of all param groups - or each group respectively. Default: 0. - eps (float): Minimal decay applied to lr. If the difference - between new and old lr is smaller than eps, the update is - ignored. Default: 1e-8. - verbose (bool): If ``True``, prints a message to stdout for - each update. Default: ``False``. - - Example: - >>> # xdoctest: +SKIP - >>> optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9) - >>> scheduler = ReduceLROnPlateau(optimizer, 'min') - >>> for epoch in range(10): - >>> train(...) - >>> val_loss = validate(...) - >>> # Note that step should be called after validate() - >>> scheduler.step(val_loss) - """ - - def __init__(self, optimizer, mode='min', factor=0.1, patience=10, - threshold=1e-4, threshold_mode='rel', cooldown=0, - min_lr=0, eps=1e-8, verbose=False): - - if factor >= 1.0: - raise ValueError('Factor should be < 1.0.') - self.factor = factor - - # Attach optimizer - if not isinstance(optimizer, Optimizer): - raise TypeError('{} is not an Optimizer'.format( - type(optimizer).__name__)) - self.optimizer = optimizer - - if isinstance(min_lr, (list, tuple)): - if len(min_lr) != len(optimizer.param_groups): - raise ValueError("expected {} min_lrs, got {}".format( - len(optimizer.param_groups), len(min_lr))) - self.min_lrs = list(min_lr) - else: - self.min_lrs = [min_lr] * len(optimizer.param_groups) - - self.patience = patience - self.verbose = verbose - self.cooldown = cooldown - self.cooldown_counter = 0 - self.mode = mode - self.threshold = threshold - self.threshold_mode = threshold_mode - self.best = None - self.num_bad_epochs = None - self.mode_worse = None # the worse value for the chosen mode - self.eps = eps - self.last_epoch = 0 - self._init_is_better(mode=mode, threshold=threshold, - threshold_mode=threshold_mode) - self._reset() - - def _reset(self): - """Resets num_bad_epochs counter and cooldown counter.""" - self.best = self.mode_worse - self.cooldown_counter = 0 - self.num_bad_epochs = 0 - - def step(self, metrics, epoch=None): - # convert `metrics` to float, in case it's a zero-dim Tensor - current = float(metrics) - if epoch is None: - epoch = self.last_epoch + 1 - else: - warnings.warn(EPOCH_DEPRECATION_WARNING, UserWarning) - self.last_epoch = epoch - - if self.is_better(current, self.best): - self.best = current - self.num_bad_epochs = 0 - else: - self.num_bad_epochs += 1 - - if self.in_cooldown: - self.cooldown_counter -= 1 - self.num_bad_epochs = 0 # ignore any bad epochs in cooldown - - if self.num_bad_epochs > self.patience: - self._reduce_lr(epoch) - self.cooldown_counter = self.cooldown - self.num_bad_epochs = 0 - - self._last_lr = [group['lr'] for group in self.optimizer.param_groups] - - def _reduce_lr(self, epoch): - for i, param_group in enumerate(self.optimizer.param_groups): - old_lr = float(param_group['lr']) - new_lr = max(old_lr * self.factor, self.min_lrs[i]) - if old_lr - new_lr > self.eps: - param_group['lr'] = new_lr - if self.verbose: - epoch_str = ("%.2f" if isinstance(epoch, float) else - "%.5d") % epoch - print('Epoch {}: reducing learning rate' - ' of group {} to {:.4e}.'.format(epoch_str, i, new_lr)) - - @property - def in_cooldown(self): - return self.cooldown_counter > 0 - - def is_better(self, a, best): - if self.mode == 'min' and self.threshold_mode == 'rel': - rel_epsilon = 1. - self.threshold - return a < best * rel_epsilon - - elif self.mode == 'min' and self.threshold_mode == 'abs': - return a < best - self.threshold - - elif self.mode == 'max' and self.threshold_mode == 'rel': - rel_epsilon = self.threshold + 1. - return a > best * rel_epsilon - - else: # mode == 'max' and epsilon_mode == 'abs': - return a > best + self.threshold - - def _init_is_better(self, mode, threshold, threshold_mode): - if mode not in {'min', 'max'}: - raise ValueError('mode ' + mode + ' is unknown!') - if threshold_mode not in {'rel', 'abs'}: - raise ValueError('threshold mode ' + threshold_mode + ' is unknown!') - - if mode == 'min': - self.mode_worse = inf - else: # mode == 'max': - self.mode_worse = -inf - - self.mode = mode - self.threshold = threshold - self.threshold_mode = threshold_mode - - def state_dict(self): - return {key: value for key, value in self.__dict__.items() if key != 'optimizer'} - - def load_state_dict(self, state_dict): - self.__dict__.update(state_dict) - self._init_is_better(mode=self.mode, threshold=self.threshold, threshold_mode=self.threshold_mode) - - -class CyclicLR(LRScheduler): - r"""Sets the learning rate of each parameter group according to - cyclical learning rate policy (CLR). The policy cycles the learning - rate between two boundaries with a constant frequency, as detailed in - the paper `Cyclical Learning Rates for Training Neural Networks`_. - The distance between the two boundaries can be scaled on a per-iteration - or per-cycle basis. - - Cyclical learning rate policy changes the learning rate after every batch. - `step` should be called after a batch has been used for training. - - This class has three built-in policies, as put forth in the paper: - - * "triangular": A basic triangular cycle without amplitude scaling. - * "triangular2": A basic triangular cycle that scales initial amplitude by half each cycle. - * "exp_range": A cycle that scales initial amplitude by :math:`\text{gamma}^{\text{cycle iterations}}` - at each cycle iteration. - - This implementation was adapted from the github repo: `bckenstler/CLR`_ - - Args: - optimizer (Optimizer): Wrapped optimizer. - base_lr (float or list): Initial learning rate which is the - lower boundary in the cycle for each parameter group. - max_lr (float or list): Upper learning rate boundaries in the cycle - for each parameter group. Functionally, - it defines the cycle amplitude (max_lr - base_lr). - The lr at any cycle is the sum of base_lr - and some scaling of the amplitude; therefore - max_lr may not actually be reached depending on - scaling function. - step_size_up (int): Number of training iterations in the - increasing half of a cycle. Default: 2000 - step_size_down (int): Number of training iterations in the - decreasing half of a cycle. If step_size_down is None, - it is set to step_size_up. Default: None - mode (str): One of {triangular, triangular2, exp_range}. - Values correspond to policies detailed above. - If scale_fn is not None, this argument is ignored. - Default: 'triangular' - gamma (float): Constant in 'exp_range' scaling function: - gamma**(cycle iterations) - Default: 1.0 - scale_fn (function): Custom scaling policy defined by a single - argument lambda function, where - 0 <= scale_fn(x) <= 1 for all x >= 0. - If specified, then 'mode' is ignored. - Default: None - scale_mode (str): {'cycle', 'iterations'}. - Defines whether scale_fn is evaluated on - cycle number or cycle iterations (training - iterations since start of cycle). - Default: 'cycle' - cycle_momentum (bool): If ``True``, momentum is cycled inversely - to learning rate between 'base_momentum' and 'max_momentum'. - Default: True - base_momentum (float or list): Lower momentum boundaries in the cycle - for each parameter group. Note that momentum is cycled inversely - to learning rate; at the peak of a cycle, momentum is - 'base_momentum' and learning rate is 'max_lr'. - Default: 0.8 - max_momentum (float or list): Upper momentum boundaries in the cycle - for each parameter group. Functionally, - it defines the cycle amplitude (max_momentum - base_momentum). - The momentum at any cycle is the difference of max_momentum - and some scaling of the amplitude; therefore - base_momentum may not actually be reached depending on - scaling function. Note that momentum is cycled inversely - to learning rate; at the start of a cycle, momentum is 'max_momentum' - and learning rate is 'base_lr' - Default: 0.9 - last_epoch (int): The index of the last batch. This parameter is used when - resuming a training job. Since `step()` should be invoked after each - batch instead of after each epoch, this number represents the total - number of *batches* computed, not the total number of epochs computed. - When last_epoch=-1, the schedule is started from the beginning. - Default: -1 - verbose (bool): If ``True``, prints a message to stdout for - each update. Default: ``False``. - - Example: - >>> # xdoctest: +SKIP - >>> optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9) - >>> scheduler = torch.optim.lr_scheduler.CyclicLR(optimizer, base_lr=0.01, max_lr=0.1) - >>> data_loader = torch.utils.data.DataLoader(...) - >>> for epoch in range(10): - >>> for batch in data_loader: - >>> train_batch(...) - >>> scheduler.step() - - - .. _Cyclical Learning Rates for Training Neural Networks: https://arxiv.org/abs/1506.01186 - .. _bckenstler/CLR: https://github.com/bckenstler/CLR - """ - - def __init__(self, - optimizer, - base_lr, - max_lr, - step_size_up=2000, - step_size_down=None, - mode='triangular', - gamma=1., - scale_fn=None, - scale_mode='cycle', - cycle_momentum=True, - base_momentum=0.8, - max_momentum=0.9, - last_epoch=-1, - verbose=False): - - # Attach optimizer - if not isinstance(optimizer, Optimizer): - raise TypeError('{} is not an Optimizer'.format( - type(optimizer).__name__)) - self.optimizer = optimizer - - base_lrs = self._format_param('base_lr', optimizer, base_lr) - if last_epoch == -1: - for lr, group in zip(base_lrs, optimizer.param_groups): - group['lr'] = lr - - self.max_lrs = self._format_param('max_lr', optimizer, max_lr) - - step_size_up = float(step_size_up) - step_size_down = float(step_size_down) if step_size_down is not None else step_size_up - self.total_size = step_size_up + step_size_down - self.step_ratio = step_size_up / self.total_size - - if mode not in ['triangular', 'triangular2', 'exp_range'] \ - and scale_fn is None: - raise ValueError('mode is invalid and scale_fn is None') - - self.mode = mode - self.gamma = gamma - - self._scale_fn_ref = None - self._scale_fn_custom = scale_fn - self.scale_mode = scale_mode - self._init_scale_fn() - - self.cycle_momentum = cycle_momentum - if cycle_momentum: - if 'momentum' not in optimizer.defaults: - raise ValueError('optimizer must support momentum with `cycle_momentum` option enabled') - - base_momentums = self._format_param('base_momentum', optimizer, base_momentum) - if last_epoch == -1: - for momentum, group in zip(base_momentums, optimizer.param_groups): - group['momentum'] = momentum - self.base_momentums = [group['momentum'] for group in optimizer.param_groups] - self.max_momentums = self._format_param('max_momentum', optimizer, max_momentum) - - super().__init__(optimizer, last_epoch, verbose) - self.base_lrs = base_lrs - - def _init_scale_fn(self): - if self._scale_fn_custom is not None: - return - if self.mode == 'triangular': - self._scale_fn_ref = weakref.WeakMethod(self._triangular_scale_fn) - self.scale_mode = 'cycle' - elif self.mode == 'triangular2': - self._scale_fn_ref = weakref.WeakMethod(self._triangular2_scale_fn) - self.scale_mode = 'cycle' - elif self.mode == 'exp_range': - self._scale_fn_ref = weakref.WeakMethod(self._exp_range_scale_fn) - self.scale_mode = 'iterations' - - def _format_param(self, name, optimizer, param): - """Return correctly formatted lr/momentum for each param group.""" - if isinstance(param, (list, tuple)): - if len(param) != len(optimizer.param_groups): - raise ValueError("expected {} values for {}, got {}".format( - len(optimizer.param_groups), name, len(param))) - return param - else: - return [param] * len(optimizer.param_groups) - - def scale_fn(self, x): - if self._scale_fn_custom is not None: - return self._scale_fn_custom(x) - - else: - return self._scale_fn_ref()(x) - - def _triangular_scale_fn(self, x): - return 1. - - def _triangular2_scale_fn(self, x): - return 1 / (2. ** (x - 1)) - - def _exp_range_scale_fn(self, x): - return self.gamma**(x) - - def get_lr(self): - """Calculates the learning rate at batch index. This function treats - `self.last_epoch` as the last batch index. - - If `self.cycle_momentum` is ``True``, this function has a side effect of - updating the optimizer's momentum. - """ - - if not self._get_lr_called_within_step: - warnings.warn("To get the last learning rate computed by the scheduler, " - "please use `get_last_lr()`.", UserWarning) - - cycle = math.floor(1 + self.last_epoch / self.total_size) - x = 1. + self.last_epoch / self.total_size - cycle - if x <= self.step_ratio: - scale_factor = x / self.step_ratio - else: - scale_factor = (x - 1) / (self.step_ratio - 1) - - lrs = [] - for base_lr, max_lr in zip(self.base_lrs, self.max_lrs): - base_height = (max_lr - base_lr) * scale_factor - if self.scale_mode == 'cycle': - lr = base_lr + base_height * self.scale_fn(cycle) - else: - lr = base_lr + base_height * self.scale_fn(self.last_epoch) - lrs.append(lr) - - if self.cycle_momentum: - momentums = [] - for base_momentum, max_momentum in zip(self.base_momentums, self.max_momentums): - base_height = (max_momentum - base_momentum) * scale_factor - if self.scale_mode == 'cycle': - momentum = max_momentum - base_height * self.scale_fn(cycle) - else: - momentum = max_momentum - base_height * self.scale_fn(self.last_epoch) - momentums.append(momentum) - for param_group, momentum in zip(self.optimizer.param_groups, momentums): - param_group['momentum'] = momentum - - return lrs - - def state_dict(self): - state = super().state_dict() - # We are dropping the `_scale_fn_ref` attribute because it is a `weakref.WeakMethod` and can't be pickled - state.pop("_scale_fn_ref") - return state - - def load_state_dict(self, state_dict): - super().load_state_dict(state_dict) - self._init_scale_fn() - - - -class CosineAnnealingWarmRestarts(LRScheduler): - r"""Set the learning rate of each parameter group using a cosine annealing - schedule, where :math:`\eta_{max}` is set to the initial lr, :math:`T_{cur}` - is the number of epochs since the last restart and :math:`T_{i}` is the number - of epochs between two warm restarts in SGDR: - - .. math:: - \eta_t = \eta_{min} + \frac{1}{2}(\eta_{max} - \eta_{min})\left(1 + - \cos\left(\frac{T_{cur}}{T_{i}}\pi\right)\right) - - When :math:`T_{cur}=T_{i}`, set :math:`\eta_t = \eta_{min}`. - When :math:`T_{cur}=0` after restart, set :math:`\eta_t=\eta_{max}`. - - It has been proposed in - `SGDR: Stochastic Gradient Descent with Warm Restarts`_. - - Args: - optimizer (Optimizer): Wrapped optimizer. - T_0 (int): Number of iterations for the first restart. - T_mult (int, optional): A factor increases :math:`T_{i}` after a restart. Default: 1. - eta_min (float, optional): Minimum learning rate. Default: 0. - last_epoch (int, optional): The index of last epoch. Default: -1. - verbose (bool): If ``True``, prints a message to stdout for - each update. Default: ``False``. - - .. _SGDR\: Stochastic Gradient Descent with Warm Restarts: - https://arxiv.org/abs/1608.03983 - """ - - def __init__(self, optimizer, T_0, T_mult=1, eta_min=0, last_epoch=-1, verbose=False): - if T_0 <= 0 or not isinstance(T_0, int): - raise ValueError("Expected positive integer T_0, but got {}".format(T_0)) - if T_mult < 1 or not isinstance(T_mult, int): - raise ValueError("Expected integer T_mult >= 1, but got {}".format(T_mult)) - self.T_0 = T_0 - self.T_i = T_0 - self.T_mult = T_mult - self.eta_min = eta_min - self.T_cur = last_epoch - super().__init__(optimizer, last_epoch, verbose) - - def get_lr(self): - if not self._get_lr_called_within_step: - warnings.warn("To get the last learning rate computed by the scheduler, " - "please use `get_last_lr()`.", UserWarning) - - return [self.eta_min + (base_lr - self.eta_min) * (1 + math.cos(math.pi * self.T_cur / self.T_i)) / 2 - for base_lr in self.base_lrs] - - def step(self, epoch=None): - """Step could be called after every batch update - - Example: - >>> # xdoctest: +SKIP("Undefined vars") - >>> scheduler = CosineAnnealingWarmRestarts(optimizer, T_0, T_mult) - >>> iters = len(dataloader) - >>> for epoch in range(20): - >>> for i, sample in enumerate(dataloader): - >>> inputs, labels = sample['inputs'], sample['labels'] - >>> optimizer.zero_grad() - >>> outputs = net(inputs) - >>> loss = criterion(outputs, labels) - >>> loss.backward() - >>> optimizer.step() - >>> scheduler.step(epoch + i / iters) - - This function can be called in an interleaved way. - - Example: - >>> # xdoctest: +SKIP("Undefined vars") - >>> scheduler = CosineAnnealingWarmRestarts(optimizer, T_0, T_mult) - >>> for epoch in range(20): - >>> scheduler.step() - >>> scheduler.step(26) - >>> scheduler.step() # scheduler.step(27), instead of scheduler(20) - """ - - if epoch is None and self.last_epoch < 0: - epoch = 0 - - if epoch is None: - epoch = self.last_epoch + 1 - self.T_cur = self.T_cur + 1 - if self.T_cur >= self.T_i: - self.T_cur = self.T_cur - self.T_i - self.T_i = self.T_i * self.T_mult - else: - if epoch < 0: - raise ValueError("Expected non-negative epoch, but got {}".format(epoch)) - if epoch >= self.T_0: - if self.T_mult == 1: - self.T_cur = epoch % self.T_0 - else: - n = int(math.log((epoch / self.T_0 * (self.T_mult - 1) + 1), self.T_mult)) - self.T_cur = epoch - self.T_0 * (self.T_mult ** n - 1) / (self.T_mult - 1) - self.T_i = self.T_0 * self.T_mult ** (n) - else: - self.T_i = self.T_0 - self.T_cur = epoch - self.last_epoch = math.floor(epoch) - - class _enable_get_lr_call: - - def __init__(self, o): - self.o = o - - def __enter__(self): - self.o._get_lr_called_within_step = True - return self - - def __exit__(self, type, value, traceback): - self.o._get_lr_called_within_step = False - return self - - with _enable_get_lr_call(self): - for i, data in enumerate(zip(self.optimizer.param_groups, self.get_lr())): - param_group, lr = data - param_group['lr'] = lr - self.print_lr(self.verbose, i, lr, epoch) - - self._last_lr = [group['lr'] for group in self.optimizer.param_groups] - - -class OneCycleLR(LRScheduler): - r"""Sets the learning rate of each parameter group according to the - 1cycle learning rate policy. The 1cycle policy anneals the learning - rate from an initial learning rate to some maximum learning rate and then - from that maximum learning rate to some minimum learning rate much lower - than the initial learning rate. - This policy was initially described in the paper `Super-Convergence: - Very Fast Training of Neural Networks Using Large Learning Rates`_. - - The 1cycle learning rate policy changes the learning rate after every batch. - `step` should be called after a batch has been used for training. - - This scheduler is not chainable. - - Note also that the total number of steps in the cycle can be determined in one - of two ways (listed in order of precedence): - - #. A value for total_steps is explicitly provided. - #. A number of epochs (epochs) and a number of steps per epoch - (steps_per_epoch) are provided. - In this case, the number of total steps is inferred by - total_steps = epochs * steps_per_epoch - - You must either provide a value for total_steps or provide a value for both - epochs and steps_per_epoch. - - The default behaviour of this scheduler follows the fastai implementation of 1cycle, which - claims that "unpublished work has shown even better results by using only two phases". To - mimic the behaviour of the original paper instead, set ``three_phase=True``. - - Args: - optimizer (Optimizer): Wrapped optimizer. - max_lr (float or list): Upper learning rate boundaries in the cycle - for each parameter group. - total_steps (int): The total number of steps in the cycle. Note that - if a value is not provided here, then it must be inferred by providing - a value for epochs and steps_per_epoch. - Default: None - epochs (int): The number of epochs to train for. This is used along - with steps_per_epoch in order to infer the total number of steps in the cycle - if a value for total_steps is not provided. - Default: None - steps_per_epoch (int): The number of steps per epoch to train for. This is - used along with epochs in order to infer the total number of steps in the - cycle if a value for total_steps is not provided. - Default: None - pct_start (float): The percentage of the cycle (in number of steps) spent - increasing the learning rate. - Default: 0.3 - anneal_strategy (str): {'cos', 'linear'} - Specifies the annealing strategy: "cos" for cosine annealing, "linear" for - linear annealing. - Default: 'cos' - cycle_momentum (bool): If ``True``, momentum is cycled inversely - to learning rate between 'base_momentum' and 'max_momentum'. - Default: True - base_momentum (float or list): Lower momentum boundaries in the cycle - for each parameter group. Note that momentum is cycled inversely - to learning rate; at the peak of a cycle, momentum is - 'base_momentum' and learning rate is 'max_lr'. - Default: 0.85 - max_momentum (float or list): Upper momentum boundaries in the cycle - for each parameter group. Functionally, - it defines the cycle amplitude (max_momentum - base_momentum). - Note that momentum is cycled inversely - to learning rate; at the start of a cycle, momentum is 'max_momentum' - and learning rate is 'base_lr' - Default: 0.95 - div_factor (float): Determines the initial learning rate via - initial_lr = max_lr/div_factor - Default: 25 - final_div_factor (float): Determines the minimum learning rate via - min_lr = initial_lr/final_div_factor - Default: 1e4 - three_phase (bool): If ``True``, use a third phase of the schedule to annihilate the - learning rate according to 'final_div_factor' instead of modifying the second - phase (the first two phases will be symmetrical about the step indicated by - 'pct_start'). - last_epoch (int): The index of the last batch. This parameter is used when - resuming a training job. Since `step()` should be invoked after each - batch instead of after each epoch, this number represents the total - number of *batches* computed, not the total number of epochs computed. - When last_epoch=-1, the schedule is started from the beginning. - Default: -1 - verbose (bool): If ``True``, prints a message to stdout for - each update. Default: ``False``. - - Example: - >>> # xdoctest: +SKIP - >>> data_loader = torch.utils.data.DataLoader(...) - >>> optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9) - >>> scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer, max_lr=0.01, steps_per_epoch=len(data_loader), epochs=10) - >>> for epoch in range(10): - >>> for batch in data_loader: - >>> train_batch(...) - >>> optimizer.step() - >>> scheduler.step() - - - .. _Super-Convergence\: Very Fast Training of Neural Networks Using Large Learning Rates: - https://arxiv.org/abs/1708.07120 - """ - def __init__(self, - optimizer, - max_lr, - total_steps=None, - epochs=None, - steps_per_epoch=None, - pct_start=0.3, - anneal_strategy='cos', - cycle_momentum=True, - base_momentum=0.85, - max_momentum=0.95, - div_factor=25., - final_div_factor=1e4, - three_phase=False, - last_epoch=-1, - verbose=False): - - # Validate optimizer - if not isinstance(optimizer, Optimizer): - raise TypeError('{} is not an Optimizer'.format( - type(optimizer).__name__)) - self.optimizer = optimizer - - # Validate total_steps - if total_steps is None and epochs is None and steps_per_epoch is None: - raise ValueError("You must define either total_steps OR (epochs AND steps_per_epoch)") - elif total_steps is not None: - if total_steps <= 0 or not isinstance(total_steps, int): - raise ValueError("Expected positive integer total_steps, but got {}".format(total_steps)) - self.total_steps = total_steps - else: - if epochs <= 0 or not isinstance(epochs, int): - raise ValueError("Expected positive integer epochs, but got {}".format(epochs)) - if steps_per_epoch <= 0 or not isinstance(steps_per_epoch, int): - raise ValueError("Expected positive integer steps_per_epoch, but got {}".format(steps_per_epoch)) - self.total_steps = epochs * steps_per_epoch - - if three_phase: - self._schedule_phases = [ - { - 'end_step': float(pct_start * self.total_steps) - 1, - 'start_lr': 'initial_lr', - 'end_lr': 'max_lr', - 'start_momentum': 'max_momentum', - 'end_momentum': 'base_momentum', - }, - { - 'end_step': float(2 * pct_start * self.total_steps) - 2, - 'start_lr': 'max_lr', - 'end_lr': 'initial_lr', - 'start_momentum': 'base_momentum', - 'end_momentum': 'max_momentum', - }, - { - 'end_step': self.total_steps - 1, - 'start_lr': 'initial_lr', - 'end_lr': 'min_lr', - 'start_momentum': 'max_momentum', - 'end_momentum': 'max_momentum', - }, - ] - else: - self._schedule_phases = [ - { - 'end_step': float(pct_start * self.total_steps) - 1, - 'start_lr': 'initial_lr', - 'end_lr': 'max_lr', - 'start_momentum': 'max_momentum', - 'end_momentum': 'base_momentum', - }, - { - 'end_step': self.total_steps - 1, - 'start_lr': 'max_lr', - 'end_lr': 'min_lr', - 'start_momentum': 'base_momentum', - 'end_momentum': 'max_momentum', - }, - ] - - # Validate pct_start - if pct_start < 0 or pct_start > 1 or not isinstance(pct_start, float): - raise ValueError("Expected float between 0 and 1 pct_start, but got {}".format(pct_start)) - - # Validate anneal_strategy - if anneal_strategy not in ['cos', 'linear']: - raise ValueError("anneal_strategy must by one of 'cos' or 'linear', instead got {}".format(anneal_strategy)) - elif anneal_strategy == 'cos': - self.anneal_func = self._annealing_cos - elif anneal_strategy == 'linear': - self.anneal_func = self._annealing_linear - - # Initialize learning rate variables - max_lrs = self._format_param('max_lr', self.optimizer, max_lr) - if last_epoch == -1: - for idx, group in enumerate(self.optimizer.param_groups): - group['initial_lr'] = max_lrs[idx] / div_factor - group['max_lr'] = max_lrs[idx] - group['min_lr'] = group['initial_lr'] / final_div_factor - - # Initialize momentum variables - self.cycle_momentum = cycle_momentum - if self.cycle_momentum: - if 'momentum' not in self.optimizer.defaults and 'betas' not in self.optimizer.defaults: - raise ValueError('optimizer must support momentum with `cycle_momentum` option enabled') - self.use_beta1 = 'betas' in self.optimizer.defaults - max_momentums = self._format_param('max_momentum', optimizer, max_momentum) - base_momentums = self._format_param('base_momentum', optimizer, base_momentum) - if last_epoch == -1: - for m_momentum, b_momentum, group in zip(max_momentums, base_momentums, optimizer.param_groups): - if self.use_beta1: - group['betas'] = (m_momentum, *group['betas'][1:]) - else: - group['momentum'] = m_momentum - group['max_momentum'] = m_momentum - group['base_momentum'] = b_momentum - - super().__init__(optimizer, last_epoch, verbose) - - def _format_param(self, name, optimizer, param): - """Return correctly formatted lr/momentum for each param group.""" - if isinstance(param, (list, tuple)): - if len(param) != len(optimizer.param_groups): - raise ValueError("expected {} values for {}, got {}".format( - len(optimizer.param_groups), name, len(param))) - return param - else: - return [param] * len(optimizer.param_groups) - - def _annealing_cos(self, start, end, pct): - "Cosine anneal from `start` to `end` as pct goes from 0.0 to 1.0." - cos_out = math.cos(math.pi * pct) + 1 - return end + (start - end) / 2.0 * cos_out - - def _annealing_linear(self, start, end, pct): - "Linearly anneal from `start` to `end` as pct goes from 0.0 to 1.0." - return (end - start) * pct + start - - def get_lr(self): - if not self._get_lr_called_within_step: - warnings.warn("To get the last learning rate computed by the scheduler, " - "please use `get_last_lr()`.", UserWarning) - - lrs = [] - step_num = self.last_epoch - - if step_num > self.total_steps: - raise ValueError("Tried to step {} times. The specified number of total steps is {}" - .format(step_num, self.total_steps)) - - for group in self.optimizer.param_groups: - start_step = 0 - for i, phase in enumerate(self._schedule_phases): - end_step = phase['end_step'] - if step_num <= end_step or i == len(self._schedule_phases) - 1: - pct = (step_num - start_step) / (end_step - start_step) - computed_lr = self.anneal_func(group[phase['start_lr']], group[phase['end_lr']], pct) - if self.cycle_momentum: - computed_momentum = self.anneal_func(group[phase['start_momentum']], group[phase['end_momentum']], pct) - break - start_step = phase['end_step'] - - lrs.append(computed_lr) - if self.cycle_momentum: - if self.use_beta1: - group['betas'] = (computed_momentum, *group['betas'][1:]) - else: - group['momentum'] = computed_momentum - - return lrs \ No newline at end of file diff --git a/python/jittor/compatibility/src/jtorch_core.cc b/python/jittor/compatibility/src/jtorch_core.cc deleted file mode 100644 index 1102b107..00000000 --- a/python/jittor/compatibility/src/jtorch_core.cc +++ /dev/null @@ -1,102 +0,0 @@ - -#include "pyjt/py_obj_holder.h" -#include "utils/str_utils.h" -#include "jtorch_core.h" -#include "graph.h" -#include "grad.h" -#include "ops/op_register.h" - -namespace jittor { - -void pyjt_def_all(PyObject* m); - -EXTERN_LIB void setter_use_cuda(int value); - -Device::Device(const string& name, int ordinal) : name(name) { - if (startswith(name, "cpu")) - setter_use_cuda(0); - else - setter_use_cuda(1); -} - -unordered_map grad_backup; -EXTERN_LIB void (*_var_free_hook)(Var*); -EXTERN_LIB unordered_map* _grad_backup_ptr; - -void jtorch_var_free_hook(Var* v) { - auto iter = grad_backup.find(v->id); - if (iter != grad_backup.end()) { - grad_backup.erase(iter); - } -} - -void jtorch_init() { - _var_free_hook = &jtorch_var_free_hook; - _grad_backup_ptr = &grad_backup; -} - -inline static VarPtr& get_grad(Var* v) { - return grad_backup[v->id]; -} -static auto make_binary = get_op_info("binary") - .get_constructor(); - -inline static void add_grad(VarPtr& a, VarPtr&& b) { - if (!a) a = move(b); - else { - a = make_binary(a, b, ns_add); - } -} - - -void grad_set(VarHolder* x, Maybe v) { - if (!v) { - grad_del(x); - return; - } - grad_backup[x->var->id] = v.ptr->var; -} - -Maybe grad_get(VarHolder* x) { - auto iter = grad_backup.find(x->var->id); - if (iter != grad_backup.end()) { - if (!iter->second.ptr) return nullptr; - return new VarHolder(iter->second.ptr); - } - return nullptr; -} - -void grad_del(VarHolder* x) { - auto iter = grad_backup.find(x->var->id); - if (iter != grad_backup.end()) - grad_backup.erase(iter); -} - -void backward(VarHolder* x) { - vector gnodes({x->var}); - bfs_backward(gnodes, [&](Node* node) { - if (node->is_stop_grad()) - return false; - return true; - }); - vector targets; - for (auto* node : gnodes) { - if (node->is_var() && node->flags.get(NodeFlags::_th_require_grad)) - targets.push_back(node->var()); - } - auto grads = grad(x->var, targets); - for (int i=0; im_doc = "Inner c++ core of jtorch"; - jittor::pyjt_def_all(m); -} -PYJT_MODULE_INIT(jtorch_core); diff --git a/python/jittor/compatibility/src/jtorch_core.h b/python/jittor/compatibility/src/jtorch_core.h deleted file mode 100644 index 36de6522..00000000 --- a/python/jittor/compatibility/src/jtorch_core.h +++ /dev/null @@ -1,40 +0,0 @@ -#pragma once -#include "common.h" -#include "var_holder.h" -#include "misc/fast_shared_ptr.h" - -namespace jittor { - -// @pyjt(device) -// @attrs(heaptype) -struct Device { - string name; - - // @pyjt(__init__) - Device(const string& name, int ordinal=0); - // @pyjt(__get__type, __str__) - inline string get_type() {return name;} - // @pyjt(__get__index) - inline int index() {return 0;} -}; - -// @pyjt(backward) -void backward(VarHolder* x); - -// @pyjt(grad_set) -void grad_set(VarHolder* x, Maybe v); -// @pyjt(grad_get) -Maybe grad_get(VarHolder* x); -// @pyjt(grad_del) -void grad_del(VarHolder* x); - -// @pyjt(retain_grad_set) -inline void retain_grad_set(VarHolder* x, bool v) { - x->var->flags.set(NodeFlags::_th_require_grad, v); -} -// @pyjt(retain_grad_get) -inline bool retain_grad_get(VarHolder* x) { - return x->var->flags.get(NodeFlags::_th_require_grad); -} - -} \ No newline at end of file diff --git a/python/jittor/compatibility/test/test_conflict_func.py b/python/jittor/compatibility/test/test_conflict_func.py deleted file mode 100644 index 97bd7d8f..00000000 --- a/python/jittor/compatibility/test/test_conflict_func.py +++ /dev/null @@ -1,25 +0,0 @@ -import unittest -import numpy as np -import torch -import jittor as jt - -class TestConflictFunc(unittest.TestCase): - def test_max(self): - a = torch.Tensor([1,4,2]) - assert a.max() == 4 - v, k = a.max(dim=0) - assert v==4 and k==1 - - def test_argsort(self): - a = torch.Tensor([1,4,2]) - k = a.argsort() - assert jt.all_equal(k, [0,2,1]) - - with jt.flag_scope(th_mode=0): - k, v = a.argsort() - assert jt.all_equal(k, [0,2,1]) - - - -if __name__ == "__main__": - unittest.main() diff --git a/python/jittor/compatibility/test/test_function.py b/python/jittor/compatibility/test/test_function.py deleted file mode 100644 index 9959dbae..00000000 --- a/python/jittor/compatibility/test/test_function.py +++ /dev/null @@ -1,58 +0,0 @@ -import unittest -import numpy as np -import torch - -class TestFunction(unittest.TestCase): - def test_example1(self): - import jtorch - from jtorch import Function - - class MyFunc(Function): - @staticmethod - def forward(self, x, y): - self.x = x - self.y = y - return x*y, x/y - - @staticmethod - def backward(self, grad0, grad1): - return grad0 * self.y, grad1 * self.x - - a = jtorch.array(3.0) - a.requires_grad = True - b = jtorch.array(4.0) - b.requires_grad = True - func = MyFunc.apply - c,d = func(a, b) - (c+d*3).backward() - assert a.grad.data == 4 - assert b.grad.data == 9 - - def test_example2(self): - import jtorch as jt - from jtorch import Function - - class MyFunc(Function): - @staticmethod - def forward(self, x, y): - self.x = x - self.y = y - return x*y, x/y - - @staticmethod - def backward(self, grad0, grad1): - assert grad1 is None - return grad0 * self.y, None - a = jt.array(3.0) - a.requires_grad = True - b = jt.array(4.0) - b.requires_grad = True - func = MyFunc.apply - c,d = func(a, b) - d.stop_grad() - da, db = jt.grad(c+d*3, [a, b]) - assert da.data == 4 - assert db.data == 0 - -if __name__ == "__main__": - unittest.main() diff --git a/python/jittor/compatibility/test/test_misc.py b/python/jittor/compatibility/test/test_misc.py deleted file mode 100644 index 00bf1b70..00000000 --- a/python/jittor/compatibility/test/test_misc.py +++ /dev/null @@ -1,24 +0,0 @@ -import unittest -import numpy as np -import torch - -class TestMisc(unittest.TestCase): - def test_update_grad(self): - class Net(torch.nn.Module): - def __init__(self): - super().__init__() - self.a = torch.nn.Parameter(torch.Tensor([1.0, 2.0])) - net = Net() - assert(net.a.requires_grad) - net.load_state_dict({"a": torch.Tensor([3.0, 4.0])}) - assert(net.a.requires_grad) - - def test_reshape(self): - a = torch.ones(3,3) - a.requires_grad = True - b = torch.reshape(a, [9]) - assert b.requires_grad == True - - -if __name__ == "__main__": - unittest.main() diff --git a/python/jittor/compatibility/test/test_tutorial.py b/python/jittor/compatibility/test/test_tutorial.py deleted file mode 100644 index 92c087c7..00000000 --- a/python/jittor/compatibility/test/test_tutorial.py +++ /dev/null @@ -1,56 +0,0 @@ -import unittest -import numpy as np -import os -import subprocess as sp -import sys - -def check_two(cmd, parser=None, checker=None): - jtorch_out = sp.getoutput(cmd) - print("=========JTORCH OUT==========") - print(jtorch_out) - torch_out = sp.getoutput("PYTHONPATH= "+cmd) - print("=========TORCH OUT==========") - print(torch_out) - if parser: - torch_out = parser(torch_out) - jtorch_out = parser(jtorch_out) - if checker: - checker(torch_out, jtorch_out) - else: - assert torch_out == jtorch_out - return jtorch_out, torch_out - -jtorch_path = os.path.join(os.path.dirname(__file__), "..") -# come from https://pytorch.org/tutorials/beginner/pytorch_with_examples.html -class TestTutorial(unittest.TestCase): - def test_auto_grad1(self): - check_two(f"{sys.executable} {jtorch_path}/tutorial/auto_grad1.py", - parser=lambda s: np.array(s.split())[[-10,-8,-5,-2]].astype(float), - checker=lambda a,b: np.testing.assert_allclose(a, b, atol=1e-4)) - def test_auto_grad2(self): - check_two(f"{sys.executable} {jtorch_path}/tutorial/auto_grad2.py", - parser=lambda s: np.array(s.split())[[-10,-8,-5,-2]].astype(float), - checker=lambda a,b: np.testing.assert_allclose(a, b, atol=1e-4)) - def test_auto_grad3(self): - check_two(f"{sys.executable} {jtorch_path}/tutorial/auto_grad3.py", - parser=lambda s: np.array(s.split())[[-9,-7,-4,-2]].astype(float), - checker=lambda a,b: np.testing.assert_allclose(a, b, atol=1e-4)) - def test_auto_grad4(self): - check_two(f"{sys.executable} {jtorch_path}/tutorial/auto_grad4.py", - parser=lambda s: np.array(s.split())[[-10,-8,-5,-2]].astype(float), - checker=lambda a,b: np.testing.assert_allclose(a, b, atol=1e-4)) - def test_auto_grad5(self): - check_two(f"{sys.executable} {jtorch_path}/tutorial/auto_grad5_optim.py", - parser=lambda s: np.array(s.split())[[-10,-8,-5,-2]].astype(float), - checker=lambda a,b: np.testing.assert_allclose(a, b, atol=1e-2)) - def test_auto_grad6(self): - check_two(f"{sys.executable} {jtorch_path}/tutorial/auto_grad6_module.py", - parser=lambda s: np.array(s.split())[[-10,-8,-5,-2]].astype(float), - checker=lambda a,b: np.testing.assert_allclose(a, b, atol=1e-4)) - def test_auto_grad7(self): - check_two(f"{sys.executable} {jtorch_path}/tutorial/auto_grad7_dynet.py", - parser=lambda s: np.array(s.split())[[-13,-10,-7,-3]].astype(float), - checker=lambda a,b: np.testing.assert_allclose(a, b, atol=1e-2)) - -if __name__ == "__main__": - unittest.main() \ No newline at end of file diff --git a/python/jittor/compatibility/tutorial/auto_grad1.py b/python/jittor/compatibility/tutorial/auto_grad1.py deleted file mode 100644 index 60a090ad..00000000 --- a/python/jittor/compatibility/tutorial/auto_grad1.py +++ /dev/null @@ -1,44 +0,0 @@ -import torch -import math - -dtype = torch.float -device = torch.device("cpu") -# device = torch.device("cuda:0") # Uncomment this to run on GPU - -# Create random input and output data -x = torch.linspace(-math.pi, math.pi, 2000, device=device, dtype=dtype) -y = torch.sin(x) - -# Randomly initialize weights -a = torch.randn((), device=device, dtype=dtype) -b = torch.randn((), device=device, dtype=dtype) -c = torch.randn((), device=device, dtype=dtype) -d = torch.randn((), device=device, dtype=dtype) - -learning_rate = 1e-6 -for t in range(20000): - # Forward pass: compute predicted y - y_pred = a + b * x + c * x ** 2 + d * x ** 3 - - # Compute and print loss - loss = (y_pred - y).pow(2).sum().item() - if t % 1000 == 999: - print(t, loss) - - # Backprop to compute gradients of a, b, c, d with respect to loss - grad_y_pred = 2.0 * (y_pred - y) - grad_a = grad_y_pred.sum() - grad_b = (grad_y_pred * x).sum() - grad_c = (grad_y_pred * x ** 2).sum() - grad_d = (grad_y_pred * x ** 3).sum() - - # Update weights using gradient descent - a -= learning_rate * grad_a - b -= learning_rate * grad_b - c -= learning_rate * grad_c - d -= learning_rate * grad_d - # print(t, torch.liveness_info()) - # torch.sync_all() - - -print(f'Result: y = {a.item()} + {b.item()} x + {c.item()} x^2 + {d.item()} x^3') \ No newline at end of file diff --git a/python/jittor/compatibility/tutorial/auto_grad2.py b/python/jittor/compatibility/tutorial/auto_grad2.py deleted file mode 100644 index a3bbc9a8..00000000 --- a/python/jittor/compatibility/tutorial/auto_grad2.py +++ /dev/null @@ -1,60 +0,0 @@ -# -*- coding: utf-8 -*- -import torch -import math - -dtype = torch.float -device = torch.device("cpu") -# device = torch.device("cuda:0") # Uncomment this to run on GPU - -# Create Tensors to hold input and outputs. -# By default, requires_grad=False, which indicates that we do not need to -# compute gradients with respect to these Tensors during the backward pass. -x = torch.linspace(-math.pi, math.pi, 2000, device=device, dtype=dtype) -y = torch.sin(x) - -# Create random Tensors for weights. For a third order polynomial, we need -# 4 weights: y = a + b x + c x^2 + d x^3 -# Setting requires_grad=True indicates that we want to compute gradients with -# respect to these Tensors during the backward pass. -a = torch.randn((), device=device, dtype=dtype, requires_grad=True) -b = torch.randn((), device=device, dtype=dtype, requires_grad=True) -c = torch.randn((), device=device, dtype=dtype, requires_grad=True) -d = torch.randn((), device=device, dtype=dtype, requires_grad=True) - -learning_rate = 1e-6 -for t in range(20000): - # Forward pass: compute predicted y using operations on Tensors. - y_pred = a + b * x + c * x ** 2 + d * x ** 3 - # print(y_pred.requires_grad) - # y_pred.requires_grad = False - - # Compute and print loss using operations on Tensors. - # Now loss is a Tensor of shape (1,) - # loss.item() gets the scalar value held in the loss. - loss = (y_pred - y).pow(2).sum() - if t % 1000 == 990: - print(t, loss.item()) - - # Use autograd to compute the backward pass. This call will compute the - # gradient of loss with respect to all Tensors with requires_grad=True. - # After this call a.grad, b.grad. c.grad and d.grad will be Tensors holding - # the gradient of the loss with respect to a, b, c, d respectively. - # torch.backward(loss) - loss.backward() - - # Manually update weights using gradient descent. Wrap in torch.no_grad() - # because weights have requires_grad=True, but we don't need to track this - # in autograd. - with torch.no_grad(): - a -= learning_rate * a.grad - b -= learning_rate * b.grad - c -= learning_rate * c.grad - d -= learning_rate * d.grad - - # Manually zero the gradients after updating weights - a.grad = None - b.grad = None - c.grad = None - d.grad = None - -print(f'Result: y = {a.item()} + {b.item()} x + {c.item()} x^2 + {d.item()} x^3') \ No newline at end of file diff --git a/python/jittor/compatibility/tutorial/auto_grad3.py b/python/jittor/compatibility/tutorial/auto_grad3.py deleted file mode 100644 index 654ec447..00000000 --- a/python/jittor/compatibility/tutorial/auto_grad3.py +++ /dev/null @@ -1,85 +0,0 @@ -# -*- coding: utf-8 -*- -import torch -import math - - -class LegendrePolynomial3(torch.autograd.Function): - """ - We can implement our own custom autograd Functions by subclassing - torch.autograd.Function and implementing the forward and backward passes - which operate on Tensors. - """ - - @staticmethod - def forward(ctx, input): - """ - In the forward pass we receive a Tensor containing the input and return - a Tensor containing the output. ctx is a context object that can be used - to stash information for backward computation. You can cache arbitrary - objects for use in the backward pass using the ctx.save_for_backward method. - """ - ctx.save_for_backward(input) - return 0.5 * (5 * input ** 3 - 3 * input) - - @staticmethod - def backward(ctx, grad_output): - """ - In the backward pass we receive a Tensor containing the gradient of the loss - with respect to the output, and we need to compute the gradient of the loss - with respect to the input. - """ - input, = ctx.saved_tensors - return grad_output * 1.5 * (5 * input ** 2 - 1) - - -dtype = torch.float -device = torch.device("cpu") -# device = torch.device("cuda:0") # Uncomment this to run on GPU - -# Create Tensors to hold input and outputs. -# By default, requires_grad=False, which indicates that we do not need to -# compute gradients with respect to these Tensors during the backward pass. -x = torch.linspace(-math.pi, math.pi, 2000, device=device, dtype=dtype) -y = torch.sin(x) - -# Create random Tensors for weights. For this example, we need -# 4 weights: y = a + b * P3(c + d * x), these weights need to be initialized -# not too far from the correct result to ensure convergence. -# Setting requires_grad=True indicates that we want to compute gradients with -# respect to these Tensors during the backward pass. -a = torch.full((), 0.0, device=device, dtype=dtype, requires_grad=True) -b = torch.full((), -1.0, device=device, dtype=dtype, requires_grad=True) -c = torch.full((), 0.0, device=device, dtype=dtype, requires_grad=True) -d = torch.full((), 0.3, device=device, dtype=dtype, requires_grad=True) - -learning_rate = 5e-6 -for t in range(2000): - # To apply our Function, we use Function.apply method. We alias this as 'P3'. - P3 = LegendrePolynomial3.apply - - # Forward pass: compute predicted y using operations; we compute - # P3 using our custom autograd operation. - y_pred = a + b * P3(c + d * x) - - # Compute and print loss - loss = (y_pred - y).pow(2).sum() - if t % 100 == 99: - print(t, loss.item()) - - # Use autograd to compute the backward pass. - loss.backward() - - # Update weights using gradient descent - with torch.no_grad(): - a -= learning_rate * a.grad - b -= learning_rate * b.grad - c -= learning_rate * c.grad - d -= learning_rate * d.grad - - # Manually zero the gradients after updating weights - a.grad = None - b.grad = None - c.grad = None - d.grad = None - -print(f'Result: y = {a.item()} + {b.item()} * P3( {c.item()} + {d.item()} x)') \ No newline at end of file diff --git a/python/jittor/compatibility/tutorial/auto_grad4.py b/python/jittor/compatibility/tutorial/auto_grad4.py deleted file mode 100644 index 062d0b0e..00000000 --- a/python/jittor/compatibility/tutorial/auto_grad4.py +++ /dev/null @@ -1,71 +0,0 @@ -# -*- coding: utf-8 -*- -import torch -import math - - -# Create Tensors to hold input and outputs. -x = torch.linspace(-math.pi, math.pi, 2000) -y = torch.sin(x) - -# For this example, the output y is a linear function of (x, x^2, x^3), so -# we can consider it as a linear layer neural network. Let's prepare the -# tensor (x, x^2, x^3). -p = torch.tensor([1, 2, 3]) -xx = x.unsqueeze(-1).pow(p) - -# In the above code, x.unsqueeze(-1) has shape (2000, 1), and p has shape -# (3,), for this case, broadcasting semantics will apply to obtain a tensor -# of shape (2000, 3) - -# Use the nn package to define our model as a sequence of layers. nn.Sequential -# is a Module which contains other Modules, and applies them in sequence to -# produce its output. The Linear Module computes output from input using a -# linear function, and holds internal Tensors for its weight and bias. -# The Flatten layer flatens the output of the linear layer to a 1D tensor, -# to match the shape of `y`. -model = torch.nn.Sequential( - torch.nn.Linear(3, 1), - torch.nn.Flatten(0, 1) -) - -# The nn package also contains definitions of popular loss functions; in this -# case we will use Mean Squared Error (MSE) as our loss function. -loss_fn = torch.nn.MSELoss(reduction='sum') -# print(model[0].weight.requires_grad) - -learning_rate = 1e-6 -for t in range(8000): - - # Forward pass: compute predicted y by passing x to the model. Module objects - # override the __call__ operator so you can call them like functions. When - # doing so you pass a Tensor of input data to the Module and it produces - # a Tensor of output data. - y_pred = model(xx) - - # Compute and print loss. We pass Tensors containing the predicted and true - # values of y, and the loss function returns a Tensor containing the - # loss. - loss = loss_fn(y_pred, y) - if t % 1000 == 999: - print(t, loss.item()) - - # Zero the gradients before running the backward pass. - model.zero_grad() - - # Backward pass: compute gradient of the loss with respect to all the learnable - # parameters of the model. Internally, the parameters of each Module are stored - # in Tensors with requires_grad=True, so this call will compute gradients for - # all learnable parameters in the model. - loss.backward() - - # Update the weights using gradient descent. Each parameter is a Tensor, so - # we can access its gradients like we did before. - with torch.no_grad(): - for param in model.parameters(): - param -= learning_rate * param.grad - -# You can access the first layer of `model` like accessing the first item of a list -linear_layer = model[0] - -# For linear layer, its parameters are stored as `weight` and `bias`. -print(f'Result: y = {linear_layer.bias.item()} + {linear_layer.weight[:, 0].item()} x + {linear_layer.weight[:, 1].item()} x^2 + {linear_layer.weight[:, 2].item()} x^3') \ No newline at end of file diff --git a/python/jittor/compatibility/tutorial/auto_grad5_optim.py b/python/jittor/compatibility/tutorial/auto_grad5_optim.py deleted file mode 100644 index 04949320..00000000 --- a/python/jittor/compatibility/tutorial/auto_grad5_optim.py +++ /dev/null @@ -1,53 +0,0 @@ -# -*- coding: utf-8 -*- -import torch -import math - - -# Create Tensors to hold input and outputs. -x = torch.linspace(-math.pi, math.pi, 2000) -y = torch.sin(x) - -# Prepare the input tensor (x, x^2, x^3). -p = torch.tensor([1, 2, 3]) -xx = x.unsqueeze(-1).pow(p) - -# Use the nn package to define our model and loss function. -model = torch.nn.Sequential( - torch.nn.Linear(3, 1), - torch.nn.Flatten(0, 1) -) -loss_fn = torch.nn.MSELoss(reduction='sum') - -# Use the optim package to define an Optimizer that will update the weights of -# the model for us. Here we will use RMSprop; the optim package contains many other -# optimization algorithms. The first argument to the RMSprop constructor tells the -# optimizer which Tensors it should update. -learning_rate = 1e-3 -optimizer = torch.optim.RMSprop(model.parameters(), lr=learning_rate) -for t in range(8000): - # Forward pass: compute predicted y by passing x to the model. - y_pred = model(xx) - - # Compute and print loss. - loss = loss_fn(y_pred, y) - if t % 1000 == 999: - print(t, loss.item()) - - # Before the backward pass, use the optimizer object to zero all of the - # gradients for the variables it will update (which are the learnable - # weights of the model). This is because by default, gradients are - # accumulated in buffers( i.e, not overwritten) whenever .backward() - # is called. Checkout docs of torch.autograd.backward for more details. - optimizer.zero_grad() - - # Backward pass: compute gradient of the loss with respect to model - # parameters - loss.backward() - - # Calling the step function on an Optimizer makes an update to its - # parameters - optimizer.step() - - -linear_layer = model[0] -print(f'Result: y = {linear_layer.bias.item()} + {linear_layer.weight[:, 0].item()} x + {linear_layer.weight[:, 1].item()} x^2 + {linear_layer.weight[:, 2].item()} x^3') \ No newline at end of file diff --git a/python/jittor/compatibility/tutorial/auto_grad6_module.py b/python/jittor/compatibility/tutorial/auto_grad6_module.py deleted file mode 100644 index a240e2b5..00000000 --- a/python/jittor/compatibility/tutorial/auto_grad6_module.py +++ /dev/null @@ -1,59 +0,0 @@ -# -*- coding: utf-8 -*- -import torch -import math - - -class Polynomial3(torch.nn.Module): - def __init__(self): - """ - In the constructor we instantiate four parameters and assign them as - member parameters. - """ - super().__init__() - self.a = torch.nn.Parameter(torch.randn(())) - self.b = torch.nn.Parameter(torch.randn(())) - self.c = torch.nn.Parameter(torch.randn(())) - self.d = torch.nn.Parameter(torch.randn(())) - - def forward(self, x): - """ - In the forward function we accept a Tensor of input data and we must return - a Tensor of output data. We can use Modules defined in the constructor as - well as arbitrary operators on Tensors. - """ - return self.a + self.b * x + self.c * x ** 2 + self.d * x ** 3 - - def string(self): - """ - Just like any class in Python, you can also define custom method on PyTorch modules - """ - return f'y = {self.a.item()} + {self.b.item()} x + {self.c.item()} x^2 + {self.d.item()} x^3' - - -# Create Tensors to hold input and outputs. -x = torch.linspace(-math.pi, math.pi, 2000) -y = torch.sin(x) - -# Construct our model by instantiating the class defined above -model = Polynomial3() - -# Construct our loss function and an Optimizer. The call to model.parameters() -# in the SGD constructor will contain the learnable parameters (defined -# with torch.nn.Parameter) which are members of the model. -criterion = torch.nn.MSELoss(reduction='sum') -optimizer = torch.optim.SGD(model.parameters(), lr=1e-6) -for t in range(8000): - # Forward pass: Compute predicted y by passing x to the model - y_pred = model(x) - - # Compute and print loss - loss = criterion(y_pred, y) - if t % 1000 == 999: - print(t, loss.item()) - - # Zero gradients, perform a backward pass, and update the weights. - optimizer.zero_grad() - loss.backward() - optimizer.step() - -print(f'Result: {model.string()}') \ No newline at end of file diff --git a/python/jittor/compatibility/tutorial/auto_grad7_dynet.py b/python/jittor/compatibility/tutorial/auto_grad7_dynet.py deleted file mode 100644 index fa954771..00000000 --- a/python/jittor/compatibility/tutorial/auto_grad7_dynet.py +++ /dev/null @@ -1,69 +0,0 @@ -# -*- coding: utf-8 -*- -import random -import torch -import math - - -class DynamicNet(torch.nn.Module): - def __init__(self): - """ - In the constructor we instantiate five parameters and assign them as members. - """ - super().__init__() - self.a = torch.nn.Parameter(torch.randn(())) - self.b = torch.nn.Parameter(torch.randn(())) - self.c = torch.nn.Parameter(torch.randn(())) - self.d = torch.nn.Parameter(torch.randn(())) - self.e = torch.nn.Parameter(torch.randn(())) - - def forward(self, x): - """ - For the forward pass of the model, we randomly choose either 4, 5 - and reuse the e parameter to compute the contribution of these orders. - - Since each forward pass builds a dynamic computation graph, we can use normal - Python control-flow operators like loops or conditional statements when - defining the forward pass of the model. - - Here we also see that it is perfectly safe to reuse the same parameter many - times when defining a computational graph. - """ - y = self.a + self.b * x + self.c * x ** 2 + self.d * x ** 3 - for exp in range(4, random.randint(4, 6)): - y = y + self.e * x ** exp - return y - - def string(self): - """ - Just like any class in Python, you can also define custom method on PyTorch modules - """ - return f'y = {self.a.item()} + {self.b.item()} x + {self.c.item()} x^2 + {self.d.item()} x^3 + {self.e.item()} x^4 ? + {self.e.item()} x^5 ?' - - -# Create Tensors to hold input and outputs. -x = torch.linspace(-math.pi, math.pi, 2000) -y = torch.sin(x) - -# Construct our model by instantiating the class defined above -model = DynamicNet() - -# Construct our loss function and an Optimizer. Training this strange model with -# vanilla stochastic gradient descent is tough, so we use momentum -criterion = torch.nn.MSELoss(reduction='sum') -optimizer = torch.optim.SGD(model.parameters(), lr=1e-8, momentum=0.9) -for t in range(60000): - # Forward pass: Compute predicted y by passing x to the model - y_pred = model(x) - - # Compute and print loss - loss = criterion(y_pred, y) - if t % 2000 == 1999: - print(t, loss.item()) - - # Zero gradients, perform a backward pass, and update the weights. - optimizer.zero_grad() - loss.backward() - optimizer.step() - # print(torch.liveness_info()) - -print(f'Result: {model.string()}') \ No newline at end of file diff --git a/python/jittor/compatibility/tutorial/quickstart.py b/python/jittor/compatibility/tutorial/quickstart.py deleted file mode 100644 index f0401a9b..00000000 --- a/python/jittor/compatibility/tutorial/quickstart.py +++ /dev/null @@ -1,106 +0,0 @@ -import torch -from torch import nn -# from jtorch.utils import DataLoader -from torch.utils.data import DataLoader -from torchvision import datasets -from torchvision.transforms import ToTensor - -# Download training data from open datasets. -training_data = datasets.FashionMNIST( - root="data", - train=True, - download=True, - transform=ToTensor(), -) - -# Download test data from open datasets. -test_data = datasets.FashionMNIST( - root="data", - train=False, - download=True, - transform=ToTensor(), -) - -batch_size = 64 - -# Create data loaders. -train_dataloader = DataLoader(training_data, batch_size=batch_size) -test_dataloader = DataLoader(test_data, batch_size=batch_size) - -print(len(train_dataloader)) -for X, y in test_dataloader: - print(f"Shape of X [N, C, H, W]: {X.shape}") - print(f"Shape of y: {y.shape} {y.dtype}") - break - -# Get cpu or gpu device for training. -device = "cuda" if torch.cuda.is_available() else "cpu" -print(f"Using {device} device") - -# Define model -class NeuralNetwork(nn.Module): - def __init__(self): - super(NeuralNetwork, self).__init__() - self.flatten = nn.Flatten() - self.linear_relu_stack = nn.Sequential( - nn.Linear(28*28, 512), - nn.ReLU(), - nn.Linear(512, 512), - nn.ReLU(), - nn.Linear(512, 10) - ) - - def forward(self, x): - x = self.flatten(x) - logits = self.linear_relu_stack(x) - return logits - -model = NeuralNetwork().to(device) -print(model) - - -loss_fn = nn.CrossEntropyLoss() -optimizer = torch.optim.SGD(model.parameters(), lr=1e-3) - -def train(dataloader, model, loss_fn, optimizer): - size = len(dataloader.dataset) - model.train() - for batch, (X, y) in enumerate(dataloader): - X, y = X.to(device), y.to(device) - - # Compute prediction error - pred = model(X) - loss = loss_fn(pred, y) - - # Backpropagation - optimizer.zero_grad() - loss.backward() - optimizer.step() - - if batch % 100 == 0: - loss, current = loss.item(), batch * len(X) - print(f"loss: {loss:>7f} [{current:>5d}/{size:>5d}]") - -def test(dataloader, model, loss_fn): - size = len(dataloader.dataset) - num_batches = len(dataloader) - model.eval() - test_loss, correct = 0, 0 - with torch.no_grad(): - for X, y in dataloader: - X, y = X.to(device), y.to(device) - pred = model(X) - test_loss += loss_fn(pred, y).item() - correct += (pred.argmax(1) == y).type(torch.float).sum().item() - test_loss /= num_batches - correct /= size - print(f"Test Error: \n Accuracy: {(100*correct):>0.1f}%, Avg loss: {test_loss:>8f} \n") - - -epochs = 5 -test(test_dataloader, model, loss_fn) -for t in range(epochs): - print(f"Epoch {t+1}\n-------------------------------") - train(train_dataloader, model, loss_fn, optimizer) - test(test_dataloader, model, loss_fn) -print("Done!") \ No newline at end of file diff --git a/python/jittor/compatibility/utils/__init__.py b/python/jittor/compatibility/utils/__init__.py deleted file mode 100644 index ac2c2bd8..00000000 --- a/python/jittor/compatibility/utils/__init__.py +++ /dev/null @@ -1,5 +0,0 @@ -cpp_extension = None -_flatten_dense_tensors = None -_unflatten_dense_tensors = None - -tensorboard = None \ No newline at end of file diff --git a/python/jittor/compatibility/utils/_pytree.py b/python/jittor/compatibility/utils/_pytree.py deleted file mode 100644 index c3118964..00000000 --- a/python/jittor/compatibility/utils/_pytree.py +++ /dev/null @@ -1,3 +0,0 @@ -#TODO: Implement this -_register_pytree_node = None -_dict_flatten = None \ No newline at end of file diff --git a/python/jittor/compatibility/utils/checkpoint.py b/python/jittor/compatibility/utils/checkpoint.py deleted file mode 100644 index ba3c3e8e..00000000 --- a/python/jittor/compatibility/utils/checkpoint.py +++ /dev/null @@ -1,8 +0,0 @@ -detach_variable = None - - -def checkpoint( - *args, - **kwargs -): - pass diff --git a/python/jittor/compatibility/utils/data.py b/python/jittor/compatibility/utils/data.py deleted file mode 100644 index 5fcfcaa6..00000000 --- a/python/jittor/compatibility/utils/data.py +++ /dev/null @@ -1,137 +0,0 @@ -import jittor as jt -import jittor.dataset -from jittor.dataset import Dataset as JDataset - -from collections import namedtuple -from typing import Any, Callable, Iterable, Optional, Sequence, Union - - -class Dataset: - def __getitem__(self, index): - raise NotImplementedError - -class IterableDataset: - def __iter__(self): - raise NotImplementedError - - -class DataLoader(JDataset): - def __init__(self, dataset, - batch_size: Optional[int] = 1, - shuffle: Optional[bool] = False, - sampler = None, - batch_sampler = None, - num_workers: int = 0, - collate_fn = None, - pin_memory: bool = False, - drop_last: bool = False, - timeout: float = 0, - worker_init_fn = None, - multiprocessing_context=None, - generator=None, - *, prefetch_factor: int = 2, - persistent_workers: bool = False, - pin_memory_device: str = "") -> None: - super().__init__(batch_size=batch_size, - shuffle=shuffle, - num_workers=num_workers, - drop_last=drop_last) - - unsupported_kwargs = { - "batch_sampler": batch_sampler, - "pin_memory": pin_memory, - "timeout": timeout, - "worker_init_fn": worker_init_fn, - "multiprocessing_context": multiprocessing_context, - "generator": generator, - "persistent_workers": persistent_workers, - "pin_memory_device": pin_memory_device - } - for kwarg, value in unsupported_kwargs.items(): - if value: - jt.LOG.w(f"Not implemented Dataloader kwarg: {kwarg}") - - self.dataset = dataset - self.collate_fn = collate_fn - self.sampler = sampler - - if not isinstance(dataset, IterableDataset): - self.total_len = len(dataset) - else: - # TODO: support multiple worker for iterable dataset - assert(num_workers == 0) - - def collate_batch(self, batch): - if self.collate_fn is not None: - return self.collate_fn(batch) - else: - return super().collate_batch(batch) - - def __getitem__(self, i): - return self.dataset[i] - - def __iter__(self): - if isinstance(self.dataset, IterableDataset): - return self.inner_iter() - else: - return super().__iter__() - - def inner_iter(self): - current_batch = [] - - if jt.world_size > 1: - assert self.batch_size % jt.world_size == 0, \ - f"IterableDataset does not support a batch size ({self.batch_size}) that is not evenly divisible by the number of processes f{jt.world_size}" - real_batch_size = int(self.batch_size / jt.world_size) - else: - real_batch_size = self.batch_size - - for element in self.dataset: - current_batch.append(element) - - if len(current_batch) == real_batch_size: - current_batch = self.collate_batch(current_batch) - current_batch = self.to_jittor(current_batch) - yield current_batch - current_batch = [] - - if not self.drop_last and len(current_batch) > 0: - current_batch = self.collate_batch(current_batch) - yield self.to_jittor(current_batch) - -# def get_worker_info(): -# # always return the fake worker info -# return namedtuple('WorkerInfo', 'id num_workers')(0, 1) - -# class RandomSampler(jt.dataset.RandomSampler): -# def __init__(self, dataset, generator=None, **kwargs): -# super().__init__(dataset, **kwargs) - -# def __iter__(self): -# if getattr(self.dataset, "support_random_access", True): -# return super().__iter__() -# else: -# self.dataset.shuffle() -# return iter(range(self.dataset.__real_len__() if hasattr(self.dataset,"__real_len__") else self.dataset.__len__())) - -# class DistributedSampler(jt.dataset.Sampler): -# def __init__(self, sampler: RandomSampler): -# assert(isinstance(sampler, RandomSampler)) -# self.sampler = sampler - -# def set_epoch(self, epoch: int): -# ### do nothing, let jittor's inner dataset handle -# pass - -# def __iter__(self): -# return self.sampler.__iter__() - -# def __len__(self): -# return self.sampler.__len__() - -# BatchSampler = jt.dataset.BatchSampler -# Sampler = jt.dataset.Sampler -# SequentialSampler = jt.dataset.SequentialSampler -# SubsetRandomSampler = jt.dataset.SubsetRandomSampler - -# TensorDataset = Dataset diff --git a/python/jittor/compatibility/utils/dtype.py b/python/jittor/compatibility/utils/dtype.py deleted file mode 100644 index 41728383..00000000 --- a/python/jittor/compatibility/utils/dtype.py +++ /dev/null @@ -1,9 +0,0 @@ -from typing import Callable, Union -Dtype = Union[Callable, str] - -def get_string_dtype(dtype): - if callable(dtype): - dtype = dtype.__name__ - if not isinstance(dtype, str): - raise ValueError(f"dtype is expected to be str, python type function, or jittor type function, but got {dtype}.") - return dtype \ No newline at end of file diff --git a/python/jittor/compatibility/utils/hooks.py b/python/jittor/compatibility/utils/hooks.py deleted file mode 100644 index e69de29b..00000000 diff --git a/python/jittor/compatibility/utils/pip_publish.py b/python/jittor/compatibility/utils/pip_publish.py deleted file mode 100644 index 72ff245f..00000000 --- a/python/jittor/compatibility/utils/pip_publish.py +++ /dev/null @@ -1,34 +0,0 @@ -import os -import glob -import shutil -import sys - -home_path = os.path.join(os.path.dirname(__file__), "..", "..", "..") -home_path = os.path.abspath(home_path) - -def callback(func, path, exc_info): - print(f"remove \"{path}\" failed.") - -def rmtree(path): - if os.path.isdir(path): - print(f"remove \"{path}\" recursive.") - shutil.rmtree(path, onerror=callback) - -def remove_tmpfile(): - dist_file = home_path+"/dist" - egg_file = glob.glob(home_path+"/**/*egg-info") - rmtree(dist_file) - for e in egg_file: - rmtree(e) - -def run_cmd(cmd): - print("[CMD]", cmd) - assert os.system(cmd)==0 - -os.chdir(home_path) -remove_tmpfile() - -run_cmd(f"{sys.executable} ./setup.py sdist") -run_cmd(f"{sys.executable} -m twine upload dist/*") - -remove_tmpfile() \ No newline at end of file diff --git a/python/jittor/compatibility/vision/_internally_replaced_utils.py b/python/jittor/compatibility/vision/_internally_replaced_utils.py deleted file mode 100644 index 748fa2ea..00000000 --- a/python/jittor/compatibility/vision/_internally_replaced_utils.py +++ /dev/null @@ -1,46 +0,0 @@ -import importlib.machinery -import os - - -def _download_file_from_remote_location(fpath: str, url: str) -> None: - pass - - -def _is_remote_location_available() -> bool: - return False - - -def _get_extension_path(lib_name): - - lib_dir = os.path.dirname(__file__) - if os.name == "nt": - # Register the main torchvision library location on the default DLL path - import ctypes - import sys - - kernel32 = ctypes.WinDLL("kernel32.dll", use_last_error=True) - with_load_library_flags = hasattr(kernel32, "AddDllDirectory") - prev_error_mode = kernel32.SetErrorMode(0x0001) - - if with_load_library_flags: - kernel32.AddDllDirectory.restype = ctypes.c_void_p - - if sys.version_info >= (3, 8): - os.add_dll_directory(lib_dir) - elif with_load_library_flags: - res = kernel32.AddDllDirectory(lib_dir) - if res is None: - err = ctypes.WinError(ctypes.get_last_error()) - err.strerror += f' Error adding "{lib_dir}" to the DLL directories.' - raise err - - kernel32.SetErrorMode(prev_error_mode) - - loader_details = (importlib.machinery.ExtensionFileLoader, importlib.machinery.EXTENSION_SUFFIXES) - - extfinder = importlib.machinery.FileFinder(lib_dir, loader_details) - ext_specs = extfinder.find_spec(lib_name) - if ext_specs is None: - raise ImportError - - return ext_specs.origin \ No newline at end of file diff --git a/python/jittor/compatibility/vision/datasets/__init__.py b/python/jittor/compatibility/vision/datasets/__init__.py deleted file mode 100644 index d04187f1..00000000 --- a/python/jittor/compatibility/vision/datasets/__init__.py +++ /dev/null @@ -1,9 +0,0 @@ -from .mnist import EMNIST, FashionMNIST, KMNIST, MNIST, QMNIST - -__all__ = ( - "EMNIST", - "FashionMNIST", - "QMNIST", - "MNIST", - "KMNIST", -) \ No newline at end of file diff --git a/python/jittor/compatibility/vision/datasets/mnist.py b/python/jittor/compatibility/vision/datasets/mnist.py deleted file mode 100644 index dfc3787b..00000000 --- a/python/jittor/compatibility/vision/datasets/mnist.py +++ /dev/null @@ -1,558 +0,0 @@ -import codecs -import os -import os.path -import shutil -import string -import sys -import warnings -from typing import Any, Callable, Dict, List, Optional, Tuple -from urllib.error import URLError - -import numpy as np -import torch -from PIL import Image - -from .utils import check_integrity, download_and_extract_archive, extract_archive, verify_str_arg -from .vision import VisionDataset - - -class MNIST(VisionDataset): - """`MNIST `_ Dataset. - - Args: - root (string): Root directory of dataset where ``MNIST/raw/train-images-idx3-ubyte`` - and ``MNIST/raw/t10k-images-idx3-ubyte`` exist. - train (bool, optional): If True, creates dataset from ``train-images-idx3-ubyte``, - otherwise from ``t10k-images-idx3-ubyte``. - download (bool, optional): If True, downloads the dataset from the internet and - puts it in root directory. If dataset is already downloaded, it is not - downloaded again. - transform (callable, optional): A function/transform that takes in an PIL image - and returns a transformed version. E.g, ``transforms.RandomCrop`` - target_transform (callable, optional): A function/transform that takes in the - target and transforms it. - """ - - mirrors = [ - "http://yann.lecun.com/exdb/mnist/", - "https://ossci-datasets.s3.amazonaws.com/mnist/", - ] - - resources = [ - ("train-images-idx3-ubyte.gz", "f68b3c2dcbeaaa9fbdd348bbdeb94873"), - ("train-labels-idx1-ubyte.gz", "d53e105ee54ea40749a09fcbcd1e9432"), - ("t10k-images-idx3-ubyte.gz", "9fb629c4189551a2d022fa330f9573f3"), - ("t10k-labels-idx1-ubyte.gz", "ec29112dd5afa0611ce80d1b7f02629c"), - ] - - training_file = "training.pt" - test_file = "test.pt" - classes = [ - "0 - zero", - "1 - one", - "2 - two", - "3 - three", - "4 - four", - "5 - five", - "6 - six", - "7 - seven", - "8 - eight", - "9 - nine", - ] - - @property - def train_labels(self): - warnings.warn("train_labels has been renamed targets") - return self.targets - - @property - def test_labels(self): - warnings.warn("test_labels has been renamed targets") - return self.targets - - @property - def train_data(self): - warnings.warn("train_data has been renamed data") - return self.data - - @property - def test_data(self): - warnings.warn("test_data has been renamed data") - return self.data - - def __init__( - self, - root: str, - train: bool = True, - transform: Optional[Callable] = None, - target_transform: Optional[Callable] = None, - download: bool = False, - ) -> None: - super().__init__(root, transform=transform, target_transform=target_transform) - self.train = train # training set or test set - - if self._check_legacy_exist(): - self.data, self.targets = self._load_legacy_data() - return - - if download: - self.download() - - if not self._check_exists(): - raise RuntimeError("Dataset not found. You can use download=True to download it") - - self.data, self.targets = self._load_data() - - def _check_legacy_exist(self): - processed_folder_exists = os.path.exists(self.processed_folder) - if not processed_folder_exists: - return False - - return all( - check_integrity(os.path.join(self.processed_folder, file)) for file in (self.training_file, self.test_file) - ) - - def _load_legacy_data(self): - # This is for BC only. We no longer cache the data in a custom binary, but simply read from the raw data - # directly. - data_file = self.training_file if self.train else self.test_file - return torch.load(os.path.join(self.processed_folder, data_file)) - - def _load_data(self): - image_file = f"{'train' if self.train else 't10k'}-images-idx3-ubyte" - data = read_image_file(os.path.join(self.raw_folder, image_file)) - - label_file = f"{'train' if self.train else 't10k'}-labels-idx1-ubyte" - targets = read_label_file(os.path.join(self.raw_folder, label_file)) - - return data, targets - - def __getitem__(self, index: int) -> Tuple[Any, Any]: - """ - Args: - index (int): Index - - Returns: - tuple: (image, target) where target is index of the target class. - """ - img, target = self.data[index], int(self.targets[index]) - - # doing this so that it is consistent with all other datasets - # to return a PIL Image - img = Image.fromarray(img.numpy(), mode="L") - - if self.transform is not None: - img = self.transform(img) - - if self.target_transform is not None: - target = self.target_transform(target) - - return img, target - - def __len__(self) -> int: - return len(self.data) - - @property - def raw_folder(self) -> str: - return os.path.join(self.root, self.__class__.__name__, "raw") - - @property - def processed_folder(self) -> str: - return os.path.join(self.root, self.__class__.__name__, "processed") - - @property - def class_to_idx(self) -> Dict[str, int]: - return {_class: i for i, _class in enumerate(self.classes)} - - def _check_exists(self) -> bool: - return all( - check_integrity(os.path.join(self.raw_folder, os.path.splitext(os.path.basename(url))[0])) - for url, _ in self.resources - ) - - def download(self) -> None: - """Download the MNIST data if it doesn't exist already.""" - - if self._check_exists(): - return - - os.makedirs(self.raw_folder, exist_ok=True) - - # download files - for filename, md5 in self.resources: - for mirror in self.mirrors: - url = f"{mirror}{filename}" - try: - print(f"Downloading {url}") - download_and_extract_archive(url, download_root=self.raw_folder, filename=filename, md5=md5) - except URLError as error: - print(f"Failed to download (trying next):\n{error}") - continue - finally: - print() - break - else: - raise RuntimeError(f"Error downloading {filename}") - - def extra_repr(self) -> str: - split = "Train" if self.train is True else "Test" - return f"Split: {split}" - - -class FashionMNIST(MNIST): - """`Fashion-MNIST `_ Dataset. - - Args: - root (string): Root directory of dataset where ``FashionMNIST/raw/train-images-idx3-ubyte`` - and ``FashionMNIST/raw/t10k-images-idx3-ubyte`` exist. - train (bool, optional): If True, creates dataset from ``train-images-idx3-ubyte``, - otherwise from ``t10k-images-idx3-ubyte``. - download (bool, optional): If True, downloads the dataset from the internet and - puts it in root directory. If dataset is already downloaded, it is not - downloaded again. - transform (callable, optional): A function/transform that takes in an PIL image - and returns a transformed version. E.g, ``transforms.RandomCrop`` - target_transform (callable, optional): A function/transform that takes in the - target and transforms it. - """ - - mirrors = ["http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/"] - - resources = [ - ("train-images-idx3-ubyte.gz", "8d4fb7e6c68d591d4c3dfef9ec88bf0d"), - ("train-labels-idx1-ubyte.gz", "25c81989df183df01b3e8a0aad5dffbe"), - ("t10k-images-idx3-ubyte.gz", "bef4ecab320f06d8554ea6380940ec79"), - ("t10k-labels-idx1-ubyte.gz", "bb300cfdad3c16e7a12a480ee83cd310"), - ] - classes = ["T-shirt/top", "Trouser", "Pullover", "Dress", "Coat", "Sandal", "Shirt", "Sneaker", "Bag", "Ankle boot"] - - -class KMNIST(MNIST): - """`Kuzushiji-MNIST `_ Dataset. - - Args: - root (string): Root directory of dataset where ``KMNIST/raw/train-images-idx3-ubyte`` - and ``KMNIST/raw/t10k-images-idx3-ubyte`` exist. - train (bool, optional): If True, creates dataset from ``train-images-idx3-ubyte``, - otherwise from ``t10k-images-idx3-ubyte``. - download (bool, optional): If True, downloads the dataset from the internet and - puts it in root directory. If dataset is already downloaded, it is not - downloaded again. - transform (callable, optional): A function/transform that takes in an PIL image - and returns a transformed version. E.g, ``transforms.RandomCrop`` - target_transform (callable, optional): A function/transform that takes in the - target and transforms it. - """ - - mirrors = ["http://codh.rois.ac.jp/kmnist/dataset/kmnist/"] - - resources = [ - ("train-images-idx3-ubyte.gz", "bdb82020997e1d708af4cf47b453dcf7"), - ("train-labels-idx1-ubyte.gz", "e144d726b3acfaa3e44228e80efcd344"), - ("t10k-images-idx3-ubyte.gz", "5c965bf0a639b31b8f53240b1b52f4d7"), - ("t10k-labels-idx1-ubyte.gz", "7320c461ea6c1c855c0b718fb2a4b134"), - ] - classes = ["o", "ki", "su", "tsu", "na", "ha", "ma", "ya", "re", "wo"] - - -class EMNIST(MNIST): - """`EMNIST `_ Dataset. - - Args: - root (string): Root directory of dataset where ``EMNIST/raw/train-images-idx3-ubyte`` - and ``EMNIST/raw/t10k-images-idx3-ubyte`` exist. - split (string): The dataset has 6 different splits: ``byclass``, ``bymerge``, - ``balanced``, ``letters``, ``digits`` and ``mnist``. This argument specifies - which one to use. - train (bool, optional): If True, creates dataset from ``training.pt``, - otherwise from ``test.pt``. - download (bool, optional): If True, downloads the dataset from the internet and - puts it in root directory. If dataset is already downloaded, it is not - downloaded again. - transform (callable, optional): A function/transform that takes in an PIL image - and returns a transformed version. E.g, ``transforms.RandomCrop`` - target_transform (callable, optional): A function/transform that takes in the - target and transforms it. - """ - - url = "https://www.itl.nist.gov/iaui/vip/cs_links/EMNIST/gzip.zip" - md5 = "58c8d27c78d21e728a6bc7b3cc06412e" - splits = ("byclass", "bymerge", "balanced", "letters", "digits", "mnist") - # Merged Classes assumes Same structure for both uppercase and lowercase version - _merged_classes = {"c", "i", "j", "k", "l", "m", "o", "p", "s", "u", "v", "w", "x", "y", "z"} - _all_classes = set(string.digits + string.ascii_letters) - classes_split_dict = { - "byclass": sorted(list(_all_classes)), - "bymerge": sorted(list(_all_classes - _merged_classes)), - "balanced": sorted(list(_all_classes - _merged_classes)), - "letters": ["N/A"] + list(string.ascii_lowercase), - "digits": list(string.digits), - "mnist": list(string.digits), - } - - def __init__(self, root: str, split: str, **kwargs: Any) -> None: - self.split = verify_str_arg(split, "split", self.splits) - self.training_file = self._training_file(split) - self.test_file = self._test_file(split) - super().__init__(root, **kwargs) - self.classes = self.classes_split_dict[self.split] - - @staticmethod - def _training_file(split) -> str: - return f"training_{split}.pt" - - @staticmethod - def _test_file(split) -> str: - return f"test_{split}.pt" - - @property - def _file_prefix(self) -> str: - return f"emnist-{self.split}-{'train' if self.train else 'test'}" - - @property - def images_file(self) -> str: - return os.path.join(self.raw_folder, f"{self._file_prefix}-images-idx3-ubyte") - - @property - def labels_file(self) -> str: - return os.path.join(self.raw_folder, f"{self._file_prefix}-labels-idx1-ubyte") - - def _load_data(self): - return read_image_file(self.images_file), read_label_file(self.labels_file) - - def _check_exists(self) -> bool: - return all(check_integrity(file) for file in (self.images_file, self.labels_file)) - - def download(self) -> None: - """Download the EMNIST data if it doesn't exist already.""" - - if self._check_exists(): - return - - os.makedirs(self.raw_folder, exist_ok=True) - - download_and_extract_archive(self.url, download_root=self.raw_folder, md5=self.md5) - gzip_folder = os.path.join(self.raw_folder, "gzip") - for gzip_file in os.listdir(gzip_folder): - if gzip_file.endswith(".gz"): - extract_archive(os.path.join(gzip_folder, gzip_file), self.raw_folder) - shutil.rmtree(gzip_folder) - - -class QMNIST(MNIST): - """`QMNIST `_ Dataset. - - Args: - root (string): Root directory of dataset whose ``raw`` - subdir contains binary files of the datasets. - what (string,optional): Can be 'train', 'test', 'test10k', - 'test50k', or 'nist' for respectively the mnist compatible - training set, the 60k qmnist testing set, the 10k qmnist - examples that match the mnist testing set, the 50k - remaining qmnist testing examples, or all the nist - digits. The default is to select 'train' or 'test' - according to the compatibility argument 'train'. - compat (bool,optional): A boolean that says whether the target - for each example is class number (for compatibility with - the MNIST dataloader) or a torch vector containing the - full qmnist information. Default=True. - download (bool, optional): If True, downloads the dataset from - the internet and puts it in root directory. If dataset is - already downloaded, it is not downloaded again. - transform (callable, optional): A function/transform that - takes in an PIL image and returns a transformed - version. E.g, ``transforms.RandomCrop`` - target_transform (callable, optional): A function/transform - that takes in the target and transforms it. - train (bool,optional,compatibility): When argument 'what' is - not specified, this boolean decides whether to load the - training set ot the testing set. Default: True. - """ - - subsets = {"train": "train", "test": "test", "test10k": "test", "test50k": "test", "nist": "nist"} - resources: Dict[str, List[Tuple[str, str]]] = { # type: ignore[assignment] - "train": [ - ( - "https://raw.githubusercontent.com/facebookresearch/qmnist/master/qmnist-train-images-idx3-ubyte.gz", - "ed72d4157d28c017586c42bc6afe6370", - ), - ( - "https://raw.githubusercontent.com/facebookresearch/qmnist/master/qmnist-train-labels-idx2-int.gz", - "0058f8dd561b90ffdd0f734c6a30e5e4", - ), - ], - "test": [ - ( - "https://raw.githubusercontent.com/facebookresearch/qmnist/master/qmnist-test-images-idx3-ubyte.gz", - "1394631089c404de565df7b7aeaf9412", - ), - ( - "https://raw.githubusercontent.com/facebookresearch/qmnist/master/qmnist-test-labels-idx2-int.gz", - "5b5b05890a5e13444e108efe57b788aa", - ), - ], - "nist": [ - ( - "https://raw.githubusercontent.com/facebookresearch/qmnist/master/xnist-images-idx3-ubyte.xz", - "7f124b3b8ab81486c9d8c2749c17f834", - ), - ( - "https://raw.githubusercontent.com/facebookresearch/qmnist/master/xnist-labels-idx2-int.xz", - "5ed0e788978e45d4a8bd4b7caec3d79d", - ), - ], - } - classes = [ - "0 - zero", - "1 - one", - "2 - two", - "3 - three", - "4 - four", - "5 - five", - "6 - six", - "7 - seven", - "8 - eight", - "9 - nine", - ] - - def __init__( - self, root: str, what: Optional[str] = None, compat: bool = True, train: bool = True, **kwargs: Any - ) -> None: - if what is None: - what = "train" if train else "test" - self.what = verify_str_arg(what, "what", tuple(self.subsets.keys())) - self.compat = compat - self.data_file = what + ".pt" - self.training_file = self.data_file - self.test_file = self.data_file - super().__init__(root, train, **kwargs) - - @property - def images_file(self) -> str: - (url, _), _ = self.resources[self.subsets[self.what]] - return os.path.join(self.raw_folder, os.path.splitext(os.path.basename(url))[0]) - - @property - def labels_file(self) -> str: - _, (url, _) = self.resources[self.subsets[self.what]] - return os.path.join(self.raw_folder, os.path.splitext(os.path.basename(url))[0]) - - def _check_exists(self) -> bool: - return all(check_integrity(file) for file in (self.images_file, self.labels_file)) - - def _load_data(self): - data = read_sn3_pascalvincent_tensor(self.images_file) - if data.dtype != torch.uint8: - raise TypeError(f"data should be of dtype torch.uint8 instead of {data.dtype}") - if data.ndimension() != 3: - raise ValueError("data should have 3 dimensions instead of {data.ndimension()}") - - targets = read_sn3_pascalvincent_tensor(self.labels_file).long() - if targets.ndimension() != 2: - raise ValueError(f"targets should have 2 dimensions instead of {targets.ndimension()}") - - if self.what == "test10k": - data = data[0:10000, :, :].clone() - targets = targets[0:10000, :].clone() - elif self.what == "test50k": - data = data[10000:, :, :].clone() - targets = targets[10000:, :].clone() - - return data, targets - - def download(self) -> None: - """Download the QMNIST data if it doesn't exist already. - Note that we only download what has been asked for (argument 'what'). - """ - if self._check_exists(): - return - - os.makedirs(self.raw_folder, exist_ok=True) - split = self.resources[self.subsets[self.what]] - - for url, md5 in split: - download_and_extract_archive(url, self.raw_folder, md5=md5) - - def __getitem__(self, index: int) -> Tuple[Any, Any]: - # redefined to handle the compat flag - img, target = self.data[index], self.targets[index] - img = Image.fromarray(img.numpy(), mode="L") - if self.transform is not None: - img = self.transform(img) - if self.compat: - target = int(target[0]) - if self.target_transform is not None: - target = self.target_transform(target) - return img, target - - def extra_repr(self) -> str: - return f"Split: {self.what}" - - -def get_int(b: bytes) -> int: - return int(codecs.encode(b, "hex"), 16) - - -SN3_PASCALVINCENT_BITSMAP = { - 8: torch.uint8, - 9: torch.int8, - 11: torch.int16, - 12: torch.int32, - 13: torch.float32, - 14: torch.float64, -} - -TORCH_TYPE_BITS = { - torch.uint8: 8, - torch.int8: 8, - torch.int16: 16, - torch.int32: 32, - torch.float32: 32, - torch.float64: 64, -} - - -def read_sn3_pascalvincent_tensor(path: str, strict: bool = True) -> torch.Tensor: - """Read a SN3 file in "Pascal Vincent" format (Lush file 'libidx/idx-io.lsh'). - Argument may be a filename, compressed filename, or file object. - """ - # read - with open(path, "rb") as f: - data = f.read() - # parse - magic = get_int(data[0:4]) - nd = magic % 256 - ty = magic // 256 - assert 1 <= nd <= 3 - assert 8 <= ty <= 14 - torch_type = SN3_PASCALVINCENT_BITSMAP[ty] - s = [get_int(data[4 * (i + 1) : 4 * (i + 2)]) for i in range(nd)] - - num_bytes_per_value = TORCH_TYPE_BITS[torch_type] // 8 - # The MNIST format uses the big endian byte order. If the system uses little endian byte order by default, - # we need to reverse the bytes before we can read them with torch.frombuffer(). - needs_byte_reversal = sys.byteorder == "little" and num_bytes_per_value > 1 - parsed = torch.frombuffer(bytearray(data), dtype=torch_type, offset=(4 * (nd + 1))) - if needs_byte_reversal: - parsed = parsed.flip(0) - - assert parsed.shape[0] == np.prod(s) or not strict - return parsed.view(*s) - - -def read_label_file(path: str) -> torch.Tensor: - x = read_sn3_pascalvincent_tensor(path, strict=False) - if x.dtype != torch.uint8: - raise TypeError(f"x should be of dtype torch.uint8 instead of {x.dtype}") - if x.ndimension() != 1: - raise ValueError(f"x should have 1 dimension instead of {x.ndimension()}") - return x.long() - - -def read_image_file(path: str) -> torch.Tensor: - x = read_sn3_pascalvincent_tensor(path, strict=False) - if x.dtype != torch.uint8: - raise TypeError(f"x should be of dtype torch.uint8 instead of {x.dtype}") - if x.ndimension() != 3: - raise ValueError(f"x should have 3 dimension instead of {x.ndimension()}") - return x \ No newline at end of file diff --git a/python/jittor/compatibility/vision/datasets/utils.py b/python/jittor/compatibility/vision/datasets/utils.py deleted file mode 100644 index f9ae1a89..00000000 --- a/python/jittor/compatibility/vision/datasets/utils.py +++ /dev/null @@ -1,522 +0,0 @@ -import bz2 -import contextlib -import gzip -import hashlib -import itertools -import lzma -import os -import os.path -import pathlib -import re -import sys -import tarfile -import urllib -import urllib.error -import urllib.request -import warnings -import zipfile -from typing import Any, Callable, Dict, IO, Iterable, Iterator, List, Optional, Tuple, TypeVar -from urllib.parse import urlparse - -import numpy as np -import requests -import torch -from tqdm import tqdm - -from .._internally_replaced_utils import _download_file_from_remote_location, _is_remote_location_available - -USER_AGENT = "pytorch/vision" - - -def _save_response_content( - content: Iterator[bytes], - destination: str, - length: Optional[int] = None, -) -> None: - with open(destination, "wb") as fh, tqdm(total=length) as pbar: - for chunk in content: - # filter out keep-alive new chunks - if not chunk: - continue - - fh.write(chunk) - pbar.update(len(chunk)) - - -def _urlretrieve(url: str, filename: str, chunk_size: int = 1024 * 32) -> None: - with urllib.request.urlopen(urllib.request.Request(url, headers={"User-Agent": USER_AGENT})) as response: - _save_response_content(iter(lambda: response.read(chunk_size), b""), filename, length=response.length) - - -def gen_bar_updater() -> Callable[[int, int, int], None]: - warnings.warn("The function `gen_bar_update` is deprecated since 0.13 and will be removed in 0.15.") - pbar = tqdm(total=None) - - def bar_update(count, block_size, total_size): - if pbar.total is None and total_size: - pbar.total = total_size - progress_bytes = count * block_size - pbar.update(progress_bytes - pbar.n) - - return bar_update - - -def calculate_md5(fpath: str, chunk_size: int = 1024 * 1024) -> str: - # Setting the `usedforsecurity` flag does not change anything about the functionality, but indicates that we are - # not using the MD5 checksum for cryptography. This enables its usage in restricted environments like FIPS. Without - # it torchvision.datasets is unusable in these environments since we perform a MD5 check everywhere. - if sys.version_info >= (3, 9): - md5 = hashlib.md5(usedforsecurity=False) - else: - md5 = hashlib.md5() - with open(fpath, "rb") as f: - for chunk in iter(lambda: f.read(chunk_size), b""): - md5.update(chunk) - return md5.hexdigest() - - -def check_md5(fpath: str, md5: str, **kwargs: Any) -> bool: - return md5 == calculate_md5(fpath, **kwargs) - - -def check_integrity(fpath: str, md5: Optional[str] = None) -> bool: - if not os.path.isfile(fpath): - return False - if md5 is None: - return True - return check_md5(fpath, md5) - - -def _get_redirect_url(url: str, max_hops: int = 3) -> str: - initial_url = url - headers = {"Method": "HEAD", "User-Agent": USER_AGENT} - - for _ in range(max_hops + 1): - with urllib.request.urlopen(urllib.request.Request(url, headers=headers)) as response: - if response.url == url or response.url is None: - return url - - url = response.url - else: - raise RecursionError( - f"Request to {initial_url} exceeded {max_hops} redirects. The last redirect points to {url}." - ) - - -def _get_google_drive_file_id(url: str) -> Optional[str]: - parts = urlparse(url) - - if re.match(r"(drive|docs)[.]google[.]com", parts.netloc) is None: - return None - - match = re.match(r"/file/d/(?P[^/]*)", parts.path) - if match is None: - return None - - return match.group("id") - - -def download_url( - url: str, root: str, filename: Optional[str] = None, md5: Optional[str] = None, max_redirect_hops: int = 3 -) -> None: - """Download a file from a url and place it in root. - - Args: - url (str): URL to download file from - root (str): Directory to place downloaded file in - filename (str, optional): Name to save the file under. If None, use the basename of the URL - md5 (str, optional): MD5 checksum of the download. If None, do not check - max_redirect_hops (int, optional): Maximum number of redirect hops allowed - """ - root = os.path.expanduser(root) - if not filename: - filename = os.path.basename(url) - fpath = os.path.join(root, filename) - - os.makedirs(root, exist_ok=True) - - # check if file is already present locally - if check_integrity(fpath, md5): - print("Using downloaded and verified file: " + fpath) - return - - if _is_remote_location_available(): - _download_file_from_remote_location(fpath, url) - else: - # expand redirect chain if needed - url = _get_redirect_url(url, max_hops=max_redirect_hops) - - # check if file is located on Google Drive - file_id = _get_google_drive_file_id(url) - if file_id is not None: - return download_file_from_google_drive(file_id, root, filename, md5) - - # download the file - try: - print("Downloading " + url + " to " + fpath) - _urlretrieve(url, fpath) - except (urllib.error.URLError, OSError) as e: # type: ignore[attr-defined] - if url[:5] == "https": - url = url.replace("https:", "http:") - print("Failed download. Trying https -> http instead. Downloading " + url + " to " + fpath) - _urlretrieve(url, fpath) - else: - raise e - - # check integrity of downloaded file - if not check_integrity(fpath, md5): - raise RuntimeError("File not found or corrupted.") - - -def list_dir(root: str, prefix: bool = False) -> List[str]: - """List all directories at a given root - - Args: - root (str): Path to directory whose folders need to be listed - prefix (bool, optional): If true, prepends the path to each result, otherwise - only returns the name of the directories found - """ - root = os.path.expanduser(root) - directories = [p for p in os.listdir(root) if os.path.isdir(os.path.join(root, p))] - if prefix is True: - directories = [os.path.join(root, d) for d in directories] - return directories - - -def list_files(root: str, suffix: str, prefix: bool = False) -> List[str]: - """List all files ending with a suffix at a given root - - Args: - root (str): Path to directory whose folders need to be listed - suffix (str or tuple): Suffix of the files to match, e.g. '.png' or ('.jpg', '.png'). - It uses the Python "str.endswith" method and is passed directly - prefix (bool, optional): If true, prepends the path to each result, otherwise - only returns the name of the files found - """ - root = os.path.expanduser(root) - files = [p for p in os.listdir(root) if os.path.isfile(os.path.join(root, p)) and p.endswith(suffix)] - if prefix is True: - files = [os.path.join(root, d) for d in files] - return files - - -def _extract_gdrive_api_response(response, chunk_size: int = 32 * 1024) -> Tuple[bytes, Iterator[bytes]]: - content = response.iter_content(chunk_size) - first_chunk = None - # filter out keep-alive new chunks - while not first_chunk: - first_chunk = next(content) - content = itertools.chain([first_chunk], content) - - try: - match = re.search("Google Drive - (?P<api_response>.+?)", first_chunk.decode()) - api_response = match["api_response"] if match is not None else None - except UnicodeDecodeError: - api_response = None - return api_response, content - - -def download_file_from_google_drive(file_id: str, root: str, filename: Optional[str] = None, md5: Optional[str] = None): - """Download a Google Drive file from and place it in root. - - Args: - file_id (str): id of file to be downloaded - root (str): Directory to place downloaded file in - filename (str, optional): Name to save the file under. If None, use the id of the file. - md5 (str, optional): MD5 checksum of the download. If None, do not check - """ - # Based on https://stackoverflow.com/questions/38511444/python-download-files-from-google-drive-using-url - - root = os.path.expanduser(root) - if not filename: - filename = file_id - fpath = os.path.join(root, filename) - - os.makedirs(root, exist_ok=True) - - if check_integrity(fpath, md5): - print(f"Using downloaded {'and verified ' if md5 else ''}file: {fpath}") - return - - url = "https://drive.google.com/uc" - params = dict(id=file_id, export="download") - with requests.Session() as session: - response = session.get(url, params=params, stream=True) - - for key, value in response.cookies.items(): - if key.startswith("download_warning"): - token = value - break - else: - api_response, content = _extract_gdrive_api_response(response) - token = "t" if api_response == "Virus scan warning" else None - - if token is not None: - response = session.get(url, params=dict(params, confirm=token), stream=True) - api_response, content = _extract_gdrive_api_response(response) - - if api_response == "Quota exceeded": - raise RuntimeError( - f"The daily quota of the file {filename} is exceeded and it " - f"can't be downloaded. This is a limitation of Google Drive " - f"and can only be overcome by trying again later." - ) - - _save_response_content(content, fpath) - - # In case we deal with an unhandled GDrive API response, the file should be smaller than 10kB and contain only text - if os.stat(fpath).st_size < 10 * 1024: - with contextlib.suppress(UnicodeDecodeError), open(fpath) as fh: - text = fh.read() - # Regular expression to detect HTML. Copied from https://stackoverflow.com/a/70585604 - if re.search(r"]*\s*>|(&(?:[\w\d]+|#\d+|#x[a-f\d]+);)", text): - warnings.warn( - f"We detected some HTML elements in the downloaded file. " - f"This most likely means that the download triggered an unhandled API response by GDrive. " - f"Please report this to torchvision at https://github.com/pytorch/vision/issues including " - f"the response:\n\n{text}" - ) - - if md5 and not check_md5(fpath, md5): - raise RuntimeError( - f"The MD5 checksum of the download file {fpath} does not match the one on record." - f"Please delete the file and try again. " - f"If the issue persists, please report this to torchvision at https://github.com/pytorch/vision/issues." - ) - - -def _extract_tar(from_path: str, to_path: str, compression: Optional[str]) -> None: - with tarfile.open(from_path, f"r:{compression[1:]}" if compression else "r") as tar: - tar.extractall(to_path) - - -_ZIP_COMPRESSION_MAP: Dict[str, int] = { - ".bz2": zipfile.ZIP_BZIP2, - ".xz": zipfile.ZIP_LZMA, -} - - -def _extract_zip(from_path: str, to_path: str, compression: Optional[str]) -> None: - with zipfile.ZipFile( - from_path, "r", compression=_ZIP_COMPRESSION_MAP[compression] if compression else zipfile.ZIP_STORED - ) as zip: - zip.extractall(to_path) - - -_ARCHIVE_EXTRACTORS: Dict[str, Callable[[str, str, Optional[str]], None]] = { - ".tar": _extract_tar, - ".zip": _extract_zip, -} -_COMPRESSED_FILE_OPENERS: Dict[str, Callable[..., IO]] = { - ".bz2": bz2.open, - ".gz": gzip.open, - ".xz": lzma.open, -} -_FILE_TYPE_ALIASES: Dict[str, Tuple[Optional[str], Optional[str]]] = { - ".tbz": (".tar", ".bz2"), - ".tbz2": (".tar", ".bz2"), - ".tgz": (".tar", ".gz"), -} - - -def _detect_file_type(file: str) -> Tuple[str, Optional[str], Optional[str]]: - """Detect the archive type and/or compression of a file. - - Args: - file (str): the filename - - Returns: - (tuple): tuple of suffix, archive type, and compression - - Raises: - RuntimeError: if file has no suffix or suffix is not supported - """ - suffixes = pathlib.Path(file).suffixes - if not suffixes: - raise RuntimeError( - f"File '{file}' has no suffixes that could be used to detect the archive type and compression." - ) - suffix = suffixes[-1] - - # check if the suffix is a known alias - if suffix in _FILE_TYPE_ALIASES: - return (suffix, *_FILE_TYPE_ALIASES[suffix]) - - # check if the suffix is an archive type - if suffix in _ARCHIVE_EXTRACTORS: - return suffix, suffix, None - - # check if the suffix is a compression - if suffix in _COMPRESSED_FILE_OPENERS: - # check for suffix hierarchy - if len(suffixes) > 1: - suffix2 = suffixes[-2] - - # check if the suffix2 is an archive type - if suffix2 in _ARCHIVE_EXTRACTORS: - return suffix2 + suffix, suffix2, suffix - - return suffix, None, suffix - - valid_suffixes = sorted(set(_FILE_TYPE_ALIASES) | set(_ARCHIVE_EXTRACTORS) | set(_COMPRESSED_FILE_OPENERS)) - raise RuntimeError(f"Unknown compression or archive type: '{suffix}'.\nKnown suffixes are: '{valid_suffixes}'.") - - -def _decompress(from_path: str, to_path: Optional[str] = None, remove_finished: bool = False) -> str: - r"""Decompress a file. - - The compression is automatically detected from the file name. - - Args: - from_path (str): Path to the file to be decompressed. - to_path (str): Path to the decompressed file. If omitted, ``from_path`` without compression extension is used. - remove_finished (bool): If ``True``, remove the file after the extraction. - - Returns: - (str): Path to the decompressed file. - """ - suffix, archive_type, compression = _detect_file_type(from_path) - if not compression: - raise RuntimeError(f"Couldn't detect a compression from suffix {suffix}.") - - if to_path is None: - to_path = from_path.replace(suffix, archive_type if archive_type is not None else "") - - # We don't need to check for a missing key here, since this was already done in _detect_file_type() - compressed_file_opener = _COMPRESSED_FILE_OPENERS[compression] - - with compressed_file_opener(from_path, "rb") as rfh, open(to_path, "wb") as wfh: - wfh.write(rfh.read()) - - if remove_finished: - os.remove(from_path) - - return to_path - - -def extract_archive(from_path: str, to_path: Optional[str] = None, remove_finished: bool = False) -> str: - """Extract an archive. - - The archive type and a possible compression is automatically detected from the file name. If the file is compressed - but not an archive the call is dispatched to :func:`decompress`. - - Args: - from_path (str): Path to the file to be extracted. - to_path (str): Path to the directory the file will be extracted to. If omitted, the directory of the file is - used. - remove_finished (bool): If ``True``, remove the file after the extraction. - - Returns: - (str): Path to the directory the file was extracted to. - """ - if to_path is None: - to_path = os.path.dirname(from_path) - - suffix, archive_type, compression = _detect_file_type(from_path) - if not archive_type: - return _decompress( - from_path, - os.path.join(to_path, os.path.basename(from_path).replace(suffix, "")), - remove_finished=remove_finished, - ) - - # We don't need to check for a missing key here, since this was already done in _detect_file_type() - extractor = _ARCHIVE_EXTRACTORS[archive_type] - - extractor(from_path, to_path, compression) - if remove_finished: - os.remove(from_path) - - return to_path - - -def download_and_extract_archive( - url: str, - download_root: str, - extract_root: Optional[str] = None, - filename: Optional[str] = None, - md5: Optional[str] = None, - remove_finished: bool = False, -) -> None: - download_root = os.path.expanduser(download_root) - if extract_root is None: - extract_root = download_root - if not filename: - filename = os.path.basename(url) - - download_url(url, download_root, filename, md5) - - archive = os.path.join(download_root, filename) - print(f"Extracting {archive} to {extract_root}") - extract_archive(archive, extract_root, remove_finished) - - -def iterable_to_str(iterable: Iterable) -> str: - return "'" + "', '".join([str(item) for item in iterable]) + "'" - - -T = TypeVar("T", str, bytes) - - -def verify_str_arg( - value: T, - arg: Optional[str] = None, - valid_values: Optional[Iterable[T]] = None, - custom_msg: Optional[str] = None, -) -> T: - if not isinstance(value, torch._six.string_classes): - if arg is None: - msg = "Expected type str, but got type {type}." - else: - msg = "Expected type str for argument {arg}, but got type {type}." - msg = msg.format(type=type(value), arg=arg) - raise ValueError(msg) - - if valid_values is None: - return value - - if value not in valid_values: - if custom_msg is not None: - msg = custom_msg - else: - msg = "Unknown value '{value}' for argument {arg}. Valid values are {{{valid_values}}}." - msg = msg.format(value=value, arg=arg, valid_values=iterable_to_str(valid_values)) - raise ValueError(msg) - - return value - - -def _read_pfm(file_name: str, slice_channels: int = 2) -> np.ndarray: - """Read file in .pfm format. Might contain either 1 or 3 channels of data. - - Args: - file_name (str): Path to the file. - slice_channels (int): Number of channels to slice out of the file. - Useful for reading different data formats stored in .pfm files: Optical Flows, Stereo Disparity Maps, etc. - """ - - with open(file_name, "rb") as f: - header = f.readline().rstrip() - if header not in [b"PF", b"Pf"]: - raise ValueError("Invalid PFM file") - - dim_match = re.match(rb"^(\d+)\s(\d+)\s$", f.readline()) - if not dim_match: - raise Exception("Malformed PFM header.") - w, h = (int(dim) for dim in dim_match.groups()) - - scale = float(f.readline().rstrip()) - if scale < 0: # little-endian - endian = "<" - scale = -scale - else: - endian = ">" # big-endian - - data = np.fromfile(f, dtype=endian + "f") - - pfm_channels = 3 if header == b"PF" else 1 - - data = data.reshape(h, w, pfm_channels).transpose(2, 0, 1) - data = np.flip(data, axis=1) # flip on h dimension - data = data[:slice_channels, :, :] - return data.astype(np.float32) \ No newline at end of file diff --git a/python/jittor/compatibility/vision/datasets/vision.py b/python/jittor/compatibility/vision/datasets/vision.py deleted file mode 100644 index d71dc2a5..00000000 --- a/python/jittor/compatibility/vision/datasets/vision.py +++ /dev/null @@ -1,104 +0,0 @@ -import os -from typing import Any, Callable, List, Optional, Tuple - -import torch -import torch.utils.data as data - -from ..utils import _log_api_usage_once - - -class VisionDataset(data.Dataset): - """ - Base Class For making datasets which are compatible with torchvision. - It is necessary to override the ``__getitem__`` and ``__len__`` method. - Args: - root (string): Root directory of dataset. - transforms (callable, optional): A function/transforms that takes in - an image and a label and returns the transformed versions of both. - transform (callable, optional): A function/transform that takes in an PIL image - and returns a transformed version. E.g, ``transforms.RandomCrop`` - target_transform (callable, optional): A function/transform that takes in the - target and transforms it. - .. note:: - :attr:`transforms` and the combination of :attr:`transform` and :attr:`target_transform` are mutually exclusive. - """ - - _repr_indent = 4 - - def __init__( - self, - root: str, - transforms: Optional[Callable] = None, - transform: Optional[Callable] = None, - target_transform: Optional[Callable] = None, - ) -> None: - self.root = root - - has_transforms = transforms is not None - has_separate_transform = transform is not None or target_transform is not None - if has_transforms and has_separate_transform: - raise ValueError("Only transforms or transform/target_transform can be passed as argument") - - # for backwards-compatibility - self.transform = transform - self.target_transform = target_transform - - if has_separate_transform: - transforms = StandardTransform(transform, target_transform) - self.transforms = transforms - - def __getitem__(self, index: int) -> Any: - """ - Args: - index (int): Index - Returns: - (Any): Sample and meta data, optionally transformed by the respective transforms. - """ - raise NotImplementedError - - def __len__(self) -> int: - raise NotImplementedError - - def __repr__(self) -> str: - head = "Dataset " + self.__class__.__name__ - body = [f"Number of datapoints: {self.__len__()}"] - if self.root is not None: - body.append(f"Root location: {self.root}") - body += self.extra_repr().splitlines() - if hasattr(self, "transforms") and self.transforms is not None: - body += [repr(self.transforms)] - lines = [head] + [" " * self._repr_indent + line for line in body] - return "\n".join(lines) - - def _format_transform_repr(self, transform: Callable, head: str) -> List[str]: - lines = transform.__repr__().splitlines() - return [f"{head}{lines[0]}"] + ["{}{}".format(" " * len(head), line) for line in lines[1:]] - - def extra_repr(self) -> str: - return "" - - -class StandardTransform: - def __init__(self, transform: Optional[Callable] = None, target_transform: Optional[Callable] = None) -> None: - self.transform = transform - self.target_transform = target_transform - - def __call__(self, input: Any, target: Any) -> Tuple[Any, Any]: - if self.transform is not None: - input = self.transform(input) - if self.target_transform is not None: - target = self.target_transform(target) - return input, target - - def _format_transform_repr(self, transform: Callable, head: str) -> List[str]: - lines = transform.__repr__().splitlines() - return [f"{head}{lines[0]}"] + ["{}{}".format(" " * len(head), line) for line in lines[1:]] - - def __repr__(self) -> str: - body = [self.__class__.__name__] - if self.transform is not None: - body += self._format_transform_repr(self.transform, "Transform: ") - if self.target_transform is not None: - body += self._format_transform_repr(self.target_transform, "Target transform: ") - - return "\n".join(body) \ No newline at end of file diff --git a/python/jittor/compatibility/vision/transforms.py b/python/jittor/compatibility/vision/transforms.py deleted file mode 100644 index 416057c7..00000000 --- a/python/jittor/compatibility/vision/transforms.py +++ /dev/null @@ -1 +0,0 @@ -from jittor.transform import * \ No newline at end of file diff --git a/python/jittor/compatibility/vision/utils.py b/python/jittor/compatibility/vision/utils.py deleted file mode 100644 index 4be36c64..00000000 --- a/python/jittor/compatibility/vision/utils.py +++ /dev/null @@ -1,582 +0,0 @@ -import collections -import math -import pathlib -import warnings -from itertools import repeat -from types import FunctionType -from typing import Any, BinaryIO, List, Optional, Tuple, Union - -import numpy as np -import torch -from PIL import Image, ImageColor, ImageDraw, ImageFont - -__all__ = [ - "make_grid", - "save_image", - "draw_bounding_boxes", - "draw_segmentation_masks", - "draw_keypoints", - "flow_to_image", -] - - -@torch.no_grad() -def make_grid( - tensor: Union[torch.Tensor, List[torch.Tensor]], - nrow: int = 8, - padding: int = 2, - normalize: bool = False, - value_range: Optional[Tuple[int, int]] = None, - scale_each: bool = False, - pad_value: float = 0.0, - **kwargs, -) -> torch.Tensor: - """ - Make a grid of images. - - Args: - tensor (Tensor or list): 4D mini-batch Tensor of shape (B x C x H x W) - or a list of images all of the same size. - nrow (int, optional): Number of images displayed in each row of the grid. - The final grid size is ``(B / nrow, nrow)``. Default: ``8``. - padding (int, optional): amount of padding. Default: ``2``. - normalize (bool, optional): If True, shift the image to the range (0, 1), - by the min and max values specified by ``value_range``. Default: ``False``. - value_range (tuple, optional): tuple (min, max) where min and max are numbers, - then these numbers are used to normalize the image. By default, min and max - are computed from the tensor. - range (tuple. optional): - .. warning:: - This parameter was deprecated in ``0.12`` and will be removed in ``0.14``. Please use ``value_range`` - instead. - scale_each (bool, optional): If ``True``, scale each image in the batch of - images separately rather than the (min, max) over all images. Default: ``False``. - pad_value (float, optional): Value for the padded pixels. Default: ``0``. - - Returns: - grid (Tensor): the tensor containing grid of images. - """ - if not torch.jit.is_scripting() and not torch.jit.is_tracing(): - _log_api_usage_once(make_grid) - if not torch.is_tensor(tensor): - if isinstance(tensor, list): - for t in tensor: - if not torch.is_tensor(t): - raise TypeError(f"tensor or list of tensors expected, got a list containing {type(t)}") - else: - raise TypeError(f"tensor or list of tensors expected, got {type(tensor)}") - - if "range" in kwargs.keys(): - warnings.warn( - "The parameter 'range' is deprecated since 0.12 and will be removed in 0.14. " - "Please use 'value_range' instead." - ) - value_range = kwargs["range"] - - # if list of tensors, convert to a 4D mini-batch Tensor - if isinstance(tensor, list): - tensor = torch.stack(tensor, dim=0) - - if tensor.dim() == 2: # single image H x W - tensor = tensor.unsqueeze(0) - if tensor.dim() == 3: # single image - if tensor.size(0) == 1: # if single-channel, convert to 3-channel - tensor = torch.cat((tensor, tensor, tensor), 0) - tensor = tensor.unsqueeze(0) - - if tensor.dim() == 4 and tensor.size(1) == 1: # single-channel images - tensor = torch.cat((tensor, tensor, tensor), 1) - - if normalize is True: - tensor = tensor.clone() # avoid modifying tensor in-place - if value_range is not None and not isinstance(value_range, tuple): - raise TypeError("value_range has to be a tuple (min, max) if specified. min and max are numbers") - - def norm_ip(img, low, high): - img.clamp_(min=low, max=high) - img.sub_(low).div_(max(high - low, 1e-5)) - - def norm_range(t, value_range): - if value_range is not None: - norm_ip(t, value_range[0], value_range[1]) - else: - norm_ip(t, float(t.min()), float(t.max())) - - if scale_each is True: - for t in tensor: # loop over mini-batch dimension - norm_range(t, value_range) - else: - norm_range(tensor, value_range) - - if not isinstance(tensor, torch.Tensor): - raise TypeError("tensor should be of type torch.Tensor") - if tensor.size(0) == 1: - return tensor.squeeze(0) - - # make the mini-batch of images into a grid - nmaps = tensor.size(0) - xmaps = min(nrow, nmaps) - ymaps = int(math.ceil(float(nmaps) / xmaps)) - height, width = int(tensor.size(2) + padding), int(tensor.size(3) + padding) - num_channels = tensor.size(1) - grid = tensor.new_full((num_channels, height * ymaps + padding, width * xmaps + padding), pad_value) - k = 0 - for y in range(ymaps): - for x in range(xmaps): - if k >= nmaps: - break - # Tensor.copy_() is a valid method but seems to be missing from the stubs - # https://pytorch.org/docs/stable/tensors.html#torch.Tensor.copy_ - grid.narrow(1, y * height + padding, height - padding).narrow( # type: ignore[attr-defined] - 2, x * width + padding, width - padding - ).copy_(tensor[k]) - k = k + 1 - return grid - - -@torch.no_grad() -def save_image( - tensor: Union[torch.Tensor, List[torch.Tensor]], - fp: Union[str, pathlib.Path, BinaryIO], - format: Optional[str] = None, - **kwargs, -) -> None: - """ - Save a given Tensor into an image file. - - Args: - tensor (Tensor or list): Image to be saved. If given a mini-batch tensor, - saves the tensor as a grid of images by calling ``make_grid``. - fp (string or file object): A filename or a file object - format(Optional): If omitted, the format to use is determined from the filename extension. - If a file object was used instead of a filename, this parameter should always be used. - **kwargs: Other arguments are documented in ``make_grid``. - """ - - if not torch.jit.is_scripting() and not torch.jit.is_tracing(): - _log_api_usage_once(save_image) - grid = make_grid(tensor, **kwargs) - # Add 0.5 after unnormalizing to [0, 255] to round to nearest integer - ndarr = grid.mul(255).add_(0.5).clamp_(0, 255).permute(1, 2, 0).to("cpu", torch.uint8).numpy() - im = Image.fromarray(ndarr) - im.save(fp, format=format) - - -@torch.no_grad() -def draw_bounding_boxes( - image: torch.Tensor, - boxes: torch.Tensor, - labels: Optional[List[str]] = None, - colors: Optional[Union[List[Union[str, Tuple[int, int, int]]], str, Tuple[int, int, int]]] = None, - fill: Optional[bool] = False, - width: int = 1, - font: Optional[str] = None, - font_size: Optional[int] = None, -) -> torch.Tensor: - - """ - Draws bounding boxes on given image. - The values of the input image should be uint8 between 0 and 255. - If fill is True, Resulting Tensor should be saved as PNG image. - - Args: - image (Tensor): Tensor of shape (C x H x W) and dtype uint8. - boxes (Tensor): Tensor of size (N, 4) containing bounding boxes in (xmin, ymin, xmax, ymax) format. Note that - the boxes are absolute coordinates with respect to the image. In other words: `0 <= xmin < xmax < W` and - `0 <= ymin < ymax < H`. - labels (List[str]): List containing the labels of bounding boxes. - colors (color or list of colors, optional): List containing the colors - of the boxes or single color for all boxes. The color can be represented as - PIL strings e.g. "red" or "#FF00FF", or as RGB tuples e.g. ``(240, 10, 157)``. - By default, random colors are generated for boxes. - fill (bool): If `True` fills the bounding box with specified color. - width (int): Width of bounding box. - font (str): A filename containing a TrueType font. If the file is not found in this filename, the loader may - also search in other directories, such as the `fonts/` directory on Windows or `/Library/Fonts/`, - `/System/Library/Fonts/` and `~/Library/Fonts/` on macOS. - font_size (int): The requested font size in points. - - Returns: - img (Tensor[C, H, W]): Image Tensor of dtype uint8 with bounding boxes plotted. - """ - - if not torch.jit.is_scripting() and not torch.jit.is_tracing(): - _log_api_usage_once(draw_bounding_boxes) - if not isinstance(image, torch.Tensor): - raise TypeError(f"Tensor expected, got {type(image)}") - elif image.dtype != torch.uint8: - raise ValueError(f"Tensor uint8 expected, got {image.dtype}") - elif image.dim() != 3: - raise ValueError("Pass individual images, not batches") - elif image.size(0) not in {1, 3}: - raise ValueError("Only grayscale and RGB images are supported") - elif (boxes[:, 0] > boxes[:, 2]).any() or (boxes[:, 1] > boxes[:, 3]).any(): - raise ValueError( - "Boxes need to be in (xmin, ymin, xmax, ymax) format. Use torchvision.ops.box_convert to convert them" - ) - - num_boxes = boxes.shape[0] - - if num_boxes == 0: - warnings.warn("boxes doesn't contain any box. No box was drawn") - return image - - if labels is None: - labels: Union[List[str], List[None]] = [None] * num_boxes # type: ignore[no-redef] - elif len(labels) != num_boxes: - raise ValueError( - f"Number of boxes ({num_boxes}) and labels ({len(labels)}) mismatch. Please specify labels for each box." - ) - - if colors is None: - colors = _generate_color_palette(num_boxes) - elif isinstance(colors, list): - if len(colors) < num_boxes: - raise ValueError(f"Number of colors ({len(colors)}) is less than number of boxes ({num_boxes}). ") - else: # colors specifies a single color for all boxes - colors = [colors] * num_boxes - - colors = [(ImageColor.getrgb(color) if isinstance(color, str) else color) for color in colors] - - if font is None: - if font_size is not None: - warnings.warn("Argument 'font_size' will be ignored since 'font' is not set.") - txt_font = ImageFont.load_default() - else: - txt_font = ImageFont.truetype(font=font, size=font_size or 10) - - # Handle Grayscale images - if image.size(0) == 1: - image = torch.tile(image, (3, 1, 1)) - - ndarr = image.permute(1, 2, 0).cpu().numpy() - img_to_draw = Image.fromarray(ndarr) - img_boxes = boxes.to(torch.int64).tolist() - - if fill: - draw = ImageDraw.Draw(img_to_draw, "RGBA") - else: - draw = ImageDraw.Draw(img_to_draw) - - for bbox, color, label in zip(img_boxes, colors, labels): # type: ignore[arg-type] - if fill: - fill_color = color + (100,) - draw.rectangle(bbox, width=width, outline=color, fill=fill_color) - else: - draw.rectangle(bbox, width=width, outline=color) - - if label is not None: - margin = width + 1 - draw.text((bbox[0] + margin, bbox[1] + margin), label, fill=color, font=txt_font) - - return torch.from_numpy(np.array(img_to_draw)).permute(2, 0, 1).to(dtype=torch.uint8) - - -@torch.no_grad() -def draw_segmentation_masks( - image: torch.Tensor, - masks: torch.Tensor, - alpha: float = 0.8, - colors: Optional[Union[List[Union[str, Tuple[int, int, int]]], str, Tuple[int, int, int]]] = None, -) -> torch.Tensor: - - """ - Draws segmentation masks on given RGB image. - The values of the input image should be uint8 between 0 and 255. - - Args: - image (Tensor): Tensor of shape (3, H, W) and dtype uint8. - masks (Tensor): Tensor of shape (num_masks, H, W) or (H, W) and dtype bool. - alpha (float): Float number between 0 and 1 denoting the transparency of the masks. - 0 means full transparency, 1 means no transparency. - colors (color or list of colors, optional): List containing the colors - of the masks or single color for all masks. The color can be represented as - PIL strings e.g. "red" or "#FF00FF", or as RGB tuples e.g. ``(240, 10, 157)``. - By default, random colors are generated for each mask. - - Returns: - img (Tensor[C, H, W]): Image Tensor, with segmentation masks drawn on top. - """ - - if not torch.jit.is_scripting() and not torch.jit.is_tracing(): - _log_api_usage_once(draw_segmentation_masks) - if not isinstance(image, torch.Tensor): - raise TypeError(f"The image must be a tensor, got {type(image)}") - elif image.dtype != torch.uint8: - raise ValueError(f"The image dtype must be uint8, got {image.dtype}") - elif image.dim() != 3: - raise ValueError("Pass individual images, not batches") - elif image.size()[0] != 3: - raise ValueError("Pass an RGB image. Other Image formats are not supported") - if masks.ndim == 2: - masks = masks[None, :, :] - if masks.ndim != 3: - raise ValueError("masks must be of shape (H, W) or (batch_size, H, W)") - if masks.dtype != torch.bool: - raise ValueError(f"The masks must be of dtype bool. Got {masks.dtype}") - if masks.shape[-2:] != image.shape[-2:]: - raise ValueError("The image and the masks must have the same height and width") - - num_masks = masks.size()[0] - if colors is not None and num_masks > len(colors): - raise ValueError(f"There are more masks ({num_masks}) than colors ({len(colors)})") - - if num_masks == 0: - warnings.warn("masks doesn't contain any mask. No mask was drawn") - return image - - if colors is None: - colors = _generate_color_palette(num_masks) - - if not isinstance(colors, list): - colors = [colors] - if not isinstance(colors[0], (tuple, str)): - raise ValueError("colors must be a tuple or a string, or a list thereof") - if isinstance(colors[0], tuple) and len(colors[0]) != 3: - raise ValueError("It seems that you passed a tuple of colors instead of a list of colors") - - out_dtype = torch.uint8 - - colors_ = [] - for color in colors: - if isinstance(color, str): - color = ImageColor.getrgb(color) - colors_.append(torch.tensor(color, dtype=out_dtype)) - - img_to_draw = image.detach().clone() - # TODO: There might be a way to vectorize this - for mask, color in zip(masks, colors_): - img_to_draw[:, mask] = color[:, None] - - out = image * (1 - alpha) + img_to_draw * alpha - return out.to(out_dtype) - - -@torch.no_grad() -def draw_keypoints( - image: torch.Tensor, - keypoints: torch.Tensor, - connectivity: Optional[List[Tuple[int, int]]] = None, - colors: Optional[Union[str, Tuple[int, int, int]]] = None, - radius: int = 2, - width: int = 3, -) -> torch.Tensor: - - """ - Draws Keypoints on given RGB image. - The values of the input image should be uint8 between 0 and 255. - - Args: - image (Tensor): Tensor of shape (3, H, W) and dtype uint8. - keypoints (Tensor): Tensor of shape (num_instances, K, 2) the K keypoints location for each of the N instances, - in the format [x, y]. - connectivity (List[Tuple[int, int]]]): A List of tuple where, - each tuple contains pair of keypoints to be connected. - colors (str, Tuple): The color can be represented as - PIL strings e.g. "red" or "#FF00FF", or as RGB tuples e.g. ``(240, 10, 157)``. - radius (int): Integer denoting radius of keypoint. - width (int): Integer denoting width of line connecting keypoints. - - Returns: - img (Tensor[C, H, W]): Image Tensor of dtype uint8 with keypoints drawn. - """ - - if not torch.jit.is_scripting() and not torch.jit.is_tracing(): - _log_api_usage_once(draw_keypoints) - if not isinstance(image, torch.Tensor): - raise TypeError(f"The image must be a tensor, got {type(image)}") - elif image.dtype != torch.uint8: - raise ValueError(f"The image dtype must be uint8, got {image.dtype}") - elif image.dim() != 3: - raise ValueError("Pass individual images, not batches") - elif image.size()[0] != 3: - raise ValueError("Pass an RGB image. Other Image formats are not supported") - - if keypoints.ndim != 3: - raise ValueError("keypoints must be of shape (num_instances, K, 2)") - - ndarr = image.permute(1, 2, 0).cpu().numpy() - img_to_draw = Image.fromarray(ndarr) - draw = ImageDraw.Draw(img_to_draw) - img_kpts = keypoints.to(torch.int64).tolist() - - for kpt_id, kpt_inst in enumerate(img_kpts): - for inst_id, kpt in enumerate(kpt_inst): - x1 = kpt[0] - radius - x2 = kpt[0] + radius - y1 = kpt[1] - radius - y2 = kpt[1] + radius - draw.ellipse([x1, y1, x2, y2], fill=colors, outline=None, width=0) - - if connectivity: - for connection in connectivity: - start_pt_x = kpt_inst[connection[0]][0] - start_pt_y = kpt_inst[connection[0]][1] - - end_pt_x = kpt_inst[connection[1]][0] - end_pt_y = kpt_inst[connection[1]][1] - - draw.line( - ((start_pt_x, start_pt_y), (end_pt_x, end_pt_y)), - width=width, - ) - - return torch.from_numpy(np.array(img_to_draw)).permute(2, 0, 1).to(dtype=torch.uint8) - - -# Flow visualization code adapted from https://github.com/tomrunia/OpticalFlow_Visualization -@torch.no_grad() -def flow_to_image(flow: torch.Tensor) -> torch.Tensor: - - """ - Converts a flow to an RGB image. - - Args: - flow (Tensor): Flow of shape (N, 2, H, W) or (2, H, W) and dtype torch.float. - - Returns: - img (Tensor): Image Tensor of dtype uint8 where each color corresponds - to a given flow direction. Shape is (N, 3, H, W) or (3, H, W) depending on the input. - """ - - if flow.dtype != torch.float: - raise ValueError(f"Flow should be of dtype torch.float, got {flow.dtype}.") - - orig_shape = flow.shape - if flow.ndim == 3: - flow = flow[None] # Add batch dim - - if flow.ndim != 4 or flow.shape[1] != 2: - raise ValueError(f"Input flow should have shape (2, H, W) or (N, 2, H, W), got {orig_shape}.") - - max_norm = torch.sum(flow**2, dim=1).sqrt().max() - epsilon = torch.finfo((flow).dtype).eps - normalized_flow = flow / (max_norm + epsilon) - img = _normalized_flow_to_image(normalized_flow) - - if len(orig_shape) == 3: - img = img[0] # Remove batch dim - return img - - -@torch.no_grad() -def _normalized_flow_to_image(normalized_flow: torch.Tensor) -> torch.Tensor: - - """ - Converts a batch of normalized flow to an RGB image. - - Args: - normalized_flow (torch.Tensor): Normalized flow tensor of shape (N, 2, H, W) - Returns: - img (Tensor(N, 3, H, W)): Flow visualization image of dtype uint8. - """ - - N, _, H, W = normalized_flow.shape - device = normalized_flow.device - flow_image = torch.zeros((N, 3, H, W), dtype=torch.uint8, device=device) - colorwheel = _make_colorwheel().to(device) # shape [55x3] - num_cols = colorwheel.shape[0] - norm = torch.sum(normalized_flow**2, dim=1).sqrt() - a = torch.atan2(-normalized_flow[:, 1, :, :], -normalized_flow[:, 0, :, :]) / torch.pi - fk = (a + 1) / 2 * (num_cols - 1) - k0 = torch.floor(fk).to(torch.long) - k1 = k0 + 1 - k1[k1 == num_cols] = 0 - f = fk - k0 - - for c in range(colorwheel.shape[1]): - tmp = colorwheel[:, c] - col0 = tmp[k0] / 255.0 - col1 = tmp[k1] / 255.0 - col = (1 - f) * col0 + f * col1 - col = 1 - norm * (1 - col) - flow_image[:, c, :, :] = torch.floor(255 * col) - return flow_image - - -def _make_colorwheel() -> torch.Tensor: - """ - Generates a color wheel for optical flow visualization as presented in: - Baker et al. "A Database and Evaluation Methodology for Optical Flow" (ICCV, 2007) - URL: http://vision.middlebury.edu/flow/flowEval-iccv07.pdf. - - Returns: - colorwheel (Tensor[55, 3]): Colorwheel Tensor. - """ - - RY = 15 - YG = 6 - GC = 4 - CB = 11 - BM = 13 - MR = 6 - - ncols = RY + YG + GC + CB + BM + MR - colorwheel = torch.zeros((ncols, 3)) - col = 0 - - # RY - colorwheel[0:RY, 0] = 255 - colorwheel[0:RY, 1] = torch.floor(255 * torch.arange(0, RY) / RY) - col = col + RY - # YG - colorwheel[col : col + YG, 0] = 255 - torch.floor(255 * torch.arange(0, YG) / YG) - colorwheel[col : col + YG, 1] = 255 - col = col + YG - # GC - colorwheel[col : col + GC, 1] = 255 - colorwheel[col : col + GC, 2] = torch.floor(255 * torch.arange(0, GC) / GC) - col = col + GC - # CB - colorwheel[col : col + CB, 1] = 255 - torch.floor(255 * torch.arange(CB) / CB) - colorwheel[col : col + CB, 2] = 255 - col = col + CB - # BM - colorwheel[col : col + BM, 2] = 255 - colorwheel[col : col + BM, 0] = torch.floor(255 * torch.arange(0, BM) / BM) - col = col + BM - # MR - colorwheel[col : col + MR, 2] = 255 - torch.floor(255 * torch.arange(MR) / MR) - colorwheel[col : col + MR, 0] = 255 - return colorwheel - - -def _generate_color_palette(num_objects: int): - palette = torch.tensor([2**25 - 1, 2**15 - 1, 2**21 - 1]) - return [tuple((i * palette) % 255) for i in range(num_objects)] - - -def _log_api_usage_once(obj: Any) -> None: - - """ - Logs API usage(module and name) within an organization. - In a large ecosystem, it's often useful to track the PyTorch and - TorchVision APIs usage. This API provides the similar functionality to the - logging module in the Python stdlib. It can be used for debugging purpose - to log which methods are used and by default it is inactive, unless the user - manually subscribes a logger via the `SetAPIUsageLogger method `_. - Please note it is triggered only once for the same API call within a process. - It does not collect any data from open-source users since it is no-op by default. - For more information, please refer to - * PyTorch note: https://pytorch.org/docs/stable/notes/large_scale_deployments.html#api-usage-logging; - * Logging policy: https://github.com/pytorch/vision/issues/5052; - - Args: - obj (class instance or method): an object to extract info from. - """ - pass - - -def _make_ntuple(x: Any, n: int) -> Tuple[Any, ...]: - """ - Make n-tuple from input x. If x is an iterable, then we just convert it to tuple. - Otherwise we will make a tuple of length n, all with value of x. - reference: https://github.com/pytorch/pytorch/blob/master/torch/nn/modules/utils.py#L8 - - Args: - x (Any): input value - n (int): length of the resulting tuple - """ - if isinstance(x, collections.abc.Iterable): - return tuple(x) - return tuple(repeat(x, n)) \ No newline at end of file From c0996cd09fb443b36ae96236b7b526446772f825 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E8=8C=83=E5=AD=99=E5=A5=87?= <2020012497@secoder.net> Date: Thu, 5 Sep 2024 20:14:36 +0800 Subject: [PATCH 72/73] fix dim=3 error --- python/jittor/nn.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/jittor/nn.py b/python/jittor/nn.py index 82440021..46ca78ec 100644 --- a/python/jittor/nn.py +++ b/python/jittor/nn.py @@ -1965,7 +1965,7 @@ def _interpolate(img, x, y, ids, mode): # TODO: tf_mode to another function def resize(img, size, mode="nearest", align_corners=False, tf_mode=False): - if img.dim() != 3: + if img.dim() != 4: raise ValueError("Input shape must be `(N, C, H, W)`!") n, c, h, w = img.shape H, W = size From 106380c42fbe5d534a143ed9ace83f0e7de098bd Mon Sep 17 00:00:00 2001 From: Zhiya Luo Date: Wed, 11 Sep 2024 11:00:15 +0800 Subject: [PATCH 73/73] fix `jittor.nn.AdaptiveMaxPool3d` doc --- doc/source/jittor.nn.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/doc/source/jittor.nn.md b/doc/source/jittor.nn.md index 0ef2b09f..0e5cd35d 100644 --- a/doc/source/jittor.nn.md +++ b/doc/source/jittor.nn.md @@ -10,7 +10,7 @@ jittor.nn .. automodule:: jittor.nn :imported-members: - :members: Pool, pool, AdaptiveAvgPool2d, Pool3d, AdaptiveMaxPool2d, AdaptiveAvgPool3d, AdaptiveMaxPool2d, pool3d, AvgPool2d, AvgPool3d, avg_pool2d, MaxPool2d, MaxPool3d, max_pool2d, max_pool3d, MaxUnpool2d, MaxUnpool3d + :members: Pool, pool, AdaptiveAvgPool2d, Pool3d, AdaptiveMaxPool2d, AdaptiveAvgPool3d, AdaptiveMaxPool3d, pool3d, AvgPool2d, AvgPool3d, avg_pool2d, MaxPool2d, MaxPool3d, max_pool2d, max_pool3d, MaxUnpool2d, MaxUnpool3d :undoc-members: .. autoclass:: jittor.nn.ReLU