Skip to content
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions convert_hf_to_gguf.py
Original file line number Diff line number Diff line change
Expand Up @@ -4182,6 +4182,19 @@ def set_vocab(self):

super().set_vocab()

@ModelBase.register("RND1")
class RND1Model(Qwen2MoeModel):
model_arch = gguf.MODEL_ARCH.RND1

def set_gguf_parameters(self):
super().set_gguf_parameters()

# RND1 specific parameters
# RND1 uses bidirectional attention
self.gguf_writer.add_causal_attention(False)

if (mask_token_id := self.hparams.get("mask_token_id")) is not None:
self.gguf_writer.add_mask_token_id(mask_token_id)

@ModelBase.register("Qwen3VLForConditionalGeneration", "Qwen3VLMoeForConditionalGeneration")
class Qwen3VLVisionModel(MmprojModel):
Expand Down
49 changes: 47 additions & 2 deletions examples/diffusion/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,53 @@ More Info:
- https://github.com/ggml-org/llama.cpp/pull/14644
- https://github.com/ggml-org/llama.cpp/pull/14771

## Parameters
The diffusion CLI supports various parameters to control the generation process:

Example of using Dream architechture: `llama-diffusion-cli -m dream7b.gguf -p "write code to train MNIST in pytorch" -ub 512 --diffusion-eps 0.001 --diffusion-algorithm 3 --diffusion-steps 256 --diffusion-visual`
### Core Diffusion Parameters
- `--diffusion-steps`: Number of diffusion steps (default: 256)
- `--diffusion-algorithm`: Algorithm for token selection
- `0`: ORIGIN - Original algorithm
- `1`: ENTROPY_BASED - Entropy-based selection (recommended)
- `2`: MARGIN_BASED - Margin-based selection
- `3`: RANDOM - Random selection
- `4`: CONFIDENCE_BASED - Confidence-based selection
- `--diffusion-visual`: Enable live visualization during generation

Example of using LLaDA architechture: `llama-diffusion-cli -m llada-8b.gguf -p "write code to train MNIST in pytorch" -ub 512 --diffusion-block-length 32 --diffusion-steps 256 --diffusion-visual`
### Scheduling Parameters
Choose one of the following scheduling methods:

**Timestep-based scheduling:**
- `--diffusion-eps`: Epsilon value for timestep scheduling (e.g., 0.001)

**Block-based scheduling:**
- `--diffusion-block-length`: Block size for block-based scheduling (e.g., 32)

### Sampling Parameters
- `--temp`: Temperature for sampling (0.0 = greedy/deterministic, higher = more random)
- `--top-k`: Top-k filtering for sampling
- `--top-p`: Top-p (nucleus) filtering for sampling
- `--seed`: Random seed for reproducibility

### Model Parameters
- `-m`: Path to the GGUF model file
- `-p`: Input prompt text
- `-ub`: Maximum sequence length (ubatch size)
- `-c`: Context size
- `-b`: Batch size

### Examples
#### Dream architechture:
```
llama-diffusion-cli -m dream7b.gguf -p "write code to train MNIST in pytorch" -ub 512 --diffusion-eps 0.001 --diffusion-algorithm 3 --diffusion-steps 256 --diffusion-visual
```

#### LLaDA architechture:
```
llama-diffusion-cli -m llada-8b.gguf -p "write code to train MNIST in pytorch" -ub 512 --diffusion-block-length 32 --diffusion-steps 256 --diffusion-visual
```

#### RND1 architecture:
```
llama-diffusion-cli -m RND1-Base-0910.gguf -p "write code to train MNIST in pytorch" -ub 512 --diffusion-algorithm 1 --diffusion-steps 256 --diffusion-visual --temp 0.5 --diffusion-eps 0.001
```
19 changes: 19 additions & 0 deletions gguf-py/gguf/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -427,6 +427,7 @@ class MODEL_ARCH(IntEnum):
APERTUS = auto()
COGVLM = auto()
MINIMAXM2 = auto()
RND1 = auto()
PANGU_EMBED = auto()


Expand Down Expand Up @@ -797,6 +798,7 @@ class MODEL_TENSOR(IntEnum):
MODEL_ARCH.APERTUS: "apertus",
MODEL_ARCH.MINIMAXM2: "minimax-m2",
MODEL_ARCH.COGVLM: "cogvlm",
MODEL_ARCH.RND1: "rnd1",
MODEL_ARCH.PANGU_EMBED: "pangu-embedded",
}

Expand Down Expand Up @@ -2991,6 +2993,23 @@ class MODEL_TENSOR(IntEnum):
MODEL_TENSOR.VISEXP_UP,
MODEL_TENSOR.VISEXP_DOWN,
],
MODEL_ARCH.RND1: [
MODEL_TENSOR.TOKEN_EMBD,
MODEL_TENSOR.OUTPUT_NORM,
MODEL_TENSOR.OUTPUT,
MODEL_TENSOR.ATTN_NORM,
MODEL_TENSOR.ATTN_Q,
MODEL_TENSOR.ATTN_Q_NORM,
MODEL_TENSOR.ATTN_K,
MODEL_TENSOR.ATTN_K_NORM,
MODEL_TENSOR.ATTN_V,
MODEL_TENSOR.ATTN_OUT,
MODEL_TENSOR.FFN_NORM,
MODEL_TENSOR.FFN_GATE_INP,
MODEL_TENSOR.FFN_GATE_EXP,
MODEL_TENSOR.FFN_DOWN_EXP,
MODEL_TENSOR.FFN_UP_EXP,
],
MODEL_ARCH.PANGU_EMBED: [
MODEL_TENSOR.TOKEN_EMBD,
MODEL_TENSOR.OUTPUT_NORM,
Expand Down
1 change: 1 addition & 0 deletions src/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,7 @@ add_library(llama
models/qwen3vl-moe.cpp
models/qwen3moe.cpp
models/refact.cpp
models/rnd1.cpp
models/rwkv6-base.cpp
models/rwkv6.cpp
models/rwkv6qwen2.cpp
Expand Down
22 changes: 22 additions & 0 deletions src/llama-arch.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,7 @@ static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
{ LLM_ARCH_APERTUS, "apertus" },
{ LLM_ARCH_MINIMAX_M2, "minimax-m2" },
{ LLM_ARCH_COGVLM, "cogvlm" },
{ LLM_ARCH_RND1, "rnd1" },
{ LLM_ARCH_PANGU_EMBED, "pangu-embedded" },
{ LLM_ARCH_UNKNOWN, "(unknown)" },
};
Expand Down Expand Up @@ -2446,6 +2447,26 @@ static const std::map<llm_arch, std::map<llm_tensor, const char *>> LLM_TENSOR_N
{ LLM_TENSOR_VISEXP_FFN_UP, "blk.%d.vis_up" },
},
},
{
LLM_ARCH_RND1,
{
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
{ LLM_TENSOR_OUTPUT_NORM, "output_norm" },
{ LLM_TENSOR_OUTPUT, "output" },
{ LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
{ LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" },
{ LLM_TENSOR_ATTN_Q_NORM, "blk.%d.attn_q_norm" },
{ LLM_TENSOR_ATTN_K, "blk.%d.attn_k" },
{ LLM_TENSOR_ATTN_K_NORM, "blk.%d.attn_k_norm" },
{ LLM_TENSOR_ATTN_V, "blk.%d.attn_v" },
{ LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
{ LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" },
{ LLM_TENSOR_FFN_GATE_INP, "blk.%d.ffn_gate_inp" },
{ LLM_TENSOR_FFN_GATE_EXPS, "blk.%d.ffn_gate_exps" },
{ LLM_TENSOR_FFN_DOWN_EXPS, "blk.%d.ffn_down_exps" },
{ LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" },
},
},
{
LLM_ARCH_UNKNOWN,
{
Expand Down Expand Up @@ -2722,6 +2743,7 @@ bool llm_arch_is_diffusion(const llm_arch & arch) {
case LLM_ARCH_DREAM:
case LLM_ARCH_LLADA:
case LLM_ARCH_LLADA_MOE:
case LLM_ARCH_RND1:
return true;
default:
return false;
Expand Down
1 change: 1 addition & 0 deletions src/llama-arch.h
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,7 @@ enum llm_arch {
LLM_ARCH_APERTUS,
LLM_ARCH_MINIMAX_M2,
LLM_ARCH_COGVLM,
LLM_ARCH_RND1,
LLM_ARCH_PANGU_EMBED,
LLM_ARCH_UNKNOWN,
};
Expand Down
22 changes: 21 additions & 1 deletion src/llama-model.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1036,6 +1036,18 @@ void llama_model::load_hparams(llama_model_loader & ml) {
default: type = LLM_TYPE_UNKNOWN;
}
} break;
case LLM_ARCH_RND1:
{
ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp, false);

ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
switch (hparams.n_layer) {
case 48: type = LLM_TYPE_30B_A3B; break;
default: type = LLM_TYPE_UNKNOWN;
}
// Set non-causal attention for diffusion models
hparams.causal_attn = false;
} break;
case LLM_ARCH_QWEN2MOE:
{
ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp, false);
Expand Down Expand Up @@ -3402,6 +3414,7 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
} break;
case LLM_ARCH_QWEN3MOE:
case LLM_ARCH_QWEN3VLMOE:
case LLM_ARCH_RND1:
{
tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);

Expand Down Expand Up @@ -6720,7 +6733,7 @@ void llama_model::print_info() const {
LLAMA_LOG_INFO("%s: n_ff_shexp = %d\n", __func__, hparams.n_ff_shexp);
}

if (arch == LLM_ARCH_QWEN3MOE || arch == LLM_ARCH_OPENAI_MOE || arch == LLM_ARCH_QWEN3VLMOE) {
if (arch == LLM_ARCH_QWEN3MOE || arch == LLM_ARCH_OPENAI_MOE || arch == LLM_ARCH_QWEN3VLMOE || arch == LLM_ARCH_RND1) {
LLAMA_LOG_INFO("%s: n_ff_exp = %d\n", __func__, hparams.n_ff_exp);
}

Expand Down Expand Up @@ -6882,6 +6895,7 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params,
case LLM_ARCH_DREAM:
case LLM_ARCH_LLADA:
case LLM_ARCH_LLADA_MOE:
case LLM_ARCH_RND1:
{
res = nullptr;
} break;
Expand Down Expand Up @@ -7075,6 +7089,11 @@ ggml_cgraph * llama_model::build_graph(const llm_graph_params & params) const {
llm = std::make_unique<llm_build_llada_moe>(*this, params);
}
break;
case LLM_ARCH_RND1:
{
llm = std::make_unique<llm_build_rnd1>(*this, params);
}
break;
case LLM_ARCH_QWEN2VL:
{
llm = std::make_unique<llm_build_qwen2vl>(*this, params);
Expand Down Expand Up @@ -7595,6 +7614,7 @@ llama_rope_type llama_model_rope_type(const llama_model * model) {
case LLM_ARCH_QWEN3:
case LLM_ARCH_QWEN3MOE:
case LLM_ARCH_LLADA_MOE:
case LLM_ARCH_RND1:
case LLM_ARCH_OLMO2:
case LLM_ARCH_OLMOE:
case LLM_ARCH_PHI2:
Expand Down
4 changes: 4 additions & 0 deletions src/models/models.h
Original file line number Diff line number Diff line change
Expand Up @@ -431,6 +431,10 @@ struct llm_build_refact : public llm_graph_context {
llm_build_refact(const llama_model & model, const llm_graph_params & params);
};

struct llm_build_rnd1 : public llm_graph_context {
llm_build_rnd1(const llama_model & model, const llm_graph_params & params);
};

struct llm_build_rwkv6 : public llm_build_rwkv6_base {
llm_build_rwkv6(const llama_model & model, const llm_graph_params & params);
};
Expand Down
125 changes: 125 additions & 0 deletions src/models/rnd1.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
#include "models.h"

llm_build_rnd1::llm_build_rnd1(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) {
const int64_t n_embd_head = hparams.n_embd_head_v;

GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
GGML_ASSERT(n_embd_head == hparams.n_rot);

ggml_tensor * cur;
ggml_tensor * inpL;

inpL = build_inp_embd(model.tok_embd);

// inp_pos - contains the positions
ggml_tensor * inp_pos = build_inp_pos();

// Non-causal attention for diffusion
auto * inp_attn = build_attn_inp_no_cache();

ggml_tensor * inp_out_ids = build_inp_out_ids();

for (int il = 0; il < n_layer; ++il) {
ggml_tensor * inpSA = inpL;

// norm
cur = build_norm(inpL,
model.layers[il].attn_norm, NULL,
LLM_NORM_RMS, il);
cb(cur, "attn_norm", il);

// self_attention
{
// compute Q and K and RoPE them
ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur);
cb(Qcur, "Qcur", il);

ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur);
cb(Kcur, "Kcur", il);

ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur);
cb(Vcur, "Vcur", il);

Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens);

Qcur = build_norm(Qcur, model.layers[il].attn_q_norm, NULL, LLM_NORM_RMS, il);
cb(Qcur, "Qcur_normed", il);

Qcur = ggml_rope_ext(
ctx0, Qcur, inp_pos, nullptr,
n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
ext_factor, attn_factor, beta_fast, beta_slow
);

Kcur = build_norm(Kcur, model.layers[il].attn_k_norm, NULL, LLM_NORM_RMS, il);
cb(Kcur, "Kcur_normed", il);

Kcur = ggml_rope_ext(
ctx0, Kcur, inp_pos, nullptr,
n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
ext_factor, attn_factor, beta_fast, beta_slow
);

cb(Qcur, "Qcur", il);
cb(Kcur, "Kcur", il);
cb(Vcur, "Vcur", il);

cur = build_attn(inp_attn,
model.layers[il].wo, model.layers[il].bo,
Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
}
if (il == n_layer - 1 && inp_out_ids) {
cur = ggml_get_rows(ctx0, cur, inp_out_ids);
inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
}
ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA);
cb(ffn_inp, "ffn_inp", il);

// MoE branch
cur = build_norm(ffn_inp,
model.layers[il].ffn_norm, NULL,
LLM_NORM_RMS, il);
cb(cur, "ffn_norm", il);

ggml_tensor * moe_out =
build_moe_ffn(cur,
model.layers[il].ffn_gate_inp,
model.layers[il].ffn_up_exps,
model.layers[il].ffn_gate_exps,
model.layers[il].ffn_down_exps,
nullptr,
n_expert, n_expert_used,
LLM_FFN_SILU, true,
false, 0.0,
LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX,
il);
cb(moe_out, "ffn_moe_out", il);
cur = moe_out;

cur = ggml_add(ctx0, cur, ffn_inp);

cur = build_cvec(cur, il);
cb(cur, "l_out", il);

// input for next layer
inpL = cur;
}
cur = inpL;

cur = build_norm(cur,
model.output_norm, NULL,
LLM_NORM_RMS, -1);

cb(cur, "result_norm", -1);
res->t_embd = cur;

// lm_head
cur = build_lora_mm(model.output, cur);

cb(cur, "result_output", -1);
res->t_logits = cur;

ggml_build_forward_expand(gf, cur);
}