@@ -321,9 +321,9 @@ def setup(self, cfg: DictConfig) -> None:
321
321
),
322
322
)
323
323
324
- self ._val_sampler , self . _val_dataloader = None , None
324
+ self ._val_dataloader = None
325
325
if cfg .get ("dataset_validation" ) is not None :
326
- self ._val_sampler , self . _val_dataloader = self ._setup_data (
326
+ self ._val_dataloader = self ._setup_data (
327
327
cfg_dataset = cfg .dataset_validation ,
328
328
batch_size = cfg .batch_size ,
329
329
collate_fn = collate_name ,
@@ -897,8 +897,10 @@ def train(self) -> None:
897
897
self ._profiler .step ()
898
898
899
899
# Run validation after gradient update
900
- if (self ._run_val_every_n_steps is not None and
901
- self .global_step % self ._run_val_every_n_steps == 0 ):
900
+ if (
901
+ self ._run_val_every_n_steps is not None
902
+ and self .global_step % self ._run_val_every_n_steps == 0
903
+ ):
902
904
pbar .refresh ()
903
905
val_loss = self .validate ()
904
906
if self ._is_rank_zero :
@@ -931,8 +933,10 @@ def validate(self) -> float:
931
933
932
934
with torch .no_grad ():
933
935
for batch_idx , batch in enumerate (self ._val_dataloader ):
934
- if (self ._max_validation_batches > 0 and
935
- batch_idx >= self ._max_validation_batches ):
936
+ if (
937
+ self ._max_validation_batches > 0
938
+ and batch_idx >= self ._max_validation_batches
939
+ ):
936
940
break
937
941
938
942
utils .batch_to_device (batch , self ._device )
@@ -942,7 +946,6 @@ def validate(self) -> float:
942
946
batch ["labels" ] != self ._loss_fn .ignore_index
943
947
).sum ()
944
948
945
-
946
949
labels = batch .pop ("labels" )
947
950
948
951
with self .activations_handling_ctx :
@@ -972,7 +975,7 @@ def validate(self) -> float:
972
975
avg_val_loss = (
973
976
(total_val_loss / total_val_tokens ).item ()
974
977
if total_val_tokens > 0
975
- else float (' inf' )
978
+ else float (" inf" )
976
979
)
977
980
978
981
self ._model .train ()
0 commit comments