Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,6 @@ perf:
# Use Apex GroupNorm (optimized normalization for performance with channelslast layout)
profile_mode: false
# Enable NVTX annotations for performance profiling
io_syncronous: true
io_synchronous: true
# Synchronize I/O operations for writing inference results

17 changes: 8 additions & 9 deletions examples/weather/corrdiff/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -357,7 +357,7 @@ def generate_fn():
has_lead_time=has_lead_time,
)

if cfg.generation.perf.io_syncronous:
if cfg.generation.perf.io_synchronous:
writer_executor = ThreadPoolExecutor(
max_workers=cfg.generation.perf.num_writer_workers
)
Expand All @@ -383,8 +383,9 @@ def elapsed_time(self, _):
start = end = DummyEvent()

times = dataset.time()
for index, (image_tar, image_lr, *lead_time_label) in enumerate(
iter(data_loader)
for dataset_index, (image_tar, image_lr, *lead_time_label) in zip(
sampler,
iter(data_loader),
):
time_index += 1
if dist.rank == 0:
Expand All @@ -407,7 +408,7 @@ def elapsed_time(self, _):
image_out = generate_fn()
if dist.rank == 0:
batch_size = image_out.shape[0]
if cfg.generation.perf.io_syncronous:
if cfg.generation.perf.io_synchronous:
# write out data in a seperate thread so we don't hold up inferencing
writer_threads.append(
writer_executor.submit(
Expand All @@ -419,8 +420,7 @@ def elapsed_time(self, _):
image_tar.cpu(),
image_lr.cpu(),
time_index,
index,
has_lead_time,
dataset_index,
)
)
else:
Expand All @@ -432,8 +432,7 @@ def elapsed_time(self, _):
image_tar.cpu(),
image_lr.cpu(),
time_index,
index,
has_lead_time,
dataset_index,
)
end.record()
end.synchronize()
Expand All @@ -451,7 +450,7 @@ def elapsed_time(self, _):
)

# make sure all the workers are done writing
if dist.rank == 0 and cfg.generation.perf.io_syncronous:
if dist.rank == 0 and cfg.generation.perf.io_synchronous:
for thread in list(writer_threads):
thread.result()
writer_threads.remove(thread)
Expand Down
21 changes: 10 additions & 11 deletions examples/weather/corrdiff/helpers/generate_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,8 +51,7 @@ def save_images(
image_tar,
image_lr,
time_index,
t_index,
has_lead_time,
dataset_index,
):
"""
Saves inferencing result along with the baseline
Expand All @@ -71,7 +70,7 @@ def save_images(
image_tar (torch.Tensor): Ground truth data
image_lr (torch.Tensor): Low resolution input data
time_index (int): Epoch number
t_index (int): index where times are located
dataset_index (int): index where times are located
"""
# weather sub-plot
image_lr2 = image_lr[0].unsqueeze(0)
Expand All @@ -95,7 +94,7 @@ def save_images(
image_out2 = image_out2.cpu().numpy()
image_out2 = dataset.denormalize_output(image_out2)

time = times[t_index]
time = times[dataset_index]
writer.write_time(time_index, time)
for channel_idx in range(image_out2.shape[1]):
info = dataset.output_channels()[channel_idx]
Expand All @@ -107,10 +106,10 @@ def save_images(
channel_name, time_index, idx, image_out2[0, channel_idx]
)

input_channel_info = dataset.input_channels()
for channel_idx in range(len(input_channel_info)):
info = input_channel_info[channel_idx]
channel_name = info.name + info.level
writer.write_input(channel_name, time_index, image_lr2[0, channel_idx])
if channel_idx == image_lr2.shape[1] - 1:
break
input_channel_info = dataset.input_channels()
for channel_idx in range(len(input_channel_info)):
info = input_channel_info[channel_idx]
channel_name = info.name + info.level
writer.write_input(channel_name, time_index, image_lr2[0, channel_idx])
if channel_idx == image_lr2.shape[1] - 1:
break
14 changes: 7 additions & 7 deletions examples/weather/corrdiff/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -830,13 +830,13 @@ def main(cfg: DictConfig) -> None:
epoch=cur_nimg,
)

# Retain only the recent n checkpoints, if desired
if cfg.training.io.save_n_recent_checkpoints > 0:
for suffix in [".mdlus", ".pt"]:
ckpts = checkpoint_list(checkpoint_dir, suffix=suffix)
while len(ckpts) > cfg.training.io.save_n_recent_checkpoints:
os.remove(os.path.join(checkpoint_dir, ckpts[0]))
ckpts = ckpts[1:]
# Retain only the recent n checkpoints, if desired
if cfg.training.io.save_n_recent_checkpoints > 0:
for suffix in [".mdlus", ".pt"]:
ckpts = checkpoint_list(checkpoint_dir, suffix=suffix)
while len(ckpts) > cfg.training.io.save_n_recent_checkpoints:
os.remove(os.path.join(checkpoint_dir, ckpts[0]))
ckpts = ckpts[1:]

# Done.
logger0.info("Training Completed.")
Expand Down