From cfee32df7237e24558b27640a2f5ab24a378db98 Mon Sep 17 00:00:00 2001 From: William Pan Date: Tue, 4 Nov 2025 20:49:40 -0800 Subject: [PATCH 1/9] Converted RND1 model to GGUF weights --- convert_hf_to_gguf.py | 29 +++++++++++++++++++++++++++++ gguf-py/gguf/constants.py | 19 +++++++++++++++++++ 2 files changed, 48 insertions(+) diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index c6f5ba6a04c..75baacf2322 100755 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -4042,6 +4042,33 @@ def set_vocab(self): super().set_vocab() +@ModelBase.register("RND1") +class RND1Model(Qwen3MoeModel): + model_arch = gguf.MODEL_ARCH.RND1 + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + hparams = ModelBase.load_hparams(self.dir_model, False) + self.origin_hf_arch = hparams.get('architectures', [None])[0] + + def set_gguf_parameters(self): + super().set_gguf_parameters() + + # RND1 specific parameters + # RND1 uses bidirectional attention + self.gguf_writer.add_causal_attention(False) + + mask_token_id = self.hparams.get("mask_token_id") + if mask_token_id is not None: + self.gguf_writer.add_mask_token_id(mask_token_id) + + def set_vocab(self): + # deal with intern-s1 + if self.origin_hf_arch == 'InternS1ForConditionalGeneration': + self._set_vocab_interns1() + return + + super().set_vocab() @ModelBase.register("Qwen3VLForConditionalGeneration", "Qwen3VLMoeForConditionalGeneration") class Qwen3VLVisionModel(MmprojModel): @@ -10192,6 +10219,8 @@ def main() -> None: model_type = ModelType.MMPROJ if args.mmproj else ModelType.TEXT hparams = ModelBase.load_hparams(dir_model, is_mistral_format) if not is_mistral_format: + print(hparams) + print(model_type) model_architecture = get_model_architecture(hparams, model_type) logger.info(f"Model architecture: {model_architecture}") try: diff --git a/gguf-py/gguf/constants.py b/gguf-py/gguf/constants.py index 77e3b0650ff..134c1875e3a 100644 --- a/gguf-py/gguf/constants.py +++ b/gguf-py/gguf/constants.py @@ -426,6 +426,7 @@ class MODEL_ARCH(IntEnum): APERTUS = auto() COGVLM = auto() MINIMAXM2 = auto() + RND1 = auto() class VISION_PROJECTOR_TYPE(IntEnum): @@ -793,6 +794,7 @@ class MODEL_TENSOR(IntEnum): MODEL_ARCH.APERTUS: "apertus", MODEL_ARCH.MINIMAXM2: "minimax-m2", MODEL_ARCH.COGVLM: "cogvlm", + MODEL_ARCH.RND1: "rnd1", } VISION_PROJECTOR_TYPE_NAMES: dict[VISION_PROJECTOR_TYPE, str] = { @@ -2958,6 +2960,23 @@ class MODEL_TENSOR(IntEnum): MODEL_TENSOR.VISEXP_UP, MODEL_TENSOR.VISEXP_DOWN, ], + MODEL_ARCH.RND1: [ + MODEL_TENSOR.TOKEN_EMBD, + MODEL_TENSOR.OUTPUT_NORM, + MODEL_TENSOR.OUTPUT, + MODEL_TENSOR.ATTN_NORM, + MODEL_TENSOR.ATTN_Q, + MODEL_TENSOR.ATTN_Q_NORM, + MODEL_TENSOR.ATTN_K, + MODEL_TENSOR.ATTN_K_NORM, + MODEL_TENSOR.ATTN_V, + MODEL_TENSOR.ATTN_OUT, + MODEL_TENSOR.FFN_NORM, + MODEL_TENSOR.FFN_GATE_INP, + MODEL_TENSOR.FFN_GATE_EXP, + MODEL_TENSOR.FFN_DOWN_EXP, + MODEL_TENSOR.FFN_UP_EXP, + ], # TODO } From 15d938c500664d9b9d15dddb33e0d1b858ab743c Mon Sep 17 00:00:00 2001 From: William Pan Date: Wed, 5 Nov 2025 09:29:29 -0800 Subject: [PATCH 2/9] RND1 llama.cpp support v1 --- src/CMakeLists.txt | 1 + src/llama-arch.cpp | 22 ++++++++ src/llama-arch.h | 1 + src/llama-model.cpp | 20 ++++++- src/models/models.h | 4 ++ src/models/rnd1.cpp | 125 ++++++++++++++++++++++++++++++++++++++++++++ 6 files changed, 172 insertions(+), 1 deletion(-) create mode 100644 src/models/rnd1.cpp diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index 832b58e315d..e438bc3d574 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -113,6 +113,7 @@ add_library(llama models/qwen3vl-moe.cpp models/qwen3moe.cpp models/refact.cpp + models/rnd1.cpp models/rwkv6-base.cpp models/rwkv6.cpp models/rwkv6qwen2.cpp diff --git a/src/llama-arch.cpp b/src/llama-arch.cpp index 7c7953b83dd..16763d9c5d5 100644 --- a/src/llama-arch.cpp +++ b/src/llama-arch.cpp @@ -107,6 +107,7 @@ static const std::map LLM_ARCH_NAMES = { { LLM_ARCH_APERTUS, "apertus" }, { LLM_ARCH_MINIMAX_M2, "minimax-m2" }, { LLM_ARCH_COGVLM, "cogvlm" }, + { LLM_ARCH_RND1, "rnd1" }, { LLM_ARCH_UNKNOWN, "(unknown)" }, }; @@ -2397,6 +2398,26 @@ static const std::map> LLM_TENSOR_N { LLM_TENSOR_VISEXP_FFN_UP, "blk.%d.vis_up" }, }, }, + { + LLM_ARCH_RND1, + { + { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, + { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, + { LLM_TENSOR_OUTPUT, "output" }, + { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, + { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" }, + { LLM_TENSOR_ATTN_Q_NORM, "blk.%d.attn_q_norm" }, + { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" }, + { LLM_TENSOR_ATTN_K_NORM, "blk.%d.attn_k_norm" }, + { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" }, + { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, + { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" }, + { LLM_TENSOR_FFN_GATE_INP, "blk.%d.ffn_gate_inp" }, + { LLM_TENSOR_FFN_GATE_EXPS, "blk.%d.ffn_gate_exps" }, + { LLM_TENSOR_FFN_DOWN_EXPS, "blk.%d.ffn_down_exps" }, + { LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" }, + }, + }, { LLM_ARCH_UNKNOWN, { @@ -2672,6 +2693,7 @@ bool llm_arch_is_diffusion(const llm_arch & arch) { case LLM_ARCH_DREAM: case LLM_ARCH_LLADA: case LLM_ARCH_LLADA_MOE: + case LLM_ARCH_RND1: return true; default: return false; diff --git a/src/llama-arch.h b/src/llama-arch.h index 3f893a2dc69..15f3eea7f71 100644 --- a/src/llama-arch.h +++ b/src/llama-arch.h @@ -111,6 +111,7 @@ enum llm_arch { LLM_ARCH_APERTUS, LLM_ARCH_MINIMAX_M2, LLM_ARCH_COGVLM, + LLM_ARCH_RND1, LLM_ARCH_UNKNOWN, }; diff --git a/src/llama-model.cpp b/src/llama-model.cpp index 896725466ce..438885ecbf7 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -1004,6 +1004,16 @@ void llama_model::load_hparams(llama_model_loader & ml) { default: type = LLM_TYPE_UNKNOWN; } } break; + case LLM_ARCH_RND1: + { + ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp, false); + + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + switch (hparams.n_layer) { + case 48: type = LLM_TYPE_30B_A3B; break; + default: type = LLM_TYPE_UNKNOWN; + } + } break; case LLM_ARCH_QWEN2MOE: { ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp, false); @@ -3370,6 +3380,7 @@ bool llama_model::load_tensors(llama_model_loader & ml) { } break; case LLM_ARCH_QWEN3MOE: case LLM_ARCH_QWEN3VLMOE: + case LLM_ARCH_RND1: { // for model loading, the weights only have the main embd // so we need to divide by the number of deepstack layers + 1 @@ -6581,7 +6592,7 @@ void llama_model::print_info() const { LLAMA_LOG_INFO("%s: n_ff_shexp = %d\n", __func__, hparams.n_ff_shexp); } - if (arch == LLM_ARCH_QWEN3MOE || arch == LLM_ARCH_OPENAI_MOE || arch == LLM_ARCH_QWEN3VLMOE) { + if (arch == LLM_ARCH_QWEN3MOE || arch == LLM_ARCH_OPENAI_MOE || arch == LLM_ARCH_QWEN3VLMOE || arch == LLM_ARCH_RND1) { LLAMA_LOG_INFO("%s: n_ff_exp = %d\n", __func__, hparams.n_ff_exp); } @@ -6743,6 +6754,7 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params, case LLM_ARCH_DREAM: case LLM_ARCH_LLADA: case LLM_ARCH_LLADA_MOE: + case LLM_ARCH_RND1: { res = nullptr; } break; @@ -6936,6 +6948,11 @@ ggml_cgraph * llama_model::build_graph(const llm_graph_params & params) const { llm = std::make_unique(*this, params); } break; + case LLM_ARCH_RND1: + { + llm = std::make_unique(*this, params); + } + break; case LLM_ARCH_QWEN2VL: { llm = std::make_unique(*this, params); @@ -7444,6 +7461,7 @@ llama_rope_type llama_model_rope_type(const llama_model * model) { case LLM_ARCH_QWEN3: case LLM_ARCH_QWEN3MOE: case LLM_ARCH_LLADA_MOE: + case LLM_ARCH_RND1: case LLM_ARCH_OLMO2: case LLM_ARCH_OLMOE: case LLM_ARCH_PHI2: diff --git a/src/models/models.h b/src/models/models.h index af203343a4d..534bcda2e07 100644 --- a/src/models/models.h +++ b/src/models/models.h @@ -145,6 +145,10 @@ struct llm_build_dream : public llm_graph_context { llm_build_dream(const llama_model & model, const llm_graph_params & params); }; +struct llm_build_rnd1 : public llm_graph_context { + llm_build_rnd1(const llama_model & model, const llm_graph_params & params); +}; + struct llm_build_ernie4_5 : public llm_graph_context { llm_build_ernie4_5(const llama_model & model, const llm_graph_params & params); }; diff --git a/src/models/rnd1.cpp b/src/models/rnd1.cpp new file mode 100644 index 00000000000..d1854583c40 --- /dev/null +++ b/src/models/rnd1.cpp @@ -0,0 +1,125 @@ +#include "models.h" + +llm_build_rnd1::llm_build_rnd1(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { + const int64_t n_embd_head = hparams.n_embd_head_v; + + GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); + GGML_ASSERT(n_embd_head == hparams.n_rot); + + ggml_tensor * cur; + ggml_tensor * inpL; + + inpL = build_inp_embd(model.tok_embd); + + // inp_pos - contains the positions + ggml_tensor * inp_pos = build_inp_pos(); + + // Non-causal attention for diffusion + auto * inp_attn = build_attn_inp_no_cache(); + + ggml_tensor * inp_out_ids = build_inp_out_ids(); + + for (int il = 0; il < n_layer; ++il) { + ggml_tensor * inpSA = inpL; + + // norm + cur = build_norm(inpL, + model.layers[il].attn_norm, NULL, + LLM_NORM_RMS, il); + cb(cur, "attn_norm", il); + + // self_attention + { + // compute Q and K and RoPE them + ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur); + cb(Qcur, "Qcur", il); + + ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur); + cb(Kcur, "Kcur", il); + + ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur); + cb(Vcur, "Vcur", il); + + Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); + Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); + Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); + + Qcur = build_norm(Qcur, model.layers[il].attn_q_norm, NULL, LLM_NORM_RMS, il); + cb(Qcur, "Qcur_normed", il); + + Qcur = ggml_rope_ext( + ctx0, Qcur, inp_pos, nullptr, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); + + Kcur = build_norm(Kcur, model.layers[il].attn_k_norm, NULL, LLM_NORM_RMS, il); + cb(Kcur, "Kcur_normed", il); + + Kcur = ggml_rope_ext( + ctx0, Kcur, inp_pos, nullptr, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); + + cb(Qcur, "Qcur", il); + cb(Kcur, "Kcur", il); + cb(Vcur, "Vcur", il); + + cur = build_attn(inp_attn, + model.layers[il].wo, model.layers[il].bo, + Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); + } + if (il == n_layer - 1 && inp_out_ids) { + cur = ggml_get_rows(ctx0, cur, inp_out_ids); + inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids); + } + ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA); + cb(ffn_inp, "ffn_inp", il); + + // MoE branch + cur = build_norm(ffn_inp, + model.layers[il].ffn_norm, NULL, + LLM_NORM_RMS, il); + cb(cur, "ffn_norm", il); + + ggml_tensor * moe_out = + build_moe_ffn(cur, + model.layers[il].ffn_gate_inp, + model.layers[il].ffn_up_exps, + model.layers[il].ffn_gate_exps, + model.layers[il].ffn_down_exps, + nullptr, + n_expert, n_expert_used, + LLM_FFN_SILU, true, + false, 0.0, + LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX, + il); + cb(moe_out, "ffn_moe_out", il); + cur = moe_out; + + cur = ggml_add(ctx0, cur, ffn_inp); + + cur = build_cvec(cur, il); + cb(cur, "l_out", il); + + // input for next layer + inpL = cur; + } + cur = inpL; + + cur = build_norm(cur, + model.output_norm, NULL, + LLM_NORM_RMS, -1); + + cb(cur, "result_norm", -1); + res->t_embd = cur; + + // lm_head + cur = build_lora_mm(model.output, cur); + + cb(cur, "result_output", -1); + res->t_logits = cur; + + ggml_build_forward_expand(gf, cur); +} From 0911acfacfa8e7e180f96a8085b97229a07f8ae7 Mon Sep 17 00:00:00 2001 From: William Pan Date: Thu, 6 Nov 2025 19:38:10 +0000 Subject: [PATCH 3/9] RND1 llama.cpp support v2 non causal bug --- src/llama-model.cpp | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/llama-model.cpp b/src/llama-model.cpp index 438885ecbf7..0f7d3a9062e 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -1013,6 +1013,8 @@ void llama_model::load_hparams(llama_model_loader & ml) { case 48: type = LLM_TYPE_30B_A3B; break; default: type = LLM_TYPE_UNKNOWN; } + // Set non-causal attention for diffusion models + hparams.causal_attn = false; } break; case LLM_ARCH_QWEN2MOE: { From 5f36f0a38dbfbf883181e7f4592512c2b4f61270 Mon Sep 17 00:00:00 2001 From: William Pan Date: Fri, 7 Nov 2025 04:37:46 +0000 Subject: [PATCH 4/9] RND1 llama.cpp support v3 doccumentation --- examples/diffusion/README.md | 49 ++++++++++++++++++++++++++++++++++-- 1 file changed, 47 insertions(+), 2 deletions(-) diff --git a/examples/diffusion/README.md b/examples/diffusion/README.md index 26de5668aa8..f0ca4558d98 100644 --- a/examples/diffusion/README.md +++ b/examples/diffusion/README.md @@ -6,8 +6,53 @@ More Info: - https://github.com/ggml-org/llama.cpp/pull/14644 - https://github.com/ggml-org/llama.cpp/pull/14771 +## Parameters +The diffusion CLI supports various parameters to control the generation process: -Example of using Dream architechture: `llama-diffusion-cli -m dream7b.gguf -p "write code to train MNIST in pytorch" -ub 512 --diffusion-eps 0.001 --diffusion-algorithm 3 --diffusion-steps 256 --diffusion-visual` +### Core Diffusion Parameters +- `--diffusion-steps`: Number of diffusion steps (default: 256) +- `--diffusion-algorithm`: Algorithm for token selection + - `0`: ORIGIN - Original algorithm + - `1`: ENTROPY_BASED - Entropy-based selection (recommended) + - `2`: MARGIN_BASED - Margin-based selection + - `3`: RANDOM - Random selection + - `4`: CONFIDENCE_BASED - Confidence-based selection +- `--diffusion-visual`: Enable live visualization during generation -Example of using LLaDA architechture: `llama-diffusion-cli -m llada-8b.gguf -p "write code to train MNIST in pytorch" -ub 512 --diffusion-block-length 32 --diffusion-steps 256 --diffusion-visual` +### Scheduling Parameters +Choose one of the following scheduling methods: +**Timestep-based scheduling:** +- `--diffusion-eps`: Epsilon value for timestep scheduling (e.g., 0.001) + +**Block-based scheduling:** +- `--diffusion-block-length`: Block size for block-based scheduling (e.g., 32) + +### Sampling Parameters +- `--temp`: Temperature for sampling (0.0 = greedy/deterministic, higher = more random) +- `--top-k`: Top-k filtering for sampling +- `--top-p`: Top-p (nucleus) filtering for sampling +- `--seed`: Random seed for reproducibility + +### Model Parameters +- `-m`: Path to the GGUF model file +- `-p`: Input prompt text +- `-ub`: Maximum sequence length (ubatch size) +- `-c`: Context size +- `-b`: Batch size + +### Examples +#### Dream architechture: +``` +llama-diffusion-cli -m dream7b.gguf -p "write code to train MNIST in pytorch" -ub 512 --diffusion-eps 0.001 --diffusion-algorithm 3 --diffusion-steps 256 --diffusion-visual +``` + +#### LLaDA architechture: +``` +llama-diffusion-cli -m llada-8b.gguf -p "write code to train MNIST in pytorch" -ub 512 --diffusion-block-length 32 --diffusion-steps 256 --diffusion-visual +``` + +#### RND1 architecture: +``` +llama-diffusion-cli -m RND1-Base-0910.gguf -p "write code to train MNIST in pytorch" -ub 512 --diffusion-algorithm 1 --diffusion-steps 256 --diffusion-visual --temp 0.5 --diffusion-eps 0.001 +``` \ No newline at end of file From d960ace571ca3c056a5e4ba557efd7f293a02bf9 Mon Sep 17 00:00:00 2001 From: William Pan Date: Fri, 7 Nov 2025 04:48:05 +0000 Subject: [PATCH 5/9] RND1 llama.cpp support v4 clean code --- convert_hf_to_gguf.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index 75baacf2322..19d7bfeaadd 100755 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -10219,8 +10219,6 @@ def main() -> None: model_type = ModelType.MMPROJ if args.mmproj else ModelType.TEXT hparams = ModelBase.load_hparams(dir_model, is_mistral_format) if not is_mistral_format: - print(hparams) - print(model_type) model_architecture = get_model_architecture(hparams, model_type) logger.info(f"Model architecture: {model_architecture}") try: From e02174b7176023e428f6891f140783da0e4a85cf Mon Sep 17 00:00:00 2001 From: William Pan Date: Fri, 21 Nov 2025 10:43:00 -0800 Subject: [PATCH 6/9] linting issues --- convert_hf_to_gguf.py | 2 +- examples/diffusion/README.md | 8 ++++---- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index 64493724a67..ee7a99bf927 100755 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -4197,7 +4197,7 @@ def set_gguf_parameters(self): # RND1 specific parameters # RND1 uses bidirectional attention self.gguf_writer.add_causal_attention(False) - + mask_token_id = self.hparams.get("mask_token_id") if mask_token_id is not None: self.gguf_writer.add_mask_token_id(mask_token_id) diff --git a/examples/diffusion/README.md b/examples/diffusion/README.md index f0ca4558d98..dafc3580600 100644 --- a/examples/diffusion/README.md +++ b/examples/diffusion/README.md @@ -42,17 +42,17 @@ Choose one of the following scheduling methods: - `-b`: Batch size ### Examples -#### Dream architechture: +#### Dream architechture: ``` llama-diffusion-cli -m dream7b.gguf -p "write code to train MNIST in pytorch" -ub 512 --diffusion-eps 0.001 --diffusion-algorithm 3 --diffusion-steps 256 --diffusion-visual ``` -#### LLaDA architechture: +#### LLaDA architechture: ``` llama-diffusion-cli -m llada-8b.gguf -p "write code to train MNIST in pytorch" -ub 512 --diffusion-block-length 32 --diffusion-steps 256 --diffusion-visual ``` -#### RND1 architecture: +#### RND1 architecture: ``` llama-diffusion-cli -m RND1-Base-0910.gguf -p "write code to train MNIST in pytorch" -ub 512 --diffusion-algorithm 1 --diffusion-steps 256 --diffusion-visual --temp 0.5 --diffusion-eps 0.001 -``` \ No newline at end of file +``` From 53a517b8b4ada75622a87c044a347af2b9ca24a3 Mon Sep 17 00:00:00 2001 From: William Pan Date: Fri, 21 Nov 2025 15:13:35 -0800 Subject: [PATCH 7/9] RND1 pr fixes v1 --- convert_hf_to_gguf.py | 18 ++---------------- src/models/models.h | 8 ++++---- 2 files changed, 6 insertions(+), 20 deletions(-) diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index ee7a99bf927..2122d686291 100755 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -4183,14 +4183,9 @@ def set_vocab(self): super().set_vocab() @ModelBase.register("RND1") -class RND1Model(Qwen3MoeModel): +class RND1Model(Qwen2MoeModel): model_arch = gguf.MODEL_ARCH.RND1 - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - hparams = ModelBase.load_hparams(self.dir_model, False) - self.origin_hf_arch = hparams.get('architectures', [None])[0] - def set_gguf_parameters(self): super().set_gguf_parameters() @@ -4198,18 +4193,9 @@ def set_gguf_parameters(self): # RND1 uses bidirectional attention self.gguf_writer.add_causal_attention(False) - mask_token_id = self.hparams.get("mask_token_id") - if mask_token_id is not None: + if (mask_token_id := self.hparams.get("mask_token_id")) is not None: self.gguf_writer.add_mask_token_id(mask_token_id) - def set_vocab(self): - # deal with intern-s1 - if self.origin_hf_arch == 'InternS1ForConditionalGeneration': - self._set_vocab_interns1() - return - - super().set_vocab() - @ModelBase.register("Qwen3VLForConditionalGeneration", "Qwen3VLMoeForConditionalGeneration") class Qwen3VLVisionModel(MmprojModel): def __init__(self, *args, **kwargs): diff --git a/src/models/models.h b/src/models/models.h index 97f2cd9909b..5f019c59be8 100644 --- a/src/models/models.h +++ b/src/models/models.h @@ -149,10 +149,6 @@ struct llm_build_dream : public llm_graph_context { llm_build_dream(const llama_model & model, const llm_graph_params & params); }; -struct llm_build_rnd1 : public llm_graph_context { - llm_build_rnd1(const llama_model & model, const llm_graph_params & params); -}; - struct llm_build_ernie4_5 : public llm_graph_context { llm_build_ernie4_5(const llama_model & model, const llm_graph_params & params); }; @@ -435,6 +431,10 @@ struct llm_build_refact : public llm_graph_context { llm_build_refact(const llama_model & model, const llm_graph_params & params); }; +struct llm_build_rnd1 : public llm_graph_context { + llm_build_rnd1(const llama_model & model, const llm_graph_params & params); +}; + struct llm_build_rwkv6 : public llm_build_rwkv6_base { llm_build_rwkv6(const llama_model & model, const llm_graph_params & params); }; From a877fe34091f65073d1b66960d87611324cdd665 Mon Sep 17 00:00:00 2001 From: william pan <61359596+wp4032@users.noreply.github.com> Date: Fri, 21 Nov 2025 15:37:56 -0800 Subject: [PATCH 8/9] RND1 pr fixes v2 MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Sigbjørn Skjæret --- convert_hf_to_gguf.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index 2122d686291..6cbaee03dfd 100755 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -4182,6 +4182,7 @@ def set_vocab(self): super().set_vocab() + @ModelBase.register("RND1") class RND1Model(Qwen2MoeModel): model_arch = gguf.MODEL_ARCH.RND1 @@ -4196,6 +4197,7 @@ def set_gguf_parameters(self): if (mask_token_id := self.hparams.get("mask_token_id")) is not None: self.gguf_writer.add_mask_token_id(mask_token_id) + @ModelBase.register("Qwen3VLForConditionalGeneration", "Qwen3VLMoeForConditionalGeneration") class Qwen3VLVisionModel(MmprojModel): def __init__(self, *args, **kwargs): From bf6d002095c11731a544f544e7409162444f4026 Mon Sep 17 00:00:00 2001 From: William Pan Date: Sat, 22 Nov 2025 08:14:31 -0800 Subject: [PATCH 9/9] Diffusion documentation edits --- examples/diffusion/README.md | 7 ++++--- src/models/rnd1.cpp | 1 + 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/examples/diffusion/README.md b/examples/diffusion/README.md index dafc3580600..f71d2413193 100644 --- a/examples/diffusion/README.md +++ b/examples/diffusion/README.md @@ -12,11 +12,12 @@ The diffusion CLI supports various parameters to control the generation process: ### Core Diffusion Parameters - `--diffusion-steps`: Number of diffusion steps (default: 256) - `--diffusion-algorithm`: Algorithm for token selection - - `0`: ORIGIN - Original algorithm - - `1`: ENTROPY_BASED - Entropy-based selection (recommended) + - `0`: ORIGIN - Token will be generated in a purely random order from https://arxiv.org/abs/2107.03006. + - `1`: ENTROPY_BASED - Entropy-based selection - `2`: MARGIN_BASED - Margin-based selection - `3`: RANDOM - Random selection - - `4`: CONFIDENCE_BASED - Confidence-based selection + - `4`: CONFIDENCE_BASED - Confidence-based selection (default) + - More documentation here https://github.com/DreamLM/Dream - `--diffusion-visual`: Enable live visualization during generation ### Scheduling Parameters diff --git a/src/models/rnd1.cpp b/src/models/rnd1.cpp index d1854583c40..46b3dc3efca 100644 --- a/src/models/rnd1.cpp +++ b/src/models/rnd1.cpp @@ -1,5 +1,6 @@ #include "models.h" +// RND1 is a Qwen3Moe AR model converted to diffusion model. llm_build_rnd1::llm_build_rnd1(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { const int64_t n_embd_head = hparams.n_embd_head_v;