@@ -346,6 +346,20 @@ void WorkerService::ExecuteModel(
346346 batched_fwd_inputs.micro_inputs [i].sampling_params );
347347 }
348348
349+ // concat acc_logprob here for beam search together
350+ if (micro_batches_num > 1 ) {
351+ std::vector<torch::Tensor> acc_logprob_vec;
352+ acc_logprob_vec.reserve (micro_batches_num);
353+ for (auto i = 0 ; i < micro_batches_num; ++i) {
354+ acc_logprob_vec.push_back (
355+ batched_fwd_inputs.micro_inputs [i].acc_logprob );
356+ }
357+ batched_fwd_inputs.acc_logprob = torch::cat (acc_logprob_vec, /* dim=*/ -1 );
358+ } else {
359+ batched_fwd_inputs.acc_logprob =
360+ batched_fwd_inputs.micro_inputs [0 ].acc_logprob ;
361+ }
362+
349363 // model output
350364 torch::Tensor next_tokens;
351365 torch::Tensor logprobs;
@@ -354,6 +368,10 @@ void WorkerService::ExecuteModel(
354368 torch::Tensor embeddings;
355369 torch::Tensor expert_load_data;
356370 int32_t prepared_layer_id = -1 ;
371+ // beam search kernel output
372+ torch::Tensor src_seq_idxes;
373+ torch::Tensor out_tokens;
374+ torch::Tensor out_logprobs;
357375
358376 // execute model
359377 auto future = worker_->step_async (batched_fwd_inputs);
@@ -364,6 +382,8 @@ void WorkerService::ExecuteModel(
364382 if (forward_outputs) {
365383 DCHECK (forward_outputs.has_value ()) << " Failed to execute model" ;
366384 const auto & sample_output = forward_outputs.value ().sample_output ;
385+ const auto & beam_search_output =
386+ forward_outputs.value ().beam_search_output ;
367387 expert_load_data = safe_to (
368388 forward_outputs.value ().expert_load_data , torch::kCPU , true );
369389 prepared_layer_id = forward_outputs.value ().prepared_layer_id ;
@@ -382,11 +402,32 @@ void WorkerService::ExecuteModel(
382402 if (next_tokens.defined ()) {
383403 // [num_seq]
384404 logprobs = safe_to (sample_output.logprobs , torch::kCPU , true );
385- // [num_seq, topk]
386- top_tokens = safe_to (sample_output.top_tokens , torch::kCPU , true );
387- // [num_seq, topk]
388- top_logprobs =
389- safe_to (sample_output.top_logprobs , torch::kCPU , true );
405+
406+ if (!beam_search_output.src_seq_idxes .defined ()) {
407+ // beam search kernel will provide final tokens/logprobs in beam
408+ // search output, so keep top_tokens/top_logprobs undefined to
409+ // avoid returning them.
410+ // [num_seq, topk]
411+ top_tokens = safe_to (sample_output.top_tokens , torch::kCPU , true );
412+ // [num_seq, topk]
413+ top_logprobs =
414+ safe_to (sample_output.top_logprobs , torch::kCPU , true );
415+ }
416+ }
417+
418+ // beam search output
419+ // [num_seq]
420+ src_seq_idxes =
421+ safe_to (beam_search_output.src_seq_idxes , torch::kCPU , true );
422+ if (src_seq_idxes.defined ()) {
423+ // [num_seq]
424+ out_tokens =
425+ safe_to (beam_search_output.out_tokens , torch::kCPU , true );
426+ // [num_seq]
427+ out_logprobs =
428+ safe_to (beam_search_output.out_logprobs ,
429+ torch::dtype (torch::kFloat32 ).device (torch::kCPU ),
430+ true );
390431 }
391432 auto ret = stream_->synchronize ();
392433 }
@@ -419,6 +460,9 @@ void WorkerService::ExecuteModel(
419460 embeddings,
420461 expert_load_data,
421462 prepared_layer_id,
463+ src_seq_idxes,
464+ out_tokens,
465+ out_logprobs,
422466 pb_forward_output);
423467 COUNTER_ADD (worker_service_latency_seconds, timer.elapsed_seconds ());
424468 });
@@ -441,6 +485,8 @@ void WorkerService::GetLastStepResult(
441485 const auto & expert_load_data = safe_to (
442486 forward_outputs.value ().expert_load_data , torch::kCPU , true );
443487 int32_t prepared_layer_id = forward_outputs.value ().prepared_layer_id ;
488+ const auto & beam_search_output =
489+ forward_outputs.value ().beam_search_output ;
444490 c10::StreamGuard streamGuard = stream_->set_stream_guard ();
445491 // [num_seq, ..., embed_dim]
446492 auto embeddings =
@@ -460,6 +506,17 @@ void WorkerService::GetLastStepResult(
460506 // [num_seq, topk]
461507 const auto & top_logprobs =
462508 safe_to (sample_output.top_logprobs , torch::kCPU , true );
509+ // [num_seq]
510+ const auto & src_seq_idxes =
511+ safe_to (beam_search_output.src_seq_idxes , torch::kCPU , true );
512+ // [num_seq]
513+ const auto & out_tokens =
514+ safe_to (beam_search_output.out_tokens , torch::kCPU , true );
515+ // [num_seq]
516+ const auto & out_logprobs =
517+ safe_to (beam_search_output.out_logprobs ,
518+ torch::dtype (torch::kFloat32 ).device (torch::kCPU ),
519+ true );
463520 auto ret = stream_->synchronize ();
464521
465522 forward_output_to_proto (next_tokens,
@@ -469,6 +526,9 @@ void WorkerService::GetLastStepResult(
469526 embeddings,
470527 expert_load_data,
471528 prepared_layer_id,
529+ src_seq_idxes,
530+ out_tokens,
531+ out_logprobs,
472532 pb_forward_output);
473533 }
474534 }
0 commit comments