Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

compiled model run in v100 GPU is slower #2008

Closed
stephen-youn opened this issue Dec 21, 2022 · 6 comments
Closed

compiled model run in v100 GPU is slower #2008

stephen-youn opened this issue Dec 21, 2022 · 6 comments
Labels
bad-ux Slow compilation, High memory footprint bug Something isn't working

Comments

@stephen-youn
Copy link

stephen-youn commented Dec 21, 2022

🐛 Describe the bug

Hi,
I tried bert and resnet examples in the tutorial https://pytorch.org/blog/Accelerating-Hugging-Face-and-TIMM-models/
but it ran slower with the "torch.compile" with v100 under unbuntu env i have (i.e., Linux GCRHYP3C148 4.15.0-193-generic #204-Ubuntu SMP)
isn't it supposed to be faster?
thanks

Error logs

No response

Minified repro

"""
resnet
"""

import torch
model = torch.hub.load('pytorch/vision:v0.10.0', 'resnet18', pretrained=True)
opt_model = torch.compile(model, backend="inductor")

start_event = torch.cuda.Event(enable_timing=True)
end_event = torch.cuda.Event(enable_timing=True)
start_event.record()
model(torch.randn(1,3,64,64))
end_event.record()
torch.cuda.synchronize()
estimate_ms = start_event.elapsed_time(end_event)
print(f"estimated_ms={estimate_ms}")

start_event = torch.cuda.Event(enable_timing=True)
end_event = torch.cuda.Event(enable_timing=True)
start_event.record()
opt_model(torch.randn(1,3,64,64))
end_event.record()
torch.cuda.synchronize()
estimate_ms = start_event.elapsed_time(end_event)
print(f"estimated_ms={estimate_ms}")

this runs like the following and the compiled model run 74x slower as shown below

~/project/sandbox$ python hello_torchdynamo4.py
Using cache found in /home/styoun/.cache/torch/hub/pytorch_vision_v0.10.0
estimated_ms=223.81260681152344
estimated_ms=16573.572265625

it's similar for the following bert example in the tutorial. it's 14.7x slower with the extra line "model = torch.compile(model)"

import torch
from transformers import BertTokenizer, BertModel
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
model = BertModel.from_pretrained("bert-base-uncased").to(device="cuda:0")
model = torch.compile(model) # This is the only line of code that we changed
text = "Replace me by any text you'd like."
encoded_input = tokenizer(text, return_tensors='pt').to(device="cuda:0")

start_event = torch.cuda.Event(enable_timing=True)
end_event = torch.cuda.Event(enable_timing=True)
start_event.record()
output = model(**encoded_input)
end_event.record()
torch.cuda.synchronize()
estimate_ms = start_event.elapsed_time(end_event)
print(f"estimated_ms={estimate_ms}")

@stephen-youn stephen-youn added the bug Something isn't working label Dec 21, 2022
@williamwen42
Copy link
Member

Different torch.compile modes may result in different performance results (e.g. torch.compile(model, mode="max-autotune")).

Also, torch.compile will generally take longer on the first pass since it needs to compile, but future passes are expected to be faster than baseline.

@williamwen42 williamwen42 added the bad-ux Slow compilation, High memory footprint label Dec 22, 2022
@stephen-youn
Copy link
Author

i tried to run it twice but it was still slower. is there any suggestion to debug this? (e.g., giving particular option in compile, adding option to make trace or verbose outputs and so on).

@anijain2305
Copy link
Contributor

anijain2305 commented Dec 30, 2022

@stephen-youn Thanks for trying out torch.compile. PyTorch 2.0 compilers are JIT compiler, i.e., they compile the model on the first iteration. In your script, you are measuring the first iteration latency, and hence you are observing the high latency. I modified your script and observing better numbers on A100 GPU (the numbers are not stable, probably because we are measuring just one iteration, but the speedup is evident).

Script

import torch
model = torch.hub.load('pytorch/vision:v0.10.0', 'resnet18', pretrained=True)
opt_model = torch.compile(model, backend="inductor")


# warmup
for _ in range(3):
    model(torch.randn(1,3,64,64))

start_event = torch.cuda.Event(enable_timing=True)
end_event = torch.cuda.Event(enable_timing=True)
start_event.record()
model(torch.randn(1,3,64,64))
end_event.record()
torch.cuda.synchronize()
estimate_ms = start_event.elapsed_time(end_event)
print(f"estimated_ms={estimate_ms}")


# warmup
for _ in range(3):
    opt_model(torch.randn(1,3,64,64))

start_event = torch.cuda.Event(enable_timing=True)
end_event = torch.cuda.Event(enable_timing=True)
start_event.record()
opt_model(torch.randn(1,3,64,64))
end_event.record()
torch.cuda.synchronize()
estimate_ms = start_event.elapsed_time(end_event)
print(f"estimated_ms={estimate_ms}")

Output

estimated_ms=1222.14990234375
estimated_ms=326.3006286621094

Please let me know if you have any other questions. Please feel free to close the bug if your question is answered.

@stephen-youn
Copy link
Author

stephen-youn commented Dec 30, 2022

yes i also modified the code similarly and got a perf gain in v100 too.
one follow-up question is what's the differences between torch.compile(model, passes={"triton-autotune":True}) and torch.compile(model, backend="inductor").
does one use triton for matmul and the other dont?
what's the default matmul kernel in inductor, isn't it a triton?
but it seems the default mm is set to "aten" not the "triton" (link)
how can I make sure I use the triton for matmuls?

@anijain2305
Copy link
Contributor

anijain2305 commented Dec 30, 2022

@stephen-youn

  • backend="inductor" uses TorchInductor backend. This is also the default backend, so torch.compile(model, passes={'triton-autotune":True}) is equivalent to torch.compile(model, backend="inductor", passes={'triton-autotune":True})
  • passes argument can be used to setup TorchInductor flags. The triton_autotune flag is already set to True as default. triton_autotune is not used for tuning matmul operations. It is used for tuning the fused kernels (pointwise, reduction, scatter etc). So, all of these are exactly same
    • torch.compile(mod)
    • torch.compile(mod, backend="inductor")
    • torch.compile(model, passes={"triton-autotune":True})

Reading between the lines, it seems you are interested in mm operators. For those

  • By default, Inductor falls back to aten implementation for mm/bmm ops. We do not use Triton to generate the code for these matmul ops.
  • If you want to use Triton for matmul, you could use passes={'triton-mm': True, 'triton-bmm': True}. This part is not super heavily tested, so please be gentle. Do open issues if you see issues.

@stephen-youn
Copy link
Author

I tried "opt_model = torch.compile(model, passes={'triton-mm': "triton", 'triton-bmm': True})"
but it crashed.
so i opened an issue here (link)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bad-ux Slow compilation, High memory footprint bug Something isn't working
Projects
None yet
Development

No branches or pull requests

3 participants