@@ -128,13 +128,14 @@ def compute_mean_std(engine, batch):
128
128
129
129
"""
130
130
131
- _state_dict_all_req_keys = ("epoch_length" , "max_epochs" )
132
- _state_dict_one_of_opt_keys = ("iteration" , "epoch" )
131
+ _state_dict_all_req_keys = ("epoch_length" ,)
132
+ _state_dict_one_of_opt_keys = (( "iteration" , "epoch" ), ( "max_epochs" , "max_iters" ) )
133
133
134
134
# Flag to disable engine._internal_run as generator feature for BC
135
135
interrupt_resume_enabled = True
136
136
137
137
def __init__ (self , process_function : Callable [["Engine" , Any ], Any ]):
138
+ super (Engine , self ).__init__ ()
138
139
self ._event_handlers : Dict [Any , List ] = defaultdict (list )
139
140
self .logger = logging .getLogger (__name__ + "." + self .__class__ .__name__ )
140
141
self ._process_function = process_function
@@ -147,7 +148,6 @@ def __init__(self, process_function: Callable[["Engine", Any], Any]):
147
148
self .should_terminate_single_epoch : Union [bool , str ] = False
148
149
self .should_interrupt = False
149
150
self .state = State ()
150
- self ._state_dict_user_keys : List [str ] = []
151
151
self ._allowed_events : List [EventEnum ] = []
152
152
153
153
self ._dataloader_iter : Optional [Iterator [Any ]] = None
@@ -691,14 +691,20 @@ def save_engine(_):
691
691
a dictionary containing engine's state
692
692
693
693
"""
694
- keys : Tuple [str , ...] = self ._state_dict_all_req_keys + (self ._state_dict_one_of_opt_keys [0 ],)
694
+ keys : Tuple [str , ...] = self ._state_dict_all_req_keys
695
+ keys += ("iteration" ,)
696
+ # Include either max_epochs or max_iters based on which was originally set
697
+ if self .state .max_iters is not None :
698
+ keys += ("max_iters" ,)
699
+ else :
700
+ keys += ("max_epochs" ,)
695
701
keys += tuple (self ._state_dict_user_keys )
696
702
return OrderedDict ([(k , getattr (self .state , k )) for k in keys ])
697
703
698
704
def load_state_dict (self , state_dict : Mapping ) -> None :
699
705
"""Setups engine from `state_dict`.
700
706
701
- State dictionary should contain keys: `iteration` or `epoch`, `max_epochs` and `epoch_length`.
707
+ State dictionary should contain keys: `iteration` or `epoch`, `max_epochs` or `max_iters`, and `epoch_length`.
702
708
If `engine.state_dict_user_keys` contains keys, they should be also present in the state dictionary.
703
709
Iteration and epoch values are 0-based: the first iteration or epoch is zero.
704
710
@@ -709,10 +715,12 @@ def load_state_dict(self, state_dict: Mapping) -> None:
709
715
710
716
.. code-block:: python
711
717
712
- # Restore from the 4rd epoch
718
+ # Restore from the 4th epoch
713
719
state_dict = {"epoch": 3, "max_epochs": 100, "epoch_length": len(data_loader)}
714
720
# or 500th iteration
715
721
# state_dict = {"iteration": 499, "max_epochs": 100, "epoch_length": len(data_loader)}
722
+ # or with max_iters
723
+ # state_dict = {"iteration": 499, "max_iters": 1000, "epoch_length": len(data_loader)}
716
724
717
725
trainer = Engine(...)
718
726
trainer.load_state_dict(state_dict)
@@ -721,22 +729,20 @@ def load_state_dict(self, state_dict: Mapping) -> None:
721
729
"""
722
730
super (Engine , self ).load_state_dict (state_dict )
723
731
724
- for k in self ._state_dict_user_keys :
725
- if k not in state_dict :
726
- raise ValueError (
727
- f"Required user state attribute '{ k } ' is absent in provided state_dict '{ state_dict .keys ()} '"
728
- )
729
- self .state .max_epochs = state_dict ["max_epochs" ]
732
+ # Set epoch_length
730
733
self .state .epoch_length = state_dict ["epoch_length" ]
734
+
735
+ # Set user keys
731
736
for k in self ._state_dict_user_keys :
732
737
setattr (self .state , k , state_dict [k ])
733
738
739
+ # Set iteration or epoch
734
740
if "iteration" in state_dict :
735
741
self .state .iteration = state_dict ["iteration" ]
736
742
self .state .epoch = 0
737
- if self .state .epoch_length is not None :
743
+ if self .state .epoch_length is not None and self . state . epoch_length > 0 :
738
744
self .state .epoch = self .state .iteration // self .state .epoch_length
739
- elif " epoch" in state_dict :
745
+ else : # epoch is in state_dict
740
746
self .state .epoch = state_dict ["epoch" ]
741
747
if self .state .epoch_length is None :
742
748
raise ValueError (
@@ -745,6 +751,36 @@ def load_state_dict(self, state_dict: Mapping) -> None:
745
751
)
746
752
self .state .iteration = self .state .epoch_length * self .state .epoch
747
753
754
+ # Set max_epochs or max_iters with validation
755
+ max_epochs_value = state_dict .get ("max_epochs" , None )
756
+ max_iters_value = state_dict .get ("max_iters" , None )
757
+
758
+ # Validate max_epochs if present
759
+ if max_epochs_value is not None :
760
+ if max_epochs_value < 1 :
761
+ raise ValueError ("max_epochs in state_dict is invalid. Please, set a correct max_epochs positive value" )
762
+ if max_epochs_value < self .state .epoch :
763
+ raise ValueError (
764
+ "max_epochs in state_dict should be larger than or equal to the current epoch "
765
+ f"defined in the state: { max_epochs_value } vs { self .state .epoch } . "
766
+ )
767
+ self .state .max_epochs = max_epochs_value
768
+ else :
769
+ self .state .max_epochs = None
770
+
771
+ # Validate max_iters if present
772
+ if max_iters_value is not None :
773
+ if max_iters_value < 1 :
774
+ raise ValueError ("max_iters in state_dict is invalid. Please, set a correct max_iters positive value" )
775
+ if max_iters_value < self .state .iteration :
776
+ raise ValueError (
777
+ "max_iters in state_dict should be larger than or equal to the current iteration "
778
+ f"defined in the state: { max_iters_value } vs { self .state .iteration } . "
779
+ )
780
+ self .state .max_iters = max_iters_value
781
+ else :
782
+ self .state .max_iters = None
783
+
748
784
@staticmethod
749
785
def _is_done (state : State ) -> bool :
750
786
is_done_iters = state .max_iters is not None and state .iteration >= state .max_iters
@@ -756,6 +792,59 @@ def _is_done(state: State) -> bool:
756
792
is_done_epochs = state .max_epochs is not None and state .epoch >= state .max_epochs
757
793
return is_done_iters or is_done_count or is_done_epochs
758
794
795
+ def _check_and_set_max_epochs (self , max_epochs : Optional [int ] = None ) -> None :
796
+ """Validate and set max_epochs with proper checks."""
797
+ if max_epochs is not None :
798
+ if max_epochs < 1 :
799
+ raise ValueError ("Argument max_epochs is invalid. Please, set a correct max_epochs positive value" )
800
+ # Only validate if training is actually done - allow resuming interrupted training
801
+ if self .state .max_epochs is not None and max_epochs < self .state .epoch :
802
+ raise ValueError (
803
+ "Argument max_epochs should be greater than or equal to the start "
804
+ f"epoch defined in the state: { max_epochs } vs { self .state .epoch } . "
805
+ "Please, set engine.state.max_epochs = None "
806
+ "before calling engine.run() in order to restart the training from the beginning."
807
+ )
808
+ self .state .max_epochs = max_epochs
809
+
810
+ def _check_and_set_max_iters (self , max_iters : Optional [int ] = None ) -> None :
811
+ """Validate and set max_iters with proper checks."""
812
+ if max_iters is not None :
813
+ if max_iters < 1 :
814
+ raise ValueError ("Argument max_iters is invalid. Please, set a correct max_iters positive value" )
815
+ # Only validate if training is actually done - allow resuming interrupted training
816
+ if (self .state .max_iters is not None ) and max_iters < self .state .iteration :
817
+ raise ValueError (
818
+ "Argument max_iters should be greater than or equal to the start "
819
+ f"iteration defined in the state: { max_iters } vs { self .state .iteration } . "
820
+ "Please, set engine.state.max_iters = None "
821
+ "before calling engine.run() in order to restart the training from the beginning."
822
+ )
823
+ self .state .max_iters = max_iters
824
+
825
+ def _check_and_set_epoch_length (self , data : Optional [Iterable ], epoch_length : Optional [int ] = None ) -> None :
826
+ """Validate and set epoch_length."""
827
+ # Check if we can redefine epoch_length
828
+ if self .state .epoch_length is not None :
829
+ if epoch_length is not None :
830
+ if epoch_length != self .state .epoch_length :
831
+ raise ValueError (
832
+ "Argument epoch_length should be same as in the state, "
833
+ f"but given { epoch_length } vs { self .state .epoch_length } "
834
+ )
835
+ else :
836
+ if epoch_length is None :
837
+ if data is not None :
838
+ epoch_length = self ._get_data_length (data )
839
+
840
+ if epoch_length is not None :
841
+ if epoch_length < 1 :
842
+ raise ValueError (
843
+ "Argument epoch_length is invalid. Please, either set a correct epoch_length value or "
844
+ "check if input data has non-zero size."
845
+ )
846
+ self .state .epoch_length = epoch_length
847
+
759
848
def set_data (self , data : Union [Iterable , DataLoader ]) -> None :
760
849
"""Method to set data. After calling the method the next batch passed to `processing_function` is
761
850
from newly provided data. Please, note that epoch length is not modified.
@@ -854,59 +943,98 @@ def switch_batch(engine):
854
943
if data is not None and not isinstance (data , Iterable ):
855
944
raise TypeError ("Argument data should be iterable" )
856
945
857
- if self .state .max_epochs is not None :
858
- # Check and apply overridden parameters
859
- if max_epochs is not None :
860
- if max_epochs < self .state .epoch :
861
- raise ValueError (
862
- "Argument max_epochs should be greater than or equal to the start "
863
- f"epoch defined in the state: { max_epochs } vs { self .state .epoch } . "
864
- "Please, set engine.state.max_epochs = None "
865
- "before calling engine.run() in order to restart the training from the beginning."
866
- )
867
- self .state .max_epochs = max_epochs
868
- if epoch_length is not None :
869
- if epoch_length != self .state .epoch_length :
870
- raise ValueError (
871
- "Argument epoch_length should be same as in the state, "
872
- f"but given { epoch_length } vs { self .state .epoch_length } "
873
- )
946
+ if max_epochs is not None and max_iters is not None :
947
+ raise ValueError (
948
+ "Arguments max_iters and max_epochs are mutually exclusive."
949
+ "Please provide only max_epochs or max_iters."
950
+ )
874
951
875
- if self .state .max_epochs is None or (self ._is_done (self .state ) and self ._internal_run_generator is None ):
876
- # Create new state
877
- if epoch_length is None :
878
- if data is None :
879
- raise ValueError ("epoch_length should be provided if data is None" )
952
+ # Check if we need to create new state or resume
953
+ # Create new state if:
954
+ # 1. No termination params set (first run), OR
955
+ # 2. Training is done AND generator is None AND no new params provided
956
+ # 3. Training is done AND same termination params provided (restart case)
957
+ should_create_new_state = (
958
+ (self .state .max_epochs is None and self .state .max_iters is None )
959
+ or (
960
+ self ._is_done (self .state )
961
+ and self ._internal_run_generator is None
962
+ and max_epochs is None
963
+ and max_iters is None
964
+ )
965
+ or (
966
+ self ._is_done (self .state )
967
+ and self ._internal_run_generator is None
968
+ and (
969
+ (max_epochs is not None and max_epochs == self .state .max_epochs )
970
+ or (max_iters is not None and max_iters == self .state .max_iters )
971
+ )
972
+ )
973
+ )
880
974
881
- epoch_length = self ._get_data_length (data )
882
- if epoch_length is not None and epoch_length < 1 :
883
- raise ValueError ("Input data has zero size. Please provide non-empty data" )
975
+ if should_create_new_state :
976
+ # Create new state
977
+ if data is None and epoch_length is None and self .state .epoch_length is None :
978
+ raise ValueError ("epoch_length should be provided if data is None" )
884
979
980
+ # Set epoch_length for new state
981
+ if epoch_length is None :
982
+ # Try to get from data first, then fall back to existing state
983
+ if data is not None :
984
+ epoch_length = self ._get_data_length (data )
985
+ if epoch_length is None and self .state .epoch_length is not None :
986
+ epoch_length = self .state .epoch_length
987
+ if epoch_length is not None and epoch_length < 1 :
988
+ raise ValueError ("Input data has zero size. Please provide non-empty data" )
989
+
990
+ # Determine max_epochs/max_iters
885
991
if max_iters is None :
886
992
if max_epochs is None :
887
993
max_epochs = 1
888
994
else :
889
- if max_epochs is not None :
890
- raise ValueError (
891
- "Arguments max_iters and max_epochs are mutually exclusive."
892
- "Please provide only max_epochs or max_iters."
893
- )
894
995
if epoch_length is not None :
895
996
max_epochs = math .ceil (max_iters / epoch_length )
896
997
998
+ # Initialize new state
897
999
self .state .iteration = 0
898
1000
self .state .epoch = 0
899
1001
self .state .max_epochs = max_epochs
900
1002
self .state .max_iters = max_iters
901
1003
self .state .epoch_length = epoch_length
902
1004
# Reset generator if previously used
903
1005
self ._internal_run_generator = None
904
- self .logger .info (f"Engine run starting with max_epochs={ max_epochs } ." )
1006
+
1007
+ # Log start message
1008
+ if self .state .max_epochs is not None :
1009
+ self .logger .info (f"Engine run starting with max_epochs={ self .state .max_epochs } ." )
1010
+ else :
1011
+ self .logger .info (f"Engine run starting with max_iters={ self .state .max_iters } ." )
905
1012
else :
906
- self .logger .info (
907
- f"Engine run resuming from iteration { self .state .iteration } , "
908
- f"epoch { self .state .epoch } until { self .state .max_epochs } epochs"
909
- )
1013
+ # Resume from existing state
1014
+ # Apply overridden parameters using helper methods
1015
+ self ._check_and_set_max_epochs (max_epochs )
1016
+ self ._check_and_set_max_iters (max_iters )
1017
+
1018
+ # Handle epoch_length validation (simplified from original)
1019
+ if epoch_length is not None :
1020
+ if epoch_length != self .state .epoch_length :
1021
+ raise ValueError (
1022
+ "Argument epoch_length should be same as in the state, "
1023
+ f"but given { epoch_length } vs { self .state .epoch_length } "
1024
+ )
1025
+
1026
+ # Log resuming message
1027
+ if self .state .max_epochs is not None :
1028
+ self .logger .info (
1029
+ f"Engine run resuming from iteration { self .state .iteration } , "
1030
+ f"epoch { self .state .epoch } until { self .state .max_epochs } epochs"
1031
+ )
1032
+ else :
1033
+ self .logger .info (
1034
+ f"Engine run resuming from iteration { self .state .iteration } , "
1035
+ f"epoch { self .state .epoch } until { self .state .max_iters } iterations"
1036
+ )
1037
+
910
1038
if self .state .epoch_length is None and data is None :
911
1039
raise ValueError ("epoch_length should be provided if data is None" )
912
1040
0 commit comments