@@ -326,3 +326,80 @@ def test_select_multiple_weights(
326326) ->  None :
327327    filtered_files  =  filter_files (sd15_test_files , variant )
328328    assert  set (filtered_files ) ==  {Path (f ) for  f  in  expected_files }
329+ 
330+ 
331+ @pytest .fixture  
332+ def  flux_schnell_test_files () ->  list [Path ]:
333+     return  [
334+         Path (f )
335+         for  f  in  [
336+             "FLUX.1-schnell/.gitattributes" ,
337+             "FLUX.1-schnell/README.md" ,
338+             "FLUX.1-schnell/ae.safetensors" ,
339+             "FLUX.1-schnell/flux1-schnell.safetensors" ,
340+             "FLUX.1-schnell/model_index.json" ,
341+             "FLUX.1-schnell/scheduler/scheduler_config.json" ,
342+             "FLUX.1-schnell/schnell_grid.jpeg" ,
343+             "FLUX.1-schnell/text_encoder/config.json" ,
344+             "FLUX.1-schnell/text_encoder/model.safetensors" ,
345+             "FLUX.1-schnell/text_encoder_2/config.json" ,
346+             "FLUX.1-schnell/text_encoder_2/model-00001-of-00002.safetensors" ,
347+             "FLUX.1-schnell/text_encoder_2/model-00002-of-00002.safetensors" ,
348+             "FLUX.1-schnell/text_encoder_2/model.safetensors.index.json" ,
349+             "FLUX.1-schnell/tokenizer/merges.txt" ,
350+             "FLUX.1-schnell/tokenizer/special_tokens_map.json" ,
351+             "FLUX.1-schnell/tokenizer/tokenizer_config.json" ,
352+             "FLUX.1-schnell/tokenizer/vocab.json" ,
353+             "FLUX.1-schnell/tokenizer_2/special_tokens_map.json" ,
354+             "FLUX.1-schnell/tokenizer_2/spiece.model" ,
355+             "FLUX.1-schnell/tokenizer_2/tokenizer.json" ,
356+             "FLUX.1-schnell/tokenizer_2/tokenizer_config.json" ,
357+             "FLUX.1-schnell/transformer/config.json" ,
358+             "FLUX.1-schnell/transformer/diffusion_pytorch_model-00001-of-00003.safetensors" ,
359+             "FLUX.1-schnell/transformer/diffusion_pytorch_model-00002-of-00003.safetensors" ,
360+             "FLUX.1-schnell/transformer/diffusion_pytorch_model-00003-of-00003.safetensors" ,
361+             "FLUX.1-schnell/transformer/diffusion_pytorch_model.safetensors.index.json" ,
362+             "FLUX.1-schnell/vae/config.json" ,
363+             "FLUX.1-schnell/vae/diffusion_pytorch_model.safetensors" ,
364+         ]
365+     ]
366+ 
367+ 
368+ @pytest .mark .parametrize ( 
369+     ["variant" , "expected_files" ], 
370+     [ 
371+         ( 
372+             ModelRepoVariant .Default , 
373+             [ 
374+                 "FLUX.1-schnell/model_index.json" , 
375+                 "FLUX.1-schnell/scheduler/scheduler_config.json" , 
376+                 "FLUX.1-schnell/text_encoder/config.json" , 
377+                 "FLUX.1-schnell/text_encoder/model.safetensors" , 
378+                 "FLUX.1-schnell/text_encoder_2/config.json" , 
379+                 "FLUX.1-schnell/text_encoder_2/model-00001-of-00002.safetensors" , 
380+                 "FLUX.1-schnell/text_encoder_2/model-00002-of-00002.safetensors" , 
381+                 "FLUX.1-schnell/text_encoder_2/model.safetensors.index.json" , 
382+                 "FLUX.1-schnell/tokenizer/merges.txt" , 
383+                 "FLUX.1-schnell/tokenizer/special_tokens_map.json" , 
384+                 "FLUX.1-schnell/tokenizer/tokenizer_config.json" , 
385+                 "FLUX.1-schnell/tokenizer/vocab.json" , 
386+                 "FLUX.1-schnell/tokenizer_2/special_tokens_map.json" , 
387+                 "FLUX.1-schnell/tokenizer_2/spiece.model" , 
388+                 "FLUX.1-schnell/tokenizer_2/tokenizer.json" , 
389+                 "FLUX.1-schnell/tokenizer_2/tokenizer_config.json" , 
390+                 "FLUX.1-schnell/transformer/config.json" , 
391+                 "FLUX.1-schnell/transformer/diffusion_pytorch_model-00001-of-00003.safetensors" , 
392+                 "FLUX.1-schnell/transformer/diffusion_pytorch_model-00002-of-00003.safetensors" , 
393+                 "FLUX.1-schnell/transformer/diffusion_pytorch_model-00003-of-00003.safetensors" , 
394+                 "FLUX.1-schnell/transformer/diffusion_pytorch_model.safetensors.index.json" , 
395+                 "FLUX.1-schnell/vae/config.json" , 
396+                 "FLUX.1-schnell/vae/diffusion_pytorch_model.safetensors" , 
397+             ], 
398+         ), 
399+     ], 
400+ ) 
401+ def  test_select_flux_schnell_files (
402+     flux_schnell_test_files : list [Path ], variant : ModelRepoVariant , expected_files : list [str ]
403+ ) ->  None :
404+     filtered_files  =  filter_files (flux_schnell_test_files , variant )
405+     assert  set (filtered_files ) ==  {Path (f ) for  f  in  expected_files }
0 commit comments