diff --git a/README.md b/README.md index 414eb868..90601b43 100644 --- a/README.md +++ b/README.md @@ -382,10 +382,10 @@ Email: jittor@qq.com File an issue: https://github.com/Jittor/jittor/issues -QQ Group: 761222083 +QQ Group: 836860279 - + ## The Team diff --git a/doc/source/jittor.nn.md b/doc/source/jittor.nn.md index 0ef2b09f..0e5cd35d 100644 --- a/doc/source/jittor.nn.md +++ b/doc/source/jittor.nn.md @@ -10,7 +10,7 @@ jittor.nn .. automodule:: jittor.nn :imported-members: - :members: Pool, pool, AdaptiveAvgPool2d, Pool3d, AdaptiveMaxPool2d, AdaptiveAvgPool3d, AdaptiveMaxPool2d, pool3d, AvgPool2d, AvgPool3d, avg_pool2d, MaxPool2d, MaxPool3d, max_pool2d, max_pool3d, MaxUnpool2d, MaxUnpool3d + :members: Pool, pool, AdaptiveAvgPool2d, Pool3d, AdaptiveMaxPool2d, AdaptiveAvgPool3d, AdaptiveMaxPool3d, pool3d, AvgPool2d, AvgPool3d, avg_pool2d, MaxPool2d, MaxPool3d, max_pool2d, max_pool3d, MaxUnpool2d, MaxUnpool3d :undoc-members: .. autoclass:: jittor.nn.ReLU diff --git a/python/jittor/__init__.py b/python/jittor/__init__.py index cfd39c42..c2df0aa7 100644 --- a/python/jittor/__init__.py +++ b/python/jittor/__init__.py @@ -9,7 +9,7 @@ # file 'LICENSE.txt', which is part of this source code package. # *************************************************************** -__version__ = '1.3.9.6' +__version__ = '1.3.9.10' from jittor_utils import lock with lock.lock_scope(): ori_int = int @@ -428,7 +428,9 @@ def random(shape, dtype="float32", type="uniform"): jt.Var([[0.96788853 0.28334728 0.30482838] [0.46107793 0.62798643 0.03457401]], dtype=float32) ''' - + for dim in shape: + if dim < 0: + raise RuntimeError(f"Trying to create tensor with negative dimension {dim}: {shape}") ret = ops.random(shape, "float32", type) ## TODO: move those code to core #if dtype in ["float16", "bfloat16"]: @@ -484,6 +486,9 @@ def ones(*shape, dtype="float32"): shape = shape[:-1] if isinstance(shape, tuple) and isinstance(shape[0], (Sequence, NanoVector)): shape = shape[0] + for dim in shape: + if dim < 0: + raise RuntimeError(f"Trying to create tensor with negative dimension {dim}: {shape}") return unary(1, dtype).broadcast(shape) def new_ones(x, size): @@ -515,6 +520,9 @@ def zeros(*shape, dtype="float32"): shape = shape[:-1] if isinstance(shape, tuple) and isinstance(shape[0], (Sequence, NanoVector)): shape = shape[0] + for dim in shape: + if dim < 0: + raise RuntimeError(f"Trying to create tensor with negative dimension {dim}: {shape}") return unary(0, dtype).broadcast(shape) def new_zeros(x, size): @@ -547,6 +555,9 @@ def full(shape,val,dtype="float32"): ''' if not isinstance(shape, (NanoVector, Sequence)): shape = (shape,) + for dim in shape: + if dim < 0: + raise RuntimeError(f"Trying to create tensor with negative dimension {dim}: {shape}") return unary(val, dtype).broadcast(shape) def new_full(x, size, val): @@ -641,14 +652,22 @@ def var(x, dim=None, dims=None, unbiased=False, keepdims=False): return sqr Var.var = var -def std(x): - matsize=1 - for i in x.shape: - matsize *= i - out=(x-x.mean()).sqr().sum() - out=out/(matsize-1) - out=out.maximum(1e-6).sqrt() - return out +def std(x, dim=None, keepdim=False): + if dim is None: + matsize=1 + for i in x.shape: + matsize *= i + out=(x-x.mean()).sqr().sum() + out=out/(matsize-1) + out=out.maximum(1e-6).sqrt() + return out + else: + dimsize=x.size(dim) + mean=jt.mean(x, dim, keepdim=True) + out=(x - mean).sqr().sum(dim=dim, keepdim=keepdim) + out=out/(dimsize-1) + out=out.maximum(1e-6).sqrt() + return out Var.std = std def norm(x, p=2, dim=-1, keepdims=False, eps=1e-30, keepdim=False): @@ -687,6 +706,8 @@ def flatten(input, start_dim=0, end_dim=-1): start_dim = len(in_shape) + start_dim if start_dim < 0 else start_dim end_dim = len(in_shape) + end_dim if end_dim < 0 else end_dim assert end_dim >= start_dim, "end_dim should be larger than or equal to start_dim for flatten function" + if len(in_shape) <= end_dim: + raise IndexError(f"Dimension out of range (expected to be in range of [{-len(in_shape)}, {len(in_shape) - 1}], but got {end_dim})") out_shape = [] for i in range(0,start_dim,1): out_shape.append(in_shape[i]) dims = 1 @@ -917,6 +938,9 @@ def randn(*size, dtype="float32", requires_grad=True) -> Var: [-0.612632 -1.1471151 -1.1879086 ]], dtype=float32) ''' if isinstance(size, tuple) and isinstance(size[0], (tuple, list, NanoVector)): size = size[0] + for dim in size: + if dim < 0: + raise RuntimeError(f"Trying to create tensor with negative dimension {dim}: {size}") arr = jt.random(size, dtype, "normal") if not requires_grad: return arr.stop_grad() return arr @@ -1013,6 +1037,9 @@ def randint(low, high=None, shape=(1,), dtype="int32") -> Var: [1 1 1]], dtype=int32) ''' if high is None: low, high = 0, low + for dim in shape: + if dim < 0: + raise RuntimeError(f"Trying to create tensor with negative dimension {dim}: {shape}") v = (jt.random(shape) * (high - low) + low).clamp(low, high-0.5) v = jt.floor_int(v) return v.astype(dtype) @@ -1437,9 +1464,17 @@ def requires_grad_(self, requires_grad=True): def __hooked_call__(self, *args, **kw): if hasattr(self, "__fhook2__"): if len(kw): - self.__fhook2__(self, args, kw) + args_kw_result = self.__fhook2__(self, args, kw) else: - self.__fhook2__(self, args) + args_kw_result = self.__fhook2__(self, args) + if args_kw_result is not None: + if isinstance(args_kw_result, tuple) and len(args_kw_result) == 2: + args, kw = args_kw_result + else: + raise RuntimeError( + "forward pre-hook must return None or a tuple " + f"of (new_args, new_kwargs), but got {args_kw_result}." + ) if hasattr(self, "__bihook__"): if len(kw): LOG.w("backward hook not support kw") @@ -1458,9 +1493,11 @@ def __hooked_call__(self, *args, **kw): ret = grad_hooker(ret, self.__bohook__) if hasattr(self, "__fhook__"): if len(kw): - self.__fhook__(self, args, ret, kw) + res = self.__fhook__(self, args, ret, kw) else: - self.__fhook__(self, args, ret) + res = self.__fhook__(self, args, ret) + if res is not None: + ret = res return ret def _place_hooker(self): @@ -1595,6 +1632,8 @@ def load_parameters(self, params): else: if hasattr(v, k): v = getattr(v, k) + if v is None: + continue assert isinstance(v, (Module, Var)), \ f"expect a jittor Module or Var, but got <{v.__class__.__name__}>, key: {key}" else: @@ -2119,6 +2158,7 @@ def is_var(v): from . import optim from . import dataset from . import init +from . import gradfunctional dtype = NanoString @@ -2152,3 +2192,7 @@ def inplace_wrapper(new_k, prev_func): from . import math_util from .math_util import * from . import distributions + +if jt.compiler.has_acl: + from jittor.extern.acl.acl_compiler import change_function + change_function() diff --git a/python/jittor/__init__.pyi b/python/jittor/__init__.pyi index b849af4c..6cfce692 100644 --- a/python/jittor/__init__.pyi +++ b/python/jittor/__init__.pyi @@ -1,7 +1,7 @@ from jittor_core import * from jittor_core.ops import * from .misc import * -from . import attention as attention, contrib as contrib, dataset as dataset, init as init, linalg as linalg, lr_scheduler as lr_scheduler, numpy2cupy as numpy2cupy, optim as optim, sparse as sparse +from . import attention as attention, contrib as contrib, dataset as dataset, init as init, linalg as linalg, lr_scheduler as lr_scheduler, numpy2cupy as numpy2cupy, optim as optim, sparse as sparse, gradfunctional as gradfunctional from .compile_extern import cublas as cublas, cudnn as cudnn, cufft as cufft, curand as curand, mkl_ops as mkl_ops, mpi_ops as mpi_ops, world_size as world_size from .compiler import compile_custom_op as compile_custom_op, compile_custom_ops as compile_custom_ops from .contrib import concat as concat diff --git a/python/jittor/attention.py b/python/jittor/attention.py index a8a486cb..5ae59c1b 100644 --- a/python/jittor/attention.py +++ b/python/jittor/attention.py @@ -9,168 +9,575 @@ # file 'LICENSE.txt', which is part of this source code package. # *************************************************************** -import jittor as jt -from jittor import init, Module, nn -import numpy as np +from typing import Optional, Tuple, List +import warnings import math +import jittor as jt +from jittor import Var +from jittor.nn import Module, Linear, softmax, pad, linear, dropout +from jittor.init import xavier_uniform_, xavier_gauss_, constant_ + +def _canonical_mask( + mask: Optional[Var], + mask_name: str, + other_type, + other_name: str, + target_type, + check_other: bool = True, +) -> Optional[Var]: + + if mask is not None: + _mask_dtype = mask.dtype + _mask_is_float = mask.dtype == jt.float16 or mask.dtype == jt.float32 or mask.dtype == jt.float64 + if _mask_dtype != jt.bool and not _mask_is_float: + raise AssertionError( + f"only bool and floating types of {mask_name} are supported") + if check_other and other_type is not None: + if _mask_dtype != other_type: + warnings.warn( + f"Support for mismatched {mask_name} and {other_name} " + "is deprecated. Use same type for both instead." + ) + if not _mask_is_float: + # WARNING(514flowey): Check Here + new_mask = jt.zeros_like(mask, dtype=target_type) + new_mask[mask] = float("-inf") + mask = new_mask + return mask + +def _none_or_dtype(input: Optional[Var]): + if input is None: + return None + elif isinstance(input, jt.Var): + return input.dtype + raise RuntimeError("input to _none_or_dtype() must be None or torch.Tensor") + +def baddbmm(input_var:jt.Var, batch1:jt.Var, batch2:jt.Var, beta=1, alpha=1) -> jt.Var: + # WARNING(514flowey): Check here + return beta * input_var + alpha * (batch1 @ batch2) + +def scaled_dot_product_attention(query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False, scale=None) -> jt.Var: + # Efficient implementation equivalent to the following: + L, S = query.size(-2), key.size(-2) + scale_factor = 1 / math.sqrt(query.size(-1)) if scale is None else scale + attn_bias = jt.zeros(L, S, dtype=query.dtype) + if is_causal: + assert attn_mask is None + temp_mask = jt.ones(L, S, dtype=jt.bool).tril(diagonal=0) + attn_bias[jt.logical_not(temp_mask)] = float("-inf") + # attn_bias.to(query.dtype) + attn_bias = jt.array(attn_bias, query.dtype) + + if attn_mask is not None: + if attn_mask.dtype == jt.bool: + attn_bias[jt.logical_not(temp_mask)] = float("-inf") + else: + attn_bias += attn_mask + attn_weight = query @ key.transpose(-2, -1) * scale_factor + attn_weight += attn_bias + attn_weight = softmax(attn_weight, dim=-1) + attn_weight = dropout(attn_weight, dropout_p, train=True) + return attn_weight @ value + +def _mha_shape_check(query: Var, key: Var, value: Var, + key_padding_mask: Optional[Var], attn_mask: Optional[Var], num_heads: int): + if query.dim() == 3: + is_batched = True + assert key.dim() == 3 and value.dim() == 3, \ + ("For batched (3-D) `query`, expected `key` and `value` to be 3-D" + f" but found {key.dim()}-D and {value.dim()}-D tensors respectively") + if key_padding_mask is not None: + assert key_padding_mask.dim() == 2, \ + ("For batched (3-D) `query`, expected `key_padding_mask` to be `None` or 2-D" + f" but found {key_padding_mask.dim()}-D tensor instead") + if attn_mask is not None: + assert attn_mask.dim() in (2, 3), \ + ("For batched (3-D) `query`, expected `attn_mask` to be `None`, 2-D or 3-D" + f" but found {attn_mask.dim()}-D tensor instead") + elif query.dim() == 2: + is_batched = False + assert key.dim() == 2 and value.dim() == 2, \ + ("For unbatched (2-D) `query`, expected `key` and `value` to be 2-D" + f" but found {key.dim()}-D and {value.dim()}-D tensors respectively") + + if key_padding_mask is not None: + assert key_padding_mask.dim() == 1, \ + ("For unbatched (2-D) `query`, expected `key_padding_mask` to be `None` or 1-D" + f" but found {key_padding_mask.dim()}-D tensor instead") + + if attn_mask is not None: + assert attn_mask.dim() in (2, 3), \ + ("For unbatched (2-D) `query`, expected `attn_mask` to be `None`, 2-D or 3-D" + f" but found {attn_mask.dim()}-D tensor instead") + if attn_mask.dim() == 3: + expected_shape = (num_heads, query.shape[0], key.shape[0]) + assert attn_mask.shape == expected_shape, \ + (f"Expected `attn_mask` shape to be {expected_shape} but got {attn_mask.shape}") + else: + raise AssertionError( + f"query should be unbatched 2D or batched 3D tensor but received {query.dim()}-D query tensor") + + return is_batched + +def _in_projection_packed( + q: Var, + k: Var, + v: Var, + w: Var, + b: Optional[Var] = None, +) -> List[Var]: + E = q.size(-1) + if k is v: + if q is k: + # self-attention + proj = linear(q, w, b) + # reshape to 3, E and not E, 3 is deliberate for better memory coalescing and keeping same order as chunk() + # proj = proj.unflatten(-1, (3, E)).unsqueeze(0).transpose(0, -2).squeeze(-2).contiguous() + nshape = proj.shape[:-1] + (3, E) + proj = proj.reshape(nshape).unsqueeze(0).transpose(0, -2).squeeze(-2) + return proj[0], proj[1], proj[2] + else: + # encoder-decoder attention + w_q, w_kv = w.split([E, E * 2]) + if b is None: + b_q = b_kv = None + else: + b_q, b_kv = b.split([E, E * 2]) + q_proj = linear(q, w_q, b_q) + kv_proj = linear(k, w_kv, b_kv) + # reshape to 2, E and not E, 2 is deliberate for better memory coalescing and keeping same order as chunk() + # kv_proj = kv_proj.unflatten(-1, (2, E)).unsqueeze(0).transpose(0, -2).squeeze(-2).contiguous() + nshape = kv_proj.shape[:-1] + (2, E) + kv_proj = kv_proj.reshape(nshape).unsqueeze(0).transpose(0, -2).squeeze(-2) + return (q_proj, kv_proj[0], kv_proj[1]) + else: + w_q, w_k, w_v = w.chunk(3) + if b is None: + b_q = b_k = b_v = None + else: + b_q, b_k, b_v = b.chunk(3) + return linear(q, w_q, b_q), linear(k, w_k, b_k), linear(v, w_v, b_v) + +def _in_projection( + q: Var, + k: Var, + v: Var, + w_q: Var, + w_k: Var, + w_v: Var, + b_q: Optional[Var] = None, + b_k: Optional[Var] = None, + b_v: Optional[Var] = None, +) -> Tuple[Var, Var, Var]: + Eq, Ek, Ev = q.size(-1), k.size(-1), v.size(-1) + assert w_q.shape == (Eq, Eq), f"expecting query weights shape of {(Eq, Eq)}, but got {w_q.shape}" + assert w_k.shape == (Eq, Ek), f"expecting key weights shape of {(Eq, Ek)}, but got {w_k.shape}" + assert w_v.shape == (Eq, Ev), f"expecting value weights shape of {(Eq, Ev)}, but got {w_v.shape}" + assert b_q is None or b_q.shape == (Eq,), f"expecting query bias shape of {(Eq,)}, but got {b_q.shape}" + assert b_k is None or b_k.shape == (Eq,), f"expecting key bias shape of {(Eq,)}, but got {b_k.shape}" + assert b_v is None or b_v.shape == (Eq,), f"expecting value bias shape of {(Eq,)}, but got {b_v.shape}" + return linear(q, w_q, b_q), linear(k, w_k, b_k), linear(v, w_v, b_v) + +def multi_head_attention_forward( + query: Var, + key: Var, + value: Var, + embed_dim_to_check: int, + num_heads: int, + in_proj_weight: Optional[Var], + in_proj_bias: Optional[Var], + bias_k: Optional[Var], + bias_v: Optional[Var], + add_zero_attn: bool, + dropout_p: float, + out_proj_weight: Var, + out_proj_bias: Optional[Var], + training: bool = True, + key_padding_mask: Optional[Var] = None, + need_weights: bool = True, + attn_mask: Optional[Var] = None, + use_separate_proj_weight: bool = False, + q_proj_weight: Optional[Var] = None, + k_proj_weight: Optional[Var] = None, + v_proj_weight: Optional[Var] = None, + static_k: Optional[Var] = None, + static_v: Optional[Var] = None, + average_attn_weights: bool = True, + is_causal: bool = False, +) -> Tuple[Var, Optional[Var]]: + + is_batched = _mha_shape_check(query, key, value, key_padding_mask, attn_mask, num_heads) + + # For unbatched input, we unsqueeze at the expected batch-dim to pretend that the input + # is batched, run the computation and before returning squeeze the + # batch dimension so that the output doesn't carry this temporary batch dimension. + if not is_batched: + # unsqueeze if the input is unbatched + query = query.unsqueeze(1) + key = key.unsqueeze(1) + value = value.unsqueeze(1) + if key_padding_mask is not None: + key_padding_mask = key_padding_mask.unsqueeze(0) + + # set up shape vars + tgt_len, bsz, embed_dim = query.shape + src_len, _, _ = key.shape + + key_padding_mask = _canonical_mask( + mask=key_padding_mask, + mask_name="key_padding_mask", + other_type=_none_or_dtype(attn_mask), + other_name="attn_mask", + target_type=query.dtype + ) + + if is_causal and attn_mask is None: + raise RuntimeError( + "Need attn_mask if specifying the is_causal hint. " + "You may use the Transformer module method " + "`generate_square_subsequent_mask` to create this mask." + ) + + if is_causal and key_padding_mask is None and not need_weights: + # when we have a kpm or need weights, we need attn_mask + # Otherwise, we use the is_causal hint go as is_causal + # indicator to SDPA. + attn_mask = None + else: + attn_mask = _canonical_mask( + mask=attn_mask, + mask_name="attn_mask", + other_type=None, + other_name="", + target_type=query.dtype, + check_other=False, + ) + + if key_padding_mask is not None: + # We have the attn_mask, and use that to merge kpm into it. + # Turn off use of is_causal hint, as the merged mask is no + # longer causal. + is_causal = False + + assert embed_dim == embed_dim_to_check, \ + f"was expecting embedding dimension of {embed_dim_to_check}, but got {embed_dim}" + if isinstance(embed_dim, jt.Var): + # embed_dim can be a tensor when JIT tracing + head_dim = embed_dim.div(num_heads, rounding_mode='trunc') + else: + head_dim = embed_dim // num_heads + assert head_dim * num_heads == embed_dim, f"embed_dim {embed_dim} not divisible by num_heads {num_heads}" + if use_separate_proj_weight: + # allow MHA to have different embedding dimensions when separate projection weights are used + assert key.shape[:2] == value.shape[:2], \ + f"key's sequence and batch dims {key.shape[:2]} do not match value's {value.shape[:2]}" + else: + assert key.shape == value.shape, f"key shape {key.shape} does not match value shape {value.shape}" + + # + # compute in-projection + # + if not use_separate_proj_weight: + assert in_proj_weight is not None, "use_separate_proj_weight is False but in_proj_weight is None" + q, k, v = _in_projection_packed(query, key, value, in_proj_weight, in_proj_bias) + else: + assert q_proj_weight is not None, "use_separate_proj_weight is True but q_proj_weight is None" + assert k_proj_weight is not None, "use_separate_proj_weight is True but k_proj_weight is None" + assert v_proj_weight is not None, "use_separate_proj_weight is True but v_proj_weight is None" + if in_proj_bias is None: + b_q = b_k = b_v = None + else: + b_q, b_k, b_v = in_proj_bias.chunk(3) + q, k, v = _in_projection(query, key, value, q_proj_weight, k_proj_weight, v_proj_weight, b_q, b_k, b_v) + + # prep attention mask + + if attn_mask is not None: + # ensure attn_mask's dim is 3 + if attn_mask.dim() == 2: + correct_2d_size = (tgt_len, src_len) + if attn_mask.shape != correct_2d_size: + raise RuntimeError(f"The shape of the 2D attn_mask is {attn_mask.shape}, but should be {correct_2d_size}.") + attn_mask = attn_mask.unsqueeze(0) + elif attn_mask.dim() == 3: + correct_3d_size = (bsz * num_heads, tgt_len, src_len) + if attn_mask.shape != correct_3d_size: + raise RuntimeError(f"The shape of the 3D attn_mask is {attn_mask.shape}, but should be {correct_3d_size}.") + else: + raise RuntimeError(f"attn_mask's dimension {attn_mask.dim()} is not supported") + + # add bias along batch dimension (currently second) + if bias_k is not None and bias_v is not None: + assert static_k is None, "bias cannot be added to static key." + assert static_v is None, "bias cannot be added to static value." + k = jt.concat([k, bias_k.repeat(1, bsz, 1)]) + v = jt.concat([v, bias_v.repeat(1, bsz, 1)]) + if attn_mask is not None: + attn_mask = pad(attn_mask, (0, 1)) + if key_padding_mask is not None: + key_padding_mask = pad(key_padding_mask, (0, 1)) + else: + assert bias_k is None + assert bias_v is None + + # + # reshape q, k, v for multihead attention and make em batch first + # + q = q.view(tgt_len, bsz * num_heads, head_dim).transpose(0, 1) + if static_k is None: + k = k.view(k.shape[0], bsz * num_heads, head_dim).transpose(0, 1) + else: + # TODO finish disentangling control flow so we don't do in-projections when statics are passed + assert static_k.size(0) == bsz * num_heads, \ + f"expecting static_k.size(0) of {bsz * num_heads}, but got {static_k.size(0)}" + assert static_k.size(2) == head_dim, \ + f"expecting static_k.size(2) of {head_dim}, but got {static_k.size(2)}" + k = static_k + if static_v is None: + v = v.view(v.shape[0], bsz * num_heads, head_dim).transpose(0, 1) + else: + # TODO finish disentangling control flow so we don't do in-projections when statics are passed + assert static_v.size(0) == bsz * num_heads, \ + f"expecting static_v.size(0) of {bsz * num_heads}, but got {static_v.size(0)}" + assert static_v.size(2) == head_dim, \ + f"expecting static_v.size(2) of {head_dim}, but got {static_v.size(2)}" + v = static_v + + # add zero attention along batch dimension (now first) + if add_zero_attn: + zero_attn_shape = (bsz * num_heads, 1, head_dim) + k = jt.concat([k, jt.zeros(zero_attn_shape, dtype=k.dtype)], dim=1) + v = jt.concat([v, jt.zeros(zero_attn_shape, dtype=v.dtype)], dim=1) + if attn_mask is not None: + attn_mask = pad(attn_mask, (0, 1)) + if key_padding_mask is not None: + key_padding_mask = pad(key_padding_mask, (0, 1)) + + # update source sequence length after adjustments + src_len = k.size(1) + + # merge key padding and attention masks + if key_padding_mask is not None: + assert key_padding_mask.shape == (bsz, src_len), \ + f"expecting key_padding_mask shape of {(bsz, src_len)}, but got {key_padding_mask.shape}" + key_padding_mask = key_padding_mask.view(bsz, 1, 1, src_len). \ + expand(-1, num_heads, -1, -1).reshape(bsz * num_heads, 1, src_len) + if attn_mask is None: + attn_mask = key_padding_mask + else: + attn_mask = attn_mask + key_padding_mask + + # adjust dropout probability + if not training: + dropout_p = 0.0 + + # + # (deep breath) calculate attention and out projection + # + + if need_weights: + B, Nt, E = q.shape + q_scaled = q / math.sqrt(E) + + assert not (is_causal and attn_mask is None), "FIXME: is_causal not implemented for need_weights" + + if attn_mask is not None: + attn_output_weights = baddbmm(attn_mask, q_scaled, k.transpose(-2, -1)) + else: + attn_output_weights = jt.bmm(q_scaled, k.transpose(-2, -1)) + attn_output_weights = softmax(attn_output_weights, dim=-1) + if dropout_p > 0.0: + attn_output_weights = dropout(attn_output_weights, p=dropout_p) + + attn_output = jt.bmm(attn_output_weights, v) + + attn_output = attn_output.transpose(0, 1).contiguous().view(tgt_len * bsz, embed_dim) + attn_output = linear(attn_output, out_proj_weight, out_proj_bias) + attn_output = attn_output.view(tgt_len, bsz, attn_output.size(1)) + + # optionally average attention weights over heads + attn_output_weights = attn_output_weights.view(bsz, num_heads, tgt_len, src_len) + if average_attn_weights: + attn_output_weights = attn_output_weights.mean(dim=1) + + if not is_batched: + # squeeze the output if input was unbatched + attn_output = attn_output.squeeze(1) + attn_output_weights = attn_output_weights.squeeze(0) + return attn_output, attn_output_weights + else: + # attn_mask can be either (L,S) or (N*num_heads, L, S) + # if attn_mask's shape is (1, L, S) we need to unsqueeze to (1, 1, L, S) + # in order to match the input for SDPA of (N, num_heads, L, S) + if attn_mask is not None: + if attn_mask.size(0) == 1 and attn_mask.dim() == 3: + attn_mask = attn_mask.unsqueeze(0) + else: + attn_mask = attn_mask.view(bsz, num_heads, -1, src_len) + + q = q.view(bsz, num_heads, tgt_len, head_dim) + k = k.view(bsz, num_heads, src_len, head_dim) + v = v.view(bsz, num_heads, src_len, head_dim) + + attn_output = scaled_dot_product_attention(q, k, v, attn_mask, dropout_p, is_causal) + attn_output = attn_output.permute(2, 0, 1, 3).contiguous().view(bsz * tgt_len, embed_dim) + + attn_output = linear(attn_output, out_proj_weight, out_proj_bias) + attn_output = attn_output.view(tgt_len, bsz, attn_output.size(1)) + if not is_batched: + # squeeze the output if input was unbatched + attn_output = attn_output.squeeze(1) + return attn_output, None + + class MultiheadAttention(Module): - def __init__( - self, - embed_dim, - num_heads, - kdim=None, - vdim=None, - dropout=0.0, - bias=True, - add_bias_kv=False, - add_zero_attn=False, - self_attention=False, - encoder_decoder_attention=False, - q_noise=0.0, - qn_block_size=8, - ): + __constants__ = ['batch_first'] + bias_k: Optional[jt.Var] + bias_v: Optional[jt.Var] + + def __init__(self, embed_dim, num_heads, dropout=0., bias=True, add_bias_kv=False, add_zero_attn=False, + kdim=None, vdim=None, batch_first=False, dtype=jt.float32) -> None: + if embed_dim <= 0 or num_heads <= 0: + raise ValueError( + f"embed_dim and num_heads must be greater than 0," + f" got embed_dim={embed_dim} and num_heads={num_heads} instead" + ) + factory_kwargs = {'dtype': dtype} super().__init__() self.embed_dim = embed_dim self.kdim = kdim if kdim is not None else embed_dim self.vdim = vdim if vdim is not None else embed_dim - self.qkv_same_dim = self.kdim == embed_dim and self.vdim == embed_dim + self._qkv_same_embed_dim = self.kdim == embed_dim and self.vdim == embed_dim self.num_heads = num_heads - assert dropout==0, "TODO: dropout>0" - + self.dropout = dropout + self.batch_first = batch_first self.head_dim = embed_dim // num_heads - assert (self.head_dim * num_heads == self.embed_dim), "embed_dim must be divisible by num_heads" - self.scaling = self.head_dim ** -0.5 - - self.self_attention = self_attention - self.encoder_decoder_attention = encoder_decoder_attention - - assert not self.self_attention or self.qkv_same_dim, ("Self-attention requires query, key and " "value to be of the same size") + assert self.head_dim * num_heads == self.embed_dim, "embed_dim must be divisible by num_heads" - #TODO: quant_noise - self.k_proj = nn.Linear(self.kdim, embed_dim, bias=bias) - self.v_proj = nn.Linear(self.vdim, embed_dim, bias=bias) - self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias) + if not self._qkv_same_embed_dim: + self.q_proj_weight = jt.empty((embed_dim, embed_dim), **factory_kwargs) + self.k_proj_weight = jt.empty((embed_dim, self.kdim), **factory_kwargs) + self.v_proj_weight = jt.empty((embed_dim, self.vdim), **factory_kwargs) + self.in_proj_weight = None + else: + self.q_proj_weight = None + self.k_proj_weight = None + self.v_proj_weight = None + self.in_proj_weight = jt.empty((3 * embed_dim, embed_dim), **factory_kwargs) - self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias) + if bias: + self.in_proj_bias = jt.empty(3 * embed_dim, **factory_kwargs) + else: + self.in_proj_bias = None + self.out_proj = Linear(embed_dim, embed_dim, bias=bias) - assert not add_bias_kv, "TODO: add_bias_kv=True" - self.bias_k = self.bias_v = None + if add_bias_kv: + self.bias_k = jt.empty((1, 1, embed_dim), **factory_kwargs) + self.bias_v = jt.empty((1, 1, embed_dim), **factory_kwargs) + else: + self.bias_k = self.bias_v = None self.add_zero_attn = add_zero_attn - self.reset_parameters() - - self.onnx_trace = False - self.tpu = False - - def reset_parameters(self): - if self.qkv_same_dim: - # Empirically observed the convergence to be much better with - # the scaled initialization - init.xavier_uniform_(self.k_proj.weight, gain=1 / math.sqrt(2)) - init.xavier_uniform_(self.v_proj.weight, gain=1 / math.sqrt(2)) - init.xavier_uniform_(self.q_proj.weight, gain=1 / math.sqrt(2)) + self._reset_parameters() + + def _reset_parameters(self): + if self._qkv_same_embed_dim: + xavier_uniform_(self.in_proj_weight) else: - init.xavier_uniform_(self.k_proj.weight) - init.xavier_uniform_(self.v_proj.weight) - init.xavier_uniform_(self.q_proj.weight) + xavier_uniform_(self.q_proj_weight) + xavier_uniform_(self.k_proj_weight) + xavier_uniform_(self.v_proj_weight) - # init.xavier_uniform_(self.out_proj.weight) - if self.out_proj.bias is not None: - init.constant_(self.out_proj.bias, 0.) + if self.in_proj_bias is not None: + constant_(self.in_proj_bias, 0.) + constant_(self.out_proj.bias, 0.) if self.bias_k is not None: - init.xavier_normal_(self.bias_k) + xavier_gauss_(self.bias_k) if self.bias_v is not None: - init.xavier_normal_(self.bias_v) + xavier_gauss_(self.bias_v) + + def __setstate__(self, state): + # Support loading old MultiheadAttention checkpoints generated by v1.1.0 + if '_qkv_same_embed_dim' not in state: + state['_qkv_same_embed_dim'] = True + + super().__setstate__(state) def execute( - self, - query, - key = None, - value = None, - key_padding_mask = None, - incremental_state = None, - need_weights = True, - static_kv = False, - attn_mask = None, - before_softmax = False, - need_head_weights = False, - ): - if need_head_weights: - need_weights = True - - tgt_len, bsz, embed_dim = query.shape - assert embed_dim == self.embed_dim - assert list(query.shape) == [tgt_len, bsz, embed_dim] - - assert incremental_state is None, "TODO: incremental_state is not None" - saved_state = None - - if self.self_attention: - q = self.q_proj(query) - k = self.k_proj(query) - v = self.v_proj(query) - elif self.encoder_decoder_attention: - # encoder-decoder attention - q = self.q_proj(query) - if key is None: - assert value is None - k = v = None + self, + query: Var, + key: Var, + value: Var, + key_padding_mask: Optional[Var] = None, + need_weights: bool = True, + attn_mask: Optional[Var] = None, + average_attn_weights: bool = True, + is_causal : bool = False) -> Tuple[Var, Optional[Var]]: + + ##### + # Fast Path is not Supported. + ##### + + is_batched = query.dim() == 3 + + key_padding_mask = _canonical_mask( + mask=key_padding_mask, + mask_name="key_padding_mask", + other_type=_none_or_dtype(attn_mask), + other_name="attn_mask", + target_type=query.dtype + ) + + attn_mask = _canonical_mask( + mask=attn_mask, + mask_name="attn_mask", + other_type=None, + other_name="", + target_type=query.dtype, + check_other=False, + ) + + if self.batch_first and is_batched: + # make sure that the transpose op does not affect the "is" property + if key is value: + if query is key: + query = key = value = query.transpose(1, 0) + else: + query, key = (x.transpose(1, 0) for x in (query, key)) + value = key else: - k = self.k_proj(key) - v = self.v_proj(key) + query, key, value = (x.transpose(1, 0) for x in (query, key, value)) + + if not self._qkv_same_embed_dim: + attn_output, attn_output_weights = multi_head_attention_forward( + query, key, value, self.embed_dim, self.num_heads, + self.in_proj_weight, self.in_proj_bias, + self.bias_k, self.bias_v, self.add_zero_attn, + self.dropout, self.out_proj.weight, self.out_proj.bias, + training=self.is_training(), + key_padding_mask=key_padding_mask, need_weights=need_weights, + attn_mask=attn_mask, + use_separate_proj_weight=True, + q_proj_weight=self.q_proj_weight, k_proj_weight=self.k_proj_weight, + v_proj_weight=self.v_proj_weight, + average_attn_weights=average_attn_weights, + is_causal=is_causal) else: - assert key is not None and value is not None - q = self.q_proj(query) - k = self.k_proj(key) - v = self.v_proj(value) - q = q*self.scaling - - assert self.bias_k is None, "TODO: self.bias_k is not None:" - - q = q.view(tgt_len, bsz * self.num_heads, self.head_dim).transpose(1, 0, 2) - if k is not None: - k = k.view(-1, bsz * self.num_heads, self.head_dim).transpose(1, 0, 2) - if v is not None: - v = v.view(-1, bsz * self.num_heads, self.head_dim).transpose(1, 0, 2) - - assert saved_state is None, "TODO: saved_state is not None" - assert k is not None - src_len = k.shape[1] - - assert key_padding_mask is None, "TODO: key_padding_mask is not None" - assert not self.add_zero_attn, "TODO: self.add_zero_attn=True" - - attn_weights = nn.bmm(q, k.transpose(0, 2, 1)) - - assert list(attn_weights.shape) == [bsz * self.num_heads, tgt_len, src_len] - - assert attn_mask is None, "TODO: attn_mask is not None" - assert key_padding_mask is None, "TODO: key_padding_mask is not None" - - if before_softmax: - return attn_weights, v - - attn_weights_float = nn.softmax(attn_weights, dim=-1) - attn_weights = attn_weights_float.type_as(attn_weights) - - assert v is not None - attn = nn.bmm(attn_weights, v) - assert list(attn.shape) == [bsz * self.num_heads, tgt_len, self.head_dim] - if self.onnx_trace and attn.shape[1] == 1: - # when ONNX tracing a single decoder step (sequence length == 1) - # the transpose is a no-op copy before view, thus unnecessary - attn = attn.view(tgt_len, bsz, embed_dim) + attn_output, attn_output_weights = multi_head_attention_forward( + query, key, value, self.embed_dim, self.num_heads, + self.in_proj_weight, self.in_proj_bias, + self.bias_k, self.bias_v, self.add_zero_attn, + self.dropout, self.out_proj.weight, self.out_proj.bias, + training=self.is_training(), + key_padding_mask=key_padding_mask, + need_weights=need_weights, + attn_mask=attn_mask, + average_attn_weights=average_attn_weights, + is_causal=is_causal) + if self.batch_first and is_batched: + return attn_output.transpose(1, 0), attn_output_weights else: - attn = attn.transpose(1, 0, 2).view(tgt_len, bsz, embed_dim) - attn = self.out_proj(attn) - attn_weights = None - if need_weights: - attn_weights = attn_weights_float.view(bsz, self.num_heads, tgt_len, src_len).transpose(1, 0, 2, 3) - if not need_head_weights: - # average attention weights over heads - attn_weights = attn_weights.mean(dims=[0]) - - return attn, attn_weights + return attn_output, attn_output_weights diff --git a/python/jittor/compatibility/__init__.py b/python/jittor/compatibility/__init__.py deleted file mode 100644 index 6f88fad5..00000000 --- a/python/jittor/compatibility/__init__.py +++ /dev/null @@ -1,430 +0,0 @@ -import os -os.environ["FIX_TORCH_ERROR"] = "0" - -import jittor as jt -from jittor import * -from typing import Tuple - -org_int = int = type(1) -org_float = float = type(1.0) -org_bool = bool = type(True) - -import jtorch.compiler - -import jtorch_core -from jtorch_core import * - -device.__reduce__ = lambda self: (device, (self.type,)) -device.__module__ = "jtorch" -jt.jittor_core.device = device - -def handle_dtype(args, kw, dtype): - def convert(x): - if isinstance(x, jt.Var): - return x.cast(dtype) - return x - if dtype is not None: - if args is not None: - if isinstance(args, (tuple,list)): - args = [ convert(a) for a in args ] - else: - args = convert(x) - if kw is not None: - kw = { k:convert(v) for k,v in kw.items() } - return args, kw - -def get_args_names(func): - import inspect - spec = inspect.getfullargspec(func) - return spec[0] + spec[4] - -def wrapper(func): - has_dtype = False - if hasattr(func, "__code__"): - has_dtype = "dtype" in get_args_names(func) - def inner(*args, **kw): - requires_grad = None - dtype = None - if "requires_grad" in kw: - requires_grad = kw["requires_grad"] - del kw["requires_grad"] - if not has_dtype and "dtype" in kw: - dtype = kw["dtype"] - del kw["dtype"] - if "device" in kw: - del kw["device"] - if 'pin_memory' in kw: - del kw['pin_memory'] - args, kw = handle_dtype(args, kw, dtype) - ret = func(*args, **kw) - if isinstance(ret, jt.Var): - if requires_grad is not None: - ret.requires_grad = requires_grad - if dtype is not None: - ret.astype(dtype) - return ret - return inner - - -import inspect -_wrapper_keys = set(["shape", "start", "size"]) -_wrapper_keys.add("x") -for k,v in list(globals().items()): - if callable(v) and not isinstance(v, type): - try: - spec = inspect.getfullargspec(v) - args_name = spec[0] - if len(args_name) and args_name[0] in _wrapper_keys: - globals()[k] = wrapper(v) - elif spec.varargs in _wrapper_keys: - globals()[k] = wrapper(v) - except: - pass - -def empty(*size, dtype=jt.float32, device=None, requires_grad=False): - if len(size) == 1 and not isinstance(size[0], org_int): - size = size[0] - return jt.empty(size, dtype) - -Tensor = Var - -Tensor.backward = lambda x: jtorch_core.backward(x) -Tensor.grad = property(grad_get, grad_set, grad_del) -Tensor.retains_grad = property(retain_grad_get, retain_grad_set) -def retain_grad(x:Tensor, value:bool=True): - x.retains_grad = value - return value -Tensor.retain_grad = retain_grad - -Tensor.dim = lambda self: self.ndim -Tensor.ndimension = lambda self: self.ndim -Tensor.nelement = lambda self: self.numel() -Tensor.cuda = lambda self: self -def device_get(x:Tensor): - return device("cpu") if not jt.has_cuda or not jt.flags.use_cuda else device("cuda") -Tensor.device = property(device_get) - -def argmax(x: Var, dim=None, keepdim: bool = False): - return jt.argmax(x, dim, keepdim)[0] -Tensor.argmax = argmax - -def tensor_type(x: Var, dtype=None, **kwargs): - if dtype: - return x.astype(dtype) - else: - return x.dtype -Tensor.type = tensor_type - -def is_floating_point(x: Var): - return "float" in str(x.dtype) -Tensor.is_floating_point = is_floating_point - -from . import autograd -from .autograd import * - -def tensor(data, *, dtype=None, device=None, requires_grad=False, pin_memory=False): - if isinstance(data,list): - data_list = [] - check = True - for p in data: - if isinstance(p, Tensor) and p.numel()==1: - data_list.append(p.item()) - elif isinstance(p, (org_int,org_float)): - data_list.append(p) - else: - check = False - break - if check: - data = data_list - return wrapper(array)(data, dtype=dtype, device=device, requires_grad=requires_grad, pin_memory=pin_memory) - -# tensor = wrapper(array) -from_numpy = wrapper(array) -strided = None - -def mod_zero_grad(self): - for p in self.parameters(): - p.grad = None -Module.zero_grad = mod_zero_grad - -class ModuleMisc: - def parameters(self): - return iter(super().parameters()) - - def load_state_dict(self, state_dict, strict=False): - return super().load_state_dict(state_dict) - - def to(self, device=None,dtype=None): - ''' do nothing but return its self''' - return self - def register_parameter(self,name,data): - self.name = data - - def buffers(self): - for _, buf in self.named_buffers(): - yield buf - - -def make_module(cls): - class TMod(ModuleMisc, cls): - def __init__(self, *args, **kw): - dtype = None - if "dtype" in kw: - dtype = kw["dtype"] - del kw["dtype"] - self._dtype = dtype - with jt.flag_scope(th_mode=0): - if "device" in kw: - del kw["device"] - super().__init__(*args, **kw) - for k,v in self.__dict__.items(): - if not k.startswith("_") and isinstance(v, Var) \ - and v.requires_grad: - v.retain_grad() - if dtype is not None and isinstance(v, Var): - v.assign(v.cast(dtype)) - def __call__(self, *args, **kw): - args, kw = handle_dtype(args, kw, self._dtype) - # if forward is override by user, call forward - if self.__class__.forward is not TMod.forward: - return self.forward(*args, **kw) - return self.execute(*args, **kw) - def forward(self, *args, **kw): - args, kw = handle_dtype(args, kw, self._dtype) - return self.execute(*args, **kw) - - @property - def training(self): - if not hasattr(self, "is_train"): - self.is_train = True - return self.is_train - @training.setter - def training(self, value): - self.is_train = value - - TMod.__name__ = cls.__name__ - return TMod - -import jtorch.cuda -import jtorch.nn -from jtorch.nn import Module, Parameter -import jtorch.optim - -from jtorch.utils.dtype import Dtype, get_string_dtype - -def frombuffer(buffer: bytearray, - *, - dtype: Dtype, - count: int = -1, - offset: int = 0, - requires_grad: bool = True) -> Tensor: - dtype = get_string_dtype(dtype) - tensor = jt.array(np.frombuffer(buffer, dtype, count=count, offset=offset)) - if requires_grad and tensor.dtype.is_float(): - tensor.requires_grad = True - return tensor - -def conflict_wrapper(origin_func, new_func): - def wrapper(*args, **kw): - if jt.flags.th_mode: - return new_func(*args, **kw) - else: - return origin_func(*args, **kw) - return wrapper - -def min(*args, **kw): - dim = None - if len(args) >= 2 and isinstance(args[1], org_int): - dim = args[1] - elif "dim" in kw and isinstance(kw["dim"], org_int): - dim = kw["dim"] - if dim is not None: - k, v = jt.argmin(*args, **kw) - return v, k - elif len(args) == 2 and isinstance(args[1], jt.Var): - return jt.minimum(args[0], args[1]) - else: - return jt.min(*args, **kw) -Tensor.min = conflict_wrapper(jt.min, min) - -def max(*args, **kw): - dim = None - if "dim" in kw: - x = kw["dim"] - if len(args) >= 2 and isinstance(args[1], org_int): - dim = args[1] - elif "dim" in kw and isinstance(kw["dim"], org_int): - dim = kw["dim"] - if dim is not None: - k, v = jt.argmax(*args, **kw) - return v, k - elif len(args) == 2 and isinstance(args[1], jt.Var): - return jt.maximum(args[0], args[1]) - else: - return jt.max(*args, **kw) -Tensor.max = conflict_wrapper(jt.max, max) - -def argsort(*args, **kw): - k, v = jt.argsort(*args, **kw) - return k -Tensor.argsort = conflict_wrapper(jt.argsort, argsort) - -LongTensor = jt.int64 -FloatTensor = jt.float -HalfTensor = jt.float16 -BoolTensor = jt.bool -IntTensor = jt.int32 - -class JDType: - def __init__(self, func, str): - self.func = func - self.str = str - self.__name__ = str.split(".")[-1] - def __call__(self, *args, **kw): - return self.func(*args, **kw) - def __str__(self): - return self.str - def is_floating_point(self): - return "float" in str(self.str) - -int8 = JDType(jt.int8, "torch.int8") -int16 = JDType(jt.int16, "torch.int16") -int = int32 = JDType(jt.int32, "torch.int32") -long = int64 = JDType(jt.int64, "torch.int64") - -half = float16 = JDType(jt.float16, "torch.float16") -float = float32 = JDType(jt.float32, "torch.float32") -double = float64 = JDType(jt.float64, "torch.float64") -bfloat16 = "bfloat16" # TODO -complex64 = "complex64" # TODO -complex128 = "complex128" # TODO -def get_JDtype(dtype): - if dtype=='float32' or dtype == jt.float32: - return float32 - elif dtype=='float64' or dtype == jt.float64: - return float64 - elif dtype=='float16' or dtype == jt.float16: - return float16 - elif dtype=='int32' or dtype == jt.int32: - return int32 - elif dtype=='int64' or dtype == jt.int64: - return int64 - elif dtype=='int16' or dtype == jt.int16: - return int16 - elif dtype=='int8' or dtype == jt.int8: - return int8 - else: - raise Exception("dtype {} not supported".format(dtype)) - -def load(path,**kwargs): - def _to_jittor(data): - if isinstance(data,dict): - return {k:_to_jittor(d) for k,d in data.items()} - if isinstance(data,list): - return [_to_jittor(d) for d in data] - if isinstance(data,np.ndarray): - return jt.array(data) - return data - data = jt.load(path) - - return _to_jittor(data) - -def is_tensor(x): - return isinstance(x, Tensor) - -manual_seed = jt.set_global_seed -jt.flags.amp_level = 3 -Size = jt.NanoVector - -class Generator: - def __init__(self,*args,**kw) -> None: - self.seed = None - def manual_seed(self,seed): - self.seed = seed - - - -from . import fx - - -_default_type = "float32" - -def get_default_dtype(): - return _default_type -def set_default_dtype(dtype): - global _default_type - _default_type = dtype - -dtype = JDType - -def div(x,y,rounding_mode="floor"): - assert rounding_mode == "floor" - z = (x / y) - if rounding_mode == "floor": - z = z.floor() - if x.dtype == "int32" and (isinstance(y,org_int) or y.dtype == "int32"): - z = z.int32() - return z - - -def randn(*args,**kw): - wrap_randn = wrapper(jt.randn) - generator = kw.get('generator',None) - kw.pop('generator',None) - if 'layout' in kw: - del kw['layout'] - if generator is not None and generator.seed is not None: - jt.set_global_seed(generator.seed) - return wrap_randn(*args,**kw) - -def rand(*args,**kw): - print("rand") - wrap_rand = wrapper(jt.rand) - generator = kw.get('generator',None) - kw.pop('generator',None) - if 'layout' in kw: - del kw['layout'] - if generator is not None and generator.seed is not None: - jt.set_global_seed(generator.seed) - return wrap_rand(*args,**kw) - - - -def set_default_tensor_type(t: type or str): - if isinstance(t, str): - info = t.split(".") - if len(info) == 3 and info[1] == 'cuda': - jt.flags.use_cuda = 1 - #TODO: type - - -def clamp(x, min=None, max=None): - return jt.clamp(x, min, max) - - -def to(x,*args,**kw): - device = None - if len(args) == 1: - device = args[0] - if isinstance(device, jt.NanoString) or callable(device): - return jt.to(x,*args,**kw) - if 'cpu' in str(device): - args = [] - device = kw.get("device",None) - if 'cpu' in str(device): - kw.pop('device',None) - print("to cpu") - # print(kw) - return jt.to(x,*args,**kw) -Tensor.to = conflict_wrapper(jt.to, to) - -mm = wrapper(jt.matmul) - -def _data_get(x): - return x - -def _data_set(x, value): - x.assign(value) - -Tensor.data = property(_data_get, _data_set) -Tensor.layout = None \ No newline at end of file diff --git a/python/jittor/compatibility/autograd.py b/python/jittor/compatibility/autograd.py deleted file mode 100644 index 5ed88dde..00000000 --- a/python/jittor/compatibility/autograd.py +++ /dev/null @@ -1,134 +0,0 @@ -import jittor as jt -from jittor import Var -from collections.abc import Sequence, Mapping - -Variable = Var - -class FunctionContext: - def save_for_backward(self, *args): - self.saved_tensors = args - -class Function: - ''' Function Module for customized backward operations - -Example 1 (Function can have multiple input and multiple output, and user -can store value for backward computation):: - - import jtorch - from jtorch import Function - - class MyFunc(Function): - @staticmethod - def forward(self, x, y): - self.x = x - self.y = y - return x*y, x/y - - @staticmethod - def backward(self, grad0, grad1): - return grad0 * self.y, grad1 * self.x - - a = jtorch.array(3.0) - a.requires_grad = True - b = jtorch.array(4.0) - b.requires_grad = True - func = MyFunc.apply - c,d = func(a, b) - (c+d*3).backward() - assert a.grad.data == 4 - assert b.grad.data == 9 - -Example 2(Function can return None for no gradiant, and gradiant -can also be None):: - - import jtorch - from jtorch import Function - - class MyFunc(Function): - @staticmethod - def forward(self, x, y): - self.x = x - self.y = y - return x*y, x/y - - @staticmethod - def backward(self, grad0, grad1): - assert grad1 is None - return grad0 * self.y, None - a = jt.array(3.0) - a.requires_grad = True - b = jt.array(4.0) - b.requires_grad = True - func = MyFunc.apply - c,d = func(a, b) - d.stop_grad() - da, db = jt.grad(c+d*3, [a, b]) - assert da.data == 4 - assert db.data == 0 - - ''' - def __call__(self, *args): - backup = args - args = list(args) - taped_inputs = [] - taped_outputs = [] - input_mask = [-1] * len(args) - for i,v in enumerate(args): - if isinstance(v, Var): - if v.is_stop_grad(): - # -2 in input_mask represents it is stop_grad - input_mask[i] = -2 - continue - v = v.tape() - input_mask[i] = len(taped_inputs) - args[i] = v - taped_inputs.append(v) - ctx = FunctionContext() - ori_res = self.forward(ctx, *args) - # ori_res = self.execute(*args) - if not isinstance(ori_res, Sequence): - res = [ori_res] - else: - res = list(ori_res) - output_mask = [-1] * len(res) - for i,v in enumerate(res): - if isinstance(v, Var): - v = v.tape() - output_mask[i] = len(taped_outputs) - res[i] = v - taped_outputs.append(v) - ctx.input_mask = input_mask - ctx.output_mask = output_mask - # tape output and input together so - # backward treat them as one operator - jt.tape_together(taped_inputs, taped_outputs, - lambda *args: self._grad(ctx, self, *args)) - if isinstance(ori_res, Sequence): - return res - else: - return res[0] - - @staticmethod - def _grad(ctx, func, *args): - new_args = ( (args[i] if i>=0 else None) for i in ctx.output_mask ) - ret = func.backward(ctx, *new_args) - if not isinstance(ret, Sequence): - ret = (ret,) - new_ret = [] - for i, r in enumerate(ret): - j = ctx.input_mask[i] - if j<0: - # -2 in input_mask represents it is stop_grad - assert r is None or j==-2, f"{type(self)}'s {i}-th returned grad should be None, "\ - "because the input value is not jittor variable." - else: - new_ret.append(r) - return new_ret - - def dfs(self, parents, k, callback, callback_leave=None): - pass - - @classmethod - def apply(cls, *args, **kw): - func = cls() - return func(*args, **kw) diff --git a/python/jittor/compatibility/compiler.py b/python/jittor/compatibility/compiler.py deleted file mode 100644 index 77bab138..00000000 --- a/python/jittor/compatibility/compiler.py +++ /dev/null @@ -1,39 +0,0 @@ -import jittor as jt -import jittor_utils -import glob -import os -from jittor import pyjt_compiler -import sys -from jittor_utils import lock - - -jtorch_path = os.path.dirname(__file__) -cache_path = os.path.join(jt.compiler.cache_path, "jtorch") -# os.makedirs(cache_path, exist_ok=True) -os.makedirs(os.path.join(cache_path, "gen"), exist_ok=True) - -with lock.lock_scope(): - pyjt_gen_src = pyjt_compiler.compile(cache_path, jtorch_path) - -ext_args = 'c[cu]' if jt.has_cuda else 'cc' -files = glob.glob(jtorch_path+"/src/**/*."+ext_args, recursive=True) -files += pyjt_gen_src -cc_flags = " -I\""+os.path.join(jtorch_path, "src")+"\" " -if os.environ.get("use_data_o", "1") == "1": - files += glob.glob(jtorch_path+"/src/**/*.o", recursive=True) - files = [f for f in files if "__data__" not in f] - - -with lock.lock_scope(): - jt.compiler.compile( - jt.compiler.cc_path, - jt.compiler.cc_flags+jt.compiler.opt_flags+ cc_flags, - files, - "jtorch_core"+jt.compiler.extension_suffix, - obj_dirname="jtorch_objs") - - -with jittor_utils.import_scope(jt.compiler.import_flags): - import jtorch_core as core - -jt.flags.th_mode = 1 diff --git a/python/jittor/compatibility/cuda.py b/python/jittor/compatibility/cuda.py deleted file mode 100644 index 75665c7c..00000000 --- a/python/jittor/compatibility/cuda.py +++ /dev/null @@ -1,64 +0,0 @@ -import jittor as jt -import jtorch - -def is_available(): - return jt.has_cuda - -def device_count(): - return int(jt.has_cuda) - -def set_device(device=None): - pass - -def get_rng_state(device=None): - pass - -def current_device(): - return jtorch.device("cuda") - -def mem_get_info(i): - return ("75GB",) - - -class Generator: - def __init__(self): - pass - - def set_state(self, state): - self.state = state - -default_generators = [Generator()] -_lazy_call = lambda func: func() -device = None - -LongTensor = jt.int64 -FloatTensor = jt.float -HalfTensor = jt.float16 -BoolTensor = jt.bool - -manual_seed = jt.set_global_seed -manual_seed_all = jt.set_global_seed - -def synchronize(): - jt.sync_all(True) - -class Event: - pass - -class Stream: - pass - -from typing import Any - -from .gradscaler import GradScaler - -class autocast: - def __init__(self,**kwargs): - pass - - def __enter__(self,): - pass - - def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any): - pass - diff --git a/python/jittor/compatibility/distributed.py b/python/jittor/compatibility/distributed.py deleted file mode 100644 index e39f559a..00000000 --- a/python/jittor/compatibility/distributed.py +++ /dev/null @@ -1,53 +0,0 @@ -import datetime -from enum import Enum -import jittor as jt - - -class DistributedDataParallel: - def __new__(cls, model): - return model - -def is_initialized(): - return True - -def get_rank(group=None): - return 0 - -def get_world_size(group=None): - return 1 - -def get_backend(group=None): - return "nccl" - -def new_group(ranks=None, timeout=datetime.timedelta(seconds=1800), backend=None, pg_options=None): - return 1 - -def barrier(): - pass - -def is_available(): - return True - -def is_built(): - return True - -class ReduceOp: - SUM = 0 - -class GroupMember: - WORLD = 0 - -class ProcessGroup: - pass - -class Join: - pass - -dist_backend = Enum("dist_backend", ("GLOO", "MPI", "NCCL")) -_backend = dist_backend.NCCL - -def is_mpi_available(): - return jt.in_mpi - -def DistributedDataParallel(model, *args, **kw): - return model diff --git a/python/jittor/compatibility/distributions.py b/python/jittor/compatibility/distributions.py deleted file mode 100644 index a98dfe29..00000000 --- a/python/jittor/compatibility/distributions.py +++ /dev/null @@ -1,15 +0,0 @@ -import jittor as jt - -class RelaxedBernoulli: - def __init__(self, temperature, probs=None, logits=None): - self.temperature = temperature - self.probs = probs - self.logits = logits - - def rsample(self): - noise = jt.rand_like(self.logits) - eps = 1e-20 - noise = jt.clamp(noise, eps, 1.0 - eps) - logit_noise = jt.log(noise) - jt.log(1 - noise) - sample = (self.logits + logit_noise) / self.temperature - return jt.sigmoid(sample) diff --git a/python/jittor/compatibility/fft/__init__.py b/python/jittor/compatibility/fft/__init__.py deleted file mode 100644 index 7a89fc9c..00000000 --- a/python/jittor/compatibility/fft/__init__.py +++ /dev/null @@ -1,5 +0,0 @@ -#TODO: Implement FFT and IFFT -fftn = None -fftshift = None -ifftn = None -ifftshift = None \ No newline at end of file diff --git a/python/jittor/compatibility/fx.py b/python/jittor/compatibility/fx.py deleted file mode 100644 index 0f0eb4f8..00000000 --- a/python/jittor/compatibility/fx.py +++ /dev/null @@ -1,2 +0,0 @@ -class Proxy: - pass \ No newline at end of file diff --git a/python/jittor/compatibility/gradscaler.py b/python/jittor/compatibility/gradscaler.py deleted file mode 100644 index 087d6bb2..00000000 --- a/python/jittor/compatibility/gradscaler.py +++ /dev/null @@ -1,519 +0,0 @@ -from collections import defaultdict, abc -from enum import Enum -from typing import Any, Dict, List, Optional, Tuple, cast -import inspect -import warnings - -import jittor as jt -# import torch - -def _refresh_per_optimizer_state(): - return {} - - -class GradScaler: - _scale: Optional[jt.Var] - _grows_tracker: Optional[jt.Var] - _per_optimizer_states: Dict[int, Dict[str, Any]] - """ - An instance ``scaler`` of :class:`GradScaler` helps perform the steps of gradient scaling - conveniently. - - * ``scaler.scale(loss)`` multiplies a given loss by ``scaler``'s current scale factor. - * ``scaler.step(optimizer)`` safely unscales gradients and calls ``optimizer.step()``. - * ``scaler.update()`` updates ``scaler``'s scale factor. - - Example:: - - # Creates a GradScaler once at the beginning of training. - scaler = GradScaler() - - for epoch in epochs: - for input, target in data: - optimizer.zero_grad() - output = model(input) - loss = loss_fn(output, target) - - # Scales loss. Calls backward() on scaled loss to create scaled gradients. - scaler.scale(loss).backward() - - # scaler.step() first unscales gradients of the optimizer's params. - # If gradients don't contain infs/NaNs, optimizer.step() is then called, - # otherwise, optimizer.step() is skipped. - scaler.step(optimizer) - - # Updates the scale for next iteration. - scaler.update() - - See the :ref:`Automatic Mixed Precision examples` for usage - (along with autocasting) in more complex cases like gradient clipping, gradient accumulation, gradient penalty, - and multiple losses/optimizers. - - ``scaler`` dynamically estimates the scale factor each iteration. To minimize gradient underflow, - a large scale factor should be used. However, ``float16`` values can "overflow" (become inf or NaN) if - the scale factor is too large. Therefore, the optimal scale factor is the largest factor that can be used - without incurring inf or NaN gradient values. - ``scaler`` approximates the optimal scale factor over time by checking the gradients for infs and NaNs during every - ``scaler.step(optimizer)`` (or optional separate ``scaler.unscale_(optimizer)``, see :meth:`unscale_`). - - * If infs/NaNs are found, ``scaler.step(optimizer)`` skips the underlying ``optimizer.step()`` (so the params - themselves remain uncorrupted) and ``update()`` multiplies the scale by ``backoff_factor``. - - * If no infs/NaNs are found, ``scaler.step(optimizer)`` runs the underlying ``optimizer.step()`` as usual. - If ``growth_interval`` unskipped iterations occur consecutively, ``update()`` multiplies the scale by - ``growth_factor``. - - The scale factor often causes infs/NaNs to appear in gradients for the first few iterations as its - value calibrates. ``scaler.step`` will skip the underlying ``optimizer.step()`` for these - iterations. After that, step skipping should occur rarely (once every few hundred or thousand iterations). - - Args: - init_scale (float, optional, default=2.**16): Initial scale factor. - growth_factor (float, optional, default=2.0): Factor by which the scale is multiplied during - :meth:`update` if no inf/NaN gradients occur for ``growth_interval`` consecutive iterations. - backoff_factor (float, optional, default=0.5): Factor by which the scale is multiplied during - :meth:`update` if inf/NaN gradients occur in an iteration. - growth_interval (int, optional, default=2000): Number of consecutive iterations without inf/NaN gradients - that must occur for the scale to be multiplied by ``growth_factor``. - enabled (bool, optional): If ``False``, disables gradient scaling. :meth:`step` simply - invokes the underlying ``optimizer.step()``, and other methods become no-ops. - Default: ``True`` - """ - def __init__(self, - init_scale=2.**16, - growth_factor=2.0, - backoff_factor=0.5, - growth_interval=2000, - enabled=True): - self._enabled = enabled - - if self._enabled: - assert growth_factor > 1.0, "The growth factor must be > 1.0." - assert backoff_factor < 1.0, "The backoff factor must be < 1.0." - - self._init_scale = init_scale - # self._scale will be lazily initialized during the first call to scale() - self._scale = None - self._growth_factor = growth_factor - self._backoff_factor = backoff_factor - self._growth_interval = growth_interval - self._init_growth_tracker = 0 - # self._growth_tracker will be lazily initialized during the first call to scale() - self._growth_tracker = None - self._per_optimizer_states = defaultdict(_refresh_per_optimizer_state) - - def _check_scale_growth_tracker(self, funcname) -> Tuple[jt.Var, jt.Var]: - fix = "This may indicate your script did not use scaler.scale(loss or outputs) earlier in the iteration." - assert self._scale is not None, "Attempted {} but _scale is None. ".format(funcname) + fix - assert self._growth_tracker is not None, "Attempted {} but _growth_tracker is None. ".format(funcname) + fix - return (self._scale, self._growth_tracker) - - def _lazy_init_scale_growth_tracker(self): - assert self._growth_tracker is None, "_growth_tracker initialized before _scale" - self._scale = self._init_scale - self._growth_tracker = self._init_growth_tracker - - def scale(self, outputs): - """ - Multiplies ('scales') a tensor or list of tensors by the scale factor. - - Returns scaled outputs. If this instance of :class:`GradScaler` is not enabled, outputs are returned - unmodified. - - Args: - outputs (Tensor or iterable of Tensors): Outputs to scale. - """ - if not self._enabled: - return outputs - - - # Short-circuit for the common case. - if isinstance(outputs, jt.Var): - assert jt.flags.use_cuda == 1 - if self._scale is None: - self._lazy_init_scale_growth_tracker() - assert self._scale is not None - return outputs * self._scale - - def apply_scale(val): - if isinstance(val, jt.Var): - assert jt.flags.use_cuda == 1 - if self._scale is None: - self._lazy_init_scale_growth_tracker() - assert self._scale is not None - return val * self._scale - elif isinstance(val, abc.Iterable): - iterable = map(apply_scale, val) - if isinstance(val, (list, tuple)): - return type(val)(iterable) - else: - return iterable - else: - raise ValueError("outputs must be a Tensor or an iterable of Tensors") - - return apply_scale(outputs) - - def _unscale_grads_(self, optimizer, inv_scale, found_inf, allow_fp16): - with jt.no_grad(): - optimizer.pre_step() - for group in optimizer.param_groups: - for to_unscale in group["grads"]: - if to_unscale is None or isinstance(to_unscale,(int,float)): - continue - if (not allow_fp16) and str(to_unscale.dtype) == "float16": - raise ValueError("Attempting to unscale FP16 gradients.") - - if not (to_unscale.isinf().any()): - if inv_scale != 1.0: - to_unscale.update(to_unscale*inv_scale) - else: - found_inf = 1.0 - - return found_inf - - def unscale_(self, optimizer): - """ - Divides ("unscales") the optimizer's gradient tensors by the scale factor. - - :meth:`unscale_` is optional, serving cases where you need to - :ref:`modify or inspect gradients` - between the backward pass(es) and :meth:`step`. - If :meth:`unscale_` is not called explicitly, gradients will be unscaled automatically during :meth:`step`. - - Simple example, using :meth:`unscale_` to enable clipping of unscaled gradients:: - - ... - scaler.scale(loss).backward() - scaler.unscale_(optimizer) - torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm) - scaler.step(optimizer) - scaler.update() - - Args: - optimizer (torch.optim.Optimizer): Optimizer that owns the gradients to be unscaled. - - .. note:: - :meth:`unscale_` does not incur a CPU-GPU sync. - - .. warning:: - :meth:`unscale_` should only be called once per optimizer per :meth:`step` call, - and only after all gradients for that optimizer's assigned parameters have been accumulated. - Calling :meth:`unscale_` twice for a given optimizer between each :meth:`step` triggers a RuntimeError. - - .. warning:: - :meth:`unscale_` may unscale sparse gradients out of place, replacing the ``.grad`` attribute. - """ - if not self._enabled: - return - - self._check_scale_growth_tracker("unscale_") - - optimizer_state = self._per_optimizer_states[id(optimizer)] - - if hasattr(optimizer,"get_find_inf"): - return - # FP32 division can be imprecise for certain compile options, so we carry out the reciprocal in FP64. - assert self._scale is not None - inv_scale = 1.0 / self._scale - found_inf = 0.0 - optimizer_state["found_inf_per_device"] = self._unscale_grads_(optimizer, inv_scale, found_inf, False) - - - def step(self, optimizer, *args, **kwargs): - """ - :meth:`step` carries out the following two operations: - - 1. Internally invokes ``unscale_(optimizer)`` (unless :meth:`unscale_` was explicitly called for ``optimizer`` - earlier in the iteration). As part of the :meth:`unscale_`, gradients are checked for infs/NaNs. - 2. If no inf/NaN gradients are found, invokes ``optimizer.step()`` using the unscaled - gradients. Otherwise, ``optimizer.step()`` is skipped to avoid corrupting the params. - - ``*args`` and ``**kwargs`` are forwarded to ``optimizer.step()``. - - Returns the return value of ``optimizer.step(*args, **kwargs)``. - - Args: - optimizer (torch.optim.Optimizer): Optimizer that applies the gradients. - args: Any arguments. - kwargs: Any keyword arguments. - - .. warning:: - Closure use is not currently supported. - """ - if (not self._enabled): - return optimizer.step(*args, **kwargs) - - if "closure" in kwargs: - raise RuntimeError("Closure use is not currently supported if GradScaler is enabled.") - - self._check_scale_growth_tracker("step") - - optimizer_state = self._per_optimizer_states[id(optimizer)] - retval = None - - if (hasattr(optimizer, "_step_supports_amp_scaling") and optimizer._step_supports_amp_scaling): - # This optimizer has customized scale-handling logic, so we can call optimizer.step() directly. - # The contract with custom optimizers is that their step() should accept an additional, - # optional grad_scaler kwarg. We append self to the kwargs so the custom optimizer has full information: - # it can query its own state, invoke unscale_ on itself, etc - # The contract above is being deprecated to avoid introducing `grad_scaler: GradScaler` argument - # to `Optimizer.step`. The new behavior is going to add two Tensor attributes of `grad_scale` - # and `found_inf` to the passed optimizer so that the optimizer can utilize those - # to skip the parameter updates or unscale gradients before updating parameters in - # the fused kernel, e.g. `FusedAdamMathFunctor`. - # In this behavior, `GradScaler._check_inf_per_device` is called if `OptState.READY`, - # while the method is expected to be called by users side, i.e. their optimizers. - kwargs_ = kwargs - has_grad_scaler_kwarg = "grad_scaler" in inspect.signature(optimizer.step).parameters - if has_grad_scaler_kwarg: - warnings.warn( - "GradScaler is going to stop passing itself as a keyword argument to the passed " - "optimizer. In the near future GradScaler registers `grad_scale: Tensor` and " - "`found_inf: Tensor` to the passed optimizer and let the optimizer use them directly.", - FutureWarning) - kwargs_.update({"grad_scaler": self}) - else: - if optimizer_state["stage"] is OptState.READY: - self._check_inf_per_device(optimizer) - scaler = self._get_scale_async() - found_inf = cast( - jt.Var, - sum([ - t for t in optimizer_state["found_inf_per_device"].values() - ]) - ) - optimizer.grad_scale = None if optimizer_state["stage"] == OptState.UNSCALED else scaler - optimizer.found_inf = found_inf - retval = optimizer.step(*args, **kwargs_) - optimizer_state["stage"] = OptState.STEPPED - if not has_grad_scaler_kwarg: - del optimizer.grad_scale - del optimizer.found_inf - return retval - - if hasattr(optimizer,"get_find_inf"): - optimizer.set_grad_scale(self._scale) - optimizer.step() - optimizer_state["found_inf_per_device"] = optimizer.get_find_inf() - return - - retval = None - if not optimizer_state["found_inf_per_device"]: - retval = optimizer.step(*args, **kwargs) - else: - optimizer.post_step() - - return retval - - - def update(self, new_scale=None): - """ - Updates the scale factor. - - If any optimizer steps were skipped the scale is multiplied by ``backoff_factor`` - to reduce it. If ``growth_interval`` unskipped iterations occurred consecutively, - the scale is multiplied by ``growth_factor`` to increase it. - - Passing ``new_scale`` sets the new scale value manually. (``new_scale`` is not - used directly, it's used to fill GradScaler's internal scale tensor. So if - ``new_scale`` was a tensor, later in-place changes to that tensor will not further - affect the scale GradScaler uses internally.) - - Args: - new_scale (float or :class:`torch.cuda.FloatTensor`, optional, default=None): New scale factor. - - .. warning:: - :meth:`update` should only be called at the end of the iteration, after ``scaler.step(optimizer)`` has - been invoked for all optimizers used this iteration. - """ - if not self._enabled: - return - - _scale, _growth_tracker = self._check_scale_growth_tracker("update") - - if new_scale is not None: - # Accept a new user-defined scale. - if isinstance(new_scale, float): - self._scale.fill_(new_scale) # type: ignore[union-attr] - else: - reason = "new_scale should be a float or a 1-element torch.cuda.FloatTensor with requires_grad=False." - assert isinstance(new_scale, torch.cuda.FloatTensor), reason # type: ignore[attr-defined] - assert new_scale.numel() == 1, reason - assert new_scale.requires_grad is False, reason - self._scale.copy_(new_scale) # type: ignore[union-attr] - else: - # Consume shared inf/nan data collected from optimizers to update the scale. - # If all found_inf tensors are on the same device as self._scale, this operation is asynchronous. - found_infs = [state["found_inf_per_device"] - for state in self._per_optimizer_states.values() - ] - - assert len(found_infs) > 0, "No inf checks were recorded prior to update." - - found_inf_combined = found_infs[0] - if len(found_infs) > 1: - for i in range(1, len(found_infs)): - found_inf_combined += found_infs[i] - - - current_scale = _scale - if found_inf_combined: - current_scale *=self._backoff_factor - _growth_tracker = 0 - else: - successful = _growth_tracker+1 - if successful == self._growth_interval: - new_scale = current_scale*self._growth_factor - if new_scale < 1e9: - current_scale = new_scale - _growth_tracker = 0 - else: - _growth_tracker = successful - - self._scale, self._growth_tracker = current_scale,_growth_tracker - - # To prepare for next iteration, clear the data collected from optimizers this iteration. - self._per_optimizer_states = defaultdict(_refresh_per_optimizer_state) - - def _get_scale_async(self): - return self._scale - - def get_scale(self): - """ - Returns a Python float containing the current scale, or 1.0 if scaling is disabled. - - .. warning:: - :meth:`get_scale` incurs a CPU-GPU sync. - """ - if self._enabled: - return self._init_scale if self._scale is None else self._get_scale_async() - else: - return 1.0 - - def get_growth_factor(self): - r""" - Returns a Python float containing the scale growth factor. - """ - return self._growth_factor - - def set_growth_factor(self, new_factor): - r""" - Args: - new_scale (float): Value to use as the new scale growth factor. - """ - self._growth_factor = new_factor - - def get_backoff_factor(self): - r""" - Returns a Python float containing the scale backoff factor. - """ - return self._backoff_factor - - def set_backoff_factor(self, new_factor): - r""" - Args: - new_scale (float): Value to use as the new scale backoff factor. - """ - self._backoff_factor = new_factor - - def get_growth_interval(self): - r""" - Returns a Python int containing the growth interval. - """ - return self._growth_interval - - def set_growth_interval(self, new_interval): - r""" - Args: - new_interval (int): Value to use as the new growth interval. - """ - self._growth_interval = new_interval - - def _get_growth_tracker(self): - if self._enabled: - return self._init_growth_tracker if self._growth_tracker is None else self._growth_tracker.item() - else: - return 0 - - def is_enabled(self): - r""" - Returns a bool indicating whether this instance is enabled. - """ - return self._enabled - - def state_dict(self): - r""" - Returns the state of the scaler as a :class:`dict`. It contains five entries: - - * ``"scale"`` - a Python float containing the current scale - * ``"growth_factor"`` - a Python float containing the current growth factor - * ``"backoff_factor"`` - a Python float containing the current backoff factor - * ``"growth_interval"`` - a Python int containing the current growth interval - * ``"_growth_tracker"`` - a Python int containing the number of recent consecutive unskipped steps. - - If this instance is not enabled, returns an empty dict. - - .. note:: - If you wish to checkpoint the scaler's state after a particular iteration, :meth:`state_dict` - should be called after :meth:`update`. - """ - return {"scale": self.get_scale(), - "growth_factor": self._growth_factor, - "backoff_factor": self._backoff_factor, - "growth_interval": self._growth_interval, - "_growth_tracker": self._get_growth_tracker()} if self._enabled else {} - - def load_state_dict(self, state_dict): - r""" - Loads the scaler state. If this instance is disabled, :meth:`load_state_dict` is a no-op. - - Args: - state_dict(dict): scaler state. Should be an object returned from a call to :meth:`state_dict`. - """ - if not self._enabled: - return - - if len(state_dict) == 0: - raise RuntimeError("The source state dict is empty, possibly because it was saved " - "from a disabled instance of GradScaler.") - - self._init_scale = state_dict["scale"] - if self._scale is not None: - self._scale.fill_(state_dict["scale"]) - self._growth_factor = state_dict["growth_factor"] - self._backoff_factor = state_dict["backoff_factor"] - self._growth_interval = state_dict["growth_interval"] - self._init_growth_tracker = state_dict["_growth_tracker"] - if self._growth_tracker is not None: - self._growth_tracker.fill_(state_dict["_growth_tracker"]) - - def __getstate__(self): - state = self.__dict__.copy() - if self._enabled: - assert len(self._per_optimizer_states) == 0, "A GradScaler instance may only be pickled at the beginning "\ - "of an iteration, or at the end after scaler.update()." - # Pickling _scale and _growth_tracker Tensors directly triggers - # "warnings.warn("pickle support for Storage will be removed in 1.5..." - # so instead, we set the unpickled instance up to reinitialize them lazily. - state['_init_scale'] = self.get_scale() - state['_init_growth_tracker'] = self._get_growth_tracker() - state['_scale'] = None - state['_growth_tracker'] = None - return state - - def __setstate__(self, state): - self.__dict__.update(state) - - def _check_inf_per_device(self, optimizer): - _scale, _ = self._check_scale_growth_tracker("_check_inf_per_device") - - dummy_inv_scale = 1.0 - found_inf = 0.0 - - self._per_optimizer_states[id(optimizer)]["found_inf_per_device"] = \ - self._unscale_grads_(optimizer, dummy_inv_scale, found_inf, True) - - return self._per_optimizer_states[id(optimizer)]["found_inf_per_device"] - - def _found_inf_per_device(self, optimizer): - return self._per_optimizer_states[id(optimizer)]["found_inf_per_device"] diff --git a/python/jittor/compatibility/gradscaler_old.py b/python/jittor/compatibility/gradscaler_old.py deleted file mode 100644 index 389be2cf..00000000 --- a/python/jittor/compatibility/gradscaler_old.py +++ /dev/null @@ -1,556 +0,0 @@ -from collections import defaultdict, abc -from enum import Enum -from typing import Any, Dict, List, Optional, Tuple, cast -import inspect -import warnings - -import jittor as jt -# import torch - - -__all__ = ["OptState", "GradScaler"] - - -# Defines default_factory for GradScaler's _per_optimizer_states defaultdict, -# as well as associated "enum" values. Prefers defining these at top level because -# - Lambdas can't be pickled, so we don't want to supply a lambda as the factory. -# - Defining READY, UNSCALED, STEPPED and _refresh_per_optimizer_state within GradScaler -# causes a circular reference, which we'd rather avoid. -class OptState(Enum): - READY = 0 - UNSCALED = 1 - STEPPED = 2 - - -def _refresh_per_optimizer_state(): - return {"stage": OptState.READY, "found_inf_per_device": {}} - - -class GradScaler: - _scale: Optional[jt.Var] - _grows_tracker: Optional[jt.Var] - _per_optimizer_states: Dict[int, Dict[str, Any]] - """ - An instance ``scaler`` of :class:`GradScaler` helps perform the steps of gradient scaling - conveniently. - - * ``scaler.scale(loss)`` multiplies a given loss by ``scaler``'s current scale factor. - * ``scaler.step(optimizer)`` safely unscales gradients and calls ``optimizer.step()``. - * ``scaler.update()`` updates ``scaler``'s scale factor. - - Example:: - - # Creates a GradScaler once at the beginning of training. - scaler = GradScaler() - - for epoch in epochs: - for input, target in data: - optimizer.zero_grad() - output = model(input) - loss = loss_fn(output, target) - - # Scales loss. Calls backward() on scaled loss to create scaled gradients. - scaler.scale(loss).backward() - - # scaler.step() first unscales gradients of the optimizer's params. - # If gradients don't contain infs/NaNs, optimizer.step() is then called, - # otherwise, optimizer.step() is skipped. - scaler.step(optimizer) - - # Updates the scale for next iteration. - scaler.update() - - See the :ref:`Automatic Mixed Precision examples` for usage - (along with autocasting) in more complex cases like gradient clipping, gradient accumulation, gradient penalty, - and multiple losses/optimizers. - - ``scaler`` dynamically estimates the scale factor each iteration. To minimize gradient underflow, - a large scale factor should be used. However, ``float16`` values can "overflow" (become inf or NaN) if - the scale factor is too large. Therefore, the optimal scale factor is the largest factor that can be used - without incurring inf or NaN gradient values. - ``scaler`` approximates the optimal scale factor over time by checking the gradients for infs and NaNs during every - ``scaler.step(optimizer)`` (or optional separate ``scaler.unscale_(optimizer)``, see :meth:`unscale_`). - - * If infs/NaNs are found, ``scaler.step(optimizer)`` skips the underlying ``optimizer.step()`` (so the params - themselves remain uncorrupted) and ``update()`` multiplies the scale by ``backoff_factor``. - - * If no infs/NaNs are found, ``scaler.step(optimizer)`` runs the underlying ``optimizer.step()`` as usual. - If ``growth_interval`` unskipped iterations occur consecutively, ``update()`` multiplies the scale by - ``growth_factor``. - - The scale factor often causes infs/NaNs to appear in gradients for the first few iterations as its - value calibrates. ``scaler.step`` will skip the underlying ``optimizer.step()`` for these - iterations. After that, step skipping should occur rarely (once every few hundred or thousand iterations). - - Args: - init_scale (float, optional, default=2.**16): Initial scale factor. - growth_factor (float, optional, default=2.0): Factor by which the scale is multiplied during - :meth:`update` if no inf/NaN gradients occur for ``growth_interval`` consecutive iterations. - backoff_factor (float, optional, default=0.5): Factor by which the scale is multiplied during - :meth:`update` if inf/NaN gradients occur in an iteration. - growth_interval (int, optional, default=2000): Number of consecutive iterations without inf/NaN gradients - that must occur for the scale to be multiplied by ``growth_factor``. - enabled (bool, optional): If ``False``, disables gradient scaling. :meth:`step` simply - invokes the underlying ``optimizer.step()``, and other methods become no-ops. - Default: ``True`` - """ - def __init__(self, - init_scale=2.**16, - growth_factor=2.0, - backoff_factor=0.5, - growth_interval=2000, - enabled=True): - self._enabled = enabled - - if self._enabled: - assert growth_factor > 1.0, "The growth factor must be > 1.0." - assert backoff_factor < 1.0, "The backoff factor must be < 1.0." - - self._init_scale = init_scale - # self._scale will be lazily initialized during the first call to scale() - self._scale = None - self._growth_factor = growth_factor - self._backoff_factor = backoff_factor - self._growth_interval = growth_interval - self._init_growth_tracker = 0 - # self._growth_tracker will be lazily initialized during the first call to scale() - self._growth_tracker = None - self._per_optimizer_states = defaultdict(_refresh_per_optimizer_state) - - def _check_scale_growth_tracker(self, funcname) -> Tuple[jt.Var, jt.Var]: - fix = "This may indicate your script did not use scaler.scale(loss or outputs) earlier in the iteration." - assert self._scale is not None, "Attempted {} but _scale is None. ".format(funcname) + fix - assert self._growth_tracker is not None, "Attempted {} but _growth_tracker is None. ".format(funcname) + fix - return (self._scale, self._growth_tracker) - - def _lazy_init_scale_growth_tracker(self): - assert self._growth_tracker is None, "_growth_tracker initialized before _scale" - self._scale = self._init_scale - self._growth_tracker = self._init_growth_tracker - - def scale(self, outputs): - """ - Multiplies ('scales') a tensor or list of tensors by the scale factor. - - Returns scaled outputs. If this instance of :class:`GradScaler` is not enabled, outputs are returned - unmodified. - - Args: - outputs (Tensor or iterable of Tensors): Outputs to scale. - """ - print("scale") - if not self._enabled: - return outputs - - - # Short-circuit for the common case. - if isinstance(outputs, jt.Var): - assert jt.flags.use_cuda == 1 - if self._scale is None: - self._lazy_init_scale_growth_tracker() - assert self._scale is not None - return outputs * self._scale - - def apply_scale(val): - if isinstance(val, jt.Var): - assert jt.flags.use_cuda == 1 - if self._scale is None: - self._lazy_init_scale_growth_tracker() - assert self._scale is not None - return val * self._scale - elif isinstance(val, abc.Iterable): - iterable = map(apply_scale, val) - if isinstance(val, (list, tuple)): - return type(val)(iterable) - else: - return iterable - else: - raise ValueError("outputs must be a Tensor or an iterable of Tensors") - - return apply_scale(outputs) - - def _unscale_grads_(self, optimizer, inv_scale, found_inf, allow_fp16): - - # To set up _amp_foreach_non_finite_check_and_unscale_, split grads by device and dtype. - # There could be hundreds of grads, so we'd like to iterate through them just once. - # However, we don't know their devices or dtypes in advance. - - # https://stackoverflow.com/questions/5029934/defaultdict-of-defaultdict - # Google says mypy struggles with defaultdicts type annotations. - with jt.no_grad(): - optimizer.pre_step() - for group in optimizer.param_groups: - for to_unscale in group["grads"]: - if to_unscale is None or isinstance(to_unscale,(int,float)): - continue - if (not allow_fp16) and str(to_unscale.dtype) == "float16": - raise ValueError("Attempting to unscale FP16 gradients.") - - if not (to_unscale.isinf().any()): - if inv_scale != 1.0: - to_unscale.update(to_unscale*inv_scale) - else: - found_inf = 1.0 - - return found_inf - - def unscale_(self, optimizer): - """ - Divides ("unscales") the optimizer's gradient tensors by the scale factor. - - :meth:`unscale_` is optional, serving cases where you need to - :ref:`modify or inspect gradients` - between the backward pass(es) and :meth:`step`. - If :meth:`unscale_` is not called explicitly, gradients will be unscaled automatically during :meth:`step`. - - Simple example, using :meth:`unscale_` to enable clipping of unscaled gradients:: - - ... - scaler.scale(loss).backward() - scaler.unscale_(optimizer) - torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm) - scaler.step(optimizer) - scaler.update() - - Args: - optimizer (torch.optim.Optimizer): Optimizer that owns the gradients to be unscaled. - - .. note:: - :meth:`unscale_` does not incur a CPU-GPU sync. - - .. warning:: - :meth:`unscale_` should only be called once per optimizer per :meth:`step` call, - and only after all gradients for that optimizer's assigned parameters have been accumulated. - Calling :meth:`unscale_` twice for a given optimizer between each :meth:`step` triggers a RuntimeError. - - .. warning:: - :meth:`unscale_` may unscale sparse gradients out of place, replacing the ``.grad`` attribute. - """ - if not self._enabled: - return - - self._check_scale_growth_tracker("unscale_") - - optimizer_state = self._per_optimizer_states[id(optimizer)] - - if optimizer_state["stage"] is OptState.UNSCALED: - raise RuntimeError("unscale_() has already been called on this optimizer since the last update().") - elif optimizer_state["stage"] is OptState.STEPPED: - raise RuntimeError("unscale_() is being called after step().") - - - # FP32 division can be imprecise for certain compile options, so we carry out the reciprocal in FP64. - assert self._scale is not None - inv_scale = 1.0 / self._scale - found_inf = 0.0 - optimizer_state["found_inf_per_device"] = self._unscale_grads_(optimizer, inv_scale, found_inf, False) - optimizer_state["stage"] = OptState.UNSCALED - - def _maybe_opt_step(self, optimizer, optimizer_state, *args, **kwargs): - retval = None - if not optimizer_state["found_inf_per_device"]: - retval = optimizer.step(*args, **kwargs) - else: - optimizer.post_step() - - return retval - - def step(self, optimizer, *args, **kwargs): - """ - :meth:`step` carries out the following two operations: - - 1. Internally invokes ``unscale_(optimizer)`` (unless :meth:`unscale_` was explicitly called for ``optimizer`` - earlier in the iteration). As part of the :meth:`unscale_`, gradients are checked for infs/NaNs. - 2. If no inf/NaN gradients are found, invokes ``optimizer.step()`` using the unscaled - gradients. Otherwise, ``optimizer.step()`` is skipped to avoid corrupting the params. - - ``*args`` and ``**kwargs`` are forwarded to ``optimizer.step()``. - - Returns the return value of ``optimizer.step(*args, **kwargs)``. - - Args: - optimizer (torch.optim.Optimizer): Optimizer that applies the gradients. - args: Any arguments. - kwargs: Any keyword arguments. - - .. warning:: - Closure use is not currently supported. - """ - if (not self._enabled): - return optimizer.step(*args, **kwargs) - - if "closure" in kwargs: - raise RuntimeError("Closure use is not currently supported if GradScaler is enabled.") - - self._check_scale_growth_tracker("step") - - optimizer_state = self._per_optimizer_states[id(optimizer)] - - if optimizer_state["stage"] is OptState.STEPPED: - raise RuntimeError("step() has already been called since the last update().") - - retval = None - - if (hasattr(optimizer, "_step_supports_amp_scaling") and optimizer._step_supports_amp_scaling): - # This optimizer has customized scale-handling logic, so we can call optimizer.step() directly. - # The contract with custom optimizers is that their step() should accept an additional, - # optional grad_scaler kwarg. We append self to the kwargs so the custom optimizer has full information: - # it can query its own state, invoke unscale_ on itself, etc - # The contract above is being deprecated to avoid introducing `grad_scaler: GradScaler` argument - # to `Optimizer.step`. The new behavior is going to add two Tensor attributes of `grad_scale` - # and `found_inf` to the passed optimizer so that the optimizer can utilize those - # to skip the parameter updates or unscale gradients before updating parameters in - # the fused kernel, e.g. `FusedAdamMathFunctor`. - # In this behavior, `GradScaler._check_inf_per_device` is called if `OptState.READY`, - # while the method is expected to be called by users side, i.e. their optimizers. - kwargs_ = kwargs - has_grad_scaler_kwarg = "grad_scaler" in inspect.signature(optimizer.step).parameters - if has_grad_scaler_kwarg: - warnings.warn( - "GradScaler is going to stop passing itself as a keyword argument to the passed " - "optimizer. In the near future GradScaler registers `grad_scale: Tensor` and " - "`found_inf: Tensor` to the passed optimizer and let the optimizer use them directly.", - FutureWarning) - kwargs_.update({"grad_scaler": self}) - else: - if optimizer_state["stage"] is OptState.READY: - self._check_inf_per_device(optimizer) - scaler = self._get_scale_async() - found_inf = cast( - jt.Var, - sum([ - t for t in optimizer_state["found_inf_per_device"].values() - ]) - ) - optimizer.grad_scale = None if optimizer_state["stage"] == OptState.UNSCALED else scaler - optimizer.found_inf = found_inf - retval = optimizer.step(*args, **kwargs_) - optimizer_state["stage"] = OptState.STEPPED - if not has_grad_scaler_kwarg: - del optimizer.grad_scale - del optimizer.found_inf - return retval - - - if optimizer_state["stage"] is OptState.READY: - self.unscale_(optimizer) - - assert "found_inf_per_device" in optimizer_state, "No inf checks were recorded for this optimizer." - - retval = self._maybe_opt_step(optimizer, optimizer_state, *args, **kwargs) - - optimizer_state["stage"] = OptState.STEPPED - - return retval - - def update(self, new_scale=None): - """ - Updates the scale factor. - - If any optimizer steps were skipped the scale is multiplied by ``backoff_factor`` - to reduce it. If ``growth_interval`` unskipped iterations occurred consecutively, - the scale is multiplied by ``growth_factor`` to increase it. - - Passing ``new_scale`` sets the new scale value manually. (``new_scale`` is not - used directly, it's used to fill GradScaler's internal scale tensor. So if - ``new_scale`` was a tensor, later in-place changes to that tensor will not further - affect the scale GradScaler uses internally.) - - Args: - new_scale (float or :class:`torch.cuda.FloatTensor`, optional, default=None): New scale factor. - - .. warning:: - :meth:`update` should only be called at the end of the iteration, after ``scaler.step(optimizer)`` has - been invoked for all optimizers used this iteration. - """ - if not self._enabled: - return - - _scale, _growth_tracker = self._check_scale_growth_tracker("update") - - if new_scale is not None: - # Accept a new user-defined scale. - if isinstance(new_scale, float): - self._scale.fill_(new_scale) # type: ignore[union-attr] - else: - reason = "new_scale should be a float or a 1-element torch.cuda.FloatTensor with requires_grad=False." - assert isinstance(new_scale, torch.cuda.FloatTensor), reason # type: ignore[attr-defined] - assert new_scale.numel() == 1, reason - assert new_scale.requires_grad is False, reason - self._scale.copy_(new_scale) # type: ignore[union-attr] - else: - # Consume shared inf/nan data collected from optimizers to update the scale. - # If all found_inf tensors are on the same device as self._scale, this operation is asynchronous. - found_infs = [state["found_inf_per_device"] - for state in self._per_optimizer_states.values() - ] - - assert len(found_infs) > 0, "No inf checks were recorded prior to update." - - found_inf_combined = found_infs[0] - if len(found_infs) > 1: - for i in range(1, len(found_infs)): - found_inf_combined += found_infs[i] - - - current_scale = _scale - if found_inf_combined: - current_scale *=self._backoff_factor - _growth_tracker = 0 - else: - successful = _growth_tracker+1 - if successful == self._growth_interval: - new_scale = current_scale*self._growth_factor - if new_scale < 1e9: - current_scale = new_scale - _growth_tracker = 0 - else: - _growth_tracker = successful - - self._scale, self._growth_tracker = current_scale,_growth_tracker - - # To prepare for next iteration, clear the data collected from optimizers this iteration. - self._per_optimizer_states = defaultdict(_refresh_per_optimizer_state) - - def _get_scale_async(self): - return self._scale - - def get_scale(self): - """ - Returns a Python float containing the current scale, or 1.0 if scaling is disabled. - - .. warning:: - :meth:`get_scale` incurs a CPU-GPU sync. - """ - if self._enabled: - return self._init_scale if self._scale is None else self._get_scale_async() - else: - return 1.0 - - def get_growth_factor(self): - r""" - Returns a Python float containing the scale growth factor. - """ - return self._growth_factor - - def set_growth_factor(self, new_factor): - r""" - Args: - new_scale (float): Value to use as the new scale growth factor. - """ - self._growth_factor = new_factor - - def get_backoff_factor(self): - r""" - Returns a Python float containing the scale backoff factor. - """ - return self._backoff_factor - - def set_backoff_factor(self, new_factor): - r""" - Args: - new_scale (float): Value to use as the new scale backoff factor. - """ - self._backoff_factor = new_factor - - def get_growth_interval(self): - r""" - Returns a Python int containing the growth interval. - """ - return self._growth_interval - - def set_growth_interval(self, new_interval): - r""" - Args: - new_interval (int): Value to use as the new growth interval. - """ - self._growth_interval = new_interval - - def _get_growth_tracker(self): - if self._enabled: - return self._init_growth_tracker if self._growth_tracker is None else self._growth_tracker.item() - else: - return 0 - - def is_enabled(self): - r""" - Returns a bool indicating whether this instance is enabled. - """ - return self._enabled - - def state_dict(self): - r""" - Returns the state of the scaler as a :class:`dict`. It contains five entries: - - * ``"scale"`` - a Python float containing the current scale - * ``"growth_factor"`` - a Python float containing the current growth factor - * ``"backoff_factor"`` - a Python float containing the current backoff factor - * ``"growth_interval"`` - a Python int containing the current growth interval - * ``"_growth_tracker"`` - a Python int containing the number of recent consecutive unskipped steps. - - If this instance is not enabled, returns an empty dict. - - .. note:: - If you wish to checkpoint the scaler's state after a particular iteration, :meth:`state_dict` - should be called after :meth:`update`. - """ - return {"scale": self.get_scale(), - "growth_factor": self._growth_factor, - "backoff_factor": self._backoff_factor, - "growth_interval": self._growth_interval, - "_growth_tracker": self._get_growth_tracker()} if self._enabled else {} - - def load_state_dict(self, state_dict): - r""" - Loads the scaler state. If this instance is disabled, :meth:`load_state_dict` is a no-op. - - Args: - state_dict(dict): scaler state. Should be an object returned from a call to :meth:`state_dict`. - """ - if not self._enabled: - return - - if len(state_dict) == 0: - raise RuntimeError("The source state dict is empty, possibly because it was saved " - "from a disabled instance of GradScaler.") - - self._init_scale = state_dict["scale"] - if self._scale is not None: - self._scale.fill_(state_dict["scale"]) - self._growth_factor = state_dict["growth_factor"] - self._backoff_factor = state_dict["backoff_factor"] - self._growth_interval = state_dict["growth_interval"] - self._init_growth_tracker = state_dict["_growth_tracker"] - if self._growth_tracker is not None: - self._growth_tracker.fill_(state_dict["_growth_tracker"]) - - def __getstate__(self): - state = self.__dict__.copy() - if self._enabled: - assert len(self._per_optimizer_states) == 0, "A GradScaler instance may only be pickled at the beginning "\ - "of an iteration, or at the end after scaler.update()." - # Pickling _scale and _growth_tracker Tensors directly triggers - # "warnings.warn("pickle support for Storage will be removed in 1.5..." - # so instead, we set the unpickled instance up to reinitialize them lazily. - state['_init_scale'] = self.get_scale() - state['_init_growth_tracker'] = self._get_growth_tracker() - state['_scale'] = None - state['_growth_tracker'] = None - return state - - def __setstate__(self, state): - self.__dict__.update(state) - - def _check_inf_per_device(self, optimizer): - _scale, _ = self._check_scale_growth_tracker("_check_inf_per_device") - - dummy_inv_scale = 1.0 - found_inf = 0.0 - - self._per_optimizer_states[id(optimizer)]["found_inf_per_device"] = \ - self._unscale_grads_(optimizer, dummy_inv_scale, found_inf, True) - - return self._per_optimizer_states[id(optimizer)]["found_inf_per_device"] - - def _found_inf_per_device(self, optimizer): - return self._per_optimizer_states[id(optimizer)]["found_inf_per_device"] diff --git a/python/jittor/compatibility/misc.py b/python/jittor/compatibility/misc.py deleted file mode 100644 index 8e9ed20d..00000000 --- a/python/jittor/compatibility/misc.py +++ /dev/null @@ -1,12 +0,0 @@ -import math - -def _jit_set_profiling_mode(x): pass -def _jit_set_profiling_executor(x): pass -def _jit_override_can_fuse_on_cpu(x): pass -def _jit_override_can_fuse_on_gpu(x): pass - -def script(func): - return func - -inf = math.inf -nan = math.nan \ No newline at end of file diff --git a/python/jittor/compatibility/nn/__init__.py b/python/jittor/compatibility/nn/__init__.py deleted file mode 100644 index ae0ff3ae..00000000 --- a/python/jittor/compatibility/nn/__init__.py +++ /dev/null @@ -1,281 +0,0 @@ -import jtorch -from typing import List, Optional, Tuple, Iterable, Iterator, Mapping, Any, overload, TypeVar, Dict -from typing_extensions import Self -import jittor as jt -from jtorch import make_module, Tensor, ModuleMisc, wrapper -#from . import init -from jittor import Function -import operator -import warnings - -for k,v in jt.nn.__dict__.items(): - if callable(v): - globals()[k] = wrapper(v) - -for k,v in jt.nn.__dict__.items(): - if isinstance(v, type) and issubclass(v, jt.Module): - globals()[k] = make_module(v) - -from collections import OrderedDict -from collections import abc as container_abcs - -class Module(ModuleMisc, jt.Module): - - def __call__(self, *args, **kw): - return self.execute(*args, **kw) - - def execute(self, *args, **kw): - return self.forward(*args, **kw) - - def get_submodule(self, target: str): - if target == "": - return self - - atoms: List[str] = target.split(".") - mod: jt.nn.Module = self - - for item in atoms: - if not hasattr(mod, item): - raise AttributeError(mod._get_name() + " has no " - "attribute `" + item + "`") - - mod = getattr(mod, item) - - if not isinstance(mod, jt.nn.Module): - raise AttributeError("`" + item + "` is not " - "an nn.Module") - return mod - - - -def Parameter(x:Tensor, requires_grad:bool=True) -> Tensor: - x = x.clone() - x.requires_grad = requires_grad - x.retains_grad = requires_grad - return x - -def embedding(input, weight, padding_idx=None, max_norm=None, norm_type=2.0, scale_grad_by_freq=False, sparse=False): - return jt.nn.embedding(input, weight) - -def dropout(x, p=0.5, training=False): - return jt.nn.dropout(x, p, training) - - -class Flatten(Module): - ''' Flattens the contiguous range of dimensions in a Var. - :param start_dim: the first dimension to be flattened. Defaults: 1. - :type start_dim: int - :param end_dim: the last dimension to be flattened. Defaults: -1. - :type end_dim: int - ''' - def __init__(self, start_dim=1, end_dim=-1): - self.start_dim = start_dim - self.end_dim = end_dim - - def forward(self, x) -> jt.Var: - return x.flatten(self.start_dim, self.end_dim) - -class _IncompatibleKeys: - def __init__(self, missing_keys, unexpected_keys): - self.missing_keys = missing_keys - self.unexpected_keys = unexpected_keys - -_BatchNorm = None - -#from . import utils -normalize = wrapper(jt.normalize) - -T = TypeVar('T', bound=Module) - -class ModuleDict(Module): - _modules: Dict[str, Module] # type: ignore[assignment] - - def __init__(self, modules: Optional[Mapping[str, Module]] = None) -> None: - super().__init__() - if modules is not None: - self.update(modules) - - def __getitem__(self, key: str) -> Module: - return self._modules[key] - - def __setitem__(self, key: str, module: Module) -> None: - self.add_module(key, module) - - def __delitem__(self, key: str) -> None: - del self._modules[key] - - def __len__(self) -> int: - return len(self._modules) - - def __iter__(self) -> Iterator[str]: - return iter(self._modules) - - def __contains__(self, key: str) -> bool: - return key in self._modules - - def clear(self) -> None: - """Remove all items from the ModuleDict.""" - self._modules.clear() - - def pop(self, key: str) -> Module: - r"""Remove key from the ModuleDict and return its module. - - Args: - key (str): key to pop from the ModuleDict - """ - v = self[key] - del self[key] - return v - - def keys(self) -> Iterable[str]: - r"""Return an iterable of the ModuleDict keys.""" - return self._modules.keys() - - def items(self) -> Iterable[Tuple[str, Module]]: - r"""Return an iterable of the ModuleDict key/value pairs.""" - return self._modules.items() - - def values(self) -> Iterable[Module]: - r"""Return an iterable of the ModuleDict values.""" - return self._modules.values() - - def update(self, modules: Mapping[str, Module]) -> None: - r"""Update the :class:`~torch.nn.ModuleDict` with key-value pairs from a mapping, overwriting existing keys. - - .. note:: - If :attr:`modules` is an ``OrderedDict``, a :class:`~torch.nn.ModuleDict`, or - an iterable of key-value pairs, the order of new elements in it is preserved. - - Args: - modules (iterable): a mapping (dictionary) from string to :class:`~torch.nn.Module`, - or an iterable of key-value pairs of type (string, :class:`~torch.nn.Module`) - """ - if not isinstance(modules, container_abcs.Iterable): - raise TypeError("ModuleDict.update should be called with an " - "iterable of key/value pairs, but got " + - type(modules).__name__) - - if isinstance(modules, (OrderedDict, ModuleDict, container_abcs.Mapping)): - for key, module in modules.items(): - self[key] = module - else: - # modules here can be a list with two items - for j, m in enumerate(modules): - if not isinstance(m, container_abcs.Iterable): - raise TypeError("ModuleDict update sequence element " - "#" + str(j) + " should be Iterable; is" + - type(m).__name__) - if not len(m) == 2: - raise ValueError("ModuleDict update sequence element " - "#" + str(j) + " has length " + str(len(m)) + - "; 2 is required") - # modules can be Mapping (what it's typed at), or a list: [(name1, module1), (name2, module2)] - # that's too cumbersome to type correctly with overloads, so we add an ignore here - self[m[0]] = m[1] # type: ignore[assignment] - - # remove forward alltogether to fallback on Module's _forward_unimplemented - - -class ParameterList(Module): - - def __init__(self, values: Optional[Iterable[Any]] = None) -> None: - super().__init__() - self._size = 0 - if values is not None: - self += values - - def _get_abs_string_index(self, idx): - """Get the absolute index for the list of modules.""" - idx = operator.index(idx) - if not (-len(self) <= idx < len(self)): - raise IndexError(f'index {idx} is out of range') - if idx < 0: - idx += len(self) - return str(idx) - - @overload - def __getitem__(self, idx: int) -> Any: - ... - - @overload - def __getitem__(self: T, idx: slice) -> T: - ... - - def __getitem__(self, idx): - if isinstance(idx, slice): - start, stop, step = idx.indices(len(self)) - out = self.__class__() - for i in range(start, stop, step): - out.append(self[i]) - return out - else: - idx = self._get_abs_string_index(idx) - return getattr(self, str(idx)) - - def __setitem__(self, idx: int, param: Any) -> None: - # Note that all other function that add an entry to the list part of - # the ParameterList end up here. So this is the only place where we need - # to wrap things into Parameter if needed. - # Objects added via setattr() are not in the list part and thus won't - # call into this function. - idx = self._get_abs_string_index(idx) - if isinstance(param, jt.Var) and not isinstance(param, Parameter): - param = Parameter(param) - return setattr(self, str(idx), param) - - def __len__(self) -> int: - return self._size - - def __iter__(self) -> Iterator[Any]: - return iter(self[i] for i in range(len(self))) - - def __iadd__(self, parameters: Iterable[Any]) -> Self: - return self.extend(parameters) - - def __dir__(self): - keys = super().__dir__() - keys = [key for key in keys if not key.isdigit()] - return keys - - def append(self, value: Any) -> 'ParameterList': - """Append a given value at the end of the list. - - Args: - value (Any): value to append - """ - new_idx = len(self) - self._size += 1 - self[new_idx] = value - return self - - def extend(self, values: Iterable[Any]) -> Self: - """Append values from a Python iterable to the end of the list. - - Args: - values (iterable): iterable of values to append - """ - # Tensor is an iterable but we never want to unpack it here - if not isinstance(values, container_abcs.Iterable) or isinstance(values, jt.Var): - raise TypeError("ParameterList.extend should be called with an " - "iterable, but got " + type(values).__name__) - for value in values: - self.append(value) - return self - - def extra_repr(self) -> str: - child_lines = [] - for k, p in enumerate(self): - if isinstance(p, jt.Var): - size_str = 'x'.join(str(size) for size in p.size()) - parastr = '{} containing: [{} of size {}{}]'.format( - "Parameter" if isinstance(p, Parameter) else "Tensor", - p.dtype, size_str, "cuda" if jt.flags.use_cuda else "cpu") - child_lines.append(' (' + str(k) + '): ' + parastr) - else: - child_lines.append(' (' + str(k) + '): Object of type: ' + type(p).__name__) - - tmpstr = '\n'.join(child_lines) - return tmpstr - - def __call__(self, *args, **kwargs): - raise RuntimeError('ParameterList should not be called.') \ No newline at end of file diff --git a/python/jittor/compatibility/nn/init.py b/python/jittor/compatibility/nn/init.py deleted file mode 100644 index 3b9f0907..00000000 --- a/python/jittor/compatibility/nn/init.py +++ /dev/null @@ -1,16 +0,0 @@ -import jittor as jt - -for k,v in jt.nn.init.__dict__.items(): - if callable(v): - globals()[k] = v - - -normal = gauss -normal_ = gauss_ -xavier_normal = xavier_gauss -xavier_normal_ = xavier_gauss_ -zeros_ = zero_ - - -jt.Var.normal_ = normal_ - diff --git a/python/jittor/compatibility/nn/utils/__init__.py b/python/jittor/compatibility/nn/utils/__init__.py deleted file mode 100644 index 83409f5f..00000000 --- a/python/jittor/compatibility/nn/utils/__init__.py +++ /dev/null @@ -1 +0,0 @@ -from . import rnn \ No newline at end of file diff --git a/python/jittor/compatibility/nn/utils/rnn.py b/python/jittor/compatibility/nn/utils/rnn.py deleted file mode 100644 index b32da8c3..00000000 --- a/python/jittor/compatibility/nn/utils/rnn.py +++ /dev/null @@ -1,20 +0,0 @@ -import jittor as jt - -PackedSequence = None - -def pad_sequence(sequences,batch_first=False,padding_value=0.0): - max_f = max([len(s) for s in sequences]) - # max_f = 512 - b = len(sequences) - if batch_first: - ret = sequences[0].new_full([b,max_f,]+list(sequences[0].shape[1:]),padding_value) - for i,s in enumerate(sequences): - ret[i,:len(s)] = s - else: - ret = sequences[0].new_full([max_f,b,]+list(sequences[0].shape[1:]),padding_value) - for i,s in enumerate(sequences): - ret[:len(s),i] = s - # print(ret.shape) - # ret = ret[:,:406] - return ret - \ No newline at end of file diff --git a/python/jittor/compatibility/optim.py b/python/jittor/compatibility/optim.py deleted file mode 100644 index 2410917f..00000000 --- a/python/jittor/compatibility/optim.py +++ /dev/null @@ -1,1854 +0,0 @@ -import jittor as jt -import math -from jittor.optim import * -from functools import partial - -class Optimizer(jt.optim.Optimizer): - def pre_step(self, loss=None, retain_graph=False): - jt.flags.node_order = 1 - params_has_grad = [] - for pg in self.param_groups: - pg["grads"] = [ jt.zeros_like(p) if p.grad is None else p.grad#.float32() - for p in pg["params"] ] - for p in pg["params"]: - if p.requires_grad: - params_has_grad.append(p) - jt.sync(params_has_grad) - self.n_step += 1 - - def zero_grad(self): - for pg in self.param_groups: - pg["grads"] = [ None for p in pg["params"] ] - for p in pg["params"]: p.grad = None - - def post_step(self): - jt.flags.node_order = 0 - - def clip_grad_norm(self, max_norm:float, norm_type:int=2): - r"""Clips gradient norm of this optimizer. - The norm is computed over all gradients together. - - Args: - max_norm (float or int): max norm of the gradients - norm_type (int): 1-norm or 2-norm - - Example:: - - a = jt.ones(2) - opt = jt.optim.SGD([a], 0.1) - - loss = a*a - opt.zero_grad() - opt.backward(loss) - - print(opt.param_groups[0]['grads'][0].norm()) # output: 2.83 - opt.clip_grad_norm(0.01, 2) - print(opt.param_groups[0]['grads'][0].norm()) # output: 0.01 - - opt.step() - - """ - self.pre_step(None) - grads = [] - for pg in self.param_groups: - for p, g in zip(pg["params"], pg["grads"]): - if p.is_stop_grad(): continue - grads.append(g.flatten()) - if len(grads) == 0: return - total_norm = jt.norm(jt.concat(grads), norm_type) - clip_coef = jt.minimum(max_norm / (total_norm + 1e-6), 1.0) - for pg in self.param_groups: - for p, g in zip(pg["params"], pg["grads"]): - if p.is_stop_grad(): continue - g.update(g*clip_coef) - - -class AdamW(Optimizer): - def __init__(self, params, lr, eps=1e-8, betas=(0.9, 0.999), weight_decay=0,use_fp32=True): - print("lr:", lr) - super().__init__(params, lr) - self.eps = eps - self.betas = betas - self.weight_decay = weight_decay - - self.use_fp32 = use_fp32 - # assert weight_decay==0, "weight_decay is not supported yet" - - # initialize required arguments for each param_groups - for pg in self.param_groups: - values = pg["values"] = [] - m = pg["m"] = [] - mp = pg['masterparams'] = [] - for p in pg["params"]: - values.append(jt.zeros(p.shape, "float32" if self.use_fp32 else p.dtype).stop_grad()) - m.append(jt.zeros(p.shape, "float32" if self.use_fp32 else p.dtype).stop_grad()) - if self.use_fp32: - mp.append(p.detach().clone().stop_grad()) - - def add_param_group(self, group): - values = group["values"] = [] - m = group["m"] = [] - mp = group['masterparams'] = [] - for p in group["params"]: - values.append(jt.zeros(p.shape, "float32" if self.use_fp32 else p.dtype).stop_grad()) - m.append(jt.zeros(p.shape, "float32" if self.use_fp32 else p.dtype).stop_grad()) - if self.use_fp32: - mp.append(p.detach().clone().stop_grad()) - self.param_groups.append(group) - - def step(self, loss=None, retain_graph=False): - self.pre_step(loss, retain_graph) - if loss is None: - self.n_step += 1 - n = float(self.n_step) - for pg in self.param_groups: - # get arguments from each param_groups - lr = pg.get("lr", self.lr) - eps = pg.get("eps", self.eps) - weight_decay = pg.get("weight_decay", self.weight_decay) - b0, b1 = pg.get("betas", self.betas) - - for p, g, v, m,mp in zip(pg["params"], pg["grads"], pg["values"], pg["m"],pg['masterparams']): - if p.is_stop_grad(): continue - #if g.abs().sum().item() < 1e-8: continue - #import pdb; pdb.set_trace() - c_p = (mp * (1 - lr * weight_decay)) - mp.update(c_p) - if self.use_fp32: - g = g.float32() - bias_correction1 = 1 - b0 ** n - bias_correction2 = 1 - b1 ** n - m.update(b0 * m + (1-b0) * g) #exp_avg - v.update(b1 * v + (1-b1) * g * g) #exp_avg_sq - denom = jt.sqrt(v) / jt.sqrt(bias_correction2) + eps - step_size = lr / bias_correction1 - new_p = (mp - step_size * m / denom) - mp.update(new_p) - p.update(mp.cast(p.dtype)) - self.post_step() - -for k,v in jt.optim.__dict__.items(): - if k == "AdamW":continue - if isinstance(v, type) and issubclass(v, jt.optim.Optimizer) and \ - not v is jt.optim.Optimizer: - class OptimWrap(v, Optimizer): - pass - globals()[k] = OptimWrap - - -class Adagrad(Optimizer): - pass - - - -import types -import math -from functools import wraps -import warnings -import weakref -from collections import Counter -from bisect import bisect_right - - -class LRScheduler: - - def __init__(self, optimizer, last_epoch=-1, verbose=False): - - # Attach optimizer - if not isinstance(optimizer, Optimizer): - raise TypeError('{} is not an Optimizer'.format( - type(optimizer).__name__)) - self.optimizer = optimizer - - # Initialize epoch and base learning rates - if last_epoch == -1: - for group in optimizer.param_groups: - group.setdefault('initial_lr', group.get("lr",optimizer.lr)) - else: - for i, group in enumerate(optimizer.param_groups): - if 'initial_lr' not in group: - raise KeyError("param 'initial_lr' is not specified " - "in param_groups[{}] when resuming an optimizer".format(i)) - self.base_lrs = [group['initial_lr'] for group in optimizer.param_groups] - self.last_epoch = last_epoch - - # Following https://github.com/pytorch/pytorch/issues/20124 - # We would like to ensure that `lr_scheduler.step()` is called after - # `optimizer.step()` - def with_counter(method): - if getattr(method, '_with_counter', False): - # `optimizer.step()` has already been replaced, return. - return method - - # Keep a weak reference to the optimizer instance to prevent - # cyclic references. - instance_ref = weakref.ref(method.__self__) - # Get the unbound method for the same purpose. - func = method.__func__ - cls = instance_ref().__class__ - del method - - @wraps(func) - def wrapper(*args, **kwargs): - instance = instance_ref() - instance._step_count += 1 - wrapped = func.__get__(instance, cls) - return wrapped(*args, **kwargs) - - # Note that the returned function here is no longer a bound method, - # so attributes like `__func__` and `__self__` no longer exist. - wrapper._with_counter = True - return wrapper - - self.optimizer.step = with_counter(self.optimizer.step) - self.verbose = verbose - - self._initial_step() - - def _initial_step(self): - """Initialize step counts and performs a step""" - self.optimizer._step_count = 0 - self._step_count = 0 - self.step() - - def state_dict(self): - """Returns the state of the scheduler as a :class:`dict`. - - It contains an entry for every variable in self.__dict__ which - is not the optimizer. - """ - return {key: value for key, value in self.__dict__.items() if key != 'optimizer'} - - def load_state_dict(self, state_dict): - """Loads the schedulers state. - - Args: - state_dict (dict): scheduler state. Should be an object returned - from a call to :meth:`state_dict`. - """ - self.__dict__.update(state_dict) - - def get_last_lr(self): - """ Return last computed learning rate by current scheduler. - """ - return self._last_lr - - def get_lr(self): - # Compute learning rate using chainable form of the scheduler - raise NotImplementedError - - def print_lr(self, is_verbose, group, lr, epoch=None): - """Display the current learning rate. - """ - if is_verbose: - if epoch is None: - print('Adjusting learning rate' - ' of group {} to {:.4e}.'.format(group, lr)) - else: - epoch_str = ("%.2f" if isinstance(epoch, float) else - "%.5d") % epoch - print('Epoch {}: adjusting learning rate' - ' of group {} to {:.4e}.'.format(epoch_str, group, lr)) - - - def step(self, epoch=None): - # Raise a warning if old pattern is detected - # https://github.com/pytorch/pytorch/issues/20124 - if self._step_count == 1: - if not hasattr(self.optimizer.step, "_with_counter"): - warnings.warn("Seems like `optimizer.step()` has been overridden after learning rate scheduler " - "initialization. Please, make sure to call `optimizer.step()` before " - "`lr_scheduler.step()`. See more details at " - "https://pytorch.org/docs/stable/optim.html#how-to-adjust-learning-rate", UserWarning) - - # Just check if there were two first lr_scheduler.step() calls before optimizer.step() - elif self.optimizer._step_count < 1: - warnings.warn("Detected call of `lr_scheduler.step()` before `optimizer.step()`. " - "In PyTorch 1.1.0 and later, you should call them in the opposite order: " - "`optimizer.step()` before `lr_scheduler.step()`. Failure to do this " - "will result in PyTorch skipping the first value of the learning rate schedule. " - "See more details at " - "https://pytorch.org/docs/stable/optim.html#how-to-adjust-learning-rate", UserWarning) - self._step_count += 1 - - with _enable_get_lr_call(self): - if epoch is None: - self.last_epoch += 1 - values = self.get_lr() - else: - warnings.warn(EPOCH_DEPRECATION_WARNING, UserWarning) - self.last_epoch = epoch - if hasattr(self, "_get_closed_form_lr"): - values = self._get_closed_form_lr() - else: - values = self.get_lr() - - for i, data in enumerate(zip(self.optimizer.param_groups, values)): - param_group, lr = data - param_group['lr'] = lr - self.print_lr(self.verbose, i, lr, epoch) - - self._last_lr = [group['lr'] for group in self.optimizer.param_groups] - -# Including _LRScheduler for backwards compatibility -# Subclass instead of assign because we want __name__ of _LRScheduler to be _LRScheduler (assigning would make it LRScheduler). -class _LRScheduler(LRScheduler): - pass - - -class _enable_get_lr_call: - - def __init__(self, o): - self.o = o - - def __enter__(self): - self.o._get_lr_called_within_step = True - return self - - def __exit__(self, type, value, traceback): - self.o._get_lr_called_within_step = False - - -class LambdaLR(LRScheduler): - """Sets the learning rate of each parameter group to the initial lr - times a given function. When last_epoch=-1, sets initial lr as lr. - - Args: - optimizer (Optimizer): Wrapped optimizer. - lr_lambda (function or list): A function which computes a multiplicative - factor given an integer parameter epoch, or a list of such - functions, one for each group in optimizer.param_groups. - last_epoch (int): The index of last epoch. Default: -1. - verbose (bool): If ``True``, prints a message to stdout for - each update. Default: ``False``. - - Example: - >>> # xdoctest: +SKIP - >>> # Assuming optimizer has two groups. - >>> lambda1 = lambda epoch: epoch // 30 - >>> lambda2 = lambda epoch: 0.95 ** epoch - >>> scheduler = LambdaLR(optimizer, lr_lambda=[lambda1, lambda2]) - >>> for epoch in range(100): - >>> train(...) - >>> validate(...) - >>> scheduler.step() - """ - - def __init__(self, optimizer, lr_lambda, last_epoch=-1, verbose=False): - self.optimizer = optimizer - - if not isinstance(lr_lambda, list) and not isinstance(lr_lambda, tuple): - self.lr_lambdas = [lr_lambda] * len(optimizer.param_groups) - else: - if len(lr_lambda) != len(optimizer.param_groups): - raise ValueError("Expected {} lr_lambdas, but got {}".format( - len(optimizer.param_groups), len(lr_lambda))) - self.lr_lambdas = list(lr_lambda) - super().__init__(optimizer, last_epoch, verbose) - - def state_dict(self): - """Returns the state of the scheduler as a :class:`dict`. - - It contains an entry for every variable in self.__dict__ which - is not the optimizer. - The learning rate lambda functions will only be saved if they are callable objects - and not if they are functions or lambdas. - - When saving or loading the scheduler, please make sure to also save or load the state of the optimizer. - """ - - state_dict = {key: value for key, value in self.__dict__.items() if key not in ('optimizer', 'lr_lambdas')} - state_dict['lr_lambdas'] = [None] * len(self.lr_lambdas) - - for idx, fn in enumerate(self.lr_lambdas): - if not isinstance(fn, types.FunctionType): - state_dict['lr_lambdas'][idx] = fn.__dict__.copy() - - return state_dict - - def load_state_dict(self, state_dict): - """Loads the schedulers state. - - When saving or loading the scheduler, please make sure to also save or load the state of the optimizer. - - Args: - state_dict (dict): scheduler state. Should be an object returned - from a call to :meth:`state_dict`. - """ - - lr_lambdas = state_dict.pop('lr_lambdas') - self.__dict__.update(state_dict) - # Restore state_dict keys in order to prevent side effects - # https://github.com/pytorch/pytorch/issues/32756 - state_dict['lr_lambdas'] = lr_lambdas - - for idx, fn in enumerate(lr_lambdas): - if fn is not None: - self.lr_lambdas[idx].__dict__.update(fn) - - def get_lr(self): - if not self._get_lr_called_within_step: - warnings.warn("To get the last learning rate computed by the scheduler, " - "please use `get_last_lr()`.") - return [base_lr * lmbda(self.last_epoch) - for lmbda, base_lr in zip(self.lr_lambdas, self.base_lrs)] - - -class MultiplicativeLR(LRScheduler): - """Multiply the learning rate of each parameter group by the factor given - in the specified function. When last_epoch=-1, sets initial lr as lr. - - Args: - optimizer (Optimizer): Wrapped optimizer. - lr_lambda (function or list): A function which computes a multiplicative - factor given an integer parameter epoch, or a list of such - functions, one for each group in optimizer.param_groups. - last_epoch (int): The index of last epoch. Default: -1. - verbose (bool): If ``True``, prints a message to stdout for - each update. Default: ``False``. - - Example: - >>> # xdoctest: +SKIP - >>> lmbda = lambda epoch: 0.95 - >>> scheduler = MultiplicativeLR(optimizer, lr_lambda=lmbda) - >>> for epoch in range(100): - >>> train(...) - >>> validate(...) - >>> scheduler.step() - """ - - def __init__(self, optimizer, lr_lambda, last_epoch=-1, verbose=False): - self.optimizer = optimizer - - if not isinstance(lr_lambda, list) and not isinstance(lr_lambda, tuple): - self.lr_lambdas = [lr_lambda] * len(optimizer.param_groups) - else: - if len(lr_lambda) != len(optimizer.param_groups): - raise ValueError("Expected {} lr_lambdas, but got {}".format( - len(optimizer.param_groups), len(lr_lambda))) - self.lr_lambdas = list(lr_lambda) - super().__init__(optimizer, last_epoch, verbose) - - def state_dict(self): - """Returns the state of the scheduler as a :class:`dict`. - - It contains an entry for every variable in self.__dict__ which - is not the optimizer. - The learning rate lambda functions will only be saved if they are callable objects - and not if they are functions or lambdas. - """ - state_dict = {key: value for key, value in self.__dict__.items() if key not in ('optimizer', 'lr_lambdas')} - state_dict['lr_lambdas'] = [None] * len(self.lr_lambdas) - - for idx, fn in enumerate(self.lr_lambdas): - if not isinstance(fn, types.FunctionType): - state_dict['lr_lambdas'][idx] = fn.__dict__.copy() - - return state_dict - - def load_state_dict(self, state_dict): - """Loads the schedulers state. - - Args: - state_dict (dict): scheduler state. Should be an object returned - from a call to :meth:`state_dict`. - """ - lr_lambdas = state_dict.pop('lr_lambdas') - self.__dict__.update(state_dict) - # Restore state_dict keys in order to prevent side effects - # https://github.com/pytorch/pytorch/issues/32756 - state_dict['lr_lambdas'] = lr_lambdas - - for idx, fn in enumerate(lr_lambdas): - if fn is not None: - self.lr_lambdas[idx].__dict__.update(fn) - - def get_lr(self): - if not self._get_lr_called_within_step: - warnings.warn("To get the last learning rate computed by the scheduler, " - "please use `get_last_lr()`.", UserWarning) - - if self.last_epoch > 0: - return [group['lr'] * lmbda(self.last_epoch) - for lmbda, group in zip(self.lr_lambdas, self.optimizer.param_groups)] - else: - return [group['lr'] for group in self.optimizer.param_groups] - - -class StepLR(LRScheduler): - """Decays the learning rate of each parameter group by gamma every - step_size epochs. Notice that such decay can happen simultaneously with - other changes to the learning rate from outside this scheduler. When - last_epoch=-1, sets initial lr as lr. - - Args: - optimizer (Optimizer): Wrapped optimizer. - step_size (int): Period of learning rate decay. - gamma (float): Multiplicative factor of learning rate decay. - Default: 0.1. - last_epoch (int): The index of last epoch. Default: -1. - verbose (bool): If ``True``, prints a message to stdout for - each update. Default: ``False``. - - Example: - >>> # xdoctest: +SKIP - >>> # Assuming optimizer uses lr = 0.05 for all groups - >>> # lr = 0.05 if epoch < 30 - >>> # lr = 0.005 if 30 <= epoch < 60 - >>> # lr = 0.0005 if 60 <= epoch < 90 - >>> # ... - >>> scheduler = StepLR(optimizer, step_size=30, gamma=0.1) - >>> for epoch in range(100): - >>> train(...) - >>> validate(...) - >>> scheduler.step() - """ - - def __init__(self, optimizer, step_size, gamma=0.1, last_epoch=-1, verbose=False): - self.step_size = step_size - self.gamma = gamma - super().__init__(optimizer, last_epoch, verbose) - - def get_lr(self): - if not self._get_lr_called_within_step: - warnings.warn("To get the last learning rate computed by the scheduler, " - "please use `get_last_lr()`.", UserWarning) - - if (self.last_epoch == 0) or (self.last_epoch % self.step_size != 0): - return [group['lr'] for group in self.optimizer.param_groups] - return [group['lr'] * self.gamma - for group in self.optimizer.param_groups] - - def _get_closed_form_lr(self): - return [base_lr * self.gamma ** (self.last_epoch // self.step_size) - for base_lr in self.base_lrs] - - -class MultiStepLR(LRScheduler): - """Decays the learning rate of each parameter group by gamma once the - number of epoch reaches one of the milestones. Notice that such decay can - happen simultaneously with other changes to the learning rate from outside - this scheduler. When last_epoch=-1, sets initial lr as lr. - - Args: - optimizer (Optimizer): Wrapped optimizer. - milestones (list): List of epoch indices. Must be increasing. - gamma (float): Multiplicative factor of learning rate decay. - Default: 0.1. - last_epoch (int): The index of last epoch. Default: -1. - verbose (bool): If ``True``, prints a message to stdout for - each update. Default: ``False``. - - Example: - >>> # xdoctest: +SKIP - >>> # Assuming optimizer uses lr = 0.05 for all groups - >>> # lr = 0.05 if epoch < 30 - >>> # lr = 0.005 if 30 <= epoch < 80 - >>> # lr = 0.0005 if epoch >= 80 - >>> scheduler = MultiStepLR(optimizer, milestones=[30,80], gamma=0.1) - >>> for epoch in range(100): - >>> train(...) - >>> validate(...) - >>> scheduler.step() - """ - - def __init__(self, optimizer, milestones, gamma=0.1, last_epoch=-1, verbose=False): - self.milestones = Counter(milestones) - self.gamma = gamma - super().__init__(optimizer, last_epoch, verbose) - - def get_lr(self): - if not self._get_lr_called_within_step: - warnings.warn("To get the last learning rate computed by the scheduler, " - "please use `get_last_lr()`.", UserWarning) - - if self.last_epoch not in self.milestones: - return [group['lr'] for group in self.optimizer.param_groups] - return [group['lr'] * self.gamma ** self.milestones[self.last_epoch] - for group in self.optimizer.param_groups] - - def _get_closed_form_lr(self): - milestones = sorted(self.milestones.elements()) - return [base_lr * self.gamma ** bisect_right(milestones, self.last_epoch) - for base_lr in self.base_lrs] - - -class ConstantLR(LRScheduler): - """Decays the learning rate of each parameter group by a small constant factor until the - number of epoch reaches a pre-defined milestone: total_iters. Notice that such decay can - happen simultaneously with other changes to the learning rate from outside this scheduler. - When last_epoch=-1, sets initial lr as lr. - - Args: - optimizer (Optimizer): Wrapped optimizer. - factor (float): The number we multiply learning rate until the milestone. Default: 1./3. - total_iters (int): The number of steps that the scheduler decays the learning rate. - Default: 5. - last_epoch (int): The index of the last epoch. Default: -1. - verbose (bool): If ``True``, prints a message to stdout for - each update. Default: ``False``. - - Example: - >>> # xdoctest: +SKIP - >>> # Assuming optimizer uses lr = 0.05 for all groups - >>> # lr = 0.025 if epoch == 0 - >>> # lr = 0.025 if epoch == 1 - >>> # lr = 0.025 if epoch == 2 - >>> # lr = 0.025 if epoch == 3 - >>> # lr = 0.05 if epoch >= 4 - >>> scheduler = ConstantLR(self.opt, factor=0.5, total_iters=4) - >>> for epoch in range(100): - >>> train(...) - >>> validate(...) - >>> scheduler.step() - """ - - def __init__(self, optimizer, factor=1.0 / 3, total_iters=5, last_epoch=-1, verbose=False): - if factor > 1.0 or factor < 0: - raise ValueError('Constant multiplicative factor expected to be between 0 and 1.') - - self.factor = factor - self.total_iters = total_iters - super().__init__(optimizer, last_epoch, verbose) - - def get_lr(self): - if not self._get_lr_called_within_step: - warnings.warn("To get the last learning rate computed by the scheduler, " - "please use `get_last_lr()`.", UserWarning) - - if self.last_epoch == 0: - return [group['lr'] * self.factor for group in self.optimizer.param_groups] - - if (self.last_epoch > self.total_iters or - (self.last_epoch != self.total_iters)): - return [group['lr'] for group in self.optimizer.param_groups] - - if (self.last_epoch == self.total_iters): - return [group['lr'] * (1.0 / self.factor) for group in self.optimizer.param_groups] - - def _get_closed_form_lr(self): - return [base_lr * (self.factor + (self.last_epoch >= self.total_iters) * (1 - self.factor)) - for base_lr in self.base_lrs] - - -class LinearLR(LRScheduler): - """Decays the learning rate of each parameter group by linearly changing small - multiplicative factor until the number of epoch reaches a pre-defined milestone: total_iters. - Notice that such decay can happen simultaneously with other changes to the learning rate - from outside this scheduler. When last_epoch=-1, sets initial lr as lr. - - Args: - optimizer (Optimizer): Wrapped optimizer. - start_factor (float): The number we multiply learning rate in the first epoch. - The multiplication factor changes towards end_factor in the following epochs. - Default: 1./3. - end_factor (float): The number we multiply learning rate at the end of linear changing - process. Default: 1.0. - total_iters (int): The number of iterations that multiplicative factor reaches to 1. - Default: 5. - last_epoch (int): The index of the last epoch. Default: -1. - verbose (bool): If ``True``, prints a message to stdout for - each update. Default: ``False``. - - Example: - >>> # xdoctest: +SKIP - >>> # Assuming optimizer uses lr = 0.05 for all groups - >>> # lr = 0.025 if epoch == 0 - >>> # lr = 0.03125 if epoch == 1 - >>> # lr = 0.0375 if epoch == 2 - >>> # lr = 0.04375 if epoch == 3 - >>> # lr = 0.05 if epoch >= 4 - >>> scheduler = LinearLR(self.opt, start_factor=0.5, total_iters=4) - >>> for epoch in range(100): - >>> train(...) - >>> validate(...) - >>> scheduler.step() - """ - - def __init__(self, optimizer, start_factor=1.0 / 3, end_factor=1.0, total_iters=5, last_epoch=-1, - verbose=False): - if start_factor > 1.0 or start_factor <= 0: - raise ValueError('Starting multiplicative factor expected to be greater than 0 and less or equal to 1.') - - if end_factor > 1.0 or end_factor < 0: - raise ValueError('Ending multiplicative factor expected to be between 0 and 1.') - - self.start_factor = start_factor - self.end_factor = end_factor - self.total_iters = total_iters - super().__init__(optimizer, last_epoch, verbose) - - def get_lr(self): - if not self._get_lr_called_within_step: - warnings.warn("To get the last learning rate computed by the scheduler, " - "please use `get_last_lr()`.", UserWarning) - - if self.last_epoch == 0: - return [group['lr'] * self.start_factor for group in self.optimizer.param_groups] - - if self.last_epoch > self.total_iters: - return [group['lr'] for group in self.optimizer.param_groups] - - return [group['lr'] * (1. + (self.end_factor - self.start_factor) / - (self.total_iters * self.start_factor + (self.last_epoch - 1) * (self.end_factor - self.start_factor))) - for group in self.optimizer.param_groups] - - def _get_closed_form_lr(self): - return [base_lr * (self.start_factor + - (self.end_factor - self.start_factor) * min(self.total_iters, self.last_epoch) / self.total_iters) - for base_lr in self.base_lrs] - - -class ExponentialLR(LRScheduler): - """Decays the learning rate of each parameter group by gamma every epoch. - When last_epoch=-1, sets initial lr as lr. - - Args: - optimizer (Optimizer): Wrapped optimizer. - gamma (float): Multiplicative factor of learning rate decay. - last_epoch (int): The index of last epoch. Default: -1. - verbose (bool): If ``True``, prints a message to stdout for - each update. Default: ``False``. - """ - - def __init__(self, optimizer, gamma, last_epoch=-1, verbose=False): - self.gamma = gamma - super().__init__(optimizer, last_epoch, verbose) - - def get_lr(self): - if not self._get_lr_called_within_step: - warnings.warn("To get the last learning rate computed by the scheduler, " - "please use `get_last_lr()`.", UserWarning) - - if self.last_epoch == 0: - return [group['lr'] for group in self.optimizer.param_groups] - return [group['lr'] * self.gamma - for group in self.optimizer.param_groups] - - def _get_closed_form_lr(self): - return [base_lr * self.gamma ** self.last_epoch - for base_lr in self.base_lrs] - - -class SequentialLR(LRScheduler): - """Receives the list of schedulers that is expected to be called sequentially during - optimization process and milestone points that provides exact intervals to reflect - which scheduler is supposed to be called at a given epoch. - - Args: - optimizer (Optimizer): Wrapped optimizer. - schedulers (list): List of chained schedulers. - milestones (list): List of integers that reflects milestone points. - last_epoch (int): The index of last epoch. Default: -1. - verbose (bool): Does nothing. - - Example: - >>> # xdoctest: +SKIP - >>> # Assuming optimizer uses lr = 1. for all groups - >>> # lr = 0.1 if epoch == 0 - >>> # lr = 0.1 if epoch == 1 - >>> # lr = 0.9 if epoch == 2 - >>> # lr = 0.81 if epoch == 3 - >>> # lr = 0.729 if epoch == 4 - >>> scheduler1 = ConstantLR(self.opt, factor=0.1, total_iters=2) - >>> scheduler2 = ExponentialLR(self.opt, gamma=0.9) - >>> scheduler = SequentialLR(self.opt, schedulers=[scheduler1, scheduler2], milestones=[2]) - >>> for epoch in range(100): - >>> train(...) - >>> validate(...) - >>> scheduler.step() - """ - - def __init__(self, optimizer, schedulers, milestones, last_epoch=-1, verbose=False): - for scheduler_idx in range(len(schedulers)): - if schedulers[scheduler_idx].optimizer != optimizer: - raise ValueError( - "Sequential Schedulers expects all schedulers to belong to the same optimizer, but " - f"got schedulers at index {scheduler_idx} to be different than the optimizer passed in." - ) - - if (schedulers[scheduler_idx].optimizer != schedulers[0].optimizer): - raise ValueError( - "Sequential Schedulers expects all schedulers to belong to the same optimizer, but " - f"got schedulers at index {0} and {scheduler_idx} to be different." - ) - if (len(milestones) != len(schedulers) - 1): - raise ValueError( - "Sequential Schedulers expects number of schedulers provided to be one more " - "than the number of milestone points, but got number of schedulers {} and the " - "number of milestones to be equal to {}".format(len(schedulers), len(milestones)) - ) - self._schedulers = schedulers - self._milestones = milestones - self.last_epoch = last_epoch + 1 - self.optimizer = optimizer - - # Reset learning rates back to initial values - for group in self.optimizer.param_groups: - group["lr"] = group["initial_lr"] - - # "Undo" the step performed by other schedulers - for scheduler in self._schedulers: - scheduler.last_epoch -= 1 - - # Perform the initial step for only the first scheduler - self._schedulers[0]._initial_step() - - self._last_lr = schedulers[0].get_last_lr() - - def step(self): - self.last_epoch += 1 - idx = bisect_right(self._milestones, self.last_epoch) - scheduler = self._schedulers[idx] - if idx > 0 and self._milestones[idx - 1] == self.last_epoch: - scheduler.step(0) - else: - scheduler.step() - - self._last_lr = scheduler.get_last_lr() - - def state_dict(self): - """Returns the state of the scheduler as a :class:`dict`. - - It contains an entry for every variable in self.__dict__ which - is not the optimizer. - The wrapped scheduler states will also be saved. - """ - state_dict = {key: value for key, value in self.__dict__.items() if key not in ('optimizer', '_schedulers')} - state_dict['_schedulers'] = [None] * len(self._schedulers) - - for idx, s in enumerate(self._schedulers): - state_dict['_schedulers'][idx] = s.state_dict() - - return state_dict - - def load_state_dict(self, state_dict): - """Loads the schedulers state. - - Args: - state_dict (dict): scheduler state. Should be an object returned - from a call to :meth:`state_dict`. - """ - _schedulers = state_dict.pop('_schedulers') - self.__dict__.update(state_dict) - # Restore state_dict keys in order to prevent side effects - # https://github.com/pytorch/pytorch/issues/32756 - state_dict['_schedulers'] = _schedulers - - for idx, s in enumerate(_schedulers): - self._schedulers[idx].load_state_dict(s) - - -class PolynomialLR(LRScheduler): - """Decays the learning rate of each parameter group using a polynomial function - in the given total_iters. When last_epoch=-1, sets initial lr as lr. - - Args: - optimizer (Optimizer): Wrapped optimizer. - total_iters (int): The number of steps that the scheduler decays the learning rate. Default: 5. - power (int): The power of the polynomial. Default: 1.0. - verbose (bool): If ``True``, prints a message to stdout for - each update. Default: ``False``. - - Example: - >>> # xdoctest: +SKIP("undefined vars") - >>> # Assuming optimizer uses lr = 0.001 for all groups - >>> # lr = 0.001 if epoch == 0 - >>> # lr = 0.00075 if epoch == 1 - >>> # lr = 0.00050 if epoch == 2 - >>> # lr = 0.00025 if epoch == 3 - >>> # lr = 0.0 if epoch >= 4 - >>> scheduler = PolynomialLR(self.opt, total_iters=4, power=1.0) - >>> for epoch in range(100): - >>> train(...) - >>> validate(...) - >>> scheduler.step() - """ - def __init__(self, optimizer, total_iters=5, power=1.0, last_epoch=-1, verbose=False): - self.total_iters = total_iters - self.power = power - super().__init__(optimizer, last_epoch, verbose) - - def get_lr(self): - if not self._get_lr_called_within_step: - warnings.warn("To get the last learning rate computed by the scheduler, " - "please use `get_last_lr()`.", UserWarning) - - if self.last_epoch == 0 or self.last_epoch > self.total_iters: - return [group["lr"] for group in self.optimizer.param_groups] - - decay_factor = ((1.0 - self.last_epoch / self.total_iters) / (1.0 - (self.last_epoch - 1) / self.total_iters)) ** self.power - return [group["lr"] * decay_factor for group in self.optimizer.param_groups] - - def _get_closed_form_lr(self): - return [ - ( - base_lr * (1.0 - min(self.total_iters, self.last_epoch) / self.total_iters) ** self.power - ) - for base_lr in self.base_lrs - ] - - -class CosineAnnealingLR(LRScheduler): - r"""Set the learning rate of each parameter group using a cosine annealing - schedule, where :math:`\eta_{max}` is set to the initial lr and - :math:`T_{cur}` is the number of epochs since the last restart in SGDR: - - .. math:: - \begin{aligned} - \eta_t & = \eta_{min} + \frac{1}{2}(\eta_{max} - \eta_{min})\left(1 - + \cos\left(\frac{T_{cur}}{T_{max}}\pi\right)\right), - & T_{cur} \neq (2k+1)T_{max}; \\ - \eta_{t+1} & = \eta_{t} + \frac{1}{2}(\eta_{max} - \eta_{min}) - \left(1 - \cos\left(\frac{1}{T_{max}}\pi\right)\right), - & T_{cur} = (2k+1)T_{max}. - \end{aligned} - - When last_epoch=-1, sets initial lr as lr. Notice that because the schedule - is defined recursively, the learning rate can be simultaneously modified - outside this scheduler by other operators. If the learning rate is set - solely by this scheduler, the learning rate at each step becomes: - - .. math:: - \eta_t = \eta_{min} + \frac{1}{2}(\eta_{max} - \eta_{min})\left(1 + - \cos\left(\frac{T_{cur}}{T_{max}}\pi\right)\right) - - It has been proposed in - `SGDR: Stochastic Gradient Descent with Warm Restarts`_. Note that this only - implements the cosine annealing part of SGDR, and not the restarts. - - Args: - optimizer (Optimizer): Wrapped optimizer. - T_max (int): Maximum number of iterations. - eta_min (float): Minimum learning rate. Default: 0. - last_epoch (int): The index of last epoch. Default: -1. - verbose (bool): If ``True``, prints a message to stdout for - each update. Default: ``False``. - - .. _SGDR\: Stochastic Gradient Descent with Warm Restarts: - https://arxiv.org/abs/1608.03983 - """ - - def __init__(self, optimizer, T_max, eta_min=0, last_epoch=-1, verbose=False): - self.T_max = T_max - self.eta_min = eta_min - super().__init__(optimizer, last_epoch, verbose) - - def get_lr(self): - if not self._get_lr_called_within_step: - warnings.warn("To get the last learning rate computed by the scheduler, " - "please use `get_last_lr()`.", UserWarning) - - if self.last_epoch == 0: - return [group['lr'] for group in self.optimizer.param_groups] - elif self._step_count == 1 and self.last_epoch > 0: - return [self.eta_min + (base_lr - self.eta_min) * - (1 + math.cos((self.last_epoch) * math.pi / self.T_max)) / 2 - for base_lr, group in - zip(self.base_lrs, self.optimizer.param_groups)] - elif (self.last_epoch - 1 - self.T_max) % (2 * self.T_max) == 0: - return [group['lr'] + (base_lr - self.eta_min) * - (1 - math.cos(math.pi / self.T_max)) / 2 - for base_lr, group in - zip(self.base_lrs, self.optimizer.param_groups)] - return [(1 + math.cos(math.pi * self.last_epoch / self.T_max)) / - (1 + math.cos(math.pi * (self.last_epoch - 1) / self.T_max)) * - (group['lr'] - self.eta_min) + self.eta_min - for group in self.optimizer.param_groups] - - def _get_closed_form_lr(self): - return [self.eta_min + (base_lr - self.eta_min) * - (1 + math.cos(math.pi * self.last_epoch / self.T_max)) / 2 - for base_lr in self.base_lrs] - - -class ChainedScheduler(LRScheduler): - """Chains list of learning rate schedulers. It takes a list of chainable learning - rate schedulers and performs consecutive step() functions belonging to them by just - one call. - - Args: - schedulers (list): List of chained schedulers. - - Example: - >>> # xdoctest: +SKIP - >>> # Assuming optimizer uses lr = 1. for all groups - >>> # lr = 0.09 if epoch == 0 - >>> # lr = 0.081 if epoch == 1 - >>> # lr = 0.729 if epoch == 2 - >>> # lr = 0.6561 if epoch == 3 - >>> # lr = 0.59049 if epoch >= 4 - >>> scheduler1 = ConstantLR(self.opt, factor=0.1, total_iters=2) - >>> scheduler2 = ExponentialLR(self.opt, gamma=0.9) - >>> scheduler = ChainedScheduler([scheduler1, scheduler2]) - >>> for epoch in range(100): - >>> train(...) - >>> validate(...) - >>> scheduler.step() - """ - - def __init__(self, schedulers): - for scheduler_idx in range(1, len(schedulers)): - if (schedulers[scheduler_idx].optimizer != schedulers[0].optimizer): - raise ValueError( - "ChainedScheduler expects all schedulers to belong to the same optimizer, but " - "got schedulers at index {} and {} to be different".format(0, scheduler_idx) - ) - self._schedulers = list(schedulers) - self.optimizer = schedulers[0].optimizer - self._last_lr = [group['lr'] for group in self._schedulers[-1].optimizer.param_groups] - - def step(self): - for scheduler in self._schedulers: - scheduler.step() - self._last_lr = [group['lr'] for group in self._schedulers[-1].optimizer.param_groups] - - def state_dict(self): - """Returns the state of the scheduler as a :class:`dict`. - - It contains an entry for every variable in self.__dict__ which - is not the optimizer. - The wrapped scheduler states will also be saved. - """ - state_dict = {key: value for key, value in self.__dict__.items() if key not in ('optimizer', '_schedulers')} - state_dict['_schedulers'] = [None] * len(self._schedulers) - - for idx, s in enumerate(self._schedulers): - state_dict['_schedulers'][idx] = s.state_dict() - - return state_dict - - def load_state_dict(self, state_dict): - """Loads the schedulers state. - - Args: - state_dict (dict): scheduler state. Should be an object returned - from a call to :meth:`state_dict`. - """ - _schedulers = state_dict.pop('_schedulers') - self.__dict__.update(state_dict) - # Restore state_dict keys in order to prevent side effects - # https://github.com/pytorch/pytorch/issues/32756 - state_dict['_schedulers'] = _schedulers - - for idx, s in enumerate(_schedulers): - self._schedulers[idx].load_state_dict(s) - - -class ReduceLROnPlateau: - """Reduce learning rate when a metric has stopped improving. - Models often benefit from reducing the learning rate by a factor - of 2-10 once learning stagnates. This scheduler reads a metrics - quantity and if no improvement is seen for a 'patience' number - of epochs, the learning rate is reduced. - - Args: - optimizer (Optimizer): Wrapped optimizer. - mode (str): One of `min`, `max`. In `min` mode, lr will - be reduced when the quantity monitored has stopped - decreasing; in `max` mode it will be reduced when the - quantity monitored has stopped increasing. Default: 'min'. - factor (float): Factor by which the learning rate will be - reduced. new_lr = lr * factor. Default: 0.1. - patience (int): Number of epochs with no improvement after - which learning rate will be reduced. For example, if - `patience = 2`, then we will ignore the first 2 epochs - with no improvement, and will only decrease the LR after the - 3rd epoch if the loss still hasn't improved then. - Default: 10. - threshold (float): Threshold for measuring the new optimum, - to only focus on significant changes. Default: 1e-4. - threshold_mode (str): One of `rel`, `abs`. In `rel` mode, - dynamic_threshold = best * ( 1 + threshold ) in 'max' - mode or best * ( 1 - threshold ) in `min` mode. - In `abs` mode, dynamic_threshold = best + threshold in - `max` mode or best - threshold in `min` mode. Default: 'rel'. - cooldown (int): Number of epochs to wait before resuming - normal operation after lr has been reduced. Default: 0. - min_lr (float or list): A scalar or a list of scalars. A - lower bound on the learning rate of all param groups - or each group respectively. Default: 0. - eps (float): Minimal decay applied to lr. If the difference - between new and old lr is smaller than eps, the update is - ignored. Default: 1e-8. - verbose (bool): If ``True``, prints a message to stdout for - each update. Default: ``False``. - - Example: - >>> # xdoctest: +SKIP - >>> optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9) - >>> scheduler = ReduceLROnPlateau(optimizer, 'min') - >>> for epoch in range(10): - >>> train(...) - >>> val_loss = validate(...) - >>> # Note that step should be called after validate() - >>> scheduler.step(val_loss) - """ - - def __init__(self, optimizer, mode='min', factor=0.1, patience=10, - threshold=1e-4, threshold_mode='rel', cooldown=0, - min_lr=0, eps=1e-8, verbose=False): - - if factor >= 1.0: - raise ValueError('Factor should be < 1.0.') - self.factor = factor - - # Attach optimizer - if not isinstance(optimizer, Optimizer): - raise TypeError('{} is not an Optimizer'.format( - type(optimizer).__name__)) - self.optimizer = optimizer - - if isinstance(min_lr, (list, tuple)): - if len(min_lr) != len(optimizer.param_groups): - raise ValueError("expected {} min_lrs, got {}".format( - len(optimizer.param_groups), len(min_lr))) - self.min_lrs = list(min_lr) - else: - self.min_lrs = [min_lr] * len(optimizer.param_groups) - - self.patience = patience - self.verbose = verbose - self.cooldown = cooldown - self.cooldown_counter = 0 - self.mode = mode - self.threshold = threshold - self.threshold_mode = threshold_mode - self.best = None - self.num_bad_epochs = None - self.mode_worse = None # the worse value for the chosen mode - self.eps = eps - self.last_epoch = 0 - self._init_is_better(mode=mode, threshold=threshold, - threshold_mode=threshold_mode) - self._reset() - - def _reset(self): - """Resets num_bad_epochs counter and cooldown counter.""" - self.best = self.mode_worse - self.cooldown_counter = 0 - self.num_bad_epochs = 0 - - def step(self, metrics, epoch=None): - # convert `metrics` to float, in case it's a zero-dim Tensor - current = float(metrics) - if epoch is None: - epoch = self.last_epoch + 1 - else: - warnings.warn(EPOCH_DEPRECATION_WARNING, UserWarning) - self.last_epoch = epoch - - if self.is_better(current, self.best): - self.best = current - self.num_bad_epochs = 0 - else: - self.num_bad_epochs += 1 - - if self.in_cooldown: - self.cooldown_counter -= 1 - self.num_bad_epochs = 0 # ignore any bad epochs in cooldown - - if self.num_bad_epochs > self.patience: - self._reduce_lr(epoch) - self.cooldown_counter = self.cooldown - self.num_bad_epochs = 0 - - self._last_lr = [group['lr'] for group in self.optimizer.param_groups] - - def _reduce_lr(self, epoch): - for i, param_group in enumerate(self.optimizer.param_groups): - old_lr = float(param_group['lr']) - new_lr = max(old_lr * self.factor, self.min_lrs[i]) - if old_lr - new_lr > self.eps: - param_group['lr'] = new_lr - if self.verbose: - epoch_str = ("%.2f" if isinstance(epoch, float) else - "%.5d") % epoch - print('Epoch {}: reducing learning rate' - ' of group {} to {:.4e}.'.format(epoch_str, i, new_lr)) - - @property - def in_cooldown(self): - return self.cooldown_counter > 0 - - def is_better(self, a, best): - if self.mode == 'min' and self.threshold_mode == 'rel': - rel_epsilon = 1. - self.threshold - return a < best * rel_epsilon - - elif self.mode == 'min' and self.threshold_mode == 'abs': - return a < best - self.threshold - - elif self.mode == 'max' and self.threshold_mode == 'rel': - rel_epsilon = self.threshold + 1. - return a > best * rel_epsilon - - else: # mode == 'max' and epsilon_mode == 'abs': - return a > best + self.threshold - - def _init_is_better(self, mode, threshold, threshold_mode): - if mode not in {'min', 'max'}: - raise ValueError('mode ' + mode + ' is unknown!') - if threshold_mode not in {'rel', 'abs'}: - raise ValueError('threshold mode ' + threshold_mode + ' is unknown!') - - if mode == 'min': - self.mode_worse = inf - else: # mode == 'max': - self.mode_worse = -inf - - self.mode = mode - self.threshold = threshold - self.threshold_mode = threshold_mode - - def state_dict(self): - return {key: value for key, value in self.__dict__.items() if key != 'optimizer'} - - def load_state_dict(self, state_dict): - self.__dict__.update(state_dict) - self._init_is_better(mode=self.mode, threshold=self.threshold, threshold_mode=self.threshold_mode) - - -class CyclicLR(LRScheduler): - r"""Sets the learning rate of each parameter group according to - cyclical learning rate policy (CLR). The policy cycles the learning - rate between two boundaries with a constant frequency, as detailed in - the paper `Cyclical Learning Rates for Training Neural Networks`_. - The distance between the two boundaries can be scaled on a per-iteration - or per-cycle basis. - - Cyclical learning rate policy changes the learning rate after every batch. - `step` should be called after a batch has been used for training. - - This class has three built-in policies, as put forth in the paper: - - * "triangular": A basic triangular cycle without amplitude scaling. - * "triangular2": A basic triangular cycle that scales initial amplitude by half each cycle. - * "exp_range": A cycle that scales initial amplitude by :math:`\text{gamma}^{\text{cycle iterations}}` - at each cycle iteration. - - This implementation was adapted from the github repo: `bckenstler/CLR`_ - - Args: - optimizer (Optimizer): Wrapped optimizer. - base_lr (float or list): Initial learning rate which is the - lower boundary in the cycle for each parameter group. - max_lr (float or list): Upper learning rate boundaries in the cycle - for each parameter group. Functionally, - it defines the cycle amplitude (max_lr - base_lr). - The lr at any cycle is the sum of base_lr - and some scaling of the amplitude; therefore - max_lr may not actually be reached depending on - scaling function. - step_size_up (int): Number of training iterations in the - increasing half of a cycle. Default: 2000 - step_size_down (int): Number of training iterations in the - decreasing half of a cycle. If step_size_down is None, - it is set to step_size_up. Default: None - mode (str): One of {triangular, triangular2, exp_range}. - Values correspond to policies detailed above. - If scale_fn is not None, this argument is ignored. - Default: 'triangular' - gamma (float): Constant in 'exp_range' scaling function: - gamma**(cycle iterations) - Default: 1.0 - scale_fn (function): Custom scaling policy defined by a single - argument lambda function, where - 0 <= scale_fn(x) <= 1 for all x >= 0. - If specified, then 'mode' is ignored. - Default: None - scale_mode (str): {'cycle', 'iterations'}. - Defines whether scale_fn is evaluated on - cycle number or cycle iterations (training - iterations since start of cycle). - Default: 'cycle' - cycle_momentum (bool): If ``True``, momentum is cycled inversely - to learning rate between 'base_momentum' and 'max_momentum'. - Default: True - base_momentum (float or list): Lower momentum boundaries in the cycle - for each parameter group. Note that momentum is cycled inversely - to learning rate; at the peak of a cycle, momentum is - 'base_momentum' and learning rate is 'max_lr'. - Default: 0.8 - max_momentum (float or list): Upper momentum boundaries in the cycle - for each parameter group. Functionally, - it defines the cycle amplitude (max_momentum - base_momentum). - The momentum at any cycle is the difference of max_momentum - and some scaling of the amplitude; therefore - base_momentum may not actually be reached depending on - scaling function. Note that momentum is cycled inversely - to learning rate; at the start of a cycle, momentum is 'max_momentum' - and learning rate is 'base_lr' - Default: 0.9 - last_epoch (int): The index of the last batch. This parameter is used when - resuming a training job. Since `step()` should be invoked after each - batch instead of after each epoch, this number represents the total - number of *batches* computed, not the total number of epochs computed. - When last_epoch=-1, the schedule is started from the beginning. - Default: -1 - verbose (bool): If ``True``, prints a message to stdout for - each update. Default: ``False``. - - Example: - >>> # xdoctest: +SKIP - >>> optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9) - >>> scheduler = torch.optim.lr_scheduler.CyclicLR(optimizer, base_lr=0.01, max_lr=0.1) - >>> data_loader = torch.utils.data.DataLoader(...) - >>> for epoch in range(10): - >>> for batch in data_loader: - >>> train_batch(...) - >>> scheduler.step() - - - .. _Cyclical Learning Rates for Training Neural Networks: https://arxiv.org/abs/1506.01186 - .. _bckenstler/CLR: https://github.com/bckenstler/CLR - """ - - def __init__(self, - optimizer, - base_lr, - max_lr, - step_size_up=2000, - step_size_down=None, - mode='triangular', - gamma=1., - scale_fn=None, - scale_mode='cycle', - cycle_momentum=True, - base_momentum=0.8, - max_momentum=0.9, - last_epoch=-1, - verbose=False): - - # Attach optimizer - if not isinstance(optimizer, Optimizer): - raise TypeError('{} is not an Optimizer'.format( - type(optimizer).__name__)) - self.optimizer = optimizer - - base_lrs = self._format_param('base_lr', optimizer, base_lr) - if last_epoch == -1: - for lr, group in zip(base_lrs, optimizer.param_groups): - group['lr'] = lr - - self.max_lrs = self._format_param('max_lr', optimizer, max_lr) - - step_size_up = float(step_size_up) - step_size_down = float(step_size_down) if step_size_down is not None else step_size_up - self.total_size = step_size_up + step_size_down - self.step_ratio = step_size_up / self.total_size - - if mode not in ['triangular', 'triangular2', 'exp_range'] \ - and scale_fn is None: - raise ValueError('mode is invalid and scale_fn is None') - - self.mode = mode - self.gamma = gamma - - self._scale_fn_ref = None - self._scale_fn_custom = scale_fn - self.scale_mode = scale_mode - self._init_scale_fn() - - self.cycle_momentum = cycle_momentum - if cycle_momentum: - if 'momentum' not in optimizer.defaults: - raise ValueError('optimizer must support momentum with `cycle_momentum` option enabled') - - base_momentums = self._format_param('base_momentum', optimizer, base_momentum) - if last_epoch == -1: - for momentum, group in zip(base_momentums, optimizer.param_groups): - group['momentum'] = momentum - self.base_momentums = [group['momentum'] for group in optimizer.param_groups] - self.max_momentums = self._format_param('max_momentum', optimizer, max_momentum) - - super().__init__(optimizer, last_epoch, verbose) - self.base_lrs = base_lrs - - def _init_scale_fn(self): - if self._scale_fn_custom is not None: - return - if self.mode == 'triangular': - self._scale_fn_ref = weakref.WeakMethod(self._triangular_scale_fn) - self.scale_mode = 'cycle' - elif self.mode == 'triangular2': - self._scale_fn_ref = weakref.WeakMethod(self._triangular2_scale_fn) - self.scale_mode = 'cycle' - elif self.mode == 'exp_range': - self._scale_fn_ref = weakref.WeakMethod(self._exp_range_scale_fn) - self.scale_mode = 'iterations' - - def _format_param(self, name, optimizer, param): - """Return correctly formatted lr/momentum for each param group.""" - if isinstance(param, (list, tuple)): - if len(param) != len(optimizer.param_groups): - raise ValueError("expected {} values for {}, got {}".format( - len(optimizer.param_groups), name, len(param))) - return param - else: - return [param] * len(optimizer.param_groups) - - def scale_fn(self, x): - if self._scale_fn_custom is not None: - return self._scale_fn_custom(x) - - else: - return self._scale_fn_ref()(x) - - def _triangular_scale_fn(self, x): - return 1. - - def _triangular2_scale_fn(self, x): - return 1 / (2. ** (x - 1)) - - def _exp_range_scale_fn(self, x): - return self.gamma**(x) - - def get_lr(self): - """Calculates the learning rate at batch index. This function treats - `self.last_epoch` as the last batch index. - - If `self.cycle_momentum` is ``True``, this function has a side effect of - updating the optimizer's momentum. - """ - - if not self._get_lr_called_within_step: - warnings.warn("To get the last learning rate computed by the scheduler, " - "please use `get_last_lr()`.", UserWarning) - - cycle = math.floor(1 + self.last_epoch / self.total_size) - x = 1. + self.last_epoch / self.total_size - cycle - if x <= self.step_ratio: - scale_factor = x / self.step_ratio - else: - scale_factor = (x - 1) / (self.step_ratio - 1) - - lrs = [] - for base_lr, max_lr in zip(self.base_lrs, self.max_lrs): - base_height = (max_lr - base_lr) * scale_factor - if self.scale_mode == 'cycle': - lr = base_lr + base_height * self.scale_fn(cycle) - else: - lr = base_lr + base_height * self.scale_fn(self.last_epoch) - lrs.append(lr) - - if self.cycle_momentum: - momentums = [] - for base_momentum, max_momentum in zip(self.base_momentums, self.max_momentums): - base_height = (max_momentum - base_momentum) * scale_factor - if self.scale_mode == 'cycle': - momentum = max_momentum - base_height * self.scale_fn(cycle) - else: - momentum = max_momentum - base_height * self.scale_fn(self.last_epoch) - momentums.append(momentum) - for param_group, momentum in zip(self.optimizer.param_groups, momentums): - param_group['momentum'] = momentum - - return lrs - - def state_dict(self): - state = super().state_dict() - # We are dropping the `_scale_fn_ref` attribute because it is a `weakref.WeakMethod` and can't be pickled - state.pop("_scale_fn_ref") - return state - - def load_state_dict(self, state_dict): - super().load_state_dict(state_dict) - self._init_scale_fn() - - - -class CosineAnnealingWarmRestarts(LRScheduler): - r"""Set the learning rate of each parameter group using a cosine annealing - schedule, where :math:`\eta_{max}` is set to the initial lr, :math:`T_{cur}` - is the number of epochs since the last restart and :math:`T_{i}` is the number - of epochs between two warm restarts in SGDR: - - .. math:: - \eta_t = \eta_{min} + \frac{1}{2}(\eta_{max} - \eta_{min})\left(1 + - \cos\left(\frac{T_{cur}}{T_{i}}\pi\right)\right) - - When :math:`T_{cur}=T_{i}`, set :math:`\eta_t = \eta_{min}`. - When :math:`T_{cur}=0` after restart, set :math:`\eta_t=\eta_{max}`. - - It has been proposed in - `SGDR: Stochastic Gradient Descent with Warm Restarts`_. - - Args: - optimizer (Optimizer): Wrapped optimizer. - T_0 (int): Number of iterations for the first restart. - T_mult (int, optional): A factor increases :math:`T_{i}` after a restart. Default: 1. - eta_min (float, optional): Minimum learning rate. Default: 0. - last_epoch (int, optional): The index of last epoch. Default: -1. - verbose (bool): If ``True``, prints a message to stdout for - each update. Default: ``False``. - - .. _SGDR\: Stochastic Gradient Descent with Warm Restarts: - https://arxiv.org/abs/1608.03983 - """ - - def __init__(self, optimizer, T_0, T_mult=1, eta_min=0, last_epoch=-1, verbose=False): - if T_0 <= 0 or not isinstance(T_0, int): - raise ValueError("Expected positive integer T_0, but got {}".format(T_0)) - if T_mult < 1 or not isinstance(T_mult, int): - raise ValueError("Expected integer T_mult >= 1, but got {}".format(T_mult)) - self.T_0 = T_0 - self.T_i = T_0 - self.T_mult = T_mult - self.eta_min = eta_min - self.T_cur = last_epoch - super().__init__(optimizer, last_epoch, verbose) - - def get_lr(self): - if not self._get_lr_called_within_step: - warnings.warn("To get the last learning rate computed by the scheduler, " - "please use `get_last_lr()`.", UserWarning) - - return [self.eta_min + (base_lr - self.eta_min) * (1 + math.cos(math.pi * self.T_cur / self.T_i)) / 2 - for base_lr in self.base_lrs] - - def step(self, epoch=None): - """Step could be called after every batch update - - Example: - >>> # xdoctest: +SKIP("Undefined vars") - >>> scheduler = CosineAnnealingWarmRestarts(optimizer, T_0, T_mult) - >>> iters = len(dataloader) - >>> for epoch in range(20): - >>> for i, sample in enumerate(dataloader): - >>> inputs, labels = sample['inputs'], sample['labels'] - >>> optimizer.zero_grad() - >>> outputs = net(inputs) - >>> loss = criterion(outputs, labels) - >>> loss.backward() - >>> optimizer.step() - >>> scheduler.step(epoch + i / iters) - - This function can be called in an interleaved way. - - Example: - >>> # xdoctest: +SKIP("Undefined vars") - >>> scheduler = CosineAnnealingWarmRestarts(optimizer, T_0, T_mult) - >>> for epoch in range(20): - >>> scheduler.step() - >>> scheduler.step(26) - >>> scheduler.step() # scheduler.step(27), instead of scheduler(20) - """ - - if epoch is None and self.last_epoch < 0: - epoch = 0 - - if epoch is None: - epoch = self.last_epoch + 1 - self.T_cur = self.T_cur + 1 - if self.T_cur >= self.T_i: - self.T_cur = self.T_cur - self.T_i - self.T_i = self.T_i * self.T_mult - else: - if epoch < 0: - raise ValueError("Expected non-negative epoch, but got {}".format(epoch)) - if epoch >= self.T_0: - if self.T_mult == 1: - self.T_cur = epoch % self.T_0 - else: - n = int(math.log((epoch / self.T_0 * (self.T_mult - 1) + 1), self.T_mult)) - self.T_cur = epoch - self.T_0 * (self.T_mult ** n - 1) / (self.T_mult - 1) - self.T_i = self.T_0 * self.T_mult ** (n) - else: - self.T_i = self.T_0 - self.T_cur = epoch - self.last_epoch = math.floor(epoch) - - class _enable_get_lr_call: - - def __init__(self, o): - self.o = o - - def __enter__(self): - self.o._get_lr_called_within_step = True - return self - - def __exit__(self, type, value, traceback): - self.o._get_lr_called_within_step = False - return self - - with _enable_get_lr_call(self): - for i, data in enumerate(zip(self.optimizer.param_groups, self.get_lr())): - param_group, lr = data - param_group['lr'] = lr - self.print_lr(self.verbose, i, lr, epoch) - - self._last_lr = [group['lr'] for group in self.optimizer.param_groups] - - -class OneCycleLR(LRScheduler): - r"""Sets the learning rate of each parameter group according to the - 1cycle learning rate policy. The 1cycle policy anneals the learning - rate from an initial learning rate to some maximum learning rate and then - from that maximum learning rate to some minimum learning rate much lower - than the initial learning rate. - This policy was initially described in the paper `Super-Convergence: - Very Fast Training of Neural Networks Using Large Learning Rates`_. - - The 1cycle learning rate policy changes the learning rate after every batch. - `step` should be called after a batch has been used for training. - - This scheduler is not chainable. - - Note also that the total number of steps in the cycle can be determined in one - of two ways (listed in order of precedence): - - #. A value for total_steps is explicitly provided. - #. A number of epochs (epochs) and a number of steps per epoch - (steps_per_epoch) are provided. - In this case, the number of total steps is inferred by - total_steps = epochs * steps_per_epoch - - You must either provide a value for total_steps or provide a value for both - epochs and steps_per_epoch. - - The default behaviour of this scheduler follows the fastai implementation of 1cycle, which - claims that "unpublished work has shown even better results by using only two phases". To - mimic the behaviour of the original paper instead, set ``three_phase=True``. - - Args: - optimizer (Optimizer): Wrapped optimizer. - max_lr (float or list): Upper learning rate boundaries in the cycle - for each parameter group. - total_steps (int): The total number of steps in the cycle. Note that - if a value is not provided here, then it must be inferred by providing - a value for epochs and steps_per_epoch. - Default: None - epochs (int): The number of epochs to train for. This is used along - with steps_per_epoch in order to infer the total number of steps in the cycle - if a value for total_steps is not provided. - Default: None - steps_per_epoch (int): The number of steps per epoch to train for. This is - used along with epochs in order to infer the total number of steps in the - cycle if a value for total_steps is not provided. - Default: None - pct_start (float): The percentage of the cycle (in number of steps) spent - increasing the learning rate. - Default: 0.3 - anneal_strategy (str): {'cos', 'linear'} - Specifies the annealing strategy: "cos" for cosine annealing, "linear" for - linear annealing. - Default: 'cos' - cycle_momentum (bool): If ``True``, momentum is cycled inversely - to learning rate between 'base_momentum' and 'max_momentum'. - Default: True - base_momentum (float or list): Lower momentum boundaries in the cycle - for each parameter group. Note that momentum is cycled inversely - to learning rate; at the peak of a cycle, momentum is - 'base_momentum' and learning rate is 'max_lr'. - Default: 0.85 - max_momentum (float or list): Upper momentum boundaries in the cycle - for each parameter group. Functionally, - it defines the cycle amplitude (max_momentum - base_momentum). - Note that momentum is cycled inversely - to learning rate; at the start of a cycle, momentum is 'max_momentum' - and learning rate is 'base_lr' - Default: 0.95 - div_factor (float): Determines the initial learning rate via - initial_lr = max_lr/div_factor - Default: 25 - final_div_factor (float): Determines the minimum learning rate via - min_lr = initial_lr/final_div_factor - Default: 1e4 - three_phase (bool): If ``True``, use a third phase of the schedule to annihilate the - learning rate according to 'final_div_factor' instead of modifying the second - phase (the first two phases will be symmetrical about the step indicated by - 'pct_start'). - last_epoch (int): The index of the last batch. This parameter is used when - resuming a training job. Since `step()` should be invoked after each - batch instead of after each epoch, this number represents the total - number of *batches* computed, not the total number of epochs computed. - When last_epoch=-1, the schedule is started from the beginning. - Default: -1 - verbose (bool): If ``True``, prints a message to stdout for - each update. Default: ``False``. - - Example: - >>> # xdoctest: +SKIP - >>> data_loader = torch.utils.data.DataLoader(...) - >>> optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9) - >>> scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer, max_lr=0.01, steps_per_epoch=len(data_loader), epochs=10) - >>> for epoch in range(10): - >>> for batch in data_loader: - >>> train_batch(...) - >>> optimizer.step() - >>> scheduler.step() - - - .. _Super-Convergence\: Very Fast Training of Neural Networks Using Large Learning Rates: - https://arxiv.org/abs/1708.07120 - """ - def __init__(self, - optimizer, - max_lr, - total_steps=None, - epochs=None, - steps_per_epoch=None, - pct_start=0.3, - anneal_strategy='cos', - cycle_momentum=True, - base_momentum=0.85, - max_momentum=0.95, - div_factor=25., - final_div_factor=1e4, - three_phase=False, - last_epoch=-1, - verbose=False): - - # Validate optimizer - if not isinstance(optimizer, Optimizer): - raise TypeError('{} is not an Optimizer'.format( - type(optimizer).__name__)) - self.optimizer = optimizer - - # Validate total_steps - if total_steps is None and epochs is None and steps_per_epoch is None: - raise ValueError("You must define either total_steps OR (epochs AND steps_per_epoch)") - elif total_steps is not None: - if total_steps <= 0 or not isinstance(total_steps, int): - raise ValueError("Expected positive integer total_steps, but got {}".format(total_steps)) - self.total_steps = total_steps - else: - if epochs <= 0 or not isinstance(epochs, int): - raise ValueError("Expected positive integer epochs, but got {}".format(epochs)) - if steps_per_epoch <= 0 or not isinstance(steps_per_epoch, int): - raise ValueError("Expected positive integer steps_per_epoch, but got {}".format(steps_per_epoch)) - self.total_steps = epochs * steps_per_epoch - - if three_phase: - self._schedule_phases = [ - { - 'end_step': float(pct_start * self.total_steps) - 1, - 'start_lr': 'initial_lr', - 'end_lr': 'max_lr', - 'start_momentum': 'max_momentum', - 'end_momentum': 'base_momentum', - }, - { - 'end_step': float(2 * pct_start * self.total_steps) - 2, - 'start_lr': 'max_lr', - 'end_lr': 'initial_lr', - 'start_momentum': 'base_momentum', - 'end_momentum': 'max_momentum', - }, - { - 'end_step': self.total_steps - 1, - 'start_lr': 'initial_lr', - 'end_lr': 'min_lr', - 'start_momentum': 'max_momentum', - 'end_momentum': 'max_momentum', - }, - ] - else: - self._schedule_phases = [ - { - 'end_step': float(pct_start * self.total_steps) - 1, - 'start_lr': 'initial_lr', - 'end_lr': 'max_lr', - 'start_momentum': 'max_momentum', - 'end_momentum': 'base_momentum', - }, - { - 'end_step': self.total_steps - 1, - 'start_lr': 'max_lr', - 'end_lr': 'min_lr', - 'start_momentum': 'base_momentum', - 'end_momentum': 'max_momentum', - }, - ] - - # Validate pct_start - if pct_start < 0 or pct_start > 1 or not isinstance(pct_start, float): - raise ValueError("Expected float between 0 and 1 pct_start, but got {}".format(pct_start)) - - # Validate anneal_strategy - if anneal_strategy not in ['cos', 'linear']: - raise ValueError("anneal_strategy must by one of 'cos' or 'linear', instead got {}".format(anneal_strategy)) - elif anneal_strategy == 'cos': - self.anneal_func = self._annealing_cos - elif anneal_strategy == 'linear': - self.anneal_func = self._annealing_linear - - # Initialize learning rate variables - max_lrs = self._format_param('max_lr', self.optimizer, max_lr) - if last_epoch == -1: - for idx, group in enumerate(self.optimizer.param_groups): - group['initial_lr'] = max_lrs[idx] / div_factor - group['max_lr'] = max_lrs[idx] - group['min_lr'] = group['initial_lr'] / final_div_factor - - # Initialize momentum variables - self.cycle_momentum = cycle_momentum - if self.cycle_momentum: - if 'momentum' not in self.optimizer.defaults and 'betas' not in self.optimizer.defaults: - raise ValueError('optimizer must support momentum with `cycle_momentum` option enabled') - self.use_beta1 = 'betas' in self.optimizer.defaults - max_momentums = self._format_param('max_momentum', optimizer, max_momentum) - base_momentums = self._format_param('base_momentum', optimizer, base_momentum) - if last_epoch == -1: - for m_momentum, b_momentum, group in zip(max_momentums, base_momentums, optimizer.param_groups): - if self.use_beta1: - group['betas'] = (m_momentum, *group['betas'][1:]) - else: - group['momentum'] = m_momentum - group['max_momentum'] = m_momentum - group['base_momentum'] = b_momentum - - super().__init__(optimizer, last_epoch, verbose) - - def _format_param(self, name, optimizer, param): - """Return correctly formatted lr/momentum for each param group.""" - if isinstance(param, (list, tuple)): - if len(param) != len(optimizer.param_groups): - raise ValueError("expected {} values for {}, got {}".format( - len(optimizer.param_groups), name, len(param))) - return param - else: - return [param] * len(optimizer.param_groups) - - def _annealing_cos(self, start, end, pct): - "Cosine anneal from `start` to `end` as pct goes from 0.0 to 1.0." - cos_out = math.cos(math.pi * pct) + 1 - return end + (start - end) / 2.0 * cos_out - - def _annealing_linear(self, start, end, pct): - "Linearly anneal from `start` to `end` as pct goes from 0.0 to 1.0." - return (end - start) * pct + start - - def get_lr(self): - if not self._get_lr_called_within_step: - warnings.warn("To get the last learning rate computed by the scheduler, " - "please use `get_last_lr()`.", UserWarning) - - lrs = [] - step_num = self.last_epoch - - if step_num > self.total_steps: - raise ValueError("Tried to step {} times. The specified number of total steps is {}" - .format(step_num, self.total_steps)) - - for group in self.optimizer.param_groups: - start_step = 0 - for i, phase in enumerate(self._schedule_phases): - end_step = phase['end_step'] - if step_num <= end_step or i == len(self._schedule_phases) - 1: - pct = (step_num - start_step) / (end_step - start_step) - computed_lr = self.anneal_func(group[phase['start_lr']], group[phase['end_lr']], pct) - if self.cycle_momentum: - computed_momentum = self.anneal_func(group[phase['start_momentum']], group[phase['end_momentum']], pct) - break - start_step = phase['end_step'] - - lrs.append(computed_lr) - if self.cycle_momentum: - if self.use_beta1: - group['betas'] = (computed_momentum, *group['betas'][1:]) - else: - group['momentum'] = computed_momentum - - return lrs \ No newline at end of file diff --git a/python/jittor/compatibility/src/jtorch_core.cc b/python/jittor/compatibility/src/jtorch_core.cc deleted file mode 100644 index 1102b107..00000000 --- a/python/jittor/compatibility/src/jtorch_core.cc +++ /dev/null @@ -1,102 +0,0 @@ - -#include "pyjt/py_obj_holder.h" -#include "utils/str_utils.h" -#include "jtorch_core.h" -#include "graph.h" -#include "grad.h" -#include "ops/op_register.h" - -namespace jittor { - -void pyjt_def_all(PyObject* m); - -EXTERN_LIB void setter_use_cuda(int value); - -Device::Device(const string& name, int ordinal) : name(name) { - if (startswith(name, "cpu")) - setter_use_cuda(0); - else - setter_use_cuda(1); -} - -unordered_map grad_backup; -EXTERN_LIB void (*_var_free_hook)(Var*); -EXTERN_LIB unordered_map* _grad_backup_ptr; - -void jtorch_var_free_hook(Var* v) { - auto iter = grad_backup.find(v->id); - if (iter != grad_backup.end()) { - grad_backup.erase(iter); - } -} - -void jtorch_init() { - _var_free_hook = &jtorch_var_free_hook; - _grad_backup_ptr = &grad_backup; -} - -inline static VarPtr& get_grad(Var* v) { - return grad_backup[v->id]; -} -static auto make_binary = get_op_info("binary") - .get_constructor(); - -inline static void add_grad(VarPtr& a, VarPtr&& b) { - if (!a) a = move(b); - else { - a = make_binary(a, b, ns_add); - } -} - - -void grad_set(VarHolder* x, Maybe v) { - if (!v) { - grad_del(x); - return; - } - grad_backup[x->var->id] = v.ptr->var; -} - -Maybe grad_get(VarHolder* x) { - auto iter = grad_backup.find(x->var->id); - if (iter != grad_backup.end()) { - if (!iter->second.ptr) return nullptr; - return new VarHolder(iter->second.ptr); - } - return nullptr; -} - -void grad_del(VarHolder* x) { - auto iter = grad_backup.find(x->var->id); - if (iter != grad_backup.end()) - grad_backup.erase(iter); -} - -void backward(VarHolder* x) { - vector gnodes({x->var}); - bfs_backward(gnodes, [&](Node* node) { - if (node->is_stop_grad()) - return false; - return true; - }); - vector targets; - for (auto* node : gnodes) { - if (node->is_var() && node->flags.get(NodeFlags::_th_require_grad)) - targets.push_back(node->var()); - } - auto grads = grad(x->var, targets); - for (int i=0; im_doc = "Inner c++ core of jtorch"; - jittor::pyjt_def_all(m); -} -PYJT_MODULE_INIT(jtorch_core); diff --git a/python/jittor/compatibility/src/jtorch_core.h b/python/jittor/compatibility/src/jtorch_core.h deleted file mode 100644 index 36de6522..00000000 --- a/python/jittor/compatibility/src/jtorch_core.h +++ /dev/null @@ -1,40 +0,0 @@ -#pragma once -#include "common.h" -#include "var_holder.h" -#include "misc/fast_shared_ptr.h" - -namespace jittor { - -// @pyjt(device) -// @attrs(heaptype) -struct Device { - string name; - - // @pyjt(__init__) - Device(const string& name, int ordinal=0); - // @pyjt(__get__type, __str__) - inline string get_type() {return name;} - // @pyjt(__get__index) - inline int index() {return 0;} -}; - -// @pyjt(backward) -void backward(VarHolder* x); - -// @pyjt(grad_set) -void grad_set(VarHolder* x, Maybe v); -// @pyjt(grad_get) -Maybe grad_get(VarHolder* x); -// @pyjt(grad_del) -void grad_del(VarHolder* x); - -// @pyjt(retain_grad_set) -inline void retain_grad_set(VarHolder* x, bool v) { - x->var->flags.set(NodeFlags::_th_require_grad, v); -} -// @pyjt(retain_grad_get) -inline bool retain_grad_get(VarHolder* x) { - return x->var->flags.get(NodeFlags::_th_require_grad); -} - -} \ No newline at end of file diff --git a/python/jittor/compatibility/test/test_conflict_func.py b/python/jittor/compatibility/test/test_conflict_func.py deleted file mode 100644 index 97bd7d8f..00000000 --- a/python/jittor/compatibility/test/test_conflict_func.py +++ /dev/null @@ -1,25 +0,0 @@ -import unittest -import numpy as np -import torch -import jittor as jt - -class TestConflictFunc(unittest.TestCase): - def test_max(self): - a = torch.Tensor([1,4,2]) - assert a.max() == 4 - v, k = a.max(dim=0) - assert v==4 and k==1 - - def test_argsort(self): - a = torch.Tensor([1,4,2]) - k = a.argsort() - assert jt.all_equal(k, [0,2,1]) - - with jt.flag_scope(th_mode=0): - k, v = a.argsort() - assert jt.all_equal(k, [0,2,1]) - - - -if __name__ == "__main__": - unittest.main() diff --git a/python/jittor/compatibility/test/test_function.py b/python/jittor/compatibility/test/test_function.py deleted file mode 100644 index 9959dbae..00000000 --- a/python/jittor/compatibility/test/test_function.py +++ /dev/null @@ -1,58 +0,0 @@ -import unittest -import numpy as np -import torch - -class TestFunction(unittest.TestCase): - def test_example1(self): - import jtorch - from jtorch import Function - - class MyFunc(Function): - @staticmethod - def forward(self, x, y): - self.x = x - self.y = y - return x*y, x/y - - @staticmethod - def backward(self, grad0, grad1): - return grad0 * self.y, grad1 * self.x - - a = jtorch.array(3.0) - a.requires_grad = True - b = jtorch.array(4.0) - b.requires_grad = True - func = MyFunc.apply - c,d = func(a, b) - (c+d*3).backward() - assert a.grad.data == 4 - assert b.grad.data == 9 - - def test_example2(self): - import jtorch as jt - from jtorch import Function - - class MyFunc(Function): - @staticmethod - def forward(self, x, y): - self.x = x - self.y = y - return x*y, x/y - - @staticmethod - def backward(self, grad0, grad1): - assert grad1 is None - return grad0 * self.y, None - a = jt.array(3.0) - a.requires_grad = True - b = jt.array(4.0) - b.requires_grad = True - func = MyFunc.apply - c,d = func(a, b) - d.stop_grad() - da, db = jt.grad(c+d*3, [a, b]) - assert da.data == 4 - assert db.data == 0 - -if __name__ == "__main__": - unittest.main() diff --git a/python/jittor/compatibility/test/test_misc.py b/python/jittor/compatibility/test/test_misc.py deleted file mode 100644 index 00bf1b70..00000000 --- a/python/jittor/compatibility/test/test_misc.py +++ /dev/null @@ -1,24 +0,0 @@ -import unittest -import numpy as np -import torch - -class TestMisc(unittest.TestCase): - def test_update_grad(self): - class Net(torch.nn.Module): - def __init__(self): - super().__init__() - self.a = torch.nn.Parameter(torch.Tensor([1.0, 2.0])) - net = Net() - assert(net.a.requires_grad) - net.load_state_dict({"a": torch.Tensor([3.0, 4.0])}) - assert(net.a.requires_grad) - - def test_reshape(self): - a = torch.ones(3,3) - a.requires_grad = True - b = torch.reshape(a, [9]) - assert b.requires_grad == True - - -if __name__ == "__main__": - unittest.main() diff --git a/python/jittor/compatibility/test/test_tutorial.py b/python/jittor/compatibility/test/test_tutorial.py deleted file mode 100644 index 92c087c7..00000000 --- a/python/jittor/compatibility/test/test_tutorial.py +++ /dev/null @@ -1,56 +0,0 @@ -import unittest -import numpy as np -import os -import subprocess as sp -import sys - -def check_two(cmd, parser=None, checker=None): - jtorch_out = sp.getoutput(cmd) - print("=========JTORCH OUT==========") - print(jtorch_out) - torch_out = sp.getoutput("PYTHONPATH= "+cmd) - print("=========TORCH OUT==========") - print(torch_out) - if parser: - torch_out = parser(torch_out) - jtorch_out = parser(jtorch_out) - if checker: - checker(torch_out, jtorch_out) - else: - assert torch_out == jtorch_out - return jtorch_out, torch_out - -jtorch_path = os.path.join(os.path.dirname(__file__), "..") -# come from https://pytorch.org/tutorials/beginner/pytorch_with_examples.html -class TestTutorial(unittest.TestCase): - def test_auto_grad1(self): - check_two(f"{sys.executable} {jtorch_path}/tutorial/auto_grad1.py", - parser=lambda s: np.array(s.split())[[-10,-8,-5,-2]].astype(float), - checker=lambda a,b: np.testing.assert_allclose(a, b, atol=1e-4)) - def test_auto_grad2(self): - check_two(f"{sys.executable} {jtorch_path}/tutorial/auto_grad2.py", - parser=lambda s: np.array(s.split())[[-10,-8,-5,-2]].astype(float), - checker=lambda a,b: np.testing.assert_allclose(a, b, atol=1e-4)) - def test_auto_grad3(self): - check_two(f"{sys.executable} {jtorch_path}/tutorial/auto_grad3.py", - parser=lambda s: np.array(s.split())[[-9,-7,-4,-2]].astype(float), - checker=lambda a,b: np.testing.assert_allclose(a, b, atol=1e-4)) - def test_auto_grad4(self): - check_two(f"{sys.executable} {jtorch_path}/tutorial/auto_grad4.py", - parser=lambda s: np.array(s.split())[[-10,-8,-5,-2]].astype(float), - checker=lambda a,b: np.testing.assert_allclose(a, b, atol=1e-4)) - def test_auto_grad5(self): - check_two(f"{sys.executable} {jtorch_path}/tutorial/auto_grad5_optim.py", - parser=lambda s: np.array(s.split())[[-10,-8,-5,-2]].astype(float), - checker=lambda a,b: np.testing.assert_allclose(a, b, atol=1e-2)) - def test_auto_grad6(self): - check_two(f"{sys.executable} {jtorch_path}/tutorial/auto_grad6_module.py", - parser=lambda s: np.array(s.split())[[-10,-8,-5,-2]].astype(float), - checker=lambda a,b: np.testing.assert_allclose(a, b, atol=1e-4)) - def test_auto_grad7(self): - check_two(f"{sys.executable} {jtorch_path}/tutorial/auto_grad7_dynet.py", - parser=lambda s: np.array(s.split())[[-13,-10,-7,-3]].astype(float), - checker=lambda a,b: np.testing.assert_allclose(a, b, atol=1e-2)) - -if __name__ == "__main__": - unittest.main() \ No newline at end of file diff --git a/python/jittor/compatibility/tutorial/auto_grad1.py b/python/jittor/compatibility/tutorial/auto_grad1.py deleted file mode 100644 index 60a090ad..00000000 --- a/python/jittor/compatibility/tutorial/auto_grad1.py +++ /dev/null @@ -1,44 +0,0 @@ -import torch -import math - -dtype = torch.float -device = torch.device("cpu") -# device = torch.device("cuda:0") # Uncomment this to run on GPU - -# Create random input and output data -x = torch.linspace(-math.pi, math.pi, 2000, device=device, dtype=dtype) -y = torch.sin(x) - -# Randomly initialize weights -a = torch.randn((), device=device, dtype=dtype) -b = torch.randn((), device=device, dtype=dtype) -c = torch.randn((), device=device, dtype=dtype) -d = torch.randn((), device=device, dtype=dtype) - -learning_rate = 1e-6 -for t in range(20000): - # Forward pass: compute predicted y - y_pred = a + b * x + c * x ** 2 + d * x ** 3 - - # Compute and print loss - loss = (y_pred - y).pow(2).sum().item() - if t % 1000 == 999: - print(t, loss) - - # Backprop to compute gradients of a, b, c, d with respect to loss - grad_y_pred = 2.0 * (y_pred - y) - grad_a = grad_y_pred.sum() - grad_b = (grad_y_pred * x).sum() - grad_c = (grad_y_pred * x ** 2).sum() - grad_d = (grad_y_pred * x ** 3).sum() - - # Update weights using gradient descent - a -= learning_rate * grad_a - b -= learning_rate * grad_b - c -= learning_rate * grad_c - d -= learning_rate * grad_d - # print(t, torch.liveness_info()) - # torch.sync_all() - - -print(f'Result: y = {a.item()} + {b.item()} x + {c.item()} x^2 + {d.item()} x^3') \ No newline at end of file diff --git a/python/jittor/compatibility/tutorial/auto_grad2.py b/python/jittor/compatibility/tutorial/auto_grad2.py deleted file mode 100644 index a3bbc9a8..00000000 --- a/python/jittor/compatibility/tutorial/auto_grad2.py +++ /dev/null @@ -1,60 +0,0 @@ -# -*- coding: utf-8 -*- -import torch -import math - -dtype = torch.float -device = torch.device("cpu") -# device = torch.device("cuda:0") # Uncomment this to run on GPU - -# Create Tensors to hold input and outputs. -# By default, requires_grad=False, which indicates that we do not need to -# compute gradients with respect to these Tensors during the backward pass. -x = torch.linspace(-math.pi, math.pi, 2000, device=device, dtype=dtype) -y = torch.sin(x) - -# Create random Tensors for weights. For a third order polynomial, we need -# 4 weights: y = a + b x + c x^2 + d x^3 -# Setting requires_grad=True indicates that we want to compute gradients with -# respect to these Tensors during the backward pass. -a = torch.randn((), device=device, dtype=dtype, requires_grad=True) -b = torch.randn((), device=device, dtype=dtype, requires_grad=True) -c = torch.randn((), device=device, dtype=dtype, requires_grad=True) -d = torch.randn((), device=device, dtype=dtype, requires_grad=True) - -learning_rate = 1e-6 -for t in range(20000): - # Forward pass: compute predicted y using operations on Tensors. - y_pred = a + b * x + c * x ** 2 + d * x ** 3 - # print(y_pred.requires_grad) - # y_pred.requires_grad = False - - # Compute and print loss using operations on Tensors. - # Now loss is a Tensor of shape (1,) - # loss.item() gets the scalar value held in the loss. - loss = (y_pred - y).pow(2).sum() - if t % 1000 == 990: - print(t, loss.item()) - - # Use autograd to compute the backward pass. This call will compute the - # gradient of loss with respect to all Tensors with requires_grad=True. - # After this call a.grad, b.grad. c.grad and d.grad will be Tensors holding - # the gradient of the loss with respect to a, b, c, d respectively. - # torch.backward(loss) - loss.backward() - - # Manually update weights using gradient descent. Wrap in torch.no_grad() - # because weights have requires_grad=True, but we don't need to track this - # in autograd. - with torch.no_grad(): - a -= learning_rate * a.grad - b -= learning_rate * b.grad - c -= learning_rate * c.grad - d -= learning_rate * d.grad - - # Manually zero the gradients after updating weights - a.grad = None - b.grad = None - c.grad = None - d.grad = None - -print(f'Result: y = {a.item()} + {b.item()} x + {c.item()} x^2 + {d.item()} x^3') \ No newline at end of file diff --git a/python/jittor/compatibility/tutorial/auto_grad3.py b/python/jittor/compatibility/tutorial/auto_grad3.py deleted file mode 100644 index 654ec447..00000000 --- a/python/jittor/compatibility/tutorial/auto_grad3.py +++ /dev/null @@ -1,85 +0,0 @@ -# -*- coding: utf-8 -*- -import torch -import math - - -class LegendrePolynomial3(torch.autograd.Function): - """ - We can implement our own custom autograd Functions by subclassing - torch.autograd.Function and implementing the forward and backward passes - which operate on Tensors. - """ - - @staticmethod - def forward(ctx, input): - """ - In the forward pass we receive a Tensor containing the input and return - a Tensor containing the output. ctx is a context object that can be used - to stash information for backward computation. You can cache arbitrary - objects for use in the backward pass using the ctx.save_for_backward method. - """ - ctx.save_for_backward(input) - return 0.5 * (5 * input ** 3 - 3 * input) - - @staticmethod - def backward(ctx, grad_output): - """ - In the backward pass we receive a Tensor containing the gradient of the loss - with respect to the output, and we need to compute the gradient of the loss - with respect to the input. - """ - input, = ctx.saved_tensors - return grad_output * 1.5 * (5 * input ** 2 - 1) - - -dtype = torch.float -device = torch.device("cpu") -# device = torch.device("cuda:0") # Uncomment this to run on GPU - -# Create Tensors to hold input and outputs. -# By default, requires_grad=False, which indicates that we do not need to -# compute gradients with respect to these Tensors during the backward pass. -x = torch.linspace(-math.pi, math.pi, 2000, device=device, dtype=dtype) -y = torch.sin(x) - -# Create random Tensors for weights. For this example, we need -# 4 weights: y = a + b * P3(c + d * x), these weights need to be initialized -# not too far from the correct result to ensure convergence. -# Setting requires_grad=True indicates that we want to compute gradients with -# respect to these Tensors during the backward pass. -a = torch.full((), 0.0, device=device, dtype=dtype, requires_grad=True) -b = torch.full((), -1.0, device=device, dtype=dtype, requires_grad=True) -c = torch.full((), 0.0, device=device, dtype=dtype, requires_grad=True) -d = torch.full((), 0.3, device=device, dtype=dtype, requires_grad=True) - -learning_rate = 5e-6 -for t in range(2000): - # To apply our Function, we use Function.apply method. We alias this as 'P3'. - P3 = LegendrePolynomial3.apply - - # Forward pass: compute predicted y using operations; we compute - # P3 using our custom autograd operation. - y_pred = a + b * P3(c + d * x) - - # Compute and print loss - loss = (y_pred - y).pow(2).sum() - if t % 100 == 99: - print(t, loss.item()) - - # Use autograd to compute the backward pass. - loss.backward() - - # Update weights using gradient descent - with torch.no_grad(): - a -= learning_rate * a.grad - b -= learning_rate * b.grad - c -= learning_rate * c.grad - d -= learning_rate * d.grad - - # Manually zero the gradients after updating weights - a.grad = None - b.grad = None - c.grad = None - d.grad = None - -print(f'Result: y = {a.item()} + {b.item()} * P3( {c.item()} + {d.item()} x)') \ No newline at end of file diff --git a/python/jittor/compatibility/tutorial/auto_grad4.py b/python/jittor/compatibility/tutorial/auto_grad4.py deleted file mode 100644 index 062d0b0e..00000000 --- a/python/jittor/compatibility/tutorial/auto_grad4.py +++ /dev/null @@ -1,71 +0,0 @@ -# -*- coding: utf-8 -*- -import torch -import math - - -# Create Tensors to hold input and outputs. -x = torch.linspace(-math.pi, math.pi, 2000) -y = torch.sin(x) - -# For this example, the output y is a linear function of (x, x^2, x^3), so -# we can consider it as a linear layer neural network. Let's prepare the -# tensor (x, x^2, x^3). -p = torch.tensor([1, 2, 3]) -xx = x.unsqueeze(-1).pow(p) - -# In the above code, x.unsqueeze(-1) has shape (2000, 1), and p has shape -# (3,), for this case, broadcasting semantics will apply to obtain a tensor -# of shape (2000, 3) - -# Use the nn package to define our model as a sequence of layers. nn.Sequential -# is a Module which contains other Modules, and applies them in sequence to -# produce its output. The Linear Module computes output from input using a -# linear function, and holds internal Tensors for its weight and bias. -# The Flatten layer flatens the output of the linear layer to a 1D tensor, -# to match the shape of `y`. -model = torch.nn.Sequential( - torch.nn.Linear(3, 1), - torch.nn.Flatten(0, 1) -) - -# The nn package also contains definitions of popular loss functions; in this -# case we will use Mean Squared Error (MSE) as our loss function. -loss_fn = torch.nn.MSELoss(reduction='sum') -# print(model[0].weight.requires_grad) - -learning_rate = 1e-6 -for t in range(8000): - - # Forward pass: compute predicted y by passing x to the model. Module objects - # override the __call__ operator so you can call them like functions. When - # doing so you pass a Tensor of input data to the Module and it produces - # a Tensor of output data. - y_pred = model(xx) - - # Compute and print loss. We pass Tensors containing the predicted and true - # values of y, and the loss function returns a Tensor containing the - # loss. - loss = loss_fn(y_pred, y) - if t % 1000 == 999: - print(t, loss.item()) - - # Zero the gradients before running the backward pass. - model.zero_grad() - - # Backward pass: compute gradient of the loss with respect to all the learnable - # parameters of the model. Internally, the parameters of each Module are stored - # in Tensors with requires_grad=True, so this call will compute gradients for - # all learnable parameters in the model. - loss.backward() - - # Update the weights using gradient descent. Each parameter is a Tensor, so - # we can access its gradients like we did before. - with torch.no_grad(): - for param in model.parameters(): - param -= learning_rate * param.grad - -# You can access the first layer of `model` like accessing the first item of a list -linear_layer = model[0] - -# For linear layer, its parameters are stored as `weight` and `bias`. -print(f'Result: y = {linear_layer.bias.item()} + {linear_layer.weight[:, 0].item()} x + {linear_layer.weight[:, 1].item()} x^2 + {linear_layer.weight[:, 2].item()} x^3') \ No newline at end of file diff --git a/python/jittor/compatibility/tutorial/auto_grad5_optim.py b/python/jittor/compatibility/tutorial/auto_grad5_optim.py deleted file mode 100644 index 04949320..00000000 --- a/python/jittor/compatibility/tutorial/auto_grad5_optim.py +++ /dev/null @@ -1,53 +0,0 @@ -# -*- coding: utf-8 -*- -import torch -import math - - -# Create Tensors to hold input and outputs. -x = torch.linspace(-math.pi, math.pi, 2000) -y = torch.sin(x) - -# Prepare the input tensor (x, x^2, x^3). -p = torch.tensor([1, 2, 3]) -xx = x.unsqueeze(-1).pow(p) - -# Use the nn package to define our model and loss function. -model = torch.nn.Sequential( - torch.nn.Linear(3, 1), - torch.nn.Flatten(0, 1) -) -loss_fn = torch.nn.MSELoss(reduction='sum') - -# Use the optim package to define an Optimizer that will update the weights of -# the model for us. Here we will use RMSprop; the optim package contains many other -# optimization algorithms. The first argument to the RMSprop constructor tells the -# optimizer which Tensors it should update. -learning_rate = 1e-3 -optimizer = torch.optim.RMSprop(model.parameters(), lr=learning_rate) -for t in range(8000): - # Forward pass: compute predicted y by passing x to the model. - y_pred = model(xx) - - # Compute and print loss. - loss = loss_fn(y_pred, y) - if t % 1000 == 999: - print(t, loss.item()) - - # Before the backward pass, use the optimizer object to zero all of the - # gradients for the variables it will update (which are the learnable - # weights of the model). This is because by default, gradients are - # accumulated in buffers( i.e, not overwritten) whenever .backward() - # is called. Checkout docs of torch.autograd.backward for more details. - optimizer.zero_grad() - - # Backward pass: compute gradient of the loss with respect to model - # parameters - loss.backward() - - # Calling the step function on an Optimizer makes an update to its - # parameters - optimizer.step() - - -linear_layer = model[0] -print(f'Result: y = {linear_layer.bias.item()} + {linear_layer.weight[:, 0].item()} x + {linear_layer.weight[:, 1].item()} x^2 + {linear_layer.weight[:, 2].item()} x^3') \ No newline at end of file diff --git a/python/jittor/compatibility/tutorial/auto_grad6_module.py b/python/jittor/compatibility/tutorial/auto_grad6_module.py deleted file mode 100644 index a240e2b5..00000000 --- a/python/jittor/compatibility/tutorial/auto_grad6_module.py +++ /dev/null @@ -1,59 +0,0 @@ -# -*- coding: utf-8 -*- -import torch -import math - - -class Polynomial3(torch.nn.Module): - def __init__(self): - """ - In the constructor we instantiate four parameters and assign them as - member parameters. - """ - super().__init__() - self.a = torch.nn.Parameter(torch.randn(())) - self.b = torch.nn.Parameter(torch.randn(())) - self.c = torch.nn.Parameter(torch.randn(())) - self.d = torch.nn.Parameter(torch.randn(())) - - def forward(self, x): - """ - In the forward function we accept a Tensor of input data and we must return - a Tensor of output data. We can use Modules defined in the constructor as - well as arbitrary operators on Tensors. - """ - return self.a + self.b * x + self.c * x ** 2 + self.d * x ** 3 - - def string(self): - """ - Just like any class in Python, you can also define custom method on PyTorch modules - """ - return f'y = {self.a.item()} + {self.b.item()} x + {self.c.item()} x^2 + {self.d.item()} x^3' - - -# Create Tensors to hold input and outputs. -x = torch.linspace(-math.pi, math.pi, 2000) -y = torch.sin(x) - -# Construct our model by instantiating the class defined above -model = Polynomial3() - -# Construct our loss function and an Optimizer. The call to model.parameters() -# in the SGD constructor will contain the learnable parameters (defined -# with torch.nn.Parameter) which are members of the model. -criterion = torch.nn.MSELoss(reduction='sum') -optimizer = torch.optim.SGD(model.parameters(), lr=1e-6) -for t in range(8000): - # Forward pass: Compute predicted y by passing x to the model - y_pred = model(x) - - # Compute and print loss - loss = criterion(y_pred, y) - if t % 1000 == 999: - print(t, loss.item()) - - # Zero gradients, perform a backward pass, and update the weights. - optimizer.zero_grad() - loss.backward() - optimizer.step() - -print(f'Result: {model.string()}') \ No newline at end of file diff --git a/python/jittor/compatibility/tutorial/auto_grad7_dynet.py b/python/jittor/compatibility/tutorial/auto_grad7_dynet.py deleted file mode 100644 index fa954771..00000000 --- a/python/jittor/compatibility/tutorial/auto_grad7_dynet.py +++ /dev/null @@ -1,69 +0,0 @@ -# -*- coding: utf-8 -*- -import random -import torch -import math - - -class DynamicNet(torch.nn.Module): - def __init__(self): - """ - In the constructor we instantiate five parameters and assign them as members. - """ - super().__init__() - self.a = torch.nn.Parameter(torch.randn(())) - self.b = torch.nn.Parameter(torch.randn(())) - self.c = torch.nn.Parameter(torch.randn(())) - self.d = torch.nn.Parameter(torch.randn(())) - self.e = torch.nn.Parameter(torch.randn(())) - - def forward(self, x): - """ - For the forward pass of the model, we randomly choose either 4, 5 - and reuse the e parameter to compute the contribution of these orders. - - Since each forward pass builds a dynamic computation graph, we can use normal - Python control-flow operators like loops or conditional statements when - defining the forward pass of the model. - - Here we also see that it is perfectly safe to reuse the same parameter many - times when defining a computational graph. - """ - y = self.a + self.b * x + self.c * x ** 2 + self.d * x ** 3 - for exp in range(4, random.randint(4, 6)): - y = y + self.e * x ** exp - return y - - def string(self): - """ - Just like any class in Python, you can also define custom method on PyTorch modules - """ - return f'y = {self.a.item()} + {self.b.item()} x + {self.c.item()} x^2 + {self.d.item()} x^3 + {self.e.item()} x^4 ? + {self.e.item()} x^5 ?' - - -# Create Tensors to hold input and outputs. -x = torch.linspace(-math.pi, math.pi, 2000) -y = torch.sin(x) - -# Construct our model by instantiating the class defined above -model = DynamicNet() - -# Construct our loss function and an Optimizer. Training this strange model with -# vanilla stochastic gradient descent is tough, so we use momentum -criterion = torch.nn.MSELoss(reduction='sum') -optimizer = torch.optim.SGD(model.parameters(), lr=1e-8, momentum=0.9) -for t in range(60000): - # Forward pass: Compute predicted y by passing x to the model - y_pred = model(x) - - # Compute and print loss - loss = criterion(y_pred, y) - if t % 2000 == 1999: - print(t, loss.item()) - - # Zero gradients, perform a backward pass, and update the weights. - optimizer.zero_grad() - loss.backward() - optimizer.step() - # print(torch.liveness_info()) - -print(f'Result: {model.string()}') \ No newline at end of file diff --git a/python/jittor/compatibility/tutorial/quickstart.py b/python/jittor/compatibility/tutorial/quickstart.py deleted file mode 100644 index f0401a9b..00000000 --- a/python/jittor/compatibility/tutorial/quickstart.py +++ /dev/null @@ -1,106 +0,0 @@ -import torch -from torch import nn -# from jtorch.utils import DataLoader -from torch.utils.data import DataLoader -from torchvision import datasets -from torchvision.transforms import ToTensor - -# Download training data from open datasets. -training_data = datasets.FashionMNIST( - root="data", - train=True, - download=True, - transform=ToTensor(), -) - -# Download test data from open datasets. -test_data = datasets.FashionMNIST( - root="data", - train=False, - download=True, - transform=ToTensor(), -) - -batch_size = 64 - -# Create data loaders. -train_dataloader = DataLoader(training_data, batch_size=batch_size) -test_dataloader = DataLoader(test_data, batch_size=batch_size) - -print(len(train_dataloader)) -for X, y in test_dataloader: - print(f"Shape of X [N, C, H, W]: {X.shape}") - print(f"Shape of y: {y.shape} {y.dtype}") - break - -# Get cpu or gpu device for training. -device = "cuda" if torch.cuda.is_available() else "cpu" -print(f"Using {device} device") - -# Define model -class NeuralNetwork(nn.Module): - def __init__(self): - super(NeuralNetwork, self).__init__() - self.flatten = nn.Flatten() - self.linear_relu_stack = nn.Sequential( - nn.Linear(28*28, 512), - nn.ReLU(), - nn.Linear(512, 512), - nn.ReLU(), - nn.Linear(512, 10) - ) - - def forward(self, x): - x = self.flatten(x) - logits = self.linear_relu_stack(x) - return logits - -model = NeuralNetwork().to(device) -print(model) - - -loss_fn = nn.CrossEntropyLoss() -optimizer = torch.optim.SGD(model.parameters(), lr=1e-3) - -def train(dataloader, model, loss_fn, optimizer): - size = len(dataloader.dataset) - model.train() - for batch, (X, y) in enumerate(dataloader): - X, y = X.to(device), y.to(device) - - # Compute prediction error - pred = model(X) - loss = loss_fn(pred, y) - - # Backpropagation - optimizer.zero_grad() - loss.backward() - optimizer.step() - - if batch % 100 == 0: - loss, current = loss.item(), batch * len(X) - print(f"loss: {loss:>7f} [{current:>5d}/{size:>5d}]") - -def test(dataloader, model, loss_fn): - size = len(dataloader.dataset) - num_batches = len(dataloader) - model.eval() - test_loss, correct = 0, 0 - with torch.no_grad(): - for X, y in dataloader: - X, y = X.to(device), y.to(device) - pred = model(X) - test_loss += loss_fn(pred, y).item() - correct += (pred.argmax(1) == y).type(torch.float).sum().item() - test_loss /= num_batches - correct /= size - print(f"Test Error: \n Accuracy: {(100*correct):>0.1f}%, Avg loss: {test_loss:>8f} \n") - - -epochs = 5 -test(test_dataloader, model, loss_fn) -for t in range(epochs): - print(f"Epoch {t+1}\n-------------------------------") - train(train_dataloader, model, loss_fn, optimizer) - test(test_dataloader, model, loss_fn) -print("Done!") \ No newline at end of file diff --git a/python/jittor/compatibility/utils/__init__.py b/python/jittor/compatibility/utils/__init__.py deleted file mode 100644 index ac2c2bd8..00000000 --- a/python/jittor/compatibility/utils/__init__.py +++ /dev/null @@ -1,5 +0,0 @@ -cpp_extension = None -_flatten_dense_tensors = None -_unflatten_dense_tensors = None - -tensorboard = None \ No newline at end of file diff --git a/python/jittor/compatibility/utils/_pytree.py b/python/jittor/compatibility/utils/_pytree.py deleted file mode 100644 index c3118964..00000000 --- a/python/jittor/compatibility/utils/_pytree.py +++ /dev/null @@ -1,3 +0,0 @@ -#TODO: Implement this -_register_pytree_node = None -_dict_flatten = None \ No newline at end of file diff --git a/python/jittor/compatibility/utils/checkpoint.py b/python/jittor/compatibility/utils/checkpoint.py deleted file mode 100644 index ba3c3e8e..00000000 --- a/python/jittor/compatibility/utils/checkpoint.py +++ /dev/null @@ -1,8 +0,0 @@ -detach_variable = None - - -def checkpoint( - *args, - **kwargs -): - pass diff --git a/python/jittor/compatibility/utils/data.py b/python/jittor/compatibility/utils/data.py deleted file mode 100644 index 71946a23..00000000 --- a/python/jittor/compatibility/utils/data.py +++ /dev/null @@ -1,137 +0,0 @@ -import jittor as jt -import jittor.dataset -from jittor.dataset import Dataset as JDataset - -from collections import namedtuple -from typing import Any, Callable, Iterable, Optional, Sequence, Union - - -class Dataset: - def __getitem__(self, index): - raise NotImplementedError - -class IterableDataset: - def __iter__(self): - raise NotImplementedError - - -class DataLoader(JDataset): - def __init__(self, dataset, - batch_size: Optional[int] = 1, - shuffle: Optional[bool] = False, - sampler = None, - batch_sampler = None, - num_workers: int = 0, - collate_fn = None, - pin_memory: bool = False, - drop_last: bool = False, - timeout: float = 0, - worker_init_fn = None, - multiprocessing_context=None, - generator=None, - *, prefetch_factor: int = 2, - persistent_workers: bool = False, - pin_memory_device: str = "") -> None: - super().__init__(batch_size=batch_size, - shuffle=shuffle, - num_workers=num_workers, - drop_last=drop_last) - - unsupported_kwargs = { - "batch_sampler": batch_sampler, - "pin_memory": pin_memory, - "timeout": timeout, - "worker_init_fn": worker_init_fn, - "multiprocessing_context": multiprocessing_context, - "generator": generator, - "persistent_workers": persistent_workers, - "pin_memory_device": pin_memory_device - } - for kwarg, value in unsupported_kwargs.items(): - if value: - jt.LOG.w(f"Not implemented Dataloader kwarg: {kwarg}") - - self.dataset = dataset - self.collate_fn = collate_fn - self.sampler = sampler - - if not isinstance(dataset, IterableDataset): - self.total_len = len(dataset) - else: - # TODO: support multiple worker for iterable dataset - assert(num_workers == 0) - - def collate_batch(self, batch): - if self.collate_fn is not None: - return self.collate_fn(batch) - else: - return super().collate_batch(batch) - - def __getitem__(self, i): - return self.dataset[i] - - def __iter__(self): - if isinstance(self.dataset, IterableDataset): - return self.inner_iter() - else: - return super().__iter__() - - def inner_iter(self): - current_batch = [] - - if jt.world_size > 1: - assert self.batch_size % jt.world_size == 0, \ - f"IterableDataset does not support a batch size ({self.batch_size}) that is not evenly divisible by the number of processes f{jt.world_size}" - real_batch_size = int(self.batch_size / jt.world_size) - else: - real_batch_size = self.batch_size - - for element in self.dataset: - current_batch.append(element) - - if len(current_batch) == real_batch_size: - current_batch = self.collate_batch(current_batch) - current_batch = self.to_jittor(current_batch) - yield current_batch - current_batch = [] - - if not self.drop_last and len(current_batch) > 0: - current_batch = self.collate_batch(current_batch) - yield self.to_jittor(current_batch) - -def get_worker_info(): - # always return the fake worker info - return namedtuple('WorkerInfo', 'id num_workers')(0, 1) - -class RandomSampler(jt.dataset.RandomSampler): - def __init__(self, dataset, generator=None, **kwargs): - super().__init__(dataset, **kwargs) - - def __iter__(self): - if getattr(self.dataset, "support_random_access", True): - return super().__iter__() - else: - self.dataset.shuffle() - return iter(range(self.dataset.__real_len__() if hasattr(self.dataset,"__real_len__") else self.dataset.__len__())) - -class DistributedSampler(jt.dataset.Sampler): - def __init__(self, sampler: RandomSampler): - assert(isinstance(sampler, RandomSampler)) - self.sampler = sampler - - def set_epoch(self, epoch: int): - ### do nothing, let jittor's inner dataset handle - pass - - def __iter__(self): - return self.sampler.__iter__() - - def __len__(self): - return self.sampler.__len__() - -BatchSampler = jt.dataset.BatchSampler -Sampler = jt.dataset.Sampler -SequentialSampler = jt.dataset.SequentialSampler -SubsetRandomSampler = jt.dataset.SubsetRandomSampler - -TensorDataset = Dataset diff --git a/python/jittor/compatibility/utils/dtype.py b/python/jittor/compatibility/utils/dtype.py deleted file mode 100644 index 41728383..00000000 --- a/python/jittor/compatibility/utils/dtype.py +++ /dev/null @@ -1,9 +0,0 @@ -from typing import Callable, Union -Dtype = Union[Callable, str] - -def get_string_dtype(dtype): - if callable(dtype): - dtype = dtype.__name__ - if not isinstance(dtype, str): - raise ValueError(f"dtype is expected to be str, python type function, or jittor type function, but got {dtype}.") - return dtype \ No newline at end of file diff --git a/python/jittor/compatibility/utils/hooks.py b/python/jittor/compatibility/utils/hooks.py deleted file mode 100644 index e69de29b..00000000 diff --git a/python/jittor/compatibility/utils/pip_publish.py b/python/jittor/compatibility/utils/pip_publish.py deleted file mode 100644 index 72ff245f..00000000 --- a/python/jittor/compatibility/utils/pip_publish.py +++ /dev/null @@ -1,34 +0,0 @@ -import os -import glob -import shutil -import sys - -home_path = os.path.join(os.path.dirname(__file__), "..", "..", "..") -home_path = os.path.abspath(home_path) - -def callback(func, path, exc_info): - print(f"remove \"{path}\" failed.") - -def rmtree(path): - if os.path.isdir(path): - print(f"remove \"{path}\" recursive.") - shutil.rmtree(path, onerror=callback) - -def remove_tmpfile(): - dist_file = home_path+"/dist" - egg_file = glob.glob(home_path+"/**/*egg-info") - rmtree(dist_file) - for e in egg_file: - rmtree(e) - -def run_cmd(cmd): - print("[CMD]", cmd) - assert os.system(cmd)==0 - -os.chdir(home_path) -remove_tmpfile() - -run_cmd(f"{sys.executable} ./setup.py sdist") -run_cmd(f"{sys.executable} -m twine upload dist/*") - -remove_tmpfile() \ No newline at end of file diff --git a/python/jittor/compatibility/vision/_internally_replaced_utils.py b/python/jittor/compatibility/vision/_internally_replaced_utils.py deleted file mode 100644 index 748fa2ea..00000000 --- a/python/jittor/compatibility/vision/_internally_replaced_utils.py +++ /dev/null @@ -1,46 +0,0 @@ -import importlib.machinery -import os - - -def _download_file_from_remote_location(fpath: str, url: str) -> None: - pass - - -def _is_remote_location_available() -> bool: - return False - - -def _get_extension_path(lib_name): - - lib_dir = os.path.dirname(__file__) - if os.name == "nt": - # Register the main torchvision library location on the default DLL path - import ctypes - import sys - - kernel32 = ctypes.WinDLL("kernel32.dll", use_last_error=True) - with_load_library_flags = hasattr(kernel32, "AddDllDirectory") - prev_error_mode = kernel32.SetErrorMode(0x0001) - - if with_load_library_flags: - kernel32.AddDllDirectory.restype = ctypes.c_void_p - - if sys.version_info >= (3, 8): - os.add_dll_directory(lib_dir) - elif with_load_library_flags: - res = kernel32.AddDllDirectory(lib_dir) - if res is None: - err = ctypes.WinError(ctypes.get_last_error()) - err.strerror += f' Error adding "{lib_dir}" to the DLL directories.' - raise err - - kernel32.SetErrorMode(prev_error_mode) - - loader_details = (importlib.machinery.ExtensionFileLoader, importlib.machinery.EXTENSION_SUFFIXES) - - extfinder = importlib.machinery.FileFinder(lib_dir, loader_details) - ext_specs = extfinder.find_spec(lib_name) - if ext_specs is None: - raise ImportError - - return ext_specs.origin \ No newline at end of file diff --git a/python/jittor/compatibility/vision/datasets/__init__.py b/python/jittor/compatibility/vision/datasets/__init__.py deleted file mode 100644 index d04187f1..00000000 --- a/python/jittor/compatibility/vision/datasets/__init__.py +++ /dev/null @@ -1,9 +0,0 @@ -from .mnist import EMNIST, FashionMNIST, KMNIST, MNIST, QMNIST - -__all__ = ( - "EMNIST", - "FashionMNIST", - "QMNIST", - "MNIST", - "KMNIST", -) \ No newline at end of file diff --git a/python/jittor/compatibility/vision/datasets/mnist.py b/python/jittor/compatibility/vision/datasets/mnist.py deleted file mode 100644 index dfc3787b..00000000 --- a/python/jittor/compatibility/vision/datasets/mnist.py +++ /dev/null @@ -1,558 +0,0 @@ -import codecs -import os -import os.path -import shutil -import string -import sys -import warnings -from typing import Any, Callable, Dict, List, Optional, Tuple -from urllib.error import URLError - -import numpy as np -import torch -from PIL import Image - -from .utils import check_integrity, download_and_extract_archive, extract_archive, verify_str_arg -from .vision import VisionDataset - - -class MNIST(VisionDataset): - """`MNIST `_ Dataset. - - Args: - root (string): Root directory of dataset where ``MNIST/raw/train-images-idx3-ubyte`` - and ``MNIST/raw/t10k-images-idx3-ubyte`` exist. - train (bool, optional): If True, creates dataset from ``train-images-idx3-ubyte``, - otherwise from ``t10k-images-idx3-ubyte``. - download (bool, optional): If True, downloads the dataset from the internet and - puts it in root directory. If dataset is already downloaded, it is not - downloaded again. - transform (callable, optional): A function/transform that takes in an PIL image - and returns a transformed version. E.g, ``transforms.RandomCrop`` - target_transform (callable, optional): A function/transform that takes in the - target and transforms it. - """ - - mirrors = [ - "http://yann.lecun.com/exdb/mnist/", - "https://ossci-datasets.s3.amazonaws.com/mnist/", - ] - - resources = [ - ("train-images-idx3-ubyte.gz", "f68b3c2dcbeaaa9fbdd348bbdeb94873"), - ("train-labels-idx1-ubyte.gz", "d53e105ee54ea40749a09fcbcd1e9432"), - ("t10k-images-idx3-ubyte.gz", "9fb629c4189551a2d022fa330f9573f3"), - ("t10k-labels-idx1-ubyte.gz", "ec29112dd5afa0611ce80d1b7f02629c"), - ] - - training_file = "training.pt" - test_file = "test.pt" - classes = [ - "0 - zero", - "1 - one", - "2 - two", - "3 - three", - "4 - four", - "5 - five", - "6 - six", - "7 - seven", - "8 - eight", - "9 - nine", - ] - - @property - def train_labels(self): - warnings.warn("train_labels has been renamed targets") - return self.targets - - @property - def test_labels(self): - warnings.warn("test_labels has been renamed targets") - return self.targets - - @property - def train_data(self): - warnings.warn("train_data has been renamed data") - return self.data - - @property - def test_data(self): - warnings.warn("test_data has been renamed data") - return self.data - - def __init__( - self, - root: str, - train: bool = True, - transform: Optional[Callable] = None, - target_transform: Optional[Callable] = None, - download: bool = False, - ) -> None: - super().__init__(root, transform=transform, target_transform=target_transform) - self.train = train # training set or test set - - if self._check_legacy_exist(): - self.data, self.targets = self._load_legacy_data() - return - - if download: - self.download() - - if not self._check_exists(): - raise RuntimeError("Dataset not found. You can use download=True to download it") - - self.data, self.targets = self._load_data() - - def _check_legacy_exist(self): - processed_folder_exists = os.path.exists(self.processed_folder) - if not processed_folder_exists: - return False - - return all( - check_integrity(os.path.join(self.processed_folder, file)) for file in (self.training_file, self.test_file) - ) - - def _load_legacy_data(self): - # This is for BC only. We no longer cache the data in a custom binary, but simply read from the raw data - # directly. - data_file = self.training_file if self.train else self.test_file - return torch.load(os.path.join(self.processed_folder, data_file)) - - def _load_data(self): - image_file = f"{'train' if self.train else 't10k'}-images-idx3-ubyte" - data = read_image_file(os.path.join(self.raw_folder, image_file)) - - label_file = f"{'train' if self.train else 't10k'}-labels-idx1-ubyte" - targets = read_label_file(os.path.join(self.raw_folder, label_file)) - - return data, targets - - def __getitem__(self, index: int) -> Tuple[Any, Any]: - """ - Args: - index (int): Index - - Returns: - tuple: (image, target) where target is index of the target class. - """ - img, target = self.data[index], int(self.targets[index]) - - # doing this so that it is consistent with all other datasets - # to return a PIL Image - img = Image.fromarray(img.numpy(), mode="L") - - if self.transform is not None: - img = self.transform(img) - - if self.target_transform is not None: - target = self.target_transform(target) - - return img, target - - def __len__(self) -> int: - return len(self.data) - - @property - def raw_folder(self) -> str: - return os.path.join(self.root, self.__class__.__name__, "raw") - - @property - def processed_folder(self) -> str: - return os.path.join(self.root, self.__class__.__name__, "processed") - - @property - def class_to_idx(self) -> Dict[str, int]: - return {_class: i for i, _class in enumerate(self.classes)} - - def _check_exists(self) -> bool: - return all( - check_integrity(os.path.join(self.raw_folder, os.path.splitext(os.path.basename(url))[0])) - for url, _ in self.resources - ) - - def download(self) -> None: - """Download the MNIST data if it doesn't exist already.""" - - if self._check_exists(): - return - - os.makedirs(self.raw_folder, exist_ok=True) - - # download files - for filename, md5 in self.resources: - for mirror in self.mirrors: - url = f"{mirror}{filename}" - try: - print(f"Downloading {url}") - download_and_extract_archive(url, download_root=self.raw_folder, filename=filename, md5=md5) - except URLError as error: - print(f"Failed to download (trying next):\n{error}") - continue - finally: - print() - break - else: - raise RuntimeError(f"Error downloading {filename}") - - def extra_repr(self) -> str: - split = "Train" if self.train is True else "Test" - return f"Split: {split}" - - -class FashionMNIST(MNIST): - """`Fashion-MNIST `_ Dataset. - - Args: - root (string): Root directory of dataset where ``FashionMNIST/raw/train-images-idx3-ubyte`` - and ``FashionMNIST/raw/t10k-images-idx3-ubyte`` exist. - train (bool, optional): If True, creates dataset from ``train-images-idx3-ubyte``, - otherwise from ``t10k-images-idx3-ubyte``. - download (bool, optional): If True, downloads the dataset from the internet and - puts it in root directory. If dataset is already downloaded, it is not - downloaded again. - transform (callable, optional): A function/transform that takes in an PIL image - and returns a transformed version. E.g, ``transforms.RandomCrop`` - target_transform (callable, optional): A function/transform that takes in the - target and transforms it. - """ - - mirrors = ["http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/"] - - resources = [ - ("train-images-idx3-ubyte.gz", "8d4fb7e6c68d591d4c3dfef9ec88bf0d"), - ("train-labels-idx1-ubyte.gz", "25c81989df183df01b3e8a0aad5dffbe"), - ("t10k-images-idx3-ubyte.gz", "bef4ecab320f06d8554ea6380940ec79"), - ("t10k-labels-idx1-ubyte.gz", "bb300cfdad3c16e7a12a480ee83cd310"), - ] - classes = ["T-shirt/top", "Trouser", "Pullover", "Dress", "Coat", "Sandal", "Shirt", "Sneaker", "Bag", "Ankle boot"] - - -class KMNIST(MNIST): - """`Kuzushiji-MNIST `_ Dataset. - - Args: - root (string): Root directory of dataset where ``KMNIST/raw/train-images-idx3-ubyte`` - and ``KMNIST/raw/t10k-images-idx3-ubyte`` exist. - train (bool, optional): If True, creates dataset from ``train-images-idx3-ubyte``, - otherwise from ``t10k-images-idx3-ubyte``. - download (bool, optional): If True, downloads the dataset from the internet and - puts it in root directory. If dataset is already downloaded, it is not - downloaded again. - transform (callable, optional): A function/transform that takes in an PIL image - and returns a transformed version. E.g, ``transforms.RandomCrop`` - target_transform (callable, optional): A function/transform that takes in the - target and transforms it. - """ - - mirrors = ["http://codh.rois.ac.jp/kmnist/dataset/kmnist/"] - - resources = [ - ("train-images-idx3-ubyte.gz", "bdb82020997e1d708af4cf47b453dcf7"), - ("train-labels-idx1-ubyte.gz", "e144d726b3acfaa3e44228e80efcd344"), - ("t10k-images-idx3-ubyte.gz", "5c965bf0a639b31b8f53240b1b52f4d7"), - ("t10k-labels-idx1-ubyte.gz", "7320c461ea6c1c855c0b718fb2a4b134"), - ] - classes = ["o", "ki", "su", "tsu", "na", "ha", "ma", "ya", "re", "wo"] - - -class EMNIST(MNIST): - """`EMNIST `_ Dataset. - - Args: - root (string): Root directory of dataset where ``EMNIST/raw/train-images-idx3-ubyte`` - and ``EMNIST/raw/t10k-images-idx3-ubyte`` exist. - split (string): The dataset has 6 different splits: ``byclass``, ``bymerge``, - ``balanced``, ``letters``, ``digits`` and ``mnist``. This argument specifies - which one to use. - train (bool, optional): If True, creates dataset from ``training.pt``, - otherwise from ``test.pt``. - download (bool, optional): If True, downloads the dataset from the internet and - puts it in root directory. If dataset is already downloaded, it is not - downloaded again. - transform (callable, optional): A function/transform that takes in an PIL image - and returns a transformed version. E.g, ``transforms.RandomCrop`` - target_transform (callable, optional): A function/transform that takes in the - target and transforms it. - """ - - url = "https://www.itl.nist.gov/iaui/vip/cs_links/EMNIST/gzip.zip" - md5 = "58c8d27c78d21e728a6bc7b3cc06412e" - splits = ("byclass", "bymerge", "balanced", "letters", "digits", "mnist") - # Merged Classes assumes Same structure for both uppercase and lowercase version - _merged_classes = {"c", "i", "j", "k", "l", "m", "o", "p", "s", "u", "v", "w", "x", "y", "z"} - _all_classes = set(string.digits + string.ascii_letters) - classes_split_dict = { - "byclass": sorted(list(_all_classes)), - "bymerge": sorted(list(_all_classes - _merged_classes)), - "balanced": sorted(list(_all_classes - _merged_classes)), - "letters": ["N/A"] + list(string.ascii_lowercase), - "digits": list(string.digits), - "mnist": list(string.digits), - } - - def __init__(self, root: str, split: str, **kwargs: Any) -> None: - self.split = verify_str_arg(split, "split", self.splits) - self.training_file = self._training_file(split) - self.test_file = self._test_file(split) - super().__init__(root, **kwargs) - self.classes = self.classes_split_dict[self.split] - - @staticmethod - def _training_file(split) -> str: - return f"training_{split}.pt" - - @staticmethod - def _test_file(split) -> str: - return f"test_{split}.pt" - - @property - def _file_prefix(self) -> str: - return f"emnist-{self.split}-{'train' if self.train else 'test'}" - - @property - def images_file(self) -> str: - return os.path.join(self.raw_folder, f"{self._file_prefix}-images-idx3-ubyte") - - @property - def labels_file(self) -> str: - return os.path.join(self.raw_folder, f"{self._file_prefix}-labels-idx1-ubyte") - - def _load_data(self): - return read_image_file(self.images_file), read_label_file(self.labels_file) - - def _check_exists(self) -> bool: - return all(check_integrity(file) for file in (self.images_file, self.labels_file)) - - def download(self) -> None: - """Download the EMNIST data if it doesn't exist already.""" - - if self._check_exists(): - return - - os.makedirs(self.raw_folder, exist_ok=True) - - download_and_extract_archive(self.url, download_root=self.raw_folder, md5=self.md5) - gzip_folder = os.path.join(self.raw_folder, "gzip") - for gzip_file in os.listdir(gzip_folder): - if gzip_file.endswith(".gz"): - extract_archive(os.path.join(gzip_folder, gzip_file), self.raw_folder) - shutil.rmtree(gzip_folder) - - -class QMNIST(MNIST): - """`QMNIST `_ Dataset. - - Args: - root (string): Root directory of dataset whose ``raw`` - subdir contains binary files of the datasets. - what (string,optional): Can be 'train', 'test', 'test10k', - 'test50k', or 'nist' for respectively the mnist compatible - training set, the 60k qmnist testing set, the 10k qmnist - examples that match the mnist testing set, the 50k - remaining qmnist testing examples, or all the nist - digits. The default is to select 'train' or 'test' - according to the compatibility argument 'train'. - compat (bool,optional): A boolean that says whether the target - for each example is class number (for compatibility with - the MNIST dataloader) or a torch vector containing the - full qmnist information. Default=True. - download (bool, optional): If True, downloads the dataset from - the internet and puts it in root directory. If dataset is - already downloaded, it is not downloaded again. - transform (callable, optional): A function/transform that - takes in an PIL image and returns a transformed - version. E.g, ``transforms.RandomCrop`` - target_transform (callable, optional): A function/transform - that takes in the target and transforms it. - train (bool,optional,compatibility): When argument 'what' is - not specified, this boolean decides whether to load the - training set ot the testing set. Default: True. - """ - - subsets = {"train": "train", "test": "test", "test10k": "test", "test50k": "test", "nist": "nist"} - resources: Dict[str, List[Tuple[str, str]]] = { # type: ignore[assignment] - "train": [ - ( - "https://raw.githubusercontent.com/facebookresearch/qmnist/master/qmnist-train-images-idx3-ubyte.gz", - "ed72d4157d28c017586c42bc6afe6370", - ), - ( - "https://raw.githubusercontent.com/facebookresearch/qmnist/master/qmnist-train-labels-idx2-int.gz", - "0058f8dd561b90ffdd0f734c6a30e5e4", - ), - ], - "test": [ - ( - "https://raw.githubusercontent.com/facebookresearch/qmnist/master/qmnist-test-images-idx3-ubyte.gz", - "1394631089c404de565df7b7aeaf9412", - ), - ( - "https://raw.githubusercontent.com/facebookresearch/qmnist/master/qmnist-test-labels-idx2-int.gz", - "5b5b05890a5e13444e108efe57b788aa", - ), - ], - "nist": [ - ( - "https://raw.githubusercontent.com/facebookresearch/qmnist/master/xnist-images-idx3-ubyte.xz", - "7f124b3b8ab81486c9d8c2749c17f834", - ), - ( - "https://raw.githubusercontent.com/facebookresearch/qmnist/master/xnist-labels-idx2-int.xz", - "5ed0e788978e45d4a8bd4b7caec3d79d", - ), - ], - } - classes = [ - "0 - zero", - "1 - one", - "2 - two", - "3 - three", - "4 - four", - "5 - five", - "6 - six", - "7 - seven", - "8 - eight", - "9 - nine", - ] - - def __init__( - self, root: str, what: Optional[str] = None, compat: bool = True, train: bool = True, **kwargs: Any - ) -> None: - if what is None: - what = "train" if train else "test" - self.what = verify_str_arg(what, "what", tuple(self.subsets.keys())) - self.compat = compat - self.data_file = what + ".pt" - self.training_file = self.data_file - self.test_file = self.data_file - super().__init__(root, train, **kwargs) - - @property - def images_file(self) -> str: - (url, _), _ = self.resources[self.subsets[self.what]] - return os.path.join(self.raw_folder, os.path.splitext(os.path.basename(url))[0]) - - @property - def labels_file(self) -> str: - _, (url, _) = self.resources[self.subsets[self.what]] - return os.path.join(self.raw_folder, os.path.splitext(os.path.basename(url))[0]) - - def _check_exists(self) -> bool: - return all(check_integrity(file) for file in (self.images_file, self.labels_file)) - - def _load_data(self): - data = read_sn3_pascalvincent_tensor(self.images_file) - if data.dtype != torch.uint8: - raise TypeError(f"data should be of dtype torch.uint8 instead of {data.dtype}") - if data.ndimension() != 3: - raise ValueError("data should have 3 dimensions instead of {data.ndimension()}") - - targets = read_sn3_pascalvincent_tensor(self.labels_file).long() - if targets.ndimension() != 2: - raise ValueError(f"targets should have 2 dimensions instead of {targets.ndimension()}") - - if self.what == "test10k": - data = data[0:10000, :, :].clone() - targets = targets[0:10000, :].clone() - elif self.what == "test50k": - data = data[10000:, :, :].clone() - targets = targets[10000:, :].clone() - - return data, targets - - def download(self) -> None: - """Download the QMNIST data if it doesn't exist already. - Note that we only download what has been asked for (argument 'what'). - """ - if self._check_exists(): - return - - os.makedirs(self.raw_folder, exist_ok=True) - split = self.resources[self.subsets[self.what]] - - for url, md5 in split: - download_and_extract_archive(url, self.raw_folder, md5=md5) - - def __getitem__(self, index: int) -> Tuple[Any, Any]: - # redefined to handle the compat flag - img, target = self.data[index], self.targets[index] - img = Image.fromarray(img.numpy(), mode="L") - if self.transform is not None: - img = self.transform(img) - if self.compat: - target = int(target[0]) - if self.target_transform is not None: - target = self.target_transform(target) - return img, target - - def extra_repr(self) -> str: - return f"Split: {self.what}" - - -def get_int(b: bytes) -> int: - return int(codecs.encode(b, "hex"), 16) - - -SN3_PASCALVINCENT_BITSMAP = { - 8: torch.uint8, - 9: torch.int8, - 11: torch.int16, - 12: torch.int32, - 13: torch.float32, - 14: torch.float64, -} - -TORCH_TYPE_BITS = { - torch.uint8: 8, - torch.int8: 8, - torch.int16: 16, - torch.int32: 32, - torch.float32: 32, - torch.float64: 64, -} - - -def read_sn3_pascalvincent_tensor(path: str, strict: bool = True) -> torch.Tensor: - """Read a SN3 file in "Pascal Vincent" format (Lush file 'libidx/idx-io.lsh'). - Argument may be a filename, compressed filename, or file object. - """ - # read - with open(path, "rb") as f: - data = f.read() - # parse - magic = get_int(data[0:4]) - nd = magic % 256 - ty = magic // 256 - assert 1 <= nd <= 3 - assert 8 <= ty <= 14 - torch_type = SN3_PASCALVINCENT_BITSMAP[ty] - s = [get_int(data[4 * (i + 1) : 4 * (i + 2)]) for i in range(nd)] - - num_bytes_per_value = TORCH_TYPE_BITS[torch_type] // 8 - # The MNIST format uses the big endian byte order. If the system uses little endian byte order by default, - # we need to reverse the bytes before we can read them with torch.frombuffer(). - needs_byte_reversal = sys.byteorder == "little" and num_bytes_per_value > 1 - parsed = torch.frombuffer(bytearray(data), dtype=torch_type, offset=(4 * (nd + 1))) - if needs_byte_reversal: - parsed = parsed.flip(0) - - assert parsed.shape[0] == np.prod(s) or not strict - return parsed.view(*s) - - -def read_label_file(path: str) -> torch.Tensor: - x = read_sn3_pascalvincent_tensor(path, strict=False) - if x.dtype != torch.uint8: - raise TypeError(f"x should be of dtype torch.uint8 instead of {x.dtype}") - if x.ndimension() != 1: - raise ValueError(f"x should have 1 dimension instead of {x.ndimension()}") - return x.long() - - -def read_image_file(path: str) -> torch.Tensor: - x = read_sn3_pascalvincent_tensor(path, strict=False) - if x.dtype != torch.uint8: - raise TypeError(f"x should be of dtype torch.uint8 instead of {x.dtype}") - if x.ndimension() != 3: - raise ValueError(f"x should have 3 dimension instead of {x.ndimension()}") - return x \ No newline at end of file diff --git a/python/jittor/compatibility/vision/datasets/utils.py b/python/jittor/compatibility/vision/datasets/utils.py deleted file mode 100644 index f9ae1a89..00000000 --- a/python/jittor/compatibility/vision/datasets/utils.py +++ /dev/null @@ -1,522 +0,0 @@ -import bz2 -import contextlib -import gzip -import hashlib -import itertools -import lzma -import os -import os.path -import pathlib -import re -import sys -import tarfile -import urllib -import urllib.error -import urllib.request -import warnings -import zipfile -from typing import Any, Callable, Dict, IO, Iterable, Iterator, List, Optional, Tuple, TypeVar -from urllib.parse import urlparse - -import numpy as np -import requests -import torch -from tqdm import tqdm - -from .._internally_replaced_utils import _download_file_from_remote_location, _is_remote_location_available - -USER_AGENT = "pytorch/vision" - - -def _save_response_content( - content: Iterator[bytes], - destination: str, - length: Optional[int] = None, -) -> None: - with open(destination, "wb") as fh, tqdm(total=length) as pbar: - for chunk in content: - # filter out keep-alive new chunks - if not chunk: - continue - - fh.write(chunk) - pbar.update(len(chunk)) - - -def _urlretrieve(url: str, filename: str, chunk_size: int = 1024 * 32) -> None: - with urllib.request.urlopen(urllib.request.Request(url, headers={"User-Agent": USER_AGENT})) as response: - _save_response_content(iter(lambda: response.read(chunk_size), b""), filename, length=response.length) - - -def gen_bar_updater() -> Callable[[int, int, int], None]: - warnings.warn("The function `gen_bar_update` is deprecated since 0.13 and will be removed in 0.15.") - pbar = tqdm(total=None) - - def bar_update(count, block_size, total_size): - if pbar.total is None and total_size: - pbar.total = total_size - progress_bytes = count * block_size - pbar.update(progress_bytes - pbar.n) - - return bar_update - - -def calculate_md5(fpath: str, chunk_size: int = 1024 * 1024) -> str: - # Setting the `usedforsecurity` flag does not change anything about the functionality, but indicates that we are - # not using the MD5 checksum for cryptography. This enables its usage in restricted environments like FIPS. Without - # it torchvision.datasets is unusable in these environments since we perform a MD5 check everywhere. - if sys.version_info >= (3, 9): - md5 = hashlib.md5(usedforsecurity=False) - else: - md5 = hashlib.md5() - with open(fpath, "rb") as f: - for chunk in iter(lambda: f.read(chunk_size), b""): - md5.update(chunk) - return md5.hexdigest() - - -def check_md5(fpath: str, md5: str, **kwargs: Any) -> bool: - return md5 == calculate_md5(fpath, **kwargs) - - -def check_integrity(fpath: str, md5: Optional[str] = None) -> bool: - if not os.path.isfile(fpath): - return False - if md5 is None: - return True - return check_md5(fpath, md5) - - -def _get_redirect_url(url: str, max_hops: int = 3) -> str: - initial_url = url - headers = {"Method": "HEAD", "User-Agent": USER_AGENT} - - for _ in range(max_hops + 1): - with urllib.request.urlopen(urllib.request.Request(url, headers=headers)) as response: - if response.url == url or response.url is None: - return url - - url = response.url - else: - raise RecursionError( - f"Request to {initial_url} exceeded {max_hops} redirects. The last redirect points to {url}." - ) - - -def _get_google_drive_file_id(url: str) -> Optional[str]: - parts = urlparse(url) - - if re.match(r"(drive|docs)[.]google[.]com", parts.netloc) is None: - return None - - match = re.match(r"/file/d/(?P[^/]*)", parts.path) - if match is None: - return None - - return match.group("id") - - -def download_url( - url: str, root: str, filename: Optional[str] = None, md5: Optional[str] = None, max_redirect_hops: int = 3 -) -> None: - """Download a file from a url and place it in root. - - Args: - url (str): URL to download file from - root (str): Directory to place downloaded file in - filename (str, optional): Name to save the file under. If None, use the basename of the URL - md5 (str, optional): MD5 checksum of the download. If None, do not check - max_redirect_hops (int, optional): Maximum number of redirect hops allowed - """ - root = os.path.expanduser(root) - if not filename: - filename = os.path.basename(url) - fpath = os.path.join(root, filename) - - os.makedirs(root, exist_ok=True) - - # check if file is already present locally - if check_integrity(fpath, md5): - print("Using downloaded and verified file: " + fpath) - return - - if _is_remote_location_available(): - _download_file_from_remote_location(fpath, url) - else: - # expand redirect chain if needed - url = _get_redirect_url(url, max_hops=max_redirect_hops) - - # check if file is located on Google Drive - file_id = _get_google_drive_file_id(url) - if file_id is not None: - return download_file_from_google_drive(file_id, root, filename, md5) - - # download the file - try: - print("Downloading " + url + " to " + fpath) - _urlretrieve(url, fpath) - except (urllib.error.URLError, OSError) as e: # type: ignore[attr-defined] - if url[:5] == "https": - url = url.replace("https:", "http:") - print("Failed download. Trying https -> http instead. Downloading " + url + " to " + fpath) - _urlretrieve(url, fpath) - else: - raise e - - # check integrity of downloaded file - if not check_integrity(fpath, md5): - raise RuntimeError("File not found or corrupted.") - - -def list_dir(root: str, prefix: bool = False) -> List[str]: - """List all directories at a given root - - Args: - root (str): Path to directory whose folders need to be listed - prefix (bool, optional): If true, prepends the path to each result, otherwise - only returns the name of the directories found - """ - root = os.path.expanduser(root) - directories = [p for p in os.listdir(root) if os.path.isdir(os.path.join(root, p))] - if prefix is True: - directories = [os.path.join(root, d) for d in directories] - return directories - - -def list_files(root: str, suffix: str, prefix: bool = False) -> List[str]: - """List all files ending with a suffix at a given root - - Args: - root (str): Path to directory whose folders need to be listed - suffix (str or tuple): Suffix of the files to match, e.g. '.png' or ('.jpg', '.png'). - It uses the Python "str.endswith" method and is passed directly - prefix (bool, optional): If true, prepends the path to each result, otherwise - only returns the name of the files found - """ - root = os.path.expanduser(root) - files = [p for p in os.listdir(root) if os.path.isfile(os.path.join(root, p)) and p.endswith(suffix)] - if prefix is True: - files = [os.path.join(root, d) for d in files] - return files - - -def _extract_gdrive_api_response(response, chunk_size: int = 32 * 1024) -> Tuple[bytes, Iterator[bytes]]: - content = response.iter_content(chunk_size) - first_chunk = None - # filter out keep-alive new chunks - while not first_chunk: - first_chunk = next(content) - content = itertools.chain([first_chunk], content) - - try: - match = re.search("Google Drive - (?P<api_response>.+?)", first_chunk.decode()) - api_response = match["api_response"] if match is not None else None - except UnicodeDecodeError: - api_response = None - return api_response, content - - -def download_file_from_google_drive(file_id: str, root: str, filename: Optional[str] = None, md5: Optional[str] = None): - """Download a Google Drive file from and place it in root. - - Args: - file_id (str): id of file to be downloaded - root (str): Directory to place downloaded file in - filename (str, optional): Name to save the file under. If None, use the id of the file. - md5 (str, optional): MD5 checksum of the download. If None, do not check - """ - # Based on https://stackoverflow.com/questions/38511444/python-download-files-from-google-drive-using-url - - root = os.path.expanduser(root) - if not filename: - filename = file_id - fpath = os.path.join(root, filename) - - os.makedirs(root, exist_ok=True) - - if check_integrity(fpath, md5): - print(f"Using downloaded {'and verified ' if md5 else ''}file: {fpath}") - return - - url = "https://drive.google.com/uc" - params = dict(id=file_id, export="download") - with requests.Session() as session: - response = session.get(url, params=params, stream=True) - - for key, value in response.cookies.items(): - if key.startswith("download_warning"): - token = value - break - else: - api_response, content = _extract_gdrive_api_response(response) - token = "t" if api_response == "Virus scan warning" else None - - if token is not None: - response = session.get(url, params=dict(params, confirm=token), stream=True) - api_response, content = _extract_gdrive_api_response(response) - - if api_response == "Quota exceeded": - raise RuntimeError( - f"The daily quota of the file {filename} is exceeded and it " - f"can't be downloaded. This is a limitation of Google Drive " - f"and can only be overcome by trying again later." - ) - - _save_response_content(content, fpath) - - # In case we deal with an unhandled GDrive API response, the file should be smaller than 10kB and contain only text - if os.stat(fpath).st_size < 10 * 1024: - with contextlib.suppress(UnicodeDecodeError), open(fpath) as fh: - text = fh.read() - # Regular expression to detect HTML. Copied from https://stackoverflow.com/a/70585604 - if re.search(r"]*\s*>|(&(?:[\w\d]+|#\d+|#x[a-f\d]+);)", text): - warnings.warn( - f"We detected some HTML elements in the downloaded file. " - f"This most likely means that the download triggered an unhandled API response by GDrive. " - f"Please report this to torchvision at https://github.com/pytorch/vision/issues including " - f"the response:\n\n{text}" - ) - - if md5 and not check_md5(fpath, md5): - raise RuntimeError( - f"The MD5 checksum of the download file {fpath} does not match the one on record." - f"Please delete the file and try again. " - f"If the issue persists, please report this to torchvision at https://github.com/pytorch/vision/issues." - ) - - -def _extract_tar(from_path: str, to_path: str, compression: Optional[str]) -> None: - with tarfile.open(from_path, f"r:{compression[1:]}" if compression else "r") as tar: - tar.extractall(to_path) - - -_ZIP_COMPRESSION_MAP: Dict[str, int] = { - ".bz2": zipfile.ZIP_BZIP2, - ".xz": zipfile.ZIP_LZMA, -} - - -def _extract_zip(from_path: str, to_path: str, compression: Optional[str]) -> None: - with zipfile.ZipFile( - from_path, "r", compression=_ZIP_COMPRESSION_MAP[compression] if compression else zipfile.ZIP_STORED - ) as zip: - zip.extractall(to_path) - - -_ARCHIVE_EXTRACTORS: Dict[str, Callable[[str, str, Optional[str]], None]] = { - ".tar": _extract_tar, - ".zip": _extract_zip, -} -_COMPRESSED_FILE_OPENERS: Dict[str, Callable[..., IO]] = { - ".bz2": bz2.open, - ".gz": gzip.open, - ".xz": lzma.open, -} -_FILE_TYPE_ALIASES: Dict[str, Tuple[Optional[str], Optional[str]]] = { - ".tbz": (".tar", ".bz2"), - ".tbz2": (".tar", ".bz2"), - ".tgz": (".tar", ".gz"), -} - - -def _detect_file_type(file: str) -> Tuple[str, Optional[str], Optional[str]]: - """Detect the archive type and/or compression of a file. - - Args: - file (str): the filename - - Returns: - (tuple): tuple of suffix, archive type, and compression - - Raises: - RuntimeError: if file has no suffix or suffix is not supported - """ - suffixes = pathlib.Path(file).suffixes - if not suffixes: - raise RuntimeError( - f"File '{file}' has no suffixes that could be used to detect the archive type and compression." - ) - suffix = suffixes[-1] - - # check if the suffix is a known alias - if suffix in _FILE_TYPE_ALIASES: - return (suffix, *_FILE_TYPE_ALIASES[suffix]) - - # check if the suffix is an archive type - if suffix in _ARCHIVE_EXTRACTORS: - return suffix, suffix, None - - # check if the suffix is a compression - if suffix in _COMPRESSED_FILE_OPENERS: - # check for suffix hierarchy - if len(suffixes) > 1: - suffix2 = suffixes[-2] - - # check if the suffix2 is an archive type - if suffix2 in _ARCHIVE_EXTRACTORS: - return suffix2 + suffix, suffix2, suffix - - return suffix, None, suffix - - valid_suffixes = sorted(set(_FILE_TYPE_ALIASES) | set(_ARCHIVE_EXTRACTORS) | set(_COMPRESSED_FILE_OPENERS)) - raise RuntimeError(f"Unknown compression or archive type: '{suffix}'.\nKnown suffixes are: '{valid_suffixes}'.") - - -def _decompress(from_path: str, to_path: Optional[str] = None, remove_finished: bool = False) -> str: - r"""Decompress a file. - - The compression is automatically detected from the file name. - - Args: - from_path (str): Path to the file to be decompressed. - to_path (str): Path to the decompressed file. If omitted, ``from_path`` without compression extension is used. - remove_finished (bool): If ``True``, remove the file after the extraction. - - Returns: - (str): Path to the decompressed file. - """ - suffix, archive_type, compression = _detect_file_type(from_path) - if not compression: - raise RuntimeError(f"Couldn't detect a compression from suffix {suffix}.") - - if to_path is None: - to_path = from_path.replace(suffix, archive_type if archive_type is not None else "") - - # We don't need to check for a missing key here, since this was already done in _detect_file_type() - compressed_file_opener = _COMPRESSED_FILE_OPENERS[compression] - - with compressed_file_opener(from_path, "rb") as rfh, open(to_path, "wb") as wfh: - wfh.write(rfh.read()) - - if remove_finished: - os.remove(from_path) - - return to_path - - -def extract_archive(from_path: str, to_path: Optional[str] = None, remove_finished: bool = False) -> str: - """Extract an archive. - - The archive type and a possible compression is automatically detected from the file name. If the file is compressed - but not an archive the call is dispatched to :func:`decompress`. - - Args: - from_path (str): Path to the file to be extracted. - to_path (str): Path to the directory the file will be extracted to. If omitted, the directory of the file is - used. - remove_finished (bool): If ``True``, remove the file after the extraction. - - Returns: - (str): Path to the directory the file was extracted to. - """ - if to_path is None: - to_path = os.path.dirname(from_path) - - suffix, archive_type, compression = _detect_file_type(from_path) - if not archive_type: - return _decompress( - from_path, - os.path.join(to_path, os.path.basename(from_path).replace(suffix, "")), - remove_finished=remove_finished, - ) - - # We don't need to check for a missing key here, since this was already done in _detect_file_type() - extractor = _ARCHIVE_EXTRACTORS[archive_type] - - extractor(from_path, to_path, compression) - if remove_finished: - os.remove(from_path) - - return to_path - - -def download_and_extract_archive( - url: str, - download_root: str, - extract_root: Optional[str] = None, - filename: Optional[str] = None, - md5: Optional[str] = None, - remove_finished: bool = False, -) -> None: - download_root = os.path.expanduser(download_root) - if extract_root is None: - extract_root = download_root - if not filename: - filename = os.path.basename(url) - - download_url(url, download_root, filename, md5) - - archive = os.path.join(download_root, filename) - print(f"Extracting {archive} to {extract_root}") - extract_archive(archive, extract_root, remove_finished) - - -def iterable_to_str(iterable: Iterable) -> str: - return "'" + "', '".join([str(item) for item in iterable]) + "'" - - -T = TypeVar("T", str, bytes) - - -def verify_str_arg( - value: T, - arg: Optional[str] = None, - valid_values: Optional[Iterable[T]] = None, - custom_msg: Optional[str] = None, -) -> T: - if not isinstance(value, torch._six.string_classes): - if arg is None: - msg = "Expected type str, but got type {type}." - else: - msg = "Expected type str for argument {arg}, but got type {type}." - msg = msg.format(type=type(value), arg=arg) - raise ValueError(msg) - - if valid_values is None: - return value - - if value not in valid_values: - if custom_msg is not None: - msg = custom_msg - else: - msg = "Unknown value '{value}' for argument {arg}. Valid values are {{{valid_values}}}." - msg = msg.format(value=value, arg=arg, valid_values=iterable_to_str(valid_values)) - raise ValueError(msg) - - return value - - -def _read_pfm(file_name: str, slice_channels: int = 2) -> np.ndarray: - """Read file in .pfm format. Might contain either 1 or 3 channels of data. - - Args: - file_name (str): Path to the file. - slice_channels (int): Number of channels to slice out of the file. - Useful for reading different data formats stored in .pfm files: Optical Flows, Stereo Disparity Maps, etc. - """ - - with open(file_name, "rb") as f: - header = f.readline().rstrip() - if header not in [b"PF", b"Pf"]: - raise ValueError("Invalid PFM file") - - dim_match = re.match(rb"^(\d+)\s(\d+)\s$", f.readline()) - if not dim_match: - raise Exception("Malformed PFM header.") - w, h = (int(dim) for dim in dim_match.groups()) - - scale = float(f.readline().rstrip()) - if scale < 0: # little-endian - endian = "<" - scale = -scale - else: - endian = ">" # big-endian - - data = np.fromfile(f, dtype=endian + "f") - - pfm_channels = 3 if header == b"PF" else 1 - - data = data.reshape(h, w, pfm_channels).transpose(2, 0, 1) - data = np.flip(data, axis=1) # flip on h dimension - data = data[:slice_channels, :, :] - return data.astype(np.float32) \ No newline at end of file diff --git a/python/jittor/compatibility/vision/datasets/vision.py b/python/jittor/compatibility/vision/datasets/vision.py deleted file mode 100644 index d71dc2a5..00000000 --- a/python/jittor/compatibility/vision/datasets/vision.py +++ /dev/null @@ -1,104 +0,0 @@ -import os -from typing import Any, Callable, List, Optional, Tuple - -import torch -import torch.utils.data as data - -from ..utils import _log_api_usage_once - - -class VisionDataset(data.Dataset): - """ - Base Class For making datasets which are compatible with torchvision. - It is necessary to override the ``__getitem__`` and ``__len__`` method. - Args: - root (string): Root directory of dataset. - transforms (callable, optional): A function/transforms that takes in - an image and a label and returns the transformed versions of both. - transform (callable, optional): A function/transform that takes in an PIL image - and returns a transformed version. E.g, ``transforms.RandomCrop`` - target_transform (callable, optional): A function/transform that takes in the - target and transforms it. - .. note:: - :attr:`transforms` and the combination of :attr:`transform` and :attr:`target_transform` are mutually exclusive. - """ - - _repr_indent = 4 - - def __init__( - self, - root: str, - transforms: Optional[Callable] = None, - transform: Optional[Callable] = None, - target_transform: Optional[Callable] = None, - ) -> None: - self.root = root - - has_transforms = transforms is not None - has_separate_transform = transform is not None or target_transform is not None - if has_transforms and has_separate_transform: - raise ValueError("Only transforms or transform/target_transform can be passed as argument") - - # for backwards-compatibility - self.transform = transform - self.target_transform = target_transform - - if has_separate_transform: - transforms = StandardTransform(transform, target_transform) - self.transforms = transforms - - def __getitem__(self, index: int) -> Any: - """ - Args: - index (int): Index - Returns: - (Any): Sample and meta data, optionally transformed by the respective transforms. - """ - raise NotImplementedError - - def __len__(self) -> int: - raise NotImplementedError - - def __repr__(self) -> str: - head = "Dataset " + self.__class__.__name__ - body = [f"Number of datapoints: {self.__len__()}"] - if self.root is not None: - body.append(f"Root location: {self.root}") - body += self.extra_repr().splitlines() - if hasattr(self, "transforms") and self.transforms is not None: - body += [repr(self.transforms)] - lines = [head] + [" " * self._repr_indent + line for line in body] - return "\n".join(lines) - - def _format_transform_repr(self, transform: Callable, head: str) -> List[str]: - lines = transform.__repr__().splitlines() - return [f"{head}{lines[0]}"] + ["{}{}".format(" " * len(head), line) for line in lines[1:]] - - def extra_repr(self) -> str: - return "" - - -class StandardTransform: - def __init__(self, transform: Optional[Callable] = None, target_transform: Optional[Callable] = None) -> None: - self.transform = transform - self.target_transform = target_transform - - def __call__(self, input: Any, target: Any) -> Tuple[Any, Any]: - if self.transform is not None: - input = self.transform(input) - if self.target_transform is not None: - target = self.target_transform(target) - return input, target - - def _format_transform_repr(self, transform: Callable, head: str) -> List[str]: - lines = transform.__repr__().splitlines() - return [f"{head}{lines[0]}"] + ["{}{}".format(" " * len(head), line) for line in lines[1:]] - - def __repr__(self) -> str: - body = [self.__class__.__name__] - if self.transform is not None: - body += self._format_transform_repr(self.transform, "Transform: ") - if self.target_transform is not None: - body += self._format_transform_repr(self.target_transform, "Target transform: ") - - return "\n".join(body) \ No newline at end of file diff --git a/python/jittor/compatibility/vision/transforms.py b/python/jittor/compatibility/vision/transforms.py deleted file mode 100644 index 416057c7..00000000 --- a/python/jittor/compatibility/vision/transforms.py +++ /dev/null @@ -1 +0,0 @@ -from jittor.transform import * \ No newline at end of file diff --git a/python/jittor/compatibility/vision/utils.py b/python/jittor/compatibility/vision/utils.py deleted file mode 100644 index 4be36c64..00000000 --- a/python/jittor/compatibility/vision/utils.py +++ /dev/null @@ -1,582 +0,0 @@ -import collections -import math -import pathlib -import warnings -from itertools import repeat -from types import FunctionType -from typing import Any, BinaryIO, List, Optional, Tuple, Union - -import numpy as np -import torch -from PIL import Image, ImageColor, ImageDraw, ImageFont - -__all__ = [ - "make_grid", - "save_image", - "draw_bounding_boxes", - "draw_segmentation_masks", - "draw_keypoints", - "flow_to_image", -] - - -@torch.no_grad() -def make_grid( - tensor: Union[torch.Tensor, List[torch.Tensor]], - nrow: int = 8, - padding: int = 2, - normalize: bool = False, - value_range: Optional[Tuple[int, int]] = None, - scale_each: bool = False, - pad_value: float = 0.0, - **kwargs, -) -> torch.Tensor: - """ - Make a grid of images. - - Args: - tensor (Tensor or list): 4D mini-batch Tensor of shape (B x C x H x W) - or a list of images all of the same size. - nrow (int, optional): Number of images displayed in each row of the grid. - The final grid size is ``(B / nrow, nrow)``. Default: ``8``. - padding (int, optional): amount of padding. Default: ``2``. - normalize (bool, optional): If True, shift the image to the range (0, 1), - by the min and max values specified by ``value_range``. Default: ``False``. - value_range (tuple, optional): tuple (min, max) where min and max are numbers, - then these numbers are used to normalize the image. By default, min and max - are computed from the tensor. - range (tuple. optional): - .. warning:: - This parameter was deprecated in ``0.12`` and will be removed in ``0.14``. Please use ``value_range`` - instead. - scale_each (bool, optional): If ``True``, scale each image in the batch of - images separately rather than the (min, max) over all images. Default: ``False``. - pad_value (float, optional): Value for the padded pixels. Default: ``0``. - - Returns: - grid (Tensor): the tensor containing grid of images. - """ - if not torch.jit.is_scripting() and not torch.jit.is_tracing(): - _log_api_usage_once(make_grid) - if not torch.is_tensor(tensor): - if isinstance(tensor, list): - for t in tensor: - if not torch.is_tensor(t): - raise TypeError(f"tensor or list of tensors expected, got a list containing {type(t)}") - else: - raise TypeError(f"tensor or list of tensors expected, got {type(tensor)}") - - if "range" in kwargs.keys(): - warnings.warn( - "The parameter 'range' is deprecated since 0.12 and will be removed in 0.14. " - "Please use 'value_range' instead." - ) - value_range = kwargs["range"] - - # if list of tensors, convert to a 4D mini-batch Tensor - if isinstance(tensor, list): - tensor = torch.stack(tensor, dim=0) - - if tensor.dim() == 2: # single image H x W - tensor = tensor.unsqueeze(0) - if tensor.dim() == 3: # single image - if tensor.size(0) == 1: # if single-channel, convert to 3-channel - tensor = torch.cat((tensor, tensor, tensor), 0) - tensor = tensor.unsqueeze(0) - - if tensor.dim() == 4 and tensor.size(1) == 1: # single-channel images - tensor = torch.cat((tensor, tensor, tensor), 1) - - if normalize is True: - tensor = tensor.clone() # avoid modifying tensor in-place - if value_range is not None and not isinstance(value_range, tuple): - raise TypeError("value_range has to be a tuple (min, max) if specified. min and max are numbers") - - def norm_ip(img, low, high): - img.clamp_(min=low, max=high) - img.sub_(low).div_(max(high - low, 1e-5)) - - def norm_range(t, value_range): - if value_range is not None: - norm_ip(t, value_range[0], value_range[1]) - else: - norm_ip(t, float(t.min()), float(t.max())) - - if scale_each is True: - for t in tensor: # loop over mini-batch dimension - norm_range(t, value_range) - else: - norm_range(tensor, value_range) - - if not isinstance(tensor, torch.Tensor): - raise TypeError("tensor should be of type torch.Tensor") - if tensor.size(0) == 1: - return tensor.squeeze(0) - - # make the mini-batch of images into a grid - nmaps = tensor.size(0) - xmaps = min(nrow, nmaps) - ymaps = int(math.ceil(float(nmaps) / xmaps)) - height, width = int(tensor.size(2) + padding), int(tensor.size(3) + padding) - num_channels = tensor.size(1) - grid = tensor.new_full((num_channels, height * ymaps + padding, width * xmaps + padding), pad_value) - k = 0 - for y in range(ymaps): - for x in range(xmaps): - if k >= nmaps: - break - # Tensor.copy_() is a valid method but seems to be missing from the stubs - # https://pytorch.org/docs/stable/tensors.html#torch.Tensor.copy_ - grid.narrow(1, y * height + padding, height - padding).narrow( # type: ignore[attr-defined] - 2, x * width + padding, width - padding - ).copy_(tensor[k]) - k = k + 1 - return grid - - -@torch.no_grad() -def save_image( - tensor: Union[torch.Tensor, List[torch.Tensor]], - fp: Union[str, pathlib.Path, BinaryIO], - format: Optional[str] = None, - **kwargs, -) -> None: - """ - Save a given Tensor into an image file. - - Args: - tensor (Tensor or list): Image to be saved. If given a mini-batch tensor, - saves the tensor as a grid of images by calling ``make_grid``. - fp (string or file object): A filename or a file object - format(Optional): If omitted, the format to use is determined from the filename extension. - If a file object was used instead of a filename, this parameter should always be used. - **kwargs: Other arguments are documented in ``make_grid``. - """ - - if not torch.jit.is_scripting() and not torch.jit.is_tracing(): - _log_api_usage_once(save_image) - grid = make_grid(tensor, **kwargs) - # Add 0.5 after unnormalizing to [0, 255] to round to nearest integer - ndarr = grid.mul(255).add_(0.5).clamp_(0, 255).permute(1, 2, 0).to("cpu", torch.uint8).numpy() - im = Image.fromarray(ndarr) - im.save(fp, format=format) - - -@torch.no_grad() -def draw_bounding_boxes( - image: torch.Tensor, - boxes: torch.Tensor, - labels: Optional[List[str]] = None, - colors: Optional[Union[List[Union[str, Tuple[int, int, int]]], str, Tuple[int, int, int]]] = None, - fill: Optional[bool] = False, - width: int = 1, - font: Optional[str] = None, - font_size: Optional[int] = None, -) -> torch.Tensor: - - """ - Draws bounding boxes on given image. - The values of the input image should be uint8 between 0 and 255. - If fill is True, Resulting Tensor should be saved as PNG image. - - Args: - image (Tensor): Tensor of shape (C x H x W) and dtype uint8. - boxes (Tensor): Tensor of size (N, 4) containing bounding boxes in (xmin, ymin, xmax, ymax) format. Note that - the boxes are absolute coordinates with respect to the image. In other words: `0 <= xmin < xmax < W` and - `0 <= ymin < ymax < H`. - labels (List[str]): List containing the labels of bounding boxes. - colors (color or list of colors, optional): List containing the colors - of the boxes or single color for all boxes. The color can be represented as - PIL strings e.g. "red" or "#FF00FF", or as RGB tuples e.g. ``(240, 10, 157)``. - By default, random colors are generated for boxes. - fill (bool): If `True` fills the bounding box with specified color. - width (int): Width of bounding box. - font (str): A filename containing a TrueType font. If the file is not found in this filename, the loader may - also search in other directories, such as the `fonts/` directory on Windows or `/Library/Fonts/`, - `/System/Library/Fonts/` and `~/Library/Fonts/` on macOS. - font_size (int): The requested font size in points. - - Returns: - img (Tensor[C, H, W]): Image Tensor of dtype uint8 with bounding boxes plotted. - """ - - if not torch.jit.is_scripting() and not torch.jit.is_tracing(): - _log_api_usage_once(draw_bounding_boxes) - if not isinstance(image, torch.Tensor): - raise TypeError(f"Tensor expected, got {type(image)}") - elif image.dtype != torch.uint8: - raise ValueError(f"Tensor uint8 expected, got {image.dtype}") - elif image.dim() != 3: - raise ValueError("Pass individual images, not batches") - elif image.size(0) not in {1, 3}: - raise ValueError("Only grayscale and RGB images are supported") - elif (boxes[:, 0] > boxes[:, 2]).any() or (boxes[:, 1] > boxes[:, 3]).any(): - raise ValueError( - "Boxes need to be in (xmin, ymin, xmax, ymax) format. Use torchvision.ops.box_convert to convert them" - ) - - num_boxes = boxes.shape[0] - - if num_boxes == 0: - warnings.warn("boxes doesn't contain any box. No box was drawn") - return image - - if labels is None: - labels: Union[List[str], List[None]] = [None] * num_boxes # type: ignore[no-redef] - elif len(labels) != num_boxes: - raise ValueError( - f"Number of boxes ({num_boxes}) and labels ({len(labels)}) mismatch. Please specify labels for each box." - ) - - if colors is None: - colors = _generate_color_palette(num_boxes) - elif isinstance(colors, list): - if len(colors) < num_boxes: - raise ValueError(f"Number of colors ({len(colors)}) is less than number of boxes ({num_boxes}). ") - else: # colors specifies a single color for all boxes - colors = [colors] * num_boxes - - colors = [(ImageColor.getrgb(color) if isinstance(color, str) else color) for color in colors] - - if font is None: - if font_size is not None: - warnings.warn("Argument 'font_size' will be ignored since 'font' is not set.") - txt_font = ImageFont.load_default() - else: - txt_font = ImageFont.truetype(font=font, size=font_size or 10) - - # Handle Grayscale images - if image.size(0) == 1: - image = torch.tile(image, (3, 1, 1)) - - ndarr = image.permute(1, 2, 0).cpu().numpy() - img_to_draw = Image.fromarray(ndarr) - img_boxes = boxes.to(torch.int64).tolist() - - if fill: - draw = ImageDraw.Draw(img_to_draw, "RGBA") - else: - draw = ImageDraw.Draw(img_to_draw) - - for bbox, color, label in zip(img_boxes, colors, labels): # type: ignore[arg-type] - if fill: - fill_color = color + (100,) - draw.rectangle(bbox, width=width, outline=color, fill=fill_color) - else: - draw.rectangle(bbox, width=width, outline=color) - - if label is not None: - margin = width + 1 - draw.text((bbox[0] + margin, bbox[1] + margin), label, fill=color, font=txt_font) - - return torch.from_numpy(np.array(img_to_draw)).permute(2, 0, 1).to(dtype=torch.uint8) - - -@torch.no_grad() -def draw_segmentation_masks( - image: torch.Tensor, - masks: torch.Tensor, - alpha: float = 0.8, - colors: Optional[Union[List[Union[str, Tuple[int, int, int]]], str, Tuple[int, int, int]]] = None, -) -> torch.Tensor: - - """ - Draws segmentation masks on given RGB image. - The values of the input image should be uint8 between 0 and 255. - - Args: - image (Tensor): Tensor of shape (3, H, W) and dtype uint8. - masks (Tensor): Tensor of shape (num_masks, H, W) or (H, W) and dtype bool. - alpha (float): Float number between 0 and 1 denoting the transparency of the masks. - 0 means full transparency, 1 means no transparency. - colors (color or list of colors, optional): List containing the colors - of the masks or single color for all masks. The color can be represented as - PIL strings e.g. "red" or "#FF00FF", or as RGB tuples e.g. ``(240, 10, 157)``. - By default, random colors are generated for each mask. - - Returns: - img (Tensor[C, H, W]): Image Tensor, with segmentation masks drawn on top. - """ - - if not torch.jit.is_scripting() and not torch.jit.is_tracing(): - _log_api_usage_once(draw_segmentation_masks) - if not isinstance(image, torch.Tensor): - raise TypeError(f"The image must be a tensor, got {type(image)}") - elif image.dtype != torch.uint8: - raise ValueError(f"The image dtype must be uint8, got {image.dtype}") - elif image.dim() != 3: - raise ValueError("Pass individual images, not batches") - elif image.size()[0] != 3: - raise ValueError("Pass an RGB image. Other Image formats are not supported") - if masks.ndim == 2: - masks = masks[None, :, :] - if masks.ndim != 3: - raise ValueError("masks must be of shape (H, W) or (batch_size, H, W)") - if masks.dtype != torch.bool: - raise ValueError(f"The masks must be of dtype bool. Got {masks.dtype}") - if masks.shape[-2:] != image.shape[-2:]: - raise ValueError("The image and the masks must have the same height and width") - - num_masks = masks.size()[0] - if colors is not None and num_masks > len(colors): - raise ValueError(f"There are more masks ({num_masks}) than colors ({len(colors)})") - - if num_masks == 0: - warnings.warn("masks doesn't contain any mask. No mask was drawn") - return image - - if colors is None: - colors = _generate_color_palette(num_masks) - - if not isinstance(colors, list): - colors = [colors] - if not isinstance(colors[0], (tuple, str)): - raise ValueError("colors must be a tuple or a string, or a list thereof") - if isinstance(colors[0], tuple) and len(colors[0]) != 3: - raise ValueError("It seems that you passed a tuple of colors instead of a list of colors") - - out_dtype = torch.uint8 - - colors_ = [] - for color in colors: - if isinstance(color, str): - color = ImageColor.getrgb(color) - colors_.append(torch.tensor(color, dtype=out_dtype)) - - img_to_draw = image.detach().clone() - # TODO: There might be a way to vectorize this - for mask, color in zip(masks, colors_): - img_to_draw[:, mask] = color[:, None] - - out = image * (1 - alpha) + img_to_draw * alpha - return out.to(out_dtype) - - -@torch.no_grad() -def draw_keypoints( - image: torch.Tensor, - keypoints: torch.Tensor, - connectivity: Optional[List[Tuple[int, int]]] = None, - colors: Optional[Union[str, Tuple[int, int, int]]] = None, - radius: int = 2, - width: int = 3, -) -> torch.Tensor: - - """ - Draws Keypoints on given RGB image. - The values of the input image should be uint8 between 0 and 255. - - Args: - image (Tensor): Tensor of shape (3, H, W) and dtype uint8. - keypoints (Tensor): Tensor of shape (num_instances, K, 2) the K keypoints location for each of the N instances, - in the format [x, y]. - connectivity (List[Tuple[int, int]]]): A List of tuple where, - each tuple contains pair of keypoints to be connected. - colors (str, Tuple): The color can be represented as - PIL strings e.g. "red" or "#FF00FF", or as RGB tuples e.g. ``(240, 10, 157)``. - radius (int): Integer denoting radius of keypoint. - width (int): Integer denoting width of line connecting keypoints. - - Returns: - img (Tensor[C, H, W]): Image Tensor of dtype uint8 with keypoints drawn. - """ - - if not torch.jit.is_scripting() and not torch.jit.is_tracing(): - _log_api_usage_once(draw_keypoints) - if not isinstance(image, torch.Tensor): - raise TypeError(f"The image must be a tensor, got {type(image)}") - elif image.dtype != torch.uint8: - raise ValueError(f"The image dtype must be uint8, got {image.dtype}") - elif image.dim() != 3: - raise ValueError("Pass individual images, not batches") - elif image.size()[0] != 3: - raise ValueError("Pass an RGB image. Other Image formats are not supported") - - if keypoints.ndim != 3: - raise ValueError("keypoints must be of shape (num_instances, K, 2)") - - ndarr = image.permute(1, 2, 0).cpu().numpy() - img_to_draw = Image.fromarray(ndarr) - draw = ImageDraw.Draw(img_to_draw) - img_kpts = keypoints.to(torch.int64).tolist() - - for kpt_id, kpt_inst in enumerate(img_kpts): - for inst_id, kpt in enumerate(kpt_inst): - x1 = kpt[0] - radius - x2 = kpt[0] + radius - y1 = kpt[1] - radius - y2 = kpt[1] + radius - draw.ellipse([x1, y1, x2, y2], fill=colors, outline=None, width=0) - - if connectivity: - for connection in connectivity: - start_pt_x = kpt_inst[connection[0]][0] - start_pt_y = kpt_inst[connection[0]][1] - - end_pt_x = kpt_inst[connection[1]][0] - end_pt_y = kpt_inst[connection[1]][1] - - draw.line( - ((start_pt_x, start_pt_y), (end_pt_x, end_pt_y)), - width=width, - ) - - return torch.from_numpy(np.array(img_to_draw)).permute(2, 0, 1).to(dtype=torch.uint8) - - -# Flow visualization code adapted from https://github.com/tomrunia/OpticalFlow_Visualization -@torch.no_grad() -def flow_to_image(flow: torch.Tensor) -> torch.Tensor: - - """ - Converts a flow to an RGB image. - - Args: - flow (Tensor): Flow of shape (N, 2, H, W) or (2, H, W) and dtype torch.float. - - Returns: - img (Tensor): Image Tensor of dtype uint8 where each color corresponds - to a given flow direction. Shape is (N, 3, H, W) or (3, H, W) depending on the input. - """ - - if flow.dtype != torch.float: - raise ValueError(f"Flow should be of dtype torch.float, got {flow.dtype}.") - - orig_shape = flow.shape - if flow.ndim == 3: - flow = flow[None] # Add batch dim - - if flow.ndim != 4 or flow.shape[1] != 2: - raise ValueError(f"Input flow should have shape (2, H, W) or (N, 2, H, W), got {orig_shape}.") - - max_norm = torch.sum(flow**2, dim=1).sqrt().max() - epsilon = torch.finfo((flow).dtype).eps - normalized_flow = flow / (max_norm + epsilon) - img = _normalized_flow_to_image(normalized_flow) - - if len(orig_shape) == 3: - img = img[0] # Remove batch dim - return img - - -@torch.no_grad() -def _normalized_flow_to_image(normalized_flow: torch.Tensor) -> torch.Tensor: - - """ - Converts a batch of normalized flow to an RGB image. - - Args: - normalized_flow (torch.Tensor): Normalized flow tensor of shape (N, 2, H, W) - Returns: - img (Tensor(N, 3, H, W)): Flow visualization image of dtype uint8. - """ - - N, _, H, W = normalized_flow.shape - device = normalized_flow.device - flow_image = torch.zeros((N, 3, H, W), dtype=torch.uint8, device=device) - colorwheel = _make_colorwheel().to(device) # shape [55x3] - num_cols = colorwheel.shape[0] - norm = torch.sum(normalized_flow**2, dim=1).sqrt() - a = torch.atan2(-normalized_flow[:, 1, :, :], -normalized_flow[:, 0, :, :]) / torch.pi - fk = (a + 1) / 2 * (num_cols - 1) - k0 = torch.floor(fk).to(torch.long) - k1 = k0 + 1 - k1[k1 == num_cols] = 0 - f = fk - k0 - - for c in range(colorwheel.shape[1]): - tmp = colorwheel[:, c] - col0 = tmp[k0] / 255.0 - col1 = tmp[k1] / 255.0 - col = (1 - f) * col0 + f * col1 - col = 1 - norm * (1 - col) - flow_image[:, c, :, :] = torch.floor(255 * col) - return flow_image - - -def _make_colorwheel() -> torch.Tensor: - """ - Generates a color wheel for optical flow visualization as presented in: - Baker et al. "A Database and Evaluation Methodology for Optical Flow" (ICCV, 2007) - URL: http://vision.middlebury.edu/flow/flowEval-iccv07.pdf. - - Returns: - colorwheel (Tensor[55, 3]): Colorwheel Tensor. - """ - - RY = 15 - YG = 6 - GC = 4 - CB = 11 - BM = 13 - MR = 6 - - ncols = RY + YG + GC + CB + BM + MR - colorwheel = torch.zeros((ncols, 3)) - col = 0 - - # RY - colorwheel[0:RY, 0] = 255 - colorwheel[0:RY, 1] = torch.floor(255 * torch.arange(0, RY) / RY) - col = col + RY - # YG - colorwheel[col : col + YG, 0] = 255 - torch.floor(255 * torch.arange(0, YG) / YG) - colorwheel[col : col + YG, 1] = 255 - col = col + YG - # GC - colorwheel[col : col + GC, 1] = 255 - colorwheel[col : col + GC, 2] = torch.floor(255 * torch.arange(0, GC) / GC) - col = col + GC - # CB - colorwheel[col : col + CB, 1] = 255 - torch.floor(255 * torch.arange(CB) / CB) - colorwheel[col : col + CB, 2] = 255 - col = col + CB - # BM - colorwheel[col : col + BM, 2] = 255 - colorwheel[col : col + BM, 0] = torch.floor(255 * torch.arange(0, BM) / BM) - col = col + BM - # MR - colorwheel[col : col + MR, 2] = 255 - torch.floor(255 * torch.arange(MR) / MR) - colorwheel[col : col + MR, 0] = 255 - return colorwheel - - -def _generate_color_palette(num_objects: int): - palette = torch.tensor([2**25 - 1, 2**15 - 1, 2**21 - 1]) - return [tuple((i * palette) % 255) for i in range(num_objects)] - - -def _log_api_usage_once(obj: Any) -> None: - - """ - Logs API usage(module and name) within an organization. - In a large ecosystem, it's often useful to track the PyTorch and - TorchVision APIs usage. This API provides the similar functionality to the - logging module in the Python stdlib. It can be used for debugging purpose - to log which methods are used and by default it is inactive, unless the user - manually subscribes a logger via the `SetAPIUsageLogger method `_. - Please note it is triggered only once for the same API call within a process. - It does not collect any data from open-source users since it is no-op by default. - For more information, please refer to - * PyTorch note: https://pytorch.org/docs/stable/notes/large_scale_deployments.html#api-usage-logging; - * Logging policy: https://github.com/pytorch/vision/issues/5052; - - Args: - obj (class instance or method): an object to extract info from. - """ - pass - - -def _make_ntuple(x: Any, n: int) -> Tuple[Any, ...]: - """ - Make n-tuple from input x. If x is an iterable, then we just convert it to tuple. - Otherwise we will make a tuple of length n, all with value of x. - reference: https://github.com/pytorch/pytorch/blob/master/torch/nn/modules/utils.py#L8 - - Args: - x (Any): input value - n (int): length of the resulting tuple - """ - if isinstance(x, collections.abc.Iterable): - return tuple(x) - return tuple(repeat(x, n)) \ No newline at end of file diff --git a/python/jittor/compile_extern.py b/python/jittor/compile_extern.py index 4df000b4..f6c93b54 100644 --- a/python/jittor/compile_extern.py +++ b/python/jittor/compile_extern.py @@ -629,6 +629,7 @@ def setup_mpi(): mpi_ops = None mpi = None has_mpi = False + if not use_mpi: return mpicc_path = env_or_try_find('mpicc_path', 'mpicc') if mpicc_path == "": # LOG.i("mpicc not found, distribution disabled.") @@ -711,4 +712,4 @@ def inner(self, *args, **kw): # install backend extern library for mod in jit_utils.backends: if mod.install_extern(): - break \ No newline at end of file + break diff --git a/python/jittor/compiler.py b/python/jittor/compiler.py index c6be01bb..a5dcd136 100644 --- a/python/jittor/compiler.py +++ b/python/jittor/compiler.py @@ -1002,6 +1002,8 @@ def check_debug_flags(): r, s = sp.getstatusoutput(f"log_v=0 {sys.executable} -m jittor_utils.query_cuda_cc") if r==0: s = sorted(list(set(s.strip().split()))) + if len(s)==0: + LOG.e("No GPU Device Found!") cu += "_sm_" + "_".join(s) if "cuda_arch" not in os.environ: os.environ["cuda_arch"] = " ".join(cu) diff --git a/python/jittor/contrib.py b/python/jittor/contrib.py index 31dd72de..c4026eee 100644 --- a/python/jittor/contrib.py +++ b/python/jittor/contrib.py @@ -15,6 +15,8 @@ def argmax_pool(x, size, stride, padding=0): + if stride<=0: + raise RuntimeError(f"stride must be > 0, but got {stride}") return pool.pool(x, size, 'maximum', padding, stride) def concat(arr, dim): @@ -241,9 +243,17 @@ def concat(arr, dim=0): if len(arr) == 0: raise ValueError("need at least one array to concat") total_dim = 0 - if dim < 0: dim += len(arr[0].shape) + base_dim = len(arr[0].shape) + if dim < 0: dim += base_dim + if dim < 0 or dim >= base_dim: + raise IndexError(f"Dimension out of range (expected to be in range of [{-base_dim}, {base_dim-1}], but got {dim})") dtypes = [] for a in arr: + if len(a.shape) != base_dim: + raise RuntimeError(f"get different number of dimensions of {base_dim} and {len(a.shape)}") + for i in range(base_dim): + if i != dim and a.shape[i] != arr[0].shape[i]: + raise RuntimeError(f"Sizes of vars must match except in dimension {dim}. Expected size {arr[0].shape[i]} but got size {a.shape[i]} for dimension number {i} in the list.") total_dim += a.shape[dim] dtypes.append(str(a.dtype)) cdim = 0 diff --git a/python/jittor/dataset/mnist.py b/python/jittor/dataset/mnist.py index 7aa1d883..f0945f94 100644 --- a/python/jittor/dataset/mnist.py +++ b/python/jittor/dataset/mnist.py @@ -26,7 +26,7 @@ class MNIST(Dataset): [in] data_root(str): your data root. [in] train(bool): choose model train or val. - [in] download(bool): Download data automatically if download is Ture. + [in] download(bool): Download data automatically if download is True. [in] batch_size(int): Data batch size. [in] shuffle(bool): Shuffle data if true. [in] transform(jittor.transform): transform data. @@ -106,7 +106,7 @@ class EMNIST(Dataset): [in] data_root(str): your data root. [in] split(str): one of 'byclass', 'bymerge', 'balanced', 'letters', 'digits', 'mnist'. [in] train(bool): choose model train or val. - [in] download(bool): Download data automatically if download is Ture. + [in] download(bool): Download data automatically if download is True. [in] batch_size(int): Data batch size. [in] shuffle(bool): Shuffle data if true. [in] transform(jittor.transform): transform data. diff --git a/python/jittor/extern/acl/acl_compiler.py b/python/jittor/extern/acl/acl_compiler.py index fea9cfce..22e366eb 100644 --- a/python/jittor/extern/acl/acl_compiler.py +++ b/python/jittor/extern/acl/acl_compiler.py @@ -1,6 +1,6 @@ # *************************************************************** -# Copyright (c) 2023 Jittor. All Rights Reserved. -# Maintainers: Dun Liang . +# Copyright (c) 2023 Jittor. All Rights Reserved. +# Maintainers: Dun Liang . # This file is subject to the terms and conditions defined in # file 'LICENSE.txt', which is part of this source code package. # *************************************************************** @@ -10,6 +10,7 @@ import ctypes import glob import jittor.compiler as compiler +import jittor as jt has_acl = 0 cc_flags = "" @@ -34,16 +35,18 @@ # export DUMP_GRAPH_LEVEL=1 # build pytorch-npu -# bash ./ci/build.sh -# python3 -m pip install ./dist/torch_npu-1.11.0.post1-cp37-cp37m-linux_x86_64.whl --force-reinstall +# bash ./ci/build.sh +# python3 -m pip install ./dist/torch_npu-1.11.0.post1-cp37-cp37m-linux_x86_64.whl --force-reinstall # pytorch: conda activate cann && source /usr/local/Ascend/ascend-toolkit/set_env.sh && export TASK_QUEUE_ENABLE=0 && cd /home/cjld/new_jittor/jittor/my/mm_benchmark # python3 ./mm_bench_pt_npu.py + def install(): import jittor.compiler as compiler global has_acl, cc_flags acl_compiler_home = os.path.dirname(__file__) - cc_files = sorted(glob.glob(acl_compiler_home+"/**/*.cc", recursive=True)) + cc_files = sorted(glob.glob(acl_compiler_home + "/**/*.cc", + recursive=True)) cc_files2 = [] for name in cc_files: if "acl_op_exec" in name: @@ -51,21 +54,24 @@ def install(): else: cc_files2.append(name) cc_files = cc_files2 + ascend_toolkit_home = os.getenv('ASCEND_TOOLKIT_HOME') cc_flags += f" -DHAS_CUDA -DIS_ACL \ - -I/usr/local/Ascend/ascend-toolkit/latest/x86_64-linux/include/ \ - -L/usr/local/Ascend/ascend-toolkit/latest/x86_64-linux/lib64 \ + -I{ascend_toolkit_home}/include/ \ + -L{ascend_toolkit_home}/lib64/ \ -I{acl_compiler_home} -lascendcl -lacl_op_compiler " + ctypes.CDLL("libascendcl.so", dlopen_flags) - ''' + f''' -ltikc_runtime - -I/usr/local/Ascend/driver/include \ - -L/usr/local/Ascend/compiler/lib64 \ - -L/usr/local/Ascend/runtime/lib64 \ + -I/usr/local/Ascend/driver/include/ \ + -L{ascend_toolkit_home}/compiler/lib64/ \ + -L{ascend_toolkit_home}/runtime/lib64/ \ ''' jittor_utils.LOG.i("ACL detected") global mod - mod = jittor_utils.compile_module(''' + mod = jittor_utils.compile_module( + ''' #include "common.h" namespace jittor { // @pyjt(process) @@ -98,9 +104,10 @@ def check(): if not has_acl: return False compiler.cc_flags += cc_flags compiler.nvcc_path = tikcc_path - compiler.nvcc_flags = compiler.cc_flags.replace("-std=c++14","") + compiler.nvcc_flags = compiler.cc_flags.replace("-std=c++14", "") return True + def post_process(): if has_acl: from jittor import pool @@ -108,5 +115,1240 @@ def post_process(): import jittor as jt jt.flags.use_cuda_host_allocator = 1 jt.flags.use_parallel_op_compiler = 0 - jt.flags.amp_reg |= 32 + 4 # 32 keep float16, 4 keep reduce type - mod.init_acl_ops() \ No newline at end of file + jt.flags.amp_reg |= 32 + 4 # 32 keep float16, 4 keep reduce type + mod.init_acl_ops() + + +def acl_cmd(name: str, inputs: list, output_dtypes: list, output_shapes: list, + attr: dict): + nchw_op = ['MaxPoolWithArgmaxV1', 'MaxPoolGradWithArgmaxV1', 'AvgPoolV2'] + attr_op = [ + 'MaxPoolWithArgmaxV1', 'MaxPoolGradWithArgmaxV1', 'AvgPoolV2', + 'AdaptiveAvgPool2d', 'AdaptiveAvgPool2dGrad', 'ReverseV2' + ] + + input_code = '' + for i in range(len(inputs)): + if name in nchw_op: + input_code += f"op.add(in{i}, true, ACL_FORMAT_NCHW);\n" + else: + input_code += f"op.add(in{i}, true);\n" + + output_code = '' + for i in range(len(output_dtypes)): + if name in nchw_op: + output_code += f"op.add(out{i}, false, ACL_FORMAT_NCHW);\n" + else: + output_code += f"op.add(out{i}, false);\n" + + # add attr to op + attr_code = '' + if name in attr_op: + for k, v in attr.items(): + if isinstance(v, bool): + if v == True: + attr_code += f"op.set_attr(\"{k}\", 1, 1);\n" + else: + attr_code += f"op.set_attr(\"{k}\", 1, 0);\n" + elif isinstance(v, str): + attr_code += f"op.set_attr(\"{k}\", \"{v}\");\n" + elif k == 'divisor_override_value': + attr_code += f"op.set_attr(\"{k}\", int64_t({v}), 0);\n" + else: + v = str(v).replace('[', '{').replace(']', '}') + attr_code += f"op.set_attr(\"{k}\", vector{v});\n" + else: + for k, v in attr.items(): + if isinstance(v, bool): + if v == True: + attr_code += f"op.set_attr(\"{k}\", 1, 1);\n" + else: + attr_code += f"op.set_attr(\"{k}\", 1, 0);\n" + elif isinstance(v, str): + attr_code += f"op.set_attr(\"{k}\", \"{v}\");\n" + else: + attr_code += f"op.set_attr(\"{k}\", int({v}));\n" + + #print("input_code",input_code) + #print("attr_code",attr_code) + import jittor as jt + return jt.code(output_shapes, + output_dtypes, + inputs, + cuda_header=""" + #include + #include + #include + #include + + namespace jittor { + + void printDeviceData(const vector& output_desc, const vector& output_data, const string& name = "", bool input=true) { + LOGir << "name: " << name; + if(input) + LOGir << "is input"; + else + LOGir << "is ouput"; + for (size_t i = 0; i < output_desc.size(); ++i) { + void* base_addr = aclGetDataBufferAddr(output_data[i]); + LOGir << "addr of data[" << i << "] :" << base_addr; + size_t num_dims = aclGetTensorDescNumDims(output_desc[i]); + size_t total_size = 1; + std::vector dims(num_dims); + + std::cout << "shape of data: "; + for (size_t j = 0; j < num_dims; ++j) { + aclGetTensorDescDimV2(output_desc[i], j, &dims[j]); + total_size *= dims[j]; + std::cout << dims[j] << ", "; + } + int evey_batch_size = total_size/dims[0]; + std::cout << std::endl; + + // for(int i= 0; i < dims[0]; i++) { + // evey_batch_size = 16; + // std::vector host_buffer(evey_batch_size); + // void* offset_addr = static_cast(base_addr) + i * evey_batch_size * sizeof(float); + // aclrtMemcpy(host_buffer.data(), evey_batch_size * sizeof(float), offset_addr, evey_batch_size * sizeof(float), ACL_MEMCPY_DEVICE_TO_HOST); + // std::cout << "batch[" << i << "]:"; + // for (size_t k = 0; k < evey_batch_size; ++k) { + // std::cout << host_buffer[k] << ", "; + // } + // std::cout << std::endl; + // } + } + } + + struct AclOpRunner { + string name; + vector input_desc; + vector output_desc; + vector input_data; + vector output_data; + aclopAttr *attr; + vector> input_host; + vector> input_host_32; + + AclOpRunner(const string& name) : name(name) { + attr = aclopCreateAttr(); + } + + ~AclOpRunner() { + for (auto i : input_desc) aclDestroyTensorDesc(i); + for (auto i : output_desc) aclDestroyTensorDesc(i); + for (auto i : input_data) aclDestroyDataBuffer(i); + for (auto i : output_data) aclDestroyDataBuffer(i); + aclopDestroyAttr(attr); + } + + aclDataType get_dtype(NanoString s) { + if (s == ns_float32) return ACL_FLOAT; + if (s == ns_float16) return ACL_FLOAT16; + if (s == ns_int64) return ACL_INT64; + if (s == ns_int32) return ACL_INT32; + if (s == ns_int8) return ACL_INT8; + if (s == ns_int16) return ACL_INT16; + if (s == ns_uint8) return ACL_UINT8; + if (s == ns_uint16) return ACL_UINT16; + if (s == ns_uint32) return ACL_UINT32; + if (s == ns_bool) return ACL_BOOL; + LOGf << "Not supported dtype: " << s; + return ACL_FLOAT; + } + + void add(Var* v, bool is_input, int format=ACL_FORMAT_ND) { + int64_t shape[v->shape.size()]; + for (int i=0; ishape.size(); i++) shape[i] = v->shape[i]; + + auto desc = aclCreateTensorDesc(get_dtype(v->dtype()), v->shape.size(), &shape[0], (aclFormat)format); + aclSetTensorFormat(desc, (aclFormat)format); + aclSetTensorShape(desc, v->shape.size(), &shape[0]); + LOGv << "aclCreateTensorDesc" << (int)get_dtype(v->dtype()) << v->shape.size() << &shape[0] << format; + auto data = aclCreateDataBuffer(v->mem_ptr, v->size); + LOGv << "aclCreateDataBuffer" << v->mem_ptr << v->size; + ASSERT(desc && data); + if (is_input) { + input_desc.push_back(desc); + input_data.push_back(data); + } else { + output_desc.push_back(desc); + output_data.push_back(data); + } + } + + void add_input_host(vector v, int dtype=ACL_UINT64) { + int64_t shape[1]; + shape[0] = v.size(); + auto desc = aclCreateTensorDesc((aclDataType)dtype, 1, &shape[0], ACL_FORMAT_ND); + aclSetTensorPlaceMent(desc, ACL_MEMTYPE_HOST_COMPILE_INDEPENDENT); + LOGv << "aclCreateTensorDesc" << dtype << 1 << &shape[0] << ACL_FORMAT_ND; + auto data = aclCreateDataBuffer(&v[0], v.size()*sizeof(uint64)); + ASSERT(desc && data); + LOGv << "aclCreateDataBuffer" << &v[0] << v.size()*sizeof(uint64); + input_desc.push_back(desc); + input_data.push_back(data); + input_host.emplace_back(move(v)); + LOGv << "move" << input_host.back().data(); + } + + void add_input_host_scalar(vector v, int dtype=ACL_UINT32) { + int64_t shape[1]; + shape[0] = v.size(); + auto x = (int*)&v[0]; + x[0] = (int32)v[0]; + auto desc = aclCreateTensorDesc((aclDataType)dtype, 0, &shape[0], ACL_FORMAT_ND); + aclSetTensorPlaceMent(desc, ACL_MEMTYPE_HOST_COMPILE_INDEPENDENT); + LOGv << "aclCreateTensorDesc" << dtype << 1 << &shape[0] << ACL_FORMAT_ND; + auto data = aclCreateDataBuffer(&v[0], v.size()*sizeof(uint32)); + ASSERT(desc && data); + LOGv << "aclCreateDataBuffer" << &v[0] << v.size()*sizeof(uint32); + input_desc.push_back(desc); + input_data.push_back(data); + input_host.emplace_back(move(v)); + } + + void add_input_host_nv(NanoVector nv, int dtype=ACL_UINT64) { + vector v(nv.size()); + for (int i=0; i v(nv.size()); + for (int i=0; i value) { + // LOGir << "string vector" << "set_attr" << key << value; + CHECK(aclopSetAttrListInt(attr, key.c_str(), value.size(), &value[0])==0); + } + void set_attr(const string& key, string value) { + // LOGir << "string string" << "set_attr" << key << value; + CHECK(aclopSetAttrString(attr, key.c_str(), value.c_str())==0); + } + void set_attr(const char* key, const char* value) { + // LOGir << "char" << "set_attr" << key << value; + CHECK(aclopSetAttrString(attr, key, value)==0); + } + + void run() { + // printDeviceData(input_desc, input_data, name); + + LOGv << "run" << name << input_desc.size() << output_desc.size(); + if (!PyGILState_Check()) { + ASSERT(0==aclopCompileAndExecuteV2(name.c_str(), input_desc.size(), &input_desc[0], &input_data[0], output_desc.size(), &output_desc[0], &output_data[0], attr, ACL_ENGINE_SYS, ACL_COMPILE_SYS, NULL, aclstream)); + } else { + int ret; + Py_BEGIN_ALLOW_THREADS + ret = aclopCompileAndExecuteV2(name.c_str(), input_desc.size(), &input_desc[0], &input_data[0], output_desc.size(), &output_desc[0], &output_data[0], attr, ACL_ENGINE_SYS, ACL_COMPILE_SYS, NULL, aclstream); + Py_END_ALLOW_THREADS + if (ret != 0) + LOGf << "aclopCompileAndExecuteV2" << name << "failed return" << ret; + } + ASSERT(0==aclrtSynchronizeDevice()); + + // printDeviceData(output_desc, output_data, name, false); + } + }; + + } + """, + cuda_src=f""" + // aclop + AclOpRunner op("{name}"); + {input_code} + {output_code} + {attr_code} + op.run();""") + + +def change_function(): + import jittor as jt + from jittor import Function + + class IndexACL(Function): + + def __init__(self): + super(IndexACL, self).__init__() + + def execute(self, inshape: list, dim, dtype="int32"): + # zeros a tensor, shape is inshape, dtype is dtype + dim_input = dim + if dim == None: + dim = [i for i in range(len(inshape))] + elif type(dim) == int: + dim = [dim] + results = [] + for d in dim: + max_len = inshape[d] + tmp = jt.zeros(max_len, dtype=dtype) + result = acl_cmd( + "Range", [jt.Var(0), jt.Var(max_len), + jt.Var(1)], + output_dtypes=[tmp.dtype], + output_shapes=[tmp.shape], + attr={})[0] + broadcast_dim = [] + for i in range(len(inshape)): + if i != d: + broadcast_dim.append(i) + result = jt.broadcast(result, + shape=inshape, + dims=broadcast_dim) + results.append(result) + if len(results) != 1 or dim_input == None: + return tuple(results) + else: + return results[0] + + def grad(self, grad_output): + return grad_output + + class PoolACL(Function): + + def get_paddings(self): + pad_top = self.padding[0] + pad_left = self.padding[1] + H = self.input.shape[-2] + W = self.input.shape[-1] + + totalH = H + 2 * self.padding[0] - self.kernel_size[0] + totalW = W + 2 * self.padding[1] - self.kernel_size[1] + + kH = (totalH + self.stride[0] - + 1) // self.stride[0] + 1 if self.attr[ + 'ceil_mode'] else totalH // self.stride[0] + 1 + kW = (totalW + self.stride[1] - + 1) // self.stride[1] + 1 if self.attr[ + 'ceil_mode'] else totalW // self.stride[1] + 1 + + if self.attr['ceil_mode']: + if (kH - 1) * self.stride[0] >= H + self.padding[0]: + kH -= 1 + need_pad_h = (kH - + 1) * self.stride[0] + self.kernel_size[0] - H + pad_top = need_pad_h - self.padding[0] + if (kW - 1) * self.stride[1] >= W + self.padding[1]: + kW -= 1 + need_pad_w = (kW - + 1) * self.stride[1] + self.kernel_size[1] - W + pad_left = need_pad_w - self.padding[1] + + pads = [self.padding[0], pad_top, self.padding[1], pad_left] + return pads + + def __init__(self, + kernel_size, + stride=None, + padding=0, + dilation=None, + return_indices=None, + ceil_mode=False, + count_include_pad=True, + op='maximum'): + super(PoolACL, self).__init__() + # set attr + self.kernel_size = kernel_size if isinstance( + kernel_size, tuple) else (kernel_size, kernel_size) + stride = stride if stride else kernel_size + self.stride = stride if isinstance(stride, tuple) else (stride, + stride) + self.padding = padding if isinstance(padding, tuple) else (padding, + padding) + dilation = dilation if dilation else 1 + self.dilation = dilation if isinstance( + dilation, tuple) else (dilation, dilation) + attr = {} + + self.return_indices = return_indices + self.uint16 = jt.Var(1).int32().dtype + self.op = op + + if op == 'mean': + attr['exclusive'] = not count_include_pad + attr['global_pooling'] = False + attr['divisor_override_value'] = 0 + attr['ksize'] = [ + 1, 1, self.kernel_size[0], self.kernel_size[1] + ] + attr['strides'] = [1, 1, self.stride[0], self.stride[1]] + attr['ceil_mode'] = ceil_mode + attr['padding_mode'] = 'CALCULATED' + attr['data_format'] = 'NCHW' + elif op == 'maximum': + attr['ksize'] = [ + 1, self.kernel_size[0], self.kernel_size[1], 1 + ] + attr['strides'] = [1, self.stride[0], self.stride[1], 1] + attr['pads'] = [1, self.padding[0], self.padding[1], 1] + attr['dilation'] = [1, self.dilation[0], self.dilation[1], 1] + # attr['ceil_mode'] = ceil_mode + + self.attr = attr + + def execute(self, input): + + # create input + input_shape = input.shape + input_dtype = input.dtype + + self.input = input + # create output + output_shape = [ + input_shape[0], input_shape[1], + (input_shape[2] + 2 * self.padding[0] - self.dilation[0] * + (self.kernel_size[0] - 1) - 1) // self.stride[0] + 1, + (input_shape[3] + 2 * self.padding[1] - self.dilation[1] * + (self.kernel_size[1] - 1) - 1) // self.stride[1] + 1 + ] + output_dtype = input_dtype + + if self.op == 'mean': + self.attr['pads'] = self.get_paddings() + result = acl_cmd("AvgPoolV2", [input], + output_dtypes=[output_dtype], + output_shapes=[output_shape], + attr=self.attr) + elif self.op == 'maximum': + result = acl_cmd("MaxPoolWithArgmaxV1", [input], + output_dtypes=[output_dtype, self.uint16], + output_shapes=[output_shape, output_shape], + attr=self.attr) + else: + raise ValueError('no this type pool') + + if self.op == 'maximum': + self.index = result[1] + + if self.return_indices: + return result[0], result[1] + else: + return result[0] + + def grad(self, grad_output): + if self.op == 'maximum': + grad_input = acl_cmd("MaxPoolGradWithArgmaxV1", + [self.input, grad_output, self.index], + output_dtypes=[grad_output.dtype], + output_shapes=[self.input.shape], + attr=self.attr)[0] + elif self.op == 'mean': + grad_input = acl_cmd("AvgPoolV2", + [self.input, grad_output, self.index], + output_dtypes=[grad_output.dtype], + output_shapes=[self.input.shape], + attr=self.attr)[0] + else: + grad_input = None + return grad_input + + class BmmACL(Function): + + def __init__(self, adj_x1=False, adj_x2=False): + super(BmmACL, self).__init__() + self.adj_x1 = adj_x1 + self.adj_x2 = adj_x2 + + def execute(self, x1, x2): + self.input = [x1, x2] + result = acl_cmd("BatchMatMul", [x1, x2], + output_dtypes=[x1.dtype], + output_shapes=[x1.shape[:-1] + x2.shape[-1:]], + attr={})[0] + return result + + def grad(self, grad_output): + x1, x2 = self.input + grad_x1 = acl_cmd( + "BatchMatMul", [grad_output, x2.transpose(-2, -1)], + output_dtypes=[x1.dtype], + output_shapes=[grad_output.shape[:-1] + x1.shape[-1:]], + attr={})[0] + grad_x2 = acl_cmd( + "BatchMatMul", [x1.transpose(-2, -1), grad_output], + output_dtypes=[x2.dtype], + output_shapes=[x2.shape[:-1] + grad_output.shape[-1:]], + attr={})[0] + return grad_x1, grad_x2 + + class MatmulACL(Function): + + def __init__(self, adj_x1=False, adj_x2=False): + super(MatmulACL, self).__init__() + self.adj_x1 = adj_x1 + self.adj_x2 = adj_x2 + + def execute(self, x1, x2): + self.input = [x1, x2] + if len(x1.shape) > 2 or len(x2.shape) > 2: + result = acl_cmd("BatchMatMul", [x1, x2], + output_dtypes=[x1.dtype], + output_shapes=[x1.shape[:-1] + x2.shape[-1:]], + attr={})[0] + else: + result = acl_cmd("MatMul", [x1, x2], + output_dtypes=[x1.dtype], + output_shapes=[x1.shape[:-1] + x2.shape[-1:]], + attr={})[0] + return result + + def grad(self, grad_output): + x1, x2 = self.input + if len(x1.shape) > 2 or len(x2.shape) > 2: + grad_x1 = acl_cmd( + "BatchMatMul", + [grad_output, x2.transpose(-2, -1)], + output_dtypes=[x1.dtype], + output_shapes=[grad_output.shape[:-1] + x1.shape[-1:]], + attr={})[0] + grad_x2 = acl_cmd( + "BatchMatMul", [x1.transpose(-2, -1), grad_output], + output_dtypes=[x2.dtype], + output_shapes=[x2.shape[:-1] + grad_output.shape[-1:]], + attr={})[0] + else: + grad_x1 = acl_cmd( + "MatMul", [grad_output, x2.transpose(-2, -1)], + output_dtypes=[x1.dtype], + output_shapes=[grad_output.shape[:-1] + x1.shape[-1:]], + attr={})[0] + grad_x2 = acl_cmd( + "MatMul", [x1.transpose(-2, -1), grad_output], + output_dtypes=[x2.dtype], + output_shapes=[x2.shape[:-1] + grad_output.shape[-1:]], + attr={})[0] + return grad_x1, grad_x2 + + class GetItem(Function): + + def __init__(self): + super(GetItem, self).__init__() + self.type_ = 'index' + + def stride(self, x, dim): + stride = 1 + for i in range(dim + 1, len(x.shape)): + stride *= x.shape[i] + return stride + + def execute(self, x, slices, return_x=None): + if isinstance(slices, jt.Var) or isinstance(slices, tuple): + if isinstance(slices, jt.Var): + slices = (slices, ) + if isinstance(slices[0], jt.Var): + slices_len = len(slices) + masks = jt.ones(slices_len, dtype=jt.int64) + output = slices[0].shape + output += x.shape[slices_len:] + input_ = [x, masks, jt.Var(list(output)).int64()] + for i in range(slices_len): + input_.append(slices[i].int32()) + result = acl_cmd("Index", + input_, + output_dtypes=[x.dtype], + output_shapes=[output], + attr={})[0] + self.shape = x.shape + self.sizes = list(output) + self.type_ = 'index' + self.slices = slices + # self.strides + return result + + # use AsStrided operator to implement the getitem function + # get the shape and stride of the input tensor + x_dim = len(x.shape) + # int type + if not isinstance(slices, tuple): + slices = (slices, ) + + if len(slices) < x_dim: + slices += (slice(None, None, None), ) * (x_dim - len(slices)) + + self.inputs = [x, slices] + + sizes = [] + strides = [] + offset = 0 + + for dim, s in enumerate(slices): + if isinstance(s, int): + if s < 0: # Handle negative indices. + s += x.shape[dim] + offset += s * self.stride(x, dim) + elif isinstance(s, slice): + # Unpack the slice + start, stop, step = s.indices(x.size(dim)) + size = (stop - start - 1) // step + 1 + stride = self.stride(x, dim) * step + offset += start * self.stride(x, dim) + sizes.append(size) + strides.append(stride) + else: + raise ValueError("Invalid slice type") + + if not sizes: + sizes = [1] + strides = [0] + # AsStrided same with as_strided of pytorch + self.sizes = sizes + self.strides = strides + self.offset = offset + self.shape = x.shape + self.type_ = 'as_strided' + result = acl_cmd( + "AsStrided", + [x, jt.Var(sizes), + jt.Var(strides), + jt.Var(offset)], + output_dtypes=[x.dtype], + output_shapes=[jt.empty(sizes).shape], + attr={})[0] + return result + + def grad(self, grad_output): + if self.type_ == 'as_strided': + result = jt.zeros(self.shape, dtype=grad_output.dtype) + sizes = list(grad_output.shape) + strides = [ + self.stride(grad_output, dim) + for dim in range(len(grad_output.shape)) + ] + result = acl_cmd("ViewCopy", [ + result, + jt.Var(self.sizes), + jt.Var(self.strides), + jt.Var(self.offset), grad_output, + jt.Var(sizes), + jt.Var(strides), + jt.Var(0) + ], + output_dtypes=[result.dtype], + output_shapes=[result.shape], + attr={})[0] + elif self.type_ == 'index': + #TODO: use IndexPutV2 to implement the grad function + assert len(self.slices) == 1 + index = self.slices[0] + input = jt.zeros(self.shape, dtype=grad_output.dtype) + input_flatten = input.reshape(input.shape[0], -1) + index_flatten = index.reshape(-1).unsqueeze(-1).repeat( + 1, input_flatten.shape[1]) + grad_output_flatten = grad_output.reshape(index.numel(), -1) + result = acl_cmd( + "ScatterElements", + [input_flatten, index_flatten, grad_output_flatten], + output_dtypes=[input.dtype], + output_shapes=[input.shape], + attr={ + 'axis': 0, + 'reduction': 'add' + })[0] + result = result.reshape(self.shape) + # result = jt.zeros(self.shape, dtype=grad_output.dtype) + # # masks = jt.ones(len(self.slices), dtype=jt.int64) + # masks = jt.array([1,1], dtype=jt.int64) + # expand_masks = jt.array([1,1], dtype=jt.int64) + # inputs_ = [result,grad_output,masks,expand_masks] + # slices_len = len(self.slices) + # for i in range(slices_len): + # inputs_.append(self.slices[i].int64()) + # # breakpoint() + # jt.sync_all(True) + # print(inputs_) + # result_ = acl_cmd("IndexPutV2", inputs_, + # output_dtypes=[result.dtype], + # output_shapes=[result.shape], + # attr={"accumulate":True})[0] + # result = result_ + else: + raise ValueError("Invalid slice type") + result.sync() + return result, None + + class ConcatACL(Function): + + def __init__(self): + super(ConcatACL, self).__init__() + + def execute(self, input_tensors, dim=0): + self.input = input_tensors + for i in range(len(input_tensors)): + if input_tensors[i].dtype != input_tensors[0].dtype: + raise ValueError( + "All input tensors must have the same dtype") + if input_tensors[i].shape[:dim] != input_tensors[ + 0].shape[:dim] or input_tensors[i].shape[ + dim + 1:] != input_tensors[0].shape[dim + 1:]: + raise ValueError( + "All input tensors must have the same shape") + result = acl_cmd( + "ConcatD", + input_tensors, + output_dtypes=[input_tensors[0].dtype], + output_shapes=[ + jt.empty(self.calculate_output_shape(input_tensors, + dim)).shape + ], + attr={ + "N": len(input_tensors), + "concat_dim": dim + })[0] + return result + + def grad(self, grad_output): + grad_inputs = self.split_grad(grad_output, self.input, self.axis) + return grad_inputs + + def calculate_output_shape(self, input_tensors, axis): + shape = list(input_tensors[0].shape) + for tensor in input_tensors[1:]: + shape[axis] += tensor.shape[axis] + return tuple(shape) + + def split_grad(self, grad_output, input_tensors, axis): + offset = 0 + grad_inputs = [] + for tensor in input_tensors: + grad_input = acl_cmd("Slice", [ + grad_output, [0] * axis + [offset] + [0] * + (len(tensor.shape) - axis - 1), tensor.shape + ]) + grad_inputs.append(grad_input) + offset += tensor.shape[axis] + return grad_inputs + + class SetItemACL(Function): + + def __init__(self): + super(SetItemACL, self).__init__() + + def stride(self, x, dim): + # 计算给定维度的步长 + stride = 1 + for i in range(dim + 1, len(x.shape)): + stride *= x.shape[i] + return stride + + def execute(self, x, slices, value, reduce='void'): + self.is_tensor = type(value) == jt.Var + if type(value) != jt.Var: + value = jt.array(value) + x_dim = len(x.shape) + + # 确保slices是一个元组 + if not isinstance(slices, tuple): + slices = (slices, ) + + # 补齐slices使其长度等于x的维度 + if len(slices) < x_dim: + slices += (slice(None, None, None), ) * (x_dim - len(slices)) + + self.inputs = [x, slices, value] + + target_sizes = [] + target_strides = [] + offset = 0 + + for dim, s in enumerate(slices): + if isinstance(s, int): + if s < 0: + s += x.shape[dim] + s = slice(s, s + 1, None) + if isinstance(s, slice): + # 解包切片 + start, stop, step = s.indices(x.shape[dim]) + size = (stop - start - 1) // step + 1 + stride = self.stride(x, dim) * step + offset += start * self.stride(x, dim) + target_sizes.append(size) + target_strides.append(stride) + else: + print("slices: ", s, type(s)) + raise ValueError("Invalid slice type") + + # 计算value的size、stride和offset + value_sizes = list(value.shape) + value_strides = [ + self.stride(value, dim) for dim in range(len(value.shape)) + ] + + self.target_sizes = target_sizes + self.target_strides = target_strides + self.offset = offset + self.value_sizes = value_sizes + self.value_strides = value_strides + + #import pdb; pdb.set_trace() + result = acl_cmd("ViewCopy", [ + x, + jt.Var(target_sizes), + jt.Var(target_strides), + jt.Var(offset), value, + jt.Var(value_sizes), + jt.Var(value_strides), + jt.Var(0) + ], + output_dtypes=[x.dtype], + output_shapes=[x.shape], + attr={})[0] + result.sync() + return result + + def grad(self, grad_output): + result = acl_cmd("AsStrided", [ + grad_output, + jt.Var(self.target_sizes), + jt.Var(self.target_strides), + jt.Var(self.offset) + ], + output_dtypes=[grad_output.dtype], + output_shapes=[jt.empty(self.target_sizes).shape], + attr={})[0] + # copy grad_output to new_grad_output + new_grad_output = acl_cmd("Copy", [grad_output], + output_dtypes=[grad_output.dtype], + output_shapes=[grad_output.shape], + attr={"N": 1})[0] + new_grad_output = acl_cmd("ViewCopy", [ + new_grad_output, + jt.Var(self.target_sizes), + jt.Var(self.target_strides), + jt.Var(self.offset), + jt.zeros(self.value_sizes, dtype=grad_output.dtype), + jt.Var(self.value_sizes), + jt.Var(self.value_strides), + jt.Var(0) + ], + output_dtypes=[grad_output.dtype], + output_shapes=[grad_output.shape], + attr={})[0] + new_grad_output.sync() + return new_grad_output, None, result if self.is_tensor else None + + class TriuACL(Function): + + def __init__(self): + super(TriuACL, self).__init__() + + def execute(self, input, k): + self.input = input + result = acl_cmd("Triu", [input], + output_dtypes=[input.dtype], + output_shapes=[input.shape], + attr={'diagonal': k})[0] + return result + + def grad(self, grad_output): + return grad_output + + class TransposeACL(Function): + + def __init__(self): + super(TransposeACL, self).__init__() + + def execute(self, input, perm): + self.input = input + + output_shape = input.shape[perm[0]:perm[0] + 1] + + for i in range(1, len(perm)): + output_shape += input.shape[perm[i]:perm[i] + 1] + result = acl_cmd("Transpose", [input, jt.Var(perm)], + output_dtypes=[input.dtype], + output_shapes=[output_shape], + attr={})[0] + return result + + def grad(self, grad_output): + return grad_output + + class AdaptiveMaxPool2dACL(Function): + + def __init__( + self, + output_size, + return_indices=False, + ): + super(AdaptiveMaxPool2dACL, self).__init__() + self.output_size = (output_size, output_size) if isinstance( + output_size, int) else output_size + + self.return_indices = return_indices + self.uint16 = jt.Var(1).int32().dtype + + attr = {} + attr['ceil_mode'] = False + attr['dilations'] = [1, 1, 1, 1] + self.attr = attr + + def execute(self, input): + input_shape = input.shape + input_dtype = input.dtype + + output_shape = [ + input_shape[0], input_shape[1], self.output_size[0], + self.output_size[1] + ] + output_dtype = input_dtype + self.input = input + + stride_h = input_shape[2] // output_shape[2] + stride_w = input_shape[3] // output_shape[3] + kernel_size_h = input_shape[2] - (output_shape[2] - 1) * stride_h + kernel_size_w = input_shape[3] - (output_shape[3] - 1) * stride_w + + stride = [0, 0] + kernel_size = [0, 0] + padding = [0, 0] + + stride[0] = stride_h + stride[1] = stride_w + kernel_size[0] = kernel_size_h + kernel_size[1] = kernel_size_w + padding[0] = padding[1] = 0 + kernel_sizes = [1, kernel_size[0], kernel_size[1], 1] + strides_size = [1, stride[0], stride[1], 1] + paddings = [1, padding[0], padding[1], 1] + + self.attr['ksize'] = kernel_sizes + self.attr['strides'] = strides_size + self.attr['pads'] = paddings + + result = acl_cmd("MaxPoolWithArgmaxV1", [input], + output_dtypes=[output_dtype, self.uint16], + output_shapes=[output_shape, output_shape], + attr=self.attr) + + self.index = result[1] + + if self.return_indices: + return result[0], result[1] + else: + return result[0] + + def grad(self, grad_output): + grad_input = acl_cmd("MaxPoolGradWithArgmaxV1", + [self.input, grad_output, self.index], + output_dtypes=[grad_output.dtype], + output_shapes=[self.input.shape], + attr=self.attr)[0] + return grad_input + + class AdaptiveAvgPool2dACL(Function): + + def __init__(self, output_size): + super(AdaptiveAvgPool2dACL, self).__init__() + self.output_size = (output_size, output_size) if isinstance( + output_size, int) else output_size + + attr = {} + if isinstance(output_size, tuple): + output_size = [output_size[0], output_size[1]] + attr['output_size'] = output_size + self.attr = attr + + def execute(self, input): + input_shape = input.shape + input_dtype = input.dtype + + self.original_shape = input_shape + + output_shape = [ + input_shape[0], input_shape[1], self.attr['output_size'][0], + self.attr['output_size'][1] + ] + output_dtype = input_dtype + self.input = input + + result = acl_cmd("AdaptiveAvgPool2d", [input], + output_dtypes=[output_dtype], + output_shapes=[output_shape], + attr=self.attr) + + return result[0] + + def grad(self, grad_output): + attr = {} + attr['orig_input_shape'] = list(self.original_shape) + grad_input = acl_cmd("AdaptiveAvgPool2dGrad", [grad_output], + output_dtypes=[grad_output.dtype], + output_shapes=[self.original_shape], + attr=attr)[0] + return grad_input + + class CumsumACL(Function): + + def __init__(self): + super(CumsumACL, self).__init__() + + def execute(self, input, dim=-1): + self.input = input + self.dim = dim + result = acl_cmd("Cumsum", [input, jt.Var(dim)], + output_dtypes=[input.dtype], + output_shapes=[input.shape], + attr={})[0] + return result + + def grad(self, grad_output): + flipped_grad_output = acl_cmd( + "ReverseV2", [grad_output, jt.Var([self.dim])], + output_dtypes=[grad_output.dtype], + output_shapes=[grad_output.shape], + attr={})[0] + cumulative_grad = acl_cmd( + "Cumsum", + [flipped_grad_output, jt.Var(self.dim)], + output_dtypes=[grad_output.dtype], + output_shapes=[grad_output.shape], + attr={})[0] + grad_input = acl_cmd( + "ReverseV2", + [cumulative_grad, jt.Var([self.dim])], + output_dtypes=[grad_output.dtype], + output_shapes=[grad_output.shape], + attr={})[0] + return grad_input + + class GatherACL(Function): + + def __init__(self): + super(GatherACL, self).__init__() + + def execute(self, input, dim, index): + self.input = input + self.dim = dim + self.index = index + + result = acl_cmd("GatherElements", [input, index], + output_dtypes=[input.dtype], + output_shapes=[index.shape], + attr={'dim': dim})[0] + return result + + def grad(self, grad_output): + tmp = jt.zeros(self.index.shape, dtype=grad_output.dtype) + grad_input = acl_cmd("ScatterElements", + [tmp, self.index, grad_output], + output_dtypes=[grad_output.dtype], + output_shapes=[tmp.shape], + attr={ + 'axis': self.dim, + 'reduction': "add" + })[0] + return grad_input + + class ScatterACL(Function): + + def __init__(self): + super(ScatterACL, self).__init__() + + def execute(self, input, dim, index, src, reduce='void'): + self.input = input + self.dim = dim + self.index = index + self.reduce = reduce + result = acl_cmd("ScatterElements", [input, self.index, src], + output_dtypes=[input.dtype], + output_shapes=[input.shape], + attr={ + 'axis': self.dim, + 'reduction': reduce + })[0] + return result + + def grad(self, grad_output): + grad_input = acl_cmd("GatherElements", [grad_output, self.index], + output_dtypes=[grad_output.dtype], + output_shapes=[self.index.shape], + attr={'dim': self.dim})[0] + return grad_output, None, None, grad_input + + class WhereACL(Function): + + def __init__(self): + super(WhereACL, self).__init__() + + def execute(self, condition, x, y): + self.condition = condition + + if x.dtype != y.dtype: + if x.dtype == jt.float32: + y = y.float32() + elif y.dtype == jt.float32: + x = x.float32() + else: + x = x.to(y.dtype) + + self.x = x + self.y = y + + result = acl_cmd("Select", [condition, x, y], + output_dtypes=[x.dtype], + output_shapes=[x.shape], + attr={})[0] + return result + + def grad(self, grad_output): + tmp = jt.zeros(grad_output.shape, dtype=grad_output.dtype) + grad_x = acl_cmd("Select", [self.condition, grad_output, tmp], + output_dtypes=[self.x.dtype], + output_shapes=[self.x.shape], + attr={})[0] + + grad_y = acl_cmd("Select", [self.condition, tmp, grad_output], + output_dtypes=[self.y.dtype], + output_shapes=[self.y.shape], + attr={})[0] + return grad_output, grad_x, grad_y + + class FlipACL(Function): + + def __init__(self): + super(FlipACL, self).__init__() + + def execute(self, input, dim): + self.input = input + #if isinstance(dim_vector, tuple): + dim_vector = jt.Var(list(dim)) + #print(dim_vector.dtype) + self.dim_vector = dim_vector + #print(input, dim_vector) + result = acl_cmd("ReverseV2", [input, dim_vector], + output_dtypes=[input.dtype], + output_shapes=[input.shape], + attr={})[0] + return result + + def grad(self, grad_output): + #print(grad_output) + grad_input = acl_cmd("ReverseV2", [grad_output, self.dim_vector], + output_dtypes=[grad_output.dtype], + output_shapes=[grad_output.shape], + attr={})[0] + return grad_input + + class FloorIntACL(Function): + + def __init__(self): + super(FloorIntACL, self).__init__() + + def execute(self, input): + self.input = input + self.shape = input.shape + result = acl_cmd("Floor", [input], + output_dtypes=[jt.int], + output_shapes=[input.shape], + attr={})[0] + return result + + def grad(self, grad_output): + return jt.zeros(self.shape, dtype=grad_output.dtype) + + def warp(origin_func, new_func): + + def warpper(*args, **kwargs): + if origin_func == jt.index: + if len(args) == 2 and args[1] == None: + args = tuple(list(args[0:1])) + if jt.flags.use_acl: + if isinstance(new_func, IndexACL): + if len(args) == 1: + args = (args[0], None) + if isinstance(new_func, CumsumACL): + args = (args[0], kwargs.get('dim', -1)) + kwargs = {} + if isinstance(new_func, + ScatterACL) and kwargs.get('reduce') is not None: + args = (args[0], args[1], args[2], args[3], + kwargs.get('reduce', 'void')) + kwargs = {} + + return new_func(*args, **kwargs) + return origin_func(*args, **kwargs) + + return warpper + + jt.index = warp(jt.index, IndexACL()) + jt.Var.index = lambda x, dim=None: warp(jt.index, IndexACL())(x.shape, dim) + jt.nn.Pool = warp(jt.nn.Pool, PoolACL) + jt.nn.AdaptiveMaxPool2d = warp(jt.nn.AdaptiveMaxPool2d, + AdaptiveMaxPool2dACL) + jt.nn.AdaptiveAvgPool2d = warp(jt.nn.AdaptiveAvgPool2d, + AdaptiveAvgPool2dACL) + + jt.triu = warp(jt.triu, TriuACL()) + jt.triu_ = warp(jt.triu, TriuACL()) + jt.Var.triu = lambda x: warp(jt.Var.triu, TriuACL())(x) + jt.Var.triu_ = lambda x: warp(jt.Var.triu_, TriuACL())(x) + + jt.getitem = warp(jt.getitem, GetItem()) + jt.Var.getitem = lambda x, slices, return_x=None: warp( + jt.getitem, GetItem())(x, slices) + + jt.setitem = warp(jt.setitem, SetItemACL()) + jt.Var.setitem = lambda x, slices, value, reduce='void': warp( + jt.setitem, SetItemACL())(x, slices, value, reduce) + + jt.misc.flip = warp(jt.misc.flip, FlipACL()) + jt.Var.flip = lambda x, dim_vector: warp(jt.misc.flip, FlipACL())( + x, dim_vector) + jt.cumsum = warp(jt.cumsum, CumsumACL()) + jt.gather = warp(jt.gather, GatherACL()) + jt.Var.gather = lambda x, dim, index: warp(jt.gather, GatherACL())(x, dim, + index) + jt.scatter = warp(jt.scatter, ScatterACL()) + jt.Var.scatter = lambda x, dim, index, src, reduce="void": warp( + jt.scatter, ScatterACL())(x, dim, index, src, reduce) + jt.where = warp(jt.where, WhereACL()) + jt.floor_int = warp(jt.floor_int, FloorIntACL()) + jt.Var.floor_int = lambda x: warp(jt.floor_int, FloorIntACL())(x) + + # jt.nn.bmm = warp(jt.nn.bmm, BmmACL()) + # jt.bmm = warp(jt.bmm, BmmACL()) + # jt.nn.matmul = warp(jt.matmul, MatmulACL()) + # jt.matmul = warp(jt.matmul, MatmulACL()) + # jt.transpose = warp(jt.transpose, TransposeACL()) + # jt.Var.transpose = lambda x, perm: warp(jt.transpose, TransposeACL())(x, perm) + # jt.concat = warp(jt.concat, ConcatACL()) diff --git a/python/jittor/extern/acl/acl_jittor.cc b/python/jittor/extern/acl/acl_jittor.cc index dc99887f..be0c17f4 100644 --- a/python/jittor/extern/acl/acl_jittor.cc +++ b/python/jittor/extern/acl/acl_jittor.cc @@ -16,6 +16,7 @@ namespace jittor { uint64_t acl_jittor_tid; int acl_jittor_thread_running=0; aclrtContext acl_jittor_context; +aclrtStream aclstream; #define CHECK_ACL(x) ASSERTop(x,==,0) @@ -28,7 +29,7 @@ static void* acl_jittor_process_callback(void*) { // LOGir << "acl_jittor_process_callback"; auto ret = aclrtProcessReport(1000); if (ret) { - if (acl_jittor_thread_running && ret != ACL_ERROR_RT_REPORT_TIMEOUT) + if (acl_jittor_thread_running && ret != ACL_ERROR_RT_REPORT_TIMEOUT && ret != ACL_ERROR_RT_THREAD_SUBSCRIBE) LOGir << "aclrtProcessReport:" << ret << acl_error_to_string(ret); break; } @@ -59,10 +60,9 @@ acl_jittor_initer() { CHECK_ACL(aclrtSubscribeReport(acl_jittor_tid,0)); // simple callback test - // aclrtStream stream; - // CHECK_ACL(aclrtCreateStream(&stream)); - // CHECK_ACL(aclrtSubscribeReport(acl_jittor_tid,stream)); - // CHECK_ACL(aclrtLaunchCallback((aclrtCallback)&aaa, 0, ACL_CALLBACK_NO_BLOCK, stream)); + CHECK_ACL(aclrtCreateStream(&aclstream)); + // CHECK_ACL(aclrtSubscribeReport(acl_jittor_tid,aclstream)); + // CHECK_ACL(aclrtLaunchCallback((aclrtCallback)&aaa, 0, ACL_CALLBACK_NO_BLOCK, aclstream)); // CHECK_ACL(aclrtLaunchCallback((aclrtCallback)&aaa, 0, ACL_CALLBACK_NO_BLOCK, 0)); } @@ -87,7 +87,7 @@ string process_acl(const string& src, const string& name, const map fake_class = { "cudnnHandle_t", "cudnnConvolutionBwdFilterAlgo_t", @@ -117,6 +117,7 @@ string process_acl(const string& src, const string& name, const map()", "-1e30"); new_src = token_replace_all(new_src, "::numeric_max()", "1e30"); // TODO: support max diff --git a/python/jittor/extern/acl/acl_jittor.h b/python/jittor/extern/acl/acl_jittor.h index 2fc6613b..0ef90b40 100644 --- a/python/jittor/extern/acl/acl_jittor.h +++ b/python/jittor/extern/acl/acl_jittor.h @@ -13,6 +13,7 @@ std::string acl_error_to_string(aclError error); namespace jittor { EXTERN_LIB uint64_t acl_jittor_tid; +EXTERN_LIB aclrtStream aclstream; void acl_jittor_op_compiler(string& filename, string& src, bool is_acl, string& extra_flags); diff --git a/python/jittor/extern/acl/acl_op_exec.cc b/python/jittor/extern/acl/acl_op_exec.cc index 0eb130d3..07b35145 100644 --- a/python/jittor/extern/acl/acl_op_exec.cc +++ b/python/jittor/extern/acl/acl_op_exec.cc @@ -16,7 +16,9 @@ #include "ops/reduce_op.h" #include "ops/binary_op.h" #include "ops/broadcast_to_op.h" +#include "ops/transpose_op.h" #include "ops/array_op.h" +#include "ops/code_op.h" #include "fused_op.h" #include "ops/unary_op.h" #include "ops/ternary_op.h" @@ -32,6 +34,44 @@ namespace jittor { using std::swap; +void printDeviceData(const vector& output_desc, const vector& output_data, const string& name = "", bool input=true) { + LOGir << "name: " << name; + if(input) + LOGir << "is input"; + else + LOGir << "is ouput"; + for (size_t i = 0; i < output_desc.size(); ++i) { + void* base_addr = aclGetDataBufferAddr(output_data[i]); + LOGir << "addr of data[" << i << "] :" << base_addr; + size_t num_dims = aclGetTensorDescNumDims(output_desc[i]); + size_t total_size = 1; + std::vector dims(num_dims); + + std::cout << "shape of data: "; + for (size_t j = 0; j < num_dims; ++j) { + aclGetTensorDescDimV2(output_desc[i], j, &dims[j]); + total_size *= dims[j]; + std::cout << dims[j] << ", "; + } + int evey_batch_size = total_size/dims[0]; + std::cout << std::endl; + + // for(int i= 0; i < dims[0]; i++) { + // evey_batch_size = 16; + // std::vector host_buffer(evey_batch_size); + // void* offset_addr = static_cast(base_addr) + i * evey_batch_size * sizeof(float); + // aclrtMemcpy(host_buffer.data(), evey_batch_size * sizeof(float), offset_addr, evey_batch_size * sizeof(float), ACL_MEMCPY_DEVICE_TO_HOST); + // std::cout << "batch[" << i << "]:"; + // for (size_t k = 0; k < evey_batch_size; ++k) { + // std::cout << host_buffer[k] << ", "; + // } + // std::cout << std::endl; + // if(i >= 3) + // break; + // } + } +} + struct AclOpRunner { string name; vector input_desc; @@ -40,6 +80,7 @@ struct AclOpRunner { vector output_data; aclopAttr *attr; vector> input_host; + vector> input_host_32; AclOpRunner(const string& name) : name(name) { attr = aclopCreateAttr(); @@ -56,9 +97,14 @@ struct AclOpRunner { aclDataType get_dtype(NanoString s) { if (s == ns_float32) return ACL_FLOAT; if (s == ns_float16) return ACL_FLOAT16; + if (s == ns_int64) return ACL_INT64; if (s == ns_int32) return ACL_INT32; if (s == ns_int8) return ACL_INT8; + if (s == ns_int16) return ACL_INT16; if (s == ns_uint8) return ACL_UINT8; + if (s == ns_uint16) return ACL_UINT16; + if (s == ns_uint32) return ACL_UINT32; + if (s == ns_bool) return ACL_BOOL; LOGf << "Not supported dtype: " << s; return ACL_FLOAT; } @@ -138,44 +184,63 @@ struct AclOpRunner { auto data = aclCreateDataBuffer(&v[0], v.size()*sizeof(int)); input_desc.push_back(desc); input_data.push_back(data); - // input_host.emplace_back(move(v)); + input_host_32.emplace_back(move(v)); } void set_attr(const string& key, bool value) { + // LOGir << "string bool" << "set_attr" << key << value; CHECK(aclopSetAttrBool(attr, key.c_str(), value)==0); } + void set_attr(const string& key, int value, int is_bool) { + // LOGir << "string bool" << "set_attr" << key << value << is_bool; + CHECK(aclopSetAttrBool(attr, key.c_str(), value==is_bool)==0); + } void set_attr(const string& key, float value) { + // LOGir << "string float" <<"set_attr" << key << value; CHECK(aclopSetAttrFloat(attr, key.c_str(), value)==0); } void set_attr(const string& key, int64_t value) { + // LOGir << "string int64" << "set_attr" << key << value; + CHECK(aclopSetAttrInt(attr, key.c_str(), value)==0); + } + void set_attr(const string& key, int64_t value, int placeholder) { + // LOGir << "string int64" << "set_attr" << key << value; CHECK(aclopSetAttrInt(attr, key.c_str(), value)==0); } void set_attr(const string& key, int32 value) { + // LOGir << "string int32" << "set_attr" << key << value; CHECK(aclopSetAttrInt(attr, key.c_str(), value)==0); } void set_attr(const string& key, vector value) { + // LOGir << "string vector" << "set_attr" << key << value; CHECK(aclopSetAttrListInt(attr, key.c_str(), value.size(), &value[0])==0); } void set_attr(const string& key, string value) { + // LOGir << "string string" << "set_attr" << key << value; CHECK(aclopSetAttrString(attr, key.c_str(), value.c_str())==0); } void set_attr(const char* key, const char* value) { + // LOGir << "char" << "set_attr" << key << value; CHECK(aclopSetAttrString(attr, key, value)==0); } void run() { + // printDeviceData(input_desc, input_data, name); + LOGv << "run" << name << input_desc.size() << output_desc.size(); if (!PyGILState_Check()) { - ASSERT(0==aclopCompileAndExecuteV2(name.c_str(), input_desc.size(), &input_desc[0], &input_data[0], output_desc.size(), &output_desc[0], &output_data[0], attr, ACL_ENGINE_SYS, ACL_COMPILE_SYS, NULL, NULL)); + ASSERT(0==aclopCompileAndExecuteV2(name.c_str(), input_desc.size(), &input_desc[0], &input_data[0], output_desc.size(), &output_desc[0], &output_data[0], attr, ACL_ENGINE_SYS, ACL_COMPILE_SYS, NULL, aclstream)); } else { int ret; Py_BEGIN_ALLOW_THREADS - ret = aclopCompileAndExecuteV2(name.c_str(), input_desc.size(), &input_desc[0], &input_data[0], output_desc.size(), &output_desc[0], &output_data[0], attr, ACL_ENGINE_SYS, ACL_COMPILE_SYS, NULL, NULL); + ret = aclopCompileAndExecuteV2(name.c_str(), input_desc.size(), &input_desc[0], &input_data[0], output_desc.size(), &output_desc[0], &output_data[0], attr, ACL_ENGINE_SYS, ACL_COMPILE_SYS, NULL, aclstream); Py_END_ALLOW_THREADS if (ret != 0) LOGf << "aclopCompileAndExecuteV2" << name << "failed return" << ret; } - // ASSERT(0==aclrtSynchronizeDevice()); + ASSERT(0==aclrtSynchronizeDevice()); + + // printDeviceData(output_desc, output_data, name, false); } }; @@ -318,6 +383,17 @@ void try_exec_and_fallback_cpu(Op* op) { auto iter = opname_map.find(bop->ns); ASSERT(iter != opname_map.end()) << "op " << bop->ns << " not found"; op.name = iter->second; + if (bop->x->dtype() == ns_bool and bop->y->dtype() == ns_bool) + { + // BitwiseOr, BitwiseAnd, BitwiseXor -> LogicalOr, LogicalAnd, LogicalXor + if (bop->ns == ns_bitwise_or) { + op.name = "LogicalOr"; + } else if (bop->ns == ns_bitwise_and) { + op.name = "LogicalAnd"; + } else if (bop->ns == ns_bitwise_xor) { + op.name = "LogicalXor"; + } + } op.run(); } else if (op->name() == string("ternary")) { @@ -345,7 +421,7 @@ void try_exec_and_fallback_cpu(Op* op) { else if (rop->ns == ns_minimum) op.name = "ReduceMin"; else if (rop->ns == ns_mean) - op.name = "Reduce"; + op.name = "ReduceMean"; else LOGf << "op " << rop->ns << " not supported"; op.add(rop->x, true); @@ -381,6 +457,19 @@ void try_exec_and_fallback_cpu(Op* op) { op.add_input_host_nv(zshape, ACL_INT64); op.add(bop->z, false); op.run(); + } + else + if (op->name() == string("fuse_transpose")) { + // replace fuse_transpose with transpose + auto top = (TransposeOp*)op; + AclOpRunner op("Transpose"); + op.add(top->x, true); + op.add(top->y, false); + vector axes; + for (int i=0; iaxes.size(); i++) + axes.push_back(top->axes[i]); + op.add_input_host(axes, ACL_INT64); + op.run(); } else { LOGf << "op " << op->name() << " not supported"; @@ -388,6 +477,7 @@ void try_exec_and_fallback_cpu(Op* op) { } } catch (std::exception& e) { fallback = 1; + LOGir << "fallback cpu" << e.what(); } for (auto v : new_alloced) { free_var_mem(v); @@ -401,7 +491,7 @@ extern int current_seed; extern int64 current_offset; static unordered_map> acl_ops = { -{"curand_random", [&](Op* op) { +{"curand_random", [¤t_seed, ¤t_offset](Op* op) { auto _op = (RandomOp*)op; AclOpRunner runner(_op->type == ns_uniform ? "StatelessRandomUniformV2" : "StatelessRandomNormalV2"); auto out = op->output(0); @@ -429,7 +519,21 @@ static unordered_map> acl_ops = { runner.set_attr("transpose_x2", _op->trans_b); runner.run(); }}, -{"cudnn_conv", [&](Op* op) { +{"cublas_batched_matmul", [&](Op* op) { + struct BatchedMatmulOp : Op { + Var* a, *b, *c; + bool adj_x1, adj_x2; + }; + auto _op = (BatchedMatmulOp*)op; + AclOpRunner runner("BatchMatMul"); + runner.add(_op->a, true); + runner.add(_op->b, true); + runner.add(_op->c, false); + runner.set_attr("adj_x1", _op->adj_x1); + runner.set_attr("adj_x2", _op->adj_x2); + runner.run(); +}}, +{"cudnn_conv", [](Op* op) { struct ConvOp : Op { Var* x, * w, * y; int strideh, stridew, paddingh, paddingw, dilationh, dilationw, groups; @@ -451,7 +555,7 @@ static unordered_map> acl_ops = { auto _op = (ConvOp*)op; _op->run_acl(); }}, -{"cudnn_conv_backward_x", [&](Op* op) { +{"cudnn_conv_backward_x", [](Op* op) { struct ConvBackwardXOp : Op { Var* w, * dy, * dx; int xh, xw, strideh, stridew, paddingh, paddingw, dilationh, dilationw, groups; @@ -480,7 +584,7 @@ static unordered_map> acl_ops = { auto _op = (ConvBackwardXOp*)op; _op->run_acl(); }}, -{"cudnn_conv_backward_w", [&](Op* op) { +{"cudnn_conv_backward_w", [](Op* op) { struct ConvBackwardWOp : Op { Var* x, * dy, * dw; int kh, kw, strideh, stridew, paddingh, paddingw, dilationh, dilationw, groups; @@ -488,7 +592,7 @@ static unordered_map> acl_ops = { void run_acl() { AclOpRunner runner("Conv2DBackpropFilter"); runner.add(x, true, ACL_FORMAT_NCHW); - runner.add_input_host_nv32(x->shape); + runner.add_input_host_nv32(dw->shape); runner.add(dy, true, ACL_FORMAT_NCHW); runner.add(dw, false, ACL_FORMAT_NCHW); runner.set_attr("strides", vector{1,1,strideh,stridew}); @@ -538,7 +642,7 @@ static jit_op_entry_t acl_do_compile(Op* op) { return oc.compile(op->get_jit_key(get_jk()), *src); } if (op->name() == string("fused")) { - FusedOp* fop = (FusedOp*)op; + FusedOp* fop = (FusedOp*)op; // if is a relayed op if (fop->context->vrm.relay_groups.size()) { LOGv << "relay fused op"; @@ -546,7 +650,17 @@ static jit_op_entry_t acl_do_compile(Op* op) { } else { return &try_exec_and_fallback_cpu; } - } else { + } else + if (op->name() == string("code")) { + CodeOp* cop = (CodeOp*)op; + if (cop->cuda_src.find("acl") != string::npos) { + LOGv << "compile acl op"; + return oc.compile(op->get_jit_key(get_jk()), *src); + } else { + return &exec_mapped_acl_ops; + } + } else + { LOGv << "compile finish" << op; return &exec_mapped_acl_ops; } diff --git a/python/jittor/extern/cuda/cublas/inc/cublas_wrapper.h b/python/jittor/extern/cuda/cublas/inc/cublas_wrapper.h index be01348f..7667e495 100644 --- a/python/jittor/extern/cuda/cublas/inc/cublas_wrapper.h +++ b/python/jittor/extern/cuda/cublas/inc/cublas_wrapper.h @@ -25,7 +25,9 @@ static inline cudaDataType get_dtype(NanoString dtype) { if (dtype == ns_float32) return CUDA_R_32F; if (dtype == ns_float64) return CUDA_R_64F; if (dtype == ns_float16) return CUDA_R_16F; + #ifndef IS_ROCM if (dtype == ns_bfloat16) return CUDA_R_16BF; + #endif LOGf << "not support type" << dtype; return CUDA_R_32F; } diff --git a/python/jittor/extern/cuda/cudnn/inc/cudnn_wrapper.h b/python/jittor/extern/cuda/cudnn/inc/cudnn_wrapper.h index 2109a03a..73fb3233 100644 --- a/python/jittor/extern/cuda/cudnn/inc/cudnn_wrapper.h +++ b/python/jittor/extern/cuda/cudnn/inc/cudnn_wrapper.h @@ -8,8 +8,9 @@ #include #include #include +#ifndef IS_ROCM #include - +#endif #include "utils/log.h" #include "helper_cuda.h" #include "fp16_emu.h" @@ -31,6 +32,8 @@ void set_max_workspace_ratio(float64 ratio); template __inline__ cudnnDataType_t getDataType(); template <> __inline__ cudnnDataType_t getDataType() { return CUDNN_DATA_HALF; } template <> __inline__ cudnnDataType_t getDataType() { return CUDNN_DATA_FLOAT; } +#ifndef IS_ROCM template <> __inline__ cudnnDataType_t getDataType<__nv_bfloat16>() { return CUDNN_DATA_BFLOAT16; } +#endif } // jittor diff --git a/python/jittor/gradfunctional/__init__.py b/python/jittor/gradfunctional/__init__.py new file mode 100644 index 00000000..259897e1 --- /dev/null +++ b/python/jittor/gradfunctional/__init__.py @@ -0,0 +1,2 @@ +from .functional import jvp, vjp + diff --git a/python/jittor/gradfunctional/functional.py b/python/jittor/gradfunctional/functional.py new file mode 100644 index 00000000..df183cc1 --- /dev/null +++ b/python/jittor/gradfunctional/functional.py @@ -0,0 +1,420 @@ +# reference: https://github.com/pytorch/pytorch/blob/8ea5b572a63b1acc538a9fc8d3862c73739116e8/torch/autograd/functional.py +import jittor as jt + +__all__ = ["vjp", "jvp", "jacobian", "hessian", "hvp", "vhp"] + +# Utility functions +def _as_tuple_nocheck(x): + if isinstance(x, tuple): + return x + elif isinstance(x, list): + return tuple(x) + else: + return (x,) + +def _as_tuple(inp, arg_name=None, fn_name=None): + # Ensures that inp is a tuple of Tensors + # Returns whether or not the original inp was a tuple and the tupled version of the input + if arg_name is None and fn_name is None: + return _as_tuple_nocheck(inp) + + is_inp_tuple = True + if not isinstance(inp, tuple): + inp = (inp,) + is_inp_tuple = False + + for i, el in enumerate(inp): + if not isinstance(el, (jt.Var, jt.nn.ComplexNumber)): + if is_inp_tuple: + raise TypeError( + f"The {arg_name} given to {fn_name} must be either a Tensor or a tuple of Tensors but the" + f" value at index {i} has type {type(el)}." + ) + else: + raise TypeError( + f"The {arg_name} given to {fn_name} must be either a Tensor or a tuple of Tensors but the" + f" given {arg_name} has type {type(el)}." + ) + + return is_inp_tuple, inp + + +def _tuple_postprocess(res, to_unpack): + # Unpacks a potentially nested tuple of Tensors + # to_unpack should be a single boolean or a tuple of two booleans. + # It is used to: + # - invert _as_tuple when res should match the inp given to _as_tuple + # - optionally remove nesting of two tuples created by multiple calls to _as_tuple + if isinstance(to_unpack, tuple): + assert len(to_unpack) == 2 + if not to_unpack[1]: + res = tuple(el[0] for el in res) + if not to_unpack[0]: + res = res[0] + else: + if not to_unpack: + res = res[0] + return res + + +def _grad_preprocess(inputs, create_graph, need_graph): + # Preprocess the inputs to make sure they require gradient + # inputs is a tuple of Tensors to preprocess + # create_graph specifies if the user wants gradients to flow back to the Tensors in inputs + # need_graph specifies if we internally want gradients to flow back to the Tensors in res + # Note that we *always* create a new Tensor object to be able to see the difference between + # inputs given as arguments and the same Tensors automatically captured by the user function. + # Check this issue for more details on how that can happen: https://github.com/pytorch/pytorch/issues/32576 + res = [] + for inp in inputs: + if create_graph and inp.requires_grad: + # Create at least a new Tensor object in a differentiable way + # Use .reshae() to get a shallow copy + res.append(inp.reshape(inp.shape)) + else: + if need_graph: + ninp = inp.detach().start_grad() + else: + ninp = inp.detach().stop_grad() + res.append(ninp) + return tuple(res) + + +def _grad_postprocess(inputs, create_graph): + # Postprocess the generated Tensors to avoid returning Tensors with history when the user did not + # request it. + if isinstance(inputs[0], (jt.Var, jt.nn.ComplexNumber)): + if not create_graph: + return tuple(inp.detach() for inp in inputs) + else: + return inputs + else: + return tuple(_grad_postprocess(inp, create_graph) for inp in inputs) + + +def _validate_v(v, other, is_other_tuple): + # This assumes that other is the correct shape, and v should match + # Both are assumed to be tuples of Tensors + if len(other) != len(v): + if is_other_tuple: + raise RuntimeError( + f"v is a tuple of invalid length: should be {len(other)} but got {len(v)}." + ) + else: + raise RuntimeError("The given v should contain a single Tensor.") + + for idx, (el_v, el_other) in enumerate(zip(v, other)): + if el_v.shape != el_other.shape: + prepend = "" + if is_other_tuple: + prepend = f"Entry {idx} in " + raise RuntimeError( + f"{prepend}v has invalid size: should be {el_other.shape} but got {el_v.shape}." + ) + + +def _check_requires_grad(inputs, input_type, strict): + # Used to make all the necessary checks to raise nice errors in strict mode. + if not strict: + return + + if input_type not in ["outputs", "grad_inputs", "jacobian", "hessian"]: + raise RuntimeError("Invalid input_type to _check_requires_grad") + for i, inp in enumerate(inputs): + if inp is None: + # This can only be reached for grad_inputs. + raise RuntimeError( + f"The output of the user-provided function is independent of input {i}." + " This is not allowed in strict mode." + ) + if not inp.requires_grad: + if input_type == "hessian": + raise RuntimeError( + f"The hessian of the user-provided function with respect to input {i}" + " is independent of the input. This is not allowed in strict mode." + " You should ensure that your function is thrice differentiable and that" + " the hessian depends on the inputs." + ) + elif input_type == "jacobian": + raise RuntimeError( + "While computing the hessian, found that the jacobian of the user-provided" + f" function with respect to input {i} is independent of the input. This is not" + " allowed in strict mode. You should ensure that your function is twice" + " differentiable and that the jacobian depends on the inputs (this would be" + " violated by a linear function for example)." + ) + elif input_type == "grad_inputs": + raise RuntimeError( + f"The gradient with respect to input {i} is independent of the inputs of the" + " user-provided function. This is not allowed in strict mode." + ) + else: + raise RuntimeError( + f"Output {i} of the user-provided function does not require gradients." + " The outputs must be computed in a differentiable manner from the input" + " when running in strict mode." + ) + + +def _autograd_grad( + outputs, + inputs, + grad_outputs=None, + create_graph=True, +): + # Version of grad that accepts `None` in outputs and do not compute gradients for them. + # This has the extra constraint that inputs has to be a tuple + assert isinstance(outputs, tuple) + if grad_outputs is None: + grad_outputs = (None,) * len(outputs) + assert isinstance(grad_outputs, tuple) + assert len(outputs) == len(grad_outputs) + + new_outputs = () + new_grad_outputs = () + for out, grad_out in zip(outputs, grad_outputs): + if out is not None and out.requires_grad: + new_outputs += (out,) + new_grad_outputs += (grad_out,) + + if len(new_outputs) == 0: + # No differentiable output, we don't need to call the autograd engine + return (None,) * len(inputs) + else: + acc_loss = None + for new_output, grad_output in zip(new_outputs, grad_outputs): + if isinstance(new_output, jt.nn.ComplexNumber): + if grad_output is not None: + loss = (new_output.value * grad_output.value).sum() + else: + loss = new_output.value.sum() + else: + if grad_output is not None: + new_output = new_output * grad_output + loss = new_output.sum() + if acc_loss is None: + acc_loss = loss + else: + acc_loss += loss + + complex_inds = [] + var_inputs = [] + for idx, inp in enumerate(inputs): + if isinstance(inp, jt.nn.ComplexNumber): + var_inputs.append(inp.value) + complex_inds.append(idx) + else: + var_inputs.append(inp) + + grads = jt.grad(acc_loss, var_inputs, retain_graph=create_graph) + for complex_ind in complex_inds: + grads[complex_ind] = jt.nn.ComplexNumber(grads[complex_ind], is_concat_value=True) + return tuple(grads) + + +def _fill_in_zeros(grads, refs, strict, create_graph, stage): + # Used to detect None in the grads and depending on the flags, either replace them + # with Tensors full of 0s of the appropriate size based on the refs or raise an error. + # strict and create graph allow us to detect when it is appropriate to raise an error + # stage gives us information of which backward call we consider to give good error message + if stage not in ["back", "back_trick", "double_back", "double_back_trick"]: + raise RuntimeError(f"Invalid stage argument '{stage}' to _fill_in_zeros") + + res = () + for i, grads_i in enumerate(grads): + if grads_i is None: + if strict: + if stage == "back": + raise RuntimeError( + "The output of the user-provided function is independent of " + f"input {i}. This is not allowed in strict mode." + ) + elif stage == "back_trick": + raise RuntimeError( + f"The gradient with respect to the input is independent of entry {i}" + " in the grad_outputs when using the double backward trick to compute" + " forward mode gradients. This is not allowed in strict mode." + ) + elif stage == "double_back": + raise RuntimeError( + "The jacobian of the user-provided function is independent of " + f"input {i}. This is not allowed in strict mode." + ) + else: + raise RuntimeError( + "The hessian of the user-provided function is independent of " + f"entry {i} in the grad_jacobian. This is not allowed in strict " + "mode as it prevents from using the double backward trick to " + "replace forward mode AD." + ) + + refs_i = refs[i] + if isinstance(refs_i, jt.nn.ComplexNumber): + grads_i = jt.nn.ComplexNumber(jt.zeros_like(refs_i.value), is_concat_value=True) + else: + grads_i = jt.zeros_like(refs_i) + else: + if strict and create_graph and not grads_i.requires_grad: + if "double" not in stage: + raise RuntimeError( + "The jacobian of the user-provided function is independent of " + f"input {i}. This is not allowed in strict mode when create_graph=True." + ) + else: + raise RuntimeError( + "The hessian of the user-provided function is independent of " + f"input {i}. This is not allowed in strict mode when create_graph=True." + ) + + res += (grads_i,) + + return res + + +# Public API + +def vjp(func, inputs, v=None, create_graph=False, strict=False): + r"""Compute the dot product between a vector ``v`` and the Jacobian of the given function at the point given by the inputs. + + Args: + func (function): a Python function that takes Tensor inputs and returns + a tuple of Tensors or a Tensor. + inputs (tuple of Tensors or Tensor): inputs to the function ``func``. + v (tuple of Tensors or Tensor): The vector for which the vector + Jacobian product is computed. Must be the same size as the output + of ``func``. This argument is optional when the output of ``func`` + contains a single element and (if it is not provided) will be set + as a Tensor containing a single ``1``. + create_graph (bool, optional): If ``True``, both the output and result + will be computed in a differentiable way. Note that when ``strict`` + is ``False``, the result can not require gradients or be + disconnected from the inputs. Defaults to ``False``. + strict (bool, optional): If ``True``, an error will be raised when we + detect that there exists an input such that all the outputs are + independent of it. If ``False``, we return a Tensor of zeros as the + vjp for said inputs, which is the expected mathematical value. + Defaults to ``False``. + + Returns: + output (tuple): tuple with: + func_output (tuple of Tensors or Tensor): output of ``func(inputs)`` + + vjp (tuple of Tensors or Tensor): result of the dot product with + the same shape as the inputs. + """ + with jt.enable_grad(): + is_inputs_tuple, inputs = _as_tuple(inputs, "inputs", "vjp") + inputs = _grad_preprocess(inputs, create_graph=create_graph, need_graph=True) + + outputs = func(*inputs) + is_outputs_tuple, outputs = _as_tuple( + outputs, "outputs of the user-provided function", "vjp" + ) + _check_requires_grad(outputs, "outputs", strict=strict) + + if v is not None: + _, v = _as_tuple(v, "v", "vjp") + v = _grad_preprocess(v, create_graph=create_graph, need_graph=False) + _validate_v(v, outputs, is_outputs_tuple) + else: + if len(outputs) != 1 or outputs[0].nelement() != 1: + raise RuntimeError( + "The vector v can only be None if the " + "user-provided function returns " + "a single Tensor with a single element." + ) + + with jt.enable_grad(): + grad_res = _autograd_grad(outputs, inputs, v, create_graph=create_graph) + vjp = _fill_in_zeros(grad_res, inputs, strict, create_graph, "back") + + # Cleanup objects and return them to the user + outputs = _grad_postprocess(outputs, create_graph) + vjp = _grad_postprocess(vjp, create_graph) + + return _tuple_postprocess(outputs, is_outputs_tuple), _tuple_postprocess( + vjp, is_inputs_tuple + ) + + +def jvp(func, inputs, v=None, create_graph=False, strict=False): + r"""Compute the dot product between the Jacobian of the given function at the point given by the inputs and a vector ``v``. + + Args: + func (function): a Python function that takes Tensor inputs and returns + a tuple of Tensors or a Tensor. + inputs (tuple of Tensors or Tensor): inputs to the function ``func``. + v (tuple of Tensors or Tensor): The vector for which the Jacobian + vector product is computed. Must be the same size as the input of + ``func``. This argument is optional when the input to ``func`` + contains a single element and (if it is not provided) will be set + as a Tensor containing a single ``1``. + create_graph (bool, optional): If ``True``, both the output and result + will be computed in a differentiable way. Note that when ``strict`` + is ``False``, the result can not require gradients or be + disconnected from the inputs. Defaults to ``False``. + strict (bool, optional): If ``True``, an error will be raised when we + detect that there exists an input such that all the outputs are + independent of it. If ``False``, we return a Tensor of zeros as the + jvp for said inputs, which is the expected mathematical value. + Defaults to ``False``. + + Returns: + output (tuple): tuple with: + func_output (tuple of Tensors or Tensor): output of ``func(inputs)`` + + jvp (tuple of Tensors or Tensor): result of the dot product with + the same shape as the output. + + """ + with jt.enable_grad(): + is_inputs_tuple, inputs = _as_tuple(inputs, "inputs", "jvp") + inputs = _grad_preprocess(inputs, create_graph=create_graph, need_graph=True) + + if v is not None: + _, v = _as_tuple(v, "v", "jvp") + v = _grad_preprocess(v, create_graph=create_graph, need_graph=False) + _validate_v(v, inputs, is_inputs_tuple) + else: + if len(inputs) != 1 or inputs[0].nelement() != 1: + raise RuntimeError( + "The vector v can only be None if the input to " + "the user-provided function is a single Tensor " + "with a single element." + ) + + outputs = func(*inputs) + is_outputs_tuple, outputs = _as_tuple( + outputs, "outputs of the user-provided function", "jvp" + ) + _check_requires_grad(outputs, "outputs", strict=strict) + # The backward is linear so the value of grad_outputs is not important as + # it won't appear in the double backward graph. We only need to ensure that + # it does not contain inf or nan. + grad_outputs = tuple( + jt.nn.ComplexNumber(jt.zeros_like(out.value), is_concat_value=True) if isinstance(out, jt.nn.ComplexNumber) else jt.zeros_like(out) + for out in outputs + ) + + grad_inputs = _autograd_grad(outputs, inputs, grad_outputs=grad_outputs, create_graph=True) + _check_requires_grad(grad_inputs, "grad_inputs", strict=strict) + + if create_graph: + with jt.enable_grad(): + grad_res = _autograd_grad( + grad_inputs, grad_outputs, v, create_graph=create_graph + ) + jvp = _fill_in_zeros(grad_res, outputs, strict, create_graph, "back_trick") + else: + grad_res = _autograd_grad( + grad_inputs, grad_outputs, v, create_graph=create_graph + ) + jvp = _fill_in_zeros(grad_res, outputs, strict, create_graph, "back_trick") + + # Cleanup objects and return them to the user + outputs = _grad_postprocess(outputs, create_graph) + jvp = _grad_postprocess(jvp, create_graph) + + return _tuple_postprocess(outputs, is_outputs_tuple), _tuple_postprocess( + jvp, is_outputs_tuple + ) diff --git a/python/jittor/linalg.py b/python/jittor/linalg.py index d601335c..2053d86c 100644 --- a/python/jittor/linalg.py +++ b/python/jittor/linalg.py @@ -10,6 +10,214 @@ # *************************************************************** import jittor as jt from functools import partial +from .nn import ComplexNumber + +def complex_inv(x:ComplexNumber): + r""" + calculate the inverse of x. + :param x (...,M,M): + :return:x^-1 (...,M,M). + + TODO: Faster Implementation; Check backward. + """ + assert isinstance(x, ComplexNumber), "complex_inv is implemented for nn.ComplexNumber" + assert x.real.dtype == jt.float32 and x.imag.dtype == jt.float32, "real and imag in ComplexNumber should be jt.float32" + assert x.shape[-2] == x.shape[-1], "only square matrix is supported for complex_inv" + + def forward_code(np, data): + def _stack_to_complex(x): + return x[..., 0] + 1j * x[..., 1] + def _complex_to_stack(x): + return np.stack([np.real(x), np.imag(x)], axis=-1) + + a = _stack_to_complex(data["inputs"][0]) + m_a = data["outputs"][0] + t_a = np.linalg.inv(a) + np.copyto(m_a, _complex_to_stack(t_a)) + + + def backward_code(np, data): + def T(x): + return np.conj(np.swapaxes(x, -1, -2)) + def _stack_to_complex(x): + return x[..., 0] + 1j * x[..., 1] + def _complex_to_stack(x): + return np.stack([np.real(x), np.imag(x)], axis=-1) + _dot = partial(np.einsum, '...ij,...jk->...ik') + dout = _stack_to_complex(data["dout"]) + out = data["outputs"][0] + mx = _stack_to_complex(data["f_outputs"][0]) + t = -_dot(_dot(T(mx), dout), T(mx)) + np.copyto(out, _complex_to_stack(t)) + + lmx = jt.numpy_code( + x.value.shape, + x.value.dtype, + [x.value], + forward_code, + [backward_code], + ) + + return ComplexNumber(lmx, is_concat_value=True) + +def complex_eig(x:ComplexNumber): + r""" + calculate the eigenvalues and eigenvectors of x. + :param x (...,M,M): + :return:w, v. + w (...,M) : the eigenvalues. + v (...,M,M) : normalized eigenvectors. + """ + assert isinstance(x, ComplexNumber), "complex_eig is implemented for nn.ComplexNumber" + assert x.real.dtype == jt.float32 and x.imag.dtype == jt.float32, "real and imag in ComplexNumber should be jt.float32" + assert x.shape[-2] == x.shape[-1], "only square matrix is supported for complex_eig" + def forward_code(np, data): + def _stack_to_complex(x): + return x[..., 0] + 1j * x[..., 1] + def _complex_to_stack(x): + return np.stack([np.real(x), np.imag(x)], axis=-1) + a = _stack_to_complex(data["inputs"][0]) + w, v = data["outputs"] + tw, tv = np.linalg.eig(a) + np.copyto(w, _complex_to_stack(tw)) + np.copyto(v, _complex_to_stack(tv)) + + def backward_code(np, data): + raise NotImplementedError + + sw = x.shape[:-2] + x.shape[-1:] + (2,) + sv = x.value.shape + w, v = jt.numpy_code( + [sw, sv], + [x.value.dtype, x.value.dtype], + [x.value], + forward_code, + [backward_code], + ) + return ComplexNumber(w, is_concat_value=True), ComplexNumber(v, is_concat_value=True) + +def complex_qr(x): + r""" + do the qr factorization of x in the below formula: + x = QR where Q is orthogonal matrix and R is upper-triangle matrix. + :param x (...,M,M): + :return:q,r as the result of qr factorization.They are both in the shape of (...,M,M). + """ + assert isinstance(x, ComplexNumber), "linalg_qr is implemented for nn.ComplexNumber" + assert x.real.dtype == jt.float32 and x.imag.dtype == jt.float32, "real and imag in ComplexNumber should be jt.float32" + assert x.shape[-2] == x.shape[-1], "only square matrix is supported for linalg_qr" + def forward_code(np, data): + def _stack_to_complex(x): + return x[..., 0] + 1j * x[..., 1] + def _complex_to_stack(x): + return np.stack([np.real(x), np.imag(x)], axis=-1) + a = _stack_to_complex(data["inputs"][0]) + qr = data["outputs"][0] + Q, R = np.linalg.qr(a) + QR = np.stack([Q, R], axis=0) + np.copyto(qr, _complex_to_stack(QR)) + + def backward_code(np, data): + # reference: https://github.com/tencent-quantum-lab/tensorcircuit/blob/master/tensorcircuit/backends/pytorch_ops.py + def H(x): + return np.conj(np.swapaxes(x, -1, -2)) + def _TriangularSolve(x, r): + return H(np.linalg.solve(r, H(x))) + def _stack_to_complex(x): + return x[..., 0] + 1j * x[..., 1] + def _complex_to_stack(x): + return np.stack([np.real(x), np.imag(x)], axis=-1) + _dot = partial(np.einsum, '...ij,...jk->...ik') + _diag = partial(np.einsum, '...ii->...i') + + dout = data["dout"] + out = data["outputs"][0] + qr = data["f_outputs"][0] + dout = _stack_to_complex(dout) + dq, dr = dout[0], dout[1] + qr = _stack_to_complex(qr) + q, r = qr[0], qr[1] + + + qdq = _dot(H(q), dq) + qdq_ = qdq - H(qdq) + rdr = _dot(r, H(dr)) + rdr_ = rdr - H(rdr) + tril = np.tril(qdq_ + rdr_) + + grad_a = _dot(q, dr + _TriangularSolve(tril, r)) + grad_b = _TriangularSolve(dq - _dot(q, qdq), r) + ret = grad_a + grad_b + + m = rdr - H(qdq) + eyem = np.zeros_like(m) + _diag(eyem)[:] = _diag(m) + correction = eyem - np.real(eyem) + ret = ret + _TriangularSolve(_dot(q, H(correction)), r) + + ret = _complex_to_stack(ret) + np.copyto(out,ret) + + qr = jt.numpy_code( + (2,) + x.value.shape, + x.value.dtype, + [x.value], + forward_code, + [backward_code], + ) + q, r = qr[0], qr[1] + return ComplexNumber(q, is_concat_value=True), ComplexNumber(r, is_concat_value=True) + +def complex_svd(x:ComplexNumber): + r''' + calculate the Singular Value Decomposition of x.It follows the below fomula: + x = usv* + only support full matrices == False ver now, which means: + x's shape (...,M,K) + u's shape (...,M,K) + s's shape (...,K) + v's shape (...,K,N) + where K is min(M,N). + :param x: + :return:u,s,v. + ''' + def forward_code(np, data): + def _stack_to_complex(x): + return x[..., 0] + 1j * x[..., 1] + def _complex_to_stack(x): + return np.stack([np.real(x), np.imag(x)], axis=-1) + a = _stack_to_complex(data["inputs"][0]) + u, s, v = data["outputs"] + #TODO:remove copyto + tu, ts, tv = np.linalg.svd(a, full_matrices=0) + np.copyto(u, _complex_to_stack(tu)) + np.copyto(s, _complex_to_stack(ts)) + np.copyto(v, _complex_to_stack(tv)) + + def backward_code(np, data): + raise NotImplementedError + + m, n = x.shape[-2:] + k = min(m, n) + s1 = list(x.shape) + s1[-1] = k + s2 = list(x.shape) + s2[-2] = k + s3 = list(x.shape)[:-2] + s3.append(k) + s1.append(2) + s2.append(2) + s3.append(2) + u, s, v = jt.numpy_code( + [s1, s3, s2], + [x.value.dtype, x.value.dtype, x.value.dtype], + [x.value], + forward_code, + [backward_code], + ) + return ComplexNumber(u, is_concat_value=True), \ + ComplexNumber(s, is_concat_value=True), \ + ComplexNumber(v, is_concat_value=True) #TODO:full_matrices=1 def svd(x): @@ -25,6 +233,8 @@ def svd(x): :param x: :return:u,s,v. ''' + if isinstance(x, ComplexNumber): + return complex_svd(x) def forward_code(np, data): a = data["inputs"][0] u, s, v = data["outputs"] @@ -92,6 +302,17 @@ def T(x): ) return u, s, v +def eig(x): + r""" + calculate the eigenvalues and eigenvectors of x. + :param x (...,M,M): + :return (ComplexNumber):w, v. + w (...,M) : the eigenvalues. + v (...,M,M) : normalized eigenvectors. + """ + if isinstance(x, ComplexNumber): + return complex_eig(x) + return complex_eig(ComplexNumber(x)) def eigh(x): r""" @@ -147,6 +368,8 @@ def inv(x): :param x (...,M,M): :return:x^-1 (...,M,M). """ + if isinstance(x, ComplexNumber): + return complex_inv(x) def forward_code(np, data): a = data["inputs"][0] m_a = data["outputs"][0] @@ -387,6 +610,8 @@ def qr(x): :param x (...,M,M): :return:q,r as the result of qr factorization.They are both in the shape of (...,M,M). """ + if isinstance(x, ComplexNumber): + return complex_qr(x) def forward_code(np, data): a = data["inputs"][0] q, r = data["outputs"] diff --git a/python/jittor/misc.py b/python/jittor/misc.py index fc528d62..8561393e 100644 --- a/python/jittor/misc.py +++ b/python/jittor/misc.py @@ -294,9 +294,10 @@ def expand(x, *shape): if len(shape) == 1 and isinstance(shape[0], (tuple,list,jt.NanoVector)): shape = shape[0] shape = list(shape) - for i in range(len(shape)): - if shape[i] == -1: - shape[i] = x.shape[i] + offset = len(shape) - len(x.shape) + for i in range(len(x.shape)): + if shape[offset + i] == -1: + shape[offset + i] = x.shape[i] return x.broadcast(shape) jt.Var.expand = expand @@ -622,6 +623,10 @@ def unique( #include #include + #include + #include + #include + #include #include ''', @@ -704,6 +709,11 @@ def unique( #include #include #include + + #include + #include + #include + #include #include @@ -881,7 +891,7 @@ def split(d, split_size, dim=0): ans = [] last = 0 s_last = len(split_size)-1 - gopt_disable = jt.flags.gopt_disable + gopt_disable = jt.flags.gopt_disable or jt.flags.use_acl for j, i in enumerate(split_size): if i==0: shape = list(d.shape) @@ -922,6 +932,48 @@ def diag(x,diagonal=0): output_shape = (x.shape[0]-d,) return x.reindex(output_shape,[f'i0+{d}' if diagonal<=0 else 'i0',f'i0+{d}' if diagonal>=0 else 'i0']) +# reference: https://github.com/pytorch/pytorch/blob/25d5a815f74db80ef19a3f714709b55b05675245/torch/_refs/__init__.py +def diagonal(x, offset=0, dim1=0, dim2=1): + def __normalize_dim(d, rank): + if d < 0: + d += rank + if d < 0 or d >= rank: + msg = f"Dimension out of range (expected to be in range of [{-rank}, {rank - 1}], but got {d})" + raise IndexError(msg) + return d + assert x.ndim >= 2, f"diagonal dimensions requires ndim larger than 2, but got {x.ndim}" + dim1 = __normalize_dim(dim1, x.ndim) + dim2 = __normalize_dim(dim2, x.ndim) + assert dim1 != dim2, f"diagonal dimensions cannot be identical {dim1}, {dim2}" + + if offset >= 0: + diag_size = max(min(x.shape[dim1], x.shape[dim2] - offset), 0) + else: + diag_size = max(min(x.shape[dim1] + offset, x.shape[dim2]), 0) + + sizes = [] + indices = [] + lsizes = 0 + dim_diag = x.ndim - 2 + abs_offset = offset if offset >= 0 else -offset + for i, s in enumerate(x.shape): + if i == dim1: + if offset >= 0: + indices.append(f"i{dim_diag}") + else: + indices.append(f"i{dim_diag}+{abs_offset}") + elif i == dim2: + if offset >= 0: + indices.append(f"i{dim_diag}+{abs_offset}") + else: + indices.append(f"i{dim_diag}") + else: + indices.append(f"i{lsizes}") + sizes.append(s) + lsizes += 1 + out_shape = tuple(sizes + [diag_size]) + return x.reindex(out_shape, indices) + jt.Var.diag = diag @@ -2010,6 +2062,7 @@ def contiguous(x): return x.clone() def cpu(x): return x.clone() jt.Var.cpu = cpu def to(x, *args, **kargs): + args += tuple(kargs.values()) if len(args) >= 1: s = args[0] if isinstance(s, jt.NanoString) or callable(s): @@ -2234,3 +2287,16 @@ def cuda(x): def expm1(x): return jt.exp(x) - 1 + + +def isin(elements, test_elements, assume_unique=False, invert=False): + + elements = elements.unsqueeze(-1) + test_elements = test_elements.unsqueeze(0) + comparison = elements == test_elements + result = comparison.any(dim=-1) + + if invert: + result = jt.logical_not(result) + + return result \ No newline at end of file diff --git a/python/jittor/nn.py b/python/jittor/nn.py index a8190c79..46ca78ec 100644 --- a/python/jittor/nn.py +++ b/python/jittor/nn.py @@ -22,6 +22,7 @@ from jittor.optim import * from jittor.misc import _pair, _triple from jittor_utils import LOG +from functools import partial def matmul_transpose(a, b): @@ -385,7 +386,7 @@ def cross_entropy_loss(output, target, weight=None, ignore_index=None,reduction= if ignore_index is not None: target_weight = jt.ternary( target==ignore_index, - jt.array(0).broadcast(target_weight), + jt.array(0).broadcast(target_weight).type_as(target_weight), target_weight ) @@ -493,6 +494,9 @@ def execute(self, output, target): return l1_loss(output, target) def binary_cross_entropy_with_logits(output, target, weight=None, pos_weight=None, size_average=True): + if not (target.shape == output.shape): + raise ValueError(f"Target size ({target.shape}) must be the same as output size ({output.shape})") + max_val = jt.clamp(-output,min_v=0) if pos_weight is not None: log_weight = (pos_weight-1)*target + 1 @@ -588,6 +592,8 @@ def __init__(self, p=0.5, is_train=False): #TODO: test model.train() to change self.is_train def execute(self, input): output = input + if (input.dim() != 4) and (input.dim() != 3): + raise RuntimeError(f'Expected 3D (unbatched) or 4D (batched) input to Dropout2d, but got input of size: {input.shape}') shape = input.shape[:-2] if self.p > 0 and self.is_train: if self.p == 1: @@ -925,6 +931,42 @@ class Conv(Module): >>> output = conv(input) ''' def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True): + if in_channels <= 0: + raise ValueError(f"in_channels must be greater than zero, got {in_channels}") + if out_channels <= 0: + raise ValueError(f"out_channels must be greater than zero, got {out_channels}") + if groups <= 0: + raise ValueError(f"groups must must be greater than zero, got {groups}") + assert in_channels % groups == 0, 'in_channels must be divisible by groups' + assert out_channels % groups == 0, 'out_channels must be divisible by groups' + if isinstance(kernel_size, tuple): + for size in kernel_size: + if size <= 0: + raise ValueError(f"kernel_size must be greater than zero, got {kernel_size}") + else: + if kernel_size <= 0: + raise ValueError(f"kernel_size must be greater than zero, got {kernel_size}") + if isinstance(stride, tuple): + for size in stride: + if size <= 0: + raise ValueError(f"stride must be greater than zero, got {stride}") + else: + if stride <= 0: + raise ValueError(f"stride must be greater than zero, got {stride}") + if isinstance(padding, tuple): + for size in padding: + if size < 0: + raise ValueError(f"padding must be nonnegative, got {padding}") + else: + if padding < 0: + raise ValueError(f"padding must be nonnegative, got {padding}") + if isinstance(dilation, tuple): + for size in dilation: + if size <= 0: + raise ValueError(f"dilation must be greater than zero, got {dilation}") + else: + if dilation <= 0: + raise ValueError(f"dilation must be greater than zero, got {dilation}") self.in_channels = in_channels self.out_channels = out_channels self.kernel_size = kernel_size if isinstance(kernel_size, tuple) else (kernel_size, kernel_size) @@ -935,8 +977,6 @@ def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, self.is_depthwise_conv = self.groups == self.out_channels and self.groups == self.in_channels if self.is_depthwise_conv and jt.flags.use_cuda and jt.compiler.is_cuda: self.depthwise_conv = DepthwiseConv(stride, padding, dilation) - assert in_channels % groups == 0, 'in_channels must be divisible by groups' - assert out_channels % groups == 0, 'out_channels must be divisible by groups' Kh, Kw = self.kernel_size # self.weight = init.relu_invariant_gauss([out_channels, in_channels//groups, Kh, Kw], dtype="float", mode="fan_out") @@ -1053,6 +1093,8 @@ class Conv1d(Module): >>> output = conv(input) ''' def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True): + assert in_channels > 0, 'in_channels must be positive' + assert out_channels > 0, 'out_channels must be positive' self.in_channels = in_channels self.out_channels = out_channels self.kernel_size = (kernel_size, 1) @@ -1061,6 +1103,8 @@ def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, self.dilation = (dilation, 1) self.groups = groups self.bias = bias + if groups <= 0: + raise ValueError("groups must be a positive integer") assert in_channels % groups == 0, 'in_channels must be divisible by groups' assert out_channels % groups == 0, 'out_channels must be divisible by groups' # using list to escape module dfs @@ -1069,6 +1113,8 @@ def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, self.bias = self._conv[0].bias def execute(self, x): + if x.dim() != 3: + raise ValueError("Input shape must be `(N, C, L)`!") N,C,D = x.shape assert C==self.in_channels self._conv[0].weight = self.weight.unsqueeze(-1) @@ -1121,6 +1167,8 @@ def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, self.padding = padding if isinstance(padding, tuple) else (padding, padding, padding) self.dilation = dilation if isinstance(dilation, tuple) else (dilation, dilation, dilation) self.groups = groups + if groups <= 0: + raise ValueError("groups must be a positive integer") assert in_channels % groups == 0, 'in_channels must be divisible by groups' assert out_channels % groups == 0, 'out_channels must be divisible by groups' Kh, Kw, Kd = self.kernel_size @@ -1144,10 +1192,14 @@ def execute(self, x): class Conv1d_sp(Linear): def __init__(self, inchannels, outchannels, kernel_size=1, bias=True): + assert inchannels > 0, 'in_channels must be positive' + assert outchannels > 0, 'out_channels must be positive' super().__init__(inchannels, outchannels, bias=bias) assert kernel_size == 1 def execute(self, x): + if x.dim() != 3: + raise ValueError("Input shape must be `(N, C, L)`!") x = x.transpose(0, 2, 1) x = super().execute(x) x = x.transpose(0, 2, 1) @@ -1187,7 +1239,8 @@ def conv2d(x, weight, bias=None, stride=1, padding=0, dilation=1, groups=1): stride = _pair(stride) dilation = _pair(dilation) out_channels = weight.shape[0] - + if groups <= 0: + raise ValueError("groups must be a positive integer") if groups == 1: N,C,H,W = x.shape Kh, Kw = weight.shape[-2:] @@ -1276,7 +1329,8 @@ def conv3d(x, weight, bias=None, stride=1, padding=0, dilation=1, groups=1): stride = _triple(stride) dilation = _triple(dilation) out_channels = weight.shape[0] - + if groups <= 0: + raise ValueError("groups must be a positive integer") if jt.flags.use_cuda and jt.cudnn: y = jt.cudnn.ops.cudnn_conv3d(x, weight, *stride, *padding, *dilation, groups) elif groups == 1: @@ -1352,6 +1406,8 @@ def __init__(self, in_channels, out_channels, kernel_size, stride=1, \ self.real_padding = (self.dilation[0] * (self.kernel_size[0] - 1) - self.padding[0], self.dilation[1] * (self.kernel_size[1] - 1) - self.padding[1]) self.output_padding = output_padding if isinstance (output_padding, tuple) else (output_padding, output_padding) + assert self.stride[0] > 0 and self.stride[1] > 0,"stride must be positive" + assert self.padding[0] >= 0 and self.padding[1] >= 0,"padding must be non-negative" assert self.output_padding[0] < max(self.stride[0], self.dilation[0]) and \ self.output_padding[1] < max(self.stride[1], self.dilation[1]), \ "output padding must be smaller than max(stride, dilation)" @@ -1369,6 +1425,8 @@ def __init__(self, in_channels, out_channels, kernel_size, stride=1, \ self.bias = None def execute(self, x): + if x.dim() != 4: + raise RuntimeError(f'Expected 4D (batched) input to conv_transpose2d, but got input of size: {x.shape}') if self.groups == 1: N,C,H,W = x.shape i,o,h,w = self.weight.shape @@ -1476,10 +1534,14 @@ def execute(self, x): def conv_transpose(input, weight, bias=None, stride=1, padding=0, output_padding=0, groups=1, dilation=1): if groups == 1: x = input + if x.dim() != 4: + raise RuntimeError(f'Expected 4D input to conv_transpose, but got input of size: {x.shape}') N,C,H,W = x.shape i,o,h,w = weight.shape assert C==i stride = stride if isinstance(stride, tuple) else (stride, stride) + if stride[0] <= 0 or stride[1] <= 0: + raise RuntimeError("non-positive stride is not supported") dilation = dilation if isinstance(dilation, tuple) else (dilation, dilation) # added padding = padding if isinstance(padding, tuple) else (padding, padding) @@ -1511,6 +1573,8 @@ def conv_transpose(input, weight, bias=None, stride=1, padding=0, output_padding assert not bias, "Bias should be none or jittor var" return y else: + if input.dim() != 4: + raise RuntimeError(f'Expected 4D input to conv_transpose, but got input of size: {input.shape}') N,C,H,W = input.shape i,o,h,w = weight.shape G = groups @@ -1519,6 +1583,8 @@ def conv_transpose(input, weight, bias=None, stride=1, padding=0, output_padding assert C % G == 0 assert C==i, (C, i) stride = stride if isinstance(stride, tuple) else (stride, stride) + if stride[0] <= 0 or stride[1] <= 0: + raise RuntimeError("non-positive stride is not supported") dilation = dilation if isinstance(dilation, tuple) else (dilation, dilation) # added padding = padding if isinstance(padding, tuple) else (padding, padding) @@ -1562,11 +1628,15 @@ def conv_transpose(input, weight, bias=None, stride=1, padding=0, output_padding def conv_transpose3d(input, weight, bias=None, stride=1, padding=0, output_padding=0, groups=1, dilation=1): x = input + if x.dim() != 5: + raise RuntimeError(f'Expected 5D input to conv_transpose3d, but got input of size: {x.shape}') N,C,D,H,W = x.shape i,o,d,h,w = weight.shape assert C==i assert groups==1, "Group conv not supported yet." stride = stride if isinstance(stride, tuple) else (stride, stride, stride) + if stride[0] <= 0 or stride[1] <= 0 or stride[2] <= 0: + raise RuntimeError("non-positive stride is not supported") dilation = dilation if isinstance(dilation, tuple) else (dilation, dilation, dilation) # added padding = padding if isinstance(padding, tuple) else (padding, padding, padding) @@ -1631,6 +1701,8 @@ def pad(x,padding, mode='constant', value=0): class ReflectionPad2d(Module): def __init__(self, padding): + if padding < 0: + raise RuntimeError(f"padding must be > 0, but got {padding}") self.padding = padding if isinstance(self.padding, int): self.pl = self.padding @@ -1641,6 +1713,8 @@ def __init__(self, padding): self.pl, self.pr, self.pt, self.pb = self.padding else: raise TypeError(f"ReflectionPad2d padding just support int or tuple, but found {type(padding)}") + if self.pl < 0 or self.pr < 0 or self.pt < 0 or self.pb < 0: + raise ValueError(f"padding must be non-negative") def execute(self, x): n,c,h,w = x.shape @@ -1669,8 +1743,12 @@ def __init__(self, padding): self.pl, self.pr, self.pt, self.pb = self.padding else: raise TypeError(f"ZeroPad2d padding just support int or tuple, but found {type(padding)}") + if self.pl < 0 or self.pr < 0 or self.pt < 0 or self.pb < 0: + raise ValueError(f"padding must be non-negative") def execute(self, x): + if x.dim() != 4: + raise RuntimeError("Input shape must be `(N, C, H, W)`!") n,c,h,w = x.shape return x.reindex([n,c,h+self.pt+self.pb,w+self.pl+self.pr], ["i0","i1",f"i2-{self.pt}",f"i3-{self.pl}"]) @@ -1687,6 +1765,8 @@ def __init__(self, padding, value): else: raise TypeError(f"ConstantPad2d padding just support int or tuple, but found {type(padding)}") self.value = value + if self.pl < 0 or self.pr < 0 or self.pt < 0 or self.pb < 0: + raise ValueError(f"padding must be non-negative") def execute(self, x): assert len(x.shape) >= 2 @@ -1701,6 +1781,8 @@ def execute(self, x): class ReplicationPad2d(Module): def __init__(self, padding): + if padding < 0: + raise RuntimeError(f"padding must be > 0, but got {padding}") self.padding = padding if isinstance(self.padding, int): self.pl = self.padding @@ -1711,8 +1793,12 @@ def __init__(self, padding): self.pl, self.pr, self.pt, self.pb = self.padding else: raise TypeError(f"ReplicationPad2d padding just support int or tuple, but found {type(padding)}") + if self.pl < 0 or self.pr < 0 or self.pt < 0 or self.pb < 0: + raise ValueError(f"padding must be non-negative") def execute(self, x): + if x.dim() != 4: + raise RuntimeError("Input shape must be `(N, C, H, W)`!") n,c,h,w = x.shape oh=h+self.pt+self.pb ow=w+self.pl+self.pr @@ -1760,13 +1846,14 @@ def embedding(input, weight): class PixelShuffle(Module): def __init__(self, upscale_factor): + assert upscale_factor > 0,f"upscale_factor must be greater than zero,got {upscale_factor}" self.upscale_factor = upscale_factor def execute(self, x): n,c,h,w = x.shape r = self.upscale_factor assert c%(r*r)==0, f"input channel needs to be divided by upscale_factor's square in PixelShuffle" - if r<0: + if r<=0: raise RuntimeError(f"pixel_shuffle expects a positive upscale_factor, but got {r}") return x.reindex([n,int(c/r**2),h*r,w*r], [ "i0", @@ -1815,6 +1902,15 @@ def execute(self, x): class Resize(Module): def __init__(self, size, mode="nearest", align_corners=False): super().__init__() + if isinstance(size,int): + if size <= 0: + raise ValueError(f"sizes must be positive, got {size}") + elif isinstance(size,tuple) or isinstance(size,list): + for item in size: + if item <= 0: + raise ValueError(f"sizes must be positive, got {item}") + else: + raise ValueError(f"size must be int or tuple") self.size = size self.mode = mode self.align_corners = align_corners @@ -1869,8 +1965,12 @@ def _interpolate(img, x, y, ids, mode): # TODO: tf_mode to another function def resize(img, size, mode="nearest", align_corners=False, tf_mode=False): + if img.dim() != 4: + raise ValueError("Input shape must be `(N, C, H, W)`!") n, c, h, w = img.shape H, W = size + if h <= 0 or w <= 0 or H <= 0 or W <= 0: + raise RuntimeError(f"Input and output sizes should be greater than 0, but got input (H: {h}, W: {w}) output (H: {H}, W: {W})") nid, cid, hid, wid = jt.index((n, c, H, W)) if align_corners: x = hid * ((h - 1) / max(1, H - 1)) @@ -2159,16 +2259,30 @@ def grid_sample(input, grid, mode='bilinear', padding_mode='zeros', align_corner class Upsample(Module): - def __init__(self, scale_factor=None, mode='nearest'): - self.scale_factor = scale_factor if isinstance(scale_factor, tuple) else (scale_factor, scale_factor) + def __init__(self, scale_factor=None, mode='nearest', align_corners=False): + if isinstance(scale_factor, tuple): + self.scale_factor = tuple(float(factor) for factor in scale_factor) + else: + self.scale_factor = float(scale_factor) if scale_factor else None self.mode = mode - + self.align_corners = align_corners + def execute(self, x): - return upsample(x, - size=( - int(x.shape[2]*self.scale_factor[0]), - int(x.shape[3]*self.scale_factor[1])), - mode=self.mode) + if self.scale_factor is None: + raise ValueError("scale_factor should be defined") + elif isinstance(self.scale_factor, float): + return upsample(x, + size=(int(x.shape[2]*self.scale_factor), + int(x.shape[3]*self.scale_factor)), + mode=self.mode, + align_corners=self.align_corners) + else: + return upsample(x, + size=( + int(x.shape[2]*self.scale_factor[0]), + int(x.shape[3]*self.scale_factor[1])), + mode=self.mode, + align_corners=self.align_cornerss) class UpsamplingBilinear2d(Upsample): def __init__(self, scale_factor=None): @@ -2234,11 +2348,20 @@ def __len__(self): def named_children(self,): return list(self.layers.items()) + + def __setattr__(self, key, value) -> None: + if isinstance(key, str) and key.isdigit(): + if int(key) 0 and kernel_size[1] > 0, "kernel size must be positive" if not isinstance(dilation, tuple): dilation = (dilation, dilation) + assert dilation[0] > 0 and dilation[1] > 0, "dilation must be positive" if not isinstance(padding, tuple): padding = (padding, padding) + assert padding[0] >= 0 and padding[1] >= 0, "padding must be non-negative" if not isinstance(stride, tuple): stride = (stride, stride) + assert stride[0] > 0 and stride[1] > 0, "stride must be positive" n, c, h, w = X.shape shape = X.shape area = kernel_size[0] * kernel_size[1] @@ -2351,14 +2478,19 @@ def unfold(X, kernel_size, dilation=1, padding=0, stride=1): def fold(X,output_size,kernel_size,dilation=1,padding=0,stride=1): assert X.ndim==3 + assert output_size[0] > 0 and output_size[1] > 0, "output size must be positive." if not isinstance(kernel_size,tuple): kernel_size = (kernel_size,kernel_size) + assert kernel_size[0] > 0 and kernel_size[1] > 0, "kernel size must be positive" if not isinstance(dilation,tuple): dilation = (dilation,dilation) + assert dilation[0] > 0 and dilation[1] > 0, "dilation must be positive" if not isinstance(padding,tuple): padding = (padding,padding) + assert padding[0] >= 0 and padding[1] >= 0, "padding must be non-negative" if not isinstance(stride,tuple): stride = (stride,stride) + assert stride[0] > 0 and stride[1] > 0, "stride must be positive" n,cl,num = X.shape area = kernel_size[0] * kernel_size[1] block_nums = [] @@ -2923,6 +3055,10 @@ def call_rnn_cell(self, input, hidden, suffix): return h, h def bilinear(in1, in2, weight, bias): + if weight.shape[1] != in1.shape[1]: + raise RuntimeError(f"bilinear(): input1 size deos not match weight size: got {in1.shape[1]} but expected {weight.shape[1]}") + if weight.shape[2] != in2.shape[1]: + raise RuntimeError(f"bilinear(): input2 size deos not match weight size: got {in2.shape[1]} but expected {weight.shape[2]}") w = weight.transpose((1,0,2)) w = w.reshape((w.shape[0], -1)) x = jt.matmul(in1, w) @@ -3010,6 +3146,10 @@ def __init__(self, real: jt.Var, imag: jt.Var=None, is_concat_value=False): assert real.dtype == imag.dtype self.value = jt.stack([real, imag], dim=-1) + @property + def requires_grad(self): + return self.value.requires_grad + @property def real(self): return self.value[..., 0] @@ -3022,6 +3162,10 @@ def imag(self): def shape(self): return self.value.shape[:-1] + @property + def dtype(self): + return "complex64" + def norm(self): return jt.sqrt(jt.sqr(self.real) + jt.sqr(self.imag)) @@ -3049,6 +3193,12 @@ def permute(self, *axes): def transpose(self, *axes): return ComplexNumber(jt.transpose(self.real, *axes), jt.transpose(self.imag, *axes)) + def broadcast(self, shape, dims): + return ComplexNumber(self.real.broadcast(shape, dims), self.imag.broadcast(shape, dims)) + + def sum(self, dims, keepdims: bool=False): + return ComplexNumber(self.real.sum(dims, keepdims=keepdims), self.imag.sum(dims, keepdims=keepdims)) + def exp(self): er = jt.exp(self.real) return ComplexNumber(er * jt.cos(self.imag), er * jt.sin(self.imag)) @@ -3059,7 +3209,7 @@ def conj(self): def __add__(self, other): if isinstance(other, ComplexNumber): return ComplexNumber(self.real + other.real, self.imag + other.imag) - elif isinstance(other, (jt.Var, int, float)): + elif isinstance(other, (int, float)): return ComplexNumber(self.real + other, self.imag) else: raise NotImplementedError @@ -3067,7 +3217,7 @@ def __add__(self, other): def __radd__(self, other): if isinstance(other, ComplexNumber): return ComplexNumber(other.real + self.real, other.imag + self.imag) - elif isinstance(other, (jt.Var, int, float)): + elif isinstance(other, (int, float)): return ComplexNumber(other + self.real, self.imag) else: raise NotImplementedError @@ -3075,7 +3225,7 @@ def __radd__(self, other): def __sub__(self, other): if isinstance(other, ComplexNumber): return ComplexNumber(self.real - other.real, self.imag - other.imag) - elif isinstance(other, (jt.Var, int, float)): + elif isinstance(other, (int, float)): return ComplexNumber(self.real - other, self.imag) else: raise NotImplementedError @@ -3083,7 +3233,7 @@ def __sub__(self, other): def __rsub__(self, other): if isinstance(other, ComplexNumber): return ComplexNumber(other.real - self.real, other.imag - self.imag) - elif isinstance(other, (jt.Var, int, float)): + elif isinstance(other, (int, float)): return ComplexNumber(other - self.real, self.imag) else: raise NotImplementedError @@ -3094,8 +3244,6 @@ def __mul__(self, other): self.real * other.imag + self.imag * other.real) elif isinstance(other, (int, float)): return ComplexNumber(self.value * other, is_concat_value=True) - elif isinstance(other, jt.Var): - return ComplexNumber(self.real * other, self.imag * other) else: raise NotImplementedError @@ -3105,8 +3253,6 @@ def __rmul__(self, other): other.imag * self.real + other.real * self.imag) elif isinstance(other, (int, float)): return ComplexNumber(other * self.value, is_concat_value=True) - elif isinstance(other, jt.Var): - return ComplexNumber(other * self.real, other * self.imag) else: raise NotImplementedError @@ -3117,8 +3263,6 @@ def __truediv__(self, other): (self.imag * other.real - self.real * other.imag) / norm) elif isinstance(other, (int, float)): return ComplexNumber(self.value / other, is_concat_value=True) - elif isinstance(other, jt.Var): - return ComplexNumber(self.real / other, self.imag / other) else: raise NotImplementedError @@ -3127,7 +3271,7 @@ def __rtruediv__(self, other): if isinstance(other, ComplexNumber): return ComplexNumber((other.real * self.real + other.imag * self.imag) / norm, (other.imag * self.real - other.real * self.imag) / norm) - elif isinstance(other, (int, float, jt.Var)): + elif isinstance(other, (int, float)): return ComplexNumber(other * self.real / norm, - other * self.imag / norm) else: raise NotImplementedError @@ -3136,8 +3280,6 @@ def __matmul__(self, other): if isinstance(other, ComplexNumber): return ComplexNumber(self.real @ other.real - self.imag @ other.imag, self.real @ other.imag + self.imag @ other.real) - elif isinstance(other, jt.Var): - return ComplexNumber(self.real @ other, self.imag @ other) else: raise NotImplementedError @@ -3145,8 +3287,6 @@ def __imatmul__(self, other): if isinstance(other, ComplexNumber): return ComplexNumber(other.real @ self.real - other.imag @ self.imag, other.imag @ self.real + other.real @ self.imag) - elif isinstance(other, jt.Var): - return ComplexNumber(other @ self.real, other @ self.imag) else: raise NotImplementedError @@ -3160,6 +3300,140 @@ def ifft2(self): return ComplexNumber(_fft2(self.value, inverse=True), is_concat_value=True) +def polar(abs:jt.Var, angle: jt.Var) -> ComplexNumber: + assert abs.shape == angle.shape + return ComplexNumber(abs * angle.cos(),abs * angle.sin()) + +def view_as_complex(x: jt.Var) -> ComplexNumber: + assert x.shape[-1] == 2 + return ComplexNumber(x[...,0],x[...,1]) + +def view_as_real(x: ComplexNumber) -> jt.Var: + return jt.stack([x.value[...,0],x.value[...,1]],dim=-1) + +# reference: https://github.com/pytorch/pytorch/blob/8ea5b572a63b1acc538a9fc8d3862c73739116e8/torch/functional.py#L1258 +def tensordot(a, b, dims=2): + r"""Returns a contraction of a and b over multiple dimensions. + + :attr:`tensordot` implements a generalized matrix product. + + Args: + a (Tensor): Left tensor to contract + b (Tensor): Right tensor to contract + dims (int or Tuple[List[int], List[int]] or List[List[int]] containing two lists or Tensor): number of dimensions to + contract or explicit lists of dimensions for :attr:`a` and + :attr:`b` respectively + + When called with a non-negative integer argument :attr:`dims` = :math:`d`, and + the number of dimensions of :attr:`a` and :attr:`b` is :math:`m` and :math:`n`, + respectively, :func:`tensordot` computes + + .. math:: + r_{i_0,...,i_{m-d}, i_d,...,i_n} + = \sum_{k_0,...,k_{d-1}} a_{i_0,...,i_{m-d},k_0,...,k_{d-1}} \times b_{k_0,...,k_{d-1}, i_d,...,i_n}. + + When called with :attr:`dims` of the list form, the given dimensions will be contracted + in place of the last :math:`d` of :attr:`a` and the first :math:`d` of :math:`b`. The sizes + in these dimensions must match. + + """ + if not isinstance(dims, (tuple, list, int)): + raise RuntimeError( + "tensordot expects dims to be int or " + + "Tuple[List[int], List[int]] or " + + "List[List[int]] containing two lists, but got " + + f"dims={dims}" + ) + + dims_a, dims_b = [], [] + + if isinstance(dims, (tuple, list)): + dims_a, dims_b = dims + + if isinstance(dims, (int)): + if dims < 0: + raise RuntimeError(f"tensordot expects dims >= 0, but got dims={dims}") + if dims > min(len(a.shape), len(b.shape)): + raise RuntimeError( + f"tensordot expects dims < ndim_a or ndim_b, but got dims={dims}" + ) + dims_a = list(range(len(a.shape)-dims, len(a.shape))) + dims_b = list(range(dims)) + + # reference: https://github.com/pytorch/pytorch/blob/8ea5b572a63b1acc538a9fc8d3862c73739116e8/aten/src/ATen/native/Linear.cpp#L769 + def __tensordot_native(input1:jt.Var, input2:jt.Var, dims1, dims2): + if not isinstance(dims1, (list, tuple)): + raise RuntimeError("tensordot expects dims1 to be List[Int], but got dims={}".format(dims1)) + if not isinstance(dims2, (list, tuple)): + raise RuntimeError("tensordot expects dims2 to be List[Int], but got dims={}".format(dims2)) + dims1 = list(dims1) + dims2 = list(dims2) + if len(dims1) != len(dims2): + raise RuntimeError("both dimension lists should have the same length") + if input1.dtype != input2.dtype: + raise RuntimeError("both inputs should have the same dtype") + t1 = input1 + t2 = input2 + csize = 1 + input1_bitmap = np.zeros(len(input1.shape), dtype='bool') + input2_bitmap = np.zeros(len(input2.shape), dtype='bool') + for i in range(len(dims1)): + s1 = input1.shape[dims1[i]] + s2 = input2.shape[dims2[i]] + input1_bitmap[dims1] = True + input2_bitmap[dims2] = True + if s2 == 1: #broadcasted dimensions can be summed right away + t1 = t1.sum(dims1[i], keepdims=True) + elif s1 == 1: + t2 = t2.sum(dims2[i], keepdims=True) + else: + if s1 != s2: + raise RuntimeError("contracted dimensions need to match, but first has size {}, in dim {}, and second has size {}".format(s1, i, s2)) + csize *= s1 + + p1, p2 = [], [] # p1, p2: input permutations + rsizes = [] + size1, size2 = 1, 1 # number of non-contracted elements + for i in range(len(input1.shape)): + if not input1_bitmap[i]: + p1.append(i) + size1 *= t1.shape[i] + rsizes.append(t1.shape[i]) + p1 += dims1 + p2 += dims2 + for i in range(len(input2.shape)): + if not input2_bitmap[i]: + p2.append(i) + size2 *= t2.shape[i] + rsizes.append(t2.shape[i]) + + # permute and reshape for matrix multiplication + t1 = t1.permute(p1).reshape((size1, csize)) + t2 = t2.permute(p2).reshape((csize, size2)) + # multiply and reshape to target size + return jt.matmul(t1, t2).reshape(rsizes) + + return __tensordot_native(a, b, dims_a, dims_b) + +# reference: https://github.com/pytorch/pytorch/blob/5ed3b70d09a4ab2a5be4becfda9dd0d3e3227c39/aten/src/ATen/native/LinearAlgebra.cpp#L3375 +def kron(a:jt.Var, b:jt.Var): + a_dim, b_dim = len(a.shape), len(b.shape) + max_dim = max(a_dim, b_dim) + pad_a, pad_b = max_dim-a_dim, max_dim-b_dim + a_reshape, b_reshape = [], [] + result_reshape = [] + for i in range(max_dim): + a_2i_shape = a.shape[i - pad_a] if i >= pad_a else 1 + b_2i1_shape = b.shape[i - pad_b] if i >= pad_b else 1 + a_reshape.append(a_2i_shape) + a_reshape.append(1) + b_reshape.append(1) + b_reshape.append(b_2i1_shape) + result_reshape.append(a_2i_shape * b_2i1_shape) + a = a.reshape(a_reshape) + b = b.reshape(b_reshape) + return (a * b).reshape(result_reshape) + def one_hot(x: jt.Var, num_classes: int=-1) -> jt.Var: ''' Returns the one_hot encoding of inputs. diff --git a/python/jittor/pool.py b/python/jittor/pool.py index f9ad9a68..7e9e808a 100644 --- a/python/jittor/pool.py +++ b/python/jittor/pool.py @@ -29,9 +29,20 @@ def __init__(self, kernel_size, stride=None, padding=0, dilation=None, return_in self.padding = padding if isinstance(padding, tuple) else (padding, padding) self.ceil_mode = ceil_mode self.count_include_pad = count_include_pad and padding != 0 + for item in self.kernel_size: + if item <= 0: + raise RuntimeError(f"kernel_size must be greater than zero, but got {item}") + for item in self.stride: + if item <= 0: + raise RuntimeError(f"stride must be greater than zero, but got {item}") + for item in self.padding: + if item < 0: + raise RuntimeError(f"padding must be non-negative, but got {item}") def execute(self, x): N,C,H,W = x.shape + if H <= self.kernel_size[0] or W <= self.kernel_size[1]: + raise RuntimeError(f"size of var should be larger than kernel_size") if self.ceil_mode == False: h = (H+self.padding[0]*2-self.kernel_size[0])//self.stride[0]+1 w = (W+self.padding[1]*2-self.kernel_size[1])//self.stride[1]+1 @@ -205,9 +216,17 @@ def __init__(self, kernel_size, stride=None, padding=0, dilation=None, return_in self.padding = _triple(padding) self.ceil_mode = ceil_mode self.count_include_pad = count_include_pad and padding != 0 + if self.kernel_size[0] <= 0 or self.kernel_size[1] <= 0 or self.kernel_size[2] <= 0: + raise RuntimeError(f"kernel_size must be greater than zero, but got {kernel_size}") + if self.stride[0] <= 0 or self.stride[1] <= 0 or self.stride[2] <= 0: + raise RuntimeError(f"stride must be greater than zero, but got {stride}") + if self.padding[0] < 0 or self.padding[1] < 0 or self.padding[2] < 0: + raise RuntimeError(f"padding must be non-negative, but got {padding}") def execute(self, x): N,C,D,H,W = x.shape + if D <= self.kernel_size[0] or H <= self.kernel_size[1] or W <= self.kernel_size[2]: + raise RuntimeError(f"size of var should be larger than kernel_size") if self.ceil_mode == False: d = (D+self.padding[0]*2-self.kernel_size[0])//self.stride[0]+1 h = (H+self.padding[1]*2-self.kernel_size[1])//self.stride[1]+1 @@ -506,7 +525,7 @@ def execute(self, x): f"i3*{self.sh}+i6", # Hid f"i4*{self.sw}+i7", # Wid ]) - return xx.reduce("maximun", [5,6,7]) + return xx.reduce("maximum", [5,6,7]) def pool(x, kernel_size, op, padding=0, stride=None): return Pool(kernel_size, stride, padding, op=op)(x) diff --git a/python/jittor/src/mem/allocator/sfrl_allocator.cc b/python/jittor/src/mem/allocator/sfrl_allocator.cc index ce5460fd..add3aef9 100644 --- a/python/jittor/src/mem/allocator/sfrl_allocator.cc +++ b/python/jittor/src/mem/allocator/sfrl_allocator.cc @@ -242,7 +242,12 @@ std::mutex sfrl_allocator_mutex; void* SFRLAllocator::alloc(size_t size, size_t& allocation) { std::unique_lock lock(sfrl_allocator_mutex); + #ifdef IS_ACL + // output of acl op need additional 32 bytes + size = align_size(size+32); + #else size = align_size(size); + #endif CachingBlockPool* blocks = get_blocks(size); //search cached block CachingBlock* block = blocks->pop_block(size); diff --git a/python/jittor/src/misc/cuda_atomic.h b/python/jittor/src/misc/cuda_atomic.h index 6348befb..15b39412 100644 --- a/python/jittor/src/misc/cuda_atomic.h +++ b/python/jittor/src/misc/cuda_atomic.h @@ -6,7 +6,9 @@ // *************************************************************** #pragma once #include +#ifndef IS_ROCM #include +#endif #include "common.h" namespace jittor { diff --git a/python/jittor/src/misc/nan_checker.cc b/python/jittor/src/misc/nan_checker.cc index bbec4ccb..00fb34aa 100644 --- a/python/jittor/src/misc/nan_checker.cc +++ b/python/jittor/src/misc/nan_checker.cc @@ -11,7 +11,9 @@ #include "misc/cuda_flags.h" #include #include +#ifndef IS_ROCM #include +#endif #include "helper_cuda.h" #endif #include "mem/allocator.h" @@ -22,7 +24,9 @@ namespace jittor { #ifdef IS_CUDA EXTERN_LIB vector check_nan_float16(__half* ptr, int64 num); +#ifndef IS_ROCM EXTERN_LIB vector check_nan_bfloat16(__nv_bfloat16* ptr, int64 num); +#endif EXTERN_LIB vector check_nan_float32(float32* ptr, int64 num); EXTERN_LIB vector check_nan_float64(float64* ptr, int64 num); #endif @@ -33,7 +37,9 @@ void dump_var(Var* v, string name) { name = ss.str(); LOGe << "dump" << v << "to" << name; char* buffer = new char[v->size]; - #ifdef HAS_CUDA + #ifdef IS_ROCM + hipMemcpy(buffer, v->mem_ptr, v->size, hipMemcpyDefault); + #elif IS_CUDA cudaMemcpy(buffer, v->mem_ptr, v->size, cudaMemcpyDefault); #else std::memcpy(buffer, v->mem_ptr, v->size); @@ -57,9 +63,11 @@ bool check_nan(Var* v, Op* op) { if (v->dtype() == ns_float16) { nan_index = check_nan_float16((__half*)v->mem_ptr, v->num); } + #ifndef IS_ROCM if (v->dtype() == ns_bfloat16) { nan_index = check_nan_bfloat16((__nv_bfloat16*)v->mem_ptr, v->num); } + #endif if (v->dtype() == ns_float32) { nan_index = check_nan_float32((float32*)v->mem_ptr, v->num); } else @@ -104,14 +112,16 @@ bool check_nan(Var* v, Op* op) { auto* ptr = input->ptr<__half>(); __half value; cudaMemcpy(&value, ptr+index, sizeof(__half), cudaMemcpyDeviceToHost); - LOGe << "input" << icnt << "dtype" << input->dtype() << "index" << index << "value" << (float)value; + // LOGe << "input" << icnt << "dtype" << input->dtype() << "index" << index << "value" << (float)value; } else + #ifndef IS_ROCM if (input->dtype() == ns_bfloat16) { auto* ptr = input->ptr<__nv_bfloat16>(); __nv_bfloat16 value; cudaMemcpy(&value, ptr+index, sizeof(__nv_bfloat16), cudaMemcpyDeviceToHost); LOGe << "input" << icnt << "dtype" << input->dtype() << "index" << index << "value" << (float)value; } else + #endif if (input->dtype() == ns_float32) { auto* ptr = input->ptr(); float32 value; diff --git a/python/jittor/src/misc/nan_checker.cu b/python/jittor/src/misc/nan_checker.cu index ec7998ba..b2e631f1 100644 --- a/python/jittor/src/misc/nan_checker.cu +++ b/python/jittor/src/misc/nan_checker.cu @@ -7,9 +7,13 @@ #include "misc/cuda_flags.h" #include #include -#include + #include "helper_cuda.h" #include +//TODO:FIX in ROCM +#ifndef IS_ROCM +#include +#endif namespace jittor { @@ -37,7 +41,7 @@ __global__ void _check_nan_float16(__half* __restrict__ ptr, int64 num, int* cnt print_nan(float(ptr[i]), i, cnt); } } - +#ifndef IS_ROCM __global__ void _check_nan_bfloat16(__nv_bfloat16* __restrict__ ptr, int64 num, int* cnt) { int64 i = threadIdx.x + blockIdx.x * (int64)blockDim.x; if (i check_nan_float16(__half* ptr, int64 num) { _check_nan_float16<<>>(ptr, num, check_nan_get_device_ptr()); return report_nan(); } - +#ifndef IS_ROCM vector check_nan_bfloat16(__nv_bfloat16* ptr, int64 num) { int block_num = std::max((int64)1, (num-1)/1024+1); int thread_num = std::min((int64)1024, num); _check_nan_bfloat16<<>>(ptr, num, check_nan_get_device_ptr()); return report_nan(); } - +#endif #endif } \ No newline at end of file diff --git a/python/jittor/src/ops/binary_op.cc b/python/jittor/src/ops/binary_op.cc index 01a8ef2d..848d40a4 100644 --- a/python/jittor/src/ops/binary_op.cc +++ b/python/jittor/src/ops/binary_op.cc @@ -440,6 +440,17 @@ BinaryOp::BinaryOp(Var* x, Var* y, NanoString op) : x(x), y(y) { return; } + #ifdef IS_ACL + if (x->dtype() != y->dtype()) { + auto dtype = binary_dtype_infer(ns_add, x->ns, y->ns, 0, 0); + auto xp = make_unary(x, dtype); + auto yp = make_unary(y, dtype); + auto zp = make_binary(xp, yp, op); + forward(zp); + return; + } + #endif + flags.set(NodeFlags::_cpu); flags.set(NodeFlags::_cuda); set_type(OpType::element); diff --git a/python/jittor/src/type/fp16_compute.h b/python/jittor/src/type/fp16_compute.h index 56f70789..f5c85d3b 100644 --- a/python/jittor/src/type/fp16_compute.h +++ b/python/jittor/src/type/fp16_compute.h @@ -11,12 +11,17 @@ #include #include +#ifndef IS_ROCM #include +#endif namespace jittor { typedef __half float16; +#ifndef IS_ROCM typedef __nv_bfloat16 bfloat16; +#endif + #if CUDA_ARCH >= 800 inline __device__ float16 max(float16 a, float16 b) { return __hmax(a, b); } @@ -32,7 +37,7 @@ inline __device__ float16 min(float16 a, float16 b) { return float(a)= 800 inline __device__ bfloat16 max(bfloat16 a, bfloat16 b) { return __hmax(a, b); } inline __device__ bfloat16 min(bfloat16 a, bfloat16 b) { return __hmin(a, b); } @@ -45,7 +50,7 @@ inline __device__ bfloat16 min(bfloat16 a, bfloat16 b) { return float(a) __device__ inline typename std::enable_if::type diff --git a/python/jittor/src/var_holder.cc b/python/jittor/src/var_holder.cc index 092eddd2..ed4d4c0a 100644 --- a/python/jittor/src/var_holder.cc +++ b/python/jittor/src/var_holder.cc @@ -215,11 +215,14 @@ inline static void cast_item_data(ItemData& data) { auto* fp16 = (float16*)&data; auto* fp32 = (float32*)&data; fp32[0] = float32(fp16[0]); - } else if (data.dtype == ns_bfloat16) { + } + #ifndef IS_ROCM + else if (data.dtype == ns_bfloat16) { auto* bf16 = (bfloat16*)&data; auto* fp32 = (float32*)&data; fp32[0] = float32(bf16[0]); } + #endif data.dtype = ns_float32; } diff --git a/python/jittor/test/test_aclop.py b/python/jittor/test/test_aclop.py new file mode 100644 index 00000000..41a6d748 --- /dev/null +++ b/python/jittor/test/test_aclop.py @@ -0,0 +1,463 @@ +import unittest +import jittor as jt +from .test_core import expect_error +import numpy as np +from jittor import init, Module +import numpy as np + + +@unittest.skipIf(not jt.compiler.has_acl, "No ACL found") +class TestACL(unittest.TestCase): + + @jt.flag_scope(use_acl=1) + def test_getitem(self): + a = jt.ones(100, 2) + b = a[0:2, 0:2] + np.testing.assert_allclose(b.numpy(), [[1, 1], [1, 1]]) + print("test getitem success") + + @jt.flag_scope(use_acl=1) + def test_getitem_neg(self): + a = jt.ones(2, 3, 2) + b = a[0:1,0:-2] + np.testing.assert_allclose(b.numpy(), [[[1,1]]]) + print("test getitem neg success") + + @jt.flag_scope(use_acl=1) + def test_setitem(self): + a = jt.ones(2, 2) + b = jt.Var(0) + a[0:1, 0:1] = b + np.testing.assert_allclose(a.numpy(), [[0, 1], [1, 1]]) + print("test setitem success") + + @jt.flag_scope(use_acl=1) + def test_setitem_neg(self): + a = jt.ones(2, 3, 2) + b = jt.Var(0) + a[0:1, 0:-2] = b + np.testing.assert_allclose(a.numpy(), [[[0,0],[1,1],[1,1]],[[1,1],[1,1],[1,1]]]) + print("test setitem neg success") + + @jt.flag_scope(use_acl=1) + def test_getitem_grad(self): + a = jt.ones(2, 2) + b = a[0:1, 0:1] + optimizer = jt.optim.SGD([a], 0.1) + loss = b.sum() + optimizer.zero_grad() + optimizer.backward(loss) + optimizer.step() + res = a.opt_grad(optimizer) + np.testing.assert_allclose(res.numpy(), [[1, 0], [0, 0]]) + print("test getitem grad success") + + @jt.flag_scope(use_acl=1) + def test_setitem_grad(self): + a = jt.ones(3, 3) + b = jt.ones(2, 2) + a[0:2, 0:2] = b * 2 + optimizer = jt.optim.SGD([a, b], 0.1) + loss = a.sum() + optimizer.zero_grad() + optimizer.backward(loss) + optimizer.step() + res_a = a.opt_grad(optimizer) + res_b = b.opt_grad(optimizer) + np.testing.assert_allclose(res_a.numpy(), + [[1, 1, 1], [1, 1, 1], [1, 1, 1]]) + np.testing.assert_allclose(res_b.numpy(), [[2, 2], [2, 2]]) + print("test setitem grad success") + + @jt.flag_scope(use_acl=1) + def test_concat(self): + a = jt.ones(2, 2) + b = jt.ones(2, 2) + c = jt.concat([a, b], 0) + np.testing.assert_allclose(c.numpy(), [[1, 1], [1, 1], [1, 1], [1, 1]]) + print("test concat success") + + @jt.flag_scope(use_acl=1) + def test_concat_neg(self): + a = jt.ones(2, 2) + b = jt.ones(2, 2) + c = jt.concat([a, b], -1) + np.testing.assert_allclose(c.numpy(), [[1,1,1,1],[1,1,1,1]]) + print("test concat neg success") + + @jt.flag_scope(use_acl=1) + def test_concat_zero_dim(self): + a = jt.ones([]) + b = jt.zeros([]) + c = jt.concat([a, b], 0) + np.testing.assert_allclose(c.numpy(), [1,0]) + print("test concat zero dim success") + + @jt.flag_scope(use_acl=1) + def test_maxpool_grad(self): + a = jt.float32([[[[1,2,3,4],[2,3,4,1],[3,4,1,2],[4,1,2,3]]]]) + max_pool = jt.nn.Pool(2, op='maximum') + optimizer = jt.optim.SGD([a], 0.1) + b = max_pool(a) + loss = b.sum() + optimizer.zero_grad() + optimizer.backward(loss) + optimizer.step() + res = a.opt_grad(optimizer) + np.testing.assert_allclose( + res.numpy(), + [[[[0, 0, 0, 1], [0, 1, 0, 0], [0, 1, 0, 0], [0, 0, 0, 1]]]]) + print("test maxpool grad success") + + @jt.flag_scope(use_acl=1) + def test_triu(self): + a = jt.ones(3, 3) + b = jt.triu_(a, 0) + c = jt.triu_(a, 1) + d = jt.triu_(a, -1) + np.testing.assert_allclose(b.numpy(), + [[1, 1, 1], [0, 1, 1], [0, 0, 1]]) + np.testing.assert_allclose(c.numpy(), + [[0, 1, 1], [0, 0, 1], [0, 0, 0]]) + np.testing.assert_allclose(d.numpy(), + [[1, 1, 1], [1, 1, 1], [0, 1, 1]]) + print("test triu success") + + @jt.flag_scope(use_acl=1) + def test_bmm(self): + a = jt.float32([[[1,2],[3,4]],[[2,1],[4,3]],[[1,2],[4,3]]]) + b = jt.bmm(a, a) + np.testing.assert_allclose( + b.numpy(), [[[7, 10], [15, 22]], [[8, 5], [20, 13]], [[9, 8], [16, 17]]]) + print("test bmm success") + + @jt.flag_scope(use_acl=1) + def test_matmul(self): + a = jt.float32([[[1,2,3,4],[5,6,7,8],[9,10,11,12],[13,14,15,16]]]) + b = jt.float32([[1,1],[1,1],[1,1],[1,1]]) + c = jt.matmul(a, b) + np.testing.assert_allclose(c.numpy(), + [[[10, 10], [26, 26], [42, 42], [58, 58]]]) + print("test matmul success") + + @jt.flag_scope(use_acl=1) + def test_maxpool(self): + a = jt.float32([[[[1,2,3,4],[2,3,4,1],[3,4,1,2],[4,1,2,3]]]]) + max_pool = jt.nn.Pool(2, op='maximum') + np.testing.assert_allclose(max_pool(a).numpy(), [[[[3, 4], [4, 3]]]]) + print("test maxpool success") + + @jt.flag_scope(use_acl=1) + def test_transpose(self): + a = jt.float32([[[1,2],[3,4]]]) + b = a.transpose(0, 2) + np.testing.assert_allclose(b.numpy(), [[[1], [3]], [[2], [4]]]) + print("test transpose success") + + @jt.flag_scope(use_acl=1) + def test_transpose_neg(self): + a = jt.float32([[[1,2],[3,4]]]) + b = a.transpose(1, -1) + np.testing.assert_allclose(b.numpy(), [[[1,3], [2,4]]]) + print("test transpose neg success") + + @jt.flag_scope(use_acl=1) + def test_matmul_grad(self): + a = jt.float32([[[1,2,3,4],[5,6,7,8],[9,10,11,12],[13,14,15,16]]]) + b = jt.float32([[1,1],[1,1],[1,1],[1,1]]) + optimizer = jt.optim.SGD([a, b], 0.1) + loss = jt.matmul(a, b).sum() + optimizer.zero_grad() + optimizer.backward(loss) + optimizer.step() + res_a = a.opt_grad(optimizer) + res_b = b.opt_grad(optimizer) + np.testing.assert_allclose(res_a.numpy(), [[[2, 2, 2, 2], [2, 2, 2, 2], [2, 2, 2, 2], [2, 2, 2, 2]]]) + np.testing.assert_allclose(res_b.numpy(), [[28, 28], [32, 32], [36, 36], [40, 40]]) + print("test matmul grad success") + + @jt.flag_scope(use_acl=1) + def test_bmm_grad(self): + a = jt.float32([[[1,2],[3,4]],[[2,1],[4,3]],[[1,2],[4,3]]]) + optimizer = jt.optim.SGD([a], 0.1) + c = jt.bmm(a, a) + loss = c.sum() + + optimizer.zero_grad() + optimizer.backward(loss) + optimizer.step() + + res = a.opt_grad(optimizer) + np.testing.assert_allclose( + res.numpy(), + [[[7, 11], [9, 13]], [[9, 13], [7, 11]], [[8, 12], [8, 12]]]) + print("test bmm grad success") + + @jt.flag_scope(use_acl=1) + def test_avgpool(self): + a = jt.float32([[[[1,2,3,4],[2,3,4,1],[3,4,1,2],[4,1,2,3]]]]) + avg_pool = jt.nn.Pool(2, op='mean') + b = avg_pool(a) + np.testing.assert_allclose(b.numpy(), [[[[2, 3], [3, 2]]]]) + print("test avgpool success") + + @jt.flag_scope(use_acl=1) + def test_adaptive_maxpool2d(self): + a = jt.float32([[[[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12], + [13, 14, 15, 16]]]]) + pool_1 = jt.nn.AdaptiveMaxPool2d((2, 2)) + pool_2 = jt.nn.AdaptiveMaxPool2d((3, 4)) + b = pool_1(a) + c = pool_2(a) + np.testing.assert_allclose(b.numpy(), [[[[6, 8], [14, 16]]]]) + np.testing.assert_allclose(c.numpy(), [[[[5,6,7,8],[9,10,11,12],[13,14,15,16]]]]) + print("test adaptive_maxpool2d success") + + @jt.flag_scope(use_acl=1) + def test_adaptive_maxpool2d_grad_1(self): + a = jt.float32([[[[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12], + [13, 14, 15, 16]]]]) + max_pool = jt.nn.AdaptiveMaxPool2d((2, 2)) + optimizer = jt.optim.SGD([a], 0.1) + b = max_pool(a) + loss = b.sum() + optimizer.zero_grad() + optimizer.backward(loss) + optimizer.step() + res = a.opt_grad(optimizer) + np.testing.assert_allclose( + res.numpy(), + [[[[0, 0, 0, 0], [0, 1, 0, 1], [0, 0, 0, 0], [0, 1, 0, 1]]]]) + print("test adaptive_maxpool2d_1 grad success") + + @jt.flag_scope(use_acl=1) + def test_adaptive_maxpool2d_grad_2(self): + a = jt.float32([[[[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12], + [13, 14, 15, 16]]]]) + max_pool = jt.nn.AdaptiveMaxPool2d((1, 3)) + optimizer = jt.optim.SGD([a], 0.1) + b = max_pool(a) + loss = b.sum() + optimizer.zero_grad() + optimizer.backward(loss) + optimizer.step() + res = a.opt_grad(optimizer) + np.testing.assert_allclose( + res.numpy(), + [[[[0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 1, 1, 1]]]]) + print("test adaptive_maxpool2d_2 grad success") + + @jt.flag_scope(use_acl=1) + def test_adaptive_avgpool2d(self): + a = jt.float32([[[[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12], + [13, 14, 15, 16]]]]) + pool_1 = jt.nn.AdaptiveAvgPool2d((2, 2)) + pool_2 = jt.nn.AdaptiveAvgPool2d((1, 3)) + b = pool_1(a) + c = pool_2(a) + np.testing.assert_allclose(b.numpy(), [[[[3.5, 5.5], [11.5, 13.5]]]]) + np.testing.assert_allclose(c.numpy(), [[[[7.5, 8.5, 9.5]]]]) + print("test adaptive_avgpool2d success") + + @jt.flag_scope(use_acl=1) + def test_adaptive_avgpool2d_grad(self): + a = jt.float32([[[[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12], + [13, 14, 15, 16]]]]) + avg_pool = jt.nn.AdaptiveAvgPool2d((2, 2)) + optimizer = jt.optim.SGD([a], 0.1) + b = avg_pool(a) + loss = b.sum() + optimizer.zero_grad() + optimizer.backward(loss) + optimizer.step() + res = a.opt_grad(optimizer) + np.testing.assert_allclose( + res.numpy(), + [[[[0.25, 0.25, 0.25, 0.25], [0.25, 0.25, 0.25, 0.25], + [0.25, 0.25, 0.25, 0.25], [0.25, 0.25, 0.25, 0.25]]]]) + print("test adaptive_avgpool2d grad success") + + @jt.flag_scope(use_acl=1) + def test_adaptive_avgpool2d_grad_2(self): + a = jt.float32([[[[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12], + [13, 14, 15, 16]]]]) + avg_pool = jt.nn.AdaptiveAvgPool2d((1, 3)) + optimizer = jt.optim.SGD([a], 0.1) + b = avg_pool(a) + loss = b.sum() + optimizer.zero_grad() + optimizer.backward(loss) + optimizer.step() + res = a.opt_grad(optimizer) + np.testing.assert_allclose( + res.numpy(), + [[[[0.125, 0.25, 0.25, 0.125], [0.125, 0.25, 0.25, 0.125], + [0.125, 0.25, 0.25, 0.125], [0.125, 0.25, 0.25, 0.125]]]]) + print("test adaptive_avgpool2d_2 grad success") + + @jt.flag_scope(use_acl=1) + def test_index(self): + a = jt.rand(2, 3) + [s1, s2] = jt.index(a.shape) + np.testing.assert_allclose(s1.numpy(), [[0, 0, 0], [1, 1, 1]]) + np.testing.assert_allclose(s2.numpy(), [[0, 1, 2], [0, 1, 2]]) + print("test index success") + + @jt.flag_scope(use_acl=1) + def test_gather(self): + a = jt.array([[1, 2], [3, 4]]) + b = jt.gather(a, 1, jt.array([[0, 0], [1, 0]])) + np.testing.assert_allclose(b.numpy(), [[1, 1], [4, 3]]) + b = jt.gather(a, 0, jt.array([[0, 0], [1, 0]])) + np.testing.assert_allclose(b.numpy(), [[1, 2], [3, 2]]) + b = jt.gather(a, -1, jt.array([[0, 0], [1, 0]])) + np.testing.assert_allclose(b.numpy(), [[1, 1], [4, 3]]) + print("test gather success") + + @jt.flag_scope(use_acl=1) + def test_gather_grad(self): + a = jt.float32([[1, 2], [3, 4]]) + optimizer = jt.optim.SGD([a], 0.1) + b = jt.gather(a, 0, jt.array([[0, 0], [1, 0]])) + loss = b.sum() + optimizer.zero_grad() + optimizer.backward(loss) + optimizer.step() + res = a.opt_grad(optimizer) + np.testing.assert_allclose(res.numpy(), [[1, 2], [1, 0]]) + print("test gather grad success") + + @jt.flag_scope(use_acl=1) + def test_gather_grad_neg(self): + a = jt.float32([[4, 3], [2, 1]]) + optimizer = jt.optim.SGD([a], 0.1) + b = jt.gather(a, -1, jt.array([[0, 0], [1, 0]])) + loss = b.sum() + optimizer.zero_grad() + optimizer.backward(loss) + optimizer.step() + res = a.opt_grad(optimizer) + np.testing.assert_allclose(res.numpy(), [[2, 0], [1, 1]]) + print("test gather grad neg success") + + @jt.flag_scope(use_acl=1) + def test_scatter_add(self): + a = jt.array([[1, 2], [3, 4]]) + b = jt.array([[0, 0], [0, 0]]) + b = jt.scatter(b, 1, jt.array([[0, 0], [1, 0]]), a, reduce="add") + np.testing.assert_allclose(b.numpy(), [[3, 0], [4, 3]]) + print("test scatter add success") + + @jt.flag_scope(use_acl=1) + def test_scatter_multi(self): + a = jt.array([[1, 2], [3, 4]]) + b = jt.array([[5, 6], [7, 8]]) + b = jt.scatter(b, 0, jt.array([[0, 0], [1, 0]]), a, reduce="multiply") + np.testing.assert_allclose(b.numpy(), [[5, 48], [21, 8]]) + print("test scatter multiply success") + + @jt.flag_scope(use_acl=1) + def test_scatter_add_grad(self): + a = jt.float32([[1, 2], [3, 4]]) + b = jt.float32([[0, 0], [0, 0]]) + optimizer = jt.optim.SGD([a, b], 0.1) + c = jt.scatter(b, 1, jt.array([[0, 0], [1, 0]]), a, reduce="add") + loss = c.max() + optimizer.zero_grad() + optimizer.backward(loss) + optimizer.step() + res_a = a.opt_grad(optimizer) + res_b = b.opt_grad(optimizer) + np.testing.assert_allclose(res_a.numpy(), [[0, 0], [0, 1]]) + np.testing.assert_allclose(res_b.numpy(), [[0, 0], [1, 0]]) + print("test scatter add grad success") + + @jt.flag_scope(use_acl=1) + def test_scatter_mult_grad(self): + a = jt.float32([[1, 2], [3, 4]]) + b = jt.float32([[5, 6], [7, 8]]) + optimizer = jt.optim.SGD([a, b], 0.1) + c = jt.scatter(b, 1, jt.array([[0, 0], [1, 0]]), a, reduce="multiply") + loss = c.max() + optimizer.zero_grad() + optimizer.backward(loss) + optimizer.step() + res_a = a.opt_grad(optimizer) + res_b = b.opt_grad(optimizer) + np.testing.assert_allclose(res_a.numpy(), [[0, 6], [0, 6]]) + np.testing.assert_allclose(res_b.numpy(), [[0, 8], [0, 0]]) + print("test scatter mult grad success") + + @jt.flag_scope(use_acl=1) + def test_where(self): + a = jt.array([[1, 2], [3, 4]]) + b = jt.ones(2, 2) + c = jt.where(a > 2, a, b) + np.testing.assert_allclose(c.numpy(), [[1, 1], [3, 4]]) + print("test where success") + + @jt.flag_scope(use_acl=1) + def test_where_2(self): + a = jt.array([[1, 2], [3, 4]]) + b = jt.array([[5, 6], [7, 8]]) + cond = jt.array([[1, 0], [0, 1]]) + c = jt.where(cond, a, b) + np.testing.assert_allclose(c.numpy(), [[1, 6], [7, 4]]) + print("test where_2 success") + + @jt.flag_scope(use_acl=1) + def test_where_grad(self): + a = jt.array([[1, 2], [3, 4]]) + b = jt.array([[5, 6], [7, 8]]) + cond = jt.array([[1, 0], [0, 1]]) + c = jt.where(cond, a, b) + optimizer = jt.optim.SGD([a, b], 0.1) + loss = c.sum() + optimizer.zero_grad() + optimizer.backward(loss) + optimizer.step() + res_a = a.opt_grad(optimizer) + res_b = b.opt_grad(optimizer) + np.testing.assert_allclose(res_a.numpy(), [[0, 0], [1, 1]]) + np.testing.assert_allclose(res_b.numpy(), [[1, 1], [0, 0]]) + print("test where grad success") + + @jt.flag_scope(use_acl=1) + def test_where_grad_2(self): + a = jt.float32([[1, 2], [3, 4]]) + b = jt.array([[2., 2.], [2., 2.]]) + c = jt.where(a > 2, a, b) + optimizer = jt.optim.SGD([a, b], 0.1) + loss = c.sum() + optimizer.zero_grad() + optimizer.backward(loss) + optimizer.step() + res_a = a.opt_grad(optimizer) + res_b = b.opt_grad(optimizer) + np.testing.assert_allclose(res_a.numpy(), [[0, 0], [1, 1]]) + np.testing.assert_allclose(res_b.numpy(), [[1, 1], [0, 0]]) + print("test where grad 2 success") + + @jt.flag_scope(use_acl=1) + def test_flip(self): + a = jt.array([[1., 2.], [3., 4.]]) + b = a.flip((0, 1)) + np.testing.assert_allclose(b.numpy(), [[4, 3], [2, 1]]) + print("test flip success") + + @jt.flag_scope(use_acl=1) + def test_flip_grad(self): + a = jt.float32([[1, 2], [3, 4]]) + optimizer = jt.optim.SGD([a], 0.1) + b = a.flip((0, 1)) + loss = b.max() + optimizer.zero_grad() + optimizer.backward(loss) + optimizer.step() + res = a.opt_grad(optimizer) + np.testing.assert_allclose(res.numpy(), [[0, 0], [0, 1]]) + print("test flip grad success") + + +if __name__ == "__main__": + unittest.main() diff --git a/python/jittor/test/test_complex.py b/python/jittor/test/test_complex.py new file mode 100644 index 00000000..4c1a6c53 --- /dev/null +++ b/python/jittor/test/test_complex.py @@ -0,0 +1,485 @@ +import jittor as jt +from jittor.nn import ComplexNumber +import unittest +import numpy as np +from functools import partial + +__skip_torch_test = False +try: + import torch +except: + __skip_torch_test = True + +class TestResultAndGrad: + def flatten_list(self, list_like): + results = [] + if isinstance(list_like, (list, tuple)): + for x in list_like: + results.extend(self.flatten_list(x)) + return results + else: + return [list_like] + + def check_results(self, rlist1, rlist2): + assert len(rlist1) == len(rlist2) + for r1, r2 in zip(rlist1, rlist2): + assert r1.shape == r2.shape + assert np.allclose(r1, r2, rtol=1e-3, atol=1e-3) + + def grad_jittor(self, inputs, losses): + grads = [] + for i in inputs: + for loss in losses: + if isinstance(i, ComplexNumber): + g = jt.grad(loss, i.value, retain_graph=True) + grads.append(g[..., 0].numpy() + 1j * g[..., 1].numpy()) + else: + g = jt.grad(loss, i, retain_graph=True) + grads.append(g.numpy()) + return grads + + def grad_torch(self, inputs, losses): + grads = [] + for i in inputs: + for loss in losses: + g = torch.autograd.grad(loss, i, retain_graph=True)[0] + grads.append(g.detach().cpu().numpy()) + return grads + + def run_jittor_op(self, op, input_list, weights=None, key_names=None, **kwargs): + def _np_to_jittor(x): + if isinstance(x, np.ndarray): + if x.dtype == np.complex64 or x.dtype == np.complex128: + nx = np.stack([np.real(x), np.imag(x)], axis=-1) + return ComplexNumber(jt.array(nx, dtype=jt.float32), is_concat_value=True) + elif x.dtype == np.float32 or x.dtype == np.float64: + return jt.array(x, dtype=jt.float32) + else: + assert False + elif isinstance(x, (list, tuple)): + nx = [_np_to_jittor(vx) for vx in x] + if isinstance(x, tuple): + return tuple(nx) + return nx + else: + assert False + def _jittor_to_np(x): + if isinstance(x, jt.Var): + return x.numpy() + elif isinstance(x, ComplexNumber): + return x.real.numpy() + 1j * x.imag.numpy() + assert False + ninput_list = [_np_to_jittor(x) for x in input_list] + + if key_names != None: + assert len(ninput_list) == len(key_names) + nkwargs = kwargs.copy() + for k, v in zip(key_names, ninput_list): + nkwargs[k] = v + output_list = op(**nkwargs) + else: + output_list = op(*ninput_list, **kwargs) + if isinstance(output_list, (jt.Var, ComplexNumber)): + output_list = [output_list] + output_list = self.flatten_list(output_list) + losses = [] + if weights is None: + weights = [] + for o in output_list: + no = o.value if isinstance(o, ComplexNumber) else o + w = np.random.randn(*no.shape) + weights.append(w) + losses.append(jt.sum(no * jt.array(w))) + else: + assert len(output_list) == len(weights) + for o, w in zip(output_list, weights): + no = o.value if isinstance(o, ComplexNumber) else o + assert w.shape == no.shape + losses.append(jt.sum(no * jt.array(w))) + output_list = [_jittor_to_np(x) for x in output_list] + return ninput_list, output_list, losses, weights + + def run_torch_op(self, op, input_list, weights=None, key_names=None, **kwargs): + def _np_to_torch(x): + if isinstance(x, np.ndarray): + return torch.from_numpy(x).requires_grad_(True) + elif isinstance(x, (list, tuple)): + nx = [_np_to_torch(vx) for vx in x] + if isinstance(x, tuple): + return tuple(nx) + return nx + else: + assert False + def _torch_to_np(x:torch.Tensor) -> np.ndarray: + return x.detach().cpu().numpy() + ninput_list = [_np_to_torch(x) for x in input_list] + if key_names != None: + assert len(ninput_list) == len(key_names) + nkwargs = kwargs.copy() + for k, v in zip(key_names, ninput_list): + nkwargs[k] = v + output_list = op(**nkwargs) + else: + output_list = op(*ninput_list, **kwargs) + if isinstance(output_list, torch.Tensor): + output_list = [output_list] + output_list = self.flatten_list(output_list) + losses = [] + if weights is None: + weights = [] + for o in output_list: + no = torch.stack([torch.real(o), torch.imag(o)], dim=-1) if o.is_complex() else o + w = np.random.randn(*no.shape) + weights.append(w) + losses.append(torch.sum(no * torch.from_numpy(w))) + else: + assert len(output_list) == len(weights) + for o, w in zip(output_list, weights): + no = torch.stack([torch.real(o), torch.imag(o)], dim=-1) if o.is_complex() else o + assert w.shape == no.shape + losses.append(torch.sum(no * torch.from_numpy(w))) + output_list = [_torch_to_np(x) for x in output_list] + return ninput_list, output_list, losses, weights + + def check_op_with_torch(self, jittor_op, torch_op, input_list, check_grad=True, jittor_knames=None, torch_knames=None, **kwargs): + weights = None + jittor_input, jittor_output, jittor_losses, weights = self.run_jittor_op(jittor_op, input_list, weights, key_names=jittor_knames, **kwargs) + torch_input, torch_output, torch_losses, weights = self.run_torch_op(torch_op, input_list, weights, key_names=torch_knames, **kwargs) + self.check_results(jittor_output, torch_output) + + if check_grad: + jittor_grads = self.grad_jittor(jittor_input, jittor_losses) + torch_grads = self.grad_torch(torch_input, torch_losses) + self.check_results(jittor_grads, torch_grads) + + def check_op_with_numpy(self, jittor_op, numpy_op, input_list): + _, jittor_output, _, _ = self.run_jittor_op(jittor_op, input_list, None) + numpy_output = numpy_op(*input_list) + if isinstance(numpy_output, np.ndarray): + numpy_output = [numpy_output] + + self.check_results(jittor_output, numpy_output) + +@unittest.skipIf(__skip_torch_test, "No Torch found") +class TestComplexLinalg(unittest.TestCase, TestResultAndGrad): + def random_complex_matrix(self, shape): + r = np.random.randn(*shape) + i = np.random.randn(*shape) + return r + 1j * i + + def test_complex_matmul(self): + s1 = (50, 200) + s2 = (200, 50) + m1 = self.random_complex_matrix(s1) + m2 = self.random_complex_matrix(s2) + + inputs = [m1, m2] + self.check_op_with_torch(jt.matmul, torch.matmul, inputs) + + def test_complex_matmul_batch(self): + s1 = (10, 50, 30) + s2 = (10, 30, 40) + m1 = self.random_complex_matrix(s1) + m2 = self.random_complex_matrix(s2) + + inputs = [m1, m2] + self.check_op_with_torch(jt.matmul, torch.matmul, inputs) + + def test_complex_inv(self): + s1 = (200, 200) + m1 = self.random_complex_matrix(s1) + inputs = [m1] + self.check_op_with_torch(jt.linalg.inv, torch.linalg.inv, inputs) + + def test_complex_inv_batch(self): + s1 = (10, 50, 50) + m1 = self.random_complex_matrix(s1) + inputs = [m1] + self.check_op_with_torch(jt.linalg.inv, torch.linalg.inv, inputs) + + def test_complex_eig(self): + # Unstable + s1 = (20, 20) + m1 = self.random_complex_matrix(s1) + inputs = [m1] + self.check_op_with_numpy(jt.linalg.eig, np.linalg.eig, inputs) + + def test_complex_eig_batch(self): + # Unstable + s1 = (5, 10, 10) + m1 = self.random_complex_matrix(s1) + inputs = [m1] + self.check_op_with_numpy(jt.linalg.eig, np.linalg.eig, inputs) + + def test_complex_qr(self): + s1 = (50, 50) + m1 = self.random_complex_matrix(s1) + inputs = [m1] + self.check_op_with_torch(jt.linalg.qr, torch.linalg.qr, inputs) + + def test_complex_qr_batch(self): + s1 = (10, 20, 20) + m1 = self.random_complex_matrix(s1) + inputs = [m1] + self.check_op_with_torch(jt.linalg.qr, torch.linalg.qr, inputs) + + def test_complex_svd(self): + # Unstable + s1 = (50, 50) + m1 = self.random_complex_matrix(s1) + inputs = [m1] + self.check_op_with_numpy(jt.linalg.svd, np.linalg.svd, inputs) + + def test_complex_svd_batch(self): + # Unstable + s1 = (10, 20, 20) + m1 = self.random_complex_matrix(s1) + inputs = [m1] + self.check_op_with_numpy(jt.linalg.svd, np.linalg.svd, inputs) + +class TestTensordot(unittest.TestCase, TestResultAndGrad): + def random_complex_matrix(self, shape): + r = np.random.randn(*shape) + i = np.random.randn(*shape) + return r + 1j * i + + def random_real_matrix(self, shape): + return np.random.randn(*shape) + + def test_complex_tensordot_numberdim(self): + s1 = (3, 4, 5) + s2 = (4, 5, 6) + dims = 2 + m1 = self.random_complex_matrix(s1) + m2 = self.random_complex_matrix(s2) + inputs = [m1, m2] + self.check_op_with_torch(jt.nn.tensordot, torch.tensordot, inputs, dims = dims) + + def test_complex_tensordot_tupledim(self): + s1 = (3, 5, 4, 6) + s2 = (6, 4, 5, 3) + dims = ([2, 1, 3], [1, 2, 0]) + m1 = self.random_complex_matrix(s1) + m2 = self.random_complex_matrix(s2) + inputs = [m1, m2] + self.check_op_with_torch(jt.nn.tensordot, torch.tensordot, inputs, dims = dims) + + def test_real_tensordot_numberdim(self): + s1 = (3, 4, 5) + s2 = (4, 5, 6) + dims = 2 + m1 = self.random_real_matrix(s1) + m2 = self.random_real_matrix(s2) + inputs = [m1, m2] + self.check_op_with_torch(jt.nn.tensordot, torch.tensordot, inputs, dims = dims) + + def test_real_tensordot_tupledim(self): + s1 = (3, 5, 4, 6) + s2 = (6, 4, 5, 3) + dims = ([2, 1, 3], [1, 2, 0]) + m1 = self.random_real_matrix(s1) + m2 = self.random_real_matrix(s2) + inputs = [m1, m2] + self.check_op_with_torch(jt.nn.tensordot, torch.tensordot, inputs, dims = dims) + +class TestKron(unittest.TestCase, TestResultAndGrad): + def random_complex_matrix(self, shape): + r = np.random.randn(*shape) + i = np.random.randn(*shape) + return r + 1j * i + + def random_real_matrix(self, shape): + return np.random.randn(*shape) + + def test_complex_firstlarge(self): + s1 = (2, 3, 4) + s2 = (5, 2) + m1 = self.random_complex_matrix(s1) + m2 = self.random_complex_matrix(s2) + inputs = [m1, m2] + self.check_op_with_torch(jt.nn.kron, torch.kron, inputs) + + def test_complex_second_large(self): + s1 = (2, 3) + s2 = (5, 2, 4) + m1 = self.random_complex_matrix(s1) + m2 = self.random_complex_matrix(s2) + inputs = [m1, m2] + self.check_op_with_torch(jt.nn.kron, torch.kron, inputs) + + def test_real_firstlarge(self): + s1 = (2, 3, 4) + s2 = (5, 2) + m1 = self.random_real_matrix(s1) + m2 = self.random_real_matrix(s2) + inputs = [m1, m2] + self.check_op_with_torch(jt.nn.kron, torch.kron, inputs) + + def test_real_second_large(self): + s1 = (2, 3) + s2 = (5, 2, 4) + m1 = self.random_real_matrix(s1) + m2 = self.random_real_matrix(s2) + inputs = [m1, m2] + self.check_op_with_torch(jt.nn.kron, torch.kron, inputs) + +@unittest.skipIf(__skip_torch_test, "No Torch found") +class TestGradFunctional(unittest.TestCase, TestResultAndGrad): + def random_complex_matrix(self, shape): + r = np.random.randn(*shape) + i = np.random.randn(*shape) + return r + 1j * i + + def random_real_matrix(self, shape): + return np.random.randn(*shape) * 0.0 + 1.0 + + def test_real_jvp_exp(self): + def exp_reducer(x): + return x.exp().sum(dim=1) + s1 = (5, 6) + m1 = self.random_real_matrix(s1) + m2 = self.random_real_matrix(s1) + inputs = [m1, m2] + self.check_op_with_torch( + partial(jt.gradfunctional.jvp, func=exp_reducer, create_graph=True), + partial(torch.autograd.functional.jvp, func=exp_reducer, create_graph=True), + inputs, + jittor_knames = ['inputs', 'v'], + torch_knames = ['inputs', 'v'], + check_grad=False) + + def test_complex_jvp_exp(self): + def exp_reducer(x): + return x.exp().sum(1) + s1 = (5, 6) + m1 = self.random_complex_matrix(s1) + m2 = self.random_complex_matrix(s1) + inputs = [m1, m2] + self.check_op_with_torch( + partial(jt.gradfunctional.jvp, func=exp_reducer, create_graph=True), + partial(torch.autograd.functional.jvp, func=exp_reducer, create_graph=True), + inputs, + jittor_knames = ['inputs', 'v'], + torch_knames = ['inputs', 'v'], + check_grad=False, + ) + + def test_real_jvp_add(self): + w1, w2 = np.random.rand(), np.random.rand() + def adder(x, y): + return w1 * x + w2 * y + s1 = (5, 6) + m1 = self.random_real_matrix(s1) + m2 = self.random_real_matrix(s1) + m3 = self.random_real_matrix(s1) + m4 = self.random_real_matrix(s1) + inputs = [(m1, m2), (m3, m4)] + self.check_op_with_torch( + partial(jt.gradfunctional.jvp, func=adder, create_graph=True), + partial(torch.autograd.functional.jvp, func=adder, create_graph=True), + inputs, + jittor_knames = ['inputs', 'v'], + torch_knames = ['inputs', 'v'], + check_grad=False, + ) + + def test_complex_jvp_add(self): + w1r, w1i = np.random.rand(), np.random.rand() + w2r, w2i = np.random.rand(), np.random.rand() + def adder_pt(x, y): + return (w1r + 1j * w1i) * x + (w2r + 1j * w2i) * y + def adder_jt(x, y): + w1 = ComplexNumber(real=jt.array(w1r).reshape(1,1), imag = jt.array(w1i).reshape(1,1)) + w2 = ComplexNumber(real=jt.array(w2r).reshape(1,1), imag = jt.array(w2i).reshape(1,1)) + return w1 * x + w2 * y + s1 = (5, 6) + m1 = self.random_complex_matrix(s1) + m2 = self.random_complex_matrix(s1) + m3 = self.random_complex_matrix(s1) + m4 = self.random_complex_matrix(s1) + inputs = [(m1, m2), (m3, m4)] + self.check_op_with_torch( + partial(jt.gradfunctional.jvp, func=adder_jt, create_graph=True), + partial(torch.autograd.functional.jvp, func=adder_pt, create_graph=True), + inputs, + jittor_knames = ['inputs', 'v'], + torch_knames = ['inputs', 'v'], + check_grad=False, + ) + + def test_real_vjp_exp(self): + def exp_reducer(x): + return x.exp().sum(dim=1) + s1 = (5, 6) + s2 = (5,) + m1 = self.random_real_matrix(s1) + m2 = self.random_real_matrix(s2) + inputs = [m1, m2] + self.check_op_with_torch( + partial(jt.gradfunctional.vjp, func=exp_reducer), + partial(torch.autograd.functional.vjp, func=exp_reducer), + inputs, + jittor_knames = ['inputs', 'v'], + torch_knames = ['inputs', 'v'], + check_grad=False) + + def test_complex_vjp_exp(self): + def exp_reducer(x): + return x.exp().sum(1) + s1 = (5, 6) + s2 = (5,) + m1 = self.random_real_matrix(s1) + m2 = self.random_real_matrix(s2) + inputs = [m1, m2] + self.check_op_with_torch( + partial(jt.gradfunctional.vjp, func=exp_reducer, create_graph=True), + partial(torch.autograd.functional.vjp, func=exp_reducer, create_graph=True), + inputs, + jittor_knames = ['inputs', 'v'], + torch_knames = ['inputs', 'v'], + check_grad=False, + ) + + def test_real_vjp_add(self): + w1, w2 = np.random.rand(), np.random.rand() + def adder(x, y): + return w1 * x + w2 * y + s1 = (5, 6) + m1 = self.random_real_matrix(s1) + m2 = self.random_real_matrix(s1) + m3 = self.random_real_matrix(s1) + inputs = [(m1, m2), m3] + self.check_op_with_torch( + partial(jt.gradfunctional.vjp, func=adder, create_graph=True), + partial(torch.autograd.functional.vjp, func=adder, create_graph=True), + inputs, + jittor_knames = ['inputs', 'v'], + torch_knames = ['inputs', 'v'], + check_grad=False, + ) + + def test_complex_vjp_add(self): + w1r, w1i = np.random.rand(), np.random.rand() + w2r, w2i = np.random.rand(), np.random.rand() + def adder_pt(x, y): + return (w1r + 1j * w1i) * x + (w2r + 1j * w2i) * y + def adder_jt(x, y): + w1 = ComplexNumber(real=jt.array(w1r).reshape(1,1), imag = jt.array(w1i).reshape(1,1)) + w2 = ComplexNumber(real=jt.array(w2r).reshape(1,1), imag = jt.array(w2i).reshape(1,1)) + return w1 * x + w2 * y + s1 = (5, 6) + m1 = self.random_complex_matrix(s1) + m2 = self.random_complex_matrix(s1) + m3 = self.random_complex_matrix(s1) + inputs = [(m1, m2), (m3)] + self.check_op_with_torch( + partial(jt.gradfunctional.vjp, func=adder_jt, create_graph=True), + partial(torch.autograd.functional.vjp, func=adder_pt, create_graph=True), + inputs, + jittor_knames = ['inputs', 'v'], + torch_knames = ['inputs', 'v'], + check_grad=False, + ) + +if __name__ == "__main__": + unittest.main() diff --git a/python/jittor/test/test_var.py b/python/jittor/test/test_var.py index 116df6c7..c18b041b 100644 --- a/python/jittor/test/test_var.py +++ b/python/jittor/test/test_var.py @@ -46,6 +46,33 @@ def test_norm(self): np.testing.assert_allclose(jt_x.norm(2,1).numpy(), tc_x.norm(2,1).numpy(), atol=1e-6) np.testing.assert_allclose(jt_x.norm(2,0).numpy(), tc_x.norm(2,0).numpy(), atol=1e-6) + def test_std_with_dim(self): + x=np.random.randn(100, 1000).astype(np.float32) + jt_x = jt.array(x) + tc_x = torch.from_numpy(x) + np.testing.assert_allclose(jt_x.std(dim=-1).numpy(), tc_x.std(dim=-1).numpy(), 1e-4) + np.testing.assert_allclose(jt_x.std(dim=0, keepdim=True).numpy(), tc_x.std(dim=0, keepdim=True).numpy(), 1e-4) + + def test_diagonal(self): + x = np.reshape(np.arange(5*6*7*8), (5,6,7,8)) + jt_x = jt.array(x) + tc_x = torch.from_numpy(x) + def __assert_equal(a:np.ndarray, b:np.ndarray, rtol=1e-6, atol=1e-6): + assert a.shape == b.shape, f"{a.shape}!={b.shape}" + np.testing.assert_allclose(a, b, rtol=rtol, atol=atol) + __assert_equal(jt.misc.diagonal(jt_x, 0, dim1=1, dim2=2).numpy(), tc_x.diagonal(offset=0, dim1=1, dim2=2).numpy()) + __assert_equal(jt.misc.diagonal(jt_x, -1, dim1=1, dim2=2).numpy(), tc_x.diagonal(offset=-1, dim1=1, dim2=2).numpy()) + __assert_equal(jt.misc.diagonal(jt_x, -2, dim1=1, dim2=2).numpy(), tc_x.diagonal(offset=-2, dim1=1, dim2=2).numpy()) + __assert_equal(jt.misc.diagonal(jt_x, -6, dim1=1, dim2=2).numpy(), tc_x.diagonal(offset=-6, dim1=1, dim2=2).numpy()) + __assert_equal(jt.misc.diagonal(jt_x, 1, dim1=1, dim2=2).numpy(), tc_x.diagonal(offset=1, dim1=1, dim2=2).numpy()) + __assert_equal(jt.misc.diagonal(jt_x, 2, dim1=1, dim2=2).numpy(), tc_x.diagonal(offset=2, dim1=1, dim2=2).numpy()) + __assert_equal(jt.misc.diagonal(jt_x, 7, dim1=1, dim2=2).numpy(), tc_x.diagonal(offset=7, dim1=1, dim2=2).numpy()) + __assert_equal(jt.misc.diagonal(jt_x, 1, dim1=-1, dim2=2).numpy(), tc_x.diagonal(offset=1, dim1=-1, dim2=2).numpy()) + __assert_equal(jt.misc.diagonal(jt_x, 1, dim1=-1, dim2=-2).numpy(), tc_x.diagonal(offset=1, dim1=-1, dim2=-2).numpy()) + __assert_equal(jt.misc.diagonal(jt_x, 1, dim1=-2, dim2=-1).numpy(), tc_x.diagonal(offset=1, dim1=-2, dim2=-1).numpy()) + __assert_equal(jt.misc.diagonal(jt_x, 1, dim1=0, dim2=-2).numpy(), tc_x.diagonal(offset=1, dim1=0, dim2=-2).numpy()) + __assert_equal(jt.misc.diagonal(jt_x, 1, dim1=2, dim2=1).numpy(), tc_x.diagonal(offset=1, dim1=2, dim2=1).numpy()) + if __name__ == "__main__": unittest.main() diff --git a/setup.py b/setup.py index 741165d2..beec50f6 100644 --- a/setup.py +++ b/setup.py @@ -62,7 +62,7 @@ package_data={'': ['*', '*/*', '*/*/*','*/*/*/*','*/*/*/*/*','*/*/*/*/*/*']}, # include_package_data=True, install_requires=[ - "numpy", + "numpy<2.0", "tqdm", "pillow", "astunparse",