Skip to content

Commit 243afad

Browse files
committed
Fix max iters issue and add tests
1 parent 670bbee commit 243afad

File tree

5 files changed

+940
-59
lines changed

5 files changed

+940
-59
lines changed

ignite/base/mixins.py

Lines changed: 28 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,18 @@
11
from collections import OrderedDict
22
from collections.abc import Mapping
3-
from typing import Tuple
3+
from typing import List, Tuple
44

55

66
class Serializable:
7-
_state_dict_all_req_keys: Tuple = ()
8-
_state_dict_one_of_opt_keys: Tuple = ()
7+
_state_dict_all_req_keys: Tuple[str, ...] = ()
8+
_state_dict_one_of_opt_keys: Tuple[Tuple[str, ...], ...] = ((),)
9+
10+
def __init__(self) -> None:
11+
self._state_dict_user_keys: List[str] = []
12+
13+
@property
14+
def state_dict_user_keys(self) -> List:
15+
return self._state_dict_user_keys
916

1017
def state_dict(self) -> OrderedDict:
1118
raise NotImplementedError
@@ -19,6 +26,21 @@ def load_state_dict(self, state_dict: Mapping) -> None:
1926
raise ValueError(
2027
f"Required state attribute '{k}' is absent in provided state_dict '{state_dict.keys()}'"
2128
)
22-
opts = [k in state_dict for k in self._state_dict_one_of_opt_keys]
23-
if len(opts) > 0 and ((not any(opts)) or (all(opts))):
24-
raise ValueError(f"state_dict should contain only one of '{self._state_dict_one_of_opt_keys}' keys")
29+
30+
# Handle groups of one-of optional keys
31+
for one_of_opt_keys in self._state_dict_one_of_opt_keys:
32+
if len(one_of_opt_keys) > 0:
33+
opts = [k in state_dict for k in one_of_opt_keys]
34+
num_present = sum(opts)
35+
if num_present == 0:
36+
raise ValueError(f"state_dict should contain at least one of '{one_of_opt_keys}' keys")
37+
if num_present > 1:
38+
raise ValueError(f"state_dict should contain only one of '{one_of_opt_keys}' keys")
39+
40+
# Check user keys
41+
if hasattr(self, "_state_dict_user_keys") and isinstance(self._state_dict_user_keys, list):
42+
for k in self._state_dict_user_keys:
43+
if k not in state_dict:
44+
raise ValueError(
45+
f"Required user state attribute '{k}' is absent in provided state_dict '{state_dict.keys()}'"
46+
)

ignite/engine/engine.py

Lines changed: 177 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -128,13 +128,14 @@ def compute_mean_std(engine, batch):
128128
129129
"""
130130

131-
_state_dict_all_req_keys = ("epoch_length", "max_epochs")
132-
_state_dict_one_of_opt_keys = ("iteration", "epoch")
131+
_state_dict_all_req_keys = ("epoch_length",)
132+
_state_dict_one_of_opt_keys = (("iteration", "epoch"), ("max_epochs", "max_iters"))
133133

134134
# Flag to disable engine._internal_run as generator feature for BC
135135
interrupt_resume_enabled = True
136136

137137
def __init__(self, process_function: Callable[["Engine", Any], Any]):
138+
super(Engine, self).__init__()
138139
self._event_handlers: Dict[Any, List] = defaultdict(list)
139140
self.logger = logging.getLogger(__name__ + "." + self.__class__.__name__)
140141
self._process_function = process_function
@@ -147,7 +148,6 @@ def __init__(self, process_function: Callable[["Engine", Any], Any]):
147148
self.should_terminate_single_epoch: Union[bool, str] = False
148149
self.should_interrupt = False
149150
self.state = State()
150-
self._state_dict_user_keys: List[str] = []
151151
self._allowed_events: List[EventEnum] = []
152152

153153
self._dataloader_iter: Optional[Iterator[Any]] = None
@@ -691,14 +691,20 @@ def save_engine(_):
691691
a dictionary containing engine's state
692692
693693
"""
694-
keys: Tuple[str, ...] = self._state_dict_all_req_keys + (self._state_dict_one_of_opt_keys[0],)
694+
keys: Tuple[str, ...] = self._state_dict_all_req_keys
695+
keys += ("iteration",)
696+
# Include either max_epochs or max_iters based on which was originally set
697+
if self.state.max_iters is not None:
698+
keys += ("max_iters",)
699+
else:
700+
keys += ("max_epochs",)
695701
keys += tuple(self._state_dict_user_keys)
696702
return OrderedDict([(k, getattr(self.state, k)) for k in keys])
697703

698704
def load_state_dict(self, state_dict: Mapping) -> None:
699705
"""Setups engine from `state_dict`.
700706
701-
State dictionary should contain keys: `iteration` or `epoch`, `max_epochs` and `epoch_length`.
707+
State dictionary should contain keys: `iteration` or `epoch`, `max_epochs` or `max_iters`, and `epoch_length`.
702708
If `engine.state_dict_user_keys` contains keys, they should be also present in the state dictionary.
703709
Iteration and epoch values are 0-based: the first iteration or epoch is zero.
704710
@@ -709,10 +715,12 @@ def load_state_dict(self, state_dict: Mapping) -> None:
709715
710716
.. code-block:: python
711717
712-
# Restore from the 4rd epoch
718+
# Restore from the 4th epoch
713719
state_dict = {"epoch": 3, "max_epochs": 100, "epoch_length": len(data_loader)}
714720
# or 500th iteration
715721
# state_dict = {"iteration": 499, "max_epochs": 100, "epoch_length": len(data_loader)}
722+
# or with max_iters
723+
# state_dict = {"iteration": 499, "max_iters": 1000, "epoch_length": len(data_loader)}
716724
717725
trainer = Engine(...)
718726
trainer.load_state_dict(state_dict)
@@ -721,22 +729,20 @@ def load_state_dict(self, state_dict: Mapping) -> None:
721729
"""
722730
super(Engine, self).load_state_dict(state_dict)
723731

724-
for k in self._state_dict_user_keys:
725-
if k not in state_dict:
726-
raise ValueError(
727-
f"Required user state attribute '{k}' is absent in provided state_dict '{state_dict.keys()}'"
728-
)
729-
self.state.max_epochs = state_dict["max_epochs"]
732+
# Set epoch_length
730733
self.state.epoch_length = state_dict["epoch_length"]
734+
735+
# Set user keys
731736
for k in self._state_dict_user_keys:
732737
setattr(self.state, k, state_dict[k])
733738

739+
# Set iteration or epoch
734740
if "iteration" in state_dict:
735741
self.state.iteration = state_dict["iteration"]
736742
self.state.epoch = 0
737-
if self.state.epoch_length is not None:
743+
if self.state.epoch_length is not None and self.state.epoch_length > 0:
738744
self.state.epoch = self.state.iteration // self.state.epoch_length
739-
elif "epoch" in state_dict:
745+
else: # epoch is in state_dict
740746
self.state.epoch = state_dict["epoch"]
741747
if self.state.epoch_length is None:
742748
raise ValueError(
@@ -745,6 +751,36 @@ def load_state_dict(self, state_dict: Mapping) -> None:
745751
)
746752
self.state.iteration = self.state.epoch_length * self.state.epoch
747753

754+
# Set max_epochs or max_iters with validation
755+
max_epochs_value = state_dict.get("max_epochs", None)
756+
max_iters_value = state_dict.get("max_iters", None)
757+
758+
# Validate max_epochs if present
759+
if max_epochs_value is not None:
760+
if max_epochs_value < 1:
761+
raise ValueError("max_epochs in state_dict is invalid. Please, set a correct max_epochs positive value")
762+
if max_epochs_value < self.state.epoch:
763+
raise ValueError(
764+
"max_epochs in state_dict should be larger than or equal to the current epoch "
765+
f"defined in the state: {max_epochs_value} vs {self.state.epoch}. "
766+
)
767+
self.state.max_epochs = max_epochs_value
768+
else:
769+
self.state.max_epochs = None
770+
771+
# Validate max_iters if present
772+
if max_iters_value is not None:
773+
if max_iters_value < 1:
774+
raise ValueError("max_iters in state_dict is invalid. Please, set a correct max_iters positive value")
775+
if max_iters_value < self.state.iteration:
776+
raise ValueError(
777+
"max_iters in state_dict should be larger than or equal to the current iteration "
778+
f"defined in the state: {max_iters_value} vs {self.state.iteration}. "
779+
)
780+
self.state.max_iters = max_iters_value
781+
else:
782+
self.state.max_iters = None
783+
748784
@staticmethod
749785
def _is_done(state: State) -> bool:
750786
is_done_iters = state.max_iters is not None and state.iteration >= state.max_iters
@@ -756,6 +792,59 @@ def _is_done(state: State) -> bool:
756792
is_done_epochs = state.max_epochs is not None and state.epoch >= state.max_epochs
757793
return is_done_iters or is_done_count or is_done_epochs
758794

795+
def _check_and_set_max_epochs(self, max_epochs: Optional[int] = None) -> None:
796+
"""Validate and set max_epochs with proper checks."""
797+
if max_epochs is not None:
798+
if max_epochs < 1:
799+
raise ValueError("Argument max_epochs is invalid. Please, set a correct max_epochs positive value")
800+
# Only validate if training is actually done - allow resuming interrupted training
801+
if self.state.max_epochs is not None and max_epochs < self.state.epoch:
802+
raise ValueError(
803+
"Argument max_epochs should be greater than or equal to the start "
804+
f"epoch defined in the state: {max_epochs} vs {self.state.epoch}. "
805+
"Please, set engine.state.max_epochs = None "
806+
"before calling engine.run() in order to restart the training from the beginning."
807+
)
808+
self.state.max_epochs = max_epochs
809+
810+
def _check_and_set_max_iters(self, max_iters: Optional[int] = None) -> None:
811+
"""Validate and set max_iters with proper checks."""
812+
if max_iters is not None:
813+
if max_iters < 1:
814+
raise ValueError("Argument max_iters is invalid. Please, set a correct max_iters positive value")
815+
# Only validate if training is actually done - allow resuming interrupted training
816+
if (self.state.max_iters is not None) and max_iters < self.state.iteration:
817+
raise ValueError(
818+
"Argument max_iters should be greater than or equal to the start "
819+
f"iteration defined in the state: {max_iters} vs {self.state.iteration}. "
820+
"Please, set engine.state.max_iters = None "
821+
"before calling engine.run() in order to restart the training from the beginning."
822+
)
823+
self.state.max_iters = max_iters
824+
825+
def _check_and_set_epoch_length(self, data: Optional[Iterable], epoch_length: Optional[int] = None) -> None:
826+
"""Validate and set epoch_length."""
827+
# Check if we can redefine epoch_length
828+
if self.state.epoch_length is not None:
829+
if epoch_length is not None:
830+
if epoch_length != self.state.epoch_length:
831+
raise ValueError(
832+
"Argument epoch_length should be same as in the state, "
833+
f"but given {epoch_length} vs {self.state.epoch_length}"
834+
)
835+
else:
836+
if epoch_length is None:
837+
if data is not None:
838+
epoch_length = self._get_data_length(data)
839+
840+
if epoch_length is not None:
841+
if epoch_length < 1:
842+
raise ValueError(
843+
"Argument epoch_length is invalid. Please, either set a correct epoch_length value or "
844+
"check if input data has non-zero size."
845+
)
846+
self.state.epoch_length = epoch_length
847+
759848
def set_data(self, data: Union[Iterable, DataLoader]) -> None:
760849
"""Method to set data. After calling the method the next batch passed to `processing_function` is
761850
from newly provided data. Please, note that epoch length is not modified.
@@ -854,59 +943,98 @@ def switch_batch(engine):
854943
if data is not None and not isinstance(data, Iterable):
855944
raise TypeError("Argument data should be iterable")
856945

857-
if self.state.max_epochs is not None:
858-
# Check and apply overridden parameters
859-
if max_epochs is not None:
860-
if max_epochs < self.state.epoch:
861-
raise ValueError(
862-
"Argument max_epochs should be greater than or equal to the start "
863-
f"epoch defined in the state: {max_epochs} vs {self.state.epoch}. "
864-
"Please, set engine.state.max_epochs = None "
865-
"before calling engine.run() in order to restart the training from the beginning."
866-
)
867-
self.state.max_epochs = max_epochs
868-
if epoch_length is not None:
869-
if epoch_length != self.state.epoch_length:
870-
raise ValueError(
871-
"Argument epoch_length should be same as in the state, "
872-
f"but given {epoch_length} vs {self.state.epoch_length}"
873-
)
946+
if max_epochs is not None and max_iters is not None:
947+
raise ValueError(
948+
"Arguments max_iters and max_epochs are mutually exclusive."
949+
"Please provide only max_epochs or max_iters."
950+
)
874951

875-
if self.state.max_epochs is None or (self._is_done(self.state) and self._internal_run_generator is None):
876-
# Create new state
877-
if epoch_length is None:
878-
if data is None:
879-
raise ValueError("epoch_length should be provided if data is None")
952+
# Check if we need to create new state or resume
953+
# Create new state if:
954+
# 1. No termination params set (first run), OR
955+
# 2. Training is done AND generator is None AND no new params provided
956+
# 3. Training is done AND same termination params provided (restart case)
957+
should_create_new_state = (
958+
(self.state.max_epochs is None and self.state.max_iters is None)
959+
or (
960+
self._is_done(self.state)
961+
and self._internal_run_generator is None
962+
and max_epochs is None
963+
and max_iters is None
964+
)
965+
or (
966+
self._is_done(self.state)
967+
and self._internal_run_generator is None
968+
and (
969+
(max_epochs is not None and max_epochs == self.state.max_epochs)
970+
or (max_iters is not None and max_iters == self.state.max_iters)
971+
)
972+
)
973+
)
880974

881-
epoch_length = self._get_data_length(data)
882-
if epoch_length is not None and epoch_length < 1:
883-
raise ValueError("Input data has zero size. Please provide non-empty data")
975+
if should_create_new_state:
976+
# Create new state
977+
if data is None and epoch_length is None and self.state.epoch_length is None:
978+
raise ValueError("epoch_length should be provided if data is None")
884979

980+
# Set epoch_length for new state
981+
if epoch_length is None:
982+
# Try to get from data first, then fall back to existing state
983+
if data is not None:
984+
epoch_length = self._get_data_length(data)
985+
if epoch_length is None and self.state.epoch_length is not None:
986+
epoch_length = self.state.epoch_length
987+
if epoch_length is not None and epoch_length < 1:
988+
raise ValueError("Input data has zero size. Please provide non-empty data")
989+
990+
# Determine max_epochs/max_iters
885991
if max_iters is None:
886992
if max_epochs is None:
887993
max_epochs = 1
888994
else:
889-
if max_epochs is not None:
890-
raise ValueError(
891-
"Arguments max_iters and max_epochs are mutually exclusive."
892-
"Please provide only max_epochs or max_iters."
893-
)
894995
if epoch_length is not None:
895996
max_epochs = math.ceil(max_iters / epoch_length)
896997

998+
# Initialize new state
897999
self.state.iteration = 0
8981000
self.state.epoch = 0
8991001
self.state.max_epochs = max_epochs
9001002
self.state.max_iters = max_iters
9011003
self.state.epoch_length = epoch_length
9021004
# Reset generator if previously used
9031005
self._internal_run_generator = None
904-
self.logger.info(f"Engine run starting with max_epochs={max_epochs}.")
1006+
1007+
# Log start message
1008+
if self.state.max_epochs is not None:
1009+
self.logger.info(f"Engine run starting with max_epochs={self.state.max_epochs}.")
1010+
else:
1011+
self.logger.info(f"Engine run starting with max_iters={self.state.max_iters}.")
9051012
else:
906-
self.logger.info(
907-
f"Engine run resuming from iteration {self.state.iteration}, "
908-
f"epoch {self.state.epoch} until {self.state.max_epochs} epochs"
909-
)
1013+
# Resume from existing state
1014+
# Apply overridden parameters using helper methods
1015+
self._check_and_set_max_epochs(max_epochs)
1016+
self._check_and_set_max_iters(max_iters)
1017+
1018+
# Handle epoch_length validation (simplified from original)
1019+
if epoch_length is not None:
1020+
if epoch_length != self.state.epoch_length:
1021+
raise ValueError(
1022+
"Argument epoch_length should be same as in the state, "
1023+
f"but given {epoch_length} vs {self.state.epoch_length}"
1024+
)
1025+
1026+
# Log resuming message
1027+
if self.state.max_epochs is not None:
1028+
self.logger.info(
1029+
f"Engine run resuming from iteration {self.state.iteration}, "
1030+
f"epoch {self.state.epoch} until {self.state.max_epochs} epochs"
1031+
)
1032+
else:
1033+
self.logger.info(
1034+
f"Engine run resuming from iteration {self.state.iteration}, "
1035+
f"epoch {self.state.epoch} until {self.state.max_iters} iterations"
1036+
)
1037+
9101038
if self.state.epoch_length is None and data is None:
9111039
raise ValueError("epoch_length should be provided if data is None")
9121040

0 commit comments

Comments
 (0)