Skip to content

Commit 0ab02c3

Browse files
committed
fix: index_put converter to handle multi-shape slicing with None
1 parent 79083b6 commit 0ab02c3

File tree

2 files changed

+79
-44
lines changed

2 files changed

+79
-44
lines changed

py/torch_tensorrt/dynamo/conversion/impl/select.py

+73-44
Original file line numberDiff line numberDiff line change
@@ -501,7 +501,6 @@ def index_put_converter(
501501
F = [i for i in range(rank) if indices[i] is None] # Free dimensions
502502
I = [i for i in range(rank) if indices[i] is not None] # Indexed dimensions
503503
K = len(I)
504-
505504
# Determine the maximum size 'N' among the index tensors
506505
if K > 0:
507506
index_shapes = [tensor.shape[0] for tensor in indices if tensor is not None]
@@ -684,16 +683,6 @@ def index_put_converter(
684683
values_reshaped = impl.shuffle.reshape(
685684
ctx, target, source_ir, f"{name}_reshape_scalar", values, (1,)
686685
)
687-
num_dims = len(expected_shape)
688-
ones_shape = tuple([1] * num_dims)
689-
values_reshaped = impl.shuffle.reshape(
690-
ctx,
691-
target,
692-
source_ir,
693-
f"{name}_reshape_to_ones",
694-
values_reshaped,
695-
ones_shape,
696-
)
697686
values_expanded = impl.slice.expand(
698687
ctx,
699688
target,
@@ -704,40 +693,79 @@ def index_put_converter(
704693
)
705694
else: # Non-scalar case
706695
values_shape = list(values.shape)
707-
708-
# Pad dimensions if necessary
709-
if len(values_shape) < len(expected_shape):
710-
values_shape = [1] * (
711-
len(expected_shape) - len(values_shape)
712-
) + values_shape
713-
714-
# Calculate a broadcastable shape
715-
broadcast_shape = []
716-
for exp_dim, val_dim in zip(expected_shape, values_shape):
717-
if val_dim == 1:
718-
broadcast_shape.append(exp_dim)
719-
elif val_dim == exp_dim:
720-
broadcast_shape.append(val_dim)
696+
if K > 0 and N in values_shape:
697+
n_idx = values_shape.index(N)
698+
permute_order = [n_idx] + [
699+
i for i in range(len(values_shape)) if i != n_idx
700+
]
701+
values_permuted = impl.permutation.permute(
702+
ctx, target, source_ir, f"{name}_permute_values", values, permute_order
703+
)
704+
remaining_shape = [
705+
values_shape[i] for i in range(len(values_shape)) if i != n_idx
706+
]
707+
target_f_dims = len(F)
708+
current_f_dims = len(remaining_shape)
709+
if current_f_dims < target_f_dims:
710+
values_expanded_shape = (
711+
[N] + [1] * (target_f_dims - current_f_dims) + remaining_shape
712+
)
721713
else:
722-
raise ValueError(f"Cannot broadcast {values_shape} to {expected_shape}")
723-
724-
# Reshape and then expand
725-
values_reshaped = impl.shuffle.reshape(
726-
ctx,
727-
target,
728-
source_ir,
729-
f"{name}_reshape_values",
730-
values,
731-
tuple(broadcast_shape),
732-
)
733-
values_expanded = impl.slice.expand(
734-
ctx,
735-
target,
736-
source_ir,
737-
f"{name}_expand_values",
738-
values_reshaped,
739-
expected_shape,
740-
)
714+
values_expanded_shape = [N] + remaining_shape[:target_f_dims]
715+
values_expanded = impl.shuffle.reshape(
716+
ctx,
717+
target,
718+
source_ir,
719+
f"{name}_unsqueeze_values",
720+
values_permuted,
721+
tuple(values_expanded_shape),
722+
)
723+
broadcast_shape = []
724+
for exp_dim, val_dim in zip(expected_shape, values_expanded_shape):
725+
if val_dim == 1:
726+
broadcast_shape.append(exp_dim)
727+
elif val_dim == exp_dim:
728+
broadcast_shape.append(val_dim)
729+
else:
730+
raise ValueError(
731+
f"Cannot broadcast {values_expanded_shape} to {expected_shape}"
732+
)
733+
values_expanded = impl.slice.expand(
734+
ctx,
735+
target,
736+
source_ir,
737+
f"{name}_expand_values",
738+
values_expanded,
739+
tuple(broadcast_shape),
740+
)
741+
else:
742+
values_shape_padded = [1] * (
743+
len(expected_shape) - len(values.shape)
744+
) + list(values.shape)
745+
broadcast_shape = []
746+
for exp_dim, val_dim in zip(expected_shape, values_shape_padded):
747+
if val_dim == 1 or exp_dim == val_dim:
748+
broadcast_shape.append(exp_dim)
749+
else:
750+
raise ValueError(
751+
f"Cannot broadcast {values.shape} to {expected_shape}"
752+
)
753+
values_reshaped = impl.shuffle.reshape(
754+
ctx,
755+
target,
756+
source_ir,
757+
f"{name}_reshape_values",
758+
values,
759+
tuple(broadcast_shape),
760+
)
761+
values_expanded = impl.slice.expand(
762+
ctx,
763+
target,
764+
source_ir,
765+
f"{name}_expand_values",
766+
values_reshaped,
767+
expected_shape,
768+
)
741769

742770
# Flatten values to (N * F_volume,)
743771
flattened_values = impl.shuffle.reshape(
@@ -749,6 +777,7 @@ def index_put_converter(
749777
(N * F_volume,),
750778
)
751779

780+
indices_cat = cast_trt_tensor(ctx, indices_cat, trt.int32, f"{name}_idx_int32")
752781
# Perform Scatter ND operation
753782
scatter_layer = ctx.net.add_scatter(
754783
input_tensor,

tests/py/dynamo/conversion/test_index_put_aten.py

+6
Original file line numberDiff line numberDiff line change
@@ -194,6 +194,12 @@ class TestIndexPutConverter(DispatchTestCase):
194194
dtype=torch.int32,
195195
),
196196
),
197+
param(
198+
test_name="4d_indices_none_none_multiple_idx_broadcast_error",
199+
source_tensor=torch.zeros([1, 2, 5, 3], dtype=torch.float32),
200+
indices_tensor=(None, None, torch.tensor([0, 1, 2], dtype=torch.int64)),
201+
value_tensor=torch.randn([2, 3, 3], dtype=torch.float32),
202+
),
197203
# param(
198204
# test_name="2d_indices_accumulate_True",
199205
# source_tensor=torch.zeros([5, 5], dtype=torch.int32),

0 commit comments

Comments
 (0)