From 39cf69cbd2d4e1792478e9c0cc85637ed949ed39 Mon Sep 17 00:00:00 2001 From: Dheeraj Peri Date: Tue, 1 Jul 2025 16:56:13 +0000 Subject: [PATCH 1/2] feat: Add bf16 support to cast converter --- .../dynamo/conversion/aten_ops_converters.py | 1 + tests/py/dynamo/conversion/harness.py | 2 ++ tests/py/dynamo/conversion/test_casts.py | 15 +++++++++++++++ 3 files changed, 18 insertions(+) diff --git a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py index e542f1d417..fbdf6c861d 100644 --- a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py +++ b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py @@ -1034,6 +1034,7 @@ def validate_dtype(to_copy_node: Node) -> bool: torch.bool, torch.int8, torch.float16, + torch.bfloat16, } # Validate input node has convertible kwargs diff --git a/tests/py/dynamo/conversion/harness.py b/tests/py/dynamo/conversion/harness.py index 79e656ef82..93ffc8b451 100644 --- a/tests/py/dynamo/conversion/harness.py +++ b/tests/py/dynamo/conversion/harness.py @@ -412,6 +412,7 @@ def run_test( propagate_shapes=False, int32_reqd=False, immutable_weights=True, + use_explicit_typing=False, ): # TODO: lan to remove this and set use_dynamo_traccer to True by default # once all the converter test files are moved to use_dynamo_tracer @@ -422,6 +423,7 @@ def run_test( enabled_precisions={dtype._from(precision)}, truncate_double=True, immutable_weights=immutable_weights, + use_explicit_typing=use_explicit_typing, ) mod = self.generate_graph( diff --git a/tests/py/dynamo/conversion/test_casts.py b/tests/py/dynamo/conversion/test_casts.py index 88260ba771..997092d24b 100644 --- a/tests/py/dynamo/conversion/test_casts.py +++ b/tests/py/dynamo/conversion/test_casts.py @@ -64,6 +64,21 @@ def forward(self, x): precision=torch.float, ) + def test_to_copy_bfloat16(self): + class ToCopyBFloat16(nn.Module): + def forward(self, x): + y = torch.ops.aten._to_copy.default(x, dtype=torch.bfloat16) + y = y**2 + return y + + inputs = [torch.rand((1, 3, 10), dtype=torch.float32)] + self.run_test( + ToCopyBFloat16(), + inputs, + precision=torch.float, + use_explicit_typing=True, + ) + def test_to_copy_i64b(self): class ToCopy64Bit(nn.Module): def forward(self, x): From 8f14ba4eaad6de1e37f5ab6fe69a6c26c102a3ac Mon Sep 17 00:00:00 2001 From: Dheeraj Peri Date: Thu, 3 Jul 2025 20:30:07 +0000 Subject: [PATCH 2/2] chore: fix for testcase --- .../dynamo/conversion/impl/elementwise/ops.py | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/py/torch_tensorrt/dynamo/conversion/impl/elementwise/ops.py b/py/torch_tensorrt/dynamo/conversion/impl/elementwise/ops.py index 17e5042ce7..1bfb8c7242 100644 --- a/py/torch_tensorrt/dynamo/conversion/impl/elementwise/ops.py +++ b/py/torch_tensorrt/dynamo/conversion/impl/elementwise/ops.py @@ -544,9 +544,16 @@ def pow( lhs_val: Union[TRTTensor, int, float], rhs_val: Union[TRTTensor, int, float], ) -> TRTTensor: + + lhs_dtype = None + rhs_dtype = None + if isinstance(lhs_val, int): + lhs_dtype = torch.int32 + if isinstance(rhs_val, int): + rhs_dtype = torch.int32 # POW operation supports only float32 and int8 inputs - lhs_val = get_trt_tensor(ctx, lhs_val, name + "_lhs_val", trt.float32) - rhs_val = get_trt_tensor(ctx, rhs_val, name + "_rhs_val", trt.float32) + lhs_val = get_trt_tensor(ctx, lhs_val, name + "_lhs_val", lhs_dtype) + rhs_val = get_trt_tensor(ctx, rhs_val, name + "_rhs_val", rhs_dtype) out = convert_binary_elementwise( ctx, target, source_ir, name, trt.ElementWiseOperation.POW, lhs_val, rhs_val )