Skip to content

drop_last is not respected #442

@robmarkcole

Description

@robmarkcole

🐛 Bug

I pass

StreamingDataLoader(
            dataset=dataset,
            batch_size=self.batch_size,
            shuffle=(split == "train"),
            num_workers=self.num_workers,
            collate_fn=self.collate_fn,
            prefetch_factor=self.prefetch_factor,
            persistent_workers=self.persistent_workers,
            multiprocessing_context=self.multiprocessing_context,
            drop_last=(split == "train"),
        )

Configure batch_size=2 and log the actual batch sizes received, the final has a size of 1.

Epoch 0:  97%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▍    | 99/102 [00:12<00:00,  8.22it/s, v_num=749d]train size of y:  torch.Size([2, 320, 320])
Epoch 0:  98%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▉   | 100/102 [00:12<00:00,  8.27it/s, v_num=749d]train size of y:  torch.Size([1, 320, 320])
Traceback (most recent call last):
  File "/code/lightning_ai/cli.py", line 157, in <module>
    run(obj={})
  File "/usr/local/lib/python3.10/site-packages/click/core.py", line 1157, in __call__
    return self.main(*args, **kwargs)
  File "/usr/local/lib/python3.10/site-packages/click/core.py", line 1078, in main
    rv = self.invoke(ctx)
  File "/usr/local/lib/python3.10/site-packages/click/core.py", line 1688, in invoke
    return _process_result(sub_ctx.command.invoke(sub_ctx))
  File "/usr/local/lib/python3.10/site-packages/click/core.py", line 1434, in invoke
    return ctx.invoke(self.callback, **ctx.params)
  File "/usr/local/lib/python3.10/site-packages/click/core.py", line 783, in invoke
    return __callback(*args, **kwargs)
  File "/usr/local/lib/python3.10/site-packages/click/decorators.py", line 33, in new_func
    return f(get_current_context(), *args, **kwargs)
  File "/code/lightning_ai/cli.py", line 135, in train
    run_trainer(lightning_cli.model, lightning_cli.trainer, lightning_cli.datamodule)
  File "/code/common/mlclient.py", line 712, in wrapper
    result = func(*args, **kwargs)
  File "/code/lightning_ai/cli.py", line 123, in run_trainer
    lightning_cli.trainer.fit(model, datamodule=datamodule)
  File "/usr/local/lib/python3.10/site-packages/mlflow/utils/autologging_utils/safety.py", line 578, in safe_patch_function
    patch_function(call_original, *args, **kwargs)
  File "/usr/local/lib/python3.10/site-packages/mlflow/utils/autologging_utils/safety.py", line 251, in patch_with_managed_run
    result = patch_function(original, *args, **kwargs)
  File "/usr/local/lib/python3.10/site-packages/mlflow/pytorch/_lightning_autolog.py", line 537, in patched_fit
    result = original(self, *args, **kwargs)
  File "/usr/local/lib/python3.10/site-packages/mlflow/utils/autologging_utils/safety.py", line 559, in call_original
    return call_original_fn_with_event_logging(_original_fn, og_args, og_kwargs)
  File "/usr/local/lib/python3.10/site-packages/mlflow/utils/autologging_utils/safety.py", line 494, in call_original_fn_with_event_logging
    original_fn_result = original_fn(*og_args, **og_kwargs)
  File "/usr/local/lib/python3.10/site-packages/mlflow/utils/autologging_utils/safety.py", line 556, in _original_fn
    original_result = original(*_og_args, **_og_kwargs)
  File "/usr/local/lib/python3.10/site-packages/lightning/pytorch/trainer/trainer.py", line 538, in fit
    call._call_and_handle_interrupt(
  File "/usr/local/lib/python3.10/site-packages/lightning/pytorch/trainer/call.py", line 47, in _call_and_handle_interrupt
    return trainer_fn(*args, **kwargs)
  File "/usr/local/lib/python3.10/site-packages/lightning/pytorch/trainer/trainer.py", line 574, in _fit_impl
    self._run(model, ckpt_path=ckpt_path)
  File "/usr/local/lib/python3.10/site-packages/lightning/pytorch/trainer/trainer.py", line 981, in _run
    results = self._run_stage()
  File "/usr/local/lib/python3.10/site-packages/lightning/pytorch/trainer/trainer.py", line 1025, in _run_stage
    self.fit_loop.run()
  File "/usr/local/lib/python3.10/site-packages/lightning/pytorch/loops/fit_loop.py", line 205, in run
    self.advance()
  File "/usr/local/lib/python3.10/site-packages/lightning/pytorch/loops/fit_loop.py", line 363, in advance
    self.epoch_loop.run(self._data_fetcher)
  File "/usr/local/lib/python3.10/site-packages/lightning/pytorch/loops/training_epoch_loop.py", line 140, in run
    self.advance(data_fetcher)
  File "/usr/local/lib/python3.10/site-packages/lightning/pytorch/loops/training_epoch_loop.py", line 250, in advance
    batch_output = self.automatic_optimization.run(trainer.optimizers[0], batch_idx, kwargs)
  File "/usr/local/lib/python3.10/site-packages/lightning/pytorch/loops/optimization/automatic.py", line 190, in run
    self._optimizer_step(batch_idx, closure)
  File "/usr/local/lib/python3.10/site-packages/lightning/pytorch/loops/optimization/automatic.py", line 268, in _optimizer_step
    call._call_lightning_module_hook(
  File "/usr/local/lib/python3.10/site-packages/lightning/pytorch/trainer/call.py", line 167, in _call_lightning_module_hook
    output = fn(*args, **kwargs)
  File "/usr/local/lib/python3.10/site-packages/lightning/pytorch/core/module.py", line 1307, in optimizer_step
    optimizer.step(closure=optimizer_closure)
  File "/usr/local/lib/python3.10/site-packages/lightning/pytorch/core/optimizer.py", line 153, in step
    step_output = self._strategy.optimizer_step(self._optimizer, closure, **kwargs)
  File "/usr/local/lib/python3.10/site-packages/lightning/pytorch/strategies/strategy.py", line 238, in optimizer_step
    return self.precision_plugin.optimizer_step(optimizer, model=model, closure=closure, **kwargs)
  File "/usr/local/lib/python3.10/site-packages/lightning/pytorch/plugins/precision/precision.py", line 122, in optimizer_step
    return optimizer.step(closure=closure, **kwargs)
  File "/usr/local/lib/python3.10/site-packages/torch/optim/optimizer.py", line 487, in wrapper
    out = func(*args, **kwargs)
  File "/usr/local/lib/python3.10/site-packages/torch/optim/optimizer.py", line 91, in _use_grad
    ret = func(self, *args, **kwargs)
  File "/usr/local/lib/python3.10/site-packages/torch/optim/adam.py", line 202, in step
    loss = closure()
  File "/usr/local/lib/python3.10/site-packages/lightning/pytorch/plugins/precision/precision.py", line 108, in _wrap_closure
    closure_result = closure()
  File "/usr/local/lib/python3.10/site-packages/lightning/pytorch/loops/optimization/automatic.py", line 144, in __call__
    self._result = self.closure(*args, **kwargs)
  File "/usr/local/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context
    return func(*args, **kwargs)
  File "/usr/local/lib/python3.10/site-packages/lightning/pytorch/loops/optimization/automatic.py", line 129, in closure
    step_output = self._step_fn()
  File "/usr/local/lib/python3.10/site-packages/lightning/pytorch/loops/optimization/automatic.py", line 317, in _training_step
    training_step_output = call._call_strategy_hook(trainer, "training_step", *kwargs.values())
  File "/usr/local/lib/python3.10/site-packages/lightning/pytorch/trainer/call.py", line 319, in _call_strategy_hook
    output = fn(*args, **kwargs)
  File "/usr/local/lib/python3.10/site-packages/lightning/pytorch/strategies/strategy.py", line 390, in training_step
    return self.lightning_module.training_step(*args, **kwargs)
  File "/code/common/model/base_lightning_net.py", line 320, in training_step
    logits = self.forward(x)
  File "/code/common/model/base_lightning_net.py", line 282, in forward
    return self.net(x)
  File "/usr/local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/usr/local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
    return forward_call(*args, **kwargs)
  File "/usr/local/lib/python3.10/site-packages/segmentation_models_pytorch/base/model.py", line 30, in forward
    decoder_output = self.decoder(*features)
  File "/usr/local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/usr/local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
    return forward_call(*args, **kwargs)
  File "/usr/local/lib/python3.10/site-packages/segmentation_models_pytorch/decoders/deeplabv3/decoder.py", line 99, in forward
    aspp_features = self.aspp(features[-1])
  File "/usr/local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/usr/local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
    return forward_call(*args, **kwargs)
  File "/usr/local/lib/python3.10/site-packages/torch/nn/modules/container.py", line 250, in forward
    input = module(input)
  File "/usr/local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/usr/local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
    return forward_call(*args, **kwargs)
  File "/usr/local/lib/python3.10/site-packages/segmentation_models_pytorch/decoders/deeplabv3/decoder.py", line 187, in forward
    res.append(conv(x))
  File "/usr/local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/usr/local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
    return forward_call(*args, **kwargs)
  File "/usr/local/lib/python3.10/site-packages/segmentation_models_pytorch/decoders/deeplabv3/decoder.py", line 151, in forward
    x = mod(x)
  File "/usr/local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/usr/local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
    return forward_call(*args, **kwargs)
  File "/usr/local/lib/python3.10/site-packages/torch/nn/modules/batchnorm.py", line 193, in forward
    return F.batch_norm(
  File "/usr/local/lib/python3.10/site-packages/torch/nn/functional.py", line 2810, in batch_norm
    _verify_batch_size(input.size())
  File "/usr/local/lib/python3.10/site-packages/torch/nn/functional.py", line 2776, in _verify_batch_size
    raise ValueError(
ValueError: Expected more than 1 value per channel when training, got input size torch.Size([1, 256, 1, 1])

To Reproduce

Steps to reproduce the behavior...

Code sample
# Ideally attach a minimal code sample to reproduce the decried issue.
# Minimal means having the shortest code but still preserving the bug.

Expected behavior

Additional context

litdata==0.2.34

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't workinghelp wantedExtra attention is needed

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions