From 1c697fa487b91c2691ce0737e85a6802f7928789 Mon Sep 17 00:00:00 2001 From: KKkai <1640576073@qq.com> Date: Thu, 30 Oct 2025 00:08:42 +0800 Subject: [PATCH 1/7] add minicpmo model support --- examples/CMakeLists.txt | 10 +- examples/minicpm_o/CMakeLists.txt | 22 +- examples/minicpm_o/config_minicpm_o.json | 37 + examples/minicpm_o/main.cpp | 173 +++- examples/minicpm_o/mainllm.cpp | 100 ++ mllm/backends/cpu/kernels/Kernels.hpp | 1 + .../minicpm_o2_6/configuration_minicpmo.hpp | 99 ++ .../image_preprocessor_minicpmo.hpp | 382 ++++++++ .../models/minicpm_o2_6/modeling_minicpmo.hpp | 361 +++++++ .../modeling_qwen2vl_for_minicpmo.hpp | 925 ++++++++++++++++++ .../minicpm_o2_6/modeling_resampler.hpp | 407 ++++++++ mllm/models/minicpm_o2_6/modeling_siglip.hpp | 452 +++++++++ .../minicpm_o2_6/tokenization_minicpmo.hpp | 496 ++++++++++ mllm/nn/Nn.hpp | 1 + tasks/build_osx_apple_silicon.yaml | 3 +- tasks/build_osx_apple_silicon_dbg.yaml | 4 +- 16 files changed, 3446 insertions(+), 27 deletions(-) create mode 100644 examples/minicpm_o/config_minicpm_o.json create mode 100644 examples/minicpm_o/mainllm.cpp create mode 100644 mllm/models/minicpm_o2_6/configuration_minicpmo.hpp create mode 100644 mllm/models/minicpm_o2_6/image_preprocessor_minicpmo.hpp create mode 100644 mllm/models/minicpm_o2_6/modeling_minicpmo.hpp create mode 100644 mllm/models/minicpm_o2_6/modeling_qwen2vl_for_minicpmo.hpp create mode 100644 mllm/models/minicpm_o2_6/modeling_resampler.hpp create mode 100644 mllm/models/minicpm_o2_6/modeling_siglip.hpp create mode 100644 mllm/models/minicpm_o2_6/tokenization_minicpmo.hpp diff --git a/examples/CMakeLists.txt b/examples/CMakeLists.txt index 9f8fd0332..a0bef0240 100644 --- a/examples/CMakeLists.txt +++ b/examples/CMakeLists.txt @@ -1,8 +1,8 @@ -add_subdirectory(qwen2vl) -add_subdirectory(qwen2vl_tracer) -add_subdirectory(qwen2_5vl) -add_subdirectory(qwen2_5vl_tracer) -add_subdirectory(llama) +# add_subdirectory(qwen2vl) +# add_subdirectory(qwen2vl_tracer) +# add_subdirectory(qwen2_5vl) +# add_subdirectory(qwen2_5vl_tracer) +# add_subdirectory(llama) add_subdirectory(minicpm_o) add_subdirectory(qwen3) add_subdirectory(qwen3_service) diff --git a/examples/minicpm_o/CMakeLists.txt b/examples/minicpm_o/CMakeLists.txt index 95d4b89c5..1f9f8e1a9 100644 --- a/examples/minicpm_o/CMakeLists.txt +++ b/examples/minicpm_o/CMakeLists.txt @@ -1,3 +1,19 @@ -add_executable(mllm-minicpm-o main.cpp) -target_link_libraries(mllm-minicpm-o PRIVATE MllmRT MllmCPUBackend) -target_include_directories(mllm-minicpm-o PRIVATE ${MLLM_INCLUDE_DIR}) +cmake_minimum_required(VERSION 3.10) +include_directories($ENV{HOME}/local/include) +link_directories($ENV{HOME}/local/lib) + +add_executable(main_minicpm_o main.cpp) +target_link_libraries(main_minicpm_o PRIVATE MllmRT MllmCPUBackend cnpy z) +target_include_directories(main_minicpm_o PRIVATE ${MLLM_INCLUDE_DIR}) + +add_executable(main_minicpm_o2 mainllm.cpp) +target_link_libraries(main_minicpm_o2 PRIVATE MllmRT MllmCPUBackend cnpy z) +target_include_directories(main_minicpm_o2 PRIVATE ${MLLM_INCLUDE_DIR}) + +add_executable(main_minicpm_dbg main_dbg.cpp) +target_link_libraries(main_minicpm_dbg PRIVATE MllmRT MllmCPUBackend cnpy z) +target_include_directories(main_minicpm_dbg PRIVATE ${MLLM_INCLUDE_DIR}) + +add_executable(tokenizer_test tokenizer_test.cpp) +target_link_libraries(tokenizer_test PRIVATE MllmRT MllmCPUBackend) +target_include_directories(tokenizer_test PRIVATE ${MLLM_INCLUDE_DIR}) diff --git a/examples/minicpm_o/config_minicpm_o.json b/examples/minicpm_o/config_minicpm_o.json new file mode 100644 index 000000000..0064ae1ff --- /dev/null +++ b/examples/minicpm_o/config_minicpm_o.json @@ -0,0 +1,37 @@ +{ + "vision_hidden_size": 1152, + "vision_intermediate_size": 4304, + "vision_num_hidden_layers": 27, + "vision_num_attention_heads": 16, + "vision_num_channels": 3, + "vision_image_size": 980, + "vision_patch_size": 14, + + "hidden_size": 3584, + "intermediate_size": 18944, + "num_attention_heads": 28, + "num_key_value_heads": 4, + "num_hidden_layers": 28, + "max_position_embeddings": 32768, + "rms_norm_eps": 1e-06, + "vocab_size": 151700, + + "query_num": 64, + + "audio_hidden_size": 1024, + "audio_num_hidden_layers": 24, + "audio_num_attention_heads": 16, + "audio_max_position_embeddings": 1500, + "audio_chunk_length": 1.0, + "audio_pool_step": 2, + + "tts_llm_dim": 3584, + + "max_cache_length": 8192, + "eos_token_id": 151645, + "bos_token_id": 151643, + "rope_theta": 1000000.0, + "tie_word_embeddings": true, + + "linear_impl_type": "default" +} diff --git a/examples/minicpm_o/main.cpp b/examples/minicpm_o/main.cpp index aec562995..a7d9fd32c 100644 --- a/examples/minicpm_o/main.cpp +++ b/examples/minicpm_o/main.cpp @@ -1,16 +1,159 @@ - +#include #include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include - -using namespace mllm; // NOLINT - -MLLM_MAIN({NYI("leave empty")}); +#include "mllm/mllm.hpp" +#include "mllm/models/minicpm_o2_6/configuration_minicpmo.hpp" +// #include "mllm/models/minicpm_o2_6/modeling_minicpmo.hpp" +#include "mllm/models/minicpm_o2_6/modeling_resampler.hpp" +#include "mllm/models/minicpm_o2_6/modeling_siglip.hpp" +#include "mllm/models/minicpm_o2_6/tokenization_minicpmo.hpp" +#include "mllm/models/minicpm_o2_6/image_preprocessor_minicpmo.hpp" +#include "mllm/utils/AnyValue.hpp" +#include "mllm/preprocessor/visual/Image.hpp" + +using mllm::Argparse; + +MLLM_MAIN({ + mllm::Logger::level() = mllm::LogLevel::kError; + + auto& help = Argparse::add("-h|--help").help("Show help message"); + auto& model_path = Argparse::add("-m|--model_path").help("Model path").required(true); + auto& model_version = Argparse::add("-mv|--model_version").help("Model version").required(true); + auto& tokenizer_path = Argparse::add("-t|--tokenizer_path").help("Tokenizer directory").required(true); + auto& config_path = Argparse::add("-c|--config_path").help("Config path").required(true); + // RUN: ./main_minicpm_o -m ../../models/minicpm-o-2_6.mllm -mv v1 -t ../../tokenizer/MiniCPM-o-2_6/tokenizer.json -c ../../examples/minicpm_o/config_minicpm_o.json + + Argparse::parse(argc, argv); + +#ifdef MLLM_PERFETTO_ENABLE + mllm::perf::start(); +#endif + + mllm::ModelFileVersion file_version = mllm::ModelFileVersion::kV1; + if (model_version.get() == "v1") { + file_version = mllm::ModelFileVersion::kV1; + } else if (model_version.get() == "v2") { + file_version = mllm::ModelFileVersion::kV2; + } + + if (help.isSet()) { + Argparse::printHelp(); + mllm::shutdownContext(); + return 0; + } + { + auto minicpmo_cfg = mllm::models::minicpmo::MiniCPMOConfig(config_path.get()); + auto minicpmo_tokenizer = mllm::models::minicpmo::MiniCPMOTokenizer(tokenizer_path.get()); + + mllm::models::minicpmo::MiniCPMOMessage message; + message.prompt = "现在你是太监,这个男子是皇上,你需要真心实意地奉承他"; + message.img_file_path = "/Users/kkkai/Desktop/pics.jpg"; + auto output = minicpmo_tokenizer.convertMessage(message); + mllm::print(output["input_ids"].shape()); + mllm::print(output["pixel_values"].shape()); + mllm::print(output["tgt_sizes"].shape()); + mllm::print(output["image_bounds"].shape()); + + auto param = mllm::load(model_path.get(), file_version); + auto siglip = mllm::models::minicpmo::SiglipVisionModel("vpm", minicpmo_cfg); + siglip.load(param); + auto res = siglip(output["pixel_values"], output["tgt_sizes"])[0]; + auto resampler = mllm::models::minicpmo::Resampler("resampler", 64, 3584, 28, 1152); + resampler.load(param); + auto res2 = resampler(res, output["tgt_sizes"])[0]; + + // auto minicpmo = mllm::models::minicpmo::MiniCPMOForCausalLM(minicpmo_cfg); + + // // Load model weights + // auto param = mllm::load(model_path.get(), file_version); + // minicpmo.load(param); + + // fmt::print("\n{:*^60}\n", " MiniCPM-o Interactive CLI "); + // fmt::print("Enter 'exit' or 'quit' to end the session\n"); + // fmt::print("Supported modes: text, image+text, audio+text, multimodal\n\n"); + + // while (true) { + // std::string mode; + // fmt::print("Mode (text/image/audio/multi) or 'exit': "); + // std::getline(std::cin, mode); + + // if (mode == "exit" || mode == "quit") { + // break; + // } + + // mllm::models::minicpmo::MiniCPMOInput input; + + // // Handle different input modes + // if (mode == "image" || mode == "multi") { + // std::string image_path; + // fmt::print("Image path: "); + // std::getline(std::cin, image_path); + // if (!image_path.empty()) { + // input.img_file_path = image_path; + // } + // } + + // if (mode == "audio" || mode == "multi") { + // std::string audio_path; + // fmt::print("Audio path: "); + // std::getline(std::cin, audio_path); + // if (!audio_path.empty()) { + // input.audio_file_path = audio_path; + // } + // } + + // std::string prompt_text; + // fmt::print("Prompt text: "); + // std::getline(std::cin, prompt_text); + // input.prompt = prompt_text; + + // try { + // fmt::print("Processing...\n"); + + // // Convert input to tokens + // auto input_tokens = minicpmo_tokenizer.convertMessage(input); + + // // Process images if provided + // auto image_tensors = minicpmo_tokenizer.processImages(input); + + // // Process audio if provided + // auto audio_tensors = minicpmo_tokenizer.processAudio(input); + + // fmt::print("\nResponse: "); + + // // TODO: Implement multimodal chat interface + // // For now, use text-only generation + // std::vector token_ids; + // auto input_ptr = input_tokens.ptr(); + // auto seq_len = input_tokens.shape()[1]; + // for (int i = 0; i < seq_len; ++i) { + // token_ids.push_back(input_ptr[i]); + // } + + // // Generate response + // for (auto& step : minicpmo.chat(token_ids)) { + // auto token_str = minicpmo_tokenizer.detokenize(step.cur_token_id); + // std::wcout << token_str << std::flush; + + // // TODO: Check for audio generation tokens + // if (minicpmo_tokenizer.isAudioToken(step.cur_token_id)) { + // fmt::print("\n🔊 [Audio generation triggered - feature not implemented yet]\n"); + // } + // } + + // fmt::print("\n\n"); + + // } catch (const std::exception& e) { + // fmt::print(" Error: {}\n", e.what()); + // } + // } + + // fmt::print("Success!\n"); + } + +#ifdef MLLM_PERFETTO_ENABLE + mllm::perf::stop(); +#endif + + mllm::shutdownContext(); + return 0; +}) diff --git a/examples/minicpm_o/mainllm.cpp b/examples/minicpm_o/mainllm.cpp new file mode 100644 index 000000000..48712eab3 --- /dev/null +++ b/examples/minicpm_o/mainllm.cpp @@ -0,0 +1,100 @@ +#include +#include +#include "mllm/mllm.hpp" +#include "mllm/models/minicpm_o2_6/configuration_minicpmo.hpp" +#include "mllm/models/minicpm_o2_6/modeling_minicpmo.hpp" +#include "mllm/models/minicpm_o2_6/modeling_resampler.hpp" +#include "mllm/models/minicpm_o2_6/modeling_siglip.hpp" +#include "mllm/models/minicpm_o2_6/tokenization_minicpmo.hpp" +#include "mllm/models/minicpm_o2_6/image_preprocessor_minicpmo.hpp" +#include "mllm/utils/AnyValue.hpp" +#include "mllm/preprocessor/visual/Image.hpp" +#include "cnpy.h" + +using mllm::Argparse; + +MLLM_MAIN({ + + mllm::Logger::level() = mllm::LogLevel::kError; + //mllm::setPrintMaxElementsPerDim(1000); + + auto& help = Argparse::add("-h|--help").help("Show help message"); + auto& model_path = Argparse::add("-m|--model_path").help("Model path").required(true); + auto& model_version = Argparse::add("-mv|--model_version").help("Model version").required(true); + auto& tokenizer_path = Argparse::add("-t|--tokenizer_path").help("Tokenizer directory").required(true); + auto& config_path = Argparse::add("-c|--config_path").help("Config path").required(true); + // RUN: ./main_minicpm_o2 -m ../../models/minicpm-o-2_6.mllm -mv v1 -t ../../tokenizer/MiniCPM-o-2_6/tokenizer.json -c ../../examples/minicpm_o/config_minicpm_o.json + + Argparse::parse(argc, argv); + +#ifdef MLLM_PERFETTO_ENABLE + mllm::perf::start(); +#endif + + mllm::ModelFileVersion file_version = mllm::ModelFileVersion::kV1; + if (model_version.get() == "v1") { + file_version = mllm::ModelFileVersion::kV1; + } else if (model_version.get() == "v2") { + file_version = mllm::ModelFileVersion::kV2; + } + + if (help.isSet()) { + Argparse::printHelp(); + mllm::shutdownContext(); + return 0; + } + { + auto minicpmo_cfg = mllm::models::minicpmo::MiniCPMOConfig(config_path.get()); + auto minicpmo_tokenizer = mllm::models::minicpmo::MiniCPMOTokenizer(tokenizer_path.get()); + auto minicpmo = mllm::models::minicpmo::MiniCPMOForCausalLM(minicpmo_cfg); + + auto param = mllm::load(model_path.get(), file_version); + minicpmo.llm_.llm.load(param); + minicpmo.vpm_.load(param); + minicpmo.resampler_.load(param); + //minicpmo.audio_proj_.load(param); + //minicpmo.tts_proj_.load(param); + + fmt::print("\n{:*^60}\n", " MiniCPM-o Interactive CLI "); + fmt::print("Enter 'exit' or 'quit' to end the session\n"); + + std::string image_path = "/Users/kkkai/Desktop/pics.jpg"; + std::string prompt_text = "描述图片中物体"; + mllm::models::minicpmo::MiniCPMOMessage message; + message.prompt = prompt_text; + message.img_file_path = image_path; + + // fmt::print("📷 Image path (or 'exit/quit'): "); + // std::getline(std::cin, image_path); + // if (image_path == "exit" || image_path == "quit") { return 0; } + // fmt::print("💬 Prompt text: "); + // std::getline(std::cin, prompt_text); + + fmt::print("Processing...\n"); + auto inputs = minicpmo_tokenizer.convertMessage(message); + + fmt::print("\nResponse: "); + + int token_count = 0; + for(auto& step : minicpmo.chat(inputs)){ + auto token_str = minicpmo_tokenizer.detokenize(step.cur_token_id); + std::wcout<< token_str << std::flush; + + token_count++; + if(token_count >= 50) break; // Limit output for debugging + } + + fmt::print("\n{}\n", std::string(60, '-')); + + + +#ifdef MLLM_PERFETTO_ENABLE + mllm::perf::stop(); + mllm::perf::saveReport("minicpmo.perf"); +#endif + + mllm::memoryReport(); + mllm::shutdownContext(); + return 0; + } +}) diff --git a/mllm/backends/cpu/kernels/Kernels.hpp b/mllm/backends/cpu/kernels/Kernels.hpp index 026cc84a8..e99ffdd73 100644 --- a/mllm/backends/cpu/kernels/Kernels.hpp +++ b/mllm/backends/cpu/kernels/Kernels.hpp @@ -27,6 +27,7 @@ #include "mllm/backends/cpu/kernels/arm/softmax.hpp" // IWYU pragma: export #include "mllm/backends/cpu/kernels/arm/rmsnorm.hpp" // IWYU pragma: export #include "mllm/backends/cpu/kernels/arm/gelu.hpp" // IWYU pragma: export +#include "mllm/backends/cpu/kernels/arm/conv2d.hpp" // IWYU pragma: export #include "mllm/backends/cpu/kernels/arm/conv3d.hpp" // IWYU pragma: export #include "mllm/backends/cpu/kernels/arm/linear/kai.hpp" // IWYU pragma: export #include "mllm/backends/cpu/kernels/arm/relu.hpp" // IWYU pragma: export diff --git a/mllm/models/minicpm_o2_6/configuration_minicpmo.hpp b/mllm/models/minicpm_o2_6/configuration_minicpmo.hpp new file mode 100644 index 000000000..841b2ec6e --- /dev/null +++ b/mllm/models/minicpm_o2_6/configuration_minicpmo.hpp @@ -0,0 +1,99 @@ +// Copyright (c) MLLM Team. +// Licensed under the MIT License. +#pragma once + +#include "mllm/core/aops/LinearOp.hpp" +#include "mllm/engine/ConfigFile.hpp" + +namespace mllm::models::minicpmo { + +struct MiniCPMOConfig : protected ConfigFile { + MiniCPMOConfig() = default; + + explicit MiniCPMOConfig(const std::string& file_path) : ConfigFile(file_path) { + // Vision Config + vision_hidden_size = data()["vision_hidden_size"]; + vision_intermediate_size = data()["vision_intermediate_size"]; + vision_num_hidden_layers = data()["vision_num_hidden_layers"]; + vision_num_attention_heads = data()["vision_num_attention_heads"]; + vision_num_channels = data()["vision_num_channels"]; + vision_image_size = data()["vision_image_size"]; + vision_patch_size = data()["vision_patch_size"]; + + // LLM Config (Qwen2 based) + hidden_size = data()["hidden_size"]; + intermediate_size = data()["intermediate_size"]; + num_attention_heads = data()["num_attention_heads"]; + num_key_value_heads = data()["num_key_value_heads"]; + num_hidden_layers = data()["num_hidden_layers"]; + max_position_embeddings = data()["max_position_embeddings"]; + rms_norm_eps = data()["rms_norm_eps"]; + vocab_size = data()["vocab_size"]; + + // Resampler Config + query_num = data()["query_num"]; + + // Audio Config (Whisper based) + audio_hidden_size = data()["audio_hidden_size"]; + audio_num_hidden_layers = data()["audio_num_hidden_layers"]; + audio_num_attention_heads = data()["audio_num_attention_heads"]; + audio_max_position_embeddings = data()["audio_max_position_embeddings"]; + audio_chunk_length = data()["audio_chunk_length"]; + audio_pool_step = data()["audio_pool_step"]; + + // TTS Config + tts_llm_dim = data()["tts_llm_dim"]; + + // Common Config + max_cache_length = data()["max_cache_length"]; + eos_token_id = data()["eos_token_id"]; + rope_theta = data()["rope_theta"]; + tie_word_embeddings = data()["tie_word_embeddings"]; + + linear_impl_type = aops::str2LinearImplTypes(data()["linear_impl_type"]); + } + + // Vision Config (SigLIP) + int32_t vision_hidden_size = 1152; + int32_t vision_intermediate_size = 4304; + int32_t vision_num_hidden_layers = 27; + int32_t vision_num_attention_heads = 16; + int32_t vision_num_channels = 3; + int32_t vision_image_size = 980; + int32_t vision_patch_size = 14; + + // LLM Config (Qwen2.5-7B) + int32_t hidden_size = 3584; + int32_t intermediate_size = 18944; + int32_t num_attention_heads = 28; + int32_t num_key_value_heads = 4; + int32_t num_hidden_layers = 28; + int32_t max_position_embeddings = 32768; + float rms_norm_eps = 1e-06; + int32_t vocab_size = 151700; + + // Resampler Config + int32_t query_num = 64; + + // Audio Config (Whisper) + int32_t audio_hidden_size = 1024; + int32_t audio_num_hidden_layers = 24; + int32_t audio_num_attention_heads = 16; + int32_t audio_max_position_embeddings = 1500; + float audio_chunk_length = 1.0; + int32_t audio_pool_step = 2; + + // TTS Config (按实际添加更改) + int32_t tts_llm_dim = 3584; + + // Common Config + int32_t max_cache_length = 8192; + int64_t eos_token_id = 151645; + int64_t bos_token_id = 151643; + float rope_theta = 1000000.0; + bool tie_word_embeddings = false; + + aops::LinearImplTypes linear_impl_type = aops::LinearImplTypes::kDefault; +}; + +} // namespace mllm::models::minicpmo diff --git a/mllm/models/minicpm_o2_6/image_preprocessor_minicpmo.hpp b/mllm/models/minicpm_o2_6/image_preprocessor_minicpmo.hpp new file mode 100644 index 000000000..2f161b654 --- /dev/null +++ b/mllm/models/minicpm_o2_6/image_preprocessor_minicpmo.hpp @@ -0,0 +1,382 @@ +// Copyright (c) MLLM Team. +// Licensed under the MIT License. +#pragma once + +#include +#include +#include +#include +#include +#include +#include + +#include "mllm/core/DataTypes.hpp" +#include "mllm/mllm.hpp" +#include "mllm/nn/Module.hpp" +#include "mllm/core/Tensor.hpp" +#include "mllm/utils/Common.hpp" +#include "mllm/preprocessor/visual/Image.hpp" + +namespace mllm::models::minicpmo { + +// Utility functions for image slicing (similar to MiniCPMV) +class ImageSliceProcessor { +public: + ImageSliceProcessor() = default; + + explicit ImageSliceProcessor(int max_slice_nums = 9, int scale_resolution = 448, int patch_size = 14) : + max_slice_nums_(max_slice_nums), + scale_resolution_(scale_resolution), + patch_size_(patch_size) { + } + + std::tuple>, std::vector> slice_image(Image img, bool never_split = false) { + int original_width = img.w(); + int original_height = img.h(); + + auto best_grid = get_sliced_grid(original_width, original_height, never_split); + std::vector> patches; + Image source_img; + if(best_grid.empty()){ + // Don't need to slice, just upsample + auto best_size = find_best_resize(original_width, original_height, true); + source_img = img.resize(best_size.first, best_size.second); + } else { + // Source image: down-sampling and ensure divided by patch_size + auto best_resize = find_best_resize(original_width, original_height); + source_img = img.resize(best_resize.first, best_resize.second); + // Refine image for slicing + auto refine_size = get_refine_size(original_width, original_height, best_grid); + auto refine_image = img.resize(refine_size.first, refine_size.second); + patches = split_to_patches(refine_image, best_grid); + } + + return std::make_tuple(source_img, patches, best_grid); + } + + std::vector> split_to_patches(Image& image, const std::vector& grid) { + std::vector> patches; + int width = image.w(); + int height = image.h(); + int grid_x = width / grid[0]; + int grid_y = height / grid[1]; + + for (int i = 0; i < height; i += grid_y) { + std::vector row_patches; + for (int j = 0; j < width; j += grid_x) { + // Calculate crop region + int crop_width = std::min(grid_x, width - j); + int crop_height = std::min(grid_y, height - i); + + // Create patch by cropping the region + auto patch = crop_image(image, j, i, crop_width, crop_height); + row_patches.push_back(patch); + } + patches.push_back(row_patches); + } + + return patches; + } + + std::vector get_sliced_grid(int width, int height, bool never_split = false){ + float log_ratio = std::log((float)width / height); + float ratio = (float)(width * height)/(scale_resolution_ * scale_resolution_); + int multiple = std::min((int)std::ceil(ratio), max_slice_nums_); + if(multiple <= 1 || never_split){ + return {}; + } + std::vector candidate_nums; + for(int i : {multiple-1, multiple, multiple+1}){ + if(i > 1 && i <=max_slice_nums_){ + candidate_nums.push_back(i); + } + } + + std::vector> candidate_grids; + for (int split_num : candidate_nums){ + for(int m = 1; m <= split_num; ++m){ + if(split_num % m == 0){ + candidate_grids.push_back({m, split_num/m}); + } + } + } + + std::vector best_grid = {1, 1}; + float min_error = INFINITY; + for(auto& grid : candidate_grids){ + float error = std::abs(log_ratio - std::log((float)grid[0]/grid[1])); + if(error < min_error){ + best_grid = grid; + min_error = error; + } + } + + return best_grid; + } + + std::pair find_best_resize(int width, int height, bool allow_upscale = false){ + if((width * height > scale_resolution_ * scale_resolution_) || allow_upscale){ + float r = (float)width / height; + int new_height = (int)(scale_resolution_ / std::sqrt(r)); + int new_width = (int)(new_height * r); + width = new_width; + height = new_height; + } + int best_width = ensure_divide(width, patch_size_); + int best_height = ensure_divide(height, patch_size_); + return {best_width, best_height}; + } + + int ensure_divide(int length, int divisor){ + return std::max(static_cast(std::round(static_cast(length) / divisor)) * divisor, divisor); + } + + std::pair get_refine_size(int width, int height, std::vector grid){ + int grid_x = grid[0]; + int grid_y = grid[1]; + int refine_width = ensure_divide(width, grid_x); + int refine_height = ensure_divide(height, grid_y); + int grid_width = refine_width / grid_x; + int grid_height = refine_height / grid_y; + auto best_grid_size = find_best_resize(grid_width, grid_height, true); + return {best_grid_size.first * grid_x, best_grid_size.second * grid_y}; + } + + // crop image region (for reshape_by_patch) + Image crop_image(Image& image, int x, int y, int crop_width, int crop_height) { + // Get source image properties + int src_width = image.w(); + int src_height = image.h(); + int src_channels = image.c(); + + // Ensure crop bounds are valid + x = std::max(0, std::min(x, src_width - 1)); + y = std::max(0, std::min(y, src_height - 1)); + crop_width = std::min(crop_width, src_width - x); + crop_height = std::min(crop_height, src_height - y); + + unsigned char* src_data = image.ptr(); + unsigned char* crop_data = new unsigned char[crop_width * crop_height * src_channels]; + + // Copy pixel data row by row + for (int row = 0; row < crop_height; ++row) { + int src_row_offset = ((y + row) * src_width + x) * src_channels; + int crop_row_offset = row * crop_width * src_channels; + std::memcpy(crop_data + crop_row_offset, + src_data + src_row_offset, + crop_width * src_channels); + } + + // Create a temporary file to save cropped data and load as Image + std::string temp_path = "/tmp/crop_" + std::to_string(rand()) + ".png"; + stbi_write_png(temp_path.c_str(), crop_width, crop_height, src_channels, crop_data, crop_width * src_channels); + + Image cropped_image = Image::open(temp_path); + + delete[] crop_data; + std::remove(temp_path.c_str()); + + return cropped_image; + } + +private: + int max_slice_nums_; + int scale_resolution_; + int patch_size_; +}; + +class MiniCPMOImageProcessor{ +public: + explicit MiniCPMOImageProcessor(int patch_size = 14, int image_size = 980, + float mean_0 = 0.5, float mean_1 = 0.5, float mean_2 = 0.5, + float std_0 = 0.5, float std_1 = 0.5, float std_2 = 0.5) : + patch_size_(patch_size), + image_size_(image_size), + mean_{mean_0, mean_1, mean_2}, + std_{std_0, std_1, std_2}, + image_slice_processor_(9, 448, patch_size) + { + } + + std::string get_slice_image_placeholder(std::pair image_size, const std::vector &grid, int image_idx = 0, bool use_image_id = true){ + std::string image_placeholder = ""; + for(int i=0; i < image_feature_size; i++){ + image_placeholder += ""; + } + image_placeholder += ""; + + std::string final_placeholder; + if(use_image_id){ + final_placeholder = "" + std::to_string(image_idx) + "" + image_placeholder; + } else { + final_placeholder = image_placeholder; + } + + if(!grid.empty()){ + final_placeholder += get_grid_placeholder(grid); + } + + return final_placeholder; + }; + + std::string get_grid_placeholder(const std::vector &grid){ + std::string slice_image_placeholder = ""; + for(int i=0; i slice; + + for(int i=0; i 0){ + result += "\n"; + } + result += slice[i]; + } + return result; + } + + std::pair, std::vector>> calc_bounds(const std::vector &input_ids, + preprocessor::BPE& bpe, int max_length = 8192){ + std::vector> image_bounds; + // Get token IDs dynamically from BPE + int im_start_id = bpe._lookup_vocab(L""); + int im_end_id = bpe._lookup_vocab(L""); + int slice_start_id = bpe._lookup_vocab(L""); + int slice_end_id = bpe._lookup_vocab(L""); + + std::vector image_start_positions; + std::vector image_end_positions; + + int seq_len = input_ids.size(); + for(int i = 0; i < seq_len; ++i){ + int token_id = input_ids[i]; + if(token_id ==im_start_id || token_id == slice_start_id){ + image_start_positions.push_back(i); + } + if(token_id ==im_end_id || token_id == slice_end_id){ + image_end_positions.push_back(i); + } + } + int valid_image_nums = std::max(image_start_positions.size(), image_end_positions.size()); + + for(int i = 0; i < valid_image_nums && i < image_start_positions.size() && i < image_end_positions.size(); ++i){ + image_bounds.push_back({image_start_positions[i], image_end_positions[i]}); + } + + return {input_ids, image_bounds}; + + } + + std::tuple, std::vector>, std::vector>, std::vector> process(const std::string& image_path, int max_slice_nums = 9) { + auto img = Image::open(image_path); + std::pair original_size = {img.w(), img.h()}; + auto [source_image, patches, grid] = image_slice_processor_.slice_image(img); + std::vector slice_images; + slice_images.push_back(source_image); + + if (!patches.empty()) { + for (const auto& patch_row : patches) { + for (const auto& patch : patch_row) { + slice_images.push_back(patch); + } + } + } + std::vector processed_tensors; + std::vector> tgt_sizes; + + for (auto& slice_img : slice_images) { + + auto tensor = slice_img.tensor(); // [H, W, C] + tensor = tensor.permute({2, 0, 1}); // [C, H, W] + tensor = tensor * (1.f / 255.f); + normalize_tensor(tensor); + auto reshaped_tensor = reshape_by_patch(tensor); + processed_tensors.push_back(reshaped_tensor); + + // Calculate target size (patch dimensions) + int patch_h = tensor.shape()[1] / patch_size_; + int patch_w = tensor.shape()[2] / patch_size_; + tgt_sizes.push_back({patch_h, patch_w}); + } + return std::make_tuple(processed_tensors, std::vector>{original_size}, tgt_sizes, grid); + } + +private: + void normalize_tensor(Tensor& tensor) { + auto tensor_ptr = tensor.ptr(); + int channels = tensor.shape()[0]; + int height = tensor.shape()[1]; + int width = tensor.shape()[2]; + + for (int c = 0; c < channels; ++c) { + float mean = mean_[c]; + float std = std_[c]; + + for (int h = 0; h < height; ++h) { + for (int w = 0; w < width; ++w) { + int idx = c * height * width + h * width + w; + tensor_ptr[idx] = (tensor_ptr[idx] - mean) / std; + } + } + } + } + + Tensor reshape_by_patch(Tensor& input_tensor) { + // Input: [C, H, W], Output: [C, patch_size, total_patches * patch_size] + int channels = input_tensor.shape()[0]; + int height = input_tensor.shape()[1]; + int width = input_tensor.shape()[2]; + + int num_patches_h = height / patch_size_; + int num_patches_w = width / patch_size_; + int total_patches = num_patches_h * num_patches_w; + + auto output = Tensor::empty({channels, patch_size_, total_patches * patch_size_}, kFloat32).alloc(); + + for (int c = 0; c < channels; ++c) { + for (int ph = 0; ph < num_patches_h; ++ph) { + for (int pw = 0; pw < num_patches_w; ++pw) { + int patch_idx = ph * num_patches_w + pw; + int start_h = ph * patch_size_; + int start_w = pw * patch_size_; + + for (int kh = 0; kh < patch_size_; ++kh) { + for (int kw = 0; kw < patch_size_; ++kw) { + int img_h = start_h + kh; + int img_w = start_w + kw; + int output_col = patch_idx * patch_size_ + kw; + + *output.offsettedPtr({c, kh, output_col}) = + input_tensor.at({c, img_h, img_w}); + } + } + } + } + } + + return output; + } + +private: + int image_size_; + int patch_size_; + int image_feature_size = 64; //对应query_num + std::array mean_; + std::array std_; + ImageSliceProcessor image_slice_processor_; +}; + +} // namespace mllm::models::minicpmo diff --git a/mllm/models/minicpm_o2_6/modeling_minicpmo.hpp b/mllm/models/minicpm_o2_6/modeling_minicpmo.hpp new file mode 100644 index 000000000..2233f34dd --- /dev/null +++ b/mllm/models/minicpm_o2_6/modeling_minicpmo.hpp @@ -0,0 +1,361 @@ +// Copyright (c) MLLM Team. +// Licensed under the MIT License. +#pragma once + +#include "mllm/mllm.hpp" +#include "mllm/nn/Module.hpp" +#include "mllm/nn/Nn.hpp" +#include "mllm/nn/Functional.hpp" +#include "mllm/nn/lmcache/StaticCache.hpp" +#include "mllm/models/minicpm_o2_6/configuration_minicpmo.hpp" +#include "mllm/models/minicpm_o2_6/modeling_siglip.hpp" +#include "mllm/models/minicpm_o2_6/modeling_resampler.hpp" +#include "mllm/models/minicpm_o2_6/modeling_qwen2vl_for_minicpmo.hpp" +#include "mllm/models/ARGeneration.hpp" +#include "mllm/utils/Enumerate.hpp" +#include "cnpy.h" + +namespace mllm::models::minicpmo { + +using namespace mllm::nn; + +// Audio Projection Layer for projecting audio features to text embedding space +class AudioProjectionLayer : public Module { + +public: + AudioProjectionLayer() = default; + + AudioProjectionLayer(const std::string& name, int32_t input_dim, int32_t hidden_dim, int32_t output_dim) + : Module(name) { + linear1_ = reg("linear1", input_dim, hidden_dim, true); + relu_ = reg("relu"); + linear2_ = reg("linear2", hidden_dim, output_dim, true); + } + + std::vector forward(const std::vector& inputs, const std::vector& args) override { + auto x = inputs[0]; + x = linear1_(x); + x = relu_(x); + x = linear2_(x); + return {x}; + } + +private: + Linear linear1_; + ReLU relu_; + Linear linear2_; +}; + +// TTS Feature Projector +class TTSProjector : public Module { + +public: + TTSProjector() = default; + + TTSProjector(const std::string& name, int32_t input_dim, int32_t hidden_dim, int32_t output_dim) + : Module(name) { + linear1_ = reg("linear1", input_dim, hidden_dim, true); + relu_ = reg("relu"); + linear2_ = reg("linear2", hidden_dim, output_dim, true); + } + + std::vector forward(const std::vector& inputs, const std::vector& args) override { + auto x = inputs[0]; + x = linear1_(x); + x = relu_(x); + x = linear2_(x); + return {x}; + } + +private: + Linear linear1_; + ReLU relu_; + Linear linear2_; +}; + +// Main MiniCPM-o Model +class MiniCPMOForCausalLM : public models::ARGeneration { +public: + explicit MiniCPMOForCausalLM(const MiniCPMOConfig& config) + : config_(config), + llm_(createLLMConfig(config)), + vpm_("vpm", config), + resampler_("resampler", config.query_num, config.hidden_size, + config.num_attention_heads, config.vision_hidden_size){ + //audio_projection_layer_("audio_projection_layer", config.audio_hidden_size, + //config.hidden_size, config.hidden_size), + //tts_projector_("tts.projector", config.hidden_size, + //config.tts_llm_dim, config.tts_llm_dim) { + // Initialize KV cache like Qwen2VL + kv_cache_ = nn::StaticCache(config.max_cache_length, config.num_hidden_layers, + config.num_attention_heads, // q_heads + config.num_key_value_heads, // kv_heads + config.hidden_size / config.num_attention_heads, // kv_dims + kFloat32, // k_dtype + kFloat32, // v_dtype + kCPU, // device_type + false // use_fa2 + ); + + // Set ARGeneration parameters + eos_token_id_ = config.eos_token_id; + max_length_ = config.max_cache_length; + } + + MiniCPMOConfig config_; + qwen2vl::Qwen2VLForCausalLM llm_; + SiglipVisionModel vpm_; + Resampler resampler_; + //AudioProjectionLayer audio_projection_layer_; + //TTSProjector tts_projector_; + + // Debug flag to control whether to load Python embeddings + bool loadPythonEmbedding = false; + +private: + nn::StaticCache kv_cache_; + +private: + static qwen2vl::Qwen2VLConfig createLLMConfig(const MiniCPMOConfig& config) { + qwen2vl::Qwen2VLConfig llm_config; + llm_config.hidden_size = config.hidden_size; + llm_config.intermediate_size = config.intermediate_size; + llm_config.num_attention_heads = config.num_attention_heads; + llm_config.num_key_value_heads = config.num_key_value_heads; + llm_config.num_hidden_layers = config.num_hidden_layers; + llm_config.max_position_embeddings = config.max_position_embeddings; + llm_config.rms_norm_eps = config.rms_norm_eps; + llm_config.vocab_size = config.vocab_size; + llm_config.rope_theta = config.rope_theta; + llm_config.tie_word_embeddings = config.tie_word_embeddings; + // Set other necessary fields for Qwen2VL compatibility + return llm_config; + } + +public: + + ARGenerationOutputPast forward(const ARGenerationOutputPast& inputs, const ARGenerationArgs& args) override { + + // In prefill stage, get "input_ids", in decode stage we get "sequence" + Tensor input_ids; + if (inputs.count("input_ids")) { + input_ids = inputs.at("input_ids"); + } else if (inputs.count("sequence")) { + input_ids = inputs.at("sequence"); + } else { + mllm::print("ERROR: No input_ids or sequence found!"); + return {}; + } + + Tensor pixel_values = inputs.count("pixel_values") ? inputs.at("pixel_values") : Tensor::nil(); + Tensor tgt_sizes = inputs.count("tgt_sizes") ? inputs.at("tgt_sizes") : Tensor::nil(); + Tensor image_bounds = inputs.count("image_bounds") ? inputs.at("image_bounds") : Tensor::nil(); + Tensor audio_features = inputs.count("audio_features") ? inputs.at("audio_features") : Tensor::nil(); + + Tensor prev_position_ids = inputs.count("position_ids") ? inputs.at("position_ids") : Tensor::nil(); + bool is_decode_stage = !prev_position_ids.isNil(); + + + auto input_embeddings = llm_.llm.embedding_(input_ids); + + // Process vision inputs if provided - ONLY in prefill stage + if (!pixel_values.isNil() && !tgt_sizes.isNil() && !is_decode_stage) { + auto vision_outputs = vpm_(pixel_values, tgt_sizes)[0]; + // std::vector vision_outputs_vec; + // vision_outputs_vec.reserve(10*1036*1152); + // for(int i=0;i<10;i++){ + // for(int j=0;j<1036;j++){ + // for(int k=0;k<1152;k++){ + // vision_outputs_vec.push_back(vision_outputs.at({i,j,k})); + // } + // } + // } + // cnpy::npy_save("vision_outputs.npy", + // vision_outputs_vec.data(), + // {10, 1036, 1152}, + // "w"); + auto vision_embeddings = resampler_(vision_outputs, tgt_sizes)[0]; + std::vector vision_embeddings_vec; + vision_embeddings_vec.reserve(10*64*3584); + for(int i=0;i<10;i++){ + for(int j=0;j<64;j++){ + for(int k=0;k<3584;k++){ + vision_embeddings_vec.push_back(vision_embeddings.at({i,j,k})); + } + } + } + cnpy::npy_save("vision_embeddings.npy", + vision_embeddings_vec.data(), + {10, 64, 3584}, + "w"); + mllm::print(vision_embeddings.shape()); + mllm::print(vision_embeddings.at({0,0,0})); + mllm::print(vision_embeddings.at({0,14,175})); + mllm::print(vision_embeddings.at({1,28,2995})); + mllm::print(vision_embeddings.at({1,33,1365})); + mllm::print(vision_embeddings.at({2,8,764})); + mllm::print(vision_embeddings.at({3,49,2222})); + mllm::print(vision_embeddings.at({4,62,2003})); + mllm::print(vision_embeddings.at({5,55,1013})); + mllm::print(vision_embeddings.at({6,19,75})); + mllm::print(vision_embeddings.at({7,21,196})); + mllm::print(vision_embeddings.at({8,50,1997})); + mllm::print(vision_embeddings.at({9,33,2958})); + mllm::print(vision_embeddings.at({8,2,2598})); + mllm::print(vision_embeddings.at({7,5,338})); + mllm::print(vision_embeddings.at({6,41,1157})); + mllm::print(vision_embeddings.at({5,61,2075})); + mllm::print(vision_embeddings.at({4,55,312})); + + if (!image_bounds.isNil()) { + input_embeddings = merge_vision_text_embeddings(input_embeddings, vision_embeddings, image_bounds); + } + } + + // Process audio inputs if provided + // if (!audio_features.isNil()) { + // auto audio_embeddings = encode_audio(audio_features); + // // TODO: Similarly handle audio embedding insertion + // input_embeddings = merge_audio_text_embeddings(input_embeddings, audio_embeddings, sequence); + // } + + // Create position IDs based on stage + Tensor position_ids; + auto seq_len = input_embeddings.shape()[1]; + + if (is_decode_stage) { + // Decode stage: create [3, 1, 1] position_ids for next token + auto last_pos = *prev_position_ids.offsettedPtr({0, 0, prev_position_ids.shape()[2] - 1}); + position_ids = Tensor::empty({3, 1, 1}, kInt64).alloc(); + position_ids.at({0, 0, 0}) = last_pos + 1; + position_ids.at({1, 0, 0}) = last_pos + 1; + position_ids.at({2, 0, 0}) = last_pos + 1; + } else { + // Prefill stage: create [3, 1, seq_len] position_ids for full sequence + position_ids = Tensor::empty({3, 1, seq_len}, kInt64).alloc(); + // Simple sequential position IDs for all dimensions + for (int d = 0; d < 3; d++) { + for (int s = 0; s < seq_len; s++) { + position_ids.at({d, 0, s}) = s; + } + } + } + + + auto head_dim = config_.hidden_size / config_.num_attention_heads; + + auto inv_freq = llm_.llm.getBuffer("inv_freq"); + + std::vector empty_mrope_section; + + auto [llm_embedding_sin, llm_embedding_cos] = + qwen2vl::makeMultimodalPositionEmbedding(position_ids, inv_freq, + config_.max_position_embeddings, + head_dim, + empty_mrope_section); + + auto output = llm_.llm(input_embeddings, llm_embedding_sin, llm_embedding_cos, + AnyValue(&kv_cache_))[0]; + + ARGenerationOutputPast result = { + {"sequence", output}, + {"position_ids", position_ids} + }; + + if (!pixel_values.isNil()) { + result["pixel_values"] = pixel_values; + } + if (!tgt_sizes.isNil()) { + result["tgt_sizes"] = tgt_sizes; + } + if (!image_bounds.isNil()) { + result["image_bounds"] = image_bounds; + } + + return result; + } + + + // Audio encoding: audio_features -> projection -> text embedding space + // Tensor encode_audio(const Tensor& audio_features) { + // // Project audio features to text embedding space + // auto projected_audio = audio_projection_layer_(audio_features)[0]; + // return projected_audio; + // } + + // TTS feature generation for audio output + // Tensor generate_tts_features(const Tensor& text_hidden_states) { + // // Project text hidden states to TTS feature space + // auto tts_features = tts_projector_(text_hidden_states)[0]; + // return tts_features; + // } + + Tensor merge_vision_text_embeddings(Tensor& text_embeddings, Tensor& vision_embeddings, Tensor& image_bounds) { + + mllm::print(text_embeddings.shape()); + mllm::print(vision_embeddings.shape()); + mllm::print(image_bounds); + auto batch_size = text_embeddings.shape()[0]; //text_embeddings: [1, seq_len, embed_dim] + auto seq_len = text_embeddings.shape()[1]; + auto embed_dim = text_embeddings.shape()[2]; + auto vision_seq_len = vision_embeddings.shape()[1]; //vision_embeddings:[batch_size, query_num, embed_dim] + + if (!image_bounds.isNil() && image_bounds.shape().size() >= 2) { + auto num_bounds = vision_embeddings.shape()[0]; + + for (int b = 0; b < batch_size; ++b) { + for (int bound_idx = 0; bound_idx < num_bounds; ++bound_idx) { + int vision_idx = 0; + auto start_pos = image_bounds.at({bound_idx, 0}) + 1; + auto end_pos = image_bounds.at({bound_idx, 1}) - 1; + for (int pos = start_pos; pos <= end_pos && vision_idx < vision_seq_len; ++pos, ++vision_idx) { + for (int d = 0; d < embed_dim; ++d) { + text_embeddings.at({b, pos, d}) = vision_embeddings.at({bound_idx, vision_idx, d}); + } + } + } + } + } + + mllm::print("finished merging!"); + + // Debug: Load and replace with Python-saved embeddings + if (loadPythonEmbedding) { + cnpy::NpyArray arr = cnpy::npy_load("../../models/merged_input_embedding.npy"); + float* data_ptr = arr.data(); + std::vector vec(data_ptr, data_ptr + arr.num_vals); + auto tt = mllm::Tensor::fromVector(vec, {1,699,3584}, mllm::kFloat32); + mllm::print(tt.shape()); + mllm::print(text_embeddings.shape()); + text_embeddings = tt; + mllm::print("✅ Loaded Python embedding for debugging!"); + return tt; + } + + return text_embeddings; + } + + Tensor merge_audio_text_embeddings(const Tensor& text_embeddings, const Tensor& audio_embeddings, const Tensor& sequence) { + // TODO: Similar to vision embedding fusion + return text_embeddings; + } + + Tensor generate_position_ids(const Tensor& embeddings) { + // Generate simple sequential position IDs + auto batch_size = embeddings.shape()[0]; + auto seq_len = embeddings.shape()[1]; + auto position_ids = Tensor::empty({batch_size, seq_len}, kInt64, kCPU).alloc(); + + auto pos_ptr = position_ids.ptr(); + for (int b = 0; b < batch_size; ++b) { + for (int i = 0; i < seq_len; ++i) { + pos_ptr[b * seq_len + i] = i; + } + } + return position_ids; + } + +}; + + +} // namespace mllm::models::minicpmo diff --git a/mllm/models/minicpm_o2_6/modeling_qwen2vl_for_minicpmo.hpp b/mllm/models/minicpm_o2_6/modeling_qwen2vl_for_minicpmo.hpp new file mode 100644 index 000000000..8d8e60ac5 --- /dev/null +++ b/mllm/models/minicpm_o2_6/modeling_qwen2vl_for_minicpmo.hpp @@ -0,0 +1,925 @@ +// Copyright (c) MLLM Team. +// Licensed under the MIT License. + +#include + +#include "mllm/mllm.hpp" +#include "mllm/nn/Module.hpp" +#include "mllm/nn/Nn.hpp" +#include "mllm/nn/Functional.hpp" +#include "mllm/nn/lmcache/StaticCache.hpp" +#include "mllm/models/qwen2vl/configuration_qwen2vl.hpp" +#include "mllm/utils/Enumerate.hpp" +#include "mllm/models/ARGeneration.hpp" + +namespace mllm::models::qwen2vl { + +inline Tensor makeVisualRoPEInvFreq(int32_t dims, float theta) { + const int half_dim = dims / (2 * 2); + Tensor inv_freq = Tensor::empty({half_dim}, kFloat32).alloc(); + float* inv_freq_ptr = inv_freq.ptr(); + const float dims_inv = 1.0f / static_cast(dims / 2); + for (int i = 0; i < half_dim; ++i) { + const float exponent = (2.0f * i) * dims_inv; + inv_freq_ptr[i] = 1.0f / std::pow(theta, exponent); + } + return inv_freq; +} + +inline Tensor makeVisualRotaryPosEmbIds(Tensor& grid_thw, int32_t spatial_merge_size) { + MLLM_RT_ASSERT_EQ(grid_thw.shape().size(), 2); + + auto img_nums = grid_thw.shape()[0]; + + int total_positions = 0; + for (int row = 0; row < img_nums; ++row) { + const int* dims = grid_thw.offsettedPtr({row, 0}); + const int t = dims[0]; + const int h = dims[1]; + const int w = dims[2]; + total_positions += t * h * w; + } + + Tensor out = Tensor::empty({total_positions, 2}, kInt32).alloc(); + int* out_ptr = out.ptr(); + int out_offset = 0; + + for (int row = 0; row < img_nums; ++row) { + const int* dims = grid_thw.offsettedPtr({row, 0}); + + const int t = dims[0]; + const int h = dims[1]; + const int w = dims[2]; + + const int num_h_blocks = h / spatial_merge_size; + const int num_w_blocks = w / spatial_merge_size; + const int total_blocks = num_h_blocks * num_w_blocks; + const int block_area = spatial_merge_size * spatial_merge_size; + const int grid_size = h * w; + + std::vector flatten_hpos(grid_size); + std::vector flatten_wpos(grid_size); + + for (int block_idx = 0; block_idx < total_blocks; ++block_idx) { + const int i_h = block_idx / num_w_blocks; + const int i_w = block_idx % num_w_blocks; + const int start_idx = block_idx * block_area; + + const int base_h = i_h * spatial_merge_size; + const int base_w = i_w * spatial_merge_size; + + for (int j_h = 0; j_h < spatial_merge_size; ++j_h) { + const int global_h = base_h + j_h; + for (int j_w = 0; j_w < spatial_merge_size; ++j_w) { + const int global_w = base_w + j_w; + const int pos = start_idx + j_h * spatial_merge_size + j_w; + flatten_hpos[pos] = global_h; + flatten_wpos[pos] = global_w; + } + } + } + + for (int frame = 0; frame < t; ++frame) { + for (int pos = 0; pos < grid_size; ++pos) { + const int out_idx = out_offset + (frame * grid_size + pos) * 2; + out_ptr[out_idx] = flatten_hpos[pos]; + out_ptr[out_idx + 1] = flatten_wpos[pos]; + } + } + out_offset += t * grid_size * 2; + } + + return out; +} + +inline Tensor makeVisualRotaryPosEmbFull(Tensor& inv_freq, int seq_len) { + MLLM_RT_ASSERT(seq_len > 0); + const int32_t dim = inv_freq.shape()[0]; + Tensor freqs = Tensor::empty({seq_len, dim}, kFloat32, kCPU).alloc(); + float* inv_freq_ptr = inv_freq.ptr(); + float* freqs_ptr = freqs.ptr(); + for (int i = 0; i < seq_len; ++i) { + const float i_val = static_cast(i); + float* row_ptr = freqs_ptr + i * dim; + for (int j = 0; j < dim; ++j) { row_ptr[j] = i_val * inv_freq_ptr[j]; } + } + return freqs; +} + +std::pair makeVisualRotarySinCos(Tensor& rotary_pos_emb) { + auto seq = rotary_pos_emb.shape()[0]; + auto dim = rotary_pos_emb.shape()[1]; + + auto rotary_pos_emb_ptr = rotary_pos_emb.ptr(); + + Tensor sin_pos_emb = Tensor::empty({seq, dim}, kFloat32, kCPU).alloc(); + Tensor cos_pos_emb = Tensor::empty({seq, dim}, kFloat32, kCPU).alloc(); + + auto sin_pos_emb_ptr = sin_pos_emb.ptr(); + auto cos_pos_emb_ptr = cos_pos_emb.ptr(); + + for (int i = 0; i < seq; i++) { + for (int j = 0; j < dim; j++) { + sin_pos_emb_ptr[i * dim + j] = std::sin(rotary_pos_emb_ptr[i * dim + j]); + cos_pos_emb_ptr[i * dim + j] = std::cos(rotary_pos_emb_ptr[i * dim + j]); + } + } + + return {sin_pos_emb, cos_pos_emb}; +} + +inline Tensor makeVisualRotaryPosEmb(Tensor& rotary_pos_emb_full, Tensor& pos_ids, Tensor& grid_thw) { + const int* grid_dims = grid_thw.offsettedPtr({0, 0}); + const int t = grid_dims[0]; + const int h = grid_dims[1]; + const int w = grid_dims[2]; + + const int32_t num_positions = rotary_pos_emb_full.shape()[0]; + const int32_t dim = rotary_pos_emb_full.shape()[1]; + const int32_t batch_size = pos_ids.shape()[0]; + const int32_t seq_len = pos_ids.shape()[1]; + + // [batch_size, dim] + Tensor out = Tensor::empty({batch_size, seq_len * dim}, kFloat32, kCPU).alloc(); + + auto rotary_pos_emb_full_ptr = rotary_pos_emb_full.ptr(); + auto pos_ids_ptr = pos_ids.ptr(); + auto out_ptr = out.ptr(); + + if (num_positions <= 0 || dim <= 0 || batch_size <= 0) { MLLM_ERROR_EXIT(ExitCode::kSliceOB, "Invalid tensor dimensions"); } + + if (t * h * w != batch_size) { MLLM_ERROR_EXIT(ExitCode::kSliceOB, "Grid dimensions mismatch with batch size"); } + + for (int i = 0; i < batch_size; ++i) { + for (int j = 0; j < seq_len; ++j) { + if ((*pos_ids.offsettedPtr({i, j})) < 0 || (*pos_ids.offsettedPtr({i, j})) >= num_positions) { + MLLM_ERROR_EXIT(ExitCode::kSliceOB, "Position index out of bounds"); + } + } + } + + for (int i = 0; i < batch_size; ++i) { + auto batch_ptr = out.offsettedPtr({i, 0}); + size_t offset = 0; + for (int j = 0; j < seq_len; ++j) { + auto emb_ptr = rotary_pos_emb_full.offsettedPtr({(*pos_ids.offsettedPtr({i, j})), 0}); + std::copy(emb_ptr, emb_ptr + dim, batch_ptr + offset); + offset += dim; + } + } + + return out; +} + +inline auto makeMultimodalRoPEInvFreq(int output_dim, float rope_theta) -> Tensor { + auto inv_freq = Tensor::empty({output_dim / 2}, kFloat32, kCPU).alloc(); + auto inv_freq_ptr = inv_freq.ptr(); + for (int i = 0; i < output_dim / 2; i++) { inv_freq_ptr[i] = 1.0 / std::pow(rope_theta, 2.0 * i / output_dim); } + return inv_freq; +} + +inline auto makeMultimodalPositionEmbedding(Tensor& position_ids, const Tensor& inv_freq, int seq_len, int output_dim, + const std::vector& mrope_section) -> std::pair { + // Position ids shape is [3, 1, seq] + + MLLM_RT_ASSERT_EQ(position_ids.shape().size(), 3); + MLLM_RT_ASSERT_EQ(position_ids.shape()[1], 1); // Batch size is always 1. + + // [3, seq, dim] + Tensor tmp_sin = Tensor::empty({3, position_ids.shape()[2], inv_freq.shape()[0] * 2}).alloc(); + Tensor tmp_cos = Tensor::empty({3, position_ids.shape()[2], inv_freq.shape()[0] * 2}).alloc(); + + for (int b = 0; b < 3; ++b) { + for (int d = 0; d < inv_freq.shape()[0]; ++d) { + for (int s = 0; s < position_ids.shape()[2]; ++s) { + auto value = inv_freq.ptr()[d] * (*position_ids.offsettedPtr({b, 0, s})); + *tmp_cos.offsettedPtr({b, s, d}) = cosf(value); + *tmp_cos.offsettedPtr({b, s, d + inv_freq.shape()[0]}) = cosf(value); + *tmp_sin.offsettedPtr({b, s, d}) = sinf(value); + *tmp_sin.offsettedPtr({b, s, d + inv_freq.shape()[0]}) = sinf(value); + } + } + } + + Tensor sin = Tensor::nil(); + Tensor cos = Tensor::nil(); + + // mrope is always [16, 24, 24] + if (!mrope_section.empty()) { + auto double_rope_section = mrope_section; + for (int i : mrope_section) { double_rope_section.push_back(i); } + + int num_rows = tmp_sin.shape()[1]; + int num_cols = tmp_sin.shape()[2]; + + sin = Tensor::empty({num_rows, num_cols}, kFloat32, kCPU).alloc(); + cos = Tensor::empty({num_rows, num_cols}, kFloat32, kCPU).alloc(); + + std::vector start_cols; + int current_start = 0; + start_cols.push_back(current_start); + for (int s : double_rope_section) { + current_start += s; + start_cols.push_back(current_start); + } + + for (int j = 0; j < double_rope_section.size(); ++j) { + int layer = j % 3; + int s_j = double_rope_section[j]; + int start_col_in = start_cols[j]; + int start_col_out = start_cols[j]; + for (int row = 0; row < num_rows; ++row) { + // Process cos + auto in_cos_row_ptr = tmp_cos.offsettedPtr({layer, row, 0}); + auto out_cos_row_ptr = cos.offsettedPtr({row, 0}); + for (int c = 0; c < s_j; ++c) { out_cos_row_ptr[start_col_out + c] = in_cos_row_ptr[start_col_in + c]; } + + // Process sin + auto in_sin_row_ptr = tmp_sin.offsettedPtr({layer, row, 0}); + auto out_sin_row_ptr = sin.offsettedPtr({row, 0}); + for (int c = 0; c < s_j; ++c) { out_sin_row_ptr[start_col_out + c] = in_sin_row_ptr[start_col_in + c]; } + } + } + } else { + // When mrope_section is empty, we still need to reshape from 3D to 2D for MultimodalRoPE + // The MultimodalRoPE operation expects [seq_len, dim] format + int num_rows = tmp_sin.shape()[1]; // seq_len + int num_cols = tmp_sin.shape()[2]; // dim + + sin = Tensor::empty({num_rows, num_cols}, kFloat32, kCPU).alloc(); + cos = Tensor::empty({num_rows, num_cols}, kFloat32, kCPU).alloc(); + + // Copy data from the first dimension (dimension 0) of tmp_sin/cos to the output + // This is a simplified approach - we take the first "layer" of the 3D tensor + for (int row = 0; row < num_rows; ++row) { + auto src_sin_ptr = tmp_sin.offsettedPtr({0, row, 0}); + auto src_cos_ptr = tmp_cos.offsettedPtr({0, row, 0}); + auto dst_sin_ptr = sin.offsettedPtr({row, 0}); + auto dst_cos_ptr = cos.offsettedPtr({row, 0}); + + for (int col = 0; col < num_cols; ++col) { + dst_sin_ptr[col] = src_sin_ptr[col]; + dst_cos_ptr[col] = src_cos_ptr[col]; + } + } + } + + + return {sin, cos}; +} + +class PatchEmbed final : public nn::Module { + int32_t in_chans_; + int32_t embed_dim_; + int32_t patch_size_; + int32_t temporal_patch_size_; + + nn::Conv3D proj_; + + public: + PatchEmbed() = default; + + inline PatchEmbed(const std::string& name, const Qwen2VLConfig& cfg) : nn::Module(name) { + in_chans_ = cfg.visual_in_chans; + embed_dim_ = cfg.visual_embed_dim; + patch_size_ = cfg.visual_patch_size; + temporal_patch_size_ = cfg.visual_temporal_patch_size; + + proj_ = reg("proj", cfg.visual_in_chans, cfg.visual_embed_dim, + std::vector{cfg.visual_temporal_patch_size, cfg.visual_patch_size, cfg.visual_patch_size}, + std::vector{cfg.visual_temporal_patch_size, cfg.visual_patch_size, cfg.visual_patch_size}, + false); + } + + std::vector forward(const std::vector& inputs, const std::vector& args) override { + auto hidden_states = inputs[0]; + + // [batch_size(x), in_channel(3), temporal_patch_size(2), patch_size(14), patch_size(14)] + hidden_states = hidden_states.view({-1, in_chans_, temporal_patch_size_, patch_size_, patch_size_}); + hidden_states = proj_(hidden_states).view({-1, embed_dim_}); + + return {hidden_states}; + } +}; + +class PatchMerger final : public nn::Module { + int32_t hidden_size_; + int32_t spatial_merge_size_; + int32_t context_dim_; + + nn::LayerNorm ln_q_; + nn::Linear mlp_0_; + nn::Linear mlp_2_; + nn::GELU mlp_gelu_; + + public: + PatchMerger() = default; + + inline PatchMerger(const std::string& name, const Qwen2VLConfig& cfg) : nn::Module(name) { + context_dim_ = cfg.visual_embed_dim; + spatial_merge_size_ = cfg.visual_spatial_merge_size; + hidden_size_ = context_dim_ * spatial_merge_size_ * spatial_merge_size_; + + ln_q_ = reg("ln_q", std::vector{context_dim_}, true, true, 1e-6); + mlp_0_ = reg("mlp.0", hidden_size_, hidden_size_, true, cfg.linear_impl_type); + mlp_gelu_ = reg("mlp.gelu"); + mlp_2_ = reg("mlp.2", hidden_size_, cfg.hidden_size, true, cfg.linear_impl_type); + } + + std::vector forward(const std::vector& inputs, const std::vector& args) override { + auto o = ln_q_(inputs[0]).view({-1, hidden_size_}); + o = mlp_0_(o); + o = mlp_gelu_(o); + o = mlp_2_(o); + return {o}; + } +}; + +class VisionMlp final : public nn::Module { + int32_t dim_; + int32_t hidden_dim_; + + nn::QuickGELU act_; + nn::Linear fc_1_; + nn::Linear fc_2_; + + public: + VisionMlp() = default; + + inline VisionMlp(const std::string& name, const Qwen2VLConfig& cfg) : nn::Module(name) { + dim_ = cfg.visual_embed_dim; + hidden_dim_ = cfg.visual_embed_dim * cfg.visual_mlp_ratio; + + fc_1_ = reg("fc1", dim_, hidden_dim_, true, cfg.linear_impl_type); + fc_2_ = reg("fc2", hidden_dim_, dim_, true, cfg.linear_impl_type); + act_ = reg("act"); + } + + std::vector forward(const std::vector& inputs, const std::vector& args) override { + return {fc_2_(act_(fc_1_(inputs[0])))}; + } +}; + +class VisionAttention final : public nn::Module { + int32_t dim_; + int32_t num_heads_; + int32_t head_dim_; + int32_t num_key_value_groups = 1; + float scaling = 0.f; + + nn::Linear qkv_; + nn::Linear proj_; + nn::Softmax softmax_; + nn::VisionRoPE vision_rope_q_; + nn::VisionRoPE vision_rope_k_; + + public: + VisionAttention() = default; + + inline VisionAttention(const std::string& name, const Qwen2VLConfig& cfg) : nn::Module(name) { + dim_ = cfg.visual_embed_dim; + num_heads_ = cfg.visual_num_heads; + head_dim_ = dim_ / num_heads_; + scaling = std::sqrt(head_dim_); + + qkv_ = reg("qkv", dim_, dim_ * 3, true, cfg.linear_impl_type); + proj_ = reg("proj", dim_, dim_, true, cfg.linear_impl_type); + softmax_ = reg("softmax", -1); + + vision_rope_q_ = reg("vision_rope_q", aops::VisionRoPEOpOptionsType::kQwen2VL, + aops::Qwen2VLRoPEOpOptions{ + .dims = head_dim_, + .spatial_merge_size = cfg.visual_spatial_merge_size, + .theta = 10000.0, + }); + vision_rope_k_ = reg("vision_rope_k", aops::VisionRoPEOpOptionsType::kQwen2VL, + aops::Qwen2VLRoPEOpOptions{ + .dims = head_dim_, + .spatial_merge_size = cfg.visual_spatial_merge_size, + .theta = 10000.0, + }); + } + + std::vector forward(const std::vector& inputs, const std::vector& args) override { + // hidden_states shape is [seq_length, dim] + auto hidden_states = inputs[0]; + auto visual_embedding_sin = inputs[1]; + auto visual_embedding_cos = inputs[2]; + + auto seq_length = hidden_states.shape()[0]; + + auto [query_states, key_states, value_states] = + nn::functional::split<3>(qkv_(hidden_states).view({seq_length, 3, num_heads_, -1}).permute({1, 0, 2, 3}), 1, 0); + + // Input to Vision ROPE must be BSHD format + // grid_thw shape is [n, 3], n is always 1 in this case. + query_states = vision_rope_q_(query_states, visual_embedding_sin, visual_embedding_cos); + key_states = vision_rope_k_(key_states, visual_embedding_sin, visual_embedding_cos); + + // [B, H, S, D] + query_states = query_states.transpose(1, 2); + key_states = key_states.transpose(1, 2); + value_states = value_states.transpose(1, 2); + + // attention weight + // [B=1, H, S, S] + auto attn = nn::functional::matmul(query_states, key_states, false, true) * (1.f / sqrtf(head_dim_)); + attn = softmax_(attn); + + // attn output + // [B=1, H, S, S] @ [B=1, H, S, D] -> [B=1, H, S, D] + auto attn_output = nn::functional::matmul(attn, value_states); + + // [B=1, H, S, D] -> [B=1, S, H, D] -> [S, H * D] + attn_output = attn_output.transpose(1, 2).view({seq_length, -1}); + attn_output = proj_(attn_output); + return { + attn_output, + }; + } +}; + +class Qwen2VLVisionBlock final : public nn::Module { + int mlp_hidden_dim_; + + nn::LayerNorm norm1_; + nn::LayerNorm norm2_; + + VisionAttention attn_; + VisionMlp mlp_; + + public: + Qwen2VLVisionBlock() = default; + + inline Qwen2VLVisionBlock(const std::string& name, const Qwen2VLConfig& cfg) : nn::Module(name) { + mlp_hidden_dim_ = cfg.visual_mlp_ratio * cfg.visual_embed_dim; + norm1_ = reg("norm1", std::vector{cfg.visual_embed_dim}, true, true, 1e-6); + norm2_ = reg("norm2", std::vector{cfg.visual_embed_dim}, true, true, 1e-6); + attn_ = reg("attn", cfg); + mlp_ = reg("mlp", cfg); + } + + std::vector forward(const std::vector& inputs, const std::vector& args) override { + auto hidden_states = inputs[0]; + auto visual_embedding_sin = inputs[1]; + auto visual_embedding_cos = inputs[2]; + + hidden_states = hidden_states + attn_(norm1_(hidden_states), visual_embedding_sin, visual_embedding_cos)[0]; + hidden_states = hidden_states + mlp_(norm2_(hidden_states))[0]; + return {hidden_states}; + } +}; + +class Qwen2VisionTransformerPretrainedModel final : public nn::Module { + PatchEmbed patch_embed_; + PatchMerger patch_merger_; + nn::ModuleList blocks_; + + public: + Qwen2VisionTransformerPretrainedModel() = default; + + Qwen2VisionTransformerPretrainedModel(const std::string& name, const Qwen2VLConfig& cfg) : nn::Module(name) { + patch_embed_ = reg("patch_embed", cfg); + patch_merger_ = reg("merger", cfg); + blocks_ = reg>("blocks", cfg.visual_depth, cfg); + } + + std::vector forward(const std::vector& inputs, const std::vector& args) override { + auto hidden_states = inputs[0]; + auto embedding_sin = inputs[1]; + auto embedding_cos = inputs[2]; + + hidden_states = patch_embed_(hidden_states)[0]; + + for (auto& b : blocks_.list()) { hidden_states = b(hidden_states, embedding_sin, embedding_cos)[0]; } + + hidden_states = patch_merger_(hidden_states)[0]; + + return {hidden_states}; + } +}; + +class Qwen2VLMLP final : public nn::Module { + nn::Linear gate_proj_; + nn::Linear up_proj_; + nn::Linear down_proj_; + nn::SiLU silu_; + + public: + Qwen2VLMLP() = default; + Qwen2VLMLP(const std::string& name, const Qwen2VLConfig& cfg) : nn::Module(name) { + gate_proj_ = reg("gate_proj", cfg.hidden_size, cfg.intermediate_size, false, cfg.linear_impl_type); + silu_ = reg("act"); + up_proj_ = reg("up_proj", cfg.hidden_size, cfg.intermediate_size, false, cfg.linear_impl_type); + down_proj_ = reg("down_proj", cfg.intermediate_size, cfg.hidden_size, false, cfg.linear_impl_type); + } + + std::vector forward(const std::vector& inputs, const std::vector& args) override { + auto x = gate_proj_(inputs[0]); + x = silu_(x); + auto y = up_proj_(inputs[0]); + x = x * y; + x = down_proj_(x); + return {x}; + } +}; + +class Qwen2VLAttention final : public nn::Module { + nn::Linear q_proj_; + nn::Linear k_proj_; + nn::Linear v_proj_; + nn::Linear o_proj_; + nn::MultimodalRoPE q_rope_; + nn::MultimodalRoPE k_rope_; + nn::CausalMask mask_; + nn::Softmax softmax_; + + int hidden_size_; + int head_dim_; + int num_attention_heads_; + int num_key_value_heads_; + int num_key_value_groups_; + + public: + Qwen2VLAttention() = default; + + Qwen2VLAttention(const std::string& name, const Qwen2VLConfig& cfg) : nn::Module(name) { + hidden_size_ = cfg.hidden_size; + num_attention_heads_ = cfg.num_attention_heads; + num_key_value_heads_ = cfg.num_key_value_heads; + head_dim_ = hidden_size_ / num_attention_heads_; + num_key_value_groups_ = num_attention_heads_ / num_key_value_heads_; + + q_proj_ = reg("q_proj", hidden_size_, head_dim_ * num_attention_heads_, true, cfg.linear_impl_type); + k_proj_ = reg("k_proj", hidden_size_, head_dim_ * num_key_value_heads_, true, cfg.linear_impl_type); + v_proj_ = reg("v_proj", hidden_size_, head_dim_ * num_key_value_heads_, true, cfg.linear_impl_type); + o_proj_ = reg("o_proj", head_dim_ * num_attention_heads_, hidden_size_, false, cfg.linear_impl_type); + + q_rope_ = reg( + "q_rope", aops::Qwen2VLMultimodalRoPEOpOptions{.rope_theta = cfg.rope_theta, + .max_position_embeddings = cfg.max_position_embeddings, + .mrope_section = cfg.mrope_section}); + k_rope_ = reg( + "k_rope", aops::Qwen2VLMultimodalRoPEOpOptions{.rope_theta = cfg.rope_theta, + .max_position_embeddings = cfg.max_position_embeddings, + .mrope_section = cfg.mrope_section}); + + mask_ = reg("mask"); + softmax_ = reg("softmax", -1); + } + + std::vector forward(const std::vector& inputs, const std::vector& args) override { + auto x = inputs[0]; + auto llm_embedding_sin = inputs[1]; + auto llm_embedding_cos = inputs[2]; + auto past_kv_cache = args[0].get(); + + // [B, S, H * D] + auto query_states = q_proj_(x); + auto key_states = k_proj_(x); + auto value_states = v_proj_(x); + + int B = inputs[0].shape()[0]; + int S = inputs[0].shape()[1]; + + // [B, S, H, D] + query_states = query_states.view({B, S, num_attention_heads_, head_dim_}); + key_states = key_states.view({B, S, num_key_value_heads_, head_dim_}); + value_states = value_states.view({B, S, num_key_value_heads_, head_dim_}); + + // [B, H, S, D] + query_states = query_states.transpose(1, 2); + key_states = key_states.transpose(1, 2); + value_states = value_states.transpose(1, 2); + + // [B, H, S, D] + query_states = q_rope_(query_states, llm_embedding_sin, llm_embedding_cos); + key_states = k_rope_(key_states, llm_embedding_sin, llm_embedding_cos); + + // [B, H, S, D] + auto [k, v] = past_kv_cache->updateKVCache(layer_idx_, key_states, value_states); + key_states = k; + value_states = v; + + Tensor attn; + if (key_states.dtype() == kFloat32) { + // attention weight + // [B, H, S, S] + attn = nn::functional::matmul(query_states, key_states, false, true) * (1.f / sqrtf(head_dim_)); + attn = mask_(attn); + attn = softmax_(attn); + } else if (key_states.dtype() == kFloat16) { + attn = nn::functional::matmul(query_states.to(kFloat32), key_states.to(kFloat32), false, true) * (1.f / sqrtf(head_dim_)); + attn = mask_(attn); + attn = softmax_(attn); + attn = attn.to(kFloat16); + } + + // attn output + // [B, H, S, S] @ [B, H, S, D] -> [B, H, S, D] + auto output = nn::functional::matmul(attn, value_states); + // [B, H, S, D] -> [B, S, H, D] -> [B, S, H * D] + output = output.transpose(1, 2).view({B, S, num_attention_heads_ * head_dim_}); + output = o_proj_(output); + return {output}; + } + + int layer_idx_; +}; + +class Qwen2VLDecoder final : public nn::Module { + public: + Qwen2VLAttention self_attn_; + Qwen2VLMLP mlp_; + nn::RMSNorm input_layer_norm_; + nn::RMSNorm post_attention_layer_norm_; + + Qwen2VLDecoder() = default; + + Qwen2VLDecoder(const std::string& name, const Qwen2VLConfig& cfg) : nn::Module(name) { + self_attn_ = reg("self_attn", cfg); + mlp_ = reg("mlp", cfg); + input_layer_norm_ = reg("input_layernorm", cfg.rms_norm_eps); + post_attention_layer_norm_ = reg("post_attention_layernorm", cfg.rms_norm_eps); + } + + std::vector forward(const std::vector& inputs, const std::vector& args) override { + auto llm_embedding_sin = inputs[1]; + auto llm_embedding_cos = inputs[2]; + auto& kv_cache = args[0]; + + auto x = input_layer_norm_(inputs[0]); + x = self_attn_(x, llm_embedding_sin, llm_embedding_cos, kv_cache)[0]; + auto tmp = x + inputs[0]; + x = post_attention_layer_norm_(tmp); + x = mlp_(x)[0]; + x = x + tmp; + return {x}; + } +}; + +class Qwen2VLText final : public nn::Module { + nn::ModuleList decode_blocks_; + nn::RMSNorm norm_; + nn::Linear lm_head_; + bool tie_word_embeddings_; + + public: + Qwen2VLText() = default; + + Qwen2VLText(const std::string& name, const Qwen2VLConfig& cfg) : nn::Module(name) { + tie_word_embeddings_ = cfg.tie_word_embeddings; + + decode_blocks_ = reg>("model.layers", cfg.num_hidden_layers, cfg); + for (auto [idx, b] : enumerate(decode_blocks_.list())) { b.self_attn_.layer_idx_ = idx; } + + norm_ = reg("model.norm", cfg.rms_norm_eps); + embedding_ = reg("model.embed_tokens", cfg.vocab_size, cfg.hidden_size); + if (cfg.tie_word_embeddings) { + // NOTE: + // model.lm_head.weight is quantization weights of model.embed_tokens.weight + lm_head_ = reg("lm_head", cfg.hidden_size, cfg.vocab_size, false, cfg.linear_impl_type); + } + + // Init inv freq + auto inv = makeMultimodalRoPEInvFreq(cfg.hidden_size / cfg.num_attention_heads, cfg.rope_theta); + registerBuffer("inv_freq", inv); + } + + std::vector forward(const std::vector& inputs, const std::vector& args) override { + auto& blocks = decode_blocks_.list(); + + // X is already embedded + auto x = inputs[0]; + auto llm_embedding_sin = inputs[1]; + auto llm_embedding_cos = inputs[2]; + auto& kv_cache = args[0]; + + for (auto& block : blocks) { x = block(x, llm_embedding_sin, llm_embedding_cos, kv_cache)[0]; } + x = norm_(x); + + // clip x to one seq length + { + auto S = x.shape()[1]; + x = x[{kAll, {S - 1}, kAll}]; + } + + if (tie_word_embeddings_) { x = lm_head_(x); } + + return {x}; + } + + nn::Embedding embedding_; +}; + +class Qwen2VLForCausalLM : public ARGeneration { + public: + explicit Qwen2VLForCausalLM(const Qwen2VLConfig& cfg) : cfg(cfg), llm("llm", cfg), visual("visual", cfg) { + kv_cache_ = nn::StaticCache(cfg.max_cache_length, cfg.num_hidden_layers, + cfg.num_attention_heads, // q_heads + cfg.num_key_value_heads, // kv_heads + cfg.hidden_size / cfg.num_attention_heads, // kv_dims + kFloat32, // k_dtype + kFloat32, // v_dtype + kCPU, // device_type + false // use_fa2 + ); + eos_token_id_ = cfg.end_of_text_token_id; + max_length_ = cfg.max_cache_length; + } + + ARGenerationOutputPast forward(const ARGenerationOutputPast& input, const ARGenerationArgs& args) override { + auto sequence = input.at("sequence"); + + // Calculate the text embeddings + auto input_embeddings = llm.embedding_(sequence); + + if (input.count("img")) { + auto img = input.at("img"); + auto grid_thw = input.at("grid_thw"); + + // process img + print("ViT Processing: ..."); + print("Image shape is:", img.shape()); + + auto v_len = img.shape()[0]; + auto inv_freq = makeVisualRoPEInvFreq(cfg.visual_embed_dim / cfg.visual_num_heads, 10000.0); + auto pos_ids = makeVisualRotaryPosEmbIds(grid_thw, cfg.visual_spatial_merge_size); + auto rotary_pos_emb_full = makeVisualRotaryPosEmbFull(inv_freq, v_len); + auto pos_emb = makeVisualRotaryPosEmb(rotary_pos_emb_full, pos_ids, grid_thw); + auto [visual_embedding_sin, visual_embedding_cos] = makeVisualRotarySinCos(pos_emb); + + auto start_time = std::chrono::high_resolution_clock::now(); + auto visual_embeddings = visual(img, visual_embedding_sin, visual_embedding_cos)[0]; + auto end_time = std::chrono::high_resolution_clock::now(); + auto all_time = std::chrono::duration_cast(end_time - start_time); + print("ViT Processing: done, time cost: {} seconds", all_time.count()); + + // Insert visual embeddings into llm's embedding + int32_t vision_pad_token_start = -1; + { + auto& input_ids = sequence; + auto S = input_ids.shape()[1]; + auto input_ids_ptr = input_ids.ptr(); + for (int s = 0; s < S; ++s) { + if (input_ids_ptr[s] == cfg.vision_token_id) { + vision_pad_token_start = s; + break; + } + } + MLLM_RT_ASSERT(vision_pad_token_start != -1); + } + // input_embedding is [B, S, D] + auto D = input_embeddings.shape()[2]; + auto visual_sequence = visual_embeddings.shape()[0]; + visual_embeddings.copy2( + input_embeddings[{kAll, {vision_pad_token_start, vision_pad_token_start + visual_sequence}, kAll}]); + } + + auto position_ids = Tensor::nil(); + if (input.count("img")) { + auto img = input.at("img"); + auto grid_thw = input.at("grid_thw"); + position_ids = getPositionIds(img, grid_thw, sequence, position_ids, cfg); + } else { + auto img = Tensor::nil(); + auto grid_thw = Tensor::nil(); + position_ids = input.at("position_ids"); + position_ids = getPositionIds(img, grid_thw, sequence, position_ids, cfg); + } + + // Generate position ids and embedding sin and cos + auto [llm_embedding_sin, llm_embedding_cos] = + makeMultimodalPositionEmbedding(position_ids, llm.getBuffer("inv_freq"), cfg.max_position_embeddings, + cfg.hidden_size / cfg.num_attention_heads, cfg.mrope_section); + + sequence = llm(input_embeddings, llm_embedding_sin, llm_embedding_cos, AnyValue(&kv_cache_))[0]; + + return { + {"sequence", sequence}, + {"position_ids", position_ids}, + }; + } + + inline Tensor getPositionIds(Tensor& img, Tensor& grid_thw, Tensor& sequence, Tensor& position_ids, + const Qwen2VLConfig& cfg) { + // Input is [B, S, D] + if (!img.isNil()) { // Prefill + return getPositionIdsPrefill(sequence, grid_thw, cfg); + } else { // Decode + auto last_pos = *position_ids.offsettedPtr({0, 0, position_ids.shape()[2] - 1}); + auto ret_position_ids = Tensor::empty({3, 1, 1}, kInt64, kCPU).alloc(); + *ret_position_ids.offsettedPtr({0, 0, 0}) = last_pos + 1; + *ret_position_ids.offsettedPtr({1, 0, 0}) = last_pos + 1; + *ret_position_ids.offsettedPtr({2, 0, 0}) = last_pos + 1; + return ret_position_ids; + } + } + + inline Tensor getPositionIdsPrefill(Tensor& input_ids, Tensor& image_grid_thw, const Qwen2VLConfig& cfg) { + // Input is [B, S] + MLLM_RT_ASSERT_EQ(input_ids.shape().size(), 2); + // image_grid_thw is [num_images, 3] + MLLM_RT_ASSERT_EQ(image_grid_thw.shape().size(), 2); + + auto B = input_ids.shape()[0]; + MLLM_RT_ASSERT_EQ(B, 1); + auto S = input_ids.shape()[1]; + + Tensor position_ids = Tensor::empty({3, B, S}, kInt64, kCPU).alloc(); + + // Process text and visual + // 1. Find the place of the first image token + // Only one image is supported. + int32_t vision_pad_token_start = -1; + { + auto input_ids_ptr = input_ids.ptr(); + for (int s = 0; s < S; ++s) { + if (input_ids_ptr[s] == cfg.vision_token_id) { + vision_pad_token_start = s; + break; + } + } + MLLM_RT_ASSERT(vision_pad_token_start != -1); + } + + // 2. Calculate grid dimensions + int img_t, img_h, img_w; + int inputs_t, inputs_h, inputs_w; + { + auto image_grid_thw_ptr = image_grid_thw.ptr(); + img_t = image_grid_thw_ptr[0]; + img_h = image_grid_thw_ptr[1]; + img_w = image_grid_thw_ptr[2]; + + inputs_t = img_t; + inputs_h = img_h / cfg.visual_spatial_merge_size; + inputs_w = img_w / cfg.visual_spatial_merge_size; + } + + // 3. We assume the inputs format is: T T T V V V T T T + int64_t current_max_position_id = 0; + // 3.1 Handle text (Sys token as usual). + { + int64_t start_idx = current_max_position_id; + for (int d = 0; d < 3; ++d) { + auto position_ids_ptr = position_ids.offsettedPtr({d, 0, 0}); + for (int64_t k = 0; k < vision_pad_token_start; ++k) { position_ids_ptr[k] = start_idx + k; } + } + current_max_position_id = vision_pad_token_start - 1; + } + // 3.2 Handle image + { + int _cnt = 0; + int64_t vision_start_id = current_max_position_id + 1; + for (int64_t ti = 0; ti < inputs_t; ++ti) { + for (int64_t hi = 0; hi < inputs_h; ++hi) { + for (int64_t wi = 0; wi < inputs_w; ++wi) { + *position_ids.offsettedPtr({0, 0, vision_pad_token_start + _cnt}) = vision_start_id + ti; + + *position_ids.offsettedPtr({1, 0, vision_pad_token_start + _cnt}) = vision_start_id + hi; + + *position_ids.offsettedPtr({2, 0, vision_pad_token_start + _cnt}) = vision_start_id + wi; + + _cnt++; + } + } + } + auto dim_0_tail = + *position_ids.offsettedPtr({0, 0, vision_pad_token_start + inputs_t * inputs_h * inputs_w - 1}); + auto dim_1_tail = + *position_ids.offsettedPtr({1, 0, vision_pad_token_start + inputs_t * inputs_h * inputs_w - 1}); + auto dim_2_tail = + *position_ids.offsettedPtr({2, 0, vision_pad_token_start + inputs_t * inputs_h * inputs_w - 1}); + current_max_position_id = std::max({dim_0_tail, dim_1_tail, dim_2_tail}); + } + // 3.3 Handle Prompt + { + const int64_t vision_token_count = inputs_t * inputs_h * inputs_w; + const int64_t trailing_text_start_seq = vision_pad_token_start + vision_token_count; + const int64_t trailing_text_count = S - trailing_text_start_seq; + + if (trailing_text_count > 0) { + int64_t start_id = current_max_position_id + 1; + for (int d = 0; d < 3; ++d) { + auto position_ids_ptr = position_ids.offsettedPtr({d, 0, 0}); + for (int64_t k = 0; k < trailing_text_count; ++k) { + const int64_t seq_idx = trailing_text_start_seq + k; + position_ids_ptr[seq_idx] = start_id + k; + } + } + } + } + + return position_ids; + } + + const Qwen2VLConfig& cfg; + Qwen2VLText llm; + Qwen2VisionTransformerPretrainedModel visual; + + private: + nn::StaticCache kv_cache_; +}; + +} // namespace mllm::models::qwen2vl diff --git a/mllm/models/minicpm_o2_6/modeling_resampler.hpp b/mllm/models/minicpm_o2_6/modeling_resampler.hpp new file mode 100644 index 000000000..e14c6d24a --- /dev/null +++ b/mllm/models/minicpm_o2_6/modeling_resampler.hpp @@ -0,0 +1,407 @@ +// Copyright (c) MLLM Team. +// Licensed under the MIT License. + +#pragma once + +#include "mllm/mllm.hpp" +#include "mllm/models/minicpm_o2_6/modeling_vector_quantize.hpp" +#include "mllm/nn/Module.hpp" +#include "mllm/nn/Nn.hpp" +#include "mllm/nn/layers/LayerNorm.hpp" +#include "mllm/nn/layers/Linear.hpp" +#include "mllm/nn/layers/Param.hpp" +#include "mllm/nn/layers/Softmax.hpp" +#include "mllm/nn/Functional.hpp" +#include "mllm/utils/Common.hpp" +#include +#include +#include +#include "cnpy.h" + +namespace mllm::models::minicpmo { + +inline Tensor get2DSinCosPosEmbed(int32_t embed_dim, const std::vector& image_size) { + int32_t grid_h_size = (image_size.size() == 1) ? image_size[0] : image_size[0]; + int32_t grid_w_size = (image_size.size() == 1) ? image_size[0] : image_size[1]; + + Tensor pos_embed = Tensor::empty({grid_h_size, grid_w_size, embed_dim}, kFloat32).alloc(); + + int32_t half_dim = embed_dim / 2; + + for (int32_t h = 0; h < grid_h_size; ++h) { + for (int32_t w = 0; w < grid_w_size; ++w) { + + // Width encoding: first half_dim dimensions (根据Python输出推断) + // [0:half_dim/2] = sin(w*omega), [half_dim/2:half_dim] = cos(w*omega) + for (int32_t i = 0; i < half_dim / 2; ++i) { + float omega = 1.0f / std::pow(10000.0f, 2.0f * i / half_dim); + *pos_embed.offsettedPtr({h, w, i}) = std::sin(w * omega); + *pos_embed.offsettedPtr({h, w, i + half_dim / 2}) = std::cos(w * omega); + } + + // Height encoding: second half_dim dimensions + // [half_dim:half_dim+half_dim/2] = sin(h*omega), [half_dim+half_dim/2:embed_dim] = cos(h*omega) + for (int32_t i = 0; i < half_dim / 2; ++i) { + float omega = 1.0f / std::pow(10000.0f, 2.0f * i / half_dim); + *pos_embed.offsettedPtr({h, w, half_dim + i}) = std::sin(h * omega); + *pos_embed.offsettedPtr({h, w, half_dim + i + half_dim / 2}) = std::cos(h * omega); + } + } + } + + return pos_embed; +} + +class ResamplerAttention : public nn::Module { + int32_t embed_dim_; + int32_t num_heads_; + int32_t head_dim_; + nn::Param in_proj_weight_; + nn::Param in_proj_bias_; + nn::Linear out_proj_; + +public: + ResamplerAttention() = default; + + ResamplerAttention(const std::string& name, int32_t embed_dim, int32_t num_heads) + : nn::Module(name), embed_dim_(embed_dim), num_heads_(num_heads) { + head_dim_ = embed_dim_ / num_heads_; + + // in_proj_weight [3*embed_dim, embed_dim] + in_proj_weight_ = reg("in_proj_weight", getModuleName() + ".in_proj.weight", + Tensor::shape_t{3 * embed_dim_, embed_dim_}); + in_proj_bias_ = reg("in_proj_bias", getModuleName() + ".in_proj.bias", + Tensor::shape_t{3 * embed_dim_}); + out_proj_ = reg("out_proj", embed_dim_, embed_dim_, true); + } + + std::vector forward(const std::vector& inputs, const std::vector& args) override { + auto query = inputs[0]; // [num_queries, embed_dim] + auto key = inputs[1]; // [seq_len, embed_dim] + auto value = inputs[2]; // [seq_len, embed_dim] + + auto key_padding_mask = Tensor(); + bool has_key_padding_mask = false; + if (inputs.size() > 3) { + key_padding_mask = inputs[3]; // [seq_len] optional + has_key_padding_mask = true; + } + + auto num_queries = query.shape()[0]; + auto seq_len = key.shape()[0]; + + // Perform packed in-projection: [query|key|value] = input @ in_proj_weight.T + in_proj_bias + // For cross-attention: q comes from query, k,v come from key_value + + auto q_weight = Tensor::empty({embed_dim_, embed_dim_}, kFloat32).alloc(); + auto k_weight = Tensor::empty({embed_dim_, embed_dim_}, kFloat32).alloc(); + auto v_weight = Tensor::empty({embed_dim_, embed_dim_}, kFloat32).alloc(); + for(int i = 0; i < embed_dim_; i++){ + for(int j = 0; j < embed_dim_; j++){ + *q_weight.offsettedPtr({i, j}) = in_proj_weight_.weight().at({i, j}); + *k_weight.offsettedPtr({i, j}) = in_proj_weight_.weight().at({embed_dim_ + i, j}); + *v_weight.offsettedPtr({i, j}) = in_proj_weight_.weight().at({2 * embed_dim_ + i, j}); + } + } + + auto q_bias = Tensor::empty({embed_dim_}, kFloat32).alloc(); + auto k_bias = Tensor::empty({embed_dim_}, kFloat32).alloc(); + auto v_bias = Tensor::empty({embed_dim_}, kFloat32).alloc(); + for(int i = 0; i < embed_dim_; i++){ + *q_bias.offsettedPtr({i}) = in_proj_bias_.weight().at({i}); + *k_bias.offsettedPtr({i}) = in_proj_bias_.weight().at({embed_dim_ + i}); + *v_bias.offsettedPtr({i}) = in_proj_bias_.weight().at({2 * embed_dim_ + i}); + } + + auto q = nn::functional::matmul(query, q_weight, false, true); + auto k = nn::functional::matmul(key, k_weight, false, true); + auto v = nn::functional::matmul(value, v_weight, false, true); + + for(int i=0; i({i, j}) += q_bias.at({j}); + } + } + + for(int i=0; i({i, j}) += k_bias.at({j}); + *v.offsettedPtr({i, j}) += v_bias.at({j}); + } + } + + for(int i=0; i({i, j}) += k_bias.at({j}); + } + } + + auto q_reshaped = Tensor::empty({num_heads_, num_queries, head_dim_}, kFloat32).alloc(); + for(int nq = 0; nq < num_queries; nq++) { + for(int h = 0; h < num_heads_; h++) { + for(int d = 0; d < head_dim_; d++) { + float val = q.at({nq, h * head_dim_ + d}); + *q_reshaped.offsettedPtr({h, nq, d}) = val; + } + } + } + q = q_reshaped; // [num_heads, num_queries, head_dim] + auto k_reshaped = Tensor::empty({num_heads_, seq_len, head_dim_}, kFloat32).alloc(); + for(int s = 0; s < seq_len; s++) { + for(int h = 0; h < num_heads_; h++) { + for(int d = 0; d < head_dim_; d++) { + float val = k.at({s, h * head_dim_ + d}); + *k_reshaped.offsettedPtr({h, s, d}) = val; + } + } + } + k = k_reshaped; + auto v_reshaped = Tensor::empty({num_heads_, seq_len, head_dim_}, kFloat32).alloc(); + for(int s = 0; s < seq_len; s++) { + for(int h = 0; h < num_heads_; h++) { + for(int d = 0; d < head_dim_; d++) { + float val = v.at({s, h * head_dim_ + d}); + *v_reshaped.offsettedPtr({h, s, d}) = val; + } + } + } + v = v_reshaped; + + + auto scale = 1.0f / std::sqrt(static_cast(head_dim_)); + auto attn_weights = nn::functional::matmul(q, k, false, true) * scale; // [num_heads, num_queries, seq_len] + + if (has_key_padding_mask && key_padding_mask.numel() > 0) { + mllm::print("Applying key padding mask in ResamplerAttention"); + auto mask_value = -std::numeric_limits::infinity(); + for (int32_t h = 0; h < num_heads_; ++h) { + for (int32_t q_idx = 0; q_idx < num_queries; ++q_idx) { + for (int32_t s = 0; s < seq_len; ++s) { + if (key_padding_mask.at({s}) == 1) { + *attn_weights.offsettedPtr({h, q_idx, s}) = mask_value; + } + } + } + } + } + + + + attn_weights = nn::functional::softmax(attn_weights.unsqueeze(0), -1).squeeze(0); + + auto attn_output = nn::functional::matmul(attn_weights, v); // [num_heads, num_queries, head_dim] + + auto attn_output_reshaped = Tensor::empty({num_queries, embed_dim_}, kFloat32).alloc(); +for(int h = 0; h < num_heads_; h++) { + for(int nq = 0; nq < num_queries; nq++) { + for(int d = 0; d < head_dim_; d++) { + float val = attn_output.at({h, nq, d}); + *attn_output_reshaped.offsettedPtr({nq, h * head_dim_ + d}) = val; + } + } +} +attn_output = attn_output_reshaped; + + return {out_proj_(attn_output)}; + } +}; + +class Resampler : public nn::Module { + int32_t num_queries_; + int32_t embed_dim_; + int32_t num_heads_; + int32_t kv_dim_; + std::vector max_size_; + + nn::Param query_; + nn::Linear kv_proj_; + ResamplerAttention attn_; + nn::LayerNorm ln_q_; + nn::LayerNorm ln_kv_; + nn::LayerNorm ln_post_; + nn::Param proj_; + +public: + Resampler() = default; + + Resampler(const std::string& name, int32_t num_queries, int32_t embed_dim, int32_t num_heads, + int32_t kv_dim = -1, const std::vector& max_size = {70, 70}) + : nn::Module(name), num_queries_(num_queries), embed_dim_(embed_dim), + num_heads_(num_heads), kv_dim_(kv_dim == -1 ? embed_dim : kv_dim), max_size_(max_size) { + + query_ = reg("query", getModuleName() + ".query", Tensor::shape_t{num_queries_, embed_dim_}); + proj_ = reg("proj", getModuleName() + ".proj", Tensor::shape_t{embed_dim_, embed_dim_}); + + // kv_proj: project from kv_dim (1152) to embed_dim (3584) + kv_proj_ = reg("kv_proj", kv_dim_, embed_dim_, false); // no bias + + attn_ = reg("attn", embed_dim_, num_heads_); + ln_q_ = reg("ln_q", std::vector{embed_dim_}, true, true, 1e-6); + ln_kv_ = reg("ln_kv", std::vector{embed_dim_}, true, true, 1e-6); + ln_post_ = reg("ln_post", std::vector{embed_dim_}, true, true, 1e-6); + auto pos_embed = get2DSinCosPosEmbed(embed_dim_, max_size_); + registerBuffer("pos_embed", pos_embed); + } + + std::vector forward(const std::vector& inputs, const std::vector& args) override { + auto x = inputs[0]; // [batch_size, seq_len, kv_dim] or [seq_len, kv_dim] + auto tgt_sizes = inputs[1]; // Tensor with shape [batch_size, 2] each item is [h, w] + + auto batch_size = 1; + auto seq_len = x.shape()[0]; + + if (x.shape().size() == 3) { + batch_size = x.shape()[0]; + seq_len = x.shape()[1]; + } else { + x = x.unsqueeze(0); + } + + std::vector patch_len(batch_size); + int max_h = 0, max_w = 0, max_patch_len=0; + for(int i = 0; i < batch_size; i++){ + patch_len[i] = tgt_sizes.at({i,0}) * tgt_sizes.at({i,1}); + if(patch_len[i] > max_patch_len) max_patch_len = patch_len[i]; + if(tgt_sizes.at({i,0}) > max_h) max_h = tgt_sizes.at({i,0}); + if(tgt_sizes.at({i,1}) > max_w) max_w = tgt_sizes.at({i,1}); + } + + if(max_h > max_size_[0] || max_w > max_size_[1]){ + max_size_[0] = max_h; + max_size_[1] = max_w; + auto new_pos_embed = get2DSinCosPosEmbed(embed_dim_, max_size_); + registerBuffer("pos_embed", new_pos_embed); + } + + auto pos_embed = getBuffer("pos_embed"); // [max_h, max_w, embed_dim] + + auto key_padding_mask = Tensor::empty({batch_size, max_patch_len}, kUInt8).alloc(); + for(int i = 0; i < batch_size; i++){ + for(int j = 0; j < max_patch_len; j++){ + key_padding_mask.at({i,j}) = 1; + } + for(int j = 0; j < patch_len[i] && j < max_patch_len; j++){ + key_padding_mask.at({i,j}) = 0; + } + } + + + std::vector pos_embed_list; + + for(int i = 0; i < batch_size; i++){ + int32_t tgt_h = tgt_sizes.at({i, 0}); + int32_t tgt_w = tgt_sizes.at({i, 1}); + int32_t patch_count = tgt_h * tgt_w; + + Tensor pos_embed_i = Tensor::empty({patch_count, embed_dim_}, kFloat32).alloc(); + + int patch_idx = 0; + for(int h = 0; h < tgt_h; h++){ + for(int w = 0; w < tgt_w; w++){ + for(int d = 0; d < embed_dim_; d++){ + float value = pos_embed.at({h, w, d}); + *pos_embed_i.offsettedPtr({patch_idx, d}) = value; + } + patch_idx++; + } + } + + pos_embed_list.push_back(pos_embed_i); + } + + Tensor pos_embed_padded = Tensor::empty({batch_size, max_patch_len, embed_dim_}, kFloat32).alloc(); + for(int i = 0; i < batch_size; i++){ + auto& pos_embed_i = pos_embed_list[i]; + int actual_len = pos_embed_i.shape()[0]; + + for(int j = 0; j < actual_len && j < max_patch_len; j++){ + for(int k = 0; k < embed_dim_; k++){ + *pos_embed_padded.offsettedPtr({i, j, k}) = pos_embed_i.at({j, k}); + } + } + + for(int j = actual_len; j < max_patch_len; j++){ + for(int k = 0; k < embed_dim_; k++){ + *pos_embed_padded.offsettedPtr({i, j, k}) = 0.0f; + } + } + } + + x = kv_proj_(x); + + x = ln_kv_(x); + + auto q = ln_q_(query_.weight()); // [num_queries, embed_dim] + + std::vector outputs; + for (int32_t b = 0; b < batch_size; ++b) { + + // x for this batch + Tensor x_b = Tensor::empty({seq_len, embed_dim_}, kFloat32).alloc(); + for(int i = 0; i < seq_len; i++){ + for(int j = 0; j < embed_dim_; j++){ + x_b.at({i, j}) = x.at({b, i, j}); + } + } + + // pos_embed for this batch + Tensor pos_embed_b = Tensor::empty({seq_len, embed_dim_}, kFloat32).alloc(); + for(int i = 0; i < seq_len; i++){ + for(int j = 0; j < embed_dim_; j++){ + if(i < max_patch_len){ + pos_embed_b.at({i, j}) = pos_embed_padded.at({b, i, j}); + } else { + pos_embed_b.at({i, j}) = 0.0f; + } + } + } + + auto kv_input = x_b + pos_embed_b; + + // key_padding_mask for this batch + Tensor key_padding_mask_b = Tensor::empty({max_patch_len}, kUInt8).alloc(); + for(int i = 0; i < max_patch_len; i++){ + key_padding_mask_b.at({i}) = key_padding_mask.at({b, i}); + } + + bool has_padding = false; + for(int i = 0; i < seq_len; i++){ + if(key_padding_mask_b.at({i}) == 1){ + has_padding = true; + break; + } + } + + auto attn_output = has_padding + ? attn_(q, kv_input, x_b, key_padding_mask_b)[0] + : attn_(q, kv_input, x_b)[0]; + + outputs.push_back(attn_output); + } + + auto out_tensor = Tensor::empty({batch_size, num_queries_, embed_dim_}, kFloat32).alloc(); + for(int i = 0; i < batch_size; i++){ + auto& out_i = outputs[i]; + for(int j = 0; j < num_queries_; j++){ + for(int k = 0; k < embed_dim_; k++){ + *out_tensor.offsettedPtr({i, j, k}) = out_i.at({j, k}); + } + } + } + + out_tensor = ln_post_(out_tensor); + + auto original_shape = out_tensor.shape(); + auto reshaped = out_tensor.view({original_shape[0] * original_shape[1], original_shape[2]}); + reshaped = nn::functional::matmul(reshaped, proj_.weight()); + out_tensor = reshaped.view(original_shape); + + if (inputs[0].shape().size() == 2) { + out_tensor = out_tensor.squeeze(0); + } + + return {out_tensor};//[batch_size, num_queries, embed_dim] or [num_queries, embed_dim] + } +}; + +} // namespace mllm::models::minicpmo diff --git a/mllm/models/minicpm_o2_6/modeling_siglip.hpp b/mllm/models/minicpm_o2_6/modeling_siglip.hpp new file mode 100644 index 000000000..d425738e8 --- /dev/null +++ b/mllm/models/minicpm_o2_6/modeling_siglip.hpp @@ -0,0 +1,452 @@ +// Copyright (c) MLLM Team. +// Licensed under the MIT License. +#pragma once + +#include "mllm/mllm.hpp" +#include "mllm/nn/Module.hpp" +#include "mllm/nn/Nn.hpp" +#include "mllm/nn/Functional.hpp" +#include "mllm/models/minicpm_o2_6/configuration_minicpmo.hpp" +#include +#include + +namespace mllm::models::minicpmo { + +// SigLIP Vision Embeddings +class SiglipVisionEmbeddings final : public nn::Module { + int32_t embed_dim_; + int32_t image_size_; + int32_t patch_size_; + int32_t num_patches_per_side_; + int32_t num_patches_; + + nn::Conv2D patch_embedding_; + nn::Embedding position_embedding_; + +public: + SiglipVisionEmbeddings() = default; + + inline SiglipVisionEmbeddings(const std::string& name, const MiniCPMOConfig& config) : nn::Module(name) { + embed_dim_ = config.vision_hidden_size; + image_size_ = config.vision_image_size; + patch_size_ = config.vision_patch_size; + num_patches_per_side_ = image_size_ / patch_size_; + num_patches_ = num_patches_per_side_ * num_patches_per_side_; + + patch_embedding_ = reg("patch_embedding", 3, embed_dim_, + std::vector{patch_size_, patch_size_}, + std::vector{patch_size_, patch_size_}, + std::vector{0, 0}, std::vector{1, 1}, 1, true); + position_embedding_ = reg("position_embedding", num_patches_, embed_dim_); + } + + std::vector forward(const std::vector& inputs, const std::vector& args) override { + auto pixel_values = inputs[0]; + auto tgt_sizes = inputs.size() > 1 ? inputs[1] : Tensor::nil(); + auto patch_attention_mask = inputs.size() > 2 ? inputs[2] : Tensor::nil(); + + auto batch_size = pixel_values.shape()[0]; + + // Patch embedding: [B, C, H, W] -> [B, embed_dim, 1, H*W] + auto patch_embeds = patch_embedding_(pixel_values); + // [B, embed_dim, 1, H*W] -> [B, H*W, embed_dim] + auto embeddings = patch_embeds.squeeze(2).transpose(1, 2); + + + // Create position embeddings + if (!tgt_sizes.isNil() && !patch_attention_mask.isNil()) { + auto max_im_h = pixel_values.shape()[2]; + auto max_im_w = pixel_values.shape()[3]; + auto max_nb_patches_h = max_im_h / patch_size_; + auto max_nb_patches_w = max_im_w / patch_size_; + + // Create boundaries like torch.arange(1 / self.num_patches_per_side, 1.0, 1 / self.num_patches_per_side) + std::vector boundaries; + float step = 1.0f / static_cast(num_patches_per_side_); + for (int i = 1; i < num_patches_per_side_; ++i) { + boundaries.push_back(i * step); + } + + // Create position_ids tensor - using the max_patches from patch_attention_mask shape + auto max_patches = patch_attention_mask.shape()[2]; + auto position_ids = Tensor::empty({batch_size, max_patches}, kInt64).alloc(); + // Initialize to zeros + for (int b = 0; b < batch_size; b++) { + for (int p = 0; p < max_patches; p++) { + position_ids.at({b, p}) = 0; + } + } + + // Fill position ids based on patch grid and attention mask + for (int batch_idx = 0; batch_idx < batch_size; batch_idx++) { + int nb_patches_h = max_nb_patches_h; + int nb_patches_w = max_nb_patches_w; + + if (tgt_sizes.shape().size() == 2 && batch_idx < tgt_sizes.shape()[0]) { + nb_patches_h = tgt_sizes.at({batch_idx, 0}); + nb_patches_w = tgt_sizes.at({batch_idx, 1}); + } + + // Create fractional coordinates like torch.arange(0, 1 - 1e-6, 1 / nb_patches_h/w) + std::vector fractional_coords_h; + std::vector fractional_coords_w; + + float step_h = 1.0f / static_cast(nb_patches_h); + float step_w = 1.0f / static_cast(nb_patches_w); + + for (int i = 0; i < nb_patches_h; ++i) { + fractional_coords_h.push_back(i * step_h); + } + for (int i = 0; i < nb_patches_w; ++i) { + fractional_coords_w.push_back(i * step_w); + } + + + + // Bucketize coordinates (equivalent to torch.bucketize with right=True) + std::vector bucket_coords_h(nb_patches_h); + std::vector bucket_coords_w(nb_patches_w); + + for (int h = 0; h < nb_patches_h; ++h) { + float coord = fractional_coords_h[h]; + int bucket = 0; + for (size_t i = 0; i < boundaries.size(); ++i) { + if (coord < boundaries[i]) { + bucket = static_cast(i); + break; + } + bucket = static_cast(i + 1); + } + bucket_coords_h[h] = bucket; + } + + for (int w = 0; w < nb_patches_w; ++w) { + float coord = fractional_coords_w[w]; + int bucket = 0; + for (size_t i = 0; i < boundaries.size(); ++i) { + if (coord < boundaries[i]) { + bucket = static_cast(i); + break; + } + bucket = static_cast(i + 1); + } + bucket_coords_w[w] = bucket; + } + + // Create pos_ids like Python: (bucket_coords_h[:, None] * self.num_patches_per_side + bucket_coords_w).flatten() + std::vector pos_ids; + for (int h = 0; h < nb_patches_h; ++h) { + for (int w = 0; w < nb_patches_w; ++w) { + int pos_id = bucket_coords_h[h] * num_patches_per_side_ + bucket_coords_w[w]; + pos_ids.push_back(pos_id); + } + } + + // Apply pos_ids only where patch_attention_mask is True (now it's 1D) + // position_ids[batch_idx][p_attn_mask.view(-1).cpu()] = pos_ids + int pos_ids_idx = 0; + for (int flat_idx = 0; flat_idx < max_patches; ++flat_idx) { + uint8_t mask_val = patch_attention_mask.at({batch_idx, 0, flat_idx}); + if (mask_val && pos_ids_idx < pos_ids.size()) { + position_ids.at({batch_idx, flat_idx}) = pos_ids[pos_ids_idx]; + pos_ids_idx++; + } + } + + } + + + auto pos_embeddings = position_embedding_(position_ids); + embeddings = embeddings + pos_embeddings; + } else { + auto seq_len = embeddings.shape()[1]; + auto position_ids = Tensor::arange(0, seq_len, kInt64).view({1, seq_len}); + auto pos_embeddings = position_embedding_(position_ids); + embeddings = embeddings + pos_embeddings; + } + + return {embeddings}; + } +}; + +// SigLIP MLP +class SiglipMLP final : public nn::Module { + int32_t hidden_size_; + int32_t intermediate_size_; + + nn::Linear fc1_; + nn::Linear fc2_; + nn::GELU activation_fn_; + +public: + SiglipMLP() = default; + + inline SiglipMLP(const std::string& name, const MiniCPMOConfig& config) : nn::Module(name) { + hidden_size_ = config.vision_hidden_size; + intermediate_size_ = config.vision_intermediate_size; + + fc1_ = reg("fc1", hidden_size_, intermediate_size_, true); + fc2_ = reg("fc2", intermediate_size_, hidden_size_, true); + activation_fn_ = reg("act"); + } + + std::vector forward(const std::vector& inputs, const std::vector& args) override { + auto hidden_states = inputs[0]; + + auto x = fc1_(hidden_states); + x = activation_fn_(x); + x = fc2_(x); + + return {x}; + } +}; + +// SigLIP Multi-Head Attention +class SiglipAttention final : public nn::Module { + int32_t embed_dim_; + int32_t num_heads_; + int32_t head_dim_; + float scale_; + + nn::Linear k_proj_; + nn::Linear v_proj_; + nn::Linear q_proj_; + nn::Linear out_proj_; + +public: + SiglipAttention() = default; + + inline SiglipAttention(const std::string& name, const MiniCPMOConfig& config) : nn::Module(name) { + embed_dim_ = config.vision_hidden_size; + num_heads_ = config.vision_num_attention_heads; + head_dim_ = embed_dim_ / num_heads_; + scale_ = 1.0f / sqrtf(static_cast(head_dim_)); + + k_proj_ = reg("k_proj", embed_dim_, embed_dim_, true); + v_proj_ = reg("v_proj", embed_dim_, embed_dim_, true); + q_proj_ = reg("q_proj", embed_dim_, embed_dim_, true); + out_proj_ = reg("out_proj", embed_dim_, embed_dim_, true); + } + + std::vector forward(const std::vector& inputs, const std::vector& args) override { + auto hidden_states = inputs[0]; + auto attention_mask = inputs.size() > 1 ? inputs[1] : Tensor::nil(); + + auto batch_size = hidden_states.shape()[0]; + auto seq_len = hidden_states.shape()[1]; + + // Apply projections + auto query_states = q_proj_(hidden_states); + auto key_states = k_proj_(hidden_states); + auto value_states = v_proj_(hidden_states); + + // Reshape for multi-head attention: [B, seq_len, embed_dim] -> [B, seq_len, num_heads, head_dim] -> [B, num_heads, seq_len, head_dim] + query_states = query_states.view({batch_size, seq_len, num_heads_, head_dim_}).transpose(1, 2); + key_states = key_states.view({batch_size, seq_len, num_heads_, head_dim_}).transpose(1, 2); + value_states = value_states.view({batch_size, seq_len, num_heads_, head_dim_}).transpose(1, 2); + + // Compute attention scores: [B, num_heads, seq_len, seq_len] + auto attn_weights = nn::functional::matmul(query_states, key_states.transpose(-2, -1)) * scale_; + + // Apply attention mask if provided + if (!attention_mask.isNil()) { + attn_weights = attn_weights + attention_mask; + } + + // Apply softmax + attn_weights = nn::functional::softmax(attn_weights, -1); + + // Apply attention to values + auto attn_output = nn::functional::matmul(attn_weights, value_states); + + // Reshape back: [B, num_heads, seq_len, head_dim] -> [B, seq_len, num_heads, head_dim] -> [B, seq_len, embed_dim] + attn_output = attn_output.transpose(1, 2).contiguous().view({batch_size, seq_len, embed_dim_}); + + // Apply output projection + auto output = out_proj_(attn_output); + + return {output}; + } +}; + +// SigLIP Encoder Layer +class SiglipEncoderLayer final : public nn::Module { + int32_t embed_dim_; + + SiglipAttention self_attn_; + nn::LayerNorm layer_norm1_; + SiglipMLP mlp_; + nn::LayerNorm layer_norm2_; + +public: + SiglipEncoderLayer() = default; + + inline SiglipEncoderLayer(const std::string& name, const MiniCPMOConfig& config) : nn::Module(name) { + embed_dim_ = config.vision_hidden_size; + + self_attn_ = reg("self_attn", config); + layer_norm1_ = reg("layer_norm1", std::vector{embed_dim_}, true, true, 1e-6); + mlp_ = reg("mlp", config); + layer_norm2_ = reg("layer_norm2", std::vector{embed_dim_}, true, true, 1e-6); + } + + std::vector forward(const std::vector& inputs, const std::vector& args) override { + auto hidden_states = inputs[0]; + auto attention_mask = inputs.size() > 1 ? inputs[1] : Tensor::nil(); + + // Self attention with residual connection + auto residual = hidden_states; + auto normed = layer_norm1_(hidden_states); + auto attn_output = self_attn_(normed, attention_mask)[0]; + auto after_attn = residual + attn_output; + + // MLP with residual connection + residual = after_attn; + normed = layer_norm2_(after_attn); + auto mlp_output = mlp_(normed)[0]; + auto output = residual + mlp_output; + + return {output}; + } +}; + +// SigLIP Vision Encoder +class SiglipVisionEncoder final : public nn::Module { + int32_t num_layers_; + std::vector layers_; + +public: + SiglipVisionEncoder() = default; + + inline SiglipVisionEncoder(const std::string& name, const MiniCPMOConfig& config) : nn::Module(name) { + num_layers_ = config.vision_num_hidden_layers; + + for (int i = 0; i < num_layers_; ++i) { + layers_.push_back(reg("layers." + std::to_string(i), config)); + } + } + + std::vector forward(const std::vector& inputs, const std::vector& args) override { + auto inputs_embeds = inputs[0]; + auto attention_mask = inputs.size() > 1 ? inputs[1] : Tensor::nil(); + + auto hidden_states = inputs_embeds; + for (auto& layer : layers_) { + hidden_states = layer(hidden_states, attention_mask)[0]; + //break; // For testing, run only one layer + } + return {hidden_states}; + } +}; + +// Main SigLIP Vision Model +class SiglipVisionModel final : public nn::Module { + int32_t embed_dim_; + int32_t patch_size_; // Add patch_size_ to access in forward + + SiglipVisionEmbeddings embeddings_; + SiglipVisionEncoder encoder_; + nn::LayerNorm post_layernorm_; + +public: + SiglipVisionModel() = default; + + inline SiglipVisionModel(const std::string& name, const MiniCPMOConfig& config) : nn::Module(name) { + embed_dim_ = config.vision_hidden_size; + patch_size_ = config.vision_patch_size; // Initialize patch_size_ + + embeddings_ = reg("embeddings", config); + encoder_ = reg("encoder", config); + post_layernorm_ = reg("post_layernorm", std::vector{embed_dim_}, true, true, 1e-6); + } + + std::vector forward(const std::vector& inputs, const std::vector& args) override { + auto pixel_values = inputs[0]; + auto tgt_sizes = inputs.size() > 1 ? inputs[1] : Tensor::nil(); + + auto batch_size = pixel_values.shape()[0]; + int max_patches = 0; + // Calculate max_patches based on tgt_sizes + for(int i=0;i({i,0})>0 && tgt_sizes.at({i,1})>0){ + int patches = (tgt_sizes.at({i,0}) ) * (tgt_sizes.at({i,1}) ); + if (patches > max_patches) max_patches = patches; + } + } + auto patch_attention_mask = Tensor::empty({batch_size, 1, max_patches}, kUInt8).alloc(); + for(int i=0;i({i,0,j}) = 0; + } + if(!tgt_sizes.isNil() && i({i,0}); + int nb_patches_w = tgt_sizes.at({i,1}); + int valid_patches = nb_patches_h * nb_patches_w; + for(int j=0;j({i,0,j}) = 1; + } + } + } + std::vector hidden_states_result; + if (tgt_sizes.isNil()) { + hidden_states_result = embeddings_(pixel_values, Tensor::nil(), patch_attention_mask); + } else { + hidden_states_result = embeddings_(pixel_values, tgt_sizes, patch_attention_mask); + } + auto hidden_states = hidden_states_result[0]; // [B, num_patches, embed_dim] + + patch_attention_mask = patch_attention_mask.squeeze(1); // [B, max_patches] + + // Create attention mask for encoder (4D mask for multi-head attention) + Tensor attention_mask = Tensor::nil(); + if (!patch_attention_mask.isNil()) { + auto batch_size = patch_attention_mask.shape()[0]; + auto max_patches = patch_attention_mask.shape()[1]; + + bool all_valid = true; + for (int i = 0; i < batch_size && all_valid; i++) { + for (int j = 0; j < max_patches && all_valid; j++) { + uint8_t mask_val = patch_attention_mask.at({i, j}); + if (mask_val == 0) { + all_valid = false; + } + } + } + if (!all_valid){ + // Convert patch_attention_mask to float and create 4D attention mask + auto patch_mask_float = Tensor::empty({batch_size, max_patches}, kFloat32).alloc(); + for (int i = 0; i < batch_size; i++) { + for (int j = 0; j < max_patches; j++) { + uint8_t mask_val = patch_attention_mask.at({i, j}); + patch_mask_float.at({i, j}) = mask_val ? 1.0f : 0.0f; + } + } + + // Create 4D attention mask: [B, 1, max_patches, max_patches] + attention_mask = Tensor::empty({batch_size, 1, max_patches, max_patches}, kFloat32).alloc(); + for (int b = 0; b < batch_size; b++) { + for (int i = 0; i < max_patches; i++) { + for (int j = 0; j < max_patches; j++) { + float mask_i = patch_mask_float.at({b, i}); + float mask_j = patch_mask_float.at({b, j}); + // Both positions must be valid + float final_mask = (mask_i > 0.0f && mask_j > 0.0f) ? 0.0f : -1e9f; + attention_mask.at({b, 0, i, j}) = final_mask; + } + } + } + } + + } + + auto encoder_outputs = encoder_(hidden_states, attention_mask)[0]; + + // Apply post layer norm + auto last_hidden_state = post_layernorm_(encoder_outputs); + + return {last_hidden_state}; + } + +}; + +} // namespace mllm::models::minicpmo diff --git a/mllm/models/minicpm_o2_6/tokenization_minicpmo.hpp b/mllm/models/minicpm_o2_6/tokenization_minicpmo.hpp new file mode 100644 index 000000000..571482283 --- /dev/null +++ b/mllm/models/minicpm_o2_6/tokenization_minicpmo.hpp @@ -0,0 +1,496 @@ +// Copyright (c) MLLM Team. +// Licensed under the MIT License. +#pragma once + +#include +#include +#include + +#include "mllm/preprocessor/tokenizers/BPE.hpp" +#include "mllm/models/ARGeneration.hpp" +#include "mllm/preprocessor/tokenizers/Unicode.hpp" +#include "mllm/preprocessor/tokenizers/AutoTokenizer.hpp" +#include "mllm/models/minicpm_o2_6/image_preprocessor_minicpmo.hpp" + +namespace mllm::models::minicpmo { + +// same with qwen2 +// 参考: https://github.com/QwenLM/Qwen2-VL/blob/main/qwen2_vl/tokenization_qwen2_vl.py +// (?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}| +// ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+ +inline bool miniCPMOTokenizerMatchPattern(const std::wstring& str, size_t& pos, std::wstring& matched) { + if (pos >= str.size()) return false; + + // 1. Match contractions: "'s|'t|'re|'ve|'m|'ll|'d" + static const std::wstring contractions[] = {L"'s", L"'t", L"'re", L"'ve", L"'m", L"'ll", L"'d"}; + for (const auto& contraction : contractions) { + if (pos + contraction.size() <= str.size() && str.compare(pos, contraction.size(), contraction) == 0) { + matched = contraction; + pos += contraction.size(); + return true; + } + } + + // 2. Match [^\r\n\p{L}\p{N}]?\p{L}+ (non-letter/digit followed by letters) + { + size_t original_pos = pos; + bool has_prefix = false; + matched.clear(); + + // Check optional non-letter/digit prefix (excluding \r\n) + if (!preprocessor::isLetter(str[pos]) && !preprocessor::isDigit(str[pos]) && str[pos] != L'\r' && str[pos] != L'\n') { + matched += str[pos]; + ++pos; + has_prefix = true; + } + + // Require at least one letter + if (pos < str.size() && preprocessor::isLetter(str[pos])) { + do { + matched += str[pos]; + ++pos; + } while (pos < str.size() && preprocessor::isLetter(str[pos])); + return true; + } else { + // Rollback if no letters after prefix + if (has_prefix) { + pos = original_pos; + matched.clear(); + } + } + } + + // 3. Match \p{N} (digits) + if (preprocessor::isDigit(str[pos])) { + matched = str.substr(pos, 1); + ++pos; + return true; + } + + // 4. Match ?[^\s\p{L}\p{N}]+[\r\n]* (punctuation/symbols with optional space prefix) + { + size_t original_pos = pos; + matched.clear(); + size_t start = pos; + + // Optional space + if (str[pos] == L' ') { ++pos; } + + // Require at least one non-letter/digit/whitespace + if (pos < str.size() && !std::iswspace(str[pos]) && !preprocessor::isLetter(str[pos]) && !preprocessor::isDigit(str[pos])) { + do { + ++pos; + } while (pos < str.size() && !std::iswspace(str[pos]) && !preprocessor::isLetter(str[pos]) + && !preprocessor::isDigit(str[pos])); + + // Capture from start (after optional space) to current pos + matched = str.substr(start, pos - start); + + // Capture trailing newlines + while (pos < str.size() && (str[pos] == L'\r' || str[pos] == L'\n')) { + matched += str[pos]; + ++pos; + } + return true; + } else { + // Rollback if no symbols found + pos = original_pos; + } + } + + // 5. Match \s*[\r\n]+ (newlines with leading whitespace) + { + size_t start = pos; + while (pos < str.size() && std::iswspace(str[pos])) ++pos; + if (pos < str.size() && (str[pos] == L'\r' || str[pos] == L'\n')) { + while (pos < str.size() && (str[pos] == L'\r' || str[pos] == L'\n')) ++pos; + matched = str.substr(start, pos - start); + return true; + } else { + pos = start; + } + } + + // 6. Match \s+(?!\S) (whitespace not followed by non-space) + if (std::iswspace(str[pos])) { + size_t start = pos; + while (pos < str.size() && std::iswspace(str[pos])) ++pos; + // Check if at end or followed by whitespace + if (pos >= str.size() || std::iswspace(str[pos])) { + matched = str.substr(start, pos - start); + return true; + } else { + pos = start; + } + } + + // 7. Match remaining whitespace + if (std::iswspace(str[pos])) { + size_t start = pos; + while (pos < str.size() && std::iswspace(str[pos])) ++pos; + matched = str.substr(start, pos - start); + return true; + } + + return false; +} + +inline bool miniCPMORegex(const std::string& str, std::vector& splitted) { + auto w_string = preprocessor::utf8string2WideString(str); + size_t pos = 0; + while (pos < w_string.size()) { + std::wstring matched; + if (miniCPMOTokenizerMatchPattern(w_string, pos, matched)) { + splitted.push_back(matched); + } else { + ++pos; + } + } + return true; +} + +struct MiniCPMOMessage { + std::string prompt; + std::string img_file_path; + std::string audio_file_path; + std::string system_prompt = "You are Qwen, created by Alibaba Clound. You are a helpful assistant."; + + // 格式: <|im_start|>{role}\n{content}<|im_end|>\n + [[nodiscard]]std::string buildChatMessage() const { + std::string result=""; + // System message + if (!system_prompt.empty()) { + result += "<|im_start|>system\n" + system_prompt + "<|im_end|>\n"; + } + + result += "<|im_start|>user\n"; + + // Image placeholder + if (!img_file_path.empty()) { + result += "(./)"; + } + + // Audio placeholder + if (!audio_file_path.empty()) { + result += "()"; + } + + result += "\n" + prompt + "<|im_end|>\n"; + + // Assistant prompt start + result += "<|im_start|>assistant\n"; + + return result; + } +}; + +struct MiniCPMOInput { + std::string prompt; + std::string img_file_path = ""; + std::string audio_file_path = ""; + std::vector image_paths = {}; + std::vector audio_paths = {}; +}; + +class MiniCPMOTokenizer final : public mllm::preprocessor::AutoTokenizer { + public: + explicit MiniCPMOTokenizer(const std::string& file_path, int32_t patch_size = 14) + //: image_preprocessor_(patch_size) + { + preprocessor::initLocal(); + preprocessor::makeBytes2UnicodeMap(bytes_2_unicode_dict_); + for (auto& kv : bytes_2_unicode_dict_) { bytes_2_unicode_dict_inverse_.insert({kv.second, kv.first}); } + bpe_.initFromSentencePieceJson(file_path); + + special_tokens_trie_.add(L"<|endoftext|>"); + special_tokens_trie_.add(L"<|im_start|>"); + special_tokens_trie_.add(L"<|object_ref_start|>"); + special_tokens_trie_.add(L"<|object_ref_end|>"); + special_tokens_trie_.add(L"<|box_start|>"); + special_tokens_trie_.add(L"<|box_end|>"); + special_tokens_trie_.add(L"<|quad_start|>"); + special_tokens_trie_.add(L"<|quad_end|>"); + special_tokens_trie_.add(L"<|vision_start|>"); + special_tokens_trie_.add(L"<|vision_end|>"); + special_tokens_trie_.add(L"<|vision_pad|>"); + special_tokens_trie_.add(L"<|image_pad|>"); + special_tokens_trie_.add(L"<|video_pad|>"); + special_tokens_trie_.add(L""); + special_tokens_trie_.add(L""); + + // Add UNK token as special token + special_tokens_trie_.add(L""); + + // Image tokens + special_tokens_trie_.add(L""); + special_tokens_trie_.add(L""); + special_tokens_trie_.add(L""); + special_tokens_trie_.add(L""); + special_tokens_trie_.add(L""); + special_tokens_trie_.add(L""); + special_tokens_trie_.add(L""); + special_tokens_trie_.add(L""); + special_tokens_trie_.add(L""); + special_tokens_trie_.add(L""); + special_tokens_trie_.add(L""); + special_tokens_trie_.add(L""); + + // Audio tokens + special_tokens_trie_.add(L"<|audio_start|>"); + special_tokens_trie_.add(L"<|audio_end|>"); + special_tokens_trie_.add(L"<|spk_bos|>"); + special_tokens_trie_.add(L"<|spk_eos|>"); + special_tokens_trie_.add(L"<|tts_bos|>"); + special_tokens_trie_.add(L"<|tts_eos|>"); + } + + std::vector _tokenize(const std::string& str) override { + std::vector ret; + std::vector splitted; + ::mllm::models::minicpmo::miniCPMORegex(str, splitted); + for (const auto& s : splitted) { + auto utf_8_str = preprocessor::wideString2Utf8String(s); + std::wstring mapped_str; + for (unsigned char c : utf_8_str) { mapped_str.push_back(bytes_2_unicode_dict_[c]); } + + auto bpe_ts = bpe_._bpe(mapped_str); + + for (const auto& bpe_t : bpe_ts) { ret.push_back(bpe_t); } + } + + return ret; + } + + std::vector tokenize(const std::string& str) override { + auto tokens = special_tokens_trie_.split(preprocessor::utf8string2WideString(str)); + std::vector all_tokens; + for (const auto& token : tokens) { + if (special_tokens_trie_.isSpecialToken(token)) { + all_tokens.emplace_back(token); + continue; + } + auto tmp_tokens = _tokenize(preprocessor::wideString2Utf8String(token)); + all_tokens.insert(all_tokens.end(), tmp_tokens.begin(), tmp_tokens.end()); + } + return all_tokens; + } + + std::wstring _detokenize(int64_t pos_idx) override { return bpe_._lookup_inverse_vocab(pos_idx); } + + std::wstring detokenize(int64_t pos_idx) override { + auto str = _detokenize(pos_idx); + std::string utf_8_str; + for (wchar_t c : str) { utf_8_str.push_back((unsigned char)(bytes_2_unicode_dict_inverse_[c])); } + return {mllm::preprocessor::utf8string2WideString(utf_8_str)}; + } + + Tensor convert2Ids(const std::vector& strs) override { + std::vector ids; + ids.reserve(strs.size()); + for (const auto& str : strs) { ids.emplace_back(bpe_._lookup_vocab(str)); } + Tensor ret = Tensor::empty({/*batch*/ 1, /*seq*/ (int32_t)ids.size()}, kInt64, kCPU) + .setMemType(kExtraInput) + .setName("minicpmo-tokenizer-i0") + .alloc(); + + auto ptr = ret.ptr(); + for (size_t i = 0; i < ids.size(); ++i) { ptr[i] = ids[i]; } + + return ret; + } + + ARGenerationOutputPast convertMessage(const MiniCPMOMessage& message) { + // 构建完整的聊天消息 + auto applied_string = message.buildChatMessage(); + // Process Image + if (!message.img_file_path.empty()) { + auto [img_tensors, original_size, tgt_sizes, grid] = image_preprocessor_.process(message.img_file_path); + //Checked with Python, all correct + + std::regex pattern(R"(\(\./\))"); + std::vector image_tags; + std::sregex_iterator iter(applied_string.begin(), applied_string.end(), pattern); + std::sregex_iterator end; + + for(; iter != end; ++iter){ + image_tags.push_back(iter->str()); + } + + std::vector text_chunks; + int32_t pos = 0; + for(const auto& tag : image_tags){ + auto found = applied_string.find(tag, pos); + if(found != std::string::npos){ + text_chunks.push_back(applied_string.substr(pos, found - pos)); + pos = found + tag.size(); + } + } + text_chunks.push_back(applied_string.substr(pos)); + std::string final_text = ""; + for(size_t i = 0; i < image_tags.size(); ++i){ + final_text += text_chunks[i]; + final_text += image_preprocessor_.get_slice_image_placeholder(original_size[i] , grid, i); + } + final_text += "\n"; + final_text += text_chunks.back(); + auto input_ids = tokenize(final_text); + std::vector input_ids_vec; + input_ids_vec.reserve(input_ids.size()); + for (const auto& id_str : input_ids) { + input_ids_vec.emplace_back(bpe_._lookup_vocab(id_str)); + } + + auto [input_ids_new, image_bounds] = image_preprocessor_.calc_bounds(input_ids_vec, bpe_); + auto result = Convert2Tensors(input_ids_new, img_tensors, tgt_sizes, image_bounds); + return result; + } else { + // 处理纯文本消息(无图像输入) + // // 处理音频占位符 + // if (!message.audio_file_path.empty()) { + // size_t audio_placeholder_pos = applied_string.find("()"); + // if (audio_placeholder_pos != std::string::npos) { + // // TODO: 实现音频placeholder生成 + // std::string audio_placeholder = "<|audio_start|><|audio_end|>"; // 简化版本 + // applied_string.replace(audio_placeholder_pos, 17, audio_placeholder); + // } + // } + + // 对最终字符串进行tokenization + auto sequence_str = tokenize(applied_string); + std::vector ids; + ids.reserve(sequence_str.size()); + for (const auto& str : sequence_str) { + ids.emplace_back(bpe_._lookup_vocab(str)); + } + + + // Get sequence Tensor + Tensor sequence = Tensor::empty({/*batch*/ 1, /*seq*/ (int32_t)ids.size()}, kInt64, kCPU) + .setMemType(kNormal) + .setName("minicpmo-tokenizer-i0") + .alloc(); + + auto ptr = sequence.ptr(); + for (size_t i = 0; i < ids.size(); ++i) { ptr[i] = ids[i]; } + + ARGenerationOutputPast result = { + {"input_ids", sequence}, + }; + + + // if (!message.img_file_path.empty() && img.isNotEmpty()) { + // result["img"] = img; + // result["grid_thw"] = grid_thw; + // } + + return result; + } + } + +private: + ARGenerationOutputPast Convert2Tensors( + std::vector& input_ids_vec, + std::vector& img_tensors, + std::vector>& tgt_sizes, + std::vector>& image_bounds) { + + ARGenerationOutputPast result; + + // Convert input_ids_new (std::vector) to Tensor + if (!input_ids_vec.empty()) { + Tensor input_ids_tensor = Tensor::empty({1, (int32_t)input_ids_vec.size()}, kInt64, kCPU) + .setMemType(kExtraInput) + .setName("input_ids") + .alloc(); + auto input_ids_ptr = input_ids_tensor.ptr(); + for (size_t i = 0; i < input_ids_vec.size(); ++i) { + input_ids_ptr[i] = input_ids_vec[i]; + } + result["input_ids"] = input_ids_tensor; + } + + // Convert img_tensors (std::vector) to single Tensor + // **ADD PADDING HERE!** + if (!img_tensors.empty()) { + if (img_tensors.size() == 1) { + result["pixel_values"] = img_tensors[0]; + } else { + int channels = img_tensors[0].shape()[0]; + int patch_size = img_tensors[0].shape()[1]; + int HW_patch_size = img_tensors[0].shape()[2]; + for(int i=0;i HW_patch_size){ + HW_patch_size = img_tensors[i].shape()[2]; + } + } + Tensor pixel_values = Tensor::empty({(int)img_tensors.size(),channels, patch_size, HW_patch_size}, kFloat32, kCPU) + .setMemType(kExtraInput) + .setName("pixel_values") + .alloc(); + auto pixel_values_ptr = pixel_values.ptr(); + for(int b = 0; b < (int)img_tensors.size(); b++){ + for(int c = 0; c < channels; c++){ + for(int p = 0; p < patch_size; p++){ + for(int hw = 0; hw < HW_patch_size; hw++){ + int dst_idx = b * channels * patch_size * HW_patch_size + + c * patch_size * HW_patch_size + + p * HW_patch_size + + hw; + if(hw > img_tensors[b].shape()[2]){ + pixel_values_ptr[dst_idx]=0; + } + else{ + pixel_values_ptr[dst_idx]=img_tensors[b].at({c, p, hw}); + } + } + } + } + } + + result["pixel_values"] = pixel_values; + } + } + + // Convert tgt_sizes (std::vector>) to Tensor + if (!tgt_sizes.empty()) { + Tensor tgt_sizes_tensor = Tensor::empty({(int32_t)tgt_sizes.size(), 2}, kInt32, kCPU) + .setMemType(kExtraInput) + .setName("tgt_sizes") + .alloc(); + auto tgt_sizes_ptr = tgt_sizes_tensor.ptr(); + for (size_t i = 0; i < tgt_sizes.size(); ++i) { + tgt_sizes_ptr[i * 2] = tgt_sizes[i].first; + tgt_sizes_ptr[i * 2 + 1] = tgt_sizes[i].second; + } + result["tgt_sizes"] = tgt_sizes_tensor; + } + + // Convert image_bounds (std::vector>) to Tensor + if (!image_bounds.empty()) { + Tensor image_bounds_tensor = Tensor::empty({(int32_t)image_bounds.size(), 2}, kInt32, kCPU) + .setMemType(kExtraInput) + .setName("image_bounds") + .alloc(); + auto image_bounds_ptr = image_bounds_tensor.ptr(); + for (size_t i = 0; i < image_bounds.size(); ++i) { + image_bounds_ptr[i * 2] = image_bounds[i].first; + image_bounds_ptr[i * 2 + 1] = image_bounds[i].second; + } + result["image_bounds"] = image_bounds_tensor; + } + + return result; + } + + private: + // For image only. + MiniCPMOImageProcessor image_preprocessor_; + + // For text + preprocessor::BPE bpe_; + std::unordered_map bytes_2_unicode_dict_; + std::unordered_map bytes_2_unicode_dict_inverse_; + +}; + +} // namespace mllm::models::minicpmo diff --git a/mllm/nn/Nn.hpp b/mllm/nn/Nn.hpp index bb4fa54d9..5b826228c 100644 --- a/mllm/nn/Nn.hpp +++ b/mllm/nn/Nn.hpp @@ -25,6 +25,7 @@ #include "mllm/nn/layers/Param.hpp" // IWYU pragma: export #include "mllm/nn/layers/KVCache.hpp" // IWYU pragma: export #include "mllm/nn/layers/Conv1D.hpp" // IWYU pragma: export +#include "mllm/nn/layers/Conv2D.hpp" // IWYU pragma: export #include "mllm/nn/layers/STFT.hpp" // IWYU pragma: export #include "mllm/nn/layers/PagedAttn.hpp" // IWYU pragma: export #include "mllm/nn/layers/RadixAttn.hpp" // IWYU pragma: export diff --git a/tasks/build_osx_apple_silicon.yaml b/tasks/build_osx_apple_silicon.yaml index 69dacb955..501cbd088 100644 --- a/tasks/build_osx_apple_silicon.yaml +++ b/tasks/build_osx_apple_silicon.yaml @@ -4,10 +4,11 @@ Tasks: cmake_build_type: "Release" cmake_extra_args: - "-DMLLM_BUILD_ARM_BACKEND=ON" - - '-DMLLM_CPU_BACKEND_COMPILE_OPTIONS="-march=native+fp16+fp16fml+dotprod+i8mm+sme"' + - '-DMLLM_CPU_BACKEND_COMPILE_OPTIONS="-march=native+fp16+fp16fml+dotprod+i8mm"' - "-DMLLM_KERNEL_USE_THREADS=ON" - "-DMLLM_KERNEL_THREADS_VENDOR_OPENMP=OFF" - "-DMLLM_KERNEL_THREADS_VENDOR_APPLE_GCD=ON" + - “-DCMAKE_CXX_FLAGS=-DMLLM_MACOS_BUILD” - CMakeBuildTask: cmake_cfg_path: "build-osx" diff --git a/tasks/build_osx_apple_silicon_dbg.yaml b/tasks/build_osx_apple_silicon_dbg.yaml index a99963c8e..a9d8e6e08 100644 --- a/tasks/build_osx_apple_silicon_dbg.yaml +++ b/tasks/build_osx_apple_silicon_dbg.yaml @@ -4,9 +4,7 @@ Tasks: cmake_build_type: "Debug" cmake_extra_args: - "-DMLLM_BUILD_ARM_BACKEND=ON" - - '-DMLLM_CPU_BACKEND_COMPILE_OPTIONS="-march=native+fp16+fp16fml+dotprod+i8mm+sme"' - - "-DMLLM_USE_BLAS=ON" - - "-DMLLM_BLAS_VENDOR_ACCELERATE=ON" + - '-DMLLM_CPU_BACKEND_COMPILE_OPTIONS="-march=native+fp16+fp16fml+dotprod+i8mm"' - "-DMLLM_KERNEL_USE_THREADS=ON" - "-DMLLM_KERNEL_THREADS_VENDOR_OPENMP=OFF" - "-DMLLM_KERNEL_THREADS_VENDOR_APPLE_GCD=ON" From 2aff5300658e6985295f0e7d4d8fba133a6b8bbc Mon Sep 17 00:00:00 2001 From: oreomaker <70836772+oreomaker@users.noreply.github.com> Date: Thu, 30 Oct 2025 09:26:47 +0800 Subject: [PATCH 2/7] chore: restore omitted cmake demo directories --- examples/CMakeLists.txt | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/examples/CMakeLists.txt b/examples/CMakeLists.txt index a0bef0240..9f8fd0332 100644 --- a/examples/CMakeLists.txt +++ b/examples/CMakeLists.txt @@ -1,8 +1,8 @@ -# add_subdirectory(qwen2vl) -# add_subdirectory(qwen2vl_tracer) -# add_subdirectory(qwen2_5vl) -# add_subdirectory(qwen2_5vl_tracer) -# add_subdirectory(llama) +add_subdirectory(qwen2vl) +add_subdirectory(qwen2vl_tracer) +add_subdirectory(qwen2_5vl) +add_subdirectory(qwen2_5vl_tracer) +add_subdirectory(llama) add_subdirectory(minicpm_o) add_subdirectory(qwen3) add_subdirectory(qwen3_service) From 284573f5ff60417793f3d92076a1fa3426deb52a Mon Sep 17 00:00:00 2001 From: oreomaker <70836772+oreomaker@users.noreply.github.com> Date: Thu, 30 Oct 2025 09:29:40 +0800 Subject: [PATCH 3/7] chore: Update MiniCPMOMessage prompt in main.cpp --- examples/minicpm_o/main.cpp | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/examples/minicpm_o/main.cpp b/examples/minicpm_o/main.cpp index a7d9fd32c..b9a69c5a9 100644 --- a/examples/minicpm_o/main.cpp +++ b/examples/minicpm_o/main.cpp @@ -44,8 +44,9 @@ MLLM_MAIN({ auto minicpmo_cfg = mllm::models::minicpmo::MiniCPMOConfig(config_path.get()); auto minicpmo_tokenizer = mllm::models::minicpmo::MiniCPMOTokenizer(tokenizer_path.get()); - mllm::models::minicpmo::MiniCPMOMessage message; - message.prompt = "现在你是太监,这个男子是皇上,你需要真心实意地奉承他"; + mllm::models::minicpmo::MiniCPMOMessage message{ + .prompt = "Inctroduce your self" + }; message.img_file_path = "/Users/kkkai/Desktop/pics.jpg"; auto output = minicpmo_tokenizer.convertMessage(message); mllm::print(output["input_ids"].shape()); From 2e3840a6b2700601da63d48cb8b97e6b1b6aafda Mon Sep 17 00:00:00 2001 From: oreomaker <70836772+oreomaker@users.noreply.github.com> Date: Thu, 30 Oct 2025 09:43:28 +0800 Subject: [PATCH 4/7] Update CPU backend compile options for ARM build --- tasks/build_osx_apple_silicon.yaml | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/tasks/build_osx_apple_silicon.yaml b/tasks/build_osx_apple_silicon.yaml index 501cbd088..69dacb955 100644 --- a/tasks/build_osx_apple_silicon.yaml +++ b/tasks/build_osx_apple_silicon.yaml @@ -4,11 +4,10 @@ Tasks: cmake_build_type: "Release" cmake_extra_args: - "-DMLLM_BUILD_ARM_BACKEND=ON" - - '-DMLLM_CPU_BACKEND_COMPILE_OPTIONS="-march=native+fp16+fp16fml+dotprod+i8mm"' + - '-DMLLM_CPU_BACKEND_COMPILE_OPTIONS="-march=native+fp16+fp16fml+dotprod+i8mm+sme"' - "-DMLLM_KERNEL_USE_THREADS=ON" - "-DMLLM_KERNEL_THREADS_VENDOR_OPENMP=OFF" - "-DMLLM_KERNEL_THREADS_VENDOR_APPLE_GCD=ON" - - “-DCMAKE_CXX_FLAGS=-DMLLM_MACOS_BUILD” - CMakeBuildTask: cmake_cfg_path: "build-osx" From 811eee2e5ef4c854d5a76b6aad7f1854cbd14095 Mon Sep 17 00:00:00 2001 From: oreomaker <70836772+oreomaker@users.noreply.github.com> Date: Thu, 30 Oct 2025 09:43:55 +0800 Subject: [PATCH 5/7] Update CPU backend compile options for ARM --- tasks/build_osx_apple_silicon_dbg.yaml | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/tasks/build_osx_apple_silicon_dbg.yaml b/tasks/build_osx_apple_silicon_dbg.yaml index a9d8e6e08..a99963c8e 100644 --- a/tasks/build_osx_apple_silicon_dbg.yaml +++ b/tasks/build_osx_apple_silicon_dbg.yaml @@ -4,7 +4,9 @@ Tasks: cmake_build_type: "Debug" cmake_extra_args: - "-DMLLM_BUILD_ARM_BACKEND=ON" - - '-DMLLM_CPU_BACKEND_COMPILE_OPTIONS="-march=native+fp16+fp16fml+dotprod+i8mm"' + - '-DMLLM_CPU_BACKEND_COMPILE_OPTIONS="-march=native+fp16+fp16fml+dotprod+i8mm+sme"' + - "-DMLLM_USE_BLAS=ON" + - "-DMLLM_BLAS_VENDOR_ACCELERATE=ON" - "-DMLLM_KERNEL_USE_THREADS=ON" - "-DMLLM_KERNEL_THREADS_VENDOR_OPENMP=OFF" - "-DMLLM_KERNEL_THREADS_VENDOR_APPLE_GCD=ON" From e11ba5cbe7a8aeeee0c0df8a9872e9122b890916 Mon Sep 17 00:00:00 2001 From: KKkai <1640576073@qq.com> Date: Thu, 30 Oct 2025 17:34:49 +0800 Subject: [PATCH 6/7] fix: resolve code review issues --- examples/minicpm_o/CMakeLists.txt | 8 +- examples/minicpm_o/main.cpp | 164 ++++++------------ examples/minicpm_o/mainllm.cpp | 100 ----------- .../image_preprocessor_minicpmo.hpp | 1 - .../models/minicpm_o2_6/modeling_minicpmo.hpp | 65 +------ .../minicpm_o2_6/modeling_resampler.hpp | 15 +- .../minicpm_o2_6/tokenization_minicpmo.hpp | 2 +- 7 files changed, 61 insertions(+), 294 deletions(-) delete mode 100644 examples/minicpm_o/mainllm.cpp diff --git a/examples/minicpm_o/CMakeLists.txt b/examples/minicpm_o/CMakeLists.txt index 1f9f8e1a9..9c2a7881c 100644 --- a/examples/minicpm_o/CMakeLists.txt +++ b/examples/minicpm_o/CMakeLists.txt @@ -1,17 +1,15 @@ cmake_minimum_required(VERSION 3.10) -include_directories($ENV{HOME}/local/include) -link_directories($ENV{HOME}/local/lib) add_executable(main_minicpm_o main.cpp) -target_link_libraries(main_minicpm_o PRIVATE MllmRT MllmCPUBackend cnpy z) +target_link_libraries(main_minicpm_o PRIVATE MllmRT MllmCPUBackend) target_include_directories(main_minicpm_o PRIVATE ${MLLM_INCLUDE_DIR}) add_executable(main_minicpm_o2 mainllm.cpp) -target_link_libraries(main_minicpm_o2 PRIVATE MllmRT MllmCPUBackend cnpy z) +target_link_libraries(main_minicpm_o2 PRIVATE MllmRT MllmCPUBackend) target_include_directories(main_minicpm_o2 PRIVATE ${MLLM_INCLUDE_DIR}) add_executable(main_minicpm_dbg main_dbg.cpp) -target_link_libraries(main_minicpm_dbg PRIVATE MllmRT MllmCPUBackend cnpy z) +target_link_libraries(main_minicpm_dbg PRIVATE MllmRT MllmCPUBackend) target_include_directories(main_minicpm_dbg PRIVATE ${MLLM_INCLUDE_DIR}) add_executable(tokenizer_test tokenizer_test.cpp) diff --git a/examples/minicpm_o/main.cpp b/examples/minicpm_o/main.cpp index b9a69c5a9..06ae30964 100644 --- a/examples/minicpm_o/main.cpp +++ b/examples/minicpm_o/main.cpp @@ -2,7 +2,7 @@ #include #include "mllm/mllm.hpp" #include "mllm/models/minicpm_o2_6/configuration_minicpmo.hpp" -// #include "mllm/models/minicpm_o2_6/modeling_minicpmo.hpp" +#include "mllm/models/minicpm_o2_6/modeling_minicpmo.hpp" #include "mllm/models/minicpm_o2_6/modeling_resampler.hpp" #include "mllm/models/minicpm_o2_6/modeling_siglip.hpp" #include "mllm/models/minicpm_o2_6/tokenization_minicpmo.hpp" @@ -14,13 +14,20 @@ using mllm::Argparse; MLLM_MAIN({ mllm::Logger::level() = mllm::LogLevel::kError; - + //mllm::setPrintMaxElementsPerDim(1000); // For debugging large tensors + auto& help = Argparse::add("-h|--help").help("Show help message"); auto& model_path = Argparse::add("-m|--model_path").help("Model path").required(true); auto& model_version = Argparse::add("-mv|--model_version").help("Model version").required(true); auto& tokenizer_path = Argparse::add("-t|--tokenizer_path").help("Tokenizer directory").required(true); auto& config_path = Argparse::add("-c|--config_path").help("Config path").required(true); - // RUN: ./main_minicpm_o -m ../../models/minicpm-o-2_6.mllm -mv v1 -t ../../tokenizer/MiniCPM-o-2_6/tokenizer.json -c ../../examples/minicpm_o/config_minicpm_o.json + /* + FOR RUN(MacOS Apple Silicon): + python task.py tasks/build_osx_apple_silicon.yaml + cd build-osx/bin + ./main_minicpm_o -m ../../models/minicpm-o-2_6.mllm -mv v1 -t ../../tokenizer/MiniCPM-o-2_6/tokenizer.json -c ../../examples/minicpm_o/config_minicpm_o.json + (need to get model.mllm and tokenizer.json first) + */ Argparse::parse(argc, argv); @@ -28,7 +35,7 @@ MLLM_MAIN({ mllm::perf::start(); #endif - mllm::ModelFileVersion file_version = mllm::ModelFileVersion::kV1; + mllm::ModelFileVersion file_version = mllm::ModelFileVersion::kV1; if (model_version.get() == "v1") { file_version = mllm::ModelFileVersion::kV1; } else if (model_version.get() == "v2") { @@ -41,120 +48,51 @@ MLLM_MAIN({ return 0; } { - auto minicpmo_cfg = mllm::models::minicpmo::MiniCPMOConfig(config_path.get()); - auto minicpmo_tokenizer = mllm::models::minicpmo::MiniCPMOTokenizer(tokenizer_path.get()); - - mllm::models::minicpmo::MiniCPMOMessage message{ - .prompt = "Inctroduce your self" - }; - message.img_file_path = "/Users/kkkai/Desktop/pics.jpg"; - auto output = minicpmo_tokenizer.convertMessage(message); - mllm::print(output["input_ids"].shape()); - mllm::print(output["pixel_values"].shape()); - mllm::print(output["tgt_sizes"].shape()); - mllm::print(output["image_bounds"].shape()); - - auto param = mllm::load(model_path.get(), file_version); - auto siglip = mllm::models::minicpmo::SiglipVisionModel("vpm", minicpmo_cfg); - siglip.load(param); - auto res = siglip(output["pixel_values"], output["tgt_sizes"])[0]; - auto resampler = mllm::models::minicpmo::Resampler("resampler", 64, 3584, 28, 1152); - resampler.load(param); - auto res2 = resampler(res, output["tgt_sizes"])[0]; - - // auto minicpmo = mllm::models::minicpmo::MiniCPMOForCausalLM(minicpmo_cfg); - - // // Load model weights - // auto param = mllm::load(model_path.get(), file_version); - // minicpmo.load(param); - - // fmt::print("\n{:*^60}\n", " MiniCPM-o Interactive CLI "); - // fmt::print("Enter 'exit' or 'quit' to end the session\n"); - // fmt::print("Supported modes: text, image+text, audio+text, multimodal\n\n"); - - // while (true) { - // std::string mode; - // fmt::print("Mode (text/image/audio/multi) or 'exit': "); - // std::getline(std::cin, mode); - - // if (mode == "exit" || mode == "quit") { - // break; - // } - - // mllm::models::minicpmo::MiniCPMOInput input; - - // // Handle different input modes - // if (mode == "image" || mode == "multi") { - // std::string image_path; - // fmt::print("Image path: "); - // std::getline(std::cin, image_path); - // if (!image_path.empty()) { - // input.img_file_path = image_path; - // } - // } - - // if (mode == "audio" || mode == "multi") { - // std::string audio_path; - // fmt::print("Audio path: "); - // std::getline(std::cin, audio_path); - // if (!audio_path.empty()) { - // input.audio_file_path = audio_path; - // } - // } - - // std::string prompt_text; - // fmt::print("Prompt text: "); - // std::getline(std::cin, prompt_text); - // input.prompt = prompt_text; - - // try { - // fmt::print("Processing...\n"); - - // // Convert input to tokens - // auto input_tokens = minicpmo_tokenizer.convertMessage(input); - - // // Process images if provided - // auto image_tensors = minicpmo_tokenizer.processImages(input); - - // // Process audio if provided - // auto audio_tensors = minicpmo_tokenizer.processAudio(input); - - // fmt::print("\nResponse: "); - - // // TODO: Implement multimodal chat interface - // // For now, use text-only generation - // std::vector token_ids; - // auto input_ptr = input_tokens.ptr(); - // auto seq_len = input_tokens.shape()[1]; - // for (int i = 0; i < seq_len; ++i) { - // token_ids.push_back(input_ptr[i]); - // } - - // // Generate response - // for (auto& step : minicpmo.chat(token_ids)) { - // auto token_str = minicpmo_tokenizer.detokenize(step.cur_token_id); - // std::wcout << token_str << std::flush; - - // // TODO: Check for audio generation tokens - // if (minicpmo_tokenizer.isAudioToken(step.cur_token_id)) { - // fmt::print("\n🔊 [Audio generation triggered - feature not implemented yet]\n"); - // } - // } - - // fmt::print("\n\n"); - - // } catch (const std::exception& e) { - // fmt::print(" Error: {}\n", e.what()); - // } - // } - - // fmt::print("Success!\n"); - } + auto minicpmo_cfg = mllm::models::minicpmo::MiniCPMOConfig(config_path.get()); + auto minicpmo_tokenizer = mllm::models::minicpmo::MiniCPMOTokenizer(tokenizer_path.get()); + auto minicpmo = mllm::models::minicpmo::MiniCPMOForCausalLM(minicpmo_cfg); + + auto param = mllm::load(model_path.get(), file_version); + minicpmo.llm_.llm.load(param); + minicpmo.vpm_.load(param); + minicpmo.resampler_.load(param); + //minicpmo.audio_proj_.load(param); + //minicpmo.tts_proj_.load(param); + + fmt::print("\n{:*^60}\n", " MiniCPM-o Interactive CLI "); + fmt::print("Enter 'exit' or 'quit' to end the session\n"); + + std::string image_path = "path/to/your/image.jpg"; + std::string prompt_text = "描述图片中物体"; + mllm::models::minicpmo::MiniCPMOMessage message; + message.prompt = prompt_text; + message.img_file_path = image_path; + + fmt::print("Processing...\n"); + auto inputs = minicpmo_tokenizer.convertMessage(message); + + fmt::print("\nResponse: "); + + int token_count = 0; + for(auto& step : minicpmo.chat(inputs)){ + auto token_str = minicpmo_tokenizer.detokenize(step.cur_token_id); + std::wcout<< token_str << std::flush; + + token_count++; + if(token_count >= 50) break; // Limit output for debugging + } + + fmt::print("\n{}\n", std::string(60, '-')); + + #ifdef MLLM_PERFETTO_ENABLE mllm::perf::stop(); + mllm::perf::saveReport("minicpmo.perf"); #endif + mllm::memoryReport(); mllm::shutdownContext(); return 0; + } }) diff --git a/examples/minicpm_o/mainllm.cpp b/examples/minicpm_o/mainllm.cpp deleted file mode 100644 index 48712eab3..000000000 --- a/examples/minicpm_o/mainllm.cpp +++ /dev/null @@ -1,100 +0,0 @@ -#include -#include -#include "mllm/mllm.hpp" -#include "mllm/models/minicpm_o2_6/configuration_minicpmo.hpp" -#include "mllm/models/minicpm_o2_6/modeling_minicpmo.hpp" -#include "mllm/models/minicpm_o2_6/modeling_resampler.hpp" -#include "mllm/models/minicpm_o2_6/modeling_siglip.hpp" -#include "mllm/models/minicpm_o2_6/tokenization_minicpmo.hpp" -#include "mllm/models/minicpm_o2_6/image_preprocessor_minicpmo.hpp" -#include "mllm/utils/AnyValue.hpp" -#include "mllm/preprocessor/visual/Image.hpp" -#include "cnpy.h" - -using mllm::Argparse; - -MLLM_MAIN({ - - mllm::Logger::level() = mllm::LogLevel::kError; - //mllm::setPrintMaxElementsPerDim(1000); - - auto& help = Argparse::add("-h|--help").help("Show help message"); - auto& model_path = Argparse::add("-m|--model_path").help("Model path").required(true); - auto& model_version = Argparse::add("-mv|--model_version").help("Model version").required(true); - auto& tokenizer_path = Argparse::add("-t|--tokenizer_path").help("Tokenizer directory").required(true); - auto& config_path = Argparse::add("-c|--config_path").help("Config path").required(true); - // RUN: ./main_minicpm_o2 -m ../../models/minicpm-o-2_6.mllm -mv v1 -t ../../tokenizer/MiniCPM-o-2_6/tokenizer.json -c ../../examples/minicpm_o/config_minicpm_o.json - - Argparse::parse(argc, argv); - -#ifdef MLLM_PERFETTO_ENABLE - mllm::perf::start(); -#endif - - mllm::ModelFileVersion file_version = mllm::ModelFileVersion::kV1; - if (model_version.get() == "v1") { - file_version = mllm::ModelFileVersion::kV1; - } else if (model_version.get() == "v2") { - file_version = mllm::ModelFileVersion::kV2; - } - - if (help.isSet()) { - Argparse::printHelp(); - mllm::shutdownContext(); - return 0; - } - { - auto minicpmo_cfg = mllm::models::minicpmo::MiniCPMOConfig(config_path.get()); - auto minicpmo_tokenizer = mllm::models::minicpmo::MiniCPMOTokenizer(tokenizer_path.get()); - auto minicpmo = mllm::models::minicpmo::MiniCPMOForCausalLM(minicpmo_cfg); - - auto param = mllm::load(model_path.get(), file_version); - minicpmo.llm_.llm.load(param); - minicpmo.vpm_.load(param); - minicpmo.resampler_.load(param); - //minicpmo.audio_proj_.load(param); - //minicpmo.tts_proj_.load(param); - - fmt::print("\n{:*^60}\n", " MiniCPM-o Interactive CLI "); - fmt::print("Enter 'exit' or 'quit' to end the session\n"); - - std::string image_path = "/Users/kkkai/Desktop/pics.jpg"; - std::string prompt_text = "描述图片中物体"; - mllm::models::minicpmo::MiniCPMOMessage message; - message.prompt = prompt_text; - message.img_file_path = image_path; - - // fmt::print("📷 Image path (or 'exit/quit'): "); - // std::getline(std::cin, image_path); - // if (image_path == "exit" || image_path == "quit") { return 0; } - // fmt::print("💬 Prompt text: "); - // std::getline(std::cin, prompt_text); - - fmt::print("Processing...\n"); - auto inputs = minicpmo_tokenizer.convertMessage(message); - - fmt::print("\nResponse: "); - - int token_count = 0; - for(auto& step : minicpmo.chat(inputs)){ - auto token_str = minicpmo_tokenizer.detokenize(step.cur_token_id); - std::wcout<< token_str << std::flush; - - token_count++; - if(token_count >= 50) break; // Limit output for debugging - } - - fmt::print("\n{}\n", std::string(60, '-')); - - - -#ifdef MLLM_PERFETTO_ENABLE - mllm::perf::stop(); - mllm::perf::saveReport("minicpmo.perf"); -#endif - - mllm::memoryReport(); - mllm::shutdownContext(); - return 0; - } -}) diff --git a/mllm/models/minicpm_o2_6/image_preprocessor_minicpmo.hpp b/mllm/models/minicpm_o2_6/image_preprocessor_minicpmo.hpp index 2f161b654..5bccda7df 100644 --- a/mllm/models/minicpm_o2_6/image_preprocessor_minicpmo.hpp +++ b/mllm/models/minicpm_o2_6/image_preprocessor_minicpmo.hpp @@ -19,7 +19,6 @@ namespace mllm::models::minicpmo { -// Utility functions for image slicing (similar to MiniCPMV) class ImageSliceProcessor { public: ImageSliceProcessor() = default; diff --git a/mllm/models/minicpm_o2_6/modeling_minicpmo.hpp b/mllm/models/minicpm_o2_6/modeling_minicpmo.hpp index 2233f34dd..ede2617dd 100644 --- a/mllm/models/minicpm_o2_6/modeling_minicpmo.hpp +++ b/mllm/models/minicpm_o2_6/modeling_minicpmo.hpp @@ -109,9 +109,6 @@ class MiniCPMOForCausalLM : public models::ARGeneration { //AudioProjectionLayer audio_projection_layer_; //TTSProjector tts_projector_; - // Debug flag to control whether to load Python embeddings - bool loadPythonEmbedding = false; - private: nn::StaticCache kv_cache_; @@ -161,52 +158,7 @@ class MiniCPMOForCausalLM : public models::ARGeneration { // Process vision inputs if provided - ONLY in prefill stage if (!pixel_values.isNil() && !tgt_sizes.isNil() && !is_decode_stage) { auto vision_outputs = vpm_(pixel_values, tgt_sizes)[0]; - // std::vector vision_outputs_vec; - // vision_outputs_vec.reserve(10*1036*1152); - // for(int i=0;i<10;i++){ - // for(int j=0;j<1036;j++){ - // for(int k=0;k<1152;k++){ - // vision_outputs_vec.push_back(vision_outputs.at({i,j,k})); - // } - // } - // } - // cnpy::npy_save("vision_outputs.npy", - // vision_outputs_vec.data(), - // {10, 1036, 1152}, - // "w"); auto vision_embeddings = resampler_(vision_outputs, tgt_sizes)[0]; - std::vector vision_embeddings_vec; - vision_embeddings_vec.reserve(10*64*3584); - for(int i=0;i<10;i++){ - for(int j=0;j<64;j++){ - for(int k=0;k<3584;k++){ - vision_embeddings_vec.push_back(vision_embeddings.at({i,j,k})); - } - } - } - cnpy::npy_save("vision_embeddings.npy", - vision_embeddings_vec.data(), - {10, 64, 3584}, - "w"); - mllm::print(vision_embeddings.shape()); - mllm::print(vision_embeddings.at({0,0,0})); - mllm::print(vision_embeddings.at({0,14,175})); - mllm::print(vision_embeddings.at({1,28,2995})); - mllm::print(vision_embeddings.at({1,33,1365})); - mllm::print(vision_embeddings.at({2,8,764})); - mllm::print(vision_embeddings.at({3,49,2222})); - mllm::print(vision_embeddings.at({4,62,2003})); - mllm::print(vision_embeddings.at({5,55,1013})); - mllm::print(vision_embeddings.at({6,19,75})); - mllm::print(vision_embeddings.at({7,21,196})); - mllm::print(vision_embeddings.at({8,50,1997})); - mllm::print(vision_embeddings.at({9,33,2958})); - mllm::print(vision_embeddings.at({8,2,2598})); - mllm::print(vision_embeddings.at({7,5,338})); - mllm::print(vision_embeddings.at({6,41,1157})); - mllm::print(vision_embeddings.at({5,61,2075})); - mllm::print(vision_embeddings.at({4,55,312})); - if (!image_bounds.isNil()) { input_embeddings = merge_vision_text_embeddings(input_embeddings, vision_embeddings, image_bounds); } @@ -308,6 +260,7 @@ class MiniCPMOForCausalLM : public models::ARGeneration { int vision_idx = 0; auto start_pos = image_bounds.at({bound_idx, 0}) + 1; auto end_pos = image_bounds.at({bound_idx, 1}) - 1; + // exactly replace tokens between and for (int pos = start_pos; pos <= end_pos && vision_idx < vision_seq_len; ++pos, ++vision_idx) { for (int d = 0; d < embed_dim; ++d) { text_embeddings.at({b, pos, d}) = vision_embeddings.at({bound_idx, vision_idx, d}); @@ -316,22 +269,6 @@ class MiniCPMOForCausalLM : public models::ARGeneration { } } } - - mllm::print("finished merging!"); - - // Debug: Load and replace with Python-saved embeddings - if (loadPythonEmbedding) { - cnpy::NpyArray arr = cnpy::npy_load("../../models/merged_input_embedding.npy"); - float* data_ptr = arr.data(); - std::vector vec(data_ptr, data_ptr + arr.num_vals); - auto tt = mllm::Tensor::fromVector(vec, {1,699,3584}, mllm::kFloat32); - mllm::print(tt.shape()); - mllm::print(text_embeddings.shape()); - text_embeddings = tt; - mllm::print("✅ Loaded Python embedding for debugging!"); - return tt; - } - return text_embeddings; } diff --git a/mllm/models/minicpm_o2_6/modeling_resampler.hpp b/mllm/models/minicpm_o2_6/modeling_resampler.hpp index e14c6d24a..55cb9f4d0 100644 --- a/mllm/models/minicpm_o2_6/modeling_resampler.hpp +++ b/mllm/models/minicpm_o2_6/modeling_resampler.hpp @@ -31,16 +31,12 @@ inline Tensor get2DSinCosPosEmbed(int32_t embed_dim, const std::vector& for (int32_t h = 0; h < grid_h_size; ++h) { for (int32_t w = 0; w < grid_w_size; ++w) { - // Width encoding: first half_dim dimensions (根据Python输出推断) - // [0:half_dim/2] = sin(w*omega), [half_dim/2:half_dim] = cos(w*omega) for (int32_t i = 0; i < half_dim / 2; ++i) { float omega = 1.0f / std::pow(10000.0f, 2.0f * i / half_dim); *pos_embed.offsettedPtr({h, w, i}) = std::sin(w * omega); *pos_embed.offsettedPtr({h, w, i + half_dim / 2}) = std::cos(w * omega); } - // Height encoding: second half_dim dimensions - // [half_dim:half_dim+half_dim/2] = sin(h*omega), [half_dim+half_dim/2:embed_dim] = cos(h*omega) for (int32_t i = 0; i < half_dim / 2; ++i) { float omega = 1.0f / std::pow(10000.0f, 2.0f * i / half_dim); *pos_embed.offsettedPtr({h, w, half_dim + i}) = std::sin(h * omega); @@ -172,7 +168,6 @@ class ResamplerAttention : public nn::Module { auto attn_weights = nn::functional::matmul(q, k, false, true) * scale; // [num_heads, num_queries, seq_len] if (has_key_padding_mask && key_padding_mask.numel() > 0) { - mllm::print("Applying key padding mask in ResamplerAttention"); auto mask_value = -std::numeric_limits::infinity(); for (int32_t h = 0; h < num_heads_; ++h) { for (int32_t q_idx = 0; q_idx < num_queries; ++q_idx) { @@ -192,14 +187,14 @@ class ResamplerAttention : public nn::Module { auto attn_output = nn::functional::matmul(attn_weights, v); // [num_heads, num_queries, head_dim] auto attn_output_reshaped = Tensor::empty({num_queries, embed_dim_}, kFloat32).alloc(); -for(int h = 0; h < num_heads_; h++) { - for(int nq = 0; nq < num_queries; nq++) { - for(int d = 0; d < head_dim_; d++) { + for(int h = 0; h < num_heads_; h++) { + for(int nq = 0; nq < num_queries; nq++) { + for(int d = 0; d < head_dim_; d++) { float val = attn_output.at({h, nq, d}); *attn_output_reshaped.offsettedPtr({nq, h * head_dim_ + d}) = val; + } + } } - } -} attn_output = attn_output_reshaped; return {out_proj_(attn_output)}; diff --git a/mllm/models/minicpm_o2_6/tokenization_minicpmo.hpp b/mllm/models/minicpm_o2_6/tokenization_minicpmo.hpp index 571482283..f1503ee90 100644 --- a/mllm/models/minicpm_o2_6/tokenization_minicpmo.hpp +++ b/mllm/models/minicpm_o2_6/tokenization_minicpmo.hpp @@ -155,8 +155,8 @@ struct MiniCPMOMessage { std::string audio_file_path; std::string system_prompt = "You are Qwen, created by Alibaba Clound. You are a helpful assistant."; - // 格式: <|im_start|>{role}\n{content}<|im_end|>\n [[nodiscard]]std::string buildChatMessage() const { + // For now, one picture only std::string result=""; // System message if (!system_prompt.empty()) { From 5a6e66627a7f003654325ad9d864ca4b91adcca4 Mon Sep 17 00:00:00 2001 From: oreomaker <70836772+oreomaker@users.noreply.github.com> Date: Fri, 31 Oct 2025 08:34:15 +0800 Subject: [PATCH 7/7] Clean up CMakeLists.txt by removing unused executables Removed multiple executable definitions and their associated libraries and include directories from CMakeLists.txt. --- examples/minicpm_o/CMakeLists.txt | 12 ------------ 1 file changed, 12 deletions(-) diff --git a/examples/minicpm_o/CMakeLists.txt b/examples/minicpm_o/CMakeLists.txt index 9c2a7881c..84f48d365 100644 --- a/examples/minicpm_o/CMakeLists.txt +++ b/examples/minicpm_o/CMakeLists.txt @@ -3,15 +3,3 @@ cmake_minimum_required(VERSION 3.10) add_executable(main_minicpm_o main.cpp) target_link_libraries(main_minicpm_o PRIVATE MllmRT MllmCPUBackend) target_include_directories(main_minicpm_o PRIVATE ${MLLM_INCLUDE_DIR}) - -add_executable(main_minicpm_o2 mainllm.cpp) -target_link_libraries(main_minicpm_o2 PRIVATE MllmRT MllmCPUBackend) -target_include_directories(main_minicpm_o2 PRIVATE ${MLLM_INCLUDE_DIR}) - -add_executable(main_minicpm_dbg main_dbg.cpp) -target_link_libraries(main_minicpm_dbg PRIVATE MllmRT MllmCPUBackend) -target_include_directories(main_minicpm_dbg PRIVATE ${MLLM_INCLUDE_DIR}) - -add_executable(tokenizer_test tokenizer_test.cpp) -target_link_libraries(tokenizer_test PRIVATE MllmRT MllmCPUBackend) -target_include_directories(tokenizer_test PRIVATE ${MLLM_INCLUDE_DIR})