@@ -2305,14 +2305,17 @@ struct ggml_tensor * ggml_repeat(
23052305struct ggml_tensor * ggml_repeat_back (
23062306 struct ggml_context * ctx ,
23072307 struct ggml_tensor * a ,
2308- struct ggml_tensor * b ) {
2308+ struct ggml_tensor * b ,
2309+ bool adjacent ) {
23092310 GGML_ASSERT (ggml_can_repeat (b , a ));
23102311
23112312 struct ggml_tensor * result = ggml_new_tensor (ctx , a -> type , GGML_MAX_DIMS , b -> ne );
23122313
23132314 result -> op = GGML_OP_REPEAT_BACK ;
23142315 result -> src [0 ] = a ;
23152316
2317+ result -> op_params [0 ] = adjacent ? 1 : 0 ;
2318+
23162319 return result ;
23172320}
23182321
@@ -5299,7 +5302,7 @@ static void ggml_compute_backward(
52995302 if (src1_needs_grads ) {
53005303 struct ggml_tensor * tmp = grad ;
53015304 if (!ggml_are_same_shape (src0 , src1 )) {
5302- tmp = ggml_repeat_back (ctx , tmp , src1 );
5305+ tmp = ggml_repeat_back (ctx , tmp , src1 , false );
53035306 }
53045307 ggml_add_or_set (ctx , cgraph , isrc1 , tmp );
53055308 }
@@ -5339,12 +5342,12 @@ static void ggml_compute_backward(
53395342 } break ;
53405343 case GGML_OP_MUL : {
53415344 if (src0_needs_grads ) {
5342- ggml_add_or_set (ctx , cgraph , isrc0 , ggml_mul (ctx , src1 , grad ));
5345+ ggml_add_or_set (ctx , cgraph , isrc0 , ggml_mul (ctx , grad , src1 ));
53435346 }
53445347 if (src1_needs_grads ) {
53455348 struct ggml_tensor * tmp = ggml_mul (ctx , src0 , grad );
53465349 if (!ggml_are_same_shape (src0 , src1 )) {
5347- tmp = ggml_repeat_back (ctx , tmp , src1 );
5350+ tmp = ggml_repeat_back (ctx , tmp , src1 , false );
53485351 }
53495352 ggml_add_or_set (ctx , cgraph , isrc1 , tmp );
53505353 }
@@ -5399,7 +5402,7 @@ static void ggml_compute_backward(
53995402 } break ;
54005403 case GGML_OP_REPEAT : {
54015404 if (src0_needs_grads ) {
5402- ggml_add_or_set (ctx , cgraph , isrc0 , ggml_repeat_back (ctx , grad , src0 ));
5405+ ggml_add_or_set (ctx , cgraph , isrc0 , ggml_repeat_back (ctx , grad , src0 , false ));
54035406 }
54045407 } break ;
54055408 case GGML_OP_REPEAT_BACK : {
@@ -5431,21 +5434,18 @@ static void ggml_compute_backward(
54315434 // src1.shape [n,p,qq,rr]
54325435
54335436 if (src0_needs_grads ) {
5434- struct ggml_tensor * s1_tg =
5437+ GGML_ASSERT (grad -> ne [2 ] == src1 -> ne [2 ]);
5438+ GGML_ASSERT (grad -> ne [3 ] == src1 -> ne [3 ]);
5439+ struct ggml_tensor * tmp =
54355440 ggml_out_prod (ctx , // [n,m,qq,rr]
54365441 src1 , // [n,p,qq,rr]
54375442 grad ); // [m,p,qq,rr]
5438- const int64_t qq = s1_tg -> ne [2 ];
5439- const int64_t rr = s1_tg -> ne [3 ];
5440- const int64_t q1 = src0 -> ne [2 ];
5441- const int64_t r1 = src0 -> ne [3 ];
5442- const bool ne2_broadcasted = qq > q1 ;
5443- const bool ne3_broadcasted = rr > r1 ;
5444- if (ne2_broadcasted || ne3_broadcasted ) {
5445- // sum broadcast repetitions of s1_tg into shape of src0
5446- s1_tg = ggml_repeat_back (ctx , s1_tg , src0 );
5443+ if (!ggml_are_same_shape (tmp , src0 )) {
5444+ GGML_ASSERT (tmp -> ne [0 ] == src0 -> ne [0 ]);
5445+ GGML_ASSERT (tmp -> ne [1 ] == src0 -> ne [1 ]);
5446+ tmp = ggml_repeat_back (ctx , tmp , src0 , true);
54475447 }
5448- ggml_add_or_set (ctx , cgraph , isrc0 , s1_tg /*= [n,m,q1,r1]*/ );
5448+ ggml_add_or_set (ctx , cgraph , isrc0 , tmp );
54495449 }
54505450 if (src1_needs_grads ) {
54515451 ggml_add_or_set (ctx , cgraph , isrc1 ,
@@ -5514,7 +5514,9 @@ static void ggml_compute_backward(
55145514 if (src0_needs_grads ) {
55155515 GGML_ASSERT (!cgraph -> grads [isrc0 ] || ggml_is_contiguous (cgraph -> grads [isrc0 ]));
55165516 GGML_ASSERT (ggml_is_contiguous (grad ));
5517- ggml_add_or_set (ctx , cgraph , isrc0 , grad );
5517+ GGML_ASSERT (ggml_nelements (tensor ) == ggml_nelements (src0 ));
5518+ ggml_add_or_set (ctx , cgraph , isrc0 ,
5519+ ggml_are_same_shape (tensor , src0 ) ? grad : ggml_reshape (ctx , grad , src0 ));
55185520 }
55195521 } break ;
55205522 case GGML_OP_RESHAPE : {
0 commit comments