Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -26,4 +26,8 @@ ktransformers/tests/chat_txt.txt
mmlu_result*
ktransformers/ktransformers_ext/cuda_musa/
test_prompt.txt
csrc/demo
csrc/demo
build*
CMakeFiles/
kvc2/
sched/
11 changes: 3 additions & 8 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -23,19 +23,14 @@ Our vision for KTransformers is to serve as a flexible platform for experimentin

<h2 id="Updates">🔥 Updates</h2>

* **July 26, 2025**: Support SmallThinker and GLM4-MoE. ([Tutorial](./doc/en/SmallThinker_and_Glm4moe.md))
* **July 11, 2025**: Support Kimi-K2. ([Tutorial](./doc/en/Kimi-K2.md))

* **June 30, 2025**: Support 3-layer (GPU-CPU-Disk) [prefix cache](./doc/en/prefix_cache.md) reuse.

* **May 14, 2025**: Support Intel Arc GPU ([Tutorial](./doc/en/xpu.md)).

* **Apr 29, 2025**: Support AMX-Int8、 AMX-BF16 and Qwen3MoE ([Tutorial](./doc/en/AMX.md))

https://github.com/user-attachments/assets/fafe8aec-4e22-49a8-8553-59fb5c6b00a2




* **Apr 9, 2025**: Experimental support for LLaMA 4 models ([Tutorial](./doc/en/llama4.md)).
* **Apr 2, 2025**: Support Multi-concurrency. ([Tutorial](./doc/en/balance-serve.md)).

Expand Down Expand Up @@ -65,7 +60,7 @@ https://github.com/user-attachments/assets/ebd70bfa-b2c1-4abb-ae3b-296ed38aa285
</p>

- **[NEW!!!] Local 671B DeepSeek-Coder-V3/R1:** Running its Q4_K_M version using only 14GB VRAM and 382GB DRAM([Tutorial](./doc/en/DeepseekR1_V3_tutorial.md)).

- Prefill Speed (tokens/s):
- KTransformers: 54.21 (32 cores) → 74.362 (dual-socket, 2×32 cores) → 255.26 (optimized AMX-based MoE kernel, V0.3 only) → 286.55 (selectively using 6 experts, V0.3 only)
- Compared to 10.31 tokens/s in llama.cpp with 2×32 cores, achieving up to **27.79× speedup**.
Expand Down Expand Up @@ -131,7 +126,6 @@ we have already supported vendors:
- Kunpeng
- AMD


### 📥 Installation

To install KTransformers, follow the official [Installation Guide](https://kvcache-ai.github.io/ktransformers/en/install.html).
Expand Down Expand Up @@ -201,3 +195,4 @@ If you have any questions, feel free to open an issue. Alternatively, you can jo
<h2 id="FAQ">🙋 FAQ</h2>

Some common questions are answered in the [FAQ](doc/en/FAQ.md).

8 changes: 4 additions & 4 deletions csrc/balance_serve/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,10 @@ message(STATUS "Using compiler: ${CMAKE_CXX_COMPILER}")
project(balance_serve VERSION 0.1.0)

set(CMAKE_CXX_STANDARD 20)
# set(CMAKE_CXX_FLAGS "-Og -march=native -Wall -Wextra -g -fPIC")
# set(CMAKE_BUILD_TYPE "Debug")
set(CMAKE_CXX_FLAGS "-O3 -march=native -Wall -Wextra -fPIC")
set(CMAKE_BUILD_TYPE "Release")
set(CMAKE_CXX_FLAGS "-Og -march=native -Wall -Wextra -g -fPIC")
set(CMAKE_BUILD_TYPE "Debug")
# set(CMAKE_CXX_FLAGS "-O3 -march=native -Wall -Wextra -fPIC")
# set(CMAKE_BUILD_TYPE "Release")


if(NOT DEFINED _GLIBCXX_USE_CXX11_ABI)
Expand Down
6 changes: 2 additions & 4 deletions csrc/balance_serve/sched/model_config.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,16 +15,14 @@ using ModelName = std::string;
class ModelConfig {
public:
DimSize hidden_size;
DimSize intermediate_size;
size_t max_position_embeddings;
std::string model_type;
size_t num_attention_heads;
size_t num_hidden_layers;
size_t num_key_value_heads;
size_t vocab_size;

NLOHMANN_DEFINE_TYPE_INTRUSIVE(ModelConfig, hidden_size, intermediate_size,
max_position_embeddings, model_type,
NLOHMANN_DEFINE_TYPE_INTRUSIVE(ModelConfig, hidden_size,
max_position_embeddings,
num_attention_heads, num_hidden_layers,
num_key_value_heads, vocab_size);

Expand Down
8 changes: 4 additions & 4 deletions csrc/ktransformers_ext/ext_bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -683,12 +683,12 @@ PYBIND11_MODULE(cpuinfer_ext, m) {
py::class_<MOEConfig>(moe_module, "MOEConfig")
.def(py::init([](int expert_num, int routed_expert_num, int hidden_size,
int intermediate_size, int stride, int group_min_len,
int group_max_len, intptr_t gate_proj,
int group_max_len, bool use_silu, intptr_t gate_proj,
intptr_t up_proj, intptr_t down_proj, int gate_type,
int up_type, int down_type, int hidden_type) {
return MOEConfig(expert_num, routed_expert_num, hidden_size,
intermediate_size, stride, group_min_len,
group_max_len, (void *)gate_proj, (void *)up_proj,
group_max_len, use_silu, (void *)gate_proj, (void *)up_proj,
(void *)down_proj, (ggml_type)gate_type,
(ggml_type)up_type, (ggml_type)down_type,
(ggml_type)hidden_type);
Expand All @@ -703,11 +703,11 @@ PYBIND11_MODULE(cpuinfer_ext, m) {
py::class_<AMX_MOEConfig>(moe_module, "AMX_MOEConfig")
.def(py::init([](int expert_num, int routed_expert_num, int hidden_size,
int intermediate_size,
int max_len, intptr_t gate_proj,
int max_len, bool use_silu, intptr_t gate_proj,
intptr_t up_proj, intptr_t down_proj) {
return AMX_MOEConfig(expert_num, routed_expert_num, hidden_size,
intermediate_size,
max_len, (void *)gate_proj,
max_len, use_silu, (void *)gate_proj,
(void *)up_proj, (void *)down_proj);
}));

Expand Down
50 changes: 37 additions & 13 deletions csrc/ktransformers_ext/operators/amx/moe.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -69,22 +69,29 @@ static inline __m512 act_fn(__m512 gate_val, __m512 up_val) {
return _mm512_mul_ps(act_val, up_val);
}

static inline __m512 relu_act_fn(__m512 gate_val, __m512 up_val) {
__m512 zero_vec = _mm512_setzero_ps();
__m512 act_val = _mm512_max_ps(zero_vec, gate_val);
return _mm512_mul_ps(act_val, up_val);
}

struct AMX_MOEConfig {
int expert_num;
int routed_expert_num;
int hidden_size;
int intermediate_size;
int max_len;
bool use_silu;
void *gate_proj;
void *up_proj;
void *down_proj;

AMX_MOEConfig() {}

AMX_MOEConfig(int expert_num, int routed_expert_num, int hidden_size, int intermediate_size, int max_len,
AMX_MOEConfig(int expert_num, int routed_expert_num, int hidden_size, int intermediate_size, int max_len, bool use_silu,
void *gate_proj, void *up_proj, void *down_proj)
: expert_num(expert_num), routed_expert_num(routed_expert_num), hidden_size(hidden_size),
intermediate_size(intermediate_size), max_len(max_len), gate_proj(gate_proj), up_proj(up_proj),
intermediate_size(intermediate_size), max_len(max_len), use_silu(use_silu), gate_proj(gate_proj), up_proj(up_proj),
down_proj(down_proj) {}
};

Expand Down Expand Up @@ -336,18 +343,35 @@ template <class T> class AMX_MOE {
gate_bc_[expert_idx]->to_mat(m_local_num_[expert_idx], m_local_gate_output_ptr_[expert_idx], ith, nth);
up_bc_[expert_idx]->to_mat(m_local_num_[expert_idx], m_local_up_output_ptr_[expert_idx], ith, nth);
auto [n_start, n_end] = T::split_range_n(config_.intermediate_size, ith, nth);
for (int i = 0; i < m_local_num_[expert_idx]; i++) {
ggml_bf16_t *gate_output_ptr = &m_local_gate_output_ptr_[expert_idx][i * config_.intermediate_size];
ggml_bf16_t *up_output_ptr = &m_local_up_output_ptr_[expert_idx][i * config_.intermediate_size];
for (int j = n_start; j < n_end; j += 32) {
__m512 gate_val0, gate_val1, up_val0, up_val1;
avx512_32xbf16_to_32xfp32((__m512i *)(gate_output_ptr + j), &gate_val0, &gate_val1);
avx512_32xbf16_to_32xfp32((__m512i *)(up_output_ptr + j), &up_val0, &up_val1);
__m512 result0 = act_fn(gate_val0, up_val0);
__m512 result1 = act_fn(gate_val1, up_val1);
avx512_32xfp32_to_32xbf16(&result0, &result1, (__m512i *)(gate_output_ptr + j));
}
if (config_.use_silu) {
for (int i = 0; i < m_local_num_[expert_idx]; i++) {
ggml_bf16_t *gate_output_ptr = &m_local_gate_output_ptr_[expert_idx][i * config_.intermediate_size];
ggml_bf16_t *up_output_ptr = &m_local_up_output_ptr_[expert_idx][i * config_.intermediate_size];
for (int j = n_start; j < n_end; j += 32) {
__m512 gate_val0, gate_val1, up_val0, up_val1;
avx512_32xbf16_to_32xfp32((__m512i *)(gate_output_ptr + j), &gate_val0, &gate_val1);
avx512_32xbf16_to_32xfp32((__m512i *)(up_output_ptr + j), &up_val0, &up_val1);
__m512 result0 = act_fn(gate_val0, up_val0);
__m512 result1 = act_fn(gate_val1, up_val1);
avx512_32xfp32_to_32xbf16(&result0, &result1, (__m512i *)(gate_output_ptr + j));
}
}
}
else {
for (int i = 0; i < m_local_num_[expert_idx]; i++) {
ggml_bf16_t *gate_output_ptr = &m_local_gate_output_ptr_[expert_idx][i * config_.intermediate_size];
ggml_bf16_t *up_output_ptr = &m_local_up_output_ptr_[expert_idx][i * config_.intermediate_size];
for (int j = n_start; j < n_end; j += 32) {
__m512 gate_val0, gate_val1, up_val0, up_val1;
avx512_32xbf16_to_32xfp32((__m512i *)(gate_output_ptr + j), &gate_val0, &gate_val1);
avx512_32xbf16_to_32xfp32((__m512i *)(up_output_ptr + j), &up_val0, &up_val1);
__m512 result0 = relu_act_fn(gate_val0, up_val0);
__m512 result1 = relu_act_fn(gate_val1, up_val1);
avx512_32xfp32_to_32xbf16(&result0, &result1, (__m512i *)(gate_output_ptr + j));
}
}
}

},
nullptr);
backend->do_work_stealing_job(
Expand Down
31 changes: 27 additions & 4 deletions csrc/ktransformers_ext/operators/llamafile/moe.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
#include "moe.h"
#include <iostream>
#include <cstdint>
#include <math.h>

#ifdef USE_NUMA
#include <numa.h>
Expand Down Expand Up @@ -134,6 +135,14 @@ static float act_fn(float x) {
return x / (1.0f + expf(-x));
}

static float act_fn_relu(float x) {
if(x > 0.0){
return x;
} else {
return 0.0;
}
}

void MOE::forward_one(int k, const uint64_t* expert_ids, const float* weights, const void* input, void* output, Backend* backend) {
const void* gate_input_ptr;
const void* up_input_ptr;
Expand Down Expand Up @@ -182,8 +191,16 @@ void MOE::forward_one(int k, const uint64_t* expert_ids, const float* weights, c

float* up_output_ptr = s_up_output_[expert_idx] + ith * config_.stride;
llamafile_sgemm(config_.stride, 1, config_.hidden_size / ggml_blck_size(config_.up_type), up_proj_ptr, config_.hidden_size / ggml_blck_size(config_.up_type), up_input_ptr, config_.hidden_size / ggml_blck_size(config_.up_type), up_output_ptr, config_.stride, 0, 1, GGML_TASK_TYPE_COMPUTE, config_.up_type, ggml_internal_get_type_traits(config_.up_type).vec_dot_type, GGML_TYPE_F32, GGML_PREC_DEFAULT);
for (int i = ith * config_.stride; i < (ith + 1) * config_.stride; i++) {
s_intermediate_fp32_[expert_idx][i] = act_fn(s_gate_output_[expert_idx][i]) * s_up_output_[expert_idx][i];
if(config_.use_silu){
// use silu as act fn
for (int i = ith * config_.stride; i < (ith + 1) * config_.stride; i++) {
s_intermediate_fp32_[expert_idx][i] = act_fn(s_gate_output_[expert_idx][i]) * s_up_output_[expert_idx][i];
}
} else {
// use relu as act fn
for (int i = ith * config_.stride; i < (ith + 1) * config_.stride; i++) {
s_intermediate_fp32_[expert_idx][i] = act_fn_relu(s_gate_output_[expert_idx][i]) * s_up_output_[expert_idx][i];
}
}
if (config_.stride % ggml_blck_size(ggml_internal_get_type_traits(config_.down_type).vec_dot_type) == 0) {
float* intermediate_fp32_ptr = s_intermediate_fp32_[expert_idx] + ith * config_.stride;
Expand Down Expand Up @@ -304,8 +321,14 @@ void MOE::forward_many(int qlen, int k, const uint64_t* expert_ids, const float*
float* up_output_ptr = m_local_up_output_ptr_[expert_idx] + ith * stride;
llamafile_sgemm(stride, m_local_num_[expert_idx], config_.hidden_size / ggml_blck_size(config_.up_type), up_proj_ptr, config_.hidden_size / ggml_blck_size(config_.up_type), up_input_ptr, config_.hidden_size / ggml_blck_size(config_.up_type), up_output_ptr, config_.intermediate_size, 0, 1, GGML_TASK_TYPE_COMPUTE, config_.up_type, ggml_internal_get_type_traits(config_.up_type).vec_dot_type, GGML_TYPE_F32, GGML_PREC_DEFAULT);
for (int i = 0; i < m_local_num_[expert_idx]; i++) {
for (int j = ith * stride; j < (ith + 1) * stride; j++) {
m_local_intermediate_fp32_ptr_[expert_idx][i * config_.intermediate_size + j] = act_fn(m_local_gate_output_ptr_[expert_idx][i * config_.intermediate_size + j]) * m_local_up_output_ptr_[expert_idx][i * config_.intermediate_size + j];
if(config_.use_silu){
for (int j = ith * stride; j < (ith + 1) * stride; j++) {
m_local_intermediate_fp32_ptr_[expert_idx][i * config_.intermediate_size + j] = act_fn(m_local_gate_output_ptr_[expert_idx][i * config_.intermediate_size + j]) * m_local_up_output_ptr_[expert_idx][i * config_.intermediate_size + j];
}
} else {
for (int j = ith * stride; j < (ith + 1) * stride; j++) {
m_local_intermediate_fp32_ptr_[expert_idx][i * config_.intermediate_size + j] = act_fn_relu(m_local_gate_output_ptr_[expert_idx][i * config_.intermediate_size + j]) * m_local_up_output_ptr_[expert_idx][i * config_.intermediate_size + j];
}
}
float* intermediate_fp32_ptr = m_local_intermediate_fp32_ptr_[expert_idx] + i * config_.intermediate_size + ith * stride;
void* down_input_ptr = m_local_down_input_ptr_[expert_idx] + i * config_.intermediate_size * ggml_type_size(ggml_internal_get_type_traits(config_.down_type).vec_dot_type) / ggml_blck_size(ggml_internal_get_type_traits(config_.down_type).vec_dot_type) + ith * stride * ggml_type_size(ggml_internal_get_type_traits(config_.down_type).vec_dot_type) / ggml_blck_size(ggml_internal_get_type_traits(config_.down_type).vec_dot_type);
Expand Down
5 changes: 3 additions & 2 deletions csrc/ktransformers_ext/operators/llamafile/moe.h
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ struct MOEConfig {
int stride;
int group_min_len;
int group_max_len;
bool use_silu;
void* gate_proj;
void* up_proj;
void* down_proj;
Expand All @@ -42,8 +43,8 @@ struct MOEConfig {

MOEConfig() {}

MOEConfig(int expert_num, int routed_expert_num, int hidden_size, int intermediate_size, int stride, int group_min_len, int group_max_len, void* gate_proj, void* up_proj, void* down_proj, ggml_type gate_type, ggml_type up_type, ggml_type down_type, ggml_type hidden_type)
: expert_num(expert_num), routed_expert_num(routed_expert_num), hidden_size(hidden_size), intermediate_size(intermediate_size), stride(stride), group_min_len(group_min_len), group_max_len(group_max_len), gate_proj(gate_proj), up_proj(up_proj), down_proj(down_proj), gate_type(gate_type), up_type(up_type), down_type(down_type), hidden_type(hidden_type) {}
MOEConfig(int expert_num, int routed_expert_num, int hidden_size, int intermediate_size, int stride, int group_min_len, int group_max_len, bool use_silu, void* gate_proj, void* up_proj, void* down_proj, ggml_type gate_type, ggml_type up_type, ggml_type down_type, ggml_type hidden_type)
: expert_num(expert_num), routed_expert_num(routed_expert_num), hidden_size(hidden_size), intermediate_size(intermediate_size), stride(stride), group_min_len(group_min_len), group_max_len(group_max_len), use_silu(use_silu), gate_proj(gate_proj), up_proj(up_proj), down_proj(down_proj), gate_type(gate_type), up_type(up_type), down_type(down_type), hidden_type(hidden_type) {}
};

class MOE {
Expand Down
76 changes: 76 additions & 0 deletions doc/en/SmallThinker_and_Glm4moe.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
# GLM-4-MoE Support for KTransformers

## Introduction

### Overview
We are excited to announce that **KTransformers now supports GLM-4-MoE**.

- **GLM-4-MoE 110B (bf16)**: ~11 TPS **on a dual-socket CPU with one consumer-grade GPU**, requiring ~440 GB DRAM.
- **GLM-4-MoE 110B (AMX INT8)**: prefill ~309 TPS / decode ~16 TPS **on a dual-socket CPU with one consumer-grade GPU**, requiring ~220 GB DRAM.

### Model & Resource Links
- **GLM-4-MoE 110B**
- *(to be announced)*

## Installation Guide

### 1. Resource Requirements

| Model | Precision | Experts | DRAM Needed | GPU Memory Needed\* | TPS (approx.) |
| ------------------------- | --------- | ------- | ----------- | ------------------- | ------------------------------ |
| GLM-4-MoE 110B | bf16 | 128 | \~440 GB | 14 GB | \~11 TPS |
| GLM-4-MoE 110B (AMX INT8) | int8 | 128 | \~220 GB | 14 GB | prefill \~309 TPS / decode \~16 TPS |

\* Exact GPU memory depends on sequence length, batch size, and kernels used.

### 2. Prepare Models

```bash
# Example: download original safetensors (adjust to your paths/repos)
# (Fill in actual repos/filenames yourself)

# GLM-4-MoE 110B
huggingface-cli download --resume-download placeholder-org/Model-TBA \
--local-dir ./Model-TBA
````

### 3. Install KTransformers

Follow the official Installation Guide.

```bash
pip install ktransformers # or from source if you need bleeding-edge features
```

### 4. Run GLM-4-MoE 110B Inference Server

```bash
python ktransformers/server/main.py \
--port 10110 \
--model_name Glm4MoeForCausalLM \
--model_path /abs/path/to/GLM-4-MoE-110B-bf16 \
--optimize_config_path ktransformers/optimize/optimize_rules/Glm4Moe-serve.yaml \
--max_new_tokens 1024 \
--cache_lens 32768 \
--chunk_size 256 \
--max_batch_size 4 \
--backend_type balance_serve
```

### 5. Access Server

```bash
curl -X POST http://localhost:10110/v1/chat/completions \
-H "accept: application/json" \
-H "Content-Type: application/json" \
-d '{
"messages": [
{"role": "user", "content": "hello"}
],
"model": "GLM-4-MoE-110B",
"temperature": 0.3,
"top_p": 1.0,
"stream": true
}'
```

14 changes: 7 additions & 7 deletions ktransformers/configs/config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -21,12 +21,12 @@ user:

model:
# type: transformers
# type: balance_serve
type: ktransformers
type: balance_serve
# type: ktransformers

name: DeepSeek-Coder-V2-Instruct
path: deepseek-ai/DeepSeek-V2-Lite-Chat
gguf_path: ./DeepSeek-V2-Lite-Chat-GGUF
name: SmallThinkerForCausalLM
path: /mnt/data/models/Smallthinker-21B
gguf_path: /mnt/data/models/Smallthinker-21B

device: cuda:0
cache_lens: 16384
Expand Down Expand Up @@ -67,7 +67,7 @@ attn:
page_size: 256
chunk_size: 256
kvc2:
gpu_only: false
gpu_only: true
utilization_percentage: 1.0
cpu_memory_size_GB: 500
disk_path: /mnt/data/kvc
disk_path: /home/wjh/kvc
1 change: 1 addition & 0 deletions ktransformers/ktransformers
Loading
Loading