Skip to content

Commit 8bbe6f6

Browse files
authored
feat: support beam search kernel for decode-only model. (#269)
1 parent 018af48 commit 8bbe6f6

25 files changed

+537
-13
lines changed

third_party/xllm_ops

Submodule xllm_ops updated from a8e8b15 to 2cda9bf

xllm/core/common/global_flags.cpp

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -377,4 +377,9 @@ DEFINE_int64(cache_size_per_token,
377377

378378
DEFINE_int64(buffer_size_per_seq,
379379
0,
380-
"Buffer size per sequence in bytes, default 0.");
380+
"Buffer size per sequence in bytes, default 0.");
381+
382+
// --- beam search config ---
383+
DEFINE_bool(enable_beam_search_kernel,
384+
false,
385+
"Whether to enable beam search kernel.");

xllm/core/common/global_flags.h

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -197,4 +197,6 @@ DECLARE_int64(granularity_size);
197197

198198
DECLARE_int64(cache_size_per_token);
199199

200-
DECLARE_int64(buffer_size_per_seq);
200+
DECLARE_int64(buffer_size_per_seq);
201+
202+
DECLARE_bool(enable_beam_search_kernel);

xllm/core/distributed_runtime/worker_service.cpp

Lines changed: 65 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -346,6 +346,20 @@ void WorkerService::ExecuteModel(
346346
batched_fwd_inputs.micro_inputs[i].sampling_params);
347347
}
348348

349+
// concat acc_logprob here for beam search together
350+
if (micro_batches_num > 1) {
351+
std::vector<torch::Tensor> acc_logprob_vec;
352+
acc_logprob_vec.reserve(micro_batches_num);
353+
for (auto i = 0; i < micro_batches_num; ++i) {
354+
acc_logprob_vec.push_back(
355+
batched_fwd_inputs.micro_inputs[i].acc_logprob);
356+
}
357+
batched_fwd_inputs.acc_logprob = torch::cat(acc_logprob_vec, /*dim=*/-1);
358+
} else {
359+
batched_fwd_inputs.acc_logprob =
360+
batched_fwd_inputs.micro_inputs[0].acc_logprob;
361+
}
362+
349363
// model output
350364
torch::Tensor next_tokens;
351365
torch::Tensor logprobs;
@@ -354,6 +368,10 @@ void WorkerService::ExecuteModel(
354368
torch::Tensor embeddings;
355369
torch::Tensor expert_load_data;
356370
int32_t prepared_layer_id = -1;
371+
// beam search kernel output
372+
torch::Tensor src_seq_idxes;
373+
torch::Tensor out_tokens;
374+
torch::Tensor out_logprobs;
357375

358376
// execute model
359377
auto future = worker_->step_async(batched_fwd_inputs);
@@ -364,6 +382,8 @@ void WorkerService::ExecuteModel(
364382
if (forward_outputs) {
365383
DCHECK(forward_outputs.has_value()) << "Failed to execute model";
366384
const auto& sample_output = forward_outputs.value().sample_output;
385+
const auto& beam_search_output =
386+
forward_outputs.value().beam_search_output;
367387
expert_load_data = safe_to(
368388
forward_outputs.value().expert_load_data, torch::kCPU, true);
369389
prepared_layer_id = forward_outputs.value().prepared_layer_id;
@@ -382,11 +402,32 @@ void WorkerService::ExecuteModel(
382402
if (next_tokens.defined()) {
383403
// [num_seq]
384404
logprobs = safe_to(sample_output.logprobs, torch::kCPU, true);
385-
// [num_seq, topk]
386-
top_tokens = safe_to(sample_output.top_tokens, torch::kCPU, true);
387-
// [num_seq, topk]
388-
top_logprobs =
389-
safe_to(sample_output.top_logprobs, torch::kCPU, true);
405+
406+
if (!beam_search_output.src_seq_idxes.defined()) {
407+
// beam search kernel will provide final tokens/logprobs in beam
408+
// search output, so keep top_tokens/top_logprobs undefined to
409+
// avoid returning them.
410+
// [num_seq, topk]
411+
top_tokens = safe_to(sample_output.top_tokens, torch::kCPU, true);
412+
// [num_seq, topk]
413+
top_logprobs =
414+
safe_to(sample_output.top_logprobs, torch::kCPU, true);
415+
}
416+
}
417+
418+
// beam search output
419+
// [num_seq]
420+
src_seq_idxes =
421+
safe_to(beam_search_output.src_seq_idxes, torch::kCPU, true);
422+
if (src_seq_idxes.defined()) {
423+
// [num_seq]
424+
out_tokens =
425+
safe_to(beam_search_output.out_tokens, torch::kCPU, true);
426+
// [num_seq]
427+
out_logprobs =
428+
safe_to(beam_search_output.out_logprobs,
429+
torch::dtype(torch::kFloat32).device(torch::kCPU),
430+
true);
390431
}
391432
auto ret = stream_->synchronize();
392433
}
@@ -419,6 +460,9 @@ void WorkerService::ExecuteModel(
419460
embeddings,
420461
expert_load_data,
421462
prepared_layer_id,
463+
src_seq_idxes,
464+
out_tokens,
465+
out_logprobs,
422466
pb_forward_output);
423467
COUNTER_ADD(worker_service_latency_seconds, timer.elapsed_seconds());
424468
});
@@ -441,6 +485,8 @@ void WorkerService::GetLastStepResult(
441485
const auto& expert_load_data = safe_to(
442486
forward_outputs.value().expert_load_data, torch::kCPU, true);
443487
int32_t prepared_layer_id = forward_outputs.value().prepared_layer_id;
488+
const auto& beam_search_output =
489+
forward_outputs.value().beam_search_output;
444490
c10::StreamGuard streamGuard = stream_->set_stream_guard();
445491
// [num_seq, ..., embed_dim]
446492
auto embeddings =
@@ -460,6 +506,17 @@ void WorkerService::GetLastStepResult(
460506
// [num_seq, topk]
461507
const auto& top_logprobs =
462508
safe_to(sample_output.top_logprobs, torch::kCPU, true);
509+
// [num_seq]
510+
const auto& src_seq_idxes =
511+
safe_to(beam_search_output.src_seq_idxes, torch::kCPU, true);
512+
// [num_seq]
513+
const auto& out_tokens =
514+
safe_to(beam_search_output.out_tokens, torch::kCPU, true);
515+
// [num_seq]
516+
const auto& out_logprobs =
517+
safe_to(beam_search_output.out_logprobs,
518+
torch::dtype(torch::kFloat32).device(torch::kCPU),
519+
true);
463520
auto ret = stream_->synchronize();
464521

465522
forward_output_to_proto(next_tokens,
@@ -469,6 +526,9 @@ void WorkerService::GetLastStepResult(
469526
embeddings,
470527
expert_load_data,
471528
prepared_layer_id,
529+
src_seq_idxes,
530+
out_tokens,
531+
out_logprobs,
472532
pb_forward_output);
473533
}
474534
}

xllm/core/framework/batch/batch.cpp

Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -264,4 +264,78 @@ void Batch::process_beam_search() {
264264
sequence_group->process_beam_search();
265265
}
266266
}
267+
268+
void Batch::process_beam_search_output(const RawForwardOutput& raw_output,
269+
bool replace_fake_token) {
270+
const int32_t beam_width = sequences_[0]->sampling_param()->beam_width;
271+
if (beam_width <= 1) {
272+
return;
273+
}
274+
275+
CHECK_EQ(raw_output.src_seq_idxes.size(), sequences_.size());
276+
CHECK_EQ(raw_output.out_tokens.size(), sequences_.size());
277+
CHECK_EQ(raw_output.out_logprobs.size(), sequences_.size());
278+
279+
auto update_for_sequence_group = [&](size_t sequence_group_id) {
280+
std::unordered_set<int32_t> seq_idx_set;
281+
std::vector<float> src_acc_logprob_vec;
282+
std::vector<std::vector<int32_t>> src_token_ids;
283+
std::vector<std::vector<std::optional<float>>> src_logprobs;
284+
src_acc_logprob_vec.resize(beam_width);
285+
src_token_ids.resize(beam_width);
286+
src_logprobs.resize(beam_width);
287+
288+
for (size_t i = 0; i < beam_width; i++) {
289+
size_t task_id = sequence_group_id * beam_width + i;
290+
int32_t src_seq_idx = raw_output.src_seq_idxes[task_id];
291+
CHECK_LE(src_seq_idx, sequences_.size());
292+
auto src_seq = sequences_[src_seq_idx];
293+
src_acc_logprob_vec[i] =
294+
src_seq->get_average_logprob() * src_seq->num_generated_tokens();
295+
src_token_ids[i] = std::vector<int32_t>(src_seq->tokens());
296+
src_logprobs[i] = src_seq->logprob_state()->get_logprobs();
297+
}
298+
299+
for (size_t i = 0; i < beam_width; i++) {
300+
size_t task_id = sequence_group_id * beam_width + i;
301+
int32_t src_seq_idx = raw_output.src_seq_idxes[task_id];
302+
CHECK_LE(src_seq_idx, sequences_.size());
303+
auto& base_seq = sequences_[task_id];
304+
auto& src_seq = sequences_[src_seq_idx];
305+
306+
for (size_t token_idx = base_seq->num_prompt_tokens();
307+
token_idx < base_seq->num_tokens();
308+
token_idx++) {
309+
Token new_token(src_token_ids[i][token_idx]);
310+
new_token.logprob = src_logprobs[i][token_idx];
311+
base_seq->update_token(token_idx, new_token);
312+
}
313+
314+
Token new_token(raw_output.out_tokens[task_id]);
315+
new_token.logprob =
316+
raw_output.out_logprobs[task_id] - src_acc_logprob_vec[i];
317+
append_token_for_sequence(base_seq, new_token, 0, replace_fake_token);
318+
319+
base_seq->logprob_state()->set_acc_logprob(
320+
raw_output.out_logprobs[task_id]);
321+
base_seq->logprob_state()->set_last_acc_token_idx(base_seq->num_tokens());
322+
323+
bool need_swap = false;
324+
if (seq_idx_set.find(src_seq_idx) != seq_idx_set.end()) {
325+
need_swap = true;
326+
} else {
327+
seq_idx_set.insert(src_seq_idx);
328+
}
329+
330+
auto src_blocks = src_seq->kv_state().kv_blocks();
331+
base_seq->kv_state().set_src_blocks(src_blocks, need_swap);
332+
}
333+
};
334+
335+
for (size_t sequence_group_id = 0;
336+
sequence_group_id < sequence_groups_.size();
337+
sequence_group_id++) {
338+
update_for_sequence_group(sequence_group_id);
339+
}
340+
}
267341
} // namespace xllm

xllm/core/framework/batch/batch.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,10 @@ class Batch {
9494
void process_sample_output(const RawForwardOutput& raw_output,
9595
bool replace_fake_token);
9696

97+
// process output for beam search kernel
98+
void process_beam_search_output(const RawForwardOutput& raw_output,
99+
bool replace_fake_token);
100+
97101
// process the accepted output embedding
98102
void process_embedding_output(const torch::Tensor& embedding);
99103

xllm/core/framework/batch/batch_input_builder.cpp

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,7 @@ BatchInputBuilder::BatchInputBuilder(
7878
state_.flatten_positions_vec.reserve(1000);
7979
state_.mrope_positions_vec.reserve(sequences.size());
8080
state_.block_tables_vec.reserve(sequences.size());
81+
state_.acc_logprob_vec.reserve(sequences.size());
8182
if (args_ != nullptr) {
8283
use_mrope_ = (args_->rope_scaling_rope_type() == "mrope");
8384
}
@@ -179,6 +180,9 @@ void BatchInputBuilder::process_sequences_multithreaded(uint32_t start_idx,
179180
state_.block_tables_vec.insert(state_.block_tables_vec.end(),
180181
state.block_tables_vec.begin(),
181182
state.block_tables_vec.end());
183+
state_.acc_logprob_vec.insert(state_.acc_logprob_vec.end(),
184+
state.acc_logprob_vec.begin(),
185+
state.acc_logprob_vec.end());
182186
// selected_token_idxes and sample_idxes need offset
183187
int32_t selected_token_idxes_offset =
184188
static_cast<int32_t>(state_.flatten_tokens_vec.size()) -
@@ -308,6 +312,13 @@ void BatchInputBuilder::process_single_sequence(
308312
if (sequence->is_prefill_stage()) {
309313
state.prefill_seq_len++;
310314
}
315+
316+
// Input for beam search kernel
317+
if (FLAGS_enable_beam_search_kernel && sequence->check_beam_search() &&
318+
sequence->num_generated_tokens() > 0) {
319+
state.acc_logprob_vec.push_back(sequence->get_average_logprob() *
320+
sequence->num_generated_tokens());
321+
}
311322
}
312323

313324
void BatchInputBuilder::extract_tokens_and_positions(Sequence* sequence,
@@ -625,6 +636,10 @@ RawForwardInput BatchInputBuilder::state_to_raw_forward_input() {
625636

626637
raw_forward_input.embedding_ids = std::move(state_.embedding_ids);
627638
raw_forward_input.extra_token_ids = std::move(state_.extra_token_ids);
639+
// beam search kernel input
640+
if (state_.acc_logprob_vec.size() > 0) {
641+
raw_forward_input.acc_logprob_vec = std::move(state_.acc_logprob_vec);
642+
}
628643

629644
if (FLAGS_enable_continuous_kvcache) {
630645
raw_forward_input.new_cache_slot_offsets =

xllm/core/framework/batch/batch_input_builder.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,9 @@ class BatchInputBuilder {
9595
std::vector<int32_t> new_token_slot_ids;
9696
std::vector<std::vector<int32_t>> block_tables_vec;
9797

98+
// beam search kernel input
99+
std::vector<float> acc_logprob_vec;
100+
98101
// Additional data
99102
std::vector<int32_t> embedding_ids;
100103
std::vector<int32_t> extra_token_ids;

xllm/core/framework/sampling/CMakeLists.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,11 +9,13 @@ cc_library(
99
logits_utils.h
1010
rejection_sampler.h
1111
sampler.h
12+
beam_searcher.h
1213
SRCS
1314
sampling_params.cpp
1415
logits_utils.cpp
1516
rejection_sampler.cpp
1617
sampler.cpp
18+
beam_searcher.cpp
1719
DEPS
1820
glog::glog
1921
torch
Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
/* Copyright 2025 The xLLM Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
https://github.com/jd-opensource/xllm/blob/main/LICENSE
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License.
14+
==============================================================================*/
15+
16+
#include "beam_searcher.h"
17+
18+
namespace xllm {
19+
BeamSearchOutput BeamSearcher::forward(
20+
const torch::Tensor& logprobs,
21+
const torch::Tensor& top_tokens,
22+
const torch::Tensor& top_logprobs) const {
23+
#if defined(USE_NPU)
24+
BeamSearchOutput output;
25+
26+
int64_t num_seq = logprobs.numel();
27+
output.out_tokens =
28+
torch::empty({num_seq, 1}, logprobs.options().dtype(torch::kInt32));
29+
output.out_logprobs =
30+
torch::empty({num_seq, 1}, logprobs.options().dtype(torch::kFloat32));
31+
output.src_seq_idxes =
32+
torch::empty({num_seq, 1}, logprobs.options().dtype(torch::kInt32));
33+
xllm_ops::beam_search(logprobs.reshape({-1, 1}),
34+
top_tokens.to(torch::kInt32),
35+
top_logprobs,
36+
output.src_seq_idxes,
37+
output.out_logprobs,
38+
output.out_tokens);
39+
output.src_seq_idxes = output.src_seq_idxes.reshape({-1});
40+
output.out_logprobs = output.out_logprobs.reshape({-1});
41+
output.out_tokens = output.out_tokens.reshape({-1});
42+
return output;
43+
#else
44+
LOG(FATAL) << "BeamSearcher is only implemented for NPU backend.";
45+
#endif
46+
}
47+
} // namespace xllm

0 commit comments

Comments
 (0)