From 74a989b70880fdfa5d5c663c8177a0657bdfad9e Mon Sep 17 00:00:00 2001 From: chenghuaWang <2923277184@qq.com> Date: Wed, 22 Oct 2025 22:21:51 +0800 Subject: [PATCH 01/25] feat(deepseek-ocr): deepseek-ocr support - Implement Conv2DOp and LayerNorm2DOp for CPU backend - Add ARM NEON optimized kernels for conv2d and layernorm2d - Integrate new ops into IR and op registration system - Extend Tensor API with size() and flatten() methods - Update RTTI and OpTypes to support new operations --- mllm/backends/cpu/CPUBackend.cpp | 5 +- mllm/backends/cpu/kernels/Kernels.hpp | 2 + mllm/backends/cpu/kernels/arm/conv2d.cpp | 100 +++ mllm/backends/cpu/kernels/arm/conv2d.hpp | 34 + mllm/backends/cpu/kernels/arm/layernorm2d.cpp | 89 +++ mllm/backends/cpu/kernels/arm/layernorm2d.hpp | 18 + mllm/backends/cpu/ops/Conv2DOp.cpp | 144 ++++ mllm/backends/cpu/ops/Conv2DOp.hpp | 27 + mllm/backends/cpu/ops/LayerNorm2DOp.cpp | 38 + mllm/backends/cpu/ops/LayerNorm2DOp.hpp | 25 + mllm/compile/ir/GeneratedRTTIKind.hpp | 3 +- mllm/compile/ir/NodeRTTIClassOfImpl.hpp | 5 +- mllm/compile/ir/linalg/Op.cpp | 2 + mllm/compile/ir/linalg/Op.hpp | 3 + mllm/compile/ir/rtti_kind_gen.py | 1 + mllm/core/OpTypes.hpp | 2 + mllm/core/Tensor.cpp | 32 + mllm/core/Tensor.hpp | 16 + mllm/core/aops/Conv2DOp.cpp | 113 +++ mllm/core/aops/Conv2DOp.hpp | 72 ++ mllm/core/aops/LayerNorm2DOp.cpp | 56 ++ mllm/core/aops/LayerNorm2DOp.hpp | 44 ++ mllm/models/deepseek_ocr/README.md | 55 ++ .../configuration_deepseek_ocr.hpp | 187 +++++ mllm/models/deepseek_ocr/conversation.hpp | 2 + mllm/models/deepseek_ocr/deepencoder.hpp | 732 ++++++++++++++++++ .../deepseek_ocr/modeling_deepseek_ocr.hpp | 2 + .../tokenization_deepseek_ocr.hpp | 51 ++ mllm/nn/Functional.cpp | 9 + mllm/nn/Functional.hpp | 3 + mllm/nn/Module.hpp | 4 +- mllm/nn/Nn.hpp | 2 + mllm/nn/layers/Conv2D.cpp | 29 + mllm/nn/layers/Conv2D.hpp | 27 + mllm/nn/layers/LayerNorm2D.cpp | 16 + mllm/nn/layers/LayerNorm2D.hpp | 23 + tests/cpu/Conv2DKernelTest.hpp | 119 +++ tests/cpu/FlashAttentionKernelTest.hpp | 1 - tests/cpu/KernelTest.cpp | 49 ++ 39 files changed, 2136 insertions(+), 6 deletions(-) create mode 100644 mllm/backends/cpu/kernels/arm/conv2d.cpp create mode 100644 mllm/backends/cpu/kernels/arm/conv2d.hpp create mode 100644 mllm/backends/cpu/kernels/arm/layernorm2d.cpp create mode 100644 mllm/backends/cpu/kernels/arm/layernorm2d.hpp create mode 100644 mllm/backends/cpu/ops/Conv2DOp.cpp create mode 100644 mllm/backends/cpu/ops/Conv2DOp.hpp create mode 100644 mllm/backends/cpu/ops/LayerNorm2DOp.cpp create mode 100644 mllm/backends/cpu/ops/LayerNorm2DOp.hpp create mode 100644 mllm/core/aops/Conv2DOp.cpp create mode 100644 mllm/core/aops/Conv2DOp.hpp create mode 100644 mllm/core/aops/LayerNorm2DOp.cpp create mode 100644 mllm/core/aops/LayerNorm2DOp.hpp create mode 100644 mllm/models/deepseek_ocr/README.md create mode 100644 mllm/models/deepseek_ocr/configuration_deepseek_ocr.hpp create mode 100644 mllm/models/deepseek_ocr/conversation.hpp create mode 100644 mllm/models/deepseek_ocr/deepencoder.hpp create mode 100644 mllm/models/deepseek_ocr/modeling_deepseek_ocr.hpp create mode 100644 mllm/models/deepseek_ocr/tokenization_deepseek_ocr.hpp create mode 100644 mllm/nn/layers/Conv2D.cpp create mode 100644 mllm/nn/layers/Conv2D.hpp create mode 100644 mllm/nn/layers/LayerNorm2D.cpp create mode 100644 mllm/nn/layers/LayerNorm2D.hpp create mode 100644 tests/cpu/Conv2DKernelTest.hpp diff --git a/mllm/backends/cpu/CPUBackend.cpp b/mllm/backends/cpu/CPUBackend.cpp index c94871b90..dd8549804 100644 --- a/mllm/backends/cpu/CPUBackend.cpp +++ b/mllm/backends/cpu/CPUBackend.cpp @@ -10,6 +10,7 @@ #include "mllm/backends/cpu/ops/ConcatOp.hpp" #include "mllm/backends/cpu/ops/ContiguousOp.hpp" #include "mllm/backends/cpu/ops/Conv1DOp.hpp" +#include "mllm/backends/cpu/ops/Conv2DOp.hpp" #include "mllm/backends/cpu/ops/Conv3DOp.hpp" #include "mllm/backends/cpu/ops/CopyOp.hpp" #include "mllm/backends/cpu/ops/ElewiseOps.hpp" @@ -17,6 +18,7 @@ #include "mllm/backends/cpu/ops/FillOp.hpp" #include "mllm/backends/cpu/ops/FlashAttention2Op.hpp" #include "mllm/backends/cpu/ops/GELUOp.hpp" +#include "mllm/backends/cpu/ops/LayerNorm2DOp.hpp" #include "mllm/backends/cpu/ops/RadixAttnOp.hpp" #include "mllm/backends/cpu/ops/ReLUOp.hpp" #include "mllm/backends/cpu/ops/GraphOps.hpp" @@ -60,7 +62,8 @@ CPUBackend::CPUBackend() : Backend(kCPU, createCPUAllocator()) { CPUFlashAttention2OpFactory, CPUSliceOpFactory, CPUVisionRoPEOpFactory, CPUParamOpFactory, CPUMultimodalRoPEOpFactory, CPURoPEOpFactory, CPUCausalMaskOpFactory, CPUConv1DOpFactory, CPUConv3DOpFactory, CPUSTFTOpFactory, CPUISTFTOpFactory, CPUIndexOpFactory, CPUTopKOpFactory, CPUClipOpFactory, CPUMeanOpFactory, - CPUKVCacheOpFactory, CPUPagedAttnOpFactory, CPUScatter2ShardsOpFactory, CPURadixAttnOpFactory>(); + CPUKVCacheOpFactory, CPUPagedAttnOpFactory, CPUScatter2ShardsOpFactory, CPURadixAttnOpFactory, + CPUConv2DOpFactory, CPULayerNorm2DOpFactory>(); } std::shared_ptr createCPUBackend() { return std::make_shared(); } diff --git a/mllm/backends/cpu/kernels/Kernels.hpp b/mllm/backends/cpu/kernels/Kernels.hpp index e68b70976..026cc84a8 100644 --- a/mllm/backends/cpu/kernels/Kernels.hpp +++ b/mllm/backends/cpu/kernels/Kernels.hpp @@ -30,6 +30,8 @@ #include "mllm/backends/cpu/kernels/arm/conv3d.hpp" // IWYU pragma: export #include "mllm/backends/cpu/kernels/arm/linear/kai.hpp" // IWYU pragma: export #include "mllm/backends/cpu/kernels/arm/relu.hpp" // IWYU pragma: export +#include "mllm/backends/cpu/kernels/arm/conv2d.hpp" // IWYU pragma: export +#include "mllm/backends/cpu/kernels/arm/layernorm2d.hpp" // IWYU pragma: export #include "mllm/backends/cpu/kernels/arm/mllm_blas/mllm_blas_sgemm.hpp" // IWYU pragma: export #else #include "mllm/backends/cpu/kernels/common/gelu-inl.hpp" // IWYU pragma: export diff --git a/mllm/backends/cpu/kernels/arm/conv2d.cpp b/mllm/backends/cpu/kernels/arm/conv2d.cpp new file mode 100644 index 000000000..d7ceddf18 --- /dev/null +++ b/mllm/backends/cpu/kernels/arm/conv2d.cpp @@ -0,0 +1,100 @@ +// Copyright (c) MLLM Team. +// Licensed under the MIT License. +#include "mllm/backends/cpu/kernels/arm/conv2d.hpp" + +#if defined(MLLM_HOST_ARCH_ARM64) || defined(MLLM_HOST_ARCH_ARM) + +#include + +namespace mllm::cpu::arm { + +void conv2d_fp32_im2col_input(const float* input_data, const int channels, const int height, const int width, + const int kernel_h, const int kernel_w, const int pad_h, const int pad_w, const int stride_h, + const int stride_w, const int dilation_h, const int dilation_w, float* col_data) { + const int output_h = (height + 2 * pad_h - (dilation_h * (kernel_h - 1) + 1)) / stride_h + 1; + const int output_w = (width + 2 * pad_w - (dilation_w * (kernel_w - 1) + 1)) / stride_w + 1; + const int channel_size = height * width; + + const float32x4_t vzero = vdupq_n_f32(0.0f); + + for (int channel = 0; channel < channels; ++channel) { + for (int kernel_y = 0; kernel_y < kernel_h; ++kernel_y) { + for (int kernel_x = 0; kernel_x < kernel_w; ++kernel_x) { + const int input_start_y = -pad_h + kernel_y * dilation_h; + const int input_start_x = -pad_w + kernel_x * dilation_w; + + for (int out_y = 0; out_y < output_h; ++out_y) { + const int cur_input_y = input_start_y + out_y * stride_h; + + if (static_cast(cur_input_y) >= static_cast(height)) { + for (int out_x = 0; out_x < output_w; out_x += 4) { + if (out_x + 3 < output_w) { + vst1q_f32(col_data, vzero); + col_data += 4; + } else { + for (int i = 0; i < output_w - out_x; ++i) { *col_data++ = 0.0f; } + } + } + } else { + int out_x = 0; + for (; out_x + 3 < output_w; out_x += 4) { + const int input_x0 = input_start_x + (out_x + 0) * stride_w; + const int input_x1 = input_start_x + (out_x + 1) * stride_w; + const int input_x2 = input_start_x + (out_x + 2) * stride_w; + const int input_x3 = input_start_x + (out_x + 3) * stride_w; + + const float val0 = (static_cast(input_x0) < static_cast(width)) + ? input_data[cur_input_y * width + input_x0] + : 0.0f; + const float val1 = (static_cast(input_x1) < static_cast(width)) + ? input_data[cur_input_y * width + input_x1] + : 0.0f; + const float val2 = (static_cast(input_x2) < static_cast(width)) + ? input_data[cur_input_y * width + input_x2] + : 0.0f; + const float val3 = (static_cast(input_x3) < static_cast(width)) + ? input_data[cur_input_y * width + input_x3] + : 0.0f; + + float32x4_t v_data = {val0, val1, val2, val3}; + vst1q_f32(col_data, v_data); + col_data += 4; + } + + for (; out_x < output_w; ++out_x) { + const int cur_input_x = input_start_x + out_x * stride_w; + if (static_cast(cur_input_x) < static_cast(width)) { + *col_data++ = input_data[cur_input_y * width + cur_input_x]; + } else { + *col_data++ = 0.0f; + } + } + } + } + } + } + input_data += channel_size; + } +} + +void conv2d_fp32_im2col_weight(const float* src_weight, float* packed_weight, int out_channels, int in_channels, int kernel_h, + int kernel_w) { + int M = out_channels; + int K = in_channels * kernel_h * kernel_w; + + for (int o = 0; o < out_channels; ++o) { + for (int i = 0; i < in_channels; ++i) { + for (int h = 0; h < kernel_h; ++h) { + for (int w = 0; w < kernel_w; ++w) { + int src_idx = h * (kernel_w * in_channels * out_channels) + w * (in_channels * out_channels) + i * (out_channels) + o; + int dst_idx = o * (in_channels * kernel_h * kernel_w) + i * (kernel_h * kernel_w) + h * (kernel_w) + w; + packed_weight[dst_idx] = src_weight[src_idx]; + } + } + } + } +} + +} // namespace mllm::cpu::arm + +#endif diff --git a/mllm/backends/cpu/kernels/arm/conv2d.hpp b/mllm/backends/cpu/kernels/arm/conv2d.hpp new file mode 100644 index 000000000..129e272fa --- /dev/null +++ b/mllm/backends/cpu/kernels/arm/conv2d.hpp @@ -0,0 +1,34 @@ +// Copyright (c) MLLM Team. +// Licensed under the MIT License. + +#pragma once + +#include "mllm/utils/CPUArchHelper.hpp" + +#if defined(MLLM_HOST_ARCH_ARM64) || defined(MLLM_HOST_ARCH_ARM) + +#include + +namespace mllm::cpu::arm { + +//===----------------------------------------------------------------------===// +// Im2col. +// +// Reformat your inputs to im2col's input +// Reformat your weights to im2col's weight +// After those 2 parts, do gemm(weight, input) +//===----------------------------------------------------------------------===// +void conv2d_fp32_im2col_input(const float* input_data, const int channels, const int height, const int width, + const int kernel_h, const int kernel_w, const int pad_h, const int pad_w, const int stride_h, + const int stride_w, const int dilation_h, const int dilation_w, float* col_data); + +// Inputs weight format should in [Out_Channels, In_Channels, Kernel_H, Kernel_W] +// Output weight format should in [M x K] +// +// +// This kernel is not performance sensitive !!! We only need to pack weight once ! +void conv2d_fp32_im2col_weight(const float* src_weight, float* packed_weight, int out_channels, int in_channels, int kernel_h, + int kernel_w); + +} // namespace mllm::cpu::arm +#endif diff --git a/mllm/backends/cpu/kernels/arm/layernorm2d.cpp b/mllm/backends/cpu/kernels/arm/layernorm2d.cpp new file mode 100644 index 000000000..be52d715e --- /dev/null +++ b/mllm/backends/cpu/kernels/arm/layernorm2d.cpp @@ -0,0 +1,89 @@ +// Copyright (c) MLLM Team. +// Licensed under the MIT License. +#include "mllm/backends/cpu/kernels/arm/conv2d.hpp" + +#if defined(MLLM_HOST_ARCH_ARM64) || defined(MLLM_HOST_ARCH_ARM) + +#include +#include + +namespace mllm::cpu::arm { +void layernorm2d_fp32(const float* x, const float* weight, const float* bias, float* y, int N, int C, int H, int W, float eps) { + const int spatial_dim = H * W; + + for (int n = 0; n < N; ++n) { + for (int i = 0; i < spatial_dim; ++i) { + const float* x_ptr = x + n * C * spatial_dim + i; + float* y_ptr = y + n * C * spatial_dim + i; + + float sum = 0.0f; +#if defined(__ARM_NEON) + float32x4_t sum_vec = vdupq_n_f32(0.0f); + int c = 0; + for (; c <= C - 4; c += 4) { + float32x4_t x_vec = {x_ptr[c * spatial_dim], x_ptr[(c + 1) * spatial_dim], x_ptr[(c + 2) * spatial_dim], + x_ptr[(c + 3) * spatial_dim]}; + sum_vec = vaddq_f32(sum_vec, x_vec); + } + sum = vaddvq_f32(sum_vec); + for (; c < C; ++c) { sum += x_ptr[c * spatial_dim]; } +#else + for (int c = 0; c < C; ++c) { sum += x_ptr[c * spatial_dim]; } +#endif + const float mean = sum / C; + + float sq_sum = 0.0f; +#if defined(__ARM_NEON) + float32x4_t sq_sum_vec = vdupq_n_f32(0.0f); + float32x4_t mean_vec = vdupq_n_f32(mean); + c = 0; + for (; c <= C - 4; c += 4) { + float32x4_t x_vec = {x_ptr[c * spatial_dim], x_ptr[(c + 1) * spatial_dim], x_ptr[(c + 2) * spatial_dim], + x_ptr[(c + 3) * spatial_dim]}; + float32x4_t diff = vsubq_f32(x_vec, mean_vec); + sq_sum_vec = vmlaq_f32(sq_sum_vec, diff, diff); // Fused multiply-accumulate: sq_sum_vec += diff * diff + } + sq_sum = vaddvq_f32(sq_sum_vec); + for (; c < C; ++c) { + float diff = x_ptr[c * spatial_dim] - mean; + sq_sum += diff * diff; + } +#else + for (int c = 0; c < C; ++c) { + float diff = x_ptr[c * spatial_dim] - mean; + sq_sum += diff * diff; + } +#endif + const float variance = sq_sum / C; + const float inv_std = 1.0f / std::sqrt(variance + eps); + +#if defined(__ARM_NEON) + float32x4_t inv_std_vec = vdupq_n_f32(inv_std); + c = 0; + for (; c <= C - 4; c += 4) { + float32x4_t x_vec = {x_ptr[c * spatial_dim], x_ptr[(c + 1) * spatial_dim], x_ptr[(c + 2) * spatial_dim], + x_ptr[(c + 3) * spatial_dim]}; + float32x4_t weight_vec = vld1q_f32(weight + c); + float32x4_t bias_vec = vld1q_f32(bias + c); + + // y = (x - mean) * inv_std + float32x4_t norm_val = vmulq_f32(vsubq_f32(x_vec, mean_vec), inv_std_vec); + // y = y * weight + bias + float32x4_t out_vec = vmlaq_f32(bias_vec, norm_val, weight_vec); + + y_ptr[c * spatial_dim] = vgetq_lane_f32(out_vec, 0); + y_ptr[(c + 1) * spatial_dim] = vgetq_lane_f32(out_vec, 1); + y_ptr[(c + 2) * spatial_dim] = vgetq_lane_f32(out_vec, 2); + y_ptr[(c + 3) * spatial_dim] = vgetq_lane_f32(out_vec, 3); + } + for (; c < C; ++c) { y_ptr[c * spatial_dim] = (x_ptr[c * spatial_dim] - mean) * inv_std * weight[c] + bias[c]; } +#else + for (int c = 0; c < C; ++c) { y_ptr[c * spatial_dim] = (x_ptr[c * spatial_dim] - mean) * inv_std * weight[c] + bias[c]; } +#endif + } + } +} + +} // namespace mllm::cpu::arm + +#endif diff --git a/mllm/backends/cpu/kernels/arm/layernorm2d.hpp b/mllm/backends/cpu/kernels/arm/layernorm2d.hpp new file mode 100644 index 000000000..157829057 --- /dev/null +++ b/mllm/backends/cpu/kernels/arm/layernorm2d.hpp @@ -0,0 +1,18 @@ +// Copyright (c) MLLM Team. +// Licensed under the MIT License. + +#pragma once + +#include "mllm/utils/CPUArchHelper.hpp" + +#if defined(MLLM_HOST_ARCH_ARM64) || defined(MLLM_HOST_ARCH_ARM) + +#include + +namespace mllm::cpu::arm { + +// For NCHW +void layernorm2d_fp32(const float* x, const float* weight, const float* bias, float* y, int N, int C, int H, int W, float eps); + +} // namespace mllm::cpu::arm +#endif diff --git a/mllm/backends/cpu/ops/Conv2DOp.cpp b/mllm/backends/cpu/ops/Conv2DOp.cpp new file mode 100644 index 000000000..aef3c83fa --- /dev/null +++ b/mllm/backends/cpu/ops/Conv2DOp.cpp @@ -0,0 +1,144 @@ +// Copyright (c) MLLM Team. +// Licensed under the MIT License. + +#include +#include "mllm/core/aops/MatMulOp.hpp" +#include "mllm/backends/cpu/ops/Conv2DOp.hpp" +#include "mllm/backends/cpu/kernels/Kernels.hpp" + +namespace mllm::cpu { + +CPUConv2DOp::CPUConv2DOp(const aops::Conv2DOpOptions& options) : aops::Conv2DOp(options) {} + +void CPUConv2DOp::load(const ParameterFile::ptr_t& ploader) { + switch (ploader->version()) { + case ModelFileVersion::kV1: { + weight_ = ploader->pull(getName() + ".weight"); + if (options_.bias) { bias_ = ploader->pull(getName() + ".bias"); } + weight_ = weight_.view({ + options_.out_channels, + options_.in_channels, + options_.kernel_size[0], + options_.kernel_size[1], + }); + if (options_.bias) { bias_ = bias_.view({options_.out_channels}); } + break; + } + case ModelFileVersion::kUserTemporary: + case ModelFileVersion::kV2: { + weight_ = ploader->pull(getName() + ".weight"); + if (options_.bias) { bias_ = ploader->pull(getName() + ".bias"); } + break; + } + default: NYI("Unsupported model file version") + } + + auto& kernel_size = options_.kernel_size; + auto& stride = options_.stride; + auto& padding = options_.padding; + auto& dilation = options_.dilation; + + // Pack data + switch (options_.impl_type) { + case aops::Conv2DOpImplType::kDefault: { + // We will do im2col algorithm when using default impl. We will packing weight here. + MLLM_INFO("Packing Conv2D weight to im2col format. kh={}, kw={}, pw={}, ph={}, dw={}, dh={}, sw={}, sh={}", + kernel_size[0], kernel_size[1], padding[0], padding[1], dilation[0], dilation[1], stride[0], stride[1]); + auto packed_weight = Tensor::empty( + { + options_.out_channels, + options_.in_channels * kernel_size[0] * kernel_size[1], + + }, + weight_.dtype(), weight_.device()) + .alloc(); +#if defined(MLLM_HOST_ARCH_ARM64) || defined(MLLM_HOST_ARCH_ARM) + arm::conv2d_fp32_im2col_weight(weight_.ptr(), packed_weight.ptr(), options_.out_channels, + options_.in_channels, kernel_size[0], kernel_size[1]); +#else + MLLM_ERROR_EXIT(ExitCode::kCoreError, "Unsupported architecture for packing conv2d weight into im2col format."); +#endif + weight_ = packed_weight; + break; + } + default: { + NYI("Unsupported impl type") + } + } +} + +void CPUConv2DOp::forward(const std::vector& inputs, std::vector& outputs) { + auto& input = inputs[0]; + auto& output = outputs[0]; + auto& kernel_size = options_.kernel_size; + auto& stride = options_.stride; + auto& padding = options_.padding; + auto& dilation = options_.dilation; + + switch (input.dtype()) { + case kFloat32: { + switch (options_.impl_type) { + case aops::Conv2DOpImplType::kDefault: { + // Weight is M x K (out_channels x (in_channels * kernel_h * kernel_w)) + // Input is K x N ((in_channels * kernel_h * kernel_w) x (out_h * out_w)) + // Output is M x N (out_channels x (out_h * out_w)) + + auto mt = aops::MatMulOpType::kDefault; + if (mt == aops::MatMulOpType::kDefault) { +#if defined(MLLM_USE_BLAS) + mt = aops::MatMulOpType::kBLAS; +#else + mt = aops::MatMulOpType::kMllmBlas; +#endif + } + int MATMUL_M = options_.out_channels; + int MATMUL_K = options_.in_channels * kernel_size[0] * kernel_size[1]; + int MATMUL_N = output.shape()[2] * output.shape()[3]; + +#if defined(MLLM_HOST_ARCH_ARM64) || defined(MLLM_HOST_ARCH_ARM) + // step 1. im2col inputs to tmp + auto packed_inputs = Tensor::empty({MATMUL_K, MATMUL_N}, input.dtype(), input.device()).alloc(); + arm::conv2d_fp32_im2col_input(input.ptr(), options_.in_channels, input.shape()[1], input.shape()[2], + kernel_size[0], kernel_size[1], padding[0], padding[1], stride[0], stride[1], + dilation[0], dilation[1], packed_inputs.ptr()); + // step 2. Do matmul + switch (mt) { // NOLINT + case aops::MatMulOpType::kBLAS: { +#if defined(MLLM_USE_BLAS) + blas::matmul_fp32(weight_.ptr(), packed_inputs.ptr(), output.ptr(), + options_.bias ? bias_.ptr() : nullptr, MATMUL_M, MATMUL_N, MATMUL_K, false, false); +#else + NYI("BLAS not supported. Pls set MLLM_USE_BLAS=ON to enable BLAS supports in cmake."); +#endif + break; + } + case aops::MatMulOpType::kMllmBlas: { + auto thread_count = options_.getThreads(); +#if defined(MLLM_HOST_ARCH_ARM64) || defined(MLLM_HOST_ARCH_ARM) + arm::mllm_blas_matmul_fp32(MATMUL_M, MATMUL_K, MATMUL_N, output.ptr(), weight_.ptr(), + packed_inputs.ptr(), options_.bias ? bias_.ptr() : nullptr, + false, false, thread_count); +#else + NYI("MllmBlas only support MLLM_HOST_ARCH_ARM64 or MLLM_HOST_ARCH_ARM right now.") +#endif + break; + } + } +#else + MLLM_ERROR_EXIT(ExitCode::kCoreError, "Unsupported architecture for perform im2col conv2d."); +#endif + break; + } + default: { + NYI("Unsupported impl type"); + } + } + break; + } + default: { + NYI("Unsupported data type"); + } + } +} + +} // namespace mllm::cpu diff --git a/mllm/backends/cpu/ops/Conv2DOp.hpp b/mllm/backends/cpu/ops/Conv2DOp.hpp new file mode 100644 index 000000000..b005391cc --- /dev/null +++ b/mllm/backends/cpu/ops/Conv2DOp.hpp @@ -0,0 +1,27 @@ +// Copyright (c) MLLM Team. +// Licensed under the MIT License. + +#pragma once + +#include "mllm/core/BaseOp.hpp" +#include "mllm/core/aops/Conv2DOp.hpp" + +namespace mllm::cpu { + +class CPUConv2DOp final : public aops::Conv2DOp { + public: + explicit CPUConv2DOp(const aops::Conv2DOpOptions& options); + + void load(const ParameterFile::ptr_t& ploader) override; + + void forward(const std::vector& inputs, std::vector& outputs) override; +}; + +class CPUConv2DOpFactory : public TypedOpFactory { + public: + std::shared_ptr createOpImpl(const aops::Conv2DOpOptions& options) override { + return std::make_shared(options); + } +}; + +} // namespace mllm::cpu diff --git a/mllm/backends/cpu/ops/LayerNorm2DOp.cpp b/mllm/backends/cpu/ops/LayerNorm2DOp.cpp new file mode 100644 index 000000000..6d9e9ebb9 --- /dev/null +++ b/mllm/backends/cpu/ops/LayerNorm2DOp.cpp @@ -0,0 +1,38 @@ +// Copyright (c) MLLM Team. +// Licensed under the MIT License. + +#include +#include "mllm/backends/cpu/ops/LayerNorm2DOp.hpp" +#include "mllm/backends/cpu/kernels/Kernels.hpp" + +namespace mllm::cpu { + +CPULayerNorm2DOp::CPULayerNorm2DOp(const aops::LayerNorm2DOpOptions& options) : aops::LayerNorm2DOp(options) {} + +void CPULayerNorm2DOp::forward(const std::vector& inputs, std::vector& outputs) { + auto& i = inputs[0]; + auto& o = outputs[0]; + + auto i_shape = i.shape(); + auto N = i_shape[0]; + auto C = i_shape[1]; + auto H = i_shape[2]; + auto W = i_shape[3]; + + switch (i.dtype()) { + case kFloat32: { +#if defined(MLLM_HOST_ARCH_X86_64) || defined(MLLM_HOST_ARCH_X86) + NYI("Not impl for x86 64"); +#elif defined(MLLM_HOST_ARCH_ARM64) || defined(MLLM_HOST_ARCH_ARM) + arm::layernorm2d_fp32(i.ptr(), weight_.ptr(), bias_.ptr(), o.ptr(), N, + C, H, W, options_.eps); +#endif + break; + } + default: { + NYI("Not support data type"); + } + } +} + +} // namespace mllm::cpu diff --git a/mllm/backends/cpu/ops/LayerNorm2DOp.hpp b/mllm/backends/cpu/ops/LayerNorm2DOp.hpp new file mode 100644 index 000000000..b6d9a9171 --- /dev/null +++ b/mllm/backends/cpu/ops/LayerNorm2DOp.hpp @@ -0,0 +1,25 @@ +// Copyright (c) MLLM Team. +// Licensed under the MIT License. + +#pragma once + +#include "mllm/core/BaseOp.hpp" +#include "mllm/core/aops/LayerNorm2DOp.hpp" + +namespace mllm::cpu { + +class CPULayerNorm2DOp final : public aops::LayerNorm2DOp { + public: + explicit CPULayerNorm2DOp(const aops::LayerNorm2DOpOptions& options); + + void forward(const std::vector& inputs, std::vector& outputs) override; +}; + +class CPULayerNorm2DOpFactory : public TypedOpFactory { + public: + std::shared_ptr createOpImpl(const aops::LayerNorm2DOpOptions& options) override { + return std::make_shared(options); + } +}; + +} // namespace mllm::cpu diff --git a/mllm/compile/ir/GeneratedRTTIKind.hpp b/mllm/compile/ir/GeneratedRTTIKind.hpp index 0ec14089a..10457d0a2 100644 --- a/mllm/compile/ir/GeneratedRTTIKind.hpp +++ b/mllm/compile/ir/GeneratedRTTIKind.hpp @@ -1,4 +1,4 @@ -// Auto generated: 2025-09-09 15:27:37 +// Auto generated: 2025-10-22 14:54:07 // do not modify this file #pragma once @@ -71,6 +71,7 @@ enum NodeKind : uint32_t { RK_Op_LinalgIROp_SinOp, RK_Op_LinalgIROp_CosOp, RK_Op_LinalgIROp_PagedAttnOp, + RK_Op_LinalgIROp_LayerNorm2DOp, RK_Op_LinalgIROp_Last, RK_Op_GraphIROp, RK_Op_GraphIROp_SubGraphOp, diff --git a/mllm/compile/ir/NodeRTTIClassOfImpl.hpp b/mllm/compile/ir/NodeRTTIClassOfImpl.hpp index 2e3a4e85a..61e49abfe 100644 --- a/mllm/compile/ir/NodeRTTIClassOfImpl.hpp +++ b/mllm/compile/ir/NodeRTTIClassOfImpl.hpp @@ -1,4 +1,4 @@ -// Auto generated: 2025-09-09 15:27:37 +// Auto generated: 2025-10-22 14:54:07 // do not modify this file #pragma once namespace mllm::ir { @@ -183,6 +183,9 @@ struct NodeRTTIClassOfImpl { #define RTTI_RK_OP_LINALGIROP_PAGEDATTNOP_IMPL(v) \ return (v)->getKind() >= RK_Op_LinalgIROp_PagedAttnOp && (v)->getKind() <= RK_Op_LinalgIROp_PagedAttnOp +#define RTTI_RK_OP_LINALGIROP_LAYERNORM2DOP_IMPL(v) \ + return (v)->getKind() >= RK_Op_LinalgIROp_LayerNorm2DOp && (v)->getKind() <= RK_Op_LinalgIROp_LayerNorm2DOp + #define RTTI_RK_OP_GRAPHIROP_IMPL(v) return (v)->getKind() >= RK_Op_GraphIROp && (v)->getKind() <= RK_Op_GraphIROp_Last #define RTTI_RK_OP_GRAPHIROP_SUBGRAPHOP_IMPL(v) \ diff --git a/mllm/compile/ir/linalg/Op.cpp b/mllm/compile/ir/linalg/Op.cpp index 9f550341b..4653a6763 100644 --- a/mllm/compile/ir/linalg/Op.cpp +++ b/mllm/compile/ir/linalg/Op.cpp @@ -101,4 +101,6 @@ LINALG_AOPS_DECL(OpTypes::kMean, MeanOp); LINALG_AOPS_DECL(OpTypes::kClip, ClipOp); LINALG_AOPS_DECL(OpTypes::kPagedAttn, PagedAttnOp); +LINALG_AOPS_DECL(OpTypes::kLayerNorm2D, LayerNorm2DOp); + } // namespace mllm::ir::linalg diff --git a/mllm/compile/ir/linalg/Op.hpp b/mllm/compile/ir/linalg/Op.hpp index e9a33641c..2fec76940 100644 --- a/mllm/compile/ir/linalg/Op.hpp +++ b/mllm/compile/ir/linalg/Op.hpp @@ -64,6 +64,7 @@ class ExpOp; class SinOp; class CosOp; class PagedAttnOp; +class LayerNorm2DOp; } // namespace mllm #define LINALG_AOPS_DEFINE(class_name, rtti_name) \ @@ -213,4 +214,6 @@ LINALG_AOPS_DEFINE(MeanOp, MEANOP); LINALG_AOPS_DEFINE(ClipOp, CLIPOP); LINALG_AOPS_DEFINE(PagedAttnOp, PAGEDATTNOP); +LINALG_AOPS_DEFINE(LayerNorm2DOp, LAYERNORM2DOP); + } // namespace mllm::ir::linalg diff --git a/mllm/compile/ir/rtti_kind_gen.py b/mllm/compile/ir/rtti_kind_gen.py index 3eeaf5ab3..f5d0e670d 100644 --- a/mllm/compile/ir/rtti_kind_gen.py +++ b/mllm/compile/ir/rtti_kind_gen.py @@ -272,6 +272,7 @@ def define_lianlg_ir(ir: dict): op.derive(Cls("SinOp")) op.derive(Cls("CosOp")) op.derive(Cls("PagedAttnOp")) + op.derive(Cls("LayerNorm2DOp")) # value diff --git a/mllm/core/OpTypes.hpp b/mllm/core/OpTypes.hpp index 4adc215e5..26bd18eda 100644 --- a/mllm/core/OpTypes.hpp +++ b/mllm/core/OpTypes.hpp @@ -75,6 +75,7 @@ enum class OpTypes : int32_t { kPagedAttn = 57, kRadixAttn = 58, kScatter2Shards = 59, + kLayerNorm2D = 60, // Dynamic Op Start for user to register there own ops. kDynamicOp_Start = 4096, @@ -143,6 +144,7 @@ inline std::string optype2Str(OpTypes type) { case OpTypes::kGraphEnd: return "GraphEnd"; case OpTypes::kPagedAttn: return "PagedAttn"; case OpTypes::kScatter2Shards: return "Scatter2Shards"; + case OpTypes::kLayerNorm2D: return "LayerNorm2D"; case OpTypes::kOpType_End: return "OpType_End"; default: return "Unknown"; } diff --git a/mllm/core/Tensor.cpp b/mllm/core/Tensor.cpp index 73f80bf23..192b527ce 100644 --- a/mllm/core/Tensor.cpp +++ b/mllm/core/Tensor.cpp @@ -345,6 +345,12 @@ bool Tensor::isContiguous() const { return impl()->isContiguous(); } bool Tensor::isContiguousN(int n) const { return impl()->isContiguousN(n); } +int32_t Tensor::size(int32_t id) const { + auto nid = id; + if (id < 0) { nid = rank() + id; } + return shape()[nid]; +} + Tensor Tensor::contiguous() { return Context::instance().buildOpAndSubmitTask(OpTypes::kContiguous, aops::ContiguousOpOptions{}, {*this})[0]; } @@ -404,6 +410,32 @@ Tensor Tensor::squeeze(int32_t dim) { } } +Tensor Tensor::flatten(int32_t dim) { + const auto old_shape = shape(); + const int32_t ndim = static_cast(old_shape.size()); + + if (dim == 0x7fffffff) { + int32_t total = 1; + for (auto s : old_shape) total *= s; + return view({total}); + } + + if (ndim == 0) return view({1}); + + if (dim < 0) dim += ndim; + if (dim < 0 || dim >= ndim) throw std::out_of_range("flatten dim out of range"); + + std::vector new_shape; + new_shape.reserve(dim + 1); + + for (int32_t i = 0; i < dim; ++i) new_shape.push_back(old_shape[i]); + + int32_t flatten_size = 1; + for (int32_t i = dim; i < ndim; ++i) flatten_size *= old_shape[i]; + new_shape.push_back(flatten_size); + + return view(new_shape); +} Tensor Tensor::clone() { return Context::instance().buildOpAndSubmitTask(OpTypes::kClone, aops::CloneOpOptions{}, {*this})[0]; } void Tensor::copy2(const Tensor& src) { diff --git a/mllm/core/Tensor.hpp b/mllm/core/Tensor.hpp index f042bc558..4ef1d8a15 100644 --- a/mllm/core/Tensor.hpp +++ b/mllm/core/Tensor.hpp @@ -476,6 +476,14 @@ class Tensor { */ [[nodiscard]] bool isContiguousN(int n) const; + /** + * @brief + * + * @param id + * @return int32_t + */ + [[nodiscard]] int32_t size(int32_t id) const; + /** * @brief Creates contiguous copy if non-contiguous. * @return Contiguous tensor (may be a view or copy). @@ -522,6 +530,14 @@ class Tensor { */ Tensor squeeze(int32_t dim = 0x7fffffff); + /** + * @brief + * + * @param dim + * @return Tensor + */ + Tensor flatten(int32_t dim = 0x7fffffff); + /** * @brief clone a tensor * diff --git a/mllm/core/aops/Conv2DOp.cpp b/mllm/core/aops/Conv2DOp.cpp new file mode 100644 index 000000000..645249ad2 --- /dev/null +++ b/mllm/core/aops/Conv2DOp.cpp @@ -0,0 +1,113 @@ +// Copyright (c) MLLM Team. +// Licensed under the MIT License. + +#include "mllm/core/aops/Conv2DOp.hpp" +#include "mllm/core/BaseOp.hpp" +#include "mllm/core/Tensor.hpp" +#include "mllm/utils/Common.hpp" +#include "mllm/compile/ir/linalg/Op.hpp" +#include "mllm/compile/ir/graph/Op.hpp" +#include "mllm/compile/ir/tensor/Op.hpp" + +namespace mllm::aops { + +Conv2DOp::Conv2DOp(const Conv2DOpOptions& options) : BaseOp(OpTypes::kConv2D), options_(options) {} + +void Conv2DOp::load(const ParameterFile::ptr_t& ploader) { + switch (ploader->version()) { + case ModelFileVersion::kV1: { + weight_ = ploader->pull(getName() + ".weight"); + if (options_.bias) { bias_ = ploader->pull(getName() + ".bias"); } + weight_ = weight_.view({ + options_.out_channels, + options_.in_channels, + options_.kernel_size[0], + options_.kernel_size[1], + }); + if (options_.bias) { bias_ = bias_.view({options_.out_channels}); } + break; + } + case ModelFileVersion::kUserTemporary: + case ModelFileVersion::kV2: { + weight_ = ploader->pull(getName() + ".weight"); + if (options_.bias) { bias_ = ploader->pull(getName() + ".bias"); } + break; + } + default: NYI("Unsupported model file version") + } +} + +void Conv2DOp::trace(void* trace_context, const std::vector& inputs, std::vector& outputs) { + auto ir_ctx = (ir::IRContext*)trace_context; + + // Register Params + if (weight_ && !ir_ctx->lookupSymbolTable(getName() + ".weight")) { + ir::IRWriterGuard guard(ir_ctx, ir_ctx->lookupSymbolTable("init")->cast_()->getTopRegion()); + ir_ctx->create(ir_ctx->create(weight_)); + if (options_.bias) { ir_ctx->create(ir_ctx->create(bias_)); } + } + + auto i_irs = ir::tensor::wrapTensors2TensorIR(ir_ctx, inputs); + auto o_irs = ir::tensor::wrapTensors2TensorIR(ir_ctx, outputs); + auto _op = ir_ctx->create(shared_from_this(), i_irs, o_irs); +} + +void Conv2DOp::forward(const std::vector& inputs, std::vector& outputs) { + NYI("Conv2DOp::forward not implemented in aops base."); +} + +void Conv2DOp::reshape(const std::vector& inputs, std::vector& outputs) { + const auto& i = inputs[0]; + const auto& ishape = i.shape(); + + // Input must be 4D: [batch, channels, height, width] + if (ishape.size() != 4) { + MLLM_ERROR_EXIT(ExitCode::kCoreError, "Conv2DOp expects 4D input, got {} D", ishape.size()); + outputs.emplace_back(Tensor::empty(i.shape(), i.dtype(), i.device())); + return; + } + + const int batch = ishape[0]; + const int in_channels = ishape[1]; // channel axis + const int in_height = ishape[2]; // height axis + const int in_width = ishape[3]; // width axis + + // Current only support single batch + MLLM_RT_ASSERT_EQ(batch, 1); + + MLLM_RT_ASSERT_EQ(in_channels, options_.in_channels); + + // Retrieve convolution parameters from options_ + // For Conv2D, kernel_size should be [kh, kw] + const auto& kernel = options_.kernel_size; + const auto& stride = options_.stride; // [sh, sw] + const auto& padding = options_.padding; // [ph, pw] if available + const auto& dilation = options_.dilation; // [dh, dw] if available + const int out_channels = options_.out_channels; + + // Output shape calculation for Conv2D + auto out_shape = [](int dim_size, int kernel_size, int stride_size, int padding_size, int dilation_size) -> int32_t { + const int dilated_kernel_size = dilation_size * (kernel_size - 1) + 1; + return ((dim_size + 2 * padding_size - dilated_kernel_size) / stride_size) + 1; + }; + + // Calculate output height and width + auto h_out = out_shape(in_height, kernel[0], stride[0], padding[0], dilation[0]); + auto w_out = out_shape(in_width, kernel[1], stride[1], padding[1], dilation[1]); + + // Output shape: [batch, out_channels, h_out, w_out] + auto new_shape = std::vector{batch, out_channels, h_out, w_out}; + + outputs.emplace_back(Tensor::empty(new_shape, i.dtype(), i.device())); +} + +void Conv2DOp::setup(const std::vector& inputs, std::vector& outputs) { BaseOp::setup(inputs, outputs); } + +ParameterFile::ptr_t Conv2DOp::getParams() { + auto p = ParameterFile::create(); + p->push(getName() + ".weight", weight_); + if (options_.bias) { p->push(getName() + ".bias", bias_); } + return p; +} + +} // namespace mllm::aops diff --git a/mllm/core/aops/Conv2DOp.hpp b/mllm/core/aops/Conv2DOp.hpp new file mode 100644 index 000000000..0904e75bd --- /dev/null +++ b/mllm/core/aops/Conv2DOp.hpp @@ -0,0 +1,72 @@ +// Copyright (c) MLLM Team. +// Licensed under the MIT License. + +#pragma once + +#include "mllm/core/BaseOp.hpp" +#include "mllm/core/ParameterFile.hpp" + +namespace mllm::aops { + +enum class Conv2DOpImplType { + kDefault = 0, +}; + +struct Conv2DOpOptions : public BaseOpOptions { + int32_t in_channels; + int32_t out_channels; + std::vector kernel_size; + std::vector stride; + std::vector padding; + std::vector dilation; + bool bias = true; + Conv2DOpImplType impl_type = Conv2DOpImplType::kDefault; +}; + +inline Conv2DOpImplType str2Conv2DOpImplType(const std::string& str) { + static const std::unordered_map map = {{"Default", Conv2DOpImplType::kDefault}}; + + auto it = map.find(str); + if (it != map.end()) { return it->second; } + + // Return default if not found + return Conv2DOpImplType::kDefault; +} + +inline std::string Conv2DOpImplType2Str(Conv2DOpImplType type) { + static const std::unordered_map map = {{Conv2DOpImplType::kDefault, "Default"}}; + + auto it = map.find(type); + if (it != map.end()) return it->second; + return "Default"; +} + +class Conv2DOp : public BaseOp { + public: + explicit Conv2DOp(const Conv2DOpOptions& options); + + void load(const ParameterFile::ptr_t& ploader) override; + + void trace(void* trace_context, const std::vector& inputs, std::vector& outputs) override; + + void forward(const std::vector& inputs, std::vector& outputs) override; + + void reshape(const std::vector& inputs, std::vector& outputs) override; + + void setup(const std::vector& inputs, std::vector& outputs) override; + + ParameterFile::ptr_t getParams() override; + + inline Tensor& weight() { return weight_; } + + inline Tensor& bias() { return bias_; } + + inline Conv2DOpOptions& options() { return options_; } + + protected: + Tensor weight_; + Tensor bias_; + Conv2DOpOptions options_; +}; + +} // namespace mllm::aops diff --git a/mllm/core/aops/LayerNorm2DOp.cpp b/mllm/core/aops/LayerNorm2DOp.cpp new file mode 100644 index 000000000..28997dda7 --- /dev/null +++ b/mllm/core/aops/LayerNorm2DOp.cpp @@ -0,0 +1,56 @@ +// Copyright (c) MLLM Team. +// Licensed under the MIT License. + +#include "mllm/core/aops/LayerNorm2DOp.hpp" +#include "mllm/core/BaseOp.hpp" +#include "mllm/core/Tensor.hpp" +#include "mllm/utils/Common.hpp" +#include "mllm/compile/ir/linalg/Op.hpp" +#include "mllm/compile/ir/graph/Op.hpp" +#include "mllm/compile/ir/tensor/Op.hpp" + +namespace mllm::aops { + +LayerNorm2DOp::LayerNorm2DOp(const LayerNorm2DOpOptions& options) : BaseOp(OpTypes::kLayerNorm2D), options_(options) {} + +void LayerNorm2DOp::load(const ParameterFile::ptr_t& ploader) { + weight_ = ploader->pull(getName() + ".weight"); + weight_ = weight_.view({options_.num_channels}); + bias_ = ploader->pull(getName() + ".bias"); + bias_ = bias_.view({options_.num_channels}); +} + +void LayerNorm2DOp::trace(void* trace_context, const std::vector& inputs, std::vector& outputs) { + auto ir_ctx = (ir::IRContext*)trace_context; + + // Register Params + if (weight_ && !ir_ctx->lookupSymbolTable(getName() + ".weight")) { + ir::IRWriterGuard guard(ir_ctx, ir_ctx->lookupSymbolTable("init")->cast_()->getTopRegion()); + ir_ctx->create(ir_ctx->create(weight_)); + ir_ctx->create(ir_ctx->create(bias_)); + } + + auto i_irs = ir::tensor::wrapTensors2TensorIR(ir_ctx, inputs); + auto o_irs = ir::tensor::wrapTensors2TensorIR(ir_ctx, outputs); + ir_ctx->create(shared_from_this(), i_irs, o_irs); +} + +void LayerNorm2DOp::forward(const std::vector& inputs, std::vector& outputs) { + NYI("LayerNorm2DOp::forward not implemented in aops base."); +} + +void LayerNorm2DOp::reshape(const std::vector& inputs, std::vector& outputs) { + const auto& i = inputs[0]; + outputs.emplace_back(Tensor::empty(i.shape(), i.dtype(), i.device())); +} + +void LayerNorm2DOp::setup(const std::vector& inputs, std::vector& outputs) { BaseOp::setup(inputs, outputs); } + +ParameterFile::ptr_t LayerNorm2DOp::getParams() { + auto p = ParameterFile::create(); + p->push(getName() + ".weight", weight_); + p->push(getName() + ".bias", bias_); + return p; +} + +} // namespace mllm::aops diff --git a/mllm/core/aops/LayerNorm2DOp.hpp b/mllm/core/aops/LayerNorm2DOp.hpp new file mode 100644 index 000000000..9ff05d134 --- /dev/null +++ b/mllm/core/aops/LayerNorm2DOp.hpp @@ -0,0 +1,44 @@ +// Copyright (c) MLLM Team. +// Licensed under the MIT License. + +#pragma once + +#include "mllm/core/BaseOp.hpp" +#include "mllm/core/ParameterFile.hpp" + +namespace mllm::aops { + +struct LayerNorm2DOpOptions : public BaseOpOptions { + int32_t num_channels; + float eps = 1e-6; +}; + +class LayerNorm2DOp : public BaseOp { + public: + explicit LayerNorm2DOp(const LayerNorm2DOpOptions& options); + + void load(const ParameterFile::ptr_t& ploader) override; + + void trace(void* trace_context, const std::vector& inputs, std::vector& outputs) override; + + void forward(const std::vector& inputs, std::vector& outputs) override; + + void reshape(const std::vector& inputs, std::vector& outputs) override; + + void setup(const std::vector& inputs, std::vector& outputs) override; + + ParameterFile::ptr_t getParams() override; + + inline Tensor& weight() { return weight_; } + + inline Tensor& bias() { return bias_; } + + inline LayerNorm2DOpOptions& options() { return options_; } + + protected: + Tensor weight_; + Tensor bias_; + LayerNorm2DOpOptions options_; +}; + +} // namespace mllm::aops diff --git a/mllm/models/deepseek_ocr/README.md b/mllm/models/deepseek_ocr/README.md new file mode 100644 index 000000000..802ed47c3 --- /dev/null +++ b/mllm/models/deepseek_ocr/README.md @@ -0,0 +1,55 @@ +--- +pipeline_tag: image-text-to-text +language: +- multilingual +tags: +- deepseek +- vision-language +- ocr +- custom_code +license: mit +--- +
+ DeepSeek AI +
+
+ + +
+ + + Discord + + + Twitter Follow + + +
+ + + +

+ 🌟 Github | + 📥 Model Download | + 📄 Paper Link | + 📄 Arxiv Paper Link | +

+

+

+ DeepSeek-OCR: Contexts Optical Compression +

+

+

+ +

+

+Explore the boundaries of visual-text compression. +

diff --git a/mllm/models/deepseek_ocr/configuration_deepseek_ocr.hpp b/mllm/models/deepseek_ocr/configuration_deepseek_ocr.hpp new file mode 100644 index 000000000..f13946025 --- /dev/null +++ b/mllm/models/deepseek_ocr/configuration_deepseek_ocr.hpp @@ -0,0 +1,187 @@ +// Copyright (c) MLLM Team. +// Licensed under the MIT License. +#pragma once + +#include "mllm/core/aops/LinearOp.hpp" +#include "mllm/engine/ConfigFile.hpp" +#include + +namespace mllm::models::deepseek_ocr { + +struct DpskOcrConfig : protected ConfigFile { + DpskOcrConfig() = default; + + explicit DpskOcrConfig(const std::string& file_path) : ConfigFile(file_path) { + // Init all + _name_or_path = data()["_name_or_path"]; + + // Parse candidate_resolutions + if (data().contains("candidate_resolutions") && data()["candidate_resolutions"].is_array()) { + for (const auto& res : data()["candidate_resolutions"]) { + if (res.is_array() && res.size() == 2) { candidate_resolutions.push_back({res[0], res[1]}); } + } + } + + global_view_pos = data()["global_view_pos"]; + model_type = data()["model_type"]; + tile_tag = data()["tile_tag"]; + transformers_version = data()["transformers_version"]; + + // Language config + language_config.bos_token_id = data()["language_config"]["bos_token_id"]; + language_config.eos_token_id = data()["language_config"]["eos_token_id"]; + language_config.first_k_dense_replace = data()["language_config"]["first_k_dense_replace"]; + language_config.hidden_size = data()["language_config"]["hidden_size"]; + language_config.intermediate_size = data()["language_config"]["intermediate_size"]; + language_config.kv_lora_rank = data()["language_config"]["kv_lora_rank"].is_null() + ? -1 + : static_cast(data()["language_config"]["kv_lora_rank"]); + language_config.lm_head = data()["language_config"]["lm_head"]; + language_config.max_position_embeddings = data()["language_config"]["max_position_embeddings"]; + language_config.moe_intermediate_size = data()["language_config"]["moe_intermediate_size"]; + language_config.n_group = data()["language_config"]["n_group"]; + language_config.n_routed_experts = data()["language_config"]["n_routed_experts"]; + language_config.n_shared_experts = data()["language_config"]["n_shared_experts"]; + language_config.num_attention_heads = data()["language_config"]["num_attention_heads"]; + language_config.num_experts_per_tok = data()["language_config"]["num_experts_per_tok"]; + language_config.num_hidden_layers = data()["language_config"]["num_hidden_layers"]; + language_config.num_key_value_heads = data()["language_config"]["num_key_value_heads"]; + language_config.q_lora_rank = data()["language_config"]["q_lora_rank"].is_null() + ? -1 + : static_cast(data()["language_config"]["q_lora_rank"]); + language_config.qk_nope_head_dim = data()["language_config"]["qk_nope_head_dim"]; + language_config.qk_rope_head_dim = data()["language_config"]["qk_rope_head_dim"]; + language_config.rm_head = data()["language_config"]["rm_head"]; + language_config.topk_group = data()["language_config"]["topk_group"]; + language_config.topk_method = data()["language_config"]["topk_method"]; + language_config.use_mla = data()["language_config"]["use_mla"]; + language_config.v_head_dim = data()["language_config"]["v_head_dim"]; + language_config.vocab_size = data()["language_config"]["vocab_size"]; + + // Projector config + projector_config.input_dim = data()["projector_config"]["input_dim"]; + projector_config.model_type = data()["projector_config"]["model_type"]; + projector_config.n_embed = data()["projector_config"]["n_embed"]; + projector_config.projector_type = data()["projector_config"]["projector_type"]; + + // Vision config + vision_config.image_size = data()["vision_config"]["image_size"]; + vision_config.mlp_ratio = data()["vision_config"]["mlp_ratio"]; + vision_config.model_name = data()["vision_config"]["model_name"]; + vision_config.model_type = data()["vision_config"]["model_type"]; + + // Main config values + bos_token_id = data()["bos_token_id"]; + eos_token_id = data()["eos_token_id"]; + first_k_dense_replace = data()["first_k_dense_replace"]; + hidden_size = data()["hidden_size"]; + intermediate_size = data()["intermediate_size"]; + kv_lora_rank = data()["kv_lora_rank"].is_null() ? -1 : static_cast(data()["kv_lora_rank"]); + lm_head = data()["lm_head"]; + max_position_embeddings = data()["max_position_embeddings"]; + moe_intermediate_size = data()["moe_intermediate_size"]; + n_group = data()["n_group"]; + n_routed_experts = data()["n_routed_experts"]; + n_shared_experts = data()["n_shared_experts"]; + num_attention_heads = data()["num_attention_heads"]; + num_experts_per_tok = data()["num_experts_per_tok"]; + num_hidden_layers = data()["num_hidden_layers"]; + num_key_value_heads = data()["num_key_value_heads"]; + q_lora_rank = data()["q_lora_rank"].is_null() ? -1 : static_cast(data()["q_lora_rank"]); + qk_nope_head_dim = data()["qk_nope_head_dim"]; + qk_rope_head_dim = data()["qk_rope_head_dim"]; + rm_head = data()["rm_head"]; + topk_group = data()["topk_group"]; + topk_method = data()["topk_method"]; + use_mla = data()["use_mla"]; + v_head_dim = data()["v_head_dim"]; + vocab_size = data()["vocab_size"]; + } + + // Nested structs for complex configuration + struct LanguageConfig { + int64_t bos_token_id = 0; + int64_t eos_token_id = 1; + int32_t first_k_dense_replace = 1; + int32_t hidden_size = 1280; + int32_t intermediate_size = 6848; + int32_t kv_lora_rank = -1; // null in JSON + bool lm_head = true; + int32_t max_position_embeddings = 8192; + int32_t moe_intermediate_size = 896; + int32_t n_group = 1; + int32_t n_routed_experts = 64; + int32_t n_shared_experts = 2; + int32_t num_attention_heads = 10; + int32_t num_experts_per_tok = 6; + int32_t num_hidden_layers = 12; + int32_t num_key_value_heads = 10; + int32_t q_lora_rank = -1; // null in JSON + int32_t qk_nope_head_dim = 0; + int32_t qk_rope_head_dim = 0; + bool rm_head = false; + int32_t topk_group = 1; + std::string topk_method = "greedy"; + bool use_mla = false; + int32_t v_head_dim = 0; + int32_t vocab_size = 129280; + }; + + struct ProjectorConfig { + int32_t input_dim = 2048; + std::string model_type = "mlp_projector"; + int32_t n_embed = 1280; + std::string projector_type = "linear"; + }; + + struct VisionConfig { + int32_t image_size = 1024; + float mlp_ratio = 3.7362; + std::string model_name = "deeplip_b_l"; + std::string model_type = "vision"; + }; + + std::string _name_or_path = "deepseek-ai/DeepSeek-OCR"; + std::vector> candidate_resolutions = {{1024, 1024}}; + std::string global_view_pos = "head"; + std::string model_type = "deepseek_vl_v2"; + std::string tile_tag = "2D"; + std::string transformers_version = "4.46.3"; + + LanguageConfig language_config; + ProjectorConfig projector_config; + VisionConfig vision_config; + + // Main config values + int64_t bos_token_id = 0; + int64_t eos_token_id = 1; + int32_t first_k_dense_replace = 1; + int32_t hidden_size = 1280; + int32_t intermediate_size = 6848; + int32_t kv_lora_rank = -1; // null in JSON + bool lm_head = true; + int32_t max_position_embeddings = 8192; + int32_t moe_intermediate_size = 896; + int32_t n_group = 1; + int32_t n_routed_experts = 64; + int32_t n_shared_experts = 2; + int32_t num_attention_heads = 10; + int32_t num_experts_per_tok = 6; + int32_t num_hidden_layers = 12; + int32_t num_key_value_heads = 10; + int32_t q_lora_rank = -1; // null in JSON + int32_t qk_nope_head_dim = 0; + int32_t qk_rope_head_dim = 0; + bool rm_head = false; + int32_t topk_group = 1; + std::string topk_method = "greedy"; + bool use_mla = false; + int32_t v_head_dim = 0; + int32_t vocab_size = 129280; + + // MLLM Related Stuff + aops::LinearImplTypes clip_linear_impl_type; + aops::LinearImplTypes sam_linear_impl_type; +}; + +} // namespace mllm::models::deepseek_ocr diff --git a/mllm/models/deepseek_ocr/conversation.hpp b/mllm/models/deepseek_ocr/conversation.hpp new file mode 100644 index 000000000..8a948dbc3 --- /dev/null +++ b/mllm/models/deepseek_ocr/conversation.hpp @@ -0,0 +1,2 @@ +// Copyright (c) MLLM Team. +// Licensed under the MIT License. diff --git a/mllm/models/deepseek_ocr/deepencoder.hpp b/mllm/models/deepseek_ocr/deepencoder.hpp new file mode 100644 index 000000000..12baca5b3 --- /dev/null +++ b/mllm/models/deepseek_ocr/deepencoder.hpp @@ -0,0 +1,732 @@ +// Copyright (c) MLLM Team. +// Licensed under the MIT License. +#pragma once + +#include +#include +#include + +#include "mllm/mllm.hpp" +#include "mllm/nn/Nn.hpp" +#include "mllm/nn/Module.hpp" +#include "mllm/nn/Functional.hpp" +#include "mllm/nn/layers/Param.hpp" +#include "mllm/utils/Enumerate.hpp" +#include "mllm/models/deepseek_ocr/configuration_deepseek_ocr.hpp" + +namespace mllm::models::deepseek_ocr { + +//===----------------------------------------------------------------------===// +// CLIP +// +// CLIP params is hard coded. Just like what deepseek official model does. +// +// vit_model_cfg = adict( +// num_layers=24, +// hidden_size=1024, +// num_heads = 16, +// num_attention_heads=16, +// ffn_hidden_size=4096, +// seq_length=256, +// max_position_embeddings=256, +// use_flash_attn=False, +// understand_projector_stride=2, +// hidden_dropout = 0.0, +// attention_dropout = 0.0, +// no_persist_layer_norm = False, +// layernorm_epsilon = 1e-5, +// pre_layernorm_epsilon = 1e-5, +// image_size = 224, +// patch_size = 14, +// recompute_list = [] +// ) + +// def build_clip_l(): +// return VitModel( +// cfg=vit_model_cfg, +// freeze_embed=False, +// freeze_pre_norm=False, +// ) +//===----------------------------------------------------------------------===// +class CLIPVisionEmbeddings final : public nn::Module { + int embed_dim_; + int image_size_; + int patch_size_; + nn::Param class_embedding_; + nn::Conv2D patch_embedding_; + int num_patches_; + int num_positions_; + nn::Embedding position_embedding_; + + public: + CLIPVisionEmbeddings() = default; + + CLIPVisionEmbeddings(const std::string& name, const DpskOcrConfig& config) : nn::Module(name) { + embed_dim_ = 1024; + image_size_ = 224; + patch_size_ = 14; + num_patches_ = (image_size_ / patch_size_) * (image_size_ / patch_size_); + num_positions_ = num_patches_ + 1; + + // [embed_dim], aka [1024] + class_embedding_ = reg("class_embedding"); + patch_embedding_ = reg("patch_embedding", 3, embed_dim_, Tensor::shape_t{14, 14}, Tensor::shape_t{14, 14}, + Tensor::shape_t{0, 0}, Tensor::shape_t{1, 1}, false); + position_embedding_ = reg("position_embedding", num_positions_, embed_dim_); + + // Register a buffer + registerBuffer("position_ids", Tensor::arange(0, num_positions_, 1, kInt64, kCPU)); + } + + Tensor getAbsPos(Tensor abs_pos, int32_t tgt_size) { + // abs_pos : L, C + // tgt_size : M + // return : M, C + + auto dim = abs_pos.size(-1); + auto abs_pos_new = abs_pos.squeeze(0); + auto cls_token = abs_pos[{{kAll, 1}, kAll}].contiguous(); + auto old_pos_embed = abs_pos[{{1, kAll}, kAll}].contiguous(); + + auto src_size = int(std::sqrt(abs_pos_new.shape()[0] - 1)); + tgt_size = int(std::sqrt(tgt_size)); + auto dtype = abs_pos.dtype(); + + if (src_size != tgt_size) { + old_pos_embed = old_pos_embed.view({1, src_size, src_size, dim}).permute({0, 3, 1, 2}); + old_pos_embed = old_pos_embed.to(kFloat32); + + auto new_pos_embed = Tensor::empty({tgt_size, tgt_size}, kFloat32, kCPU).alloc(); + + // F.interpolate here. + { + const int channels = old_pos_embed.shape()[1]; + + auto old_pos_embed_ptr = old_pos_embed.ptr(); + auto new_pos_embed_ptr = new_pos_embed.ptr(); + + auto cubic_kernel = [](float x) -> float { + constexpr float a = -0.5f; + x = std::abs(x); + if (x < 1.0f) { + return (a + 2.0f) * x * x * x - (a + 3.0f) * x * x + 1.0f; + } else if (x < 2.0f) { + return a * x * x * x - 5.0f * a * x * x + 8.0f * a * x - 4.0f * a; + } else { + return 0.0f; + } + }; + + const float scale_y = static_cast(src_size) / tgt_size; + const float scale_x = static_cast(src_size) / tgt_size; + + for (int c = 0; c < channels; ++c) { + const float* src_channel_ptr = old_pos_embed_ptr + c * src_size * src_size; + float* dst_channel_ptr = new_pos_embed_ptr + c * tgt_size * tgt_size; + + for (int j = 0; j < tgt_size; ++j) { + for (int i = 0; i < tgt_size; ++i) { + float src_y = (static_cast(j) + 0.5f) * scale_y - 0.5f; + float src_x = (static_cast(i) + 0.5f) * scale_x - 0.5f; + + int y0 = static_cast(std::floor(src_y)) - 1; + int x0 = static_cast(std::floor(src_x)) - 1; + + float total_weight = 0.0f; + + for (int m = 0; m < 4; ++m) { + for (int n = 0; n < 4; ++n) { + int cur_y = y0 + m; + int cur_x = x0 + n; + cur_y = std::max(0, std::min(src_size - 1, cur_y)); + cur_x = std::max(0, std::min(src_size - 1, cur_x)); + float weight_y = cubic_kernel(src_y - (y0 + m)); + float weight_x = cubic_kernel(src_x - (x0 + n)); + total_weight += src_channel_ptr[cur_y * src_size + cur_x] * weight_y * weight_x; + } + } + dst_channel_ptr[j * tgt_size + i] = total_weight; + } + } + } + } + new_pos_embed = new_pos_embed.permute({0, 2, 3, 1}); + new_pos_embed = new_pos_embed.view({tgt_size * tgt_size, dim}); + auto vision_pos_embed = nn::functional::concat({cls_token, new_pos_embed}, 0); + vision_pos_embed = vision_pos_embed.view({1, tgt_size * tgt_size + 1, dim}); + return vision_pos_embed; + } else { + return abs_pos; + } + } + + std::vector forward(const std::vector& inputs, const std::vector& args) override { + auto pixel_values = Tensor::nil(); + auto patch_embeds = Tensor::nil(); + + auto batch_size = pixel_values.shape()[0]; + + if (inputs.size() == 1) { + pixel_values = inputs[0]; + } else if (inputs.size() == 2) { + pixel_values = inputs[0]; + patch_embeds = inputs[1]; + } + + if (!patch_embeds) { patch_embeds = patch_embedding_(pixel_values); } + + // Flatten and transpose. + // patch_embeds original shape is [batch(1), out_channel, width, grid, grid] + patch_embeds = patch_embeds.flatten(2).transpose(1, 2); // [batch(1), width * grid * grid, out_channel] + + // Assume batch is always 1 + MLLM_RT_ASSERT_EQ(batch_size, 1); + // [batch(1), 1, 1024] + auto class_embeds = class_embedding_.weight().view({1, 1, 1024}); + + auto embeddings = nn::functional::concat({class_embeds, patch_embeds}, 1); + embeddings = embeddings + getAbsPos(position_embedding_(getBuffer("position_ids")), embeddings.size(1)); + + return {embeddings}; + } +}; + +class NoTPFeedForward final : public nn::Module { + nn::Linear fc1_; + nn::Linear fc2_; + nn::QuickGELU act_; + + public: + NoTPFeedForward() = default; + + NoTPFeedForward(const std::string& name, int32_t dim, int32_t hidden_dim, const DpskOcrConfig& config) : nn::Module(name) { + fc1_ = reg("fc1", dim, hidden_dim, true, config.clip_linear_impl_type); + fc2_ = reg("fc2", hidden_dim, dim, true, config.clip_linear_impl_type); + act_ = reg("act"); + } + + std::vector forward(const std::vector& inputs, const std::vector& args) override { + return {fc2_(act_(fc1_(inputs[0])))}; + } +}; + +class NoTPAttention final : public nn::Module { + int num_heads_; + int n_local_heads_; + int head_dim_; + int max_seq_len_; + nn::Linear qkv_proj_; + nn::Linear out_proj_; + + public: + NoTPAttention() = default; + + NoTPAttention(const std::string& name, const DpskOcrConfig& config) { + num_heads_ = 16; + n_local_heads_ = 16; + head_dim_ = 1024 / 16; + max_seq_len_ = 256; + + qkv_proj_ = reg("qkv_proj", 1024, 1024 * 3, true, config.clip_linear_impl_type); + out_proj_ = reg("out_proj", 1024, 1024, true, config.clip_linear_impl_type); + } + + std::vector forward(const std::vector& inputs, const std::vector& args) override { + // TODO + auto& x = inputs[0]; + auto bsz = x.size(0); + auto seqlen = x.size(1); + + auto xqkv = qkv_proj_(x); + xqkv = xqkv.view({bsz, seqlen, 3, num_heads_, head_dim_}); + auto [xq, xk, xv] = nn::functional::split<3>(xqkv, 2); + + // TODO need squeeze ? + + // TODO permute ? + + // TODO FA without mask + + // TODO outproj. + + return {}; + } +}; + +class NoTPTransformerBlock final : public nn::Module { + int n_heads_; + int dim_; + int head_dim_; + NoTPAttention self_attn_; + NoTPFeedForward mlp_; + nn::LayerNorm layer_norm1_; + nn::LayerNorm layer_norm2_; + + public: + int layer_id_; + + NoTPTransformerBlock() = default; + + NoTPTransformerBlock(const std::string& name, const DpskOcrConfig& config) : nn::Module(name) { + n_heads_ = 16; + dim_ = 1024; + head_dim_ = 1024 / 16; + self_attn_ = reg("self_attn", config); + mlp_ = reg("mlp", 1024, 4096, config); + layer_norm1_ = reg("layer_norm1", Tensor::shape_t{1024}, true, true, 1e-5); + layer_norm2_ = reg("layer_norm2", Tensor::shape_t{1024}, true, true, 1e-5); + } + + std::vector forward(const std::vector& inputs, const std::vector& args) override { + auto x = inputs[0]; + auto residual = self_attn_(layer_norm1_(x))[0]; + auto h = x + residual; + auto out = h + mlp_(layer_norm2_(h))[0]; + return {out}; + } +}; + +class NoTPTransformer final : public nn::Module { + int num_layers_; + nn::ModuleList layers_; + + public: + NoTPTransformer() = default; + + NoTPTransformer(const std::string& name, const DpskOcrConfig& config) { + num_layers_ = 24; + layers_ = reg>("layers", num_layers_, config); + for (auto [idx, layer] : enumerate(layers_.list())) { layer.layer_id_ = idx; } + } + + std::vector forward(const std::vector& inputs, const std::vector& args) override { + auto hidden_states = inputs[0]; + for (auto [idx, layer] : enumerate(layers_.list())) { hidden_states = layer(hidden_states)[0]; } + + return {hidden_states}; + } +}; + +class VitModel final : public nn::Module { + CLIPVisionEmbeddings embeddings_; + NoTPTransformer transformer_; + nn::LayerNorm pre_layernorm_; ///< input must in fp32 dtype. + + public: + VitModel() = default; + + VitModel(const std::string& name, const DpskOcrConfig& config) : nn::Module(name) { + embeddings_ = reg("embeddings", config); + transformer_ = reg("transformer", config); + + // NOTE: + // Yes!!!, Its pre_layrnorm! Deepseek Typo!. + pre_layernorm_ = reg("pre_layrnorm", Tensor::shape_t{1024}, true, true, 1e-5); + } + + std::vector forward(const std::vector& inputs, const std::vector& args) override { + auto x = inputs[0]; + auto patch_embeds = inputs[1]; + + auto output = embeddings_(x, patch_embeds)[0]; + output = pre_layernorm_(output); + output = transformer_(output)[0]; + return {output}; + } +}; + +//===----------------------------------------------------------------------===// +// SAM +//===----------------------------------------------------------------------===// +class PatchEmbed final : public nn::Module { + nn::Conv2D proj_; + + public: + PatchEmbed() = default; + + PatchEmbed(const std::string& name, const DpskOcrConfig& config) : nn::Module(name) { + proj_ = reg("proj", 3, 768, Tensor::shape_t{16, 16}, Tensor::shape_t{16, 16}, Tensor::shape_t{0, 0}, + Tensor::shape_t{1, 1}, true); + } + + std::vector forward(const std::vector& inputs, const std::vector& args) override { + auto x = inputs[0]; + // B C H W -> B H W C + return {proj_(x).permute({0, 2, 3, 1})}; + } +}; + +class MLPBlock final : public nn::Module { + nn::Linear lin1_; + nn::Linear lin2_; + nn::GELU act_; + + public: + MLPBlock() = default; + + MLPBlock(const std::string& name, int embedding_dim, int mlp_dim, const DpskOcrConfig& config) : nn::Module(name) { + lin1_ = reg("lin1", embedding_dim, mlp_dim, true, config.sam_linear_impl_type); + lin2_ = reg("lin2", mlp_dim, embedding_dim, true, config.sam_linear_impl_type); + act_ = reg("act"); + } + + std::vector forward(const std::vector& inputs, const std::vector& args) override { + return {lin2_(act_(lin1_(inputs[0])))}; + } +}; + +class Attention final : public nn::Module { + int num_heads_; + bool use_rel_pos_; + + nn::Linear qkv_; + nn::Linear proj_; + nn::Param rel_pos_h_; + nn::Param rel_pos_w_; + + public: + Attention() = default; + + Attention(const std::string& name, int dim, int num_heads, bool qkv_bias, bool use_rel_pos, + std::optional> input_size, const DpskOcrConfig& config) + : nn::Module(name) { + num_heads_ = num_heads; + use_rel_pos_ = use_rel_pos; + + qkv_ = reg("qkv", dim, dim * 3, qkv_bias, config.sam_linear_impl_type); + proj_ = reg("proj", dim, dim, true, config.sam_linear_impl_type); + if (use_rel_pos) { + rel_pos_h_ = reg("rel_pos_h"); + rel_pos_w_ = reg("rel_pos_w"); + } + } + + Tensor __interpolateLinear1d(const Tensor& input, int output_size) { + auto output = Tensor::empty({output_size, input.size(1)}).alloc(); + int input_size = input.size(0); + float scale_factor = static_cast(input_size - 1) / (output_size - 1); + + for (int i = 0; i < output_size; ++i) { + float in_x = i * scale_factor; + int x0 = static_cast(floor(in_x)); + int x1 = std::min(x0 + 1, input_size - 1); + float w1 = in_x - x0; + float w0 = 1.0f - w1; + + for (int c = 0; c < input.size(1); ++c) { + float val = + w0 * input.ptr()[x0 * input.size(1) + c] + w1 * input.ptr()[x1 * input.size(1) + c]; + *output.offsettedPtr({i, c}) = val; + } + } + return output; + } + + // Get relative positional embeddings according to the relative positions of query and key sizes. + Tensor getRelPos(int q_size, int k_size, const Tensor& rel_pos) { + auto max_rel_dist = 2 * std::max(q_size, k_size) - 1; + Tensor rel_pos_resized = Tensor::nil(); + + if (rel_pos.size(0) != max_rel_dist) { + rel_pos_resized = __interpolateLinear1d(rel_pos, max_rel_dist); + } else { + rel_pos_resized = rel_pos; + } + + std::vector q_coords(q_size); + std::vector k_coords(k_size); + + float q_scale = std::max((float)k_size / q_size, 1.0f); + float k_scale = std::max((float)q_size / k_size, 1.0f); + + for (int i = 0; i < q_size; ++i) { q_coords[i] = i * q_scale; } + + for (int i = 0; i < k_size; ++i) { k_coords[i] = i * k_scale; } + + float offset = (k_size - 1) * k_scale; + int embedding_dim = rel_pos_resized.size(1); + auto out = Tensor::empty({q_size, k_size, embedding_dim}).alloc(); + + for (int i = 0; i < q_size; ++i) { + for (int j = 0; j < k_size; ++j) { + float relative_coord_float = (q_coords[i] - k_coords[j]) + offset; + int64_t relative_coord_long = static_cast(std::round(relative_coord_float)); + + if (relative_coord_long < 0) relative_coord_long = 0; + if (relative_coord_long >= max_rel_dist) relative_coord_long = max_rel_dist - 1; + + for (int d = 0; d < embedding_dim; ++d) { + out.ptr()[i * k_size * embedding_dim + j * embedding_dim + d] = + *rel_pos_resized.offsettedPtr({(int32_t)relative_coord_long, d}); + } + } + } + + return out; + } + + // Calculate decomposed Relative Positional Embeddings from :paper:`mvitv2`. + // https://github.com/facebookresearch/mvit/blob/19786631e330df9f3622e5402b4a419a263a2c80/mvit/models/attention.py + std::tuple addDecomposedRelPos(Tensor q, const Tensor& rel_pos_h, const Tensor& rel_pos_w, + std::tuple q_size, std::tuple k_size) { + auto [q_h, q_w] = q_size; + auto [k_h, k_w] = k_size; + + auto Rh = getRelPos(q_h, k_h, rel_pos_h); + auto Rw = getRelPos(q_w, k_w, rel_pos_w); + + auto B = q.size(0); + auto dim = q.size(2); + + auto r_q = q.view({B, q_h, q_w, dim}); + + // Einsum + // 1. bhwc,hkc->bhwk + auto rel_h = Tensor::empty({B, q_h, q_w, k_h}).alloc(); + { + auto* r_q_ptr = r_q.ptr(); + auto* Rh_ptr = Rh.ptr(); + auto* rel_h_ptr = rel_h.ptr(); + // rel_h[b, h, w, k] = sum over c from 0 to dim-1 ( r_q[b, h, w, c] * Rh[h, k, c] ) + for (int b = 0; b < B; ++b) { + for (int h = 0; h < q_h; ++h) { + for (int w = 0; w < q_w; ++w) { + for (int k = 0; k < k_h; ++k) { + float sum = 0.0f; + const auto* p_r_q = r_q_ptr + (b * q_h * q_w * dim) + (h * q_w * dim) + (w * dim); + const auto* p_Rh = Rh_ptr + (h * k_h * dim) + (k * dim); + for (int c = 0; c < dim; ++c) { sum += p_r_q[c] * p_Rh[c]; } + rel_h_ptr[(b * q_h * q_w * k_h) + (h * q_w * k_h) + (w * k_h) + k] = sum; + } + } + } + } + } + // 2. bhwc,wkc->bhwk + auto rel_w = Tensor::empty({B, q_h, q_w, k_w}).alloc(); + { + auto* r_q_ptr = r_q.ptr(); + auto* Rw_ptr = Rw.ptr(); + auto* rel_w_ptr = rel_w.ptr(); + + // rel_w[b, h, w, k] = sum over c from 0 to dim-1 ( r_q[b, h, w, c] * Rw[w, k, c] ) + for (int b = 0; b < B; ++b) { + for (int h = 0; h < q_h; ++h) { + for (int w = 0; w < q_w; ++w) { + for (int k = 0; k < k_w; ++k) { + float sum = 0.0f; + const auto* p_r_q = r_q_ptr + (b * q_h * q_w * dim) + (h * q_w * dim) + (w * dim); + const auto* p_Rw = Rw_ptr + (w * k_w * dim) + (k * dim); + for (int c = 0; c < dim; ++c) { sum += p_r_q[c] * p_Rw[c]; } + rel_w_ptr[(b * q_h * q_w * k_w) + (h * q_w * k_w) + (w * k_w) + k] = sum; + } + } + } + } + } + + rel_h = rel_h.unsqueeze(-1); + rel_w = rel_w.unsqueeze(-2); + rel_h = rel_h.view({B, q_h * q_w, k_h, 1}); + rel_w = rel_w.view({B, q_h * q_w, 1, k_w}); + return {rel_h, rel_w}; + } + + std::vector forward(const std::vector& inputs, const std::vector& args) override { + auto x = inputs[0]; + auto B = x.size(0); + auto H = x.size(1); + auto W = x.size(2); + + // qkv with shape (3, B, nHead, H * W, C) + auto qkv = qkv_(x) + .view({ + B, + H * W, + 3, + num_heads_, + -1, + }) + .permute({2, 0, 3, 1, 4}); + qkv = qkv.view({3, B * num_heads_, H * W, -1}); + auto [q, k, v] = nn::functional::split<3>(qkv, 0); + + auto rel_h = Tensor::nil(); + auto rel_w = Tensor::nil(); + + if (use_rel_pos_) { + std::tie(rel_h, rel_w) = addDecomposedRelPos(q, rel_pos_h_.weight(), rel_pos_w_.weight(), {H, W}, {W, H}); + } + + q = q.view({B, num_heads_, H * W, -1}); + k = k.view({B, num_heads_, H * W, -1}); + v = v.view({B, num_heads_, H * W, -1}); + + if (use_rel_pos_) { + rel_h = rel_h.view({B, num_heads_, rel_h.size(1), rel_h.size(2), rel_h.size(3)}); + rel_w = rel_w.view({B, num_heads_, rel_w.size(1), rel_w.size(2), rel_w.size(3)}); + auto attn_bias = (rel_h + rel_w).view({B, num_heads_, rel_h.size(2), rel_h.size(3) * rel_w.size(4)}); + x = nn::functional::scaledDotProductAttention(q, k, v, attn_bias); + } else { + x = nn::functional::scaledDotProductAttention(q, k, v); + } + + x = x.view({B, num_heads_, H, W, -1}).permute({0, 2, 3, 1, 4}).view({B, H, W, -1}); + x = proj_(x); + return {x}; + } +}; + +class Block final : public nn::Module { + nn::LayerNorm norm1_; + nn::LayerNorm norm2_; + Attention attn_; + MLPBlock mlp_; + int window_size_; + + public: + Block() = default; + + Block(const std::string& name, int dim, int num_heads, float mlp_ratio, bool qkv_bias, bool use_rel_pos, int window_size, + std::optional> input_size, const DpskOcrConfig& config) + : nn::Module(name) { + norm1_ = reg("norm1", Tensor::shape_t{dim}); + attn_ = + reg("attn", dim, num_heads, qkv_bias, use_rel_pos, + window_size == 0 ? input_size : std::make_optional(std::make_tuple(window_size, window_size)), config); + norm2_ = reg("norm2", Tensor::shape_t{dim}); + mlp_ = reg("mlp", dim, (int)(dim * mlp_ratio), config); + window_size_ = window_size; + } + + std::tuple> windowPartition(Tensor x, int window_size) { + auto B = x.size(0); + auto H = x.size(1); + auto W = x.size(2); + auto C = x.size(3); + + auto pad_h = (window_size - H % window_size) % window_size; + auto pad_w = (window_size - W % window_size) % window_size; + + if (pad_h > 0 || pad_w > 0) { + // TODO do pad + // x = F.pad(x, (0, 0, 0, pad_w, 0, pad_h)) + } + auto Hp = H + pad_h; + auto Wp = W + pad_w; + + x = x.view({B, Hp / window_size, window_size, Wp / window_size, window_size, C}); + auto window = x.permute({0, 1, 3, 2, 4, 5}).view({-1, window_size, window_size, C}); + return {window, {Hp, Wp}}; + } + + Tensor windowUnpartition(Tensor windows, int window_size, std::tuple pad_wh, std::tuple hw) { + auto [Hp, Wp] = pad_wh; + auto [H, W] = hw; + auto B = windows.size(0) / (Hp * Wp / window_size / window_size); + auto x = windows.view({B, Hp / window_size, Wp / window_size, window_size, window_size, -1}); + x = x.permute({0, 1, 3, 2, 4, 5}).view({B, Hp, Wp, -1}); + + if (Hp > H || Wp > W) { x = x[{kAll, {kAll, H}, {kAll, W}, kAll}].contiguous(); } + + return x; + } + + std::vector forward(const std::vector& inputs, const std::vector& args) override { + auto x = inputs[0]; + auto shortcut = x; + x = norm1_(x); + + // Window partition + int H = 0; + int W = 0; + std::tuple pad_hw; + if (window_size_ > 0) { + H = x.size(1); + W = x.size(2); + std::tie(x, pad_hw) = windowPartition(x, window_size_); + } + + x = attn_(x)[0]; + + // Reverse window partition + if (window_size_ > 0) { x = windowUnpartition(x, window_size_, pad_hw, {H, W}); } + + x = shortcut + x; + x = x + mlp_(norm2_(x))[0]; + + return {x}; + } +}; + +class Blocks final : public nn::Module { + std::vector blocks_; + + public: + Blocks() = default; + + Blocks(const std::string& name, int nums, const std::vector& global_attn_indexes, const DpskOcrConfig& config) + : nn::Module(name) { + for (int i = 0; i < nums; ++i) { + bool is_in = std::find(global_attn_indexes.begin(), global_attn_indexes.end(), i) != global_attn_indexes.end(); + auto this_block_window_size = is_in ? 14 : 0; + blocks_.emplace_back(reg(std::to_string(i), 768, 12, 4.0, true, true, this_block_window_size, + std::make_optional(std::make_tuple(1024 / 16, 1024 / 16)), config)); + } + }; + + std::vector& list() { return blocks_; } +}; + +class ImageEncoderViT final : public nn::Module { + PatchEmbed patch_embed_; + nn::Param pos_embed_; + Blocks blocks_; + nn::Sequential neck_; + nn::Conv2D net_2_; + nn::Conv2D net_3_; + + public: + ImageEncoderViT() = default; + + ImageEncoderViT(const std::string& name, const DpskOcrConfig& config) : nn::Module(name) { + patch_embed_ = reg("patch_embed", config); + pos_embed_ = reg("pos_embed"); + + // block_nums = 12 + // embed_dim = 768 + // num_heads = 12 + // mlp_ratio = 4.f + // qkv_bias = true + // use_rel_pos = true + // window_size = 14 + blocks_ = reg("blocks", 12, std::vector{2, 5, 8, 11}, config); + + neck_ = reg("neck") + .add(768, 12, Tensor::shape_t{1, 1}, Tensor::shape_t{1, 1}, Tensor::shape_t{0, 0}, + Tensor::shape_t{1, 1}, false) + .add(256) + .add(256, 256, Tensor::shape_t{3, 3}, Tensor::shape_t{1, 1}, Tensor::shape_t{1, 1}, + Tensor::shape_t{1, 1}, false) + .add(256); + + net_2_ = reg("net_2", 256, 512, Tensor::shape_t{3, 3}, Tensor::shape_t{2, 2}, Tensor::shape_t{1, 1}, + Tensor::shape_t{1, 1}, false); + net_3_ = reg("net_3", 512, 1024, Tensor::shape_t{3, 3}, Tensor::shape_t{2, 2}, Tensor::shape_t{1, 1}, + Tensor::shape_t{1, 1}, false); + } + + std::vector forward(const std::vector& inputs, const std::vector& args) override { + auto x = inputs[0]; + x = patch_embed_(x)[0]; + // TODO x = x + get_abs_pos_sam(self.pos_embed, x.size(1)) + for (auto& blk : blocks_.list()) { x = blk(x)[0]; } + + x = neck_(x.permute({0, 3, 1, 2}))[0]; + x = net_2_(x); + x = net_3_(x); + return {x}; + } +}; + +} // namespace mllm::models::deepseek_ocr diff --git a/mllm/models/deepseek_ocr/modeling_deepseek_ocr.hpp b/mllm/models/deepseek_ocr/modeling_deepseek_ocr.hpp new file mode 100644 index 000000000..8a948dbc3 --- /dev/null +++ b/mllm/models/deepseek_ocr/modeling_deepseek_ocr.hpp @@ -0,0 +1,2 @@ +// Copyright (c) MLLM Team. +// Licensed under the MIT License. diff --git a/mllm/models/deepseek_ocr/tokenization_deepseek_ocr.hpp b/mllm/models/deepseek_ocr/tokenization_deepseek_ocr.hpp new file mode 100644 index 000000000..76e332fcb --- /dev/null +++ b/mllm/models/deepseek_ocr/tokenization_deepseek_ocr.hpp @@ -0,0 +1,51 @@ +// Copyright (c) MLLM Team. +// Licensed under the MIT License. + +// ref: https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/tokenization_llama.py +// and +// ref: https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/tokenization_llama_fast.py + +// LlamaTokenizerFast +#pragma once + +#include +#include + +#include "mllm/models/ARGeneration.hpp" +#include "mllm/preprocessor/tokenizers/BPE.hpp" +#include "mllm/preprocessor/tokenizers/Unicode.hpp" +#include "mllm/preprocessor/tokenizers/AutoTokenizer.hpp" + +namespace mllm::models::deepseek_ocr { + +// Actually is LlamaTokenizer +class DpskOcrTokenizer final : public mllm::preprocessor::AutoTokenizer { + explicit DpskOcrTokenizer(const std::string& file_path) { preprocessor::initLocal(); } + + std::vector _tokenize(const std::string& str) override { + // TODO + return {}; + } + + std::vector tokenize(const std::string& str) override { + // TODO + return {}; + } + + std::wstring _detokenize(int64_t pos_idx) override { + // TODO + return L""; + } + + std::wstring detokenize(int64_t pos_idx) override { + // TODO + return _detokenize(pos_idx); + } + + Tensor convert2Ids(const std::vector& strs) override { return Tensor::nil(); } + + private: + // For text + preprocessor::BPE bpe_; +}; +} // namespace mllm::models::deepseek_ocr diff --git a/mllm/nn/Functional.cpp b/mllm/nn/Functional.cpp index c6803af9c..2f447ddd0 100644 --- a/mllm/nn/Functional.cpp +++ b/mllm/nn/Functional.cpp @@ -118,4 +118,13 @@ void scatter2Shards(const Tensor& src, const Tensor& shards_pointer, int32_t dim {src, shards_pointer}); } +Tensor scaledDotProductAttention(const Tensor& Q, const Tensor& K, const Tensor& V, const Tensor& mask) { + auto scale = Q.size(-1); + scale = (1.f / sqrtf(scale)); + auto attn_weight = matmul(Q, K, false, true) * scale; + if (mask) { attn_weight = attn_weight + mask; } + attn_weight = softmax(attn_weight, -1); + return matmul(attn_weight, V); +} + } // namespace mllm::nn::functional diff --git a/mllm/nn/Functional.hpp b/mllm/nn/Functional.hpp index 9efad0ab2..b61c564b9 100644 --- a/mllm/nn/Functional.hpp +++ b/mllm/nn/Functional.hpp @@ -129,4 +129,7 @@ Tensor silu_(const Tensor& x); void scatter2Shards(const Tensor& src, const Tensor& shards_pointer, int32_t dim); +// If you want causal mask attention. Use Flash attention instead. +Tensor scaledDotProductAttention(const Tensor& Q, const Tensor& K, const Tensor& V, const Tensor& mask = Tensor()); + } // namespace mllm::nn::functional diff --git a/mllm/nn/Module.hpp b/mllm/nn/Module.hpp index 8adadae57..5de909e37 100644 --- a/mllm/nn/Module.hpp +++ b/mllm/nn/Module.hpp @@ -42,10 +42,10 @@ class ModuleImpl : public AbstractNnNode { }; template -class ModuleLists; +class ModuleList; template -class ModuleListsSuffix; +class ModuleListSuffix; class Module { public: diff --git a/mllm/nn/Nn.hpp b/mllm/nn/Nn.hpp index 7bab5ebaa..bb4fa54d9 100644 --- a/mllm/nn/Nn.hpp +++ b/mllm/nn/Nn.hpp @@ -17,6 +17,7 @@ #include "mllm/nn/layers/LayerNorm.hpp" // IWYU pragma: export #include "mllm/nn/layers/Softmax.hpp" // IWYU pragma: export #include "mllm/nn/layers/VisionRoPE.hpp" // IWYU pragma: export +#include "mllm/nn/layers/Conv2D.hpp" // IWYU pragma: export #include "mllm/nn/layers/Conv3D.hpp" // IWYU pragma: export #include "mllm/nn/layers/CausalMask.hpp" // IWYU pragma: export #include "mllm/nn/layers/RoPE.hpp" // IWYU pragma: export @@ -27,3 +28,4 @@ #include "mllm/nn/layers/STFT.hpp" // IWYU pragma: export #include "mllm/nn/layers/PagedAttn.hpp" // IWYU pragma: export #include "mllm/nn/layers/RadixAttn.hpp" // IWYU pragma: export +#include "mllm/nn/layers/LayerNorm2D.hpp" // IWYU pragma: export diff --git a/mllm/nn/layers/Conv2D.cpp b/mllm/nn/layers/Conv2D.cpp new file mode 100644 index 000000000..6be072084 --- /dev/null +++ b/mllm/nn/layers/Conv2D.cpp @@ -0,0 +1,29 @@ +#include "mllm/nn/layers/Conv2D.hpp" +#include "mllm/core/BaseOp.hpp" +#include "mllm/core/aops/Conv2DOp.hpp" + +namespace mllm::nn { + +Conv2D::Conv2D() : Layer(OpTypes::kConv2D, aops::Conv2DOpOptions{}) {} + +Conv2D::Conv2D(int32_t in_channels, int32_t out_channels, const std::vector& kernel_size, + const std::vector& stride_size, const std::vector& padding_size, + const std::vector& dilation_size, bool bias, aops::Conv2DOpImplType impl_type) + : Layer(OpTypes::kConv2D, aops::Conv2DOpOptions{ + .in_channels = in_channels, + .out_channels = out_channels, + .kernel_size = kernel_size, + .stride = stride_size, + .padding = padding_size, + .dilation = dilation_size, + .bias = bias, + .impl_type = impl_type, + }) {} + +Conv2D::Conv2D(const aops::Conv2DOpOptions& options) : Layer(OpTypes::kConv2D, options) {} + +Tensor Conv2D::weight() const { return std::static_pointer_cast(impl()->getInstancedOp())->weight(); } + +Tensor Conv2D::bias() const { return std::static_pointer_cast(impl()->getInstancedOp())->bias(); } + +} // namespace mllm::nn diff --git a/mllm/nn/layers/Conv2D.hpp b/mllm/nn/layers/Conv2D.hpp new file mode 100644 index 000000000..ff67918bf --- /dev/null +++ b/mllm/nn/layers/Conv2D.hpp @@ -0,0 +1,27 @@ +#pragma once + +#include "mllm/nn/Layer.hpp" +#include "mllm/core/aops/Conv2DOp.hpp" + +namespace mllm::nn { + +class Conv2D : public Layer { + public: + Conv2D(); + + Conv2D(int32_t in_channels, int32_t out_channels, const std::vector& kernel_size, + const std::vector& stride_size, const std::vector& padding_size, + const std::vector& dilation_size, bool bias = true, + aops::Conv2DOpImplType impl_type = aops::Conv2DOpImplType::kDefault); + + explicit Conv2D(const aops::Conv2DOpOptions& options); + + [[nodiscard]] Tensor weight() const; + + [[nodiscard]] Tensor bias() const; + + MLLM_LAYER_ANY_INPUTS_1_OUTPUTS_FORWARD + MLLM_LAYER_ENABLE_REDIRECT_ATTRIBUTE(Conv2D) +}; + +} // namespace mllm::nn diff --git a/mllm/nn/layers/LayerNorm2D.cpp b/mllm/nn/layers/LayerNorm2D.cpp new file mode 100644 index 000000000..628d6d64f --- /dev/null +++ b/mllm/nn/layers/LayerNorm2D.cpp @@ -0,0 +1,16 @@ +// Copyright (c) MLLM Team. +// Licensed under the MIT License. + +#include "mllm/core/aops/LayerNorm2DOp.hpp" +#include "mllm/nn/layers/LayerNorm2D.hpp" + +namespace mllm::nn { + +LayerNorm2D::LayerNorm2D() : Layer(OpTypes::kLayerNorm2D, aops::LayerNorm2DOpOptions{}) {} + +LayerNorm2D::LayerNorm2D(const aops::LayerNorm2DOpOptions& options) : Layer(OpTypes::kLayerNorm2D, options) {} + +LayerNorm2D::LayerNorm2D(const int32_t num_channels, float eps) + : Layer(OpTypes::kLayerNorm2D, aops::LayerNorm2DOpOptions{.num_channels = num_channels, .eps = eps}) {} + +} // namespace mllm::nn diff --git a/mllm/nn/layers/LayerNorm2D.hpp b/mllm/nn/layers/LayerNorm2D.hpp new file mode 100644 index 000000000..ebe29d0ca --- /dev/null +++ b/mllm/nn/layers/LayerNorm2D.hpp @@ -0,0 +1,23 @@ +// Copyright (c) MLLM Team. +// Licensed under the MIT License. + +#pragma once + +#include "mllm/nn/Layer.hpp" +#include "mllm/core/aops/LayerNorm2DOp.hpp" + +namespace mllm::nn { + +class LayerNorm2D : public Layer { + public: + LayerNorm2D(); + + explicit LayerNorm2D(const aops::LayerNorm2DOpOptions& options); + + explicit LayerNorm2D(const int32_t num_channels, float eps = 1e-6); + + MLLM_LAYER_ANY_INPUTS_1_OUTPUTS_FORWARD + MLLM_LAYER_ENABLE_INPLACE_ATTRIBUTE(LayerNorm2D) +}; + +} // namespace mllm::nn diff --git a/tests/cpu/Conv2DKernelTest.hpp b/tests/cpu/Conv2DKernelTest.hpp new file mode 100644 index 000000000..67c6d89a2 --- /dev/null +++ b/tests/cpu/Conv2DKernelTest.hpp @@ -0,0 +1,119 @@ +// Copyright (c) MLLM Team. +// Licensed under the MIT License. +#include + +#include "mllm/mllm.hpp" +#include "mllm/nn/Nn.hpp" +#include "mllm/core/ParameterFile.hpp" + +#include "KernelTestHelper.hpp" + +using namespace mllm; // NOLINT + +void naive_conv2d(const float* input_data, const float* weight_data, const float* bias_data, float* output_data, + int in_channels, int in_h, int in_w, int out_channels, int kernel_h, int kernel_w, int pad_h, int pad_w, + int stride_h, int stride_w, int dilation_h, int dilation_w) { + const int out_h = (in_h + 2 * pad_h - (dilation_h * (kernel_h - 1) + 1)) / stride_h + 1; + const int out_w = (in_w + 2 * pad_w - (dilation_w * (kernel_w - 1) + 1)) / stride_w + 1; + for (int oc = 0; oc < out_channels; ++oc) { + for (int oy = 0; oy < out_h; ++oy) { + for (int ox = 0; ox < out_w; ++ox) { + float accumulated_value = 0.0f; + for (int ic = 0; ic < in_channels; ++ic) { + for (int ky = 0; ky < kernel_h; ++ky) { + for (int kx = 0; kx < kernel_w; ++kx) { + const int iy = oy * stride_h + ky * dilation_h - pad_h; + const int ix = ox * stride_w + kx * dilation_w - pad_w; + + if (iy >= 0 && iy < in_h && ix >= 0 && ix < in_w) { + int input_idx = ic * (in_h * in_w) + iy * in_w + ix; + int weight_idx = oc * (in_channels * kernel_h * kernel_w) + ic * (kernel_h * kernel_w) + ky * kernel_w + kx; + + accumulated_value += input_data[input_idx] * weight_data[weight_idx]; + } + } + } + } + + if (bias_data != nullptr) { accumulated_value += bias_data[oc]; } + + int output_idx = oc * (out_h * out_w) + oy * out_w + ox; + output_data[output_idx] = accumulated_value; + } + } + } +} + +class Conv2DModule : public nn::Module { + nn::Conv2D conv2d_; + + public: + Conv2DModule() = default; + + Conv2DModule(int in_channel, int out_channel, int K_H, int K_W, int S_H, int S_W, int P_H, int P_W, bool bias) + : nn::Module() { + conv2d_ = reg("emb", in_channel, out_channel, std::vector{K_H, K_W}, std::vector{S_H, S_W}, + std::vector{P_H, P_W}, std::vector{1, 1}, bias); + } + + std::vector forward(const std::vector& inputs, const std::vector& args) override { + // inputs is Q, K_indices, V_indices + return {conv2d_(inputs[0])}; + } +}; + +class Conv2DKernelTest : public KernelTest { + public: + Conv2DKernelTest() = default; + ~Conv2DKernelTest() override = default; + + bool testConv2DOnce(const std::unordered_map& cfg) { + auto in_channel = cfg.at("in_channel"); + auto out_channel = cfg.at("out_channel"); + auto I_H = cfg.at("I_H"); + auto I_W = cfg.at("I_W"); + auto K_H = cfg.at("K_H"); + auto K_W = cfg.at("K_W"); + auto S_H = cfg.at("S_H"); + auto S_W = cfg.at("S_W"); + auto P_H = cfg.at("P_H"); + auto P_W = cfg.at("P_W"); + auto bias = cfg.at("bias"); + + auto module = Conv2DModule(in_channel, out_channel, K_H, K_W, S_H, S_W, P_H, P_W, bias); + + // Make fake data + auto weight_param = Tensor::random({out_channel, in_channel, K_H, K_W}, -1, 1, kFloat32, kCPU); + auto bias_param = Tensor::random({out_channel}, -1, 1, kFloat32, kCPU); + weight_param.setName("emb.weight"); + bias_param.setName("emb.bias"); + auto param = ParameterFile::create(); + param->push("emb.weight", weight_param); + param->push("emb.bias", bias_param); + module.load(param); + + auto input = Tensor::random({1, in_channel, I_H, I_W}, -1, 1, kFloat32, kCPU); + auto predict = module(input)[0]; + + auto GT = Tensor::empty(predict.shape(), kFloat32, kCPU).alloc(); + + // Naive impl to check correctness. + naive_conv2d(input.ptr(), weight_param.ptr(), bias ? bias_param.ptr() : nullptr, GT.ptr(), + in_channel, I_H, I_W, out_channel, K_H, K_W, P_H, P_W, S_H, S_W, 1, 1); + + auto result = test::allClose(GT, predict, 1e-2f, 1e-2f); + if (!result) { + print(result); + return false; + } + + return true; + } + + bool testConv2D(const std::vector>& cfgs) { + for (auto& cfg : cfgs) { + if (!testConv2DOnce(cfg)) { return false; } + } + return true; + } +}; diff --git a/tests/cpu/FlashAttentionKernelTest.hpp b/tests/cpu/FlashAttentionKernelTest.hpp index 571cf41ff..27a5e0d10 100644 --- a/tests/cpu/FlashAttentionKernelTest.hpp +++ b/tests/cpu/FlashAttentionKernelTest.hpp @@ -7,7 +7,6 @@ #include "mllm/mllm.hpp" #include "mllm/nn/Nn.hpp" #include "mllm/nn/Functional.hpp" -#include "mllm/nn/lmcache/PrefixCache.hpp" #include "KernelTestHelper.hpp" diff --git a/tests/cpu/KernelTest.cpp b/tests/cpu/KernelTest.cpp index a0fe4d7b8..d06bc68fa 100644 --- a/tests/cpu/KernelTest.cpp +++ b/tests/cpu/KernelTest.cpp @@ -838,6 +838,55 @@ TEST_F(FlashAttn2KernelTest, fwd_bshd) { } #endif +//===----------------------------------------------------------------------===// +// Conv2D Test +// +// auto in_channel = cfg.at("in_channel"); +// auto out_channel = cfg.at("out_channel"); +// auto I_H = cfg.at("I_H"); +// auto I_W = cfg.at("I_W"); +// auto K_H = cfg.at("K_H"); +// auto K_W = cfg.at("K_W"); +// auto S_H = cfg.at("S_H"); +// auto S_W = cfg.at("S_W"); +// auto P_H = cfg.at("P_H"); +// auto P_W = cfg.at("P_W"); +// auto bias = cfg.at("bias"); +// +// In deepseek-ocr we have +// CASE 1: +// in_channel = 3 +// out_channel = 1024 +// I_H = 224 +// I_W = 224 +// K_H = 14 +// K_W = 14 +// S_H = 14 +// S_W = 14 +// P_H = 0 +// P_W = 0 +// bias = false +// CASE 2: +// +//===----------------------------------------------------------------------===// +#include "Conv2DKernelTest.hpp" +TEST_F(Conv2DKernelTest, im2col) { + EXPECT_EQ(testConv2D({{ + {"in_channel", 3}, + {"out_channel", 1024}, + {"I_H", 224}, + {"I_W", 224}, + {"K_H", 14}, + {"K_W", 14}, + {"S_H", 14}, + {"S_W", 14}, + {"P_H", 0}, + {"P_W", 0}, + {"bias", 0}, + }}), + true); +} + int main(int argc, char** argv) { testing::InitGoogleTest(&argc, argv); mllm::initializeContext(); From 8eebf54dd8c6f0b0b4081786603b7424e45b5aba Mon Sep 17 00:00:00 2001 From: chenghuaWang <2923277184@qq.com> Date: Thu, 23 Oct 2025 10:53:54 +0800 Subject: [PATCH 02/25] feat(deepseek_ocr): implement conversation management and preprocessing utilities - Add `Conversation` class to manage prompt templates and conversation history - Support multiple separator styles: DeepSeek, DeepSeekV2, PLAIN, ALIGNMENT - Implement methods for generating prompts, converting to Gradio/OpenAI formats - Add global registry for conversation templates with initialization function - Include image loading functionality from conversation messages - Refactor NoTPAttention module with proper tensor operations and attention mechanism - Implement padding logic in Block class for window-based processing - Add cubic interpolation and absolute position embedding resizing in ImageEncoderViT --- mllm/models/deepseek_ocr/conversation.hpp | 2 - .../deepseek_ocr/conversation_preprocess.hpp | 319 ++++++++++++++++++ mllm/models/deepseek_ocr/deepencoder.hpp | 115 ++++++- 3 files changed, 422 insertions(+), 14 deletions(-) delete mode 100644 mllm/models/deepseek_ocr/conversation.hpp create mode 100644 mllm/models/deepseek_ocr/conversation_preprocess.hpp diff --git a/mllm/models/deepseek_ocr/conversation.hpp b/mllm/models/deepseek_ocr/conversation.hpp deleted file mode 100644 index 8a948dbc3..000000000 --- a/mllm/models/deepseek_ocr/conversation.hpp +++ /dev/null @@ -1,2 +0,0 @@ -// Copyright (c) MLLM Team. -// Licensed under the MIT License. diff --git a/mllm/models/deepseek_ocr/conversation_preprocess.hpp b/mllm/models/deepseek_ocr/conversation_preprocess.hpp new file mode 100644 index 000000000..34f52bc9b --- /dev/null +++ b/mllm/models/deepseek_ocr/conversation_preprocess.hpp @@ -0,0 +1,319 @@ +// Copyright (c) MLLM Team. +// Licensed under the MIT License. +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "mllm/preprocessor/visual/Image.hpp" + +namespace mllm::models::deepseek_ocr { + +/** + * Separator styles for different conversation formats + */ +enum class SeparatorStyle { + DeepSeek, + DeepSeekV2, + PLAIN, + ALIGNMENT, +}; + +/** + * A class that manages prompt templates and keeps all conversation history + */ +class Conversation { + public: + // Constructor with required and optional parameters + explicit Conversation(const std::string& name, const std::string& system_template = "{system_message}", + const std::string& system_message = "", const std::vector& roles = {"USER", "ASSISTANT"}, + const std::vector>& messages = {}, int offset = 0, + SeparatorStyle sep_style = SeparatorStyle::DeepSeek, const std::string& sep = "\n", + const std::optional& sep2 = std::nullopt, + const std::optional& stop_str = std::nullopt, + const std::optional>& stop_token_ids = std::nullopt) + : name_(name), + system_template_(system_template), + system_message_(system_message), + roles_(roles), + messages_(messages), + offset_(offset), + sep_style_(sep_style), + sep_(sep), + sep2_(sep2), + stop_str_(stop_str), + stop_token_ids_(stop_token_ids) {} + + /** + * Get the prompt for generation + */ + [[nodiscard]] std::string getPrompt() const { + std::string system_prompt = formatSystemTemplate(); + + if (sep_style_ == SeparatorStyle::DeepSeek) { + std::vector seps = {sep_, sep2_.value_or("")}; + std::string ret; + + if (!system_prompt.empty()) { ret = system_prompt + seps[0]; } + + for (size_t i = 0; i < messages_.size(); ++i) { + const auto& [role, message] = std::make_pair(messages_[i][0], messages_[i][1]); + if (!message.empty()) { + ret += role + ": " + message + seps[i % 2]; // NOLINT + } else { + ret += role + ":"; + } + } + return ret; + } else if (sep_style_ == SeparatorStyle::DeepSeekV2) { + std::vector seps = {sep_, sep2_.value_or("")}; + std::string ret; + + if (!system_prompt.empty()) { ret = system_prompt + seps[0]; } + + for (const auto& i : messages_) { + const auto& [role, message] = std::make_pair(i[0], i[1]); + if (!message.empty()) { + if (role == "User") { + ret += "<|sft▁begin|>\n" + message + sep_; + } else { + ret += message + sep2_.value_or(""); + } + } + } + return ret; + } else if (sep_style_ == SeparatorStyle::PLAIN) { + std::vector seps = {sep_, sep2_.value_or("")}; + std::string ret; + + for (size_t i = 0; i < messages_.size(); ++i) { + const auto& [role, message] = std::make_pair(messages_[i][0], messages_[i][1]); + if (!message.empty()) { ret += message + seps[i % 2]; } + } + return ret; + } else if (sep_style_ == SeparatorStyle::ALIGNMENT) { + std::vector seps = {sep_, sep2_.value_or("")}; + std::string ret; + + for (size_t i = 0; i < messages_.size(); ++i) { + const auto& [role, message] = std::make_pair(messages_[i][0], messages_[i][1]); + if (!message.empty()) { + if (i % 2 == 0) { + ret += "\n" + seps[i % 2]; + } else { + ret += message + seps[i % 2]; + } + } + } + return ret; + } else { + throw std::invalid_argument("Invalid separator style"); + } + } + + /** + * Set the system message + */ + void setSystemMessage(const std::string& system_message) { system_message_ = system_message; } + + /** + * Append a new message + */ + void appendMessage(const std::string& role, const std::string& message) { messages_.push_back({role, message}); } + + /** + * Update the last output + * The last message is typically set to be empty when constructing the prompt, + * so we need to update it in-place after getting the response from a model. + */ + void updateLastMessage(const std::string& message) { + if (!messages_.empty()) { messages_.back()[1] = message; } + } + + /** + * Reset messages + */ + void resetMessages() { messages_.clear(); } + + /** + * Convert the conversation to gradio chatbot format + */ + [[nodiscard]] std::vector>> toGradioChatbot() const { + std::vector>> ret; + + for (size_t i = offset_; i < messages_.size(); ++i) { + const auto& [role, msg] = std::make_pair(messages_[i][0], messages_[i][1]); + if (i % 2 == 0) { + ret.push_back({msg, std::nullopt}); + } else if (!ret.empty()) { + ret.back()[1] = msg; + } + } + + return ret; + } + + /** + * Convert the conversation to OpenAI chat completion format + */ + [[nodiscard]] std::vector> toOpenAIApiMessages() const { + std::string system_prompt = formatSystemTemplate(); + std::vector> ret; + + ret.push_back({{"role", "system"}, {"content", system_prompt}}); + + for (size_t i = offset_; i < messages_.size(); ++i) { + const auto& [_, msg] = std::make_pair(messages_[i][0], messages_[i][1]); + if (i % 2 == 0) { + ret.push_back({{"role", "user"}, {"content", msg}}); + } else if (!msg.empty()) { + ret.push_back({{"role", "assistant"}, {"content", msg}}); + } + } + + return ret; + } + + /** + * Create a copy of the conversation + */ + [[nodiscard]] std::shared_ptr copy() const { + return std::make_shared(name_, system_template_, system_message_, roles_, messages_, offset_, sep_style_, + sep_, sep2_, stop_str_, stop_token_ids_); + } + + /** + * Convert the conversation to a dictionary + */ + [[nodiscard]] std::map, std::vector>, int>> + toDict() const { + return {{"template_name", name_}, + {"system_message", system_message_}, + {"roles", roles_}, + {"messages", messages_}, + {"offset", offset_}}; + } + + // Getters + [[nodiscard]] const std::string& getName() const { return name_; } + [[nodiscard]] const std::vector& getRoles() const { return roles_; } + + private: + [[nodiscard]] std::string formatSystemTemplate() const { + std::string result = system_template_; + size_t pos = result.find("{system_message}"); + if (pos != std::string::npos) { result.replace(pos, 16, system_message_); } + return result; + } + + private: + std::string name_; + std::string system_template_; + std::string system_message_; + std::vector roles_; + std::vector> messages_; + int offset_; + SeparatorStyle sep_style_; + std::string sep_; + std::optional sep2_; + std::optional stop_str_; + std::optional> stop_token_ids_; +}; + +// A global registry for all conversation templates +static std::map> conv_templates; + +/** + * Register a new conversation template + */ +void registerConvTemplate(const std::shared_ptr& template_ptr, bool override = false) { + const std::string& name = template_ptr->getName(); + if (!override) { assert(conv_templates.find(name) == conv_templates.end() && (name + " has been registered.").c_str()); } + conv_templates[name] = template_ptr; +} + +/** + * Get a conversation template + */ +std::shared_ptr getConvTemplate(const std::string& name) { + auto it = conv_templates.find(name); + if (it == conv_templates.end()) { throw std::runtime_error("Template not found: " + name); } + return it->second->copy(); +} + +// Initialize templates +void initializeTemplates() { + // DeepSeek template + auto deepseek = std::make_shared( + "deepseek", "{system_message}", "", std::vector{"<|User|>", "<|Assistant|>"}, + std::vector>{}, 0, SeparatorStyle::DeepSeek, "\n\n", "<|end▁of▁sentence|>", + std::vector{"User:", "<|end▁of▁sentence|>"}, std::vector{100001}); + registerConvTemplate(deepseek); + + // DeepSeekV2 template + auto deepseekv2 = std::make_shared( + "deepseekv2", "{system_message}", "", std::vector{"<|User|>", "<|Assistant|>"}, + std::vector>{}, 0, SeparatorStyle::DeepSeek, "", "<|end▁of▁sentence|>", + std::vector{"User:", "<|end▁of▁sentence|>"}, std::vector{100001}); + registerConvTemplate(deepseekv2); + + // Plain template + auto plain = std::make_shared("plain", "", "", std::vector{"", ""}, + std::vector>{}, 0, SeparatorStyle::PLAIN, "", "", + std::vector{""}, std::vector{100001}); + registerConvTemplate(plain); + + // Alignment template + auto alignment = std::make_shared("alignment", "", "", std::vector{"", ""}, + std::vector>{}, 0, SeparatorStyle::ALIGNMENT, "", "", + std::vector{""}, std::vector{100001}); + registerConvTemplate(alignment); +} + +//===----------------------------------------------------------------------===// +// For Image processing +//===----------------------------------------------------------------------===// + +/** + * Loads images from conversation messages + * + * @param conversations JSON array of conversation messages + * An example is: + * [ + * { + * "role": "User", + * "content": "\nExtract all information from this image and convert them into markdown format.", + * "images": ["./examples/table_datasets.png"] + * }, + * {"role": "Assistant", "content": ""} + * ] + * + * @return Vector of Image objects + */ +std::vector loadImages(const nlohmann::json& conversations) { + std::vector ret; + // Iterate through each conversation message + for (const auto& message : conversations) { + // Skip if message doesn't contain "images" field + if (!message.contains("images")) { continue; } + + // Process each image path in the "images" array + for (const auto& image_path : message["images"]) { + // Load the image using Image::open + Image img = Image::open(image_path); + // Add to result vector + ret.push_back(img); + } + } + return ret; +} + +} // namespace mllm::models::deepseek_ocr diff --git a/mllm/models/deepseek_ocr/deepencoder.hpp b/mllm/models/deepseek_ocr/deepencoder.hpp index 12baca5b3..0ad56e8b7 100644 --- a/mllm/models/deepseek_ocr/deepencoder.hpp +++ b/mllm/models/deepseek_ocr/deepencoder.hpp @@ -221,7 +221,7 @@ class NoTPAttention final : public nn::Module { public: NoTPAttention() = default; - NoTPAttention(const std::string& name, const DpskOcrConfig& config) { + NoTPAttention(const std::string& name, const DpskOcrConfig& config) : nn::Module(name) { num_heads_ = 16; n_local_heads_ = 16; head_dim_ = 1024 / 16; @@ -232,7 +232,6 @@ class NoTPAttention final : public nn::Module { } std::vector forward(const std::vector& inputs, const std::vector& args) override { - // TODO auto& x = inputs[0]; auto bsz = x.size(0); auto seqlen = x.size(1); @@ -241,15 +240,18 @@ class NoTPAttention final : public nn::Module { xqkv = xqkv.view({bsz, seqlen, 3, num_heads_, head_dim_}); auto [xq, xk, xv] = nn::functional::split<3>(xqkv, 2); - // TODO need squeeze ? + xq = xq.squeeze(2); + xk = xk.squeeze(2); + xv = xv.squeeze(2); - // TODO permute ? + xq = xq.permute({0, 2, 1, 3}); + xk = xk.permute({0, 2, 1, 3}); + xv = xv.permute({0, 2, 1, 3}); - // TODO FA without mask - - // TODO outproj. - - return {}; + auto output = nn::functional::scaledDotProductAttention(xq, xk, xv); + output = output.permute({0, 2, 1, 3}).reshape({bsz, seqlen, -1}); + output = out_proj_(output); + return {output}; } }; @@ -609,9 +611,32 @@ class Block final : public nn::Module { auto pad_w = (window_size - W % window_size) % window_size; if (pad_h > 0 || pad_w > 0) { - // TODO do pad - // x = F.pad(x, (0, 0, 0, pad_w, 0, pad_h)) + // Do x = F.pad(x, (0, 0, 0, pad_w, 0, pad_h)) + const auto& x_shape = x.shape(); + const int in_c = x_shape[0]; + const int in_h = x_shape[1]; + const int in_w = x_shape[2]; + + const int out_h = in_h + pad_h; + const int out_w = in_w + pad_w; + std::vector out_shape = {in_c, out_h, out_w}; + + Tensor padded_x = Tensor::empty(out_shape).alloc(); + + mllm_fp32_t* x_data_ptr = x.ptr(); + mllm_fp32_t* padded_x_data_ptr = padded_x.ptr(); + + for (int c = 0; c < in_c; ++c) { + for (int h = 0; h < in_h; ++h) { + for (int w = 0; w < in_w; ++w) { + int src_idx = c * (in_h * in_w) + h * in_w + w; + int dest_idx = c * (out_h * out_w) + h * out_w + w; + padded_x_data_ptr[dest_idx] = x_data_ptr[src_idx]; + } + } + } } + auto Hp = H + pad_h; auto Wp = W + pad_w; @@ -716,10 +741,76 @@ class ImageEncoderViT final : public nn::Module { Tensor::shape_t{1, 1}, false); } + template + T __cubicInterpolate(T p0, T p1, T p2, T p3, float t) { + return p1 + 0.5f * t * (p2 - p0 + t * (2.0f * p0 - 5.0f * p1 + 4.0f * p2 - p3 + t * (3.0f * (p1 - p2) + p3 - p0))); + } + + Tensor getAbsPosSam(Tensor abs_pos, int tgt_size) { + auto dtype = abs_pos.dtype(); + auto src_size = abs_pos.size(1); + + if (src_size != tgt_size) { + auto old_pos_embed = abs_pos.permute({0, 3, 1, 2}); + old_pos_embed = old_pos_embed.to(kFloat32); + + const int batch_size = old_pos_embed.size(0); + const int channels = old_pos_embed.size(1); + const int src_h = old_pos_embed.size(2); + const int src_w = old_pos_embed.size(3); + const int tgt_h = tgt_size; + const int tgt_w = tgt_size; + + auto new_pos_embed = Tensor::empty({batch_size, channels, tgt_h, tgt_w}, kFloat32).alloc(); + + const float* src_data = old_pos_embed.ptr(); + float* dst_data = new_pos_embed.ptr(); + + const float height_scale = static_cast(src_h) / tgt_h; + const float width_scale = static_cast(src_w) / tgt_w; + for (int b = 0; b < batch_size; ++b) { + for (int c = 0; c < channels; ++c) { + const float* current_src_channel = src_data + (b * channels + c) * src_h * src_w; + float* current_dst_channel = dst_data + (b * channels + c) * tgt_h * tgt_w; + + for (int y_tgt = 0; y_tgt < tgt_h; ++y_tgt) { + for (int x_tgt = 0; x_tgt < tgt_w; ++x_tgt) { + float y_src = (static_cast(y_tgt) + 0.5f) * height_scale - 0.5f; + float x_src = (static_cast(x_tgt) + 0.5f) * width_scale - 0.5f; + + int y_floor = static_cast(std::floor(y_src)); + int x_floor = static_cast(std::floor(x_src)); + float y_frac = y_src - y_floor; + float x_frac = x_src - x_floor; + + float p[4][4]; + for (int i = 0; i < 4; ++i) { + for (int j = 0; j < 4; ++j) { + int y_coord = std::max(0, std::min(src_h - 1, y_floor - 1 + i)); + int x_coord = std::max(0, std::min(src_w - 1, x_floor - 1 + j)); + p[i][j] = current_src_channel[y_coord * src_w + x_coord]; + } + } + float col[4]; + for (int i = 0; i < 4; ++i) { col[i] = __cubicInterpolate(p[i][0], p[i][1], p[i][2], p[i][3], x_frac); } + float value = __cubicInterpolate(col[0], col[1], col[2], col[3], y_frac); + current_dst_channel[y_tgt * tgt_w + x_tgt] = value; + } + } + } + } + new_pos_embed = new_pos_embed.to(dtype); + new_pos_embed = new_pos_embed.permute({0, 2, 3, 1}); + return new_pos_embed; + } else { + return abs_pos; + } + } + std::vector forward(const std::vector& inputs, const std::vector& args) override { auto x = inputs[0]; x = patch_embed_(x)[0]; - // TODO x = x + get_abs_pos_sam(self.pos_embed, x.size(1)) + x = x + getAbsPosSam(pos_embed_.weight(), x.size(1)); for (auto& blk : blocks_.list()) { x = blk(x)[0]; } x = neck_(x.permute({0, 3, 1, 2}))[0]; From 60f6f92342784b55e7451c1cbe4bb01fb16bc444 Mon Sep 17 00:00:00 2001 From: chenghuaWang <2923277184@qq.com> Date: Thu, 23 Oct 2025 14:52:25 +0800 Subject: [PATCH 03/25] feat(cpu): add interpolate and pad operations with full interpolation modes and padding strategies Added comprehensive interpolation and padding support for CPU backend: - InterpolateOp with nearest, linear, bilinear, bicubic, and trilinear modes - PadOp with constant, reflect, replicate, and circular padding modes - Updated CPU backend to register new operations - Fixed conv2d weight packing and bias addition logic - Added documentation for new functional APIs The implementation supports various tensor dimensions and aligns with PyTorch-style behavior for consistent results. --- docs/api/functional.rst | 31 ++ mllm/backends/cpu/CPUBackend.cpp | 4 +- mllm/backends/cpu/kernels/arm/conv2d.cpp | 9 +- mllm/backends/cpu/ops/Conv2DOp.cpp | 26 +- mllm/backends/cpu/ops/EinsumOp.cpp | 0 mllm/backends/cpu/ops/EinsumOp.hpp | 0 mllm/backends/cpu/ops/InterpolateOp.cpp | 438 +++++++++++++++++++++++ mllm/backends/cpu/ops/InterpolateOp.hpp | 25 ++ mllm/backends/cpu/ops/PadOp.cpp | 139 +++++++ mllm/backends/cpu/ops/PadOp.hpp | 25 ++ mllm/compile/ir/GeneratedRTTIKind.hpp | 5 +- mllm/compile/ir/NodeRTTIClassOfImpl.hpp | 11 +- mllm/compile/ir/linalg/Op.cpp | 2 + mllm/compile/ir/linalg/Op.hpp | 4 + mllm/compile/ir/rtti_kind_gen.py | 3 + mllm/core/OpTypes.hpp | 4 + mllm/core/aops/EinsumOp.cpp | 0 mllm/core/aops/EinsumOp.hpp | 0 mllm/core/aops/InterpolateOp.cpp | 81 +++++ mllm/core/aops/InterpolateOp.hpp | 47 +++ mllm/core/aops/PadOp.cpp | 28 ++ mllm/core/aops/PadOp.hpp | 44 +++ mllm/nn/Functional.cpp | 25 ++ mllm/nn/Functional.hpp | 14 + mllm/nn/Module.hpp | 2 +- tests/cpu/Conv2DKernelTest.hpp | 16 +- tests/cpu/KernelTest.cpp | 99 ++++- 27 files changed, 1055 insertions(+), 27 deletions(-) create mode 100644 mllm/backends/cpu/ops/EinsumOp.cpp create mode 100644 mllm/backends/cpu/ops/EinsumOp.hpp create mode 100644 mllm/backends/cpu/ops/InterpolateOp.cpp create mode 100644 mllm/backends/cpu/ops/InterpolateOp.hpp create mode 100644 mllm/backends/cpu/ops/PadOp.cpp create mode 100644 mllm/backends/cpu/ops/PadOp.hpp create mode 100644 mllm/core/aops/EinsumOp.cpp create mode 100644 mllm/core/aops/EinsumOp.hpp create mode 100644 mllm/core/aops/InterpolateOp.cpp create mode 100644 mllm/core/aops/InterpolateOp.hpp diff --git a/docs/api/functional.rst b/docs/api/functional.rst index 5fca137d2..03f8bd2e8 100644 --- a/docs/api/functional.rst +++ b/docs/api/functional.rst @@ -78,6 +78,37 @@ Shape Operations :param dim: Dimension along which to concatenate :return: Concatenated tensor +.. cpp:function:: Tensor mllm::nn::functional::pad(const Tensor& x, const std::vector& pad, aops::PadMode mode = aops::PadMode::kConstant, float value = 0.0f) + + Pad a tensor along the last N dimensions as specified. + + :param x: Input tensor + :param pad: Padding sizes ordered from the last dimension to the first, e.g. [last_left, last_right, ..., first_left, first_right] + :param mode: Padding mode (kConstant, kReflect, kReplicate, kCircular). Default: kConstant + :param value: Constant value used when mode is kConstant. Default: 0.0 + :return: Padded tensor + +.. cpp:function:: Tensor mllm::nn::functional::interpolate(const Tensor& x, const std::vector& size, aops::InterpolateOpMode mode = aops::InterpolateOpMode::kNearest, bool align_corners = false, bool keep_aspect_ratio = false) + + Resize a tensor to the target spatial size. + + :param x: Input tensor (supports 1D/2D/3D spatial resizing depending on mode) + :param size: Target spatial size (e.g., [H_out, W_out] for 2D) + :param mode: Interpolation mode (kNearest, kLinear, kBilinear, kBicubic, kTrilinear). Default: kNearest + :param align_corners: Align corners for linear/bilinear/trilinear interpolation. Default: false + :param keep_aspect_ratio: Keep aspect ratio when size is provided (handled by AOP). Default: false + :return: Resized tensor + +.. cpp:function:: Tensor mllm::nn::functional::interpolate(const Tensor& x, const std::vector& scale_factor, aops::InterpolateOpMode mode = aops::InterpolateOpMode::kNearest, bool align_corners = false) + + Resize a tensor by scale factors per spatial dimension. + + :param x: Input tensor (supports 1D/2D/3D spatial resizing depending on mode) + :param scale_factor: Scale factors per spatial dimension (e.g., [sh, sw] for 2D) + :param mode: Interpolation mode (kNearest, kLinear, kBilinear, kBicubic, kTrilinear). Default: kNearest + :param align_corners: Align corners for linear/bilinear/trilinear interpolation. Default: false + :return: Resized tensor + Attention Operations -------------------- diff --git a/mllm/backends/cpu/CPUBackend.cpp b/mllm/backends/cpu/CPUBackend.cpp index dd8549804..8a607c5ae 100644 --- a/mllm/backends/cpu/CPUBackend.cpp +++ b/mllm/backends/cpu/CPUBackend.cpp @@ -18,7 +18,9 @@ #include "mllm/backends/cpu/ops/FillOp.hpp" #include "mllm/backends/cpu/ops/FlashAttention2Op.hpp" #include "mllm/backends/cpu/ops/GELUOp.hpp" +#include "mllm/backends/cpu/ops/InterpolateOp.hpp" #include "mllm/backends/cpu/ops/LayerNorm2DOp.hpp" +#include "mllm/backends/cpu/ops/PadOp.hpp" #include "mllm/backends/cpu/ops/RadixAttnOp.hpp" #include "mllm/backends/cpu/ops/ReLUOp.hpp" #include "mllm/backends/cpu/ops/GraphOps.hpp" @@ -63,7 +65,7 @@ CPUBackend::CPUBackend() : Backend(kCPU, createCPUAllocator()) { CPUMultimodalRoPEOpFactory, CPURoPEOpFactory, CPUCausalMaskOpFactory, CPUConv1DOpFactory, CPUConv3DOpFactory, CPUSTFTOpFactory, CPUISTFTOpFactory, CPUIndexOpFactory, CPUTopKOpFactory, CPUClipOpFactory, CPUMeanOpFactory, CPUKVCacheOpFactory, CPUPagedAttnOpFactory, CPUScatter2ShardsOpFactory, CPURadixAttnOpFactory, - CPUConv2DOpFactory, CPULayerNorm2DOpFactory>(); + CPUConv2DOpFactory, CPULayerNorm2DOpFactory, CPUInterpolateOpFactory, CPUPadOpFactory>(); } std::shared_ptr createCPUBackend() { return std::make_shared(); } diff --git a/mllm/backends/cpu/kernels/arm/conv2d.cpp b/mllm/backends/cpu/kernels/arm/conv2d.cpp index d7ceddf18..27bd3ae46 100644 --- a/mllm/backends/cpu/kernels/arm/conv2d.cpp +++ b/mllm/backends/cpu/kernels/arm/conv2d.cpp @@ -79,15 +79,14 @@ void conv2d_fp32_im2col_input(const float* input_data, const int channels, const void conv2d_fp32_im2col_weight(const float* src_weight, float* packed_weight, int out_channels, int in_channels, int kernel_h, int kernel_w) { - int M = out_channels; - int K = in_channels * kernel_h * kernel_w; - + // Original Weight: [Out, In, Kh, Kw] + // Packed Weight: [Out, In*Kh*Kw] for (int o = 0; o < out_channels; ++o) { for (int i = 0; i < in_channels; ++i) { for (int h = 0; h < kernel_h; ++h) { for (int w = 0; w < kernel_w; ++w) { - int src_idx = h * (kernel_w * in_channels * out_channels) + w * (in_channels * out_channels) + i * (out_channels) + o; - int dst_idx = o * (in_channels * kernel_h * kernel_w) + i * (kernel_h * kernel_w) + h * (kernel_w) + w; + int src_idx = o * (in_channels * kernel_h * kernel_w) + i * (kernel_h * kernel_w) + h * kernel_w + w; + int dst_idx = o * (in_channels * kernel_h * kernel_w) + i * (kernel_h * kernel_w) + h * kernel_w + w; packed_weight[dst_idx] = src_weight[src_idx]; } } diff --git a/mllm/backends/cpu/ops/Conv2DOp.cpp b/mllm/backends/cpu/ops/Conv2DOp.cpp index aef3c83fa..ae07b2c1d 100644 --- a/mllm/backends/cpu/ops/Conv2DOp.cpp +++ b/mllm/backends/cpu/ops/Conv2DOp.cpp @@ -98,7 +98,7 @@ void CPUConv2DOp::forward(const std::vector& inputs, std::vector #if defined(MLLM_HOST_ARCH_ARM64) || defined(MLLM_HOST_ARCH_ARM) // step 1. im2col inputs to tmp auto packed_inputs = Tensor::empty({MATMUL_K, MATMUL_N}, input.dtype(), input.device()).alloc(); - arm::conv2d_fp32_im2col_input(input.ptr(), options_.in_channels, input.shape()[1], input.shape()[2], + arm::conv2d_fp32_im2col_input(input.ptr(), options_.in_channels, input.shape()[2], input.shape()[3], kernel_size[0], kernel_size[1], padding[0], padding[1], stride[0], stride[1], dilation[0], dilation[1], packed_inputs.ptr()); // step 2. Do matmul @@ -106,7 +106,17 @@ void CPUConv2DOp::forward(const std::vector& inputs, std::vector case aops::MatMulOpType::kBLAS: { #if defined(MLLM_USE_BLAS) blas::matmul_fp32(weight_.ptr(), packed_inputs.ptr(), output.ptr(), - options_.bias ? bias_.ptr() : nullptr, MATMUL_M, MATMUL_N, MATMUL_K, false, false); + nullptr, MATMUL_M, MATMUL_N, MATMUL_K, false, false); + + // Add Bias + if (options_.bias) { + auto out_ptr = output.ptr(); + const auto bias_ptr = bias_.ptr(); + for (int m = 0; m < MATMUL_M; ++m) { + const float b = bias_ptr[m]; + for (int n = 0; n < MATMUL_N; ++n) { out_ptr[m * MATMUL_N + n] += b; } + } + } #else NYI("BLAS not supported. Pls set MLLM_USE_BLAS=ON to enable BLAS supports in cmake."); #endif @@ -116,8 +126,16 @@ void CPUConv2DOp::forward(const std::vector& inputs, std::vector auto thread_count = options_.getThreads(); #if defined(MLLM_HOST_ARCH_ARM64) || defined(MLLM_HOST_ARCH_ARM) arm::mllm_blas_matmul_fp32(MATMUL_M, MATMUL_K, MATMUL_N, output.ptr(), weight_.ptr(), - packed_inputs.ptr(), options_.bias ? bias_.ptr() : nullptr, - false, false, thread_count); + packed_inputs.ptr(), nullptr, false, false, thread_count); + // Add Bias + if (options_.bias) { + auto out_ptr = output.ptr(); + const auto bias_ptr = bias_.ptr(); + for (int m = 0; m < MATMUL_M; ++m) { + const float b = bias_ptr[m]; + for (int n = 0; n < MATMUL_N; ++n) { out_ptr[m * MATMUL_N + n] += b; } + } + } #else NYI("MllmBlas only support MLLM_HOST_ARCH_ARM64 or MLLM_HOST_ARCH_ARM right now.") #endif diff --git a/mllm/backends/cpu/ops/EinsumOp.cpp b/mllm/backends/cpu/ops/EinsumOp.cpp new file mode 100644 index 000000000..e69de29bb diff --git a/mllm/backends/cpu/ops/EinsumOp.hpp b/mllm/backends/cpu/ops/EinsumOp.hpp new file mode 100644 index 000000000..e69de29bb diff --git a/mllm/backends/cpu/ops/InterpolateOp.cpp b/mllm/backends/cpu/ops/InterpolateOp.cpp new file mode 100644 index 000000000..0450c021b --- /dev/null +++ b/mllm/backends/cpu/ops/InterpolateOp.cpp @@ -0,0 +1,438 @@ +// Copyright (c) MLLM Team. +// Licensed under the MIT License. + +#include +#include +#include +#include "mllm/backends/cpu/ops/InterpolateOp.hpp" + +namespace mllm::cpu { + +CPUInterpolateOp::CPUInterpolateOp(const aops::InterpolateOpOptions& options) : aops::InterpolateOp(options) {} + +// Helper function to compute scale factors +static void compute_scale_factors(const std::vector& input_size, const std::vector& output_size, + std::vector& scale_factors, bool align_corners) { + scale_factors.resize(input_size.size()); + for (size_t i = 0; i < input_size.size(); ++i) { + if (output_size[i] == 1) { + scale_factors[i] = 0.0f; + } else if (align_corners) { + scale_factors[i] = static_cast(input_size[i] - 1) / (output_size[i] - 1); + } else { + scale_factors[i] = static_cast(input_size[i]) / output_size[i]; + } + } +} + +// Nearest neighbor interpolation for 2D data (4D tensor with NCHW layout) +template +static void nearest_interpolate_2d(const T* input_data, T* output_data, const std::vector& input_shape, + const std::vector& output_shape, bool align_corners) { + const int batch_size = input_shape[0]; + const int channels = input_shape[1]; + const int input_height = input_shape[2]; + const int input_width = input_shape[3]; + const int output_height = output_shape[2]; + const int output_width = output_shape[3]; + + std::vector scale_factors; + compute_scale_factors({input_height, input_width}, {output_height, output_width}, scale_factors, align_corners); + + const float height_scale = scale_factors[0]; + const float width_scale = scale_factors[1]; + + for (int n = 0; n < batch_size; ++n) { + for (int c = 0; c < channels; ++c) { + for (int oh = 0; oh < output_height; ++oh) { + for (int ow = 0; ow < output_width; ++ow) { + // Compute source indices + float ih_f = align_corners ? oh * height_scale : (oh + 0.5f) * height_scale - 0.5f; + float iw_f = align_corners ? ow * width_scale : (ow + 0.5f) * width_scale - 0.5f; + + // Round to nearest neighbor + int ih = std::min(static_cast(std::round(ih_f)), input_height - 1); + int iw = std::min(static_cast(std::round(iw_f)), input_width - 1); + + // Handle edge cases + ih = std::max(0, ih); + iw = std::max(0, iw); + + // Copy value + const int input_idx = ((n * channels + c) * input_height + ih) * input_width + iw; + const int output_idx = ((n * channels + c) * output_height + oh) * output_width + ow; + output_data[output_idx] = input_data[input_idx]; + } + } + } + } +} + +// Linear interpolation for 1D data (3D tensor with NCL layout) +template +static void linear_interpolate_1d(const T* input_data, T* output_data, const std::vector& input_shape, + const std::vector& output_shape, bool align_corners) { + const int batch_size = input_shape[0]; + const int channels = input_shape[1]; + const int input_length = input_shape[2]; + const int output_length = output_shape[2]; + + std::vector scale_factors; + compute_scale_factors({input_length}, {output_length}, scale_factors, align_corners); + + const float length_scale = scale_factors[0]; + + for (int n = 0; n < batch_size; ++n) { + for (int c = 0; c < channels; ++c) { + for (int ol = 0; ol < output_length; ++ol) { + // Compute source position + float il_f = align_corners ? ol * length_scale : (ol + 0.5f) * length_scale - 0.5f; + + // Get the neighboring indices + int il_low = static_cast(std::floor(il_f)); + int il_high = il_low + 1; + + // Compute weights + float w_high = il_f - il_low; + float w_low = 1.0f - w_high; + + // Handle boundary conditions + il_low = std::max(0, std::min(il_low, input_length - 1)); + il_high = std::max(0, std::min(il_high, input_length - 1)); + + // Compute indices + const int input_idx_low = (n * channels + c) * input_length + il_low; + const int input_idx_high = (n * channels + c) * input_length + il_high; + const int output_idx = (n * channels + c) * output_length + ol; + + // Linear interpolation + output_data[output_idx] = static_cast(w_low * input_data[input_idx_low] + w_high * input_data[input_idx_high]); + } + } + } +} + +// Bilinear interpolation for 2D data (4D tensor with NCHW layout) +template +static void bilinear_interpolate_2d(const T* input_data, T* output_data, const std::vector& input_shape, + const std::vector& output_shape, bool align_corners) { + const int batch_size = input_shape[0]; + const int channels = input_shape[1]; + const int input_height = input_shape[2]; + const int input_width = input_shape[3]; + const int output_height = output_shape[2]; + const int output_width = output_shape[3]; + + std::vector scale_factors; + compute_scale_factors({input_height, input_width}, {output_height, output_width}, scale_factors, align_corners); + + const float height_scale = scale_factors[0]; + const float width_scale = scale_factors[1]; + + for (int n = 0; n < batch_size; ++n) { + for (int c = 0; c < channels; ++c) { + for (int oh = 0; oh < output_height; ++oh) { + for (int ow = 0; ow < output_width; ++ow) { + // Compute source position + float ih_f = align_corners ? oh * height_scale : (oh + 0.5f) * height_scale - 0.5f; + float iw_f = align_corners ? ow * width_scale : (ow + 0.5f) * width_scale - 0.5f; + + // Get the four neighboring pixels + int ih_low = static_cast(std::floor(ih_f)); + int iw_low = static_cast(std::floor(iw_f)); + int ih_high = ih_low + 1; + int iw_high = iw_low + 1; + + // Compute weights + float h_weight_high = ih_f - ih_low; + float w_weight_high = iw_f - iw_low; + float h_weight_low = 1.0f - h_weight_high; + float w_weight_low = 1.0f - w_weight_high; + + // Handle boundary conditions + ih_low = std::max(0, std::min(ih_low, input_height - 1)); + ih_high = std::max(0, std::min(ih_high, input_height - 1)); + iw_low = std::max(0, std::min(iw_low, input_width - 1)); + iw_high = std::max(0, std::min(iw_high, input_width - 1)); + + // Compute indices for the four corners + const int idx_top_left = ((n * channels + c) * input_height + ih_low) * input_width + iw_low; + const int idx_top_right = ((n * channels + c) * input_height + ih_low) * input_width + iw_high; + const int idx_bottom_left = ((n * channels + c) * input_height + ih_high) * input_width + iw_low; + const int idx_bottom_right = ((n * channels + c) * input_height + ih_high) * input_width + iw_high; + + // Compute output index + const int output_idx = ((n * channels + c) * output_height + oh) * output_width + ow; + + // Bilinear interpolation + output_data[output_idx] = static_cast(h_weight_low * w_weight_low * input_data[idx_top_left] + + h_weight_low * w_weight_high * input_data[idx_top_right] + + h_weight_high * w_weight_low * input_data[idx_bottom_left] + + h_weight_high * w_weight_high * input_data[idx_bottom_right]); + } + } + } + } +} + +// Bicubic interpolation helper function +static float cubic_interp1d(float x0, float x1, float x2, float x3, float t) { + float a = -0.5f * x0 + 1.5f * x1 - 1.5f * x2 + 0.5f * x3; + float b = x0 - 2.5f * x1 + 2.0f * x2 - 0.5f * x3; + float c = -0.5f * x0 + 0.5f * x2; + float d = x1; + + return ((a * t + b) * t + c) * t + d; +} + +// Bicubic interpolation for 2D data (4D tensor with NCHW layout) +template +static void bicubic_interpolate_2d(const T* input_data, T* output_data, const std::vector& input_shape, + const std::vector& output_shape, bool align_corners) { + const int batch_size = input_shape[0]; + const int channels = input_shape[1]; + const int input_height = input_shape[2]; + const int input_width = input_shape[3]; + const int output_height = output_shape[2]; + const int output_width = output_shape[3]; + + std::vector scale_factors; + compute_scale_factors({input_height, input_width}, {output_height, output_width}, scale_factors, align_corners); + + const float height_scale = scale_factors[0]; + const float width_scale = scale_factors[1]; + + auto get_value_bounded = [&](int n, int c, int h, int w) -> T { + h = std::max(0, std::min(h, input_height - 1)); + w = std::max(0, std::min(w, input_width - 1)); + return input_data[((n * channels + c) * input_height + h) * input_width + w]; + }; + + for (int n = 0; n < batch_size; ++n) { + for (int c = 0; c < channels; ++c) { + for (int oh = 0; oh < output_height; ++oh) { + for (int ow = 0; ow < output_width; ++ow) { + // Compute source position + float ih_f = align_corners ? oh * height_scale : (oh + 0.5f) * height_scale - 0.5f; + float iw_f = align_corners ? ow * width_scale : (ow + 0.5f) * width_scale - 0.5f; + + // Get the integer part + int ih = static_cast(std::floor(ih_f)); + int iw = static_cast(std::floor(iw_f)); + + // Get fractional part + float h_frac = ih_f - ih; + float w_frac = iw_f - iw; + + // Compute output index + const int output_idx = ((n * channels + c) * output_height + oh) * output_width + ow; + + // Perform bicubic interpolation + float coeffs[4]; + + // Interpolate along each row + for (int i = 0; i < 4; ++i) { + float row_values[4]; + for (int j = 0; j < 4; ++j) { row_values[j] = static_cast(get_value_bounded(n, c, ih + i - 1, iw + j - 1)); } + coeffs[i] = cubic_interp1d(row_values[0], row_values[1], row_values[2], row_values[3], w_frac); + } + + // Interpolate along column + float result = cubic_interp1d(coeffs[0], coeffs[1], coeffs[2], coeffs[3], h_frac); + + // Clamp result to avoid overshoot/undershoot + output_data[output_idx] = static_cast(result); + } + } + } + } +} + +// Trilinear interpolation for 3D data (5D tensor with NCDHW layout) +template +static void trilinear_interpolate_3d(const T* input_data, T* output_data, const std::vector& input_shape, + const std::vector& output_shape, bool align_corners) { + const int batch_size = input_shape[0]; + const int channels = input_shape[1]; + const int input_depth = input_shape[2]; + const int input_height = input_shape[3]; + const int input_width = input_shape[4]; + const int output_depth = output_shape[2]; + const int output_height = output_shape[3]; + const int output_width = output_shape[4]; + + std::vector scale_factors; + compute_scale_factors({input_depth, input_height, input_width}, {output_depth, output_height, output_width}, scale_factors, + align_corners); + + const float depth_scale = scale_factors[0]; + const float height_scale = scale_factors[1]; + const float width_scale = scale_factors[2]; + + for (int n = 0; n < batch_size; ++n) { + for (int c = 0; c < channels; ++c) { + for (int od = 0; od < output_depth; ++od) { + for (int oh = 0; oh < output_height; ++oh) { + for (int ow = 0; ow < output_width; ++ow) { + // Compute source position + float id_f = align_corners ? od * depth_scale : (od + 0.5f) * depth_scale - 0.5f; + float ih_f = align_corners ? oh * height_scale : (oh + 0.5f) * height_scale - 0.5f; + float iw_f = align_corners ? ow * width_scale : (ow + 0.5f) * width_scale - 0.5f; + + // Get the eight neighboring voxels + int id_low = static_cast(std::floor(id_f)); + int ih_low = static_cast(std::floor(ih_f)); + int iw_low = static_cast(std::floor(iw_f)); + int id_high = id_low + 1; + int ih_high = ih_low + 1; + int iw_high = iw_low + 1; + + // Compute weights + float d_weight_high = id_f - id_low; + float h_weight_high = ih_f - ih_low; + float w_weight_high = iw_f - iw_low; + float d_weight_low = 1.0f - d_weight_high; + float h_weight_low = 1.0f - h_weight_high; + float w_weight_low = 1.0f - w_weight_high; + + // Handle boundary conditions + id_low = std::max(0, std::min(id_low, input_depth - 1)); + id_high = std::max(0, std::min(id_high, input_depth - 1)); + ih_low = std::max(0, std::min(ih_low, input_height - 1)); + ih_high = std::max(0, std::min(ih_high, input_height - 1)); + iw_low = std::max(0, std::min(iw_low, input_width - 1)); + iw_high = std::max(0, std::min(iw_high, input_width - 1)); + + // Compute indices for the eight corners + const int idx_d0_h0_w0 = + (((n * channels + c) * input_depth + id_low) * input_height + ih_low) * input_width + iw_low; + const int idx_d0_h0_w1 = + (((n * channels + c) * input_depth + id_low) * input_height + ih_low) * input_width + iw_high; + const int idx_d0_h1_w0 = + (((n * channels + c) * input_depth + id_low) * input_height + ih_high) * input_width + iw_low; + const int idx_d0_h1_w1 = + (((n * channels + c) * input_depth + id_low) * input_height + ih_high) * input_width + iw_high; + const int idx_d1_h0_w0 = + (((n * channels + c) * input_depth + id_high) * input_height + ih_low) * input_width + iw_low; + const int idx_d1_h0_w1 = + (((n * channels + c) * input_depth + id_high) * input_height + ih_low) * input_width + iw_high; + const int idx_d1_h1_w0 = + (((n * channels + c) * input_depth + id_high) * input_height + ih_high) * input_width + iw_low; + const int idx_d1_h1_w1 = + (((n * channels + c) * input_depth + id_high) * input_height + ih_high) * input_width + iw_high; + + // Compute output index + const int output_idx = (((n * channels + c) * output_depth + od) * output_height + oh) * output_width + ow; + + // Trilinear interpolation + output_data[output_idx] = + static_cast(d_weight_low * h_weight_low * w_weight_low * input_data[idx_d0_h0_w0] + + d_weight_low * h_weight_low * w_weight_high * input_data[idx_d0_h0_w1] + + d_weight_low * h_weight_high * w_weight_low * input_data[idx_d0_h1_w0] + + d_weight_low * h_weight_high * w_weight_high * input_data[idx_d0_h1_w1] + + d_weight_high * h_weight_low * w_weight_low * input_data[idx_d1_h0_w0] + + d_weight_high * h_weight_low * w_weight_high * input_data[idx_d1_h0_w1] + + d_weight_high * h_weight_high * w_weight_low * input_data[idx_d1_h1_w0] + + d_weight_high * h_weight_high * w_weight_high * input_data[idx_d1_h1_w1]); + } + } + } + } + } +} + +void CPUInterpolateOp::forward(const std::vector& inputs, std::vector& outputs) { + const auto& X = inputs[0]; + auto& Y = outputs[0]; + + // Get shapes + const auto& input_shape = X.shape(); + const auto& output_shape = Y.shape(); + const int input_dim = static_cast(input_shape.size()); + + // Get options + const auto& mode = options().mode; + const bool align_corners = options().align_corners; + + // Allocate output tensor if not already allocated + if (Y.isNil() || Y.numel() == 0) { Y = Tensor::empty(output_shape, X.dtype(), X.device()).alloc(); } + + switch (X.dtype()) { + case kFloat32: { + const float* input_data = X.ptr(); + float* output_data = Y.ptr(); + + // Choose interpolation method based on mode and input dimensions + if (mode == aops::InterpolateOpMode::kNearest) { + if (input_dim == 3) { // NCL format + nearest_interpolate_2d(input_data, output_data, {input_shape[0], input_shape[1], 1, input_shape[2]}, + {output_shape[0], output_shape[1], 1, output_shape[2]}, align_corners); + } else if (input_dim == 4) { // NCHW format + nearest_interpolate_2d(input_data, output_data, input_shape, output_shape, align_corners); + } else if (input_dim == 5) { // NCDHW format + // For 3D data, we handle each depth slice separately using 2D nearest neighbor + const int batch_size = input_shape[0]; + const int channels = input_shape[1]; + const int input_depth = input_shape[2]; + const int output_depth = output_shape[2]; + + std::vector scale_factors; + compute_scale_factors({input_depth}, {output_depth}, scale_factors, align_corners); + const float depth_scale = scale_factors[0]; + + for (int od = 0; od < output_depth; ++od) { + // Compute source depth index + float id_f = align_corners ? od * depth_scale : (od + 0.5f) * depth_scale - 0.5f; + int id = std::min(static_cast(std::round(id_f)), input_depth - 1); + id = std::max(0, id); + + // Process each 2D slice + for (int n = 0; n < batch_size; ++n) { + for (int c = 0; c < channels; ++c) { + const float* input_slice = + input_data + (((n * channels + c) * input_depth + id) * input_shape[3] * input_shape[4]); + float* output_slice = + output_data + (((n * channels + c) * output_depth + od) * output_shape[3] * output_shape[4]); + + nearest_interpolate_2d(input_slice, output_slice, {1, 1, input_shape[3], input_shape[4]}, + {1, 1, output_shape[3], output_shape[4]}, align_corners); + } + } + } + } else { + NYI("CPUInterpolateOp::forward nearest mode not support input dim {}", input_dim); + } + } else if (mode == aops::InterpolateOpMode::kLinear) { + if (input_dim == 3) { // NCL format + linear_interpolate_1d(input_data, output_data, input_shape, output_shape, align_corners); + } else { + NYI("CPUInterpolateOp::forward linear mode only supports 3D input (NCL format)"); + } + } else if (mode == aops::InterpolateOpMode::kBilinear) { + if (input_dim == 4) { // NCHW format + bilinear_interpolate_2d(input_data, output_data, input_shape, output_shape, align_corners); + } else { + NYI("CPUInterpolateOp::forward bilinear mode only supports 4D input (NCHW format)"); + } + } else if (mode == aops::InterpolateOpMode::kBicubic) { + if (input_dim == 4) { // NCHW format + bicubic_interpolate_2d(input_data, output_data, input_shape, output_shape, align_corners); + } else { + NYI("CPUInterpolateOp::forward bicubic mode only supports 4D input (NCHW format)"); + } + } else if (mode == aops::InterpolateOpMode::kTrilinear) { + if (input_dim == 5) { // NCDHW format + trilinear_interpolate_3d(input_data, output_data, input_shape, output_shape, align_corners); + } else { + NYI("CPUInterpolateOp::forward trilinear mode only supports 5D input (NCDHW format)"); + } + } else { + NYI("CPUInterpolateOp::forward unknown interpolation mode"); + } + break; + } + default: NYI("CPUInterpolateOp::forward not support dtype {}", nameOfType(X.dtype())); break; + } +} + +} // namespace mllm::cpu diff --git a/mllm/backends/cpu/ops/InterpolateOp.hpp b/mllm/backends/cpu/ops/InterpolateOp.hpp new file mode 100644 index 000000000..24940eb25 --- /dev/null +++ b/mllm/backends/cpu/ops/InterpolateOp.hpp @@ -0,0 +1,25 @@ +// Copyright (c) MLLM Team. +// Licensed under the MIT License. + +#pragma once + +#include "mllm/core/BaseOp.hpp" +#include "mllm/core/aops/InterpolateOp.hpp" + +namespace mllm::cpu { + +class CPUInterpolateOp final : public aops::InterpolateOp { + public: + explicit CPUInterpolateOp(const aops::InterpolateOpOptions& options); + + void forward(const std::vector& inputs, std::vector& outputs) override; +}; + +class CPUInterpolateOpFactory : public TypedOpFactory { + public: + std::shared_ptr createOpImpl(const aops::InterpolateOpOptions& options) override { + return std::make_shared(options); + } +}; + +} // namespace mllm::cpu diff --git a/mllm/backends/cpu/ops/PadOp.cpp b/mllm/backends/cpu/ops/PadOp.cpp new file mode 100644 index 000000000..ba363c85b --- /dev/null +++ b/mllm/backends/cpu/ops/PadOp.cpp @@ -0,0 +1,139 @@ +// Copyright (c) MLLM Team. +// Licensed under the MIT License. + +#include +#include +#include +#include "mllm/backends/cpu/ops/PadOp.hpp" + +namespace mllm::cpu { + +CPUPadOp::CPUPadOp(const aops::PadOpOptions& options) : aops::PadOp(options) {} + +// Helper: compute output shape from input shape and pad vector (starting from last dimension) +static std::vector compute_padded_shape(const std::vector& in_shape, const std::vector& pad) { + const int D = static_cast(in_shape.size()); + std::vector out_shape = in_shape; + const int pairs = static_cast(pad.size()) / 2; + for (int i = 0; i < D; ++i) { + int base = 2 * (D - 1 - i); + int32_t before = (base < pad.size()) ? pad[base] : 0; + int32_t after = (base + 1 < pad.size()) ? pad[base + 1] : 0; + out_shape[i] = in_shape[i] + before + after; + } + return out_shape; +} + +// Helper: reflect index to [0, size-1] without repeating edge (PyTorch-like reflect) +static inline int32_t reflect_index(int32_t x, int32_t size) { + if (size <= 1) return 0; + int32_t m = size - 1; + // Map x to the range [-(size-1), size-1] + int32_t p = std::abs(x); + int32_t period = 2 * m; + int32_t r = p % period; + if (r >= size) { r = period - r; } + return r; +} + +// Helper: replicate index (clamp) +static inline int32_t replicate_index(int32_t x, int32_t size) { + if (size <= 1) return 0; + return std::max(0, std::min(x, size - 1)); +} + +// Helper: circular index (wrap) +static inline int32_t circular_index(int32_t x, int32_t size) { + if (size <= 1) return 0; + int32_t r = x % size; + if (r < 0) r += size; + return r; +} + +void CPUPadOp::forward(const std::vector& inputs, std::vector& outputs) { + const auto& X = inputs[0]; + auto& Y = outputs[0]; + + const auto& in_shape = X.shape(); + const auto& opts = options(); + const auto& pad = opts.pad; // [last_dim_left, last_dim_right, ..., first_dim_left, first_dim_right] + + // Compute output shape and allocate Y if needed + std::vector out_shape = Y.isNil() ? compute_padded_shape(in_shape, pad) : Y.shape(); + if (Y.isNil() || Y.numel() == 0) { Y = Tensor::empty(out_shape, X.dtype(), X.device()).alloc(); } + + // Precompute pad_before/after per dimension in order + const int D = static_cast(in_shape.size()); + std::vector pad_before(D, 0), pad_after(D, 0); + for (int i = 0; i < D; ++i) { + int base = 2 * (D - 1 - i); + if (base < pad.size()) pad_before[i] = pad[base]; + if (base + 1 < pad.size()) pad_after[i] = pad[base + 1]; + } + + // Only implement float32 for now + switch (X.dtype()) { + case kFloat32: { + const float* in = X.ptr(); + float* out = Y.ptr(); + + // Compute input/output strides for index mapping + std::vector in_stride(D, 1), out_stride(D, 1); + for (int i = D - 2; i >= 0; --i) { in_stride[i] = in_stride[i + 1] * in_shape[i + 1]; } + const auto& actual_out_shape = Y.shape(); + for (int i = D - 2; i >= 0; --i) { out_stride[i] = out_stride[i + 1] * actual_out_shape[i + 1]; } + + const int64_t out_numel = Y.numel(); + const auto mode = opts.mode; + const float constant_val = opts.value; + + // Iterate over all output elements + for (int64_t idx = 0; idx < out_numel; ++idx) { + // Decode idx into coordinates + int64_t t = idx; + std::vector oc(D, 0); + for (int i = 0; i < D; ++i) { + int64_t s = out_stride[i]; + oc[i] = (i == D - 1) ? static_cast(t) : static_cast(t / s); + if (i != D - 1) t %= s; + } + + // Map to input coordinates + bool oob = false; + std::vector ic(D, 0); + for (int i = 0; i < D; ++i) { + int32_t xi = oc[i] - pad_before[i]; + int32_t s = in_shape[i]; + switch (mode) { + case aops::PadMode::kConstant: + if (xi < 0 || xi >= s) { + oob = true; + ic[i] = 0; + } else { + ic[i] = xi; + } + break; + case aops::PadMode::kReflect: ic[i] = reflect_index(xi, s); break; + case aops::PadMode::kReplicate: ic[i] = replicate_index(xi, s); break; + case aops::PadMode::kCircular: ic[i] = circular_index(xi, s); break; + default: ic[i] = replicate_index(xi, s); break; + } + } + + if (mode == aops::PadMode::kConstant && oob) { + out[idx] = constant_val; + } else { + // Compute input linear index and copy value + int64_t in_idx = 0; + for (int i = 0; i < D; ++i) { in_idx += static_cast(ic[i]) * in_stride[i]; } + out[idx] = in[in_idx]; + } + } + + break; + } + default: NYI("CPUPadOp::forward not support dtype {}", nameOfType(X.dtype())); break; + } +} + +} // namespace mllm::cpu diff --git a/mllm/backends/cpu/ops/PadOp.hpp b/mllm/backends/cpu/ops/PadOp.hpp new file mode 100644 index 000000000..637ad54ec --- /dev/null +++ b/mllm/backends/cpu/ops/PadOp.hpp @@ -0,0 +1,25 @@ +// Copyright (c) MLLM Team. +// Licensed under the MIT License. + +#pragma once + +#include "mllm/core/BaseOp.hpp" +#include "mllm/core/aops/PadOp.hpp" + +namespace mllm::cpu { + +class CPUPadOp final : public aops::PadOp { + public: + explicit CPUPadOp(const aops::PadOpOptions& options); + + void forward(const std::vector& inputs, std::vector& outputs) override; +}; + +class CPUPadOpFactory : public TypedOpFactory { + public: + std::shared_ptr createOpImpl(const aops::PadOpOptions& options) override { + return std::make_shared(options); + } +}; + +} // namespace mllm::cpu diff --git a/mllm/compile/ir/GeneratedRTTIKind.hpp b/mllm/compile/ir/GeneratedRTTIKind.hpp index 10457d0a2..0706c7554 100644 --- a/mllm/compile/ir/GeneratedRTTIKind.hpp +++ b/mllm/compile/ir/GeneratedRTTIKind.hpp @@ -1,4 +1,4 @@ -// Auto generated: 2025-10-22 14:54:07 +// Auto generated: 2025-10-23 14:44:05 // do not modify this file #pragma once @@ -72,6 +72,9 @@ enum NodeKind : uint32_t { RK_Op_LinalgIROp_CosOp, RK_Op_LinalgIROp_PagedAttnOp, RK_Op_LinalgIROp_LayerNorm2DOp, + RK_Op_LinalgIROp_PadOp, + RK_Op_LinalgIROp_InterpolateOp, + RK_Op_LinalgIROp_EinsumOp, RK_Op_LinalgIROp_Last, RK_Op_GraphIROp, RK_Op_GraphIROp_SubGraphOp, diff --git a/mllm/compile/ir/NodeRTTIClassOfImpl.hpp b/mllm/compile/ir/NodeRTTIClassOfImpl.hpp index 61e49abfe..e22b28c04 100644 --- a/mllm/compile/ir/NodeRTTIClassOfImpl.hpp +++ b/mllm/compile/ir/NodeRTTIClassOfImpl.hpp @@ -1,4 +1,4 @@ -// Auto generated: 2025-10-22 14:54:07 +// Auto generated: 2025-10-23 14:44:05 // do not modify this file #pragma once namespace mllm::ir { @@ -186,6 +186,15 @@ struct NodeRTTIClassOfImpl { #define RTTI_RK_OP_LINALGIROP_LAYERNORM2DOP_IMPL(v) \ return (v)->getKind() >= RK_Op_LinalgIROp_LayerNorm2DOp && (v)->getKind() <= RK_Op_LinalgIROp_LayerNorm2DOp +#define RTTI_RK_OP_LINALGIROP_PADOP_IMPL(v) \ + return (v)->getKind() >= RK_Op_LinalgIROp_PadOp && (v)->getKind() <= RK_Op_LinalgIROp_PadOp + +#define RTTI_RK_OP_LINALGIROP_INTERPOLATEOP_IMPL(v) \ + return (v)->getKind() >= RK_Op_LinalgIROp_InterpolateOp && (v)->getKind() <= RK_Op_LinalgIROp_InterpolateOp + +#define RTTI_RK_OP_LINALGIROP_EINSUMOP_IMPL(v) \ + return (v)->getKind() >= RK_Op_LinalgIROp_EinsumOp && (v)->getKind() <= RK_Op_LinalgIROp_EinsumOp + #define RTTI_RK_OP_GRAPHIROP_IMPL(v) return (v)->getKind() >= RK_Op_GraphIROp && (v)->getKind() <= RK_Op_GraphIROp_Last #define RTTI_RK_OP_GRAPHIROP_SUBGRAPHOP_IMPL(v) \ diff --git a/mllm/compile/ir/linalg/Op.cpp b/mllm/compile/ir/linalg/Op.cpp index 4653a6763..b2e412b72 100644 --- a/mllm/compile/ir/linalg/Op.cpp +++ b/mllm/compile/ir/linalg/Op.cpp @@ -102,5 +102,7 @@ LINALG_AOPS_DECL(OpTypes::kClip, ClipOp); LINALG_AOPS_DECL(OpTypes::kPagedAttn, PagedAttnOp); LINALG_AOPS_DECL(OpTypes::kLayerNorm2D, LayerNorm2DOp); +LINALG_AOPS_DECL(OpTypes::kPad, PadOp); +LINALG_AOPS_DECL(OpTypes::kInterpolate, InterpolateOp); } // namespace mllm::ir::linalg diff --git a/mllm/compile/ir/linalg/Op.hpp b/mllm/compile/ir/linalg/Op.hpp index 2fec76940..ecb5532b5 100644 --- a/mllm/compile/ir/linalg/Op.hpp +++ b/mllm/compile/ir/linalg/Op.hpp @@ -65,6 +65,8 @@ class SinOp; class CosOp; class PagedAttnOp; class LayerNorm2DOp; +class PadOp; +class InterpolateOp; } // namespace mllm #define LINALG_AOPS_DEFINE(class_name, rtti_name) \ @@ -215,5 +217,7 @@ LINALG_AOPS_DEFINE(ClipOp, CLIPOP); LINALG_AOPS_DEFINE(PagedAttnOp, PAGEDATTNOP); LINALG_AOPS_DEFINE(LayerNorm2DOp, LAYERNORM2DOP); +LINALG_AOPS_DEFINE(PadOp, PADOP); +LINALG_AOPS_DEFINE(InterpolateOp, INTERPOLATEOP); } // namespace mllm::ir::linalg diff --git a/mllm/compile/ir/rtti_kind_gen.py b/mllm/compile/ir/rtti_kind_gen.py index f5d0e670d..5a8f43539 100644 --- a/mllm/compile/ir/rtti_kind_gen.py +++ b/mllm/compile/ir/rtti_kind_gen.py @@ -273,6 +273,9 @@ def define_lianlg_ir(ir: dict): op.derive(Cls("CosOp")) op.derive(Cls("PagedAttnOp")) op.derive(Cls("LayerNorm2DOp")) + op.derive(Cls("PadOp")) + op.derive(Cls("InterpolateOp")) + op.derive(Cls("EinsumOp")) # value diff --git a/mllm/core/OpTypes.hpp b/mllm/core/OpTypes.hpp index 26bd18eda..fc1b11e80 100644 --- a/mllm/core/OpTypes.hpp +++ b/mllm/core/OpTypes.hpp @@ -77,6 +77,10 @@ enum class OpTypes : int32_t { kScatter2Shards = 59, kLayerNorm2D = 60, + // Padding Op + kPad = 61, + kInterpolate = 62, + // Dynamic Op Start for user to register there own ops. kDynamicOp_Start = 4096, diff --git a/mllm/core/aops/EinsumOp.cpp b/mllm/core/aops/EinsumOp.cpp new file mode 100644 index 000000000..e69de29bb diff --git a/mllm/core/aops/EinsumOp.hpp b/mllm/core/aops/EinsumOp.hpp new file mode 100644 index 000000000..e69de29bb diff --git a/mllm/core/aops/InterpolateOp.cpp b/mllm/core/aops/InterpolateOp.cpp new file mode 100644 index 000000000..2e2e51790 --- /dev/null +++ b/mllm/core/aops/InterpolateOp.cpp @@ -0,0 +1,81 @@ +// Copyright (c) MLLM Team. +// Licensed under the MIT License. + +#include "mllm/core/aops/InterpolateOp.hpp" +#include "mllm/core/BaseOp.hpp" +#include "mllm/core/Tensor.hpp" +#include "mllm/utils/Common.hpp" +#include "mllm/compile/ir/linalg/Op.hpp" + +namespace mllm::aops { + +InterpolateOp::InterpolateOp(const InterpolateOpOptions& options) : BaseOp(OpTypes::kSiLU), options_(options) {} + +void InterpolateOp::load(const ParameterFile::ptr_t& ploader) { MLLM_EMPTY_SCOPE; } + +void InterpolateOp::trace(void* trace_context, const std::vector& inputs, std::vector& outputs) { + auto ir_ctx = (ir::IRContext*)trace_context; + auto i_irs = ir::tensor::wrapTensors2TensorIR(ir_ctx, inputs); + auto o_irs = ir::tensor::wrapTensors2TensorIR(ir_ctx, outputs); + ir_ctx->create(shared_from_this(), i_irs, o_irs); +} + +void InterpolateOp::forward(const std::vector& inputs, std::vector& outputs) { + NYI("InterpolateOp::forward not implemented in aops base."); +} + +void InterpolateOp::reshape(const std::vector& inputs, std::vector& outputs) { + // Get the input tensor + const auto& input = inputs[0]; + + // Skip if input is nil + if (input.isNil()) { return; } + + // Get input shape + auto input_shape = input.shape(); + const int input_dim = static_cast(input_shape.size()); + + // Calculate output shape based on options + std::vector output_shape = input_shape; + + // If size is specified, use it to determine output dimensions + if (!options_.size.empty()) { + // Ensure size vector has correct dimensions + MLLM_RT_ASSERT(options_.size.size() <= input_dim); + + // Apply size to the last N dimensions where N is the size of options_.size + const int offset = input_dim - static_cast(options_.size.size()); + for (size_t i = 0; i < options_.size.size(); ++i) { output_shape[offset + i] = options_.size[i]; } + } + // If scale_factor is specified, use it to scale dimensions + else if (!options_.scale_factor.empty()) { + // Ensure scale_factor vector has correct dimensions + MLLM_RT_ASSERT(options_.scale_factor.size() <= input_dim); + + // Apply scale factor to the last N dimensions where N is the size of options_.scale_factor + const int offset = input_dim - static_cast(options_.scale_factor.size()); + for (size_t i = 0; i < options_.scale_factor.size(); ++i) { + output_shape[offset + i] = static_cast(input_shape[offset + i] * options_.scale_factor[i]); + } + } + + // If keep_aspect_ratio is true, adjust dimensions to maintain aspect ratio + if (options_.keep_aspect_ratio && !options_.size.empty() && options_.size.size() >= 2) { + // This is typically used for image resizing where we want to maintain aspect ratio + // We'll implement a simple version that scales based on the smaller dimension + const int offset = input_dim - static_cast(options_.size.size()); + float h_scale = static_cast(options_.size[0]) / input_shape[offset]; + float w_scale = static_cast(options_.size[1]) / input_shape[offset + 1]; + + float scale = std::min(h_scale, w_scale); + output_shape[offset] = static_cast(input_shape[offset] * scale); + output_shape[offset + 1] = static_cast(input_shape[offset + 1] * scale); + } + + // Create output tensor with the calculated shape + outputs.emplace_back(Tensor::empty(output_shape, input.dtype(), input.device())); +} + +void InterpolateOp::setup(const std::vector& inputs, std::vector& outputs) { BaseOp::setup(inputs, outputs); } + +} // namespace mllm::aops diff --git a/mllm/core/aops/InterpolateOp.hpp b/mllm/core/aops/InterpolateOp.hpp new file mode 100644 index 000000000..7d8bbc3f7 --- /dev/null +++ b/mllm/core/aops/InterpolateOp.hpp @@ -0,0 +1,47 @@ +// Copyright (c) MLLM Team. +// Licensed under the MIT License. + +#pragma once + +#include "mllm/core/BaseOp.hpp" +#include "mllm/core/ParameterFile.hpp" + +namespace mllm::aops { + +enum class InterpolateOpMode { + kNearest, + kLinear, + kBilinear, + kBicubic, + kTrilinear, +}; + +struct InterpolateOpOptions : public BaseOpOptions { + std::vector size; + std::vector scale_factor; + InterpolateOpMode mode = InterpolateOpMode::kNearest; + bool align_corners = false; + bool keep_aspect_ratio = false; +}; + +class InterpolateOp : public BaseOp { + public: + explicit InterpolateOp(const InterpolateOpOptions& options); + + void load(const ParameterFile::ptr_t& ploader) override; + + void trace(void* trace_context, const std::vector& inputs, std::vector& outputs) override; + + void forward(const std::vector& inputs, std::vector& outputs) override; + + void reshape(const std::vector& inputs, std::vector& outputs) override; + + void setup(const std::vector& inputs, std::vector& outputs) override; + + inline InterpolateOpOptions& options() { return options_; } + + protected: + InterpolateOpOptions options_; +}; + +} // namespace mllm::aops diff --git a/mllm/core/aops/PadOp.cpp b/mllm/core/aops/PadOp.cpp index e69de29bb..923c5b905 100644 --- a/mllm/core/aops/PadOp.cpp +++ b/mllm/core/aops/PadOp.cpp @@ -0,0 +1,28 @@ +// Copyright (c) MLLM Team. +// Licensed under the MIT License. + +#include "mllm/core/aops/PadOp.hpp" + +namespace mllm::aops { + +PadOp::PadOp(const PadOpOptions& options) : BaseOp(OpTypes::kPad), options_(options) {} + +void PadOp::load(const ParameterFile::ptr_t& ploader) { MLLM_EMPTY_SCOPE; } + +void PadOp::trace(void* trace_context, const std::vector& inputs, std::vector& outputs) { + // TODO +} + +void PadOp::forward(const std::vector& inputs, std::vector& outputs) { + NYI("PadOp::forward is not implemented in the base class"); +} + +void PadOp::reshape(const std::vector& inputs, std::vector& outputs) { + // TODO +} + +void PadOp::setup(const std::vector& inputs, std::vector& outputs) { + // TODO +} + +} // namespace mllm::aops diff --git a/mllm/core/aops/PadOp.hpp b/mllm/core/aops/PadOp.hpp index e69de29bb..e4c5c555b 100644 --- a/mllm/core/aops/PadOp.hpp +++ b/mllm/core/aops/PadOp.hpp @@ -0,0 +1,44 @@ +// Copyright (c) MLLM Team. +// Licensed under the MIT License. + +#pragma once + +#include "mllm/core/BaseOp.hpp" +#include "mllm/core/ParameterFile.hpp" + +namespace mllm::aops { + +enum class PadMode : uint8_t { + kConstant = 0, + kReflect = 1, + kReplicate = 2, + kCircular = 3, +}; + +struct PadOpOptions : public BaseOpOptions { + std::vector pad; // padding sizes, starting from the last dimension + PadMode mode{PadMode::kConstant}; // padding mode + float value{0.0f}; // padding value for constant mode +}; + +class PadOp : public BaseOp { + public: + explicit PadOp(const PadOpOptions& options); + + void load(const ParameterFile::ptr_t& ploader) override; + + void trace(void* trace_context, const std::vector& inputs, std::vector& outputs) override; + + void forward(const std::vector& inputs, std::vector& outputs) override; + + void reshape(const std::vector& inputs, std::vector& outputs) override; + + void setup(const std::vector& inputs, std::vector& outputs) override; + + inline const PadOpOptions& options() const { return options_; } + + protected: + PadOpOptions options_; +}; + +} // namespace mllm::aops diff --git a/mllm/nn/Functional.cpp b/mllm/nn/Functional.cpp index 2f447ddd0..071d48abf 100644 --- a/mllm/nn/Functional.cpp +++ b/mllm/nn/Functional.cpp @@ -14,6 +14,8 @@ #include "mllm/core/aops/ViewOp.hpp" #include "mllm/core/aops/TopKOp.hpp" #include "mllm/core/aops/SiLUOp.hpp" +#include "mllm/core/aops/PadOp.hpp" +#include "mllm/core/aops/InterpolateOp.hpp" #include "mllm/engine/Context.hpp" namespace mllm::nn::functional { @@ -127,4 +129,27 @@ Tensor scaledDotProductAttention(const Tensor& Q, const Tensor& K, const Tensor& return matmul(attn_weight, V); } +Tensor pad(const Tensor& x, const std::vector& pad, aops::PadMode mode, float value) { + return Context::instance().buildOpAndSubmitTask(OpTypes::kPad, aops::PadOpOptions{.pad = pad, .mode = mode, .value = value}, + {x})[0]; +} + +Tensor interpolate(const Tensor& x, const std::vector& size, aops::InterpolateOpMode mode, bool align_corners, + bool keep_aspect_ratio) { + aops::InterpolateOpOptions opts{}; + opts.size.assign(size.begin(), size.end()); + opts.mode = mode; + opts.align_corners = align_corners; + opts.keep_aspect_ratio = keep_aspect_ratio; + return Context::instance().buildOpAndSubmitTask(OpTypes::kInterpolate, opts, {x})[0]; +} + +Tensor interpolate(const Tensor& x, const std::vector& scale_factor, aops::InterpolateOpMode mode, bool align_corners) { + aops::InterpolateOpOptions opts{}; + opts.scale_factor = scale_factor; + opts.mode = mode; + opts.align_corners = align_corners; + return Context::instance().buildOpAndSubmitTask(OpTypes::kInterpolate, opts, {x})[0]; +} + } // namespace mllm::nn::functional diff --git a/mllm/nn/Functional.hpp b/mllm/nn/Functional.hpp index b61c564b9..30c127667 100644 --- a/mllm/nn/Functional.hpp +++ b/mllm/nn/Functional.hpp @@ -9,6 +9,8 @@ #include "mllm/core/Tensor.hpp" #include "mllm/core/aops/MatMulOp.hpp" #include "mllm/core/aops/SplitOp.hpp" +#include "mllm/core/aops/PadOp.hpp" +#include "mllm/core/aops/InterpolateOp.hpp" #include "mllm/engine/Context.hpp" namespace mllm::nn::functional { @@ -132,4 +134,16 @@ void scatter2Shards(const Tensor& src, const Tensor& shards_pointer, int32_t dim // If you want causal mask attention. Use Flash attention instead. Tensor scaledDotProductAttention(const Tensor& Q, const Tensor& K, const Tensor& V, const Tensor& mask = Tensor()); +// Pad: apply N-D padding. pad is ordered from the last to first dimension. +Tensor pad(const Tensor& x, const std::vector& pad, aops::PadMode mode = aops::PadMode::kConstant, float value = 0.0f); + +// Interpolate by target size +Tensor interpolate(const Tensor& x, const std::vector& size, + aops::InterpolateOpMode mode = aops::InterpolateOpMode::kNearest, bool align_corners = false, + bool keep_aspect_ratio = false); + +// Interpolate by scale factor +Tensor interpolate(const Tensor& x, const std::vector& scale_factor, + aops::InterpolateOpMode mode = aops::InterpolateOpMode::kNearest, bool align_corners = false); + } // namespace mllm::nn::functional diff --git a/mllm/nn/Module.hpp b/mllm/nn/Module.hpp index 5de909e37..061bedc9b 100644 --- a/mllm/nn/Module.hpp +++ b/mllm/nn/Module.hpp @@ -45,7 +45,7 @@ template class ModuleList; template -class ModuleListSuffix; +class ModuleListSuffixed; class Module { public: diff --git a/tests/cpu/Conv2DKernelTest.hpp b/tests/cpu/Conv2DKernelTest.hpp index 67c6d89a2..f52ca045e 100644 --- a/tests/cpu/Conv2DKernelTest.hpp +++ b/tests/cpu/Conv2DKernelTest.hpp @@ -112,7 +112,21 @@ class Conv2DKernelTest : public KernelTest { bool testConv2D(const std::vector>& cfgs) { for (auto& cfg : cfgs) { - if (!testConv2DOnce(cfg)) { return false; } + if (!testConv2DOnce(cfg)) { + auto in_channel = cfg.at("in_channel"); + auto out_channel = cfg.at("out_channel"); + auto I_H = cfg.at("I_H"); + auto I_W = cfg.at("I_W"); + auto K_H = cfg.at("K_H"); + auto K_W = cfg.at("K_W"); + auto S_H = cfg.at("S_H"); + auto S_W = cfg.at("S_W"); + auto P_H = cfg.at("P_H"); + auto P_W = cfg.at("P_W"); + auto bias = cfg.at("bias"); + print(in_channel, out_channel, I_H, I_W, K_H, K_W, S_H, S_W, P_H, P_W, bias); + return false; + } } return true; } diff --git a/tests/cpu/KernelTest.cpp b/tests/cpu/KernelTest.cpp index d06bc68fa..fb595c79d 100644 --- a/tests/cpu/KernelTest.cpp +++ b/tests/cpu/KernelTest.cpp @@ -871,19 +871,92 @@ TEST_F(FlashAttn2KernelTest, fwd_bshd) { //===----------------------------------------------------------------------===// #include "Conv2DKernelTest.hpp" TEST_F(Conv2DKernelTest, im2col) { - EXPECT_EQ(testConv2D({{ - {"in_channel", 3}, - {"out_channel", 1024}, - {"I_H", 224}, - {"I_W", 224}, - {"K_H", 14}, - {"K_W", 14}, - {"S_H", 14}, - {"S_W", 14}, - {"P_H", 0}, - {"P_W", 0}, - {"bias", 0}, - }}), + EXPECT_EQ(testConv2D({ + // CLIP patch embedding + { + {"in_channel", 3}, + {"out_channel", 1024}, + {"I_H", 224}, + {"I_W", 224}, + {"K_H", 14}, + {"K_W", 14}, + {"S_H", 14}, + {"S_W", 14}, + {"P_H", 0}, + {"P_W", 0}, + {"bias", 0}, + }, + // SAM PatchEmbed.proj + { + {"in_channel", 3}, + {"out_channel", 768}, + {"I_H", 1024}, + {"I_W", 1024}, + {"K_H", 16}, + {"K_W", 16}, + {"S_H", 16}, + {"S_W", 16}, + {"P_H", 0}, + {"P_W", 0}, + {"bias", 1}, + }, + // neck: Conv2D(768 -> 12, 1x1, stride=1, pad=0, bias=false) + { + {"in_channel", 768}, + {"out_channel", 12}, + {"I_H", 64}, + {"I_W", 64}, + {"K_H", 1}, + {"K_W", 1}, + {"S_H", 1}, + {"S_W", 1}, + {"P_H", 0}, + {"P_W", 0}, + {"bias", 0}, + }, + // neck: Conv2D(256 -> 256, 3x3, stride=1, pad=1, bias=false) + { + {"in_channel", 256}, + {"out_channel", 256}, + {"I_H", 64}, + {"I_W", 64}, + {"K_H", 3}, + {"K_W", 3}, + {"S_H", 1}, + {"S_W", 1}, + {"P_H", 1}, + {"P_W", 1}, + {"bias", 0}, + }, + // net_2_: Conv2D(256 -> 512, 3x3, stride=2, pad=1, bias=false) + { + {"in_channel", 256}, + {"out_channel", 512}, + {"I_H", 64}, + {"I_W", 64}, + {"K_H", 3}, + {"K_W", 3}, + {"S_H", 2}, + {"S_W", 2}, + {"P_H", 1}, + {"P_W", 1}, + {"bias", 0}, + }, + // net_3_: Conv2D(512 -> 1024, 3x3, stride=2, pad=1, bias=false) + { + {"in_channel", 512}, + {"out_channel", 1024}, + {"I_H", 32}, + {"I_W", 32}, + {"K_H", 3}, + {"K_W", 3}, + {"S_H", 2}, + {"S_W", 2}, + {"P_H", 1}, + {"P_W", 1}, + {"bias", 0}, + }, + }), true); } From bf82e6b85541079787cf17a0a1e29a49064009fc Mon Sep 17 00:00:00 2001 From: chenghuaWang <2923277184@qq.com> Date: Thu, 23 Oct 2025 16:23:43 +0800 Subject: [PATCH 04/25] feat(image): add dynamic image preprocessing and cropping support - Implement `dynamic_preprocess` to generate image tiles with aspect ratio handling - Add `Image::crop` with PIL-style out-of-bounds padding - Introduce `ImageTransform` system with composable transforms (Resize, CenterCrop, etc.) - Support common preprocessing pipelines via `BasicImageTransform` - Enable thumbnail generation and grid-based tiling for OCR models feat(tokenizer): remove unused header include - Remove unnecessary `ARGeneration.hpp` include in tokenization header chore(deepseek_ocr): add missing `` header - Include `` for `std::numeric_limits` usage in aspect ratio calculation --- .../deepseek_ocr/conversation_preprocess.hpp | 97 +++++++++ .../tokenization_deepseek_ocr.hpp | 1 - mllm/preprocessor/visual/Image.cpp | 44 ++++ mllm/preprocessor/visual/Image.hpp | 4 + mllm/preprocessor/visual/ImageTransform.cpp | 155 ++++++++++++++ mllm/preprocessor/visual/ImageTransform.hpp | 192 ++++++++++++++++++ 6 files changed, 492 insertions(+), 1 deletion(-) create mode 100644 mllm/preprocessor/visual/ImageTransform.cpp create mode 100644 mllm/preprocessor/visual/ImageTransform.hpp diff --git a/mllm/models/deepseek_ocr/conversation_preprocess.hpp b/mllm/models/deepseek_ocr/conversation_preprocess.hpp index 34f52bc9b..9327115a6 100644 --- a/mllm/models/deepseek_ocr/conversation_preprocess.hpp +++ b/mllm/models/deepseek_ocr/conversation_preprocess.hpp @@ -10,6 +10,7 @@ #include #include #include +#include #include #include "mllm/preprocessor/visual/Image.hpp" @@ -316,4 +317,100 @@ std::vector loadImages(const nlohmann::json& conversations) { return ret; } +std::pair findClosestAspectRatio(double aspect_ratio, const std::vector>& target_ratios, + int width, int height, int image_size) { + double best_ratio_diff = std::numeric_limits::infinity(); + std::pair best_ratio = {1, 1}; + const double area = static_cast(width) * static_cast(height); + + for (const auto& ratio : target_ratios) { + const double target_aspect_ratio = static_cast(ratio.first) / static_cast(ratio.second); + const double ratio_diff = std::abs(aspect_ratio - target_aspect_ratio); + + if (ratio_diff < best_ratio_diff) { + best_ratio_diff = ratio_diff; + best_ratio = ratio; + } else if (ratio_diff == best_ratio_diff) { + if (area > 0.5 * static_cast(image_size) * static_cast(image_size) * static_cast(ratio.first) + * static_cast(ratio.second)) { + best_ratio = ratio; + } + } + } + + return best_ratio; +} + +/** + * Dynamic preprocess that crops and resizes an image into tiles matching + * a target aspect ratio grid, similar to DeepSeek-OCR implementation. + * + * - Selects closest aspect ratio from predefined candidates. + * - Generates non-overlapping tiles along the longer dimension. + * - Pads out-of-bounds areas during crop and resizes tiles to target grid. + * + * @param image Input RGB image + * @param image_size Base size used to construct target grid (e.g., 448) + * @param max_num Max number of tiles to generate (default 6) + * @param use_thumbnail Whether to add a square thumbnail (image_size x image_size) as the first tile + * @return Vector of processed Image tiles + */ +inline std::vector dynamic_preprocess(const Image& image, int image_size, int max_num = 6, bool use_thumbnail = false) { + // Mirror Python logic: generate a grid of image_size x image_size crops. + // Grid shape is derived from ceil(width/size) and ceil(height/size), + // and then capped to max_num by reducing along the longer dimension. + + Image src = image; // non-const copy to call non-const methods + const int w = src.w(); + const int h = src.h(); + if (w <= 0 || h <= 0) { return {}; } + + // Integer ceil for positive numbers: ceil(x / y) == (x + y - 1) / y + int grid_w = (w + image_size - 1) / image_size; + int grid_h = (h + image_size - 1) / image_size; + if (grid_w < 1) grid_w = 1; + if (grid_h < 1) grid_h = 1; + + // Cap total tiles to max_num while preserving aspect tendency + while (grid_w * grid_h > max_num) { + if (grid_w >= grid_h && grid_w > 1) { + --grid_w; + } else if (grid_h > 1) { + --grid_h; + } else { + break; + } + } + + const int target_width = grid_w * image_size; + const int target_height = grid_h * image_size; + + std::vector out; + out.reserve(static_cast(max_num)); + + // Optional thumbnail first + if (use_thumbnail) { + out.push_back(src.resize(image_size, image_size)); + if (static_cast(out.size()) >= max_num) { return out; } + } + + const int total_tiles = grid_w * grid_h; + for (int i = 0; i < total_tiles; ++i) { + // Python equivalent: + // x = (i % (target_width // image_size)) * image_size + // y = (i // (target_width // image_size)) * image_size + const int cols = target_width / image_size; // == grid_w + const int x = (i % cols) * image_size; + const int y = (i / cols) * image_size; + + // PIL-style crop with zero padding beyond bounds + Image tile = src.crop(x, y, x + image_size, y + image_size); + out.push_back(tile); + + if (static_cast(out.size()) >= max_num) { break; } + } + + return out; +} + } // namespace mllm::models::deepseek_ocr diff --git a/mllm/models/deepseek_ocr/tokenization_deepseek_ocr.hpp b/mllm/models/deepseek_ocr/tokenization_deepseek_ocr.hpp index 76e332fcb..ca39f30df 100644 --- a/mllm/models/deepseek_ocr/tokenization_deepseek_ocr.hpp +++ b/mllm/models/deepseek_ocr/tokenization_deepseek_ocr.hpp @@ -11,7 +11,6 @@ #include #include -#include "mllm/models/ARGeneration.hpp" #include "mllm/preprocessor/tokenizers/BPE.hpp" #include "mllm/preprocessor/tokenizers/Unicode.hpp" #include "mllm/preprocessor/tokenizers/AutoTokenizer.hpp" diff --git a/mllm/preprocessor/visual/Image.cpp b/mllm/preprocessor/visual/Image.cpp index 22429b4ce..5269352e0 100644 --- a/mllm/preprocessor/visual/Image.cpp +++ b/mllm/preprocessor/visual/Image.cpp @@ -101,4 +101,48 @@ int Image::h() { return h_; } int Image::c() { return c_; } +Image Image::crop(int left, int upper, int right, int lower) { + // Validate input dimensions and ensure source image is loaded + MLLM_RT_ASSERT(image_ptr_ != nullptr); + MLLM_RT_ASSERT(right > left && lower > upper); + + const int crop_w = right - left; + const int crop_h = lower - upper; + + Image new_img; + new_img.w_ = crop_w; + new_img.h_ = crop_h; + new_img.c_ = 3; // Force RGB, consistent with Image::open + + // Allocate output buffer; stbi_image_free uses free, so malloc is compatible + unsigned char* output = static_cast(malloc(static_cast(crop_w) * crop_h * new_img.c_)); + MLLM_RT_ASSERT(output != nullptr); + + const unsigned char* src = static_cast(image_ptr_->ptr_); + + // PIL-style crop: pad out-of-bounds with zeros + for (int y = 0; y < crop_h; ++y) { + const int sy = upper + y; + for (int x = 0; x < crop_w; ++x) { + const int sx = left + x; + unsigned char* dst_px = output + (static_cast(y) * crop_w + x) * new_img.c_; + if (sx >= 0 && sx < w_ && sy >= 0 && sy < h_) { + const unsigned char* src_px = src + (static_cast(sy) * w_ + sx) * c_; + dst_px[0] = src_px[0]; + dst_px[1] = src_px[1]; + dst_px[2] = src_px[2]; + } else { + dst_px[0] = 0; + dst_px[1] = 0; + dst_px[2] = 0; + } + } + } + + new_img.image_ptr_ = std::make_shared<_ImagePtr>(); + new_img.image_ptr_->ptr_ = output; + + return new_img; +} + } // namespace mllm \ No newline at end of file diff --git a/mllm/preprocessor/visual/Image.hpp b/mllm/preprocessor/visual/Image.hpp index c47546a68..f7bb76283 100644 --- a/mllm/preprocessor/visual/Image.hpp +++ b/mllm/preprocessor/visual/Image.hpp @@ -30,6 +30,10 @@ class Image { Image resize(int new_w, int new_h); + // Crop the image with PIL-style box (left, upper, right, lower). + // Out-of-bounds areas are padded with zeros. Returns a new Image. + Image crop(int left, int upper, int right, int lower); + void save(const std::string& fp); Tensor tensor(); diff --git a/mllm/preprocessor/visual/ImageTransform.cpp b/mllm/preprocessor/visual/ImageTransform.cpp new file mode 100644 index 000000000..d3aefc710 --- /dev/null +++ b/mllm/preprocessor/visual/ImageTransform.cpp @@ -0,0 +1,155 @@ +// Copyright (c) MLLM Team. +// Licensed under the MIT License. + +#include +#include + +#include "mllm/preprocessor/visual/ImageTransform.hpp" +#include "mllm/utils/Common.hpp" + +namespace mllm { + +// ========================= ComposeImageTransforms ========================= +Image ComposeImageTransforms::operator()(const Image& img) const { + Image current = img; // make a mutable copy for non-const member calls + for (const auto& t : transforms_) { current = t->apply(current); } + return current; +} + +// ========================= ComposeTensorTransforms ========================= +Tensor ComposeTensorTransforms::operator()(const Tensor& t) const { + Tensor current = t; // make a mutable copy for non-const member calls + for (const auto& tr : transforms_) { current = tr->apply(current); } + return current; +} + +// ========================= Resize ========================= +Resize::Resize(int size_shorter) : size_shorter_(size_shorter) {} +Resize::Resize(int target_h, int target_w) : size_hw_(std::make_pair(target_h, target_w)) {} + +Image Resize::apply(const Image& input) const { + Image src = input; // mutable copy + if (size_hw_.has_value()) { + // Explicit resize to (H, W) + const int new_h = size_hw_->first; + const int new_w = size_hw_->second; + return src.resize(new_w, new_h); // Image::resize expects (w, h) + } + + // Shorter-side resize with aspect ratio preserved + MLLM_RT_ASSERT(size_shorter_.has_value()); + const int target_shorter = *size_shorter_; + const int h = src.h(); + const int w = src.w(); + + const int shorter = std::min(h, w); + const double scale = static_cast(target_shorter) / static_cast(shorter); + const int new_h = static_cast(std::round(h * scale)); + const int new_w = static_cast(std::round(w * scale)); + return src.resize(new_w, new_h); +} + +// ========================= CenterCrop ========================= +CenterCrop::CenterCrop(int crop_size) : crop_h_(crop_size), crop_w_(crop_size) {} +CenterCrop::CenterCrop(int crop_h, int crop_w) : crop_h_(crop_h), crop_w_(crop_w) {} + +Image CenterCrop::apply(const Image& input) const { + Image src = input; // mutable copy + const int h = src.h(); + const int w = src.w(); + + // Compute centered crop box, PIL-style + const int left = (w - crop_w_) / 2; + const int upper = (h - crop_h_) / 2; + const int right = left + crop_w_; + const int lower = upper + crop_h_; + + return src.crop(left, upper, right, lower); +} + +// ========================= ToTensor ========================= +Tensor ToTensor::apply(const Image& input) const { + Image src = input; // mutable copy + // Image::tensor returns HWC float32 tensor with values in [0, 255] + Tensor t = src.tensor(); + + // Reorder to CHW to match torchvision semantics + t = t.permute({2, 0, 1}); + + // Scale to [0, 1] + t = t / 255.0f; + return t; +} + +// ========================= Normalize ========================= +Normalize::Normalize(const std::vector& mean, const std::vector& std) : mean_(mean), std_(std) { + MLLM_RT_ASSERT_EQ(mean_.size(), std_.size()); +} + +Tensor Normalize::apply(const Tensor& input) const { + Tensor src = input; // mutable copy + // Expect src in CHW layout + MLLM_RT_ASSERT(src.rank() == 3); + const int c = src.size(0); + const int h = src.size(1); + const int w = src.size(2); + MLLM_RT_ASSERT_EQ(static_cast(mean_.size()), c); + MLLM_RT_ASSERT_EQ(static_cast(std_.size()), c); + + // Work on a contiguous clone to simplify indexing + Tensor out = src.clone().contiguous(); + float* ptr = out.ptr(); + const size_t plane = static_cast(h) * static_cast(w); + + for (int ch = 0; ch < c; ++ch) { + const float m = mean_[ch]; + const float s = std_[ch]; + + float* base = ptr + static_cast(ch) * plane; + for (size_t i = 0; i < plane; ++i) { base[i] = (base[i] - m) / s; } + } + + return out; +} + +// ========================= BasicImageTransform ========================= +BasicImageTransform::BasicImageTransform(std::optional resize_shorter, std::optional> resize_hw, + std::optional> center_crop, + std::optional> norm_mean, + std::optional> norm_std) { + // Build image pipeline + if (resize_shorter.has_value()) { + image_pipeline_.add(std::make_shared(*resize_shorter)); + } else if (resize_hw.has_value()) { + image_pipeline_.add(std::make_shared(resize_hw->first, resize_hw->second)); + } + + if (center_crop.has_value()) { image_pipeline_.add(std::make_shared(center_crop->first, center_crop->second)); } + + // Build tensor pipeline (Normalize optional) + if (norm_mean.has_value() && norm_std.has_value()) { + tensor_pipeline_.add(std::make_shared(*norm_mean, *norm_std)); + } +} + +BasicImageTransform::BasicImageTransform(std::optional resize_shorter, std::optional center_crop_square, + std::optional> norm_mean, + std::optional> norm_std) { + if (resize_shorter.has_value()) { image_pipeline_.add(std::make_shared(*resize_shorter)); } + if (center_crop_square.has_value()) { image_pipeline_.add(std::make_shared(*center_crop_square)); } + if (norm_mean.has_value() && norm_std.has_value()) { + tensor_pipeline_.add(std::make_shared(*norm_mean, *norm_std)); + } +} + +Tensor BasicImageTransform::operator()(const Image& img) const { + // 1) Run image-level transforms + Image processed = image_pipeline_(img); + // 2) Convert to tensor (CHW, [0,1]) + Tensor t = to_tensor_.apply(processed); + // 3) Run tensor-level transforms + t = tensor_pipeline_(t); + return t; +} + +} // namespace mllm diff --git a/mllm/preprocessor/visual/ImageTransform.hpp b/mllm/preprocessor/visual/ImageTransform.hpp new file mode 100644 index 000000000..de8979ca7 --- /dev/null +++ b/mllm/preprocessor/visual/ImageTransform.hpp @@ -0,0 +1,192 @@ +// Copyright (c) MLLM Team. +// Licensed under the MIT License. + +/** + * ImageTransform.hpp + * + * This header provides a small, extensible transform system tailored for mllm::Image, + * modeled after the design philosophy of torchvision.transforms. It enables users to + * compose image processing operations and common post-processing steps used in CV/ML pipelines. + * + * Key Concepts (parallels to torchvision): + * - Transform Operators: Small classes that implement a single, well-defined transformation. + * For example, Resize, CenterCrop operate on Image; Normalize operates on Tensor. + * - Composition: Compose objects allow chaining multiple transforms in order. + * - Type Flow: Similar to torchvision, some image transforms keep data as image-like objects, + * then a "ToTensor" step converts the image to a tensor. Later steps like "Normalize" operate on Tensor. + * + * Design Notes: + * - Image-level transforms implement IImageTransform and return a new Image. + * - Tensor-level transforms implement ITensorTransform and return a new Tensor. + * - ToTensor is a bridging transform converting Image(H,W,C, uint8-like) to Tensor(C,H,W, float32) + * with values scaled to [0, 1], matching torchvision.transforms.ToTensor semantics. + * - Normalize implements channel-wise normalization: (x - mean) / std for each channel, + * where x is a float32 Tensor in CHW layout. + * - BasicImageTransform is a convenience pipeline assembling common steps: resize -> optional crop -> to_tensor -> optional + * normalize. + * + * Example (usage, mirrors torchvision style): + * // Build a preprocessing pipeline + * // mllm::BasicImageTransform tf( + * // std::optional(512), // resize shorter side to 512 + * // std::optional(448), // center-crop square 448x448 + * // std::vector{0.485f, 0.456f, 0.406f}, // mean + * // std::vector{0.229f, 0.224f, 0.225f} // std + * // ); + * + * // Apply to an image + * // mllm::Image img = mllm::Image::open("/path/to/img.jpg"); + * // mllm::Tensor input = tf(img); // CHW, float32, normalized + */ + +#pragma once + +#include +#include +#include +#include + +#include "mllm/preprocessor/visual/Image.hpp" +#include "mllm/core/Tensor.hpp" + +namespace mllm { + +// Interface for transforms that take an Image and return an Image. +class IImageTransform { + public: + virtual ~IImageTransform() = default; + [[nodiscard]] virtual Image apply(const Image& input) const = 0; +}; + +// Interface for transforms that take a Tensor and return a Tensor. +class ITensorTransform { + public: + virtual ~ITensorTransform() = default; + [[nodiscard]] virtual Tensor apply(const Tensor& input) const = 0; +}; + +// Compose multiple image transforms (executed in order). Returns the final Image. +class ComposeImageTransforms { + public: + ComposeImageTransforms() = default; + + explicit ComposeImageTransforms(const std::vector>& transforms) : transforms_(transforms) {} + + ComposeImageTransforms& add(const std::shared_ptr& t) { + transforms_.push_back(t); + return *this; + } + + [[nodiscard]] Image operator()(const Image& img) const; + + private: + std::vector> transforms_; +}; + +// Compose multiple tensor transforms (executed in order). Returns the final Tensor. +class ComposeTensorTransforms { + public: + ComposeTensorTransforms() = default; + + explicit ComposeTensorTransforms(const std::vector>& transforms) + : transforms_(transforms) {} + + ComposeTensorTransforms& add(const std::shared_ptr& t) { + transforms_.push_back(t); + return *this; + } + + [[nodiscard]] Tensor operator()(const Tensor& t) const; + + private: + std::vector> transforms_; +}; + +// Resize transform. +// TorchVision semantics: +// - If constructed with a single integer `size`, resize so that the shorter side == size, +// preserving aspect ratio. The longer side is scaled accordingly. +// - If constructed with (height, width), resize to exactly that spatial size. +class Resize : public IImageTransform { + public: + // Preserve aspect ratio: shorter side == size. + explicit Resize(int size_shorter); + // Explicit target (height, width). + Resize(int target_h, int target_w); + + [[nodiscard]] Image apply(const Image& input) const override; + + private: + std::optional size_shorter_; + std::optional> size_hw_; +}; + +// CenterCrop transform. +// TorchVision semantics: +// - Crop a Region of size (crop_h, crop_w) at the image center. +// - If the crop extends beyond boundaries, out-of-bounds is zero-padded (PIL-style); our Image::crop already supports this. +class CenterCrop : public IImageTransform { + public: + explicit CenterCrop(int crop_size); + CenterCrop(int crop_h, int crop_w); + + [[nodiscard]] Image apply(const Image& input) const override; + + private: + int crop_h_; + int crop_w_; +}; + +// ToTensor bridging transform: Image -> Tensor. +// TorchVision semantics: +// - Convert PIL-like image to a float32 tensor with shape (C, H, W). +// - Scale values from [0, 255] to [0, 1]. +class ToTensor { + public: + [[nodiscard]] Tensor apply(const Image& input) const; +}; + +// Normalize transform: Tensor -> Tensor. +// TorchVision semantics: +// - Input tensor expected to be float32 in CHW layout. +// - For each channel c: out[c, :, :] = (in[c, :, :] - mean[c]) / std[c]. +class Normalize : public ITensorTransform { + public: + Normalize(const std::vector& mean, const std::vector& std); + + [[nodiscard]] Tensor apply(const Tensor& input) const override; + + private: + std::vector mean_; + std::vector std_; +}; + +// Convenience pipeline resembling common torchvision usage: +// transforms = Compose([Resize, (optional) CenterCrop, ToTensor, (optional) Normalize]) +// - Users provide parameters commonly used in OCR/vision preprocessing. +// - Returns a Tensor ready for model input. +class BasicImageTransform { + public: + // Build a pipeline: + // - If `resize_shorter` is set, resize by shorter side; otherwise if `resize_hw` is set, resize to (h, w). + // - If `center_crop` is set, apply center crop. + // - Always apply ToTensor. + // - If `norm_mean` and `norm_std` are provided, apply Normalize. + BasicImageTransform(std::optional resize_shorter, std::optional> resize_hw, + std::optional> center_crop, std::optional> norm_mean, + std::optional> norm_std); + + // Convenience ctor: shorter-side resize, optional square crop, optional normalize. + BasicImageTransform(std::optional resize_shorter, std::optional center_crop_square, + std::optional> norm_mean, std::optional> norm_std); + + // Apply the pipeline to an input image and return the final tensor. + [[nodiscard]] Tensor operator()(const Image& img) const; + + private: + ComposeImageTransforms image_pipeline_; + ToTensor to_tensor_; + ComposeTensorTransforms tensor_pipeline_; +}; + +} // namespace mllm From 6770f875da54e7503c28981455b77ab383c3f8c0 Mon Sep 17 00:00:00 2001 From: chenghuaWang <2923277184@qq.com> Date: Thu, 23 Oct 2025 17:09:40 +0800 Subject: [PATCH 05/25] feat(interpolate): add antialias support and remove keep_aspect_ratio - Replace `keep_aspect_ratio` option with `antialias` in interpolate operations - Implement Gaussian blur utility for antialiasing in bilinear and bicubic modes - Update documentation to reflect the new `antialias` parameter - Refactor internal interpolate functions to use new API with antialiasing - Fix incorrect OpType in InterpolateOp constructor - Simplify position embedding interpolation using new functional API - Remove custom padding logic in favor of `nn::functional::pad` --- docs/api/functional.rst | 5 +- mllm/backends/cpu/ops/InterpolateOp.cpp | 84 ++++++++++- mllm/core/aops/InterpolateOp.cpp | 15 +- mllm/core/aops/InterpolateOp.hpp | 2 +- mllm/models/deepseek_ocr/deepencoder.hpp | 176 +++-------------------- mllm/nn/Functional.cpp | 10 +- mllm/nn/Functional.hpp | 11 +- 7 files changed, 117 insertions(+), 186 deletions(-) diff --git a/docs/api/functional.rst b/docs/api/functional.rst index 03f8bd2e8..080d20914 100644 --- a/docs/api/functional.rst +++ b/docs/api/functional.rst @@ -88,7 +88,7 @@ Shape Operations :param value: Constant value used when mode is kConstant. Default: 0.0 :return: Padded tensor -.. cpp:function:: Tensor mllm::nn::functional::interpolate(const Tensor& x, const std::vector& size, aops::InterpolateOpMode mode = aops::InterpolateOpMode::kNearest, bool align_corners = false, bool keep_aspect_ratio = false) +.. cpp:function:: Tensor mllm::nn::functional::interpolate(const Tensor& x, const std::vector& size, aops::InterpolateOpMode mode = aops::InterpolateOpMode::kNearest, bool align_corners = false, bool antialias = false) Resize a tensor to the target spatial size. @@ -96,8 +96,7 @@ Shape Operations :param size: Target spatial size (e.g., [H_out, W_out] for 2D) :param mode: Interpolation mode (kNearest, kLinear, kBilinear, kBicubic, kTrilinear). Default: kNearest :param align_corners: Align corners for linear/bilinear/trilinear interpolation. Default: false - :param keep_aspect_ratio: Keep aspect ratio when size is provided (handled by AOP). Default: false - :return: Resized tensor + :return: Resized tensor .. cpp:function:: Tensor mllm::nn::functional::interpolate(const Tensor& x, const std::vector& scale_factor, aops::InterpolateOpMode mode = aops::InterpolateOpMode::kNearest, bool align_corners = false) diff --git a/mllm/backends/cpu/ops/InterpolateOp.cpp b/mllm/backends/cpu/ops/InterpolateOp.cpp index 0450c021b..66fdc3d7a 100644 --- a/mllm/backends/cpu/ops/InterpolateOp.cpp +++ b/mllm/backends/cpu/ops/InterpolateOp.cpp @@ -353,6 +353,7 @@ void CPUInterpolateOp::forward(const std::vector& inputs, std::vector& inputs, std::vector(); float* output_data = Y.ptr(); + // Gaussian blur utility (separable) for NCHW + auto gaussian_blur_nchw = [&](const float* src, int N, int C, int H, int W, float sigma) { + int radius = std::max(1, static_cast(std::ceil(3.f * sigma))); + std::vector kernel(2 * radius + 1); + float sumw = 0.f; + for (int i = -radius; i <= radius; ++i) { + float w = std::exp(-(i * i) / (2.f * sigma * sigma)); + kernel[i + radius] = w; + sumw += w; + } + for (float& w : kernel) { w /= sumw; } + + size_t numel = static_cast(N) * C * H * W; + std::vector tmp(numel, 0.f); + std::vector dst(numel, 0.f); + + // Horizontal pass + for (int n = 0; n < N; ++n) { + for (int c = 0; c < C; ++c) { + for (int h = 0; h < H; ++h) { + for (int w = 0; w < W; ++w) { + float acc = 0.f; + for (int k = -radius; k <= radius; ++k) { + int x = std::min(std::max(w + k, 0), W - 1); + size_t idx = ((static_cast(n) * C + c) * H + h) * W + x; + acc += kernel[k + radius] * src[idx]; + } + size_t oidx = ((static_cast(n) * C + c) * H + h) * W + w; + tmp[oidx] = acc; + } + } + } + } + // Vertical pass + for (int n = 0; n < N; ++n) { + for (int c = 0; c < C; ++c) { + for (int h = 0; h < H; ++h) { + for (int w = 0; w < W; ++w) { + float acc = 0.f; + for (int k = -radius; k <= radius; ++k) { + int y = std::min(std::max(h + k, 0), H - 1); + size_t idx = ((static_cast(n) * C + c) * H + y) * W + w; + acc += kernel[k + radius] * tmp[idx]; + } + size_t oidx = ((static_cast(n) * C + c) * H + h) * W + w; + dst[oidx] = acc; + } + } + } + } + return dst; + }; + // Choose interpolation method based on mode and input dimensions if (mode == aops::InterpolateOpMode::kNearest) { if (input_dim == 3) { // NCL format @@ -410,13 +464,39 @@ void CPUInterpolateOp::forward(const std::vector& inputs, std::vector(input_data, output_data, input_shape, output_shape, align_corners); + // Antialias for downsampling + std::vector scale_factors; + compute_scale_factors({input_shape[2], input_shape[3]}, {output_shape[2], output_shape[3]}, scale_factors, + align_corners); + const float h_scale = scale_factors[0]; + const float w_scale = scale_factors[1]; + const float* src_ptr = input_data; + std::vector blurred; + if (antialias && (h_scale > 1.f || w_scale > 1.f)) { + float sigma = 0.5f * std::max(h_scale, w_scale); + blurred = gaussian_blur_nchw(input_data, input_shape[0], input_shape[1], input_shape[2], input_shape[3], sigma); + src_ptr = blurred.data(); + } + bilinear_interpolate_2d(src_ptr, output_data, input_shape, output_shape, align_corners); } else { NYI("CPUInterpolateOp::forward bilinear mode only supports 4D input (NCHW format)"); } } else if (mode == aops::InterpolateOpMode::kBicubic) { if (input_dim == 4) { // NCHW format - bicubic_interpolate_2d(input_data, output_data, input_shape, output_shape, align_corners); + // Antialias for downsampling + std::vector scale_factors; + compute_scale_factors({input_shape[2], input_shape[3]}, {output_shape[2], output_shape[3]}, scale_factors, + align_corners); + const float h_scale = scale_factors[0]; + const float w_scale = scale_factors[1]; + const float* src_ptr = input_data; + std::vector blurred; + if (antialias && (h_scale > 1.f || w_scale > 1.f)) { + float sigma = 0.5f * std::max(h_scale, w_scale); + blurred = gaussian_blur_nchw(input_data, input_shape[0], input_shape[1], input_shape[2], input_shape[3], sigma); + src_ptr = blurred.data(); + } + bicubic_interpolate_2d(src_ptr, output_data, input_shape, output_shape, align_corners); } else { NYI("CPUInterpolateOp::forward bicubic mode only supports 4D input (NCHW format)"); } diff --git a/mllm/core/aops/InterpolateOp.cpp b/mllm/core/aops/InterpolateOp.cpp index 2e2e51790..2c160ade3 100644 --- a/mllm/core/aops/InterpolateOp.cpp +++ b/mllm/core/aops/InterpolateOp.cpp @@ -9,7 +9,7 @@ namespace mllm::aops { -InterpolateOp::InterpolateOp(const InterpolateOpOptions& options) : BaseOp(OpTypes::kSiLU), options_(options) {} +InterpolateOp::InterpolateOp(const InterpolateOpOptions& options) : BaseOp(OpTypes::kInterpolate), options_(options) {} void InterpolateOp::load(const ParameterFile::ptr_t& ploader) { MLLM_EMPTY_SCOPE; } @@ -59,19 +59,6 @@ void InterpolateOp::reshape(const std::vector& inputs, std::vector= 2) { - // This is typically used for image resizing where we want to maintain aspect ratio - // We'll implement a simple version that scales based on the smaller dimension - const int offset = input_dim - static_cast(options_.size.size()); - float h_scale = static_cast(options_.size[0]) / input_shape[offset]; - float w_scale = static_cast(options_.size[1]) / input_shape[offset + 1]; - - float scale = std::min(h_scale, w_scale); - output_shape[offset] = static_cast(input_shape[offset] * scale); - output_shape[offset + 1] = static_cast(input_shape[offset + 1] * scale); - } - // Create output tensor with the calculated shape outputs.emplace_back(Tensor::empty(output_shape, input.dtype(), input.device())); } diff --git a/mllm/core/aops/InterpolateOp.hpp b/mllm/core/aops/InterpolateOp.hpp index 7d8bbc3f7..1000573fe 100644 --- a/mllm/core/aops/InterpolateOp.hpp +++ b/mllm/core/aops/InterpolateOp.hpp @@ -21,7 +21,7 @@ struct InterpolateOpOptions : public BaseOpOptions { std::vector scale_factor; InterpolateOpMode mode = InterpolateOpMode::kNearest; bool align_corners = false; - bool keep_aspect_ratio = false; + bool antialias = false; }; class InterpolateOp : public BaseOp { diff --git a/mllm/models/deepseek_ocr/deepencoder.hpp b/mllm/models/deepseek_ocr/deepencoder.hpp index 0ad56e8b7..89a11dd5e 100644 --- a/mllm/models/deepseek_ocr/deepencoder.hpp +++ b/mllm/models/deepseek_ocr/deepencoder.hpp @@ -16,6 +16,11 @@ namespace mllm::models::deepseek_ocr { +//===----------------------------------------------------------------------===// +// MLP Projector For Mapping Visual Tokens to Text Token Space +//===----------------------------------------------------------------------===// +// TODO + //===----------------------------------------------------------------------===// // CLIP // @@ -95,61 +100,8 @@ class CLIPVisionEmbeddings final : public nn::Module { if (src_size != tgt_size) { old_pos_embed = old_pos_embed.view({1, src_size, src_size, dim}).permute({0, 3, 1, 2}); old_pos_embed = old_pos_embed.to(kFloat32); - - auto new_pos_embed = Tensor::empty({tgt_size, tgt_size}, kFloat32, kCPU).alloc(); - - // F.interpolate here. - { - const int channels = old_pos_embed.shape()[1]; - - auto old_pos_embed_ptr = old_pos_embed.ptr(); - auto new_pos_embed_ptr = new_pos_embed.ptr(); - - auto cubic_kernel = [](float x) -> float { - constexpr float a = -0.5f; - x = std::abs(x); - if (x < 1.0f) { - return (a + 2.0f) * x * x * x - (a + 3.0f) * x * x + 1.0f; - } else if (x < 2.0f) { - return a * x * x * x - 5.0f * a * x * x + 8.0f * a * x - 4.0f * a; - } else { - return 0.0f; - } - }; - - const float scale_y = static_cast(src_size) / tgt_size; - const float scale_x = static_cast(src_size) / tgt_size; - - for (int c = 0; c < channels; ++c) { - const float* src_channel_ptr = old_pos_embed_ptr + c * src_size * src_size; - float* dst_channel_ptr = new_pos_embed_ptr + c * tgt_size * tgt_size; - - for (int j = 0; j < tgt_size; ++j) { - for (int i = 0; i < tgt_size; ++i) { - float src_y = (static_cast(j) + 0.5f) * scale_y - 0.5f; - float src_x = (static_cast(i) + 0.5f) * scale_x - 0.5f; - - int y0 = static_cast(std::floor(src_y)) - 1; - int x0 = static_cast(std::floor(src_x)) - 1; - - float total_weight = 0.0f; - - for (int m = 0; m < 4; ++m) { - for (int n = 0; n < 4; ++n) { - int cur_y = y0 + m; - int cur_x = x0 + n; - cur_y = std::max(0, std::min(src_size - 1, cur_y)); - cur_x = std::max(0, std::min(src_size - 1, cur_x)); - float weight_y = cubic_kernel(src_y - (y0 + m)); - float weight_x = cubic_kernel(src_x - (x0 + n)); - total_weight += src_channel_ptr[cur_y * src_size + cur_x] * weight_y * weight_x; - } - } - dst_channel_ptr[j * tgt_size + i] = total_weight; - } - } - } - } + auto new_pos_embed = nn::functional::interpolateBySize(old_pos_embed, {tgt_size, tgt_size}, + aops::InterpolateOpMode::kBicubic, false, true); new_pos_embed = new_pos_embed.permute({0, 2, 3, 1}); new_pos_embed = new_pos_embed.view({tgt_size * tgt_size, dim}); auto vision_pos_embed = nn::functional::concat({cls_token, new_pos_embed}, 0); @@ -403,34 +355,19 @@ class Attention final : public nn::Module { } } - Tensor __interpolateLinear1d(const Tensor& input, int output_size) { - auto output = Tensor::empty({output_size, input.size(1)}).alloc(); - int input_size = input.size(0); - float scale_factor = static_cast(input_size - 1) / (output_size - 1); - - for (int i = 0; i < output_size; ++i) { - float in_x = i * scale_factor; - int x0 = static_cast(floor(in_x)); - int x1 = std::min(x0 + 1, input_size - 1); - float w1 = in_x - x0; - float w0 = 1.0f - w1; - - for (int c = 0; c < input.size(1); ++c) { - float val = - w0 * input.ptr()[x0 * input.size(1) + c] + w1 * input.ptr()[x1 * input.size(1) + c]; - *output.offsettedPtr({i, c}) = val; - } - } - return output; - } - // Get relative positional embeddings according to the relative positions of query and key sizes. - Tensor getRelPos(int q_size, int k_size, const Tensor& rel_pos) { + Tensor getRelPos(int q_size, int k_size, const Tensor& rel_pos_) { + auto rel_pos = rel_pos_; auto max_rel_dist = 2 * std::max(q_size, k_size) - 1; Tensor rel_pos_resized = Tensor::nil(); if (rel_pos.size(0) != max_rel_dist) { - rel_pos_resized = __interpolateLinear1d(rel_pos, max_rel_dist); + auto dtype = rel_pos.dtype(); + rel_pos = rel_pos.to(kFloat32); + rel_pos_resized = nn::functional::interpolateBySize(rel_pos.view({1, rel_pos.size(0), -1}).permute({0, 2, 1}), + {max_rel_dist}, aops::InterpolateOpMode::kLinear) + .to(dtype); + rel_pos_resized = rel_pos_resized.view({-1, max_rel_dist}).permute({1, 0}); } else { rel_pos_resized = rel_pos; } @@ -442,7 +379,6 @@ class Attention final : public nn::Module { float k_scale = std::max((float)q_size / k_size, 1.0f); for (int i = 0; i < q_size; ++i) { q_coords[i] = i * q_scale; } - for (int i = 0; i < k_size; ++i) { k_coords[i] = i * k_scale; } float offset = (k_size - 1) * k_scale; @@ -610,32 +546,7 @@ class Block final : public nn::Module { auto pad_h = (window_size - H % window_size) % window_size; auto pad_w = (window_size - W % window_size) % window_size; - if (pad_h > 0 || pad_w > 0) { - // Do x = F.pad(x, (0, 0, 0, pad_w, 0, pad_h)) - const auto& x_shape = x.shape(); - const int in_c = x_shape[0]; - const int in_h = x_shape[1]; - const int in_w = x_shape[2]; - - const int out_h = in_h + pad_h; - const int out_w = in_w + pad_w; - std::vector out_shape = {in_c, out_h, out_w}; - - Tensor padded_x = Tensor::empty(out_shape).alloc(); - - mllm_fp32_t* x_data_ptr = x.ptr(); - mllm_fp32_t* padded_x_data_ptr = padded_x.ptr(); - - for (int c = 0; c < in_c; ++c) { - for (int h = 0; h < in_h; ++h) { - for (int w = 0; w < in_w; ++w) { - int src_idx = c * (in_h * in_w) + h * in_w + w; - int dest_idx = c * (out_h * out_w) + h * out_w + w; - padded_x_data_ptr[dest_idx] = x_data_ptr[src_idx]; - } - } - } - } + if (pad_h > 0 || pad_w > 0) { x = nn::functional::pad(x, {0, 0, 0, pad_w, 0, pad_h}); } auto Hp = H + pad_h; auto Wp = W + pad_w; @@ -741,11 +652,6 @@ class ImageEncoderViT final : public nn::Module { Tensor::shape_t{1, 1}, false); } - template - T __cubicInterpolate(T p0, T p1, T p2, T p3, float t) { - return p1 + 0.5f * t * (p2 - p0 + t * (2.0f * p0 - 5.0f * p1 + 4.0f * p2 - p3 + t * (3.0f * (p1 - p2) + p3 - p0))); - } - Tensor getAbsPosSam(Tensor abs_pos, int tgt_size) { auto dtype = abs_pos.dtype(); auto src_size = abs_pos.size(1); @@ -753,53 +659,9 @@ class ImageEncoderViT final : public nn::Module { if (src_size != tgt_size) { auto old_pos_embed = abs_pos.permute({0, 3, 1, 2}); old_pos_embed = old_pos_embed.to(kFloat32); - - const int batch_size = old_pos_embed.size(0); - const int channels = old_pos_embed.size(1); - const int src_h = old_pos_embed.size(2); - const int src_w = old_pos_embed.size(3); - const int tgt_h = tgt_size; - const int tgt_w = tgt_size; - - auto new_pos_embed = Tensor::empty({batch_size, channels, tgt_h, tgt_w}, kFloat32).alloc(); - - const float* src_data = old_pos_embed.ptr(); - float* dst_data = new_pos_embed.ptr(); - - const float height_scale = static_cast(src_h) / tgt_h; - const float width_scale = static_cast(src_w) / tgt_w; - for (int b = 0; b < batch_size; ++b) { - for (int c = 0; c < channels; ++c) { - const float* current_src_channel = src_data + (b * channels + c) * src_h * src_w; - float* current_dst_channel = dst_data + (b * channels + c) * tgt_h * tgt_w; - - for (int y_tgt = 0; y_tgt < tgt_h; ++y_tgt) { - for (int x_tgt = 0; x_tgt < tgt_w; ++x_tgt) { - float y_src = (static_cast(y_tgt) + 0.5f) * height_scale - 0.5f; - float x_src = (static_cast(x_tgt) + 0.5f) * width_scale - 0.5f; - - int y_floor = static_cast(std::floor(y_src)); - int x_floor = static_cast(std::floor(x_src)); - float y_frac = y_src - y_floor; - float x_frac = x_src - x_floor; - - float p[4][4]; - for (int i = 0; i < 4; ++i) { - for (int j = 0; j < 4; ++j) { - int y_coord = std::max(0, std::min(src_h - 1, y_floor - 1 + i)); - int x_coord = std::max(0, std::min(src_w - 1, x_floor - 1 + j)); - p[i][j] = current_src_channel[y_coord * src_w + x_coord]; - } - } - float col[4]; - for (int i = 0; i < 4; ++i) { col[i] = __cubicInterpolate(p[i][0], p[i][1], p[i][2], p[i][3], x_frac); } - float value = __cubicInterpolate(col[0], col[1], col[2], col[3], y_frac); - current_dst_channel[y_tgt * tgt_w + x_tgt] = value; - } - } - } - } - new_pos_embed = new_pos_embed.to(dtype); + // clang-format off + auto new_pos_embed = nn::functional::interpolateBySize(old_pos_embed, {tgt_size, tgt_size}, aops::InterpolateOpMode::kBicubic, false, true).to(dtype); + // clang-format on new_pos_embed = new_pos_embed.permute({0, 2, 3, 1}); return new_pos_embed; } else { diff --git a/mllm/nn/Functional.cpp b/mllm/nn/Functional.cpp index 071d48abf..b6b6a5c71 100644 --- a/mllm/nn/Functional.cpp +++ b/mllm/nn/Functional.cpp @@ -134,21 +134,23 @@ Tensor pad(const Tensor& x, const std::vector& pad, aops::PadMode mode, {x})[0]; } -Tensor interpolate(const Tensor& x, const std::vector& size, aops::InterpolateOpMode mode, bool align_corners, - bool keep_aspect_ratio) { +Tensor interpolateBySize(const Tensor& x, const std::vector& size, aops::InterpolateOpMode mode, bool align_corners, + bool antialias) { aops::InterpolateOpOptions opts{}; opts.size.assign(size.begin(), size.end()); opts.mode = mode; opts.align_corners = align_corners; - opts.keep_aspect_ratio = keep_aspect_ratio; + opts.antialias = antialias; return Context::instance().buildOpAndSubmitTask(OpTypes::kInterpolate, opts, {x})[0]; } -Tensor interpolate(const Tensor& x, const std::vector& scale_factor, aops::InterpolateOpMode mode, bool align_corners) { +Tensor interpolateByScale(const Tensor& x, const std::vector& scale_factor, aops::InterpolateOpMode mode, + bool align_corners, bool antialias) { aops::InterpolateOpOptions opts{}; opts.scale_factor = scale_factor; opts.mode = mode; opts.align_corners = align_corners; + opts.antialias = antialias; return Context::instance().buildOpAndSubmitTask(OpTypes::kInterpolate, opts, {x})[0]; } diff --git a/mllm/nn/Functional.hpp b/mllm/nn/Functional.hpp index 30c127667..9654b6663 100644 --- a/mllm/nn/Functional.hpp +++ b/mllm/nn/Functional.hpp @@ -138,12 +138,13 @@ Tensor scaledDotProductAttention(const Tensor& Q, const Tensor& K, const Tensor& Tensor pad(const Tensor& x, const std::vector& pad, aops::PadMode mode = aops::PadMode::kConstant, float value = 0.0f); // Interpolate by target size -Tensor interpolate(const Tensor& x, const std::vector& size, - aops::InterpolateOpMode mode = aops::InterpolateOpMode::kNearest, bool align_corners = false, - bool keep_aspect_ratio = false); +Tensor interpolateBySize(const Tensor& x, const std::vector& size, + aops::InterpolateOpMode mode = aops::InterpolateOpMode::kNearest, bool align_corners = false, + bool antialias = false); // Interpolate by scale factor -Tensor interpolate(const Tensor& x, const std::vector& scale_factor, - aops::InterpolateOpMode mode = aops::InterpolateOpMode::kNearest, bool align_corners = false); +Tensor interpolateByScale(const Tensor& x, const std::vector& scale_factor, + aops::InterpolateOpMode mode = aops::InterpolateOpMode::kNearest, bool align_corners = false, + bool antialias = false); } // namespace mllm::nn::functional From 1515cdbc12fda6d4d49bedd486ddce27fb1a0786 Mon Sep 17 00:00:00 2001 From: chenghuaWang <2923277184@qq.com> Date: Thu, 23 Oct 2025 17:29:15 +0800 Subject: [PATCH 06/25] feat(deepseek_ocr): add mlp projector linear impl type configuration MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Add `mlp_projector_linear_impl_type` to `DpskOcrConfig` for configuring the linear implementation type in MLP projector - Update conversation preprocessing templates: - Replace `<|sft▁begin|>` with `<|sft begin|>` - Adjust separator tags and role markers in deepseek and deepseekv2 templates - Refactor `dynamic_preprocess` to `dynamicPreprocess`: - Return both processed images and target aspect ratio - Improve tiling logic using candidate aspect ratios within min/max limits - Support optional thumbnail generation based on block count - Introduce `MlpProjector` module for mapping visual tokens to text token space - Include necessary headers (``, ``, ``) in preprocessing header --- .../configuration_deepseek_ocr.hpp | 1 + .../deepseek_ocr/conversation_preprocess.hpp | 104 +++++++++--------- mllm/models/deepseek_ocr/deepencoder.hpp | 15 ++- 3 files changed, 66 insertions(+), 54 deletions(-) diff --git a/mllm/models/deepseek_ocr/configuration_deepseek_ocr.hpp b/mllm/models/deepseek_ocr/configuration_deepseek_ocr.hpp index f13946025..8b7edae1e 100644 --- a/mllm/models/deepseek_ocr/configuration_deepseek_ocr.hpp +++ b/mllm/models/deepseek_ocr/configuration_deepseek_ocr.hpp @@ -182,6 +182,7 @@ struct DpskOcrConfig : protected ConfigFile { // MLLM Related Stuff aops::LinearImplTypes clip_linear_impl_type; aops::LinearImplTypes sam_linear_impl_type; + aops::LinearImplTypes mlp_projector_linear_impl_type; }; } // namespace mllm::models::deepseek_ocr diff --git a/mllm/models/deepseek_ocr/conversation_preprocess.hpp b/mllm/models/deepseek_ocr/conversation_preprocess.hpp index 9327115a6..8206a48e7 100644 --- a/mllm/models/deepseek_ocr/conversation_preprocess.hpp +++ b/mllm/models/deepseek_ocr/conversation_preprocess.hpp @@ -12,6 +12,9 @@ #include #include #include +#include +#include +#include #include "mllm/preprocessor/visual/Image.hpp" @@ -83,7 +86,7 @@ class Conversation { const auto& [role, message] = std::make_pair(i[0], i[1]); if (!message.empty()) { if (role == "User") { - ret += "<|sft▁begin|>\n" + message + sep_; + ret += "<|sft begin|>\n" + message + sep_; } else { ret += message + sep2_.value_or(""); } @@ -254,16 +257,16 @@ std::shared_ptr getConvTemplate(const std::string& name) { void initializeTemplates() { // DeepSeek template auto deepseek = std::make_shared( - "deepseek", "{system_message}", "", std::vector{"<|User|>", "<|Assistant|>"}, - std::vector>{}, 0, SeparatorStyle::DeepSeek, "\n\n", "<|end▁of▁sentence|>", - std::vector{"User:", "<|end▁of▁sentence|>"}, std::vector{100001}); + "deepseek", "{system_message}", "", std::vector{"<|User|>", " outputId="}, + std::vector>{}, 0, SeparatorStyle::DeepSeek, "\n\n", "<|end of sentence|>", + std::vector{"User:", "<|end of sentence|>"}, std::vector{100001}); registerConvTemplate(deepseek); // DeepSeekV2 template auto deepseekv2 = std::make_shared( "deepseekv2", "{system_message}", "", std::vector{"<|User|>", "<|Assistant|>"}, - std::vector>{}, 0, SeparatorStyle::DeepSeek, "", "<|end▁of▁sentence|>", - std::vector{"User:", "<|end▁of▁sentence|>"}, std::vector{100001}); + std::vector>{}, 0, SeparatorStyle::DeepSeek, "", "<|end of sentence|>", + std::vector{"User:", "<|end of sentence|>"}, std::vector{100001}); registerConvTemplate(deepseekv2); // Plain template @@ -355,62 +358,57 @@ std::pair findClosestAspectRatio(double aspect_ratio, const std::vecto * @param use_thumbnail Whether to add a square thumbnail (image_size x image_size) as the first tile * @return Vector of processed Image tiles */ -inline std::vector dynamic_preprocess(const Image& image, int image_size, int max_num = 6, bool use_thumbnail = false) { - // Mirror Python logic: generate a grid of image_size x image_size crops. - // Grid shape is derived from ceil(width/size) and ceil(height/size), - // and then capped to max_num by reducing along the longer dimension. - - Image src = image; // non-const copy to call non-const methods +inline std::pair, std::pair> dynamicPreprocess(const Image& image, int min_num = 2, + int max_num = 9, int image_size = 640, + bool use_thumbnail = false) { + Image src = image; const int w = src.w(); const int h = src.h(); - if (w <= 0 || h <= 0) { return {}; } - - // Integer ceil for positive numbers: ceil(x / y) == (x + y - 1) / y - int grid_w = (w + image_size - 1) / image_size; - int grid_h = (h + image_size - 1) / image_size; - if (grid_w < 1) grid_w = 1; - if (grid_h < 1) grid_h = 1; - - // Cap total tiles to max_num while preserving aspect tendency - while (grid_w * grid_h > max_num) { - if (grid_w >= grid_h && grid_w > 1) { - --grid_w; - } else if (grid_h > 1) { - --grid_h; - } else { - break; - } - } - - const int target_width = grid_w * image_size; - const int target_height = grid_h * image_size; + if (w <= 0 || h <= 0) { return {{}, {1, 1}}; } - std::vector out; - out.reserve(static_cast(max_num)); + const double aspect_ratio = static_cast(w) / static_cast(h); - // Optional thumbnail first - if (use_thumbnail) { - out.push_back(src.resize(image_size, image_size)); - if (static_cast(out.size()) >= max_num) { return out; } + // Build candidate ratios: all pairs (i, j) with min_num <= i*j <= max_num + std::set> ratio_set; + for (int n = min_num; n <= max_num; ++n) { + for (int i = 1; i <= n; ++i) { + for (int j = 1; j <= n; ++j) { + const int blocks = i * j; + if (blocks >= min_num && blocks <= max_num) { ratio_set.insert({i, j}); } + } + } + } + std::vector> target_ratios(ratio_set.begin(), ratio_set.end()); + std::sort(target_ratios.begin(), target_ratios.end(), + [](const auto& a, const auto& b) { return (a.first * a.second) < (b.first * b.second); }); + + const auto target_aspect_ratio = findClosestAspectRatio(aspect_ratio, target_ratios, w, h, image_size); + + const int target_width = image_size * target_aspect_ratio.first; + const int target_height = image_size * target_aspect_ratio.second; + const int blocks = target_aspect_ratio.first * target_aspect_ratio.second; + + Image resized_img = src.resize(target_width, target_height); + std::vector processed_images; + processed_images.reserve(static_cast(blocks)); + + for (int i = 0; i < blocks; ++i) { + const int cols = target_width / image_size; // equals target_aspect_ratio.first + const int x0 = (i % cols) * image_size; + const int y0 = (i / cols) * image_size; + const int x1 = x0 + image_size; + const int y1 = y0 + image_size; + Image split_img = resized_img.crop(x0, y0, x1, y1); + processed_images.push_back(split_img); } - const int total_tiles = grid_w * grid_h; - for (int i = 0; i < total_tiles; ++i) { - // Python equivalent: - // x = (i % (target_width // image_size)) * image_size - // y = (i // (target_width // image_size)) * image_size - const int cols = target_width / image_size; // == grid_w - const int x = (i % cols) * image_size; - const int y = (i / cols) * image_size; - - // PIL-style crop with zero padding beyond bounds - Image tile = src.crop(x, y, x + image_size, y + image_size); - out.push_back(tile); + assert(static_cast(processed_images.size()) == blocks); - if (static_cast(out.size()) >= max_num) { break; } + if (use_thumbnail && static_cast(processed_images.size()) != 1) { + processed_images.push_back(src.resize(image_size, image_size)); } - return out; + return {processed_images, target_aspect_ratio}; } } // namespace mllm::models::deepseek_ocr diff --git a/mllm/models/deepseek_ocr/deepencoder.hpp b/mllm/models/deepseek_ocr/deepencoder.hpp index 89a11dd5e..154bb8d47 100644 --- a/mllm/models/deepseek_ocr/deepencoder.hpp +++ b/mllm/models/deepseek_ocr/deepencoder.hpp @@ -19,7 +19,20 @@ namespace mllm::models::deepseek_ocr { //===----------------------------------------------------------------------===// // MLP Projector For Mapping Visual Tokens to Text Token Space //===----------------------------------------------------------------------===// -// TODO +class MlpProjector final : public nn::Module { + nn::Linear layers_; + + public: + MlpProjector() = default; + + MlpProjector(const std::string& name, const DpskOcrConfig& config) : nn::Module(name) { + layers_ = reg("layers", 2048, 1280, true, config.mlp_projector_linear_impl_type); + } + + std::vector forward(const std::vector& inputs, const std::vector& args) override { + return {layers_(inputs[0])}; + } +}; //===----------------------------------------------------------------------===// // CLIP From e858d25d51272bbabcc8947133f00290cc97701a Mon Sep 17 00:00:00 2001 From: chenghuaWang <2923277184@qq.com> Date: Thu, 23 Oct 2025 22:14:19 +0800 Subject: [PATCH 07/25] feat(deepseek_ocr): add message formatting and model inference support - Implement `formatMessages` function to process conversation JSON with trimming and role/content extraction - Add `DeepseekOCRForCausalLM` class with basic inference logic and image handling - Update tokenizer implementation to support special tokens and proper tokenization - Remove unused `` include in tokenizer header --- .../deepseek_ocr/conversation_preprocess.hpp | 36 ++++++++++ .../deepseek_ocr/modeling_deepseek_ocr.hpp | 65 +++++++++++++++++++ .../tokenization_deepseek_ocr.hpp | 46 ++++++++++--- 3 files changed, 138 insertions(+), 9 deletions(-) diff --git a/mllm/models/deepseek_ocr/conversation_preprocess.hpp b/mllm/models/deepseek_ocr/conversation_preprocess.hpp index 8206a48e7..7119e478e 100644 --- a/mllm/models/deepseek_ocr/conversation_preprocess.hpp +++ b/mllm/models/deepseek_ocr/conversation_preprocess.hpp @@ -282,6 +282,42 @@ void initializeTemplates() { registerConvTemplate(alignment); } +inline std::string formatMessages(const nlohmann::json& conversations, const std::string& sft_format = "deepseek", + const std::string& system_prompt = "") { + auto conv = getConvTemplate(sft_format); + + // Helper trim function to mimic Python's .strip() + auto trim = [](const std::string& s) -> std::string { + const char* ws = " \t\n\r\f\v"; + const auto start = s.find_first_not_of(ws); + if (start == std::string::npos) return ""; + const auto end = s.find_last_not_of(ws); + return s.substr(start, end - start + 1); + }; + + conv->setSystemMessage(system_prompt); + for (const auto& message : conversations) { + std::string role; + std::string content; + + if (message.contains("role") && message["role"].is_string()) { + role = message["role"].get(); + } else { + role = ""; + } + + if (message.contains("content") && message["content"].is_string()) { + content = trim(message["content"].get()); + } else { + content = ""; + } + + conv->appendMessage(role, content); + } + + return trim(conv->getPrompt()); +} + //===----------------------------------------------------------------------===// // For Image processing //===----------------------------------------------------------------------===// diff --git a/mllm/models/deepseek_ocr/modeling_deepseek_ocr.hpp b/mllm/models/deepseek_ocr/modeling_deepseek_ocr.hpp index 8a948dbc3..c5fa378f6 100644 --- a/mllm/models/deepseek_ocr/modeling_deepseek_ocr.hpp +++ b/mllm/models/deepseek_ocr/modeling_deepseek_ocr.hpp @@ -1,2 +1,67 @@ // Copyright (c) MLLM Team. // Licensed under the MIT License. +#pragma once + +#include + +#include + +#include "mllm/mllm.hpp" +#include "mllm/models/ARGeneration.hpp" +#include "mllm/preprocessor/visual/ImageTransform.hpp" +#include "mllm/models/deepseek_ocr/conversation_preprocess.hpp" +#include "mllm/models/deepseek_ocr/tokenization_deepseek_ocr.hpp" +#include "mllm/models/deepseek_ocr/configuration_deepseek_ocr.hpp" + +namespace mllm::models::deepseek_ocr { + +class DeepseekOCRForCausalLM final : public nn::Module, public ARGeneration { + public: + DeepseekOCRForCausalLM() = default; + + explicit DeepseekOCRForCausalLM(const DpskOcrConfig& config) {} + + ARGenerationOutputPast forward(const ARGenerationOutputPast& input, const ARGenerationArgs& args) override { return {}; } + + void infer(DpskOcrTokenizer& tokenizer, const std::string& prompt, const std::string& image_fp, + const std::string& output_path, int base_size = 1024, int image_size = 640, bool crop_mode = true) { + namespace fs = std::filesystem; + fs::path out_path(output_path); + fs::create_directories(out_path); + fs::create_directories(out_path / "images"); + + nlohmann::json conversations; + if (!prompt.empty() && !image_fp.empty()) { + conversations = nlohmann::json::array(); + conversations.push_back({{"role", "<|User|>"}, {"content", prompt}, {"images", nlohmann::json::array({image_fp})}}); + conversations.push_back({{"role", "<|Assistant|>"}, {"content", ""}}); + } else if (!prompt.empty()) { + conversations = nlohmann::json::array(); + conversations.push_back({{"role", "<|User|>"}, {"content", prompt}}); + conversations.push_back({{"role", "<|Assistant|>"}, {"content", ""}}); + } else { + // Prompt should not be empty + MLLM_RT_ASSERT_EQ(prompt.empty(), false); + } + + auto processed_prompt = formatMessages(conversations, "plain", ""); + + // Global constant define + const int PATCH_SIZE = 16; + const int DOWN_SAMPLE_RATIO = 4; + const std::string IMAGE_TOKEN = ""; + const int64_t IMAGE_TOKEN_ID = 128815; + + // Load image + auto images = loadImages(conversations); + + // Image transform infra + auto image_transform = BasicImageTransform(std::nullopt, std::nullopt, /*mean=*/std::vector{0.5, 0.5, 0.5}, + /*std=*/std::vector{0.5, 0.5, 0.5}); + + // Split text with IMAGE_TOKEN + // TODO + } +}; + +} // namespace mllm::models::deepseek_ocr diff --git a/mllm/models/deepseek_ocr/tokenization_deepseek_ocr.hpp b/mllm/models/deepseek_ocr/tokenization_deepseek_ocr.hpp index ca39f30df..88cc13965 100644 --- a/mllm/models/deepseek_ocr/tokenization_deepseek_ocr.hpp +++ b/mllm/models/deepseek_ocr/tokenization_deepseek_ocr.hpp @@ -9,7 +9,6 @@ #pragma once #include -#include #include "mllm/preprocessor/tokenizers/BPE.hpp" #include "mllm/preprocessor/tokenizers/Unicode.hpp" @@ -19,18 +18,35 @@ namespace mllm::models::deepseek_ocr { // Actually is LlamaTokenizer class DpskOcrTokenizer final : public mllm::preprocessor::AutoTokenizer { - explicit DpskOcrTokenizer(const std::string& file_path) { preprocessor::initLocal(); } + public: + explicit DpskOcrTokenizer(const std::string& file_path) { + // Init + preprocessor::initLocal(); - std::vector _tokenize(const std::string& str) override { - // TODO - return {}; + // Load bpe files + bpe_.initFromSentencePieceJson(file_path); + + // Add special tokens to trie + special_tokens_trie_.add(L"<|User|>"); + special_tokens_trie_.add(L"<|Assistant|>"); + special_tokens_trie_.add(L"<|begin▁of▁sentence|>"); + special_tokens_trie_.add(L"<|end▁of▁sentence|>"); + special_tokens_trie_.add(L"<|▁pad▁|>"); } - std::vector tokenize(const std::string& str) override { - // TODO - return {}; + std::vector _tokenize(const std::string& str) override { + std::wstring text = preprocessor::utf8string2WideString(str); + std::replace(text.begin(), text.end(), L' ', SPIECE_UNDERLINE[0]); + auto tokens = bpe_._bpe(text); + + if (tokens.size() > 1 && tokens[0] == SPIECE_UNDERLINE && special_tokens_trie_.isSpecialToken(tokens[1])) { + tokens.erase(tokens.begin()); + } + return tokens; } + std::vector tokenize(const std::string& str) override { return _tokenize(str); } + std::wstring _detokenize(int64_t pos_idx) override { // TODO return L""; @@ -41,10 +57,22 @@ class DpskOcrTokenizer final : public mllm::preprocessor::AutoTokenizer { return _detokenize(pos_idx); } - Tensor convert2Ids(const std::vector& strs) override { return Tensor::nil(); } + Tensor convert2Ids(const std::vector& strs) override { + std::vector ids; + ids.reserve(strs.size()); + for (const auto& str : strs) { ids.emplace_back(bpe_._lookup_vocab(str)); } + Tensor ret = Tensor::empty({/*batch*/ 1, /*seq*/ (int32_t)ids.size()}, kInt64, kCPU) + .setMemType(kExtraInput) + .setName("llama-tokenizer-i0") + .alloc(); + auto ptr = ret.ptr(); + for (size_t i = 0; i < ids.size(); ++i) { ptr[i] = ids[i]; } + return ret; + } private: // For text preprocessor::BPE bpe_; + std::wstring SPIECE_UNDERLINE = L"▁"; }; } // namespace mllm::models::deepseek_ocr From 4ca7a0732c80579f1a93216ed4e352b2e54f3cd5 Mon Sep 17 00:00:00 2001 From: chenghuaWang <2923277184@qq.com> Date: Thu, 23 Oct 2025 22:39:26 +0800 Subject: [PATCH 08/25] feat(ext): add tokenizers-cpp and opencv-mobile as optional extensions - Add tokenizers-cpp as a submodule under mllm/ext/vendors - Add opencv-mobile as a submodule under mllm/ext/vendors - Update .gitmodules to include the new submodules - Set update policy to 'none' for tokenizers-cpp - Add README.md to document the purpose of mllm extension - Add CMakeLists.txt for deepseek_ocr example - Register deepseek_ocr example in examples/CMakeLists.txt --- .gitmodules | 4 ++++ examples/CMakeLists.txt | 1 + examples/deepseek_ocr/CMakeLists.txt | 3 +++ examples/deepseek_ocr/main.cpp | 0 mllm/ext/README.md | 6 ++++++ mllm/ext/vendors/opencv-mobile | 1 + mllm/ext/vendors/tokenizers-cpp | 1 + 7 files changed, 16 insertions(+) create mode 100644 examples/deepseek_ocr/CMakeLists.txt create mode 100644 examples/deepseek_ocr/main.cpp create mode 100644 mllm/ext/README.md create mode 100644 mllm/ext/vendors/opencv-mobile create mode 160000 mllm/ext/vendors/tokenizers-cpp diff --git a/.gitmodules b/.gitmodules index c532fbc7d..8f2c8d283 100644 --- a/.gitmodules +++ b/.gitmodules @@ -21,3 +21,7 @@ [submodule "mllm/ffi/vendors/tvm-ffi"] path = mllm/ffi/vendors/tvm-ffi url = https://github.com/apache/tvm-ffi +[submodule "mllm/ext/vendors/tokenizers-cpp"] + path = mllm/ext/vendors/tokenizers-cpp + url = https://github.com/mlc-ai/tokenizers-cpp + update = none diff --git a/examples/CMakeLists.txt b/examples/CMakeLists.txt index f193f34c6..51dfd2e17 100644 --- a/examples/CMakeLists.txt +++ b/examples/CMakeLists.txt @@ -6,3 +6,4 @@ add_subdirectory(llama) add_subdirectory(minicpm_o) add_subdirectory(qwen3) add_subdirectory(qwen3_service) +add_subdirectory(deepseek_ocr) diff --git a/examples/deepseek_ocr/CMakeLists.txt b/examples/deepseek_ocr/CMakeLists.txt new file mode 100644 index 000000000..8a80e34de --- /dev/null +++ b/examples/deepseek_ocr/CMakeLists.txt @@ -0,0 +1,3 @@ +add_executable(mllm-deepseek-ocr-runner main.cpp) +target_link_libraries(mllm-deepseek-ocr-runner PRIVATE MllmRT MllmCPUBackend) +target_include_directories(mllm-deepseek-ocr-runner PRIVATE ${MLLM_INCLUDE_DIR}) diff --git a/examples/deepseek_ocr/main.cpp b/examples/deepseek_ocr/main.cpp new file mode 100644 index 000000000..e69de29bb diff --git a/mllm/ext/README.md b/mllm/ext/README.md new file mode 100644 index 000000000..20ca6e222 --- /dev/null +++ b/mllm/ext/README.md @@ -0,0 +1,6 @@ +# MLLM Extension + +Mllm extension contains some third-party packages and it's warper to mllm. Those extensions are OPTIONAL for mllm main lib. + +- tokenizer-cpp +- mobile opencv diff --git a/mllm/ext/vendors/opencv-mobile b/mllm/ext/vendors/opencv-mobile new file mode 100644 index 000000000..4c3f0b435 --- /dev/null +++ b/mllm/ext/vendors/opencv-mobile @@ -0,0 +1 @@ +https://github.com/nihui/opencv-mobile diff --git a/mllm/ext/vendors/tokenizers-cpp b/mllm/ext/vendors/tokenizers-cpp new file mode 160000 index 000000000..55d53aa38 --- /dev/null +++ b/mllm/ext/vendors/tokenizers-cpp @@ -0,0 +1 @@ +Subproject commit 55d53aa38dc8df7d9c8bd9ed50907e82ae83ce66 From 0eabe59b8507905a886e81261d1b8cc17d843546 Mon Sep 17 00:00:00 2001 From: chenghuaWang <2923277184@qq.com> Date: Fri, 24 Oct 2025 10:15:21 +0800 Subject: [PATCH 09/25] feat(ocr): add llvm-project submodule and update deepseek ocr model - Add llvm-project as a submodule with none update strategy - Update conversation preprocessing with corrected special tokens - Register conversation templates with optional end tokens - Initialize templates in model inference and log processed prompt --- .gitmodules | 4 ++++ examples/deepseek_ocr/main.cpp | 13 ++++++++++++ mllm/ext/CMakeLists.txt | 0 mllm/ext/vendors/llvm-project | 1 + .../deepseek_ocr/conversation_preprocess.hpp | 21 ++++++++++--------- .../deepseek_ocr/modeling_deepseek_ocr.hpp | 5 +++++ 6 files changed, 34 insertions(+), 10 deletions(-) create mode 100644 mllm/ext/CMakeLists.txt create mode 160000 mllm/ext/vendors/llvm-project diff --git a/.gitmodules b/.gitmodules index 8f2c8d283..e2523ea21 100644 --- a/.gitmodules +++ b/.gitmodules @@ -25,3 +25,7 @@ path = mllm/ext/vendors/tokenizers-cpp url = https://github.com/mlc-ai/tokenizers-cpp update = none +[submodule "mllm/ext/vendors/llvm-project"] + path = mllm/ext/vendors/llvm-project + url = https://github.com/llvm/llvm-project + update = none diff --git a/examples/deepseek_ocr/main.cpp b/examples/deepseek_ocr/main.cpp index e69de29bb..881bc053b 100644 --- a/examples/deepseek_ocr/main.cpp +++ b/examples/deepseek_ocr/main.cpp @@ -0,0 +1,13 @@ +#include +#include +#include +#include "mllm/models/deepseek_ocr/modeling_deepseek_ocr.hpp" +#include "mllm/models/deepseek_ocr/tokenization_deepseek_ocr.hpp" + +using mllm::Argparse; + +MLLM_MAIN({ + auto model = mllm::models::deepseek_ocr::DeepseekOCRForCausalLM(); + auto tokenizer = mllm::models::deepseek_ocr::DpskOcrTokenizer("/Volumes/D/hf-models/DeepSeek-OCR/tokenizer.json"); + model.infer(tokenizer, "hello world", "/Volumes/D/mllm/.tmp/dpsk-ocr-pr.png", "/Volumes/D/mllm/.tmp/dpsk-ocr"); +}); diff --git a/mllm/ext/CMakeLists.txt b/mllm/ext/CMakeLists.txt new file mode 100644 index 000000000..e69de29bb diff --git a/mllm/ext/vendors/llvm-project b/mllm/ext/vendors/llvm-project new file mode 160000 index 000000000..6a0f392bb --- /dev/null +++ b/mllm/ext/vendors/llvm-project @@ -0,0 +1 @@ +Subproject commit 6a0f392bb50d890f13cb961a911be28f965ed4f2 diff --git a/mllm/models/deepseek_ocr/conversation_preprocess.hpp b/mllm/models/deepseek_ocr/conversation_preprocess.hpp index 7119e478e..7191ac301 100644 --- a/mllm/models/deepseek_ocr/conversation_preprocess.hpp +++ b/mllm/models/deepseek_ocr/conversation_preprocess.hpp @@ -86,7 +86,7 @@ class Conversation { const auto& [role, message] = std::make_pair(i[0], i[1]); if (!message.empty()) { if (role == "User") { - ret += "<|sft begin|>\n" + message + sep_; + ret += "<|sft▁begin|>\n" + message + sep_; } else { ret += message + sep2_.value_or(""); } @@ -257,28 +257,29 @@ std::shared_ptr getConvTemplate(const std::string& name) { void initializeTemplates() { // DeepSeek template auto deepseek = std::make_shared( - "deepseek", "{system_message}", "", std::vector{"<|User|>", " outputId="}, - std::vector>{}, 0, SeparatorStyle::DeepSeek, "\n\n", "<|end of sentence|>", - std::vector{"User:", "<|end of sentence|>"}, std::vector{100001}); + "deepseek", "{system_message}", "", std::vector{"<|User|>", "<|Assistant|>"}, + std::vector>{}, 0, SeparatorStyle::DeepSeek, "\n\n", "<|end▁of▁sentence|>", + std::optional{"<|end▁of▁sentence|>"}, std::optional>{std::vector{100001}}); registerConvTemplate(deepseek); // DeepSeekV2 template auto deepseekv2 = std::make_shared( "deepseekv2", "{system_message}", "", std::vector{"<|User|>", "<|Assistant|>"}, - std::vector>{}, 0, SeparatorStyle::DeepSeek, "", "<|end of sentence|>", - std::vector{"User:", "<|end of sentence|>"}, std::vector{100001}); + std::vector>{}, 0, SeparatorStyle::DeepSeekV2, "", "<|end▁of▁sentence|>", + std::optional{"<|end▁of▁sentence|>"}, std::optional>{std::vector{100001}}); registerConvTemplate(deepseekv2); // Plain template - auto plain = std::make_shared("plain", "", "", std::vector{"", ""}, - std::vector>{}, 0, SeparatorStyle::PLAIN, "", "", - std::vector{""}, std::vector{100001}); + auto plain = std::make_shared( + "plain", "", "", std::vector{"", ""}, std::vector>{}, 0, SeparatorStyle::PLAIN, "", + "", std::optional{""}, std::optional>{std::vector{100001}}); registerConvTemplate(plain); // Alignment template auto alignment = std::make_shared("alignment", "", "", std::vector{"", ""}, std::vector>{}, 0, SeparatorStyle::ALIGNMENT, "", "", - std::vector{""}, std::vector{100001}); + std::optional{""}, + std::optional>{std::vector{100001}}); registerConvTemplate(alignment); } diff --git a/mllm/models/deepseek_ocr/modeling_deepseek_ocr.hpp b/mllm/models/deepseek_ocr/modeling_deepseek_ocr.hpp index c5fa378f6..bc16a477e 100644 --- a/mllm/models/deepseek_ocr/modeling_deepseek_ocr.hpp +++ b/mllm/models/deepseek_ocr/modeling_deepseek_ocr.hpp @@ -25,6 +25,9 @@ class DeepseekOCRForCausalLM final : public nn::Module, public ARGeneration { void infer(DpskOcrTokenizer& tokenizer, const std::string& prompt, const std::string& image_fp, const std::string& output_path, int base_size = 1024, int image_size = 640, bool crop_mode = true) { + // Initialize template + initializeTemplates(); + namespace fs = std::filesystem; fs::path out_path(output_path); fs::create_directories(out_path); @@ -46,6 +49,8 @@ class DeepseekOCRForCausalLM final : public nn::Module, public ARGeneration { auto processed_prompt = formatMessages(conversations, "plain", ""); + MLLM_INFO("processed_prompt: {}", processed_prompt); + // Global constant define const int PATCH_SIZE = 16; const int DOWN_SAMPLE_RATIO = 4; From 78a17fb9c95c8cb2cebdcc22eda7e9e80c62bc8a Mon Sep 17 00:00:00 2001 From: chenghuaWang <2923277184@qq.com> Date: Fri, 24 Oct 2025 17:16:03 +0800 Subject: [PATCH 10/25] feat(cpu): add StackOp implementation and integrate into DeepSeek-OCR model - Implement `StackOp` for CPU backend with optimized memory copy for contiguous tensors and fallback for non-contiguous cases - Register `StackOp` in IR and core op system with proper RTTI and factory integration - Update DeepSeek-OCR model preprocessing to use `stack` for image tensor batching - Add `stack` functional API in `nn::functional` - Extend tokenizer utilities and string splitting helpers for better text processing - Introduce image padding utility that mirrors PIL's ImageOps.pad behavior - Add extension options in CMake for future LLVM and tokenizers-cpp integrations - Update example inference call with grounding prompt for OCR task --- CMakeLists.txt | 4 + examples/deepseek_ocr/main.cpp | 3 +- mllm/backends/cpu/CPUBackend.cpp | 9 +- mllm/backends/cpu/ops/StackOp.cpp | 117 ++++++++++++++ mllm/backends/cpu/ops/StackOp.hpp | 25 +++ mllm/compile/ir/GeneratedRTTIKind.hpp | 3 +- mllm/compile/ir/NodeRTTIClassOfImpl.hpp | 5 +- mllm/compile/ir/linalg/Op.cpp | 2 + mllm/compile/ir/linalg/Op.hpp | 4 + mllm/compile/ir/rtti_kind_gen.py | 1 + mllm/core/OpTypes.hpp | 6 + mllm/core/aops/StackOp.cpp | 63 ++++++++ mllm/core/aops/StackOp.hpp | 35 ++++ .../deepseek_ocr/modeling_deepseek_ocr.hpp | 150 +++++++++++++++++- .../tokenization_deepseek_ocr.hpp | 140 +++++++++++++++- mllm/nn/Functional.cpp | 5 + mllm/nn/Functional.hpp | 2 + mllm/preprocessor/visual/Image.cpp | 62 ++++++++ mllm/preprocessor/visual/Image.hpp | 4 + mllm/preprocessor/visual/ImageTransform.cpp | 4 +- mllm/utils/StringHelper.hpp | 60 +++++++ 21 files changed, 685 insertions(+), 19 deletions(-) create mode 100644 mllm/backends/cpu/ops/StackOp.cpp create mode 100644 mllm/backends/cpu/ops/StackOp.hpp create mode 100644 mllm/core/aops/StackOp.cpp create mode 100644 mllm/core/aops/StackOp.hpp create mode 100644 mllm/utils/StringHelper.hpp diff --git a/CMakeLists.txt b/CMakeLists.txt index b3c939de1..a7167a537 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -22,6 +22,10 @@ option(MLLM_BUILD_QNN_BACKEND "Enable MLLM QNN backend" OFF) option(MLLM_BUILD_SDK_C_BINDING "Enable MLLM C SDK binding" OFF) option(MLLM_BUILD_EXPERIMENTS "Enable MLLM experiments" OFF) +# Extension Enable +option(MLLM_EXT_ENABLE_LLVM_PROJECT OFF) +option(MLLM_EXT_ENABLE_TOKENIZERS_CPP OFF) + # CPU Backend: BLAS option(MLLM_USE_BLAS "Enable BLAS" OFF) option(MLLM_BLAS_VENDOR_ACCELERATE "Enable Accelerate BLAS on OSX" OFF) diff --git a/examples/deepseek_ocr/main.cpp b/examples/deepseek_ocr/main.cpp index 881bc053b..fb3368b4b 100644 --- a/examples/deepseek_ocr/main.cpp +++ b/examples/deepseek_ocr/main.cpp @@ -9,5 +9,6 @@ using mllm::Argparse; MLLM_MAIN({ auto model = mllm::models::deepseek_ocr::DeepseekOCRForCausalLM(); auto tokenizer = mllm::models::deepseek_ocr::DpskOcrTokenizer("/Volumes/D/hf-models/DeepSeek-OCR/tokenizer.json"); - model.infer(tokenizer, "hello world", "/Volumes/D/mllm/.tmp/dpsk-ocr-pr.png", "/Volumes/D/mllm/.tmp/dpsk-ocr"); + model.infer(tokenizer, "\n<|grounding|>Convert the document to markdown. ", "/Volumes/D/mllm/.tmp/dpsk-ocr-pr.png", + "/Volumes/D/mllm/.tmp/dpsk-ocr"); }); diff --git a/mllm/backends/cpu/CPUBackend.cpp b/mllm/backends/cpu/CPUBackend.cpp index 8a607c5ae..cc63857cc 100644 --- a/mllm/backends/cpu/CPUBackend.cpp +++ b/mllm/backends/cpu/CPUBackend.cpp @@ -50,6 +50,7 @@ #include "mllm/backends/cpu/ops/ViewOp.hpp" #include "mllm/backends/cpu/ops/VisionRoPEOp.hpp" #include "mllm/backends/cpu/ops/X2XOp.hpp" +#include "mllm/backends/cpu/ops/StackOp.hpp" namespace mllm::cpu { @@ -58,10 +59,10 @@ CPUBackend::CPUBackend() : Backend(kCPU, createCPUAllocator()) { CPUSubOpFactory, CPUMulOpFactory, CPUDivOpFactory, CPUNegOpFactory, CPUAbsOpFactory, CPULogOpFactory, CPUExpOpFactory, CPUSinOpFactory, CPUCosOpFactory, CPUReduceMaxOpFactory, CPUReduceMinOpFactory, CPUReduceSumOpFactory, CPUTransposeOpFactory, CPUPermuteOpFactory, CPUCastTypeOpFactory, CPUConcatOpFactory, - CPUContiguousOpFactory, CPUCopyOpFactory, CPUEmbeddingOpFactory, CPUSplitOpFactory, CPUViewOpFactory, - CPULayerNormOpFactory, CPURepeatOpFactory, CPUX2XOpFactory, CPUSoftmaxOpFactory, CPUSiLUOpFactory, - CPURMSNormOpFactory, CPUGELUOpFactory, CPUQuickGELUOpFactory, CPUReLUOpFactory, CPUMatMulOpFactory, - CPUFlashAttention2OpFactory, CPUSliceOpFactory, CPUVisionRoPEOpFactory, CPUParamOpFactory, + CPUStackOpFactory, CPUContiguousOpFactory, CPUCopyOpFactory, CPUEmbeddingOpFactory, CPUSplitOpFactory, + CPUViewOpFactory, CPULayerNormOpFactory, CPURepeatOpFactory, CPUX2XOpFactory, CPUSoftmaxOpFactory, + CPUSiLUOpFactory, CPURMSNormOpFactory, CPUGELUOpFactory, CPUQuickGELUOpFactory, CPUReLUOpFactory, + CPUMatMulOpFactory, CPUFlashAttention2OpFactory, CPUSliceOpFactory, CPUVisionRoPEOpFactory, CPUParamOpFactory, CPUMultimodalRoPEOpFactory, CPURoPEOpFactory, CPUCausalMaskOpFactory, CPUConv1DOpFactory, CPUConv3DOpFactory, CPUSTFTOpFactory, CPUISTFTOpFactory, CPUIndexOpFactory, CPUTopKOpFactory, CPUClipOpFactory, CPUMeanOpFactory, CPUKVCacheOpFactory, CPUPagedAttnOpFactory, CPUScatter2ShardsOpFactory, CPURadixAttnOpFactory, diff --git a/mllm/backends/cpu/ops/StackOp.cpp b/mllm/backends/cpu/ops/StackOp.cpp new file mode 100644 index 000000000..54a40671e --- /dev/null +++ b/mllm/backends/cpu/ops/StackOp.cpp @@ -0,0 +1,117 @@ +// Copyright (c) MLLM Team. +// Licensed under the MIT License. + +#include + +#include "mllm/backends/cpu/ops/StackOp.hpp" + +namespace mllm::cpu { + +CPUStackOp::CPUStackOp(const aops::StackOpOptions& options) : aops::StackOp(options) {} + +void CPUStackOp::forward(const std::vector& inputs, std::vector& outputs) { + bool is_all_contiguous = true; + for (auto& input : inputs) { is_all_contiguous &= input.isContiguous(); } + + int stack_dim = options_.dim; + const int input_rank = inputs[0].rank(); + if (stack_dim < 0) { stack_dim += (input_rank + 1); } + + const int N = static_cast(inputs.size()); + + if (is_all_contiguous) { + // Elements before stack_dim + int num_slices = 1; + for (int i = 0; i < stack_dim; ++i) { num_slices *= inputs[0].shape()[i]; } + + // Elements after stack_dim in input (inner block size) + int inner_size = 1; + for (int i = stack_dim; i < input_rank; ++i) { inner_size *= inputs[0].shape()[i]; } + + switch (outputs[0].dtype()) { + case kFloat32: { + mllm_fp32_t* out_ptr = outputs[0].ptr(); + for (int k = 0; k < N; ++k) { + const mllm_fp32_t* in_ptr = inputs[k].ptr(); + for (int slice = 0; slice < num_slices; ++slice) { + const mllm_fp32_t* src = in_ptr + slice * inner_size; + mllm_fp32_t* dst = out_ptr + (slice * N + k) * inner_size; + std::memcpy(dst, src, inner_size * sizeof(mllm_fp32_t)); + } + } + break; + } + case kFloat16: { + mllm_fp16_t* out_ptr = outputs[0].ptr(); + for (int k = 0; k < N; ++k) { + const mllm_fp16_t* in_ptr = inputs[k].ptr(); + for (int slice = 0; slice < num_slices; ++slice) { + const mllm_fp16_t* src = in_ptr + slice * inner_size; + mllm_fp16_t* dst = out_ptr + (slice * N + k) * inner_size; + std::memcpy(dst, src, inner_size * sizeof(mllm_fp16_t)); + } + } + break; + } + default: NYI("Type not supported in stack op"); + } + } else { + MLLM_WARN("Stack op has weak performance for non-contiguous inputs."); + + switch (outputs[0].dtype()) { + case kFloat32: { + for (int k = 0; k < N; ++k) { + auto input = inputs[k]; + std::vector input_shape = input.shape(); + + for (int64_t j = 0; j < input.numel(); ++j) { + std::vector input_index(input_shape.size(), 0); + int64_t temp = j; + for (int d = input_shape.size() - 1; d >= 0; --d) { + input_index[d] = temp % input_shape[d]; + temp /= input_shape[d]; + } + + std::vector output_index; + output_index.reserve(input_shape.size() + 1); + for (int d = 0; d < stack_dim; ++d) { output_index.push_back(input_index[d]); } + output_index.push_back(k); + for (int d = stack_dim; d < input_shape.size(); ++d) { output_index.push_back(input_index[d]); } + + mllm_fp32_t value = input.at(input_index); + outputs[0].at(output_index) = value; + } + } + break; + } + case kFloat16: { + for (int k = 0; k < N; ++k) { + auto input = inputs[k]; + std::vector input_shape = input.shape(); + + for (int64_t j = 0; j < input.numel(); ++j) { + std::vector input_index(input_shape.size(), 0); + int64_t temp = j; + for (int d = input_shape.size() - 1; d >= 0; --d) { + input_index[d] = temp % input_shape[d]; + temp /= input_shape[d]; + } + + std::vector output_index; + output_index.reserve(input_shape.size() + 1); + for (int d = 0; d < stack_dim; ++d) { output_index.push_back(input_index[d]); } + output_index.push_back(k); + for (int d = stack_dim; d < input_shape.size(); ++d) { output_index.push_back(input_index[d]); } + + mllm_fp16_t value = input.at(input_index); + outputs[0].at(output_index) = value; + } + } + break; + } + default: NYI("Type not supported in stack op"); + } + } +} + +} // namespace mllm::cpu \ No newline at end of file diff --git a/mllm/backends/cpu/ops/StackOp.hpp b/mllm/backends/cpu/ops/StackOp.hpp new file mode 100644 index 000000000..cce1f84c7 --- /dev/null +++ b/mllm/backends/cpu/ops/StackOp.hpp @@ -0,0 +1,25 @@ +// Copyright (c) MLLM Team. +// Licensed under the MIT License. + +#pragma once + +#include "mllm/core/BaseOp.hpp" +#include "mllm/core/aops/StackOp.hpp" + +namespace mllm::cpu { + +class CPUStackOp final : public aops::StackOp { + public: + explicit CPUStackOp(const aops::StackOpOptions& options); + + void forward(const std::vector& inputs, std::vector& outputs) override; +}; + +class CPUStackOpFactory : public TypedOpFactory { + public: + std::shared_ptr createOpImpl(const aops::StackOpOptions& options) override { + return std::make_shared(options); + } +}; + +} // namespace mllm::cpu \ No newline at end of file diff --git a/mllm/compile/ir/GeneratedRTTIKind.hpp b/mllm/compile/ir/GeneratedRTTIKind.hpp index 0706c7554..4c67ecea6 100644 --- a/mllm/compile/ir/GeneratedRTTIKind.hpp +++ b/mllm/compile/ir/GeneratedRTTIKind.hpp @@ -1,4 +1,4 @@ -// Auto generated: 2025-10-23 14:44:05 +// Auto generated: 2025-10-24 14:21:08 // do not modify this file #pragma once @@ -75,6 +75,7 @@ enum NodeKind : uint32_t { RK_Op_LinalgIROp_PadOp, RK_Op_LinalgIROp_InterpolateOp, RK_Op_LinalgIROp_EinsumOp, + RK_Op_LinalgIROp_StackOp, RK_Op_LinalgIROp_Last, RK_Op_GraphIROp, RK_Op_GraphIROp_SubGraphOp, diff --git a/mllm/compile/ir/NodeRTTIClassOfImpl.hpp b/mllm/compile/ir/NodeRTTIClassOfImpl.hpp index e22b28c04..4b60ab383 100644 --- a/mllm/compile/ir/NodeRTTIClassOfImpl.hpp +++ b/mllm/compile/ir/NodeRTTIClassOfImpl.hpp @@ -1,4 +1,4 @@ -// Auto generated: 2025-10-23 14:44:05 +// Auto generated: 2025-10-24 14:21:08 // do not modify this file #pragma once namespace mllm::ir { @@ -195,6 +195,9 @@ struct NodeRTTIClassOfImpl { #define RTTI_RK_OP_LINALGIROP_EINSUMOP_IMPL(v) \ return (v)->getKind() >= RK_Op_LinalgIROp_EinsumOp && (v)->getKind() <= RK_Op_LinalgIROp_EinsumOp +#define RTTI_RK_OP_LINALGIROP_STACKOP_IMPL(v) \ + return (v)->getKind() >= RK_Op_LinalgIROp_StackOp && (v)->getKind() <= RK_Op_LinalgIROp_StackOp + #define RTTI_RK_OP_GRAPHIROP_IMPL(v) return (v)->getKind() >= RK_Op_GraphIROp && (v)->getKind() <= RK_Op_GraphIROp_Last #define RTTI_RK_OP_GRAPHIROP_SUBGRAPHOP_IMPL(v) \ diff --git a/mllm/compile/ir/linalg/Op.cpp b/mllm/compile/ir/linalg/Op.cpp index b2e412b72..8b9021a95 100644 --- a/mllm/compile/ir/linalg/Op.cpp +++ b/mllm/compile/ir/linalg/Op.cpp @@ -104,5 +104,7 @@ LINALG_AOPS_DECL(OpTypes::kPagedAttn, PagedAttnOp); LINALG_AOPS_DECL(OpTypes::kLayerNorm2D, LayerNorm2DOp); LINALG_AOPS_DECL(OpTypes::kPad, PadOp); LINALG_AOPS_DECL(OpTypes::kInterpolate, InterpolateOp); +LINALG_AOPS_DECL(OpTypes::kEinsum, EinsumOp); +LINALG_AOPS_DECL(OpTypes::kStack, StackOp); } // namespace mllm::ir::linalg diff --git a/mllm/compile/ir/linalg/Op.hpp b/mllm/compile/ir/linalg/Op.hpp index ecb5532b5..fc3c1d4d4 100644 --- a/mllm/compile/ir/linalg/Op.hpp +++ b/mllm/compile/ir/linalg/Op.hpp @@ -67,6 +67,8 @@ class PagedAttnOp; class LayerNorm2DOp; class PadOp; class InterpolateOp; +class EinsumOp; +class StackOp; } // namespace mllm #define LINALG_AOPS_DEFINE(class_name, rtti_name) \ @@ -219,5 +221,7 @@ LINALG_AOPS_DEFINE(PagedAttnOp, PAGEDATTNOP); LINALG_AOPS_DEFINE(LayerNorm2DOp, LAYERNORM2DOP); LINALG_AOPS_DEFINE(PadOp, PADOP); LINALG_AOPS_DEFINE(InterpolateOp, INTERPOLATEOP); +LINALG_AOPS_DEFINE(EinsumOp, EINSUMOP); +LINALG_AOPS_DEFINE(StackOp, STACKOP); } // namespace mllm::ir::linalg diff --git a/mllm/compile/ir/rtti_kind_gen.py b/mllm/compile/ir/rtti_kind_gen.py index 5a8f43539..f84540a4b 100644 --- a/mllm/compile/ir/rtti_kind_gen.py +++ b/mllm/compile/ir/rtti_kind_gen.py @@ -276,6 +276,7 @@ def define_lianlg_ir(ir: dict): op.derive(Cls("PadOp")) op.derive(Cls("InterpolateOp")) op.derive(Cls("EinsumOp")) + op.derive(Cls("StackOp")) # value diff --git a/mllm/core/OpTypes.hpp b/mllm/core/OpTypes.hpp index fc1b11e80..03040ba29 100644 --- a/mllm/core/OpTypes.hpp +++ b/mllm/core/OpTypes.hpp @@ -80,6 +80,8 @@ enum class OpTypes : int32_t { // Padding Op kPad = 61, kInterpolate = 62, + kEinsum = 63, + kStack = 64, // Dynamic Op Start for user to register there own ops. kDynamicOp_Start = 4096, @@ -147,8 +149,12 @@ inline std::string optype2Str(OpTypes type) { case OpTypes::kGraphBegin: return "GraphBegin"; case OpTypes::kGraphEnd: return "GraphEnd"; case OpTypes::kPagedAttn: return "PagedAttn"; + case OpTypes::kRadixAttn: return "RadixAttn"; case OpTypes::kScatter2Shards: return "Scatter2Shards"; case OpTypes::kLayerNorm2D: return "LayerNorm2D"; + case OpTypes::kPad: return "Pad"; + case OpTypes::kInterpolate: return "Interpolate"; + case OpTypes::kStack: return "Stack"; case OpTypes::kOpType_End: return "OpType_End"; default: return "Unknown"; } diff --git a/mllm/core/aops/StackOp.cpp b/mllm/core/aops/StackOp.cpp new file mode 100644 index 000000000..858bf2732 --- /dev/null +++ b/mllm/core/aops/StackOp.cpp @@ -0,0 +1,63 @@ +// Copyright (c) MLLM Team. +// Licensed under the MIT License. + +#include "mllm/core/aops/StackOp.hpp" +#include "mllm/core/BaseOp.hpp" +#include "mllm/core/Tensor.hpp" +#include "mllm/utils/Common.hpp" +#include "mllm/compile/ir/linalg/Op.hpp" + +namespace mllm::aops { + +StackOp::StackOp(const StackOpOptions& options) : BaseOp(OpTypes::kStack), options_(options) {} + +void StackOp::load(const ParameterFile::ptr_t& ploader) { MLLM_EMPTY_SCOPE; } + +void StackOp::trace(void* trace_context, const std::vector& inputs, std::vector& outputs) { + auto ir_ctx = (ir::IRContext*)trace_context; + auto i_irs = ir::tensor::wrapTensors2TensorIR(ir_ctx, inputs); + auto o_irs = ir::tensor::wrapTensors2TensorIR(ir_ctx, outputs); + ir_ctx->create(shared_from_this(), i_irs, o_irs); +} + +void StackOp::forward(const std::vector& inputs, std::vector& outputs) { + NYI("StackOp::forward not implemented in aops base."); +} + +void StackOp::reshape(const std::vector& inputs, std::vector& outputs) { + if (inputs.empty()) { + MLLM_ERROR_EXIT(ExitCode::kCoreError, "StackOp: no inputs"); + return; + } + + const int n_dims = inputs[0].shape().size(); + int at_dim = options_.dim; + + // Normalize dim into [0, n_dims] + if (at_dim < 0) { at_dim += (n_dims + 1); } + if (at_dim < 0 || at_dim > n_dims) { + MLLM_ERROR_EXIT(ExitCode::kCoreError, "StackOp: dim {} out of range [0, {}]", at_dim, n_dims); + return; + } + + // Check all input shapes equal + for (size_t i = 1; i < inputs.size(); ++i) { + if (inputs[i].shape() != inputs[0].shape()) { + MLLM_ERROR_EXIT(ExitCode::kCoreError, "StackOp: input shape mismatch"); + return; + } + } + + // Build new shape by inserting a new dimension of size inputs.size() + std::vector new_shape; + new_shape.reserve(n_dims + 1); + for (int d = 0; d < at_dim; ++d) { new_shape.push_back(inputs[0].shape()[d]); } + new_shape.push_back(static_cast(inputs.size())); + for (int d = at_dim; d < n_dims; ++d) { new_shape.push_back(inputs[0].shape()[d]); } + + outputs.emplace_back(Tensor::empty(new_shape, inputs[0].dtype(), inputs[0].device())); +} + +void StackOp::setup(const std::vector& inputs, std::vector& outputs) { BaseOp::setup(inputs, outputs); } + +} // namespace mllm::aops \ No newline at end of file diff --git a/mllm/core/aops/StackOp.hpp b/mllm/core/aops/StackOp.hpp new file mode 100644 index 000000000..01c0ba4c5 --- /dev/null +++ b/mllm/core/aops/StackOp.hpp @@ -0,0 +1,35 @@ +// Copyright (c) MLLM Team. +// Licensed under the MIT License. + +#pragma once + +#include "mllm/core/BaseOp.hpp" +#include "mllm/core/ParameterFile.hpp" + +namespace mllm::aops { + +struct StackOpOptions : public BaseOpOptions { + int32_t dim; +}; + +class StackOp : public BaseOp { + public: + explicit StackOp(const StackOpOptions& options); + + void load(const ParameterFile::ptr_t& ploader) override; + + void trace(void* trace_context, const std::vector& inputs, std::vector& outputs) override; + + void forward(const std::vector& inputs, std::vector& outputs) override; + + void reshape(const std::vector& inputs, std::vector& outputs) override; + + void setup(const std::vector& inputs, std::vector& outputs) override; + + inline StackOpOptions& options() { return options_; } + + protected: + StackOpOptions options_; +}; + +} // namespace mllm::aops diff --git a/mllm/models/deepseek_ocr/modeling_deepseek_ocr.hpp b/mllm/models/deepseek_ocr/modeling_deepseek_ocr.hpp index bc16a477e..315c9d831 100644 --- a/mllm/models/deepseek_ocr/modeling_deepseek_ocr.hpp +++ b/mllm/models/deepseek_ocr/modeling_deepseek_ocr.hpp @@ -2,11 +2,12 @@ // Licensed under the MIT License. #pragma once +#include #include - #include #include "mllm/mllm.hpp" +#include "mllm/utils/StringHelper.hpp" #include "mllm/models/ARGeneration.hpp" #include "mllm/preprocessor/visual/ImageTransform.hpp" #include "mllm/models/deepseek_ocr/conversation_preprocess.hpp" @@ -49,14 +50,16 @@ class DeepseekOCRForCausalLM final : public nn::Module, public ARGeneration { auto processed_prompt = formatMessages(conversations, "plain", ""); - MLLM_INFO("processed_prompt: {}", processed_prompt); - // Global constant define const int PATCH_SIZE = 16; const int DOWN_SAMPLE_RATIO = 4; const std::string IMAGE_TOKEN = ""; const int64_t IMAGE_TOKEN_ID = 128815; + // Global states + int valid_img_tokens = 0; + float ratio = 1.f; + // Load image auto images = loadImages(conversations); @@ -65,6 +68,147 @@ class DeepseekOCRForCausalLM final : public nn::Module, public ARGeneration { /*std=*/std::vector{0.5, 0.5, 0.5}); // Split text with IMAGE_TOKEN + // Like what python does: text_splits = prompt.split(image_token) + auto text_splits = mllm::splitString(processed_prompt, IMAGE_TOKEN); + + // Processed states + std::vector tokenized_str; + std::vector images_seq_mask; + std::vector images_list; + std::vector images_crop_list; + std::vector> images_spatial_crop; + + // text_splits's length should be greater than images' length. + // text_splits.size() - images.size() = 1 + for (int idx = 0; idx < std::min(images.size(), text_splits.size()); ++idx) { + auto tokenized_sep = tokenizer.convert2VectorIds(tokenizer.tokenize(text_splits[idx])); + tokenized_str.insert(tokenized_str.end(), tokenized_sep.begin(), tokenized_sep.end()); + for (int _i = 0; _i < tokenized_sep.size(); ++_i) { + images_seq_mask.emplace_back(0); // emplace_back(false) + } + + // Get image in this loop + auto image = images[idx]; + std::tuple crop_ratio; + std::vector images_crop_raw; + + // Processing Image + if (crop_mode) { + if (image.h() <= 640 && image.w() <= 640) { + crop_ratio = {1, 1}; + } else { + if (crop_mode) { + auto p = dynamicPreprocess(image); + images_crop_raw = p.first; + crop_ratio = p.second; + } else { + crop_ratio = {1, 1}; + } + } + + // color=tuple(int(x * 255) for x in image_transform.mean + auto global_view = image.pad(base_size, base_size, (int)(255 * 0.5), (int)(255 * 0.5), (int)(255 * 0.5)); + + if (base_size == 1024) { + valid_img_tokens += (int)(256 * ratio); + } else if (base_size == 1280) { + valid_img_tokens += (int)(400 * ratio); + } else { + MLLM_RT_ASSERT(false); + } + + images_list.emplace_back(image_transform(global_view)); + + auto [width_crop_num, height_crop_num] = crop_ratio; + images_spatial_crop.emplace_back(width_crop_num, height_crop_num); + + // Processing crops + if (width_crop_num > 1 || height_crop_num > 1) { + for (const auto& _i : images_crop_raw) { images_crop_list.emplace_back(image_transform(_i)); } + } + + // Check if image_size is 640 + valid_img_tokens += images_crop_list.size() * 100; + + // Compute query + auto num_queries = std::ceil((image_size / PATCH_SIZE) / DOWN_SAMPLE_RATIO); + auto num_queries_base = std::ceil((base_size / PATCH_SIZE) / DOWN_SAMPLE_RATIO); + + // Do python logic below: + // tokenized_image = ([image_token_id] * num_queries_base + [image_token_id]) * num_queries_base + // tokenized_image += [image_token_id] + std::vector tokenized_image; + tokenized_image.reserve((num_queries_base + 1) * num_queries_base + 1); + for (int i = 0; i < num_queries_base; ++i) { + tokenized_image.insert(tokenized_image.end(), num_queries_base, IMAGE_TOKEN_ID); + tokenized_image.push_back(IMAGE_TOKEN_ID); + } + tokenized_image.push_back(IMAGE_TOKEN_ID); + + if (width_crop_num > 1 || height_crop_num > 1) { + for (int h = 0; h < num_queries * height_crop_num; ++h) { + tokenized_image.insert(tokenized_image.end(), num_queries * width_crop_num, IMAGE_TOKEN_ID); + tokenized_image.push_back(IMAGE_TOKEN_ID); + } + } + + tokenized_str.insert(tokenized_str.end(), tokenized_image.begin(), tokenized_image.end()); + for (int _i = 0; _i < tokenized_image.size(); ++_i) { images_seq_mask.emplace_back(true); } + } else { + NYI("crop_mode = false is not supported yet."); + } + } + + // Processing last text split + auto tokenized_sep = tokenizer.convert2VectorIds(tokenizer.tokenize(text_splits.back())); + tokenized_str.insert(tokenized_str.end(), tokenized_sep.begin(), tokenized_sep.end()); + images_seq_mask.insert(images_seq_mask.end(), tokenized_sep.size(), false); + + // Add bos token + // bos_id = 0 + // tokenized_str = [bos_id] + tokenized_str + // images_seq_mask = [False] + images_seq_mask + tokenized_str.insert(tokenized_str.begin(), 0); + images_seq_mask.insert(images_seq_mask.begin(), false); + + // Prepare Tensor to DeepSeek-OCR Model + auto input_ids = Tensor::fromVector(tokenized_str, {1, (int32_t)tokenized_str.size()}, kInt64); + auto images_seq_mask_tensor = Tensor::fromVector(images_seq_mask, {1, (int32_t)images_seq_mask.size()}, kFloat32); + auto images_ori_tensor = Tensor::nil(); + auto images_spatial_crop_tensor = Tensor::nil(); + auto images_crop_tensor = Tensor::nil(); + if (images_list.empty()) { + images_ori_tensor = Tensor::zeros({1, 3, image_size, image_size}); + images_spatial_crop_tensor = Tensor::zeros({1, 2}, kInt64); + images_crop_tensor = Tensor::zeros({1, 3, base_size, base_size}); + } else { + images_ori_tensor = nn::functional::stack(images_list, 0); + images_spatial_crop_tensor = Tensor::zeros({(int32_t)images_spatial_crop.size(), 2}, kInt64); + auto _ptr = images_spatial_crop_tensor.ptr(); + for (int _i = 0; _i < images_spatial_crop.size(); ++_i) { + auto [l, h] = images_spatial_crop[_i]; + _ptr[2 * _i + 0] = l; + _ptr[2 * _i + 1] = h; + } + if (!images_crop_list.empty()) { + images_crop_tensor = nn::functional::stack(images_crop_list, 0); + } else { + images_crop_tensor = Tensor::zeros({1, 3, base_size, base_size}); + } + } + + MLLM_INFO("BRAVO! U R HERE"); + print(input_ids.shape()); + print(input_ids); + print(images_seq_mask_tensor); + print(images_ori_tensor); + print(images_spatial_crop_tensor); + print(images_crop_tensor); + + // Run model. Use generate + // TODO + + // Post process data // TODO } }; diff --git a/mllm/models/deepseek_ocr/tokenization_deepseek_ocr.hpp b/mllm/models/deepseek_ocr/tokenization_deepseek_ocr.hpp index 88cc13965..28ff2f045 100644 --- a/mllm/models/deepseek_ocr/tokenization_deepseek_ocr.hpp +++ b/mllm/models/deepseek_ocr/tokenization_deepseek_ocr.hpp @@ -9,6 +9,9 @@ #pragma once #include +#include +#include +#include #include "mllm/preprocessor/tokenizers/BPE.hpp" #include "mllm/preprocessor/tokenizers/Unicode.hpp" @@ -35,14 +38,13 @@ class DpskOcrTokenizer final : public mllm::preprocessor::AutoTokenizer { } std::vector _tokenize(const std::string& str) override { - std::wstring text = preprocessor::utf8string2WideString(str); - std::replace(text.begin(), text.end(), L' ', SPIECE_UNDERLINE[0]); - auto tokens = bpe_._bpe(text); + // Replace spaces with SentencePiece underline before processing + // TODO - if (tokens.size() > 1 && tokens[0] == SPIECE_UNDERLINE && special_tokens_trie_.isSpecialToken(tokens[1])) { - tokens.erase(tokens.begin()); - } - return tokens; + auto processed_tokens = preTokenize(preprocessor::utf8string2WideString(str)); + std::vector ret; + + return ret; } std::vector tokenize(const std::string& str) override { return _tokenize(str); } @@ -57,6 +59,13 @@ class DpskOcrTokenizer final : public mllm::preprocessor::AutoTokenizer { return _detokenize(pos_idx); } + std::vector convert2VectorIds(const std::vector& strs) { + std::vector ids; + ids.reserve(strs.size()); + for (const auto& str : strs) { ids.emplace_back(bpe_._lookup_vocab(str)); } + return ids; + } + Tensor convert2Ids(const std::vector& strs) override { std::vector ids; ids.reserve(strs.size()); @@ -71,6 +80,123 @@ class DpskOcrTokenizer final : public mllm::preprocessor::AutoTokenizer { } private: + // "pre_tokenizer": { + // "type": "Sequence", + // "pretokenizers": [ + // { + // "type": "Split", + // "pattern": { + // "Regex": "\\p{N}{1,3}" + // }, + // "behavior": "Isolated", + // "invert": false + // }, + // { + // "type": "Split", + // "pattern": { + // "Regex": "[一-龥぀-ゟ゠-ヿ]+" + // }, + // "behavior": "Isolated", + // "invert": false + // }, + // { + // "type": "Split", + // "pattern": { + // "Regex": "[!\"#$%&'()*+,\\-./:;<=>?@\\[\\\\\\]^_`{|}~][A-Za-z]+|[^\r\n\\p{L}\\p{P}\\p{S}]?[\\p{L}\\p{M}]+| + // ?[\\p{P}\\p{S}]+[\r\n]*|\\s*[\r\n]+|\\s+(?!\\S)|\\s+" + // }, + // "behavior": "Isolated", + // "invert": false + // }, + // { + // "type": "ByteLevel", + // "add_prefix_space": false, + // "trim_offsets": true, + // "use_regex": false + // } + // ] + // } + std::vector preTokenize(const std::wstring& str) { + std::vector result; + size_t pos = 0; + + while (pos < str.size()) { + std::wstring matched; + bool found_match = false; + + // Pattern 1: Match 1-3 consecutive digits (\p{N}{1,3}) + if (preprocessor::isDigit(str[pos])) { + size_t start = pos; + size_t count = 0; + while (pos < str.size() && preprocessor::isDigit(str[pos]) && count < 3) { + ++pos; + ++count; + } + matched = str.substr(start, count); + found_match = true; + } + // Pattern 2: Match CJK characters ([一-龥぀-ゟ゠-ヿ]+) + else if ((str[pos] >= L'一' && str[pos] <= L'龥') || // Chinese characters + (str[pos] >= L'぀' && str[pos] <= L'ゟ') || // Hiragana + (str[pos] >= L'゠' && str[pos] <= L'ヿ')) { // Katakana + size_t start = pos; + while (pos < str.size() + && ((str[pos] >= L'一' && str[pos] <= L'龥') || (str[pos] >= L'぀' && str[pos] <= L'ゟ') + || (str[pos] >= L'゠' && str[pos] <= L'ヿ'))) { + ++pos; + } + matched = str.substr(start, pos - start); + found_match = true; + } + // Pattern 3: Complex pattern for other characters + // [!\"#$%&'()*+,\\-./:;<=>?@\\[\\\\\\]^_`{|}~][A-Za-z]+|[^\r\n\\p{L}\\p{P}\\p{S}]?[\\p{L}\\p{M}]+| + // ?[\\p{P}\\p{S}]+[\r\n]*|\\s*[\r\n]+|\\s+(?!\\S)|\\s+ + else { + // Handle punctuation followed by letters + if ((str[pos] >= L'!' && str[pos] <= L'/') || (str[pos] >= L':' && str[pos] <= L'@') + || (str[pos] >= L'[' && str[pos] <= L'`') || (str[pos] >= L'{' && str[pos] <= L'~')) { + size_t start = pos; + ++pos; // consume the punctuation + // Check if followed by letters + if (pos < str.size() && preprocessor::isLetter(str[pos])) { + while (pos < str.size() && preprocessor::isLetter(str[pos])) { ++pos; } + matched = str.substr(start, pos - start); + found_match = true; + } else { + pos = start + 1; // just consume the punctuation character + matched = str.substr(start, 1); + found_match = true; + } + } + // Handle letters with optional prefix + else if (preprocessor::isLetter(str[pos])) { + size_t start = pos; + while (pos < str.size() && preprocessor::isLetter(str[pos])) { ++pos; } + matched = str.substr(start, pos - start); + found_match = true; + } + // Handle whitespace + else if (std::iswspace(str[pos])) { + size_t start = pos; + while (pos < str.size() && std::iswspace(str[pos])) { ++pos; } + matched = str.substr(start, pos - start); + found_match = true; + } + // Handle any other character + else { + matched = str.substr(pos, 1); + ++pos; + found_match = true; + } + } + + // Add matched string to result + if (found_match) { result.push_back(matched); } + } + + return result; + } + // For text preprocessor::BPE bpe_; std::wstring SPIECE_UNDERLINE = L"▁"; diff --git a/mllm/nn/Functional.cpp b/mllm/nn/Functional.cpp index b6b6a5c71..72947719a 100644 --- a/mllm/nn/Functional.cpp +++ b/mllm/nn/Functional.cpp @@ -16,6 +16,7 @@ #include "mllm/core/aops/SiLUOp.hpp" #include "mllm/core/aops/PadOp.hpp" #include "mllm/core/aops/InterpolateOp.hpp" +#include "mllm/core/aops/StackOp.hpp" #include "mllm/engine/Context.hpp" namespace mllm::nn::functional { @@ -44,6 +45,10 @@ Tensor concat(const std::vector& ins, int32_t dim) { return Context::instance().buildOpAndSubmitTask(OpTypes::kConcat, aops::ConcatOpOptions{.dim = dim}, ins)[0]; } +Tensor stack(const std::vector& ins, int32_t dim) { + return Context::instance().buildOpAndSubmitTask(OpTypes::kStack, aops::StackOpOptions{.dim = dim}, ins)[0]; +} + Tensor flashAttention2(const Tensor& Q, const Tensor& K, const Tensor& V) { // Inputs is all BSHD format. diff --git a/mllm/nn/Functional.hpp b/mllm/nn/Functional.hpp index 9654b6663..1c4a7dd05 100644 --- a/mllm/nn/Functional.hpp +++ b/mllm/nn/Functional.hpp @@ -104,6 +104,8 @@ inline std::vector chunk(int32_t num, const Tensor& x, int32_t dim) { Tensor concat(const std::vector& ins, int32_t dim); +Tensor stack(const std::vector& ins, int32_t dim); + Tensor flashAttention2(const Tensor& Q, const Tensor& K, const Tensor& V); Tensor softmax(const Tensor& x, int32_t dim); diff --git a/mllm/preprocessor/visual/Image.cpp b/mllm/preprocessor/visual/Image.cpp index 5269352e0..3c06ae1a4 100644 --- a/mllm/preprocessor/visual/Image.cpp +++ b/mllm/preprocessor/visual/Image.cpp @@ -145,4 +145,66 @@ Image Image::crop(int left, int upper, int right, int lower) { return new_img; } +// Pad the image to target size with given RGB color. +// Semantics mirror PIL ImageOps.pad: resize to fit within target (keeping aspect ratio) +// then center the resized image on a canvas of target size filled with color. +Image Image::pad(int target_w, int target_h, unsigned char r, unsigned char g, unsigned char b) { + MLLM_RT_ASSERT(image_ptr_ != nullptr); + MLLM_RT_ASSERT_EQ(c_, 3); + MLLM_RT_ASSERT(target_w > 0 && target_h > 0); + + // Compute scale to fit within target while preserving aspect ratio + const double scale_w = static_cast(target_w) / static_cast(w_); + const double scale_h = static_cast(target_h) / static_cast(h_); + const double scale = std::min(scale_w, scale_h); + + int new_w = static_cast(std::round(static_cast(w_) * scale)); + int new_h = static_cast(std::round(static_cast(h_) * scale)); + new_w = std::max(1, new_w); + new_h = std::max(1, new_h); + + // Resize current image to the computed size + Image resized = this->resize(new_w, new_h); + + // Prepare output canvas filled with color + Image out; + out.w_ = target_w; + out.h_ = target_h; + out.c_ = 3; + unsigned char* canvas = static_cast(malloc(static_cast(target_w) * target_h * out.c_)); + MLLM_RT_ASSERT(canvas != nullptr); + + for (int y = 0; y < target_h; ++y) { + for (int x = 0; x < target_w; ++x) { + unsigned char* dst_px = canvas + (static_cast(y) * target_w + x) * out.c_; + dst_px[0] = r; + dst_px[1] = g; + dst_px[2] = b; + } + } + + // Compute offsets to center the resized image + const int offset_x = (target_w - new_w) / 2; + const int offset_y = (target_h - new_h) / 2; + + const unsigned char* src = static_cast(resized.image_ptr_->ptr_); + + // Blit resized image onto the canvas + for (int y = 0; y < new_h; ++y) { + const int dy = offset_y + y; + for (int x = 0; x < new_w; ++x) { + const int dx = offset_x + x; + unsigned char* dst_px = canvas + (static_cast(dy) * target_w + dx) * out.c_; + const unsigned char* src_px = src + (static_cast(y) * new_w + x) * resized.c_; + dst_px[0] = src_px[0]; + dst_px[1] = src_px[1]; + dst_px[2] = src_px[2]; + } + } + + out.image_ptr_ = std::make_shared<_ImagePtr>(); + out.image_ptr_->ptr_ = canvas; + return out; +} + } // namespace mllm \ No newline at end of file diff --git a/mllm/preprocessor/visual/Image.hpp b/mllm/preprocessor/visual/Image.hpp index f7bb76283..98e08c72b 100644 --- a/mllm/preprocessor/visual/Image.hpp +++ b/mllm/preprocessor/visual/Image.hpp @@ -34,6 +34,10 @@ class Image { // Out-of-bounds areas are padded with zeros. Returns a new Image. Image crop(int left, int upper, int right, int lower); + // Pad the image to target size (target_w, target_h) with RGB color. + // Mirrors PIL ImageOps.pad: scale to fit, then center-pad with color. + Image pad(int target_w, int target_h, unsigned char r, unsigned char g, unsigned char b); + void save(const std::string& fp); Tensor tensor(); diff --git a/mllm/preprocessor/visual/ImageTransform.cpp b/mllm/preprocessor/visual/ImageTransform.cpp index d3aefc710..6462cac38 100644 --- a/mllm/preprocessor/visual/ImageTransform.cpp +++ b/mllm/preprocessor/visual/ImageTransform.cpp @@ -87,7 +87,7 @@ Normalize::Normalize(const std::vector& mean, const std::vector& s } Tensor Normalize::apply(const Tensor& input) const { - Tensor src = input; // mutable copy + const Tensor& src = input; // Expect src in CHW layout MLLM_RT_ASSERT(src.rank() == 3); const int c = src.size(0); @@ -97,7 +97,7 @@ Tensor Normalize::apply(const Tensor& input) const { MLLM_RT_ASSERT_EQ(static_cast(std_.size()), c); // Work on a contiguous clone to simplify indexing - Tensor out = src.clone().contiguous(); + Tensor out = Tensor::empty(src.shape(), src.dtype(), src.device()).alloc(); float* ptr = out.ptr(); const size_t plane = static_cast(h) * static_cast(w); diff --git a/mllm/utils/StringHelper.hpp b/mllm/utils/StringHelper.hpp new file mode 100644 index 000000000..c30185bd0 --- /dev/null +++ b/mllm/utils/StringHelper.hpp @@ -0,0 +1,60 @@ +// Copyright (c) MLLM Team. +// Licensed under the MIT License. +#pragma once + +#include +#include +#include + +namespace mllm { +inline std::vector splitString(const std::string& s, const std::string& sep = "", int maxsplit = -1) { + std::vector out; + if (maxsplit == 0) { + out.push_back(s); + return out; + } + + const char* p = s.data(); + const char* end = p + s.size(); + + if (sep.empty()) { + auto skip_space = [&]() { + while (p != end && std::isspace(static_cast(*p))) ++p; + }; + skip_space(); + while (p != end) { + const char* start = p; + while (p != end && !std::isspace(static_cast(*p))) ++p; + out.emplace_back(start, p); + if (maxsplit >= 0 && --maxsplit == 0) { + out.emplace_back(p, end); + return out; + } + skip_space(); + } + return out; + } + + if (sep.size() == 1) { + const unsigned char needle = static_cast(sep[0]); + while (maxsplit != 0) { + const char* pos = reinterpret_cast(std::memchr(p, needle, end - p)); + if (!pos) break; + out.emplace_back(p, pos); + p = pos + 1; + if (maxsplit > 0) --maxsplit; + } + } else { + const auto n = sep.size(); + while (maxsplit != 0) { + const char* pos = std::search(p, end, sep.begin(), sep.end()); + if (pos == end) break; + out.emplace_back(p, pos); + p = pos + n; + if (maxsplit > 0) --maxsplit; + } + } + out.emplace_back(p, end); + return out; +} +} // namespace mllm From 593258ede0b3683abdd6ac4b851251b72b3e88a5 Mon Sep 17 00:00:00 2001 From: chenghuaWang <2923277184@qq.com> Date: Fri, 24 Oct 2025 22:15:18 +0800 Subject: [PATCH 11/25] feat(tokenizer): implement UTF-8 support for DeepSeek OCR tokenizer - Integrate utfcpp library for proper UTF-8 handling - Replace wide string-based tokenization with UTF-8 compatible logic - Update tokenizer to inherit from AutoTokenizerUTF8 - Add BPEUTF8 class for byte-level BPE tokenization with UTF-8 support - Modify preprocessing logic to handle Unicode characters correctly - Update CMakeLists.txt to include utfcpp headers and install rules This change enables the DeepSeek OCR model to correctly process Unicode text inputs, including CJK characters and other multibyte UTF-8 sequences, improving internationalization support. --- CMakeLists.txt | 8 + .../deepseek_ocr/modeling_deepseek_ocr.hpp | 4 +- .../tokenization_deepseek_ocr.hpp | 241 +++++---- .../preprocessor/tokenizers/AutoTokenizer.hpp | 16 +- mllm/preprocessor/tokenizers/BPEUTF8.cpp | 163 ++++++ mllm/preprocessor/tokenizers/BPEUTF8.hpp | 50 ++ third_party/utfcpp/include/utfcpp/utf8.h | 46 ++ .../utfcpp/include/utfcpp/utf8/checked.h | 359 +++++++++++++ third_party/utfcpp/include/utfcpp/utf8/core.h | 502 ++++++++++++++++++ .../utfcpp/include/utfcpp/utf8/cpp11.h | 70 +++ .../utfcpp/include/utfcpp/utf8/cpp17.h | 96 ++++ .../utfcpp/include/utfcpp/utf8/cpp20.h | 124 +++++ .../utfcpp/include/utfcpp/utf8/unchecked.h | 286 ++++++++++ 13 files changed, 1856 insertions(+), 109 deletions(-) create mode 100644 mllm/preprocessor/tokenizers/BPEUTF8.cpp create mode 100644 mllm/preprocessor/tokenizers/BPEUTF8.hpp create mode 100644 third_party/utfcpp/include/utfcpp/utf8.h create mode 100644 third_party/utfcpp/include/utfcpp/utf8/checked.h create mode 100644 third_party/utfcpp/include/utfcpp/utf8/core.h create mode 100644 third_party/utfcpp/include/utfcpp/utf8/cpp11.h create mode 100644 third_party/utfcpp/include/utfcpp/utf8/cpp17.h create mode 100644 third_party/utfcpp/include/utfcpp/utf8/cpp20.h create mode 100644 third_party/utfcpp/include/utfcpp/utf8/unchecked.h diff --git a/CMakeLists.txt b/CMakeLists.txt index a7167a537..6a2f4d1bc 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -210,6 +210,7 @@ set(MLLM_INCLUDE_DIR $ $ $ + $ $ $) set(MLLM_JSON_INCLUDE_DIR @@ -318,6 +319,13 @@ install( PATTERN "*.h" PATTERN "*.hpp") +install( + DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}/third_party/utfcpp/include/utfcpp/ + DESTINATION include/utfcpp/ + FILES_MATCHING + PATTERN "*.h" + PATTERN "*.hpp") + install( DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}/third_party/xxHash/include/xxHash/ DESTINATION include/xxHash/ diff --git a/mllm/models/deepseek_ocr/modeling_deepseek_ocr.hpp b/mllm/models/deepseek_ocr/modeling_deepseek_ocr.hpp index 315c9d831..557d9a8d6 100644 --- a/mllm/models/deepseek_ocr/modeling_deepseek_ocr.hpp +++ b/mllm/models/deepseek_ocr/modeling_deepseek_ocr.hpp @@ -81,7 +81,7 @@ class DeepseekOCRForCausalLM final : public nn::Module, public ARGeneration { // text_splits's length should be greater than images' length. // text_splits.size() - images.size() = 1 for (int idx = 0; idx < std::min(images.size(), text_splits.size()); ++idx) { - auto tokenized_sep = tokenizer.convert2VectorIds(tokenizer.tokenize(text_splits[idx])); + auto tokenized_sep = tokenizer.tokenize(text_splits[idx]); tokenized_str.insert(tokenized_str.end(), tokenized_sep.begin(), tokenized_sep.end()); for (int _i = 0; _i < tokenized_sep.size(); ++_i) { images_seq_mask.emplace_back(0); // emplace_back(false) @@ -160,7 +160,7 @@ class DeepseekOCRForCausalLM final : public nn::Module, public ARGeneration { } // Processing last text split - auto tokenized_sep = tokenizer.convert2VectorIds(tokenizer.tokenize(text_splits.back())); + auto tokenized_sep = tokenizer.tokenize(text_splits.back()); tokenized_str.insert(tokenized_str.end(), tokenized_sep.begin(), tokenized_sep.end()); images_seq_mask.insert(images_seq_mask.end(), tokenized_sep.size(), false); diff --git a/mllm/models/deepseek_ocr/tokenization_deepseek_ocr.hpp b/mllm/models/deepseek_ocr/tokenization_deepseek_ocr.hpp index 28ff2f045..e745effc9 100644 --- a/mllm/models/deepseek_ocr/tokenization_deepseek_ocr.hpp +++ b/mllm/models/deepseek_ocr/tokenization_deepseek_ocr.hpp @@ -13,14 +13,14 @@ #include #include -#include "mllm/preprocessor/tokenizers/BPE.hpp" +#include "mllm/preprocessor/tokenizers/BPEUTF8.hpp" #include "mllm/preprocessor/tokenizers/Unicode.hpp" #include "mllm/preprocessor/tokenizers/AutoTokenizer.hpp" namespace mllm::models::deepseek_ocr { // Actually is LlamaTokenizer -class DpskOcrTokenizer final : public mllm::preprocessor::AutoTokenizer { +class DpskOcrTokenizer final : public mllm::preprocessor::AutoTokenizerUTF8 { public: explicit DpskOcrTokenizer(const std::string& file_path) { // Init @@ -37,46 +37,14 @@ class DpskOcrTokenizer final : public mllm::preprocessor::AutoTokenizer { special_tokens_trie_.add(L"<|▁pad▁|>"); } - std::vector _tokenize(const std::string& str) override { - // Replace spaces with SentencePiece underline before processing + std::vector tokenize(const std::string& str) override { // TODO - - auto processed_tokens = preTokenize(preprocessor::utf8string2WideString(str)); - std::vector ret; - - return ret; - } - - std::vector tokenize(const std::string& str) override { return _tokenize(str); } - - std::wstring _detokenize(int64_t pos_idx) override { - // TODO - return L""; + return {}; } - std::wstring detokenize(int64_t pos_idx) override { + std::string detokenize(int64_t pos_idx) override { // TODO - return _detokenize(pos_idx); - } - - std::vector convert2VectorIds(const std::vector& strs) { - std::vector ids; - ids.reserve(strs.size()); - for (const auto& str : strs) { ids.emplace_back(bpe_._lookup_vocab(str)); } - return ids; - } - - Tensor convert2Ids(const std::vector& strs) override { - std::vector ids; - ids.reserve(strs.size()); - for (const auto& str : strs) { ids.emplace_back(bpe_._lookup_vocab(str)); } - Tensor ret = Tensor::empty({/*batch*/ 1, /*seq*/ (int32_t)ids.size()}, kInt64, kCPU) - .setMemType(kExtraInput) - .setName("llama-tokenizer-i0") - .alloc(); - auto ptr = ret.ptr(); - for (size_t i = 0; i < ids.size(); ++i) { ptr[i] = ids[i]; } - return ret; + return ""; } private: @@ -116,89 +84,150 @@ class DpskOcrTokenizer final : public mllm::preprocessor::AutoTokenizer { // } // ] // } - std::vector preTokenize(const std::wstring& str) { - std::vector result; - size_t pos = 0; - - while (pos < str.size()) { - std::wstring matched; - bool found_match = false; - - // Pattern 1: Match 1-3 consecutive digits (\p{N}{1,3}) - if (preprocessor::isDigit(str[pos])) { - size_t start = pos; - size_t count = 0; - while (pos < str.size() && preprocessor::isDigit(str[pos]) && count < 3) { - ++pos; - ++count; + std::vector preprocessToken(const std::string& token) { + std::vector out; + auto it = token.begin(); + auto end = token.end(); + + while (it != end) { + auto seg_start = it; + int digit_cnt = 0; + auto tmp = it; + while (digit_cnt < 3) { + uint32_t cp = 0; + auto next = tmp; + utf8::next(next, end); + if (next == tmp) break; + cp = utf8::peek_next(tmp, end); + if (!is_digit(cp)) break; + tmp = next; + ++digit_cnt; + } + if (digit_cnt > 0) { + out.emplace_back(seg_start, tmp); + it = tmp; + continue; + } + + uint32_t cp = utf8::peek_next(it, end); + if (is_cjk(cp)) { + auto tmp2 = it; + while (tmp2 != end) { + uint32_t nxt = utf8::peek_next(tmp2, end); + if (!is_cjk(nxt)) break; + utf8::next(tmp2, end); + } + out.emplace_back(seg_start, tmp2); + it = tmp2; + continue; + } + + if (is_punct_symbol(cp)) { + auto tmp3 = it; + utf8::next(tmp3, end); + if (tmp3 != end && is_letter(utf8::peek_next(tmp3, end))) { + utf8::next(tmp3, end); + out.emplace_back(seg_start, tmp3); + it = tmp3; + continue; } - matched = str.substr(start, count); - found_match = true; } - // Pattern 2: Match CJK characters ([一-龥぀-ゟ゠-ヿ]+) - else if ((str[pos] >= L'一' && str[pos] <= L'龥') || // Chinese characters - (str[pos] >= L'぀' && str[pos] <= L'ゟ') || // Hiragana - (str[pos] >= L'゠' && str[pos] <= L'ヿ')) { // Katakana - size_t start = pos; - while (pos < str.size() - && ((str[pos] >= L'一' && str[pos] <= L'龥') || (str[pos] >= L'぀' && str[pos] <= L'ゟ') - || (str[pos] >= L'゠' && str[pos] <= L'ヿ'))) { - ++pos; + + if (!is_letter(cp) && !is_space(cp) && !is_punct_symbol(cp)) { + auto tmp3 = it; + utf8::next(tmp3, end); + if (tmp3 != end && is_letter(utf8::peek_next(tmp3, end))) { + while (tmp3 != end) { + uint32_t nxt = utf8::peek_next(tmp3, end); + if (!is_letter(nxt)) break; + utf8::next(tmp3, end); + } + out.emplace_back(seg_start, tmp3); + it = tmp3; + continue; + } + } + + if (is_punct_symbol(cp)) { + auto tmp3 = it; + while (tmp3 != end) { + uint32_t nxt = utf8::peek_next(tmp3, end); + if (!is_punct_symbol(nxt)) break; + utf8::next(tmp3, end); + } + + while (tmp3 != end) { + uint32_t nxt = utf8::peek_next(tmp3, end); + if (nxt != 0x0A && nxt != 0x0D) break; + utf8::next(tmp3, end); } - matched = str.substr(start, pos - start); - found_match = true; + out.emplace_back(seg_start, tmp3); + it = tmp3; + continue; } - // Pattern 3: Complex pattern for other characters - // [!\"#$%&'()*+,\\-./:;<=>?@\\[\\\\\\]^_`{|}~][A-Za-z]+|[^\r\n\\p{L}\\p{P}\\p{S}]?[\\p{L}\\p{M}]+| - // ?[\\p{P}\\p{S}]+[\r\n]*|\\s*[\r\n]+|\\s+(?!\\S)|\\s+ - else { - // Handle punctuation followed by letters - if ((str[pos] >= L'!' && str[pos] <= L'/') || (str[pos] >= L':' && str[pos] <= L'@') - || (str[pos] >= L'[' && str[pos] <= L'`') || (str[pos] >= L'{' && str[pos] <= L'~')) { - size_t start = pos; - ++pos; // consume the punctuation - // Check if followed by letters - if (pos < str.size() && preprocessor::isLetter(str[pos])) { - while (pos < str.size() && preprocessor::isLetter(str[pos])) { ++pos; } - matched = str.substr(start, pos - start); - found_match = true; + + if (is_space(cp)) { + auto tmp3 = it; + bool has_nl = false; + while (tmp3 != end) { + uint32_t nxt = utf8::peek_next(tmp3, end); + if (nxt == 0x0A || nxt == 0x0D) { + has_nl = true; + utf8::next(tmp3, end); + } else if (is_space(nxt)) { + utf8::next(tmp3, end); } else { - pos = start + 1; // just consume the punctuation character - matched = str.substr(start, 1); - found_match = true; + break; } } - // Handle letters with optional prefix - else if (preprocessor::isLetter(str[pos])) { - size_t start = pos; - while (pos < str.size() && preprocessor::isLetter(str[pos])) { ++pos; } - matched = str.substr(start, pos - start); - found_match = true; + if (has_nl) { + out.emplace_back(seg_start, tmp3); + it = tmp3; + continue; } - // Handle whitespace - else if (std::iswspace(str[pos])) { - size_t start = pos; - while (pos < str.size() && std::iswspace(str[pos])) { ++pos; } - matched = str.substr(start, pos - start); - found_match = true; + auto tmp4 = tmp3; + while (tmp4 != end && is_space(utf8::peek_next(tmp4, end))) utf8::next(tmp4, end); + if (tmp4 == end) { + out.emplace_back(seg_start, tmp4); + it = tmp4; + continue; } - // Handle any other character - else { - matched = str.substr(pos, 1); - ++pos; - found_match = true; + while (tmp3 != end) { + uint32_t nxt = utf8::peek_next(tmp3, end); + if (!is_space(nxt)) break; + utf8::next(tmp3, end); } + out.emplace_back(seg_start, tmp3); + it = tmp3; + continue; } - - // Add matched string to result - if (found_match) { result.push_back(matched); } + utf8::next(it, end); + out.emplace_back(seg_start, it); } - return result; + return out; + } + + static inline bool is_digit(uint32_t cp) { return cp >= 0x30 && cp <= 0x39; } + + static inline bool is_cjk(uint32_t cp) { + return (cp >= 0x4E00 && cp <= 0x9FFF) || // CJK Unified Ideographs + (cp >= 0x3400 && cp <= 0x4DBF) || // CJK Extension A + (cp >= 0xF900 && cp <= 0xFAFF) || // CJK Compatibility + (cp >= 0x3040 && cp <= 0x309F) || // Hiragana + (cp >= 0x30A0 && cp <= 0x30FF); // Katakana } + static inline bool is_letter(uint32_t cp) { return (cp >= 0x41 && cp <= 0x5A) || (cp >= 0x61 && cp <= 0x7A); } + + static inline bool is_punct_symbol(uint32_t cp) { + return (cp >= 0x21 && cp <= 0x2F) || (cp >= 0x3A && cp <= 0x40) || (cp >= 0x5B && cp <= 0x60) || (cp >= 0x7B && cp <= 0x7E); + } + + static inline bool is_space(uint32_t cp) { return cp == 0x20 || cp == 0x09 || cp == 0x0A || cp == 0x0D; } + // For text - preprocessor::BPE bpe_; - std::wstring SPIECE_UNDERLINE = L"▁"; + preprocessor::BPEUTF8 bpe_; + std::string SPIECE_UNDERLINE = "▁"; }; } // namespace mllm::models::deepseek_ocr diff --git a/mllm/preprocessor/tokenizers/AutoTokenizer.hpp b/mllm/preprocessor/tokenizers/AutoTokenizer.hpp index e3f76b884..7dbe6af2c 100644 --- a/mllm/preprocessor/tokenizers/AutoTokenizer.hpp +++ b/mllm/preprocessor/tokenizers/AutoTokenizer.hpp @@ -16,6 +16,8 @@ #include using json = nlohmann::json; +#include + #include "mllm/core/Tensor.hpp" #include @@ -70,4 +72,16 @@ class AutoTokenizer { Trie special_tokens_trie_; }; -} // namespace mllm::preprocessor \ No newline at end of file +class AutoTokenizerUTF8 { + public: + void addSpecialToken(const std::string& special_token); + + virtual std::vector tokenize(const std::string& str) = 0; + + virtual std::string detokenize(int64_t pos_idx) = 0; + + protected: + Trie special_tokens_trie_; +}; + +} // namespace mllm::preprocessor diff --git a/mllm/preprocessor/tokenizers/BPEUTF8.cpp b/mllm/preprocessor/tokenizers/BPEUTF8.cpp new file mode 100644 index 000000000..da9e53637 --- /dev/null +++ b/mllm/preprocessor/tokenizers/BPEUTF8.cpp @@ -0,0 +1,163 @@ +// Copyright (c) MLLM Team. +// Licensed under the MIT License. + +#include +#include +#include +#include +#include + +#include +#include + +#include "mllm/utils/Common.hpp" +#include "mllm/preprocessor/tokenizers/BPEUTF8.hpp" + +namespace mllm::preprocessor { +bool BPEUTF8::initFromSentencePieceJson(const std::string& file_path) { + std::ifstream f(file_path); + if (!f.is_open()) { + MLLM_ERROR("BPEUTF8 Cannot open file {}", file_path); + return false; + } + auto json_data = nlohmann::json::parse(f); + + if (!json_data.contains("model") || !json_data["model"].contains("vocab") || !json_data["model"].contains("merges")) { + MLLM_ERROR("BPEUTF8 initFromSentencePieceJson need sentence piece json, but get {}", file_path); + return false; + } + + for (const auto& [key, value] : json_data["model"]["vocab"].items()) { + vocab_.insert({ + key, + value, + }); + vocab_inverse_.insert({ + value, + key, + }); + } + + for (const auto& add_token : json_data["added_tokens"].items()) { + int64_t id = add_token.value()["id"]; + std::string content = add_token.value()["content"]; + vocab_.insert({ + content, + id, + }); + vocab_inverse_.insert({ + id, + content, + }); + } + + int64_t cnt = 0; + for (auto& merge_item : json_data["model"]["merges"]) { + if (merge_item.is_string()) { + std::string wide_merge_item = merge_item; + + // 0x20 will only represent space in utf8. we can use this unsafe method to speed up. + auto blank_pos = wide_merge_item.find(' '); + auto first = wide_merge_item.substr(0, blank_pos); + auto second = wide_merge_item.substr(blank_pos + 1); + bpe_ranks_.insert({{first, second}, cnt++}); + } else if (merge_item.is_array()) { + bpe_ranks_.insert({{merge_item[0], merge_item[1]}, cnt++}); + } + } + + return true; +} + +// ByteLevel BPE +std::vector BPEUTF8::_bpe(const std::string& token) { + // Treats token as a sequence of bytes + std::vector word; + for (const auto& w : token) word.emplace_back(1, w); + + auto pairs = _get_pairs(word); + if (pairs.empty()) return {token}; + + while (true) { + bool has_bigram = false; + int64_t rank_bigram = std::numeric_limits::max(); + std::pair bigram; + + for (const auto& p : pairs) { + if (bpe_ranks_.count(p)) { + auto rank = bpe_ranks_.at(p); + if (rank < rank_bigram) { + rank_bigram = rank; + bigram = p; + has_bigram = true; + } + } + } + + if (!has_bigram) { break; } + + auto [first, second] = bigram; + std::vector new_word; + int i = 0; + + while (i < word.size()) { + // Find the next occurrence of 'first' starting at i + int j = i; + while (j < word.size() && word[j] != first) { j++; } + + // Add elements from i to j-1 (if any) + if (j > i) { new_word.insert(new_word.end(), word.begin() + i, word.begin() + j); } + + // Check if we can merge at position j + if (j < word.size() - 1 && word[j] == first && word[j + 1] == second) { + new_word.push_back(first + second); + i = j + 2; // Skip both merged elements + } else if (j < word.size()) { + new_word.push_back(word[j]); + i = j + 1; + } else { + i = j; // j == word.size() + } + } + + word = std::move(new_word); + if (word.size() == 1) { + break; + } else { + pairs = _get_pairs(word); + } + } + + return word; +} + +int64_t BPEUTF8::_lookup_vocab(const std::string& token) { + if (vocab_.find(token) != vocab_.end()) { + return vocab_[token]; + } else { + MLLM_WARN("Cannot find token: {} in BPEUTF8 vocab", token); + return 0; + } +} + +std::string BPEUTF8::_lookup_inverse_vocab(int64_t idx) { + if (vocab_inverse_.find(idx) != vocab_inverse_.end()) { + return vocab_inverse_[idx]; + } else { + MLLM_WARN("Cannot find token in BPEUTF8 vocab. When doing _lookup_inverse_vocab"); + return {}; + } +} + +std::unordered_set, BPEUTF8PairHash> BPEUTF8::_get_pairs( + const std::vector& word) { + std::unordered_set, BPEUTF8PairHash> pairs; + if (word.size() < 2) return pairs; + auto prev_char = word[0]; + for (size_t i = 1; i < word.size(); ++i) { + pairs.insert({prev_char, word[i]}); + prev_char = word[i]; + } + return pairs; +} +} // namespace mllm::preprocessor diff --git a/mllm/preprocessor/tokenizers/BPEUTF8.hpp b/mllm/preprocessor/tokenizers/BPEUTF8.hpp new file mode 100644 index 000000000..4a2f24457 --- /dev/null +++ b/mllm/preprocessor/tokenizers/BPEUTF8.hpp @@ -0,0 +1,50 @@ +// Copyright (c) MLLM Team. +// Licensed under the MIT License. +#pragma once + +// TODO +// Documents +// This Byte-Level BPE(BBPE) works as an fully correct byte-level BPE tokenizer. + +#include +#include +#include +// CPP's support of UTF-8 is weak. We use the utfcpp library to handle UTF-8 strings. +#include +#include + +// Remember: +// utfcpp use +// std::string to represent UTF-8 strings. +// std::u16string to represent UTF-16 strings. +// std::u32string to represent UTF-32 strings. + +namespace mllm::preprocessor { + +struct BPEUTF8PairHash { + std::size_t operator()(const std::pair& key) const { + std::size_t h1 = std::hash{}(key.first + key.second); + return h1; + } +}; + +class BPEUTF8 { + public: + // BPE can accept sentence piece's json foramt. + bool initFromSentencePieceJson(const std::string& file_path); + + std::vector _bpe(const std::string& token); + + int64_t _lookup_vocab(const std::string& token); + + std::string _lookup_inverse_vocab(int64_t idx); + + private: + std::unordered_set, BPEUTF8PairHash> _get_pairs(const std::vector& word); + + std::unordered_map vocab_; + std::unordered_map vocab_inverse_; + std::unordered_map, int64_t, BPEUTF8PairHash> bpe_ranks_; +}; + +} // namespace mllm::preprocessor diff --git a/third_party/utfcpp/include/utfcpp/utf8.h b/third_party/utfcpp/include/utfcpp/utf8.h new file mode 100644 index 000000000..b51353093 --- /dev/null +++ b/third_party/utfcpp/include/utfcpp/utf8.h @@ -0,0 +1,46 @@ +// Copyright 2006 Nemanja Trifunovic + +/* +Permission is hereby granted, free of charge, to any person or organization +obtaining a copy of the software and accompanying documentation covered by +this license (the "Software") to use, reproduce, display, distribute, +execute, and transmit the Software, and to prepare derivative works of the +Software, and to permit third-parties to whom the Software is furnished to +do so, all subject to the following: + +The copyright notices in the Software and this entire statement, including +the above license grant, this restriction and the following disclaimer, +must be included in all copies of the Software, in whole or in part, and +all derivative works of the Software, unless such copies or derivative +works are solely in the form of machine-executable object code generated by +a source language processor. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE, TITLE AND NON-INFRINGEMENT. IN NO EVENT +SHALL THE COPYRIGHT HOLDERS OR ANYONE DISTRIBUTING THE SOFTWARE BE LIABLE +FOR ANY DAMAGES OR OTHER LIABILITY, WHETHER IN CONTRACT, TORT OR OTHERWISE, +ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +DEALINGS IN THE SOFTWARE. +*/ + + +#ifndef UTF8_FOR_CPP_2675DCD0_9480_4c0c_B92A_CC14C027B731 +#define UTF8_FOR_CPP_2675DCD0_9480_4c0c_B92A_CC14C027B731 + +/* +To control the C++ language version used by the library, you can define UTF_CPP_CPLUSPLUS macro +and set it to one of the values used by the __cplusplus predefined macro. + +For instance, + #define UTF_CPP_CPLUSPLUS 199711L +will cause the UTF-8 CPP library to use only types and language features available in the C++ 98 standard. +Some library features will be disabled. + +If you leave UTF_CPP_CPLUSPLUS undefined, it will be internally assigned to __cplusplus. +*/ + +#include "utf8/checked.h" +#include "utf8/unchecked.h" + +#endif // header guard diff --git a/third_party/utfcpp/include/utfcpp/utf8/checked.h b/third_party/utfcpp/include/utfcpp/utf8/checked.h new file mode 100644 index 000000000..96ceb4d50 --- /dev/null +++ b/third_party/utfcpp/include/utfcpp/utf8/checked.h @@ -0,0 +1,359 @@ +// Copyright 2006-2016 Nemanja Trifunovic + +/* +Permission is hereby granted, free of charge, to any person or organization +obtaining a copy of the software and accompanying documentation covered by +this license (the "Software") to use, reproduce, display, distribute, +execute, and transmit the Software, and to prepare derivative works of the +Software, and to permit third-parties to whom the Software is furnished to +do so, all subject to the following: + +The copyright notices in the Software and this entire statement, including +the above license grant, this restriction and the following disclaimer, +must be included in all copies of the Software, in whole or in part, and +all derivative works of the Software, unless such copies or derivative +works are solely in the form of machine-executable object code generated by +a source language processor. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE, TITLE AND NON-INFRINGEMENT. IN NO EVENT +SHALL THE COPYRIGHT HOLDERS OR ANYONE DISTRIBUTING THE SOFTWARE BE LIABLE +FOR ANY DAMAGES OR OTHER LIABILITY, WHETHER IN CONTRACT, TORT OR OTHERWISE, +ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +DEALINGS IN THE SOFTWARE. +*/ + + +#ifndef UTF8_FOR_CPP_CHECKED_H_2675DCD0_9480_4c0c_B92A_CC14C027B731 +#define UTF8_FOR_CPP_CHECKED_H_2675DCD0_9480_4c0c_B92A_CC14C027B731 + +#include "core.h" +#include + +namespace utf8 +{ + // Base for the exceptions that may be thrown from the library + class exception : public ::std::exception { + }; + + // Exceptions that may be thrown from the library functions. + class invalid_code_point : public exception { + utfchar32_t cp; + public: + invalid_code_point(utfchar32_t codepoint) : cp(codepoint) {} + virtual const char* what() const UTF_CPP_NOEXCEPT UTF_CPP_OVERRIDE { return "Invalid code point"; } + utfchar32_t code_point() const {return cp;} + }; + + class invalid_utf8 : public exception { + utfchar8_t u8; + public: + invalid_utf8 (utfchar8_t u) : u8(u) {} + invalid_utf8 (char c) : u8(static_cast(c)) {} + virtual const char* what() const UTF_CPP_NOEXCEPT UTF_CPP_OVERRIDE { return "Invalid UTF-8"; } + utfchar8_t utf8_octet() const {return u8;} + }; + + class invalid_utf16 : public exception { + utfchar16_t u16; + public: + invalid_utf16 (utfchar16_t u) : u16(u) {} + virtual const char* what() const UTF_CPP_NOEXCEPT UTF_CPP_OVERRIDE { return "Invalid UTF-16"; } + utfchar16_t utf16_word() const {return u16;} + }; + + class not_enough_room : public exception { + public: + virtual const char* what() const UTF_CPP_NOEXCEPT UTF_CPP_OVERRIDE { return "Not enough space"; } + }; + + /// The library API - functions intended to be called by the users + + template + octet_iterator append(utfchar32_t cp, octet_iterator result) + { + if (!utf8::internal::is_code_point_valid(cp)) + throw invalid_code_point(cp); + + return internal::append(cp, result); + } + + inline void append(utfchar32_t cp, std::string& s) + { + append(cp, std::back_inserter(s)); + } + + template + word_iterator append16(utfchar32_t cp, word_iterator result) + { + if (!utf8::internal::is_code_point_valid(cp)) + throw invalid_code_point(cp); + + return internal::append16(cp, result); + } + + template + output_iterator replace_invalid(octet_iterator start, octet_iterator end, output_iterator out, utfchar32_t replacement) + { + while (start != end) { + octet_iterator sequence_start = start; + internal::utf_error err_code = utf8::internal::validate_next(start, end); + switch (err_code) { + case internal::UTF8_OK : + for (octet_iterator it = sequence_start; it != start; ++it) + *out++ = *it; + break; + case internal::NOT_ENOUGH_ROOM: + out = utf8::append (replacement, out); + start = end; + break; + case internal::INVALID_LEAD: + out = utf8::append (replacement, out); + ++start; + break; + case internal::INCOMPLETE_SEQUENCE: + case internal::OVERLONG_SEQUENCE: + case internal::INVALID_CODE_POINT: + out = utf8::append (replacement, out); + ++start; + // just one replacement mark for the sequence + while (start != end && utf8::internal::is_trail(*start)) + ++start; + break; + } + } + return out; + } + + template + inline output_iterator replace_invalid(octet_iterator start, octet_iterator end, output_iterator out) + { + static const utfchar32_t replacement_marker = static_cast(utf8::internal::mask16(0xfffd)); + return utf8::replace_invalid(start, end, out, replacement_marker); + } + + inline std::string replace_invalid(const std::string& s, utfchar32_t replacement) + { + std::string result; + replace_invalid(s.begin(), s.end(), std::back_inserter(result), replacement); + return result; + } + + inline std::string replace_invalid(const std::string& s) + { + std::string result; + replace_invalid(s.begin(), s.end(), std::back_inserter(result)); + return result; + } + + template + utfchar32_t next(octet_iterator& it, octet_iterator end) + { + utfchar32_t cp = 0; + internal::utf_error err_code = utf8::internal::validate_next(it, end, cp); + switch (err_code) { + case internal::UTF8_OK : + break; + case internal::NOT_ENOUGH_ROOM : + throw not_enough_room(); + case internal::INVALID_LEAD : + case internal::INCOMPLETE_SEQUENCE : + case internal::OVERLONG_SEQUENCE : + throw invalid_utf8(static_cast(*it)); + case internal::INVALID_CODE_POINT : + throw invalid_code_point(cp); + } + return cp; + } + + template + utfchar32_t next16(word_iterator& it, word_iterator end) + { + utfchar32_t cp = 0; + internal::utf_error err_code = utf8::internal::validate_next16(it, end, cp); + if (err_code == internal::NOT_ENOUGH_ROOM) + throw not_enough_room(); + return cp; + } + + template + utfchar32_t peek_next(octet_iterator it, octet_iterator end) + { + return utf8::next(it, end); + } + + template + utfchar32_t prior(octet_iterator& it, octet_iterator start) + { + // can't do much if it == start + if (it == start) + throw not_enough_room(); + + octet_iterator end = it; + // Go back until we hit either a lead octet or start + while (utf8::internal::is_trail(*(--it))) + if (it == start) + throw invalid_utf8(*it); // error - no lead byte in the sequence + return utf8::peek_next(it, end); + } + + template + void advance (octet_iterator& it, distance_type n, octet_iterator end) + { + const distance_type zero(0); + if (n < zero) { + // backward + for (distance_type i = n; i < zero; ++i) + utf8::prior(it, end); + } else { + // forward + for (distance_type i = zero; i < n; ++i) + utf8::next(it, end); + } + } + + template + typename std::iterator_traits::difference_type + distance (octet_iterator first, octet_iterator last) + { + typename std::iterator_traits::difference_type dist; + for (dist = 0; first < last; ++dist) + utf8::next(first, last); + return dist; + } + + template + octet_iterator utf16to8 (u16bit_iterator start, u16bit_iterator end, octet_iterator result) + { + while (start != end) { + utfchar32_t cp = static_cast(utf8::internal::mask16(*start++)); + // Take care of surrogate pairs first + if (utf8::internal::is_lead_surrogate(cp)) { + if (start != end) { + const utfchar32_t trail_surrogate = static_cast(utf8::internal::mask16(*start++)); + if (utf8::internal::is_trail_surrogate(trail_surrogate)) + cp = (cp << 10) + trail_surrogate + internal::SURROGATE_OFFSET; + else + throw invalid_utf16(static_cast(trail_surrogate)); + } + else + throw invalid_utf16(static_cast(cp)); + + } + // Lone trail surrogate + else if (utf8::internal::is_trail_surrogate(cp)) + throw invalid_utf16(static_cast(cp)); + + result = utf8::append(cp, result); + } + return result; + } + + template + u16bit_iterator utf8to16 (octet_iterator start, octet_iterator end, u16bit_iterator result) + { + while (start < end) { + const utfchar32_t cp = utf8::next(start, end); + if (cp > 0xffff) { //make a surrogate pair + *result++ = static_cast((cp >> 10) + internal::LEAD_OFFSET); + *result++ = static_cast((cp & 0x3ff) + internal::TRAIL_SURROGATE_MIN); + } + else + *result++ = static_cast(cp); + } + return result; + } + + template + octet_iterator utf32to8 (u32bit_iterator start, u32bit_iterator end, octet_iterator result) + { + while (start != end) + result = utf8::append(*(start++), result); + + return result; + } + + template + u32bit_iterator utf8to32 (octet_iterator start, octet_iterator end, u32bit_iterator result) + { + while (start < end) + (*result++) = utf8::next(start, end); + + return result; + } + + // The iterator class + template + class iterator { + octet_iterator it; + octet_iterator range_start; + octet_iterator range_end; + public: + typedef utfchar32_t value_type; + typedef utfchar32_t* pointer; + typedef utfchar32_t& reference; + typedef std::ptrdiff_t difference_type; + typedef std::bidirectional_iterator_tag iterator_category; + iterator () {} + explicit iterator (const octet_iterator& octet_it, + const octet_iterator& rangestart, + const octet_iterator& rangeend) : + it(octet_it), range_start(rangestart), range_end(rangeend) + { + if (it < range_start || it > range_end) + throw std::out_of_range("Invalid utf-8 iterator position"); + } + // the default "big three" are OK + octet_iterator base () const { return it; } + utfchar32_t operator * () const + { + octet_iterator temp = it; + return utf8::next(temp, range_end); + } + bool operator == (const iterator& rhs) const + { + if (range_start != rhs.range_start || range_end != rhs.range_end) + throw std::logic_error("Comparing utf-8 iterators defined with different ranges"); + return (it == rhs.it); + } + bool operator != (const iterator& rhs) const + { + return !(operator == (rhs)); + } + iterator& operator ++ () + { + utf8::next(it, range_end); + return *this; + } + iterator operator ++ (int) + { + iterator temp = *this; + utf8::next(it, range_end); + return temp; + } + iterator& operator -- () + { + utf8::prior(it, range_start); + return *this; + } + iterator operator -- (int) + { + iterator temp = *this; + utf8::prior(it, range_start); + return temp; + } + }; // class iterator + +} // namespace utf8 + +#if UTF_CPP_CPLUSPLUS >= 202002L // C++ 20 or later +#include "cpp20.h" +#elif UTF_CPP_CPLUSPLUS >= 201703L // C++ 17 or later +#include "cpp17.h" +#elif UTF_CPP_CPLUSPLUS >= 201103L // C++ 11 or later +#include "cpp11.h" +#endif // C++ 11 or later + +#endif //header guard + diff --git a/third_party/utfcpp/include/utfcpp/utf8/core.h b/third_party/utfcpp/include/utfcpp/utf8/core.h new file mode 100644 index 000000000..064f838fd --- /dev/null +++ b/third_party/utfcpp/include/utfcpp/utf8/core.h @@ -0,0 +1,502 @@ +// Copyright 2006 Nemanja Trifunovic + +/* +Permission is hereby granted, free of charge, to any person or organization +obtaining a copy of the software and accompanying documentation covered by +this license (the "Software") to use, reproduce, display, distribute, +execute, and transmit the Software, and to prepare derivative works of the +Software, and to permit third-parties to whom the Software is furnished to +do so, all subject to the following: + +The copyright notices in the Software and this entire statement, including +the above license grant, this restriction and the following disclaimer, +must be included in all copies of the Software, in whole or in part, and +all derivative works of the Software, unless such copies or derivative +works are solely in the form of machine-executable object code generated by +a source language processor. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE, TITLE AND NON-INFRINGEMENT. IN NO EVENT +SHALL THE COPYRIGHT HOLDERS OR ANYONE DISTRIBUTING THE SOFTWARE BE LIABLE +FOR ANY DAMAGES OR OTHER LIABILITY, WHETHER IN CONTRACT, TORT OR OTHERWISE, +ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +DEALINGS IN THE SOFTWARE. +*/ + + +#ifndef UTF8_FOR_CPP_CORE_H_2675DCD0_9480_4c0c_B92A_CC14C027B731 +#define UTF8_FOR_CPP_CORE_H_2675DCD0_9480_4c0c_B92A_CC14C027B731 + +#include +#include +#include + +// Determine the C++ standard version. +// If the user defines UTF_CPP_CPLUSPLUS, use that. +// Otherwise, trust the unreliable predefined macro __cplusplus + +#if !defined UTF_CPP_CPLUSPLUS + #define UTF_CPP_CPLUSPLUS __cplusplus +#endif + +#if UTF_CPP_CPLUSPLUS >= 201103L // C++ 11 or later + #define UTF_CPP_OVERRIDE override + #define UTF_CPP_NOEXCEPT noexcept + #define UTF_CPP_STATIC_ASSERT(condition) static_assert(condition, "UTFCPP static assert"); +#else // C++ 98/03 + #define UTF_CPP_OVERRIDE + #define UTF_CPP_NOEXCEPT throw() + // Simulate static_assert: + template struct StaticAssert {static void utf8_static_assert() {char static_assert_impl[Condition ? 1 : 0]; } }; + template <> struct StaticAssert {static void utf8_static_assert() {}}; + #define UTF_CPP_STATIC_ASSERT(condition) StaticAssert::utf8_static_assert(); +#endif // C++ 11 or later + + +namespace utf8 +{ +// The typedefs for 8-bit, 16-bit and 32-bit code units +#if UTF_CPP_CPLUSPLUS >= 201103L // C++ 11 or later + #if UTF_CPP_CPLUSPLUS >= 202002L // C++ 20 or later + typedef char8_t utfchar8_t; + #else // C++ 11/14/17 + typedef unsigned char utfchar8_t; + #endif + typedef char16_t utfchar16_t; + typedef char32_t utfchar32_t; +#else // C++ 98/03 + typedef unsigned char utfchar8_t; + typedef unsigned short utfchar16_t; + typedef unsigned int utfchar32_t; +#endif // C++ 11 or later + +// Helper code - not intended to be directly called by the library users. May be changed at any time +namespace internal +{ + // Unicode constants + // Leading (high) surrogates: 0xd800 - 0xdbff + // Trailing (low) surrogates: 0xdc00 - 0xdfff + const utfchar16_t LEAD_SURROGATE_MIN = 0xd800u; + const utfchar16_t LEAD_SURROGATE_MAX = 0xdbffu; + const utfchar16_t TRAIL_SURROGATE_MIN = 0xdc00u; + const utfchar16_t TRAIL_SURROGATE_MAX = 0xdfffu; + const utfchar16_t LEAD_OFFSET = 0xd7c0u; // LEAD_SURROGATE_MIN - (0x10000 >> 10) + const utfchar32_t SURROGATE_OFFSET = 0xfca02400u; // 0x10000u - (LEAD_SURROGATE_MIN << 10) - TRAIL_SURROGATE_MIN + + // Maximum valid value for a Unicode code point + const utfchar32_t CODE_POINT_MAX = 0x0010ffffu; + + template + inline utfchar8_t mask8(octet_type oc) + { + return static_cast(0xff & oc); + } + + template + inline utfchar16_t mask16(u16_type oc) + { + return static_cast(0xffff & oc); + } + + template + inline bool is_trail(octet_type oc) + { + return ((utf8::internal::mask8(oc) >> 6) == 0x2); + } + + inline bool is_lead_surrogate(utfchar32_t cp) + { + return (cp >= static_cast(LEAD_SURROGATE_MIN) && cp <= static_cast(LEAD_SURROGATE_MAX)); + } + + inline bool is_trail_surrogate(utfchar32_t cp) + { + return (cp >= static_cast(TRAIL_SURROGATE_MIN) && cp <= static_cast(TRAIL_SURROGATE_MAX)); + } + + inline bool is_surrogate(utfchar32_t cp) + { + return (cp >= static_cast(LEAD_SURROGATE_MIN) && cp <= static_cast(TRAIL_SURROGATE_MAX)); + } + + inline bool is_code_point_valid(utfchar32_t cp) + { + return (cp <= CODE_POINT_MAX && !utf8::internal::is_surrogate(cp)); + } + + inline bool is_in_bmp(utfchar32_t cp) + { + return cp < utfchar32_t(0x10000); + } + + template + int sequence_length(octet_iterator lead_it) + { + const utfchar8_t lead = utf8::internal::mask8(*lead_it); + if (lead < 0x80) + return 1; + else if ((lead >> 5) == 0x6) + return 2; + else if ((lead >> 4) == 0xe) + return 3; + else if ((lead >> 3) == 0x1e) + return 4; + else + return 0; + } + + inline bool is_overlong_sequence(utfchar32_t cp, int length) + { + if (cp < 0x80) { + if (length != 1) + return true; + } + else if (cp < 0x800) { + if (length != 2) + return true; + } + else if (cp < 0x10000) { + if (length != 3) + return true; + } + return false; + } + + enum utf_error {UTF8_OK, NOT_ENOUGH_ROOM, INVALID_LEAD, INCOMPLETE_SEQUENCE, OVERLONG_SEQUENCE, INVALID_CODE_POINT}; + + /// Helper for get_sequence_x + template + utf_error increase_safely(octet_iterator& it, const octet_iterator end) + { + if (++it == end) + return NOT_ENOUGH_ROOM; + + if (!utf8::internal::is_trail(*it)) + return INCOMPLETE_SEQUENCE; + + return UTF8_OK; + } + + #define UTF8_CPP_INCREASE_AND_RETURN_ON_ERROR(IT, END) {utf_error ret = increase_safely(IT, END); if (ret != UTF8_OK) return ret;} + + /// get_sequence_x functions decode utf-8 sequences of the length x + template + utf_error get_sequence_1(octet_iterator& it, octet_iterator end, utfchar32_t& code_point) + { + if (it == end) + return NOT_ENOUGH_ROOM; + + code_point = static_cast(utf8::internal::mask8(*it)); + + return UTF8_OK; + } + + template + utf_error get_sequence_2(octet_iterator& it, octet_iterator end, utfchar32_t& code_point) + { + if (it == end) + return NOT_ENOUGH_ROOM; + + code_point = static_cast(utf8::internal::mask8(*it)); + + UTF8_CPP_INCREASE_AND_RETURN_ON_ERROR(it, end) + + code_point = ((code_point << 6) & 0x7ff) + ((*it) & 0x3f); + + return UTF8_OK; + } + + template + utf_error get_sequence_3(octet_iterator& it, octet_iterator end, utfchar32_t& code_point) + { + if (it == end) + return NOT_ENOUGH_ROOM; + + code_point = static_cast(utf8::internal::mask8(*it)); + + UTF8_CPP_INCREASE_AND_RETURN_ON_ERROR(it, end) + + code_point = ((code_point << 12) & 0xffff) + ((utf8::internal::mask8(*it) << 6) & 0xfff); + + UTF8_CPP_INCREASE_AND_RETURN_ON_ERROR(it, end) + + code_point = static_cast(code_point + ((*it) & 0x3f)); + + return UTF8_OK; + } + + template + utf_error get_sequence_4(octet_iterator& it, octet_iterator end, utfchar32_t& code_point) + { + if (it == end) + return NOT_ENOUGH_ROOM; + + code_point = static_cast(utf8::internal::mask8(*it)); + + UTF8_CPP_INCREASE_AND_RETURN_ON_ERROR(it, end) + + code_point = ((code_point << 18) & 0x1fffff) + ((utf8::internal::mask8(*it) << 12) & 0x3ffff); + + UTF8_CPP_INCREASE_AND_RETURN_ON_ERROR(it, end) + + code_point = static_cast(code_point + ((utf8::internal::mask8(*it) << 6) & 0xfff)); + + UTF8_CPP_INCREASE_AND_RETURN_ON_ERROR(it, end) + + code_point = static_cast(code_point + ((*it) & 0x3f)); + + return UTF8_OK; + } + + #undef UTF8_CPP_INCREASE_AND_RETURN_ON_ERROR + + template + utf_error validate_next(octet_iterator& it, octet_iterator end, utfchar32_t& code_point) + { + if (it == end) + return NOT_ENOUGH_ROOM; + + // Save the original value of it so we can go back in case of failure + // Of course, it does not make much sense with i.e. stream iterators + octet_iterator original_it = it; + + utfchar32_t cp = 0; + // Determine the sequence length based on the lead octet + const int length = utf8::internal::sequence_length(it); + + // Get trail octets and calculate the code point + utf_error err = UTF8_OK; + switch (length) { + case 0: + return INVALID_LEAD; + case 1: + err = utf8::internal::get_sequence_1(it, end, cp); + break; + case 2: + err = utf8::internal::get_sequence_2(it, end, cp); + break; + case 3: + err = utf8::internal::get_sequence_3(it, end, cp); + break; + case 4: + err = utf8::internal::get_sequence_4(it, end, cp); + break; + } + + if (err == UTF8_OK) { + // Decoding succeeded. Now, security checks... + if (utf8::internal::is_code_point_valid(cp)) { + if (!utf8::internal::is_overlong_sequence(cp, length)){ + // Passed! Return here. + code_point = cp; + ++it; + return UTF8_OK; + } + else + err = OVERLONG_SEQUENCE; + } + else + err = INVALID_CODE_POINT; + } + + // Failure branch - restore the original value of the iterator + it = original_it; + return err; + } + + template + inline utf_error validate_next(octet_iterator& it, octet_iterator end) { + utfchar32_t ignored; + return utf8::internal::validate_next(it, end, ignored); + } + + template + utf_error validate_next16(word_iterator& it, word_iterator end, utfchar32_t& code_point) + { + // Make sure the iterator dereferences a large enough type + typedef typename std::iterator_traits::value_type word_type; + UTF_CPP_STATIC_ASSERT(sizeof(word_type) >= sizeof(utfchar16_t)); + // Check the edge case: + if (it == end) + return NOT_ENOUGH_ROOM; + // Save the original value of it so we can go back in case of failure + // Of course, it does not make much sense with i.e. stream iterators + word_iterator original_it = it; + + utf_error err = UTF8_OK; + + const utfchar16_t first_word = *it++; + if (!is_surrogate(first_word)) { + code_point = first_word; + return UTF8_OK; + } + else { + if (it == end) + err = NOT_ENOUGH_ROOM; + else if (is_lead_surrogate(first_word)) { + const utfchar16_t second_word = *it++; + if (is_trail_surrogate(static_cast(second_word))) { + code_point = static_cast(first_word << 10) + static_cast(second_word) + SURROGATE_OFFSET; + return UTF8_OK; + } else + err = INCOMPLETE_SEQUENCE; + + } else { + err = INVALID_LEAD; + } + } + // error branch + it = original_it; + return err; + } + + // Internal implementation of both checked and unchecked append() function + // This function will be invoked by the overloads below, as they will know + // the octet_type. + template + octet_iterator append(utfchar32_t cp, octet_iterator result) { + if (cp < 0x80) // one octet + *(result++) = static_cast(cp); + else if (cp < 0x800) { // two octets + *(result++) = static_cast((cp >> 6) | 0xc0); + *(result++) = static_cast((cp & 0x3f) | 0x80); + } + else if (cp < 0x10000) { // three octets + *(result++) = static_cast((cp >> 12) | 0xe0); + *(result++) = static_cast(((cp >> 6) & 0x3f) | 0x80); + *(result++) = static_cast((cp & 0x3f) | 0x80); + } + else { // four octets + *(result++) = static_cast((cp >> 18) | 0xf0); + *(result++) = static_cast(((cp >> 12) & 0x3f)| 0x80); + *(result++) = static_cast(((cp >> 6) & 0x3f) | 0x80); + *(result++) = static_cast((cp & 0x3f) | 0x80); + } + return result; + } + + // One of the following overloads will be invoked from the API calls + + // A simple (but dangerous) case: the caller appends byte(s) to a char array + inline char* append(utfchar32_t cp, char* result) { + return append(cp, result); + } + + // Hopefully, most common case: the caller uses back_inserter + // i.e. append(cp, std::back_inserter(str)); + template + std::back_insert_iterator append + (utfchar32_t cp, std::back_insert_iterator result) { + return append, + typename container_type::value_type>(cp, result); + } + + // The caller uses some other kind of output operator - not covered above + // Note that in this case we are not able to determine octet_type + // so we assume it's utfchar8_t; that can cause a conversion warning if we are wrong. + template + octet_iterator append(utfchar32_t cp, octet_iterator result) { + return append(cp, result); + } + + // Internal implementation of both checked and unchecked append16() function + // This function will be invoked by the overloads below, as they will know + // the word_type. + template + word_iterator append16(utfchar32_t cp, word_iterator result) { + UTF_CPP_STATIC_ASSERT(sizeof(word_type) >= sizeof(utfchar16_t)); + if (is_in_bmp(cp)) + *(result++) = static_cast(cp); + else { + // Code points from the supplementary planes are encoded via surrogate pairs + *(result++) = static_cast(LEAD_OFFSET + (cp >> 10)); + *(result++) = static_cast(TRAIL_SURROGATE_MIN + (cp & 0x3FF)); + } + return result; + } + + // Hopefully, most common case: the caller uses back_inserter + // i.e. append16(cp, std::back_inserter(str)); + template + std::back_insert_iterator append16 + (utfchar32_t cp, std::back_insert_iterator result) { + return append16, + typename container_type::value_type>(cp, result); + } + + // The caller uses some other kind of output operator - not covered above + // Note that in this case we are not able to determine word_type + // so we assume it's utfchar16_t; that can cause a conversion warning if we are wrong. + template + word_iterator append16(utfchar32_t cp, word_iterator result) { + return append16(cp, result); + } + +} // namespace internal + + /// The library API - functions intended to be called by the users + + // Byte order mark + const utfchar8_t bom[] = {0xef, 0xbb, 0xbf}; + + template + octet_iterator find_invalid(octet_iterator start, octet_iterator end) + { + octet_iterator result = start; + while (result != end) { + utf8::internal::utf_error err_code = utf8::internal::validate_next(result, end); + if (err_code != internal::UTF8_OK) + return result; + } + return result; + } + + inline const char* find_invalid(const char* str) + { + const char* end = str + std::strlen(str); + return find_invalid(str, end); + } + + inline std::size_t find_invalid(const std::string& s) + { + std::string::const_iterator invalid = find_invalid(s.begin(), s.end()); + return (invalid == s.end()) ? std::string::npos : static_cast(invalid - s.begin()); + } + + template + inline bool is_valid(octet_iterator start, octet_iterator end) + { + return (utf8::find_invalid(start, end) == end); + } + + inline bool is_valid(const char* str) + { + return (*(utf8::find_invalid(str)) == '\0'); + } + + inline bool is_valid(const std::string& s) + { + return is_valid(s.begin(), s.end()); + } + + + + template + inline bool starts_with_bom (octet_iterator it, octet_iterator end) + { + return ( + ((it != end) && (utf8::internal::mask8(*it++)) == bom[0]) && + ((it != end) && (utf8::internal::mask8(*it++)) == bom[1]) && + ((it != end) && (utf8::internal::mask8(*it)) == bom[2]) + ); + } + + inline bool starts_with_bom(const std::string& s) + { + return starts_with_bom(s.begin(), s.end()); + } +} // namespace utf8 + +#endif // header guard + diff --git a/third_party/utfcpp/include/utfcpp/utf8/cpp11.h b/third_party/utfcpp/include/utfcpp/utf8/cpp11.h new file mode 100644 index 000000000..691633c84 --- /dev/null +++ b/third_party/utfcpp/include/utfcpp/utf8/cpp11.h @@ -0,0 +1,70 @@ +// Copyright 2018 Nemanja Trifunovic + +/* +Permission is hereby granted, free of charge, to any person or organization +obtaining a copy of the software and accompanying documentation covered by +this license (the "Software") to use, reproduce, display, distribute, +execute, and transmit the Software, and to prepare derivative works of the +Software, and to permit third-parties to whom the Software is furnished to +do so, all subject to the following: + +The copyright notices in the Software and this entire statement, including +the above license grant, this restriction and the following disclaimer, +must be included in all copies of the Software, in whole or in part, and +all derivative works of the Software, unless such copies or derivative +works are solely in the form of machine-executable object code generated by +a source language processor. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE, TITLE AND NON-INFRINGEMENT. IN NO EVENT +SHALL THE COPYRIGHT HOLDERS OR ANYONE DISTRIBUTING THE SOFTWARE BE LIABLE +FOR ANY DAMAGES OR OTHER LIABILITY, WHETHER IN CONTRACT, TORT OR OTHERWISE, +ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +DEALINGS IN THE SOFTWARE. +*/ + + +#ifndef UTF8_FOR_CPP_a184c22c_d012_11e8_a8d5_f2801f1b9fd1 +#define UTF8_FOR_CPP_a184c22c_d012_11e8_a8d5_f2801f1b9fd1 + +#include "checked.h" + +namespace utf8 +{ + inline void append16(utfchar32_t cp, std::u16string& s) + { + append16(cp, std::back_inserter(s)); + } + + inline std::string utf16to8(const std::u16string& s) + { + std::string result; + utf16to8(s.begin(), s.end(), std::back_inserter(result)); + return result; + } + + inline std::u16string utf8to16(const std::string& s) + { + std::u16string result; + utf8to16(s.begin(), s.end(), std::back_inserter(result)); + return result; + } + + inline std::string utf32to8(const std::u32string& s) + { + std::string result; + utf32to8(s.begin(), s.end(), std::back_inserter(result)); + return result; + } + + inline std::u32string utf8to32(const std::string& s) + { + std::u32string result; + utf8to32(s.begin(), s.end(), std::back_inserter(result)); + return result; + } +} // namespace utf8 + +#endif // header guard + diff --git a/third_party/utfcpp/include/utfcpp/utf8/cpp17.h b/third_party/utfcpp/include/utfcpp/utf8/cpp17.h new file mode 100644 index 000000000..075873003 --- /dev/null +++ b/third_party/utfcpp/include/utfcpp/utf8/cpp17.h @@ -0,0 +1,96 @@ +// Copyright 2018 Nemanja Trifunovic + +/* +Permission is hereby granted, free of charge, to any person or organization +obtaining a copy of the software and accompanying documentation covered by +this license (the "Software") to use, reproduce, display, distribute, +execute, and transmit the Software, and to prepare derivative works of the +Software, and to permit third-parties to whom the Software is furnished to +do so, all subject to the following: + +The copyright notices in the Software and this entire statement, including +the above license grant, this restriction and the following disclaimer, +must be included in all copies of the Software, in whole or in part, and +all derivative works of the Software, unless such copies or derivative +works are solely in the form of machine-executable object code generated by +a source language processor. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE, TITLE AND NON-INFRINGEMENT. IN NO EVENT +SHALL THE COPYRIGHT HOLDERS OR ANYONE DISTRIBUTING THE SOFTWARE BE LIABLE +FOR ANY DAMAGES OR OTHER LIABILITY, WHETHER IN CONTRACT, TORT OR OTHERWISE, +ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +DEALINGS IN THE SOFTWARE. +*/ + + +#ifndef UTF8_FOR_CPP_7e906c01_03a3_4daf_b420_ea7ea952b3c9 +#define UTF8_FOR_CPP_7e906c01_03a3_4daf_b420_ea7ea952b3c9 + +#include "cpp11.h" + +namespace utf8 +{ + inline std::string utf16to8(std::u16string_view s) + { + std::string result; + utf16to8(s.begin(), s.end(), std::back_inserter(result)); + return result; + } + + inline std::u16string utf8to16(std::string_view s) + { + std::u16string result; + utf8to16(s.begin(), s.end(), std::back_inserter(result)); + return result; + } + + inline std::string utf32to8(std::u32string_view s) + { + std::string result; + utf32to8(s.begin(), s.end(), std::back_inserter(result)); + return result; + } + + inline std::u32string utf8to32(std::string_view s) + { + std::u32string result; + utf8to32(s.begin(), s.end(), std::back_inserter(result)); + return result; + } + + inline std::size_t find_invalid(std::string_view s) + { + std::string_view::const_iterator invalid = find_invalid(s.begin(), s.end()); + return (invalid == s.end()) ? std::string_view::npos : static_cast(invalid - s.begin()); + } + + inline bool is_valid(std::string_view s) + { + return is_valid(s.begin(), s.end()); + } + + inline std::string replace_invalid(std::string_view s, char32_t replacement) + { + std::string result; + replace_invalid(s.begin(), s.end(), std::back_inserter(result), replacement); + return result; + } + + inline std::string replace_invalid(std::string_view s) + { + std::string result; + replace_invalid(s.begin(), s.end(), std::back_inserter(result)); + return result; + } + + inline bool starts_with_bom(std::string_view s) + { + return starts_with_bom(s.begin(), s.end()); + } + +} // namespace utf8 + +#endif // header guard + diff --git a/third_party/utfcpp/include/utfcpp/utf8/cpp20.h b/third_party/utfcpp/include/utfcpp/utf8/cpp20.h new file mode 100644 index 000000000..07b61d0fb --- /dev/null +++ b/third_party/utfcpp/include/utfcpp/utf8/cpp20.h @@ -0,0 +1,124 @@ +// Copyright 2022 Nemanja Trifunovic + +/* +Permission is hereby granted, free of charge, to any person or organization +obtaining a copy of the software and accompanying documentation covered by +this license (the "Software") to use, reproduce, display, distribute, +execute, and transmit the Software, and to prepare derivative works of the +Software, and to permit third-parties to whom the Software is furnished to +do so, all subject to the following: + +The copyright notices in the Software and this entire statement, including +the above license grant, this restriction and the following disclaimer, +must be included in all copies of the Software, in whole or in part, and +all derivative works of the Software, unless such copies or derivative +works are solely in the form of machine-executable object code generated by +a source language processor. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE, TITLE AND NON-INFRINGEMENT. IN NO EVENT +SHALL THE COPYRIGHT HOLDERS OR ANYONE DISTRIBUTING THE SOFTWARE BE LIABLE +FOR ANY DAMAGES OR OTHER LIABILITY, WHETHER IN CONTRACT, TORT OR OTHERWISE, +ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +DEALINGS IN THE SOFTWARE. +*/ + + +#ifndef UTF8_FOR_CPP_207e906c01_03a3_4daf_b420_ea7ea952b3c9 +#define UTF8_FOR_CPP_207e906c01_03a3_4daf_b420_ea7ea952b3c9 + +#include "cpp17.h" + +namespace utf8 +{ + inline std::u8string utf16tou8(const std::u16string& s) + { + std::u8string result; + utf16to8(s.begin(), s.end(), std::back_inserter(result)); + return result; + } + + inline std::u8string utf16tou8(std::u16string_view s) + { + std::u8string result; + utf16to8(s.begin(), s.end(), std::back_inserter(result)); + return result; + } + + inline std::u16string utf8to16(const std::u8string& s) + { + std::u16string result; + utf8to16(s.begin(), s.end(), std::back_inserter(result)); + return result; + } + + inline std::u16string utf8to16(const std::u8string_view& s) + { + std::u16string result; + utf8to16(s.begin(), s.end(), std::back_inserter(result)); + return result; + } + + inline std::u8string utf32tou8(const std::u32string& s) + { + std::u8string result; + utf32to8(s.begin(), s.end(), std::back_inserter(result)); + return result; + } + + inline std::u8string utf32tou8(const std::u32string_view& s) + { + std::u8string result; + utf32to8(s.begin(), s.end(), std::back_inserter(result)); + return result; + } + + inline std::u32string utf8to32(const std::u8string& s) + { + std::u32string result; + utf8to32(s.begin(), s.end(), std::back_inserter(result)); + return result; + } + + inline std::u32string utf8to32(const std::u8string_view& s) + { + std::u32string result; + utf8to32(s.begin(), s.end(), std::back_inserter(result)); + return result; + } + + inline std::size_t find_invalid(const std::u8string& s) + { + std::u8string::const_iterator invalid = find_invalid(s.begin(), s.end()); + return (invalid == s.end()) ? std::string_view::npos : static_cast(invalid - s.begin()); + } + + inline bool is_valid(const std::u8string& s) + { + return is_valid(s.begin(), s.end()); + } + + inline std::u8string replace_invalid(const std::u8string& s, char32_t replacement) + { + std::u8string result; + replace_invalid(s.begin(), s.end(), std::back_inserter(result), replacement); + return result; + } + + inline std::u8string replace_invalid(const std::u8string& s) + { + std::u8string result; + replace_invalid(s.begin(), s.end(), std::back_inserter(result)); + return result; + } + + inline bool starts_with_bom(const std::u8string& s) + { + return starts_with_bom(s.begin(), s.end()); + } + +} // namespace utf8 + +#endif // header guard + diff --git a/third_party/utfcpp/include/utfcpp/utf8/unchecked.h b/third_party/utfcpp/include/utfcpp/utf8/unchecked.h new file mode 100644 index 000000000..173d0302e --- /dev/null +++ b/third_party/utfcpp/include/utfcpp/utf8/unchecked.h @@ -0,0 +1,286 @@ +// Copyright 2006 Nemanja Trifunovic + +/* +Permission is hereby granted, free of charge, to any person or organization +obtaining a copy of the software and accompanying documentation covered by +this license (the "Software") to use, reproduce, display, distribute, +execute, and transmit the Software, and to prepare derivative works of the +Software, and to permit third-parties to whom the Software is furnished to +do so, all subject to the following: + +The copyright notices in the Software and this entire statement, including +the above license grant, this restriction and the following disclaimer, +must be included in all copies of the Software, in whole or in part, and +all derivative works of the Software, unless such copies or derivative +works are solely in the form of machine-executable object code generated by +a source language processor. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE, TITLE AND NON-INFRINGEMENT. IN NO EVENT +SHALL THE COPYRIGHT HOLDERS OR ANYONE DISTRIBUTING THE SOFTWARE BE LIABLE +FOR ANY DAMAGES OR OTHER LIABILITY, WHETHER IN CONTRACT, TORT OR OTHERWISE, +ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +DEALINGS IN THE SOFTWARE. +*/ + + +#ifndef UTF8_FOR_CPP_UNCHECKED_H_2675DCD0_9480_4c0c_B92A_CC14C027B731 +#define UTF8_FOR_CPP_UNCHECKED_H_2675DCD0_9480_4c0c_B92A_CC14C027B731 + +#include "core.h" + +namespace utf8 +{ + namespace unchecked + { + template + octet_iterator append(utfchar32_t cp, octet_iterator result) + { + return internal::append(cp, result); + } + + template + word_iterator append16(utfchar32_t cp, word_iterator result) + { + return internal::append16(cp, result); + } + + template + output_iterator replace_invalid(octet_iterator start, octet_iterator end, output_iterator out, utfchar32_t replacement) + { + while (start != end) { + octet_iterator sequence_start = start; + internal::utf_error err_code = utf8::internal::validate_next(start, end); + switch (err_code) { + case internal::UTF8_OK : + for (octet_iterator it = sequence_start; it != start; ++it) + *out++ = *it; + break; + case internal::NOT_ENOUGH_ROOM: + out = utf8::unchecked::append(replacement, out); + start = end; + break; + case internal::INVALID_LEAD: + out = utf8::unchecked::append(replacement, out); + ++start; + break; + case internal::INCOMPLETE_SEQUENCE: + case internal::OVERLONG_SEQUENCE: + case internal::INVALID_CODE_POINT: + out = utf8::unchecked::append(replacement, out); + ++start; + // just one replacement mark for the sequence + while (start != end && utf8::internal::is_trail(*start)) + ++start; + break; + } + } + return out; + } + + template + inline output_iterator replace_invalid(octet_iterator start, octet_iterator end, output_iterator out) + { + static const utfchar32_t replacement_marker = static_cast(utf8::internal::mask16(0xfffd)); + return utf8::unchecked::replace_invalid(start, end, out, replacement_marker); + } + + inline std::string replace_invalid(const std::string& s, utfchar32_t replacement) + { + std::string result; + replace_invalid(s.begin(), s.end(), std::back_inserter(result), replacement); + return result; + } + + inline std::string replace_invalid(const std::string& s) + { + std::string result; + replace_invalid(s.begin(), s.end(), std::back_inserter(result)); + return result; + } + + template + utfchar32_t next(octet_iterator& it) + { + utfchar32_t cp = utf8::internal::mask8(*it); + switch (utf8::internal::sequence_length(it)) { + case 1: + break; + case 2: + ++it; + cp = ((cp << 6) & 0x7ff) + ((*it) & 0x3f); + break; + case 3: + ++it; + cp = ((cp << 12) & 0xffff) + ((utf8::internal::mask8(*it) << 6) & 0xfff); + ++it; + cp = static_cast(cp + ((*it) & 0x3f)); + break; + case 4: + ++it; + cp = ((cp << 18) & 0x1fffff) + ((utf8::internal::mask8(*it) << 12) & 0x3ffff); + ++it; + cp = static_cast(cp + ((utf8::internal::mask8(*it) << 6) & 0xfff)); + ++it; + cp = static_cast(cp + ((*it) & 0x3f)); + break; + } + ++it; + return cp; + } + + template + utfchar32_t peek_next(octet_iterator it) + { + return utf8::unchecked::next(it); + } + + template + utfchar32_t next16(word_iterator& it) + { + utfchar32_t cp = utf8::internal::mask16(*it++); + if (utf8::internal::is_lead_surrogate(cp)) + return (cp << 10) + *it++ + utf8::internal::SURROGATE_OFFSET; + return cp; + } + + template + utfchar32_t prior(octet_iterator& it) + { + while (utf8::internal::is_trail(*(--it))) ; + octet_iterator temp = it; + return utf8::unchecked::next(temp); + } + + template + void advance(octet_iterator& it, distance_type n) + { + const distance_type zero(0); + if (n < zero) { + // backward + for (distance_type i = n; i < zero; ++i) + utf8::unchecked::prior(it); + } else { + // forward + for (distance_type i = zero; i < n; ++i) + utf8::unchecked::next(it); + } + } + + template + typename std::iterator_traits::difference_type + distance(octet_iterator first, octet_iterator last) + { + typename std::iterator_traits::difference_type dist; + for (dist = 0; first < last; ++dist) + utf8::unchecked::next(first); + return dist; + } + + template + octet_iterator utf16to8(u16bit_iterator start, u16bit_iterator end, octet_iterator result) + { + while (start != end) { + utfchar32_t cp = utf8::internal::mask16(*start++); + // Take care of surrogate pairs first + if (utf8::internal::is_lead_surrogate(cp)) { + if (start == end) + return result; + utfchar32_t trail_surrogate = utf8::internal::mask16(*start++); + cp = (cp << 10) + trail_surrogate + internal::SURROGATE_OFFSET; + } + result = utf8::unchecked::append(cp, result); + } + return result; + } + + template + u16bit_iterator utf8to16(octet_iterator start, octet_iterator end, u16bit_iterator result) + { + while (start < end) { + utfchar32_t cp = utf8::unchecked::next(start); + if (cp > 0xffff) { //make a surrogate pair + *result++ = static_cast((cp >> 10) + internal::LEAD_OFFSET); + *result++ = static_cast((cp & 0x3ff) + internal::TRAIL_SURROGATE_MIN); + } + else + *result++ = static_cast(cp); + } + return result; + } + + template + octet_iterator utf32to8(u32bit_iterator start, u32bit_iterator end, octet_iterator result) + { + while (start != end) + result = utf8::unchecked::append(*(start++), result); + + return result; + } + + template + u32bit_iterator utf8to32(octet_iterator start, octet_iterator end, u32bit_iterator result) + { + while (start < end) + (*result++) = utf8::unchecked::next(start); + + return result; + } + + // The iterator class + template + class iterator { + octet_iterator it; + public: + typedef utfchar32_t value_type; + typedef utfchar32_t* pointer; + typedef utfchar32_t& reference; + typedef std::ptrdiff_t difference_type; + typedef std::bidirectional_iterator_tag iterator_category; + iterator () {} + explicit iterator (const octet_iterator& octet_it): it(octet_it) {} + // the default "big three" are OK + octet_iterator base () const { return it; } + utfchar32_t operator * () const + { + octet_iterator temp = it; + return utf8::unchecked::next(temp); + } + bool operator == (const iterator& rhs) const + { + return (it == rhs.it); + } + bool operator != (const iterator& rhs) const + { + return !(operator == (rhs)); + } + iterator& operator ++ () + { + ::std::advance(it, utf8::internal::sequence_length(it)); + return *this; + } + iterator operator ++ (int) + { + iterator temp = *this; + ::std::advance(it, utf8::internal::sequence_length(it)); + return temp; + } + iterator& operator -- () + { + utf8::unchecked::prior(it); + return *this; + } + iterator operator -- (int) + { + iterator temp = *this; + utf8::unchecked::prior(it); + return temp; + } + }; // class iterator + + } // namespace utf8::unchecked +} // namespace utf8 + +#endif // header guard + From f76bce9ebdf6ade30cfbd323aa39d43650a2004e Mon Sep 17 00:00:00 2001 From: chenghuaWang <2923277184@qq.com> Date: Fri, 24 Oct 2025 22:19:57 +0800 Subject: [PATCH 12/25] refactor(ext): replace tokenizers-cpp with tokenizers submodule - Remove tokenizers-cpp submodule from .gitmodules - Add new tokenizers submodule pointing to meta-pytorch/tokenizers - Update README.md to reflect the change in tokenizer dependency - Update subproject commit reference for the new tokenizers module --- .gitmodules | 8 ++++---- mllm/ext/README.md | 2 +- mllm/ext/vendors/tokenizers | 1 + mllm/ext/vendors/tokenizers-cpp | 1 - 4 files changed, 6 insertions(+), 6 deletions(-) create mode 160000 mllm/ext/vendors/tokenizers delete mode 160000 mllm/ext/vendors/tokenizers-cpp diff --git a/.gitmodules b/.gitmodules index e2523ea21..e05419d65 100644 --- a/.gitmodules +++ b/.gitmodules @@ -21,11 +21,11 @@ [submodule "mllm/ffi/vendors/tvm-ffi"] path = mllm/ffi/vendors/tvm-ffi url = https://github.com/apache/tvm-ffi -[submodule "mllm/ext/vendors/tokenizers-cpp"] - path = mllm/ext/vendors/tokenizers-cpp - url = https://github.com/mlc-ai/tokenizers-cpp - update = none [submodule "mllm/ext/vendors/llvm-project"] path = mllm/ext/vendors/llvm-project url = https://github.com/llvm/llvm-project update = none +[submodule "mllm/ext/vendors/tokenizers"] + path = mllm/ext/vendors/tokenizers + url = https://github.com/meta-pytorch/tokenizers.git + update = none diff --git a/mllm/ext/README.md b/mllm/ext/README.md index 20ca6e222..919c7ea9f 100644 --- a/mllm/ext/README.md +++ b/mllm/ext/README.md @@ -2,5 +2,5 @@ Mllm extension contains some third-party packages and it's warper to mllm. Those extensions are OPTIONAL for mllm main lib. -- tokenizer-cpp +- tokenizers: [tokenizers](https://github.com/meta-pytorch/tokenizers.git) - mobile opencv diff --git a/mllm/ext/vendors/tokenizers b/mllm/ext/vendors/tokenizers new file mode 160000 index 000000000..0bcd9f532 --- /dev/null +++ b/mllm/ext/vendors/tokenizers @@ -0,0 +1 @@ +Subproject commit 0bcd9f5325c42d2e05765d584131900dbc8835c8 diff --git a/mllm/ext/vendors/tokenizers-cpp b/mllm/ext/vendors/tokenizers-cpp deleted file mode 160000 index 55d53aa38..000000000 --- a/mllm/ext/vendors/tokenizers-cpp +++ /dev/null @@ -1 +0,0 @@ -Subproject commit 55d53aa38dc8df7d9c8bd9ed50907e82ae83ce66 From 73b74f21637ce1d91bd8de8d6353107656e9f30d Mon Sep 17 00:00:00 2001 From: chenghuaWang <2923277184@qq.com> Date: Sat, 25 Oct 2025 21:38:41 +0800 Subject: [PATCH 13/25] feat(tokenizer): implement deepseek-ocr tokenizer with BPE and UTF-8 support - Add `DpskOcrTokenizer` with byte-level BPE tokenization - Support encoding/decoding using UTF-8 code points - Integrate regex-based pre-tokenizer from GPT-2 standard - Update `AutoTokenizerUTF8` interface to include encode/decode methods - Refactor `BPEUTF8` to work with UTF-32 code point vectors - Add unicode processing utilities from llama.cpp - Enable meta-torch-tokenizers extension in CMake fix(tokenizer): replace tokenize with encode in deepseek-ocr model - Use `encode` instead of `tokenize` for proper id conversion - Adjust sequence mask generation accordingly build(cmake): add extension support and tokenizer options - Introduce `MLLM_EXT_ENABLE` option - Rename `MLLM_EXT_ENABLE_TOKENIZERS_CPP` to `MLLM_EXT_ENABLE_META_TORCH_TOKENIZERS` - Add subdirectory for external components when enabled test(tokenizer): add print statements for tokenizer debugging - Print tokenization and encoding results for various inputs - Include space, newline, and unicode characters chore(fmt): add formatter for vector and vector - Enable formatted printing of string and int64_t vectors - Support cleaner debug output in examples --- CMakeLists.txt | 3 +- examples/deepseek_ocr/main.cpp | 13 +- mllm/CMakeLists.txt | 5 + mllm/ext/CMakeLists.txt | 5 + mllm/mllm.inl | 38 + .../deepseek_ocr/modeling_deepseek_ocr.hpp | 4 +- .../tokenization_deepseek_ocr.hpp | 44 +- .../preprocessor/tokenizers/AutoTokenizer.hpp | 8 +- mllm/preprocessor/tokenizers/BPEUTF8.cpp | 48 +- mllm/preprocessor/tokenizers/BPEUTF8.hpp | 58 +- .../tokenizers/llama_cpp_unicode/README.md | 8 + .../llama_cpp_unicode/unicode-data.cpp | 1617 +++++++++++++++++ .../llama_cpp_unicode/unicode-data.h | 54 + .../tokenizers/llama_cpp_unicode/unicode.cpp | 1068 +++++++++++ .../tokenizers/llama_cpp_unicode/unicode.h | 90 + 15 files changed, 3025 insertions(+), 38 deletions(-) create mode 100644 mllm/preprocessor/tokenizers/llama_cpp_unicode/README.md create mode 100644 mllm/preprocessor/tokenizers/llama_cpp_unicode/unicode-data.cpp create mode 100644 mllm/preprocessor/tokenizers/llama_cpp_unicode/unicode-data.h create mode 100644 mllm/preprocessor/tokenizers/llama_cpp_unicode/unicode.cpp create mode 100644 mllm/preprocessor/tokenizers/llama_cpp_unicode/unicode.h diff --git a/CMakeLists.txt b/CMakeLists.txt index 6a2f4d1bc..1e85ecf7f 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -23,8 +23,9 @@ option(MLLM_BUILD_SDK_C_BINDING "Enable MLLM C SDK binding" OFF) option(MLLM_BUILD_EXPERIMENTS "Enable MLLM experiments" OFF) # Extension Enable +option(MLLM_EXT_ENABLE OFF) option(MLLM_EXT_ENABLE_LLVM_PROJECT OFF) -option(MLLM_EXT_ENABLE_TOKENIZERS_CPP OFF) +option(MLLM_EXT_ENABLE_META_TORCH_TOKENIZERS OFF) # CPU Backend: BLAS option(MLLM_USE_BLAS "Enable BLAS" OFF) diff --git a/examples/deepseek_ocr/main.cpp b/examples/deepseek_ocr/main.cpp index fb3368b4b..9fb8a6771 100644 --- a/examples/deepseek_ocr/main.cpp +++ b/examples/deepseek_ocr/main.cpp @@ -1,5 +1,3 @@ -#include -#include #include #include "mllm/models/deepseek_ocr/modeling_deepseek_ocr.hpp" #include "mllm/models/deepseek_ocr/tokenization_deepseek_ocr.hpp" @@ -9,6 +7,17 @@ using mllm::Argparse; MLLM_MAIN({ auto model = mllm::models::deepseek_ocr::DeepseekOCRForCausalLM(); auto tokenizer = mllm::models::deepseek_ocr::DpskOcrTokenizer("/Volumes/D/hf-models/DeepSeek-OCR/tokenizer.json"); + + mllm::print(tokenizer.tokenize(" ")); + mllm::print(tokenizer.tokenize("▁")); + mllm::print(tokenizer.tokenize("\n")); + mllm::print(tokenizer.tokenize("你好啊!")); + + mllm::print(tokenizer.encode(" ")); + mllm::print(tokenizer.encode("▁")); + mllm::print(tokenizer.encode("\n")); + mllm::print(tokenizer.encode("你好啊!")); + model.infer(tokenizer, "\n<|grounding|>Convert the document to markdown. ", "/Volumes/D/mllm/.tmp/dpsk-ocr-pr.png", "/Volumes/D/mllm/.tmp/dpsk-ocr"); }); diff --git a/mllm/CMakeLists.txt b/mllm/CMakeLists.txt index f28afdce2..d4728bfe0 100644 --- a/mllm/CMakeLists.txt +++ b/mllm/CMakeLists.txt @@ -105,3 +105,8 @@ if(MLLM_BUILD_CUDA_BACKEND) MLLM_CUDA_BACKEND ) endif() + +# Extension +if(MLLM_EXT_ENABLE) + add_subdirectory(ext) +endif() diff --git a/mllm/ext/CMakeLists.txt b/mllm/ext/CMakeLists.txt index e69de29bb..de1e27a24 100644 --- a/mllm/ext/CMakeLists.txt +++ b/mllm/ext/CMakeLists.txt @@ -0,0 +1,5 @@ +if (MLLM_EXT_ENABLE_META_TORCH_TOKENIZERS) + add_subdirectory(vendors/tokenizers) +endif() + +# LLVM MLIR Stuff. diff --git a/mllm/mllm.inl b/mllm/mllm.inl index 19410166a..cc9afbaec 100644 --- a/mllm/mllm.inl +++ b/mllm/mllm.inl @@ -111,6 +111,44 @@ struct formatter> { } }; +template<> +struct formatter> { + constexpr auto parse(format_parse_context& ctx) { return ctx.begin(); } + template + auto format(const std::vector& vec, FormatContext& ctx) const { + auto out = ctx.out(); + *out++ = '['; + for (size_t i = 0; i < vec.size(); ++i) { + if (i > 0) { + *out++ = ','; + *out++ = ' '; + } + out = fmt::format_to(out, "\"{}\"", vec[i]); + } + *out++ = ']'; + return out; + } +}; + +template<> +struct formatter> { + constexpr auto parse(format_parse_context& ctx) { return ctx.begin(); } + template + auto format(const std::vector& vec, FormatContext& ctx) const { + auto out = ctx.out(); + *out++ = '['; + for (size_t i = 0; i < vec.size(); ++i) { + if (i > 0) { + *out++ = ','; + *out++ = ' '; + } + out = fmt::format_to(out, "\"{}\"", vec[i]); + } + *out++ = ']'; + return out; + } +}; + template<> struct formatter { constexpr auto parse(format_parse_context& ctx) { return ctx.begin(); } diff --git a/mllm/models/deepseek_ocr/modeling_deepseek_ocr.hpp b/mllm/models/deepseek_ocr/modeling_deepseek_ocr.hpp index 557d9a8d6..6464263be 100644 --- a/mllm/models/deepseek_ocr/modeling_deepseek_ocr.hpp +++ b/mllm/models/deepseek_ocr/modeling_deepseek_ocr.hpp @@ -81,7 +81,7 @@ class DeepseekOCRForCausalLM final : public nn::Module, public ARGeneration { // text_splits's length should be greater than images' length. // text_splits.size() - images.size() = 1 for (int idx = 0; idx < std::min(images.size(), text_splits.size()); ++idx) { - auto tokenized_sep = tokenizer.tokenize(text_splits[idx]); + auto tokenized_sep = tokenizer.encode(text_splits[idx]); tokenized_str.insert(tokenized_str.end(), tokenized_sep.begin(), tokenized_sep.end()); for (int _i = 0; _i < tokenized_sep.size(); ++_i) { images_seq_mask.emplace_back(0); // emplace_back(false) @@ -160,7 +160,7 @@ class DeepseekOCRForCausalLM final : public nn::Module, public ARGeneration { } // Processing last text split - auto tokenized_sep = tokenizer.tokenize(text_splits.back()); + auto tokenized_sep = tokenizer.encode(text_splits.back()); tokenized_str.insert(tokenized_str.end(), tokenized_sep.begin(), tokenized_sep.end()); images_seq_mask.insert(images_seq_mask.end(), tokenized_sep.size(), false); diff --git a/mllm/models/deepseek_ocr/tokenization_deepseek_ocr.hpp b/mllm/models/deepseek_ocr/tokenization_deepseek_ocr.hpp index e745effc9..3358c079a 100644 --- a/mllm/models/deepseek_ocr/tokenization_deepseek_ocr.hpp +++ b/mllm/models/deepseek_ocr/tokenization_deepseek_ocr.hpp @@ -11,14 +11,22 @@ #include #include #include -#include #include "mllm/preprocessor/tokenizers/BPEUTF8.hpp" #include "mllm/preprocessor/tokenizers/Unicode.hpp" #include "mllm/preprocessor/tokenizers/AutoTokenizer.hpp" +#include "mllm/preprocessor/tokenizers/llama_cpp_unicode/unicode.h" namespace mllm::models::deepseek_ocr { +namespace details { + +// Standard GPT2 regex +// https://github.com/openai/gpt-2/blob/master/src/encoder.py#L53 +constexpr char GPT2_EXPR[] = R"('s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+)"; + +} // namespace details + // Actually is LlamaTokenizer class DpskOcrTokenizer final : public mllm::preprocessor::AutoTokenizerUTF8 { public: @@ -37,17 +45,43 @@ class DpskOcrTokenizer final : public mllm::preprocessor::AutoTokenizerUTF8 { special_tokens_trie_.add(L"<|▁pad▁|>"); } - std::vector tokenize(const std::string& str) override { + std::vector encode(const std::string& str) override { + auto sub_tokens = tokenize(str); + auto ret = std::vector{}; + for (auto& token : sub_tokens) { ret.emplace_back(bpe_._lookup_vocab(token)); } + return ret; + } + + std::string decode(const std::vector& ids) override { // TODO return {}; } - std::string detokenize(int64_t pos_idx) override { + std::vector tokenize(const std::string& str) override { + auto after_regex_process = regexPreTokenizer(str); + std::vector ret; + for (auto& ss : after_regex_process) { + auto after_bytes_process = byteLevelPreTokenizer(ss); + + // Perform BPE algorithm on each sub-token + for (auto& bbpe_str : after_bytes_process) { + auto bbpe_str_sub_tokens = bpe_._bpe(bbpe_str); + ret.insert(ret.end(), bbpe_str_sub_tokens.begin(), bbpe_str_sub_tokens.end()); + } + } + return ret; + } + + std::string detokenize(const std::vector& tokenized_str) override { // TODO - return ""; + return {}; } private: + std::vector byteLevelPreTokenizer(const std::string& str) { + return unicode_regex_split(str, {std::string{details::GPT2_EXPR}}); + } + // "pre_tokenizer": { // "type": "Sequence", // "pretokenizers": [ @@ -84,7 +118,7 @@ class DpskOcrTokenizer final : public mllm::preprocessor::AutoTokenizerUTF8 { // } // ] // } - std::vector preprocessToken(const std::string& token) { + std::vector regexPreTokenizer(const std::string& token) { std::vector out; auto it = token.begin(); auto end = token.end(); diff --git a/mllm/preprocessor/tokenizers/AutoTokenizer.hpp b/mllm/preprocessor/tokenizers/AutoTokenizer.hpp index 7dbe6af2c..cc542f3ff 100644 --- a/mllm/preprocessor/tokenizers/AutoTokenizer.hpp +++ b/mllm/preprocessor/tokenizers/AutoTokenizer.hpp @@ -76,9 +76,13 @@ class AutoTokenizerUTF8 { public: void addSpecialToken(const std::string& special_token); - virtual std::vector tokenize(const std::string& str) = 0; + virtual std::vector encode(const std::string& str) = 0; - virtual std::string detokenize(int64_t pos_idx) = 0; + virtual std::string decode(const std::vector& ids) = 0; + + virtual std::vector tokenize(const std::string& str) = 0; + + virtual std::string detokenize(const std::vector& tokenized_str) = 0; protected: Trie special_tokens_trie_; diff --git a/mllm/preprocessor/tokenizers/BPEUTF8.cpp b/mllm/preprocessor/tokenizers/BPEUTF8.cpp index da9e53637..bd6be2a87 100644 --- a/mllm/preprocessor/tokenizers/BPEUTF8.cpp +++ b/mllm/preprocessor/tokenizers/BPEUTF8.cpp @@ -29,12 +29,12 @@ bool BPEUTF8::initFromSentencePieceJson(const std::string& file_path) { for (const auto& [key, value] : json_data["model"]["vocab"].items()) { vocab_.insert({ - key, + utf8String2Cpts(key), value, }); vocab_inverse_.insert({ value, - key, + utf8String2Cpts(key), }); } @@ -42,12 +42,12 @@ bool BPEUTF8::initFromSentencePieceJson(const std::string& file_path) { int64_t id = add_token.value()["id"]; std::string content = add_token.value()["content"]; vocab_.insert({ - content, + utf8String2Cpts(content), id, }); vocab_inverse_.insert({ id, - content, + utf8String2Cpts(content), }); } @@ -60,9 +60,9 @@ bool BPEUTF8::initFromSentencePieceJson(const std::string& file_path) { auto blank_pos = wide_merge_item.find(' '); auto first = wide_merge_item.substr(0, blank_pos); auto second = wide_merge_item.substr(blank_pos + 1); - bpe_ranks_.insert({{first, second}, cnt++}); + bpe_ranks_.insert({{utf8String2Cpts(first), utf8String2Cpts(second)}, cnt++}); } else if (merge_item.is_array()) { - bpe_ranks_.insert({{merge_item[0], merge_item[1]}, cnt++}); + bpe_ranks_.insert({{utf8String2Cpts(merge_item[0]), utf8String2Cpts(merge_item[1])}, cnt++}); } } @@ -71,9 +71,12 @@ bool BPEUTF8::initFromSentencePieceJson(const std::string& file_path) { // ByteLevel BPE std::vector BPEUTF8::_bpe(const std::string& token) { - // Treats token as a sequence of bytes - std::vector word; - for (const auto& w : token) word.emplace_back(1, w); + // Slice all tokens to word + std::vector word; + { + auto cpts = utf8String2Cpts(token); + for (auto cpt : cpts) { word.push_back(cpt_string_t{cpt}); } + } auto pairs = _get_pairs(word); if (pairs.empty()) return {token}; @@ -81,7 +84,7 @@ std::vector BPEUTF8::_bpe(const std::string& token) { while (true) { bool has_bigram = false; int64_t rank_bigram = std::numeric_limits::max(); - std::pair bigram; + std::pair bigram; for (const auto& p : pairs) { if (bpe_ranks_.count(p)) { @@ -97,7 +100,7 @@ std::vector BPEUTF8::_bpe(const std::string& token) { if (!has_bigram) { break; } auto [first, second] = bigram; - std::vector new_word; + std::vector new_word; int i = 0; while (i < word.size()) { @@ -110,7 +113,9 @@ std::vector BPEUTF8::_bpe(const std::string& token) { // Check if we can merge at position j if (j < word.size() - 1 && word[j] == first && word[j + 1] == second) { - new_word.push_back(first + second); + auto __merged = first; + __merged.insert(std::end(__merged), std::begin(second), std::end(second)); + new_word.push_back(__merged); i = j + 2; // Skip both merged elements } else if (j < word.size()) { new_word.push_back(word[j]); @@ -128,12 +133,17 @@ std::vector BPEUTF8::_bpe(const std::string& token) { } } - return word; + std::vector ret; + ret.reserve(word.size()); + for (auto& cpt : word) { ret.push_back(cpts2Utf8String(cpt)); } + + return ret; } int64_t BPEUTF8::_lookup_vocab(const std::string& token) { - if (vocab_.find(token) != vocab_.end()) { - return vocab_[token]; + auto cpts = utf8String2Cpts(token); + if (vocab_.find(cpts) != vocab_.end()) { + return vocab_[cpts]; } else { MLLM_WARN("Cannot find token: {} in BPEUTF8 vocab", token); return 0; @@ -142,16 +152,16 @@ int64_t BPEUTF8::_lookup_vocab(const std::string& token) { std::string BPEUTF8::_lookup_inverse_vocab(int64_t idx) { if (vocab_inverse_.find(idx) != vocab_inverse_.end()) { - return vocab_inverse_[idx]; + return cpts2Utf8String(vocab_inverse_[idx]); } else { MLLM_WARN("Cannot find token in BPEUTF8 vocab. When doing _lookup_inverse_vocab"); return {}; } } -std::unordered_set, BPEUTF8PairHash> BPEUTF8::_get_pairs( - const std::vector& word) { - std::unordered_set, BPEUTF8PairHash> pairs; +std::unordered_set, BPEUTF8PairHash> BPEUTF8::_get_pairs( + const std::vector& word) { + std::unordered_set, BPEUTF8PairHash> pairs; if (word.size() < 2) return pairs; auto prev_char = word[0]; for (size_t i = 1; i < word.size(); ++i) { diff --git a/mllm/preprocessor/tokenizers/BPEUTF8.hpp b/mllm/preprocessor/tokenizers/BPEUTF8.hpp index 4a2f24457..fc772a8c4 100644 --- a/mllm/preprocessor/tokenizers/BPEUTF8.hpp +++ b/mllm/preprocessor/tokenizers/BPEUTF8.hpp @@ -9,10 +9,14 @@ #include #include #include + // CPP's support of UTF-8 is weak. We use the utfcpp library to handle UTF-8 strings. #include #include +// Use XXHash +#include + // Remember: // utfcpp use // std::string to represent UTF-8 strings. @@ -21,15 +25,44 @@ namespace mllm::preprocessor { +namespace details { + +struct VectorUint32Hash { + std::size_t operator()(const std::vector& v) const noexcept { + if (v.empty()) return 0; + return static_cast(XXH64(v.data(), v.size() * sizeof(uint32_t), /*seed=*/0)); + } +}; + +} // namespace details + struct BPEUTF8PairHash { - std::size_t operator()(const std::pair& key) const { - std::size_t h1 = std::hash{}(key.first + key.second); - return h1; + std::size_t operator()(const std::pair, std::vector>& key) const noexcept { + const auto& a = key.first; + const auto& b = key.second; + + const std::size_t bytes_a = a.size() * sizeof(uint32_t); + const std::size_t bytes_b = b.size() * sizeof(uint32_t); + + if (bytes_a == 0 && bytes_b == 0) return 0; + + XXH64_state_t* state = XXH64_createState(); + if (!state) return 0; + XXH64_reset(state, /*seed=*/0); + + if (!a.empty()) XXH64_update(state, a.data(), bytes_a); + if (!b.empty()) XXH64_update(state, b.data(), bytes_b); + + std::size_t h = static_cast(XXH64_digest(state)); + XXH64_freeState(state); + return h; } }; class BPEUTF8 { public: + using cpt_string_t = std::vector; + // BPE can accept sentence piece's json foramt. bool initFromSentencePieceJson(const std::string& file_path); @@ -40,11 +73,22 @@ class BPEUTF8 { std::string _lookup_inverse_vocab(int64_t idx); private: - std::unordered_set, BPEUTF8PairHash> _get_pairs(const std::vector& word); + inline std::vector utf8String2Cpts(const std::string& str) { + std::vector word32; + utf8::utf8to32(str.begin(), str.end(), std::back_inserter(word32)); + return word32; + } + + inline std::string cpts2Utf8String(const std::vector& cpts) { + std::string str; + utf8::utf32to8(cpts.begin(), cpts.end(), std::back_inserter(str)); + return str; + } - std::unordered_map vocab_; - std::unordered_map vocab_inverse_; - std::unordered_map, int64_t, BPEUTF8PairHash> bpe_ranks_; + std::unordered_set, BPEUTF8PairHash> _get_pairs(const std::vector& word); + std::unordered_map vocab_; + std::unordered_map vocab_inverse_; + std::unordered_map, int64_t, BPEUTF8PairHash> bpe_ranks_; }; } // namespace mllm::preprocessor diff --git a/mllm/preprocessor/tokenizers/llama_cpp_unicode/README.md b/mllm/preprocessor/tokenizers/llama_cpp_unicode/README.md new file mode 100644 index 000000000..b7f405921 --- /dev/null +++ b/mllm/preprocessor/tokenizers/llama_cpp_unicode/README.md @@ -0,0 +1,8 @@ +# llama.cpp Unicode + +This is a vendored copy of the `unicode.h` and `unicode-data.h` modules from [llama.cpp](https://github.com/ggerganov/llama.cpp), along with their corresponding source files. The modules are held as vendored source rather than submodules since they are a small subset of the overall `llama.cpp` project. + +## Latest Update + +llama.cpp - commit 54ef9cfc +https://github.com/ggerganov/llama.cpp \ No newline at end of file diff --git a/mllm/preprocessor/tokenizers/llama_cpp_unicode/unicode-data.cpp b/mllm/preprocessor/tokenizers/llama_cpp_unicode/unicode-data.cpp new file mode 100644 index 000000000..32b39afdb --- /dev/null +++ b/mllm/preprocessor/tokenizers/llama_cpp_unicode/unicode-data.cpp @@ -0,0 +1,1617 @@ +/* +llama.cpp - commit 54ef9cfc +https://github.com/ggerganov/llama.cpp + +MIT License + +Copyright (c) 2023-2024 The ggml authors + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. +*/ + +// generated with scripts/gen-unicode-data.py + +#include "unicode-data.h" + +#include +#include +#include +#include + +const std::initializer_list> unicode_ranges_flags = { + // start, flags // last=next_start-1 + {0x000000, 0x0080}, {0x000020, 0x0008}, {0x000021, 0x0020}, {0x000024, 0x0040}, {0x000025, 0x0020}, {0x00002B, 0x0040}, + {0x00002C, 0x0020}, {0x000030, 0x0002}, {0x00003A, 0x0020}, {0x00003C, 0x0040}, {0x00003F, 0x0020}, {0x000041, 0x0004}, + {0x00005B, 0x0020}, {0x00005E, 0x0040}, {0x00005F, 0x0020}, {0x000060, 0x0040}, {0x000061, 0x0004}, {0x00007B, 0x0020}, + {0x00007C, 0x0040}, {0x00007D, 0x0020}, {0x00007E, 0x0040}, {0x00007F, 0x0080}, {0x0000A0, 0x0008}, {0x0000A1, 0x0020}, + {0x0000A2, 0x0040}, {0x0000A7, 0x0020}, {0x0000A8, 0x0040}, {0x0000AA, 0x0004}, {0x0000AB, 0x0020}, {0x0000AC, 0x0040}, + {0x0000AD, 0x0080}, {0x0000AE, 0x0040}, {0x0000B2, 0x0002}, {0x0000B4, 0x0040}, {0x0000B5, 0x0004}, {0x0000B6, 0x0020}, + {0x0000B8, 0x0040}, {0x0000B9, 0x0002}, {0x0000BA, 0x0004}, {0x0000BB, 0x0020}, {0x0000BC, 0x0002}, {0x0000BF, 0x0020}, + {0x0000C0, 0x0004}, {0x0000D7, 0x0040}, {0x0000D8, 0x0004}, {0x0000F7, 0x0040}, {0x0000F8, 0x0004}, {0x0002C2, 0x0040}, + {0x0002C6, 0x0004}, {0x0002D2, 0x0040}, {0x0002E0, 0x0004}, {0x0002E5, 0x0040}, {0x0002EC, 0x0004}, {0x0002ED, 0x0040}, + {0x0002EE, 0x0004}, {0x0002EF, 0x0040}, {0x000300, 0x0010}, {0x000370, 0x0004}, {0x000375, 0x0040}, {0x000376, 0x0004}, + {0x000378, 0x0001}, {0x00037A, 0x0004}, {0x00037E, 0x0020}, {0x00037F, 0x0004}, {0x000380, 0x0001}, {0x000384, 0x0040}, + {0x000386, 0x0004}, {0x000387, 0x0020}, {0x000388, 0x0004}, {0x00038B, 0x0001}, {0x00038C, 0x0004}, {0x00038D, 0x0001}, + {0x00038E, 0x0004}, {0x0003A2, 0x0001}, {0x0003A3, 0x0004}, {0x0003F6, 0x0040}, {0x0003F7, 0x0004}, {0x000482, 0x0040}, + {0x000483, 0x0010}, {0x00048A, 0x0004}, {0x000530, 0x0001}, {0x000531, 0x0004}, {0x000557, 0x0001}, {0x000559, 0x0004}, + {0x00055A, 0x0020}, {0x000560, 0x0004}, {0x000589, 0x0020}, {0x00058B, 0x0001}, {0x00058D, 0x0040}, {0x000590, 0x0001}, + {0x000591, 0x0010}, {0x0005BE, 0x0020}, {0x0005BF, 0x0010}, {0x0005C0, 0x0020}, {0x0005C1, 0x0010}, {0x0005C3, 0x0020}, + {0x0005C4, 0x0010}, {0x0005C6, 0x0020}, {0x0005C7, 0x0010}, {0x0005C8, 0x0001}, {0x0005D0, 0x0004}, {0x0005EB, 0x0001}, + {0x0005EF, 0x0004}, {0x0005F3, 0x0020}, {0x0005F5, 0x0001}, {0x000600, 0x0080}, {0x000606, 0x0040}, {0x000609, 0x0020}, + {0x00060B, 0x0040}, {0x00060C, 0x0020}, {0x00060E, 0x0040}, {0x000610, 0x0010}, {0x00061B, 0x0020}, {0x00061C, 0x0080}, + {0x00061D, 0x0020}, {0x000620, 0x0004}, {0x00064B, 0x0010}, {0x000660, 0x0002}, {0x00066A, 0x0020}, {0x00066E, 0x0004}, + {0x000670, 0x0010}, {0x000671, 0x0004}, {0x0006D4, 0x0020}, {0x0006D5, 0x0004}, {0x0006D6, 0x0010}, {0x0006DD, 0x0080}, + {0x0006DE, 0x0040}, {0x0006DF, 0x0010}, {0x0006E5, 0x0004}, {0x0006E7, 0x0010}, {0x0006E9, 0x0040}, {0x0006EA, 0x0010}, + {0x0006EE, 0x0004}, {0x0006F0, 0x0002}, {0x0006FA, 0x0004}, {0x0006FD, 0x0040}, {0x0006FF, 0x0004}, {0x000700, 0x0020}, + {0x00070E, 0x0001}, {0x00070F, 0x0080}, {0x000710, 0x0004}, {0x000711, 0x0010}, {0x000712, 0x0004}, {0x000730, 0x0010}, + {0x00074B, 0x0001}, {0x00074D, 0x0004}, {0x0007A6, 0x0010}, {0x0007B1, 0x0004}, {0x0007B2, 0x0001}, {0x0007C0, 0x0002}, + {0x0007CA, 0x0004}, {0x0007EB, 0x0010}, {0x0007F4, 0x0004}, {0x0007F6, 0x0040}, {0x0007F7, 0x0020}, {0x0007FA, 0x0004}, + {0x0007FB, 0x0001}, {0x0007FD, 0x0010}, {0x0007FE, 0x0040}, {0x000800, 0x0004}, {0x000816, 0x0010}, {0x00081A, 0x0004}, + {0x00081B, 0x0010}, {0x000824, 0x0004}, {0x000825, 0x0010}, {0x000828, 0x0004}, {0x000829, 0x0010}, {0x00082E, 0x0001}, + {0x000830, 0x0020}, {0x00083F, 0x0001}, {0x000840, 0x0004}, {0x000859, 0x0010}, {0x00085C, 0x0001}, {0x00085E, 0x0020}, + {0x00085F, 0x0001}, {0x000860, 0x0004}, {0x00086B, 0x0001}, {0x000870, 0x0004}, {0x000888, 0x0040}, {0x000889, 0x0004}, + {0x00088F, 0x0001}, {0x000890, 0x0080}, {0x000892, 0x0001}, {0x000898, 0x0010}, {0x0008A0, 0x0004}, {0x0008CA, 0x0010}, + {0x0008E2, 0x0080}, {0x0008E3, 0x0010}, {0x000904, 0x0004}, {0x00093A, 0x0010}, {0x00093D, 0x0004}, {0x00093E, 0x0010}, + {0x000950, 0x0004}, {0x000951, 0x0010}, {0x000958, 0x0004}, {0x000962, 0x0010}, {0x000964, 0x0020}, {0x000966, 0x0002}, + {0x000970, 0x0020}, {0x000971, 0x0004}, {0x000981, 0x0010}, {0x000984, 0x0001}, {0x000985, 0x0004}, {0x00098D, 0x0001}, + {0x00098F, 0x0004}, {0x000991, 0x0001}, {0x000993, 0x0004}, {0x0009A9, 0x0001}, {0x0009AA, 0x0004}, {0x0009B1, 0x0001}, + {0x0009B2, 0x0004}, {0x0009B3, 0x0001}, {0x0009B6, 0x0004}, {0x0009BA, 0x0001}, {0x0009BC, 0x0010}, {0x0009BD, 0x0004}, + {0x0009BE, 0x0010}, {0x0009C5, 0x0001}, {0x0009C7, 0x0010}, {0x0009C9, 0x0001}, {0x0009CB, 0x0010}, {0x0009CE, 0x0004}, + {0x0009CF, 0x0001}, {0x0009D7, 0x0010}, {0x0009D8, 0x0001}, {0x0009DC, 0x0004}, {0x0009DE, 0x0001}, {0x0009DF, 0x0004}, + {0x0009E2, 0x0010}, {0x0009E4, 0x0001}, {0x0009E6, 0x0002}, {0x0009F0, 0x0004}, {0x0009F2, 0x0040}, {0x0009F4, 0x0002}, + {0x0009FA, 0x0040}, {0x0009FC, 0x0004}, {0x0009FD, 0x0020}, {0x0009FE, 0x0010}, {0x0009FF, 0x0001}, {0x000A01, 0x0010}, + {0x000A04, 0x0001}, {0x000A05, 0x0004}, {0x000A0B, 0x0001}, {0x000A0F, 0x0004}, {0x000A11, 0x0001}, {0x000A13, 0x0004}, + {0x000A29, 0x0001}, {0x000A2A, 0x0004}, {0x000A31, 0x0001}, {0x000A32, 0x0004}, {0x000A34, 0x0001}, {0x000A35, 0x0004}, + {0x000A37, 0x0001}, {0x000A38, 0x0004}, {0x000A3A, 0x0001}, {0x000A3C, 0x0010}, {0x000A3D, 0x0001}, {0x000A3E, 0x0010}, + {0x000A43, 0x0001}, {0x000A47, 0x0010}, {0x000A49, 0x0001}, {0x000A4B, 0x0010}, {0x000A4E, 0x0001}, {0x000A51, 0x0010}, + {0x000A52, 0x0001}, {0x000A59, 0x0004}, {0x000A5D, 0x0001}, {0x000A5E, 0x0004}, {0x000A5F, 0x0001}, {0x000A66, 0x0002}, + {0x000A70, 0x0010}, {0x000A72, 0x0004}, {0x000A75, 0x0010}, {0x000A76, 0x0020}, {0x000A77, 0x0001}, {0x000A81, 0x0010}, + {0x000A84, 0x0001}, {0x000A85, 0x0004}, {0x000A8E, 0x0001}, {0x000A8F, 0x0004}, {0x000A92, 0x0001}, {0x000A93, 0x0004}, + {0x000AA9, 0x0001}, {0x000AAA, 0x0004}, {0x000AB1, 0x0001}, {0x000AB2, 0x0004}, {0x000AB4, 0x0001}, {0x000AB5, 0x0004}, + {0x000ABA, 0x0001}, {0x000ABC, 0x0010}, {0x000ABD, 0x0004}, {0x000ABE, 0x0010}, {0x000AC6, 0x0001}, {0x000AC7, 0x0010}, + {0x000ACA, 0x0001}, {0x000ACB, 0x0010}, {0x000ACE, 0x0001}, {0x000AD0, 0x0004}, {0x000AD1, 0x0001}, {0x000AE0, 0x0004}, + {0x000AE2, 0x0010}, {0x000AE4, 0x0001}, {0x000AE6, 0x0002}, {0x000AF0, 0x0020}, {0x000AF1, 0x0040}, {0x000AF2, 0x0001}, + {0x000AF9, 0x0004}, {0x000AFA, 0x0010}, {0x000B00, 0x0001}, {0x000B01, 0x0010}, {0x000B04, 0x0001}, {0x000B05, 0x0004}, + {0x000B0D, 0x0001}, {0x000B0F, 0x0004}, {0x000B11, 0x0001}, {0x000B13, 0x0004}, {0x000B29, 0x0001}, {0x000B2A, 0x0004}, + {0x000B31, 0x0001}, {0x000B32, 0x0004}, {0x000B34, 0x0001}, {0x000B35, 0x0004}, {0x000B3A, 0x0001}, {0x000B3C, 0x0010}, + {0x000B3D, 0x0004}, {0x000B3E, 0x0010}, {0x000B45, 0x0001}, {0x000B47, 0x0010}, {0x000B49, 0x0001}, {0x000B4B, 0x0010}, + {0x000B4E, 0x0001}, {0x000B55, 0x0010}, {0x000B58, 0x0001}, {0x000B5C, 0x0004}, {0x000B5E, 0x0001}, {0x000B5F, 0x0004}, + {0x000B62, 0x0010}, {0x000B64, 0x0001}, {0x000B66, 0x0002}, {0x000B70, 0x0040}, {0x000B71, 0x0004}, {0x000B72, 0x0002}, + {0x000B78, 0x0001}, {0x000B82, 0x0010}, {0x000B83, 0x0004}, {0x000B84, 0x0001}, {0x000B85, 0x0004}, {0x000B8B, 0x0001}, + {0x000B8E, 0x0004}, {0x000B91, 0x0001}, {0x000B92, 0x0004}, {0x000B96, 0x0001}, {0x000B99, 0x0004}, {0x000B9B, 0x0001}, + {0x000B9C, 0x0004}, {0x000B9D, 0x0001}, {0x000B9E, 0x0004}, {0x000BA0, 0x0001}, {0x000BA3, 0x0004}, {0x000BA5, 0x0001}, + {0x000BA8, 0x0004}, {0x000BAB, 0x0001}, {0x000BAE, 0x0004}, {0x000BBA, 0x0001}, {0x000BBE, 0x0010}, {0x000BC3, 0x0001}, + {0x000BC6, 0x0010}, {0x000BC9, 0x0001}, {0x000BCA, 0x0010}, {0x000BCE, 0x0001}, {0x000BD0, 0x0004}, {0x000BD1, 0x0001}, + {0x000BD7, 0x0010}, {0x000BD8, 0x0001}, {0x000BE6, 0x0002}, {0x000BF3, 0x0040}, {0x000BFB, 0x0001}, {0x000C00, 0x0010}, + {0x000C05, 0x0004}, {0x000C0D, 0x0001}, {0x000C0E, 0x0004}, {0x000C11, 0x0001}, {0x000C12, 0x0004}, {0x000C29, 0x0001}, + {0x000C2A, 0x0004}, {0x000C3A, 0x0001}, {0x000C3C, 0x0010}, {0x000C3D, 0x0004}, {0x000C3E, 0x0010}, {0x000C45, 0x0001}, + {0x000C46, 0x0010}, {0x000C49, 0x0001}, {0x000C4A, 0x0010}, {0x000C4E, 0x0001}, {0x000C55, 0x0010}, {0x000C57, 0x0001}, + {0x000C58, 0x0004}, {0x000C5B, 0x0001}, {0x000C5D, 0x0004}, {0x000C5E, 0x0001}, {0x000C60, 0x0004}, {0x000C62, 0x0010}, + {0x000C64, 0x0001}, {0x000C66, 0x0002}, {0x000C70, 0x0001}, {0x000C77, 0x0020}, {0x000C78, 0x0002}, {0x000C7F, 0x0040}, + {0x000C80, 0x0004}, {0x000C81, 0x0010}, {0x000C84, 0x0020}, {0x000C85, 0x0004}, {0x000C8D, 0x0001}, {0x000C8E, 0x0004}, + {0x000C91, 0x0001}, {0x000C92, 0x0004}, {0x000CA9, 0x0001}, {0x000CAA, 0x0004}, {0x000CB4, 0x0001}, {0x000CB5, 0x0004}, + {0x000CBA, 0x0001}, {0x000CBC, 0x0010}, {0x000CBD, 0x0004}, {0x000CBE, 0x0010}, {0x000CC5, 0x0001}, {0x000CC6, 0x0010}, + {0x000CC9, 0x0001}, {0x000CCA, 0x0010}, {0x000CCE, 0x0001}, {0x000CD5, 0x0010}, {0x000CD7, 0x0001}, {0x000CDD, 0x0004}, + {0x000CDF, 0x0001}, {0x000CE0, 0x0004}, {0x000CE2, 0x0010}, {0x000CE4, 0x0001}, {0x000CE6, 0x0002}, {0x000CF0, 0x0001}, + {0x000CF1, 0x0004}, {0x000CF3, 0x0010}, {0x000CF4, 0x0001}, {0x000D00, 0x0010}, {0x000D04, 0x0004}, {0x000D0D, 0x0001}, + {0x000D0E, 0x0004}, {0x000D11, 0x0001}, {0x000D12, 0x0004}, {0x000D3B, 0x0010}, {0x000D3D, 0x0004}, {0x000D3E, 0x0010}, + {0x000D45, 0x0001}, {0x000D46, 0x0010}, {0x000D49, 0x0001}, {0x000D4A, 0x0010}, {0x000D4E, 0x0004}, {0x000D4F, 0x0040}, + {0x000D50, 0x0001}, {0x000D54, 0x0004}, {0x000D57, 0x0010}, {0x000D58, 0x0002}, {0x000D5F, 0x0004}, {0x000D62, 0x0010}, + {0x000D64, 0x0001}, {0x000D66, 0x0002}, {0x000D79, 0x0040}, {0x000D7A, 0x0004}, {0x000D80, 0x0001}, {0x000D81, 0x0010}, + {0x000D84, 0x0001}, {0x000D85, 0x0004}, {0x000D97, 0x0001}, {0x000D9A, 0x0004}, {0x000DB2, 0x0001}, {0x000DB3, 0x0004}, + {0x000DBC, 0x0001}, {0x000DBD, 0x0004}, {0x000DBE, 0x0001}, {0x000DC0, 0x0004}, {0x000DC7, 0x0001}, {0x000DCA, 0x0010}, + {0x000DCB, 0x0001}, {0x000DCF, 0x0010}, {0x000DD5, 0x0001}, {0x000DD6, 0x0010}, {0x000DD7, 0x0001}, {0x000DD8, 0x0010}, + {0x000DE0, 0x0001}, {0x000DE6, 0x0002}, {0x000DF0, 0x0001}, {0x000DF2, 0x0010}, {0x000DF4, 0x0020}, {0x000DF5, 0x0001}, + {0x000E01, 0x0004}, {0x000E31, 0x0010}, {0x000E32, 0x0004}, {0x000E34, 0x0010}, {0x000E3B, 0x0001}, {0x000E3F, 0x0040}, + {0x000E40, 0x0004}, {0x000E47, 0x0010}, {0x000E4F, 0x0020}, {0x000E50, 0x0002}, {0x000E5A, 0x0020}, {0x000E5C, 0x0001}, + {0x000E81, 0x0004}, {0x000E83, 0x0001}, {0x000E84, 0x0004}, {0x000E85, 0x0001}, {0x000E86, 0x0004}, {0x000E8B, 0x0001}, + {0x000E8C, 0x0004}, {0x000EA4, 0x0001}, {0x000EA5, 0x0004}, {0x000EA6, 0x0001}, {0x000EA7, 0x0004}, {0x000EB1, 0x0010}, + {0x000EB2, 0x0004}, {0x000EB4, 0x0010}, {0x000EBD, 0x0004}, {0x000EBE, 0x0001}, {0x000EC0, 0x0004}, {0x000EC5, 0x0001}, + {0x000EC6, 0x0004}, {0x000EC7, 0x0001}, {0x000EC8, 0x0010}, {0x000ECF, 0x0001}, {0x000ED0, 0x0002}, {0x000EDA, 0x0001}, + {0x000EDC, 0x0004}, {0x000EE0, 0x0001}, {0x000F00, 0x0004}, {0x000F01, 0x0040}, {0x000F04, 0x0020}, {0x000F13, 0x0040}, + {0x000F14, 0x0020}, {0x000F15, 0x0040}, {0x000F18, 0x0010}, {0x000F1A, 0x0040}, {0x000F20, 0x0002}, {0x000F34, 0x0040}, + {0x000F35, 0x0010}, {0x000F36, 0x0040}, {0x000F37, 0x0010}, {0x000F38, 0x0040}, {0x000F39, 0x0010}, {0x000F3A, 0x0020}, + {0x000F3E, 0x0010}, {0x000F40, 0x0004}, {0x000F48, 0x0001}, {0x000F49, 0x0004}, {0x000F6D, 0x0001}, {0x000F71, 0x0010}, + {0x000F85, 0x0020}, {0x000F86, 0x0010}, {0x000F88, 0x0004}, {0x000F8D, 0x0010}, {0x000F98, 0x0001}, {0x000F99, 0x0010}, + {0x000FBD, 0x0001}, {0x000FBE, 0x0040}, {0x000FC6, 0x0010}, {0x000FC7, 0x0040}, {0x000FCD, 0x0001}, {0x000FCE, 0x0040}, + {0x000FD0, 0x0020}, {0x000FD5, 0x0040}, {0x000FD9, 0x0020}, {0x000FDB, 0x0001}, {0x001000, 0x0004}, {0x00102B, 0x0010}, + {0x00103F, 0x0004}, {0x001040, 0x0002}, {0x00104A, 0x0020}, {0x001050, 0x0004}, {0x001056, 0x0010}, {0x00105A, 0x0004}, + {0x00105E, 0x0010}, {0x001061, 0x0004}, {0x001062, 0x0010}, {0x001065, 0x0004}, {0x001067, 0x0010}, {0x00106E, 0x0004}, + {0x001071, 0x0010}, {0x001075, 0x0004}, {0x001082, 0x0010}, {0x00108E, 0x0004}, {0x00108F, 0x0010}, {0x001090, 0x0002}, + {0x00109A, 0x0010}, {0x00109E, 0x0040}, {0x0010A0, 0x0004}, {0x0010C6, 0x0001}, {0x0010C7, 0x0004}, {0x0010C8, 0x0001}, + {0x0010CD, 0x0004}, {0x0010CE, 0x0001}, {0x0010D0, 0x0004}, {0x0010FB, 0x0020}, {0x0010FC, 0x0004}, {0x001249, 0x0001}, + {0x00124A, 0x0004}, {0x00124E, 0x0001}, {0x001250, 0x0004}, {0x001257, 0x0001}, {0x001258, 0x0004}, {0x001259, 0x0001}, + {0x00125A, 0x0004}, {0x00125E, 0x0001}, {0x001260, 0x0004}, {0x001289, 0x0001}, {0x00128A, 0x0004}, {0x00128E, 0x0001}, + {0x001290, 0x0004}, {0x0012B1, 0x0001}, {0x0012B2, 0x0004}, {0x0012B6, 0x0001}, {0x0012B8, 0x0004}, {0x0012BF, 0x0001}, + {0x0012C0, 0x0004}, {0x0012C1, 0x0001}, {0x0012C2, 0x0004}, {0x0012C6, 0x0001}, {0x0012C8, 0x0004}, {0x0012D7, 0x0001}, + {0x0012D8, 0x0004}, {0x001311, 0x0001}, {0x001312, 0x0004}, {0x001316, 0x0001}, {0x001318, 0x0004}, {0x00135B, 0x0001}, + {0x00135D, 0x0010}, {0x001360, 0x0020}, {0x001369, 0x0002}, {0x00137D, 0x0001}, {0x001380, 0x0004}, {0x001390, 0x0040}, + {0x00139A, 0x0001}, {0x0013A0, 0x0004}, {0x0013F6, 0x0001}, {0x0013F8, 0x0004}, {0x0013FE, 0x0001}, {0x001400, 0x0020}, + {0x001401, 0x0004}, {0x00166D, 0x0040}, {0x00166E, 0x0020}, {0x00166F, 0x0004}, {0x001680, 0x0008}, {0x001681, 0x0004}, + {0x00169B, 0x0020}, {0x00169D, 0x0001}, {0x0016A0, 0x0004}, {0x0016EB, 0x0020}, {0x0016EE, 0x0002}, {0x0016F1, 0x0004}, + {0x0016F9, 0x0001}, {0x001700, 0x0004}, {0x001712, 0x0010}, {0x001716, 0x0001}, {0x00171F, 0x0004}, {0x001732, 0x0010}, + {0x001735, 0x0020}, {0x001737, 0x0001}, {0x001740, 0x0004}, {0x001752, 0x0010}, {0x001754, 0x0001}, {0x001760, 0x0004}, + {0x00176D, 0x0001}, {0x00176E, 0x0004}, {0x001771, 0x0001}, {0x001772, 0x0010}, {0x001774, 0x0001}, {0x001780, 0x0004}, + {0x0017B4, 0x0010}, {0x0017D4, 0x0020}, {0x0017D7, 0x0004}, {0x0017D8, 0x0020}, {0x0017DB, 0x0040}, {0x0017DC, 0x0004}, + {0x0017DD, 0x0010}, {0x0017DE, 0x0001}, {0x0017E0, 0x0002}, {0x0017EA, 0x0001}, {0x0017F0, 0x0002}, {0x0017FA, 0x0001}, + {0x001800, 0x0020}, {0x00180B, 0x0010}, {0x00180E, 0x0080}, {0x00180F, 0x0010}, {0x001810, 0x0002}, {0x00181A, 0x0001}, + {0x001820, 0x0004}, {0x001879, 0x0001}, {0x001880, 0x0004}, {0x001885, 0x0010}, {0x001887, 0x0004}, {0x0018A9, 0x0010}, + {0x0018AA, 0x0004}, {0x0018AB, 0x0001}, {0x0018B0, 0x0004}, {0x0018F6, 0x0001}, {0x001900, 0x0004}, {0x00191F, 0x0001}, + {0x001920, 0x0010}, {0x00192C, 0x0001}, {0x001930, 0x0010}, {0x00193C, 0x0001}, {0x001940, 0x0040}, {0x001941, 0x0001}, + {0x001944, 0x0020}, {0x001946, 0x0002}, {0x001950, 0x0004}, {0x00196E, 0x0001}, {0x001970, 0x0004}, {0x001975, 0x0001}, + {0x001980, 0x0004}, {0x0019AC, 0x0001}, {0x0019B0, 0x0004}, {0x0019CA, 0x0001}, {0x0019D0, 0x0002}, {0x0019DB, 0x0001}, + {0x0019DE, 0x0040}, {0x001A00, 0x0004}, {0x001A17, 0x0010}, {0x001A1C, 0x0001}, {0x001A1E, 0x0020}, {0x001A20, 0x0004}, + {0x001A55, 0x0010}, {0x001A5F, 0x0001}, {0x001A60, 0x0010}, {0x001A7D, 0x0001}, {0x001A7F, 0x0010}, {0x001A80, 0x0002}, + {0x001A8A, 0x0001}, {0x001A90, 0x0002}, {0x001A9A, 0x0001}, {0x001AA0, 0x0020}, {0x001AA7, 0x0004}, {0x001AA8, 0x0020}, + {0x001AAE, 0x0001}, {0x001AB0, 0x0010}, {0x001ACF, 0x0001}, {0x001B00, 0x0010}, {0x001B05, 0x0004}, {0x001B34, 0x0010}, + {0x001B45, 0x0004}, {0x001B4D, 0x0001}, {0x001B50, 0x0002}, {0x001B5A, 0x0020}, {0x001B61, 0x0040}, {0x001B6B, 0x0010}, + {0x001B74, 0x0040}, {0x001B7D, 0x0020}, {0x001B7F, 0x0001}, {0x001B80, 0x0010}, {0x001B83, 0x0004}, {0x001BA1, 0x0010}, + {0x001BAE, 0x0004}, {0x001BB0, 0x0002}, {0x001BBA, 0x0004}, {0x001BE6, 0x0010}, {0x001BF4, 0x0001}, {0x001BFC, 0x0020}, + {0x001C00, 0x0004}, {0x001C24, 0x0010}, {0x001C38, 0x0001}, {0x001C3B, 0x0020}, {0x001C40, 0x0002}, {0x001C4A, 0x0001}, + {0x001C4D, 0x0004}, {0x001C50, 0x0002}, {0x001C5A, 0x0004}, {0x001C7E, 0x0020}, {0x001C80, 0x0004}, {0x001C89, 0x0001}, + {0x001C90, 0x0004}, {0x001CBB, 0x0001}, {0x001CBD, 0x0004}, {0x001CC0, 0x0020}, {0x001CC8, 0x0001}, {0x001CD0, 0x0010}, + {0x001CD3, 0x0020}, {0x001CD4, 0x0010}, {0x001CE9, 0x0004}, {0x001CED, 0x0010}, {0x001CEE, 0x0004}, {0x001CF4, 0x0010}, + {0x001CF5, 0x0004}, {0x001CF7, 0x0010}, {0x001CFA, 0x0004}, {0x001CFB, 0x0001}, {0x001D00, 0x0004}, {0x001DC0, 0x0010}, + {0x001E00, 0x0004}, {0x001F16, 0x0001}, {0x001F18, 0x0004}, {0x001F1E, 0x0001}, {0x001F20, 0x0004}, {0x001F46, 0x0001}, + {0x001F48, 0x0004}, {0x001F4E, 0x0001}, {0x001F50, 0x0004}, {0x001F58, 0x0001}, {0x001F59, 0x0004}, {0x001F5A, 0x0001}, + {0x001F5B, 0x0004}, {0x001F5C, 0x0001}, {0x001F5D, 0x0004}, {0x001F5E, 0x0001}, {0x001F5F, 0x0004}, {0x001F7E, 0x0001}, + {0x001F80, 0x0004}, {0x001FB5, 0x0001}, {0x001FB6, 0x0004}, {0x001FBD, 0x0040}, {0x001FBE, 0x0004}, {0x001FBF, 0x0040}, + {0x001FC2, 0x0004}, {0x001FC5, 0x0001}, {0x001FC6, 0x0004}, {0x001FCD, 0x0040}, {0x001FD0, 0x0004}, {0x001FD4, 0x0001}, + {0x001FD6, 0x0004}, {0x001FDC, 0x0001}, {0x001FDD, 0x0040}, {0x001FE0, 0x0004}, {0x001FED, 0x0040}, {0x001FF0, 0x0001}, + {0x001FF2, 0x0004}, {0x001FF5, 0x0001}, {0x001FF6, 0x0004}, {0x001FFD, 0x0040}, {0x001FFF, 0x0001}, {0x002000, 0x0008}, + {0x00200B, 0x0080}, {0x002010, 0x0020}, {0x002028, 0x0008}, {0x00202A, 0x0080}, {0x00202F, 0x0008}, {0x002030, 0x0020}, + {0x002044, 0x0040}, {0x002045, 0x0020}, {0x002052, 0x0040}, {0x002053, 0x0020}, {0x00205F, 0x0008}, {0x002060, 0x0080}, + {0x002065, 0x0001}, {0x002066, 0x0080}, {0x002070, 0x0002}, {0x002071, 0x0004}, {0x002072, 0x0001}, {0x002074, 0x0002}, + {0x00207A, 0x0040}, {0x00207D, 0x0020}, {0x00207F, 0x0004}, {0x002080, 0x0002}, {0x00208A, 0x0040}, {0x00208D, 0x0020}, + {0x00208F, 0x0001}, {0x002090, 0x0004}, {0x00209D, 0x0001}, {0x0020A0, 0x0040}, {0x0020C1, 0x0001}, {0x0020D0, 0x0010}, + {0x0020F1, 0x0001}, {0x002100, 0x0040}, {0x002102, 0x0004}, {0x002103, 0x0040}, {0x002107, 0x0004}, {0x002108, 0x0040}, + {0x00210A, 0x0004}, {0x002114, 0x0040}, {0x002115, 0x0004}, {0x002116, 0x0040}, {0x002119, 0x0004}, {0x00211E, 0x0040}, + {0x002124, 0x0004}, {0x002125, 0x0040}, {0x002126, 0x0004}, {0x002127, 0x0040}, {0x002128, 0x0004}, {0x002129, 0x0040}, + {0x00212A, 0x0004}, {0x00212E, 0x0040}, {0x00212F, 0x0004}, {0x00213A, 0x0040}, {0x00213C, 0x0004}, {0x002140, 0x0040}, + {0x002145, 0x0004}, {0x00214A, 0x0040}, {0x00214E, 0x0004}, {0x00214F, 0x0040}, {0x002150, 0x0002}, {0x002183, 0x0004}, + {0x002185, 0x0002}, {0x00218A, 0x0040}, {0x00218C, 0x0001}, {0x002190, 0x0040}, {0x002308, 0x0020}, {0x00230C, 0x0040}, + {0x002329, 0x0020}, {0x00232B, 0x0040}, {0x002427, 0x0001}, {0x002440, 0x0040}, {0x00244B, 0x0001}, {0x002460, 0x0002}, + {0x00249C, 0x0040}, {0x0024EA, 0x0002}, {0x002500, 0x0040}, {0x002768, 0x0020}, {0x002776, 0x0002}, {0x002794, 0x0040}, + {0x0027C5, 0x0020}, {0x0027C7, 0x0040}, {0x0027E6, 0x0020}, {0x0027F0, 0x0040}, {0x002983, 0x0020}, {0x002999, 0x0040}, + {0x0029D8, 0x0020}, {0x0029DC, 0x0040}, {0x0029FC, 0x0020}, {0x0029FE, 0x0040}, {0x002B74, 0x0001}, {0x002B76, 0x0040}, + {0x002B96, 0x0001}, {0x002B97, 0x0040}, {0x002C00, 0x0004}, {0x002CE5, 0x0040}, {0x002CEB, 0x0004}, {0x002CEF, 0x0010}, + {0x002CF2, 0x0004}, {0x002CF4, 0x0001}, {0x002CF9, 0x0020}, {0x002CFD, 0x0002}, {0x002CFE, 0x0020}, {0x002D00, 0x0004}, + {0x002D26, 0x0001}, {0x002D27, 0x0004}, {0x002D28, 0x0001}, {0x002D2D, 0x0004}, {0x002D2E, 0x0001}, {0x002D30, 0x0004}, + {0x002D68, 0x0001}, {0x002D6F, 0x0004}, {0x002D70, 0x0020}, {0x002D71, 0x0001}, {0x002D7F, 0x0010}, {0x002D80, 0x0004}, + {0x002D97, 0x0001}, {0x002DA0, 0x0004}, {0x002DA7, 0x0001}, {0x002DA8, 0x0004}, {0x002DAF, 0x0001}, {0x002DB0, 0x0004}, + {0x002DB7, 0x0001}, {0x002DB8, 0x0004}, {0x002DBF, 0x0001}, {0x002DC0, 0x0004}, {0x002DC7, 0x0001}, {0x002DC8, 0x0004}, + {0x002DCF, 0x0001}, {0x002DD0, 0x0004}, {0x002DD7, 0x0001}, {0x002DD8, 0x0004}, {0x002DDF, 0x0001}, {0x002DE0, 0x0010}, + {0x002E00, 0x0020}, {0x002E2F, 0x0004}, {0x002E30, 0x0020}, {0x002E50, 0x0040}, {0x002E52, 0x0020}, {0x002E5E, 0x0001}, + {0x002E80, 0x0040}, {0x002E9A, 0x0001}, {0x002E9B, 0x0040}, {0x002EF4, 0x0001}, {0x002F00, 0x0040}, {0x002FD6, 0x0001}, + {0x002FF0, 0x0040}, {0x003000, 0x0008}, {0x003001, 0x0020}, {0x003004, 0x0040}, {0x003005, 0x0004}, {0x003007, 0x0002}, + {0x003008, 0x0020}, {0x003012, 0x0040}, {0x003014, 0x0020}, {0x003020, 0x0040}, {0x003021, 0x0002}, {0x00302A, 0x0010}, + {0x003030, 0x0020}, {0x003031, 0x0004}, {0x003036, 0x0040}, {0x003038, 0x0002}, {0x00303B, 0x0004}, {0x00303D, 0x0020}, + {0x00303E, 0x0040}, {0x003040, 0x0001}, {0x003041, 0x0004}, {0x003097, 0x0001}, {0x003099, 0x0010}, {0x00309B, 0x0040}, + {0x00309D, 0x0004}, {0x0030A0, 0x0020}, {0x0030A1, 0x0004}, {0x0030FB, 0x0020}, {0x0030FC, 0x0004}, {0x003100, 0x0001}, + {0x003105, 0x0004}, {0x003130, 0x0001}, {0x003131, 0x0004}, {0x00318F, 0x0001}, {0x003190, 0x0040}, {0x003192, 0x0002}, + {0x003196, 0x0040}, {0x0031A0, 0x0004}, {0x0031C0, 0x0040}, {0x0031E4, 0x0001}, {0x0031EF, 0x0040}, {0x0031F0, 0x0004}, + {0x003200, 0x0040}, {0x00321F, 0x0001}, {0x003220, 0x0002}, {0x00322A, 0x0040}, {0x003248, 0x0002}, {0x003250, 0x0040}, + {0x003251, 0x0002}, {0x003260, 0x0040}, {0x003280, 0x0002}, {0x00328A, 0x0040}, {0x0032B1, 0x0002}, {0x0032C0, 0x0040}, + {0x003400, 0x0004}, {0x004DC0, 0x0040}, {0x004E00, 0x0004}, {0x00A48D, 0x0001}, {0x00A490, 0x0040}, {0x00A4C7, 0x0001}, + {0x00A4D0, 0x0004}, {0x00A4FE, 0x0020}, {0x00A500, 0x0004}, {0x00A60D, 0x0020}, {0x00A610, 0x0004}, {0x00A620, 0x0002}, + {0x00A62A, 0x0004}, {0x00A62C, 0x0001}, {0x00A640, 0x0004}, {0x00A66F, 0x0010}, {0x00A673, 0x0020}, {0x00A674, 0x0010}, + {0x00A67E, 0x0020}, {0x00A67F, 0x0004}, {0x00A69E, 0x0010}, {0x00A6A0, 0x0004}, {0x00A6E6, 0x0002}, {0x00A6F0, 0x0010}, + {0x00A6F2, 0x0020}, {0x00A6F8, 0x0001}, {0x00A700, 0x0040}, {0x00A717, 0x0004}, {0x00A720, 0x0040}, {0x00A722, 0x0004}, + {0x00A789, 0x0040}, {0x00A78B, 0x0004}, {0x00A7CB, 0x0001}, {0x00A7D0, 0x0004}, {0x00A7D2, 0x0001}, {0x00A7D3, 0x0004}, + {0x00A7D4, 0x0001}, {0x00A7D5, 0x0004}, {0x00A7DA, 0x0001}, {0x00A7F2, 0x0004}, {0x00A802, 0x0010}, {0x00A803, 0x0004}, + {0x00A806, 0x0010}, {0x00A807, 0x0004}, {0x00A80B, 0x0010}, {0x00A80C, 0x0004}, {0x00A823, 0x0010}, {0x00A828, 0x0040}, + {0x00A82C, 0x0010}, {0x00A82D, 0x0001}, {0x00A830, 0x0002}, {0x00A836, 0x0040}, {0x00A83A, 0x0001}, {0x00A840, 0x0004}, + {0x00A874, 0x0020}, {0x00A878, 0x0001}, {0x00A880, 0x0010}, {0x00A882, 0x0004}, {0x00A8B4, 0x0010}, {0x00A8C6, 0x0001}, + {0x00A8CE, 0x0020}, {0x00A8D0, 0x0002}, {0x00A8DA, 0x0001}, {0x00A8E0, 0x0010}, {0x00A8F2, 0x0004}, {0x00A8F8, 0x0020}, + {0x00A8FB, 0x0004}, {0x00A8FC, 0x0020}, {0x00A8FD, 0x0004}, {0x00A8FF, 0x0010}, {0x00A900, 0x0002}, {0x00A90A, 0x0004}, + {0x00A926, 0x0010}, {0x00A92E, 0x0020}, {0x00A930, 0x0004}, {0x00A947, 0x0010}, {0x00A954, 0x0001}, {0x00A95F, 0x0020}, + {0x00A960, 0x0004}, {0x00A97D, 0x0001}, {0x00A980, 0x0010}, {0x00A984, 0x0004}, {0x00A9B3, 0x0010}, {0x00A9C1, 0x0020}, + {0x00A9CE, 0x0001}, {0x00A9CF, 0x0004}, {0x00A9D0, 0x0002}, {0x00A9DA, 0x0001}, {0x00A9DE, 0x0020}, {0x00A9E0, 0x0004}, + {0x00A9E5, 0x0010}, {0x00A9E6, 0x0004}, {0x00A9F0, 0x0002}, {0x00A9FA, 0x0004}, {0x00A9FF, 0x0001}, {0x00AA00, 0x0004}, + {0x00AA29, 0x0010}, {0x00AA37, 0x0001}, {0x00AA40, 0x0004}, {0x00AA43, 0x0010}, {0x00AA44, 0x0004}, {0x00AA4C, 0x0010}, + {0x00AA4E, 0x0001}, {0x00AA50, 0x0002}, {0x00AA5A, 0x0001}, {0x00AA5C, 0x0020}, {0x00AA60, 0x0004}, {0x00AA77, 0x0040}, + {0x00AA7A, 0x0004}, {0x00AA7B, 0x0010}, {0x00AA7E, 0x0004}, {0x00AAB0, 0x0010}, {0x00AAB1, 0x0004}, {0x00AAB2, 0x0010}, + {0x00AAB5, 0x0004}, {0x00AAB7, 0x0010}, {0x00AAB9, 0x0004}, {0x00AABE, 0x0010}, {0x00AAC0, 0x0004}, {0x00AAC1, 0x0010}, + {0x00AAC2, 0x0004}, {0x00AAC3, 0x0001}, {0x00AADB, 0x0004}, {0x00AADE, 0x0020}, {0x00AAE0, 0x0004}, {0x00AAEB, 0x0010}, + {0x00AAF0, 0x0020}, {0x00AAF2, 0x0004}, {0x00AAF5, 0x0010}, {0x00AAF7, 0x0001}, {0x00AB01, 0x0004}, {0x00AB07, 0x0001}, + {0x00AB09, 0x0004}, {0x00AB0F, 0x0001}, {0x00AB11, 0x0004}, {0x00AB17, 0x0001}, {0x00AB20, 0x0004}, {0x00AB27, 0x0001}, + {0x00AB28, 0x0004}, {0x00AB2F, 0x0001}, {0x00AB30, 0x0004}, {0x00AB5B, 0x0040}, {0x00AB5C, 0x0004}, {0x00AB6A, 0x0040}, + {0x00AB6C, 0x0001}, {0x00AB70, 0x0004}, {0x00ABE3, 0x0010}, {0x00ABEB, 0x0020}, {0x00ABEC, 0x0010}, {0x00ABEE, 0x0001}, + {0x00ABF0, 0x0002}, {0x00ABFA, 0x0001}, {0x00AC00, 0x0004}, {0x00D7A4, 0x0001}, {0x00D7B0, 0x0004}, {0x00D7C7, 0x0001}, + {0x00D7CB, 0x0004}, {0x00D7FC, 0x0001}, {0x00D800, 0x0080}, {0x00F900, 0x0004}, {0x00FA6E, 0x0001}, {0x00FA70, 0x0004}, + {0x00FADA, 0x0001}, {0x00FB00, 0x0004}, {0x00FB07, 0x0001}, {0x00FB13, 0x0004}, {0x00FB18, 0x0001}, {0x00FB1D, 0x0004}, + {0x00FB1E, 0x0010}, {0x00FB1F, 0x0004}, {0x00FB29, 0x0040}, {0x00FB2A, 0x0004}, {0x00FB37, 0x0001}, {0x00FB38, 0x0004}, + {0x00FB3D, 0x0001}, {0x00FB3E, 0x0004}, {0x00FB3F, 0x0001}, {0x00FB40, 0x0004}, {0x00FB42, 0x0001}, {0x00FB43, 0x0004}, + {0x00FB45, 0x0001}, {0x00FB46, 0x0004}, {0x00FBB2, 0x0040}, {0x00FBC3, 0x0001}, {0x00FBD3, 0x0004}, {0x00FD3E, 0x0020}, + {0x00FD40, 0x0040}, {0x00FD50, 0x0004}, {0x00FD90, 0x0001}, {0x00FD92, 0x0004}, {0x00FDC8, 0x0001}, {0x00FDCF, 0x0040}, + {0x00FDD0, 0x0001}, {0x00FDF0, 0x0004}, {0x00FDFC, 0x0040}, {0x00FE00, 0x0010}, {0x00FE10, 0x0020}, {0x00FE1A, 0x0001}, + {0x00FE20, 0x0010}, {0x00FE30, 0x0020}, {0x00FE53, 0x0001}, {0x00FE54, 0x0020}, {0x00FE62, 0x0040}, {0x00FE63, 0x0020}, + {0x00FE64, 0x0040}, {0x00FE67, 0x0001}, {0x00FE68, 0x0020}, {0x00FE69, 0x0040}, {0x00FE6A, 0x0020}, {0x00FE6C, 0x0001}, + {0x00FE70, 0x0004}, {0x00FE75, 0x0001}, {0x00FE76, 0x0004}, {0x00FEFD, 0x0001}, {0x00FEFF, 0x0080}, {0x00FF00, 0x0001}, + {0x00FF01, 0x0020}, {0x00FF04, 0x0040}, {0x00FF05, 0x0020}, {0x00FF0B, 0x0040}, {0x00FF0C, 0x0020}, {0x00FF10, 0x0002}, + {0x00FF1A, 0x0020}, {0x00FF1C, 0x0040}, {0x00FF1F, 0x0020}, {0x00FF21, 0x0004}, {0x00FF3B, 0x0020}, {0x00FF3E, 0x0040}, + {0x00FF3F, 0x0020}, {0x00FF40, 0x0040}, {0x00FF41, 0x0004}, {0x00FF5B, 0x0020}, {0x00FF5C, 0x0040}, {0x00FF5D, 0x0020}, + {0x00FF5E, 0x0040}, {0x00FF5F, 0x0020}, {0x00FF66, 0x0004}, {0x00FFBF, 0x0001}, {0x00FFC2, 0x0004}, {0x00FFC8, 0x0001}, + {0x00FFCA, 0x0004}, {0x00FFD0, 0x0001}, {0x00FFD2, 0x0004}, {0x00FFD8, 0x0001}, {0x00FFDA, 0x0004}, {0x00FFDD, 0x0001}, + {0x00FFE0, 0x0040}, {0x00FFE7, 0x0001}, {0x00FFE8, 0x0040}, {0x00FFEF, 0x0001}, {0x00FFF9, 0x0080}, {0x00FFFC, 0x0040}, + {0x00FFFE, 0x0001}, {0x010000, 0x0004}, {0x01000C, 0x0001}, {0x01000D, 0x0004}, {0x010027, 0x0001}, {0x010028, 0x0004}, + {0x01003B, 0x0001}, {0x01003C, 0x0004}, {0x01003E, 0x0001}, {0x01003F, 0x0004}, {0x01004E, 0x0001}, {0x010050, 0x0004}, + {0x01005E, 0x0001}, {0x010080, 0x0004}, {0x0100FB, 0x0001}, {0x010100, 0x0020}, {0x010103, 0x0001}, {0x010107, 0x0002}, + {0x010134, 0x0001}, {0x010137, 0x0040}, {0x010140, 0x0002}, {0x010179, 0x0040}, {0x01018A, 0x0002}, {0x01018C, 0x0040}, + {0x01018F, 0x0001}, {0x010190, 0x0040}, {0x01019D, 0x0001}, {0x0101A0, 0x0040}, {0x0101A1, 0x0001}, {0x0101D0, 0x0040}, + {0x0101FD, 0x0010}, {0x0101FE, 0x0001}, {0x010280, 0x0004}, {0x01029D, 0x0001}, {0x0102A0, 0x0004}, {0x0102D1, 0x0001}, + {0x0102E0, 0x0010}, {0x0102E1, 0x0002}, {0x0102FC, 0x0001}, {0x010300, 0x0004}, {0x010320, 0x0002}, {0x010324, 0x0001}, + {0x01032D, 0x0004}, {0x010341, 0x0002}, {0x010342, 0x0004}, {0x01034A, 0x0002}, {0x01034B, 0x0001}, {0x010350, 0x0004}, + {0x010376, 0x0010}, {0x01037B, 0x0001}, {0x010380, 0x0004}, {0x01039E, 0x0001}, {0x01039F, 0x0020}, {0x0103A0, 0x0004}, + {0x0103C4, 0x0001}, {0x0103C8, 0x0004}, {0x0103D0, 0x0020}, {0x0103D1, 0x0002}, {0x0103D6, 0x0001}, {0x010400, 0x0004}, + {0x01049E, 0x0001}, {0x0104A0, 0x0002}, {0x0104AA, 0x0001}, {0x0104B0, 0x0004}, {0x0104D4, 0x0001}, {0x0104D8, 0x0004}, + {0x0104FC, 0x0001}, {0x010500, 0x0004}, {0x010528, 0x0001}, {0x010530, 0x0004}, {0x010564, 0x0001}, {0x01056F, 0x0020}, + {0x010570, 0x0004}, {0x01057B, 0x0001}, {0x01057C, 0x0004}, {0x01058B, 0x0001}, {0x01058C, 0x0004}, {0x010593, 0x0001}, + {0x010594, 0x0004}, {0x010596, 0x0001}, {0x010597, 0x0004}, {0x0105A2, 0x0001}, {0x0105A3, 0x0004}, {0x0105B2, 0x0001}, + {0x0105B3, 0x0004}, {0x0105BA, 0x0001}, {0x0105BB, 0x0004}, {0x0105BD, 0x0001}, {0x010600, 0x0004}, {0x010737, 0x0001}, + {0x010740, 0x0004}, {0x010756, 0x0001}, {0x010760, 0x0004}, {0x010768, 0x0001}, {0x010780, 0x0004}, {0x010786, 0x0001}, + {0x010787, 0x0004}, {0x0107B1, 0x0001}, {0x0107B2, 0x0004}, {0x0107BB, 0x0001}, {0x010800, 0x0004}, {0x010806, 0x0001}, + {0x010808, 0x0004}, {0x010809, 0x0001}, {0x01080A, 0x0004}, {0x010836, 0x0001}, {0x010837, 0x0004}, {0x010839, 0x0001}, + {0x01083C, 0x0004}, {0x01083D, 0x0001}, {0x01083F, 0x0004}, {0x010856, 0x0001}, {0x010857, 0x0020}, {0x010858, 0x0002}, + {0x010860, 0x0004}, {0x010877, 0x0040}, {0x010879, 0x0002}, {0x010880, 0x0004}, {0x01089F, 0x0001}, {0x0108A7, 0x0002}, + {0x0108B0, 0x0001}, {0x0108E0, 0x0004}, {0x0108F3, 0x0001}, {0x0108F4, 0x0004}, {0x0108F6, 0x0001}, {0x0108FB, 0x0002}, + {0x010900, 0x0004}, {0x010916, 0x0002}, {0x01091C, 0x0001}, {0x01091F, 0x0020}, {0x010920, 0x0004}, {0x01093A, 0x0001}, + {0x01093F, 0x0020}, {0x010940, 0x0001}, {0x010980, 0x0004}, {0x0109B8, 0x0001}, {0x0109BC, 0x0002}, {0x0109BE, 0x0004}, + {0x0109C0, 0x0002}, {0x0109D0, 0x0001}, {0x0109D2, 0x0002}, {0x010A00, 0x0004}, {0x010A01, 0x0010}, {0x010A04, 0x0001}, + {0x010A05, 0x0010}, {0x010A07, 0x0001}, {0x010A0C, 0x0010}, {0x010A10, 0x0004}, {0x010A14, 0x0001}, {0x010A15, 0x0004}, + {0x010A18, 0x0001}, {0x010A19, 0x0004}, {0x010A36, 0x0001}, {0x010A38, 0x0010}, {0x010A3B, 0x0001}, {0x010A3F, 0x0010}, + {0x010A40, 0x0002}, {0x010A49, 0x0001}, {0x010A50, 0x0020}, {0x010A59, 0x0001}, {0x010A60, 0x0004}, {0x010A7D, 0x0002}, + {0x010A7F, 0x0020}, {0x010A80, 0x0004}, {0x010A9D, 0x0002}, {0x010AA0, 0x0001}, {0x010AC0, 0x0004}, {0x010AC8, 0x0040}, + {0x010AC9, 0x0004}, {0x010AE5, 0x0010}, {0x010AE7, 0x0001}, {0x010AEB, 0x0002}, {0x010AF0, 0x0020}, {0x010AF7, 0x0001}, + {0x010B00, 0x0004}, {0x010B36, 0x0001}, {0x010B39, 0x0020}, {0x010B40, 0x0004}, {0x010B56, 0x0001}, {0x010B58, 0x0002}, + {0x010B60, 0x0004}, {0x010B73, 0x0001}, {0x010B78, 0x0002}, {0x010B80, 0x0004}, {0x010B92, 0x0001}, {0x010B99, 0x0020}, + {0x010B9D, 0x0001}, {0x010BA9, 0x0002}, {0x010BB0, 0x0001}, {0x010C00, 0x0004}, {0x010C49, 0x0001}, {0x010C80, 0x0004}, + {0x010CB3, 0x0001}, {0x010CC0, 0x0004}, {0x010CF3, 0x0001}, {0x010CFA, 0x0002}, {0x010D00, 0x0004}, {0x010D24, 0x0010}, + {0x010D28, 0x0001}, {0x010D30, 0x0002}, {0x010D3A, 0x0001}, {0x010E60, 0x0002}, {0x010E7F, 0x0001}, {0x010E80, 0x0004}, + {0x010EAA, 0x0001}, {0x010EAB, 0x0010}, {0x010EAD, 0x0020}, {0x010EAE, 0x0001}, {0x010EB0, 0x0004}, {0x010EB2, 0x0001}, + {0x010EFD, 0x0010}, {0x010F00, 0x0004}, {0x010F1D, 0x0002}, {0x010F27, 0x0004}, {0x010F28, 0x0001}, {0x010F30, 0x0004}, + {0x010F46, 0x0010}, {0x010F51, 0x0002}, {0x010F55, 0x0020}, {0x010F5A, 0x0001}, {0x010F70, 0x0004}, {0x010F82, 0x0010}, + {0x010F86, 0x0020}, {0x010F8A, 0x0001}, {0x010FB0, 0x0004}, {0x010FC5, 0x0002}, {0x010FCC, 0x0001}, {0x010FE0, 0x0004}, + {0x010FF7, 0x0001}, {0x011000, 0x0010}, {0x011003, 0x0004}, {0x011038, 0x0010}, {0x011047, 0x0020}, {0x01104E, 0x0001}, + {0x011052, 0x0002}, {0x011070, 0x0010}, {0x011071, 0x0004}, {0x011073, 0x0010}, {0x011075, 0x0004}, {0x011076, 0x0001}, + {0x01107F, 0x0010}, {0x011083, 0x0004}, {0x0110B0, 0x0010}, {0x0110BB, 0x0020}, {0x0110BD, 0x0080}, {0x0110BE, 0x0020}, + {0x0110C2, 0x0010}, {0x0110C3, 0x0001}, {0x0110CD, 0x0080}, {0x0110CE, 0x0001}, {0x0110D0, 0x0004}, {0x0110E9, 0x0001}, + {0x0110F0, 0x0002}, {0x0110FA, 0x0001}, {0x011100, 0x0010}, {0x011103, 0x0004}, {0x011127, 0x0010}, {0x011135, 0x0001}, + {0x011136, 0x0002}, {0x011140, 0x0020}, {0x011144, 0x0004}, {0x011145, 0x0010}, {0x011147, 0x0004}, {0x011148, 0x0001}, + {0x011150, 0x0004}, {0x011173, 0x0010}, {0x011174, 0x0020}, {0x011176, 0x0004}, {0x011177, 0x0001}, {0x011180, 0x0010}, + {0x011183, 0x0004}, {0x0111B3, 0x0010}, {0x0111C1, 0x0004}, {0x0111C5, 0x0020}, {0x0111C9, 0x0010}, {0x0111CD, 0x0020}, + {0x0111CE, 0x0010}, {0x0111D0, 0x0002}, {0x0111DA, 0x0004}, {0x0111DB, 0x0020}, {0x0111DC, 0x0004}, {0x0111DD, 0x0020}, + {0x0111E0, 0x0001}, {0x0111E1, 0x0002}, {0x0111F5, 0x0001}, {0x011200, 0x0004}, {0x011212, 0x0001}, {0x011213, 0x0004}, + {0x01122C, 0x0010}, {0x011238, 0x0020}, {0x01123E, 0x0010}, {0x01123F, 0x0004}, {0x011241, 0x0010}, {0x011242, 0x0001}, + {0x011280, 0x0004}, {0x011287, 0x0001}, {0x011288, 0x0004}, {0x011289, 0x0001}, {0x01128A, 0x0004}, {0x01128E, 0x0001}, + {0x01128F, 0x0004}, {0x01129E, 0x0001}, {0x01129F, 0x0004}, {0x0112A9, 0x0020}, {0x0112AA, 0x0001}, {0x0112B0, 0x0004}, + {0x0112DF, 0x0010}, {0x0112EB, 0x0001}, {0x0112F0, 0x0002}, {0x0112FA, 0x0001}, {0x011300, 0x0010}, {0x011304, 0x0001}, + {0x011305, 0x0004}, {0x01130D, 0x0001}, {0x01130F, 0x0004}, {0x011311, 0x0001}, {0x011313, 0x0004}, {0x011329, 0x0001}, + {0x01132A, 0x0004}, {0x011331, 0x0001}, {0x011332, 0x0004}, {0x011334, 0x0001}, {0x011335, 0x0004}, {0x01133A, 0x0001}, + {0x01133B, 0x0010}, {0x01133D, 0x0004}, {0x01133E, 0x0010}, {0x011345, 0x0001}, {0x011347, 0x0010}, {0x011349, 0x0001}, + {0x01134B, 0x0010}, {0x01134E, 0x0001}, {0x011350, 0x0004}, {0x011351, 0x0001}, {0x011357, 0x0010}, {0x011358, 0x0001}, + {0x01135D, 0x0004}, {0x011362, 0x0010}, {0x011364, 0x0001}, {0x011366, 0x0010}, {0x01136D, 0x0001}, {0x011370, 0x0010}, + {0x011375, 0x0001}, {0x011400, 0x0004}, {0x011435, 0x0010}, {0x011447, 0x0004}, {0x01144B, 0x0020}, {0x011450, 0x0002}, + {0x01145A, 0x0020}, {0x01145C, 0x0001}, {0x01145D, 0x0020}, {0x01145E, 0x0010}, {0x01145F, 0x0004}, {0x011462, 0x0001}, + {0x011480, 0x0004}, {0x0114B0, 0x0010}, {0x0114C4, 0x0004}, {0x0114C6, 0x0020}, {0x0114C7, 0x0004}, {0x0114C8, 0x0001}, + {0x0114D0, 0x0002}, {0x0114DA, 0x0001}, {0x011580, 0x0004}, {0x0115AF, 0x0010}, {0x0115B6, 0x0001}, {0x0115B8, 0x0010}, + {0x0115C1, 0x0020}, {0x0115D8, 0x0004}, {0x0115DC, 0x0010}, {0x0115DE, 0x0001}, {0x011600, 0x0004}, {0x011630, 0x0010}, + {0x011641, 0x0020}, {0x011644, 0x0004}, {0x011645, 0x0001}, {0x011650, 0x0002}, {0x01165A, 0x0001}, {0x011660, 0x0020}, + {0x01166D, 0x0001}, {0x011680, 0x0004}, {0x0116AB, 0x0010}, {0x0116B8, 0x0004}, {0x0116B9, 0x0020}, {0x0116BA, 0x0001}, + {0x0116C0, 0x0002}, {0x0116CA, 0x0001}, {0x011700, 0x0004}, {0x01171B, 0x0001}, {0x01171D, 0x0010}, {0x01172C, 0x0001}, + {0x011730, 0x0002}, {0x01173C, 0x0020}, {0x01173F, 0x0040}, {0x011740, 0x0004}, {0x011747, 0x0001}, {0x011800, 0x0004}, + {0x01182C, 0x0010}, {0x01183B, 0x0020}, {0x01183C, 0x0001}, {0x0118A0, 0x0004}, {0x0118E0, 0x0002}, {0x0118F3, 0x0001}, + {0x0118FF, 0x0004}, {0x011907, 0x0001}, {0x011909, 0x0004}, {0x01190A, 0x0001}, {0x01190C, 0x0004}, {0x011914, 0x0001}, + {0x011915, 0x0004}, {0x011917, 0x0001}, {0x011918, 0x0004}, {0x011930, 0x0010}, {0x011936, 0x0001}, {0x011937, 0x0010}, + {0x011939, 0x0001}, {0x01193B, 0x0010}, {0x01193F, 0x0004}, {0x011940, 0x0010}, {0x011941, 0x0004}, {0x011942, 0x0010}, + {0x011944, 0x0020}, {0x011947, 0x0001}, {0x011950, 0x0002}, {0x01195A, 0x0001}, {0x0119A0, 0x0004}, {0x0119A8, 0x0001}, + {0x0119AA, 0x0004}, {0x0119D1, 0x0010}, {0x0119D8, 0x0001}, {0x0119DA, 0x0010}, {0x0119E1, 0x0004}, {0x0119E2, 0x0020}, + {0x0119E3, 0x0004}, {0x0119E4, 0x0010}, {0x0119E5, 0x0001}, {0x011A00, 0x0004}, {0x011A01, 0x0010}, {0x011A0B, 0x0004}, + {0x011A33, 0x0010}, {0x011A3A, 0x0004}, {0x011A3B, 0x0010}, {0x011A3F, 0x0020}, {0x011A47, 0x0010}, {0x011A48, 0x0001}, + {0x011A50, 0x0004}, {0x011A51, 0x0010}, {0x011A5C, 0x0004}, {0x011A8A, 0x0010}, {0x011A9A, 0x0020}, {0x011A9D, 0x0004}, + {0x011A9E, 0x0020}, {0x011AA3, 0x0001}, {0x011AB0, 0x0004}, {0x011AF9, 0x0001}, {0x011B00, 0x0020}, {0x011B0A, 0x0001}, + {0x011C00, 0x0004}, {0x011C09, 0x0001}, {0x011C0A, 0x0004}, {0x011C2F, 0x0010}, {0x011C37, 0x0001}, {0x011C38, 0x0010}, + {0x011C40, 0x0004}, {0x011C41, 0x0020}, {0x011C46, 0x0001}, {0x011C50, 0x0002}, {0x011C6D, 0x0001}, {0x011C70, 0x0020}, + {0x011C72, 0x0004}, {0x011C90, 0x0001}, {0x011C92, 0x0010}, {0x011CA8, 0x0001}, {0x011CA9, 0x0010}, {0x011CB7, 0x0001}, + {0x011D00, 0x0004}, {0x011D07, 0x0001}, {0x011D08, 0x0004}, {0x011D0A, 0x0001}, {0x011D0B, 0x0004}, {0x011D31, 0x0010}, + {0x011D37, 0x0001}, {0x011D3A, 0x0010}, {0x011D3B, 0x0001}, {0x011D3C, 0x0010}, {0x011D3E, 0x0001}, {0x011D3F, 0x0010}, + {0x011D46, 0x0004}, {0x011D47, 0x0010}, {0x011D48, 0x0001}, {0x011D50, 0x0002}, {0x011D5A, 0x0001}, {0x011D60, 0x0004}, + {0x011D66, 0x0001}, {0x011D67, 0x0004}, {0x011D69, 0x0001}, {0x011D6A, 0x0004}, {0x011D8A, 0x0010}, {0x011D8F, 0x0001}, + {0x011D90, 0x0010}, {0x011D92, 0x0001}, {0x011D93, 0x0010}, {0x011D98, 0x0004}, {0x011D99, 0x0001}, {0x011DA0, 0x0002}, + {0x011DAA, 0x0001}, {0x011EE0, 0x0004}, {0x011EF3, 0x0010}, {0x011EF7, 0x0020}, {0x011EF9, 0x0001}, {0x011F00, 0x0010}, + {0x011F02, 0x0004}, {0x011F03, 0x0010}, {0x011F04, 0x0004}, {0x011F11, 0x0001}, {0x011F12, 0x0004}, {0x011F34, 0x0010}, + {0x011F3B, 0x0001}, {0x011F3E, 0x0010}, {0x011F43, 0x0020}, {0x011F50, 0x0002}, {0x011F5A, 0x0001}, {0x011FB0, 0x0004}, + {0x011FB1, 0x0001}, {0x011FC0, 0x0002}, {0x011FD5, 0x0040}, {0x011FF2, 0x0001}, {0x011FFF, 0x0020}, {0x012000, 0x0004}, + {0x01239A, 0x0001}, {0x012400, 0x0002}, {0x01246F, 0x0001}, {0x012470, 0x0020}, {0x012475, 0x0001}, {0x012480, 0x0004}, + {0x012544, 0x0001}, {0x012F90, 0x0004}, {0x012FF1, 0x0020}, {0x012FF3, 0x0001}, {0x013000, 0x0004}, {0x013430, 0x0080}, + {0x013440, 0x0010}, {0x013441, 0x0004}, {0x013447, 0x0010}, {0x013456, 0x0001}, {0x014400, 0x0004}, {0x014647, 0x0001}, + {0x016800, 0x0004}, {0x016A39, 0x0001}, {0x016A40, 0x0004}, {0x016A5F, 0x0001}, {0x016A60, 0x0002}, {0x016A6A, 0x0001}, + {0x016A6E, 0x0020}, {0x016A70, 0x0004}, {0x016ABF, 0x0001}, {0x016AC0, 0x0002}, {0x016ACA, 0x0001}, {0x016AD0, 0x0004}, + {0x016AEE, 0x0001}, {0x016AF0, 0x0010}, {0x016AF5, 0x0020}, {0x016AF6, 0x0001}, {0x016B00, 0x0004}, {0x016B30, 0x0010}, + {0x016B37, 0x0020}, {0x016B3C, 0x0040}, {0x016B40, 0x0004}, {0x016B44, 0x0020}, {0x016B45, 0x0040}, {0x016B46, 0x0001}, + {0x016B50, 0x0002}, {0x016B5A, 0x0001}, {0x016B5B, 0x0002}, {0x016B62, 0x0001}, {0x016B63, 0x0004}, {0x016B78, 0x0001}, + {0x016B7D, 0x0004}, {0x016B90, 0x0001}, {0x016E40, 0x0004}, {0x016E80, 0x0002}, {0x016E97, 0x0020}, {0x016E9B, 0x0001}, + {0x016F00, 0x0004}, {0x016F4B, 0x0001}, {0x016F4F, 0x0010}, {0x016F50, 0x0004}, {0x016F51, 0x0010}, {0x016F88, 0x0001}, + {0x016F8F, 0x0010}, {0x016F93, 0x0004}, {0x016FA0, 0x0001}, {0x016FE0, 0x0004}, {0x016FE2, 0x0020}, {0x016FE3, 0x0004}, + {0x016FE4, 0x0010}, {0x016FE5, 0x0001}, {0x016FF0, 0x0010}, {0x016FF2, 0x0001}, {0x017000, 0x0004}, {0x0187F8, 0x0001}, + {0x018800, 0x0004}, {0x018CD6, 0x0001}, {0x018D00, 0x0004}, {0x018D09, 0x0001}, {0x01AFF0, 0x0004}, {0x01AFF4, 0x0001}, + {0x01AFF5, 0x0004}, {0x01AFFC, 0x0001}, {0x01AFFD, 0x0004}, {0x01AFFF, 0x0001}, {0x01B000, 0x0004}, {0x01B123, 0x0001}, + {0x01B132, 0x0004}, {0x01B133, 0x0001}, {0x01B150, 0x0004}, {0x01B153, 0x0001}, {0x01B155, 0x0004}, {0x01B156, 0x0001}, + {0x01B164, 0x0004}, {0x01B168, 0x0001}, {0x01B170, 0x0004}, {0x01B2FC, 0x0001}, {0x01BC00, 0x0004}, {0x01BC6B, 0x0001}, + {0x01BC70, 0x0004}, {0x01BC7D, 0x0001}, {0x01BC80, 0x0004}, {0x01BC89, 0x0001}, {0x01BC90, 0x0004}, {0x01BC9A, 0x0001}, + {0x01BC9C, 0x0040}, {0x01BC9D, 0x0010}, {0x01BC9F, 0x0020}, {0x01BCA0, 0x0080}, {0x01BCA4, 0x0001}, {0x01CF00, 0x0010}, + {0x01CF2E, 0x0001}, {0x01CF30, 0x0010}, {0x01CF47, 0x0001}, {0x01CF50, 0x0040}, {0x01CFC4, 0x0001}, {0x01D000, 0x0040}, + {0x01D0F6, 0x0001}, {0x01D100, 0x0040}, {0x01D127, 0x0001}, {0x01D129, 0x0040}, {0x01D165, 0x0010}, {0x01D16A, 0x0040}, + {0x01D16D, 0x0010}, {0x01D173, 0x0080}, {0x01D17B, 0x0010}, {0x01D183, 0x0040}, {0x01D185, 0x0010}, {0x01D18C, 0x0040}, + {0x01D1AA, 0x0010}, {0x01D1AE, 0x0040}, {0x01D1EB, 0x0001}, {0x01D200, 0x0040}, {0x01D242, 0x0010}, {0x01D245, 0x0040}, + {0x01D246, 0x0001}, {0x01D2C0, 0x0002}, {0x01D2D4, 0x0001}, {0x01D2E0, 0x0002}, {0x01D2F4, 0x0001}, {0x01D300, 0x0040}, + {0x01D357, 0x0001}, {0x01D360, 0x0002}, {0x01D379, 0x0001}, {0x01D400, 0x0004}, {0x01D455, 0x0001}, {0x01D456, 0x0004}, + {0x01D49D, 0x0001}, {0x01D49E, 0x0004}, {0x01D4A0, 0x0001}, {0x01D4A2, 0x0004}, {0x01D4A3, 0x0001}, {0x01D4A5, 0x0004}, + {0x01D4A7, 0x0001}, {0x01D4A9, 0x0004}, {0x01D4AD, 0x0001}, {0x01D4AE, 0x0004}, {0x01D4BA, 0x0001}, {0x01D4BB, 0x0004}, + {0x01D4BC, 0x0001}, {0x01D4BD, 0x0004}, {0x01D4C4, 0x0001}, {0x01D4C5, 0x0004}, {0x01D506, 0x0001}, {0x01D507, 0x0004}, + {0x01D50B, 0x0001}, {0x01D50D, 0x0004}, {0x01D515, 0x0001}, {0x01D516, 0x0004}, {0x01D51D, 0x0001}, {0x01D51E, 0x0004}, + {0x01D53A, 0x0001}, {0x01D53B, 0x0004}, {0x01D53F, 0x0001}, {0x01D540, 0x0004}, {0x01D545, 0x0001}, {0x01D546, 0x0004}, + {0x01D547, 0x0001}, {0x01D54A, 0x0004}, {0x01D551, 0x0001}, {0x01D552, 0x0004}, {0x01D6A6, 0x0001}, {0x01D6A8, 0x0004}, + {0x01D6C1, 0x0040}, {0x01D6C2, 0x0004}, {0x01D6DB, 0x0040}, {0x01D6DC, 0x0004}, {0x01D6FB, 0x0040}, {0x01D6FC, 0x0004}, + {0x01D715, 0x0040}, {0x01D716, 0x0004}, {0x01D735, 0x0040}, {0x01D736, 0x0004}, {0x01D74F, 0x0040}, {0x01D750, 0x0004}, + {0x01D76F, 0x0040}, {0x01D770, 0x0004}, {0x01D789, 0x0040}, {0x01D78A, 0x0004}, {0x01D7A9, 0x0040}, {0x01D7AA, 0x0004}, + {0x01D7C3, 0x0040}, {0x01D7C4, 0x0004}, {0x01D7CC, 0x0001}, {0x01D7CE, 0x0002}, {0x01D800, 0x0040}, {0x01DA00, 0x0010}, + {0x01DA37, 0x0040}, {0x01DA3B, 0x0010}, {0x01DA6D, 0x0040}, {0x01DA75, 0x0010}, {0x01DA76, 0x0040}, {0x01DA84, 0x0010}, + {0x01DA85, 0x0040}, {0x01DA87, 0x0020}, {0x01DA8C, 0x0001}, {0x01DA9B, 0x0010}, {0x01DAA0, 0x0001}, {0x01DAA1, 0x0010}, + {0x01DAB0, 0x0001}, {0x01DF00, 0x0004}, {0x01DF1F, 0x0001}, {0x01DF25, 0x0004}, {0x01DF2B, 0x0001}, {0x01E000, 0x0010}, + {0x01E007, 0x0001}, {0x01E008, 0x0010}, {0x01E019, 0x0001}, {0x01E01B, 0x0010}, {0x01E022, 0x0001}, {0x01E023, 0x0010}, + {0x01E025, 0x0001}, {0x01E026, 0x0010}, {0x01E02B, 0x0001}, {0x01E030, 0x0004}, {0x01E06E, 0x0001}, {0x01E08F, 0x0010}, + {0x01E090, 0x0001}, {0x01E100, 0x0004}, {0x01E12D, 0x0001}, {0x01E130, 0x0010}, {0x01E137, 0x0004}, {0x01E13E, 0x0001}, + {0x01E140, 0x0002}, {0x01E14A, 0x0001}, {0x01E14E, 0x0004}, {0x01E14F, 0x0040}, {0x01E150, 0x0001}, {0x01E290, 0x0004}, + {0x01E2AE, 0x0010}, {0x01E2AF, 0x0001}, {0x01E2C0, 0x0004}, {0x01E2EC, 0x0010}, {0x01E2F0, 0x0002}, {0x01E2FA, 0x0001}, + {0x01E2FF, 0x0040}, {0x01E300, 0x0001}, {0x01E4D0, 0x0004}, {0x01E4EC, 0x0010}, {0x01E4F0, 0x0002}, {0x01E4FA, 0x0001}, + {0x01E7E0, 0x0004}, {0x01E7E7, 0x0001}, {0x01E7E8, 0x0004}, {0x01E7EC, 0x0001}, {0x01E7ED, 0x0004}, {0x01E7EF, 0x0001}, + {0x01E7F0, 0x0004}, {0x01E7FF, 0x0001}, {0x01E800, 0x0004}, {0x01E8C5, 0x0001}, {0x01E8C7, 0x0002}, {0x01E8D0, 0x0010}, + {0x01E8D7, 0x0001}, {0x01E900, 0x0004}, {0x01E944, 0x0010}, {0x01E94B, 0x0004}, {0x01E94C, 0x0001}, {0x01E950, 0x0002}, + {0x01E95A, 0x0001}, {0x01E95E, 0x0020}, {0x01E960, 0x0001}, {0x01EC71, 0x0002}, {0x01ECAC, 0x0040}, {0x01ECAD, 0x0002}, + {0x01ECB0, 0x0040}, {0x01ECB1, 0x0002}, {0x01ECB5, 0x0001}, {0x01ED01, 0x0002}, {0x01ED2E, 0x0040}, {0x01ED2F, 0x0002}, + {0x01ED3E, 0x0001}, {0x01EE00, 0x0004}, {0x01EE04, 0x0001}, {0x01EE05, 0x0004}, {0x01EE20, 0x0001}, {0x01EE21, 0x0004}, + {0x01EE23, 0x0001}, {0x01EE24, 0x0004}, {0x01EE25, 0x0001}, {0x01EE27, 0x0004}, {0x01EE28, 0x0001}, {0x01EE29, 0x0004}, + {0x01EE33, 0x0001}, {0x01EE34, 0x0004}, {0x01EE38, 0x0001}, {0x01EE39, 0x0004}, {0x01EE3A, 0x0001}, {0x01EE3B, 0x0004}, + {0x01EE3C, 0x0001}, {0x01EE42, 0x0004}, {0x01EE43, 0x0001}, {0x01EE47, 0x0004}, {0x01EE48, 0x0001}, {0x01EE49, 0x0004}, + {0x01EE4A, 0x0001}, {0x01EE4B, 0x0004}, {0x01EE4C, 0x0001}, {0x01EE4D, 0x0004}, {0x01EE50, 0x0001}, {0x01EE51, 0x0004}, + {0x01EE53, 0x0001}, {0x01EE54, 0x0004}, {0x01EE55, 0x0001}, {0x01EE57, 0x0004}, {0x01EE58, 0x0001}, {0x01EE59, 0x0004}, + {0x01EE5A, 0x0001}, {0x01EE5B, 0x0004}, {0x01EE5C, 0x0001}, {0x01EE5D, 0x0004}, {0x01EE5E, 0x0001}, {0x01EE5F, 0x0004}, + {0x01EE60, 0x0001}, {0x01EE61, 0x0004}, {0x01EE63, 0x0001}, {0x01EE64, 0x0004}, {0x01EE65, 0x0001}, {0x01EE67, 0x0004}, + {0x01EE6B, 0x0001}, {0x01EE6C, 0x0004}, {0x01EE73, 0x0001}, {0x01EE74, 0x0004}, {0x01EE78, 0x0001}, {0x01EE79, 0x0004}, + {0x01EE7D, 0x0001}, {0x01EE7E, 0x0004}, {0x01EE7F, 0x0001}, {0x01EE80, 0x0004}, {0x01EE8A, 0x0001}, {0x01EE8B, 0x0004}, + {0x01EE9C, 0x0001}, {0x01EEA1, 0x0004}, {0x01EEA4, 0x0001}, {0x01EEA5, 0x0004}, {0x01EEAA, 0x0001}, {0x01EEAB, 0x0004}, + {0x01EEBC, 0x0001}, {0x01EEF0, 0x0040}, {0x01EEF2, 0x0001}, {0x01F000, 0x0040}, {0x01F02C, 0x0001}, {0x01F030, 0x0040}, + {0x01F094, 0x0001}, {0x01F0A0, 0x0040}, {0x01F0AF, 0x0001}, {0x01F0B1, 0x0040}, {0x01F0C0, 0x0001}, {0x01F0C1, 0x0040}, + {0x01F0D0, 0x0001}, {0x01F0D1, 0x0040}, {0x01F0F6, 0x0001}, {0x01F100, 0x0002}, {0x01F10D, 0x0040}, {0x01F1AE, 0x0001}, + {0x01F1E6, 0x0040}, {0x01F203, 0x0001}, {0x01F210, 0x0040}, {0x01F23C, 0x0001}, {0x01F240, 0x0040}, {0x01F249, 0x0001}, + {0x01F250, 0x0040}, {0x01F252, 0x0001}, {0x01F260, 0x0040}, {0x01F266, 0x0001}, {0x01F300, 0x0040}, {0x01F6D8, 0x0001}, + {0x01F6DC, 0x0040}, {0x01F6ED, 0x0001}, {0x01F6F0, 0x0040}, {0x01F6FD, 0x0001}, {0x01F700, 0x0040}, {0x01F777, 0x0001}, + {0x01F77B, 0x0040}, {0x01F7DA, 0x0001}, {0x01F7E0, 0x0040}, {0x01F7EC, 0x0001}, {0x01F7F0, 0x0040}, {0x01F7F1, 0x0001}, + {0x01F800, 0x0040}, {0x01F80C, 0x0001}, {0x01F810, 0x0040}, {0x01F848, 0x0001}, {0x01F850, 0x0040}, {0x01F85A, 0x0001}, + {0x01F860, 0x0040}, {0x01F888, 0x0001}, {0x01F890, 0x0040}, {0x01F8AE, 0x0001}, {0x01F8B0, 0x0040}, {0x01F8B2, 0x0001}, + {0x01F900, 0x0040}, {0x01FA54, 0x0001}, {0x01FA60, 0x0040}, {0x01FA6E, 0x0001}, {0x01FA70, 0x0040}, {0x01FA7D, 0x0001}, + {0x01FA80, 0x0040}, {0x01FA89, 0x0001}, {0x01FA90, 0x0040}, {0x01FABE, 0x0001}, {0x01FABF, 0x0040}, {0x01FAC6, 0x0001}, + {0x01FACE, 0x0040}, {0x01FADC, 0x0001}, {0x01FAE0, 0x0040}, {0x01FAE9, 0x0001}, {0x01FAF0, 0x0040}, {0x01FAF9, 0x0001}, + {0x01FB00, 0x0040}, {0x01FB93, 0x0001}, {0x01FB94, 0x0040}, {0x01FBCB, 0x0001}, {0x01FBF0, 0x0002}, {0x01FBFA, 0x0001}, + {0x020000, 0x0004}, {0x02A6E0, 0x0001}, {0x02A700, 0x0004}, {0x02B73A, 0x0001}, {0x02B740, 0x0004}, {0x02B81E, 0x0001}, + {0x02B820, 0x0004}, {0x02CEA2, 0x0001}, {0x02CEB0, 0x0004}, {0x02EBE1, 0x0001}, {0x02EBF0, 0x0004}, {0x02EE5E, 0x0001}, + {0x02F800, 0x0004}, {0x02FA1E, 0x0001}, {0x030000, 0x0004}, {0x03134B, 0x0001}, {0x031350, 0x0004}, {0x0323B0, 0x0001}, + {0x0E0001, 0x0080}, {0x0E0002, 0x0001}, {0x0E0020, 0x0080}, {0x0E0080, 0x0001}, {0x0E0100, 0x0010}, {0x0E01F0, 0x0001}, + {0x0F0000, 0x0080}, {0x0FFFFE, 0x0001}, {0x100000, 0x0080}, {0x10FFFE, 0x0001}, {0x110000, 0x0000}, +}; + +// list is always in ascending order, to enable binary search +const std::initializer_list> unicode_map_lowercase = { + {0x000041, 0x000061}, {0x000042, 0x000062}, {0x000043, 0x000063}, {0x000044, 0x000064}, {0x000045, 0x000065}, + {0x000046, 0x000066}, {0x000047, 0x000067}, {0x000048, 0x000068}, {0x000049, 0x000069}, {0x00004A, 0x00006A}, + {0x00004B, 0x00006B}, {0x00004C, 0x00006C}, {0x00004D, 0x00006D}, {0x00004E, 0x00006E}, {0x00004F, 0x00006F}, + {0x000050, 0x000070}, {0x000051, 0x000071}, {0x000052, 0x000072}, {0x000053, 0x000073}, {0x000054, 0x000074}, + {0x000055, 0x000075}, {0x000056, 0x000076}, {0x000057, 0x000077}, {0x000058, 0x000078}, {0x000059, 0x000079}, + {0x00005A, 0x00007A}, {0x0000C0, 0x0000E0}, {0x0000C1, 0x0000E1}, {0x0000C2, 0x0000E2}, {0x0000C3, 0x0000E3}, + {0x0000C4, 0x0000E4}, {0x0000C5, 0x0000E5}, {0x0000C6, 0x0000E6}, {0x0000C7, 0x0000E7}, {0x0000C8, 0x0000E8}, + {0x0000C9, 0x0000E9}, {0x0000CA, 0x0000EA}, {0x0000CB, 0x0000EB}, {0x0000CC, 0x0000EC}, {0x0000CD, 0x0000ED}, + {0x0000CE, 0x0000EE}, {0x0000CF, 0x0000EF}, {0x0000D0, 0x0000F0}, {0x0000D1, 0x0000F1}, {0x0000D2, 0x0000F2}, + {0x0000D3, 0x0000F3}, {0x0000D4, 0x0000F4}, {0x0000D5, 0x0000F5}, {0x0000D6, 0x0000F6}, {0x0000D8, 0x0000F8}, + {0x0000D9, 0x0000F9}, {0x0000DA, 0x0000FA}, {0x0000DB, 0x0000FB}, {0x0000DC, 0x0000FC}, {0x0000DD, 0x0000FD}, + {0x0000DE, 0x0000FE}, {0x000100, 0x000101}, {0x000102, 0x000103}, {0x000104, 0x000105}, {0x000106, 0x000107}, + {0x000108, 0x000109}, {0x00010A, 0x00010B}, {0x00010C, 0x00010D}, {0x00010E, 0x00010F}, {0x000110, 0x000111}, + {0x000112, 0x000113}, {0x000114, 0x000115}, {0x000116, 0x000117}, {0x000118, 0x000119}, {0x00011A, 0x00011B}, + {0x00011C, 0x00011D}, {0x00011E, 0x00011F}, {0x000120, 0x000121}, {0x000122, 0x000123}, {0x000124, 0x000125}, + {0x000126, 0x000127}, {0x000128, 0x000129}, {0x00012A, 0x00012B}, {0x00012C, 0x00012D}, {0x00012E, 0x00012F}, + {0x000130, 0x000069}, {0x000132, 0x000133}, {0x000134, 0x000135}, {0x000136, 0x000137}, {0x000139, 0x00013A}, + {0x00013B, 0x00013C}, {0x00013D, 0x00013E}, {0x00013F, 0x000140}, {0x000141, 0x000142}, {0x000143, 0x000144}, + {0x000145, 0x000146}, {0x000147, 0x000148}, {0x00014A, 0x00014B}, {0x00014C, 0x00014D}, {0x00014E, 0x00014F}, + {0x000150, 0x000151}, {0x000152, 0x000153}, {0x000154, 0x000155}, {0x000156, 0x000157}, {0x000158, 0x000159}, + {0x00015A, 0x00015B}, {0x00015C, 0x00015D}, {0x00015E, 0x00015F}, {0x000160, 0x000161}, {0x000162, 0x000163}, + {0x000164, 0x000165}, {0x000166, 0x000167}, {0x000168, 0x000169}, {0x00016A, 0x00016B}, {0x00016C, 0x00016D}, + {0x00016E, 0x00016F}, {0x000170, 0x000171}, {0x000172, 0x000173}, {0x000174, 0x000175}, {0x000176, 0x000177}, + {0x000178, 0x0000FF}, {0x000179, 0x00017A}, {0x00017B, 0x00017C}, {0x00017D, 0x00017E}, {0x000181, 0x000253}, + {0x000182, 0x000183}, {0x000184, 0x000185}, {0x000186, 0x000254}, {0x000187, 0x000188}, {0x000189, 0x000256}, + {0x00018A, 0x000257}, {0x00018B, 0x00018C}, {0x00018E, 0x0001DD}, {0x00018F, 0x000259}, {0x000190, 0x00025B}, + {0x000191, 0x000192}, {0x000193, 0x000260}, {0x000194, 0x000263}, {0x000196, 0x000269}, {0x000197, 0x000268}, + {0x000198, 0x000199}, {0x00019C, 0x00026F}, {0x00019D, 0x000272}, {0x00019F, 0x000275}, {0x0001A0, 0x0001A1}, + {0x0001A2, 0x0001A3}, {0x0001A4, 0x0001A5}, {0x0001A6, 0x000280}, {0x0001A7, 0x0001A8}, {0x0001A9, 0x000283}, + {0x0001AC, 0x0001AD}, {0x0001AE, 0x000288}, {0x0001AF, 0x0001B0}, {0x0001B1, 0x00028A}, {0x0001B2, 0x00028B}, + {0x0001B3, 0x0001B4}, {0x0001B5, 0x0001B6}, {0x0001B7, 0x000292}, {0x0001B8, 0x0001B9}, {0x0001BC, 0x0001BD}, + {0x0001C4, 0x0001C6}, {0x0001C5, 0x0001C6}, {0x0001C7, 0x0001C9}, {0x0001C8, 0x0001C9}, {0x0001CA, 0x0001CC}, + {0x0001CB, 0x0001CC}, {0x0001CD, 0x0001CE}, {0x0001CF, 0x0001D0}, {0x0001D1, 0x0001D2}, {0x0001D3, 0x0001D4}, + {0x0001D5, 0x0001D6}, {0x0001D7, 0x0001D8}, {0x0001D9, 0x0001DA}, {0x0001DB, 0x0001DC}, {0x0001DE, 0x0001DF}, + {0x0001E0, 0x0001E1}, {0x0001E2, 0x0001E3}, {0x0001E4, 0x0001E5}, {0x0001E6, 0x0001E7}, {0x0001E8, 0x0001E9}, + {0x0001EA, 0x0001EB}, {0x0001EC, 0x0001ED}, {0x0001EE, 0x0001EF}, {0x0001F1, 0x0001F3}, {0x0001F2, 0x0001F3}, + {0x0001F4, 0x0001F5}, {0x0001F6, 0x000195}, {0x0001F7, 0x0001BF}, {0x0001F8, 0x0001F9}, {0x0001FA, 0x0001FB}, + {0x0001FC, 0x0001FD}, {0x0001FE, 0x0001FF}, {0x000200, 0x000201}, {0x000202, 0x000203}, {0x000204, 0x000205}, + {0x000206, 0x000207}, {0x000208, 0x000209}, {0x00020A, 0x00020B}, {0x00020C, 0x00020D}, {0x00020E, 0x00020F}, + {0x000210, 0x000211}, {0x000212, 0x000213}, {0x000214, 0x000215}, {0x000216, 0x000217}, {0x000218, 0x000219}, + {0x00021A, 0x00021B}, {0x00021C, 0x00021D}, {0x00021E, 0x00021F}, {0x000220, 0x00019E}, {0x000222, 0x000223}, + {0x000224, 0x000225}, {0x000226, 0x000227}, {0x000228, 0x000229}, {0x00022A, 0x00022B}, {0x00022C, 0x00022D}, + {0x00022E, 0x00022F}, {0x000230, 0x000231}, {0x000232, 0x000233}, {0x00023A, 0x002C65}, {0x00023B, 0x00023C}, + {0x00023D, 0x00019A}, {0x00023E, 0x002C66}, {0x000241, 0x000242}, {0x000243, 0x000180}, {0x000244, 0x000289}, + {0x000245, 0x00028C}, {0x000246, 0x000247}, {0x000248, 0x000249}, {0x00024A, 0x00024B}, {0x00024C, 0x00024D}, + {0x00024E, 0x00024F}, {0x000370, 0x000371}, {0x000372, 0x000373}, {0x000376, 0x000377}, {0x00037F, 0x0003F3}, + {0x000386, 0x0003AC}, {0x000388, 0x0003AD}, {0x000389, 0x0003AE}, {0x00038A, 0x0003AF}, {0x00038C, 0x0003CC}, + {0x00038E, 0x0003CD}, {0x00038F, 0x0003CE}, {0x000391, 0x0003B1}, {0x000392, 0x0003B2}, {0x000393, 0x0003B3}, + {0x000394, 0x0003B4}, {0x000395, 0x0003B5}, {0x000396, 0x0003B6}, {0x000397, 0x0003B7}, {0x000398, 0x0003B8}, + {0x000399, 0x0003B9}, {0x00039A, 0x0003BA}, {0x00039B, 0x0003BB}, {0x00039C, 0x0003BC}, {0x00039D, 0x0003BD}, + {0x00039E, 0x0003BE}, {0x00039F, 0x0003BF}, {0x0003A0, 0x0003C0}, {0x0003A1, 0x0003C1}, {0x0003A3, 0x0003C3}, + {0x0003A4, 0x0003C4}, {0x0003A5, 0x0003C5}, {0x0003A6, 0x0003C6}, {0x0003A7, 0x0003C7}, {0x0003A8, 0x0003C8}, + {0x0003A9, 0x0003C9}, {0x0003AA, 0x0003CA}, {0x0003AB, 0x0003CB}, {0x0003CF, 0x0003D7}, {0x0003D8, 0x0003D9}, + {0x0003DA, 0x0003DB}, {0x0003DC, 0x0003DD}, {0x0003DE, 0x0003DF}, {0x0003E0, 0x0003E1}, {0x0003E2, 0x0003E3}, + {0x0003E4, 0x0003E5}, {0x0003E6, 0x0003E7}, {0x0003E8, 0x0003E9}, {0x0003EA, 0x0003EB}, {0x0003EC, 0x0003ED}, + {0x0003EE, 0x0003EF}, {0x0003F4, 0x0003B8}, {0x0003F7, 0x0003F8}, {0x0003F9, 0x0003F2}, {0x0003FA, 0x0003FB}, + {0x0003FD, 0x00037B}, {0x0003FE, 0x00037C}, {0x0003FF, 0x00037D}, {0x000400, 0x000450}, {0x000401, 0x000451}, + {0x000402, 0x000452}, {0x000403, 0x000453}, {0x000404, 0x000454}, {0x000405, 0x000455}, {0x000406, 0x000456}, + {0x000407, 0x000457}, {0x000408, 0x000458}, {0x000409, 0x000459}, {0x00040A, 0x00045A}, {0x00040B, 0x00045B}, + {0x00040C, 0x00045C}, {0x00040D, 0x00045D}, {0x00040E, 0x00045E}, {0x00040F, 0x00045F}, {0x000410, 0x000430}, + {0x000411, 0x000431}, {0x000412, 0x000432}, {0x000413, 0x000433}, {0x000414, 0x000434}, {0x000415, 0x000435}, + {0x000416, 0x000436}, {0x000417, 0x000437}, {0x000418, 0x000438}, {0x000419, 0x000439}, {0x00041A, 0x00043A}, + {0x00041B, 0x00043B}, {0x00041C, 0x00043C}, {0x00041D, 0x00043D}, {0x00041E, 0x00043E}, {0x00041F, 0x00043F}, + {0x000420, 0x000440}, {0x000421, 0x000441}, {0x000422, 0x000442}, {0x000423, 0x000443}, {0x000424, 0x000444}, + {0x000425, 0x000445}, {0x000426, 0x000446}, {0x000427, 0x000447}, {0x000428, 0x000448}, {0x000429, 0x000449}, + {0x00042A, 0x00044A}, {0x00042B, 0x00044B}, {0x00042C, 0x00044C}, {0x00042D, 0x00044D}, {0x00042E, 0x00044E}, + {0x00042F, 0x00044F}, {0x000460, 0x000461}, {0x000462, 0x000463}, {0x000464, 0x000465}, {0x000466, 0x000467}, + {0x000468, 0x000469}, {0x00046A, 0x00046B}, {0x00046C, 0x00046D}, {0x00046E, 0x00046F}, {0x000470, 0x000471}, + {0x000472, 0x000473}, {0x000474, 0x000475}, {0x000476, 0x000477}, {0x000478, 0x000479}, {0x00047A, 0x00047B}, + {0x00047C, 0x00047D}, {0x00047E, 0x00047F}, {0x000480, 0x000481}, {0x00048A, 0x00048B}, {0x00048C, 0x00048D}, + {0x00048E, 0x00048F}, {0x000490, 0x000491}, {0x000492, 0x000493}, {0x000494, 0x000495}, {0x000496, 0x000497}, + {0x000498, 0x000499}, {0x00049A, 0x00049B}, {0x00049C, 0x00049D}, {0x00049E, 0x00049F}, {0x0004A0, 0x0004A1}, + {0x0004A2, 0x0004A3}, {0x0004A4, 0x0004A5}, {0x0004A6, 0x0004A7}, {0x0004A8, 0x0004A9}, {0x0004AA, 0x0004AB}, + {0x0004AC, 0x0004AD}, {0x0004AE, 0x0004AF}, {0x0004B0, 0x0004B1}, {0x0004B2, 0x0004B3}, {0x0004B4, 0x0004B5}, + {0x0004B6, 0x0004B7}, {0x0004B8, 0x0004B9}, {0x0004BA, 0x0004BB}, {0x0004BC, 0x0004BD}, {0x0004BE, 0x0004BF}, + {0x0004C0, 0x0004CF}, {0x0004C1, 0x0004C2}, {0x0004C3, 0x0004C4}, {0x0004C5, 0x0004C6}, {0x0004C7, 0x0004C8}, + {0x0004C9, 0x0004CA}, {0x0004CB, 0x0004CC}, {0x0004CD, 0x0004CE}, {0x0004D0, 0x0004D1}, {0x0004D2, 0x0004D3}, + {0x0004D4, 0x0004D5}, {0x0004D6, 0x0004D7}, {0x0004D8, 0x0004D9}, {0x0004DA, 0x0004DB}, {0x0004DC, 0x0004DD}, + {0x0004DE, 0x0004DF}, {0x0004E0, 0x0004E1}, {0x0004E2, 0x0004E3}, {0x0004E4, 0x0004E5}, {0x0004E6, 0x0004E7}, + {0x0004E8, 0x0004E9}, {0x0004EA, 0x0004EB}, {0x0004EC, 0x0004ED}, {0x0004EE, 0x0004EF}, {0x0004F0, 0x0004F1}, + {0x0004F2, 0x0004F3}, {0x0004F4, 0x0004F5}, {0x0004F6, 0x0004F7}, {0x0004F8, 0x0004F9}, {0x0004FA, 0x0004FB}, + {0x0004FC, 0x0004FD}, {0x0004FE, 0x0004FF}, {0x000500, 0x000501}, {0x000502, 0x000503}, {0x000504, 0x000505}, + {0x000506, 0x000507}, {0x000508, 0x000509}, {0x00050A, 0x00050B}, {0x00050C, 0x00050D}, {0x00050E, 0x00050F}, + {0x000510, 0x000511}, {0x000512, 0x000513}, {0x000514, 0x000515}, {0x000516, 0x000517}, {0x000518, 0x000519}, + {0x00051A, 0x00051B}, {0x00051C, 0x00051D}, {0x00051E, 0x00051F}, {0x000520, 0x000521}, {0x000522, 0x000523}, + {0x000524, 0x000525}, {0x000526, 0x000527}, {0x000528, 0x000529}, {0x00052A, 0x00052B}, {0x00052C, 0x00052D}, + {0x00052E, 0x00052F}, {0x000531, 0x000561}, {0x000532, 0x000562}, {0x000533, 0x000563}, {0x000534, 0x000564}, + {0x000535, 0x000565}, {0x000536, 0x000566}, {0x000537, 0x000567}, {0x000538, 0x000568}, {0x000539, 0x000569}, + {0x00053A, 0x00056A}, {0x00053B, 0x00056B}, {0x00053C, 0x00056C}, {0x00053D, 0x00056D}, {0x00053E, 0x00056E}, + {0x00053F, 0x00056F}, {0x000540, 0x000570}, {0x000541, 0x000571}, {0x000542, 0x000572}, {0x000543, 0x000573}, + {0x000544, 0x000574}, {0x000545, 0x000575}, {0x000546, 0x000576}, {0x000547, 0x000577}, {0x000548, 0x000578}, + {0x000549, 0x000579}, {0x00054A, 0x00057A}, {0x00054B, 0x00057B}, {0x00054C, 0x00057C}, {0x00054D, 0x00057D}, + {0x00054E, 0x00057E}, {0x00054F, 0x00057F}, {0x000550, 0x000580}, {0x000551, 0x000581}, {0x000552, 0x000582}, + {0x000553, 0x000583}, {0x000554, 0x000584}, {0x000555, 0x000585}, {0x000556, 0x000586}, {0x0010A0, 0x002D00}, + {0x0010A1, 0x002D01}, {0x0010A2, 0x002D02}, {0x0010A3, 0x002D03}, {0x0010A4, 0x002D04}, {0x0010A5, 0x002D05}, + {0x0010A6, 0x002D06}, {0x0010A7, 0x002D07}, {0x0010A8, 0x002D08}, {0x0010A9, 0x002D09}, {0x0010AA, 0x002D0A}, + {0x0010AB, 0x002D0B}, {0x0010AC, 0x002D0C}, {0x0010AD, 0x002D0D}, {0x0010AE, 0x002D0E}, {0x0010AF, 0x002D0F}, + {0x0010B0, 0x002D10}, {0x0010B1, 0x002D11}, {0x0010B2, 0x002D12}, {0x0010B3, 0x002D13}, {0x0010B4, 0x002D14}, + {0x0010B5, 0x002D15}, {0x0010B6, 0x002D16}, {0x0010B7, 0x002D17}, {0x0010B8, 0x002D18}, {0x0010B9, 0x002D19}, + {0x0010BA, 0x002D1A}, {0x0010BB, 0x002D1B}, {0x0010BC, 0x002D1C}, {0x0010BD, 0x002D1D}, {0x0010BE, 0x002D1E}, + {0x0010BF, 0x002D1F}, {0x0010C0, 0x002D20}, {0x0010C1, 0x002D21}, {0x0010C2, 0x002D22}, {0x0010C3, 0x002D23}, + {0x0010C4, 0x002D24}, {0x0010C5, 0x002D25}, {0x0010C7, 0x002D27}, {0x0010CD, 0x002D2D}, {0x0013A0, 0x00AB70}, + {0x0013A1, 0x00AB71}, {0x0013A2, 0x00AB72}, {0x0013A3, 0x00AB73}, {0x0013A4, 0x00AB74}, {0x0013A5, 0x00AB75}, + {0x0013A6, 0x00AB76}, {0x0013A7, 0x00AB77}, {0x0013A8, 0x00AB78}, {0x0013A9, 0x00AB79}, {0x0013AA, 0x00AB7A}, + {0x0013AB, 0x00AB7B}, {0x0013AC, 0x00AB7C}, {0x0013AD, 0x00AB7D}, {0x0013AE, 0x00AB7E}, {0x0013AF, 0x00AB7F}, + {0x0013B0, 0x00AB80}, {0x0013B1, 0x00AB81}, {0x0013B2, 0x00AB82}, {0x0013B3, 0x00AB83}, {0x0013B4, 0x00AB84}, + {0x0013B5, 0x00AB85}, {0x0013B6, 0x00AB86}, {0x0013B7, 0x00AB87}, {0x0013B8, 0x00AB88}, {0x0013B9, 0x00AB89}, + {0x0013BA, 0x00AB8A}, {0x0013BB, 0x00AB8B}, {0x0013BC, 0x00AB8C}, {0x0013BD, 0x00AB8D}, {0x0013BE, 0x00AB8E}, + {0x0013BF, 0x00AB8F}, {0x0013C0, 0x00AB90}, {0x0013C1, 0x00AB91}, {0x0013C2, 0x00AB92}, {0x0013C3, 0x00AB93}, + {0x0013C4, 0x00AB94}, {0x0013C5, 0x00AB95}, {0x0013C6, 0x00AB96}, {0x0013C7, 0x00AB97}, {0x0013C8, 0x00AB98}, + {0x0013C9, 0x00AB99}, {0x0013CA, 0x00AB9A}, {0x0013CB, 0x00AB9B}, {0x0013CC, 0x00AB9C}, {0x0013CD, 0x00AB9D}, + {0x0013CE, 0x00AB9E}, {0x0013CF, 0x00AB9F}, {0x0013D0, 0x00ABA0}, {0x0013D1, 0x00ABA1}, {0x0013D2, 0x00ABA2}, + {0x0013D3, 0x00ABA3}, {0x0013D4, 0x00ABA4}, {0x0013D5, 0x00ABA5}, {0x0013D6, 0x00ABA6}, {0x0013D7, 0x00ABA7}, + {0x0013D8, 0x00ABA8}, {0x0013D9, 0x00ABA9}, {0x0013DA, 0x00ABAA}, {0x0013DB, 0x00ABAB}, {0x0013DC, 0x00ABAC}, + {0x0013DD, 0x00ABAD}, {0x0013DE, 0x00ABAE}, {0x0013DF, 0x00ABAF}, {0x0013E0, 0x00ABB0}, {0x0013E1, 0x00ABB1}, + {0x0013E2, 0x00ABB2}, {0x0013E3, 0x00ABB3}, {0x0013E4, 0x00ABB4}, {0x0013E5, 0x00ABB5}, {0x0013E6, 0x00ABB6}, + {0x0013E7, 0x00ABB7}, {0x0013E8, 0x00ABB8}, {0x0013E9, 0x00ABB9}, {0x0013EA, 0x00ABBA}, {0x0013EB, 0x00ABBB}, + {0x0013EC, 0x00ABBC}, {0x0013ED, 0x00ABBD}, {0x0013EE, 0x00ABBE}, {0x0013EF, 0x00ABBF}, {0x0013F0, 0x0013F8}, + {0x0013F1, 0x0013F9}, {0x0013F2, 0x0013FA}, {0x0013F3, 0x0013FB}, {0x0013F4, 0x0013FC}, {0x0013F5, 0x0013FD}, + {0x001C90, 0x0010D0}, {0x001C91, 0x0010D1}, {0x001C92, 0x0010D2}, {0x001C93, 0x0010D3}, {0x001C94, 0x0010D4}, + {0x001C95, 0x0010D5}, {0x001C96, 0x0010D6}, {0x001C97, 0x0010D7}, {0x001C98, 0x0010D8}, {0x001C99, 0x0010D9}, + {0x001C9A, 0x0010DA}, {0x001C9B, 0x0010DB}, {0x001C9C, 0x0010DC}, {0x001C9D, 0x0010DD}, {0x001C9E, 0x0010DE}, + {0x001C9F, 0x0010DF}, {0x001CA0, 0x0010E0}, {0x001CA1, 0x0010E1}, {0x001CA2, 0x0010E2}, {0x001CA3, 0x0010E3}, + {0x001CA4, 0x0010E4}, {0x001CA5, 0x0010E5}, {0x001CA6, 0x0010E6}, {0x001CA7, 0x0010E7}, {0x001CA8, 0x0010E8}, + {0x001CA9, 0x0010E9}, {0x001CAA, 0x0010EA}, {0x001CAB, 0x0010EB}, {0x001CAC, 0x0010EC}, {0x001CAD, 0x0010ED}, + {0x001CAE, 0x0010EE}, {0x001CAF, 0x0010EF}, {0x001CB0, 0x0010F0}, {0x001CB1, 0x0010F1}, {0x001CB2, 0x0010F2}, + {0x001CB3, 0x0010F3}, {0x001CB4, 0x0010F4}, {0x001CB5, 0x0010F5}, {0x001CB6, 0x0010F6}, {0x001CB7, 0x0010F7}, + {0x001CB8, 0x0010F8}, {0x001CB9, 0x0010F9}, {0x001CBA, 0x0010FA}, {0x001CBD, 0x0010FD}, {0x001CBE, 0x0010FE}, + {0x001CBF, 0x0010FF}, {0x001E00, 0x001E01}, {0x001E02, 0x001E03}, {0x001E04, 0x001E05}, {0x001E06, 0x001E07}, + {0x001E08, 0x001E09}, {0x001E0A, 0x001E0B}, {0x001E0C, 0x001E0D}, {0x001E0E, 0x001E0F}, {0x001E10, 0x001E11}, + {0x001E12, 0x001E13}, {0x001E14, 0x001E15}, {0x001E16, 0x001E17}, {0x001E18, 0x001E19}, {0x001E1A, 0x001E1B}, + {0x001E1C, 0x001E1D}, {0x001E1E, 0x001E1F}, {0x001E20, 0x001E21}, {0x001E22, 0x001E23}, {0x001E24, 0x001E25}, + {0x001E26, 0x001E27}, {0x001E28, 0x001E29}, {0x001E2A, 0x001E2B}, {0x001E2C, 0x001E2D}, {0x001E2E, 0x001E2F}, + {0x001E30, 0x001E31}, {0x001E32, 0x001E33}, {0x001E34, 0x001E35}, {0x001E36, 0x001E37}, {0x001E38, 0x001E39}, + {0x001E3A, 0x001E3B}, {0x001E3C, 0x001E3D}, {0x001E3E, 0x001E3F}, {0x001E40, 0x001E41}, {0x001E42, 0x001E43}, + {0x001E44, 0x001E45}, {0x001E46, 0x001E47}, {0x001E48, 0x001E49}, {0x001E4A, 0x001E4B}, {0x001E4C, 0x001E4D}, + {0x001E4E, 0x001E4F}, {0x001E50, 0x001E51}, {0x001E52, 0x001E53}, {0x001E54, 0x001E55}, {0x001E56, 0x001E57}, + {0x001E58, 0x001E59}, {0x001E5A, 0x001E5B}, {0x001E5C, 0x001E5D}, {0x001E5E, 0x001E5F}, {0x001E60, 0x001E61}, + {0x001E62, 0x001E63}, {0x001E64, 0x001E65}, {0x001E66, 0x001E67}, {0x001E68, 0x001E69}, {0x001E6A, 0x001E6B}, + {0x001E6C, 0x001E6D}, {0x001E6E, 0x001E6F}, {0x001E70, 0x001E71}, {0x001E72, 0x001E73}, {0x001E74, 0x001E75}, + {0x001E76, 0x001E77}, {0x001E78, 0x001E79}, {0x001E7A, 0x001E7B}, {0x001E7C, 0x001E7D}, {0x001E7E, 0x001E7F}, + {0x001E80, 0x001E81}, {0x001E82, 0x001E83}, {0x001E84, 0x001E85}, {0x001E86, 0x001E87}, {0x001E88, 0x001E89}, + {0x001E8A, 0x001E8B}, {0x001E8C, 0x001E8D}, {0x001E8E, 0x001E8F}, {0x001E90, 0x001E91}, {0x001E92, 0x001E93}, + {0x001E94, 0x001E95}, {0x001E9E, 0x0000DF}, {0x001EA0, 0x001EA1}, {0x001EA2, 0x001EA3}, {0x001EA4, 0x001EA5}, + {0x001EA6, 0x001EA7}, {0x001EA8, 0x001EA9}, {0x001EAA, 0x001EAB}, {0x001EAC, 0x001EAD}, {0x001EAE, 0x001EAF}, + {0x001EB0, 0x001EB1}, {0x001EB2, 0x001EB3}, {0x001EB4, 0x001EB5}, {0x001EB6, 0x001EB7}, {0x001EB8, 0x001EB9}, + {0x001EBA, 0x001EBB}, {0x001EBC, 0x001EBD}, {0x001EBE, 0x001EBF}, {0x001EC0, 0x001EC1}, {0x001EC2, 0x001EC3}, + {0x001EC4, 0x001EC5}, {0x001EC6, 0x001EC7}, {0x001EC8, 0x001EC9}, {0x001ECA, 0x001ECB}, {0x001ECC, 0x001ECD}, + {0x001ECE, 0x001ECF}, {0x001ED0, 0x001ED1}, {0x001ED2, 0x001ED3}, {0x001ED4, 0x001ED5}, {0x001ED6, 0x001ED7}, + {0x001ED8, 0x001ED9}, {0x001EDA, 0x001EDB}, {0x001EDC, 0x001EDD}, {0x001EDE, 0x001EDF}, {0x001EE0, 0x001EE1}, + {0x001EE2, 0x001EE3}, {0x001EE4, 0x001EE5}, {0x001EE6, 0x001EE7}, {0x001EE8, 0x001EE9}, {0x001EEA, 0x001EEB}, + {0x001EEC, 0x001EED}, {0x001EEE, 0x001EEF}, {0x001EF0, 0x001EF1}, {0x001EF2, 0x001EF3}, {0x001EF4, 0x001EF5}, + {0x001EF6, 0x001EF7}, {0x001EF8, 0x001EF9}, {0x001EFA, 0x001EFB}, {0x001EFC, 0x001EFD}, {0x001EFE, 0x001EFF}, + {0x001F08, 0x001F00}, {0x001F09, 0x001F01}, {0x001F0A, 0x001F02}, {0x001F0B, 0x001F03}, {0x001F0C, 0x001F04}, + {0x001F0D, 0x001F05}, {0x001F0E, 0x001F06}, {0x001F0F, 0x001F07}, {0x001F18, 0x001F10}, {0x001F19, 0x001F11}, + {0x001F1A, 0x001F12}, {0x001F1B, 0x001F13}, {0x001F1C, 0x001F14}, {0x001F1D, 0x001F15}, {0x001F28, 0x001F20}, + {0x001F29, 0x001F21}, {0x001F2A, 0x001F22}, {0x001F2B, 0x001F23}, {0x001F2C, 0x001F24}, {0x001F2D, 0x001F25}, + {0x001F2E, 0x001F26}, {0x001F2F, 0x001F27}, {0x001F38, 0x001F30}, {0x001F39, 0x001F31}, {0x001F3A, 0x001F32}, + {0x001F3B, 0x001F33}, {0x001F3C, 0x001F34}, {0x001F3D, 0x001F35}, {0x001F3E, 0x001F36}, {0x001F3F, 0x001F37}, + {0x001F48, 0x001F40}, {0x001F49, 0x001F41}, {0x001F4A, 0x001F42}, {0x001F4B, 0x001F43}, {0x001F4C, 0x001F44}, + {0x001F4D, 0x001F45}, {0x001F59, 0x001F51}, {0x001F5B, 0x001F53}, {0x001F5D, 0x001F55}, {0x001F5F, 0x001F57}, + {0x001F68, 0x001F60}, {0x001F69, 0x001F61}, {0x001F6A, 0x001F62}, {0x001F6B, 0x001F63}, {0x001F6C, 0x001F64}, + {0x001F6D, 0x001F65}, {0x001F6E, 0x001F66}, {0x001F6F, 0x001F67}, {0x001F88, 0x001F80}, {0x001F89, 0x001F81}, + {0x001F8A, 0x001F82}, {0x001F8B, 0x001F83}, {0x001F8C, 0x001F84}, {0x001F8D, 0x001F85}, {0x001F8E, 0x001F86}, + {0x001F8F, 0x001F87}, {0x001F98, 0x001F90}, {0x001F99, 0x001F91}, {0x001F9A, 0x001F92}, {0x001F9B, 0x001F93}, + {0x001F9C, 0x001F94}, {0x001F9D, 0x001F95}, {0x001F9E, 0x001F96}, {0x001F9F, 0x001F97}, {0x001FA8, 0x001FA0}, + {0x001FA9, 0x001FA1}, {0x001FAA, 0x001FA2}, {0x001FAB, 0x001FA3}, {0x001FAC, 0x001FA4}, {0x001FAD, 0x001FA5}, + {0x001FAE, 0x001FA6}, {0x001FAF, 0x001FA7}, {0x001FB8, 0x001FB0}, {0x001FB9, 0x001FB1}, {0x001FBA, 0x001F70}, + {0x001FBB, 0x001F71}, {0x001FBC, 0x001FB3}, {0x001FC8, 0x001F72}, {0x001FC9, 0x001F73}, {0x001FCA, 0x001F74}, + {0x001FCB, 0x001F75}, {0x001FCC, 0x001FC3}, {0x001FD8, 0x001FD0}, {0x001FD9, 0x001FD1}, {0x001FDA, 0x001F76}, + {0x001FDB, 0x001F77}, {0x001FE8, 0x001FE0}, {0x001FE9, 0x001FE1}, {0x001FEA, 0x001F7A}, {0x001FEB, 0x001F7B}, + {0x001FEC, 0x001FE5}, {0x001FF8, 0x001F78}, {0x001FF9, 0x001F79}, {0x001FFA, 0x001F7C}, {0x001FFB, 0x001F7D}, + {0x001FFC, 0x001FF3}, {0x002126, 0x0003C9}, {0x00212A, 0x00006B}, {0x00212B, 0x0000E5}, {0x002132, 0x00214E}, + {0x002160, 0x002170}, {0x002161, 0x002171}, {0x002162, 0x002172}, {0x002163, 0x002173}, {0x002164, 0x002174}, + {0x002165, 0x002175}, {0x002166, 0x002176}, {0x002167, 0x002177}, {0x002168, 0x002178}, {0x002169, 0x002179}, + {0x00216A, 0x00217A}, {0x00216B, 0x00217B}, {0x00216C, 0x00217C}, {0x00216D, 0x00217D}, {0x00216E, 0x00217E}, + {0x00216F, 0x00217F}, {0x002183, 0x002184}, {0x0024B6, 0x0024D0}, {0x0024B7, 0x0024D1}, {0x0024B8, 0x0024D2}, + {0x0024B9, 0x0024D3}, {0x0024BA, 0x0024D4}, {0x0024BB, 0x0024D5}, {0x0024BC, 0x0024D6}, {0x0024BD, 0x0024D7}, + {0x0024BE, 0x0024D8}, {0x0024BF, 0x0024D9}, {0x0024C0, 0x0024DA}, {0x0024C1, 0x0024DB}, {0x0024C2, 0x0024DC}, + {0x0024C3, 0x0024DD}, {0x0024C4, 0x0024DE}, {0x0024C5, 0x0024DF}, {0x0024C6, 0x0024E0}, {0x0024C7, 0x0024E1}, + {0x0024C8, 0x0024E2}, {0x0024C9, 0x0024E3}, {0x0024CA, 0x0024E4}, {0x0024CB, 0x0024E5}, {0x0024CC, 0x0024E6}, + {0x0024CD, 0x0024E7}, {0x0024CE, 0x0024E8}, {0x0024CF, 0x0024E9}, {0x002C00, 0x002C30}, {0x002C01, 0x002C31}, + {0x002C02, 0x002C32}, {0x002C03, 0x002C33}, {0x002C04, 0x002C34}, {0x002C05, 0x002C35}, {0x002C06, 0x002C36}, + {0x002C07, 0x002C37}, {0x002C08, 0x002C38}, {0x002C09, 0x002C39}, {0x002C0A, 0x002C3A}, {0x002C0B, 0x002C3B}, + {0x002C0C, 0x002C3C}, {0x002C0D, 0x002C3D}, {0x002C0E, 0x002C3E}, {0x002C0F, 0x002C3F}, {0x002C10, 0x002C40}, + {0x002C11, 0x002C41}, {0x002C12, 0x002C42}, {0x002C13, 0x002C43}, {0x002C14, 0x002C44}, {0x002C15, 0x002C45}, + {0x002C16, 0x002C46}, {0x002C17, 0x002C47}, {0x002C18, 0x002C48}, {0x002C19, 0x002C49}, {0x002C1A, 0x002C4A}, + {0x002C1B, 0x002C4B}, {0x002C1C, 0x002C4C}, {0x002C1D, 0x002C4D}, {0x002C1E, 0x002C4E}, {0x002C1F, 0x002C4F}, + {0x002C20, 0x002C50}, {0x002C21, 0x002C51}, {0x002C22, 0x002C52}, {0x002C23, 0x002C53}, {0x002C24, 0x002C54}, + {0x002C25, 0x002C55}, {0x002C26, 0x002C56}, {0x002C27, 0x002C57}, {0x002C28, 0x002C58}, {0x002C29, 0x002C59}, + {0x002C2A, 0x002C5A}, {0x002C2B, 0x002C5B}, {0x002C2C, 0x002C5C}, {0x002C2D, 0x002C5D}, {0x002C2E, 0x002C5E}, + {0x002C2F, 0x002C5F}, {0x002C60, 0x002C61}, {0x002C62, 0x00026B}, {0x002C63, 0x001D7D}, {0x002C64, 0x00027D}, + {0x002C67, 0x002C68}, {0x002C69, 0x002C6A}, {0x002C6B, 0x002C6C}, {0x002C6D, 0x000251}, {0x002C6E, 0x000271}, + {0x002C6F, 0x000250}, {0x002C70, 0x000252}, {0x002C72, 0x002C73}, {0x002C75, 0x002C76}, {0x002C7E, 0x00023F}, + {0x002C7F, 0x000240}, {0x002C80, 0x002C81}, {0x002C82, 0x002C83}, {0x002C84, 0x002C85}, {0x002C86, 0x002C87}, + {0x002C88, 0x002C89}, {0x002C8A, 0x002C8B}, {0x002C8C, 0x002C8D}, {0x002C8E, 0x002C8F}, {0x002C90, 0x002C91}, + {0x002C92, 0x002C93}, {0x002C94, 0x002C95}, {0x002C96, 0x002C97}, {0x002C98, 0x002C99}, {0x002C9A, 0x002C9B}, + {0x002C9C, 0x002C9D}, {0x002C9E, 0x002C9F}, {0x002CA0, 0x002CA1}, {0x002CA2, 0x002CA3}, {0x002CA4, 0x002CA5}, + {0x002CA6, 0x002CA7}, {0x002CA8, 0x002CA9}, {0x002CAA, 0x002CAB}, {0x002CAC, 0x002CAD}, {0x002CAE, 0x002CAF}, + {0x002CB0, 0x002CB1}, {0x002CB2, 0x002CB3}, {0x002CB4, 0x002CB5}, {0x002CB6, 0x002CB7}, {0x002CB8, 0x002CB9}, + {0x002CBA, 0x002CBB}, {0x002CBC, 0x002CBD}, {0x002CBE, 0x002CBF}, {0x002CC0, 0x002CC1}, {0x002CC2, 0x002CC3}, + {0x002CC4, 0x002CC5}, {0x002CC6, 0x002CC7}, {0x002CC8, 0x002CC9}, {0x002CCA, 0x002CCB}, {0x002CCC, 0x002CCD}, + {0x002CCE, 0x002CCF}, {0x002CD0, 0x002CD1}, {0x002CD2, 0x002CD3}, {0x002CD4, 0x002CD5}, {0x002CD6, 0x002CD7}, + {0x002CD8, 0x002CD9}, {0x002CDA, 0x002CDB}, {0x002CDC, 0x002CDD}, {0x002CDE, 0x002CDF}, {0x002CE0, 0x002CE1}, + {0x002CE2, 0x002CE3}, {0x002CEB, 0x002CEC}, {0x002CED, 0x002CEE}, {0x002CF2, 0x002CF3}, {0x00A640, 0x00A641}, + {0x00A642, 0x00A643}, {0x00A644, 0x00A645}, {0x00A646, 0x00A647}, {0x00A648, 0x00A649}, {0x00A64A, 0x00A64B}, + {0x00A64C, 0x00A64D}, {0x00A64E, 0x00A64F}, {0x00A650, 0x00A651}, {0x00A652, 0x00A653}, {0x00A654, 0x00A655}, + {0x00A656, 0x00A657}, {0x00A658, 0x00A659}, {0x00A65A, 0x00A65B}, {0x00A65C, 0x00A65D}, {0x00A65E, 0x00A65F}, + {0x00A660, 0x00A661}, {0x00A662, 0x00A663}, {0x00A664, 0x00A665}, {0x00A666, 0x00A667}, {0x00A668, 0x00A669}, + {0x00A66A, 0x00A66B}, {0x00A66C, 0x00A66D}, {0x00A680, 0x00A681}, {0x00A682, 0x00A683}, {0x00A684, 0x00A685}, + {0x00A686, 0x00A687}, {0x00A688, 0x00A689}, {0x00A68A, 0x00A68B}, {0x00A68C, 0x00A68D}, {0x00A68E, 0x00A68F}, + {0x00A690, 0x00A691}, {0x00A692, 0x00A693}, {0x00A694, 0x00A695}, {0x00A696, 0x00A697}, {0x00A698, 0x00A699}, + {0x00A69A, 0x00A69B}, {0x00A722, 0x00A723}, {0x00A724, 0x00A725}, {0x00A726, 0x00A727}, {0x00A728, 0x00A729}, + {0x00A72A, 0x00A72B}, {0x00A72C, 0x00A72D}, {0x00A72E, 0x00A72F}, {0x00A732, 0x00A733}, {0x00A734, 0x00A735}, + {0x00A736, 0x00A737}, {0x00A738, 0x00A739}, {0x00A73A, 0x00A73B}, {0x00A73C, 0x00A73D}, {0x00A73E, 0x00A73F}, + {0x00A740, 0x00A741}, {0x00A742, 0x00A743}, {0x00A744, 0x00A745}, {0x00A746, 0x00A747}, {0x00A748, 0x00A749}, + {0x00A74A, 0x00A74B}, {0x00A74C, 0x00A74D}, {0x00A74E, 0x00A74F}, {0x00A750, 0x00A751}, {0x00A752, 0x00A753}, + {0x00A754, 0x00A755}, {0x00A756, 0x00A757}, {0x00A758, 0x00A759}, {0x00A75A, 0x00A75B}, {0x00A75C, 0x00A75D}, + {0x00A75E, 0x00A75F}, {0x00A760, 0x00A761}, {0x00A762, 0x00A763}, {0x00A764, 0x00A765}, {0x00A766, 0x00A767}, + {0x00A768, 0x00A769}, {0x00A76A, 0x00A76B}, {0x00A76C, 0x00A76D}, {0x00A76E, 0x00A76F}, {0x00A779, 0x00A77A}, + {0x00A77B, 0x00A77C}, {0x00A77D, 0x001D79}, {0x00A77E, 0x00A77F}, {0x00A780, 0x00A781}, {0x00A782, 0x00A783}, + {0x00A784, 0x00A785}, {0x00A786, 0x00A787}, {0x00A78B, 0x00A78C}, {0x00A78D, 0x000265}, {0x00A790, 0x00A791}, + {0x00A792, 0x00A793}, {0x00A796, 0x00A797}, {0x00A798, 0x00A799}, {0x00A79A, 0x00A79B}, {0x00A79C, 0x00A79D}, + {0x00A79E, 0x00A79F}, {0x00A7A0, 0x00A7A1}, {0x00A7A2, 0x00A7A3}, {0x00A7A4, 0x00A7A5}, {0x00A7A6, 0x00A7A7}, + {0x00A7A8, 0x00A7A9}, {0x00A7AA, 0x000266}, {0x00A7AB, 0x00025C}, {0x00A7AC, 0x000261}, {0x00A7AD, 0x00026C}, + {0x00A7AE, 0x00026A}, {0x00A7B0, 0x00029E}, {0x00A7B1, 0x000287}, {0x00A7B2, 0x00029D}, {0x00A7B3, 0x00AB53}, + {0x00A7B4, 0x00A7B5}, {0x00A7B6, 0x00A7B7}, {0x00A7B8, 0x00A7B9}, {0x00A7BA, 0x00A7BB}, {0x00A7BC, 0x00A7BD}, + {0x00A7BE, 0x00A7BF}, {0x00A7C0, 0x00A7C1}, {0x00A7C2, 0x00A7C3}, {0x00A7C4, 0x00A794}, {0x00A7C5, 0x000282}, + {0x00A7C6, 0x001D8E}, {0x00A7C7, 0x00A7C8}, {0x00A7C9, 0x00A7CA}, {0x00A7D0, 0x00A7D1}, {0x00A7D6, 0x00A7D7}, + {0x00A7D8, 0x00A7D9}, {0x00A7F5, 0x00A7F6}, {0x00FF21, 0x00FF41}, {0x00FF22, 0x00FF42}, {0x00FF23, 0x00FF43}, + {0x00FF24, 0x00FF44}, {0x00FF25, 0x00FF45}, {0x00FF26, 0x00FF46}, {0x00FF27, 0x00FF47}, {0x00FF28, 0x00FF48}, + {0x00FF29, 0x00FF49}, {0x00FF2A, 0x00FF4A}, {0x00FF2B, 0x00FF4B}, {0x00FF2C, 0x00FF4C}, {0x00FF2D, 0x00FF4D}, + {0x00FF2E, 0x00FF4E}, {0x00FF2F, 0x00FF4F}, {0x00FF30, 0x00FF50}, {0x00FF31, 0x00FF51}, {0x00FF32, 0x00FF52}, + {0x00FF33, 0x00FF53}, {0x00FF34, 0x00FF54}, {0x00FF35, 0x00FF55}, {0x00FF36, 0x00FF56}, {0x00FF37, 0x00FF57}, + {0x00FF38, 0x00FF58}, {0x00FF39, 0x00FF59}, {0x00FF3A, 0x00FF5A}, {0x010400, 0x010428}, {0x010401, 0x010429}, + {0x010402, 0x01042A}, {0x010403, 0x01042B}, {0x010404, 0x01042C}, {0x010405, 0x01042D}, {0x010406, 0x01042E}, + {0x010407, 0x01042F}, {0x010408, 0x010430}, {0x010409, 0x010431}, {0x01040A, 0x010432}, {0x01040B, 0x010433}, + {0x01040C, 0x010434}, {0x01040D, 0x010435}, {0x01040E, 0x010436}, {0x01040F, 0x010437}, {0x010410, 0x010438}, + {0x010411, 0x010439}, {0x010412, 0x01043A}, {0x010413, 0x01043B}, {0x010414, 0x01043C}, {0x010415, 0x01043D}, + {0x010416, 0x01043E}, {0x010417, 0x01043F}, {0x010418, 0x010440}, {0x010419, 0x010441}, {0x01041A, 0x010442}, + {0x01041B, 0x010443}, {0x01041C, 0x010444}, {0x01041D, 0x010445}, {0x01041E, 0x010446}, {0x01041F, 0x010447}, + {0x010420, 0x010448}, {0x010421, 0x010449}, {0x010422, 0x01044A}, {0x010423, 0x01044B}, {0x010424, 0x01044C}, + {0x010425, 0x01044D}, {0x010426, 0x01044E}, {0x010427, 0x01044F}, {0x0104B0, 0x0104D8}, {0x0104B1, 0x0104D9}, + {0x0104B2, 0x0104DA}, {0x0104B3, 0x0104DB}, {0x0104B4, 0x0104DC}, {0x0104B5, 0x0104DD}, {0x0104B6, 0x0104DE}, + {0x0104B7, 0x0104DF}, {0x0104B8, 0x0104E0}, {0x0104B9, 0x0104E1}, {0x0104BA, 0x0104E2}, {0x0104BB, 0x0104E3}, + {0x0104BC, 0x0104E4}, {0x0104BD, 0x0104E5}, {0x0104BE, 0x0104E6}, {0x0104BF, 0x0104E7}, {0x0104C0, 0x0104E8}, + {0x0104C1, 0x0104E9}, {0x0104C2, 0x0104EA}, {0x0104C3, 0x0104EB}, {0x0104C4, 0x0104EC}, {0x0104C5, 0x0104ED}, + {0x0104C6, 0x0104EE}, {0x0104C7, 0x0104EF}, {0x0104C8, 0x0104F0}, {0x0104C9, 0x0104F1}, {0x0104CA, 0x0104F2}, + {0x0104CB, 0x0104F3}, {0x0104CC, 0x0104F4}, {0x0104CD, 0x0104F5}, {0x0104CE, 0x0104F6}, {0x0104CF, 0x0104F7}, + {0x0104D0, 0x0104F8}, {0x0104D1, 0x0104F9}, {0x0104D2, 0x0104FA}, {0x0104D3, 0x0104FB}, {0x010570, 0x010597}, + {0x010571, 0x010598}, {0x010572, 0x010599}, {0x010573, 0x01059A}, {0x010574, 0x01059B}, {0x010575, 0x01059C}, + {0x010576, 0x01059D}, {0x010577, 0x01059E}, {0x010578, 0x01059F}, {0x010579, 0x0105A0}, {0x01057A, 0x0105A1}, + {0x01057C, 0x0105A3}, {0x01057D, 0x0105A4}, {0x01057E, 0x0105A5}, {0x01057F, 0x0105A6}, {0x010580, 0x0105A7}, + {0x010581, 0x0105A8}, {0x010582, 0x0105A9}, {0x010583, 0x0105AA}, {0x010584, 0x0105AB}, {0x010585, 0x0105AC}, + {0x010586, 0x0105AD}, {0x010587, 0x0105AE}, {0x010588, 0x0105AF}, {0x010589, 0x0105B0}, {0x01058A, 0x0105B1}, + {0x01058C, 0x0105B3}, {0x01058D, 0x0105B4}, {0x01058E, 0x0105B5}, {0x01058F, 0x0105B6}, {0x010590, 0x0105B7}, + {0x010591, 0x0105B8}, {0x010592, 0x0105B9}, {0x010594, 0x0105BB}, {0x010595, 0x0105BC}, {0x010C80, 0x010CC0}, + {0x010C81, 0x010CC1}, {0x010C82, 0x010CC2}, {0x010C83, 0x010CC3}, {0x010C84, 0x010CC4}, {0x010C85, 0x010CC5}, + {0x010C86, 0x010CC6}, {0x010C87, 0x010CC7}, {0x010C88, 0x010CC8}, {0x010C89, 0x010CC9}, {0x010C8A, 0x010CCA}, + {0x010C8B, 0x010CCB}, {0x010C8C, 0x010CCC}, {0x010C8D, 0x010CCD}, {0x010C8E, 0x010CCE}, {0x010C8F, 0x010CCF}, + {0x010C90, 0x010CD0}, {0x010C91, 0x010CD1}, {0x010C92, 0x010CD2}, {0x010C93, 0x010CD3}, {0x010C94, 0x010CD4}, + {0x010C95, 0x010CD5}, {0x010C96, 0x010CD6}, {0x010C97, 0x010CD7}, {0x010C98, 0x010CD8}, {0x010C99, 0x010CD9}, + {0x010C9A, 0x010CDA}, {0x010C9B, 0x010CDB}, {0x010C9C, 0x010CDC}, {0x010C9D, 0x010CDD}, {0x010C9E, 0x010CDE}, + {0x010C9F, 0x010CDF}, {0x010CA0, 0x010CE0}, {0x010CA1, 0x010CE1}, {0x010CA2, 0x010CE2}, {0x010CA3, 0x010CE3}, + {0x010CA4, 0x010CE4}, {0x010CA5, 0x010CE5}, {0x010CA6, 0x010CE6}, {0x010CA7, 0x010CE7}, {0x010CA8, 0x010CE8}, + {0x010CA9, 0x010CE9}, {0x010CAA, 0x010CEA}, {0x010CAB, 0x010CEB}, {0x010CAC, 0x010CEC}, {0x010CAD, 0x010CED}, + {0x010CAE, 0x010CEE}, {0x010CAF, 0x010CEF}, {0x010CB0, 0x010CF0}, {0x010CB1, 0x010CF1}, {0x010CB2, 0x010CF2}, + {0x0118A0, 0x0118C0}, {0x0118A1, 0x0118C1}, {0x0118A2, 0x0118C2}, {0x0118A3, 0x0118C3}, {0x0118A4, 0x0118C4}, + {0x0118A5, 0x0118C5}, {0x0118A6, 0x0118C6}, {0x0118A7, 0x0118C7}, {0x0118A8, 0x0118C8}, {0x0118A9, 0x0118C9}, + {0x0118AA, 0x0118CA}, {0x0118AB, 0x0118CB}, {0x0118AC, 0x0118CC}, {0x0118AD, 0x0118CD}, {0x0118AE, 0x0118CE}, + {0x0118AF, 0x0118CF}, {0x0118B0, 0x0118D0}, {0x0118B1, 0x0118D1}, {0x0118B2, 0x0118D2}, {0x0118B3, 0x0118D3}, + {0x0118B4, 0x0118D4}, {0x0118B5, 0x0118D5}, {0x0118B6, 0x0118D6}, {0x0118B7, 0x0118D7}, {0x0118B8, 0x0118D8}, + {0x0118B9, 0x0118D9}, {0x0118BA, 0x0118DA}, {0x0118BB, 0x0118DB}, {0x0118BC, 0x0118DC}, {0x0118BD, 0x0118DD}, + {0x0118BE, 0x0118DE}, {0x0118BF, 0x0118DF}, {0x016E40, 0x016E60}, {0x016E41, 0x016E61}, {0x016E42, 0x016E62}, + {0x016E43, 0x016E63}, {0x016E44, 0x016E64}, {0x016E45, 0x016E65}, {0x016E46, 0x016E66}, {0x016E47, 0x016E67}, + {0x016E48, 0x016E68}, {0x016E49, 0x016E69}, {0x016E4A, 0x016E6A}, {0x016E4B, 0x016E6B}, {0x016E4C, 0x016E6C}, + {0x016E4D, 0x016E6D}, {0x016E4E, 0x016E6E}, {0x016E4F, 0x016E6F}, {0x016E50, 0x016E70}, {0x016E51, 0x016E71}, + {0x016E52, 0x016E72}, {0x016E53, 0x016E73}, {0x016E54, 0x016E74}, {0x016E55, 0x016E75}, {0x016E56, 0x016E76}, + {0x016E57, 0x016E77}, {0x016E58, 0x016E78}, {0x016E59, 0x016E79}, {0x016E5A, 0x016E7A}, {0x016E5B, 0x016E7B}, + {0x016E5C, 0x016E7C}, {0x016E5D, 0x016E7D}, {0x016E5E, 0x016E7E}, {0x016E5F, 0x016E7F}, {0x01E900, 0x01E922}, + {0x01E901, 0x01E923}, {0x01E902, 0x01E924}, {0x01E903, 0x01E925}, {0x01E904, 0x01E926}, {0x01E905, 0x01E927}, + {0x01E906, 0x01E928}, {0x01E907, 0x01E929}, {0x01E908, 0x01E92A}, {0x01E909, 0x01E92B}, {0x01E90A, 0x01E92C}, + {0x01E90B, 0x01E92D}, {0x01E90C, 0x01E92E}, {0x01E90D, 0x01E92F}, {0x01E90E, 0x01E930}, {0x01E90F, 0x01E931}, + {0x01E910, 0x01E932}, {0x01E911, 0x01E933}, {0x01E912, 0x01E934}, {0x01E913, 0x01E935}, {0x01E914, 0x01E936}, + {0x01E915, 0x01E937}, {0x01E916, 0x01E938}, {0x01E917, 0x01E939}, {0x01E918, 0x01E93A}, {0x01E919, 0x01E93B}, + {0x01E91A, 0x01E93C}, {0x01E91B, 0x01E93D}, {0x01E91C, 0x01E93E}, {0x01E91D, 0x01E93F}, {0x01E91E, 0x01E940}, + {0x01E91F, 0x01E941}, {0x01E920, 0x01E942}, {0x01E921, 0x01E943}, +}; + +// list is always in ascending order, to enable binary search +const std::initializer_list> unicode_map_uppercase = { + {0x000061, 0x000041}, {0x000062, 0x000042}, {0x000063, 0x000043}, {0x000064, 0x000044}, {0x000065, 0x000045}, + {0x000066, 0x000046}, {0x000067, 0x000047}, {0x000068, 0x000048}, {0x000069, 0x000049}, {0x00006A, 0x00004A}, + {0x00006B, 0x00004B}, {0x00006C, 0x00004C}, {0x00006D, 0x00004D}, {0x00006E, 0x00004E}, {0x00006F, 0x00004F}, + {0x000070, 0x000050}, {0x000071, 0x000051}, {0x000072, 0x000052}, {0x000073, 0x000053}, {0x000074, 0x000054}, + {0x000075, 0x000055}, {0x000076, 0x000056}, {0x000077, 0x000057}, {0x000078, 0x000058}, {0x000079, 0x000059}, + {0x00007A, 0x00005A}, {0x0000B5, 0x00039C}, {0x0000E0, 0x0000C0}, {0x0000E1, 0x0000C1}, {0x0000E2, 0x0000C2}, + {0x0000E3, 0x0000C3}, {0x0000E4, 0x0000C4}, {0x0000E5, 0x0000C5}, {0x0000E6, 0x0000C6}, {0x0000E7, 0x0000C7}, + {0x0000E8, 0x0000C8}, {0x0000E9, 0x0000C9}, {0x0000EA, 0x0000CA}, {0x0000EB, 0x0000CB}, {0x0000EC, 0x0000CC}, + {0x0000ED, 0x0000CD}, {0x0000EE, 0x0000CE}, {0x0000EF, 0x0000CF}, {0x0000F0, 0x0000D0}, {0x0000F1, 0x0000D1}, + {0x0000F2, 0x0000D2}, {0x0000F3, 0x0000D3}, {0x0000F4, 0x0000D4}, {0x0000F5, 0x0000D5}, {0x0000F6, 0x0000D6}, + {0x0000F8, 0x0000D8}, {0x0000F9, 0x0000D9}, {0x0000FA, 0x0000DA}, {0x0000FB, 0x0000DB}, {0x0000FC, 0x0000DC}, + {0x0000FD, 0x0000DD}, {0x0000FE, 0x0000DE}, {0x0000FF, 0x000178}, {0x000101, 0x000100}, {0x000103, 0x000102}, + {0x000105, 0x000104}, {0x000107, 0x000106}, {0x000109, 0x000108}, {0x00010B, 0x00010A}, {0x00010D, 0x00010C}, + {0x00010F, 0x00010E}, {0x000111, 0x000110}, {0x000113, 0x000112}, {0x000115, 0x000114}, {0x000117, 0x000116}, + {0x000119, 0x000118}, {0x00011B, 0x00011A}, {0x00011D, 0x00011C}, {0x00011F, 0x00011E}, {0x000121, 0x000120}, + {0x000123, 0x000122}, {0x000125, 0x000124}, {0x000127, 0x000126}, {0x000129, 0x000128}, {0x00012B, 0x00012A}, + {0x00012D, 0x00012C}, {0x00012F, 0x00012E}, {0x000131, 0x000049}, {0x000133, 0x000132}, {0x000135, 0x000134}, + {0x000137, 0x000136}, {0x00013A, 0x000139}, {0x00013C, 0x00013B}, {0x00013E, 0x00013D}, {0x000140, 0x00013F}, + {0x000142, 0x000141}, {0x000144, 0x000143}, {0x000146, 0x000145}, {0x000148, 0x000147}, {0x00014B, 0x00014A}, + {0x00014D, 0x00014C}, {0x00014F, 0x00014E}, {0x000151, 0x000150}, {0x000153, 0x000152}, {0x000155, 0x000154}, + {0x000157, 0x000156}, {0x000159, 0x000158}, {0x00015B, 0x00015A}, {0x00015D, 0x00015C}, {0x00015F, 0x00015E}, + {0x000161, 0x000160}, {0x000163, 0x000162}, {0x000165, 0x000164}, {0x000167, 0x000166}, {0x000169, 0x000168}, + {0x00016B, 0x00016A}, {0x00016D, 0x00016C}, {0x00016F, 0x00016E}, {0x000171, 0x000170}, {0x000173, 0x000172}, + {0x000175, 0x000174}, {0x000177, 0x000176}, {0x00017A, 0x000179}, {0x00017C, 0x00017B}, {0x00017E, 0x00017D}, + {0x00017F, 0x000053}, {0x000180, 0x000243}, {0x000183, 0x000182}, {0x000185, 0x000184}, {0x000188, 0x000187}, + {0x00018C, 0x00018B}, {0x000192, 0x000191}, {0x000195, 0x0001F6}, {0x000199, 0x000198}, {0x00019A, 0x00023D}, + {0x00019E, 0x000220}, {0x0001A1, 0x0001A0}, {0x0001A3, 0x0001A2}, {0x0001A5, 0x0001A4}, {0x0001A8, 0x0001A7}, + {0x0001AD, 0x0001AC}, {0x0001B0, 0x0001AF}, {0x0001B4, 0x0001B3}, {0x0001B6, 0x0001B5}, {0x0001B9, 0x0001B8}, + {0x0001BD, 0x0001BC}, {0x0001BF, 0x0001F7}, {0x0001C5, 0x0001C4}, {0x0001C6, 0x0001C4}, {0x0001C8, 0x0001C7}, + {0x0001C9, 0x0001C7}, {0x0001CB, 0x0001CA}, {0x0001CC, 0x0001CA}, {0x0001CE, 0x0001CD}, {0x0001D0, 0x0001CF}, + {0x0001D2, 0x0001D1}, {0x0001D4, 0x0001D3}, {0x0001D6, 0x0001D5}, {0x0001D8, 0x0001D7}, {0x0001DA, 0x0001D9}, + {0x0001DC, 0x0001DB}, {0x0001DD, 0x00018E}, {0x0001DF, 0x0001DE}, {0x0001E1, 0x0001E0}, {0x0001E3, 0x0001E2}, + {0x0001E5, 0x0001E4}, {0x0001E7, 0x0001E6}, {0x0001E9, 0x0001E8}, {0x0001EB, 0x0001EA}, {0x0001ED, 0x0001EC}, + {0x0001EF, 0x0001EE}, {0x0001F2, 0x0001F1}, {0x0001F3, 0x0001F1}, {0x0001F5, 0x0001F4}, {0x0001F9, 0x0001F8}, + {0x0001FB, 0x0001FA}, {0x0001FD, 0x0001FC}, {0x0001FF, 0x0001FE}, {0x000201, 0x000200}, {0x000203, 0x000202}, + {0x000205, 0x000204}, {0x000207, 0x000206}, {0x000209, 0x000208}, {0x00020B, 0x00020A}, {0x00020D, 0x00020C}, + {0x00020F, 0x00020E}, {0x000211, 0x000210}, {0x000213, 0x000212}, {0x000215, 0x000214}, {0x000217, 0x000216}, + {0x000219, 0x000218}, {0x00021B, 0x00021A}, {0x00021D, 0x00021C}, {0x00021F, 0x00021E}, {0x000223, 0x000222}, + {0x000225, 0x000224}, {0x000227, 0x000226}, {0x000229, 0x000228}, {0x00022B, 0x00022A}, {0x00022D, 0x00022C}, + {0x00022F, 0x00022E}, {0x000231, 0x000230}, {0x000233, 0x000232}, {0x00023C, 0x00023B}, {0x00023F, 0x002C7E}, + {0x000240, 0x002C7F}, {0x000242, 0x000241}, {0x000247, 0x000246}, {0x000249, 0x000248}, {0x00024B, 0x00024A}, + {0x00024D, 0x00024C}, {0x00024F, 0x00024E}, {0x000250, 0x002C6F}, {0x000251, 0x002C6D}, {0x000252, 0x002C70}, + {0x000253, 0x000181}, {0x000254, 0x000186}, {0x000256, 0x000189}, {0x000257, 0x00018A}, {0x000259, 0x00018F}, + {0x00025B, 0x000190}, {0x00025C, 0x00A7AB}, {0x000260, 0x000193}, {0x000261, 0x00A7AC}, {0x000263, 0x000194}, + {0x000265, 0x00A78D}, {0x000266, 0x00A7AA}, {0x000268, 0x000197}, {0x000269, 0x000196}, {0x00026A, 0x00A7AE}, + {0x00026B, 0x002C62}, {0x00026C, 0x00A7AD}, {0x00026F, 0x00019C}, {0x000271, 0x002C6E}, {0x000272, 0x00019D}, + {0x000275, 0x00019F}, {0x00027D, 0x002C64}, {0x000280, 0x0001A6}, {0x000282, 0x00A7C5}, {0x000283, 0x0001A9}, + {0x000287, 0x00A7B1}, {0x000288, 0x0001AE}, {0x000289, 0x000244}, {0x00028A, 0x0001B1}, {0x00028B, 0x0001B2}, + {0x00028C, 0x000245}, {0x000292, 0x0001B7}, {0x00029D, 0x00A7B2}, {0x00029E, 0x00A7B0}, {0x000345, 0x000399}, + {0x000371, 0x000370}, {0x000373, 0x000372}, {0x000377, 0x000376}, {0x00037B, 0x0003FD}, {0x00037C, 0x0003FE}, + {0x00037D, 0x0003FF}, {0x0003AC, 0x000386}, {0x0003AD, 0x000388}, {0x0003AE, 0x000389}, {0x0003AF, 0x00038A}, + {0x0003B1, 0x000391}, {0x0003B2, 0x000392}, {0x0003B3, 0x000393}, {0x0003B4, 0x000394}, {0x0003B5, 0x000395}, + {0x0003B6, 0x000396}, {0x0003B7, 0x000397}, {0x0003B8, 0x000398}, {0x0003B9, 0x000399}, {0x0003BA, 0x00039A}, + {0x0003BB, 0x00039B}, {0x0003BC, 0x00039C}, {0x0003BD, 0x00039D}, {0x0003BE, 0x00039E}, {0x0003BF, 0x00039F}, + {0x0003C0, 0x0003A0}, {0x0003C1, 0x0003A1}, {0x0003C2, 0x0003A3}, {0x0003C3, 0x0003A3}, {0x0003C4, 0x0003A4}, + {0x0003C5, 0x0003A5}, {0x0003C6, 0x0003A6}, {0x0003C7, 0x0003A7}, {0x0003C8, 0x0003A8}, {0x0003C9, 0x0003A9}, + {0x0003CA, 0x0003AA}, {0x0003CB, 0x0003AB}, {0x0003CC, 0x00038C}, {0x0003CD, 0x00038E}, {0x0003CE, 0x00038F}, + {0x0003D0, 0x000392}, {0x0003D1, 0x000398}, {0x0003D5, 0x0003A6}, {0x0003D6, 0x0003A0}, {0x0003D7, 0x0003CF}, + {0x0003D9, 0x0003D8}, {0x0003DB, 0x0003DA}, {0x0003DD, 0x0003DC}, {0x0003DF, 0x0003DE}, {0x0003E1, 0x0003E0}, + {0x0003E3, 0x0003E2}, {0x0003E5, 0x0003E4}, {0x0003E7, 0x0003E6}, {0x0003E9, 0x0003E8}, {0x0003EB, 0x0003EA}, + {0x0003ED, 0x0003EC}, {0x0003EF, 0x0003EE}, {0x0003F0, 0x00039A}, {0x0003F1, 0x0003A1}, {0x0003F2, 0x0003F9}, + {0x0003F3, 0x00037F}, {0x0003F5, 0x000395}, {0x0003F8, 0x0003F7}, {0x0003FB, 0x0003FA}, {0x000430, 0x000410}, + {0x000431, 0x000411}, {0x000432, 0x000412}, {0x000433, 0x000413}, {0x000434, 0x000414}, {0x000435, 0x000415}, + {0x000436, 0x000416}, {0x000437, 0x000417}, {0x000438, 0x000418}, {0x000439, 0x000419}, {0x00043A, 0x00041A}, + {0x00043B, 0x00041B}, {0x00043C, 0x00041C}, {0x00043D, 0x00041D}, {0x00043E, 0x00041E}, {0x00043F, 0x00041F}, + {0x000440, 0x000420}, {0x000441, 0x000421}, {0x000442, 0x000422}, {0x000443, 0x000423}, {0x000444, 0x000424}, + {0x000445, 0x000425}, {0x000446, 0x000426}, {0x000447, 0x000427}, {0x000448, 0x000428}, {0x000449, 0x000429}, + {0x00044A, 0x00042A}, {0x00044B, 0x00042B}, {0x00044C, 0x00042C}, {0x00044D, 0x00042D}, {0x00044E, 0x00042E}, + {0x00044F, 0x00042F}, {0x000450, 0x000400}, {0x000451, 0x000401}, {0x000452, 0x000402}, {0x000453, 0x000403}, + {0x000454, 0x000404}, {0x000455, 0x000405}, {0x000456, 0x000406}, {0x000457, 0x000407}, {0x000458, 0x000408}, + {0x000459, 0x000409}, {0x00045A, 0x00040A}, {0x00045B, 0x00040B}, {0x00045C, 0x00040C}, {0x00045D, 0x00040D}, + {0x00045E, 0x00040E}, {0x00045F, 0x00040F}, {0x000461, 0x000460}, {0x000463, 0x000462}, {0x000465, 0x000464}, + {0x000467, 0x000466}, {0x000469, 0x000468}, {0x00046B, 0x00046A}, {0x00046D, 0x00046C}, {0x00046F, 0x00046E}, + {0x000471, 0x000470}, {0x000473, 0x000472}, {0x000475, 0x000474}, {0x000477, 0x000476}, {0x000479, 0x000478}, + {0x00047B, 0x00047A}, {0x00047D, 0x00047C}, {0x00047F, 0x00047E}, {0x000481, 0x000480}, {0x00048B, 0x00048A}, + {0x00048D, 0x00048C}, {0x00048F, 0x00048E}, {0x000491, 0x000490}, {0x000493, 0x000492}, {0x000495, 0x000494}, + {0x000497, 0x000496}, {0x000499, 0x000498}, {0x00049B, 0x00049A}, {0x00049D, 0x00049C}, {0x00049F, 0x00049E}, + {0x0004A1, 0x0004A0}, {0x0004A3, 0x0004A2}, {0x0004A5, 0x0004A4}, {0x0004A7, 0x0004A6}, {0x0004A9, 0x0004A8}, + {0x0004AB, 0x0004AA}, {0x0004AD, 0x0004AC}, {0x0004AF, 0x0004AE}, {0x0004B1, 0x0004B0}, {0x0004B3, 0x0004B2}, + {0x0004B5, 0x0004B4}, {0x0004B7, 0x0004B6}, {0x0004B9, 0x0004B8}, {0x0004BB, 0x0004BA}, {0x0004BD, 0x0004BC}, + {0x0004BF, 0x0004BE}, {0x0004C2, 0x0004C1}, {0x0004C4, 0x0004C3}, {0x0004C6, 0x0004C5}, {0x0004C8, 0x0004C7}, + {0x0004CA, 0x0004C9}, {0x0004CC, 0x0004CB}, {0x0004CE, 0x0004CD}, {0x0004CF, 0x0004C0}, {0x0004D1, 0x0004D0}, + {0x0004D3, 0x0004D2}, {0x0004D5, 0x0004D4}, {0x0004D7, 0x0004D6}, {0x0004D9, 0x0004D8}, {0x0004DB, 0x0004DA}, + {0x0004DD, 0x0004DC}, {0x0004DF, 0x0004DE}, {0x0004E1, 0x0004E0}, {0x0004E3, 0x0004E2}, {0x0004E5, 0x0004E4}, + {0x0004E7, 0x0004E6}, {0x0004E9, 0x0004E8}, {0x0004EB, 0x0004EA}, {0x0004ED, 0x0004EC}, {0x0004EF, 0x0004EE}, + {0x0004F1, 0x0004F0}, {0x0004F3, 0x0004F2}, {0x0004F5, 0x0004F4}, {0x0004F7, 0x0004F6}, {0x0004F9, 0x0004F8}, + {0x0004FB, 0x0004FA}, {0x0004FD, 0x0004FC}, {0x0004FF, 0x0004FE}, {0x000501, 0x000500}, {0x000503, 0x000502}, + {0x000505, 0x000504}, {0x000507, 0x000506}, {0x000509, 0x000508}, {0x00050B, 0x00050A}, {0x00050D, 0x00050C}, + {0x00050F, 0x00050E}, {0x000511, 0x000510}, {0x000513, 0x000512}, {0x000515, 0x000514}, {0x000517, 0x000516}, + {0x000519, 0x000518}, {0x00051B, 0x00051A}, {0x00051D, 0x00051C}, {0x00051F, 0x00051E}, {0x000521, 0x000520}, + {0x000523, 0x000522}, {0x000525, 0x000524}, {0x000527, 0x000526}, {0x000529, 0x000528}, {0x00052B, 0x00052A}, + {0x00052D, 0x00052C}, {0x00052F, 0x00052E}, {0x000561, 0x000531}, {0x000562, 0x000532}, {0x000563, 0x000533}, + {0x000564, 0x000534}, {0x000565, 0x000535}, {0x000566, 0x000536}, {0x000567, 0x000537}, {0x000568, 0x000538}, + {0x000569, 0x000539}, {0x00056A, 0x00053A}, {0x00056B, 0x00053B}, {0x00056C, 0x00053C}, {0x00056D, 0x00053D}, + {0x00056E, 0x00053E}, {0x00056F, 0x00053F}, {0x000570, 0x000540}, {0x000571, 0x000541}, {0x000572, 0x000542}, + {0x000573, 0x000543}, {0x000574, 0x000544}, {0x000575, 0x000545}, {0x000576, 0x000546}, {0x000577, 0x000547}, + {0x000578, 0x000548}, {0x000579, 0x000549}, {0x00057A, 0x00054A}, {0x00057B, 0x00054B}, {0x00057C, 0x00054C}, + {0x00057D, 0x00054D}, {0x00057E, 0x00054E}, {0x00057F, 0x00054F}, {0x000580, 0x000550}, {0x000581, 0x000551}, + {0x000582, 0x000552}, {0x000583, 0x000553}, {0x000584, 0x000554}, {0x000585, 0x000555}, {0x000586, 0x000556}, + {0x0010D0, 0x001C90}, {0x0010D1, 0x001C91}, {0x0010D2, 0x001C92}, {0x0010D3, 0x001C93}, {0x0010D4, 0x001C94}, + {0x0010D5, 0x001C95}, {0x0010D6, 0x001C96}, {0x0010D7, 0x001C97}, {0x0010D8, 0x001C98}, {0x0010D9, 0x001C99}, + {0x0010DA, 0x001C9A}, {0x0010DB, 0x001C9B}, {0x0010DC, 0x001C9C}, {0x0010DD, 0x001C9D}, {0x0010DE, 0x001C9E}, + {0x0010DF, 0x001C9F}, {0x0010E0, 0x001CA0}, {0x0010E1, 0x001CA1}, {0x0010E2, 0x001CA2}, {0x0010E3, 0x001CA3}, + {0x0010E4, 0x001CA4}, {0x0010E5, 0x001CA5}, {0x0010E6, 0x001CA6}, {0x0010E7, 0x001CA7}, {0x0010E8, 0x001CA8}, + {0x0010E9, 0x001CA9}, {0x0010EA, 0x001CAA}, {0x0010EB, 0x001CAB}, {0x0010EC, 0x001CAC}, {0x0010ED, 0x001CAD}, + {0x0010EE, 0x001CAE}, {0x0010EF, 0x001CAF}, {0x0010F0, 0x001CB0}, {0x0010F1, 0x001CB1}, {0x0010F2, 0x001CB2}, + {0x0010F3, 0x001CB3}, {0x0010F4, 0x001CB4}, {0x0010F5, 0x001CB5}, {0x0010F6, 0x001CB6}, {0x0010F7, 0x001CB7}, + {0x0010F8, 0x001CB8}, {0x0010F9, 0x001CB9}, {0x0010FA, 0x001CBA}, {0x0010FD, 0x001CBD}, {0x0010FE, 0x001CBE}, + {0x0010FF, 0x001CBF}, {0x0013F8, 0x0013F0}, {0x0013F9, 0x0013F1}, {0x0013FA, 0x0013F2}, {0x0013FB, 0x0013F3}, + {0x0013FC, 0x0013F4}, {0x0013FD, 0x0013F5}, {0x001C80, 0x000412}, {0x001C81, 0x000414}, {0x001C82, 0x00041E}, + {0x001C83, 0x000421}, {0x001C84, 0x000422}, {0x001C85, 0x000422}, {0x001C86, 0x00042A}, {0x001C87, 0x000462}, + {0x001C88, 0x00A64A}, {0x001D79, 0x00A77D}, {0x001D7D, 0x002C63}, {0x001D8E, 0x00A7C6}, {0x001E01, 0x001E00}, + {0x001E03, 0x001E02}, {0x001E05, 0x001E04}, {0x001E07, 0x001E06}, {0x001E09, 0x001E08}, {0x001E0B, 0x001E0A}, + {0x001E0D, 0x001E0C}, {0x001E0F, 0x001E0E}, {0x001E11, 0x001E10}, {0x001E13, 0x001E12}, {0x001E15, 0x001E14}, + {0x001E17, 0x001E16}, {0x001E19, 0x001E18}, {0x001E1B, 0x001E1A}, {0x001E1D, 0x001E1C}, {0x001E1F, 0x001E1E}, + {0x001E21, 0x001E20}, {0x001E23, 0x001E22}, {0x001E25, 0x001E24}, {0x001E27, 0x001E26}, {0x001E29, 0x001E28}, + {0x001E2B, 0x001E2A}, {0x001E2D, 0x001E2C}, {0x001E2F, 0x001E2E}, {0x001E31, 0x001E30}, {0x001E33, 0x001E32}, + {0x001E35, 0x001E34}, {0x001E37, 0x001E36}, {0x001E39, 0x001E38}, {0x001E3B, 0x001E3A}, {0x001E3D, 0x001E3C}, + {0x001E3F, 0x001E3E}, {0x001E41, 0x001E40}, {0x001E43, 0x001E42}, {0x001E45, 0x001E44}, {0x001E47, 0x001E46}, + {0x001E49, 0x001E48}, {0x001E4B, 0x001E4A}, {0x001E4D, 0x001E4C}, {0x001E4F, 0x001E4E}, {0x001E51, 0x001E50}, + {0x001E53, 0x001E52}, {0x001E55, 0x001E54}, {0x001E57, 0x001E56}, {0x001E59, 0x001E58}, {0x001E5B, 0x001E5A}, + {0x001E5D, 0x001E5C}, {0x001E5F, 0x001E5E}, {0x001E61, 0x001E60}, {0x001E63, 0x001E62}, {0x001E65, 0x001E64}, + {0x001E67, 0x001E66}, {0x001E69, 0x001E68}, {0x001E6B, 0x001E6A}, {0x001E6D, 0x001E6C}, {0x001E6F, 0x001E6E}, + {0x001E71, 0x001E70}, {0x001E73, 0x001E72}, {0x001E75, 0x001E74}, {0x001E77, 0x001E76}, {0x001E79, 0x001E78}, + {0x001E7B, 0x001E7A}, {0x001E7D, 0x001E7C}, {0x001E7F, 0x001E7E}, {0x001E81, 0x001E80}, {0x001E83, 0x001E82}, + {0x001E85, 0x001E84}, {0x001E87, 0x001E86}, {0x001E89, 0x001E88}, {0x001E8B, 0x001E8A}, {0x001E8D, 0x001E8C}, + {0x001E8F, 0x001E8E}, {0x001E91, 0x001E90}, {0x001E93, 0x001E92}, {0x001E95, 0x001E94}, {0x001E9B, 0x001E60}, + {0x001EA1, 0x001EA0}, {0x001EA3, 0x001EA2}, {0x001EA5, 0x001EA4}, {0x001EA7, 0x001EA6}, {0x001EA9, 0x001EA8}, + {0x001EAB, 0x001EAA}, {0x001EAD, 0x001EAC}, {0x001EAF, 0x001EAE}, {0x001EB1, 0x001EB0}, {0x001EB3, 0x001EB2}, + {0x001EB5, 0x001EB4}, {0x001EB7, 0x001EB6}, {0x001EB9, 0x001EB8}, {0x001EBB, 0x001EBA}, {0x001EBD, 0x001EBC}, + {0x001EBF, 0x001EBE}, {0x001EC1, 0x001EC0}, {0x001EC3, 0x001EC2}, {0x001EC5, 0x001EC4}, {0x001EC7, 0x001EC6}, + {0x001EC9, 0x001EC8}, {0x001ECB, 0x001ECA}, {0x001ECD, 0x001ECC}, {0x001ECF, 0x001ECE}, {0x001ED1, 0x001ED0}, + {0x001ED3, 0x001ED2}, {0x001ED5, 0x001ED4}, {0x001ED7, 0x001ED6}, {0x001ED9, 0x001ED8}, {0x001EDB, 0x001EDA}, + {0x001EDD, 0x001EDC}, {0x001EDF, 0x001EDE}, {0x001EE1, 0x001EE0}, {0x001EE3, 0x001EE2}, {0x001EE5, 0x001EE4}, + {0x001EE7, 0x001EE6}, {0x001EE9, 0x001EE8}, {0x001EEB, 0x001EEA}, {0x001EED, 0x001EEC}, {0x001EEF, 0x001EEE}, + {0x001EF1, 0x001EF0}, {0x001EF3, 0x001EF2}, {0x001EF5, 0x001EF4}, {0x001EF7, 0x001EF6}, {0x001EF9, 0x001EF8}, + {0x001EFB, 0x001EFA}, {0x001EFD, 0x001EFC}, {0x001EFF, 0x001EFE}, {0x001F00, 0x001F08}, {0x001F01, 0x001F09}, + {0x001F02, 0x001F0A}, {0x001F03, 0x001F0B}, {0x001F04, 0x001F0C}, {0x001F05, 0x001F0D}, {0x001F06, 0x001F0E}, + {0x001F07, 0x001F0F}, {0x001F10, 0x001F18}, {0x001F11, 0x001F19}, {0x001F12, 0x001F1A}, {0x001F13, 0x001F1B}, + {0x001F14, 0x001F1C}, {0x001F15, 0x001F1D}, {0x001F20, 0x001F28}, {0x001F21, 0x001F29}, {0x001F22, 0x001F2A}, + {0x001F23, 0x001F2B}, {0x001F24, 0x001F2C}, {0x001F25, 0x001F2D}, {0x001F26, 0x001F2E}, {0x001F27, 0x001F2F}, + {0x001F30, 0x001F38}, {0x001F31, 0x001F39}, {0x001F32, 0x001F3A}, {0x001F33, 0x001F3B}, {0x001F34, 0x001F3C}, + {0x001F35, 0x001F3D}, {0x001F36, 0x001F3E}, {0x001F37, 0x001F3F}, {0x001F40, 0x001F48}, {0x001F41, 0x001F49}, + {0x001F42, 0x001F4A}, {0x001F43, 0x001F4B}, {0x001F44, 0x001F4C}, {0x001F45, 0x001F4D}, {0x001F51, 0x001F59}, + {0x001F53, 0x001F5B}, {0x001F55, 0x001F5D}, {0x001F57, 0x001F5F}, {0x001F60, 0x001F68}, {0x001F61, 0x001F69}, + {0x001F62, 0x001F6A}, {0x001F63, 0x001F6B}, {0x001F64, 0x001F6C}, {0x001F65, 0x001F6D}, {0x001F66, 0x001F6E}, + {0x001F67, 0x001F6F}, {0x001F70, 0x001FBA}, {0x001F71, 0x001FBB}, {0x001F72, 0x001FC8}, {0x001F73, 0x001FC9}, + {0x001F74, 0x001FCA}, {0x001F75, 0x001FCB}, {0x001F76, 0x001FDA}, {0x001F77, 0x001FDB}, {0x001F78, 0x001FF8}, + {0x001F79, 0x001FF9}, {0x001F7A, 0x001FEA}, {0x001F7B, 0x001FEB}, {0x001F7C, 0x001FFA}, {0x001F7D, 0x001FFB}, + {0x001F80, 0x001F88}, {0x001F81, 0x001F89}, {0x001F82, 0x001F8A}, {0x001F83, 0x001F8B}, {0x001F84, 0x001F8C}, + {0x001F85, 0x001F8D}, {0x001F86, 0x001F8E}, {0x001F87, 0x001F8F}, {0x001F90, 0x001F98}, {0x001F91, 0x001F99}, + {0x001F92, 0x001F9A}, {0x001F93, 0x001F9B}, {0x001F94, 0x001F9C}, {0x001F95, 0x001F9D}, {0x001F96, 0x001F9E}, + {0x001F97, 0x001F9F}, {0x001FA0, 0x001FA8}, {0x001FA1, 0x001FA9}, {0x001FA2, 0x001FAA}, {0x001FA3, 0x001FAB}, + {0x001FA4, 0x001FAC}, {0x001FA5, 0x001FAD}, {0x001FA6, 0x001FAE}, {0x001FA7, 0x001FAF}, {0x001FB0, 0x001FB8}, + {0x001FB1, 0x001FB9}, {0x001FB3, 0x001FBC}, {0x001FBE, 0x000399}, {0x001FC3, 0x001FCC}, {0x001FD0, 0x001FD8}, + {0x001FD1, 0x001FD9}, {0x001FE0, 0x001FE8}, {0x001FE1, 0x001FE9}, {0x001FE5, 0x001FEC}, {0x001FF3, 0x001FFC}, + {0x00214E, 0x002132}, {0x002170, 0x002160}, {0x002171, 0x002161}, {0x002172, 0x002162}, {0x002173, 0x002163}, + {0x002174, 0x002164}, {0x002175, 0x002165}, {0x002176, 0x002166}, {0x002177, 0x002167}, {0x002178, 0x002168}, + {0x002179, 0x002169}, {0x00217A, 0x00216A}, {0x00217B, 0x00216B}, {0x00217C, 0x00216C}, {0x00217D, 0x00216D}, + {0x00217E, 0x00216E}, {0x00217F, 0x00216F}, {0x002184, 0x002183}, {0x0024D0, 0x0024B6}, {0x0024D1, 0x0024B7}, + {0x0024D2, 0x0024B8}, {0x0024D3, 0x0024B9}, {0x0024D4, 0x0024BA}, {0x0024D5, 0x0024BB}, {0x0024D6, 0x0024BC}, + {0x0024D7, 0x0024BD}, {0x0024D8, 0x0024BE}, {0x0024D9, 0x0024BF}, {0x0024DA, 0x0024C0}, {0x0024DB, 0x0024C1}, + {0x0024DC, 0x0024C2}, {0x0024DD, 0x0024C3}, {0x0024DE, 0x0024C4}, {0x0024DF, 0x0024C5}, {0x0024E0, 0x0024C6}, + {0x0024E1, 0x0024C7}, {0x0024E2, 0x0024C8}, {0x0024E3, 0x0024C9}, {0x0024E4, 0x0024CA}, {0x0024E5, 0x0024CB}, + {0x0024E6, 0x0024CC}, {0x0024E7, 0x0024CD}, {0x0024E8, 0x0024CE}, {0x0024E9, 0x0024CF}, {0x002C30, 0x002C00}, + {0x002C31, 0x002C01}, {0x002C32, 0x002C02}, {0x002C33, 0x002C03}, {0x002C34, 0x002C04}, {0x002C35, 0x002C05}, + {0x002C36, 0x002C06}, {0x002C37, 0x002C07}, {0x002C38, 0x002C08}, {0x002C39, 0x002C09}, {0x002C3A, 0x002C0A}, + {0x002C3B, 0x002C0B}, {0x002C3C, 0x002C0C}, {0x002C3D, 0x002C0D}, {0x002C3E, 0x002C0E}, {0x002C3F, 0x002C0F}, + {0x002C40, 0x002C10}, {0x002C41, 0x002C11}, {0x002C42, 0x002C12}, {0x002C43, 0x002C13}, {0x002C44, 0x002C14}, + {0x002C45, 0x002C15}, {0x002C46, 0x002C16}, {0x002C47, 0x002C17}, {0x002C48, 0x002C18}, {0x002C49, 0x002C19}, + {0x002C4A, 0x002C1A}, {0x002C4B, 0x002C1B}, {0x002C4C, 0x002C1C}, {0x002C4D, 0x002C1D}, {0x002C4E, 0x002C1E}, + {0x002C4F, 0x002C1F}, {0x002C50, 0x002C20}, {0x002C51, 0x002C21}, {0x002C52, 0x002C22}, {0x002C53, 0x002C23}, + {0x002C54, 0x002C24}, {0x002C55, 0x002C25}, {0x002C56, 0x002C26}, {0x002C57, 0x002C27}, {0x002C58, 0x002C28}, + {0x002C59, 0x002C29}, {0x002C5A, 0x002C2A}, {0x002C5B, 0x002C2B}, {0x002C5C, 0x002C2C}, {0x002C5D, 0x002C2D}, + {0x002C5E, 0x002C2E}, {0x002C5F, 0x002C2F}, {0x002C61, 0x002C60}, {0x002C65, 0x00023A}, {0x002C66, 0x00023E}, + {0x002C68, 0x002C67}, {0x002C6A, 0x002C69}, {0x002C6C, 0x002C6B}, {0x002C73, 0x002C72}, {0x002C76, 0x002C75}, + {0x002C81, 0x002C80}, {0x002C83, 0x002C82}, {0x002C85, 0x002C84}, {0x002C87, 0x002C86}, {0x002C89, 0x002C88}, + {0x002C8B, 0x002C8A}, {0x002C8D, 0x002C8C}, {0x002C8F, 0x002C8E}, {0x002C91, 0x002C90}, {0x002C93, 0x002C92}, + {0x002C95, 0x002C94}, {0x002C97, 0x002C96}, {0x002C99, 0x002C98}, {0x002C9B, 0x002C9A}, {0x002C9D, 0x002C9C}, + {0x002C9F, 0x002C9E}, {0x002CA1, 0x002CA0}, {0x002CA3, 0x002CA2}, {0x002CA5, 0x002CA4}, {0x002CA7, 0x002CA6}, + {0x002CA9, 0x002CA8}, {0x002CAB, 0x002CAA}, {0x002CAD, 0x002CAC}, {0x002CAF, 0x002CAE}, {0x002CB1, 0x002CB0}, + {0x002CB3, 0x002CB2}, {0x002CB5, 0x002CB4}, {0x002CB7, 0x002CB6}, {0x002CB9, 0x002CB8}, {0x002CBB, 0x002CBA}, + {0x002CBD, 0x002CBC}, {0x002CBF, 0x002CBE}, {0x002CC1, 0x002CC0}, {0x002CC3, 0x002CC2}, {0x002CC5, 0x002CC4}, + {0x002CC7, 0x002CC6}, {0x002CC9, 0x002CC8}, {0x002CCB, 0x002CCA}, {0x002CCD, 0x002CCC}, {0x002CCF, 0x002CCE}, + {0x002CD1, 0x002CD0}, {0x002CD3, 0x002CD2}, {0x002CD5, 0x002CD4}, {0x002CD7, 0x002CD6}, {0x002CD9, 0x002CD8}, + {0x002CDB, 0x002CDA}, {0x002CDD, 0x002CDC}, {0x002CDF, 0x002CDE}, {0x002CE1, 0x002CE0}, {0x002CE3, 0x002CE2}, + {0x002CEC, 0x002CEB}, {0x002CEE, 0x002CED}, {0x002CF3, 0x002CF2}, {0x002D00, 0x0010A0}, {0x002D01, 0x0010A1}, + {0x002D02, 0x0010A2}, {0x002D03, 0x0010A3}, {0x002D04, 0x0010A4}, {0x002D05, 0x0010A5}, {0x002D06, 0x0010A6}, + {0x002D07, 0x0010A7}, {0x002D08, 0x0010A8}, {0x002D09, 0x0010A9}, {0x002D0A, 0x0010AA}, {0x002D0B, 0x0010AB}, + {0x002D0C, 0x0010AC}, {0x002D0D, 0x0010AD}, {0x002D0E, 0x0010AE}, {0x002D0F, 0x0010AF}, {0x002D10, 0x0010B0}, + {0x002D11, 0x0010B1}, {0x002D12, 0x0010B2}, {0x002D13, 0x0010B3}, {0x002D14, 0x0010B4}, {0x002D15, 0x0010B5}, + {0x002D16, 0x0010B6}, {0x002D17, 0x0010B7}, {0x002D18, 0x0010B8}, {0x002D19, 0x0010B9}, {0x002D1A, 0x0010BA}, + {0x002D1B, 0x0010BB}, {0x002D1C, 0x0010BC}, {0x002D1D, 0x0010BD}, {0x002D1E, 0x0010BE}, {0x002D1F, 0x0010BF}, + {0x002D20, 0x0010C0}, {0x002D21, 0x0010C1}, {0x002D22, 0x0010C2}, {0x002D23, 0x0010C3}, {0x002D24, 0x0010C4}, + {0x002D25, 0x0010C5}, {0x002D27, 0x0010C7}, {0x002D2D, 0x0010CD}, {0x00A641, 0x00A640}, {0x00A643, 0x00A642}, + {0x00A645, 0x00A644}, {0x00A647, 0x00A646}, {0x00A649, 0x00A648}, {0x00A64B, 0x00A64A}, {0x00A64D, 0x00A64C}, + {0x00A64F, 0x00A64E}, {0x00A651, 0x00A650}, {0x00A653, 0x00A652}, {0x00A655, 0x00A654}, {0x00A657, 0x00A656}, + {0x00A659, 0x00A658}, {0x00A65B, 0x00A65A}, {0x00A65D, 0x00A65C}, {0x00A65F, 0x00A65E}, {0x00A661, 0x00A660}, + {0x00A663, 0x00A662}, {0x00A665, 0x00A664}, {0x00A667, 0x00A666}, {0x00A669, 0x00A668}, {0x00A66B, 0x00A66A}, + {0x00A66D, 0x00A66C}, {0x00A681, 0x00A680}, {0x00A683, 0x00A682}, {0x00A685, 0x00A684}, {0x00A687, 0x00A686}, + {0x00A689, 0x00A688}, {0x00A68B, 0x00A68A}, {0x00A68D, 0x00A68C}, {0x00A68F, 0x00A68E}, {0x00A691, 0x00A690}, + {0x00A693, 0x00A692}, {0x00A695, 0x00A694}, {0x00A697, 0x00A696}, {0x00A699, 0x00A698}, {0x00A69B, 0x00A69A}, + {0x00A723, 0x00A722}, {0x00A725, 0x00A724}, {0x00A727, 0x00A726}, {0x00A729, 0x00A728}, {0x00A72B, 0x00A72A}, + {0x00A72D, 0x00A72C}, {0x00A72F, 0x00A72E}, {0x00A733, 0x00A732}, {0x00A735, 0x00A734}, {0x00A737, 0x00A736}, + {0x00A739, 0x00A738}, {0x00A73B, 0x00A73A}, {0x00A73D, 0x00A73C}, {0x00A73F, 0x00A73E}, {0x00A741, 0x00A740}, + {0x00A743, 0x00A742}, {0x00A745, 0x00A744}, {0x00A747, 0x00A746}, {0x00A749, 0x00A748}, {0x00A74B, 0x00A74A}, + {0x00A74D, 0x00A74C}, {0x00A74F, 0x00A74E}, {0x00A751, 0x00A750}, {0x00A753, 0x00A752}, {0x00A755, 0x00A754}, + {0x00A757, 0x00A756}, {0x00A759, 0x00A758}, {0x00A75B, 0x00A75A}, {0x00A75D, 0x00A75C}, {0x00A75F, 0x00A75E}, + {0x00A761, 0x00A760}, {0x00A763, 0x00A762}, {0x00A765, 0x00A764}, {0x00A767, 0x00A766}, {0x00A769, 0x00A768}, + {0x00A76B, 0x00A76A}, {0x00A76D, 0x00A76C}, {0x00A76F, 0x00A76E}, {0x00A77A, 0x00A779}, {0x00A77C, 0x00A77B}, + {0x00A77F, 0x00A77E}, {0x00A781, 0x00A780}, {0x00A783, 0x00A782}, {0x00A785, 0x00A784}, {0x00A787, 0x00A786}, + {0x00A78C, 0x00A78B}, {0x00A791, 0x00A790}, {0x00A793, 0x00A792}, {0x00A794, 0x00A7C4}, {0x00A797, 0x00A796}, + {0x00A799, 0x00A798}, {0x00A79B, 0x00A79A}, {0x00A79D, 0x00A79C}, {0x00A79F, 0x00A79E}, {0x00A7A1, 0x00A7A0}, + {0x00A7A3, 0x00A7A2}, {0x00A7A5, 0x00A7A4}, {0x00A7A7, 0x00A7A6}, {0x00A7A9, 0x00A7A8}, {0x00A7B5, 0x00A7B4}, + {0x00A7B7, 0x00A7B6}, {0x00A7B9, 0x00A7B8}, {0x00A7BB, 0x00A7BA}, {0x00A7BD, 0x00A7BC}, {0x00A7BF, 0x00A7BE}, + {0x00A7C1, 0x00A7C0}, {0x00A7C3, 0x00A7C2}, {0x00A7C8, 0x00A7C7}, {0x00A7CA, 0x00A7C9}, {0x00A7D1, 0x00A7D0}, + {0x00A7D7, 0x00A7D6}, {0x00A7D9, 0x00A7D8}, {0x00A7F6, 0x00A7F5}, {0x00AB53, 0x00A7B3}, {0x00AB70, 0x0013A0}, + {0x00AB71, 0x0013A1}, {0x00AB72, 0x0013A2}, {0x00AB73, 0x0013A3}, {0x00AB74, 0x0013A4}, {0x00AB75, 0x0013A5}, + {0x00AB76, 0x0013A6}, {0x00AB77, 0x0013A7}, {0x00AB78, 0x0013A8}, {0x00AB79, 0x0013A9}, {0x00AB7A, 0x0013AA}, + {0x00AB7B, 0x0013AB}, {0x00AB7C, 0x0013AC}, {0x00AB7D, 0x0013AD}, {0x00AB7E, 0x0013AE}, {0x00AB7F, 0x0013AF}, + {0x00AB80, 0x0013B0}, {0x00AB81, 0x0013B1}, {0x00AB82, 0x0013B2}, {0x00AB83, 0x0013B3}, {0x00AB84, 0x0013B4}, + {0x00AB85, 0x0013B5}, {0x00AB86, 0x0013B6}, {0x00AB87, 0x0013B7}, {0x00AB88, 0x0013B8}, {0x00AB89, 0x0013B9}, + {0x00AB8A, 0x0013BA}, {0x00AB8B, 0x0013BB}, {0x00AB8C, 0x0013BC}, {0x00AB8D, 0x0013BD}, {0x00AB8E, 0x0013BE}, + {0x00AB8F, 0x0013BF}, {0x00AB90, 0x0013C0}, {0x00AB91, 0x0013C1}, {0x00AB92, 0x0013C2}, {0x00AB93, 0x0013C3}, + {0x00AB94, 0x0013C4}, {0x00AB95, 0x0013C5}, {0x00AB96, 0x0013C6}, {0x00AB97, 0x0013C7}, {0x00AB98, 0x0013C8}, + {0x00AB99, 0x0013C9}, {0x00AB9A, 0x0013CA}, {0x00AB9B, 0x0013CB}, {0x00AB9C, 0x0013CC}, {0x00AB9D, 0x0013CD}, + {0x00AB9E, 0x0013CE}, {0x00AB9F, 0x0013CF}, {0x00ABA0, 0x0013D0}, {0x00ABA1, 0x0013D1}, {0x00ABA2, 0x0013D2}, + {0x00ABA3, 0x0013D3}, {0x00ABA4, 0x0013D4}, {0x00ABA5, 0x0013D5}, {0x00ABA6, 0x0013D6}, {0x00ABA7, 0x0013D7}, + {0x00ABA8, 0x0013D8}, {0x00ABA9, 0x0013D9}, {0x00ABAA, 0x0013DA}, {0x00ABAB, 0x0013DB}, {0x00ABAC, 0x0013DC}, + {0x00ABAD, 0x0013DD}, {0x00ABAE, 0x0013DE}, {0x00ABAF, 0x0013DF}, {0x00ABB0, 0x0013E0}, {0x00ABB1, 0x0013E1}, + {0x00ABB2, 0x0013E2}, {0x00ABB3, 0x0013E3}, {0x00ABB4, 0x0013E4}, {0x00ABB5, 0x0013E5}, {0x00ABB6, 0x0013E6}, + {0x00ABB7, 0x0013E7}, {0x00ABB8, 0x0013E8}, {0x00ABB9, 0x0013E9}, {0x00ABBA, 0x0013EA}, {0x00ABBB, 0x0013EB}, + {0x00ABBC, 0x0013EC}, {0x00ABBD, 0x0013ED}, {0x00ABBE, 0x0013EE}, {0x00ABBF, 0x0013EF}, {0x00FF41, 0x00FF21}, + {0x00FF42, 0x00FF22}, {0x00FF43, 0x00FF23}, {0x00FF44, 0x00FF24}, {0x00FF45, 0x00FF25}, {0x00FF46, 0x00FF26}, + {0x00FF47, 0x00FF27}, {0x00FF48, 0x00FF28}, {0x00FF49, 0x00FF29}, {0x00FF4A, 0x00FF2A}, {0x00FF4B, 0x00FF2B}, + {0x00FF4C, 0x00FF2C}, {0x00FF4D, 0x00FF2D}, {0x00FF4E, 0x00FF2E}, {0x00FF4F, 0x00FF2F}, {0x00FF50, 0x00FF30}, + {0x00FF51, 0x00FF31}, {0x00FF52, 0x00FF32}, {0x00FF53, 0x00FF33}, {0x00FF54, 0x00FF34}, {0x00FF55, 0x00FF35}, + {0x00FF56, 0x00FF36}, {0x00FF57, 0x00FF37}, {0x00FF58, 0x00FF38}, {0x00FF59, 0x00FF39}, {0x00FF5A, 0x00FF3A}, + {0x010428, 0x010400}, {0x010429, 0x010401}, {0x01042A, 0x010402}, {0x01042B, 0x010403}, {0x01042C, 0x010404}, + {0x01042D, 0x010405}, {0x01042E, 0x010406}, {0x01042F, 0x010407}, {0x010430, 0x010408}, {0x010431, 0x010409}, + {0x010432, 0x01040A}, {0x010433, 0x01040B}, {0x010434, 0x01040C}, {0x010435, 0x01040D}, {0x010436, 0x01040E}, + {0x010437, 0x01040F}, {0x010438, 0x010410}, {0x010439, 0x010411}, {0x01043A, 0x010412}, {0x01043B, 0x010413}, + {0x01043C, 0x010414}, {0x01043D, 0x010415}, {0x01043E, 0x010416}, {0x01043F, 0x010417}, {0x010440, 0x010418}, + {0x010441, 0x010419}, {0x010442, 0x01041A}, {0x010443, 0x01041B}, {0x010444, 0x01041C}, {0x010445, 0x01041D}, + {0x010446, 0x01041E}, {0x010447, 0x01041F}, {0x010448, 0x010420}, {0x010449, 0x010421}, {0x01044A, 0x010422}, + {0x01044B, 0x010423}, {0x01044C, 0x010424}, {0x01044D, 0x010425}, {0x01044E, 0x010426}, {0x01044F, 0x010427}, + {0x0104D8, 0x0104B0}, {0x0104D9, 0x0104B1}, {0x0104DA, 0x0104B2}, {0x0104DB, 0x0104B3}, {0x0104DC, 0x0104B4}, + {0x0104DD, 0x0104B5}, {0x0104DE, 0x0104B6}, {0x0104DF, 0x0104B7}, {0x0104E0, 0x0104B8}, {0x0104E1, 0x0104B9}, + {0x0104E2, 0x0104BA}, {0x0104E3, 0x0104BB}, {0x0104E4, 0x0104BC}, {0x0104E5, 0x0104BD}, {0x0104E6, 0x0104BE}, + {0x0104E7, 0x0104BF}, {0x0104E8, 0x0104C0}, {0x0104E9, 0x0104C1}, {0x0104EA, 0x0104C2}, {0x0104EB, 0x0104C3}, + {0x0104EC, 0x0104C4}, {0x0104ED, 0x0104C5}, {0x0104EE, 0x0104C6}, {0x0104EF, 0x0104C7}, {0x0104F0, 0x0104C8}, + {0x0104F1, 0x0104C9}, {0x0104F2, 0x0104CA}, {0x0104F3, 0x0104CB}, {0x0104F4, 0x0104CC}, {0x0104F5, 0x0104CD}, + {0x0104F6, 0x0104CE}, {0x0104F7, 0x0104CF}, {0x0104F8, 0x0104D0}, {0x0104F9, 0x0104D1}, {0x0104FA, 0x0104D2}, + {0x0104FB, 0x0104D3}, {0x010597, 0x010570}, {0x010598, 0x010571}, {0x010599, 0x010572}, {0x01059A, 0x010573}, + {0x01059B, 0x010574}, {0x01059C, 0x010575}, {0x01059D, 0x010576}, {0x01059E, 0x010577}, {0x01059F, 0x010578}, + {0x0105A0, 0x010579}, {0x0105A1, 0x01057A}, {0x0105A3, 0x01057C}, {0x0105A4, 0x01057D}, {0x0105A5, 0x01057E}, + {0x0105A6, 0x01057F}, {0x0105A7, 0x010580}, {0x0105A8, 0x010581}, {0x0105A9, 0x010582}, {0x0105AA, 0x010583}, + {0x0105AB, 0x010584}, {0x0105AC, 0x010585}, {0x0105AD, 0x010586}, {0x0105AE, 0x010587}, {0x0105AF, 0x010588}, + {0x0105B0, 0x010589}, {0x0105B1, 0x01058A}, {0x0105B3, 0x01058C}, {0x0105B4, 0x01058D}, {0x0105B5, 0x01058E}, + {0x0105B6, 0x01058F}, {0x0105B7, 0x010590}, {0x0105B8, 0x010591}, {0x0105B9, 0x010592}, {0x0105BB, 0x010594}, + {0x0105BC, 0x010595}, {0x010CC0, 0x010C80}, {0x010CC1, 0x010C81}, {0x010CC2, 0x010C82}, {0x010CC3, 0x010C83}, + {0x010CC4, 0x010C84}, {0x010CC5, 0x010C85}, {0x010CC6, 0x010C86}, {0x010CC7, 0x010C87}, {0x010CC8, 0x010C88}, + {0x010CC9, 0x010C89}, {0x010CCA, 0x010C8A}, {0x010CCB, 0x010C8B}, {0x010CCC, 0x010C8C}, {0x010CCD, 0x010C8D}, + {0x010CCE, 0x010C8E}, {0x010CCF, 0x010C8F}, {0x010CD0, 0x010C90}, {0x010CD1, 0x010C91}, {0x010CD2, 0x010C92}, + {0x010CD3, 0x010C93}, {0x010CD4, 0x010C94}, {0x010CD5, 0x010C95}, {0x010CD6, 0x010C96}, {0x010CD7, 0x010C97}, + {0x010CD8, 0x010C98}, {0x010CD9, 0x010C99}, {0x010CDA, 0x010C9A}, {0x010CDB, 0x010C9B}, {0x010CDC, 0x010C9C}, + {0x010CDD, 0x010C9D}, {0x010CDE, 0x010C9E}, {0x010CDF, 0x010C9F}, {0x010CE0, 0x010CA0}, {0x010CE1, 0x010CA1}, + {0x010CE2, 0x010CA2}, {0x010CE3, 0x010CA3}, {0x010CE4, 0x010CA4}, {0x010CE5, 0x010CA5}, {0x010CE6, 0x010CA6}, + {0x010CE7, 0x010CA7}, {0x010CE8, 0x010CA8}, {0x010CE9, 0x010CA9}, {0x010CEA, 0x010CAA}, {0x010CEB, 0x010CAB}, + {0x010CEC, 0x010CAC}, {0x010CED, 0x010CAD}, {0x010CEE, 0x010CAE}, {0x010CEF, 0x010CAF}, {0x010CF0, 0x010CB0}, + {0x010CF1, 0x010CB1}, {0x010CF2, 0x010CB2}, {0x0118C0, 0x0118A0}, {0x0118C1, 0x0118A1}, {0x0118C2, 0x0118A2}, + {0x0118C3, 0x0118A3}, {0x0118C4, 0x0118A4}, {0x0118C5, 0x0118A5}, {0x0118C6, 0x0118A6}, {0x0118C7, 0x0118A7}, + {0x0118C8, 0x0118A8}, {0x0118C9, 0x0118A9}, {0x0118CA, 0x0118AA}, {0x0118CB, 0x0118AB}, {0x0118CC, 0x0118AC}, + {0x0118CD, 0x0118AD}, {0x0118CE, 0x0118AE}, {0x0118CF, 0x0118AF}, {0x0118D0, 0x0118B0}, {0x0118D1, 0x0118B1}, + {0x0118D2, 0x0118B2}, {0x0118D3, 0x0118B3}, {0x0118D4, 0x0118B4}, {0x0118D5, 0x0118B5}, {0x0118D6, 0x0118B6}, + {0x0118D7, 0x0118B7}, {0x0118D8, 0x0118B8}, {0x0118D9, 0x0118B9}, {0x0118DA, 0x0118BA}, {0x0118DB, 0x0118BB}, + {0x0118DC, 0x0118BC}, {0x0118DD, 0x0118BD}, {0x0118DE, 0x0118BE}, {0x0118DF, 0x0118BF}, {0x016E60, 0x016E40}, + {0x016E61, 0x016E41}, {0x016E62, 0x016E42}, {0x016E63, 0x016E43}, {0x016E64, 0x016E44}, {0x016E65, 0x016E45}, + {0x016E66, 0x016E46}, {0x016E67, 0x016E47}, {0x016E68, 0x016E48}, {0x016E69, 0x016E49}, {0x016E6A, 0x016E4A}, + {0x016E6B, 0x016E4B}, {0x016E6C, 0x016E4C}, {0x016E6D, 0x016E4D}, {0x016E6E, 0x016E4E}, {0x016E6F, 0x016E4F}, + {0x016E70, 0x016E50}, {0x016E71, 0x016E51}, {0x016E72, 0x016E52}, {0x016E73, 0x016E53}, {0x016E74, 0x016E54}, + {0x016E75, 0x016E55}, {0x016E76, 0x016E56}, {0x016E77, 0x016E57}, {0x016E78, 0x016E58}, {0x016E79, 0x016E59}, + {0x016E7A, 0x016E5A}, {0x016E7B, 0x016E5B}, {0x016E7C, 0x016E5C}, {0x016E7D, 0x016E5D}, {0x016E7E, 0x016E5E}, + {0x016E7F, 0x016E5F}, {0x01E922, 0x01E900}, {0x01E923, 0x01E901}, {0x01E924, 0x01E902}, {0x01E925, 0x01E903}, + {0x01E926, 0x01E904}, {0x01E927, 0x01E905}, {0x01E928, 0x01E906}, {0x01E929, 0x01E907}, {0x01E92A, 0x01E908}, + {0x01E92B, 0x01E909}, {0x01E92C, 0x01E90A}, {0x01E92D, 0x01E90B}, {0x01E92E, 0x01E90C}, {0x01E92F, 0x01E90D}, + {0x01E930, 0x01E90E}, {0x01E931, 0x01E90F}, {0x01E932, 0x01E910}, {0x01E933, 0x01E911}, {0x01E934, 0x01E912}, + {0x01E935, 0x01E913}, {0x01E936, 0x01E914}, {0x01E937, 0x01E915}, {0x01E938, 0x01E916}, {0x01E939, 0x01E917}, + {0x01E93A, 0x01E918}, {0x01E93B, 0x01E919}, {0x01E93C, 0x01E91A}, {0x01E93D, 0x01E91B}, {0x01E93E, 0x01E91C}, + {0x01E93F, 0x01E91D}, {0x01E940, 0x01E91E}, {0x01E941, 0x01E91F}, {0x01E942, 0x01E920}, {0x01E943, 0x01E921}, +}; + +const std::initializer_list unicode_ranges_nfd = { + // start, last, nfd + {0x000000, 0x000000, 0x000000}, {0x0000C0, 0x0000C5, 0x000041}, {0x0000C7, 0x0000C7, 0x000043}, + {0x0000C8, 0x0000CB, 0x000045}, {0x0000CC, 0x0000CF, 0x000049}, {0x0000D1, 0x0000D1, 0x00004E}, + {0x0000D2, 0x0000D6, 0x00004F}, {0x0000D9, 0x0000DC, 0x000055}, {0x0000DD, 0x0000DD, 0x000059}, + {0x0000E0, 0x0000E5, 0x000061}, {0x0000E7, 0x0000E7, 0x000063}, {0x0000E8, 0x0000EB, 0x000065}, + {0x0000EC, 0x0000EF, 0x000069}, {0x0000F1, 0x0000F1, 0x00006E}, {0x0000F2, 0x0000F6, 0x00006F}, + {0x0000F9, 0x0000FC, 0x000075}, {0x0000FD, 0x0000FD, 0x000079}, {0x0000FF, 0x0000FF, 0x000079}, + {0x000100, 0x000100, 0x000041}, {0x000101, 0x000101, 0x000061}, {0x000102, 0x000102, 0x000041}, + {0x000103, 0x000103, 0x000061}, {0x000104, 0x000104, 0x000041}, {0x000105, 0x000105, 0x000061}, + {0x000106, 0x000106, 0x000043}, {0x000107, 0x000107, 0x000063}, {0x000108, 0x000108, 0x000043}, + {0x000109, 0x000109, 0x000063}, {0x00010A, 0x00010A, 0x000043}, {0x00010B, 0x00010B, 0x000063}, + {0x00010C, 0x00010C, 0x000043}, {0x00010D, 0x00010D, 0x000063}, {0x00010E, 0x00010E, 0x000044}, + {0x00010F, 0x00010F, 0x000064}, {0x000112, 0x000112, 0x000045}, {0x000113, 0x000113, 0x000065}, + {0x000114, 0x000114, 0x000045}, {0x000115, 0x000115, 0x000065}, {0x000116, 0x000116, 0x000045}, + {0x000117, 0x000117, 0x000065}, {0x000118, 0x000118, 0x000045}, {0x000119, 0x000119, 0x000065}, + {0x00011A, 0x00011A, 0x000045}, {0x00011B, 0x00011B, 0x000065}, {0x00011C, 0x00011C, 0x000047}, + {0x00011D, 0x00011D, 0x000067}, {0x00011E, 0x00011E, 0x000047}, {0x00011F, 0x00011F, 0x000067}, + {0x000120, 0x000120, 0x000047}, {0x000121, 0x000121, 0x000067}, {0x000122, 0x000122, 0x000047}, + {0x000123, 0x000123, 0x000067}, {0x000124, 0x000124, 0x000048}, {0x000125, 0x000125, 0x000068}, + {0x000128, 0x000128, 0x000049}, {0x000129, 0x000129, 0x000069}, {0x00012A, 0x00012A, 0x000049}, + {0x00012B, 0x00012B, 0x000069}, {0x00012C, 0x00012C, 0x000049}, {0x00012D, 0x00012D, 0x000069}, + {0x00012E, 0x00012E, 0x000049}, {0x00012F, 0x00012F, 0x000069}, {0x000130, 0x000130, 0x000049}, + {0x000134, 0x000134, 0x00004A}, {0x000135, 0x000135, 0x00006A}, {0x000136, 0x000136, 0x00004B}, + {0x000137, 0x000137, 0x00006B}, {0x000139, 0x000139, 0x00004C}, {0x00013A, 0x00013A, 0x00006C}, + {0x00013B, 0x00013B, 0x00004C}, {0x00013C, 0x00013C, 0x00006C}, {0x00013D, 0x00013D, 0x00004C}, + {0x00013E, 0x00013E, 0x00006C}, {0x000143, 0x000143, 0x00004E}, {0x000144, 0x000144, 0x00006E}, + {0x000145, 0x000145, 0x00004E}, {0x000146, 0x000146, 0x00006E}, {0x000147, 0x000147, 0x00004E}, + {0x000148, 0x000148, 0x00006E}, {0x00014C, 0x00014C, 0x00004F}, {0x00014D, 0x00014D, 0x00006F}, + {0x00014E, 0x00014E, 0x00004F}, {0x00014F, 0x00014F, 0x00006F}, {0x000150, 0x000150, 0x00004F}, + {0x000151, 0x000151, 0x00006F}, {0x000154, 0x000154, 0x000052}, {0x000155, 0x000155, 0x000072}, + {0x000156, 0x000156, 0x000052}, {0x000157, 0x000157, 0x000072}, {0x000158, 0x000158, 0x000052}, + {0x000159, 0x000159, 0x000072}, {0x00015A, 0x00015A, 0x000053}, {0x00015B, 0x00015B, 0x000073}, + {0x00015C, 0x00015C, 0x000053}, {0x00015D, 0x00015D, 0x000073}, {0x00015E, 0x00015E, 0x000053}, + {0x00015F, 0x00015F, 0x000073}, {0x000160, 0x000160, 0x000053}, {0x000161, 0x000161, 0x000073}, + {0x000162, 0x000162, 0x000054}, {0x000163, 0x000163, 0x000074}, {0x000164, 0x000164, 0x000054}, + {0x000165, 0x000165, 0x000074}, {0x000168, 0x000168, 0x000055}, {0x000169, 0x000169, 0x000075}, + {0x00016A, 0x00016A, 0x000055}, {0x00016B, 0x00016B, 0x000075}, {0x00016C, 0x00016C, 0x000055}, + {0x00016D, 0x00016D, 0x000075}, {0x00016E, 0x00016E, 0x000055}, {0x00016F, 0x00016F, 0x000075}, + {0x000170, 0x000170, 0x000055}, {0x000171, 0x000171, 0x000075}, {0x000172, 0x000172, 0x000055}, + {0x000173, 0x000173, 0x000075}, {0x000174, 0x000174, 0x000057}, {0x000175, 0x000175, 0x000077}, + {0x000176, 0x000176, 0x000059}, {0x000177, 0x000177, 0x000079}, {0x000178, 0x000178, 0x000059}, + {0x000179, 0x000179, 0x00005A}, {0x00017A, 0x00017A, 0x00007A}, {0x00017B, 0x00017B, 0x00005A}, + {0x00017C, 0x00017C, 0x00007A}, {0x00017D, 0x00017D, 0x00005A}, {0x00017E, 0x00017E, 0x00007A}, + {0x0001A0, 0x0001A0, 0x00004F}, {0x0001A1, 0x0001A1, 0x00006F}, {0x0001AF, 0x0001AF, 0x000055}, + {0x0001B0, 0x0001B0, 0x000075}, {0x0001CD, 0x0001CD, 0x000041}, {0x0001CE, 0x0001CE, 0x000061}, + {0x0001CF, 0x0001CF, 0x000049}, {0x0001D0, 0x0001D0, 0x000069}, {0x0001D1, 0x0001D1, 0x00004F}, + {0x0001D2, 0x0001D2, 0x00006F}, {0x0001D3, 0x0001D3, 0x000055}, {0x0001D4, 0x0001D4, 0x000075}, + {0x0001D5, 0x0001D5, 0x000055}, {0x0001D6, 0x0001D6, 0x000075}, {0x0001D7, 0x0001D7, 0x000055}, + {0x0001D8, 0x0001D8, 0x000075}, {0x0001D9, 0x0001D9, 0x000055}, {0x0001DA, 0x0001DA, 0x000075}, + {0x0001DB, 0x0001DB, 0x000055}, {0x0001DC, 0x0001DC, 0x000075}, {0x0001DE, 0x0001DE, 0x000041}, + {0x0001DF, 0x0001DF, 0x000061}, {0x0001E0, 0x0001E0, 0x000041}, {0x0001E1, 0x0001E1, 0x000061}, + {0x0001E2, 0x0001E2, 0x0000C6}, {0x0001E3, 0x0001E3, 0x0000E6}, {0x0001E6, 0x0001E6, 0x000047}, + {0x0001E7, 0x0001E7, 0x000067}, {0x0001E8, 0x0001E8, 0x00004B}, {0x0001E9, 0x0001E9, 0x00006B}, + {0x0001EA, 0x0001EA, 0x00004F}, {0x0001EB, 0x0001EB, 0x00006F}, {0x0001EC, 0x0001EC, 0x00004F}, + {0x0001ED, 0x0001ED, 0x00006F}, {0x0001EE, 0x0001EE, 0x0001B7}, {0x0001EF, 0x0001EF, 0x000292}, + {0x0001F0, 0x0001F0, 0x00006A}, {0x0001F4, 0x0001F4, 0x000047}, {0x0001F5, 0x0001F5, 0x000067}, + {0x0001F8, 0x0001F8, 0x00004E}, {0x0001F9, 0x0001F9, 0x00006E}, {0x0001FA, 0x0001FA, 0x000041}, + {0x0001FB, 0x0001FB, 0x000061}, {0x0001FC, 0x0001FC, 0x0000C6}, {0x0001FD, 0x0001FD, 0x0000E6}, + {0x0001FE, 0x0001FE, 0x0000D8}, {0x0001FF, 0x0001FF, 0x0000F8}, {0x000200, 0x000200, 0x000041}, + {0x000201, 0x000201, 0x000061}, {0x000202, 0x000202, 0x000041}, {0x000203, 0x000203, 0x000061}, + {0x000204, 0x000204, 0x000045}, {0x000205, 0x000205, 0x000065}, {0x000206, 0x000206, 0x000045}, + {0x000207, 0x000207, 0x000065}, {0x000208, 0x000208, 0x000049}, {0x000209, 0x000209, 0x000069}, + {0x00020A, 0x00020A, 0x000049}, {0x00020B, 0x00020B, 0x000069}, {0x00020C, 0x00020C, 0x00004F}, + {0x00020D, 0x00020D, 0x00006F}, {0x00020E, 0x00020E, 0x00004F}, {0x00020F, 0x00020F, 0x00006F}, + {0x000210, 0x000210, 0x000052}, {0x000211, 0x000211, 0x000072}, {0x000212, 0x000212, 0x000052}, + {0x000213, 0x000213, 0x000072}, {0x000214, 0x000214, 0x000055}, {0x000215, 0x000215, 0x000075}, + {0x000216, 0x000216, 0x000055}, {0x000217, 0x000217, 0x000075}, {0x000218, 0x000218, 0x000053}, + {0x000219, 0x000219, 0x000073}, {0x00021A, 0x00021A, 0x000054}, {0x00021B, 0x00021B, 0x000074}, + {0x00021E, 0x00021E, 0x000048}, {0x00021F, 0x00021F, 0x000068}, {0x000226, 0x000226, 0x000041}, + {0x000227, 0x000227, 0x000061}, {0x000228, 0x000228, 0x000045}, {0x000229, 0x000229, 0x000065}, + {0x00022A, 0x00022A, 0x00004F}, {0x00022B, 0x00022B, 0x00006F}, {0x00022C, 0x00022C, 0x00004F}, + {0x00022D, 0x00022D, 0x00006F}, {0x00022E, 0x00022E, 0x00004F}, {0x00022F, 0x00022F, 0x00006F}, + {0x000230, 0x000230, 0x00004F}, {0x000231, 0x000231, 0x00006F}, {0x000232, 0x000232, 0x000059}, + {0x000233, 0x000233, 0x000079}, {0x000340, 0x000340, 0x000300}, {0x000341, 0x000341, 0x000301}, + {0x000343, 0x000343, 0x000313}, {0x000344, 0x000344, 0x000308}, {0x000374, 0x000374, 0x0002B9}, + {0x00037E, 0x00037E, 0x00003B}, {0x000385, 0x000385, 0x0000A8}, {0x000386, 0x000386, 0x000391}, + {0x000387, 0x000387, 0x0000B7}, {0x000388, 0x000388, 0x000395}, {0x000389, 0x000389, 0x000397}, + {0x00038A, 0x00038A, 0x000399}, {0x00038C, 0x00038C, 0x00039F}, {0x00038E, 0x00038E, 0x0003A5}, + {0x00038F, 0x00038F, 0x0003A9}, {0x000390, 0x000390, 0x0003B9}, {0x0003AA, 0x0003AA, 0x000399}, + {0x0003AB, 0x0003AB, 0x0003A5}, {0x0003AC, 0x0003AC, 0x0003B1}, {0x0003AD, 0x0003AD, 0x0003B5}, + {0x0003AE, 0x0003AE, 0x0003B7}, {0x0003AF, 0x0003AF, 0x0003B9}, {0x0003B0, 0x0003B0, 0x0003C5}, + {0x0003CA, 0x0003CA, 0x0003B9}, {0x0003CB, 0x0003CB, 0x0003C5}, {0x0003CC, 0x0003CC, 0x0003BF}, + {0x0003CD, 0x0003CD, 0x0003C5}, {0x0003CE, 0x0003CE, 0x0003C9}, {0x0003D3, 0x0003D4, 0x0003D2}, + {0x000400, 0x000401, 0x000415}, {0x000403, 0x000403, 0x000413}, {0x000407, 0x000407, 0x000406}, + {0x00040C, 0x00040C, 0x00041A}, {0x00040D, 0x00040D, 0x000418}, {0x00040E, 0x00040E, 0x000423}, + {0x000419, 0x000419, 0x000418}, {0x000439, 0x000439, 0x000438}, {0x000450, 0x000451, 0x000435}, + {0x000453, 0x000453, 0x000433}, {0x000457, 0x000457, 0x000456}, {0x00045C, 0x00045C, 0x00043A}, + {0x00045D, 0x00045D, 0x000438}, {0x00045E, 0x00045E, 0x000443}, {0x000476, 0x000476, 0x000474}, + {0x000477, 0x000477, 0x000475}, {0x0004C1, 0x0004C1, 0x000416}, {0x0004C2, 0x0004C2, 0x000436}, + {0x0004D0, 0x0004D0, 0x000410}, {0x0004D1, 0x0004D1, 0x000430}, {0x0004D2, 0x0004D2, 0x000410}, + {0x0004D3, 0x0004D3, 0x000430}, {0x0004D6, 0x0004D6, 0x000415}, {0x0004D7, 0x0004D7, 0x000435}, + {0x0004DA, 0x0004DA, 0x0004D8}, {0x0004DB, 0x0004DB, 0x0004D9}, {0x0004DC, 0x0004DC, 0x000416}, + {0x0004DD, 0x0004DD, 0x000436}, {0x0004DE, 0x0004DE, 0x000417}, {0x0004DF, 0x0004DF, 0x000437}, + {0x0004E2, 0x0004E2, 0x000418}, {0x0004E3, 0x0004E3, 0x000438}, {0x0004E4, 0x0004E4, 0x000418}, + {0x0004E5, 0x0004E5, 0x000438}, {0x0004E6, 0x0004E6, 0x00041E}, {0x0004E7, 0x0004E7, 0x00043E}, + {0x0004EA, 0x0004EA, 0x0004E8}, {0x0004EB, 0x0004EB, 0x0004E9}, {0x0004EC, 0x0004EC, 0x00042D}, + {0x0004ED, 0x0004ED, 0x00044D}, {0x0004EE, 0x0004EE, 0x000423}, {0x0004EF, 0x0004EF, 0x000443}, + {0x0004F0, 0x0004F0, 0x000423}, {0x0004F1, 0x0004F1, 0x000443}, {0x0004F2, 0x0004F2, 0x000423}, + {0x0004F3, 0x0004F3, 0x000443}, {0x0004F4, 0x0004F4, 0x000427}, {0x0004F5, 0x0004F5, 0x000447}, + {0x0004F8, 0x0004F8, 0x00042B}, {0x0004F9, 0x0004F9, 0x00044B}, {0x000622, 0x000623, 0x000627}, + {0x000624, 0x000624, 0x000648}, {0x000625, 0x000625, 0x000627}, {0x000626, 0x000626, 0x00064A}, + {0x0006C0, 0x0006C0, 0x0006D5}, {0x0006C2, 0x0006C2, 0x0006C1}, {0x0006D3, 0x0006D3, 0x0006D2}, + {0x000929, 0x000929, 0x000928}, {0x000931, 0x000931, 0x000930}, {0x000934, 0x000934, 0x000933}, + {0x000958, 0x000958, 0x000915}, {0x000959, 0x000959, 0x000916}, {0x00095A, 0x00095A, 0x000917}, + {0x00095B, 0x00095B, 0x00091C}, {0x00095C, 0x00095C, 0x000921}, {0x00095D, 0x00095D, 0x000922}, + {0x00095E, 0x00095E, 0x00092B}, {0x00095F, 0x00095F, 0x00092F}, {0x0009CB, 0x0009CC, 0x0009C7}, + {0x0009DC, 0x0009DC, 0x0009A1}, {0x0009DD, 0x0009DD, 0x0009A2}, {0x0009DF, 0x0009DF, 0x0009AF}, + {0x000A33, 0x000A33, 0x000A32}, {0x000A36, 0x000A36, 0x000A38}, {0x000A59, 0x000A59, 0x000A16}, + {0x000A5A, 0x000A5A, 0x000A17}, {0x000A5B, 0x000A5B, 0x000A1C}, {0x000A5E, 0x000A5E, 0x000A2B}, + {0x000B48, 0x000B48, 0x000B47}, {0x000B4B, 0x000B4C, 0x000B47}, {0x000B5C, 0x000B5C, 0x000B21}, + {0x000B5D, 0x000B5D, 0x000B22}, {0x000B94, 0x000B94, 0x000B92}, {0x000BCA, 0x000BCA, 0x000BC6}, + {0x000BCB, 0x000BCB, 0x000BC7}, {0x000BCC, 0x000BCC, 0x000BC6}, {0x000C48, 0x000C48, 0x000C46}, + {0x000CC0, 0x000CC0, 0x000CBF}, {0x000CC7, 0x000CC8, 0x000CC6}, {0x000CCA, 0x000CCB, 0x000CC6}, + {0x000D4A, 0x000D4A, 0x000D46}, {0x000D4B, 0x000D4B, 0x000D47}, {0x000D4C, 0x000D4C, 0x000D46}, + {0x000DDA, 0x000DDA, 0x000DD9}, {0x000DDC, 0x000DDE, 0x000DD9}, {0x000F43, 0x000F43, 0x000F42}, + {0x000F4D, 0x000F4D, 0x000F4C}, {0x000F52, 0x000F52, 0x000F51}, {0x000F57, 0x000F57, 0x000F56}, + {0x000F5C, 0x000F5C, 0x000F5B}, {0x000F69, 0x000F69, 0x000F40}, {0x000F73, 0x000F73, 0x000F71}, + {0x000F75, 0x000F75, 0x000F71}, {0x000F76, 0x000F76, 0x000FB2}, {0x000F78, 0x000F78, 0x000FB3}, + {0x000F81, 0x000F81, 0x000F71}, {0x000F93, 0x000F93, 0x000F92}, {0x000F9D, 0x000F9D, 0x000F9C}, + {0x000FA2, 0x000FA2, 0x000FA1}, {0x000FA7, 0x000FA7, 0x000FA6}, {0x000FAC, 0x000FAC, 0x000FAB}, + {0x000FB9, 0x000FB9, 0x000F90}, {0x001026, 0x001026, 0x001025}, {0x001B06, 0x001B06, 0x001B05}, + {0x001B08, 0x001B08, 0x001B07}, {0x001B0A, 0x001B0A, 0x001B09}, {0x001B0C, 0x001B0C, 0x001B0B}, + {0x001B0E, 0x001B0E, 0x001B0D}, {0x001B12, 0x001B12, 0x001B11}, {0x001B3B, 0x001B3B, 0x001B3A}, + {0x001B3D, 0x001B3D, 0x001B3C}, {0x001B40, 0x001B40, 0x001B3E}, {0x001B41, 0x001B41, 0x001B3F}, + {0x001B43, 0x001B43, 0x001B42}, {0x001E00, 0x001E00, 0x000041}, {0x001E01, 0x001E01, 0x000061}, + {0x001E02, 0x001E02, 0x000042}, {0x001E03, 0x001E03, 0x000062}, {0x001E04, 0x001E04, 0x000042}, + {0x001E05, 0x001E05, 0x000062}, {0x001E06, 0x001E06, 0x000042}, {0x001E07, 0x001E07, 0x000062}, + {0x001E08, 0x001E08, 0x000043}, {0x001E09, 0x001E09, 0x000063}, {0x001E0A, 0x001E0A, 0x000044}, + {0x001E0B, 0x001E0B, 0x000064}, {0x001E0C, 0x001E0C, 0x000044}, {0x001E0D, 0x001E0D, 0x000064}, + {0x001E0E, 0x001E0E, 0x000044}, {0x001E0F, 0x001E0F, 0x000064}, {0x001E10, 0x001E10, 0x000044}, + {0x001E11, 0x001E11, 0x000064}, {0x001E12, 0x001E12, 0x000044}, {0x001E13, 0x001E13, 0x000064}, + {0x001E14, 0x001E14, 0x000045}, {0x001E15, 0x001E15, 0x000065}, {0x001E16, 0x001E16, 0x000045}, + {0x001E17, 0x001E17, 0x000065}, {0x001E18, 0x001E18, 0x000045}, {0x001E19, 0x001E19, 0x000065}, + {0x001E1A, 0x001E1A, 0x000045}, {0x001E1B, 0x001E1B, 0x000065}, {0x001E1C, 0x001E1C, 0x000045}, + {0x001E1D, 0x001E1D, 0x000065}, {0x001E1E, 0x001E1E, 0x000046}, {0x001E1F, 0x001E1F, 0x000066}, + {0x001E20, 0x001E20, 0x000047}, {0x001E21, 0x001E21, 0x000067}, {0x001E22, 0x001E22, 0x000048}, + {0x001E23, 0x001E23, 0x000068}, {0x001E24, 0x001E24, 0x000048}, {0x001E25, 0x001E25, 0x000068}, + {0x001E26, 0x001E26, 0x000048}, {0x001E27, 0x001E27, 0x000068}, {0x001E28, 0x001E28, 0x000048}, + {0x001E29, 0x001E29, 0x000068}, {0x001E2A, 0x001E2A, 0x000048}, {0x001E2B, 0x001E2B, 0x000068}, + {0x001E2C, 0x001E2C, 0x000049}, {0x001E2D, 0x001E2D, 0x000069}, {0x001E2E, 0x001E2E, 0x000049}, + {0x001E2F, 0x001E2F, 0x000069}, {0x001E30, 0x001E30, 0x00004B}, {0x001E31, 0x001E31, 0x00006B}, + {0x001E32, 0x001E32, 0x00004B}, {0x001E33, 0x001E33, 0x00006B}, {0x001E34, 0x001E34, 0x00004B}, + {0x001E35, 0x001E35, 0x00006B}, {0x001E36, 0x001E36, 0x00004C}, {0x001E37, 0x001E37, 0x00006C}, + {0x001E38, 0x001E38, 0x00004C}, {0x001E39, 0x001E39, 0x00006C}, {0x001E3A, 0x001E3A, 0x00004C}, + {0x001E3B, 0x001E3B, 0x00006C}, {0x001E3C, 0x001E3C, 0x00004C}, {0x001E3D, 0x001E3D, 0x00006C}, + {0x001E3E, 0x001E3E, 0x00004D}, {0x001E3F, 0x001E3F, 0x00006D}, {0x001E40, 0x001E40, 0x00004D}, + {0x001E41, 0x001E41, 0x00006D}, {0x001E42, 0x001E42, 0x00004D}, {0x001E43, 0x001E43, 0x00006D}, + {0x001E44, 0x001E44, 0x00004E}, {0x001E45, 0x001E45, 0x00006E}, {0x001E46, 0x001E46, 0x00004E}, + {0x001E47, 0x001E47, 0x00006E}, {0x001E48, 0x001E48, 0x00004E}, {0x001E49, 0x001E49, 0x00006E}, + {0x001E4A, 0x001E4A, 0x00004E}, {0x001E4B, 0x001E4B, 0x00006E}, {0x001E4C, 0x001E4C, 0x00004F}, + {0x001E4D, 0x001E4D, 0x00006F}, {0x001E4E, 0x001E4E, 0x00004F}, {0x001E4F, 0x001E4F, 0x00006F}, + {0x001E50, 0x001E50, 0x00004F}, {0x001E51, 0x001E51, 0x00006F}, {0x001E52, 0x001E52, 0x00004F}, + {0x001E53, 0x001E53, 0x00006F}, {0x001E54, 0x001E54, 0x000050}, {0x001E55, 0x001E55, 0x000070}, + {0x001E56, 0x001E56, 0x000050}, {0x001E57, 0x001E57, 0x000070}, {0x001E58, 0x001E58, 0x000052}, + {0x001E59, 0x001E59, 0x000072}, {0x001E5A, 0x001E5A, 0x000052}, {0x001E5B, 0x001E5B, 0x000072}, + {0x001E5C, 0x001E5C, 0x000052}, {0x001E5D, 0x001E5D, 0x000072}, {0x001E5E, 0x001E5E, 0x000052}, + {0x001E5F, 0x001E5F, 0x000072}, {0x001E60, 0x001E60, 0x000053}, {0x001E61, 0x001E61, 0x000073}, + {0x001E62, 0x001E62, 0x000053}, {0x001E63, 0x001E63, 0x000073}, {0x001E64, 0x001E64, 0x000053}, + {0x001E65, 0x001E65, 0x000073}, {0x001E66, 0x001E66, 0x000053}, {0x001E67, 0x001E67, 0x000073}, + {0x001E68, 0x001E68, 0x000053}, {0x001E69, 0x001E69, 0x000073}, {0x001E6A, 0x001E6A, 0x000054}, + {0x001E6B, 0x001E6B, 0x000074}, {0x001E6C, 0x001E6C, 0x000054}, {0x001E6D, 0x001E6D, 0x000074}, + {0x001E6E, 0x001E6E, 0x000054}, {0x001E6F, 0x001E6F, 0x000074}, {0x001E70, 0x001E70, 0x000054}, + {0x001E71, 0x001E71, 0x000074}, {0x001E72, 0x001E72, 0x000055}, {0x001E73, 0x001E73, 0x000075}, + {0x001E74, 0x001E74, 0x000055}, {0x001E75, 0x001E75, 0x000075}, {0x001E76, 0x001E76, 0x000055}, + {0x001E77, 0x001E77, 0x000075}, {0x001E78, 0x001E78, 0x000055}, {0x001E79, 0x001E79, 0x000075}, + {0x001E7A, 0x001E7A, 0x000055}, {0x001E7B, 0x001E7B, 0x000075}, {0x001E7C, 0x001E7C, 0x000056}, + {0x001E7D, 0x001E7D, 0x000076}, {0x001E7E, 0x001E7E, 0x000056}, {0x001E7F, 0x001E7F, 0x000076}, + {0x001E80, 0x001E80, 0x000057}, {0x001E81, 0x001E81, 0x000077}, {0x001E82, 0x001E82, 0x000057}, + {0x001E83, 0x001E83, 0x000077}, {0x001E84, 0x001E84, 0x000057}, {0x001E85, 0x001E85, 0x000077}, + {0x001E86, 0x001E86, 0x000057}, {0x001E87, 0x001E87, 0x000077}, {0x001E88, 0x001E88, 0x000057}, + {0x001E89, 0x001E89, 0x000077}, {0x001E8A, 0x001E8A, 0x000058}, {0x001E8B, 0x001E8B, 0x000078}, + {0x001E8C, 0x001E8C, 0x000058}, {0x001E8D, 0x001E8D, 0x000078}, {0x001E8E, 0x001E8E, 0x000059}, + {0x001E8F, 0x001E8F, 0x000079}, {0x001E90, 0x001E90, 0x00005A}, {0x001E91, 0x001E91, 0x00007A}, + {0x001E92, 0x001E92, 0x00005A}, {0x001E93, 0x001E93, 0x00007A}, {0x001E94, 0x001E94, 0x00005A}, + {0x001E95, 0x001E95, 0x00007A}, {0x001E96, 0x001E96, 0x000068}, {0x001E97, 0x001E97, 0x000074}, + {0x001E98, 0x001E98, 0x000077}, {0x001E99, 0x001E99, 0x000079}, {0x001E9B, 0x001E9B, 0x00017F}, + {0x001EA0, 0x001EA0, 0x000041}, {0x001EA1, 0x001EA1, 0x000061}, {0x001EA2, 0x001EA2, 0x000041}, + {0x001EA3, 0x001EA3, 0x000061}, {0x001EA4, 0x001EA4, 0x000041}, {0x001EA5, 0x001EA5, 0x000061}, + {0x001EA6, 0x001EA6, 0x000041}, {0x001EA7, 0x001EA7, 0x000061}, {0x001EA8, 0x001EA8, 0x000041}, + {0x001EA9, 0x001EA9, 0x000061}, {0x001EAA, 0x001EAA, 0x000041}, {0x001EAB, 0x001EAB, 0x000061}, + {0x001EAC, 0x001EAC, 0x000041}, {0x001EAD, 0x001EAD, 0x000061}, {0x001EAE, 0x001EAE, 0x000041}, + {0x001EAF, 0x001EAF, 0x000061}, {0x001EB0, 0x001EB0, 0x000041}, {0x001EB1, 0x001EB1, 0x000061}, + {0x001EB2, 0x001EB2, 0x000041}, {0x001EB3, 0x001EB3, 0x000061}, {0x001EB4, 0x001EB4, 0x000041}, + {0x001EB5, 0x001EB5, 0x000061}, {0x001EB6, 0x001EB6, 0x000041}, {0x001EB7, 0x001EB7, 0x000061}, + {0x001EB8, 0x001EB8, 0x000045}, {0x001EB9, 0x001EB9, 0x000065}, {0x001EBA, 0x001EBA, 0x000045}, + {0x001EBB, 0x001EBB, 0x000065}, {0x001EBC, 0x001EBC, 0x000045}, {0x001EBD, 0x001EBD, 0x000065}, + {0x001EBE, 0x001EBE, 0x000045}, {0x001EBF, 0x001EBF, 0x000065}, {0x001EC0, 0x001EC0, 0x000045}, + {0x001EC1, 0x001EC1, 0x000065}, {0x001EC2, 0x001EC2, 0x000045}, {0x001EC3, 0x001EC3, 0x000065}, + {0x001EC4, 0x001EC4, 0x000045}, {0x001EC5, 0x001EC5, 0x000065}, {0x001EC6, 0x001EC6, 0x000045}, + {0x001EC7, 0x001EC7, 0x000065}, {0x001EC8, 0x001EC8, 0x000049}, {0x001EC9, 0x001EC9, 0x000069}, + {0x001ECA, 0x001ECA, 0x000049}, {0x001ECB, 0x001ECB, 0x000069}, {0x001ECC, 0x001ECC, 0x00004F}, + {0x001ECD, 0x001ECD, 0x00006F}, {0x001ECE, 0x001ECE, 0x00004F}, {0x001ECF, 0x001ECF, 0x00006F}, + {0x001ED0, 0x001ED0, 0x00004F}, {0x001ED1, 0x001ED1, 0x00006F}, {0x001ED2, 0x001ED2, 0x00004F}, + {0x001ED3, 0x001ED3, 0x00006F}, {0x001ED4, 0x001ED4, 0x00004F}, {0x001ED5, 0x001ED5, 0x00006F}, + {0x001ED6, 0x001ED6, 0x00004F}, {0x001ED7, 0x001ED7, 0x00006F}, {0x001ED8, 0x001ED8, 0x00004F}, + {0x001ED9, 0x001ED9, 0x00006F}, {0x001EDA, 0x001EDA, 0x00004F}, {0x001EDB, 0x001EDB, 0x00006F}, + {0x001EDC, 0x001EDC, 0x00004F}, {0x001EDD, 0x001EDD, 0x00006F}, {0x001EDE, 0x001EDE, 0x00004F}, + {0x001EDF, 0x001EDF, 0x00006F}, {0x001EE0, 0x001EE0, 0x00004F}, {0x001EE1, 0x001EE1, 0x00006F}, + {0x001EE2, 0x001EE2, 0x00004F}, {0x001EE3, 0x001EE3, 0x00006F}, {0x001EE4, 0x001EE4, 0x000055}, + {0x001EE5, 0x001EE5, 0x000075}, {0x001EE6, 0x001EE6, 0x000055}, {0x001EE7, 0x001EE7, 0x000075}, + {0x001EE8, 0x001EE8, 0x000055}, {0x001EE9, 0x001EE9, 0x000075}, {0x001EEA, 0x001EEA, 0x000055}, + {0x001EEB, 0x001EEB, 0x000075}, {0x001EEC, 0x001EEC, 0x000055}, {0x001EED, 0x001EED, 0x000075}, + {0x001EEE, 0x001EEE, 0x000055}, {0x001EEF, 0x001EEF, 0x000075}, {0x001EF0, 0x001EF0, 0x000055}, + {0x001EF1, 0x001EF1, 0x000075}, {0x001EF2, 0x001EF2, 0x000059}, {0x001EF3, 0x001EF3, 0x000079}, + {0x001EF4, 0x001EF4, 0x000059}, {0x001EF5, 0x001EF5, 0x000079}, {0x001EF6, 0x001EF6, 0x000059}, + {0x001EF7, 0x001EF7, 0x000079}, {0x001EF8, 0x001EF8, 0x000059}, {0x001EF9, 0x001EF9, 0x000079}, + {0x001F00, 0x001F07, 0x0003B1}, {0x001F08, 0x001F0F, 0x000391}, {0x001F10, 0x001F15, 0x0003B5}, + {0x001F18, 0x001F1D, 0x000395}, {0x001F20, 0x001F27, 0x0003B7}, {0x001F28, 0x001F2F, 0x000397}, + {0x001F30, 0x001F37, 0x0003B9}, {0x001F38, 0x001F3F, 0x000399}, {0x001F40, 0x001F45, 0x0003BF}, + {0x001F48, 0x001F4D, 0x00039F}, {0x001F50, 0x001F57, 0x0003C5}, {0x001F59, 0x001F59, 0x0003A5}, + {0x001F5B, 0x001F5B, 0x0003A5}, {0x001F5D, 0x001F5D, 0x0003A5}, {0x001F5F, 0x001F5F, 0x0003A5}, + {0x001F60, 0x001F67, 0x0003C9}, {0x001F68, 0x001F6F, 0x0003A9}, {0x001F70, 0x001F71, 0x0003B1}, + {0x001F72, 0x001F73, 0x0003B5}, {0x001F74, 0x001F75, 0x0003B7}, {0x001F76, 0x001F77, 0x0003B9}, + {0x001F78, 0x001F79, 0x0003BF}, {0x001F7A, 0x001F7B, 0x0003C5}, {0x001F7C, 0x001F7D, 0x0003C9}, + {0x001F80, 0x001F87, 0x0003B1}, {0x001F88, 0x001F8F, 0x000391}, {0x001F90, 0x001F97, 0x0003B7}, + {0x001F98, 0x001F9F, 0x000397}, {0x001FA0, 0x001FA7, 0x0003C9}, {0x001FA8, 0x001FAF, 0x0003A9}, + {0x001FB0, 0x001FB4, 0x0003B1}, {0x001FB6, 0x001FB7, 0x0003B1}, {0x001FB8, 0x001FBC, 0x000391}, + {0x001FBE, 0x001FBE, 0x0003B9}, {0x001FC1, 0x001FC1, 0x0000A8}, {0x001FC2, 0x001FC4, 0x0003B7}, + {0x001FC6, 0x001FC7, 0x0003B7}, {0x001FC8, 0x001FC9, 0x000395}, {0x001FCA, 0x001FCC, 0x000397}, + {0x001FCD, 0x001FCF, 0x001FBF}, {0x001FD0, 0x001FD3, 0x0003B9}, {0x001FD6, 0x001FD7, 0x0003B9}, + {0x001FD8, 0x001FDB, 0x000399}, {0x001FDD, 0x001FDF, 0x001FFE}, {0x001FE0, 0x001FE3, 0x0003C5}, + {0x001FE4, 0x001FE5, 0x0003C1}, {0x001FE6, 0x001FE7, 0x0003C5}, {0x001FE8, 0x001FEB, 0x0003A5}, + {0x001FEC, 0x001FEC, 0x0003A1}, {0x001FED, 0x001FEE, 0x0000A8}, {0x001FEF, 0x001FEF, 0x000060}, + {0x001FF2, 0x001FF4, 0x0003C9}, {0x001FF6, 0x001FF7, 0x0003C9}, {0x001FF8, 0x001FF9, 0x00039F}, + {0x001FFA, 0x001FFC, 0x0003A9}, {0x001FFD, 0x001FFD, 0x0000B4}, {0x002000, 0x002000, 0x002002}, + {0x002001, 0x002001, 0x002003}, {0x002126, 0x002126, 0x0003A9}, {0x00212A, 0x00212A, 0x00004B}, + {0x00212B, 0x00212B, 0x000041}, {0x00219A, 0x00219A, 0x002190}, {0x00219B, 0x00219B, 0x002192}, + {0x0021AE, 0x0021AE, 0x002194}, {0x0021CD, 0x0021CD, 0x0021D0}, {0x0021CE, 0x0021CE, 0x0021D4}, + {0x0021CF, 0x0021CF, 0x0021D2}, {0x002204, 0x002204, 0x002203}, {0x002209, 0x002209, 0x002208}, + {0x00220C, 0x00220C, 0x00220B}, {0x002224, 0x002224, 0x002223}, {0x002226, 0x002226, 0x002225}, + {0x002241, 0x002241, 0x00223C}, {0x002244, 0x002244, 0x002243}, {0x002247, 0x002247, 0x002245}, + {0x002249, 0x002249, 0x002248}, {0x002260, 0x002260, 0x00003D}, {0x002262, 0x002262, 0x002261}, + {0x00226D, 0x00226D, 0x00224D}, {0x00226E, 0x00226E, 0x00003C}, {0x00226F, 0x00226F, 0x00003E}, + {0x002270, 0x002270, 0x002264}, {0x002271, 0x002271, 0x002265}, {0x002274, 0x002274, 0x002272}, + {0x002275, 0x002275, 0x002273}, {0x002278, 0x002278, 0x002276}, {0x002279, 0x002279, 0x002277}, + {0x002280, 0x002280, 0x00227A}, {0x002281, 0x002281, 0x00227B}, {0x002284, 0x002284, 0x002282}, + {0x002285, 0x002285, 0x002283}, {0x002288, 0x002288, 0x002286}, {0x002289, 0x002289, 0x002287}, + {0x0022AC, 0x0022AC, 0x0022A2}, {0x0022AD, 0x0022AD, 0x0022A8}, {0x0022AE, 0x0022AE, 0x0022A9}, + {0x0022AF, 0x0022AF, 0x0022AB}, {0x0022E0, 0x0022E0, 0x00227C}, {0x0022E1, 0x0022E1, 0x00227D}, + {0x0022E2, 0x0022E2, 0x002291}, {0x0022E3, 0x0022E3, 0x002292}, {0x0022EA, 0x0022EA, 0x0022B2}, + {0x0022EB, 0x0022EB, 0x0022B3}, {0x0022EC, 0x0022EC, 0x0022B4}, {0x0022ED, 0x0022ED, 0x0022B5}, + {0x002329, 0x002329, 0x003008}, {0x00232A, 0x00232A, 0x003009}, {0x002ADC, 0x002ADC, 0x002ADD}, + {0x00304C, 0x00304C, 0x00304B}, {0x00304E, 0x00304E, 0x00304D}, {0x003050, 0x003050, 0x00304F}, + {0x003052, 0x003052, 0x003051}, {0x003054, 0x003054, 0x003053}, {0x003056, 0x003056, 0x003055}, + {0x003058, 0x003058, 0x003057}, {0x00305A, 0x00305A, 0x003059}, {0x00305C, 0x00305C, 0x00305B}, + {0x00305E, 0x00305E, 0x00305D}, {0x003060, 0x003060, 0x00305F}, {0x003062, 0x003062, 0x003061}, + {0x003065, 0x003065, 0x003064}, {0x003067, 0x003067, 0x003066}, {0x003069, 0x003069, 0x003068}, + {0x003070, 0x003071, 0x00306F}, {0x003073, 0x003074, 0x003072}, {0x003076, 0x003077, 0x003075}, + {0x003079, 0x00307A, 0x003078}, {0x00307C, 0x00307D, 0x00307B}, {0x003094, 0x003094, 0x003046}, + {0x00309E, 0x00309E, 0x00309D}, {0x0030AC, 0x0030AC, 0x0030AB}, {0x0030AE, 0x0030AE, 0x0030AD}, + {0x0030B0, 0x0030B0, 0x0030AF}, {0x0030B2, 0x0030B2, 0x0030B1}, {0x0030B4, 0x0030B4, 0x0030B3}, + {0x0030B6, 0x0030B6, 0x0030B5}, {0x0030B8, 0x0030B8, 0x0030B7}, {0x0030BA, 0x0030BA, 0x0030B9}, + {0x0030BC, 0x0030BC, 0x0030BB}, {0x0030BE, 0x0030BE, 0x0030BD}, {0x0030C0, 0x0030C0, 0x0030BF}, + {0x0030C2, 0x0030C2, 0x0030C1}, {0x0030C5, 0x0030C5, 0x0030C4}, {0x0030C7, 0x0030C7, 0x0030C6}, + {0x0030C9, 0x0030C9, 0x0030C8}, {0x0030D0, 0x0030D1, 0x0030CF}, {0x0030D3, 0x0030D4, 0x0030D2}, + {0x0030D6, 0x0030D7, 0x0030D5}, {0x0030D9, 0x0030DA, 0x0030D8}, {0x0030DC, 0x0030DD, 0x0030DB}, + {0x0030F4, 0x0030F4, 0x0030A6}, {0x0030F7, 0x0030F7, 0x0030EF}, {0x0030F8, 0x0030F8, 0x0030F0}, + {0x0030F9, 0x0030F9, 0x0030F1}, {0x0030FA, 0x0030FA, 0x0030F2}, {0x0030FE, 0x0030FE, 0x0030FD}, + {0x00AC00, 0x00AE4B, 0x001100}, {0x00AE4C, 0x00B097, 0x001101}, {0x00B098, 0x00B2E3, 0x001102}, + {0x00B2E4, 0x00B52F, 0x001103}, {0x00B530, 0x00B77B, 0x001104}, {0x00B77C, 0x00B9C7, 0x001105}, + {0x00B9C8, 0x00BC13, 0x001106}, {0x00BC14, 0x00BE5F, 0x001107}, {0x00BE60, 0x00C0AB, 0x001108}, + {0x00C0AC, 0x00C2F7, 0x001109}, {0x00C2F8, 0x00C543, 0x00110A}, {0x00C544, 0x00C78F, 0x00110B}, + {0x00C790, 0x00C9DB, 0x00110C}, {0x00C9DC, 0x00CC27, 0x00110D}, {0x00CC28, 0x00CE73, 0x00110E}, + {0x00CE74, 0x00D0BF, 0x00110F}, {0x00D0C0, 0x00D30B, 0x001110}, {0x00D30C, 0x00D557, 0x001111}, + {0x00D558, 0x00D7A3, 0x001112}, {0x00F900, 0x00F900, 0x008C48}, {0x00F901, 0x00F901, 0x0066F4}, + {0x00F902, 0x00F902, 0x008ECA}, {0x00F903, 0x00F903, 0x008CC8}, {0x00F904, 0x00F904, 0x006ED1}, + {0x00F905, 0x00F905, 0x004E32}, {0x00F906, 0x00F906, 0x0053E5}, {0x00F907, 0x00F908, 0x009F9C}, + {0x00F909, 0x00F909, 0x005951}, {0x00F90A, 0x00F90A, 0x0091D1}, {0x00F90B, 0x00F90B, 0x005587}, + {0x00F90C, 0x00F90C, 0x005948}, {0x00F90D, 0x00F90D, 0x0061F6}, {0x00F90E, 0x00F90E, 0x007669}, + {0x00F90F, 0x00F90F, 0x007F85}, {0x00F910, 0x00F910, 0x00863F}, {0x00F911, 0x00F911, 0x0087BA}, + {0x00F912, 0x00F912, 0x0088F8}, {0x00F913, 0x00F913, 0x00908F}, {0x00F914, 0x00F914, 0x006A02}, + {0x00F915, 0x00F915, 0x006D1B}, {0x00F916, 0x00F916, 0x0070D9}, {0x00F917, 0x00F917, 0x0073DE}, + {0x00F918, 0x00F918, 0x00843D}, {0x00F919, 0x00F919, 0x00916A}, {0x00F91A, 0x00F91A, 0x0099F1}, + {0x00F91B, 0x00F91B, 0x004E82}, {0x00F91C, 0x00F91C, 0x005375}, {0x00F91D, 0x00F91D, 0x006B04}, + {0x00F91E, 0x00F91E, 0x00721B}, {0x00F91F, 0x00F91F, 0x00862D}, {0x00F920, 0x00F920, 0x009E1E}, + {0x00F921, 0x00F921, 0x005D50}, {0x00F922, 0x00F922, 0x006FEB}, {0x00F923, 0x00F923, 0x0085CD}, + {0x00F924, 0x00F924, 0x008964}, {0x00F925, 0x00F925, 0x0062C9}, {0x00F926, 0x00F926, 0x0081D8}, + {0x00F927, 0x00F927, 0x00881F}, {0x00F928, 0x00F928, 0x005ECA}, {0x00F929, 0x00F929, 0x006717}, + {0x00F92A, 0x00F92A, 0x006D6A}, {0x00F92B, 0x00F92B, 0x0072FC}, {0x00F92C, 0x00F92C, 0x0090CE}, + {0x00F92D, 0x00F92D, 0x004F86}, {0x00F92E, 0x00F92E, 0x0051B7}, {0x00F92F, 0x00F92F, 0x0052DE}, + {0x00F930, 0x00F930, 0x0064C4}, {0x00F931, 0x00F931, 0x006AD3}, {0x00F932, 0x00F932, 0x007210}, + {0x00F933, 0x00F933, 0x0076E7}, {0x00F934, 0x00F934, 0x008001}, {0x00F935, 0x00F935, 0x008606}, + {0x00F936, 0x00F936, 0x00865C}, {0x00F937, 0x00F937, 0x008DEF}, {0x00F938, 0x00F938, 0x009732}, + {0x00F939, 0x00F939, 0x009B6F}, {0x00F93A, 0x00F93A, 0x009DFA}, {0x00F93B, 0x00F93B, 0x00788C}, + {0x00F93C, 0x00F93C, 0x00797F}, {0x00F93D, 0x00F93D, 0x007DA0}, {0x00F93E, 0x00F93E, 0x0083C9}, + {0x00F93F, 0x00F93F, 0x009304}, {0x00F940, 0x00F940, 0x009E7F}, {0x00F941, 0x00F941, 0x008AD6}, + {0x00F942, 0x00F942, 0x0058DF}, {0x00F943, 0x00F943, 0x005F04}, {0x00F944, 0x00F944, 0x007C60}, + {0x00F945, 0x00F945, 0x00807E}, {0x00F946, 0x00F946, 0x007262}, {0x00F947, 0x00F947, 0x0078CA}, + {0x00F948, 0x00F948, 0x008CC2}, {0x00F949, 0x00F949, 0x0096F7}, {0x00F94A, 0x00F94A, 0x0058D8}, + {0x00F94B, 0x00F94B, 0x005C62}, {0x00F94C, 0x00F94C, 0x006A13}, {0x00F94D, 0x00F94D, 0x006DDA}, + {0x00F94E, 0x00F94E, 0x006F0F}, {0x00F94F, 0x00F94F, 0x007D2F}, {0x00F950, 0x00F950, 0x007E37}, + {0x00F951, 0x00F951, 0x00964B}, {0x00F952, 0x00F952, 0x0052D2}, {0x00F953, 0x00F953, 0x00808B}, + {0x00F954, 0x00F954, 0x0051DC}, {0x00F955, 0x00F955, 0x0051CC}, {0x00F956, 0x00F956, 0x007A1C}, + {0x00F957, 0x00F957, 0x007DBE}, {0x00F958, 0x00F958, 0x0083F1}, {0x00F959, 0x00F959, 0x009675}, + {0x00F95A, 0x00F95A, 0x008B80}, {0x00F95B, 0x00F95B, 0x0062CF}, {0x00F95C, 0x00F95C, 0x006A02}, + {0x00F95D, 0x00F95D, 0x008AFE}, {0x00F95E, 0x00F95E, 0x004E39}, {0x00F95F, 0x00F95F, 0x005BE7}, + {0x00F960, 0x00F960, 0x006012}, {0x00F961, 0x00F961, 0x007387}, {0x00F962, 0x00F962, 0x007570}, + {0x00F963, 0x00F963, 0x005317}, {0x00F964, 0x00F964, 0x0078FB}, {0x00F965, 0x00F965, 0x004FBF}, + {0x00F966, 0x00F966, 0x005FA9}, {0x00F967, 0x00F967, 0x004E0D}, {0x00F968, 0x00F968, 0x006CCC}, + {0x00F969, 0x00F969, 0x006578}, {0x00F96A, 0x00F96A, 0x007D22}, {0x00F96B, 0x00F96B, 0x0053C3}, + {0x00F96C, 0x00F96C, 0x00585E}, {0x00F96D, 0x00F96D, 0x007701}, {0x00F96E, 0x00F96E, 0x008449}, + {0x00F96F, 0x00F96F, 0x008AAA}, {0x00F970, 0x00F970, 0x006BBA}, {0x00F971, 0x00F971, 0x008FB0}, + {0x00F972, 0x00F972, 0x006C88}, {0x00F973, 0x00F973, 0x0062FE}, {0x00F974, 0x00F974, 0x0082E5}, + {0x00F975, 0x00F975, 0x0063A0}, {0x00F976, 0x00F976, 0x007565}, {0x00F977, 0x00F977, 0x004EAE}, + {0x00F978, 0x00F978, 0x005169}, {0x00F979, 0x00F979, 0x0051C9}, {0x00F97A, 0x00F97A, 0x006881}, + {0x00F97B, 0x00F97B, 0x007CE7}, {0x00F97C, 0x00F97C, 0x00826F}, {0x00F97D, 0x00F97D, 0x008AD2}, + {0x00F97E, 0x00F97E, 0x0091CF}, {0x00F97F, 0x00F97F, 0x0052F5}, {0x00F980, 0x00F980, 0x005442}, + {0x00F981, 0x00F981, 0x005973}, {0x00F982, 0x00F982, 0x005EEC}, {0x00F983, 0x00F983, 0x0065C5}, + {0x00F984, 0x00F984, 0x006FFE}, {0x00F985, 0x00F985, 0x00792A}, {0x00F986, 0x00F986, 0x0095AD}, + {0x00F987, 0x00F987, 0x009A6A}, {0x00F988, 0x00F988, 0x009E97}, {0x00F989, 0x00F989, 0x009ECE}, + {0x00F98A, 0x00F98A, 0x00529B}, {0x00F98B, 0x00F98B, 0x0066C6}, {0x00F98C, 0x00F98C, 0x006B77}, + {0x00F98D, 0x00F98D, 0x008F62}, {0x00F98E, 0x00F98E, 0x005E74}, {0x00F98F, 0x00F98F, 0x006190}, + {0x00F990, 0x00F990, 0x006200}, {0x00F991, 0x00F991, 0x00649A}, {0x00F992, 0x00F992, 0x006F23}, + {0x00F993, 0x00F993, 0x007149}, {0x00F994, 0x00F994, 0x007489}, {0x00F995, 0x00F995, 0x0079CA}, + {0x00F996, 0x00F996, 0x007DF4}, {0x00F997, 0x00F997, 0x00806F}, {0x00F998, 0x00F998, 0x008F26}, + {0x00F999, 0x00F999, 0x0084EE}, {0x00F99A, 0x00F99A, 0x009023}, {0x00F99B, 0x00F99B, 0x00934A}, + {0x00F99C, 0x00F99C, 0x005217}, {0x00F99D, 0x00F99D, 0x0052A3}, {0x00F99E, 0x00F99E, 0x0054BD}, + {0x00F99F, 0x00F99F, 0x0070C8}, {0x00F9A0, 0x00F9A0, 0x0088C2}, {0x00F9A1, 0x00F9A1, 0x008AAA}, + {0x00F9A2, 0x00F9A2, 0x005EC9}, {0x00F9A3, 0x00F9A3, 0x005FF5}, {0x00F9A4, 0x00F9A4, 0x00637B}, + {0x00F9A5, 0x00F9A5, 0x006BAE}, {0x00F9A6, 0x00F9A6, 0x007C3E}, {0x00F9A7, 0x00F9A7, 0x007375}, + {0x00F9A8, 0x00F9A8, 0x004EE4}, {0x00F9A9, 0x00F9A9, 0x0056F9}, {0x00F9AA, 0x00F9AA, 0x005BE7}, + {0x00F9AB, 0x00F9AB, 0x005DBA}, {0x00F9AC, 0x00F9AC, 0x00601C}, {0x00F9AD, 0x00F9AD, 0x0073B2}, + {0x00F9AE, 0x00F9AE, 0x007469}, {0x00F9AF, 0x00F9AF, 0x007F9A}, {0x00F9B0, 0x00F9B0, 0x008046}, + {0x00F9B1, 0x00F9B1, 0x009234}, {0x00F9B2, 0x00F9B2, 0x0096F6}, {0x00F9B3, 0x00F9B3, 0x009748}, + {0x00F9B4, 0x00F9B4, 0x009818}, {0x00F9B5, 0x00F9B5, 0x004F8B}, {0x00F9B6, 0x00F9B6, 0x0079AE}, + {0x00F9B7, 0x00F9B7, 0x0091B4}, {0x00F9B8, 0x00F9B8, 0x0096B8}, {0x00F9B9, 0x00F9B9, 0x0060E1}, + {0x00F9BA, 0x00F9BA, 0x004E86}, {0x00F9BB, 0x00F9BB, 0x0050DA}, {0x00F9BC, 0x00F9BC, 0x005BEE}, + {0x00F9BD, 0x00F9BD, 0x005C3F}, {0x00F9BE, 0x00F9BE, 0x006599}, {0x00F9BF, 0x00F9BF, 0x006A02}, + {0x00F9C0, 0x00F9C0, 0x0071CE}, {0x00F9C1, 0x00F9C1, 0x007642}, {0x00F9C2, 0x00F9C2, 0x0084FC}, + {0x00F9C3, 0x00F9C3, 0x00907C}, {0x00F9C4, 0x00F9C4, 0x009F8D}, {0x00F9C5, 0x00F9C5, 0x006688}, + {0x00F9C6, 0x00F9C6, 0x00962E}, {0x00F9C7, 0x00F9C7, 0x005289}, {0x00F9C8, 0x00F9C8, 0x00677B}, + {0x00F9C9, 0x00F9C9, 0x0067F3}, {0x00F9CA, 0x00F9CA, 0x006D41}, {0x00F9CB, 0x00F9CB, 0x006E9C}, + {0x00F9CC, 0x00F9CC, 0x007409}, {0x00F9CD, 0x00F9CD, 0x007559}, {0x00F9CE, 0x00F9CE, 0x00786B}, + {0x00F9CF, 0x00F9CF, 0x007D10}, {0x00F9D0, 0x00F9D0, 0x00985E}, {0x00F9D1, 0x00F9D1, 0x00516D}, + {0x00F9D2, 0x00F9D2, 0x00622E}, {0x00F9D3, 0x00F9D3, 0x009678}, {0x00F9D4, 0x00F9D4, 0x00502B}, + {0x00F9D5, 0x00F9D5, 0x005D19}, {0x00F9D6, 0x00F9D6, 0x006DEA}, {0x00F9D7, 0x00F9D7, 0x008F2A}, + {0x00F9D8, 0x00F9D8, 0x005F8B}, {0x00F9D9, 0x00F9D9, 0x006144}, {0x00F9DA, 0x00F9DA, 0x006817}, + {0x00F9DB, 0x00F9DB, 0x007387}, {0x00F9DC, 0x00F9DC, 0x009686}, {0x00F9DD, 0x00F9DD, 0x005229}, + {0x00F9DE, 0x00F9DE, 0x00540F}, {0x00F9DF, 0x00F9DF, 0x005C65}, {0x00F9E0, 0x00F9E0, 0x006613}, + {0x00F9E1, 0x00F9E1, 0x00674E}, {0x00F9E2, 0x00F9E2, 0x0068A8}, {0x00F9E3, 0x00F9E3, 0x006CE5}, + {0x00F9E4, 0x00F9E4, 0x007406}, {0x00F9E5, 0x00F9E5, 0x0075E2}, {0x00F9E6, 0x00F9E6, 0x007F79}, + {0x00F9E7, 0x00F9E7, 0x0088CF}, {0x00F9E8, 0x00F9E8, 0x0088E1}, {0x00F9E9, 0x00F9E9, 0x0091CC}, + {0x00F9EA, 0x00F9EA, 0x0096E2}, {0x00F9EB, 0x00F9EB, 0x00533F}, {0x00F9EC, 0x00F9EC, 0x006EBA}, + {0x00F9ED, 0x00F9ED, 0x00541D}, {0x00F9EE, 0x00F9EE, 0x0071D0}, {0x00F9EF, 0x00F9EF, 0x007498}, + {0x00F9F0, 0x00F9F0, 0x0085FA}, {0x00F9F1, 0x00F9F1, 0x0096A3}, {0x00F9F2, 0x00F9F2, 0x009C57}, + {0x00F9F3, 0x00F9F3, 0x009E9F}, {0x00F9F4, 0x00F9F4, 0x006797}, {0x00F9F5, 0x00F9F5, 0x006DCB}, + {0x00F9F6, 0x00F9F6, 0x0081E8}, {0x00F9F7, 0x00F9F7, 0x007ACB}, {0x00F9F8, 0x00F9F8, 0x007B20}, + {0x00F9F9, 0x00F9F9, 0x007C92}, {0x00F9FA, 0x00F9FA, 0x0072C0}, {0x00F9FB, 0x00F9FB, 0x007099}, + {0x00F9FC, 0x00F9FC, 0x008B58}, {0x00F9FD, 0x00F9FD, 0x004EC0}, {0x00F9FE, 0x00F9FE, 0x008336}, + {0x00F9FF, 0x00F9FF, 0x00523A}, {0x00FA00, 0x00FA00, 0x005207}, {0x00FA01, 0x00FA01, 0x005EA6}, + {0x00FA02, 0x00FA02, 0x0062D3}, {0x00FA03, 0x00FA03, 0x007CD6}, {0x00FA04, 0x00FA04, 0x005B85}, + {0x00FA05, 0x00FA05, 0x006D1E}, {0x00FA06, 0x00FA06, 0x0066B4}, {0x00FA07, 0x00FA07, 0x008F3B}, + {0x00FA08, 0x00FA08, 0x00884C}, {0x00FA09, 0x00FA09, 0x00964D}, {0x00FA0A, 0x00FA0A, 0x00898B}, + {0x00FA0B, 0x00FA0B, 0x005ED3}, {0x00FA0C, 0x00FA0C, 0x005140}, {0x00FA0D, 0x00FA0D, 0x0055C0}, + {0x00FA10, 0x00FA10, 0x00585A}, {0x00FA12, 0x00FA12, 0x006674}, {0x00FA15, 0x00FA15, 0x0051DE}, + {0x00FA16, 0x00FA16, 0x00732A}, {0x00FA17, 0x00FA17, 0x0076CA}, {0x00FA18, 0x00FA18, 0x00793C}, + {0x00FA19, 0x00FA19, 0x00795E}, {0x00FA1A, 0x00FA1A, 0x007965}, {0x00FA1B, 0x00FA1B, 0x00798F}, + {0x00FA1C, 0x00FA1C, 0x009756}, {0x00FA1D, 0x00FA1D, 0x007CBE}, {0x00FA1E, 0x00FA1E, 0x007FBD}, + {0x00FA20, 0x00FA20, 0x008612}, {0x00FA22, 0x00FA22, 0x008AF8}, {0x00FA25, 0x00FA25, 0x009038}, + {0x00FA26, 0x00FA26, 0x0090FD}, {0x00FA2A, 0x00FA2A, 0x0098EF}, {0x00FA2B, 0x00FA2B, 0x0098FC}, + {0x00FA2C, 0x00FA2C, 0x009928}, {0x00FA2D, 0x00FA2D, 0x009DB4}, {0x00FA2E, 0x00FA2E, 0x0090DE}, + {0x00FA2F, 0x00FA2F, 0x0096B7}, {0x00FA30, 0x00FA30, 0x004FAE}, {0x00FA31, 0x00FA31, 0x0050E7}, + {0x00FA32, 0x00FA32, 0x00514D}, {0x00FA33, 0x00FA33, 0x0052C9}, {0x00FA34, 0x00FA34, 0x0052E4}, + {0x00FA35, 0x00FA35, 0x005351}, {0x00FA36, 0x00FA36, 0x00559D}, {0x00FA37, 0x00FA37, 0x005606}, + {0x00FA38, 0x00FA38, 0x005668}, {0x00FA39, 0x00FA39, 0x005840}, {0x00FA3A, 0x00FA3A, 0x0058A8}, + {0x00FA3B, 0x00FA3B, 0x005C64}, {0x00FA3C, 0x00FA3C, 0x005C6E}, {0x00FA3D, 0x00FA3D, 0x006094}, + {0x00FA3E, 0x00FA3E, 0x006168}, {0x00FA3F, 0x00FA3F, 0x00618E}, {0x00FA40, 0x00FA40, 0x0061F2}, + {0x00FA41, 0x00FA41, 0x00654F}, {0x00FA42, 0x00FA42, 0x0065E2}, {0x00FA43, 0x00FA43, 0x006691}, + {0x00FA44, 0x00FA44, 0x006885}, {0x00FA45, 0x00FA45, 0x006D77}, {0x00FA46, 0x00FA46, 0x006E1A}, + {0x00FA47, 0x00FA47, 0x006F22}, {0x00FA48, 0x00FA48, 0x00716E}, {0x00FA49, 0x00FA49, 0x00722B}, + {0x00FA4A, 0x00FA4A, 0x007422}, {0x00FA4B, 0x00FA4B, 0x007891}, {0x00FA4C, 0x00FA4C, 0x00793E}, + {0x00FA4D, 0x00FA4D, 0x007949}, {0x00FA4E, 0x00FA4E, 0x007948}, {0x00FA4F, 0x00FA4F, 0x007950}, + {0x00FA50, 0x00FA50, 0x007956}, {0x00FA51, 0x00FA51, 0x00795D}, {0x00FA52, 0x00FA52, 0x00798D}, + {0x00FA53, 0x00FA53, 0x00798E}, {0x00FA54, 0x00FA54, 0x007A40}, {0x00FA55, 0x00FA55, 0x007A81}, + {0x00FA56, 0x00FA56, 0x007BC0}, {0x00FA57, 0x00FA57, 0x007DF4}, {0x00FA58, 0x00FA58, 0x007E09}, + {0x00FA59, 0x00FA59, 0x007E41}, {0x00FA5A, 0x00FA5A, 0x007F72}, {0x00FA5B, 0x00FA5B, 0x008005}, + {0x00FA5C, 0x00FA5C, 0x0081ED}, {0x00FA5D, 0x00FA5E, 0x008279}, {0x00FA5F, 0x00FA5F, 0x008457}, + {0x00FA60, 0x00FA60, 0x008910}, {0x00FA61, 0x00FA61, 0x008996}, {0x00FA62, 0x00FA62, 0x008B01}, + {0x00FA63, 0x00FA63, 0x008B39}, {0x00FA64, 0x00FA64, 0x008CD3}, {0x00FA65, 0x00FA65, 0x008D08}, + {0x00FA66, 0x00FA66, 0x008FB6}, {0x00FA67, 0x00FA67, 0x009038}, {0x00FA68, 0x00FA68, 0x0096E3}, + {0x00FA69, 0x00FA69, 0x0097FF}, {0x00FA6A, 0x00FA6A, 0x00983B}, {0x00FA6B, 0x00FA6B, 0x006075}, + {0x00FA6C, 0x00FA6C, 0x0242EE}, {0x00FA6D, 0x00FA6D, 0x008218}, {0x00FA70, 0x00FA70, 0x004E26}, + {0x00FA71, 0x00FA71, 0x0051B5}, {0x00FA72, 0x00FA72, 0x005168}, {0x00FA73, 0x00FA73, 0x004F80}, + {0x00FA74, 0x00FA74, 0x005145}, {0x00FA75, 0x00FA75, 0x005180}, {0x00FA76, 0x00FA76, 0x0052C7}, + {0x00FA77, 0x00FA77, 0x0052FA}, {0x00FA78, 0x00FA78, 0x00559D}, {0x00FA79, 0x00FA79, 0x005555}, + {0x00FA7A, 0x00FA7A, 0x005599}, {0x00FA7B, 0x00FA7B, 0x0055E2}, {0x00FA7C, 0x00FA7C, 0x00585A}, + {0x00FA7D, 0x00FA7D, 0x0058B3}, {0x00FA7E, 0x00FA7E, 0x005944}, {0x00FA7F, 0x00FA7F, 0x005954}, + {0x00FA80, 0x00FA80, 0x005A62}, {0x00FA81, 0x00FA81, 0x005B28}, {0x00FA82, 0x00FA82, 0x005ED2}, + {0x00FA83, 0x00FA83, 0x005ED9}, {0x00FA84, 0x00FA84, 0x005F69}, {0x00FA85, 0x00FA85, 0x005FAD}, + {0x00FA86, 0x00FA86, 0x0060D8}, {0x00FA87, 0x00FA87, 0x00614E}, {0x00FA88, 0x00FA88, 0x006108}, + {0x00FA89, 0x00FA89, 0x00618E}, {0x00FA8A, 0x00FA8A, 0x006160}, {0x00FA8B, 0x00FA8B, 0x0061F2}, + {0x00FA8C, 0x00FA8C, 0x006234}, {0x00FA8D, 0x00FA8D, 0x0063C4}, {0x00FA8E, 0x00FA8E, 0x00641C}, + {0x00FA8F, 0x00FA8F, 0x006452}, {0x00FA90, 0x00FA90, 0x006556}, {0x00FA91, 0x00FA91, 0x006674}, + {0x00FA92, 0x00FA92, 0x006717}, {0x00FA93, 0x00FA93, 0x00671B}, {0x00FA94, 0x00FA94, 0x006756}, + {0x00FA95, 0x00FA95, 0x006B79}, {0x00FA96, 0x00FA96, 0x006BBA}, {0x00FA97, 0x00FA97, 0x006D41}, + {0x00FA98, 0x00FA98, 0x006EDB}, {0x00FA99, 0x00FA99, 0x006ECB}, {0x00FA9A, 0x00FA9A, 0x006F22}, + {0x00FA9B, 0x00FA9B, 0x00701E}, {0x00FA9C, 0x00FA9C, 0x00716E}, {0x00FA9D, 0x00FA9D, 0x0077A7}, + {0x00FA9E, 0x00FA9E, 0x007235}, {0x00FA9F, 0x00FA9F, 0x0072AF}, {0x00FAA0, 0x00FAA0, 0x00732A}, + {0x00FAA1, 0x00FAA1, 0x007471}, {0x00FAA2, 0x00FAA2, 0x007506}, {0x00FAA3, 0x00FAA3, 0x00753B}, + {0x00FAA4, 0x00FAA4, 0x00761D}, {0x00FAA5, 0x00FAA5, 0x00761F}, {0x00FAA6, 0x00FAA6, 0x0076CA}, + {0x00FAA7, 0x00FAA7, 0x0076DB}, {0x00FAA8, 0x00FAA8, 0x0076F4}, {0x00FAA9, 0x00FAA9, 0x00774A}, + {0x00FAAA, 0x00FAAA, 0x007740}, {0x00FAAB, 0x00FAAB, 0x0078CC}, {0x00FAAC, 0x00FAAC, 0x007AB1}, + {0x00FAAD, 0x00FAAD, 0x007BC0}, {0x00FAAE, 0x00FAAE, 0x007C7B}, {0x00FAAF, 0x00FAAF, 0x007D5B}, + {0x00FAB0, 0x00FAB0, 0x007DF4}, {0x00FAB1, 0x00FAB1, 0x007F3E}, {0x00FAB2, 0x00FAB2, 0x008005}, + {0x00FAB3, 0x00FAB3, 0x008352}, {0x00FAB4, 0x00FAB4, 0x0083EF}, {0x00FAB5, 0x00FAB5, 0x008779}, + {0x00FAB6, 0x00FAB6, 0x008941}, {0x00FAB7, 0x00FAB7, 0x008986}, {0x00FAB8, 0x00FAB8, 0x008996}, + {0x00FAB9, 0x00FAB9, 0x008ABF}, {0x00FABA, 0x00FABA, 0x008AF8}, {0x00FABB, 0x00FABB, 0x008ACB}, + {0x00FABC, 0x00FABC, 0x008B01}, {0x00FABD, 0x00FABD, 0x008AFE}, {0x00FABE, 0x00FABE, 0x008AED}, + {0x00FABF, 0x00FABF, 0x008B39}, {0x00FAC0, 0x00FAC0, 0x008B8A}, {0x00FAC1, 0x00FAC1, 0x008D08}, + {0x00FAC2, 0x00FAC2, 0x008F38}, {0x00FAC3, 0x00FAC3, 0x009072}, {0x00FAC4, 0x00FAC4, 0x009199}, + {0x00FAC5, 0x00FAC5, 0x009276}, {0x00FAC6, 0x00FAC6, 0x00967C}, {0x00FAC7, 0x00FAC7, 0x0096E3}, + {0x00FAC8, 0x00FAC8, 0x009756}, {0x00FAC9, 0x00FAC9, 0x0097DB}, {0x00FACA, 0x00FACA, 0x0097FF}, + {0x00FACB, 0x00FACB, 0x00980B}, {0x00FACC, 0x00FACC, 0x00983B}, {0x00FACD, 0x00FACD, 0x009B12}, + {0x00FACE, 0x00FACE, 0x009F9C}, {0x00FACF, 0x00FACF, 0x02284A}, {0x00FAD0, 0x00FAD0, 0x022844}, + {0x00FAD1, 0x00FAD1, 0x0233D5}, {0x00FAD2, 0x00FAD2, 0x003B9D}, {0x00FAD3, 0x00FAD3, 0x004018}, + {0x00FAD4, 0x00FAD4, 0x004039}, {0x00FAD5, 0x00FAD5, 0x025249}, {0x00FAD6, 0x00FAD6, 0x025CD0}, + {0x00FAD7, 0x00FAD7, 0x027ED3}, {0x00FAD8, 0x00FAD8, 0x009F43}, {0x00FAD9, 0x00FAD9, 0x009F8E}, + {0x00FB1D, 0x00FB1D, 0x0005D9}, {0x00FB1F, 0x00FB1F, 0x0005F2}, {0x00FB2A, 0x00FB2D, 0x0005E9}, + {0x00FB2E, 0x00FB30, 0x0005D0}, {0x00FB31, 0x00FB31, 0x0005D1}, {0x00FB32, 0x00FB32, 0x0005D2}, + {0x00FB33, 0x00FB33, 0x0005D3}, {0x00FB34, 0x00FB34, 0x0005D4}, {0x00FB35, 0x00FB35, 0x0005D5}, + {0x00FB36, 0x00FB36, 0x0005D6}, {0x00FB38, 0x00FB38, 0x0005D8}, {0x00FB39, 0x00FB39, 0x0005D9}, + {0x00FB3A, 0x00FB3A, 0x0005DA}, {0x00FB3B, 0x00FB3B, 0x0005DB}, {0x00FB3C, 0x00FB3C, 0x0005DC}, + {0x00FB3E, 0x00FB3E, 0x0005DE}, {0x00FB40, 0x00FB40, 0x0005E0}, {0x00FB41, 0x00FB41, 0x0005E1}, + {0x00FB43, 0x00FB43, 0x0005E3}, {0x00FB44, 0x00FB44, 0x0005E4}, {0x00FB46, 0x00FB46, 0x0005E6}, + {0x00FB47, 0x00FB47, 0x0005E7}, {0x00FB48, 0x00FB48, 0x0005E8}, {0x00FB49, 0x00FB49, 0x0005E9}, + {0x00FB4A, 0x00FB4A, 0x0005EA}, {0x00FB4B, 0x00FB4B, 0x0005D5}, {0x00FB4C, 0x00FB4C, 0x0005D1}, + {0x00FB4D, 0x00FB4D, 0x0005DB}, {0x00FB4E, 0x00FB4E, 0x0005E4}, {0x01109A, 0x01109A, 0x011099}, + {0x01109C, 0x01109C, 0x01109B}, {0x0110AB, 0x0110AB, 0x0110A5}, {0x01112E, 0x01112E, 0x011131}, + {0x01112F, 0x01112F, 0x011132}, {0x01134B, 0x01134C, 0x011347}, {0x0114BB, 0x0114BC, 0x0114B9}, + {0x0114BE, 0x0114BE, 0x0114B9}, {0x0115BA, 0x0115BA, 0x0115B8}, {0x0115BB, 0x0115BB, 0x0115B9}, + {0x011938, 0x011938, 0x011935}, {0x01D15E, 0x01D15E, 0x01D157}, {0x01D15F, 0x01D164, 0x01D158}, + {0x01D1BB, 0x01D1BB, 0x01D1B9}, {0x01D1BC, 0x01D1BC, 0x01D1BA}, {0x01D1BD, 0x01D1BD, 0x01D1B9}, + {0x01D1BE, 0x01D1BE, 0x01D1BA}, {0x01D1BF, 0x01D1BF, 0x01D1B9}, {0x01D1C0, 0x01D1C0, 0x01D1BA}, + {0x02F800, 0x02F800, 0x004E3D}, {0x02F801, 0x02F801, 0x004E38}, {0x02F802, 0x02F802, 0x004E41}, + {0x02F803, 0x02F803, 0x020122}, {0x02F804, 0x02F804, 0x004F60}, {0x02F805, 0x02F805, 0x004FAE}, + {0x02F806, 0x02F806, 0x004FBB}, {0x02F807, 0x02F807, 0x005002}, {0x02F808, 0x02F808, 0x00507A}, + {0x02F809, 0x02F809, 0x005099}, {0x02F80A, 0x02F80A, 0x0050E7}, {0x02F80B, 0x02F80B, 0x0050CF}, + {0x02F80C, 0x02F80C, 0x00349E}, {0x02F80D, 0x02F80D, 0x02063A}, {0x02F80E, 0x02F80E, 0x00514D}, + {0x02F80F, 0x02F80F, 0x005154}, {0x02F810, 0x02F810, 0x005164}, {0x02F811, 0x02F811, 0x005177}, + {0x02F812, 0x02F812, 0x02051C}, {0x02F813, 0x02F813, 0x0034B9}, {0x02F814, 0x02F814, 0x005167}, + {0x02F815, 0x02F815, 0x00518D}, {0x02F816, 0x02F816, 0x02054B}, {0x02F817, 0x02F817, 0x005197}, + {0x02F818, 0x02F818, 0x0051A4}, {0x02F819, 0x02F819, 0x004ECC}, {0x02F81A, 0x02F81A, 0x0051AC}, + {0x02F81B, 0x02F81B, 0x0051B5}, {0x02F81C, 0x02F81C, 0x0291DF}, {0x02F81D, 0x02F81D, 0x0051F5}, + {0x02F81E, 0x02F81E, 0x005203}, {0x02F81F, 0x02F81F, 0x0034DF}, {0x02F820, 0x02F820, 0x00523B}, + {0x02F821, 0x02F821, 0x005246}, {0x02F822, 0x02F822, 0x005272}, {0x02F823, 0x02F823, 0x005277}, + {0x02F824, 0x02F824, 0x003515}, {0x02F825, 0x02F825, 0x0052C7}, {0x02F826, 0x02F826, 0x0052C9}, + {0x02F827, 0x02F827, 0x0052E4}, {0x02F828, 0x02F828, 0x0052FA}, {0x02F829, 0x02F829, 0x005305}, + {0x02F82A, 0x02F82A, 0x005306}, {0x02F82B, 0x02F82B, 0x005317}, {0x02F82C, 0x02F82C, 0x005349}, + {0x02F82D, 0x02F82D, 0x005351}, {0x02F82E, 0x02F82E, 0x00535A}, {0x02F82F, 0x02F82F, 0x005373}, + {0x02F830, 0x02F830, 0x00537D}, {0x02F831, 0x02F833, 0x00537F}, {0x02F834, 0x02F834, 0x020A2C}, + {0x02F835, 0x02F835, 0x007070}, {0x02F836, 0x02F836, 0x0053CA}, {0x02F837, 0x02F837, 0x0053DF}, + {0x02F838, 0x02F838, 0x020B63}, {0x02F839, 0x02F839, 0x0053EB}, {0x02F83A, 0x02F83A, 0x0053F1}, + {0x02F83B, 0x02F83B, 0x005406}, {0x02F83C, 0x02F83C, 0x00549E}, {0x02F83D, 0x02F83D, 0x005438}, + {0x02F83E, 0x02F83E, 0x005448}, {0x02F83F, 0x02F83F, 0x005468}, {0x02F840, 0x02F840, 0x0054A2}, + {0x02F841, 0x02F841, 0x0054F6}, {0x02F842, 0x02F842, 0x005510}, {0x02F843, 0x02F843, 0x005553}, + {0x02F844, 0x02F844, 0x005563}, {0x02F845, 0x02F846, 0x005584}, {0x02F847, 0x02F847, 0x005599}, + {0x02F848, 0x02F848, 0x0055AB}, {0x02F849, 0x02F849, 0x0055B3}, {0x02F84A, 0x02F84A, 0x0055C2}, + {0x02F84B, 0x02F84B, 0x005716}, {0x02F84C, 0x02F84C, 0x005606}, {0x02F84D, 0x02F84D, 0x005717}, + {0x02F84E, 0x02F84E, 0x005651}, {0x02F84F, 0x02F84F, 0x005674}, {0x02F850, 0x02F850, 0x005207}, + {0x02F851, 0x02F851, 0x0058EE}, {0x02F852, 0x02F852, 0x0057CE}, {0x02F853, 0x02F853, 0x0057F4}, + {0x02F854, 0x02F854, 0x00580D}, {0x02F855, 0x02F855, 0x00578B}, {0x02F856, 0x02F856, 0x005832}, + {0x02F857, 0x02F857, 0x005831}, {0x02F858, 0x02F858, 0x0058AC}, {0x02F859, 0x02F859, 0x0214E4}, + {0x02F85A, 0x02F85A, 0x0058F2}, {0x02F85B, 0x02F85B, 0x0058F7}, {0x02F85C, 0x02F85C, 0x005906}, + {0x02F85D, 0x02F85D, 0x00591A}, {0x02F85E, 0x02F85E, 0x005922}, {0x02F85F, 0x02F85F, 0x005962}, + {0x02F860, 0x02F860, 0x0216A8}, {0x02F861, 0x02F861, 0x0216EA}, {0x02F862, 0x02F862, 0x0059EC}, + {0x02F863, 0x02F863, 0x005A1B}, {0x02F864, 0x02F864, 0x005A27}, {0x02F865, 0x02F865, 0x0059D8}, + {0x02F866, 0x02F866, 0x005A66}, {0x02F867, 0x02F867, 0x0036EE}, {0x02F868, 0x02F868, 0x0036FC}, + {0x02F869, 0x02F869, 0x005B08}, {0x02F86A, 0x02F86B, 0x005B3E}, {0x02F86C, 0x02F86C, 0x0219C8}, + {0x02F86D, 0x02F86D, 0x005BC3}, {0x02F86E, 0x02F86E, 0x005BD8}, {0x02F86F, 0x02F86F, 0x005BE7}, + {0x02F870, 0x02F870, 0x005BF3}, {0x02F871, 0x02F871, 0x021B18}, {0x02F872, 0x02F872, 0x005BFF}, + {0x02F873, 0x02F873, 0x005C06}, {0x02F874, 0x02F874, 0x005F53}, {0x02F875, 0x02F875, 0x005C22}, + {0x02F876, 0x02F876, 0x003781}, {0x02F877, 0x02F877, 0x005C60}, {0x02F878, 0x02F878, 0x005C6E}, + {0x02F879, 0x02F879, 0x005CC0}, {0x02F87A, 0x02F87A, 0x005C8D}, {0x02F87B, 0x02F87B, 0x021DE4}, + {0x02F87C, 0x02F87C, 0x005D43}, {0x02F87D, 0x02F87D, 0x021DE6}, {0x02F87E, 0x02F87E, 0x005D6E}, + {0x02F87F, 0x02F87F, 0x005D6B}, {0x02F880, 0x02F880, 0x005D7C}, {0x02F881, 0x02F881, 0x005DE1}, + {0x02F882, 0x02F882, 0x005DE2}, {0x02F883, 0x02F883, 0x00382F}, {0x02F884, 0x02F884, 0x005DFD}, + {0x02F885, 0x02F885, 0x005E28}, {0x02F886, 0x02F886, 0x005E3D}, {0x02F887, 0x02F887, 0x005E69}, + {0x02F888, 0x02F888, 0x003862}, {0x02F889, 0x02F889, 0x022183}, {0x02F88A, 0x02F88A, 0x00387C}, + {0x02F88B, 0x02F88B, 0x005EB0}, {0x02F88C, 0x02F88C, 0x005EB3}, {0x02F88D, 0x02F88D, 0x005EB6}, + {0x02F88E, 0x02F88E, 0x005ECA}, {0x02F88F, 0x02F88F, 0x02A392}, {0x02F890, 0x02F890, 0x005EFE}, + {0x02F891, 0x02F892, 0x022331}, {0x02F893, 0x02F893, 0x008201}, {0x02F894, 0x02F895, 0x005F22}, + {0x02F896, 0x02F896, 0x0038C7}, {0x02F897, 0x02F897, 0x0232B8}, {0x02F898, 0x02F898, 0x0261DA}, + {0x02F899, 0x02F899, 0x005F62}, {0x02F89A, 0x02F89A, 0x005F6B}, {0x02F89B, 0x02F89B, 0x0038E3}, + {0x02F89C, 0x02F89C, 0x005F9A}, {0x02F89D, 0x02F89D, 0x005FCD}, {0x02F89E, 0x02F89E, 0x005FD7}, + {0x02F89F, 0x02F89F, 0x005FF9}, {0x02F8A0, 0x02F8A0, 0x006081}, {0x02F8A1, 0x02F8A1, 0x00393A}, + {0x02F8A2, 0x02F8A2, 0x00391C}, {0x02F8A3, 0x02F8A3, 0x006094}, {0x02F8A4, 0x02F8A4, 0x0226D4}, + {0x02F8A5, 0x02F8A5, 0x0060C7}, {0x02F8A6, 0x02F8A6, 0x006148}, {0x02F8A7, 0x02F8A7, 0x00614C}, + {0x02F8A8, 0x02F8A8, 0x00614E}, {0x02F8A9, 0x02F8A9, 0x00614C}, {0x02F8AA, 0x02F8AA, 0x00617A}, + {0x02F8AB, 0x02F8AB, 0x00618E}, {0x02F8AC, 0x02F8AC, 0x0061B2}, {0x02F8AD, 0x02F8AD, 0x0061A4}, + {0x02F8AE, 0x02F8AE, 0x0061AF}, {0x02F8AF, 0x02F8AF, 0x0061DE}, {0x02F8B0, 0x02F8B0, 0x0061F2}, + {0x02F8B1, 0x02F8B1, 0x0061F6}, {0x02F8B2, 0x02F8B2, 0x006210}, {0x02F8B3, 0x02F8B3, 0x00621B}, + {0x02F8B4, 0x02F8B4, 0x00625D}, {0x02F8B5, 0x02F8B5, 0x0062B1}, {0x02F8B6, 0x02F8B6, 0x0062D4}, + {0x02F8B7, 0x02F8B7, 0x006350}, {0x02F8B8, 0x02F8B8, 0x022B0C}, {0x02F8B9, 0x02F8B9, 0x00633D}, + {0x02F8BA, 0x02F8BA, 0x0062FC}, {0x02F8BB, 0x02F8BB, 0x006368}, {0x02F8BC, 0x02F8BC, 0x006383}, + {0x02F8BD, 0x02F8BD, 0x0063E4}, {0x02F8BE, 0x02F8BE, 0x022BF1}, {0x02F8BF, 0x02F8BF, 0x006422}, + {0x02F8C0, 0x02F8C0, 0x0063C5}, {0x02F8C1, 0x02F8C1, 0x0063A9}, {0x02F8C2, 0x02F8C2, 0x003A2E}, + {0x02F8C3, 0x02F8C3, 0x006469}, {0x02F8C4, 0x02F8C4, 0x00647E}, {0x02F8C5, 0x02F8C5, 0x00649D}, + {0x02F8C6, 0x02F8C6, 0x006477}, {0x02F8C7, 0x02F8C7, 0x003A6C}, {0x02F8C8, 0x02F8C8, 0x00654F}, + {0x02F8C9, 0x02F8C9, 0x00656C}, {0x02F8CA, 0x02F8CA, 0x02300A}, {0x02F8CB, 0x02F8CB, 0x0065E3}, + {0x02F8CC, 0x02F8CC, 0x0066F8}, {0x02F8CD, 0x02F8CD, 0x006649}, {0x02F8CE, 0x02F8CE, 0x003B19}, + {0x02F8CF, 0x02F8CF, 0x006691}, {0x02F8D0, 0x02F8D0, 0x003B08}, {0x02F8D1, 0x02F8D1, 0x003AE4}, + {0x02F8D2, 0x02F8D2, 0x005192}, {0x02F8D3, 0x02F8D3, 0x005195}, {0x02F8D4, 0x02F8D4, 0x006700}, + {0x02F8D5, 0x02F8D5, 0x00669C}, {0x02F8D6, 0x02F8D6, 0x0080AD}, {0x02F8D7, 0x02F8D7, 0x0043D9}, + {0x02F8D8, 0x02F8D8, 0x006717}, {0x02F8D9, 0x02F8D9, 0x00671B}, {0x02F8DA, 0x02F8DA, 0x006721}, + {0x02F8DB, 0x02F8DB, 0x00675E}, {0x02F8DC, 0x02F8DC, 0x006753}, {0x02F8DD, 0x02F8DD, 0x0233C3}, + {0x02F8DE, 0x02F8DE, 0x003B49}, {0x02F8DF, 0x02F8DF, 0x0067FA}, {0x02F8E0, 0x02F8E0, 0x006785}, + {0x02F8E1, 0x02F8E1, 0x006852}, {0x02F8E2, 0x02F8E2, 0x006885}, {0x02F8E3, 0x02F8E3, 0x02346D}, + {0x02F8E4, 0x02F8E4, 0x00688E}, {0x02F8E5, 0x02F8E5, 0x00681F}, {0x02F8E6, 0x02F8E6, 0x006914}, + {0x02F8E7, 0x02F8E7, 0x003B9D}, {0x02F8E8, 0x02F8E8, 0x006942}, {0x02F8E9, 0x02F8E9, 0x0069A3}, + {0x02F8EA, 0x02F8EA, 0x0069EA}, {0x02F8EB, 0x02F8EB, 0x006AA8}, {0x02F8EC, 0x02F8EC, 0x0236A3}, + {0x02F8ED, 0x02F8ED, 0x006ADB}, {0x02F8EE, 0x02F8EE, 0x003C18}, {0x02F8EF, 0x02F8EF, 0x006B21}, + {0x02F8F0, 0x02F8F0, 0x0238A7}, {0x02F8F1, 0x02F8F1, 0x006B54}, {0x02F8F2, 0x02F8F2, 0x003C4E}, + {0x02F8F3, 0x02F8F3, 0x006B72}, {0x02F8F4, 0x02F8F4, 0x006B9F}, {0x02F8F5, 0x02F8F5, 0x006BBA}, + {0x02F8F6, 0x02F8F6, 0x006BBB}, {0x02F8F7, 0x02F8F7, 0x023A8D}, {0x02F8F8, 0x02F8F8, 0x021D0B}, + {0x02F8F9, 0x02F8F9, 0x023AFA}, {0x02F8FA, 0x02F8FA, 0x006C4E}, {0x02F8FB, 0x02F8FB, 0x023CBC}, + {0x02F8FC, 0x02F8FC, 0x006CBF}, {0x02F8FD, 0x02F8FD, 0x006CCD}, {0x02F8FE, 0x02F8FE, 0x006C67}, + {0x02F8FF, 0x02F8FF, 0x006D16}, {0x02F900, 0x02F900, 0x006D3E}, {0x02F901, 0x02F901, 0x006D77}, + {0x02F902, 0x02F902, 0x006D41}, {0x02F903, 0x02F903, 0x006D69}, {0x02F904, 0x02F904, 0x006D78}, + {0x02F905, 0x02F905, 0x006D85}, {0x02F906, 0x02F906, 0x023D1E}, {0x02F907, 0x02F907, 0x006D34}, + {0x02F908, 0x02F908, 0x006E2F}, {0x02F909, 0x02F909, 0x006E6E}, {0x02F90A, 0x02F90A, 0x003D33}, + {0x02F90B, 0x02F90B, 0x006ECB}, {0x02F90C, 0x02F90C, 0x006EC7}, {0x02F90D, 0x02F90D, 0x023ED1}, + {0x02F90E, 0x02F90E, 0x006DF9}, {0x02F90F, 0x02F90F, 0x006F6E}, {0x02F910, 0x02F910, 0x023F5E}, + {0x02F911, 0x02F911, 0x023F8E}, {0x02F912, 0x02F912, 0x006FC6}, {0x02F913, 0x02F913, 0x007039}, + {0x02F914, 0x02F914, 0x00701E}, {0x02F915, 0x02F915, 0x00701B}, {0x02F916, 0x02F916, 0x003D96}, + {0x02F917, 0x02F917, 0x00704A}, {0x02F918, 0x02F918, 0x00707D}, {0x02F919, 0x02F919, 0x007077}, + {0x02F91A, 0x02F91A, 0x0070AD}, {0x02F91B, 0x02F91B, 0x020525}, {0x02F91C, 0x02F91C, 0x007145}, + {0x02F91D, 0x02F91D, 0x024263}, {0x02F91E, 0x02F91E, 0x00719C}, {0x02F91F, 0x02F91F, 0x0243AB}, + {0x02F920, 0x02F920, 0x007228}, {0x02F921, 0x02F921, 0x007235}, {0x02F922, 0x02F922, 0x007250}, + {0x02F923, 0x02F923, 0x024608}, {0x02F924, 0x02F924, 0x007280}, {0x02F925, 0x02F925, 0x007295}, + {0x02F926, 0x02F926, 0x024735}, {0x02F927, 0x02F927, 0x024814}, {0x02F928, 0x02F928, 0x00737A}, + {0x02F929, 0x02F929, 0x00738B}, {0x02F92A, 0x02F92A, 0x003EAC}, {0x02F92B, 0x02F92B, 0x0073A5}, + {0x02F92C, 0x02F92D, 0x003EB8}, {0x02F92E, 0x02F92E, 0x007447}, {0x02F92F, 0x02F92F, 0x00745C}, + {0x02F930, 0x02F930, 0x007471}, {0x02F931, 0x02F931, 0x007485}, {0x02F932, 0x02F932, 0x0074CA}, + {0x02F933, 0x02F933, 0x003F1B}, {0x02F934, 0x02F934, 0x007524}, {0x02F935, 0x02F935, 0x024C36}, + {0x02F936, 0x02F936, 0x00753E}, {0x02F937, 0x02F937, 0x024C92}, {0x02F938, 0x02F938, 0x007570}, + {0x02F939, 0x02F939, 0x02219F}, {0x02F93A, 0x02F93A, 0x007610}, {0x02F93B, 0x02F93B, 0x024FA1}, + {0x02F93C, 0x02F93C, 0x024FB8}, {0x02F93D, 0x02F93D, 0x025044}, {0x02F93E, 0x02F93E, 0x003FFC}, + {0x02F93F, 0x02F93F, 0x004008}, {0x02F940, 0x02F940, 0x0076F4}, {0x02F941, 0x02F941, 0x0250F3}, + {0x02F942, 0x02F942, 0x0250F2}, {0x02F943, 0x02F943, 0x025119}, {0x02F944, 0x02F944, 0x025133}, + {0x02F945, 0x02F945, 0x00771E}, {0x02F946, 0x02F947, 0x00771F}, {0x02F948, 0x02F948, 0x00774A}, + {0x02F949, 0x02F949, 0x004039}, {0x02F94A, 0x02F94A, 0x00778B}, {0x02F94B, 0x02F94B, 0x004046}, + {0x02F94C, 0x02F94C, 0x004096}, {0x02F94D, 0x02F94D, 0x02541D}, {0x02F94E, 0x02F94E, 0x00784E}, + {0x02F94F, 0x02F94F, 0x00788C}, {0x02F950, 0x02F950, 0x0078CC}, {0x02F951, 0x02F951, 0x0040E3}, + {0x02F952, 0x02F952, 0x025626}, {0x02F953, 0x02F953, 0x007956}, {0x02F954, 0x02F954, 0x02569A}, + {0x02F955, 0x02F955, 0x0256C5}, {0x02F956, 0x02F956, 0x00798F}, {0x02F957, 0x02F957, 0x0079EB}, + {0x02F958, 0x02F958, 0x00412F}, {0x02F959, 0x02F959, 0x007A40}, {0x02F95A, 0x02F95A, 0x007A4A}, + {0x02F95B, 0x02F95B, 0x007A4F}, {0x02F95C, 0x02F95C, 0x02597C}, {0x02F95D, 0x02F95E, 0x025AA7}, + {0x02F95F, 0x02F95F, 0x007AEE}, {0x02F960, 0x02F960, 0x004202}, {0x02F961, 0x02F961, 0x025BAB}, + {0x02F962, 0x02F962, 0x007BC6}, {0x02F963, 0x02F963, 0x007BC9}, {0x02F964, 0x02F964, 0x004227}, + {0x02F965, 0x02F965, 0x025C80}, {0x02F966, 0x02F966, 0x007CD2}, {0x02F967, 0x02F967, 0x0042A0}, + {0x02F968, 0x02F968, 0x007CE8}, {0x02F969, 0x02F969, 0x007CE3}, {0x02F96A, 0x02F96A, 0x007D00}, + {0x02F96B, 0x02F96B, 0x025F86}, {0x02F96C, 0x02F96C, 0x007D63}, {0x02F96D, 0x02F96D, 0x004301}, + {0x02F96E, 0x02F96E, 0x007DC7}, {0x02F96F, 0x02F96F, 0x007E02}, {0x02F970, 0x02F970, 0x007E45}, + {0x02F971, 0x02F971, 0x004334}, {0x02F972, 0x02F972, 0x026228}, {0x02F973, 0x02F973, 0x026247}, + {0x02F974, 0x02F974, 0x004359}, {0x02F975, 0x02F975, 0x0262D9}, {0x02F976, 0x02F976, 0x007F7A}, + {0x02F977, 0x02F977, 0x02633E}, {0x02F978, 0x02F978, 0x007F95}, {0x02F979, 0x02F979, 0x007FFA}, + {0x02F97A, 0x02F97A, 0x008005}, {0x02F97B, 0x02F97B, 0x0264DA}, {0x02F97C, 0x02F97C, 0x026523}, + {0x02F97D, 0x02F97D, 0x008060}, {0x02F97E, 0x02F97E, 0x0265A8}, {0x02F97F, 0x02F97F, 0x008070}, + {0x02F980, 0x02F980, 0x02335F}, {0x02F981, 0x02F981, 0x0043D5}, {0x02F982, 0x02F982, 0x0080B2}, + {0x02F983, 0x02F983, 0x008103}, {0x02F984, 0x02F984, 0x00440B}, {0x02F985, 0x02F985, 0x00813E}, + {0x02F986, 0x02F986, 0x005AB5}, {0x02F987, 0x02F987, 0x0267A7}, {0x02F988, 0x02F988, 0x0267B5}, + {0x02F989, 0x02F989, 0x023393}, {0x02F98A, 0x02F98A, 0x02339C}, {0x02F98B, 0x02F98B, 0x008201}, + {0x02F98C, 0x02F98C, 0x008204}, {0x02F98D, 0x02F98D, 0x008F9E}, {0x02F98E, 0x02F98E, 0x00446B}, + {0x02F98F, 0x02F98F, 0x008291}, {0x02F990, 0x02F990, 0x00828B}, {0x02F991, 0x02F991, 0x00829D}, + {0x02F992, 0x02F992, 0x0052B3}, {0x02F993, 0x02F993, 0x0082B1}, {0x02F994, 0x02F994, 0x0082B3}, + {0x02F995, 0x02F995, 0x0082BD}, {0x02F996, 0x02F996, 0x0082E6}, {0x02F997, 0x02F997, 0x026B3C}, + {0x02F998, 0x02F998, 0x0082E5}, {0x02F999, 0x02F999, 0x00831D}, {0x02F99A, 0x02F99A, 0x008363}, + {0x02F99B, 0x02F99B, 0x0083AD}, {0x02F99C, 0x02F99C, 0x008323}, {0x02F99D, 0x02F99D, 0x0083BD}, + {0x02F99E, 0x02F99E, 0x0083E7}, {0x02F99F, 0x02F99F, 0x008457}, {0x02F9A0, 0x02F9A0, 0x008353}, + {0x02F9A1, 0x02F9A1, 0x0083CA}, {0x02F9A2, 0x02F9A2, 0x0083CC}, {0x02F9A3, 0x02F9A3, 0x0083DC}, + {0x02F9A4, 0x02F9A4, 0x026C36}, {0x02F9A5, 0x02F9A5, 0x026D6B}, {0x02F9A6, 0x02F9A6, 0x026CD5}, + {0x02F9A7, 0x02F9A7, 0x00452B}, {0x02F9A8, 0x02F9A8, 0x0084F1}, {0x02F9A9, 0x02F9A9, 0x0084F3}, + {0x02F9AA, 0x02F9AA, 0x008516}, {0x02F9AB, 0x02F9AB, 0x0273CA}, {0x02F9AC, 0x02F9AC, 0x008564}, + {0x02F9AD, 0x02F9AD, 0x026F2C}, {0x02F9AE, 0x02F9AE, 0x00455D}, {0x02F9AF, 0x02F9AF, 0x004561}, + {0x02F9B0, 0x02F9B0, 0x026FB1}, {0x02F9B1, 0x02F9B1, 0x0270D2}, {0x02F9B2, 0x02F9B2, 0x00456B}, + {0x02F9B3, 0x02F9B3, 0x008650}, {0x02F9B4, 0x02F9B4, 0x00865C}, {0x02F9B5, 0x02F9B5, 0x008667}, + {0x02F9B6, 0x02F9B6, 0x008669}, {0x02F9B7, 0x02F9B7, 0x0086A9}, {0x02F9B8, 0x02F9B8, 0x008688}, + {0x02F9B9, 0x02F9B9, 0x00870E}, {0x02F9BA, 0x02F9BA, 0x0086E2}, {0x02F9BB, 0x02F9BB, 0x008779}, + {0x02F9BC, 0x02F9BC, 0x008728}, {0x02F9BD, 0x02F9BD, 0x00876B}, {0x02F9BE, 0x02F9BE, 0x008786}, + {0x02F9BF, 0x02F9BF, 0x0045D7}, {0x02F9C0, 0x02F9C0, 0x0087E1}, {0x02F9C1, 0x02F9C1, 0x008801}, + {0x02F9C2, 0x02F9C2, 0x0045F9}, {0x02F9C3, 0x02F9C3, 0x008860}, {0x02F9C4, 0x02F9C4, 0x008863}, + {0x02F9C5, 0x02F9C5, 0x027667}, {0x02F9C6, 0x02F9C6, 0x0088D7}, {0x02F9C7, 0x02F9C7, 0x0088DE}, + {0x02F9C8, 0x02F9C8, 0x004635}, {0x02F9C9, 0x02F9C9, 0x0088FA}, {0x02F9CA, 0x02F9CA, 0x0034BB}, + {0x02F9CB, 0x02F9CB, 0x0278AE}, {0x02F9CC, 0x02F9CC, 0x027966}, {0x02F9CD, 0x02F9CD, 0x0046BE}, + {0x02F9CE, 0x02F9CE, 0x0046C7}, {0x02F9CF, 0x02F9CF, 0x008AA0}, {0x02F9D0, 0x02F9D0, 0x008AED}, + {0x02F9D1, 0x02F9D1, 0x008B8A}, {0x02F9D2, 0x02F9D2, 0x008C55}, {0x02F9D3, 0x02F9D3, 0x027CA8}, + {0x02F9D4, 0x02F9D4, 0x008CAB}, {0x02F9D5, 0x02F9D5, 0x008CC1}, {0x02F9D6, 0x02F9D6, 0x008D1B}, + {0x02F9D7, 0x02F9D7, 0x008D77}, {0x02F9D8, 0x02F9D8, 0x027F2F}, {0x02F9D9, 0x02F9D9, 0x020804}, + {0x02F9DA, 0x02F9DA, 0x008DCB}, {0x02F9DB, 0x02F9DB, 0x008DBC}, {0x02F9DC, 0x02F9DC, 0x008DF0}, + {0x02F9DD, 0x02F9DD, 0x0208DE}, {0x02F9DE, 0x02F9DE, 0x008ED4}, {0x02F9DF, 0x02F9DF, 0x008F38}, + {0x02F9E0, 0x02F9E0, 0x0285D2}, {0x02F9E1, 0x02F9E1, 0x0285ED}, {0x02F9E2, 0x02F9E2, 0x009094}, + {0x02F9E3, 0x02F9E3, 0x0090F1}, {0x02F9E4, 0x02F9E4, 0x009111}, {0x02F9E5, 0x02F9E5, 0x02872E}, + {0x02F9E6, 0x02F9E6, 0x00911B}, {0x02F9E7, 0x02F9E7, 0x009238}, {0x02F9E8, 0x02F9E8, 0x0092D7}, + {0x02F9E9, 0x02F9E9, 0x0092D8}, {0x02F9EA, 0x02F9EA, 0x00927C}, {0x02F9EB, 0x02F9EB, 0x0093F9}, + {0x02F9EC, 0x02F9EC, 0x009415}, {0x02F9ED, 0x02F9ED, 0x028BFA}, {0x02F9EE, 0x02F9EE, 0x00958B}, + {0x02F9EF, 0x02F9EF, 0x004995}, {0x02F9F0, 0x02F9F0, 0x0095B7}, {0x02F9F1, 0x02F9F1, 0x028D77}, + {0x02F9F2, 0x02F9F2, 0x0049E6}, {0x02F9F3, 0x02F9F3, 0x0096C3}, {0x02F9F4, 0x02F9F4, 0x005DB2}, + {0x02F9F5, 0x02F9F5, 0x009723}, {0x02F9F6, 0x02F9F6, 0x029145}, {0x02F9F7, 0x02F9F7, 0x02921A}, + {0x02F9F8, 0x02F9F8, 0x004A6E}, {0x02F9F9, 0x02F9F9, 0x004A76}, {0x02F9FA, 0x02F9FA, 0x0097E0}, + {0x02F9FB, 0x02F9FB, 0x02940A}, {0x02F9FC, 0x02F9FC, 0x004AB2}, {0x02F9FD, 0x02F9FD, 0x029496}, + {0x02F9FE, 0x02F9FF, 0x00980B}, {0x02FA00, 0x02FA00, 0x009829}, {0x02FA01, 0x02FA01, 0x0295B6}, + {0x02FA02, 0x02FA02, 0x0098E2}, {0x02FA03, 0x02FA03, 0x004B33}, {0x02FA04, 0x02FA04, 0x009929}, + {0x02FA05, 0x02FA05, 0x0099A7}, {0x02FA06, 0x02FA06, 0x0099C2}, {0x02FA07, 0x02FA07, 0x0099FE}, + {0x02FA08, 0x02FA08, 0x004BCE}, {0x02FA09, 0x02FA09, 0x029B30}, {0x02FA0A, 0x02FA0A, 0x009B12}, + {0x02FA0B, 0x02FA0B, 0x009C40}, {0x02FA0C, 0x02FA0C, 0x009CFD}, {0x02FA0D, 0x02FA0D, 0x004CCE}, + {0x02FA0E, 0x02FA0E, 0x004CED}, {0x02FA0F, 0x02FA0F, 0x009D67}, {0x02FA10, 0x02FA10, 0x02A0CE}, + {0x02FA11, 0x02FA11, 0x004CF8}, {0x02FA12, 0x02FA12, 0x02A105}, {0x02FA13, 0x02FA13, 0x02A20E}, + {0x02FA14, 0x02FA14, 0x02A291}, {0x02FA15, 0x02FA15, 0x009EBB}, {0x02FA16, 0x02FA16, 0x004D56}, + {0x02FA17, 0x02FA17, 0x009EF9}, {0x02FA18, 0x02FA18, 0x009EFE}, {0x02FA19, 0x02FA19, 0x009F05}, + {0x02FA1A, 0x02FA1A, 0x009F0F}, {0x02FA1B, 0x02FA1B, 0x009F16}, {0x02FA1C, 0x02FA1C, 0x009F3B}, + {0x02FA1D, 0x02FA1D, 0x02A600}, +}; diff --git a/mllm/preprocessor/tokenizers/llama_cpp_unicode/unicode-data.h b/mllm/preprocessor/tokenizers/llama_cpp_unicode/unicode-data.h new file mode 100644 index 000000000..90191cedd --- /dev/null +++ b/mllm/preprocessor/tokenizers/llama_cpp_unicode/unicode-data.h @@ -0,0 +1,54 @@ +/* +llama.cpp - commit 54ef9cfc +https://github.com/ggerganov/llama.cpp + +MIT License + +Copyright (c) 2023-2024 The ggml authors + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. +*/ + +#pragma once + +#include +#include +#include +#include +#include + +struct range_nfd { + uint32_t first; + uint32_t last; + uint32_t nfd; +}; + +static const uint32_t MAX_CODEPOINTS = 0x110000; + +extern const std::initializer_list> unicode_ranges_flags; + +constexpr std::array unicode_set_whitespace = { + 0x000009, 0x00000A, 0x00000B, 0x00000C, 0x00000D, 0x000020, 0x000085, 0x0000A0, 0x001680, + 0x002000, 0x002001, 0x002002, 0x002003, 0x002004, 0x002005, 0x002006, 0x002007, 0x002008, + 0x002009, 0x00200A, 0x002028, 0x002029, 0x00202F, 0x00205F, 0x003000, +}; + +extern const std::initializer_list> unicode_map_lowercase; +extern const std::initializer_list> unicode_map_uppercase; +extern const std::initializer_list unicode_ranges_nfd; diff --git a/mllm/preprocessor/tokenizers/llama_cpp_unicode/unicode.cpp b/mllm/preprocessor/tokenizers/llama_cpp_unicode/unicode.cpp new file mode 100644 index 000000000..da8a79e54 --- /dev/null +++ b/mllm/preprocessor/tokenizers/llama_cpp_unicode/unicode.cpp @@ -0,0 +1,1068 @@ +/* +llama.cpp - commit 54ef9cfc +https://github.com/ggerganov/llama.cpp + +MIT License + +Copyright (c) 2023-2024 The ggml authors + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. +*/ + +#if defined(_MSC_VER) +#define _SILENCE_CXX17_CODECVT_HEADER_DEPRECATION_WARNING +#endif + +#include "unicode.h" +#include "unicode-data.h" + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +// Hash function for std::pair used in composition table +namespace std { +template<> +struct hash> { + std::size_t operator()(const std::pair& p) const { + return std::hash{}(((uint64_t)p.first << 32) | p.second); + } +}; +} // namespace std + +size_t unicode_len_utf8(char src) { + const size_t lookup[] = {1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 3, 4}; + uint8_t highbits = static_cast(src) >> 4; + return lookup[highbits]; +} + +// Unused function +// static std::string unicode_cpts_to_utf8(const std::vector &cps) { +// std::string result; +// for (size_t i = 0; i < cps.size(); ++i) { +// result.append(unicode_cpt_to_utf8(cps[i])); +// } +// return result; +// } + +uint32_t unicode_cpt_from_utf8(const std::string& utf8, size_t& offset) { + assert(offset < utf8.size()); + if (!(utf8[offset + 0] & 0x80)) { + auto result = utf8[offset + 0]; + offset += 1; + return result; + } + if (!(utf8[offset + 0] & 0x40)) { throw std::invalid_argument("invalid character"); } + if (!(utf8[offset + 0] & 0x20)) { + if (offset + 1 >= utf8.size() || !((utf8[offset + 1] & 0xc0) == 0x80)) { throw std::invalid_argument("invalid character"); } + auto result = ((utf8[offset + 0] & 0x1f) << 6) | (utf8[offset + 1] & 0x3f); + offset += 2; + return result; + } + if (!(utf8[offset + 0] & 0x10)) { + if (offset + 2 >= utf8.size() || !((utf8[offset + 1] & 0xc0) == 0x80) || !((utf8[offset + 2] & 0xc0) == 0x80)) { + throw std::invalid_argument("invalid character"); + } + auto result = ((utf8[offset + 0] & 0x0f) << 12) | ((utf8[offset + 1] & 0x3f) << 6) | (utf8[offset + 2] & 0x3f); + offset += 3; + return result; + } + if (!(utf8[offset + 0] & 0x08)) { + if (offset + 3 >= utf8.size() || !((utf8[offset + 1] & 0xc0) == 0x80) || !((utf8[offset + 2] & 0xc0) == 0x80) + || !((utf8[offset + 3] & 0xc0) == 0x80)) { + throw std::invalid_argument("invalid character"); + } + auto result = ((utf8[offset + 0] & 0x07) << 18) | ((utf8[offset + 1] & 0x3f) << 12) | ((utf8[offset + 2] & 0x3f) << 6) + | (utf8[offset + 3] & 0x3f); + offset += 4; + return result; + } + throw std::invalid_argument("failed to convert utf8 to codepoint"); +} + +// static std::vector unicode_cpt_to_utf16(uint32_t cp) { +// std::vector result; +// if (/* 0x0000 <= cp && */ cp <= 0xffff) { +// result.emplace_back(cp); +// return result; +// } +// if (0x10000 <= cp && cp <= 0x10ffff) { +// result.emplace_back(0xd800 | ((cp - 0x10000) >> 10)); +// result.emplace_back(0xdc00 | ((cp - 0x10000) & 0x03ff)); +// return result; +// } +// throw std::invalid_argument("failed to convert codepoint to utf16"); +// } + +// static std::vector unicode_cpts_to_utf16(const +// std::vector & cps) { +// std::vector result; +// for (size_t i = 0; i < cps.size(); ++i) { +// auto temp = unicode_cpt_to_utf16(cps[i]); +// result.insert(result.end(), temp.begin(), temp.end()); +// } +// return result; +// } + +// static uint32_t unicode_cpt_from_utf16(const std::vector & utf16, +// size_t & offset) { +// assert(offset < utf16.size()); +// if (((utf16[0] >> 10) << 10) != 0xd800) { +// auto result = utf16[offset + 0]; +// offset += 1; +// return result; +// } +// +// if (offset + 1 >= utf16.size() || !((utf16[1] & 0xdc00) == 0xdc00)) { +// throw std::invalid_argument("invalid character"); +// } +// +// auto result = 0x10000 + (((utf16[0] & 0x03ff) << 10) | (utf16[1] & +// 0x03ff)); offset += 2; return result; +// } + +// static std::vector unicode_cpts_from_utf16(const +// std::vector & utf16) { +// std::vector result; +// size_t offset = 0; +// while (offset < utf16.size()) { +// result.push_back(unicode_cpt_from_utf16(utf16, offset)); +// } +// return result; +// } + +static std::vector unicode_cpt_flags_array() { + std::vector cpt_flags(MAX_CODEPOINTS, codepoint_flags::UNDEFINED); + + assert(unicode_ranges_flags.begin()[0].first == 0); + assert(unicode_ranges_flags.begin()[unicode_ranges_flags.size() - 1].first == MAX_CODEPOINTS); + for (size_t i = 1; i < unicode_ranges_flags.size(); ++i) { + const auto range_ini = unicode_ranges_flags.begin()[i - 1]; // codepoint_ini, flags + const auto range_end = unicode_ranges_flags.begin()[i]; // codepoint_end, flags + for (uint32_t cpt = range_ini.first; cpt < range_end.first; ++cpt) { cpt_flags[cpt] = range_ini.second; } + } + + for (auto cpt : unicode_set_whitespace) { cpt_flags[cpt].is_whitespace = true; } + + for (auto p : unicode_map_lowercase) { cpt_flags[p.second].is_lowercase = true; } + + for (auto p : unicode_map_uppercase) { cpt_flags[p.second].is_uppercase = true; } + + for (auto& range : unicode_ranges_nfd) { // start, last, nfd + cpt_flags[range.nfd].is_nfd = true; + } + + return cpt_flags; +} + +static std::unordered_map unicode_byte_to_utf8_map() { + std::unordered_map map; + for (int ch = 0x21; ch <= 0x7E; ++ch) { // u'!' to u'~' + assert(0 <= ch && ch < 256); + map[ch] = unicode_cpt_to_utf8(ch); + } + for (int ch = 0xA1; ch <= 0xAC; ++ch) { // u'¡' to u'¬' + assert(0 <= ch && ch < 256); + map[ch] = unicode_cpt_to_utf8(ch); + } + for (int ch = 0xAE; ch <= 0xFF; ++ch) { // u'®' to u'ÿ' + assert(0 <= ch && ch < 256); + map[ch] = unicode_cpt_to_utf8(ch); + } + auto n = 0; + for (int ch = 0; ch < 256; ++ch) { + if (map.find(ch) == map.end()) { + map[ch] = unicode_cpt_to_utf8(256 + n); + ++n; + } + } + return map; +} + +static std::unordered_map unicode_utf8_to_byte_map() { + std::unordered_map map; + for (int ch = 0x21; ch <= 0x7E; ++ch) { // u'!' to u'~' + assert(0 <= ch && ch < 256); + map[unicode_cpt_to_utf8(ch)] = ch; + } + for (int ch = 0xA1; ch <= 0xAC; ++ch) { // u'¡' to u'¬' + assert(0 <= ch && ch < 256); + map[unicode_cpt_to_utf8(ch)] = ch; + } + for (int ch = 0xAE; ch <= 0xFF; ++ch) { // u'®' to u'ÿ' + assert(0 <= ch && ch < 256); + map[unicode_cpt_to_utf8(ch)] = ch; + } + auto n = 0; + for (int ch = 0; ch < 256; ++ch) { + if (map.find(unicode_cpt_to_utf8(ch)) == map.end()) { + map[unicode_cpt_to_utf8(256 + n)] = ch; + ++n; + } + } + return map; +} + +static inline std::wstring unicode_wstring_from_utf8(const std::string& s) { + std::wstring_convert> conv; + return conv.from_bytes(s); +} + +static std::vector unicode_byte_encoding_process(const std::vector& bpe_words) { + std::vector bpe_encoded_words; + for (const auto& word : bpe_words) { + std::string text_utf; + auto utf_word = unicode_cpts_from_utf8(word); + for (size_t i = 0; i < utf_word.size(); ++i) { text_utf += unicode_cpt_to_utf8(utf_word[i]); } + + std::string encoded_token; + for (char& c : text_utf) { encoded_token += unicode_byte_to_utf8(c); } + bpe_encoded_words.emplace_back(encoded_token); + } + return bpe_encoded_words; +} + +// GPT2 system regex: 's|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| +// ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+ +static std::vector unicode_regex_split_custom_gpt2(const std::string& text, const std::vector& offsets) { + std::vector bpe_offsets; // store the offset of each word + bpe_offsets.reserve(offsets.size()); // Reserve memory for the approximate size + + const auto cpts = unicode_cpts_from_utf8(text); + + size_t start = 0; + for (auto offset : offsets) { + const size_t offset_ini = start; + const size_t offset_end = start + offset; + assert(offset_end <= cpts.size()); + start = offset_end; + + static const uint32_t OUT_OF_RANGE = 0xFFFFFFFF; + auto _get_cpt = [&](const size_t pos) -> uint32_t { + return (offset_ini <= pos && pos < offset_end) ? cpts[pos] : OUT_OF_RANGE; + }; + + auto _get_flags = [&](const size_t pos) -> codepoint_flags { + return (offset_ini <= pos && pos < offset_end) ? unicode_cpt_flags(cpts[pos]) : codepoint_flags{}; + }; + + size_t _prev_end = offset_ini; + auto _add_token = [&](const size_t end) -> size_t { + assert(_prev_end <= end && end <= offset_end); + size_t len = end - _prev_end; + if (len > 0) { bpe_offsets.push_back(len); } + _prev_end = end; + // if (len > 0) { + // std::string s = ""; + // for(size_t p = end-len; p < end; p++) + // s += unicode_cpt_to_utf8(cpts[p]); + // printf(">>> '%s'\n", s.c_str()); + // } + return len; + }; + + for (size_t pos = offset_ini; pos < offset_end; /*pos++*/) { + const uint32_t cpt = _get_cpt(pos); + const auto flags = _get_flags(pos); + + // regex: 's|'t|'re|'ve|'m|'ll|'d + if (cpt == '\'' && pos + 1 < offset_end) { + uint32_t cpt_next = _get_cpt(pos + 1); + if (cpt_next == 's' || cpt_next == 't' || cpt_next == 'm' || cpt_next == 'd') { + pos += _add_token(pos + 2); + continue; + } + if (pos + 2 < offset_end) { + uint32_t cpt_next_next = _get_cpt(pos + 2); + if ((cpt_next == 'r' && cpt_next_next == 'e') || (cpt_next == 'v' && cpt_next_next == 'e') + || (cpt_next == 'l' && cpt_next_next == 'l')) { + pos += _add_token(pos + 3); + continue; + } + } + } + + auto flags2 = (cpt == ' ' ? _get_flags(pos + 1) : flags); + // regex: ?\p{L}+ + if (flags2.is_letter) { + pos += (cpt == ' '); + while (flags2.is_letter) { flags2 = _get_flags(++pos); } + _add_token(pos); + continue; + } + // regex: ?\p{N}+ + if (flags2.is_number) { + pos += (cpt == ' '); + while (flags2.is_number) { flags2 = _get_flags(++pos); } + _add_token(pos); + continue; + } + // regex: ?[^\s\p{L}\p{N}]+ + if (!(flags2.is_whitespace | flags2.is_letter | flags2.is_number) && flags2.as_uint()) { + pos += (cpt == ' '); + while (!(flags2.is_whitespace | flags2.is_letter | flags2.is_number) && flags2.as_uint()) { + flags2 = _get_flags(++pos); + } + _add_token(pos); + continue; + } + + size_t num_whitespaces = 0; + while (_get_flags(pos + num_whitespaces).is_whitespace) { num_whitespaces++; } + + // regex: \s+(?!\S) + if (num_whitespaces > 1 && _get_cpt(pos + num_whitespaces) != OUT_OF_RANGE) { + pos += num_whitespaces - 1; + _add_token(pos); + continue; + } + + // regex: \s+ + if (num_whitespaces > 0) { + pos += num_whitespaces; + _add_token(pos); + continue; + } + + // no matches + _add_token(++pos); + } + } + + return bpe_offsets; +} + +// LLAMA3 system regex: +// "(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| +// ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+" +static std::vector unicode_regex_split_custom_llama3(const std::string& text, const std::vector& offsets) { + std::vector bpe_offsets; // store the offset of each word + bpe_offsets.reserve(offsets.size()); // Reserve memory for the approximate size + + const auto cpts = unicode_cpts_from_utf8(text); + + size_t start = 0; + for (auto offset : offsets) { + const size_t offset_ini = start; + const size_t offset_end = start + offset; + assert(offset_end <= cpts.size()); + start = offset_end; + + static const uint32_t OUT_OF_RANGE = 0xFFFFFFFF; + auto _get_cpt = [&](const size_t pos) -> uint32_t { + return (offset_ini <= pos && pos < offset_end) ? cpts[pos] : OUT_OF_RANGE; + }; + + auto _get_flags = [&](const size_t pos) -> codepoint_flags { + return (offset_ini <= pos && pos < offset_end) ? unicode_cpt_flags(cpts[pos]) : codepoint_flags{}; + }; + + size_t _prev_end = offset_ini; + auto _add_token = [&](const size_t end) -> size_t { + assert(_prev_end <= end && end <= offset_end); + size_t len = end - _prev_end; + if (len > 0) { bpe_offsets.push_back(len); } + _prev_end = end; + // if (len > 0) { + // std::string s = ""; + // for(size_t p = end-len; p < end; p++) + // s += unicode_cpt_to_utf8(cpts[p]); + // printf(">>> '%s'\n", s.c_str()); + // } + return len; + }; + + for (size_t pos = offset_ini; pos < offset_end; /*pos++*/) { + const uint32_t cpt = _get_cpt(pos); + const auto flags = _get_flags(pos); + + // regex: (?i:'s|'t|'re|'ve|'m|'ll|'d) // case insensitive + if (cpt == '\'' && pos + 1 < offset_end) { + uint32_t cpt_next = unicode_tolower(_get_cpt(pos + 1)); + if (cpt_next == 's' || cpt_next == 't' || cpt_next == 'm' || cpt_next == 'd') { + pos += _add_token(pos + 2); + continue; + } + if (pos + 2 < offset_end) { + uint32_t cpt_next_next = unicode_tolower(_get_cpt(pos + 2)); + if ((cpt_next == 'r' && cpt_next_next == 'e') || (cpt_next == 'v' && cpt_next_next == 'e') + || (cpt_next == 'l' && cpt_next_next == 'l')) { + pos += _add_token(pos + 3); + continue; + } + } + } + + // regex: [^\r\n\p{L}\p{N}]?\p{L}+ + if (!(cpt == '\r' || cpt == '\n' || flags.is_number)) { + if (flags.is_letter || _get_flags(pos + 1).is_letter) { // one or more letters + pos++; + while (_get_flags(pos).is_letter) { pos++; } + _add_token(pos); + continue; + } + } + + // regex: \p{N}{1,3} + if (flags.is_number) { + size_t ini = pos; + while (_get_flags(pos).is_number) { + if (++pos - ini >= 3) { + _add_token(pos); + ini = pos; + } + } + _add_token(pos); + continue; + } + + // regex: ?[^\s\p{L}\p{N}]+[\r\n]* + auto flags2 = (cpt == ' ' ? _get_flags(pos + 1) : flags); + if (!(flags2.is_whitespace | flags2.is_letter | flags2.is_number) && flags.as_uint()) { + pos += (cpt == ' '); + while (!(flags2.is_whitespace | flags2.is_letter | flags2.is_number) && flags2.as_uint()) { + flags2 = _get_flags(++pos); + } + uint32_t cpt2 = _get_cpt(pos); + while (cpt2 == '\r' || cpt2 == '\n') { cpt2 = _get_cpt(++pos); } + _add_token(pos); + continue; + } + + size_t num_whitespaces = 0; + size_t last_end_r_or_n = 0; + while (_get_flags(pos + num_whitespaces).is_whitespace) { + uint32_t cpt2 = _get_cpt(pos + num_whitespaces); + if (cpt2 == '\r' || cpt2 == '\n') { last_end_r_or_n = pos + num_whitespaces + 1; } + num_whitespaces++; + } + + // regex: \s*[\r\n]+ + if (last_end_r_or_n > 0) { + pos = last_end_r_or_n; + _add_token(pos); + continue; + } + + // regex: \s+(?!\S) + if (num_whitespaces > 1 && _get_cpt(pos + num_whitespaces) != OUT_OF_RANGE) { + pos += num_whitespaces - 1; + _add_token(pos); + continue; + } + + // regex: \s+ + if (num_whitespaces > 0) { + pos += num_whitespaces; + _add_token(pos); + continue; + } + + // no matches + _add_token(++pos); + } + } + + return bpe_offsets; +} + +// use std::wregex to split the text +static std::vector unicode_regex_split_stl(const std::wstring& wtext, const std::wstring& regex_expr, + const std::vector& offsets) { + std::wregex expr(regex_expr); + std::vector bpe_offsets; // store the offset of each word + bpe_offsets.reserve(offsets.size()); // Reserve memory for the approximate size + size_t start = 0; + for (auto offset : offsets) { + std::wcregex_iterator it(wtext.data() + start, wtext.data() + start + offset, expr); + std::wcregex_iterator end; + + int64_t start_idx = 0; + while (it != end) { + std::wcmatch match = *it; + if (match.position() > start_idx) { bpe_offsets.emplace_back(match.position() - start_idx); } + bpe_offsets.emplace_back(match.length()); + start_idx = match.position() + match.length(); + ++it; + } + + if (start_idx < (int64_t)offset) { bpe_offsets.emplace_back(offset - start_idx); } + start += offset; + } + + return bpe_offsets; +} + +// use std::regex to split the text +static std::vector unicode_regex_split_stl(const std::string& text, const std::string& regex_expr, + const std::vector& offsets) { + std::regex expr(regex_expr); + std::vector bpe_offsets; // store the offset of each word + bpe_offsets.reserve(offsets.size()); // Reserve memory for the approximate size + size_t start = 0; + for (auto offset : offsets) { + std::cregex_iterator it(text.data() + start, text.data() + start + offset, expr); + std::cregex_iterator end; + + int64_t start_idx = 0; + while (it != end) { + std::cmatch match = *it; + if (match.position() > start_idx) { bpe_offsets.emplace_back(match.position() - start_idx); } + bpe_offsets.emplace_back(match.length()); + start_idx = match.position() + match.length(); + ++it; + } + + if (start_idx < (int64_t)offset) { bpe_offsets.emplace_back(offset - start_idx); } + start += offset; + } + + return bpe_offsets; +} + +static std::vector unicode_regex_split_custom(const std::string& text, const std::string& regex_expr, + const std::vector& offsets) { + std::vector bpe_offsets; + + if (regex_expr + == "'s|'t|'re|'ve|'m|'ll|'d| ?\\p{L}+| ?\\p{N}+| " + "?[^\\s\\p{L}\\p{N}]+|\\s+(?!\\S)") { + bpe_offsets = unicode_regex_split_custom_gpt2(text, offsets); + } else if (regex_expr + == "(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}{1," + "3}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+" + || regex_expr + == "(?:'[sS]|'[tT]|'[rR][eE]|'[vV][eE]|'[mM]|'[lL][lL]|'[dD])|[^" + "\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}{1,3}| " + "?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+") { + bpe_offsets = unicode_regex_split_custom_llama3(text, offsets); + } + + return bpe_offsets; +} + +// +// interface +// + +std::string unicode_cpt_to_utf8(uint32_t cp) { + std::string result; + + if (/* 0x00 <= cp && */ cp <= 0x7f) { + result.push_back(cp); + return result; + } + if (0x80 <= cp && cp <= 0x7ff) { + result.push_back(0xc0 | ((cp >> 6) & 0x1f)); + result.push_back(0x80 | (cp & 0x3f)); + return result; + } + if (0x800 <= cp && cp <= 0xffff) { + result.push_back(0xe0 | ((cp >> 12) & 0x0f)); + result.push_back(0x80 | ((cp >> 6) & 0x3f)); + result.push_back(0x80 | (cp & 0x3f)); + return result; + } + if (0x10000 <= cp && cp <= 0x10ffff) { + result.push_back(0xf0 | ((cp >> 18) & 0x07)); + result.push_back(0x80 | ((cp >> 12) & 0x3f)); + result.push_back(0x80 | ((cp >> 6) & 0x3f)); + result.push_back(0x80 | (cp & 0x3f)); + return result; + } + + throw std::invalid_argument("invalid codepoint"); +} + +std::vector unicode_cpts_normalize_nfd(const std::vector& cpts) { + auto comp = [](const uint32_t cpt, const range_nfd& range) { return cpt < range.first; }; + std::vector result(cpts.size()); + for (size_t i = 0; i < cpts.size(); ++i) { + const uint32_t cpt = cpts[i]; + auto it = std::upper_bound(unicode_ranges_nfd.begin(), unicode_ranges_nfd.end(), cpt, comp) - 1; + result[i] = (it->first <= cpt && cpt <= it->last) ? it->nfd : cpt; + } + return result; +} + +std::vector unicode_cpts_from_utf8(const std::string& utf8) { + std::vector result; + result.reserve(utf8.size()); + size_t offset = 0; + while (offset < utf8.size()) { result.push_back(unicode_cpt_from_utf8(utf8, offset)); } + return result; +} + +codepoint_flags unicode_cpt_flags(const uint32_t cp) { + static const codepoint_flags undef(codepoint_flags::UNDEFINED); + static const auto cpt_flags = unicode_cpt_flags_array(); + return cp < cpt_flags.size() ? cpt_flags[cp] : undef; +} + +codepoint_flags unicode_cpt_flags(const std::string& utf8) { + static const codepoint_flags undef(codepoint_flags::UNDEFINED); + if (utf8.empty()) { + return undef; // undefined + } + size_t offset = 0; + return unicode_cpt_flags(unicode_cpt_from_utf8(utf8, offset)); +} + +std::string unicode_byte_to_utf8(uint8_t byte) { + static std::unordered_map map = unicode_byte_to_utf8_map(); + return map.at(byte); +} + +uint8_t unicode_utf8_to_byte(const std::string& utf8) { + static std::unordered_map map = unicode_utf8_to_byte_map(); + return map.at(utf8); +} + +uint32_t unicode_tolower(uint32_t cp) { + // binary search + auto it = std::lower_bound(unicode_map_lowercase.begin(), unicode_map_lowercase.end(), cp, + [](const std::pair& pair, uint32_t value) { return pair.first < value; }); + if (it != unicode_map_lowercase.end() && it->first == cp) { return it->second; } + return cp; // Return the original code point if no lowercase mapping is found +} + +std::vector unicode_regex_split(const std::string& text, const std::vector& regex_exprs) { + // unicode categories + static const std::map k_ucat_enum = { + {"\\p{N}", codepoint_flags::NUMBER}, + {"\\p{L}", codepoint_flags::LETTER}, + {"\\p{P}", codepoint_flags::PUNCTUATION}, + }; + + static const std::map k_ucat_cpt = { + {codepoint_flags::NUMBER, 0xD1}, + {codepoint_flags::LETTER, 0xD2}, + {codepoint_flags::PUNCTUATION, 0xD3}, + }; + + static const std::map k_ucat_map = { + {codepoint_flags::NUMBER, "\x30-\x39"}, // 0-9 + {codepoint_flags::LETTER, "\x41-\x5A\x61-\x7A"}, // A-Za-z + {codepoint_flags::PUNCTUATION, "\x21-\x23\x25-\x2A\x2C-\x2F\x3A-\x3B\x3F-\x40\\\x5B-" + "\\\x5D\x5F\\\x7B\\\x7D"}, // !-#%-*,-/:-;?-@\[-\]_\{\} + }; + + // compute collapsed codepoints only if needed by at least one regex + bool need_collapse = false; + for (auto& regex_expr : regex_exprs) { + // search for unicode categories + for (const auto& ucat : k_ucat_enum) { + if (std::string::npos != regex_expr.find(ucat.first)) { + need_collapse = true; + break; + } + } + } + + const auto cpts = unicode_cpts_from_utf8(text); + + // generate a "collapsed" representation of the text, where all codepoints are + // replaced by a single byte ref: + // https://github.com/ggerganov/llama.cpp/pull/6920#issuecomment-2081479935 + std::string text_collapsed; + if (need_collapse) { + // collapse all unicode categories + text_collapsed.resize(cpts.size()); + + for (size_t i = 0; i < cpts.size(); ++i) { + // keep single-byte codepoints as is + if (cpts[i] < 128) { + text_collapsed[i] = cpts[i]; + continue; + } + + const auto flags = unicode_cpt_flags(cpts[i]); + + if (flags.is_whitespace) { + // NOTE: C++ std::regex \s does not mach 0x85, Rust and Python regex + // does. text_collapsed[i] = (char) 0x85; // as whitespace + // fallback + text_collapsed[i] = (char)0x0B; // as whitespace fallback + } else if (k_ucat_cpt.find(flags.category_flag()) != k_ucat_cpt.end()) { + text_collapsed[i] = k_ucat_cpt.at(flags.category_flag()); + } else { + text_collapsed[i] = (char)0xD0; // fallback + } + } + } + + std::vector bpe_offsets = {cpts.size()}; + + for (auto& regex_expr : regex_exprs) { + // first, see if we have an efficient custom regex implementation + auto tmp = unicode_regex_split_custom(text, regex_expr, bpe_offsets); + + if (!tmp.empty()) { + bpe_offsets = std::move(tmp); + continue; + } + + // fallback to general-purpose std::regex / std::wregex + try { + // if a unicode category is used in the regex, we use the collapsed text + // and replace the unicode category with the corresponding collapsed + // representation + bool use_collapsed = false; + for (auto& ucat : k_ucat_enum) { + if (std::string::npos != regex_expr.find(ucat.first)) { + use_collapsed = true; + break; + } + } + + if (use_collapsed) { + // sanity-check that the original regex does not contain any non-ASCII + // characters + const auto cpts_regex = unicode_cpts_from_utf8(regex_expr); + for (size_t i = 0; i < cpts_regex.size(); ++i) { + if (cpts_regex[i] >= 128) { + throw std::runtime_error("Regex includes both unicode categories and non-ASCII " + "characters - not supported"); + } + } + + // generate a collapsed representation of the regex + std::string regex_expr_collapsed; + + // track if we are inside [], because nested [] are not allowed + bool inside = false; + for (size_t i = 0; i < regex_expr.size(); ++i) { + if (regex_expr[i] == '[' && (i == 0 || regex_expr[i - 1] != '\\')) { + regex_expr_collapsed += '['; + inside = true; + continue; + } + + if (inside && regex_expr[i] == ']' && regex_expr[i - 1] != '\\') { + regex_expr_collapsed += ']'; + inside = false; + continue; + } + + if (regex_expr[i + 0] == '\\' && i + 4 < regex_expr.size() && regex_expr[i + 1] == 'p' && regex_expr[i + 2] == '{' + && regex_expr[i + 4] == '}') { + const std::string pat = regex_expr.substr(i, 5); + if (k_ucat_enum.find(pat) != k_ucat_enum.end()) { + if (!inside) { regex_expr_collapsed += '['; } + regex_expr_collapsed += k_ucat_cpt.at(k_ucat_enum.at(pat)); + regex_expr_collapsed += k_ucat_map.at(k_ucat_enum.at(pat)); + if (!inside) { regex_expr_collapsed += ']'; } + i += 4; + continue; + } + } + + regex_expr_collapsed += regex_expr[i]; + } + + // printf("text_collapsed: %s\n", text_collapsed.c_str()); + // printf("regex_expr_collapsed: %s\n", regex_expr_collapsed.c_str()); + bpe_offsets = unicode_regex_split_stl(text_collapsed, regex_expr_collapsed, bpe_offsets); + } else { + // no unicode category used, we can use std::wregex directly + const std::wstring wregex_expr = unicode_wstring_from_utf8(regex_expr); + + // std::wregex \s does not mach non-ASCII whitespaces, using 0x0B as + // fallback + std::wstring wtext(cpts.begin(), cpts.end()); + for (size_t i = 0; i < wtext.size(); ++i) { + if (wtext[i] > 0x7F && unicode_cpt_flags(wtext[i]).is_whitespace) { wtext[i] = 0x0B; } + } + + // printf("text: %s\n", text.c_str()); + // printf("regex_expr: %s\n", regex_expr.c_str()); + bpe_offsets = unicode_regex_split_stl(wtext, wregex_expr, bpe_offsets); + } + } catch (std::regex_error& e) { + fprintf(stderr, "Failed to process regex: '%s'\n", regex_expr.c_str()); + fprintf(stderr, "Regex error: %s\n", e.what()); + throw std::runtime_error("Failed to process regex"); + } + } + + std::vector bpe_words; + bpe_words.reserve(bpe_offsets.size()); // reserve memory for the approximate size + + size_t start = 0; + for (size_t& offset : bpe_offsets) { + bpe_words.emplace_back(); + for (size_t i = start; i < start + offset; ++i) { bpe_words.back() += unicode_cpt_to_utf8(cpts[i]); } + start += offset; + } + + return unicode_byte_encoding_process(bpe_words); +} + +// Get canonical combining class for a codepoint using existing flags data +static uint8_t get_combining_class(uint32_t cpt) { + codepoint_flags flags = unicode_cpt_flags(cpt); + + // Use the existing flag system to determine combining class + if (flags.is_accent_mark) { + // Most combining marks have class 230, but some have different classes + // This is a simplified mapping based on common Unicode patterns + if (cpt >= 0x0591 && cpt <= 0x05BD) return 220; // Hebrew accents + if (cpt >= 0x05BF && cpt <= 0x05C7) return 230; // Hebrew points + if (cpt >= 0x0610 && cpt <= 0x061A) return 230; // Arabic marks + if (cpt >= 0x064B && cpt <= 0x065F) return 30; // Arabic vowels + if (cpt >= 0x0670 && cpt <= 0x0670) return 35; // Arabic superscript alef + if (cpt >= 0x06D6 && cpt <= 0x06E4) return 230; // Arabic small high marks + if (cpt >= 0x06E7 && cpt <= 0x06E8) return 230; // Arabic small high marks + if (cpt >= 0x06EA && cpt <= 0x06ED) return 220; // Arabic small low marks + + // Default combining class for most combining marks + return 230; + } + + return 0; // Non-combining character (starter) +} + +// Apply canonical ordering using bubble sort (simple but correct) +static void canonical_order(std::vector& cpts) { + for (size_t i = 1; i < cpts.size(); ++i) { + for (size_t j = i; j > 0; --j) { + uint8_t cc1 = get_combining_class(cpts[j - 1]); + uint8_t cc2 = get_combining_class(cpts[j]); + + // Only reorder if both have non-zero combining class and are out of order + if (cc1 > cc2 && cc2 != 0) { + std::swap(cpts[j - 1], cpts[j]); + } else { + break; + } + } + } +} + +// Build composition table by reverse-engineering the NFD data +static std::unordered_map, uint32_t> build_composition_table() { + std::unordered_map, uint32_t> composition_map; + + // Iterate through all NFD mappings to build reverse composition table + for (const auto& range : unicode_ranges_nfd) { + for (uint32_t cpt = range.first; cpt <= range.last; ++cpt) { + uint32_t base = range.nfd; + + // For NFC, we need to figure out what combining character was removed + // This is a simplified approach that works for the most common cases + + // Common diacritic mappings based on the composed character + uint32_t combining = 0; + + // Determine combining character based on the composed character + // This is derived from common Unicode patterns + switch (cpt) { + // Grave accent (0x0300) + case 0x00C0: + case 0x00E0: // À à + case 0x00C8: + case 0x00E8: // È è + case 0x00CC: + case 0x00EC: // Ì ì + case 0x00D2: + case 0x00F2: // Ò ò + case 0x00D9: + case 0x00F9: // Ù ù + case 0x01CD: + case 0x01CE: // Ǎ ǎ + case 0x01CF: + case 0x01D0: // Ǐ ǐ + case 0x01D1: + case 0x01D2: // Ǒ ǒ + case 0x01D3: + case 0x01D4: // Ǔ ǔ + combining = 0x0300; + break; + + // Acute accent (0x0301) + case 0x00C1: + case 0x00E1: // Á á + case 0x00C9: + case 0x00E9: // É é + case 0x00CD: + case 0x00ED: // Í í + case 0x00D3: + case 0x00F3: // Ó ó + case 0x00DA: + case 0x00FA: // Ú ú + case 0x00DD: + case 0x00FD: // Ý ý + combining = 0x0301; + break; + + // Circumflex (0x0302) + case 0x00C2: + case 0x00E2: // Â â + case 0x00CA: + case 0x00EA: // Ê ê + case 0x00CE: + case 0x00EE: // Î î + case 0x00D4: + case 0x00F4: // Ô ô + case 0x00DB: + case 0x00FB: // Û û + combining = 0x0302; + break; + + // Tilde (0x0303) + case 0x00C3: + case 0x00E3: // Ã ã + case 0x00D1: + case 0x00F1: // Ñ ñ + case 0x00D5: + case 0x00F5: // Õ õ + combining = 0x0303; + break; + + // Diaeresis (0x0308) + case 0x00C4: + case 0x00E4: // Ä ä + case 0x00CB: + case 0x00EB: // Ë ë + case 0x00CF: + case 0x00EF: // Ï ï + case 0x00D6: + case 0x00F6: // Ö ö + case 0x00DC: + case 0x00FC: // Ü ü + case 0x00FF: // ÿ + combining = 0x0308; + break; + + // Ring above (0x030A) + case 0x00C5: + case 0x00E5: // Å å + combining = 0x030A; + break; + + // Cedilla (0x0327) + case 0x00C7: + case 0x00E7: // Ç ç + combining = 0x0327; + break; + + default: + // For other characters, try to infer from Unicode blocks + if (cpt >= 0x0100 && cpt <= 0x017F) { + // Extended Latin A - try common patterns + if ((cpt & 1) == 0) { // Even codepoints (uppercase) + if (cpt >= 0x0100 && cpt <= 0x0105) + combining = 0x0304; // macron + else if (cpt >= 0x0102 && cpt <= 0x0107) + combining = 0x0306; // breve + else if (cpt >= 0x0104 && cpt <= 0x0119) + combining = 0x0328; // ogonek + else if (cpt >= 0x0106 && cpt <= 0x010D) + combining = 0x0301; // acute + else if (cpt >= 0x0108 && cpt <= 0x010F) + combining = 0x0302; // circumflex + else if (cpt >= 0x010A && cpt <= 0x0111) + combining = 0x0307; // dot above + else if (cpt >= 0x010C && cpt <= 0x0165) + combining = 0x030C; // caron + } + } + break; + } + + // Only add to composition table if we identified a combining character + if (combining != 0) { composition_map[{base, combining}] = cpt; } + } + } + + return composition_map; +} + +// Get the composition table (built once, cached) +static const std::unordered_map, uint32_t>& get_composition_table() { + static const auto composition_table = build_composition_table(); + return composition_table; +} + +std::vector unicode_cpts_normalize_nfc(const std::vector& cpts) { + // Step 1: Apply NFD (canonical decomposition) using existing implementation + std::vector nfd_result = unicode_cpts_normalize_nfd(cpts); + + // Step 2: Apply canonical ordering + canonical_order(nfd_result); + + // Step 3: Apply canonical composition + const auto& composition_table = get_composition_table(); + std::vector result; + result.reserve(nfd_result.size()); + + size_t i = 0; + while (i < nfd_result.size()) { + uint32_t starter = nfd_result[i]; + result.push_back(starter); + + // Only try to compose if this is a starter (combining class 0) + if (get_combining_class(starter) == 0) { + size_t last_starter_pos = result.size() - 1; + + // Look for composable combining marks after this starter + size_t j = i + 1; + while (j < nfd_result.size()) { + uint32_t combining = nfd_result[j]; + uint8_t cc = get_combining_class(combining); + + // If we hit another starter, stop + if (cc == 0) break; + + // Try to compose with the last starter + auto key = std::make_pair(result[last_starter_pos], combining); + auto it = composition_table.find(key); + + if (it != composition_table.end()) { + // Compose: replace starter with composed character + result[last_starter_pos] = it->second; + // Skip this combining character + ++j; + continue; + } + + // No composition possible, add the combining character + result.push_back(combining); + ++j; + } + i = j; + } else { + ++i; + } + } + + return result; +} diff --git a/mllm/preprocessor/tokenizers/llama_cpp_unicode/unicode.h b/mllm/preprocessor/tokenizers/llama_cpp_unicode/unicode.h new file mode 100644 index 000000000..da4e2ea9f --- /dev/null +++ b/mllm/preprocessor/tokenizers/llama_cpp_unicode/unicode.h @@ -0,0 +1,90 @@ +/* +llama.cpp - commit 54ef9cfc +https://github.com/ggerganov/llama.cpp + +MIT License + +Copyright (c) 2023-2024 The ggml authors + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. +*/ + +#pragma once + +#include +#include +#include + +// TODO: prefix all symbols with "llama_" + +struct codepoint_flags { + enum { + UNDEFINED = 0x0001, + NUMBER = 0x0002, // regex: \p{N} + LETTER = 0x0004, // regex: \p{L} + SEPARATOR = 0x0008, // regex: \p{Z} + ACCENT_MARK = 0x0010, // regex: \p{M} + PUNCTUATION = 0x0020, // regex: \p{P} + SYMBOL = 0x0040, // regex: \p{S} + CONTROL = 0x0080, // regex: \p{C} + MASK_CATEGORIES = 0x00FF, + }; + + // codepoint type + uint16_t is_undefined : 1; + uint16_t is_number : 1; // regex: \p{N} + uint16_t is_letter : 1; // regex: \p{L} + uint16_t is_separator : 1; // regex: \p{Z} + uint16_t is_accent_mark : 1; // regex: \p{M} + uint16_t is_punctuation : 1; // regex: \p{P} + uint16_t is_symbol : 1; // regex: \p{S} + uint16_t is_control : 1; // regex: \p{C} + // helper flags + uint16_t is_whitespace : 1; // regex: \s + uint16_t is_lowercase : 1; + uint16_t is_uppercase : 1; + uint16_t is_nfd : 1; + + // decode from uint16 + inline codepoint_flags(const uint16_t flags = 0) { *reinterpret_cast(this) = flags; } + + inline uint16_t as_uint() const { return *reinterpret_cast(this); } + + inline uint16_t category_flag() const { return this->as_uint() & MASK_CATEGORIES; } +}; + +size_t unicode_len_utf8(char src); + +std::string unicode_cpt_to_utf8(uint32_t cp); +uint32_t unicode_cpt_from_utf8(const std::string& utf8, size_t& offset); +std::vector unicode_cpts_from_utf8(const std::string& utf8); + +std::vector unicode_cpts_normalize_nfd(const std::vector& cpts); + +std::vector unicode_cpts_normalize_nfc(const std::vector& cpts); + +codepoint_flags unicode_cpt_flags(const uint32_t cp); +codepoint_flags unicode_cpt_flags(const std::string& utf8); + +std::string unicode_byte_to_utf8(uint8_t byte); +uint8_t unicode_utf8_to_byte(const std::string& utf8); + +uint32_t unicode_tolower(uint32_t cp); + +std::vector unicode_regex_split(const std::string& text, const std::vector& regex_exprs); From a556ad26d5cd42a0431f8a90c3416dc061a11ee1 Mon Sep 17 00:00:00 2001 From: chenghuaWang <2923277184@qq.com> Date: Sat, 25 Oct 2025 22:40:27 +0800 Subject: [PATCH 14/25] feat(deepseek_ocr): improve tokenizer and add support for new special tokens - Add new special tokens including ``, `<|grounding|>`, ``, and `` to `DpskOcrTokenizer` - Implement `TrieUTF8` to properly handle UTF-8 encoded special tokens in tokenization - Update `tokenize()` and `decode()` methods to correctly process and detokenize strings with special tokens - Remove outdated debug print statements from example code - Modify `fmt::format_to` to avoid extra quotes around vector elements - Refactor tokenization logic to integrate regex and byte-level processing within special token handling This change enhances the DeepSeek OCR tokenizer's ability to work with complex input sequences involving special tokens and improves overall string handling. --- examples/deepseek_ocr/main.cpp | 13 +- mllm/mllm.inl | 2 +- .../tokenization_deepseek_ocr.hpp | 65 +++++++--- .../preprocessor/tokenizers/AutoTokenizer.cpp | 114 +++++++++++++++++- .../preprocessor/tokenizers/AutoTokenizer.hpp | 46 ++++++- 5 files changed, 212 insertions(+), 28 deletions(-) diff --git a/examples/deepseek_ocr/main.cpp b/examples/deepseek_ocr/main.cpp index 9fb8a6771..6ded77b73 100644 --- a/examples/deepseek_ocr/main.cpp +++ b/examples/deepseek_ocr/main.cpp @@ -8,15 +8,10 @@ MLLM_MAIN({ auto model = mllm::models::deepseek_ocr::DeepseekOCRForCausalLM(); auto tokenizer = mllm::models::deepseek_ocr::DpskOcrTokenizer("/Volumes/D/hf-models/DeepSeek-OCR/tokenizer.json"); - mllm::print(tokenizer.tokenize(" ")); - mllm::print(tokenizer.tokenize("▁")); - mllm::print(tokenizer.tokenize("\n")); - mllm::print(tokenizer.tokenize("你好啊!")); - - mllm::print(tokenizer.encode(" ")); - mllm::print(tokenizer.encode("▁")); - mllm::print(tokenizer.encode("\n")); - mllm::print(tokenizer.encode("你好啊!")); + mllm::print(tokenizer.tokenize("\n<|grounding|>Convert the document to markdown. ")); + mllm::print(tokenizer.encode("\n<|grounding|>Convert the document to markdown. ")); + mllm::print(tokenizer.decode({128815, 201, 128820, 21842, 270, 4940, 304, 2121, 7919, 16, 223})); + exit(0); model.infer(tokenizer, "\n<|grounding|>Convert the document to markdown. ", "/Volumes/D/mllm/.tmp/dpsk-ocr-pr.png", "/Volumes/D/mllm/.tmp/dpsk-ocr"); diff --git a/mllm/mllm.inl b/mllm/mllm.inl index cc9afbaec..c697687c2 100644 --- a/mllm/mllm.inl +++ b/mllm/mllm.inl @@ -142,7 +142,7 @@ struct formatter> { *out++ = ','; *out++ = ' '; } - out = fmt::format_to(out, "\"{}\"", vec[i]); + out = fmt::format_to(out, "{}", vec[i]); } *out++ = ']'; return out; diff --git a/mllm/models/deepseek_ocr/tokenization_deepseek_ocr.hpp b/mllm/models/deepseek_ocr/tokenization_deepseek_ocr.hpp index 3358c079a..5bdfc1dd7 100644 --- a/mllm/models/deepseek_ocr/tokenization_deepseek_ocr.hpp +++ b/mllm/models/deepseek_ocr/tokenization_deepseek_ocr.hpp @@ -38,11 +38,15 @@ class DpskOcrTokenizer final : public mllm::preprocessor::AutoTokenizerUTF8 { bpe_.initFromSentencePieceJson(file_path); // Add special tokens to trie - special_tokens_trie_.add(L"<|User|>"); - special_tokens_trie_.add(L"<|Assistant|>"); - special_tokens_trie_.add(L"<|begin▁of▁sentence|>"); - special_tokens_trie_.add(L"<|end▁of▁sentence|>"); - special_tokens_trie_.add(L"<|▁pad▁|>"); + special_tokens_trie_.add("<|User|>"); + special_tokens_trie_.add("<|Assistant|>"); + special_tokens_trie_.add("<|begin▁of▁sentence|>"); + special_tokens_trie_.add("<|end▁of▁sentence|>"); + special_tokens_trie_.add("<|▁pad▁|>"); + special_tokens_trie_.add(""); + special_tokens_trie_.add("<|grounding|>"); + special_tokens_trie_.add(""); + special_tokens_trie_.add(""); } std::vector encode(const std::string& str) override { @@ -53,28 +57,57 @@ class DpskOcrTokenizer final : public mllm::preprocessor::AutoTokenizerUTF8 { } std::string decode(const std::vector& ids) override { - // TODO - return {}; + std::vector after_bpe_check; + for (auto& each_id : ids) { + auto each_str = bpe_._lookup_inverse_vocab(each_id); + after_bpe_check.emplace_back(each_str); + } + return detokenize(after_bpe_check); } std::vector tokenize(const std::string& str) override { - auto after_regex_process = regexPreTokenizer(str); + // Replace all blank token to underscore + std::vector ret; - for (auto& ss : after_regex_process) { - auto after_bytes_process = byteLevelPreTokenizer(ss); + for (auto& each_str : special_tokens_trie_.split(str)) { + if (special_tokens_trie_.isSpecialToken(each_str)) { + ret.emplace_back(each_str); + continue; + } + + // FIXME Should Regex: + auto after_regex_process = {each_str}; + + for (auto& ss : after_regex_process) { + auto after_bytes_process = byteLevelPreTokenizer(ss); - // Perform BPE algorithm on each sub-token - for (auto& bbpe_str : after_bytes_process) { - auto bbpe_str_sub_tokens = bpe_._bpe(bbpe_str); - ret.insert(ret.end(), bbpe_str_sub_tokens.begin(), bbpe_str_sub_tokens.end()); + // Perform BPE algorithm on each sub-token + for (auto& bbpe_str : after_bytes_process) { + auto bbpe_str_sub_tokens = bpe_._bpe(bbpe_str); + ret.insert(ret.end(), bbpe_str_sub_tokens.begin(), bbpe_str_sub_tokens.end()); + } } } return ret; } std::string detokenize(const std::vector& tokenized_str) override { - // TODO - return {}; + std::string ret; + for (auto& each_str : tokenized_str) { + if (special_tokens_trie_.isSpecialToken(each_str)) { + ret += each_str; + continue; + } + // Loop utf8 string + utf8::iterator it(each_str.begin(), each_str.begin(), each_str.end()); + utf8::iterator end_it(each_str.end(), each_str.begin(), each_str.end()); + for (; it != end_it; ++it) { + char32_t cp = *it; + auto b = unicode_utf8_to_byte(unicode_cpt_to_utf8(cp)); + ret.push_back(b); + } + } + return ret; } private: diff --git a/mllm/preprocessor/tokenizers/AutoTokenizer.cpp b/mllm/preprocessor/tokenizers/AutoTokenizer.cpp index 1d958f3f2..ae9f4f9ef 100644 --- a/mllm/preprocessor/tokenizers/AutoTokenizer.cpp +++ b/mllm/preprocessor/tokenizers/AutoTokenizer.cpp @@ -110,6 +110,118 @@ std::vector Trie::split(const std::wstring& text) { bool Trie::isSpecialToken(const std::wstring& token) { return special_tokens_.count(token); } +void TrieUTF8::add(const std::string& word_utf8) { + auto word = utf8String2Cpts(word_utf8); + if (word.empty()) return; + special_tokens_.insert(word); + + TrieNode* current = root_.get(); + + for (const auto& c : word) { + if (!current->children.count(c)) { current->children[c] = std::make_unique(); } + current = current->children[c].get(); + } + + current->is_end = true; +} + +void TrieUTF8::update(const std::vector& words) { + for (const auto& word : words) { add(word); } +} + +// I use FSA to implement the split function. +std::vector TrieUTF8::split(const std::string& text_utf8) { + auto text = utf8String2Cpts(text_utf8); + + std::map states; + std::vector offsets = {0}; + size_t skip = 0; + + for (size_t current = 0; current < text.size(); ++current) { + if (skip > current) continue; + + std::unordered_set to_remove; + bool reset = false; + + wchar_t current_char = text[current]; + + for (auto& [_start, node] : states) { + auto start = _start; + if (node->is_end) { + // trying to find the longest match + size_t max_end = current; + + for (auto& [look_start, look_node] : states) { + if (look_start > start) break; + + size_t lookahead = (look_start < start) ? current + 1 : current; + size_t end = lookahead; + TrieNode* ptr = look_node; + + while (lookahead < text.size()) { + wchar_t ch = text[lookahead]; + + if (!ptr->children.count(ch)) break; + + ptr = ptr->children[ch].get(); + lookahead++; + + if (ptr->is_end) { + start = look_start; + end = lookahead; + skip = lookahead; + } + } + + if (ptr->is_end && end > max_end) { max_end = end; } + } + offsets.push_back(start); + offsets.push_back(max_end); + reset = true; + break; + } + if (node->children.count(current_char)) { + states[start] = node->children[current_char].get(); + } else { + to_remove.insert(start); + } + } + if (reset) { + states.clear(); + } else { + for (auto start : to_remove) { states.erase(start); } + } + if (current >= skip && root_->children.count(current_char)) { states[current] = root_->children[current_char].get(); } + } + for (auto& [start, node] : states) { + if (node->is_end) { + offsets.push_back(start); + offsets.push_back(text.size()); + break; + } + } + + sort(offsets.begin(), offsets.end()); + std::vector result; + + for (size_t i = 1; i < offsets.size(); ++i) { + if (offsets[i - 1] != offsets[i]) { + auto cpts_str = cpts_string_t{}; + for (int __idx = offsets[i - 1]; __idx < offsets[i]; __idx++) { cpts_str.push_back(text[__idx]); } + result.push_back(cpts2Utf8String(cpts_str)); + } + } + if (offsets[offsets.size() - 1] != text.size()) { + auto cpts_str = cpts_string_t{}; + for (int __idx = offsets[offsets.size() - 1]; __idx < text.size(); __idx++) { cpts_str.push_back(text[__idx]); } + result.push_back(cpts2Utf8String(cpts_str)); + } + + return result; +} + +bool TrieUTF8::isSpecialToken(const std::string& token) { return special_tokens_.count(utf8String2Cpts(token)); } + void AutoTokenizer::addSpecialToken(const std::wstring& special_token) { special_tokens_trie_.add(special_token); } -} // namespace mllm::preprocessor \ No newline at end of file +} // namespace mllm::preprocessor diff --git a/mllm/preprocessor/tokenizers/AutoTokenizer.hpp b/mllm/preprocessor/tokenizers/AutoTokenizer.hpp index cc542f3ff..10721e231 100644 --- a/mllm/preprocessor/tokenizers/AutoTokenizer.hpp +++ b/mllm/preprocessor/tokenizers/AutoTokenizer.hpp @@ -16,6 +16,8 @@ #include using json = nlohmann::json; +#include + #include #include "mllm/core/Tensor.hpp" @@ -54,6 +56,48 @@ class Trie { std::unordered_set special_tokens_; }; +class TrieUTF8 { + using cpts_string_t = std::vector; + + struct TrieNode { + std::unordered_map> children; + bool is_end = false; + }; + + struct VectorUint32Hash { + std::size_t operator()(const std::vector& v) const noexcept { + if (v.empty()) return 0; + return static_cast(XXH64(v.data(), v.size() * sizeof(uint32_t), /*seed=*/0)); + } + }; + + public: + void add(const std::string& word); + + void update(const std::vector& words); + + // I use FSA to implement the split function. + std::vector split(const std::string& text); + + bool isSpecialToken(const std::string& token); + + inline std::vector utf8String2Cpts(const std::string& str) { + std::vector word32; + utf8::utf8to32(str.begin(), str.end(), std::back_inserter(word32)); + return word32; + } + + inline std::string cpts2Utf8String(const std::vector& cpts) { + std::string str; + utf8::utf32to8(cpts.begin(), cpts.end(), std::back_inserter(str)); + return str; + } + + private: + std::unique_ptr root_ = std::make_unique(); + std::unordered_set special_tokens_; +}; + class AutoTokenizer { public: void addSpecialToken(const std::wstring& special_token); @@ -85,7 +129,7 @@ class AutoTokenizerUTF8 { virtual std::string detokenize(const std::vector& tokenized_str) = 0; protected: - Trie special_tokens_trie_; + TrieUTF8 special_tokens_trie_; }; } // namespace mllm::preprocessor From 7b2f501a06b32a84db9e9224348a19ee3d3bbc9b Mon Sep 17 00:00:00 2001 From: chenghuaWang <2923277184@qq.com> Date: Sun, 26 Oct 2025 16:55:13 +0800 Subject: [PATCH 15/25] feat(deepseek_ocr): refactor model loading and initialization with config - Introduce `DpskOcrConfig` for model configuration management - Update `DeepseekOCRForCausalLM` constructor to accept config object - Load model weights using `mllm::load` with version specification - Simplify example main function by removing debug prints and exits Also includes: - Add `Tensor::item()` method for scalar access - Improve formatting for vector and tuple types in logging - Set default values for linear implementation types in config - Register module parameters with full names for better tracing - Fix neck Conv2D output channels from 12 to 256 - Add cast2fp32 quantization pipeline support - Handle None quant_cfg in QuantizeSolver methods - Enable streaming quantization when no config path is provided --- examples/deepseek_ocr/main.cpp | 12 +- mllm/core/Tensor.hpp | 6 + mllm/mllm.inl | 13 +- .../configuration_deepseek_ocr.hpp | 7 +- mllm/models/deepseek_ocr/deepencoder.hpp | 12 +- .../deepseek_ocr/modeling_deepseek_ocr.hpp | 223 +++++++++++++++++- mllm/preprocessor/visual/ImageTransform.cpp | 7 +- pymllm/quantize/pipeline.py | 10 +- pymllm/quantize/solver.py | 6 + pymllm/utils/mllm_convertor.py | 23 +- 10 files changed, 286 insertions(+), 33 deletions(-) diff --git a/examples/deepseek_ocr/main.cpp b/examples/deepseek_ocr/main.cpp index 6ded77b73..2bbcc657c 100644 --- a/examples/deepseek_ocr/main.cpp +++ b/examples/deepseek_ocr/main.cpp @@ -5,13 +5,13 @@ using mllm::Argparse; MLLM_MAIN({ - auto model = mllm::models::deepseek_ocr::DeepseekOCRForCausalLM(); - auto tokenizer = mllm::models::deepseek_ocr::DpskOcrTokenizer("/Volumes/D/hf-models/DeepSeek-OCR/tokenizer.json"); + auto config = mllm::models::deepseek_ocr::DpskOcrConfig("/Volumes/D/mllm-models/DeepSeek-OCR-w32a32/config.json"); + auto model = mllm::models::deepseek_ocr::DeepseekOCRForCausalLM(config); + auto tokenizer = mllm::models::deepseek_ocr::DpskOcrTokenizer("/Volumes/D/mllm-models/DeepSeek-OCR-w32a32/tokenizer.json"); + model.load(mllm::load("/Volumes/D/mllm-models/DeepSeek-OCR-w32a32/model.mllm", mllm::ModelFileVersion::kV2)); - mllm::print(tokenizer.tokenize("\n<|grounding|>Convert the document to markdown. ")); - mllm::print(tokenizer.encode("\n<|grounding|>Convert the document to markdown. ")); - mllm::print(tokenizer.decode({128815, 201, 128820, 21842, 270, 4940, 304, 2121, 7919, 16, 223})); - exit(0); + mllm::print(model); + return 0; model.infer(tokenizer, "\n<|grounding|>Convert the document to markdown. ", "/Volumes/D/mllm/.tmp/dpsk-ocr-pr.png", "/Volumes/D/mllm/.tmp/dpsk-ocr"); diff --git a/mllm/core/Tensor.hpp b/mllm/core/Tensor.hpp index 4ef1d8a15..0621916a1 100644 --- a/mllm/core/Tensor.hpp +++ b/mllm/core/Tensor.hpp @@ -634,6 +634,12 @@ class Tensor { return *(offsettedPtr(offsets)); } + template + T item() const { + MLLM_RT_ASSERT_EQ(numel(), 1); + return *(ptr()); + }; + /** * @brief Accesses a tensor element at specified coordinates (const version). * @tparam T Expected data type (must match tensor dtype). diff --git a/mllm/mllm.inl b/mllm/mllm.inl index c697687c2..5ef1ddac8 100644 --- a/mllm/mllm.inl +++ b/mllm/mllm.inl @@ -123,7 +123,7 @@ struct formatter> { *out++ = ','; *out++ = ' '; } - out = fmt::format_to(out, "\"{}\"", vec[i]); + out = fmt::format_to(out, "{:?}", vec[i]); } *out++ = ']'; return out; @@ -149,6 +149,17 @@ struct formatter> { } }; +template<> +struct formatter> { + constexpr auto parse(format_parse_context& ctx) { return ctx.begin(); } + template + auto format(const std::tuple& tuple, FormatContext& ctx) const { + auto out = ctx.out(); + out = fmt::format_to(out, "tuple[{}, {}]", std::get<0>(tuple), std::get<1>(tuple)); + return out; + } +}; + template<> struct formatter { constexpr auto parse(format_parse_context& ctx) { return ctx.begin(); } diff --git a/mllm/models/deepseek_ocr/configuration_deepseek_ocr.hpp b/mllm/models/deepseek_ocr/configuration_deepseek_ocr.hpp index 8b7edae1e..e19dfb8b4 100644 --- a/mllm/models/deepseek_ocr/configuration_deepseek_ocr.hpp +++ b/mllm/models/deepseek_ocr/configuration_deepseek_ocr.hpp @@ -180,9 +180,10 @@ struct DpskOcrConfig : protected ConfigFile { int32_t vocab_size = 129280; // MLLM Related Stuff - aops::LinearImplTypes clip_linear_impl_type; - aops::LinearImplTypes sam_linear_impl_type; - aops::LinearImplTypes mlp_projector_linear_impl_type; + aops::LinearImplTypes clip_linear_impl_type = aops::LinearImplTypes::kDefault; + aops::LinearImplTypes sam_linear_impl_type = aops::LinearImplTypes::kDefault; + aops::LinearImplTypes mlp_projector_linear_impl_type = aops::LinearImplTypes::kDefault; + aops::LinearImplTypes lm_head_linear_impl_type = aops::LinearImplTypes::kDefault; }; } // namespace mllm::models::deepseek_ocr diff --git a/mllm/models/deepseek_ocr/deepencoder.hpp b/mllm/models/deepseek_ocr/deepencoder.hpp index 154bb8d47..d32d74684 100644 --- a/mllm/models/deepseek_ocr/deepencoder.hpp +++ b/mllm/models/deepseek_ocr/deepencoder.hpp @@ -87,7 +87,7 @@ class CLIPVisionEmbeddings final : public nn::Module { num_positions_ = num_patches_ + 1; // [embed_dim], aka [1024] - class_embedding_ = reg("class_embedding"); + class_embedding_ = reg("class_embedding", getModuleName() + ".class_embedding"); patch_embedding_ = reg("patch_embedding", 3, embed_dim_, Tensor::shape_t{14, 14}, Tensor::shape_t{14, 14}, Tensor::shape_t{0, 0}, Tensor::shape_t{1, 1}, false); position_embedding_ = reg("position_embedding", num_positions_, embed_dim_); @@ -260,7 +260,7 @@ class NoTPTransformer final : public nn::Module { public: NoTPTransformer() = default; - NoTPTransformer(const std::string& name, const DpskOcrConfig& config) { + NoTPTransformer(const std::string& name, const DpskOcrConfig& config) : nn::Module(name) { num_layers_ = 24; layers_ = reg>("layers", num_layers_, config); for (auto [idx, layer] : enumerate(layers_.list())) { layer.layer_id_ = idx; } @@ -363,8 +363,8 @@ class Attention final : public nn::Module { qkv_ = reg("qkv", dim, dim * 3, qkv_bias, config.sam_linear_impl_type); proj_ = reg("proj", dim, dim, true, config.sam_linear_impl_type); if (use_rel_pos) { - rel_pos_h_ = reg("rel_pos_h"); - rel_pos_w_ = reg("rel_pos_w"); + rel_pos_h_ = reg("rel_pos_h", getModuleName() + ".rel_pos_h"); + rel_pos_w_ = reg("rel_pos_w", getModuleName() + ".rel_pos_w"); } } @@ -640,7 +640,7 @@ class ImageEncoderViT final : public nn::Module { ImageEncoderViT(const std::string& name, const DpskOcrConfig& config) : nn::Module(name) { patch_embed_ = reg("patch_embed", config); - pos_embed_ = reg("pos_embed"); + pos_embed_ = reg("pos_embed", getModuleName() + ".pos_embed"); // block_nums = 12 // embed_dim = 768 @@ -652,7 +652,7 @@ class ImageEncoderViT final : public nn::Module { blocks_ = reg("blocks", 12, std::vector{2, 5, 8, 11}, config); neck_ = reg("neck") - .add(768, 12, Tensor::shape_t{1, 1}, Tensor::shape_t{1, 1}, Tensor::shape_t{0, 0}, + .add(768, 256, Tensor::shape_t{1, 1}, Tensor::shape_t{1, 1}, Tensor::shape_t{0, 0}, Tensor::shape_t{1, 1}, false) .add(256) .add(256, 256, Tensor::shape_t{3, 3}, Tensor::shape_t{1, 1}, Tensor::shape_t{1, 1}, diff --git a/mllm/models/deepseek_ocr/modeling_deepseek_ocr.hpp b/mllm/models/deepseek_ocr/modeling_deepseek_ocr.hpp index 6464263be..dc0b71923 100644 --- a/mllm/models/deepseek_ocr/modeling_deepseek_ocr.hpp +++ b/mllm/models/deepseek_ocr/modeling_deepseek_ocr.hpp @@ -7,20 +7,229 @@ #include #include "mllm/mllm.hpp" +#include "mllm/nn/Nn.hpp" #include "mllm/utils/StringHelper.hpp" #include "mllm/models/ARGeneration.hpp" #include "mllm/preprocessor/visual/ImageTransform.hpp" + +#include "mllm/models/deepseek_ocr/deepencoder.hpp" #include "mllm/models/deepseek_ocr/conversation_preprocess.hpp" #include "mllm/models/deepseek_ocr/tokenization_deepseek_ocr.hpp" #include "mllm/models/deepseek_ocr/configuration_deepseek_ocr.hpp" namespace mllm::models::deepseek_ocr { +class DeepSeekV2Model : public nn::Module { + protected: + nn::Embedding embed_tokens_; + + public: + DeepSeekV2Model() = default; + + explicit DeepSeekV2Model(const std::string& name, const DpskOcrConfig& config) : nn::Module(name) { + embed_tokens_ = reg("embed_tokens", config.vocab_size, config.hidden_size); + } + + std::vector forward(const std::vector& inputs, const std::vector& args) override { + // TODO + return {}; + } +}; + +class DeepseekOCRModel final : public DeepSeekV2Model { + VitModel vision_model_; + ImageEncoderViT sam_model_; + MlpProjector projector_; + nn::Param image_newline_; + nn::Param view_separator_; + int n_embed = 1280; + + public: + DeepseekOCRModel() = default; + + explicit DeepseekOCRModel(const std::string& name, const DpskOcrConfig& config) : DeepSeekV2Model(name, config) { + sam_model_ = reg("sam_model", config); + vision_model_ = reg("vision_model", config); + projector_ = reg("projector", config); + image_newline_ = reg("image_newline", getModuleName() + ".image_newline"); + view_separator_ = reg("view_seperator", getModuleName() + ".view_seperator"); ///< DeepSeek's typo. + } + + std::vector forward(const std::vector& inputs, const std::vector& args) override { + // FIXME: Just support one image right now. + // Inputs: should be [input_ids, optional[image_crop], optional[image_ori], optional[images_spatial_crop]] + auto& input_ids = inputs[0]; + auto patches = inputs.size() > 1 ? inputs[1] : Tensor::nil(); + auto image_ori = inputs.size() > 2 ? inputs[2] : Tensor::nil(); + auto images_spatial_crop = inputs.size() > 3 ? inputs[3] : Tensor::nil(); + auto images_seq_mask = inputs.size() > 4 ? inputs[4] : Tensor::nil(); + + // Embedding + auto inputs_embeds = embed_tokens_(input_ids); + + // We need to process image + auto images_in_this_batch = Tensor::nil(); + if (patches && image_ori && images_spatial_crop && images_seq_mask) { + if (nn::functional::sum(patches).item() != 0) { + // Local features + auto local_features_1 = sam_model_(patches)[0]; + auto local_features_2 = vision_model_(patches, local_features_1)[0]; + auto local_features = nn::functional::concat( + { + local_features_2[{kAll, {1, kAll}}], + local_features_1.flatten(2).permute({0, 2, 1}), + }, + -1); + local_features = projector_(local_features)[0]; + + // Global features + auto global_features_1 = sam_model_(image_ori)[0]; + auto global_features_2 = vision_model_(image_ori, global_features_1)[0]; + auto global_features = nn::functional::concat( + { + global_features_2[{kAll, {1, kAll}}], + global_features_1.flatten(2).permute({0, 2, 1}), + }, + -1); + global_features = projector_(global_features)[0]; + + print("====================="); + print("BASE: ", global_features.shape()); + print("PATCHES: ", local_features.shape()); + print("====================="); + + auto hw = global_features.size(1); + auto n_dim = global_features.size(2); + auto h = (int)std::sqrt(hw); + auto w = h; + + auto hw2 = local_features.size(1); + auto n_dim2 = local_features.size(2); + auto h2 = (int)std::sqrt(hw2); + auto w2 = h2; + + MLLM_RT_ASSERT_EQ(images_spatial_crop.dtype(), kInt64); + int width_crop_num = images_spatial_crop.at({0, 0}); + int height_crop_num = images_spatial_crop.at({0, 0}); + + global_features = global_features.view({h, w, n_dim}); + global_features = nn::functional::concat( + { + global_features, + + // FIXME: This line is in-efficient. + // pytorch logic: self.image_newline[None, None, :].expand(h, 1, n_dim) + // + // Use pytorch like expand instead. Expand will only modified stride, no memory copy involved. + // But many kernels in mllm's arm backend not use stride as loop step, but calculate itself, so we need to + // refact it. + image_newline_.weight().view({1, 1, -1}).repeat(h, 0), + }, + 1); + + global_features = global_features.view({-1, n_dim}); + + local_features = local_features.view({height_crop_num, width_crop_num, h2, w2, n_dim2}) + .permute({0, 2, 1, 3, 4}) + .view({height_crop_num * h2, width_crop_num * w2, n_dim2}); + local_features = nn::functional::concat( + { + local_features, + + // FIXME: This line is in-efficient. + // pytorch logic: self.image_newline[None, None, :].expand(height_crop_num * h2, 1, n_dim2) + // + // Use pytorch like expand instead. Expand will only modified stride, no memory copy involved. + // But many kernels in mllm's arm backend not use stride as loop step, but calculate itself, so we need to + // refact it. + image_newline_.weight().view({1, 1, -1}).repeat(height_crop_num * h2, 0), + }, + 1); + + local_features = local_features.view({-1, n_dim2}); + auto global_local_features = nn::functional::concat( + { + local_features, + global_features, + + // pytorch logic: self.view_seperator[None, :] + view_separator_.weight().view({1, -1}), + }, + 0); + images_in_this_batch = global_local_features; + } else { + auto global_features_1 = sam_model_(image_ori)[0]; + auto global_features_2 = vision_model_(image_ori, global_features_1)[0]; + auto global_features = nn::functional::concat( + { + global_features_2[{kAll, {1, kAll}}], + global_features_1.flatten(2).permute({0, 2, 1}), + }, + -1); + + global_features = projector_(global_features)[0]; + + print("====================="); + print("BASE: ", global_features.shape()); + print("NO PATCHES"); + print("====================="); + + auto hw = global_features.size(1); + auto n_dim = global_features.size(2); + auto h = (int)std::sqrt(hw); + auto w = h; + + global_features = global_features.view({h, w, n_dim}); + global_features = nn::functional::concat( + { + global_features, + + // FIXME: This line is in-efficient. + // pytorch logic: self.image_newline[None, None, :].expand(h, 1, n_dim) + // + // Use pytorch like expand instead. Expand will only modified stride, no memory copy involved. + // But many kernels in mllm's arm backend not use stride as loop step, but calculate itself, so we need to + // refact it. + image_newline_.weight().view({1, 1, -1}).repeat(h, 0), + }, + 1); + + global_features = global_features.view({-1, n_dim}); + + auto global_local_features = nn::functional::concat( + { + global_features, + view_separator_.weight().view({1, -1}), + }, + 0); + + images_in_this_batch = global_local_features; + } + } + + if (images_in_this_batch) { + // TODO + // inputs_embeds[idx].masked_scatter_(...) + } + + // Normal forward with text and embedded image + // TODO + + return {}; + } +}; + class DeepseekOCRForCausalLM final : public nn::Module, public ARGeneration { + DeepseekOCRModel model_; + nn::Linear lm_head_; + public: DeepseekOCRForCausalLM() = default; - explicit DeepseekOCRForCausalLM(const DpskOcrConfig& config) {} + explicit DeepseekOCRForCausalLM(const DpskOcrConfig& config) { + model_ = reg("model", config); + lm_head_ = reg("lm_head", config.hidden_size, config.vocab_size, false, config.lm_head_linear_impl_type); + } ARGenerationOutputPast forward(const ARGenerationOutputPast& input, const ARGenerationArgs& args) override { return {}; } @@ -63,6 +272,10 @@ class DeepseekOCRForCausalLM final : public nn::Module, public ARGeneration { // Load image auto images = loadImages(conversations); + auto w = images[0].w(); + auto h = images[0].h(); + ratio = 1 - (float)((std::max(w, h) - std::min(w, h)) / (float)(std::max(w, h))); + // Image transform infra auto image_transform = BasicImageTransform(std::nullopt, std::nullopt, /*mean=*/std::vector{0.5, 0.5, 0.5}, /*std=*/std::vector{0.5, 0.5, 0.5}); @@ -197,14 +410,6 @@ class DeepseekOCRForCausalLM final : public nn::Module, public ARGeneration { } } - MLLM_INFO("BRAVO! U R HERE"); - print(input_ids.shape()); - print(input_ids); - print(images_seq_mask_tensor); - print(images_ori_tensor); - print(images_spatial_crop_tensor); - print(images_crop_tensor); - // Run model. Use generate // TODO diff --git a/mllm/preprocessor/visual/ImageTransform.cpp b/mllm/preprocessor/visual/ImageTransform.cpp index 6462cac38..d71a0072a 100644 --- a/mllm/preprocessor/visual/ImageTransform.cpp +++ b/mllm/preprocessor/visual/ImageTransform.cpp @@ -96,9 +96,8 @@ Tensor Normalize::apply(const Tensor& input) const { MLLM_RT_ASSERT_EQ(static_cast(mean_.size()), c); MLLM_RT_ASSERT_EQ(static_cast(std_.size()), c); - // Work on a contiguous clone to simplify indexing - Tensor out = Tensor::empty(src.shape(), src.dtype(), src.device()).alloc(); - float* ptr = out.ptr(); + // Asuming Work on a contiguous clone to simplify indexing + float* ptr = input.ptr(); const size_t plane = static_cast(h) * static_cast(w); for (int ch = 0; ch < c; ++ch) { @@ -109,7 +108,7 @@ Tensor Normalize::apply(const Tensor& input) const { for (size_t i = 0; i < plane; ++i) { base[i] = (base[i] - m) / s; } } - return out; + return input; } // ========================= BasicImageTransform ========================= diff --git a/pymllm/quantize/pipeline.py b/pymllm/quantize/pipeline.py index 3443491aa..71da013c6 100644 --- a/pymllm/quantize/pipeline.py +++ b/pymllm/quantize/pipeline.py @@ -15,7 +15,15 @@ def build_w4a32_kai_pipeline() -> QuantizeSolver: return ret -BUILTIN_QUANTIZE_PIPELINE: Dict = {"w4a32_kai_pipeline": build_w4a32_kai_pipeline} +def build_cast2fp32_pipeline() -> QuantizeSolver: + ret = QuantizeSolver() + return ret + + +BUILTIN_QUANTIZE_PIPELINE: Dict = { + "w4a32_kai_pipeline": build_w4a32_kai_pipeline, + "cast2fp32_pipeline": build_cast2fp32_pipeline, +} BUILTIN_QUANTIZE_PASS: Dict = { "w4a32_kai": W4A32KAIQuantizePass, "cast2fp32": Cast2Fp32QuantizePass, diff --git a/pymllm/quantize/solver.py b/pymllm/quantize/solver.py index dd5fec398..c9669e48d 100644 --- a/pymllm/quantize/solver.py +++ b/pymllm/quantize/solver.py @@ -25,6 +25,9 @@ def _stream_quantize_write_v2(self, tensor_dict: Dict, writer: ModelFileV2) -> b def stream_quantize_params_size( self, quant_cfg, tensor_dict: Dict, **kwargs ) -> int: + if quant_cfg is None: + quant_cfg = {} + param_groups: Dict[str, List[Any, Dict]] = {} for k, v in quant_cfg.items(): sub_group: Dict[str, QuantizePlanPayload] = {} @@ -79,6 +82,9 @@ def stream_quantize( "stream_quantize only support type: ModelFileV2 currently." ) + if quant_cfg is None: + quant_cfg = {} + # Planning param_groups: Dict[str, List[Any, Dict]] = {} for k, v in quant_cfg.items(): diff --git a/pymllm/utils/mllm_convertor.py b/pymllm/utils/mllm_convertor.py index 8aa9b4af9..7b1aabfb0 100644 --- a/pymllm/utils/mllm_convertor.py +++ b/pymllm/utils/mllm_convertor.py @@ -48,9 +48,24 @@ def main(): params = convertor.load_model(args.input_path) # Build pipeline - if args.cfg_path is None: - # TODO just convert to mllm file - pass + if args.cfg_path is None and args.pipeline is not None and args.format == "v2": + cfg = None + pipeline: QuantizeSolver = BUILTIN_QUANTIZE_PIPELINE[args.pipeline]() + old_param_size = len(params) + new_param_size = pipeline.stream_quantize_params_size(cfg, params) + print(f"Params Num: Before: {old_param_size}, After: {new_param_size}") + pipeline.stream_quantize( + cfg, + params, + writer=ModelFileV2( + args.output_path, + args.model_name, + "Streaming", + max_params_descriptor_buffer_num=new_param_size, + ), + cast_left_2_fp32=True, + verbose=args.verbose, + ) elif ( args.cfg_path is not None and args.pipeline is not None and args.format == "v2" ): @@ -76,3 +91,5 @@ def main(): cast_left_2_fp32=True, verbose=args.verbose, ) + else: + print("No pipeline specified") From 6732314d3da2eed8a35eef75c3ab91ee22dcf1bd Mon Sep 17 00:00:00 2001 From: chenghuaWang <2923277184@qq.com> Date: Sun, 26 Oct 2025 17:22:11 +0800 Subject: [PATCH 16/25] feat(cpu): implement MaskedScatterOp for CPU backend - Add CPUMaskedScatterOp and its factory to CPUBackend - Implement forward logic with broadcasting support for masked scatter - Register the new op in the CPU op factory list - Update RTTI and IR definitions to include MaskedScatterOp - Add necessary headers and update model usage in DeepseekOCR - Expose maskedScatter function in nn::functional API --- mllm/backends/cpu/CPUBackend.cpp | 25 +- mllm/backends/cpu/ops/MaskedScatterOp.cpp | 219 ++++++++++++++++++ mllm/backends/cpu/ops/MaskedScatterOp.hpp | 25 ++ mllm/compile/ir/GeneratedRTTIKind.hpp | 3 +- mllm/compile/ir/NodeRTTIClassOfImpl.hpp | 5 +- mllm/compile/ir/linalg/Op.cpp | 1 + mllm/compile/ir/linalg/Op.hpp | 2 + mllm/compile/ir/rtti_kind_gen.py | 1 + mllm/core/OpTypes.hpp | 4 + mllm/core/aops/MaskedScatterOp.cpp | 36 +++ mllm/core/aops/MaskedScatterOp.hpp | 33 +++ .../deepseek_ocr/modeling_deepseek_ocr.hpp | 6 +- mllm/nn/Functional.cpp | 5 + mllm/nn/Functional.hpp | 2 + 14 files changed, 349 insertions(+), 18 deletions(-) create mode 100644 mllm/backends/cpu/ops/MaskedScatterOp.cpp create mode 100644 mllm/backends/cpu/ops/MaskedScatterOp.hpp create mode 100644 mllm/core/aops/MaskedScatterOp.cpp create mode 100644 mllm/core/aops/MaskedScatterOp.hpp diff --git a/mllm/backends/cpu/CPUBackend.cpp b/mllm/backends/cpu/CPUBackend.cpp index cc63857cc..396a0f5b9 100644 --- a/mllm/backends/cpu/CPUBackend.cpp +++ b/mllm/backends/cpu/CPUBackend.cpp @@ -20,6 +20,7 @@ #include "mllm/backends/cpu/ops/GELUOp.hpp" #include "mllm/backends/cpu/ops/InterpolateOp.hpp" #include "mllm/backends/cpu/ops/LayerNorm2DOp.hpp" +#include "mllm/backends/cpu/ops/MaskedScatterOp.hpp" #include "mllm/backends/cpu/ops/PadOp.hpp" #include "mllm/backends/cpu/ops/RadixAttnOp.hpp" #include "mllm/backends/cpu/ops/ReLUOp.hpp" @@ -55,18 +56,18 @@ namespace mllm::cpu { CPUBackend::CPUBackend() : Backend(kCPU, createCPUAllocator()) { - regOpFactory(); + regOpFactory< + CPULinearOpFactory, CPUFillOpFactory, CPUGraphBeginOpFactory, CPUGraphEndOpFactory, CPUAddOpFactory, CPUSubOpFactory, + CPUMulOpFactory, CPUDivOpFactory, CPUNegOpFactory, CPUAbsOpFactory, CPULogOpFactory, CPUExpOpFactory, CPUSinOpFactory, + CPUCosOpFactory, CPUReduceMaxOpFactory, CPUReduceMinOpFactory, CPUReduceSumOpFactory, CPUTransposeOpFactory, + CPUPermuteOpFactory, CPUCastTypeOpFactory, CPUConcatOpFactory, CPUStackOpFactory, CPUContiguousOpFactory, + CPUCopyOpFactory, CPUEmbeddingOpFactory, CPUSplitOpFactory, CPUViewOpFactory, CPULayerNormOpFactory, CPURepeatOpFactory, + CPUX2XOpFactory, CPUSoftmaxOpFactory, CPUSiLUOpFactory, CPURMSNormOpFactory, CPUGELUOpFactory, CPUQuickGELUOpFactory, + CPUReLUOpFactory, CPUMatMulOpFactory, CPUFlashAttention2OpFactory, CPUSliceOpFactory, CPUVisionRoPEOpFactory, + CPUParamOpFactory, CPUMultimodalRoPEOpFactory, CPURoPEOpFactory, CPUCausalMaskOpFactory, CPUConv1DOpFactory, + CPUConv3DOpFactory, CPUSTFTOpFactory, CPUISTFTOpFactory, CPUIndexOpFactory, CPUTopKOpFactory, CPUClipOpFactory, + CPUMeanOpFactory, CPUKVCacheOpFactory, CPUPagedAttnOpFactory, CPUScatter2ShardsOpFactory, CPURadixAttnOpFactory, + CPUConv2DOpFactory, CPULayerNorm2DOpFactory, CPUInterpolateOpFactory, CPUPadOpFactory, CPUMaskedScatterOpFactory>(); } std::shared_ptr createCPUBackend() { return std::make_shared(); } diff --git a/mllm/backends/cpu/ops/MaskedScatterOp.cpp b/mllm/backends/cpu/ops/MaskedScatterOp.cpp new file mode 100644 index 000000000..e41c568ee --- /dev/null +++ b/mllm/backends/cpu/ops/MaskedScatterOp.cpp @@ -0,0 +1,219 @@ +// Copyright (c) MLLM Team. +// Licensed under the MIT License. + +#include "mllm/backends/cpu/ops/MaskedScatterOp.hpp" +#include "mllm/core/Tensor.hpp" +#include "mllm/core/OpTypes.hpp" +#include +#include +#include + +namespace mllm::cpu { + +CPUMaskedScatterOp::CPUMaskedScatterOp(const aops::MaskedScatterOpOptions& options) : aops::MaskedScatterOp(options) {} + +// Helper function to calculate broadcast shape +std::vector calculateBroadcastShape(const std::vector& shape_a, const std::vector& shape_b) { + // Determine the maximum number of dimensions + size_t max_ndim = std::max(shape_a.size(), shape_b.size()); + std::vector broadcast_shape(max_ndim); + + // Pad shapes to the same number of dimensions + std::vector padded_a(max_ndim, 1); + std::vector padded_b(max_ndim, 1); + + // Copy original shapes to the end (right-aligned) + std::copy(shape_a.begin(), shape_a.end(), padded_a.begin() + (max_ndim - shape_a.size())); + std::copy(shape_b.begin(), shape_b.end(), padded_b.begin() + (max_ndim - shape_b.size())); + + // Calculate broadcast shape + for (size_t i = 0; i < max_ndim; ++i) { + if (padded_a[i] == padded_b[i]) { + broadcast_shape[i] = padded_a[i]; + } else if (padded_a[i] == 1) { + broadcast_shape[i] = padded_b[i]; + } else if (padded_b[i] == 1) { + broadcast_shape[i] = padded_a[i]; + } else { + // Cannot broadcast, should not happen in valid cases + broadcast_shape[i] = std::max(padded_a[i], padded_b[i]); + } + } + + return broadcast_shape; +} + +// Helper function to calculate strides for broadcasting +std::vector calculateStrides(const std::vector& shape) { + std::vector strides(shape.size(), 1); + for (int i = shape.size() - 2; i >= 0; --i) { strides[i] = strides[i + 1] * shape[i + 1]; } + return strides; +} + +// Helper function to convert multi-dimensional index to linear index +int32_t getLinearIndex(const std::vector& indices, const std::vector& strides) { + int32_t linear_index = 0; + for (size_t i = 0; i < indices.size(); ++i) { linear_index += indices[i] * strides[i]; } + return linear_index; +} + +// Helper function to convert linear index to multi-dimensional indices +std::vector getMultiDimIndices(int32_t linear_index, const std::vector& shape) { + std::vector indices(shape.size()); + for (int i = shape.size() - 1; i >= 0; --i) { + indices[i] = linear_index % shape[i]; + linear_index /= shape[i]; + } + return indices; +} + +void CPUMaskedScatterOp::forward(const std::vector& inputs, std::vector& outputs) { + // dst, mask, src + auto& dst = inputs[0]; + auto& mask = inputs[1]; + auto& src = inputs[2]; + + MLLM_RT_ASSERT_EQ(mask.dtype(), kInt8); + + // dst and output should be the same tensor (in-place operation) + // But we still need to ensure output has the correct shape + auto& output = outputs[0]; + + // Get shapes + auto dst_shape = dst.shape(); + auto mask_shape = mask.shape(); + auto src_shape = src.shape(); + + // Calculate broadcast shape for all tensors + auto broadcast_shape = calculateBroadcastShape(dst_shape, mask_shape); + broadcast_shape = calculateBroadcastShape(broadcast_shape, src_shape); + + // Calculate strides for broadcasting + auto dst_strides = calculateStrides(dst_shape); + auto mask_strides = calculateStrides(mask_shape); + auto src_strides = calculateStrides(src_shape); + auto broadcast_strides = calculateStrides(broadcast_shape); + + // Calculate total elements + int32_t total_elements = std::accumulate(broadcast_shape.begin(), broadcast_shape.end(), 1, std::multiplies<>()); + + // Check data types + MLLM_RT_ASSERT(dst.dtype() == src.dtype()); + + // Handle different data types + if (dst.dtype() == MLLM_TYPE_F32) { + float* dst_ptr = dst.ptr(); + float* src_ptr = src.ptr(); + uint8_t* mask_ptr = mask.ptr(); + + for (int32_t i = 0; i < total_elements; ++i) { + // Convert linear index to multi-dimensional indices + auto indices = getMultiDimIndices(i, broadcast_shape); + + // Calculate index for mask with broadcasting + int32_t mask_linear_index = 0; + if (mask_shape.size() == 1 && mask_shape[0] == 1) { + // Scalar mask + mask_linear_index = 0; + } else { + // Multi-dimensional mask + std::vector mask_indices(mask_shape.size()); + int offset = broadcast_shape.size() - mask_shape.size(); + for (size_t j = 0; j < mask_shape.size(); ++j) { mask_indices[j] = indices[j + offset] % mask_shape[j]; } + mask_linear_index = getLinearIndex(mask_indices, mask_strides); + } + + // If mask is true, copy from src to dst + if (mask_ptr[mask_linear_index] != 0) { + // Calculate index for dst + int32_t dst_linear_index = 0; + if (dst_shape.size() == 1 && dst_shape[0] == 1) { + // Scalar dst + dst_linear_index = 0; + } else { + // Multi-dimensional dst + std::vector dst_indices(dst_shape.size()); + int offset = broadcast_shape.size() - dst_shape.size(); + for (size_t j = 0; j < dst_shape.size(); ++j) { dst_indices[j] = indices[j + offset] % dst_shape[j]; } + dst_linear_index = getLinearIndex(dst_indices, dst_strides); + } + + // Calculate index for src with broadcasting + int32_t src_linear_index = 0; + if (src_shape.size() == 1 && src_shape[0] == 1) { + // Scalar src + src_linear_index = 0; + } else { + // Multi-dimensional src + std::vector src_indices(src_shape.size()); + int offset = broadcast_shape.size() - src_shape.size(); + for (size_t j = 0; j < src_shape.size(); ++j) { src_indices[j] = indices[j + offset] % src_shape[j]; } + src_linear_index = getLinearIndex(src_indices, src_strides); + } + + dst_ptr[dst_linear_index] = src_ptr[src_linear_index]; + } + } + } else if (dst.dtype() == MLLM_TYPE_F16) { + mllm_fp16_t* dst_ptr = dst.ptr(); + mllm_fp16_t* src_ptr = src.ptr(); + uint8_t* mask_ptr = mask.ptr(); + + for (int32_t i = 0; i < total_elements; ++i) { + // Convert linear index to multi-dimensional indices + auto indices = getMultiDimIndices(i, broadcast_shape); + + // Calculate index for mask with broadcasting + int32_t mask_linear_index = 0; + if (mask_shape.size() == 1 && mask_shape[0] == 1) { + // Scalar mask + mask_linear_index = 0; + } else { + // Multi-dimensional mask + std::vector mask_indices(mask_shape.size()); + int offset = broadcast_shape.size() - mask_shape.size(); + for (size_t j = 0; j < mask_shape.size(); ++j) { mask_indices[j] = indices[j + offset] % mask_shape[j]; } + mask_linear_index = getLinearIndex(mask_indices, mask_strides); + } + + // If mask is true, copy from src to dst + if (mask_ptr[mask_linear_index] != 0) { + // Calculate index for dst + int32_t dst_linear_index = 0; + if (dst_shape.size() == 1 && dst_shape[0] == 1) { + // Scalar dst + dst_linear_index = 0; + } else { + // Multi-dimensional dst + std::vector dst_indices(dst_shape.size()); + int offset = broadcast_shape.size() - dst_shape.size(); + for (size_t j = 0; j < dst_shape.size(); ++j) { dst_indices[j] = indices[j + offset] % dst_shape[j]; } + dst_linear_index = getLinearIndex(dst_indices, dst_strides); + } + + // Calculate index for src with broadcasting + int32_t src_linear_index = 0; + if (src_shape.size() == 1 && src_shape[0] == 1) { + // Scalar src + src_linear_index = 0; + } else { + // Multi-dimensional src + std::vector src_indices(src_shape.size()); + int offset = broadcast_shape.size() - src_shape.size(); + for (size_t j = 0; j < src_shape.size(); ++j) { src_indices[j] = indices[j + offset] % src_shape[j]; } + src_linear_index = getLinearIndex(src_indices, src_strides); + } + + dst_ptr[dst_linear_index] = src_ptr[src_linear_index]; + } + } + } else { + // For other data types, we could add support similarly + NYI("Unsupported data type for MaskedScatter operation"); + } + + // Copy result to output + if (output.ptr() != dst.ptr()) { std::memcpy(output.ptr(), dst.ptr(), dst.bytes()); } +} + +} // namespace mllm::cpu diff --git a/mllm/backends/cpu/ops/MaskedScatterOp.hpp b/mllm/backends/cpu/ops/MaskedScatterOp.hpp new file mode 100644 index 000000000..d61fa90bc --- /dev/null +++ b/mllm/backends/cpu/ops/MaskedScatterOp.hpp @@ -0,0 +1,25 @@ +// Copyright (c) MLLM Team. +// Licensed under the MIT License. + +#pragma once + +#include "mllm/core/BaseOp.hpp" +#include "mllm/core/aops/MaskedScatterOp.hpp" + +namespace mllm::cpu { + +class CPUMaskedScatterOp final : public aops::MaskedScatterOp { + public: + explicit CPUMaskedScatterOp(const aops::MaskedScatterOpOptions& options); + + void forward(const std::vector& inputs, std::vector& outputs) override; +}; + +class CPUMaskedScatterOpFactory : public TypedOpFactory { + public: + std::shared_ptr createOpImpl(const aops::MaskedScatterOpOptions& options) override { + return std::make_shared(options); + } +}; + +} // namespace mllm::cpu diff --git a/mllm/compile/ir/GeneratedRTTIKind.hpp b/mllm/compile/ir/GeneratedRTTIKind.hpp index 4c67ecea6..8af92f0bc 100644 --- a/mllm/compile/ir/GeneratedRTTIKind.hpp +++ b/mllm/compile/ir/GeneratedRTTIKind.hpp @@ -1,4 +1,4 @@ -// Auto generated: 2025-10-24 14:21:08 +// Auto generated: 2025-10-26 17:03:51 // do not modify this file #pragma once @@ -76,6 +76,7 @@ enum NodeKind : uint32_t { RK_Op_LinalgIROp_InterpolateOp, RK_Op_LinalgIROp_EinsumOp, RK_Op_LinalgIROp_StackOp, + RK_Op_LinalgIROp_MaskedScatterOp, RK_Op_LinalgIROp_Last, RK_Op_GraphIROp, RK_Op_GraphIROp_SubGraphOp, diff --git a/mllm/compile/ir/NodeRTTIClassOfImpl.hpp b/mllm/compile/ir/NodeRTTIClassOfImpl.hpp index 4b60ab383..7fb5d1b55 100644 --- a/mllm/compile/ir/NodeRTTIClassOfImpl.hpp +++ b/mllm/compile/ir/NodeRTTIClassOfImpl.hpp @@ -1,4 +1,4 @@ -// Auto generated: 2025-10-24 14:21:08 +// Auto generated: 2025-10-26 17:03:51 // do not modify this file #pragma once namespace mllm::ir { @@ -198,6 +198,9 @@ struct NodeRTTIClassOfImpl { #define RTTI_RK_OP_LINALGIROP_STACKOP_IMPL(v) \ return (v)->getKind() >= RK_Op_LinalgIROp_StackOp && (v)->getKind() <= RK_Op_LinalgIROp_StackOp +#define RTTI_RK_OP_LINALGIROP_MASKEDSCATTEROP_IMPL(v) \ + return (v)->getKind() >= RK_Op_LinalgIROp_MaskedScatterOp && (v)->getKind() <= RK_Op_LinalgIROp_MaskedScatterOp + #define RTTI_RK_OP_GRAPHIROP_IMPL(v) return (v)->getKind() >= RK_Op_GraphIROp && (v)->getKind() <= RK_Op_GraphIROp_Last #define RTTI_RK_OP_GRAPHIROP_SUBGRAPHOP_IMPL(v) \ diff --git a/mllm/compile/ir/linalg/Op.cpp b/mllm/compile/ir/linalg/Op.cpp index 8b9021a95..4a6a1f8cf 100644 --- a/mllm/compile/ir/linalg/Op.cpp +++ b/mllm/compile/ir/linalg/Op.cpp @@ -106,5 +106,6 @@ LINALG_AOPS_DECL(OpTypes::kPad, PadOp); LINALG_AOPS_DECL(OpTypes::kInterpolate, InterpolateOp); LINALG_AOPS_DECL(OpTypes::kEinsum, EinsumOp); LINALG_AOPS_DECL(OpTypes::kStack, StackOp); +LINALG_AOPS_DECL(OpTypes::kMaskedScatter, MaskedScatterOp); } // namespace mllm::ir::linalg diff --git a/mllm/compile/ir/linalg/Op.hpp b/mllm/compile/ir/linalg/Op.hpp index fc3c1d4d4..265d919fc 100644 --- a/mllm/compile/ir/linalg/Op.hpp +++ b/mllm/compile/ir/linalg/Op.hpp @@ -69,6 +69,7 @@ class PadOp; class InterpolateOp; class EinsumOp; class StackOp; +class MaskedScatterOp; } // namespace mllm #define LINALG_AOPS_DEFINE(class_name, rtti_name) \ @@ -223,5 +224,6 @@ LINALG_AOPS_DEFINE(PadOp, PADOP); LINALG_AOPS_DEFINE(InterpolateOp, INTERPOLATEOP); LINALG_AOPS_DEFINE(EinsumOp, EINSUMOP); LINALG_AOPS_DEFINE(StackOp, STACKOP); +LINALG_AOPS_DEFINE(MaskedScatterOp, MASKEDSCATTEROP); } // namespace mllm::ir::linalg diff --git a/mllm/compile/ir/rtti_kind_gen.py b/mllm/compile/ir/rtti_kind_gen.py index f84540a4b..ae4629ad6 100644 --- a/mllm/compile/ir/rtti_kind_gen.py +++ b/mllm/compile/ir/rtti_kind_gen.py @@ -277,6 +277,7 @@ def define_lianlg_ir(ir: dict): op.derive(Cls("InterpolateOp")) op.derive(Cls("EinsumOp")) op.derive(Cls("StackOp")) + op.derive(Cls("MaskedScatterOp")) # value diff --git a/mllm/core/OpTypes.hpp b/mllm/core/OpTypes.hpp index 03040ba29..dfac01e04 100644 --- a/mllm/core/OpTypes.hpp +++ b/mllm/core/OpTypes.hpp @@ -82,6 +82,7 @@ enum class OpTypes : int32_t { kInterpolate = 62, kEinsum = 63, kStack = 64, + kMaskedScatter = 65, // Dynamic Op Start for user to register there own ops. kDynamicOp_Start = 4096, @@ -155,6 +156,9 @@ inline std::string optype2Str(OpTypes type) { case OpTypes::kPad: return "Pad"; case OpTypes::kInterpolate: return "Interpolate"; case OpTypes::kStack: return "Stack"; + case OpTypes::kEinsum: return "Einsum"; + case OpTypes::kMaskedScatter: return "MaskedScatter"; + case OpTypes::kDynamicOp_Start: return "DynamicOp_Start"; case OpTypes::kOpType_End: return "OpType_End"; default: return "Unknown"; } diff --git a/mllm/core/aops/MaskedScatterOp.cpp b/mllm/core/aops/MaskedScatterOp.cpp new file mode 100644 index 000000000..9502dda09 --- /dev/null +++ b/mllm/core/aops/MaskedScatterOp.cpp @@ -0,0 +1,36 @@ +// Copyright (c) MLLM Team. +// Licensed under the MIT License. + +#include "mllm/core/aops/MaskedScatterOp.hpp" +#include "mllm/core/BaseOp.hpp" +#include "mllm/core/Tensor.hpp" +#include "mllm/utils/Common.hpp" +#include "mllm/compile/ir/linalg/Op.hpp" + +namespace mllm::aops { + +MaskedScatterOp::MaskedScatterOp(const MaskedScatterOpOptions& options) : BaseOp(OpTypes::kMaskedScatter), options_(options) {} + +void MaskedScatterOp::load(const ParameterFile::ptr_t& ploader) { MLLM_EMPTY_SCOPE; } + +void MaskedScatterOp::trace(void* trace_context, const std::vector& inputs, std::vector& outputs) { + auto ir_ctx = (ir::IRContext*)trace_context; + auto i_irs = ir::tensor::wrapTensors2TensorIR(ir_ctx, inputs); + auto o_irs = ir::tensor::wrapTensors2TensorIR(ir_ctx, outputs); + ir_ctx->create(shared_from_this(), i_irs, o_irs); +} + +void MaskedScatterOp::forward(const std::vector& inputs, std::vector& outputs) { + NYI("MaskedScatterOp::forward not implemented in aops base."); +} + +void MaskedScatterOp::reshape(const std::vector& inputs, std::vector& outputs) { + outputs.emplace_back(inputs[0]); +} + +void MaskedScatterOp::setup(const std::vector& inputs, std::vector& outputs) { + // Do nothing. + MLLM_EMPTY_SCOPE; +} + +} // namespace mllm::aops diff --git a/mllm/core/aops/MaskedScatterOp.hpp b/mllm/core/aops/MaskedScatterOp.hpp new file mode 100644 index 000000000..9e452e6af --- /dev/null +++ b/mllm/core/aops/MaskedScatterOp.hpp @@ -0,0 +1,33 @@ +// Copyright (c) MLLM Team. +// Licensed under the MIT License. + +#pragma once + +#include "mllm/core/BaseOp.hpp" +#include "mllm/core/ParameterFile.hpp" + +namespace mllm::aops { + +struct MaskedScatterOpOptions : public BaseOpOptions {}; + +class MaskedScatterOp : public BaseOp { + public: + explicit MaskedScatterOp(const MaskedScatterOpOptions& options); + + void load(const ParameterFile::ptr_t& ploader) override; + + void trace(void* trace_context, const std::vector& inputs, std::vector& outputs) override; + + void forward(const std::vector& inputs, std::vector& outputs) override; + + void reshape(const std::vector& inputs, std::vector& outputs) override; + + void setup(const std::vector& inputs, std::vector& outputs) override; + + inline MaskedScatterOpOptions& options() { return options_; } + + protected: + MaskedScatterOpOptions options_; +}; + +} // namespace mllm::aops diff --git a/mllm/models/deepseek_ocr/modeling_deepseek_ocr.hpp b/mllm/models/deepseek_ocr/modeling_deepseek_ocr.hpp index dc0b71923..9eae8f7bc 100644 --- a/mllm/models/deepseek_ocr/modeling_deepseek_ocr.hpp +++ b/mllm/models/deepseek_ocr/modeling_deepseek_ocr.hpp @@ -207,10 +207,8 @@ class DeepseekOCRModel final : public DeepSeekV2Model { } } - if (images_in_this_batch) { - // TODO - // inputs_embeds[idx].masked_scatter_(...) - } + // Scatter copy. + if (images_in_this_batch) { nn::functional::maskedScatter(inputs_embeds, images_seq_mask, images_in_this_batch); } // Normal forward with text and embedded image // TODO diff --git a/mllm/nn/Functional.cpp b/mllm/nn/Functional.cpp index 72947719a..9f8e981b9 100644 --- a/mllm/nn/Functional.cpp +++ b/mllm/nn/Functional.cpp @@ -15,6 +15,7 @@ #include "mllm/core/aops/TopKOp.hpp" #include "mllm/core/aops/SiLUOp.hpp" #include "mllm/core/aops/PadOp.hpp" +#include "mllm/core/aops/MaskedScatterOp.hpp" #include "mllm/core/aops/InterpolateOp.hpp" #include "mllm/core/aops/StackOp.hpp" #include "mllm/engine/Context.hpp" @@ -159,4 +160,8 @@ Tensor interpolateByScale(const Tensor& x, const std::vector& scale_facto return Context::instance().buildOpAndSubmitTask(OpTypes::kInterpolate, opts, {x})[0]; } +void maskedScatter(const Tensor& dst, const Tensor& mask, const Tensor& src) { + Context::instance().buildOpAndSubmitTask(OpTypes::kMaskedScatter, aops::MaskedScatterOpOptions{}, {dst, mask, src}); +} + } // namespace mllm::nn::functional diff --git a/mllm/nn/Functional.hpp b/mllm/nn/Functional.hpp index 1c4a7dd05..08f5fbf8f 100644 --- a/mllm/nn/Functional.hpp +++ b/mllm/nn/Functional.hpp @@ -149,4 +149,6 @@ Tensor interpolateByScale(const Tensor& x, const std::vector& scale_facto aops::InterpolateOpMode mode = aops::InterpolateOpMode::kNearest, bool align_corners = false, bool antialias = false); +void maskedScatter(const Tensor& dst, const Tensor& mask, const Tensor& src); + } // namespace mllm::nn::functional From e849dbccfd8a1d09f023fa50b6df386248aa6e41 Mon Sep 17 00:00:00 2001 From: chenghuaWang <2923277184@qq.com> Date: Sun, 26 Oct 2025 21:36:10 +0800 Subject: [PATCH 17/25] feat(deepseek_ocr): add DeepseekV2MLP, MoEGate, and DeepseekV2MoE modules - Introduce `DeepseekV2MLP` for feed-forward operations with configurable linear implementations - Add `MoEGate` for computing gating scores in mixture-of-experts layers - Implement `DeepseekV2MoE` to support routing inputs to multiple expert networks - Update configuration to include `llm_mlp_linear_impl_type` for MLP layer customization - Include `` header for optional parameters in module constructors fix(cpu): replace incorrect conv2d header with CPUArchHelper - Corrected misplaced header include in layernorm2d.cpp to properly support ARM architecture checks --- mllm/backends/cpu/kernels/arm/layernorm2d.cpp | 2 +- .../configuration_deepseek_ocr.hpp | 1 + .../deepseek_ocr/modeling_deepseek_ocr.hpp | 149 ++++++++++++++++++ 3 files changed, 151 insertions(+), 1 deletion(-) diff --git a/mllm/backends/cpu/kernels/arm/layernorm2d.cpp b/mllm/backends/cpu/kernels/arm/layernorm2d.cpp index be52d715e..b75257d71 100644 --- a/mllm/backends/cpu/kernels/arm/layernorm2d.cpp +++ b/mllm/backends/cpu/kernels/arm/layernorm2d.cpp @@ -1,6 +1,6 @@ // Copyright (c) MLLM Team. // Licensed under the MIT License. -#include "mllm/backends/cpu/kernels/arm/conv2d.hpp" +#include "mllm/utils/CPUArchHelper.hpp" #if defined(MLLM_HOST_ARCH_ARM64) || defined(MLLM_HOST_ARCH_ARM) diff --git a/mllm/models/deepseek_ocr/configuration_deepseek_ocr.hpp b/mllm/models/deepseek_ocr/configuration_deepseek_ocr.hpp index e19dfb8b4..eb007d988 100644 --- a/mllm/models/deepseek_ocr/configuration_deepseek_ocr.hpp +++ b/mllm/models/deepseek_ocr/configuration_deepseek_ocr.hpp @@ -184,6 +184,7 @@ struct DpskOcrConfig : protected ConfigFile { aops::LinearImplTypes sam_linear_impl_type = aops::LinearImplTypes::kDefault; aops::LinearImplTypes mlp_projector_linear_impl_type = aops::LinearImplTypes::kDefault; aops::LinearImplTypes lm_head_linear_impl_type = aops::LinearImplTypes::kDefault; + aops::LinearImplTypes llm_mlp_linear_impl_type = aops::LinearImplTypes::kDefault; }; } // namespace mllm::models::deepseek_ocr diff --git a/mllm/models/deepseek_ocr/modeling_deepseek_ocr.hpp b/mllm/models/deepseek_ocr/modeling_deepseek_ocr.hpp index 9eae8f7bc..8c8494227 100644 --- a/mllm/models/deepseek_ocr/modeling_deepseek_ocr.hpp +++ b/mllm/models/deepseek_ocr/modeling_deepseek_ocr.hpp @@ -2,6 +2,7 @@ // Licensed under the MIT License. #pragma once +#include #include #include #include @@ -19,6 +20,154 @@ namespace mllm::models::deepseek_ocr { +class DeepseekV2MLP final : public nn::Module { + nn::Linear gate_proj_; + nn::Linear up_proj_; + nn::Linear down_proj_; + nn::SiLU act_; + + int hidden_size_; + int intermediate_size_; + + public: + DeepseekV2MLP() = default; + + explicit DeepseekV2MLP(const std::string& name, const DpskOcrConfig& config, + const std::optional& hidden_size = std::nullopt, + const std::optional& intermediate_size = std::nullopt) + : nn::Module(name) { + hidden_size_ = hidden_size.value_or(config.hidden_size); + intermediate_size_ = intermediate_size.value_or(config.intermediate_size); + + // clang-format off + gate_proj_ = reg("gate_proj", hidden_size_, intermediate_size_, false, config.llm_mlp_linear_impl_type); + up_proj_ = reg("up_proj", hidden_size_, intermediate_size_, false, config.llm_mlp_linear_impl_type); + down_proj_ = reg("down_proj", intermediate_size_, hidden_size_, false, config.llm_mlp_linear_impl_type); + act_ = reg("act"); + // clang-format on + } + + std::vector forward(const std::vector& inputs, const std::vector& args) override { + return {down_proj_(act_(gate_proj_(inputs[0])) * up_proj_(inputs[0]))}; + } +}; + +class MoEGate final : public nn::Module { + // FIXME: We may need to support more types + std::string scoring_func_ = "softmax"; + std::string topk_method_ = "greedy"; + + int top_k_; + int n_routed_experts_; + float routed_scaling_factor_; + int n_group_; + int topk_group_; + bool norm_topk_prob_; + + nn::Param weight_; + + public: + MoEGate() = default; + + MoEGate(const std::string& name, const DpskOcrConfig& config) : nn::Module(name) { + top_k_ = config.num_experts_per_tok; + n_routed_experts_ = config.n_routed_experts; + + // FIXME: Read from config.json instead of hard-coding + routed_scaling_factor_ = 1.f; + norm_topk_prob_ = false; + + n_group_ = config.n_group; + topk_group_ = config.topk_group; + + weight_ = reg("weight", getModuleName() + ".weight"); + } + + std::vector forward(const std::vector& inputs, const std::vector& args) override { + auto hidden_states = inputs[0]; + auto bsz = hidden_states.size(0); + auto seq_len = hidden_states.size(1); + auto h = hidden_states.size(2); + + // Compute gating score + hidden_states = hidden_states.view({-1, h}); + // hidden_states and weight must in fp32 to keep precision !!! + auto logits = nn::functional::matmul(hidden_states, weight_.weight(), false, true); + auto scores = nn::functional::softmax(logits, -1); + auto [topk_weight, topk_idx] = nn::functional::topk(scores, top_k_, -1, true, false); + + // FIXME: Someone may need to Norm gate to sum 1. + // FIXME: Someone may need rescale topk_weight by routed_scaling_factor_, but here is hard-code to 1.f + + return {topk_idx, topk_weight}; + } +}; + +class DeepseekV2MoE final : public nn::Module { + int num_experts_per_tok_; + + // FIXME: Should not hard-code + int ep_size_ = 1; + int experts_per_rank_; + int n_shared_experts_ = 0; + + nn::ModuleList experts_; + MoEGate gate_; + nn::ModuleList shared_experts_; + + public: + DeepseekV2MoE() = default; + + DeepseekV2MoE(const std::string& name, const DpskOcrConfig& config) : nn::Module(name) { + num_experts_per_tok_ = config.num_experts_per_tok; + experts_per_rank_ = config.n_routed_experts; + n_shared_experts_ = config.n_shared_experts; + + // Init experts + experts_ = reg>("experts", config.n_routed_experts, config, std::nullopt, + config.moe_intermediate_size); + gate_ = reg("gate", config); + + if (n_shared_experts_ > 0) { + auto intermediate_size = config.moe_intermediate_size * config.n_shared_experts; + shared_experts_ = + reg>("shared_experts", n_shared_experts_, config, std::nullopt, intermediate_size); + } + } + + std::vector forward(const std::vector& inputs, const std::vector& args) override { + auto hidden_states = inputs[0]; + auto identity = hidden_states; + auto orig_shape = hidden_states.shape(); + auto topk_idx = Tensor::nil(); + auto topk_weight = Tensor::nil(); + auto gated_ret = gate_(hidden_states); + topk_idx = gated_ret[0]; + topk_weight = gated_ret[1]; + hidden_states = hidden_states.view({-1, hidden_states.size(-1)}); + auto flat_topk_idx = topk_idx.view({-1}); + auto y = moeInfer(hidden_states, topk_idx, topk_weight).view(orig_shape); + if (n_shared_experts_ > 0) { y = y + shared_experts_(identity)[0]; } + return {y}; + } + + private: + Tensor moeInfer(const Tensor& x, const Tensor& topk_ids, const Tensor& topk_weights) { + // TODO + return Tensor::nil(); + } +}; + +class DeepseekV2Attention final : public nn::Module { + public: + // TODO +}; + +class DeepseekV2DecoderLayer final : public nn::Module { + public: + // TODO +}; + class DeepSeekV2Model : public nn::Module { protected: nn::Embedding embed_tokens_; From a80a973b5afd099e8e99d480717c1c3cfa578fc2 Mon Sep 17 00:00:00 2001 From: chenghuaWang <2923277184@qq.com> Date: Sun, 26 Oct 2025 21:53:45 +0800 Subject: [PATCH 18/25] fix(Tensor): cast rank to int32_t for negative index handling fix(Conv2DOp): add missing semicolon and improve bias registration logic feat(Conv2DOp): add runtime assertions for kernel, stride, padding, and dilation sizes fix(LayerNorm2DOp): improve parameter registration with shared init region fix(deepseek_ocr): move batch_size declaration after pixel_values assignment fix(deepseek_ocr): correct height_crop_num indexing in spatial crop usage feat(Tensor): implement clone method for tensor copying --- mllm/core/Tensor.cpp | 3 ++- mllm/core/aops/Conv2DOp.cpp | 14 +++++++++++--- mllm/core/aops/LayerNorm2DOp.cpp | 6 +++++- mllm/models/deepseek_ocr/deepencoder.hpp | 4 ++-- mllm/models/deepseek_ocr/modeling_deepseek_ocr.hpp | 2 +- 5 files changed, 21 insertions(+), 8 deletions(-) diff --git a/mllm/core/Tensor.cpp b/mllm/core/Tensor.cpp index 192b527ce..6aed89153 100644 --- a/mllm/core/Tensor.cpp +++ b/mllm/core/Tensor.cpp @@ -347,7 +347,7 @@ bool Tensor::isContiguousN(int n) const { return impl()->isContiguousN(n); } int32_t Tensor::size(int32_t id) const { auto nid = id; - if (id < 0) { nid = rank() + id; } + if (id < 0) { nid = static_cast(rank()) + id; } return shape()[nid]; } @@ -436,6 +436,7 @@ Tensor Tensor::flatten(int32_t dim) { return view(new_shape); } + Tensor Tensor::clone() { return Context::instance().buildOpAndSubmitTask(OpTypes::kClone, aops::CloneOpOptions{}, {*this})[0]; } void Tensor::copy2(const Tensor& src) { diff --git a/mllm/core/aops/Conv2DOp.cpp b/mllm/core/aops/Conv2DOp.cpp index 645249ad2..46628b137 100644 --- a/mllm/core/aops/Conv2DOp.cpp +++ b/mllm/core/aops/Conv2DOp.cpp @@ -33,7 +33,7 @@ void Conv2DOp::load(const ParameterFile::ptr_t& ploader) { if (options_.bias) { bias_ = ploader->pull(getName() + ".bias"); } break; } - default: NYI("Unsupported model file version") + default: NYI("Unsupported model file version"); } } @@ -41,10 +41,14 @@ void Conv2DOp::trace(void* trace_context, const std::vector& inputs, std auto ir_ctx = (ir::IRContext*)trace_context; // Register Params + auto init_region = ir_ctx->lookupSymbolTable("init")->cast_()->getTopRegion(); if (weight_ && !ir_ctx->lookupSymbolTable(getName() + ".weight")) { - ir::IRWriterGuard guard(ir_ctx, ir_ctx->lookupSymbolTable("init")->cast_()->getTopRegion()); + ir::IRWriterGuard guard(ir_ctx, init_region); ir_ctx->create(ir_ctx->create(weight_)); - if (options_.bias) { ir_ctx->create(ir_ctx->create(bias_)); } + } + if (options_.bias && bias_ && !ir_ctx->lookupSymbolTable(getName() + ".bias")) { + ir::IRWriterGuard guard(ir_ctx, init_region); + ir_ctx->create(ir_ctx->create(bias_)); } auto i_irs = ir::tensor::wrapTensors2TensorIR(ir_ctx, inputs); @@ -84,6 +88,10 @@ void Conv2DOp::reshape(const std::vector& inputs, std::vector& o const auto& padding = options_.padding; // [ph, pw] if available const auto& dilation = options_.dilation; // [dh, dw] if available const int out_channels = options_.out_channels; + MLLM_RT_ASSERT_EQ(kernel.size(), 2); + MLLM_RT_ASSERT_EQ(stride.size(), 2); + MLLM_RT_ASSERT_EQ(padding.size(), 2); + MLLM_RT_ASSERT_EQ(dilation.size(), 2); // Output shape calculation for Conv2D auto out_shape = [](int dim_size, int kernel_size, int stride_size, int padding_size, int dilation_size) -> int32_t { diff --git a/mllm/core/aops/LayerNorm2DOp.cpp b/mllm/core/aops/LayerNorm2DOp.cpp index 28997dda7..3ce487b2b 100644 --- a/mllm/core/aops/LayerNorm2DOp.cpp +++ b/mllm/core/aops/LayerNorm2DOp.cpp @@ -24,9 +24,13 @@ void LayerNorm2DOp::trace(void* trace_context, const std::vector& inputs auto ir_ctx = (ir::IRContext*)trace_context; // Register Params + auto init_region = ir_ctx->lookupSymbolTable("init")->cast_()->getTopRegion(); if (weight_ && !ir_ctx->lookupSymbolTable(getName() + ".weight")) { - ir::IRWriterGuard guard(ir_ctx, ir_ctx->lookupSymbolTable("init")->cast_()->getTopRegion()); + ir::IRWriterGuard guard(ir_ctx, init_region); ir_ctx->create(ir_ctx->create(weight_)); + } + if (bias_ && !ir_ctx->lookupSymbolTable(getName() + ".bias")) { + ir::IRWriterGuard guard(ir_ctx, init_region); ir_ctx->create(ir_ctx->create(bias_)); } diff --git a/mllm/models/deepseek_ocr/deepencoder.hpp b/mllm/models/deepseek_ocr/deepencoder.hpp index d32d74684..ba42d55e3 100644 --- a/mllm/models/deepseek_ocr/deepencoder.hpp +++ b/mllm/models/deepseek_ocr/deepencoder.hpp @@ -129,8 +129,6 @@ class CLIPVisionEmbeddings final : public nn::Module { auto pixel_values = Tensor::nil(); auto patch_embeds = Tensor::nil(); - auto batch_size = pixel_values.shape()[0]; - if (inputs.size() == 1) { pixel_values = inputs[0]; } else if (inputs.size() == 2) { @@ -138,6 +136,8 @@ class CLIPVisionEmbeddings final : public nn::Module { patch_embeds = inputs[1]; } + auto batch_size = pixel_values.shape()[0]; + if (!patch_embeds) { patch_embeds = patch_embedding_(pixel_values); } // Flatten and transpose. diff --git a/mllm/models/deepseek_ocr/modeling_deepseek_ocr.hpp b/mllm/models/deepseek_ocr/modeling_deepseek_ocr.hpp index 8c8494227..a58def5c0 100644 --- a/mllm/models/deepseek_ocr/modeling_deepseek_ocr.hpp +++ b/mllm/models/deepseek_ocr/modeling_deepseek_ocr.hpp @@ -259,7 +259,7 @@ class DeepseekOCRModel final : public DeepSeekV2Model { MLLM_RT_ASSERT_EQ(images_spatial_crop.dtype(), kInt64); int width_crop_num = images_spatial_crop.at({0, 0}); - int height_crop_num = images_spatial_crop.at({0, 0}); + int height_crop_num = images_spatial_crop.at({0, 1}); global_features = global_features.view({h, w, n_dim}); global_features = nn::functional::concat( From 7e3940112319d5e1c18b19b147a3d5078c1b66ca Mon Sep 17 00:00:00 2001 From: chenghuaWang <2923277184@qq.com> Date: Mon, 27 Oct 2025 16:20:40 +0800 Subject: [PATCH 19/25] docs(contribute): rename guidelines.md to guidelines.rst Rename the contribution guidelines file from markdown to reStructuredText format. Also move and rename ArgSortOp files to contribute documentation. docs(cpu_backend): add fa2_radix_paged to index Include fa2_radix_paged documentation in the CPU backend index. Move related files to appropriate locations. feat(examples): remove debug print statements Remove unnecessary model printing and early return in deepseek_ocr example. Enable actual inference execution in the example. --- .../{guidelines.md => guidelines.rst} | 0 docs/contribute/index.rst | 1 + docs/contribute/model_supports.rst | 2 + docs/cpu_backend/fa2_radix_paged.rst | 2 + docs/cpu_backend/index.rst | 1 + examples/deepseek_ocr/main.cpp | 3 - examples/deepseek_ocr/net_info.text | 4159 +++++++++++++++++ mllm/backends/cpu/CPUBackend.cpp | 26 +- mllm/backends/cpu/ops/ArgsortOp.cpp | 100 + mllm/backends/cpu/ops/ArgsortOp.hpp | 24 + mllm/backends/cpu/ops/Conv2DOp.cpp | 93 +- mllm/compile/ir/GeneratedRTTIKind.hpp | 5 +- mllm/compile/ir/NodeRTTIClassOfImpl.hpp | 11 +- mllm/compile/ir/linalg/Op.cpp | 4 + mllm/compile/ir/linalg/Op.hpp | 6 + mllm/compile/ir/rtti_kind_gen.py | 3 + mllm/core/OpTypes.hpp | 6 + mllm/core/Tensor.cpp | 18 + mllm/core/Tensor.hpp | 19 + mllm/core/aops/ArgSortOp.cpp | 0 mllm/core/aops/ArgSortOp.hpp | 0 mllm/core/aops/ArgsortOp.cpp | 40 + mllm/core/aops/ArgsortOp.hpp | 36 + mllm/core/aops/Conv2DOp.cpp | 3 - mllm/core/aops/ElewiseOps.cpp | 9 +- mllm/core/aops/PadOp.cpp | 26 +- .../configuration_deepseek_ocr.hpp | 1 + .../deepseek_ocr/conversation_preprocess.hpp | 14 +- mllm/models/deepseek_ocr/deepencoder.hpp | 9 +- .../deepseek_ocr/modeling_deepseek_ocr.hpp | 373 +- mllm/nn/Module.hpp | 26 + 31 files changed, 4927 insertions(+), 93 deletions(-) rename docs/contribute/{guidelines.md => guidelines.rst} (100%) create mode 100644 docs/contribute/model_supports.rst create mode 100644 docs/cpu_backend/fa2_radix_paged.rst create mode 100644 examples/deepseek_ocr/net_info.text create mode 100644 mllm/backends/cpu/ops/ArgsortOp.cpp create mode 100644 mllm/backends/cpu/ops/ArgsortOp.hpp delete mode 100644 mllm/core/aops/ArgSortOp.cpp delete mode 100644 mllm/core/aops/ArgSortOp.hpp create mode 100644 mllm/core/aops/ArgsortOp.cpp create mode 100644 mllm/core/aops/ArgsortOp.hpp diff --git a/docs/contribute/guidelines.md b/docs/contribute/guidelines.rst similarity index 100% rename from docs/contribute/guidelines.md rename to docs/contribute/guidelines.rst diff --git a/docs/contribute/index.rst b/docs/contribute/index.rst index 4e969b0bc..b26dc3663 100644 --- a/docs/contribute/index.rst +++ b/docs/contribute/index.rst @@ -6,3 +6,4 @@ Contribute roadmap guidelines + model_supports diff --git a/docs/contribute/model_supports.rst b/docs/contribute/model_supports.rst new file mode 100644 index 000000000..fda32d04b --- /dev/null +++ b/docs/contribute/model_supports.rst @@ -0,0 +1,2 @@ +Model Supports +================= diff --git a/docs/cpu_backend/fa2_radix_paged.rst b/docs/cpu_backend/fa2_radix_paged.rst new file mode 100644 index 000000000..a678ad263 --- /dev/null +++ b/docs/cpu_backend/fa2_radix_paged.rst @@ -0,0 +1,2 @@ +FA2, Radix, Paged +==================== diff --git a/docs/cpu_backend/index.rst b/docs/cpu_backend/index.rst index 3527029b3..6aa138007 100644 --- a/docs/cpu_backend/index.rst +++ b/docs/cpu_backend/index.rst @@ -5,5 +5,6 @@ CPU Backend :maxdepth: 2 threads + fa2_radix_paged arm/index x86/index diff --git a/examples/deepseek_ocr/main.cpp b/examples/deepseek_ocr/main.cpp index 2bbcc657c..45e03e24a 100644 --- a/examples/deepseek_ocr/main.cpp +++ b/examples/deepseek_ocr/main.cpp @@ -10,9 +10,6 @@ MLLM_MAIN({ auto tokenizer = mllm::models::deepseek_ocr::DpskOcrTokenizer("/Volumes/D/mllm-models/DeepSeek-OCR-w32a32/tokenizer.json"); model.load(mllm::load("/Volumes/D/mllm-models/DeepSeek-OCR-w32a32/model.mllm", mllm::ModelFileVersion::kV2)); - mllm::print(model); - return 0; - model.infer(tokenizer, "\n<|grounding|>Convert the document to markdown. ", "/Volumes/D/mllm/.tmp/dpsk-ocr-pr.png", "/Volumes/D/mllm/.tmp/dpsk-ocr"); }); diff --git a/examples/deepseek_ocr/net_info.text b/examples/deepseek_ocr/net_info.text new file mode 100644 index 000000000..e1ec22dc1 --- /dev/null +++ b/examples/deepseek_ocr/net_info.text @@ -0,0 +1,4159 @@ +Module: , device: CPU + Module: model, device: CPU + model.embed_tokens, device: CPU + Module: model.layers, device: CPU + Module: model.layers.0, device: CPU + Module: model.layers.0.self_attn, device: CPU + model.layers.0.self_attn.q_proj, device: CPU + model.layers.0.self_attn.k_proj, device: CPU + model.layers.0.self_attn.v_proj, device: CPU + model.layers.0.self_attn.o_proj, device: CPU + model.layers.0.self_attn.q_rope, device: CPU + model.layers.0.self_attn.k_rope, device: CPU + Module: model.layers.0.mlp, device: CPU + model.layers.0.mlp.gate_proj, device: CPU + model.layers.0.mlp.up_proj, device: CPU + model.layers.0.mlp.down_proj, device: CPU + model.layers.0.mlp.act, device: CPU + model.layers.0.input_layernorm, device: CPU + model.layers.0.post_attention_layernorm, device: CPU + Module: model.layers.1, device: CPU + Module: model.layers.1.self_attn, device: CPU + model.layers.1.self_attn.q_proj, device: CPU + model.layers.1.self_attn.k_proj, device: CPU + model.layers.1.self_attn.v_proj, device: CPU + model.layers.1.self_attn.o_proj, device: CPU + model.layers.1.self_attn.q_rope, device: CPU + model.layers.1.self_attn.k_rope, device: CPU + Module: model.layers.1.mlp, device: CPU + Module: model.layers.1.mlp.experts, device: CPU + Module: model.layers.1.mlp.experts.0, device: CPU + model.layers.1.mlp.experts.0.gate_proj, device: CPU + model.layers.1.mlp.experts.0.up_proj, device: CPU + model.layers.1.mlp.experts.0.down_proj, device: CPU + model.layers.1.mlp.experts.0.act, device: CPU + Module: model.layers.1.mlp.experts.1, device: CPU + model.layers.1.mlp.experts.1.gate_proj, device: CPU + model.layers.1.mlp.experts.1.up_proj, device: CPU + model.layers.1.mlp.experts.1.down_proj, device: CPU + model.layers.1.mlp.experts.1.act, device: CPU + Module: model.layers.1.mlp.experts.2, device: CPU + model.layers.1.mlp.experts.2.gate_proj, device: CPU + model.layers.1.mlp.experts.2.up_proj, device: CPU + model.layers.1.mlp.experts.2.down_proj, device: CPU + model.layers.1.mlp.experts.2.act, device: CPU + Module: model.layers.1.mlp.experts.3, device: CPU + model.layers.1.mlp.experts.3.gate_proj, device: CPU + model.layers.1.mlp.experts.3.up_proj, device: CPU + model.layers.1.mlp.experts.3.down_proj, device: CPU + model.layers.1.mlp.experts.3.act, device: CPU + Module: model.layers.1.mlp.experts.4, device: CPU + model.layers.1.mlp.experts.4.gate_proj, device: CPU + model.layers.1.mlp.experts.4.up_proj, device: CPU + model.layers.1.mlp.experts.4.down_proj, device: CPU + model.layers.1.mlp.experts.4.act, device: CPU + Module: model.layers.1.mlp.experts.5, device: CPU + model.layers.1.mlp.experts.5.gate_proj, device: CPU + model.layers.1.mlp.experts.5.up_proj, device: CPU + model.layers.1.mlp.experts.5.down_proj, device: CPU + model.layers.1.mlp.experts.5.act, device: CPU + Module: model.layers.1.mlp.experts.6, device: CPU + model.layers.1.mlp.experts.6.gate_proj, device: CPU + model.layers.1.mlp.experts.6.up_proj, device: CPU + model.layers.1.mlp.experts.6.down_proj, device: CPU + model.layers.1.mlp.experts.6.act, device: CPU + Module: model.layers.1.mlp.experts.7, device: CPU + model.layers.1.mlp.experts.7.gate_proj, device: CPU + model.layers.1.mlp.experts.7.up_proj, device: CPU + model.layers.1.mlp.experts.7.down_proj, device: CPU + model.layers.1.mlp.experts.7.act, device: CPU + Module: model.layers.1.mlp.experts.8, device: CPU + model.layers.1.mlp.experts.8.gate_proj, device: CPU + model.layers.1.mlp.experts.8.up_proj, device: CPU + model.layers.1.mlp.experts.8.down_proj, device: CPU + model.layers.1.mlp.experts.8.act, device: CPU + Module: model.layers.1.mlp.experts.9, device: CPU + model.layers.1.mlp.experts.9.gate_proj, device: CPU + model.layers.1.mlp.experts.9.up_proj, device: CPU + model.layers.1.mlp.experts.9.down_proj, device: CPU + model.layers.1.mlp.experts.9.act, device: CPU + Module: model.layers.1.mlp.experts.10, device: CPU + model.layers.1.mlp.experts.10.gate_proj, device: CPU + model.layers.1.mlp.experts.10.up_proj, device: CPU + model.layers.1.mlp.experts.10.down_proj, device: CPU + model.layers.1.mlp.experts.10.act, device: CPU + Module: model.layers.1.mlp.experts.11, device: CPU + model.layers.1.mlp.experts.11.gate_proj, device: CPU + model.layers.1.mlp.experts.11.up_proj, device: CPU + model.layers.1.mlp.experts.11.down_proj, device: CPU + model.layers.1.mlp.experts.11.act, device: CPU + Module: model.layers.1.mlp.experts.12, device: CPU + model.layers.1.mlp.experts.12.gate_proj, device: CPU + model.layers.1.mlp.experts.12.up_proj, device: CPU + model.layers.1.mlp.experts.12.down_proj, device: CPU + model.layers.1.mlp.experts.12.act, device: CPU + Module: model.layers.1.mlp.experts.13, device: CPU + model.layers.1.mlp.experts.13.gate_proj, device: CPU + model.layers.1.mlp.experts.13.up_proj, device: CPU + model.layers.1.mlp.experts.13.down_proj, device: CPU + model.layers.1.mlp.experts.13.act, device: CPU + Module: model.layers.1.mlp.experts.14, device: CPU + model.layers.1.mlp.experts.14.gate_proj, device: CPU + model.layers.1.mlp.experts.14.up_proj, device: CPU + model.layers.1.mlp.experts.14.down_proj, device: CPU + model.layers.1.mlp.experts.14.act, device: CPU + Module: model.layers.1.mlp.experts.15, device: CPU + model.layers.1.mlp.experts.15.gate_proj, device: CPU + model.layers.1.mlp.experts.15.up_proj, device: CPU + model.layers.1.mlp.experts.15.down_proj, device: CPU + model.layers.1.mlp.experts.15.act, device: CPU + Module: model.layers.1.mlp.experts.16, device: CPU + model.layers.1.mlp.experts.16.gate_proj, device: CPU + model.layers.1.mlp.experts.16.up_proj, device: CPU + model.layers.1.mlp.experts.16.down_proj, device: CPU + model.layers.1.mlp.experts.16.act, device: CPU + Module: model.layers.1.mlp.experts.17, device: CPU + model.layers.1.mlp.experts.17.gate_proj, device: CPU + model.layers.1.mlp.experts.17.up_proj, device: CPU + model.layers.1.mlp.experts.17.down_proj, device: CPU + model.layers.1.mlp.experts.17.act, device: CPU + Module: model.layers.1.mlp.experts.18, device: CPU + model.layers.1.mlp.experts.18.gate_proj, device: CPU + model.layers.1.mlp.experts.18.up_proj, device: CPU + model.layers.1.mlp.experts.18.down_proj, device: CPU + model.layers.1.mlp.experts.18.act, device: CPU + Module: model.layers.1.mlp.experts.19, device: CPU + model.layers.1.mlp.experts.19.gate_proj, device: CPU + model.layers.1.mlp.experts.19.up_proj, device: CPU + model.layers.1.mlp.experts.19.down_proj, device: CPU + model.layers.1.mlp.experts.19.act, device: CPU + Module: model.layers.1.mlp.experts.20, device: CPU + model.layers.1.mlp.experts.20.gate_proj, device: CPU + model.layers.1.mlp.experts.20.up_proj, device: CPU + model.layers.1.mlp.experts.20.down_proj, device: CPU + model.layers.1.mlp.experts.20.act, device: CPU + Module: model.layers.1.mlp.experts.21, device: CPU + model.layers.1.mlp.experts.21.gate_proj, device: CPU + model.layers.1.mlp.experts.21.up_proj, device: CPU + model.layers.1.mlp.experts.21.down_proj, device: CPU + model.layers.1.mlp.experts.21.act, device: CPU + Module: model.layers.1.mlp.experts.22, device: CPU + model.layers.1.mlp.experts.22.gate_proj, device: CPU + model.layers.1.mlp.experts.22.up_proj, device: CPU + model.layers.1.mlp.experts.22.down_proj, device: CPU + model.layers.1.mlp.experts.22.act, device: CPU + Module: model.layers.1.mlp.experts.23, device: CPU + model.layers.1.mlp.experts.23.gate_proj, device: CPU + model.layers.1.mlp.experts.23.up_proj, device: CPU + model.layers.1.mlp.experts.23.down_proj, device: CPU + model.layers.1.mlp.experts.23.act, device: CPU + Module: model.layers.1.mlp.experts.24, device: CPU + model.layers.1.mlp.experts.24.gate_proj, device: CPU + model.layers.1.mlp.experts.24.up_proj, device: CPU + model.layers.1.mlp.experts.24.down_proj, device: CPU + model.layers.1.mlp.experts.24.act, device: CPU + Module: model.layers.1.mlp.experts.25, device: CPU + model.layers.1.mlp.experts.25.gate_proj, device: CPU + model.layers.1.mlp.experts.25.up_proj, device: CPU + model.layers.1.mlp.experts.25.down_proj, device: CPU + model.layers.1.mlp.experts.25.act, device: CPU + Module: model.layers.1.mlp.experts.26, device: CPU + model.layers.1.mlp.experts.26.gate_proj, device: CPU + model.layers.1.mlp.experts.26.up_proj, device: CPU + model.layers.1.mlp.experts.26.down_proj, device: CPU + model.layers.1.mlp.experts.26.act, device: CPU + Module: model.layers.1.mlp.experts.27, device: CPU + model.layers.1.mlp.experts.27.gate_proj, device: CPU + model.layers.1.mlp.experts.27.up_proj, device: CPU + model.layers.1.mlp.experts.27.down_proj, device: CPU + model.layers.1.mlp.experts.27.act, device: CPU + Module: model.layers.1.mlp.experts.28, device: CPU + model.layers.1.mlp.experts.28.gate_proj, device: CPU + model.layers.1.mlp.experts.28.up_proj, device: CPU + model.layers.1.mlp.experts.28.down_proj, device: CPU + model.layers.1.mlp.experts.28.act, device: CPU + Module: model.layers.1.mlp.experts.29, device: CPU + model.layers.1.mlp.experts.29.gate_proj, device: CPU + model.layers.1.mlp.experts.29.up_proj, device: CPU + model.layers.1.mlp.experts.29.down_proj, device: CPU + model.layers.1.mlp.experts.29.act, device: CPU + Module: model.layers.1.mlp.experts.30, device: CPU + model.layers.1.mlp.experts.30.gate_proj, device: CPU + model.layers.1.mlp.experts.30.up_proj, device: CPU + model.layers.1.mlp.experts.30.down_proj, device: CPU + model.layers.1.mlp.experts.30.act, device: CPU + Module: model.layers.1.mlp.experts.31, device: CPU + model.layers.1.mlp.experts.31.gate_proj, device: CPU + model.layers.1.mlp.experts.31.up_proj, device: CPU + model.layers.1.mlp.experts.31.down_proj, device: CPU + model.layers.1.mlp.experts.31.act, device: CPU + Module: model.layers.1.mlp.experts.32, device: CPU + model.layers.1.mlp.experts.32.gate_proj, device: CPU + model.layers.1.mlp.experts.32.up_proj, device: CPU + model.layers.1.mlp.experts.32.down_proj, device: CPU + model.layers.1.mlp.experts.32.act, device: CPU + Module: model.layers.1.mlp.experts.33, device: CPU + model.layers.1.mlp.experts.33.gate_proj, device: CPU + model.layers.1.mlp.experts.33.up_proj, device: CPU + model.layers.1.mlp.experts.33.down_proj, device: CPU + model.layers.1.mlp.experts.33.act, device: CPU + Module: model.layers.1.mlp.experts.34, device: CPU + model.layers.1.mlp.experts.34.gate_proj, device: CPU + model.layers.1.mlp.experts.34.up_proj, device: CPU + model.layers.1.mlp.experts.34.down_proj, device: CPU + model.layers.1.mlp.experts.34.act, device: CPU + Module: model.layers.1.mlp.experts.35, device: CPU + model.layers.1.mlp.experts.35.gate_proj, device: CPU + model.layers.1.mlp.experts.35.up_proj, device: CPU + model.layers.1.mlp.experts.35.down_proj, device: CPU + model.layers.1.mlp.experts.35.act, device: CPU + Module: model.layers.1.mlp.experts.36, device: CPU + model.layers.1.mlp.experts.36.gate_proj, device: CPU + model.layers.1.mlp.experts.36.up_proj, device: CPU + model.layers.1.mlp.experts.36.down_proj, device: CPU + model.layers.1.mlp.experts.36.act, device: CPU + Module: model.layers.1.mlp.experts.37, device: CPU + model.layers.1.mlp.experts.37.gate_proj, device: CPU + model.layers.1.mlp.experts.37.up_proj, device: CPU + model.layers.1.mlp.experts.37.down_proj, device: CPU + model.layers.1.mlp.experts.37.act, device: CPU + Module: model.layers.1.mlp.experts.38, device: CPU + model.layers.1.mlp.experts.38.gate_proj, device: CPU + model.layers.1.mlp.experts.38.up_proj, device: CPU + model.layers.1.mlp.experts.38.down_proj, device: CPU + model.layers.1.mlp.experts.38.act, device: CPU + Module: model.layers.1.mlp.experts.39, device: CPU + model.layers.1.mlp.experts.39.gate_proj, device: CPU + model.layers.1.mlp.experts.39.up_proj, device: CPU + model.layers.1.mlp.experts.39.down_proj, device: CPU + model.layers.1.mlp.experts.39.act, device: CPU + Module: model.layers.1.mlp.experts.40, device: CPU + model.layers.1.mlp.experts.40.gate_proj, device: CPU + model.layers.1.mlp.experts.40.up_proj, device: CPU + model.layers.1.mlp.experts.40.down_proj, device: CPU + model.layers.1.mlp.experts.40.act, device: CPU + Module: model.layers.1.mlp.experts.41, device: CPU + model.layers.1.mlp.experts.41.gate_proj, device: CPU + model.layers.1.mlp.experts.41.up_proj, device: CPU + model.layers.1.mlp.experts.41.down_proj, device: CPU + model.layers.1.mlp.experts.41.act, device: CPU + Module: model.layers.1.mlp.experts.42, device: CPU + model.layers.1.mlp.experts.42.gate_proj, device: CPU + model.layers.1.mlp.experts.42.up_proj, device: CPU + model.layers.1.mlp.experts.42.down_proj, device: CPU + model.layers.1.mlp.experts.42.act, device: CPU + Module: model.layers.1.mlp.experts.43, device: CPU + model.layers.1.mlp.experts.43.gate_proj, device: CPU + model.layers.1.mlp.experts.43.up_proj, device: CPU + model.layers.1.mlp.experts.43.down_proj, device: CPU + model.layers.1.mlp.experts.43.act, device: CPU + Module: model.layers.1.mlp.experts.44, device: CPU + model.layers.1.mlp.experts.44.gate_proj, device: CPU + model.layers.1.mlp.experts.44.up_proj, device: CPU + model.layers.1.mlp.experts.44.down_proj, device: CPU + model.layers.1.mlp.experts.44.act, device: CPU + Module: model.layers.1.mlp.experts.45, device: CPU + model.layers.1.mlp.experts.45.gate_proj, device: CPU + model.layers.1.mlp.experts.45.up_proj, device: CPU + model.layers.1.mlp.experts.45.down_proj, device: CPU + model.layers.1.mlp.experts.45.act, device: CPU + Module: model.layers.1.mlp.experts.46, device: CPU + model.layers.1.mlp.experts.46.gate_proj, device: CPU + model.layers.1.mlp.experts.46.up_proj, device: CPU + model.layers.1.mlp.experts.46.down_proj, device: CPU + model.layers.1.mlp.experts.46.act, device: CPU + Module: model.layers.1.mlp.experts.47, device: CPU + model.layers.1.mlp.experts.47.gate_proj, device: CPU + model.layers.1.mlp.experts.47.up_proj, device: CPU + model.layers.1.mlp.experts.47.down_proj, device: CPU + model.layers.1.mlp.experts.47.act, device: CPU + Module: model.layers.1.mlp.experts.48, device: CPU + model.layers.1.mlp.experts.48.gate_proj, device: CPU + model.layers.1.mlp.experts.48.up_proj, device: CPU + model.layers.1.mlp.experts.48.down_proj, device: CPU + model.layers.1.mlp.experts.48.act, device: CPU + Module: model.layers.1.mlp.experts.49, device: CPU + model.layers.1.mlp.experts.49.gate_proj, device: CPU + model.layers.1.mlp.experts.49.up_proj, device: CPU + model.layers.1.mlp.experts.49.down_proj, device: CPU + model.layers.1.mlp.experts.49.act, device: CPU + Module: model.layers.1.mlp.experts.50, device: CPU + model.layers.1.mlp.experts.50.gate_proj, device: CPU + model.layers.1.mlp.experts.50.up_proj, device: CPU + model.layers.1.mlp.experts.50.down_proj, device: CPU + model.layers.1.mlp.experts.50.act, device: CPU + Module: model.layers.1.mlp.experts.51, device: CPU + model.layers.1.mlp.experts.51.gate_proj, device: CPU + model.layers.1.mlp.experts.51.up_proj, device: CPU + model.layers.1.mlp.experts.51.down_proj, device: CPU + model.layers.1.mlp.experts.51.act, device: CPU + Module: model.layers.1.mlp.experts.52, device: CPU + model.layers.1.mlp.experts.52.gate_proj, device: CPU + model.layers.1.mlp.experts.52.up_proj, device: CPU + model.layers.1.mlp.experts.52.down_proj, device: CPU + model.layers.1.mlp.experts.52.act, device: CPU + Module: model.layers.1.mlp.experts.53, device: CPU + model.layers.1.mlp.experts.53.gate_proj, device: CPU + model.layers.1.mlp.experts.53.up_proj, device: CPU + model.layers.1.mlp.experts.53.down_proj, device: CPU + model.layers.1.mlp.experts.53.act, device: CPU + Module: model.layers.1.mlp.experts.54, device: CPU + model.layers.1.mlp.experts.54.gate_proj, device: CPU + model.layers.1.mlp.experts.54.up_proj, device: CPU + model.layers.1.mlp.experts.54.down_proj, device: CPU + model.layers.1.mlp.experts.54.act, device: CPU + Module: model.layers.1.mlp.experts.55, device: CPU + model.layers.1.mlp.experts.55.gate_proj, device: CPU + model.layers.1.mlp.experts.55.up_proj, device: CPU + model.layers.1.mlp.experts.55.down_proj, device: CPU + model.layers.1.mlp.experts.55.act, device: CPU + Module: model.layers.1.mlp.experts.56, device: CPU + model.layers.1.mlp.experts.56.gate_proj, device: CPU + model.layers.1.mlp.experts.56.up_proj, device: CPU + model.layers.1.mlp.experts.56.down_proj, device: CPU + model.layers.1.mlp.experts.56.act, device: CPU + Module: model.layers.1.mlp.experts.57, device: CPU + model.layers.1.mlp.experts.57.gate_proj, device: CPU + model.layers.1.mlp.experts.57.up_proj, device: CPU + model.layers.1.mlp.experts.57.down_proj, device: CPU + model.layers.1.mlp.experts.57.act, device: CPU + Module: model.layers.1.mlp.experts.58, device: CPU + model.layers.1.mlp.experts.58.gate_proj, device: CPU + model.layers.1.mlp.experts.58.up_proj, device: CPU + model.layers.1.mlp.experts.58.down_proj, device: CPU + model.layers.1.mlp.experts.58.act, device: CPU + Module: model.layers.1.mlp.experts.59, device: CPU + model.layers.1.mlp.experts.59.gate_proj, device: CPU + model.layers.1.mlp.experts.59.up_proj, device: CPU + model.layers.1.mlp.experts.59.down_proj, device: CPU + model.layers.1.mlp.experts.59.act, device: CPU + Module: model.layers.1.mlp.experts.60, device: CPU + model.layers.1.mlp.experts.60.gate_proj, device: CPU + model.layers.1.mlp.experts.60.up_proj, device: CPU + model.layers.1.mlp.experts.60.down_proj, device: CPU + model.layers.1.mlp.experts.60.act, device: CPU + Module: model.layers.1.mlp.experts.61, device: CPU + model.layers.1.mlp.experts.61.gate_proj, device: CPU + model.layers.1.mlp.experts.61.up_proj, device: CPU + model.layers.1.mlp.experts.61.down_proj, device: CPU + model.layers.1.mlp.experts.61.act, device: CPU + Module: model.layers.1.mlp.experts.62, device: CPU + model.layers.1.mlp.experts.62.gate_proj, device: CPU + model.layers.1.mlp.experts.62.up_proj, device: CPU + model.layers.1.mlp.experts.62.down_proj, device: CPU + model.layers.1.mlp.experts.62.act, device: CPU + Module: model.layers.1.mlp.experts.63, device: CPU + model.layers.1.mlp.experts.63.gate_proj, device: CPU + model.layers.1.mlp.experts.63.up_proj, device: CPU + model.layers.1.mlp.experts.63.down_proj, device: CPU + model.layers.1.mlp.experts.63.act, device: CPU + Module: model.layers.1.mlp.gate, device: CPU + model.layers.1.mlp.gate.weight, device: CPU + Module: model.layers.1.mlp.shared_experts, device: CPU + model.layers.1.mlp.shared_experts.gate_proj, device: CPU + model.layers.1.mlp.shared_experts.up_proj, device: CPU + model.layers.1.mlp.shared_experts.down_proj, device: CPU + model.layers.1.mlp.shared_experts.act, device: CPU + model.layers.1.input_layernorm, device: CPU + model.layers.1.post_attention_layernorm, device: CPU + Module: model.layers.2, device: CPU + Module: model.layers.2.self_attn, device: CPU + model.layers.2.self_attn.q_proj, device: CPU + model.layers.2.self_attn.k_proj, device: CPU + model.layers.2.self_attn.v_proj, device: CPU + model.layers.2.self_attn.o_proj, device: CPU + model.layers.2.self_attn.q_rope, device: CPU + model.layers.2.self_attn.k_rope, device: CPU + Module: model.layers.2.mlp, device: CPU + Module: model.layers.2.mlp.experts, device: CPU + Module: model.layers.2.mlp.experts.0, device: CPU + model.layers.2.mlp.experts.0.gate_proj, device: CPU + model.layers.2.mlp.experts.0.up_proj, device: CPU + model.layers.2.mlp.experts.0.down_proj, device: CPU + model.layers.2.mlp.experts.0.act, device: CPU + Module: model.layers.2.mlp.experts.1, device: CPU + model.layers.2.mlp.experts.1.gate_proj, device: CPU + model.layers.2.mlp.experts.1.up_proj, device: CPU + model.layers.2.mlp.experts.1.down_proj, device: CPU + model.layers.2.mlp.experts.1.act, device: CPU + Module: model.layers.2.mlp.experts.2, device: CPU + model.layers.2.mlp.experts.2.gate_proj, device: CPU + model.layers.2.mlp.experts.2.up_proj, device: CPU + model.layers.2.mlp.experts.2.down_proj, device: CPU + model.layers.2.mlp.experts.2.act, device: CPU + Module: model.layers.2.mlp.experts.3, device: CPU + model.layers.2.mlp.experts.3.gate_proj, device: CPU + model.layers.2.mlp.experts.3.up_proj, device: CPU + model.layers.2.mlp.experts.3.down_proj, device: CPU + model.layers.2.mlp.experts.3.act, device: CPU + Module: model.layers.2.mlp.experts.4, device: CPU + model.layers.2.mlp.experts.4.gate_proj, device: CPU + model.layers.2.mlp.experts.4.up_proj, device: CPU + model.layers.2.mlp.experts.4.down_proj, device: CPU + model.layers.2.mlp.experts.4.act, device: CPU + Module: model.layers.2.mlp.experts.5, device: CPU + model.layers.2.mlp.experts.5.gate_proj, device: CPU + model.layers.2.mlp.experts.5.up_proj, device: CPU + model.layers.2.mlp.experts.5.down_proj, device: CPU + model.layers.2.mlp.experts.5.act, device: CPU + Module: model.layers.2.mlp.experts.6, device: CPU + model.layers.2.mlp.experts.6.gate_proj, device: CPU + model.layers.2.mlp.experts.6.up_proj, device: CPU + model.layers.2.mlp.experts.6.down_proj, device: CPU + model.layers.2.mlp.experts.6.act, device: CPU + Module: model.layers.2.mlp.experts.7, device: CPU + model.layers.2.mlp.experts.7.gate_proj, device: CPU + model.layers.2.mlp.experts.7.up_proj, device: CPU + model.layers.2.mlp.experts.7.down_proj, device: CPU + model.layers.2.mlp.experts.7.act, device: CPU + Module: model.layers.2.mlp.experts.8, device: CPU + model.layers.2.mlp.experts.8.gate_proj, device: CPU + model.layers.2.mlp.experts.8.up_proj, device: CPU + model.layers.2.mlp.experts.8.down_proj, device: CPU + model.layers.2.mlp.experts.8.act, device: CPU + Module: model.layers.2.mlp.experts.9, device: CPU + model.layers.2.mlp.experts.9.gate_proj, device: CPU + model.layers.2.mlp.experts.9.up_proj, device: CPU + model.layers.2.mlp.experts.9.down_proj, device: CPU + model.layers.2.mlp.experts.9.act, device: CPU + Module: model.layers.2.mlp.experts.10, device: CPU + model.layers.2.mlp.experts.10.gate_proj, device: CPU + model.layers.2.mlp.experts.10.up_proj, device: CPU + model.layers.2.mlp.experts.10.down_proj, device: CPU + model.layers.2.mlp.experts.10.act, device: CPU + Module: model.layers.2.mlp.experts.11, device: CPU + model.layers.2.mlp.experts.11.gate_proj, device: CPU + model.layers.2.mlp.experts.11.up_proj, device: CPU + model.layers.2.mlp.experts.11.down_proj, device: CPU + model.layers.2.mlp.experts.11.act, device: CPU + Module: model.layers.2.mlp.experts.12, device: CPU + model.layers.2.mlp.experts.12.gate_proj, device: CPU + model.layers.2.mlp.experts.12.up_proj, device: CPU + model.layers.2.mlp.experts.12.down_proj, device: CPU + model.layers.2.mlp.experts.12.act, device: CPU + Module: model.layers.2.mlp.experts.13, device: CPU + model.layers.2.mlp.experts.13.gate_proj, device: CPU + model.layers.2.mlp.experts.13.up_proj, device: CPU + model.layers.2.mlp.experts.13.down_proj, device: CPU + model.layers.2.mlp.experts.13.act, device: CPU + Module: model.layers.2.mlp.experts.14, device: CPU + model.layers.2.mlp.experts.14.gate_proj, device: CPU + model.layers.2.mlp.experts.14.up_proj, device: CPU + model.layers.2.mlp.experts.14.down_proj, device: CPU + model.layers.2.mlp.experts.14.act, device: CPU + Module: model.layers.2.mlp.experts.15, device: CPU + model.layers.2.mlp.experts.15.gate_proj, device: CPU + model.layers.2.mlp.experts.15.up_proj, device: CPU + model.layers.2.mlp.experts.15.down_proj, device: CPU + model.layers.2.mlp.experts.15.act, device: CPU + Module: model.layers.2.mlp.experts.16, device: CPU + model.layers.2.mlp.experts.16.gate_proj, device: CPU + model.layers.2.mlp.experts.16.up_proj, device: CPU + model.layers.2.mlp.experts.16.down_proj, device: CPU + model.layers.2.mlp.experts.16.act, device: CPU + Module: model.layers.2.mlp.experts.17, device: CPU + model.layers.2.mlp.experts.17.gate_proj, device: CPU + model.layers.2.mlp.experts.17.up_proj, device: CPU + model.layers.2.mlp.experts.17.down_proj, device: CPU + model.layers.2.mlp.experts.17.act, device: CPU + Module: model.layers.2.mlp.experts.18, device: CPU + model.layers.2.mlp.experts.18.gate_proj, device: CPU + model.layers.2.mlp.experts.18.up_proj, device: CPU + model.layers.2.mlp.experts.18.down_proj, device: CPU + model.layers.2.mlp.experts.18.act, device: CPU + Module: model.layers.2.mlp.experts.19, device: CPU + model.layers.2.mlp.experts.19.gate_proj, device: CPU + model.layers.2.mlp.experts.19.up_proj, device: CPU + model.layers.2.mlp.experts.19.down_proj, device: CPU + model.layers.2.mlp.experts.19.act, device: CPU + Module: model.layers.2.mlp.experts.20, device: CPU + model.layers.2.mlp.experts.20.gate_proj, device: CPU + model.layers.2.mlp.experts.20.up_proj, device: CPU + model.layers.2.mlp.experts.20.down_proj, device: CPU + model.layers.2.mlp.experts.20.act, device: CPU + Module: model.layers.2.mlp.experts.21, device: CPU + model.layers.2.mlp.experts.21.gate_proj, device: CPU + model.layers.2.mlp.experts.21.up_proj, device: CPU + model.layers.2.mlp.experts.21.down_proj, device: CPU + model.layers.2.mlp.experts.21.act, device: CPU + Module: model.layers.2.mlp.experts.22, device: CPU + model.layers.2.mlp.experts.22.gate_proj, device: CPU + model.layers.2.mlp.experts.22.up_proj, device: CPU + model.layers.2.mlp.experts.22.down_proj, device: CPU + model.layers.2.mlp.experts.22.act, device: CPU + Module: model.layers.2.mlp.experts.23, device: CPU + model.layers.2.mlp.experts.23.gate_proj, device: CPU + model.layers.2.mlp.experts.23.up_proj, device: CPU + model.layers.2.mlp.experts.23.down_proj, device: CPU + model.layers.2.mlp.experts.23.act, device: CPU + Module: model.layers.2.mlp.experts.24, device: CPU + model.layers.2.mlp.experts.24.gate_proj, device: CPU + model.layers.2.mlp.experts.24.up_proj, device: CPU + model.layers.2.mlp.experts.24.down_proj, device: CPU + model.layers.2.mlp.experts.24.act, device: CPU + Module: model.layers.2.mlp.experts.25, device: CPU + model.layers.2.mlp.experts.25.gate_proj, device: CPU + model.layers.2.mlp.experts.25.up_proj, device: CPU + model.layers.2.mlp.experts.25.down_proj, device: CPU + model.layers.2.mlp.experts.25.act, device: CPU + Module: model.layers.2.mlp.experts.26, device: CPU + model.layers.2.mlp.experts.26.gate_proj, device: CPU + model.layers.2.mlp.experts.26.up_proj, device: CPU + model.layers.2.mlp.experts.26.down_proj, device: CPU + model.layers.2.mlp.experts.26.act, device: CPU + Module: model.layers.2.mlp.experts.27, device: CPU + model.layers.2.mlp.experts.27.gate_proj, device: CPU + model.layers.2.mlp.experts.27.up_proj, device: CPU + model.layers.2.mlp.experts.27.down_proj, device: CPU + model.layers.2.mlp.experts.27.act, device: CPU + Module: model.layers.2.mlp.experts.28, device: CPU + model.layers.2.mlp.experts.28.gate_proj, device: CPU + model.layers.2.mlp.experts.28.up_proj, device: CPU + model.layers.2.mlp.experts.28.down_proj, device: CPU + model.layers.2.mlp.experts.28.act, device: CPU + Module: model.layers.2.mlp.experts.29, device: CPU + model.layers.2.mlp.experts.29.gate_proj, device: CPU + model.layers.2.mlp.experts.29.up_proj, device: CPU + model.layers.2.mlp.experts.29.down_proj, device: CPU + model.layers.2.mlp.experts.29.act, device: CPU + Module: model.layers.2.mlp.experts.30, device: CPU + model.layers.2.mlp.experts.30.gate_proj, device: CPU + model.layers.2.mlp.experts.30.up_proj, device: CPU + model.layers.2.mlp.experts.30.down_proj, device: CPU + model.layers.2.mlp.experts.30.act, device: CPU + Module: model.layers.2.mlp.experts.31, device: CPU + model.layers.2.mlp.experts.31.gate_proj, device: CPU + model.layers.2.mlp.experts.31.up_proj, device: CPU + model.layers.2.mlp.experts.31.down_proj, device: CPU + model.layers.2.mlp.experts.31.act, device: CPU + Module: model.layers.2.mlp.experts.32, device: CPU + model.layers.2.mlp.experts.32.gate_proj, device: CPU + model.layers.2.mlp.experts.32.up_proj, device: CPU + model.layers.2.mlp.experts.32.down_proj, device: CPU + model.layers.2.mlp.experts.32.act, device: CPU + Module: model.layers.2.mlp.experts.33, device: CPU + model.layers.2.mlp.experts.33.gate_proj, device: CPU + model.layers.2.mlp.experts.33.up_proj, device: CPU + model.layers.2.mlp.experts.33.down_proj, device: CPU + model.layers.2.mlp.experts.33.act, device: CPU + Module: model.layers.2.mlp.experts.34, device: CPU + model.layers.2.mlp.experts.34.gate_proj, device: CPU + model.layers.2.mlp.experts.34.up_proj, device: CPU + model.layers.2.mlp.experts.34.down_proj, device: CPU + model.layers.2.mlp.experts.34.act, device: CPU + Module: model.layers.2.mlp.experts.35, device: CPU + model.layers.2.mlp.experts.35.gate_proj, device: CPU + model.layers.2.mlp.experts.35.up_proj, device: CPU + model.layers.2.mlp.experts.35.down_proj, device: CPU + model.layers.2.mlp.experts.35.act, device: CPU + Module: model.layers.2.mlp.experts.36, device: CPU + model.layers.2.mlp.experts.36.gate_proj, device: CPU + model.layers.2.mlp.experts.36.up_proj, device: CPU + model.layers.2.mlp.experts.36.down_proj, device: CPU + model.layers.2.mlp.experts.36.act, device: CPU + Module: model.layers.2.mlp.experts.37, device: CPU + model.layers.2.mlp.experts.37.gate_proj, device: CPU + model.layers.2.mlp.experts.37.up_proj, device: CPU + model.layers.2.mlp.experts.37.down_proj, device: CPU + model.layers.2.mlp.experts.37.act, device: CPU + Module: model.layers.2.mlp.experts.38, device: CPU + model.layers.2.mlp.experts.38.gate_proj, device: CPU + model.layers.2.mlp.experts.38.up_proj, device: CPU + model.layers.2.mlp.experts.38.down_proj, device: CPU + model.layers.2.mlp.experts.38.act, device: CPU + Module: model.layers.2.mlp.experts.39, device: CPU + model.layers.2.mlp.experts.39.gate_proj, device: CPU + model.layers.2.mlp.experts.39.up_proj, device: CPU + model.layers.2.mlp.experts.39.down_proj, device: CPU + model.layers.2.mlp.experts.39.act, device: CPU + Module: model.layers.2.mlp.experts.40, device: CPU + model.layers.2.mlp.experts.40.gate_proj, device: CPU + model.layers.2.mlp.experts.40.up_proj, device: CPU + model.layers.2.mlp.experts.40.down_proj, device: CPU + model.layers.2.mlp.experts.40.act, device: CPU + Module: model.layers.2.mlp.experts.41, device: CPU + model.layers.2.mlp.experts.41.gate_proj, device: CPU + model.layers.2.mlp.experts.41.up_proj, device: CPU + model.layers.2.mlp.experts.41.down_proj, device: CPU + model.layers.2.mlp.experts.41.act, device: CPU + Module: model.layers.2.mlp.experts.42, device: CPU + model.layers.2.mlp.experts.42.gate_proj, device: CPU + model.layers.2.mlp.experts.42.up_proj, device: CPU + model.layers.2.mlp.experts.42.down_proj, device: CPU + model.layers.2.mlp.experts.42.act, device: CPU + Module: model.layers.2.mlp.experts.43, device: CPU + model.layers.2.mlp.experts.43.gate_proj, device: CPU + model.layers.2.mlp.experts.43.up_proj, device: CPU + model.layers.2.mlp.experts.43.down_proj, device: CPU + model.layers.2.mlp.experts.43.act, device: CPU + Module: model.layers.2.mlp.experts.44, device: CPU + model.layers.2.mlp.experts.44.gate_proj, device: CPU + model.layers.2.mlp.experts.44.up_proj, device: CPU + model.layers.2.mlp.experts.44.down_proj, device: CPU + model.layers.2.mlp.experts.44.act, device: CPU + Module: model.layers.2.mlp.experts.45, device: CPU + model.layers.2.mlp.experts.45.gate_proj, device: CPU + model.layers.2.mlp.experts.45.up_proj, device: CPU + model.layers.2.mlp.experts.45.down_proj, device: CPU + model.layers.2.mlp.experts.45.act, device: CPU + Module: model.layers.2.mlp.experts.46, device: CPU + model.layers.2.mlp.experts.46.gate_proj, device: CPU + model.layers.2.mlp.experts.46.up_proj, device: CPU + model.layers.2.mlp.experts.46.down_proj, device: CPU + model.layers.2.mlp.experts.46.act, device: CPU + Module: model.layers.2.mlp.experts.47, device: CPU + model.layers.2.mlp.experts.47.gate_proj, device: CPU + model.layers.2.mlp.experts.47.up_proj, device: CPU + model.layers.2.mlp.experts.47.down_proj, device: CPU + model.layers.2.mlp.experts.47.act, device: CPU + Module: model.layers.2.mlp.experts.48, device: CPU + model.layers.2.mlp.experts.48.gate_proj, device: CPU + model.layers.2.mlp.experts.48.up_proj, device: CPU + model.layers.2.mlp.experts.48.down_proj, device: CPU + model.layers.2.mlp.experts.48.act, device: CPU + Module: model.layers.2.mlp.experts.49, device: CPU + model.layers.2.mlp.experts.49.gate_proj, device: CPU + model.layers.2.mlp.experts.49.up_proj, device: CPU + model.layers.2.mlp.experts.49.down_proj, device: CPU + model.layers.2.mlp.experts.49.act, device: CPU + Module: model.layers.2.mlp.experts.50, device: CPU + model.layers.2.mlp.experts.50.gate_proj, device: CPU + model.layers.2.mlp.experts.50.up_proj, device: CPU + model.layers.2.mlp.experts.50.down_proj, device: CPU + model.layers.2.mlp.experts.50.act, device: CPU + Module: model.layers.2.mlp.experts.51, device: CPU + model.layers.2.mlp.experts.51.gate_proj, device: CPU + model.layers.2.mlp.experts.51.up_proj, device: CPU + model.layers.2.mlp.experts.51.down_proj, device: CPU + model.layers.2.mlp.experts.51.act, device: CPU + Module: model.layers.2.mlp.experts.52, device: CPU + model.layers.2.mlp.experts.52.gate_proj, device: CPU + model.layers.2.mlp.experts.52.up_proj, device: CPU + model.layers.2.mlp.experts.52.down_proj, device: CPU + model.layers.2.mlp.experts.52.act, device: CPU + Module: model.layers.2.mlp.experts.53, device: CPU + model.layers.2.mlp.experts.53.gate_proj, device: CPU + model.layers.2.mlp.experts.53.up_proj, device: CPU + model.layers.2.mlp.experts.53.down_proj, device: CPU + model.layers.2.mlp.experts.53.act, device: CPU + Module: model.layers.2.mlp.experts.54, device: CPU + model.layers.2.mlp.experts.54.gate_proj, device: CPU + model.layers.2.mlp.experts.54.up_proj, device: CPU + model.layers.2.mlp.experts.54.down_proj, device: CPU + model.layers.2.mlp.experts.54.act, device: CPU + Module: model.layers.2.mlp.experts.55, device: CPU + model.layers.2.mlp.experts.55.gate_proj, device: CPU + model.layers.2.mlp.experts.55.up_proj, device: CPU + model.layers.2.mlp.experts.55.down_proj, device: CPU + model.layers.2.mlp.experts.55.act, device: CPU + Module: model.layers.2.mlp.experts.56, device: CPU + model.layers.2.mlp.experts.56.gate_proj, device: CPU + model.layers.2.mlp.experts.56.up_proj, device: CPU + model.layers.2.mlp.experts.56.down_proj, device: CPU + model.layers.2.mlp.experts.56.act, device: CPU + Module: model.layers.2.mlp.experts.57, device: CPU + model.layers.2.mlp.experts.57.gate_proj, device: CPU + model.layers.2.mlp.experts.57.up_proj, device: CPU + model.layers.2.mlp.experts.57.down_proj, device: CPU + model.layers.2.mlp.experts.57.act, device: CPU + Module: model.layers.2.mlp.experts.58, device: CPU + model.layers.2.mlp.experts.58.gate_proj, device: CPU + model.layers.2.mlp.experts.58.up_proj, device: CPU + model.layers.2.mlp.experts.58.down_proj, device: CPU + model.layers.2.mlp.experts.58.act, device: CPU + Module: model.layers.2.mlp.experts.59, device: CPU + model.layers.2.mlp.experts.59.gate_proj, device: CPU + model.layers.2.mlp.experts.59.up_proj, device: CPU + model.layers.2.mlp.experts.59.down_proj, device: CPU + model.layers.2.mlp.experts.59.act, device: CPU + Module: model.layers.2.mlp.experts.60, device: CPU + model.layers.2.mlp.experts.60.gate_proj, device: CPU + model.layers.2.mlp.experts.60.up_proj, device: CPU + model.layers.2.mlp.experts.60.down_proj, device: CPU + model.layers.2.mlp.experts.60.act, device: CPU + Module: model.layers.2.mlp.experts.61, device: CPU + model.layers.2.mlp.experts.61.gate_proj, device: CPU + model.layers.2.mlp.experts.61.up_proj, device: CPU + model.layers.2.mlp.experts.61.down_proj, device: CPU + model.layers.2.mlp.experts.61.act, device: CPU + Module: model.layers.2.mlp.experts.62, device: CPU + model.layers.2.mlp.experts.62.gate_proj, device: CPU + model.layers.2.mlp.experts.62.up_proj, device: CPU + model.layers.2.mlp.experts.62.down_proj, device: CPU + model.layers.2.mlp.experts.62.act, device: CPU + Module: model.layers.2.mlp.experts.63, device: CPU + model.layers.2.mlp.experts.63.gate_proj, device: CPU + model.layers.2.mlp.experts.63.up_proj, device: CPU + model.layers.2.mlp.experts.63.down_proj, device: CPU + model.layers.2.mlp.experts.63.act, device: CPU + Module: model.layers.2.mlp.gate, device: CPU + model.layers.2.mlp.gate.weight, device: CPU + Module: model.layers.2.mlp.shared_experts, device: CPU + model.layers.2.mlp.shared_experts.gate_proj, device: CPU + model.layers.2.mlp.shared_experts.up_proj, device: CPU + model.layers.2.mlp.shared_experts.down_proj, device: CPU + model.layers.2.mlp.shared_experts.act, device: CPU + model.layers.2.input_layernorm, device: CPU + model.layers.2.post_attention_layernorm, device: CPU + Module: model.layers.3, device: CPU + Module: model.layers.3.self_attn, device: CPU + model.layers.3.self_attn.q_proj, device: CPU + model.layers.3.self_attn.k_proj, device: CPU + model.layers.3.self_attn.v_proj, device: CPU + model.layers.3.self_attn.o_proj, device: CPU + model.layers.3.self_attn.q_rope, device: CPU + model.layers.3.self_attn.k_rope, device: CPU + Module: model.layers.3.mlp, device: CPU + Module: model.layers.3.mlp.experts, device: CPU + Module: model.layers.3.mlp.experts.0, device: CPU + model.layers.3.mlp.experts.0.gate_proj, device: CPU + model.layers.3.mlp.experts.0.up_proj, device: CPU + model.layers.3.mlp.experts.0.down_proj, device: CPU + model.layers.3.mlp.experts.0.act, device: CPU + Module: model.layers.3.mlp.experts.1, device: CPU + model.layers.3.mlp.experts.1.gate_proj, device: CPU + model.layers.3.mlp.experts.1.up_proj, device: CPU + model.layers.3.mlp.experts.1.down_proj, device: CPU + model.layers.3.mlp.experts.1.act, device: CPU + Module: model.layers.3.mlp.experts.2, device: CPU + model.layers.3.mlp.experts.2.gate_proj, device: CPU + model.layers.3.mlp.experts.2.up_proj, device: CPU + model.layers.3.mlp.experts.2.down_proj, device: CPU + model.layers.3.mlp.experts.2.act, device: CPU + Module: model.layers.3.mlp.experts.3, device: CPU + model.layers.3.mlp.experts.3.gate_proj, device: CPU + model.layers.3.mlp.experts.3.up_proj, device: CPU + model.layers.3.mlp.experts.3.down_proj, device: CPU + model.layers.3.mlp.experts.3.act, device: CPU + Module: model.layers.3.mlp.experts.4, device: CPU + model.layers.3.mlp.experts.4.gate_proj, device: CPU + model.layers.3.mlp.experts.4.up_proj, device: CPU + model.layers.3.mlp.experts.4.down_proj, device: CPU + model.layers.3.mlp.experts.4.act, device: CPU + Module: model.layers.3.mlp.experts.5, device: CPU + model.layers.3.mlp.experts.5.gate_proj, device: CPU + model.layers.3.mlp.experts.5.up_proj, device: CPU + model.layers.3.mlp.experts.5.down_proj, device: CPU + model.layers.3.mlp.experts.5.act, device: CPU + Module: model.layers.3.mlp.experts.6, device: CPU + model.layers.3.mlp.experts.6.gate_proj, device: CPU + model.layers.3.mlp.experts.6.up_proj, device: CPU + model.layers.3.mlp.experts.6.down_proj, device: CPU + model.layers.3.mlp.experts.6.act, device: CPU + Module: model.layers.3.mlp.experts.7, device: CPU + model.layers.3.mlp.experts.7.gate_proj, device: CPU + model.layers.3.mlp.experts.7.up_proj, device: CPU + model.layers.3.mlp.experts.7.down_proj, device: CPU + model.layers.3.mlp.experts.7.act, device: CPU + Module: model.layers.3.mlp.experts.8, device: CPU + model.layers.3.mlp.experts.8.gate_proj, device: CPU + model.layers.3.mlp.experts.8.up_proj, device: CPU + model.layers.3.mlp.experts.8.down_proj, device: CPU + model.layers.3.mlp.experts.8.act, device: CPU + Module: model.layers.3.mlp.experts.9, device: CPU + model.layers.3.mlp.experts.9.gate_proj, device: CPU + model.layers.3.mlp.experts.9.up_proj, device: CPU + model.layers.3.mlp.experts.9.down_proj, device: CPU + model.layers.3.mlp.experts.9.act, device: CPU + Module: model.layers.3.mlp.experts.10, device: CPU + model.layers.3.mlp.experts.10.gate_proj, device: CPU + model.layers.3.mlp.experts.10.up_proj, device: CPU + model.layers.3.mlp.experts.10.down_proj, device: CPU + model.layers.3.mlp.experts.10.act, device: CPU + Module: model.layers.3.mlp.experts.11, device: CPU + model.layers.3.mlp.experts.11.gate_proj, device: CPU + model.layers.3.mlp.experts.11.up_proj, device: CPU + model.layers.3.mlp.experts.11.down_proj, device: CPU + model.layers.3.mlp.experts.11.act, device: CPU + Module: model.layers.3.mlp.experts.12, device: CPU + model.layers.3.mlp.experts.12.gate_proj, device: CPU + model.layers.3.mlp.experts.12.up_proj, device: CPU + model.layers.3.mlp.experts.12.down_proj, device: CPU + model.layers.3.mlp.experts.12.act, device: CPU + Module: model.layers.3.mlp.experts.13, device: CPU + model.layers.3.mlp.experts.13.gate_proj, device: CPU + model.layers.3.mlp.experts.13.up_proj, device: CPU + model.layers.3.mlp.experts.13.down_proj, device: CPU + model.layers.3.mlp.experts.13.act, device: CPU + Module: model.layers.3.mlp.experts.14, device: CPU + model.layers.3.mlp.experts.14.gate_proj, device: CPU + model.layers.3.mlp.experts.14.up_proj, device: CPU + model.layers.3.mlp.experts.14.down_proj, device: CPU + model.layers.3.mlp.experts.14.act, device: CPU + Module: model.layers.3.mlp.experts.15, device: CPU + model.layers.3.mlp.experts.15.gate_proj, device: CPU + model.layers.3.mlp.experts.15.up_proj, device: CPU + model.layers.3.mlp.experts.15.down_proj, device: CPU + model.layers.3.mlp.experts.15.act, device: CPU + Module: model.layers.3.mlp.experts.16, device: CPU + model.layers.3.mlp.experts.16.gate_proj, device: CPU + model.layers.3.mlp.experts.16.up_proj, device: CPU + model.layers.3.mlp.experts.16.down_proj, device: CPU + model.layers.3.mlp.experts.16.act, device: CPU + Module: model.layers.3.mlp.experts.17, device: CPU + model.layers.3.mlp.experts.17.gate_proj, device: CPU + model.layers.3.mlp.experts.17.up_proj, device: CPU + model.layers.3.mlp.experts.17.down_proj, device: CPU + model.layers.3.mlp.experts.17.act, device: CPU + Module: model.layers.3.mlp.experts.18, device: CPU + model.layers.3.mlp.experts.18.gate_proj, device: CPU + model.layers.3.mlp.experts.18.up_proj, device: CPU + model.layers.3.mlp.experts.18.down_proj, device: CPU + model.layers.3.mlp.experts.18.act, device: CPU + Module: model.layers.3.mlp.experts.19, device: CPU + model.layers.3.mlp.experts.19.gate_proj, device: CPU + model.layers.3.mlp.experts.19.up_proj, device: CPU + model.layers.3.mlp.experts.19.down_proj, device: CPU + model.layers.3.mlp.experts.19.act, device: CPU + Module: model.layers.3.mlp.experts.20, device: CPU + model.layers.3.mlp.experts.20.gate_proj, device: CPU + model.layers.3.mlp.experts.20.up_proj, device: CPU + model.layers.3.mlp.experts.20.down_proj, device: CPU + model.layers.3.mlp.experts.20.act, device: CPU + Module: model.layers.3.mlp.experts.21, device: CPU + model.layers.3.mlp.experts.21.gate_proj, device: CPU + model.layers.3.mlp.experts.21.up_proj, device: CPU + model.layers.3.mlp.experts.21.down_proj, device: CPU + model.layers.3.mlp.experts.21.act, device: CPU + Module: model.layers.3.mlp.experts.22, device: CPU + model.layers.3.mlp.experts.22.gate_proj, device: CPU + model.layers.3.mlp.experts.22.up_proj, device: CPU + model.layers.3.mlp.experts.22.down_proj, device: CPU + model.layers.3.mlp.experts.22.act, device: CPU + Module: model.layers.3.mlp.experts.23, device: CPU + model.layers.3.mlp.experts.23.gate_proj, device: CPU + model.layers.3.mlp.experts.23.up_proj, device: CPU + model.layers.3.mlp.experts.23.down_proj, device: CPU + model.layers.3.mlp.experts.23.act, device: CPU + Module: model.layers.3.mlp.experts.24, device: CPU + model.layers.3.mlp.experts.24.gate_proj, device: CPU + model.layers.3.mlp.experts.24.up_proj, device: CPU + model.layers.3.mlp.experts.24.down_proj, device: CPU + model.layers.3.mlp.experts.24.act, device: CPU + Module: model.layers.3.mlp.experts.25, device: CPU + model.layers.3.mlp.experts.25.gate_proj, device: CPU + model.layers.3.mlp.experts.25.up_proj, device: CPU + model.layers.3.mlp.experts.25.down_proj, device: CPU + model.layers.3.mlp.experts.25.act, device: CPU + Module: model.layers.3.mlp.experts.26, device: CPU + model.layers.3.mlp.experts.26.gate_proj, device: CPU + model.layers.3.mlp.experts.26.up_proj, device: CPU + model.layers.3.mlp.experts.26.down_proj, device: CPU + model.layers.3.mlp.experts.26.act, device: CPU + Module: model.layers.3.mlp.experts.27, device: CPU + model.layers.3.mlp.experts.27.gate_proj, device: CPU + model.layers.3.mlp.experts.27.up_proj, device: CPU + model.layers.3.mlp.experts.27.down_proj, device: CPU + model.layers.3.mlp.experts.27.act, device: CPU + Module: model.layers.3.mlp.experts.28, device: CPU + model.layers.3.mlp.experts.28.gate_proj, device: CPU + model.layers.3.mlp.experts.28.up_proj, device: CPU + model.layers.3.mlp.experts.28.down_proj, device: CPU + model.layers.3.mlp.experts.28.act, device: CPU + Module: model.layers.3.mlp.experts.29, device: CPU + model.layers.3.mlp.experts.29.gate_proj, device: CPU + model.layers.3.mlp.experts.29.up_proj, device: CPU + model.layers.3.mlp.experts.29.down_proj, device: CPU + model.layers.3.mlp.experts.29.act, device: CPU + Module: model.layers.3.mlp.experts.30, device: CPU + model.layers.3.mlp.experts.30.gate_proj, device: CPU + model.layers.3.mlp.experts.30.up_proj, device: CPU + model.layers.3.mlp.experts.30.down_proj, device: CPU + model.layers.3.mlp.experts.30.act, device: CPU + Module: model.layers.3.mlp.experts.31, device: CPU + model.layers.3.mlp.experts.31.gate_proj, device: CPU + model.layers.3.mlp.experts.31.up_proj, device: CPU + model.layers.3.mlp.experts.31.down_proj, device: CPU + model.layers.3.mlp.experts.31.act, device: CPU + Module: model.layers.3.mlp.experts.32, device: CPU + model.layers.3.mlp.experts.32.gate_proj, device: CPU + model.layers.3.mlp.experts.32.up_proj, device: CPU + model.layers.3.mlp.experts.32.down_proj, device: CPU + model.layers.3.mlp.experts.32.act, device: CPU + Module: model.layers.3.mlp.experts.33, device: CPU + model.layers.3.mlp.experts.33.gate_proj, device: CPU + model.layers.3.mlp.experts.33.up_proj, device: CPU + model.layers.3.mlp.experts.33.down_proj, device: CPU + model.layers.3.mlp.experts.33.act, device: CPU + Module: model.layers.3.mlp.experts.34, device: CPU + model.layers.3.mlp.experts.34.gate_proj, device: CPU + model.layers.3.mlp.experts.34.up_proj, device: CPU + model.layers.3.mlp.experts.34.down_proj, device: CPU + model.layers.3.mlp.experts.34.act, device: CPU + Module: model.layers.3.mlp.experts.35, device: CPU + model.layers.3.mlp.experts.35.gate_proj, device: CPU + model.layers.3.mlp.experts.35.up_proj, device: CPU + model.layers.3.mlp.experts.35.down_proj, device: CPU + model.layers.3.mlp.experts.35.act, device: CPU + Module: model.layers.3.mlp.experts.36, device: CPU + model.layers.3.mlp.experts.36.gate_proj, device: CPU + model.layers.3.mlp.experts.36.up_proj, device: CPU + model.layers.3.mlp.experts.36.down_proj, device: CPU + model.layers.3.mlp.experts.36.act, device: CPU + Module: model.layers.3.mlp.experts.37, device: CPU + model.layers.3.mlp.experts.37.gate_proj, device: CPU + model.layers.3.mlp.experts.37.up_proj, device: CPU + model.layers.3.mlp.experts.37.down_proj, device: CPU + model.layers.3.mlp.experts.37.act, device: CPU + Module: model.layers.3.mlp.experts.38, device: CPU + model.layers.3.mlp.experts.38.gate_proj, device: CPU + model.layers.3.mlp.experts.38.up_proj, device: CPU + model.layers.3.mlp.experts.38.down_proj, device: CPU + model.layers.3.mlp.experts.38.act, device: CPU + Module: model.layers.3.mlp.experts.39, device: CPU + model.layers.3.mlp.experts.39.gate_proj, device: CPU + model.layers.3.mlp.experts.39.up_proj, device: CPU + model.layers.3.mlp.experts.39.down_proj, device: CPU + model.layers.3.mlp.experts.39.act, device: CPU + Module: model.layers.3.mlp.experts.40, device: CPU + model.layers.3.mlp.experts.40.gate_proj, device: CPU + model.layers.3.mlp.experts.40.up_proj, device: CPU + model.layers.3.mlp.experts.40.down_proj, device: CPU + model.layers.3.mlp.experts.40.act, device: CPU + Module: model.layers.3.mlp.experts.41, device: CPU + model.layers.3.mlp.experts.41.gate_proj, device: CPU + model.layers.3.mlp.experts.41.up_proj, device: CPU + model.layers.3.mlp.experts.41.down_proj, device: CPU + model.layers.3.mlp.experts.41.act, device: CPU + Module: model.layers.3.mlp.experts.42, device: CPU + model.layers.3.mlp.experts.42.gate_proj, device: CPU + model.layers.3.mlp.experts.42.up_proj, device: CPU + model.layers.3.mlp.experts.42.down_proj, device: CPU + model.layers.3.mlp.experts.42.act, device: CPU + Module: model.layers.3.mlp.experts.43, device: CPU + model.layers.3.mlp.experts.43.gate_proj, device: CPU + model.layers.3.mlp.experts.43.up_proj, device: CPU + model.layers.3.mlp.experts.43.down_proj, device: CPU + model.layers.3.mlp.experts.43.act, device: CPU + Module: model.layers.3.mlp.experts.44, device: CPU + model.layers.3.mlp.experts.44.gate_proj, device: CPU + model.layers.3.mlp.experts.44.up_proj, device: CPU + model.layers.3.mlp.experts.44.down_proj, device: CPU + model.layers.3.mlp.experts.44.act, device: CPU + Module: model.layers.3.mlp.experts.45, device: CPU + model.layers.3.mlp.experts.45.gate_proj, device: CPU + model.layers.3.mlp.experts.45.up_proj, device: CPU + model.layers.3.mlp.experts.45.down_proj, device: CPU + model.layers.3.mlp.experts.45.act, device: CPU + Module: model.layers.3.mlp.experts.46, device: CPU + model.layers.3.mlp.experts.46.gate_proj, device: CPU + model.layers.3.mlp.experts.46.up_proj, device: CPU + model.layers.3.mlp.experts.46.down_proj, device: CPU + model.layers.3.mlp.experts.46.act, device: CPU + Module: model.layers.3.mlp.experts.47, device: CPU + model.layers.3.mlp.experts.47.gate_proj, device: CPU + model.layers.3.mlp.experts.47.up_proj, device: CPU + model.layers.3.mlp.experts.47.down_proj, device: CPU + model.layers.3.mlp.experts.47.act, device: CPU + Module: model.layers.3.mlp.experts.48, device: CPU + model.layers.3.mlp.experts.48.gate_proj, device: CPU + model.layers.3.mlp.experts.48.up_proj, device: CPU + model.layers.3.mlp.experts.48.down_proj, device: CPU + model.layers.3.mlp.experts.48.act, device: CPU + Module: model.layers.3.mlp.experts.49, device: CPU + model.layers.3.mlp.experts.49.gate_proj, device: CPU + model.layers.3.mlp.experts.49.up_proj, device: CPU + model.layers.3.mlp.experts.49.down_proj, device: CPU + model.layers.3.mlp.experts.49.act, device: CPU + Module: model.layers.3.mlp.experts.50, device: CPU + model.layers.3.mlp.experts.50.gate_proj, device: CPU + model.layers.3.mlp.experts.50.up_proj, device: CPU + model.layers.3.mlp.experts.50.down_proj, device: CPU + model.layers.3.mlp.experts.50.act, device: CPU + Module: model.layers.3.mlp.experts.51, device: CPU + model.layers.3.mlp.experts.51.gate_proj, device: CPU + model.layers.3.mlp.experts.51.up_proj, device: CPU + model.layers.3.mlp.experts.51.down_proj, device: CPU + model.layers.3.mlp.experts.51.act, device: CPU + Module: model.layers.3.mlp.experts.52, device: CPU + model.layers.3.mlp.experts.52.gate_proj, device: CPU + model.layers.3.mlp.experts.52.up_proj, device: CPU + model.layers.3.mlp.experts.52.down_proj, device: CPU + model.layers.3.mlp.experts.52.act, device: CPU + Module: model.layers.3.mlp.experts.53, device: CPU + model.layers.3.mlp.experts.53.gate_proj, device: CPU + model.layers.3.mlp.experts.53.up_proj, device: CPU + model.layers.3.mlp.experts.53.down_proj, device: CPU + model.layers.3.mlp.experts.53.act, device: CPU + Module: model.layers.3.mlp.experts.54, device: CPU + model.layers.3.mlp.experts.54.gate_proj, device: CPU + model.layers.3.mlp.experts.54.up_proj, device: CPU + model.layers.3.mlp.experts.54.down_proj, device: CPU + model.layers.3.mlp.experts.54.act, device: CPU + Module: model.layers.3.mlp.experts.55, device: CPU + model.layers.3.mlp.experts.55.gate_proj, device: CPU + model.layers.3.mlp.experts.55.up_proj, device: CPU + model.layers.3.mlp.experts.55.down_proj, device: CPU + model.layers.3.mlp.experts.55.act, device: CPU + Module: model.layers.3.mlp.experts.56, device: CPU + model.layers.3.mlp.experts.56.gate_proj, device: CPU + model.layers.3.mlp.experts.56.up_proj, device: CPU + model.layers.3.mlp.experts.56.down_proj, device: CPU + model.layers.3.mlp.experts.56.act, device: CPU + Module: model.layers.3.mlp.experts.57, device: CPU + model.layers.3.mlp.experts.57.gate_proj, device: CPU + model.layers.3.mlp.experts.57.up_proj, device: CPU + model.layers.3.mlp.experts.57.down_proj, device: CPU + model.layers.3.mlp.experts.57.act, device: CPU + Module: model.layers.3.mlp.experts.58, device: CPU + model.layers.3.mlp.experts.58.gate_proj, device: CPU + model.layers.3.mlp.experts.58.up_proj, device: CPU + model.layers.3.mlp.experts.58.down_proj, device: CPU + model.layers.3.mlp.experts.58.act, device: CPU + Module: model.layers.3.mlp.experts.59, device: CPU + model.layers.3.mlp.experts.59.gate_proj, device: CPU + model.layers.3.mlp.experts.59.up_proj, device: CPU + model.layers.3.mlp.experts.59.down_proj, device: CPU + model.layers.3.mlp.experts.59.act, device: CPU + Module: model.layers.3.mlp.experts.60, device: CPU + model.layers.3.mlp.experts.60.gate_proj, device: CPU + model.layers.3.mlp.experts.60.up_proj, device: CPU + model.layers.3.mlp.experts.60.down_proj, device: CPU + model.layers.3.mlp.experts.60.act, device: CPU + Module: model.layers.3.mlp.experts.61, device: CPU + model.layers.3.mlp.experts.61.gate_proj, device: CPU + model.layers.3.mlp.experts.61.up_proj, device: CPU + model.layers.3.mlp.experts.61.down_proj, device: CPU + model.layers.3.mlp.experts.61.act, device: CPU + Module: model.layers.3.mlp.experts.62, device: CPU + model.layers.3.mlp.experts.62.gate_proj, device: CPU + model.layers.3.mlp.experts.62.up_proj, device: CPU + model.layers.3.mlp.experts.62.down_proj, device: CPU + model.layers.3.mlp.experts.62.act, device: CPU + Module: model.layers.3.mlp.experts.63, device: CPU + model.layers.3.mlp.experts.63.gate_proj, device: CPU + model.layers.3.mlp.experts.63.up_proj, device: CPU + model.layers.3.mlp.experts.63.down_proj, device: CPU + model.layers.3.mlp.experts.63.act, device: CPU + Module: model.layers.3.mlp.gate, device: CPU + model.layers.3.mlp.gate.weight, device: CPU + Module: model.layers.3.mlp.shared_experts, device: CPU + model.layers.3.mlp.shared_experts.gate_proj, device: CPU + model.layers.3.mlp.shared_experts.up_proj, device: CPU + model.layers.3.mlp.shared_experts.down_proj, device: CPU + model.layers.3.mlp.shared_experts.act, device: CPU + model.layers.3.input_layernorm, device: CPU + model.layers.3.post_attention_layernorm, device: CPU + Module: model.layers.4, device: CPU + Module: model.layers.4.self_attn, device: CPU + model.layers.4.self_attn.q_proj, device: CPU + model.layers.4.self_attn.k_proj, device: CPU + model.layers.4.self_attn.v_proj, device: CPU + model.layers.4.self_attn.o_proj, device: CPU + model.layers.4.self_attn.q_rope, device: CPU + model.layers.4.self_attn.k_rope, device: CPU + Module: model.layers.4.mlp, device: CPU + Module: model.layers.4.mlp.experts, device: CPU + Module: model.layers.4.mlp.experts.0, device: CPU + model.layers.4.mlp.experts.0.gate_proj, device: CPU + model.layers.4.mlp.experts.0.up_proj, device: CPU + model.layers.4.mlp.experts.0.down_proj, device: CPU + model.layers.4.mlp.experts.0.act, device: CPU + Module: model.layers.4.mlp.experts.1, device: CPU + model.layers.4.mlp.experts.1.gate_proj, device: CPU + model.layers.4.mlp.experts.1.up_proj, device: CPU + model.layers.4.mlp.experts.1.down_proj, device: CPU + model.layers.4.mlp.experts.1.act, device: CPU + Module: model.layers.4.mlp.experts.2, device: CPU + model.layers.4.mlp.experts.2.gate_proj, device: CPU + model.layers.4.mlp.experts.2.up_proj, device: CPU + model.layers.4.mlp.experts.2.down_proj, device: CPU + model.layers.4.mlp.experts.2.act, device: CPU + Module: model.layers.4.mlp.experts.3, device: CPU + model.layers.4.mlp.experts.3.gate_proj, device: CPU + model.layers.4.mlp.experts.3.up_proj, device: CPU + model.layers.4.mlp.experts.3.down_proj, device: CPU + model.layers.4.mlp.experts.3.act, device: CPU + Module: model.layers.4.mlp.experts.4, device: CPU + model.layers.4.mlp.experts.4.gate_proj, device: CPU + model.layers.4.mlp.experts.4.up_proj, device: CPU + model.layers.4.mlp.experts.4.down_proj, device: CPU + model.layers.4.mlp.experts.4.act, device: CPU + Module: model.layers.4.mlp.experts.5, device: CPU + model.layers.4.mlp.experts.5.gate_proj, device: CPU + model.layers.4.mlp.experts.5.up_proj, device: CPU + model.layers.4.mlp.experts.5.down_proj, device: CPU + model.layers.4.mlp.experts.5.act, device: CPU + Module: model.layers.4.mlp.experts.6, device: CPU + model.layers.4.mlp.experts.6.gate_proj, device: CPU + model.layers.4.mlp.experts.6.up_proj, device: CPU + model.layers.4.mlp.experts.6.down_proj, device: CPU + model.layers.4.mlp.experts.6.act, device: CPU + Module: model.layers.4.mlp.experts.7, device: CPU + model.layers.4.mlp.experts.7.gate_proj, device: CPU + model.layers.4.mlp.experts.7.up_proj, device: CPU + model.layers.4.mlp.experts.7.down_proj, device: CPU + model.layers.4.mlp.experts.7.act, device: CPU + Module: model.layers.4.mlp.experts.8, device: CPU + model.layers.4.mlp.experts.8.gate_proj, device: CPU + model.layers.4.mlp.experts.8.up_proj, device: CPU + model.layers.4.mlp.experts.8.down_proj, device: CPU + model.layers.4.mlp.experts.8.act, device: CPU + Module: model.layers.4.mlp.experts.9, device: CPU + model.layers.4.mlp.experts.9.gate_proj, device: CPU + model.layers.4.mlp.experts.9.up_proj, device: CPU + model.layers.4.mlp.experts.9.down_proj, device: CPU + model.layers.4.mlp.experts.9.act, device: CPU + Module: model.layers.4.mlp.experts.10, device: CPU + model.layers.4.mlp.experts.10.gate_proj, device: CPU + model.layers.4.mlp.experts.10.up_proj, device: CPU + model.layers.4.mlp.experts.10.down_proj, device: CPU + model.layers.4.mlp.experts.10.act, device: CPU + Module: model.layers.4.mlp.experts.11, device: CPU + model.layers.4.mlp.experts.11.gate_proj, device: CPU + model.layers.4.mlp.experts.11.up_proj, device: CPU + model.layers.4.mlp.experts.11.down_proj, device: CPU + model.layers.4.mlp.experts.11.act, device: CPU + Module: model.layers.4.mlp.experts.12, device: CPU + model.layers.4.mlp.experts.12.gate_proj, device: CPU + model.layers.4.mlp.experts.12.up_proj, device: CPU + model.layers.4.mlp.experts.12.down_proj, device: CPU + model.layers.4.mlp.experts.12.act, device: CPU + Module: model.layers.4.mlp.experts.13, device: CPU + model.layers.4.mlp.experts.13.gate_proj, device: CPU + model.layers.4.mlp.experts.13.up_proj, device: CPU + model.layers.4.mlp.experts.13.down_proj, device: CPU + model.layers.4.mlp.experts.13.act, device: CPU + Module: model.layers.4.mlp.experts.14, device: CPU + model.layers.4.mlp.experts.14.gate_proj, device: CPU + model.layers.4.mlp.experts.14.up_proj, device: CPU + model.layers.4.mlp.experts.14.down_proj, device: CPU + model.layers.4.mlp.experts.14.act, device: CPU + Module: model.layers.4.mlp.experts.15, device: CPU + model.layers.4.mlp.experts.15.gate_proj, device: CPU + model.layers.4.mlp.experts.15.up_proj, device: CPU + model.layers.4.mlp.experts.15.down_proj, device: CPU + model.layers.4.mlp.experts.15.act, device: CPU + Module: model.layers.4.mlp.experts.16, device: CPU + model.layers.4.mlp.experts.16.gate_proj, device: CPU + model.layers.4.mlp.experts.16.up_proj, device: CPU + model.layers.4.mlp.experts.16.down_proj, device: CPU + model.layers.4.mlp.experts.16.act, device: CPU + Module: model.layers.4.mlp.experts.17, device: CPU + model.layers.4.mlp.experts.17.gate_proj, device: CPU + model.layers.4.mlp.experts.17.up_proj, device: CPU + model.layers.4.mlp.experts.17.down_proj, device: CPU + model.layers.4.mlp.experts.17.act, device: CPU + Module: model.layers.4.mlp.experts.18, device: CPU + model.layers.4.mlp.experts.18.gate_proj, device: CPU + model.layers.4.mlp.experts.18.up_proj, device: CPU + model.layers.4.mlp.experts.18.down_proj, device: CPU + model.layers.4.mlp.experts.18.act, device: CPU + Module: model.layers.4.mlp.experts.19, device: CPU + model.layers.4.mlp.experts.19.gate_proj, device: CPU + model.layers.4.mlp.experts.19.up_proj, device: CPU + model.layers.4.mlp.experts.19.down_proj, device: CPU + model.layers.4.mlp.experts.19.act, device: CPU + Module: model.layers.4.mlp.experts.20, device: CPU + model.layers.4.mlp.experts.20.gate_proj, device: CPU + model.layers.4.mlp.experts.20.up_proj, device: CPU + model.layers.4.mlp.experts.20.down_proj, device: CPU + model.layers.4.mlp.experts.20.act, device: CPU + Module: model.layers.4.mlp.experts.21, device: CPU + model.layers.4.mlp.experts.21.gate_proj, device: CPU + model.layers.4.mlp.experts.21.up_proj, device: CPU + model.layers.4.mlp.experts.21.down_proj, device: CPU + model.layers.4.mlp.experts.21.act, device: CPU + Module: model.layers.4.mlp.experts.22, device: CPU + model.layers.4.mlp.experts.22.gate_proj, device: CPU + model.layers.4.mlp.experts.22.up_proj, device: CPU + model.layers.4.mlp.experts.22.down_proj, device: CPU + model.layers.4.mlp.experts.22.act, device: CPU + Module: model.layers.4.mlp.experts.23, device: CPU + model.layers.4.mlp.experts.23.gate_proj, device: CPU + model.layers.4.mlp.experts.23.up_proj, device: CPU + model.layers.4.mlp.experts.23.down_proj, device: CPU + model.layers.4.mlp.experts.23.act, device: CPU + Module: model.layers.4.mlp.experts.24, device: CPU + model.layers.4.mlp.experts.24.gate_proj, device: CPU + model.layers.4.mlp.experts.24.up_proj, device: CPU + model.layers.4.mlp.experts.24.down_proj, device: CPU + model.layers.4.mlp.experts.24.act, device: CPU + Module: model.layers.4.mlp.experts.25, device: CPU + model.layers.4.mlp.experts.25.gate_proj, device: CPU + model.layers.4.mlp.experts.25.up_proj, device: CPU + model.layers.4.mlp.experts.25.down_proj, device: CPU + model.layers.4.mlp.experts.25.act, device: CPU + Module: model.layers.4.mlp.experts.26, device: CPU + model.layers.4.mlp.experts.26.gate_proj, device: CPU + model.layers.4.mlp.experts.26.up_proj, device: CPU + model.layers.4.mlp.experts.26.down_proj, device: CPU + model.layers.4.mlp.experts.26.act, device: CPU + Module: model.layers.4.mlp.experts.27, device: CPU + model.layers.4.mlp.experts.27.gate_proj, device: CPU + model.layers.4.mlp.experts.27.up_proj, device: CPU + model.layers.4.mlp.experts.27.down_proj, device: CPU + model.layers.4.mlp.experts.27.act, device: CPU + Module: model.layers.4.mlp.experts.28, device: CPU + model.layers.4.mlp.experts.28.gate_proj, device: CPU + model.layers.4.mlp.experts.28.up_proj, device: CPU + model.layers.4.mlp.experts.28.down_proj, device: CPU + model.layers.4.mlp.experts.28.act, device: CPU + Module: model.layers.4.mlp.experts.29, device: CPU + model.layers.4.mlp.experts.29.gate_proj, device: CPU + model.layers.4.mlp.experts.29.up_proj, device: CPU + model.layers.4.mlp.experts.29.down_proj, device: CPU + model.layers.4.mlp.experts.29.act, device: CPU + Module: model.layers.4.mlp.experts.30, device: CPU + model.layers.4.mlp.experts.30.gate_proj, device: CPU + model.layers.4.mlp.experts.30.up_proj, device: CPU + model.layers.4.mlp.experts.30.down_proj, device: CPU + model.layers.4.mlp.experts.30.act, device: CPU + Module: model.layers.4.mlp.experts.31, device: CPU + model.layers.4.mlp.experts.31.gate_proj, device: CPU + model.layers.4.mlp.experts.31.up_proj, device: CPU + model.layers.4.mlp.experts.31.down_proj, device: CPU + model.layers.4.mlp.experts.31.act, device: CPU + Module: model.layers.4.mlp.experts.32, device: CPU + model.layers.4.mlp.experts.32.gate_proj, device: CPU + model.layers.4.mlp.experts.32.up_proj, device: CPU + model.layers.4.mlp.experts.32.down_proj, device: CPU + model.layers.4.mlp.experts.32.act, device: CPU + Module: model.layers.4.mlp.experts.33, device: CPU + model.layers.4.mlp.experts.33.gate_proj, device: CPU + model.layers.4.mlp.experts.33.up_proj, device: CPU + model.layers.4.mlp.experts.33.down_proj, device: CPU + model.layers.4.mlp.experts.33.act, device: CPU + Module: model.layers.4.mlp.experts.34, device: CPU + model.layers.4.mlp.experts.34.gate_proj, device: CPU + model.layers.4.mlp.experts.34.up_proj, device: CPU + model.layers.4.mlp.experts.34.down_proj, device: CPU + model.layers.4.mlp.experts.34.act, device: CPU + Module: model.layers.4.mlp.experts.35, device: CPU + model.layers.4.mlp.experts.35.gate_proj, device: CPU + model.layers.4.mlp.experts.35.up_proj, device: CPU + model.layers.4.mlp.experts.35.down_proj, device: CPU + model.layers.4.mlp.experts.35.act, device: CPU + Module: model.layers.4.mlp.experts.36, device: CPU + model.layers.4.mlp.experts.36.gate_proj, device: CPU + model.layers.4.mlp.experts.36.up_proj, device: CPU + model.layers.4.mlp.experts.36.down_proj, device: CPU + model.layers.4.mlp.experts.36.act, device: CPU + Module: model.layers.4.mlp.experts.37, device: CPU + model.layers.4.mlp.experts.37.gate_proj, device: CPU + model.layers.4.mlp.experts.37.up_proj, device: CPU + model.layers.4.mlp.experts.37.down_proj, device: CPU + model.layers.4.mlp.experts.37.act, device: CPU + Module: model.layers.4.mlp.experts.38, device: CPU + model.layers.4.mlp.experts.38.gate_proj, device: CPU + model.layers.4.mlp.experts.38.up_proj, device: CPU + model.layers.4.mlp.experts.38.down_proj, device: CPU + model.layers.4.mlp.experts.38.act, device: CPU + Module: model.layers.4.mlp.experts.39, device: CPU + model.layers.4.mlp.experts.39.gate_proj, device: CPU + model.layers.4.mlp.experts.39.up_proj, device: CPU + model.layers.4.mlp.experts.39.down_proj, device: CPU + model.layers.4.mlp.experts.39.act, device: CPU + Module: model.layers.4.mlp.experts.40, device: CPU + model.layers.4.mlp.experts.40.gate_proj, device: CPU + model.layers.4.mlp.experts.40.up_proj, device: CPU + model.layers.4.mlp.experts.40.down_proj, device: CPU + model.layers.4.mlp.experts.40.act, device: CPU + Module: model.layers.4.mlp.experts.41, device: CPU + model.layers.4.mlp.experts.41.gate_proj, device: CPU + model.layers.4.mlp.experts.41.up_proj, device: CPU + model.layers.4.mlp.experts.41.down_proj, device: CPU + model.layers.4.mlp.experts.41.act, device: CPU + Module: model.layers.4.mlp.experts.42, device: CPU + model.layers.4.mlp.experts.42.gate_proj, device: CPU + model.layers.4.mlp.experts.42.up_proj, device: CPU + model.layers.4.mlp.experts.42.down_proj, device: CPU + model.layers.4.mlp.experts.42.act, device: CPU + Module: model.layers.4.mlp.experts.43, device: CPU + model.layers.4.mlp.experts.43.gate_proj, device: CPU + model.layers.4.mlp.experts.43.up_proj, device: CPU + model.layers.4.mlp.experts.43.down_proj, device: CPU + model.layers.4.mlp.experts.43.act, device: CPU + Module: model.layers.4.mlp.experts.44, device: CPU + model.layers.4.mlp.experts.44.gate_proj, device: CPU + model.layers.4.mlp.experts.44.up_proj, device: CPU + model.layers.4.mlp.experts.44.down_proj, device: CPU + model.layers.4.mlp.experts.44.act, device: CPU + Module: model.layers.4.mlp.experts.45, device: CPU + model.layers.4.mlp.experts.45.gate_proj, device: CPU + model.layers.4.mlp.experts.45.up_proj, device: CPU + model.layers.4.mlp.experts.45.down_proj, device: CPU + model.layers.4.mlp.experts.45.act, device: CPU + Module: model.layers.4.mlp.experts.46, device: CPU + model.layers.4.mlp.experts.46.gate_proj, device: CPU + model.layers.4.mlp.experts.46.up_proj, device: CPU + model.layers.4.mlp.experts.46.down_proj, device: CPU + model.layers.4.mlp.experts.46.act, device: CPU + Module: model.layers.4.mlp.experts.47, device: CPU + model.layers.4.mlp.experts.47.gate_proj, device: CPU + model.layers.4.mlp.experts.47.up_proj, device: CPU + model.layers.4.mlp.experts.47.down_proj, device: CPU + model.layers.4.mlp.experts.47.act, device: CPU + Module: model.layers.4.mlp.experts.48, device: CPU + model.layers.4.mlp.experts.48.gate_proj, device: CPU + model.layers.4.mlp.experts.48.up_proj, device: CPU + model.layers.4.mlp.experts.48.down_proj, device: CPU + model.layers.4.mlp.experts.48.act, device: CPU + Module: model.layers.4.mlp.experts.49, device: CPU + model.layers.4.mlp.experts.49.gate_proj, device: CPU + model.layers.4.mlp.experts.49.up_proj, device: CPU + model.layers.4.mlp.experts.49.down_proj, device: CPU + model.layers.4.mlp.experts.49.act, device: CPU + Module: model.layers.4.mlp.experts.50, device: CPU + model.layers.4.mlp.experts.50.gate_proj, device: CPU + model.layers.4.mlp.experts.50.up_proj, device: CPU + model.layers.4.mlp.experts.50.down_proj, device: CPU + model.layers.4.mlp.experts.50.act, device: CPU + Module: model.layers.4.mlp.experts.51, device: CPU + model.layers.4.mlp.experts.51.gate_proj, device: CPU + model.layers.4.mlp.experts.51.up_proj, device: CPU + model.layers.4.mlp.experts.51.down_proj, device: CPU + model.layers.4.mlp.experts.51.act, device: CPU + Module: model.layers.4.mlp.experts.52, device: CPU + model.layers.4.mlp.experts.52.gate_proj, device: CPU + model.layers.4.mlp.experts.52.up_proj, device: CPU + model.layers.4.mlp.experts.52.down_proj, device: CPU + model.layers.4.mlp.experts.52.act, device: CPU + Module: model.layers.4.mlp.experts.53, device: CPU + model.layers.4.mlp.experts.53.gate_proj, device: CPU + model.layers.4.mlp.experts.53.up_proj, device: CPU + model.layers.4.mlp.experts.53.down_proj, device: CPU + model.layers.4.mlp.experts.53.act, device: CPU + Module: model.layers.4.mlp.experts.54, device: CPU + model.layers.4.mlp.experts.54.gate_proj, device: CPU + model.layers.4.mlp.experts.54.up_proj, device: CPU + model.layers.4.mlp.experts.54.down_proj, device: CPU + model.layers.4.mlp.experts.54.act, device: CPU + Module: model.layers.4.mlp.experts.55, device: CPU + model.layers.4.mlp.experts.55.gate_proj, device: CPU + model.layers.4.mlp.experts.55.up_proj, device: CPU + model.layers.4.mlp.experts.55.down_proj, device: CPU + model.layers.4.mlp.experts.55.act, device: CPU + Module: model.layers.4.mlp.experts.56, device: CPU + model.layers.4.mlp.experts.56.gate_proj, device: CPU + model.layers.4.mlp.experts.56.up_proj, device: CPU + model.layers.4.mlp.experts.56.down_proj, device: CPU + model.layers.4.mlp.experts.56.act, device: CPU + Module: model.layers.4.mlp.experts.57, device: CPU + model.layers.4.mlp.experts.57.gate_proj, device: CPU + model.layers.4.mlp.experts.57.up_proj, device: CPU + model.layers.4.mlp.experts.57.down_proj, device: CPU + model.layers.4.mlp.experts.57.act, device: CPU + Module: model.layers.4.mlp.experts.58, device: CPU + model.layers.4.mlp.experts.58.gate_proj, device: CPU + model.layers.4.mlp.experts.58.up_proj, device: CPU + model.layers.4.mlp.experts.58.down_proj, device: CPU + model.layers.4.mlp.experts.58.act, device: CPU + Module: model.layers.4.mlp.experts.59, device: CPU + model.layers.4.mlp.experts.59.gate_proj, device: CPU + model.layers.4.mlp.experts.59.up_proj, device: CPU + model.layers.4.mlp.experts.59.down_proj, device: CPU + model.layers.4.mlp.experts.59.act, device: CPU + Module: model.layers.4.mlp.experts.60, device: CPU + model.layers.4.mlp.experts.60.gate_proj, device: CPU + model.layers.4.mlp.experts.60.up_proj, device: CPU + model.layers.4.mlp.experts.60.down_proj, device: CPU + model.layers.4.mlp.experts.60.act, device: CPU + Module: model.layers.4.mlp.experts.61, device: CPU + model.layers.4.mlp.experts.61.gate_proj, device: CPU + model.layers.4.mlp.experts.61.up_proj, device: CPU + model.layers.4.mlp.experts.61.down_proj, device: CPU + model.layers.4.mlp.experts.61.act, device: CPU + Module: model.layers.4.mlp.experts.62, device: CPU + model.layers.4.mlp.experts.62.gate_proj, device: CPU + model.layers.4.mlp.experts.62.up_proj, device: CPU + model.layers.4.mlp.experts.62.down_proj, device: CPU + model.layers.4.mlp.experts.62.act, device: CPU + Module: model.layers.4.mlp.experts.63, device: CPU + model.layers.4.mlp.experts.63.gate_proj, device: CPU + model.layers.4.mlp.experts.63.up_proj, device: CPU + model.layers.4.mlp.experts.63.down_proj, device: CPU + model.layers.4.mlp.experts.63.act, device: CPU + Module: model.layers.4.mlp.gate, device: CPU + model.layers.4.mlp.gate.weight, device: CPU + Module: model.layers.4.mlp.shared_experts, device: CPU + model.layers.4.mlp.shared_experts.gate_proj, device: CPU + model.layers.4.mlp.shared_experts.up_proj, device: CPU + model.layers.4.mlp.shared_experts.down_proj, device: CPU + model.layers.4.mlp.shared_experts.act, device: CPU + model.layers.4.input_layernorm, device: CPU + model.layers.4.post_attention_layernorm, device: CPU + Module: model.layers.5, device: CPU + Module: model.layers.5.self_attn, device: CPU + model.layers.5.self_attn.q_proj, device: CPU + model.layers.5.self_attn.k_proj, device: CPU + model.layers.5.self_attn.v_proj, device: CPU + model.layers.5.self_attn.o_proj, device: CPU + model.layers.5.self_attn.q_rope, device: CPU + model.layers.5.self_attn.k_rope, device: CPU + Module: model.layers.5.mlp, device: CPU + Module: model.layers.5.mlp.experts, device: CPU + Module: model.layers.5.mlp.experts.0, device: CPU + model.layers.5.mlp.experts.0.gate_proj, device: CPU + model.layers.5.mlp.experts.0.up_proj, device: CPU + model.layers.5.mlp.experts.0.down_proj, device: CPU + model.layers.5.mlp.experts.0.act, device: CPU + Module: model.layers.5.mlp.experts.1, device: CPU + model.layers.5.mlp.experts.1.gate_proj, device: CPU + model.layers.5.mlp.experts.1.up_proj, device: CPU + model.layers.5.mlp.experts.1.down_proj, device: CPU + model.layers.5.mlp.experts.1.act, device: CPU + Module: model.layers.5.mlp.experts.2, device: CPU + model.layers.5.mlp.experts.2.gate_proj, device: CPU + model.layers.5.mlp.experts.2.up_proj, device: CPU + model.layers.5.mlp.experts.2.down_proj, device: CPU + model.layers.5.mlp.experts.2.act, device: CPU + Module: model.layers.5.mlp.experts.3, device: CPU + model.layers.5.mlp.experts.3.gate_proj, device: CPU + model.layers.5.mlp.experts.3.up_proj, device: CPU + model.layers.5.mlp.experts.3.down_proj, device: CPU + model.layers.5.mlp.experts.3.act, device: CPU + Module: model.layers.5.mlp.experts.4, device: CPU + model.layers.5.mlp.experts.4.gate_proj, device: CPU + model.layers.5.mlp.experts.4.up_proj, device: CPU + model.layers.5.mlp.experts.4.down_proj, device: CPU + model.layers.5.mlp.experts.4.act, device: CPU + Module: model.layers.5.mlp.experts.5, device: CPU + model.layers.5.mlp.experts.5.gate_proj, device: CPU + model.layers.5.mlp.experts.5.up_proj, device: CPU + model.layers.5.mlp.experts.5.down_proj, device: CPU + model.layers.5.mlp.experts.5.act, device: CPU + Module: model.layers.5.mlp.experts.6, device: CPU + model.layers.5.mlp.experts.6.gate_proj, device: CPU + model.layers.5.mlp.experts.6.up_proj, device: CPU + model.layers.5.mlp.experts.6.down_proj, device: CPU + model.layers.5.mlp.experts.6.act, device: CPU + Module: model.layers.5.mlp.experts.7, device: CPU + model.layers.5.mlp.experts.7.gate_proj, device: CPU + model.layers.5.mlp.experts.7.up_proj, device: CPU + model.layers.5.mlp.experts.7.down_proj, device: CPU + model.layers.5.mlp.experts.7.act, device: CPU + Module: model.layers.5.mlp.experts.8, device: CPU + model.layers.5.mlp.experts.8.gate_proj, device: CPU + model.layers.5.mlp.experts.8.up_proj, device: CPU + model.layers.5.mlp.experts.8.down_proj, device: CPU + model.layers.5.mlp.experts.8.act, device: CPU + Module: model.layers.5.mlp.experts.9, device: CPU + model.layers.5.mlp.experts.9.gate_proj, device: CPU + model.layers.5.mlp.experts.9.up_proj, device: CPU + model.layers.5.mlp.experts.9.down_proj, device: CPU + model.layers.5.mlp.experts.9.act, device: CPU + Module: model.layers.5.mlp.experts.10, device: CPU + model.layers.5.mlp.experts.10.gate_proj, device: CPU + model.layers.5.mlp.experts.10.up_proj, device: CPU + model.layers.5.mlp.experts.10.down_proj, device: CPU + model.layers.5.mlp.experts.10.act, device: CPU + Module: model.layers.5.mlp.experts.11, device: CPU + model.layers.5.mlp.experts.11.gate_proj, device: CPU + model.layers.5.mlp.experts.11.up_proj, device: CPU + model.layers.5.mlp.experts.11.down_proj, device: CPU + model.layers.5.mlp.experts.11.act, device: CPU + Module: model.layers.5.mlp.experts.12, device: CPU + model.layers.5.mlp.experts.12.gate_proj, device: CPU + model.layers.5.mlp.experts.12.up_proj, device: CPU + model.layers.5.mlp.experts.12.down_proj, device: CPU + model.layers.5.mlp.experts.12.act, device: CPU + Module: model.layers.5.mlp.experts.13, device: CPU + model.layers.5.mlp.experts.13.gate_proj, device: CPU + model.layers.5.mlp.experts.13.up_proj, device: CPU + model.layers.5.mlp.experts.13.down_proj, device: CPU + model.layers.5.mlp.experts.13.act, device: CPU + Module: model.layers.5.mlp.experts.14, device: CPU + model.layers.5.mlp.experts.14.gate_proj, device: CPU + model.layers.5.mlp.experts.14.up_proj, device: CPU + model.layers.5.mlp.experts.14.down_proj, device: CPU + model.layers.5.mlp.experts.14.act, device: CPU + Module: model.layers.5.mlp.experts.15, device: CPU + model.layers.5.mlp.experts.15.gate_proj, device: CPU + model.layers.5.mlp.experts.15.up_proj, device: CPU + model.layers.5.mlp.experts.15.down_proj, device: CPU + model.layers.5.mlp.experts.15.act, device: CPU + Module: model.layers.5.mlp.experts.16, device: CPU + model.layers.5.mlp.experts.16.gate_proj, device: CPU + model.layers.5.mlp.experts.16.up_proj, device: CPU + model.layers.5.mlp.experts.16.down_proj, device: CPU + model.layers.5.mlp.experts.16.act, device: CPU + Module: model.layers.5.mlp.experts.17, device: CPU + model.layers.5.mlp.experts.17.gate_proj, device: CPU + model.layers.5.mlp.experts.17.up_proj, device: CPU + model.layers.5.mlp.experts.17.down_proj, device: CPU + model.layers.5.mlp.experts.17.act, device: CPU + Module: model.layers.5.mlp.experts.18, device: CPU + model.layers.5.mlp.experts.18.gate_proj, device: CPU + model.layers.5.mlp.experts.18.up_proj, device: CPU + model.layers.5.mlp.experts.18.down_proj, device: CPU + model.layers.5.mlp.experts.18.act, device: CPU + Module: model.layers.5.mlp.experts.19, device: CPU + model.layers.5.mlp.experts.19.gate_proj, device: CPU + model.layers.5.mlp.experts.19.up_proj, device: CPU + model.layers.5.mlp.experts.19.down_proj, device: CPU + model.layers.5.mlp.experts.19.act, device: CPU + Module: model.layers.5.mlp.experts.20, device: CPU + model.layers.5.mlp.experts.20.gate_proj, device: CPU + model.layers.5.mlp.experts.20.up_proj, device: CPU + model.layers.5.mlp.experts.20.down_proj, device: CPU + model.layers.5.mlp.experts.20.act, device: CPU + Module: model.layers.5.mlp.experts.21, device: CPU + model.layers.5.mlp.experts.21.gate_proj, device: CPU + model.layers.5.mlp.experts.21.up_proj, device: CPU + model.layers.5.mlp.experts.21.down_proj, device: CPU + model.layers.5.mlp.experts.21.act, device: CPU + Module: model.layers.5.mlp.experts.22, device: CPU + model.layers.5.mlp.experts.22.gate_proj, device: CPU + model.layers.5.mlp.experts.22.up_proj, device: CPU + model.layers.5.mlp.experts.22.down_proj, device: CPU + model.layers.5.mlp.experts.22.act, device: CPU + Module: model.layers.5.mlp.experts.23, device: CPU + model.layers.5.mlp.experts.23.gate_proj, device: CPU + model.layers.5.mlp.experts.23.up_proj, device: CPU + model.layers.5.mlp.experts.23.down_proj, device: CPU + model.layers.5.mlp.experts.23.act, device: CPU + Module: model.layers.5.mlp.experts.24, device: CPU + model.layers.5.mlp.experts.24.gate_proj, device: CPU + model.layers.5.mlp.experts.24.up_proj, device: CPU + model.layers.5.mlp.experts.24.down_proj, device: CPU + model.layers.5.mlp.experts.24.act, device: CPU + Module: model.layers.5.mlp.experts.25, device: CPU + model.layers.5.mlp.experts.25.gate_proj, device: CPU + model.layers.5.mlp.experts.25.up_proj, device: CPU + model.layers.5.mlp.experts.25.down_proj, device: CPU + model.layers.5.mlp.experts.25.act, device: CPU + Module: model.layers.5.mlp.experts.26, device: CPU + model.layers.5.mlp.experts.26.gate_proj, device: CPU + model.layers.5.mlp.experts.26.up_proj, device: CPU + model.layers.5.mlp.experts.26.down_proj, device: CPU + model.layers.5.mlp.experts.26.act, device: CPU + Module: model.layers.5.mlp.experts.27, device: CPU + model.layers.5.mlp.experts.27.gate_proj, device: CPU + model.layers.5.mlp.experts.27.up_proj, device: CPU + model.layers.5.mlp.experts.27.down_proj, device: CPU + model.layers.5.mlp.experts.27.act, device: CPU + Module: model.layers.5.mlp.experts.28, device: CPU + model.layers.5.mlp.experts.28.gate_proj, device: CPU + model.layers.5.mlp.experts.28.up_proj, device: CPU + model.layers.5.mlp.experts.28.down_proj, device: CPU + model.layers.5.mlp.experts.28.act, device: CPU + Module: model.layers.5.mlp.experts.29, device: CPU + model.layers.5.mlp.experts.29.gate_proj, device: CPU + model.layers.5.mlp.experts.29.up_proj, device: CPU + model.layers.5.mlp.experts.29.down_proj, device: CPU + model.layers.5.mlp.experts.29.act, device: CPU + Module: model.layers.5.mlp.experts.30, device: CPU + model.layers.5.mlp.experts.30.gate_proj, device: CPU + model.layers.5.mlp.experts.30.up_proj, device: CPU + model.layers.5.mlp.experts.30.down_proj, device: CPU + model.layers.5.mlp.experts.30.act, device: CPU + Module: model.layers.5.mlp.experts.31, device: CPU + model.layers.5.mlp.experts.31.gate_proj, device: CPU + model.layers.5.mlp.experts.31.up_proj, device: CPU + model.layers.5.mlp.experts.31.down_proj, device: CPU + model.layers.5.mlp.experts.31.act, device: CPU + Module: model.layers.5.mlp.experts.32, device: CPU + model.layers.5.mlp.experts.32.gate_proj, device: CPU + model.layers.5.mlp.experts.32.up_proj, device: CPU + model.layers.5.mlp.experts.32.down_proj, device: CPU + model.layers.5.mlp.experts.32.act, device: CPU + Module: model.layers.5.mlp.experts.33, device: CPU + model.layers.5.mlp.experts.33.gate_proj, device: CPU + model.layers.5.mlp.experts.33.up_proj, device: CPU + model.layers.5.mlp.experts.33.down_proj, device: CPU + model.layers.5.mlp.experts.33.act, device: CPU + Module: model.layers.5.mlp.experts.34, device: CPU + model.layers.5.mlp.experts.34.gate_proj, device: CPU + model.layers.5.mlp.experts.34.up_proj, device: CPU + model.layers.5.mlp.experts.34.down_proj, device: CPU + model.layers.5.mlp.experts.34.act, device: CPU + Module: model.layers.5.mlp.experts.35, device: CPU + model.layers.5.mlp.experts.35.gate_proj, device: CPU + model.layers.5.mlp.experts.35.up_proj, device: CPU + model.layers.5.mlp.experts.35.down_proj, device: CPU + model.layers.5.mlp.experts.35.act, device: CPU + Module: model.layers.5.mlp.experts.36, device: CPU + model.layers.5.mlp.experts.36.gate_proj, device: CPU + model.layers.5.mlp.experts.36.up_proj, device: CPU + model.layers.5.mlp.experts.36.down_proj, device: CPU + model.layers.5.mlp.experts.36.act, device: CPU + Module: model.layers.5.mlp.experts.37, device: CPU + model.layers.5.mlp.experts.37.gate_proj, device: CPU + model.layers.5.mlp.experts.37.up_proj, device: CPU + model.layers.5.mlp.experts.37.down_proj, device: CPU + model.layers.5.mlp.experts.37.act, device: CPU + Module: model.layers.5.mlp.experts.38, device: CPU + model.layers.5.mlp.experts.38.gate_proj, device: CPU + model.layers.5.mlp.experts.38.up_proj, device: CPU + model.layers.5.mlp.experts.38.down_proj, device: CPU + model.layers.5.mlp.experts.38.act, device: CPU + Module: model.layers.5.mlp.experts.39, device: CPU + model.layers.5.mlp.experts.39.gate_proj, device: CPU + model.layers.5.mlp.experts.39.up_proj, device: CPU + model.layers.5.mlp.experts.39.down_proj, device: CPU + model.layers.5.mlp.experts.39.act, device: CPU + Module: model.layers.5.mlp.experts.40, device: CPU + model.layers.5.mlp.experts.40.gate_proj, device: CPU + model.layers.5.mlp.experts.40.up_proj, device: CPU + model.layers.5.mlp.experts.40.down_proj, device: CPU + model.layers.5.mlp.experts.40.act, device: CPU + Module: model.layers.5.mlp.experts.41, device: CPU + model.layers.5.mlp.experts.41.gate_proj, device: CPU + model.layers.5.mlp.experts.41.up_proj, device: CPU + model.layers.5.mlp.experts.41.down_proj, device: CPU + model.layers.5.mlp.experts.41.act, device: CPU + Module: model.layers.5.mlp.experts.42, device: CPU + model.layers.5.mlp.experts.42.gate_proj, device: CPU + model.layers.5.mlp.experts.42.up_proj, device: CPU + model.layers.5.mlp.experts.42.down_proj, device: CPU + model.layers.5.mlp.experts.42.act, device: CPU + Module: model.layers.5.mlp.experts.43, device: CPU + model.layers.5.mlp.experts.43.gate_proj, device: CPU + model.layers.5.mlp.experts.43.up_proj, device: CPU + model.layers.5.mlp.experts.43.down_proj, device: CPU + model.layers.5.mlp.experts.43.act, device: CPU + Module: model.layers.5.mlp.experts.44, device: CPU + model.layers.5.mlp.experts.44.gate_proj, device: CPU + model.layers.5.mlp.experts.44.up_proj, device: CPU + model.layers.5.mlp.experts.44.down_proj, device: CPU + model.layers.5.mlp.experts.44.act, device: CPU + Module: model.layers.5.mlp.experts.45, device: CPU + model.layers.5.mlp.experts.45.gate_proj, device: CPU + model.layers.5.mlp.experts.45.up_proj, device: CPU + model.layers.5.mlp.experts.45.down_proj, device: CPU + model.layers.5.mlp.experts.45.act, device: CPU + Module: model.layers.5.mlp.experts.46, device: CPU + model.layers.5.mlp.experts.46.gate_proj, device: CPU + model.layers.5.mlp.experts.46.up_proj, device: CPU + model.layers.5.mlp.experts.46.down_proj, device: CPU + model.layers.5.mlp.experts.46.act, device: CPU + Module: model.layers.5.mlp.experts.47, device: CPU + model.layers.5.mlp.experts.47.gate_proj, device: CPU + model.layers.5.mlp.experts.47.up_proj, device: CPU + model.layers.5.mlp.experts.47.down_proj, device: CPU + model.layers.5.mlp.experts.47.act, device: CPU + Module: model.layers.5.mlp.experts.48, device: CPU + model.layers.5.mlp.experts.48.gate_proj, device: CPU + model.layers.5.mlp.experts.48.up_proj, device: CPU + model.layers.5.mlp.experts.48.down_proj, device: CPU + model.layers.5.mlp.experts.48.act, device: CPU + Module: model.layers.5.mlp.experts.49, device: CPU + model.layers.5.mlp.experts.49.gate_proj, device: CPU + model.layers.5.mlp.experts.49.up_proj, device: CPU + model.layers.5.mlp.experts.49.down_proj, device: CPU + model.layers.5.mlp.experts.49.act, device: CPU + Module: model.layers.5.mlp.experts.50, device: CPU + model.layers.5.mlp.experts.50.gate_proj, device: CPU + model.layers.5.mlp.experts.50.up_proj, device: CPU + model.layers.5.mlp.experts.50.down_proj, device: CPU + model.layers.5.mlp.experts.50.act, device: CPU + Module: model.layers.5.mlp.experts.51, device: CPU + model.layers.5.mlp.experts.51.gate_proj, device: CPU + model.layers.5.mlp.experts.51.up_proj, device: CPU + model.layers.5.mlp.experts.51.down_proj, device: CPU + model.layers.5.mlp.experts.51.act, device: CPU + Module: model.layers.5.mlp.experts.52, device: CPU + model.layers.5.mlp.experts.52.gate_proj, device: CPU + model.layers.5.mlp.experts.52.up_proj, device: CPU + model.layers.5.mlp.experts.52.down_proj, device: CPU + model.layers.5.mlp.experts.52.act, device: CPU + Module: model.layers.5.mlp.experts.53, device: CPU + model.layers.5.mlp.experts.53.gate_proj, device: CPU + model.layers.5.mlp.experts.53.up_proj, device: CPU + model.layers.5.mlp.experts.53.down_proj, device: CPU + model.layers.5.mlp.experts.53.act, device: CPU + Module: model.layers.5.mlp.experts.54, device: CPU + model.layers.5.mlp.experts.54.gate_proj, device: CPU + model.layers.5.mlp.experts.54.up_proj, device: CPU + model.layers.5.mlp.experts.54.down_proj, device: CPU + model.layers.5.mlp.experts.54.act, device: CPU + Module: model.layers.5.mlp.experts.55, device: CPU + model.layers.5.mlp.experts.55.gate_proj, device: CPU + model.layers.5.mlp.experts.55.up_proj, device: CPU + model.layers.5.mlp.experts.55.down_proj, device: CPU + model.layers.5.mlp.experts.55.act, device: CPU + Module: model.layers.5.mlp.experts.56, device: CPU + model.layers.5.mlp.experts.56.gate_proj, device: CPU + model.layers.5.mlp.experts.56.up_proj, device: CPU + model.layers.5.mlp.experts.56.down_proj, device: CPU + model.layers.5.mlp.experts.56.act, device: CPU + Module: model.layers.5.mlp.experts.57, device: CPU + model.layers.5.mlp.experts.57.gate_proj, device: CPU + model.layers.5.mlp.experts.57.up_proj, device: CPU + model.layers.5.mlp.experts.57.down_proj, device: CPU + model.layers.5.mlp.experts.57.act, device: CPU + Module: model.layers.5.mlp.experts.58, device: CPU + model.layers.5.mlp.experts.58.gate_proj, device: CPU + model.layers.5.mlp.experts.58.up_proj, device: CPU + model.layers.5.mlp.experts.58.down_proj, device: CPU + model.layers.5.mlp.experts.58.act, device: CPU + Module: model.layers.5.mlp.experts.59, device: CPU + model.layers.5.mlp.experts.59.gate_proj, device: CPU + model.layers.5.mlp.experts.59.up_proj, device: CPU + model.layers.5.mlp.experts.59.down_proj, device: CPU + model.layers.5.mlp.experts.59.act, device: CPU + Module: model.layers.5.mlp.experts.60, device: CPU + model.layers.5.mlp.experts.60.gate_proj, device: CPU + model.layers.5.mlp.experts.60.up_proj, device: CPU + model.layers.5.mlp.experts.60.down_proj, device: CPU + model.layers.5.mlp.experts.60.act, device: CPU + Module: model.layers.5.mlp.experts.61, device: CPU + model.layers.5.mlp.experts.61.gate_proj, device: CPU + model.layers.5.mlp.experts.61.up_proj, device: CPU + model.layers.5.mlp.experts.61.down_proj, device: CPU + model.layers.5.mlp.experts.61.act, device: CPU + Module: model.layers.5.mlp.experts.62, device: CPU + model.layers.5.mlp.experts.62.gate_proj, device: CPU + model.layers.5.mlp.experts.62.up_proj, device: CPU + model.layers.5.mlp.experts.62.down_proj, device: CPU + model.layers.5.mlp.experts.62.act, device: CPU + Module: model.layers.5.mlp.experts.63, device: CPU + model.layers.5.mlp.experts.63.gate_proj, device: CPU + model.layers.5.mlp.experts.63.up_proj, device: CPU + model.layers.5.mlp.experts.63.down_proj, device: CPU + model.layers.5.mlp.experts.63.act, device: CPU + Module: model.layers.5.mlp.gate, device: CPU + model.layers.5.mlp.gate.weight, device: CPU + Module: model.layers.5.mlp.shared_experts, device: CPU + model.layers.5.mlp.shared_experts.gate_proj, device: CPU + model.layers.5.mlp.shared_experts.up_proj, device: CPU + model.layers.5.mlp.shared_experts.down_proj, device: CPU + model.layers.5.mlp.shared_experts.act, device: CPU + model.layers.5.input_layernorm, device: CPU + model.layers.5.post_attention_layernorm, device: CPU + Module: model.layers.6, device: CPU + Module: model.layers.6.self_attn, device: CPU + model.layers.6.self_attn.q_proj, device: CPU + model.layers.6.self_attn.k_proj, device: CPU + model.layers.6.self_attn.v_proj, device: CPU + model.layers.6.self_attn.o_proj, device: CPU + model.layers.6.self_attn.q_rope, device: CPU + model.layers.6.self_attn.k_rope, device: CPU + Module: model.layers.6.mlp, device: CPU + Module: model.layers.6.mlp.experts, device: CPU + Module: model.layers.6.mlp.experts.0, device: CPU + model.layers.6.mlp.experts.0.gate_proj, device: CPU + model.layers.6.mlp.experts.0.up_proj, device: CPU + model.layers.6.mlp.experts.0.down_proj, device: CPU + model.layers.6.mlp.experts.0.act, device: CPU + Module: model.layers.6.mlp.experts.1, device: CPU + model.layers.6.mlp.experts.1.gate_proj, device: CPU + model.layers.6.mlp.experts.1.up_proj, device: CPU + model.layers.6.mlp.experts.1.down_proj, device: CPU + model.layers.6.mlp.experts.1.act, device: CPU + Module: model.layers.6.mlp.experts.2, device: CPU + model.layers.6.mlp.experts.2.gate_proj, device: CPU + model.layers.6.mlp.experts.2.up_proj, device: CPU + model.layers.6.mlp.experts.2.down_proj, device: CPU + model.layers.6.mlp.experts.2.act, device: CPU + Module: model.layers.6.mlp.experts.3, device: CPU + model.layers.6.mlp.experts.3.gate_proj, device: CPU + model.layers.6.mlp.experts.3.up_proj, device: CPU + model.layers.6.mlp.experts.3.down_proj, device: CPU + model.layers.6.mlp.experts.3.act, device: CPU + Module: model.layers.6.mlp.experts.4, device: CPU + model.layers.6.mlp.experts.4.gate_proj, device: CPU + model.layers.6.mlp.experts.4.up_proj, device: CPU + model.layers.6.mlp.experts.4.down_proj, device: CPU + model.layers.6.mlp.experts.4.act, device: CPU + Module: model.layers.6.mlp.experts.5, device: CPU + model.layers.6.mlp.experts.5.gate_proj, device: CPU + model.layers.6.mlp.experts.5.up_proj, device: CPU + model.layers.6.mlp.experts.5.down_proj, device: CPU + model.layers.6.mlp.experts.5.act, device: CPU + Module: model.layers.6.mlp.experts.6, device: CPU + model.layers.6.mlp.experts.6.gate_proj, device: CPU + model.layers.6.mlp.experts.6.up_proj, device: CPU + model.layers.6.mlp.experts.6.down_proj, device: CPU + model.layers.6.mlp.experts.6.act, device: CPU + Module: model.layers.6.mlp.experts.7, device: CPU + model.layers.6.mlp.experts.7.gate_proj, device: CPU + model.layers.6.mlp.experts.7.up_proj, device: CPU + model.layers.6.mlp.experts.7.down_proj, device: CPU + model.layers.6.mlp.experts.7.act, device: CPU + Module: model.layers.6.mlp.experts.8, device: CPU + model.layers.6.mlp.experts.8.gate_proj, device: CPU + model.layers.6.mlp.experts.8.up_proj, device: CPU + model.layers.6.mlp.experts.8.down_proj, device: CPU + model.layers.6.mlp.experts.8.act, device: CPU + Module: model.layers.6.mlp.experts.9, device: CPU + model.layers.6.mlp.experts.9.gate_proj, device: CPU + model.layers.6.mlp.experts.9.up_proj, device: CPU + model.layers.6.mlp.experts.9.down_proj, device: CPU + model.layers.6.mlp.experts.9.act, device: CPU + Module: model.layers.6.mlp.experts.10, device: CPU + model.layers.6.mlp.experts.10.gate_proj, device: CPU + model.layers.6.mlp.experts.10.up_proj, device: CPU + model.layers.6.mlp.experts.10.down_proj, device: CPU + model.layers.6.mlp.experts.10.act, device: CPU + Module: model.layers.6.mlp.experts.11, device: CPU + model.layers.6.mlp.experts.11.gate_proj, device: CPU + model.layers.6.mlp.experts.11.up_proj, device: CPU + model.layers.6.mlp.experts.11.down_proj, device: CPU + model.layers.6.mlp.experts.11.act, device: CPU + Module: model.layers.6.mlp.experts.12, device: CPU + model.layers.6.mlp.experts.12.gate_proj, device: CPU + model.layers.6.mlp.experts.12.up_proj, device: CPU + model.layers.6.mlp.experts.12.down_proj, device: CPU + model.layers.6.mlp.experts.12.act, device: CPU + Module: model.layers.6.mlp.experts.13, device: CPU + model.layers.6.mlp.experts.13.gate_proj, device: CPU + model.layers.6.mlp.experts.13.up_proj, device: CPU + model.layers.6.mlp.experts.13.down_proj, device: CPU + model.layers.6.mlp.experts.13.act, device: CPU + Module: model.layers.6.mlp.experts.14, device: CPU + model.layers.6.mlp.experts.14.gate_proj, device: CPU + model.layers.6.mlp.experts.14.up_proj, device: CPU + model.layers.6.mlp.experts.14.down_proj, device: CPU + model.layers.6.mlp.experts.14.act, device: CPU + Module: model.layers.6.mlp.experts.15, device: CPU + model.layers.6.mlp.experts.15.gate_proj, device: CPU + model.layers.6.mlp.experts.15.up_proj, device: CPU + model.layers.6.mlp.experts.15.down_proj, device: CPU + model.layers.6.mlp.experts.15.act, device: CPU + Module: model.layers.6.mlp.experts.16, device: CPU + model.layers.6.mlp.experts.16.gate_proj, device: CPU + model.layers.6.mlp.experts.16.up_proj, device: CPU + model.layers.6.mlp.experts.16.down_proj, device: CPU + model.layers.6.mlp.experts.16.act, device: CPU + Module: model.layers.6.mlp.experts.17, device: CPU + model.layers.6.mlp.experts.17.gate_proj, device: CPU + model.layers.6.mlp.experts.17.up_proj, device: CPU + model.layers.6.mlp.experts.17.down_proj, device: CPU + model.layers.6.mlp.experts.17.act, device: CPU + Module: model.layers.6.mlp.experts.18, device: CPU + model.layers.6.mlp.experts.18.gate_proj, device: CPU + model.layers.6.mlp.experts.18.up_proj, device: CPU + model.layers.6.mlp.experts.18.down_proj, device: CPU + model.layers.6.mlp.experts.18.act, device: CPU + Module: model.layers.6.mlp.experts.19, device: CPU + model.layers.6.mlp.experts.19.gate_proj, device: CPU + model.layers.6.mlp.experts.19.up_proj, device: CPU + model.layers.6.mlp.experts.19.down_proj, device: CPU + model.layers.6.mlp.experts.19.act, device: CPU + Module: model.layers.6.mlp.experts.20, device: CPU + model.layers.6.mlp.experts.20.gate_proj, device: CPU + model.layers.6.mlp.experts.20.up_proj, device: CPU + model.layers.6.mlp.experts.20.down_proj, device: CPU + model.layers.6.mlp.experts.20.act, device: CPU + Module: model.layers.6.mlp.experts.21, device: CPU + model.layers.6.mlp.experts.21.gate_proj, device: CPU + model.layers.6.mlp.experts.21.up_proj, device: CPU + model.layers.6.mlp.experts.21.down_proj, device: CPU + model.layers.6.mlp.experts.21.act, device: CPU + Module: model.layers.6.mlp.experts.22, device: CPU + model.layers.6.mlp.experts.22.gate_proj, device: CPU + model.layers.6.mlp.experts.22.up_proj, device: CPU + model.layers.6.mlp.experts.22.down_proj, device: CPU + model.layers.6.mlp.experts.22.act, device: CPU + Module: model.layers.6.mlp.experts.23, device: CPU + model.layers.6.mlp.experts.23.gate_proj, device: CPU + model.layers.6.mlp.experts.23.up_proj, device: CPU + model.layers.6.mlp.experts.23.down_proj, device: CPU + model.layers.6.mlp.experts.23.act, device: CPU + Module: model.layers.6.mlp.experts.24, device: CPU + model.layers.6.mlp.experts.24.gate_proj, device: CPU + model.layers.6.mlp.experts.24.up_proj, device: CPU + model.layers.6.mlp.experts.24.down_proj, device: CPU + model.layers.6.mlp.experts.24.act, device: CPU + Module: model.layers.6.mlp.experts.25, device: CPU + model.layers.6.mlp.experts.25.gate_proj, device: CPU + model.layers.6.mlp.experts.25.up_proj, device: CPU + model.layers.6.mlp.experts.25.down_proj, device: CPU + model.layers.6.mlp.experts.25.act, device: CPU + Module: model.layers.6.mlp.experts.26, device: CPU + model.layers.6.mlp.experts.26.gate_proj, device: CPU + model.layers.6.mlp.experts.26.up_proj, device: CPU + model.layers.6.mlp.experts.26.down_proj, device: CPU + model.layers.6.mlp.experts.26.act, device: CPU + Module: model.layers.6.mlp.experts.27, device: CPU + model.layers.6.mlp.experts.27.gate_proj, device: CPU + model.layers.6.mlp.experts.27.up_proj, device: CPU + model.layers.6.mlp.experts.27.down_proj, device: CPU + model.layers.6.mlp.experts.27.act, device: CPU + Module: model.layers.6.mlp.experts.28, device: CPU + model.layers.6.mlp.experts.28.gate_proj, device: CPU + model.layers.6.mlp.experts.28.up_proj, device: CPU + model.layers.6.mlp.experts.28.down_proj, device: CPU + model.layers.6.mlp.experts.28.act, device: CPU + Module: model.layers.6.mlp.experts.29, device: CPU + model.layers.6.mlp.experts.29.gate_proj, device: CPU + model.layers.6.mlp.experts.29.up_proj, device: CPU + model.layers.6.mlp.experts.29.down_proj, device: CPU + model.layers.6.mlp.experts.29.act, device: CPU + Module: model.layers.6.mlp.experts.30, device: CPU + model.layers.6.mlp.experts.30.gate_proj, device: CPU + model.layers.6.mlp.experts.30.up_proj, device: CPU + model.layers.6.mlp.experts.30.down_proj, device: CPU + model.layers.6.mlp.experts.30.act, device: CPU + Module: model.layers.6.mlp.experts.31, device: CPU + model.layers.6.mlp.experts.31.gate_proj, device: CPU + model.layers.6.mlp.experts.31.up_proj, device: CPU + model.layers.6.mlp.experts.31.down_proj, device: CPU + model.layers.6.mlp.experts.31.act, device: CPU + Module: model.layers.6.mlp.experts.32, device: CPU + model.layers.6.mlp.experts.32.gate_proj, device: CPU + model.layers.6.mlp.experts.32.up_proj, device: CPU + model.layers.6.mlp.experts.32.down_proj, device: CPU + model.layers.6.mlp.experts.32.act, device: CPU + Module: model.layers.6.mlp.experts.33, device: CPU + model.layers.6.mlp.experts.33.gate_proj, device: CPU + model.layers.6.mlp.experts.33.up_proj, device: CPU + model.layers.6.mlp.experts.33.down_proj, device: CPU + model.layers.6.mlp.experts.33.act, device: CPU + Module: model.layers.6.mlp.experts.34, device: CPU + model.layers.6.mlp.experts.34.gate_proj, device: CPU + model.layers.6.mlp.experts.34.up_proj, device: CPU + model.layers.6.mlp.experts.34.down_proj, device: CPU + model.layers.6.mlp.experts.34.act, device: CPU + Module: model.layers.6.mlp.experts.35, device: CPU + model.layers.6.mlp.experts.35.gate_proj, device: CPU + model.layers.6.mlp.experts.35.up_proj, device: CPU + model.layers.6.mlp.experts.35.down_proj, device: CPU + model.layers.6.mlp.experts.35.act, device: CPU + Module: model.layers.6.mlp.experts.36, device: CPU + model.layers.6.mlp.experts.36.gate_proj, device: CPU + model.layers.6.mlp.experts.36.up_proj, device: CPU + model.layers.6.mlp.experts.36.down_proj, device: CPU + model.layers.6.mlp.experts.36.act, device: CPU + Module: model.layers.6.mlp.experts.37, device: CPU + model.layers.6.mlp.experts.37.gate_proj, device: CPU + model.layers.6.mlp.experts.37.up_proj, device: CPU + model.layers.6.mlp.experts.37.down_proj, device: CPU + model.layers.6.mlp.experts.37.act, device: CPU + Module: model.layers.6.mlp.experts.38, device: CPU + model.layers.6.mlp.experts.38.gate_proj, device: CPU + model.layers.6.mlp.experts.38.up_proj, device: CPU + model.layers.6.mlp.experts.38.down_proj, device: CPU + model.layers.6.mlp.experts.38.act, device: CPU + Module: model.layers.6.mlp.experts.39, device: CPU + model.layers.6.mlp.experts.39.gate_proj, device: CPU + model.layers.6.mlp.experts.39.up_proj, device: CPU + model.layers.6.mlp.experts.39.down_proj, device: CPU + model.layers.6.mlp.experts.39.act, device: CPU + Module: model.layers.6.mlp.experts.40, device: CPU + model.layers.6.mlp.experts.40.gate_proj, device: CPU + model.layers.6.mlp.experts.40.up_proj, device: CPU + model.layers.6.mlp.experts.40.down_proj, device: CPU + model.layers.6.mlp.experts.40.act, device: CPU + Module: model.layers.6.mlp.experts.41, device: CPU + model.layers.6.mlp.experts.41.gate_proj, device: CPU + model.layers.6.mlp.experts.41.up_proj, device: CPU + model.layers.6.mlp.experts.41.down_proj, device: CPU + model.layers.6.mlp.experts.41.act, device: CPU + Module: model.layers.6.mlp.experts.42, device: CPU + model.layers.6.mlp.experts.42.gate_proj, device: CPU + model.layers.6.mlp.experts.42.up_proj, device: CPU + model.layers.6.mlp.experts.42.down_proj, device: CPU + model.layers.6.mlp.experts.42.act, device: CPU + Module: model.layers.6.mlp.experts.43, device: CPU + model.layers.6.mlp.experts.43.gate_proj, device: CPU + model.layers.6.mlp.experts.43.up_proj, device: CPU + model.layers.6.mlp.experts.43.down_proj, device: CPU + model.layers.6.mlp.experts.43.act, device: CPU + Module: model.layers.6.mlp.experts.44, device: CPU + model.layers.6.mlp.experts.44.gate_proj, device: CPU + model.layers.6.mlp.experts.44.up_proj, device: CPU + model.layers.6.mlp.experts.44.down_proj, device: CPU + model.layers.6.mlp.experts.44.act, device: CPU + Module: model.layers.6.mlp.experts.45, device: CPU + model.layers.6.mlp.experts.45.gate_proj, device: CPU + model.layers.6.mlp.experts.45.up_proj, device: CPU + model.layers.6.mlp.experts.45.down_proj, device: CPU + model.layers.6.mlp.experts.45.act, device: CPU + Module: model.layers.6.mlp.experts.46, device: CPU + model.layers.6.mlp.experts.46.gate_proj, device: CPU + model.layers.6.mlp.experts.46.up_proj, device: CPU + model.layers.6.mlp.experts.46.down_proj, device: CPU + model.layers.6.mlp.experts.46.act, device: CPU + Module: model.layers.6.mlp.experts.47, device: CPU + model.layers.6.mlp.experts.47.gate_proj, device: CPU + model.layers.6.mlp.experts.47.up_proj, device: CPU + model.layers.6.mlp.experts.47.down_proj, device: CPU + model.layers.6.mlp.experts.47.act, device: CPU + Module: model.layers.6.mlp.experts.48, device: CPU + model.layers.6.mlp.experts.48.gate_proj, device: CPU + model.layers.6.mlp.experts.48.up_proj, device: CPU + model.layers.6.mlp.experts.48.down_proj, device: CPU + model.layers.6.mlp.experts.48.act, device: CPU + Module: model.layers.6.mlp.experts.49, device: CPU + model.layers.6.mlp.experts.49.gate_proj, device: CPU + model.layers.6.mlp.experts.49.up_proj, device: CPU + model.layers.6.mlp.experts.49.down_proj, device: CPU + model.layers.6.mlp.experts.49.act, device: CPU + Module: model.layers.6.mlp.experts.50, device: CPU + model.layers.6.mlp.experts.50.gate_proj, device: CPU + model.layers.6.mlp.experts.50.up_proj, device: CPU + model.layers.6.mlp.experts.50.down_proj, device: CPU + model.layers.6.mlp.experts.50.act, device: CPU + Module: model.layers.6.mlp.experts.51, device: CPU + model.layers.6.mlp.experts.51.gate_proj, device: CPU + model.layers.6.mlp.experts.51.up_proj, device: CPU + model.layers.6.mlp.experts.51.down_proj, device: CPU + model.layers.6.mlp.experts.51.act, device: CPU + Module: model.layers.6.mlp.experts.52, device: CPU + model.layers.6.mlp.experts.52.gate_proj, device: CPU + model.layers.6.mlp.experts.52.up_proj, device: CPU + model.layers.6.mlp.experts.52.down_proj, device: CPU + model.layers.6.mlp.experts.52.act, device: CPU + Module: model.layers.6.mlp.experts.53, device: CPU + model.layers.6.mlp.experts.53.gate_proj, device: CPU + model.layers.6.mlp.experts.53.up_proj, device: CPU + model.layers.6.mlp.experts.53.down_proj, device: CPU + model.layers.6.mlp.experts.53.act, device: CPU + Module: model.layers.6.mlp.experts.54, device: CPU + model.layers.6.mlp.experts.54.gate_proj, device: CPU + model.layers.6.mlp.experts.54.up_proj, device: CPU + model.layers.6.mlp.experts.54.down_proj, device: CPU + model.layers.6.mlp.experts.54.act, device: CPU + Module: model.layers.6.mlp.experts.55, device: CPU + model.layers.6.mlp.experts.55.gate_proj, device: CPU + model.layers.6.mlp.experts.55.up_proj, device: CPU + model.layers.6.mlp.experts.55.down_proj, device: CPU + model.layers.6.mlp.experts.55.act, device: CPU + Module: model.layers.6.mlp.experts.56, device: CPU + model.layers.6.mlp.experts.56.gate_proj, device: CPU + model.layers.6.mlp.experts.56.up_proj, device: CPU + model.layers.6.mlp.experts.56.down_proj, device: CPU + model.layers.6.mlp.experts.56.act, device: CPU + Module: model.layers.6.mlp.experts.57, device: CPU + model.layers.6.mlp.experts.57.gate_proj, device: CPU + model.layers.6.mlp.experts.57.up_proj, device: CPU + model.layers.6.mlp.experts.57.down_proj, device: CPU + model.layers.6.mlp.experts.57.act, device: CPU + Module: model.layers.6.mlp.experts.58, device: CPU + model.layers.6.mlp.experts.58.gate_proj, device: CPU + model.layers.6.mlp.experts.58.up_proj, device: CPU + model.layers.6.mlp.experts.58.down_proj, device: CPU + model.layers.6.mlp.experts.58.act, device: CPU + Module: model.layers.6.mlp.experts.59, device: CPU + model.layers.6.mlp.experts.59.gate_proj, device: CPU + model.layers.6.mlp.experts.59.up_proj, device: CPU + model.layers.6.mlp.experts.59.down_proj, device: CPU + model.layers.6.mlp.experts.59.act, device: CPU + Module: model.layers.6.mlp.experts.60, device: CPU + model.layers.6.mlp.experts.60.gate_proj, device: CPU + model.layers.6.mlp.experts.60.up_proj, device: CPU + model.layers.6.mlp.experts.60.down_proj, device: CPU + model.layers.6.mlp.experts.60.act, device: CPU + Module: model.layers.6.mlp.experts.61, device: CPU + model.layers.6.mlp.experts.61.gate_proj, device: CPU + model.layers.6.mlp.experts.61.up_proj, device: CPU + model.layers.6.mlp.experts.61.down_proj, device: CPU + model.layers.6.mlp.experts.61.act, device: CPU + Module: model.layers.6.mlp.experts.62, device: CPU + model.layers.6.mlp.experts.62.gate_proj, device: CPU + model.layers.6.mlp.experts.62.up_proj, device: CPU + model.layers.6.mlp.experts.62.down_proj, device: CPU + model.layers.6.mlp.experts.62.act, device: CPU + Module: model.layers.6.mlp.experts.63, device: CPU + model.layers.6.mlp.experts.63.gate_proj, device: CPU + model.layers.6.mlp.experts.63.up_proj, device: CPU + model.layers.6.mlp.experts.63.down_proj, device: CPU + model.layers.6.mlp.experts.63.act, device: CPU + Module: model.layers.6.mlp.gate, device: CPU + model.layers.6.mlp.gate.weight, device: CPU + Module: model.layers.6.mlp.shared_experts, device: CPU + model.layers.6.mlp.shared_experts.gate_proj, device: CPU + model.layers.6.mlp.shared_experts.up_proj, device: CPU + model.layers.6.mlp.shared_experts.down_proj, device: CPU + model.layers.6.mlp.shared_experts.act, device: CPU + model.layers.6.input_layernorm, device: CPU + model.layers.6.post_attention_layernorm, device: CPU + Module: model.layers.7, device: CPU + Module: model.layers.7.self_attn, device: CPU + model.layers.7.self_attn.q_proj, device: CPU + model.layers.7.self_attn.k_proj, device: CPU + model.layers.7.self_attn.v_proj, device: CPU + model.layers.7.self_attn.o_proj, device: CPU + model.layers.7.self_attn.q_rope, device: CPU + model.layers.7.self_attn.k_rope, device: CPU + Module: model.layers.7.mlp, device: CPU + Module: model.layers.7.mlp.experts, device: CPU + Module: model.layers.7.mlp.experts.0, device: CPU + model.layers.7.mlp.experts.0.gate_proj, device: CPU + model.layers.7.mlp.experts.0.up_proj, device: CPU + model.layers.7.mlp.experts.0.down_proj, device: CPU + model.layers.7.mlp.experts.0.act, device: CPU + Module: model.layers.7.mlp.experts.1, device: CPU + model.layers.7.mlp.experts.1.gate_proj, device: CPU + model.layers.7.mlp.experts.1.up_proj, device: CPU + model.layers.7.mlp.experts.1.down_proj, device: CPU + model.layers.7.mlp.experts.1.act, device: CPU + Module: model.layers.7.mlp.experts.2, device: CPU + model.layers.7.mlp.experts.2.gate_proj, device: CPU + model.layers.7.mlp.experts.2.up_proj, device: CPU + model.layers.7.mlp.experts.2.down_proj, device: CPU + model.layers.7.mlp.experts.2.act, device: CPU + Module: model.layers.7.mlp.experts.3, device: CPU + model.layers.7.mlp.experts.3.gate_proj, device: CPU + model.layers.7.mlp.experts.3.up_proj, device: CPU + model.layers.7.mlp.experts.3.down_proj, device: CPU + model.layers.7.mlp.experts.3.act, device: CPU + Module: model.layers.7.mlp.experts.4, device: CPU + model.layers.7.mlp.experts.4.gate_proj, device: CPU + model.layers.7.mlp.experts.4.up_proj, device: CPU + model.layers.7.mlp.experts.4.down_proj, device: CPU + model.layers.7.mlp.experts.4.act, device: CPU + Module: model.layers.7.mlp.experts.5, device: CPU + model.layers.7.mlp.experts.5.gate_proj, device: CPU + model.layers.7.mlp.experts.5.up_proj, device: CPU + model.layers.7.mlp.experts.5.down_proj, device: CPU + model.layers.7.mlp.experts.5.act, device: CPU + Module: model.layers.7.mlp.experts.6, device: CPU + model.layers.7.mlp.experts.6.gate_proj, device: CPU + model.layers.7.mlp.experts.6.up_proj, device: CPU + model.layers.7.mlp.experts.6.down_proj, device: CPU + model.layers.7.mlp.experts.6.act, device: CPU + Module: model.layers.7.mlp.experts.7, device: CPU + model.layers.7.mlp.experts.7.gate_proj, device: CPU + model.layers.7.mlp.experts.7.up_proj, device: CPU + model.layers.7.mlp.experts.7.down_proj, device: CPU + model.layers.7.mlp.experts.7.act, device: CPU + Module: model.layers.7.mlp.experts.8, device: CPU + model.layers.7.mlp.experts.8.gate_proj, device: CPU + model.layers.7.mlp.experts.8.up_proj, device: CPU + model.layers.7.mlp.experts.8.down_proj, device: CPU + model.layers.7.mlp.experts.8.act, device: CPU + Module: model.layers.7.mlp.experts.9, device: CPU + model.layers.7.mlp.experts.9.gate_proj, device: CPU + model.layers.7.mlp.experts.9.up_proj, device: CPU + model.layers.7.mlp.experts.9.down_proj, device: CPU + model.layers.7.mlp.experts.9.act, device: CPU + Module: model.layers.7.mlp.experts.10, device: CPU + model.layers.7.mlp.experts.10.gate_proj, device: CPU + model.layers.7.mlp.experts.10.up_proj, device: CPU + model.layers.7.mlp.experts.10.down_proj, device: CPU + model.layers.7.mlp.experts.10.act, device: CPU + Module: model.layers.7.mlp.experts.11, device: CPU + model.layers.7.mlp.experts.11.gate_proj, device: CPU + model.layers.7.mlp.experts.11.up_proj, device: CPU + model.layers.7.mlp.experts.11.down_proj, device: CPU + model.layers.7.mlp.experts.11.act, device: CPU + Module: model.layers.7.mlp.experts.12, device: CPU + model.layers.7.mlp.experts.12.gate_proj, device: CPU + model.layers.7.mlp.experts.12.up_proj, device: CPU + model.layers.7.mlp.experts.12.down_proj, device: CPU + model.layers.7.mlp.experts.12.act, device: CPU + Module: model.layers.7.mlp.experts.13, device: CPU + model.layers.7.mlp.experts.13.gate_proj, device: CPU + model.layers.7.mlp.experts.13.up_proj, device: CPU + model.layers.7.mlp.experts.13.down_proj, device: CPU + model.layers.7.mlp.experts.13.act, device: CPU + Module: model.layers.7.mlp.experts.14, device: CPU + model.layers.7.mlp.experts.14.gate_proj, device: CPU + model.layers.7.mlp.experts.14.up_proj, device: CPU + model.layers.7.mlp.experts.14.down_proj, device: CPU + model.layers.7.mlp.experts.14.act, device: CPU + Module: model.layers.7.mlp.experts.15, device: CPU + model.layers.7.mlp.experts.15.gate_proj, device: CPU + model.layers.7.mlp.experts.15.up_proj, device: CPU + model.layers.7.mlp.experts.15.down_proj, device: CPU + model.layers.7.mlp.experts.15.act, device: CPU + Module: model.layers.7.mlp.experts.16, device: CPU + model.layers.7.mlp.experts.16.gate_proj, device: CPU + model.layers.7.mlp.experts.16.up_proj, device: CPU + model.layers.7.mlp.experts.16.down_proj, device: CPU + model.layers.7.mlp.experts.16.act, device: CPU + Module: model.layers.7.mlp.experts.17, device: CPU + model.layers.7.mlp.experts.17.gate_proj, device: CPU + model.layers.7.mlp.experts.17.up_proj, device: CPU + model.layers.7.mlp.experts.17.down_proj, device: CPU + model.layers.7.mlp.experts.17.act, device: CPU + Module: model.layers.7.mlp.experts.18, device: CPU + model.layers.7.mlp.experts.18.gate_proj, device: CPU + model.layers.7.mlp.experts.18.up_proj, device: CPU + model.layers.7.mlp.experts.18.down_proj, device: CPU + model.layers.7.mlp.experts.18.act, device: CPU + Module: model.layers.7.mlp.experts.19, device: CPU + model.layers.7.mlp.experts.19.gate_proj, device: CPU + model.layers.7.mlp.experts.19.up_proj, device: CPU + model.layers.7.mlp.experts.19.down_proj, device: CPU + model.layers.7.mlp.experts.19.act, device: CPU + Module: model.layers.7.mlp.experts.20, device: CPU + model.layers.7.mlp.experts.20.gate_proj, device: CPU + model.layers.7.mlp.experts.20.up_proj, device: CPU + model.layers.7.mlp.experts.20.down_proj, device: CPU + model.layers.7.mlp.experts.20.act, device: CPU + Module: model.layers.7.mlp.experts.21, device: CPU + model.layers.7.mlp.experts.21.gate_proj, device: CPU + model.layers.7.mlp.experts.21.up_proj, device: CPU + model.layers.7.mlp.experts.21.down_proj, device: CPU + model.layers.7.mlp.experts.21.act, device: CPU + Module: model.layers.7.mlp.experts.22, device: CPU + model.layers.7.mlp.experts.22.gate_proj, device: CPU + model.layers.7.mlp.experts.22.up_proj, device: CPU + model.layers.7.mlp.experts.22.down_proj, device: CPU + model.layers.7.mlp.experts.22.act, device: CPU + Module: model.layers.7.mlp.experts.23, device: CPU + model.layers.7.mlp.experts.23.gate_proj, device: CPU + model.layers.7.mlp.experts.23.up_proj, device: CPU + model.layers.7.mlp.experts.23.down_proj, device: CPU + model.layers.7.mlp.experts.23.act, device: CPU + Module: model.layers.7.mlp.experts.24, device: CPU + model.layers.7.mlp.experts.24.gate_proj, device: CPU + model.layers.7.mlp.experts.24.up_proj, device: CPU + model.layers.7.mlp.experts.24.down_proj, device: CPU + model.layers.7.mlp.experts.24.act, device: CPU + Module: model.layers.7.mlp.experts.25, device: CPU + model.layers.7.mlp.experts.25.gate_proj, device: CPU + model.layers.7.mlp.experts.25.up_proj, device: CPU + model.layers.7.mlp.experts.25.down_proj, device: CPU + model.layers.7.mlp.experts.25.act, device: CPU + Module: model.layers.7.mlp.experts.26, device: CPU + model.layers.7.mlp.experts.26.gate_proj, device: CPU + model.layers.7.mlp.experts.26.up_proj, device: CPU + model.layers.7.mlp.experts.26.down_proj, device: CPU + model.layers.7.mlp.experts.26.act, device: CPU + Module: model.layers.7.mlp.experts.27, device: CPU + model.layers.7.mlp.experts.27.gate_proj, device: CPU + model.layers.7.mlp.experts.27.up_proj, device: CPU + model.layers.7.mlp.experts.27.down_proj, device: CPU + model.layers.7.mlp.experts.27.act, device: CPU + Module: model.layers.7.mlp.experts.28, device: CPU + model.layers.7.mlp.experts.28.gate_proj, device: CPU + model.layers.7.mlp.experts.28.up_proj, device: CPU + model.layers.7.mlp.experts.28.down_proj, device: CPU + model.layers.7.mlp.experts.28.act, device: CPU + Module: model.layers.7.mlp.experts.29, device: CPU + model.layers.7.mlp.experts.29.gate_proj, device: CPU + model.layers.7.mlp.experts.29.up_proj, device: CPU + model.layers.7.mlp.experts.29.down_proj, device: CPU + model.layers.7.mlp.experts.29.act, device: CPU + Module: model.layers.7.mlp.experts.30, device: CPU + model.layers.7.mlp.experts.30.gate_proj, device: CPU + model.layers.7.mlp.experts.30.up_proj, device: CPU + model.layers.7.mlp.experts.30.down_proj, device: CPU + model.layers.7.mlp.experts.30.act, device: CPU + Module: model.layers.7.mlp.experts.31, device: CPU + model.layers.7.mlp.experts.31.gate_proj, device: CPU + model.layers.7.mlp.experts.31.up_proj, device: CPU + model.layers.7.mlp.experts.31.down_proj, device: CPU + model.layers.7.mlp.experts.31.act, device: CPU + Module: model.layers.7.mlp.experts.32, device: CPU + model.layers.7.mlp.experts.32.gate_proj, device: CPU + model.layers.7.mlp.experts.32.up_proj, device: CPU + model.layers.7.mlp.experts.32.down_proj, device: CPU + model.layers.7.mlp.experts.32.act, device: CPU + Module: model.layers.7.mlp.experts.33, device: CPU + model.layers.7.mlp.experts.33.gate_proj, device: CPU + model.layers.7.mlp.experts.33.up_proj, device: CPU + model.layers.7.mlp.experts.33.down_proj, device: CPU + model.layers.7.mlp.experts.33.act, device: CPU + Module: model.layers.7.mlp.experts.34, device: CPU + model.layers.7.mlp.experts.34.gate_proj, device: CPU + model.layers.7.mlp.experts.34.up_proj, device: CPU + model.layers.7.mlp.experts.34.down_proj, device: CPU + model.layers.7.mlp.experts.34.act, device: CPU + Module: model.layers.7.mlp.experts.35, device: CPU + model.layers.7.mlp.experts.35.gate_proj, device: CPU + model.layers.7.mlp.experts.35.up_proj, device: CPU + model.layers.7.mlp.experts.35.down_proj, device: CPU + model.layers.7.mlp.experts.35.act, device: CPU + Module: model.layers.7.mlp.experts.36, device: CPU + model.layers.7.mlp.experts.36.gate_proj, device: CPU + model.layers.7.mlp.experts.36.up_proj, device: CPU + model.layers.7.mlp.experts.36.down_proj, device: CPU + model.layers.7.mlp.experts.36.act, device: CPU + Module: model.layers.7.mlp.experts.37, device: CPU + model.layers.7.mlp.experts.37.gate_proj, device: CPU + model.layers.7.mlp.experts.37.up_proj, device: CPU + model.layers.7.mlp.experts.37.down_proj, device: CPU + model.layers.7.mlp.experts.37.act, device: CPU + Module: model.layers.7.mlp.experts.38, device: CPU + model.layers.7.mlp.experts.38.gate_proj, device: CPU + model.layers.7.mlp.experts.38.up_proj, device: CPU + model.layers.7.mlp.experts.38.down_proj, device: CPU + model.layers.7.mlp.experts.38.act, device: CPU + Module: model.layers.7.mlp.experts.39, device: CPU + model.layers.7.mlp.experts.39.gate_proj, device: CPU + model.layers.7.mlp.experts.39.up_proj, device: CPU + model.layers.7.mlp.experts.39.down_proj, device: CPU + model.layers.7.mlp.experts.39.act, device: CPU + Module: model.layers.7.mlp.experts.40, device: CPU + model.layers.7.mlp.experts.40.gate_proj, device: CPU + model.layers.7.mlp.experts.40.up_proj, device: CPU + model.layers.7.mlp.experts.40.down_proj, device: CPU + model.layers.7.mlp.experts.40.act, device: CPU + Module: model.layers.7.mlp.experts.41, device: CPU + model.layers.7.mlp.experts.41.gate_proj, device: CPU + model.layers.7.mlp.experts.41.up_proj, device: CPU + model.layers.7.mlp.experts.41.down_proj, device: CPU + model.layers.7.mlp.experts.41.act, device: CPU + Module: model.layers.7.mlp.experts.42, device: CPU + model.layers.7.mlp.experts.42.gate_proj, device: CPU + model.layers.7.mlp.experts.42.up_proj, device: CPU + model.layers.7.mlp.experts.42.down_proj, device: CPU + model.layers.7.mlp.experts.42.act, device: CPU + Module: model.layers.7.mlp.experts.43, device: CPU + model.layers.7.mlp.experts.43.gate_proj, device: CPU + model.layers.7.mlp.experts.43.up_proj, device: CPU + model.layers.7.mlp.experts.43.down_proj, device: CPU + model.layers.7.mlp.experts.43.act, device: CPU + Module: model.layers.7.mlp.experts.44, device: CPU + model.layers.7.mlp.experts.44.gate_proj, device: CPU + model.layers.7.mlp.experts.44.up_proj, device: CPU + model.layers.7.mlp.experts.44.down_proj, device: CPU + model.layers.7.mlp.experts.44.act, device: CPU + Module: model.layers.7.mlp.experts.45, device: CPU + model.layers.7.mlp.experts.45.gate_proj, device: CPU + model.layers.7.mlp.experts.45.up_proj, device: CPU + model.layers.7.mlp.experts.45.down_proj, device: CPU + model.layers.7.mlp.experts.45.act, device: CPU + Module: model.layers.7.mlp.experts.46, device: CPU + model.layers.7.mlp.experts.46.gate_proj, device: CPU + model.layers.7.mlp.experts.46.up_proj, device: CPU + model.layers.7.mlp.experts.46.down_proj, device: CPU + model.layers.7.mlp.experts.46.act, device: CPU + Module: model.layers.7.mlp.experts.47, device: CPU + model.layers.7.mlp.experts.47.gate_proj, device: CPU + model.layers.7.mlp.experts.47.up_proj, device: CPU + model.layers.7.mlp.experts.47.down_proj, device: CPU + model.layers.7.mlp.experts.47.act, device: CPU + Module: model.layers.7.mlp.experts.48, device: CPU + model.layers.7.mlp.experts.48.gate_proj, device: CPU + model.layers.7.mlp.experts.48.up_proj, device: CPU + model.layers.7.mlp.experts.48.down_proj, device: CPU + model.layers.7.mlp.experts.48.act, device: CPU + Module: model.layers.7.mlp.experts.49, device: CPU + model.layers.7.mlp.experts.49.gate_proj, device: CPU + model.layers.7.mlp.experts.49.up_proj, device: CPU + model.layers.7.mlp.experts.49.down_proj, device: CPU + model.layers.7.mlp.experts.49.act, device: CPU + Module: model.layers.7.mlp.experts.50, device: CPU + model.layers.7.mlp.experts.50.gate_proj, device: CPU + model.layers.7.mlp.experts.50.up_proj, device: CPU + model.layers.7.mlp.experts.50.down_proj, device: CPU + model.layers.7.mlp.experts.50.act, device: CPU + Module: model.layers.7.mlp.experts.51, device: CPU + model.layers.7.mlp.experts.51.gate_proj, device: CPU + model.layers.7.mlp.experts.51.up_proj, device: CPU + model.layers.7.mlp.experts.51.down_proj, device: CPU + model.layers.7.mlp.experts.51.act, device: CPU + Module: model.layers.7.mlp.experts.52, device: CPU + model.layers.7.mlp.experts.52.gate_proj, device: CPU + model.layers.7.mlp.experts.52.up_proj, device: CPU + model.layers.7.mlp.experts.52.down_proj, device: CPU + model.layers.7.mlp.experts.52.act, device: CPU + Module: model.layers.7.mlp.experts.53, device: CPU + model.layers.7.mlp.experts.53.gate_proj, device: CPU + model.layers.7.mlp.experts.53.up_proj, device: CPU + model.layers.7.mlp.experts.53.down_proj, device: CPU + model.layers.7.mlp.experts.53.act, device: CPU + Module: model.layers.7.mlp.experts.54, device: CPU + model.layers.7.mlp.experts.54.gate_proj, device: CPU + model.layers.7.mlp.experts.54.up_proj, device: CPU + model.layers.7.mlp.experts.54.down_proj, device: CPU + model.layers.7.mlp.experts.54.act, device: CPU + Module: model.layers.7.mlp.experts.55, device: CPU + model.layers.7.mlp.experts.55.gate_proj, device: CPU + model.layers.7.mlp.experts.55.up_proj, device: CPU + model.layers.7.mlp.experts.55.down_proj, device: CPU + model.layers.7.mlp.experts.55.act, device: CPU + Module: model.layers.7.mlp.experts.56, device: CPU + model.layers.7.mlp.experts.56.gate_proj, device: CPU + model.layers.7.mlp.experts.56.up_proj, device: CPU + model.layers.7.mlp.experts.56.down_proj, device: CPU + model.layers.7.mlp.experts.56.act, device: CPU + Module: model.layers.7.mlp.experts.57, device: CPU + model.layers.7.mlp.experts.57.gate_proj, device: CPU + model.layers.7.mlp.experts.57.up_proj, device: CPU + model.layers.7.mlp.experts.57.down_proj, device: CPU + model.layers.7.mlp.experts.57.act, device: CPU + Module: model.layers.7.mlp.experts.58, device: CPU + model.layers.7.mlp.experts.58.gate_proj, device: CPU + model.layers.7.mlp.experts.58.up_proj, device: CPU + model.layers.7.mlp.experts.58.down_proj, device: CPU + model.layers.7.mlp.experts.58.act, device: CPU + Module: model.layers.7.mlp.experts.59, device: CPU + model.layers.7.mlp.experts.59.gate_proj, device: CPU + model.layers.7.mlp.experts.59.up_proj, device: CPU + model.layers.7.mlp.experts.59.down_proj, device: CPU + model.layers.7.mlp.experts.59.act, device: CPU + Module: model.layers.7.mlp.experts.60, device: CPU + model.layers.7.mlp.experts.60.gate_proj, device: CPU + model.layers.7.mlp.experts.60.up_proj, device: CPU + model.layers.7.mlp.experts.60.down_proj, device: CPU + model.layers.7.mlp.experts.60.act, device: CPU + Module: model.layers.7.mlp.experts.61, device: CPU + model.layers.7.mlp.experts.61.gate_proj, device: CPU + model.layers.7.mlp.experts.61.up_proj, device: CPU + model.layers.7.mlp.experts.61.down_proj, device: CPU + model.layers.7.mlp.experts.61.act, device: CPU + Module: model.layers.7.mlp.experts.62, device: CPU + model.layers.7.mlp.experts.62.gate_proj, device: CPU + model.layers.7.mlp.experts.62.up_proj, device: CPU + model.layers.7.mlp.experts.62.down_proj, device: CPU + model.layers.7.mlp.experts.62.act, device: CPU + Module: model.layers.7.mlp.experts.63, device: CPU + model.layers.7.mlp.experts.63.gate_proj, device: CPU + model.layers.7.mlp.experts.63.up_proj, device: CPU + model.layers.7.mlp.experts.63.down_proj, device: CPU + model.layers.7.mlp.experts.63.act, device: CPU + Module: model.layers.7.mlp.gate, device: CPU + model.layers.7.mlp.gate.weight, device: CPU + Module: model.layers.7.mlp.shared_experts, device: CPU + model.layers.7.mlp.shared_experts.gate_proj, device: CPU + model.layers.7.mlp.shared_experts.up_proj, device: CPU + model.layers.7.mlp.shared_experts.down_proj, device: CPU + model.layers.7.mlp.shared_experts.act, device: CPU + model.layers.7.input_layernorm, device: CPU + model.layers.7.post_attention_layernorm, device: CPU + Module: model.layers.8, device: CPU + Module: model.layers.8.self_attn, device: CPU + model.layers.8.self_attn.q_proj, device: CPU + model.layers.8.self_attn.k_proj, device: CPU + model.layers.8.self_attn.v_proj, device: CPU + model.layers.8.self_attn.o_proj, device: CPU + model.layers.8.self_attn.q_rope, device: CPU + model.layers.8.self_attn.k_rope, device: CPU + Module: model.layers.8.mlp, device: CPU + Module: model.layers.8.mlp.experts, device: CPU + Module: model.layers.8.mlp.experts.0, device: CPU + model.layers.8.mlp.experts.0.gate_proj, device: CPU + model.layers.8.mlp.experts.0.up_proj, device: CPU + model.layers.8.mlp.experts.0.down_proj, device: CPU + model.layers.8.mlp.experts.0.act, device: CPU + Module: model.layers.8.mlp.experts.1, device: CPU + model.layers.8.mlp.experts.1.gate_proj, device: CPU + model.layers.8.mlp.experts.1.up_proj, device: CPU + model.layers.8.mlp.experts.1.down_proj, device: CPU + model.layers.8.mlp.experts.1.act, device: CPU + Module: model.layers.8.mlp.experts.2, device: CPU + model.layers.8.mlp.experts.2.gate_proj, device: CPU + model.layers.8.mlp.experts.2.up_proj, device: CPU + model.layers.8.mlp.experts.2.down_proj, device: CPU + model.layers.8.mlp.experts.2.act, device: CPU + Module: model.layers.8.mlp.experts.3, device: CPU + model.layers.8.mlp.experts.3.gate_proj, device: CPU + model.layers.8.mlp.experts.3.up_proj, device: CPU + model.layers.8.mlp.experts.3.down_proj, device: CPU + model.layers.8.mlp.experts.3.act, device: CPU + Module: model.layers.8.mlp.experts.4, device: CPU + model.layers.8.mlp.experts.4.gate_proj, device: CPU + model.layers.8.mlp.experts.4.up_proj, device: CPU + model.layers.8.mlp.experts.4.down_proj, device: CPU + model.layers.8.mlp.experts.4.act, device: CPU + Module: model.layers.8.mlp.experts.5, device: CPU + model.layers.8.mlp.experts.5.gate_proj, device: CPU + model.layers.8.mlp.experts.5.up_proj, device: CPU + model.layers.8.mlp.experts.5.down_proj, device: CPU + model.layers.8.mlp.experts.5.act, device: CPU + Module: model.layers.8.mlp.experts.6, device: CPU + model.layers.8.mlp.experts.6.gate_proj, device: CPU + model.layers.8.mlp.experts.6.up_proj, device: CPU + model.layers.8.mlp.experts.6.down_proj, device: CPU + model.layers.8.mlp.experts.6.act, device: CPU + Module: model.layers.8.mlp.experts.7, device: CPU + model.layers.8.mlp.experts.7.gate_proj, device: CPU + model.layers.8.mlp.experts.7.up_proj, device: CPU + model.layers.8.mlp.experts.7.down_proj, device: CPU + model.layers.8.mlp.experts.7.act, device: CPU + Module: model.layers.8.mlp.experts.8, device: CPU + model.layers.8.mlp.experts.8.gate_proj, device: CPU + model.layers.8.mlp.experts.8.up_proj, device: CPU + model.layers.8.mlp.experts.8.down_proj, device: CPU + model.layers.8.mlp.experts.8.act, device: CPU + Module: model.layers.8.mlp.experts.9, device: CPU + model.layers.8.mlp.experts.9.gate_proj, device: CPU + model.layers.8.mlp.experts.9.up_proj, device: CPU + model.layers.8.mlp.experts.9.down_proj, device: CPU + model.layers.8.mlp.experts.9.act, device: CPU + Module: model.layers.8.mlp.experts.10, device: CPU + model.layers.8.mlp.experts.10.gate_proj, device: CPU + model.layers.8.mlp.experts.10.up_proj, device: CPU + model.layers.8.mlp.experts.10.down_proj, device: CPU + model.layers.8.mlp.experts.10.act, device: CPU + Module: model.layers.8.mlp.experts.11, device: CPU + model.layers.8.mlp.experts.11.gate_proj, device: CPU + model.layers.8.mlp.experts.11.up_proj, device: CPU + model.layers.8.mlp.experts.11.down_proj, device: CPU + model.layers.8.mlp.experts.11.act, device: CPU + Module: model.layers.8.mlp.experts.12, device: CPU + model.layers.8.mlp.experts.12.gate_proj, device: CPU + model.layers.8.mlp.experts.12.up_proj, device: CPU + model.layers.8.mlp.experts.12.down_proj, device: CPU + model.layers.8.mlp.experts.12.act, device: CPU + Module: model.layers.8.mlp.experts.13, device: CPU + model.layers.8.mlp.experts.13.gate_proj, device: CPU + model.layers.8.mlp.experts.13.up_proj, device: CPU + model.layers.8.mlp.experts.13.down_proj, device: CPU + model.layers.8.mlp.experts.13.act, device: CPU + Module: model.layers.8.mlp.experts.14, device: CPU + model.layers.8.mlp.experts.14.gate_proj, device: CPU + model.layers.8.mlp.experts.14.up_proj, device: CPU + model.layers.8.mlp.experts.14.down_proj, device: CPU + model.layers.8.mlp.experts.14.act, device: CPU + Module: model.layers.8.mlp.experts.15, device: CPU + model.layers.8.mlp.experts.15.gate_proj, device: CPU + model.layers.8.mlp.experts.15.up_proj, device: CPU + model.layers.8.mlp.experts.15.down_proj, device: CPU + model.layers.8.mlp.experts.15.act, device: CPU + Module: model.layers.8.mlp.experts.16, device: CPU + model.layers.8.mlp.experts.16.gate_proj, device: CPU + model.layers.8.mlp.experts.16.up_proj, device: CPU + model.layers.8.mlp.experts.16.down_proj, device: CPU + model.layers.8.mlp.experts.16.act, device: CPU + Module: model.layers.8.mlp.experts.17, device: CPU + model.layers.8.mlp.experts.17.gate_proj, device: CPU + model.layers.8.mlp.experts.17.up_proj, device: CPU + model.layers.8.mlp.experts.17.down_proj, device: CPU + model.layers.8.mlp.experts.17.act, device: CPU + Module: model.layers.8.mlp.experts.18, device: CPU + model.layers.8.mlp.experts.18.gate_proj, device: CPU + model.layers.8.mlp.experts.18.up_proj, device: CPU + model.layers.8.mlp.experts.18.down_proj, device: CPU + model.layers.8.mlp.experts.18.act, device: CPU + Module: model.layers.8.mlp.experts.19, device: CPU + model.layers.8.mlp.experts.19.gate_proj, device: CPU + model.layers.8.mlp.experts.19.up_proj, device: CPU + model.layers.8.mlp.experts.19.down_proj, device: CPU + model.layers.8.mlp.experts.19.act, device: CPU + Module: model.layers.8.mlp.experts.20, device: CPU + model.layers.8.mlp.experts.20.gate_proj, device: CPU + model.layers.8.mlp.experts.20.up_proj, device: CPU + model.layers.8.mlp.experts.20.down_proj, device: CPU + model.layers.8.mlp.experts.20.act, device: CPU + Module: model.layers.8.mlp.experts.21, device: CPU + model.layers.8.mlp.experts.21.gate_proj, device: CPU + model.layers.8.mlp.experts.21.up_proj, device: CPU + model.layers.8.mlp.experts.21.down_proj, device: CPU + model.layers.8.mlp.experts.21.act, device: CPU + Module: model.layers.8.mlp.experts.22, device: CPU + model.layers.8.mlp.experts.22.gate_proj, device: CPU + model.layers.8.mlp.experts.22.up_proj, device: CPU + model.layers.8.mlp.experts.22.down_proj, device: CPU + model.layers.8.mlp.experts.22.act, device: CPU + Module: model.layers.8.mlp.experts.23, device: CPU + model.layers.8.mlp.experts.23.gate_proj, device: CPU + model.layers.8.mlp.experts.23.up_proj, device: CPU + model.layers.8.mlp.experts.23.down_proj, device: CPU + model.layers.8.mlp.experts.23.act, device: CPU + Module: model.layers.8.mlp.experts.24, device: CPU + model.layers.8.mlp.experts.24.gate_proj, device: CPU + model.layers.8.mlp.experts.24.up_proj, device: CPU + model.layers.8.mlp.experts.24.down_proj, device: CPU + model.layers.8.mlp.experts.24.act, device: CPU + Module: model.layers.8.mlp.experts.25, device: CPU + model.layers.8.mlp.experts.25.gate_proj, device: CPU + model.layers.8.mlp.experts.25.up_proj, device: CPU + model.layers.8.mlp.experts.25.down_proj, device: CPU + model.layers.8.mlp.experts.25.act, device: CPU + Module: model.layers.8.mlp.experts.26, device: CPU + model.layers.8.mlp.experts.26.gate_proj, device: CPU + model.layers.8.mlp.experts.26.up_proj, device: CPU + model.layers.8.mlp.experts.26.down_proj, device: CPU + model.layers.8.mlp.experts.26.act, device: CPU + Module: model.layers.8.mlp.experts.27, device: CPU + model.layers.8.mlp.experts.27.gate_proj, device: CPU + model.layers.8.mlp.experts.27.up_proj, device: CPU + model.layers.8.mlp.experts.27.down_proj, device: CPU + model.layers.8.mlp.experts.27.act, device: CPU + Module: model.layers.8.mlp.experts.28, device: CPU + model.layers.8.mlp.experts.28.gate_proj, device: CPU + model.layers.8.mlp.experts.28.up_proj, device: CPU + model.layers.8.mlp.experts.28.down_proj, device: CPU + model.layers.8.mlp.experts.28.act, device: CPU + Module: model.layers.8.mlp.experts.29, device: CPU + model.layers.8.mlp.experts.29.gate_proj, device: CPU + model.layers.8.mlp.experts.29.up_proj, device: CPU + model.layers.8.mlp.experts.29.down_proj, device: CPU + model.layers.8.mlp.experts.29.act, device: CPU + Module: model.layers.8.mlp.experts.30, device: CPU + model.layers.8.mlp.experts.30.gate_proj, device: CPU + model.layers.8.mlp.experts.30.up_proj, device: CPU + model.layers.8.mlp.experts.30.down_proj, device: CPU + model.layers.8.mlp.experts.30.act, device: CPU + Module: model.layers.8.mlp.experts.31, device: CPU + model.layers.8.mlp.experts.31.gate_proj, device: CPU + model.layers.8.mlp.experts.31.up_proj, device: CPU + model.layers.8.mlp.experts.31.down_proj, device: CPU + model.layers.8.mlp.experts.31.act, device: CPU + Module: model.layers.8.mlp.experts.32, device: CPU + model.layers.8.mlp.experts.32.gate_proj, device: CPU + model.layers.8.mlp.experts.32.up_proj, device: CPU + model.layers.8.mlp.experts.32.down_proj, device: CPU + model.layers.8.mlp.experts.32.act, device: CPU + Module: model.layers.8.mlp.experts.33, device: CPU + model.layers.8.mlp.experts.33.gate_proj, device: CPU + model.layers.8.mlp.experts.33.up_proj, device: CPU + model.layers.8.mlp.experts.33.down_proj, device: CPU + model.layers.8.mlp.experts.33.act, device: CPU + Module: model.layers.8.mlp.experts.34, device: CPU + model.layers.8.mlp.experts.34.gate_proj, device: CPU + model.layers.8.mlp.experts.34.up_proj, device: CPU + model.layers.8.mlp.experts.34.down_proj, device: CPU + model.layers.8.mlp.experts.34.act, device: CPU + Module: model.layers.8.mlp.experts.35, device: CPU + model.layers.8.mlp.experts.35.gate_proj, device: CPU + model.layers.8.mlp.experts.35.up_proj, device: CPU + model.layers.8.mlp.experts.35.down_proj, device: CPU + model.layers.8.mlp.experts.35.act, device: CPU + Module: model.layers.8.mlp.experts.36, device: CPU + model.layers.8.mlp.experts.36.gate_proj, device: CPU + model.layers.8.mlp.experts.36.up_proj, device: CPU + model.layers.8.mlp.experts.36.down_proj, device: CPU + model.layers.8.mlp.experts.36.act, device: CPU + Module: model.layers.8.mlp.experts.37, device: CPU + model.layers.8.mlp.experts.37.gate_proj, device: CPU + model.layers.8.mlp.experts.37.up_proj, device: CPU + model.layers.8.mlp.experts.37.down_proj, device: CPU + model.layers.8.mlp.experts.37.act, device: CPU + Module: model.layers.8.mlp.experts.38, device: CPU + model.layers.8.mlp.experts.38.gate_proj, device: CPU + model.layers.8.mlp.experts.38.up_proj, device: CPU + model.layers.8.mlp.experts.38.down_proj, device: CPU + model.layers.8.mlp.experts.38.act, device: CPU + Module: model.layers.8.mlp.experts.39, device: CPU + model.layers.8.mlp.experts.39.gate_proj, device: CPU + model.layers.8.mlp.experts.39.up_proj, device: CPU + model.layers.8.mlp.experts.39.down_proj, device: CPU + model.layers.8.mlp.experts.39.act, device: CPU + Module: model.layers.8.mlp.experts.40, device: CPU + model.layers.8.mlp.experts.40.gate_proj, device: CPU + model.layers.8.mlp.experts.40.up_proj, device: CPU + model.layers.8.mlp.experts.40.down_proj, device: CPU + model.layers.8.mlp.experts.40.act, device: CPU + Module: model.layers.8.mlp.experts.41, device: CPU + model.layers.8.mlp.experts.41.gate_proj, device: CPU + model.layers.8.mlp.experts.41.up_proj, device: CPU + model.layers.8.mlp.experts.41.down_proj, device: CPU + model.layers.8.mlp.experts.41.act, device: CPU + Module: model.layers.8.mlp.experts.42, device: CPU + model.layers.8.mlp.experts.42.gate_proj, device: CPU + model.layers.8.mlp.experts.42.up_proj, device: CPU + model.layers.8.mlp.experts.42.down_proj, device: CPU + model.layers.8.mlp.experts.42.act, device: CPU + Module: model.layers.8.mlp.experts.43, device: CPU + model.layers.8.mlp.experts.43.gate_proj, device: CPU + model.layers.8.mlp.experts.43.up_proj, device: CPU + model.layers.8.mlp.experts.43.down_proj, device: CPU + model.layers.8.mlp.experts.43.act, device: CPU + Module: model.layers.8.mlp.experts.44, device: CPU + model.layers.8.mlp.experts.44.gate_proj, device: CPU + model.layers.8.mlp.experts.44.up_proj, device: CPU + model.layers.8.mlp.experts.44.down_proj, device: CPU + model.layers.8.mlp.experts.44.act, device: CPU + Module: model.layers.8.mlp.experts.45, device: CPU + model.layers.8.mlp.experts.45.gate_proj, device: CPU + model.layers.8.mlp.experts.45.up_proj, device: CPU + model.layers.8.mlp.experts.45.down_proj, device: CPU + model.layers.8.mlp.experts.45.act, device: CPU + Module: model.layers.8.mlp.experts.46, device: CPU + model.layers.8.mlp.experts.46.gate_proj, device: CPU + model.layers.8.mlp.experts.46.up_proj, device: CPU + model.layers.8.mlp.experts.46.down_proj, device: CPU + model.layers.8.mlp.experts.46.act, device: CPU + Module: model.layers.8.mlp.experts.47, device: CPU + model.layers.8.mlp.experts.47.gate_proj, device: CPU + model.layers.8.mlp.experts.47.up_proj, device: CPU + model.layers.8.mlp.experts.47.down_proj, device: CPU + model.layers.8.mlp.experts.47.act, device: CPU + Module: model.layers.8.mlp.experts.48, device: CPU + model.layers.8.mlp.experts.48.gate_proj, device: CPU + model.layers.8.mlp.experts.48.up_proj, device: CPU + model.layers.8.mlp.experts.48.down_proj, device: CPU + model.layers.8.mlp.experts.48.act, device: CPU + Module: model.layers.8.mlp.experts.49, device: CPU + model.layers.8.mlp.experts.49.gate_proj, device: CPU + model.layers.8.mlp.experts.49.up_proj, device: CPU + model.layers.8.mlp.experts.49.down_proj, device: CPU + model.layers.8.mlp.experts.49.act, device: CPU + Module: model.layers.8.mlp.experts.50, device: CPU + model.layers.8.mlp.experts.50.gate_proj, device: CPU + model.layers.8.mlp.experts.50.up_proj, device: CPU + model.layers.8.mlp.experts.50.down_proj, device: CPU + model.layers.8.mlp.experts.50.act, device: CPU + Module: model.layers.8.mlp.experts.51, device: CPU + model.layers.8.mlp.experts.51.gate_proj, device: CPU + model.layers.8.mlp.experts.51.up_proj, device: CPU + model.layers.8.mlp.experts.51.down_proj, device: CPU + model.layers.8.mlp.experts.51.act, device: CPU + Module: model.layers.8.mlp.experts.52, device: CPU + model.layers.8.mlp.experts.52.gate_proj, device: CPU + model.layers.8.mlp.experts.52.up_proj, device: CPU + model.layers.8.mlp.experts.52.down_proj, device: CPU + model.layers.8.mlp.experts.52.act, device: CPU + Module: model.layers.8.mlp.experts.53, device: CPU + model.layers.8.mlp.experts.53.gate_proj, device: CPU + model.layers.8.mlp.experts.53.up_proj, device: CPU + model.layers.8.mlp.experts.53.down_proj, device: CPU + model.layers.8.mlp.experts.53.act, device: CPU + Module: model.layers.8.mlp.experts.54, device: CPU + model.layers.8.mlp.experts.54.gate_proj, device: CPU + model.layers.8.mlp.experts.54.up_proj, device: CPU + model.layers.8.mlp.experts.54.down_proj, device: CPU + model.layers.8.mlp.experts.54.act, device: CPU + Module: model.layers.8.mlp.experts.55, device: CPU + model.layers.8.mlp.experts.55.gate_proj, device: CPU + model.layers.8.mlp.experts.55.up_proj, device: CPU + model.layers.8.mlp.experts.55.down_proj, device: CPU + model.layers.8.mlp.experts.55.act, device: CPU + Module: model.layers.8.mlp.experts.56, device: CPU + model.layers.8.mlp.experts.56.gate_proj, device: CPU + model.layers.8.mlp.experts.56.up_proj, device: CPU + model.layers.8.mlp.experts.56.down_proj, device: CPU + model.layers.8.mlp.experts.56.act, device: CPU + Module: model.layers.8.mlp.experts.57, device: CPU + model.layers.8.mlp.experts.57.gate_proj, device: CPU + model.layers.8.mlp.experts.57.up_proj, device: CPU + model.layers.8.mlp.experts.57.down_proj, device: CPU + model.layers.8.mlp.experts.57.act, device: CPU + Module: model.layers.8.mlp.experts.58, device: CPU + model.layers.8.mlp.experts.58.gate_proj, device: CPU + model.layers.8.mlp.experts.58.up_proj, device: CPU + model.layers.8.mlp.experts.58.down_proj, device: CPU + model.layers.8.mlp.experts.58.act, device: CPU + Module: model.layers.8.mlp.experts.59, device: CPU + model.layers.8.mlp.experts.59.gate_proj, device: CPU + model.layers.8.mlp.experts.59.up_proj, device: CPU + model.layers.8.mlp.experts.59.down_proj, device: CPU + model.layers.8.mlp.experts.59.act, device: CPU + Module: model.layers.8.mlp.experts.60, device: CPU + model.layers.8.mlp.experts.60.gate_proj, device: CPU + model.layers.8.mlp.experts.60.up_proj, device: CPU + model.layers.8.mlp.experts.60.down_proj, device: CPU + model.layers.8.mlp.experts.60.act, device: CPU + Module: model.layers.8.mlp.experts.61, device: CPU + model.layers.8.mlp.experts.61.gate_proj, device: CPU + model.layers.8.mlp.experts.61.up_proj, device: CPU + model.layers.8.mlp.experts.61.down_proj, device: CPU + model.layers.8.mlp.experts.61.act, device: CPU + Module: model.layers.8.mlp.experts.62, device: CPU + model.layers.8.mlp.experts.62.gate_proj, device: CPU + model.layers.8.mlp.experts.62.up_proj, device: CPU + model.layers.8.mlp.experts.62.down_proj, device: CPU + model.layers.8.mlp.experts.62.act, device: CPU + Module: model.layers.8.mlp.experts.63, device: CPU + model.layers.8.mlp.experts.63.gate_proj, device: CPU + model.layers.8.mlp.experts.63.up_proj, device: CPU + model.layers.8.mlp.experts.63.down_proj, device: CPU + model.layers.8.mlp.experts.63.act, device: CPU + Module: model.layers.8.mlp.gate, device: CPU + model.layers.8.mlp.gate.weight, device: CPU + Module: model.layers.8.mlp.shared_experts, device: CPU + model.layers.8.mlp.shared_experts.gate_proj, device: CPU + model.layers.8.mlp.shared_experts.up_proj, device: CPU + model.layers.8.mlp.shared_experts.down_proj, device: CPU + model.layers.8.mlp.shared_experts.act, device: CPU + model.layers.8.input_layernorm, device: CPU + model.layers.8.post_attention_layernorm, device: CPU + Module: model.layers.9, device: CPU + Module: model.layers.9.self_attn, device: CPU + model.layers.9.self_attn.q_proj, device: CPU + model.layers.9.self_attn.k_proj, device: CPU + model.layers.9.self_attn.v_proj, device: CPU + model.layers.9.self_attn.o_proj, device: CPU + model.layers.9.self_attn.q_rope, device: CPU + model.layers.9.self_attn.k_rope, device: CPU + Module: model.layers.9.mlp, device: CPU + Module: model.layers.9.mlp.experts, device: CPU + Module: model.layers.9.mlp.experts.0, device: CPU + model.layers.9.mlp.experts.0.gate_proj, device: CPU + model.layers.9.mlp.experts.0.up_proj, device: CPU + model.layers.9.mlp.experts.0.down_proj, device: CPU + model.layers.9.mlp.experts.0.act, device: CPU + Module: model.layers.9.mlp.experts.1, device: CPU + model.layers.9.mlp.experts.1.gate_proj, device: CPU + model.layers.9.mlp.experts.1.up_proj, device: CPU + model.layers.9.mlp.experts.1.down_proj, device: CPU + model.layers.9.mlp.experts.1.act, device: CPU + Module: model.layers.9.mlp.experts.2, device: CPU + model.layers.9.mlp.experts.2.gate_proj, device: CPU + model.layers.9.mlp.experts.2.up_proj, device: CPU + model.layers.9.mlp.experts.2.down_proj, device: CPU + model.layers.9.mlp.experts.2.act, device: CPU + Module: model.layers.9.mlp.experts.3, device: CPU + model.layers.9.mlp.experts.3.gate_proj, device: CPU + model.layers.9.mlp.experts.3.up_proj, device: CPU + model.layers.9.mlp.experts.3.down_proj, device: CPU + model.layers.9.mlp.experts.3.act, device: CPU + Module: model.layers.9.mlp.experts.4, device: CPU + model.layers.9.mlp.experts.4.gate_proj, device: CPU + model.layers.9.mlp.experts.4.up_proj, device: CPU + model.layers.9.mlp.experts.4.down_proj, device: CPU + model.layers.9.mlp.experts.4.act, device: CPU + Module: model.layers.9.mlp.experts.5, device: CPU + model.layers.9.mlp.experts.5.gate_proj, device: CPU + model.layers.9.mlp.experts.5.up_proj, device: CPU + model.layers.9.mlp.experts.5.down_proj, device: CPU + model.layers.9.mlp.experts.5.act, device: CPU + Module: model.layers.9.mlp.experts.6, device: CPU + model.layers.9.mlp.experts.6.gate_proj, device: CPU + model.layers.9.mlp.experts.6.up_proj, device: CPU + model.layers.9.mlp.experts.6.down_proj, device: CPU + model.layers.9.mlp.experts.6.act, device: CPU + Module: model.layers.9.mlp.experts.7, device: CPU + model.layers.9.mlp.experts.7.gate_proj, device: CPU + model.layers.9.mlp.experts.7.up_proj, device: CPU + model.layers.9.mlp.experts.7.down_proj, device: CPU + model.layers.9.mlp.experts.7.act, device: CPU + Module: model.layers.9.mlp.experts.8, device: CPU + model.layers.9.mlp.experts.8.gate_proj, device: CPU + model.layers.9.mlp.experts.8.up_proj, device: CPU + model.layers.9.mlp.experts.8.down_proj, device: CPU + model.layers.9.mlp.experts.8.act, device: CPU + Module: model.layers.9.mlp.experts.9, device: CPU + model.layers.9.mlp.experts.9.gate_proj, device: CPU + model.layers.9.mlp.experts.9.up_proj, device: CPU + model.layers.9.mlp.experts.9.down_proj, device: CPU + model.layers.9.mlp.experts.9.act, device: CPU + Module: model.layers.9.mlp.experts.10, device: CPU + model.layers.9.mlp.experts.10.gate_proj, device: CPU + model.layers.9.mlp.experts.10.up_proj, device: CPU + model.layers.9.mlp.experts.10.down_proj, device: CPU + model.layers.9.mlp.experts.10.act, device: CPU + Module: model.layers.9.mlp.experts.11, device: CPU + model.layers.9.mlp.experts.11.gate_proj, device: CPU + model.layers.9.mlp.experts.11.up_proj, device: CPU + model.layers.9.mlp.experts.11.down_proj, device: CPU + model.layers.9.mlp.experts.11.act, device: CPU + Module: model.layers.9.mlp.experts.12, device: CPU + model.layers.9.mlp.experts.12.gate_proj, device: CPU + model.layers.9.mlp.experts.12.up_proj, device: CPU + model.layers.9.mlp.experts.12.down_proj, device: CPU + model.layers.9.mlp.experts.12.act, device: CPU + Module: model.layers.9.mlp.experts.13, device: CPU + model.layers.9.mlp.experts.13.gate_proj, device: CPU + model.layers.9.mlp.experts.13.up_proj, device: CPU + model.layers.9.mlp.experts.13.down_proj, device: CPU + model.layers.9.mlp.experts.13.act, device: CPU + Module: model.layers.9.mlp.experts.14, device: CPU + model.layers.9.mlp.experts.14.gate_proj, device: CPU + model.layers.9.mlp.experts.14.up_proj, device: CPU + model.layers.9.mlp.experts.14.down_proj, device: CPU + model.layers.9.mlp.experts.14.act, device: CPU + Module: model.layers.9.mlp.experts.15, device: CPU + model.layers.9.mlp.experts.15.gate_proj, device: CPU + model.layers.9.mlp.experts.15.up_proj, device: CPU + model.layers.9.mlp.experts.15.down_proj, device: CPU + model.layers.9.mlp.experts.15.act, device: CPU + Module: model.layers.9.mlp.experts.16, device: CPU + model.layers.9.mlp.experts.16.gate_proj, device: CPU + model.layers.9.mlp.experts.16.up_proj, device: CPU + model.layers.9.mlp.experts.16.down_proj, device: CPU + model.layers.9.mlp.experts.16.act, device: CPU + Module: model.layers.9.mlp.experts.17, device: CPU + model.layers.9.mlp.experts.17.gate_proj, device: CPU + model.layers.9.mlp.experts.17.up_proj, device: CPU + model.layers.9.mlp.experts.17.down_proj, device: CPU + model.layers.9.mlp.experts.17.act, device: CPU + Module: model.layers.9.mlp.experts.18, device: CPU + model.layers.9.mlp.experts.18.gate_proj, device: CPU + model.layers.9.mlp.experts.18.up_proj, device: CPU + model.layers.9.mlp.experts.18.down_proj, device: CPU + model.layers.9.mlp.experts.18.act, device: CPU + Module: model.layers.9.mlp.experts.19, device: CPU + model.layers.9.mlp.experts.19.gate_proj, device: CPU + model.layers.9.mlp.experts.19.up_proj, device: CPU + model.layers.9.mlp.experts.19.down_proj, device: CPU + model.layers.9.mlp.experts.19.act, device: CPU + Module: model.layers.9.mlp.experts.20, device: CPU + model.layers.9.mlp.experts.20.gate_proj, device: CPU + model.layers.9.mlp.experts.20.up_proj, device: CPU + model.layers.9.mlp.experts.20.down_proj, device: CPU + model.layers.9.mlp.experts.20.act, device: CPU + Module: model.layers.9.mlp.experts.21, device: CPU + model.layers.9.mlp.experts.21.gate_proj, device: CPU + model.layers.9.mlp.experts.21.up_proj, device: CPU + model.layers.9.mlp.experts.21.down_proj, device: CPU + model.layers.9.mlp.experts.21.act, device: CPU + Module: model.layers.9.mlp.experts.22, device: CPU + model.layers.9.mlp.experts.22.gate_proj, device: CPU + model.layers.9.mlp.experts.22.up_proj, device: CPU + model.layers.9.mlp.experts.22.down_proj, device: CPU + model.layers.9.mlp.experts.22.act, device: CPU + Module: model.layers.9.mlp.experts.23, device: CPU + model.layers.9.mlp.experts.23.gate_proj, device: CPU + model.layers.9.mlp.experts.23.up_proj, device: CPU + model.layers.9.mlp.experts.23.down_proj, device: CPU + model.layers.9.mlp.experts.23.act, device: CPU + Module: model.layers.9.mlp.experts.24, device: CPU + model.layers.9.mlp.experts.24.gate_proj, device: CPU + model.layers.9.mlp.experts.24.up_proj, device: CPU + model.layers.9.mlp.experts.24.down_proj, device: CPU + model.layers.9.mlp.experts.24.act, device: CPU + Module: model.layers.9.mlp.experts.25, device: CPU + model.layers.9.mlp.experts.25.gate_proj, device: CPU + model.layers.9.mlp.experts.25.up_proj, device: CPU + model.layers.9.mlp.experts.25.down_proj, device: CPU + model.layers.9.mlp.experts.25.act, device: CPU + Module: model.layers.9.mlp.experts.26, device: CPU + model.layers.9.mlp.experts.26.gate_proj, device: CPU + model.layers.9.mlp.experts.26.up_proj, device: CPU + model.layers.9.mlp.experts.26.down_proj, device: CPU + model.layers.9.mlp.experts.26.act, device: CPU + Module: model.layers.9.mlp.experts.27, device: CPU + model.layers.9.mlp.experts.27.gate_proj, device: CPU + model.layers.9.mlp.experts.27.up_proj, device: CPU + model.layers.9.mlp.experts.27.down_proj, device: CPU + model.layers.9.mlp.experts.27.act, device: CPU + Module: model.layers.9.mlp.experts.28, device: CPU + model.layers.9.mlp.experts.28.gate_proj, device: CPU + model.layers.9.mlp.experts.28.up_proj, device: CPU + model.layers.9.mlp.experts.28.down_proj, device: CPU + model.layers.9.mlp.experts.28.act, device: CPU + Module: model.layers.9.mlp.experts.29, device: CPU + model.layers.9.mlp.experts.29.gate_proj, device: CPU + model.layers.9.mlp.experts.29.up_proj, device: CPU + model.layers.9.mlp.experts.29.down_proj, device: CPU + model.layers.9.mlp.experts.29.act, device: CPU + Module: model.layers.9.mlp.experts.30, device: CPU + model.layers.9.mlp.experts.30.gate_proj, device: CPU + model.layers.9.mlp.experts.30.up_proj, device: CPU + model.layers.9.mlp.experts.30.down_proj, device: CPU + model.layers.9.mlp.experts.30.act, device: CPU + Module: model.layers.9.mlp.experts.31, device: CPU + model.layers.9.mlp.experts.31.gate_proj, device: CPU + model.layers.9.mlp.experts.31.up_proj, device: CPU + model.layers.9.mlp.experts.31.down_proj, device: CPU + model.layers.9.mlp.experts.31.act, device: CPU + Module: model.layers.9.mlp.experts.32, device: CPU + model.layers.9.mlp.experts.32.gate_proj, device: CPU + model.layers.9.mlp.experts.32.up_proj, device: CPU + model.layers.9.mlp.experts.32.down_proj, device: CPU + model.layers.9.mlp.experts.32.act, device: CPU + Module: model.layers.9.mlp.experts.33, device: CPU + model.layers.9.mlp.experts.33.gate_proj, device: CPU + model.layers.9.mlp.experts.33.up_proj, device: CPU + model.layers.9.mlp.experts.33.down_proj, device: CPU + model.layers.9.mlp.experts.33.act, device: CPU + Module: model.layers.9.mlp.experts.34, device: CPU + model.layers.9.mlp.experts.34.gate_proj, device: CPU + model.layers.9.mlp.experts.34.up_proj, device: CPU + model.layers.9.mlp.experts.34.down_proj, device: CPU + model.layers.9.mlp.experts.34.act, device: CPU + Module: model.layers.9.mlp.experts.35, device: CPU + model.layers.9.mlp.experts.35.gate_proj, device: CPU + model.layers.9.mlp.experts.35.up_proj, device: CPU + model.layers.9.mlp.experts.35.down_proj, device: CPU + model.layers.9.mlp.experts.35.act, device: CPU + Module: model.layers.9.mlp.experts.36, device: CPU + model.layers.9.mlp.experts.36.gate_proj, device: CPU + model.layers.9.mlp.experts.36.up_proj, device: CPU + model.layers.9.mlp.experts.36.down_proj, device: CPU + model.layers.9.mlp.experts.36.act, device: CPU + Module: model.layers.9.mlp.experts.37, device: CPU + model.layers.9.mlp.experts.37.gate_proj, device: CPU + model.layers.9.mlp.experts.37.up_proj, device: CPU + model.layers.9.mlp.experts.37.down_proj, device: CPU + model.layers.9.mlp.experts.37.act, device: CPU + Module: model.layers.9.mlp.experts.38, device: CPU + model.layers.9.mlp.experts.38.gate_proj, device: CPU + model.layers.9.mlp.experts.38.up_proj, device: CPU + model.layers.9.mlp.experts.38.down_proj, device: CPU + model.layers.9.mlp.experts.38.act, device: CPU + Module: model.layers.9.mlp.experts.39, device: CPU + model.layers.9.mlp.experts.39.gate_proj, device: CPU + model.layers.9.mlp.experts.39.up_proj, device: CPU + model.layers.9.mlp.experts.39.down_proj, device: CPU + model.layers.9.mlp.experts.39.act, device: CPU + Module: model.layers.9.mlp.experts.40, device: CPU + model.layers.9.mlp.experts.40.gate_proj, device: CPU + model.layers.9.mlp.experts.40.up_proj, device: CPU + model.layers.9.mlp.experts.40.down_proj, device: CPU + model.layers.9.mlp.experts.40.act, device: CPU + Module: model.layers.9.mlp.experts.41, device: CPU + model.layers.9.mlp.experts.41.gate_proj, device: CPU + model.layers.9.mlp.experts.41.up_proj, device: CPU + model.layers.9.mlp.experts.41.down_proj, device: CPU + model.layers.9.mlp.experts.41.act, device: CPU + Module: model.layers.9.mlp.experts.42, device: CPU + model.layers.9.mlp.experts.42.gate_proj, device: CPU + model.layers.9.mlp.experts.42.up_proj, device: CPU + model.layers.9.mlp.experts.42.down_proj, device: CPU + model.layers.9.mlp.experts.42.act, device: CPU + Module: model.layers.9.mlp.experts.43, device: CPU + model.layers.9.mlp.experts.43.gate_proj, device: CPU + model.layers.9.mlp.experts.43.up_proj, device: CPU + model.layers.9.mlp.experts.43.down_proj, device: CPU + model.layers.9.mlp.experts.43.act, device: CPU + Module: model.layers.9.mlp.experts.44, device: CPU + model.layers.9.mlp.experts.44.gate_proj, device: CPU + model.layers.9.mlp.experts.44.up_proj, device: CPU + model.layers.9.mlp.experts.44.down_proj, device: CPU + model.layers.9.mlp.experts.44.act, device: CPU + Module: model.layers.9.mlp.experts.45, device: CPU + model.layers.9.mlp.experts.45.gate_proj, device: CPU + model.layers.9.mlp.experts.45.up_proj, device: CPU + model.layers.9.mlp.experts.45.down_proj, device: CPU + model.layers.9.mlp.experts.45.act, device: CPU + Module: model.layers.9.mlp.experts.46, device: CPU + model.layers.9.mlp.experts.46.gate_proj, device: CPU + model.layers.9.mlp.experts.46.up_proj, device: CPU + model.layers.9.mlp.experts.46.down_proj, device: CPU + model.layers.9.mlp.experts.46.act, device: CPU + Module: model.layers.9.mlp.experts.47, device: CPU + model.layers.9.mlp.experts.47.gate_proj, device: CPU + model.layers.9.mlp.experts.47.up_proj, device: CPU + model.layers.9.mlp.experts.47.down_proj, device: CPU + model.layers.9.mlp.experts.47.act, device: CPU + Module: model.layers.9.mlp.experts.48, device: CPU + model.layers.9.mlp.experts.48.gate_proj, device: CPU + model.layers.9.mlp.experts.48.up_proj, device: CPU + model.layers.9.mlp.experts.48.down_proj, device: CPU + model.layers.9.mlp.experts.48.act, device: CPU + Module: model.layers.9.mlp.experts.49, device: CPU + model.layers.9.mlp.experts.49.gate_proj, device: CPU + model.layers.9.mlp.experts.49.up_proj, device: CPU + model.layers.9.mlp.experts.49.down_proj, device: CPU + model.layers.9.mlp.experts.49.act, device: CPU + Module: model.layers.9.mlp.experts.50, device: CPU + model.layers.9.mlp.experts.50.gate_proj, device: CPU + model.layers.9.mlp.experts.50.up_proj, device: CPU + model.layers.9.mlp.experts.50.down_proj, device: CPU + model.layers.9.mlp.experts.50.act, device: CPU + Module: model.layers.9.mlp.experts.51, device: CPU + model.layers.9.mlp.experts.51.gate_proj, device: CPU + model.layers.9.mlp.experts.51.up_proj, device: CPU + model.layers.9.mlp.experts.51.down_proj, device: CPU + model.layers.9.mlp.experts.51.act, device: CPU + Module: model.layers.9.mlp.experts.52, device: CPU + model.layers.9.mlp.experts.52.gate_proj, device: CPU + model.layers.9.mlp.experts.52.up_proj, device: CPU + model.layers.9.mlp.experts.52.down_proj, device: CPU + model.layers.9.mlp.experts.52.act, device: CPU + Module: model.layers.9.mlp.experts.53, device: CPU + model.layers.9.mlp.experts.53.gate_proj, device: CPU + model.layers.9.mlp.experts.53.up_proj, device: CPU + model.layers.9.mlp.experts.53.down_proj, device: CPU + model.layers.9.mlp.experts.53.act, device: CPU + Module: model.layers.9.mlp.experts.54, device: CPU + model.layers.9.mlp.experts.54.gate_proj, device: CPU + model.layers.9.mlp.experts.54.up_proj, device: CPU + model.layers.9.mlp.experts.54.down_proj, device: CPU + model.layers.9.mlp.experts.54.act, device: CPU + Module: model.layers.9.mlp.experts.55, device: CPU + model.layers.9.mlp.experts.55.gate_proj, device: CPU + model.layers.9.mlp.experts.55.up_proj, device: CPU + model.layers.9.mlp.experts.55.down_proj, device: CPU + model.layers.9.mlp.experts.55.act, device: CPU + Module: model.layers.9.mlp.experts.56, device: CPU + model.layers.9.mlp.experts.56.gate_proj, device: CPU + model.layers.9.mlp.experts.56.up_proj, device: CPU + model.layers.9.mlp.experts.56.down_proj, device: CPU + model.layers.9.mlp.experts.56.act, device: CPU + Module: model.layers.9.mlp.experts.57, device: CPU + model.layers.9.mlp.experts.57.gate_proj, device: CPU + model.layers.9.mlp.experts.57.up_proj, device: CPU + model.layers.9.mlp.experts.57.down_proj, device: CPU + model.layers.9.mlp.experts.57.act, device: CPU + Module: model.layers.9.mlp.experts.58, device: CPU + model.layers.9.mlp.experts.58.gate_proj, device: CPU + model.layers.9.mlp.experts.58.up_proj, device: CPU + model.layers.9.mlp.experts.58.down_proj, device: CPU + model.layers.9.mlp.experts.58.act, device: CPU + Module: model.layers.9.mlp.experts.59, device: CPU + model.layers.9.mlp.experts.59.gate_proj, device: CPU + model.layers.9.mlp.experts.59.up_proj, device: CPU + model.layers.9.mlp.experts.59.down_proj, device: CPU + model.layers.9.mlp.experts.59.act, device: CPU + Module: model.layers.9.mlp.experts.60, device: CPU + model.layers.9.mlp.experts.60.gate_proj, device: CPU + model.layers.9.mlp.experts.60.up_proj, device: CPU + model.layers.9.mlp.experts.60.down_proj, device: CPU + model.layers.9.mlp.experts.60.act, device: CPU + Module: model.layers.9.mlp.experts.61, device: CPU + model.layers.9.mlp.experts.61.gate_proj, device: CPU + model.layers.9.mlp.experts.61.up_proj, device: CPU + model.layers.9.mlp.experts.61.down_proj, device: CPU + model.layers.9.mlp.experts.61.act, device: CPU + Module: model.layers.9.mlp.experts.62, device: CPU + model.layers.9.mlp.experts.62.gate_proj, device: CPU + model.layers.9.mlp.experts.62.up_proj, device: CPU + model.layers.9.mlp.experts.62.down_proj, device: CPU + model.layers.9.mlp.experts.62.act, device: CPU + Module: model.layers.9.mlp.experts.63, device: CPU + model.layers.9.mlp.experts.63.gate_proj, device: CPU + model.layers.9.mlp.experts.63.up_proj, device: CPU + model.layers.9.mlp.experts.63.down_proj, device: CPU + model.layers.9.mlp.experts.63.act, device: CPU + Module: model.layers.9.mlp.gate, device: CPU + model.layers.9.mlp.gate.weight, device: CPU + Module: model.layers.9.mlp.shared_experts, device: CPU + model.layers.9.mlp.shared_experts.gate_proj, device: CPU + model.layers.9.mlp.shared_experts.up_proj, device: CPU + model.layers.9.mlp.shared_experts.down_proj, device: CPU + model.layers.9.mlp.shared_experts.act, device: CPU + model.layers.9.input_layernorm, device: CPU + model.layers.9.post_attention_layernorm, device: CPU + Module: model.layers.10, device: CPU + Module: model.layers.10.self_attn, device: CPU + model.layers.10.self_attn.q_proj, device: CPU + model.layers.10.self_attn.k_proj, device: CPU + model.layers.10.self_attn.v_proj, device: CPU + model.layers.10.self_attn.o_proj, device: CPU + model.layers.10.self_attn.q_rope, device: CPU + model.layers.10.self_attn.k_rope, device: CPU + Module: model.layers.10.mlp, device: CPU + Module: model.layers.10.mlp.experts, device: CPU + Module: model.layers.10.mlp.experts.0, device: CPU + model.layers.10.mlp.experts.0.gate_proj, device: CPU + model.layers.10.mlp.experts.0.up_proj, device: CPU + model.layers.10.mlp.experts.0.down_proj, device: CPU + model.layers.10.mlp.experts.0.act, device: CPU + Module: model.layers.10.mlp.experts.1, device: CPU + model.layers.10.mlp.experts.1.gate_proj, device: CPU + model.layers.10.mlp.experts.1.up_proj, device: CPU + model.layers.10.mlp.experts.1.down_proj, device: CPU + model.layers.10.mlp.experts.1.act, device: CPU + Module: model.layers.10.mlp.experts.2, device: CPU + model.layers.10.mlp.experts.2.gate_proj, device: CPU + model.layers.10.mlp.experts.2.up_proj, device: CPU + model.layers.10.mlp.experts.2.down_proj, device: CPU + model.layers.10.mlp.experts.2.act, device: CPU + Module: model.layers.10.mlp.experts.3, device: CPU + model.layers.10.mlp.experts.3.gate_proj, device: CPU + model.layers.10.mlp.experts.3.up_proj, device: CPU + model.layers.10.mlp.experts.3.down_proj, device: CPU + model.layers.10.mlp.experts.3.act, device: CPU + Module: model.layers.10.mlp.experts.4, device: CPU + model.layers.10.mlp.experts.4.gate_proj, device: CPU + model.layers.10.mlp.experts.4.up_proj, device: CPU + model.layers.10.mlp.experts.4.down_proj, device: CPU + model.layers.10.mlp.experts.4.act, device: CPU + Module: model.layers.10.mlp.experts.5, device: CPU + model.layers.10.mlp.experts.5.gate_proj, device: CPU + model.layers.10.mlp.experts.5.up_proj, device: CPU + model.layers.10.mlp.experts.5.down_proj, device: CPU + model.layers.10.mlp.experts.5.act, device: CPU + Module: model.layers.10.mlp.experts.6, device: CPU + model.layers.10.mlp.experts.6.gate_proj, device: CPU + model.layers.10.mlp.experts.6.up_proj, device: CPU + model.layers.10.mlp.experts.6.down_proj, device: CPU + model.layers.10.mlp.experts.6.act, device: CPU + Module: model.layers.10.mlp.experts.7, device: CPU + model.layers.10.mlp.experts.7.gate_proj, device: CPU + model.layers.10.mlp.experts.7.up_proj, device: CPU + model.layers.10.mlp.experts.7.down_proj, device: CPU + model.layers.10.mlp.experts.7.act, device: CPU + Module: model.layers.10.mlp.experts.8, device: CPU + model.layers.10.mlp.experts.8.gate_proj, device: CPU + model.layers.10.mlp.experts.8.up_proj, device: CPU + model.layers.10.mlp.experts.8.down_proj, device: CPU + model.layers.10.mlp.experts.8.act, device: CPU + Module: model.layers.10.mlp.experts.9, device: CPU + model.layers.10.mlp.experts.9.gate_proj, device: CPU + model.layers.10.mlp.experts.9.up_proj, device: CPU + model.layers.10.mlp.experts.9.down_proj, device: CPU + model.layers.10.mlp.experts.9.act, device: CPU + Module: model.layers.10.mlp.experts.10, device: CPU + model.layers.10.mlp.experts.10.gate_proj, device: CPU + model.layers.10.mlp.experts.10.up_proj, device: CPU + model.layers.10.mlp.experts.10.down_proj, device: CPU + model.layers.10.mlp.experts.10.act, device: CPU + Module: model.layers.10.mlp.experts.11, device: CPU + model.layers.10.mlp.experts.11.gate_proj, device: CPU + model.layers.10.mlp.experts.11.up_proj, device: CPU + model.layers.10.mlp.experts.11.down_proj, device: CPU + model.layers.10.mlp.experts.11.act, device: CPU + Module: model.layers.10.mlp.experts.12, device: CPU + model.layers.10.mlp.experts.12.gate_proj, device: CPU + model.layers.10.mlp.experts.12.up_proj, device: CPU + model.layers.10.mlp.experts.12.down_proj, device: CPU + model.layers.10.mlp.experts.12.act, device: CPU + Module: model.layers.10.mlp.experts.13, device: CPU + model.layers.10.mlp.experts.13.gate_proj, device: CPU + model.layers.10.mlp.experts.13.up_proj, device: CPU + model.layers.10.mlp.experts.13.down_proj, device: CPU + model.layers.10.mlp.experts.13.act, device: CPU + Module: model.layers.10.mlp.experts.14, device: CPU + model.layers.10.mlp.experts.14.gate_proj, device: CPU + model.layers.10.mlp.experts.14.up_proj, device: CPU + model.layers.10.mlp.experts.14.down_proj, device: CPU + model.layers.10.mlp.experts.14.act, device: CPU + Module: model.layers.10.mlp.experts.15, device: CPU + model.layers.10.mlp.experts.15.gate_proj, device: CPU + model.layers.10.mlp.experts.15.up_proj, device: CPU + model.layers.10.mlp.experts.15.down_proj, device: CPU + model.layers.10.mlp.experts.15.act, device: CPU + Module: model.layers.10.mlp.experts.16, device: CPU + model.layers.10.mlp.experts.16.gate_proj, device: CPU + model.layers.10.mlp.experts.16.up_proj, device: CPU + model.layers.10.mlp.experts.16.down_proj, device: CPU + model.layers.10.mlp.experts.16.act, device: CPU + Module: model.layers.10.mlp.experts.17, device: CPU + model.layers.10.mlp.experts.17.gate_proj, device: CPU + model.layers.10.mlp.experts.17.up_proj, device: CPU + model.layers.10.mlp.experts.17.down_proj, device: CPU + model.layers.10.mlp.experts.17.act, device: CPU + Module: model.layers.10.mlp.experts.18, device: CPU + model.layers.10.mlp.experts.18.gate_proj, device: CPU + model.layers.10.mlp.experts.18.up_proj, device: CPU + model.layers.10.mlp.experts.18.down_proj, device: CPU + model.layers.10.mlp.experts.18.act, device: CPU + Module: model.layers.10.mlp.experts.19, device: CPU + model.layers.10.mlp.experts.19.gate_proj, device: CPU + model.layers.10.mlp.experts.19.up_proj, device: CPU + model.layers.10.mlp.experts.19.down_proj, device: CPU + model.layers.10.mlp.experts.19.act, device: CPU + Module: model.layers.10.mlp.experts.20, device: CPU + model.layers.10.mlp.experts.20.gate_proj, device: CPU + model.layers.10.mlp.experts.20.up_proj, device: CPU + model.layers.10.mlp.experts.20.down_proj, device: CPU + model.layers.10.mlp.experts.20.act, device: CPU + Module: model.layers.10.mlp.experts.21, device: CPU + model.layers.10.mlp.experts.21.gate_proj, device: CPU + model.layers.10.mlp.experts.21.up_proj, device: CPU + model.layers.10.mlp.experts.21.down_proj, device: CPU + model.layers.10.mlp.experts.21.act, device: CPU + Module: model.layers.10.mlp.experts.22, device: CPU + model.layers.10.mlp.experts.22.gate_proj, device: CPU + model.layers.10.mlp.experts.22.up_proj, device: CPU + model.layers.10.mlp.experts.22.down_proj, device: CPU + model.layers.10.mlp.experts.22.act, device: CPU + Module: model.layers.10.mlp.experts.23, device: CPU + model.layers.10.mlp.experts.23.gate_proj, device: CPU + model.layers.10.mlp.experts.23.up_proj, device: CPU + model.layers.10.mlp.experts.23.down_proj, device: CPU + model.layers.10.mlp.experts.23.act, device: CPU + Module: model.layers.10.mlp.experts.24, device: CPU + model.layers.10.mlp.experts.24.gate_proj, device: CPU + model.layers.10.mlp.experts.24.up_proj, device: CPU + model.layers.10.mlp.experts.24.down_proj, device: CPU + model.layers.10.mlp.experts.24.act, device: CPU + Module: model.layers.10.mlp.experts.25, device: CPU + model.layers.10.mlp.experts.25.gate_proj, device: CPU + model.layers.10.mlp.experts.25.up_proj, device: CPU + model.layers.10.mlp.experts.25.down_proj, device: CPU + model.layers.10.mlp.experts.25.act, device: CPU + Module: model.layers.10.mlp.experts.26, device: CPU + model.layers.10.mlp.experts.26.gate_proj, device: CPU + model.layers.10.mlp.experts.26.up_proj, device: CPU + model.layers.10.mlp.experts.26.down_proj, device: CPU + model.layers.10.mlp.experts.26.act, device: CPU + Module: model.layers.10.mlp.experts.27, device: CPU + model.layers.10.mlp.experts.27.gate_proj, device: CPU + model.layers.10.mlp.experts.27.up_proj, device: CPU + model.layers.10.mlp.experts.27.down_proj, device: CPU + model.layers.10.mlp.experts.27.act, device: CPU + Module: model.layers.10.mlp.experts.28, device: CPU + model.layers.10.mlp.experts.28.gate_proj, device: CPU + model.layers.10.mlp.experts.28.up_proj, device: CPU + model.layers.10.mlp.experts.28.down_proj, device: CPU + model.layers.10.mlp.experts.28.act, device: CPU + Module: model.layers.10.mlp.experts.29, device: CPU + model.layers.10.mlp.experts.29.gate_proj, device: CPU + model.layers.10.mlp.experts.29.up_proj, device: CPU + model.layers.10.mlp.experts.29.down_proj, device: CPU + model.layers.10.mlp.experts.29.act, device: CPU + Module: model.layers.10.mlp.experts.30, device: CPU + model.layers.10.mlp.experts.30.gate_proj, device: CPU + model.layers.10.mlp.experts.30.up_proj, device: CPU + model.layers.10.mlp.experts.30.down_proj, device: CPU + model.layers.10.mlp.experts.30.act, device: CPU + Module: model.layers.10.mlp.experts.31, device: CPU + model.layers.10.mlp.experts.31.gate_proj, device: CPU + model.layers.10.mlp.experts.31.up_proj, device: CPU + model.layers.10.mlp.experts.31.down_proj, device: CPU + model.layers.10.mlp.experts.31.act, device: CPU + Module: model.layers.10.mlp.experts.32, device: CPU + model.layers.10.mlp.experts.32.gate_proj, device: CPU + model.layers.10.mlp.experts.32.up_proj, device: CPU + model.layers.10.mlp.experts.32.down_proj, device: CPU + model.layers.10.mlp.experts.32.act, device: CPU + Module: model.layers.10.mlp.experts.33, device: CPU + model.layers.10.mlp.experts.33.gate_proj, device: CPU + model.layers.10.mlp.experts.33.up_proj, device: CPU + model.layers.10.mlp.experts.33.down_proj, device: CPU + model.layers.10.mlp.experts.33.act, device: CPU + Module: model.layers.10.mlp.experts.34, device: CPU + model.layers.10.mlp.experts.34.gate_proj, device: CPU + model.layers.10.mlp.experts.34.up_proj, device: CPU + model.layers.10.mlp.experts.34.down_proj, device: CPU + model.layers.10.mlp.experts.34.act, device: CPU + Module: model.layers.10.mlp.experts.35, device: CPU + model.layers.10.mlp.experts.35.gate_proj, device: CPU + model.layers.10.mlp.experts.35.up_proj, device: CPU + model.layers.10.mlp.experts.35.down_proj, device: CPU + model.layers.10.mlp.experts.35.act, device: CPU + Module: model.layers.10.mlp.experts.36, device: CPU + model.layers.10.mlp.experts.36.gate_proj, device: CPU + model.layers.10.mlp.experts.36.up_proj, device: CPU + model.layers.10.mlp.experts.36.down_proj, device: CPU + model.layers.10.mlp.experts.36.act, device: CPU + Module: model.layers.10.mlp.experts.37, device: CPU + model.layers.10.mlp.experts.37.gate_proj, device: CPU + model.layers.10.mlp.experts.37.up_proj, device: CPU + model.layers.10.mlp.experts.37.down_proj, device: CPU + model.layers.10.mlp.experts.37.act, device: CPU + Module: model.layers.10.mlp.experts.38, device: CPU + model.layers.10.mlp.experts.38.gate_proj, device: CPU + model.layers.10.mlp.experts.38.up_proj, device: CPU + model.layers.10.mlp.experts.38.down_proj, device: CPU + model.layers.10.mlp.experts.38.act, device: CPU + Module: model.layers.10.mlp.experts.39, device: CPU + model.layers.10.mlp.experts.39.gate_proj, device: CPU + model.layers.10.mlp.experts.39.up_proj, device: CPU + model.layers.10.mlp.experts.39.down_proj, device: CPU + model.layers.10.mlp.experts.39.act, device: CPU + Module: model.layers.10.mlp.experts.40, device: CPU + model.layers.10.mlp.experts.40.gate_proj, device: CPU + model.layers.10.mlp.experts.40.up_proj, device: CPU + model.layers.10.mlp.experts.40.down_proj, device: CPU + model.layers.10.mlp.experts.40.act, device: CPU + Module: model.layers.10.mlp.experts.41, device: CPU + model.layers.10.mlp.experts.41.gate_proj, device: CPU + model.layers.10.mlp.experts.41.up_proj, device: CPU + model.layers.10.mlp.experts.41.down_proj, device: CPU + model.layers.10.mlp.experts.41.act, device: CPU + Module: model.layers.10.mlp.experts.42, device: CPU + model.layers.10.mlp.experts.42.gate_proj, device: CPU + model.layers.10.mlp.experts.42.up_proj, device: CPU + model.layers.10.mlp.experts.42.down_proj, device: CPU + model.layers.10.mlp.experts.42.act, device: CPU + Module: model.layers.10.mlp.experts.43, device: CPU + model.layers.10.mlp.experts.43.gate_proj, device: CPU + model.layers.10.mlp.experts.43.up_proj, device: CPU + model.layers.10.mlp.experts.43.down_proj, device: CPU + model.layers.10.mlp.experts.43.act, device: CPU + Module: model.layers.10.mlp.experts.44, device: CPU + model.layers.10.mlp.experts.44.gate_proj, device: CPU + model.layers.10.mlp.experts.44.up_proj, device: CPU + model.layers.10.mlp.experts.44.down_proj, device: CPU + model.layers.10.mlp.experts.44.act, device: CPU + Module: model.layers.10.mlp.experts.45, device: CPU + model.layers.10.mlp.experts.45.gate_proj, device: CPU + model.layers.10.mlp.experts.45.up_proj, device: CPU + model.layers.10.mlp.experts.45.down_proj, device: CPU + model.layers.10.mlp.experts.45.act, device: CPU + Module: model.layers.10.mlp.experts.46, device: CPU + model.layers.10.mlp.experts.46.gate_proj, device: CPU + model.layers.10.mlp.experts.46.up_proj, device: CPU + model.layers.10.mlp.experts.46.down_proj, device: CPU + model.layers.10.mlp.experts.46.act, device: CPU + Module: model.layers.10.mlp.experts.47, device: CPU + model.layers.10.mlp.experts.47.gate_proj, device: CPU + model.layers.10.mlp.experts.47.up_proj, device: CPU + model.layers.10.mlp.experts.47.down_proj, device: CPU + model.layers.10.mlp.experts.47.act, device: CPU + Module: model.layers.10.mlp.experts.48, device: CPU + model.layers.10.mlp.experts.48.gate_proj, device: CPU + model.layers.10.mlp.experts.48.up_proj, device: CPU + model.layers.10.mlp.experts.48.down_proj, device: CPU + model.layers.10.mlp.experts.48.act, device: CPU + Module: model.layers.10.mlp.experts.49, device: CPU + model.layers.10.mlp.experts.49.gate_proj, device: CPU + model.layers.10.mlp.experts.49.up_proj, device: CPU + model.layers.10.mlp.experts.49.down_proj, device: CPU + model.layers.10.mlp.experts.49.act, device: CPU + Module: model.layers.10.mlp.experts.50, device: CPU + model.layers.10.mlp.experts.50.gate_proj, device: CPU + model.layers.10.mlp.experts.50.up_proj, device: CPU + model.layers.10.mlp.experts.50.down_proj, device: CPU + model.layers.10.mlp.experts.50.act, device: CPU + Module: model.layers.10.mlp.experts.51, device: CPU + model.layers.10.mlp.experts.51.gate_proj, device: CPU + model.layers.10.mlp.experts.51.up_proj, device: CPU + model.layers.10.mlp.experts.51.down_proj, device: CPU + model.layers.10.mlp.experts.51.act, device: CPU + Module: model.layers.10.mlp.experts.52, device: CPU + model.layers.10.mlp.experts.52.gate_proj, device: CPU + model.layers.10.mlp.experts.52.up_proj, device: CPU + model.layers.10.mlp.experts.52.down_proj, device: CPU + model.layers.10.mlp.experts.52.act, device: CPU + Module: model.layers.10.mlp.experts.53, device: CPU + model.layers.10.mlp.experts.53.gate_proj, device: CPU + model.layers.10.mlp.experts.53.up_proj, device: CPU + model.layers.10.mlp.experts.53.down_proj, device: CPU + model.layers.10.mlp.experts.53.act, device: CPU + Module: model.layers.10.mlp.experts.54, device: CPU + model.layers.10.mlp.experts.54.gate_proj, device: CPU + model.layers.10.mlp.experts.54.up_proj, device: CPU + model.layers.10.mlp.experts.54.down_proj, device: CPU + model.layers.10.mlp.experts.54.act, device: CPU + Module: model.layers.10.mlp.experts.55, device: CPU + model.layers.10.mlp.experts.55.gate_proj, device: CPU + model.layers.10.mlp.experts.55.up_proj, device: CPU + model.layers.10.mlp.experts.55.down_proj, device: CPU + model.layers.10.mlp.experts.55.act, device: CPU + Module: model.layers.10.mlp.experts.56, device: CPU + model.layers.10.mlp.experts.56.gate_proj, device: CPU + model.layers.10.mlp.experts.56.up_proj, device: CPU + model.layers.10.mlp.experts.56.down_proj, device: CPU + model.layers.10.mlp.experts.56.act, device: CPU + Module: model.layers.10.mlp.experts.57, device: CPU + model.layers.10.mlp.experts.57.gate_proj, device: CPU + model.layers.10.mlp.experts.57.up_proj, device: CPU + model.layers.10.mlp.experts.57.down_proj, device: CPU + model.layers.10.mlp.experts.57.act, device: CPU + Module: model.layers.10.mlp.experts.58, device: CPU + model.layers.10.mlp.experts.58.gate_proj, device: CPU + model.layers.10.mlp.experts.58.up_proj, device: CPU + model.layers.10.mlp.experts.58.down_proj, device: CPU + model.layers.10.mlp.experts.58.act, device: CPU + Module: model.layers.10.mlp.experts.59, device: CPU + model.layers.10.mlp.experts.59.gate_proj, device: CPU + model.layers.10.mlp.experts.59.up_proj, device: CPU + model.layers.10.mlp.experts.59.down_proj, device: CPU + model.layers.10.mlp.experts.59.act, device: CPU + Module: model.layers.10.mlp.experts.60, device: CPU + model.layers.10.mlp.experts.60.gate_proj, device: CPU + model.layers.10.mlp.experts.60.up_proj, device: CPU + model.layers.10.mlp.experts.60.down_proj, device: CPU + model.layers.10.mlp.experts.60.act, device: CPU + Module: model.layers.10.mlp.experts.61, device: CPU + model.layers.10.mlp.experts.61.gate_proj, device: CPU + model.layers.10.mlp.experts.61.up_proj, device: CPU + model.layers.10.mlp.experts.61.down_proj, device: CPU + model.layers.10.mlp.experts.61.act, device: CPU + Module: model.layers.10.mlp.experts.62, device: CPU + model.layers.10.mlp.experts.62.gate_proj, device: CPU + model.layers.10.mlp.experts.62.up_proj, device: CPU + model.layers.10.mlp.experts.62.down_proj, device: CPU + model.layers.10.mlp.experts.62.act, device: CPU + Module: model.layers.10.mlp.experts.63, device: CPU + model.layers.10.mlp.experts.63.gate_proj, device: CPU + model.layers.10.mlp.experts.63.up_proj, device: CPU + model.layers.10.mlp.experts.63.down_proj, device: CPU + model.layers.10.mlp.experts.63.act, device: CPU + Module: model.layers.10.mlp.gate, device: CPU + model.layers.10.mlp.gate.weight, device: CPU + Module: model.layers.10.mlp.shared_experts, device: CPU + model.layers.10.mlp.shared_experts.gate_proj, device: CPU + model.layers.10.mlp.shared_experts.up_proj, device: CPU + model.layers.10.mlp.shared_experts.down_proj, device: CPU + model.layers.10.mlp.shared_experts.act, device: CPU + model.layers.10.input_layernorm, device: CPU + model.layers.10.post_attention_layernorm, device: CPU + Module: model.layers.11, device: CPU + Module: model.layers.11.self_attn, device: CPU + model.layers.11.self_attn.q_proj, device: CPU + model.layers.11.self_attn.k_proj, device: CPU + model.layers.11.self_attn.v_proj, device: CPU + model.layers.11.self_attn.o_proj, device: CPU + model.layers.11.self_attn.q_rope, device: CPU + model.layers.11.self_attn.k_rope, device: CPU + Module: model.layers.11.mlp, device: CPU + Module: model.layers.11.mlp.experts, device: CPU + Module: model.layers.11.mlp.experts.0, device: CPU + model.layers.11.mlp.experts.0.gate_proj, device: CPU + model.layers.11.mlp.experts.0.up_proj, device: CPU + model.layers.11.mlp.experts.0.down_proj, device: CPU + model.layers.11.mlp.experts.0.act, device: CPU + Module: model.layers.11.mlp.experts.1, device: CPU + model.layers.11.mlp.experts.1.gate_proj, device: CPU + model.layers.11.mlp.experts.1.up_proj, device: CPU + model.layers.11.mlp.experts.1.down_proj, device: CPU + model.layers.11.mlp.experts.1.act, device: CPU + Module: model.layers.11.mlp.experts.2, device: CPU + model.layers.11.mlp.experts.2.gate_proj, device: CPU + model.layers.11.mlp.experts.2.up_proj, device: CPU + model.layers.11.mlp.experts.2.down_proj, device: CPU + model.layers.11.mlp.experts.2.act, device: CPU + Module: model.layers.11.mlp.experts.3, device: CPU + model.layers.11.mlp.experts.3.gate_proj, device: CPU + model.layers.11.mlp.experts.3.up_proj, device: CPU + model.layers.11.mlp.experts.3.down_proj, device: CPU + model.layers.11.mlp.experts.3.act, device: CPU + Module: model.layers.11.mlp.experts.4, device: CPU + model.layers.11.mlp.experts.4.gate_proj, device: CPU + model.layers.11.mlp.experts.4.up_proj, device: CPU + model.layers.11.mlp.experts.4.down_proj, device: CPU + model.layers.11.mlp.experts.4.act, device: CPU + Module: model.layers.11.mlp.experts.5, device: CPU + model.layers.11.mlp.experts.5.gate_proj, device: CPU + model.layers.11.mlp.experts.5.up_proj, device: CPU + model.layers.11.mlp.experts.5.down_proj, device: CPU + model.layers.11.mlp.experts.5.act, device: CPU + Module: model.layers.11.mlp.experts.6, device: CPU + model.layers.11.mlp.experts.6.gate_proj, device: CPU + model.layers.11.mlp.experts.6.up_proj, device: CPU + model.layers.11.mlp.experts.6.down_proj, device: CPU + model.layers.11.mlp.experts.6.act, device: CPU + Module: model.layers.11.mlp.experts.7, device: CPU + model.layers.11.mlp.experts.7.gate_proj, device: CPU + model.layers.11.mlp.experts.7.up_proj, device: CPU + model.layers.11.mlp.experts.7.down_proj, device: CPU + model.layers.11.mlp.experts.7.act, device: CPU + Module: model.layers.11.mlp.experts.8, device: CPU + model.layers.11.mlp.experts.8.gate_proj, device: CPU + model.layers.11.mlp.experts.8.up_proj, device: CPU + model.layers.11.mlp.experts.8.down_proj, device: CPU + model.layers.11.mlp.experts.8.act, device: CPU + Module: model.layers.11.mlp.experts.9, device: CPU + model.layers.11.mlp.experts.9.gate_proj, device: CPU + model.layers.11.mlp.experts.9.up_proj, device: CPU + model.layers.11.mlp.experts.9.down_proj, device: CPU + model.layers.11.mlp.experts.9.act, device: CPU + Module: model.layers.11.mlp.experts.10, device: CPU + model.layers.11.mlp.experts.10.gate_proj, device: CPU + model.layers.11.mlp.experts.10.up_proj, device: CPU + model.layers.11.mlp.experts.10.down_proj, device: CPU + model.layers.11.mlp.experts.10.act, device: CPU + Module: model.layers.11.mlp.experts.11, device: CPU + model.layers.11.mlp.experts.11.gate_proj, device: CPU + model.layers.11.mlp.experts.11.up_proj, device: CPU + model.layers.11.mlp.experts.11.down_proj, device: CPU + model.layers.11.mlp.experts.11.act, device: CPU + Module: model.layers.11.mlp.experts.12, device: CPU + model.layers.11.mlp.experts.12.gate_proj, device: CPU + model.layers.11.mlp.experts.12.up_proj, device: CPU + model.layers.11.mlp.experts.12.down_proj, device: CPU + model.layers.11.mlp.experts.12.act, device: CPU + Module: model.layers.11.mlp.experts.13, device: CPU + model.layers.11.mlp.experts.13.gate_proj, device: CPU + model.layers.11.mlp.experts.13.up_proj, device: CPU + model.layers.11.mlp.experts.13.down_proj, device: CPU + model.layers.11.mlp.experts.13.act, device: CPU + Module: model.layers.11.mlp.experts.14, device: CPU + model.layers.11.mlp.experts.14.gate_proj, device: CPU + model.layers.11.mlp.experts.14.up_proj, device: CPU + model.layers.11.mlp.experts.14.down_proj, device: CPU + model.layers.11.mlp.experts.14.act, device: CPU + Module: model.layers.11.mlp.experts.15, device: CPU + model.layers.11.mlp.experts.15.gate_proj, device: CPU + model.layers.11.mlp.experts.15.up_proj, device: CPU + model.layers.11.mlp.experts.15.down_proj, device: CPU + model.layers.11.mlp.experts.15.act, device: CPU + Module: model.layers.11.mlp.experts.16, device: CPU + model.layers.11.mlp.experts.16.gate_proj, device: CPU + model.layers.11.mlp.experts.16.up_proj, device: CPU + model.layers.11.mlp.experts.16.down_proj, device: CPU + model.layers.11.mlp.experts.16.act, device: CPU + Module: model.layers.11.mlp.experts.17, device: CPU + model.layers.11.mlp.experts.17.gate_proj, device: CPU + model.layers.11.mlp.experts.17.up_proj, device: CPU + model.layers.11.mlp.experts.17.down_proj, device: CPU + model.layers.11.mlp.experts.17.act, device: CPU + Module: model.layers.11.mlp.experts.18, device: CPU + model.layers.11.mlp.experts.18.gate_proj, device: CPU + model.layers.11.mlp.experts.18.up_proj, device: CPU + model.layers.11.mlp.experts.18.down_proj, device: CPU + model.layers.11.mlp.experts.18.act, device: CPU + Module: model.layers.11.mlp.experts.19, device: CPU + model.layers.11.mlp.experts.19.gate_proj, device: CPU + model.layers.11.mlp.experts.19.up_proj, device: CPU + model.layers.11.mlp.experts.19.down_proj, device: CPU + model.layers.11.mlp.experts.19.act, device: CPU + Module: model.layers.11.mlp.experts.20, device: CPU + model.layers.11.mlp.experts.20.gate_proj, device: CPU + model.layers.11.mlp.experts.20.up_proj, device: CPU + model.layers.11.mlp.experts.20.down_proj, device: CPU + model.layers.11.mlp.experts.20.act, device: CPU + Module: model.layers.11.mlp.experts.21, device: CPU + model.layers.11.mlp.experts.21.gate_proj, device: CPU + model.layers.11.mlp.experts.21.up_proj, device: CPU + model.layers.11.mlp.experts.21.down_proj, device: CPU + model.layers.11.mlp.experts.21.act, device: CPU + Module: model.layers.11.mlp.experts.22, device: CPU + model.layers.11.mlp.experts.22.gate_proj, device: CPU + model.layers.11.mlp.experts.22.up_proj, device: CPU + model.layers.11.mlp.experts.22.down_proj, device: CPU + model.layers.11.mlp.experts.22.act, device: CPU + Module: model.layers.11.mlp.experts.23, device: CPU + model.layers.11.mlp.experts.23.gate_proj, device: CPU + model.layers.11.mlp.experts.23.up_proj, device: CPU + model.layers.11.mlp.experts.23.down_proj, device: CPU + model.layers.11.mlp.experts.23.act, device: CPU + Module: model.layers.11.mlp.experts.24, device: CPU + model.layers.11.mlp.experts.24.gate_proj, device: CPU + model.layers.11.mlp.experts.24.up_proj, device: CPU + model.layers.11.mlp.experts.24.down_proj, device: CPU + model.layers.11.mlp.experts.24.act, device: CPU + Module: model.layers.11.mlp.experts.25, device: CPU + model.layers.11.mlp.experts.25.gate_proj, device: CPU + model.layers.11.mlp.experts.25.up_proj, device: CPU + model.layers.11.mlp.experts.25.down_proj, device: CPU + model.layers.11.mlp.experts.25.act, device: CPU + Module: model.layers.11.mlp.experts.26, device: CPU + model.layers.11.mlp.experts.26.gate_proj, device: CPU + model.layers.11.mlp.experts.26.up_proj, device: CPU + model.layers.11.mlp.experts.26.down_proj, device: CPU + model.layers.11.mlp.experts.26.act, device: CPU + Module: model.layers.11.mlp.experts.27, device: CPU + model.layers.11.mlp.experts.27.gate_proj, device: CPU + model.layers.11.mlp.experts.27.up_proj, device: CPU + model.layers.11.mlp.experts.27.down_proj, device: CPU + model.layers.11.mlp.experts.27.act, device: CPU + Module: model.layers.11.mlp.experts.28, device: CPU + model.layers.11.mlp.experts.28.gate_proj, device: CPU + model.layers.11.mlp.experts.28.up_proj, device: CPU + model.layers.11.mlp.experts.28.down_proj, device: CPU + model.layers.11.mlp.experts.28.act, device: CPU + Module: model.layers.11.mlp.experts.29, device: CPU + model.layers.11.mlp.experts.29.gate_proj, device: CPU + model.layers.11.mlp.experts.29.up_proj, device: CPU + model.layers.11.mlp.experts.29.down_proj, device: CPU + model.layers.11.mlp.experts.29.act, device: CPU + Module: model.layers.11.mlp.experts.30, device: CPU + model.layers.11.mlp.experts.30.gate_proj, device: CPU + model.layers.11.mlp.experts.30.up_proj, device: CPU + model.layers.11.mlp.experts.30.down_proj, device: CPU + model.layers.11.mlp.experts.30.act, device: CPU + Module: model.layers.11.mlp.experts.31, device: CPU + model.layers.11.mlp.experts.31.gate_proj, device: CPU + model.layers.11.mlp.experts.31.up_proj, device: CPU + model.layers.11.mlp.experts.31.down_proj, device: CPU + model.layers.11.mlp.experts.31.act, device: CPU + Module: model.layers.11.mlp.experts.32, device: CPU + model.layers.11.mlp.experts.32.gate_proj, device: CPU + model.layers.11.mlp.experts.32.up_proj, device: CPU + model.layers.11.mlp.experts.32.down_proj, device: CPU + model.layers.11.mlp.experts.32.act, device: CPU + Module: model.layers.11.mlp.experts.33, device: CPU + model.layers.11.mlp.experts.33.gate_proj, device: CPU + model.layers.11.mlp.experts.33.up_proj, device: CPU + model.layers.11.mlp.experts.33.down_proj, device: CPU + model.layers.11.mlp.experts.33.act, device: CPU + Module: model.layers.11.mlp.experts.34, device: CPU + model.layers.11.mlp.experts.34.gate_proj, device: CPU + model.layers.11.mlp.experts.34.up_proj, device: CPU + model.layers.11.mlp.experts.34.down_proj, device: CPU + model.layers.11.mlp.experts.34.act, device: CPU + Module: model.layers.11.mlp.experts.35, device: CPU + model.layers.11.mlp.experts.35.gate_proj, device: CPU + model.layers.11.mlp.experts.35.up_proj, device: CPU + model.layers.11.mlp.experts.35.down_proj, device: CPU + model.layers.11.mlp.experts.35.act, device: CPU + Module: model.layers.11.mlp.experts.36, device: CPU + model.layers.11.mlp.experts.36.gate_proj, device: CPU + model.layers.11.mlp.experts.36.up_proj, device: CPU + model.layers.11.mlp.experts.36.down_proj, device: CPU + model.layers.11.mlp.experts.36.act, device: CPU + Module: model.layers.11.mlp.experts.37, device: CPU + model.layers.11.mlp.experts.37.gate_proj, device: CPU + model.layers.11.mlp.experts.37.up_proj, device: CPU + model.layers.11.mlp.experts.37.down_proj, device: CPU + model.layers.11.mlp.experts.37.act, device: CPU + Module: model.layers.11.mlp.experts.38, device: CPU + model.layers.11.mlp.experts.38.gate_proj, device: CPU + model.layers.11.mlp.experts.38.up_proj, device: CPU + model.layers.11.mlp.experts.38.down_proj, device: CPU + model.layers.11.mlp.experts.38.act, device: CPU + Module: model.layers.11.mlp.experts.39, device: CPU + model.layers.11.mlp.experts.39.gate_proj, device: CPU + model.layers.11.mlp.experts.39.up_proj, device: CPU + model.layers.11.mlp.experts.39.down_proj, device: CPU + model.layers.11.mlp.experts.39.act, device: CPU + Module: model.layers.11.mlp.experts.40, device: CPU + model.layers.11.mlp.experts.40.gate_proj, device: CPU + model.layers.11.mlp.experts.40.up_proj, device: CPU + model.layers.11.mlp.experts.40.down_proj, device: CPU + model.layers.11.mlp.experts.40.act, device: CPU + Module: model.layers.11.mlp.experts.41, device: CPU + model.layers.11.mlp.experts.41.gate_proj, device: CPU + model.layers.11.mlp.experts.41.up_proj, device: CPU + model.layers.11.mlp.experts.41.down_proj, device: CPU + model.layers.11.mlp.experts.41.act, device: CPU + Module: model.layers.11.mlp.experts.42, device: CPU + model.layers.11.mlp.experts.42.gate_proj, device: CPU + model.layers.11.mlp.experts.42.up_proj, device: CPU + model.layers.11.mlp.experts.42.down_proj, device: CPU + model.layers.11.mlp.experts.42.act, device: CPU + Module: model.layers.11.mlp.experts.43, device: CPU + model.layers.11.mlp.experts.43.gate_proj, device: CPU + model.layers.11.mlp.experts.43.up_proj, device: CPU + model.layers.11.mlp.experts.43.down_proj, device: CPU + model.layers.11.mlp.experts.43.act, device: CPU + Module: model.layers.11.mlp.experts.44, device: CPU + model.layers.11.mlp.experts.44.gate_proj, device: CPU + model.layers.11.mlp.experts.44.up_proj, device: CPU + model.layers.11.mlp.experts.44.down_proj, device: CPU + model.layers.11.mlp.experts.44.act, device: CPU + Module: model.layers.11.mlp.experts.45, device: CPU + model.layers.11.mlp.experts.45.gate_proj, device: CPU + model.layers.11.mlp.experts.45.up_proj, device: CPU + model.layers.11.mlp.experts.45.down_proj, device: CPU + model.layers.11.mlp.experts.45.act, device: CPU + Module: model.layers.11.mlp.experts.46, device: CPU + model.layers.11.mlp.experts.46.gate_proj, device: CPU + model.layers.11.mlp.experts.46.up_proj, device: CPU + model.layers.11.mlp.experts.46.down_proj, device: CPU + model.layers.11.mlp.experts.46.act, device: CPU + Module: model.layers.11.mlp.experts.47, device: CPU + model.layers.11.mlp.experts.47.gate_proj, device: CPU + model.layers.11.mlp.experts.47.up_proj, device: CPU + model.layers.11.mlp.experts.47.down_proj, device: CPU + model.layers.11.mlp.experts.47.act, device: CPU + Module: model.layers.11.mlp.experts.48, device: CPU + model.layers.11.mlp.experts.48.gate_proj, device: CPU + model.layers.11.mlp.experts.48.up_proj, device: CPU + model.layers.11.mlp.experts.48.down_proj, device: CPU + model.layers.11.mlp.experts.48.act, device: CPU + Module: model.layers.11.mlp.experts.49, device: CPU + model.layers.11.mlp.experts.49.gate_proj, device: CPU + model.layers.11.mlp.experts.49.up_proj, device: CPU + model.layers.11.mlp.experts.49.down_proj, device: CPU + model.layers.11.mlp.experts.49.act, device: CPU + Module: model.layers.11.mlp.experts.50, device: CPU + model.layers.11.mlp.experts.50.gate_proj, device: CPU + model.layers.11.mlp.experts.50.up_proj, device: CPU + model.layers.11.mlp.experts.50.down_proj, device: CPU + model.layers.11.mlp.experts.50.act, device: CPU + Module: model.layers.11.mlp.experts.51, device: CPU + model.layers.11.mlp.experts.51.gate_proj, device: CPU + model.layers.11.mlp.experts.51.up_proj, device: CPU + model.layers.11.mlp.experts.51.down_proj, device: CPU + model.layers.11.mlp.experts.51.act, device: CPU + Module: model.layers.11.mlp.experts.52, device: CPU + model.layers.11.mlp.experts.52.gate_proj, device: CPU + model.layers.11.mlp.experts.52.up_proj, device: CPU + model.layers.11.mlp.experts.52.down_proj, device: CPU + model.layers.11.mlp.experts.52.act, device: CPU + Module: model.layers.11.mlp.experts.53, device: CPU + model.layers.11.mlp.experts.53.gate_proj, device: CPU + model.layers.11.mlp.experts.53.up_proj, device: CPU + model.layers.11.mlp.experts.53.down_proj, device: CPU + model.layers.11.mlp.experts.53.act, device: CPU + Module: model.layers.11.mlp.experts.54, device: CPU + model.layers.11.mlp.experts.54.gate_proj, device: CPU + model.layers.11.mlp.experts.54.up_proj, device: CPU + model.layers.11.mlp.experts.54.down_proj, device: CPU + model.layers.11.mlp.experts.54.act, device: CPU + Module: model.layers.11.mlp.experts.55, device: CPU + model.layers.11.mlp.experts.55.gate_proj, device: CPU + model.layers.11.mlp.experts.55.up_proj, device: CPU + model.layers.11.mlp.experts.55.down_proj, device: CPU + model.layers.11.mlp.experts.55.act, device: CPU + Module: model.layers.11.mlp.experts.56, device: CPU + model.layers.11.mlp.experts.56.gate_proj, device: CPU + model.layers.11.mlp.experts.56.up_proj, device: CPU + model.layers.11.mlp.experts.56.down_proj, device: CPU + model.layers.11.mlp.experts.56.act, device: CPU + Module: model.layers.11.mlp.experts.57, device: CPU + model.layers.11.mlp.experts.57.gate_proj, device: CPU + model.layers.11.mlp.experts.57.up_proj, device: CPU + model.layers.11.mlp.experts.57.down_proj, device: CPU + model.layers.11.mlp.experts.57.act, device: CPU + Module: model.layers.11.mlp.experts.58, device: CPU + model.layers.11.mlp.experts.58.gate_proj, device: CPU + model.layers.11.mlp.experts.58.up_proj, device: CPU + model.layers.11.mlp.experts.58.down_proj, device: CPU + model.layers.11.mlp.experts.58.act, device: CPU + Module: model.layers.11.mlp.experts.59, device: CPU + model.layers.11.mlp.experts.59.gate_proj, device: CPU + model.layers.11.mlp.experts.59.up_proj, device: CPU + model.layers.11.mlp.experts.59.down_proj, device: CPU + model.layers.11.mlp.experts.59.act, device: CPU + Module: model.layers.11.mlp.experts.60, device: CPU + model.layers.11.mlp.experts.60.gate_proj, device: CPU + model.layers.11.mlp.experts.60.up_proj, device: CPU + model.layers.11.mlp.experts.60.down_proj, device: CPU + model.layers.11.mlp.experts.60.act, device: CPU + Module: model.layers.11.mlp.experts.61, device: CPU + model.layers.11.mlp.experts.61.gate_proj, device: CPU + model.layers.11.mlp.experts.61.up_proj, device: CPU + model.layers.11.mlp.experts.61.down_proj, device: CPU + model.layers.11.mlp.experts.61.act, device: CPU + Module: model.layers.11.mlp.experts.62, device: CPU + model.layers.11.mlp.experts.62.gate_proj, device: CPU + model.layers.11.mlp.experts.62.up_proj, device: CPU + model.layers.11.mlp.experts.62.down_proj, device: CPU + model.layers.11.mlp.experts.62.act, device: CPU + Module: model.layers.11.mlp.experts.63, device: CPU + model.layers.11.mlp.experts.63.gate_proj, device: CPU + model.layers.11.mlp.experts.63.up_proj, device: CPU + model.layers.11.mlp.experts.63.down_proj, device: CPU + model.layers.11.mlp.experts.63.act, device: CPU + Module: model.layers.11.mlp.gate, device: CPU + model.layers.11.mlp.gate.weight, device: CPU + Module: model.layers.11.mlp.shared_experts, device: CPU + model.layers.11.mlp.shared_experts.gate_proj, device: CPU + model.layers.11.mlp.shared_experts.up_proj, device: CPU + model.layers.11.mlp.shared_experts.down_proj, device: CPU + model.layers.11.mlp.shared_experts.act, device: CPU + model.layers.11.input_layernorm, device: CPU + model.layers.11.post_attention_layernorm, device: CPU + model.norm, device: CPU + Module: model.sam_model, device: CPU + Module: model.sam_model.patch_embed, device: CPU + model.sam_model.patch_embed.proj, device: CPU + model.sam_model.pos_embed, device: CPU + Module: model.sam_model.blocks, device: CPU + Module: model.sam_model.blocks.0, device: CPU + model.sam_model.blocks.0.norm1, device: CPU + Module: model.sam_model.blocks.0.attn, device: CPU + model.sam_model.blocks.0.attn.qkv, device: CPU + model.sam_model.blocks.0.attn.proj, device: CPU + model.sam_model.blocks.0.attn.rel_pos_h, device: CPU + model.sam_model.blocks.0.attn.rel_pos_w, device: CPU + model.sam_model.blocks.0.norm2, device: CPU + Module: model.sam_model.blocks.0.mlp, device: CPU + model.sam_model.blocks.0.mlp.lin1, device: CPU + model.sam_model.blocks.0.mlp.lin2, device: CPU + model.sam_model.blocks.0.mlp.act, device: CPU + Module: model.sam_model.blocks.1, device: CPU + model.sam_model.blocks.1.norm1, device: CPU + Module: model.sam_model.blocks.1.attn, device: CPU + model.sam_model.blocks.1.attn.qkv, device: CPU + model.sam_model.blocks.1.attn.proj, device: CPU + model.sam_model.blocks.1.attn.rel_pos_h, device: CPU + model.sam_model.blocks.1.attn.rel_pos_w, device: CPU + model.sam_model.blocks.1.norm2, device: CPU + Module: model.sam_model.blocks.1.mlp, device: CPU + model.sam_model.blocks.1.mlp.lin1, device: CPU + model.sam_model.blocks.1.mlp.lin2, device: CPU + model.sam_model.blocks.1.mlp.act, device: CPU + Module: model.sam_model.blocks.2, device: CPU + model.sam_model.blocks.2.norm1, device: CPU + Module: model.sam_model.blocks.2.attn, device: CPU + model.sam_model.blocks.2.attn.qkv, device: CPU + model.sam_model.blocks.2.attn.proj, device: CPU + model.sam_model.blocks.2.attn.rel_pos_h, device: CPU + model.sam_model.blocks.2.attn.rel_pos_w, device: CPU + model.sam_model.blocks.2.norm2, device: CPU + Module: model.sam_model.blocks.2.mlp, device: CPU + model.sam_model.blocks.2.mlp.lin1, device: CPU + model.sam_model.blocks.2.mlp.lin2, device: CPU + model.sam_model.blocks.2.mlp.act, device: CPU + Module: model.sam_model.blocks.3, device: CPU + model.sam_model.blocks.3.norm1, device: CPU + Module: model.sam_model.blocks.3.attn, device: CPU + model.sam_model.blocks.3.attn.qkv, device: CPU + model.sam_model.blocks.3.attn.proj, device: CPU + model.sam_model.blocks.3.attn.rel_pos_h, device: CPU + model.sam_model.blocks.3.attn.rel_pos_w, device: CPU + model.sam_model.blocks.3.norm2, device: CPU + Module: model.sam_model.blocks.3.mlp, device: CPU + model.sam_model.blocks.3.mlp.lin1, device: CPU + model.sam_model.blocks.3.mlp.lin2, device: CPU + model.sam_model.blocks.3.mlp.act, device: CPU + Module: model.sam_model.blocks.4, device: CPU + model.sam_model.blocks.4.norm1, device: CPU + Module: model.sam_model.blocks.4.attn, device: CPU + model.sam_model.blocks.4.attn.qkv, device: CPU + model.sam_model.blocks.4.attn.proj, device: CPU + model.sam_model.blocks.4.attn.rel_pos_h, device: CPU + model.sam_model.blocks.4.attn.rel_pos_w, device: CPU + model.sam_model.blocks.4.norm2, device: CPU + Module: model.sam_model.blocks.4.mlp, device: CPU + model.sam_model.blocks.4.mlp.lin1, device: CPU + model.sam_model.blocks.4.mlp.lin2, device: CPU + model.sam_model.blocks.4.mlp.act, device: CPU + Module: model.sam_model.blocks.5, device: CPU + model.sam_model.blocks.5.norm1, device: CPU + Module: model.sam_model.blocks.5.attn, device: CPU + model.sam_model.blocks.5.attn.qkv, device: CPU + model.sam_model.blocks.5.attn.proj, device: CPU + model.sam_model.blocks.5.attn.rel_pos_h, device: CPU + model.sam_model.blocks.5.attn.rel_pos_w, device: CPU + model.sam_model.blocks.5.norm2, device: CPU + Module: model.sam_model.blocks.5.mlp, device: CPU + model.sam_model.blocks.5.mlp.lin1, device: CPU + model.sam_model.blocks.5.mlp.lin2, device: CPU + model.sam_model.blocks.5.mlp.act, device: CPU + Module: model.sam_model.blocks.6, device: CPU + model.sam_model.blocks.6.norm1, device: CPU + Module: model.sam_model.blocks.6.attn, device: CPU + model.sam_model.blocks.6.attn.qkv, device: CPU + model.sam_model.blocks.6.attn.proj, device: CPU + model.sam_model.blocks.6.attn.rel_pos_h, device: CPU + model.sam_model.blocks.6.attn.rel_pos_w, device: CPU + model.sam_model.blocks.6.norm2, device: CPU + Module: model.sam_model.blocks.6.mlp, device: CPU + model.sam_model.blocks.6.mlp.lin1, device: CPU + model.sam_model.blocks.6.mlp.lin2, device: CPU + model.sam_model.blocks.6.mlp.act, device: CPU + Module: model.sam_model.blocks.7, device: CPU + model.sam_model.blocks.7.norm1, device: CPU + Module: model.sam_model.blocks.7.attn, device: CPU + model.sam_model.blocks.7.attn.qkv, device: CPU + model.sam_model.blocks.7.attn.proj, device: CPU + model.sam_model.blocks.7.attn.rel_pos_h, device: CPU + model.sam_model.blocks.7.attn.rel_pos_w, device: CPU + model.sam_model.blocks.7.norm2, device: CPU + Module: model.sam_model.blocks.7.mlp, device: CPU + model.sam_model.blocks.7.mlp.lin1, device: CPU + model.sam_model.blocks.7.mlp.lin2, device: CPU + model.sam_model.blocks.7.mlp.act, device: CPU + Module: model.sam_model.blocks.8, device: CPU + model.sam_model.blocks.8.norm1, device: CPU + Module: model.sam_model.blocks.8.attn, device: CPU + model.sam_model.blocks.8.attn.qkv, device: CPU + model.sam_model.blocks.8.attn.proj, device: CPU + model.sam_model.blocks.8.attn.rel_pos_h, device: CPU + model.sam_model.blocks.8.attn.rel_pos_w, device: CPU + model.sam_model.blocks.8.norm2, device: CPU + Module: model.sam_model.blocks.8.mlp, device: CPU + model.sam_model.blocks.8.mlp.lin1, device: CPU + model.sam_model.blocks.8.mlp.lin2, device: CPU + model.sam_model.blocks.8.mlp.act, device: CPU + Module: model.sam_model.blocks.9, device: CPU + model.sam_model.blocks.9.norm1, device: CPU + Module: model.sam_model.blocks.9.attn, device: CPU + model.sam_model.blocks.9.attn.qkv, device: CPU + model.sam_model.blocks.9.attn.proj, device: CPU + model.sam_model.blocks.9.attn.rel_pos_h, device: CPU + model.sam_model.blocks.9.attn.rel_pos_w, device: CPU + model.sam_model.blocks.9.norm2, device: CPU + Module: model.sam_model.blocks.9.mlp, device: CPU + model.sam_model.blocks.9.mlp.lin1, device: CPU + model.sam_model.blocks.9.mlp.lin2, device: CPU + model.sam_model.blocks.9.mlp.act, device: CPU + Module: model.sam_model.blocks.10, device: CPU + model.sam_model.blocks.10.norm1, device: CPU + Module: model.sam_model.blocks.10.attn, device: CPU + model.sam_model.blocks.10.attn.qkv, device: CPU + model.sam_model.blocks.10.attn.proj, device: CPU + model.sam_model.blocks.10.attn.rel_pos_h, device: CPU + model.sam_model.blocks.10.attn.rel_pos_w, device: CPU + model.sam_model.blocks.10.norm2, device: CPU + Module: model.sam_model.blocks.10.mlp, device: CPU + model.sam_model.blocks.10.mlp.lin1, device: CPU + model.sam_model.blocks.10.mlp.lin2, device: CPU + model.sam_model.blocks.10.mlp.act, device: CPU + Module: model.sam_model.blocks.11, device: CPU + model.sam_model.blocks.11.norm1, device: CPU + Module: model.sam_model.blocks.11.attn, device: CPU + model.sam_model.blocks.11.attn.qkv, device: CPU + model.sam_model.blocks.11.attn.proj, device: CPU + model.sam_model.blocks.11.attn.rel_pos_h, device: CPU + model.sam_model.blocks.11.attn.rel_pos_w, device: CPU + model.sam_model.blocks.11.norm2, device: CPU + Module: model.sam_model.blocks.11.mlp, device: CPU + model.sam_model.blocks.11.mlp.lin1, device: CPU + model.sam_model.blocks.11.mlp.lin2, device: CPU + model.sam_model.blocks.11.mlp.act, device: CPU + Module: model.sam_model.neck, device: CPU + model.sam_model.neck.0, device: CPU + model.sam_model.neck.1, device: CPU + model.sam_model.neck.2, device: CPU + model.sam_model.neck.3, device: CPU + model.sam_model.net_2, device: CPU + model.sam_model.net_3, device: CPU + Module: model.vision_model, device: CPU + Module: model.vision_model.embeddings, device: CPU + model.vision_model.embeddings.class_embedding, device: CPU + model.vision_model.embeddings.patch_embedding, device: CPU + model.vision_model.embeddings.position_embedding, device: CPU + Module: model.vision_model.transformer, device: CPU + Module: model.vision_model.transformer.layers, device: CPU + Module: model.vision_model.transformer.layers.0, device: CPU + Module: model.vision_model.transformer.layers.0.self_attn, device: CPU + model.vision_model.transformer.layers.0.self_attn.qkv_proj, device: CPU + model.vision_model.transformer.layers.0.self_attn.out_proj, device: CPU + Module: model.vision_model.transformer.layers.0.mlp, device: CPU + model.vision_model.transformer.layers.0.mlp.fc1, device: CPU + model.vision_model.transformer.layers.0.mlp.fc2, device: CPU + model.vision_model.transformer.layers.0.mlp.act, device: CPU + model.vision_model.transformer.layers.0.layer_norm1, device: CPU + model.vision_model.transformer.layers.0.layer_norm2, device: CPU + Module: model.vision_model.transformer.layers.1, device: CPU + Module: model.vision_model.transformer.layers.1.self_attn, device: CPU + model.vision_model.transformer.layers.1.self_attn.qkv_proj, device: CPU + model.vision_model.transformer.layers.1.self_attn.out_proj, device: CPU + Module: model.vision_model.transformer.layers.1.mlp, device: CPU + model.vision_model.transformer.layers.1.mlp.fc1, device: CPU + model.vision_model.transformer.layers.1.mlp.fc2, device: CPU + model.vision_model.transformer.layers.1.mlp.act, device: CPU + model.vision_model.transformer.layers.1.layer_norm1, device: CPU + model.vision_model.transformer.layers.1.layer_norm2, device: CPU + Module: model.vision_model.transformer.layers.2, device: CPU + Module: model.vision_model.transformer.layers.2.self_attn, device: CPU + model.vision_model.transformer.layers.2.self_attn.qkv_proj, device: CPU + model.vision_model.transformer.layers.2.self_attn.out_proj, device: CPU + Module: model.vision_model.transformer.layers.2.mlp, device: CPU + model.vision_model.transformer.layers.2.mlp.fc1, device: CPU + model.vision_model.transformer.layers.2.mlp.fc2, device: CPU + model.vision_model.transformer.layers.2.mlp.act, device: CPU + model.vision_model.transformer.layers.2.layer_norm1, device: CPU + model.vision_model.transformer.layers.2.layer_norm2, device: CPU + Module: model.vision_model.transformer.layers.3, device: CPU + Module: model.vision_model.transformer.layers.3.self_attn, device: CPU + model.vision_model.transformer.layers.3.self_attn.qkv_proj, device: CPU + model.vision_model.transformer.layers.3.self_attn.out_proj, device: CPU + Module: model.vision_model.transformer.layers.3.mlp, device: CPU + model.vision_model.transformer.layers.3.mlp.fc1, device: CPU + model.vision_model.transformer.layers.3.mlp.fc2, device: CPU + model.vision_model.transformer.layers.3.mlp.act, device: CPU + model.vision_model.transformer.layers.3.layer_norm1, device: CPU + model.vision_model.transformer.layers.3.layer_norm2, device: CPU + Module: model.vision_model.transformer.layers.4, device: CPU + Module: model.vision_model.transformer.layers.4.self_attn, device: CPU + model.vision_model.transformer.layers.4.self_attn.qkv_proj, device: CPU + model.vision_model.transformer.layers.4.self_attn.out_proj, device: CPU + Module: model.vision_model.transformer.layers.4.mlp, device: CPU + model.vision_model.transformer.layers.4.mlp.fc1, device: CPU + model.vision_model.transformer.layers.4.mlp.fc2, device: CPU + model.vision_model.transformer.layers.4.mlp.act, device: CPU + model.vision_model.transformer.layers.4.layer_norm1, device: CPU + model.vision_model.transformer.layers.4.layer_norm2, device: CPU + Module: model.vision_model.transformer.layers.5, device: CPU + Module: model.vision_model.transformer.layers.5.self_attn, device: CPU + model.vision_model.transformer.layers.5.self_attn.qkv_proj, device: CPU + model.vision_model.transformer.layers.5.self_attn.out_proj, device: CPU + Module: model.vision_model.transformer.layers.5.mlp, device: CPU + model.vision_model.transformer.layers.5.mlp.fc1, device: CPU + model.vision_model.transformer.layers.5.mlp.fc2, device: CPU + model.vision_model.transformer.layers.5.mlp.act, device: CPU + model.vision_model.transformer.layers.5.layer_norm1, device: CPU + model.vision_model.transformer.layers.5.layer_norm2, device: CPU + Module: model.vision_model.transformer.layers.6, device: CPU + Module: model.vision_model.transformer.layers.6.self_attn, device: CPU + model.vision_model.transformer.layers.6.self_attn.qkv_proj, device: CPU + model.vision_model.transformer.layers.6.self_attn.out_proj, device: CPU + Module: model.vision_model.transformer.layers.6.mlp, device: CPU + model.vision_model.transformer.layers.6.mlp.fc1, device: CPU + model.vision_model.transformer.layers.6.mlp.fc2, device: CPU + model.vision_model.transformer.layers.6.mlp.act, device: CPU + model.vision_model.transformer.layers.6.layer_norm1, device: CPU + model.vision_model.transformer.layers.6.layer_norm2, device: CPU + Module: model.vision_model.transformer.layers.7, device: CPU + Module: model.vision_model.transformer.layers.7.self_attn, device: CPU + model.vision_model.transformer.layers.7.self_attn.qkv_proj, device: CPU + model.vision_model.transformer.layers.7.self_attn.out_proj, device: CPU + Module: model.vision_model.transformer.layers.7.mlp, device: CPU + model.vision_model.transformer.layers.7.mlp.fc1, device: CPU + model.vision_model.transformer.layers.7.mlp.fc2, device: CPU + model.vision_model.transformer.layers.7.mlp.act, device: CPU + model.vision_model.transformer.layers.7.layer_norm1, device: CPU + model.vision_model.transformer.layers.7.layer_norm2, device: CPU + Module: model.vision_model.transformer.layers.8, device: CPU + Module: model.vision_model.transformer.layers.8.self_attn, device: CPU + model.vision_model.transformer.layers.8.self_attn.qkv_proj, device: CPU + model.vision_model.transformer.layers.8.self_attn.out_proj, device: CPU + Module: model.vision_model.transformer.layers.8.mlp, device: CPU + model.vision_model.transformer.layers.8.mlp.fc1, device: CPU + model.vision_model.transformer.layers.8.mlp.fc2, device: CPU + model.vision_model.transformer.layers.8.mlp.act, device: CPU + model.vision_model.transformer.layers.8.layer_norm1, device: CPU + model.vision_model.transformer.layers.8.layer_norm2, device: CPU + Module: model.vision_model.transformer.layers.9, device: CPU + Module: model.vision_model.transformer.layers.9.self_attn, device: CPU + model.vision_model.transformer.layers.9.self_attn.qkv_proj, device: CPU + model.vision_model.transformer.layers.9.self_attn.out_proj, device: CPU + Module: model.vision_model.transformer.layers.9.mlp, device: CPU + model.vision_model.transformer.layers.9.mlp.fc1, device: CPU + model.vision_model.transformer.layers.9.mlp.fc2, device: CPU + model.vision_model.transformer.layers.9.mlp.act, device: CPU + model.vision_model.transformer.layers.9.layer_norm1, device: CPU + model.vision_model.transformer.layers.9.layer_norm2, device: CPU + Module: model.vision_model.transformer.layers.10, device: CPU + Module: model.vision_model.transformer.layers.10.self_attn, device: CPU + model.vision_model.transformer.layers.10.self_attn.qkv_proj, device: CPU + model.vision_model.transformer.layers.10.self_attn.out_proj, device: CPU + Module: model.vision_model.transformer.layers.10.mlp, device: CPU + model.vision_model.transformer.layers.10.mlp.fc1, device: CPU + model.vision_model.transformer.layers.10.mlp.fc2, device: CPU + model.vision_model.transformer.layers.10.mlp.act, device: CPU + model.vision_model.transformer.layers.10.layer_norm1, device: CPU + model.vision_model.transformer.layers.10.layer_norm2, device: CPU + Module: model.vision_model.transformer.layers.11, device: CPU + Module: model.vision_model.transformer.layers.11.self_attn, device: CPU + model.vision_model.transformer.layers.11.self_attn.qkv_proj, device: CPU + model.vision_model.transformer.layers.11.self_attn.out_proj, device: CPU + Module: model.vision_model.transformer.layers.11.mlp, device: CPU + model.vision_model.transformer.layers.11.mlp.fc1, device: CPU + model.vision_model.transformer.layers.11.mlp.fc2, device: CPU + model.vision_model.transformer.layers.11.mlp.act, device: CPU + model.vision_model.transformer.layers.11.layer_norm1, device: CPU + model.vision_model.transformer.layers.11.layer_norm2, device: CPU + Module: model.vision_model.transformer.layers.12, device: CPU + Module: model.vision_model.transformer.layers.12.self_attn, device: CPU + model.vision_model.transformer.layers.12.self_attn.qkv_proj, device: CPU + model.vision_model.transformer.layers.12.self_attn.out_proj, device: CPU + Module: model.vision_model.transformer.layers.12.mlp, device: CPU + model.vision_model.transformer.layers.12.mlp.fc1, device: CPU + model.vision_model.transformer.layers.12.mlp.fc2, device: CPU + model.vision_model.transformer.layers.12.mlp.act, device: CPU + model.vision_model.transformer.layers.12.layer_norm1, device: CPU + model.vision_model.transformer.layers.12.layer_norm2, device: CPU + Module: model.vision_model.transformer.layers.13, device: CPU + Module: model.vision_model.transformer.layers.13.self_attn, device: CPU + model.vision_model.transformer.layers.13.self_attn.qkv_proj, device: CPU + model.vision_model.transformer.layers.13.self_attn.out_proj, device: CPU + Module: model.vision_model.transformer.layers.13.mlp, device: CPU + model.vision_model.transformer.layers.13.mlp.fc1, device: CPU + model.vision_model.transformer.layers.13.mlp.fc2, device: CPU + model.vision_model.transformer.layers.13.mlp.act, device: CPU + model.vision_model.transformer.layers.13.layer_norm1, device: CPU + model.vision_model.transformer.layers.13.layer_norm2, device: CPU + Module: model.vision_model.transformer.layers.14, device: CPU + Module: model.vision_model.transformer.layers.14.self_attn, device: CPU + model.vision_model.transformer.layers.14.self_attn.qkv_proj, device: CPU + model.vision_model.transformer.layers.14.self_attn.out_proj, device: CPU + Module: model.vision_model.transformer.layers.14.mlp, device: CPU + model.vision_model.transformer.layers.14.mlp.fc1, device: CPU + model.vision_model.transformer.layers.14.mlp.fc2, device: CPU + model.vision_model.transformer.layers.14.mlp.act, device: CPU + model.vision_model.transformer.layers.14.layer_norm1, device: CPU + model.vision_model.transformer.layers.14.layer_norm2, device: CPU + Module: model.vision_model.transformer.layers.15, device: CPU + Module: model.vision_model.transformer.layers.15.self_attn, device: CPU + model.vision_model.transformer.layers.15.self_attn.qkv_proj, device: CPU + model.vision_model.transformer.layers.15.self_attn.out_proj, device: CPU + Module: model.vision_model.transformer.layers.15.mlp, device: CPU + model.vision_model.transformer.layers.15.mlp.fc1, device: CPU + model.vision_model.transformer.layers.15.mlp.fc2, device: CPU + model.vision_model.transformer.layers.15.mlp.act, device: CPU + model.vision_model.transformer.layers.15.layer_norm1, device: CPU + model.vision_model.transformer.layers.15.layer_norm2, device: CPU + Module: model.vision_model.transformer.layers.16, device: CPU + Module: model.vision_model.transformer.layers.16.self_attn, device: CPU + model.vision_model.transformer.layers.16.self_attn.qkv_proj, device: CPU + model.vision_model.transformer.layers.16.self_attn.out_proj, device: CPU + Module: model.vision_model.transformer.layers.16.mlp, device: CPU + model.vision_model.transformer.layers.16.mlp.fc1, device: CPU + model.vision_model.transformer.layers.16.mlp.fc2, device: CPU + model.vision_model.transformer.layers.16.mlp.act, device: CPU + model.vision_model.transformer.layers.16.layer_norm1, device: CPU + model.vision_model.transformer.layers.16.layer_norm2, device: CPU + Module: model.vision_model.transformer.layers.17, device: CPU + Module: model.vision_model.transformer.layers.17.self_attn, device: CPU + model.vision_model.transformer.layers.17.self_attn.qkv_proj, device: CPU + model.vision_model.transformer.layers.17.self_attn.out_proj, device: CPU + Module: model.vision_model.transformer.layers.17.mlp, device: CPU + model.vision_model.transformer.layers.17.mlp.fc1, device: CPU + model.vision_model.transformer.layers.17.mlp.fc2, device: CPU + model.vision_model.transformer.layers.17.mlp.act, device: CPU + model.vision_model.transformer.layers.17.layer_norm1, device: CPU + model.vision_model.transformer.layers.17.layer_norm2, device: CPU + Module: model.vision_model.transformer.layers.18, device: CPU + Module: model.vision_model.transformer.layers.18.self_attn, device: CPU + model.vision_model.transformer.layers.18.self_attn.qkv_proj, device: CPU + model.vision_model.transformer.layers.18.self_attn.out_proj, device: CPU + Module: model.vision_model.transformer.layers.18.mlp, device: CPU + model.vision_model.transformer.layers.18.mlp.fc1, device: CPU + model.vision_model.transformer.layers.18.mlp.fc2, device: CPU + model.vision_model.transformer.layers.18.mlp.act, device: CPU + model.vision_model.transformer.layers.18.layer_norm1, device: CPU + model.vision_model.transformer.layers.18.layer_norm2, device: CPU + Module: model.vision_model.transformer.layers.19, device: CPU + Module: model.vision_model.transformer.layers.19.self_attn, device: CPU + model.vision_model.transformer.layers.19.self_attn.qkv_proj, device: CPU + model.vision_model.transformer.layers.19.self_attn.out_proj, device: CPU + Module: model.vision_model.transformer.layers.19.mlp, device: CPU + model.vision_model.transformer.layers.19.mlp.fc1, device: CPU + model.vision_model.transformer.layers.19.mlp.fc2, device: CPU + model.vision_model.transformer.layers.19.mlp.act, device: CPU + model.vision_model.transformer.layers.19.layer_norm1, device: CPU + model.vision_model.transformer.layers.19.layer_norm2, device: CPU + Module: model.vision_model.transformer.layers.20, device: CPU + Module: model.vision_model.transformer.layers.20.self_attn, device: CPU + model.vision_model.transformer.layers.20.self_attn.qkv_proj, device: CPU + model.vision_model.transformer.layers.20.self_attn.out_proj, device: CPU + Module: model.vision_model.transformer.layers.20.mlp, device: CPU + model.vision_model.transformer.layers.20.mlp.fc1, device: CPU + model.vision_model.transformer.layers.20.mlp.fc2, device: CPU + model.vision_model.transformer.layers.20.mlp.act, device: CPU + model.vision_model.transformer.layers.20.layer_norm1, device: CPU + model.vision_model.transformer.layers.20.layer_norm2, device: CPU + Module: model.vision_model.transformer.layers.21, device: CPU + Module: model.vision_model.transformer.layers.21.self_attn, device: CPU + model.vision_model.transformer.layers.21.self_attn.qkv_proj, device: CPU + model.vision_model.transformer.layers.21.self_attn.out_proj, device: CPU + Module: model.vision_model.transformer.layers.21.mlp, device: CPU + model.vision_model.transformer.layers.21.mlp.fc1, device: CPU + model.vision_model.transformer.layers.21.mlp.fc2, device: CPU + model.vision_model.transformer.layers.21.mlp.act, device: CPU + model.vision_model.transformer.layers.21.layer_norm1, device: CPU + model.vision_model.transformer.layers.21.layer_norm2, device: CPU + Module: model.vision_model.transformer.layers.22, device: CPU + Module: model.vision_model.transformer.layers.22.self_attn, device: CPU + model.vision_model.transformer.layers.22.self_attn.qkv_proj, device: CPU + model.vision_model.transformer.layers.22.self_attn.out_proj, device: CPU + Module: model.vision_model.transformer.layers.22.mlp, device: CPU + model.vision_model.transformer.layers.22.mlp.fc1, device: CPU + model.vision_model.transformer.layers.22.mlp.fc2, device: CPU + model.vision_model.transformer.layers.22.mlp.act, device: CPU + model.vision_model.transformer.layers.22.layer_norm1, device: CPU + model.vision_model.transformer.layers.22.layer_norm2, device: CPU + Module: model.vision_model.transformer.layers.23, device: CPU + Module: model.vision_model.transformer.layers.23.self_attn, device: CPU + model.vision_model.transformer.layers.23.self_attn.qkv_proj, device: CPU + model.vision_model.transformer.layers.23.self_attn.out_proj, device: CPU + Module: model.vision_model.transformer.layers.23.mlp, device: CPU + model.vision_model.transformer.layers.23.mlp.fc1, device: CPU + model.vision_model.transformer.layers.23.mlp.fc2, device: CPU + model.vision_model.transformer.layers.23.mlp.act, device: CPU + model.vision_model.transformer.layers.23.layer_norm1, device: CPU + model.vision_model.transformer.layers.23.layer_norm2, device: CPU + model.vision_model.pre_layrnorm, device: CPU + Module: model.projector, device: CPU + model.projector.layers, device: CPU + model.image_newline, device: CPU + model.view_seperator, device: CPU + lm_head, device: CPU + diff --git a/mllm/backends/cpu/CPUBackend.cpp b/mllm/backends/cpu/CPUBackend.cpp index 396a0f5b9..508c381eb 100644 --- a/mllm/backends/cpu/CPUBackend.cpp +++ b/mllm/backends/cpu/CPUBackend.cpp @@ -5,6 +5,7 @@ #include "mllm/backends/cpu/CPUAllocator.hpp" // Ops +#include "mllm/backends/cpu/ops/ArgsortOp.hpp" #include "mllm/backends/cpu/ops/CastTypeOp.hpp" #include "mllm/backends/cpu/ops/CausalMaskOp.hpp" #include "mllm/backends/cpu/ops/ConcatOp.hpp" @@ -56,18 +57,19 @@ namespace mllm::cpu { CPUBackend::CPUBackend() : Backend(kCPU, createCPUAllocator()) { - regOpFactory< - CPULinearOpFactory, CPUFillOpFactory, CPUGraphBeginOpFactory, CPUGraphEndOpFactory, CPUAddOpFactory, CPUSubOpFactory, - CPUMulOpFactory, CPUDivOpFactory, CPUNegOpFactory, CPUAbsOpFactory, CPULogOpFactory, CPUExpOpFactory, CPUSinOpFactory, - CPUCosOpFactory, CPUReduceMaxOpFactory, CPUReduceMinOpFactory, CPUReduceSumOpFactory, CPUTransposeOpFactory, - CPUPermuteOpFactory, CPUCastTypeOpFactory, CPUConcatOpFactory, CPUStackOpFactory, CPUContiguousOpFactory, - CPUCopyOpFactory, CPUEmbeddingOpFactory, CPUSplitOpFactory, CPUViewOpFactory, CPULayerNormOpFactory, CPURepeatOpFactory, - CPUX2XOpFactory, CPUSoftmaxOpFactory, CPUSiLUOpFactory, CPURMSNormOpFactory, CPUGELUOpFactory, CPUQuickGELUOpFactory, - CPUReLUOpFactory, CPUMatMulOpFactory, CPUFlashAttention2OpFactory, CPUSliceOpFactory, CPUVisionRoPEOpFactory, - CPUParamOpFactory, CPUMultimodalRoPEOpFactory, CPURoPEOpFactory, CPUCausalMaskOpFactory, CPUConv1DOpFactory, - CPUConv3DOpFactory, CPUSTFTOpFactory, CPUISTFTOpFactory, CPUIndexOpFactory, CPUTopKOpFactory, CPUClipOpFactory, - CPUMeanOpFactory, CPUKVCacheOpFactory, CPUPagedAttnOpFactory, CPUScatter2ShardsOpFactory, CPURadixAttnOpFactory, - CPUConv2DOpFactory, CPULayerNorm2DOpFactory, CPUInterpolateOpFactory, CPUPadOpFactory, CPUMaskedScatterOpFactory>(); + regOpFactory(); } std::shared_ptr createCPUBackend() { return std::make_shared(); } diff --git a/mllm/backends/cpu/ops/ArgsortOp.cpp b/mllm/backends/cpu/ops/ArgsortOp.cpp new file mode 100644 index 000000000..646499c58 --- /dev/null +++ b/mllm/backends/cpu/ops/ArgsortOp.cpp @@ -0,0 +1,100 @@ +// Copyright (c) MLLM Team. +// Licensed under the MIT License. + +#include +#include +#include + +#include "mllm/backends/cpu/ops/ArgsortOp.hpp" +#include "mllm/core/DataTypes.hpp" +#include "mllm/utils/Common.hpp" +#include "mllm/utils/UnsafeMacros.hpp" + +namespace mllm::cpu { + +CPUArgsortOp::CPUArgsortOp(const aops::ArgsortOpOptions& options) : aops::ArgsortOp(options) {} + +namespace MLLM_ANONYMOUS_NAMESPACE { + +template +void argsort_impl(const T* input_data, int32_t* indices_data, int outer_size, int axis_size, int inner_size, int dim, + bool descending) { + for (int out = 0; out < outer_size; ++out) { + for (int in = 0; in < inner_size; ++in) { + // Create pairs of (value, index) for sorting + std::vector> data_pairs(axis_size); + for (int i = 0; i < axis_size; ++i) { + int index = out * axis_size * inner_size + i * inner_size + in; + data_pairs[i] = {input_data[index], static_cast(i)}; + } + + // Sort based on values + if (descending) { + std::sort(data_pairs.begin(), data_pairs.end(), + [](const std::pair& a, const std::pair& b) { return a.first > b.first; }); + } else { + std::sort(data_pairs.begin(), data_pairs.end(), + [](const std::pair& a, const std::pair& b) { return a.first < b.first; }); + } + + // Store sorted indices + for (int i = 0; i < axis_size; ++i) { + int out_index = out * axis_size * inner_size + i * inner_size + in; + indices_data[out_index] = data_pairs[i].second; + } + } + } +} +} // namespace MLLM_ANONYMOUS_NAMESPACE + +__MLLM_UNSAFE_OPT_BEGIN_O3_FAST_MATH +void CPUArgsortOp::forward(const std::vector& inputs, std::vector& outputs) { + auto& input = inputs[0]; + auto& indices = outputs[0]; + + auto dtype = input.dtype(); + int dim = options_.dim; + bool descending = options_.descending; + + // Handle negative dimension index + if (dim < 0) { dim += input.shape().size(); } + + // Calculate sizes + int outer_size = 1; + int inner_size = 1; + int axis_size = input.shape()[dim]; + + for (int i = 0; i < dim; ++i) { outer_size *= input.shape()[i]; } + for (int i = dim + 1; i < input.shape().size(); ++i) { inner_size *= input.shape()[i]; } + + switch (dtype) { + case kFloat32: { + argsort_impl(input.ptr(), indices.ptr(), outer_size, axis_size, inner_size, dim, descending); + break; + } + case kFloat16: { + argsort_impl(input.ptr(), indices.ptr(), outer_size, axis_size, inner_size, dim, + descending); + break; + } + case kInt32: { + argsort_impl(input.ptr(), indices.ptr(), outer_size, axis_size, inner_size, dim, + descending); + break; + } + case kInt16: { + argsort_impl(input.ptr(), indices.ptr(), outer_size, axis_size, inner_size, dim, + descending); + break; + } + case kInt8: { + argsort_impl(input.ptr(), indices.ptr(), outer_size, axis_size, inner_size, dim, + descending); + break; + } + default: NYI("Unsupported data type for ArgsortOp"); + } +} +__MLLM_UNSAFE_OPT_END + +} // namespace mllm::cpu diff --git a/mllm/backends/cpu/ops/ArgsortOp.hpp b/mllm/backends/cpu/ops/ArgsortOp.hpp new file mode 100644 index 000000000..3649b8bce --- /dev/null +++ b/mllm/backends/cpu/ops/ArgsortOp.hpp @@ -0,0 +1,24 @@ +// Copyright (c) MLLM Team. +// Licensed under the MIT License. + +#pragma once + +#include "mllm/core/aops/ArgsortOp.hpp" + +namespace mllm::cpu { + +class CPUArgsortOp final : public aops::ArgsortOp { + public: + explicit CPUArgsortOp(const aops::ArgsortOpOptions& options); + + void forward(const std::vector& inputs, std::vector& outputs) override; +}; + +class CPUArgsortOpFactory final : public TypedOpFactory { + public: + std::shared_ptr createOpImpl(const aops::ArgsortOpOptions& options) override { + return std::make_shared(options); + } +}; + +} // namespace mllm::cpu \ No newline at end of file diff --git a/mllm/backends/cpu/ops/Conv2DOp.cpp b/mllm/backends/cpu/ops/Conv2DOp.cpp index ae07b2c1d..0c1c58789 100644 --- a/mllm/backends/cpu/ops/Conv2DOp.cpp +++ b/mllm/backends/cpu/ops/Conv2DOp.cpp @@ -75,6 +75,12 @@ void CPUConv2DOp::forward(const std::vector& inputs, std::vector auto& padding = options_.padding; auto& dilation = options_.dilation; + MLLM_RT_ASSERT_EQ(input.rank(), 4); + auto batch_size = input.size(0); + auto _1 = input.size(1); + auto _2 = input.size(2); + auto _3 = input.size(3); + switch (input.dtype()) { case kFloat32: { switch (options_.impl_type) { @@ -95,60 +101,65 @@ void CPUConv2DOp::forward(const std::vector& inputs, std::vector int MATMUL_K = options_.in_channels * kernel_size[0] * kernel_size[1]; int MATMUL_N = output.shape()[2] * output.shape()[3]; -#if defined(MLLM_HOST_ARCH_ARM64) || defined(MLLM_HOST_ARCH_ARM) // step 1. im2col inputs to tmp auto packed_inputs = Tensor::empty({MATMUL_K, MATMUL_N}, input.dtype(), input.device()).alloc(); - arm::conv2d_fp32_im2col_input(input.ptr(), options_.in_channels, input.shape()[2], input.shape()[3], - kernel_size[0], kernel_size[1], padding[0], padding[1], stride[0], stride[1], - dilation[0], dilation[1], packed_inputs.ptr()); - // step 2. Do matmul - switch (mt) { // NOLINT - case aops::MatMulOpType::kBLAS: { + + for (int _b_idx = 0; _b_idx < batch_size; ++_b_idx) { +#if defined(MLLM_HOST_ARCH_ARM64) || defined(MLLM_HOST_ARCH_ARM) + + arm::conv2d_fp32_im2col_input(input.ptr() + _b_idx * (_1 * _2 * _3), options_.in_channels, + input.shape()[2], input.shape()[3], kernel_size[0], kernel_size[1], padding[0], + padding[1], stride[0], stride[1], dilation[0], dilation[1], + packed_inputs.ptr()); + // step 2. Do matmul + switch (mt) { // NOLINT + case aops::MatMulOpType::kBLAS: { #if defined(MLLM_USE_BLAS) - blas::matmul_fp32(weight_.ptr(), packed_inputs.ptr(), output.ptr(), - nullptr, MATMUL_M, MATMUL_N, MATMUL_K, false, false); - - // Add Bias - if (options_.bias) { - auto out_ptr = output.ptr(); - const auto bias_ptr = bias_.ptr(); - for (int m = 0; m < MATMUL_M; ++m) { - const float b = bias_ptr[m]; - for (int n = 0; n < MATMUL_N; ++n) { out_ptr[m * MATMUL_N + n] += b; } + blas::matmul_fp32(weight_.ptr(), packed_inputs.ptr(), output.ptr(), + nullptr, MATMUL_M, MATMUL_N, MATMUL_K, false, false); + + // Add Bias + if (options_.bias) { + auto out_ptr = output.ptr(); + const auto bias_ptr = bias_.ptr(); + for (int m = 0; m < MATMUL_M; ++m) { + const float b = bias_ptr[m]; + for (int n = 0; n < MATMUL_N; ++n) { out_ptr[m * MATMUL_N + n] += b; } + } } - } #else - NYI("BLAS not supported. Pls set MLLM_USE_BLAS=ON to enable BLAS supports in cmake."); + NYI("BLAS not supported. Pls set MLLM_USE_BLAS=ON to enable BLAS supports in cmake."); #endif - break; - } - case aops::MatMulOpType::kMllmBlas: { - auto thread_count = options_.getThreads(); + break; + } + case aops::MatMulOpType::kMllmBlas: { + auto thread_count = options_.getThreads(); #if defined(MLLM_HOST_ARCH_ARM64) || defined(MLLM_HOST_ARCH_ARM) - arm::mllm_blas_matmul_fp32(MATMUL_M, MATMUL_K, MATMUL_N, output.ptr(), weight_.ptr(), - packed_inputs.ptr(), nullptr, false, false, thread_count); - // Add Bias - if (options_.bias) { - auto out_ptr = output.ptr(); - const auto bias_ptr = bias_.ptr(); - for (int m = 0; m < MATMUL_M; ++m) { - const float b = bias_ptr[m]; - for (int n = 0; n < MATMUL_N; ++n) { out_ptr[m * MATMUL_N + n] += b; } + arm::mllm_blas_matmul_fp32(MATMUL_M, MATMUL_K, MATMUL_N, output.ptr(), weight_.ptr(), + packed_inputs.ptr(), nullptr, false, false, thread_count); + // Add Bias + if (options_.bias) { + auto out_ptr = output.ptr(); + const auto bias_ptr = bias_.ptr(); + for (int m = 0; m < MATMUL_M; ++m) { + const float b = bias_ptr[m]; + for (int n = 0; n < MATMUL_N; ++n) { out_ptr[m * MATMUL_N + n] += b; } + } } - } #else - NYI("MllmBlas only support MLLM_HOST_ARCH_ARM64 or MLLM_HOST_ARCH_ARM right now.") + NYI("MllmBlas only support MLLM_HOST_ARCH_ARM64 or MLLM_HOST_ARCH_ARM right now.") #endif - break; + break; + } + default: { + NYI("Unsupported matmul type"); + } } - } #else - MLLM_ERROR_EXIT(ExitCode::kCoreError, "Unsupported architecture for perform im2col conv2d."); + MLLM_ERROR_EXIT(ExitCode::kCoreError, "Unsupported architecture for perform im2col conv2d."); #endif - break; - } - default: { - NYI("Unsupported impl type"); + break; + } } } break; diff --git a/mllm/compile/ir/GeneratedRTTIKind.hpp b/mllm/compile/ir/GeneratedRTTIKind.hpp index 8af92f0bc..f097306f3 100644 --- a/mllm/compile/ir/GeneratedRTTIKind.hpp +++ b/mllm/compile/ir/GeneratedRTTIKind.hpp @@ -1,4 +1,4 @@ -// Auto generated: 2025-10-26 17:03:51 +// Auto generated: 2025-10-27 11:05:42 // do not modify this file #pragma once @@ -77,6 +77,9 @@ enum NodeKind : uint32_t { RK_Op_LinalgIROp_EinsumOp, RK_Op_LinalgIROp_StackOp, RK_Op_LinalgIROp_MaskedScatterOp, + RK_Op_LinalgIROp_ScatterOp, + RK_Op_LinalgIROp_GatherOp, + RK_Op_LinalgIROp_ArgsortOp, RK_Op_LinalgIROp_Last, RK_Op_GraphIROp, RK_Op_GraphIROp_SubGraphOp, diff --git a/mllm/compile/ir/NodeRTTIClassOfImpl.hpp b/mllm/compile/ir/NodeRTTIClassOfImpl.hpp index 7fb5d1b55..cd750459b 100644 --- a/mllm/compile/ir/NodeRTTIClassOfImpl.hpp +++ b/mllm/compile/ir/NodeRTTIClassOfImpl.hpp @@ -1,4 +1,4 @@ -// Auto generated: 2025-10-26 17:03:51 +// Auto generated: 2025-10-27 11:05:42 // do not modify this file #pragma once namespace mllm::ir { @@ -201,6 +201,15 @@ struct NodeRTTIClassOfImpl { #define RTTI_RK_OP_LINALGIROP_MASKEDSCATTEROP_IMPL(v) \ return (v)->getKind() >= RK_Op_LinalgIROp_MaskedScatterOp && (v)->getKind() <= RK_Op_LinalgIROp_MaskedScatterOp +#define RTTI_RK_OP_LINALGIROP_SCATTEROP_IMPL(v) \ + return (v)->getKind() >= RK_Op_LinalgIROp_ScatterOp && (v)->getKind() <= RK_Op_LinalgIROp_ScatterOp + +#define RTTI_RK_OP_LINALGIROP_GATHEROP_IMPL(v) \ + return (v)->getKind() >= RK_Op_LinalgIROp_GatherOp && (v)->getKind() <= RK_Op_LinalgIROp_GatherOp + +#define RTTI_RK_OP_LINALGIROP_ARGSORTOP_IMPL(v) \ + return (v)->getKind() >= RK_Op_LinalgIROp_ArgsortOp && (v)->getKind() <= RK_Op_LinalgIROp_ArgsortOp + #define RTTI_RK_OP_GRAPHIROP_IMPL(v) return (v)->getKind() >= RK_Op_GraphIROp && (v)->getKind() <= RK_Op_GraphIROp_Last #define RTTI_RK_OP_GRAPHIROP_SUBGRAPHOP_IMPL(v) \ diff --git a/mllm/compile/ir/linalg/Op.cpp b/mllm/compile/ir/linalg/Op.cpp index 4a6a1f8cf..68fa7b02b 100644 --- a/mllm/compile/ir/linalg/Op.cpp +++ b/mllm/compile/ir/linalg/Op.cpp @@ -108,4 +108,8 @@ LINALG_AOPS_DECL(OpTypes::kEinsum, EinsumOp); LINALG_AOPS_DECL(OpTypes::kStack, StackOp); LINALG_AOPS_DECL(OpTypes::kMaskedScatter, MaskedScatterOp); +LINALG_AOPS_DECL(OpTypes::kScatter, ScatterOp); +LINALG_AOPS_DECL(OpTypes::kGather, GatherOp); +LINALG_AOPS_DECL(OpTypes::kArgsort, ArgsortOp); + } // namespace mllm::ir::linalg diff --git a/mllm/compile/ir/linalg/Op.hpp b/mllm/compile/ir/linalg/Op.hpp index 265d919fc..6ea04d9aa 100644 --- a/mllm/compile/ir/linalg/Op.hpp +++ b/mllm/compile/ir/linalg/Op.hpp @@ -70,6 +70,9 @@ class InterpolateOp; class EinsumOp; class StackOp; class MaskedScatterOp; +class ScatterOp; +class GatherOp; +class ArgsortOp; } // namespace mllm #define LINALG_AOPS_DEFINE(class_name, rtti_name) \ @@ -225,5 +228,8 @@ LINALG_AOPS_DEFINE(InterpolateOp, INTERPOLATEOP); LINALG_AOPS_DEFINE(EinsumOp, EINSUMOP); LINALG_AOPS_DEFINE(StackOp, STACKOP); LINALG_AOPS_DEFINE(MaskedScatterOp, MASKEDSCATTEROP); +LINALG_AOPS_DEFINE(ScatterOp, SCATTEROP); +LINALG_AOPS_DEFINE(GatherOp, GATHEROP); +LINALG_AOPS_DEFINE(ArgsortOp, ARGSORTOP); } // namespace mllm::ir::linalg diff --git a/mllm/compile/ir/rtti_kind_gen.py b/mllm/compile/ir/rtti_kind_gen.py index ae4629ad6..d81d1e8cb 100644 --- a/mllm/compile/ir/rtti_kind_gen.py +++ b/mllm/compile/ir/rtti_kind_gen.py @@ -278,6 +278,9 @@ def define_lianlg_ir(ir: dict): op.derive(Cls("EinsumOp")) op.derive(Cls("StackOp")) op.derive(Cls("MaskedScatterOp")) + op.derive(Cls("ScatterOp")) + op.derive(Cls("GatherOp")) + op.derive(Cls("ArgsortOp")) # value diff --git a/mllm/core/OpTypes.hpp b/mllm/core/OpTypes.hpp index dfac01e04..d072055bb 100644 --- a/mllm/core/OpTypes.hpp +++ b/mllm/core/OpTypes.hpp @@ -83,6 +83,9 @@ enum class OpTypes : int32_t { kEinsum = 63, kStack = 64, kMaskedScatter = 65, + kScatter = 66, + kGather = 67, + kArgsort = 68, // Dynamic Op Start for user to register there own ops. kDynamicOp_Start = 4096, @@ -158,6 +161,9 @@ inline std::string optype2Str(OpTypes type) { case OpTypes::kStack: return "Stack"; case OpTypes::kEinsum: return "Einsum"; case OpTypes::kMaskedScatter: return "MaskedScatter"; + case OpTypes::kScatter: return "Scatter"; + case OpTypes::kGather: return "Gather"; + case OpTypes::kArgsort: return "Argsort"; case OpTypes::kDynamicOp_Start: return "DynamicOp_Start"; case OpTypes::kOpType_End: return "OpType_End"; default: return "Unknown"; diff --git a/mllm/core/Tensor.cpp b/mllm/core/Tensor.cpp index 6aed89153..3da641f72 100644 --- a/mllm/core/Tensor.cpp +++ b/mllm/core/Tensor.cpp @@ -21,6 +21,7 @@ #include "mllm/core/aops/TransposeOp.hpp" #include "mllm/core/aops/ViewOp.hpp" #include "mllm/core/aops/X2XOp.hpp" +#include "mllm/core/aops/ArgsortOp.hpp" #include "mllm/engine/Context.hpp" namespace mllm { @@ -70,6 +71,11 @@ Tensor Tensor::empty(const std::vector& shape, DataTypes dtype, DeviceT return Tensor(impl); } +Tensor Tensor::emptyLike(const Tensor& liked_tensor) { + auto ret = Tensor::empty(liked_tensor.shape(), liked_tensor.dtype(), liked_tensor.device()); + return ret; +} + Tensor& Tensor::allocExtraTensorView(const std::string& extra_tensor_name, const std::vector& shape, DataTypes dtype, DeviceTypes device) { MLLM_RT_ASSERT_EQ(attached_views_.count(extra_tensor_name), 0); @@ -129,6 +135,12 @@ Tensor Tensor::operator/(const Tensor& rhs) { return Context::instance().buildOpAndSubmitTask(OpTypes::kDiv, aops::DivOpOptions{}, {*this, rhs})[0]; } +Tensor Tensor::mul_(const Tensor& rhs) { + auto opts = aops::MulOpOptions{}; + opts.setInplace(true); + return Context::instance().buildOpAndSubmitTask(OpTypes::kMul, opts, {*this, rhs})[0]; +} + Tensor Tensor::operator+(float rhs) { auto rhs_tensor = Tensor::empty({1}, dtype(), device()).alloc(); switch (dtype()) { @@ -243,6 +255,11 @@ Tensor Tensor::operator/(std::complex rhs) { Tensor Tensor::abs() { return Context::instance().buildOpAndSubmitTask(OpTypes::kAbs, aops::AbsOpOptions{}, {*this})[0]; } +Tensor Tensor::argsort(int dim, bool descending) { + return Context::instance().buildOpAndSubmitTask(OpTypes::kArgsort, + aops::ArgsortOpOptions{.dim = dim, .descending = descending}, {*this})[0]; +} + Tensor Tensor::clip(float min_val, float max_val) { return Context::instance().buildOpAndSubmitTask(OpTypes::kClip, aops::ClipOpOptions{.min_val = min_val, .max_val = max_val}, {*this})[0]; @@ -370,6 +387,7 @@ Tensor Tensor::repeat(int32_t multiplier, int32_t dim) { } Tensor Tensor::unsqueeze(int32_t dim) { + if (dim < 0) { dim = static_cast(rank()) + dim; } auto this_shape = shape(); this_shape.insert(this_shape.begin() + dim, 1); return view(this_shape); diff --git a/mllm/core/Tensor.hpp b/mllm/core/Tensor.hpp index 0621916a1..7f04b211a 100644 --- a/mllm/core/Tensor.hpp +++ b/mllm/core/Tensor.hpp @@ -175,6 +175,14 @@ class Tensor { */ static Tensor empty(const std::vector& shape, DataTypes dtype = kFloat32, DeviceTypes device = kCPU); + /** + * @brief Creates an uninitialized tensor with the same shape and attributes as another tensor. + * + * @param liked_tensor + * @return Tensor + */ + static Tensor emptyLike(const Tensor& liked_tensor); + /** * @brief If this tensor is not initialized * @@ -267,6 +275,8 @@ class Tensor { Tensor operator/(const Tensor& rhs); /// @} + Tensor mul_(const Tensor& rhs); + /// @name Scalar Operations /// Element-wise operations with scalar values. /// @{ @@ -296,6 +306,15 @@ class Tensor { */ Tensor abs(); + /** + * @brief Argsort + * + * @param dim + * @param descending + * @return Tensor + */ + Tensor argsort(int dim = -1, bool descending = false); + /** * @brief Clips (limits) the values in a tensor. * @param min_val Minimum value diff --git a/mllm/core/aops/ArgSortOp.cpp b/mllm/core/aops/ArgSortOp.cpp deleted file mode 100644 index e69de29bb..000000000 diff --git a/mllm/core/aops/ArgSortOp.hpp b/mllm/core/aops/ArgSortOp.hpp deleted file mode 100644 index e69de29bb..000000000 diff --git a/mllm/core/aops/ArgsortOp.cpp b/mllm/core/aops/ArgsortOp.cpp new file mode 100644 index 000000000..35e2f4d61 --- /dev/null +++ b/mllm/core/aops/ArgsortOp.cpp @@ -0,0 +1,40 @@ +// Copyright (c) MLLM Team. +// Licensed under the MIT License. + +#include "mllm/core/aops/ArgsortOp.hpp" +#include "mllm/core/BaseOp.hpp" +#include "mllm/core/Tensor.hpp" +#include "mllm/utils/Common.hpp" +#include "mllm/compile/ir/linalg/Op.hpp" + +namespace mllm::aops { + +ArgsortOp::ArgsortOp(const ArgsortOpOptions& options) : BaseOp(OpTypes::kArgsort), options_(options) {} + +void ArgsortOp::load(const ParameterFile::ptr_t& ploader) { MLLM_EMPTY_SCOPE; } + +void ArgsortOp::trace(void* trace_context, const std::vector& inputs, std::vector& outputs) { + auto ir_ctx = (ir::IRContext*)trace_context; + auto i_irs = ir::tensor::wrapTensors2TensorIR(ir_ctx, inputs); + auto o_irs = ir::tensor::wrapTensors2TensorIR(ir_ctx, outputs); + ir_ctx->create(shared_from_this(), i_irs, o_irs); +} + +void ArgsortOp::forward(const std::vector& inputs, std::vector& outputs) { + NYI("ArgsortOp::forward not implemented in aops base."); +} + +void ArgsortOp::reshape(const std::vector& inputs, std::vector& outputs) { + // Define output tensor shapes based on input shapes + auto& input = inputs[0]; + + if (!input.isNil()) { + auto input_shape = input.shape(); + // Output indices tensor has the same shape as input + outputs.emplace_back(Tensor::empty(input_shape, kInt32, input.device())); + } +} + +void ArgsortOp::setup(const std::vector& inputs, std::vector& outputs) { BaseOp::setup(inputs, outputs); } + +} // namespace mllm::aops diff --git a/mllm/core/aops/ArgsortOp.hpp b/mllm/core/aops/ArgsortOp.hpp new file mode 100644 index 000000000..4db898736 --- /dev/null +++ b/mllm/core/aops/ArgsortOp.hpp @@ -0,0 +1,36 @@ +// Copyright (c) MLLM Team. +// Licensed under the MIT License. + +#pragma once + +#include "mllm/core/BaseOp.hpp" +#include "mllm/core/ParameterFile.hpp" + +namespace mllm::aops { + +struct ArgsortOpOptions : public BaseOpOptions { + int dim = -1; + bool descending = false; +}; + +class ArgsortOp : public BaseOp { + public: + explicit ArgsortOp(const ArgsortOpOptions& options); + + void load(const ParameterFile::ptr_t& ploader) override; + + void trace(void* trace_context, const std::vector& inputs, std::vector& outputs) override; + + void forward(const std::vector& inputs, std::vector& outputs) override; + + void reshape(const std::vector& inputs, std::vector& outputs) override; + + void setup(const std::vector& inputs, std::vector& outputs) override; + + inline ArgsortOpOptions& options() { return options_; } + + protected: + ArgsortOpOptions options_; +}; + +} // namespace mllm::aops diff --git a/mllm/core/aops/Conv2DOp.cpp b/mllm/core/aops/Conv2DOp.cpp index 46628b137..4b20fea8b 100644 --- a/mllm/core/aops/Conv2DOp.cpp +++ b/mllm/core/aops/Conv2DOp.cpp @@ -76,9 +76,6 @@ void Conv2DOp::reshape(const std::vector& inputs, std::vector& o const int in_height = ishape[2]; // height axis const int in_width = ishape[3]; // width axis - // Current only support single batch - MLLM_RT_ASSERT_EQ(batch, 1); - MLLM_RT_ASSERT_EQ(in_channels, options_.in_channels); // Retrieve convolution parameters from options_ diff --git a/mllm/core/aops/ElewiseOps.cpp b/mllm/core/aops/ElewiseOps.cpp index 06319f589..55e00189d 100644 --- a/mllm/core/aops/ElewiseOps.cpp +++ b/mllm/core/aops/ElewiseOps.cpp @@ -47,6 +47,10 @@ static std::vector broadcastShapes(const std::vector>& sha MLLM_WARN(#name "::forward is not implemented"); \ } \ void name::reshape(const std::vector& inputs, std::vector& outputs) { \ + if (options_.isInplace()) { \ + outputs.emplace_back(inputs[0]); \ + return; \ + } \ std::vector> input_shapes; \ input_shapes.reserve(inputs.size()); \ for (const auto& input : inputs) { input_shapes.push_back(input.shape()); } \ @@ -58,7 +62,10 @@ static std::vector broadcastShapes(const std::vector>& sha } \ outputs.emplace_back(output_0); \ } \ - void name::setup(const std::vector& inputs, std::vector& outputs) { BaseOp::setup(inputs, outputs); } + void name::setup(const std::vector& inputs, std::vector& outputs) { \ + if (options_.isInplace()) { return; } \ + BaseOp::setup(inputs, outputs); \ + } // for unary ops, reshape don't consider shape broadcast, and dtype of output needs to be handled for Abs op #define __MLLM_ELEWISE_UNARY_OP_IMPL(types, name) \ diff --git a/mllm/core/aops/PadOp.cpp b/mllm/core/aops/PadOp.cpp index 923c5b905..b6d05e323 100644 --- a/mllm/core/aops/PadOp.cpp +++ b/mllm/core/aops/PadOp.cpp @@ -2,6 +2,7 @@ // Licensed under the MIT License. #include "mllm/core/aops/PadOp.hpp" +#include "mllm/compile/ir/linalg/Op.hpp" namespace mllm::aops { @@ -10,7 +11,10 @@ PadOp::PadOp(const PadOpOptions& options) : BaseOp(OpTypes::kPad), options_(opti void PadOp::load(const ParameterFile::ptr_t& ploader) { MLLM_EMPTY_SCOPE; } void PadOp::trace(void* trace_context, const std::vector& inputs, std::vector& outputs) { - // TODO + auto ir_ctx = (ir::IRContext*)trace_context; + auto i_irs = ir::tensor::wrapTensors2TensorIR(ir_ctx, inputs); + auto o_irs = ir::tensor::wrapTensors2TensorIR(ir_ctx, outputs); + ir_ctx->create(shared_from_this(), i_irs, o_irs); } void PadOp::forward(const std::vector& inputs, std::vector& outputs) { @@ -18,11 +22,23 @@ void PadOp::forward(const std::vector& inputs, std::vector& outp } void PadOp::reshape(const std::vector& inputs, std::vector& outputs) { - // TODO -} + if (inputs.empty()) { throw std::invalid_argument("PadOp::reshape: inputs empty"); } + + const auto& i = inputs[0]; + const int in_dims = i.shape().size(); + const auto& pad = options_.pad; // [dim_last_low, dim_last_high, ...] -void PadOp::setup(const std::vector& inputs, std::vector& outputs) { - // TODO + std::vector out_shape(i.shape().begin(), i.shape().end()); + for (int d = 0; d < in_dims; ++d) { + int idx = (in_dims - 1 - d) * 2; + int32_t low = (idx < pad.size()) ? pad[idx] : 0; + int32_t high = (idx + 1 < pad.size()) ? pad[idx + 1] : 0; + out_shape[d] += low + high; + } + + outputs.emplace_back(Tensor::empty(out_shape, i.dtype(), i.device())); } +void PadOp::setup(const std::vector& inputs, std::vector& outputs) { BaseOp::setup(inputs, outputs); } + } // namespace mllm::aops diff --git a/mllm/models/deepseek_ocr/configuration_deepseek_ocr.hpp b/mllm/models/deepseek_ocr/configuration_deepseek_ocr.hpp index eb007d988..883be51a1 100644 --- a/mllm/models/deepseek_ocr/configuration_deepseek_ocr.hpp +++ b/mllm/models/deepseek_ocr/configuration_deepseek_ocr.hpp @@ -180,6 +180,7 @@ struct DpskOcrConfig : protected ConfigFile { int32_t vocab_size = 129280; // MLLM Related Stuff + int32_t max_cache_length = 2048; aops::LinearImplTypes clip_linear_impl_type = aops::LinearImplTypes::kDefault; aops::LinearImplTypes sam_linear_impl_type = aops::LinearImplTypes::kDefault; aops::LinearImplTypes mlp_projector_linear_impl_type = aops::LinearImplTypes::kDefault; diff --git a/mllm/models/deepseek_ocr/conversation_preprocess.hpp b/mllm/models/deepseek_ocr/conversation_preprocess.hpp index 7191ac301..df54e4f7b 100644 --- a/mllm/models/deepseek_ocr/conversation_preprocess.hpp +++ b/mllm/models/deepseek_ocr/conversation_preprocess.hpp @@ -2,19 +2,21 @@ // Licensed under the MIT License. #pragma once +#include +#include #include #include -#include #include -#include +#include #include #include #include -#include -#include -#include -#include +#include + #include +#include + +#include #include "mllm/preprocessor/visual/Image.hpp" diff --git a/mllm/models/deepseek_ocr/deepencoder.hpp b/mllm/models/deepseek_ocr/deepencoder.hpp index ba42d55e3..16f7cd27a 100644 --- a/mllm/models/deepseek_ocr/deepencoder.hpp +++ b/mllm/models/deepseek_ocr/deepencoder.hpp @@ -506,7 +506,7 @@ class Attention final : public nn::Module { auto rel_w = Tensor::nil(); if (use_rel_pos_) { - std::tie(rel_h, rel_w) = addDecomposedRelPos(q, rel_pos_h_.weight(), rel_pos_w_.weight(), {H, W}, {W, H}); + std::tie(rel_h, rel_w) = addDecomposedRelPos(q, rel_pos_h_.weight(), rel_pos_w_.weight(), {H, W}, {H, W}); } q = q.view({B, num_heads_, H * W, -1}); @@ -586,6 +586,8 @@ class Block final : public nn::Module { auto shortcut = x; x = norm1_(x); + print(x.shape(), window_size_); + // Window partition int H = 0; int W = 0; @@ -596,6 +598,9 @@ class Block final : public nn::Module { std::tie(x, pad_hw) = windowPartition(x, window_size_); } + print(x.shape(), H, W, pad_hw); + exit(0); + x = attn_(x)[0]; // Reverse window partition @@ -617,7 +622,7 @@ class Blocks final : public nn::Module { Blocks(const std::string& name, int nums, const std::vector& global_attn_indexes, const DpskOcrConfig& config) : nn::Module(name) { for (int i = 0; i < nums; ++i) { - bool is_in = std::find(global_attn_indexes.begin(), global_attn_indexes.end(), i) != global_attn_indexes.end(); + bool is_in = std::find(global_attn_indexes.begin(), global_attn_indexes.end(), i) == global_attn_indexes.end(); auto this_block_window_size = is_in ? 14 : 0; blocks_.emplace_back(reg(std::to_string(i), 768, 12, 4.0, true, true, this_block_window_size, std::make_optional(std::make_tuple(1024 / 16, 1024 / 16)), config)); diff --git a/mllm/models/deepseek_ocr/modeling_deepseek_ocr.hpp b/mllm/models/deepseek_ocr/modeling_deepseek_ocr.hpp index a58def5c0..0a2c33ce5 100644 --- a/mllm/models/deepseek_ocr/modeling_deepseek_ocr.hpp +++ b/mllm/models/deepseek_ocr/modeling_deepseek_ocr.hpp @@ -11,6 +11,7 @@ #include "mllm/nn/Nn.hpp" #include "mllm/utils/StringHelper.hpp" #include "mllm/models/ARGeneration.hpp" +#include "mllm/nn/lmcache/StaticCache.hpp" #include "mllm/preprocessor/visual/ImageTransform.hpp" #include "mllm/models/deepseek_ocr/deepencoder.hpp" @@ -20,6 +21,62 @@ namespace mllm::models::deepseek_ocr { +inline auto makeRoPEInvFreq(int output_dim, float rope_theta) -> Tensor { + auto inv_freq = Tensor::empty({output_dim / 2}, kFloat32, kCPU).alloc(); + auto inv_freq_ptr = inv_freq.ptr(); + for (int i = 0; i < output_dim / 2; i++) { inv_freq_ptr[i] = 1.0 / std::pow(rope_theta, 2.0 * i / output_dim); } + return inv_freq; +} + +inline auto makeRotaryPosEmbedding(Tensor& position_ids, const Tensor& inv_freq, + float attention_scaling = 1.0f) -> std::pair { + auto batch_size = position_ids.shape()[0]; + auto seq_len = position_ids.shape()[1]; + auto inv_freq_len = inv_freq.shape()[0]; + auto dim = inv_freq_len * 2; + + // Create freqs tensor: position_ids @ inv_freq + auto freqs = Tensor::empty({batch_size, seq_len, inv_freq_len}, kFloat32, kCPU).alloc(); + auto freqs_ptr = freqs.ptr(); + auto position_ids_ptr = position_ids.ptr(); + auto inv_freq_ptr = inv_freq.ptr(); + + // Compute freqs = position_ids[:, :, None] @ inv_freq[None, :] + for (int b = 0; b < batch_size; ++b) { + for (int s = 0; s < seq_len; ++s) { + auto pos = position_ids_ptr[b * seq_len + s]; + for (int d = 0; d < inv_freq_len; ++d) { + freqs_ptr[b * seq_len * inv_freq_len + s * inv_freq_len + d] = static_cast(pos) * inv_freq_ptr[d]; + } + } + } + + // Create sin and cos tensors with shape [batch_size, seq_len, dim] + auto sin_emb = Tensor::empty({batch_size, seq_len, dim}, kFloat32, kCPU).alloc(); + auto cos_emb = Tensor::empty({batch_size, seq_len, dim}, kFloat32, kCPU).alloc(); + auto sin_ptr = sin_emb.ptr(); + auto cos_ptr = cos_emb.ptr(); + + // Compute sin and cos embeddings: emb = [freqs, freqs] + for (int b = 0; b < batch_size; ++b) { + for (int s = 0; s < seq_len; ++s) { + for (int d = 0; d < inv_freq_len; ++d) { + auto freq = freqs_ptr[b * seq_len * inv_freq_len + s * inv_freq_len + d]; + auto sin_val = std::sin(freq) * attention_scaling; + auto cos_val = std::cos(freq) * attention_scaling; + + // Store the same values in both halves: [freqs, freqs] + sin_ptr[b * seq_len * dim + s * dim + d] = sin_val; + sin_ptr[b * seq_len * dim + s * dim + d + inv_freq_len] = sin_val; + cos_ptr[b * seq_len * dim + s * dim + d] = cos_val; + cos_ptr[b * seq_len * dim + s * dim + d + inv_freq_len] = cos_val; + } + } + } + + return {sin_emb, cos_emb}; +} + class DeepseekV2MLP final : public nn::Module { nn::Linear gate_proj_; nn::Linear up_proj_; @@ -113,7 +170,7 @@ class DeepseekV2MoE final : public nn::Module { nn::ModuleList experts_; MoEGate gate_; - nn::ModuleList shared_experts_; + DeepseekV2MLP shared_experts_; public: DeepseekV2MoE() = default; @@ -130,8 +187,7 @@ class DeepseekV2MoE final : public nn::Module { if (n_shared_experts_ > 0) { auto intermediate_size = config.moe_intermediate_size * config.n_shared_experts; - shared_experts_ = - reg>("shared_experts", n_shared_experts_, config, std::nullopt, intermediate_size); + shared_experts_ = reg("shared_experts", config, std::nullopt, intermediate_size); } } @@ -152,36 +208,233 @@ class DeepseekV2MoE final : public nn::Module { } private: - Tensor moeInfer(const Tensor& x, const Tensor& topk_ids, const Tensor& topk_weights) { - // TODO - return Tensor::nil(); + Tensor moeInfer(const Tensor& x, Tensor& topk_ids, Tensor& topk_weights) { + // x shape is [batch_size * seq, hidden_dim] + + auto cnts = Tensor::zeros({topk_ids.size(0), (int32_t)experts_.list().size()}); + // Do scatter_ operation + { + const int32_t* idx_ptr = topk_ids.ptr(); + float* cnt_ptr = cnts.ptr(); + const int batch = topk_ids.size(0); + const int k = topk_ids.size(1); + const int n_exp = cnts.size(1); + for (int b = 0; b < batch; ++b) { + for (int j = 0; j < k; ++j) { + int32_t e = idx_ptr[b * k + j]; + MLLM_RT_ASSERT(e >= 0 && e < n_exp); + cnt_ptr[b * n_exp + e] += 1.f; // +1 + } + } + } + auto tokens_per_expert = cnts.sum(0); + auto idxs = topk_ids.view({-1}).argsort(); + + // TODO this line maybe error + auto sorted_tokens = x[{idxs / topk_ids.size(1), {kAll}}]; + + std::vector outputs; + int start_idx = 0; + + // tokens_per_expert shape is [num_experts] + // Loop through each expert + for (int i = 0; i < experts_.list().size(); ++i) { + auto num_tokens = tokens_per_expert.ptr()[i]; + auto end_idx = start_idx + (int32_t)num_tokens; + if (num_tokens == 0) { continue; } + auto& expert = experts_.list()[i]; + auto tokens_for_this_expert = sorted_tokens[{{start_idx, end_idx}, kAll}]; + auto expert_out = expert(tokens_for_this_expert)[0]; + outputs.push_back(expert_out); + start_idx = end_idx; + } + + auto outs = nn::functional::concat(outputs, 0); + auto new_x = Tensor::emptyLike(outs).alloc(); + + // indexed_write + // python logic: new_x[idxs] = outs + { + const int32_t* idx_ptr = idxs.ptr(); + float* outs_ptr = outs.ptr(); + float* new_x_ptr = new_x.ptr(); + MLLM_RT_ASSERT_EQ(new_x.rank(), 2); + MLLM_RT_ASSERT_EQ(new_x.size(0), idxs.size(0)); + auto dim = new_x.size(1); + for (int i = 0; i < idxs.size(0); ++i) { + int32_t idx = idx_ptr[i]; + std::memcpy(new_x_ptr + idx * dim, outs_ptr + i * dim, dim * sizeof(float)); + } + } + + auto final_out_shape = topk_ids.shape(); + final_out_shape.emplace_back(-1); + auto final_out = + new_x.view(final_out_shape).to(topk_weights.dtype()).mul_(topk_weights.unsqueeze(-1).sum(1).to(new_x.dtype())); + + return final_out; } }; +// Deepseek OCR's attention not used MLA. It's same with LlamaFlashAttention2 class DeepseekV2Attention final : public nn::Module { + nn::Linear q_proj_; + nn::Linear k_proj_; + nn::Linear v_proj_; + nn::RoPE q_rope_; + nn::RoPE k_rope_; + nn::Linear o_proj_; + int hidden_size_; + int num_head_; + int head_dim_; + int num_key_value_heads_; + public: - // TODO + int layer_idx_; + + DeepseekV2Attention() = default; + + DeepseekV2Attention(const std::string& name, const DpskOcrConfig& config) : nn::Module(name) { + hidden_size_ = config.hidden_size; + num_head_ = config.num_attention_heads; + head_dim_ = config.hidden_size / config.num_attention_heads; + num_key_value_heads_ = config.num_key_value_heads; + + // clang-format off + q_proj_ = reg("q_proj", hidden_size_, num_head_ * head_dim_, false, config.llm_mlp_linear_impl_type); + k_proj_ = reg("k_proj", hidden_size_, num_key_value_heads_ * head_dim_, false, config.llm_mlp_linear_impl_type).redirect(); + v_proj_ = reg("v_proj", hidden_size_, num_key_value_heads_ * head_dim_, false, config.llm_mlp_linear_impl_type).redirect(); + o_proj_ = reg("o_proj", num_head_ * head_dim_, hidden_size_, false, config.llm_mlp_linear_impl_type); + q_rope_ = reg("q_rope", 10000.0, config.max_position_embeddings, aops::RoPEOpOptionsInputType::kBSHD).inplace(); + k_rope_ = reg("k_rope", 10000.0, config.max_position_embeddings, aops::RoPEOpOptionsInputType::kBSHD).inplace(); + // clang-format on + } + + std::vector forward(const std::vector& inputs, const std::vector& args) override { + auto hidden_states = inputs[0]; + auto llm_embedding_sin = inputs[1]; + auto llm_embedding_cos = inputs[2]; + auto past_kv_cache = args[0].get(); + + auto bsz = hidden_states.size(0); + auto q_len = hidden_states.size(1); + + // Get KV cache for Key and Value first. + // [B, S, H * D] + auto [key_states_redirect, value_states_redirect] = past_kv_cache->preGetKVWriteLocation(layer_idx_, q_len); + + auto query_states = q_proj_(hidden_states); + auto key_states = k_proj_(hidden_states, key_states_redirect); + auto value_states = v_proj_(hidden_states, value_states_redirect); + + // [B, S, H, D] + query_states = query_states.view({bsz, q_len, num_head_, head_dim_}); + key_states = key_states.view({bsz, q_len, num_key_value_heads_, head_dim_}); + + // [B, S, H, D] + query_states = q_rope_(query_states, llm_embedding_sin, llm_embedding_cos); + key_states = k_rope_(key_states, llm_embedding_sin, llm_embedding_cos); + + // Get KV + auto [K, V] = past_kv_cache->getKVCache(layer_idx_); + + // [B, S, H, D] FA2 + auto output = o_proj_(nn::functional::flashAttention2(query_states, K, V).view({bsz, q_len, num_head_ * head_dim_})); + + return {output}; + } }; class DeepseekV2DecoderLayer final : public nn::Module { + // Use llama2 attention impl in deepseek-ocr model + DeepseekV2Attention self_attn_; + + // FIXME: Do not use hard-code + int first_k_dense_replace_ = 1; + int moe_layer_freq_ = 1; + + nn::RMSNorm input_layernorm_; + nn::RMSNorm post_attention_layernorm_; + + std::optional mlp_opt0_ = std::nullopt; + std::optional mlp_opt1_ = std::nullopt; + public: - // TODO + int layer_idx_; + + DeepseekV2DecoderLayer() = default; + + DeepseekV2DecoderLayer(const std::string& name, const DpskOcrConfig& config, int layer_idx) : nn::Module(name) { + layer_idx_ = layer_idx; + first_k_dense_replace_ = config.first_k_dense_replace; + + self_attn_ = reg("self_attn", config); + self_attn_.layer_idx_ = layer_idx; + + if (config.n_routed_experts > 0 && layer_idx_ >= config.first_k_dense_replace && layer_idx_ % moe_layer_freq_ == 0) { + mlp_opt0_ = reg("mlp", config); + } else { + mlp_opt1_ = reg("mlp", config); + } + + input_layernorm_ = reg("input_layernorm", 1e-6); + post_attention_layernorm_ = reg("post_attention_layernorm", 1e-6); + } + + std::vector forward(const std::vector& inputs, const std::vector& args) override { + auto hidden_states = inputs[0]; + auto rope_pos_embed_sin = inputs[1]; + auto rope_pos_embed_cos = inputs[2]; + auto kv_cache = args[0]; + auto residual = hidden_states; + + hidden_states = input_layernorm_(hidden_states); + hidden_states = self_attn_(hidden_states, rope_pos_embed_sin, rope_pos_embed_cos, kv_cache)[0]; + hidden_states = residual + hidden_states; + + residual = hidden_states; + hidden_states = post_attention_layernorm_(hidden_states); + if (mlp_opt0_) { + hidden_states = mlp_opt0_.value()(hidden_states)[0]; + } else { + hidden_states = mlp_opt1_.value()(hidden_states)[0]; + } + hidden_states = residual + hidden_states; + + return {hidden_states}; + } }; class DeepSeekV2Model : public nn::Module { protected: nn::Embedding embed_tokens_; + nn::ModuleListWithIdx layers_; + nn::RMSNorm norm_; public: DeepSeekV2Model() = default; explicit DeepSeekV2Model(const std::string& name, const DpskOcrConfig& config) : nn::Module(name) { embed_tokens_ = reg("embed_tokens", config.vocab_size, config.hidden_size); + layers_ = reg>("layers", config.num_hidden_layers, config); + norm_ = reg("norm", 1e-6); } std::vector forward(const std::vector& inputs, const std::vector& args) override { - // TODO - return {}; + auto& input_embeddings = inputs[0]; + auto rope_embedding_sin = inputs[1]; + auto rope_embedding_cos = inputs[2]; + auto kv_cache = args[0]; + + auto hidden_states = input_embeddings; + + for (auto& layer : layers_.list()) { + hidden_states = layer(hidden_states, rope_embedding_sin, rope_embedding_cos, kv_cache)[0]; + } + + hidden_states = norm_(hidden_states); + + return {hidden_states}; } }; @@ -212,6 +465,8 @@ class DeepseekOCRModel final : public DeepSeekV2Model { auto image_ori = inputs.size() > 2 ? inputs[2] : Tensor::nil(); auto images_spatial_crop = inputs.size() > 3 ? inputs[3] : Tensor::nil(); auto images_seq_mask = inputs.size() > 4 ? inputs[4] : Tensor::nil(); + auto rope_embedding_sin = inputs.size() > 5 ? inputs[5] : Tensor::nil(); + auto rope_embedding_cos = inputs.size() > 6 ? inputs[6] : Tensor::nil(); // Embedding auto inputs_embeds = embed_tokens_(input_ids); @@ -359,16 +614,16 @@ class DeepseekOCRModel final : public DeepSeekV2Model { // Scatter copy. if (images_in_this_batch) { nn::functional::maskedScatter(inputs_embeds, images_seq_mask, images_in_this_batch); } - // Normal forward with text and embedded image - // TODO + auto sequence = DeepSeekV2Model::forward({inputs_embeds, rope_embedding_sin, rope_embedding_cos}, args)[0]; - return {}; + return {sequence}; } }; class DeepseekOCRForCausalLM final : public nn::Module, public ARGeneration { DeepseekOCRModel model_; nn::Linear lm_head_; + nn::StaticCache kv_cache_; public: DeepseekOCRForCausalLM() = default; @@ -376,9 +631,75 @@ class DeepseekOCRForCausalLM final : public nn::Module, public ARGeneration { explicit DeepseekOCRForCausalLM(const DpskOcrConfig& config) { model_ = reg("model", config); lm_head_ = reg("lm_head", config.hidden_size, config.vocab_size, false, config.lm_head_linear_impl_type); + + // Init inv freq + auto inv = makeRoPEInvFreq(config.hidden_size / config.num_attention_heads, 10000.0); + registerBuffer("inv_freq", inv); + + // kv_cache_ + kv_cache_ = nn::StaticCache(config.max_cache_length, config.num_hidden_layers, + config.num_attention_heads, // q_heads + config.num_key_value_heads, // kv_heads + config.hidden_size / config.num_attention_heads, // kv_dim + kFloat32, // k_dtype + kFloat32, // v_dtype + kCPU, // device_type + true // use_fa2 + ); } - ARGenerationOutputPast forward(const ARGenerationOutputPast& input, const ARGenerationArgs& args) override { return {}; } + ARGenerationOutputPast forward(const ARGenerationOutputPast& input, const ARGenerationArgs& args) override { + auto patches = input.count("patches") ? input.at("patches") : Tensor::nil(); + auto image_ori = input.count("image_ori") ? input.at("image_ori") : Tensor::nil(); + auto images_spatial_crop = input.count("images_spatial_crop") ? input.at("images_spatial_crop") : Tensor::nil(); + auto images_seq_mask = input.count("images_seq_mask") ? input.at("images_seq_mask") : Tensor::nil(); + + auto sequence = input.at("sequence"); + + // Generate position_ids for the current sequence + auto batch_size = sequence.shape()[0]; + auto seq_len = sequence.shape()[1]; + + auto position_ids = Tensor::nil(); + auto rope_embedding_sin = Tensor::nil(); + auto rope_embedding_cos = Tensor::nil(); + auto kv_cache = args.at("kv_cache"); + + if (input.count("position_ids")) { + // Use existing position_ids for decode phase + position_ids = input.at("position_ids"); + // For decode phase, increment the last position + if (seq_len == 1) { + auto last_pos = *position_ids.offsettedPtr({0, position_ids.shape()[1] - 1}); + position_ids = Tensor::empty({batch_size, 1}, kInt64, kCPU).alloc(); + *position_ids.offsettedPtr({0, 0}) = last_pos + 1; + } + } else { + // Generate position_ids for prefill phase + position_ids = Tensor::empty({batch_size, seq_len}, kInt64, kCPU).alloc(); + auto position_ids_ptr = position_ids.ptr(); + for (int b = 0; b < batch_size; ++b) { + for (int s = 0; s < seq_len; ++s) { position_ids_ptr[b * seq_len + s] = s; } + } + } + + auto [llm_embedding_sin, llm_embedding_cos] = makeRotaryPosEmbedding(position_ids, getBuffer("inv_freq"), 1.0f); + rope_embedding_sin = llm_embedding_sin; + rope_embedding_cos = llm_embedding_cos; + sequence = model_(sequence, patches, image_ori, images_spatial_crop, images_seq_mask, rope_embedding_sin, + rope_embedding_cos, kv_cache)[0]; + // clip x to one seq length + { + auto S = sequence.shape()[1]; + sequence = sequence[{kAll, {S - 1}, kAll}]; + } + sequence = lm_head_(sequence); + + return { + {"sequence", sequence}, + {"position_ids", position_ids}, + }; + } void infer(DpskOcrTokenizer& tokenizer, const std::string& prompt, const std::string& image_fp, const std::string& output_path, int base_size = 1024, int image_size = 640, bool crop_mode = true) { @@ -433,7 +754,7 @@ class DeepseekOCRForCausalLM final : public nn::Module, public ARGeneration { // Processed states std::vector tokenized_str; - std::vector images_seq_mask; + std::vector images_seq_mask; std::vector images_list; std::vector images_crop_list; std::vector> images_spatial_crop; @@ -533,7 +854,7 @@ class DeepseekOCRForCausalLM final : public nn::Module, public ARGeneration { // Prepare Tensor to DeepSeek-OCR Model auto input_ids = Tensor::fromVector(tokenized_str, {1, (int32_t)tokenized_str.size()}, kInt64); - auto images_seq_mask_tensor = Tensor::fromVector(images_seq_mask, {1, (int32_t)images_seq_mask.size()}, kFloat32); + auto images_seq_mask_tensor = Tensor::fromVector(images_seq_mask, {1, (int32_t)images_seq_mask.size()}, kInt8); auto images_ori_tensor = Tensor::nil(); auto images_spatial_crop_tensor = Tensor::nil(); auto images_crop_tensor = Tensor::nil(); @@ -557,8 +878,24 @@ class DeepseekOCRForCausalLM final : public nn::Module, public ARGeneration { } } - // Run model. Use generate - // TODO + std::stringstream result; + streamGenerate( + { + {"patches", images_crop_tensor}, + {"image_ori", images_ori_tensor}, + {"images_spatial_crop", images_spatial_crop_tensor}, + {"images_seq_mask", images_seq_mask_tensor}, + {"sequence", input_ids}, + }, + { + {"kv_cache", mllm::AnyValue(&kv_cache_)}, + }, + [&](int64_t token_id) { + auto decode = tokenizer.decode({token_id}); + result << decode; + fmt::print("{}", decode); + }); + print("\n"); ///< flush // Post process data // TODO diff --git a/mllm/nn/Module.hpp b/mllm/nn/Module.hpp index 061bedc9b..cedefd799 100644 --- a/mllm/nn/Module.hpp +++ b/mllm/nn/Module.hpp @@ -44,6 +44,9 @@ class ModuleImpl : public AbstractNnNode { template class ModuleList; +template +class ModuleListWithIdx; + template class ModuleListSuffixed; @@ -198,6 +201,29 @@ class ModuleList final : public Module { std::vector& list() { return layers_; } }; +template +class ModuleListWithIdx final : public Module { + std::vector layers_; + + public: + ModuleListWithIdx() = default; + + template + ModuleListWithIdx(const std::string& name, int nums, Args&&... args) : Module(name) { + for (int i = 0; i < nums; ++i) { + layers_.emplace_back(reg(/*name*/ std::to_string(i), /*args*/ std::forward(args)..., /*index*/ i)); + } + }; + + std::vector forward(const std::vector& inputs, const std::vector& args) override { + std::vector o = inputs; + for (auto& layer : layers_) { o = layer.forward(o, args); } + return o; + } + + std::vector& list() { return layers_; } +}; + template class ModuleListSuffixed final : public Module { std::vector layers_; From 9b843b6992fa74dc657f980facdd9c2e31a0e984 Mon Sep 17 00:00:00 2001 From: chenghuaWang <2923277184@qq.com> Date: Mon, 27 Oct 2025 17:52:45 +0800 Subject: [PATCH 20/25] feat(tensor): support negative dim in repeat operation feat(pad): simplify shape copy in PadOp feat(deepseek_ocr): add layer index to Attention and Block modules feat(deepseek_ocr): reshape qkv tensors explicitly in Attention feat(deepseek_ocr): enable rel_pos tensor repeat for cpu backend compatibility feat(deepseek_ocr): initialize layer_idx_ in Blocks module refactor(deepseek_ocr): remove debug print statements in Block module fix(deepseek_ocr): annotate batch size assumption as TODO in CLIPVisionEmbeddings --- .../deepseek_ocr/quant_cfg_w4a8_kai_i8mm.json | 0 mllm/core/Tensor.cpp | 1 + mllm/core/aops/PadOp.cpp | 2 +- mllm/models/deepseek_ocr/deepencoder.hpp | 28 ++++++++++++++----- 4 files changed, 23 insertions(+), 8 deletions(-) create mode 100644 examples/deepseek_ocr/quant_cfg_w4a8_kai_i8mm.json diff --git a/examples/deepseek_ocr/quant_cfg_w4a8_kai_i8mm.json b/examples/deepseek_ocr/quant_cfg_w4a8_kai_i8mm.json new file mode 100644 index 000000000..e69de29bb diff --git a/mllm/core/Tensor.cpp b/mllm/core/Tensor.cpp index 3da641f72..a9ac5bb44 100644 --- a/mllm/core/Tensor.cpp +++ b/mllm/core/Tensor.cpp @@ -382,6 +382,7 @@ Tensor Tensor::view(const Tensor::shape_t& indicies) { Tensor Tensor::repeat(int32_t multiplier, int32_t dim) { if (multiplier == 1) { return *this; } + if (dim < 0) { dim = static_cast(rank()) + dim; } return Context::instance().buildOpAndSubmitTask(OpTypes::kRepeat, aops::RepeatOpOptions{.dim = dim, .repeat_times = multiplier}, {*this})[0]; } diff --git a/mllm/core/aops/PadOp.cpp b/mllm/core/aops/PadOp.cpp index b6d05e323..5421b8198 100644 --- a/mllm/core/aops/PadOp.cpp +++ b/mllm/core/aops/PadOp.cpp @@ -28,7 +28,7 @@ void PadOp::reshape(const std::vector& inputs, std::vector& outp const int in_dims = i.shape().size(); const auto& pad = options_.pad; // [dim_last_low, dim_last_high, ...] - std::vector out_shape(i.shape().begin(), i.shape().end()); + std::vector out_shape = i.shape(); for (int d = 0; d < in_dims; ++d) { int idx = (in_dims - 1 - d) * 2; int32_t low = (idx < pad.size()) ? pad[idx] : 0; diff --git a/mllm/models/deepseek_ocr/deepencoder.hpp b/mllm/models/deepseek_ocr/deepencoder.hpp index 16f7cd27a..1eff89665 100644 --- a/mllm/models/deepseek_ocr/deepencoder.hpp +++ b/mllm/models/deepseek_ocr/deepencoder.hpp @@ -144,7 +144,7 @@ class CLIPVisionEmbeddings final : public nn::Module { // patch_embeds original shape is [batch(1), out_channel, width, grid, grid] patch_embeds = patch_embeds.flatten(2).transpose(1, 2); // [batch(1), width * grid * grid, out_channel] - // Assume batch is always 1 + // TODO bugs. Assume batch is always 1 MLLM_RT_ASSERT_EQ(batch_size, 1); // [batch(1), 1, 1024] auto class_embeds = class_embedding_.weight().view({1, 1, 1024}); @@ -352,6 +352,8 @@ class Attention final : public nn::Module { nn::Param rel_pos_w_; public: + int layer_idx_; + Attention() = default; Attention(const std::string& name, int dim, int num_heads, bool qkv_bias, bool use_rel_pos, @@ -502,6 +504,10 @@ class Attention final : public nn::Module { qkv = qkv.view({3, B * num_heads_, H * W, -1}); auto [q, k, v] = nn::functional::split<3>(qkv, 0); + q = q.view({B * num_heads_, H * W, -1}); + k = k.view({B * num_heads_, H * W, -1}); + v = v.view({B * num_heads_, H * W, -1}); + auto rel_h = Tensor::nil(); auto rel_w = Tensor::nil(); @@ -516,6 +522,15 @@ class Attention final : public nn::Module { if (use_rel_pos_) { rel_h = rel_h.view({B, num_heads_, rel_h.size(1), rel_h.size(2), rel_h.size(3)}); rel_w = rel_w.view({B, num_heads_, rel_w.size(1), rel_w.size(2), rel_w.size(3)}); + + // Dual broadcast is not supported in cpu backend. So we need to repeat rel_h and rel_w + // torch.Size([54, 12, 196, 14, 1]) + // torch.Size([54, 12, 196, 1, 14]) + auto _dim_neg_1 = rel_w.size(4); + auto _dim_neg_2 = rel_h.size(3); + rel_h = rel_h.repeat(_dim_neg_1, -1); + rel_w = rel_w.repeat(_dim_neg_2, -2); + MLLM_RT_ASSERT_EQ(rel_h.shape(), rel_w.shape()); auto attn_bias = (rel_h + rel_w).view({B, num_heads_, rel_h.size(2), rel_h.size(3) * rel_w.size(4)}); x = nn::functional::scaledDotProductAttention(q, k, v, attn_bias); } else { @@ -531,11 +546,13 @@ class Attention final : public nn::Module { class Block final : public nn::Module { nn::LayerNorm norm1_; nn::LayerNorm norm2_; - Attention attn_; MLPBlock mlp_; int window_size_; public: + Attention attn_; + int layer_idx_; + Block() = default; Block(const std::string& name, int dim, int num_heads, float mlp_ratio, bool qkv_bias, bool use_rel_pos, int window_size, @@ -586,8 +603,6 @@ class Block final : public nn::Module { auto shortcut = x; x = norm1_(x); - print(x.shape(), window_size_); - // Window partition int H = 0; int W = 0; @@ -598,9 +613,6 @@ class Block final : public nn::Module { std::tie(x, pad_hw) = windowPartition(x, window_size_); } - print(x.shape(), H, W, pad_hw); - exit(0); - x = attn_(x)[0]; // Reverse window partition @@ -626,6 +638,8 @@ class Blocks final : public nn::Module { auto this_block_window_size = is_in ? 14 : 0; blocks_.emplace_back(reg(std::to_string(i), 768, 12, 4.0, true, true, this_block_window_size, std::make_optional(std::make_tuple(1024 / 16, 1024 / 16)), config)); + blocks_[i].layer_idx_ = i; + blocks_[i].attn_.layer_idx_ = i; } }; From 324cd50b81789038deb568b271d4734d89182191 Mon Sep 17 00:00:00 2001 From: chenghuaWang <2923277184@qq.com> Date: Mon, 27 Oct 2025 22:02:07 +0800 Subject: [PATCH 21/25] feat(cpu): implement optimized softmax for last dimension cases - Add fast path for softmax when axis is -1 (last dimension) - Support both fp32 and fp16 data types with architecture-specific implementations - Enable parallel execution based on thread options - Handle negative axis indexing correctly fix(tensor): correct unsqueeze dimension calculation - Adjust negative dimension indexing logic in unsqueeze operation - Add 1 to dim adjustment to match expected behavior feat(deepseek-ocr): improve vision embeddings and attention mechanisms - Update position_ids registration with proper view shape - Fix tensor slicing in CLIPVisionEmbeddings by using squeezed tensor - Remove batch size assertion and support dynamic batch sizes - Replace class embedding expansion with repeat operation - Add contiguous operations in attention modules for performance - Use view instead of reshape in attention output processing - Optimize MoE final output computation by adjusting sum operation order - Add temporary contiguous calls to work around performance issues in concat operations - Clean up unnecessary comments and fix tensor indexing in feature concatenation --- mllm/backends/cpu/ops/SoftmaxOp.cpp | 43 +++++++++++++++++++ mllm/core/Tensor.cpp | 2 +- mllm/models/deepseek_ocr/deepencoder.hpp | 26 +++++------ .../deepseek_ocr/modeling_deepseek_ocr.hpp | 11 ++--- 4 files changed, 63 insertions(+), 19 deletions(-) diff --git a/mllm/backends/cpu/ops/SoftmaxOp.cpp b/mllm/backends/cpu/ops/SoftmaxOp.cpp index 09bf01aa4..2f412be2c 100644 --- a/mllm/backends/cpu/ops/SoftmaxOp.cpp +++ b/mllm/backends/cpu/ops/SoftmaxOp.cpp @@ -14,6 +14,49 @@ void CPUSoftmaxOp::forward(const std::vector& inputs, std::vector 1, options_.getThreads(), loop_idx, 0, loop_times, 1, { +#if defined(MLLM_HOST_ARCH_X86_64) || defined(MLLM_HOST_ARCH_X86) + x86::softmax_v1_fp32(X.ptr() + loop_idx * loop_dims, Y.ptr() + +loop_idx * loop_dims, + loop_dims, 1, options_.getThreads()); +#elif defined(MLLM_HOST_ARCH_ARM64) || defined(MLLM_HOST_ARCH_ARM) + arm::softmax_v1_fp32(X.ptr() + loop_idx * loop_dims, Y.ptr() + +loop_idx * loop_dims, + loop_dims, 1, options_.getThreads()); +#endif + }); + break; + } + case kFloat16: { + MLLM_CONDITIONAL_PARALLEL_FOR(options_.getThreads() > 1, options_.getThreads(), loop_idx, 0, loop_times, 1, { +#if defined(MLLM_HOST_ARCH_X86_64) || defined(MLLM_HOST_ARCH_X86) + NYI("CPUSoftmaxOp::forward not support dtype {}", nameOfType(X.dtype())); +#elif defined(MLLM_HOST_ARCH_ARM64) || defined(MLLM_HOST_ARCH_ARM) + arm::softmax_v1_fp16(X.ptr() + loop_idx * loop_dims, Y.ptr() + +loop_idx * loop_dims, + loop_dims, 1, options_.getThreads()); +#endif + }); + break; + } + default: { + NYI("CPUSoftmaxOp::forward not support dtype {}", nameOfType(X.dtype())); + break; + } + } + return; + } + + // CASE 2. FOR NORMAL CASE in LLM (BSHD, BHSD, etc) MLLM_RT_ASSERT_EQ(X.shape().size(), 4); MLLM_RT_ASSERT(options_.axis == -1 || options_.axis == 3); diff --git a/mllm/core/Tensor.cpp b/mllm/core/Tensor.cpp index a9ac5bb44..a1fc90c73 100644 --- a/mllm/core/Tensor.cpp +++ b/mllm/core/Tensor.cpp @@ -388,7 +388,7 @@ Tensor Tensor::repeat(int32_t multiplier, int32_t dim) { } Tensor Tensor::unsqueeze(int32_t dim) { - if (dim < 0) { dim = static_cast(rank()) + dim; } + if (dim < 0) { dim = static_cast(rank()) + dim + 1; } auto this_shape = shape(); this_shape.insert(this_shape.begin() + dim, 1); return view(this_shape); diff --git a/mllm/models/deepseek_ocr/deepencoder.hpp b/mllm/models/deepseek_ocr/deepencoder.hpp index 1eff89665..c454c0b13 100644 --- a/mllm/models/deepseek_ocr/deepencoder.hpp +++ b/mllm/models/deepseek_ocr/deepencoder.hpp @@ -93,7 +93,7 @@ class CLIPVisionEmbeddings final : public nn::Module { position_embedding_ = reg("position_embedding", num_positions_, embed_dim_); // Register a buffer - registerBuffer("position_ids", Tensor::arange(0, num_positions_, 1, kInt64, kCPU)); + registerBuffer("position_ids", Tensor::arange(0, num_positions_, 1, kInt64, kCPU).view({1, -1})); } Tensor getAbsPos(Tensor abs_pos, int32_t tgt_size) { @@ -103,8 +103,8 @@ class CLIPVisionEmbeddings final : public nn::Module { auto dim = abs_pos.size(-1); auto abs_pos_new = abs_pos.squeeze(0); - auto cls_token = abs_pos[{{kAll, 1}, kAll}].contiguous(); - auto old_pos_embed = abs_pos[{{1, kAll}, kAll}].contiguous(); + auto cls_token = abs_pos_new[{{kAll, 1}, kAll}].contiguous(); + auto old_pos_embed = abs_pos_new[{{1, kAll}, kAll}].contiguous(); auto src_size = int(std::sqrt(abs_pos_new.shape()[0] - 1)); tgt_size = int(std::sqrt(tgt_size)); @@ -141,13 +141,12 @@ class CLIPVisionEmbeddings final : public nn::Module { if (!patch_embeds) { patch_embeds = patch_embedding_(pixel_values); } // Flatten and transpose. - // patch_embeds original shape is [batch(1), out_channel, width, grid, grid] - patch_embeds = patch_embeds.flatten(2).transpose(1, 2); // [batch(1), width * grid * grid, out_channel] + // patch_embeds original shape is [batch, out_channel, width, grid, grid] + patch_embeds = patch_embeds.flatten(2).transpose(1, 2); // [batch, width * grid * grid, out_channel] - // TODO bugs. Assume batch is always 1 - MLLM_RT_ASSERT_EQ(batch_size, 1); - // [batch(1), 1, 1024] - auto class_embeds = class_embedding_.weight().view({1, 1, 1024}); + // [batch, 1, 1024] + // Same as expand(batch_size, 1, -1) + auto class_embeds = class_embedding_.weight().view({1, 1, -1}).repeat(batch_size, 0); auto embeddings = nn::functional::concat({class_embeds, patch_embeds}, 1); embeddings = embeddings + getAbsPos(position_embedding_(getBuffer("position_ids")), embeddings.size(1)); @@ -205,16 +204,17 @@ class NoTPAttention final : public nn::Module { xqkv = xqkv.view({bsz, seqlen, 3, num_heads_, head_dim_}); auto [xq, xk, xv] = nn::functional::split<3>(xqkv, 2); - xq = xq.squeeze(2); - xk = xk.squeeze(2); - xv = xv.squeeze(2); + // FIXME: contiguous is not needed, actually. + xq = xq.contiguous().squeeze(2); + xk = xk.contiguous().squeeze(2); + xv = xv.contiguous().squeeze(2); xq = xq.permute({0, 2, 1, 3}); xk = xk.permute({0, 2, 1, 3}); xv = xv.permute({0, 2, 1, 3}); auto output = nn::functional::scaledDotProductAttention(xq, xk, xv); - output = output.permute({0, 2, 1, 3}).reshape({bsz, seqlen, -1}); + output = output.permute({0, 2, 1, 3}).view({bsz, seqlen, -1}); output = out_proj_(output); return {output}; } diff --git a/mllm/models/deepseek_ocr/modeling_deepseek_ocr.hpp b/mllm/models/deepseek_ocr/modeling_deepseek_ocr.hpp index 0a2c33ce5..d76310922 100644 --- a/mllm/models/deepseek_ocr/modeling_deepseek_ocr.hpp +++ b/mllm/models/deepseek_ocr/modeling_deepseek_ocr.hpp @@ -270,8 +270,7 @@ class DeepseekV2MoE final : public nn::Module { auto final_out_shape = topk_ids.shape(); final_out_shape.emplace_back(-1); auto final_out = - new_x.view(final_out_shape).to(topk_weights.dtype()).mul_(topk_weights.unsqueeze(-1).sum(1).to(new_x.dtype())); - + new_x.view(final_out_shape).to(topk_weights.dtype()).mul_(topk_weights.unsqueeze(-1)).sum(1).to(new_x.dtype()); return final_out; } }; @@ -480,7 +479,8 @@ class DeepseekOCRModel final : public DeepSeekV2Model { auto local_features_2 = vision_model_(patches, local_features_1)[0]; auto local_features = nn::functional::concat( { - local_features_2[{kAll, {1, kAll}}], + // FIXME: contiguous is not needed. We use contiguous because mllm has weak performance in this case. + local_features_2[{kAll, {1, kAll}, kAll}].contiguous(), local_features_1.flatten(2).permute({0, 2, 1}), }, -1); @@ -491,7 +491,8 @@ class DeepseekOCRModel final : public DeepSeekV2Model { auto global_features_2 = vision_model_(image_ori, global_features_1)[0]; auto global_features = nn::functional::concat( { - global_features_2[{kAll, {1, kAll}}], + // FIXME: contiguous is not needed. We use contiguous because mllm has weak performance in this case. + global_features_2[{kAll, {1, kAll}, kAll}].contiguous(), global_features_1.flatten(2).permute({0, 2, 1}), }, -1); @@ -566,7 +567,7 @@ class DeepseekOCRModel final : public DeepSeekV2Model { auto global_features_2 = vision_model_(image_ori, global_features_1)[0]; auto global_features = nn::functional::concat( { - global_features_2[{kAll, {1, kAll}}], + global_features_2[{kAll, {1, kAll}, kAll}], global_features_1.flatten(2).permute({0, 2, 1}), }, -1); From 00d787a0daaa2cbaf490cd4a3f3c69c149de3e4e Mon Sep 17 00:00:00 2001 From: chenghuaWang <2923277184@qq.com> Date: Tue, 28 Oct 2025 15:31:53 +0800 Subject: [PATCH 22/25] feat(cpu): add Tracy profiler option and update quantization config - Added `MLLM_TRACY_ENABLE` option in CMakeLists.txt for enabling Tracy profiler - Introduced new quantization configuration file for DeepSeek OCR model using KAI quantization methods - Updated Conv2DOp to correctly handle output tensor indexing during BLAS operations - Added support for configurable linear implementation types in DeepSeek OCR model - Improved image resizing with bicubic interpolation and better bounds checking - Fixed scaling calculation in scaled dot-product attention using standard sqrtf - Minor code cleanup and debugging prints added in DeepSeek OCR encoder The changes span across build configuration, model quantization, CPU backend ops, model definitions, and image preprocessing utilities. The addition of Tracy profiler support provides enhanced performance analysis capabilities. The quantization config enables more efficient model inference through k-ai quantization techniques. CPU convolution fixes ensure proper tensor handling. Model updates allow flexible linear layer implementations. Image processing improvements offer better quality and safety. --- CMakeLists.txt | 1 + .../deepseek_ocr/quant_cfg_w4a8_kai_i8mm.json | 301 ++++++++++++++++++ mllm/backends/cpu/ops/Conv2DOp.cpp | 19 +- .../configuration_deepseek_ocr.hpp | 5 + mllm/models/deepseek_ocr/deepencoder.hpp | 3 + .../deepseek_ocr/modeling_deepseek_ocr.hpp | 3 + mllm/nn/Functional.cpp | 4 +- mllm/preprocessor/visual/Image.cpp | 23 +- mllm/preprocessor/visual/Image.hpp | 4 +- 9 files changed, 344 insertions(+), 19 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index 1e85ecf7f..b8e209453 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -40,6 +40,7 @@ option(MLLM_KERNEL_THREADS_VENDOR_APPLE_GCD "Enable Apple GCD Threads" OFF) # Performance components option(MLLM_PERFETTO_ENABLE "Enable perfetto" OFF) +option(MLLM_TRACY_ENABLE "Enable Tracy. A more advanced profiler" OFF) message(STATUS "CXX Compiler=${CMAKE_CXX_COMPILER_ID}") message(STATUS "CXX Compiler version=${CMAKE_CXX_COMPILER_VERSION}") diff --git a/examples/deepseek_ocr/quant_cfg_w4a8_kai_i8mm.json b/examples/deepseek_ocr/quant_cfg_w4a8_kai_i8mm.json index e69de29bb..4d58d855b 100644 --- a/examples/deepseek_ocr/quant_cfg_w4a8_kai_i8mm.json +++ b/examples/deepseek_ocr/quant_cfg_w4a8_kai_i8mm.json @@ -0,0 +1,301 @@ +{ + "lm_head.weight": { + "hints": { + "quant_method": "kai", + "kai_matmul_triplet": "f32_qai8dxp_qsi4c32p", + "kai_matmul_layout": "mxk_nxk", + "kai_matmul_tile_cfg": "qai8dxp4x8_qsi4c32p8x8_4x8x32", + "shape": [ + 129280, + 1280 + ], + "replace": true + } + }, + "^model\\.layers\\.\\d+\\.mlp\\.down_proj\\.weight": { + "hints": { + "quant_method": "kai", + "kai_matmul_triplet": "f32_qai8dxp_qsi4c32p", + "kai_matmul_layout": "mxk_nxk", + "kai_matmul_tile_cfg": "qai8dxp4x8_qsi4c32p8x8_4x8x32", + "shape": [ + 1280, + 6848 + ], + "replace": true + } + }, + "^model\\.layers\\.\\d+\\.mlp\\.gate_proj\\.weight": { + "hints": { + "quant_method": "kai", + "kai_matmul_triplet": "f32_qai8dxp_qsi4c32p", + "kai_matmul_layout": "mxk_nxk", + "kai_matmul_tile_cfg": "qai8dxp4x8_qsi4c32p8x8_4x8x32", + "shape": [ + 6848, + 1280 + ], + "replace": true + } + }, + "^model\\.layers\\.\\d+\\.mlp\\.up_proj\\.weight": { + "hints": { + "quant_method": "kai", + "kai_matmul_triplet": "f32_qai8dxp_qsi4c32p", + "kai_matmul_layout": "mxk_nxk", + "kai_matmul_tile_cfg": "qai8dxp4x8_qsi4c32p8x8_4x8x32", + "shape": [ + 6848, + 1280 + ], + "replace": true + } + }, + "^model\\.layers\\.\\d+\\.self_attn\\.k_proj\\.weight": { + "hints": { + "quant_method": "kai", + "kai_matmul_triplet": "f32_qai8dxp_qsi4c32p", + "kai_matmul_layout": "mxk_nxk", + "kai_matmul_tile_cfg": "qai8dxp4x8_qsi4c32p8x8_4x8x32", + "shape": [ + 1280, + 1280 + ], + "replace": true + } + }, + "^model\\.layers\\.\\d+\\.self_attn\\.q_proj\\.weight": { + "hints": { + "quant_method": "kai", + "kai_matmul_triplet": "f32_qai8dxp_qsi4c32p", + "kai_matmul_layout": "mxk_nxk", + "kai_matmul_tile_cfg": "qai8dxp4x8_qsi4c32p8x8_4x8x32", + "shape": [ + 1280, + 1280 + ], + "replace": true + } + }, + "^model\\.layers\\.\\d+\\.self_attn\\.v_proj\\.weight": { + "hints": { + "quant_method": "kai", + "kai_matmul_triplet": "f32_qai8dxp_qsi4c32p", + "kai_matmul_layout": "mxk_nxk", + "kai_matmul_tile_cfg": "qai8dxp4x8_qsi4c32p8x8_4x8x32", + "shape": [ + 1280, + 1280 + ], + "replace": true + } + }, + "^model\\.layers\\.\\d+\\.self_attn\\.o_proj\\.weight": { + "hints": { + "quant_method": "kai", + "kai_matmul_triplet": "f32_qai8dxp_qsi4c32p", + "kai_matmul_layout": "mxk_nxk", + "kai_matmul_tile_cfg": "qai8dxp4x8_qsi4c32p8x8_4x8x32", + "shape": [ + 1280, + 1280 + ], + "replace": true + } + }, + "^model\\.layers\\.\\d+\\.mlp\\.experts\\.\\d+\\.down_proj\\.weight": { + "hints": { + "quant_method": "kai", + "kai_matmul_triplet": "f32_qai8dxp_qsi4c32p", + "kai_matmul_layout": "mxk_nxk", + "kai_matmul_tile_cfg": "qai8dxp4x8_qsi4c32p8x8_4x8x32", + "shape": [ + 1280, + 896 + ], + "replace": true + } + }, + "^model\\.layers\\.\\d+\\.mlp\\.experts\\.\\d+\\.gate_proj\\.weight": { + "hints": { + "quant_method": "kai", + "kai_matmul_triplet": "f32_qai8dxp_qsi4c32p", + "kai_matmul_layout": "mxk_nxk", + "kai_matmul_tile_cfg": "qai8dxp4x8_qsi4c32p8x8_4x8x32", + "shape": [ + 896, + 1280 + ], + "replace": true + } + }, + "^model\\.layers\\.\\d+\\.mlp\\.experts\\.\\d+\\.up_proj\\.weight": { + "hints": { + "quant_method": "kai", + "kai_matmul_triplet": "f32_qai8dxp_qsi4c32p", + "kai_matmul_layout": "mxk_nxk", + "kai_matmul_tile_cfg": "qai8dxp4x8_qsi4c32p8x8_4x8x32", + "shape": [ + 896, + 1280 + ], + "replace": true + } + }, + "^model\\.layers\\.\\d+\\.mlp\\.shared_experts\\.down_proj\\.weight": { + "hints": { + "quant_method": "kai", + "kai_matmul_triplet": "f32_qai8dxp_qsi4c32p", + "kai_matmul_layout": "mxk_nxk", + "kai_matmul_tile_cfg": "qai8dxp4x8_qsi4c32p8x8_4x8x32", + "shape": [ + 1280, + 1792 + ], + "replace": true + } + }, + "^model\\.layers\\.\\d+\\.mlp\\.shared_experts\\.gate_proj\\.weight": { + "hints": { + "quant_method": "kai", + "kai_matmul_triplet": "f32_qai8dxp_qsi4c32p", + "kai_matmul_layout": "mxk_nxk", + "kai_matmul_tile_cfg": "qai8dxp4x8_qsi4c32p8x8_4x8x32", + "shape": [ + 1792, + 1280 + ], + "replace": true + } + }, + "^model\\.layers\\.\\d+\\.mlp\\.shared_experts\\.up_proj\\.weight": { + "hints": { + "quant_method": "kai", + "kai_matmul_triplet": "f32_qai8dxp_qsi4c32p", + "kai_matmul_layout": "mxk_nxk", + "kai_matmul_tile_cfg": "qai8dxp4x8_qsi4c32p8x8_4x8x32", + "shape": [ + 1792, + 1280 + ], + "replace": true + } + }, + "^model\\.projector\\.layers.(bias|weight)": { + "hints": { + "quant_method": "kai", + "kai_matmul_triplet": "f32_qai8dxp_qsi4c32p", + "kai_matmul_layout": "mxk_nxk", + "kai_matmul_tile_cfg": "qai8dxp4x8_qsi4c32p8x8_4x8x32", + "shape": [ + 1280, + 2048 + ], + "replace": true + } + }, + "^model\\.sam_model\\.blocks\\.\\d+\\.attn\\.proj.(bias|weight)": { + "hints": { + "quant_method": "kai", + "kai_matmul_triplet": "f32_qai8dxp_qsi4c32p", + "kai_matmul_layout": "mxk_nxk", + "kai_matmul_tile_cfg": "qai8dxp4x8_qsi4c32p8x8_4x8x32", + "shape": [ + 768, + 768 + ], + "replace": true + } + }, + "^model\\.sam_model\\.blocks\\.\\d+\\.attn\\.qkv.(bias|weight)": { + "hints": { + "quant_method": "kai", + "kai_matmul_triplet": "f32_qai8dxp_qsi4c32p", + "kai_matmul_layout": "mxk_nxk", + "kai_matmul_tile_cfg": "qai8dxp4x8_qsi4c32p8x8_4x8x32", + "shape": [ + 2304, + 768 + ], + "replace": true + } + }, + "^model\\.sam_model\\.blocks\\.\\d+\\.mlp\\.lin1.(bias|weight)": { + "hints": { + "quant_method": "kai", + "kai_matmul_triplet": "f32_qai8dxp_qsi4c32p", + "kai_matmul_layout": "mxk_nxk", + "kai_matmul_tile_cfg": "qai8dxp4x8_qsi4c32p8x8_4x8x32", + "shape": [ + 3072, + 768 + ], + "replace": true + } + }, + "^model\\.sam_model\\.blocks\\.\\d+\\.mlp\\.lin2.(bias|weight)": { + "hints": { + "quant_method": "kai", + "kai_matmul_triplet": "f32_qai8dxp_qsi4c32p", + "kai_matmul_layout": "mxk_nxk", + "kai_matmul_tile_cfg": "qai8dxp4x8_qsi4c32p8x8_4x8x32", + "shape": [ + 768, + 3072 + ], + "replace": true + } + }, + "^model\\.vision_model\\.transformer\\.layers\\.\\d+\\.mlp\\.fc1.(bias|weight)": { + "hints": { + "quant_method": "kai", + "kai_matmul_triplet": "f32_qai8dxp_qsi4c32p", + "kai_matmul_layout": "mxk_nxk", + "kai_matmul_tile_cfg": "qai8dxp4x8_qsi4c32p8x8_4x8x32", + "shape": [ + 4096, + 1024 + ], + "replace": true + } + }, + "^model\\.vision_model\\.transformer\\.layers\\.\\d+\\.mlp\\.fc2.(bias|weight)": { + "hints": { + "quant_method": "kai", + "kai_matmul_triplet": "f32_qai8dxp_qsi4c32p", + "kai_matmul_layout": "mxk_nxk", + "kai_matmul_tile_cfg": "qai8dxp4x8_qsi4c32p8x8_4x8x32", + "shape": [ + 1024, + 4096 + ], + "replace": true + } + }, + "^model\\.vision_model\\.transformer\\.layers\\.\\d+\\.self_attn\\.out_proj.(bias|weight)": { + "hints": { + "quant_method": "kai", + "kai_matmul_triplet": "f32_qai8dxp_qsi4c32p", + "kai_matmul_layout": "mxk_nxk", + "kai_matmul_tile_cfg": "qai8dxp4x8_qsi4c32p8x8_4x8x32", + "shape": [ + 1024, + 1024 + ], + "replace": true + } + }, + "^model\\.vision_model\\.transformer\\.layers\\.\\d+\\.self_attn\\.qkv_proj.(bias|weight)": { + "hints": { + "quant_method": "kai", + "kai_matmul_triplet": "f32_qai8dxp_qsi4c32p", + "kai_matmul_layout": "mxk_nxk", + "kai_matmul_tile_cfg": "qai8dxp4x8_qsi4c32p8x8_4x8x32", + "shape": [ + 3072, + 1024 + ], + "replace": true + } + } +} diff --git a/mllm/backends/cpu/ops/Conv2DOp.cpp b/mllm/backends/cpu/ops/Conv2DOp.cpp index 0c1c58789..5770afb33 100644 --- a/mllm/backends/cpu/ops/Conv2DOp.cpp +++ b/mllm/backends/cpu/ops/Conv2DOp.cpp @@ -76,10 +76,14 @@ void CPUConv2DOp::forward(const std::vector& inputs, std::vector auto& dilation = options_.dilation; MLLM_RT_ASSERT_EQ(input.rank(), 4); + MLLM_RT_ASSERT_EQ(output.rank(), 4); auto batch_size = input.size(0); auto _1 = input.size(1); auto _2 = input.size(2); auto _3 = input.size(3); + auto _out_1 = output.size(1); + auto _out_2 = output.size(2); + auto _out_3 = output.size(3); switch (input.dtype()) { case kFloat32: { @@ -115,12 +119,13 @@ void CPUConv2DOp::forward(const std::vector& inputs, std::vector switch (mt) { // NOLINT case aops::MatMulOpType::kBLAS: { #if defined(MLLM_USE_BLAS) - blas::matmul_fp32(weight_.ptr(), packed_inputs.ptr(), output.ptr(), - nullptr, MATMUL_M, MATMUL_N, MATMUL_K, false, false); + blas::matmul_fp32(weight_.ptr(), packed_inputs.ptr(), + output.ptr() + _b_idx * (_out_1 * _out_2 * _out_3), nullptr, MATMUL_M, MATMUL_N, + MATMUL_K, false, false); // Add Bias if (options_.bias) { - auto out_ptr = output.ptr(); + auto out_ptr = output.ptr() + _b_idx * (_out_1 * _out_2 * _out_3); const auto bias_ptr = bias_.ptr(); for (int m = 0; m < MATMUL_M; ++m) { const float b = bias_ptr[m]; @@ -135,11 +140,12 @@ void CPUConv2DOp::forward(const std::vector& inputs, std::vector case aops::MatMulOpType::kMllmBlas: { auto thread_count = options_.getThreads(); #if defined(MLLM_HOST_ARCH_ARM64) || defined(MLLM_HOST_ARCH_ARM) - arm::mllm_blas_matmul_fp32(MATMUL_M, MATMUL_K, MATMUL_N, output.ptr(), weight_.ptr(), - packed_inputs.ptr(), nullptr, false, false, thread_count); + arm::mllm_blas_matmul_fp32( + MATMUL_M, MATMUL_K, MATMUL_N, output.ptr() + _b_idx * (_out_1 * _out_2 * _out_3), + weight_.ptr(), packed_inputs.ptr(), nullptr, false, false, thread_count); // Add Bias if (options_.bias) { - auto out_ptr = output.ptr(); + auto out_ptr = output.ptr() + _b_idx * (_out_1 * _out_2 * _out_3); const auto bias_ptr = bias_.ptr(); for (int m = 0; m < MATMUL_M; ++m) { const float b = bias_ptr[m]; @@ -158,7 +164,6 @@ void CPUConv2DOp::forward(const std::vector& inputs, std::vector #else MLLM_ERROR_EXIT(ExitCode::kCoreError, "Unsupported architecture for perform im2col conv2d."); #endif - break; } } } diff --git a/mllm/models/deepseek_ocr/configuration_deepseek_ocr.hpp b/mllm/models/deepseek_ocr/configuration_deepseek_ocr.hpp index 883be51a1..ddc6a47cd 100644 --- a/mllm/models/deepseek_ocr/configuration_deepseek_ocr.hpp +++ b/mllm/models/deepseek_ocr/configuration_deepseek_ocr.hpp @@ -96,6 +96,11 @@ struct DpskOcrConfig : protected ConfigFile { use_mla = data()["use_mla"]; v_head_dim = data()["v_head_dim"]; vocab_size = data()["vocab_size"]; + clip_linear_impl_type = aops::str2LinearImplTypes(data()["clip_linear_impl_type"]); + llm_mlp_linear_impl_type = aops::str2LinearImplTypes(data()["llm_mlp_linear_impl_type"]); + lm_head_linear_impl_type = aops::str2LinearImplTypes(data()["lm_head_linear_impl_type"]); + mlp_projector_linear_impl_type = aops::str2LinearImplTypes(data()["mlp_projector_linear_impl_type"]); + sam_linear_impl_type = aops::str2LinearImplTypes(data()["sam_linear_impl_type"]); } // Nested structs for complex configuration diff --git a/mllm/models/deepseek_ocr/deepencoder.hpp b/mllm/models/deepseek_ocr/deepencoder.hpp index c454c0b13..bd7fa6dec 100644 --- a/mllm/models/deepseek_ocr/deepencoder.hpp +++ b/mllm/models/deepseek_ocr/deepencoder.hpp @@ -707,6 +707,9 @@ class ImageEncoderViT final : public nn::Module { x = x + getAbsPosSam(pos_embed_.weight(), x.size(1)); for (auto& blk : blocks_.list()) { x = blk(x)[0]; } + print(x); + exit(0); + x = neck_(x.permute({0, 3, 1, 2}))[0]; x = net_2_(x); x = net_3_(x); diff --git a/mllm/models/deepseek_ocr/modeling_deepseek_ocr.hpp b/mllm/models/deepseek_ocr/modeling_deepseek_ocr.hpp index d76310922..b19d151e8 100644 --- a/mllm/models/deepseek_ocr/modeling_deepseek_ocr.hpp +++ b/mllm/models/deepseek_ocr/modeling_deepseek_ocr.hpp @@ -476,6 +476,9 @@ class DeepseekOCRModel final : public DeepSeekV2Model { if (nn::functional::sum(patches).item() != 0) { // Local features auto local_features_1 = sam_model_(patches)[0]; + print(local_features_1.shape()); + print(local_features_1); + exit(0); auto local_features_2 = vision_model_(patches, local_features_1)[0]; auto local_features = nn::functional::concat( { diff --git a/mllm/nn/Functional.cpp b/mllm/nn/Functional.cpp index 9f8e981b9..3931f5d66 100644 --- a/mllm/nn/Functional.cpp +++ b/mllm/nn/Functional.cpp @@ -127,8 +127,8 @@ void scatter2Shards(const Tensor& src, const Tensor& shards_pointer, int32_t dim } Tensor scaledDotProductAttention(const Tensor& Q, const Tensor& K, const Tensor& V, const Tensor& mask) { - auto scale = Q.size(-1); - scale = (1.f / sqrtf(scale)); + float scale = Q.size(-1); + scale = (1.f / std::sqrtf((float)scale)); auto attn_weight = matmul(Q, K, false, true) * scale; if (mask) { attn_weight = attn_weight + mask; } attn_weight = softmax(attn_weight, -1); diff --git a/mllm/preprocessor/visual/Image.cpp b/mllm/preprocessor/visual/Image.cpp index 3c06ae1a4..a6733480a 100644 --- a/mllm/preprocessor/visual/Image.cpp +++ b/mllm/preprocessor/visual/Image.cpp @@ -39,7 +39,7 @@ Image Image::open(const std::string& fp) { return ret_image; } -Image Image::resize(int new_w, int new_h) { +Image Image::resize(int new_w, int new_h, const std::string& method) { Image new_img; new_img.w_ = new_w; new_img.h_ = new_h; @@ -47,9 +47,14 @@ Image Image::resize(int new_w, int new_h) { unsigned char* output_data = nullptr; - // stb will alloc memory for us - output_data = stbir_resize_uint8_linear(static_cast(image_ptr_->ptr_), w_, h_, 0, output_data, new_w, new_h, - 0, STBIR_RGB); + if (method == "bilinear") { + // stb will alloc memory for us + output_data = stbir_resize_uint8_linear(static_cast(image_ptr_->ptr_), w_, h_, 0, output_data, new_w, new_h, + 0, STBIR_RGB); + } else if (method == "bicubic") { + output_data = (unsigned char*)stbir_resize(static_cast(image_ptr_->ptr_), w_, h_, 0, (void*)output_data, new_w, + new_h, 0, STBIR_RGB, STBIR_TYPE_UINT8, STBIR_EDGE_CLAMP, STBIR_FILTER_MITCHELL); + } new_img.image_ptr_ = std::make_shared<_ImagePtr>(); new_img.image_ptr_->ptr_ = output_data; @@ -164,7 +169,7 @@ Image Image::pad(int target_w, int target_h, unsigned char r, unsigned char g, u new_h = std::max(1, new_h); // Resize current image to the computed size - Image resized = this->resize(new_w, new_h); + Image resized = this->resize(new_w, new_h, "bicubic"); // Prepare output canvas filled with color Image out; @@ -192,10 +197,12 @@ Image Image::pad(int target_w, int target_h, unsigned char r, unsigned char g, u // Blit resized image onto the canvas for (int y = 0; y < new_h; ++y) { const int dy = offset_y + y; + if (dy < 0 || dy >= target_h) continue; for (int x = 0; x < new_w; ++x) { const int dx = offset_x + x; - unsigned char* dst_px = canvas + (static_cast(dy) * target_w + dx) * out.c_; - const unsigned char* src_px = src + (static_cast(y) * new_w + x) * resized.c_; + if (dx < 0 || dx >= target_w) continue; + unsigned char* dst_px = canvas + (static_cast(dy) * target_w + dx) * 3; + const unsigned char* src_px = src + (static_cast(y) * new_w + x) * 3; dst_px[0] = src_px[0]; dst_px[1] = src_px[1]; dst_px[2] = src_px[2]; @@ -207,4 +214,4 @@ Image Image::pad(int target_w, int target_h, unsigned char r, unsigned char g, u return out; } -} // namespace mllm \ No newline at end of file +} // namespace mllm diff --git a/mllm/preprocessor/visual/Image.hpp b/mllm/preprocessor/visual/Image.hpp index 98e08c72b..2f3de4e03 100644 --- a/mllm/preprocessor/visual/Image.hpp +++ b/mllm/preprocessor/visual/Image.hpp @@ -28,7 +28,7 @@ class Image { public: static Image open(const std::string& fp); - Image resize(int new_w, int new_h); + Image resize(int new_w, int new_h, const std::string& method = "bilinear"); // Crop the image with PIL-style box (left, upper, right, lower). // Out-of-bounds areas are padded with zeros. Returns a new Image. @@ -57,4 +57,4 @@ class Image { std::shared_ptr<_ImagePtr> image_ptr_ = nullptr; }; -} // namespace mllm \ No newline at end of file +} // namespace mllm From e861deb10c6ad1d498448b05cae6a5df05a9fc2e Mon Sep 17 00:00:00 2001 From: chenghuaWang <2923277184@qq.com> Date: Tue, 28 Oct 2025 22:43:20 +0800 Subject: [PATCH 23/25] feat(deepseek_ocr): update model paths and quantization config - Switch model loading path from w32a32 to w4a8-i8mm-kai variant - Remove debug print statements and exit calls in encoder and model files - Adjust maskedScatter call to unsqueeze images_seq_mask for proper tensor shape - Replace fmt::print with std::cout for token decoding output - Clean up unused kai quantization rules in quantization config file --- examples/deepseek_ocr/main.cpp | 12 +- .../deepseek_ocr/quant_cfg_w4a8_kai_i8mm.json | 117 ------------------ mllm/models/deepseek_ocr/deepencoder.hpp | 3 - .../deepseek_ocr/modeling_deepseek_ocr.hpp | 10 +- 4 files changed, 14 insertions(+), 128 deletions(-) diff --git a/examples/deepseek_ocr/main.cpp b/examples/deepseek_ocr/main.cpp index 45e03e24a..cf819be08 100644 --- a/examples/deepseek_ocr/main.cpp +++ b/examples/deepseek_ocr/main.cpp @@ -5,10 +5,16 @@ using mllm::Argparse; MLLM_MAIN({ - auto config = mllm::models::deepseek_ocr::DpskOcrConfig("/Volumes/D/mllm-models/DeepSeek-OCR-w32a32/config.json"); + // auto config = mllm::models::deepseek_ocr::DpskOcrConfig("/Volumes/D/mllm-models/DeepSeek-OCR-w32a32/config.json"); + // auto model = mllm::models::deepseek_ocr::DeepseekOCRForCausalLM(config); + // auto tokenizer = mllm::models::deepseek_ocr::DpskOcrTokenizer("/Volumes/D/mllm-models/DeepSeek-OCR-w32a32/tokenizer.json"); + // model.load(mllm::load("/Volumes/D/mllm-models/DeepSeek-OCR-w32a32/model.mllm", mllm::ModelFileVersion::kV2)); + mllm::setLogLevel(mllm::LogLevel::kError); + auto config = mllm::models::deepseek_ocr::DpskOcrConfig("/Volumes/D/mllm-models/DeepSeek-OCR-w4a8-i8mm-kai/config.json"); auto model = mllm::models::deepseek_ocr::DeepseekOCRForCausalLM(config); - auto tokenizer = mllm::models::deepseek_ocr::DpskOcrTokenizer("/Volumes/D/mllm-models/DeepSeek-OCR-w32a32/tokenizer.json"); - model.load(mllm::load("/Volumes/D/mllm-models/DeepSeek-OCR-w32a32/model.mllm", mllm::ModelFileVersion::kV2)); + auto tokenizer = + mllm::models::deepseek_ocr::DpskOcrTokenizer("/Volumes/D/mllm-models/DeepSeek-OCR-w4a8-i8mm-kai/tokenizer.json"); + model.load(mllm::load("/Volumes/D/mllm-models/DeepSeek-OCR-w4a8-i8mm-kai/model.mllm", mllm::ModelFileVersion::kV2)); model.infer(tokenizer, "\n<|grounding|>Convert the document to markdown. ", "/Volumes/D/mllm/.tmp/dpsk-ocr-pr.png", "/Volumes/D/mllm/.tmp/dpsk-ocr"); diff --git a/examples/deepseek_ocr/quant_cfg_w4a8_kai_i8mm.json b/examples/deepseek_ocr/quant_cfg_w4a8_kai_i8mm.json index 4d58d855b..a35346d45 100644 --- a/examples/deepseek_ocr/quant_cfg_w4a8_kai_i8mm.json +++ b/examples/deepseek_ocr/quant_cfg_w4a8_kai_i8mm.json @@ -180,122 +180,5 @@ ], "replace": true } - }, - "^model\\.projector\\.layers.(bias|weight)": { - "hints": { - "quant_method": "kai", - "kai_matmul_triplet": "f32_qai8dxp_qsi4c32p", - "kai_matmul_layout": "mxk_nxk", - "kai_matmul_tile_cfg": "qai8dxp4x8_qsi4c32p8x8_4x8x32", - "shape": [ - 1280, - 2048 - ], - "replace": true - } - }, - "^model\\.sam_model\\.blocks\\.\\d+\\.attn\\.proj.(bias|weight)": { - "hints": { - "quant_method": "kai", - "kai_matmul_triplet": "f32_qai8dxp_qsi4c32p", - "kai_matmul_layout": "mxk_nxk", - "kai_matmul_tile_cfg": "qai8dxp4x8_qsi4c32p8x8_4x8x32", - "shape": [ - 768, - 768 - ], - "replace": true - } - }, - "^model\\.sam_model\\.blocks\\.\\d+\\.attn\\.qkv.(bias|weight)": { - "hints": { - "quant_method": "kai", - "kai_matmul_triplet": "f32_qai8dxp_qsi4c32p", - "kai_matmul_layout": "mxk_nxk", - "kai_matmul_tile_cfg": "qai8dxp4x8_qsi4c32p8x8_4x8x32", - "shape": [ - 2304, - 768 - ], - "replace": true - } - }, - "^model\\.sam_model\\.blocks\\.\\d+\\.mlp\\.lin1.(bias|weight)": { - "hints": { - "quant_method": "kai", - "kai_matmul_triplet": "f32_qai8dxp_qsi4c32p", - "kai_matmul_layout": "mxk_nxk", - "kai_matmul_tile_cfg": "qai8dxp4x8_qsi4c32p8x8_4x8x32", - "shape": [ - 3072, - 768 - ], - "replace": true - } - }, - "^model\\.sam_model\\.blocks\\.\\d+\\.mlp\\.lin2.(bias|weight)": { - "hints": { - "quant_method": "kai", - "kai_matmul_triplet": "f32_qai8dxp_qsi4c32p", - "kai_matmul_layout": "mxk_nxk", - "kai_matmul_tile_cfg": "qai8dxp4x8_qsi4c32p8x8_4x8x32", - "shape": [ - 768, - 3072 - ], - "replace": true - } - }, - "^model\\.vision_model\\.transformer\\.layers\\.\\d+\\.mlp\\.fc1.(bias|weight)": { - "hints": { - "quant_method": "kai", - "kai_matmul_triplet": "f32_qai8dxp_qsi4c32p", - "kai_matmul_layout": "mxk_nxk", - "kai_matmul_tile_cfg": "qai8dxp4x8_qsi4c32p8x8_4x8x32", - "shape": [ - 4096, - 1024 - ], - "replace": true - } - }, - "^model\\.vision_model\\.transformer\\.layers\\.\\d+\\.mlp\\.fc2.(bias|weight)": { - "hints": { - "quant_method": "kai", - "kai_matmul_triplet": "f32_qai8dxp_qsi4c32p", - "kai_matmul_layout": "mxk_nxk", - "kai_matmul_tile_cfg": "qai8dxp4x8_qsi4c32p8x8_4x8x32", - "shape": [ - 1024, - 4096 - ], - "replace": true - } - }, - "^model\\.vision_model\\.transformer\\.layers\\.\\d+\\.self_attn\\.out_proj.(bias|weight)": { - "hints": { - "quant_method": "kai", - "kai_matmul_triplet": "f32_qai8dxp_qsi4c32p", - "kai_matmul_layout": "mxk_nxk", - "kai_matmul_tile_cfg": "qai8dxp4x8_qsi4c32p8x8_4x8x32", - "shape": [ - 1024, - 1024 - ], - "replace": true - } - }, - "^model\\.vision_model\\.transformer\\.layers\\.\\d+\\.self_attn\\.qkv_proj.(bias|weight)": { - "hints": { - "quant_method": "kai", - "kai_matmul_triplet": "f32_qai8dxp_qsi4c32p", - "kai_matmul_layout": "mxk_nxk", - "kai_matmul_tile_cfg": "qai8dxp4x8_qsi4c32p8x8_4x8x32", - "shape": [ - 3072, - 1024 - ], - "replace": true - } } } diff --git a/mllm/models/deepseek_ocr/deepencoder.hpp b/mllm/models/deepseek_ocr/deepencoder.hpp index bd7fa6dec..c454c0b13 100644 --- a/mllm/models/deepseek_ocr/deepencoder.hpp +++ b/mllm/models/deepseek_ocr/deepencoder.hpp @@ -707,9 +707,6 @@ class ImageEncoderViT final : public nn::Module { x = x + getAbsPosSam(pos_embed_.weight(), x.size(1)); for (auto& blk : blocks_.list()) { x = blk(x)[0]; } - print(x); - exit(0); - x = neck_(x.permute({0, 3, 1, 2}))[0]; x = net_2_(x); x = net_3_(x); diff --git a/mllm/models/deepseek_ocr/modeling_deepseek_ocr.hpp b/mllm/models/deepseek_ocr/modeling_deepseek_ocr.hpp index b19d151e8..1a1ed955a 100644 --- a/mllm/models/deepseek_ocr/modeling_deepseek_ocr.hpp +++ b/mllm/models/deepseek_ocr/modeling_deepseek_ocr.hpp @@ -2,6 +2,7 @@ // Licensed under the MIT License. #pragma once +#include #include #include #include @@ -476,9 +477,6 @@ class DeepseekOCRModel final : public DeepSeekV2Model { if (nn::functional::sum(patches).item() != 0) { // Local features auto local_features_1 = sam_model_(patches)[0]; - print(local_features_1.shape()); - print(local_features_1); - exit(0); auto local_features_2 = vision_model_(patches, local_features_1)[0]; auto local_features = nn::functional::concat( { @@ -616,7 +614,9 @@ class DeepseekOCRModel final : public DeepSeekV2Model { } // Scatter copy. - if (images_in_this_batch) { nn::functional::maskedScatter(inputs_embeds, images_seq_mask, images_in_this_batch); } + if (images_in_this_batch) { + nn::functional::maskedScatter(inputs_embeds, images_seq_mask.unsqueeze(-1), images_in_this_batch); + } auto sequence = DeepSeekV2Model::forward({inputs_embeds, rope_embedding_sin, rope_embedding_cos}, args)[0]; @@ -897,7 +897,7 @@ class DeepseekOCRForCausalLM final : public nn::Module, public ARGeneration { [&](int64_t token_id) { auto decode = tokenizer.decode({token_id}); result << decode; - fmt::print("{}", decode); + std::cout << decode << std::flush; }); print("\n"); ///< flush From 05811484038996b7ef5207b420eb4055edbcc273 Mon Sep 17 00:00:00 2001 From: chenghuaWang <2923277184@qq.com> Date: Wed, 29 Oct 2025 09:57:04 +0800 Subject: [PATCH 24/25] feat(examples): update deepseek_ocr example to use argparse for model paths - Replace hardcoded paths with command-line arguments for model, tokenizer, and config - Update include paths to use angle brackets instead of quotes - Adjust image and output paths for better portability - Add help option to display usage information --- examples/deepseek_ocr/main.cpp | 25 +++++++++++++------------ 1 file changed, 13 insertions(+), 12 deletions(-) diff --git a/examples/deepseek_ocr/main.cpp b/examples/deepseek_ocr/main.cpp index cf819be08..86408f68c 100644 --- a/examples/deepseek_ocr/main.cpp +++ b/examples/deepseek_ocr/main.cpp @@ -1,21 +1,22 @@ #include -#include "mllm/models/deepseek_ocr/modeling_deepseek_ocr.hpp" -#include "mllm/models/deepseek_ocr/tokenization_deepseek_ocr.hpp" +#include +#include using mllm::Argparse; MLLM_MAIN({ - // auto config = mllm::models::deepseek_ocr::DpskOcrConfig("/Volumes/D/mllm-models/DeepSeek-OCR-w32a32/config.json"); - // auto model = mllm::models::deepseek_ocr::DeepseekOCRForCausalLM(config); - // auto tokenizer = mllm::models::deepseek_ocr::DpskOcrTokenizer("/Volumes/D/mllm-models/DeepSeek-OCR-w32a32/tokenizer.json"); - // model.load(mllm::load("/Volumes/D/mllm-models/DeepSeek-OCR-w32a32/model.mllm", mllm::ModelFileVersion::kV2)); + auto& help = Argparse::add("-h|--help").help("Show help message"); + auto& model_path = Argparse::add("-m|--model_path").help("Model path").required(true); + auto& tokenizer_path = Argparse::add("-t|--tokenizer_path").help("Tokenizer directory").required(true); + auto& config_path = Argparse::add("-c|--config_path").help("Config path").required(true); + + Argparse::parse(argc, argv); mllm::setLogLevel(mllm::LogLevel::kError); - auto config = mllm::models::deepseek_ocr::DpskOcrConfig("/Volumes/D/mllm-models/DeepSeek-OCR-w4a8-i8mm-kai/config.json"); + + auto config = mllm::models::deepseek_ocr::DpskOcrConfig(config_path.get()); auto model = mllm::models::deepseek_ocr::DeepseekOCRForCausalLM(config); - auto tokenizer = - mllm::models::deepseek_ocr::DpskOcrTokenizer("/Volumes/D/mllm-models/DeepSeek-OCR-w4a8-i8mm-kai/tokenizer.json"); - model.load(mllm::load("/Volumes/D/mllm-models/DeepSeek-OCR-w4a8-i8mm-kai/model.mllm", mllm::ModelFileVersion::kV2)); + auto tokenizer = mllm::models::deepseek_ocr::DpskOcrTokenizer(tokenizer_path.get()); + model.load(mllm::load(model_path.get(), mllm::ModelFileVersion::kV2)); - model.infer(tokenizer, "\n<|grounding|>Convert the document to markdown. ", "/Volumes/D/mllm/.tmp/dpsk-ocr-pr.png", - "/Volumes/D/mllm/.tmp/dpsk-ocr"); + model.infer(tokenizer, "\n<|grounding|>Convert the document to markdown. ", "dpsk-ocr-pr.png", "."); }); From 960b9ebd4d8113b44f7142141c56462b91488594 Mon Sep 17 00:00:00 2001 From: chenghuaWang <2923277184@qq.com> Date: Wed, 29 Oct 2025 10:05:38 +0800 Subject: [PATCH 25/25] refactor(deepseek_ocr): reformat function signature and add eos token initialization - Reformat `makeRotaryPosEmbedding` function signature to improve readability - Initialize `eos_token_id_` from config in DeepseekOCRForCausalLM constructor --- mllm/models/deepseek_ocr/modeling_deepseek_ocr.hpp | 3 +++ 1 file changed, 3 insertions(+) diff --git a/mllm/models/deepseek_ocr/modeling_deepseek_ocr.hpp b/mllm/models/deepseek_ocr/modeling_deepseek_ocr.hpp index 1a1ed955a..487ba38f2 100644 --- a/mllm/models/deepseek_ocr/modeling_deepseek_ocr.hpp +++ b/mllm/models/deepseek_ocr/modeling_deepseek_ocr.hpp @@ -650,6 +650,9 @@ class DeepseekOCRForCausalLM final : public nn::Module, public ARGeneration { kCPU, // device_type true // use_fa2 ); + + // eos + eos_token_id_ = config.eos_token_id; } ARGenerationOutputPast forward(const ARGenerationOutputPast& input, const ARGenerationArgs& args) override {