Skip to content

Cpu memory accumulation bug #20730

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 8 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 16 additions & 21 deletions src/lightning/pytorch/loops/prediction_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,9 @@
from collections.abc import Iterator
from typing import Any, Optional, Union

import torch
from lightning_utilities import WarningCache

import lightning.pytorch as pl
from lightning.fabric.utilities import move_data_to_device
from lightning.pytorch.callbacks import BasePredictionWriter
from lightning.pytorch.loops.fetchers import _DataFetcher, _DataLoaderIterDataFetcher
from lightning.pytorch.loops.loop import _Loop
Expand Down Expand Up @@ -247,32 +245,29 @@ def _predict_step(
self.batch_progress.increment_started()

# configure step_kwargs
step_args = (
self._build_step_args_from_hook_kwargs(hook_kwargs, "predict_step")
if not using_dataloader_iter
else (dataloader_iter,)
)
predictions = call._call_strategy_hook(trainer, "predict_step", *step_args)
if predictions is None:
self._warning_cache.warn("predict returned None if it was on purpose, ignore this warning...")
step_args = self._build_step_args_from_hook_kwargs(hook_kwargs, "predict_step")
step_output = call._call_lightning_module_hook(trainer, "predict_step", *step_args)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why are you directly calling lightning module hook without calling strategy hook?

After couple of checks and precision_plugin context, it does call lightning_module's predict_step.


self.batch_progress.increment_processed()

if using_dataloader_iter:
# update the hook kwargs now that the step method might have consumed the iterator
batch = data_fetcher._batch
batch_idx = data_fetcher._batch_idx
dataloader_idx = data_fetcher._dataloader_idx
hook_kwargs = self._build_kwargs(batch, batch_idx, dataloader_idx if self.num_dataloaders > 1 else None)
# track batch indices for prediction writer
if not using_dataloader_iter and any_on_epoch:
self.current_batch_indices = self._get_batch_indices(data_fetcher.current_dataloader)

# track predictions if needed
if self.return_predictions:
self._predictions[dataloader_idx].append(step_output)
else:
# Clear memory if not returning predictions
import gc

gc.collect()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do you think it would be a good idea to have an argument collect_gc or something that users can toggle.

As Adrian said: it might be expensive in certain situations.


call._call_callback_hooks(trainer, "on_predict_batch_end", predictions, *hook_kwargs.values())
call._call_lightning_module_hook(trainer, "on_predict_batch_end", predictions, *hook_kwargs.values())
call._call_callback_hooks(trainer, "on_predict_batch_end", step_output, *hook_kwargs.values())
call._call_lightning_module_hook(trainer, "on_predict_batch_end", step_output, *hook_kwargs.values())

self.batch_progress.increment_completed()

if self._return_predictions or any_on_epoch:
self._predictions[dataloader_idx].append(move_data_to_device(predictions, torch.device("cpu")))

def _build_kwargs(self, batch: Any, batch_idx: int, dataloader_idx: Optional[int]) -> OrderedDict:
"""Assembles the keyword arguments for the ``predict_step``

Expand Down
9 changes: 9 additions & 0 deletions tests/tests_pytorch/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,15 @@ def restore_env_variables():
"TF_GRPC_DEFAULT_OPTIONS",
"XLA_FLAGS",
"TORCHINDUCTOR_CACHE_DIR", # leaked by torch.compile
# Memory leak test related
"PYTORCH_CUDA_ALLOC_CONF", # PyTorch memory allocator config
"CUDA_VISIBLE_DEVICES", # GPU visibility
"PYTORCH_NO_CUDA_MEMORY_CACHING", # Disable CUDA memory caching
# TensorFlow and TPU related
"ENABLE_RUNTIME_UPTIME_TELEMETRY", # TensorFlow telemetry
"TF2_BEHAVIOR", # TensorFlow 2.x behavior flag
"TPU_ML_PLATFORM", # TPU platform configuration
"TPU_ML_PLATFORM_VERSION", # TPU platform version
}
leaked_vars.difference_update(allowlist)
assert not leaked_vars, f"test is leaking environment variable(s): {set(leaked_vars)}"
Expand Down
81 changes: 81 additions & 0 deletions tests/tests_pytorch/trainer/test_memory_leak.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
import os

import psutil
import pytest
import torch
from torch.utils.data import DataLoader, Dataset

from lightning.pytorch import Trainer
from lightning.pytorch.demos.boring_classes import BoringModel


class CustomModel(BoringModel):
def __init__(self):
super().__init__()
self.layer = torch.nn.Linear(1000, 2) # Changed to match LargeDataset dim=1000

def forward(self, x):
return self.layer(x)


class LargeDataset(Dataset):
def __init__(self, size=1000, dim=1000):
self.data = torch.randn(size, dim)
self.targets = torch.randint(0, 10, (size,))

def __len__(self):
return len(self.data)

def __iter__(self):
for i in range(len(self)):
yield self[i]

def __getitem__(self, idx):
# During prediction, return only the input tensor
if hasattr(self, "prediction_mode") and self.prediction_mode:
return self.data[idx]
return self.data[idx], self.targets[idx]

def set_prediction_mode(self, mode=True):
self.prediction_mode = mode


def get_memory_usage():
process = psutil.Process(os.getpid())
return process.memory_info().rss / 1024 / 1024 # MB


@pytest.mark.parametrize("return_predictions", [True, False])
def test_prediction_memory_leak(tmp_path, return_predictions):
"""Test that memory usage doesn't grow during prediction when return_predictions=False."""
# Create a model and dataset
model = CustomModel()
dataset = LargeDataset()
dataset.set_prediction_mode(True) # Set prediction mode
dataloader = DataLoader(dataset, batch_size=32)

# Get initial memory usage
initial_memory = get_memory_usage()

# Run prediction
trainer = Trainer(
default_root_dir=tmp_path,
accelerator="cpu",
devices=1,
max_epochs=1,
)

trainer.predict(model, dataloaders=dataloader, return_predictions=return_predictions)

# Get final memory usage
final_memory = get_memory_usage()

# Calculate memory growth
memory_growth = final_memory - initial_memory

# When return_predictions=False, memory growth should be minimal
if not return_predictions:
assert memory_growth < 100, f"Memory growth {memory_growth}MB is too high when return_predictions=False"
else:
# When return_predictions=True, we expect some memory growth due to storing predictions
assert memory_growth > 0, "Expected memory growth when storing predictions"
Loading