Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion DESCRIPTION
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
Package: chatterbox
Title: Text-to-Speech Using Chatterbox TTS Engine
Version: 0.1.0.7
Version: 0.1.0.8
Authors@R:
c(person("Troy", "Hernandez", role = c("aut", "cre"),
email = "troy@cornball.ai",
Expand Down
1 change: 1 addition & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ export(decode_tokens)
export(download_chatterbox_models)
export(download_chatterbox_turbo_models)
export(generate)
export(generate_batch)
export(get_model_paths)
export(get_turbo_model_paths)
export(integrated_loudness)
Expand Down
6 changes: 6 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,9 @@
# chatterbox 0.1.0.8 (development)

- New `generate_batch()`: several texts, one batched S3Gen synthesis
pass; padded rows validated to match single runs (mel diff <= 0.005).
- `s3gen$inference()` accepts ragged batches via `speech_token_lens`.

# chatterbox 0.1.0.7 (development)

- New `voice_convert()`: speech-to-speech voice conversion (port of
Expand Down
6 changes: 6 additions & 0 deletions R/conformer.R
Original file line number Diff line number Diff line change
Expand Up @@ -533,6 +533,12 @@ upsample_conformer_encoder_full <- torch::nn_module(
pos_emb <- embed_result[[2]]
masks <- embed_result[[3]]

# Zero the padded tail first: embed's LayerNorm makes pad frames
# nonzero, and the pre-lookahead conv reads rightward, so a padded
# row would otherwise differ from the same row run alone (where the
# conv right-pads true zeros)
xs <- xs$masked_fill(!masks$transpose(2L, 3L), 0.0)

# Pre-lookahead layer
xs <- self$pre_lookahead_layer$forward(xs)

Expand Down
138 changes: 102 additions & 36 deletions R/s3gen.R
Original file line number Diff line number Diff line change
Expand Up @@ -518,10 +518,16 @@ cfm_estimator <- torch::nn_module(
},

compute_attn_mask = function(x, mask) {
# For now, return NULL (no causal masking in inference)
# The Python code uses chunk-based attention masks for streaming,
# but for full sequence inference, we can use standard attention
NULL
# No causal masking in (non-streaming) inference. Single utterances
# have no padding, so the transformers run unmasked. Padded batches
# need a key-padding mask or the tail bleeds into valid frames
# through bidirectional attention.
if (is.null(mask) || !as.logical((mask < 0.5)$any()$item())) {
return(NULL)
}
# (B, 1, T) float -> (B, 1, 1, T) bool, TRUE = may attend;
# broadcasts over heads and query positions in SDPA
(mask > 0.5)$unsqueeze(2)
}
)

Expand Down Expand Up @@ -610,6 +616,11 @@ causal_cfm <- torch::nn_module(
# fresh noise. Avoids advancing the RNG and a wasted allocation.
if (!use_meanflow && is.null(noised_mels) && !is.null(self$rand_noise)) {
z <- self$rand_noise[,, 1:seq_len]$to(device = device)$to(dtype = mu$dtype) * temperature
if (mu$size(1) > 1L) {
# Batched: every row gets the same fixed noise, so row i
# matches what a single-utterance run would draw
z <- z$expand(c(mu$size(1), -1L, -1L))$contiguous()
}
} else {
z <- torch::torch_randn_like(mu)
if (!is.null(noised_mels)) {
Expand Down Expand Up @@ -663,6 +674,12 @@ causal_cfm <- torch::nn_module(
device <- x$device
dtype <- x$dtype

# Traced graph is fixed at CFG batch 2; fall back for real batches
if (traced && batch_size > 1L) {
warning("traced CFM supports batch size 1, falling back to non-traced")
traced <- FALSE
}

# For traced mode, pad to fixed max length
if (traced && seq_len <= CFM_MAX_SEQ_LEN) {
traced_est <- self$get_traced_estimator(device)
Expand Down Expand Up @@ -699,40 +716,45 @@ causal_cfm <- torch::nn_module(
work_len <- if (traced &&
seq_len <= CFM_MAX_SEQ_LEN) CFM_MAX_SEQ_LEN else seq_len

# Pre-allocate tensors for CFG (batch size 2)
x_in <- torch::torch_zeros(c(2L, 80L, work_len), device = device, dtype = dtype)
mask_in <- torch::torch_zeros(c(2L, 1L, work_len), device = device, dtype = dtype)
mu_in <- torch::torch_zeros(c(2L, 80L, work_len), device = device, dtype = dtype)
t_in <- torch::torch_zeros(2L, device = device, dtype = dtype)
spks_in <- torch::torch_zeros(c(2L, 80L), device = device, dtype = dtype)
cond_in <- torch::torch_zeros(c(2L, 80L, work_len), device = device, dtype = dtype)
# Pre-allocate tensors for CFG: rows 1:B conditional, (B+1):2B
# unconditional (zero mu/spks/cond)
B <- batch_size
x_in <- torch::torch_zeros(c(2L * B, 80L, work_len), device = device, dtype = dtype)
mask_in <- torch::torch_zeros(c(2L * B, 1L, work_len), device = device, dtype = dtype)
mu_in <- torch::torch_zeros(c(2L * B, 80L, work_len), device = device, dtype = dtype)
t_in <- torch::torch_zeros(2L * B, device = device, dtype = dtype)
spks_in <- torch::torch_zeros(c(2L * B, 80L), device = device, dtype = dtype)
cond_in <- torch::torch_zeros(c(2L * B, 80L, work_len), device = device, dtype = dtype)

for (step in 2:length(t_span)) {
# Classifier-Free Guidance: conditional and unconditional paths
# Use padded tensors in traced mode
if (traced && seq_len <= CFM_MAX_SEQ_LEN) {
x_in[1:2,,] <- x_padded
mask_in[1:2,,] <- mask_padded
mu_in[1,,] <- mu_padded
cond_in[1,,] <- cond_padded
x_in[1:B,,] <- x_padded
x_in[(B + 1):(2 * B),,] <- x_padded
mask_in[1:B,,] <- mask_padded
mask_in[(B + 1):(2 * B),,] <- mask_padded
mu_in[1:B,,] <- mu_padded
cond_in[1:B,,] <- cond_padded
} else {
x_in[1:2,,] <- x
mask_in[1:2,,] <- mask
mu_in[1,,] <- mu
cond_in[1,,] <- cond
x_in[1:B,,] <- x
x_in[(B + 1):(2 * B),,] <- x
mask_in[1:B,,] <- mask
mask_in[(B + 1):(2 * B),,] <- mask
mu_in[1:B,,] <- mu
cond_in[1:B,,] <- cond
}
# mu_in[2] stays zero (unconditional)
t_in[1:2] <- t
spks_in[1,] <- spks
# spks_in[2] stays zero
# cond_in[2] stays zero
# mu_in[(B+1):(2B)] stays zero (unconditional)
t_in[1:(2 * B)] <- t
spks_in[1:B,] <- spks
# spks_in/cond_in upper half stays zero

# Forward through estimator (traced or normal)
dphi_dt <- estimator_fn(x_in, mask_in, mu_in, t_in, spks_in, cond_in)

# CFG combination
dphi_cond <- dphi_dt[1,,]$unsqueeze(1)
dphi_uncond <- dphi_dt[2,,]$unsqueeze(1)
dphi_cond <- dphi_dt[1:B,,, drop = FALSE]
dphi_uncond <- dphi_dt[(B + 1):(2 * B),,, drop = FALSE]
dphi_dt <- (1.0 + self$inference_cfg_rate) * dphi_cond - self$inference_cfg_rate * dphi_uncond

# Euler step (only on actual sequence length)
Expand Down Expand Up @@ -819,14 +841,32 @@ causal_masked_diff_xvec <- torch::nn_module(
meanflow = NULL
) {
device <- token$device
B <- token$size(1)

# Normalize and project speaker embedding
embedding <- torch::nnf_normalize(embedding, dim = 2)
embedding <- self$spk_embed_affine_layer$forward(embedding)

# Concatenate prompt and speech tokens
# Batched generation shares one voice: expand the single-row
# conditioning to every utterance row
if (B > 1L) {
if (embedding$size(1) == 1L) {
embedding <- embedding$expand(c(B, -1L))$contiguous()
}
if (prompt_token$size(1) == 1L) {
prompt_token <- prompt_token$expand(c(B, -1L))$contiguous()
}
if (prompt_feat$size(1) == 1L) {
prompt_feat <- prompt_feat$expand(c(B, -1L, -1L))$contiguous()
}
}

# Concatenate prompt and speech tokens. Lengths flatten to (B):
# ref_dict's prompt_token_len is (1, 1), which would broadcast a
# (B) token_len up to (1, B) and collapse the batch in make_pad_mask
token <- torch::torch_cat(list(prompt_token, token), dim = 2)
token_len <- prompt_token_len + token_len
token_len <- (prompt_token_len$view(-1) + token_len$view(-1))$to(
dtype = torch::torch_long())

# Create mask
mask <- (!make_pad_mask(token_len))$unsqueeze(3)$to(dtype = embedding$dtype, device = device)
Expand Down Expand Up @@ -858,15 +898,18 @@ causal_masked_diff_xvec <- torch::nn_module(
h <- self$encoder_proj$forward(h)

# Prepare conditioning: prompt mel fills the first mel_len1 frames
conds <- torch::torch_zeros(c(1, mel_len1 + mel_len2, self$output_size),
conds <- torch::torch_zeros(c(B, mel_len1 + mel_len2, self$output_size),
device = device, dtype = h$dtype)
if (mel_len1 > 0) {
conds[1, 1:mel_len1,] <- prompt_feat[1,,]
conds[, 1:mel_len1,] <- prompt_feat
}
conds <- conds$transpose(2, 3)

# Create mask for decoder
dec_mask <- torch::torch_ones(c(1, 1, mel_len1 + mel_len2), device = device, dtype = h$dtype)
# Create mask for decoder: per-row valid mel length (ragged batches
# have padded tails). Single-row stays all-ones as before.
mel_lens <- mel_len1 +
(token_len - prompt_token_len$view(-1)) * self$token_mel_ratio
dec_mask <- (!make_pad_mask(mel_lens, max_len = mel_len1 + mel_len2))$unsqueeze(2)$to(dtype = h$dtype, device = device)

# Run decoder
h <- h$transpose(2, 3)
Expand Down Expand Up @@ -1004,7 +1047,10 @@ s3gen <- torch::nn_module(
},

#' Run inference (tokens -> mel -> audio;
#' skip_vocoder = TRUE stops at the mel)
#' skip_vocoder = TRUE stops at the mel).
#' speech_tokens may be (B, T) with
#' speech_token_lens giving per-row valid
#' lengths for ragged batches.
inference = function(
speech_tokens,
ref_wav = NULL,
Expand All @@ -1013,7 +1059,8 @@ s3gen <- torch::nn_module(
finalize = TRUE,
traced = FALSE,
n_cfm_timesteps = NULL,
skip_vocoder = FALSE
skip_vocoder = FALSE,
speech_token_lens = NULL
) {
# Get reference dict
if (is.null(ref_dict)) {
Expand All @@ -1030,7 +1077,12 @@ s3gen <- torch::nn_module(
speech_tokens <- speech_tokens$unsqueeze(1)
}
speech_tokens <- speech_tokens$to(device = device)
speech_token_len <- torch::torch_tensor(speech_tokens$size(2), device = device)
speech_token_len <- if (is.null(speech_token_lens)) {
torch::torch_tensor(rep(speech_tokens$size(2), speech_tokens$size(1)),
device = device)
} else {
torch::torch_tensor(as.integer(speech_token_lens), device = device)
}

# Determine timesteps and noise for meanflow
n_steps <- n_cfm_timesteps
Expand All @@ -1039,7 +1091,7 @@ s3gen <- torch::nn_module(
if (is.null(n_steps)) n_steps <- 2L
# MeanFlow uses random noise per call instead of pre-computed buffer
noised_mels <- torch::torch_randn(
c(1L, 80L, as.integer(speech_tokens$size(2)) * 2L),
c(speech_tokens$size(1), 80L, as.integer(speech_tokens$size(2)) * 2L),
dtype = torch::torch_float32(), device = device
)
}
Expand All @@ -1061,6 +1113,20 @@ s3gen <- torch::nn_module(
)
output_mels <- result[[1]]

# Ragged batch: zero each row's padded mel tail. At padded
# positions the masked estimator leaves dphi = 0, so the "mel"
# there is the untouched initial Gaussian noise - the vocoder's
# convolutional context would smear it into the end of shorter
# utterances. Zeros match what a single run's convs see (zero
# padding past the sequence end).
if (!is.null(speech_token_lens)) {
gen_mel_lens <- (speech_token_len * 2L)$to(dtype = torch::torch_long())
gen_mask <- (!make_pad_mask(gen_mel_lens,
max_len = output_mels$size(3)))$unsqueeze(2)$to(
dtype = output_mels$dtype, device = output_mels$device)
output_mels <- output_mels * gen_mask
}

# Vocoder (mel -> audio)
if (!skip_vocoder && !is.null(self$mel2wav)) {
vocoder_result <- self$mel2wav$inference(output_mels)
Expand Down
Loading
Loading