diff --git a/vit_jax/train.py b/vit_jax/train.py index 9159fb7..3fecaf9 100644 --- a/vit_jax/train.py +++ b/vit_jax/train.py @@ -239,7 +239,7 @@ def init_model(): img_sec_core_test=img_sec_core_test)) # Store checkpoint. - if ((config.checkpoint_every and step % config.eval_every == 0) or + if ((config.checkpoint_every and step % config.checkpoint_every == 0) or step == total_steps): checkpoint_path = flax_checkpoints.save_checkpoint( workdir, (flax.jax_utils.unreplicate(params_repl),