Skip to content

Commit

Permalink
[Fix] Fix lint (#1598)
Browse files Browse the repository at this point in the history
* [Fix] Fix lint

* [Fix] Fix lint

* Update mmengine/dist/utils.py

Co-authored-by: Zaida Zhou <[email protected]>

---------

Co-authored-by: Zaida Zhou <[email protected]>
  • Loading branch information
HAOCHENYE and zhouzaida authored Nov 2, 2024
1 parent c9b5996 commit cc3b74b
Show file tree
Hide file tree
Showing 46 changed files with 146 additions and 134 deletions.
4 changes: 2 additions & 2 deletions .github/workflows/lint.yml
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,10 @@ jobs:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v2
- name: Set up Python 3.7
- name: Set up Python 3.10.15
uses: actions/setup-python@v2
with:
python-version: 3.7
python-version: '3.10.15'
- name: Install pre-commit hook
run: |
pip install pre-commit
Expand Down
4 changes: 4 additions & 0 deletions .github/workflows/pr_stage_test.yml
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
name: pr_stage_test

env:
ACTIONS_ALLOW_USE_UNSECURE_NODE_VERSION: true


on:
pull_request:
paths-ignore:
Expand Down
13 changes: 9 additions & 4 deletions .pre-commit-config-zh-cn.yaml
Original file line number Diff line number Diff line change
@@ -1,7 +1,11 @@
exclude: ^tests/data/
repos:
- repo: https://gitee.com/openmmlab/mirrors-flake8
rev: 5.0.4
- repo: https://github.com/pre-commit/pre-commit
rev: v4.0.0
hooks:
- id: validate_manifest
- repo: https://github.com/PyCQA/flake8
rev: 7.1.1
hooks:
- id: flake8
- repo: https://gitee.com/openmmlab/mirrors-isort
Expand All @@ -13,7 +17,7 @@ repos:
hooks:
- id: yapf
- repo: https://gitee.com/openmmlab/mirrors-pre-commit-hooks
rev: v4.3.0
rev: v5.0.0
hooks:
- id: trailing-whitespace
- id: check-yaml
Expand Down Expand Up @@ -55,11 +59,12 @@ repos:
args: ["mmengine", "tests"]
- id: remove-improper-eol-in-cn-docs
- repo: https://gitee.com/openmmlab/mirrors-mypy
rev: v0.812
rev: v1.2.0
hooks:
- id: mypy
exclude: |-
(?x)(
^examples
| ^docs
)
additional_dependencies: ["types-setuptools", "types-requests", "types-PyYAML"]
17 changes: 9 additions & 8 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
@@ -1,7 +1,11 @@
exclude: ^tests/data/
repos:
- repo: https://github.com/pre-commit/pre-commit
rev: v4.0.0
hooks:
- id: validate_manifest
- repo: https://github.com/PyCQA/flake8
rev: 5.0.4
rev: 7.1.1
hooks:
- id: flake8
- repo: https://github.com/PyCQA/isort
Expand All @@ -13,7 +17,7 @@ repos:
hooks:
- id: yapf
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v4.3.0
rev: v5.0.0
hooks:
- id: trailing-whitespace
- id: check-yaml
Expand All @@ -34,12 +38,8 @@ repos:
- mdformat-openmmlab
- mdformat_frontmatter
- linkify-it-py
- repo: https://github.com/codespell-project/codespell
rev: v2.2.1
hooks:
- id: codespell
- repo: https://github.com/myint/docformatter
rev: v1.3.1
rev: 06907d0
hooks:
- id: docformatter
args: ["--in-place", "--wrap-descriptions", "79"]
Expand All @@ -55,11 +55,12 @@ repos:
args: ["mmengine", "tests"]
- id: remove-improper-eol-in-cn-docs
- repo: https://github.com/pre-commit/mirrors-mypy
rev: v0.812
rev: v1.2.0
hooks:
- id: mypy
exclude: |-
(?x)(
^examples
| ^docs
)
additional_dependencies: ["types-setuptools", "types-requests", "types-PyYAML"]
2 changes: 1 addition & 1 deletion mmengine/_strategy/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -499,7 +499,7 @@ def build_optim_wrapper(
'"type" and "constructor" are not in '
f'optimizer, but got {name}={optim}')
optim_wrappers[name] = optim
return OptimWrapperDict(**optim_wrappers)
return OptimWrapperDict(**optim_wrappers) # type: ignore
else:
raise TypeError('optimizer wrapper should be an OptimWrapper '
f'object or dict, but got {optim_wrapper}')
Expand Down
2 changes: 1 addition & 1 deletion mmengine/_strategy/colossalai.py
Original file line number Diff line number Diff line change
Expand Up @@ -361,7 +361,7 @@ def resume(
map_location: Union[str, Callable] = 'default',
callback: Optional[Callable] = None,
) -> dict:
"""override this method since colossalai resume optimizer from filename
"""Override this method since colossalai resume optimizer from filename
directly."""
self.logger.info(f'Resume checkpoint from {filename}')

Expand Down
2 changes: 1 addition & 1 deletion mmengine/_strategy/distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ def _setup_distributed( # type: ignore
init_dist(launcher, backend, **kwargs)

def convert_model(self, model: nn.Module) -> nn.Module:
"""convert all ``BatchNorm`` layers in the model to ``SyncBatchNorm``
"""Convert all ``BatchNorm`` layers in the model to ``SyncBatchNorm``
(SyncBN) or ``mmcv.ops.sync_bn.SyncBatchNorm`` (MMSyncBN) layers.
Args:
Expand Down
31 changes: 16 additions & 15 deletions mmengine/config/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -393,7 +393,7 @@ class Config:

def __init__(
self,
cfg_dict: dict = None,
cfg_dict: Optional[dict] = None,
cfg_text: Optional[str] = None,
filename: Optional[Union[str, Path]] = None,
env_variables: Optional[dict] = None,
Expand Down Expand Up @@ -1227,7 +1227,8 @@ def is_base_line(c):
if base_code is not None:
base_code = ast.Expression( # type: ignore
body=base_code.value) # type: ignore
base_files = eval(compile(base_code, '', mode='eval'))
base_files = eval(compile(base_code, '',
mode='eval')) # type: ignore
else:
base_files = []
elif file_format in ('.yml', '.yaml', '.json'):
Expand Down Expand Up @@ -1288,7 +1289,7 @@ def _get_cfg_path(cfg_path: str,
def _merge_a_into_b(a: dict,
b: dict,
allow_list_keys: bool = False) -> dict:
"""merge dict ``a`` into dict ``b`` (non-inplace).
"""Merge dict ``a`` into dict ``b`` (non-inplace).
Values in ``a`` will overwrite ``b``. ``b`` is copied first to avoid
in-place modifications.
Expand Down Expand Up @@ -1358,22 +1359,22 @@ def auto_argparser(description=None):

@property
def filename(self) -> str:
"""get file name of config."""
"""Get file name of config."""
return self._filename

@property
def text(self) -> str:
"""get config text."""
"""Get config text."""
return self._text

@property
def env_variables(self) -> dict:
"""get used environment variables."""
"""Get used environment variables."""
return self._env_variables

@property
def pretty_text(self) -> str:
"""get formatted python config text."""
"""Get formatted python config text."""

indent = 4

Expand Down Expand Up @@ -1727,17 +1728,17 @@ def to_dict(self, keep_imported: bool = False):


class DictAction(Action):
"""
argparse action to split an argument into KEY=VALUE form
on the first = and append to a dictionary. List options can
be passed as comma separated values, i.e 'KEY=V1,V2,V3', or with explicit
brackets, i.e. 'KEY=[V1,V2,V3]'. It also support nested brackets to build
list/tuple values. e.g. 'KEY=[(V1,V2),(V3,V4)]'
"""Argparse action to split an argument into KEY=VALUE form on the first =
and append to a dictionary.
List options can be passed as comma separated values, i.e 'KEY=V1,V2,V3',
or with explicit brackets, i.e. 'KEY=[V1,V2,V3]'. It also support nested
brackets to build list/tuple values. e.g. 'KEY=[(V1,V2),(V3,V4)]'
"""

@staticmethod
def _parse_int_float_bool(val: str) -> Union[int, float, bool, Any]:
"""parse int/float/bool value in the string."""
"""Parse int/float/bool value in the string."""
try:
return int(val)
except ValueError:
Expand Down Expand Up @@ -1822,7 +1823,7 @@ def __call__(self,
parser: ArgumentParser,
namespace: Namespace,
values: Union[str, Sequence[Any], None],
option_string: str = None):
option_string: str = None): # type: ignore
"""Parse Variables in string and add them into argparser.
Args:
Expand Down
2 changes: 1 addition & 1 deletion mmengine/dist/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -563,7 +563,7 @@ def cast_data_device(
Tensor or list or dict: ``data`` was casted to ``device``.
"""
if out is not None:
if type(data) != type(out):
if type(data) is not type(out):
raise TypeError(
'out should be the same type with data, but got data is '
f'{type(data)} and out is {type(data)}')
Expand Down
6 changes: 3 additions & 3 deletions mmengine/evaluator/metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,11 +175,11 @@ def __init__(self,
self.out_file_path = out_file_path

def process(self, data_batch: Any, predictions: Sequence[dict]) -> None:
"""transfer tensors in predictions to CPU."""
"""Transfer tensors in predictions to CPU."""
self.results.extend(_to_cpu(predictions))

def compute_metrics(self, results: list) -> dict:
"""dump the prediction results to a pickle file."""
"""Dump the prediction results to a pickle file."""
dump(results, self.out_file_path)
print_log(
f'Results has been saved to {self.out_file_path}.',
Expand All @@ -188,7 +188,7 @@ def compute_metrics(self, results: list) -> dict:


def _to_cpu(data: Any) -> Any:
"""transfer all tensors and BaseDataElement to cpu."""
"""Transfer all tensors and BaseDataElement to cpu."""
if isinstance(data, (Tensor, BaseDataElement)):
return data.to('cpu')
elif isinstance(data, list):
Expand Down
2 changes: 1 addition & 1 deletion mmengine/hooks/profiler_hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,7 +233,7 @@ def after_train_epoch(self, runner):
self._export_chrome_trace(runner)

def after_train_iter(self, runner, batch_idx, data_batch, outputs):
"""profiler will call `step` method if it is not closed."""
"""Profiler will call `step` method if it is not closed."""
if not self._closed:
self.profiler.step()
if runner.iter == self.profile_times - 1 and not self.by_epoch:
Expand Down
2 changes: 1 addition & 1 deletion mmengine/logging/history_buffer.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ def _set_default_statistics(self) -> None:
self._statistics_methods.setdefault('mean', HistoryBuffer.mean)

def update(self, log_val: Union[int, float], count: int = 1) -> None:
"""update the log history.
"""Update the log history.
If the length of the buffer exceeds ``self._max_length``, the oldest
element will be removed from the buffer.
Expand Down
14 changes: 7 additions & 7 deletions mmengine/model/base_model/data_preprocessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -253,17 +253,17 @@ def forward(self, data: dict, training: bool = False) -> Union[dict, list]:
dict or list: Data in the same format as the model input.
"""
data = self.cast_data(data) # type: ignore
_batch_inputs = data['inputs']
_batch_inputs = data['inputs'] # type: ignore
# Process data with `pseudo_collate`.
if is_seq_of(_batch_inputs, torch.Tensor):
batch_inputs = []
for _batch_input in _batch_inputs:
# channel transform
if self._channel_conversion:
_batch_input = _batch_input[[2, 1, 0], ...]
_batch_input = _batch_input[[2, 1, 0], ...] # type: ignore
# Convert to float after channel conversion to ensure
# efficiency
_batch_input = _batch_input.float()
_batch_input = _batch_input.float() # type: ignore
# Normalization.
if self._enable_normalize:
if self.mean.shape[0] == 3:
Expand Down Expand Up @@ -302,7 +302,7 @@ def forward(self, data: dict, training: bool = False) -> Union[dict, list]:
else:
raise TypeError('Output of `cast_data` should be a dict of '
'list/tuple with inputs and data_samples, '
f'but got {type(data)}: {data}')
data['inputs'] = batch_inputs
data.setdefault('data_samples', None)
return data
f'but got {type(data)}: {data}') # type: ignore
data['inputs'] = batch_inputs # type: ignore
data.setdefault('data_samples', None) # type: ignore
return data # type: ignore
14 changes: 7 additions & 7 deletions mmengine/model/weight_init.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ def caffe2_xavier_init(module, bias=0):


def bias_init_with_prob(prior_prob):
"""initialize conv/fc bias value according to a given probability value."""
"""Initialize conv/fc bias value according to a given probability value."""
bias_init = float(-np.log((1 - prior_prob) / prior_prob))
return bias_init

Expand Down Expand Up @@ -662,12 +662,12 @@ def trunc_normal_(tensor: Tensor,
std: float = 1.,
a: float = -2.,
b: float = 2.) -> Tensor:
r"""Fills the input Tensor with values drawn from a truncated
normal distribution. The values are effectively drawn from the
normal distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)`
with values outside :math:`[a, b]` redrawn until they are within
the bounds. The method used for generating the random values works
best when :math:`a \leq \text{mean} \leq b`.
r"""Fills the input Tensor with values drawn from a truncated normal
distribution. The values are effectively drawn from the normal distribution
:math:`\mathcal{N}(\text{mean}, \text{std}^2)` with values outside
:math:`[a, b]` redrawn until they are within the bounds. The method used
for generating the random values works best when :math:`a \leq \text{mean}
\leq b`.
Modified from
https://github.com/pytorch/pytorch/blob/master/torch/nn/init.py
Expand Down
7 changes: 4 additions & 3 deletions mmengine/model/wrappers/fully_sharded_distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,8 @@ def __init__(
auto_wrap_policy: Union[str, Callable, None] = None,
backward_prefetch: Union[str, BackwardPrefetch, None] = None,
mixed_precision: Union[dict, MixedPrecision, None] = None,
param_init_fn: Union[str, Callable[[nn.Module], None]] = None,
param_init_fn: Union[str, Callable[
[nn.Module], None]] = None, # type: ignore # noqa: E501
use_orig_params: bool = True,
**kwargs,
):
Expand Down Expand Up @@ -362,7 +363,7 @@ def optim_state_dict(
optim: torch.optim.Optimizer,
group: Optional[dist.ProcessGroup] = None,
) -> Dict[str, Any]:
"""copied from pytorch 2.0.1 which has fixed some bugs."""
"""Copied from pytorch 2.0.1 which has fixed some bugs."""
state_dict_settings = FullyShardedDataParallel.get_state_dict_type(
model)
return FullyShardedDataParallel._optim_state_dict_impl(
Expand All @@ -384,7 +385,7 @@ def set_state_dict_type(
state_dict_config: Optional[StateDictConfig] = None,
optim_state_dict_config: Optional[OptimStateDictConfig] = None,
) -> StateDictSettings:
"""copied from pytorch 2.0.1 which has fixed some bugs."""
"""Copied from pytorch 2.0.1 which has fixed some bugs."""
import torch.distributed.fsdp._traversal_utils as traversal_utils
_state_dict_type_to_config = {
StateDictType.FULL_STATE_DICT: FullStateDictConfig,
Expand Down
3 changes: 1 addition & 2 deletions mmengine/optim/optimizer/apex_optimizer_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,8 +123,7 @@ def backward(self, loss: torch.Tensor, **kwargs) -> None:
self._inner_count += 1

def state_dict(self) -> dict:
"""Get the state dictionary of :attr:`optimizer` and
:attr:`apex_amp`.
"""Get the state dictionary of :attr:`optimizer` and :attr:`apex_amp`.
Based on the state dictionary of the optimizer, the returned state
dictionary will add a key named "apex_amp".
Expand Down
4 changes: 2 additions & 2 deletions mmengine/optim/optimizer/default_constructor.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ def __init__(self,
self._validate_cfg()

def _validate_cfg(self) -> None:
"""verify the correctness of the config."""
"""Verify the correctness of the config."""
if not isinstance(self.paramwise_cfg, dict):
raise TypeError('paramwise_cfg should be None or a dict, '
f'but got {type(self.paramwise_cfg)}')
Expand All @@ -155,7 +155,7 @@ def _validate_cfg(self) -> None:
raise ValueError('base_wd should not be None')

def _is_in(self, param_group: dict, param_group_list: list) -> bool:
"""check whether the `param_group` is in the`param_group_list`"""
"""Check whether the `param_group` is in the`param_group_list`"""
assert is_list_of(param_group_list, dict)
param = set(param_group['params'])
param_set = set()
Expand Down
Loading

0 comments on commit cc3b74b

Please sign in to comment.