Skip to content

aten::nonzero calls taking a huge amount of time when using MPS backend vs CPU #124850

Open
pytorch/vision
#9100
@theo-costain-arondite

Description

@theo-costain-arondite

🐛 Describe the bug

I found that running a torchvision model under MPS backend was extremely slow compared to cpu.
I ran the profiler and found that the vast majority of that time was coming from a small number of calls to aten::nonzero.
Using the repro below with cpu device takes ~1s to run, but switching to mps increases this to ~75s, most of which is spent in aten::nonzero.
I wonder if this might be related to #122916.

repro

import torch
import torchvision.models as models
from tqdm import tqdm

device = "mps"

inputs = torch.randn(3, 224, 224).to(device)

transform = models.detection.FasterRCNN_ResNet50_FPN_V2_Weights.DEFAULT.transforms()

dectection_model = models.detection.fasterrcnn_resnet50_fpn_v2(
    weights=models.detection.FasterRCNN_ResNet50_FPN_V2_Weights.DEFAULT,
).to(device)
dectection_model.eval()

from torch.profiler import ProfilerActivity, profile, record_function

with profile(activities=[ProfilerActivity.CPU], record_shapes=True) as prof:
    with record_function("model_inference"):
        out = dectection_model([transform(inputs)])

print(
    prof.key_averages(group_by_input_shape=True).table(
        sort_by="cpu_time_total", row_limit=10
    )
)
prof.export_chrome_trace(f"trace_test_{device}.json")

CPU profile results

-------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  --------------------------------------------------------------------------------  
                                 Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg    # of Calls                                                                      Input Shapes  
-------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  --------------------------------------------------------------------------------  
                      model_inference         2.25%      26.482ms       100.00%        1.176s        1.176s             1                                                                                []  
                         aten::conv2d         0.00%      14.000us        38.19%     449.105ms     112.276ms             4                         [[1000, 256, 7, 7], [256, 256, 3, 3], [], [], [], [], []]  
                    aten::convolution         0.00%      30.000us        38.19%     449.091ms     112.273ms             4                 [[1000, 256, 7, 7], [256, 256, 3, 3], [], [], [], [], [], [], []]  
                   aten::_convolution         0.00%      33.000us        38.19%     449.061ms     112.265ms             4  [[1000, 256, 7, 7], [256, 256, 3, 3], [], [], [], [], [], [], [], [], [], [], []  
    aten::_nnpack_spatial_convolution        38.18%     448.984ms        38.19%     449.028ms     112.257ms             4                                 [[1000, 256, 7, 7], [256, 256, 3, 3], [], [], []]  
                         aten::conv2d         0.00%       6.000us         9.21%     108.300ms      54.150ms             2                     [[1, 256, 200, 200], [256, 256, 3, 3], [256], [], [], [], []]  
                    aten::convolution         0.00%      15.000us         9.21%     108.294ms      54.147ms             2             [[1, 256, 200, 200], [256, 256, 3, 3], [256], [], [], [], [], [], []]  
                   aten::_convolution         0.00%      13.000us         9.21%     108.279ms      54.139ms             2  [[1, 256, 200, 200], [256, 256, 3, 3], [256], [], [], [], [], [], [], [], [], []  
                    aten::thnn_conv2d         0.00%       4.000us         9.21%     108.266ms      54.133ms             2                         [[1, 256, 200, 200], [256, 256, 3, 3], [], [256], [], []]  
           aten::_slow_conv2d_forward         8.96%     105.342ms         9.21%     108.262ms      54.131ms             2                         [[1, 256, 200, 200], [256, 256, 3, 3], [], [256], [], []]  
-------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  --------------------------------------------------------------------------------  
Self CPU time total: 1.176s

MPS profile results

--------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  --------------------------------------------------------------------------------  
                            Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg    # of Calls                                                                      Input Shapes  
--------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  --------------------------------------------------------------------------------  
                 model_inference        -0.00%   -2260.000us       100.00%       74.706s       74.706s             1                                                                                []  
                     aten::where        -0.00%     -17.000us        99.25%       74.143s       18.536s             4                                                                          [[1000]]  
             aten::nonzero_numpy         0.00%      38.000us        99.25%       74.143s       18.536s             4                                                                          [[1000]]  
                   aten::nonzero        99.24%       74.138s        99.25%       74.143s       18.536s             4                                                                          [[1000]]  
                       aten::cat         0.06%      41.956ms         0.06%      41.956ms       1.998ms            21                                                                          [[], []]  
                       aten::sub         0.05%      34.111ms         0.05%      34.111ms      34.111ms             1                                                    [[3, 224, 224], [3, 1, 1], []]  
                        aten::to        -0.00%     -88.000us         0.04%      32.683ms       1.421ms            23                                                          [[], [], [], [], [], []]  
                  aten::_to_copy         0.00%     676.000us         0.04%      32.681ms       2.723ms            12                                                      [[], [], [], [], [], [], []]  
                     aten::copy_         0.04%      32.618ms         0.04%      32.645ms       2.720ms            12                                                                      [[], [], []]  
                     aten::stack         0.00%      64.000us         0.03%      23.270ms       2.327ms            10                                                                          [[], []]  
--------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  --------------------------------------------------------------------------------  
Self CPU time total: 74.706s

P.S. I think running the mps repro above might have hard crashed my laptop (it happened whilst i was writing this issue for the first time), but I don't have access to another machine to test that this isn't an issue with my macine.

Versions

Collecting environment information...
PyTorch version: 2.2.2
Is debug build: False
CUDA used to build PyTorch: None
ROCM used to build PyTorch: N/A

OS: macOS 14.4.1 (arm64)
GCC version: Could not collect
Clang version: 15.0.0 (clang-1500.3.9.4)
CMake version: Could not collect
Libc version: N/A

Python version: 3.12.2 | packaged by Anaconda, Inc. | (main, Feb 27 2024, 12:57:28) [Clang 14.0.6 ] (64-bit runtime)
Python platform: macOS-14.4.1-arm64-arm-64bit
Is CUDA available: False
CUDA runtime version: No CUDA
CUDA_MODULE_LOADING set to: N/A
GPU models and configuration: No CUDA
Nvidia driver version: No CUDA
cuDNN version: No CUDA
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True

CPU:
Apple M3 Pro

Versions of relevant libraries:
[pip3] numpy==1.26.4
[pip3] torch==2.2.2
[pip3] torchvision==0.17.2
[conda] numpy 1.26.4 py312h7f4fdc5_0
[conda] numpy-base 1.26.4 py312he047099_0
[conda] pytorch 2.2.2 py3.12_0 pytorch
[conda] torchvision 0.17.2 py312_cpu pytorch

cc @ezyang @gchanan @zou3519 @kadeng @kulinseth @albanD @malfet @DenisVieriu97 @jhavukainen

Metadata

Metadata

Assignees

Labels

high prioritymodule: crashProblem manifests as a hard crash, as opposed to a RuntimeErrormodule: mpsRelated to Apple Metal Performance Shaders frameworkmodule: performanceIssues related to performance, either of kernel code or framework gluetriagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate module

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions