From 32180de7fddb1052bd4d19b79fe1bc21327df774 Mon Sep 17 00:00:00 2001 From: Alisehen <814073252@qq.com> Date: Mon, 12 May 2025 02:40:47 +0000 Subject: [PATCH 1/2] add qwen3 local chat --- ktransformers/local_chat.py | 10 +- ktransformers/models/modeling_qwen3_moe.py | 6 +- ktransformers/operators/experts.py | 106 ++++++- ktransformers/operators/models.py | 288 +++++++++++++++++- .../optimize_rules/Qwen3-30B-A3B.yaml | 83 +++++ setup.py | 24 +- test_speed.py | 138 +++++++++ 7 files changed, 638 insertions(+), 17 deletions(-) create mode 100644 ktransformers/optimize/optimize_rules/Qwen3-30B-A3B.yaml create mode 100644 test_speed.py diff --git a/ktransformers/local_chat.py b/ktransformers/local_chat.py index 928de487..8c6f7718 100644 --- a/ktransformers/local_chat.py +++ b/ktransformers/local_chat.py @@ -25,6 +25,7 @@ from ktransformers.optimize.optimize import optimize_and_load_gguf from ktransformers.models.modeling_deepseek import DeepseekV2ForCausalLM from ktransformers.models.modeling_qwen2_moe import Qwen2MoeForCausalLM +from ktransformers.models.modeling_qwen3_moe import Qwen3MoeForCausalLM from ktransformers.models.modeling_deepseek_v3 import DeepseekV3ForCausalLM from ktransformers.models.modeling_llama import LlamaForCausalLM from ktransformers.models.modeling_mixtral import MixtralForCausalLM @@ -37,6 +38,7 @@ "DeepseekV2ForCausalLM": DeepseekV2ForCausalLM, "DeepseekV3ForCausalLM": DeepseekV3ForCausalLM, "Qwen2MoeForCausalLM": Qwen2MoeForCausalLM, + "Qwen3MoeForCausalLM": Qwen3MoeForCausalLM, "LlamaForCausalLM": LlamaForCausalLM, "MixtralForCausalLM": MixtralForCausalLM, } @@ -182,4 +184,10 @@ def local_chat( if __name__ == "__main__": - fire.Fire(local_chat) + # fire.Fire(local_chat) + local_chat( + model_path="/mnt/data/models/Qwen3-30B-A3B-250425/", + optimize_config_path="ktransformers/optimize/optimize_rules/Qwen3-30B-A3B.yaml", + gguf_path="/mnt/data/models/Qwen3-30B-A3B-GGUF/", + use_cuda_graph=False + ) diff --git a/ktransformers/models/modeling_qwen3_moe.py b/ktransformers/models/modeling_qwen3_moe.py index 175f88c6..100d75fd 100644 --- a/ktransformers/models/modeling_qwen3_moe.py +++ b/ktransformers/models/modeling_qwen3_moe.py @@ -185,9 +185,10 @@ def forward( hidden_states: torch.Tensor, position_embeddings: Tuple[torch.Tensor, torch.Tensor], attention_mask: Optional[torch.Tensor], + position_ids: Optional[torch.LongTensor], past_key_value: Optional[Cache] = None, cache_position: Optional[torch.LongTensor] = None, - # **kwargs: Unpack[FlashAttentionKwargs], + **kwargs, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: input_shape = hidden_states.shape[:-1] hidden_shape = (*input_shape, -1, self.head_dim) @@ -196,7 +197,8 @@ def forward( key_states = self.k_norm(self.k_proj(hidden_states).view(hidden_shape)).transpose(1, 2) value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) - cos, sin = position_embeddings + # cos, sin = position_embeddings + cos, sin = self.rotary_emb(value_states, position_ids) query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) if past_key_value is not None: diff --git a/ktransformers/operators/experts.py b/ktransformers/operators/experts.py index 34f0af09..163468ef 100644 --- a/ktransformers/operators/experts.py +++ b/ktransformers/operators/experts.py @@ -1494,4 +1494,108 @@ def moe_infer(self, x, topk_ids, topk_weight): .sum(dim=1) .type(new_x.dtype) ) - return final_out \ No newline at end of file + return final_out + +class KQwen3MoeSparseMoeBlock(BaseInjectedModule, Qwen3MoeSparseMoeBlock): + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + """ """ + orig_shape = hidden_states.shape + batch_size, sequence_length, hidden_dim = hidden_states.shape + hidden_states = hidden_states.view(-1, hidden_dim) + # router_logits: (batch * sequence_length, n_experts) + router_logits = self.gate(hidden_states) + + routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float) + routing_weights, selected_experts = torch.topk(routing_weights, self.top_k, dim=-1) + if self.norm_topk_prob: + routing_weights /= routing_weights.sum(dim=-1, keepdim=True) + # we cast back to the input dtype + routing_weights = routing_weights.to(hidden_states.dtype) + + if sequence_length == 1 and hasattr(self.experts.generate_experts, "submit_for_one_decode"): + self.experts.generate_experts.submit_for_one_decode(hidden_states[0], selected_experts[0], routing_weights[0]) + # shared_expert_output = self.shared_expert(hidden_states) + # shared_expert_output = F.sigmoid(self.shared_expert_gate(hidden_states)) * shared_expert_output + y = self.experts.generate_experts.sync_for_one_decode().unsqueeze(0) + # y += shared_expert_output + y.resize_(*orig_shape) + return y, router_logits + + hidden_states_expert = hidden_states.to(self.experts.device) if isinstance(self.experts, KExpertsBase) else hidden_states.cpu() + selected_experts_expert = selected_experts.to(self.experts.device) if isinstance(self.experts, KExpertsBase) else selected_experts.cpu() + routing_weights_expert = routing_weights.to(self.experts.device) if isinstance(self.experts, KExpertsBase) else routing_weights.cpu() + + # shared_expert_output = self.shared_expert(hidden_states) + # shared_expert_output = ( + # F.sigmoid(self.shared_expert_gate(hidden_states)) * shared_expert_output + # ) + + if isinstance(self.experts, KExpertsBase): + y = ( + self.moe_kexperts( + hidden_states_expert, selected_experts_expert, routing_weights_expert + ) + .view(*orig_shape) + .to(device=hidden_states.device) + ) + elif hidden_states_expert.size(0) > 10: + y = self.moe_infer( + hidden_states_expert, selected_experts_expert, routing_weights_expert, orig_shape + ).to(device=hidden_states.device) + else: + y = self.moe_infer_simple( + hidden_states_expert, selected_experts_expert, routing_weights_expert + ).to(device=hidden_states.device) + # y += shared_expert_output + y.resize_(*orig_shape) + return y, router_logits + + @torch.no_grad() + def moe_kexperts(self, x: torch.Tensor, topk_ids: torch.Tensor, topk_weight: torch.Tensor) -> torch.Tensor: + outs = self.experts(x, topk_ids, topk_weight) + return outs + + @torch.no_grad() + # TODO may bugs here + def moe_infer_simple(self, hidden_states_cpu: torch.Tensor, selected_experts_cpu: torch.Tensor, routing_weights_cpu: torch.Tensor) -> torch.Tensor: + ''' + hidden_states_cpu: [num_tokens, hidden_size] + topk_ids, topk_weight: [num_tokens, num_selected_experts] + ''' + outs = torch.zeros_like(hidden_states_cpu) + for token_idx in range(selected_experts_cpu.size(0)): + for expert_idx in range(selected_experts_cpu.size(1)): + expert = self.experts[selected_experts_cpu[token_idx, expert_idx]] + outs[token_idx] += expert.forward(hidden_states_cpu[token_idx]) * routing_weights_cpu[token_idx, expert_idx] + return outs + + @torch.no_grad() + # TODO may bugs here + def moe_infer(self, hidden_states_cpu: torch.Tensor, selected_experts_cpu: torch.Tensor, routing_weights_cpu: torch.Tensor, orig_shape: tuple) -> torch.Tensor: + + batch_size, sequence_length, hidden_dim = orig_shape + + final_hidden_states = torch.zeros( + (batch_size * sequence_length, hidden_dim), dtype=hidden_states_cpu.dtype, device=hidden_states_cpu.device + ) + + # One hot encode the selected experts to create an expert mask + # this will be used to easily index which expert is going to be sollicitated + expert_mask = torch.nn.functional.one_hot(selected_experts_cpu, num_classes=self.num_experts).permute(2, 1, 0) + + # Loop over all available experts in the model and perform the computation on each expert + for expert_idx in range(self.num_experts): + expert_layer = self.experts[expert_idx] + idx, top_x = torch.where(expert_mask[expert_idx]) + + # Index the correct hidden states and compute the expert hidden state for + # the current expert. We need to make sure to multiply the output hidden + # states by `routing_weights` on the corresponding tokens (top-1 and top-2) + current_state = hidden_states_cpu[None, top_x].reshape(-1, hidden_dim) + current_hidden_states = expert_layer.forward(current_state) * routing_weights_cpu[top_x, idx, None] + + # However `index_add_` only support torch tensors for indexing so we'll use + # the `top_x` tensor here. + final_hidden_states.index_add_(0, top_x, current_hidden_states.to(hidden_states_cpu.dtype)) + + return final_hidden_states \ No newline at end of file diff --git a/ktransformers/operators/models.py b/ktransformers/operators/models.py index bbac29a3..f13387fd 100644 --- a/ktransformers/operators/models.py +++ b/ktransformers/operators/models.py @@ -65,7 +65,7 @@ LlamaRMSNorm, LlamaRotaryEmbedding, ) - +from ktransformers.models.modeling_qwen3_moe import Qwen3MoeRotaryEmbedding if is_flash_attn_2_available(): from flash_attn import flash_attn_func, flash_attn_varlen_func from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa @@ -1350,3 +1350,289 @@ def _update_causal_mask( ) return causal_mask + +class KQwen3MoeModel(BaseInjectedModule): + """ + Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`Qwen2MoeDecoderLayer`] + + Args: + config: Qwen2MoeConfig + """ + + def __init__( + self, + key: str, + gguf_loader: GGUFLoader, + config: PretrainedConfig, + orig_module: nn.Module, + device: str = "cuda", + per_layer_prefill_intput_threshold: int = 30000, # if None, no per-layer prefill + transfer_map: dict = None, + **kwargs, + ): + BaseInjectedModule.__init__( + self, key, gguf_loader, config, orig_module, device, **kwargs + ) + self.per_layer_prefill_intput_threshold = per_layer_prefill_intput_threshold + self.transfer_map = transfer_map + self.stream_device_map = dict() + self.rotary_emb = Qwen3MoeRotaryEmbedding(config=config) + + + @add_start_docstrings_to_model_forward(QWEN2MOE_INPUTS_DOCSTRING) + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + output_router_logits: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + per_layer_prefill_intput_threshold: ( + int | None + ) = None, # if None or 0, close per-layer prefill + ) -> Union[Tuple, MoeModelOutputWithPast]: + # print(f'Total length of input_ids: {input_ids.size(1)}, {input_ids.size()}') + + if per_layer_prefill_intput_threshold is None: + per_layer_prefill_intput_threshold = self.per_layer_prefill_intput_threshold + per_layer_prefill_flag = False + seq_lenth = ( + inputs_embeds.size(1) if inputs_embeds is not None else input_ids.size(1) + ) + if ( + per_layer_prefill_intput_threshold + and per_layer_prefill_intput_threshold < seq_lenth + ): + per_layer_prefill_flag = True + for layer in self.layers: + self.load_layer_to(layer, InferenceState.UNLOAD) + else: + pass + output_attentions = ( + output_attentions + if output_attentions is not None + else self.config.output_attentions + ) + output_router_logits = ( + output_router_logits + if output_router_logits is not None + else self.config.output_router_logits + ) + output_hidden_states = ( + output_hidden_states + if output_hidden_states is not None + else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + + return_dict = ( + return_dict if return_dict is not None else self.config.use_return_dict + ) + + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError( + "You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one" + ) + + if self.gradient_checkpointing and self.training: + if use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." + ) + use_cache = False + + if use_cache and past_key_values is None: + past_key_values = DynamicCache() + # use_legacy_cache = False + # if use_cache and not isinstance(past_key_values, Cache): + # use_legacy_cache = True + # past_key_values = DynamicCache.from_legacy_cache(past_key_values) + # logger.warning_once( + # "We detected that you are passing `past_key_values` as a tuple and this is deprecated and will be removed in v4.43. " + # "Please use an appropriate `Cache` class (https://huggingface.co/docs/transformers/v4.41.3/en/internal/generation_utils#transformers.Cache)" + # ) + + if inputs_embeds is None: + input_ids = input_ids.to("cpu") + inputs_embeds = self.embed_tokens(input_ids) + inputs_embeds = inputs_embeds.to("cuda") + + if cache_position is None: + past_seen_tokens = ( + past_key_values.get_seq_length() if past_key_values is not None else 0 + ) + cache_position = torch.arange( + past_seen_tokens, + past_seen_tokens + inputs_embeds.shape[1], + device=inputs_embeds.device, + ) + if position_ids is None: + position_ids = cache_position.unsqueeze(0) + + causal_mask = self._update_causal_mask( + attention_mask, + inputs_embeds, + cache_position, + past_key_values, + output_attentions, + ) + + hidden_states = inputs_embeds + + # position_embeddings = self.rotary_emb(hidden_states, position_ids) + # decoder layers + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + all_router_logits = () if output_router_logits else None + # next_decoder_cache = None + + for i, decoder_layer in enumerate(self.layers): + # if self.transfer_map is not None and i in self.transfer_map: + # prev_stream = torch.cuda.current_stream() + # cur_device = self.transfer_map[i] + # if cur_device not in self.stream_device_map: + # self.stream_device_map[cur_device] = torch.cuda.Stream(cur_device) + # torch.cuda.set_device(cur_device) + # self.stream_device_map[cur_device].wait_stream(prev_stream) + # torch.cuda.set_stream(self.stream_device_map[cur_device]) + # hidden_states = hidden_states.to( + # self.transfer_map[i], non_blocking=True + # ) + # causal_mask = ( + # causal_mask.to(self.transfer_map[i], non_blocking=True) + # if causal_mask is not None + # else None + # ) + # position_ids = ( + # position_ids.to(self.transfer_map[i], non_blocking=True) + # if position_ids is not None + # else None + # ) + # cache_position = ( + # cache_position.to(self.transfer_map[i], non_blocking=True) + # if cache_position is not None + # else None + # ) + + if output_hidden_states: + all_hidden_states += (hidden_states,) + + if self.gradient_checkpointing and self.training: + layer_outputs = self._gradient_checkpointing_func( + decoder_layer.__call__, + hidden_states, + causal_mask, + position_ids, + past_key_values, + output_attentions, + output_router_logits, + use_cache, + cache_position, + # position_embeddings, + ) + else: + if per_layer_prefill_flag: + # print(f"to gpu") + self.load_layer_to(decoder_layer, InferenceState.PREFILL) + torch.cuda.empty_cache() + layer_outputs = decoder_layer( + hidden_states, + attention_mask=causal_mask, + position_ids=position_ids, + past_key_value=past_key_values, + output_attentions=output_attentions, + output_router_logits=output_router_logits, + use_cache=use_cache, + cache_position=cache_position, + # position_embeddings=position_embeddings, + ) + if per_layer_prefill_flag: + # print(f"to cpu") + self.load_layer_to(decoder_layer, InferenceState.UNLOAD) + torch.cuda.empty_cache() + hidden_states = layer_outputs[0] + # use_cache=False + # if use_cache: + # next_decoder_cache = layer_outputs[2 if output_attentions else 1] + + if output_attentions: + all_self_attns += (layer_outputs[1],) + + if output_router_logits and layer_outputs[-1] is not None: + all_router_logits += (layer_outputs[-1],) + + hidden_states = self.norm(hidden_states) + + if per_layer_prefill_flag: + per_layer_prefill_flag = False + for layer in self.layers: + self.load_layer_to(layer, InferenceState.GENERATE) + if output_hidden_states: + all_hidden_states += (hidden_states,) + + # next_cache = None + # if use_cache: + # next_cache = ( + # next_decoder_cache.to_legacy_cache() + # if use_legacy_cache + # else next_decoder_cache + # ) + + if not return_dict: + return tuple( + v + for v in [ + hidden_states, + past_key_values, + all_hidden_states, + all_self_attns, + all_router_logits, + ] + if v is not None + ) + return MoeModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=past_key_values, + hidden_states=all_hidden_states, + attentions=all_self_attns, + router_logits=all_router_logits, + ) + + def load_layer_to(self, layer: Qwen2MoeDecoderLayer, target: InferenceState): + assert isinstance( + layer, Qwen2MoeDecoderLayer + ), "module should be nn.ModuleList of decoder layers" + + # TODO Support restore to original device, not only cuda + device = "cpu" if target == InferenceState.UNLOAD else "cuda" + + # attn + layer.self_attn.q_proj.set_inference_mode(target) + layer.self_attn.k_proj.set_inference_mode(target) + layer.self_attn.v_proj.set_inference_mode(target) + layer.self_attn.o_proj.set_inference_mode(target) + layer.self_attn.rotary_emb = layer.self_attn.rotary_emb.to(device) + + # mlp + if isinstance(layer.mlp, Qwen2MoeSparseMoeBlock): + layer.mlp.gate.set_inference_mode(target) + layer.mlp.experts.set_inference_mode(target) + layer.mlp.shared_expert.gate_proj.set_inference_mode(target) + layer.mlp.shared_expert.up_proj.set_inference_mode(target) + layer.mlp.shared_expert.down_proj.set_inference_mode(target) + layer.mlp.shared_expert.act_fn.to(device) + layer.mlp.shared_expert_gate.to(device) + else: + layer.mlp.gate_proj.set_inference_mode(target) + layer.mlp.up_proj.set_inference_mode(target) + layer.mlp.down_proj.set_inference_mode(target) + layer.mlp.act_fn.to(device) + # layer norm + layer.input_layernorm.to(device) + layer.post_attention_layernorm.to(device) \ No newline at end of file diff --git a/ktransformers/optimize/optimize_rules/Qwen3-30B-A3B.yaml b/ktransformers/optimize/optimize_rules/Qwen3-30B-A3B.yaml new file mode 100644 index 00000000..3ab73c72 --- /dev/null +++ b/ktransformers/optimize/optimize_rules/Qwen3-30B-A3B.yaml @@ -0,0 +1,83 @@ +- match: + class: ktransformers.models.modeling_qwen2_moe.Qwen2MoeRotaryEmbedding + replace: + class: ktransformers.operators.RoPE.RotaryEmbedding + kwargs: + generate_device: "cuda" + prefill_device: "cuda" + +# - match: +# name: "^model\\.layers\\..*$" # regular expression +# class: torch.nn.Linear # only match modules matching name and class simultaneously +# replace: +# class: ktransformers.operators.linear.KTransformersLinear # optimized Kernel on quantized data types +# kwargs: +# generate_device: "cuda" +# prefill_device: "cuda" +# generate_op: "KLinearMarlin" +# prefill_op: "KLinearTorch" + +- match: + name: "^lm_head$" # regular expression + class: torch.nn.Linear # only match modules matching name and class simultaneously + replace: + class: ktransformers.operators.linear.KTransformersLinear # optimized Kernel on quantized data types + kwargs: + generate_device: "cuda" + prefill_device: "cuda" + generate_op: "VLinearMarlin" + prefill_op: "KLinearTorch" + +- match: + name: "^model\\.layers\\.(?!.*mlp\\.shared_expert_gate).*$" # regular expression + class: torch.nn.Linear # only match modules matching name and class simultaneously + replace: + class: ktransformers.operators.linear.KTransformersLinear # optimized Kernel on quantized data types + kwargs: + generate_device: "cuda" + prefill_device: "cuda" + generate_op: "KLinearMarlin" + prefill_op: "KLinearTorch" + +- match: + name: "^model\\.layers\\..*\\.mlp$" + class: ktransformers.models.modeling_qwen3_moe.Qwen3MoeSparseMoeBlock + replace: + class: ktransformers.operators.experts.KQwen3MoeSparseMoeBlock # mlp module with custom forward function + kwargs: + generate_device: "cuda" + prefill_device: "cuda" + +- match: + name: "^model\\.layers\\..*\\.mlp\\.experts$" + replace: + class: ktransformers.operators.experts.KTransformersExperts # custom MoE Kernel with expert paralleism + kwargs: + prefill_device: "cuda" + prefill_op: "KExpertsTorch" + generate_device: "cpu" + generate_op: "KExpertsCPU" + out_device: "cuda" + recursive: False # don't recursively inject submodules of this module + +# - match: +# name: "^model$" +# replace: +# class: "ktransformers.operators.models.KQwen3MoeModel" +# kwargs: +# per_layer_prefill_intput_threshold: 0 # 0 is close layer wise prefill + +- match: + name: "^model.embed_tokens" + replace: + class: "default" + kwargs: + generate_device: "cpu" + prefill_device: "cpu" +# - match: +# name: "^model\\.layers\\..*\\." +# replace: +# class: "default" +# kwargs: +# generate_device: "cuda" +# prefill_device: "cuda" \ No newline at end of file diff --git a/setup.py b/setup.py index c5bf1280..17ace1b9 100644 --- a/setup.py +++ b/setup.py @@ -625,18 +625,18 @@ def build_extension(self, ext) -> None: ext_modules = [ CMakeExtension("cpuinfer_ext", os.fspath(Path("").resolve() / "csrc" / "ktransformers_ext")), - ops_module, - CUDAExtension( - 'vLLMMarlin', [ - 'csrc/custom_marlin/binding.cpp', - 'csrc/custom_marlin/gptq_marlin/gptq_marlin.cu', - 'csrc/custom_marlin/gptq_marlin/gptq_marlin_repack.cu', - ], - extra_compile_args={ - 'cxx': ['-O3'], - 'nvcc': ['-O3', '-Xcompiler', '-fPIC'], - }, - ) + # ops_module, + # CUDAExtension( + # 'vLLMMarlin', [ + # 'csrc/custom_marlin/binding.cpp', + # 'csrc/custom_marlin/gptq_marlin/gptq_marlin.cu', + # 'csrc/custom_marlin/gptq_marlin/gptq_marlin_repack.cu', + # ], + # extra_compile_args={ + # 'cxx': ['-O3'], + # 'nvcc': ['-O3', '-Xcompiler', '-fPIC'], + # }, + # ) ] if with_balance: print("using balance_serve") diff --git a/test_speed.py b/test_speed.py new file mode 100644 index 00000000..c45c1f61 --- /dev/null +++ b/test_speed.py @@ -0,0 +1,138 @@ +import asyncio +import json +import sys +import aiohttp +import random +import argparse +import yaml +import os +import time +from time import sleep + +decodesz = 128 +# Server URL (replace with your server URL) +decodesz_list = [128] +prefill_speeds = [] +decode_speeds = [] +ktansformer_prompt="" +image_path="" +async def fetch_event_stream(session, request_id, prompt,image_path, max_tokens, model): + try: + payload = { + "messages": [ + { + "role": "user", + "content": [ + {"type": "text", "text": prompt}, + { + "type": "image_url", + "image_url": { + "url": image_path + }, + }, + ], + } + ], + "model": model, + "temperature": 0.3, + "top_p": 1.0, + "stream": True, + "return_speed": True, + "max_tokens": max_tokens, + } + + headers = { + 'accept': 'application/json', + 'Content-Type': 'application/json' + } + + async with session.post(SERVER_URL, json=payload, headers=headers, timeout=500000) as response: + if response.status != 200: + print(f"[Request {request_id}] Error: Status {response.status}") + return + + buffer = "" + total_tokens = 0 + decode_start_time = None + decode_end_time = None + usage_info = None + + async for line in response.content: + try: + decoded_line = line.decode("utf-8").strip() + if not decoded_line or not decoded_line.startswith("data: "): + continue + + decoded_line = decoded_line[6:].strip() + if not decoded_line: + continue + + response_data = json.loads(decoded_line) + + if "usage" in response_data: + usage_info = response_data["usage"] + + choices = response_data.get("choices", []) + if not choices: + continue + + delta = choices[0].get("delta", {}) + token = delta.get("content", "") + + if token: + if decode_start_time is None: + decode_start_time = time.time() + buffer += token + total_tokens += 1 + decode_end_time = time.time() + + while "\n" in buffer: + line, buffer = buffer.split("\n", 1) + print(f"[Request {request_id}] {line}") + + finish_reason = choices[0].get("finish_reason", None) + if finish_reason: + break + + except Exception as e: + print(f"[Request {request_id}] Stream Error: {e}") + + if buffer.strip(): + print(f"[Request {request_id}] {buffer.strip()}") + + if usage_info: + if "prefill_time" in usage_info: + prefill_speed = usage_info["prompt_tokens"] / usage_info["prefill_time"] + decode_speed = usage_info["completion_tokens"] / usage_info["decode_time"] + prefill_speeds.append(prefill_speed) + decode_speeds.append(decode_speed) + print(f'[Request {request_id}] prefill speed: {prefill_speed}') + print(f'[Request {request_id}] decode speed: {decode_speed}') + + except Exception as e: + print(f"[Request {request_id}] Exception: {e}") + +async def main(concurrent_requests , prompt, image_path,max_tokens, model): + async with aiohttp.ClientSession() as session: + tasks = [fetch_event_stream(session, i , prompt, image_path, max_tokens, model) for i in range(concurrent_requests)] + await asyncio.gather(*tasks) + if len(prefill_speeds) != 0: + import numpy as np + print(f"concurrency: {len(prefill_speeds)}") + print(f"total prefill speed: {np.sum(prefill_speeds)}\n total decode speed: {np.sum(decode_speeds)}") + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Event Stream Request Tester") + parser.add_argument("--concurrent", type=int, default=1, help="Number of concurrent requests") + parser.add_argument("--model", type=str, default="DeepSeek-V3", help="Model name") + parser.add_argument("--prompt_lens", type=int, default=1024, help="prefill prompt lens, 1024 or 2048") + parser.add_argument("--api_url", type=str, default="http://localhost:10002/v1/chat/completions", help="API URL") + parser.add_argument("--max_tokens", type=int, default=50, help="max decode tokens") + + args = parser.parse_args() + SERVER_URL = args.api_url + max_tokens = args.max_tokens + model = args.model + prompt = ktansformer_prompt + asyncio.run(main(args.concurrent, prompt, image_path,max_tokens, model)) + From db3f5120fade052a428323c42025d34320412e70 Mon Sep 17 00:00:00 2001 From: Alisehen <814073252@qq.com> Date: Thu, 15 May 2025 02:27:45 +0000 Subject: [PATCH 2/2] bug fix --- ktransformers/local_chat.py | 8 +-- setup.py | 24 +++---- test_speed.py | 138 ------------------------------------ 3 files changed, 13 insertions(+), 157 deletions(-) delete mode 100644 test_speed.py diff --git a/ktransformers/local_chat.py b/ktransformers/local_chat.py index 8c6f7718..973c4769 100644 --- a/ktransformers/local_chat.py +++ b/ktransformers/local_chat.py @@ -184,10 +184,4 @@ def local_chat( if __name__ == "__main__": - # fire.Fire(local_chat) - local_chat( - model_path="/mnt/data/models/Qwen3-30B-A3B-250425/", - optimize_config_path="ktransformers/optimize/optimize_rules/Qwen3-30B-A3B.yaml", - gguf_path="/mnt/data/models/Qwen3-30B-A3B-GGUF/", - use_cuda_graph=False - ) + fire.Fire(local_chat) diff --git a/setup.py b/setup.py index 17ace1b9..c5bf1280 100644 --- a/setup.py +++ b/setup.py @@ -625,18 +625,18 @@ def build_extension(self, ext) -> None: ext_modules = [ CMakeExtension("cpuinfer_ext", os.fspath(Path("").resolve() / "csrc" / "ktransformers_ext")), - # ops_module, - # CUDAExtension( - # 'vLLMMarlin', [ - # 'csrc/custom_marlin/binding.cpp', - # 'csrc/custom_marlin/gptq_marlin/gptq_marlin.cu', - # 'csrc/custom_marlin/gptq_marlin/gptq_marlin_repack.cu', - # ], - # extra_compile_args={ - # 'cxx': ['-O3'], - # 'nvcc': ['-O3', '-Xcompiler', '-fPIC'], - # }, - # ) + ops_module, + CUDAExtension( + 'vLLMMarlin', [ + 'csrc/custom_marlin/binding.cpp', + 'csrc/custom_marlin/gptq_marlin/gptq_marlin.cu', + 'csrc/custom_marlin/gptq_marlin/gptq_marlin_repack.cu', + ], + extra_compile_args={ + 'cxx': ['-O3'], + 'nvcc': ['-O3', '-Xcompiler', '-fPIC'], + }, + ) ] if with_balance: print("using balance_serve") diff --git a/test_speed.py b/test_speed.py deleted file mode 100644 index c45c1f61..00000000 --- a/test_speed.py +++ /dev/null @@ -1,138 +0,0 @@ -import asyncio -import json -import sys -import aiohttp -import random -import argparse -import yaml -import os -import time -from time import sleep - -decodesz = 128 -# Server URL (replace with your server URL) -decodesz_list = [128] -prefill_speeds = [] -decode_speeds = [] -ktansformer_prompt="" -image_path="" -async def fetch_event_stream(session, request_id, prompt,image_path, max_tokens, model): - try: - payload = { - "messages": [ - { - "role": "user", - "content": [ - {"type": "text", "text": prompt}, - { - "type": "image_url", - "image_url": { - "url": image_path - }, - }, - ], - } - ], - "model": model, - "temperature": 0.3, - "top_p": 1.0, - "stream": True, - "return_speed": True, - "max_tokens": max_tokens, - } - - headers = { - 'accept': 'application/json', - 'Content-Type': 'application/json' - } - - async with session.post(SERVER_URL, json=payload, headers=headers, timeout=500000) as response: - if response.status != 200: - print(f"[Request {request_id}] Error: Status {response.status}") - return - - buffer = "" - total_tokens = 0 - decode_start_time = None - decode_end_time = None - usage_info = None - - async for line in response.content: - try: - decoded_line = line.decode("utf-8").strip() - if not decoded_line or not decoded_line.startswith("data: "): - continue - - decoded_line = decoded_line[6:].strip() - if not decoded_line: - continue - - response_data = json.loads(decoded_line) - - if "usage" in response_data: - usage_info = response_data["usage"] - - choices = response_data.get("choices", []) - if not choices: - continue - - delta = choices[0].get("delta", {}) - token = delta.get("content", "") - - if token: - if decode_start_time is None: - decode_start_time = time.time() - buffer += token - total_tokens += 1 - decode_end_time = time.time() - - while "\n" in buffer: - line, buffer = buffer.split("\n", 1) - print(f"[Request {request_id}] {line}") - - finish_reason = choices[0].get("finish_reason", None) - if finish_reason: - break - - except Exception as e: - print(f"[Request {request_id}] Stream Error: {e}") - - if buffer.strip(): - print(f"[Request {request_id}] {buffer.strip()}") - - if usage_info: - if "prefill_time" in usage_info: - prefill_speed = usage_info["prompt_tokens"] / usage_info["prefill_time"] - decode_speed = usage_info["completion_tokens"] / usage_info["decode_time"] - prefill_speeds.append(prefill_speed) - decode_speeds.append(decode_speed) - print(f'[Request {request_id}] prefill speed: {prefill_speed}') - print(f'[Request {request_id}] decode speed: {decode_speed}') - - except Exception as e: - print(f"[Request {request_id}] Exception: {e}") - -async def main(concurrent_requests , prompt, image_path,max_tokens, model): - async with aiohttp.ClientSession() as session: - tasks = [fetch_event_stream(session, i , prompt, image_path, max_tokens, model) for i in range(concurrent_requests)] - await asyncio.gather(*tasks) - if len(prefill_speeds) != 0: - import numpy as np - print(f"concurrency: {len(prefill_speeds)}") - print(f"total prefill speed: {np.sum(prefill_speeds)}\n total decode speed: {np.sum(decode_speeds)}") - -if __name__ == "__main__": - parser = argparse.ArgumentParser(description="Event Stream Request Tester") - parser.add_argument("--concurrent", type=int, default=1, help="Number of concurrent requests") - parser.add_argument("--model", type=str, default="DeepSeek-V3", help="Model name") - parser.add_argument("--prompt_lens", type=int, default=1024, help="prefill prompt lens, 1024 or 2048") - parser.add_argument("--api_url", type=str, default="http://localhost:10002/v1/chat/completions", help="API URL") - parser.add_argument("--max_tokens", type=int, default=50, help="max decode tokens") - - args = parser.parse_args() - SERVER_URL = args.api_url - max_tokens = args.max_tokens - model = args.model - prompt = ktansformer_prompt - asyncio.run(main(args.concurrent, prompt, image_path,max_tokens, model)) -