1818import abc
1919from collections .abc import Mapping
2020import dataclasses
21+ import functools
2122import gc
2223import os
2324import time
@@ -96,7 +97,6 @@ def export_model(self, model: keras.Model, model_dir: str):
9697 model: The Keras model constructed by `create_model`.
9798 model_dir: The model directory passed to the trainer.
9899 """
99- model .save (os .path .join (model_dir , core .KERAS_MODEL_SAVEFILE ))
100100
101101
102102class KerasTrainer (core .Trainer [KerasTask ]):
@@ -118,6 +118,7 @@ def __init__(
118118 max_checkpoints_to_keep : int = 5 ,
119119 checkpoint_save_interval_epochs : int = 1 ,
120120 rng_seed : int = core .DEFAULT_RNG_SEED ,
121+ legacy_checkpoint_format : bool = True ,
121122 ):
122123 """Initializes the instance."""
123124
@@ -143,60 +144,77 @@ def __init__(
143144 self ._steps_per_eval = steps_per_eval
144145 self ._continuous_eval_timeout = continuous_eval_timeout
145146 self ._steps_per_loop = steps_per_loop
146- self ._checkpoint_manager = None
147147 self ._marker_path = os .path .join (
148148 model_dir , core .TRAINING_COMPLETE_MARKER_FILE
149149 )
150150 self ._checkpoint_dir = os .path .join (model_dir , core .CHECKPOINT_DIR )
151+ self ._max_checkpoints_to_keep = max_checkpoints_to_keep
152+ self ._checkpoint_save_interval_epochs = checkpoint_save_interval_epochs
153+ self ._legacy_checkpoint_format = legacy_checkpoint_format
151154
155+ @functools .cached_property
156+ def train_callbacks (self ) -> list [keras .callbacks .Callback ]:
157+ """Returns the training callbacks."""
152158 if keras .backend .backend () == "jax" :
153- self ._checkpoint_manager = keras_utils .KerasOrbaxCheckpointManager (
154- checkpoint_dir = self ._checkpoint_dir ,
155- max_to_keep = max_checkpoints_to_keep ,
156- save_interval_epochs = checkpoint_save_interval_epochs ,
157- )
158- self ._train_callbacks = [
159+ if self ._legacy_checkpoint_format :
160+ checkpoint_manager = keras_utils .KerasOrbaxCheckpointManager (
161+ checkpoint_dir = self ._checkpoint_dir ,
162+ max_to_keep = self ._max_checkpoints_to_keep ,
163+ save_interval_epochs = self ._checkpoint_save_interval_epochs ,
164+ )
165+ else :
166+ checkpoint_manager = keras_utils .KerasOrbaxCheckpointManagerV2 (
167+ checkpoint_dir = self ._checkpoint_dir ,
168+ max_to_keep = self ._max_checkpoints_to_keep ,
169+ save_interval_epochs = self ._checkpoint_save_interval_epochs ,
170+ )
171+ return [
159172 keras_utils .EpochSummaryCallback (
160- log_dir = os .path .join (model_dir , core .LOG_DIR ),
161- steps_per_epoch = steps_per_loop ,
173+ log_dir = os .path .join (self . _model_dir , core .LOG_DIR ),
174+ steps_per_epoch = self . _steps_per_loop ,
162175 write_steps_per_second = True ,
163176 ),
164177 keras_utils .EpochOrbaxCheckpointAndRestoreCallback (
165- checkpoint_manager = self . _checkpoint_manager ,
178+ checkpoint_manager = checkpoint_manager ,
166179 marker_path = self ._marker_path ,
167180 ),
168181 ]
169- self ._eval_callbacks = [
182+ return [
183+ keras .callbacks .TensorBoard (
184+ log_dir = os .path .join (self ._model_dir , core .LOG_DIR ),
185+ write_steps_per_second = True ,
186+ ),
187+ keras .callbacks .BackupAndRestore (
188+ backup_dir = os .path .join (self ._model_dir , core .BACKUP_DIR ),
189+ ),
190+ keras .callbacks .ModelCheckpoint (
191+ filepath = os .path .join (
192+ self ._model_dir ,
193+ core .CHECKPOINT_DIR ,
194+ "ckpt-{epoch:d}.weights.h5" ,
195+ ),
196+ save_weights_only = True ,
197+ verbose = 1 ,
198+ ),
199+ ]
200+
201+ @functools .cached_property
202+ def eval_callbacks (self ) -> list [keras .callbacks .Callback ]:
203+ """Returns the evaluation callbacks."""
204+ if keras .backend .backend () == "jax" :
205+ return [
170206 keras_utils .EpochSummaryCallback (
171- log_dir = os .path .join (model_dir , core .LOG_DIR ),
172- steps_per_epoch = steps_per_loop ,
207+ log_dir = os .path .join (self . _model_dir , core .LOG_DIR ),
208+ steps_per_epoch = self . _steps_per_loop ,
173209 write_steps_per_second = False ,
174210 ),
175211 ]
176- else :
177- self ._checkpoint_manager = None
178- self ._train_callbacks = [
179- keras .callbacks .TensorBoard (
180- log_dir = os .path .join (model_dir , core .LOG_DIR ),
181- write_steps_per_second = True ,
182- ),
183- keras .callbacks .BackupAndRestore (
184- backup_dir = os .path .join (model_dir , core .BACKUP_DIR ),
185- ),
186- keras .callbacks .ModelCheckpoint (
187- filepath = os .path .join (
188- model_dir , core .CHECKPOINT_DIR , "ckpt-{epoch:d}.weights.h5"
189- ),
190- save_weights_only = True ,
191- verbose = 1 ,
192- ),
193- ]
194- self ._eval_callbacks = [
195- keras .callbacks .TensorBoard (
196- log_dir = os .path .join (model_dir , core .LOG_DIR ),
197- write_steps_per_second = True ,
198- ),
199- ]
212+ return [
213+ keras .callbacks .TensorBoard (
214+ log_dir = os .path .join (self ._model_dir , core .LOG_DIR ),
215+ write_steps_per_second = True ,
216+ ),
217+ ]
200218
201219 def _maybe_get_model_kws (
202220 self , task : KerasTask , dataset : tf .data .Dataset
@@ -218,7 +236,7 @@ def train(self, task: KerasTask) -> core.Logs:
218236 dataset ,
219237 epochs = self ._train_epochs ,
220238 steps_per_epoch = self ._steps_per_loop ,
221- callbacks = self ._train_callbacks ,
239+ callbacks = self .train_callbacks ,
222240 )
223241 model .summary (print_fn = logging .info )
224242
@@ -237,14 +255,14 @@ def evaluate(self, task: KerasTask) -> core.Logs:
237255 if keras .backend .backend () == "jax" :
238256 [tb_cbk ] = [
239257 cbk
240- for cbk in self ._eval_callbacks
258+ for cbk in self .eval_callbacks
241259 if isinstance (cbk , keras_utils .EpochSummaryCallback )
242260 ]
243261 epoch_start_time = time .time ()
244262 history = model .evaluate (
245263 dataset ,
246264 steps = self ._steps_per_eval ,
247- callbacks = self ._eval_callbacks ,
265+ callbacks = self .eval_callbacks ,
248266 return_dict = True ,
249267 )
250268 epoch_dt = time .time () - epoch_start_time
@@ -257,7 +275,7 @@ def evaluate(self, task: KerasTask) -> core.Logs:
257275 return model .evaluate (
258276 dataset ,
259277 steps = self ._steps_per_eval ,
260- callbacks = self ._eval_callbacks ,
278+ callbacks = self .eval_callbacks ,
261279 )
262280
263281 def train_and_evaluate (self , task : KerasTask ) -> core .Logs :
@@ -277,7 +295,7 @@ def train_and_evaluate(self, task: KerasTask) -> core.Logs:
277295 steps_per_epoch = self ._steps_per_loop ,
278296 # Explicitly set to None for deterministic evaluation.
279297 validation_steps = None ,
280- callbacks = self ._train_callbacks ,
298+ callbacks = self .train_callbacks ,
281299 )
282300 model .summary (print_fn = logging .info )
283301
@@ -308,7 +326,10 @@ def timeout_fn() -> bool:
308326 else :
309327 steps_msg = "running complete evaluation..."
310328
329+ use_legacy_checkpoint_format = self ._legacy_checkpoint_format
330+
311331 class _RestoreCallback (keras .callbacks .Callback ):
332+ """Callback for restoring the model from the latest checkpoint."""
312333
313334 def __init__ (
314335 self ,
@@ -319,9 +340,14 @@ def __init__(
319340 self ._epoch = epoch
320341
321342 def on_test_begin (self , logs : Mapping [str , Any ] | None = None ):
322- keras_utils .restore_keras_model (
323- model , self ._checkpoint_dir , step = self ._epoch
324- )
343+ if use_legacy_checkpoint_format :
344+ keras_utils .restore_keras_model (
345+ model , self ._checkpoint_dir , step = self ._epoch
346+ )
347+ else :
348+ keras_utils .restore_keras_checkpoint (
349+ self ._checkpoint_dir , model = model , epoch = self ._epoch
350+ )
325351
326352 history = None
327353 for epoch in ocp .checkpoint_utils .checkpoints_iterator (
@@ -332,7 +358,7 @@ def on_test_begin(self, logs: Mapping[str, Any] | None = None):
332358 restore_callback = _RestoreCallback (self ._checkpoint_dir , epoch )
333359 [tb_cbk ] = [
334360 cbk
335- for cbk in self ._eval_callbacks
361+ for cbk in self .eval_callbacks
336362 if isinstance (cbk , keras_utils .EpochSummaryCallback )
337363 ]
338364 try :
@@ -346,7 +372,7 @@ def on_test_begin(self, logs: Mapping[str, Any] | None = None):
346372 history = model .evaluate (
347373 eval_dataset ,
348374 steps = self ._steps_per_eval ,
349- callbacks = [restore_callback ] + self ._eval_callbacks ,
375+ callbacks = [restore_callback ] + self .eval_callbacks ,
350376 return_dict = True ,
351377 )
352378
0 commit comments