@@ -501,7 +501,6 @@ def index_put_converter(
501
501
F = [i for i in range (rank ) if indices [i ] is None ] # Free dimensions
502
502
I = [i for i in range (rank ) if indices [i ] is not None ] # Indexed dimensions
503
503
K = len (I )
504
-
505
504
# Determine the maximum size 'N' among the index tensors
506
505
if K > 0 :
507
506
index_shapes = [tensor .shape [0 ] for tensor in indices if tensor is not None ]
@@ -684,16 +683,6 @@ def index_put_converter(
684
683
values_reshaped = impl .shuffle .reshape (
685
684
ctx , target , source_ir , f"{ name } _reshape_scalar" , values , (1 ,)
686
685
)
687
- num_dims = len (expected_shape )
688
- ones_shape = tuple ([1 ] * num_dims )
689
- values_reshaped = impl .shuffle .reshape (
690
- ctx ,
691
- target ,
692
- source_ir ,
693
- f"{ name } _reshape_to_ones" ,
694
- values_reshaped ,
695
- ones_shape ,
696
- )
697
686
values_expanded = impl .slice .expand (
698
687
ctx ,
699
688
target ,
@@ -704,40 +693,79 @@ def index_put_converter(
704
693
)
705
694
else : # Non-scalar case
706
695
values_shape = list (values .shape )
707
-
708
- # Pad dimensions if necessary
709
- if len (values_shape ) < len (expected_shape ):
710
- values_shape = [1 ] * (
711
- len (expected_shape ) - len (values_shape )
712
- ) + values_shape
713
-
714
- # Calculate a broadcastable shape
715
- broadcast_shape = []
716
- for exp_dim , val_dim in zip (expected_shape , values_shape ):
717
- if val_dim == 1 :
718
- broadcast_shape .append (exp_dim )
719
- elif val_dim == exp_dim :
720
- broadcast_shape .append (val_dim )
696
+ if K > 0 and N in values_shape :
697
+ n_idx = values_shape .index (N )
698
+ permute_order = [n_idx ] + [
699
+ i for i in range (len (values_shape )) if i != n_idx
700
+ ]
701
+ values_permuted = impl .permutation .permute (
702
+ ctx , target , source_ir , f"{ name } _permute_values" , values , permute_order
703
+ )
704
+ remaining_shape = [
705
+ values_shape [i ] for i in range (len (values_shape )) if i != n_idx
706
+ ]
707
+ target_f_dims = len (F )
708
+ current_f_dims = len (remaining_shape )
709
+ if current_f_dims < target_f_dims :
710
+ values_expanded_shape = (
711
+ [N ] + [1 ] * (target_f_dims - current_f_dims ) + remaining_shape
712
+ )
721
713
else :
722
- raise ValueError (f"Cannot broadcast { values_shape } to { expected_shape } " )
723
-
724
- # Reshape and then expand
725
- values_reshaped = impl .shuffle .reshape (
726
- ctx ,
727
- target ,
728
- source_ir ,
729
- f"{ name } _reshape_values" ,
730
- values ,
731
- tuple (broadcast_shape ),
732
- )
733
- values_expanded = impl .slice .expand (
734
- ctx ,
735
- target ,
736
- source_ir ,
737
- f"{ name } _expand_values" ,
738
- values_reshaped ,
739
- expected_shape ,
740
- )
714
+ values_expanded_shape = [N ] + remaining_shape [:target_f_dims ]
715
+ values_expanded = impl .shuffle .reshape (
716
+ ctx ,
717
+ target ,
718
+ source_ir ,
719
+ f"{ name } _unsqueeze_values" ,
720
+ values_permuted ,
721
+ tuple (values_expanded_shape ),
722
+ )
723
+ broadcast_shape = []
724
+ for exp_dim , val_dim in zip (expected_shape , values_expanded_shape ):
725
+ if val_dim == 1 :
726
+ broadcast_shape .append (exp_dim )
727
+ elif val_dim == exp_dim :
728
+ broadcast_shape .append (val_dim )
729
+ else :
730
+ raise ValueError (
731
+ f"Cannot broadcast { values_expanded_shape } to { expected_shape } "
732
+ )
733
+ values_expanded = impl .slice .expand (
734
+ ctx ,
735
+ target ,
736
+ source_ir ,
737
+ f"{ name } _expand_values" ,
738
+ values_expanded ,
739
+ tuple (broadcast_shape ),
740
+ )
741
+ else :
742
+ values_shape_padded = [1 ] * (
743
+ len (expected_shape ) - len (values .shape )
744
+ ) + list (values .shape )
745
+ broadcast_shape = []
746
+ for exp_dim , val_dim in zip (expected_shape , values_shape_padded ):
747
+ if val_dim == 1 or exp_dim == val_dim :
748
+ broadcast_shape .append (exp_dim )
749
+ else :
750
+ raise ValueError (
751
+ f"Cannot broadcast { values .shape } to { expected_shape } "
752
+ )
753
+ values_reshaped = impl .shuffle .reshape (
754
+ ctx ,
755
+ target ,
756
+ source_ir ,
757
+ f"{ name } _reshape_values" ,
758
+ values ,
759
+ tuple (broadcast_shape ),
760
+ )
761
+ values_expanded = impl .slice .expand (
762
+ ctx ,
763
+ target ,
764
+ source_ir ,
765
+ f"{ name } _expand_values" ,
766
+ values_reshaped ,
767
+ expected_shape ,
768
+ )
741
769
742
770
# Flatten values to (N * F_volume,)
743
771
flattened_values = impl .shuffle .reshape (
@@ -749,6 +777,7 @@ def index_put_converter(
749
777
(N * F_volume ,),
750
778
)
751
779
780
+ indices_cat = cast_trt_tensor (ctx , indices_cat , trt .int32 , f"{ name } _idx_int32" )
752
781
# Perform Scatter ND operation
753
782
scatter_layer = ctx .net .add_scatter (
754
783
input_tensor ,
0 commit comments