Skip to content

Pytorch Profiler crashes while using it with Pytorch Lightning modules #20779

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
MKaczkow opened this issue May 1, 2025 · 2 comments · May be fixed by #20864
Closed

Pytorch Profiler crashes while using it with Pytorch Lightning modules #20779

MKaczkow opened this issue May 1, 2025 · 2 comments · May be fixed by #20864
Labels
bug Something isn't working needs triage Waiting to be triaged by maintainers ver: 2.3.x

Comments

@MKaczkow
Copy link

MKaczkow commented May 1, 2025

Bug description

Pytorch Profiler crashes while using it with pytorch-lightning. I am attempting to profile some experiments, but keep getting errors like shown below. I've searched forum and gh issues and I'm aware of the following:

  • issue (not relevant -> different cause of error as sugested by message)
  • issue (not relevant -> different cause of error as sugested by message)
  • forum post (not relevant -> profiler runs, but output not in tensorboard)

Suspecting / judging from error message, that the problem is related to context management in profiler, I've tried 2 ways of launching it, v1 -> distinct-context-per-stage and v2 -> single-context-for-experiment, but neither have succeded. Remaining parts of experiment, like dataloaders, model, etc. are provided in the environment and so far worked correctly (listed example setup at the very end of this issue, as it's quite a lot of code). Expected behaviour is obviously "no-crashing" and returning / writting relevant profiling information istead.

Will be grateful for any ideas / debugging tips 🙂

What version are you seeing the problem on?

v2.3

How to reproduce the bug

import os
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, Dataset
from torchvision.transforms import ToTensor
from torchvision.datasets import MNIST

import pytorch_lightning as pl
from torch.profiler import profile, record_function, ProfilerActivity

# Define a simple SimCLR model
class SimCLRModel(pl.LightningModule):
    def __init__(self, hidden_dim=128, lr=1e-3):
        super().__init__()
        self.encoder = nn.Sequential(
            nn.Linear(28 * 28, 512),
            nn.ReLU(),
            nn.Linear(512, hidden_dim),
        )
        self.projection = nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
        )
        self.lr = lr

    def forward(self, x):
        h = self.encoder(x.view(x.size(0), -1))
        z = self.projection(h)
        return z

    def training_step(self, batch, batch_idx):
        x, _ = batch
        z = self(x)
        # Dummy loss for demonstration purposes
        loss = torch.mean(z)
        self.log('train_loss', loss)
        return loss

    def validation_step(self, batch, batch_idx):
        x, _ = batch
        z = self(x)
        # Dummy loss for demonstration purposes
        loss = torch.mean(z)
        self.log('val_loss', loss)
        return loss

    def test_step(self, batch, batch_idx):
        x, _ = batch
        z = self(x)
        # Dummy loss for demonstration purposes
        loss = torch.mean(z)
        self.log('test_loss', loss)
        return loss

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=self.lr)
        return optimizer

# Define a simple dataset (using MNIST for simplicity)
class ContrastiveMNIST(Dataset):
    def __init__(self, root, train=True, transform=ToTensor(), download=True):
        self.mnist = MNIST(root, train=train, transform=transform, download=download)

    def __len__(self):
        return len(self.mnist)

    def __getitem__(self, idx):
        img, target = self.mnist[idx]
        # Create a dummy second view for contrastive learning (same as first for simplicity)
        img_pair = img
        return img, img_pair

# --- Setup ---
# Define hyperparameters
max_epochs = 3
batch_size = 64
learning_rate = 1e-3
hidden_dimension = 128
accelerator = "gpu"  # "cpu" or "cuda"

# Create data loaders
data_dir = os.getcwd()  # Use current directory to store MNIST
train_dataset = ContrastiveMNIST(data_dir, train=True, download=True)
val_dataset = ContrastiveMNIST(data_dir, train=False, download=True)
test_dataset = ContrastiveMNIST(data_dir, train=False, download=True)

dataloader_train_simclr = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=2)
dataloader_val_simclr = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=2)
dataloader_test_simclr = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=2)

# Initialize the model
model = SimCLRModel(hidden_dim=hidden_dimension, lr=learning_rate)


and then, I've tried this 2 options:

trainer = pl.Trainer(
    log_every_n_steps=100,
    max_epochs=max_epochs,
    devices=1,
    accelerator=accelerator,
    enable_checkpointing=False,
    num_sanity_val_steps=0,  # to avoid adding unnecessary item to validation_epoch_embedding_norms
)

###########################
# Pre-training
###########################
with profile(
    activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA],
    record_shapes=True,
    profile_memory=True,
    with_stack=True,
) as prof:
    with record_function("pretraining-validation"):
        # perform extra 'validation' epoch to see if untrained model does anything useful
        trainer.validate(model, dataloader_val_simclr)

###########################
# Training
###########################
    with record_function("training-phase"):
        trainer.fit(
            model=model,
            train_dataloaders=dataloader_train_simclr,
            val_dataloaders=dataloader_val_simclr,
        )

###########################
# Testing
###########################
    with record_function("testing-final"):
        trainer.test(
            model,
            dataloaders=dataloader_test_simclr,
        )

Code snippet v2:

trainer = pl.Trainer(
    log_every_n_steps=100,
    max_epochs=max_epochs,
    devices=1,
    accelerator=accelerator,
    enable_checkpointing=False,
    num_sanity_val_steps=0,  # to avoid adding unnecessary item to validation_epoch_embedding_norms
)

###########################
# Pre-training
###########################
with profile(
    activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA],
    record_shapes=True,
    profile_memory=True,
    with_stack=True,
) as prof:
    with record_function("pretraining-validation"):
        # perform extra 'validation' epoch to see if untrained model does anything useful
        trainer.validate(model, dataloader_val_simclr)

###########################
# Training
###########################
with profile(
    activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA],
    record_shapes=True,
    profile_memory=True,
    with_stack=True,
) as prof:
    with record_function("training-phase"):
        trainer.fit(
            model=model,
            train_dataloaders=dataloader_train_simclr,
            val_dataloaders=dataloader_val_simclr,
        )

###########################
# Testing
###########################
with profile(
    activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA],
    record_shapes=True,
    profile_memory=True,
    with_stack=True,
) as prof:
    with record_function("testing-final"):
        trainer.test(
            model,
            dataloaders=dataloader_test_simclr,
        )

Error messages and logs

Stack traces:

RuntimeError                              Traceback (most recent call last)
Cell In[4], line 107
     96 trainer = pl.Trainer(
     97     log_every_n_steps=100,
     98     max_epochs=max_epochs,
   (...)
    101     num_sanity_val_steps=0,  # to avoid adding unnecessary item to validation_epoch_embedding_norms
    102 )
    104 ###########################
    105 # Pre-training (Validation before training)
    106 ###########################
--> 107 with profile(
    108     activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA],
    109     record_shapes=True,
    110     profile_memory=True,
    111     with_stack=True,
    112 ) as prof:
    113     with record_function("pretraining-validation"):
    114         # perform extra 'validation' epoch to see if untrained model does anything useful
    115         trainer.validate(model, dataloader_val_simclr)

File d:\{repository_path}\venv\Lib\site-packages\torch\profiler\profiler.py:699, in profile.__exit__(self, exc_type, exc_val, exc_tb)
    698 def __exit__(self, exc_type, exc_val, exc_tb):
--> 699     self.stop()
    700     prof.KinetoStepTracker.erase_step_count(PROFILER_STEP_NAME)
    701     if self.execution_trace_observer:

File d:\{repository_path}\venv\Lib\site-packages\torch\profiler\profiler.py:715, in profile.stop(self)
    713 if self.record_steps and self.step_rec_fn:
    714     self.step_rec_fn.__exit__(None, None, None)
--> 715 self._transit_action(self.current_action, None)

File d:\{repository_path}\venv\Lib\site-packages\torch\profiler\profiler.py:744, in profile._transit_action(self, prev_action, current_action)
    742 if action_list:
    743     for action in action_list:
--> 744         action()

File d:\{repository_path}\venv\Lib\site-packages\torch\profiler\profiler.py:199, in _KinetoProfile.stop_trace(self)
    197     self.execution_trace_observer.stop()
    198 assert self.profiler is not None
--> 199 self.profiler.__exit__(None, None, None)

File d:\{repository_path}\venv\Lib\site-packages\torch\autograd\profiler.py:296, in profile.__exit__(self, exc_type, exc_val, exc_tb)
    294 if self.use_cuda:
    295     torch.cuda.synchronize()
--> 296 self.kineto_results = _disable_profiler()
    297 _run_on_profiler_stop()
    298 parsed_results = self._parse_kineto_results(self.kineto_results)

RuntimeError: !stack.empty() INTERNAL ASSERT FAILED at "..\\torch\\csrc\\autograd\\profiler_python.cpp":969, please report a bug to PyTorch. Python replay stack is empty.

Sometimes (seems random to be), I get this error:

RuntimeError                              Traceback (most recent call last)
Cell In[28], [line 208](vscode-notebook-cell:?execution_count=28&line=208)
    [189](vscode-notebook-cell:?execution_count=28&line=189) trainer = pl.Trainer(
    [190](vscode-notebook-cell:?execution_count=28&line=190)     log_every_n_steps=100,
    [191](vscode-notebook-cell:?execution_count=28&line=191)     max_epochs=max_epochs,
   (...)
    [202](vscode-notebook-cell:?execution_count=28&line=202)     ],
    [203](vscode-notebook-cell:?execution_count=28&line=203) )
    [205](vscode-notebook-cell:?execution_count=28&line=205) ###########################
    [206](vscode-notebook-cell:?execution_count=28&line=206) # Pre-training
    [207](vscode-notebook-cell:?execution_count=28&line=207) ###########################
--> [208](vscode-notebook-cell:?execution_count=28&line=208) with profile(
    [209](vscode-notebook-cell:?execution_count=28&line=209)     activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA],
    [210](vscode-notebook-cell:?execution_count=28&line=210)     record_shapes=True,
    [211](vscode-notebook-cell:?execution_count=28&line=211)     profile_memory=True,
    [212](vscode-notebook-cell:?execution_count=28&line=212)     with_stack=True,
    [213](vscode-notebook-cell:?execution_count=28&line=213) ) as prof:
    [214](vscode-notebook-cell:?execution_count=28&line=214)     with record_function("pretraining-validation"):
    [215](vscode-notebook-cell:?execution_count=28&line=215)         # perform extra 'validation' epoch to see if untrained model does anything useful
    [216](vscode-notebook-cell:?execution_count=28&line=216)         trainer.validate(model, dataloader_val_simclr)

File d:\__repos\masters_bacter_private\venv\Lib\site-packages\torch\profiler\profiler.py:695, in profile.__enter__(self)
    [694](file:///D:/{repository_path}/venv/Lib/site-packages/torch/profiler/profiler.py:694) def __enter__(self):
--> [695](file:///D:/{repository_path}/venv/Lib/site-packages/torch/profiler/profiler.py:695)     self.start()
    [696](file:///D:/{repository_path}/venv/Lib/site-packages/torch/profiler/profiler.py:696)     return self

File d:\__repos\masters_bacter_private\venv\Lib\site-packages\torch\profiler\profiler.py:705, in profile.start(self)
    [704](file:///D:/{repository_path}/venv/Lib/site-packages/torch/profiler/profiler.py:704) def start(self):
--> [705](file:///D:/{repository_path}/venv/Lib/site-packages/torch/profiler/profiler.py:705)     self._transit_action(ProfilerAction.NONE, self.current_action)
    [706](file:///D:/{repository_path}/venv/Lib/site-packages/torch/profiler/profiler.py:706)     if self.record_steps:
    [707](file:///D:/{repository_path}/venv/Lib/site-packages/torch/profiler/profiler.py:707)         self.step_rec_fn = prof.record_function(
    [708](file:///D:/{repository_path}/venv/Lib/site-packages/torch/profiler/profiler.py:708)             "ProfilerStep#" + str(self.step_num)
    [709](file:///D:/{repository_path}/venv/Lib/site-packages/torch/profiler/profiler.py:709)         )

File d:\__repos\masters_bacter_private\venv\Lib\site-packages\torch\profiler\profiler.py:744, in profile._transit_action(self, prev_action, current_action)
    [742](file:///D:/{repository_path}/venv/Lib/site-packages/torch/profiler/profiler.py:742) if action_list:
    [743](file:///D:/{repository_path}/venv/Lib/site-packages/torch/profiler/profiler.py:743)     for action in action_list:
--> [744](file:///D:/{repository_path}/venv/Lib/site-packages/torch/profiler/profiler.py:744)         action()

File d:\__repos\masters_bacter_private\venv\Lib\site-packages\torch\profiler\profiler.py:155, in _KinetoProfile.prepare_trace(self)
    [141](file:///D:/{repository_path}/venv/Lib/site-packages/torch/profiler/profiler.py:141) def prepare_trace(self):
    [142](file:///D:/{repository_path}/venv/Lib/site-packages/torch/profiler/profiler.py:142)     self.profiler = prof.profile(
    [143](file:///D:/{repository_path}/venv/Lib/site-packages/torch/profiler/profiler.py:143)         use_cuda=(ProfilerActivity.CUDA in self.activities),
    [144](file:///D:/{repository_path}/venv/Lib/site-packages/torch/profiler/profiler.py:144)         use_cpu=(ProfilerActivity.CPU in self.activities),
   (...)
    [153](file:///D:/{repository_path}/venv/Lib/site-packages/torch/profiler/profiler.py:153)         experimental_config=self.experimental_config,
    [154](file:///D:/{repository_path}/venv/Lib/site-packages/torch/profiler/profiler.py:154)     )
--> [155](file:///D:/{repository_path}/venv/Lib/site-packages/torch/profiler/profiler.py:155)     self.profiler._prepare_trace()

File d:\__repos\masters_bacter_private\venv\Lib\site-packages\torch\autograd\profiler.py:284, in profile._prepare_trace(self)
    [282](file:///D:/{repository_path}/venv/Lib/site-packages/torch/autograd/profiler.py:282) def _prepare_trace(self):
    [283](file:///D:/{repository_path}/venv/Lib/site-packages/torch/autograd/profiler.py:283)     self.entered = True
--> [284](file:///D:/{repository_path}/venv/Lib/site-packages/torch/autograd/profiler.py:284)     _prepare_profiler(self.config(), self.kineto_activities)

RuntimeError: Can't disable Kineto profiler when it's not running

Environment

PyTorch version: 2.3.1+cu118
Is debug build: False
CUDA used to build PyTorch: 11.8
ROCM used to build PyTorch: N/A

OS: Microsoft Windows 10 Education (10.0.19045 64-bitowy)
GCC version: Could not collect
Clang version: Could not collect
CMake version: Could not collect
Libc version: N/A

Python version: 3.12.3 (tags/v3.12.3:f6650f9, Apr 9 2024, 14:05:25) [MSC v.1938 64 bit (AMD64)]
(64-bit runtime)
Python platform: Windows-10-10.0.19045-SP0
Is CUDA available: True
CUDA runtime version: Could not collect
CUDA_MODULE_LOADING set to: LAZY
GPU models and configuration: GPU 0: NVIDIA GeForce RTX 3060
Nvidia driver version: 560.94
cuDNN version: Could not collect
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True

CPU:
Name: AMD Ryzen 7 5800X 8-Core Processor
Manufacturer: AuthenticAMD
Family: 107
Architecture: 9
ProcessorType: 3
DeviceID: CPU0
CurrentClockSpeed: 3801
MaxClockSpeed: 3801
L2CacheSize: 4096
L2CacheSpeed: None
Revision: 8448

Versions of relevant libraries:
[pip3] mypy-extensions==1.0.0
[pip3] numpy==1.26.3
[pip3] pytorch-lightning==2.3.3
[pip3] torch==2.3.1+cu118
[pip3] torch-tb-profiler==0.4.3
[pip3] torchmetrics==1.4.0.post0
[pip3] torchvision==0.18.1+cu118
[conda] Could not collect

More info

No response

@MKaczkow MKaczkow added bug Something isn't working needs triage Waiting to be triaged by maintainers labels May 1, 2025
@KAVYANSHTYAGI
Copy link
Contributor

From what i can think of, you can try doing the following-

Use Lightning’s BuiltIn Profiler

This is the only robust, supported way. Lightning integrates the profiler correctly within its training loop.

trainer = pl.Trainer(
    ...,
    profiler="pytorch",  # or profiler="advanced"
)

This triggers profiling at the right spots, per-batch, and avoids context mismatches.

Manual Profiling: Only Profile Inside a Model Step

If you need to profile something custom, do it INSIDE a LightningModule method (like training_step), not outside the trainer.

def training_step(self, batch, batch_idx):
    with torch.profiler.profile(
        activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA],
        record_shapes=True,
        profile_memory=True,
        with_stack=True,
    ) as prof:
        # ... your forward & loss logic
    print(prof.key_averages().table(sort_by="self_cpu_time_total"))
    return loss

@MKaczkow
Copy link
Author

MKaczkow commented Jun 2, 2025

OK, good to know, thanks.

Since this is no longer relevant for me, I'm closing issue, perhaps someone will find this useful later.

@MKaczkow MKaczkow closed this as completed Jun 2, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working needs triage Waiting to be triaged by maintainers ver: 2.3.x
Projects
None yet
2 participants