Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions vllm_ascend/envs.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,6 +176,9 @@
# Whether to anbale dynamic EPLB
"DYNAMIC_EPLB":
lambda: os.getenv("DYNAMIC_EPLB", "false").lower(),
# Whether to enable exponential overlap with model executing.
"VLLM_ASCEND_ENABLE_ASYNC_EXPONENTIAL":
lambda: bool(int(os.getenv("VLLM_ASCEND_ENABLE_ASYNC_EXPONENTIAL", '0'))),
}

# end-env-vars-definition
Expand Down
14 changes: 14 additions & 0 deletions vllm_ascend/sample/sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,16 @@ def __init__(self, logprobs_mode=DEFAULT_LOGPROBS_MODE):
super().__init__(logprobs_mode=logprobs_mode)
self.topk_topp_sampler = AscendTopKTopPSampler()

def set_q_event(self, q, event):
self.topk_topp_sampler.set_q_event(q, event)


class AscendTopKTopPSampler(TopKTopPSampler):

def set_q_event(self, q, event):
self.q = q
self.event = event

def _apply_top_k_top_p(
self,
logits: torch.Tensor,
Expand Down Expand Up @@ -72,4 +79,11 @@ def forward_native(self, logits, generators, k, p):
logits_to_return = logits.log_softmax(dim=-1, dtype=torch.float32)

probs = logits.softmax(dim=-1, dtype=torch.float32)
if getattr(self, "q", None) is not None:
return self.random_sample(probs, self.q,
self.event), logits_to_return
return random_sample(probs, generators), logits_to_return

def random_sample(self, probs, q, event):
event.synchronize()
return probs.div_(q).argmax(dim=-1).view(-1)
27 changes: 27 additions & 0 deletions vllm_ascend/worker/model_runner_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -625,6 +625,11 @@ def __init__(self, vllm_config: VllmConfig, device: torch.device):

self.transfer_event = torch.npu.Event()

if envs_ascend.VLLM_ASCEND_ENABLE_ASYNC_EXPONENTIAL and envs_ascend.VLLM_ASCEND_ENABLE_TOPK_TOPP_OPTIMIZATION:
logger.info("Enable async exponential while model executing.")
self._async_exponential_stream = torch.npu.Stream()
self._async_exponential_event = torch.npu.Event()

def _set_up_drafter(self):
# Set up speculative decoding.
self.spec_attn_mask = None
Expand Down Expand Up @@ -2311,6 +2316,11 @@ def execute_model(
aclgraph_runtime_mode, batch_descriptor = \
self.aclgraph_dispatcher.dispatch(num_tokens=num_input_tokens, uniform_decode=uniform_decode, has_lora=has_lora)

if envs_ascend.VLLM_ASCEND_ENABLE_ASYNC_EXPONENTIAL and envs_ascend.VLLM_ASCEND_ENABLE_TOPK_TOPP_OPTIMIZATION:
default_stream = torch.npu.current_stream()
self._do_async_exponential(default_stream=default_stream,
logits_indices=logits_indices)

# Run forward pass
with ProfileExecuteDuration().capture_async("forward"):
with set_ascend_forward_context(
Expand Down Expand Up @@ -4458,3 +4468,20 @@ def _generate_pcp_mtp_input(
self.input_ids_pcp_full_cpu[:total_num_scheduled_tokens_pcp_full],
non_blocking=True,
)

def _do_async_exponential(self, default_stream, logits_indices):
# Calculating exponential randoms in a different stream
# and overlapping with model executing.
with torch.npu.stream(self._async_exponential_stream):
self._async_exponential_stream.wait_stream(default_stream)
b_s = logits_indices.shape[0]
head_dim = self.model_config.get_vocab_size()
q = torch.empty((b_s, head_dim), device="npu", dtype=torch.float32)
generators = self.input_batch.sampling_metadata.generators
if len(generators) != q.shape[0]:
q.exponential_()
if generators:
Comment on lines +4481 to +4483
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what if both len(generators) != q.shape[0] and generators are True?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If both len(generators) != q.shape[0] and generators are True, we just do q.exponential_() first, then overwrite each q[i] with q[i].exponential_(generator=generator).
This part we simply re-use the same logic in vllm's random_sample. Hope this information is helpful!

for i, generator in generators.items():
q[i].exponential_(generator=generator)
self._async_exponential_event.record()
self.sampler.set_q_event(q, self._async_exponential_event)
Loading