-
Notifications
You must be signed in to change notification settings - Fork 86
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
nvFuser using more memory than inductor for HF CausalLMLoss #1654
Labels
Comments
This was referenced Jan 17, 2025
Just to add, interesting enough, if cross entropy is claimed by the apex executor, the memory footprint seems to remain the same. Therefore it is important for nvFuser to be able to make a single fusion with reshapes and cross entropy. nvFuser + apex trace, peak memory 6228.60 MB# Constructed by Delete Last Used (took 0 milliseconds)
import torch
from thunder.executors.torchex import no_autocast
@torch.no_grad()
@no_autocast
def computation(l_logits_, l_labels_):
# l_logits_: "cuda:0 bf16[1, 4096, 152064]"
# l_labels_: "cuda:0 i64[1, 4096]"
[shift_logits_1, shift_labels_1] = nvFusion0(l_logits_, l_labels_)
# logits = prims.convert_element_type(l_logits_, dtypes.float32) # logits: "cuda:0 f32[1, 4096, 152064]"
# getitem = prims.slice_prim(logits, [0, 0, 0], [1, 4095, 152064], [1, 1, 1]) # getitem: "cuda:0 f32[1, 4095, 152064]"
# shift_logits = prims.stride_order(getitem, (2, 1, 0)) # shift_logits: "cuda:0 f32[1, 4095, 152064]"
# getitem_1 = prims.slice_prim(l_labels_, [0, 1], [1, 4096], [1, 1]) # getitem_1: "cuda:0 i64[1, 4095]"
# shift_labels = prims.stride_order(getitem_1, (1, 0)) # shift_labels: "cuda:0 i64[1, 4095]"
# shift_logits_1 = prims.reshape(shift_logits, (4095, 152064)) # shift_logits_1: "cuda:0 f32[4095, 152064]"
# shift_labels_1 = prims.reshape(shift_labels, (4095,)) # shift_labels_1: "cuda:0 i64[4095]"
# <eval_with_key>.4:14: loss = torch.nn.functional.cross_entropy(shift_logits_1, shift_labels_2, ignore_index = -100, reduction = 'mean'); shift_logits_1 = shift_labels_2 = None
(loss, _) = apex_cross_entropy(shift_logits_1, shift_labels_1, 'mean', 0.0)
return {'output': (loss,), 'flat_args': [l_logits_, l_labels_], 'flat_output': (loss,)}, ((shift_labels_1, shift_logits_1), ()) |
Adding Inductor's generated code here for debugging: # AOT ID: ['0_forward']
from ctypes import c_void_p, c_long, c_int
import torch
import math
import random
import os
import tempfile
from math import inf, nan
from torch._inductor.hooks import run_intermediate_hooks
from torch._inductor.utils import maybe_profile
from torch._inductor.codegen.memory_planning import _align as align
from torch import device, empty_strided
from torch._inductor.async_compile import AsyncCompile
from torch._inductor.select_algorithm import extern_kernels
from torch._inductor.codegen.multi_kernel import MultiKernelCall
import triton
import triton.language as tl
from torch._inductor.runtime.triton_heuristics import grid, split_scan_grid, grid_combo_kernels, start_graph, end_graph
from torch._C import _cuda_getCurrentRawStream as get_raw_stream
aten = torch.ops.aten
inductor_ops = torch.ops.inductor
_quantized = torch.ops._quantized
assert_size_stride = torch._C._dynamo.guards.assert_size_stride
empty_strided_cpu = torch._C._dynamo.guards._empty_strided_cpu
empty_strided_cuda = torch._C._dynamo.guards._empty_strided_cuda
empty_strided_xpu = torch._C._dynamo.guards._empty_strided_xpu
reinterpret_tensor = torch._C._dynamo.guards._reinterpret_tensor
alloc_from_pool = torch.ops.inductor._alloc_from_pool
async_compile = AsyncCompile()
# Topologically Sorted Source Nodes: [labels_1], Original ATen: [aten.constant_pad_nd]
# Source node to ATen node mapping:
# labels_1 => constant_pad_nd
# Graph fragment:
# %constant_pad_nd : [num_users=2] = call_function[target=torch.ops.aten.constant_pad_nd.default](args = (%primals_2, [0, 1], -100.0), kwargs = {})
triton_poi_fused_constant_pad_nd_0 = async_compile.triton('triton_', '''
import triton
import triton.language as tl
from triton.compiler.compiler import AttrsDescriptor
from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, instance_descriptor, DeviceProperties
@triton_heuristics.pointwise(
size_hints=[8192],
filename=__file__,
triton_meta={'signature': {0: '*i64', 1: '*i64', 2: 'i32'}, 'device': DeviceProperties(type='cuda', index=0, cc=89, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=1536, multi_processor_count=142), 'constants': {}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1), equal_to_1=())]},
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_constant_pad_nd_0', 'mutated_arg_names': [], 'no_x_dim': False, 'num_load': 1, 'num_reduction': 0, 'backend_hash': 'EE1FF1A692BB318A52D1FC9AAB9D4731B4C6A59706D9DFCBED88CECB030EA339', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False},
min_elem_per_thread=0
)
@triton.jit
def triton_(in_ptr0, out_ptr0, xnumel, XBLOCK : tl.constexpr):
xnumel = 4097
xoffset = tl.program_id(0) * XBLOCK
xindex = xoffset + tl.arange(0, XBLOCK)[:]
xmask = xindex < xnumel
x0 = xindex
tmp0 = x0
tmp1 = tl.full([1], 4096, tl.int64)
tmp2 = tmp0 < tmp1
tmp3 = tl.load(in_ptr0 + (x0), tmp2 & xmask, other=-100)
tl.store(out_ptr0 + (x0), tmp3, xmask)
''', device_str='cuda')
# Topologically Sorted Source Nodes: [loss], Original ATen: [aten._log_softmax]
# Source node to ATen node mapping:
# loss => amax, exp, log, sub, sum_1
# Graph fragment:
# %amax : [num_users=2] = call_function[target=torch.ops.aten.amax.default](args = (%view, [1], True), kwargs = {})
# %sub : [num_users=2] = call_function[target=torch.ops.aten.sub.Tensor](args = (%view, %amax), kwargs = {})
# %exp : [num_users=1] = call_function[target=torch.ops.aten.exp.default](args = (%sub,), kwargs = {})
# %sum_1 : [num_users=1] = call_function[target=torch.ops.aten.sum.dim_IntList](args = (%exp, [1], True), kwargs = {})
# %log : [num_users=2] = call_function[target=torch.ops.aten.log.default](args = (%sum_1,), kwargs = {})
triton_red_fused__log_softmax_1 = async_compile.triton('triton_', '''
import triton
import triton.language as tl
from triton.compiler.compiler import AttrsDescriptor
from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, instance_descriptor, DeviceProperties
@triton_heuristics.reduction(
size_hints=[4096, 262144],
reduction_hint=ReductionHint.INNER,
filename=__file__,
triton_meta={'signature': {0: '*fp32', 1: '*bf16', 2: '*fp32', 3: 'i32', 4: 'i32'}, 'device': DeviceProperties(type='cuda', index=0, cc=89, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=1536, multi_processor_count=142), 'constants': {}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1, 2, 3, 4), equal_to_1=())]},
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_red_fused__log_softmax_1', 'mutated_arg_names': ['in_out_ptr0'], 'no_x_dim': False, 'num_load': 2, 'num_reduction': 2, 'backend_hash': 'EE1FF1A692BB318A52D1FC9AAB9D4731B4C6A59706D9DFCBED88CECB030EA339', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False}
)
@triton.jit
def triton_(in_out_ptr0, in_ptr0, out_ptr0, xnumel, rnumel, XBLOCK : tl.constexpr, RBLOCK : tl.constexpr):
xnumel = 4096
rnumel = 152064
xoffset = tl.program_id(0) * XBLOCK
xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
xmask = tl.full([XBLOCK, RBLOCK], True, tl.int1)
rbase = tl.arange(0, RBLOCK)[None, :]
x0 = xindex
_tmp3 = tl.full([XBLOCK, RBLOCK], float("-inf"), tl.float32)
for roffset in range(0, rnumel, RBLOCK):
rindex = roffset + rbase
rmask = rindex < rnumel
r1 = rindex
tmp0 = tl.load(in_ptr0 + (r1 + (152064*x0)), rmask, eviction_policy='evict_last', other=0.0).to(tl.float32)
tmp1 = tmp0.to(tl.float32)
tmp2 = tl.broadcast_to(tmp1, [XBLOCK, RBLOCK])
tmp4 = triton_helpers.maximum(_tmp3, tmp2)
_tmp3 = tl.where(rmask, tmp4, _tmp3)
tmp3 = triton_helpers.max2(_tmp3, 1)[:, None]
tl.store(out_ptr0 + (x0), tmp3, None)
_tmp10 = tl.full([XBLOCK, RBLOCK], 0, tl.float32)
for roffset in range(0, rnumel, RBLOCK):
rindex = roffset + rbase
rmask = rindex < rnumel
r1 = rindex
tmp5 = tl.load(in_ptr0 + (r1 + (152064*x0)), rmask, eviction_policy='evict_first', other=0.0).to(tl.float32)
tmp6 = tmp5.to(tl.float32)
tmp7 = tmp6 - tmp3
tmp8 = tl_math.exp(tmp7)
tmp9 = tl.broadcast_to(tmp8, [XBLOCK, RBLOCK])
tmp11 = _tmp10 + tmp9
_tmp10 = tl.where(rmask, tmp11, _tmp10)
tmp10 = tl.sum(_tmp10, 1)[:, None]
tmp12 = tl_math.log(tmp10)
tl.debug_barrier()
tl.store(in_out_ptr0 + (x0), tmp12, None)
''', device_str='cuda')
# Topologically Sorted Source Nodes: [loss], Original ATen: [aten.nll_loss_forward]
# Source node to ATen node mapping:
# loss => convert_element_type_1, div, full_default_1, ne, neg, sum_2, sum_3, where_1
# Graph fragment:
# %ne : [num_users=3] = call_function[target=torch.ops.aten.ne.Scalar](args = (%view_1, -100), kwargs = {})
# %neg : [num_users=1] = call_function[target=torch.ops.aten.neg.default](args = (%squeeze,), kwargs = {})
# %full_default_1 : [num_users=1] = call_function[target=torch.ops.aten.full.default](args = ([], 0.0), kwargs = {dtype: torch.float32, layout: torch.strided, device: cuda:0, pin_memory: False})
# %where_1 : [num_users=1] = call_function[target=torch.ops.aten.where.self](args = (%ne, %neg, %full_default_1), kwargs = {})
# %sum_2 : [num_users=1] = call_function[target=torch.ops.aten.sum.default](args = (%ne,), kwargs = {})
# %convert_element_type_1 : [num_users=2] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%sum_2, torch.float32), kwargs = {})
# %sum_3 : [num_users=1] = call_function[target=torch.ops.aten.sum.default](args = (%where_1,), kwargs = {})
# %div : [num_users=1] = call_function[target=torch.ops.aten.div.Tensor](args = (%sum_3, %convert_element_type_1), kwargs = {})
triton_red_fused_nll_loss_forward_2 = async_compile.triton('triton_', '''
import triton
import triton.language as tl
from triton.compiler.compiler import AttrsDescriptor
from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, instance_descriptor, DeviceProperties
@triton_heuristics.reduction(
size_hints=[1, 4096],
reduction_hint=ReductionHint.INNER,
filename=__file__,
triton_meta={'signature': {0: '*fp32', 1: '*i64', 2: '*bf16', 3: '*fp32', 4: '*fp32', 5: '*fp32', 6: 'i32', 7: 'i32'}, 'device': DeviceProperties(type='cuda', index=0, cc=89, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=1536, multi_processor_count=142), 'constants': {6: 1}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1, 2, 3, 4, 5, 7), equal_to_1=(6,))]},
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_red_fused_nll_loss_forward_2', 'mutated_arg_names': ['in_out_ptr0'], 'no_x_dim': False, 'num_load': 3, 'num_reduction': 2, 'backend_hash': 'EE1FF1A692BB318A52D1FC9AAB9D4731B4C6A59706D9DFCBED88CECB030EA339', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False}
)
@triton.jit
def triton_(in_out_ptr0, in_ptr0, in_ptr1, in_ptr2, in_ptr3, out_ptr1, xnumel, rnumel, XBLOCK : tl.constexpr, RBLOCK : tl.constexpr):
xnumel = 1
rnumel = 4096
xoffset = tl.program_id(0) * XBLOCK
xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
xmask = tl.full([XBLOCK, RBLOCK], True, tl.int1)
rbase = tl.arange(0, RBLOCK)[None, :]
_tmp5 = tl.full([XBLOCK, RBLOCK], 0, tl.int64)
_tmp24 = tl.full([XBLOCK, RBLOCK], 0, tl.float32)
for roffset in range(0, rnumel, RBLOCK):
rindex = roffset + rbase
rmask = rindex < rnumel
r0 = rindex
tmp0 = tl.load(in_ptr0 + (1 + r0), rmask, eviction_policy='evict_first', other=0.0)
tmp16 = tl.load(in_ptr2 + (r0), rmask, eviction_policy='evict_first', other=0.0)
tmp18 = tl.load(in_ptr3 + (r0), rmask, eviction_policy='evict_first', other=0.0)
tmp1 = tl.full([1, 1], -100, tl.int64)
tmp2 = tmp0 != tmp1
tmp3 = tmp2.to(tl.int64)
tmp4 = tl.broadcast_to(tmp3, [XBLOCK, RBLOCK])
tmp6 = _tmp5 + tmp4
_tmp5 = tl.where(rmask, tmp6, _tmp5)
tmp7 = tl.full([1, 1], 0, tl.int64)
tmp8 = tl.where(tmp2, tmp0, tmp7)
tmp9 = tl.full([XBLOCK, RBLOCK], 152064, tl.int32)
tmp10 = tmp8 + tmp9
tmp11 = tmp8 < 0
tmp12 = tl.where(tmp11, tmp10, tmp8)
tl.device_assert(((0 <= tmp12) & (tmp12 < 152064)) | ~(rmask), "index out of bounds: 0 <= tmp12 < 152064")
tmp14 = tl.load(in_ptr1 + (tmp12 + (152064*r0)), rmask, eviction_policy='evict_last', other=0.0).to(tl.float32)
tmp15 = tmp14.to(tl.float32)
tmp17 = tmp15 - tmp16
tmp19 = tmp17 - tmp18
tmp20 = -tmp19
tmp21 = 0.0
tmp22 = tl.where(tmp2, tmp20, tmp21)
tmp23 = tl.broadcast_to(tmp22, [XBLOCK, RBLOCK])
tmp25 = _tmp24 + tmp23
_tmp24 = tl.where(rmask, tmp25, _tmp24)
tmp5 = tl.sum(_tmp5, 1)[:, None]
tmp24 = tl.sum(_tmp24, 1)[:, None]
tmp26 = tmp5.to(tl.float32)
tmp27 = tmp24 / tmp26
tl.store(out_ptr1 + (tl.full([XBLOCK, 1], 0, tl.int32)), tmp26, None)
tl.debug_barrier()
tl.store(in_out_ptr0 + (tl.full([XBLOCK, 1], 0, tl.int32)), tmp27, None)
''', device_str='cuda')
async_compile.wait(globals())
del async_compile
def call(args):
primals_1, primals_2 = args
args.clear()
assert_size_stride(primals_1, (1, 4096, 152064), (622854144, 152064, 1))
assert_size_stride(primals_2, (1, 4096), (4096, 1))
with torch.cuda._DeviceGuard(0):
torch.cuda.set_device(0)
buf0 = empty_strided_cuda((1, 4097), (4112, 1), torch.int64)
# Topologically Sorted Source Nodes: [labels_1], Original ATen: [aten.constant_pad_nd]
stream0 = get_raw_stream(0)
triton_poi_fused_constant_pad_nd_0.run(primals_2, buf0, 4097, grid=grid(4097), stream=stream0)
del primals_2
buf1 = empty_strided_cuda((4096, 1), (1, 1), torch.float32)
buf2 = empty_strided_cuda((4096, 1), (1, 4096), torch.float32)
buf3 = reinterpret_tensor(buf2, (4096, 1), (1, 1), 0); del buf2 # reuse
# Topologically Sorted Source Nodes: [loss], Original ATen: [aten._log_softmax]
triton_red_fused__log_softmax_1.run(buf3, primals_1, buf1, 4096, 152064, grid=grid(4096), stream=stream0)
buf6 = empty_strided_cuda((), (), torch.float32)
buf5 = empty_strided_cuda((), (), torch.float32)
buf7 = buf6; del buf6 # reuse
# Topologically Sorted Source Nodes: [loss], Original ATen: [aten.nll_loss_forward]
triton_red_fused_nll_loss_forward_2.run(buf7, buf0, primals_1, buf1, buf3, buf5, 1, 4096, grid=grid(1), stream=stream0)
return (buf7, primals_1, buf0, buf1, buf3, buf5, )
def benchmark_compiled_module(times=10, repeat=10):
from torch._dynamo.testing import rand_strided
from torch._inductor.utils import print_performance
primals_1 = rand_strided((1, 4096, 152064), (622854144, 152064, 1), device='cuda:0', dtype=torch.bfloat16)
primals_2 = rand_strided((1, 4096), (4096, 1), device='cuda:0', dtype=torch.int64)
fn = lambda: call([primals_1, primals_2])
return print_performance(fn, times=times, repeat=repeat)
if __name__ == "__main__":
from torch._inductor.wrapper_benchmark import compiled_module_main
compiled_module_main('None', benchmark_compiled_module) Inductor computation graph |
Will close when this PR merges: [WIP] Add cross_entropy to torchcompile_cat executor #1655 |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
🐛 Bug
It looks like nvFuser uses 5x more memory than inductor(triton) and slightly more than eager to compute forward cross_entropy loss. While in the current status with
take_along_axis
commented out nvFuser is not able to capture the full function,lightning-thunder/thunder/executors/nvfuserex_impl.py
Lines 1326 to 1335 in cb5a230
adding the op does not seem to solve the issue and results in errors when executing the fusion #3781). Ideally we would want to let
take_along_axis
be captured by nvFuser so that it fully captures the loss fn computation graph.Inductor's fusion, peak mem 1245.77 MB
nvFuser fusion without `take_along_axis`, peak mem 6228.60 MB
nvFuser fusion with `take_along_axis`, peak mem 6228.60 MB
To Reproduce
Get HF transformers with
pip install transformers==4.47.1
and then run the following:To test capturing
take_along_axis
please uncomment the lines shown above inthunder/executors/nvfuserex_impl.py
.Expected behavior
Similar performance and memory consumption as inductor.
Environment
ps.
This issue is connected to #1552 and temporary workaround would be to let inductor capture the corss_entropy instead of nvFuser.
cc @tfogal
The text was updated successfully, but these errors were encountered: