@@ -120,6 +120,7 @@ const char * llm_type_name(llm_type type) {
120120 case LLM_TYPE_30B_A3B: return "30B.A3B";
121121 case LLM_TYPE_100B_A6B: return "100B.A6B";
122122 case LLM_TYPE_106B_A12B: return "106B.A12B";
123+ case LLM_TYPE_230B_A10B: return "230B.A10B";
123124 case LLM_TYPE_235B_A22B: return "235B.A22B";
124125 case LLM_TYPE_300B_A47B: return "300B.A47B";
125126 case LLM_TYPE_355B_A32B: return "355B.A32B";
@@ -2155,6 +2156,17 @@ void llama_model::load_hparams(llama_model_loader & ml) {
21552156 default: type = LLM_TYPE_UNKNOWN;
21562157 }
21572158 } break;
2159+ case LLM_ARCH_MINIMAX_M2:
2160+ {
2161+ ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
2162+ ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp);
2163+ ml.get_key(LLM_KV_EXPERT_GATING_FUNC, hparams.expert_gating_func, false);
2164+
2165+ switch (hparams.n_layer) {
2166+ case 62: type = LLM_TYPE_230B_A10B; break;
2167+ default: type = LLM_TYPE_UNKNOWN;
2168+ }
2169+ } break;
21582170 case LLM_ARCH_COGVLM:
21592171 {
21602172 ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
@@ -6185,6 +6197,35 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
61856197 layer.attn_k_norm_b = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "bias", i), { n_embd_head_k }, TENSOR_NOT_REQUIRED);
61866198 }
61876199 } break;
6200+ case LLM_ARCH_MINIMAX_M2:
6201+ {
6202+ tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
6203+
6204+ // output
6205+ output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0);
6206+ output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, 0);
6207+
6208+ for (int i = 0; i < n_layer; ++i) {
6209+ auto & layer = layers[i];
6210+
6211+ layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), { n_embd, n_embd_head_k * n_head }, 0);
6212+ layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), { n_embd, n_embd_gqa }, 0);
6213+ layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), { n_embd, n_embd_gqa }, 0);
6214+ layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), { n_embd_head_k * n_head, n_embd }, 0);
6215+
6216+ layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0);
6217+ layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {n_embd_head_k * n_head}, 0);
6218+ layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {n_embd_k_gqa}, 0);
6219+
6220+ layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0);
6221+
6222+ layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert}, 0);
6223+ layer.ffn_gate_exps = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), {n_embd, n_ff, n_expert}, 0);
6224+ layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), {n_ff, n_embd, n_expert}, 0);
6225+ layer.ffn_up_exps = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), {n_embd, n_ff, n_expert}, 0);
6226+ layer.ffn_exp_probs_b = create_tensor(tn(LLM_TENSOR_FFN_EXP_PROBS_B, "bias", i), {n_expert}, 0);
6227+ }
6228+ } break;
61886229 case LLM_ARCH_COGVLM:
61896230 {
61906231 tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
@@ -20024,6 +20065,130 @@ struct llm_build_apertus : public llm_graph_context {
2002420065 }
2002520066};
2002620067
20068+ struct llm_build_minimax_m2 : public llm_graph_context {
20069+ llm_build_minimax_m2(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) {
20070+ const int64_t n_embd_head = hparams.n_embd_head_v;
20071+
20072+ GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
20073+ // GGML_ASSERT(n_embd_head == hparams.n_rot); this is wrong in case of minimax, head_dim = 128, n_rot = 64
20074+
20075+ ggml_tensor * cur;
20076+ ggml_tensor * inpL;
20077+
20078+ inpL = build_inp_embd(model.tok_embd);
20079+
20080+ ggml_tensor * inp_pos = build_inp_pos();
20081+ auto inp_attn = build_attn_inp_kv();
20082+ ggml_tensor * inp_out_ids = build_inp_out_ids();
20083+
20084+ for (int il = 0; il < n_layer; ++il) {
20085+ ggml_tensor * inpSA = inpL;
20086+
20087+ cur = inpL;
20088+
20089+ // self_attention
20090+ {
20091+ cur = build_norm(inpL, model.layers[il].attn_norm, NULL, LLM_NORM_RMS, il);
20092+ cb(cur, "attn_norm", il);
20093+
20094+ // compute Q and K and RoPE them
20095+ ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur);
20096+ cb(Qcur, "Qcur", il);
20097+
20098+ ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur);
20099+ cb(Kcur, "Kcur", il);
20100+
20101+ ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur);
20102+ cb(Vcur, "Vcur", il);
20103+
20104+ Qcur = build_norm(Qcur, model.layers[il].attn_q_norm, NULL,
20105+ LLM_NORM_RMS, il);
20106+ cb(Qcur, "Qcur_normed", il);
20107+
20108+ Kcur = build_norm(Kcur, model.layers[il].attn_k_norm, NULL,
20109+ LLM_NORM_RMS, il);
20110+ cb(Kcur, "Kcur_normed", il);
20111+
20112+ Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
20113+ Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
20114+ Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens);
20115+
20116+ Qcur = ggml_rope_ext(
20117+ ctx0, Qcur, inp_pos, nullptr,
20118+ n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
20119+ ext_factor, attn_factor, beta_fast, beta_slow
20120+ );
20121+
20122+ Kcur = ggml_rope_ext(
20123+ ctx0, Kcur, inp_pos, nullptr,
20124+ n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
20125+ ext_factor, attn_factor, beta_fast, beta_slow
20126+ );
20127+
20128+ cb(Qcur, "Qcur", il);
20129+ cb(Kcur, "Kcur", il);
20130+ cb(Vcur, "Vcur", il);
20131+
20132+ cur = build_attn(inp_attn,
20133+ model.layers[il].wo, NULL,
20134+ Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
20135+ }
20136+
20137+ if (il == n_layer - 1 && inp_out_ids) {
20138+ cur = ggml_get_rows(ctx0, cur, inp_out_ids);
20139+ inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
20140+ }
20141+
20142+ ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA);
20143+ cb(ffn_inp, "ffn_inp", il);
20144+
20145+ // MoE branch
20146+ cur = build_norm(ffn_inp,
20147+ model.layers[il].ffn_norm, NULL,
20148+ LLM_NORM_RMS, il);
20149+ cb(cur, "ffn_norm", il);
20150+
20151+ cur = build_moe_ffn(cur,
20152+ model.layers[il].ffn_gate_inp,
20153+ model.layers[il].ffn_up_exps,
20154+ model.layers[il].ffn_gate_exps,
20155+ model.layers[il].ffn_down_exps,
20156+ model.layers[il].ffn_exp_probs_b,
20157+ n_expert, n_expert_used,
20158+ LLM_FFN_SILU, true,
20159+ false, 0.0,
20160+ (llama_expert_gating_func_type) hparams.expert_gating_func,
20161+ il);
20162+ cb(cur, "ffn_moe_out", il);
20163+
20164+ cur = ggml_add(ctx0, cur, ffn_inp);
20165+
20166+ cur = build_cvec(cur, il);
20167+ cb(cur, "l_out", il);
20168+
20169+ // input for next layer
20170+ inpL = cur;
20171+ }
20172+
20173+ cur = inpL;
20174+
20175+ cur = build_norm(cur,
20176+ model.output_norm, NULL,
20177+ LLM_NORM_RMS, -1);
20178+
20179+ cb(cur, "result_norm", -1);
20180+ res->t_embd = cur;
20181+
20182+ // lm_head
20183+ cur = build_lora_mm(model.output, cur);
20184+
20185+ cb(cur, "result_output", -1);
20186+ res->t_logits = cur;
20187+
20188+ ggml_build_forward_expand(gf, cur);
20189+ }
20190+ };
20191+
2002720192struct llm_build_cogvlm : public llm_graph_context {
2002820193 llm_build_cogvlm(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) {
2002920194 const int64_t n_embd_head = hparams.n_embd_head_v;
@@ -20654,6 +20819,10 @@ ggml_cgraph * llama_model::build_graph(const llm_graph_params & params) const {
2065420819 {
2065520820 llm = std::make_unique<llm_build_apertus>(*this, params);
2065620821 } break;
20822+ case LLM_ARCH_MINIMAX_M2:
20823+ {
20824+ llm = std::make_unique<llm_build_minimax_m2>(*this, params);
20825+ } break;
2065720826 case LLM_ARCH_COGVLM:
2065820827 {
2065920828 llm = std::make_unique<llm_build_cogvlm>(*this, params);
@@ -20875,6 +21044,7 @@ llama_rope_type llama_model_rope_type(const llama_model * model) {
2087521044 case LLM_ARCH_SEED_OSS:
2087621045 case LLM_ARCH_GROVEMOE:
2087721046 case LLM_ARCH_APERTUS:
21047+ case LLM_ARCH_MINIMAX_M2:
2087821048 case LLM_ARCH_COGVLM:
2087921049 return LLAMA_ROPE_TYPE_NEOX;
2088021050
0 commit comments