Skip to content

Feat/text encoder #27

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/unlearn/configs/uce.md
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ class UnifiedConceptEditingConfig(BaseConfig):
self.add_prompts = False # Whether to add additional prompts

# Preserver concepts (comma-separated if multiple)
self.preserver_concepts = (
self.preserve_concepts = (
"A Lion image" # Comma-separated string of preserver concepts
)

Expand Down
19 changes: 19 additions & 0 deletions mu/algorithms/erase_diff/configs/train_config.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import os
import re
from mu.core.base_config import BaseConfig
from pathlib import Path

Expand Down Expand Up @@ -30,6 +31,8 @@ def __init__(self, **kwargs):
self.use_sample = True
self.num_workers = 4
self.pin_memory = True
self.encoder_component = ["all", "ffn", "attn"]
self.norm_layer = False

for key, value in kwargs.items():
setattr(self, key, value)
Expand All @@ -46,6 +49,22 @@ def validate_config(self):
"notime",
"xlayer",
"selflayer",
"text_encoder_layer_full",
"text_encoder_layer0",
"text_encoder_layer01",
"text_encoder_layer012",
"text_encoder_layer0123",
"text_encoder_layer01234",
"text_encoder_layer012345",
"text_encoder_layer0123456",
"text_encoder_layer01234567",
"text_encoder_layer012345678",
"text_encoder_layer0123456789",
"text_encoder_layer012345678910",
"text_encoder_layer01234567891011",
"text_encoder_layer0_11",
"text_encoder_layer01_1011",
"text_encoder_layer012_91011",
]
if self.epochs <= 0:
raise ValueError("epochs should be a positive integer.")
Expand Down
157 changes: 113 additions & 44 deletions mu/algorithms/erase_diff/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,31 +9,35 @@
import logging

from timm.utils import AverageMeter
import logging
import logging

from mu.core import BaseTrainer
from mu.algorithms.erase_diff.model import EraseDiffModel
from mu.helpers.utils import param_choices


class EraseDiffTrainer(BaseTrainer):
"""
Trainer for the EraseDiff algorithm.
Handles the training loop, loss computation, and optimization.
"""

def __init__(self, model: EraseDiffModel, config: dict, device: str, data_handler, **kwargs):
def __init__(
self, model: EraseDiffModel, config: dict, device: str, data_handler, **kwargs
):
"""
Initialize the EraseDiffTrainer.

Args:
model (EraseDiffModel): Instance of EraseDiffModel.
config (dict): Configuration dictionary.
device (str): Device to perform training on.
device_orig (str): Device for the original (frozen) model.
data_handler (EraseDiffDataHandler): Instance of EraseDiffDataHandler.
data_handler: Instance of EraseDiffDataHandler.
**kwargs: Additional keyword arguments.
"""
super().__init__(model, config, **kwargs)
self.device = device
# Note: the passed model is an EraseDiffModel, whose attribute "model" is the actual network.
self.model = model.model
self.sampler = None
self.sampler_orig = None
Expand All @@ -47,50 +51,84 @@ def setup_optimizer(self):
Setup the optimizer based on the training method.
"""
# Select parameters to train based on train_method
train_method = self.config.get('train_method', 'xattn')
train_method = self.config.get("train_method", "xattn")
parameters = []
for name, param in self.model.model.named_parameters():
if train_method == 'full':
parameters.append(param)
elif train_method == 'xattn' and 'attn2' in name:
parameters.append(param)
elif train_method == 'selfattn' and 'attn1' in name:
if train_method == "full":
for name, param in self.model.model.named_parameters():
parameters.append(param)
elif train_method == 'noxattn':
if not (name.startswith('out.') or 'attn2' in name or 'time_embed' in name):
elif train_method == "xattn":
for name, param in self.model.model.named_parameters():
if "attn2" in name:
parameters.append(param)
elif train_method == "selfattn":
for name, param in self.model.model.named_parameters():
if "attn1" in name:
parameters.append(param)
elif train_method == 'xlayer':
if 'attn2' in name and ('output_blocks.6.' in name or 'output_blocks.8.' in name):
elif train_method == "noxattn":
for name, param in self.model.model.named_parameters():
if not (
name.startswith("out.") or "attn2" in name or "time_embed" in name
):
parameters.append(param)
elif train_method == 'selflayer':
if 'attn1' in name and ('input_blocks.4.' in name or 'input_blocks.7.' in name):
elif train_method == "xlayer":
for name, param in self.model.model.named_parameters():
if "attn2" in name and (
"output_blocks.6." in name or "output_blocks.8." in name
):
parameters.append(param)

self.optimizer = torch.optim.Adam(parameters, lr=self.config.get('lr', 1e-5))
elif train_method == "selflayer":
for name, param in self.model.model.named_parameters():
if "attn1" in name and (
"input_blocks.4." in name or "input_blocks.7." in name
):
parameters.append(param)
elif train_method.startswith("text_encoder"):
if hasattr(self.model, "cond_stage_model"):
text_encoder_module = self.model.cond_stage_model
elif hasattr(self.model, "text_encoder"):
text_encoder_module = self.model.text_encoder.text_model
else:
self.logger.error("Model does not have a text encoder attribute.")
text_encoder_module = None

if text_encoder_module is not None:
parameters = param_choices(
model=text_encoder_module, train_method=train_method
)
self.optimizer = torch.optim.Adam(parameters, lr=self.config.get("lr", 1e-5))

def train(self):
"""
Execute the training loop.
"""
epochs = self.config.get('epochs', 1)
K_steps = self.config.get('K_steps', 2)
alpha = self.config.get('alpha', 0.1)
epochs = self.config.get("epochs", 1)
K_steps = self.config.get("K_steps", 2)
alpha = self.config.get("alpha", 0.1)

# Retrieve data loaders
data_loaders = self.data_handler.get_data_loaders()
forget_dl = data_loaders.get('forget')
remain_dl = data_loaders.get('remain')
forget_dl = data_loaders.get("forget")
remain_dl = data_loaders.get("remain")

self.logger.info(f"Number of forget samples: {len(forget_dl.dataset)}")
self.logger.info(f"Number of remain samples: {len(remain_dl.dataset)}")

train_method = self.config.get("train_method", "xattn")

for epoch in range(epochs):
self.logger.info(f"Starting Epoch {epoch+1}/{epochs}")
with tqdm(total=len(forget_dl), desc=f'Epoch {epoch+1}/{epochs}') as pbar:
self.model.train()
with tqdm(total=len(forget_dl), desc=f"Epoch {epoch+1}/{epochs}") as pbar:
# NEW: For text_encoder training, freeze the overall model and set only text_encoder to train.
if train_method.startswith("text_encoder"):
self.model.model.eval()
if hasattr(self.model.model, "text_encoder"):
self.model.model.text_encoder.train()
else:
self.model.train()

param_i = self.get_param()

for step in range(K_steps):
for step in range(K_steps):
unl_losses = AverageMeter()
for forget_batch in forget_dl:
self.optimizer.zero_grad()
Expand All @@ -107,7 +145,9 @@ def train(self):
pseudo_prompts = remain_prompts

# Forget stage
forget_loss = self.compute_forget_loss(forget_images, forget_prompts, pseudo_prompts)
forget_loss = self.compute_forget_loss(
forget_images, forget_prompts, pseudo_prompts
)

forget_loss.backward()
self.optimizer.step()
Expand All @@ -118,7 +158,13 @@ def train(self):

# Remain stage
for remain_batch in remain_dl:
self.model.train()
if train_method.startswith("text_encoder"):
self.model.model.eval()
if hasattr(self.model.model, "text_encoder"):
self.model.model.text_encoder.train()
else:
self.model.train()

self.optimizer.zero_grad()

remain_images, remain_prompts = remain_batch
Expand All @@ -134,13 +180,15 @@ def train(self):

remain_btch = {
"edited": remain_images.to(self.device),
"edit": {"c_crossattn": remain_prompts}
"edit": {"c_crossattn": remain_prompts},
}
# Remain loss
remain_loss = self.model.shared_step(remain_btch)[0]

# Forget loss within remain stage
unlearn_loss = self.compute_unlearn_loss(forget_images, forget_prompts, pseudo_prompts)
unlearn_loss = self.compute_unlearn_loss(
forget_images, forget_prompts, pseudo_prompts
)

q_loss = unlearn_loss - unl_losses.avg.detach()

Expand All @@ -150,7 +198,12 @@ def train(self):

# Logging
wandb.log({"loss": total_loss.item()})
pbar.set_postfix({"loss": total_loss.item() / self.config.get('batch_size', 4)})
pbar.set_postfix(
{
"loss": total_loss.item()
/ self.config.get("batch_size", 4)
}
)
pbar.update(1)

self.logger.info(f"Epoch {epoch+1}/{epochs} completed.")
Expand Down Expand Up @@ -187,7 +240,9 @@ def set_param(self, old_param: list):
torch.cuda.empty_cache()
torch.manual_seed(0)

def compute_forget_loss(self, forget_images: torch.Tensor, forget_prompts: list, pseudo_prompts: list) -> torch.Tensor:
def compute_forget_loss(
self, forget_images: torch.Tensor, forget_prompts: list, pseudo_prompts: list
) -> torch.Tensor:
"""
Compute the forget loss.

Expand All @@ -201,17 +256,23 @@ def compute_forget_loss(self, forget_images: torch.Tensor, forget_prompts: list,
"""
forget_batch = {
"edited": forget_images.to(self.device),
"edit": {"c_crossattn": forget_prompts}
"edit": {"c_crossattn": forget_prompts},
}
pseudo_batch = {
"edited": forget_images.to(self.device),
"edit": {"c_crossattn": pseudo_prompts}
"edit": {"c_crossattn": pseudo_prompts},
}

forget_input, forget_emb = self.model.get_input(forget_batch, self.model.first_stage_key)
pseudo_input, pseudo_emb = self.model.get_input(pseudo_batch, self.model.first_stage_key)
forget_input, forget_emb = self.model.get_input(
forget_batch, self.model.first_stage_key
)
pseudo_input, pseudo_emb = self.model.get_input(
pseudo_batch, self.model.first_stage_key
)

t = torch.randint(0, self.model.num_timesteps, (forget_input.shape[0],), device=self.device).long()
t = torch.randint(
0, self.model.num_timesteps, (forget_input.shape[0],), device=self.device
).long()
noise = torch.randn_like(forget_input, device=self.device)

forget_noisy = self.model.q_sample(x_start=forget_input, t=t, noise=noise)
Expand All @@ -223,7 +284,9 @@ def compute_forget_loss(self, forget_images: torch.Tensor, forget_prompts: list,
forget_loss = self.criteria(forget_out, pseudo_out)
return forget_loss

def compute_unlearn_loss(self, forget_images: torch.Tensor, forget_prompts: list, pseudo_prompts: list) -> torch.Tensor:
def compute_unlearn_loss(
self, forget_images: torch.Tensor, forget_prompts: list, pseudo_prompts: list
) -> torch.Tensor:
"""
Compute the unlearn loss within the remain stage.

Expand All @@ -237,17 +300,23 @@ def compute_unlearn_loss(self, forget_images: torch.Tensor, forget_prompts: list
"""
forget_batch = {
"edited": forget_images.to(self.device),
"edit": {"c_crossattn": forget_prompts}
"edit": {"c_crossattn": forget_prompts},
}
pseudo_batch = {
"edited": forget_images.to(self.device),
"edit": {"c_crossattn": pseudo_prompts}
"edit": {"c_crossattn": pseudo_prompts},
}

forget_input, forget_emb = self.model.get_input(forget_batch, self.model.first_stage_key)
pseudo_input, pseudo_emb = self.model.get_input(pseudo_batch, self.model.first_stage_key)
forget_input, forget_emb = self.model.get_input(
forget_batch, self.model.first_stage_key
)
pseudo_input, pseudo_emb = self.model.get_input(
pseudo_batch, self.model.first_stage_key
)

t = torch.randint(0, self.model.num_timesteps, (forget_input.shape[0],), device=self.device).long()
t = torch.randint(
0, self.model.num_timesteps, (forget_input.shape[0],), device=self.device
).long()
noise = torch.randn_like(forget_input, device=self.device)

forget_noisy = self.model.q_sample(x_start=forget_input, t=t, noise=noise)
Expand Down
8 changes: 4 additions & 4 deletions mu/algorithms/unified_concept_editing/configs/train_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ def __init__(self, **kwargs):
self.add_prompts = False # Whether to add additional prompts

# Preserver concepts (comma-separated if multiple)
self.preserver_concepts = (
self.preserve_concepts = (
"A Lion image" # Comma-separated string of preserver concepts
)

Expand Down Expand Up @@ -96,9 +96,9 @@ def validate_config(self):
if self.lamb < 0:
raise ValueError("lamb should be a positive value.")

# Validate preserver_concepts
if not isinstance(self.preserver_concepts, str):
raise ValueError("preserver_concepts should be a string.")
# Validate preserve_concepts
if not isinstance(self.preserve_concepts, str):
raise ValueError("preserve_concepts should be a string.")

# Validate base model
if self.base not in ["stable-diffusion-v1-4"]:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,5 +35,5 @@ add_prompts: false


# Additional parameters
preserver_concepts: "A Lion image" # Comma-separated string if multiple concepts, e.g., "concept1,concept2"
preserve_concepts: "A Lion image" # Comma-separated string if multiple concepts, e.g., "concept1,concept2"
base: "stable-diffusion-v1-4" # Base version of stable diffusion
2 changes: 1 addition & 1 deletion mu/algorithms/unified_concept_editing/scripts/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ def main():
parser.add_argument("--erase_scale", type=float, help="Scale for erasure")
parser.add_argument("--lamb", type=float, help="Lambda parameter")
parser.add_argument("--add_prompts", type=bool, help="Whether to add prompts")
parser.add_argument("--preserver_concepts", type=str, help="Concepts to preserve")
parser.add_argument("--preserve_concepts", type=str, help="Concepts to preserve")
parser.add_argument("--base", type=str, help="Base version of stable diffusion")

args = parser.parse_args()
Expand Down
Loading