@@ -50,13 +50,14 @@ void CPUCausalMaskOp::forward(const std::vector<Tensor>& inputs, std::vector<Ten
5050
5151 if (!options_.sliding_window ) {
5252 // Standard causal mask
53- #if defined(MLLM_HOST_ARCH_X86_64) || defined(MLLM_HOST_ARCH_X86) && defined(__AVX2__)
53+ #if ( defined(MLLM_HOST_ARCH_X86_64) || defined(MLLM_HOST_ARCH_X86) ) && defined(__AVX2__)
5454 const __m256 mask_val = _mm256_set1_ps (-1e10f);
5555 for (size_t r = 0 ; r < S; ++r) {
56+ const size_t row_offset = r * D;
5657 const size_t copy_count = D - S + r + 1 ;
5758 const size_t fill_count = std::max (D - copy_count, (size_t )0 );
5859
59- memcpy (o_ptr + r * D , i_ptr + r * D , copy_count * sizeof (float ));
60+ memcpy (o_ptr + row_offset , i_ptr + row_offset , copy_count * sizeof (float ));
6061
6162 float * fill_start = o_ptr + row_offset + copy_count;
6263 size_t avx_iters = fill_count / 8 ;
@@ -81,6 +82,17 @@ void CPUCausalMaskOp::forward(const std::vector<Tensor>& inputs, std::vector<Ten
8182 for (size_t i = 0 ; i < neon_iters; ++i) { vst1q_f32 (fill_start + i * 4 , mask_val); }
8283 for (size_t i = 0 ; i < remainder; ++i) { fill_start[neon_iters * 4 + i] = -1e10f; }
8384 }
85+ #else
86+ for (size_t r = 0 ; r < S; ++r) {
87+ const size_t row_offset = r * D;
88+ const size_t copy_count = D - S + r + 1 ;
89+ const size_t fill_count = std::max (D - copy_count, (size_t )0 );
90+
91+ memcpy (o_ptr + row_offset, i_ptr + row_offset, copy_count * sizeof (float ));
92+
93+ float * fill_start = o_ptr + row_offset + copy_count;
94+ for (size_t i = 0 ; i < fill_count; ++i) { fill_start[i] = -1e10f; }
95+ }
8496#endif
8597 } else {
8698 // Sliding window causal mask
@@ -98,7 +110,7 @@ void CPUCausalMaskOp::forward(const std::vector<Tensor>& inputs, std::vector<Ten
98110 const size_t suffix_fill_start_idx = s + 1 ;
99111 const size_t suffix_fill_count = S - suffix_fill_start_idx;
100112
101- #if defined(MLLM_HOST_ARCH_X86_64) || defined(MLLM_HOST_ARCH_X86) && defined(__AVX2__)
113+ #if ( defined(MLLM_HOST_ARCH_X86_64) || defined(MLLM_HOST_ARCH_X86) ) && defined(__AVX2__)
102114 const __m256 mask_val = _mm256_set1_ps (-1e10f);
103115 // Fill prefix
104116 float * prefix_fill_start = o_ptr + row_offset;
@@ -118,6 +130,11 @@ void CPUCausalMaskOp::forward(const std::vector<Tensor>& inputs, std::vector<Ten
118130 float * suffix_fill_start = o_ptr + row_offset + suffix_fill_start_idx;
119131 for (size_t i = 0 ; i < suffix_fill_count / 4 ; ++i) vst1q_f32 (suffix_fill_start + i * 4 , mask_val);
120132 for (size_t i = (suffix_fill_count / 4 ) * 4 ; i < suffix_fill_count; ++i) suffix_fill_start[i] = -1e10f;
133+ #else
134+ float * prefix_fill_start = o_ptr + row_offset;
135+ for (size_t i = 0 ; i < prefix_fill_count; ++i) { prefix_fill_start[i] = -1e10f; }
136+ float * suffix_fill_start = o_ptr + row_offset + suffix_fill_start_idx;
137+ for (size_t i = 0 ; i < suffix_fill_count; ++i) { suffix_fill_start[i] = -1e10f; }
121138#endif
122139 }
123140 }
@@ -143,7 +160,7 @@ void CPUCausalMaskOp::forward(const std::vector<Tensor>& inputs, std::vector<Ten
143160
144161 if (!options_.sliding_window ) {
145162 // Standard causal mask
146- #if defined(MLLM_HOST_ARCH_X86_64) || defined(MLLM_HOST_ARCH_X86) && defined(__AVX2__) && defined(__F16C__)
163+ #if ( defined(MLLM_HOST_ARCH_X86_64) || defined(MLLM_HOST_ARCH_X86) ) && defined(__AVX2__) && defined(__F16C__)
147164 const __m256 mask_ps = _mm256_set1_ps (-65500 .f );
148165 const __m128i mask_val = _mm256_cvtps_ph (mask_ps, _MM_FROUND_TO_NEAREST_INT);
149166 for (size_t s = 0 ; s < S; ++s) {
@@ -178,6 +195,17 @@ void CPUCausalMaskOp::forward(const std::vector<Tensor>& inputs, std::vector<Ten
178195 for (size_t i = 0 ; i < neon_iters; ++i) { vst1q_f16 (fill_start + i * 8 , mask_val); }
179196 for (size_t i = 0 ; i < remainder; ++i) { fill_start[neon_iters * 8 + i] = -65500 .f ; }
180197 }
198+ #else
199+ for (size_t s = 0 ; s < S; ++s) {
200+ const size_t row_offset = s * S;
201+ const size_t copy_count = s + 1 ;
202+ const size_t fill_count = S - copy_count;
203+
204+ if (copy_count > 0 ) { memcpy (o_ptr + row_offset, i_ptr + row_offset, copy_count * sizeof (mllm_fp16_t )); }
205+
206+ mllm_fp16_t * fill_start = o_ptr + row_offset + copy_count;
207+ for (size_t i = 0 ; i < fill_count; ++i) { fill_start[i] = -65500 .f ; }
208+ }
181209#endif
182210 } else {
183211 // Sliding window causal mask
@@ -196,7 +224,7 @@ void CPUCausalMaskOp::forward(const std::vector<Tensor>& inputs, std::vector<Ten
196224 const size_t suffix_fill_start_idx = s + 1 ;
197225 const size_t suffix_fill_count = S - suffix_fill_start_idx;
198226
199- #if defined(MLLM_HOST_ARCH_X86_64) || defined(MLLM_HOST_ARCH_X86) && defined(__AVX2__) && defined(__F16C__)
227+ #if ( defined(MLLM_HOST_ARCH_X86_64) || defined(MLLM_HOST_ARCH_X86) ) && defined(__AVX2__) && defined(__F16C__)
200228 const __m256 mask_ps = _mm256_set1_ps (-65500 .f );
201229 const __m128i mask_val = _mm256_cvtps_ph (mask_ps, _MM_FROUND_TO_NEAREST_INT);
202230
@@ -222,6 +250,11 @@ void CPUCausalMaskOp::forward(const std::vector<Tensor>& inputs, std::vector<Ten
222250 mllm_fp16_t * suffix_fill_start = o_ptr + row_offset + suffix_fill_start_idx;
223251 for (size_t i = 0 ; i < suffix_fill_count / 8 ; ++i) vst1q_f16 (suffix_fill_start + i * 8 , mask_val);
224252 for (size_t i = (suffix_fill_count / 8 ) * 8 ; i < suffix_fill_count; ++i) suffix_fill_start[i] = -65500 .f ;
253+ #else
254+ mllm_fp16_t * prefix_fill_start = o_ptr + row_offset;
255+ for (size_t i = 0 ; i < prefix_fill_count; ++i) { prefix_fill_start[i] = -65500 .f ; }
256+ mllm_fp16_t * suffix_fill_start = o_ptr + row_offset + suffix_fill_start_idx;
257+ for (size_t i = 0 ; i < suffix_fill_count; ++i) { suffix_fill_start[i] = -65500 .f ; }
225258#endif
226259 }
227260 }
0 commit comments