We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
1 parent 623a9cd commit aac7eabCopy full SHA for aac7eab
3 files changed
flue-core/src/flux/sampling.rs
@@ -8,8 +8,8 @@ pub fn get_noise(
8
width: usize,
9
device: &Device,
10
) -> Result<Tensor> {
11
- let height = (height + 15) / 16 * 2;
12
- let width = (width + 15) / 16 * 2;
+ let height = height.div_ceil(16) * 2;
+ let width = width.div_ceil(16) * 2;
13
Tensor::randn(0f32, 1., (num_samples, 16, height, width), device)
14
}
15
@@ -86,8 +86,8 @@ pub fn get_schedule(num_steps: usize, shift: Option<(usize, f64, f64)>) -> Vec<f
86
87
pub fn unpack(xs: &Tensor, height: usize, width: usize) -> Result<Tensor> {
88
let (b, _h_w, c_ph_pw) = xs.dims3()?;
89
- let height = (height + 15) / 16;
90
- let width = (width + 15) / 16;
+ let height = height.div_ceil(16);
+ let width = width.div_ceil(16);
91
xs.reshape((b, height, width, c_ph_pw / 4, 2, 2))? // (b, h, w, c, ph, pw)
92
.permute((0, 3, 1, 4, 2, 5))? // (b, c, h, ph, w, pw)
93
.reshape((b, c_ph_pw / 4, height * 2, width * 2))
flue-server/Cargo.toml
@@ -11,7 +11,7 @@ license.workspace = true
homepage.workspace = true
[dependencies]
-flue-core = { version = "0.1.0", path = "../flue-core" }
+flue-core = { path = "../flue-core" }
anyhow = { workspace = true }
16
axum = { workspace = true }
17
base64 = { workspace = true }
@@ -21,12 +21,3 @@ image = { workspace = true }
21
serde = { workspace = true }
22
serde_json = { workspace = true }
23
tokio = { workspace = true }
24
-
25
-[features]
26
-cuda = ["flue-core/cuda"]
27
-cudnn = ["flue-core/cudnn"]
28
-metal = ["flue-core/metal"]
29
-flash-attn-v2 = ["cuda", "flue-core/flash-attn-v2"]
30
-flash-attn-v3 = ["cuda", "flue-core/flash-attn-v3"]
31
-accelerate = ["flue-core/accelerate"]
32
-mkl = ["flue-core/mkl"]
flue-server/src/main.rs
@@ -46,6 +46,7 @@ fn image_to_base64_png(img: &DynamicImage) -> Result<String> {
46
#[derive(Serialize)]
47
struct GenerationResponse {
48
image: String,
49
+ gen_time: f64, // Time in seconds
50
51
52
// Application state containing the preloaded models and device settings.
@@ -57,7 +58,11 @@ async fn generate_image_handler(
57
58
Json(req): Json<GenerationRequest>,
59
) -> impl IntoResponse {
60
match generate_image(req, &state).await {
- Ok(img_base64) => Json(GenerationResponse { image: img_base64 }).into_response(),
61
+ Ok((img_base64, gen_time)) => Json(GenerationResponse {
62
+ image: img_base64,
63
+ gen_time,
64
+ })
65
+ .into_response(),
66
Err(e) => {
67
eprintln!("Error generating image: {:?}", e);
68
(StatusCode::INTERNAL_SERVER_ERROR, format!("Error: {:?}", e)).into_response()
@@ -66,9 +71,13 @@ async fn generate_image_handler(
71
72
73
/// This function uses the preloaded models from `state` to generate an image (base64).
69
-async fn generate_image(params: GenerationRequest, state: &AppState) -> Result<String> {
74
+/// Returns both the base64 image and the generation time in seconds.
75
+async fn generate_image(params: GenerationRequest, state: &AppState) -> Result<(String, f64)> {
76
+ let start_time = std::time::Instant::now();
70
77
let image = state.0.run(params)?;
- image_to_base64_png(&image)
78
+ let gen_time = start_time.elapsed().as_secs_f64();
79
+ let base64_image = image_to_base64_png(&image)?;
80
+ Ok((base64_image, gen_time))
81
82
83
#[tokio::main]
0 commit comments