Description
🐛 Describe the bug
I found that running a torchvision model under MPS backend was extremely slow compared to cpu.
I ran the profiler and found that the vast majority of that time was coming from a small number of calls to aten::nonzero.
Using the repro below with cpu
device takes ~1s to run, but switching to mps
increases this to ~75s, most of which is spent in aten::nonzero.
I wonder if this might be related to #122916.
repro
import torch
import torchvision.models as models
from tqdm import tqdm
device = "mps"
inputs = torch.randn(3, 224, 224).to(device)
transform = models.detection.FasterRCNN_ResNet50_FPN_V2_Weights.DEFAULT.transforms()
dectection_model = models.detection.fasterrcnn_resnet50_fpn_v2(
weights=models.detection.FasterRCNN_ResNet50_FPN_V2_Weights.DEFAULT,
).to(device)
dectection_model.eval()
from torch.profiler import ProfilerActivity, profile, record_function
with profile(activities=[ProfilerActivity.CPU], record_shapes=True) as prof:
with record_function("model_inference"):
out = dectection_model([transform(inputs)])
print(
prof.key_averages(group_by_input_shape=True).table(
sort_by="cpu_time_total", row_limit=10
)
)
prof.export_chrome_trace(f"trace_test_{device}.json")
CPU profile results
------------------------------------- ------------ ------------ ------------ ------------ ------------ ------------ --------------------------------------------------------------------------------
Name Self CPU % Self CPU CPU total % CPU total CPU time avg # of Calls Input Shapes
------------------------------------- ------------ ------------ ------------ ------------ ------------ ------------ --------------------------------------------------------------------------------
model_inference 2.25% 26.482ms 100.00% 1.176s 1.176s 1 []
aten::conv2d 0.00% 14.000us 38.19% 449.105ms 112.276ms 4 [[1000, 256, 7, 7], [256, 256, 3, 3], [], [], [], [], []]
aten::convolution 0.00% 30.000us 38.19% 449.091ms 112.273ms 4 [[1000, 256, 7, 7], [256, 256, 3, 3], [], [], [], [], [], [], []]
aten::_convolution 0.00% 33.000us 38.19% 449.061ms 112.265ms 4 [[1000, 256, 7, 7], [256, 256, 3, 3], [], [], [], [], [], [], [], [], [], [], []
aten::_nnpack_spatial_convolution 38.18% 448.984ms 38.19% 449.028ms 112.257ms 4 [[1000, 256, 7, 7], [256, 256, 3, 3], [], [], []]
aten::conv2d 0.00% 6.000us 9.21% 108.300ms 54.150ms 2 [[1, 256, 200, 200], [256, 256, 3, 3], [256], [], [], [], []]
aten::convolution 0.00% 15.000us 9.21% 108.294ms 54.147ms 2 [[1, 256, 200, 200], [256, 256, 3, 3], [256], [], [], [], [], [], []]
aten::_convolution 0.00% 13.000us 9.21% 108.279ms 54.139ms 2 [[1, 256, 200, 200], [256, 256, 3, 3], [256], [], [], [], [], [], [], [], [], []
aten::thnn_conv2d 0.00% 4.000us 9.21% 108.266ms 54.133ms 2 [[1, 256, 200, 200], [256, 256, 3, 3], [], [256], [], []]
aten::_slow_conv2d_forward 8.96% 105.342ms 9.21% 108.262ms 54.131ms 2 [[1, 256, 200, 200], [256, 256, 3, 3], [], [256], [], []]
------------------------------------- ------------ ------------ ------------ ------------ ------------ ------------ --------------------------------------------------------------------------------
Self CPU time total: 1.176s
MPS profile results
-------------------------------- ------------ ------------ ------------ ------------ ------------ ------------ --------------------------------------------------------------------------------
Name Self CPU % Self CPU CPU total % CPU total CPU time avg # of Calls Input Shapes
-------------------------------- ------------ ------------ ------------ ------------ ------------ ------------ --------------------------------------------------------------------------------
model_inference -0.00% -2260.000us 100.00% 74.706s 74.706s 1 []
aten::where -0.00% -17.000us 99.25% 74.143s 18.536s 4 [[1000]]
aten::nonzero_numpy 0.00% 38.000us 99.25% 74.143s 18.536s 4 [[1000]]
aten::nonzero 99.24% 74.138s 99.25% 74.143s 18.536s 4 [[1000]]
aten::cat 0.06% 41.956ms 0.06% 41.956ms 1.998ms 21 [[], []]
aten::sub 0.05% 34.111ms 0.05% 34.111ms 34.111ms 1 [[3, 224, 224], [3, 1, 1], []]
aten::to -0.00% -88.000us 0.04% 32.683ms 1.421ms 23 [[], [], [], [], [], []]
aten::_to_copy 0.00% 676.000us 0.04% 32.681ms 2.723ms 12 [[], [], [], [], [], [], []]
aten::copy_ 0.04% 32.618ms 0.04% 32.645ms 2.720ms 12 [[], [], []]
aten::stack 0.00% 64.000us 0.03% 23.270ms 2.327ms 10 [[], []]
-------------------------------- ------------ ------------ ------------ ------------ ------------ ------------ --------------------------------------------------------------------------------
Self CPU time total: 74.706s
P.S. I think running the mps
repro above might have hard crashed my laptop (it happened whilst i was writing this issue for the first time), but I don't have access to another machine to test that this isn't an issue with my macine.
Versions
Collecting environment information...
PyTorch version: 2.2.2
Is debug build: False
CUDA used to build PyTorch: None
ROCM used to build PyTorch: N/A
OS: macOS 14.4.1 (arm64)
GCC version: Could not collect
Clang version: 15.0.0 (clang-1500.3.9.4)
CMake version: Could not collect
Libc version: N/A
Python version: 3.12.2 | packaged by Anaconda, Inc. | (main, Feb 27 2024, 12:57:28) [Clang 14.0.6 ] (64-bit runtime)
Python platform: macOS-14.4.1-arm64-arm-64bit
Is CUDA available: False
CUDA runtime version: No CUDA
CUDA_MODULE_LOADING set to: N/A
GPU models and configuration: No CUDA
Nvidia driver version: No CUDA
cuDNN version: No CUDA
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True
CPU:
Apple M3 Pro
Versions of relevant libraries:
[pip3] numpy==1.26.4
[pip3] torch==2.2.2
[pip3] torchvision==0.17.2
[conda] numpy 1.26.4 py312h7f4fdc5_0
[conda] numpy-base 1.26.4 py312he047099_0
[conda] pytorch 2.2.2 py3.12_0 pytorch
[conda] torchvision 0.17.2 py312_cpu pytorch
cc @ezyang @gchanan @zou3519 @kadeng @kulinseth @albanD @malfet @DenisVieriu97 @jhavukainen