diff --git a/examples/weather/corrdiff/conf/base/generation/base_all.yaml b/examples/weather/corrdiff/conf/base/generation/base_all.yaml index 98fac900ed..2cbca0647e 100644 --- a/examples/weather/corrdiff/conf/base/generation/base_all.yaml +++ b/examples/weather/corrdiff/conf/base/generation/base_all.yaml @@ -46,6 +46,6 @@ perf: # Use Apex GroupNorm (optimized normalization for performance with channelslast layout) profile_mode: false # Enable NVTX annotations for performance profiling - io_syncronous: true + io_synchronous: true # Synchronize I/O operations for writing inference results diff --git a/examples/weather/corrdiff/generate.py b/examples/weather/corrdiff/generate.py index 21e1233b1d..5e1e699427 100644 --- a/examples/weather/corrdiff/generate.py +++ b/examples/weather/corrdiff/generate.py @@ -355,7 +355,7 @@ def generate_fn(): has_lead_time=has_lead_time, ) - if cfg.generation.perf.io_syncronous: + if cfg.generation.perf.io_synchronous: writer_executor = ThreadPoolExecutor( max_workers=cfg.generation.perf.num_writer_workers ) @@ -381,8 +381,9 @@ def elapsed_time(self, _): start = end = DummyEvent() times = dataset.time() - for index, (image_tar, image_lr, *lead_time_label) in enumerate( - iter(data_loader) + for dataset_index, (image_tar, image_lr, *lead_time_label) in zip( + sampler, + iter(data_loader), ): time_index += 1 if dist.rank == 0: @@ -405,7 +406,7 @@ def elapsed_time(self, _): image_out = generate_fn() if dist.rank == 0: batch_size = image_out.shape[0] - if cfg.generation.perf.io_syncronous: + if cfg.generation.perf.io_synchronous: # write out data in a seperate thread so we don't hold up inferencing writer_threads.append( writer_executor.submit( @@ -417,8 +418,7 @@ def elapsed_time(self, _): image_tar.cpu(), image_lr.cpu(), time_index, - index, - has_lead_time, + dataset_index, ) ) else: @@ -430,8 +430,7 @@ def elapsed_time(self, _): image_tar.cpu(), image_lr.cpu(), time_index, - index, - has_lead_time, + dataset_index, ) end.record() end.synchronize() @@ -449,7 +448,7 @@ def elapsed_time(self, _): ) # make sure all the workers are done writing - if dist.rank == 0 and cfg.generation.perf.io_syncronous: + if dist.rank == 0 and cfg.generation.perf.io_synchronous: for thread in list(writer_threads): thread.result() writer_threads.remove(thread) diff --git a/examples/weather/corrdiff/helpers/generate_helpers.py b/examples/weather/corrdiff/helpers/generate_helpers.py index e970bb7aed..abfa4fee87 100644 --- a/examples/weather/corrdiff/helpers/generate_helpers.py +++ b/examples/weather/corrdiff/helpers/generate_helpers.py @@ -51,8 +51,7 @@ def save_images( image_tar, image_lr, time_index, - t_index, - has_lead_time, + dataset_index, ): """ Saves inferencing result along with the baseline @@ -71,7 +70,7 @@ def save_images( image_tar (torch.Tensor): Ground truth data image_lr (torch.Tensor): Low resolution input data time_index (int): Epoch number - t_index (int): index where times are located + dataset_index (int): index where times are located """ # weather sub-plot image_lr2 = image_lr[0].unsqueeze(0) @@ -95,7 +94,7 @@ def save_images( image_out2 = image_out2.cpu().numpy() image_out2 = dataset.denormalize_output(image_out2) - time = times[t_index] + time = times[dataset_index] writer.write_time(time_index, time) for channel_idx in range(image_out2.shape[1]): info = dataset.output_channels()[channel_idx] @@ -107,10 +106,10 @@ def save_images( channel_name, time_index, idx, image_out2[0, channel_idx] ) - input_channel_info = dataset.input_channels() - for channel_idx in range(len(input_channel_info)): - info = input_channel_info[channel_idx] - channel_name = info.name + info.level - writer.write_input(channel_name, time_index, image_lr2[0, channel_idx]) - if channel_idx == image_lr2.shape[1] - 1: - break + input_channel_info = dataset.input_channels() + for channel_idx in range(len(input_channel_info)): + info = input_channel_info[channel_idx] + channel_name = info.name + info.level + writer.write_input(channel_name, time_index, image_lr2[0, channel_idx]) + if channel_idx == image_lr2.shape[1] - 1: + break diff --git a/examples/weather/corrdiff/train.py b/examples/weather/corrdiff/train.py index bc1f838ee1..d56c367cc1 100644 --- a/examples/weather/corrdiff/train.py +++ b/examples/weather/corrdiff/train.py @@ -828,13 +828,15 @@ def main(cfg: DictConfig) -> None: epoch=cur_nimg, ) - # Retain only the recent n checkpoints, if desired - if cfg.training.io.save_n_recent_checkpoints > 0: - for suffix in [".mdlus", ".pt"]: - ckpts = checkpoint_list(checkpoint_dir, suffix=suffix) - while len(ckpts) > cfg.training.io.save_n_recent_checkpoints: - os.remove(os.path.join(checkpoint_dir, ckpts[0])) - ckpts = ckpts[1:] + # Retain only the recent n checkpoints, if desired + if cfg.training.io.save_n_recent_checkpoints > 0: + for suffix in [".mdlus", ".pt"]: + ckpts = checkpoint_list(checkpoint_dir, suffix=suffix) + while ( + len(ckpts) > cfg.training.io.save_n_recent_checkpoints + ): + os.remove(os.path.join(checkpoint_dir, ckpts[0])) + ckpts = ckpts[1:] # Done. logger0.info("Training Completed.")