Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add acceleration support for FLUX models #1066

Open
oreasono opened this issue Aug 6, 2024 · 11 comments
Open

Add acceleration support for FLUX models #1066

oreasono opened this issue Aug 6, 2024 · 11 comments

Comments

@oreasono
Copy link

oreasono commented Aug 6, 2024

🚀 The feature, motivation and pitch

https://blackforestlabs.ai/#get-flux
FLUX models are the new SOTA opensource text-to-image model. I am wondering if this slightly different architecture model can still gain benifit from onediff framework.

Alternatives

diffusers library has a FluxPipeline support https://github.com/black-forest-labs/flux?tab=readme-ov-file#diffusers-integration

Additional context

No response

@Swarzox
Copy link

Swarzox commented Aug 6, 2024

bump I am also looking for this

@strint
Copy link
Collaborator

strint commented Aug 6, 2024

We are working on FLUX related optimization.

If you are interested to try now, you can try this, it should works: https://github.com/siliconflow/onediff/tree/main/src/onediff/infer_compiler

from onediff.infer_compiler import compile

# module is the model you want to compile
options = '{"mode": "O3"}'  # mode can be O2 or O3
flux_pipe = ...
flux_pipe.transformer = compile(flux_pipe.transformer, backend="nexfort", options=options)

@Swarzox
Copy link

Swarzox commented Aug 6, 2024

We are working on FLUX related optimization.

If you are interested to try now, you can try this, it should works: https://github.com/siliconflow/onediff/tree/main/src/onediff/infer_compiler

from onediff.infer_compiler import compile

# module is the model you want to compile
options = '{"mode": "O3"}'  # mode can be O2 or O3
flux_pipe = ...
flux_pipe.transformer = compile(flux_pipe.transformer, backend="nexfort", options=options)

I keep having this error: RuntimeError: RuntimeError: Unsupported timesteps dtype: c10::BFloat16

@strint
Copy link
Collaborator

strint commented Aug 6, 2024

I keep having this error: RuntimeError: RuntimeError: Unsupported timesteps dtype: c10::BFloat16

update nexfort, and then set these environment variables:

    os.environ['NEXFORT_FUSE_TIMESTEP_EMBEDDING'] = '0'

@antferdom
Copy link

antferdom commented Aug 13, 2024

Hi, I haven't been able to run FLUX diffusers pipeline with nextfort compiler backend. I'm using the settings discussed in this thread, and also based on the official configurations. @strint any suggestion?

Environnent

Torch version: 2.3.0
CUDA version: 12.1.0
GPU: NVIDIA A100-SXM4-80GB
Nexfort version: 0.1.dev264
Onediff/onediffx: 1.2.1.dev18+g6b53a83b Build from source for dev
Diffusers version: 0.30.0

Test Script

# %%
import os
from onediffx import compile_pipe
from onediff.infer_compiler import compile
import torch
from torch.utils.flop_counter import FlopCounterMode
from triton.testing import do_bench
from diffusers.pipelines.flux.pipeline_flux import FluxPipeline

torch.set_default_device('cuda')
# RuntimeError: RuntimeError: Unsupported timesteps dtype: c10::BFloat16
# ref: https://github.com/siliconflow/onediff/issues/1066#issuecomment-2271523799
os.environ['NEXFORT_FUSE_TIMESTEP_EMBEDDING'] = '0'
# %%
model_id: str = "black-forest-labs/FLUX.1-schnell"

pipe = FluxPipeline.from_pretrained(model_id, 
                                    torch_dtype=torch.bfloat16, 
                                    )
pipe.to("cuda")
# %%
"""
options = '{"mode": "O3"}' 
pipe.transformer = compile(pipe.transformer, backend="nexfort", options=options)
"""
# compiler optimization options: '{"mode": "O3", "memory_format": "channels_last"}'
options = '{"mode": "O2"}'
pipe = compile_pipe(pipe, backend="nexfort", options=options, fuse_qkv_projections=True)
# %%
prompt = "A cat holding a sign that says hello world"
image = pipe(
    prompt,
    guidance_scale=0.0,
    num_inference_steps=4,
    max_sequence_length=256,
    generator=torch.Generator("cpu").manual_seed(0)
).images[0]
image.save("flux-schnell.png")
# %%
def get_flops_achieved(f):
    flop_counter = FlopCounterMode(display=False)
    with flop_counter:
        f()
    total_flops = flop_counter.get_total_flops()
    ms_per_iter = do_bench(f)
    iters_per_second = 1e3/ms_per_iter
    print(f"{iters_per_second * total_flops / 1e12} TF/s")


get_flops_achieved(lambda: pipe(
                "A tree in the forest",
                guidance_scale=0.0,
                num_inference_steps=4,
                max_sequence_length=256,
                generator=torch.Generator("cpu").manual_seed(0)
))
# %%
def benchmark_torch_function(iters, f, *args, **kwargs):
    f(*args, **kwargs)
    f(*args, **kwargs)
    torch.cuda.synchronize()
    start_event = torch.cuda.Event(enable_timing=True)
    end_event = torch.cuda.Event(enable_timing=True)
    start_event.record()
    for _ in range(iters):
        f(*args, **kwargs)
    end_event.record()
    torch.cuda.synchronize()
    # elapsed_time has a resolution of 0.5 microseconds:
    # but returns milliseconds, so we need to multiply it to increase resolution
    return start_event.elapsed_time(end_event) * 1000 / iters, *f(*args, **kwargs)
# %%
with torch.inference_mode():
    time_nextfort_flux_fwd, _ = benchmark_torch_function(10,
                                                        pipe,
                                                        "A tree in the forest",
                                                        guidance_scale=0.0,
                                                        num_inference_steps=4,
                                                        max_sequence_length=256,
                                                        )
print(time_nextfort_flux_fwd)

Error Result

   8913         del arg704_1
   8914         del arg705_1
   8915         # Source Nodes: [hidden_states_600, key_98, query_98], Original ATen: [aten._scaled_dot_product_flash_attention, aten._to_copy]
   8916         buf960 = extern_kernels.nexfort_cuda_cudnn_scaled_dot_product_attention(buf958, buf959, reinterpret_tensor(buf957, (1, 24, 4352, 128), (13369344, 128, 3072, 1), 0), dropout_p=0.0, is_causal=False, scale=0.08838834764831843, attn_mask=None)
-> 8917         assert_size_stride(buf960, (1, 24, 4352, 128), (13369344, 557056, 128, 1))
   8918         del buf957
   8919         # Source Nodes: [emb_45], Original ATen: [nexfort_inductor.linear_epilogue]
   8920         buf961 = extern_kernels.nexfort_cuda_constrained_linear_epilogue(buf8, reinterpret_tensor(arg708_1, (3072, 9216), (1, 3072), 0), reinterpret_tensor(arg709_1, (1, 9216), (0, 1), 0), epilogue_ops=['add'], epilogue_tensor_args=[True], epilogue_scalar_args=[None], with_bias=True)

AssertionError: expected size 24==24, stride 128==557056 at dim=1

@strint
Copy link
Collaborator

strint commented Aug 13, 2024

Update nexfort to the latest version and set

os.environ['NEXFORT_FUSE_TIMESTEP_EMBEDDING'] = '0'
os.environ['NEXFORT_FX_FORCE_TRITON_SDPA'] = `1`

@antferdom

@antferdom
Copy link

antferdom commented Aug 13, 2024

@strint would it be possible to use fp8? e.g.

@strint
Copy link
Collaborator

strint commented Aug 13, 2024

@strint would it be possible to use fp8? e.g.

Not yet. We are doing some work on flux.

@badhri-suresh
Copy link

badhri-suresh commented Aug 31, 2024

from onediff.infer_compiler import compile


options = '{"mode": "O3"}'  # mode can be O2 or O3
flux_pipe = ...
**flux_pipe.transformer = compile(flux_pipe.transformer, backend="nexfort", options=options)**

@strint Will this also work for Flux Dev model?

@strint
Copy link
Collaborator

strint commented Aug 31, 2024

from onediff.infer_compiler import compile


options = '{"mode": "O3"}'  # mode can be O2 or O3
flux_pipe = ...
**flux_pipe.transformer = compile(flux_pipe.transformer, backend="nexfort", options=options)**

@strint Will this also work for Flux Dev model?

Yes

@Ph0rk0z
Copy link

Ph0rk0z commented Oct 12, 2024

Anyone gotten it working in comfyui?

backend='nexfort' raised:
CompilationError: at 47:12:
tmp24 = tmp21 * tmp7
tmp25 = tmp24 + tmp9
tmp26 = 47 + ((-1)*x2)
tmp27 = tmp26.to(tl.float32)
tmp28 = tmp27 * tmp7
tmp29 = 47.0
tmp30 = tmp29 - tmp28
tmp31 = tl.where(tmp23, tmp25, tmp30)
tmp32 = tmp31.to(tl.float32)
tmp33 = tl.where(tmp19, tmp32, tmp9)
tmp34 = tl.where(tmp2, tmp17, tmp33)
tmp36 = tmp35.to(tl.float32)

import torch
from onediff.infer_compiler import compile
import os

class TorchCompileModel:
    @classmethod
    def INPUT_TYPES(s):
        return {"required": { "model": ("MODEL",),
                             "backend": (["inductor", "cudagraphs"],),
                              }}
    RETURN_TYPES = ("MODEL",)
    FUNCTION = "patch"

    CATEGORY = "_for_testing"
    EXPERIMENTAL = True

    def patch(self, model, backend):
        os.environ['NEXFORT_FX_FORCE_TRITON_SDPA'] = '1'
        os.environ['NEXFORT_FUSE_TIMESTEP_EMBEDDING'] = '0'
        options = '{"mode": "O3"}'  # mode can be O2 or O3
        m = model.clone()
        #m.add_object_patch("diffusion_model", torch.compile(model=m.get_model_object("diffusion_model"), backend=backend))
        m.add_object_patch("diffusion_model", compile(m.get_model_object("diffusion_model"), backend="nexfort", options=options))
        return (m, )

NODE_CLASS_MAPPINGS = {
    "TorchCompileModel": TorchCompileModel,

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

6 participants