You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Pytorch Profiler crashes while using it with pytorch-lightning. I am attempting to profile some experiments, but keep getting errors like shown below. I've searched forum and gh issues and I'm aware of the following:
issue (not relevant -> different cause of error as sugested by message)
issue (not relevant -> different cause of error as sugested by message)
forum post (not relevant -> profiler runs, but output not in tensorboard)
Suspecting / judging from error message, that the problem is related to context management in profiler, I've tried 2 ways of launching it, v1 -> distinct-context-per-stage and v2 -> single-context-for-experiment, but neither have succeded. Remaining parts of experiment, like dataloaders, model, etc. are provided in the environment and so far worked correctly (listed example setup at the very end of this issue, as it's quite a lot of code). Expected behaviour is obviously "no-crashing" and returning / writting relevant profiling information istead.
Will be grateful for any ideas / debugging tips 🙂
What version are you seeing the problem on?
v2.3
How to reproduce the bug
importosimporttorchimporttorch.nnasnnfromtorch.utils.dataimportDataLoader, Datasetfromtorchvision.transformsimportToTensorfromtorchvision.datasetsimportMNISTimportpytorch_lightningasplfromtorch.profilerimportprofile, record_function, ProfilerActivity# Define a simple SimCLR modelclassSimCLRModel(pl.LightningModule):
def__init__(self, hidden_dim=128, lr=1e-3):
super().__init__()
self.encoder=nn.Sequential(
nn.Linear(28*28, 512),
nn.ReLU(),
nn.Linear(512, hidden_dim),
)
self.projection=nn.Sequential(
nn.Linear(hidden_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, hidden_dim),
)
self.lr=lrdefforward(self, x):
h=self.encoder(x.view(x.size(0), -1))
z=self.projection(h)
returnzdeftraining_step(self, batch, batch_idx):
x, _=batchz=self(x)
# Dummy loss for demonstration purposesloss=torch.mean(z)
self.log('train_loss', loss)
returnlossdefvalidation_step(self, batch, batch_idx):
x, _=batchz=self(x)
# Dummy loss for demonstration purposesloss=torch.mean(z)
self.log('val_loss', loss)
returnlossdeftest_step(self, batch, batch_idx):
x, _=batchz=self(x)
# Dummy loss for demonstration purposesloss=torch.mean(z)
self.log('test_loss', loss)
returnlossdefconfigure_optimizers(self):
optimizer=torch.optim.Adam(self.parameters(), lr=self.lr)
returnoptimizer# Define a simple dataset (using MNIST for simplicity)classContrastiveMNIST(Dataset):
def__init__(self, root, train=True, transform=ToTensor(), download=True):
self.mnist=MNIST(root, train=train, transform=transform, download=download)
def__len__(self):
returnlen(self.mnist)
def__getitem__(self, idx):
img, target=self.mnist[idx]
# Create a dummy second view for contrastive learning (same as first for simplicity)img_pair=imgreturnimg, img_pair# --- Setup ---# Define hyperparametersmax_epochs=3batch_size=64learning_rate=1e-3hidden_dimension=128accelerator="gpu"# "cpu" or "cuda"# Create data loadersdata_dir=os.getcwd() # Use current directory to store MNISTtrain_dataset=ContrastiveMNIST(data_dir, train=True, download=True)
val_dataset=ContrastiveMNIST(data_dir, train=False, download=True)
test_dataset=ContrastiveMNIST(data_dir, train=False, download=True)
dataloader_train_simclr=DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=2)
dataloader_val_simclr=DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=2)
dataloader_test_simclr=DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=2)
# Initialize the modelmodel=SimCLRModel(hidden_dim=hidden_dimension, lr=learning_rate)
andthen, I'vetriedthis2options:
trainer=pl.Trainer(
log_every_n_steps=100,
max_epochs=max_epochs,
devices=1,
accelerator=accelerator,
enable_checkpointing=False,
num_sanity_val_steps=0, # to avoid adding unnecessary item to validation_epoch_embedding_norms
)
############################ Pre-training###########################withprofile(
activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA],
record_shapes=True,
profile_memory=True,
with_stack=True,
) asprof:
withrecord_function("pretraining-validation"):
# perform extra 'validation' epoch to see if untrained model does anything usefultrainer.validate(model, dataloader_val_simclr)
############################ Training###########################withrecord_function("training-phase"):
trainer.fit(
model=model,
train_dataloaders=dataloader_train_simclr,
val_dataloaders=dataloader_val_simclr,
)
############################ Testing###########################withrecord_function("testing-final"):
trainer.test(
model,
dataloaders=dataloader_test_simclr,
)
Codesnippetv2:
trainer=pl.Trainer(
log_every_n_steps=100,
max_epochs=max_epochs,
devices=1,
accelerator=accelerator,
enable_checkpointing=False,
num_sanity_val_steps=0, # to avoid adding unnecessary item to validation_epoch_embedding_norms
)
############################ Pre-training###########################withprofile(
activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA],
record_shapes=True,
profile_memory=True,
with_stack=True,
) asprof:
withrecord_function("pretraining-validation"):
# perform extra 'validation' epoch to see if untrained model does anything usefultrainer.validate(model, dataloader_val_simclr)
############################ Training###########################withprofile(
activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA],
record_shapes=True,
profile_memory=True,
with_stack=True,
) asprof:
withrecord_function("training-phase"):
trainer.fit(
model=model,
train_dataloaders=dataloader_train_simclr,
val_dataloaders=dataloader_val_simclr,
)
############################ Testing###########################withprofile(
activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA],
record_shapes=True,
profile_memory=True,
with_stack=True,
) asprof:
withrecord_function("testing-final"):
trainer.test(
model,
dataloaders=dataloader_test_simclr,
)
Error messages and logs
Stack traces:
RuntimeError Traceback (most recent call last)
Cell In[4], line 107
96 trainer = pl.Trainer(
97 log_every_n_steps=100,
98 max_epochs=max_epochs,
(...)
101 num_sanity_val_steps=0, # to avoid adding unnecessary item to validation_epoch_embedding_norms
102 )
104 ###########################
105 # Pre-training (Validation before training)
106 ###########################
--> 107 with profile(
108 activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA],
109 record_shapes=True,
110 profile_memory=True,
111 with_stack=True,
112 ) as prof:
113 with record_function("pretraining-validation"):
114 # perform extra 'validation' epoch to see if untrained model does anything useful
115 trainer.validate(model, dataloader_val_simclr)
File d:\{repository_path}\venv\Lib\site-packages\torch\profiler\profiler.py:699, in profile.__exit__(self, exc_type, exc_val, exc_tb)
698 def __exit__(self, exc_type, exc_val, exc_tb):
--> 699 self.stop()
700 prof.KinetoStepTracker.erase_step_count(PROFILER_STEP_NAME)
701 if self.execution_trace_observer:
File d:\{repository_path}\venv\Lib\site-packages\torch\profiler\profiler.py:715, in profile.stop(self)
713 if self.record_steps and self.step_rec_fn:
714 self.step_rec_fn.__exit__(None, None, None)
--> 715 self._transit_action(self.current_action, None)
File d:\{repository_path}\venv\Lib\site-packages\torch\profiler\profiler.py:744, in profile._transit_action(self, prev_action, current_action)
742 if action_list:
743 for action in action_list:
--> 744 action()
File d:\{repository_path}\venv\Lib\site-packages\torch\profiler\profiler.py:199, in _KinetoProfile.stop_trace(self)
197 self.execution_trace_observer.stop()
198 assert self.profiler is not None
--> 199 self.profiler.__exit__(None, None, None)
File d:\{repository_path}\venv\Lib\site-packages\torch\autograd\profiler.py:296, in profile.__exit__(self, exc_type, exc_val, exc_tb)
294 if self.use_cuda:
295 torch.cuda.synchronize()
--> 296 self.kineto_results = _disable_profiler()
297 _run_on_profiler_stop()
298 parsed_results = self._parse_kineto_results(self.kineto_results)
RuntimeError: !stack.empty() INTERNAL ASSERT FAILED at "..\\torch\\csrc\\autograd\\profiler_python.cpp":969, please report a bug to PyTorch. Python replay stack is empty.
Sometimes (seems random to be), I get this error:
RuntimeError Traceback (most recent call last)
Cell In[28], [line 208](vscode-notebook-cell:?execution_count=28&line=208)
[189](vscode-notebook-cell:?execution_count=28&line=189) trainer = pl.Trainer(
[190](vscode-notebook-cell:?execution_count=28&line=190) log_every_n_steps=100,
[191](vscode-notebook-cell:?execution_count=28&line=191) max_epochs=max_epochs,
(...)
[202](vscode-notebook-cell:?execution_count=28&line=202) ],
[203](vscode-notebook-cell:?execution_count=28&line=203) )
[205](vscode-notebook-cell:?execution_count=28&line=205) ###########################
[206](vscode-notebook-cell:?execution_count=28&line=206) # Pre-training
[207](vscode-notebook-cell:?execution_count=28&line=207) ###########################
--> [208](vscode-notebook-cell:?execution_count=28&line=208) with profile(
[209](vscode-notebook-cell:?execution_count=28&line=209) activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA],
[210](vscode-notebook-cell:?execution_count=28&line=210) record_shapes=True,
[211](vscode-notebook-cell:?execution_count=28&line=211) profile_memory=True,
[212](vscode-notebook-cell:?execution_count=28&line=212) with_stack=True,
[213](vscode-notebook-cell:?execution_count=28&line=213) ) as prof:
[214](vscode-notebook-cell:?execution_count=28&line=214) with record_function("pretraining-validation"):
[215](vscode-notebook-cell:?execution_count=28&line=215) # perform extra 'validation' epoch to see if untrained model does anything useful
[216](vscode-notebook-cell:?execution_count=28&line=216) trainer.validate(model, dataloader_val_simclr)
File d:\__repos\masters_bacter_private\venv\Lib\site-packages\torch\profiler\profiler.py:695, in profile.__enter__(self)
[694](file:///D:/{repository_path}/venv/Lib/site-packages/torch/profiler/profiler.py:694) def __enter__(self):
--> [695](file:///D:/{repository_path}/venv/Lib/site-packages/torch/profiler/profiler.py:695) self.start()
[696](file:///D:/{repository_path}/venv/Lib/site-packages/torch/profiler/profiler.py:696) return self
File d:\__repos\masters_bacter_private\venv\Lib\site-packages\torch\profiler\profiler.py:705, in profile.start(self)
[704](file:///D:/{repository_path}/venv/Lib/site-packages/torch/profiler/profiler.py:704) def start(self):
--> [705](file:///D:/{repository_path}/venv/Lib/site-packages/torch/profiler/profiler.py:705) self._transit_action(ProfilerAction.NONE, self.current_action)
[706](file:///D:/{repository_path}/venv/Lib/site-packages/torch/profiler/profiler.py:706) if self.record_steps:
[707](file:///D:/{repository_path}/venv/Lib/site-packages/torch/profiler/profiler.py:707) self.step_rec_fn = prof.record_function(
[708](file:///D:/{repository_path}/venv/Lib/site-packages/torch/profiler/profiler.py:708) "ProfilerStep#" + str(self.step_num)
[709](file:///D:/{repository_path}/venv/Lib/site-packages/torch/profiler/profiler.py:709) )
File d:\__repos\masters_bacter_private\venv\Lib\site-packages\torch\profiler\profiler.py:744, in profile._transit_action(self, prev_action, current_action)
[742](file:///D:/{repository_path}/venv/Lib/site-packages/torch/profiler/profiler.py:742) if action_list:
[743](file:///D:/{repository_path}/venv/Lib/site-packages/torch/profiler/profiler.py:743) for action in action_list:
--> [744](file:///D:/{repository_path}/venv/Lib/site-packages/torch/profiler/profiler.py:744) action()
File d:\__repos\masters_bacter_private\venv\Lib\site-packages\torch\profiler\profiler.py:155, in _KinetoProfile.prepare_trace(self)
[141](file:///D:/{repository_path}/venv/Lib/site-packages/torch/profiler/profiler.py:141) def prepare_trace(self):
[142](file:///D:/{repository_path}/venv/Lib/site-packages/torch/profiler/profiler.py:142) self.profiler = prof.profile(
[143](file:///D:/{repository_path}/venv/Lib/site-packages/torch/profiler/profiler.py:143) use_cuda=(ProfilerActivity.CUDA in self.activities),
[144](file:///D:/{repository_path}/venv/Lib/site-packages/torch/profiler/profiler.py:144) use_cpu=(ProfilerActivity.CPU in self.activities),
(...)
[153](file:///D:/{repository_path}/venv/Lib/site-packages/torch/profiler/profiler.py:153) experimental_config=self.experimental_config,
[154](file:///D:/{repository_path}/venv/Lib/site-packages/torch/profiler/profiler.py:154) )
--> [155](file:///D:/{repository_path}/venv/Lib/site-packages/torch/profiler/profiler.py:155) self.profiler._prepare_trace()
File d:\__repos\masters_bacter_private\venv\Lib\site-packages\torch\autograd\profiler.py:284, in profile._prepare_trace(self)
[282](file:///D:/{repository_path}/venv/Lib/site-packages/torch/autograd/profiler.py:282) def _prepare_trace(self):
[283](file:///D:/{repository_path}/venv/Lib/site-packages/torch/autograd/profiler.py:283) self.entered = True
--> [284](file:///D:/{repository_path}/venv/Lib/site-packages/torch/autograd/profiler.py:284) _prepare_profiler(self.config(), self.kineto_activities)
RuntimeError: Can't disable Kineto profiler when it's not running
Environment
PyTorch version: 2.3.1+cu118
Is debug build: False
CUDA used to build PyTorch: 11.8
ROCM used to build PyTorch: N/A
OS: Microsoft Windows 10 Education (10.0.19045 64-bitowy)
GCC version: Could not collect
Clang version: Could not collect
CMake version: Could not collect
Libc version: N/A
Python version: 3.12.3 (tags/v3.12.3:f6650f9, Apr 9 2024, 14:05:25) [MSC v.1938 64 bit (AMD64)]
(64-bit runtime)
Python platform: Windows-10-10.0.19045-SP0
Is CUDA available: True
CUDA runtime version: Could not collect
CUDA_MODULE_LOADING set to: LAZY
GPU models and configuration: GPU 0: NVIDIA GeForce RTX 3060
Nvidia driver version: 560.94
cuDNN version: Could not collect
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True
From what i can think of, you can try doing the following-
Use Lightning’s BuiltIn Profiler
This is the only robust, supported way. Lightning integrates the profiler correctly within its training loop.
trainer = pl.Trainer(
...,
profiler="pytorch", # or profiler="advanced"
)
This triggers profiling at the right spots, per-batch, and avoids context mismatches.
Manual Profiling: Only Profile Inside a Model Step
If you need to profile something custom, do it INSIDE a LightningModule method (like training_step), not outside the trainer.
def training_step(self, batch, batch_idx):
with torch.profiler.profile(
activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA],
record_shapes=True,
profile_memory=True,
with_stack=True,
) as prof:
# ... your forward & loss logic
print(prof.key_averages().table(sort_by="self_cpu_time_total"))
return loss
Bug description
Pytorch Profiler crashes while using it with pytorch-lightning. I am attempting to profile some experiments, but keep getting errors like shown below. I've searched forum and gh issues and I'm aware of the following:
Suspecting / judging from error message, that the problem is related to context management in profiler, I've tried 2 ways of launching it, v1 -> distinct-context-per-stage and v2 -> single-context-for-experiment, but neither have succeded. Remaining parts of experiment, like dataloaders, model, etc. are provided in the environment and so far worked correctly (listed example setup at the very end of this issue, as it's quite a lot of code). Expected behaviour is obviously "no-crashing" and returning / writting relevant profiling information istead.
Will be grateful for any ideas / debugging tips 🙂
What version are you seeing the problem on?
v2.3
How to reproduce the bug
Error messages and logs
Stack traces:
Sometimes (seems random to be), I get this error:
Environment
PyTorch version: 2.3.1+cu118
Is debug build: False
CUDA used to build PyTorch: 11.8
ROCM used to build PyTorch: N/A
OS: Microsoft Windows 10 Education (10.0.19045 64-bitowy)
GCC version: Could not collect
Clang version: Could not collect
CMake version: Could not collect
Libc version: N/A
Python version: 3.12.3 (tags/v3.12.3:f6650f9, Apr 9 2024, 14:05:25) [MSC v.1938 64 bit (AMD64)]
(64-bit runtime)
Python platform: Windows-10-10.0.19045-SP0
Is CUDA available: True
CUDA runtime version: Could not collect
CUDA_MODULE_LOADING set to: LAZY
GPU models and configuration: GPU 0: NVIDIA GeForce RTX 3060
Nvidia driver version: 560.94
cuDNN version: Could not collect
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True
CPU:
Name: AMD Ryzen 7 5800X 8-Core Processor
Manufacturer: AuthenticAMD
Family: 107
Architecture: 9
ProcessorType: 3
DeviceID: CPU0
CurrentClockSpeed: 3801
MaxClockSpeed: 3801
L2CacheSize: 4096
L2CacheSpeed: None
Revision: 8448
Versions of relevant libraries:
[pip3] mypy-extensions==1.0.0
[pip3] numpy==1.26.3
[pip3] pytorch-lightning==2.3.3
[pip3] torch==2.3.1+cu118
[pip3] torch-tb-profiler==0.4.3
[pip3] torchmetrics==1.4.0.post0
[pip3] torchvision==0.18.1+cu118
[conda] Could not collect
More info
No response
The text was updated successfully, but these errors were encountered: