Skip to content
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.

Commit 043df07

Browse files
committedAug 19, 2024·
Run ruff, setup initial text to image node
1 parent f4f5c46 commit 043df07

File tree

15 files changed

+330
-155
lines changed

15 files changed

+330
-155
lines changed
 

‎invokeai/app/invocations/flux_text_encoder.py

+1-6
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,13 @@
11
import torch
2-
3-
4-
from einops import repeat
5-
from diffusers.pipelines.flux.pipeline_flux import FluxPipeline
62
from transformers import CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5Tokenizer
73

84
from invokeai.app.invocations.baseinvocation import BaseInvocation, invocation
95
from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField
106
from invokeai.app.invocations.model import CLIPField, T5EncoderField
117
from invokeai.app.invocations.primitives import ConditioningOutput
128
from invokeai.app.services.shared.invocation_context import InvocationContext
13-
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ConditioningFieldData, FLUXConditioningInfo
14-
from invokeai.backend.util.devices import TorchDevice
159
from invokeai.backend.flux.modules.conditioner import HFEncoder
10+
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ConditioningFieldData, FLUXConditioningInfo
1611

1712

1813
@invocation(
+78-69
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,6 @@
1-
from typing import Literal
2-
31
import torch
4-
from diffusers import AutoencoderKL, FlowMatchEulerDiscreteScheduler
5-
from diffusers.models.transformers.transformer_flux import FluxTransformer2DModel
6-
from diffusers.pipelines.flux.pipeline_flux import FluxPipeline
2+
from einops import rearrange, repeat
73
from PIL import Image
8-
from transformers.models.auto import AutoModelForTextEncoding
94

105
from invokeai.app.invocations.baseinvocation import BaseInvocation, invocation
116
from invokeai.app.invocations.fields import (
@@ -19,20 +14,11 @@
1914
from invokeai.app.invocations.model import TransformerField, VAEField
2015
from invokeai.app.invocations.primitives import ImageOutput
2116
from invokeai.app.services.shared.invocation_context import InvocationContext
22-
from invokeai.backend.quantization.fast_quantized_diffusion_model import FastQuantizedDiffusersModel
23-
from invokeai.backend.quantization.fast_quantized_transformers_model import FastQuantizedTransformersModel
17+
from invokeai.backend.flux.model import Flux
18+
from invokeai.backend.flux.modules.autoencoder import AutoEncoder
19+
from invokeai.backend.flux.sampling import denoise, get_noise, get_schedule, unpack
2420
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import FLUXConditioningInfo
25-
26-
TFluxModelKeys = Literal["flux-schnell"]
27-
FLUX_MODELS: dict[TFluxModelKeys, str] = {"flux-schnell": "black-forest-labs/FLUX.1-schnell"}
28-
29-
30-
class QuantizedFluxTransformer2DModel(FastQuantizedDiffusersModel):
31-
base_class = FluxTransformer2DModel
32-
33-
34-
class QuantizedModelForTextEncoding(FastQuantizedTransformersModel):
35-
auto_class = AutoModelForTextEncoding
21+
from invokeai.backend.util.devices import TorchDevice
3622

3723

3824
@invocation(
@@ -75,7 +61,7 @@ def invoke(self, context: InvocationContext) -> ImageOutput:
7561
assert isinstance(flux_conditioning, FLUXConditioningInfo)
7662

7763
latents = self._run_diffusion(context, flux_conditioning.clip_embeds, flux_conditioning.t5_embeds)
78-
image = self._run_vae_decoding(context, latents)
64+
image = self._run_vae_decoding(context, flux_ae_path, latents)
7965
image_dto = context.images.save(image=image)
8066
return ImageOutput.build(image_dto)
8167

@@ -86,42 +72,79 @@ def _run_diffusion(
8672
t5_embeddings: torch.Tensor,
8773
):
8874
transformer_info = context.models.load(self.transformer.transformer)
75+
inference_dtype = TorchDevice.choose_torch_dtype()
76+
77+
# Prepare input noise.
78+
# TODO(ryand): Does the seed behave the same on different devices? Should we re-implement this to always use a
79+
# CPU RNG?
80+
x = get_noise(
81+
num_samples=1,
82+
height=self.height,
83+
width=self.width,
84+
device=TorchDevice.choose_torch_device(),
85+
dtype=inference_dtype,
86+
seed=self.seed,
87+
)
88+
89+
img, img_ids = self._prepare_latent_img_patches(x)
90+
91+
# HACK(ryand): Find a better way to determine if this is a schnell model or not.
92+
is_schnell = "shnell" in transformer_info.config.path if transformer_info.config else ""
93+
timesteps = get_schedule(
94+
num_steps=self.num_steps,
95+
image_seq_len=img.shape[1],
96+
shift=not is_schnell,
97+
)
98+
99+
bs, t5_seq_len, _ = t5_embeddings.shape
100+
txt_ids = torch.zeros(bs, t5_seq_len, 3, dtype=inference_dtype, device=TorchDevice.choose_torch_device())
89101

90102
# HACK(ryand): Manually empty the cache. Currently we don't check the size of the model before loading it from
91103
# disk. Since the transformer model is large (24GB), there's a good chance that it will OOM on 32GB RAM systems
92104
# if the cache is not empty.
93-
# context.models._services.model_manager.load.ram_cache.make_room(24 * 2**30)
105+
context.models._services.model_manager.load.ram_cache.make_room(24 * 2**30)
94106

95107
with transformer_info as transformer:
96-
assert isinstance(transformer, FluxTransformer2DModel)
97-
98-
flux_pipeline_with_transformer = FluxPipeline(
99-
scheduler=scheduler,
100-
vae=None,
101-
text_encoder=None,
102-
tokenizer=None,
103-
text_encoder_2=None,
104-
tokenizer_2=None,
105-
transformer=transformer,
108+
assert isinstance(transformer, Flux)
109+
110+
x = denoise(
111+
model=transformer,
112+
img=img,
113+
img_ids=img_ids,
114+
txt=t5_embeddings,
115+
txt_ids=txt_ids,
116+
vec=clip_embeddings,
117+
timesteps=timesteps,
118+
guidance=self.guidance,
106119
)
107120

108-
t5_embeddings = t5_embeddings.to(dtype=transformer.dtype)
109-
clip_embeddings = clip_embeddings.to(dtype=transformer.dtype)
121+
x = unpack(x.float(), self.height, self.width)
122+
123+
return x
110124

111-
latents = flux_pipeline_with_transformer(
112-
height=self.height,
113-
width=self.width,
114-
num_inference_steps=self.num_steps,
115-
guidance_scale=self.guidance,
116-
generator=torch.Generator().manual_seed(self.seed),
117-
prompt_embeds=t5_embeddings,
118-
pooled_prompt_embeds=clip_embeddings,
119-
output_type="latent",
120-
return_dict=False,
121-
)[0]
125+
def _prepare_latent_img_patches(self, latent_img: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
126+
"""Convert an input image in latent space to patches for diffusion.
122127
123-
assert isinstance(latents, torch.Tensor)
124-
return latents
128+
This implementation was extracted from:
129+
https://github.com/black-forest-labs/flux/blob/c00d7c60b085fce8058b9df845e036090873f2ce/src/flux/sampling.py#L32
130+
131+
Returns:
132+
tuple[Tensor, Tensor]: (img, img_ids), as defined in the original flux repo.
133+
"""
134+
bs, c, h, w = latent_img.shape
135+
136+
# Pixel unshuffle with a scale of 2, and flatten the height/width dimensions to get an array of patches.
137+
img = rearrange(latent_img, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2)
138+
if img.shape[0] == 1 and bs > 1:
139+
img = repeat(img, "1 ... -> bs ...", bs=bs)
140+
141+
# Generate patch position ids.
142+
img_ids = torch.zeros(h // 2, w // 2, 3)
143+
img_ids[..., 1] = img_ids[..., 1] + torch.arange(h // 2)[:, None]
144+
img_ids[..., 2] = img_ids[..., 2] + torch.arange(w // 2)[None, :]
145+
img_ids = repeat(img_ids, "h w c -> b (h w) c", b=bs)
146+
147+
return img, img_ids
125148

126149
def _run_vae_decoding(
127150
self,
@@ -130,27 +153,13 @@ def _run_vae_decoding(
130153
) -> Image.Image:
131154
vae_info = context.models.load(self.vae.vae)
132155
with vae_info as vae:
133-
assert isinstance(vae, AutoencoderKL)
134-
135-
flux_pipeline_with_vae = FluxPipeline(
136-
scheduler=None,
137-
vae=vae,
138-
text_encoder=None,
139-
tokenizer=None,
140-
text_encoder_2=None,
141-
tokenizer_2=None,
142-
transformer=None,
143-
)
156+
assert isinstance(vae, AutoEncoder)
157+
# TODO(ryand): Test that this works with both float16 and bfloat16.
158+
with torch.autocast(device_type=latents.device.type, dtype=TorchDevice.choose_torch_dtype()):
159+
img = vae.decode(latents)
144160

145-
latents = flux_pipeline_with_vae._unpack_latents(
146-
latents, self.height, self.width, flux_pipeline_with_vae.vae_scale_factor
147-
)
148-
latents = (
149-
latents / flux_pipeline_with_vae.vae.config.scaling_factor
150-
) + flux_pipeline_with_vae.vae.config.shift_factor
151-
latents = latents.to(dtype=vae.dtype)
152-
image = flux_pipeline_with_vae.vae.decode(latents, return_dict=False)[0]
153-
image = flux_pipeline_with_vae.image_processor.postprocess(image, output_type="pil")[0]
154-
155-
assert isinstance(image, Image.Image)
156-
return image
161+
img.clamp(-1, 1)
162+
img = rearrange(img[0], "c h w -> h w c")
163+
img_pil = Image.fromarray((127.5 * (img + 1.0)).byte().cpu().numpy())
164+
165+
return img_pil

‎invokeai/app/invocations/model.py

+60-35
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import copy
22
from time import sleep
3-
from typing import List, Optional, Literal, Dict
3+
from typing import Dict, List, Literal, Optional
44

55
from pydantic import BaseModel, Field
66

@@ -12,10 +12,10 @@
1212
invocation_output,
1313
)
1414
from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField, OutputField, UIType
15+
from invokeai.app.services.model_records import ModelRecordChanges
1516
from invokeai.app.services.shared.invocation_context import InvocationContext
1617
from invokeai.app.shared.models import FreeUConfig
17-
from invokeai.app.services.model_records import ModelRecordChanges
18-
from invokeai.backend.model_manager.config import AnyModelConfig, BaseModelType, ModelType, SubModelType, ModelFormat
18+
from invokeai.backend.model_manager.config import AnyModelConfig, BaseModelType, ModelFormat, ModelType, SubModelType
1919

2020

2121
class ModelIdentifierField(BaseModel):
@@ -132,31 +132,22 @@ def invoke(self, context: InvocationContext) -> ModelIdentifierOutput:
132132

133133
return ModelIdentifierOutput(model=self.model)
134134

135-
T5_ENCODER_OPTIONS = Literal["base", "16b_quantized", "8b_quantized"]
135+
136+
T5_ENCODER_OPTIONS = Literal["base", "8b_quantized"]
136137
T5_ENCODER_MAP: Dict[str, Dict[str, str]] = {
137138
"base": {
138-
"text_encoder_repo": "black-forest-labs/FLUX.1-schnell::text_encoder_2",
139-
"tokenizer_repo": "black-forest-labs/FLUX.1-schnell::tokenizer_2",
140-
"text_encoder_name": "FLUX.1-schnell_text_encoder_2",
141-
"tokenizer_name": "FLUX.1-schnell_tokenizer_2",
139+
"repo": "invokeai/flux_dev::t5_xxl_encoder/base",
140+
"name": "t5_base_encoder",
142141
"format": ModelFormat.T5Encoder,
143142
},
144143
"8b_quantized": {
145-
"text_encoder_repo": "hf_repo1",
146-
"tokenizer_repo": "hf_repo1",
147-
"text_encoder_name": "hf_repo1",
148-
"tokenizer_name": "hf_repo1",
149-
"format": ModelFormat.T5Encoder8b,
150-
},
151-
"4b_quantized": {
152-
"text_encoder_repo": "hf_repo2",
153-
"tokenizer_repo": "hf_repo2",
154-
"text_encoder_name": "hf_repo2",
155-
"tokenizer_name": "hf_repo2",
156-
"format": ModelFormat.T5Encoder8b,
144+
"repo": "invokeai/flux_dev::t5_xxl_encoder/8b_quantized",
145+
"name": "t5_8b_quantized_encoder",
146+
"format": ModelFormat.T5Encoder,
157147
},
158148
}
159149

150+
160151
@invocation_output("flux_model_loader_output")
161152
class FluxModelLoaderOutput(BaseInvocationOutput):
162153
"""Flux base model loader output"""
@@ -176,7 +167,7 @@ class FluxModelLoaderInvocation(BaseInvocation):
176167
ui_type=UIType.FluxMainModel,
177168
input=Input.Direct,
178169
)
179-
170+
180171
t5_encoder: T5_ENCODER_OPTIONS = InputField(description="The T5 Encoder model to use.")
181172

182173
def invoke(self, context: InvocationContext) -> FluxModelLoaderOutput:
@@ -189,7 +180,15 @@ def invoke(self, context: InvocationContext) -> FluxModelLoaderOutput:
189180
tokenizer2 = self._get_model(context, SubModelType.Tokenizer2)
190181
clip_encoder = self._get_model(context, SubModelType.TextEncoder)
191182
t5_encoder = self._get_model(context, SubModelType.TextEncoder2)
192-
vae = self._install_model(context, SubModelType.VAE, "FLUX.1-schnell_ae", "black-forest-labs/FLUX.1-schnell::ae.safetensors", ModelFormat.Checkpoint, ModelType.VAE, BaseModelType.Flux)
183+
vae = self._install_model(
184+
context,
185+
SubModelType.VAE,
186+
"FLUX.1-schnell_ae",
187+
"black-forest-labs/FLUX.1-schnell::ae.safetensors",
188+
ModelFormat.Checkpoint,
189+
ModelType.VAE,
190+
BaseModelType.Flux,
191+
)
193192

194193
return FluxModelLoaderOutput(
195194
transformer=TransformerField(transformer=transformer),
@@ -198,33 +197,59 @@ def invoke(self, context: InvocationContext) -> FluxModelLoaderOutput:
198197
vae=VAEField(vae=vae),
199198
)
200199

201-
def _get_model(self, context: InvocationContext, submodel:SubModelType) -> ModelIdentifierField:
202-
match(submodel):
200+
def _get_model(self, context: InvocationContext, submodel: SubModelType) -> ModelIdentifierField:
201+
match submodel:
203202
case SubModelType.Transformer:
204203
return self.model.model_copy(update={"submodel_type": SubModelType.Transformer})
205204
case submodel if submodel in [SubModelType.Tokenizer, SubModelType.TextEncoder]:
206-
return self._install_model(context, submodel, "clip-vit-large-patch14", "openai/clip-vit-large-patch14", ModelFormat.Diffusers, ModelType.CLIPEmbed, BaseModelType.Any)
207-
case SubModelType.TextEncoder2:
208-
return self._install_model(context, submodel, T5_ENCODER_MAP[self.t5_encoder]["text_encoder_name"], T5_ENCODER_MAP[self.t5_encoder]["text_encoder_repo"], ModelFormat(T5_ENCODER_MAP[self.t5_encoder]["format"]), ModelType.T5Encoder, BaseModelType.Any)
209-
case SubModelType.Tokenizer2:
210-
return self._install_model(context, submodel, T5_ENCODER_MAP[self.t5_encoder]["tokenizer_name"], T5_ENCODER_MAP[self.t5_encoder]["tokenizer_repo"], ModelFormat(T5_ENCODER_MAP[self.t5_encoder]["format"]), ModelType.T5Encoder, BaseModelType.Any)
205+
return self._install_model(
206+
context,
207+
submodel,
208+
"clip-vit-large-patch14",
209+
"openai/clip-vit-large-patch14",
210+
ModelFormat.Diffusers,
211+
ModelType.CLIPEmbed,
212+
BaseModelType.Any,
213+
)
214+
case submodel if submodel in [SubModelType.Tokenizer2, SubModelType.TextEncoder2]:
215+
return self._install_model(
216+
context,
217+
submodel,
218+
T5_ENCODER_MAP[self.t5_encoder]["name"],
219+
T5_ENCODER_MAP[self.t5_encoder]["repo"],
220+
ModelFormat(T5_ENCODER_MAP[self.t5_encoder]["format"]),
221+
ModelType.T5Encoder,
222+
BaseModelType.Any,
223+
)
211224
case _:
212-
raise Exception(f"{submodel.value} is not a supported submodule for a flux model")
213-
214-
def _install_model(self, context: InvocationContext, submodel:SubModelType, name: str, repo_id: str, format: ModelFormat, type: ModelType, base: BaseModelType):
215-
if (models := context.models.search_by_attrs(name=name, base=base, type=type)):
225+
raise Exception(f"{submodel.value} is not a supported submodule for a flux model")
226+
227+
def _install_model(
228+
self,
229+
context: InvocationContext,
230+
submodel: SubModelType,
231+
name: str,
232+
repo_id: str,
233+
format: ModelFormat,
234+
type: ModelType,
235+
base: BaseModelType,
236+
):
237+
if models := context.models.search_by_attrs(name=name, base=base, type=type):
216238
if len(models) != 1:
217239
raise Exception(f"Multiple models detected for selected model with name {name}")
218240
return ModelIdentifierField.from_config(models[0]).model_copy(update={"submodel_type": submodel})
219241
else:
220242
model_path = context.models.download_and_cache_model(repo_id)
221-
config = ModelRecordChanges(name = name, base = base, type=type, format=format)
243+
config = ModelRecordChanges(name=name, base=base, type=type, format=format)
222244
model_install_job = context.models.import_local_model(model_path=model_path, config=config)
223245
while not model_install_job.in_terminal_state:
224246
sleep(0.01)
225247
if not model_install_job.config_out:
226248
raise Exception(f"Failed to install {name}")
227-
return ModelIdentifierField.from_config(model_install_job.config_out).model_copy(update={"submodel_type": submodel})
249+
return ModelIdentifierField.from_config(model_install_job.config_out).model_copy(
250+
update={"submodel_type": submodel}
251+
)
252+
228253

229254
@invocation(
230255
"main_model_loader",

‎invokeai/app/services/model_records/model_records_sql.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -301,7 +301,7 @@ def search_by_attr(
301301
for row in result:
302302
try:
303303
model_config = ModelConfigFactory.make_config(json.loads(row[0]), timestamp=row[1])
304-
except pydantic.ValidationError as e:
304+
except pydantic.ValidationError:
305305
# We catch this error so that the app can still run if there are invalid model configs in the database.
306306
# One reason that an invalid model config might be in the database is if someone had to rollback from a
307307
# newer version of the app that added a new model type.

‎invokeai/app/services/shared/invocation_context.py

+8-6
Original file line numberDiff line numberDiff line change
@@ -465,18 +465,20 @@ def download_and_cache_model(
465465
return self._services.model_manager.install.download_and_cache_model(source=source)
466466

467467
def import_local_model(
468-
self,
469-
model_path: Path,
470-
config: Optional[ModelRecordChanges] = None,
471-
access_token: Optional[str] = None,
472-
inplace: Optional[bool] = False,
468+
self,
469+
model_path: Path,
470+
config: Optional[ModelRecordChanges] = None,
471+
access_token: Optional[str] = None,
472+
inplace: Optional[bool] = False,
473473
):
474474
"""
475475
TODO: Fill out description of this method
476476
"""
477477
if not model_path.exists():
478478
raise Exception("Models provided to import_local_model must already exist on disk")
479-
return self._services.model_manager.install.heuristic_import(str(model_path), config=config, access_token=access_token, inplace=inplace)
479+
return self._services.model_manager.install.heuristic_import(
480+
str(model_path), config=config, access_token=access_token, inplace=inplace
481+
)
480482

481483
def load_local_model(
482484
self,

‎invokeai/backend/flux/math.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -27,4 +27,4 @@ def apply_rope(xq: Tensor, xk: Tensor, freqs_cis: Tensor) -> tuple[Tensor, Tenso
2727
xk_ = xk.float().reshape(*xk.shape[:-1], -1, 1, 2)
2828
xq_out = freqs_cis[..., 0] * xq_[..., 0] + freqs_cis[..., 1] * xq_[..., 1]
2929
xk_out = freqs_cis[..., 0] * xk_[..., 0] + freqs_cis[..., 1] * xk_[..., 1]
30-
return xq_out.reshape(*xq.shape).type_as(xq), xk_out.reshape(*xk.shape).type_as(xk)
30+
return xq_out.reshape(*xq.shape).type_as(xq), xk_out.reshape(*xk.shape).type_as(xk)

‎invokeai/backend/flux/model.py

+11-7
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,15 @@
33
import torch
44
from torch import Tensor, nn
55

6-
from invokeai.backend.flux.modules.layers import (DoubleStreamBlock, EmbedND, LastLayer,
7-
MLPEmbedder, SingleStreamBlock,
8-
timestep_embedding)
6+
from invokeai.backend.flux.modules.layers import (
7+
DoubleStreamBlock,
8+
EmbedND,
9+
LastLayer,
10+
MLPEmbedder,
11+
SingleStreamBlock,
12+
timestep_embedding,
13+
)
14+
915

1016
@dataclass
1117
class FluxParams:
@@ -35,9 +41,7 @@ def __init__(self, params: FluxParams):
3541
self.in_channels = params.in_channels
3642
self.out_channels = self.in_channels
3743
if params.hidden_size % params.num_heads != 0:
38-
raise ValueError(
39-
f"Hidden size {params.hidden_size} must be divisible by num_heads {params.num_heads}"
40-
)
44+
raise ValueError(f"Hidden size {params.hidden_size} must be divisible by num_heads {params.num_heads}")
4145
pe_dim = params.hidden_size // params.num_heads
4246
if sum(params.axes_dim) != pe_dim:
4347
raise ValueError(f"Got {params.axes_dim} but expected positional dim {pe_dim}")
@@ -108,4 +112,4 @@ def forward(
108112
img = img[:, txt.shape[1] :, ...]
109113

110114
img = self.final_layer(img, vec) # (N, T, patch_size ** 2 * out_channels)
111-
return img
115+
return img

‎invokeai/backend/flux/modules/autoencoder.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -309,4 +309,4 @@ def decode(self, z: Tensor) -> Tensor:
309309
return self.decoder(z)
310310

311311
def forward(self, x: Tensor) -> Tensor:
312-
return self.decode(self.encode(x))
312+
return self.decode(self.encode(x))

‎invokeai/backend/flux/modules/conditioner.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from torch import Tensor, nn
2-
from transformers import (PreTrainedModel, PreTrainedTokenizer)
2+
from transformers import PreTrainedModel, PreTrainedTokenizer
3+
34

45
class HFEncoder(nn.Module):
56
def __init__(self, encoder: PreTrainedModel, tokenizer: PreTrainedTokenizer, is_clip: bool, max_length: int):
@@ -27,4 +28,4 @@ def forward(self, text: list[str]) -> Tensor:
2728
attention_mask=None,
2829
output_hidden_states=False,
2930
)
30-
return outputs[self.output_key]
31+
return outputs[self.output_key]

‎invokeai/backend/flux/modules/layers.py

+2-4
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from einops import rearrange
66
from torch import Tensor, nn
77

88
from ..math import attention, rope
99

1010

1111
class EmbedND(nn.Module):
@@ -36,9 +36,7 @@
3636
"""
3737
t = time_factor * t
3838
half = dim // 2
39-
freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half).to(
40-
t.device
41-
)
39+
freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half).to(t.device)
4240

4341
args = t[:, None].float() * freqs[None]
4442
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
@@ -250,4 +248,4 @@
250248
shift, scale = self.adaLN_modulation(vec).chunk(2, dim=1)
251249
x = (1 + scale[:, None, :]) * self.norm_final(x) + shift[:, None, :]
252250
x = self.linear(x)
253-
return x
251+
return x

‎invokeai/backend/flux/sampling.py

+134
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,134 @@
1+
import math
2+
from typing import Callable
3+
4+
import torch
5+
from einops import rearrange, repeat
6+
from torch import Tensor
7+
8+
from .model import Flux
9+
from .modules.conditioner import HFEncoder
10+
11+
12+
def get_noise(
13+
num_samples: int,
14+
height: int,
15+
width: int,
16+
device: torch.device,
17+
dtype: torch.dtype,
18+
seed: int,
19+
):
20+
return torch.randn(
21+
num_samples,
22+
16,
23+
# allow for packing
24+
2 * math.ceil(height / 16),
25+
2 * math.ceil(width / 16),
26+
device=device,
27+
dtype=dtype,
28+
generator=torch.Generator(device=device).manual_seed(seed),
29+
)
30+
31+
32+
def prepare(t5: HFEncoder, clip: HFEncoder, img: Tensor, prompt: str | list[str]) -> dict[str, Tensor]:
33+
bs, c, h, w = img.shape
34+
if bs == 1 and not isinstance(prompt, str):
35+
bs = len(prompt)
36+
37+
img = rearrange(img, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2)
38+
if img.shape[0] == 1 and bs > 1:
39+
img = repeat(img, "1 ... -> bs ...", bs=bs)
40+
41+
img_ids = torch.zeros(h // 2, w // 2, 3)
42+
img_ids[..., 1] = img_ids[..., 1] + torch.arange(h // 2)[:, None]
43+
img_ids[..., 2] = img_ids[..., 2] + torch.arange(w // 2)[None, :]
44+
img_ids = repeat(img_ids, "h w c -> b (h w) c", b=bs)
45+
46+
if isinstance(prompt, str):
47+
prompt = [prompt]
48+
txt = t5(prompt)
49+
if txt.shape[0] == 1 and bs > 1:
50+
txt = repeat(txt, "1 ... -> bs ...", bs=bs)
51+
txt_ids = torch.zeros(bs, txt.shape[1], 3)
52+
53+
vec = clip(prompt)
54+
if vec.shape[0] == 1 and bs > 1:
55+
vec = repeat(vec, "1 ... -> bs ...", bs=bs)
56+
57+
return {
58+
"img": img,
59+
"img_ids": img_ids.to(img.device),
60+
"txt": txt.to(img.device),
61+
"txt_ids": txt_ids.to(img.device),
62+
"vec": vec.to(img.device),
63+
}
64+
65+
66+
def time_shift(mu: float, sigma: float, t: Tensor):
67+
return math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma)
68+
69+
70+
def get_lin_function(x1: float = 256, y1: float = 0.5, x2: float = 4096, y2: float = 1.15) -> Callable[[float], float]:
71+
m = (y2 - y1) / (x2 - x1)
72+
b = y1 - m * x1
73+
return lambda x: m * x + b
74+
75+
76+
def get_schedule(
77+
num_steps: int,
78+
image_seq_len: int,
79+
base_shift: float = 0.5,
80+
max_shift: float = 1.15,
81+
shift: bool = True,
82+
) -> list[float]:
83+
# extra step for zero
84+
timesteps = torch.linspace(1, 0, num_steps + 1)
85+
86+
# shifting the schedule to favor high timesteps for higher signal images
87+
if shift:
88+
# eastimate mu based on linear estimation between two points
89+
mu = get_lin_function(y1=base_shift, y2=max_shift)(image_seq_len)
90+
timesteps = time_shift(mu, 1.0, timesteps)
91+
92+
return timesteps.tolist()
93+
94+
95+
def denoise(
96+
model: Flux,
97+
# model input
98+
img: Tensor,
99+
img_ids: Tensor,
100+
txt: Tensor,
101+
txt_ids: Tensor,
102+
vec: Tensor,
103+
# sampling parameters
104+
timesteps: list[float],
105+
guidance: float = 4.0,
106+
):
107+
# this is ignored for schnell
108+
guidance_vec = torch.full((img.shape[0],), guidance, device=img.device, dtype=img.dtype)
109+
for t_curr, t_prev in zip(timesteps[:-1], timesteps[1:], strict=False):
110+
t_vec = torch.full((img.shape[0],), t_curr, dtype=img.dtype, device=img.device)
111+
pred = model(
112+
img=img,
113+
img_ids=img_ids,
114+
txt=txt,
115+
txt_ids=txt_ids,
116+
y=vec,
117+
timesteps=t_vec,
118+
guidance=guidance_vec,
119+
)
120+
121+
img = img + (t_prev - t_curr) * pred
122+
123+
return img
124+
125+
126+
def unpack(x: Tensor, height: int, width: int) -> Tensor:
127+
return rearrange(
128+
x,
129+
"b (h w) (c ph pw) -> b c (h ph) (w pw)",
130+
h=math.ceil(height / 16),
131+
w=math.ceil(width / 16),
132+
ph=2,
133+
pw=2,
134+
)

‎invokeai/backend/model_manager/load/model_loaders/flux.py

+15-16
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,17 @@
11
# Copyright (c) 2024, Brandon W. Rising and the InvokeAI Development Team
22
"""Class for Flux model loading in InvokeAI."""
33

4+
from dataclasses import fields
45
from pathlib import Path
5-
import yaml
6+
from typing import Any, Optional
67

7-
from dataclasses import fields
8+
import yaml
89
from safetensors.torch import load_file
9-
from typing import Optional, Any
10-
from transformers import T5EncoderModel, T5Tokenizer
10+
from transformers import CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5Tokenizer
1111

12+
from invokeai.app.services.config.config_default import get_config
13+
from invokeai.backend.flux.model import Flux, FluxParams
14+
from invokeai.backend.flux.modules.autoencoder import AutoEncoder, AutoEncoderParams
1215
from invokeai.backend.model_manager import (
1316
AnyModel,
1417
AnyModelConfig,
@@ -19,20 +22,15 @@
1922
)
2023
from invokeai.backend.model_manager.config import (
2124
CheckpointConfigBase,
22-
MainCheckpointConfig,
2325
CLIPEmbedDiffusersConfig,
26+
MainCheckpointConfig,
2427
T5EncoderConfig,
2528
VAECheckpointConfig,
2629
)
27-
from invokeai.app.services.config.config_default import get_config
2830
from invokeai.backend.model_manager.load.model_loader_registry import ModelLoaderRegistry
2931
from invokeai.backend.model_manager.load.model_loaders.generic_diffusers import GenericDiffusersLoader
30-
from invokeai.backend.util.silence_warnings import SilenceWarnings
3132
from invokeai.backend.util.devices import TorchDevice
32-
from invokeai.backend.flux.model import Flux, FluxParams
33-
from invokeai.backend.flux.modules.autoencoder import AutoEncoderParams, AutoEncoder
34-
from transformers import (CLIPTextModel, CLIPTokenizer, T5EncoderModel,
35-
T5Tokenizer)
33+
from invokeai.backend.util.silence_warnings import SilenceWarnings
3634

3735
app_config = get_config()
3836

@@ -56,9 +54,9 @@ def _load_model(
5654
flux_conf = yaml.safe_load(stream)
5755
except:
5856
raise
59-
57+
6058
dataclass_fields = {f.name for f in fields(AutoEncoderParams)}
61-
filtered_data = {k: v for k, v in flux_conf['params']['ae_params'].items() if k in dataclass_fields}
59+
filtered_data = {k: v for k, v in flux_conf["params"]["ae_params"].items() if k in dataclass_fields}
6260
params = AutoEncoderParams(**filtered_data)
6361

6462
with SilenceWarnings():
@@ -92,6 +90,7 @@ def _load_model(
9290

9391
raise Exception("Only Checkpoint Flux models are currently supported.")
9492

93+
9594
@ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.T5Encoder, format=ModelFormat.T5Encoder)
9695
class T5EncoderCheckpointModel(GenericDiffusersLoader):
9796
"""Class to load main models."""
@@ -106,9 +105,9 @@ def _load_model(
106105

107106
match submodel_type:
108107
case SubModelType.Tokenizer2:
109-
return T5Tokenizer.from_pretrained(Path(config.path), max_length=512)
108+
return T5Tokenizer.from_pretrained(Path(config.path) / "encoder", max_length=512)
110109
case SubModelType.TextEncoder2:
111-
return T5EncoderModel.from_pretrained(Path(config.path))
110+
return T5EncoderModel.from_pretrained(Path(config.path) / "tokenizer")
112111

113112
raise Exception("Only Checkpoint Flux models are currently supported.")
114113

@@ -148,7 +147,7 @@ def _load_from_singlefile(
148147
params = None
149148
model_path = Path(config.path)
150149
dataclass_fields = {f.name for f in fields(FluxParams)}
151-
filtered_data = {k: v for k, v in flux_conf['params'].items() if k in dataclass_fields}
150+
filtered_data = {k: v for k, v in flux_conf["params"].items() if k in dataclass_fields}
152151
params = FluxParams(**filtered_data)
153152

154153
with SilenceWarnings():

‎invokeai/backend/model_manager/load/model_loaders/stable_diffusion.py

+6-2
Original file line numberDiff line numberDiff line change
@@ -39,11 +39,15 @@
3939
@ModelLoaderRegistry.register(base=BaseModelType.StableDiffusion1, type=ModelType.Main, format=ModelFormat.Diffusers)
4040
@ModelLoaderRegistry.register(base=BaseModelType.StableDiffusion2, type=ModelType.Main, format=ModelFormat.Diffusers)
4141
@ModelLoaderRegistry.register(base=BaseModelType.StableDiffusionXL, type=ModelType.Main, format=ModelFormat.Diffusers)
42-
@ModelLoaderRegistry.register(base=BaseModelType.StableDiffusionXLRefiner, type=ModelType.Main, format=ModelFormat.Diffusers)
42+
@ModelLoaderRegistry.register(
43+
base=BaseModelType.StableDiffusionXLRefiner, type=ModelType.Main, format=ModelFormat.Diffusers
44+
)
4345
@ModelLoaderRegistry.register(base=BaseModelType.StableDiffusion1, type=ModelType.Main, format=ModelFormat.Checkpoint)
4446
@ModelLoaderRegistry.register(base=BaseModelType.StableDiffusion2, type=ModelType.Main, format=ModelFormat.Checkpoint)
4547
@ModelLoaderRegistry.register(base=BaseModelType.StableDiffusionXL, type=ModelType.Main, format=ModelFormat.Checkpoint)
46-
@ModelLoaderRegistry.register(base=BaseModelType.StableDiffusionXLRefiner, type=ModelType.Main, format=ModelFormat.Checkpoint)
48+
@ModelLoaderRegistry.register(
49+
base=BaseModelType.StableDiffusionXLRefiner, type=ModelType.Main, format=ModelFormat.Checkpoint
50+
)
4751
class StableDiffusionDiffusersModel(GenericDiffusersLoader):
4852
"""Class to load main models."""
4953

‎invokeai/backend/model_manager/load/model_util.py

+5-2
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
import torch
1010
from diffusers.pipelines.pipeline_utils import DiffusionPipeline
1111
from diffusers.schedulers.scheduling_utils import SchedulerMixin
12-
from transformers import CLIPTokenizer, T5TokenizerFast, T5Tokenizer
12+
from transformers import CLIPTokenizer, T5Tokenizer, T5TokenizerFast
1313

1414
from invokeai.backend.image_util.grounding_dino.grounding_dino_pipeline import GroundingDinoPipeline
1515
from invokeai.backend.image_util.segment_anything.segment_anything_pipeline import SegmentAnythingPipeline
@@ -50,7 +50,10 @@ def calc_model_size_by_data(logger: logging.Logger, model: AnyModel) -> int:
5050
return model.calc_size()
5151
elif isinstance(
5252
model,
53-
(T5TokenizerFast,T5Tokenizer,),
53+
(
54+
T5TokenizerFast,
55+
T5Tokenizer,
56+
),
5457
):
5558
return len(model)
5659
else:

‎invokeai/backend/model_manager/probe.py

+4-3
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@
5656
},
5757
BaseModelType.StableDiffusionXLRefiner: {
5858
ModelVariantType.Normal: "sd_xl_refiner.yaml",
59-
}
59+
},
6060
}
6161

6262

@@ -132,7 +132,7 @@ def probe(
132132
fields = {}
133133

134134
model_path = model_path.resolve()
135-
135+
136136
format_type = ModelFormat.Diffusers if model_path.is_dir() else ModelFormat.Checkpoint
137137
model_info = None
138138
model_type = ModelType(fields["type"]) if "type" in fields and fields["type"] else None
@@ -323,7 +323,7 @@ def _get_checkpoint_config_path(
323323

324324
if model_type is ModelType.Main:
325325
if base_type == BaseModelType.Flux:
326-
config_file="flux/flux1-schnell.yaml"
326+
config_file = "flux/flux1-schnell.yaml"
327327
else:
328328
config_file = LEGACY_CONFIGS[base_type][variant_type]
329329
if isinstance(config_file, dict): # need another tier for sd-2.x models
@@ -727,6 +727,7 @@ class T5EncoderFolderProbe(FolderProbeBase):
727727
def get_format(self) -> ModelFormat:
728728
return ModelFormat.T5Encoder
729729

730+
730731
class ONNXFolderProbe(PipelineFolderProbe):
731732
def get_base_type(self) -> BaseModelType:
732733
# Due to the way the installer is set up, the configuration file for safetensors

0 commit comments

Comments
 (0)
Please sign in to comment.