Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

div with nvfuser returns incorrect dtype #1723

Open
beverlylytle opened this issue Jan 30, 2025 · 4 comments · May be fixed by #1734
Open

div with nvfuser returns incorrect dtype #1723

beverlylytle opened this issue Jan 30, 2025 · 4 comments · May be fixed by #1734
Assignees
Labels

Comments

@beverlylytle
Copy link
Collaborator

beverlylytle commented Jan 30, 2025

test_core_vs_torch_consistency_div_nvfuser_cuda_thunder.dtypes.float32, which tests the elementwise division of tensors of dtype float32 with nvfuser, fails with the torch result of dtype float32 while the Thunder result is float64.

Thunder should also return float32.

cc @tfogal

@kshitij12345
Copy link
Collaborator

This seems to happen only for rounding_mode=floor

Minimal Repro

import thunder
import thunder.examine
import torch

def fn(x, y):
    return torch.div(x, y, rounding_mode="floor")

x = torch.randn(4, 2, 3, device="cuda")
y = torch.randn(4, 2, 3, device="cuda")
jfn = thunder.jit(fn, nv_store_fusion_inputs=True)

o = jfn(x, y)

trc = thunder.last_traces(jfn)[-1]

print(trc)
fusion_syms = thunder.examine.get_fusion_symbols(trc)

bsym = fusion_syms[0]

repro = thunder.examine.get_nvfuser_repro(trc, bsym.sym.name)
print(repro)

Generated Trace (this looks fine in terms of dtypes)

@torch.no_grad()
@no_autocast
def computation(x, y):
  # x: "cuda:0 f32[4, 2, 3]"
  # y: "cuda:0 f32[4, 2, 3]"
  [t21] = nvFusion0(x, y)
    # t0 = prims.fmod(x, y)  # t0: "cuda:0 f32[4, 2, 3]"
    # t1 = prims.sub(x, t0)  # t1: "cuda:0 f32[4, 2, 3]"
    # t2 = prims.div(t1, y)  # t2: "cuda:0 f32[4, 2, 3]"
    # t3 = prims.lt(x, 0.0)  # t3: "cuda:0 b8[4, 2, 3]"
    # t4 = prims.lt(y, 0.0)  # t4: "cuda:0 b8[4, 2, 3]"
    # t5 = prims.bitwise_xor(t3, t4)  # t5: "cuda:0 b8[4, 2, 3]"
    # t6 = prims.ne(t0, 0.0)  # t6: "cuda:0 b8[4, 2, 3]"
    # t7 = prims.bitwise_and(t6, t5)  # t7: "cuda:0 b8[4, 2, 3]"
    # t8 = prims.sub(t2, 1.0)  # t8: "cuda:0 f32[4, 2, 3]"
    # t9 = prims.where(t7, t8, t2)  # t9: "cuda:0 f32[4, 2, 3]"
    # t10 = prims.floor(t9)  # t10: "cuda:0 f32[4, 2, 3]"
    # t11 = prims.sub(t9, t10)  # t11: "cuda:0 f32[4, 2, 3]"
    # t12 = prims.gt(t11, 0.5)  # t12: "cuda:0 b8[4, 2, 3]"
    # t13 = prims.add(t10, 1.0)  # t13: "cuda:0 f32[4, 2, 3]"
    # t14 = prims.where(t12, t13, t10)  # t14: "cuda:0 f32[4, 2, 3]"
    # t15 = prims.div(x, y)  # t15: "cuda:0 f32[4, 2, 3]"
    # t16 = prims.ne(t9, 0.0)  # t16: "cuda:0 b8[4, 2, 3]"
    # t17 = prims.signbit(t15)  # t17: "cuda:0 b8[4, 2, 3]"
    # t18 = prims.where(t17, -0.0, 0.0)  # t18: "cuda:0 f32[4, 2, 3]"
    # t19 = prims.where(t16, t14, t18)  # t19: "cuda:0 f32[4, 2, 3]"
    # t20 = prims.eq(y, 0.0)  # t20: "cuda:0 b8[4, 2, 3]"
    # t21 = prims.where(t20, t15, t19)  # t21: "cuda:0 f32[4, 2, 3]"
  return (t21,)

nvFuser Definition (maybe scalar double leads to type promotion here?)

# CUDA devices:
#  0: NVIDIA RTX 6000 Ada Generation
#  1: NVIDIA RTX 6000 Ada Generation
# torch version: 2.6.0a0+git81ecf98
# cuda version: 12.3
# nvfuser version: 0.2.23+git769a4d2
import torch
from nvfuser import FusionDefinition, DataType

def nvfuser_fusion_id0(fd : FusionDefinition) -> None :
    T0 = fd.define_tensor(shape=[4, 2, 3], contiguity=[True, True, True], dtype=DataType.Float, is_cpu=False, stride_order=[2, 1, 0])
    T1 = fd.define_tensor(shape=[4, 2, 3], contiguity=[True, True, True], dtype=DataType.Float, is_cpu=False, stride_order=[2, 1, 0])
    T2 = fd.ops.fmod(T0, T1)
    T3 = fd.ops.sub(T0, T2)
    T4 = fd.ops.reciprocal(T1)
    T5 = fd.ops.mul(T3, T4)
    S6 = fd.define_scalar(0.00000, dtype=DataType.Double)
    T7 = fd.ops.lt(T0, S6)
    S8 = fd.define_scalar(0.00000, dtype=DataType.Double)
    T9 = fd.ops.lt(T1, S8)
    T10 = fd.ops.bitwise_xor(T7, T9)
    S11 = fd.define_scalar(0.00000, dtype=DataType.Double)
    T12 = fd.ops.ne(T2, S11)
    T13 = fd.ops.bitwise_and(T12, T10)
    S14 = fd.define_scalar(1.00000, dtype=DataType.Double)
    T15 = fd.ops.sub(T5, S14)
    T16 = fd.ops.where(T13, T15, T5)
    T17 = fd.ops.floor(T16)
    T18 = fd.ops.sub(T16, T17)
    S19 = fd.define_scalar(0.500000, dtype=DataType.Double)
    T20 = fd.ops.gt(T18, S19)
    S21 = fd.define_scalar(1.00000, dtype=DataType.Double)
    T22 = fd.ops.add(T17, S21)
    T23 = fd.ops.where(T20, T22, T17)
    T24 = fd.ops.reciprocal(T1)
    T25 = fd.ops.mul(T0, T24)
    S26 = fd.define_scalar(0.00000, dtype=DataType.Double)
    T27 = fd.ops.ne(T16, S26)
    T28 = fd.ops.signbit(T25)
    S29 = fd.define_scalar(-0.00000, dtype=DataType.Double)
    S30 = fd.define_scalar(0.00000, dtype=DataType.Double)
    T31 = fd.ops.where(T28, S29, S30)
    T32 = fd.ops.where(T27, T23, T31)
    S33 = fd.define_scalar(0.00000, dtype=DataType.Double)
    T34 = fd.ops.eq(T1, S33)
    T35 = fd.ops.where(T34, T25, T32)
    fd.add_output(T35)

with FusionDefinition() as fd:
    nvfuser_fusion_id0(fd)

inputs = [
    torch.testing.make_tensor((4, 2, 3), dtype=torch.float32, device='cuda:0'),
    torch.testing.make_tensor((4, 2, 3), dtype=torch.float32, device='cuda:0'),
]
fd.execute(inputs)

@tfogal
Copy link
Collaborator

tfogal commented Jan 31, 2025

Thanks for the issue, and extra thanks for the small reproducer.

cc @jjsjann123 is this an issue in the nvfuser executor or nvfuser itself?

@jjsjann123 jjsjann123 self-assigned this Feb 1, 2025
@jjsjann123
Copy link
Collaborator

Sorry for the delayed response, this felt like an nvfuser issue.

It went wrong here.

  # t18 = prims.where(t17, -0.0, 0.0)  # t18: "cuda:0 f32[4, 2, 3]"

Where it's translated to nvfuser as T31 = fd.ops.where(T28, S29, S30) Our promotion logic checks the scalar value dtypes, which are double. https://github.com/NVIDIA/Fuser/blob/93b68e00340578534e54fd813e2168f522bf2b8f/csrc/ops/arith.cpp#L1863

I think we can explicitly adding a cast at the end of where.

@jjsjann123
Copy link
Collaborator

trying it out in #1734

@jjsjann123 jjsjann123 linked a pull request Feb 1, 2025 that will close this issue
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging a pull request may close this issue.

4 participants