diff --git a/flue-core/Cargo.toml b/flue-core/Cargo.toml index d2d12da..32723e4 100644 --- a/flue-core/Cargo.toml +++ b/flue-core/Cargo.toml @@ -30,7 +30,7 @@ tokio = { workspace = true } serde_plain = { workspace = true } [features] -cuda = ["candle-core/cuda", "candle-nn/cuda", "candle-transformers/cuda"] +cuda = ["candle-core/cuda", "candle-nn/cuda"] cudnn = ["candle-core/cudnn"] metal = ["candle-core/metal", "candle-nn/metal", "candle-transformers/metal"] flash-attn = ["candle-flash-attn"] diff --git a/flue-core/src/flux/model.rs b/flue-core/src/flux/model.rs index fbd59bd..36ed295 100644 --- a/flue-core/src/flux/model.rs +++ b/flue-core/src/flux/model.rs @@ -563,7 +563,7 @@ pub struct Flux { time_in: MlpEmbedder, vector_in: MlpEmbedder, guidance_in: Option, - pe_embedder: EmbedNd, + pub pe_embedder: EmbedNd, double_blocks: Vec, single_blocks: Vec, final_layer: LastLayer, @@ -614,9 +614,8 @@ impl Flux { pub fn forward( &self, img: &Tensor, - img_ids: &Tensor, txt: &Tensor, - txt_ids: &Tensor, + pe: &Tensor, timesteps: &Tensor, y: &Tensor, guidance: Option<&Tensor>, @@ -628,10 +627,6 @@ impl Flux { candle_core::bail!("unexpected shape for img {:?}", img.shape()) } let dtype = img.dtype(); - let pe = { - let ids = Tensor::cat(&[txt_ids, img_ids], 1)?; - ids.apply(&self.pe_embedder)? - }; let mut txt = txt.apply(&self.txt_in)?; let mut img = img.apply(&self.img_in)?; let vec_ = timestep_embedding(timesteps, 256, dtype)?.apply(&self.time_in)?; @@ -645,12 +640,12 @@ impl Flux { // Double blocks for block in self.double_blocks.iter() { - (img, txt) = block.forward(&img, &txt, &vec_, &pe)? + (img, txt) = block.forward(&img, &txt, &vec_, pe)? } // Single blocks let mut img = Tensor::cat(&[&txt, &img], 1)?; for block in self.single_blocks.iter() { - img = block.forward(&img, &vec_, &pe)?; + img = block.forward(&img, &vec_, pe)?; } let img = img.i((.., txt.dim(1)?..))?; self.final_layer.forward(&img, &vec_) diff --git a/flue-core/src/flux/sampling.rs b/flue-core/src/flux/sampling.rs index 540ca57..6446941 100644 --- a/flue-core/src/flux/sampling.rs +++ b/flue-core/src/flux/sampling.rs @@ -107,14 +107,21 @@ pub fn denoise( let b_sz = img.dim(0)?; let dev = img.device(); let guidance = Tensor::full(guidance as f32, b_sz, dev)?; + let t_vec_one = Tensor::full(1f32, b_sz, dev)?; let mut img = img.clone(); + + let pe = { + let ids = Tensor::cat(&[txt_ids, img_ids], 1)?; + ids.apply(&model.pe_embedder)? + }; + for window in timesteps.windows(2) { let (t_curr, t_prev) = match window { [a, b] => (a, b), _ => continue, }; - let t_vec = Tensor::full(*t_curr as f32, b_sz, dev)?; - let pred = model.forward(&img, img_ids, txt, txt_ids, &t_vec, vec_, Some(&guidance))?; + let t_vec = (&t_vec_one * { *t_curr })?; + let pred = model.forward(&img, txt, &pe, &t_vec, vec_, Some(&guidance))?; img = (img + pred * (t_prev - t_curr))? } Ok(img) diff --git a/flue-flash-attn-v2/cutlass b/flue-flash-attn-v2/cutlass index afa1772..62750a2 160000 --- a/flue-flash-attn-v2/cutlass +++ b/flue-flash-attn-v2/cutlass @@ -1 +1 @@ -Subproject commit afa1772203677c5118fcd82537a9c8fefbcc7008 +Subproject commit 62750a2b75c802660e4894434dc55e839f322277 diff --git a/flue-flash-attn-v3/cutlass b/flue-flash-attn-v3/cutlass index 4c42f73..62750a2 160000 --- a/flue-flash-attn-v3/cutlass +++ b/flue-flash-attn-v3/cutlass @@ -1 +1 @@ -Subproject commit 4c42f73fdab5787e3bb57717f35a8cb1b3c0dc6d +Subproject commit 62750a2b75c802660e4894434dc55e839f322277