Skip to content

Commit fc1fcb5

Browse files
committed
Fix max iters issue and add tests
1 parent f41b1e3 commit fc1fcb5

File tree

10 files changed

+956
-67
lines changed

10 files changed

+956
-67
lines changed

ignite/base/mixins.py

Lines changed: 28 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,18 @@
11
from collections import OrderedDict
22
from collections.abc import Mapping
3-
from typing import Tuple
3+
from typing import List, Tuple
44

55

66
class Serializable:
7-
_state_dict_all_req_keys: Tuple = ()
8-
_state_dict_one_of_opt_keys: Tuple = ()
7+
_state_dict_all_req_keys: Tuple[str, ...] = ()
8+
_state_dict_one_of_opt_keys: Tuple[Tuple[str, ...], ...] = ((),)
9+
10+
def __init__(self) -> None:
11+
self._state_dict_user_keys: List[str] = []
12+
13+
@property
14+
def state_dict_user_keys(self) -> List:
15+
return self._state_dict_user_keys
916

1017
def state_dict(self) -> OrderedDict:
1118
raise NotImplementedError
@@ -19,6 +26,21 @@ def load_state_dict(self, state_dict: Mapping) -> None:
1926
raise ValueError(
2027
f"Required state attribute '{k}' is absent in provided state_dict '{state_dict.keys()}'"
2128
)
22-
opts = [k in state_dict for k in self._state_dict_one_of_opt_keys]
23-
if len(opts) > 0 and ((not any(opts)) or (all(opts))):
24-
raise ValueError(f"state_dict should contain only one of '{self._state_dict_one_of_opt_keys}' keys")
29+
30+
# Handle groups of one-of optional keys
31+
for one_of_opt_keys in self._state_dict_one_of_opt_keys:
32+
if len(one_of_opt_keys) > 0:
33+
opts = [k in state_dict for k in one_of_opt_keys]
34+
num_present = sum(opts)
35+
if num_present == 0:
36+
raise ValueError(f"state_dict should contain at least one of '{one_of_opt_keys}' keys")
37+
if num_present > 1:
38+
raise ValueError(f"state_dict should contain only one of '{one_of_opt_keys}' keys")
39+
40+
# Check user keys
41+
if hasattr(self, "_state_dict_user_keys") and isinstance(self._state_dict_user_keys, list):
42+
for k in self._state_dict_user_keys:
43+
if k not in state_dict:
44+
raise ValueError(
45+
f"Required user state attribute '{k}' is absent in provided state_dict '{state_dict.keys()}'"
46+
)

ignite/engine/engine.py

Lines changed: 177 additions & 49 deletions
Large diffs are not rendered by default.

ignite/metrics/mean_average_precision.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66

77
import ignite.distributed as idist
88
from ignite.distributed.utils import all_gather_tensors_with_shapes
9-
from ignite.metrics.metric import Metric, reinit__is_reduced
9+
from ignite.metrics.metric import reinit__is_reduced
1010
from ignite.metrics.precision import _BaseClassification
1111
from ignite.utils import to_onehot
1212

@@ -220,13 +220,16 @@ def __init__(
220220
.. versionadded:: 0.5.2
221221
"""
222222

223-
super(MeanAveragePrecision, self).__init__(
223+
# Initialize _BaseClassification first
224+
_BaseClassification.__init__(
225+
self,
224226
output_transform=output_transform,
225227
is_multilabel=is_multilabel,
226228
device=device,
227229
skip_unrolling=skip_unrolling,
228230
)
229-
super(Metric, self).__init__(rec_thresholds=rec_thresholds, class_mean=class_mean)
231+
# Then initialize _BaseAveragePrecision
232+
_BaseAveragePrecision.__init__(self, rec_thresholds=rec_thresholds, class_mean=class_mean)
230233

231234
@reinit__is_reduced
232235
def reset(self) -> None:

ignite/metrics/metric.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -361,6 +361,7 @@ def __init__(
361361
device: Union[str, torch.device] = torch.device("cpu"),
362362
skip_unrolling: bool = False,
363363
):
364+
super().__init__() # Initialize Serializable
364365
if not callable(output_transform):
365366
raise TypeError(f"Argument output_transform should be callable, got {type(output_transform)}")
366367
self._output_transform = output_transform

ignite/metrics/vision/object_detection_average_precision_recall.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -124,12 +124,16 @@ def box_iou(pred_boxes: torch.Tensor, gt_boxes: torch.Tensor, iscrowd: torch.Boo
124124
self._area_range = area_range
125125
self._max_detections_per_image_per_class = max_detections_per_image_per_class
126126

127-
super(ObjectDetectionAvgPrecisionRecall, self).__init__(
127+
# Initialize Metric first
128+
Metric.__init__(
129+
self,
128130
output_transform=output_transform,
129131
device=device,
130132
skip_unrolling=skip_unrolling,
131133
)
132-
super(Metric, self).__init__(
134+
# Then initialize _BaseAveragePrecision
135+
_BaseAveragePrecision.__init__(
136+
self,
133137
rec_thresholds=rec_thresholds,
134138
class_mean=None,
135139
)

mypy.ini

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ warn_return_any = False
2222
; results in too many false positives, therefore set to False
2323
warn_unreachable = False
2424
warn_unused_configs = True
25-
warn_unused_ignores = True
25+
warn_unused_ignores = False
2626

2727
[mypy-apex.*]
2828
ignore_missing_imports = True

0 commit comments

Comments
 (0)