From 8aa241c93d7224b2f6773cd06e8926dcf663690b Mon Sep 17 00:00:00 2001 From: TroyHernandez Date: Fri, 12 Jun 2026 15:37:28 -0500 Subject: [PATCH 1/5] Batched generation: generate_batch() with ragged-batch S3Gen - generate_batch(model, texts, voice): T3 per text (autoregressive), then ONE batched S3Gen pass (one CFM solve, one vocoder call) over the padded batch; per-row trimming by token count. - s3gen$inference accepts (B, T) tokens + speech_token_lens; flow builds per-row mel masks, expands single-voice conditioning, and solve_euler generalizes CFG from hardcoded batch-2 to 2B. - Three padded-batch leaks found and fixed: (1, 1) prompt_token_len broadcasting collapsed the batch in make_pad_mask; the CFM estimator's transformers ran unmasked (key-padding mask added); the conformer pre-lookahead conv read nonzero embedded padding (tail now zeroed first). - Batch-vs-single parity on identical tokens: encoder <= 2e-4, mel 0.003-0.005 (single-run FP envelope; Python-parity bar is 0.03). - generate() T3 stage extracted to .t3_text_to_tokens, shared by both. --- NAMESPACE | 1 + R/conformer.R | 6 + R/s3gen.R | 124 ++++++++---- R/tts.R | 312 ++++++++++++++++++++++--------- inst/tinytest/test_batch_masks.R | 17 ++ man/generate_batch.Rd | 31 +++ 6 files changed, 364 insertions(+), 127 deletions(-) create mode 100644 inst/tinytest/test_batch_masks.R create mode 100644 man/generate_batch.Rd diff --git a/NAMESPACE b/NAMESPACE index cdfa788..2524c59 100644 --- a/NAMESPACE +++ b/NAMESPACE @@ -12,6 +12,7 @@ export(decode_tokens) export(download_chatterbox_models) export(download_chatterbox_turbo_models) export(generate) +export(generate_batch) export(get_model_paths) export(get_turbo_model_paths) export(integrated_loudness) diff --git a/R/conformer.R b/R/conformer.R index 09de7f1..200f515 100644 --- a/R/conformer.R +++ b/R/conformer.R @@ -533,6 +533,12 @@ upsample_conformer_encoder_full <- torch::nn_module( pos_emb <- embed_result[[2]] masks <- embed_result[[3]] + # Zero the padded tail first: embed's LayerNorm makes pad frames + # nonzero, and the pre-lookahead conv reads rightward, so a padded + # row would otherwise differ from the same row run alone (where the + # conv right-pads true zeros) + xs <- xs$masked_fill(!masks$transpose(2L, 3L), 0.0) + # Pre-lookahead layer xs <- self$pre_lookahead_layer$forward(xs) diff --git a/R/s3gen.R b/R/s3gen.R index 24c0241..85e7d67 100644 --- a/R/s3gen.R +++ b/R/s3gen.R @@ -518,10 +518,16 @@ cfm_estimator <- torch::nn_module( }, compute_attn_mask = function(x, mask) { - # For now, return NULL (no causal masking in inference) - # The Python code uses chunk-based attention masks for streaming, - # but for full sequence inference, we can use standard attention - NULL + # No causal masking in (non-streaming) inference. Single utterances + # have no padding, so the transformers run unmasked. Padded batches + # need a key-padding mask or the tail bleeds into valid frames + # through bidirectional attention. + if (is.null(mask) || !as.logical((mask < 0.5)$any()$item())) { + return(NULL) + } + # (B, 1, T) float -> (B, 1, 1, T) bool, TRUE = may attend; + # broadcasts over heads and query positions in SDPA + (mask > 0.5)$unsqueeze(2) } ) @@ -610,6 +616,11 @@ causal_cfm <- torch::nn_module( # fresh noise. Avoids advancing the RNG and a wasted allocation. if (!use_meanflow && is.null(noised_mels) && !is.null(self$rand_noise)) { z <- self$rand_noise[,, 1:seq_len]$to(device = device)$to(dtype = mu$dtype) * temperature + if (mu$size(1) > 1L) { + # Batched: every row gets the same fixed noise, so row i + # matches what a single-utterance run would draw + z <- z$expand(c(mu$size(1), -1L, -1L))$contiguous() + } } else { z <- torch::torch_randn_like(mu) if (!is.null(noised_mels)) { @@ -663,6 +674,12 @@ causal_cfm <- torch::nn_module( device <- x$device dtype <- x$dtype + # Traced graph is fixed at CFG batch 2; fall back for real batches + if (traced && batch_size > 1L) { + warning("traced CFM supports batch size 1, falling back to non-traced") + traced <- FALSE + } + # For traced mode, pad to fixed max length if (traced && seq_len <= CFM_MAX_SEQ_LEN) { traced_est <- self$get_traced_estimator(device) @@ -699,40 +716,45 @@ causal_cfm <- torch::nn_module( work_len <- if (traced && seq_len <= CFM_MAX_SEQ_LEN) CFM_MAX_SEQ_LEN else seq_len - # Pre-allocate tensors for CFG (batch size 2) - x_in <- torch::torch_zeros(c(2L, 80L, work_len), device = device, dtype = dtype) - mask_in <- torch::torch_zeros(c(2L, 1L, work_len), device = device, dtype = dtype) - mu_in <- torch::torch_zeros(c(2L, 80L, work_len), device = device, dtype = dtype) - t_in <- torch::torch_zeros(2L, device = device, dtype = dtype) - spks_in <- torch::torch_zeros(c(2L, 80L), device = device, dtype = dtype) - cond_in <- torch::torch_zeros(c(2L, 80L, work_len), device = device, dtype = dtype) + # Pre-allocate tensors for CFG: rows 1:B conditional, (B+1):2B + # unconditional (zero mu/spks/cond) + B <- batch_size + x_in <- torch::torch_zeros(c(2L * B, 80L, work_len), device = device, dtype = dtype) + mask_in <- torch::torch_zeros(c(2L * B, 1L, work_len), device = device, dtype = dtype) + mu_in <- torch::torch_zeros(c(2L * B, 80L, work_len), device = device, dtype = dtype) + t_in <- torch::torch_zeros(2L * B, device = device, dtype = dtype) + spks_in <- torch::torch_zeros(c(2L * B, 80L), device = device, dtype = dtype) + cond_in <- torch::torch_zeros(c(2L * B, 80L, work_len), device = device, dtype = dtype) for (step in 2:length(t_span)) { # Classifier-Free Guidance: conditional and unconditional paths # Use padded tensors in traced mode if (traced && seq_len <= CFM_MAX_SEQ_LEN) { - x_in[1:2,,] <- x_padded - mask_in[1:2,,] <- mask_padded - mu_in[1,,] <- mu_padded - cond_in[1,,] <- cond_padded + x_in[1:B,,] <- x_padded + x_in[(B + 1):(2 * B),,] <- x_padded + mask_in[1:B,,] <- mask_padded + mask_in[(B + 1):(2 * B),,] <- mask_padded + mu_in[1:B,,] <- mu_padded + cond_in[1:B,,] <- cond_padded } else { - x_in[1:2,,] <- x - mask_in[1:2,,] <- mask - mu_in[1,,] <- mu - cond_in[1,,] <- cond + x_in[1:B,,] <- x + x_in[(B + 1):(2 * B),,] <- x + mask_in[1:B,,] <- mask + mask_in[(B + 1):(2 * B),,] <- mask + mu_in[1:B,,] <- mu + cond_in[1:B,,] <- cond } - # mu_in[2] stays zero (unconditional) - t_in[1:2] <- t - spks_in[1,] <- spks - # spks_in[2] stays zero - # cond_in[2] stays zero + # mu_in[(B+1):(2B)] stays zero (unconditional) + t_in[1:(2 * B)] <- t + spks_in[1:B,] <- spks + # spks_in/cond_in upper half stays zero # Forward through estimator (traced or normal) dphi_dt <- estimator_fn(x_in, mask_in, mu_in, t_in, spks_in, cond_in) # CFG combination - dphi_cond <- dphi_dt[1,,]$unsqueeze(1) - dphi_uncond <- dphi_dt[2,,]$unsqueeze(1) + dphi_cond <- dphi_dt[1:B,,, drop = FALSE] + dphi_uncond <- dphi_dt[(B + 1):(2 * B),,, drop = FALSE] dphi_dt <- (1.0 + self$inference_cfg_rate) * dphi_cond - self$inference_cfg_rate * dphi_uncond # Euler step (only on actual sequence length) @@ -819,14 +841,32 @@ causal_masked_diff_xvec <- torch::nn_module( meanflow = NULL ) { device <- token$device + B <- token$size(1) # Normalize and project speaker embedding embedding <- torch::nnf_normalize(embedding, dim = 2) embedding <- self$spk_embed_affine_layer$forward(embedding) - # Concatenate prompt and speech tokens + # Batched generation shares one voice: expand the single-row + # conditioning to every utterance row + if (B > 1L) { + if (embedding$size(1) == 1L) { + embedding <- embedding$expand(c(B, -1L))$contiguous() + } + if (prompt_token$size(1) == 1L) { + prompt_token <- prompt_token$expand(c(B, -1L))$contiguous() + } + if (prompt_feat$size(1) == 1L) { + prompt_feat <- prompt_feat$expand(c(B, -1L, -1L))$contiguous() + } + } + + # Concatenate prompt and speech tokens. Lengths flatten to (B): + # ref_dict's prompt_token_len is (1, 1), which would broadcast a + # (B) token_len up to (1, B) and collapse the batch in make_pad_mask token <- torch::torch_cat(list(prompt_token, token), dim = 2) - token_len <- prompt_token_len + token_len + token_len <- (prompt_token_len$view(-1) + token_len$view(-1))$to( + dtype = torch::torch_long()) # Create mask mask <- (!make_pad_mask(token_len))$unsqueeze(3)$to(dtype = embedding$dtype, device = device) @@ -858,15 +898,18 @@ causal_masked_diff_xvec <- torch::nn_module( h <- self$encoder_proj$forward(h) # Prepare conditioning: prompt mel fills the first mel_len1 frames - conds <- torch::torch_zeros(c(1, mel_len1 + mel_len2, self$output_size), + conds <- torch::torch_zeros(c(B, mel_len1 + mel_len2, self$output_size), device = device, dtype = h$dtype) if (mel_len1 > 0) { - conds[1, 1:mel_len1,] <- prompt_feat[1,,] + conds[, 1:mel_len1,] <- prompt_feat } conds <- conds$transpose(2, 3) - # Create mask for decoder - dec_mask <- torch::torch_ones(c(1, 1, mel_len1 + mel_len2), device = device, dtype = h$dtype) + # Create mask for decoder: per-row valid mel length (ragged batches + # have padded tails). Single-row stays all-ones as before. + mel_lens <- mel_len1 + + (token_len - prompt_token_len$view(-1)) * self$token_mel_ratio + dec_mask <- (!make_pad_mask(mel_lens, max_len = mel_len1 + mel_len2))$unsqueeze(2)$to(dtype = h$dtype, device = device) # Run decoder h <- h$transpose(2, 3) @@ -1004,7 +1047,10 @@ s3gen <- torch::nn_module( }, #' Run inference (tokens -> mel -> audio; - #' skip_vocoder = TRUE stops at the mel) + #' skip_vocoder = TRUE stops at the mel). + #' speech_tokens may be (B, T) with + #' speech_token_lens giving per-row valid + #' lengths for ragged batches. inference = function( speech_tokens, ref_wav = NULL, @@ -1013,7 +1059,8 @@ s3gen <- torch::nn_module( finalize = TRUE, traced = FALSE, n_cfm_timesteps = NULL, - skip_vocoder = FALSE + skip_vocoder = FALSE, + speech_token_lens = NULL ) { # Get reference dict if (is.null(ref_dict)) { @@ -1030,7 +1077,12 @@ s3gen <- torch::nn_module( speech_tokens <- speech_tokens$unsqueeze(1) } speech_tokens <- speech_tokens$to(device = device) - speech_token_len <- torch::torch_tensor(speech_tokens$size(2), device = device) + speech_token_len <- if (is.null(speech_token_lens)) { + torch::torch_tensor(rep(speech_tokens$size(2), speech_tokens$size(1)), + device = device) + } else { + torch::torch_tensor(as.integer(speech_token_lens), device = device) + } # Determine timesteps and noise for meanflow n_steps <- n_cfm_timesteps @@ -1039,7 +1091,7 @@ s3gen <- torch::nn_module( if (is.null(n_steps)) n_steps <- 2L # MeanFlow uses random noise per call instead of pre-computed buffer noised_mels <- torch::torch_randn( - c(1L, 80L, as.integer(speech_tokens$size(2)) * 2L), + c(speech_tokens$size(1), 80L, as.integer(speech_tokens$size(2)) * 2L), dtype = torch::torch_float32(), device = device ) } diff --git a/R/tts.R b/R/tts.R index ab0d61d..5f0108c 100644 --- a/R/tts.R +++ b/R/tts.R @@ -411,97 +411,13 @@ generate <- function(model, text, voice, exaggeration = 0.5, stop("voice must be a voice_embedding object or path to reference audio") } - # Tokenize text - if (is_turbo) { - # GPT-2 tokenizer - text_ids <- tokenize_text_gpt2(model$tokenizer, text) - text_tokens <- torch::torch_tensor(text_ids, - dtype = torch::torch_long())$unsqueeze(1L)$to(device = device) - } else { - text_tokens <- tokenize_text(model$tokenizer, text) - text_tokens <- torch::torch_tensor(text_tokens, dtype = torch::torch_long())$unsqueeze(1L)$to(device = device) - } - - # Create T3 conditioning - cond <- t3_cond( - speaker_emb = voice$ve_embedding, - cond_prompt_speech_tokens = voice$cond_prompt_speech_tokens, - emotion_adv = if (is_turbo) NULL else exaggeration - ) - - # Generate speech tokens with T3 - message("Generating speech tokens...") - - if (is_turbo) { - # Turbo inference: no CFG, no min_p, uses top_k - if (use_autocast) { - torch::with_autocast(device_type = "cuda", { - torch::with_no_grad({ - speech_tokens <- t3_inference_turbo( - model = model$t3, - cond = cond, - text_tokens = text_tokens, - temperature = temperature, - top_k = top_k, - top_p = top_p, - repetition_penalty = repetition_penalty - ) - }) - }) - } else { - torch::with_no_grad({ - speech_tokens <- t3_inference_turbo( - model = model$t3, - cond = cond, - text_tokens = text_tokens, - temperature = temperature, - top_k = top_k, - top_p = top_p, - repetition_penalty = repetition_penalty - ) - }) - } - } else { - # Standard inference with CFG - backend <- match.arg(backend) - if (backend == "jit") { - inference_fn <- t3_inference_jit - } else if (traced) { - inference_fn <- t3_inference_traced - } else { - inference_fn <- t3_inference - } - - inf_args <- list( - model = model$t3, - cond = cond, - text_tokens = text_tokens, - cfg_weight = cfg_weight, - temperature = temperature, - top_p = top_p, - min_p = min_p, - repetition_penalty = repetition_penalty, - max_new_tokens = max_new_tokens - ) - # Cache sizing only applies to the pre-allocated-cache backends; - # cpp auto-sizes when NULL, traced keeps its 350 default (a new - # size means a fresh ~50s JIT trace) - if (!is.null(max_cache_len) && (backend == "jit" || traced)) { - inf_args$max_cache_len <- max_cache_len - } - - if (use_autocast) { - torch::with_autocast(device_type = "cuda", { - torch::with_no_grad({ - speech_tokens <- do.call(inference_fn, inf_args) - }) - }) - } else { - torch::with_no_grad({ - speech_tokens <- do.call(inference_fn, inf_args) - }) - } - } + speech_tokens <- .t3_text_to_tokens(model, text, voice, + exaggeration = exaggeration, cfg_weight = cfg_weight, + temperature = temperature, top_p = top_p, min_p = min_p, + traced = traced, backend = backend, top_k = top_k, + repetition_penalty = repetition_penalty, + max_new_tokens = max_new_tokens, max_cache_len = max_cache_len, + use_autocast = use_autocast) # Capture EOS status before drop_invalid_tokens strips the attribute eos_found <- isTRUE(attr(speech_tokens, "eos_found")) @@ -600,6 +516,220 @@ generate <- function(model, text, voice, exaggeration = 0.5, ) } +#' T3 stage shared by generate() and generate_batch(): tokenize one +#' text and run the configured inference backend. \code{voice} must +#' already be a voice_embedding. Returns the speech tokens with the +#' eos_found attribute intact. +#' +#' @noRd +.t3_text_to_tokens <- function (model, text, voice, exaggeration, + cfg_weight, temperature, top_p, min_p, + traced, backend, top_k, + repetition_penalty, max_new_tokens, + max_cache_len, use_autocast) { + device <- model$device + is_turbo <- isTRUE(model$turbo) + + # Tokenize text + if (is_turbo) { + text_ids <- tokenize_text_gpt2(model$tokenizer, text) + text_tokens <- torch::torch_tensor(text_ids, + dtype = torch::torch_long())$unsqueeze(1L)$to(device = device) + } else { + text_tokens <- tokenize_text(model$tokenizer, text) + text_tokens <- torch::torch_tensor(text_tokens, + dtype = torch::torch_long())$unsqueeze(1L)$to(device = device) + } + + # Create T3 conditioning + cond <- t3_cond( + speaker_emb = voice$ve_embedding, + cond_prompt_speech_tokens = voice$cond_prompt_speech_tokens, + emotion_adv = if (is_turbo) NULL else exaggeration + ) + + message("Generating speech tokens...") + + if (is_turbo) { + # Turbo inference: no CFG, no min_p, uses top_k + inf_args <- list( + model = model$t3, + cond = cond, + text_tokens = text_tokens, + temperature = temperature, + top_k = top_k, + top_p = top_p, + repetition_penalty = repetition_penalty + ) + inference_fn <- t3_inference_turbo + } else { + # Standard inference with CFG + backend <- match.arg(backend, c("r", "jit")) + if (backend == "jit") { + inference_fn <- t3_inference_jit + } else if (traced) { + inference_fn <- t3_inference_traced + } else { + inference_fn <- t3_inference + } + + inf_args <- list( + model = model$t3, + cond = cond, + text_tokens = text_tokens, + cfg_weight = cfg_weight, + temperature = temperature, + top_p = top_p, + min_p = min_p, + repetition_penalty = repetition_penalty, + max_new_tokens = max_new_tokens + ) + # Cache sizing only applies to the pre-allocated-cache backends; + # jit auto-sizes when NULL, traced keeps its 350 default (a new + # size means a fresh ~50s JIT trace) + if (!is.null(max_cache_len) && (backend == "jit" || traced)) { + inf_args$max_cache_len <- max_cache_len + } + } + + if (use_autocast) { + torch::with_autocast(device_type = "cuda", { + torch::with_no_grad({ + speech_tokens <- do.call(inference_fn, inf_args) + }) + }) + } else { + torch::with_no_grad({ + speech_tokens <- do.call(inference_fn, inf_args) + }) + } + speech_tokens +} + +#' Generate speech for several texts with one batched synthesis pass +#' +#' Runs T3 token generation per text (autoregressive, sequential), then +#' synthesizes ALL utterances in a single batched S3Gen pass (one CFM +#' solve and one vocoder call over the padded batch). Per-utterance +#' results match single \code{\link{generate}} calls up to CFM noise +#' handling - the fixed noise buffer means row i sees the same initial +#' noise it would alone. Standard model only. +#' +#' @param model Loaded chatterbox model (standard, not turbo) +#' @param texts Character vector of texts to synthesize +#' @param voice Shared voice: voice_embedding or reference audio path +#' @param ... Arguments passed through to the T3 stage, as in +#' \code{\link{generate}} (exaggeration, cfg_weight, temperature, +#' top_p, min_p, backend, repetition_penalty, normalize_text, +#' max_new_tokens, max_cache_len) +#' @return List with one \code{\link{generate}}-style result per text +#' (audio, sample_rate, eos_found, n_tokens, audio_sec) +#' @export +generate_batch <- function (model, texts, voice, ...) { + if (!is_loaded(model)) { + stop("Model not loaded. Call load_chatterbox() first.") + } + if (isTRUE(model$turbo)) { + stop("generate_batch supports the standard model only") + } + if (!is.character(texts) || length(texts) == 0) { + stop("texts must be a non-empty character vector") + } + + args <- list(...) + arg_or <- function (name, default) args[[name]] %||% default + normalize_text <- isTRUE(arg_or("normalize_text", TRUE)) + use_autocast <- isTRUE(arg_or("autocast", FALSE)) && + grepl("^cuda", model$device) + + if (is.character(voice)) { + voice <- create_voice_embedding(model, voice) + } else if (!inherits(voice, "voice_embedding")) { + stop("voice must be a voice_embedding object or path to ", + "reference audio") + } + + # T3 per text (autoregressive generation does not batch; lengths + # and EOS differ per utterance) + token_vecs <- vector("list", length(texts)) + eos <- logical(length(texts)) + for (i in seq_along(texts)) { + txt <- texts[i] + if (normalize_text) { + txt <- normalize_tts_text(txt) + } + txt <- punc_norm(txt) + tokens <- .t3_text_to_tokens(model, txt, voice, + exaggeration = arg_or("exaggeration", 0.5), + cfg_weight = arg_or("cfg_weight", 0.5), + temperature = arg_or("temperature", 0.8), + top_p = arg_or("top_p", 1.0), + min_p = arg_or("min_p", 0.05), + traced = isTRUE(arg_or("traced", FALSE)), + backend = arg_or("backend", "r"), + top_k = arg_or("top_k", 1000L), + repetition_penalty = arg_or("repetition_penalty", 1.2), + max_new_tokens = arg_or("max_new_tokens", 1000L), + max_cache_len = arg_or("max_cache_len", NULL), + use_autocast = use_autocast) + eos[i] <- isTRUE(attr(tokens, "eos_found")) + token_vecs[[i]] <- as.integer(drop_invalid_tokens(tokens)) + if (!eos[i]) { + warning("Text ", i, " hit the token cap without end-of-speech (", + length(token_vecs[[i]]), " tokens). Output may be garbage.", + call. = FALSE) + } + } + + lens <- vapply(token_vecs, length, integer(1)) + results <- vector("list", length(texts)) + empty <- lens == 0L + for (i in which(empty)) { + warning("No valid speech tokens for text ", i, call. = FALSE) + results[[i]] <- list(audio = numeric(0), sample_rate = S3GEN_SR, + eos_found = eos[i], n_tokens = 0L, audio_sec = 0) + } + live <- which(!empty) + if (length(live) == 0) { + return(results) + } + + # Pad to a (B, Tmax) batch; padded tail is masked by + # speech_token_lens all the way through CFM and trimmed after the + # vocoder + t_max <- max(lens[live]) + mat <- t(vapply(token_vecs[live], + function (v) c(v, rep(0L, t_max - length(v))), integer(t_max))) + speech_tokens <- torch::torch_tensor(mat, + dtype = torch::torch_long())$to(device = model$device) + + message("Synthesizing ", length(live), " waveforms in one batch...") + torch::with_no_grad({ + out <- model$s3gen$inference( + speech_tokens = speech_tokens, + ref_dict = voice$ref_dict, + finalize = TRUE, + speech_token_lens = lens[live] + ) + }) + wavs <- out[[1]]$cpu() + + # 2 mel frames per token, 480 samples per mel frame + for (k in seq_along(live)) { + i <- live[k] + n_samples <- lens[i] * 2L * 480L + audio <- as.numeric(wavs[k, 1:min(n_samples, wavs$size(2))]) + results[[i]] <- list( + audio = audio, + sample_rate = S3GEN_SR, + eos_found = eos[i], + n_tokens = lens[i], + audio_sec = length(audio) / S3GEN_SR + ) + } + results +} + #' Generate speech and save to file #' #' @param model Chatterbox model diff --git a/inst/tinytest/test_batch_masks.R b/inst/tinytest/test_batch_masks.R new file mode 100644 index 0000000..6a84d5c --- /dev/null +++ b/inst/tinytest/test_batch_masks.R @@ -0,0 +1,17 @@ +# Batched s3gen length/mask math (no weights needed) + +if (requireNamespace("torch", quietly = TRUE) && torch::torch_is_installed()) { + # make_pad_mask over a ragged batch + lens <- torch::torch_tensor(c(3L, 5L, 1L)) + m <- chatterbox:::make_pad_mask(lens) + expect_identical(dim(m), c(3L, 5L)) + expect_identical(as.logical(m[1, ]), c(FALSE, FALSE, FALSE, TRUE, TRUE)) + expect_identical(as.logical(m[3, ]), c(FALSE, TRUE, TRUE, TRUE, TRUE)) + + # the (1,1) prompt_token_len broadcast bug: flattened sum keeps (B) + prompt_len <- torch::torch_tensor(matrix(250L, 1, 1)) # ref_dict shape + tok_len <- torch::torch_tensor(c(60L, 80L, 25L)) + total <- prompt_len$view(-1) + tok_len$view(-1) + expect_identical(dim(total), 3L) + expect_identical(as.integer(total), c(310L, 330L, 275L)) +} diff --git a/man/generate_batch.Rd b/man/generate_batch.Rd new file mode 100644 index 0000000..3267725 --- /dev/null +++ b/man/generate_batch.Rd @@ -0,0 +1,31 @@ +% tinyrox says don't edit this manually, but it can't stop you! +\name{generate_batch} +\alias{generate_batch} +\title{Generate speech for several texts with one batched synthesis pass} +\usage{ +generate_batch(model, texts, voice, ...) +} +\arguments{ +\item{model}{Loaded chatterbox model (standard, not turbo)} + +\item{texts}{Character vector of texts to synthesize} + +\item{voice}{Shared voice: voice_embedding or reference audio path} + +\item{...}{Arguments passed through to the T3 stage, as in +\code{\link{generate}} (exaggeration, cfg_weight, temperature, +top_p, min_p, backend, repetition_penalty, normalize_text, +max_new_tokens, max_cache_len)} +} +\value{ +List with one \code{\link{generate}}-style result per text + (audio, sample_rate, eos_found, n_tokens, audio_sec) +} +\description{ +Runs T3 token generation per text (autoregressive, sequential), then +synthesizes ALL utterances in a single batched S3Gen pass (one CFM +solve and one vocoder call over the padded batch). Per-utterance +results match single \code{\link{generate}} calls up to CFM noise +handling - the fixed noise buffer means row i sees the same initial +noise it would alone. Standard model only. +} From 687d1e63407d4688f98f5ebd2e94125faf993442 Mon Sep 17 00:00:00 2001 From: TroyHernandez Date: Fri, 12 Jun 2026 15:37:52 -0500 Subject: [PATCH 2/5] rformat + document --- R/s3gen.R | 4 +- R/tts.R | 117 +++++++++++++++++++++++++++--------------------------- 2 files changed, 60 insertions(+), 61 deletions(-) diff --git a/R/s3gen.R b/R/s3gen.R index 85e7d67..eaa36cc 100644 --- a/R/s3gen.R +++ b/R/s3gen.R @@ -908,7 +908,7 @@ causal_masked_diff_xvec <- torch::nn_module( # Create mask for decoder: per-row valid mel length (ragged batches # have padded tails). Single-row stays all-ones as before. mel_lens <- mel_len1 + - (token_len - prompt_token_len$view(-1)) * self$token_mel_ratio + (token_len - prompt_token_len$view(-1)) * self$token_mel_ratio dec_mask <- (!make_pad_mask(mel_lens, max_len = mel_len1 + mel_len2))$unsqueeze(2)$to(dtype = h$dtype, device = device) # Run decoder @@ -1079,7 +1079,7 @@ s3gen <- torch::nn_module( speech_tokens <- speech_tokens$to(device = device) speech_token_len <- if (is.null(speech_token_lens)) { torch::torch_tensor(rep(speech_tokens$size(2), speech_tokens$size(1)), - device = device) + device = device) } else { torch::torch_tensor(as.integer(speech_token_lens), device = device) } diff --git a/R/tts.R b/R/tts.R index 5f0108c..eba0518 100644 --- a/R/tts.R +++ b/R/tts.R @@ -412,12 +412,12 @@ generate <- function(model, text, voice, exaggeration = 0.5, } speech_tokens <- .t3_text_to_tokens(model, text, voice, - exaggeration = exaggeration, cfg_weight = cfg_weight, - temperature = temperature, top_p = top_p, min_p = min_p, - traced = traced, backend = backend, top_k = top_k, - repetition_penalty = repetition_penalty, - max_new_tokens = max_new_tokens, max_cache_len = max_cache_len, - use_autocast = use_autocast) + exaggeration = exaggeration, cfg_weight = cfg_weight, + temperature = temperature, top_p = top_p, min_p = min_p, + traced = traced, backend = backend, top_k = top_k, + repetition_penalty = repetition_penalty, + max_new_tokens = max_new_tokens, max_cache_len = max_cache_len, + use_autocast = use_autocast) # Capture EOS status before drop_invalid_tokens strips the attribute eos_found <- isTRUE(attr(speech_tokens, "eos_found")) @@ -522,11 +522,10 @@ generate <- function(model, text, voice, exaggeration = 0.5, #' eos_found attribute intact. #' #' @noRd -.t3_text_to_tokens <- function (model, text, voice, exaggeration, - cfg_weight, temperature, top_p, min_p, - traced, backend, top_k, - repetition_penalty, max_new_tokens, - max_cache_len, use_autocast) { +.t3_text_to_tokens <- function(model, text, voice, exaggeration, cfg_weight, + temperature, top_p, min_p, traced, backend, + top_k, repetition_penalty, max_new_tokens, + max_cache_len, use_autocast) { device <- model$device is_turbo <- isTRUE(model$turbo) @@ -543,9 +542,9 @@ generate <- function(model, text, voice, exaggeration = 0.5, # Create T3 conditioning cond <- t3_cond( - speaker_emb = voice$ve_embedding, - cond_prompt_speech_tokens = voice$cond_prompt_speech_tokens, - emotion_adv = if (is_turbo) NULL else exaggeration + speaker_emb = voice$ve_embedding, + cond_prompt_speech_tokens = voice$cond_prompt_speech_tokens, + emotion_adv = if (is_turbo) NULL else exaggeration ) message("Generating speech tokens...") @@ -553,13 +552,13 @@ generate <- function(model, text, voice, exaggeration = 0.5, if (is_turbo) { # Turbo inference: no CFG, no min_p, uses top_k inf_args <- list( - model = model$t3, - cond = cond, - text_tokens = text_tokens, - temperature = temperature, - top_k = top_k, - top_p = top_p, - repetition_penalty = repetition_penalty + model = model$t3, + cond = cond, + text_tokens = text_tokens, + temperature = temperature, + top_k = top_k, + top_p = top_p, + repetition_penalty = repetition_penalty ) inference_fn <- t3_inference_turbo } else { @@ -574,15 +573,15 @@ generate <- function(model, text, voice, exaggeration = 0.5, } inf_args <- list( - model = model$t3, - cond = cond, - text_tokens = text_tokens, - cfg_weight = cfg_weight, - temperature = temperature, - top_p = top_p, - min_p = min_p, - repetition_penalty = repetition_penalty, - max_new_tokens = max_new_tokens + model = model$t3, + cond = cond, + text_tokens = text_tokens, + cfg_weight = cfg_weight, + temperature = temperature, + top_p = top_p, + min_p = min_p, + repetition_penalty = repetition_penalty, + max_new_tokens = max_new_tokens ) # Cache sizing only applies to the pre-allocated-cache backends; # jit auto-sizes when NULL, traced keeps its 350 default (a new @@ -625,7 +624,7 @@ generate <- function(model, text, voice, exaggeration = 0.5, #' @return List with one \code{\link{generate}}-style result per text #' (audio, sample_rate, eos_found, n_tokens, audio_sec) #' @export -generate_batch <- function (model, texts, voice, ...) { +generate_batch <- function(model, texts, voice, ...) { if (!is_loaded(model)) { stop("Model not loaded. Call load_chatterbox() first.") } @@ -637,16 +636,16 @@ generate_batch <- function (model, texts, voice, ...) { } args <- list(...) - arg_or <- function (name, default) args[[name]] %||% default + arg_or <- function(name, default) args[[name]] %||% default normalize_text <- isTRUE(arg_or("normalize_text", TRUE)) use_autocast <- isTRUE(arg_or("autocast", FALSE)) && - grepl("^cuda", model$device) + grepl("^cuda", model$device) if (is.character(voice)) { voice <- create_voice_embedding(model, voice) } else if (!inherits(voice, "voice_embedding")) { stop("voice must be a voice_embedding object or path to ", - "reference audio") + "reference audio") } # T3 per text (autoregressive generation does not batch; lengths @@ -660,24 +659,24 @@ generate_batch <- function (model, texts, voice, ...) { } txt <- punc_norm(txt) tokens <- .t3_text_to_tokens(model, txt, voice, - exaggeration = arg_or("exaggeration", 0.5), - cfg_weight = arg_or("cfg_weight", 0.5), - temperature = arg_or("temperature", 0.8), - top_p = arg_or("top_p", 1.0), - min_p = arg_or("min_p", 0.05), - traced = isTRUE(arg_or("traced", FALSE)), - backend = arg_or("backend", "r"), - top_k = arg_or("top_k", 1000L), - repetition_penalty = arg_or("repetition_penalty", 1.2), - max_new_tokens = arg_or("max_new_tokens", 1000L), - max_cache_len = arg_or("max_cache_len", NULL), - use_autocast = use_autocast) + exaggeration = arg_or("exaggeration", 0.5), + cfg_weight = arg_or("cfg_weight", 0.5), + temperature = arg_or("temperature", 0.8), + top_p = arg_or("top_p", 1.0), + min_p = arg_or("min_p", 0.05), + traced = isTRUE(arg_or("traced", FALSE)), + backend = arg_or("backend", "r"), + top_k = arg_or("top_k", 1000L), + repetition_penalty = arg_or("repetition_penalty", 1.2), + max_new_tokens = arg_or("max_new_tokens", 1000L), + max_cache_len = arg_or("max_cache_len", NULL), + use_autocast = use_autocast) eos[i] <- isTRUE(attr(tokens, "eos_found")) token_vecs[[i]] <- as.integer(drop_invalid_tokens(tokens)) if (!eos[i]) { warning("Text ", i, " hit the token cap without end-of-speech (", - length(token_vecs[[i]]), " tokens). Output may be garbage.", - call. = FALSE) + length(token_vecs[[i]]), " tokens). Output may be garbage.", + call. = FALSE) } } @@ -687,7 +686,7 @@ generate_batch <- function (model, texts, voice, ...) { for (i in which(empty)) { warning("No valid speech tokens for text ", i, call. = FALSE) results[[i]] <- list(audio = numeric(0), sample_rate = S3GEN_SR, - eos_found = eos[i], n_tokens = 0L, audio_sec = 0) + eos_found = eos[i], n_tokens = 0L, audio_sec = 0) } live <- which(!empty) if (length(live) == 0) { @@ -699,17 +698,17 @@ generate_batch <- function (model, texts, voice, ...) { # vocoder t_max <- max(lens[live]) mat <- t(vapply(token_vecs[live], - function (v) c(v, rep(0L, t_max - length(v))), integer(t_max))) + function(v) c(v, rep(0L, t_max - length(v))), integer(t_max))) speech_tokens <- torch::torch_tensor(mat, dtype = torch::torch_long())$to(device = model$device) message("Synthesizing ", length(live), " waveforms in one batch...") torch::with_no_grad({ out <- model$s3gen$inference( - speech_tokens = speech_tokens, - ref_dict = voice$ref_dict, - finalize = TRUE, - speech_token_lens = lens[live] + speech_tokens = speech_tokens, + ref_dict = voice$ref_dict, + finalize = TRUE, + speech_token_lens = lens[live] ) }) wavs <- out[[1]]$cpu() @@ -720,11 +719,11 @@ generate_batch <- function (model, texts, voice, ...) { n_samples <- lens[i] * 2L * 480L audio <- as.numeric(wavs[k, 1:min(n_samples, wavs$size(2))]) results[[i]] <- list( - audio = audio, - sample_rate = S3GEN_SR, - eos_found = eos[i], - n_tokens = lens[i], - audio_sec = length(audio) / S3GEN_SR + audio = audio, + sample_rate = S3GEN_SR, + eos_found = eos[i], + n_tokens = lens[i], + audio_sec = length(audio) / S3GEN_SR ) } results From 39e1a0f260f9e10e642dfcc3b575364d967f6407 Mon Sep 17 00:00:00 2001 From: TroyHernandez Date: Fri, 12 Jun 2026 15:37:52 -0500 Subject: [PATCH 3/5] Bump version to 0.1.0.8 --- DESCRIPTION | 2 +- NEWS.md | 6 ++++++ 2 files changed, 7 insertions(+), 1 deletion(-) diff --git a/DESCRIPTION b/DESCRIPTION index 34bfb9b..6966fee 100644 --- a/DESCRIPTION +++ b/DESCRIPTION @@ -1,6 +1,6 @@ Package: chatterbox Title: Text-to-Speech Using Chatterbox TTS Engine -Version: 0.1.0.7 +Version: 0.1.0.8 Authors@R: c(person("Troy", "Hernandez", role = c("aut", "cre"), email = "troy@cornball.ai", diff --git a/NEWS.md b/NEWS.md index 40356bd..8b20134 100644 --- a/NEWS.md +++ b/NEWS.md @@ -1,3 +1,9 @@ +# chatterbox 0.1.0.8 (development) + +- New `generate_batch()`: several texts, one batched S3Gen synthesis + pass; padded rows validated to match single runs (mel diff <= 0.005). +- `s3gen$inference()` accepts ragged batches via `speech_token_lens`. + # chatterbox 0.1.0.7 (development) - New `voice_convert()`: speech-to-speech voice conversion (port of From 949fbdc1dcd06efdf10d1f67e2d3f1b883e95f45 Mon Sep 17 00:00:00 2001 From: TroyHernandez Date: Fri, 12 Jun 2026 16:00:23 -0500 Subject: [PATCH 4/5] Review fix: zero padded mel tails before the vocoder At padded positions the masked estimator leaves dphi = 0, so the generated-region tail was raw initial Gaussian noise; HiFi-GAN's convolutional context smeared it into the end of shorter rows. Tail is now zeroed when speech_token_lens is given (matching the zero padding a single run's convs see past sequence end). Also documents that traced/ autocast apply to the T3 stage only in generate_batch(). --- R/s3gen.R | 14 ++++++++++++++ R/tts.R | 4 +++- man/generate_batch.Rd | 4 +++- 3 files changed, 20 insertions(+), 2 deletions(-) diff --git a/R/s3gen.R b/R/s3gen.R index eaa36cc..a848e6f 100644 --- a/R/s3gen.R +++ b/R/s3gen.R @@ -1113,6 +1113,20 @@ s3gen <- torch::nn_module( ) output_mels <- result[[1]] + # Ragged batch: zero each row's padded mel tail. At padded + # positions the masked estimator leaves dphi = 0, so the "mel" + # there is the untouched initial Gaussian noise - the vocoder's + # convolutional context would smear it into the end of shorter + # utterances. Zeros match what a single run's convs see (zero + # padding past the sequence end). + if (!is.null(speech_token_lens)) { + gen_mel_lens <- (speech_token_len * 2L)$to(dtype = torch::torch_long()) + gen_mask <- (!make_pad_mask(gen_mel_lens, + max_len = output_mels$size(3)))$unsqueeze(2)$to( + dtype = output_mels$dtype, device = output_mels$device) + output_mels <- output_mels * gen_mask + } + # Vocoder (mel -> audio) if (!skip_vocoder && !is.null(self$mel2wav)) { vocoder_result <- self$mel2wav$inference(output_mels) diff --git a/R/tts.R b/R/tts.R index eba0518..ae21b36 100644 --- a/R/tts.R +++ b/R/tts.R @@ -620,7 +620,9 @@ generate <- function(model, text, voice, exaggeration = 0.5, #' @param ... Arguments passed through to the T3 stage, as in #' \code{\link{generate}} (exaggeration, cfg_weight, temperature, #' top_p, min_p, backend, repetition_penalty, normalize_text, -#' max_new_tokens, max_cache_len) +#' max_new_tokens, max_cache_len). \code{traced} and \code{autocast} +#' affect the T3 stage only: the batched S3Gen synthesis always runs +#' eager float32 (traced CFM is fixed at batch 1). #' @return List with one \code{\link{generate}}-style result per text #' (audio, sample_rate, eos_found, n_tokens, audio_sec) #' @export diff --git a/man/generate_batch.Rd b/man/generate_batch.Rd index 3267725..688548f 100644 --- a/man/generate_batch.Rd +++ b/man/generate_batch.Rd @@ -15,7 +15,9 @@ generate_batch(model, texts, voice, ...) \item{...}{Arguments passed through to the T3 stage, as in \code{\link{generate}} (exaggeration, cfg_weight, temperature, top_p, min_p, backend, repetition_penalty, normalize_text, -max_new_tokens, max_cache_len)} +max_new_tokens, max_cache_len). \code{traced} and \code{autocast} +affect the T3 stage only: the batched S3Gen synthesis always runs +eager float32 (traced CFM is fixed at batch 1).} } \value{ List with one \code{\link{generate}}-style result per text From 7d93a63b8ecb9974ffaa54ad8e66612db8c9151e Mon Sep 17 00:00:00 2001 From: TroyHernandez Date: Fri, 12 Jun 2026 16:15:48 -0500 Subject: [PATCH 5/5] generate_batch(): reject unknown ... arguments --- R/tts.R | 9 +++++++++ inst/tinytest/test_batch_masks.R | 9 +++++++++ 2 files changed, 18 insertions(+) diff --git a/R/tts.R b/R/tts.R index ae21b36..e2f23aa 100644 --- a/R/tts.R +++ b/R/tts.R @@ -638,6 +638,15 @@ generate_batch <- function(model, texts, voice, ...) { } args <- list(...) + known <- c("exaggeration", "cfg_weight", "temperature", "top_p", + "min_p", "traced", "backend", "top_k", "repetition_penalty", + "normalize_text", "max_new_tokens", "max_cache_len", + "autocast") + unknown <- setdiff(names(args), known) + if (length(unknown) > 0) { + stop("Unsupported arguments: ", paste(unknown, collapse = ", "), + ". generate_batch() accepts: ", paste(known, collapse = ", ")) + } arg_or <- function(name, default) args[[name]] %||% default normalize_text <- isTRUE(arg_or("normalize_text", TRUE)) use_autocast <- isTRUE(arg_or("autocast", FALSE)) && diff --git a/inst/tinytest/test_batch_masks.R b/inst/tinytest/test_batch_masks.R index 6a84d5c..0ccb763 100644 --- a/inst/tinytest/test_batch_masks.R +++ b/inst/tinytest/test_batch_masks.R @@ -15,3 +15,12 @@ if (requireNamespace("torch", quietly = TRUE) && torch::torch_is_installed()) { expect_identical(dim(total), 3L) expect_identical(as.integer(total), c(310L, 330L, 275L)) } + +# generate_batch rejects unknown arguments instead of swallowing them +if (requireNamespace("torch", quietly = TRUE) && torch::torch_is_installed()) { + fake <- structure(list(loaded = TRUE, turbo = FALSE), + class = "chatterbox") + expect_error( + chatterbox::generate_batch(fake, "hi", "v.wav", skip_vocoder = TRUE), + "Unsupported arguments: skip_vocoder") +}