diff --git a/README.md b/README.md
index 414eb868..90601b43 100644
--- a/README.md
+++ b/README.md
@@ -382,10 +382,10 @@ Email: jittor@qq.com
File an issue: https://github.com/Jittor/jittor/issues
-QQ Group: 761222083
+QQ Group: 836860279
-
+
## The Team
diff --git a/doc/source/jittor.nn.md b/doc/source/jittor.nn.md
index 0ef2b09f..0e5cd35d 100644
--- a/doc/source/jittor.nn.md
+++ b/doc/source/jittor.nn.md
@@ -10,7 +10,7 @@ jittor.nn
.. automodule:: jittor.nn
:imported-members:
- :members: Pool, pool, AdaptiveAvgPool2d, Pool3d, AdaptiveMaxPool2d, AdaptiveAvgPool3d, AdaptiveMaxPool2d, pool3d, AvgPool2d, AvgPool3d, avg_pool2d, MaxPool2d, MaxPool3d, max_pool2d, max_pool3d, MaxUnpool2d, MaxUnpool3d
+ :members: Pool, pool, AdaptiveAvgPool2d, Pool3d, AdaptiveMaxPool2d, AdaptiveAvgPool3d, AdaptiveMaxPool3d, pool3d, AvgPool2d, AvgPool3d, avg_pool2d, MaxPool2d, MaxPool3d, max_pool2d, max_pool3d, MaxUnpool2d, MaxUnpool3d
:undoc-members:
.. autoclass:: jittor.nn.ReLU
diff --git a/python/jittor/__init__.py b/python/jittor/__init__.py
index cfd39c42..c2df0aa7 100644
--- a/python/jittor/__init__.py
+++ b/python/jittor/__init__.py
@@ -9,7 +9,7 @@
# file 'LICENSE.txt', which is part of this source code package.
# ***************************************************************
-__version__ = '1.3.9.6'
+__version__ = '1.3.9.10'
from jittor_utils import lock
with lock.lock_scope():
ori_int = int
@@ -428,7 +428,9 @@ def random(shape, dtype="float32", type="uniform"):
jt.Var([[0.96788853 0.28334728 0.30482838]
[0.46107793 0.62798643 0.03457401]], dtype=float32)
'''
-
+ for dim in shape:
+ if dim < 0:
+ raise RuntimeError(f"Trying to create tensor with negative dimension {dim}: {shape}")
ret = ops.random(shape, "float32", type)
## TODO: move those code to core
#if dtype in ["float16", "bfloat16"]:
@@ -484,6 +486,9 @@ def ones(*shape, dtype="float32"):
shape = shape[:-1]
if isinstance(shape, tuple) and isinstance(shape[0], (Sequence, NanoVector)):
shape = shape[0]
+ for dim in shape:
+ if dim < 0:
+ raise RuntimeError(f"Trying to create tensor with negative dimension {dim}: {shape}")
return unary(1, dtype).broadcast(shape)
def new_ones(x, size):
@@ -515,6 +520,9 @@ def zeros(*shape, dtype="float32"):
shape = shape[:-1]
if isinstance(shape, tuple) and isinstance(shape[0], (Sequence, NanoVector)):
shape = shape[0]
+ for dim in shape:
+ if dim < 0:
+ raise RuntimeError(f"Trying to create tensor with negative dimension {dim}: {shape}")
return unary(0, dtype).broadcast(shape)
def new_zeros(x, size):
@@ -547,6 +555,9 @@ def full(shape,val,dtype="float32"):
'''
if not isinstance(shape, (NanoVector, Sequence)):
shape = (shape,)
+ for dim in shape:
+ if dim < 0:
+ raise RuntimeError(f"Trying to create tensor with negative dimension {dim}: {shape}")
return unary(val, dtype).broadcast(shape)
def new_full(x, size, val):
@@ -641,14 +652,22 @@ def var(x, dim=None, dims=None, unbiased=False, keepdims=False):
return sqr
Var.var = var
-def std(x):
- matsize=1
- for i in x.shape:
- matsize *= i
- out=(x-x.mean()).sqr().sum()
- out=out/(matsize-1)
- out=out.maximum(1e-6).sqrt()
- return out
+def std(x, dim=None, keepdim=False):
+ if dim is None:
+ matsize=1
+ for i in x.shape:
+ matsize *= i
+ out=(x-x.mean()).sqr().sum()
+ out=out/(matsize-1)
+ out=out.maximum(1e-6).sqrt()
+ return out
+ else:
+ dimsize=x.size(dim)
+ mean=jt.mean(x, dim, keepdim=True)
+ out=(x - mean).sqr().sum(dim=dim, keepdim=keepdim)
+ out=out/(dimsize-1)
+ out=out.maximum(1e-6).sqrt()
+ return out
Var.std = std
def norm(x, p=2, dim=-1, keepdims=False, eps=1e-30, keepdim=False):
@@ -687,6 +706,8 @@ def flatten(input, start_dim=0, end_dim=-1):
start_dim = len(in_shape) + start_dim if start_dim < 0 else start_dim
end_dim = len(in_shape) + end_dim if end_dim < 0 else end_dim
assert end_dim >= start_dim, "end_dim should be larger than or equal to start_dim for flatten function"
+ if len(in_shape) <= end_dim:
+ raise IndexError(f"Dimension out of range (expected to be in range of [{-len(in_shape)}, {len(in_shape) - 1}], but got {end_dim})")
out_shape = []
for i in range(0,start_dim,1): out_shape.append(in_shape[i])
dims = 1
@@ -917,6 +938,9 @@ def randn(*size, dtype="float32", requires_grad=True) -> Var:
[-0.612632 -1.1471151 -1.1879086 ]], dtype=float32)
'''
if isinstance(size, tuple) and isinstance(size[0], (tuple, list, NanoVector)): size = size[0]
+ for dim in size:
+ if dim < 0:
+ raise RuntimeError(f"Trying to create tensor with negative dimension {dim}: {size}")
arr = jt.random(size, dtype, "normal")
if not requires_grad: return arr.stop_grad()
return arr
@@ -1013,6 +1037,9 @@ def randint(low, high=None, shape=(1,), dtype="int32") -> Var:
[1 1 1]], dtype=int32)
'''
if high is None: low, high = 0, low
+ for dim in shape:
+ if dim < 0:
+ raise RuntimeError(f"Trying to create tensor with negative dimension {dim}: {shape}")
v = (jt.random(shape) * (high - low) + low).clamp(low, high-0.5)
v = jt.floor_int(v)
return v.astype(dtype)
@@ -1437,9 +1464,17 @@ def requires_grad_(self, requires_grad=True):
def __hooked_call__(self, *args, **kw):
if hasattr(self, "__fhook2__"):
if len(kw):
- self.__fhook2__(self, args, kw)
+ args_kw_result = self.__fhook2__(self, args, kw)
else:
- self.__fhook2__(self, args)
+ args_kw_result = self.__fhook2__(self, args)
+ if args_kw_result is not None:
+ if isinstance(args_kw_result, tuple) and len(args_kw_result) == 2:
+ args, kw = args_kw_result
+ else:
+ raise RuntimeError(
+ "forward pre-hook must return None or a tuple "
+ f"of (new_args, new_kwargs), but got {args_kw_result}."
+ )
if hasattr(self, "__bihook__"):
if len(kw):
LOG.w("backward hook not support kw")
@@ -1458,9 +1493,11 @@ def __hooked_call__(self, *args, **kw):
ret = grad_hooker(ret, self.__bohook__)
if hasattr(self, "__fhook__"):
if len(kw):
- self.__fhook__(self, args, ret, kw)
+ res = self.__fhook__(self, args, ret, kw)
else:
- self.__fhook__(self, args, ret)
+ res = self.__fhook__(self, args, ret)
+ if res is not None:
+ ret = res
return ret
def _place_hooker(self):
@@ -1595,6 +1632,8 @@ def load_parameters(self, params):
else:
if hasattr(v, k):
v = getattr(v, k)
+ if v is None:
+ continue
assert isinstance(v, (Module, Var)), \
f"expect a jittor Module or Var, but got <{v.__class__.__name__}>, key: {key}"
else:
@@ -2119,6 +2158,7 @@ def is_var(v):
from . import optim
from . import dataset
from . import init
+from . import gradfunctional
dtype = NanoString
@@ -2152,3 +2192,7 @@ def inplace_wrapper(new_k, prev_func):
from . import math_util
from .math_util import *
from . import distributions
+
+if jt.compiler.has_acl:
+ from jittor.extern.acl.acl_compiler import change_function
+ change_function()
diff --git a/python/jittor/__init__.pyi b/python/jittor/__init__.pyi
index b849af4c..6cfce692 100644
--- a/python/jittor/__init__.pyi
+++ b/python/jittor/__init__.pyi
@@ -1,7 +1,7 @@
from jittor_core import *
from jittor_core.ops import *
from .misc import *
-from . import attention as attention, contrib as contrib, dataset as dataset, init as init, linalg as linalg, lr_scheduler as lr_scheduler, numpy2cupy as numpy2cupy, optim as optim, sparse as sparse
+from . import attention as attention, contrib as contrib, dataset as dataset, init as init, linalg as linalg, lr_scheduler as lr_scheduler, numpy2cupy as numpy2cupy, optim as optim, sparse as sparse, gradfunctional as gradfunctional
from .compile_extern import cublas as cublas, cudnn as cudnn, cufft as cufft, curand as curand, mkl_ops as mkl_ops, mpi_ops as mpi_ops, world_size as world_size
from .compiler import compile_custom_op as compile_custom_op, compile_custom_ops as compile_custom_ops
from .contrib import concat as concat
diff --git a/python/jittor/attention.py b/python/jittor/attention.py
index a8a486cb..5ae59c1b 100644
--- a/python/jittor/attention.py
+++ b/python/jittor/attention.py
@@ -9,168 +9,575 @@
# file 'LICENSE.txt', which is part of this source code package.
# ***************************************************************
-import jittor as jt
-from jittor import init, Module, nn
-import numpy as np
+from typing import Optional, Tuple, List
+import warnings
import math
+import jittor as jt
+from jittor import Var
+from jittor.nn import Module, Linear, softmax, pad, linear, dropout
+from jittor.init import xavier_uniform_, xavier_gauss_, constant_
+
+def _canonical_mask(
+ mask: Optional[Var],
+ mask_name: str,
+ other_type,
+ other_name: str,
+ target_type,
+ check_other: bool = True,
+) -> Optional[Var]:
+
+ if mask is not None:
+ _mask_dtype = mask.dtype
+ _mask_is_float = mask.dtype == jt.float16 or mask.dtype == jt.float32 or mask.dtype == jt.float64
+ if _mask_dtype != jt.bool and not _mask_is_float:
+ raise AssertionError(
+ f"only bool and floating types of {mask_name} are supported")
+ if check_other and other_type is not None:
+ if _mask_dtype != other_type:
+ warnings.warn(
+ f"Support for mismatched {mask_name} and {other_name} "
+ "is deprecated. Use same type for both instead."
+ )
+ if not _mask_is_float:
+ # WARNING(514flowey): Check Here
+ new_mask = jt.zeros_like(mask, dtype=target_type)
+ new_mask[mask] = float("-inf")
+ mask = new_mask
+ return mask
+
+def _none_or_dtype(input: Optional[Var]):
+ if input is None:
+ return None
+ elif isinstance(input, jt.Var):
+ return input.dtype
+ raise RuntimeError("input to _none_or_dtype() must be None or torch.Tensor")
+
+def baddbmm(input_var:jt.Var, batch1:jt.Var, batch2:jt.Var, beta=1, alpha=1) -> jt.Var:
+ # WARNING(514flowey): Check here
+ return beta * input_var + alpha * (batch1 @ batch2)
+
+def scaled_dot_product_attention(query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False, scale=None) -> jt.Var:
+ # Efficient implementation equivalent to the following:
+ L, S = query.size(-2), key.size(-2)
+ scale_factor = 1 / math.sqrt(query.size(-1)) if scale is None else scale
+ attn_bias = jt.zeros(L, S, dtype=query.dtype)
+ if is_causal:
+ assert attn_mask is None
+ temp_mask = jt.ones(L, S, dtype=jt.bool).tril(diagonal=0)
+ attn_bias[jt.logical_not(temp_mask)] = float("-inf")
+ # attn_bias.to(query.dtype)
+ attn_bias = jt.array(attn_bias, query.dtype)
+
+ if attn_mask is not None:
+ if attn_mask.dtype == jt.bool:
+ attn_bias[jt.logical_not(temp_mask)] = float("-inf")
+ else:
+ attn_bias += attn_mask
+ attn_weight = query @ key.transpose(-2, -1) * scale_factor
+ attn_weight += attn_bias
+ attn_weight = softmax(attn_weight, dim=-1)
+ attn_weight = dropout(attn_weight, dropout_p, train=True)
+ return attn_weight @ value
+
+def _mha_shape_check(query: Var, key: Var, value: Var,
+ key_padding_mask: Optional[Var], attn_mask: Optional[Var], num_heads: int):
+ if query.dim() == 3:
+ is_batched = True
+ assert key.dim() == 3 and value.dim() == 3, \
+ ("For batched (3-D) `query`, expected `key` and `value` to be 3-D"
+ f" but found {key.dim()}-D and {value.dim()}-D tensors respectively")
+ if key_padding_mask is not None:
+ assert key_padding_mask.dim() == 2, \
+ ("For batched (3-D) `query`, expected `key_padding_mask` to be `None` or 2-D"
+ f" but found {key_padding_mask.dim()}-D tensor instead")
+ if attn_mask is not None:
+ assert attn_mask.dim() in (2, 3), \
+ ("For batched (3-D) `query`, expected `attn_mask` to be `None`, 2-D or 3-D"
+ f" but found {attn_mask.dim()}-D tensor instead")
+ elif query.dim() == 2:
+ is_batched = False
+ assert key.dim() == 2 and value.dim() == 2, \
+ ("For unbatched (2-D) `query`, expected `key` and `value` to be 2-D"
+ f" but found {key.dim()}-D and {value.dim()}-D tensors respectively")
+
+ if key_padding_mask is not None:
+ assert key_padding_mask.dim() == 1, \
+ ("For unbatched (2-D) `query`, expected `key_padding_mask` to be `None` or 1-D"
+ f" but found {key_padding_mask.dim()}-D tensor instead")
+
+ if attn_mask is not None:
+ assert attn_mask.dim() in (2, 3), \
+ ("For unbatched (2-D) `query`, expected `attn_mask` to be `None`, 2-D or 3-D"
+ f" but found {attn_mask.dim()}-D tensor instead")
+ if attn_mask.dim() == 3:
+ expected_shape = (num_heads, query.shape[0], key.shape[0])
+ assert attn_mask.shape == expected_shape, \
+ (f"Expected `attn_mask` shape to be {expected_shape} but got {attn_mask.shape}")
+ else:
+ raise AssertionError(
+ f"query should be unbatched 2D or batched 3D tensor but received {query.dim()}-D query tensor")
+
+ return is_batched
+
+def _in_projection_packed(
+ q: Var,
+ k: Var,
+ v: Var,
+ w: Var,
+ b: Optional[Var] = None,
+) -> List[Var]:
+ E = q.size(-1)
+ if k is v:
+ if q is k:
+ # self-attention
+ proj = linear(q, w, b)
+ # reshape to 3, E and not E, 3 is deliberate for better memory coalescing and keeping same order as chunk()
+ # proj = proj.unflatten(-1, (3, E)).unsqueeze(0).transpose(0, -2).squeeze(-2).contiguous()
+ nshape = proj.shape[:-1] + (3, E)
+ proj = proj.reshape(nshape).unsqueeze(0).transpose(0, -2).squeeze(-2)
+ return proj[0], proj[1], proj[2]
+ else:
+ # encoder-decoder attention
+ w_q, w_kv = w.split([E, E * 2])
+ if b is None:
+ b_q = b_kv = None
+ else:
+ b_q, b_kv = b.split([E, E * 2])
+ q_proj = linear(q, w_q, b_q)
+ kv_proj = linear(k, w_kv, b_kv)
+ # reshape to 2, E and not E, 2 is deliberate for better memory coalescing and keeping same order as chunk()
+ # kv_proj = kv_proj.unflatten(-1, (2, E)).unsqueeze(0).transpose(0, -2).squeeze(-2).contiguous()
+ nshape = kv_proj.shape[:-1] + (2, E)
+ kv_proj = kv_proj.reshape(nshape).unsqueeze(0).transpose(0, -2).squeeze(-2)
+ return (q_proj, kv_proj[0], kv_proj[1])
+ else:
+ w_q, w_k, w_v = w.chunk(3)
+ if b is None:
+ b_q = b_k = b_v = None
+ else:
+ b_q, b_k, b_v = b.chunk(3)
+ return linear(q, w_q, b_q), linear(k, w_k, b_k), linear(v, w_v, b_v)
+
+def _in_projection(
+ q: Var,
+ k: Var,
+ v: Var,
+ w_q: Var,
+ w_k: Var,
+ w_v: Var,
+ b_q: Optional[Var] = None,
+ b_k: Optional[Var] = None,
+ b_v: Optional[Var] = None,
+) -> Tuple[Var, Var, Var]:
+ Eq, Ek, Ev = q.size(-1), k.size(-1), v.size(-1)
+ assert w_q.shape == (Eq, Eq), f"expecting query weights shape of {(Eq, Eq)}, but got {w_q.shape}"
+ assert w_k.shape == (Eq, Ek), f"expecting key weights shape of {(Eq, Ek)}, but got {w_k.shape}"
+ assert w_v.shape == (Eq, Ev), f"expecting value weights shape of {(Eq, Ev)}, but got {w_v.shape}"
+ assert b_q is None or b_q.shape == (Eq,), f"expecting query bias shape of {(Eq,)}, but got {b_q.shape}"
+ assert b_k is None or b_k.shape == (Eq,), f"expecting key bias shape of {(Eq,)}, but got {b_k.shape}"
+ assert b_v is None or b_v.shape == (Eq,), f"expecting value bias shape of {(Eq,)}, but got {b_v.shape}"
+ return linear(q, w_q, b_q), linear(k, w_k, b_k), linear(v, w_v, b_v)
+
+def multi_head_attention_forward(
+ query: Var,
+ key: Var,
+ value: Var,
+ embed_dim_to_check: int,
+ num_heads: int,
+ in_proj_weight: Optional[Var],
+ in_proj_bias: Optional[Var],
+ bias_k: Optional[Var],
+ bias_v: Optional[Var],
+ add_zero_attn: bool,
+ dropout_p: float,
+ out_proj_weight: Var,
+ out_proj_bias: Optional[Var],
+ training: bool = True,
+ key_padding_mask: Optional[Var] = None,
+ need_weights: bool = True,
+ attn_mask: Optional[Var] = None,
+ use_separate_proj_weight: bool = False,
+ q_proj_weight: Optional[Var] = None,
+ k_proj_weight: Optional[Var] = None,
+ v_proj_weight: Optional[Var] = None,
+ static_k: Optional[Var] = None,
+ static_v: Optional[Var] = None,
+ average_attn_weights: bool = True,
+ is_causal: bool = False,
+) -> Tuple[Var, Optional[Var]]:
+
+ is_batched = _mha_shape_check(query, key, value, key_padding_mask, attn_mask, num_heads)
+
+ # For unbatched input, we unsqueeze at the expected batch-dim to pretend that the input
+ # is batched, run the computation and before returning squeeze the
+ # batch dimension so that the output doesn't carry this temporary batch dimension.
+ if not is_batched:
+ # unsqueeze if the input is unbatched
+ query = query.unsqueeze(1)
+ key = key.unsqueeze(1)
+ value = value.unsqueeze(1)
+ if key_padding_mask is not None:
+ key_padding_mask = key_padding_mask.unsqueeze(0)
+
+ # set up shape vars
+ tgt_len, bsz, embed_dim = query.shape
+ src_len, _, _ = key.shape
+
+ key_padding_mask = _canonical_mask(
+ mask=key_padding_mask,
+ mask_name="key_padding_mask",
+ other_type=_none_or_dtype(attn_mask),
+ other_name="attn_mask",
+ target_type=query.dtype
+ )
+
+ if is_causal and attn_mask is None:
+ raise RuntimeError(
+ "Need attn_mask if specifying the is_causal hint. "
+ "You may use the Transformer module method "
+ "`generate_square_subsequent_mask` to create this mask."
+ )
+
+ if is_causal and key_padding_mask is None and not need_weights:
+ # when we have a kpm or need weights, we need attn_mask
+ # Otherwise, we use the is_causal hint go as is_causal
+ # indicator to SDPA.
+ attn_mask = None
+ else:
+ attn_mask = _canonical_mask(
+ mask=attn_mask,
+ mask_name="attn_mask",
+ other_type=None,
+ other_name="",
+ target_type=query.dtype,
+ check_other=False,
+ )
+
+ if key_padding_mask is not None:
+ # We have the attn_mask, and use that to merge kpm into it.
+ # Turn off use of is_causal hint, as the merged mask is no
+ # longer causal.
+ is_causal = False
+
+ assert embed_dim == embed_dim_to_check, \
+ f"was expecting embedding dimension of {embed_dim_to_check}, but got {embed_dim}"
+ if isinstance(embed_dim, jt.Var):
+ # embed_dim can be a tensor when JIT tracing
+ head_dim = embed_dim.div(num_heads, rounding_mode='trunc')
+ else:
+ head_dim = embed_dim // num_heads
+ assert head_dim * num_heads == embed_dim, f"embed_dim {embed_dim} not divisible by num_heads {num_heads}"
+ if use_separate_proj_weight:
+ # allow MHA to have different embedding dimensions when separate projection weights are used
+ assert key.shape[:2] == value.shape[:2], \
+ f"key's sequence and batch dims {key.shape[:2]} do not match value's {value.shape[:2]}"
+ else:
+ assert key.shape == value.shape, f"key shape {key.shape} does not match value shape {value.shape}"
+
+ #
+ # compute in-projection
+ #
+ if not use_separate_proj_weight:
+ assert in_proj_weight is not None, "use_separate_proj_weight is False but in_proj_weight is None"
+ q, k, v = _in_projection_packed(query, key, value, in_proj_weight, in_proj_bias)
+ else:
+ assert q_proj_weight is not None, "use_separate_proj_weight is True but q_proj_weight is None"
+ assert k_proj_weight is not None, "use_separate_proj_weight is True but k_proj_weight is None"
+ assert v_proj_weight is not None, "use_separate_proj_weight is True but v_proj_weight is None"
+ if in_proj_bias is None:
+ b_q = b_k = b_v = None
+ else:
+ b_q, b_k, b_v = in_proj_bias.chunk(3)
+ q, k, v = _in_projection(query, key, value, q_proj_weight, k_proj_weight, v_proj_weight, b_q, b_k, b_v)
+
+ # prep attention mask
+
+ if attn_mask is not None:
+ # ensure attn_mask's dim is 3
+ if attn_mask.dim() == 2:
+ correct_2d_size = (tgt_len, src_len)
+ if attn_mask.shape != correct_2d_size:
+ raise RuntimeError(f"The shape of the 2D attn_mask is {attn_mask.shape}, but should be {correct_2d_size}.")
+ attn_mask = attn_mask.unsqueeze(0)
+ elif attn_mask.dim() == 3:
+ correct_3d_size = (bsz * num_heads, tgt_len, src_len)
+ if attn_mask.shape != correct_3d_size:
+ raise RuntimeError(f"The shape of the 3D attn_mask is {attn_mask.shape}, but should be {correct_3d_size}.")
+ else:
+ raise RuntimeError(f"attn_mask's dimension {attn_mask.dim()} is not supported")
+
+ # add bias along batch dimension (currently second)
+ if bias_k is not None and bias_v is not None:
+ assert static_k is None, "bias cannot be added to static key."
+ assert static_v is None, "bias cannot be added to static value."
+ k = jt.concat([k, bias_k.repeat(1, bsz, 1)])
+ v = jt.concat([v, bias_v.repeat(1, bsz, 1)])
+ if attn_mask is not None:
+ attn_mask = pad(attn_mask, (0, 1))
+ if key_padding_mask is not None:
+ key_padding_mask = pad(key_padding_mask, (0, 1))
+ else:
+ assert bias_k is None
+ assert bias_v is None
+
+ #
+ # reshape q, k, v for multihead attention and make em batch first
+ #
+ q = q.view(tgt_len, bsz * num_heads, head_dim).transpose(0, 1)
+ if static_k is None:
+ k = k.view(k.shape[0], bsz * num_heads, head_dim).transpose(0, 1)
+ else:
+ # TODO finish disentangling control flow so we don't do in-projections when statics are passed
+ assert static_k.size(0) == bsz * num_heads, \
+ f"expecting static_k.size(0) of {bsz * num_heads}, but got {static_k.size(0)}"
+ assert static_k.size(2) == head_dim, \
+ f"expecting static_k.size(2) of {head_dim}, but got {static_k.size(2)}"
+ k = static_k
+ if static_v is None:
+ v = v.view(v.shape[0], bsz * num_heads, head_dim).transpose(0, 1)
+ else:
+ # TODO finish disentangling control flow so we don't do in-projections when statics are passed
+ assert static_v.size(0) == bsz * num_heads, \
+ f"expecting static_v.size(0) of {bsz * num_heads}, but got {static_v.size(0)}"
+ assert static_v.size(2) == head_dim, \
+ f"expecting static_v.size(2) of {head_dim}, but got {static_v.size(2)}"
+ v = static_v
+
+ # add zero attention along batch dimension (now first)
+ if add_zero_attn:
+ zero_attn_shape = (bsz * num_heads, 1, head_dim)
+ k = jt.concat([k, jt.zeros(zero_attn_shape, dtype=k.dtype)], dim=1)
+ v = jt.concat([v, jt.zeros(zero_attn_shape, dtype=v.dtype)], dim=1)
+ if attn_mask is not None:
+ attn_mask = pad(attn_mask, (0, 1))
+ if key_padding_mask is not None:
+ key_padding_mask = pad(key_padding_mask, (0, 1))
+
+ # update source sequence length after adjustments
+ src_len = k.size(1)
+
+ # merge key padding and attention masks
+ if key_padding_mask is not None:
+ assert key_padding_mask.shape == (bsz, src_len), \
+ f"expecting key_padding_mask shape of {(bsz, src_len)}, but got {key_padding_mask.shape}"
+ key_padding_mask = key_padding_mask.view(bsz, 1, 1, src_len). \
+ expand(-1, num_heads, -1, -1).reshape(bsz * num_heads, 1, src_len)
+ if attn_mask is None:
+ attn_mask = key_padding_mask
+ else:
+ attn_mask = attn_mask + key_padding_mask
+
+ # adjust dropout probability
+ if not training:
+ dropout_p = 0.0
+
+ #
+ # (deep breath) calculate attention and out projection
+ #
+
+ if need_weights:
+ B, Nt, E = q.shape
+ q_scaled = q / math.sqrt(E)
+
+ assert not (is_causal and attn_mask is None), "FIXME: is_causal not implemented for need_weights"
+
+ if attn_mask is not None:
+ attn_output_weights = baddbmm(attn_mask, q_scaled, k.transpose(-2, -1))
+ else:
+ attn_output_weights = jt.bmm(q_scaled, k.transpose(-2, -1))
+ attn_output_weights = softmax(attn_output_weights, dim=-1)
+ if dropout_p > 0.0:
+ attn_output_weights = dropout(attn_output_weights, p=dropout_p)
+
+ attn_output = jt.bmm(attn_output_weights, v)
+
+ attn_output = attn_output.transpose(0, 1).contiguous().view(tgt_len * bsz, embed_dim)
+ attn_output = linear(attn_output, out_proj_weight, out_proj_bias)
+ attn_output = attn_output.view(tgt_len, bsz, attn_output.size(1))
+
+ # optionally average attention weights over heads
+ attn_output_weights = attn_output_weights.view(bsz, num_heads, tgt_len, src_len)
+ if average_attn_weights:
+ attn_output_weights = attn_output_weights.mean(dim=1)
+
+ if not is_batched:
+ # squeeze the output if input was unbatched
+ attn_output = attn_output.squeeze(1)
+ attn_output_weights = attn_output_weights.squeeze(0)
+ return attn_output, attn_output_weights
+ else:
+ # attn_mask can be either (L,S) or (N*num_heads, L, S)
+ # if attn_mask's shape is (1, L, S) we need to unsqueeze to (1, 1, L, S)
+ # in order to match the input for SDPA of (N, num_heads, L, S)
+ if attn_mask is not None:
+ if attn_mask.size(0) == 1 and attn_mask.dim() == 3:
+ attn_mask = attn_mask.unsqueeze(0)
+ else:
+ attn_mask = attn_mask.view(bsz, num_heads, -1, src_len)
+
+ q = q.view(bsz, num_heads, tgt_len, head_dim)
+ k = k.view(bsz, num_heads, src_len, head_dim)
+ v = v.view(bsz, num_heads, src_len, head_dim)
+
+ attn_output = scaled_dot_product_attention(q, k, v, attn_mask, dropout_p, is_causal)
+ attn_output = attn_output.permute(2, 0, 1, 3).contiguous().view(bsz * tgt_len, embed_dim)
+
+ attn_output = linear(attn_output, out_proj_weight, out_proj_bias)
+ attn_output = attn_output.view(tgt_len, bsz, attn_output.size(1))
+ if not is_batched:
+ # squeeze the output if input was unbatched
+ attn_output = attn_output.squeeze(1)
+ return attn_output, None
+
+
class MultiheadAttention(Module):
- def __init__(
- self,
- embed_dim,
- num_heads,
- kdim=None,
- vdim=None,
- dropout=0.0,
- bias=True,
- add_bias_kv=False,
- add_zero_attn=False,
- self_attention=False,
- encoder_decoder_attention=False,
- q_noise=0.0,
- qn_block_size=8,
- ):
+ __constants__ = ['batch_first']
+ bias_k: Optional[jt.Var]
+ bias_v: Optional[jt.Var]
+
+ def __init__(self, embed_dim, num_heads, dropout=0., bias=True, add_bias_kv=False, add_zero_attn=False,
+ kdim=None, vdim=None, batch_first=False, dtype=jt.float32) -> None:
+ if embed_dim <= 0 or num_heads <= 0:
+ raise ValueError(
+ f"embed_dim and num_heads must be greater than 0,"
+ f" got embed_dim={embed_dim} and num_heads={num_heads} instead"
+ )
+ factory_kwargs = {'dtype': dtype}
super().__init__()
self.embed_dim = embed_dim
self.kdim = kdim if kdim is not None else embed_dim
self.vdim = vdim if vdim is not None else embed_dim
- self.qkv_same_dim = self.kdim == embed_dim and self.vdim == embed_dim
+ self._qkv_same_embed_dim = self.kdim == embed_dim and self.vdim == embed_dim
self.num_heads = num_heads
- assert dropout==0, "TODO: dropout>0"
-
+ self.dropout = dropout
+ self.batch_first = batch_first
self.head_dim = embed_dim // num_heads
- assert (self.head_dim * num_heads == self.embed_dim), "embed_dim must be divisible by num_heads"
- self.scaling = self.head_dim ** -0.5
-
- self.self_attention = self_attention
- self.encoder_decoder_attention = encoder_decoder_attention
-
- assert not self.self_attention or self.qkv_same_dim, ("Self-attention requires query, key and " "value to be of the same size")
+ assert self.head_dim * num_heads == self.embed_dim, "embed_dim must be divisible by num_heads"
- #TODO: quant_noise
- self.k_proj = nn.Linear(self.kdim, embed_dim, bias=bias)
- self.v_proj = nn.Linear(self.vdim, embed_dim, bias=bias)
- self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
+ if not self._qkv_same_embed_dim:
+ self.q_proj_weight = jt.empty((embed_dim, embed_dim), **factory_kwargs)
+ self.k_proj_weight = jt.empty((embed_dim, self.kdim), **factory_kwargs)
+ self.v_proj_weight = jt.empty((embed_dim, self.vdim), **factory_kwargs)
+ self.in_proj_weight = None
+ else:
+ self.q_proj_weight = None
+ self.k_proj_weight = None
+ self.v_proj_weight = None
+ self.in_proj_weight = jt.empty((3 * embed_dim, embed_dim), **factory_kwargs)
- self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
+ if bias:
+ self.in_proj_bias = jt.empty(3 * embed_dim, **factory_kwargs)
+ else:
+ self.in_proj_bias = None
+ self.out_proj = Linear(embed_dim, embed_dim, bias=bias)
- assert not add_bias_kv, "TODO: add_bias_kv=True"
- self.bias_k = self.bias_v = None
+ if add_bias_kv:
+ self.bias_k = jt.empty((1, 1, embed_dim), **factory_kwargs)
+ self.bias_v = jt.empty((1, 1, embed_dim), **factory_kwargs)
+ else:
+ self.bias_k = self.bias_v = None
self.add_zero_attn = add_zero_attn
- self.reset_parameters()
-
- self.onnx_trace = False
- self.tpu = False
-
- def reset_parameters(self):
- if self.qkv_same_dim:
- # Empirically observed the convergence to be much better with
- # the scaled initialization
- init.xavier_uniform_(self.k_proj.weight, gain=1 / math.sqrt(2))
- init.xavier_uniform_(self.v_proj.weight, gain=1 / math.sqrt(2))
- init.xavier_uniform_(self.q_proj.weight, gain=1 / math.sqrt(2))
+ self._reset_parameters()
+
+ def _reset_parameters(self):
+ if self._qkv_same_embed_dim:
+ xavier_uniform_(self.in_proj_weight)
else:
- init.xavier_uniform_(self.k_proj.weight)
- init.xavier_uniform_(self.v_proj.weight)
- init.xavier_uniform_(self.q_proj.weight)
+ xavier_uniform_(self.q_proj_weight)
+ xavier_uniform_(self.k_proj_weight)
+ xavier_uniform_(self.v_proj_weight)
- # init.xavier_uniform_(self.out_proj.weight)
- if self.out_proj.bias is not None:
- init.constant_(self.out_proj.bias, 0.)
+ if self.in_proj_bias is not None:
+ constant_(self.in_proj_bias, 0.)
+ constant_(self.out_proj.bias, 0.)
if self.bias_k is not None:
- init.xavier_normal_(self.bias_k)
+ xavier_gauss_(self.bias_k)
if self.bias_v is not None:
- init.xavier_normal_(self.bias_v)
+ xavier_gauss_(self.bias_v)
+
+ def __setstate__(self, state):
+ # Support loading old MultiheadAttention checkpoints generated by v1.1.0
+ if '_qkv_same_embed_dim' not in state:
+ state['_qkv_same_embed_dim'] = True
+
+ super().__setstate__(state)
def execute(
- self,
- query,
- key = None,
- value = None,
- key_padding_mask = None,
- incremental_state = None,
- need_weights = True,
- static_kv = False,
- attn_mask = None,
- before_softmax = False,
- need_head_weights = False,
- ):
- if need_head_weights:
- need_weights = True
-
- tgt_len, bsz, embed_dim = query.shape
- assert embed_dim == self.embed_dim
- assert list(query.shape) == [tgt_len, bsz, embed_dim]
-
- assert incremental_state is None, "TODO: incremental_state is not None"
- saved_state = None
-
- if self.self_attention:
- q = self.q_proj(query)
- k = self.k_proj(query)
- v = self.v_proj(query)
- elif self.encoder_decoder_attention:
- # encoder-decoder attention
- q = self.q_proj(query)
- if key is None:
- assert value is None
- k = v = None
+ self,
+ query: Var,
+ key: Var,
+ value: Var,
+ key_padding_mask: Optional[Var] = None,
+ need_weights: bool = True,
+ attn_mask: Optional[Var] = None,
+ average_attn_weights: bool = True,
+ is_causal : bool = False) -> Tuple[Var, Optional[Var]]:
+
+ #####
+ # Fast Path is not Supported.
+ #####
+
+ is_batched = query.dim() == 3
+
+ key_padding_mask = _canonical_mask(
+ mask=key_padding_mask,
+ mask_name="key_padding_mask",
+ other_type=_none_or_dtype(attn_mask),
+ other_name="attn_mask",
+ target_type=query.dtype
+ )
+
+ attn_mask = _canonical_mask(
+ mask=attn_mask,
+ mask_name="attn_mask",
+ other_type=None,
+ other_name="",
+ target_type=query.dtype,
+ check_other=False,
+ )
+
+ if self.batch_first and is_batched:
+ # make sure that the transpose op does not affect the "is" property
+ if key is value:
+ if query is key:
+ query = key = value = query.transpose(1, 0)
+ else:
+ query, key = (x.transpose(1, 0) for x in (query, key))
+ value = key
else:
- k = self.k_proj(key)
- v = self.v_proj(key)
+ query, key, value = (x.transpose(1, 0) for x in (query, key, value))
+
+ if not self._qkv_same_embed_dim:
+ attn_output, attn_output_weights = multi_head_attention_forward(
+ query, key, value, self.embed_dim, self.num_heads,
+ self.in_proj_weight, self.in_proj_bias,
+ self.bias_k, self.bias_v, self.add_zero_attn,
+ self.dropout, self.out_proj.weight, self.out_proj.bias,
+ training=self.is_training(),
+ key_padding_mask=key_padding_mask, need_weights=need_weights,
+ attn_mask=attn_mask,
+ use_separate_proj_weight=True,
+ q_proj_weight=self.q_proj_weight, k_proj_weight=self.k_proj_weight,
+ v_proj_weight=self.v_proj_weight,
+ average_attn_weights=average_attn_weights,
+ is_causal=is_causal)
else:
- assert key is not None and value is not None
- q = self.q_proj(query)
- k = self.k_proj(key)
- v = self.v_proj(value)
- q = q*self.scaling
-
- assert self.bias_k is None, "TODO: self.bias_k is not None:"
-
- q = q.view(tgt_len, bsz * self.num_heads, self.head_dim).transpose(1, 0, 2)
- if k is not None:
- k = k.view(-1, bsz * self.num_heads, self.head_dim).transpose(1, 0, 2)
- if v is not None:
- v = v.view(-1, bsz * self.num_heads, self.head_dim).transpose(1, 0, 2)
-
- assert saved_state is None, "TODO: saved_state is not None"
- assert k is not None
- src_len = k.shape[1]
-
- assert key_padding_mask is None, "TODO: key_padding_mask is not None"
- assert not self.add_zero_attn, "TODO: self.add_zero_attn=True"
-
- attn_weights = nn.bmm(q, k.transpose(0, 2, 1))
-
- assert list(attn_weights.shape) == [bsz * self.num_heads, tgt_len, src_len]
-
- assert attn_mask is None, "TODO: attn_mask is not None"
- assert key_padding_mask is None, "TODO: key_padding_mask is not None"
-
- if before_softmax:
- return attn_weights, v
-
- attn_weights_float = nn.softmax(attn_weights, dim=-1)
- attn_weights = attn_weights_float.type_as(attn_weights)
-
- assert v is not None
- attn = nn.bmm(attn_weights, v)
- assert list(attn.shape) == [bsz * self.num_heads, tgt_len, self.head_dim]
- if self.onnx_trace and attn.shape[1] == 1:
- # when ONNX tracing a single decoder step (sequence length == 1)
- # the transpose is a no-op copy before view, thus unnecessary
- attn = attn.view(tgt_len, bsz, embed_dim)
+ attn_output, attn_output_weights = multi_head_attention_forward(
+ query, key, value, self.embed_dim, self.num_heads,
+ self.in_proj_weight, self.in_proj_bias,
+ self.bias_k, self.bias_v, self.add_zero_attn,
+ self.dropout, self.out_proj.weight, self.out_proj.bias,
+ training=self.is_training(),
+ key_padding_mask=key_padding_mask,
+ need_weights=need_weights,
+ attn_mask=attn_mask,
+ average_attn_weights=average_attn_weights,
+ is_causal=is_causal)
+ if self.batch_first and is_batched:
+ return attn_output.transpose(1, 0), attn_output_weights
else:
- attn = attn.transpose(1, 0, 2).view(tgt_len, bsz, embed_dim)
- attn = self.out_proj(attn)
- attn_weights = None
- if need_weights:
- attn_weights = attn_weights_float.view(bsz, self.num_heads, tgt_len, src_len).transpose(1, 0, 2, 3)
- if not need_head_weights:
- # average attention weights over heads
- attn_weights = attn_weights.mean(dims=[0])
-
- return attn, attn_weights
+ return attn_output, attn_output_weights
diff --git a/python/jittor/compatibility/__init__.py b/python/jittor/compatibility/__init__.py
deleted file mode 100644
index 6f88fad5..00000000
--- a/python/jittor/compatibility/__init__.py
+++ /dev/null
@@ -1,430 +0,0 @@
-import os
-os.environ["FIX_TORCH_ERROR"] = "0"
-
-import jittor as jt
-from jittor import *
-from typing import Tuple
-
-org_int = int = type(1)
-org_float = float = type(1.0)
-org_bool = bool = type(True)
-
-import jtorch.compiler
-
-import jtorch_core
-from jtorch_core import *
-
-device.__reduce__ = lambda self: (device, (self.type,))
-device.__module__ = "jtorch"
-jt.jittor_core.device = device
-
-def handle_dtype(args, kw, dtype):
- def convert(x):
- if isinstance(x, jt.Var):
- return x.cast(dtype)
- return x
- if dtype is not None:
- if args is not None:
- if isinstance(args, (tuple,list)):
- args = [ convert(a) for a in args ]
- else:
- args = convert(x)
- if kw is not None:
- kw = { k:convert(v) for k,v in kw.items() }
- return args, kw
-
-def get_args_names(func):
- import inspect
- spec = inspect.getfullargspec(func)
- return spec[0] + spec[4]
-
-def wrapper(func):
- has_dtype = False
- if hasattr(func, "__code__"):
- has_dtype = "dtype" in get_args_names(func)
- def inner(*args, **kw):
- requires_grad = None
- dtype = None
- if "requires_grad" in kw:
- requires_grad = kw["requires_grad"]
- del kw["requires_grad"]
- if not has_dtype and "dtype" in kw:
- dtype = kw["dtype"]
- del kw["dtype"]
- if "device" in kw:
- del kw["device"]
- if 'pin_memory' in kw:
- del kw['pin_memory']
- args, kw = handle_dtype(args, kw, dtype)
- ret = func(*args, **kw)
- if isinstance(ret, jt.Var):
- if requires_grad is not None:
- ret.requires_grad = requires_grad
- if dtype is not None:
- ret.astype(dtype)
- return ret
- return inner
-
-
-import inspect
-_wrapper_keys = set(["shape", "start", "size"])
-_wrapper_keys.add("x")
-for k,v in list(globals().items()):
- if callable(v) and not isinstance(v, type):
- try:
- spec = inspect.getfullargspec(v)
- args_name = spec[0]
- if len(args_name) and args_name[0] in _wrapper_keys:
- globals()[k] = wrapper(v)
- elif spec.varargs in _wrapper_keys:
- globals()[k] = wrapper(v)
- except:
- pass
-
-def empty(*size, dtype=jt.float32, device=None, requires_grad=False):
- if len(size) == 1 and not isinstance(size[0], org_int):
- size = size[0]
- return jt.empty(size, dtype)
-
-Tensor = Var
-
-Tensor.backward = lambda x: jtorch_core.backward(x)
-Tensor.grad = property(grad_get, grad_set, grad_del)
-Tensor.retains_grad = property(retain_grad_get, retain_grad_set)
-def retain_grad(x:Tensor, value:bool=True):
- x.retains_grad = value
- return value
-Tensor.retain_grad = retain_grad
-
-Tensor.dim = lambda self: self.ndim
-Tensor.ndimension = lambda self: self.ndim
-Tensor.nelement = lambda self: self.numel()
-Tensor.cuda = lambda self: self
-def device_get(x:Tensor):
- return device("cpu") if not jt.has_cuda or not jt.flags.use_cuda else device("cuda")
-Tensor.device = property(device_get)
-
-def argmax(x: Var, dim=None, keepdim: bool = False):
- return jt.argmax(x, dim, keepdim)[0]
-Tensor.argmax = argmax
-
-def tensor_type(x: Var, dtype=None, **kwargs):
- if dtype:
- return x.astype(dtype)
- else:
- return x.dtype
-Tensor.type = tensor_type
-
-def is_floating_point(x: Var):
- return "float" in str(x.dtype)
-Tensor.is_floating_point = is_floating_point
-
-from . import autograd
-from .autograd import *
-
-def tensor(data, *, dtype=None, device=None, requires_grad=False, pin_memory=False):
- if isinstance(data,list):
- data_list = []
- check = True
- for p in data:
- if isinstance(p, Tensor) and p.numel()==1:
- data_list.append(p.item())
- elif isinstance(p, (org_int,org_float)):
- data_list.append(p)
- else:
- check = False
- break
- if check:
- data = data_list
- return wrapper(array)(data, dtype=dtype, device=device, requires_grad=requires_grad, pin_memory=pin_memory)
-
-# tensor = wrapper(array)
-from_numpy = wrapper(array)
-strided = None
-
-def mod_zero_grad(self):
- for p in self.parameters():
- p.grad = None
-Module.zero_grad = mod_zero_grad
-
-class ModuleMisc:
- def parameters(self):
- return iter(super().parameters())
-
- def load_state_dict(self, state_dict, strict=False):
- return super().load_state_dict(state_dict)
-
- def to(self, device=None,dtype=None):
- ''' do nothing but return its self'''
- return self
- def register_parameter(self,name,data):
- self.name = data
-
- def buffers(self):
- for _, buf in self.named_buffers():
- yield buf
-
-
-def make_module(cls):
- class TMod(ModuleMisc, cls):
- def __init__(self, *args, **kw):
- dtype = None
- if "dtype" in kw:
- dtype = kw["dtype"]
- del kw["dtype"]
- self._dtype = dtype
- with jt.flag_scope(th_mode=0):
- if "device" in kw:
- del kw["device"]
- super().__init__(*args, **kw)
- for k,v in self.__dict__.items():
- if not k.startswith("_") and isinstance(v, Var) \
- and v.requires_grad:
- v.retain_grad()
- if dtype is not None and isinstance(v, Var):
- v.assign(v.cast(dtype))
- def __call__(self, *args, **kw):
- args, kw = handle_dtype(args, kw, self._dtype)
- # if forward is override by user, call forward
- if self.__class__.forward is not TMod.forward:
- return self.forward(*args, **kw)
- return self.execute(*args, **kw)
- def forward(self, *args, **kw):
- args, kw = handle_dtype(args, kw, self._dtype)
- return self.execute(*args, **kw)
-
- @property
- def training(self):
- if not hasattr(self, "is_train"):
- self.is_train = True
- return self.is_train
- @training.setter
- def training(self, value):
- self.is_train = value
-
- TMod.__name__ = cls.__name__
- return TMod
-
-import jtorch.cuda
-import jtorch.nn
-from jtorch.nn import Module, Parameter
-import jtorch.optim
-
-from jtorch.utils.dtype import Dtype, get_string_dtype
-
-def frombuffer(buffer: bytearray,
- *,
- dtype: Dtype,
- count: int = -1,
- offset: int = 0,
- requires_grad: bool = True) -> Tensor:
- dtype = get_string_dtype(dtype)
- tensor = jt.array(np.frombuffer(buffer, dtype, count=count, offset=offset))
- if requires_grad and tensor.dtype.is_float():
- tensor.requires_grad = True
- return tensor
-
-def conflict_wrapper(origin_func, new_func):
- def wrapper(*args, **kw):
- if jt.flags.th_mode:
- return new_func(*args, **kw)
- else:
- return origin_func(*args, **kw)
- return wrapper
-
-def min(*args, **kw):
- dim = None
- if len(args) >= 2 and isinstance(args[1], org_int):
- dim = args[1]
- elif "dim" in kw and isinstance(kw["dim"], org_int):
- dim = kw["dim"]
- if dim is not None:
- k, v = jt.argmin(*args, **kw)
- return v, k
- elif len(args) == 2 and isinstance(args[1], jt.Var):
- return jt.minimum(args[0], args[1])
- else:
- return jt.min(*args, **kw)
-Tensor.min = conflict_wrapper(jt.min, min)
-
-def max(*args, **kw):
- dim = None
- if "dim" in kw:
- x = kw["dim"]
- if len(args) >= 2 and isinstance(args[1], org_int):
- dim = args[1]
- elif "dim" in kw and isinstance(kw["dim"], org_int):
- dim = kw["dim"]
- if dim is not None:
- k, v = jt.argmax(*args, **kw)
- return v, k
- elif len(args) == 2 and isinstance(args[1], jt.Var):
- return jt.maximum(args[0], args[1])
- else:
- return jt.max(*args, **kw)
-Tensor.max = conflict_wrapper(jt.max, max)
-
-def argsort(*args, **kw):
- k, v = jt.argsort(*args, **kw)
- return k
-Tensor.argsort = conflict_wrapper(jt.argsort, argsort)
-
-LongTensor = jt.int64
-FloatTensor = jt.float
-HalfTensor = jt.float16
-BoolTensor = jt.bool
-IntTensor = jt.int32
-
-class JDType:
- def __init__(self, func, str):
- self.func = func
- self.str = str
- self.__name__ = str.split(".")[-1]
- def __call__(self, *args, **kw):
- return self.func(*args, **kw)
- def __str__(self):
- return self.str
- def is_floating_point(self):
- return "float" in str(self.str)
-
-int8 = JDType(jt.int8, "torch.int8")
-int16 = JDType(jt.int16, "torch.int16")
-int = int32 = JDType(jt.int32, "torch.int32")
-long = int64 = JDType(jt.int64, "torch.int64")
-
-half = float16 = JDType(jt.float16, "torch.float16")
-float = float32 = JDType(jt.float32, "torch.float32")
-double = float64 = JDType(jt.float64, "torch.float64")
-bfloat16 = "bfloat16" # TODO
-complex64 = "complex64" # TODO
-complex128 = "complex128" # TODO
-def get_JDtype(dtype):
- if dtype=='float32' or dtype == jt.float32:
- return float32
- elif dtype=='float64' or dtype == jt.float64:
- return float64
- elif dtype=='float16' or dtype == jt.float16:
- return float16
- elif dtype=='int32' or dtype == jt.int32:
- return int32
- elif dtype=='int64' or dtype == jt.int64:
- return int64
- elif dtype=='int16' or dtype == jt.int16:
- return int16
- elif dtype=='int8' or dtype == jt.int8:
- return int8
- else:
- raise Exception("dtype {} not supported".format(dtype))
-
-def load(path,**kwargs):
- def _to_jittor(data):
- if isinstance(data,dict):
- return {k:_to_jittor(d) for k,d in data.items()}
- if isinstance(data,list):
- return [_to_jittor(d) for d in data]
- if isinstance(data,np.ndarray):
- return jt.array(data)
- return data
- data = jt.load(path)
-
- return _to_jittor(data)
-
-def is_tensor(x):
- return isinstance(x, Tensor)
-
-manual_seed = jt.set_global_seed
-jt.flags.amp_level = 3
-Size = jt.NanoVector
-
-class Generator:
- def __init__(self,*args,**kw) -> None:
- self.seed = None
- def manual_seed(self,seed):
- self.seed = seed
-
-
-
-from . import fx
-
-
-_default_type = "float32"
-
-def get_default_dtype():
- return _default_type
-def set_default_dtype(dtype):
- global _default_type
- _default_type = dtype
-
-dtype = JDType
-
-def div(x,y,rounding_mode="floor"):
- assert rounding_mode == "floor"
- z = (x / y)
- if rounding_mode == "floor":
- z = z.floor()
- if x.dtype == "int32" and (isinstance(y,org_int) or y.dtype == "int32"):
- z = z.int32()
- return z
-
-
-def randn(*args,**kw):
- wrap_randn = wrapper(jt.randn)
- generator = kw.get('generator',None)
- kw.pop('generator',None)
- if 'layout' in kw:
- del kw['layout']
- if generator is not None and generator.seed is not None:
- jt.set_global_seed(generator.seed)
- return wrap_randn(*args,**kw)
-
-def rand(*args,**kw):
- print("rand")
- wrap_rand = wrapper(jt.rand)
- generator = kw.get('generator',None)
- kw.pop('generator',None)
- if 'layout' in kw:
- del kw['layout']
- if generator is not None and generator.seed is not None:
- jt.set_global_seed(generator.seed)
- return wrap_rand(*args,**kw)
-
-
-
-def set_default_tensor_type(t: type or str):
- if isinstance(t, str):
- info = t.split(".")
- if len(info) == 3 and info[1] == 'cuda':
- jt.flags.use_cuda = 1
- #TODO: type
-
-
-def clamp(x, min=None, max=None):
- return jt.clamp(x, min, max)
-
-
-def to(x,*args,**kw):
- device = None
- if len(args) == 1:
- device = args[0]
- if isinstance(device, jt.NanoString) or callable(device):
- return jt.to(x,*args,**kw)
- if 'cpu' in str(device):
- args = []
- device = kw.get("device",None)
- if 'cpu' in str(device):
- kw.pop('device',None)
- print("to cpu")
- # print(kw)
- return jt.to(x,*args,**kw)
-Tensor.to = conflict_wrapper(jt.to, to)
-
-mm = wrapper(jt.matmul)
-
-def _data_get(x):
- return x
-
-def _data_set(x, value):
- x.assign(value)
-
-Tensor.data = property(_data_get, _data_set)
-Tensor.layout = None
\ No newline at end of file
diff --git a/python/jittor/compatibility/autograd.py b/python/jittor/compatibility/autograd.py
deleted file mode 100644
index 5ed88dde..00000000
--- a/python/jittor/compatibility/autograd.py
+++ /dev/null
@@ -1,134 +0,0 @@
-import jittor as jt
-from jittor import Var
-from collections.abc import Sequence, Mapping
-
-Variable = Var
-
-class FunctionContext:
- def save_for_backward(self, *args):
- self.saved_tensors = args
-
-class Function:
- ''' Function Module for customized backward operations
-
-Example 1 (Function can have multiple input and multiple output, and user
-can store value for backward computation)::
-
- import jtorch
- from jtorch import Function
-
- class MyFunc(Function):
- @staticmethod
- def forward(self, x, y):
- self.x = x
- self.y = y
- return x*y, x/y
-
- @staticmethod
- def backward(self, grad0, grad1):
- return grad0 * self.y, grad1 * self.x
-
- a = jtorch.array(3.0)
- a.requires_grad = True
- b = jtorch.array(4.0)
- b.requires_grad = True
- func = MyFunc.apply
- c,d = func(a, b)
- (c+d*3).backward()
- assert a.grad.data == 4
- assert b.grad.data == 9
-
-Example 2(Function can return None for no gradiant, and gradiant
-can also be None)::
-
- import jtorch
- from jtorch import Function
-
- class MyFunc(Function):
- @staticmethod
- def forward(self, x, y):
- self.x = x
- self.y = y
- return x*y, x/y
-
- @staticmethod
- def backward(self, grad0, grad1):
- assert grad1 is None
- return grad0 * self.y, None
- a = jt.array(3.0)
- a.requires_grad = True
- b = jt.array(4.0)
- b.requires_grad = True
- func = MyFunc.apply
- c,d = func(a, b)
- d.stop_grad()
- da, db = jt.grad(c+d*3, [a, b])
- assert da.data == 4
- assert db.data == 0
-
- '''
- def __call__(self, *args):
- backup = args
- args = list(args)
- taped_inputs = []
- taped_outputs = []
- input_mask = [-1] * len(args)
- for i,v in enumerate(args):
- if isinstance(v, Var):
- if v.is_stop_grad():
- # -2 in input_mask represents it is stop_grad
- input_mask[i] = -2
- continue
- v = v.tape()
- input_mask[i] = len(taped_inputs)
- args[i] = v
- taped_inputs.append(v)
- ctx = FunctionContext()
- ori_res = self.forward(ctx, *args)
- # ori_res = self.execute(*args)
- if not isinstance(ori_res, Sequence):
- res = [ori_res]
- else:
- res = list(ori_res)
- output_mask = [-1] * len(res)
- for i,v in enumerate(res):
- if isinstance(v, Var):
- v = v.tape()
- output_mask[i] = len(taped_outputs)
- res[i] = v
- taped_outputs.append(v)
- ctx.input_mask = input_mask
- ctx.output_mask = output_mask
- # tape output and input together so
- # backward treat them as one operator
- jt.tape_together(taped_inputs, taped_outputs,
- lambda *args: self._grad(ctx, self, *args))
- if isinstance(ori_res, Sequence):
- return res
- else:
- return res[0]
-
- @staticmethod
- def _grad(ctx, func, *args):
- new_args = ( (args[i] if i>=0 else None) for i in ctx.output_mask )
- ret = func.backward(ctx, *new_args)
- if not isinstance(ret, Sequence):
- ret = (ret,)
- new_ret = []
- for i, r in enumerate(ret):
- j = ctx.input_mask[i]
- if j<0:
- # -2 in input_mask represents it is stop_grad
- assert r is None or j==-2, f"{type(self)}'s {i}-th returned grad should be None, "\
- "because the input value is not jittor variable."
- else:
- new_ret.append(r)
- return new_ret
-
- def dfs(self, parents, k, callback, callback_leave=None):
- pass
-
- @classmethod
- def apply(cls, *args, **kw):
- func = cls()
- return func(*args, **kw)
diff --git a/python/jittor/compatibility/compiler.py b/python/jittor/compatibility/compiler.py
deleted file mode 100644
index 77bab138..00000000
--- a/python/jittor/compatibility/compiler.py
+++ /dev/null
@@ -1,39 +0,0 @@
-import jittor as jt
-import jittor_utils
-import glob
-import os
-from jittor import pyjt_compiler
-import sys
-from jittor_utils import lock
-
-
-jtorch_path = os.path.dirname(__file__)
-cache_path = os.path.join(jt.compiler.cache_path, "jtorch")
-# os.makedirs(cache_path, exist_ok=True)
-os.makedirs(os.path.join(cache_path, "gen"), exist_ok=True)
-
-with lock.lock_scope():
- pyjt_gen_src = pyjt_compiler.compile(cache_path, jtorch_path)
-
-ext_args = 'c[cu]' if jt.has_cuda else 'cc'
-files = glob.glob(jtorch_path+"/src/**/*."+ext_args, recursive=True)
-files += pyjt_gen_src
-cc_flags = " -I\""+os.path.join(jtorch_path, "src")+"\" "
-if os.environ.get("use_data_o", "1") == "1":
- files += glob.glob(jtorch_path+"/src/**/*.o", recursive=True)
- files = [f for f in files if "__data__" not in f]
-
-
-with lock.lock_scope():
- jt.compiler.compile(
- jt.compiler.cc_path,
- jt.compiler.cc_flags+jt.compiler.opt_flags+ cc_flags,
- files,
- "jtorch_core"+jt.compiler.extension_suffix,
- obj_dirname="jtorch_objs")
-
-
-with jittor_utils.import_scope(jt.compiler.import_flags):
- import jtorch_core as core
-
-jt.flags.th_mode = 1
diff --git a/python/jittor/compatibility/cuda.py b/python/jittor/compatibility/cuda.py
deleted file mode 100644
index 75665c7c..00000000
--- a/python/jittor/compatibility/cuda.py
+++ /dev/null
@@ -1,64 +0,0 @@
-import jittor as jt
-import jtorch
-
-def is_available():
- return jt.has_cuda
-
-def device_count():
- return int(jt.has_cuda)
-
-def set_device(device=None):
- pass
-
-def get_rng_state(device=None):
- pass
-
-def current_device():
- return jtorch.device("cuda")
-
-def mem_get_info(i):
- return ("75GB",)
-
-
-class Generator:
- def __init__(self):
- pass
-
- def set_state(self, state):
- self.state = state
-
-default_generators = [Generator()]
-_lazy_call = lambda func: func()
-device = None
-
-LongTensor = jt.int64
-FloatTensor = jt.float
-HalfTensor = jt.float16
-BoolTensor = jt.bool
-
-manual_seed = jt.set_global_seed
-manual_seed_all = jt.set_global_seed
-
-def synchronize():
- jt.sync_all(True)
-
-class Event:
- pass
-
-class Stream:
- pass
-
-from typing import Any
-
-from .gradscaler import GradScaler
-
-class autocast:
- def __init__(self,**kwargs):
- pass
-
- def __enter__(self,):
- pass
-
- def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any):
- pass
-
diff --git a/python/jittor/compatibility/distributed.py b/python/jittor/compatibility/distributed.py
deleted file mode 100644
index e39f559a..00000000
--- a/python/jittor/compatibility/distributed.py
+++ /dev/null
@@ -1,53 +0,0 @@
-import datetime
-from enum import Enum
-import jittor as jt
-
-
-class DistributedDataParallel:
- def __new__(cls, model):
- return model
-
-def is_initialized():
- return True
-
-def get_rank(group=None):
- return 0
-
-def get_world_size(group=None):
- return 1
-
-def get_backend(group=None):
- return "nccl"
-
-def new_group(ranks=None, timeout=datetime.timedelta(seconds=1800), backend=None, pg_options=None):
- return 1
-
-def barrier():
- pass
-
-def is_available():
- return True
-
-def is_built():
- return True
-
-class ReduceOp:
- SUM = 0
-
-class GroupMember:
- WORLD = 0
-
-class ProcessGroup:
- pass
-
-class Join:
- pass
-
-dist_backend = Enum("dist_backend", ("GLOO", "MPI", "NCCL"))
-_backend = dist_backend.NCCL
-
-def is_mpi_available():
- return jt.in_mpi
-
-def DistributedDataParallel(model, *args, **kw):
- return model
diff --git a/python/jittor/compatibility/distributions.py b/python/jittor/compatibility/distributions.py
deleted file mode 100644
index a98dfe29..00000000
--- a/python/jittor/compatibility/distributions.py
+++ /dev/null
@@ -1,15 +0,0 @@
-import jittor as jt
-
-class RelaxedBernoulli:
- def __init__(self, temperature, probs=None, logits=None):
- self.temperature = temperature
- self.probs = probs
- self.logits = logits
-
- def rsample(self):
- noise = jt.rand_like(self.logits)
- eps = 1e-20
- noise = jt.clamp(noise, eps, 1.0 - eps)
- logit_noise = jt.log(noise) - jt.log(1 - noise)
- sample = (self.logits + logit_noise) / self.temperature
- return jt.sigmoid(sample)
diff --git a/python/jittor/compatibility/fft/__init__.py b/python/jittor/compatibility/fft/__init__.py
deleted file mode 100644
index 7a89fc9c..00000000
--- a/python/jittor/compatibility/fft/__init__.py
+++ /dev/null
@@ -1,5 +0,0 @@
-#TODO: Implement FFT and IFFT
-fftn = None
-fftshift = None
-ifftn = None
-ifftshift = None
\ No newline at end of file
diff --git a/python/jittor/compatibility/fx.py b/python/jittor/compatibility/fx.py
deleted file mode 100644
index 0f0eb4f8..00000000
--- a/python/jittor/compatibility/fx.py
+++ /dev/null
@@ -1,2 +0,0 @@
-class Proxy:
- pass
\ No newline at end of file
diff --git a/python/jittor/compatibility/gradscaler.py b/python/jittor/compatibility/gradscaler.py
deleted file mode 100644
index 087d6bb2..00000000
--- a/python/jittor/compatibility/gradscaler.py
+++ /dev/null
@@ -1,519 +0,0 @@
-from collections import defaultdict, abc
-from enum import Enum
-from typing import Any, Dict, List, Optional, Tuple, cast
-import inspect
-import warnings
-
-import jittor as jt
-# import torch
-
-def _refresh_per_optimizer_state():
- return {}
-
-
-class GradScaler:
- _scale: Optional[jt.Var]
- _grows_tracker: Optional[jt.Var]
- _per_optimizer_states: Dict[int, Dict[str, Any]]
- """
- An instance ``scaler`` of :class:`GradScaler` helps perform the steps of gradient scaling
- conveniently.
-
- * ``scaler.scale(loss)`` multiplies a given loss by ``scaler``'s current scale factor.
- * ``scaler.step(optimizer)`` safely unscales gradients and calls ``optimizer.step()``.
- * ``scaler.update()`` updates ``scaler``'s scale factor.
-
- Example::
-
- # Creates a GradScaler once at the beginning of training.
- scaler = GradScaler()
-
- for epoch in epochs:
- for input, target in data:
- optimizer.zero_grad()
- output = model(input)
- loss = loss_fn(output, target)
-
- # Scales loss. Calls backward() on scaled loss to create scaled gradients.
- scaler.scale(loss).backward()
-
- # scaler.step() first unscales gradients of the optimizer's params.
- # If gradients don't contain infs/NaNs, optimizer.step() is then called,
- # otherwise, optimizer.step() is skipped.
- scaler.step(optimizer)
-
- # Updates the scale for next iteration.
- scaler.update()
-
- See the :ref:`Automatic Mixed Precision examples` for usage
- (along with autocasting) in more complex cases like gradient clipping, gradient accumulation, gradient penalty,
- and multiple losses/optimizers.
-
- ``scaler`` dynamically estimates the scale factor each iteration. To minimize gradient underflow,
- a large scale factor should be used. However, ``float16`` values can "overflow" (become inf or NaN) if
- the scale factor is too large. Therefore, the optimal scale factor is the largest factor that can be used
- without incurring inf or NaN gradient values.
- ``scaler`` approximates the optimal scale factor over time by checking the gradients for infs and NaNs during every
- ``scaler.step(optimizer)`` (or optional separate ``scaler.unscale_(optimizer)``, see :meth:`unscale_`).
-
- * If infs/NaNs are found, ``scaler.step(optimizer)`` skips the underlying ``optimizer.step()`` (so the params
- themselves remain uncorrupted) and ``update()`` multiplies the scale by ``backoff_factor``.
-
- * If no infs/NaNs are found, ``scaler.step(optimizer)`` runs the underlying ``optimizer.step()`` as usual.
- If ``growth_interval`` unskipped iterations occur consecutively, ``update()`` multiplies the scale by
- ``growth_factor``.
-
- The scale factor often causes infs/NaNs to appear in gradients for the first few iterations as its
- value calibrates. ``scaler.step`` will skip the underlying ``optimizer.step()`` for these
- iterations. After that, step skipping should occur rarely (once every few hundred or thousand iterations).
-
- Args:
- init_scale (float, optional, default=2.**16): Initial scale factor.
- growth_factor (float, optional, default=2.0): Factor by which the scale is multiplied during
- :meth:`update` if no inf/NaN gradients occur for ``growth_interval`` consecutive iterations.
- backoff_factor (float, optional, default=0.5): Factor by which the scale is multiplied during
- :meth:`update` if inf/NaN gradients occur in an iteration.
- growth_interval (int, optional, default=2000): Number of consecutive iterations without inf/NaN gradients
- that must occur for the scale to be multiplied by ``growth_factor``.
- enabled (bool, optional): If ``False``, disables gradient scaling. :meth:`step` simply
- invokes the underlying ``optimizer.step()``, and other methods become no-ops.
- Default: ``True``
- """
- def __init__(self,
- init_scale=2.**16,
- growth_factor=2.0,
- backoff_factor=0.5,
- growth_interval=2000,
- enabled=True):
- self._enabled = enabled
-
- if self._enabled:
- assert growth_factor > 1.0, "The growth factor must be > 1.0."
- assert backoff_factor < 1.0, "The backoff factor must be < 1.0."
-
- self._init_scale = init_scale
- # self._scale will be lazily initialized during the first call to scale()
- self._scale = None
- self._growth_factor = growth_factor
- self._backoff_factor = backoff_factor
- self._growth_interval = growth_interval
- self._init_growth_tracker = 0
- # self._growth_tracker will be lazily initialized during the first call to scale()
- self._growth_tracker = None
- self._per_optimizer_states = defaultdict(_refresh_per_optimizer_state)
-
- def _check_scale_growth_tracker(self, funcname) -> Tuple[jt.Var, jt.Var]:
- fix = "This may indicate your script did not use scaler.scale(loss or outputs) earlier in the iteration."
- assert self._scale is not None, "Attempted {} but _scale is None. ".format(funcname) + fix
- assert self._growth_tracker is not None, "Attempted {} but _growth_tracker is None. ".format(funcname) + fix
- return (self._scale, self._growth_tracker)
-
- def _lazy_init_scale_growth_tracker(self):
- assert self._growth_tracker is None, "_growth_tracker initialized before _scale"
- self._scale = self._init_scale
- self._growth_tracker = self._init_growth_tracker
-
- def scale(self, outputs):
- """
- Multiplies ('scales') a tensor or list of tensors by the scale factor.
-
- Returns scaled outputs. If this instance of :class:`GradScaler` is not enabled, outputs are returned
- unmodified.
-
- Args:
- outputs (Tensor or iterable of Tensors): Outputs to scale.
- """
- if not self._enabled:
- return outputs
-
-
- # Short-circuit for the common case.
- if isinstance(outputs, jt.Var):
- assert jt.flags.use_cuda == 1
- if self._scale is None:
- self._lazy_init_scale_growth_tracker()
- assert self._scale is not None
- return outputs * self._scale
-
- def apply_scale(val):
- if isinstance(val, jt.Var):
- assert jt.flags.use_cuda == 1
- if self._scale is None:
- self._lazy_init_scale_growth_tracker()
- assert self._scale is not None
- return val * self._scale
- elif isinstance(val, abc.Iterable):
- iterable = map(apply_scale, val)
- if isinstance(val, (list, tuple)):
- return type(val)(iterable)
- else:
- return iterable
- else:
- raise ValueError("outputs must be a Tensor or an iterable of Tensors")
-
- return apply_scale(outputs)
-
- def _unscale_grads_(self, optimizer, inv_scale, found_inf, allow_fp16):
- with jt.no_grad():
- optimizer.pre_step()
- for group in optimizer.param_groups:
- for to_unscale in group["grads"]:
- if to_unscale is None or isinstance(to_unscale,(int,float)):
- continue
- if (not allow_fp16) and str(to_unscale.dtype) == "float16":
- raise ValueError("Attempting to unscale FP16 gradients.")
-
- if not (to_unscale.isinf().any()):
- if inv_scale != 1.0:
- to_unscale.update(to_unscale*inv_scale)
- else:
- found_inf = 1.0
-
- return found_inf
-
- def unscale_(self, optimizer):
- """
- Divides ("unscales") the optimizer's gradient tensors by the scale factor.
-
- :meth:`unscale_` is optional, serving cases where you need to
- :ref:`modify or inspect gradients`
- between the backward pass(es) and :meth:`step`.
- If :meth:`unscale_` is not called explicitly, gradients will be unscaled automatically during :meth:`step`.
-
- Simple example, using :meth:`unscale_` to enable clipping of unscaled gradients::
-
- ...
- scaler.scale(loss).backward()
- scaler.unscale_(optimizer)
- torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm)
- scaler.step(optimizer)
- scaler.update()
-
- Args:
- optimizer (torch.optim.Optimizer): Optimizer that owns the gradients to be unscaled.
-
- .. note::
- :meth:`unscale_` does not incur a CPU-GPU sync.
-
- .. warning::
- :meth:`unscale_` should only be called once per optimizer per :meth:`step` call,
- and only after all gradients for that optimizer's assigned parameters have been accumulated.
- Calling :meth:`unscale_` twice for a given optimizer between each :meth:`step` triggers a RuntimeError.
-
- .. warning::
- :meth:`unscale_` may unscale sparse gradients out of place, replacing the ``.grad`` attribute.
- """
- if not self._enabled:
- return
-
- self._check_scale_growth_tracker("unscale_")
-
- optimizer_state = self._per_optimizer_states[id(optimizer)]
-
- if hasattr(optimizer,"get_find_inf"):
- return
- # FP32 division can be imprecise for certain compile options, so we carry out the reciprocal in FP64.
- assert self._scale is not None
- inv_scale = 1.0 / self._scale
- found_inf = 0.0
- optimizer_state["found_inf_per_device"] = self._unscale_grads_(optimizer, inv_scale, found_inf, False)
-
-
- def step(self, optimizer, *args, **kwargs):
- """
- :meth:`step` carries out the following two operations:
-
- 1. Internally invokes ``unscale_(optimizer)`` (unless :meth:`unscale_` was explicitly called for ``optimizer``
- earlier in the iteration). As part of the :meth:`unscale_`, gradients are checked for infs/NaNs.
- 2. If no inf/NaN gradients are found, invokes ``optimizer.step()`` using the unscaled
- gradients. Otherwise, ``optimizer.step()`` is skipped to avoid corrupting the params.
-
- ``*args`` and ``**kwargs`` are forwarded to ``optimizer.step()``.
-
- Returns the return value of ``optimizer.step(*args, **kwargs)``.
-
- Args:
- optimizer (torch.optim.Optimizer): Optimizer that applies the gradients.
- args: Any arguments.
- kwargs: Any keyword arguments.
-
- .. warning::
- Closure use is not currently supported.
- """
- if (not self._enabled):
- return optimizer.step(*args, **kwargs)
-
- if "closure" in kwargs:
- raise RuntimeError("Closure use is not currently supported if GradScaler is enabled.")
-
- self._check_scale_growth_tracker("step")
-
- optimizer_state = self._per_optimizer_states[id(optimizer)]
- retval = None
-
- if (hasattr(optimizer, "_step_supports_amp_scaling") and optimizer._step_supports_amp_scaling):
- # This optimizer has customized scale-handling logic, so we can call optimizer.step() directly.
- # The contract with custom optimizers is that their step() should accept an additional,
- # optional grad_scaler kwarg. We append self to the kwargs so the custom optimizer has full information:
- # it can query its own state, invoke unscale_ on itself, etc
- # The contract above is being deprecated to avoid introducing `grad_scaler: GradScaler` argument
- # to `Optimizer.step`. The new behavior is going to add two Tensor attributes of `grad_scale`
- # and `found_inf` to the passed optimizer so that the optimizer can utilize those
- # to skip the parameter updates or unscale gradients before updating parameters in
- # the fused kernel, e.g. `FusedAdamMathFunctor`.
- # In this behavior, `GradScaler._check_inf_per_device` is called if `OptState.READY`,
- # while the method is expected to be called by users side, i.e. their optimizers.
- kwargs_ = kwargs
- has_grad_scaler_kwarg = "grad_scaler" in inspect.signature(optimizer.step).parameters
- if has_grad_scaler_kwarg:
- warnings.warn(
- "GradScaler is going to stop passing itself as a keyword argument to the passed "
- "optimizer. In the near future GradScaler registers `grad_scale: Tensor` and "
- "`found_inf: Tensor` to the passed optimizer and let the optimizer use them directly.",
- FutureWarning)
- kwargs_.update({"grad_scaler": self})
- else:
- if optimizer_state["stage"] is OptState.READY:
- self._check_inf_per_device(optimizer)
- scaler = self._get_scale_async()
- found_inf = cast(
- jt.Var,
- sum([
- t for t in optimizer_state["found_inf_per_device"].values()
- ])
- )
- optimizer.grad_scale = None if optimizer_state["stage"] == OptState.UNSCALED else scaler
- optimizer.found_inf = found_inf
- retval = optimizer.step(*args, **kwargs_)
- optimizer_state["stage"] = OptState.STEPPED
- if not has_grad_scaler_kwarg:
- del optimizer.grad_scale
- del optimizer.found_inf
- return retval
-
- if hasattr(optimizer,"get_find_inf"):
- optimizer.set_grad_scale(self._scale)
- optimizer.step()
- optimizer_state["found_inf_per_device"] = optimizer.get_find_inf()
- return
-
- retval = None
- if not optimizer_state["found_inf_per_device"]:
- retval = optimizer.step(*args, **kwargs)
- else:
- optimizer.post_step()
-
- return retval
-
-
- def update(self, new_scale=None):
- """
- Updates the scale factor.
-
- If any optimizer steps were skipped the scale is multiplied by ``backoff_factor``
- to reduce it. If ``growth_interval`` unskipped iterations occurred consecutively,
- the scale is multiplied by ``growth_factor`` to increase it.
-
- Passing ``new_scale`` sets the new scale value manually. (``new_scale`` is not
- used directly, it's used to fill GradScaler's internal scale tensor. So if
- ``new_scale`` was a tensor, later in-place changes to that tensor will not further
- affect the scale GradScaler uses internally.)
-
- Args:
- new_scale (float or :class:`torch.cuda.FloatTensor`, optional, default=None): New scale factor.
-
- .. warning::
- :meth:`update` should only be called at the end of the iteration, after ``scaler.step(optimizer)`` has
- been invoked for all optimizers used this iteration.
- """
- if not self._enabled:
- return
-
- _scale, _growth_tracker = self._check_scale_growth_tracker("update")
-
- if new_scale is not None:
- # Accept a new user-defined scale.
- if isinstance(new_scale, float):
- self._scale.fill_(new_scale) # type: ignore[union-attr]
- else:
- reason = "new_scale should be a float or a 1-element torch.cuda.FloatTensor with requires_grad=False."
- assert isinstance(new_scale, torch.cuda.FloatTensor), reason # type: ignore[attr-defined]
- assert new_scale.numel() == 1, reason
- assert new_scale.requires_grad is False, reason
- self._scale.copy_(new_scale) # type: ignore[union-attr]
- else:
- # Consume shared inf/nan data collected from optimizers to update the scale.
- # If all found_inf tensors are on the same device as self._scale, this operation is asynchronous.
- found_infs = [state["found_inf_per_device"]
- for state in self._per_optimizer_states.values()
- ]
-
- assert len(found_infs) > 0, "No inf checks were recorded prior to update."
-
- found_inf_combined = found_infs[0]
- if len(found_infs) > 1:
- for i in range(1, len(found_infs)):
- found_inf_combined += found_infs[i]
-
-
- current_scale = _scale
- if found_inf_combined:
- current_scale *=self._backoff_factor
- _growth_tracker = 0
- else:
- successful = _growth_tracker+1
- if successful == self._growth_interval:
- new_scale = current_scale*self._growth_factor
- if new_scale < 1e9:
- current_scale = new_scale
- _growth_tracker = 0
- else:
- _growth_tracker = successful
-
- self._scale, self._growth_tracker = current_scale,_growth_tracker
-
- # To prepare for next iteration, clear the data collected from optimizers this iteration.
- self._per_optimizer_states = defaultdict(_refresh_per_optimizer_state)
-
- def _get_scale_async(self):
- return self._scale
-
- def get_scale(self):
- """
- Returns a Python float containing the current scale, or 1.0 if scaling is disabled.
-
- .. warning::
- :meth:`get_scale` incurs a CPU-GPU sync.
- """
- if self._enabled:
- return self._init_scale if self._scale is None else self._get_scale_async()
- else:
- return 1.0
-
- def get_growth_factor(self):
- r"""
- Returns a Python float containing the scale growth factor.
- """
- return self._growth_factor
-
- def set_growth_factor(self, new_factor):
- r"""
- Args:
- new_scale (float): Value to use as the new scale growth factor.
- """
- self._growth_factor = new_factor
-
- def get_backoff_factor(self):
- r"""
- Returns a Python float containing the scale backoff factor.
- """
- return self._backoff_factor
-
- def set_backoff_factor(self, new_factor):
- r"""
- Args:
- new_scale (float): Value to use as the new scale backoff factor.
- """
- self._backoff_factor = new_factor
-
- def get_growth_interval(self):
- r"""
- Returns a Python int containing the growth interval.
- """
- return self._growth_interval
-
- def set_growth_interval(self, new_interval):
- r"""
- Args:
- new_interval (int): Value to use as the new growth interval.
- """
- self._growth_interval = new_interval
-
- def _get_growth_tracker(self):
- if self._enabled:
- return self._init_growth_tracker if self._growth_tracker is None else self._growth_tracker.item()
- else:
- return 0
-
- def is_enabled(self):
- r"""
- Returns a bool indicating whether this instance is enabled.
- """
- return self._enabled
-
- def state_dict(self):
- r"""
- Returns the state of the scaler as a :class:`dict`. It contains five entries:
-
- * ``"scale"`` - a Python float containing the current scale
- * ``"growth_factor"`` - a Python float containing the current growth factor
- * ``"backoff_factor"`` - a Python float containing the current backoff factor
- * ``"growth_interval"`` - a Python int containing the current growth interval
- * ``"_growth_tracker"`` - a Python int containing the number of recent consecutive unskipped steps.
-
- If this instance is not enabled, returns an empty dict.
-
- .. note::
- If you wish to checkpoint the scaler's state after a particular iteration, :meth:`state_dict`
- should be called after :meth:`update`.
- """
- return {"scale": self.get_scale(),
- "growth_factor": self._growth_factor,
- "backoff_factor": self._backoff_factor,
- "growth_interval": self._growth_interval,
- "_growth_tracker": self._get_growth_tracker()} if self._enabled else {}
-
- def load_state_dict(self, state_dict):
- r"""
- Loads the scaler state. If this instance is disabled, :meth:`load_state_dict` is a no-op.
-
- Args:
- state_dict(dict): scaler state. Should be an object returned from a call to :meth:`state_dict`.
- """
- if not self._enabled:
- return
-
- if len(state_dict) == 0:
- raise RuntimeError("The source state dict is empty, possibly because it was saved "
- "from a disabled instance of GradScaler.")
-
- self._init_scale = state_dict["scale"]
- if self._scale is not None:
- self._scale.fill_(state_dict["scale"])
- self._growth_factor = state_dict["growth_factor"]
- self._backoff_factor = state_dict["backoff_factor"]
- self._growth_interval = state_dict["growth_interval"]
- self._init_growth_tracker = state_dict["_growth_tracker"]
- if self._growth_tracker is not None:
- self._growth_tracker.fill_(state_dict["_growth_tracker"])
-
- def __getstate__(self):
- state = self.__dict__.copy()
- if self._enabled:
- assert len(self._per_optimizer_states) == 0, "A GradScaler instance may only be pickled at the beginning "\
- "of an iteration, or at the end after scaler.update()."
- # Pickling _scale and _growth_tracker Tensors directly triggers
- # "warnings.warn("pickle support for Storage will be removed in 1.5..."
- # so instead, we set the unpickled instance up to reinitialize them lazily.
- state['_init_scale'] = self.get_scale()
- state['_init_growth_tracker'] = self._get_growth_tracker()
- state['_scale'] = None
- state['_growth_tracker'] = None
- return state
-
- def __setstate__(self, state):
- self.__dict__.update(state)
-
- def _check_inf_per_device(self, optimizer):
- _scale, _ = self._check_scale_growth_tracker("_check_inf_per_device")
-
- dummy_inv_scale = 1.0
- found_inf = 0.0
-
- self._per_optimizer_states[id(optimizer)]["found_inf_per_device"] = \
- self._unscale_grads_(optimizer, dummy_inv_scale, found_inf, True)
-
- return self._per_optimizer_states[id(optimizer)]["found_inf_per_device"]
-
- def _found_inf_per_device(self, optimizer):
- return self._per_optimizer_states[id(optimizer)]["found_inf_per_device"]
diff --git a/python/jittor/compatibility/gradscaler_old.py b/python/jittor/compatibility/gradscaler_old.py
deleted file mode 100644
index 389be2cf..00000000
--- a/python/jittor/compatibility/gradscaler_old.py
+++ /dev/null
@@ -1,556 +0,0 @@
-from collections import defaultdict, abc
-from enum import Enum
-from typing import Any, Dict, List, Optional, Tuple, cast
-import inspect
-import warnings
-
-import jittor as jt
-# import torch
-
-
-__all__ = ["OptState", "GradScaler"]
-
-
-# Defines default_factory for GradScaler's _per_optimizer_states defaultdict,
-# as well as associated "enum" values. Prefers defining these at top level because
-# - Lambdas can't be pickled, so we don't want to supply a lambda as the factory.
-# - Defining READY, UNSCALED, STEPPED and _refresh_per_optimizer_state within GradScaler
-# causes a circular reference, which we'd rather avoid.
-class OptState(Enum):
- READY = 0
- UNSCALED = 1
- STEPPED = 2
-
-
-def _refresh_per_optimizer_state():
- return {"stage": OptState.READY, "found_inf_per_device": {}}
-
-
-class GradScaler:
- _scale: Optional[jt.Var]
- _grows_tracker: Optional[jt.Var]
- _per_optimizer_states: Dict[int, Dict[str, Any]]
- """
- An instance ``scaler`` of :class:`GradScaler` helps perform the steps of gradient scaling
- conveniently.
-
- * ``scaler.scale(loss)`` multiplies a given loss by ``scaler``'s current scale factor.
- * ``scaler.step(optimizer)`` safely unscales gradients and calls ``optimizer.step()``.
- * ``scaler.update()`` updates ``scaler``'s scale factor.
-
- Example::
-
- # Creates a GradScaler once at the beginning of training.
- scaler = GradScaler()
-
- for epoch in epochs:
- for input, target in data:
- optimizer.zero_grad()
- output = model(input)
- loss = loss_fn(output, target)
-
- # Scales loss. Calls backward() on scaled loss to create scaled gradients.
- scaler.scale(loss).backward()
-
- # scaler.step() first unscales gradients of the optimizer's params.
- # If gradients don't contain infs/NaNs, optimizer.step() is then called,
- # otherwise, optimizer.step() is skipped.
- scaler.step(optimizer)
-
- # Updates the scale for next iteration.
- scaler.update()
-
- See the :ref:`Automatic Mixed Precision examples` for usage
- (along with autocasting) in more complex cases like gradient clipping, gradient accumulation, gradient penalty,
- and multiple losses/optimizers.
-
- ``scaler`` dynamically estimates the scale factor each iteration. To minimize gradient underflow,
- a large scale factor should be used. However, ``float16`` values can "overflow" (become inf or NaN) if
- the scale factor is too large. Therefore, the optimal scale factor is the largest factor that can be used
- without incurring inf or NaN gradient values.
- ``scaler`` approximates the optimal scale factor over time by checking the gradients for infs and NaNs during every
- ``scaler.step(optimizer)`` (or optional separate ``scaler.unscale_(optimizer)``, see :meth:`unscale_`).
-
- * If infs/NaNs are found, ``scaler.step(optimizer)`` skips the underlying ``optimizer.step()`` (so the params
- themselves remain uncorrupted) and ``update()`` multiplies the scale by ``backoff_factor``.
-
- * If no infs/NaNs are found, ``scaler.step(optimizer)`` runs the underlying ``optimizer.step()`` as usual.
- If ``growth_interval`` unskipped iterations occur consecutively, ``update()`` multiplies the scale by
- ``growth_factor``.
-
- The scale factor often causes infs/NaNs to appear in gradients for the first few iterations as its
- value calibrates. ``scaler.step`` will skip the underlying ``optimizer.step()`` for these
- iterations. After that, step skipping should occur rarely (once every few hundred or thousand iterations).
-
- Args:
- init_scale (float, optional, default=2.**16): Initial scale factor.
- growth_factor (float, optional, default=2.0): Factor by which the scale is multiplied during
- :meth:`update` if no inf/NaN gradients occur for ``growth_interval`` consecutive iterations.
- backoff_factor (float, optional, default=0.5): Factor by which the scale is multiplied during
- :meth:`update` if inf/NaN gradients occur in an iteration.
- growth_interval (int, optional, default=2000): Number of consecutive iterations without inf/NaN gradients
- that must occur for the scale to be multiplied by ``growth_factor``.
- enabled (bool, optional): If ``False``, disables gradient scaling. :meth:`step` simply
- invokes the underlying ``optimizer.step()``, and other methods become no-ops.
- Default: ``True``
- """
- def __init__(self,
- init_scale=2.**16,
- growth_factor=2.0,
- backoff_factor=0.5,
- growth_interval=2000,
- enabled=True):
- self._enabled = enabled
-
- if self._enabled:
- assert growth_factor > 1.0, "The growth factor must be > 1.0."
- assert backoff_factor < 1.0, "The backoff factor must be < 1.0."
-
- self._init_scale = init_scale
- # self._scale will be lazily initialized during the first call to scale()
- self._scale = None
- self._growth_factor = growth_factor
- self._backoff_factor = backoff_factor
- self._growth_interval = growth_interval
- self._init_growth_tracker = 0
- # self._growth_tracker will be lazily initialized during the first call to scale()
- self._growth_tracker = None
- self._per_optimizer_states = defaultdict(_refresh_per_optimizer_state)
-
- def _check_scale_growth_tracker(self, funcname) -> Tuple[jt.Var, jt.Var]:
- fix = "This may indicate your script did not use scaler.scale(loss or outputs) earlier in the iteration."
- assert self._scale is not None, "Attempted {} but _scale is None. ".format(funcname) + fix
- assert self._growth_tracker is not None, "Attempted {} but _growth_tracker is None. ".format(funcname) + fix
- return (self._scale, self._growth_tracker)
-
- def _lazy_init_scale_growth_tracker(self):
- assert self._growth_tracker is None, "_growth_tracker initialized before _scale"
- self._scale = self._init_scale
- self._growth_tracker = self._init_growth_tracker
-
- def scale(self, outputs):
- """
- Multiplies ('scales') a tensor or list of tensors by the scale factor.
-
- Returns scaled outputs. If this instance of :class:`GradScaler` is not enabled, outputs are returned
- unmodified.
-
- Args:
- outputs (Tensor or iterable of Tensors): Outputs to scale.
- """
- print("scale")
- if not self._enabled:
- return outputs
-
-
- # Short-circuit for the common case.
- if isinstance(outputs, jt.Var):
- assert jt.flags.use_cuda == 1
- if self._scale is None:
- self._lazy_init_scale_growth_tracker()
- assert self._scale is not None
- return outputs * self._scale
-
- def apply_scale(val):
- if isinstance(val, jt.Var):
- assert jt.flags.use_cuda == 1
- if self._scale is None:
- self._lazy_init_scale_growth_tracker()
- assert self._scale is not None
- return val * self._scale
- elif isinstance(val, abc.Iterable):
- iterable = map(apply_scale, val)
- if isinstance(val, (list, tuple)):
- return type(val)(iterable)
- else:
- return iterable
- else:
- raise ValueError("outputs must be a Tensor or an iterable of Tensors")
-
- return apply_scale(outputs)
-
- def _unscale_grads_(self, optimizer, inv_scale, found_inf, allow_fp16):
-
- # To set up _amp_foreach_non_finite_check_and_unscale_, split grads by device and dtype.
- # There could be hundreds of grads, so we'd like to iterate through them just once.
- # However, we don't know their devices or dtypes in advance.
-
- # https://stackoverflow.com/questions/5029934/defaultdict-of-defaultdict
- # Google says mypy struggles with defaultdicts type annotations.
- with jt.no_grad():
- optimizer.pre_step()
- for group in optimizer.param_groups:
- for to_unscale in group["grads"]:
- if to_unscale is None or isinstance(to_unscale,(int,float)):
- continue
- if (not allow_fp16) and str(to_unscale.dtype) == "float16":
- raise ValueError("Attempting to unscale FP16 gradients.")
-
- if not (to_unscale.isinf().any()):
- if inv_scale != 1.0:
- to_unscale.update(to_unscale*inv_scale)
- else:
- found_inf = 1.0
-
- return found_inf
-
- def unscale_(self, optimizer):
- """
- Divides ("unscales") the optimizer's gradient tensors by the scale factor.
-
- :meth:`unscale_` is optional, serving cases where you need to
- :ref:`modify or inspect gradients`
- between the backward pass(es) and :meth:`step`.
- If :meth:`unscale_` is not called explicitly, gradients will be unscaled automatically during :meth:`step`.
-
- Simple example, using :meth:`unscale_` to enable clipping of unscaled gradients::
-
- ...
- scaler.scale(loss).backward()
- scaler.unscale_(optimizer)
- torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm)
- scaler.step(optimizer)
- scaler.update()
-
- Args:
- optimizer (torch.optim.Optimizer): Optimizer that owns the gradients to be unscaled.
-
- .. note::
- :meth:`unscale_` does not incur a CPU-GPU sync.
-
- .. warning::
- :meth:`unscale_` should only be called once per optimizer per :meth:`step` call,
- and only after all gradients for that optimizer's assigned parameters have been accumulated.
- Calling :meth:`unscale_` twice for a given optimizer between each :meth:`step` triggers a RuntimeError.
-
- .. warning::
- :meth:`unscale_` may unscale sparse gradients out of place, replacing the ``.grad`` attribute.
- """
- if not self._enabled:
- return
-
- self._check_scale_growth_tracker("unscale_")
-
- optimizer_state = self._per_optimizer_states[id(optimizer)]
-
- if optimizer_state["stage"] is OptState.UNSCALED:
- raise RuntimeError("unscale_() has already been called on this optimizer since the last update().")
- elif optimizer_state["stage"] is OptState.STEPPED:
- raise RuntimeError("unscale_() is being called after step().")
-
-
- # FP32 division can be imprecise for certain compile options, so we carry out the reciprocal in FP64.
- assert self._scale is not None
- inv_scale = 1.0 / self._scale
- found_inf = 0.0
- optimizer_state["found_inf_per_device"] = self._unscale_grads_(optimizer, inv_scale, found_inf, False)
- optimizer_state["stage"] = OptState.UNSCALED
-
- def _maybe_opt_step(self, optimizer, optimizer_state, *args, **kwargs):
- retval = None
- if not optimizer_state["found_inf_per_device"]:
- retval = optimizer.step(*args, **kwargs)
- else:
- optimizer.post_step()
-
- return retval
-
- def step(self, optimizer, *args, **kwargs):
- """
- :meth:`step` carries out the following two operations:
-
- 1. Internally invokes ``unscale_(optimizer)`` (unless :meth:`unscale_` was explicitly called for ``optimizer``
- earlier in the iteration). As part of the :meth:`unscale_`, gradients are checked for infs/NaNs.
- 2. If no inf/NaN gradients are found, invokes ``optimizer.step()`` using the unscaled
- gradients. Otherwise, ``optimizer.step()`` is skipped to avoid corrupting the params.
-
- ``*args`` and ``**kwargs`` are forwarded to ``optimizer.step()``.
-
- Returns the return value of ``optimizer.step(*args, **kwargs)``.
-
- Args:
- optimizer (torch.optim.Optimizer): Optimizer that applies the gradients.
- args: Any arguments.
- kwargs: Any keyword arguments.
-
- .. warning::
- Closure use is not currently supported.
- """
- if (not self._enabled):
- return optimizer.step(*args, **kwargs)
-
- if "closure" in kwargs:
- raise RuntimeError("Closure use is not currently supported if GradScaler is enabled.")
-
- self._check_scale_growth_tracker("step")
-
- optimizer_state = self._per_optimizer_states[id(optimizer)]
-
- if optimizer_state["stage"] is OptState.STEPPED:
- raise RuntimeError("step() has already been called since the last update().")
-
- retval = None
-
- if (hasattr(optimizer, "_step_supports_amp_scaling") and optimizer._step_supports_amp_scaling):
- # This optimizer has customized scale-handling logic, so we can call optimizer.step() directly.
- # The contract with custom optimizers is that their step() should accept an additional,
- # optional grad_scaler kwarg. We append self to the kwargs so the custom optimizer has full information:
- # it can query its own state, invoke unscale_ on itself, etc
- # The contract above is being deprecated to avoid introducing `grad_scaler: GradScaler` argument
- # to `Optimizer.step`. The new behavior is going to add two Tensor attributes of `grad_scale`
- # and `found_inf` to the passed optimizer so that the optimizer can utilize those
- # to skip the parameter updates or unscale gradients before updating parameters in
- # the fused kernel, e.g. `FusedAdamMathFunctor`.
- # In this behavior, `GradScaler._check_inf_per_device` is called if `OptState.READY`,
- # while the method is expected to be called by users side, i.e. their optimizers.
- kwargs_ = kwargs
- has_grad_scaler_kwarg = "grad_scaler" in inspect.signature(optimizer.step).parameters
- if has_grad_scaler_kwarg:
- warnings.warn(
- "GradScaler is going to stop passing itself as a keyword argument to the passed "
- "optimizer. In the near future GradScaler registers `grad_scale: Tensor` and "
- "`found_inf: Tensor` to the passed optimizer and let the optimizer use them directly.",
- FutureWarning)
- kwargs_.update({"grad_scaler": self})
- else:
- if optimizer_state["stage"] is OptState.READY:
- self._check_inf_per_device(optimizer)
- scaler = self._get_scale_async()
- found_inf = cast(
- jt.Var,
- sum([
- t for t in optimizer_state["found_inf_per_device"].values()
- ])
- )
- optimizer.grad_scale = None if optimizer_state["stage"] == OptState.UNSCALED else scaler
- optimizer.found_inf = found_inf
- retval = optimizer.step(*args, **kwargs_)
- optimizer_state["stage"] = OptState.STEPPED
- if not has_grad_scaler_kwarg:
- del optimizer.grad_scale
- del optimizer.found_inf
- return retval
-
-
- if optimizer_state["stage"] is OptState.READY:
- self.unscale_(optimizer)
-
- assert "found_inf_per_device" in optimizer_state, "No inf checks were recorded for this optimizer."
-
- retval = self._maybe_opt_step(optimizer, optimizer_state, *args, **kwargs)
-
- optimizer_state["stage"] = OptState.STEPPED
-
- return retval
-
- def update(self, new_scale=None):
- """
- Updates the scale factor.
-
- If any optimizer steps were skipped the scale is multiplied by ``backoff_factor``
- to reduce it. If ``growth_interval`` unskipped iterations occurred consecutively,
- the scale is multiplied by ``growth_factor`` to increase it.
-
- Passing ``new_scale`` sets the new scale value manually. (``new_scale`` is not
- used directly, it's used to fill GradScaler's internal scale tensor. So if
- ``new_scale`` was a tensor, later in-place changes to that tensor will not further
- affect the scale GradScaler uses internally.)
-
- Args:
- new_scale (float or :class:`torch.cuda.FloatTensor`, optional, default=None): New scale factor.
-
- .. warning::
- :meth:`update` should only be called at the end of the iteration, after ``scaler.step(optimizer)`` has
- been invoked for all optimizers used this iteration.
- """
- if not self._enabled:
- return
-
- _scale, _growth_tracker = self._check_scale_growth_tracker("update")
-
- if new_scale is not None:
- # Accept a new user-defined scale.
- if isinstance(new_scale, float):
- self._scale.fill_(new_scale) # type: ignore[union-attr]
- else:
- reason = "new_scale should be a float or a 1-element torch.cuda.FloatTensor with requires_grad=False."
- assert isinstance(new_scale, torch.cuda.FloatTensor), reason # type: ignore[attr-defined]
- assert new_scale.numel() == 1, reason
- assert new_scale.requires_grad is False, reason
- self._scale.copy_(new_scale) # type: ignore[union-attr]
- else:
- # Consume shared inf/nan data collected from optimizers to update the scale.
- # If all found_inf tensors are on the same device as self._scale, this operation is asynchronous.
- found_infs = [state["found_inf_per_device"]
- for state in self._per_optimizer_states.values()
- ]
-
- assert len(found_infs) > 0, "No inf checks were recorded prior to update."
-
- found_inf_combined = found_infs[0]
- if len(found_infs) > 1:
- for i in range(1, len(found_infs)):
- found_inf_combined += found_infs[i]
-
-
- current_scale = _scale
- if found_inf_combined:
- current_scale *=self._backoff_factor
- _growth_tracker = 0
- else:
- successful = _growth_tracker+1
- if successful == self._growth_interval:
- new_scale = current_scale*self._growth_factor
- if new_scale < 1e9:
- current_scale = new_scale
- _growth_tracker = 0
- else:
- _growth_tracker = successful
-
- self._scale, self._growth_tracker = current_scale,_growth_tracker
-
- # To prepare for next iteration, clear the data collected from optimizers this iteration.
- self._per_optimizer_states = defaultdict(_refresh_per_optimizer_state)
-
- def _get_scale_async(self):
- return self._scale
-
- def get_scale(self):
- """
- Returns a Python float containing the current scale, or 1.0 if scaling is disabled.
-
- .. warning::
- :meth:`get_scale` incurs a CPU-GPU sync.
- """
- if self._enabled:
- return self._init_scale if self._scale is None else self._get_scale_async()
- else:
- return 1.0
-
- def get_growth_factor(self):
- r"""
- Returns a Python float containing the scale growth factor.
- """
- return self._growth_factor
-
- def set_growth_factor(self, new_factor):
- r"""
- Args:
- new_scale (float): Value to use as the new scale growth factor.
- """
- self._growth_factor = new_factor
-
- def get_backoff_factor(self):
- r"""
- Returns a Python float containing the scale backoff factor.
- """
- return self._backoff_factor
-
- def set_backoff_factor(self, new_factor):
- r"""
- Args:
- new_scale (float): Value to use as the new scale backoff factor.
- """
- self._backoff_factor = new_factor
-
- def get_growth_interval(self):
- r"""
- Returns a Python int containing the growth interval.
- """
- return self._growth_interval
-
- def set_growth_interval(self, new_interval):
- r"""
- Args:
- new_interval (int): Value to use as the new growth interval.
- """
- self._growth_interval = new_interval
-
- def _get_growth_tracker(self):
- if self._enabled:
- return self._init_growth_tracker if self._growth_tracker is None else self._growth_tracker.item()
- else:
- return 0
-
- def is_enabled(self):
- r"""
- Returns a bool indicating whether this instance is enabled.
- """
- return self._enabled
-
- def state_dict(self):
- r"""
- Returns the state of the scaler as a :class:`dict`. It contains five entries:
-
- * ``"scale"`` - a Python float containing the current scale
- * ``"growth_factor"`` - a Python float containing the current growth factor
- * ``"backoff_factor"`` - a Python float containing the current backoff factor
- * ``"growth_interval"`` - a Python int containing the current growth interval
- * ``"_growth_tracker"`` - a Python int containing the number of recent consecutive unskipped steps.
-
- If this instance is not enabled, returns an empty dict.
-
- .. note::
- If you wish to checkpoint the scaler's state after a particular iteration, :meth:`state_dict`
- should be called after :meth:`update`.
- """
- return {"scale": self.get_scale(),
- "growth_factor": self._growth_factor,
- "backoff_factor": self._backoff_factor,
- "growth_interval": self._growth_interval,
- "_growth_tracker": self._get_growth_tracker()} if self._enabled else {}
-
- def load_state_dict(self, state_dict):
- r"""
- Loads the scaler state. If this instance is disabled, :meth:`load_state_dict` is a no-op.
-
- Args:
- state_dict(dict): scaler state. Should be an object returned from a call to :meth:`state_dict`.
- """
- if not self._enabled:
- return
-
- if len(state_dict) == 0:
- raise RuntimeError("The source state dict is empty, possibly because it was saved "
- "from a disabled instance of GradScaler.")
-
- self._init_scale = state_dict["scale"]
- if self._scale is not None:
- self._scale.fill_(state_dict["scale"])
- self._growth_factor = state_dict["growth_factor"]
- self._backoff_factor = state_dict["backoff_factor"]
- self._growth_interval = state_dict["growth_interval"]
- self._init_growth_tracker = state_dict["_growth_tracker"]
- if self._growth_tracker is not None:
- self._growth_tracker.fill_(state_dict["_growth_tracker"])
-
- def __getstate__(self):
- state = self.__dict__.copy()
- if self._enabled:
- assert len(self._per_optimizer_states) == 0, "A GradScaler instance may only be pickled at the beginning "\
- "of an iteration, or at the end after scaler.update()."
- # Pickling _scale and _growth_tracker Tensors directly triggers
- # "warnings.warn("pickle support for Storage will be removed in 1.5..."
- # so instead, we set the unpickled instance up to reinitialize them lazily.
- state['_init_scale'] = self.get_scale()
- state['_init_growth_tracker'] = self._get_growth_tracker()
- state['_scale'] = None
- state['_growth_tracker'] = None
- return state
-
- def __setstate__(self, state):
- self.__dict__.update(state)
-
- def _check_inf_per_device(self, optimizer):
- _scale, _ = self._check_scale_growth_tracker("_check_inf_per_device")
-
- dummy_inv_scale = 1.0
- found_inf = 0.0
-
- self._per_optimizer_states[id(optimizer)]["found_inf_per_device"] = \
- self._unscale_grads_(optimizer, dummy_inv_scale, found_inf, True)
-
- return self._per_optimizer_states[id(optimizer)]["found_inf_per_device"]
-
- def _found_inf_per_device(self, optimizer):
- return self._per_optimizer_states[id(optimizer)]["found_inf_per_device"]
diff --git a/python/jittor/compatibility/misc.py b/python/jittor/compatibility/misc.py
deleted file mode 100644
index 8e9ed20d..00000000
--- a/python/jittor/compatibility/misc.py
+++ /dev/null
@@ -1,12 +0,0 @@
-import math
-
-def _jit_set_profiling_mode(x): pass
-def _jit_set_profiling_executor(x): pass
-def _jit_override_can_fuse_on_cpu(x): pass
-def _jit_override_can_fuse_on_gpu(x): pass
-
-def script(func):
- return func
-
-inf = math.inf
-nan = math.nan
\ No newline at end of file
diff --git a/python/jittor/compatibility/nn/__init__.py b/python/jittor/compatibility/nn/__init__.py
deleted file mode 100644
index ae0ff3ae..00000000
--- a/python/jittor/compatibility/nn/__init__.py
+++ /dev/null
@@ -1,281 +0,0 @@
-import jtorch
-from typing import List, Optional, Tuple, Iterable, Iterator, Mapping, Any, overload, TypeVar, Dict
-from typing_extensions import Self
-import jittor as jt
-from jtorch import make_module, Tensor, ModuleMisc, wrapper
-#from . import init
-from jittor import Function
-import operator
-import warnings
-
-for k,v in jt.nn.__dict__.items():
- if callable(v):
- globals()[k] = wrapper(v)
-
-for k,v in jt.nn.__dict__.items():
- if isinstance(v, type) and issubclass(v, jt.Module):
- globals()[k] = make_module(v)
-
-from collections import OrderedDict
-from collections import abc as container_abcs
-
-class Module(ModuleMisc, jt.Module):
-
- def __call__(self, *args, **kw):
- return self.execute(*args, **kw)
-
- def execute(self, *args, **kw):
- return self.forward(*args, **kw)
-
- def get_submodule(self, target: str):
- if target == "":
- return self
-
- atoms: List[str] = target.split(".")
- mod: jt.nn.Module = self
-
- for item in atoms:
- if not hasattr(mod, item):
- raise AttributeError(mod._get_name() + " has no "
- "attribute `" + item + "`")
-
- mod = getattr(mod, item)
-
- if not isinstance(mod, jt.nn.Module):
- raise AttributeError("`" + item + "` is not "
- "an nn.Module")
- return mod
-
-
-
-def Parameter(x:Tensor, requires_grad:bool=True) -> Tensor:
- x = x.clone()
- x.requires_grad = requires_grad
- x.retains_grad = requires_grad
- return x
-
-def embedding(input, weight, padding_idx=None, max_norm=None, norm_type=2.0, scale_grad_by_freq=False, sparse=False):
- return jt.nn.embedding(input, weight)
-
-def dropout(x, p=0.5, training=False):
- return jt.nn.dropout(x, p, training)
-
-
-class Flatten(Module):
- ''' Flattens the contiguous range of dimensions in a Var.
- :param start_dim: the first dimension to be flattened. Defaults: 1.
- :type start_dim: int
- :param end_dim: the last dimension to be flattened. Defaults: -1.
- :type end_dim: int
- '''
- def __init__(self, start_dim=1, end_dim=-1):
- self.start_dim = start_dim
- self.end_dim = end_dim
-
- def forward(self, x) -> jt.Var:
- return x.flatten(self.start_dim, self.end_dim)
-
-class _IncompatibleKeys:
- def __init__(self, missing_keys, unexpected_keys):
- self.missing_keys = missing_keys
- self.unexpected_keys = unexpected_keys
-
-_BatchNorm = None
-
-#from . import utils
-normalize = wrapper(jt.normalize)
-
-T = TypeVar('T', bound=Module)
-
-class ModuleDict(Module):
- _modules: Dict[str, Module] # type: ignore[assignment]
-
- def __init__(self, modules: Optional[Mapping[str, Module]] = None) -> None:
- super().__init__()
- if modules is not None:
- self.update(modules)
-
- def __getitem__(self, key: str) -> Module:
- return self._modules[key]
-
- def __setitem__(self, key: str, module: Module) -> None:
- self.add_module(key, module)
-
- def __delitem__(self, key: str) -> None:
- del self._modules[key]
-
- def __len__(self) -> int:
- return len(self._modules)
-
- def __iter__(self) -> Iterator[str]:
- return iter(self._modules)
-
- def __contains__(self, key: str) -> bool:
- return key in self._modules
-
- def clear(self) -> None:
- """Remove all items from the ModuleDict."""
- self._modules.clear()
-
- def pop(self, key: str) -> Module:
- r"""Remove key from the ModuleDict and return its module.
-
- Args:
- key (str): key to pop from the ModuleDict
- """
- v = self[key]
- del self[key]
- return v
-
- def keys(self) -> Iterable[str]:
- r"""Return an iterable of the ModuleDict keys."""
- return self._modules.keys()
-
- def items(self) -> Iterable[Tuple[str, Module]]:
- r"""Return an iterable of the ModuleDict key/value pairs."""
- return self._modules.items()
-
- def values(self) -> Iterable[Module]:
- r"""Return an iterable of the ModuleDict values."""
- return self._modules.values()
-
- def update(self, modules: Mapping[str, Module]) -> None:
- r"""Update the :class:`~torch.nn.ModuleDict` with key-value pairs from a mapping, overwriting existing keys.
-
- .. note::
- If :attr:`modules` is an ``OrderedDict``, a :class:`~torch.nn.ModuleDict`, or
- an iterable of key-value pairs, the order of new elements in it is preserved.
-
- Args:
- modules (iterable): a mapping (dictionary) from string to :class:`~torch.nn.Module`,
- or an iterable of key-value pairs of type (string, :class:`~torch.nn.Module`)
- """
- if not isinstance(modules, container_abcs.Iterable):
- raise TypeError("ModuleDict.update should be called with an "
- "iterable of key/value pairs, but got " +
- type(modules).__name__)
-
- if isinstance(modules, (OrderedDict, ModuleDict, container_abcs.Mapping)):
- for key, module in modules.items():
- self[key] = module
- else:
- # modules here can be a list with two items
- for j, m in enumerate(modules):
- if not isinstance(m, container_abcs.Iterable):
- raise TypeError("ModuleDict update sequence element "
- "#" + str(j) + " should be Iterable; is" +
- type(m).__name__)
- if not len(m) == 2:
- raise ValueError("ModuleDict update sequence element "
- "#" + str(j) + " has length " + str(len(m)) +
- "; 2 is required")
- # modules can be Mapping (what it's typed at), or a list: [(name1, module1), (name2, module2)]
- # that's too cumbersome to type correctly with overloads, so we add an ignore here
- self[m[0]] = m[1] # type: ignore[assignment]
-
- # remove forward alltogether to fallback on Module's _forward_unimplemented
-
-
-class ParameterList(Module):
-
- def __init__(self, values: Optional[Iterable[Any]] = None) -> None:
- super().__init__()
- self._size = 0
- if values is not None:
- self += values
-
- def _get_abs_string_index(self, idx):
- """Get the absolute index for the list of modules."""
- idx = operator.index(idx)
- if not (-len(self) <= idx < len(self)):
- raise IndexError(f'index {idx} is out of range')
- if idx < 0:
- idx += len(self)
- return str(idx)
-
- @overload
- def __getitem__(self, idx: int) -> Any:
- ...
-
- @overload
- def __getitem__(self: T, idx: slice) -> T:
- ...
-
- def __getitem__(self, idx):
- if isinstance(idx, slice):
- start, stop, step = idx.indices(len(self))
- out = self.__class__()
- for i in range(start, stop, step):
- out.append(self[i])
- return out
- else:
- idx = self._get_abs_string_index(idx)
- return getattr(self, str(idx))
-
- def __setitem__(self, idx: int, param: Any) -> None:
- # Note that all other function that add an entry to the list part of
- # the ParameterList end up here. So this is the only place where we need
- # to wrap things into Parameter if needed.
- # Objects added via setattr() are not in the list part and thus won't
- # call into this function.
- idx = self._get_abs_string_index(idx)
- if isinstance(param, jt.Var) and not isinstance(param, Parameter):
- param = Parameter(param)
- return setattr(self, str(idx), param)
-
- def __len__(self) -> int:
- return self._size
-
- def __iter__(self) -> Iterator[Any]:
- return iter(self[i] for i in range(len(self)))
-
- def __iadd__(self, parameters: Iterable[Any]) -> Self:
- return self.extend(parameters)
-
- def __dir__(self):
- keys = super().__dir__()
- keys = [key for key in keys if not key.isdigit()]
- return keys
-
- def append(self, value: Any) -> 'ParameterList':
- """Append a given value at the end of the list.
-
- Args:
- value (Any): value to append
- """
- new_idx = len(self)
- self._size += 1
- self[new_idx] = value
- return self
-
- def extend(self, values: Iterable[Any]) -> Self:
- """Append values from a Python iterable to the end of the list.
-
- Args:
- values (iterable): iterable of values to append
- """
- # Tensor is an iterable but we never want to unpack it here
- if not isinstance(values, container_abcs.Iterable) or isinstance(values, jt.Var):
- raise TypeError("ParameterList.extend should be called with an "
- "iterable, but got " + type(values).__name__)
- for value in values:
- self.append(value)
- return self
-
- def extra_repr(self) -> str:
- child_lines = []
- for k, p in enumerate(self):
- if isinstance(p, jt.Var):
- size_str = 'x'.join(str(size) for size in p.size())
- parastr = '{} containing: [{} of size {}{}]'.format(
- "Parameter" if isinstance(p, Parameter) else "Tensor",
- p.dtype, size_str, "cuda" if jt.flags.use_cuda else "cpu")
- child_lines.append(' (' + str(k) + '): ' + parastr)
- else:
- child_lines.append(' (' + str(k) + '): Object of type: ' + type(p).__name__)
-
- tmpstr = '\n'.join(child_lines)
- return tmpstr
-
- def __call__(self, *args, **kwargs):
- raise RuntimeError('ParameterList should not be called.')
\ No newline at end of file
diff --git a/python/jittor/compatibility/nn/init.py b/python/jittor/compatibility/nn/init.py
deleted file mode 100644
index 3b9f0907..00000000
--- a/python/jittor/compatibility/nn/init.py
+++ /dev/null
@@ -1,16 +0,0 @@
-import jittor as jt
-
-for k,v in jt.nn.init.__dict__.items():
- if callable(v):
- globals()[k] = v
-
-
-normal = gauss
-normal_ = gauss_
-xavier_normal = xavier_gauss
-xavier_normal_ = xavier_gauss_
-zeros_ = zero_
-
-
-jt.Var.normal_ = normal_
-
diff --git a/python/jittor/compatibility/nn/utils/__init__.py b/python/jittor/compatibility/nn/utils/__init__.py
deleted file mode 100644
index 83409f5f..00000000
--- a/python/jittor/compatibility/nn/utils/__init__.py
+++ /dev/null
@@ -1 +0,0 @@
-from . import rnn
\ No newline at end of file
diff --git a/python/jittor/compatibility/nn/utils/rnn.py b/python/jittor/compatibility/nn/utils/rnn.py
deleted file mode 100644
index b32da8c3..00000000
--- a/python/jittor/compatibility/nn/utils/rnn.py
+++ /dev/null
@@ -1,20 +0,0 @@
-import jittor as jt
-
-PackedSequence = None
-
-def pad_sequence(sequences,batch_first=False,padding_value=0.0):
- max_f = max([len(s) for s in sequences])
- # max_f = 512
- b = len(sequences)
- if batch_first:
- ret = sequences[0].new_full([b,max_f,]+list(sequences[0].shape[1:]),padding_value)
- for i,s in enumerate(sequences):
- ret[i,:len(s)] = s
- else:
- ret = sequences[0].new_full([max_f,b,]+list(sequences[0].shape[1:]),padding_value)
- for i,s in enumerate(sequences):
- ret[:len(s),i] = s
- # print(ret.shape)
- # ret = ret[:,:406]
- return ret
-
\ No newline at end of file
diff --git a/python/jittor/compatibility/optim.py b/python/jittor/compatibility/optim.py
deleted file mode 100644
index 2410917f..00000000
--- a/python/jittor/compatibility/optim.py
+++ /dev/null
@@ -1,1854 +0,0 @@
-import jittor as jt
-import math
-from jittor.optim import *
-from functools import partial
-
-class Optimizer(jt.optim.Optimizer):
- def pre_step(self, loss=None, retain_graph=False):
- jt.flags.node_order = 1
- params_has_grad = []
- for pg in self.param_groups:
- pg["grads"] = [ jt.zeros_like(p) if p.grad is None else p.grad#.float32()
- for p in pg["params"] ]
- for p in pg["params"]:
- if p.requires_grad:
- params_has_grad.append(p)
- jt.sync(params_has_grad)
- self.n_step += 1
-
- def zero_grad(self):
- for pg in self.param_groups:
- pg["grads"] = [ None for p in pg["params"] ]
- for p in pg["params"]: p.grad = None
-
- def post_step(self):
- jt.flags.node_order = 0
-
- def clip_grad_norm(self, max_norm:float, norm_type:int=2):
- r"""Clips gradient norm of this optimizer.
- The norm is computed over all gradients together.
-
- Args:
- max_norm (float or int): max norm of the gradients
- norm_type (int): 1-norm or 2-norm
-
- Example::
-
- a = jt.ones(2)
- opt = jt.optim.SGD([a], 0.1)
-
- loss = a*a
- opt.zero_grad()
- opt.backward(loss)
-
- print(opt.param_groups[0]['grads'][0].norm()) # output: 2.83
- opt.clip_grad_norm(0.01, 2)
- print(opt.param_groups[0]['grads'][0].norm()) # output: 0.01
-
- opt.step()
-
- """
- self.pre_step(None)
- grads = []
- for pg in self.param_groups:
- for p, g in zip(pg["params"], pg["grads"]):
- if p.is_stop_grad(): continue
- grads.append(g.flatten())
- if len(grads) == 0: return
- total_norm = jt.norm(jt.concat(grads), norm_type)
- clip_coef = jt.minimum(max_norm / (total_norm + 1e-6), 1.0)
- for pg in self.param_groups:
- for p, g in zip(pg["params"], pg["grads"]):
- if p.is_stop_grad(): continue
- g.update(g*clip_coef)
-
-
-class AdamW(Optimizer):
- def __init__(self, params, lr, eps=1e-8, betas=(0.9, 0.999), weight_decay=0,use_fp32=True):
- print("lr:", lr)
- super().__init__(params, lr)
- self.eps = eps
- self.betas = betas
- self.weight_decay = weight_decay
-
- self.use_fp32 = use_fp32
- # assert weight_decay==0, "weight_decay is not supported yet"
-
- # initialize required arguments for each param_groups
- for pg in self.param_groups:
- values = pg["values"] = []
- m = pg["m"] = []
- mp = pg['masterparams'] = []
- for p in pg["params"]:
- values.append(jt.zeros(p.shape, "float32" if self.use_fp32 else p.dtype).stop_grad())
- m.append(jt.zeros(p.shape, "float32" if self.use_fp32 else p.dtype).stop_grad())
- if self.use_fp32:
- mp.append(p.detach().clone().stop_grad())
-
- def add_param_group(self, group):
- values = group["values"] = []
- m = group["m"] = []
- mp = group['masterparams'] = []
- for p in group["params"]:
- values.append(jt.zeros(p.shape, "float32" if self.use_fp32 else p.dtype).stop_grad())
- m.append(jt.zeros(p.shape, "float32" if self.use_fp32 else p.dtype).stop_grad())
- if self.use_fp32:
- mp.append(p.detach().clone().stop_grad())
- self.param_groups.append(group)
-
- def step(self, loss=None, retain_graph=False):
- self.pre_step(loss, retain_graph)
- if loss is None:
- self.n_step += 1
- n = float(self.n_step)
- for pg in self.param_groups:
- # get arguments from each param_groups
- lr = pg.get("lr", self.lr)
- eps = pg.get("eps", self.eps)
- weight_decay = pg.get("weight_decay", self.weight_decay)
- b0, b1 = pg.get("betas", self.betas)
-
- for p, g, v, m,mp in zip(pg["params"], pg["grads"], pg["values"], pg["m"],pg['masterparams']):
- if p.is_stop_grad(): continue
- #if g.abs().sum().item() < 1e-8: continue
- #import pdb; pdb.set_trace()
- c_p = (mp * (1 - lr * weight_decay))
- mp.update(c_p)
- if self.use_fp32:
- g = g.float32()
- bias_correction1 = 1 - b0 ** n
- bias_correction2 = 1 - b1 ** n
- m.update(b0 * m + (1-b0) * g) #exp_avg
- v.update(b1 * v + (1-b1) * g * g) #exp_avg_sq
- denom = jt.sqrt(v) / jt.sqrt(bias_correction2) + eps
- step_size = lr / bias_correction1
- new_p = (mp - step_size * m / denom)
- mp.update(new_p)
- p.update(mp.cast(p.dtype))
- self.post_step()
-
-for k,v in jt.optim.__dict__.items():
- if k == "AdamW":continue
- if isinstance(v, type) and issubclass(v, jt.optim.Optimizer) and \
- not v is jt.optim.Optimizer:
- class OptimWrap(v, Optimizer):
- pass
- globals()[k] = OptimWrap
-
-
-class Adagrad(Optimizer):
- pass
-
-
-
-import types
-import math
-from functools import wraps
-import warnings
-import weakref
-from collections import Counter
-from bisect import bisect_right
-
-
-class LRScheduler:
-
- def __init__(self, optimizer, last_epoch=-1, verbose=False):
-
- # Attach optimizer
- if not isinstance(optimizer, Optimizer):
- raise TypeError('{} is not an Optimizer'.format(
- type(optimizer).__name__))
- self.optimizer = optimizer
-
- # Initialize epoch and base learning rates
- if last_epoch == -1:
- for group in optimizer.param_groups:
- group.setdefault('initial_lr', group.get("lr",optimizer.lr))
- else:
- for i, group in enumerate(optimizer.param_groups):
- if 'initial_lr' not in group:
- raise KeyError("param 'initial_lr' is not specified "
- "in param_groups[{}] when resuming an optimizer".format(i))
- self.base_lrs = [group['initial_lr'] for group in optimizer.param_groups]
- self.last_epoch = last_epoch
-
- # Following https://github.com/pytorch/pytorch/issues/20124
- # We would like to ensure that `lr_scheduler.step()` is called after
- # `optimizer.step()`
- def with_counter(method):
- if getattr(method, '_with_counter', False):
- # `optimizer.step()` has already been replaced, return.
- return method
-
- # Keep a weak reference to the optimizer instance to prevent
- # cyclic references.
- instance_ref = weakref.ref(method.__self__)
- # Get the unbound method for the same purpose.
- func = method.__func__
- cls = instance_ref().__class__
- del method
-
- @wraps(func)
- def wrapper(*args, **kwargs):
- instance = instance_ref()
- instance._step_count += 1
- wrapped = func.__get__(instance, cls)
- return wrapped(*args, **kwargs)
-
- # Note that the returned function here is no longer a bound method,
- # so attributes like `__func__` and `__self__` no longer exist.
- wrapper._with_counter = True
- return wrapper
-
- self.optimizer.step = with_counter(self.optimizer.step)
- self.verbose = verbose
-
- self._initial_step()
-
- def _initial_step(self):
- """Initialize step counts and performs a step"""
- self.optimizer._step_count = 0
- self._step_count = 0
- self.step()
-
- def state_dict(self):
- """Returns the state of the scheduler as a :class:`dict`.
-
- It contains an entry for every variable in self.__dict__ which
- is not the optimizer.
- """
- return {key: value for key, value in self.__dict__.items() if key != 'optimizer'}
-
- def load_state_dict(self, state_dict):
- """Loads the schedulers state.
-
- Args:
- state_dict (dict): scheduler state. Should be an object returned
- from a call to :meth:`state_dict`.
- """
- self.__dict__.update(state_dict)
-
- def get_last_lr(self):
- """ Return last computed learning rate by current scheduler.
- """
- return self._last_lr
-
- def get_lr(self):
- # Compute learning rate using chainable form of the scheduler
- raise NotImplementedError
-
- def print_lr(self, is_verbose, group, lr, epoch=None):
- """Display the current learning rate.
- """
- if is_verbose:
- if epoch is None:
- print('Adjusting learning rate'
- ' of group {} to {:.4e}.'.format(group, lr))
- else:
- epoch_str = ("%.2f" if isinstance(epoch, float) else
- "%.5d") % epoch
- print('Epoch {}: adjusting learning rate'
- ' of group {} to {:.4e}.'.format(epoch_str, group, lr))
-
-
- def step(self, epoch=None):
- # Raise a warning if old pattern is detected
- # https://github.com/pytorch/pytorch/issues/20124
- if self._step_count == 1:
- if not hasattr(self.optimizer.step, "_with_counter"):
- warnings.warn("Seems like `optimizer.step()` has been overridden after learning rate scheduler "
- "initialization. Please, make sure to call `optimizer.step()` before "
- "`lr_scheduler.step()`. See more details at "
- "https://pytorch.org/docs/stable/optim.html#how-to-adjust-learning-rate", UserWarning)
-
- # Just check if there were two first lr_scheduler.step() calls before optimizer.step()
- elif self.optimizer._step_count < 1:
- warnings.warn("Detected call of `lr_scheduler.step()` before `optimizer.step()`. "
- "In PyTorch 1.1.0 and later, you should call them in the opposite order: "
- "`optimizer.step()` before `lr_scheduler.step()`. Failure to do this "
- "will result in PyTorch skipping the first value of the learning rate schedule. "
- "See more details at "
- "https://pytorch.org/docs/stable/optim.html#how-to-adjust-learning-rate", UserWarning)
- self._step_count += 1
-
- with _enable_get_lr_call(self):
- if epoch is None:
- self.last_epoch += 1
- values = self.get_lr()
- else:
- warnings.warn(EPOCH_DEPRECATION_WARNING, UserWarning)
- self.last_epoch = epoch
- if hasattr(self, "_get_closed_form_lr"):
- values = self._get_closed_form_lr()
- else:
- values = self.get_lr()
-
- for i, data in enumerate(zip(self.optimizer.param_groups, values)):
- param_group, lr = data
- param_group['lr'] = lr
- self.print_lr(self.verbose, i, lr, epoch)
-
- self._last_lr = [group['lr'] for group in self.optimizer.param_groups]
-
-# Including _LRScheduler for backwards compatibility
-# Subclass instead of assign because we want __name__ of _LRScheduler to be _LRScheduler (assigning would make it LRScheduler).
-class _LRScheduler(LRScheduler):
- pass
-
-
-class _enable_get_lr_call:
-
- def __init__(self, o):
- self.o = o
-
- def __enter__(self):
- self.o._get_lr_called_within_step = True
- return self
-
- def __exit__(self, type, value, traceback):
- self.o._get_lr_called_within_step = False
-
-
-class LambdaLR(LRScheduler):
- """Sets the learning rate of each parameter group to the initial lr
- times a given function. When last_epoch=-1, sets initial lr as lr.
-
- Args:
- optimizer (Optimizer): Wrapped optimizer.
- lr_lambda (function or list): A function which computes a multiplicative
- factor given an integer parameter epoch, or a list of such
- functions, one for each group in optimizer.param_groups.
- last_epoch (int): The index of last epoch. Default: -1.
- verbose (bool): If ``True``, prints a message to stdout for
- each update. Default: ``False``.
-
- Example:
- >>> # xdoctest: +SKIP
- >>> # Assuming optimizer has two groups.
- >>> lambda1 = lambda epoch: epoch // 30
- >>> lambda2 = lambda epoch: 0.95 ** epoch
- >>> scheduler = LambdaLR(optimizer, lr_lambda=[lambda1, lambda2])
- >>> for epoch in range(100):
- >>> train(...)
- >>> validate(...)
- >>> scheduler.step()
- """
-
- def __init__(self, optimizer, lr_lambda, last_epoch=-1, verbose=False):
- self.optimizer = optimizer
-
- if not isinstance(lr_lambda, list) and not isinstance(lr_lambda, tuple):
- self.lr_lambdas = [lr_lambda] * len(optimizer.param_groups)
- else:
- if len(lr_lambda) != len(optimizer.param_groups):
- raise ValueError("Expected {} lr_lambdas, but got {}".format(
- len(optimizer.param_groups), len(lr_lambda)))
- self.lr_lambdas = list(lr_lambda)
- super().__init__(optimizer, last_epoch, verbose)
-
- def state_dict(self):
- """Returns the state of the scheduler as a :class:`dict`.
-
- It contains an entry for every variable in self.__dict__ which
- is not the optimizer.
- The learning rate lambda functions will only be saved if they are callable objects
- and not if they are functions or lambdas.
-
- When saving or loading the scheduler, please make sure to also save or load the state of the optimizer.
- """
-
- state_dict = {key: value for key, value in self.__dict__.items() if key not in ('optimizer', 'lr_lambdas')}
- state_dict['lr_lambdas'] = [None] * len(self.lr_lambdas)
-
- for idx, fn in enumerate(self.lr_lambdas):
- if not isinstance(fn, types.FunctionType):
- state_dict['lr_lambdas'][idx] = fn.__dict__.copy()
-
- return state_dict
-
- def load_state_dict(self, state_dict):
- """Loads the schedulers state.
-
- When saving or loading the scheduler, please make sure to also save or load the state of the optimizer.
-
- Args:
- state_dict (dict): scheduler state. Should be an object returned
- from a call to :meth:`state_dict`.
- """
-
- lr_lambdas = state_dict.pop('lr_lambdas')
- self.__dict__.update(state_dict)
- # Restore state_dict keys in order to prevent side effects
- # https://github.com/pytorch/pytorch/issues/32756
- state_dict['lr_lambdas'] = lr_lambdas
-
- for idx, fn in enumerate(lr_lambdas):
- if fn is not None:
- self.lr_lambdas[idx].__dict__.update(fn)
-
- def get_lr(self):
- if not self._get_lr_called_within_step:
- warnings.warn("To get the last learning rate computed by the scheduler, "
- "please use `get_last_lr()`.")
- return [base_lr * lmbda(self.last_epoch)
- for lmbda, base_lr in zip(self.lr_lambdas, self.base_lrs)]
-
-
-class MultiplicativeLR(LRScheduler):
- """Multiply the learning rate of each parameter group by the factor given
- in the specified function. When last_epoch=-1, sets initial lr as lr.
-
- Args:
- optimizer (Optimizer): Wrapped optimizer.
- lr_lambda (function or list): A function which computes a multiplicative
- factor given an integer parameter epoch, or a list of such
- functions, one for each group in optimizer.param_groups.
- last_epoch (int): The index of last epoch. Default: -1.
- verbose (bool): If ``True``, prints a message to stdout for
- each update. Default: ``False``.
-
- Example:
- >>> # xdoctest: +SKIP
- >>> lmbda = lambda epoch: 0.95
- >>> scheduler = MultiplicativeLR(optimizer, lr_lambda=lmbda)
- >>> for epoch in range(100):
- >>> train(...)
- >>> validate(...)
- >>> scheduler.step()
- """
-
- def __init__(self, optimizer, lr_lambda, last_epoch=-1, verbose=False):
- self.optimizer = optimizer
-
- if not isinstance(lr_lambda, list) and not isinstance(lr_lambda, tuple):
- self.lr_lambdas = [lr_lambda] * len(optimizer.param_groups)
- else:
- if len(lr_lambda) != len(optimizer.param_groups):
- raise ValueError("Expected {} lr_lambdas, but got {}".format(
- len(optimizer.param_groups), len(lr_lambda)))
- self.lr_lambdas = list(lr_lambda)
- super().__init__(optimizer, last_epoch, verbose)
-
- def state_dict(self):
- """Returns the state of the scheduler as a :class:`dict`.
-
- It contains an entry for every variable in self.__dict__ which
- is not the optimizer.
- The learning rate lambda functions will only be saved if they are callable objects
- and not if they are functions or lambdas.
- """
- state_dict = {key: value for key, value in self.__dict__.items() if key not in ('optimizer', 'lr_lambdas')}
- state_dict['lr_lambdas'] = [None] * len(self.lr_lambdas)
-
- for idx, fn in enumerate(self.lr_lambdas):
- if not isinstance(fn, types.FunctionType):
- state_dict['lr_lambdas'][idx] = fn.__dict__.copy()
-
- return state_dict
-
- def load_state_dict(self, state_dict):
- """Loads the schedulers state.
-
- Args:
- state_dict (dict): scheduler state. Should be an object returned
- from a call to :meth:`state_dict`.
- """
- lr_lambdas = state_dict.pop('lr_lambdas')
- self.__dict__.update(state_dict)
- # Restore state_dict keys in order to prevent side effects
- # https://github.com/pytorch/pytorch/issues/32756
- state_dict['lr_lambdas'] = lr_lambdas
-
- for idx, fn in enumerate(lr_lambdas):
- if fn is not None:
- self.lr_lambdas[idx].__dict__.update(fn)
-
- def get_lr(self):
- if not self._get_lr_called_within_step:
- warnings.warn("To get the last learning rate computed by the scheduler, "
- "please use `get_last_lr()`.", UserWarning)
-
- if self.last_epoch > 0:
- return [group['lr'] * lmbda(self.last_epoch)
- for lmbda, group in zip(self.lr_lambdas, self.optimizer.param_groups)]
- else:
- return [group['lr'] for group in self.optimizer.param_groups]
-
-
-class StepLR(LRScheduler):
- """Decays the learning rate of each parameter group by gamma every
- step_size epochs. Notice that such decay can happen simultaneously with
- other changes to the learning rate from outside this scheduler. When
- last_epoch=-1, sets initial lr as lr.
-
- Args:
- optimizer (Optimizer): Wrapped optimizer.
- step_size (int): Period of learning rate decay.
- gamma (float): Multiplicative factor of learning rate decay.
- Default: 0.1.
- last_epoch (int): The index of last epoch. Default: -1.
- verbose (bool): If ``True``, prints a message to stdout for
- each update. Default: ``False``.
-
- Example:
- >>> # xdoctest: +SKIP
- >>> # Assuming optimizer uses lr = 0.05 for all groups
- >>> # lr = 0.05 if epoch < 30
- >>> # lr = 0.005 if 30 <= epoch < 60
- >>> # lr = 0.0005 if 60 <= epoch < 90
- >>> # ...
- >>> scheduler = StepLR(optimizer, step_size=30, gamma=0.1)
- >>> for epoch in range(100):
- >>> train(...)
- >>> validate(...)
- >>> scheduler.step()
- """
-
- def __init__(self, optimizer, step_size, gamma=0.1, last_epoch=-1, verbose=False):
- self.step_size = step_size
- self.gamma = gamma
- super().__init__(optimizer, last_epoch, verbose)
-
- def get_lr(self):
- if not self._get_lr_called_within_step:
- warnings.warn("To get the last learning rate computed by the scheduler, "
- "please use `get_last_lr()`.", UserWarning)
-
- if (self.last_epoch == 0) or (self.last_epoch % self.step_size != 0):
- return [group['lr'] for group in self.optimizer.param_groups]
- return [group['lr'] * self.gamma
- for group in self.optimizer.param_groups]
-
- def _get_closed_form_lr(self):
- return [base_lr * self.gamma ** (self.last_epoch // self.step_size)
- for base_lr in self.base_lrs]
-
-
-class MultiStepLR(LRScheduler):
- """Decays the learning rate of each parameter group by gamma once the
- number of epoch reaches one of the milestones. Notice that such decay can
- happen simultaneously with other changes to the learning rate from outside
- this scheduler. When last_epoch=-1, sets initial lr as lr.
-
- Args:
- optimizer (Optimizer): Wrapped optimizer.
- milestones (list): List of epoch indices. Must be increasing.
- gamma (float): Multiplicative factor of learning rate decay.
- Default: 0.1.
- last_epoch (int): The index of last epoch. Default: -1.
- verbose (bool): If ``True``, prints a message to stdout for
- each update. Default: ``False``.
-
- Example:
- >>> # xdoctest: +SKIP
- >>> # Assuming optimizer uses lr = 0.05 for all groups
- >>> # lr = 0.05 if epoch < 30
- >>> # lr = 0.005 if 30 <= epoch < 80
- >>> # lr = 0.0005 if epoch >= 80
- >>> scheduler = MultiStepLR(optimizer, milestones=[30,80], gamma=0.1)
- >>> for epoch in range(100):
- >>> train(...)
- >>> validate(...)
- >>> scheduler.step()
- """
-
- def __init__(self, optimizer, milestones, gamma=0.1, last_epoch=-1, verbose=False):
- self.milestones = Counter(milestones)
- self.gamma = gamma
- super().__init__(optimizer, last_epoch, verbose)
-
- def get_lr(self):
- if not self._get_lr_called_within_step:
- warnings.warn("To get the last learning rate computed by the scheduler, "
- "please use `get_last_lr()`.", UserWarning)
-
- if self.last_epoch not in self.milestones:
- return [group['lr'] for group in self.optimizer.param_groups]
- return [group['lr'] * self.gamma ** self.milestones[self.last_epoch]
- for group in self.optimizer.param_groups]
-
- def _get_closed_form_lr(self):
- milestones = sorted(self.milestones.elements())
- return [base_lr * self.gamma ** bisect_right(milestones, self.last_epoch)
- for base_lr in self.base_lrs]
-
-
-class ConstantLR(LRScheduler):
- """Decays the learning rate of each parameter group by a small constant factor until the
- number of epoch reaches a pre-defined milestone: total_iters. Notice that such decay can
- happen simultaneously with other changes to the learning rate from outside this scheduler.
- When last_epoch=-1, sets initial lr as lr.
-
- Args:
- optimizer (Optimizer): Wrapped optimizer.
- factor (float): The number we multiply learning rate until the milestone. Default: 1./3.
- total_iters (int): The number of steps that the scheduler decays the learning rate.
- Default: 5.
- last_epoch (int): The index of the last epoch. Default: -1.
- verbose (bool): If ``True``, prints a message to stdout for
- each update. Default: ``False``.
-
- Example:
- >>> # xdoctest: +SKIP
- >>> # Assuming optimizer uses lr = 0.05 for all groups
- >>> # lr = 0.025 if epoch == 0
- >>> # lr = 0.025 if epoch == 1
- >>> # lr = 0.025 if epoch == 2
- >>> # lr = 0.025 if epoch == 3
- >>> # lr = 0.05 if epoch >= 4
- >>> scheduler = ConstantLR(self.opt, factor=0.5, total_iters=4)
- >>> for epoch in range(100):
- >>> train(...)
- >>> validate(...)
- >>> scheduler.step()
- """
-
- def __init__(self, optimizer, factor=1.0 / 3, total_iters=5, last_epoch=-1, verbose=False):
- if factor > 1.0 or factor < 0:
- raise ValueError('Constant multiplicative factor expected to be between 0 and 1.')
-
- self.factor = factor
- self.total_iters = total_iters
- super().__init__(optimizer, last_epoch, verbose)
-
- def get_lr(self):
- if not self._get_lr_called_within_step:
- warnings.warn("To get the last learning rate computed by the scheduler, "
- "please use `get_last_lr()`.", UserWarning)
-
- if self.last_epoch == 0:
- return [group['lr'] * self.factor for group in self.optimizer.param_groups]
-
- if (self.last_epoch > self.total_iters or
- (self.last_epoch != self.total_iters)):
- return [group['lr'] for group in self.optimizer.param_groups]
-
- if (self.last_epoch == self.total_iters):
- return [group['lr'] * (1.0 / self.factor) for group in self.optimizer.param_groups]
-
- def _get_closed_form_lr(self):
- return [base_lr * (self.factor + (self.last_epoch >= self.total_iters) * (1 - self.factor))
- for base_lr in self.base_lrs]
-
-
-class LinearLR(LRScheduler):
- """Decays the learning rate of each parameter group by linearly changing small
- multiplicative factor until the number of epoch reaches a pre-defined milestone: total_iters.
- Notice that such decay can happen simultaneously with other changes to the learning rate
- from outside this scheduler. When last_epoch=-1, sets initial lr as lr.
-
- Args:
- optimizer (Optimizer): Wrapped optimizer.
- start_factor (float): The number we multiply learning rate in the first epoch.
- The multiplication factor changes towards end_factor in the following epochs.
- Default: 1./3.
- end_factor (float): The number we multiply learning rate at the end of linear changing
- process. Default: 1.0.
- total_iters (int): The number of iterations that multiplicative factor reaches to 1.
- Default: 5.
- last_epoch (int): The index of the last epoch. Default: -1.
- verbose (bool): If ``True``, prints a message to stdout for
- each update. Default: ``False``.
-
- Example:
- >>> # xdoctest: +SKIP
- >>> # Assuming optimizer uses lr = 0.05 for all groups
- >>> # lr = 0.025 if epoch == 0
- >>> # lr = 0.03125 if epoch == 1
- >>> # lr = 0.0375 if epoch == 2
- >>> # lr = 0.04375 if epoch == 3
- >>> # lr = 0.05 if epoch >= 4
- >>> scheduler = LinearLR(self.opt, start_factor=0.5, total_iters=4)
- >>> for epoch in range(100):
- >>> train(...)
- >>> validate(...)
- >>> scheduler.step()
- """
-
- def __init__(self, optimizer, start_factor=1.0 / 3, end_factor=1.0, total_iters=5, last_epoch=-1,
- verbose=False):
- if start_factor > 1.0 or start_factor <= 0:
- raise ValueError('Starting multiplicative factor expected to be greater than 0 and less or equal to 1.')
-
- if end_factor > 1.0 or end_factor < 0:
- raise ValueError('Ending multiplicative factor expected to be between 0 and 1.')
-
- self.start_factor = start_factor
- self.end_factor = end_factor
- self.total_iters = total_iters
- super().__init__(optimizer, last_epoch, verbose)
-
- def get_lr(self):
- if not self._get_lr_called_within_step:
- warnings.warn("To get the last learning rate computed by the scheduler, "
- "please use `get_last_lr()`.", UserWarning)
-
- if self.last_epoch == 0:
- return [group['lr'] * self.start_factor for group in self.optimizer.param_groups]
-
- if self.last_epoch > self.total_iters:
- return [group['lr'] for group in self.optimizer.param_groups]
-
- return [group['lr'] * (1. + (self.end_factor - self.start_factor) /
- (self.total_iters * self.start_factor + (self.last_epoch - 1) * (self.end_factor - self.start_factor)))
- for group in self.optimizer.param_groups]
-
- def _get_closed_form_lr(self):
- return [base_lr * (self.start_factor +
- (self.end_factor - self.start_factor) * min(self.total_iters, self.last_epoch) / self.total_iters)
- for base_lr in self.base_lrs]
-
-
-class ExponentialLR(LRScheduler):
- """Decays the learning rate of each parameter group by gamma every epoch.
- When last_epoch=-1, sets initial lr as lr.
-
- Args:
- optimizer (Optimizer): Wrapped optimizer.
- gamma (float): Multiplicative factor of learning rate decay.
- last_epoch (int): The index of last epoch. Default: -1.
- verbose (bool): If ``True``, prints a message to stdout for
- each update. Default: ``False``.
- """
-
- def __init__(self, optimizer, gamma, last_epoch=-1, verbose=False):
- self.gamma = gamma
- super().__init__(optimizer, last_epoch, verbose)
-
- def get_lr(self):
- if not self._get_lr_called_within_step:
- warnings.warn("To get the last learning rate computed by the scheduler, "
- "please use `get_last_lr()`.", UserWarning)
-
- if self.last_epoch == 0:
- return [group['lr'] for group in self.optimizer.param_groups]
- return [group['lr'] * self.gamma
- for group in self.optimizer.param_groups]
-
- def _get_closed_form_lr(self):
- return [base_lr * self.gamma ** self.last_epoch
- for base_lr in self.base_lrs]
-
-
-class SequentialLR(LRScheduler):
- """Receives the list of schedulers that is expected to be called sequentially during
- optimization process and milestone points that provides exact intervals to reflect
- which scheduler is supposed to be called at a given epoch.
-
- Args:
- optimizer (Optimizer): Wrapped optimizer.
- schedulers (list): List of chained schedulers.
- milestones (list): List of integers that reflects milestone points.
- last_epoch (int): The index of last epoch. Default: -1.
- verbose (bool): Does nothing.
-
- Example:
- >>> # xdoctest: +SKIP
- >>> # Assuming optimizer uses lr = 1. for all groups
- >>> # lr = 0.1 if epoch == 0
- >>> # lr = 0.1 if epoch == 1
- >>> # lr = 0.9 if epoch == 2
- >>> # lr = 0.81 if epoch == 3
- >>> # lr = 0.729 if epoch == 4
- >>> scheduler1 = ConstantLR(self.opt, factor=0.1, total_iters=2)
- >>> scheduler2 = ExponentialLR(self.opt, gamma=0.9)
- >>> scheduler = SequentialLR(self.opt, schedulers=[scheduler1, scheduler2], milestones=[2])
- >>> for epoch in range(100):
- >>> train(...)
- >>> validate(...)
- >>> scheduler.step()
- """
-
- def __init__(self, optimizer, schedulers, milestones, last_epoch=-1, verbose=False):
- for scheduler_idx in range(len(schedulers)):
- if schedulers[scheduler_idx].optimizer != optimizer:
- raise ValueError(
- "Sequential Schedulers expects all schedulers to belong to the same optimizer, but "
- f"got schedulers at index {scheduler_idx} to be different than the optimizer passed in."
- )
-
- if (schedulers[scheduler_idx].optimizer != schedulers[0].optimizer):
- raise ValueError(
- "Sequential Schedulers expects all schedulers to belong to the same optimizer, but "
- f"got schedulers at index {0} and {scheduler_idx} to be different."
- )
- if (len(milestones) != len(schedulers) - 1):
- raise ValueError(
- "Sequential Schedulers expects number of schedulers provided to be one more "
- "than the number of milestone points, but got number of schedulers {} and the "
- "number of milestones to be equal to {}".format(len(schedulers), len(milestones))
- )
- self._schedulers = schedulers
- self._milestones = milestones
- self.last_epoch = last_epoch + 1
- self.optimizer = optimizer
-
- # Reset learning rates back to initial values
- for group in self.optimizer.param_groups:
- group["lr"] = group["initial_lr"]
-
- # "Undo" the step performed by other schedulers
- for scheduler in self._schedulers:
- scheduler.last_epoch -= 1
-
- # Perform the initial step for only the first scheduler
- self._schedulers[0]._initial_step()
-
- self._last_lr = schedulers[0].get_last_lr()
-
- def step(self):
- self.last_epoch += 1
- idx = bisect_right(self._milestones, self.last_epoch)
- scheduler = self._schedulers[idx]
- if idx > 0 and self._milestones[idx - 1] == self.last_epoch:
- scheduler.step(0)
- else:
- scheduler.step()
-
- self._last_lr = scheduler.get_last_lr()
-
- def state_dict(self):
- """Returns the state of the scheduler as a :class:`dict`.
-
- It contains an entry for every variable in self.__dict__ which
- is not the optimizer.
- The wrapped scheduler states will also be saved.
- """
- state_dict = {key: value for key, value in self.__dict__.items() if key not in ('optimizer', '_schedulers')}
- state_dict['_schedulers'] = [None] * len(self._schedulers)
-
- for idx, s in enumerate(self._schedulers):
- state_dict['_schedulers'][idx] = s.state_dict()
-
- return state_dict
-
- def load_state_dict(self, state_dict):
- """Loads the schedulers state.
-
- Args:
- state_dict (dict): scheduler state. Should be an object returned
- from a call to :meth:`state_dict`.
- """
- _schedulers = state_dict.pop('_schedulers')
- self.__dict__.update(state_dict)
- # Restore state_dict keys in order to prevent side effects
- # https://github.com/pytorch/pytorch/issues/32756
- state_dict['_schedulers'] = _schedulers
-
- for idx, s in enumerate(_schedulers):
- self._schedulers[idx].load_state_dict(s)
-
-
-class PolynomialLR(LRScheduler):
- """Decays the learning rate of each parameter group using a polynomial function
- in the given total_iters. When last_epoch=-1, sets initial lr as lr.
-
- Args:
- optimizer (Optimizer): Wrapped optimizer.
- total_iters (int): The number of steps that the scheduler decays the learning rate. Default: 5.
- power (int): The power of the polynomial. Default: 1.0.
- verbose (bool): If ``True``, prints a message to stdout for
- each update. Default: ``False``.
-
- Example:
- >>> # xdoctest: +SKIP("undefined vars")
- >>> # Assuming optimizer uses lr = 0.001 for all groups
- >>> # lr = 0.001 if epoch == 0
- >>> # lr = 0.00075 if epoch == 1
- >>> # lr = 0.00050 if epoch == 2
- >>> # lr = 0.00025 if epoch == 3
- >>> # lr = 0.0 if epoch >= 4
- >>> scheduler = PolynomialLR(self.opt, total_iters=4, power=1.0)
- >>> for epoch in range(100):
- >>> train(...)
- >>> validate(...)
- >>> scheduler.step()
- """
- def __init__(self, optimizer, total_iters=5, power=1.0, last_epoch=-1, verbose=False):
- self.total_iters = total_iters
- self.power = power
- super().__init__(optimizer, last_epoch, verbose)
-
- def get_lr(self):
- if not self._get_lr_called_within_step:
- warnings.warn("To get the last learning rate computed by the scheduler, "
- "please use `get_last_lr()`.", UserWarning)
-
- if self.last_epoch == 0 or self.last_epoch > self.total_iters:
- return [group["lr"] for group in self.optimizer.param_groups]
-
- decay_factor = ((1.0 - self.last_epoch / self.total_iters) / (1.0 - (self.last_epoch - 1) / self.total_iters)) ** self.power
- return [group["lr"] * decay_factor for group in self.optimizer.param_groups]
-
- def _get_closed_form_lr(self):
- return [
- (
- base_lr * (1.0 - min(self.total_iters, self.last_epoch) / self.total_iters) ** self.power
- )
- for base_lr in self.base_lrs
- ]
-
-
-class CosineAnnealingLR(LRScheduler):
- r"""Set the learning rate of each parameter group using a cosine annealing
- schedule, where :math:`\eta_{max}` is set to the initial lr and
- :math:`T_{cur}` is the number of epochs since the last restart in SGDR:
-
- .. math::
- \begin{aligned}
- \eta_t & = \eta_{min} + \frac{1}{2}(\eta_{max} - \eta_{min})\left(1
- + \cos\left(\frac{T_{cur}}{T_{max}}\pi\right)\right),
- & T_{cur} \neq (2k+1)T_{max}; \\
- \eta_{t+1} & = \eta_{t} + \frac{1}{2}(\eta_{max} - \eta_{min})
- \left(1 - \cos\left(\frac{1}{T_{max}}\pi\right)\right),
- & T_{cur} = (2k+1)T_{max}.
- \end{aligned}
-
- When last_epoch=-1, sets initial lr as lr. Notice that because the schedule
- is defined recursively, the learning rate can be simultaneously modified
- outside this scheduler by other operators. If the learning rate is set
- solely by this scheduler, the learning rate at each step becomes:
-
- .. math::
- \eta_t = \eta_{min} + \frac{1}{2}(\eta_{max} - \eta_{min})\left(1 +
- \cos\left(\frac{T_{cur}}{T_{max}}\pi\right)\right)
-
- It has been proposed in
- `SGDR: Stochastic Gradient Descent with Warm Restarts`_. Note that this only
- implements the cosine annealing part of SGDR, and not the restarts.
-
- Args:
- optimizer (Optimizer): Wrapped optimizer.
- T_max (int): Maximum number of iterations.
- eta_min (float): Minimum learning rate. Default: 0.
- last_epoch (int): The index of last epoch. Default: -1.
- verbose (bool): If ``True``, prints a message to stdout for
- each update. Default: ``False``.
-
- .. _SGDR\: Stochastic Gradient Descent with Warm Restarts:
- https://arxiv.org/abs/1608.03983
- """
-
- def __init__(self, optimizer, T_max, eta_min=0, last_epoch=-1, verbose=False):
- self.T_max = T_max
- self.eta_min = eta_min
- super().__init__(optimizer, last_epoch, verbose)
-
- def get_lr(self):
- if not self._get_lr_called_within_step:
- warnings.warn("To get the last learning rate computed by the scheduler, "
- "please use `get_last_lr()`.", UserWarning)
-
- if self.last_epoch == 0:
- return [group['lr'] for group in self.optimizer.param_groups]
- elif self._step_count == 1 and self.last_epoch > 0:
- return [self.eta_min + (base_lr - self.eta_min) *
- (1 + math.cos((self.last_epoch) * math.pi / self.T_max)) / 2
- for base_lr, group in
- zip(self.base_lrs, self.optimizer.param_groups)]
- elif (self.last_epoch - 1 - self.T_max) % (2 * self.T_max) == 0:
- return [group['lr'] + (base_lr - self.eta_min) *
- (1 - math.cos(math.pi / self.T_max)) / 2
- for base_lr, group in
- zip(self.base_lrs, self.optimizer.param_groups)]
- return [(1 + math.cos(math.pi * self.last_epoch / self.T_max)) /
- (1 + math.cos(math.pi * (self.last_epoch - 1) / self.T_max)) *
- (group['lr'] - self.eta_min) + self.eta_min
- for group in self.optimizer.param_groups]
-
- def _get_closed_form_lr(self):
- return [self.eta_min + (base_lr - self.eta_min) *
- (1 + math.cos(math.pi * self.last_epoch / self.T_max)) / 2
- for base_lr in self.base_lrs]
-
-
-class ChainedScheduler(LRScheduler):
- """Chains list of learning rate schedulers. It takes a list of chainable learning
- rate schedulers and performs consecutive step() functions belonging to them by just
- one call.
-
- Args:
- schedulers (list): List of chained schedulers.
-
- Example:
- >>> # xdoctest: +SKIP
- >>> # Assuming optimizer uses lr = 1. for all groups
- >>> # lr = 0.09 if epoch == 0
- >>> # lr = 0.081 if epoch == 1
- >>> # lr = 0.729 if epoch == 2
- >>> # lr = 0.6561 if epoch == 3
- >>> # lr = 0.59049 if epoch >= 4
- >>> scheduler1 = ConstantLR(self.opt, factor=0.1, total_iters=2)
- >>> scheduler2 = ExponentialLR(self.opt, gamma=0.9)
- >>> scheduler = ChainedScheduler([scheduler1, scheduler2])
- >>> for epoch in range(100):
- >>> train(...)
- >>> validate(...)
- >>> scheduler.step()
- """
-
- def __init__(self, schedulers):
- for scheduler_idx in range(1, len(schedulers)):
- if (schedulers[scheduler_idx].optimizer != schedulers[0].optimizer):
- raise ValueError(
- "ChainedScheduler expects all schedulers to belong to the same optimizer, but "
- "got schedulers at index {} and {} to be different".format(0, scheduler_idx)
- )
- self._schedulers = list(schedulers)
- self.optimizer = schedulers[0].optimizer
- self._last_lr = [group['lr'] for group in self._schedulers[-1].optimizer.param_groups]
-
- def step(self):
- for scheduler in self._schedulers:
- scheduler.step()
- self._last_lr = [group['lr'] for group in self._schedulers[-1].optimizer.param_groups]
-
- def state_dict(self):
- """Returns the state of the scheduler as a :class:`dict`.
-
- It contains an entry for every variable in self.__dict__ which
- is not the optimizer.
- The wrapped scheduler states will also be saved.
- """
- state_dict = {key: value for key, value in self.__dict__.items() if key not in ('optimizer', '_schedulers')}
- state_dict['_schedulers'] = [None] * len(self._schedulers)
-
- for idx, s in enumerate(self._schedulers):
- state_dict['_schedulers'][idx] = s.state_dict()
-
- return state_dict
-
- def load_state_dict(self, state_dict):
- """Loads the schedulers state.
-
- Args:
- state_dict (dict): scheduler state. Should be an object returned
- from a call to :meth:`state_dict`.
- """
- _schedulers = state_dict.pop('_schedulers')
- self.__dict__.update(state_dict)
- # Restore state_dict keys in order to prevent side effects
- # https://github.com/pytorch/pytorch/issues/32756
- state_dict['_schedulers'] = _schedulers
-
- for idx, s in enumerate(_schedulers):
- self._schedulers[idx].load_state_dict(s)
-
-
-class ReduceLROnPlateau:
- """Reduce learning rate when a metric has stopped improving.
- Models often benefit from reducing the learning rate by a factor
- of 2-10 once learning stagnates. This scheduler reads a metrics
- quantity and if no improvement is seen for a 'patience' number
- of epochs, the learning rate is reduced.
-
- Args:
- optimizer (Optimizer): Wrapped optimizer.
- mode (str): One of `min`, `max`. In `min` mode, lr will
- be reduced when the quantity monitored has stopped
- decreasing; in `max` mode it will be reduced when the
- quantity monitored has stopped increasing. Default: 'min'.
- factor (float): Factor by which the learning rate will be
- reduced. new_lr = lr * factor. Default: 0.1.
- patience (int): Number of epochs with no improvement after
- which learning rate will be reduced. For example, if
- `patience = 2`, then we will ignore the first 2 epochs
- with no improvement, and will only decrease the LR after the
- 3rd epoch if the loss still hasn't improved then.
- Default: 10.
- threshold (float): Threshold for measuring the new optimum,
- to only focus on significant changes. Default: 1e-4.
- threshold_mode (str): One of `rel`, `abs`. In `rel` mode,
- dynamic_threshold = best * ( 1 + threshold ) in 'max'
- mode or best * ( 1 - threshold ) in `min` mode.
- In `abs` mode, dynamic_threshold = best + threshold in
- `max` mode or best - threshold in `min` mode. Default: 'rel'.
- cooldown (int): Number of epochs to wait before resuming
- normal operation after lr has been reduced. Default: 0.
- min_lr (float or list): A scalar or a list of scalars. A
- lower bound on the learning rate of all param groups
- or each group respectively. Default: 0.
- eps (float): Minimal decay applied to lr. If the difference
- between new and old lr is smaller than eps, the update is
- ignored. Default: 1e-8.
- verbose (bool): If ``True``, prints a message to stdout for
- each update. Default: ``False``.
-
- Example:
- >>> # xdoctest: +SKIP
- >>> optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9)
- >>> scheduler = ReduceLROnPlateau(optimizer, 'min')
- >>> for epoch in range(10):
- >>> train(...)
- >>> val_loss = validate(...)
- >>> # Note that step should be called after validate()
- >>> scheduler.step(val_loss)
- """
-
- def __init__(self, optimizer, mode='min', factor=0.1, patience=10,
- threshold=1e-4, threshold_mode='rel', cooldown=0,
- min_lr=0, eps=1e-8, verbose=False):
-
- if factor >= 1.0:
- raise ValueError('Factor should be < 1.0.')
- self.factor = factor
-
- # Attach optimizer
- if not isinstance(optimizer, Optimizer):
- raise TypeError('{} is not an Optimizer'.format(
- type(optimizer).__name__))
- self.optimizer = optimizer
-
- if isinstance(min_lr, (list, tuple)):
- if len(min_lr) != len(optimizer.param_groups):
- raise ValueError("expected {} min_lrs, got {}".format(
- len(optimizer.param_groups), len(min_lr)))
- self.min_lrs = list(min_lr)
- else:
- self.min_lrs = [min_lr] * len(optimizer.param_groups)
-
- self.patience = patience
- self.verbose = verbose
- self.cooldown = cooldown
- self.cooldown_counter = 0
- self.mode = mode
- self.threshold = threshold
- self.threshold_mode = threshold_mode
- self.best = None
- self.num_bad_epochs = None
- self.mode_worse = None # the worse value for the chosen mode
- self.eps = eps
- self.last_epoch = 0
- self._init_is_better(mode=mode, threshold=threshold,
- threshold_mode=threshold_mode)
- self._reset()
-
- def _reset(self):
- """Resets num_bad_epochs counter and cooldown counter."""
- self.best = self.mode_worse
- self.cooldown_counter = 0
- self.num_bad_epochs = 0
-
- def step(self, metrics, epoch=None):
- # convert `metrics` to float, in case it's a zero-dim Tensor
- current = float(metrics)
- if epoch is None:
- epoch = self.last_epoch + 1
- else:
- warnings.warn(EPOCH_DEPRECATION_WARNING, UserWarning)
- self.last_epoch = epoch
-
- if self.is_better(current, self.best):
- self.best = current
- self.num_bad_epochs = 0
- else:
- self.num_bad_epochs += 1
-
- if self.in_cooldown:
- self.cooldown_counter -= 1
- self.num_bad_epochs = 0 # ignore any bad epochs in cooldown
-
- if self.num_bad_epochs > self.patience:
- self._reduce_lr(epoch)
- self.cooldown_counter = self.cooldown
- self.num_bad_epochs = 0
-
- self._last_lr = [group['lr'] for group in self.optimizer.param_groups]
-
- def _reduce_lr(self, epoch):
- for i, param_group in enumerate(self.optimizer.param_groups):
- old_lr = float(param_group['lr'])
- new_lr = max(old_lr * self.factor, self.min_lrs[i])
- if old_lr - new_lr > self.eps:
- param_group['lr'] = new_lr
- if self.verbose:
- epoch_str = ("%.2f" if isinstance(epoch, float) else
- "%.5d") % epoch
- print('Epoch {}: reducing learning rate'
- ' of group {} to {:.4e}.'.format(epoch_str, i, new_lr))
-
- @property
- def in_cooldown(self):
- return self.cooldown_counter > 0
-
- def is_better(self, a, best):
- if self.mode == 'min' and self.threshold_mode == 'rel':
- rel_epsilon = 1. - self.threshold
- return a < best * rel_epsilon
-
- elif self.mode == 'min' and self.threshold_mode == 'abs':
- return a < best - self.threshold
-
- elif self.mode == 'max' and self.threshold_mode == 'rel':
- rel_epsilon = self.threshold + 1.
- return a > best * rel_epsilon
-
- else: # mode == 'max' and epsilon_mode == 'abs':
- return a > best + self.threshold
-
- def _init_is_better(self, mode, threshold, threshold_mode):
- if mode not in {'min', 'max'}:
- raise ValueError('mode ' + mode + ' is unknown!')
- if threshold_mode not in {'rel', 'abs'}:
- raise ValueError('threshold mode ' + threshold_mode + ' is unknown!')
-
- if mode == 'min':
- self.mode_worse = inf
- else: # mode == 'max':
- self.mode_worse = -inf
-
- self.mode = mode
- self.threshold = threshold
- self.threshold_mode = threshold_mode
-
- def state_dict(self):
- return {key: value for key, value in self.__dict__.items() if key != 'optimizer'}
-
- def load_state_dict(self, state_dict):
- self.__dict__.update(state_dict)
- self._init_is_better(mode=self.mode, threshold=self.threshold, threshold_mode=self.threshold_mode)
-
-
-class CyclicLR(LRScheduler):
- r"""Sets the learning rate of each parameter group according to
- cyclical learning rate policy (CLR). The policy cycles the learning
- rate between two boundaries with a constant frequency, as detailed in
- the paper `Cyclical Learning Rates for Training Neural Networks`_.
- The distance between the two boundaries can be scaled on a per-iteration
- or per-cycle basis.
-
- Cyclical learning rate policy changes the learning rate after every batch.
- `step` should be called after a batch has been used for training.
-
- This class has three built-in policies, as put forth in the paper:
-
- * "triangular": A basic triangular cycle without amplitude scaling.
- * "triangular2": A basic triangular cycle that scales initial amplitude by half each cycle.
- * "exp_range": A cycle that scales initial amplitude by :math:`\text{gamma}^{\text{cycle iterations}}`
- at each cycle iteration.
-
- This implementation was adapted from the github repo: `bckenstler/CLR`_
-
- Args:
- optimizer (Optimizer): Wrapped optimizer.
- base_lr (float or list): Initial learning rate which is the
- lower boundary in the cycle for each parameter group.
- max_lr (float or list): Upper learning rate boundaries in the cycle
- for each parameter group. Functionally,
- it defines the cycle amplitude (max_lr - base_lr).
- The lr at any cycle is the sum of base_lr
- and some scaling of the amplitude; therefore
- max_lr may not actually be reached depending on
- scaling function.
- step_size_up (int): Number of training iterations in the
- increasing half of a cycle. Default: 2000
- step_size_down (int): Number of training iterations in the
- decreasing half of a cycle. If step_size_down is None,
- it is set to step_size_up. Default: None
- mode (str): One of {triangular, triangular2, exp_range}.
- Values correspond to policies detailed above.
- If scale_fn is not None, this argument is ignored.
- Default: 'triangular'
- gamma (float): Constant in 'exp_range' scaling function:
- gamma**(cycle iterations)
- Default: 1.0
- scale_fn (function): Custom scaling policy defined by a single
- argument lambda function, where
- 0 <= scale_fn(x) <= 1 for all x >= 0.
- If specified, then 'mode' is ignored.
- Default: None
- scale_mode (str): {'cycle', 'iterations'}.
- Defines whether scale_fn is evaluated on
- cycle number or cycle iterations (training
- iterations since start of cycle).
- Default: 'cycle'
- cycle_momentum (bool): If ``True``, momentum is cycled inversely
- to learning rate between 'base_momentum' and 'max_momentum'.
- Default: True
- base_momentum (float or list): Lower momentum boundaries in the cycle
- for each parameter group. Note that momentum is cycled inversely
- to learning rate; at the peak of a cycle, momentum is
- 'base_momentum' and learning rate is 'max_lr'.
- Default: 0.8
- max_momentum (float or list): Upper momentum boundaries in the cycle
- for each parameter group. Functionally,
- it defines the cycle amplitude (max_momentum - base_momentum).
- The momentum at any cycle is the difference of max_momentum
- and some scaling of the amplitude; therefore
- base_momentum may not actually be reached depending on
- scaling function. Note that momentum is cycled inversely
- to learning rate; at the start of a cycle, momentum is 'max_momentum'
- and learning rate is 'base_lr'
- Default: 0.9
- last_epoch (int): The index of the last batch. This parameter is used when
- resuming a training job. Since `step()` should be invoked after each
- batch instead of after each epoch, this number represents the total
- number of *batches* computed, not the total number of epochs computed.
- When last_epoch=-1, the schedule is started from the beginning.
- Default: -1
- verbose (bool): If ``True``, prints a message to stdout for
- each update. Default: ``False``.
-
- Example:
- >>> # xdoctest: +SKIP
- >>> optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9)
- >>> scheduler = torch.optim.lr_scheduler.CyclicLR(optimizer, base_lr=0.01, max_lr=0.1)
- >>> data_loader = torch.utils.data.DataLoader(...)
- >>> for epoch in range(10):
- >>> for batch in data_loader:
- >>> train_batch(...)
- >>> scheduler.step()
-
-
- .. _Cyclical Learning Rates for Training Neural Networks: https://arxiv.org/abs/1506.01186
- .. _bckenstler/CLR: https://github.com/bckenstler/CLR
- """
-
- def __init__(self,
- optimizer,
- base_lr,
- max_lr,
- step_size_up=2000,
- step_size_down=None,
- mode='triangular',
- gamma=1.,
- scale_fn=None,
- scale_mode='cycle',
- cycle_momentum=True,
- base_momentum=0.8,
- max_momentum=0.9,
- last_epoch=-1,
- verbose=False):
-
- # Attach optimizer
- if not isinstance(optimizer, Optimizer):
- raise TypeError('{} is not an Optimizer'.format(
- type(optimizer).__name__))
- self.optimizer = optimizer
-
- base_lrs = self._format_param('base_lr', optimizer, base_lr)
- if last_epoch == -1:
- for lr, group in zip(base_lrs, optimizer.param_groups):
- group['lr'] = lr
-
- self.max_lrs = self._format_param('max_lr', optimizer, max_lr)
-
- step_size_up = float(step_size_up)
- step_size_down = float(step_size_down) if step_size_down is not None else step_size_up
- self.total_size = step_size_up + step_size_down
- self.step_ratio = step_size_up / self.total_size
-
- if mode not in ['triangular', 'triangular2', 'exp_range'] \
- and scale_fn is None:
- raise ValueError('mode is invalid and scale_fn is None')
-
- self.mode = mode
- self.gamma = gamma
-
- self._scale_fn_ref = None
- self._scale_fn_custom = scale_fn
- self.scale_mode = scale_mode
- self._init_scale_fn()
-
- self.cycle_momentum = cycle_momentum
- if cycle_momentum:
- if 'momentum' not in optimizer.defaults:
- raise ValueError('optimizer must support momentum with `cycle_momentum` option enabled')
-
- base_momentums = self._format_param('base_momentum', optimizer, base_momentum)
- if last_epoch == -1:
- for momentum, group in zip(base_momentums, optimizer.param_groups):
- group['momentum'] = momentum
- self.base_momentums = [group['momentum'] for group in optimizer.param_groups]
- self.max_momentums = self._format_param('max_momentum', optimizer, max_momentum)
-
- super().__init__(optimizer, last_epoch, verbose)
- self.base_lrs = base_lrs
-
- def _init_scale_fn(self):
- if self._scale_fn_custom is not None:
- return
- if self.mode == 'triangular':
- self._scale_fn_ref = weakref.WeakMethod(self._triangular_scale_fn)
- self.scale_mode = 'cycle'
- elif self.mode == 'triangular2':
- self._scale_fn_ref = weakref.WeakMethod(self._triangular2_scale_fn)
- self.scale_mode = 'cycle'
- elif self.mode == 'exp_range':
- self._scale_fn_ref = weakref.WeakMethod(self._exp_range_scale_fn)
- self.scale_mode = 'iterations'
-
- def _format_param(self, name, optimizer, param):
- """Return correctly formatted lr/momentum for each param group."""
- if isinstance(param, (list, tuple)):
- if len(param) != len(optimizer.param_groups):
- raise ValueError("expected {} values for {}, got {}".format(
- len(optimizer.param_groups), name, len(param)))
- return param
- else:
- return [param] * len(optimizer.param_groups)
-
- def scale_fn(self, x):
- if self._scale_fn_custom is not None:
- return self._scale_fn_custom(x)
-
- else:
- return self._scale_fn_ref()(x)
-
- def _triangular_scale_fn(self, x):
- return 1.
-
- def _triangular2_scale_fn(self, x):
- return 1 / (2. ** (x - 1))
-
- def _exp_range_scale_fn(self, x):
- return self.gamma**(x)
-
- def get_lr(self):
- """Calculates the learning rate at batch index. This function treats
- `self.last_epoch` as the last batch index.
-
- If `self.cycle_momentum` is ``True``, this function has a side effect of
- updating the optimizer's momentum.
- """
-
- if not self._get_lr_called_within_step:
- warnings.warn("To get the last learning rate computed by the scheduler, "
- "please use `get_last_lr()`.", UserWarning)
-
- cycle = math.floor(1 + self.last_epoch / self.total_size)
- x = 1. + self.last_epoch / self.total_size - cycle
- if x <= self.step_ratio:
- scale_factor = x / self.step_ratio
- else:
- scale_factor = (x - 1) / (self.step_ratio - 1)
-
- lrs = []
- for base_lr, max_lr in zip(self.base_lrs, self.max_lrs):
- base_height = (max_lr - base_lr) * scale_factor
- if self.scale_mode == 'cycle':
- lr = base_lr + base_height * self.scale_fn(cycle)
- else:
- lr = base_lr + base_height * self.scale_fn(self.last_epoch)
- lrs.append(lr)
-
- if self.cycle_momentum:
- momentums = []
- for base_momentum, max_momentum in zip(self.base_momentums, self.max_momentums):
- base_height = (max_momentum - base_momentum) * scale_factor
- if self.scale_mode == 'cycle':
- momentum = max_momentum - base_height * self.scale_fn(cycle)
- else:
- momentum = max_momentum - base_height * self.scale_fn(self.last_epoch)
- momentums.append(momentum)
- for param_group, momentum in zip(self.optimizer.param_groups, momentums):
- param_group['momentum'] = momentum
-
- return lrs
-
- def state_dict(self):
- state = super().state_dict()
- # We are dropping the `_scale_fn_ref` attribute because it is a `weakref.WeakMethod` and can't be pickled
- state.pop("_scale_fn_ref")
- return state
-
- def load_state_dict(self, state_dict):
- super().load_state_dict(state_dict)
- self._init_scale_fn()
-
-
-
-class CosineAnnealingWarmRestarts(LRScheduler):
- r"""Set the learning rate of each parameter group using a cosine annealing
- schedule, where :math:`\eta_{max}` is set to the initial lr, :math:`T_{cur}`
- is the number of epochs since the last restart and :math:`T_{i}` is the number
- of epochs between two warm restarts in SGDR:
-
- .. math::
- \eta_t = \eta_{min} + \frac{1}{2}(\eta_{max} - \eta_{min})\left(1 +
- \cos\left(\frac{T_{cur}}{T_{i}}\pi\right)\right)
-
- When :math:`T_{cur}=T_{i}`, set :math:`\eta_t = \eta_{min}`.
- When :math:`T_{cur}=0` after restart, set :math:`\eta_t=\eta_{max}`.
-
- It has been proposed in
- `SGDR: Stochastic Gradient Descent with Warm Restarts`_.
-
- Args:
- optimizer (Optimizer): Wrapped optimizer.
- T_0 (int): Number of iterations for the first restart.
- T_mult (int, optional): A factor increases :math:`T_{i}` after a restart. Default: 1.
- eta_min (float, optional): Minimum learning rate. Default: 0.
- last_epoch (int, optional): The index of last epoch. Default: -1.
- verbose (bool): If ``True``, prints a message to stdout for
- each update. Default: ``False``.
-
- .. _SGDR\: Stochastic Gradient Descent with Warm Restarts:
- https://arxiv.org/abs/1608.03983
- """
-
- def __init__(self, optimizer, T_0, T_mult=1, eta_min=0, last_epoch=-1, verbose=False):
- if T_0 <= 0 or not isinstance(T_0, int):
- raise ValueError("Expected positive integer T_0, but got {}".format(T_0))
- if T_mult < 1 or not isinstance(T_mult, int):
- raise ValueError("Expected integer T_mult >= 1, but got {}".format(T_mult))
- self.T_0 = T_0
- self.T_i = T_0
- self.T_mult = T_mult
- self.eta_min = eta_min
- self.T_cur = last_epoch
- super().__init__(optimizer, last_epoch, verbose)
-
- def get_lr(self):
- if not self._get_lr_called_within_step:
- warnings.warn("To get the last learning rate computed by the scheduler, "
- "please use `get_last_lr()`.", UserWarning)
-
- return [self.eta_min + (base_lr - self.eta_min) * (1 + math.cos(math.pi * self.T_cur / self.T_i)) / 2
- for base_lr in self.base_lrs]
-
- def step(self, epoch=None):
- """Step could be called after every batch update
-
- Example:
- >>> # xdoctest: +SKIP("Undefined vars")
- >>> scheduler = CosineAnnealingWarmRestarts(optimizer, T_0, T_mult)
- >>> iters = len(dataloader)
- >>> for epoch in range(20):
- >>> for i, sample in enumerate(dataloader):
- >>> inputs, labels = sample['inputs'], sample['labels']
- >>> optimizer.zero_grad()
- >>> outputs = net(inputs)
- >>> loss = criterion(outputs, labels)
- >>> loss.backward()
- >>> optimizer.step()
- >>> scheduler.step(epoch + i / iters)
-
- This function can be called in an interleaved way.
-
- Example:
- >>> # xdoctest: +SKIP("Undefined vars")
- >>> scheduler = CosineAnnealingWarmRestarts(optimizer, T_0, T_mult)
- >>> for epoch in range(20):
- >>> scheduler.step()
- >>> scheduler.step(26)
- >>> scheduler.step() # scheduler.step(27), instead of scheduler(20)
- """
-
- if epoch is None and self.last_epoch < 0:
- epoch = 0
-
- if epoch is None:
- epoch = self.last_epoch + 1
- self.T_cur = self.T_cur + 1
- if self.T_cur >= self.T_i:
- self.T_cur = self.T_cur - self.T_i
- self.T_i = self.T_i * self.T_mult
- else:
- if epoch < 0:
- raise ValueError("Expected non-negative epoch, but got {}".format(epoch))
- if epoch >= self.T_0:
- if self.T_mult == 1:
- self.T_cur = epoch % self.T_0
- else:
- n = int(math.log((epoch / self.T_0 * (self.T_mult - 1) + 1), self.T_mult))
- self.T_cur = epoch - self.T_0 * (self.T_mult ** n - 1) / (self.T_mult - 1)
- self.T_i = self.T_0 * self.T_mult ** (n)
- else:
- self.T_i = self.T_0
- self.T_cur = epoch
- self.last_epoch = math.floor(epoch)
-
- class _enable_get_lr_call:
-
- def __init__(self, o):
- self.o = o
-
- def __enter__(self):
- self.o._get_lr_called_within_step = True
- return self
-
- def __exit__(self, type, value, traceback):
- self.o._get_lr_called_within_step = False
- return self
-
- with _enable_get_lr_call(self):
- for i, data in enumerate(zip(self.optimizer.param_groups, self.get_lr())):
- param_group, lr = data
- param_group['lr'] = lr
- self.print_lr(self.verbose, i, lr, epoch)
-
- self._last_lr = [group['lr'] for group in self.optimizer.param_groups]
-
-
-class OneCycleLR(LRScheduler):
- r"""Sets the learning rate of each parameter group according to the
- 1cycle learning rate policy. The 1cycle policy anneals the learning
- rate from an initial learning rate to some maximum learning rate and then
- from that maximum learning rate to some minimum learning rate much lower
- than the initial learning rate.
- This policy was initially described in the paper `Super-Convergence:
- Very Fast Training of Neural Networks Using Large Learning Rates`_.
-
- The 1cycle learning rate policy changes the learning rate after every batch.
- `step` should be called after a batch has been used for training.
-
- This scheduler is not chainable.
-
- Note also that the total number of steps in the cycle can be determined in one
- of two ways (listed in order of precedence):
-
- #. A value for total_steps is explicitly provided.
- #. A number of epochs (epochs) and a number of steps per epoch
- (steps_per_epoch) are provided.
- In this case, the number of total steps is inferred by
- total_steps = epochs * steps_per_epoch
-
- You must either provide a value for total_steps or provide a value for both
- epochs and steps_per_epoch.
-
- The default behaviour of this scheduler follows the fastai implementation of 1cycle, which
- claims that "unpublished work has shown even better results by using only two phases". To
- mimic the behaviour of the original paper instead, set ``three_phase=True``.
-
- Args:
- optimizer (Optimizer): Wrapped optimizer.
- max_lr (float or list): Upper learning rate boundaries in the cycle
- for each parameter group.
- total_steps (int): The total number of steps in the cycle. Note that
- if a value is not provided here, then it must be inferred by providing
- a value for epochs and steps_per_epoch.
- Default: None
- epochs (int): The number of epochs to train for. This is used along
- with steps_per_epoch in order to infer the total number of steps in the cycle
- if a value for total_steps is not provided.
- Default: None
- steps_per_epoch (int): The number of steps per epoch to train for. This is
- used along with epochs in order to infer the total number of steps in the
- cycle if a value for total_steps is not provided.
- Default: None
- pct_start (float): The percentage of the cycle (in number of steps) spent
- increasing the learning rate.
- Default: 0.3
- anneal_strategy (str): {'cos', 'linear'}
- Specifies the annealing strategy: "cos" for cosine annealing, "linear" for
- linear annealing.
- Default: 'cos'
- cycle_momentum (bool): If ``True``, momentum is cycled inversely
- to learning rate between 'base_momentum' and 'max_momentum'.
- Default: True
- base_momentum (float or list): Lower momentum boundaries in the cycle
- for each parameter group. Note that momentum is cycled inversely
- to learning rate; at the peak of a cycle, momentum is
- 'base_momentum' and learning rate is 'max_lr'.
- Default: 0.85
- max_momentum (float or list): Upper momentum boundaries in the cycle
- for each parameter group. Functionally,
- it defines the cycle amplitude (max_momentum - base_momentum).
- Note that momentum is cycled inversely
- to learning rate; at the start of a cycle, momentum is 'max_momentum'
- and learning rate is 'base_lr'
- Default: 0.95
- div_factor (float): Determines the initial learning rate via
- initial_lr = max_lr/div_factor
- Default: 25
- final_div_factor (float): Determines the minimum learning rate via
- min_lr = initial_lr/final_div_factor
- Default: 1e4
- three_phase (bool): If ``True``, use a third phase of the schedule to annihilate the
- learning rate according to 'final_div_factor' instead of modifying the second
- phase (the first two phases will be symmetrical about the step indicated by
- 'pct_start').
- last_epoch (int): The index of the last batch. This parameter is used when
- resuming a training job. Since `step()` should be invoked after each
- batch instead of after each epoch, this number represents the total
- number of *batches* computed, not the total number of epochs computed.
- When last_epoch=-1, the schedule is started from the beginning.
- Default: -1
- verbose (bool): If ``True``, prints a message to stdout for
- each update. Default: ``False``.
-
- Example:
- >>> # xdoctest: +SKIP
- >>> data_loader = torch.utils.data.DataLoader(...)
- >>> optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9)
- >>> scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer, max_lr=0.01, steps_per_epoch=len(data_loader), epochs=10)
- >>> for epoch in range(10):
- >>> for batch in data_loader:
- >>> train_batch(...)
- >>> optimizer.step()
- >>> scheduler.step()
-
-
- .. _Super-Convergence\: Very Fast Training of Neural Networks Using Large Learning Rates:
- https://arxiv.org/abs/1708.07120
- """
- def __init__(self,
- optimizer,
- max_lr,
- total_steps=None,
- epochs=None,
- steps_per_epoch=None,
- pct_start=0.3,
- anneal_strategy='cos',
- cycle_momentum=True,
- base_momentum=0.85,
- max_momentum=0.95,
- div_factor=25.,
- final_div_factor=1e4,
- three_phase=False,
- last_epoch=-1,
- verbose=False):
-
- # Validate optimizer
- if not isinstance(optimizer, Optimizer):
- raise TypeError('{} is not an Optimizer'.format(
- type(optimizer).__name__))
- self.optimizer = optimizer
-
- # Validate total_steps
- if total_steps is None and epochs is None and steps_per_epoch is None:
- raise ValueError("You must define either total_steps OR (epochs AND steps_per_epoch)")
- elif total_steps is not None:
- if total_steps <= 0 or not isinstance(total_steps, int):
- raise ValueError("Expected positive integer total_steps, but got {}".format(total_steps))
- self.total_steps = total_steps
- else:
- if epochs <= 0 or not isinstance(epochs, int):
- raise ValueError("Expected positive integer epochs, but got {}".format(epochs))
- if steps_per_epoch <= 0 or not isinstance(steps_per_epoch, int):
- raise ValueError("Expected positive integer steps_per_epoch, but got {}".format(steps_per_epoch))
- self.total_steps = epochs * steps_per_epoch
-
- if three_phase:
- self._schedule_phases = [
- {
- 'end_step': float(pct_start * self.total_steps) - 1,
- 'start_lr': 'initial_lr',
- 'end_lr': 'max_lr',
- 'start_momentum': 'max_momentum',
- 'end_momentum': 'base_momentum',
- },
- {
- 'end_step': float(2 * pct_start * self.total_steps) - 2,
- 'start_lr': 'max_lr',
- 'end_lr': 'initial_lr',
- 'start_momentum': 'base_momentum',
- 'end_momentum': 'max_momentum',
- },
- {
- 'end_step': self.total_steps - 1,
- 'start_lr': 'initial_lr',
- 'end_lr': 'min_lr',
- 'start_momentum': 'max_momentum',
- 'end_momentum': 'max_momentum',
- },
- ]
- else:
- self._schedule_phases = [
- {
- 'end_step': float(pct_start * self.total_steps) - 1,
- 'start_lr': 'initial_lr',
- 'end_lr': 'max_lr',
- 'start_momentum': 'max_momentum',
- 'end_momentum': 'base_momentum',
- },
- {
- 'end_step': self.total_steps - 1,
- 'start_lr': 'max_lr',
- 'end_lr': 'min_lr',
- 'start_momentum': 'base_momentum',
- 'end_momentum': 'max_momentum',
- },
- ]
-
- # Validate pct_start
- if pct_start < 0 or pct_start > 1 or not isinstance(pct_start, float):
- raise ValueError("Expected float between 0 and 1 pct_start, but got {}".format(pct_start))
-
- # Validate anneal_strategy
- if anneal_strategy not in ['cos', 'linear']:
- raise ValueError("anneal_strategy must by one of 'cos' or 'linear', instead got {}".format(anneal_strategy))
- elif anneal_strategy == 'cos':
- self.anneal_func = self._annealing_cos
- elif anneal_strategy == 'linear':
- self.anneal_func = self._annealing_linear
-
- # Initialize learning rate variables
- max_lrs = self._format_param('max_lr', self.optimizer, max_lr)
- if last_epoch == -1:
- for idx, group in enumerate(self.optimizer.param_groups):
- group['initial_lr'] = max_lrs[idx] / div_factor
- group['max_lr'] = max_lrs[idx]
- group['min_lr'] = group['initial_lr'] / final_div_factor
-
- # Initialize momentum variables
- self.cycle_momentum = cycle_momentum
- if self.cycle_momentum:
- if 'momentum' not in self.optimizer.defaults and 'betas' not in self.optimizer.defaults:
- raise ValueError('optimizer must support momentum with `cycle_momentum` option enabled')
- self.use_beta1 = 'betas' in self.optimizer.defaults
- max_momentums = self._format_param('max_momentum', optimizer, max_momentum)
- base_momentums = self._format_param('base_momentum', optimizer, base_momentum)
- if last_epoch == -1:
- for m_momentum, b_momentum, group in zip(max_momentums, base_momentums, optimizer.param_groups):
- if self.use_beta1:
- group['betas'] = (m_momentum, *group['betas'][1:])
- else:
- group['momentum'] = m_momentum
- group['max_momentum'] = m_momentum
- group['base_momentum'] = b_momentum
-
- super().__init__(optimizer, last_epoch, verbose)
-
- def _format_param(self, name, optimizer, param):
- """Return correctly formatted lr/momentum for each param group."""
- if isinstance(param, (list, tuple)):
- if len(param) != len(optimizer.param_groups):
- raise ValueError("expected {} values for {}, got {}".format(
- len(optimizer.param_groups), name, len(param)))
- return param
- else:
- return [param] * len(optimizer.param_groups)
-
- def _annealing_cos(self, start, end, pct):
- "Cosine anneal from `start` to `end` as pct goes from 0.0 to 1.0."
- cos_out = math.cos(math.pi * pct) + 1
- return end + (start - end) / 2.0 * cos_out
-
- def _annealing_linear(self, start, end, pct):
- "Linearly anneal from `start` to `end` as pct goes from 0.0 to 1.0."
- return (end - start) * pct + start
-
- def get_lr(self):
- if not self._get_lr_called_within_step:
- warnings.warn("To get the last learning rate computed by the scheduler, "
- "please use `get_last_lr()`.", UserWarning)
-
- lrs = []
- step_num = self.last_epoch
-
- if step_num > self.total_steps:
- raise ValueError("Tried to step {} times. The specified number of total steps is {}"
- .format(step_num, self.total_steps))
-
- for group in self.optimizer.param_groups:
- start_step = 0
- for i, phase in enumerate(self._schedule_phases):
- end_step = phase['end_step']
- if step_num <= end_step or i == len(self._schedule_phases) - 1:
- pct = (step_num - start_step) / (end_step - start_step)
- computed_lr = self.anneal_func(group[phase['start_lr']], group[phase['end_lr']], pct)
- if self.cycle_momentum:
- computed_momentum = self.anneal_func(group[phase['start_momentum']], group[phase['end_momentum']], pct)
- break
- start_step = phase['end_step']
-
- lrs.append(computed_lr)
- if self.cycle_momentum:
- if self.use_beta1:
- group['betas'] = (computed_momentum, *group['betas'][1:])
- else:
- group['momentum'] = computed_momentum
-
- return lrs
\ No newline at end of file
diff --git a/python/jittor/compatibility/src/jtorch_core.cc b/python/jittor/compatibility/src/jtorch_core.cc
deleted file mode 100644
index 1102b107..00000000
--- a/python/jittor/compatibility/src/jtorch_core.cc
+++ /dev/null
@@ -1,102 +0,0 @@
-
-#include "pyjt/py_obj_holder.h"
-#include "utils/str_utils.h"
-#include "jtorch_core.h"
-#include "graph.h"
-#include "grad.h"
-#include "ops/op_register.h"
-
-namespace jittor {
-
-void pyjt_def_all(PyObject* m);
-
-EXTERN_LIB void setter_use_cuda(int value);
-
-Device::Device(const string& name, int ordinal) : name(name) {
- if (startswith(name, "cpu"))
- setter_use_cuda(0);
- else
- setter_use_cuda(1);
-}
-
-unordered_map grad_backup;
-EXTERN_LIB void (*_var_free_hook)(Var*);
-EXTERN_LIB unordered_map* _grad_backup_ptr;
-
-void jtorch_var_free_hook(Var* v) {
- auto iter = grad_backup.find(v->id);
- if (iter != grad_backup.end()) {
- grad_backup.erase(iter);
- }
-}
-
-void jtorch_init() {
- _var_free_hook = &jtorch_var_free_hook;
- _grad_backup_ptr = &grad_backup;
-}
-
-inline static VarPtr& get_grad(Var* v) {
- return grad_backup[v->id];
-}
-static auto make_binary = get_op_info("binary")
- .get_constructor();
-
-inline static void add_grad(VarPtr& a, VarPtr&& b) {
- if (!a) a = move(b);
- else {
- a = make_binary(a, b, ns_add);
- }
-}
-
-
-void grad_set(VarHolder* x, Maybe v) {
- if (!v) {
- grad_del(x);
- return;
- }
- grad_backup[x->var->id] = v.ptr->var;
-}
-
-Maybe grad_get(VarHolder* x) {
- auto iter = grad_backup.find(x->var->id);
- if (iter != grad_backup.end()) {
- if (!iter->second.ptr) return nullptr;
- return new VarHolder(iter->second.ptr);
- }
- return nullptr;
-}
-
-void grad_del(VarHolder* x) {
- auto iter = grad_backup.find(x->var->id);
- if (iter != grad_backup.end())
- grad_backup.erase(iter);
-}
-
-void backward(VarHolder* x) {
- vector gnodes({x->var});
- bfs_backward(gnodes, [&](Node* node) {
- if (node->is_stop_grad())
- return false;
- return true;
- });
- vector targets;
- for (auto* node : gnodes) {
- if (node->is_var() && node->flags.get(NodeFlags::_th_require_grad))
- targets.push_back(node->var());
- }
- auto grads = grad(x->var, targets);
- for (int i=0; im_doc = "Inner c++ core of jtorch";
- jittor::pyjt_def_all(m);
-}
-PYJT_MODULE_INIT(jtorch_core);
diff --git a/python/jittor/compatibility/src/jtorch_core.h b/python/jittor/compatibility/src/jtorch_core.h
deleted file mode 100644
index 36de6522..00000000
--- a/python/jittor/compatibility/src/jtorch_core.h
+++ /dev/null
@@ -1,40 +0,0 @@
-#pragma once
-#include "common.h"
-#include "var_holder.h"
-#include "misc/fast_shared_ptr.h"
-
-namespace jittor {
-
-// @pyjt(device)
-// @attrs(heaptype)
-struct Device {
- string name;
-
- // @pyjt(__init__)
- Device(const string& name, int ordinal=0);
- // @pyjt(__get__type, __str__)
- inline string get_type() {return name;}
- // @pyjt(__get__index)
- inline int index() {return 0;}
-};
-
-// @pyjt(backward)
-void backward(VarHolder* x);
-
-// @pyjt(grad_set)
-void grad_set(VarHolder* x, Maybe v);
-// @pyjt(grad_get)
-Maybe grad_get(VarHolder* x);
-// @pyjt(grad_del)
-void grad_del(VarHolder* x);
-
-// @pyjt(retain_grad_set)
-inline void retain_grad_set(VarHolder* x, bool v) {
- x->var->flags.set(NodeFlags::_th_require_grad, v);
-}
-// @pyjt(retain_grad_get)
-inline bool retain_grad_get(VarHolder* x) {
- return x->var->flags.get(NodeFlags::_th_require_grad);
-}
-
-}
\ No newline at end of file
diff --git a/python/jittor/compatibility/test/test_conflict_func.py b/python/jittor/compatibility/test/test_conflict_func.py
deleted file mode 100644
index 97bd7d8f..00000000
--- a/python/jittor/compatibility/test/test_conflict_func.py
+++ /dev/null
@@ -1,25 +0,0 @@
-import unittest
-import numpy as np
-import torch
-import jittor as jt
-
-class TestConflictFunc(unittest.TestCase):
- def test_max(self):
- a = torch.Tensor([1,4,2])
- assert a.max() == 4
- v, k = a.max(dim=0)
- assert v==4 and k==1
-
- def test_argsort(self):
- a = torch.Tensor([1,4,2])
- k = a.argsort()
- assert jt.all_equal(k, [0,2,1])
-
- with jt.flag_scope(th_mode=0):
- k, v = a.argsort()
- assert jt.all_equal(k, [0,2,1])
-
-
-
-if __name__ == "__main__":
- unittest.main()
diff --git a/python/jittor/compatibility/test/test_function.py b/python/jittor/compatibility/test/test_function.py
deleted file mode 100644
index 9959dbae..00000000
--- a/python/jittor/compatibility/test/test_function.py
+++ /dev/null
@@ -1,58 +0,0 @@
-import unittest
-import numpy as np
-import torch
-
-class TestFunction(unittest.TestCase):
- def test_example1(self):
- import jtorch
- from jtorch import Function
-
- class MyFunc(Function):
- @staticmethod
- def forward(self, x, y):
- self.x = x
- self.y = y
- return x*y, x/y
-
- @staticmethod
- def backward(self, grad0, grad1):
- return grad0 * self.y, grad1 * self.x
-
- a = jtorch.array(3.0)
- a.requires_grad = True
- b = jtorch.array(4.0)
- b.requires_grad = True
- func = MyFunc.apply
- c,d = func(a, b)
- (c+d*3).backward()
- assert a.grad.data == 4
- assert b.grad.data == 9
-
- def test_example2(self):
- import jtorch as jt
- from jtorch import Function
-
- class MyFunc(Function):
- @staticmethod
- def forward(self, x, y):
- self.x = x
- self.y = y
- return x*y, x/y
-
- @staticmethod
- def backward(self, grad0, grad1):
- assert grad1 is None
- return grad0 * self.y, None
- a = jt.array(3.0)
- a.requires_grad = True
- b = jt.array(4.0)
- b.requires_grad = True
- func = MyFunc.apply
- c,d = func(a, b)
- d.stop_grad()
- da, db = jt.grad(c+d*3, [a, b])
- assert da.data == 4
- assert db.data == 0
-
-if __name__ == "__main__":
- unittest.main()
diff --git a/python/jittor/compatibility/test/test_misc.py b/python/jittor/compatibility/test/test_misc.py
deleted file mode 100644
index 00bf1b70..00000000
--- a/python/jittor/compatibility/test/test_misc.py
+++ /dev/null
@@ -1,24 +0,0 @@
-import unittest
-import numpy as np
-import torch
-
-class TestMisc(unittest.TestCase):
- def test_update_grad(self):
- class Net(torch.nn.Module):
- def __init__(self):
- super().__init__()
- self.a = torch.nn.Parameter(torch.Tensor([1.0, 2.0]))
- net = Net()
- assert(net.a.requires_grad)
- net.load_state_dict({"a": torch.Tensor([3.0, 4.0])})
- assert(net.a.requires_grad)
-
- def test_reshape(self):
- a = torch.ones(3,3)
- a.requires_grad = True
- b = torch.reshape(a, [9])
- assert b.requires_grad == True
-
-
-if __name__ == "__main__":
- unittest.main()
diff --git a/python/jittor/compatibility/test/test_tutorial.py b/python/jittor/compatibility/test/test_tutorial.py
deleted file mode 100644
index 92c087c7..00000000
--- a/python/jittor/compatibility/test/test_tutorial.py
+++ /dev/null
@@ -1,56 +0,0 @@
-import unittest
-import numpy as np
-import os
-import subprocess as sp
-import sys
-
-def check_two(cmd, parser=None, checker=None):
- jtorch_out = sp.getoutput(cmd)
- print("=========JTORCH OUT==========")
- print(jtorch_out)
- torch_out = sp.getoutput("PYTHONPATH= "+cmd)
- print("=========TORCH OUT==========")
- print(torch_out)
- if parser:
- torch_out = parser(torch_out)
- jtorch_out = parser(jtorch_out)
- if checker:
- checker(torch_out, jtorch_out)
- else:
- assert torch_out == jtorch_out
- return jtorch_out, torch_out
-
-jtorch_path = os.path.join(os.path.dirname(__file__), "..")
-# come from https://pytorch.org/tutorials/beginner/pytorch_with_examples.html
-class TestTutorial(unittest.TestCase):
- def test_auto_grad1(self):
- check_two(f"{sys.executable} {jtorch_path}/tutorial/auto_grad1.py",
- parser=lambda s: np.array(s.split())[[-10,-8,-5,-2]].astype(float),
- checker=lambda a,b: np.testing.assert_allclose(a, b, atol=1e-4))
- def test_auto_grad2(self):
- check_two(f"{sys.executable} {jtorch_path}/tutorial/auto_grad2.py",
- parser=lambda s: np.array(s.split())[[-10,-8,-5,-2]].astype(float),
- checker=lambda a,b: np.testing.assert_allclose(a, b, atol=1e-4))
- def test_auto_grad3(self):
- check_two(f"{sys.executable} {jtorch_path}/tutorial/auto_grad3.py",
- parser=lambda s: np.array(s.split())[[-9,-7,-4,-2]].astype(float),
- checker=lambda a,b: np.testing.assert_allclose(a, b, atol=1e-4))
- def test_auto_grad4(self):
- check_two(f"{sys.executable} {jtorch_path}/tutorial/auto_grad4.py",
- parser=lambda s: np.array(s.split())[[-10,-8,-5,-2]].astype(float),
- checker=lambda a,b: np.testing.assert_allclose(a, b, atol=1e-4))
- def test_auto_grad5(self):
- check_two(f"{sys.executable} {jtorch_path}/tutorial/auto_grad5_optim.py",
- parser=lambda s: np.array(s.split())[[-10,-8,-5,-2]].astype(float),
- checker=lambda a,b: np.testing.assert_allclose(a, b, atol=1e-2))
- def test_auto_grad6(self):
- check_two(f"{sys.executable} {jtorch_path}/tutorial/auto_grad6_module.py",
- parser=lambda s: np.array(s.split())[[-10,-8,-5,-2]].astype(float),
- checker=lambda a,b: np.testing.assert_allclose(a, b, atol=1e-4))
- def test_auto_grad7(self):
- check_two(f"{sys.executable} {jtorch_path}/tutorial/auto_grad7_dynet.py",
- parser=lambda s: np.array(s.split())[[-13,-10,-7,-3]].astype(float),
- checker=lambda a,b: np.testing.assert_allclose(a, b, atol=1e-2))
-
-if __name__ == "__main__":
- unittest.main()
\ No newline at end of file
diff --git a/python/jittor/compatibility/tutorial/auto_grad1.py b/python/jittor/compatibility/tutorial/auto_grad1.py
deleted file mode 100644
index 60a090ad..00000000
--- a/python/jittor/compatibility/tutorial/auto_grad1.py
+++ /dev/null
@@ -1,44 +0,0 @@
-import torch
-import math
-
-dtype = torch.float
-device = torch.device("cpu")
-# device = torch.device("cuda:0") # Uncomment this to run on GPU
-
-# Create random input and output data
-x = torch.linspace(-math.pi, math.pi, 2000, device=device, dtype=dtype)
-y = torch.sin(x)
-
-# Randomly initialize weights
-a = torch.randn((), device=device, dtype=dtype)
-b = torch.randn((), device=device, dtype=dtype)
-c = torch.randn((), device=device, dtype=dtype)
-d = torch.randn((), device=device, dtype=dtype)
-
-learning_rate = 1e-6
-for t in range(20000):
- # Forward pass: compute predicted y
- y_pred = a + b * x + c * x ** 2 + d * x ** 3
-
- # Compute and print loss
- loss = (y_pred - y).pow(2).sum().item()
- if t % 1000 == 999:
- print(t, loss)
-
- # Backprop to compute gradients of a, b, c, d with respect to loss
- grad_y_pred = 2.0 * (y_pred - y)
- grad_a = grad_y_pred.sum()
- grad_b = (grad_y_pred * x).sum()
- grad_c = (grad_y_pred * x ** 2).sum()
- grad_d = (grad_y_pred * x ** 3).sum()
-
- # Update weights using gradient descent
- a -= learning_rate * grad_a
- b -= learning_rate * grad_b
- c -= learning_rate * grad_c
- d -= learning_rate * grad_d
- # print(t, torch.liveness_info())
- # torch.sync_all()
-
-
-print(f'Result: y = {a.item()} + {b.item()} x + {c.item()} x^2 + {d.item()} x^3')
\ No newline at end of file
diff --git a/python/jittor/compatibility/tutorial/auto_grad2.py b/python/jittor/compatibility/tutorial/auto_grad2.py
deleted file mode 100644
index a3bbc9a8..00000000
--- a/python/jittor/compatibility/tutorial/auto_grad2.py
+++ /dev/null
@@ -1,60 +0,0 @@
-# -*- coding: utf-8 -*-
-import torch
-import math
-
-dtype = torch.float
-device = torch.device("cpu")
-# device = torch.device("cuda:0") # Uncomment this to run on GPU
-
-# Create Tensors to hold input and outputs.
-# By default, requires_grad=False, which indicates that we do not need to
-# compute gradients with respect to these Tensors during the backward pass.
-x = torch.linspace(-math.pi, math.pi, 2000, device=device, dtype=dtype)
-y = torch.sin(x)
-
-# Create random Tensors for weights. For a third order polynomial, we need
-# 4 weights: y = a + b x + c x^2 + d x^3
-# Setting requires_grad=True indicates that we want to compute gradients with
-# respect to these Tensors during the backward pass.
-a = torch.randn((), device=device, dtype=dtype, requires_grad=True)
-b = torch.randn((), device=device, dtype=dtype, requires_grad=True)
-c = torch.randn((), device=device, dtype=dtype, requires_grad=True)
-d = torch.randn((), device=device, dtype=dtype, requires_grad=True)
-
-learning_rate = 1e-6
-for t in range(20000):
- # Forward pass: compute predicted y using operations on Tensors.
- y_pred = a + b * x + c * x ** 2 + d * x ** 3
- # print(y_pred.requires_grad)
- # y_pred.requires_grad = False
-
- # Compute and print loss using operations on Tensors.
- # Now loss is a Tensor of shape (1,)
- # loss.item() gets the scalar value held in the loss.
- loss = (y_pred - y).pow(2).sum()
- if t % 1000 == 990:
- print(t, loss.item())
-
- # Use autograd to compute the backward pass. This call will compute the
- # gradient of loss with respect to all Tensors with requires_grad=True.
- # After this call a.grad, b.grad. c.grad and d.grad will be Tensors holding
- # the gradient of the loss with respect to a, b, c, d respectively.
- # torch.backward(loss)
- loss.backward()
-
- # Manually update weights using gradient descent. Wrap in torch.no_grad()
- # because weights have requires_grad=True, but we don't need to track this
- # in autograd.
- with torch.no_grad():
- a -= learning_rate * a.grad
- b -= learning_rate * b.grad
- c -= learning_rate * c.grad
- d -= learning_rate * d.grad
-
- # Manually zero the gradients after updating weights
- a.grad = None
- b.grad = None
- c.grad = None
- d.grad = None
-
-print(f'Result: y = {a.item()} + {b.item()} x + {c.item()} x^2 + {d.item()} x^3')
\ No newline at end of file
diff --git a/python/jittor/compatibility/tutorial/auto_grad3.py b/python/jittor/compatibility/tutorial/auto_grad3.py
deleted file mode 100644
index 654ec447..00000000
--- a/python/jittor/compatibility/tutorial/auto_grad3.py
+++ /dev/null
@@ -1,85 +0,0 @@
-# -*- coding: utf-8 -*-
-import torch
-import math
-
-
-class LegendrePolynomial3(torch.autograd.Function):
- """
- We can implement our own custom autograd Functions by subclassing
- torch.autograd.Function and implementing the forward and backward passes
- which operate on Tensors.
- """
-
- @staticmethod
- def forward(ctx, input):
- """
- In the forward pass we receive a Tensor containing the input and return
- a Tensor containing the output. ctx is a context object that can be used
- to stash information for backward computation. You can cache arbitrary
- objects for use in the backward pass using the ctx.save_for_backward method.
- """
- ctx.save_for_backward(input)
- return 0.5 * (5 * input ** 3 - 3 * input)
-
- @staticmethod
- def backward(ctx, grad_output):
- """
- In the backward pass we receive a Tensor containing the gradient of the loss
- with respect to the output, and we need to compute the gradient of the loss
- with respect to the input.
- """
- input, = ctx.saved_tensors
- return grad_output * 1.5 * (5 * input ** 2 - 1)
-
-
-dtype = torch.float
-device = torch.device("cpu")
-# device = torch.device("cuda:0") # Uncomment this to run on GPU
-
-# Create Tensors to hold input and outputs.
-# By default, requires_grad=False, which indicates that we do not need to
-# compute gradients with respect to these Tensors during the backward pass.
-x = torch.linspace(-math.pi, math.pi, 2000, device=device, dtype=dtype)
-y = torch.sin(x)
-
-# Create random Tensors for weights. For this example, we need
-# 4 weights: y = a + b * P3(c + d * x), these weights need to be initialized
-# not too far from the correct result to ensure convergence.
-# Setting requires_grad=True indicates that we want to compute gradients with
-# respect to these Tensors during the backward pass.
-a = torch.full((), 0.0, device=device, dtype=dtype, requires_grad=True)
-b = torch.full((), -1.0, device=device, dtype=dtype, requires_grad=True)
-c = torch.full((), 0.0, device=device, dtype=dtype, requires_grad=True)
-d = torch.full((), 0.3, device=device, dtype=dtype, requires_grad=True)
-
-learning_rate = 5e-6
-for t in range(2000):
- # To apply our Function, we use Function.apply method. We alias this as 'P3'.
- P3 = LegendrePolynomial3.apply
-
- # Forward pass: compute predicted y using operations; we compute
- # P3 using our custom autograd operation.
- y_pred = a + b * P3(c + d * x)
-
- # Compute and print loss
- loss = (y_pred - y).pow(2).sum()
- if t % 100 == 99:
- print(t, loss.item())
-
- # Use autograd to compute the backward pass.
- loss.backward()
-
- # Update weights using gradient descent
- with torch.no_grad():
- a -= learning_rate * a.grad
- b -= learning_rate * b.grad
- c -= learning_rate * c.grad
- d -= learning_rate * d.grad
-
- # Manually zero the gradients after updating weights
- a.grad = None
- b.grad = None
- c.grad = None
- d.grad = None
-
-print(f'Result: y = {a.item()} + {b.item()} * P3( {c.item()} + {d.item()} x)')
\ No newline at end of file
diff --git a/python/jittor/compatibility/tutorial/auto_grad4.py b/python/jittor/compatibility/tutorial/auto_grad4.py
deleted file mode 100644
index 062d0b0e..00000000
--- a/python/jittor/compatibility/tutorial/auto_grad4.py
+++ /dev/null
@@ -1,71 +0,0 @@
-# -*- coding: utf-8 -*-
-import torch
-import math
-
-
-# Create Tensors to hold input and outputs.
-x = torch.linspace(-math.pi, math.pi, 2000)
-y = torch.sin(x)
-
-# For this example, the output y is a linear function of (x, x^2, x^3), so
-# we can consider it as a linear layer neural network. Let's prepare the
-# tensor (x, x^2, x^3).
-p = torch.tensor([1, 2, 3])
-xx = x.unsqueeze(-1).pow(p)
-
-# In the above code, x.unsqueeze(-1) has shape (2000, 1), and p has shape
-# (3,), for this case, broadcasting semantics will apply to obtain a tensor
-# of shape (2000, 3)
-
-# Use the nn package to define our model as a sequence of layers. nn.Sequential
-# is a Module which contains other Modules, and applies them in sequence to
-# produce its output. The Linear Module computes output from input using a
-# linear function, and holds internal Tensors for its weight and bias.
-# The Flatten layer flatens the output of the linear layer to a 1D tensor,
-# to match the shape of `y`.
-model = torch.nn.Sequential(
- torch.nn.Linear(3, 1),
- torch.nn.Flatten(0, 1)
-)
-
-# The nn package also contains definitions of popular loss functions; in this
-# case we will use Mean Squared Error (MSE) as our loss function.
-loss_fn = torch.nn.MSELoss(reduction='sum')
-# print(model[0].weight.requires_grad)
-
-learning_rate = 1e-6
-for t in range(8000):
-
- # Forward pass: compute predicted y by passing x to the model. Module objects
- # override the __call__ operator so you can call them like functions. When
- # doing so you pass a Tensor of input data to the Module and it produces
- # a Tensor of output data.
- y_pred = model(xx)
-
- # Compute and print loss. We pass Tensors containing the predicted and true
- # values of y, and the loss function returns a Tensor containing the
- # loss.
- loss = loss_fn(y_pred, y)
- if t % 1000 == 999:
- print(t, loss.item())
-
- # Zero the gradients before running the backward pass.
- model.zero_grad()
-
- # Backward pass: compute gradient of the loss with respect to all the learnable
- # parameters of the model. Internally, the parameters of each Module are stored
- # in Tensors with requires_grad=True, so this call will compute gradients for
- # all learnable parameters in the model.
- loss.backward()
-
- # Update the weights using gradient descent. Each parameter is a Tensor, so
- # we can access its gradients like we did before.
- with torch.no_grad():
- for param in model.parameters():
- param -= learning_rate * param.grad
-
-# You can access the first layer of `model` like accessing the first item of a list
-linear_layer = model[0]
-
-# For linear layer, its parameters are stored as `weight` and `bias`.
-print(f'Result: y = {linear_layer.bias.item()} + {linear_layer.weight[:, 0].item()} x + {linear_layer.weight[:, 1].item()} x^2 + {linear_layer.weight[:, 2].item()} x^3')
\ No newline at end of file
diff --git a/python/jittor/compatibility/tutorial/auto_grad5_optim.py b/python/jittor/compatibility/tutorial/auto_grad5_optim.py
deleted file mode 100644
index 04949320..00000000
--- a/python/jittor/compatibility/tutorial/auto_grad5_optim.py
+++ /dev/null
@@ -1,53 +0,0 @@
-# -*- coding: utf-8 -*-
-import torch
-import math
-
-
-# Create Tensors to hold input and outputs.
-x = torch.linspace(-math.pi, math.pi, 2000)
-y = torch.sin(x)
-
-# Prepare the input tensor (x, x^2, x^3).
-p = torch.tensor([1, 2, 3])
-xx = x.unsqueeze(-1).pow(p)
-
-# Use the nn package to define our model and loss function.
-model = torch.nn.Sequential(
- torch.nn.Linear(3, 1),
- torch.nn.Flatten(0, 1)
-)
-loss_fn = torch.nn.MSELoss(reduction='sum')
-
-# Use the optim package to define an Optimizer that will update the weights of
-# the model for us. Here we will use RMSprop; the optim package contains many other
-# optimization algorithms. The first argument to the RMSprop constructor tells the
-# optimizer which Tensors it should update.
-learning_rate = 1e-3
-optimizer = torch.optim.RMSprop(model.parameters(), lr=learning_rate)
-for t in range(8000):
- # Forward pass: compute predicted y by passing x to the model.
- y_pred = model(xx)
-
- # Compute and print loss.
- loss = loss_fn(y_pred, y)
- if t % 1000 == 999:
- print(t, loss.item())
-
- # Before the backward pass, use the optimizer object to zero all of the
- # gradients for the variables it will update (which are the learnable
- # weights of the model). This is because by default, gradients are
- # accumulated in buffers( i.e, not overwritten) whenever .backward()
- # is called. Checkout docs of torch.autograd.backward for more details.
- optimizer.zero_grad()
-
- # Backward pass: compute gradient of the loss with respect to model
- # parameters
- loss.backward()
-
- # Calling the step function on an Optimizer makes an update to its
- # parameters
- optimizer.step()
-
-
-linear_layer = model[0]
-print(f'Result: y = {linear_layer.bias.item()} + {linear_layer.weight[:, 0].item()} x + {linear_layer.weight[:, 1].item()} x^2 + {linear_layer.weight[:, 2].item()} x^3')
\ No newline at end of file
diff --git a/python/jittor/compatibility/tutorial/auto_grad6_module.py b/python/jittor/compatibility/tutorial/auto_grad6_module.py
deleted file mode 100644
index a240e2b5..00000000
--- a/python/jittor/compatibility/tutorial/auto_grad6_module.py
+++ /dev/null
@@ -1,59 +0,0 @@
-# -*- coding: utf-8 -*-
-import torch
-import math
-
-
-class Polynomial3(torch.nn.Module):
- def __init__(self):
- """
- In the constructor we instantiate four parameters and assign them as
- member parameters.
- """
- super().__init__()
- self.a = torch.nn.Parameter(torch.randn(()))
- self.b = torch.nn.Parameter(torch.randn(()))
- self.c = torch.nn.Parameter(torch.randn(()))
- self.d = torch.nn.Parameter(torch.randn(()))
-
- def forward(self, x):
- """
- In the forward function we accept a Tensor of input data and we must return
- a Tensor of output data. We can use Modules defined in the constructor as
- well as arbitrary operators on Tensors.
- """
- return self.a + self.b * x + self.c * x ** 2 + self.d * x ** 3
-
- def string(self):
- """
- Just like any class in Python, you can also define custom method on PyTorch modules
- """
- return f'y = {self.a.item()} + {self.b.item()} x + {self.c.item()} x^2 + {self.d.item()} x^3'
-
-
-# Create Tensors to hold input and outputs.
-x = torch.linspace(-math.pi, math.pi, 2000)
-y = torch.sin(x)
-
-# Construct our model by instantiating the class defined above
-model = Polynomial3()
-
-# Construct our loss function and an Optimizer. The call to model.parameters()
-# in the SGD constructor will contain the learnable parameters (defined
-# with torch.nn.Parameter) which are members of the model.
-criterion = torch.nn.MSELoss(reduction='sum')
-optimizer = torch.optim.SGD(model.parameters(), lr=1e-6)
-for t in range(8000):
- # Forward pass: Compute predicted y by passing x to the model
- y_pred = model(x)
-
- # Compute and print loss
- loss = criterion(y_pred, y)
- if t % 1000 == 999:
- print(t, loss.item())
-
- # Zero gradients, perform a backward pass, and update the weights.
- optimizer.zero_grad()
- loss.backward()
- optimizer.step()
-
-print(f'Result: {model.string()}')
\ No newline at end of file
diff --git a/python/jittor/compatibility/tutorial/auto_grad7_dynet.py b/python/jittor/compatibility/tutorial/auto_grad7_dynet.py
deleted file mode 100644
index fa954771..00000000
--- a/python/jittor/compatibility/tutorial/auto_grad7_dynet.py
+++ /dev/null
@@ -1,69 +0,0 @@
-# -*- coding: utf-8 -*-
-import random
-import torch
-import math
-
-
-class DynamicNet(torch.nn.Module):
- def __init__(self):
- """
- In the constructor we instantiate five parameters and assign them as members.
- """
- super().__init__()
- self.a = torch.nn.Parameter(torch.randn(()))
- self.b = torch.nn.Parameter(torch.randn(()))
- self.c = torch.nn.Parameter(torch.randn(()))
- self.d = torch.nn.Parameter(torch.randn(()))
- self.e = torch.nn.Parameter(torch.randn(()))
-
- def forward(self, x):
- """
- For the forward pass of the model, we randomly choose either 4, 5
- and reuse the e parameter to compute the contribution of these orders.
-
- Since each forward pass builds a dynamic computation graph, we can use normal
- Python control-flow operators like loops or conditional statements when
- defining the forward pass of the model.
-
- Here we also see that it is perfectly safe to reuse the same parameter many
- times when defining a computational graph.
- """
- y = self.a + self.b * x + self.c * x ** 2 + self.d * x ** 3
- for exp in range(4, random.randint(4, 6)):
- y = y + self.e * x ** exp
- return y
-
- def string(self):
- """
- Just like any class in Python, you can also define custom method on PyTorch modules
- """
- return f'y = {self.a.item()} + {self.b.item()} x + {self.c.item()} x^2 + {self.d.item()} x^3 + {self.e.item()} x^4 ? + {self.e.item()} x^5 ?'
-
-
-# Create Tensors to hold input and outputs.
-x = torch.linspace(-math.pi, math.pi, 2000)
-y = torch.sin(x)
-
-# Construct our model by instantiating the class defined above
-model = DynamicNet()
-
-# Construct our loss function and an Optimizer. Training this strange model with
-# vanilla stochastic gradient descent is tough, so we use momentum
-criterion = torch.nn.MSELoss(reduction='sum')
-optimizer = torch.optim.SGD(model.parameters(), lr=1e-8, momentum=0.9)
-for t in range(60000):
- # Forward pass: Compute predicted y by passing x to the model
- y_pred = model(x)
-
- # Compute and print loss
- loss = criterion(y_pred, y)
- if t % 2000 == 1999:
- print(t, loss.item())
-
- # Zero gradients, perform a backward pass, and update the weights.
- optimizer.zero_grad()
- loss.backward()
- optimizer.step()
- # print(torch.liveness_info())
-
-print(f'Result: {model.string()}')
\ No newline at end of file
diff --git a/python/jittor/compatibility/tutorial/quickstart.py b/python/jittor/compatibility/tutorial/quickstart.py
deleted file mode 100644
index f0401a9b..00000000
--- a/python/jittor/compatibility/tutorial/quickstart.py
+++ /dev/null
@@ -1,106 +0,0 @@
-import torch
-from torch import nn
-# from jtorch.utils import DataLoader
-from torch.utils.data import DataLoader
-from torchvision import datasets
-from torchvision.transforms import ToTensor
-
-# Download training data from open datasets.
-training_data = datasets.FashionMNIST(
- root="data",
- train=True,
- download=True,
- transform=ToTensor(),
-)
-
-# Download test data from open datasets.
-test_data = datasets.FashionMNIST(
- root="data",
- train=False,
- download=True,
- transform=ToTensor(),
-)
-
-batch_size = 64
-
-# Create data loaders.
-train_dataloader = DataLoader(training_data, batch_size=batch_size)
-test_dataloader = DataLoader(test_data, batch_size=batch_size)
-
-print(len(train_dataloader))
-for X, y in test_dataloader:
- print(f"Shape of X [N, C, H, W]: {X.shape}")
- print(f"Shape of y: {y.shape} {y.dtype}")
- break
-
-# Get cpu or gpu device for training.
-device = "cuda" if torch.cuda.is_available() else "cpu"
-print(f"Using {device} device")
-
-# Define model
-class NeuralNetwork(nn.Module):
- def __init__(self):
- super(NeuralNetwork, self).__init__()
- self.flatten = nn.Flatten()
- self.linear_relu_stack = nn.Sequential(
- nn.Linear(28*28, 512),
- nn.ReLU(),
- nn.Linear(512, 512),
- nn.ReLU(),
- nn.Linear(512, 10)
- )
-
- def forward(self, x):
- x = self.flatten(x)
- logits = self.linear_relu_stack(x)
- return logits
-
-model = NeuralNetwork().to(device)
-print(model)
-
-
-loss_fn = nn.CrossEntropyLoss()
-optimizer = torch.optim.SGD(model.parameters(), lr=1e-3)
-
-def train(dataloader, model, loss_fn, optimizer):
- size = len(dataloader.dataset)
- model.train()
- for batch, (X, y) in enumerate(dataloader):
- X, y = X.to(device), y.to(device)
-
- # Compute prediction error
- pred = model(X)
- loss = loss_fn(pred, y)
-
- # Backpropagation
- optimizer.zero_grad()
- loss.backward()
- optimizer.step()
-
- if batch % 100 == 0:
- loss, current = loss.item(), batch * len(X)
- print(f"loss: {loss:>7f} [{current:>5d}/{size:>5d}]")
-
-def test(dataloader, model, loss_fn):
- size = len(dataloader.dataset)
- num_batches = len(dataloader)
- model.eval()
- test_loss, correct = 0, 0
- with torch.no_grad():
- for X, y in dataloader:
- X, y = X.to(device), y.to(device)
- pred = model(X)
- test_loss += loss_fn(pred, y).item()
- correct += (pred.argmax(1) == y).type(torch.float).sum().item()
- test_loss /= num_batches
- correct /= size
- print(f"Test Error: \n Accuracy: {(100*correct):>0.1f}%, Avg loss: {test_loss:>8f} \n")
-
-
-epochs = 5
-test(test_dataloader, model, loss_fn)
-for t in range(epochs):
- print(f"Epoch {t+1}\n-------------------------------")
- train(train_dataloader, model, loss_fn, optimizer)
- test(test_dataloader, model, loss_fn)
-print("Done!")
\ No newline at end of file
diff --git a/python/jittor/compatibility/utils/__init__.py b/python/jittor/compatibility/utils/__init__.py
deleted file mode 100644
index ac2c2bd8..00000000
--- a/python/jittor/compatibility/utils/__init__.py
+++ /dev/null
@@ -1,5 +0,0 @@
-cpp_extension = None
-_flatten_dense_tensors = None
-_unflatten_dense_tensors = None
-
-tensorboard = None
\ No newline at end of file
diff --git a/python/jittor/compatibility/utils/_pytree.py b/python/jittor/compatibility/utils/_pytree.py
deleted file mode 100644
index c3118964..00000000
--- a/python/jittor/compatibility/utils/_pytree.py
+++ /dev/null
@@ -1,3 +0,0 @@
-#TODO: Implement this
-_register_pytree_node = None
-_dict_flatten = None
\ No newline at end of file
diff --git a/python/jittor/compatibility/utils/checkpoint.py b/python/jittor/compatibility/utils/checkpoint.py
deleted file mode 100644
index ba3c3e8e..00000000
--- a/python/jittor/compatibility/utils/checkpoint.py
+++ /dev/null
@@ -1,8 +0,0 @@
-detach_variable = None
-
-
-def checkpoint(
- *args,
- **kwargs
-):
- pass
diff --git a/python/jittor/compatibility/utils/data.py b/python/jittor/compatibility/utils/data.py
deleted file mode 100644
index 71946a23..00000000
--- a/python/jittor/compatibility/utils/data.py
+++ /dev/null
@@ -1,137 +0,0 @@
-import jittor as jt
-import jittor.dataset
-from jittor.dataset import Dataset as JDataset
-
-from collections import namedtuple
-from typing import Any, Callable, Iterable, Optional, Sequence, Union
-
-
-class Dataset:
- def __getitem__(self, index):
- raise NotImplementedError
-
-class IterableDataset:
- def __iter__(self):
- raise NotImplementedError
-
-
-class DataLoader(JDataset):
- def __init__(self, dataset,
- batch_size: Optional[int] = 1,
- shuffle: Optional[bool] = False,
- sampler = None,
- batch_sampler = None,
- num_workers: int = 0,
- collate_fn = None,
- pin_memory: bool = False,
- drop_last: bool = False,
- timeout: float = 0,
- worker_init_fn = None,
- multiprocessing_context=None,
- generator=None,
- *, prefetch_factor: int = 2,
- persistent_workers: bool = False,
- pin_memory_device: str = "") -> None:
- super().__init__(batch_size=batch_size,
- shuffle=shuffle,
- num_workers=num_workers,
- drop_last=drop_last)
-
- unsupported_kwargs = {
- "batch_sampler": batch_sampler,
- "pin_memory": pin_memory,
- "timeout": timeout,
- "worker_init_fn": worker_init_fn,
- "multiprocessing_context": multiprocessing_context,
- "generator": generator,
- "persistent_workers": persistent_workers,
- "pin_memory_device": pin_memory_device
- }
- for kwarg, value in unsupported_kwargs.items():
- if value:
- jt.LOG.w(f"Not implemented Dataloader kwarg: {kwarg}")
-
- self.dataset = dataset
- self.collate_fn = collate_fn
- self.sampler = sampler
-
- if not isinstance(dataset, IterableDataset):
- self.total_len = len(dataset)
- else:
- # TODO: support multiple worker for iterable dataset
- assert(num_workers == 0)
-
- def collate_batch(self, batch):
- if self.collate_fn is not None:
- return self.collate_fn(batch)
- else:
- return super().collate_batch(batch)
-
- def __getitem__(self, i):
- return self.dataset[i]
-
- def __iter__(self):
- if isinstance(self.dataset, IterableDataset):
- return self.inner_iter()
- else:
- return super().__iter__()
-
- def inner_iter(self):
- current_batch = []
-
- if jt.world_size > 1:
- assert self.batch_size % jt.world_size == 0, \
- f"IterableDataset does not support a batch size ({self.batch_size}) that is not evenly divisible by the number of processes f{jt.world_size}"
- real_batch_size = int(self.batch_size / jt.world_size)
- else:
- real_batch_size = self.batch_size
-
- for element in self.dataset:
- current_batch.append(element)
-
- if len(current_batch) == real_batch_size:
- current_batch = self.collate_batch(current_batch)
- current_batch = self.to_jittor(current_batch)
- yield current_batch
- current_batch = []
-
- if not self.drop_last and len(current_batch) > 0:
- current_batch = self.collate_batch(current_batch)
- yield self.to_jittor(current_batch)
-
-def get_worker_info():
- # always return the fake worker info
- return namedtuple('WorkerInfo', 'id num_workers')(0, 1)
-
-class RandomSampler(jt.dataset.RandomSampler):
- def __init__(self, dataset, generator=None, **kwargs):
- super().__init__(dataset, **kwargs)
-
- def __iter__(self):
- if getattr(self.dataset, "support_random_access", True):
- return super().__iter__()
- else:
- self.dataset.shuffle()
- return iter(range(self.dataset.__real_len__() if hasattr(self.dataset,"__real_len__") else self.dataset.__len__()))
-
-class DistributedSampler(jt.dataset.Sampler):
- def __init__(self, sampler: RandomSampler):
- assert(isinstance(sampler, RandomSampler))
- self.sampler = sampler
-
- def set_epoch(self, epoch: int):
- ### do nothing, let jittor's inner dataset handle
- pass
-
- def __iter__(self):
- return self.sampler.__iter__()
-
- def __len__(self):
- return self.sampler.__len__()
-
-BatchSampler = jt.dataset.BatchSampler
-Sampler = jt.dataset.Sampler
-SequentialSampler = jt.dataset.SequentialSampler
-SubsetRandomSampler = jt.dataset.SubsetRandomSampler
-
-TensorDataset = Dataset
diff --git a/python/jittor/compatibility/utils/dtype.py b/python/jittor/compatibility/utils/dtype.py
deleted file mode 100644
index 41728383..00000000
--- a/python/jittor/compatibility/utils/dtype.py
+++ /dev/null
@@ -1,9 +0,0 @@
-from typing import Callable, Union
-Dtype = Union[Callable, str]
-
-def get_string_dtype(dtype):
- if callable(dtype):
- dtype = dtype.__name__
- if not isinstance(dtype, str):
- raise ValueError(f"dtype is expected to be str, python type function, or jittor type function, but got {dtype}.")
- return dtype
\ No newline at end of file
diff --git a/python/jittor/compatibility/utils/hooks.py b/python/jittor/compatibility/utils/hooks.py
deleted file mode 100644
index e69de29b..00000000
diff --git a/python/jittor/compatibility/utils/pip_publish.py b/python/jittor/compatibility/utils/pip_publish.py
deleted file mode 100644
index 72ff245f..00000000
--- a/python/jittor/compatibility/utils/pip_publish.py
+++ /dev/null
@@ -1,34 +0,0 @@
-import os
-import glob
-import shutil
-import sys
-
-home_path = os.path.join(os.path.dirname(__file__), "..", "..", "..")
-home_path = os.path.abspath(home_path)
-
-def callback(func, path, exc_info):
- print(f"remove \"{path}\" failed.")
-
-def rmtree(path):
- if os.path.isdir(path):
- print(f"remove \"{path}\" recursive.")
- shutil.rmtree(path, onerror=callback)
-
-def remove_tmpfile():
- dist_file = home_path+"/dist"
- egg_file = glob.glob(home_path+"/**/*egg-info")
- rmtree(dist_file)
- for e in egg_file:
- rmtree(e)
-
-def run_cmd(cmd):
- print("[CMD]", cmd)
- assert os.system(cmd)==0
-
-os.chdir(home_path)
-remove_tmpfile()
-
-run_cmd(f"{sys.executable} ./setup.py sdist")
-run_cmd(f"{sys.executable} -m twine upload dist/*")
-
-remove_tmpfile()
\ No newline at end of file
diff --git a/python/jittor/compatibility/vision/_internally_replaced_utils.py b/python/jittor/compatibility/vision/_internally_replaced_utils.py
deleted file mode 100644
index 748fa2ea..00000000
--- a/python/jittor/compatibility/vision/_internally_replaced_utils.py
+++ /dev/null
@@ -1,46 +0,0 @@
-import importlib.machinery
-import os
-
-
-def _download_file_from_remote_location(fpath: str, url: str) -> None:
- pass
-
-
-def _is_remote_location_available() -> bool:
- return False
-
-
-def _get_extension_path(lib_name):
-
- lib_dir = os.path.dirname(__file__)
- if os.name == "nt":
- # Register the main torchvision library location on the default DLL path
- import ctypes
- import sys
-
- kernel32 = ctypes.WinDLL("kernel32.dll", use_last_error=True)
- with_load_library_flags = hasattr(kernel32, "AddDllDirectory")
- prev_error_mode = kernel32.SetErrorMode(0x0001)
-
- if with_load_library_flags:
- kernel32.AddDllDirectory.restype = ctypes.c_void_p
-
- if sys.version_info >= (3, 8):
- os.add_dll_directory(lib_dir)
- elif with_load_library_flags:
- res = kernel32.AddDllDirectory(lib_dir)
- if res is None:
- err = ctypes.WinError(ctypes.get_last_error())
- err.strerror += f' Error adding "{lib_dir}" to the DLL directories.'
- raise err
-
- kernel32.SetErrorMode(prev_error_mode)
-
- loader_details = (importlib.machinery.ExtensionFileLoader, importlib.machinery.EXTENSION_SUFFIXES)
-
- extfinder = importlib.machinery.FileFinder(lib_dir, loader_details)
- ext_specs = extfinder.find_spec(lib_name)
- if ext_specs is None:
- raise ImportError
-
- return ext_specs.origin
\ No newline at end of file
diff --git a/python/jittor/compatibility/vision/datasets/__init__.py b/python/jittor/compatibility/vision/datasets/__init__.py
deleted file mode 100644
index d04187f1..00000000
--- a/python/jittor/compatibility/vision/datasets/__init__.py
+++ /dev/null
@@ -1,9 +0,0 @@
-from .mnist import EMNIST, FashionMNIST, KMNIST, MNIST, QMNIST
-
-__all__ = (
- "EMNIST",
- "FashionMNIST",
- "QMNIST",
- "MNIST",
- "KMNIST",
-)
\ No newline at end of file
diff --git a/python/jittor/compatibility/vision/datasets/mnist.py b/python/jittor/compatibility/vision/datasets/mnist.py
deleted file mode 100644
index dfc3787b..00000000
--- a/python/jittor/compatibility/vision/datasets/mnist.py
+++ /dev/null
@@ -1,558 +0,0 @@
-import codecs
-import os
-import os.path
-import shutil
-import string
-import sys
-import warnings
-from typing import Any, Callable, Dict, List, Optional, Tuple
-from urllib.error import URLError
-
-import numpy as np
-import torch
-from PIL import Image
-
-from .utils import check_integrity, download_and_extract_archive, extract_archive, verify_str_arg
-from .vision import VisionDataset
-
-
-class MNIST(VisionDataset):
- """`MNIST `_ Dataset.
-
- Args:
- root (string): Root directory of dataset where ``MNIST/raw/train-images-idx3-ubyte``
- and ``MNIST/raw/t10k-images-idx3-ubyte`` exist.
- train (bool, optional): If True, creates dataset from ``train-images-idx3-ubyte``,
- otherwise from ``t10k-images-idx3-ubyte``.
- download (bool, optional): If True, downloads the dataset from the internet and
- puts it in root directory. If dataset is already downloaded, it is not
- downloaded again.
- transform (callable, optional): A function/transform that takes in an PIL image
- and returns a transformed version. E.g, ``transforms.RandomCrop``
- target_transform (callable, optional): A function/transform that takes in the
- target and transforms it.
- """
-
- mirrors = [
- "http://yann.lecun.com/exdb/mnist/",
- "https://ossci-datasets.s3.amazonaws.com/mnist/",
- ]
-
- resources = [
- ("train-images-idx3-ubyte.gz", "f68b3c2dcbeaaa9fbdd348bbdeb94873"),
- ("train-labels-idx1-ubyte.gz", "d53e105ee54ea40749a09fcbcd1e9432"),
- ("t10k-images-idx3-ubyte.gz", "9fb629c4189551a2d022fa330f9573f3"),
- ("t10k-labels-idx1-ubyte.gz", "ec29112dd5afa0611ce80d1b7f02629c"),
- ]
-
- training_file = "training.pt"
- test_file = "test.pt"
- classes = [
- "0 - zero",
- "1 - one",
- "2 - two",
- "3 - three",
- "4 - four",
- "5 - five",
- "6 - six",
- "7 - seven",
- "8 - eight",
- "9 - nine",
- ]
-
- @property
- def train_labels(self):
- warnings.warn("train_labels has been renamed targets")
- return self.targets
-
- @property
- def test_labels(self):
- warnings.warn("test_labels has been renamed targets")
- return self.targets
-
- @property
- def train_data(self):
- warnings.warn("train_data has been renamed data")
- return self.data
-
- @property
- def test_data(self):
- warnings.warn("test_data has been renamed data")
- return self.data
-
- def __init__(
- self,
- root: str,
- train: bool = True,
- transform: Optional[Callable] = None,
- target_transform: Optional[Callable] = None,
- download: bool = False,
- ) -> None:
- super().__init__(root, transform=transform, target_transform=target_transform)
- self.train = train # training set or test set
-
- if self._check_legacy_exist():
- self.data, self.targets = self._load_legacy_data()
- return
-
- if download:
- self.download()
-
- if not self._check_exists():
- raise RuntimeError("Dataset not found. You can use download=True to download it")
-
- self.data, self.targets = self._load_data()
-
- def _check_legacy_exist(self):
- processed_folder_exists = os.path.exists(self.processed_folder)
- if not processed_folder_exists:
- return False
-
- return all(
- check_integrity(os.path.join(self.processed_folder, file)) for file in (self.training_file, self.test_file)
- )
-
- def _load_legacy_data(self):
- # This is for BC only. We no longer cache the data in a custom binary, but simply read from the raw data
- # directly.
- data_file = self.training_file if self.train else self.test_file
- return torch.load(os.path.join(self.processed_folder, data_file))
-
- def _load_data(self):
- image_file = f"{'train' if self.train else 't10k'}-images-idx3-ubyte"
- data = read_image_file(os.path.join(self.raw_folder, image_file))
-
- label_file = f"{'train' if self.train else 't10k'}-labels-idx1-ubyte"
- targets = read_label_file(os.path.join(self.raw_folder, label_file))
-
- return data, targets
-
- def __getitem__(self, index: int) -> Tuple[Any, Any]:
- """
- Args:
- index (int): Index
-
- Returns:
- tuple: (image, target) where target is index of the target class.
- """
- img, target = self.data[index], int(self.targets[index])
-
- # doing this so that it is consistent with all other datasets
- # to return a PIL Image
- img = Image.fromarray(img.numpy(), mode="L")
-
- if self.transform is not None:
- img = self.transform(img)
-
- if self.target_transform is not None:
- target = self.target_transform(target)
-
- return img, target
-
- def __len__(self) -> int:
- return len(self.data)
-
- @property
- def raw_folder(self) -> str:
- return os.path.join(self.root, self.__class__.__name__, "raw")
-
- @property
- def processed_folder(self) -> str:
- return os.path.join(self.root, self.__class__.__name__, "processed")
-
- @property
- def class_to_idx(self) -> Dict[str, int]:
- return {_class: i for i, _class in enumerate(self.classes)}
-
- def _check_exists(self) -> bool:
- return all(
- check_integrity(os.path.join(self.raw_folder, os.path.splitext(os.path.basename(url))[0]))
- for url, _ in self.resources
- )
-
- def download(self) -> None:
- """Download the MNIST data if it doesn't exist already."""
-
- if self._check_exists():
- return
-
- os.makedirs(self.raw_folder, exist_ok=True)
-
- # download files
- for filename, md5 in self.resources:
- for mirror in self.mirrors:
- url = f"{mirror}{filename}"
- try:
- print(f"Downloading {url}")
- download_and_extract_archive(url, download_root=self.raw_folder, filename=filename, md5=md5)
- except URLError as error:
- print(f"Failed to download (trying next):\n{error}")
- continue
- finally:
- print()
- break
- else:
- raise RuntimeError(f"Error downloading {filename}")
-
- def extra_repr(self) -> str:
- split = "Train" if self.train is True else "Test"
- return f"Split: {split}"
-
-
-class FashionMNIST(MNIST):
- """`Fashion-MNIST `_ Dataset.
-
- Args:
- root (string): Root directory of dataset where ``FashionMNIST/raw/train-images-idx3-ubyte``
- and ``FashionMNIST/raw/t10k-images-idx3-ubyte`` exist.
- train (bool, optional): If True, creates dataset from ``train-images-idx3-ubyte``,
- otherwise from ``t10k-images-idx3-ubyte``.
- download (bool, optional): If True, downloads the dataset from the internet and
- puts it in root directory. If dataset is already downloaded, it is not
- downloaded again.
- transform (callable, optional): A function/transform that takes in an PIL image
- and returns a transformed version. E.g, ``transforms.RandomCrop``
- target_transform (callable, optional): A function/transform that takes in the
- target and transforms it.
- """
-
- mirrors = ["http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/"]
-
- resources = [
- ("train-images-idx3-ubyte.gz", "8d4fb7e6c68d591d4c3dfef9ec88bf0d"),
- ("train-labels-idx1-ubyte.gz", "25c81989df183df01b3e8a0aad5dffbe"),
- ("t10k-images-idx3-ubyte.gz", "bef4ecab320f06d8554ea6380940ec79"),
- ("t10k-labels-idx1-ubyte.gz", "bb300cfdad3c16e7a12a480ee83cd310"),
- ]
- classes = ["T-shirt/top", "Trouser", "Pullover", "Dress", "Coat", "Sandal", "Shirt", "Sneaker", "Bag", "Ankle boot"]
-
-
-class KMNIST(MNIST):
- """`Kuzushiji-MNIST `_ Dataset.
-
- Args:
- root (string): Root directory of dataset where ``KMNIST/raw/train-images-idx3-ubyte``
- and ``KMNIST/raw/t10k-images-idx3-ubyte`` exist.
- train (bool, optional): If True, creates dataset from ``train-images-idx3-ubyte``,
- otherwise from ``t10k-images-idx3-ubyte``.
- download (bool, optional): If True, downloads the dataset from the internet and
- puts it in root directory. If dataset is already downloaded, it is not
- downloaded again.
- transform (callable, optional): A function/transform that takes in an PIL image
- and returns a transformed version. E.g, ``transforms.RandomCrop``
- target_transform (callable, optional): A function/transform that takes in the
- target and transforms it.
- """
-
- mirrors = ["http://codh.rois.ac.jp/kmnist/dataset/kmnist/"]
-
- resources = [
- ("train-images-idx3-ubyte.gz", "bdb82020997e1d708af4cf47b453dcf7"),
- ("train-labels-idx1-ubyte.gz", "e144d726b3acfaa3e44228e80efcd344"),
- ("t10k-images-idx3-ubyte.gz", "5c965bf0a639b31b8f53240b1b52f4d7"),
- ("t10k-labels-idx1-ubyte.gz", "7320c461ea6c1c855c0b718fb2a4b134"),
- ]
- classes = ["o", "ki", "su", "tsu", "na", "ha", "ma", "ya", "re", "wo"]
-
-
-class EMNIST(MNIST):
- """`EMNIST `_ Dataset.
-
- Args:
- root (string): Root directory of dataset where ``EMNIST/raw/train-images-idx3-ubyte``
- and ``EMNIST/raw/t10k-images-idx3-ubyte`` exist.
- split (string): The dataset has 6 different splits: ``byclass``, ``bymerge``,
- ``balanced``, ``letters``, ``digits`` and ``mnist``. This argument specifies
- which one to use.
- train (bool, optional): If True, creates dataset from ``training.pt``,
- otherwise from ``test.pt``.
- download (bool, optional): If True, downloads the dataset from the internet and
- puts it in root directory. If dataset is already downloaded, it is not
- downloaded again.
- transform (callable, optional): A function/transform that takes in an PIL image
- and returns a transformed version. E.g, ``transforms.RandomCrop``
- target_transform (callable, optional): A function/transform that takes in the
- target and transforms it.
- """
-
- url = "https://www.itl.nist.gov/iaui/vip/cs_links/EMNIST/gzip.zip"
- md5 = "58c8d27c78d21e728a6bc7b3cc06412e"
- splits = ("byclass", "bymerge", "balanced", "letters", "digits", "mnist")
- # Merged Classes assumes Same structure for both uppercase and lowercase version
- _merged_classes = {"c", "i", "j", "k", "l", "m", "o", "p", "s", "u", "v", "w", "x", "y", "z"}
- _all_classes = set(string.digits + string.ascii_letters)
- classes_split_dict = {
- "byclass": sorted(list(_all_classes)),
- "bymerge": sorted(list(_all_classes - _merged_classes)),
- "balanced": sorted(list(_all_classes - _merged_classes)),
- "letters": ["N/A"] + list(string.ascii_lowercase),
- "digits": list(string.digits),
- "mnist": list(string.digits),
- }
-
- def __init__(self, root: str, split: str, **kwargs: Any) -> None:
- self.split = verify_str_arg(split, "split", self.splits)
- self.training_file = self._training_file(split)
- self.test_file = self._test_file(split)
- super().__init__(root, **kwargs)
- self.classes = self.classes_split_dict[self.split]
-
- @staticmethod
- def _training_file(split) -> str:
- return f"training_{split}.pt"
-
- @staticmethod
- def _test_file(split) -> str:
- return f"test_{split}.pt"
-
- @property
- def _file_prefix(self) -> str:
- return f"emnist-{self.split}-{'train' if self.train else 'test'}"
-
- @property
- def images_file(self) -> str:
- return os.path.join(self.raw_folder, f"{self._file_prefix}-images-idx3-ubyte")
-
- @property
- def labels_file(self) -> str:
- return os.path.join(self.raw_folder, f"{self._file_prefix}-labels-idx1-ubyte")
-
- def _load_data(self):
- return read_image_file(self.images_file), read_label_file(self.labels_file)
-
- def _check_exists(self) -> bool:
- return all(check_integrity(file) for file in (self.images_file, self.labels_file))
-
- def download(self) -> None:
- """Download the EMNIST data if it doesn't exist already."""
-
- if self._check_exists():
- return
-
- os.makedirs(self.raw_folder, exist_ok=True)
-
- download_and_extract_archive(self.url, download_root=self.raw_folder, md5=self.md5)
- gzip_folder = os.path.join(self.raw_folder, "gzip")
- for gzip_file in os.listdir(gzip_folder):
- if gzip_file.endswith(".gz"):
- extract_archive(os.path.join(gzip_folder, gzip_file), self.raw_folder)
- shutil.rmtree(gzip_folder)
-
-
-class QMNIST(MNIST):
- """`QMNIST `_ Dataset.
-
- Args:
- root (string): Root directory of dataset whose ``raw``
- subdir contains binary files of the datasets.
- what (string,optional): Can be 'train', 'test', 'test10k',
- 'test50k', or 'nist' for respectively the mnist compatible
- training set, the 60k qmnist testing set, the 10k qmnist
- examples that match the mnist testing set, the 50k
- remaining qmnist testing examples, or all the nist
- digits. The default is to select 'train' or 'test'
- according to the compatibility argument 'train'.
- compat (bool,optional): A boolean that says whether the target
- for each example is class number (for compatibility with
- the MNIST dataloader) or a torch vector containing the
- full qmnist information. Default=True.
- download (bool, optional): If True, downloads the dataset from
- the internet and puts it in root directory. If dataset is
- already downloaded, it is not downloaded again.
- transform (callable, optional): A function/transform that
- takes in an PIL image and returns a transformed
- version. E.g, ``transforms.RandomCrop``
- target_transform (callable, optional): A function/transform
- that takes in the target and transforms it.
- train (bool,optional,compatibility): When argument 'what' is
- not specified, this boolean decides whether to load the
- training set ot the testing set. Default: True.
- """
-
- subsets = {"train": "train", "test": "test", "test10k": "test", "test50k": "test", "nist": "nist"}
- resources: Dict[str, List[Tuple[str, str]]] = { # type: ignore[assignment]
- "train": [
- (
- "https://raw.githubusercontent.com/facebookresearch/qmnist/master/qmnist-train-images-idx3-ubyte.gz",
- "ed72d4157d28c017586c42bc6afe6370",
- ),
- (
- "https://raw.githubusercontent.com/facebookresearch/qmnist/master/qmnist-train-labels-idx2-int.gz",
- "0058f8dd561b90ffdd0f734c6a30e5e4",
- ),
- ],
- "test": [
- (
- "https://raw.githubusercontent.com/facebookresearch/qmnist/master/qmnist-test-images-idx3-ubyte.gz",
- "1394631089c404de565df7b7aeaf9412",
- ),
- (
- "https://raw.githubusercontent.com/facebookresearch/qmnist/master/qmnist-test-labels-idx2-int.gz",
- "5b5b05890a5e13444e108efe57b788aa",
- ),
- ],
- "nist": [
- (
- "https://raw.githubusercontent.com/facebookresearch/qmnist/master/xnist-images-idx3-ubyte.xz",
- "7f124b3b8ab81486c9d8c2749c17f834",
- ),
- (
- "https://raw.githubusercontent.com/facebookresearch/qmnist/master/xnist-labels-idx2-int.xz",
- "5ed0e788978e45d4a8bd4b7caec3d79d",
- ),
- ],
- }
- classes = [
- "0 - zero",
- "1 - one",
- "2 - two",
- "3 - three",
- "4 - four",
- "5 - five",
- "6 - six",
- "7 - seven",
- "8 - eight",
- "9 - nine",
- ]
-
- def __init__(
- self, root: str, what: Optional[str] = None, compat: bool = True, train: bool = True, **kwargs: Any
- ) -> None:
- if what is None:
- what = "train" if train else "test"
- self.what = verify_str_arg(what, "what", tuple(self.subsets.keys()))
- self.compat = compat
- self.data_file = what + ".pt"
- self.training_file = self.data_file
- self.test_file = self.data_file
- super().__init__(root, train, **kwargs)
-
- @property
- def images_file(self) -> str:
- (url, _), _ = self.resources[self.subsets[self.what]]
- return os.path.join(self.raw_folder, os.path.splitext(os.path.basename(url))[0])
-
- @property
- def labels_file(self) -> str:
- _, (url, _) = self.resources[self.subsets[self.what]]
- return os.path.join(self.raw_folder, os.path.splitext(os.path.basename(url))[0])
-
- def _check_exists(self) -> bool:
- return all(check_integrity(file) for file in (self.images_file, self.labels_file))
-
- def _load_data(self):
- data = read_sn3_pascalvincent_tensor(self.images_file)
- if data.dtype != torch.uint8:
- raise TypeError(f"data should be of dtype torch.uint8 instead of {data.dtype}")
- if data.ndimension() != 3:
- raise ValueError("data should have 3 dimensions instead of {data.ndimension()}")
-
- targets = read_sn3_pascalvincent_tensor(self.labels_file).long()
- if targets.ndimension() != 2:
- raise ValueError(f"targets should have 2 dimensions instead of {targets.ndimension()}")
-
- if self.what == "test10k":
- data = data[0:10000, :, :].clone()
- targets = targets[0:10000, :].clone()
- elif self.what == "test50k":
- data = data[10000:, :, :].clone()
- targets = targets[10000:, :].clone()
-
- return data, targets
-
- def download(self) -> None:
- """Download the QMNIST data if it doesn't exist already.
- Note that we only download what has been asked for (argument 'what').
- """
- if self._check_exists():
- return
-
- os.makedirs(self.raw_folder, exist_ok=True)
- split = self.resources[self.subsets[self.what]]
-
- for url, md5 in split:
- download_and_extract_archive(url, self.raw_folder, md5=md5)
-
- def __getitem__(self, index: int) -> Tuple[Any, Any]:
- # redefined to handle the compat flag
- img, target = self.data[index], self.targets[index]
- img = Image.fromarray(img.numpy(), mode="L")
- if self.transform is not None:
- img = self.transform(img)
- if self.compat:
- target = int(target[0])
- if self.target_transform is not None:
- target = self.target_transform(target)
- return img, target
-
- def extra_repr(self) -> str:
- return f"Split: {self.what}"
-
-
-def get_int(b: bytes) -> int:
- return int(codecs.encode(b, "hex"), 16)
-
-
-SN3_PASCALVINCENT_BITSMAP = {
- 8: torch.uint8,
- 9: torch.int8,
- 11: torch.int16,
- 12: torch.int32,
- 13: torch.float32,
- 14: torch.float64,
-}
-
-TORCH_TYPE_BITS = {
- torch.uint8: 8,
- torch.int8: 8,
- torch.int16: 16,
- torch.int32: 32,
- torch.float32: 32,
- torch.float64: 64,
-}
-
-
-def read_sn3_pascalvincent_tensor(path: str, strict: bool = True) -> torch.Tensor:
- """Read a SN3 file in "Pascal Vincent" format (Lush file 'libidx/idx-io.lsh').
- Argument may be a filename, compressed filename, or file object.
- """
- # read
- with open(path, "rb") as f:
- data = f.read()
- # parse
- magic = get_int(data[0:4])
- nd = magic % 256
- ty = magic // 256
- assert 1 <= nd <= 3
- assert 8 <= ty <= 14
- torch_type = SN3_PASCALVINCENT_BITSMAP[ty]
- s = [get_int(data[4 * (i + 1) : 4 * (i + 2)]) for i in range(nd)]
-
- num_bytes_per_value = TORCH_TYPE_BITS[torch_type] // 8
- # The MNIST format uses the big endian byte order. If the system uses little endian byte order by default,
- # we need to reverse the bytes before we can read them with torch.frombuffer().
- needs_byte_reversal = sys.byteorder == "little" and num_bytes_per_value > 1
- parsed = torch.frombuffer(bytearray(data), dtype=torch_type, offset=(4 * (nd + 1)))
- if needs_byte_reversal:
- parsed = parsed.flip(0)
-
- assert parsed.shape[0] == np.prod(s) or not strict
- return parsed.view(*s)
-
-
-def read_label_file(path: str) -> torch.Tensor:
- x = read_sn3_pascalvincent_tensor(path, strict=False)
- if x.dtype != torch.uint8:
- raise TypeError(f"x should be of dtype torch.uint8 instead of {x.dtype}")
- if x.ndimension() != 1:
- raise ValueError(f"x should have 1 dimension instead of {x.ndimension()}")
- return x.long()
-
-
-def read_image_file(path: str) -> torch.Tensor:
- x = read_sn3_pascalvincent_tensor(path, strict=False)
- if x.dtype != torch.uint8:
- raise TypeError(f"x should be of dtype torch.uint8 instead of {x.dtype}")
- if x.ndimension() != 3:
- raise ValueError(f"x should have 3 dimension instead of {x.ndimension()}")
- return x
\ No newline at end of file
diff --git a/python/jittor/compatibility/vision/datasets/utils.py b/python/jittor/compatibility/vision/datasets/utils.py
deleted file mode 100644
index f9ae1a89..00000000
--- a/python/jittor/compatibility/vision/datasets/utils.py
+++ /dev/null
@@ -1,522 +0,0 @@
-import bz2
-import contextlib
-import gzip
-import hashlib
-import itertools
-import lzma
-import os
-import os.path
-import pathlib
-import re
-import sys
-import tarfile
-import urllib
-import urllib.error
-import urllib.request
-import warnings
-import zipfile
-from typing import Any, Callable, Dict, IO, Iterable, Iterator, List, Optional, Tuple, TypeVar
-from urllib.parse import urlparse
-
-import numpy as np
-import requests
-import torch
-from tqdm import tqdm
-
-from .._internally_replaced_utils import _download_file_from_remote_location, _is_remote_location_available
-
-USER_AGENT = "pytorch/vision"
-
-
-def _save_response_content(
- content: Iterator[bytes],
- destination: str,
- length: Optional[int] = None,
-) -> None:
- with open(destination, "wb") as fh, tqdm(total=length) as pbar:
- for chunk in content:
- # filter out keep-alive new chunks
- if not chunk:
- continue
-
- fh.write(chunk)
- pbar.update(len(chunk))
-
-
-def _urlretrieve(url: str, filename: str, chunk_size: int = 1024 * 32) -> None:
- with urllib.request.urlopen(urllib.request.Request(url, headers={"User-Agent": USER_AGENT})) as response:
- _save_response_content(iter(lambda: response.read(chunk_size), b""), filename, length=response.length)
-
-
-def gen_bar_updater() -> Callable[[int, int, int], None]:
- warnings.warn("The function `gen_bar_update` is deprecated since 0.13 and will be removed in 0.15.")
- pbar = tqdm(total=None)
-
- def bar_update(count, block_size, total_size):
- if pbar.total is None and total_size:
- pbar.total = total_size
- progress_bytes = count * block_size
- pbar.update(progress_bytes - pbar.n)
-
- return bar_update
-
-
-def calculate_md5(fpath: str, chunk_size: int = 1024 * 1024) -> str:
- # Setting the `usedforsecurity` flag does not change anything about the functionality, but indicates that we are
- # not using the MD5 checksum for cryptography. This enables its usage in restricted environments like FIPS. Without
- # it torchvision.datasets is unusable in these environments since we perform a MD5 check everywhere.
- if sys.version_info >= (3, 9):
- md5 = hashlib.md5(usedforsecurity=False)
- else:
- md5 = hashlib.md5()
- with open(fpath, "rb") as f:
- for chunk in iter(lambda: f.read(chunk_size), b""):
- md5.update(chunk)
- return md5.hexdigest()
-
-
-def check_md5(fpath: str, md5: str, **kwargs: Any) -> bool:
- return md5 == calculate_md5(fpath, **kwargs)
-
-
-def check_integrity(fpath: str, md5: Optional[str] = None) -> bool:
- if not os.path.isfile(fpath):
- return False
- if md5 is None:
- return True
- return check_md5(fpath, md5)
-
-
-def _get_redirect_url(url: str, max_hops: int = 3) -> str:
- initial_url = url
- headers = {"Method": "HEAD", "User-Agent": USER_AGENT}
-
- for _ in range(max_hops + 1):
- with urllib.request.urlopen(urllib.request.Request(url, headers=headers)) as response:
- if response.url == url or response.url is None:
- return url
-
- url = response.url
- else:
- raise RecursionError(
- f"Request to {initial_url} exceeded {max_hops} redirects. The last redirect points to {url}."
- )
-
-
-def _get_google_drive_file_id(url: str) -> Optional[str]:
- parts = urlparse(url)
-
- if re.match(r"(drive|docs)[.]google[.]com", parts.netloc) is None:
- return None
-
- match = re.match(r"/file/d/(?P[^/]*)", parts.path)
- if match is None:
- return None
-
- return match.group("id")
-
-
-def download_url(
- url: str, root: str, filename: Optional[str] = None, md5: Optional[str] = None, max_redirect_hops: int = 3
-) -> None:
- """Download a file from a url and place it in root.
-
- Args:
- url (str): URL to download file from
- root (str): Directory to place downloaded file in
- filename (str, optional): Name to save the file under. If None, use the basename of the URL
- md5 (str, optional): MD5 checksum of the download. If None, do not check
- max_redirect_hops (int, optional): Maximum number of redirect hops allowed
- """
- root = os.path.expanduser(root)
- if not filename:
- filename = os.path.basename(url)
- fpath = os.path.join(root, filename)
-
- os.makedirs(root, exist_ok=True)
-
- # check if file is already present locally
- if check_integrity(fpath, md5):
- print("Using downloaded and verified file: " + fpath)
- return
-
- if _is_remote_location_available():
- _download_file_from_remote_location(fpath, url)
- else:
- # expand redirect chain if needed
- url = _get_redirect_url(url, max_hops=max_redirect_hops)
-
- # check if file is located on Google Drive
- file_id = _get_google_drive_file_id(url)
- if file_id is not None:
- return download_file_from_google_drive(file_id, root, filename, md5)
-
- # download the file
- try:
- print("Downloading " + url + " to " + fpath)
- _urlretrieve(url, fpath)
- except (urllib.error.URLError, OSError) as e: # type: ignore[attr-defined]
- if url[:5] == "https":
- url = url.replace("https:", "http:")
- print("Failed download. Trying https -> http instead. Downloading " + url + " to " + fpath)
- _urlretrieve(url, fpath)
- else:
- raise e
-
- # check integrity of downloaded file
- if not check_integrity(fpath, md5):
- raise RuntimeError("File not found or corrupted.")
-
-
-def list_dir(root: str, prefix: bool = False) -> List[str]:
- """List all directories at a given root
-
- Args:
- root (str): Path to directory whose folders need to be listed
- prefix (bool, optional): If true, prepends the path to each result, otherwise
- only returns the name of the directories found
- """
- root = os.path.expanduser(root)
- directories = [p for p in os.listdir(root) if os.path.isdir(os.path.join(root, p))]
- if prefix is True:
- directories = [os.path.join(root, d) for d in directories]
- return directories
-
-
-def list_files(root: str, suffix: str, prefix: bool = False) -> List[str]:
- """List all files ending with a suffix at a given root
-
- Args:
- root (str): Path to directory whose folders need to be listed
- suffix (str or tuple): Suffix of the files to match, e.g. '.png' or ('.jpg', '.png').
- It uses the Python "str.endswith" method and is passed directly
- prefix (bool, optional): If true, prepends the path to each result, otherwise
- only returns the name of the files found
- """
- root = os.path.expanduser(root)
- files = [p for p in os.listdir(root) if os.path.isfile(os.path.join(root, p)) and p.endswith(suffix)]
- if prefix is True:
- files = [os.path.join(root, d) for d in files]
- return files
-
-
-def _extract_gdrive_api_response(response, chunk_size: int = 32 * 1024) -> Tuple[bytes, Iterator[bytes]]:
- content = response.iter_content(chunk_size)
- first_chunk = None
- # filter out keep-alive new chunks
- while not first_chunk:
- first_chunk = next(content)
- content = itertools.chain([first_chunk], content)
-
- try:
- match = re.search("Google Drive - (?P.+?)", first_chunk.decode())
- api_response = match["api_response"] if match is not None else None
- except UnicodeDecodeError:
- api_response = None
- return api_response, content
-
-
-def download_file_from_google_drive(file_id: str, root: str, filename: Optional[str] = None, md5: Optional[str] = None):
- """Download a Google Drive file from and place it in root.
-
- Args:
- file_id (str): id of file to be downloaded
- root (str): Directory to place downloaded file in
- filename (str, optional): Name to save the file under. If None, use the id of the file.
- md5 (str, optional): MD5 checksum of the download. If None, do not check
- """
- # Based on https://stackoverflow.com/questions/38511444/python-download-files-from-google-drive-using-url
-
- root = os.path.expanduser(root)
- if not filename:
- filename = file_id
- fpath = os.path.join(root, filename)
-
- os.makedirs(root, exist_ok=True)
-
- if check_integrity(fpath, md5):
- print(f"Using downloaded {'and verified ' if md5 else ''}file: {fpath}")
- return
-
- url = "https://drive.google.com/uc"
- params = dict(id=file_id, export="download")
- with requests.Session() as session:
- response = session.get(url, params=params, stream=True)
-
- for key, value in response.cookies.items():
- if key.startswith("download_warning"):
- token = value
- break
- else:
- api_response, content = _extract_gdrive_api_response(response)
- token = "t" if api_response == "Virus scan warning" else None
-
- if token is not None:
- response = session.get(url, params=dict(params, confirm=token), stream=True)
- api_response, content = _extract_gdrive_api_response(response)
-
- if api_response == "Quota exceeded":
- raise RuntimeError(
- f"The daily quota of the file {filename} is exceeded and it "
- f"can't be downloaded. This is a limitation of Google Drive "
- f"and can only be overcome by trying again later."
- )
-
- _save_response_content(content, fpath)
-
- # In case we deal with an unhandled GDrive API response, the file should be smaller than 10kB and contain only text
- if os.stat(fpath).st_size < 10 * 1024:
- with contextlib.suppress(UnicodeDecodeError), open(fpath) as fh:
- text = fh.read()
- # Regular expression to detect HTML. Copied from https://stackoverflow.com/a/70585604
- if re.search(r"?\s*[a-z-][^>]*\s*>|(&(?:[\w\d]+|#\d+|#x[a-f\d]+);)", text):
- warnings.warn(
- f"We detected some HTML elements in the downloaded file. "
- f"This most likely means that the download triggered an unhandled API response by GDrive. "
- f"Please report this to torchvision at https://github.com/pytorch/vision/issues including "
- f"the response:\n\n{text}"
- )
-
- if md5 and not check_md5(fpath, md5):
- raise RuntimeError(
- f"The MD5 checksum of the download file {fpath} does not match the one on record."
- f"Please delete the file and try again. "
- f"If the issue persists, please report this to torchvision at https://github.com/pytorch/vision/issues."
- )
-
-
-def _extract_tar(from_path: str, to_path: str, compression: Optional[str]) -> None:
- with tarfile.open(from_path, f"r:{compression[1:]}" if compression else "r") as tar:
- tar.extractall(to_path)
-
-
-_ZIP_COMPRESSION_MAP: Dict[str, int] = {
- ".bz2": zipfile.ZIP_BZIP2,
- ".xz": zipfile.ZIP_LZMA,
-}
-
-
-def _extract_zip(from_path: str, to_path: str, compression: Optional[str]) -> None:
- with zipfile.ZipFile(
- from_path, "r", compression=_ZIP_COMPRESSION_MAP[compression] if compression else zipfile.ZIP_STORED
- ) as zip:
- zip.extractall(to_path)
-
-
-_ARCHIVE_EXTRACTORS: Dict[str, Callable[[str, str, Optional[str]], None]] = {
- ".tar": _extract_tar,
- ".zip": _extract_zip,
-}
-_COMPRESSED_FILE_OPENERS: Dict[str, Callable[..., IO]] = {
- ".bz2": bz2.open,
- ".gz": gzip.open,
- ".xz": lzma.open,
-}
-_FILE_TYPE_ALIASES: Dict[str, Tuple[Optional[str], Optional[str]]] = {
- ".tbz": (".tar", ".bz2"),
- ".tbz2": (".tar", ".bz2"),
- ".tgz": (".tar", ".gz"),
-}
-
-
-def _detect_file_type(file: str) -> Tuple[str, Optional[str], Optional[str]]:
- """Detect the archive type and/or compression of a file.
-
- Args:
- file (str): the filename
-
- Returns:
- (tuple): tuple of suffix, archive type, and compression
-
- Raises:
- RuntimeError: if file has no suffix or suffix is not supported
- """
- suffixes = pathlib.Path(file).suffixes
- if not suffixes:
- raise RuntimeError(
- f"File '{file}' has no suffixes that could be used to detect the archive type and compression."
- )
- suffix = suffixes[-1]
-
- # check if the suffix is a known alias
- if suffix in _FILE_TYPE_ALIASES:
- return (suffix, *_FILE_TYPE_ALIASES[suffix])
-
- # check if the suffix is an archive type
- if suffix in _ARCHIVE_EXTRACTORS:
- return suffix, suffix, None
-
- # check if the suffix is a compression
- if suffix in _COMPRESSED_FILE_OPENERS:
- # check for suffix hierarchy
- if len(suffixes) > 1:
- suffix2 = suffixes[-2]
-
- # check if the suffix2 is an archive type
- if suffix2 in _ARCHIVE_EXTRACTORS:
- return suffix2 + suffix, suffix2, suffix
-
- return suffix, None, suffix
-
- valid_suffixes = sorted(set(_FILE_TYPE_ALIASES) | set(_ARCHIVE_EXTRACTORS) | set(_COMPRESSED_FILE_OPENERS))
- raise RuntimeError(f"Unknown compression or archive type: '{suffix}'.\nKnown suffixes are: '{valid_suffixes}'.")
-
-
-def _decompress(from_path: str, to_path: Optional[str] = None, remove_finished: bool = False) -> str:
- r"""Decompress a file.
-
- The compression is automatically detected from the file name.
-
- Args:
- from_path (str): Path to the file to be decompressed.
- to_path (str): Path to the decompressed file. If omitted, ``from_path`` without compression extension is used.
- remove_finished (bool): If ``True``, remove the file after the extraction.
-
- Returns:
- (str): Path to the decompressed file.
- """
- suffix, archive_type, compression = _detect_file_type(from_path)
- if not compression:
- raise RuntimeError(f"Couldn't detect a compression from suffix {suffix}.")
-
- if to_path is None:
- to_path = from_path.replace(suffix, archive_type if archive_type is not None else "")
-
- # We don't need to check for a missing key here, since this was already done in _detect_file_type()
- compressed_file_opener = _COMPRESSED_FILE_OPENERS[compression]
-
- with compressed_file_opener(from_path, "rb") as rfh, open(to_path, "wb") as wfh:
- wfh.write(rfh.read())
-
- if remove_finished:
- os.remove(from_path)
-
- return to_path
-
-
-def extract_archive(from_path: str, to_path: Optional[str] = None, remove_finished: bool = False) -> str:
- """Extract an archive.
-
- The archive type and a possible compression is automatically detected from the file name. If the file is compressed
- but not an archive the call is dispatched to :func:`decompress`.
-
- Args:
- from_path (str): Path to the file to be extracted.
- to_path (str): Path to the directory the file will be extracted to. If omitted, the directory of the file is
- used.
- remove_finished (bool): If ``True``, remove the file after the extraction.
-
- Returns:
- (str): Path to the directory the file was extracted to.
- """
- if to_path is None:
- to_path = os.path.dirname(from_path)
-
- suffix, archive_type, compression = _detect_file_type(from_path)
- if not archive_type:
- return _decompress(
- from_path,
- os.path.join(to_path, os.path.basename(from_path).replace(suffix, "")),
- remove_finished=remove_finished,
- )
-
- # We don't need to check for a missing key here, since this was already done in _detect_file_type()
- extractor = _ARCHIVE_EXTRACTORS[archive_type]
-
- extractor(from_path, to_path, compression)
- if remove_finished:
- os.remove(from_path)
-
- return to_path
-
-
-def download_and_extract_archive(
- url: str,
- download_root: str,
- extract_root: Optional[str] = None,
- filename: Optional[str] = None,
- md5: Optional[str] = None,
- remove_finished: bool = False,
-) -> None:
- download_root = os.path.expanduser(download_root)
- if extract_root is None:
- extract_root = download_root
- if not filename:
- filename = os.path.basename(url)
-
- download_url(url, download_root, filename, md5)
-
- archive = os.path.join(download_root, filename)
- print(f"Extracting {archive} to {extract_root}")
- extract_archive(archive, extract_root, remove_finished)
-
-
-def iterable_to_str(iterable: Iterable) -> str:
- return "'" + "', '".join([str(item) for item in iterable]) + "'"
-
-
-T = TypeVar("T", str, bytes)
-
-
-def verify_str_arg(
- value: T,
- arg: Optional[str] = None,
- valid_values: Optional[Iterable[T]] = None,
- custom_msg: Optional[str] = None,
-) -> T:
- if not isinstance(value, torch._six.string_classes):
- if arg is None:
- msg = "Expected type str, but got type {type}."
- else:
- msg = "Expected type str for argument {arg}, but got type {type}."
- msg = msg.format(type=type(value), arg=arg)
- raise ValueError(msg)
-
- if valid_values is None:
- return value
-
- if value not in valid_values:
- if custom_msg is not None:
- msg = custom_msg
- else:
- msg = "Unknown value '{value}' for argument {arg}. Valid values are {{{valid_values}}}."
- msg = msg.format(value=value, arg=arg, valid_values=iterable_to_str(valid_values))
- raise ValueError(msg)
-
- return value
-
-
-def _read_pfm(file_name: str, slice_channels: int = 2) -> np.ndarray:
- """Read file in .pfm format. Might contain either 1 or 3 channels of data.
-
- Args:
- file_name (str): Path to the file.
- slice_channels (int): Number of channels to slice out of the file.
- Useful for reading different data formats stored in .pfm files: Optical Flows, Stereo Disparity Maps, etc.
- """
-
- with open(file_name, "rb") as f:
- header = f.readline().rstrip()
- if header not in [b"PF", b"Pf"]:
- raise ValueError("Invalid PFM file")
-
- dim_match = re.match(rb"^(\d+)\s(\d+)\s$", f.readline())
- if not dim_match:
- raise Exception("Malformed PFM header.")
- w, h = (int(dim) for dim in dim_match.groups())
-
- scale = float(f.readline().rstrip())
- if scale < 0: # little-endian
- endian = "<"
- scale = -scale
- else:
- endian = ">" # big-endian
-
- data = np.fromfile(f, dtype=endian + "f")
-
- pfm_channels = 3 if header == b"PF" else 1
-
- data = data.reshape(h, w, pfm_channels).transpose(2, 0, 1)
- data = np.flip(data, axis=1) # flip on h dimension
- data = data[:slice_channels, :, :]
- return data.astype(np.float32)
\ No newline at end of file
diff --git a/python/jittor/compatibility/vision/datasets/vision.py b/python/jittor/compatibility/vision/datasets/vision.py
deleted file mode 100644
index d71dc2a5..00000000
--- a/python/jittor/compatibility/vision/datasets/vision.py
+++ /dev/null
@@ -1,104 +0,0 @@
-import os
-from typing import Any, Callable, List, Optional, Tuple
-
-import torch
-import torch.utils.data as data
-
-from ..utils import _log_api_usage_once
-
-
-class VisionDataset(data.Dataset):
- """
- Base Class For making datasets which are compatible with torchvision.
- It is necessary to override the ``__getitem__`` and ``__len__`` method.
- Args:
- root (string): Root directory of dataset.
- transforms (callable, optional): A function/transforms that takes in
- an image and a label and returns the transformed versions of both.
- transform (callable, optional): A function/transform that takes in an PIL image
- and returns a transformed version. E.g, ``transforms.RandomCrop``
- target_transform (callable, optional): A function/transform that takes in the
- target and transforms it.
- .. note::
- :attr:`transforms` and the combination of :attr:`transform` and :attr:`target_transform` are mutually exclusive.
- """
-
- _repr_indent = 4
-
- def __init__(
- self,
- root: str,
- transforms: Optional[Callable] = None,
- transform: Optional[Callable] = None,
- target_transform: Optional[Callable] = None,
- ) -> None:
- self.root = root
-
- has_transforms = transforms is not None
- has_separate_transform = transform is not None or target_transform is not None
- if has_transforms and has_separate_transform:
- raise ValueError("Only transforms or transform/target_transform can be passed as argument")
-
- # for backwards-compatibility
- self.transform = transform
- self.target_transform = target_transform
-
- if has_separate_transform:
- transforms = StandardTransform(transform, target_transform)
- self.transforms = transforms
-
- def __getitem__(self, index: int) -> Any:
- """
- Args:
- index (int): Index
- Returns:
- (Any): Sample and meta data, optionally transformed by the respective transforms.
- """
- raise NotImplementedError
-
- def __len__(self) -> int:
- raise NotImplementedError
-
- def __repr__(self) -> str:
- head = "Dataset " + self.__class__.__name__
- body = [f"Number of datapoints: {self.__len__()}"]
- if self.root is not None:
- body.append(f"Root location: {self.root}")
- body += self.extra_repr().splitlines()
- if hasattr(self, "transforms") and self.transforms is not None:
- body += [repr(self.transforms)]
- lines = [head] + [" " * self._repr_indent + line for line in body]
- return "\n".join(lines)
-
- def _format_transform_repr(self, transform: Callable, head: str) -> List[str]:
- lines = transform.__repr__().splitlines()
- return [f"{head}{lines[0]}"] + ["{}{}".format(" " * len(head), line) for line in lines[1:]]
-
- def extra_repr(self) -> str:
- return ""
-
-
-class StandardTransform:
- def __init__(self, transform: Optional[Callable] = None, target_transform: Optional[Callable] = None) -> None:
- self.transform = transform
- self.target_transform = target_transform
-
- def __call__(self, input: Any, target: Any) -> Tuple[Any, Any]:
- if self.transform is not None:
- input = self.transform(input)
- if self.target_transform is not None:
- target = self.target_transform(target)
- return input, target
-
- def _format_transform_repr(self, transform: Callable, head: str) -> List[str]:
- lines = transform.__repr__().splitlines()
- return [f"{head}{lines[0]}"] + ["{}{}".format(" " * len(head), line) for line in lines[1:]]
-
- def __repr__(self) -> str:
- body = [self.__class__.__name__]
- if self.transform is not None:
- body += self._format_transform_repr(self.transform, "Transform: ")
- if self.target_transform is not None:
- body += self._format_transform_repr(self.target_transform, "Target transform: ")
-
- return "\n".join(body)
\ No newline at end of file
diff --git a/python/jittor/compatibility/vision/transforms.py b/python/jittor/compatibility/vision/transforms.py
deleted file mode 100644
index 416057c7..00000000
--- a/python/jittor/compatibility/vision/transforms.py
+++ /dev/null
@@ -1 +0,0 @@
-from jittor.transform import *
\ No newline at end of file
diff --git a/python/jittor/compatibility/vision/utils.py b/python/jittor/compatibility/vision/utils.py
deleted file mode 100644
index 4be36c64..00000000
--- a/python/jittor/compatibility/vision/utils.py
+++ /dev/null
@@ -1,582 +0,0 @@
-import collections
-import math
-import pathlib
-import warnings
-from itertools import repeat
-from types import FunctionType
-from typing import Any, BinaryIO, List, Optional, Tuple, Union
-
-import numpy as np
-import torch
-from PIL import Image, ImageColor, ImageDraw, ImageFont
-
-__all__ = [
- "make_grid",
- "save_image",
- "draw_bounding_boxes",
- "draw_segmentation_masks",
- "draw_keypoints",
- "flow_to_image",
-]
-
-
-@torch.no_grad()
-def make_grid(
- tensor: Union[torch.Tensor, List[torch.Tensor]],
- nrow: int = 8,
- padding: int = 2,
- normalize: bool = False,
- value_range: Optional[Tuple[int, int]] = None,
- scale_each: bool = False,
- pad_value: float = 0.0,
- **kwargs,
-) -> torch.Tensor:
- """
- Make a grid of images.
-
- Args:
- tensor (Tensor or list): 4D mini-batch Tensor of shape (B x C x H x W)
- or a list of images all of the same size.
- nrow (int, optional): Number of images displayed in each row of the grid.
- The final grid size is ``(B / nrow, nrow)``. Default: ``8``.
- padding (int, optional): amount of padding. Default: ``2``.
- normalize (bool, optional): If True, shift the image to the range (0, 1),
- by the min and max values specified by ``value_range``. Default: ``False``.
- value_range (tuple, optional): tuple (min, max) where min and max are numbers,
- then these numbers are used to normalize the image. By default, min and max
- are computed from the tensor.
- range (tuple. optional):
- .. warning::
- This parameter was deprecated in ``0.12`` and will be removed in ``0.14``. Please use ``value_range``
- instead.
- scale_each (bool, optional): If ``True``, scale each image in the batch of
- images separately rather than the (min, max) over all images. Default: ``False``.
- pad_value (float, optional): Value for the padded pixels. Default: ``0``.
-
- Returns:
- grid (Tensor): the tensor containing grid of images.
- """
- if not torch.jit.is_scripting() and not torch.jit.is_tracing():
- _log_api_usage_once(make_grid)
- if not torch.is_tensor(tensor):
- if isinstance(tensor, list):
- for t in tensor:
- if not torch.is_tensor(t):
- raise TypeError(f"tensor or list of tensors expected, got a list containing {type(t)}")
- else:
- raise TypeError(f"tensor or list of tensors expected, got {type(tensor)}")
-
- if "range" in kwargs.keys():
- warnings.warn(
- "The parameter 'range' is deprecated since 0.12 and will be removed in 0.14. "
- "Please use 'value_range' instead."
- )
- value_range = kwargs["range"]
-
- # if list of tensors, convert to a 4D mini-batch Tensor
- if isinstance(tensor, list):
- tensor = torch.stack(tensor, dim=0)
-
- if tensor.dim() == 2: # single image H x W
- tensor = tensor.unsqueeze(0)
- if tensor.dim() == 3: # single image
- if tensor.size(0) == 1: # if single-channel, convert to 3-channel
- tensor = torch.cat((tensor, tensor, tensor), 0)
- tensor = tensor.unsqueeze(0)
-
- if tensor.dim() == 4 and tensor.size(1) == 1: # single-channel images
- tensor = torch.cat((tensor, tensor, tensor), 1)
-
- if normalize is True:
- tensor = tensor.clone() # avoid modifying tensor in-place
- if value_range is not None and not isinstance(value_range, tuple):
- raise TypeError("value_range has to be a tuple (min, max) if specified. min and max are numbers")
-
- def norm_ip(img, low, high):
- img.clamp_(min=low, max=high)
- img.sub_(low).div_(max(high - low, 1e-5))
-
- def norm_range(t, value_range):
- if value_range is not None:
- norm_ip(t, value_range[0], value_range[1])
- else:
- norm_ip(t, float(t.min()), float(t.max()))
-
- if scale_each is True:
- for t in tensor: # loop over mini-batch dimension
- norm_range(t, value_range)
- else:
- norm_range(tensor, value_range)
-
- if not isinstance(tensor, torch.Tensor):
- raise TypeError("tensor should be of type torch.Tensor")
- if tensor.size(0) == 1:
- return tensor.squeeze(0)
-
- # make the mini-batch of images into a grid
- nmaps = tensor.size(0)
- xmaps = min(nrow, nmaps)
- ymaps = int(math.ceil(float(nmaps) / xmaps))
- height, width = int(tensor.size(2) + padding), int(tensor.size(3) + padding)
- num_channels = tensor.size(1)
- grid = tensor.new_full((num_channels, height * ymaps + padding, width * xmaps + padding), pad_value)
- k = 0
- for y in range(ymaps):
- for x in range(xmaps):
- if k >= nmaps:
- break
- # Tensor.copy_() is a valid method but seems to be missing from the stubs
- # https://pytorch.org/docs/stable/tensors.html#torch.Tensor.copy_
- grid.narrow(1, y * height + padding, height - padding).narrow( # type: ignore[attr-defined]
- 2, x * width + padding, width - padding
- ).copy_(tensor[k])
- k = k + 1
- return grid
-
-
-@torch.no_grad()
-def save_image(
- tensor: Union[torch.Tensor, List[torch.Tensor]],
- fp: Union[str, pathlib.Path, BinaryIO],
- format: Optional[str] = None,
- **kwargs,
-) -> None:
- """
- Save a given Tensor into an image file.
-
- Args:
- tensor (Tensor or list): Image to be saved. If given a mini-batch tensor,
- saves the tensor as a grid of images by calling ``make_grid``.
- fp (string or file object): A filename or a file object
- format(Optional): If omitted, the format to use is determined from the filename extension.
- If a file object was used instead of a filename, this parameter should always be used.
- **kwargs: Other arguments are documented in ``make_grid``.
- """
-
- if not torch.jit.is_scripting() and not torch.jit.is_tracing():
- _log_api_usage_once(save_image)
- grid = make_grid(tensor, **kwargs)
- # Add 0.5 after unnormalizing to [0, 255] to round to nearest integer
- ndarr = grid.mul(255).add_(0.5).clamp_(0, 255).permute(1, 2, 0).to("cpu", torch.uint8).numpy()
- im = Image.fromarray(ndarr)
- im.save(fp, format=format)
-
-
-@torch.no_grad()
-def draw_bounding_boxes(
- image: torch.Tensor,
- boxes: torch.Tensor,
- labels: Optional[List[str]] = None,
- colors: Optional[Union[List[Union[str, Tuple[int, int, int]]], str, Tuple[int, int, int]]] = None,
- fill: Optional[bool] = False,
- width: int = 1,
- font: Optional[str] = None,
- font_size: Optional[int] = None,
-) -> torch.Tensor:
-
- """
- Draws bounding boxes on given image.
- The values of the input image should be uint8 between 0 and 255.
- If fill is True, Resulting Tensor should be saved as PNG image.
-
- Args:
- image (Tensor): Tensor of shape (C x H x W) and dtype uint8.
- boxes (Tensor): Tensor of size (N, 4) containing bounding boxes in (xmin, ymin, xmax, ymax) format. Note that
- the boxes are absolute coordinates with respect to the image. In other words: `0 <= xmin < xmax < W` and
- `0 <= ymin < ymax < H`.
- labels (List[str]): List containing the labels of bounding boxes.
- colors (color or list of colors, optional): List containing the colors
- of the boxes or single color for all boxes. The color can be represented as
- PIL strings e.g. "red" or "#FF00FF", or as RGB tuples e.g. ``(240, 10, 157)``.
- By default, random colors are generated for boxes.
- fill (bool): If `True` fills the bounding box with specified color.
- width (int): Width of bounding box.
- font (str): A filename containing a TrueType font. If the file is not found in this filename, the loader may
- also search in other directories, such as the `fonts/` directory on Windows or `/Library/Fonts/`,
- `/System/Library/Fonts/` and `~/Library/Fonts/` on macOS.
- font_size (int): The requested font size in points.
-
- Returns:
- img (Tensor[C, H, W]): Image Tensor of dtype uint8 with bounding boxes plotted.
- """
-
- if not torch.jit.is_scripting() and not torch.jit.is_tracing():
- _log_api_usage_once(draw_bounding_boxes)
- if not isinstance(image, torch.Tensor):
- raise TypeError(f"Tensor expected, got {type(image)}")
- elif image.dtype != torch.uint8:
- raise ValueError(f"Tensor uint8 expected, got {image.dtype}")
- elif image.dim() != 3:
- raise ValueError("Pass individual images, not batches")
- elif image.size(0) not in {1, 3}:
- raise ValueError("Only grayscale and RGB images are supported")
- elif (boxes[:, 0] > boxes[:, 2]).any() or (boxes[:, 1] > boxes[:, 3]).any():
- raise ValueError(
- "Boxes need to be in (xmin, ymin, xmax, ymax) format. Use torchvision.ops.box_convert to convert them"
- )
-
- num_boxes = boxes.shape[0]
-
- if num_boxes == 0:
- warnings.warn("boxes doesn't contain any box. No box was drawn")
- return image
-
- if labels is None:
- labels: Union[List[str], List[None]] = [None] * num_boxes # type: ignore[no-redef]
- elif len(labels) != num_boxes:
- raise ValueError(
- f"Number of boxes ({num_boxes}) and labels ({len(labels)}) mismatch. Please specify labels for each box."
- )
-
- if colors is None:
- colors = _generate_color_palette(num_boxes)
- elif isinstance(colors, list):
- if len(colors) < num_boxes:
- raise ValueError(f"Number of colors ({len(colors)}) is less than number of boxes ({num_boxes}). ")
- else: # colors specifies a single color for all boxes
- colors = [colors] * num_boxes
-
- colors = [(ImageColor.getrgb(color) if isinstance(color, str) else color) for color in colors]
-
- if font is None:
- if font_size is not None:
- warnings.warn("Argument 'font_size' will be ignored since 'font' is not set.")
- txt_font = ImageFont.load_default()
- else:
- txt_font = ImageFont.truetype(font=font, size=font_size or 10)
-
- # Handle Grayscale images
- if image.size(0) == 1:
- image = torch.tile(image, (3, 1, 1))
-
- ndarr = image.permute(1, 2, 0).cpu().numpy()
- img_to_draw = Image.fromarray(ndarr)
- img_boxes = boxes.to(torch.int64).tolist()
-
- if fill:
- draw = ImageDraw.Draw(img_to_draw, "RGBA")
- else:
- draw = ImageDraw.Draw(img_to_draw)
-
- for bbox, color, label in zip(img_boxes, colors, labels): # type: ignore[arg-type]
- if fill:
- fill_color = color + (100,)
- draw.rectangle(bbox, width=width, outline=color, fill=fill_color)
- else:
- draw.rectangle(bbox, width=width, outline=color)
-
- if label is not None:
- margin = width + 1
- draw.text((bbox[0] + margin, bbox[1] + margin), label, fill=color, font=txt_font)
-
- return torch.from_numpy(np.array(img_to_draw)).permute(2, 0, 1).to(dtype=torch.uint8)
-
-
-@torch.no_grad()
-def draw_segmentation_masks(
- image: torch.Tensor,
- masks: torch.Tensor,
- alpha: float = 0.8,
- colors: Optional[Union[List[Union[str, Tuple[int, int, int]]], str, Tuple[int, int, int]]] = None,
-) -> torch.Tensor:
-
- """
- Draws segmentation masks on given RGB image.
- The values of the input image should be uint8 between 0 and 255.
-
- Args:
- image (Tensor): Tensor of shape (3, H, W) and dtype uint8.
- masks (Tensor): Tensor of shape (num_masks, H, W) or (H, W) and dtype bool.
- alpha (float): Float number between 0 and 1 denoting the transparency of the masks.
- 0 means full transparency, 1 means no transparency.
- colors (color or list of colors, optional): List containing the colors
- of the masks or single color for all masks. The color can be represented as
- PIL strings e.g. "red" or "#FF00FF", or as RGB tuples e.g. ``(240, 10, 157)``.
- By default, random colors are generated for each mask.
-
- Returns:
- img (Tensor[C, H, W]): Image Tensor, with segmentation masks drawn on top.
- """
-
- if not torch.jit.is_scripting() and not torch.jit.is_tracing():
- _log_api_usage_once(draw_segmentation_masks)
- if not isinstance(image, torch.Tensor):
- raise TypeError(f"The image must be a tensor, got {type(image)}")
- elif image.dtype != torch.uint8:
- raise ValueError(f"The image dtype must be uint8, got {image.dtype}")
- elif image.dim() != 3:
- raise ValueError("Pass individual images, not batches")
- elif image.size()[0] != 3:
- raise ValueError("Pass an RGB image. Other Image formats are not supported")
- if masks.ndim == 2:
- masks = masks[None, :, :]
- if masks.ndim != 3:
- raise ValueError("masks must be of shape (H, W) or (batch_size, H, W)")
- if masks.dtype != torch.bool:
- raise ValueError(f"The masks must be of dtype bool. Got {masks.dtype}")
- if masks.shape[-2:] != image.shape[-2:]:
- raise ValueError("The image and the masks must have the same height and width")
-
- num_masks = masks.size()[0]
- if colors is not None and num_masks > len(colors):
- raise ValueError(f"There are more masks ({num_masks}) than colors ({len(colors)})")
-
- if num_masks == 0:
- warnings.warn("masks doesn't contain any mask. No mask was drawn")
- return image
-
- if colors is None:
- colors = _generate_color_palette(num_masks)
-
- if not isinstance(colors, list):
- colors = [colors]
- if not isinstance(colors[0], (tuple, str)):
- raise ValueError("colors must be a tuple or a string, or a list thereof")
- if isinstance(colors[0], tuple) and len(colors[0]) != 3:
- raise ValueError("It seems that you passed a tuple of colors instead of a list of colors")
-
- out_dtype = torch.uint8
-
- colors_ = []
- for color in colors:
- if isinstance(color, str):
- color = ImageColor.getrgb(color)
- colors_.append(torch.tensor(color, dtype=out_dtype))
-
- img_to_draw = image.detach().clone()
- # TODO: There might be a way to vectorize this
- for mask, color in zip(masks, colors_):
- img_to_draw[:, mask] = color[:, None]
-
- out = image * (1 - alpha) + img_to_draw * alpha
- return out.to(out_dtype)
-
-
-@torch.no_grad()
-def draw_keypoints(
- image: torch.Tensor,
- keypoints: torch.Tensor,
- connectivity: Optional[List[Tuple[int, int]]] = None,
- colors: Optional[Union[str, Tuple[int, int, int]]] = None,
- radius: int = 2,
- width: int = 3,
-) -> torch.Tensor:
-
- """
- Draws Keypoints on given RGB image.
- The values of the input image should be uint8 between 0 and 255.
-
- Args:
- image (Tensor): Tensor of shape (3, H, W) and dtype uint8.
- keypoints (Tensor): Tensor of shape (num_instances, K, 2) the K keypoints location for each of the N instances,
- in the format [x, y].
- connectivity (List[Tuple[int, int]]]): A List of tuple where,
- each tuple contains pair of keypoints to be connected.
- colors (str, Tuple): The color can be represented as
- PIL strings e.g. "red" or "#FF00FF", or as RGB tuples e.g. ``(240, 10, 157)``.
- radius (int): Integer denoting radius of keypoint.
- width (int): Integer denoting width of line connecting keypoints.
-
- Returns:
- img (Tensor[C, H, W]): Image Tensor of dtype uint8 with keypoints drawn.
- """
-
- if not torch.jit.is_scripting() and not torch.jit.is_tracing():
- _log_api_usage_once(draw_keypoints)
- if not isinstance(image, torch.Tensor):
- raise TypeError(f"The image must be a tensor, got {type(image)}")
- elif image.dtype != torch.uint8:
- raise ValueError(f"The image dtype must be uint8, got {image.dtype}")
- elif image.dim() != 3:
- raise ValueError("Pass individual images, not batches")
- elif image.size()[0] != 3:
- raise ValueError("Pass an RGB image. Other Image formats are not supported")
-
- if keypoints.ndim != 3:
- raise ValueError("keypoints must be of shape (num_instances, K, 2)")
-
- ndarr = image.permute(1, 2, 0).cpu().numpy()
- img_to_draw = Image.fromarray(ndarr)
- draw = ImageDraw.Draw(img_to_draw)
- img_kpts = keypoints.to(torch.int64).tolist()
-
- for kpt_id, kpt_inst in enumerate(img_kpts):
- for inst_id, kpt in enumerate(kpt_inst):
- x1 = kpt[0] - radius
- x2 = kpt[0] + radius
- y1 = kpt[1] - radius
- y2 = kpt[1] + radius
- draw.ellipse([x1, y1, x2, y2], fill=colors, outline=None, width=0)
-
- if connectivity:
- for connection in connectivity:
- start_pt_x = kpt_inst[connection[0]][0]
- start_pt_y = kpt_inst[connection[0]][1]
-
- end_pt_x = kpt_inst[connection[1]][0]
- end_pt_y = kpt_inst[connection[1]][1]
-
- draw.line(
- ((start_pt_x, start_pt_y), (end_pt_x, end_pt_y)),
- width=width,
- )
-
- return torch.from_numpy(np.array(img_to_draw)).permute(2, 0, 1).to(dtype=torch.uint8)
-
-
-# Flow visualization code adapted from https://github.com/tomrunia/OpticalFlow_Visualization
-@torch.no_grad()
-def flow_to_image(flow: torch.Tensor) -> torch.Tensor:
-
- """
- Converts a flow to an RGB image.
-
- Args:
- flow (Tensor): Flow of shape (N, 2, H, W) or (2, H, W) and dtype torch.float.
-
- Returns:
- img (Tensor): Image Tensor of dtype uint8 where each color corresponds
- to a given flow direction. Shape is (N, 3, H, W) or (3, H, W) depending on the input.
- """
-
- if flow.dtype != torch.float:
- raise ValueError(f"Flow should be of dtype torch.float, got {flow.dtype}.")
-
- orig_shape = flow.shape
- if flow.ndim == 3:
- flow = flow[None] # Add batch dim
-
- if flow.ndim != 4 or flow.shape[1] != 2:
- raise ValueError(f"Input flow should have shape (2, H, W) or (N, 2, H, W), got {orig_shape}.")
-
- max_norm = torch.sum(flow**2, dim=1).sqrt().max()
- epsilon = torch.finfo((flow).dtype).eps
- normalized_flow = flow / (max_norm + epsilon)
- img = _normalized_flow_to_image(normalized_flow)
-
- if len(orig_shape) == 3:
- img = img[0] # Remove batch dim
- return img
-
-
-@torch.no_grad()
-def _normalized_flow_to_image(normalized_flow: torch.Tensor) -> torch.Tensor:
-
- """
- Converts a batch of normalized flow to an RGB image.
-
- Args:
- normalized_flow (torch.Tensor): Normalized flow tensor of shape (N, 2, H, W)
- Returns:
- img (Tensor(N, 3, H, W)): Flow visualization image of dtype uint8.
- """
-
- N, _, H, W = normalized_flow.shape
- device = normalized_flow.device
- flow_image = torch.zeros((N, 3, H, W), dtype=torch.uint8, device=device)
- colorwheel = _make_colorwheel().to(device) # shape [55x3]
- num_cols = colorwheel.shape[0]
- norm = torch.sum(normalized_flow**2, dim=1).sqrt()
- a = torch.atan2(-normalized_flow[:, 1, :, :], -normalized_flow[:, 0, :, :]) / torch.pi
- fk = (a + 1) / 2 * (num_cols - 1)
- k0 = torch.floor(fk).to(torch.long)
- k1 = k0 + 1
- k1[k1 == num_cols] = 0
- f = fk - k0
-
- for c in range(colorwheel.shape[1]):
- tmp = colorwheel[:, c]
- col0 = tmp[k0] / 255.0
- col1 = tmp[k1] / 255.0
- col = (1 - f) * col0 + f * col1
- col = 1 - norm * (1 - col)
- flow_image[:, c, :, :] = torch.floor(255 * col)
- return flow_image
-
-
-def _make_colorwheel() -> torch.Tensor:
- """
- Generates a color wheel for optical flow visualization as presented in:
- Baker et al. "A Database and Evaluation Methodology for Optical Flow" (ICCV, 2007)
- URL: http://vision.middlebury.edu/flow/flowEval-iccv07.pdf.
-
- Returns:
- colorwheel (Tensor[55, 3]): Colorwheel Tensor.
- """
-
- RY = 15
- YG = 6
- GC = 4
- CB = 11
- BM = 13
- MR = 6
-
- ncols = RY + YG + GC + CB + BM + MR
- colorwheel = torch.zeros((ncols, 3))
- col = 0
-
- # RY
- colorwheel[0:RY, 0] = 255
- colorwheel[0:RY, 1] = torch.floor(255 * torch.arange(0, RY) / RY)
- col = col + RY
- # YG
- colorwheel[col : col + YG, 0] = 255 - torch.floor(255 * torch.arange(0, YG) / YG)
- colorwheel[col : col + YG, 1] = 255
- col = col + YG
- # GC
- colorwheel[col : col + GC, 1] = 255
- colorwheel[col : col + GC, 2] = torch.floor(255 * torch.arange(0, GC) / GC)
- col = col + GC
- # CB
- colorwheel[col : col + CB, 1] = 255 - torch.floor(255 * torch.arange(CB) / CB)
- colorwheel[col : col + CB, 2] = 255
- col = col + CB
- # BM
- colorwheel[col : col + BM, 2] = 255
- colorwheel[col : col + BM, 0] = torch.floor(255 * torch.arange(0, BM) / BM)
- col = col + BM
- # MR
- colorwheel[col : col + MR, 2] = 255 - torch.floor(255 * torch.arange(MR) / MR)
- colorwheel[col : col + MR, 0] = 255
- return colorwheel
-
-
-def _generate_color_palette(num_objects: int):
- palette = torch.tensor([2**25 - 1, 2**15 - 1, 2**21 - 1])
- return [tuple((i * palette) % 255) for i in range(num_objects)]
-
-
-def _log_api_usage_once(obj: Any) -> None:
-
- """
- Logs API usage(module and name) within an organization.
- In a large ecosystem, it's often useful to track the PyTorch and
- TorchVision APIs usage. This API provides the similar functionality to the
- logging module in the Python stdlib. It can be used for debugging purpose
- to log which methods are used and by default it is inactive, unless the user
- manually subscribes a logger via the `SetAPIUsageLogger method `_.
- Please note it is triggered only once for the same API call within a process.
- It does not collect any data from open-source users since it is no-op by default.
- For more information, please refer to
- * PyTorch note: https://pytorch.org/docs/stable/notes/large_scale_deployments.html#api-usage-logging;
- * Logging policy: https://github.com/pytorch/vision/issues/5052;
-
- Args:
- obj (class instance or method): an object to extract info from.
- """
- pass
-
-
-def _make_ntuple(x: Any, n: int) -> Tuple[Any, ...]:
- """
- Make n-tuple from input x. If x is an iterable, then we just convert it to tuple.
- Otherwise we will make a tuple of length n, all with value of x.
- reference: https://github.com/pytorch/pytorch/blob/master/torch/nn/modules/utils.py#L8
-
- Args:
- x (Any): input value
- n (int): length of the resulting tuple
- """
- if isinstance(x, collections.abc.Iterable):
- return tuple(x)
- return tuple(repeat(x, n))
\ No newline at end of file
diff --git a/python/jittor/compile_extern.py b/python/jittor/compile_extern.py
index 4df000b4..f6c93b54 100644
--- a/python/jittor/compile_extern.py
+++ b/python/jittor/compile_extern.py
@@ -629,6 +629,7 @@ def setup_mpi():
mpi_ops = None
mpi = None
has_mpi = False
+ if not use_mpi: return
mpicc_path = env_or_try_find('mpicc_path', 'mpicc')
if mpicc_path == "":
# LOG.i("mpicc not found, distribution disabled.")
@@ -711,4 +712,4 @@ def inner(self, *args, **kw):
# install backend extern library
for mod in jit_utils.backends:
if mod.install_extern():
- break
\ No newline at end of file
+ break
diff --git a/python/jittor/compiler.py b/python/jittor/compiler.py
index c6be01bb..a5dcd136 100644
--- a/python/jittor/compiler.py
+++ b/python/jittor/compiler.py
@@ -1002,6 +1002,8 @@ def check_debug_flags():
r, s = sp.getstatusoutput(f"log_v=0 {sys.executable} -m jittor_utils.query_cuda_cc")
if r==0:
s = sorted(list(set(s.strip().split())))
+ if len(s)==0:
+ LOG.e("No GPU Device Found!")
cu += "_sm_" + "_".join(s)
if "cuda_arch" not in os.environ:
os.environ["cuda_arch"] = " ".join(cu)
diff --git a/python/jittor/contrib.py b/python/jittor/contrib.py
index 31dd72de..c4026eee 100644
--- a/python/jittor/contrib.py
+++ b/python/jittor/contrib.py
@@ -15,6 +15,8 @@
def argmax_pool(x, size, stride, padding=0):
+ if stride<=0:
+ raise RuntimeError(f"stride must be > 0, but got {stride}")
return pool.pool(x, size, 'maximum', padding, stride)
def concat(arr, dim):
@@ -241,9 +243,17 @@ def concat(arr, dim=0):
if len(arr) == 0:
raise ValueError("need at least one array to concat")
total_dim = 0
- if dim < 0: dim += len(arr[0].shape)
+ base_dim = len(arr[0].shape)
+ if dim < 0: dim += base_dim
+ if dim < 0 or dim >= base_dim:
+ raise IndexError(f"Dimension out of range (expected to be in range of [{-base_dim}, {base_dim-1}], but got {dim})")
dtypes = []
for a in arr:
+ if len(a.shape) != base_dim:
+ raise RuntimeError(f"get different number of dimensions of {base_dim} and {len(a.shape)}")
+ for i in range(base_dim):
+ if i != dim and a.shape[i] != arr[0].shape[i]:
+ raise RuntimeError(f"Sizes of vars must match except in dimension {dim}. Expected size {arr[0].shape[i]} but got size {a.shape[i]} for dimension number {i} in the list.")
total_dim += a.shape[dim]
dtypes.append(str(a.dtype))
cdim = 0
diff --git a/python/jittor/dataset/mnist.py b/python/jittor/dataset/mnist.py
index 7aa1d883..f0945f94 100644
--- a/python/jittor/dataset/mnist.py
+++ b/python/jittor/dataset/mnist.py
@@ -26,7 +26,7 @@ class MNIST(Dataset):
[in] data_root(str): your data root.
[in] train(bool): choose model train or val.
- [in] download(bool): Download data automatically if download is Ture.
+ [in] download(bool): Download data automatically if download is True.
[in] batch_size(int): Data batch size.
[in] shuffle(bool): Shuffle data if true.
[in] transform(jittor.transform): transform data.
@@ -106,7 +106,7 @@ class EMNIST(Dataset):
[in] data_root(str): your data root.
[in] split(str): one of 'byclass', 'bymerge', 'balanced', 'letters', 'digits', 'mnist'.
[in] train(bool): choose model train or val.
- [in] download(bool): Download data automatically if download is Ture.
+ [in] download(bool): Download data automatically if download is True.
[in] batch_size(int): Data batch size.
[in] shuffle(bool): Shuffle data if true.
[in] transform(jittor.transform): transform data.
diff --git a/python/jittor/extern/acl/acl_compiler.py b/python/jittor/extern/acl/acl_compiler.py
index fea9cfce..22e366eb 100644
--- a/python/jittor/extern/acl/acl_compiler.py
+++ b/python/jittor/extern/acl/acl_compiler.py
@@ -1,6 +1,6 @@
# ***************************************************************
-# Copyright (c) 2023 Jittor. All Rights Reserved.
-# Maintainers: Dun Liang .
+# Copyright (c) 2023 Jittor. All Rights Reserved.
+# Maintainers: Dun Liang .
# This file is subject to the terms and conditions defined in
# file 'LICENSE.txt', which is part of this source code package.
# ***************************************************************
@@ -10,6 +10,7 @@
import ctypes
import glob
import jittor.compiler as compiler
+import jittor as jt
has_acl = 0
cc_flags = ""
@@ -34,16 +35,18 @@
# export DUMP_GRAPH_LEVEL=1
# build pytorch-npu
-# bash ./ci/build.sh
-# python3 -m pip install ./dist/torch_npu-1.11.0.post1-cp37-cp37m-linux_x86_64.whl --force-reinstall
+# bash ./ci/build.sh
+# python3 -m pip install ./dist/torch_npu-1.11.0.post1-cp37-cp37m-linux_x86_64.whl --force-reinstall
# pytorch: conda activate cann && source /usr/local/Ascend/ascend-toolkit/set_env.sh && export TASK_QUEUE_ENABLE=0 && cd /home/cjld/new_jittor/jittor/my/mm_benchmark
# python3 ./mm_bench_pt_npu.py
+
def install():
import jittor.compiler as compiler
global has_acl, cc_flags
acl_compiler_home = os.path.dirname(__file__)
- cc_files = sorted(glob.glob(acl_compiler_home+"/**/*.cc", recursive=True))
+ cc_files = sorted(glob.glob(acl_compiler_home + "/**/*.cc",
+ recursive=True))
cc_files2 = []
for name in cc_files:
if "acl_op_exec" in name:
@@ -51,21 +54,24 @@ def install():
else:
cc_files2.append(name)
cc_files = cc_files2
+ ascend_toolkit_home = os.getenv('ASCEND_TOOLKIT_HOME')
cc_flags += f" -DHAS_CUDA -DIS_ACL \
- -I/usr/local/Ascend/ascend-toolkit/latest/x86_64-linux/include/ \
- -L/usr/local/Ascend/ascend-toolkit/latest/x86_64-linux/lib64 \
+ -I{ascend_toolkit_home}/include/ \
+ -L{ascend_toolkit_home}/lib64/ \
-I{acl_compiler_home} -lascendcl -lacl_op_compiler "
+
ctypes.CDLL("libascendcl.so", dlopen_flags)
- '''
+ f'''
-ltikc_runtime
- -I/usr/local/Ascend/driver/include \
- -L/usr/local/Ascend/compiler/lib64 \
- -L/usr/local/Ascend/runtime/lib64 \
+ -I/usr/local/Ascend/driver/include/ \
+ -L{ascend_toolkit_home}/compiler/lib64/ \
+ -L{ascend_toolkit_home}/runtime/lib64/ \
'''
jittor_utils.LOG.i("ACL detected")
global mod
- mod = jittor_utils.compile_module('''
+ mod = jittor_utils.compile_module(
+ '''
#include "common.h"
namespace jittor {
// @pyjt(process)
@@ -98,9 +104,10 @@ def check():
if not has_acl: return False
compiler.cc_flags += cc_flags
compiler.nvcc_path = tikcc_path
- compiler.nvcc_flags = compiler.cc_flags.replace("-std=c++14","")
+ compiler.nvcc_flags = compiler.cc_flags.replace("-std=c++14", "")
return True
+
def post_process():
if has_acl:
from jittor import pool
@@ -108,5 +115,1240 @@ def post_process():
import jittor as jt
jt.flags.use_cuda_host_allocator = 1
jt.flags.use_parallel_op_compiler = 0
- jt.flags.amp_reg |= 32 + 4 # 32 keep float16, 4 keep reduce type
- mod.init_acl_ops()
\ No newline at end of file
+ jt.flags.amp_reg |= 32 + 4 # 32 keep float16, 4 keep reduce type
+ mod.init_acl_ops()
+
+
+def acl_cmd(name: str, inputs: list, output_dtypes: list, output_shapes: list,
+ attr: dict):
+ nchw_op = ['MaxPoolWithArgmaxV1', 'MaxPoolGradWithArgmaxV1', 'AvgPoolV2']
+ attr_op = [
+ 'MaxPoolWithArgmaxV1', 'MaxPoolGradWithArgmaxV1', 'AvgPoolV2',
+ 'AdaptiveAvgPool2d', 'AdaptiveAvgPool2dGrad', 'ReverseV2'
+ ]
+
+ input_code = ''
+ for i in range(len(inputs)):
+ if name in nchw_op:
+ input_code += f"op.add(in{i}, true, ACL_FORMAT_NCHW);\n"
+ else:
+ input_code += f"op.add(in{i}, true);\n"
+
+ output_code = ''
+ for i in range(len(output_dtypes)):
+ if name in nchw_op:
+ output_code += f"op.add(out{i}, false, ACL_FORMAT_NCHW);\n"
+ else:
+ output_code += f"op.add(out{i}, false);\n"
+
+ # add attr to op
+ attr_code = ''
+ if name in attr_op:
+ for k, v in attr.items():
+ if isinstance(v, bool):
+ if v == True:
+ attr_code += f"op.set_attr(\"{k}\", 1, 1);\n"
+ else:
+ attr_code += f"op.set_attr(\"{k}\", 1, 0);\n"
+ elif isinstance(v, str):
+ attr_code += f"op.set_attr(\"{k}\", \"{v}\");\n"
+ elif k == 'divisor_override_value':
+ attr_code += f"op.set_attr(\"{k}\", int64_t({v}), 0);\n"
+ else:
+ v = str(v).replace('[', '{').replace(']', '}')
+ attr_code += f"op.set_attr(\"{k}\", vector{v});\n"
+ else:
+ for k, v in attr.items():
+ if isinstance(v, bool):
+ if v == True:
+ attr_code += f"op.set_attr(\"{k}\", 1, 1);\n"
+ else:
+ attr_code += f"op.set_attr(\"{k}\", 1, 0);\n"
+ elif isinstance(v, str):
+ attr_code += f"op.set_attr(\"{k}\", \"{v}\");\n"
+ else:
+ attr_code += f"op.set_attr(\"{k}\", int({v}));\n"
+
+ #print("input_code",input_code)
+ #print("attr_code",attr_code)
+ import jittor as jt
+ return jt.code(output_shapes,
+ output_dtypes,
+ inputs,
+ cuda_header="""
+ #include
+ #include
+ #include
+ #include
+
+ namespace jittor {
+
+ void printDeviceData(const vector& output_desc, const vector& output_data, const string& name = "", bool input=true) {
+ LOGir << "name: " << name;
+ if(input)
+ LOGir << "is input";
+ else
+ LOGir << "is ouput";
+ for (size_t i = 0; i < output_desc.size(); ++i) {
+ void* base_addr = aclGetDataBufferAddr(output_data[i]);
+ LOGir << "addr of data[" << i << "] :" << base_addr;
+ size_t num_dims = aclGetTensorDescNumDims(output_desc[i]);
+ size_t total_size = 1;
+ std::vector dims(num_dims);
+
+ std::cout << "shape of data: ";
+ for (size_t j = 0; j < num_dims; ++j) {
+ aclGetTensorDescDimV2(output_desc[i], j, &dims[j]);
+ total_size *= dims[j];
+ std::cout << dims[j] << ", ";
+ }
+ int evey_batch_size = total_size/dims[0];
+ std::cout << std::endl;
+
+ // for(int i= 0; i < dims[0]; i++) {
+ // evey_batch_size = 16;
+ // std::vector host_buffer(evey_batch_size);
+ // void* offset_addr = static_cast(base_addr) + i * evey_batch_size * sizeof(float);
+ // aclrtMemcpy(host_buffer.data(), evey_batch_size * sizeof(float), offset_addr, evey_batch_size * sizeof(float), ACL_MEMCPY_DEVICE_TO_HOST);
+ // std::cout << "batch[" << i << "]:";
+ // for (size_t k = 0; k < evey_batch_size; ++k) {
+ // std::cout << host_buffer[k] << ", ";
+ // }
+ // std::cout << std::endl;
+ // }
+ }
+ }
+
+ struct AclOpRunner {
+ string name;
+ vector input_desc;
+ vector output_desc;
+ vector input_data;
+ vector output_data;
+ aclopAttr *attr;
+ vector> input_host;
+ vector> input_host_32;
+
+ AclOpRunner(const string& name) : name(name) {
+ attr = aclopCreateAttr();
+ }
+
+ ~AclOpRunner() {
+ for (auto i : input_desc) aclDestroyTensorDesc(i);
+ for (auto i : output_desc) aclDestroyTensorDesc(i);
+ for (auto i : input_data) aclDestroyDataBuffer(i);
+ for (auto i : output_data) aclDestroyDataBuffer(i);
+ aclopDestroyAttr(attr);
+ }
+
+ aclDataType get_dtype(NanoString s) {
+ if (s == ns_float32) return ACL_FLOAT;
+ if (s == ns_float16) return ACL_FLOAT16;
+ if (s == ns_int64) return ACL_INT64;
+ if (s == ns_int32) return ACL_INT32;
+ if (s == ns_int8) return ACL_INT8;
+ if (s == ns_int16) return ACL_INT16;
+ if (s == ns_uint8) return ACL_UINT8;
+ if (s == ns_uint16) return ACL_UINT16;
+ if (s == ns_uint32) return ACL_UINT32;
+ if (s == ns_bool) return ACL_BOOL;
+ LOGf << "Not supported dtype: " << s;
+ return ACL_FLOAT;
+ }
+
+ void add(Var* v, bool is_input, int format=ACL_FORMAT_ND) {
+ int64_t shape[v->shape.size()];
+ for (int i=0; ishape.size(); i++) shape[i] = v->shape[i];
+
+ auto desc = aclCreateTensorDesc(get_dtype(v->dtype()), v->shape.size(), &shape[0], (aclFormat)format);
+ aclSetTensorFormat(desc, (aclFormat)format);
+ aclSetTensorShape(desc, v->shape.size(), &shape[0]);
+ LOGv << "aclCreateTensorDesc" << (int)get_dtype(v->dtype()) << v->shape.size() << &shape[0] << format;
+ auto data = aclCreateDataBuffer(v->mem_ptr, v->size);
+ LOGv << "aclCreateDataBuffer" << v->mem_ptr << v->size;
+ ASSERT(desc && data);
+ if (is_input) {
+ input_desc.push_back(desc);
+ input_data.push_back(data);
+ } else {
+ output_desc.push_back(desc);
+ output_data.push_back(data);
+ }
+ }
+
+ void add_input_host(vector v, int dtype=ACL_UINT64) {
+ int64_t shape[1];
+ shape[0] = v.size();
+ auto desc = aclCreateTensorDesc((aclDataType)dtype, 1, &shape[0], ACL_FORMAT_ND);
+ aclSetTensorPlaceMent(desc, ACL_MEMTYPE_HOST_COMPILE_INDEPENDENT);
+ LOGv << "aclCreateTensorDesc" << dtype << 1 << &shape[0] << ACL_FORMAT_ND;
+ auto data = aclCreateDataBuffer(&v[0], v.size()*sizeof(uint64));
+ ASSERT(desc && data);
+ LOGv << "aclCreateDataBuffer" << &v[0] << v.size()*sizeof(uint64);
+ input_desc.push_back(desc);
+ input_data.push_back(data);
+ input_host.emplace_back(move(v));
+ LOGv << "move" << input_host.back().data();
+ }
+
+ void add_input_host_scalar(vector v, int dtype=ACL_UINT32) {
+ int64_t shape[1];
+ shape[0] = v.size();
+ auto x = (int*)&v[0];
+ x[0] = (int32)v[0];
+ auto desc = aclCreateTensorDesc((aclDataType)dtype, 0, &shape[0], ACL_FORMAT_ND);
+ aclSetTensorPlaceMent(desc, ACL_MEMTYPE_HOST_COMPILE_INDEPENDENT);
+ LOGv << "aclCreateTensorDesc" << dtype << 1 << &shape[0] << ACL_FORMAT_ND;
+ auto data = aclCreateDataBuffer(&v[0], v.size()*sizeof(uint32));
+ ASSERT(desc && data);
+ LOGv << "aclCreateDataBuffer" << &v[0] << v.size()*sizeof(uint32);
+ input_desc.push_back(desc);
+ input_data.push_back(data);
+ input_host.emplace_back(move(v));
+ }
+
+ void add_input_host_nv(NanoVector nv, int dtype=ACL_UINT64) {
+ vector v(nv.size());
+ for (int i=0; i v(nv.size());
+ for (int i=0; i value) {
+ // LOGir << "string vector" << "set_attr" << key << value;
+ CHECK(aclopSetAttrListInt(attr, key.c_str(), value.size(), &value[0])==0);
+ }
+ void set_attr(const string& key, string value) {
+ // LOGir << "string string" << "set_attr" << key << value;
+ CHECK(aclopSetAttrString(attr, key.c_str(), value.c_str())==0);
+ }
+ void set_attr(const char* key, const char* value) {
+ // LOGir << "char" << "set_attr" << key << value;
+ CHECK(aclopSetAttrString(attr, key, value)==0);
+ }
+
+ void run() {
+ // printDeviceData(input_desc, input_data, name);
+
+ LOGv << "run" << name << input_desc.size() << output_desc.size();
+ if (!PyGILState_Check()) {
+ ASSERT(0==aclopCompileAndExecuteV2(name.c_str(), input_desc.size(), &input_desc[0], &input_data[0], output_desc.size(), &output_desc[0], &output_data[0], attr, ACL_ENGINE_SYS, ACL_COMPILE_SYS, NULL, aclstream));
+ } else {
+ int ret;
+ Py_BEGIN_ALLOW_THREADS
+ ret = aclopCompileAndExecuteV2(name.c_str(), input_desc.size(), &input_desc[0], &input_data[0], output_desc.size(), &output_desc[0], &output_data[0], attr, ACL_ENGINE_SYS, ACL_COMPILE_SYS, NULL, aclstream);
+ Py_END_ALLOW_THREADS
+ if (ret != 0)
+ LOGf << "aclopCompileAndExecuteV2" << name << "failed return" << ret;
+ }
+ ASSERT(0==aclrtSynchronizeDevice());
+
+ // printDeviceData(output_desc, output_data, name, false);
+ }
+ };
+
+ }
+ """,
+ cuda_src=f"""
+ // aclop
+ AclOpRunner op("{name}");
+ {input_code}
+ {output_code}
+ {attr_code}
+ op.run();""")
+
+
+def change_function():
+ import jittor as jt
+ from jittor import Function
+
+ class IndexACL(Function):
+
+ def __init__(self):
+ super(IndexACL, self).__init__()
+
+ def execute(self, inshape: list, dim, dtype="int32"):
+ # zeros a tensor, shape is inshape, dtype is dtype
+ dim_input = dim
+ if dim == None:
+ dim = [i for i in range(len(inshape))]
+ elif type(dim) == int:
+ dim = [dim]
+ results = []
+ for d in dim:
+ max_len = inshape[d]
+ tmp = jt.zeros(max_len, dtype=dtype)
+ result = acl_cmd(
+ "Range", [jt.Var(0), jt.Var(max_len),
+ jt.Var(1)],
+ output_dtypes=[tmp.dtype],
+ output_shapes=[tmp.shape],
+ attr={})[0]
+ broadcast_dim = []
+ for i in range(len(inshape)):
+ if i != d:
+ broadcast_dim.append(i)
+ result = jt.broadcast(result,
+ shape=inshape,
+ dims=broadcast_dim)
+ results.append(result)
+ if len(results) != 1 or dim_input == None:
+ return tuple(results)
+ else:
+ return results[0]
+
+ def grad(self, grad_output):
+ return grad_output
+
+ class PoolACL(Function):
+
+ def get_paddings(self):
+ pad_top = self.padding[0]
+ pad_left = self.padding[1]
+ H = self.input.shape[-2]
+ W = self.input.shape[-1]
+
+ totalH = H + 2 * self.padding[0] - self.kernel_size[0]
+ totalW = W + 2 * self.padding[1] - self.kernel_size[1]
+
+ kH = (totalH + self.stride[0] -
+ 1) // self.stride[0] + 1 if self.attr[
+ 'ceil_mode'] else totalH // self.stride[0] + 1
+ kW = (totalW + self.stride[1] -
+ 1) // self.stride[1] + 1 if self.attr[
+ 'ceil_mode'] else totalW // self.stride[1] + 1
+
+ if self.attr['ceil_mode']:
+ if (kH - 1) * self.stride[0] >= H + self.padding[0]:
+ kH -= 1
+ need_pad_h = (kH -
+ 1) * self.stride[0] + self.kernel_size[0] - H
+ pad_top = need_pad_h - self.padding[0]
+ if (kW - 1) * self.stride[1] >= W + self.padding[1]:
+ kW -= 1
+ need_pad_w = (kW -
+ 1) * self.stride[1] + self.kernel_size[1] - W
+ pad_left = need_pad_w - self.padding[1]
+
+ pads = [self.padding[0], pad_top, self.padding[1], pad_left]
+ return pads
+
+ def __init__(self,
+ kernel_size,
+ stride=None,
+ padding=0,
+ dilation=None,
+ return_indices=None,
+ ceil_mode=False,
+ count_include_pad=True,
+ op='maximum'):
+ super(PoolACL, self).__init__()
+ # set attr
+ self.kernel_size = kernel_size if isinstance(
+ kernel_size, tuple) else (kernel_size, kernel_size)
+ stride = stride if stride else kernel_size
+ self.stride = stride if isinstance(stride, tuple) else (stride,
+ stride)
+ self.padding = padding if isinstance(padding, tuple) else (padding,
+ padding)
+ dilation = dilation if dilation else 1
+ self.dilation = dilation if isinstance(
+ dilation, tuple) else (dilation, dilation)
+ attr = {}
+
+ self.return_indices = return_indices
+ self.uint16 = jt.Var(1).int32().dtype
+ self.op = op
+
+ if op == 'mean':
+ attr['exclusive'] = not count_include_pad
+ attr['global_pooling'] = False
+ attr['divisor_override_value'] = 0
+ attr['ksize'] = [
+ 1, 1, self.kernel_size[0], self.kernel_size[1]
+ ]
+ attr['strides'] = [1, 1, self.stride[0], self.stride[1]]
+ attr['ceil_mode'] = ceil_mode
+ attr['padding_mode'] = 'CALCULATED'
+ attr['data_format'] = 'NCHW'
+ elif op == 'maximum':
+ attr['ksize'] = [
+ 1, self.kernel_size[0], self.kernel_size[1], 1
+ ]
+ attr['strides'] = [1, self.stride[0], self.stride[1], 1]
+ attr['pads'] = [1, self.padding[0], self.padding[1], 1]
+ attr['dilation'] = [1, self.dilation[0], self.dilation[1], 1]
+ # attr['ceil_mode'] = ceil_mode
+
+ self.attr = attr
+
+ def execute(self, input):
+
+ # create input
+ input_shape = input.shape
+ input_dtype = input.dtype
+
+ self.input = input
+ # create output
+ output_shape = [
+ input_shape[0], input_shape[1],
+ (input_shape[2] + 2 * self.padding[0] - self.dilation[0] *
+ (self.kernel_size[0] - 1) - 1) // self.stride[0] + 1,
+ (input_shape[3] + 2 * self.padding[1] - self.dilation[1] *
+ (self.kernel_size[1] - 1) - 1) // self.stride[1] + 1
+ ]
+ output_dtype = input_dtype
+
+ if self.op == 'mean':
+ self.attr['pads'] = self.get_paddings()
+ result = acl_cmd("AvgPoolV2", [input],
+ output_dtypes=[output_dtype],
+ output_shapes=[output_shape],
+ attr=self.attr)
+ elif self.op == 'maximum':
+ result = acl_cmd("MaxPoolWithArgmaxV1", [input],
+ output_dtypes=[output_dtype, self.uint16],
+ output_shapes=[output_shape, output_shape],
+ attr=self.attr)
+ else:
+ raise ValueError('no this type pool')
+
+ if self.op == 'maximum':
+ self.index = result[1]
+
+ if self.return_indices:
+ return result[0], result[1]
+ else:
+ return result[0]
+
+ def grad(self, grad_output):
+ if self.op == 'maximum':
+ grad_input = acl_cmd("MaxPoolGradWithArgmaxV1",
+ [self.input, grad_output, self.index],
+ output_dtypes=[grad_output.dtype],
+ output_shapes=[self.input.shape],
+ attr=self.attr)[0]
+ elif self.op == 'mean':
+ grad_input = acl_cmd("AvgPoolV2",
+ [self.input, grad_output, self.index],
+ output_dtypes=[grad_output.dtype],
+ output_shapes=[self.input.shape],
+ attr=self.attr)[0]
+ else:
+ grad_input = None
+ return grad_input
+
+ class BmmACL(Function):
+
+ def __init__(self, adj_x1=False, adj_x2=False):
+ super(BmmACL, self).__init__()
+ self.adj_x1 = adj_x1
+ self.adj_x2 = adj_x2
+
+ def execute(self, x1, x2):
+ self.input = [x1, x2]
+ result = acl_cmd("BatchMatMul", [x1, x2],
+ output_dtypes=[x1.dtype],
+ output_shapes=[x1.shape[:-1] + x2.shape[-1:]],
+ attr={})[0]
+ return result
+
+ def grad(self, grad_output):
+ x1, x2 = self.input
+ grad_x1 = acl_cmd(
+ "BatchMatMul", [grad_output, x2.transpose(-2, -1)],
+ output_dtypes=[x1.dtype],
+ output_shapes=[grad_output.shape[:-1] + x1.shape[-1:]],
+ attr={})[0]
+ grad_x2 = acl_cmd(
+ "BatchMatMul", [x1.transpose(-2, -1), grad_output],
+ output_dtypes=[x2.dtype],
+ output_shapes=[x2.shape[:-1] + grad_output.shape[-1:]],
+ attr={})[0]
+ return grad_x1, grad_x2
+
+ class MatmulACL(Function):
+
+ def __init__(self, adj_x1=False, adj_x2=False):
+ super(MatmulACL, self).__init__()
+ self.adj_x1 = adj_x1
+ self.adj_x2 = adj_x2
+
+ def execute(self, x1, x2):
+ self.input = [x1, x2]
+ if len(x1.shape) > 2 or len(x2.shape) > 2:
+ result = acl_cmd("BatchMatMul", [x1, x2],
+ output_dtypes=[x1.dtype],
+ output_shapes=[x1.shape[:-1] + x2.shape[-1:]],
+ attr={})[0]
+ else:
+ result = acl_cmd("MatMul", [x1, x2],
+ output_dtypes=[x1.dtype],
+ output_shapes=[x1.shape[:-1] + x2.shape[-1:]],
+ attr={})[0]
+ return result
+
+ def grad(self, grad_output):
+ x1, x2 = self.input
+ if len(x1.shape) > 2 or len(x2.shape) > 2:
+ grad_x1 = acl_cmd(
+ "BatchMatMul",
+ [grad_output, x2.transpose(-2, -1)],
+ output_dtypes=[x1.dtype],
+ output_shapes=[grad_output.shape[:-1] + x1.shape[-1:]],
+ attr={})[0]
+ grad_x2 = acl_cmd(
+ "BatchMatMul", [x1.transpose(-2, -1), grad_output],
+ output_dtypes=[x2.dtype],
+ output_shapes=[x2.shape[:-1] + grad_output.shape[-1:]],
+ attr={})[0]
+ else:
+ grad_x1 = acl_cmd(
+ "MatMul", [grad_output, x2.transpose(-2, -1)],
+ output_dtypes=[x1.dtype],
+ output_shapes=[grad_output.shape[:-1] + x1.shape[-1:]],
+ attr={})[0]
+ grad_x2 = acl_cmd(
+ "MatMul", [x1.transpose(-2, -1), grad_output],
+ output_dtypes=[x2.dtype],
+ output_shapes=[x2.shape[:-1] + grad_output.shape[-1:]],
+ attr={})[0]
+ return grad_x1, grad_x2
+
+ class GetItem(Function):
+
+ def __init__(self):
+ super(GetItem, self).__init__()
+ self.type_ = 'index'
+
+ def stride(self, x, dim):
+ stride = 1
+ for i in range(dim + 1, len(x.shape)):
+ stride *= x.shape[i]
+ return stride
+
+ def execute(self, x, slices, return_x=None):
+ if isinstance(slices, jt.Var) or isinstance(slices, tuple):
+ if isinstance(slices, jt.Var):
+ slices = (slices, )
+ if isinstance(slices[0], jt.Var):
+ slices_len = len(slices)
+ masks = jt.ones(slices_len, dtype=jt.int64)
+ output = slices[0].shape
+ output += x.shape[slices_len:]
+ input_ = [x, masks, jt.Var(list(output)).int64()]
+ for i in range(slices_len):
+ input_.append(slices[i].int32())
+ result = acl_cmd("Index",
+ input_,
+ output_dtypes=[x.dtype],
+ output_shapes=[output],
+ attr={})[0]
+ self.shape = x.shape
+ self.sizes = list(output)
+ self.type_ = 'index'
+ self.slices = slices
+ # self.strides
+ return result
+
+ # use AsStrided operator to implement the getitem function
+ # get the shape and stride of the input tensor
+ x_dim = len(x.shape)
+ # int type
+ if not isinstance(slices, tuple):
+ slices = (slices, )
+
+ if len(slices) < x_dim:
+ slices += (slice(None, None, None), ) * (x_dim - len(slices))
+
+ self.inputs = [x, slices]
+
+ sizes = []
+ strides = []
+ offset = 0
+
+ for dim, s in enumerate(slices):
+ if isinstance(s, int):
+ if s < 0: # Handle negative indices.
+ s += x.shape[dim]
+ offset += s * self.stride(x, dim)
+ elif isinstance(s, slice):
+ # Unpack the slice
+ start, stop, step = s.indices(x.size(dim))
+ size = (stop - start - 1) // step + 1
+ stride = self.stride(x, dim) * step
+ offset += start * self.stride(x, dim)
+ sizes.append(size)
+ strides.append(stride)
+ else:
+ raise ValueError("Invalid slice type")
+
+ if not sizes:
+ sizes = [1]
+ strides = [0]
+ # AsStrided same with as_strided of pytorch
+ self.sizes = sizes
+ self.strides = strides
+ self.offset = offset
+ self.shape = x.shape
+ self.type_ = 'as_strided'
+ result = acl_cmd(
+ "AsStrided",
+ [x, jt.Var(sizes),
+ jt.Var(strides),
+ jt.Var(offset)],
+ output_dtypes=[x.dtype],
+ output_shapes=[jt.empty(sizes).shape],
+ attr={})[0]
+ return result
+
+ def grad(self, grad_output):
+ if self.type_ == 'as_strided':
+ result = jt.zeros(self.shape, dtype=grad_output.dtype)
+ sizes = list(grad_output.shape)
+ strides = [
+ self.stride(grad_output, dim)
+ for dim in range(len(grad_output.shape))
+ ]
+ result = acl_cmd("ViewCopy", [
+ result,
+ jt.Var(self.sizes),
+ jt.Var(self.strides),
+ jt.Var(self.offset), grad_output,
+ jt.Var(sizes),
+ jt.Var(strides),
+ jt.Var(0)
+ ],
+ output_dtypes=[result.dtype],
+ output_shapes=[result.shape],
+ attr={})[0]
+ elif self.type_ == 'index':
+ #TODO: use IndexPutV2 to implement the grad function
+ assert len(self.slices) == 1
+ index = self.slices[0]
+ input = jt.zeros(self.shape, dtype=grad_output.dtype)
+ input_flatten = input.reshape(input.shape[0], -1)
+ index_flatten = index.reshape(-1).unsqueeze(-1).repeat(
+ 1, input_flatten.shape[1])
+ grad_output_flatten = grad_output.reshape(index.numel(), -1)
+ result = acl_cmd(
+ "ScatterElements",
+ [input_flatten, index_flatten, grad_output_flatten],
+ output_dtypes=[input.dtype],
+ output_shapes=[input.shape],
+ attr={
+ 'axis': 0,
+ 'reduction': 'add'
+ })[0]
+ result = result.reshape(self.shape)
+ # result = jt.zeros(self.shape, dtype=grad_output.dtype)
+ # # masks = jt.ones(len(self.slices), dtype=jt.int64)
+ # masks = jt.array([1,1], dtype=jt.int64)
+ # expand_masks = jt.array([1,1], dtype=jt.int64)
+ # inputs_ = [result,grad_output,masks,expand_masks]
+ # slices_len = len(self.slices)
+ # for i in range(slices_len):
+ # inputs_.append(self.slices[i].int64())
+ # # breakpoint()
+ # jt.sync_all(True)
+ # print(inputs_)
+ # result_ = acl_cmd("IndexPutV2", inputs_,
+ # output_dtypes=[result.dtype],
+ # output_shapes=[result.shape],
+ # attr={"accumulate":True})[0]
+ # result = result_
+ else:
+ raise ValueError("Invalid slice type")
+ result.sync()
+ return result, None
+
+ class ConcatACL(Function):
+
+ def __init__(self):
+ super(ConcatACL, self).__init__()
+
+ def execute(self, input_tensors, dim=0):
+ self.input = input_tensors
+ for i in range(len(input_tensors)):
+ if input_tensors[i].dtype != input_tensors[0].dtype:
+ raise ValueError(
+ "All input tensors must have the same dtype")
+ if input_tensors[i].shape[:dim] != input_tensors[
+ 0].shape[:dim] or input_tensors[i].shape[
+ dim + 1:] != input_tensors[0].shape[dim + 1:]:
+ raise ValueError(
+ "All input tensors must have the same shape")
+ result = acl_cmd(
+ "ConcatD",
+ input_tensors,
+ output_dtypes=[input_tensors[0].dtype],
+ output_shapes=[
+ jt.empty(self.calculate_output_shape(input_tensors,
+ dim)).shape
+ ],
+ attr={
+ "N": len(input_tensors),
+ "concat_dim": dim
+ })[0]
+ return result
+
+ def grad(self, grad_output):
+ grad_inputs = self.split_grad(grad_output, self.input, self.axis)
+ return grad_inputs
+
+ def calculate_output_shape(self, input_tensors, axis):
+ shape = list(input_tensors[0].shape)
+ for tensor in input_tensors[1:]:
+ shape[axis] += tensor.shape[axis]
+ return tuple(shape)
+
+ def split_grad(self, grad_output, input_tensors, axis):
+ offset = 0
+ grad_inputs = []
+ for tensor in input_tensors:
+ grad_input = acl_cmd("Slice", [
+ grad_output, [0] * axis + [offset] + [0] *
+ (len(tensor.shape) - axis - 1), tensor.shape
+ ])
+ grad_inputs.append(grad_input)
+ offset += tensor.shape[axis]
+ return grad_inputs
+
+ class SetItemACL(Function):
+
+ def __init__(self):
+ super(SetItemACL, self).__init__()
+
+ def stride(self, x, dim):
+ # 计算给定维度的步长
+ stride = 1
+ for i in range(dim + 1, len(x.shape)):
+ stride *= x.shape[i]
+ return stride
+
+ def execute(self, x, slices, value, reduce='void'):
+ self.is_tensor = type(value) == jt.Var
+ if type(value) != jt.Var:
+ value = jt.array(value)
+ x_dim = len(x.shape)
+
+ # 确保slices是一个元组
+ if not isinstance(slices, tuple):
+ slices = (slices, )
+
+ # 补齐slices使其长度等于x的维度
+ if len(slices) < x_dim:
+ slices += (slice(None, None, None), ) * (x_dim - len(slices))
+
+ self.inputs = [x, slices, value]
+
+ target_sizes = []
+ target_strides = []
+ offset = 0
+
+ for dim, s in enumerate(slices):
+ if isinstance(s, int):
+ if s < 0:
+ s += x.shape[dim]
+ s = slice(s, s + 1, None)
+ if isinstance(s, slice):
+ # 解包切片
+ start, stop, step = s.indices(x.shape[dim])
+ size = (stop - start - 1) // step + 1
+ stride = self.stride(x, dim) * step
+ offset += start * self.stride(x, dim)
+ target_sizes.append(size)
+ target_strides.append(stride)
+ else:
+ print("slices: ", s, type(s))
+ raise ValueError("Invalid slice type")
+
+ # 计算value的size、stride和offset
+ value_sizes = list(value.shape)
+ value_strides = [
+ self.stride(value, dim) for dim in range(len(value.shape))
+ ]
+
+ self.target_sizes = target_sizes
+ self.target_strides = target_strides
+ self.offset = offset
+ self.value_sizes = value_sizes
+ self.value_strides = value_strides
+
+ #import pdb; pdb.set_trace()
+ result = acl_cmd("ViewCopy", [
+ x,
+ jt.Var(target_sizes),
+ jt.Var(target_strides),
+ jt.Var(offset), value,
+ jt.Var(value_sizes),
+ jt.Var(value_strides),
+ jt.Var(0)
+ ],
+ output_dtypes=[x.dtype],
+ output_shapes=[x.shape],
+ attr={})[0]
+ result.sync()
+ return result
+
+ def grad(self, grad_output):
+ result = acl_cmd("AsStrided", [
+ grad_output,
+ jt.Var(self.target_sizes),
+ jt.Var(self.target_strides),
+ jt.Var(self.offset)
+ ],
+ output_dtypes=[grad_output.dtype],
+ output_shapes=[jt.empty(self.target_sizes).shape],
+ attr={})[0]
+ # copy grad_output to new_grad_output
+ new_grad_output = acl_cmd("Copy", [grad_output],
+ output_dtypes=[grad_output.dtype],
+ output_shapes=[grad_output.shape],
+ attr={"N": 1})[0]
+ new_grad_output = acl_cmd("ViewCopy", [
+ new_grad_output,
+ jt.Var(self.target_sizes),
+ jt.Var(self.target_strides),
+ jt.Var(self.offset),
+ jt.zeros(self.value_sizes, dtype=grad_output.dtype),
+ jt.Var(self.value_sizes),
+ jt.Var(self.value_strides),
+ jt.Var(0)
+ ],
+ output_dtypes=[grad_output.dtype],
+ output_shapes=[grad_output.shape],
+ attr={})[0]
+ new_grad_output.sync()
+ return new_grad_output, None, result if self.is_tensor else None
+
+ class TriuACL(Function):
+
+ def __init__(self):
+ super(TriuACL, self).__init__()
+
+ def execute(self, input, k):
+ self.input = input
+ result = acl_cmd("Triu", [input],
+ output_dtypes=[input.dtype],
+ output_shapes=[input.shape],
+ attr={'diagonal': k})[0]
+ return result
+
+ def grad(self, grad_output):
+ return grad_output
+
+ class TransposeACL(Function):
+
+ def __init__(self):
+ super(TransposeACL, self).__init__()
+
+ def execute(self, input, perm):
+ self.input = input
+
+ output_shape = input.shape[perm[0]:perm[0] + 1]
+
+ for i in range(1, len(perm)):
+ output_shape += input.shape[perm[i]:perm[i] + 1]
+ result = acl_cmd("Transpose", [input, jt.Var(perm)],
+ output_dtypes=[input.dtype],
+ output_shapes=[output_shape],
+ attr={})[0]
+ return result
+
+ def grad(self, grad_output):
+ return grad_output
+
+ class AdaptiveMaxPool2dACL(Function):
+
+ def __init__(
+ self,
+ output_size,
+ return_indices=False,
+ ):
+ super(AdaptiveMaxPool2dACL, self).__init__()
+ self.output_size = (output_size, output_size) if isinstance(
+ output_size, int) else output_size
+
+ self.return_indices = return_indices
+ self.uint16 = jt.Var(1).int32().dtype
+
+ attr = {}
+ attr['ceil_mode'] = False
+ attr['dilations'] = [1, 1, 1, 1]
+ self.attr = attr
+
+ def execute(self, input):
+ input_shape = input.shape
+ input_dtype = input.dtype
+
+ output_shape = [
+ input_shape[0], input_shape[1], self.output_size[0],
+ self.output_size[1]
+ ]
+ output_dtype = input_dtype
+ self.input = input
+
+ stride_h = input_shape[2] // output_shape[2]
+ stride_w = input_shape[3] // output_shape[3]
+ kernel_size_h = input_shape[2] - (output_shape[2] - 1) * stride_h
+ kernel_size_w = input_shape[3] - (output_shape[3] - 1) * stride_w
+
+ stride = [0, 0]
+ kernel_size = [0, 0]
+ padding = [0, 0]
+
+ stride[0] = stride_h
+ stride[1] = stride_w
+ kernel_size[0] = kernel_size_h
+ kernel_size[1] = kernel_size_w
+ padding[0] = padding[1] = 0
+ kernel_sizes = [1, kernel_size[0], kernel_size[1], 1]
+ strides_size = [1, stride[0], stride[1], 1]
+ paddings = [1, padding[0], padding[1], 1]
+
+ self.attr['ksize'] = kernel_sizes
+ self.attr['strides'] = strides_size
+ self.attr['pads'] = paddings
+
+ result = acl_cmd("MaxPoolWithArgmaxV1", [input],
+ output_dtypes=[output_dtype, self.uint16],
+ output_shapes=[output_shape, output_shape],
+ attr=self.attr)
+
+ self.index = result[1]
+
+ if self.return_indices:
+ return result[0], result[1]
+ else:
+ return result[0]
+
+ def grad(self, grad_output):
+ grad_input = acl_cmd("MaxPoolGradWithArgmaxV1",
+ [self.input, grad_output, self.index],
+ output_dtypes=[grad_output.dtype],
+ output_shapes=[self.input.shape],
+ attr=self.attr)[0]
+ return grad_input
+
+ class AdaptiveAvgPool2dACL(Function):
+
+ def __init__(self, output_size):
+ super(AdaptiveAvgPool2dACL, self).__init__()
+ self.output_size = (output_size, output_size) if isinstance(
+ output_size, int) else output_size
+
+ attr = {}
+ if isinstance(output_size, tuple):
+ output_size = [output_size[0], output_size[1]]
+ attr['output_size'] = output_size
+ self.attr = attr
+
+ def execute(self, input):
+ input_shape = input.shape
+ input_dtype = input.dtype
+
+ self.original_shape = input_shape
+
+ output_shape = [
+ input_shape[0], input_shape[1], self.attr['output_size'][0],
+ self.attr['output_size'][1]
+ ]
+ output_dtype = input_dtype
+ self.input = input
+
+ result = acl_cmd("AdaptiveAvgPool2d", [input],
+ output_dtypes=[output_dtype],
+ output_shapes=[output_shape],
+ attr=self.attr)
+
+ return result[0]
+
+ def grad(self, grad_output):
+ attr = {}
+ attr['orig_input_shape'] = list(self.original_shape)
+ grad_input = acl_cmd("AdaptiveAvgPool2dGrad", [grad_output],
+ output_dtypes=[grad_output.dtype],
+ output_shapes=[self.original_shape],
+ attr=attr)[0]
+ return grad_input
+
+ class CumsumACL(Function):
+
+ def __init__(self):
+ super(CumsumACL, self).__init__()
+
+ def execute(self, input, dim=-1):
+ self.input = input
+ self.dim = dim
+ result = acl_cmd("Cumsum", [input, jt.Var(dim)],
+ output_dtypes=[input.dtype],
+ output_shapes=[input.shape],
+ attr={})[0]
+ return result
+
+ def grad(self, grad_output):
+ flipped_grad_output = acl_cmd(
+ "ReverseV2", [grad_output, jt.Var([self.dim])],
+ output_dtypes=[grad_output.dtype],
+ output_shapes=[grad_output.shape],
+ attr={})[0]
+ cumulative_grad = acl_cmd(
+ "Cumsum",
+ [flipped_grad_output, jt.Var(self.dim)],
+ output_dtypes=[grad_output.dtype],
+ output_shapes=[grad_output.shape],
+ attr={})[0]
+ grad_input = acl_cmd(
+ "ReverseV2",
+ [cumulative_grad, jt.Var([self.dim])],
+ output_dtypes=[grad_output.dtype],
+ output_shapes=[grad_output.shape],
+ attr={})[0]
+ return grad_input
+
+ class GatherACL(Function):
+
+ def __init__(self):
+ super(GatherACL, self).__init__()
+
+ def execute(self, input, dim, index):
+ self.input = input
+ self.dim = dim
+ self.index = index
+
+ result = acl_cmd("GatherElements", [input, index],
+ output_dtypes=[input.dtype],
+ output_shapes=[index.shape],
+ attr={'dim': dim})[0]
+ return result
+
+ def grad(self, grad_output):
+ tmp = jt.zeros(self.index.shape, dtype=grad_output.dtype)
+ grad_input = acl_cmd("ScatterElements",
+ [tmp, self.index, grad_output],
+ output_dtypes=[grad_output.dtype],
+ output_shapes=[tmp.shape],
+ attr={
+ 'axis': self.dim,
+ 'reduction': "add"
+ })[0]
+ return grad_input
+
+ class ScatterACL(Function):
+
+ def __init__(self):
+ super(ScatterACL, self).__init__()
+
+ def execute(self, input, dim, index, src, reduce='void'):
+ self.input = input
+ self.dim = dim
+ self.index = index
+ self.reduce = reduce
+ result = acl_cmd("ScatterElements", [input, self.index, src],
+ output_dtypes=[input.dtype],
+ output_shapes=[input.shape],
+ attr={
+ 'axis': self.dim,
+ 'reduction': reduce
+ })[0]
+ return result
+
+ def grad(self, grad_output):
+ grad_input = acl_cmd("GatherElements", [grad_output, self.index],
+ output_dtypes=[grad_output.dtype],
+ output_shapes=[self.index.shape],
+ attr={'dim': self.dim})[0]
+ return grad_output, None, None, grad_input
+
+ class WhereACL(Function):
+
+ def __init__(self):
+ super(WhereACL, self).__init__()
+
+ def execute(self, condition, x, y):
+ self.condition = condition
+
+ if x.dtype != y.dtype:
+ if x.dtype == jt.float32:
+ y = y.float32()
+ elif y.dtype == jt.float32:
+ x = x.float32()
+ else:
+ x = x.to(y.dtype)
+
+ self.x = x
+ self.y = y
+
+ result = acl_cmd("Select", [condition, x, y],
+ output_dtypes=[x.dtype],
+ output_shapes=[x.shape],
+ attr={})[0]
+ return result
+
+ def grad(self, grad_output):
+ tmp = jt.zeros(grad_output.shape, dtype=grad_output.dtype)
+ grad_x = acl_cmd("Select", [self.condition, grad_output, tmp],
+ output_dtypes=[self.x.dtype],
+ output_shapes=[self.x.shape],
+ attr={})[0]
+
+ grad_y = acl_cmd("Select", [self.condition, tmp, grad_output],
+ output_dtypes=[self.y.dtype],
+ output_shapes=[self.y.shape],
+ attr={})[0]
+ return grad_output, grad_x, grad_y
+
+ class FlipACL(Function):
+
+ def __init__(self):
+ super(FlipACL, self).__init__()
+
+ def execute(self, input, dim):
+ self.input = input
+ #if isinstance(dim_vector, tuple):
+ dim_vector = jt.Var(list(dim))
+ #print(dim_vector.dtype)
+ self.dim_vector = dim_vector
+ #print(input, dim_vector)
+ result = acl_cmd("ReverseV2", [input, dim_vector],
+ output_dtypes=[input.dtype],
+ output_shapes=[input.shape],
+ attr={})[0]
+ return result
+
+ def grad(self, grad_output):
+ #print(grad_output)
+ grad_input = acl_cmd("ReverseV2", [grad_output, self.dim_vector],
+ output_dtypes=[grad_output.dtype],
+ output_shapes=[grad_output.shape],
+ attr={})[0]
+ return grad_input
+
+ class FloorIntACL(Function):
+
+ def __init__(self):
+ super(FloorIntACL, self).__init__()
+
+ def execute(self, input):
+ self.input = input
+ self.shape = input.shape
+ result = acl_cmd("Floor", [input],
+ output_dtypes=[jt.int],
+ output_shapes=[input.shape],
+ attr={})[0]
+ return result
+
+ def grad(self, grad_output):
+ return jt.zeros(self.shape, dtype=grad_output.dtype)
+
+ def warp(origin_func, new_func):
+
+ def warpper(*args, **kwargs):
+ if origin_func == jt.index:
+ if len(args) == 2 and args[1] == None:
+ args = tuple(list(args[0:1]))
+ if jt.flags.use_acl:
+ if isinstance(new_func, IndexACL):
+ if len(args) == 1:
+ args = (args[0], None)
+ if isinstance(new_func, CumsumACL):
+ args = (args[0], kwargs.get('dim', -1))
+ kwargs = {}
+ if isinstance(new_func,
+ ScatterACL) and kwargs.get('reduce') is not None:
+ args = (args[0], args[1], args[2], args[3],
+ kwargs.get('reduce', 'void'))
+ kwargs = {}
+
+ return new_func(*args, **kwargs)
+ return origin_func(*args, **kwargs)
+
+ return warpper
+
+ jt.index = warp(jt.index, IndexACL())
+ jt.Var.index = lambda x, dim=None: warp(jt.index, IndexACL())(x.shape, dim)
+ jt.nn.Pool = warp(jt.nn.Pool, PoolACL)
+ jt.nn.AdaptiveMaxPool2d = warp(jt.nn.AdaptiveMaxPool2d,
+ AdaptiveMaxPool2dACL)
+ jt.nn.AdaptiveAvgPool2d = warp(jt.nn.AdaptiveAvgPool2d,
+ AdaptiveAvgPool2dACL)
+
+ jt.triu = warp(jt.triu, TriuACL())
+ jt.triu_ = warp(jt.triu, TriuACL())
+ jt.Var.triu = lambda x: warp(jt.Var.triu, TriuACL())(x)
+ jt.Var.triu_ = lambda x: warp(jt.Var.triu_, TriuACL())(x)
+
+ jt.getitem = warp(jt.getitem, GetItem())
+ jt.Var.getitem = lambda x, slices, return_x=None: warp(
+ jt.getitem, GetItem())(x, slices)
+
+ jt.setitem = warp(jt.setitem, SetItemACL())
+ jt.Var.setitem = lambda x, slices, value, reduce='void': warp(
+ jt.setitem, SetItemACL())(x, slices, value, reduce)
+
+ jt.misc.flip = warp(jt.misc.flip, FlipACL())
+ jt.Var.flip = lambda x, dim_vector: warp(jt.misc.flip, FlipACL())(
+ x, dim_vector)
+ jt.cumsum = warp(jt.cumsum, CumsumACL())
+ jt.gather = warp(jt.gather, GatherACL())
+ jt.Var.gather = lambda x, dim, index: warp(jt.gather, GatherACL())(x, dim,
+ index)
+ jt.scatter = warp(jt.scatter, ScatterACL())
+ jt.Var.scatter = lambda x, dim, index, src, reduce="void": warp(
+ jt.scatter, ScatterACL())(x, dim, index, src, reduce)
+ jt.where = warp(jt.where, WhereACL())
+ jt.floor_int = warp(jt.floor_int, FloorIntACL())
+ jt.Var.floor_int = lambda x: warp(jt.floor_int, FloorIntACL())(x)
+
+ # jt.nn.bmm = warp(jt.nn.bmm, BmmACL())
+ # jt.bmm = warp(jt.bmm, BmmACL())
+ # jt.nn.matmul = warp(jt.matmul, MatmulACL())
+ # jt.matmul = warp(jt.matmul, MatmulACL())
+ # jt.transpose = warp(jt.transpose, TransposeACL())
+ # jt.Var.transpose = lambda x, perm: warp(jt.transpose, TransposeACL())(x, perm)
+ # jt.concat = warp(jt.concat, ConcatACL())
diff --git a/python/jittor/extern/acl/acl_jittor.cc b/python/jittor/extern/acl/acl_jittor.cc
index dc99887f..be0c17f4 100644
--- a/python/jittor/extern/acl/acl_jittor.cc
+++ b/python/jittor/extern/acl/acl_jittor.cc
@@ -16,6 +16,7 @@ namespace jittor {
uint64_t acl_jittor_tid;
int acl_jittor_thread_running=0;
aclrtContext acl_jittor_context;
+aclrtStream aclstream;
#define CHECK_ACL(x) ASSERTop(x,==,0)
@@ -28,7 +29,7 @@ static void* acl_jittor_process_callback(void*) {
// LOGir << "acl_jittor_process_callback";
auto ret = aclrtProcessReport(1000);
if (ret) {
- if (acl_jittor_thread_running && ret != ACL_ERROR_RT_REPORT_TIMEOUT)
+ if (acl_jittor_thread_running && ret != ACL_ERROR_RT_REPORT_TIMEOUT && ret != ACL_ERROR_RT_THREAD_SUBSCRIBE)
LOGir << "aclrtProcessReport:" << ret << acl_error_to_string(ret);
break;
}
@@ -59,10 +60,9 @@ acl_jittor_initer() {
CHECK_ACL(aclrtSubscribeReport(acl_jittor_tid,0));
// simple callback test
- // aclrtStream stream;
- // CHECK_ACL(aclrtCreateStream(&stream));
- // CHECK_ACL(aclrtSubscribeReport(acl_jittor_tid,stream));
- // CHECK_ACL(aclrtLaunchCallback((aclrtCallback)&aaa, 0, ACL_CALLBACK_NO_BLOCK, stream));
+ CHECK_ACL(aclrtCreateStream(&aclstream));
+ // CHECK_ACL(aclrtSubscribeReport(acl_jittor_tid,aclstream));
+ // CHECK_ACL(aclrtLaunchCallback((aclrtCallback)&aaa, 0, ACL_CALLBACK_NO_BLOCK, aclstream));
// CHECK_ACL(aclrtLaunchCallback((aclrtCallback)&aaa, 0, ACL_CALLBACK_NO_BLOCK, 0));
}
@@ -87,7 +87,7 @@ string process_acl(const string& src, const string& name, const map fake_class = {
"cudnnHandle_t", "cudnnConvolutionBwdFilterAlgo_t",
@@ -117,6 +117,7 @@ string process_acl(const string& src, const string& name, const map()", "-1e30");
new_src = token_replace_all(new_src, "::numeric_max()", "1e30");
// TODO: support max
diff --git a/python/jittor/extern/acl/acl_jittor.h b/python/jittor/extern/acl/acl_jittor.h
index 2fc6613b..0ef90b40 100644
--- a/python/jittor/extern/acl/acl_jittor.h
+++ b/python/jittor/extern/acl/acl_jittor.h
@@ -13,6 +13,7 @@ std::string acl_error_to_string(aclError error);
namespace jittor {
EXTERN_LIB uint64_t acl_jittor_tid;
+EXTERN_LIB aclrtStream aclstream;
void acl_jittor_op_compiler(string& filename, string& src, bool is_acl, string& extra_flags);
diff --git a/python/jittor/extern/acl/acl_op_exec.cc b/python/jittor/extern/acl/acl_op_exec.cc
index 0eb130d3..07b35145 100644
--- a/python/jittor/extern/acl/acl_op_exec.cc
+++ b/python/jittor/extern/acl/acl_op_exec.cc
@@ -16,7 +16,9 @@
#include "ops/reduce_op.h"
#include "ops/binary_op.h"
#include "ops/broadcast_to_op.h"
+#include "ops/transpose_op.h"
#include "ops/array_op.h"
+#include "ops/code_op.h"
#include "fused_op.h"
#include "ops/unary_op.h"
#include "ops/ternary_op.h"
@@ -32,6 +34,44 @@ namespace jittor {
using std::swap;
+void printDeviceData(const vector& output_desc, const vector& output_data, const string& name = "", bool input=true) {
+ LOGir << "name: " << name;
+ if(input)
+ LOGir << "is input";
+ else
+ LOGir << "is ouput";
+ for (size_t i = 0; i < output_desc.size(); ++i) {
+ void* base_addr = aclGetDataBufferAddr(output_data[i]);
+ LOGir << "addr of data[" << i << "] :" << base_addr;
+ size_t num_dims = aclGetTensorDescNumDims(output_desc[i]);
+ size_t total_size = 1;
+ std::vector dims(num_dims);
+
+ std::cout << "shape of data: ";
+ for (size_t j = 0; j < num_dims; ++j) {
+ aclGetTensorDescDimV2(output_desc[i], j, &dims[j]);
+ total_size *= dims[j];
+ std::cout << dims[j] << ", ";
+ }
+ int evey_batch_size = total_size/dims[0];
+ std::cout << std::endl;
+
+ // for(int i= 0; i < dims[0]; i++) {
+ // evey_batch_size = 16;
+ // std::vector host_buffer(evey_batch_size);
+ // void* offset_addr = static_cast(base_addr) + i * evey_batch_size * sizeof(float);
+ // aclrtMemcpy(host_buffer.data(), evey_batch_size * sizeof(float), offset_addr, evey_batch_size * sizeof(float), ACL_MEMCPY_DEVICE_TO_HOST);
+ // std::cout << "batch[" << i << "]:";
+ // for (size_t k = 0; k < evey_batch_size; ++k) {
+ // std::cout << host_buffer[k] << ", ";
+ // }
+ // std::cout << std::endl;
+ // if(i >= 3)
+ // break;
+ // }
+ }
+}
+
struct AclOpRunner {
string name;
vector input_desc;
@@ -40,6 +80,7 @@ struct AclOpRunner {
vector output_data;
aclopAttr *attr;
vector> input_host;
+ vector> input_host_32;
AclOpRunner(const string& name) : name(name) {
attr = aclopCreateAttr();
@@ -56,9 +97,14 @@ struct AclOpRunner {
aclDataType get_dtype(NanoString s) {
if (s == ns_float32) return ACL_FLOAT;
if (s == ns_float16) return ACL_FLOAT16;
+ if (s == ns_int64) return ACL_INT64;
if (s == ns_int32) return ACL_INT32;
if (s == ns_int8) return ACL_INT8;
+ if (s == ns_int16) return ACL_INT16;
if (s == ns_uint8) return ACL_UINT8;
+ if (s == ns_uint16) return ACL_UINT16;
+ if (s == ns_uint32) return ACL_UINT32;
+ if (s == ns_bool) return ACL_BOOL;
LOGf << "Not supported dtype: " << s;
return ACL_FLOAT;
}
@@ -138,44 +184,63 @@ struct AclOpRunner {
auto data = aclCreateDataBuffer(&v[0], v.size()*sizeof(int));
input_desc.push_back(desc);
input_data.push_back(data);
- // input_host.emplace_back(move(v));
+ input_host_32.emplace_back(move(v));
}
void set_attr(const string& key, bool value) {
+ // LOGir << "string bool" << "set_attr" << key << value;
CHECK(aclopSetAttrBool(attr, key.c_str(), value)==0);
}
+ void set_attr(const string& key, int value, int is_bool) {
+ // LOGir << "string bool" << "set_attr" << key << value << is_bool;
+ CHECK(aclopSetAttrBool(attr, key.c_str(), value==is_bool)==0);
+ }
void set_attr(const string& key, float value) {
+ // LOGir << "string float" <<"set_attr" << key << value;
CHECK(aclopSetAttrFloat(attr, key.c_str(), value)==0);
}
void set_attr(const string& key, int64_t value) {
+ // LOGir << "string int64" << "set_attr" << key << value;
+ CHECK(aclopSetAttrInt(attr, key.c_str(), value)==0);
+ }
+ void set_attr(const string& key, int64_t value, int placeholder) {
+ // LOGir << "string int64" << "set_attr" << key << value;
CHECK(aclopSetAttrInt(attr, key.c_str(), value)==0);
}
void set_attr(const string& key, int32 value) {
+ // LOGir << "string int32" << "set_attr" << key << value;
CHECK(aclopSetAttrInt(attr, key.c_str(), value)==0);
}
void set_attr(const string& key, vector value) {
+ // LOGir << "string vector" << "set_attr" << key << value;
CHECK(aclopSetAttrListInt(attr, key.c_str(), value.size(), &value[0])==0);
}
void set_attr(const string& key, string value) {
+ // LOGir << "string string" << "set_attr" << key << value;
CHECK(aclopSetAttrString(attr, key.c_str(), value.c_str())==0);
}
void set_attr(const char* key, const char* value) {
+ // LOGir << "char" << "set_attr" << key << value;
CHECK(aclopSetAttrString(attr, key, value)==0);
}
void run() {
+ // printDeviceData(input_desc, input_data, name);
+
LOGv << "run" << name << input_desc.size() << output_desc.size();
if (!PyGILState_Check()) {
- ASSERT(0==aclopCompileAndExecuteV2(name.c_str(), input_desc.size(), &input_desc[0], &input_data[0], output_desc.size(), &output_desc[0], &output_data[0], attr, ACL_ENGINE_SYS, ACL_COMPILE_SYS, NULL, NULL));
+ ASSERT(0==aclopCompileAndExecuteV2(name.c_str(), input_desc.size(), &input_desc[0], &input_data[0], output_desc.size(), &output_desc[0], &output_data[0], attr, ACL_ENGINE_SYS, ACL_COMPILE_SYS, NULL, aclstream));
} else {
int ret;
Py_BEGIN_ALLOW_THREADS
- ret = aclopCompileAndExecuteV2(name.c_str(), input_desc.size(), &input_desc[0], &input_data[0], output_desc.size(), &output_desc[0], &output_data[0], attr, ACL_ENGINE_SYS, ACL_COMPILE_SYS, NULL, NULL);
+ ret = aclopCompileAndExecuteV2(name.c_str(), input_desc.size(), &input_desc[0], &input_data[0], output_desc.size(), &output_desc[0], &output_data[0], attr, ACL_ENGINE_SYS, ACL_COMPILE_SYS, NULL, aclstream);
Py_END_ALLOW_THREADS
if (ret != 0)
LOGf << "aclopCompileAndExecuteV2" << name << "failed return" << ret;
}
- // ASSERT(0==aclrtSynchronizeDevice());
+ ASSERT(0==aclrtSynchronizeDevice());
+
+ // printDeviceData(output_desc, output_data, name, false);
}
};
@@ -318,6 +383,17 @@ void try_exec_and_fallback_cpu(Op* op) {
auto iter = opname_map.find(bop->ns);
ASSERT(iter != opname_map.end()) << "op " << bop->ns << " not found";
op.name = iter->second;
+ if (bop->x->dtype() == ns_bool and bop->y->dtype() == ns_bool)
+ {
+ // BitwiseOr, BitwiseAnd, BitwiseXor -> LogicalOr, LogicalAnd, LogicalXor
+ if (bop->ns == ns_bitwise_or) {
+ op.name = "LogicalOr";
+ } else if (bop->ns == ns_bitwise_and) {
+ op.name = "LogicalAnd";
+ } else if (bop->ns == ns_bitwise_xor) {
+ op.name = "LogicalXor";
+ }
+ }
op.run();
} else
if (op->name() == string("ternary")) {
@@ -345,7 +421,7 @@ void try_exec_and_fallback_cpu(Op* op) {
else if (rop->ns == ns_minimum)
op.name = "ReduceMin";
else if (rop->ns == ns_mean)
- op.name = "Reduce";
+ op.name = "ReduceMean";
else
LOGf << "op " << rop->ns << " not supported";
op.add(rop->x, true);
@@ -381,6 +457,19 @@ void try_exec_and_fallback_cpu(Op* op) {
op.add_input_host_nv(zshape, ACL_INT64);
op.add(bop->z, false);
op.run();
+ }
+ else
+ if (op->name() == string("fuse_transpose")) {
+ // replace fuse_transpose with transpose
+ auto top = (TransposeOp*)op;
+ AclOpRunner op("Transpose");
+ op.add(top->x, true);
+ op.add(top->y, false);
+ vector axes;
+ for (int i=0; iaxes.size(); i++)
+ axes.push_back(top->axes[i]);
+ op.add_input_host(axes, ACL_INT64);
+ op.run();
} else
{
LOGf << "op " << op->name() << " not supported";
@@ -388,6 +477,7 @@ void try_exec_and_fallback_cpu(Op* op) {
}
} catch (std::exception& e) {
fallback = 1;
+ LOGir << "fallback cpu" << e.what();
}
for (auto v : new_alloced) {
free_var_mem(v);
@@ -401,7 +491,7 @@ extern int current_seed;
extern int64 current_offset;
static unordered_map> acl_ops = {
-{"curand_random", [&](Op* op) {
+{"curand_random", [¤t_seed, ¤t_offset](Op* op) {
auto _op = (RandomOp*)op;
AclOpRunner runner(_op->type == ns_uniform ? "StatelessRandomUniformV2" : "StatelessRandomNormalV2");
auto out = op->output(0);
@@ -429,7 +519,21 @@ static unordered_map> acl_ops = {
runner.set_attr("transpose_x2", _op->trans_b);
runner.run();
}},
-{"cudnn_conv", [&](Op* op) {
+{"cublas_batched_matmul", [&](Op* op) {
+ struct BatchedMatmulOp : Op {
+ Var* a, *b, *c;
+ bool adj_x1, adj_x2;
+ };
+ auto _op = (BatchedMatmulOp*)op;
+ AclOpRunner runner("BatchMatMul");
+ runner.add(_op->a, true);
+ runner.add(_op->b, true);
+ runner.add(_op->c, false);
+ runner.set_attr("adj_x1", _op->adj_x1);
+ runner.set_attr("adj_x2", _op->adj_x2);
+ runner.run();
+}},
+{"cudnn_conv", [](Op* op) {
struct ConvOp : Op {
Var* x, * w, * y;
int strideh, stridew, paddingh, paddingw, dilationh, dilationw, groups;
@@ -451,7 +555,7 @@ static unordered_map> acl_ops = {
auto _op = (ConvOp*)op;
_op->run_acl();
}},
-{"cudnn_conv_backward_x", [&](Op* op) {
+{"cudnn_conv_backward_x", [](Op* op) {
struct ConvBackwardXOp : Op {
Var* w, * dy, * dx;
int xh, xw, strideh, stridew, paddingh, paddingw, dilationh, dilationw, groups;
@@ -480,7 +584,7 @@ static unordered_map> acl_ops = {
auto _op = (ConvBackwardXOp*)op;
_op->run_acl();
}},
-{"cudnn_conv_backward_w", [&](Op* op) {
+{"cudnn_conv_backward_w", [](Op* op) {
struct ConvBackwardWOp : Op {
Var* x, * dy, * dw;
int kh, kw, strideh, stridew, paddingh, paddingw, dilationh, dilationw, groups;
@@ -488,7 +592,7 @@ static unordered_map> acl_ops = {
void run_acl() {
AclOpRunner runner("Conv2DBackpropFilter");
runner.add(x, true, ACL_FORMAT_NCHW);
- runner.add_input_host_nv32(x->shape);
+ runner.add_input_host_nv32(dw->shape);
runner.add(dy, true, ACL_FORMAT_NCHW);
runner.add(dw, false, ACL_FORMAT_NCHW);
runner.set_attr("strides", vector{1,1,strideh,stridew});
@@ -538,7 +642,7 @@ static jit_op_entry_t acl_do_compile(Op* op) {
return oc.compile(op->get_jit_key(get_jk()), *src);
}
if (op->name() == string("fused")) {
- FusedOp* fop = (FusedOp*)op;
+ FusedOp* fop = (FusedOp*)op;
// if is a relayed op
if (fop->context->vrm.relay_groups.size()) {
LOGv << "relay fused op";
@@ -546,7 +650,17 @@ static jit_op_entry_t acl_do_compile(Op* op) {
} else {
return &try_exec_and_fallback_cpu;
}
- } else {
+ } else
+ if (op->name() == string("code")) {
+ CodeOp* cop = (CodeOp*)op;
+ if (cop->cuda_src.find("acl") != string::npos) {
+ LOGv << "compile acl op";
+ return oc.compile(op->get_jit_key(get_jk()), *src);
+ } else {
+ return &exec_mapped_acl_ops;
+ }
+ } else
+ {
LOGv << "compile finish" << op;
return &exec_mapped_acl_ops;
}
diff --git a/python/jittor/extern/cuda/cublas/inc/cublas_wrapper.h b/python/jittor/extern/cuda/cublas/inc/cublas_wrapper.h
index be01348f..7667e495 100644
--- a/python/jittor/extern/cuda/cublas/inc/cublas_wrapper.h
+++ b/python/jittor/extern/cuda/cublas/inc/cublas_wrapper.h
@@ -25,7 +25,9 @@ static inline cudaDataType get_dtype(NanoString dtype) {
if (dtype == ns_float32) return CUDA_R_32F;
if (dtype == ns_float64) return CUDA_R_64F;
if (dtype == ns_float16) return CUDA_R_16F;
+ #ifndef IS_ROCM
if (dtype == ns_bfloat16) return CUDA_R_16BF;
+ #endif
LOGf << "not support type" << dtype;
return CUDA_R_32F;
}
diff --git a/python/jittor/extern/cuda/cudnn/inc/cudnn_wrapper.h b/python/jittor/extern/cuda/cudnn/inc/cudnn_wrapper.h
index 2109a03a..73fb3233 100644
--- a/python/jittor/extern/cuda/cudnn/inc/cudnn_wrapper.h
+++ b/python/jittor/extern/cuda/cudnn/inc/cudnn_wrapper.h
@@ -8,8 +8,9 @@
#include
#include
#include
+#ifndef IS_ROCM
#include
-
+#endif
#include "utils/log.h"
#include "helper_cuda.h"
#include "fp16_emu.h"
@@ -31,6 +32,8 @@ void set_max_workspace_ratio(float64 ratio);
template __inline__ cudnnDataType_t getDataType();
template <> __inline__ cudnnDataType_t getDataType() { return CUDNN_DATA_HALF; }
template <> __inline__ cudnnDataType_t getDataType() { return CUDNN_DATA_FLOAT; }
+#ifndef IS_ROCM
template <> __inline__ cudnnDataType_t getDataType<__nv_bfloat16>() { return CUDNN_DATA_BFLOAT16; }
+#endif
} // jittor
diff --git a/python/jittor/gradfunctional/__init__.py b/python/jittor/gradfunctional/__init__.py
new file mode 100644
index 00000000..259897e1
--- /dev/null
+++ b/python/jittor/gradfunctional/__init__.py
@@ -0,0 +1,2 @@
+from .functional import jvp, vjp
+
diff --git a/python/jittor/gradfunctional/functional.py b/python/jittor/gradfunctional/functional.py
new file mode 100644
index 00000000..df183cc1
--- /dev/null
+++ b/python/jittor/gradfunctional/functional.py
@@ -0,0 +1,420 @@
+# reference: https://github.com/pytorch/pytorch/blob/8ea5b572a63b1acc538a9fc8d3862c73739116e8/torch/autograd/functional.py
+import jittor as jt
+
+__all__ = ["vjp", "jvp", "jacobian", "hessian", "hvp", "vhp"]
+
+# Utility functions
+def _as_tuple_nocheck(x):
+ if isinstance(x, tuple):
+ return x
+ elif isinstance(x, list):
+ return tuple(x)
+ else:
+ return (x,)
+
+def _as_tuple(inp, arg_name=None, fn_name=None):
+ # Ensures that inp is a tuple of Tensors
+ # Returns whether or not the original inp was a tuple and the tupled version of the input
+ if arg_name is None and fn_name is None:
+ return _as_tuple_nocheck(inp)
+
+ is_inp_tuple = True
+ if not isinstance(inp, tuple):
+ inp = (inp,)
+ is_inp_tuple = False
+
+ for i, el in enumerate(inp):
+ if not isinstance(el, (jt.Var, jt.nn.ComplexNumber)):
+ if is_inp_tuple:
+ raise TypeError(
+ f"The {arg_name} given to {fn_name} must be either a Tensor or a tuple of Tensors but the"
+ f" value at index {i} has type {type(el)}."
+ )
+ else:
+ raise TypeError(
+ f"The {arg_name} given to {fn_name} must be either a Tensor or a tuple of Tensors but the"
+ f" given {arg_name} has type {type(el)}."
+ )
+
+ return is_inp_tuple, inp
+
+
+def _tuple_postprocess(res, to_unpack):
+ # Unpacks a potentially nested tuple of Tensors
+ # to_unpack should be a single boolean or a tuple of two booleans.
+ # It is used to:
+ # - invert _as_tuple when res should match the inp given to _as_tuple
+ # - optionally remove nesting of two tuples created by multiple calls to _as_tuple
+ if isinstance(to_unpack, tuple):
+ assert len(to_unpack) == 2
+ if not to_unpack[1]:
+ res = tuple(el[0] for el in res)
+ if not to_unpack[0]:
+ res = res[0]
+ else:
+ if not to_unpack:
+ res = res[0]
+ return res
+
+
+def _grad_preprocess(inputs, create_graph, need_graph):
+ # Preprocess the inputs to make sure they require gradient
+ # inputs is a tuple of Tensors to preprocess
+ # create_graph specifies if the user wants gradients to flow back to the Tensors in inputs
+ # need_graph specifies if we internally want gradients to flow back to the Tensors in res
+ # Note that we *always* create a new Tensor object to be able to see the difference between
+ # inputs given as arguments and the same Tensors automatically captured by the user function.
+ # Check this issue for more details on how that can happen: https://github.com/pytorch/pytorch/issues/32576
+ res = []
+ for inp in inputs:
+ if create_graph and inp.requires_grad:
+ # Create at least a new Tensor object in a differentiable way
+ # Use .reshae() to get a shallow copy
+ res.append(inp.reshape(inp.shape))
+ else:
+ if need_graph:
+ ninp = inp.detach().start_grad()
+ else:
+ ninp = inp.detach().stop_grad()
+ res.append(ninp)
+ return tuple(res)
+
+
+def _grad_postprocess(inputs, create_graph):
+ # Postprocess the generated Tensors to avoid returning Tensors with history when the user did not
+ # request it.
+ if isinstance(inputs[0], (jt.Var, jt.nn.ComplexNumber)):
+ if not create_graph:
+ return tuple(inp.detach() for inp in inputs)
+ else:
+ return inputs
+ else:
+ return tuple(_grad_postprocess(inp, create_graph) for inp in inputs)
+
+
+def _validate_v(v, other, is_other_tuple):
+ # This assumes that other is the correct shape, and v should match
+ # Both are assumed to be tuples of Tensors
+ if len(other) != len(v):
+ if is_other_tuple:
+ raise RuntimeError(
+ f"v is a tuple of invalid length: should be {len(other)} but got {len(v)}."
+ )
+ else:
+ raise RuntimeError("The given v should contain a single Tensor.")
+
+ for idx, (el_v, el_other) in enumerate(zip(v, other)):
+ if el_v.shape != el_other.shape:
+ prepend = ""
+ if is_other_tuple:
+ prepend = f"Entry {idx} in "
+ raise RuntimeError(
+ f"{prepend}v has invalid size: should be {el_other.shape} but got {el_v.shape}."
+ )
+
+
+def _check_requires_grad(inputs, input_type, strict):
+ # Used to make all the necessary checks to raise nice errors in strict mode.
+ if not strict:
+ return
+
+ if input_type not in ["outputs", "grad_inputs", "jacobian", "hessian"]:
+ raise RuntimeError("Invalid input_type to _check_requires_grad")
+ for i, inp in enumerate(inputs):
+ if inp is None:
+ # This can only be reached for grad_inputs.
+ raise RuntimeError(
+ f"The output of the user-provided function is independent of input {i}."
+ " This is not allowed in strict mode."
+ )
+ if not inp.requires_grad:
+ if input_type == "hessian":
+ raise RuntimeError(
+ f"The hessian of the user-provided function with respect to input {i}"
+ " is independent of the input. This is not allowed in strict mode."
+ " You should ensure that your function is thrice differentiable and that"
+ " the hessian depends on the inputs."
+ )
+ elif input_type == "jacobian":
+ raise RuntimeError(
+ "While computing the hessian, found that the jacobian of the user-provided"
+ f" function with respect to input {i} is independent of the input. This is not"
+ " allowed in strict mode. You should ensure that your function is twice"
+ " differentiable and that the jacobian depends on the inputs (this would be"
+ " violated by a linear function for example)."
+ )
+ elif input_type == "grad_inputs":
+ raise RuntimeError(
+ f"The gradient with respect to input {i} is independent of the inputs of the"
+ " user-provided function. This is not allowed in strict mode."
+ )
+ else:
+ raise RuntimeError(
+ f"Output {i} of the user-provided function does not require gradients."
+ " The outputs must be computed in a differentiable manner from the input"
+ " when running in strict mode."
+ )
+
+
+def _autograd_grad(
+ outputs,
+ inputs,
+ grad_outputs=None,
+ create_graph=True,
+):
+ # Version of grad that accepts `None` in outputs and do not compute gradients for them.
+ # This has the extra constraint that inputs has to be a tuple
+ assert isinstance(outputs, tuple)
+ if grad_outputs is None:
+ grad_outputs = (None,) * len(outputs)
+ assert isinstance(grad_outputs, tuple)
+ assert len(outputs) == len(grad_outputs)
+
+ new_outputs = ()
+ new_grad_outputs = ()
+ for out, grad_out in zip(outputs, grad_outputs):
+ if out is not None and out.requires_grad:
+ new_outputs += (out,)
+ new_grad_outputs += (grad_out,)
+
+ if len(new_outputs) == 0:
+ # No differentiable output, we don't need to call the autograd engine
+ return (None,) * len(inputs)
+ else:
+ acc_loss = None
+ for new_output, grad_output in zip(new_outputs, grad_outputs):
+ if isinstance(new_output, jt.nn.ComplexNumber):
+ if grad_output is not None:
+ loss = (new_output.value * grad_output.value).sum()
+ else:
+ loss = new_output.value.sum()
+ else:
+ if grad_output is not None:
+ new_output = new_output * grad_output
+ loss = new_output.sum()
+ if acc_loss is None:
+ acc_loss = loss
+ else:
+ acc_loss += loss
+
+ complex_inds = []
+ var_inputs = []
+ for idx, inp in enumerate(inputs):
+ if isinstance(inp, jt.nn.ComplexNumber):
+ var_inputs.append(inp.value)
+ complex_inds.append(idx)
+ else:
+ var_inputs.append(inp)
+
+ grads = jt.grad(acc_loss, var_inputs, retain_graph=create_graph)
+ for complex_ind in complex_inds:
+ grads[complex_ind] = jt.nn.ComplexNumber(grads[complex_ind], is_concat_value=True)
+ return tuple(grads)
+
+
+def _fill_in_zeros(grads, refs, strict, create_graph, stage):
+ # Used to detect None in the grads and depending on the flags, either replace them
+ # with Tensors full of 0s of the appropriate size based on the refs or raise an error.
+ # strict and create graph allow us to detect when it is appropriate to raise an error
+ # stage gives us information of which backward call we consider to give good error message
+ if stage not in ["back", "back_trick", "double_back", "double_back_trick"]:
+ raise RuntimeError(f"Invalid stage argument '{stage}' to _fill_in_zeros")
+
+ res = ()
+ for i, grads_i in enumerate(grads):
+ if grads_i is None:
+ if strict:
+ if stage == "back":
+ raise RuntimeError(
+ "The output of the user-provided function is independent of "
+ f"input {i}. This is not allowed in strict mode."
+ )
+ elif stage == "back_trick":
+ raise RuntimeError(
+ f"The gradient with respect to the input is independent of entry {i}"
+ " in the grad_outputs when using the double backward trick to compute"
+ " forward mode gradients. This is not allowed in strict mode."
+ )
+ elif stage == "double_back":
+ raise RuntimeError(
+ "The jacobian of the user-provided function is independent of "
+ f"input {i}. This is not allowed in strict mode."
+ )
+ else:
+ raise RuntimeError(
+ "The hessian of the user-provided function is independent of "
+ f"entry {i} in the grad_jacobian. This is not allowed in strict "
+ "mode as it prevents from using the double backward trick to "
+ "replace forward mode AD."
+ )
+
+ refs_i = refs[i]
+ if isinstance(refs_i, jt.nn.ComplexNumber):
+ grads_i = jt.nn.ComplexNumber(jt.zeros_like(refs_i.value), is_concat_value=True)
+ else:
+ grads_i = jt.zeros_like(refs_i)
+ else:
+ if strict and create_graph and not grads_i.requires_grad:
+ if "double" not in stage:
+ raise RuntimeError(
+ "The jacobian of the user-provided function is independent of "
+ f"input {i}. This is not allowed in strict mode when create_graph=True."
+ )
+ else:
+ raise RuntimeError(
+ "The hessian of the user-provided function is independent of "
+ f"input {i}. This is not allowed in strict mode when create_graph=True."
+ )
+
+ res += (grads_i,)
+
+ return res
+
+
+# Public API
+
+def vjp(func, inputs, v=None, create_graph=False, strict=False):
+ r"""Compute the dot product between a vector ``v`` and the Jacobian of the given function at the point given by the inputs.
+
+ Args:
+ func (function): a Python function that takes Tensor inputs and returns
+ a tuple of Tensors or a Tensor.
+ inputs (tuple of Tensors or Tensor): inputs to the function ``func``.
+ v (tuple of Tensors or Tensor): The vector for which the vector
+ Jacobian product is computed. Must be the same size as the output
+ of ``func``. This argument is optional when the output of ``func``
+ contains a single element and (if it is not provided) will be set
+ as a Tensor containing a single ``1``.
+ create_graph (bool, optional): If ``True``, both the output and result
+ will be computed in a differentiable way. Note that when ``strict``
+ is ``False``, the result can not require gradients or be
+ disconnected from the inputs. Defaults to ``False``.
+ strict (bool, optional): If ``True``, an error will be raised when we
+ detect that there exists an input such that all the outputs are
+ independent of it. If ``False``, we return a Tensor of zeros as the
+ vjp for said inputs, which is the expected mathematical value.
+ Defaults to ``False``.
+
+ Returns:
+ output (tuple): tuple with:
+ func_output (tuple of Tensors or Tensor): output of ``func(inputs)``
+
+ vjp (tuple of Tensors or Tensor): result of the dot product with
+ the same shape as the inputs.
+ """
+ with jt.enable_grad():
+ is_inputs_tuple, inputs = _as_tuple(inputs, "inputs", "vjp")
+ inputs = _grad_preprocess(inputs, create_graph=create_graph, need_graph=True)
+
+ outputs = func(*inputs)
+ is_outputs_tuple, outputs = _as_tuple(
+ outputs, "outputs of the user-provided function", "vjp"
+ )
+ _check_requires_grad(outputs, "outputs", strict=strict)
+
+ if v is not None:
+ _, v = _as_tuple(v, "v", "vjp")
+ v = _grad_preprocess(v, create_graph=create_graph, need_graph=False)
+ _validate_v(v, outputs, is_outputs_tuple)
+ else:
+ if len(outputs) != 1 or outputs[0].nelement() != 1:
+ raise RuntimeError(
+ "The vector v can only be None if the "
+ "user-provided function returns "
+ "a single Tensor with a single element."
+ )
+
+ with jt.enable_grad():
+ grad_res = _autograd_grad(outputs, inputs, v, create_graph=create_graph)
+ vjp = _fill_in_zeros(grad_res, inputs, strict, create_graph, "back")
+
+ # Cleanup objects and return them to the user
+ outputs = _grad_postprocess(outputs, create_graph)
+ vjp = _grad_postprocess(vjp, create_graph)
+
+ return _tuple_postprocess(outputs, is_outputs_tuple), _tuple_postprocess(
+ vjp, is_inputs_tuple
+ )
+
+
+def jvp(func, inputs, v=None, create_graph=False, strict=False):
+ r"""Compute the dot product between the Jacobian of the given function at the point given by the inputs and a vector ``v``.
+
+ Args:
+ func (function): a Python function that takes Tensor inputs and returns
+ a tuple of Tensors or a Tensor.
+ inputs (tuple of Tensors or Tensor): inputs to the function ``func``.
+ v (tuple of Tensors or Tensor): The vector for which the Jacobian
+ vector product is computed. Must be the same size as the input of
+ ``func``. This argument is optional when the input to ``func``
+ contains a single element and (if it is not provided) will be set
+ as a Tensor containing a single ``1``.
+ create_graph (bool, optional): If ``True``, both the output and result
+ will be computed in a differentiable way. Note that when ``strict``
+ is ``False``, the result can not require gradients or be
+ disconnected from the inputs. Defaults to ``False``.
+ strict (bool, optional): If ``True``, an error will be raised when we
+ detect that there exists an input such that all the outputs are
+ independent of it. If ``False``, we return a Tensor of zeros as the
+ jvp for said inputs, which is the expected mathematical value.
+ Defaults to ``False``.
+
+ Returns:
+ output (tuple): tuple with:
+ func_output (tuple of Tensors or Tensor): output of ``func(inputs)``
+
+ jvp (tuple of Tensors or Tensor): result of the dot product with
+ the same shape as the output.
+
+ """
+ with jt.enable_grad():
+ is_inputs_tuple, inputs = _as_tuple(inputs, "inputs", "jvp")
+ inputs = _grad_preprocess(inputs, create_graph=create_graph, need_graph=True)
+
+ if v is not None:
+ _, v = _as_tuple(v, "v", "jvp")
+ v = _grad_preprocess(v, create_graph=create_graph, need_graph=False)
+ _validate_v(v, inputs, is_inputs_tuple)
+ else:
+ if len(inputs) != 1 or inputs[0].nelement() != 1:
+ raise RuntimeError(
+ "The vector v can only be None if the input to "
+ "the user-provided function is a single Tensor "
+ "with a single element."
+ )
+
+ outputs = func(*inputs)
+ is_outputs_tuple, outputs = _as_tuple(
+ outputs, "outputs of the user-provided function", "jvp"
+ )
+ _check_requires_grad(outputs, "outputs", strict=strict)
+ # The backward is linear so the value of grad_outputs is not important as
+ # it won't appear in the double backward graph. We only need to ensure that
+ # it does not contain inf or nan.
+ grad_outputs = tuple(
+ jt.nn.ComplexNumber(jt.zeros_like(out.value), is_concat_value=True) if isinstance(out, jt.nn.ComplexNumber) else jt.zeros_like(out)
+ for out in outputs
+ )
+
+ grad_inputs = _autograd_grad(outputs, inputs, grad_outputs=grad_outputs, create_graph=True)
+ _check_requires_grad(grad_inputs, "grad_inputs", strict=strict)
+
+ if create_graph:
+ with jt.enable_grad():
+ grad_res = _autograd_grad(
+ grad_inputs, grad_outputs, v, create_graph=create_graph
+ )
+ jvp = _fill_in_zeros(grad_res, outputs, strict, create_graph, "back_trick")
+ else:
+ grad_res = _autograd_grad(
+ grad_inputs, grad_outputs, v, create_graph=create_graph
+ )
+ jvp = _fill_in_zeros(grad_res, outputs, strict, create_graph, "back_trick")
+
+ # Cleanup objects and return them to the user
+ outputs = _grad_postprocess(outputs, create_graph)
+ jvp = _grad_postprocess(jvp, create_graph)
+
+ return _tuple_postprocess(outputs, is_outputs_tuple), _tuple_postprocess(
+ jvp, is_outputs_tuple
+ )
diff --git a/python/jittor/linalg.py b/python/jittor/linalg.py
index d601335c..2053d86c 100644
--- a/python/jittor/linalg.py
+++ b/python/jittor/linalg.py
@@ -10,6 +10,214 @@
# ***************************************************************
import jittor as jt
from functools import partial
+from .nn import ComplexNumber
+
+def complex_inv(x:ComplexNumber):
+ r"""
+ calculate the inverse of x.
+ :param x (...,M,M):
+ :return:x^-1 (...,M,M).
+
+ TODO: Faster Implementation; Check backward.
+ """
+ assert isinstance(x, ComplexNumber), "complex_inv is implemented for nn.ComplexNumber"
+ assert x.real.dtype == jt.float32 and x.imag.dtype == jt.float32, "real and imag in ComplexNumber should be jt.float32"
+ assert x.shape[-2] == x.shape[-1], "only square matrix is supported for complex_inv"
+
+ def forward_code(np, data):
+ def _stack_to_complex(x):
+ return x[..., 0] + 1j * x[..., 1]
+ def _complex_to_stack(x):
+ return np.stack([np.real(x), np.imag(x)], axis=-1)
+
+ a = _stack_to_complex(data["inputs"][0])
+ m_a = data["outputs"][0]
+ t_a = np.linalg.inv(a)
+ np.copyto(m_a, _complex_to_stack(t_a))
+
+
+ def backward_code(np, data):
+ def T(x):
+ return np.conj(np.swapaxes(x, -1, -2))
+ def _stack_to_complex(x):
+ return x[..., 0] + 1j * x[..., 1]
+ def _complex_to_stack(x):
+ return np.stack([np.real(x), np.imag(x)], axis=-1)
+ _dot = partial(np.einsum, '...ij,...jk->...ik')
+ dout = _stack_to_complex(data["dout"])
+ out = data["outputs"][0]
+ mx = _stack_to_complex(data["f_outputs"][0])
+ t = -_dot(_dot(T(mx), dout), T(mx))
+ np.copyto(out, _complex_to_stack(t))
+
+ lmx = jt.numpy_code(
+ x.value.shape,
+ x.value.dtype,
+ [x.value],
+ forward_code,
+ [backward_code],
+ )
+
+ return ComplexNumber(lmx, is_concat_value=True)
+
+def complex_eig(x:ComplexNumber):
+ r"""
+ calculate the eigenvalues and eigenvectors of x.
+ :param x (...,M,M):
+ :return:w, v.
+ w (...,M) : the eigenvalues.
+ v (...,M,M) : normalized eigenvectors.
+ """
+ assert isinstance(x, ComplexNumber), "complex_eig is implemented for nn.ComplexNumber"
+ assert x.real.dtype == jt.float32 and x.imag.dtype == jt.float32, "real and imag in ComplexNumber should be jt.float32"
+ assert x.shape[-2] == x.shape[-1], "only square matrix is supported for complex_eig"
+ def forward_code(np, data):
+ def _stack_to_complex(x):
+ return x[..., 0] + 1j * x[..., 1]
+ def _complex_to_stack(x):
+ return np.stack([np.real(x), np.imag(x)], axis=-1)
+ a = _stack_to_complex(data["inputs"][0])
+ w, v = data["outputs"]
+ tw, tv = np.linalg.eig(a)
+ np.copyto(w, _complex_to_stack(tw))
+ np.copyto(v, _complex_to_stack(tv))
+
+ def backward_code(np, data):
+ raise NotImplementedError
+
+ sw = x.shape[:-2] + x.shape[-1:] + (2,)
+ sv = x.value.shape
+ w, v = jt.numpy_code(
+ [sw, sv],
+ [x.value.dtype, x.value.dtype],
+ [x.value],
+ forward_code,
+ [backward_code],
+ )
+ return ComplexNumber(w, is_concat_value=True), ComplexNumber(v, is_concat_value=True)
+
+def complex_qr(x):
+ r"""
+ do the qr factorization of x in the below formula:
+ x = QR where Q is orthogonal matrix and R is upper-triangle matrix.
+ :param x (...,M,M):
+ :return:q,r as the result of qr factorization.They are both in the shape of (...,M,M).
+ """
+ assert isinstance(x, ComplexNumber), "linalg_qr is implemented for nn.ComplexNumber"
+ assert x.real.dtype == jt.float32 and x.imag.dtype == jt.float32, "real and imag in ComplexNumber should be jt.float32"
+ assert x.shape[-2] == x.shape[-1], "only square matrix is supported for linalg_qr"
+ def forward_code(np, data):
+ def _stack_to_complex(x):
+ return x[..., 0] + 1j * x[..., 1]
+ def _complex_to_stack(x):
+ return np.stack([np.real(x), np.imag(x)], axis=-1)
+ a = _stack_to_complex(data["inputs"][0])
+ qr = data["outputs"][0]
+ Q, R = np.linalg.qr(a)
+ QR = np.stack([Q, R], axis=0)
+ np.copyto(qr, _complex_to_stack(QR))
+
+ def backward_code(np, data):
+ # reference: https://github.com/tencent-quantum-lab/tensorcircuit/blob/master/tensorcircuit/backends/pytorch_ops.py
+ def H(x):
+ return np.conj(np.swapaxes(x, -1, -2))
+ def _TriangularSolve(x, r):
+ return H(np.linalg.solve(r, H(x)))
+ def _stack_to_complex(x):
+ return x[..., 0] + 1j * x[..., 1]
+ def _complex_to_stack(x):
+ return np.stack([np.real(x), np.imag(x)], axis=-1)
+ _dot = partial(np.einsum, '...ij,...jk->...ik')
+ _diag = partial(np.einsum, '...ii->...i')
+
+ dout = data["dout"]
+ out = data["outputs"][0]
+ qr = data["f_outputs"][0]
+ dout = _stack_to_complex(dout)
+ dq, dr = dout[0], dout[1]
+ qr = _stack_to_complex(qr)
+ q, r = qr[0], qr[1]
+
+
+ qdq = _dot(H(q), dq)
+ qdq_ = qdq - H(qdq)
+ rdr = _dot(r, H(dr))
+ rdr_ = rdr - H(rdr)
+ tril = np.tril(qdq_ + rdr_)
+
+ grad_a = _dot(q, dr + _TriangularSolve(tril, r))
+ grad_b = _TriangularSolve(dq - _dot(q, qdq), r)
+ ret = grad_a + grad_b
+
+ m = rdr - H(qdq)
+ eyem = np.zeros_like(m)
+ _diag(eyem)[:] = _diag(m)
+ correction = eyem - np.real(eyem)
+ ret = ret + _TriangularSolve(_dot(q, H(correction)), r)
+
+ ret = _complex_to_stack(ret)
+ np.copyto(out,ret)
+
+ qr = jt.numpy_code(
+ (2,) + x.value.shape,
+ x.value.dtype,
+ [x.value],
+ forward_code,
+ [backward_code],
+ )
+ q, r = qr[0], qr[1]
+ return ComplexNumber(q, is_concat_value=True), ComplexNumber(r, is_concat_value=True)
+
+def complex_svd(x:ComplexNumber):
+ r'''
+ calculate the Singular Value Decomposition of x.It follows the below fomula:
+ x = usv*
+ only support full matrices == False ver now, which means:
+ x's shape (...,M,K)
+ u's shape (...,M,K)
+ s's shape (...,K)
+ v's shape (...,K,N)
+ where K is min(M,N).
+ :param x:
+ :return:u,s,v.
+ '''
+ def forward_code(np, data):
+ def _stack_to_complex(x):
+ return x[..., 0] + 1j * x[..., 1]
+ def _complex_to_stack(x):
+ return np.stack([np.real(x), np.imag(x)], axis=-1)
+ a = _stack_to_complex(data["inputs"][0])
+ u, s, v = data["outputs"]
+ #TODO:remove copyto
+ tu, ts, tv = np.linalg.svd(a, full_matrices=0)
+ np.copyto(u, _complex_to_stack(tu))
+ np.copyto(s, _complex_to_stack(ts))
+ np.copyto(v, _complex_to_stack(tv))
+
+ def backward_code(np, data):
+ raise NotImplementedError
+
+ m, n = x.shape[-2:]
+ k = min(m, n)
+ s1 = list(x.shape)
+ s1[-1] = k
+ s2 = list(x.shape)
+ s2[-2] = k
+ s3 = list(x.shape)[:-2]
+ s3.append(k)
+ s1.append(2)
+ s2.append(2)
+ s3.append(2)
+ u, s, v = jt.numpy_code(
+ [s1, s3, s2],
+ [x.value.dtype, x.value.dtype, x.value.dtype],
+ [x.value],
+ forward_code,
+ [backward_code],
+ )
+ return ComplexNumber(u, is_concat_value=True), \
+ ComplexNumber(s, is_concat_value=True), \
+ ComplexNumber(v, is_concat_value=True)
#TODO:full_matrices=1
def svd(x):
@@ -25,6 +233,8 @@ def svd(x):
:param x:
:return:u,s,v.
'''
+ if isinstance(x, ComplexNumber):
+ return complex_svd(x)
def forward_code(np, data):
a = data["inputs"][0]
u, s, v = data["outputs"]
@@ -92,6 +302,17 @@ def T(x):
)
return u, s, v
+def eig(x):
+ r"""
+ calculate the eigenvalues and eigenvectors of x.
+ :param x (...,M,M):
+ :return (ComplexNumber):w, v.
+ w (...,M) : the eigenvalues.
+ v (...,M,M) : normalized eigenvectors.
+ """
+ if isinstance(x, ComplexNumber):
+ return complex_eig(x)
+ return complex_eig(ComplexNumber(x))
def eigh(x):
r"""
@@ -147,6 +368,8 @@ def inv(x):
:param x (...,M,M):
:return:x^-1 (...,M,M).
"""
+ if isinstance(x, ComplexNumber):
+ return complex_inv(x)
def forward_code(np, data):
a = data["inputs"][0]
m_a = data["outputs"][0]
@@ -387,6 +610,8 @@ def qr(x):
:param x (...,M,M):
:return:q,r as the result of qr factorization.They are both in the shape of (...,M,M).
"""
+ if isinstance(x, ComplexNumber):
+ return complex_qr(x)
def forward_code(np, data):
a = data["inputs"][0]
q, r = data["outputs"]
diff --git a/python/jittor/misc.py b/python/jittor/misc.py
index fc528d62..8561393e 100644
--- a/python/jittor/misc.py
+++ b/python/jittor/misc.py
@@ -294,9 +294,10 @@ def expand(x, *shape):
if len(shape) == 1 and isinstance(shape[0], (tuple,list,jt.NanoVector)):
shape = shape[0]
shape = list(shape)
- for i in range(len(shape)):
- if shape[i] == -1:
- shape[i] = x.shape[i]
+ offset = len(shape) - len(x.shape)
+ for i in range(len(x.shape)):
+ if shape[offset + i] == -1:
+ shape[offset + i] = x.shape[i]
return x.broadcast(shape)
jt.Var.expand = expand
@@ -622,6 +623,10 @@ def unique(
#include
#include
+ #include
+ #include
+ #include
+
#include
#include
''',
@@ -704,6 +709,11 @@ def unique(
#include
#include
#include
+
+ #include
+ #include
+ #include
+
#include
#include
@@ -881,7 +891,7 @@ def split(d, split_size, dim=0):
ans = []
last = 0
s_last = len(split_size)-1
- gopt_disable = jt.flags.gopt_disable
+ gopt_disable = jt.flags.gopt_disable or jt.flags.use_acl
for j, i in enumerate(split_size):
if i==0:
shape = list(d.shape)
@@ -922,6 +932,48 @@ def diag(x,diagonal=0):
output_shape = (x.shape[0]-d,)
return x.reindex(output_shape,[f'i0+{d}' if diagonal<=0 else 'i0',f'i0+{d}' if diagonal>=0 else 'i0'])
+# reference: https://github.com/pytorch/pytorch/blob/25d5a815f74db80ef19a3f714709b55b05675245/torch/_refs/__init__.py
+def diagonal(x, offset=0, dim1=0, dim2=1):
+ def __normalize_dim(d, rank):
+ if d < 0:
+ d += rank
+ if d < 0 or d >= rank:
+ msg = f"Dimension out of range (expected to be in range of [{-rank}, {rank - 1}], but got {d})"
+ raise IndexError(msg)
+ return d
+ assert x.ndim >= 2, f"diagonal dimensions requires ndim larger than 2, but got {x.ndim}"
+ dim1 = __normalize_dim(dim1, x.ndim)
+ dim2 = __normalize_dim(dim2, x.ndim)
+ assert dim1 != dim2, f"diagonal dimensions cannot be identical {dim1}, {dim2}"
+
+ if offset >= 0:
+ diag_size = max(min(x.shape[dim1], x.shape[dim2] - offset), 0)
+ else:
+ diag_size = max(min(x.shape[dim1] + offset, x.shape[dim2]), 0)
+
+ sizes = []
+ indices = []
+ lsizes = 0
+ dim_diag = x.ndim - 2
+ abs_offset = offset if offset >= 0 else -offset
+ for i, s in enumerate(x.shape):
+ if i == dim1:
+ if offset >= 0:
+ indices.append(f"i{dim_diag}")
+ else:
+ indices.append(f"i{dim_diag}+{abs_offset}")
+ elif i == dim2:
+ if offset >= 0:
+ indices.append(f"i{dim_diag}+{abs_offset}")
+ else:
+ indices.append(f"i{dim_diag}")
+ else:
+ indices.append(f"i{lsizes}")
+ sizes.append(s)
+ lsizes += 1
+ out_shape = tuple(sizes + [diag_size])
+ return x.reindex(out_shape, indices)
+
jt.Var.diag = diag
@@ -2010,6 +2062,7 @@ def contiguous(x): return x.clone()
def cpu(x): return x.clone()
jt.Var.cpu = cpu
def to(x, *args, **kargs):
+ args += tuple(kargs.values())
if len(args) >= 1:
s = args[0]
if isinstance(s, jt.NanoString) or callable(s):
@@ -2234,3 +2287,16 @@ def cuda(x):
def expm1(x):
return jt.exp(x) - 1
+
+
+def isin(elements, test_elements, assume_unique=False, invert=False):
+
+ elements = elements.unsqueeze(-1)
+ test_elements = test_elements.unsqueeze(0)
+ comparison = elements == test_elements
+ result = comparison.any(dim=-1)
+
+ if invert:
+ result = jt.logical_not(result)
+
+ return result
\ No newline at end of file
diff --git a/python/jittor/nn.py b/python/jittor/nn.py
index a8190c79..46ca78ec 100644
--- a/python/jittor/nn.py
+++ b/python/jittor/nn.py
@@ -22,6 +22,7 @@
from jittor.optim import *
from jittor.misc import _pair, _triple
from jittor_utils import LOG
+from functools import partial
def matmul_transpose(a, b):
@@ -385,7 +386,7 @@ def cross_entropy_loss(output, target, weight=None, ignore_index=None,reduction=
if ignore_index is not None:
target_weight = jt.ternary(
target==ignore_index,
- jt.array(0).broadcast(target_weight),
+ jt.array(0).broadcast(target_weight).type_as(target_weight),
target_weight
)
@@ -493,6 +494,9 @@ def execute(self, output, target):
return l1_loss(output, target)
def binary_cross_entropy_with_logits(output, target, weight=None, pos_weight=None, size_average=True):
+ if not (target.shape == output.shape):
+ raise ValueError(f"Target size ({target.shape}) must be the same as output size ({output.shape})")
+
max_val = jt.clamp(-output,min_v=0)
if pos_weight is not None:
log_weight = (pos_weight-1)*target + 1
@@ -588,6 +592,8 @@ def __init__(self, p=0.5, is_train=False):
#TODO: test model.train() to change self.is_train
def execute(self, input):
output = input
+ if (input.dim() != 4) and (input.dim() != 3):
+ raise RuntimeError(f'Expected 3D (unbatched) or 4D (batched) input to Dropout2d, but got input of size: {input.shape}')
shape = input.shape[:-2]
if self.p > 0 and self.is_train:
if self.p == 1:
@@ -925,6 +931,42 @@ class Conv(Module):
>>> output = conv(input)
'''
def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True):
+ if in_channels <= 0:
+ raise ValueError(f"in_channels must be greater than zero, got {in_channels}")
+ if out_channels <= 0:
+ raise ValueError(f"out_channels must be greater than zero, got {out_channels}")
+ if groups <= 0:
+ raise ValueError(f"groups must must be greater than zero, got {groups}")
+ assert in_channels % groups == 0, 'in_channels must be divisible by groups'
+ assert out_channels % groups == 0, 'out_channels must be divisible by groups'
+ if isinstance(kernel_size, tuple):
+ for size in kernel_size:
+ if size <= 0:
+ raise ValueError(f"kernel_size must be greater than zero, got {kernel_size}")
+ else:
+ if kernel_size <= 0:
+ raise ValueError(f"kernel_size must be greater than zero, got {kernel_size}")
+ if isinstance(stride, tuple):
+ for size in stride:
+ if size <= 0:
+ raise ValueError(f"stride must be greater than zero, got {stride}")
+ else:
+ if stride <= 0:
+ raise ValueError(f"stride must be greater than zero, got {stride}")
+ if isinstance(padding, tuple):
+ for size in padding:
+ if size < 0:
+ raise ValueError(f"padding must be nonnegative, got {padding}")
+ else:
+ if padding < 0:
+ raise ValueError(f"padding must be nonnegative, got {padding}")
+ if isinstance(dilation, tuple):
+ for size in dilation:
+ if size <= 0:
+ raise ValueError(f"dilation must be greater than zero, got {dilation}")
+ else:
+ if dilation <= 0:
+ raise ValueError(f"dilation must be greater than zero, got {dilation}")
self.in_channels = in_channels
self.out_channels = out_channels
self.kernel_size = kernel_size if isinstance(kernel_size, tuple) else (kernel_size, kernel_size)
@@ -935,8 +977,6 @@ def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0,
self.is_depthwise_conv = self.groups == self.out_channels and self.groups == self.in_channels
if self.is_depthwise_conv and jt.flags.use_cuda and jt.compiler.is_cuda:
self.depthwise_conv = DepthwiseConv(stride, padding, dilation)
- assert in_channels % groups == 0, 'in_channels must be divisible by groups'
- assert out_channels % groups == 0, 'out_channels must be divisible by groups'
Kh, Kw = self.kernel_size
# self.weight = init.relu_invariant_gauss([out_channels, in_channels//groups, Kh, Kw], dtype="float", mode="fan_out")
@@ -1053,6 +1093,8 @@ class Conv1d(Module):
>>> output = conv(input)
'''
def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True):
+ assert in_channels > 0, 'in_channels must be positive'
+ assert out_channels > 0, 'out_channels must be positive'
self.in_channels = in_channels
self.out_channels = out_channels
self.kernel_size = (kernel_size, 1)
@@ -1061,6 +1103,8 @@ def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0,
self.dilation = (dilation, 1)
self.groups = groups
self.bias = bias
+ if groups <= 0:
+ raise ValueError("groups must be a positive integer")
assert in_channels % groups == 0, 'in_channels must be divisible by groups'
assert out_channels % groups == 0, 'out_channels must be divisible by groups'
# using list to escape module dfs
@@ -1069,6 +1113,8 @@ def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0,
self.bias = self._conv[0].bias
def execute(self, x):
+ if x.dim() != 3:
+ raise ValueError("Input shape must be `(N, C, L)`!")
N,C,D = x.shape
assert C==self.in_channels
self._conv[0].weight = self.weight.unsqueeze(-1)
@@ -1121,6 +1167,8 @@ def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0,
self.padding = padding if isinstance(padding, tuple) else (padding, padding, padding)
self.dilation = dilation if isinstance(dilation, tuple) else (dilation, dilation, dilation)
self.groups = groups
+ if groups <= 0:
+ raise ValueError("groups must be a positive integer")
assert in_channels % groups == 0, 'in_channels must be divisible by groups'
assert out_channels % groups == 0, 'out_channels must be divisible by groups'
Kh, Kw, Kd = self.kernel_size
@@ -1144,10 +1192,14 @@ def execute(self, x):
class Conv1d_sp(Linear):
def __init__(self, inchannels, outchannels, kernel_size=1, bias=True):
+ assert inchannels > 0, 'in_channels must be positive'
+ assert outchannels > 0, 'out_channels must be positive'
super().__init__(inchannels, outchannels, bias=bias)
assert kernel_size == 1
def execute(self, x):
+ if x.dim() != 3:
+ raise ValueError("Input shape must be `(N, C, L)`!")
x = x.transpose(0, 2, 1)
x = super().execute(x)
x = x.transpose(0, 2, 1)
@@ -1187,7 +1239,8 @@ def conv2d(x, weight, bias=None, stride=1, padding=0, dilation=1, groups=1):
stride = _pair(stride)
dilation = _pair(dilation)
out_channels = weight.shape[0]
-
+ if groups <= 0:
+ raise ValueError("groups must be a positive integer")
if groups == 1:
N,C,H,W = x.shape
Kh, Kw = weight.shape[-2:]
@@ -1276,7 +1329,8 @@ def conv3d(x, weight, bias=None, stride=1, padding=0, dilation=1, groups=1):
stride = _triple(stride)
dilation = _triple(dilation)
out_channels = weight.shape[0]
-
+ if groups <= 0:
+ raise ValueError("groups must be a positive integer")
if jt.flags.use_cuda and jt.cudnn:
y = jt.cudnn.ops.cudnn_conv3d(x, weight, *stride, *padding, *dilation, groups)
elif groups == 1:
@@ -1352,6 +1406,8 @@ def __init__(self, in_channels, out_channels, kernel_size, stride=1, \
self.real_padding = (self.dilation[0] * (self.kernel_size[0] - 1) - self.padding[0],
self.dilation[1] * (self.kernel_size[1] - 1) - self.padding[1])
self.output_padding = output_padding if isinstance (output_padding, tuple) else (output_padding, output_padding)
+ assert self.stride[0] > 0 and self.stride[1] > 0,"stride must be positive"
+ assert self.padding[0] >= 0 and self.padding[1] >= 0,"padding must be non-negative"
assert self.output_padding[0] < max(self.stride[0], self.dilation[0]) and \
self.output_padding[1] < max(self.stride[1], self.dilation[1]), \
"output padding must be smaller than max(stride, dilation)"
@@ -1369,6 +1425,8 @@ def __init__(self, in_channels, out_channels, kernel_size, stride=1, \
self.bias = None
def execute(self, x):
+ if x.dim() != 4:
+ raise RuntimeError(f'Expected 4D (batched) input to conv_transpose2d, but got input of size: {x.shape}')
if self.groups == 1:
N,C,H,W = x.shape
i,o,h,w = self.weight.shape
@@ -1476,10 +1534,14 @@ def execute(self, x):
def conv_transpose(input, weight, bias=None, stride=1, padding=0, output_padding=0, groups=1, dilation=1):
if groups == 1:
x = input
+ if x.dim() != 4:
+ raise RuntimeError(f'Expected 4D input to conv_transpose, but got input of size: {x.shape}')
N,C,H,W = x.shape
i,o,h,w = weight.shape
assert C==i
stride = stride if isinstance(stride, tuple) else (stride, stride)
+ if stride[0] <= 0 or stride[1] <= 0:
+ raise RuntimeError("non-positive stride is not supported")
dilation = dilation if isinstance(dilation, tuple) else (dilation, dilation)
# added
padding = padding if isinstance(padding, tuple) else (padding, padding)
@@ -1511,6 +1573,8 @@ def conv_transpose(input, weight, bias=None, stride=1, padding=0, output_padding
assert not bias, "Bias should be none or jittor var"
return y
else:
+ if input.dim() != 4:
+ raise RuntimeError(f'Expected 4D input to conv_transpose, but got input of size: {input.shape}')
N,C,H,W = input.shape
i,o,h,w = weight.shape
G = groups
@@ -1519,6 +1583,8 @@ def conv_transpose(input, weight, bias=None, stride=1, padding=0, output_padding
assert C % G == 0
assert C==i, (C, i)
stride = stride if isinstance(stride, tuple) else (stride, stride)
+ if stride[0] <= 0 or stride[1] <= 0:
+ raise RuntimeError("non-positive stride is not supported")
dilation = dilation if isinstance(dilation, tuple) else (dilation, dilation)
# added
padding = padding if isinstance(padding, tuple) else (padding, padding)
@@ -1562,11 +1628,15 @@ def conv_transpose(input, weight, bias=None, stride=1, padding=0, output_padding
def conv_transpose3d(input, weight, bias=None, stride=1, padding=0, output_padding=0, groups=1, dilation=1):
x = input
+ if x.dim() != 5:
+ raise RuntimeError(f'Expected 5D input to conv_transpose3d, but got input of size: {x.shape}')
N,C,D,H,W = x.shape
i,o,d,h,w = weight.shape
assert C==i
assert groups==1, "Group conv not supported yet."
stride = stride if isinstance(stride, tuple) else (stride, stride, stride)
+ if stride[0] <= 0 or stride[1] <= 0 or stride[2] <= 0:
+ raise RuntimeError("non-positive stride is not supported")
dilation = dilation if isinstance(dilation, tuple) else (dilation, dilation, dilation)
# added
padding = padding if isinstance(padding, tuple) else (padding, padding, padding)
@@ -1631,6 +1701,8 @@ def pad(x,padding, mode='constant', value=0):
class ReflectionPad2d(Module):
def __init__(self, padding):
+ if padding < 0:
+ raise RuntimeError(f"padding must be > 0, but got {padding}")
self.padding = padding
if isinstance(self.padding, int):
self.pl = self.padding
@@ -1641,6 +1713,8 @@ def __init__(self, padding):
self.pl, self.pr, self.pt, self.pb = self.padding
else:
raise TypeError(f"ReflectionPad2d padding just support int or tuple, but found {type(padding)}")
+ if self.pl < 0 or self.pr < 0 or self.pt < 0 or self.pb < 0:
+ raise ValueError(f"padding must be non-negative")
def execute(self, x):
n,c,h,w = x.shape
@@ -1669,8 +1743,12 @@ def __init__(self, padding):
self.pl, self.pr, self.pt, self.pb = self.padding
else:
raise TypeError(f"ZeroPad2d padding just support int or tuple, but found {type(padding)}")
+ if self.pl < 0 or self.pr < 0 or self.pt < 0 or self.pb < 0:
+ raise ValueError(f"padding must be non-negative")
def execute(self, x):
+ if x.dim() != 4:
+ raise RuntimeError("Input shape must be `(N, C, H, W)`!")
n,c,h,w = x.shape
return x.reindex([n,c,h+self.pt+self.pb,w+self.pl+self.pr], ["i0","i1",f"i2-{self.pt}",f"i3-{self.pl}"])
@@ -1687,6 +1765,8 @@ def __init__(self, padding, value):
else:
raise TypeError(f"ConstantPad2d padding just support int or tuple, but found {type(padding)}")
self.value = value
+ if self.pl < 0 or self.pr < 0 or self.pt < 0 or self.pb < 0:
+ raise ValueError(f"padding must be non-negative")
def execute(self, x):
assert len(x.shape) >= 2
@@ -1701,6 +1781,8 @@ def execute(self, x):
class ReplicationPad2d(Module):
def __init__(self, padding):
+ if padding < 0:
+ raise RuntimeError(f"padding must be > 0, but got {padding}")
self.padding = padding
if isinstance(self.padding, int):
self.pl = self.padding
@@ -1711,8 +1793,12 @@ def __init__(self, padding):
self.pl, self.pr, self.pt, self.pb = self.padding
else:
raise TypeError(f"ReplicationPad2d padding just support int or tuple, but found {type(padding)}")
+ if self.pl < 0 or self.pr < 0 or self.pt < 0 or self.pb < 0:
+ raise ValueError(f"padding must be non-negative")
def execute(self, x):
+ if x.dim() != 4:
+ raise RuntimeError("Input shape must be `(N, C, H, W)`!")
n,c,h,w = x.shape
oh=h+self.pt+self.pb
ow=w+self.pl+self.pr
@@ -1760,13 +1846,14 @@ def embedding(input, weight):
class PixelShuffle(Module):
def __init__(self, upscale_factor):
+ assert upscale_factor > 0,f"upscale_factor must be greater than zero,got {upscale_factor}"
self.upscale_factor = upscale_factor
def execute(self, x):
n,c,h,w = x.shape
r = self.upscale_factor
assert c%(r*r)==0, f"input channel needs to be divided by upscale_factor's square in PixelShuffle"
- if r<0:
+ if r<=0:
raise RuntimeError(f"pixel_shuffle expects a positive upscale_factor, but got {r}")
return x.reindex([n,int(c/r**2),h*r,w*r], [
"i0",
@@ -1815,6 +1902,15 @@ def execute(self, x):
class Resize(Module):
def __init__(self, size, mode="nearest", align_corners=False):
super().__init__()
+ if isinstance(size,int):
+ if size <= 0:
+ raise ValueError(f"sizes must be positive, got {size}")
+ elif isinstance(size,tuple) or isinstance(size,list):
+ for item in size:
+ if item <= 0:
+ raise ValueError(f"sizes must be positive, got {item}")
+ else:
+ raise ValueError(f"size must be int or tuple")
self.size = size
self.mode = mode
self.align_corners = align_corners
@@ -1869,8 +1965,12 @@ def _interpolate(img, x, y, ids, mode):
# TODO: tf_mode to another function
def resize(img, size, mode="nearest", align_corners=False, tf_mode=False):
+ if img.dim() != 4:
+ raise ValueError("Input shape must be `(N, C, H, W)`!")
n, c, h, w = img.shape
H, W = size
+ if h <= 0 or w <= 0 or H <= 0 or W <= 0:
+ raise RuntimeError(f"Input and output sizes should be greater than 0, but got input (H: {h}, W: {w}) output (H: {H}, W: {W})")
nid, cid, hid, wid = jt.index((n, c, H, W))
if align_corners:
x = hid * ((h - 1) / max(1, H - 1))
@@ -2159,16 +2259,30 @@ def grid_sample(input, grid, mode='bilinear', padding_mode='zeros', align_corner
class Upsample(Module):
- def __init__(self, scale_factor=None, mode='nearest'):
- self.scale_factor = scale_factor if isinstance(scale_factor, tuple) else (scale_factor, scale_factor)
+ def __init__(self, scale_factor=None, mode='nearest', align_corners=False):
+ if isinstance(scale_factor, tuple):
+ self.scale_factor = tuple(float(factor) for factor in scale_factor)
+ else:
+ self.scale_factor = float(scale_factor) if scale_factor else None
self.mode = mode
-
+ self.align_corners = align_corners
+
def execute(self, x):
- return upsample(x,
- size=(
- int(x.shape[2]*self.scale_factor[0]),
- int(x.shape[3]*self.scale_factor[1])),
- mode=self.mode)
+ if self.scale_factor is None:
+ raise ValueError("scale_factor should be defined")
+ elif isinstance(self.scale_factor, float):
+ return upsample(x,
+ size=(int(x.shape[2]*self.scale_factor),
+ int(x.shape[3]*self.scale_factor)),
+ mode=self.mode,
+ align_corners=self.align_corners)
+ else:
+ return upsample(x,
+ size=(
+ int(x.shape[2]*self.scale_factor[0]),
+ int(x.shape[3]*self.scale_factor[1])),
+ mode=self.mode,
+ align_corners=self.align_cornerss)
class UpsamplingBilinear2d(Upsample):
def __init__(self, scale_factor=None):
@@ -2234,11 +2348,20 @@ def __len__(self):
def named_children(self,):
return list(self.layers.items())
+
+ def __setattr__(self, key, value) -> None:
+ if isinstance(key, str) and key.isdigit():
+ if int(key) 0 and kernel_size[1] > 0, "kernel size must be positive"
if not isinstance(dilation, tuple):
dilation = (dilation, dilation)
+ assert dilation[0] > 0 and dilation[1] > 0, "dilation must be positive"
if not isinstance(padding, tuple):
padding = (padding, padding)
+ assert padding[0] >= 0 and padding[1] >= 0, "padding must be non-negative"
if not isinstance(stride, tuple):
stride = (stride, stride)
+ assert stride[0] > 0 and stride[1] > 0, "stride must be positive"
n, c, h, w = X.shape
shape = X.shape
area = kernel_size[0] * kernel_size[1]
@@ -2351,14 +2478,19 @@ def unfold(X, kernel_size, dilation=1, padding=0, stride=1):
def fold(X,output_size,kernel_size,dilation=1,padding=0,stride=1):
assert X.ndim==3
+ assert output_size[0] > 0 and output_size[1] > 0, "output size must be positive."
if not isinstance(kernel_size,tuple):
kernel_size = (kernel_size,kernel_size)
+ assert kernel_size[0] > 0 and kernel_size[1] > 0, "kernel size must be positive"
if not isinstance(dilation,tuple):
dilation = (dilation,dilation)
+ assert dilation[0] > 0 and dilation[1] > 0, "dilation must be positive"
if not isinstance(padding,tuple):
padding = (padding,padding)
+ assert padding[0] >= 0 and padding[1] >= 0, "padding must be non-negative"
if not isinstance(stride,tuple):
stride = (stride,stride)
+ assert stride[0] > 0 and stride[1] > 0, "stride must be positive"
n,cl,num = X.shape
area = kernel_size[0] * kernel_size[1]
block_nums = []
@@ -2923,6 +3055,10 @@ def call_rnn_cell(self, input, hidden, suffix):
return h, h
def bilinear(in1, in2, weight, bias):
+ if weight.shape[1] != in1.shape[1]:
+ raise RuntimeError(f"bilinear(): input1 size deos not match weight size: got {in1.shape[1]} but expected {weight.shape[1]}")
+ if weight.shape[2] != in2.shape[1]:
+ raise RuntimeError(f"bilinear(): input2 size deos not match weight size: got {in2.shape[1]} but expected {weight.shape[2]}")
w = weight.transpose((1,0,2))
w = w.reshape((w.shape[0], -1))
x = jt.matmul(in1, w)
@@ -3010,6 +3146,10 @@ def __init__(self, real: jt.Var, imag: jt.Var=None, is_concat_value=False):
assert real.dtype == imag.dtype
self.value = jt.stack([real, imag], dim=-1)
+ @property
+ def requires_grad(self):
+ return self.value.requires_grad
+
@property
def real(self):
return self.value[..., 0]
@@ -3022,6 +3162,10 @@ def imag(self):
def shape(self):
return self.value.shape[:-1]
+ @property
+ def dtype(self):
+ return "complex64"
+
def norm(self):
return jt.sqrt(jt.sqr(self.real) + jt.sqr(self.imag))
@@ -3049,6 +3193,12 @@ def permute(self, *axes):
def transpose(self, *axes):
return ComplexNumber(jt.transpose(self.real, *axes), jt.transpose(self.imag, *axes))
+ def broadcast(self, shape, dims):
+ return ComplexNumber(self.real.broadcast(shape, dims), self.imag.broadcast(shape, dims))
+
+ def sum(self, dims, keepdims: bool=False):
+ return ComplexNumber(self.real.sum(dims, keepdims=keepdims), self.imag.sum(dims, keepdims=keepdims))
+
def exp(self):
er = jt.exp(self.real)
return ComplexNumber(er * jt.cos(self.imag), er * jt.sin(self.imag))
@@ -3059,7 +3209,7 @@ def conj(self):
def __add__(self, other):
if isinstance(other, ComplexNumber):
return ComplexNumber(self.real + other.real, self.imag + other.imag)
- elif isinstance(other, (jt.Var, int, float)):
+ elif isinstance(other, (int, float)):
return ComplexNumber(self.real + other, self.imag)
else:
raise NotImplementedError
@@ -3067,7 +3217,7 @@ def __add__(self, other):
def __radd__(self, other):
if isinstance(other, ComplexNumber):
return ComplexNumber(other.real + self.real, other.imag + self.imag)
- elif isinstance(other, (jt.Var, int, float)):
+ elif isinstance(other, (int, float)):
return ComplexNumber(other + self.real, self.imag)
else:
raise NotImplementedError
@@ -3075,7 +3225,7 @@ def __radd__(self, other):
def __sub__(self, other):
if isinstance(other, ComplexNumber):
return ComplexNumber(self.real - other.real, self.imag - other.imag)
- elif isinstance(other, (jt.Var, int, float)):
+ elif isinstance(other, (int, float)):
return ComplexNumber(self.real - other, self.imag)
else:
raise NotImplementedError
@@ -3083,7 +3233,7 @@ def __sub__(self, other):
def __rsub__(self, other):
if isinstance(other, ComplexNumber):
return ComplexNumber(other.real - self.real, other.imag - self.imag)
- elif isinstance(other, (jt.Var, int, float)):
+ elif isinstance(other, (int, float)):
return ComplexNumber(other - self.real, self.imag)
else:
raise NotImplementedError
@@ -3094,8 +3244,6 @@ def __mul__(self, other):
self.real * other.imag + self.imag * other.real)
elif isinstance(other, (int, float)):
return ComplexNumber(self.value * other, is_concat_value=True)
- elif isinstance(other, jt.Var):
- return ComplexNumber(self.real * other, self.imag * other)
else:
raise NotImplementedError
@@ -3105,8 +3253,6 @@ def __rmul__(self, other):
other.imag * self.real + other.real * self.imag)
elif isinstance(other, (int, float)):
return ComplexNumber(other * self.value, is_concat_value=True)
- elif isinstance(other, jt.Var):
- return ComplexNumber(other * self.real, other * self.imag)
else:
raise NotImplementedError
@@ -3117,8 +3263,6 @@ def __truediv__(self, other):
(self.imag * other.real - self.real * other.imag) / norm)
elif isinstance(other, (int, float)):
return ComplexNumber(self.value / other, is_concat_value=True)
- elif isinstance(other, jt.Var):
- return ComplexNumber(self.real / other, self.imag / other)
else:
raise NotImplementedError
@@ -3127,7 +3271,7 @@ def __rtruediv__(self, other):
if isinstance(other, ComplexNumber):
return ComplexNumber((other.real * self.real + other.imag * self.imag) / norm,
(other.imag * self.real - other.real * self.imag) / norm)
- elif isinstance(other, (int, float, jt.Var)):
+ elif isinstance(other, (int, float)):
return ComplexNumber(other * self.real / norm, - other * self.imag / norm)
else:
raise NotImplementedError
@@ -3136,8 +3280,6 @@ def __matmul__(self, other):
if isinstance(other, ComplexNumber):
return ComplexNumber(self.real @ other.real - self.imag @ other.imag,
self.real @ other.imag + self.imag @ other.real)
- elif isinstance(other, jt.Var):
- return ComplexNumber(self.real @ other, self.imag @ other)
else:
raise NotImplementedError
@@ -3145,8 +3287,6 @@ def __imatmul__(self, other):
if isinstance(other, ComplexNumber):
return ComplexNumber(other.real @ self.real - other.imag @ self.imag,
other.imag @ self.real + other.real @ self.imag)
- elif isinstance(other, jt.Var):
- return ComplexNumber(other @ self.real, other @ self.imag)
else:
raise NotImplementedError
@@ -3160,6 +3300,140 @@ def ifft2(self):
return ComplexNumber(_fft2(self.value, inverse=True), is_concat_value=True)
+def polar(abs:jt.Var, angle: jt.Var) -> ComplexNumber:
+ assert abs.shape == angle.shape
+ return ComplexNumber(abs * angle.cos(),abs * angle.sin())
+
+def view_as_complex(x: jt.Var) -> ComplexNumber:
+ assert x.shape[-1] == 2
+ return ComplexNumber(x[...,0],x[...,1])
+
+def view_as_real(x: ComplexNumber) -> jt.Var:
+ return jt.stack([x.value[...,0],x.value[...,1]],dim=-1)
+
+# reference: https://github.com/pytorch/pytorch/blob/8ea5b572a63b1acc538a9fc8d3862c73739116e8/torch/functional.py#L1258
+def tensordot(a, b, dims=2):
+ r"""Returns a contraction of a and b over multiple dimensions.
+
+ :attr:`tensordot` implements a generalized matrix product.
+
+ Args:
+ a (Tensor): Left tensor to contract
+ b (Tensor): Right tensor to contract
+ dims (int or Tuple[List[int], List[int]] or List[List[int]] containing two lists or Tensor): number of dimensions to
+ contract or explicit lists of dimensions for :attr:`a` and
+ :attr:`b` respectively
+
+ When called with a non-negative integer argument :attr:`dims` = :math:`d`, and
+ the number of dimensions of :attr:`a` and :attr:`b` is :math:`m` and :math:`n`,
+ respectively, :func:`tensordot` computes
+
+ .. math::
+ r_{i_0,...,i_{m-d}, i_d,...,i_n}
+ = \sum_{k_0,...,k_{d-1}} a_{i_0,...,i_{m-d},k_0,...,k_{d-1}} \times b_{k_0,...,k_{d-1}, i_d,...,i_n}.
+
+ When called with :attr:`dims` of the list form, the given dimensions will be contracted
+ in place of the last :math:`d` of :attr:`a` and the first :math:`d` of :math:`b`. The sizes
+ in these dimensions must match.
+
+ """
+ if not isinstance(dims, (tuple, list, int)):
+ raise RuntimeError(
+ "tensordot expects dims to be int or "
+ + "Tuple[List[int], List[int]] or "
+ + "List[List[int]] containing two lists, but got "
+ + f"dims={dims}"
+ )
+
+ dims_a, dims_b = [], []
+
+ if isinstance(dims, (tuple, list)):
+ dims_a, dims_b = dims
+
+ if isinstance(dims, (int)):
+ if dims < 0:
+ raise RuntimeError(f"tensordot expects dims >= 0, but got dims={dims}")
+ if dims > min(len(a.shape), len(b.shape)):
+ raise RuntimeError(
+ f"tensordot expects dims < ndim_a or ndim_b, but got dims={dims}"
+ )
+ dims_a = list(range(len(a.shape)-dims, len(a.shape)))
+ dims_b = list(range(dims))
+
+ # reference: https://github.com/pytorch/pytorch/blob/8ea5b572a63b1acc538a9fc8d3862c73739116e8/aten/src/ATen/native/Linear.cpp#L769
+ def __tensordot_native(input1:jt.Var, input2:jt.Var, dims1, dims2):
+ if not isinstance(dims1, (list, tuple)):
+ raise RuntimeError("tensordot expects dims1 to be List[Int], but got dims={}".format(dims1))
+ if not isinstance(dims2, (list, tuple)):
+ raise RuntimeError("tensordot expects dims2 to be List[Int], but got dims={}".format(dims2))
+ dims1 = list(dims1)
+ dims2 = list(dims2)
+ if len(dims1) != len(dims2):
+ raise RuntimeError("both dimension lists should have the same length")
+ if input1.dtype != input2.dtype:
+ raise RuntimeError("both inputs should have the same dtype")
+ t1 = input1
+ t2 = input2
+ csize = 1
+ input1_bitmap = np.zeros(len(input1.shape), dtype='bool')
+ input2_bitmap = np.zeros(len(input2.shape), dtype='bool')
+ for i in range(len(dims1)):
+ s1 = input1.shape[dims1[i]]
+ s2 = input2.shape[dims2[i]]
+ input1_bitmap[dims1] = True
+ input2_bitmap[dims2] = True
+ if s2 == 1: #broadcasted dimensions can be summed right away
+ t1 = t1.sum(dims1[i], keepdims=True)
+ elif s1 == 1:
+ t2 = t2.sum(dims2[i], keepdims=True)
+ else:
+ if s1 != s2:
+ raise RuntimeError("contracted dimensions need to match, but first has size {}, in dim {}, and second has size {}".format(s1, i, s2))
+ csize *= s1
+
+ p1, p2 = [], [] # p1, p2: input permutations
+ rsizes = []
+ size1, size2 = 1, 1 # number of non-contracted elements
+ for i in range(len(input1.shape)):
+ if not input1_bitmap[i]:
+ p1.append(i)
+ size1 *= t1.shape[i]
+ rsizes.append(t1.shape[i])
+ p1 += dims1
+ p2 += dims2
+ for i in range(len(input2.shape)):
+ if not input2_bitmap[i]:
+ p2.append(i)
+ size2 *= t2.shape[i]
+ rsizes.append(t2.shape[i])
+
+ # permute and reshape for matrix multiplication
+ t1 = t1.permute(p1).reshape((size1, csize))
+ t2 = t2.permute(p2).reshape((csize, size2))
+ # multiply and reshape to target size
+ return jt.matmul(t1, t2).reshape(rsizes)
+
+ return __tensordot_native(a, b, dims_a, dims_b)
+
+# reference: https://github.com/pytorch/pytorch/blob/5ed3b70d09a4ab2a5be4becfda9dd0d3e3227c39/aten/src/ATen/native/LinearAlgebra.cpp#L3375
+def kron(a:jt.Var, b:jt.Var):
+ a_dim, b_dim = len(a.shape), len(b.shape)
+ max_dim = max(a_dim, b_dim)
+ pad_a, pad_b = max_dim-a_dim, max_dim-b_dim
+ a_reshape, b_reshape = [], []
+ result_reshape = []
+ for i in range(max_dim):
+ a_2i_shape = a.shape[i - pad_a] if i >= pad_a else 1
+ b_2i1_shape = b.shape[i - pad_b] if i >= pad_b else 1
+ a_reshape.append(a_2i_shape)
+ a_reshape.append(1)
+ b_reshape.append(1)
+ b_reshape.append(b_2i1_shape)
+ result_reshape.append(a_2i_shape * b_2i1_shape)
+ a = a.reshape(a_reshape)
+ b = b.reshape(b_reshape)
+ return (a * b).reshape(result_reshape)
+
def one_hot(x: jt.Var, num_classes: int=-1) -> jt.Var:
''' Returns the one_hot encoding of inputs.
diff --git a/python/jittor/pool.py b/python/jittor/pool.py
index f9ad9a68..7e9e808a 100644
--- a/python/jittor/pool.py
+++ b/python/jittor/pool.py
@@ -29,9 +29,20 @@ def __init__(self, kernel_size, stride=None, padding=0, dilation=None, return_in
self.padding = padding if isinstance(padding, tuple) else (padding, padding)
self.ceil_mode = ceil_mode
self.count_include_pad = count_include_pad and padding != 0
+ for item in self.kernel_size:
+ if item <= 0:
+ raise RuntimeError(f"kernel_size must be greater than zero, but got {item}")
+ for item in self.stride:
+ if item <= 0:
+ raise RuntimeError(f"stride must be greater than zero, but got {item}")
+ for item in self.padding:
+ if item < 0:
+ raise RuntimeError(f"padding must be non-negative, but got {item}")
def execute(self, x):
N,C,H,W = x.shape
+ if H <= self.kernel_size[0] or W <= self.kernel_size[1]:
+ raise RuntimeError(f"size of var should be larger than kernel_size")
if self.ceil_mode == False:
h = (H+self.padding[0]*2-self.kernel_size[0])//self.stride[0]+1
w = (W+self.padding[1]*2-self.kernel_size[1])//self.stride[1]+1
@@ -205,9 +216,17 @@ def __init__(self, kernel_size, stride=None, padding=0, dilation=None, return_in
self.padding = _triple(padding)
self.ceil_mode = ceil_mode
self.count_include_pad = count_include_pad and padding != 0
+ if self.kernel_size[0] <= 0 or self.kernel_size[1] <= 0 or self.kernel_size[2] <= 0:
+ raise RuntimeError(f"kernel_size must be greater than zero, but got {kernel_size}")
+ if self.stride[0] <= 0 or self.stride[1] <= 0 or self.stride[2] <= 0:
+ raise RuntimeError(f"stride must be greater than zero, but got {stride}")
+ if self.padding[0] < 0 or self.padding[1] < 0 or self.padding[2] < 0:
+ raise RuntimeError(f"padding must be non-negative, but got {padding}")
def execute(self, x):
N,C,D,H,W = x.shape
+ if D <= self.kernel_size[0] or H <= self.kernel_size[1] or W <= self.kernel_size[2]:
+ raise RuntimeError(f"size of var should be larger than kernel_size")
if self.ceil_mode == False:
d = (D+self.padding[0]*2-self.kernel_size[0])//self.stride[0]+1
h = (H+self.padding[1]*2-self.kernel_size[1])//self.stride[1]+1
@@ -506,7 +525,7 @@ def execute(self, x):
f"i3*{self.sh}+i6", # Hid
f"i4*{self.sw}+i7", # Wid
])
- return xx.reduce("maximun", [5,6,7])
+ return xx.reduce("maximum", [5,6,7])
def pool(x, kernel_size, op, padding=0, stride=None):
return Pool(kernel_size, stride, padding, op=op)(x)
diff --git a/python/jittor/src/mem/allocator/sfrl_allocator.cc b/python/jittor/src/mem/allocator/sfrl_allocator.cc
index ce5460fd..add3aef9 100644
--- a/python/jittor/src/mem/allocator/sfrl_allocator.cc
+++ b/python/jittor/src/mem/allocator/sfrl_allocator.cc
@@ -242,7 +242,12 @@ std::mutex sfrl_allocator_mutex;
void* SFRLAllocator::alloc(size_t size, size_t& allocation) {
std::unique_lock lock(sfrl_allocator_mutex);
+ #ifdef IS_ACL
+ // output of acl op need additional 32 bytes
+ size = align_size(size+32);
+ #else
size = align_size(size);
+ #endif
CachingBlockPool* blocks = get_blocks(size);
//search cached block
CachingBlock* block = blocks->pop_block(size);
diff --git a/python/jittor/src/misc/cuda_atomic.h b/python/jittor/src/misc/cuda_atomic.h
index 6348befb..15b39412 100644
--- a/python/jittor/src/misc/cuda_atomic.h
+++ b/python/jittor/src/misc/cuda_atomic.h
@@ -6,7 +6,9 @@
// ***************************************************************
#pragma once
#include
+#ifndef IS_ROCM
#include
+#endif
#include "common.h"
namespace jittor {
diff --git a/python/jittor/src/misc/nan_checker.cc b/python/jittor/src/misc/nan_checker.cc
index bbec4ccb..00fb34aa 100644
--- a/python/jittor/src/misc/nan_checker.cc
+++ b/python/jittor/src/misc/nan_checker.cc
@@ -11,7 +11,9 @@
#include "misc/cuda_flags.h"
#include
#include
+#ifndef IS_ROCM
#include
+#endif
#include "helper_cuda.h"
#endif
#include "mem/allocator.h"
@@ -22,7 +24,9 @@ namespace jittor {
#ifdef IS_CUDA
EXTERN_LIB vector check_nan_float16(__half* ptr, int64 num);
+#ifndef IS_ROCM
EXTERN_LIB vector check_nan_bfloat16(__nv_bfloat16* ptr, int64 num);
+#endif
EXTERN_LIB vector check_nan_float32(float32* ptr, int64 num);
EXTERN_LIB vector check_nan_float64(float64* ptr, int64 num);
#endif
@@ -33,7 +37,9 @@ void dump_var(Var* v, string name) {
name = ss.str();
LOGe << "dump" << v << "to" << name;
char* buffer = new char[v->size];
- #ifdef HAS_CUDA
+ #ifdef IS_ROCM
+ hipMemcpy(buffer, v->mem_ptr, v->size, hipMemcpyDefault);
+ #elif IS_CUDA
cudaMemcpy(buffer, v->mem_ptr, v->size, cudaMemcpyDefault);
#else
std::memcpy(buffer, v->mem_ptr, v->size);
@@ -57,9 +63,11 @@ bool check_nan(Var* v, Op* op) {
if (v->dtype() == ns_float16) {
nan_index = check_nan_float16((__half*)v->mem_ptr, v->num);
}
+ #ifndef IS_ROCM
if (v->dtype() == ns_bfloat16) {
nan_index = check_nan_bfloat16((__nv_bfloat16*)v->mem_ptr, v->num);
}
+ #endif
if (v->dtype() == ns_float32) {
nan_index = check_nan_float32((float32*)v->mem_ptr, v->num);
} else
@@ -104,14 +112,16 @@ bool check_nan(Var* v, Op* op) {
auto* ptr = input->ptr<__half>();
__half value;
cudaMemcpy(&value, ptr+index, sizeof(__half), cudaMemcpyDeviceToHost);
- LOGe << "input" << icnt << "dtype" << input->dtype() << "index" << index << "value" << (float)value;
+ // LOGe << "input" << icnt << "dtype" << input->dtype() << "index" << index << "value" << (float)value;
} else
+ #ifndef IS_ROCM
if (input->dtype() == ns_bfloat16) {
auto* ptr = input->ptr<__nv_bfloat16>();
__nv_bfloat16 value;
cudaMemcpy(&value, ptr+index, sizeof(__nv_bfloat16), cudaMemcpyDeviceToHost);
LOGe << "input" << icnt << "dtype" << input->dtype() << "index" << index << "value" << (float)value;
} else
+ #endif
if (input->dtype() == ns_float32) {
auto* ptr = input->ptr();
float32 value;
diff --git a/python/jittor/src/misc/nan_checker.cu b/python/jittor/src/misc/nan_checker.cu
index ec7998ba..b2e631f1 100644
--- a/python/jittor/src/misc/nan_checker.cu
+++ b/python/jittor/src/misc/nan_checker.cu
@@ -7,9 +7,13 @@
#include "misc/cuda_flags.h"
#include
#include
-#include
+
#include "helper_cuda.h"
#include
+//TODO:FIX in ROCM
+#ifndef IS_ROCM
+#include
+#endif
namespace jittor {
@@ -37,7 +41,7 @@ __global__ void _check_nan_float16(__half* __restrict__ ptr, int64 num, int* cnt
print_nan(float(ptr[i]), i, cnt);
}
}
-
+#ifndef IS_ROCM
__global__ void _check_nan_bfloat16(__nv_bfloat16* __restrict__ ptr, int64 num, int* cnt) {
int64 i = threadIdx.x + blockIdx.x * (int64)blockDim.x;
if (i check_nan_float16(__half* ptr, int64 num) {
_check_nan_float16<<>>(ptr, num, check_nan_get_device_ptr());
return report_nan();
}
-
+#ifndef IS_ROCM
vector check_nan_bfloat16(__nv_bfloat16* ptr, int64 num) {
int block_num = std::max((int64)1, (num-1)/1024+1);
int thread_num = std::min((int64)1024, num);
_check_nan_bfloat16<<>>(ptr, num, check_nan_get_device_ptr());
return report_nan();
}
-
+#endif
#endif
}
\ No newline at end of file
diff --git a/python/jittor/src/ops/binary_op.cc b/python/jittor/src/ops/binary_op.cc
index 01a8ef2d..848d40a4 100644
--- a/python/jittor/src/ops/binary_op.cc
+++ b/python/jittor/src/ops/binary_op.cc
@@ -440,6 +440,17 @@ BinaryOp::BinaryOp(Var* x, Var* y, NanoString op) : x(x), y(y) {
return;
}
+ #ifdef IS_ACL
+ if (x->dtype() != y->dtype()) {
+ auto dtype = binary_dtype_infer(ns_add, x->ns, y->ns, 0, 0);
+ auto xp = make_unary(x, dtype);
+ auto yp = make_unary(y, dtype);
+ auto zp = make_binary(xp, yp, op);
+ forward(zp);
+ return;
+ }
+ #endif
+
flags.set(NodeFlags::_cpu);
flags.set(NodeFlags::_cuda);
set_type(OpType::element);
diff --git a/python/jittor/src/type/fp16_compute.h b/python/jittor/src/type/fp16_compute.h
index 56f70789..f5c85d3b 100644
--- a/python/jittor/src/type/fp16_compute.h
+++ b/python/jittor/src/type/fp16_compute.h
@@ -11,12 +11,17 @@
#include
#include
+#ifndef IS_ROCM
#include
+#endif
namespace jittor {
typedef __half float16;
+#ifndef IS_ROCM
typedef __nv_bfloat16 bfloat16;
+#endif
+
#if CUDA_ARCH >= 800
inline __device__ float16 max(float16 a, float16 b) { return __hmax(a, b); }
@@ -32,7 +37,7 @@ inline __device__ float16 min(float16 a, float16 b) { return float(a)= 800
inline __device__ bfloat16 max(bfloat16 a, bfloat16 b) { return __hmax(a, b); }
inline __device__ bfloat16 min(bfloat16 a, bfloat16 b) { return __hmin(a, b); }
@@ -45,7 +50,7 @@ inline __device__ bfloat16 min(bfloat16 a, bfloat16 b) { return float(a)
__device__ inline
typename std::enable_if::type
diff --git a/python/jittor/src/var_holder.cc b/python/jittor/src/var_holder.cc
index 092eddd2..ed4d4c0a 100644
--- a/python/jittor/src/var_holder.cc
+++ b/python/jittor/src/var_holder.cc
@@ -215,11 +215,14 @@ inline static void cast_item_data(ItemData& data) {
auto* fp16 = (float16*)&data;
auto* fp32 = (float32*)&data;
fp32[0] = float32(fp16[0]);
- } else if (data.dtype == ns_bfloat16) {
+ }
+ #ifndef IS_ROCM
+ else if (data.dtype == ns_bfloat16) {
auto* bf16 = (bfloat16*)&data;
auto* fp32 = (float32*)&data;
fp32[0] = float32(bf16[0]);
}
+ #endif
data.dtype = ns_float32;
}
diff --git a/python/jittor/test/test_aclop.py b/python/jittor/test/test_aclop.py
new file mode 100644
index 00000000..41a6d748
--- /dev/null
+++ b/python/jittor/test/test_aclop.py
@@ -0,0 +1,463 @@
+import unittest
+import jittor as jt
+from .test_core import expect_error
+import numpy as np
+from jittor import init, Module
+import numpy as np
+
+
+@unittest.skipIf(not jt.compiler.has_acl, "No ACL found")
+class TestACL(unittest.TestCase):
+
+ @jt.flag_scope(use_acl=1)
+ def test_getitem(self):
+ a = jt.ones(100, 2)
+ b = a[0:2, 0:2]
+ np.testing.assert_allclose(b.numpy(), [[1, 1], [1, 1]])
+ print("test getitem success")
+
+ @jt.flag_scope(use_acl=1)
+ def test_getitem_neg(self):
+ a = jt.ones(2, 3, 2)
+ b = a[0:1,0:-2]
+ np.testing.assert_allclose(b.numpy(), [[[1,1]]])
+ print("test getitem neg success")
+
+ @jt.flag_scope(use_acl=1)
+ def test_setitem(self):
+ a = jt.ones(2, 2)
+ b = jt.Var(0)
+ a[0:1, 0:1] = b
+ np.testing.assert_allclose(a.numpy(), [[0, 1], [1, 1]])
+ print("test setitem success")
+
+ @jt.flag_scope(use_acl=1)
+ def test_setitem_neg(self):
+ a = jt.ones(2, 3, 2)
+ b = jt.Var(0)
+ a[0:1, 0:-2] = b
+ np.testing.assert_allclose(a.numpy(), [[[0,0],[1,1],[1,1]],[[1,1],[1,1],[1,1]]])
+ print("test setitem neg success")
+
+ @jt.flag_scope(use_acl=1)
+ def test_getitem_grad(self):
+ a = jt.ones(2, 2)
+ b = a[0:1, 0:1]
+ optimizer = jt.optim.SGD([a], 0.1)
+ loss = b.sum()
+ optimizer.zero_grad()
+ optimizer.backward(loss)
+ optimizer.step()
+ res = a.opt_grad(optimizer)
+ np.testing.assert_allclose(res.numpy(), [[1, 0], [0, 0]])
+ print("test getitem grad success")
+
+ @jt.flag_scope(use_acl=1)
+ def test_setitem_grad(self):
+ a = jt.ones(3, 3)
+ b = jt.ones(2, 2)
+ a[0:2, 0:2] = b * 2
+ optimizer = jt.optim.SGD([a, b], 0.1)
+ loss = a.sum()
+ optimizer.zero_grad()
+ optimizer.backward(loss)
+ optimizer.step()
+ res_a = a.opt_grad(optimizer)
+ res_b = b.opt_grad(optimizer)
+ np.testing.assert_allclose(res_a.numpy(),
+ [[1, 1, 1], [1, 1, 1], [1, 1, 1]])
+ np.testing.assert_allclose(res_b.numpy(), [[2, 2], [2, 2]])
+ print("test setitem grad success")
+
+ @jt.flag_scope(use_acl=1)
+ def test_concat(self):
+ a = jt.ones(2, 2)
+ b = jt.ones(2, 2)
+ c = jt.concat([a, b], 0)
+ np.testing.assert_allclose(c.numpy(), [[1, 1], [1, 1], [1, 1], [1, 1]])
+ print("test concat success")
+
+ @jt.flag_scope(use_acl=1)
+ def test_concat_neg(self):
+ a = jt.ones(2, 2)
+ b = jt.ones(2, 2)
+ c = jt.concat([a, b], -1)
+ np.testing.assert_allclose(c.numpy(), [[1,1,1,1],[1,1,1,1]])
+ print("test concat neg success")
+
+ @jt.flag_scope(use_acl=1)
+ def test_concat_zero_dim(self):
+ a = jt.ones([])
+ b = jt.zeros([])
+ c = jt.concat([a, b], 0)
+ np.testing.assert_allclose(c.numpy(), [1,0])
+ print("test concat zero dim success")
+
+ @jt.flag_scope(use_acl=1)
+ def test_maxpool_grad(self):
+ a = jt.float32([[[[1,2,3,4],[2,3,4,1],[3,4,1,2],[4,1,2,3]]]])
+ max_pool = jt.nn.Pool(2, op='maximum')
+ optimizer = jt.optim.SGD([a], 0.1)
+ b = max_pool(a)
+ loss = b.sum()
+ optimizer.zero_grad()
+ optimizer.backward(loss)
+ optimizer.step()
+ res = a.opt_grad(optimizer)
+ np.testing.assert_allclose(
+ res.numpy(),
+ [[[[0, 0, 0, 1], [0, 1, 0, 0], [0, 1, 0, 0], [0, 0, 0, 1]]]])
+ print("test maxpool grad success")
+
+ @jt.flag_scope(use_acl=1)
+ def test_triu(self):
+ a = jt.ones(3, 3)
+ b = jt.triu_(a, 0)
+ c = jt.triu_(a, 1)
+ d = jt.triu_(a, -1)
+ np.testing.assert_allclose(b.numpy(),
+ [[1, 1, 1], [0, 1, 1], [0, 0, 1]])
+ np.testing.assert_allclose(c.numpy(),
+ [[0, 1, 1], [0, 0, 1], [0, 0, 0]])
+ np.testing.assert_allclose(d.numpy(),
+ [[1, 1, 1], [1, 1, 1], [0, 1, 1]])
+ print("test triu success")
+
+ @jt.flag_scope(use_acl=1)
+ def test_bmm(self):
+ a = jt.float32([[[1,2],[3,4]],[[2,1],[4,3]],[[1,2],[4,3]]])
+ b = jt.bmm(a, a)
+ np.testing.assert_allclose(
+ b.numpy(), [[[7, 10], [15, 22]], [[8, 5], [20, 13]], [[9, 8], [16, 17]]])
+ print("test bmm success")
+
+ @jt.flag_scope(use_acl=1)
+ def test_matmul(self):
+ a = jt.float32([[[1,2,3,4],[5,6,7,8],[9,10,11,12],[13,14,15,16]]])
+ b = jt.float32([[1,1],[1,1],[1,1],[1,1]])
+ c = jt.matmul(a, b)
+ np.testing.assert_allclose(c.numpy(),
+ [[[10, 10], [26, 26], [42, 42], [58, 58]]])
+ print("test matmul success")
+
+ @jt.flag_scope(use_acl=1)
+ def test_maxpool(self):
+ a = jt.float32([[[[1,2,3,4],[2,3,4,1],[3,4,1,2],[4,1,2,3]]]])
+ max_pool = jt.nn.Pool(2, op='maximum')
+ np.testing.assert_allclose(max_pool(a).numpy(), [[[[3, 4], [4, 3]]]])
+ print("test maxpool success")
+
+ @jt.flag_scope(use_acl=1)
+ def test_transpose(self):
+ a = jt.float32([[[1,2],[3,4]]])
+ b = a.transpose(0, 2)
+ np.testing.assert_allclose(b.numpy(), [[[1], [3]], [[2], [4]]])
+ print("test transpose success")
+
+ @jt.flag_scope(use_acl=1)
+ def test_transpose_neg(self):
+ a = jt.float32([[[1,2],[3,4]]])
+ b = a.transpose(1, -1)
+ np.testing.assert_allclose(b.numpy(), [[[1,3], [2,4]]])
+ print("test transpose neg success")
+
+ @jt.flag_scope(use_acl=1)
+ def test_matmul_grad(self):
+ a = jt.float32([[[1,2,3,4],[5,6,7,8],[9,10,11,12],[13,14,15,16]]])
+ b = jt.float32([[1,1],[1,1],[1,1],[1,1]])
+ optimizer = jt.optim.SGD([a, b], 0.1)
+ loss = jt.matmul(a, b).sum()
+ optimizer.zero_grad()
+ optimizer.backward(loss)
+ optimizer.step()
+ res_a = a.opt_grad(optimizer)
+ res_b = b.opt_grad(optimizer)
+ np.testing.assert_allclose(res_a.numpy(), [[[2, 2, 2, 2], [2, 2, 2, 2], [2, 2, 2, 2], [2, 2, 2, 2]]])
+ np.testing.assert_allclose(res_b.numpy(), [[28, 28], [32, 32], [36, 36], [40, 40]])
+ print("test matmul grad success")
+
+ @jt.flag_scope(use_acl=1)
+ def test_bmm_grad(self):
+ a = jt.float32([[[1,2],[3,4]],[[2,1],[4,3]],[[1,2],[4,3]]])
+ optimizer = jt.optim.SGD([a], 0.1)
+ c = jt.bmm(a, a)
+ loss = c.sum()
+
+ optimizer.zero_grad()
+ optimizer.backward(loss)
+ optimizer.step()
+
+ res = a.opt_grad(optimizer)
+ np.testing.assert_allclose(
+ res.numpy(),
+ [[[7, 11], [9, 13]], [[9, 13], [7, 11]], [[8, 12], [8, 12]]])
+ print("test bmm grad success")
+
+ @jt.flag_scope(use_acl=1)
+ def test_avgpool(self):
+ a = jt.float32([[[[1,2,3,4],[2,3,4,1],[3,4,1,2],[4,1,2,3]]]])
+ avg_pool = jt.nn.Pool(2, op='mean')
+ b = avg_pool(a)
+ np.testing.assert_allclose(b.numpy(), [[[[2, 3], [3, 2]]]])
+ print("test avgpool success")
+
+ @jt.flag_scope(use_acl=1)
+ def test_adaptive_maxpool2d(self):
+ a = jt.float32([[[[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12],
+ [13, 14, 15, 16]]]])
+ pool_1 = jt.nn.AdaptiveMaxPool2d((2, 2))
+ pool_2 = jt.nn.AdaptiveMaxPool2d((3, 4))
+ b = pool_1(a)
+ c = pool_2(a)
+ np.testing.assert_allclose(b.numpy(), [[[[6, 8], [14, 16]]]])
+ np.testing.assert_allclose(c.numpy(), [[[[5,6,7,8],[9,10,11,12],[13,14,15,16]]]])
+ print("test adaptive_maxpool2d success")
+
+ @jt.flag_scope(use_acl=1)
+ def test_adaptive_maxpool2d_grad_1(self):
+ a = jt.float32([[[[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12],
+ [13, 14, 15, 16]]]])
+ max_pool = jt.nn.AdaptiveMaxPool2d((2, 2))
+ optimizer = jt.optim.SGD([a], 0.1)
+ b = max_pool(a)
+ loss = b.sum()
+ optimizer.zero_grad()
+ optimizer.backward(loss)
+ optimizer.step()
+ res = a.opt_grad(optimizer)
+ np.testing.assert_allclose(
+ res.numpy(),
+ [[[[0, 0, 0, 0], [0, 1, 0, 1], [0, 0, 0, 0], [0, 1, 0, 1]]]])
+ print("test adaptive_maxpool2d_1 grad success")
+
+ @jt.flag_scope(use_acl=1)
+ def test_adaptive_maxpool2d_grad_2(self):
+ a = jt.float32([[[[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12],
+ [13, 14, 15, 16]]]])
+ max_pool = jt.nn.AdaptiveMaxPool2d((1, 3))
+ optimizer = jt.optim.SGD([a], 0.1)
+ b = max_pool(a)
+ loss = b.sum()
+ optimizer.zero_grad()
+ optimizer.backward(loss)
+ optimizer.step()
+ res = a.opt_grad(optimizer)
+ np.testing.assert_allclose(
+ res.numpy(),
+ [[[[0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 1, 1, 1]]]])
+ print("test adaptive_maxpool2d_2 grad success")
+
+ @jt.flag_scope(use_acl=1)
+ def test_adaptive_avgpool2d(self):
+ a = jt.float32([[[[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12],
+ [13, 14, 15, 16]]]])
+ pool_1 = jt.nn.AdaptiveAvgPool2d((2, 2))
+ pool_2 = jt.nn.AdaptiveAvgPool2d((1, 3))
+ b = pool_1(a)
+ c = pool_2(a)
+ np.testing.assert_allclose(b.numpy(), [[[[3.5, 5.5], [11.5, 13.5]]]])
+ np.testing.assert_allclose(c.numpy(), [[[[7.5, 8.5, 9.5]]]])
+ print("test adaptive_avgpool2d success")
+
+ @jt.flag_scope(use_acl=1)
+ def test_adaptive_avgpool2d_grad(self):
+ a = jt.float32([[[[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12],
+ [13, 14, 15, 16]]]])
+ avg_pool = jt.nn.AdaptiveAvgPool2d((2, 2))
+ optimizer = jt.optim.SGD([a], 0.1)
+ b = avg_pool(a)
+ loss = b.sum()
+ optimizer.zero_grad()
+ optimizer.backward(loss)
+ optimizer.step()
+ res = a.opt_grad(optimizer)
+ np.testing.assert_allclose(
+ res.numpy(),
+ [[[[0.25, 0.25, 0.25, 0.25], [0.25, 0.25, 0.25, 0.25],
+ [0.25, 0.25, 0.25, 0.25], [0.25, 0.25, 0.25, 0.25]]]])
+ print("test adaptive_avgpool2d grad success")
+
+ @jt.flag_scope(use_acl=1)
+ def test_adaptive_avgpool2d_grad_2(self):
+ a = jt.float32([[[[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12],
+ [13, 14, 15, 16]]]])
+ avg_pool = jt.nn.AdaptiveAvgPool2d((1, 3))
+ optimizer = jt.optim.SGD([a], 0.1)
+ b = avg_pool(a)
+ loss = b.sum()
+ optimizer.zero_grad()
+ optimizer.backward(loss)
+ optimizer.step()
+ res = a.opt_grad(optimizer)
+ np.testing.assert_allclose(
+ res.numpy(),
+ [[[[0.125, 0.25, 0.25, 0.125], [0.125, 0.25, 0.25, 0.125],
+ [0.125, 0.25, 0.25, 0.125], [0.125, 0.25, 0.25, 0.125]]]])
+ print("test adaptive_avgpool2d_2 grad success")
+
+ @jt.flag_scope(use_acl=1)
+ def test_index(self):
+ a = jt.rand(2, 3)
+ [s1, s2] = jt.index(a.shape)
+ np.testing.assert_allclose(s1.numpy(), [[0, 0, 0], [1, 1, 1]])
+ np.testing.assert_allclose(s2.numpy(), [[0, 1, 2], [0, 1, 2]])
+ print("test index success")
+
+ @jt.flag_scope(use_acl=1)
+ def test_gather(self):
+ a = jt.array([[1, 2], [3, 4]])
+ b = jt.gather(a, 1, jt.array([[0, 0], [1, 0]]))
+ np.testing.assert_allclose(b.numpy(), [[1, 1], [4, 3]])
+ b = jt.gather(a, 0, jt.array([[0, 0], [1, 0]]))
+ np.testing.assert_allclose(b.numpy(), [[1, 2], [3, 2]])
+ b = jt.gather(a, -1, jt.array([[0, 0], [1, 0]]))
+ np.testing.assert_allclose(b.numpy(), [[1, 1], [4, 3]])
+ print("test gather success")
+
+ @jt.flag_scope(use_acl=1)
+ def test_gather_grad(self):
+ a = jt.float32([[1, 2], [3, 4]])
+ optimizer = jt.optim.SGD([a], 0.1)
+ b = jt.gather(a, 0, jt.array([[0, 0], [1, 0]]))
+ loss = b.sum()
+ optimizer.zero_grad()
+ optimizer.backward(loss)
+ optimizer.step()
+ res = a.opt_grad(optimizer)
+ np.testing.assert_allclose(res.numpy(), [[1, 2], [1, 0]])
+ print("test gather grad success")
+
+ @jt.flag_scope(use_acl=1)
+ def test_gather_grad_neg(self):
+ a = jt.float32([[4, 3], [2, 1]])
+ optimizer = jt.optim.SGD([a], 0.1)
+ b = jt.gather(a, -1, jt.array([[0, 0], [1, 0]]))
+ loss = b.sum()
+ optimizer.zero_grad()
+ optimizer.backward(loss)
+ optimizer.step()
+ res = a.opt_grad(optimizer)
+ np.testing.assert_allclose(res.numpy(), [[2, 0], [1, 1]])
+ print("test gather grad neg success")
+
+ @jt.flag_scope(use_acl=1)
+ def test_scatter_add(self):
+ a = jt.array([[1, 2], [3, 4]])
+ b = jt.array([[0, 0], [0, 0]])
+ b = jt.scatter(b, 1, jt.array([[0, 0], [1, 0]]), a, reduce="add")
+ np.testing.assert_allclose(b.numpy(), [[3, 0], [4, 3]])
+ print("test scatter add success")
+
+ @jt.flag_scope(use_acl=1)
+ def test_scatter_multi(self):
+ a = jt.array([[1, 2], [3, 4]])
+ b = jt.array([[5, 6], [7, 8]])
+ b = jt.scatter(b, 0, jt.array([[0, 0], [1, 0]]), a, reduce="multiply")
+ np.testing.assert_allclose(b.numpy(), [[5, 48], [21, 8]])
+ print("test scatter multiply success")
+
+ @jt.flag_scope(use_acl=1)
+ def test_scatter_add_grad(self):
+ a = jt.float32([[1, 2], [3, 4]])
+ b = jt.float32([[0, 0], [0, 0]])
+ optimizer = jt.optim.SGD([a, b], 0.1)
+ c = jt.scatter(b, 1, jt.array([[0, 0], [1, 0]]), a, reduce="add")
+ loss = c.max()
+ optimizer.zero_grad()
+ optimizer.backward(loss)
+ optimizer.step()
+ res_a = a.opt_grad(optimizer)
+ res_b = b.opt_grad(optimizer)
+ np.testing.assert_allclose(res_a.numpy(), [[0, 0], [0, 1]])
+ np.testing.assert_allclose(res_b.numpy(), [[0, 0], [1, 0]])
+ print("test scatter add grad success")
+
+ @jt.flag_scope(use_acl=1)
+ def test_scatter_mult_grad(self):
+ a = jt.float32([[1, 2], [3, 4]])
+ b = jt.float32([[5, 6], [7, 8]])
+ optimizer = jt.optim.SGD([a, b], 0.1)
+ c = jt.scatter(b, 1, jt.array([[0, 0], [1, 0]]), a, reduce="multiply")
+ loss = c.max()
+ optimizer.zero_grad()
+ optimizer.backward(loss)
+ optimizer.step()
+ res_a = a.opt_grad(optimizer)
+ res_b = b.opt_grad(optimizer)
+ np.testing.assert_allclose(res_a.numpy(), [[0, 6], [0, 6]])
+ np.testing.assert_allclose(res_b.numpy(), [[0, 8], [0, 0]])
+ print("test scatter mult grad success")
+
+ @jt.flag_scope(use_acl=1)
+ def test_where(self):
+ a = jt.array([[1, 2], [3, 4]])
+ b = jt.ones(2, 2)
+ c = jt.where(a > 2, a, b)
+ np.testing.assert_allclose(c.numpy(), [[1, 1], [3, 4]])
+ print("test where success")
+
+ @jt.flag_scope(use_acl=1)
+ def test_where_2(self):
+ a = jt.array([[1, 2], [3, 4]])
+ b = jt.array([[5, 6], [7, 8]])
+ cond = jt.array([[1, 0], [0, 1]])
+ c = jt.where(cond, a, b)
+ np.testing.assert_allclose(c.numpy(), [[1, 6], [7, 4]])
+ print("test where_2 success")
+
+ @jt.flag_scope(use_acl=1)
+ def test_where_grad(self):
+ a = jt.array([[1, 2], [3, 4]])
+ b = jt.array([[5, 6], [7, 8]])
+ cond = jt.array([[1, 0], [0, 1]])
+ c = jt.where(cond, a, b)
+ optimizer = jt.optim.SGD([a, b], 0.1)
+ loss = c.sum()
+ optimizer.zero_grad()
+ optimizer.backward(loss)
+ optimizer.step()
+ res_a = a.opt_grad(optimizer)
+ res_b = b.opt_grad(optimizer)
+ np.testing.assert_allclose(res_a.numpy(), [[0, 0], [1, 1]])
+ np.testing.assert_allclose(res_b.numpy(), [[1, 1], [0, 0]])
+ print("test where grad success")
+
+ @jt.flag_scope(use_acl=1)
+ def test_where_grad_2(self):
+ a = jt.float32([[1, 2], [3, 4]])
+ b = jt.array([[2., 2.], [2., 2.]])
+ c = jt.where(a > 2, a, b)
+ optimizer = jt.optim.SGD([a, b], 0.1)
+ loss = c.sum()
+ optimizer.zero_grad()
+ optimizer.backward(loss)
+ optimizer.step()
+ res_a = a.opt_grad(optimizer)
+ res_b = b.opt_grad(optimizer)
+ np.testing.assert_allclose(res_a.numpy(), [[0, 0], [1, 1]])
+ np.testing.assert_allclose(res_b.numpy(), [[1, 1], [0, 0]])
+ print("test where grad 2 success")
+
+ @jt.flag_scope(use_acl=1)
+ def test_flip(self):
+ a = jt.array([[1., 2.], [3., 4.]])
+ b = a.flip((0, 1))
+ np.testing.assert_allclose(b.numpy(), [[4, 3], [2, 1]])
+ print("test flip success")
+
+ @jt.flag_scope(use_acl=1)
+ def test_flip_grad(self):
+ a = jt.float32([[1, 2], [3, 4]])
+ optimizer = jt.optim.SGD([a], 0.1)
+ b = a.flip((0, 1))
+ loss = b.max()
+ optimizer.zero_grad()
+ optimizer.backward(loss)
+ optimizer.step()
+ res = a.opt_grad(optimizer)
+ np.testing.assert_allclose(res.numpy(), [[0, 0], [0, 1]])
+ print("test flip grad success")
+
+
+if __name__ == "__main__":
+ unittest.main()
diff --git a/python/jittor/test/test_complex.py b/python/jittor/test/test_complex.py
new file mode 100644
index 00000000..4c1a6c53
--- /dev/null
+++ b/python/jittor/test/test_complex.py
@@ -0,0 +1,485 @@
+import jittor as jt
+from jittor.nn import ComplexNumber
+import unittest
+import numpy as np
+from functools import partial
+
+__skip_torch_test = False
+try:
+ import torch
+except:
+ __skip_torch_test = True
+
+class TestResultAndGrad:
+ def flatten_list(self, list_like):
+ results = []
+ if isinstance(list_like, (list, tuple)):
+ for x in list_like:
+ results.extend(self.flatten_list(x))
+ return results
+ else:
+ return [list_like]
+
+ def check_results(self, rlist1, rlist2):
+ assert len(rlist1) == len(rlist2)
+ for r1, r2 in zip(rlist1, rlist2):
+ assert r1.shape == r2.shape
+ assert np.allclose(r1, r2, rtol=1e-3, atol=1e-3)
+
+ def grad_jittor(self, inputs, losses):
+ grads = []
+ for i in inputs:
+ for loss in losses:
+ if isinstance(i, ComplexNumber):
+ g = jt.grad(loss, i.value, retain_graph=True)
+ grads.append(g[..., 0].numpy() + 1j * g[..., 1].numpy())
+ else:
+ g = jt.grad(loss, i, retain_graph=True)
+ grads.append(g.numpy())
+ return grads
+
+ def grad_torch(self, inputs, losses):
+ grads = []
+ for i in inputs:
+ for loss in losses:
+ g = torch.autograd.grad(loss, i, retain_graph=True)[0]
+ grads.append(g.detach().cpu().numpy())
+ return grads
+
+ def run_jittor_op(self, op, input_list, weights=None, key_names=None, **kwargs):
+ def _np_to_jittor(x):
+ if isinstance(x, np.ndarray):
+ if x.dtype == np.complex64 or x.dtype == np.complex128:
+ nx = np.stack([np.real(x), np.imag(x)], axis=-1)
+ return ComplexNumber(jt.array(nx, dtype=jt.float32), is_concat_value=True)
+ elif x.dtype == np.float32 or x.dtype == np.float64:
+ return jt.array(x, dtype=jt.float32)
+ else:
+ assert False
+ elif isinstance(x, (list, tuple)):
+ nx = [_np_to_jittor(vx) for vx in x]
+ if isinstance(x, tuple):
+ return tuple(nx)
+ return nx
+ else:
+ assert False
+ def _jittor_to_np(x):
+ if isinstance(x, jt.Var):
+ return x.numpy()
+ elif isinstance(x, ComplexNumber):
+ return x.real.numpy() + 1j * x.imag.numpy()
+ assert False
+ ninput_list = [_np_to_jittor(x) for x in input_list]
+
+ if key_names != None:
+ assert len(ninput_list) == len(key_names)
+ nkwargs = kwargs.copy()
+ for k, v in zip(key_names, ninput_list):
+ nkwargs[k] = v
+ output_list = op(**nkwargs)
+ else:
+ output_list = op(*ninput_list, **kwargs)
+ if isinstance(output_list, (jt.Var, ComplexNumber)):
+ output_list = [output_list]
+ output_list = self.flatten_list(output_list)
+ losses = []
+ if weights is None:
+ weights = []
+ for o in output_list:
+ no = o.value if isinstance(o, ComplexNumber) else o
+ w = np.random.randn(*no.shape)
+ weights.append(w)
+ losses.append(jt.sum(no * jt.array(w)))
+ else:
+ assert len(output_list) == len(weights)
+ for o, w in zip(output_list, weights):
+ no = o.value if isinstance(o, ComplexNumber) else o
+ assert w.shape == no.shape
+ losses.append(jt.sum(no * jt.array(w)))
+ output_list = [_jittor_to_np(x) for x in output_list]
+ return ninput_list, output_list, losses, weights
+
+ def run_torch_op(self, op, input_list, weights=None, key_names=None, **kwargs):
+ def _np_to_torch(x):
+ if isinstance(x, np.ndarray):
+ return torch.from_numpy(x).requires_grad_(True)
+ elif isinstance(x, (list, tuple)):
+ nx = [_np_to_torch(vx) for vx in x]
+ if isinstance(x, tuple):
+ return tuple(nx)
+ return nx
+ else:
+ assert False
+ def _torch_to_np(x:torch.Tensor) -> np.ndarray:
+ return x.detach().cpu().numpy()
+ ninput_list = [_np_to_torch(x) for x in input_list]
+ if key_names != None:
+ assert len(ninput_list) == len(key_names)
+ nkwargs = kwargs.copy()
+ for k, v in zip(key_names, ninput_list):
+ nkwargs[k] = v
+ output_list = op(**nkwargs)
+ else:
+ output_list = op(*ninput_list, **kwargs)
+ if isinstance(output_list, torch.Tensor):
+ output_list = [output_list]
+ output_list = self.flatten_list(output_list)
+ losses = []
+ if weights is None:
+ weights = []
+ for o in output_list:
+ no = torch.stack([torch.real(o), torch.imag(o)], dim=-1) if o.is_complex() else o
+ w = np.random.randn(*no.shape)
+ weights.append(w)
+ losses.append(torch.sum(no * torch.from_numpy(w)))
+ else:
+ assert len(output_list) == len(weights)
+ for o, w in zip(output_list, weights):
+ no = torch.stack([torch.real(o), torch.imag(o)], dim=-1) if o.is_complex() else o
+ assert w.shape == no.shape
+ losses.append(torch.sum(no * torch.from_numpy(w)))
+ output_list = [_torch_to_np(x) for x in output_list]
+ return ninput_list, output_list, losses, weights
+
+ def check_op_with_torch(self, jittor_op, torch_op, input_list, check_grad=True, jittor_knames=None, torch_knames=None, **kwargs):
+ weights = None
+ jittor_input, jittor_output, jittor_losses, weights = self.run_jittor_op(jittor_op, input_list, weights, key_names=jittor_knames, **kwargs)
+ torch_input, torch_output, torch_losses, weights = self.run_torch_op(torch_op, input_list, weights, key_names=torch_knames, **kwargs)
+ self.check_results(jittor_output, torch_output)
+
+ if check_grad:
+ jittor_grads = self.grad_jittor(jittor_input, jittor_losses)
+ torch_grads = self.grad_torch(torch_input, torch_losses)
+ self.check_results(jittor_grads, torch_grads)
+
+ def check_op_with_numpy(self, jittor_op, numpy_op, input_list):
+ _, jittor_output, _, _ = self.run_jittor_op(jittor_op, input_list, None)
+ numpy_output = numpy_op(*input_list)
+ if isinstance(numpy_output, np.ndarray):
+ numpy_output = [numpy_output]
+
+ self.check_results(jittor_output, numpy_output)
+
+@unittest.skipIf(__skip_torch_test, "No Torch found")
+class TestComplexLinalg(unittest.TestCase, TestResultAndGrad):
+ def random_complex_matrix(self, shape):
+ r = np.random.randn(*shape)
+ i = np.random.randn(*shape)
+ return r + 1j * i
+
+ def test_complex_matmul(self):
+ s1 = (50, 200)
+ s2 = (200, 50)
+ m1 = self.random_complex_matrix(s1)
+ m2 = self.random_complex_matrix(s2)
+
+ inputs = [m1, m2]
+ self.check_op_with_torch(jt.matmul, torch.matmul, inputs)
+
+ def test_complex_matmul_batch(self):
+ s1 = (10, 50, 30)
+ s2 = (10, 30, 40)
+ m1 = self.random_complex_matrix(s1)
+ m2 = self.random_complex_matrix(s2)
+
+ inputs = [m1, m2]
+ self.check_op_with_torch(jt.matmul, torch.matmul, inputs)
+
+ def test_complex_inv(self):
+ s1 = (200, 200)
+ m1 = self.random_complex_matrix(s1)
+ inputs = [m1]
+ self.check_op_with_torch(jt.linalg.inv, torch.linalg.inv, inputs)
+
+ def test_complex_inv_batch(self):
+ s1 = (10, 50, 50)
+ m1 = self.random_complex_matrix(s1)
+ inputs = [m1]
+ self.check_op_with_torch(jt.linalg.inv, torch.linalg.inv, inputs)
+
+ def test_complex_eig(self):
+ # Unstable
+ s1 = (20, 20)
+ m1 = self.random_complex_matrix(s1)
+ inputs = [m1]
+ self.check_op_with_numpy(jt.linalg.eig, np.linalg.eig, inputs)
+
+ def test_complex_eig_batch(self):
+ # Unstable
+ s1 = (5, 10, 10)
+ m1 = self.random_complex_matrix(s1)
+ inputs = [m1]
+ self.check_op_with_numpy(jt.linalg.eig, np.linalg.eig, inputs)
+
+ def test_complex_qr(self):
+ s1 = (50, 50)
+ m1 = self.random_complex_matrix(s1)
+ inputs = [m1]
+ self.check_op_with_torch(jt.linalg.qr, torch.linalg.qr, inputs)
+
+ def test_complex_qr_batch(self):
+ s1 = (10, 20, 20)
+ m1 = self.random_complex_matrix(s1)
+ inputs = [m1]
+ self.check_op_with_torch(jt.linalg.qr, torch.linalg.qr, inputs)
+
+ def test_complex_svd(self):
+ # Unstable
+ s1 = (50, 50)
+ m1 = self.random_complex_matrix(s1)
+ inputs = [m1]
+ self.check_op_with_numpy(jt.linalg.svd, np.linalg.svd, inputs)
+
+ def test_complex_svd_batch(self):
+ # Unstable
+ s1 = (10, 20, 20)
+ m1 = self.random_complex_matrix(s1)
+ inputs = [m1]
+ self.check_op_with_numpy(jt.linalg.svd, np.linalg.svd, inputs)
+
+class TestTensordot(unittest.TestCase, TestResultAndGrad):
+ def random_complex_matrix(self, shape):
+ r = np.random.randn(*shape)
+ i = np.random.randn(*shape)
+ return r + 1j * i
+
+ def random_real_matrix(self, shape):
+ return np.random.randn(*shape)
+
+ def test_complex_tensordot_numberdim(self):
+ s1 = (3, 4, 5)
+ s2 = (4, 5, 6)
+ dims = 2
+ m1 = self.random_complex_matrix(s1)
+ m2 = self.random_complex_matrix(s2)
+ inputs = [m1, m2]
+ self.check_op_with_torch(jt.nn.tensordot, torch.tensordot, inputs, dims = dims)
+
+ def test_complex_tensordot_tupledim(self):
+ s1 = (3, 5, 4, 6)
+ s2 = (6, 4, 5, 3)
+ dims = ([2, 1, 3], [1, 2, 0])
+ m1 = self.random_complex_matrix(s1)
+ m2 = self.random_complex_matrix(s2)
+ inputs = [m1, m2]
+ self.check_op_with_torch(jt.nn.tensordot, torch.tensordot, inputs, dims = dims)
+
+ def test_real_tensordot_numberdim(self):
+ s1 = (3, 4, 5)
+ s2 = (4, 5, 6)
+ dims = 2
+ m1 = self.random_real_matrix(s1)
+ m2 = self.random_real_matrix(s2)
+ inputs = [m1, m2]
+ self.check_op_with_torch(jt.nn.tensordot, torch.tensordot, inputs, dims = dims)
+
+ def test_real_tensordot_tupledim(self):
+ s1 = (3, 5, 4, 6)
+ s2 = (6, 4, 5, 3)
+ dims = ([2, 1, 3], [1, 2, 0])
+ m1 = self.random_real_matrix(s1)
+ m2 = self.random_real_matrix(s2)
+ inputs = [m1, m2]
+ self.check_op_with_torch(jt.nn.tensordot, torch.tensordot, inputs, dims = dims)
+
+class TestKron(unittest.TestCase, TestResultAndGrad):
+ def random_complex_matrix(self, shape):
+ r = np.random.randn(*shape)
+ i = np.random.randn(*shape)
+ return r + 1j * i
+
+ def random_real_matrix(self, shape):
+ return np.random.randn(*shape)
+
+ def test_complex_firstlarge(self):
+ s1 = (2, 3, 4)
+ s2 = (5, 2)
+ m1 = self.random_complex_matrix(s1)
+ m2 = self.random_complex_matrix(s2)
+ inputs = [m1, m2]
+ self.check_op_with_torch(jt.nn.kron, torch.kron, inputs)
+
+ def test_complex_second_large(self):
+ s1 = (2, 3)
+ s2 = (5, 2, 4)
+ m1 = self.random_complex_matrix(s1)
+ m2 = self.random_complex_matrix(s2)
+ inputs = [m1, m2]
+ self.check_op_with_torch(jt.nn.kron, torch.kron, inputs)
+
+ def test_real_firstlarge(self):
+ s1 = (2, 3, 4)
+ s2 = (5, 2)
+ m1 = self.random_real_matrix(s1)
+ m2 = self.random_real_matrix(s2)
+ inputs = [m1, m2]
+ self.check_op_with_torch(jt.nn.kron, torch.kron, inputs)
+
+ def test_real_second_large(self):
+ s1 = (2, 3)
+ s2 = (5, 2, 4)
+ m1 = self.random_real_matrix(s1)
+ m2 = self.random_real_matrix(s2)
+ inputs = [m1, m2]
+ self.check_op_with_torch(jt.nn.kron, torch.kron, inputs)
+
+@unittest.skipIf(__skip_torch_test, "No Torch found")
+class TestGradFunctional(unittest.TestCase, TestResultAndGrad):
+ def random_complex_matrix(self, shape):
+ r = np.random.randn(*shape)
+ i = np.random.randn(*shape)
+ return r + 1j * i
+
+ def random_real_matrix(self, shape):
+ return np.random.randn(*shape) * 0.0 + 1.0
+
+ def test_real_jvp_exp(self):
+ def exp_reducer(x):
+ return x.exp().sum(dim=1)
+ s1 = (5, 6)
+ m1 = self.random_real_matrix(s1)
+ m2 = self.random_real_matrix(s1)
+ inputs = [m1, m2]
+ self.check_op_with_torch(
+ partial(jt.gradfunctional.jvp, func=exp_reducer, create_graph=True),
+ partial(torch.autograd.functional.jvp, func=exp_reducer, create_graph=True),
+ inputs,
+ jittor_knames = ['inputs', 'v'],
+ torch_knames = ['inputs', 'v'],
+ check_grad=False)
+
+ def test_complex_jvp_exp(self):
+ def exp_reducer(x):
+ return x.exp().sum(1)
+ s1 = (5, 6)
+ m1 = self.random_complex_matrix(s1)
+ m2 = self.random_complex_matrix(s1)
+ inputs = [m1, m2]
+ self.check_op_with_torch(
+ partial(jt.gradfunctional.jvp, func=exp_reducer, create_graph=True),
+ partial(torch.autograd.functional.jvp, func=exp_reducer, create_graph=True),
+ inputs,
+ jittor_knames = ['inputs', 'v'],
+ torch_knames = ['inputs', 'v'],
+ check_grad=False,
+ )
+
+ def test_real_jvp_add(self):
+ w1, w2 = np.random.rand(), np.random.rand()
+ def adder(x, y):
+ return w1 * x + w2 * y
+ s1 = (5, 6)
+ m1 = self.random_real_matrix(s1)
+ m2 = self.random_real_matrix(s1)
+ m3 = self.random_real_matrix(s1)
+ m4 = self.random_real_matrix(s1)
+ inputs = [(m1, m2), (m3, m4)]
+ self.check_op_with_torch(
+ partial(jt.gradfunctional.jvp, func=adder, create_graph=True),
+ partial(torch.autograd.functional.jvp, func=adder, create_graph=True),
+ inputs,
+ jittor_knames = ['inputs', 'v'],
+ torch_knames = ['inputs', 'v'],
+ check_grad=False,
+ )
+
+ def test_complex_jvp_add(self):
+ w1r, w1i = np.random.rand(), np.random.rand()
+ w2r, w2i = np.random.rand(), np.random.rand()
+ def adder_pt(x, y):
+ return (w1r + 1j * w1i) * x + (w2r + 1j * w2i) * y
+ def adder_jt(x, y):
+ w1 = ComplexNumber(real=jt.array(w1r).reshape(1,1), imag = jt.array(w1i).reshape(1,1))
+ w2 = ComplexNumber(real=jt.array(w2r).reshape(1,1), imag = jt.array(w2i).reshape(1,1))
+ return w1 * x + w2 * y
+ s1 = (5, 6)
+ m1 = self.random_complex_matrix(s1)
+ m2 = self.random_complex_matrix(s1)
+ m3 = self.random_complex_matrix(s1)
+ m4 = self.random_complex_matrix(s1)
+ inputs = [(m1, m2), (m3, m4)]
+ self.check_op_with_torch(
+ partial(jt.gradfunctional.jvp, func=adder_jt, create_graph=True),
+ partial(torch.autograd.functional.jvp, func=adder_pt, create_graph=True),
+ inputs,
+ jittor_knames = ['inputs', 'v'],
+ torch_knames = ['inputs', 'v'],
+ check_grad=False,
+ )
+
+ def test_real_vjp_exp(self):
+ def exp_reducer(x):
+ return x.exp().sum(dim=1)
+ s1 = (5, 6)
+ s2 = (5,)
+ m1 = self.random_real_matrix(s1)
+ m2 = self.random_real_matrix(s2)
+ inputs = [m1, m2]
+ self.check_op_with_torch(
+ partial(jt.gradfunctional.vjp, func=exp_reducer),
+ partial(torch.autograd.functional.vjp, func=exp_reducer),
+ inputs,
+ jittor_knames = ['inputs', 'v'],
+ torch_knames = ['inputs', 'v'],
+ check_grad=False)
+
+ def test_complex_vjp_exp(self):
+ def exp_reducer(x):
+ return x.exp().sum(1)
+ s1 = (5, 6)
+ s2 = (5,)
+ m1 = self.random_real_matrix(s1)
+ m2 = self.random_real_matrix(s2)
+ inputs = [m1, m2]
+ self.check_op_with_torch(
+ partial(jt.gradfunctional.vjp, func=exp_reducer, create_graph=True),
+ partial(torch.autograd.functional.vjp, func=exp_reducer, create_graph=True),
+ inputs,
+ jittor_knames = ['inputs', 'v'],
+ torch_knames = ['inputs', 'v'],
+ check_grad=False,
+ )
+
+ def test_real_vjp_add(self):
+ w1, w2 = np.random.rand(), np.random.rand()
+ def adder(x, y):
+ return w1 * x + w2 * y
+ s1 = (5, 6)
+ m1 = self.random_real_matrix(s1)
+ m2 = self.random_real_matrix(s1)
+ m3 = self.random_real_matrix(s1)
+ inputs = [(m1, m2), m3]
+ self.check_op_with_torch(
+ partial(jt.gradfunctional.vjp, func=adder, create_graph=True),
+ partial(torch.autograd.functional.vjp, func=adder, create_graph=True),
+ inputs,
+ jittor_knames = ['inputs', 'v'],
+ torch_knames = ['inputs', 'v'],
+ check_grad=False,
+ )
+
+ def test_complex_vjp_add(self):
+ w1r, w1i = np.random.rand(), np.random.rand()
+ w2r, w2i = np.random.rand(), np.random.rand()
+ def adder_pt(x, y):
+ return (w1r + 1j * w1i) * x + (w2r + 1j * w2i) * y
+ def adder_jt(x, y):
+ w1 = ComplexNumber(real=jt.array(w1r).reshape(1,1), imag = jt.array(w1i).reshape(1,1))
+ w2 = ComplexNumber(real=jt.array(w2r).reshape(1,1), imag = jt.array(w2i).reshape(1,1))
+ return w1 * x + w2 * y
+ s1 = (5, 6)
+ m1 = self.random_complex_matrix(s1)
+ m2 = self.random_complex_matrix(s1)
+ m3 = self.random_complex_matrix(s1)
+ inputs = [(m1, m2), (m3)]
+ self.check_op_with_torch(
+ partial(jt.gradfunctional.vjp, func=adder_jt, create_graph=True),
+ partial(torch.autograd.functional.vjp, func=adder_pt, create_graph=True),
+ inputs,
+ jittor_knames = ['inputs', 'v'],
+ torch_knames = ['inputs', 'v'],
+ check_grad=False,
+ )
+
+if __name__ == "__main__":
+ unittest.main()
diff --git a/python/jittor/test/test_var.py b/python/jittor/test/test_var.py
index 116df6c7..c18b041b 100644
--- a/python/jittor/test/test_var.py
+++ b/python/jittor/test/test_var.py
@@ -46,6 +46,33 @@ def test_norm(self):
np.testing.assert_allclose(jt_x.norm(2,1).numpy(), tc_x.norm(2,1).numpy(), atol=1e-6)
np.testing.assert_allclose(jt_x.norm(2,0).numpy(), tc_x.norm(2,0).numpy(), atol=1e-6)
+ def test_std_with_dim(self):
+ x=np.random.randn(100, 1000).astype(np.float32)
+ jt_x = jt.array(x)
+ tc_x = torch.from_numpy(x)
+ np.testing.assert_allclose(jt_x.std(dim=-1).numpy(), tc_x.std(dim=-1).numpy(), 1e-4)
+ np.testing.assert_allclose(jt_x.std(dim=0, keepdim=True).numpy(), tc_x.std(dim=0, keepdim=True).numpy(), 1e-4)
+
+ def test_diagonal(self):
+ x = np.reshape(np.arange(5*6*7*8), (5,6,7,8))
+ jt_x = jt.array(x)
+ tc_x = torch.from_numpy(x)
+ def __assert_equal(a:np.ndarray, b:np.ndarray, rtol=1e-6, atol=1e-6):
+ assert a.shape == b.shape, f"{a.shape}!={b.shape}"
+ np.testing.assert_allclose(a, b, rtol=rtol, atol=atol)
+ __assert_equal(jt.misc.diagonal(jt_x, 0, dim1=1, dim2=2).numpy(), tc_x.diagonal(offset=0, dim1=1, dim2=2).numpy())
+ __assert_equal(jt.misc.diagonal(jt_x, -1, dim1=1, dim2=2).numpy(), tc_x.diagonal(offset=-1, dim1=1, dim2=2).numpy())
+ __assert_equal(jt.misc.diagonal(jt_x, -2, dim1=1, dim2=2).numpy(), tc_x.diagonal(offset=-2, dim1=1, dim2=2).numpy())
+ __assert_equal(jt.misc.diagonal(jt_x, -6, dim1=1, dim2=2).numpy(), tc_x.diagonal(offset=-6, dim1=1, dim2=2).numpy())
+ __assert_equal(jt.misc.diagonal(jt_x, 1, dim1=1, dim2=2).numpy(), tc_x.diagonal(offset=1, dim1=1, dim2=2).numpy())
+ __assert_equal(jt.misc.diagonal(jt_x, 2, dim1=1, dim2=2).numpy(), tc_x.diagonal(offset=2, dim1=1, dim2=2).numpy())
+ __assert_equal(jt.misc.diagonal(jt_x, 7, dim1=1, dim2=2).numpy(), tc_x.diagonal(offset=7, dim1=1, dim2=2).numpy())
+ __assert_equal(jt.misc.diagonal(jt_x, 1, dim1=-1, dim2=2).numpy(), tc_x.diagonal(offset=1, dim1=-1, dim2=2).numpy())
+ __assert_equal(jt.misc.diagonal(jt_x, 1, dim1=-1, dim2=-2).numpy(), tc_x.diagonal(offset=1, dim1=-1, dim2=-2).numpy())
+ __assert_equal(jt.misc.diagonal(jt_x, 1, dim1=-2, dim2=-1).numpy(), tc_x.diagonal(offset=1, dim1=-2, dim2=-1).numpy())
+ __assert_equal(jt.misc.diagonal(jt_x, 1, dim1=0, dim2=-2).numpy(), tc_x.diagonal(offset=1, dim1=0, dim2=-2).numpy())
+ __assert_equal(jt.misc.diagonal(jt_x, 1, dim1=2, dim2=1).numpy(), tc_x.diagonal(offset=1, dim1=2, dim2=1).numpy())
+
if __name__ == "__main__":
unittest.main()
diff --git a/setup.py b/setup.py
index 741165d2..beec50f6 100644
--- a/setup.py
+++ b/setup.py
@@ -62,7 +62,7 @@
package_data={'': ['*', '*/*', '*/*/*','*/*/*/*','*/*/*/*/*','*/*/*/*/*/*']},
# include_package_data=True,
install_requires=[
- "numpy",
+ "numpy<2.0",
"tqdm",
"pillow",
"astunparse",