Skip to content

Commit b617631

Browse files
RyanJDickbrandonrising
authored andcommitted
Update HF download logic to work for black-forest-labs/FLUX.1-schnell.
1 parent 562c2cc commit b617631

File tree

2 files changed

+99
-2
lines changed

2 files changed

+99
-2
lines changed

invokeai/backend/model_manager/util/select_hf_files.py

+22-2
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@ def filter_files(
5454
"lora_weights.safetensors",
5555
"weights.pb",
5656
"onnx_data",
57+
"spiece.model", # Added for `black-forest-labs/FLUX.1-schnell`.
5758
)
5859
):
5960
paths.append(file)
@@ -62,7 +63,7 @@ def filter_files(
6263
# downloading random checkpoints that might also be in the repo. However there is no guarantee
6364
# that a checkpoint doesn't contain "model" in its name, and no guarantee that future diffusers models
6465
# will adhere to this naming convention, so this is an area to be careful of.
65-
elif re.search(r"model(\.[^.]+)?\.(safetensors|bin|onnx|xml|pth|pt|ckpt|msgpack)$", file.name):
66+
elif re.search(r"model.*\.(safetensors|bin|onnx|xml|pth|pt|ckpt|msgpack)$", file.name):
6667
paths.append(file)
6768

6869
# limit search to subfolder if requested
@@ -97,7 +98,9 @@ def _filter_by_variant(files: List[Path], variant: ModelRepoVariant) -> Set[Path
9798
if variant == ModelRepoVariant.Flax:
9899
result.add(path)
99100

100-
elif path.suffix in [".json", ".txt"]:
101+
# Note: '.model' was added to support:
102+
# https://huggingface.co/black-forest-labs/FLUX.1-schnell/blob/768d12a373ed5cc9ef9a9dea7504dc09fcc14842/tokenizer_2/spiece.model
103+
elif path.suffix in [".json", ".txt", ".model"]:
101104
result.add(path)
102105

103106
elif variant in [
@@ -140,6 +143,23 @@ def _filter_by_variant(files: List[Path], variant: ModelRepoVariant) -> Set[Path
140143
continue
141144

142145
for candidate_list in subfolder_weights.values():
146+
# Check if at least one of the files has the explicit fp16 variant.
147+
at_least_one_fp16 = False
148+
for candidate in candidate_list:
149+
if len(candidate.path.suffixes) == 2 and candidate.path.suffixes[0] == ".fp16":
150+
at_least_one_fp16 = True
151+
break
152+
153+
if not at_least_one_fp16:
154+
# If none of the candidates in this candidate_list have the explicit fp16 variant label, then this
155+
# candidate_list probably doesn't adhere to the variant naming convention that we expected. In this case,
156+
# we'll simply keep all the candidates. An example of a model that hits this case is
157+
# `black-forest-labs/FLUX.1-schnell` (as of commit 012d2fd).
158+
for candidate in candidate_list:
159+
result.add(candidate.path)
160+
161+
# The candidate_list seems to have the expected variant naming convention. We'll select the highest scoring
162+
# candidate.
143163
highest_score_candidate = max(candidate_list, key=lambda candidate: candidate.score)
144164
if highest_score_candidate:
145165
result.add(highest_score_candidate.path)

tests/backend/model_manager/util/test_hf_model_select.py

+77
Original file line numberDiff line numberDiff line change
@@ -326,3 +326,80 @@ def test_select_multiple_weights(
326326
) -> None:
327327
filtered_files = filter_files(sd15_test_files, variant)
328328
assert set(filtered_files) == {Path(f) for f in expected_files}
329+
330+
331+
@pytest.fixture
332+
def flux_schnell_test_files() -> list[Path]:
333+
return [
334+
Path(f)
335+
for f in [
336+
"FLUX.1-schnell/.gitattributes",
337+
"FLUX.1-schnell/README.md",
338+
"FLUX.1-schnell/ae.safetensors",
339+
"FLUX.1-schnell/flux1-schnell.safetensors",
340+
"FLUX.1-schnell/model_index.json",
341+
"FLUX.1-schnell/scheduler/scheduler_config.json",
342+
"FLUX.1-schnell/schnell_grid.jpeg",
343+
"FLUX.1-schnell/text_encoder/config.json",
344+
"FLUX.1-schnell/text_encoder/model.safetensors",
345+
"FLUX.1-schnell/text_encoder_2/config.json",
346+
"FLUX.1-schnell/text_encoder_2/model-00001-of-00002.safetensors",
347+
"FLUX.1-schnell/text_encoder_2/model-00002-of-00002.safetensors",
348+
"FLUX.1-schnell/text_encoder_2/model.safetensors.index.json",
349+
"FLUX.1-schnell/tokenizer/merges.txt",
350+
"FLUX.1-schnell/tokenizer/special_tokens_map.json",
351+
"FLUX.1-schnell/tokenizer/tokenizer_config.json",
352+
"FLUX.1-schnell/tokenizer/vocab.json",
353+
"FLUX.1-schnell/tokenizer_2/special_tokens_map.json",
354+
"FLUX.1-schnell/tokenizer_2/spiece.model",
355+
"FLUX.1-schnell/tokenizer_2/tokenizer.json",
356+
"FLUX.1-schnell/tokenizer_2/tokenizer_config.json",
357+
"FLUX.1-schnell/transformer/config.json",
358+
"FLUX.1-schnell/transformer/diffusion_pytorch_model-00001-of-00003.safetensors",
359+
"FLUX.1-schnell/transformer/diffusion_pytorch_model-00002-of-00003.safetensors",
360+
"FLUX.1-schnell/transformer/diffusion_pytorch_model-00003-of-00003.safetensors",
361+
"FLUX.1-schnell/transformer/diffusion_pytorch_model.safetensors.index.json",
362+
"FLUX.1-schnell/vae/config.json",
363+
"FLUX.1-schnell/vae/diffusion_pytorch_model.safetensors",
364+
]
365+
]
366+
367+
368+
@pytest.mark.parametrize(
369+
["variant", "expected_files"],
370+
[
371+
(
372+
ModelRepoVariant.Default,
373+
[
374+
"FLUX.1-schnell/model_index.json",
375+
"FLUX.1-schnell/scheduler/scheduler_config.json",
376+
"FLUX.1-schnell/text_encoder/config.json",
377+
"FLUX.1-schnell/text_encoder/model.safetensors",
378+
"FLUX.1-schnell/text_encoder_2/config.json",
379+
"FLUX.1-schnell/text_encoder_2/model-00001-of-00002.safetensors",
380+
"FLUX.1-schnell/text_encoder_2/model-00002-of-00002.safetensors",
381+
"FLUX.1-schnell/text_encoder_2/model.safetensors.index.json",
382+
"FLUX.1-schnell/tokenizer/merges.txt",
383+
"FLUX.1-schnell/tokenizer/special_tokens_map.json",
384+
"FLUX.1-schnell/tokenizer/tokenizer_config.json",
385+
"FLUX.1-schnell/tokenizer/vocab.json",
386+
"FLUX.1-schnell/tokenizer_2/special_tokens_map.json",
387+
"FLUX.1-schnell/tokenizer_2/spiece.model",
388+
"FLUX.1-schnell/tokenizer_2/tokenizer.json",
389+
"FLUX.1-schnell/tokenizer_2/tokenizer_config.json",
390+
"FLUX.1-schnell/transformer/config.json",
391+
"FLUX.1-schnell/transformer/diffusion_pytorch_model-00001-of-00003.safetensors",
392+
"FLUX.1-schnell/transformer/diffusion_pytorch_model-00002-of-00003.safetensors",
393+
"FLUX.1-schnell/transformer/diffusion_pytorch_model-00003-of-00003.safetensors",
394+
"FLUX.1-schnell/transformer/diffusion_pytorch_model.safetensors.index.json",
395+
"FLUX.1-schnell/vae/config.json",
396+
"FLUX.1-schnell/vae/diffusion_pytorch_model.safetensors",
397+
],
398+
),
399+
],
400+
)
401+
def test_select_flux_schnell_files(
402+
flux_schnell_test_files: list[Path], variant: ModelRepoVariant, expected_files: list[str]
403+
) -> None:
404+
filtered_files = filter_files(flux_schnell_test_files, variant)
405+
assert set(filtered_files) == {Path(f) for f in expected_files}

0 commit comments

Comments
 (0)