Skip to content

Commit 075f57c

Browse files
committed
Fix max iters issue and add tests
1 parent 72fede0 commit 075f57c

22 files changed

+1122
-130
lines changed
Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
name: Code Style Checks
2+
3+
on:
4+
push:
5+
branches:
6+
- master
7+
- "*.*.*"
8+
paths:
9+
- "**.py"
10+
- "setup.cfg"
11+
- "requirements-dev.txt"
12+
- "pyproject.toml"
13+
- ".pre-commit-config.yaml"
14+
- ".github/workflows/code-style-checks.yml"
15+
- "!assets/**"
16+
- "!docker/**"
17+
- "!docs/**"
18+
- "!conda.recipe"
19+
pull_request:
20+
paths:
21+
- "**.py"
22+
- "setup.cfg"
23+
- "requirements-dev.txt"
24+
- "pyproject.toml"
25+
- ".pre-commit-config.yaml"
26+
- ".github/workflows/code-style-checks.yml"
27+
- "!assets/**"
28+
- "!docker/**"
29+
- "!docs/**"
30+
- "!conda.recipe"
31+
workflow_dispatch:
32+
33+
concurrency:
34+
# <workflow_name>-<branch_name>-<true || commit_sha (if branch is protected)>
35+
group: code-style-${{ github.ref_name }}-${{ !(github.ref_protected) || github.sha }}
36+
cancel-in-progress: true
37+
38+
jobs:
39+
code-style:
40+
runs-on: ubuntu-latest
41+
strategy:
42+
matrix:
43+
python-version: ["3.9", "3.13"]
44+
45+
steps:
46+
- uses: actions/checkout@v4
47+
48+
- uses: astral-sh/setup-uv@v6
49+
with:
50+
version: "latest"
51+
python-version: ${{ matrix.python-version }}
52+
activate-environment: true
53+
enable-cache: true
54+
cache-dependency-glob: |
55+
**/requirements-dev.txt
56+
**/pyproject.toml
57+
58+
- name: Install dependencies
59+
run: |
60+
uv pip install pre-commit
61+
uv pip install -r requirements-dev.txt
62+
uv pip install -e .
63+
64+
- name: Run pre-commit checks
65+
run: pre-commit run --all-files --show-diff-on-failure

.github/workflows/code-style.yml

Lines changed: 0 additions & 40 deletions
This file was deleted.

.github/workflows/gpu-hvd-tests.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -161,7 +161,7 @@ jobs:
161161
- name: Upload coverage to Codecov
162162
uses: codecov/codecov-action@v5
163163
with:
164-
file: ${{ github.repository }}/coverage.xml
164+
files: ${{ github.repository }}/coverage.xml
165165
flags: gpu-2
166166
fail_ci_if_error: false
167167

.github/workflows/gpu-tests.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -131,7 +131,7 @@ jobs:
131131
- name: Upload coverage to Codecov
132132
uses: codecov/codecov-action@v5
133133
with:
134-
file: ${{ github.repository }}/coverage.xml
134+
files: ${{ github.repository }}/coverage.xml
135135
flags: gpu-2
136136
fail_ci_if_error: false
137137

.github/workflows/hvd-tests.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,6 @@ jobs:
9494
- name: Upload coverage to Codecov
9595
uses: codecov/codecov-action@v5
9696
with:
97-
file: ./coverage.xml
97+
files: ./coverage.xml
9898
flags: hvd-cpu
9999
fail_ci_if_error: false

.github/workflows/mps-tests.yml

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,12 @@ jobs:
7575
conda activate $CONDA_ENV
7676
pip install uv
7777
78+
- name: Install GPG for Codecov
79+
shell: bash -l {0}
80+
run: |
81+
# Install GPG which is required by codecov-action@v5
82+
brew install gnupg
83+
7884
- name: Install PyTorch
7985
if: ${{ matrix.pytorch-channel == 'pytorch' }}
8086
shell: bash -l {0}
@@ -130,7 +136,7 @@ jobs:
130136
- name: Upload coverage to Codecov
131137
uses: codecov/codecov-action@v5
132138
with:
133-
file: ${{ github.repository }}/coverage.xml
139+
files: ${{ github.repository }}/coverage.xml
134140
flags: mps
135141
fail_ci_if_error: false
136142

.github/workflows/tpu-tests.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -103,6 +103,6 @@ jobs:
103103
- name: Upload coverage to Codecov
104104
uses: codecov/codecov-action@v5
105105
with:
106-
file: ./coverage.xml
106+
files: ./coverage.xml
107107
flags: tpu
108108
fail_ci_if_error: false
Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,69 @@
1+
name: Type Checking
2+
3+
on:
4+
push:
5+
branches:
6+
- master
7+
- "*.*.*"
8+
paths:
9+
- "ignite/**"
10+
- "examples/**.py"
11+
- "tests/ignite/**"
12+
- "requirements-dev.txt"
13+
- "pyproject.toml"
14+
- "mypy.ini"
15+
- ".github/workflows/typing-checks.yml"
16+
pull_request:
17+
paths:
18+
- "ignite/**"
19+
- "examples/**.py"
20+
- "tests/ignite/**"
21+
- "requirements-dev.txt"
22+
- "pyproject.toml"
23+
- "mypy.ini"
24+
- ".github/workflows/typing-checks.yml"
25+
workflow_dispatch:
26+
27+
concurrency:
28+
# <workflow_name>-<branch_name>-<true || commit_sha (if branch is protected)>
29+
group: typing-${{ github.ref_name }}-${{ !(github.ref_protected) || github.sha }}
30+
cancel-in-progress: true
31+
32+
jobs:
33+
mypy:
34+
runs-on: ubuntu-latest
35+
strategy:
36+
matrix:
37+
python-version: ["3.9", "3.13"]
38+
pytorch-channel: [pytorch]
39+
40+
steps:
41+
- uses: actions/checkout@v4
42+
43+
- name: Get year & week number
44+
id: get-date
45+
run: |
46+
echo "date=$(/bin/date "+%Y-%U")" >> $GITHUB_OUTPUT
47+
48+
- uses: astral-sh/setup-uv@v6
49+
with:
50+
version: "latest"
51+
python-version: ${{ matrix.python-version }}
52+
activate-environment: true
53+
enable-cache: true
54+
cache-suffix: "${{ steps.get-date.outputs.date }}-typing-${{ runner.os }}-${{ matrix.python-version }}"
55+
cache-dependency-glob: |
56+
**/requirements-dev.txt
57+
**/pyproject.toml
58+
59+
- name: Install PyTorch
60+
run: uv pip install torch torchvision --index-url https://download.pytorch.org/whl/cpu
61+
62+
- name: Install dependencies
63+
run: |
64+
uv pip install -r requirements-dev.txt
65+
uv pip install .
66+
uv pip install mypy
67+
68+
- name: Run MyPy type checking
69+
run: mypy

.github/workflows/unit-tests.yml

Lines changed: 1 addition & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -93,17 +93,6 @@ jobs:
9393
uv pip install .
9494
uv pip list
9595
96-
- name: Check code formatting
97-
run: |
98-
pre-commit run -a
99-
100-
- name: Run Mypy
101-
# https://github.com/pytorch/ignite/pull/2780
102-
#
103-
if: ${{ matrix.os == 'ubuntu-latest' }}
104-
run: |
105-
mypy
106-
10796
# Download MNIST: https://github.com/pytorch/ignite/issues/1737
10897
# to "/tmp" for unit tests
10998
- name: Download MNIST
@@ -128,7 +117,7 @@ jobs:
128117
- name: Upload coverage to Codecov
129118
uses: codecov/codecov-action@v5
130119
with:
131-
file: ./coverage.xml
120+
files: ./coverage.xml
132121
flags: cpu
133122
fail_ci_if_error: false
134123

ignite/base/mixins.py

Lines changed: 28 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,18 @@
11
from collections import OrderedDict
22
from collections.abc import Mapping
3-
from typing import Tuple
3+
from typing import List, Tuple
44

55

66
class Serializable:
7-
_state_dict_all_req_keys: Tuple = ()
8-
_state_dict_one_of_opt_keys: Tuple = ()
7+
_state_dict_all_req_keys: Tuple[str, ...] = ()
8+
_state_dict_one_of_opt_keys: Tuple[Tuple[str, ...], ...] = ((),)
9+
10+
def __init__(self) -> None:
11+
self._state_dict_user_keys: List[str] = []
12+
13+
@property
14+
def state_dict_user_keys(self) -> List:
15+
return self._state_dict_user_keys
916

1017
def state_dict(self) -> OrderedDict:
1118
raise NotImplementedError
@@ -19,6 +26,21 @@ def load_state_dict(self, state_dict: Mapping) -> None:
1926
raise ValueError(
2027
f"Required state attribute '{k}' is absent in provided state_dict '{state_dict.keys()}'"
2128
)
22-
opts = [k in state_dict for k in self._state_dict_one_of_opt_keys]
23-
if len(opts) > 0 and ((not any(opts)) or (all(opts))):
24-
raise ValueError(f"state_dict should contain only one of '{self._state_dict_one_of_opt_keys}' keys")
29+
30+
# Handle groups of one-of optional keys
31+
for one_of_opt_keys in self._state_dict_one_of_opt_keys:
32+
if len(one_of_opt_keys) > 0:
33+
opts = [k in state_dict for k in one_of_opt_keys]
34+
num_present = sum(opts)
35+
if num_present == 0:
36+
raise ValueError(f"state_dict should contain at least one of '{one_of_opt_keys}' keys")
37+
if num_present > 1:
38+
raise ValueError(f"state_dict should contain only one of '{one_of_opt_keys}' keys")
39+
40+
# Check user keys
41+
if hasattr(self, "_state_dict_user_keys") and isinstance(self._state_dict_user_keys, list):
42+
for k in self._state_dict_user_keys:
43+
if k not in state_dict:
44+
raise ValueError(
45+
f"Required user state attribute '{k}' is absent in provided state_dict '{state_dict.keys()}'"
46+
)

0 commit comments

Comments
 (0)