@@ -501,7 +501,6 @@ def index_put_converter(
501501 F = [i for i in range (rank ) if indices [i ] is None ] # Free dimensions
502502 I = [i for i in range (rank ) if indices [i ] is not None ] # Indexed dimensions
503503 K = len (I )
504-
505504 # Determine the maximum size 'N' among the index tensors
506505 if K > 0 :
507506 index_shapes = [tensor .shape [0 ] for tensor in indices if tensor is not None ]
@@ -684,16 +683,6 @@ def index_put_converter(
684683 values_reshaped = impl .shuffle .reshape (
685684 ctx , target , source_ir , f"{ name } _reshape_scalar" , values , (1 ,)
686685 )
687- num_dims = len (expected_shape )
688- ones_shape = tuple ([1 ] * num_dims )
689- values_reshaped = impl .shuffle .reshape (
690- ctx ,
691- target ,
692- source_ir ,
693- f"{ name } _reshape_to_ones" ,
694- values_reshaped ,
695- ones_shape ,
696- )
697686 values_expanded = impl .slice .expand (
698687 ctx ,
699688 target ,
@@ -704,40 +693,79 @@ def index_put_converter(
704693 )
705694 else : # Non-scalar case
706695 values_shape = list (values .shape )
707-
708- # Pad dimensions if necessary
709- if len (values_shape ) < len (expected_shape ):
710- values_shape = [1 ] * (
711- len (expected_shape ) - len (values_shape )
712- ) + values_shape
713-
714- # Calculate a broadcastable shape
715- broadcast_shape = []
716- for exp_dim , val_dim in zip (expected_shape , values_shape ):
717- if val_dim == 1 :
718- broadcast_shape .append (exp_dim )
719- elif val_dim == exp_dim :
720- broadcast_shape .append (val_dim )
696+ if K > 0 and N in values_shape :
697+ n_idx = values_shape .index (N )
698+ permute_order = [n_idx ] + [
699+ i for i in range (len (values_shape )) if i != n_idx
700+ ]
701+ values_permuted = impl .permutation .permute (
702+ ctx , target , source_ir , f"{ name } _permute_values" , values , permute_order
703+ )
704+ remaining_shape = [
705+ values_shape [i ] for i in range (len (values_shape )) if i != n_idx
706+ ]
707+ target_f_dims = len (F )
708+ current_f_dims = len (remaining_shape )
709+ if current_f_dims < target_f_dims :
710+ values_expanded_shape = (
711+ [N ] + [1 ] * (target_f_dims - current_f_dims ) + remaining_shape
712+ )
721713 else :
722- raise ValueError (f"Cannot broadcast { values_shape } to { expected_shape } " )
723-
724- # Reshape and then expand
725- values_reshaped = impl .shuffle .reshape (
726- ctx ,
727- target ,
728- source_ir ,
729- f"{ name } _reshape_values" ,
730- values ,
731- tuple (broadcast_shape ),
732- )
733- values_expanded = impl .slice .expand (
734- ctx ,
735- target ,
736- source_ir ,
737- f"{ name } _expand_values" ,
738- values_reshaped ,
739- expected_shape ,
740- )
714+ values_expanded_shape = [N ] + remaining_shape [:target_f_dims ]
715+ values_expanded = impl .shuffle .reshape (
716+ ctx ,
717+ target ,
718+ source_ir ,
719+ f"{ name } _unsqueeze_values" ,
720+ values_permuted ,
721+ tuple (values_expanded_shape ),
722+ )
723+ broadcast_shape = []
724+ for exp_dim , val_dim in zip (expected_shape , values_expanded_shape ):
725+ if val_dim == 1 :
726+ broadcast_shape .append (exp_dim )
727+ elif val_dim == exp_dim :
728+ broadcast_shape .append (val_dim )
729+ else :
730+ raise ValueError (
731+ f"Cannot broadcast { values_expanded_shape } to { expected_shape } "
732+ )
733+ values_expanded = impl .slice .expand (
734+ ctx ,
735+ target ,
736+ source_ir ,
737+ f"{ name } _expand_values" ,
738+ values_expanded ,
739+ tuple (broadcast_shape ),
740+ )
741+ else :
742+ values_shape_padded = [1 ] * (
743+ len (expected_shape ) - len (values .shape )
744+ ) + list (values .shape )
745+ broadcast_shape = []
746+ for exp_dim , val_dim in zip (expected_shape , values_shape_padded ):
747+ if val_dim == 1 or exp_dim == val_dim :
748+ broadcast_shape .append (exp_dim )
749+ else :
750+ raise ValueError (
751+ f"Cannot broadcast { values .shape } to { expected_shape } "
752+ )
753+ values_reshaped = impl .shuffle .reshape (
754+ ctx ,
755+ target ,
756+ source_ir ,
757+ f"{ name } _reshape_values" ,
758+ values ,
759+ tuple (broadcast_shape ),
760+ )
761+ values_expanded = impl .slice .expand (
762+ ctx ,
763+ target ,
764+ source_ir ,
765+ f"{ name } _expand_values" ,
766+ values_reshaped ,
767+ expected_shape ,
768+ )
741769
742770 # Flatten values to (N * F_volume,)
743771 flattened_values = impl .shuffle .reshape (
@@ -749,6 +777,7 @@ def index_put_converter(
749777 (N * F_volume ,),
750778 )
751779
780+ indices_cat = cast_trt_tensor (ctx , indices_cat , trt .int32 , f"{ name } _idx_int32" )
752781 # Perform Scatter ND operation
753782 scatter_layer = ctx .net .add_scatter (
754783 input_tensor ,
0 commit comments