Skip to content

Commit 5da8ae3

Browse files
committed
tests : add option to permute the dst tensor
ggml-ci
1 parent 938c779 commit 5da8ae3

File tree

2 files changed

+47
-35
lines changed

2 files changed

+47
-35
lines changed

ggml/src/ggml-cpu/ggml-cpu.c

+25-22
Original file line numberDiff line numberDiff line change
@@ -3110,17 +3110,17 @@ static void ggml_compute_forward_dup_same_cont(
31103110
const int ith = params->ith; // thread index
31113111
const int nth = params->nth; // number of threads
31123112

3113-
// parallelize by elements
3114-
const int ne = ggml_nelements(src0)/ggml_blck_size(src0->type);
3115-
const int dr = (ne + nth - 1) / nth;
3116-
const int ie0 = dr * ith;
3117-
const int ie1 = MIN(ie0 + dr, ne);
3113+
// parallelize by blocks
3114+
const int nk = ggml_nelements(src0)/ggml_blck_size(src0->type);
3115+
const int dr = (nk + nth - 1) / nth;
3116+
const int k0 = dr * ith;
3117+
const int k1 = MIN(k0 + dr, nk);
31183118

3119-
if (ie0 < ie1) {
3119+
if (k0 < k1) {
31203120
memcpy(
3121-
((char *) dst->data + ie0*nb0),
3122-
((char *) src0->data + ie0*nb0),
3123-
(ie1 - ie0) * nb0);
3121+
((char *) dst->data + k0*nb0),
3122+
((char *) src0->data + k0*nb0),
3123+
(k1 - k0) * nb0);
31243124
}
31253125
}
31263126

@@ -4140,19 +4140,22 @@ static void ggml_compute_forward_dup_bytes(
41404140

41414141
// dst counters
41424142

4143-
int64_t i10 = 0;
4143+
int64_t k10 = 0;
41444144
int64_t i11 = 0;
41454145
int64_t i12 = 0;
41464146
int64_t i13 = 0;
41474147

4148+
assert(ne0 == ne00);
4149+
41484150
// number of blocks in a row
4149-
const int64_t nb = ne00/ggml_blck_size(src0->type);
4151+
const int64_t nk00 = ne00 / ggml_blck_size(src0->type);
4152+
const int64_t nk0 = nk00;
41504153

41514154
for (int64_t i03 = 0; i03 < ne03; i03++) {
41524155
for (int64_t i02 = 0; i02 < ne02; i02++) {
4153-
i10 += nb * ir0;
4154-
while (i10 >= ne0) {
4155-
i10 -= ne0;
4156+
k10 += nk00 * ir0;
4157+
while (k10 >= nk0) {
4158+
k10 -= nk0;
41564159
if (++i11 == ne1) {
41574160
i11 = 0;
41584161
if (++i12 == ne2) {
@@ -4164,14 +4167,14 @@ static void ggml_compute_forward_dup_bytes(
41644167
}
41654168
}
41664169
for (int64_t i01 = ir0; i01 < ir1; i01++) {
4167-
for (int64_t i00 = 0; i00 < nb; i00++) {
4168-
const char * src0_ptr = ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
4169-
char * dst_ptr = ((char *) dst->data + i10*nb0 + i11*nb1 + i12*nb2 + i13*nb3);
4170+
for (int64_t k00 = 0; k00 < nk00; k00++) {
4171+
const char * src0_ptr = ((char *) src0->data + k00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
4172+
char * dst_ptr = ((char *) dst->data + k10*nb0 + i11*nb1 + i12*nb2 + i13*nb3);
41704173

41714174
memcpy(dst_ptr, src0_ptr, type_size);
41724175

4173-
if (++i10 == ne0) {
4174-
i10 = 0;
4176+
if (++k10 == nk0) {
4177+
k10 = 0;
41754178
if (++i11 == ne1) {
41764179
i11 = 0;
41774180
if (++i12 == ne2) {
@@ -4184,9 +4187,9 @@ static void ggml_compute_forward_dup_bytes(
41844187
}
41854188
}
41864189
}
4187-
i10 += nb * (ne01 - ir1);
4188-
while (i10 >= ne0) {
4189-
i10 -= ne0;
4190+
k10 += nk00 * (ne01 - ir1);
4191+
while (k10 >= nk0) {
4192+
k10 -= nk0;
41904193
if (++i11 == ne1) {
41914194
i11 = 0;
41924195
if (++i12 == ne2) {

tests/test-backend-ops.cpp

+22-13
Original file line numberDiff line numberDiff line change
@@ -1459,11 +1459,13 @@ struct test_cpy : public test_case {
14591459
const ggml_type type_src;
14601460
const ggml_type type_dst;
14611461
const std::array<int64_t, 4> ne;
1462-
const std::array<int64_t, 4> permute;
1462+
const std::array<int64_t, 4> permute_src;
1463+
const std::array<int64_t, 4> permute_dst;
14631464
bool _src_use_permute;
1465+
bool _dst_use_permute;
14641466

14651467
std::string vars() override {
1466-
return VARS_TO_STR4(type_src, type_dst, ne, permute);
1468+
return VARS_TO_STR5(type_src, type_dst, ne, permute_src, permute_dst);
14671469
}
14681470

14691471
double max_nmse_err() override {
@@ -1476,23 +1478,30 @@ struct test_cpy : public test_case {
14761478

14771479
test_cpy(ggml_type type_src = GGML_TYPE_F32, ggml_type type_dst = GGML_TYPE_F32,
14781480
std::array<int64_t, 4> ne = {10, 10, 10, 1},
1479-
std::array<int64_t, 4> permute = {0, 0, 0, 0})
1480-
: type_src(type_src), type_dst(type_dst), ne(ne), permute(permute),
1481-
_src_use_permute(permute[0] + permute[1] + permute[2] + permute[3] > 0) {}
1481+
std::array<int64_t, 4> permute_src = {0, 0, 0, 0},
1482+
std::array<int64_t, 4> permute_dst = {0, 0, 0, 0})
1483+
: type_src(type_src), type_dst(type_dst), ne(ne), permute_src(permute_src), permute_dst(permute_dst),
1484+
_src_use_permute(permute_src[0] + permute_src[1] + permute_src[2] + permute_src[3] > 0),
1485+
_dst_use_permute(permute_dst[0] + permute_dst[1] + permute_dst[2] + permute_dst[3] > 0) {}
14821486

14831487
ggml_tensor * build_graph(ggml_context * ctx) override {
14841488
ggml_tensor * src = ggml_new_tensor(ctx, type_src, 4, ne.data());
14851489
ggml_set_param(ctx, src);
14861490
ggml_set_name(src, "src");
14871491

14881492
if (_src_use_permute) {
1489-
src = ggml_permute(ctx, src, permute[0], permute[1], permute[2], permute[3]);
1493+
src = ggml_permute(ctx, src, permute_src[0], permute_src[1], permute_src[2], permute_src[3]);
14901494
ggml_set_name(src, "src_permuted");
14911495
}
14921496

1493-
ggml_tensor* dst = ggml_new_tensor(ctx, type_dst, 4, src->ne);
1497+
ggml_tensor * dst = ggml_new_tensor(ctx, type_dst, 4, src->ne);
14941498
ggml_set_name(dst, "dst");
14951499

1500+
if (_dst_use_permute) {
1501+
dst = ggml_permute(ctx, dst, permute_dst[0], permute_dst[1], permute_dst[2], permute_dst[3]);
1502+
ggml_set_name(dst, "dst_permuted");
1503+
}
1504+
14961505
ggml_tensor * out = ggml_cpy(ctx, src, dst);
14971506
ggml_set_name(out, "out");
14981507

@@ -3930,13 +3939,13 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
39303939
}
39313940

39323941
// same-type copy
3933-
for (int nb = 1; nb < 4; ++nb) {
3934-
for (ggml_type type : all_types) {
3935-
const auto neb = ggml_blck_size(type);
3942+
for (ggml_type type : all_types) {
3943+
const auto nk = ggml_blck_size(type);
39363944

3937-
test_cases.emplace_back(new test_cpy(type, type, {nb*neb, 2, 3, 4}, {0, 1, 2, 3}));
3938-
test_cases.emplace_back(new test_cpy(type, type, {nb*neb, 2, 3, 4}, {0, 2, 1, 3}));
3939-
test_cases.emplace_back(new test_cpy(type, type, {nb*neb, 2, 3, 4}, {0, 3, 1, 2}));
3945+
for (int k = 1; k < 4; ++k) {
3946+
test_cases.emplace_back(new test_cpy(type, type, {k*nk, 2, 3, 4}));
3947+
test_cases.emplace_back(new test_cpy(type, type, {k*nk, 2, 3, 4}, {0, 2, 1, 3}));
3948+
test_cases.emplace_back(new test_cpy(type, type, {k*nk, 2, 3, 4}, {0, 3, 1, 2}, {0, 2, 1, 3}));
39403949
}
39413950
}
39423951

0 commit comments

Comments
 (0)