diff --git a/test/benchmark/static_inference/model_infer.py b/test/benchmark/static_inference/model_infer.py index b7c07d17a..b8b535647 100644 --- a/test/benchmark/static_inference/model_infer.py +++ b/test/benchmark/static_inference/model_infer.py @@ -185,10 +185,14 @@ def prefill( b_ready_cache_len, ): b_mtp_index = torch.zeros(batch_size, dtype=torch.int32, device="cpu") + b_prefill_start_loc = b_seq_len.cumsum(dim=0, dtype=torch.int32) - b_seq_len model_input = ModelInput( batch_size=batch_size, total_token_num=total_token_num, max_len_in_batch=max_len_in_batch, + max_q_seq_len=max_len_in_batch, + max_kv_seq_len=max_len_in_batch, + max_cache_len=0, input_ids=input_ids, b_req_idx=b_req_idx, b_seq_len=b_seq_len, @@ -196,6 +200,8 @@ def prefill( mem_indexes_cpu=mem_indexes, is_prefill=True, b_ready_cache_len=b_ready_cache_len, # b_ready_cache_len + b_prefill_start_loc=b_prefill_start_loc, + prefix_total_token_num=0, # the default kvcache len is zero. ) model_output = model_part.forward(model_input) @@ -209,6 +215,8 @@ def decode( batch_size=batch_size, total_token_num=total_token_num, max_len_in_batch=max_len_in_batch, + max_q_seq_len=1, + max_kv_seq_len=max_len_in_batch, input_ids=input_ids, b_req_idx=b_req_idx, b_seq_len=b_seq_len,