@@ -71,17 +71,28 @@ pybind11::bytes PackCustomCallCommonWkDescriptor(const std::vector<size_t> &shap
71
71
return PackOpaque (desc);
72
72
}
73
73
74
- pybind11::bytes PackCustomCallNormDescriptor (size_t batch_size, size_t hidden_size,
75
- size_t wkspace_size, size_t barrier_size,
76
- size_t *dgamma_part_sizes, size_t *dbeta_part_sizes,
77
- DType x_dtype, DType w_dtype, DType wkspace_dtype,
78
- DType barrier_dtype, DType dgamma_part_dtype,
79
- DType dbeta_part_dtype, bool zero_centered_gamma,
80
- float eps, int sm_margin) {
81
- return PackOpaque (CustomCallNormDescriptor{
82
- batch_size, hidden_size, wkspace_size, barrier_size, dgamma_part_sizes, dbeta_part_sizes,
83
- x_dtype, w_dtype, wkspace_dtype, barrier_dtype, dgamma_part_dtype, dbeta_part_dtype,
84
- zero_centered_gamma, eps, sm_margin});
74
+ pybind11::bytes PackCustomCallNormDescriptor (
75
+ size_t batch_size, size_t hidden_size, size_t wkspace_size, size_t barrier_size,
76
+ const std::vector<size_t > &dgamma_part_shape, const std::vector<size_t > &dbeta_part_shape,
77
+ DType x_dtype, DType w_dtype, DType wkspace_dtype, DType barrier_dtype, DType dgamma_part_dtype,
78
+ DType dbeta_part_dtype, bool zero_centered_gamma, float eps, int sm_margin) {
79
+ CustomCallNormDescriptor desc;
80
+ desc.batch_size = batch_size;
81
+ desc.hidden_size = hidden_size;
82
+ desc.wkspace_size = wkspace_size;
83
+ desc.barrier_size = barrier_size;
84
+ desc.dgamma_part_shape .from_vector (dgamma_part_shape);
85
+ desc.dbeta_part_shape .from_vector (dbeta_part_shape);
86
+ desc.x_dtype = x_dtype;
87
+ desc.w_dtype = w_dtype;
88
+ desc.wkspace_dtype = wkspace_dtype;
89
+ desc.barrier_dtype = barrier_dtype;
90
+ desc.dgamma_part_dtype = dgamma_part_dtype;
91
+ desc.dbeta_part_dtype = dbeta_part_dtype;
92
+ desc.zero_centered_gamma = zero_centered_gamma;
93
+ desc.eps = eps;
94
+ desc.sm_margin = sm_margin;
95
+ return PackOpaque (desc);
85
96
}
86
97
87
98
pybind11::bytes PackCustomCallSoftmaxDescriptor (size_t batch_size, size_t padding_size,
@@ -529,7 +540,7 @@ pybind11::tuple GetLayerNormBackwardWorkspaceSizes(size_t batch_size, size_t hid
529
540
}
530
541
531
542
void LayerNormBackwardImpl (size_t batch_size, size_t hidden_size, size_t wkspace_size,
532
- size_t barrier_size, size_t *dgamma_part_sizes, size_t *dbeta_part_sizes ,
543
+ size_t barrier_size, Shape dgamma_part_shape, Shape dbeta_part_shape ,
533
544
bool zero_centered_gamma, float eps, void *input, DType in_dtype,
534
545
void *weight, DType w_dtype, void *ograd, void *workspace,
535
546
DType wkspace_dtype, void *barrier, DType barrier_dtype, void *mu,
@@ -563,14 +574,14 @@ void LayerNormBackwardImpl(size_t batch_size, size_t hidden_size, size_t wkspace
563
574
auto workspace_tensor = TensorWrapper (workspace, workspace_shape, wkspace_dtype);
564
575
auto barrier_shape = std::vector<size_t >{barrier_size};
565
576
auto barrier_tensor = TensorWrapper (barrier, barrier_shape, barrier_dtype);
566
- auto dgamma_part_shape = std::vector< size_t >{dgamma_part_sizes[ 0 ], dgamma_part_sizes[ 1 ]};
567
- auto dgamma_part_tensor = TensorWrapper (dgamma_part, dgamma_part_shape, dgamma_dtype);
577
+ auto dgamma_part_tensor =
578
+ TensorWrapper (dgamma_part, dgamma_part_shape. to_vector () , dgamma_dtype);
568
579
569
580
if (is_layer_norm) {
570
581
auto mu_tensor = TensorWrapper (mu, intermediates_shape, intermediates_dtype);
571
582
auto dbeta_tensor = TensorWrapper (dbeta, weight_shape, w_dtype);
572
- auto dbeta_part_shape = std::vector< size_t >{dbeta_part_sizes[ 0 ], dbeta_part_sizes[ 1 ]};
573
- auto dbeta_part_tensor = TensorWrapper (dbeta_part, dbeta_part_shape, dbeta_dtype);
583
+ auto dbeta_part_tensor =
584
+ TensorWrapper (dbeta_part, dbeta_part_shape. to_vector () , dbeta_dtype);
574
585
575
586
layernorm_bwd_func (dz_tensor.data (), x_tensor.data (), mu_tensor.data (),
576
587
rsigma_tensor.data (), gamma_tensor.data (), xgrad_tensor.data (),
@@ -664,8 +675,8 @@ void LayerNormBackward(cudaStream_t stream, void **buffers, const char *opaque,
664
675
auto hidden_size = desc.hidden_size ;
665
676
auto wkspace_size = desc.wkspace_size ;
666
677
auto barrier_size = desc.barrier_size ;
667
- auto *dgamma_part_sizes = desc.dgamma_part_sizes ;
668
- auto *dbeta_part_sizes = desc.dbeta_part_sizes ;
678
+ auto dgamma_part_shape = desc.dgamma_part_shape ;
679
+ auto dbeta_part_shape = desc.dbeta_part_shape ;
669
680
auto in_dtype = desc.x_dtype ;
670
681
auto w_dtype = desc.w_dtype ;
671
682
auto wkspace_dtype = desc.wkspace_dtype ;
@@ -689,8 +700,8 @@ void LayerNormBackward(cudaStream_t stream, void **buffers, const char *opaque,
689
700
auto *dgamma_part = buffers[10 ];
690
701
auto *dbeta_part = buffers[11 ];
691
702
692
- LayerNormBackwardImpl (batch_size, hidden_size, wkspace_size, barrier_size, dgamma_part_sizes ,
693
- dbeta_part_sizes , zero_centered_gamma, eps, input, in_dtype, weight,
703
+ LayerNormBackwardImpl (batch_size, hidden_size, wkspace_size, barrier_size, dgamma_part_shape ,
704
+ dbeta_part_shape , zero_centered_gamma, eps, input, in_dtype, weight,
694
705
w_dtype, ograd, workspace, wkspace_dtype, barrier, barrier_dtype, mu,
695
706
rsigma, xgrad, wgrad, dbeta, dgamma_part, dgamma_part_dtype, dbeta_part,
696
707
dbeta_part_dtype, stream);
@@ -786,8 +797,9 @@ void RMSNormBackward(cudaStream_t stream, void **buffers, const char *opaque, si
786
797
auto hidden_size = desc.hidden_size ;
787
798
auto wkspace_size = desc.wkspace_size ;
788
799
auto barrier_size = desc.barrier_size ;
789
- auto dgamma_part_sizes = desc.dgamma_part_sizes ;
790
- size_t dbeta_part_sizes[2 ] = {0 , 0 };
800
+ auto dgamma_part_shape = desc.dgamma_part_shape ;
801
+ Shape dbeta_part_shape;
802
+ dbeta_part_shape.from_vector ({0 , 0 });
791
803
auto in_dtype = desc.x_dtype ;
792
804
auto w_dtype = desc.w_dtype ;
793
805
auto wkspace_dtype = desc.wkspace_dtype ;
@@ -797,8 +809,8 @@ void RMSNormBackward(cudaStream_t stream, void **buffers, const char *opaque, si
797
809
auto eps = desc.eps ;
798
810
auto zero_centered_gamma = desc.zero_centered_gamma ;
799
811
800
- LayerNormBackwardImpl (batch_size, hidden_size, wkspace_size, barrier_size, dgamma_part_sizes ,
801
- dbeta_part_sizes , zero_centered_gamma, eps, input, in_dtype, weight,
812
+ LayerNormBackwardImpl (batch_size, hidden_size, wkspace_size, barrier_size, dgamma_part_shape ,
813
+ dbeta_part_shape , zero_centered_gamma, eps, input, in_dtype, weight,
802
814
w_dtype, ograd, workspace, wkspace_dtype, barrier, barrier_dtype, mu,
803
815
rsigma, xgrad, wgrad, dbeta, dgamma_part, dgamma_part_dtype, dbeta_part,
804
816
dbeta_part_dtype, stream);
0 commit comments