Skip to content

Commit aac7eab

Browse files
authored
Cleanup and generation time (#18)
* Cleanup and generation time * Rustfmt / Clippy fixes
1 parent 623a9cd commit aac7eab

3 files changed

Lines changed: 17 additions & 17 deletions

File tree

flue-core/src/flux/sampling.rs

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,8 @@ pub fn get_noise(
88
width: usize,
99
device: &Device,
1010
) -> Result<Tensor> {
11-
let height = (height + 15) / 16 * 2;
12-
let width = (width + 15) / 16 * 2;
11+
let height = height.div_ceil(16) * 2;
12+
let width = width.div_ceil(16) * 2;
1313
Tensor::randn(0f32, 1., (num_samples, 16, height, width), device)
1414
}
1515

@@ -86,8 +86,8 @@ pub fn get_schedule(num_steps: usize, shift: Option<(usize, f64, f64)>) -> Vec<f
8686

8787
pub fn unpack(xs: &Tensor, height: usize, width: usize) -> Result<Tensor> {
8888
let (b, _h_w, c_ph_pw) = xs.dims3()?;
89-
let height = (height + 15) / 16;
90-
let width = (width + 15) / 16;
89+
let height = height.div_ceil(16);
90+
let width = width.div_ceil(16);
9191
xs.reshape((b, height, width, c_ph_pw / 4, 2, 2))? // (b, h, w, c, ph, pw)
9292
.permute((0, 3, 1, 4, 2, 5))? // (b, c, h, ph, w, pw)
9393
.reshape((b, c_ph_pw / 4, height * 2, width * 2))

flue-server/Cargo.toml

Lines changed: 1 addition & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ license.workspace = true
1111
homepage.workspace = true
1212

1313
[dependencies]
14-
flue-core = { version = "0.1.0", path = "../flue-core" }
14+
flue-core = { path = "../flue-core" }
1515
anyhow = { workspace = true }
1616
axum = { workspace = true }
1717
base64 = { workspace = true }
@@ -21,12 +21,3 @@ image = { workspace = true }
2121
serde = { workspace = true }
2222
serde_json = { workspace = true }
2323
tokio = { workspace = true }
24-
25-
[features]
26-
cuda = ["flue-core/cuda"]
27-
cudnn = ["flue-core/cudnn"]
28-
metal = ["flue-core/metal"]
29-
flash-attn-v2 = ["cuda", "flue-core/flash-attn-v2"]
30-
flash-attn-v3 = ["cuda", "flue-core/flash-attn-v3"]
31-
accelerate = ["flue-core/accelerate"]
32-
mkl = ["flue-core/mkl"]

flue-server/src/main.rs

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@ fn image_to_base64_png(img: &DynamicImage) -> Result<String> {
4646
#[derive(Serialize)]
4747
struct GenerationResponse {
4848
image: String,
49+
gen_time: f64, // Time in seconds
4950
}
5051

5152
// Application state containing the preloaded models and device settings.
@@ -57,7 +58,11 @@ async fn generate_image_handler(
5758
Json(req): Json<GenerationRequest>,
5859
) -> impl IntoResponse {
5960
match generate_image(req, &state).await {
60-
Ok(img_base64) => Json(GenerationResponse { image: img_base64 }).into_response(),
61+
Ok((img_base64, gen_time)) => Json(GenerationResponse {
62+
image: img_base64,
63+
gen_time,
64+
})
65+
.into_response(),
6166
Err(e) => {
6267
eprintln!("Error generating image: {:?}", e);
6368
(StatusCode::INTERNAL_SERVER_ERROR, format!("Error: {:?}", e)).into_response()
@@ -66,9 +71,13 @@ async fn generate_image_handler(
6671
}
6772

6873
/// This function uses the preloaded models from `state` to generate an image (base64).
69-
async fn generate_image(params: GenerationRequest, state: &AppState) -> Result<String> {
74+
/// Returns both the base64 image and the generation time in seconds.
75+
async fn generate_image(params: GenerationRequest, state: &AppState) -> Result<(String, f64)> {
76+
let start_time = std::time::Instant::now();
7077
let image = state.0.run(params)?;
71-
image_to_base64_png(&image)
78+
let gen_time = start_time.elapsed().as_secs_f64();
79+
let base64_image = image_to_base64_png(&image)?;
80+
Ok((base64_image, gen_time))
7281
}
7382

7483
#[tokio::main]

0 commit comments

Comments
 (0)