diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index 95b9b7f..0e1191c 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -13,7 +13,7 @@ jobs: runs-on: ubuntu-latest strategy: matrix: - python-version: [3.8] + python-version: ['3.10'] steps: - name: Cancel previous uses: styfle/cancel-workflow-action@0.8.0 diff --git a/vit_jax/train.py b/vit_jax/train.py index a1a53dc..f67b6a5 100644 --- a/vit_jax/train.py +++ b/vit_jax/train.py @@ -113,8 +113,6 @@ def init_model(): filename = best.filename logging.info('Selected fillename="%s" for "%s" with final_val=%.3f', filename, model_or_filename, best.final_val) - pretrained_path = os.path.join(config.pretrained_dir, - f'{config.model.model_name}.npz') else: # ViT / Mixer papers filename = config.model.model_name @@ -140,7 +138,7 @@ def init_model(): optax.sgd( learning_rate=lr_fn, momentum=0.9, - accumulator_dtype='bfloat16', + accumulator_dtype=config.optim_dtype, ), ) @@ -212,7 +210,7 @@ def init_model(): (step == total_steps)): accuracies = [] - lt0 = time.time() + tt0 = time.time() for test_batch in input_pipeline.prefetch(ds_test, config.prefetch): logits = infer_fn_repl( dict(params=params_repl), test_batch['image']) @@ -223,8 +221,7 @@ def init_model(): accuracy_test = np.mean(accuracies) img_sec_core_test = ( config.batch_eval * ds_test.cardinality().numpy() / - (time.time() - lt0) / jax.device_count()) - lt0 = time.time() + (time.time() - tt0) / jax.device_count()) lr = float(lr_fn(step)) logging.info(f'Step: {step} ' # pylint: disable=logging-fstring-interpolation @@ -237,14 +234,17 @@ def init_model(): accuracy_test=accuracy_test, lr=lr, img_sec_core_test=img_sec_core_test)) + lt0 += time.time() - tt0 # Store checkpoint. if ((config.checkpoint_every and step % config.eval_every == 0) or step == total_steps): + tt0 = time.time() checkpoint_path = flax_checkpoints.save_checkpoint( workdir, (flax.jax_utils.unreplicate(params_repl), flax.jax_utils.unreplicate(opt_state_repl), step), step) logging.info('Stored checkpoint at step %d to "%s"', step, checkpoint_path) + lt0 += time.time() - tt0 return flax.jax_utils.unreplicate(params_repl)