@@ -2149,3 +2149,51 @@ def check_module(denoiser):
2149
2149
2150
2150
_ , _ , inputs = self .get_dummy_inputs (with_generator = False )
2151
2151
pipe (** inputs , generator = torch .manual_seed (0 ))[0 ]
2152
+
2153
+ def test_inference_load_delete_load_adapters (self ):
2154
+ "Tests if `load_lora_weights()` -> `delete_adapters()` -> `load_lora_weights()` works."
2155
+ for scheduler_cls in self .scheduler_classes :
2156
+ components , text_lora_config , denoiser_lora_config = self .get_dummy_components (scheduler_cls )
2157
+ pipe = self .pipeline_class (** components )
2158
+ pipe = pipe .to (torch_device )
2159
+ pipe .set_progress_bar_config (disable = None )
2160
+ _ , _ , inputs = self .get_dummy_inputs (with_generator = False )
2161
+
2162
+ output_no_lora = pipe (** inputs , generator = torch .manual_seed (0 ))[0 ]
2163
+
2164
+ if "text_encoder" in self .pipeline_class ._lora_loadable_modules :
2165
+ pipe .text_encoder .add_adapter (text_lora_config )
2166
+ self .assertTrue (
2167
+ check_if_lora_correctly_set (pipe .text_encoder ), "Lora not correctly set in text encoder"
2168
+ )
2169
+
2170
+ denoiser = pipe .transformer if self .unet_kwargs is None else pipe .unet
2171
+ denoiser .add_adapter (denoiser_lora_config )
2172
+ self .assertTrue (check_if_lora_correctly_set (denoiser ), "Lora not correctly set in denoiser." )
2173
+
2174
+ if self .has_two_text_encoders or self .has_three_text_encoders :
2175
+ lora_loadable_components = self .pipeline_class ._lora_loadable_modules
2176
+ if "text_encoder_2" in lora_loadable_components :
2177
+ pipe .text_encoder_2 .add_adapter (text_lora_config )
2178
+ self .assertTrue (
2179
+ check_if_lora_correctly_set (pipe .text_encoder_2 ), "Lora not correctly set in text encoder 2"
2180
+ )
2181
+
2182
+ output_adapter_1 = pipe (** inputs , generator = torch .manual_seed (0 ))[0 ]
2183
+
2184
+ with tempfile .TemporaryDirectory () as tmpdirname :
2185
+ modules_to_save = self ._get_modules_to_save (pipe , has_denoiser = True )
2186
+ lora_state_dicts = self ._get_lora_state_dicts (modules_to_save )
2187
+ self .pipeline_class .save_lora_weights (save_directory = tmpdirname , ** lora_state_dicts )
2188
+ self .assertTrue (os .path .isfile (os .path .join (tmpdirname , "pytorch_lora_weights.safetensors" )))
2189
+
2190
+ # First, delete adapter and compare.
2191
+ pipe .delete_adapters (pipe .get_active_adapters ()[0 ])
2192
+ output_no_adapter = pipe (** inputs , generator = torch .manual_seed (0 ))[0 ]
2193
+ self .assertFalse (np .allclose (output_adapter_1 , output_no_adapter , atol = 1e-3 , rtol = 1e-3 ))
2194
+ self .assertTrue (np .allclose (output_no_lora , output_no_adapter , atol = 1e-3 , rtol = 1e-3 ))
2195
+
2196
+ # Then load adapter and compare.
2197
+ pipe .load_lora_weights (tmpdirname )
2198
+ output_lora_loaded = pipe (** inputs , generator = torch .manual_seed (0 ))[0 ]
2199
+ self .assertTrue (np .allclose (output_adapter_1 , output_lora_loaded , atol = 1e-3 , rtol = 1e-3 ))
0 commit comments