Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

nvFuser using more memory than inductor for HF CausalLMLoss #1654

Open
riccardofelluga opened this issue Jan 17, 2025 · 3 comments · May be fixed by #1655
Open

nvFuser using more memory than inductor for HF CausalLMLoss #1654

riccardofelluga opened this issue Jan 17, 2025 · 3 comments · May be fixed by #1655

Comments

@riccardofelluga
Copy link
Collaborator

riccardofelluga commented Jan 17, 2025

🐛 Bug

It looks like nvFuser uses 5x more memory than inductor(triton) and slightly more than eager to compute forward cross_entropy loss. While in the current status with take_along_axis commented out nvFuser is not able to capture the full function,

# TAKE_ALONG_AXIS is currently disabled
# There was an nvFuser bug that prevented this which is now fixed; we should
# investigate re-enabling take_along_axis.
# # TODO Check that the nvFuser version is >= 0.0.10 when this operator was added
# def take_along_axis(a: TensorProxy, /, index: TensorProxy, dim: int, *, fd: FusionDefinition, lc_to_nv_map: dict) -> Any:
# nv_a = getnv(a, fd, lc_to_nv_map)
# nv_index = getnv(index, fd, lc_to_nv_map)
# return fd.ops.take_along_axis(nv_a, nv_index, dim)
# register_supported(PrimIDs.TAKE_ALONG_AXIS, take_along_axis, _take_check)

adding the op does not seem to solve the issue and results in errors when executing the fusion #3781). Ideally we would want to let take_along_axis be captured by nvFuser so that it fully captures the loss fn computation graph.

Inductor's fusion, peak mem 1245.77 MB
# Constructed by Unwrap the actual return value --- 
import torch
from thunder.executors.torchex import no_autocast

@torch.no_grad()
@no_autocast
def computation(l_logits_, l_labels_):
  # l_logits_: "cuda:0 bf16[1, 4096, 152064]"
  # l_labels_: "cuda:0 i64[1, 4096]"
  [loss] = TorchCompile0(l_logits_, l_labels_)
    # logits = prims.convert_element_type(l_logits_, dtypes.float32)  # logits: "cuda:0 f32[1, 4096, 152064]"
    # getitem = ltorch.getitem(logits, (..., slice(None,-1,None), slice(None,None,None)))  # getitem: "cuda:0 f32[1, 4095, 152064]"
      # getitem = prims.slice_prim(logits, [0, 0, 0], [1, 4095, 152064], [1, 1, 1])  # getitem: "cuda:0 f32[1, 4095, 152064]"
    # shift_logits = ltorch.contiguous(getitem, memory_format=_torch_memory_format_0)  # shift_logits: "cuda:0 f32[1, 4095, 152064]"
      # shift_logits = prims.stride_order(getitem, (2, 1, 0))  # shift_logits: "cuda:0 f32[1, 4095, 152064]"
    # getitem_1 = ltorch.getitem(l_labels_, (..., slice(1,None,None)))  # getitem_1: "cuda:0 i64[1, 4095]"
      # getitem_1 = prims.slice_prim(l_labels_, [0, 1], [1, 4096], [1, 1])  # getitem_1: "cuda:0 i64[1, 4095]"
    # shift_labels = ltorch.contiguous(getitem_1, memory_format=_torch_memory_format_0)  # shift_labels: "cuda:0 i64[1, 4095]"
      # shift_labels = prims.stride_order(getitem_1, (1, 0))  # shift_labels: "cuda:0 i64[1, 4095]"
    # shift_logits_1 = ltorch.view(shift_logits, -1, 152064)  # shift_logits_1: "cuda:0 f32[4095, 152064]"
      # shift_logits_1 = ltorch.reshape(shift_logits, (-1, 152064))  # shift_logits_1: "cuda:0 f32[4095, 152064]"
        # shift_logits_1 = prims.reshape(shift_logits, (4095, 152064))  # shift_logits_1: "cuda:0 f32[4095, 152064]"
    # shift_labels_1 = ltorch.view(shift_labels, -1)  # shift_labels_1: "cuda:0 i64[4095]"
      # shift_labels_1 = ltorch.reshape(shift_labels, (-1,))  # shift_labels_1: "cuda:0 i64[4095]"
        # shift_labels_1 = prims.reshape(shift_labels, (4095,))  # shift_labels_1: "cuda:0 i64[4095]"
    # loss = ltorch.cross_entropy(shift_logits_1, shift_labels_1, None, None, -100, None, 'mean', 0.0)  # loss: "cuda:0 f32[]"
      # t20 = ltorch.log_softmax(shift_logits_1, 1, dtype=None)  # t20: "cuda:0 f32[4095, 152064]"
        # t18 = ltorch.logsumexp(shift_logits_1, 1, True)  # t18: "cuda:0 f32[4095, 1]"
          # t8 = ltorch.amax(shift_logits_1, 1, True)  # t8: "cuda:0 f32[4095, 1]"
            # t7 = prims.amax(shift_logits_1, (1,))  # t7: "cuda:0 f32[4095]"
            # t8 = prims.broadcast_in_dim(t7, [4095, 1], [0])  # t8: "cuda:0 f32[4095, 1]"
          # t9 = ltorch.abs(t8)  # t9: "cuda:0 f32[4095, 1]"
            # t9 = prims.abs(t8)  # t9: "cuda:0 f32[4095, 1]"
          # t10 = ltorch.eq(t9, float('inf'))  # t10: "cuda:0 b8[4095, 1]"
            # t10 = prims.eq(t9, float('inf'))  # t10: "cuda:0 b8[4095, 1]"
          # t11 = ltorch.where(t10, 0, t8)  # t11: "cuda:0 f32[4095, 1]"
            # t11 = prims.where(t10, 0.0, t8)  # t11: "cuda:0 f32[4095, 1]"
          # t13 = ltorch.sub(shift_logits_1, t11, alpha=1)  # t13: "cuda:0 f32[4095, 152064]"
            # t12 = prims.broadcast_in_dim(t11, (4095, 152064), (0, 1))  # t12: "cuda:0 f32[4095, 152064]"
            # t13 = prims.sub(shift_logits_1, t12)  # t13: "cuda:0 f32[4095, 152064]"
          # t14 = ltorch.exp(t13)  # t14: "cuda:0 f32[4095, 152064]"
            # t14 = prims.exp(t13)  # t14: "cuda:0 f32[4095, 152064]"
          # t16 = ltorch.sum(t14, 1, True, dtype=None)  # t16: "cuda:0 f32[4095, 1]"
            # t15 = prims.sum(t14, (1,))  # t15: "cuda:0 f32[4095]"
            # t16 = prims.broadcast_in_dim(t15, [4095, 1], [0])  # t16: "cuda:0 f32[4095, 1]"
          # t17 = ltorch.log(t16)  # t17: "cuda:0 f32[4095, 1]"
            # t17 = prims.log(t16)  # t17: "cuda:0 f32[4095, 1]"
          # t18 = ltorch.add(t17, t11, alpha=1)  # t18: "cuda:0 f32[4095, 1]"
            # t18 = prims.add(t17, t11)  # t18: "cuda:0 f32[4095, 1]"
        # t20 = ltorch.sub(shift_logits_1, t18, alpha=1)  # t20: "cuda:0 f32[4095, 152064]"
          # t19 = prims.broadcast_in_dim(t18, (4095, 152064), (0, 1))  # t19: "cuda:0 f32[4095, 152064]"
          # t20 = prims.sub(shift_logits_1, t19)  # t20: "cuda:0 f32[4095, 152064]"
      # loss = ltorch.nll_loss(t20, shift_labels_1, None, -100, 'mean')  # loss: "cuda:0 f32[]"
        # t21 = ltorch.neg(t20)  # t21: "cuda:0 f32[4095, 152064]"
          # t21 = prims.neg(t20)  # t21: "cuda:0 f32[4095, 152064]"
        # t22 = ltorch.unsqueeze(shift_labels_1, 1)  # t22: "cuda:0 i64[4095, 1]"
          # t22 = prims.broadcast_in_dim(shift_labels_1, [4095, 1], [0])  # t22: "cuda:0 i64[4095, 1]"
        # t23 = ltorch.take_along_dim(t21, t22, 1)  # t23: "cuda:0 f32[4095, 1]"
          # t23 = prims.take_along_axis(t21, t22, 1)  # t23: "cuda:0 f32[4095, 1]"
        # t24 = ltorch.ne(t22, -100)  # t24: "cuda:0 b8[4095, 1]"
          # t24 = prims.ne(t22, -100)  # t24: "cuda:0 b8[4095, 1]"
        # t25 = ltorch.where(t24, t23, 0)  # t25: "cuda:0 f32[4095, 1]"
          # t25 = prims.where(t24, t23, 0.0)  # t25: "cuda:0 f32[4095, 1]"
        # t26 = ltorch.sum(t25, None, False, dtype=None)  # t26: "cuda:0 f32[]"
          # t26 = prims.sum(t25, (0, 1))  # t26: "cuda:0 f32[]"
        # t28 = ltorch.sum(t24, None, False, dtype=None)  # t28: "cuda:0 i64[]"
          # t27 = ltorch.to(t24, dtypes.int64, None, device=None, dtype=None, copy=False, memory_format=None)  # t27: "cuda:0 i64[4095, 1]"
            # t27 = prims.convert_element_type(t24, dtypes.int64)  # t27: "cuda:0 i64[4095, 1]"
          # t28 = prims.sum(t27, (0, 1))  # t28: "cuda:0 i64[]"
        # loss = ltorch.true_divide(t26, t28)  # loss: "cuda:0 f32[]"
          # t29 = prims.convert_element_type(t28, dtypes.float32)  # t29: "cuda:0 f32[]"
          # loss = prims.div(t26, t29)  # loss: "cuda:0 f32[]"
  return (loss,) 
nvFuser fusion without `take_along_axis`, peak mem 6228.60 MB
# Constructed by Unwrap the actual return value --- 6228.606976 
import torch
import torch.nn.functional
from thunder.executors.torchex import no_autocast

@torch.no_grad()
@no_autocast
def computation(l_logits_, l_labels_):
  # l_logits_: "cuda:0 bf16[1, 4096, 152064]"
  # l_labels_: "cuda:0 i64[1, 4096]"
  [shift_labels_1, t20] = nvFusion0(l_logits_, l_labels_)
    # logits = prims.convert_element_type(l_logits_, dtypes.float32)  # logits: "cuda:0 f32[1, 4096, 152064]"
    # getitem = prims.slice_prim(logits, [0, 0, 0], [1, 4095, 152064], [1, 1, 1])  # getitem: "cuda:0 f32[1, 4095, 152064]"
    # shift_logits = prims.stride_order(getitem, (2, 1, 0))  # shift_logits: "cuda:0 f32[1, 4095, 152064]"
    # getitem_1 = prims.slice_prim(l_labels_, [0, 1], [1, 4096], [1, 1])  # getitem_1: "cuda:0 i64[1, 4095]"
    # shift_labels = prims.stride_order(getitem_1, (1, 0))  # shift_labels: "cuda:0 i64[1, 4095]"
    # shift_logits_1 = prims.reshape(shift_logits, (4095, 152064))  # shift_logits_1: "cuda:0 f32[4095, 152064]"
    # shift_labels_1 = prims.reshape(shift_labels, (4095,))  # shift_labels_1: "cuda:0 i64[4095]"
    # t7 = prims.amax(shift_logits_1, (1,))  # t7: "cuda:0 f32[4095]"
    # t8 = prims.broadcast_in_dim(t7, [4095, 1], [0])  # t8: "cuda:0 f32[4095, 1]"
    # t9 = prims.abs(t8)  # t9: "cuda:0 f32[4095, 1]"
    # t10 = prims.eq(t9, float('inf'))  # t10: "cuda:0 b8[4095, 1]"
    # t11 = prims.where(t10, 0.0, t8)  # t11: "cuda:0 f32[4095, 1]"
    # t12 = prims.broadcast_in_dim(t11, (4095, 152064), (0, 1))  # t12: "cuda:0 f32[4095, 152064]"
    # t13 = prims.sub(shift_logits_1, t12)  # t13: "cuda:0 f32[4095, 152064]"
    # t14 = prims.exp(t13)  # t14: "cuda:0 f32[4095, 152064]"
    # t15 = prims.sum(t14, (1,))  # t15: "cuda:0 f32[4095]"
    # t16 = prims.broadcast_in_dim(t15, [4095, 1], [0])  # t16: "cuda:0 f32[4095, 1]"
    # t17 = prims.log(t16)  # t17: "cuda:0 f32[4095, 1]"
    # t18 = prims.add(t17, t11)  # t18: "cuda:0 f32[4095, 1]"
    # t19 = prims.broadcast_in_dim(t18, (4095, 152064), (0, 1))  # t19: "cuda:0 f32[4095, 152064]"
    # t20 = prims.sub(shift_logits_1, t19)  # t20: "cuda:0 f32[4095, 152064]"

  # <eval_with_key>.4:13: 	    loss = torch.nn.functional.cross_entropy(shift_logits_1, shift_labels_2, ignore_index = -100, reduction = 'mean');  shift_logits_1 = shift_labels_2 = None
  t50 = torch.nn.functional.nll_loss(t20, target=shift_labels_1, weight=None, ignore_index=-100, reduction='mean')  # t50: "cuda:0 f32[]"
    # t50 = ltorch.nll_loss(t20, shift_labels_1, None, -100, 'mean')  # t50: "cuda:0 f32[]"
      # t41 = ltorch.neg(t20)  # t41: "cuda:0 f32[4095, 152064]"
        # t41 = prims.neg(t20)  # t41: "cuda:0 f32[4095, 152064]"
      # t42 = ltorch.unsqueeze(shift_labels_1, 1)  # t42: "cuda:0 i64[4095, 1]"
        # t42 = prims.broadcast_in_dim(shift_labels_1, [4095, 1], [0])  # t42: "cuda:0 i64[4095, 1]"
      # t43 = ltorch.take_along_dim(t41, t42, 1)  # t43: "cuda:0 f32[4095, 1]"
        # t43 = prims.take_along_axis(t41, t42, 1)  # t43: "cuda:0 f32[4095, 1]"
      # t44 = ltorch.ne(t42, -100)  # t44: "cuda:0 b8[4095, 1]"
        # t44 = prims.ne(t42, -100)  # t44: "cuda:0 b8[4095, 1]"
      # t45 = ltorch.where(t44, t43, 0)  # t45: "cuda:0 f32[4095, 1]"
        # t45 = prims.where(t44, t43, 0.0)  # t45: "cuda:0 f32[4095, 1]"
      # t46 = ltorch.sum(t45, None, False, dtype=None)  # t46: "cuda:0 f32[]"
        # t46 = prims.sum(t45, (0, 1))  # t46: "cuda:0 f32[]"
      # t48 = ltorch.sum(t44, None, False, dtype=None)  # t48: "cuda:0 i64[]"
        # t47 = ltorch.to(t44, dtypes.int64, None, device=None, dtype=None, copy=False, memory_format=None)  # t47: "cuda:0 i64[4095, 1]"
          # t47 = prims.convert_element_type(t44, dtypes.int64)  # t47: "cuda:0 i64[4095, 1]"
        # t48 = prims.sum(t47, (0, 1))  # t48: "cuda:0 i64[]"
      # t50 = ltorch.true_divide(t46, t48)  # t50: "cuda:0 f32[]"
        # t49 = prims.convert_element_type(t48, dtypes.float32)  # t49: "cuda:0 f32[]"
        # t50 = prims.div(t46, t49)  # t50: "cuda:0 f32[]"
  del t20, shift_labels_1
  return (t50,)
nvFuser fusion with `take_along_axis`, peak mem 6228.60 MB
# Constructed by Unwrap the actual return value 
import torch
from thunder.executors.torchex import no_autocast

@torch.no_grad()
@no_autocast
def computation(l_logits_, l_labels_):
  # l_logits_: "cuda:0 bf16[1, 4096, 152064]"
  # l_labels_: "cuda:0 i64[1, 4096]"
  [loss] = nvFusion0(l_logits_, l_labels_)
    # logits = prims.convert_element_type(l_logits_, dtypes.float32)  # logits: "cuda:0 f32[1, 4096, 152064]"
    # getitem = prims.slice_prim(logits, [0, 0, 0], [1, 4095, 152064], [1, 1, 1])  # getitem: "cuda:0 f32[1, 4095, 152064]"
    # shift_logits = prims.stride_order(getitem, (2, 1, 0))  # shift_logits: "cuda:0 f32[1, 4095, 152064]"
    # getitem_1 = prims.slice_prim(l_labels_, [0, 1], [1, 4096], [1, 1])  # getitem_1: "cuda:0 i64[1, 4095]"
    # shift_labels = prims.stride_order(getitem_1, (1, 0))  # shift_labels: "cuda:0 i64[1, 4095]"
    # shift_logits_1 = prims.reshape(shift_logits, (4095, 152064))  # shift_logits_1: "cuda:0 f32[4095, 152064]"
    # shift_labels_1 = prims.reshape(shift_labels, (4095,))  # shift_labels_1: "cuda:0 i64[4095]"
    # t7 = prims.amax(shift_logits_1, (1,))  # t7: "cuda:0 f32[4095]"
    # t8 = prims.broadcast_in_dim(t7, [4095, 1], [0])  # t8: "cuda:0 f32[4095, 1]"
    # t9 = prims.abs(t8)  # t9: "cuda:0 f32[4095, 1]"
    # t10 = prims.eq(t9, float('inf'))  # t10: "cuda:0 b8[4095, 1]"
    # t11 = prims.where(t10, 0.0, t8)  # t11: "cuda:0 f32[4095, 1]"
    # t12 = prims.broadcast_in_dim(t11, (4095, 152064), (0, 1))  # t12: "cuda:0 f32[4095, 152064]"
    # t13 = prims.sub(shift_logits_1, t12)  # t13: "cuda:0 f32[4095, 152064]"
    # t14 = prims.exp(t13)  # t14: "cuda:0 f32[4095, 152064]"
    # t15 = prims.sum(t14, (1,))  # t15: "cuda:0 f32[4095]"
    # t16 = prims.broadcast_in_dim(t15, [4095, 1], [0])  # t16: "cuda:0 f32[4095, 1]"
    # t17 = prims.log(t16)  # t17: "cuda:0 f32[4095, 1]"
    # t18 = prims.add(t17, t11)  # t18: "cuda:0 f32[4095, 1]"
    # t19 = prims.broadcast_in_dim(t18, (4095, 152064), (0, 1))  # t19: "cuda:0 f32[4095, 152064]"
    # t20 = prims.sub(shift_logits_1, t19)  # t20: "cuda:0 f32[4095, 152064]"
    # t21 = prims.neg(t20)  # t21: "cuda:0 f32[4095, 152064]"
    # t22 = prims.broadcast_in_dim(shift_labels_1, [4095, 1], [0])  # t22: "cuda:0 i64[4095, 1]"
    # t23 = prims.take_along_axis(t21, t22, 1)  # t23: "cuda:0 f32[4095, 1]"
    # t24 = prims.ne(t22, -100)  # t24: "cuda:0 b8[4095, 1]"
    # t25 = prims.where(t24, t23, 0.0)  # t25: "cuda:0 f32[4095, 1]"
    # t26 = prims.sum(t25, (0, 1))  # t26: "cuda:0 f32[]"
    # t27 = prims.convert_element_type(t24, dtypes.int64)  # t27: "cuda:0 i64[4095, 1]"
    # t28 = prims.sum(t27, (0, 1))  # t28: "cuda:0 i64[]"
    # t29 = prims.convert_element_type(t28, dtypes.float32)  # t29: "cuda:0 f32[]"
    # loss = prims.div(t26, t29)  # loss: "cuda:0 f32[]"
  return (loss,)

To Reproduce

Get HF transformers with pip install transformers==4.47.1 and then run the following:

import torch
from thunder.dynamo import ThunderCompiler
from transformers.loss.loss_utils import ForCausalLMLoss 

logits = torch.randn(1, 4096, 152064, dtype=torch.bfloat16, device="cuda", requires_grad=True)
labels = torch.randint(0, 152064, (1, 4096), device="cuda")
vocab_size = 152064
ignore_index = -100

executors = ("nvfuser",)
# executors = ("torchcompile",)
backend = ThunderCompiler(executors=executors)
compiled_loss = torch.compile(ForCausalLMLoss, dynamic=False, backend=backend, fullgraph=True)

out = compiled_loss(logits, labels, vocab_size, ignore_index=ignore_index)

torch.cuda.synchronize()

print("Peak memory", torch.cuda.max_memory_allocated() / 1e6)

To test capturing take_along_axis please uncomment the lines shown above in thunder/executors/nvfuserex_impl.py.

Expected behavior

Similar performance and memory consumption as inductor.

Environment

  • torch version: 2.5.1+cu124
  • cuda version: 12.4
  • nvfuser version: 0.2.24+gitac84633
  • transformers-4.47.1

ps.

This issue is connected to #1552 and temporary workaround would be to let inductor capture the corss_entropy instead of nvFuser.

cc @tfogal

@riccardofelluga
Copy link
Collaborator Author

Just to add, interesting enough, if cross entropy is claimed by the apex executor, the memory footprint seems to remain the same. Therefore it is important for nvFuser to be able to make a single fusion with reshapes and cross entropy.

nvFuser + apex trace, peak memory 6228.60 MB
# Constructed by Delete Last Used (took 0 milliseconds)
import torch
from thunder.executors.torchex import no_autocast

@torch.no_grad()
@no_autocast
def computation(l_logits_, l_labels_):
  # l_logits_: "cuda:0 bf16[1, 4096, 152064]"
  # l_labels_: "cuda:0 i64[1, 4096]"
  [shift_logits_1, shift_labels_1] = nvFusion0(l_logits_, l_labels_)
    # logits = prims.convert_element_type(l_logits_, dtypes.float32)  # logits: "cuda:0 f32[1, 4096, 152064]"
    # getitem = prims.slice_prim(logits, [0, 0, 0], [1, 4095, 152064], [1, 1, 1])  # getitem: "cuda:0 f32[1, 4095, 152064]"
    # shift_logits = prims.stride_order(getitem, (2, 1, 0))  # shift_logits: "cuda:0 f32[1, 4095, 152064]"
    # getitem_1 = prims.slice_prim(l_labels_, [0, 1], [1, 4096], [1, 1])  # getitem_1: "cuda:0 i64[1, 4095]"
    # shift_labels = prims.stride_order(getitem_1, (1, 0))  # shift_labels: "cuda:0 i64[1, 4095]"
    # shift_logits_1 = prims.reshape(shift_logits, (4095, 152064))  # shift_logits_1: "cuda:0 f32[4095, 152064]"
    # shift_labels_1 = prims.reshape(shift_labels, (4095,))  # shift_labels_1: "cuda:0 i64[4095]"

  # <eval_with_key>.4:14: 	    loss = torch.nn.functional.cross_entropy(shift_logits_1, shift_labels_2, ignore_index = -100, reduction = 'mean');  shift_logits_1 = shift_labels_2 = None
  (loss, _) = apex_cross_entropy(shift_logits_1, shift_labels_1, 'mean', 0.0)
  return {'output': (loss,), 'flat_args': [l_logits_, l_labels_], 'flat_output': (loss,)}, ((shift_labels_1, shift_logits_1), ())

@riccardofelluga
Copy link
Collaborator Author

riccardofelluga commented Jan 23, 2025

Adding Inductor's generated code here for debugging:

# AOT ID: ['0_forward']
from ctypes import c_void_p, c_long, c_int
import torch
import math
import random
import os
import tempfile
from math import inf, nan
from torch._inductor.hooks import run_intermediate_hooks
from torch._inductor.utils import maybe_profile
from torch._inductor.codegen.memory_planning import _align as align
from torch import device, empty_strided
from torch._inductor.async_compile import AsyncCompile
from torch._inductor.select_algorithm import extern_kernels
from torch._inductor.codegen.multi_kernel import MultiKernelCall
import triton
import triton.language as tl
from torch._inductor.runtime.triton_heuristics import grid, split_scan_grid, grid_combo_kernels, start_graph, end_graph
from torch._C import _cuda_getCurrentRawStream as get_raw_stream

aten = torch.ops.aten
inductor_ops = torch.ops.inductor
_quantized = torch.ops._quantized
assert_size_stride = torch._C._dynamo.guards.assert_size_stride
empty_strided_cpu = torch._C._dynamo.guards._empty_strided_cpu
empty_strided_cuda = torch._C._dynamo.guards._empty_strided_cuda
empty_strided_xpu = torch._C._dynamo.guards._empty_strided_xpu
reinterpret_tensor = torch._C._dynamo.guards._reinterpret_tensor
alloc_from_pool = torch.ops.inductor._alloc_from_pool
async_compile = AsyncCompile()


# Topologically Sorted Source Nodes: [labels_1], Original ATen: [aten.constant_pad_nd]
# Source node to ATen node mapping:
#   labels_1 => constant_pad_nd
# Graph fragment:
#   %constant_pad_nd : [num_users=2] = call_function[target=torch.ops.aten.constant_pad_nd.default](args = (%primals_2, [0, 1], -100.0), kwargs = {})
triton_poi_fused_constant_pad_nd_0 = async_compile.triton('triton_', '''
import triton
import triton.language as tl
from triton.compiler.compiler import AttrsDescriptor

from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, instance_descriptor, DeviceProperties

@triton_heuristics.pointwise(
    size_hints=[8192], 
    filename=__file__,
    triton_meta={'signature': {0: '*i64', 1: '*i64', 2: 'i32'}, 'device': DeviceProperties(type='cuda', index=0, cc=89, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=1536, multi_processor_count=142), 'constants': {}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1), equal_to_1=())]},
    inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_constant_pad_nd_0', 'mutated_arg_names': [], 'no_x_dim': False, 'num_load': 1, 'num_reduction': 0, 'backend_hash': 'EE1FF1A692BB318A52D1FC9AAB9D4731B4C6A59706D9DFCBED88CECB030EA339', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False},
    min_elem_per_thread=0
)
@triton.jit
def triton_(in_ptr0, out_ptr0, xnumel, XBLOCK : tl.constexpr):
    xnumel = 4097
    xoffset = tl.program_id(0) * XBLOCK
    xindex = xoffset + tl.arange(0, XBLOCK)[:]
    xmask = xindex < xnumel
    x0 = xindex
    tmp0 = x0
    tmp1 = tl.full([1], 4096, tl.int64)
    tmp2 = tmp0 < tmp1
    tmp3 = tl.load(in_ptr0 + (x0), tmp2 & xmask, other=-100)
    tl.store(out_ptr0 + (x0), tmp3, xmask)
''', device_str='cuda')


# Topologically Sorted Source Nodes: [loss], Original ATen: [aten._log_softmax]
# Source node to ATen node mapping:
#   loss => amax, exp, log, sub, sum_1
# Graph fragment:
#   %amax : [num_users=2] = call_function[target=torch.ops.aten.amax.default](args = (%view, [1], True), kwargs = {})
#   %sub : [num_users=2] = call_function[target=torch.ops.aten.sub.Tensor](args = (%view, %amax), kwargs = {})
#   %exp : [num_users=1] = call_function[target=torch.ops.aten.exp.default](args = (%sub,), kwargs = {})
#   %sum_1 : [num_users=1] = call_function[target=torch.ops.aten.sum.dim_IntList](args = (%exp, [1], True), kwargs = {})
#   %log : [num_users=2] = call_function[target=torch.ops.aten.log.default](args = (%sum_1,), kwargs = {})
triton_red_fused__log_softmax_1 = async_compile.triton('triton_', '''
import triton
import triton.language as tl
from triton.compiler.compiler import AttrsDescriptor

from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, instance_descriptor, DeviceProperties

@triton_heuristics.reduction(
    size_hints=[4096, 262144],
    reduction_hint=ReductionHint.INNER,
    filename=__file__,
    triton_meta={'signature': {0: '*fp32', 1: '*bf16', 2: '*fp32', 3: 'i32', 4: 'i32'}, 'device': DeviceProperties(type='cuda', index=0, cc=89, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=1536, multi_processor_count=142), 'constants': {}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1, 2, 3, 4), equal_to_1=())]},
    inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_red_fused__log_softmax_1', 'mutated_arg_names': ['in_out_ptr0'], 'no_x_dim': False, 'num_load': 2, 'num_reduction': 2, 'backend_hash': 'EE1FF1A692BB318A52D1FC9AAB9D4731B4C6A59706D9DFCBED88CECB030EA339', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False}
)
@triton.jit
def triton_(in_out_ptr0, in_ptr0, out_ptr0, xnumel, rnumel, XBLOCK : tl.constexpr, RBLOCK : tl.constexpr):
    xnumel = 4096
    rnumel = 152064
    xoffset = tl.program_id(0) * XBLOCK
    xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
    xmask = tl.full([XBLOCK, RBLOCK], True, tl.int1)
    rbase = tl.arange(0, RBLOCK)[None, :]
    x0 = xindex
    _tmp3 = tl.full([XBLOCK, RBLOCK], float("-inf"), tl.float32)
    for roffset in range(0, rnumel, RBLOCK):
        rindex = roffset + rbase
        rmask = rindex < rnumel
        r1 = rindex
        tmp0 = tl.load(in_ptr0 + (r1 + (152064*x0)), rmask, eviction_policy='evict_last', other=0.0).to(tl.float32)
        tmp1 = tmp0.to(tl.float32)
        tmp2 = tl.broadcast_to(tmp1, [XBLOCK, RBLOCK])
        tmp4 = triton_helpers.maximum(_tmp3, tmp2)
        _tmp3 = tl.where(rmask, tmp4, _tmp3)
    tmp3 = triton_helpers.max2(_tmp3, 1)[:, None]
    tl.store(out_ptr0 + (x0), tmp3, None)
    _tmp10 = tl.full([XBLOCK, RBLOCK], 0, tl.float32)
    for roffset in range(0, rnumel, RBLOCK):
        rindex = roffset + rbase
        rmask = rindex < rnumel
        r1 = rindex
        tmp5 = tl.load(in_ptr0 + (r1 + (152064*x0)), rmask, eviction_policy='evict_first', other=0.0).to(tl.float32)
        tmp6 = tmp5.to(tl.float32)
        tmp7 = tmp6 - tmp3
        tmp8 = tl_math.exp(tmp7)
        tmp9 = tl.broadcast_to(tmp8, [XBLOCK, RBLOCK])
        tmp11 = _tmp10 + tmp9
        _tmp10 = tl.where(rmask, tmp11, _tmp10)
    tmp10 = tl.sum(_tmp10, 1)[:, None]
    tmp12 = tl_math.log(tmp10)
    tl.debug_barrier()
    tl.store(in_out_ptr0 + (x0), tmp12, None)
''', device_str='cuda')


# Topologically Sorted Source Nodes: [loss], Original ATen: [aten.nll_loss_forward]
# Source node to ATen node mapping:
#   loss => convert_element_type_1, div, full_default_1, ne, neg, sum_2, sum_3, where_1
# Graph fragment:
#   %ne : [num_users=3] = call_function[target=torch.ops.aten.ne.Scalar](args = (%view_1, -100), kwargs = {})
#   %neg : [num_users=1] = call_function[target=torch.ops.aten.neg.default](args = (%squeeze,), kwargs = {})
#   %full_default_1 : [num_users=1] = call_function[target=torch.ops.aten.full.default](args = ([], 0.0), kwargs = {dtype: torch.float32, layout: torch.strided, device: cuda:0, pin_memory: False})
#   %where_1 : [num_users=1] = call_function[target=torch.ops.aten.where.self](args = (%ne, %neg, %full_default_1), kwargs = {})
#   %sum_2 : [num_users=1] = call_function[target=torch.ops.aten.sum.default](args = (%ne,), kwargs = {})
#   %convert_element_type_1 : [num_users=2] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%sum_2, torch.float32), kwargs = {})
#   %sum_3 : [num_users=1] = call_function[target=torch.ops.aten.sum.default](args = (%where_1,), kwargs = {})
#   %div : [num_users=1] = call_function[target=torch.ops.aten.div.Tensor](args = (%sum_3, %convert_element_type_1), kwargs = {})
triton_red_fused_nll_loss_forward_2 = async_compile.triton('triton_', '''
import triton
import triton.language as tl
from triton.compiler.compiler import AttrsDescriptor

from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, instance_descriptor, DeviceProperties

@triton_heuristics.reduction(
    size_hints=[1, 4096],
    reduction_hint=ReductionHint.INNER,
    filename=__file__,
    triton_meta={'signature': {0: '*fp32', 1: '*i64', 2: '*bf16', 3: '*fp32', 4: '*fp32', 5: '*fp32', 6: 'i32', 7: 'i32'}, 'device': DeviceProperties(type='cuda', index=0, cc=89, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=1536, multi_processor_count=142), 'constants': {6: 1}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1, 2, 3, 4, 5, 7), equal_to_1=(6,))]},
    inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_red_fused_nll_loss_forward_2', 'mutated_arg_names': ['in_out_ptr0'], 'no_x_dim': False, 'num_load': 3, 'num_reduction': 2, 'backend_hash': 'EE1FF1A692BB318A52D1FC9AAB9D4731B4C6A59706D9DFCBED88CECB030EA339', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False}
)
@triton.jit
def triton_(in_out_ptr0, in_ptr0, in_ptr1, in_ptr2, in_ptr3, out_ptr1, xnumel, rnumel, XBLOCK : tl.constexpr, RBLOCK : tl.constexpr):
    xnumel = 1
    rnumel = 4096
    xoffset = tl.program_id(0) * XBLOCK
    xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
    xmask = tl.full([XBLOCK, RBLOCK], True, tl.int1)
    rbase = tl.arange(0, RBLOCK)[None, :]
    _tmp5 = tl.full([XBLOCK, RBLOCK], 0, tl.int64)
    _tmp24 = tl.full([XBLOCK, RBLOCK], 0, tl.float32)
    for roffset in range(0, rnumel, RBLOCK):
        rindex = roffset + rbase
        rmask = rindex < rnumel
        r0 = rindex
        tmp0 = tl.load(in_ptr0 + (1 + r0), rmask, eviction_policy='evict_first', other=0.0)
        tmp16 = tl.load(in_ptr2 + (r0), rmask, eviction_policy='evict_first', other=0.0)
        tmp18 = tl.load(in_ptr3 + (r0), rmask, eviction_policy='evict_first', other=0.0)
        tmp1 = tl.full([1, 1], -100, tl.int64)
        tmp2 = tmp0 != tmp1
        tmp3 = tmp2.to(tl.int64)
        tmp4 = tl.broadcast_to(tmp3, [XBLOCK, RBLOCK])
        tmp6 = _tmp5 + tmp4
        _tmp5 = tl.where(rmask, tmp6, _tmp5)
        tmp7 = tl.full([1, 1], 0, tl.int64)
        tmp8 = tl.where(tmp2, tmp0, tmp7)
        tmp9 = tl.full([XBLOCK, RBLOCK], 152064, tl.int32)
        tmp10 = tmp8 + tmp9
        tmp11 = tmp8 < 0
        tmp12 = tl.where(tmp11, tmp10, tmp8)
        tl.device_assert(((0 <= tmp12) & (tmp12 < 152064)) | ~(rmask), "index out of bounds: 0 <= tmp12 < 152064")
        tmp14 = tl.load(in_ptr1 + (tmp12 + (152064*r0)), rmask, eviction_policy='evict_last', other=0.0).to(tl.float32)
        tmp15 = tmp14.to(tl.float32)
        tmp17 = tmp15 - tmp16
        tmp19 = tmp17 - tmp18
        tmp20 = -tmp19
        tmp21 = 0.0
        tmp22 = tl.where(tmp2, tmp20, tmp21)
        tmp23 = tl.broadcast_to(tmp22, [XBLOCK, RBLOCK])
        tmp25 = _tmp24 + tmp23
        _tmp24 = tl.where(rmask, tmp25, _tmp24)
    tmp5 = tl.sum(_tmp5, 1)[:, None]
    tmp24 = tl.sum(_tmp24, 1)[:, None]
    tmp26 = tmp5.to(tl.float32)
    tmp27 = tmp24 / tmp26
    tl.store(out_ptr1 + (tl.full([XBLOCK, 1], 0, tl.int32)), tmp26, None)
    tl.debug_barrier()
    tl.store(in_out_ptr0 + (tl.full([XBLOCK, 1], 0, tl.int32)), tmp27, None)
''', device_str='cuda')


async_compile.wait(globals())
del async_compile

def call(args):
    primals_1, primals_2 = args
    args.clear()
    assert_size_stride(primals_1, (1, 4096, 152064), (622854144, 152064, 1))
    assert_size_stride(primals_2, (1, 4096), (4096, 1))
    with torch.cuda._DeviceGuard(0):
        torch.cuda.set_device(0)
        buf0 = empty_strided_cuda((1, 4097), (4112, 1), torch.int64)
        # Topologically Sorted Source Nodes: [labels_1], Original ATen: [aten.constant_pad_nd]
        stream0 = get_raw_stream(0)
        triton_poi_fused_constant_pad_nd_0.run(primals_2, buf0, 4097, grid=grid(4097), stream=stream0)
        del primals_2
        buf1 = empty_strided_cuda((4096, 1), (1, 1), torch.float32)
        buf2 = empty_strided_cuda((4096, 1), (1, 4096), torch.float32)
        buf3 = reinterpret_tensor(buf2, (4096, 1), (1, 1), 0); del buf2  # reuse
        # Topologically Sorted Source Nodes: [loss], Original ATen: [aten._log_softmax]
        triton_red_fused__log_softmax_1.run(buf3, primals_1, buf1, 4096, 152064, grid=grid(4096), stream=stream0)
        buf6 = empty_strided_cuda((), (), torch.float32)
        buf5 = empty_strided_cuda((), (), torch.float32)
        buf7 = buf6; del buf6  # reuse
        # Topologically Sorted Source Nodes: [loss], Original ATen: [aten.nll_loss_forward]
        triton_red_fused_nll_loss_forward_2.run(buf7, buf0, primals_1, buf1, buf3, buf5, 1, 4096, grid=grid(1), stream=stream0)
    return (buf7, primals_1, buf0, buf1, buf3, buf5, )


def benchmark_compiled_module(times=10, repeat=10):
    from torch._dynamo.testing import rand_strided
    from torch._inductor.utils import print_performance
    primals_1 = rand_strided((1, 4096, 152064), (622854144, 152064, 1), device='cuda:0', dtype=torch.bfloat16)
    primals_2 = rand_strided((1, 4096), (4096, 1), device='cuda:0', dtype=torch.int64)
    fn = lambda: call([primals_1, primals_2])
    return print_performance(fn, times=times, repeat=repeat)


if __name__ == "__main__":
    from torch._inductor.wrapper_benchmark import compiled_module_main
    compiled_module_main('None', benchmark_compiled_module)

Inductor computation graph

Image

@nvMelissa
Copy link
Collaborator

Will close when this PR merges: [WIP] Add cross_entropy to torchcompile_cat executor #1655

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants