-
Notifications
You must be signed in to change notification settings - Fork 87
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
div with nvfuser returns incorrect dtype #1723
Comments
This seems to happen only for Minimal Repro import thunder
import thunder.examine
import torch
def fn(x, y):
return torch.div(x, y, rounding_mode="floor")
x = torch.randn(4, 2, 3, device="cuda")
y = torch.randn(4, 2, 3, device="cuda")
jfn = thunder.jit(fn, nv_store_fusion_inputs=True)
o = jfn(x, y)
trc = thunder.last_traces(jfn)[-1]
print(trc)
fusion_syms = thunder.examine.get_fusion_symbols(trc)
bsym = fusion_syms[0]
repro = thunder.examine.get_nvfuser_repro(trc, bsym.sym.name)
print(repro) Generated Trace (this looks fine in terms of dtypes) @torch.no_grad()
@no_autocast
def computation(x, y):
# x: "cuda:0 f32[4, 2, 3]"
# y: "cuda:0 f32[4, 2, 3]"
[t21] = nvFusion0(x, y)
# t0 = prims.fmod(x, y) # t0: "cuda:0 f32[4, 2, 3]"
# t1 = prims.sub(x, t0) # t1: "cuda:0 f32[4, 2, 3]"
# t2 = prims.div(t1, y) # t2: "cuda:0 f32[4, 2, 3]"
# t3 = prims.lt(x, 0.0) # t3: "cuda:0 b8[4, 2, 3]"
# t4 = prims.lt(y, 0.0) # t4: "cuda:0 b8[4, 2, 3]"
# t5 = prims.bitwise_xor(t3, t4) # t5: "cuda:0 b8[4, 2, 3]"
# t6 = prims.ne(t0, 0.0) # t6: "cuda:0 b8[4, 2, 3]"
# t7 = prims.bitwise_and(t6, t5) # t7: "cuda:0 b8[4, 2, 3]"
# t8 = prims.sub(t2, 1.0) # t8: "cuda:0 f32[4, 2, 3]"
# t9 = prims.where(t7, t8, t2) # t9: "cuda:0 f32[4, 2, 3]"
# t10 = prims.floor(t9) # t10: "cuda:0 f32[4, 2, 3]"
# t11 = prims.sub(t9, t10) # t11: "cuda:0 f32[4, 2, 3]"
# t12 = prims.gt(t11, 0.5) # t12: "cuda:0 b8[4, 2, 3]"
# t13 = prims.add(t10, 1.0) # t13: "cuda:0 f32[4, 2, 3]"
# t14 = prims.where(t12, t13, t10) # t14: "cuda:0 f32[4, 2, 3]"
# t15 = prims.div(x, y) # t15: "cuda:0 f32[4, 2, 3]"
# t16 = prims.ne(t9, 0.0) # t16: "cuda:0 b8[4, 2, 3]"
# t17 = prims.signbit(t15) # t17: "cuda:0 b8[4, 2, 3]"
# t18 = prims.where(t17, -0.0, 0.0) # t18: "cuda:0 f32[4, 2, 3]"
# t19 = prims.where(t16, t14, t18) # t19: "cuda:0 f32[4, 2, 3]"
# t20 = prims.eq(y, 0.0) # t20: "cuda:0 b8[4, 2, 3]"
# t21 = prims.where(t20, t15, t19) # t21: "cuda:0 f32[4, 2, 3]"
return (t21,) nvFuser Definition (maybe scalar double leads to type promotion here?) # CUDA devices:
# 0: NVIDIA RTX 6000 Ada Generation
# 1: NVIDIA RTX 6000 Ada Generation
# torch version: 2.6.0a0+git81ecf98
# cuda version: 12.3
# nvfuser version: 0.2.23+git769a4d2
import torch
from nvfuser import FusionDefinition, DataType
def nvfuser_fusion_id0(fd : FusionDefinition) -> None :
T0 = fd.define_tensor(shape=[4, 2, 3], contiguity=[True, True, True], dtype=DataType.Float, is_cpu=False, stride_order=[2, 1, 0])
T1 = fd.define_tensor(shape=[4, 2, 3], contiguity=[True, True, True], dtype=DataType.Float, is_cpu=False, stride_order=[2, 1, 0])
T2 = fd.ops.fmod(T0, T1)
T3 = fd.ops.sub(T0, T2)
T4 = fd.ops.reciprocal(T1)
T5 = fd.ops.mul(T3, T4)
S6 = fd.define_scalar(0.00000, dtype=DataType.Double)
T7 = fd.ops.lt(T0, S6)
S8 = fd.define_scalar(0.00000, dtype=DataType.Double)
T9 = fd.ops.lt(T1, S8)
T10 = fd.ops.bitwise_xor(T7, T9)
S11 = fd.define_scalar(0.00000, dtype=DataType.Double)
T12 = fd.ops.ne(T2, S11)
T13 = fd.ops.bitwise_and(T12, T10)
S14 = fd.define_scalar(1.00000, dtype=DataType.Double)
T15 = fd.ops.sub(T5, S14)
T16 = fd.ops.where(T13, T15, T5)
T17 = fd.ops.floor(T16)
T18 = fd.ops.sub(T16, T17)
S19 = fd.define_scalar(0.500000, dtype=DataType.Double)
T20 = fd.ops.gt(T18, S19)
S21 = fd.define_scalar(1.00000, dtype=DataType.Double)
T22 = fd.ops.add(T17, S21)
T23 = fd.ops.where(T20, T22, T17)
T24 = fd.ops.reciprocal(T1)
T25 = fd.ops.mul(T0, T24)
S26 = fd.define_scalar(0.00000, dtype=DataType.Double)
T27 = fd.ops.ne(T16, S26)
T28 = fd.ops.signbit(T25)
S29 = fd.define_scalar(-0.00000, dtype=DataType.Double)
S30 = fd.define_scalar(0.00000, dtype=DataType.Double)
T31 = fd.ops.where(T28, S29, S30)
T32 = fd.ops.where(T27, T23, T31)
S33 = fd.define_scalar(0.00000, dtype=DataType.Double)
T34 = fd.ops.eq(T1, S33)
T35 = fd.ops.where(T34, T25, T32)
fd.add_output(T35)
with FusionDefinition() as fd:
nvfuser_fusion_id0(fd)
inputs = [
torch.testing.make_tensor((4, 2, 3), dtype=torch.float32, device='cuda:0'),
torch.testing.make_tensor((4, 2, 3), dtype=torch.float32, device='cuda:0'),
]
fd.execute(inputs) |
Thanks for the issue, and extra thanks for the small reproducer. cc @jjsjann123 is this an issue in the nvfuser executor or nvfuser itself? |
Sorry for the delayed response, this felt like an nvfuser issue. It went wrong here.
Where it's translated to nvfuser as I think we can explicitly adding a cast at the end of where. |
trying it out in #1734 |
test_core_vs_torch_consistency_div_nvfuser_cuda_thunder.dtypes.float32
, which tests the elementwise division of tensors of dtype float32 with nvfuser, fails with the torch result of dtype float32 while the Thunder result is float64.Thunder should also return float32.
cc @tfogal
The text was updated successfully, but these errors were encountered: