Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

thunder.jit has a relatively high CPU overhead when processing small graphs with small inputs. #1657

Open
kiya00 opened this issue Jan 17, 2025 · 4 comments

Comments

@kiya00
Copy link
Collaborator

kiya00 commented Jan 17, 2025

Note: If you have a model or program that is not supported yet but should be, please use the program coverage template.

🐛 Bug

thunder.jit has a relatively high CPU overhead when processing small graphs with small inputs.

To Reproduce

import thunder
import torch
import time

class DynamoModule(torch.nn.Module):
    def forward(self, l_x_ : torch.Tensor):
        x = torch.sin(l_x_);  l_x_ = None
        sum_1 = x.sum()
        gt = sum_1 > 0;  sum_1 = None
        return (x, gt)

inputs = [
    torch.testing.make_tensor((319999,), dtype=torch.int64,  device='cuda:0', requires_grad=False, low=3, high=9,).as_strided((400, 400), (800, 2)),
]

mod = DynamoModule()

jitted = thunder.jit(mod) # ,nv_store_fusion_inputs=True

def measure(f, name):
    torch.cuda.synchronize()
    #warmup
    for i in range(100):
        f(*inputs)

    torch.cuda.synchronize()
    st = time.time()
    # torch.cuda.cudart().cudaProfilerStart()
    for i in range(10000):
        # torch.cuda.nvtx.range_push("fforward")
        f(*inputs)
        # torch.cuda.nvtx.range_pop()
    torch.cuda.synchronize()
    # torch.cuda.cudart().cudaProfilerStop()
    print(f"{name}:", (time.time()-st)/10000*1000*1000)


measure(jitted, "thunder")
measure(mod, "eager")

# trc = thunder.last_traces(jitted)[-1]
# # from thunder.examine import get_nvfuser_repro
# # print(get_nvfuser_repro(trc, "nvFusion0"))
# print(trc)

### script get from `get_nvfuser_repro`
# CUDA devices:
#  0: NVIDIA RTX 6000 Ada Generation
#  1: NVIDIA RTX 6000 Ada Generation
# torch version: 2.7.0a0+gitc3b2849
# cuda version: 12.8
# nvfuser version: 0.2.24+git9ce2112
import torch
from nvfuser import FusionDefinition, DataType

def nvfuser_fusion_id0(fd : FusionDefinition) -> None :
    T0 = fd.define_tensor(shape=[400, 400], contiguity=[True, False], dtype=DataType.Int, is_cpu=False, stride_order=[1, 0])
    T1 = fd.ops.cast(T0, dtype=DataType.Float)
    T2 = fd.ops.sin(T1)
    T3 = fd.ops.sum(T2, dims=[0, 1], keepdim=False, dtype=DataType.Null)
    S4 = fd.define_scalar(0.00000, dtype=DataType.Double)
    T5 = fd.ops.gt(T3, S4)
    fd.add_output(T2)
    fd.add_output(T5)

with FusionDefinition() as fd:
    nvfuser_fusion_id0(fd)

# inputs = [
#     torch.randint(0, 10, (319999,), dtype=torch.int64, device='cuda:0').as_strided((400, 400), (800, 2)),
# ]

for i in range(10):
    fd.execute(inputs)

torch.cuda.synchronize()
st = time.time()
for i in range(10000):
    fd.execute(inputs)
torch.cuda.synchronize()
print("nvfuser api:", (time.time()-st)/10000*1000*1000)

output:

thunder: 81.38272762298584
eager: 12.05604076385498
nvfuser api: 17.391753196716305

Note that the nvfuser function is obtained from get_nvfuser_repro, which should be identical to what Thunder executes using the nvfuser backend. However, the execution time is significantly longer compared to the nvfuser API version.
The trace of the thunder.jit:

@torch.no_grad()
@no_autocast
def computation(l_x_):
  # l_x_: "cuda:0 i64[400, 400]"
  [x, gt] = nvFusion0(l_x_)
    # t0 = prims.convert_element_type(l_x_, dtypes.float32_)  # t0: "cuda:0 f32[400, 400]"
    # x = prims.sin(t0)  # x: "cuda:0 f32[400, 400]"
    # sum_1 = prims.sum(x, (0, 1))  # sum_1: "cuda:0 f32[]"
    # gt = prims.gt(sum_1, 0.0)  # gt: "cuda:0 b8[]"
  return (x, gt)

I believe the difference is due to CPU overhead in thunder.jit. Since the kernel finishes very quickly in this case, it’s likely not an issue for real models. Please feel free to close this if you don't consider it a problem.

cc: @mruberry

@mruberry
Copy link
Collaborator

Thanks for filing this issue, @kiya00! In the future may we can just use the entire original reproducer script with the pytest benchmarking so you don't have to edit it?

@kiya00
Copy link
Collaborator Author

kiya00 commented Jan 20, 2025

In the future may we can just use the entire original reproducer script with the pytest benchmarking so you don't have to edit it?

you mean use the reproducer script to get the performance number without edit it? yes we can do it now

@kiya00
Copy link
Collaborator Author

kiya00 commented Jan 20, 2025

Since this issue is a known issue that thunder will have relatively noticeable overhead when GPU kernels finish very fast(in this case the graph is small(3 ops) and inputs are small), maybe we can close this issue for now? WDYT @mruberry ?

@mruberry
Copy link
Collaborator

Since this issue is a known issue that thunder will have relatively noticeable overhead when GPU kernels finish very fast(in this case the graph is small(3 ops) and inputs are small), maybe we can close this issue for now? WDYT @mruberry ?

I think we can leave this open unless it's a duplicate of another issue.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

2 participants