diff --git a/docs/unlearn/configs/uce.md b/docs/unlearn/configs/uce.md index 5283f66d..681a40e0 100644 --- a/docs/unlearn/configs/uce.md +++ b/docs/unlearn/configs/uce.md @@ -43,7 +43,7 @@ class UnifiedConceptEditingConfig(BaseConfig): self.add_prompts = False # Whether to add additional prompts # Preserver concepts (comma-separated if multiple) - self.preserver_concepts = ( + self.preserve_concepts = ( "A Lion image" # Comma-separated string of preserver concepts ) diff --git a/mu/algorithms/erase_diff/configs/train_config.py b/mu/algorithms/erase_diff/configs/train_config.py index 5d27044c..686ae854 100644 --- a/mu/algorithms/erase_diff/configs/train_config.py +++ b/mu/algorithms/erase_diff/configs/train_config.py @@ -1,4 +1,5 @@ import os +import re from mu.core.base_config import BaseConfig from pathlib import Path @@ -30,6 +31,8 @@ def __init__(self, **kwargs): self.use_sample = True self.num_workers = 4 self.pin_memory = True + self.encoder_component = ["all", "ffn", "attn"] + self.norm_layer = False for key, value in kwargs.items(): setattr(self, key, value) @@ -46,6 +49,22 @@ def validate_config(self): "notime", "xlayer", "selflayer", + "text_encoder_layer_full", + "text_encoder_layer0", + "text_encoder_layer01", + "text_encoder_layer012", + "text_encoder_layer0123", + "text_encoder_layer01234", + "text_encoder_layer012345", + "text_encoder_layer0123456", + "text_encoder_layer01234567", + "text_encoder_layer012345678", + "text_encoder_layer0123456789", + "text_encoder_layer012345678910", + "text_encoder_layer01234567891011", + "text_encoder_layer0_11", + "text_encoder_layer01_1011", + "text_encoder_layer012_91011", ] if self.epochs <= 0: raise ValueError("epochs should be a positive integer.") diff --git a/mu/algorithms/erase_diff/trainer.py b/mu/algorithms/erase_diff/trainer.py index 517d38e7..7d5c4576 100644 --- a/mu/algorithms/erase_diff/trainer.py +++ b/mu/algorithms/erase_diff/trainer.py @@ -9,10 +9,12 @@ import logging from timm.utils import AverageMeter -import logging +import logging from mu.core import BaseTrainer from mu.algorithms.erase_diff.model import EraseDiffModel +from mu.helpers.utils import param_choices + class EraseDiffTrainer(BaseTrainer): """ @@ -20,7 +22,9 @@ class EraseDiffTrainer(BaseTrainer): Handles the training loop, loss computation, and optimization. """ - def __init__(self, model: EraseDiffModel, config: dict, device: str, data_handler, **kwargs): + def __init__( + self, model: EraseDiffModel, config: dict, device: str, data_handler, **kwargs + ): """ Initialize the EraseDiffTrainer. @@ -28,12 +32,12 @@ def __init__(self, model: EraseDiffModel, config: dict, device: str, data_handl model (EraseDiffModel): Instance of EraseDiffModel. config (dict): Configuration dictionary. device (str): Device to perform training on. - device_orig (str): Device for the original (frozen) model. - data_handler (EraseDiffDataHandler): Instance of EraseDiffDataHandler. + data_handler: Instance of EraseDiffDataHandler. **kwargs: Additional keyword arguments. """ super().__init__(model, config, **kwargs) self.device = device + # Note: the passed model is an EraseDiffModel, whose attribute "model" is the actual network. self.model = model.model self.sampler = None self.sampler_orig = None @@ -47,50 +51,84 @@ def setup_optimizer(self): Setup the optimizer based on the training method. """ # Select parameters to train based on train_method - train_method = self.config.get('train_method', 'xattn') + train_method = self.config.get("train_method", "xattn") parameters = [] - for name, param in self.model.model.named_parameters(): - if train_method == 'full': - parameters.append(param) - elif train_method == 'xattn' and 'attn2' in name: - parameters.append(param) - elif train_method == 'selfattn' and 'attn1' in name: + if train_method == "full": + for name, param in self.model.model.named_parameters(): parameters.append(param) - elif train_method == 'noxattn': - if not (name.startswith('out.') or 'attn2' in name or 'time_embed' in name): + elif train_method == "xattn": + for name, param in self.model.model.named_parameters(): + if "attn2" in name: + parameters.append(param) + elif train_method == "selfattn": + for name, param in self.model.model.named_parameters(): + if "attn1" in name: parameters.append(param) - elif train_method == 'xlayer': - if 'attn2' in name and ('output_blocks.6.' in name or 'output_blocks.8.' in name): + elif train_method == "noxattn": + for name, param in self.model.model.named_parameters(): + if not ( + name.startswith("out.") or "attn2" in name or "time_embed" in name + ): parameters.append(param) - elif train_method == 'selflayer': - if 'attn1' in name and ('input_blocks.4.' in name or 'input_blocks.7.' in name): + elif train_method == "xlayer": + for name, param in self.model.model.named_parameters(): + if "attn2" in name and ( + "output_blocks.6." in name or "output_blocks.8." in name + ): parameters.append(param) - - self.optimizer = torch.optim.Adam(parameters, lr=self.config.get('lr', 1e-5)) + elif train_method == "selflayer": + for name, param in self.model.model.named_parameters(): + if "attn1" in name and ( + "input_blocks.4." in name or "input_blocks.7." in name + ): + parameters.append(param) + elif train_method.startswith("text_encoder"): + if hasattr(self.model, "cond_stage_model"): + text_encoder_module = self.model.cond_stage_model + elif hasattr(self.model, "text_encoder"): + text_encoder_module = self.model.text_encoder.text_model + else: + self.logger.error("Model does not have a text encoder attribute.") + text_encoder_module = None + + if text_encoder_module is not None: + parameters = param_choices( + model=text_encoder_module, train_method=train_method + ) + self.optimizer = torch.optim.Adam(parameters, lr=self.config.get("lr", 1e-5)) def train(self): """ Execute the training loop. """ - epochs = self.config.get('epochs', 1) - K_steps = self.config.get('K_steps', 2) - alpha = self.config.get('alpha', 0.1) + epochs = self.config.get("epochs", 1) + K_steps = self.config.get("K_steps", 2) + alpha = self.config.get("alpha", 0.1) # Retrieve data loaders data_loaders = self.data_handler.get_data_loaders() - forget_dl = data_loaders.get('forget') - remain_dl = data_loaders.get('remain') + forget_dl = data_loaders.get("forget") + remain_dl = data_loaders.get("remain") self.logger.info(f"Number of forget samples: {len(forget_dl.dataset)}") self.logger.info(f"Number of remain samples: {len(remain_dl.dataset)}") + train_method = self.config.get("train_method", "xattn") + for epoch in range(epochs): self.logger.info(f"Starting Epoch {epoch+1}/{epochs}") - with tqdm(total=len(forget_dl), desc=f'Epoch {epoch+1}/{epochs}') as pbar: - self.model.train() + with tqdm(total=len(forget_dl), desc=f"Epoch {epoch+1}/{epochs}") as pbar: + # NEW: For text_encoder training, freeze the overall model and set only text_encoder to train. + if train_method.startswith("text_encoder"): + self.model.model.eval() + if hasattr(self.model.model, "text_encoder"): + self.model.model.text_encoder.train() + else: + self.model.train() + param_i = self.get_param() - for step in range(K_steps): + for step in range(K_steps): unl_losses = AverageMeter() for forget_batch in forget_dl: self.optimizer.zero_grad() @@ -107,7 +145,9 @@ def train(self): pseudo_prompts = remain_prompts # Forget stage - forget_loss = self.compute_forget_loss(forget_images, forget_prompts, pseudo_prompts) + forget_loss = self.compute_forget_loss( + forget_images, forget_prompts, pseudo_prompts + ) forget_loss.backward() self.optimizer.step() @@ -118,7 +158,13 @@ def train(self): # Remain stage for remain_batch in remain_dl: - self.model.train() + if train_method.startswith("text_encoder"): + self.model.model.eval() + if hasattr(self.model.model, "text_encoder"): + self.model.model.text_encoder.train() + else: + self.model.train() + self.optimizer.zero_grad() remain_images, remain_prompts = remain_batch @@ -134,13 +180,15 @@ def train(self): remain_btch = { "edited": remain_images.to(self.device), - "edit": {"c_crossattn": remain_prompts} + "edit": {"c_crossattn": remain_prompts}, } # Remain loss remain_loss = self.model.shared_step(remain_btch)[0] # Forget loss within remain stage - unlearn_loss = self.compute_unlearn_loss(forget_images, forget_prompts, pseudo_prompts) + unlearn_loss = self.compute_unlearn_loss( + forget_images, forget_prompts, pseudo_prompts + ) q_loss = unlearn_loss - unl_losses.avg.detach() @@ -150,7 +198,12 @@ def train(self): # Logging wandb.log({"loss": total_loss.item()}) - pbar.set_postfix({"loss": total_loss.item() / self.config.get('batch_size', 4)}) + pbar.set_postfix( + { + "loss": total_loss.item() + / self.config.get("batch_size", 4) + } + ) pbar.update(1) self.logger.info(f"Epoch {epoch+1}/{epochs} completed.") @@ -187,7 +240,9 @@ def set_param(self, old_param: list): torch.cuda.empty_cache() torch.manual_seed(0) - def compute_forget_loss(self, forget_images: torch.Tensor, forget_prompts: list, pseudo_prompts: list) -> torch.Tensor: + def compute_forget_loss( + self, forget_images: torch.Tensor, forget_prompts: list, pseudo_prompts: list + ) -> torch.Tensor: """ Compute the forget loss. @@ -201,17 +256,23 @@ def compute_forget_loss(self, forget_images: torch.Tensor, forget_prompts: list, """ forget_batch = { "edited": forget_images.to(self.device), - "edit": {"c_crossattn": forget_prompts} + "edit": {"c_crossattn": forget_prompts}, } pseudo_batch = { "edited": forget_images.to(self.device), - "edit": {"c_crossattn": pseudo_prompts} + "edit": {"c_crossattn": pseudo_prompts}, } - forget_input, forget_emb = self.model.get_input(forget_batch, self.model.first_stage_key) - pseudo_input, pseudo_emb = self.model.get_input(pseudo_batch, self.model.first_stage_key) + forget_input, forget_emb = self.model.get_input( + forget_batch, self.model.first_stage_key + ) + pseudo_input, pseudo_emb = self.model.get_input( + pseudo_batch, self.model.first_stage_key + ) - t = torch.randint(0, self.model.num_timesteps, (forget_input.shape[0],), device=self.device).long() + t = torch.randint( + 0, self.model.num_timesteps, (forget_input.shape[0],), device=self.device + ).long() noise = torch.randn_like(forget_input, device=self.device) forget_noisy = self.model.q_sample(x_start=forget_input, t=t, noise=noise) @@ -223,7 +284,9 @@ def compute_forget_loss(self, forget_images: torch.Tensor, forget_prompts: list, forget_loss = self.criteria(forget_out, pseudo_out) return forget_loss - def compute_unlearn_loss(self, forget_images: torch.Tensor, forget_prompts: list, pseudo_prompts: list) -> torch.Tensor: + def compute_unlearn_loss( + self, forget_images: torch.Tensor, forget_prompts: list, pseudo_prompts: list + ) -> torch.Tensor: """ Compute the unlearn loss within the remain stage. @@ -237,17 +300,23 @@ def compute_unlearn_loss(self, forget_images: torch.Tensor, forget_prompts: list """ forget_batch = { "edited": forget_images.to(self.device), - "edit": {"c_crossattn": forget_prompts} + "edit": {"c_crossattn": forget_prompts}, } pseudo_batch = { "edited": forget_images.to(self.device), - "edit": {"c_crossattn": pseudo_prompts} + "edit": {"c_crossattn": pseudo_prompts}, } - forget_input, forget_emb = self.model.get_input(forget_batch, self.model.first_stage_key) - pseudo_input, pseudo_emb = self.model.get_input(pseudo_batch, self.model.first_stage_key) + forget_input, forget_emb = self.model.get_input( + forget_batch, self.model.first_stage_key + ) + pseudo_input, pseudo_emb = self.model.get_input( + pseudo_batch, self.model.first_stage_key + ) - t = torch.randint(0, self.model.num_timesteps, (forget_input.shape[0],), device=self.device).long() + t = torch.randint( + 0, self.model.num_timesteps, (forget_input.shape[0],), device=self.device + ).long() noise = torch.randn_like(forget_input, device=self.device) forget_noisy = self.model.q_sample(x_start=forget_input, t=t, noise=noise) diff --git a/mu/algorithms/unified_concept_editing/configs/train_config.py b/mu/algorithms/unified_concept_editing/configs/train_config.py index c51eb2ad..ef389577 100644 --- a/mu/algorithms/unified_concept_editing/configs/train_config.py +++ b/mu/algorithms/unified_concept_editing/configs/train_config.py @@ -45,7 +45,7 @@ def __init__(self, **kwargs): self.add_prompts = False # Whether to add additional prompts # Preserver concepts (comma-separated if multiple) - self.preserver_concepts = ( + self.preserve_concepts = ( "A Lion image" # Comma-separated string of preserver concepts ) @@ -96,9 +96,9 @@ def validate_config(self): if self.lamb < 0: raise ValueError("lamb should be a positive value.") - # Validate preserver_concepts - if not isinstance(self.preserver_concepts, str): - raise ValueError("preserver_concepts should be a string.") + # Validate preserve_concepts + if not isinstance(self.preserve_concepts, str): + raise ValueError("preserve_concepts should be a string.") # Validate base model if self.base not in ["stable-diffusion-v1-4"]: diff --git a/mu/algorithms/unified_concept_editing/configs/train_config.yaml b/mu/algorithms/unified_concept_editing/configs/train_config.yaml index 9316f43e..2ed8ed5a 100755 --- a/mu/algorithms/unified_concept_editing/configs/train_config.yaml +++ b/mu/algorithms/unified_concept_editing/configs/train_config.yaml @@ -35,5 +35,5 @@ add_prompts: false # Additional parameters -preserver_concepts: "A Lion image" # Comma-separated string if multiple concepts, e.g., "concept1,concept2" +preserve_concepts: "A Lion image" # Comma-separated string if multiple concepts, e.g., "concept1,concept2" base: "stable-diffusion-v1-4" # Base version of stable diffusion \ No newline at end of file diff --git a/mu/algorithms/unified_concept_editing/scripts/train.py b/mu/algorithms/unified_concept_editing/scripts/train.py index b70ddee8..abaf1237 100644 --- a/mu/algorithms/unified_concept_editing/scripts/train.py +++ b/mu/algorithms/unified_concept_editing/scripts/train.py @@ -54,7 +54,7 @@ def main(): parser.add_argument("--erase_scale", type=float, help="Scale for erasure") parser.add_argument("--lamb", type=float, help="Lambda parameter") parser.add_argument("--add_prompts", type=bool, help="Whether to add prompts") - parser.add_argument("--preserver_concepts", type=str, help="Concepts to preserve") + parser.add_argument("--preserve_concepts", type=str, help="Concepts to preserve") parser.add_argument("--base", type=str, help="Base version of stable diffusion") args = parser.parse_args() diff --git a/mu/helpers/utils.py b/mu/helpers/utils.py index 54e8ab19..3fa0c4d3 100644 --- a/mu/helpers/utils.py +++ b/mu/helpers/utils.py @@ -174,5 +174,528 @@ def to_cuda(elements): return elements +def param_choices(model, train_method, component="all", final_layer_norm=False): + # choose parameters to train based on train_method + model = model.transformer + parameters = [] + # Text Encoder FUll Weight Tuning + if train_method == "text_encoder_full": + for name, param in model.text_model.named_parameters(): + # Final Layer Norm + if name.startswith("final_layer_norm"): + if component == "all" or final_layer_norm == True: + print(name) + parameters.append(param) + else: + pass + # Transformer layers + elif name.startswith("encoder"): + if component == "ffn" and "mlp" in name: + print(name) + parameters.append(param) + elif component == "attn" and "self_attn" in name: + print(name) + parameters.append(param) + elif component == "all": + print(name) + parameters.append(param) + else: + pass + + # Embedding layers + else: + pass + + # Text Encoder Layer 0 Tuning + elif train_method == "text_encoder_layer0": + for name, param in model.text_model.named_parameters(): + # Encoder Layer 0 + if name.startswith("encoder.layers.0"): + if component == "ffn" and "mlp" in name: + print(name) + parameters.append(param) + elif component == "attn" and "self_attn" in name: + print(name) + parameters.append(param) + elif component == "all": + print(name) + parameters.append(param) + else: + pass + + elif name.startswith("final_layer_norm") and final_layer_norm == True: + print(name) + parameters.append(param) + + else: + pass + + elif train_method == "text_encoder_layer01": + for name, param in model.text_model.named_parameters(): + # Encoder Layer 0 + if name.startswith("encoder.layers.0") or name.startswith( + "encoder.layers.1" + ): + if component == "ffn" and "mlp" in name: + print(name) + parameters.append(param) + elif component == "attn" and "self_attn" in name: + print(name) + parameters.append(param) + elif component == "all": + print(name) + parameters.append(param) + else: + pass + elif name.startswith("final_layer_norm") and final_layer_norm == True: + print(name) + parameters.append(param) + + else: + pass + + elif train_method == "text_encoder_layer012": + for name, param in model.text_model.named_parameters(): + # Encoder Layer 0 + if ( + name.startswith("encoder.layers.0") + or name.startswith("encoder.layers.1") + or name.startswith("encoder.layers.2") + ): + if component == "ffn" and "mlp" in name: + print(name) + parameters.append(param) + elif component == "attn" and "self_attn" in name: + print(name) + parameters.append(param) + elif component == "all": + print(name) + parameters.append(param) + else: + pass + + elif name.startswith("final_layer_norm") and final_layer_norm == True: + print(name) + parameters.append(param) + + else: + pass + + elif train_method == "text_encoder_layer0123": + for name, param in model.text_model.named_parameters(): + # Encoder Layer 0 + if ( + name.startswith("encoder.layers.0") + or name.startswith("encoder.layers.1") + or name.startswith("encoder.layers.2") + or name.startswith("encoder.layers.3") + ): + if component == "ffn" and "mlp" in name: + print(name) + parameters.append(param) + elif component == "attn" and "self_attn" in name: + print(name) + parameters.append(param) + elif component == "all": + print(name) + parameters.append(param) + else: + pass + + elif name.startswith("final_layer_norm") and final_layer_norm == True: + print(name) + parameters.append(param) + + else: + pass + + elif train_method == "text_encoder_layer01234": + for name, param in model.text_model.named_parameters(): + # Encoder Layer 0 + if ( + name.startswith("encoder.layers.0") + or name.startswith("encoder.layers.1") + or name.startswith("encoder.layers.2") + or name.startswith("encoder.layers.3") + or name.startswith("encoder.layers.4") + ): + if component == "ffn" and "mlp" in name: + print(name) + parameters.append(param) + elif component == "attn" and "self_attn" in name: + print(name) + parameters.append(param) + elif component == "all": + print(name) + parameters.append(param) + else: + pass + + elif name.startswith("final_layer_norm") and final_layer_norm == True: + print(name) + parameters.append(param) + + else: + pass + + elif train_method == "text_encoder_layer012345": + for name, param in model.text_model.named_parameters(): + # Encoder Layer 0 + if ( + name.startswith("encoder.layers.0") + or name.startswith("encoder.layers.1") + or name.startswith("encoder.layers.2") + or name.startswith("encoder.layers.3") + or name.startswith("encoder.layers.4") + or name.startswith("encoder.layers.5") + ): + if component == "ffn" and "mlp" in name: + print(name) + parameters.append(param) + elif component == "attn" and "self_attn" in name: + print(name) + parameters.append(param) + elif component == "all": + print(name) + parameters.append(param) + else: + pass + + elif name.startswith("final_layer_norm") and final_layer_norm == True: + print(name) + parameters.append(param) + + else: + pass + + elif train_method == "text_encoder_layer0123456": + for name, param in model.text_model.named_parameters(): + # Encoder Layer 0 + if ( + name.startswith("encoder.layers.0") + or name.startswith("encoder.layers.1") + or name.startswith("encoder.layers.2") + or name.startswith("encoder.layers.3") + or name.startswith("encoder.layers.4") + or name.startswith("encoder.layers.5") + or name.startswith("encoder.layers.6") + ): + if component == "ffn" and "mlp" in name: + print(name) + parameters.append(param) + elif component == "attn" and "self_attn" in name: + print(name) + parameters.append(param) + elif component == "all": + print(name) + parameters.append(param) + else: + pass + + elif name.startswith("final_layer_norm") and final_layer_norm == True: + print(name) + parameters.append(param) + + else: + pass + + elif train_method == "text_encoder_layer01234567": + for name, param in model.text_model.named_parameters(): + # Encoder Layer 0 + if ( + name.startswith("encoder.layers.0") + or name.startswith("encoder.layers.1") + or name.startswith("encoder.layers.2") + or name.startswith("encoder.layers.3") + or name.startswith("encoder.layers.4") + or name.startswith("encoder.layers.5") + or name.startswith("encoder.layers.6") + or name.startswith("encoder.layers.7") + ): + if component == "ffn" and "mlp" in name: + print(name) + parameters.append(param) + elif component == "attn" and "self_attn" in name: + print(name) + parameters.append(param) + elif component == "all": + print(name) + parameters.append(param) + else: + pass + + elif name.startswith("final_layer_norm") and final_layer_norm == True: + print(name) + parameters.append(param) + + else: + pass + + elif train_method == "text_encoder_layer012345678": + for name, param in model.text_model.named_parameters(): + # Encoder Layer 0 + if ( + name.startswith("encoder.layers.0") + or name.startswith("encoder.layers.1") + or name.startswith("encoder.layers.2") + or name.startswith("encoder.layers.3") + or name.startswith("encoder.layers.4") + or name.startswith("encoder.layers.5") + or name.startswith("encoder.layers.6") + or name.startswith("encoder.layers.7") + or name.startswith("encoder.layers.8") + ): + if component == "ffn" and "mlp" in name: + print(name) + parameters.append(param) + elif component == "attn" and "self_attn" in name: + print(name) + parameters.append(param) + elif component == "all": + print(name) + parameters.append(param) + else: + pass + + elif name.startswith("final_layer_norm") and final_layer_norm == True: + print(name) + parameters.append(param) + + else: + pass + + elif train_method == "text_encoder_layer0123456789": + for name, param in model.text_model.named_parameters(): + # Encoder Layer 0 + if ( + name.startswith("encoder.layers.0") + or name.startswith("encoder.layers.1") + or name.startswith("encoder.layers.2") + or name.startswith("encoder.layers.3") + or name.startswith("encoder.layers.4") + or name.startswith("encoder.layers.5") + or name.startswith("encoder.layers.6") + or name.startswith("encoder.layers.7") + or name.startswith("encoder.layers.8") + or name.startswith("encoder.layers.9") + ): + if component == "ffn" and "mlp" in name: + print(name) + parameters.append(param) + elif component == "attn" and "self_attn" in name: + print(name) + parameters.append(param) + elif component == "all": + print(name) + parameters.append(param) + else: + pass + + elif name.startswith("final_layer_norm") and final_layer_norm == True: + print(name) + parameters.append(param) + + else: + pass + + elif train_method == "text_encoder_layer012345678910": + for name, param in model.text_model.named_parameters(): + # Encoder Layer 0 + if ( + name.startswith("encoder.layers.0") + or name.startswith("encoder.layers.1") + or name.startswith("encoder.layers.2") + or name.startswith("encoder.layers.3") + or name.startswith("encoder.layers.4") + or name.startswith("encoder.layers.5") + or name.startswith("encoder.layers.6") + or name.startswith("encoder.layers.7") + or name.startswith("encoder.layers.8") + or name.startswith("encoder.layers.9") + or name.startswith("encoder.layers.10") + ): + if component == "ffn" and "mlp" in name: + print(name) + parameters.append(param) + elif component == "attn" and "self_attn" in name: + print(name) + parameters.append(param) + elif component == "all": + print(name) + parameters.append(param) + else: + pass + + elif name.startswith("final_layer_norm") and final_layer_norm == True: + print(name) + parameters.append(param) + + else: + pass + + elif train_method == "text_encoder_layer01234567891011": + for name, param in model.text_model.named_parameters(): + # Encoder Layer 0 + if ( + name.startswith("encoder.layers.0") + or name.startswith("encoder.layers.1") + or name.startswith("encoder.layers.2") + or name.startswith("encoder.layers.3") + or name.startswith("encoder.layers.4") + or name.startswith("encoder.layers.5") + or name.startswith("encoder.layers.6") + or name.startswith("encoder.layers.7") + or name.startswith("encoder.layers.8") + or name.startswith("encoder.layers.9") + or name.startswith("encoder.layers.10") + or name.startswith("encoder.layers.11") + ): + if component == "ffn" and "mlp" in name: + print(name) + parameters.append(param) + elif component == "attn" and "self_attn" in name: + print(name) + parameters.append(param) + elif component == "all": + print(name) + parameters.append(param) + else: + pass + + elif name.startswith("final_layer_norm") and final_layer_norm == True: + print(name) + parameters.append(param) + + else: + pass + + elif train_method == "text_encoder_layer0_11": + for name, param in model.text_model.named_parameters(): + # Encoder Layer 0 + if name.startswith("encoder.layers.0") or name.startswith( + "encoder.layers.11" + ): + if component == "ffn" and "mlp" in name: + print(name) + parameters.append(param) + elif component == "attn" and "self_attn" in name: + print(name) + parameters.append(param) + elif component == "all": + print(name) + parameters.append(param) + else: + pass + + elif name.startswith("final_layer_norm") and final_layer_norm == True: + print(name) + parameters.append(param) + + else: + pass + + elif train_method == "text_encoder_layer01_1011": + for name, param in model.text_model.named_parameters(): + # Encoder Layer 0 + if ( + name.startswith("encoder.layers.0") + or name.startswith("encoder.layers.1") + or name.startswith("encoder.layers.10") + or name.startswith("encoder.layers.11") + ): + if component == "ffn" and "mlp" in name: + print(name) + parameters.append(param) + elif component == "attn" and "self_attn" in name: + print(name) + parameters.append(param) + elif component == "all": + print(name) + parameters.append(param) + else: + pass + + elif name.startswith("final_layer_norm") and final_layer_norm == True: + print(name) + parameters.append(param) + + else: + pass + + elif train_method == "text_encoder_layer012_91011": + for name, param in model.text_model.named_parameters(): + # Encoder Layer 0 + if ( + name.startswith("encoder.layers.0") + or name.startswith("encoder.layers.1") + or name.startswith("encoder.layers.2") + or name.startswith("encoder.layers.9") + or name.startswith("encoder.layers.10") + or name.startswith("encoder.layers.11") + ): + if component == "ffn" and "mlp" in name: + print(name) + parameters.append(param) + elif component == "attn" and "self_attn" in name: + print(name) + parameters.append(param) + elif component == "all": + print(name) + parameters.append(param) + else: + pass + + elif name.startswith("final_layer_norm") and final_layer_norm == True: + print(name) + parameters.append(param) + + else: + pass + + # UNet Model Tuning + else: + for name, param in model.model.diffusion_model.named_parameters(): + # train all layers except x-attns and time_embed layers + if train_method == "noxattn": + if name.startswith("out.") or "attn2" in name or "time_embed" in name: + pass + else: + print(name) + parameters.append(param) + + # train only self attention layers + if train_method == "selfattn": + if "attn1" in name: + print(name) + parameters.append(param) + + # train only x attention layers + if train_method == "xattn": + if "attn2" in name: + print(name) + parameters.append(param) + + # train all layers + if train_method == "full": + print(name) + parameters.append(param) + + # train all layers except time embed layers + if train_method == "notime": + if not (name.startswith("out.") or "time_embed" in name): + print(name) + parameters.append(param) + if train_method == "xlayer": + if "attn2" in name: + if "output_blocks.6." in name or "output_blocks.8." in name: + print(name) + parameters.append(param) + if train_method == "selflayer": + if "attn1" in name: + if "input_blocks.4." in name or "input_blocks.7." in name: + print(name) + parameters.append(param) + + return parameters