Skip to content

Batched generation: generate_batch() with ragged-batch S3Gen#12

Merged
TroyHernandez merged 5 commits into
mainfrom
batch-generation
Jun 12, 2026
Merged

Batched generation: generate_batch() with ragged-batch S3Gen#12
TroyHernandez merged 5 commits into
mainfrom
batch-generation

Conversation

@TroyHernandez

Copy link
Copy Markdown
Contributor

Adds generate_batch(model, texts, voice, ...): T3 token generation runs per text (autoregressive generation doesn't batch — lengths and EOS differ), then ALL utterances synthesize in one batched S3Gen pass — a single CFM solve and a single vocoder call over the padded batch, trimmed per row. This is the upstream 0.1.7 speech_token_lens batching, not pipelining.

Plumbing changes:

  • s3gen$inference() accepts (B, T) tokens + speech_token_lens for ragged batches
  • solve_euler CFG generalized from hardcoded batch-2 to 2B (traced falls back to non-traced for B > 1)
  • flow builds per-row mel masks and expands single-voice conditioning across rows
  • generate()'s T3 stage extracted to .t3_text_to_tokens, shared by both entry points

Three padded-batch leaks found and fixed (each invisible at B = 1):

  1. ref_dict's (1, 1)-shaped prompt_token_len broadcast a (B) length vector to (1, B), collapsing the batch in make_pad_mask
  2. the CFM estimator's transformer blocks ran unmasked — padded tails bled into valid frames through bidirectional attention; a key-padding mask is now built when padding exists
  3. the conformer pre-lookahead conv reads rightward and pulled nonzero embedded-padding into the last valid frames; the tail is zeroed after embed

Validation (GPU, identical token sequences): batch row vs single run — encoder ≤ 2e-4, mel through the full 10-step CFM 0.003–0.005 (the single-run FP envelope; the Python-parity bar is 0.03). End-to-end batch of 3 texts: all EOS, correct durations, samples in ~/Sync.

CPU regression test for the mask/broadcast math (test_batch_masks.R); full suite passes. Version 0.1.0.8 + NEWS as separate bump commit.

- generate_batch(model, texts, voice): T3 per text (autoregressive),
  then ONE batched S3Gen pass (one CFM solve, one vocoder call) over
  the padded batch; per-row trimming by token count.
- s3gen$inference accepts (B, T) tokens + speech_token_lens; flow
  builds per-row mel masks, expands single-voice conditioning, and
  solve_euler generalizes CFG from hardcoded batch-2 to 2B.
- Three padded-batch leaks found and fixed: (1, 1) prompt_token_len
  broadcasting collapsed the batch in make_pad_mask; the CFM
  estimator's transformers ran unmasked (key-padding mask added); the
  conformer pre-lookahead conv read nonzero embedded padding (tail now
  zeroed first).
- Batch-vs-single parity on identical tokens: encoder <= 2e-4, mel
  0.003-0.005 (single-run FP envelope; Python-parity bar is 0.03).
- generate() T3 stage extracted to .t3_text_to_tokens, shared by both.
At padded positions the masked estimator leaves dphi = 0, so the
generated-region tail was raw initial Gaussian noise; HiFi-GAN's
convolutional context smeared it into the end of shorter rows. Tail is
now zeroed when speech_token_lens is given (matching the zero padding a
single run's convs see past sequence end). Also documents that traced/
autocast apply to the T3 stage only in generate_batch().
@TroyHernandez TroyHernandez merged commit cfb5785 into main Jun 12, 2026
4 checks passed
@TroyHernandez TroyHernandez deleted the batch-generation branch June 12, 2026 21:16
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant