From 20b09fb21cbf2b56ea4023775a250b3cb3829f97 Mon Sep 17 00:00:00 2001 From: Andreas Steiner Date: Thu, 10 Aug 2023 00:20:22 -0700 Subject: [PATCH] Fixes img/sec/core. Before this fix, only `lt0` but not `lstep` was updated after computing an evaluation, which led to a img/sec/core computation that was too high. Tis fix simply adds the time needed to eval / checkpoint to `lt0`, correcting the `dt` term in the `dsteps/dt` computation. PiperOrigin-RevId: 555396130 --- .github/workflows/build.yml | 2 +- vit_jax/train.py | 8 +++++--- 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index 95b9b7f..0e1191c 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -13,7 +13,7 @@ jobs: runs-on: ubuntu-latest strategy: matrix: - python-version: [3.8] + python-version: ['3.10'] steps: - name: Cancel previous uses: styfle/cancel-workflow-action@0.8.0 diff --git a/vit_jax/train.py b/vit_jax/train.py index a1a53dc..8dfca4e 100644 --- a/vit_jax/train.py +++ b/vit_jax/train.py @@ -212,7 +212,7 @@ def init_model(): (step == total_steps)): accuracies = [] - lt0 = time.time() + tt0 = time.time() for test_batch in input_pipeline.prefetch(ds_test, config.prefetch): logits = infer_fn_repl( dict(params=params_repl), test_batch['image']) @@ -223,8 +223,7 @@ def init_model(): accuracy_test = np.mean(accuracies) img_sec_core_test = ( config.batch_eval * ds_test.cardinality().numpy() / - (time.time() - lt0) / jax.device_count()) - lt0 = time.time() + (time.time() - tt0) / jax.device_count()) lr = float(lr_fn(step)) logging.info(f'Step: {step} ' # pylint: disable=logging-fstring-interpolation @@ -237,14 +236,17 @@ def init_model(): accuracy_test=accuracy_test, lr=lr, img_sec_core_test=img_sec_core_test)) + lt0 += time.time() - tt0 # Store checkpoint. if ((config.checkpoint_every and step % config.eval_every == 0) or step == total_steps): + tt0 = time.time() checkpoint_path = flax_checkpoints.save_checkpoint( workdir, (flax.jax_utils.unreplicate(params_repl), flax.jax_utils.unreplicate(opt_state_repl), step), step) logging.info('Stored checkpoint at step %d to "%s"', step, checkpoint_path) + lt0 += time.time() - tt0 return flax.jax_utils.unreplicate(params_repl)