Skip to content

Commit bac2247

Browse files
committed
Change validation to use StatefulDataLoader
1 parent fcdb7c0 commit bac2247

File tree

2 files changed

+22
-15
lines changed

2 files changed

+22
-15
lines changed

recipes/full_finetune_distributed.py

+11-7
Original file line numberDiff line numberDiff line change
@@ -361,9 +361,9 @@ def setup(self, cfg: DictConfig) -> None:
361361
)
362362

363363
# Setup validation dataloader if validation dataset is provided
364-
self._val_sampler, self._val_dataloader = None, None
364+
self._val_dataloader = None
365365
if cfg.get("dataset_validation") is not None:
366-
self._val_sampler, self._val_dataloader = self._setup_data(
366+
self._val_dataloader = self._setup_data(
367367
cfg_dataset=cfg.dataset_validation,
368368
batch_size=cfg.batch_size,
369369
collate_fn=collate_name,
@@ -788,8 +788,10 @@ def validate(self) -> float:
788788

789789
with torch.no_grad():
790790
for batch_idx, batch in enumerate(self._val_dataloader):
791-
if (self._max_validation_batches > 0 and
792-
batch_idx >= self._max_validation_batches):
791+
if (
792+
self._max_validation_batches > 0
793+
and batch_idx >= self._max_validation_batches
794+
):
793795
break
794796

795797
utils.batch_to_device(batch, self._device)
@@ -813,7 +815,7 @@ def validate(self) -> float:
813815
avg_val_loss = (
814816
(total_val_loss / total_val_tokens).item()
815817
if total_val_tokens > 0
816-
else float('inf')
818+
else float("inf")
817819
)
818820

819821
self._model.train()
@@ -965,8 +967,10 @@ def train(self) -> None:
965967
self._profiler.step()
966968

967969
# Run validation after gradient update
968-
if (self._run_val_every_n_steps is not None and
969-
self.global_step % self._run_val_every_n_steps == 0):
970+
if (
971+
self._run_val_every_n_steps is not None
972+
and self.global_step % self._run_val_every_n_steps == 0
973+
):
970974
pbar.refresh()
971975
val_loss = self.validate()
972976
if self._is_rank_zero:

recipes/lora_finetune_distributed.py

+11-8
Original file line numberDiff line numberDiff line change
@@ -321,9 +321,9 @@ def setup(self, cfg: DictConfig) -> None:
321321
),
322322
)
323323

324-
self._val_sampler, self._val_dataloader = None, None
324+
self._val_dataloader = None
325325
if cfg.get("dataset_validation") is not None:
326-
self._val_sampler, self._val_dataloader = self._setup_data(
326+
self._val_dataloader = self._setup_data(
327327
cfg_dataset=cfg.dataset_validation,
328328
batch_size=cfg.batch_size,
329329
collate_fn=collate_name,
@@ -897,8 +897,10 @@ def train(self) -> None:
897897
self._profiler.step()
898898

899899
# Run validation after gradient update
900-
if (self._run_val_every_n_steps is not None and
901-
self.global_step % self._run_val_every_n_steps == 0):
900+
if (
901+
self._run_val_every_n_steps is not None
902+
and self.global_step % self._run_val_every_n_steps == 0
903+
):
902904
pbar.refresh()
903905
val_loss = self.validate()
904906
if self._is_rank_zero:
@@ -931,8 +933,10 @@ def validate(self) -> float:
931933

932934
with torch.no_grad():
933935
for batch_idx, batch in enumerate(self._val_dataloader):
934-
if (self._max_validation_batches > 0 and
935-
batch_idx >= self._max_validation_batches):
936+
if (
937+
self._max_validation_batches > 0
938+
and batch_idx >= self._max_validation_batches
939+
):
936940
break
937941

938942
utils.batch_to_device(batch, self._device)
@@ -942,7 +946,6 @@ def validate(self) -> float:
942946
batch["labels"] != self._loss_fn.ignore_index
943947
).sum()
944948

945-
946949
labels = batch.pop("labels")
947950

948951
with self.activations_handling_ctx:
@@ -972,7 +975,7 @@ def validate(self) -> float:
972975
avg_val_loss = (
973976
(total_val_loss / total_val_tokens).item()
974977
if total_val_tokens > 0
975-
else float('inf')
978+
else float("inf")
976979
)
977980

978981
self._model.train()

0 commit comments

Comments
 (0)