Skip to content

Commit 93378ac

Browse files
authored
Merge pull request #538 from jialilve/feature/qwen-npu-decoding
test: fix CausalMaskOp CPU coverage
2 parents 9d081db + 8fa9644 commit 93378ac

File tree

5 files changed

+128
-10
lines changed

5 files changed

+128
-10
lines changed

mllm/backends/cpu/ops/CausalMaskOp.cpp

Lines changed: 38 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -50,13 +50,14 @@ void CPUCausalMaskOp::forward(const std::vector<Tensor>& inputs, std::vector<Ten
5050

5151
if (!options_.sliding_window) {
5252
// Standard causal mask
53-
#if defined(MLLM_HOST_ARCH_X86_64) || defined(MLLM_HOST_ARCH_X86) && defined(__AVX2__)
53+
#if (defined(MLLM_HOST_ARCH_X86_64) || defined(MLLM_HOST_ARCH_X86)) && defined(__AVX2__)
5454
const __m256 mask_val = _mm256_set1_ps(-1e10f);
5555
for (size_t r = 0; r < S; ++r) {
56+
const size_t row_offset = r * D;
5657
const size_t copy_count = D - S + r + 1;
5758
const size_t fill_count = std::max(D - copy_count, (size_t)0);
5859

59-
memcpy(o_ptr + r * D, i_ptr + r * D, copy_count * sizeof(float));
60+
memcpy(o_ptr + row_offset, i_ptr + row_offset, copy_count * sizeof(float));
6061

6162
float* fill_start = o_ptr + row_offset + copy_count;
6263
size_t avx_iters = fill_count / 8;
@@ -81,6 +82,17 @@ void CPUCausalMaskOp::forward(const std::vector<Tensor>& inputs, std::vector<Ten
8182
for (size_t i = 0; i < neon_iters; ++i) { vst1q_f32(fill_start + i * 4, mask_val); }
8283
for (size_t i = 0; i < remainder; ++i) { fill_start[neon_iters * 4 + i] = -1e10f; }
8384
}
85+
#else
86+
for (size_t r = 0; r < S; ++r) {
87+
const size_t row_offset = r * D;
88+
const size_t copy_count = D - S + r + 1;
89+
const size_t fill_count = std::max(D - copy_count, (size_t)0);
90+
91+
memcpy(o_ptr + row_offset, i_ptr + row_offset, copy_count * sizeof(float));
92+
93+
float* fill_start = o_ptr + row_offset + copy_count;
94+
for (size_t i = 0; i < fill_count; ++i) { fill_start[i] = -1e10f; }
95+
}
8496
#endif
8597
} else {
8698
// Sliding window causal mask
@@ -98,7 +110,7 @@ void CPUCausalMaskOp::forward(const std::vector<Tensor>& inputs, std::vector<Ten
98110
const size_t suffix_fill_start_idx = s + 1;
99111
const size_t suffix_fill_count = S - suffix_fill_start_idx;
100112

101-
#if defined(MLLM_HOST_ARCH_X86_64) || defined(MLLM_HOST_ARCH_X86) && defined(__AVX2__)
113+
#if (defined(MLLM_HOST_ARCH_X86_64) || defined(MLLM_HOST_ARCH_X86)) && defined(__AVX2__)
102114
const __m256 mask_val = _mm256_set1_ps(-1e10f);
103115
// Fill prefix
104116
float* prefix_fill_start = o_ptr + row_offset;
@@ -118,6 +130,11 @@ void CPUCausalMaskOp::forward(const std::vector<Tensor>& inputs, std::vector<Ten
118130
float* suffix_fill_start = o_ptr + row_offset + suffix_fill_start_idx;
119131
for (size_t i = 0; i < suffix_fill_count / 4; ++i) vst1q_f32(suffix_fill_start + i * 4, mask_val);
120132
for (size_t i = (suffix_fill_count / 4) * 4; i < suffix_fill_count; ++i) suffix_fill_start[i] = -1e10f;
133+
#else
134+
float* prefix_fill_start = o_ptr + row_offset;
135+
for (size_t i = 0; i < prefix_fill_count; ++i) { prefix_fill_start[i] = -1e10f; }
136+
float* suffix_fill_start = o_ptr + row_offset + suffix_fill_start_idx;
137+
for (size_t i = 0; i < suffix_fill_count; ++i) { suffix_fill_start[i] = -1e10f; }
121138
#endif
122139
}
123140
}
@@ -143,7 +160,7 @@ void CPUCausalMaskOp::forward(const std::vector<Tensor>& inputs, std::vector<Ten
143160

144161
if (!options_.sliding_window) {
145162
// Standard causal mask
146-
#if defined(MLLM_HOST_ARCH_X86_64) || defined(MLLM_HOST_ARCH_X86) && defined(__AVX2__) && defined(__F16C__)
163+
#if (defined(MLLM_HOST_ARCH_X86_64) || defined(MLLM_HOST_ARCH_X86)) && defined(__AVX2__) && defined(__F16C__)
147164
const __m256 mask_ps = _mm256_set1_ps(-65500.f);
148165
const __m128i mask_val = _mm256_cvtps_ph(mask_ps, _MM_FROUND_TO_NEAREST_INT);
149166
for (size_t s = 0; s < S; ++s) {
@@ -178,6 +195,17 @@ void CPUCausalMaskOp::forward(const std::vector<Tensor>& inputs, std::vector<Ten
178195
for (size_t i = 0; i < neon_iters; ++i) { vst1q_f16(fill_start + i * 8, mask_val); }
179196
for (size_t i = 0; i < remainder; ++i) { fill_start[neon_iters * 8 + i] = -65500.f; }
180197
}
198+
#else
199+
for (size_t s = 0; s < S; ++s) {
200+
const size_t row_offset = s * S;
201+
const size_t copy_count = s + 1;
202+
const size_t fill_count = S - copy_count;
203+
204+
if (copy_count > 0) { memcpy(o_ptr + row_offset, i_ptr + row_offset, copy_count * sizeof(mllm_fp16_t)); }
205+
206+
mllm_fp16_t* fill_start = o_ptr + row_offset + copy_count;
207+
for (size_t i = 0; i < fill_count; ++i) { fill_start[i] = -65500.f; }
208+
}
181209
#endif
182210
} else {
183211
// Sliding window causal mask
@@ -196,7 +224,7 @@ void CPUCausalMaskOp::forward(const std::vector<Tensor>& inputs, std::vector<Ten
196224
const size_t suffix_fill_start_idx = s + 1;
197225
const size_t suffix_fill_count = S - suffix_fill_start_idx;
198226

199-
#if defined(MLLM_HOST_ARCH_X86_64) || defined(MLLM_HOST_ARCH_X86) && defined(__AVX2__) && defined(__F16C__)
227+
#if (defined(MLLM_HOST_ARCH_X86_64) || defined(MLLM_HOST_ARCH_X86)) && defined(__AVX2__) && defined(__F16C__)
200228
const __m256 mask_ps = _mm256_set1_ps(-65500.f);
201229
const __m128i mask_val = _mm256_cvtps_ph(mask_ps, _MM_FROUND_TO_NEAREST_INT);
202230

@@ -222,6 +250,11 @@ void CPUCausalMaskOp::forward(const std::vector<Tensor>& inputs, std::vector<Ten
222250
mllm_fp16_t* suffix_fill_start = o_ptr + row_offset + suffix_fill_start_idx;
223251
for (size_t i = 0; i < suffix_fill_count / 8; ++i) vst1q_f16(suffix_fill_start + i * 8, mask_val);
224252
for (size_t i = (suffix_fill_count / 8) * 8; i < suffix_fill_count; ++i) suffix_fill_start[i] = -65500.f;
253+
#else
254+
mllm_fp16_t* prefix_fill_start = o_ptr + row_offset;
255+
for (size_t i = 0; i < prefix_fill_count; ++i) { prefix_fill_start[i] = -65500.f; }
256+
mllm_fp16_t* suffix_fill_start = o_ptr + row_offset + suffix_fill_start_idx;
257+
for (size_t i = 0; i < suffix_fill_count; ++i) { suffix_fill_start[i] = -65500.f; }
225258
#endif
226259
}
227260
}

mllm/engine/HpcThreadPool.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -96,21 +96,21 @@ void HpcThreadPool::splitTask(HpcThreadPoolTask&& task, int task_slot_idx) {
9696
// 0, 1, 2, 3, 0, 1, 2, 3, 0, 1, 2, 3
9797
if (tiles_num > thread_cnt_) {
9898
tasks_[task_slot_idx].first = {
99-
.start = 0,
100-
.end = thread_cnt_,
101-
.step = 1,
10299
.func =
103100
[tiles_num, &task, &true_idx, this](int thread_idx) {
104101
for (int v = thread_idx; v < tiles_num; v += thread_cnt_) { task.func(true_idx[v]); }
105102
},
103+
.start = 0,
104+
.end = thread_cnt_,
105+
.step = 1,
106106
};
107107
tiles_num = thread_cnt_;
108108
} else {
109109
tasks_[task_slot_idx].first = {
110+
.func = [tiles_num, &task, &true_idx, this](int thread_idx) { task.func(true_idx[thread_idx]); },
110111
.start = 0,
111112
.end = tiles_num,
112113
.step = 1,
113-
.func = [tiles_num, &task, &true_idx, this](int thread_idx) { task.func(true_idx[thread_idx]); },
114114
};
115115
}
116116
{

mllm/models/minicpm_o2_6/streaming_generation.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -96,7 +96,7 @@ class StreamingGenerator {
9696
config_(config) {
9797
// Configure chunk generation
9898
models::ChunkGenerationConfig chunk_config{
99-
.chunk_size = 5, .max_new_tokens = 10, .do_sample = false, .save_first_chunk_hidden_states = true};
99+
.max_new_tokens = 10, .chunk_size = 5, .do_sample = false, .save_first_chunk_hidden_states = true};
100100

101101
// Add EOS tokens for MiniCPMO
102102
auto eos_ids = tokenizer_.convert2Ids({L"<|im_end|>"});

tests/cpu/CausalMaskOpTest.hpp

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
#pragma once
2+
3+
#include <algorithm>
4+
5+
#include "KernelTestHelper.hpp"
6+
#include "mllm/mllm.hpp"
7+
#include "mllm/nn/layers/CausalMask.hpp"
8+
9+
class CausalMaskOpTest : public KernelTest {
10+
public:
11+
void SetUp() override {
12+
KernelTest::SetUp();
13+
mask_.to(mllm::kCPU);
14+
}
15+
16+
mllm::test::AllCloseResult runScenario(int B, int H, int S, int D) {
17+
using namespace mllm; // NOLINT
18+
const int64_t total = static_cast<int64_t>(B) * H * S * D;
19+
auto input = Tensor::arange(0, static_cast<float>(total), 1, kFloat32, kCPU).view({B, H, S, D});
20+
auto output = mask_(input);
21+
auto expected = buildExpectedTensor(input);
22+
auto result = test::allClose(expected, output);
23+
if (!result) {
24+
mllm::print(result);
25+
mllm::print(expected);
26+
mllm::print(output);
27+
}
28+
return result;
29+
}
30+
31+
private:
32+
static mllm::Tensor buildExpectedTensor(const mllm::Tensor& input) {
33+
using namespace mllm; // NOLINT
34+
auto shape = input.shape();
35+
const int B = shape[0];
36+
const int H = shape[1];
37+
const int S = shape[2];
38+
const int D = shape[3];
39+
auto expected = Tensor::zeros(shape, kFloat32, kCPU);
40+
41+
const float* in_ptr = input.ptr<float>();
42+
float* exp_ptr = expected.ptr<float>();
43+
const int context_offset = std::max(0, D - S);
44+
const float mask_value = -1e10f;
45+
46+
for (int b = 0; b < B; ++b) {
47+
for (int h = 0; h < H; ++h) {
48+
for (int s = 0; s < S; ++s) {
49+
const int allowed = std::min(D, context_offset + s + 1);
50+
for (int d = 0; d < D; ++d) {
51+
const int64_t idx = (((static_cast<int64_t>(b) * H) + h) * S + s) * D + d;
52+
if (d < allowed) {
53+
exp_ptr[idx] = in_ptr[idx];
54+
} else {
55+
exp_ptr[idx] = mask_value;
56+
}
57+
}
58+
}
59+
}
60+
}
61+
return expected;
62+
}
63+
64+
mllm::nn::CausalMask mask_;
65+
};
66+

tests/cpu/KernelTest.cpp

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -533,6 +533,25 @@ TEST_F(ElementwiseKernelTest, DivScalarInt32) {
533533
true);
534534
}
535535

536+
//===----------------------------------------------------------------------===//
537+
// CausalMaskOp
538+
//===----------------------------------------------------------------------===//
539+
#include "CausalMaskOpTest.hpp"
540+
TEST_F(CausalMaskOpTest, PrefillScenario) {
541+
auto result = runScenario(1, 1, 4, 4);
542+
EXPECT_TRUE(result.is_close);
543+
}
544+
545+
TEST_F(CausalMaskOpTest, DecodeScenario) {
546+
auto result = runScenario(1, 1, 1, 6);
547+
EXPECT_TRUE(result.is_close);
548+
}
549+
550+
TEST_F(CausalMaskOpTest, AppendScenario) {
551+
auto result = runScenario(2, 3, 3, 7);
552+
EXPECT_TRUE(result.is_close);
553+
}
554+
536555
//===----------------------------------------------------------------------===//
537556
// GELU
538557
//===----------------------------------------------------------------------===//

0 commit comments

Comments
 (0)