diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml
index ba8038d668..b0c95b01df 100644
--- a/.pre-commit-config.yaml
+++ b/.pre-commit-config.yaml
@@ -1,24 +1,25 @@
exclude: ^tests/data/
repos:
- repo: https://github.com/pre-commit/pre-commit
- rev: v4.0.0
+ rev: v4.3.0
hooks:
- id: validate_manifest
- repo: https://github.com/PyCQA/flake8
- rev: 7.1.1
+ rev: 7.3.0
hooks:
- id: flake8
+ args: [--max-line-length=90]
- repo: https://github.com/PyCQA/isort
- rev: 5.11.5
+ rev: 7.0.0
hooks:
- id: isort
- - repo: https://github.com/pre-commit/mirrors-yapf
- rev: v0.32.0
+ - repo: https://github.com/google/yapf
+ rev: v0.43.0
hooks:
- id: yapf
additional_dependencies: [toml]
- repo: https://github.com/pre-commit/pre-commit-hooks
- rev: v5.0.0
+ rev: v6.0.0
hooks:
- id: trailing-whitespace
- id: check-yaml
@@ -26,8 +27,6 @@ repos:
- id: requirements-txt-fixer
- id: double-quote-string-fixer
- id: check-merge-conflict
- - id: fix-encoding-pragma
- args: ["--remove"]
- id: mixed-line-ending
args: ["--fix=lf"]
- repo: https://github.com/executablebooks/mdformat
@@ -40,12 +39,12 @@ repos:
- mdformat_frontmatter
- linkify-it-py
- repo: https://github.com/myint/docformatter
- rev: 06907d0
+ rev: v1.7.7
hooks:
- id: docformatter
args: ["--in-place", "--wrap-descriptions", "79"]
- repo: https://github.com/asottile/pyupgrade
- rev: v3.0.0
+ rev: v3.21.0
hooks:
- id: pyupgrade
args: ["--py36-plus"]
@@ -56,7 +55,7 @@ repos:
args: ["mmengine", "tests"]
- id: remove-improper-eol-in-cn-docs
- repo: https://github.com/pre-commit/mirrors-mypy
- rev: v1.2.0
+ rev: v1.18.2
hooks:
- id: mypy
exclude: |-
@@ -67,6 +66,6 @@ repos:
additional_dependencies: ["types-setuptools", "types-requests", "types-PyYAML"]
- repo: https://github.com/astral-sh/uv-pre-commit
# uv version.
- rev: 0.9.5
+ rev: 0.9.7
hooks:
- id: uv-lock
diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md
index 618f68835f..b55796d2ab 100644
--- a/CONTRIBUTING.md
+++ b/CONTRIBUTING.md
@@ -75,7 +75,7 @@ pre-commit run --all-files
If the installation process is interrupted, you can repeatedly run `pre-commit run ... ` to continue the installation.
-If the code does not conform to the code style specification, pre-commit will raise a warning and fixes some of the errors automatically.
+If the code does not conform to the code style specification, pre-commit will raise a warning and fixes some of the errors automatically.
diff --git a/docs/en/advanced_tutorials/basedataset.md b/docs/en/advanced_tutorials/basedataset.md
index 3423e6c505..7844267cca 100644
--- a/docs/en/advanced_tutorials/basedataset.md
+++ b/docs/en/advanced_tutorials/basedataset.md
@@ -291,7 +291,7 @@ The above is not fully initialized by setting `lazy_init=True`, and then complet
In the specific process of reading data, the dataloader will usually prefetch data from multiple dataloader workers, and multiple workers have complete dataset object backup, so there will be multiple copies of the same `data_list` in the memory. In order to save this part of memory consumption, The `BaseDataset` can serialize `data_list` into memory in advance, so that multiple workers can share the same copy of `data_list`, so as to save memory.
-By default, the BaseDataset stores the serialization of `data_list` into memory. It is also possible to control whether the data will be serialized into memory ahead of time by using the `serialize_data` argument (default is `True`) :
+By default, the BaseDataset stores the serialization of `data_list` into memory. It is also possible to control whether the data will be serialized into memory ahead of time by using the `serialize_data` argument (default is `True`) :
```python
pipeline = [
@@ -374,7 +374,7 @@ MMEngine provides `ClassBalancedDataset` wrapper to repeatedly sample the corres
**Notice:**
-The `ClassBalancedDataset` wrapper assumes that the wrapped dataset class supports the `get_cat_ids(idx)` method, which returns a list. The list contains the categories of `data_info` given by 'idx'. The usage is as follows:
+The `ClassBalancedDataset` wrapper assumes that the wrapped dataset class supports the `get_cat_ids(idx)` method, which returns a list. The list contains the categories of `data_info` given by 'idx'. The usage is as follows:
```python
from mmengine.dataset import BaseDataset, ClassBalancedDataset
diff --git a/docs/en/advanced_tutorials/cross_library.md b/docs/en/advanced_tutorials/cross_library.md
index 0335aa563f..b201a69503 100644
--- a/docs/en/advanced_tutorials/cross_library.md
+++ b/docs/en/advanced_tutorials/cross_library.md
@@ -59,7 +59,7 @@ train_pipeline=[
Using an algorithm from another library is a little bit complex.
-An algorithm contains multiple submodules. Each submodule needs to add a prefix to its `type`. Take using MMDetection's YOLOX in MMTracking as an example:
+An algorithm contains multiple submodules. Each submodule needs to add a prefix to its `type`. Take using MMDetection's YOLOX in MMTracking as an example:
```python
# Use custom_imports to register mmdet models to the registry
diff --git a/docs/en/advanced_tutorials/initialize.md b/docs/en/advanced_tutorials/initialize.md
index 503d848df2..54be2e1525 100644
--- a/docs/en/advanced_tutorials/initialize.md
+++ b/docs/en/advanced_tutorials/initialize.md
@@ -239,7 +239,7 @@ Although the `init_cfg` could control the initialization method for different mo
Assuming we've defined the following modules:
-- `ToyConv` inherit from `nn.Module`, implements `init_weights`which initialize `custom_weight`(`parameter` of `ToyConv`) with 1 and initialize `custom_bias` with 0
+- `ToyConv` inherit from `nn.Module`, implements `init_weights`which initialize `custom_weight`(`parameter` of `ToyConv`) with 1 and initialize `custom_bias` with 0
- `ToyNet` defines a `ToyConv` submodule.
@@ -353,7 +353,7 @@ from mmengine.model import normal_init
normal_init(model, mean=0, std=0.01, bias=0)
```
-Similarly, we could also use [Kaiming](http://proceedings.mlr.press/v9/glorot10a/glorot10a.pdf) initialization and [Xavier](http://proceedings.mlr.press/v9/glorot10a/glorot10a.pdf) initialization:
+Similarly, we could also use [Kaiming](http://proceedings.mlr.press/v9/glorot10a/glorot10a.pdf) initialization and [Xavier](http://proceedings.mlr.press/v9/glorot10a/glorot10a.pdf) initialization:
```python
from mmengine.model import kaiming_init, xavier_init
diff --git a/docs/en/advanced_tutorials/model_analysis.md b/docs/en/advanced_tutorials/model_analysis.md
index 767fc18871..1d8ac497ef 100644
--- a/docs/en/advanced_tutorials/model_analysis.md
+++ b/docs/en/advanced_tutorials/model_analysis.md
@@ -1,6 +1,6 @@
# Model Complexity Analysis
-We provide a tool to help with the complexity analysis for the network. We borrow the idea from the implementation of [fvcore](https://github.com/facebookresearch/fvcore) to build this tool, and plan to support more custom operators in the future. Currently, it provides the interfaces to compute "FLOPs", "Activations" and "Parameters", of the given model, and supports printing the related information layer-by-layer in terms of network structure or table. The analysis tool provides both operator-level and module-level flop counts simultaneously. Please refer to [Flop Count](https://github.com/facebookresearch/fvcore/blob/main/docs/flop_count.md) for implementation details of how to accurately measure the flops of one operator if interested.
+We provide a tool to help with the complexity analysis for the network. We borrow the idea from the implementation of [fvcore](https://github.com/facebookresearch/fvcore) to build this tool, and plan to support more custom operators in the future. Currently, it provides the interfaces to compute "FLOPs", "Activations" and "Parameters", of the given model, and supports printing the related information layer-by-layer in terms of network structure or table. The analysis tool provides both operator-level and module-level flop counts simultaneously. Please refer to [Flop Count](https://github.com/facebookresearch/fvcore/blob/main/docs/flop_count.md) for implementation details of how to accurately measure the flops of one operator if interested.
## Definition
@@ -8,7 +8,7 @@ The model complexity has three indicators, namely floating-point operations (FLO
- FLOPs
- Floating-point operations (FLOPs) is not a clearly defined indicator. Here, we refer to the description in [detectron2](https://detectron2.readthedocs.io/en/latest/modules/fvcore.html#fvcore.nn.FlopCountAnalysis), which defines a set of multiply-accumulate operations as 1 FLOP.
+ Floating-point operations (FLOPs) is not a clearly defined indicator. Here, we refer to the description in [detectron2](https://detectron2.readthedocs.io/en/latest/modules/fvcore.html#fvcore.nn.FlopCountAnalysis), which defines a set of multiply-accumulate operations as 1 FLOP.
- Activations
@@ -180,4 +180,4 @@ We provide more options to support custom output
- `input_shape`: (tuple) the shape of the input, e.g., (3, 224, 224)
- `inputs`: (optional: torch.Tensor), if given, `input_shape` will be ignored
- `show_table`: (bool) whether return the statistics in the form of table, default: True
-- `show_arch`: (bool) whether return the statistics by network layers, default: True
+- `show_arch`: (bool) whether return the statistics by network layers, default: True
diff --git a/docs/en/advanced_tutorials/registry.md b/docs/en/advanced_tutorials/registry.md
index a562c0a72d..b4df62ac35 100644
--- a/docs/en/advanced_tutorials/registry.md
+++ b/docs/en/advanced_tutorials/registry.md
@@ -274,7 +274,7 @@ print(output)
### How does the parent node know about child registry?
-When working in our `MMAlpha` it might be necessary to use the `Runner` class defined in MMENGINE. This class is in charge of building most of the objects. If these objects are added to the child registry (`MMAlpha`), how is `MMEngine` able to find them? It cannot, `MMEngine` needs to switch to the Registry from `MMEngine` to `MMAlpha` according to the scope which is defined in default_runtime.py for searching the target class.
+When working in our `MMAlpha` it might be necessary to use the `Runner` class defined in MMENGINE. This class is in charge of building most of the objects. If these objects are added to the child registry (`MMAlpha`), how is `MMEngine` able to find them? It cannot, `MMEngine` needs to switch to the Registry from `MMEngine` to `MMAlpha` according to the scope which is defined in default_runtime.py for searching the target class.
We can also init the scope accordingly, see example below:
diff --git a/docs/en/design/evaluation.md b/docs/en/design/evaluation.md
index 6d54e0a4ea..d6097742f5 100644
--- a/docs/en/design/evaluation.md
+++ b/docs/en/design/evaluation.md
@@ -73,7 +73,7 @@ Usually, the process of model accuracy evaluation is shown in the figure below.
**Online evaluation**: The test data is usually divided into batches. Through a loop, each batch is fed into the model in turn, yielding corresponding predictions, and the test data and model predictions are passed to the evaluator. The evaluator calls the `process()` method of the `Metric` to process the data and prediction results. When the loop ends, the evaluator calls the `evaluate()` method of the metrics to calculate the model accuracy of the corresponding metrics.
-**Offline evaluation**: Similar to the online evaluation process, the difference is that the pre-saved model predictions are read directly to perform the evaluation. The evaluator provides the `offline_evaluate` interface for calling the `Metric`s to calculate the model accuracy in an offline way. In order to avoid memory overflow caused by processing a large amount of data at the same time, the offline evaluation divides the test data and prediction results into chunks for processing, similar to the batches in online evaluation.
+**Offline evaluation**: Similar to the online evaluation process, the difference is that the pre-saved model predictions are read directly to perform the evaluation. The evaluator provides the `offline_evaluate` interface for calling the `Metric`s to calculate the model accuracy in an offline way. In order to avoid memory overflow caused by processing a large amount of data at the same time, the offline evaluation divides the test data and prediction results into chunks for processing, similar to the batches in online evaluation.
diff --git a/docs/en/design/infer.md b/docs/en/design/infer.md
index fa7582819f..c4ac474aff 100644
--- a/docs/en/design/infer.md
+++ b/docs/en/design/infer.md
@@ -92,7 +92,7 @@ When performing inference, the following steps are typically executed:
3. visualize: Visualization of predicted results.
4. postprocess: Post-processing of predicted results, including result format conversion, exporting predicted results, etc.
-To improve the user experience of the inferencer, we do not want users to have to configure parameters for each step when performing inference. In other words, we hope that users can simply configure parameters for the `__call__` interface without being aware of the above process and complete the inference.
+To improve the user experience of the inferencer, we do not want users to have to configure parameters for each step when performing inference. In other words, we hope that users can simply configure parameters for the `__call__` interface without being aware of the above process and complete the inference.
The `__call__` interface will execute the aforementioned steps in order, but it is not aware of which step the parameters provided by the user should be assigned to. Therefore, when developing a `CustomInferencer`, developers need to define four class attributes: `preprocess_kwargs`, `forward_kwargs`, `visualize_kwargs`, and `postprocess_kwargs`. Each attribute is a set of strings that are used to specify which step the parameters in the `__call__` interface correspond to:
diff --git a/docs/en/design/logging.md b/docs/en/design/logging.md
index 68a976bfc1..9b68eb8ece 100644
--- a/docs/en/design/logging.md
+++ b/docs/en/design/logging.md
@@ -10,7 +10,7 @@

-Each scalar (losses, learning rates, etc.) during training is encapsulated by HistoryBuffer, managed by MessageHub in key-value pairs, formatted by LogProcessor and then exported to various visualization backends by [LoggerHook](mmengine.hooks.LoggerHook). **In most cases, statistical methods of these scalars can be configured through the LogProcessor without understanding the data flow.** Before diving into the design of the logging system, please read through [logging tutorial](../advanced_tutorials/logging.md) first for familiarizing basic use cases.
+Each scalar (losses, learning rates, etc.) during training is encapsulated by HistoryBuffer, managed by MessageHub in key-value pairs, formatted by LogProcessor and then exported to various visualization backends by [LoggerHook](mmengine.hooks.LoggerHook). **In most cases, statistical methods of these scalars can be configured through the LogProcessor without understanding the data flow.** Before diving into the design of the logging system, please read through [logging tutorial](../advanced_tutorials/logging.md) first for familiarizing basic use cases.
## HistoryBuffer
diff --git a/docs/en/design/visualization.md b/docs/en/design/visualization.md
index f9161d1447..836c3cf3ee 100644
--- a/docs/en/design/visualization.md
+++ b/docs/en/design/visualization.md
@@ -11,7 +11,7 @@ Visualization provides an intuitive explanation of the training and testing proc
Based on the above requirements, we proposed the `Visualizer` and various `VisBackend` such as `LocalVisBackend`, `WandbVisBackend`, and `TensorboardVisBackend` in OpenMMLab 2.0. The visualizer could not only visualize the image data, but also things like configurations, scalars, and model structure.
-- For convenience, the APIs provided by the `Visualizer` implement the drawing and storage functions. As an internal property of `Visualizer`, `VisBackend` will be called by `Visualizer` to write data to different backends.
+- For convenience, the APIs provided by the `Visualizer` implement the drawing and storage functions. As an internal property of `Visualizer`, `VisBackend` will be called by `Visualizer` to write data to different backends.
- Considering that you may want to write data to multiple backends after drawing, `Visualizer` can be configured with multiple backends. When the user calls the storage API of the `Visualizer`, it will traverse and call all the specified APIs of `VisBackend` internally.
The UML diagram of the two is as follows.
diff --git a/docs/en/examples/train_a_gan.md b/docs/en/examples/train_a_gan.md
index 2617fda767..d84be9d3d5 100644
--- a/docs/en/examples/train_a_gan.md
+++ b/docs/en/examples/train_a_gan.md
@@ -148,7 +148,7 @@ from mmengine.model import ImgDataPreprocessor
data_preprocessor = ImgDataPreprocessor(mean=([127.5]), std=([127.5]))
```
-The following code implements the basic algorithm of GAN. To implement the algorithm using MMEngine, you need to inherit from the [BaseModel](mmengine.model.BaseModel) and implement the training process in the train_step. GAN requires alternating training of the generator and discriminator, which are implemented by train_discriminator and train_generator and implement disc_loss and gen_loss to calculate the discriminator loss function and generator loss function.
+The following code implements the basic algorithm of GAN. To implement the algorithm using MMEngine, you need to inherit from the [BaseModel](mmengine.model.BaseModel) and implement the training process in the train_step. GAN requires alternating training of the generator and discriminator, which are implemented by train_discriminator and train_generator and implement disc_loss and gen_loss to calculate the discriminator loss function and generator loss function.
More details about BaseModel, refer to [Model tutorial](../tutorials/model.md).
```python
diff --git a/docs/en/migration/param_scheduler.md b/docs/en/migration/param_scheduler.md
index 64867252e6..1c77ab8be4 100644
--- a/docs/en/migration/param_scheduler.md
+++ b/docs/en/migration/param_scheduler.md
@@ -435,7 +435,7 @@ param_scheduler = [
-Notice: `by_epoch` defaults to `False` in MMCV. It now defaults to `True` in MMEngine.
+Notice: `by_epoch` defaults to `False` in MMCV. It now defaults to `True` in MMEngine.
### LinearAnnealingLrUpdaterHook migration
diff --git a/docs/en/notes/changelog.md b/docs/en/notes/changelog.md
index 0aa64bfdbc..27062237e9 100644
--- a/docs/en/notes/changelog.md
+++ b/docs/en/notes/changelog.md
@@ -117,7 +117,7 @@ A total of 3 developers contributed to this release. Thanks [@HIT-cwh](https://g
### Contributors
-A total of 9 developers contributed to this release. Thanks [@POI-WX](https://github.com/POI-WX), [@whlook](https://github.com/whlook), [@jonbakerfish](https://github.com/jonbakerfish), [@LZHgrla](https://github.com/LZHgrla), [@Ben-Louis](https://github.com/Ben-Louis), [@YiyaoYang1](https://github.com/YiyaoYang1), [@fanqiNO1](https://github.com/fanqiNO1), [@HAOCHENYE](https://github.com/HAOCHENYE), [@zhouzaida](https://github.com/zhouzaida)
+A total of 9 developers contributed to this release. Thanks [@POI-WX](https://github.com/POI-WX), [@whlook](https://github.com/whlook), [@jonbakerfish](https://github.com/jonbakerfish), [@LZHgrla](https://github.com/LZHgrla), [@Ben-Louis](https://github.com/Ben-Louis), [@YiyaoYang1](https://github.com/YiyaoYang1), [@fanqiNO1](https://github.com/fanqiNO1), [@HAOCHENYE](https://github.com/HAOCHENYE), [@zhouzaida](https://github.com/zhouzaida)
## v0.9.0 (10/10/2023)
@@ -345,7 +345,7 @@ A total of 9 developers contributed to this release. Thanks [@evdcush](https://g
### Contributors
-A total of 19 developers contributed to this release. Thanks [@Hongru-Xiao](https://github.com/Hongru-Xiao) [@i-aki-y](https://github.com/i-aki-y) [@Bomsw](https://github.com/Bomsw) [@KickCellarDoor](https://github.com/KickCellarDoor) [@zhouzaida](https://github.com/zhouzaida) [@YQisme](https://github.com/YQisme) [@gachiemchiep](https://github.com/gachiemchiep) [@CescMessi](https://github.com/CescMessi) [@W-ZN](https://github.com/W-ZN) [@Ginray](https://github.com/Ginray) [@adrianjoshua-strutt](https://github.com/adrianjoshua-strutt) [@CokeDong](https://github.com/CokeDong) [@xin-li-67](https://github.com/xin-li-67) [@Xiangxu-0103](https://github.com/Xiangxu-0103) [@HAOCHENYE](https://github.com/HAOCHENYE) [@Shiyang980713](https://github.com/Shiyang980713) [@TankNee](https://github.com/TankNee) [@zimonitrome](https://github.com/zimonitrome) [@gy-7](https://github.com/gy-7)
+A total of 19 developers contributed to this release. Thanks [@Hongru-Xiao](https://github.com/Hongru-Xiao) [@i-aki-y](https://github.com/i-aki-y) [@Bomsw](https://github.com/Bomsw) [@KickCellarDoor](https://github.com/KickCellarDoor) [@zhouzaida](https://github.com/zhouzaida) [@YQisme](https://github.com/YQisme) [@gachiemchiep](https://github.com/gachiemchiep) [@CescMessi](https://github.com/CescMessi) [@W-ZN](https://github.com/W-ZN) [@Ginray](https://github.com/Ginray) [@adrianjoshua-strutt](https://github.com/adrianjoshua-strutt) [@CokeDong](https://github.com/CokeDong) [@xin-li-67](https://github.com/xin-li-67) [@Xiangxu-0103](https://github.com/Xiangxu-0103) [@HAOCHENYE](https://github.com/HAOCHENYE) [@Shiyang980713](https://github.com/Shiyang980713) [@TankNee](https://github.com/TankNee) [@zimonitrome](https://github.com/zimonitrome) [@gy-7](https://github.com/gy-7)
## v0.7.3 (04/28/2023)
@@ -369,7 +369,7 @@ A total of 19 developers contributed to this release. Thanks [@Hongru-Xiao](http
- Enhance the support for MLU device by [@josh6688](https://github.com/josh6688) in https://github.com/open-mmlab/mmengine/pull/1075
- Support configuring synchronization directory for BaseMetric by [@HAOCHENYE](https://github.com/HAOCHENYE) in https://github.com/open-mmlab/mmengine/pull/1074
- Support accepting multiple `input_shape` for `get_model_complexity_info` by [@sjiang95](https://github.com/sjiang95) in https://github.com/open-mmlab/mmengine/pull/1065
-- Enhance docstring and error catching in `MessageHub` by [@HAOCHENYE](https://github.com/HAOCHENYE) in https://github.com/open-mmlab/mmengine/pull/1098
+- Enhance docstring and error catching in `MessageHub` by [@HAOCHENYE](https://github.com/HAOCHENYE) in https://github.com/open-mmlab/mmengine/pull/1098
- Enhance the efficiency of Visualizer.show by [@HAOCHENYE](https://github.com/HAOCHENYE) in https://github.com/open-mmlab/mmengine/pull/1015
- Update repo list by [@HAOCHENYE](https://github.com/HAOCHENYE) in https://github.com/open-mmlab/mmengine/pull/1108
- Enhance error message during custom import by [@HAOCHENYE](https://github.com/HAOCHENYE) in https://github.com/open-mmlab/mmengine/pull/1102
@@ -613,10 +613,10 @@ A total of 8 developers contributed to this release. Thanks [@LEFTeyex](https://
- Fix typos in EN `contributing.md` by [@RangeKing](https://github.com/RangeKing) in https://github.com/open-mmlab/mmengine/pull/792
- Translate data transform docs. by [@mzr1996](https://github.com/mzr1996) in https://github.com/open-mmlab/mmengine/pull/737
- Replace markdown table with html table by [@HAOCHENYE](https://github.com/HAOCHENYE) in https://github.com/open-mmlab/mmengine/pull/800
-- Fix wrong example in `Visualizer.draw_polygons` by [@lyviva](https://github.com/lyviva) in https://github.com/open-mmlab/mmengine/pull/798
+- Fix wrong example in `Visualizer.draw_polygons` by [@lyviva](https://github.com/lyviva) in https://github.com/open-mmlab/mmengine/pull/798
- Fix docstring format and rescale the images by [@zhouzaida](https://github.com/zhouzaida) in https://github.com/open-mmlab/mmengine/pull/802
- Fix failed link in registry by [@zhouzaida](https://github.com/zhouzaida) in https://github.com/open-mmlab/mmengine/pull/811
-- Fix typos by [@shanmo](https://github.com/shanmo) in https://github.com/open-mmlab/mmengine/pull/814
+- Fix typos by [@shanmo](https://github.com/shanmo) in https://github.com/open-mmlab/mmengine/pull/814
- Fix wrong links and typos in docs by [@shanmo](https://github.com/shanmo) in https://github.com/open-mmlab/mmengine/pull/815
- Translate `save_gpu_memory.md` by [@xin-li-67](https://github.com/xin-li-67) in https://github.com/open-mmlab/mmengine/pull/803
- Translate the documentation of hook design by [@zhouzaida](https://github.com/zhouzaida) in https://github.com/open-mmlab/mmengine/pull/780
@@ -691,7 +691,7 @@ A total of 16 developers contributed to this release. Thanks [@BayMaxBHL](https:
- Add documents for `clip_grad`, and support clip grad by value. by [@HAOCHENYE](https://github.com/HAOCHENYE) in https://github.com/open-mmlab/mmengine/pull/513
- Add ROCm info when collecting env by [@zhouzaida](https://github.com/zhouzaida) in https://github.com/open-mmlab/mmengine/pull/633
- Add a function to mark the deprecated function. by [@HAOCHENYE](https://github.com/HAOCHENYE) in https://github.com/open-mmlab/mmengine/pull/609
-- Call `register_all_modules` in `Registry.get()` by [@HAOCHENYE](https://github.com/HAOCHENYE) in https://github.com/open-mmlab/mmengine/pull/541
+- Call `register_all_modules` in `Registry.get()` by [@HAOCHENYE](https://github.com/HAOCHENYE) in https://github.com/open-mmlab/mmengine/pull/541
- Deprecate `_save_to_state_dict` implemented in mmengine by [@HAOCHENYE](https://github.com/HAOCHENYE) in https://github.com/open-mmlab/mmengine/pull/610
- Add `ignore_keys` in ConcatDataset by [@BIGWangYuDong](https://github.com/BIGWangYuDong) in https://github.com/open-mmlab/mmengine/pull/556
@@ -768,7 +768,7 @@ A total of 16 developers contributed to this release. Thanks [@BayMaxBHL](https:
- Fix uploading image in wandb backend [@okotaku](https://github.com/okotaku) in https://github.com/open-mmlab/mmengine/pull/510
- Fix loading state dictionary in `EMAHook` by [@okotaku](https://github.com/okotaku) in https://github.com/open-mmlab/mmengine/pull/507
- Fix circle import in `EMAHook` by [@HAOCHENYE](https://github.com/HAOCHENYE) in https://github.com/open-mmlab/mmengine/pull/523
-- Fix unit test could fail caused by `MultiProcessTestCase` by [@HAOCHENYE](https://github.com/HAOCHENYE) in https://github.com/open-mmlab/mmengine/pull/535
+- Fix unit test could fail caused by `MultiProcessTestCase` by [@HAOCHENYE](https://github.com/HAOCHENYE) in https://github.com/open-mmlab/mmengine/pull/535
- Remove unnecessary "if statement" in `Registry` by [@MambaWong](https://github.com/MambaWong) in https://github.com/open-mmlab/mmengine/pull/536
- Fix `_save_to_state_dict` by [@HAOCHENYE](https://github.com/HAOCHENYE) in https://github.com/open-mmlab/mmengine/pull/542
- Support comparing NumPy array dataset meta in `Runner.resume` by [@HAOCHENYE](https://github.com/HAOCHENYE) in https://github.com/open-mmlab/mmengine/pull/511
diff --git a/docs/en/notes/contributing.md b/docs/en/notes/contributing.md
index deb398f02f..291cda46c4 100644
--- a/docs/en/notes/contributing.md
+++ b/docs/en/notes/contributing.md
@@ -77,7 +77,7 @@ pre-commit run --all-files
If the installation process is interrupted, you can repeatedly run `pre-commit run ... ` to continue the installation.
-If the code does not conform to the code style specification, pre-commit will raise a warning and fixes some of the errors automatically.
+If the code does not conform to the code style specification, pre-commit will raise a warning and fixes some of the errors automatically.
diff --git a/mmengine/_strategy/deepspeed.py b/mmengine/_strategy/deepspeed.py
index 3f89ff760d..3d945a6a54 100644
--- a/mmengine/_strategy/deepspeed.py
+++ b/mmengine/_strategy/deepspeed.py
@@ -310,10 +310,10 @@ def __init__(
self.config.setdefault('gradient_accumulation_steps', 1)
self.config['steps_per_print'] = steps_per_print
self._inputs_to_half = inputs_to_half
- assert (exclude_frozen_parameters is None or
- digit_version(deepspeed.__version__) >= digit_version('0.13.2')
- ), ('DeepSpeed >= 0.13.2 is required to enable '
- 'exclude_frozen_parameters')
+ assert (exclude_frozen_parameters is None or digit_version(
+ deepspeed.__version__) >= digit_version('0.13.2')), (
+ 'DeepSpeed >= 0.13.2 is required to enable '
+ 'exclude_frozen_parameters')
self.exclude_frozen_parameters = exclude_frozen_parameters
register_deepspeed_optimizers()
diff --git a/mmengine/config/config.py b/mmengine/config/config.py
index 7df8dcc52c..183138eea9 100644
--- a/mmengine/config/config.py
+++ b/mmengine/config/config.py
@@ -46,9 +46,10 @@
def _lazy2string(cfg_dict, dict_type=None):
if isinstance(cfg_dict, dict):
dict_type = dict_type or type(cfg_dict)
- return dict_type(
- {k: _lazy2string(v, dict_type)
- for k, v in dict.items(cfg_dict)})
+ return dict_type({
+ k: _lazy2string(v, dict_type)
+ for k, v in dict.items(cfg_dict)
+ })
elif isinstance(cfg_dict, (tuple, list)):
return type(cfg_dict)(_lazy2string(v, dict_type) for v in cfg_dict)
elif isinstance(cfg_dict, (LazyAttr, LazyObject)):
@@ -271,13 +272,15 @@ def __reduce_ex__(self, proto):
# called by CPython interpreter during pickling. See more details in
# https://github.com/python/cpython/blob/8d61a71f9c81619e34d4a30b625922ebc83c561b/Objects/typeobject.c#L6196 # noqa: E501
if digit_version(platform.python_version()) < digit_version('3.8'):
- return (self.__class__, ({k: v
- for k, v in super().items()}, ), None,
- None, None)
+ return (self.__class__, ({
+ k: v
+ for k, v in super().items()
+ }, ), None, None, None)
else:
- return (self.__class__, ({k: v
- for k, v in super().items()}, ), None,
- None, None, None)
+ return (self.__class__, ({
+ k: v
+ for k, v in super().items()
+ }, ), None, None, None, None)
def __eq__(self, other):
if isinstance(other, ConfigDict):
diff --git a/mmengine/dataset/utils.py b/mmengine/dataset/utils.py
index 2c9cf96497..d140cc8dc4 100644
--- a/mmengine/dataset/utils.py
+++ b/mmengine/dataset/utils.py
@@ -158,7 +158,8 @@ def default_collate(data_batch: Sequence) -> Any:
return [default_collate(samples) for samples in transposed]
elif isinstance(data_item, Mapping):
return data_item_type({
- key: default_collate([d[key] for d in data_batch])
+ key:
+ default_collate([d[key] for d in data_batch])
for key in data_item
})
else:
diff --git a/mmengine/fileio/backends/local_backend.py b/mmengine/fileio/backends/local_backend.py
index c7d5f04621..d60c539400 100644
--- a/mmengine/fileio/backends/local_backend.py
+++ b/mmengine/fileio/backends/local_backend.py
@@ -156,8 +156,8 @@ def isfile(self, filepath: Union[str, Path]) -> bool:
"""
return osp.isfile(filepath)
- def join_path(self, filepath: Union[str, Path],
- *filepaths: Union[str, Path]) -> str:
+ def join_path(self, filepath: Union[str, Path], *filepaths:
+ Union[str, Path]) -> Union[str, Path]:
r"""Concatenate all file paths.
Join one or more filepath components intelligently. The return value
@@ -167,7 +167,8 @@ def join_path(self, filepath: Union[str, Path],
filepath (str or Path): Path to be concatenated.
Returns:
- str: The result of concatenation.
+ str or Path: The result of concatenation,
+ with the same type as filepath.
Examples:
>>> backend = LocalBackend()
@@ -177,8 +178,8 @@ def join_path(self, filepath: Union[str, Path],
>>> backend.join_path(filepath1, filepath2, filepath3)
'/path/of/dir/dir2/path/of/file'
"""
- # TODO, if filepath or filepaths are Path, should return Path
- return osp.join(filepath, *filepaths)
+ result = osp.join(filepath, *filepaths)
+ return type(filepath)(result)
@contextmanager
def get_local_path(
@@ -203,7 +204,7 @@ def copyfile(
self,
src: Union[str, Path],
dst: Union[str, Path],
- ) -> str:
+ ) -> Union[str, Path]:
"""Copy a file src to dst and return the destination file.
src and dst should have the same prefix. If dst specifies a directory,
@@ -215,7 +216,7 @@ def copyfile(
dst (str or Path): Copy file to dst.
Returns:
- str: The destination file.
+ str or Path: The destination file, with the same type as dst.
Raises:
SameFileError: If src and dst are the same file, a SameFileError
@@ -236,13 +237,14 @@ def copyfile(
>>> backend.copyfile(src, dst)
'/path1/of/dir/file'
"""
- return shutil.copy(src, dst)
+ result = shutil.copy(str(src), str(dst))
+ return type(dst)(result)
def copytree(
self,
src: Union[str, Path],
dst: Union[str, Path],
- ) -> str:
+ ) -> Union[str, Path]:
"""Recursively copy an entire directory tree rooted at src to a
directory named dst and return the destination directory.
@@ -255,7 +257,7 @@ def copytree(
dst (str or Path): Copy directory to dst.
Returns:
- str: The destination directory.
+ str or Path: The destination directory, with the same type as dst.
Raises:
FileExistsError: If dst had already existed, a FileExistsError will
@@ -268,13 +270,14 @@ def copytree(
>>> backend.copytree(src, dst)
'/path/of/dir2'
"""
- return shutil.copytree(src, dst)
+ result = shutil.copytree(str(src), str(dst))
+ return type(dst)(result)
def copyfile_from_local(
self,
src: Union[str, Path],
dst: Union[str, Path],
- ) -> str:
+ ) -> Union[str, Path]:
"""Copy a local file src to dst and return the destination file. Same
as :meth:`copyfile`.
@@ -283,8 +286,8 @@ def copyfile_from_local(
dst (str or Path): Copy file to dst.
Returns:
- str: If dst specifies a directory, the file will be copied into dst
- using the base filename from src.
+ str or Path: If dst specifies a directory, the file will be copied into dst
+ using the base filename from src, with the same type as dst.
Raises:
SameFileError: If src and dst are the same file, a SameFileError
@@ -311,7 +314,7 @@ def copytree_from_local(
self,
src: Union[str, Path],
dst: Union[str, Path],
- ) -> str:
+ ) -> Union[str, Path]:
"""Recursively copy an entire directory tree rooted at src to a
directory named dst and return the destination directory. Same as
:meth:`copytree`.
@@ -321,7 +324,7 @@ def copytree_from_local(
dst (str or Path): Copy directory to dst.
Returns:
- str: The destination directory.
+ str or Path: The destination directory, with the same type as dst.
Examples:
>>> backend = LocalBackend()
@@ -336,7 +339,7 @@ def copyfile_to_local(
self,
src: Union[str, Path],
dst: Union[str, Path],
- ) -> str:
+ ) -> Union[str, Path]:
"""Copy the file src to local dst and return the destination file. Same
as :meth:`copyfile`.
@@ -349,8 +352,8 @@ def copyfile_to_local(
dst (str or Path): Copy file to to local dst.
Returns:
- str: If dst specifies a directory, the file will be copied into dst
- using the base filename from src.
+ str or Path: If dst specifies a directory, the file will be copied into dst
+ using the base filename from src, with the same type as dst.
Examples:
>>> backend = LocalBackend()
@@ -373,7 +376,7 @@ def copytree_to_local(
self,
src: Union[str, Path],
dst: Union[str, Path],
- ) -> str:
+ ) -> Union[str, Path]:
"""Recursively copy an entire directory tree rooted at src to a local
directory named dst and return the destination directory.
@@ -384,7 +387,7 @@ def copytree_to_local(
prefix of uri corresponding backend. Defaults to None.
Returns:
- str: The destination directory.
+ str or Path: The destination directory, with the same type as dst.
Examples:
>>> backend = LocalBackend()
diff --git a/mmengine/fileio/backends/registry_utils.py b/mmengine/fileio/backends/registry_utils.py
index 4578a4ca76..e2f41fb248 100644
--- a/mmengine/fileio/backends/registry_utils.py
+++ b/mmengine/fileio/backends/registry_utils.py
@@ -28,8 +28,6 @@ def _register_backend(name: str,
prefixes (str or list[str] or tuple[str], optional): The prefix
of the registered storage backend. Defaults to None.
"""
- global backends, prefix_to_backends
-
if not isinstance(name, str):
raise TypeError('the backend name should be a string, '
f'but got {type(name)}')
diff --git a/mmengine/fileio/file_client.py b/mmengine/fileio/file_client.py
index bbb81b3dfc..603c501e3b 100644
--- a/mmengine/fileio/file_client.py
+++ b/mmengine/fileio/file_client.py
@@ -385,8 +385,8 @@ def isfile(self, filepath: Union[str, Path]) -> bool:
"""
return self.client.isfile(filepath)
- def join_path(self, filepath: Union[str, Path],
- *filepaths: Union[str, Path]) -> str:
+ def join_path(self, filepath: Union[str, Path], *filepaths:
+ Union[str, Path]) -> Union[str, Path]:
r"""Concatenate all file paths.
Join one or more filepath components intelligently. The return value
@@ -396,7 +396,7 @@ def join_path(self, filepath: Union[str, Path],
filepath (str or Path): Path to be concatenated.
Returns:
- str: The result of concatenation.
+ str or Path: The result of concatenation.
"""
return self.client.join_path(filepath, *filepaths)
diff --git a/mmengine/fileio/io.py b/mmengine/fileio/io.py
index fdeb4dc6df..d849abf658 100644
--- a/mmengine/fileio/io.py
+++ b/mmengine/fileio/io.py
@@ -128,8 +128,6 @@ def get_file_backend(
>>> # backend name has a higher priority if 'backend' in backend_args
>>> backend = get_file_backend(uri, backend_args={'backend': 'petrel'})
"""
- global backend_instances
-
if backend_args is None:
backend_args = {}
diff --git a/mmengine/hooks/checkpoint_hook.py b/mmengine/hooks/checkpoint_hook.py
index 92a4867bb9..3adb78c7dc 100644
--- a/mmengine/hooks/checkpoint_hook.py
+++ b/mmengine/hooks/checkpoint_hook.py
@@ -196,10 +196,10 @@ def __init__(self,
self.save_best = save_best
# rule logic
- assert (isinstance(rule, str) or is_list_of(rule, str)
- or (rule is None)), (
- '"rule" should be a str or list of str or None, '
- f'but got {type(rule)}')
+ assert (isinstance(rule, str) or is_list_of(rule, str) or
+ (rule
+ is None)), ('"rule" should be a str or list of str or None, '
+ f'but got {type(rule)}')
if isinstance(rule, list):
# check the length of rule list
assert len(rule) in [
diff --git a/mmengine/model/test_time_aug.py b/mmengine/model/test_time_aug.py
index c623eec8bc..2f19248c2c 100644
--- a/mmengine/model/test_time_aug.py
+++ b/mmengine/model/test_time_aug.py
@@ -124,9 +124,10 @@ def test_step(self, data):
data_list: Union[List[dict], List[list]]
if isinstance(data, dict):
num_augs = len(data[next(iter(data))])
- data_list = [{key: value[idx]
- for key, value in data.items()}
- for idx in range(num_augs)]
+ data_list = [{
+ key: value[idx]
+ for key, value in data.items()
+ } for idx in range(num_augs)]
elif isinstance(data, (tuple, list)):
num_augs = len(data[0])
data_list = [[_data[idx] for _data in data]
diff --git a/mmengine/runner/_flexible_runner.py b/mmengine/runner/_flexible_runner.py
index 5160a5cfb0..eff4d28ca2 100644
--- a/mmengine/runner/_flexible_runner.py
+++ b/mmengine/runner/_flexible_runner.py
@@ -863,8 +863,8 @@ def build_dataloader(
worker_init_fn_type = worker_init_fn_cfg.pop('type')
worker_init_fn = FUNCTIONS.get(worker_init_fn_type)
assert callable(worker_init_fn)
- init_fn = partial(worker_init_fn,
- **worker_init_fn_cfg) # type: ignore
+ init_fn = partial( # type: ignore
+ worker_init_fn, **worker_init_fn_cfg)
else:
if seed is not None:
disable_subprocess_warning = dataloader_cfg.pop(
@@ -874,7 +874,7 @@ def build_dataloader(
f'{type(disable_subprocess_warning)}')
init_fn = partial(
default_worker_init_fn,
- num_workers=dataloader_cfg.get('num_workers'),
+ num_workers=dataloader_cfg.get('num_workers', 0),
rank=get_rank(),
seed=seed,
disable_subprocess_warning=disable_subprocess_warning)
@@ -1611,7 +1611,7 @@ def callback(checkpoint):
self.call_hook('before_save_checkpoint', checkpoint=checkpoint)
self.strategy.save_checkpoint(
- filename=filepath,
+ filename=str(filepath),
save_optimizer=save_optimizer,
save_param_scheduler=save_param_scheduler,
extra_ckpt=checkpoint,
diff --git a/mmengine/runner/checkpoint.py b/mmengine/runner/checkpoint.py
index b9c62a8f70..a5809013a1 100644
--- a/mmengine/runner/checkpoint.py
+++ b/mmengine/runner/checkpoint.py
@@ -661,9 +661,10 @@ def _load_checkpoint_to_model(model,
# strip prefix of state_dict
metadata = getattr(state_dict, '_metadata', OrderedDict())
for p, r in revise_keys:
- state_dict = OrderedDict(
- {re.sub(p, r, k): v
- for k, v in state_dict.items()})
+ state_dict = OrderedDict({
+ re.sub(p, r, k): v
+ for k, v in state_dict.items()
+ })
# Keep metadata in state_dict
state_dict._metadata = metadata
diff --git a/mmengine/runner/runner.py b/mmengine/runner/runner.py
index ba8039d476..9b7b64e2c1 100644
--- a/mmengine/runner/runner.py
+++ b/mmengine/runner/runner.py
@@ -1425,8 +1425,8 @@ def build_dataloader(dataloader: Union[DataLoader, Dict],
'type of worker_init_fn should be string or callable '
f'object, but got {type(worker_init_fn_type)}')
assert callable(worker_init_fn)
- init_fn = partial(worker_init_fn,
- **worker_init_fn_cfg) # type: ignore
+ init_fn = partial( # type: ignore
+ worker_init_fn, **worker_init_fn_cfg)
else:
if seed is not None:
disable_subprocess_warning = dataloader_cfg.pop(
@@ -1436,7 +1436,7 @@ def build_dataloader(dataloader: Union[DataLoader, Dict],
f'{type(disable_subprocess_warning)}')
init_fn = partial(
default_worker_init_fn,
- num_workers=dataloader_cfg.get('num_workers'),
+ num_workers=dataloader_cfg.get('num_workers', 0),
rank=get_rank(),
seed=seed,
disable_subprocess_warning=disable_subprocess_warning)
diff --git a/mmengine/structures/instance_data.py b/mmengine/structures/instance_data.py
index a083b5b505..9da5231cf1 100644
--- a/mmengine/structures/instance_data.py
+++ b/mmengine/structures/instance_data.py
@@ -1,7 +1,7 @@
# Copyright (c) OpenMMLab. All rights reserved.
import itertools
from collections.abc import Sized
-from typing import Any, List, Union
+from typing import Any, List, Union, cast, get_args
import numpy as np
import torch
@@ -9,8 +9,8 @@
from mmengine.device import get_device
from .base_data_element import BaseDataElement
-BoolTypeTensor: Union[Any]
-LongTypeTensor: Union[Any]
+BoolTypeTensor: Any
+LongTypeTensor: Any
if get_device() == 'npu':
BoolTypeTensor = Union[torch.BoolTensor, torch.npu.BoolTensor]
@@ -25,8 +25,8 @@
BoolTypeTensor = Union[torch.BoolTensor, torch.cuda.BoolTensor]
LongTypeTensor = Union[torch.LongTensor, torch.cuda.LongTensor]
-IndexType: Union[Any] = Union[str, slice, int, list, LongTypeTensor,
- BoolTypeTensor, np.ndarray]
+IndexType = Union[str, slice, int, list, LongTypeTensor, BoolTypeTensor,
+ np.ndarray]
# Modified from
@@ -172,7 +172,7 @@ def __getitem__(self, item: IndexType) -> 'InstanceData':
Returns:
:obj:`InstanceData`: Corresponding values.
"""
- assert isinstance(item, IndexType.__args__)
+ assert isinstance(item, get_args(IndexType))
if isinstance(item, list):
item = np.array(item)
if isinstance(item, np.ndarray):
@@ -284,7 +284,7 @@ def cat(instances_list: List['InstanceData']) -> 'InstanceData':
new_data = instances_list[0].__class__(
metainfo=instances_list[0].metainfo)
for k in instances_list[0].keys():
- values = [results[k] for results in instances_list]
+ values = [cast(Any, results[k]) for results in instances_list]
v0 = values[0]
if isinstance(v0, torch.Tensor):
new_values = torch.cat(values, dim=0)
diff --git a/mmengine/utils/dl_utils/torch_ops.py b/mmengine/utils/dl_utils/torch_ops.py
index 2550ae6986..85dc3100d2 100644
--- a/mmengine/utils/dl_utils/torch_ops.py
+++ b/mmengine/utils/dl_utils/torch_ops.py
@@ -4,9 +4,9 @@
from ..version_utils import digit_version
from .parrots_wrapper import TORCH_VERSION
-_torch_version_meshgrid_indexing = (
- 'parrots' not in TORCH_VERSION
- and digit_version(TORCH_VERSION) >= digit_version('1.10.0a0'))
+_torch_version_meshgrid_indexing = ('parrots' not in TORCH_VERSION
+ and digit_version(TORCH_VERSION)
+ >= digit_version('1.10.0a0'))
def torch_meshgrid(*tensors):
diff --git a/mmengine/utils/package_utils.py b/mmengine/utils/package_utils.py
index 79188247df..41bc86fdf1 100644
--- a/mmengine/utils/package_utils.py
+++ b/mmengine/utils/package_utils.py
@@ -68,7 +68,7 @@ def get_installed_path(package: str) -> str:
top_level = dist.read_text('top_level.txt')
if top_level:
module_name = top_level.split('\n')[0].strip()
- possible_path = osp.join(dist.locate_file(''), module_name)
+ possible_path = osp.join(str(dist.locate_file('')), module_name)
if osp.exists(possible_path):
return possible_path
diff --git a/mmengine/visualization/visualizer.py b/mmengine/visualization/visualizer.py
index 6979395aca..58e2293013 100644
--- a/mmengine/visualization/visualizer.py
+++ b/mmengine/visualization/visualizer.py
@@ -754,8 +754,8 @@ def draw_bboxes(
assert bboxes.shape[-1] == 4, (
f'The shape of `bboxes` should be (N, 4), but got {bboxes.shape}')
- assert (bboxes[:, 0] <= bboxes[:, 2]).all() and (bboxes[:, 1] <=
- bboxes[:, 3]).all()
+ assert ((bboxes[:, 0] <= bboxes[:, 2]).all()) and (
+ (bboxes[:, 1] <= bboxes[:, 3]).all())
if not self._is_posion_valid(bboxes.reshape((-1, 2, 2))):
warnings.warn(
'Warning: The bbox is out of bounds,'
diff --git a/tests/test_analysis/test_jit_analysis.py b/tests/test_analysis/test_jit_analysis.py
index be10309d0f..4b1dfaf595 100644
--- a/tests/test_analysis/test_jit_analysis.py
+++ b/tests/test_analysis/test_jit_analysis.py
@@ -634,9 +634,10 @@ def dummy_ops_handle(inputs: List[Any],
dummy_flops = {}
for name, counts in model.flops.items():
- dummy_flops[name] = Counter(
- {op: flop
- for op, flop in counts.items() if op != self.lin_op})
+ dummy_flops[name] = Counter({
+ op: flop
+ for op, flop in counts.items() if op != self.lin_op
+ })
dummy_flops[''][dummy_name] = 2 * dummy_out
dummy_flops['fc'][dummy_name] = dummy_out
dummy_flops['submod'][dummy_name] = dummy_out
diff --git a/tests/test_dataset/test_base_dataset.py b/tests/test_dataset/test_base_dataset.py
index f4ec815ec2..48bba665fe 100644
--- a/tests/test_dataset/test_base_dataset.py
+++ b/tests/test_dataset/test_base_dataset.py
@@ -733,13 +733,13 @@ def test_length(self):
def test_getitem(self):
assert (
self.cat_datasets[0]['imgs'] == self.dataset_a[0]['imgs']).all()
- assert (self.cat_datasets[0]['imgs'] !=
- self.dataset_b[0]['imgs']).all()
+ assert (self.cat_datasets[0]['imgs']
+ != self.dataset_b[0]['imgs']).all()
assert (
self.cat_datasets[-1]['imgs'] == self.dataset_b[-1]['imgs']).all()
- assert (self.cat_datasets[-1]['imgs'] !=
- self.dataset_a[-1]['imgs']).all()
+ assert (self.cat_datasets[-1]['imgs']
+ != self.dataset_a[-1]['imgs']).all()
def test_get_data_info(self):
assert self.cat_datasets.get_data_info(
diff --git a/tests/test_fileio/test_backends/test_local_backend.py b/tests/test_fileio/test_backends/test_local_backend.py
index 427ebf789a..286510a48d 100644
--- a/tests/test_fileio/test_backends/test_local_backend.py
+++ b/tests/test_fileio/test_backends/test_local_backend.py
@@ -148,13 +148,12 @@ def test_join_path(self, path_type):
backend = LocalBackend()
filepath = backend.join_path(
path_type(self.test_data_dir), path_type('file'))
- expected = osp.join(path_type(self.test_data_dir), path_type('file'))
+ expected = path_type(osp.join(str(self.test_data_dir), 'file'))
self.assertEqual(filepath, expected)
filepath = backend.join_path(
path_type(self.test_data_dir), path_type('dir'), path_type('file'))
- expected = osp.join(
- path_type(self.test_data_dir), path_type('dir'), path_type('file'))
+ expected = path_type(osp.join(str(self.test_data_dir), 'dir', 'file'))
self.assertEqual(filepath, expected)
@parameterized.expand([[Path], [str]])
diff --git a/tests/test_fileio/test_io.py b/tests/test_fileio/test_io.py
index c34af47e0b..f77a605270 100644
--- a/tests/test_fileio/test_io.py
+++ b/tests/test_fileio/test_io.py
@@ -245,11 +245,11 @@ def test_isfile():
def test_join_path():
# test LocalBackend
filepath = fileio.join_path(test_data_dir, 'file')
- expected = osp.join(test_data_dir, 'file')
+ expected = test_data_dir / 'file'
assert filepath == expected
filepath = fileio.join_path(test_data_dir, 'dir', 'file')
- expected = osp.join(test_data_dir, 'dir', 'file')
+ expected = test_data_dir / 'dir' / 'file'
assert filepath == expected
diff --git a/tests/test_hooks/test_empty_cache_hook.py b/tests/test_hooks/test_empty_cache_hook.py
index d30972d360..024699e44d 100644
--- a/tests/test_hooks/test_empty_cache_hook.py
+++ b/tests/test_hooks/test_empty_cache_hook.py
@@ -15,7 +15,7 @@ def test_with_runner(self):
with patch('torch.cuda.empty_cache') as mock_empty_cache:
cfg = self.epoch_based_cfg
cfg.custom_hooks = [dict(type='EmptyCacheHook')]
- cfg.train_cfg.val_interval = 1e6 # disable validation during training # noqa: E501
+ cfg.train_cfg.val_begin = 1e6 # disable validation during training # noqa: E501
runner = self.build_runner(cfg)
runner.train()
diff --git a/tests/test_optim/test_optimizer/test_optimizer_wrapper.py b/tests/test_optim/test_optimizer/test_optimizer_wrapper.py
index ba4ca77d11..763c2d054c 100644
--- a/tests/test_optim/test_optimizer/test_optimizer_wrapper.py
+++ b/tests/test_optim/test_optimizer/test_optimizer_wrapper.py
@@ -455,8 +455,8 @@ def test_init(self):
not torch.cuda.is_available(),
reason='`torch.cuda.amp` is only available when pytorch-gpu installed')
def test_step(self, dtype):
- if dtype is not None and (digit_version(TORCH_VERSION) <
- digit_version('1.10.0')):
+ if dtype is not None and (digit_version(TORCH_VERSION)
+ < digit_version('1.10.0')):
raise unittest.SkipTest('Require PyTorch version >= 1.10.0 to '
'support `dtype` argument in autocast')
if dtype == 'bfloat16' and not bf16_supported():
@@ -478,8 +478,8 @@ def test_step(self, dtype):
not torch.cuda.is_available(),
reason='`torch.cuda.amp` is only available when pytorch-gpu installed')
def test_backward(self, dtype):
- if dtype is not None and (digit_version(TORCH_VERSION) <
- digit_version('1.10.0')):
+ if dtype is not None and (digit_version(TORCH_VERSION)
+ < digit_version('1.10.0')):
raise unittest.SkipTest('Require PyTorch version >= 1.10.0 to '
'support `dtype` argument in autocast')
if dtype == 'bfloat16' and not bf16_supported():
@@ -539,8 +539,8 @@ def test_load_state_dict(self):
not torch.cuda.is_available(),
reason='`torch.cuda.amp` is only available when pytorch-gpu installed')
def test_optim_context(self, dtype, target_dtype):
- if dtype is not None and (digit_version(TORCH_VERSION) <
- digit_version('1.10.0')):
+ if dtype is not None and (digit_version(TORCH_VERSION)
+ < digit_version('1.10.0')):
raise unittest.SkipTest('Require PyTorch version >= 1.10.0 to '
'support `dtype` argument in autocast')
if dtype == 'bfloat16' and not bf16_supported():