diff --git a/.devops/s390x.Dockerfile b/.devops/s390x.Dockerfile index 3df1a2b0defe0..b7c9457680b08 100644 --- a/.devops/s390x.Dockerfile +++ b/.devops/s390x.Dockerfile @@ -24,8 +24,9 @@ RUN --mount=type=cache,target=/root/.ccache \ -DCMAKE_C_COMPILER_LAUNCHER=ccache \ -DCMAKE_CXX_COMPILER_LAUNCHER=ccache \ -DLLAMA_BUILD_TESTS=OFF \ - -DGGML_BACKEND_DL=OFF \ -DGGML_NATIVE=OFF \ + -DGGML_BACKEND_DL=ON \ + -DGGML_CPU_ALL_VARIANTS=ON \ -DGGML_BLAS=ON \ -DGGML_BLAS_VENDOR=OpenBLAS && \ cmake --build build --config Release -j $(nproc) && \ @@ -103,6 +104,7 @@ FROM base AS light WORKDIR /llama.cpp/bin # Copy llama.cpp binaries and libraries +COPY --from=collector /llama.cpp/bin/*.so /llama.cpp/bin COPY --from=collector /llama.cpp/bin/llama-cli /llama.cpp/bin ENTRYPOINT [ "/llama.cpp/bin/llama-cli" ] @@ -116,6 +118,7 @@ ENV LLAMA_ARG_HOST=0.0.0.0 WORKDIR /llama.cpp/bin # Copy llama.cpp binaries and libraries +COPY --from=collector /llama.cpp/bin/*.so /llama.cpp/bin COPY --from=collector /llama.cpp/bin/llama-server /llama.cpp/bin EXPOSE 8080 diff --git a/.github/workflows/build-linux-cross.yml b/.github/workflows/build-linux-cross.yml index 937306f7afae7..36201281f0059 100644 --- a/.github/workflows/build-linux-cross.yml +++ b/.github/workflows/build-linux-cross.yml @@ -4,49 +4,49 @@ on: workflow_call: jobs: - ubuntu-24-riscv64-cpu-cross: - runs-on: ubuntu-24.04 + # ubuntu-24-riscv64-cpu-cross: + # runs-on: ubuntu-24.04 - steps: - - uses: actions/checkout@v4 - - name: Setup Riscv - run: | - sudo dpkg --add-architecture riscv64 + # steps: + # - uses: actions/checkout@v4 + # - name: Setup Riscv + # run: | + # sudo dpkg --add-architecture riscv64 - # Add arch-specific repositories for non-amd64 architectures - cat << EOF | sudo tee /etc/apt/sources.list.d/riscv64-ports.list - deb [arch=riscv64] http://ports.ubuntu.com/ubuntu-ports/ noble main universe - deb [arch=riscv64] http://ports.ubuntu.com/ubuntu-ports/ noble-updates main universe - deb [arch=riscv64] http://ports.ubuntu.com/ubuntu-ports/ noble-security main universe - deb [arch=riscv64] http://ports.ubuntu.com/ubuntu-ports/ noble-backports main universe - EOF + # # Add arch-specific repositories for non-amd64 architectures + # cat << EOF | sudo tee /etc/apt/sources.list.d/riscv64-ports.list + # deb [arch=riscv64] http://ports.ubuntu.com/ubuntu-ports/ noble main universe + # deb [arch=riscv64] http://ports.ubuntu.com/ubuntu-ports/ noble-updates main universe + # deb [arch=riscv64] http://ports.ubuntu.com/ubuntu-ports/ noble-security main universe + # deb [arch=riscv64] http://ports.ubuntu.com/ubuntu-ports/ noble-backports main universe + # EOF - sudo apt-get update || true ;# Prevent failure due to missing URLs. + # sudo apt-get update || true ;# Prevent failure due to missing URLs. - sudo apt-get install -y --no-install-recommends \ - build-essential \ - gcc-14-riscv64-linux-gnu \ - g++-14-riscv64-linux-gnu + # sudo apt-get install -y --no-install-recommends \ + # build-essential \ + # gcc-14-riscv64-linux-gnu \ + # g++-14-riscv64-linux-gnu - - name: Build - run: | - cmake -B build -DLLAMA_CURL=OFF \ - -DCMAKE_BUILD_TYPE=Release \ - -DGGML_OPENMP=OFF \ - -DLLAMA_BUILD_EXAMPLES=ON \ - -DLLAMA_BUILD_TOOLS=ON \ - -DLLAMA_BUILD_TESTS=OFF \ - -DCMAKE_SYSTEM_NAME=Linux \ - -DCMAKE_SYSTEM_PROCESSOR=riscv64 \ - -DCMAKE_C_COMPILER=riscv64-linux-gnu-gcc-14 \ - -DCMAKE_CXX_COMPILER=riscv64-linux-gnu-g++-14 \ - -DCMAKE_POSITION_INDEPENDENT_CODE=ON \ - -DCMAKE_FIND_ROOT_PATH=/usr/lib/riscv64-linux-gnu \ - -DCMAKE_FIND_ROOT_PATH_MODE_PROGRAM=NEVER \ - -DCMAKE_FIND_ROOT_PATH_MODE_LIBRARY=ONLY \ - -DCMAKE_FIND_ROOT_PATH_MODE_INCLUDE=BOTH + # - name: Build + # run: | + # cmake -B build -DLLAMA_CURL=OFF \ + # -DCMAKE_BUILD_TYPE=Release \ + # -DGGML_OPENMP=OFF \ + # -DLLAMA_BUILD_EXAMPLES=ON \ + # -DLLAMA_BUILD_TOOLS=ON \ + # -DLLAMA_BUILD_TESTS=OFF \ + # -DCMAKE_SYSTEM_NAME=Linux \ + # -DCMAKE_SYSTEM_PROCESSOR=riscv64 \ + # -DCMAKE_C_COMPILER=riscv64-linux-gnu-gcc-14 \ + # -DCMAKE_CXX_COMPILER=riscv64-linux-gnu-g++-14 \ + # -DCMAKE_POSITION_INDEPENDENT_CODE=ON \ + # -DCMAKE_FIND_ROOT_PATH=/usr/lib/riscv64-linux-gnu \ + # -DCMAKE_FIND_ROOT_PATH_MODE_PROGRAM=NEVER \ + # -DCMAKE_FIND_ROOT_PATH_MODE_LIBRARY=ONLY \ + # -DCMAKE_FIND_ROOT_PATH_MODE_INCLUDE=BOTH - cmake --build build --config Release -j $(nproc) + # cmake --build build --config Release -j $(nproc) # ubuntu-24-riscv64-vulkan-cross: # runs-on: ubuntu-24.04 diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index cab3ba9e68ee4..e72caa423ba0f 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -134,8 +134,8 @@ jobs: include: - build: 'x64' os: ubuntu-22.04 - - build: 's390x-z15' # z15 because our CI runners are on z15 - os: ubuntu-22.04-s390x + - build: 's390x' + os: ubuntu-24.04-s390x # GGML_BACKEND_DL and GGML_CPU_ALL_VARIANTS are not currently supported on arm # - build: 'arm64' # os: ubuntu-22.04-arm diff --git a/common/chat.cpp b/common/chat.cpp index 63583fb22489d..938872e82ee1d 100644 --- a/common/chat.cpp +++ b/common/chat.cpp @@ -313,7 +313,6 @@ json common_chat_msgs_to_json_oaicompat(const std::vector & msg } if (!msg.reasoning_content.empty()) { jmsg["reasoning_content"] = msg.reasoning_content; - jmsg["thinking"] = msg.reasoning_content; // gpt-oss } if (!msg.tool_name.empty()) { jmsg["name"] = msg.tool_name; @@ -1810,7 +1809,23 @@ static void common_chat_parse_deepseek_v3_1(common_chat_msg_parser & builder) { static common_chat_params common_chat_params_init_gpt_oss(const common_chat_template & tmpl, const struct templates_params & inputs) { common_chat_params data; - auto prompt = apply(tmpl, inputs); + + // Copy reasoning to the "thinking" field as expected by the gpt-oss template + auto adjusted_messages = json::array(); + for (const auto & msg : inputs.messages) { + auto has_reasoning_content = msg.contains("reasoning_content") && msg.at("reasoning_content").is_string(); + auto has_tool_calls = msg.contains("tool_calls") && msg.at("tool_calls").is_array(); + + if (has_reasoning_content && has_tool_calls) { + auto adjusted_message = msg; + adjusted_message["thinking"] = msg.at("reasoning_content"); + adjusted_messages.push_back(adjusted_message); + } else { + adjusted_messages.push_back(msg); + } + } + + auto prompt = apply(tmpl, inputs, /* messages_override= */ adjusted_messages); // Check if we need to replace the return token with end token during // inference and without generation prompt. For more details see: diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index f186c2167d7d0..c6f5ba6a04c54 100755 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -9802,6 +9802,113 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter return [(self.map_tensor_name(name), data_torch)] + +@ModelBase.register("JanusForConditionalGeneration") +class JanusProModel(LlamaModel): + model_arch = gguf.MODEL_ARCH.LLAMA # reuse Llama arch + + def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: + # Skip vision, aligner, and generation tensors + skip_prefixes = ( + 'model.vision_model.', + 'model.aligner.', + 'model.vqmodel.', + 'model.generation_embeddings.', + 'model.generation_aligner.', + 'model.generation_head.', + ) + if name.startswith(skip_prefixes): + return [] + + if name.startswith('model.language_model.'): + name = name.replace('model.language_model.', 'model.') + elif name.startswith('language_model.'): + name = name.replace('language_model.', '') + + return super().modify_tensors(data_torch, name, bid) + + +@ModelBase.register("JanusForConditionalGeneration") +class JanusProVisionModel(MmprojModel): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + assert self.hparams_vision is not None + if "intermediate_size" not in self.hparams_vision: + mlp_ratio = self.hparams_vision.get("mlp_ratio") + hidden_size = self.hparams_vision.get("hidden_size") + if mlp_ratio is not None and hidden_size is not None: + self.hparams_vision["intermediate_size"] = int(round(hidden_size * mlp_ratio)) + + def set_gguf_parameters(self): + super().set_gguf_parameters() + assert self.hparams_vision is not None + + self.gguf_writer.add_clip_projector_type(gguf.VisionProjectorType.JANUS_PRO) + + self.gguf_writer.add_vision_attention_layernorm_eps(self.hparams_vision.get("layer_norm_eps", 1e-6)) + + hidden_act = str(self.hparams_vision.get("hidden_act", "")).lower() + if hidden_act == "gelu": + self.gguf_writer.add_vision_use_gelu(True) + elif hidden_act == "silu": + self.gguf_writer.add_vision_use_silu(True) + + def _map_aligner_tensor(self, data_torch: Tensor, name: str) -> Iterable[tuple[str, Tensor]]: + """Map aligner tensors to projector format""" + suffix = ".bias" if name.endswith(".bias") else ".weight" + + if name.startswith("model.aligner."): + local_name = name[len("model.aligner."):] + elif name.startswith("aligner."): + local_name = name[len("aligner."):] + else: + raise ValueError(f"Unsupported Janus aligner prefix: {name}") + + if local_name.startswith("fc1."): + mm_index = 0 + elif local_name.startswith("hidden_layers."): + parts = local_name.split(".", 2) + if len(parts) < 3: + raise ValueError(f"Unexpected Janus aligner tensor name: {name}") + mm_index = int(parts[1]) + 1 + else: + raise ValueError(f"Unsupported Janus aligner tensor: {name}") + + tensor_name = self.format_tensor_name(gguf.MODEL_TENSOR.V_MMPROJ, mm_index, suffix=suffix) + return [(tensor_name, data_torch)] + + def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: + del bid # unused + + # Skip language model tensors as they will be handled by `JanusProModel` + if name.startswith(('model.language_model.', 'language_model.')): + return [] + + # Skip generation-related components + skip_generation_prefixes = ( + 'model.vqmodel.', + 'vqmodel.', + 'model.generation_embeddings.', + 'generation_embeddings.', + 'model.generation_aligner.', + 'generation_aligner.', + 'model.generation_head.', + 'generation_head.', + ) + if name.startswith(skip_generation_prefixes): + return [] + + # Handle aligner tensors + if name.startswith(('model.aligner.', 'aligner.')): + return list(self._map_aligner_tensor(data_torch, name)) + + # Handle vision tensors + if name.startswith(('model.vision_model.', 'vision_model.')): + return [(self.map_tensor_name(name), data_torch)] + + return [] + + ###### CONVERSION LOGIC ###### diff --git a/docs/docker.md b/docs/docker.md index bfabf2425a7d6..98502a0c50598 100644 --- a/docs/docker.md +++ b/docs/docker.md @@ -7,9 +7,9 @@ ## Images We have three Docker images available for this project: -1. `ghcr.io/ggml-org/llama.cpp:full`: This image includes both the main executable file and the tools to convert LLaMA models into ggml and convert into 4-bit quantization. (platforms: `linux/amd64`, `linux/arm64`) -2. `ghcr.io/ggml-org/llama.cpp:light`: This image only includes the main executable file. (platforms: `linux/amd64`, `linux/arm64`) -3. `ghcr.io/ggml-org/llama.cpp:server`: This image only includes the server executable file. (platforms: `linux/amd64`, `linux/arm64`) +1. `ghcr.io/ggml-org/llama.cpp:full`: This image includes both the main executable file and the tools to convert LLaMA models into ggml and convert into 4-bit quantization. (platforms: `linux/amd64`, `linux/arm64`, `linux/s390x`) +2. `ghcr.io/ggml-org/llama.cpp:light`: This image only includes the main executable file. (platforms: `linux/amd64`, `linux/arm64`, `linux/s390x`) +3. `ghcr.io/ggml-org/llama.cpp:server`: This image only includes the server executable file. (platforms: `linux/amd64`, `linux/arm64`, `linux/s390x`) Additionally, there the following images, similar to the above: diff --git a/ggml/src/CMakeLists.txt b/ggml/src/CMakeLists.txt index ba281b8e6d17a..f30e4ac9020fa 100644 --- a/ggml/src/CMakeLists.txt +++ b/ggml/src/CMakeLists.txt @@ -308,6 +308,10 @@ function(ggml_add_cpu_backend_variant tag_name) set(GGML_INTERNAL_${feat} ON) endforeach() elseif (GGML_SYSTEM_ARCH STREQUAL "s390x") + foreach (feat VXE2 NNPA) + set(GGML_INTERNAL_${feat} OFF) + endforeach() + foreach (feat ${ARGN}) set(GGML_INTERNAL_${feat} ON) endforeach() @@ -377,9 +381,8 @@ if (GGML_CPU_ALL_VARIANTS) endif() elseif (GGML_SYSTEM_ARCH STREQUAL "s390x") if (CMAKE_SYSTEM_NAME MATCHES "Linux") - ggml_add_cpu_backend_variant(s390x_z15 Z15 VXE) - # ggml_add_cpu_backend_variant(s390x_z16 Z16 VXE) - # ggml_add_cpu_backend_variant(s390x_z17 Z17 VXE) + ggml_add_cpu_backend_variant(z15 Z15 VXE2) + ggml_add_cpu_backend_variant(z16 Z16 VXE2 NNPA) else() message(FATAL_ERROR "Unsupported s390x target OS: ${CMAKE_SYSTEM_NAME}") endif() diff --git a/ggml/src/ggml-cpu/CMakeLists.txt b/ggml/src/ggml-cpu/CMakeLists.txt index 34323afa0762a..23ec8bb08a732 100644 --- a/ggml/src/ggml-cpu/CMakeLists.txt +++ b/ggml/src/ggml-cpu/CMakeLists.txt @@ -504,11 +504,18 @@ function(ggml_add_cpu_backend_variant_impl tag_name) endforeach() endif() - if (GGML_VXE OR GGML_INTERNAL_VXE) - message(STATUS "VX/VXE/VXE2 enabled") + if (GGML_VXE OR GGML_INTERNAL_VXE2) + message(STATUS "VXE2 enabled") list(APPEND ARCH_FLAGS -mvx -mzvector) - list(APPEND ARCH_DEFINITIONS GGML_VXE) + list(APPEND ARCH_DEFINITIONS GGML_USE_VXE2) endif() + + if (GGML_INTERNAL_NNPA) + message(STATUS "NNPA enabled") + list(APPEND ARCH_DEFINITIONS GGML_USE_NNPA) + endif() + + ggml_add_cpu_backend_features(${GGML_CPU_NAME} s390 ${ARCH_DEFINITIONS}) elseif (CMAKE_SYSTEM_PROCESSOR MATCHES "wasm") message(STATUS "Wasm detected") list (APPEND GGML_CPU_SOURCES ggml-cpu/arch/wasm/quants.c) diff --git a/ggml/src/ggml-cpu/arch/s390/cpu-feats.cpp b/ggml/src/ggml-cpu/arch/s390/cpu-feats.cpp new file mode 100644 index 0000000000000..5f4405a7f308b --- /dev/null +++ b/ggml/src/ggml-cpu/arch/s390/cpu-feats.cpp @@ -0,0 +1,50 @@ +#include "ggml-backend-impl.h" + +#if defined(__s390x__) +#include + +// find hwcap bits in asm/elf.h +#ifndef HWCAP_VXRS_EXT2 +#define HWCAP_VXRS_EXT2 (1 << 15) +#endif + +#ifndef HWCAP_NNPA +#define HWCAP_NNPA (1 << 20) +#endif + +struct s390x_features { + bool has_vxe2 = false; + bool has_nnpa = false; + + s390x_features() { + uint32_t hwcap = getauxval(AT_HWCAP); + // NOTE: use hwcap2 with DFLT for z17 and later + // uint32_t hwcap2 = getauxval(AT_HWCAP2); + + has_vxe2 = !!(hwcap & HWCAP_VXRS_EXT2); + has_nnpa = !!(hwcap & HWCAP_NNPA); + } +}; + +static int ggml_backend_cpu_s390x_score() { + int score = 1; + s390x_features sf; + +// IBM z15 / LinuxONE 3 +#ifdef GGML_USE_VXE2 + if (!sf.has_vxe2) { return 0; } + score += 1 << 1; +#endif + +// IBM z16 / LinuxONE 4 and z17 / LinuxONE 5 +#ifdef GGML_USE_NNPA + if (!sf.has_nnpa) { return 0; } + score += 1 << 2; +#endif + + return score; +} + +GGML_BACKEND_DL_SCORE_IMPL(ggml_backend_cpu_s390x_score) + +#endif // __s390x__ diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu index 61a8f1df87de1..5667ec0c4d709 100644 --- a/ggml/src/ggml-cuda/ggml-cuda.cu +++ b/ggml/src/ggml-cuda/ggml-cuda.cu @@ -2499,6 +2499,18 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg case GGML_UNARY_OP_XIELU: ggml_cuda_op_xielu(ctx, dst); break; + case GGML_UNARY_OP_FLOOR: + ggml_cuda_op_floor(ctx, dst); + break; + case GGML_UNARY_OP_CEIL: + ggml_cuda_op_ceil(ctx, dst); + break; + case GGML_UNARY_OP_ROUND: + ggml_cuda_op_round(ctx, dst); + break; + case GGML_UNARY_OP_TRUNC: + ggml_cuda_op_trunc(ctx, dst); + break; default: return false; } @@ -3769,6 +3781,10 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g case GGML_UNARY_OP_TANH: case GGML_UNARY_OP_EXP: case GGML_UNARY_OP_ELU: + case GGML_UNARY_OP_FLOOR: + case GGML_UNARY_OP_CEIL: + case GGML_UNARY_OP_ROUND: + case GGML_UNARY_OP_TRUNC: return ggml_is_contiguous(op->src[0]); default: return false; diff --git a/ggml/src/ggml-cuda/unary.cu b/ggml/src/ggml-cuda/unary.cu index 5f0d3a6726aef..c1dc6ddbf8f81 100644 --- a/ggml/src/ggml-cuda/unary.cu +++ b/ggml/src/ggml-cuda/unary.cu @@ -85,6 +85,22 @@ static __device__ __forceinline__ float op_elu(float x) { return (x > 0.f) ? x : expm1f(x); } +static __device__ __forceinline__ float op_floor(float x) { + return floorf(x); +} + +static __device__ __forceinline__ float op_ceil(float x) { + return ceilf(x); +} + +static __device__ __forceinline__ float op_round(float x) { + return round(x); +} + +static __device__ __forceinline__ float op_trunc(float x) { + return trunc(x); +} + template static __global__ void unary_op_kernel(const T * x, T * dst, const int k) { const int i = blockDim.x*blockIdx.x + threadIdx.x; @@ -201,6 +217,22 @@ void ggml_cuda_op_log(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { void ggml_cuda_op_elu(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { ggml_cuda_op_unary(ctx, dst); } + +void ggml_cuda_op_floor(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { + ggml_cuda_op_unary(ctx, dst); +} + +void ggml_cuda_op_ceil(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { + ggml_cuda_op_unary(ctx, dst); +} + +void ggml_cuda_op_round(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { + ggml_cuda_op_unary(ctx, dst); +} + +void ggml_cuda_op_trunc(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { + ggml_cuda_op_unary(ctx, dst); +} /* gated ops */ template diff --git a/ggml/src/ggml-cuda/unary.cuh b/ggml/src/ggml-cuda/unary.cuh index 6c738cefecfd2..2800c75ba3f7a 100644 --- a/ggml/src/ggml-cuda/unary.cuh +++ b/ggml/src/ggml-cuda/unary.cuh @@ -63,6 +63,14 @@ void ggml_cuda_op_log(ggml_backend_cuda_context & ctx, ggml_tensor * dst); void ggml_cuda_op_elu(ggml_backend_cuda_context & ctx, ggml_tensor * dst); +void ggml_cuda_op_floor(ggml_backend_cuda_context & ctx, ggml_tensor * dst); + +void ggml_cuda_op_ceil(ggml_backend_cuda_context & ctx, ggml_tensor * dst); + +void ggml_cuda_op_round(ggml_backend_cuda_context & ctx, ggml_tensor * dst); + +void ggml_cuda_op_trunc(ggml_backend_cuda_context & ctx, ggml_tensor * dst); + void ggml_cuda_op_reglu(ggml_backend_cuda_context & ctx, ggml_tensor * dst); void ggml_cuda_op_geglu(ggml_backend_cuda_context & ctx, ggml_tensor * dst); diff --git a/ggml/src/ggml-metal/ggml-metal-device.m b/ggml/src/ggml-metal/ggml-metal-device.m index 360fbe19f0fb6..0cadd19a30fe9 100644 --- a/ggml/src/ggml-metal/ggml-metal-device.m +++ b/ggml/src/ggml-metal/ggml-metal-device.m @@ -707,6 +707,7 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te if (op->src[0]->ne[0] != 32 && op->src[0]->ne[0] != 40 && op->src[0]->ne[0] != 64 && + op->src[0]->ne[0] != 72 && op->src[0]->ne[0] != 80 && op->src[0]->ne[0] != 96 && op->src[0]->ne[0] != 112 && diff --git a/ggml/src/ggml-metal/ggml-metal.metal b/ggml/src/ggml-metal/ggml-metal.metal index fa839a1df6e30..424c400f24b9b 100644 --- a/ggml/src/ggml-metal/ggml-metal.metal +++ b/ggml/src/ggml-metal/ggml-metal.metal @@ -5362,6 +5362,7 @@ typedef decltype(kernel_flash_attn_ext; template [[host_name("kernel_flash_attn_ext_f32_dk40_dv40" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; template [[host_name("kernel_flash_attn_ext_f32_dk64_dv64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_f32_dk72_dv72" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; template [[host_name("kernel_flash_attn_ext_f32_dk80_dv80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; template [[host_name("kernel_flash_attn_ext_f32_dk96_dv96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; template [[host_name("kernel_flash_attn_ext_f32_dk112_dv112")]] kernel flash_attn_ext_t kernel_flash_attn_ext; @@ -5374,6 +5375,7 @@ template [[host_name("kernel_flash_attn_ext_f32_dk576_dv512")]] kernel flash_at template [[host_name("kernel_flash_attn_ext_f16_dk32_dv32" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; template [[host_name("kernel_flash_attn_ext_f16_dk40_dv40" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; template [[host_name("kernel_flash_attn_ext_f16_dk64_dv64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_f16_dk72_dv72" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; template [[host_name("kernel_flash_attn_ext_f16_dk80_dv80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; template [[host_name("kernel_flash_attn_ext_f16_dk96_dv96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; template [[host_name("kernel_flash_attn_ext_f16_dk112_dv112")]] kernel flash_attn_ext_t kernel_flash_attn_ext; @@ -5387,6 +5389,7 @@ template [[host_name("kernel_flash_attn_ext_f16_dk576_dv512")]] kernel flash_at template [[host_name("kernel_flash_attn_ext_bf16_dk32_dv32" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; template [[host_name("kernel_flash_attn_ext_bf16_dk40_dv40" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; template [[host_name("kernel_flash_attn_ext_bf16_dk64_dv64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_bf16_dk72_dv72" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; template [[host_name("kernel_flash_attn_ext_bf16_dk80_dv80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; template [[host_name("kernel_flash_attn_ext_bf16_dk96_dv96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; template [[host_name("kernel_flash_attn_ext_bf16_dk112_dv112")]] kernel flash_attn_ext_t kernel_flash_attn_ext; @@ -5400,6 +5403,7 @@ template [[host_name("kernel_flash_attn_ext_bf16_dk576_dv512")]] kernel flash_at template [[host_name("kernel_flash_attn_ext_q4_0_dk32_dv32" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; template [[host_name("kernel_flash_attn_ext_q4_0_dk40_dv40" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; template [[host_name("kernel_flash_attn_ext_q4_0_dk64_dv64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q4_0_dk72_dv72" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; template [[host_name("kernel_flash_attn_ext_q4_0_dk80_dv80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; template [[host_name("kernel_flash_attn_ext_q4_0_dk96_dv96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; template [[host_name("kernel_flash_attn_ext_q4_0_dk112_dv112")]] kernel flash_attn_ext_t kernel_flash_attn_ext; @@ -5412,6 +5416,7 @@ template [[host_name("kernel_flash_attn_ext_q4_0_dk576_dv512")]] kernel flash_at template [[host_name("kernel_flash_attn_ext_q4_1_dk32_dv32" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; template [[host_name("kernel_flash_attn_ext_q4_1_dk40_dv40" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; template [[host_name("kernel_flash_attn_ext_q4_1_dk64_dv64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q4_1_dk72_dv72" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; template [[host_name("kernel_flash_attn_ext_q4_1_dk80_dv80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; template [[host_name("kernel_flash_attn_ext_q4_1_dk96_dv96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; template [[host_name("kernel_flash_attn_ext_q4_1_dk112_dv112")]] kernel flash_attn_ext_t kernel_flash_attn_ext; @@ -5424,6 +5429,7 @@ template [[host_name("kernel_flash_attn_ext_q4_1_dk576_dv512")]] kernel flash_at template [[host_name("kernel_flash_attn_ext_q5_0_dk32_dv32" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; template [[host_name("kernel_flash_attn_ext_q5_0_dk40_dv40" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; template [[host_name("kernel_flash_attn_ext_q5_0_dk64_dv64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q5_0_dk72_dv72" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; template [[host_name("kernel_flash_attn_ext_q5_0_dk80_dv80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; template [[host_name("kernel_flash_attn_ext_q5_0_dk96_dv96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; template [[host_name("kernel_flash_attn_ext_q5_0_dk112_dv112")]] kernel flash_attn_ext_t kernel_flash_attn_ext; @@ -5436,6 +5442,7 @@ template [[host_name("kernel_flash_attn_ext_q5_0_dk576_dv512")]] kernel flash_at template [[host_name("kernel_flash_attn_ext_q5_1_dk32_dv32" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; template [[host_name("kernel_flash_attn_ext_q5_1_dk40_dv40" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; template [[host_name("kernel_flash_attn_ext_q5_1_dk64_dv64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q5_1_dk72_dv72" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; template [[host_name("kernel_flash_attn_ext_q5_1_dk80_dv80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; template [[host_name("kernel_flash_attn_ext_q5_1_dk96_dv96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; template [[host_name("kernel_flash_attn_ext_q5_1_dk112_dv112")]] kernel flash_attn_ext_t kernel_flash_attn_ext; @@ -5448,6 +5455,7 @@ template [[host_name("kernel_flash_attn_ext_q5_1_dk576_dv512")]] kernel flash_at template [[host_name("kernel_flash_attn_ext_q8_0_dk32_dv32" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; template [[host_name("kernel_flash_attn_ext_q8_0_dk40_dv40" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; template [[host_name("kernel_flash_attn_ext_q8_0_dk64_dv64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q8_0_dk72_dv72" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; template [[host_name("kernel_flash_attn_ext_q8_0_dk80_dv80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; template [[host_name("kernel_flash_attn_ext_q8_0_dk96_dv96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; template [[host_name("kernel_flash_attn_ext_q8_0_dk112_dv112")]] kernel flash_attn_ext_t kernel_flash_attn_ext; diff --git a/gguf-py/gguf/constants.py b/gguf-py/gguf/constants.py index 0d5afa01edf84..77e3b0650ff0b 100644 --- a/gguf-py/gguf/constants.py +++ b/gguf-py/gguf/constants.py @@ -3186,6 +3186,7 @@ class VisionProjectorType: KIMIVL = "kimivl" LIGHTONOCR = "lightonocr" COGVLM = "cogvlm" + JANUS_PRO = "janus_pro" # Items here are (block size, type size) diff --git a/gguf-py/gguf/tensor_mapping.py b/gguf-py/gguf/tensor_mapping.py index cef5acec7581f..929406687610c 100644 --- a/gguf-py/gguf/tensor_mapping.py +++ b/gguf-py/gguf/tensor_mapping.py @@ -1183,6 +1183,7 @@ class TensorNameMap: "model.mm_projector.mlp.mlp.{bid}", "vision_model.vision_adapter.mlp.fc{bid}", # llama 4 "mlp1.{bid}", # InternVL + "model.aligner.fc1.hidden_layers.{bid}", # Janus Pro ), MODEL_TENSOR.V_MMPROJ_PEG: ( @@ -1291,6 +1292,7 @@ class TensorNameMap: "model.vision_tower.encoder.layer.{bid}.attention.projection_layer", # Intern-S1 "vpm.encoder.layers.{bid}.self_attn.out_proj", "model.vision_model.encoder.layers.{bid}.self_attn.out_proj", # SmolVLM + "model.vision_model.encoder.layers.{bid}.self_attn.projection_layer", # Janus Pro "vision_model.model.layers.{bid}.self_attn.o_proj", # llama4 "vision_tower.transformer.layers.{bid}.attention.o_proj", # pixtral-hf "vision_encoder.transformer.layers.{bid}.attention.wo", # pixtral diff --git a/include/llama.h b/include/llama.h index 08e2ffa34c9f6..98bed9d6150a0 100644 --- a/include/llama.h +++ b/include/llama.h @@ -461,7 +461,10 @@ extern "C" { LLAMA_API bool llama_supports_gpu_offload(void); LLAMA_API bool llama_supports_rpc (void); + // NOTE: After creating a llama_context, it is recommended to query the actual values using these functions + // In some cases the requested values via llama_context_params may differ from the actual values used by the context LLAMA_API uint32_t llama_n_ctx (const struct llama_context * ctx); + LLAMA_API uint32_t llama_n_ctx_seq (const struct llama_context * ctx); LLAMA_API uint32_t llama_n_batch (const struct llama_context * ctx); LLAMA_API uint32_t llama_n_ubatch (const struct llama_context * ctx); LLAMA_API uint32_t llama_n_seq_max (const struct llama_context * ctx); @@ -585,7 +588,7 @@ extern "C" { LLAMA_API int32_t llama_adapter_meta_val_str_by_index(const struct llama_adapter_lora * adapter, int32_t i, char * buf, size_t buf_size); // Manually free a LoRA adapter - // Note: loaded adapters will be free when the associated model is deleted + // NOTE: loaded adapters will be free when the associated model is deleted LLAMA_API void llama_adapter_lora_free(struct llama_adapter_lora * adapter); // Get the invocation tokens if the current lora is an alora @@ -1111,8 +1114,6 @@ extern "C" { // // sample from the logits of the last token in the batch // const llama_token id = llama_sampler_sample(smpl, ctx, -1); // - // // accepting the token updates the internal state of certain samplers (e.g. grammar, repetition, etc.) - // llama_sampler_accept(smpl, id); // ... // } // diff --git a/src/llama-context.cpp b/src/llama-context.cpp index f6192a36e0ee5..2b39366271ff9 100644 --- a/src/llama-context.cpp +++ b/src/llama-context.cpp @@ -112,11 +112,24 @@ llama_context::llama_context( } } - const uint32_t n_ctx_per_seq = cparams.n_ctx / cparams.n_seq_max; + if (cparams.kv_unified) { + cparams.n_ctx_seq = cparams.n_ctx; + } else { + cparams.n_ctx_seq = cparams.n_ctx / cparams.n_seq_max; + + if (cparams.n_ctx_seq == 0) { + throw std::runtime_error("n_ctx_seq == 0"); + } + + if (cparams.n_ctx != cparams.n_ctx_seq * cparams.n_seq_max) { + cparams.n_ctx = cparams.n_ctx_seq * cparams.n_seq_max; + LLAMA_LOG_WARN("%s: n_ctx is not divisible by n_seq_max - rounding down to %u\n", __func__, cparams.n_ctx); + } + } LLAMA_LOG_INFO("%s: n_seq_max = %u\n", __func__, cparams.n_seq_max); LLAMA_LOG_INFO("%s: n_ctx = %u\n", __func__, cparams.n_ctx); - LLAMA_LOG_INFO("%s: n_ctx_per_seq = %u\n", __func__, n_ctx_per_seq); + LLAMA_LOG_INFO("%s: n_ctx_seq = %u\n", __func__, cparams.n_ctx_seq); LLAMA_LOG_INFO("%s: n_batch = %u\n", __func__, cparams.n_batch); LLAMA_LOG_INFO("%s: n_ubatch = %u\n", __func__, cparams.n_ubatch); LLAMA_LOG_INFO("%s: causal_attn = %d\n", __func__, cparams.causal_attn); @@ -125,14 +138,14 @@ llama_context::llama_context( LLAMA_LOG_INFO("%s: freq_base = %.1f\n", __func__, cparams.rope_freq_base); LLAMA_LOG_INFO("%s: freq_scale = %g\n", __func__, cparams.rope_freq_scale); - if (n_ctx_per_seq < hparams.n_ctx_train) { - LLAMA_LOG_WARN("%s: n_ctx_per_seq (%u) < n_ctx_train (%u) -- the full capacity of the model will not be utilized\n", - __func__, n_ctx_per_seq, hparams.n_ctx_train); + if (cparams.n_ctx_seq < hparams.n_ctx_train) { + LLAMA_LOG_WARN("%s: n_ctx_seq (%u) < n_ctx_train (%u) -- the full capacity of the model will not be utilized\n", + __func__, cparams.n_ctx_seq, hparams.n_ctx_train); } - if (n_ctx_per_seq > hparams.n_ctx_train) { - LLAMA_LOG_WARN("%s: n_ctx_per_seq (%u) > n_ctx_train (%u) -- possible training context overflow\n", - __func__, n_ctx_per_seq, hparams.n_ctx_train); + if (cparams.n_ctx_seq > hparams.n_ctx_train) { + LLAMA_LOG_WARN("%s: n_ctx_seq (%u) > n_ctx_train (%u) -- possible training context overflow\n", + __func__, cparams.n_ctx_seq, hparams.n_ctx_train); } if (!hparams.vocab_only) { @@ -453,8 +466,8 @@ uint32_t llama_context::n_ctx() const { return cparams.n_ctx; } -uint32_t llama_context::n_ctx_per_seq() const { - return cparams.n_ctx / cparams.n_seq_max; +uint32_t llama_context::n_ctx_seq() const { + return cparams.n_ctx_seq; } uint32_t llama_context::n_batch() const { @@ -2383,6 +2396,10 @@ uint32_t llama_n_ctx(const llama_context * ctx) { return ctx->n_ctx(); } +uint32_t llama_n_ctx_seq(const llama_context * ctx) { + return ctx->n_ctx_seq(); +} + uint32_t llama_n_batch(const llama_context * ctx) { return ctx->n_batch(); } diff --git a/src/llama-context.h b/src/llama-context.h index ed6d82cb396f9..20cbd78955412 100644 --- a/src/llama-context.h +++ b/src/llama-context.h @@ -43,11 +43,11 @@ struct llama_context { ggml_backend_sched_t get_sched() const; - uint32_t n_ctx() const; - uint32_t n_ctx_per_seq() const; - uint32_t n_batch() const; - uint32_t n_ubatch() const; - uint32_t n_seq_max() const; + uint32_t n_ctx() const; + uint32_t n_ctx_seq() const; + uint32_t n_batch() const; + uint32_t n_ubatch() const; + uint32_t n_seq_max() const; uint32_t n_threads() const; uint32_t n_threads_batch() const; diff --git a/src/llama-cparams.h b/src/llama-cparams.h index eae7b839f4857..fcef8fa976038 100644 --- a/src/llama-cparams.h +++ b/src/llama-cparams.h @@ -8,6 +8,7 @@ struct llama_cparams { uint32_t n_ctx; // context size used during inference + uint32_t n_ctx_seq; // context for a single sequence uint32_t n_batch; uint32_t n_ubatch; uint32_t n_seq_max; diff --git a/src/llama-model.cpp b/src/llama-model.cpp index 04239181c7765..896725466ce24 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -6712,14 +6712,14 @@ float llama_model::get_rope_freq_scale(const llama_cparams & cparams, int il) co } ggml_tensor * llama_model::get_rope_factors(const llama_cparams & cparams, int il) const { - const uint32_t n_ctx_per_seq = cparams.n_ctx / cparams.n_seq_max; + const uint32_t n_ctx_seq = cparams.n_ctx_seq; // choose long/short freq factors based on the context size if (layers[il].rope_freqs != nullptr) { return layers[il].rope_freqs; } - if (n_ctx_per_seq > hparams.n_ctx_orig_yarn) { + if (n_ctx_seq > hparams.n_ctx_orig_yarn) { return layers[il].rope_long; } @@ -6795,12 +6795,6 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params, /* filter_attn */ std::move(filter_attn), /* filter_recr */ std::move(filter_recr)); } else { - uint32_t n_ctx_per_stream = cparams.n_ctx; - - if (!cparams.kv_unified) { - n_ctx_per_stream = (cparams.n_ctx + cparams.n_seq_max - 1)/cparams.n_seq_max; - } - llama_memory_i::layer_reuse_cb reuse = nullptr; if (arch == LLM_ARCH_GEMMA3N) { @@ -6824,7 +6818,7 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params, cparams.offload_kqv, params.swa_full, cparams.kv_unified, - n_ctx_per_stream, + cparams.n_ctx_seq, cparams.n_seq_max, cparams.n_ubatch, 1, @@ -6840,7 +6834,7 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params, !cparams.flash_attn, cparams.offload_kqv, cparams.kv_unified, - n_ctx_per_stream, + cparams.n_ctx_seq, cparams.n_seq_max, 1, hparams.n_swa, diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index 04fa1b62d3b4d..967a53c63d86d 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -1454,6 +1454,8 @@ struct test_case { ggml_context_ptr ctx(ggml_init(params)); // smart ptr GGML_ASSERT(ctx); + gf = ggml_new_graph_custom(ctx.get(), graph_nodes, false); + ggml_tensor * out = build_graph(ctx.get()); current_op_name = op_desc(out); @@ -7225,8 +7227,8 @@ static std::vector> make_test_cases_eval() { test_cases.emplace_back(new test_pad_ext(GGML_TYPE_F32, {11, 22, 33, 44}, 1, 2, 3, 4, 5, 6, 7, 8, v)); } - for (int hsk : { 40, 64, 80, 96, 128, 192, 256, 576 }) { - for (int hsv : { 40, 64, 80, 96, 128, 192, 256, 512 }) { + for (int hsk : { 40, 64, 72, 80, 96, 128, 192, 256, 576 }) { + for (int hsv : { 40, 64, 72, 80, 96, 128, 192, 256, 512 }) { if (hsk != 192 && hsk != 576 && hsk != hsv) continue; if (hsk == 192 && (hsv != 128 && hsv != 192)) continue; if (hsk == 576 && hsv != 512) continue; // DeepSeek MLA @@ -7569,6 +7571,15 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op if (mode == MODE_SUPPORT) { auto test_cases = make_test_cases_eval(); filter_test_cases(test_cases, params_filter); + + // Filter out fusion cases + test_cases.erase( + std::remove_if(test_cases.begin(), test_cases.end(), [](const std::unique_ptr & tc) { + return tc->run_whole_graph(); + }), + test_cases.end() + ); + for (auto & test : test_cases) { test->eval_support(backend, op_names_filter, output_printer); } @@ -7619,6 +7630,14 @@ static void show_test_coverage() { all_ops.insert(ggml_glu_op_name((enum ggml_glu_op)i)); } auto test_cases = make_test_cases_eval(); + // Filter out fusion cases + test_cases.erase( + std::remove_if(test_cases.begin(), test_cases.end(), [](const std::unique_ptr & tc) { + return tc->run_whole_graph(); + }), + test_cases.end() + ); + std::set tested_ops; ggml_init_params params = { diff --git a/tests/test-thread-safety.cpp b/tests/test-thread-safety.cpp index e5158fb5062f0..bcb86c35e6652 100644 --- a/tests/test-thread-safety.cpp +++ b/tests/test-thread-safety.cpp @@ -131,7 +131,14 @@ int main(int argc, char ** argv) { } batch = llama_batch_get_one(&token, 1); - if (llama_decode(ctx.get(), batch)) { + + int ret = llama_decode(ctx.get(), batch); + if (ret == 1 && i > 0) { + LOG_INF("Context full, stopping generation.\n"); + break; + } + + if (ret != 0) { LOG_ERR("Model %d/%d, Context %d/%d: failed to decode\n", m + 1, num_models, c + 1, num_contexts); failed.store(true); return; diff --git a/tools/mtmd/clip-impl.h b/tools/mtmd/clip-impl.h index c7e9498349c1b..722b1a4948d6f 100644 --- a/tools/mtmd/clip-impl.h +++ b/tools/mtmd/clip-impl.h @@ -155,6 +155,7 @@ enum projector_type { PROJECTOR_TYPE_KIMIVL, PROJECTOR_TYPE_LIGHTONOCR, PROJECTOR_TYPE_COGVLM, + PROJECTOR_TYPE_JANUS_PRO, PROJECTOR_TYPE_UNKNOWN, }; @@ -180,6 +181,7 @@ static std::map PROJECTOR_TYPE_NAMES = { { PROJECTOR_TYPE_KIMIVL, "kimivl"}, { PROJECTOR_TYPE_LIGHTONOCR,"lightonocr"}, { PROJECTOR_TYPE_COGVLM, "cogvlm"}, + { PROJECTOR_TYPE_JANUS_PRO, "janus_pro"}, }; static projector_type clip_projector_type_from_string(const std::string & str) { diff --git a/tools/mtmd/clip.cpp b/tools/mtmd/clip.cpp index dcfdb49600b6c..0784e69fcdf93 100644 --- a/tools/mtmd/clip.cpp +++ b/tools/mtmd/clip.cpp @@ -6,7 +6,6 @@ #include "clip-impl.h" #include "ggml.h" #include "ggml-cpp.h" -#include "ggml-cpu.h" #include "ggml-alloc.h" #include "ggml-backend.h" #include "gguf.h" @@ -17,17 +16,15 @@ #include #include #include -#include #include #include #include -#include #include #include #include -#include #include +// TODO: allow to pass callback from user code struct clip_logger_state g_logger_state = {GGML_LOG_LEVEL_CONT, clip_log_callback_default, NULL}; enum ffn_op_type { @@ -426,12 +423,14 @@ struct clip_ctx { int max_nodes = 8192; ggml_backend_sched_ptr sched; + clip_flash_attn_type flash_attn_type = CLIP_FLASH_ATTN_TYPE_AUTO; // for debugging bool debug_graph = false; std::vector debug_print_tensors; clip_ctx(clip_context_params & ctx_params) { + flash_attn_type = ctx_params.flash_attn_type; debug_graph = std::getenv("MTMD_DEBUG_GRAPH") != nullptr; backend_cpu = ggml_backend_init_by_type(GGML_BACKEND_DEVICE_TYPE_CPU, nullptr); if (!backend_cpu) { @@ -589,6 +588,15 @@ struct clip_graph { cur = ggml_gelu(ctx0, cur); cur = ggml_mul_mat(ctx0, model.mm_2_w, cur); cur = ggml_add(ctx0, cur, model.mm_2_b); + + } else if (ctx->proj_type() == PROJECTOR_TYPE_JANUS_PRO) { + cur = build_ffn(cur, + model.mm_0_w, model.mm_0_b, + nullptr, nullptr, + model.mm_1_w, model.mm_1_b, + hparams.ffn_op, + -1); + } else { GGML_ABORT("SigLIP: Unsupported projector type"); } @@ -1730,7 +1738,6 @@ struct clip_graph { return gf; } - // whisper encoder with custom projector ggml_cgraph * build_whisper_enc() { const int n_frames = img.nx; @@ -2260,17 +2267,25 @@ struct clip_graph { ggml_tensor * k = ggml_permute(ctx0, k_cur, 0, 2, 1, 3); //cb(k, "k", il); - ggml_tensor * v = ggml_permute(ctx0, v_cur, 1, 2, 0, 3); - v = ggml_cont(ctx0, v); - //cb(k, "v", il); - ggml_tensor * cur; - // TODO @ngxson : support flash attention - { + if (ctx->flash_attn_type == CLIP_FLASH_ATTN_TYPE_ENABLED) { + ggml_tensor * v = ggml_permute(ctx0, v_cur, 0, 2, 1, 3); + + k = ggml_cast(ctx0, k, GGML_TYPE_F16); + v = ggml_cast(ctx0, v, GGML_TYPE_F16); + + cur = ggml_flash_attn_ext(ctx0, q, k, v, kq_mask, kq_scale, 0.0f, 0.0f); + ggml_flash_attn_ext_set_prec(cur, GGML_PREC_F32); + + cur = ggml_reshape_2d(ctx0, cur, cur->ne[0]*cur->ne[1], cur->ne[2]*cur->ne[3]); + + } else { + ggml_tensor * v = ggml_permute(ctx0, v_cur, 1, 2, 0, 3); + v = ggml_cont(ctx0, v); + const auto n_tokens = q->ne[1]; const auto n_head = q->ne[2]; - // const auto n_kv = k->ne[1]; // for flash attention ggml_tensor * kq = ggml_mul_mat(ctx0, k, q); // F32 may not needed for vision encoders? @@ -2450,6 +2465,10 @@ static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32 { res = graph.build_kimivl(); } break; + case PROJECTOR_TYPE_JANUS_PRO: + { + res = graph.build_siglip(); + } break; case PROJECTOR_TYPE_COGVLM: { res = graph.build_cogvlm(); @@ -3151,6 +3170,13 @@ struct clip_model_loader { model.mm_boi = get_tensor(TN_TOK_BOI); model.mm_eoi = get_tensor(TN_TOK_EOI); } break; + case PROJECTOR_TYPE_JANUS_PRO: + { + model.mm_0_w = get_tensor(string_format(TN_LLAVA_PROJ, 0, "weight")); + model.mm_0_b = get_tensor(string_format(TN_LLAVA_PROJ, 0, "bias")); + model.mm_1_w = get_tensor(string_format(TN_LLAVA_PROJ, 1, "weight")); + model.mm_1_b = get_tensor(string_format(TN_LLAVA_PROJ, 1, "bias")); + } break; default: GGML_ASSERT(false && "unknown projector type"); } @@ -3192,7 +3218,87 @@ struct clip_model_loader { } } - void alloc_compute_meta(clip_ctx & ctx_clip) { + struct support_info_op { + ggml_tensor * op; + + // true if the op runs on the accelerated ctx_clip.backend + bool is_accel = true; + }; + + struct support_info_graph { + // whether the clip_ctx.backend supports flash attention + bool fattn = true; + ggml_tensor * fattn_op = nullptr; // for debugging + + std::vector ops; + }; + + static void warmup(clip_ctx & ctx_clip) { + support_info_graph info; + + if (ctx_clip.flash_attn_type == CLIP_FLASH_ATTN_TYPE_AUTO) { + // try to enable flash attention to see if it's supported + ctx_clip.flash_attn_type = CLIP_FLASH_ATTN_TYPE_ENABLED; + info = alloc_compute_meta(ctx_clip); + if (!info.fattn && info.fattn_op) { + auto op = info.fattn_op; + LOG_WRN("%s: *****************************************************************\n", __func__); + LOG_WRN("%s: WARNING: flash attention not supported by %s, memory usage will increase\n", __func__, ggml_backend_name(ctx_clip.backend)); + LOG_WRN("%s: op params: \n", __func__); + static auto print_shape = [](const char * fn, const char * name, ggml_tensor * t) { + LOG_WRN("%s: %s: type = %s, ne = [%d %d %d %d], nb = [%d %d %d %d]\n", fn, + name, ggml_type_name(t->type), + t->ne[0], t->ne[1], t->ne[2], t->ne[3], + t->nb[0], t->nb[1], t->nb[2], t->nb[3]); + }; + print_shape(__func__, " dst", op); + print_shape(__func__, "src0", op->src[0]); + print_shape(__func__, "src1", op->src[1]); + print_shape(__func__, "src2", op->src[2]); + LOG_WRN("%s: please report this on github as an issue\n", __func__); + LOG_WRN("%s: *****************************************************************\n", __func__); + ctx_clip.flash_attn_type = CLIP_FLASH_ATTN_TYPE_DISABLED; + alloc_compute_meta(ctx_clip); + } + } else { + info = alloc_compute_meta(ctx_clip); + if (!info.fattn && ctx_clip.flash_attn_type == CLIP_FLASH_ATTN_TYPE_ENABLED) { + LOG_WRN("%s: flash attention is not supported by the current backend; falling back to CPU (performance will be degraded)\n", __func__); + } + } + + LOG_INF("%s: flash attention is %s\n", __func__, + (ctx_clip.flash_attn_type == CLIP_FLASH_ATTN_TYPE_ENABLED) ? "enabled" : "disabled"); + + // print ops that are not supported by the GPU backend (if there is one) + if (ctx_clip.backend && ctx_clip.backend != ctx_clip.backend_cpu) { + std::vector unsupported_ops; + for (const auto & op : info.ops) { + if (!op.is_accel) { + unsupported_ops.push_back(op); + } + } + if (!unsupported_ops.empty()) { + LOG_WRN("%s: *****************************************************************\n", __func__); + LOG_WRN("%s: WARNING: the CLIP graph uses unsupported operators by the backend\n", __func__); + LOG_WRN("%s: the performance will be suboptimal \n", __func__); + LOG_WRN("%s: list of unsupported ops (backend=%s):\n", __func__, ggml_backend_name(ctx_clip.backend)); + for (const auto & op : unsupported_ops) { + LOG_WRN("%s: %16s: type = %s, ne = [%d %d %d %d]\n", __func__, + ggml_op_name(op.op->op), + ggml_type_name(op.op->type), + op.op->ne[0], op.op->ne[1], op.op->ne[2], op.op->ne[3]); + } + LOG_WRN("%s: flash attention is %s\n", __func__, + (ctx_clip.flash_attn_type == CLIP_FLASH_ATTN_TYPE_ENABLED) ? "enabled" : "disabled"); + LOG_WRN("%s: please report this on github as an issue\n", __func__); + LOG_WRN("%s: ref: https://github.com/ggml-org/llama.cpp/pull/16837#issuecomment-3461676118\n", __func__); + LOG_WRN("%s: *****************************************************************\n", __func__); + } + } + } + + static support_info_graph alloc_compute_meta(clip_ctx & ctx_clip) { const auto & hparams = ctx_clip.model.hparams; ctx_clip.buf_compute_meta.resize(ctx_clip.max_nodes * ggml_tensor_overhead() + ggml_graph_overhead()); @@ -3223,57 +3329,95 @@ struct clip_model_loader { size / 1024.0 / 1024.0); } } + + const int n_splits = ggml_backend_sched_get_n_splits(ctx_clip.sched.get()); + const int n_nodes = ggml_graph_n_nodes(gf); + + LOG_INF("%s: graph splits = %d, nodes = %d\n", __func__, n_splits, n_nodes); + + support_info_graph res { + /*.fattn = */ true, + /*.fattn_op = */ nullptr, + /*.ops = */ {}, + }; + + // check op support + for (int i = 0; i < ggml_graph_n_nodes(gf); i++) { + ggml_tensor * node = ggml_graph_node(gf, i); + res.ops.push_back({node, true}); + if (!ggml_backend_supports_op(ctx_clip.backend, node)) { + res.ops.back().is_accel = false; + if (node->op == GGML_OP_FLASH_ATTN_EXT) { + res.fattn = false; + res.fattn_op = node; + } + } + } + + return res; } - void get_bool(const std::string & key, bool & output, bool required = true) { + void get_bool(const std::string & key, bool & output, bool required = true) const { const int i = gguf_find_key(ctx_gguf.get(), key.c_str()); if (i < 0) { - if (required) throw std::runtime_error("Key not found: " + key); + if (required) { + throw std::runtime_error("Key not found: " + key); + } return; } output = gguf_get_val_bool(ctx_gguf.get(), i); } - void get_i32(const std::string & key, int & output, bool required = true) { + void get_i32(const std::string & key, int & output, bool required = true) const { const int i = gguf_find_key(ctx_gguf.get(), key.c_str()); if (i < 0) { - if (required) throw std::runtime_error("Key not found: " + key); + if (required) { + throw std::runtime_error("Key not found: " + key); + } return; } output = gguf_get_val_i32(ctx_gguf.get(), i); } - void get_u32(const std::string & key, int & output, bool required = true) { + void get_u32(const std::string & key, int & output, bool required = true) const { const int i = gguf_find_key(ctx_gguf.get(), key.c_str()); if (i < 0) { - if (required) throw std::runtime_error("Key not found: " + key); + if (required) { + throw std::runtime_error("Key not found: " + key); + } return; } output = gguf_get_val_u32(ctx_gguf.get(), i); } - void get_f32(const std::string & key, float & output, bool required = true) { + void get_f32(const std::string & key, float & output, bool required = true) const { const int i = gguf_find_key(ctx_gguf.get(), key.c_str()); if (i < 0) { - if (required) throw std::runtime_error("Key not found: " + key); + if (required) { + throw std::runtime_error("Key not found: " + key); + } return; } output = gguf_get_val_f32(ctx_gguf.get(), i); } - void get_string(const std::string & key, std::string & output, bool required = true) { + void get_string(const std::string & key, std::string & output, bool required = true) const { const int i = gguf_find_key(ctx_gguf.get(), key.c_str()); if (i < 0) { - if (required) throw std::runtime_error("Key not found: " + key); + if (required) { + throw std::runtime_error("Key not found: " + key); + } return; } output = std::string(gguf_get_val_str(ctx_gguf.get(), i)); } - void get_arr_int(const std::string & key, std::vector & output, bool required = true) { + void get_arr_int(const std::string & key, std::vector & output, bool required = true) const { const int i = gguf_find_key(ctx_gguf.get(), key.c_str()); if (i < 0) { - if (required) throw std::runtime_error("Key not found: " + key); + if (required) { + throw std::runtime_error("Key not found: " + key); + } return; } int n = gguf_get_arr_n(ctx_gguf.get(), i); @@ -3284,7 +3428,7 @@ struct clip_model_loader { } } - void set_llava_uhd_res_candidates(clip_model & model, const int max_patches_per_side) { + static void set_llava_uhd_res_candidates(clip_model & model, const int max_patches_per_side) { auto & hparams = model.hparams; for (int x = 1; x <= max_patches_per_side; x++) { for (int y = 1; y <= max_patches_per_side; y++) { @@ -3312,24 +3456,22 @@ struct clip_init_result clip_init(const char * fname, struct clip_context_params ctx_vision = new clip_ctx(ctx_params); loader.load_hparams(ctx_vision->model, CLIP_MODALITY_VISION); loader.load_tensors(*ctx_vision); - loader.alloc_compute_meta(*ctx_vision); + loader.warmup(*ctx_vision); } if (loader.has_audio) { ctx_audio = new clip_ctx(ctx_params); loader.load_hparams(ctx_audio->model, CLIP_MODALITY_AUDIO); loader.load_tensors(*ctx_audio); - loader.alloc_compute_meta(*ctx_audio); + loader.warmup(*ctx_audio); } } catch (const std::exception & e) { LOG_ERR("%s: failed to load model '%s': %s\n", __func__, fname, e.what()); - if (ctx_vision) { - delete ctx_vision; - } - if (ctx_audio) { - delete ctx_audio; - } + + delete ctx_vision; + delete ctx_audio; + return {nullptr, nullptr}; } @@ -3367,10 +3509,10 @@ void clip_image_size_free(struct clip_image_size * load_image_size) { } delete load_image_size; } -void clip_image_u8_free(struct clip_image_u8 * img) { if (img) delete img; } -void clip_image_f32_free(struct clip_image_f32 * img) { if (img) delete img; } -void clip_image_u8_batch_free(struct clip_image_u8_batch * batch) { if (batch) delete batch; } -void clip_image_f32_batch_free(struct clip_image_f32_batch * batch) { if (batch) delete batch; } +void clip_image_u8_free(struct clip_image_u8 * img) { delete img; } +void clip_image_f32_free(struct clip_image_f32 * img) { delete img; } +void clip_image_u8_batch_free(struct clip_image_u8_batch * batch) { delete batch; } +void clip_image_f32_batch_free(struct clip_image_f32_batch * batch) { delete batch; } size_t clip_image_f32_batch_n_images(const struct clip_image_f32_batch * batch) { return batch->entries.size(); @@ -4096,6 +4238,18 @@ bool clip_image_preprocess(struct clip_ctx * ctx, const clip_image_u8 * img, str res_imgs->entries.push_back(std::move(img_f32)); } break; + case PROJECTOR_TYPE_JANUS_PRO: + { + // Janus Pro preprocessing: pad to square with gray(127), resize to 384x384 + const std::array pad_color = {127, 127, 127}; + clip_image_u8 resized_image; + int sz = params.image_size; + img_tool::resize(*img, resized_image, {sz, sz}, img_tool::RESIZE_ALGO_BILINEAR, true, pad_color); + clip_image_f32_ptr img_f32(clip_image_f32_init()); + normalize_image_u8_to_f32(resized_image, *img_f32, params.image_mean, params.image_std); + res_imgs->entries.push_back(std::move(img_f32)); + } break; + case PROJECTOR_TYPE_PIXTRAL: case PROJECTOR_TYPE_LIGHTONOCR: { @@ -4272,6 +4426,7 @@ int clip_n_output_tokens(const struct clip_ctx * ctx, struct clip_image_f32 * im switch (proj) { case PROJECTOR_TYPE_MLP: case PROJECTOR_TYPE_MLP_NORM: + case PROJECTOR_TYPE_JANUS_PRO: { // do nothing } break; @@ -4782,6 +4937,7 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima case PROJECTOR_TYPE_ULTRAVOX: case PROJECTOR_TYPE_LFM2: case PROJECTOR_TYPE_VOXTRAL: + case PROJECTOR_TYPE_JANUS_PRO: case PROJECTOR_TYPE_COGVLM: { // do nothing @@ -4870,6 +5026,7 @@ int clip_n_mmproj_embd(const struct clip_ctx * ctx) { return ctx->model.mm_model_mlp_3_w->ne[1]; case PROJECTOR_TYPE_QWEN2VL: case PROJECTOR_TYPE_QWEN25VL: + case PROJECTOR_TYPE_JANUS_PRO: return ctx->model.mm_1_b->ne[0]; case PROJECTOR_TYPE_QWEN3VL: // main path + deepstack paths diff --git a/tools/mtmd/clip.h b/tools/mtmd/clip.h index 3387cdbd36955..6384e2adaf775 100644 --- a/tools/mtmd/clip.h +++ b/tools/mtmd/clip.h @@ -1,6 +1,7 @@ #pragma once #include "ggml.h" + #include #include @@ -22,9 +23,16 @@ enum clip_modality { CLIP_MODALITY_AUDIO, }; +enum clip_flash_attn_type { + CLIP_FLASH_ATTN_TYPE_AUTO = -1, + CLIP_FLASH_ATTN_TYPE_DISABLED = 0, + CLIP_FLASH_ATTN_TYPE_ENABLED = 1, +}; + struct clip_context_params { bool use_gpu; enum ggml_log_level verbosity; + enum clip_flash_attn_type flash_attn_type; }; struct clip_init_result { diff --git a/tools/mtmd/mtmd-cli.cpp b/tools/mtmd/mtmd-cli.cpp index fd1fb6581b163..17aea1472b3c6 100644 --- a/tools/mtmd/mtmd-cli.cpp +++ b/tools/mtmd/mtmd-cli.cpp @@ -136,6 +136,7 @@ struct mtmd_cli_context { mparams.print_timings = true; mparams.n_threads = params.cpuparams.n_threads; mparams.verbosity = params.verbosity > 0 ? GGML_LOG_LEVEL_DEBUG : GGML_LOG_LEVEL_INFO; + mparams.flash_attn_type = params.flash_attn_type; ctx_vision.reset(mtmd_init_from_file(clip_path, model, mparams)); if (!ctx_vision.get()) { LOG_ERR("Failed to load vision model from %s\n", clip_path); diff --git a/tools/mtmd/mtmd.cpp b/tools/mtmd/mtmd.cpp index 196641dd95ef4..297eef437ab91 100644 --- a/tools/mtmd/mtmd.cpp +++ b/tools/mtmd/mtmd.cpp @@ -19,7 +19,6 @@ #include #include #include -#include #include // represents raw image data, layout is RGBRGBRGB... @@ -92,6 +91,15 @@ const char * mtmd_default_marker() { return "<__media__>"; } +static clip_flash_attn_type mtmd_get_clip_flash_attn_type(enum llama_flash_attn_type flash_attn_type) { + switch (flash_attn_type) { + case LLAMA_FLASH_ATTN_TYPE_AUTO: return CLIP_FLASH_ATTN_TYPE_AUTO; + case LLAMA_FLASH_ATTN_TYPE_DISABLED: return CLIP_FLASH_ATTN_TYPE_DISABLED; + case LLAMA_FLASH_ATTN_TYPE_ENABLED: return CLIP_FLASH_ATTN_TYPE_ENABLED; + } + return CLIP_FLASH_ATTN_TYPE_AUTO; +} + mtmd_context_params mtmd_context_params_default() { mtmd_context_params params; params.use_gpu = true; @@ -100,6 +108,7 @@ mtmd_context_params mtmd_context_params_default() { params.verbosity = GGML_LOG_LEVEL_INFO; params.image_marker = MTMD_DEFAULT_IMAGE_MARKER; params.media_marker = mtmd_default_marker(); + params.flash_attn_type = LLAMA_FLASH_ATTN_TYPE_AUTO; return params; } @@ -164,6 +173,7 @@ struct mtmd_context { clip_context_params ctx_clip_params; ctx_clip_params.use_gpu = ctx_params.use_gpu; ctx_clip_params.verbosity = ctx_params.verbosity; + ctx_clip_params.flash_attn_type = mtmd_get_clip_flash_attn_type(ctx_params.flash_attn_type); auto res = clip_init(mmproj_fname, ctx_clip_params); ctx_v = res.ctx_v; ctx_a = res.ctx_a; @@ -378,9 +388,7 @@ mtmd_context * mtmd_init_from_file(const char * mmproj_fname, } void mtmd_free(mtmd_context * ctx) { - if (ctx) { - delete ctx; - } + delete ctx; } struct mtmd_tokenizer { diff --git a/tools/mtmd/mtmd.h b/tools/mtmd/mtmd.h index 0b5d2ba0c7634..4ae1925bcdfb6 100644 --- a/tools/mtmd/mtmd.h +++ b/tools/mtmd/mtmd.h @@ -82,6 +82,7 @@ struct mtmd_context_params { enum ggml_log_level verbosity; const char * image_marker; // deprecated, use media_marker instead const char * media_marker; + enum llama_flash_attn_type flash_attn_type; }; MTMD_API const char * mtmd_default_marker(void); diff --git a/tools/server/server.cpp b/tools/server/server.cpp index 92d30664e41f4..a9bef35189b3a 100644 --- a/tools/server/server.cpp +++ b/tools/server/server.cpp @@ -2407,7 +2407,7 @@ struct server_context { params_dft.devices = params_base.speculative.devices; params_dft.model = params_base.speculative.model; - params_dft.n_ctx = params_base.speculative.n_ctx == 0 ? params_base.n_ctx / params_base.n_parallel : params_base.speculative.n_ctx; + params_dft.n_ctx = params_base.speculative.n_ctx == 0 ? llama_n_ctx_seq(ctx) : params_base.speculative.n_ctx; params_dft.n_gpu_layers = params_base.speculative.n_gpu_layers; params_dft.n_parallel = 1; params_dft.cache_type_k = params_base.speculative.cache_type_k; @@ -2456,6 +2456,7 @@ struct server_context { mparams.print_timings = false; mparams.n_threads = params_base.cpuparams.n_threads; mparams.verbosity = params_base.verbosity > 0 ? GGML_LOG_LEVEL_DEBUG : GGML_LOG_LEVEL_INFO; + mparams.flash_attn_type = params_base.flash_attn_type; mctx = mtmd_init_from_file(mmproj_path.c_str(), model, mparams); if (mctx == nullptr) { SRV_ERR("failed to load multimodal model, '%s'\n", mmproj_path.c_str()); @@ -2495,10 +2496,16 @@ struct server_context { } void init() { - const int32_t n_ctx_slot = n_ctx / params_base.n_parallel; - SRV_INF("initializing slots, n_slots = %d\n", params_base.n_parallel); + const int n_ctx_train = llama_model_n_ctx_train(model); + + int n_ctx_slot = llama_n_ctx_seq(ctx); + if (n_ctx_slot > n_ctx_train) { + SRV_WRN("the slot context (%d) exceeds the training context of the model (%d) - capping\n", n_ctx_slot, n_ctx_train); + n_ctx_slot = n_ctx_train; + } + for (int i = 0; i < params_base.n_parallel; i++) { server_slot slot; @@ -2527,7 +2534,7 @@ struct server_context { } } - SLT_INF(slot, "new slot n_ctx_slot = %d\n", slot.n_ctx); + SLT_INF(slot, "new slot, n_ctx = %d\n", slot.n_ctx); slot.callback_on_release = [this](int) { queue_tasks.pop_deferred_task(); @@ -2699,6 +2706,39 @@ struct server_context { return ret; } + // return true if at least one slot has been purged + // TODO: improve logic + // - smarter decision which slot to purge (LRU or longest prompt?) + // - move slot to level 2 cache instead of removing? + // - instead of purging, try to store and resume later? + bool try_purge_idle_slots() { + bool res = false; + + if (!params_base.kv_unified) { + return res; + } + + for (auto & slot : slots) { + if (slot.is_processing()) { + continue; + } + + if (slot.prompt.n_tokens() > 0) { + SRV_WRN("purging slot %d with %zu tokens\n", slot.id, slot.prompt.tokens.size()); + + llama_memory_seq_rm(llama_get_memory(ctx), slot.id, -1, -1); + slot.prompt.tokens.clear(); + + res = true; + + // purge slots one by one + break; + } + } + + return res; + } + bool launch_slot_with_task(server_slot & slot, server_task && task) { slot.reset(); @@ -3635,9 +3675,10 @@ struct server_context { int32_t n_batch = llama_n_batch(ctx); int32_t n_ubatch = llama_n_ubatch(ctx); - // next, batch any pending prompts without exceeding n_batch - float alora_scale = -1.0f; + float alora_scale = -1.0f; size_t alora_disabled_id = 0; + + // next, batch any pending prompts without exceeding n_batch if (params_base.cont_batching || batch.n_tokens == 0) { for (auto & slot : slots) { // check if we can batch this slot with the previous one @@ -3914,8 +3955,11 @@ struct server_context { // truncate any tokens that are beyond n_past for this slot const llama_pos p0 = slot.prompt.tokens.pos_next(); + + SLT_INF(slot, "n_tokens = %d, memory_seq_rm [%d, end)\n", slot.prompt.n_tokens(), p0); + if (!llama_memory_seq_rm(llama_get_memory(ctx), slot.id, p0, -1)) { - SLT_WRN(slot, "failed to truncate tokens with position >= %d\n", p0); + SLT_WRN(slot, "failed to truncate tokens with position >= %d - clearing the memory\n", p0); llama_memory_seq_rm(llama_get_memory(ctx), slot.id, -1, -1); // there is no common part left @@ -3924,8 +3968,6 @@ struct server_context { slot.prompt.tokens.clear(); } - SLT_INF(slot, "n_tokens = %d, memory_seq_rm [%d, end)\n", slot.prompt.n_tokens(), p0); - // check if we should process the image if (slot.prompt.n_tokens() < slot.task->n_tokens() && input_tokens[slot.prompt.n_tokens()] == LLAMA_TOKEN_NULL) { // process the image @@ -4126,6 +4168,8 @@ struct server_context { std::string err; if (n_batch == 1 && ret == 1) { + // TODO: try to terminate only the largest active slot/sequence and continue with the rest + // need to remove the tokens from the current batch too err = "Context size has been exceeded."; } @@ -4141,17 +4185,23 @@ struct server_context { // TODO: handle ret == 2 (abort) when we start aborting if (!err.empty()) { - SRV_ERR("%s, i = %d, n_batch = %d, ret = %d\n", err.c_str(), i, n_batch, ret); + SRV_ERR("%s i = %d, n_batch = %d, ret = %d\n", err.c_str(), i, n_batch, ret); + for (auto & slot : slots) { - send_error(slot, err); - slot.release(); + if (slot.is_processing()) { + send_error(slot, err); + slot.release(); + } } + break; } } // retry with half the batch size to try to find a free slot in the KV cache - n_batch /= 2; + if (!try_purge_idle_slots()) { + n_batch /= 2; + } SRV_WRN("failed to find free space in the KV cache, retrying with smaller batch size, i = %d, n_batch = %d, ret = %d\n", i, n_batch, ret); @@ -4391,6 +4441,15 @@ int main(int argc, char ** argv) { return 1; } + // TODO: should we have a separate n_parallel parameter for the server? + // https://github.com/ggml-org/llama.cpp/pull/16736#discussion_r2483763177 + if (params.n_parallel == 1 && params.kv_unified == false) { + LOG_WRN("%s: setting n_parallel = 4 and kv_unified = true\n", __func__); + + params.n_parallel = 4; + params.kv_unified = true; + } + common_init(); // struct that contains llama context and inference @@ -4944,7 +5003,7 @@ int main(int argc, char ** argv) { // Everything else, including multimodal completions. inputs = tokenize_input_prompts(ctx_server.vocab, ctx_server.mctx, prompt, true, true); } - const size_t n_ctx_slot = ctx_server.n_ctx / ctx_server.params_base.n_parallel; + const size_t n_ctx_slot = ctx_server.slots.front().n_ctx; tasks.reserve(inputs.size()); for (size_t i = 0; i < inputs.size(); i++) { auto n_prompt_tokens = inputs[i].size(); diff --git a/tools/server/tests/unit/test_chat_completion.py b/tools/server/tests/unit/test_chat_completion.py index d56d3d5f178b8..392e0efecdbbd 100644 --- a/tools/server/tests/unit/test_chat_completion.py +++ b/tools/server/tests/unit/test_chat_completion.py @@ -433,21 +433,21 @@ def test_context_size_exceeded_stream(): @pytest.mark.parametrize( "n_batch,batch_count,reuse_cache", [ - (64, 15, False), + (64, 3, False), (64, 1, True), ] ) -def test_return_progresssss(n_batch, batch_count, reuse_cache): +def test_return_progress(n_batch, batch_count, reuse_cache): global server server.n_batch = n_batch - server.n_ctx = 2048 + server.n_ctx = 256 server.n_slots = 1 server.start() def make_cmpl_request(): return server.make_stream_request("POST", "/chat/completions", data={ "max_tokens": 10, "messages": [ - {"role": "user", "content": "This is a test" * 100}, + {"role": "user", "content": "This is a test" * 10}, ], "stream": True, "return_progress": True, diff --git a/tools/server/tests/unit/test_completion.py b/tools/server/tests/unit/test_completion.py index 00ba78cf67c09..3c0ce98973f4b 100644 --- a/tools/server/tests/unit/test_completion.py +++ b/tools/server/tests/unit/test_completion.py @@ -368,6 +368,37 @@ def check_slots_status(): # assert match_regex(re_content, res.body["content"]) +@pytest.mark.parametrize( + "n_ctx,n_slots,n_predict_vals,expected_success", + [ + (256, 4, [80, 40, 80, 80], [True, True, True, True]), + (256, 4, [70, 70, 70, 70], [False, False, False, False]), + (256, 4, [90, 90, 40, 90], [False, False, True, False]), + (256, 4, [90, 90, 40, 75], [True, True, True, True]), + ], +) +def test_completion_unified(n_ctx, n_slots, n_predict_vals, expected_success): + global server + server.n_slots = n_slots + server.kv_unified = True + server.n_ctx = n_ctx + server.start() + prompt = "A" + tasks = [] + for n_predict in n_predict_vals: + tasks.append((server.make_request, ("POST", "/completion", {"prompt": prompt, "n_predict": n_predict}))) + results = parallel_function_calls(tasks) + for res, n_predict, expect_ok in zip(results, n_predict_vals, expected_success): + if expect_ok: + assert res.status_code == 200 + assert "content" in res.body + if "timings" in res.body: + assert res.body["timings"]["predicted_n"] == n_predict + else: + assert res.status_code == 500 + assert "content" not in res.body + + @pytest.mark.parametrize( "prompt,n_predict,response_fields", [ diff --git a/tools/server/tests/unit/test_infill.py b/tools/server/tests/unit/test_infill.py index 73dacdae812b8..cd1a391b4adbc 100644 --- a/tools/server/tests/unit/test_infill.py +++ b/tools/server/tests/unit/test_infill.py @@ -18,7 +18,7 @@ def test_infill_without_input_extra(): "input_suffix": "}\n", }) assert res.status_code == 200 - assert match_regex("(Ann|small|shiny|Daddy)+", res.body["content"]) + assert match_regex("(Ann|small|shiny|Daddy|Jimmy)+", res.body["content"]) def test_infill_with_input_extra(): @@ -34,7 +34,7 @@ def test_infill_with_input_extra(): "input_suffix": "}\n", }) assert res.status_code == 200 - assert match_regex("(Dad|excited|park)+", res.body["content"]) + assert match_regex("(Dad|excited|park|Jimmy)+", res.body["content"]) @pytest.mark.parametrize("input_extra", [ diff --git a/tools/server/tests/utils.py b/tools/server/tests/utils.py index 4ba3d43c33044..da703c4c51a15 100644 --- a/tools/server/tests/utils.py +++ b/tools/server/tests/utils.py @@ -78,6 +78,7 @@ class ServerProcess: server_embeddings: bool | None = False server_reranking: bool | None = False server_metrics: bool | None = False + kv_unified: bool | None = False server_slots: bool | None = False pooling: str | None = None draft: int | None = None @@ -159,6 +160,8 @@ def start(self, timeout_seconds: int | None = DEFAULT_HTTP_TIMEOUT) -> None: server_args.append("--reranking") if self.server_metrics: server_args.append("--metrics") + if self.kv_unified: + server_args.append("--kv-unified") if self.server_slots: server_args.append("--slots") else: diff --git a/tools/server/utils.hpp b/tools/server/utils.hpp index b6198edfc487c..2bce2f4a47af9 100644 --- a/tools/server/utils.hpp +++ b/tools/server/utils.hpp @@ -1212,7 +1212,7 @@ struct server_tokens { for (auto it = tokens.map_idx_to_media.begin(); it != tokens.map_idx_to_media.end(); ) { auto * chunk = tokens.map_idx_to_media[it->first].get(); mtmd::input_chunk_ptr new_chunk(mtmd_input_chunk_copy(chunk)); - map_idx_to_media[start_idx+it->first] = std::move(new_chunk); + map_idx_to_media[start_idx + it->first] = std::move(new_chunk); } } } @@ -1244,6 +1244,7 @@ struct server_tokens { } void clear() { + map_idx_to_media.clear(); tokens.clear(); }