Skip to content

Commit 03f1f30

Browse files
authored
Add files via upload
1 parent d8c4bc0 commit 03f1f30

File tree

1 file changed

+120
-0
lines changed

1 file changed

+120
-0
lines changed

__init__.py

Lines changed: 120 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,120 @@
1+
# Copyright (C) 2025 Arcee AI
2+
# SPDX-License-Identifier: BUSL-1.1
3+
4+
import logging
5+
from functools import lru_cache
6+
from typing import TYPE_CHECKING, Optional
7+
8+
from transformers import PretrainedConfig
9+
10+
from mergekit.architecture.auto import infer_architecture_info
11+
from mergekit.architecture.base import (
12+
ConfiguredModelArchitecture,
13+
ConfiguredModuleArchitecture,
14+
ModelArchitecture,
15+
ModuleArchitecture,
16+
ModuleDefinition,
17+
WeightInfo,
18+
)
19+
from mergekit.architecture.json_definitions import NAME_TO_ARCH
20+
from mergekit.architecture.moe_defs import (
21+
Ernie4_5_MoeModuleArchitecture,
22+
GptOssModuleArchitecture,
23+
MixtralModuleArchitecture,
24+
Qwen3MoeModuleArchitecture,
25+
)
26+
from mergekit.options import MergeOptions
27+
28+
if TYPE_CHECKING:
29+
from mergekit.config import MergeConfiguration
30+
31+
LOG = logging.getLogger(__name__)
32+
33+
WARNED_ARCHITECTURE_NAMES = set()
34+
35+
36+
def arch_info_for_config(config: PretrainedConfig) -> Optional[ModelArchitecture]:
37+
if len(config.architectures) != 1:
38+
raise RuntimeError("More than one architecture in config?")
39+
arch_name = config.architectures[0]
40+
41+
if arch_name == MixtralModuleArchitecture.ARCHITECTURE_NAME:
42+
module = MixtralModuleArchitecture.from_config(config)
43+
return ModelArchitecture(
44+
modules={"default": ModuleDefinition(architecture=module)},
45+
architectures=[arch_name],
46+
model_type="mixtral",
47+
)
48+
elif arch_name == Qwen3MoeModuleArchitecture.ARCHITECTURE_NAME:
49+
module = Qwen3MoeModuleArchitecture.from_config(config)
50+
return ModelArchitecture(
51+
modules={"default": ModuleDefinition(architecture=module)},
52+
architectures=[arch_name],
53+
model_type="qwen3_moe",
54+
)
55+
elif arch_name == Ernie4_5_MoeModuleArchitecture.ARCHITECTURE_NAME:
56+
module = Ernie4_5_MoeModuleArchitecture.from_config(config)
57+
return ModelArchitecture(
58+
modules={"default": ModuleDefinition(architecture=module)},
59+
architectures=[arch_name],
60+
model_type="ernie4_5_moe",
61+
)
62+
elif arch_name == GptOssModuleArchitecture.ARCHITECTURE_NAME:
63+
module = GptOssModuleArchitecture.from_config(config)
64+
return ModelArchitecture(
65+
modules={"default": ModuleDefinition(architecture=module)},
66+
architectures=[arch_name],
67+
model_type="gpt_oss",
68+
)
69+
elif arch_name in NAME_TO_ARCH:
70+
candidates = list(NAME_TO_ARCH[arch_name])
71+
if len(candidates) == 1:
72+
return candidates[0]
73+
74+
for c in candidates:
75+
if c.expected_model_type == config.model_type:
76+
return c
77+
LOG.warning(
78+
f"Multiple architectures for {arch_name}, none match model type {config.model_type}"
79+
)
80+
81+
if arch_name not in WARNED_ARCHITECTURE_NAMES:
82+
LOG.warning(f"No JSON architecture found for {arch_name}")
83+
WARNED_ARCHITECTURE_NAMES.add(arch_name)
84+
return None
85+
86+
87+
def get_architecture_info(
88+
config: "MergeConfiguration", options: MergeOptions
89+
) -> ModelArchitecture:
90+
models = config.referenced_models()
91+
if not models:
92+
raise ValueError("No models referenced in config")
93+
94+
model_arch_info = [
95+
arch_info_for_config(m.config(trust_remote_code=options.trust_remote_code))
96+
for m in models
97+
]
98+
if all(arch is not None for arch in model_arch_info):
99+
if not options.allow_crimes and any(
100+
arch != model_arch_info[0] for arch in model_arch_info
101+
):
102+
raise RuntimeError(
103+
"Must specify --allow-crimes to attempt to mix different architectures"
104+
)
105+
return model_arch_info[0]
106+
107+
# try to infer from all models
108+
return infer_architecture_info(tuple(models), config.base_model, options)
109+
110+
111+
__all__ = [
112+
"ModelArchitecture",
113+
"ModuleArchitecture",
114+
"ModuleDefinition",
115+
"ConfiguredModuleArchitecture",
116+
"ConfiguredModelArchitecture",
117+
"WeightInfo",
118+
"get_architecture_info",
119+
"arch_info_for_config",
120+
]

0 commit comments

Comments
 (0)