Guidance for any agentic coding model working in this repo. Reflects the v2 overhaul
(branch feature/dimba-v2-overhaul, PR #18). For deeper detail see
docs/OVERHAUL_STATUS.md, docs/IMPROVEMENT_PLAN.md, and docs/RESEARCH_DIRECTIONS.md.
DIMBA is a non-autoregressive latent-diffusion language model: continuous Gaussian
diffusion runs in a learned latent space (VAE/projector over token embeddings; raw-embedding
diffusion is the degenerate latent_diffusion=False case), denoised by a bidirectional
Mamba backbone, generating whole sequences in parallel by iterative denoising.
- v1 =
paper/main.pdf— an architectural concept (explicitly untested). Do not treat it as ground truth: it contains a prompt-conditioning leak (C = PromptEncoder(X₀)) and an MSE-only objective that the v2 code deliberately fixes. - v2 = this repo — the implementation + the overhaul below. This file describes v2.
import torchsegfaults at interpreter teardown on the original dev box (Windows), and the barepythonon PATH is a WindowsApps shim that hangs. Use the project venv:venv\Scripts\python.exe(Windows) /venv/bin/python(mac). For scripts that import torch, end withos._exit(0)after flushing to dodge the teardown crash.- Validate without running torch:
python -m compileall src/dimba scripts tests(syntax). - Runtime smoke:
venv/bin/python .sisyphus/smoke_full.py(end-to-end, usesos._exit). - CI (
.github/workflows/ci.yml) runs the realpytestsuite on clean Linux runners (py3.10/3.12) with working torch — that's the source of truth for runtime tests. - Apple Silicon (M1/M2/M3) training: use the PyTorch MPS path (the
backends/mlx/port is a skeleton, not training-ready). SetPYTORCH_ENABLE_MPS_FALLBACK=1, use fp32, andlatent_diffusion=True.SimpleMamba2uses the vectorized scan on MPS; keepseq_len≤ ~256. See the small-model recipe indocs/OVERHAUL_STATUS.md.
input_ids ─► token_embed ─► encode_latent (×latent_scale → ~unit variance)
│
prompt (pooled, response-free) ─┐ add_noise (cosine, zero-terminal-SNR)
timestep_embed τ(t) ────────────┤ │ (only response positions if prompt_mask)
▼ ▼
Mamba denoiser (N bidirectional blocks, FiLM/additive cond)
│ [+ self-conditioning: prev x̂₀ fused in]
▼
raw pred → x̂₀ latent (x0 or v param)
▼
decode_latent (÷latent_scale → embedding) ─► output_head ─► logits
Key points that differ from v1 and must not be reintroduced as bugs:
- No conditioning leak. Conditioning is the prompt only — a pooled prompt summary, and
(when
prompt_maskis given) the prompt tokens kept clean in-sequence while only the response is noised, with loss on the response. Never condition on the clean target. forward()always returns the 3-tuple(x_pred, noise, latent_info).encode_latent/decode_latentcarrylatent_scaleand round-trip exactly. Anything that diffuses or samples must go through them (signal must be ~unit variance for a calibrated SNR). Callmodel.calibrate_latent_scale(batch)before training in latent mode.- The model stores its full constructor config in
model.config(used to build EMA/replicas).
DIMBA(...) notable kwargs: latent_diffusion, d_latent, use_vae_latent,
bidirectional=True, self_conditioning=False, prediction_type="x0"|"v",
zero_terminal_snr=True, embed_init_std=0.02, latent_scale=None (auto = 1/embed_init_std
for embedding mode, 1.0 for latent mode → calibrate).
forward(input_ids, t, noise=None, prompt_mask=None, x_self_cond=None, drop_cond=False)predict_token_logits(input_ids, t)→[B,L,vocab](the discrete/masked track)denoise_to_x0_latent(x_t, t, cond, x_self_cond=None)/denoise_step(...)(inference)conditioning_from_prompt(prompt_ids=None, batch_size, device, drop_cond=False)→[B,1,cond_dim]encode_latent/decode_latent/calibrate_latent_scale(batch, target_std=1.0)
- Continuous latent (default) —
GaussianEmbeddingCorruption; theforward()path above. - Discrete / masked (LLaDA/MDLM) —
diffusion/corruption.py:AbsorbingMaskCorruption+diffusion/masked_sampling.py:masked_diffusion_sample(predict_logits, ...); model side ispredict_token_logits. Needs a[MASK]token id (not in the tokenizer yet — pass explicitly). - Hybrid (novel, experimental) —
HybridCorruptioninterpolates masked ↔ Gaussian per token.
Use compute_dimba_losses(model, input_ids, t, *, ce_loss_weight=1.0, min_snr_gamma=5.0, prompt_mask=None) → (loss, parts). It combines:
- min-SNR-γ-weighted diffusion regression in latent space (x0 or v target),
- a cross-entropy / rounding anchor (trains the head/decoder, ties to real tokens),
- latent autoencoder consistency + optional VAE KL (latent mode).
DIMBALightningModule and SimpleTrainer both call it; both accept ce_loss_weight /
min_snr_gamma. The CDLM consistency loss (compute_consistency_loss) is de-leaked (null cond).
Schedule helpers: CosineNoiseSchedule(num_steps, zero_terminal_snr=True) with .add_noise,
.velocity, .predict_x0_from_v, .snr, plus enforce_zero_terminal_snr.
sample_from_model(model, prompt_ids, seq_len, num_steps, temperature, top_k, top_p, guidance_scale=1.0, eta=0.0, clamp_to_tokens=False)— correct x0-DDIM, CFG, self-cond carry, clean-prefix conditioning.DDIMSamplerwraps it.diffusion/rerank.py:best_of_k(generate_fn, score_fn, k)+diffusion_elbo_score(...)— best-of-K.
finetune_sft.py— SFT (leak-free per-position prompt cond; response-only CE via labels).finetune_dpo.py+training/preference.py— DPO/IPO/SimPO with a diffusion-ELBO/VRPO surrogate (diffusion log-likelihoods are intractable). Preferred for preference pairs.finetune_grpo.py+training/rewards.py— GRPO with a pluggable--reward(defaultnumeric;token_overlapkept but deprecated — it just teaches copying).
selective_scan(dt, A, Bmat, C, x, *, stable=True, chunk_size=64)— vectorized, numerically stable (chunked);selective_scan_sequentialis the parity reference;bidirectional_*too.SimpleMamba2uses it and falls back to the sequential scan if the result is non-finite.maybe_compile(module)—torch.compileon CUDA only.backends/mlx/— MLX skeleton (WIP).
src/dimba/{models,diffusion,training,data,tokenizers,evaluation,utils,backends}/
scripts/{train*,generate,evaluate,benchmark}.py scripts/finetuning/finetune_{sft,dpo,grpo,interactive}.py
configs/ tests/ notebooks/ docs/ paper/
- Python ≥3.9, black line-length 100, type hints, Google-style docstrings.
- Don't reintroduce the conditioning leak, the 2-tuple
forward, positive-ASSM, or un-scaled latents. Runcompileall+ the smoke before claiming a change works.
- PR #18 open into
main:c2352ba(overhaul) +60f30eb(latent scale-factor). Not merged. All known bugs fixed;compileallclean; runtime smoke 14/14 (venv python). - Open follow-ups: first-class masked-mode training script +
[MASK]token; an M1 quickstart config; train a real VAE to calibrate the latent against; cross-attention conditioning (stronger than pooled-global); real speed/quality benchmarks once compute lands.