1818import abc
1919from collections .abc import Mapping
2020import dataclasses
21+ import functools
2122import gc
2223import os
2324import time
@@ -96,7 +97,6 @@ def export_model(self, model: keras.Model, model_dir: str):
9697 model: The Keras model constructed by `create_model`.
9798 model_dir: The model directory passed to the trainer.
9899 """
99- model .save (os .path .join (model_dir , core .KERAS_MODEL_SAVEFILE ))
100100
101101
102102class KerasTrainer (core .Trainer [KerasTask ]):
@@ -118,6 +118,7 @@ def __init__(
118118 max_checkpoints_to_keep : int = 5 ,
119119 checkpoint_save_interval_epochs : int = 1 ,
120120 rng_seed : int = core .DEFAULT_RNG_SEED ,
121+ legacy_checkpoint_format : bool = True ,
121122 ):
122123 """Initializes the instance."""
123124
@@ -143,60 +144,77 @@ def __init__(
143144 self ._steps_per_eval = steps_per_eval
144145 self ._continuous_eval_timeout = continuous_eval_timeout
145146 self ._steps_per_loop = steps_per_loop
146- self ._checkpoint_manager = None
147147 self ._marker_path = os .path .join (
148148 model_dir , core .TRAINING_COMPLETE_MARKER_FILE
149149 )
150150 self ._checkpoint_dir = os .path .join (model_dir , core .CHECKPOINT_DIR )
151+ self ._max_checkpoints_to_keep = max_checkpoints_to_keep
152+ self ._checkpoint_save_interval_epochs = checkpoint_save_interval_epochs
153+ self ._legacy_checkpoint_format = legacy_checkpoint_format
151154
155+ @functools .cached_property
156+ def train_callbacks (self ) -> list [keras .callbacks .Callback ]:
157+ """Returns the training callbacks."""
152158 if keras .backend .backend () == "jax" :
153- self ._checkpoint_manager = keras_utils .KerasOrbaxCheckpointManager (
154- checkpoint_dir = self ._checkpoint_dir ,
155- max_to_keep = max_checkpoints_to_keep ,
156- save_interval_epochs = checkpoint_save_interval_epochs ,
157- )
158- self ._train_callbacks = [
159+ if self ._legacy_checkpoint_format :
160+ checkpoint_manager = keras_utils .KerasOrbaxCheckpointManager (
161+ checkpoint_dir = self ._checkpoint_dir ,
162+ max_to_keep = self ._max_checkpoints_to_keep ,
163+ save_interval_epochs = self ._checkpoint_save_interval_epochs ,
164+ )
165+ else :
166+ checkpoint_manager = keras_utils .KerasOrbaxCheckpointManagerV2 (
167+ checkpoint_dir = self ._checkpoint_dir ,
168+ max_to_keep = self ._max_checkpoints_to_keep ,
169+ save_interval_epochs = self ._checkpoint_save_interval_epochs ,
170+ )
171+ return [
159172 keras_utils .EpochSummaryCallback (
160- log_dir = os .path .join (model_dir , core .LOG_DIR ),
161- steps_per_epoch = steps_per_loop ,
173+ log_dir = os .path .join (self . _model_dir , core .LOG_DIR ),
174+ steps_per_epoch = self . _steps_per_loop ,
162175 write_steps_per_second = True ,
163176 ),
164177 keras_utils .EpochOrbaxCheckpointAndRestoreCallback (
165- checkpoint_manager = self . _checkpoint_manager ,
178+ checkpoint_manager = checkpoint_manager ,
166179 marker_path = self ._marker_path ,
167180 ),
168181 ]
169- self ._eval_callbacks = [
182+ return [
183+ keras .callbacks .TensorBoard (
184+ log_dir = os .path .join (self ._model_dir , core .LOG_DIR ),
185+ write_steps_per_second = True ,
186+ ),
187+ keras .callbacks .BackupAndRestore (
188+ backup_dir = os .path .join (self ._model_dir , core .BACKUP_DIR ),
189+ ),
190+ keras .callbacks .ModelCheckpoint (
191+ filepath = os .path .join (
192+ self ._model_dir ,
193+ core .CHECKPOINT_DIR ,
194+ "ckpt-{epoch:d}.weights.h5" ,
195+ ),
196+ save_weights_only = True ,
197+ verbose = 1 ,
198+ ),
199+ ]
200+
201+ @functools .cached_property
202+ def eval_callbacks (self ) -> list [keras .callbacks .Callback ]:
203+ """Returns the evaluation callbacks."""
204+ if keras .backend .backend () == "jax" :
205+ return [
170206 keras_utils .EpochSummaryCallback (
171- log_dir = os .path .join (model_dir , core .LOG_DIR ),
172- steps_per_epoch = steps_per_loop ,
207+ log_dir = os .path .join (self . _model_dir , core .LOG_DIR ),
208+ steps_per_epoch = self . _steps_per_loop ,
173209 write_steps_per_second = False ,
174210 ),
175211 ]
176- else :
177- self ._checkpoint_manager = None
178- self ._train_callbacks = [
179- keras .callbacks .TensorBoard (
180- log_dir = os .path .join (model_dir , core .LOG_DIR ),
181- write_steps_per_second = True ,
182- ),
183- keras .callbacks .BackupAndRestore (
184- backup_dir = os .path .join (model_dir , core .BACKUP_DIR ),
185- ),
186- keras .callbacks .ModelCheckpoint (
187- filepath = os .path .join (
188- model_dir , core .CHECKPOINT_DIR , "ckpt-{epoch:d}.weights.h5"
189- ),
190- save_weights_only = True ,
191- verbose = 1 ,
192- ),
193- ]
194- self ._eval_callbacks = [
195- keras .callbacks .TensorBoard (
196- log_dir = os .path .join (model_dir , core .LOG_DIR ),
197- write_steps_per_second = True ,
198- ),
199- ]
212+ return [
213+ keras .callbacks .TensorBoard (
214+ log_dir = os .path .join (self ._model_dir , core .LOG_DIR ),
215+ write_steps_per_second = True ,
216+ ),
217+ ]
200218
201219 def _maybe_get_model_kws (
202220 self , task : KerasTask , dataset : tf .data .Dataset
@@ -216,9 +234,11 @@ def train(self, task: KerasTask) -> core.Logs:
216234
217235 history = model .fit (
218236 dataset ,
219- epochs = self ._train_epochs ,
237+ epochs = self ._train_epochs + 1 ,
220238 steps_per_epoch = self ._steps_per_loop ,
221- callbacks = self ._train_callbacks ,
239+ callbacks = self .train_callbacks ,
240+ initial_epoch = 1 ,
241+ verbose = 0 , # Disable progbar for better TPU utilization
222242 )
223243 model .summary (print_fn = logging .info )
224244
@@ -237,15 +257,16 @@ def evaluate(self, task: KerasTask) -> core.Logs:
237257 if keras .backend .backend () == "jax" :
238258 [tb_cbk ] = [
239259 cbk
240- for cbk in self ._eval_callbacks
260+ for cbk in self .eval_callbacks
241261 if isinstance (cbk , keras_utils .EpochSummaryCallback )
242262 ]
243263 epoch_start_time = time .time ()
244264 history = model .evaluate (
245265 dataset ,
246266 steps = self ._steps_per_eval ,
247- callbacks = self ._eval_callbacks ,
267+ callbacks = self .eval_callbacks ,
248268 return_dict = True ,
269+ verbose = 0 , # Disable progbar for better TPU utilization
249270 )
250271 epoch_dt = time .time () - epoch_start_time
251272 steps_per_second = self ._steps_per_eval / epoch_dt
@@ -257,7 +278,8 @@ def evaluate(self, task: KerasTask) -> core.Logs:
257278 return model .evaluate (
258279 dataset ,
259280 steps = self ._steps_per_eval ,
260- callbacks = self ._eval_callbacks ,
281+ callbacks = self .eval_callbacks ,
282+ verbose = 0 , # Disable progbar for better TPU utilization
261283 )
262284
263285 def train_and_evaluate (self , task : KerasTask ) -> core .Logs :
@@ -273,11 +295,13 @@ def train_and_evaluate(self, task: KerasTask) -> core.Logs:
273295 history = model .fit (
274296 train_dataset ,
275297 validation_data = eval_dataset ,
276- epochs = self ._train_epochs ,
298+ epochs = self ._train_epochs + 1 ,
277299 steps_per_epoch = self ._steps_per_loop ,
278300 # Explicitly set to None for deterministic evaluation.
279301 validation_steps = None ,
280- callbacks = self ._train_callbacks ,
302+ callbacks = self .train_callbacks ,
303+ initial_epoch = 1 ,
304+ verbose = 0 , # Disable progbar for better TPU utilization
281305 )
282306 model .summary (print_fn = logging .info )
283307
@@ -308,7 +332,10 @@ def timeout_fn() -> bool:
308332 else :
309333 steps_msg = "running complete evaluation..."
310334
335+ use_legacy_checkpoint_format = self ._legacy_checkpoint_format
336+
311337 class _RestoreCallback (keras .callbacks .Callback ):
338+ """Callback for restoring the model from the latest checkpoint."""
312339
313340 def __init__ (
314341 self ,
@@ -319,9 +346,14 @@ def __init__(
319346 self ._epoch = epoch
320347
321348 def on_test_begin (self , logs : Mapping [str , Any ] | None = None ):
322- keras_utils .restore_keras_model (
323- model , self ._checkpoint_dir , step = self ._epoch
324- )
349+ if use_legacy_checkpoint_format :
350+ keras_utils .restore_keras_model (
351+ model , self ._checkpoint_dir , step = self ._epoch
352+ )
353+ else :
354+ keras_utils .restore_keras_checkpoint (
355+ self ._checkpoint_dir , model = model , epoch = self ._epoch
356+ )
325357
326358 history = None
327359 for epoch in ocp .checkpoint_utils .checkpoints_iterator (
@@ -332,7 +364,7 @@ def on_test_begin(self, logs: Mapping[str, Any] | None = None):
332364 restore_callback = _RestoreCallback (self ._checkpoint_dir , epoch )
333365 [tb_cbk ] = [
334366 cbk
335- for cbk in self ._eval_callbacks
367+ for cbk in self .eval_callbacks
336368 if isinstance (cbk , keras_utils .EpochSummaryCallback )
337369 ]
338370 try :
@@ -346,8 +378,9 @@ def on_test_begin(self, logs: Mapping[str, Any] | None = None):
346378 history = model .evaluate (
347379 eval_dataset ,
348380 steps = self ._steps_per_eval ,
349- callbacks = [restore_callback ] + self ._eval_callbacks ,
381+ callbacks = [restore_callback ] + self .eval_callbacks ,
350382 return_dict = True ,
383+ verbose = 0 , # Disable progbar for better TPU utilization
351384 )
352385
353386 logging .info (
0 commit comments