|
| 1 | +# Copyright (C) 2025 Arcee AI |
| 2 | +# SPDX-License-Identifier: BUSL-1.1 |
| 3 | + |
| 4 | +import logging |
| 5 | +from functools import lru_cache |
| 6 | +from typing import TYPE_CHECKING, Optional |
| 7 | + |
| 8 | +from transformers import PretrainedConfig |
| 9 | + |
| 10 | +from mergekit.architecture.auto import infer_architecture_info |
| 11 | +from mergekit.architecture.base import ( |
| 12 | + ConfiguredModelArchitecture, |
| 13 | + ConfiguredModuleArchitecture, |
| 14 | + ModelArchitecture, |
| 15 | + ModuleArchitecture, |
| 16 | + ModuleDefinition, |
| 17 | + WeightInfo, |
| 18 | +) |
| 19 | +from mergekit.architecture.json_definitions import NAME_TO_ARCH |
| 20 | +from mergekit.architecture.moe_defs import ( |
| 21 | + Ernie4_5_MoeModuleArchitecture, |
| 22 | + GptOssModuleArchitecture, |
| 23 | + MixtralModuleArchitecture, |
| 24 | + Qwen3MoeModuleArchitecture, |
| 25 | +) |
| 26 | +from mergekit.options import MergeOptions |
| 27 | + |
| 28 | +if TYPE_CHECKING: |
| 29 | + from mergekit.config import MergeConfiguration |
| 30 | + |
| 31 | +LOG = logging.getLogger(__name__) |
| 32 | + |
| 33 | +WARNED_ARCHITECTURE_NAMES = set() |
| 34 | + |
| 35 | + |
| 36 | +def arch_info_for_config(config: PretrainedConfig) -> Optional[ModelArchitecture]: |
| 37 | + if len(config.architectures) != 1: |
| 38 | + raise RuntimeError("More than one architecture in config?") |
| 39 | + arch_name = config.architectures[0] |
| 40 | + |
| 41 | + if arch_name == MixtralModuleArchitecture.ARCHITECTURE_NAME: |
| 42 | + module = MixtralModuleArchitecture.from_config(config) |
| 43 | + return ModelArchitecture( |
| 44 | + modules={"default": ModuleDefinition(architecture=module)}, |
| 45 | + architectures=[arch_name], |
| 46 | + model_type="mixtral", |
| 47 | + ) |
| 48 | + elif arch_name == Qwen3MoeModuleArchitecture.ARCHITECTURE_NAME: |
| 49 | + module = Qwen3MoeModuleArchitecture.from_config(config) |
| 50 | + return ModelArchitecture( |
| 51 | + modules={"default": ModuleDefinition(architecture=module)}, |
| 52 | + architectures=[arch_name], |
| 53 | + model_type="qwen3_moe", |
| 54 | + ) |
| 55 | + elif arch_name == Ernie4_5_MoeModuleArchitecture.ARCHITECTURE_NAME: |
| 56 | + module = Ernie4_5_MoeModuleArchitecture.from_config(config) |
| 57 | + return ModelArchitecture( |
| 58 | + modules={"default": ModuleDefinition(architecture=module)}, |
| 59 | + architectures=[arch_name], |
| 60 | + model_type="ernie4_5_moe", |
| 61 | + ) |
| 62 | + elif arch_name == GptOssModuleArchitecture.ARCHITECTURE_NAME: |
| 63 | + module = GptOssModuleArchitecture.from_config(config) |
| 64 | + return ModelArchitecture( |
| 65 | + modules={"default": ModuleDefinition(architecture=module)}, |
| 66 | + architectures=[arch_name], |
| 67 | + model_type="gpt_oss", |
| 68 | + ) |
| 69 | + elif arch_name in NAME_TO_ARCH: |
| 70 | + candidates = list(NAME_TO_ARCH[arch_name]) |
| 71 | + if len(candidates) == 1: |
| 72 | + return candidates[0] |
| 73 | + |
| 74 | + for c in candidates: |
| 75 | + if c.expected_model_type == config.model_type: |
| 76 | + return c |
| 77 | + LOG.warning( |
| 78 | + f"Multiple architectures for {arch_name}, none match model type {config.model_type}" |
| 79 | + ) |
| 80 | + |
| 81 | + if arch_name not in WARNED_ARCHITECTURE_NAMES: |
| 82 | + LOG.warning(f"No JSON architecture found for {arch_name}") |
| 83 | + WARNED_ARCHITECTURE_NAMES.add(arch_name) |
| 84 | + return None |
| 85 | + |
| 86 | + |
| 87 | +def get_architecture_info( |
| 88 | + config: "MergeConfiguration", options: MergeOptions |
| 89 | +) -> ModelArchitecture: |
| 90 | + models = config.referenced_models() |
| 91 | + if not models: |
| 92 | + raise ValueError("No models referenced in config") |
| 93 | + |
| 94 | + model_arch_info = [ |
| 95 | + arch_info_for_config(m.config(trust_remote_code=options.trust_remote_code)) |
| 96 | + for m in models |
| 97 | + ] |
| 98 | + if all(arch is not None for arch in model_arch_info): |
| 99 | + if not options.allow_crimes and any( |
| 100 | + arch != model_arch_info[0] for arch in model_arch_info |
| 101 | + ): |
| 102 | + raise RuntimeError( |
| 103 | + "Must specify --allow-crimes to attempt to mix different architectures" |
| 104 | + ) |
| 105 | + return model_arch_info[0] |
| 106 | + |
| 107 | + # try to infer from all models |
| 108 | + return infer_architecture_info(tuple(models), config.base_model, options) |
| 109 | + |
| 110 | + |
| 111 | +__all__ = [ |
| 112 | + "ModelArchitecture", |
| 113 | + "ModuleArchitecture", |
| 114 | + "ModuleDefinition", |
| 115 | + "ConfiguredModuleArchitecture", |
| 116 | + "ConfiguredModelArchitecture", |
| 117 | + "WeightInfo", |
| 118 | + "get_architecture_info", |
| 119 | + "arch_info_for_config", |
| 120 | +] |
0 commit comments