From 8c487e98303f0405618a4b831b362a9dc3a12811 Mon Sep 17 00:00:00 2001 From: palisthadeshar Date: Mon, 3 Feb 2025 12:01:51 +0000 Subject: [PATCH 01/22] advunlearn refactor --- mu_attack/attackers/soft_prompt.py | 184 ++++++++++ mu_attack/execs/adv_attack.py | 372 +++++++++++++++++++ mu_attack/helpers/utils.py | 556 +++++++++++++++++++++++++++++ 3 files changed, 1112 insertions(+) create mode 100644 mu_attack/attackers/soft_prompt.py create mode 100644 mu_attack/execs/adv_attack.py diff --git a/mu_attack/attackers/soft_prompt.py b/mu_attack/attackers/soft_prompt.py new file mode 100644 index 00000000..8f27cd35 --- /dev/null +++ b/mu_attack/attackers/soft_prompt.py @@ -0,0 +1,184 @@ +import torch +import wandb +from mu.helpers import sample_model +from mu_attack.helpers.utils import split_id, id2embedding, split_embd, init_adv, construct_embd, construct_id + + +class SoftPromptAttack: + """ + A class to perform a soft prompt attack on the ESD model. + + Attributes: + model: The ESD model. + model_orig: The frozen (original) model. + tokenizer: The tokenizer. + text_encoder: The text encoder. + sampler: The sampler. + emb_0: Unconditional embedding. + emb_p: Conditional embedding. + start_guidance: Guidance scale for sampling. + devices: List of devices to use. + ddim_steps: Number of DDIM steps. + ddim_eta: The eta parameter for DDIM. + image_size: The size (width and height) for generated images. + criteria: The loss criteria function. + k: Number of tokens (or a related parameter for the prompt). + all_embeddings: The preloaded word embeddings. + """ + + def __init__(self, model, model_orig, tokenizer, text_encoder, sampler, + emb_0, emb_p, start_guidance, devices, ddim_steps, ddim_eta, + image_size, criteria, k, all_embeddings): + self.model = model + self.model_orig = model_orig + self.tokenizer = tokenizer + self.text_encoder = text_encoder + self.sampler = sampler + self.emb_0 = emb_0 + self.emb_p = emb_p + self.start_guidance = start_guidance + self.devices = devices + self.ddim_steps = ddim_steps + self.ddim_eta = ddim_eta + self.image_size = image_size + self.criteria = criteria + self.k = k + self.all_embeddings = all_embeddings + + def attack(self, global_step, word, attack_round, attack_type, + attack_embd_type, attack_step, attack_lr, + attack_init=None, attack_init_embd=None, attack_method='pgd'): + """ + Perform soft prompt attack on the ESD model. + + Args: + global_step (int): The current global training step. + word (str): The input prompt. + attack_round (int): The current attack round. + attack_type (str): Type of attack ("add" or "insert"). + attack_embd_type (str): Type of adversarial embedding ("condition_embd" or "word_embd"). + attack_step (int): Number of steps to run the attack. + attack_lr (float): Learning rate for the adversarial optimization. + attack_init (str, optional): Initialization method ("latest" or "random"). + attack_init_embd (torch.Tensor, optional): Initial adversarial embedding. + attack_method (str, optional): Attack method to use ("pgd" or "fast_at"). + + Returns: + tuple: Depending on attack_embd_type, returns a tuple (embedding, input_ids) + where the embedding is either a conditional or word embedding. + """ + orig_prompt_len = len(word.split()) + if attack_type == 'add': + # When using "add", update k to match the prompt length. + self.k = orig_prompt_len + + # A helper lambda to sample an image until a given time step. + quick_sample_till_t = lambda x, s, code, t: sample_model( + self.model, self.sampler, x, self.image_size, self.image_size, + self.ddim_steps, s, self.ddim_eta, start_code=code, till_T=t, verbose=False + ) + + # --- Tokenization and Embedding --- + text_input = self.tokenizer( + word, padding="max_length", max_length=self.tokenizer.model_max_length, + return_tensors="pt", truncation=True + ) + sot_id, mid_id, replace_id, eot_id = split_id( + text_input.input_ids.to(self.devices[0]), self.k, orig_prompt_len + ) + + text_embeddings = id2embedding( + self.tokenizer, self.all_embeddings, + text_input.input_ids.to(self.devices[0]), self.devices[0] + ) + sot_embd, mid_embd, _, eot_embd = split_embd(text_embeddings, self.k, orig_prompt_len) + + # --- Initialize the adversarial embedding --- + if attack_init == 'latest': + adv_embedding = init_adv(self.k, self.tokenizer, self.all_embeddings, + attack_type, self.devices[0], 1, attack_init_embd) + elif attack_init == 'random': + adv_embedding = init_adv(self.k, self.tokenizer, self.all_embeddings, + attack_type, self.devices[0], 1) + else: + # Default initialization if no method is provided + adv_embedding = init_adv(self.k, self.tokenizer, self.all_embeddings, + attack_type, self.devices[0], 1) + + attack_opt = torch.optim.Adam([adv_embedding], lr=attack_lr) + + # For the condition_embd attack type, construct the initial adversarial condition embedding. + if attack_embd_type == 'condition_embd': + input_adv_condition_embedding = construct_embd( + self.k, adv_embedding, attack_type, sot_embd, mid_embd, eot_embd + ) + adv_input_ids = construct_id( + self.k, replace_id, attack_type, sot_id, eot_id, mid_id + ) + + print(f'[{attack_type}] Starting {attack_method} attack on "{word}"') + + # --- Attack Loop --- + for i in range(attack_step): + # Randomly sample a time step for the attack. + t_enc = torch.randint(self.ddim_steps, (1,), device=self.devices[0]) + og_num = round((int(t_enc) / self.ddim_steps) * 1000) + og_num_lim = round((int(t_enc + 1) / self.ddim_steps) * 1000) + t_enc_ddpm = torch.randint(og_num, og_num_lim, (1,), device=self.devices[0]) + start_code = torch.randn((1, 4, 64, 64)).to(self.devices[0]) + + with torch.no_grad(): + # Generate an image with the concept from the frozen model. + z = quick_sample_till_t( + self.emb_p.to(self.devices[0]), self.start_guidance, start_code, int(t_enc) + ) + e_0 = self.model_orig.apply_model( + z.to(self.devices[0]), t_enc_ddpm.to(self.devices[0]), self.emb_0.to(self.devices[0]) + ) + e_p = self.model_orig.apply_model( + z.to(self.devices[0]), t_enc_ddpm.to(self.devices[0]), self.emb_p.to(self.devices[0]) + ) + + # For word_embd attack type, update the adversarial condition embedding using the text encoder. + if attack_embd_type == 'word_embd': + input_adv_word_embedding = construct_embd( + self.k, adv_embedding, attack_type, sot_embd, mid_embd, eot_embd + ) + adv_input_ids = construct_id( + self.k, replace_id, attack_type, sot_id, eot_id, mid_id + ) + input_adv_condition_embedding = self.text_encoder( + input_ids=adv_input_ids.to(self.devices[0]), + inputs_embeds=input_adv_word_embedding + )[0] + + # Get the conditional score from the ESD model with the adversarial condition embedding. + e_n = self.model.apply_model( + z.to(self.devices[0]), t_enc_ddpm.to(self.devices[0]), + input_adv_condition_embedding.to(self.devices[0]) + ) + e_0.requires_grad = False + e_p.requires_grad = False + + # Compute the loss between the adversarial output and the target. + loss = self.criteria(e_n.to(self.devices[0]), e_p.to(self.devices[0])) + loss.backward() + + if attack_method == 'pgd': + attack_opt.step() + elif attack_method == 'fast_at': + adv_embedding.grad.sign_() + attack_opt.step() + else: + raise ValueError('attack_method must be either pgd or fast_at') + + wandb.log({'Attack_Loss': loss.item()}, step=global_step + i) + wandb.log({'Train_Loss': 0.0}, step=global_step + i) + + # --- Return the adversarial embeddings and input IDs --- + if attack_embd_type == 'condition_embd': + return input_adv_condition_embedding, adv_input_ids + elif attack_embd_type == 'word_embd': + return input_adv_word_embedding, adv_input_ids + else: + raise ValueError('attack_embd_type must be either condition_embd or word_embd') diff --git a/mu_attack/execs/adv_attack.py b/mu_attack/execs/adv_attack.py new file mode 100644 index 00000000..3d59e288 --- /dev/null +++ b/mu_attack/execs/adv_attack.py @@ -0,0 +1,372 @@ + +from mu.helpers import sample_model +from mu_attack.tasks.utils.text_encoder import CustomTextEncoder +from mu_attack.helpers.utils import id2embedding, param_choices, get_models +from mu_attack.attackers.soft_prompt import SoftPromptAttack +from transformers import CLIPTextModel, CLIPTokenizer +from diffusers import AutoencoderKL +import torch +from tqdm import tqdm +import random +import argparse +import wandb +from pathlib import Path +import os + + +class AdvUnlearn: + """ + Class for adversarial unlearning training. + + This class wraps the full training pipeline including prompt cleaning, + attack (adversarial prompt generation), and retention-based regularized training. + """ + def __init__( + self, + prompt, + dataset_retain, + retain_batch, + retain_train, + retain_step, + retain_loss_w, + attack_method, + train_method, + norm_layer, + component, + start_guidance, + negative_guidance, + iterations, + save_interval, + lr, + config_path, + ckpt_path, + diffusers_config_path, + output_dir, + devices, + seperator=None, + image_size=512, + ddim_steps=50, + adv_prompt_num=3, + attack_embd_type='word_embd', + attack_type='prefix_k', + attack_init='latest', + warmup_iter=200, + attack_step=30, + attack_lr=1e-2, + adv_prompt_update_step=20 + ): + # General training and attack settings + self.prompt = prompt + self.dataset_retain = dataset_retain + self.retain_batch = retain_batch + self.retain_train = retain_train + self.retain_step = retain_step + self.retain_loss_w = retain_loss_w + self.attack_method = attack_method + self.train_method = train_method + self.norm_layer = norm_layer + self.component = component + self.start_guidance = start_guidance + self.negative_guidance = negative_guidance + self.iterations = iterations + self.save_interval = save_interval + self.lr = lr + self.config_path = config_path + self.ckpt_path = ckpt_path + self.diffusers_config_path = diffusers_config_path + self.output_dir = output_dir + self.devices = devices + self.seperator = seperator + self.image_size = image_size + self.ddim_steps = ddim_steps + self.adv_prompt_num = adv_prompt_num + self.attack_embd_type = attack_embd_type + self.attack_type = attack_type + self.attack_init = attack_init + self.warmup_iter = warmup_iter + self.attack_step = attack_step + self.attack_lr = attack_lr + self.adv_prompt_update_step = adv_prompt_update_step + + # Will be set during training. + self.words = None + self.retain_dataset = None + self.tokenizer = None + self.text_encoder = None + self.custom_text_encoder = None + self.all_embeddings = None + self.vae = None + self.model_orig = None + self.sampler_orig = None + self.model = None + self.sampler = None + self.parameters = None + self.opt = None + self.criteria = torch.nn.MSELoss() + + # For adversarial prompt update + self.adv_word_embd = None + self.adv_condition_embd = None + self.adv_input_ids = None + + def setup(self): + """Stage 0 & 1: Prompt cleaning and training setup.""" + # --- Prompt cleaning --- + word_print = self.prompt.replace(' ', '') + # Special cases for certain prompts + if self.prompt == 'allartist': + self.prompt = "Kelly Mckernan, Thomas Kinkade, Ajin Demi Human, Alena Aenami, Tyler Edlin, Kilian Eng" + if self.prompt == 'i2p': + self.prompt = "hate, harassment, violence, suffering, humiliation, harm, suicide, sexual, nudity, bodily fluids, blood" + if self.prompt == "artifact": + self.prompt = ("ugly, tiling, poorly drawn hands, poorly drawn feet, poorly drawn face, out of frame, " + "mutation, mutated, extra limbs, extra legs, extra arms, disfigured, deformed, cross-eye, " + "body out of frame, blurry, bad art, bad anatomy, blurred, text, watermark, grainy") + + if self.seperator is not None: + self.words = [w.strip() for w in self.prompt.split(self.seperator)] + else: + self.words = [self.prompt] + print(f'The Concept Prompt to be unlearned: {self.words}') + + # Create a retaining dataset (assumed to be a prompt dataset) + self.retain_dataset = retain_prompt(self.dataset_retain) + + # --- Training Setup --- + ddim_eta = 0 # constant value for training + + model_name_or_path = "CompVis/stable-diffusion-v1-4" + cache_path = ".cache" + # Load the VAE + self.vae = AutoencoderKL.from_pretrained(model_name_or_path, subfolder="vae", cache_dir=cache_path).to(self.devices[0]) + # Load tokenizer and text encoder + self.tokenizer = CLIPTokenizer.from_pretrained(model_name_or_path, subfolder="tokenizer", cache_dir=cache_path) + self.text_encoder = CLIPTextModel.from_pretrained(model_name_or_path, subfolder="text_encoder", cache_dir=cache_path).to(self.devices[0]) + self.custom_text_encoder = CustomTextEncoder(self.text_encoder).to(self.devices[0]) + self.all_embeddings = self.custom_text_encoder.get_all_embedding().unsqueeze(0) + + # Load models using your helper function (assumed to be defined in utils) + self.model_orig, self.sampler_orig, self.model, self.sampler = get_models(self.config_path, self.ckpt_path, self.devices) + self.model_orig.eval() + + # Setup trainable parameters based on train_method + if 'text_encoder' in self.train_method: + self.parameters = param_choices(model=self.custom_text_encoder, train_method=self.train_method, component=self.component, final_layer_norm=self.norm_layer) + else: + self.parameters = param_choices(model=self.model, train_method=self.train_method, component=self.component, final_layer_norm=self.norm_layer) + + self.opt = torch.optim.Adam(self.parameters, lr=self.lr) + + return word_print # For later use in saving history + + def train(self): + """Stage 2: Training loop.""" + word_print = self.setup() + ddim_eta = 0 # As used in training + + # A lambda function to sample until a given time step. + quick_sample_till_t = lambda x, s, code, batch, t: sample_model( + self.model, self.sampler, + x, self.image_size, self.image_size, self.ddim_steps, s, ddim_eta, + start_code=code, n_samples=batch, till_T=t, verbose=False + ) + + losses = [] + history = [] + global_step = 0 + attack_round = 0 + + # Create a tqdm progress bar + pbar = tqdm(range(self.iterations)) + for i in pbar: + # --- Update adversarial prompt every adv_prompt_update_step iterations --- + if i % self.adv_prompt_update_step == 0: + # Reset the retaining dataset if needed + if self.retain_dataset.check_unseen_prompt_count() < self.retain_batch: + self.retain_dataset.reset() + + # Randomly choose one prompt from the list + word = random.sample(self.words, 1)[0] + text_input = self.tokenizer( + word, padding="max_length", max_length=self.tokenizer.model_max_length, + return_tensors="pt", truncation=True + ) + text_embeddings = id2embedding(self.tokenizer, self.all_embeddings, text_input.input_ids.to(self.devices[0]), self.devices[0]) + + # Get conditional embeddings from the frozen model + emb_0 = self.model_orig.get_learned_conditioning(['']) + emb_p = self.model_orig.get_learned_conditioning([word]) + + # --- Attack Step: Get adversarial prompt --- + if i >= self.warmup_iter: + self.custom_text_encoder.text_encoder.eval() + self.custom_text_encoder.text_encoder.requires_grad_(False) + self.model.eval() + + if attack_round == 0: + if self.attack_embd_type == 'word_embd': + self.adv_word_embd, self.adv_input_ids = SoftPromptAttack.attack( + global_step, word, self.model, self.model_orig, self.tokenizer, + self.custom_text_encoder, self.sampler, emb_0, emb_p, self.start_guidance, + self.devices, self.ddim_steps, ddim_eta, self.image_size, self.criteria, + self.adv_prompt_num, self.all_embeddings, attack_round, self.attack_type, + self.attack_embd_type, self.attack_step, self.attack_lr, self.attack_init, + None, self.attack_method + ) + elif self.attack_embd_type == 'condition_embd': + self.adv_condition_embd, self.adv_input_ids = SoftPromptAttack.attack( + global_step, word, self.model, self.model_orig, self.tokenizer, + self.custom_text_encoder, self.sampler, emb_0, emb_p, self.start_guidance, + self.devices, self.ddim_steps, ddim_eta, self.image_size, self.criteria, + self.adv_prompt_num, self.all_embeddings, attack_round, self.attack_type, + self.attack_embd_type, self.attack_step, self.attack_lr, self.attack_init, + None, self.attack_method + ) + else: + if self.attack_embd_type == 'word_embd': + self.adv_word_embd, self.adv_input_ids = SoftPromptAttack.attack( + global_step, word, self.model, self.model_orig, self.tokenizer, + self.custom_text_encoder, self.sampler, emb_0, emb_p, self.start_guidance, + self.devices, self.ddim_steps, ddim_eta, self.image_size, self.criteria, + self.adv_prompt_num, self.all_embeddings, attack_round, self.attack_type, + self.attack_embd_type, self.attack_step, self.attack_lr, self.attack_init, + self.adv_word_embd, self.attack_method + ) + elif self.attack_embd_type == 'condition_embd': + self.adv_condition_embd, self.adv_input_ids = SoftPromptAttack.attack( + global_step, word, self.model, self.model_orig, self.tokenizer, + self.custom_text_encoder, self.sampler, emb_0, emb_p, self.start_guidance, + self.devices, self.ddim_steps, ddim_eta, self.image_size, self.criteria, + self.adv_prompt_num, self.all_embeddings, attack_round, self.attack_type, + self.attack_embd_type, self.attack_step, self.attack_lr, self.attack_init, + self.adv_condition_embd, self.attack_method + ) + global_step += self.attack_step + attack_round += 1 + + # --- Set models to training/eval modes based on training method --- + if 'text_encoder' in self.train_method: + self.custom_text_encoder.text_encoder.train() + self.custom_text_encoder.text_encoder.requires_grad_(True) + self.model.eval() + else: + self.custom_text_encoder.text_encoder.eval() + self.custom_text_encoder.text_encoder.requires_grad_(False) + self.model.train() + self.opt.zero_grad() + + # --- Retaining prompts for retention regularization --- + if self.retain_train == 'reg': + retain_words = self.retain_dataset.get_random_prompts(self.retain_batch) + retain_text_input = self.tokenizer( + retain_words, padding="max_length", max_length=self.tokenizer.model_max_length, + return_tensors="pt", truncation=True + ) + retain_input_ids = retain_text_input.input_ids.to(self.devices[0]) + + retain_emb_p = self.model_orig.get_learned_conditioning(retain_words) + retain_text_embeddings = id2embedding(self.tokenizer, self.all_embeddings, retain_text_input.input_ids.to(self.devices[0]), self.devices[0]) + # Reshape to [batch, 77, embedding_dim] + retain_text_embeddings = retain_text_embeddings.reshape(self.retain_batch, -1, retain_text_embeddings.shape[-1]) + retain_emb_n = self.custom_text_encoder(input_ids=retain_input_ids, inputs_embeds=retain_text_embeddings)[0] + else: + retain_text_input = None + retain_text_embeddings = None + retain_emb_p = None + retain_emb_n = None + + # --- Compute training loss --- + if i < self.warmup_iter: + # Warmup training uses the original prompt embeddings. + input_ids = text_input.input_ids.to(self.devices[0]) + emb_n = self.custom_text_encoder(input_ids=input_ids, inputs_embeds=text_embeddings)[0] + loss = get_train_loss_retain( + self.retain_batch, self.retain_train, self.retain_loss_w, + self.model, self.model_orig, self.custom_text_encoder, self.sampler, + emb_0, emb_p, retain_emb_p, emb_n, retain_emb_n, self.start_guidance, + self.negative_guidance, self.devices, self.ddim_steps, ddim_eta, + self.image_size, self.criteria, input_ids, self.attack_embd_type + ) + else: + if self.attack_embd_type == 'word_embd': + loss = get_train_loss_retain( + self.retain_batch, self.retain_train, self.retain_loss_w, + self.model, self.model_orig, self.custom_text_encoder, self.sampler, + emb_0, emb_p, retain_emb_p, None, retain_emb_n, self.start_guidance, + self.negative_guidance, self.devices, self.ddim_steps, ddim_eta, + self.image_size, self.criteria, self.adv_input_ids, self.attack_embd_type, self.adv_word_embd + ) + elif self.attack_embd_type == 'condition_embd': + loss = get_train_loss_retain( + self.retain_batch, self.retain_train, self.retain_loss_w, + self.model, self.model_orig, self.custom_text_encoder, self.sampler, + emb_0, emb_p, retain_emb_p, None, retain_emb_n, self.start_guidance, + self.negative_guidance, self.devices, self.ddim_steps, ddim_eta, + self.image_size, self.criteria, self.adv_input_ids, self.attack_embd_type, self.adv_condition_embd + ) + + # Backpropagate loss and update weights. + loss.backward() + losses.append(loss.item()) + pbar.set_postfix({"loss": loss.item()}) + history.append(loss.item()) + wandb.log({'Train_Loss': loss.item()}, step=global_step) + wandb.log({'Attack_Loss': 0.0}, step=global_step) + global_step += 1 + self.opt.step() + + # --- Additional Retention Training (if using iterative retention) --- + if self.retain_train == 'iter': + for r in range(self.retain_step): + print(f'==== Retain Training at step {r} ====') + self.opt.zero_grad() + if self.retain_dataset.check_unseen_prompt_count() < self.retain_batch: + self.retain_dataset.reset() + retain_words = self.retain_dataset.get_random_prompts(self.retain_batch) + + t_enc = torch.randint(self.ddim_steps, (1,), device=self.devices[0]) + og_num = round((int(t_enc) / self.ddim_steps) * 1000) + og_num_lim = round((int(t_enc + 1) / self.ddim_steps) * 1000) + t_enc_ddpm = torch.randint(og_num, og_num_lim, (1,), device=self.devices[0]) + retain_start_code = torch.randn((self.retain_batch, 4, 64, 64)).to(self.devices[0]) + + retain_emb_p = self.model_orig.get_learned_conditioning(retain_words) + retain_z = quick_sample_till_t(retain_emb_p.to(self.devices[0]), self.start_guidance, retain_start_code, self.retain_batch, int(t_enc)) + retain_e_p = self.model_orig.apply_model(retain_z.to(self.devices[0]), t_enc_ddpm.to(self.devices[0]), retain_emb_p.to(self.devices[0])) + + retain_text_input = self.tokenizer( + retain_words, padding="max_length", max_length=self.tokenizer.model_max_length, + return_tensors="pt", truncation=True + ) + retain_input_ids = retain_text_input.input_ids.to(self.devices[0]) + retain_text_embeddings = id2embedding(self.tokenizer, self.all_embeddings, retain_text_input.input_ids.to(self.devices[0]), self.devices[0]) + retain_text_embeddings = retain_text_embeddings.reshape(self.retain_batch, -1, retain_text_embeddings.shape[-1]) + retain_emb_n = self.custom_text_encoder(input_ids=retain_input_ids, inputs_embeds=retain_text_embeddings)[0] + retain_e_n = self.model.apply_model(retain_z.to(self.devices[0]), t_enc_ddpm.to(self.devices[0]), retain_emb_n.to(self.devices[0])) + + retain_loss = self.criteria(retain_e_n.to(self.devices[0]), retain_e_p.to(self.devices[0])) + retain_loss.backward() + self.opt.step() + + # --- Checkpointing and saving history --- + if (i + 1) % self.save_interval == 0 and (i + 1) != self.iterations and (i + 1) >= self.save_interval: + if 'text_encoder' in self.train_method: + save_text_encoder(self.output_dir, self.custom_text_encoder, self.train_method, i) + else: + save_model(self.output_dir, self.model, self.train_method, i, save_compvis=True, + save_diffusers=True, compvis_config_file=self.config_path, + diffusers_config_file=self.diffusers_config_path) + if i % 1 == 0: + save_history(self.output_dir, losses, word_print) + + # --- Stage 3: Save final model and loss curve --- + self.model.eval() + self.custom_text_encoder.text_encoder.eval() + self.custom_text_encoder.text_encoder.requires_grad_(False) + if 'text_encoder' in self.train_method: + save_text_encoder(self.output_dir, self.custom_text_encoder, self.train_method, i) + else: + save_model(self.output_dir, self.model, self.train_method, i, save_compvis=True, + save_diffusers=True, compvis_config_file=self.config_path, + diffusers_config_file=self.diffusers_config_path) + save_history(self.output_dir, losses, word_print) \ No newline at end of file diff --git a/mu_attack/helpers/utils.py b/mu_attack/helpers/utils.py index 2d9d89bc..6ca998f9 100644 --- a/mu_attack/helpers/utils.py +++ b/mu_attack/helpers/utils.py @@ -6,6 +6,9 @@ import json from typing import Optional, Tuple, Union +from mu.helpers.utils import load_model_from_config +from stable_diffusion.ldm.models.diffusion.ddim import DDIMSampler +import torch.nn.functional as F from torchvision.transforms.functional import InterpolationMode from transformers.modeling_outputs import BaseModelOutputWithPooling import torchvision.transforms as torch_transforms @@ -67,3 +70,556 @@ def convert_time(time_str): total_minutes_direct = hours * 60 + minutes + seconds_microseconds / 60 return total_minutes_direct +def id2embedding(tokenizer, all_embeddings, input_ids, device): + input_one_hot = F.one_hot(input_ids.view(-1), num_classes = len(tokenizer.get_vocab())).float() + input_one_hot = torch.unsqueeze(input_one_hot,0).to(device) + input_embeds = input_one_hot @ all_embeddings + return input_embeds + +def split_id(input_ids, k, orig_prompt_len): + sot_id, mid_id, replace_id, eot_id = torch.split(input_ids, [1, orig_prompt_len, k, 76-orig_prompt_len-k], dim=1) + return sot_id, mid_id, replace_id, eot_id + +def split_embd(input_embed, k, orig_prompt_len): + sot_embd, mid_embd, replace_embd, eot_embd = torch.split(input_embed, [1, orig_prompt_len, k, 76-orig_prompt_len-k ], dim=1) + return sot_embd, mid_embd, replace_embd, eot_embd + +def init_adv(k, tokenizer, all_embeddings, attack_type, device, batch = 1, attack_init_embd = None): + # Different attack types have different initializations (Attack types: add, insert) + adv_embedding = torch.nn.Parameter(torch.randn([batch, k, 768])).to(device) + + if attack_init_embd is not None: + # Use the provided initial adversarial embedding + adv_embedding.data = attack_init_embd[:,1:1+k].data + else: + # Random sample k words from the vocabulary as the initial adversarial words + tmp_ids = torch.randint(0,len(tokenizer),(batch, k)).to(device) + tmp_embeddings = id2embedding(tokenizer, all_embeddings, tmp_ids, device) + tmp_embeddings = tmp_embeddings.reshape(batch, k, 768) + adv_embedding.data = tmp_embeddings.data + adv_embedding = adv_embedding.detach().requires_grad_(True) + + return adv_embedding + +def construct_embd(k, adv_embedding, insertion_location, sot_embd, mid_embd, eot_embd): + if insertion_location == 'prefix_k': # Prepend k words before the original prompt + embedding = torch.cat([sot_embd,adv_embedding,mid_embd,eot_embd],dim=1) + elif insertion_location == 'replace_k': # Replace k words in the original prompt + replace_embd = eot_embd[:,0,:].repeat(1,mid_embd.shape[1],1) + embedding = torch.cat([sot_embd,adv_embedding,replace_embd,eot_embd],dim=1) + elif insertion_location == 'add': # Add perturbation to the original prompt + replace_embd = eot_embd[:,0,:].repeat(1,k,1) + embedding = torch.cat([sot_embd,adv_embedding+mid_embd,replace_embd,eot_embd],dim=1) + elif insertion_location == 'suffix_k': # Append k words after the original prompt + embedding = torch.cat([sot_embd,mid_embd,adv_embedding,eot_embd],dim=1) + elif insertion_location == 'mid_k': # Insert k words in the middle of the original prompt + embedding = [sot_embd,] + total_num = mid_embd.size(1) + embedding.append(mid_embd[:,:total_num//2,:]) + embedding.append(adv_embedding) + embedding.append(mid_embd[:,total_num//2:,:]) + embedding.append(eot_embd) + embedding = torch.cat(embedding,dim=1) + elif insertion_location == 'insert_k': # seperate k words into the original prompt with equal intervals + embedding = [sot_embd,] + total_num = mid_embd.size(1) + internals = total_num // (k+1) + for i in range(k): + embedding.append(mid_embd[:,internals*i:internals*(i+1),:]) + embedding.append(adv_embedding[:,i,:].unsqueeze(1)) + embedding.append(mid_embd[:,internals*(i+1):,:]) + embedding.append(eot_embd) + embedding = torch.cat(embedding,dim=1) + + elif insertion_location == 'per_k_words': + embedding = [sot_embd,] + for i in range(adv_embedding.size(1) - 1): + embedding.append(adv_embedding[:,i,:].unsqueeze(1)) + embedding.append(mid_embd[:,3*i:3*(i+1),:]) + embedding.append(adv_embedding[:,-1,:].unsqueeze(1)) + embedding.append(mid_embd[:,3*(i+1):,:]) + embedding.append(eot_embd) + embedding = torch.cat(embedding,dim=1) + return embedding + +def construct_id(k, adv_id, insertion_location,sot_id,eot_id,mid_id): + if insertion_location == 'prefix_k': + input_ids = torch.cat([sot_id,adv_id,mid_id,eot_id],dim=1) + + elif insertion_location == 'replace_k': + replace_id = eot_id[:,0].repeat(1,mid_id.shape[1]) + input_ids = torch.cat([sot_id,adv_id,replace_id,eot_id],dim=1) + + elif insertion_location == 'add': + replace_id = eot_id[:,0].repeat(1,k) + input_ids = torch.cat([sot_id,mid_id,replace_id,eot_id],dim=1) + + elif insertion_location == 'suffix_k': + input_ids = torch.cat([sot_id,mid_id,adv_id,eot_id],dim=1) + + elif insertion_location == 'mid_k': + input_ids = [sot_id,] + total_num = mid_id.size(1) + input_ids.append(mid_id[:,:total_num//2]) + input_ids.append(adv_id) + input_ids.append(mid_id[:,total_num//2:]) + input_ids.append(eot_id) + input_ids = torch.cat(input_ids,dim=1) + + elif insertion_location == 'insert_k': + input_ids = [sot_id,] + total_num = mid_id.size(1) + internals = total_num // (k+1) + for i in range(k): + input_ids.append(mid_id[:,internals*i:internals*(i+1)]) + input_ids.append(adv_id[:,i].unsqueeze(1)) + input_ids.append(mid_id[:,internals*(i+1):]) + input_ids.append(eot_id) + input_ids = torch.cat(input_ids,dim=1) + + elif insertion_location == 'per_k_words': + input_ids = [sot_id,] + for i in range(adv_id.size(1) - 1): + input_ids.append(adv_id[:,i].unsqueeze(1)) + input_ids.append(mid_id[:,3*i:3*(i+1)]) + input_ids.append(adv_id[:,-1].unsqueeze(1)) + input_ids.append(mid_id[:,3*(i+1):]) + input_ids.append(eot_id) + input_ids = torch.cat(input_ids,dim=1) + return input_ids + + +def param_choices(model, train_method, component='all', final_layer_norm=False): + # choose parameters to train based on train_method + parameters = [] + + # Text Encoder FUll Weight Tuning + if train_method == 'text_encoder_full': + for name, param in model.text_encoder.text_model.named_parameters(): + # Final Layer Norm + if name.startswith('final_layer_norm'): + if component == 'all' or final_layer_norm==True: + print(name) + parameters.append(param) + else: + pass + + # Transformer layers + elif name.startswith('encoder'): + if component == 'ffn' and 'mlp' in name: + print(name) + parameters.append(param) + elif component == 'attn' and 'self_attn' in name: + print(name) + parameters.append(param) + elif component == 'all': + print(name) + parameters.append(param) + else: + pass + + # Embedding layers + else: + pass + + # Text Encoder Layer 0 Tuning + elif train_method == 'text_encoder_layer0': + for name, param in model.text_encoder.text_model.named_parameters(): + # Encoder Layer 0 + if name.startswith('encoder.layers.0'): + if component == 'ffn' and 'mlp' in name: + print(name) + parameters.append(param) + elif component == 'attn' and 'self_attn' in name: + print(name) + parameters.append(param) + elif component == 'all': + print(name) + parameters.append(param) + else: + pass + + elif name.startswith('final_layer_norm') and final_layer_norm==True: + print(name) + parameters.append(param) + + else: + pass + + elif train_method == 'text_encoder_layer01': + for name, param in model.text_encoder.text_model.named_parameters(): + # Encoder Layer 0 + if name.startswith('encoder.layers.0') or name.startswith('encoder.layers.1'): + if component == 'ffn' and 'mlp' in name: + print(name) + parameters.append(param) + elif component == 'attn' and 'self_attn' in name: + print(name) + parameters.append(param) + elif component == 'all': + print(name) + parameters.append(param) + else: + pass + elif name.startswith('final_layer_norm') and final_layer_norm==True: + print(name) + parameters.append(param) + + else: + pass + + elif train_method == 'text_encoder_layer012': + for name, param in model.text_encoder.text_model.named_parameters(): + # Encoder Layer 0 + if name.startswith('encoder.layers.0') or name.startswith('encoder.layers.1') or name.startswith('encoder.layers.2'): + if component == 'ffn' and 'mlp' in name: + print(name) + parameters.append(param) + elif component == 'attn' and 'self_attn' in name: + print(name) + parameters.append(param) + elif component == 'all': + print(name) + parameters.append(param) + else: + pass + + elif name.startswith('final_layer_norm') and final_layer_norm==True: + print(name) + parameters.append(param) + + else: + pass + + elif train_method == 'text_encoder_layer0123': + for name, param in model.text_encoder.text_model.named_parameters(): + # Encoder Layer 0 + if name.startswith('encoder.layers.0') or name.startswith('encoder.layers.1') or name.startswith('encoder.layers.2') or name.startswith('encoder.layers.3'): + if component == 'ffn' and 'mlp' in name: + print(name) + parameters.append(param) + elif component == 'attn' and 'self_attn' in name: + print(name) + parameters.append(param) + elif component == 'all': + print(name) + parameters.append(param) + else: + pass + + elif name.startswith('final_layer_norm') and final_layer_norm==True: + print(name) + parameters.append(param) + + else: + pass + + elif train_method == 'text_encoder_layer01234': + for name, param in model.text_encoder.text_model.named_parameters(): + # Encoder Layer 0 + if name.startswith('encoder.layers.0') or name.startswith('encoder.layers.1') or name.startswith('encoder.layers.2') or name.startswith('encoder.layers.3') or name.startswith('encoder.layers.4'): + if component == 'ffn' and 'mlp' in name: + print(name) + parameters.append(param) + elif component == 'attn' and 'self_attn' in name: + print(name) + parameters.append(param) + elif component == 'all': + print(name) + parameters.append(param) + else: + pass + + elif name.startswith('final_layer_norm') and final_layer_norm==True: + print(name) + parameters.append(param) + + else: + pass + + elif train_method == 'text_encoder_layer012345': + for name, param in model.text_encoder.text_model.named_parameters(): + # Encoder Layer 0 + if name.startswith('encoder.layers.0') or name.startswith('encoder.layers.1') or name.startswith('encoder.layers.2') or name.startswith('encoder.layers.3') or name.startswith('encoder.layers.4') or name.startswith('encoder.layers.5'): + if component == 'ffn' and 'mlp' in name: + print(name) + parameters.append(param) + elif component == 'attn' and 'self_attn' in name: + print(name) + parameters.append(param) + elif component == 'all': + print(name) + parameters.append(param) + else: + pass + + elif name.startswith('final_layer_norm') and final_layer_norm==True: + print(name) + parameters.append(param) + + else: + pass + + elif train_method == 'text_encoder_layer0123456': + for name, param in model.text_encoder.text_model.named_parameters(): + # Encoder Layer 0 + if name.startswith('encoder.layers.0') or name.startswith('encoder.layers.1') or name.startswith('encoder.layers.2') or name.startswith('encoder.layers.3') or name.startswith('encoder.layers.4') or name.startswith('encoder.layers.5') or name.startswith('encoder.layers.6'): + if component == 'ffn' and 'mlp' in name: + print(name) + parameters.append(param) + elif component == 'attn' and 'self_attn' in name: + print(name) + parameters.append(param) + elif component == 'all': + print(name) + parameters.append(param) + else: + pass + + elif name.startswith('final_layer_norm') and final_layer_norm==True: + print(name) + parameters.append(param) + + else: + pass + + elif train_method == 'text_encoder_layer01234567': + for name, param in model.text_encoder.text_model.named_parameters(): + # Encoder Layer 0 + if name.startswith('encoder.layers.0') or name.startswith('encoder.layers.1') or name.startswith('encoder.layers.2') or name.startswith('encoder.layers.3') or name.startswith('encoder.layers.4') or name.startswith('encoder.layers.5') or name.startswith('encoder.layers.6') or name.startswith('encoder.layers.7'): + if component == 'ffn' and 'mlp' in name: + print(name) + parameters.append(param) + elif component == 'attn' and 'self_attn' in name: + print(name) + parameters.append(param) + elif component == 'all': + print(name) + parameters.append(param) + else: + pass + + elif name.startswith('final_layer_norm') and final_layer_norm==True: + print(name) + parameters.append(param) + + else: + pass + + elif train_method == 'text_encoder_layer012345678': + for name, param in model.text_encoder.text_model.named_parameters(): + # Encoder Layer 0 + if name.startswith('encoder.layers.0') or name.startswith('encoder.layers.1') or name.startswith('encoder.layers.2') or name.startswith('encoder.layers.3') or name.startswith('encoder.layers.4') or name.startswith('encoder.layers.5') or name.startswith('encoder.layers.6') or name.startswith('encoder.layers.7') or name.startswith('encoder.layers.8'): + if component == 'ffn' and 'mlp' in name: + print(name) + parameters.append(param) + elif component == 'attn' and 'self_attn' in name: + print(name) + parameters.append(param) + elif component == 'all': + print(name) + parameters.append(param) + else: + pass + + elif name.startswith('final_layer_norm') and final_layer_norm==True: + print(name) + parameters.append(param) + + else: + pass + + elif train_method == 'text_encoder_layer0123456789': + for name, param in model.text_encoder.text_model.named_parameters(): + # Encoder Layer 0 + if name.startswith('encoder.layers.0') or name.startswith('encoder.layers.1') or name.startswith('encoder.layers.2') or name.startswith('encoder.layers.3') or name.startswith('encoder.layers.4') or name.startswith('encoder.layers.5') or name.startswith('encoder.layers.6') or name.startswith('encoder.layers.7') or name.startswith('encoder.layers.8') or name.startswith('encoder.layers.9'): + if component == 'ffn' and 'mlp' in name: + print(name) + parameters.append(param) + elif component == 'attn' and 'self_attn' in name: + print(name) + parameters.append(param) + elif component == 'all': + print(name) + parameters.append(param) + else: + pass + + elif name.startswith('final_layer_norm') and final_layer_norm==True: + print(name) + parameters.append(param) + + else: + pass + + elif train_method == 'text_encoder_layer012345678910': + for name, param in model.text_encoder.text_model.named_parameters(): + # Encoder Layer 0 + if name.startswith('encoder.layers.0') or name.startswith('encoder.layers.1') or name.startswith('encoder.layers.2') or name.startswith('encoder.layers.3') or name.startswith('encoder.layers.4') or name.startswith('encoder.layers.5') or name.startswith('encoder.layers.6') or name.startswith('encoder.layers.7') or name.startswith('encoder.layers.8') or name.startswith('encoder.layers.9') or name.startswith('encoder.layers.10'): + if component == 'ffn' and 'mlp' in name: + print(name) + parameters.append(param) + elif component == 'attn' and 'self_attn' in name: + print(name) + parameters.append(param) + elif component == 'all': + print(name) + parameters.append(param) + else: + pass + + elif name.startswith('final_layer_norm') and final_layer_norm==True: + print(name) + parameters.append(param) + + else: + pass + + elif train_method == 'text_encoder_layer01234567891011': + for name, param in model.text_encoder.text_model.named_parameters(): + # Encoder Layer 0 + if name.startswith('encoder.layers.0') or name.startswith('encoder.layers.1') or name.startswith('encoder.layers.2') or name.startswith('encoder.layers.3') or name.startswith('encoder.layers.4') or name.startswith('encoder.layers.5') or name.startswith('encoder.layers.6') or name.startswith('encoder.layers.7') or name.startswith('encoder.layers.8') or name.startswith('encoder.layers.9') or name.startswith('encoder.layers.10') or name.startswith('encoder.layers.11'): + if component == 'ffn' and 'mlp' in name: + print(name) + parameters.append(param) + elif component == 'attn' and 'self_attn' in name: + print(name) + parameters.append(param) + elif component == 'all': + print(name) + parameters.append(param) + else: + pass + + elif name.startswith('final_layer_norm') and final_layer_norm==True: + print(name) + parameters.append(param) + + else: + pass + + elif train_method == 'text_encoder_layer0_11': + for name, param in model.text_encoder.text_model.named_parameters(): + # Encoder Layer 0 + if name.startswith('encoder.layers.0') or name.startswith('encoder.layers.11'): + if component == 'ffn' and 'mlp' in name: + print(name) + parameters.append(param) + elif component == 'attn' and 'self_attn' in name: + print(name) + parameters.append(param) + elif component == 'all': + print(name) + parameters.append(param) + else: + pass + + elif name.startswith('final_layer_norm') and final_layer_norm==True: + print(name) + parameters.append(param) + + else: + pass + + + elif train_method == 'text_encoder_layer01_1011': + for name, param in model.text_encoder.text_model.named_parameters(): + # Encoder Layer 0 + if name.startswith('encoder.layers.0') or name.startswith('encoder.layers.1') or name.startswith('encoder.layers.10') or name.startswith('encoder.layers.11'): + if component == 'ffn' and 'mlp' in name: + print(name) + parameters.append(param) + elif component == 'attn' and 'self_attn' in name: + print(name) + parameters.append(param) + elif component == 'all': + print(name) + parameters.append(param) + else: + pass + + elif name.startswith('final_layer_norm') and final_layer_norm==True: + print(name) + parameters.append(param) + + else: + pass + + elif train_method == 'text_encoder_layer012_91011': + for name, param in model.text_encoder.text_model.named_parameters(): + # Encoder Layer 0 + if name.startswith('encoder.layers.0') or name.startswith('encoder.layers.1') or name.startswith('encoder.layers.2') or name.startswith('encoder.layers.9') or name.startswith('encoder.layers.10') or name.startswith('encoder.layers.11'): + if component == 'ffn' and 'mlp' in name: + print(name) + parameters.append(param) + elif component == 'attn' and 'self_attn' in name: + print(name) + parameters.append(param) + elif component == 'all': + print(name) + parameters.append(param) + else: + pass + + elif name.startswith('final_layer_norm') and final_layer_norm==True: + print(name) + parameters.append(param) + + else: + pass + + # UNet Model Tuning + else: + for name, param in model.model.diffusion_model.named_parameters(): + # train all layers except x-attns and time_embed layers + if train_method == 'noxattn': + if name.startswith('out.') or 'attn2' in name or 'time_embed' in name: + pass + else: + print(name) + parameters.append(param) + + # train only self attention layers + if train_method == 'selfattn': + if 'attn1' in name: + print(name) + parameters.append(param) + + # train only x attention layers + if train_method == 'xattn': + if 'attn2' in name: + print(name) + parameters.append(param) + + # train all layers + if train_method == 'full': + print(name) + parameters.append(param) + + # train all layers except time embed layers + if train_method == 'notime': + if not (name.startswith('out.') or 'time_embed' in name): + print(name) + parameters.append(param) + if train_method == 'xlayer': + if 'attn2' in name: + if 'output_blocks.6.' in name or 'output_blocks.8.' in name: + print(name) + parameters.append(param) + if train_method == 'selflayer': + if 'attn1' in name: + if 'input_blocks.4.' in name or 'input_blocks.7.' in name: + print(name) + parameters.append(param) + + return parameters + + +def get_models(config_path, ckpt_path, devices): + model_orig = load_model_from_config(config_path, ckpt_path, devices[1]) + sampler_orig = DDIMSampler(model_orig) + + model = load_model_from_config(config_path, ckpt_path, devices[0]) + sampler = DDIMSampler(model) + + return model_orig, sampler_orig, model, sampler \ No newline at end of file From 11147e2fac3bc15353f17dc84d9a1fd1193b5c1b Mon Sep 17 00:00:00 2001 From: palisthadeshar Date: Mon, 3 Feb 2025 19:50:07 +0545 Subject: [PATCH 02/22] adv unlearn refactor --- mu_attack/configs/adv_unlearn/__init__.py | 1 + .../configs/adv_unlearn/adv_unlearn_config.py | 80 +++ .../configs/adv_unlearn/model_config.yaml | 70 ++ mu_attack/execs/adv_attack.py | 143 ++-- mu_attack/helpers/utils.py | 670 +++++++++++++++++- 5 files changed, 877 insertions(+), 87 deletions(-) create mode 100644 mu_attack/configs/adv_unlearn/__init__.py create mode 100644 mu_attack/configs/adv_unlearn/adv_unlearn_config.py create mode 100644 mu_attack/configs/adv_unlearn/model_config.yaml diff --git a/mu_attack/configs/adv_unlearn/__init__.py b/mu_attack/configs/adv_unlearn/__init__.py new file mode 100644 index 00000000..cc01a1ec --- /dev/null +++ b/mu_attack/configs/adv_unlearn/__init__.py @@ -0,0 +1 @@ +from .adv_unlearn_config import AdvUnlearnConfig, adv_unlearn_config \ No newline at end of file diff --git a/mu_attack/configs/adv_unlearn/adv_unlearn_config.py b/mu_attack/configs/adv_unlearn/adv_unlearn_config.py new file mode 100644 index 00000000..99dc1172 --- /dev/null +++ b/mu_attack/configs/adv_unlearn/adv_unlearn_config.py @@ -0,0 +1,80 @@ +import os +from pathlib import Path +from mu.core.base_config import BaseConfig + +current_dir = Path(__file__).parent + +class AdvUnlearnConfig(BaseConfig): + def __init__(self, **kwargs): + # Inference & Model Paths + self.config_path = current_dir / "configs/stable-diffusion/v1-inference.yaml" + self.ckpt_path = "models/sd-v1-4-full-ema.ckpt" + self.diffusers_config_path = current_dir / "diffusers_unet_config.json" + self.model_name_or_path = "CompVis/stable-diffusion-v1-4" + self.cache_path = ".cache" + + # Devices & IO + self.devices = "0,0" # You can later parse this string into a list if needed. + self.seperator = None + self.output_dir = "outputs/adv_unlearn" + + # Image & Diffusion Sampling + self.image_size = 512 + self.ddim_steps = 50 + self.start_guidance = 3.0 + self.negative_guidance = 1.0 + + # Training Setup + self.prompt = "nudity" + self.dataset_retain = "coco" # Choices: 'coco_object', 'coco_object_no_filter', 'imagenet243', 'imagenet243_no_filter' + self.retain_batch = 5 + self.retain_train = "iter" # Options: 'iter' or 'reg' + self.retain_step = 1 + self.retain_loss_w = 1.0 + self.ddim_eta = 0 + + self.train_method = "text_encoder_full" #choices: text_encoder_full', 'text_encoder_layer0', 'text_encoder_layer01', 'text_encoder_layer012', 'text_encoder_layer0123', 'text_encoder_layer01234', 'text_encoder_layer012345', 'text_encoder_layer0123456', 'text_encoder_layer01234567', 'text_encoder_layer012345678', 'text_encoder_layer0123456789', 'text_encoder_layer012345678910', 'text_encoder_layer01234567891011', 'text_encoder_layer0_11','text_encoder_layer01_1011', 'text_encoder_layer012_91011', 'noxattn', 'selfattn', 'xattn', 'full', 'notime', 'xlayer', 'selflayer + self.norm_layer = False # This is a flag; use True if you wish to update the norm layer. + self.attack_method = "pgd" # Choices: 'pgd', 'multi_pgd', 'fast_at', 'free_at' + self.component = "all" # Choices: 'all', 'ffn', 'attn' + self.iterations = 1000 + self.save_interval = 200 + self.lr = 1e-5 + + # Adversarial Attack Hyperparameters + self.adv_prompt_num = 1 + self.attack_embd_type = "word_embd" # Choices: 'word_embd', 'condition_embd' + self.attack_type = "prefix_k" # Choices: 'replace_k', 'add', 'prefix_k', 'suffix_k', 'mid_k', 'insert_k', 'per_k_words' + self.attack_init = "latest" # Choices: 'random', 'latest' + self.attack_step = 30 + self.adv_prompt_update_step = 1 + self.attack_lr = 1e-3 + self.warmup_iter = 200 + + # Override default values with any provided keyword arguments. + for key, value in kwargs.items(): + setattr(self, key, value) + + def validate_config(self): + """ + Perform basic validation on the config parameters. + """ + if self.retain_batch <= 0: + raise ValueError("retain_batch should be a positive integer.") + if self.lr <= 0: + raise ValueError("Learning rate (lr) should be positive.") + if self.image_size <= 0: + raise ValueError("Image size should be a positive integer.") + if self.iterations <= 0: + raise ValueError("Iterations must be a positive integer.") + if not os.path.exists(self.config_path): + raise FileNotFoundError(f"Model config file {self.config_path} does not exist.") + if not os.path.exists(self.ckpt_path): + raise FileNotFoundError(f"Checkpoint file {self.ckpt_path} does not exist.") + if not os.path.exists(self.diffusers_config_path): + raise FileNotFoundError(f"Diffusers config file {self.diffusers_config_path} does not exist.") + if not os.path.exists(self.output_dir): + os.makedirs(self.output_dir) + +adv_unlearn_config = AdvUnlearnConfig() + diff --git a/mu_attack/configs/adv_unlearn/model_config.yaml b/mu_attack/configs/adv_unlearn/model_config.yaml new file mode 100644 index 00000000..d4effe56 --- /dev/null +++ b/mu_attack/configs/adv_unlearn/model_config.yaml @@ -0,0 +1,70 @@ +model: + base_learning_rate: 1.0e-04 + target: ldm.models.diffusion.ddpm.LatentDiffusion + params: + linear_start: 0.00085 + linear_end: 0.0120 + num_timesteps_cond: 1 + log_every_t: 200 + timesteps: 1000 + first_stage_key: "jpg" + cond_stage_key: "txt" + image_size: 64 + channels: 4 + cond_stage_trainable: false # Note: different from the one we trained before + conditioning_key: crossattn + monitor: val/loss_simple_ema + scale_factor: 0.18215 + use_ema: False + + scheduler_config: # 10000 warmup steps + target: ldm.lr_scheduler.LambdaLinearScheduler + params: + warm_up_steps: [ 10000 ] + cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases + f_start: [ 1.e-6 ] + f_max: [ 1. ] + f_min: [ 1. ] + + unet_config: + target: ldm.modules.diffusionmodules.openaimodel.UNetModel + params: + image_size: 32 # unused + in_channels: 4 + out_channels: 4 + model_channels: 320 + attention_resolutions: [ 4, 2, 1 ] + num_res_blocks: 2 + channel_mult: [ 1, 2, 4, 4 ] + num_heads: 8 + use_spatial_transformer: True + transformer_depth: 1 + context_dim: 768 + use_checkpoint: True + legacy: False + + first_stage_config: + target: ldm.models.autoencoder.AutoencoderKL + params: + embed_dim: 4 + monitor: val/rec_loss + ddconfig: + double_z: true + z_channels: 4 + resolution: 256 + in_channels: 3 + out_ch: 3 + ch: 128 + ch_mult: + - 1 + - 2 + - 4 + - 4 + num_res_blocks: 2 + attn_resolutions: [] + dropout: 0.0 + lossconfig: + target: torch.nn.Identity + + cond_stage_config: + target: ldm.modules.encoders.modules.FrozenCLIPEmbedder diff --git a/mu_attack/execs/adv_attack.py b/mu_attack/execs/adv_attack.py index 3d59e288..7ca4241c 100644 --- a/mu_attack/execs/adv_attack.py +++ b/mu_attack/execs/adv_attack.py @@ -1,17 +1,19 @@ -from mu.helpers import sample_model -from mu_attack.tasks.utils.text_encoder import CustomTextEncoder -from mu_attack.helpers.utils import id2embedding, param_choices, get_models -from mu_attack.attackers.soft_prompt import SoftPromptAttack -from transformers import CLIPTextModel, CLIPTokenizer -from diffusers import AutoencoderKL + import torch from tqdm import tqdm import random -import argparse import wandb -from pathlib import Path -import os + +from transformers import CLIPTextModel, CLIPTokenizer +from diffusers import AutoencoderKL + +from mu_attack.configs.adv_unlearn import AdvUnlearnConfig +from mu.helpers import sample_model +from mu_attack.tasks.utils.text_encoder import CustomTextEncoder +from mu_attack.attackers.soft_prompt import SoftPromptAttack +from mu_attack.helpers.utils import id2embedding, param_choices, get_models, retain_prompt, get_train_loss_retain,save_text_encoder, save_model, save_history + class AdvUnlearn: @@ -23,70 +25,50 @@ class AdvUnlearn: """ def __init__( self, - prompt, - dataset_retain, - retain_batch, - retain_train, - retain_step, - retain_loss_w, - attack_method, - train_method, - norm_layer, - component, - start_guidance, - negative_guidance, - iterations, - save_interval, - lr, - config_path, - ckpt_path, - diffusers_config_path, - output_dir, - devices, - seperator=None, - image_size=512, - ddim_steps=50, - adv_prompt_num=3, - attack_embd_type='word_embd', - attack_type='prefix_k', - attack_init='latest', - warmup_iter=200, - attack_step=30, - attack_lr=1e-2, - adv_prompt_update_step=20 + config: AdvUnlearnConfig, + **kwargs ): - # General training and attack settings - self.prompt = prompt - self.dataset_retain = dataset_retain - self.retain_batch = retain_batch - self.retain_train = retain_train - self.retain_step = retain_step - self.retain_loss_w = retain_loss_w - self.attack_method = attack_method - self.train_method = train_method - self.norm_layer = norm_layer - self.component = component - self.start_guidance = start_guidance - self.negative_guidance = negative_guidance - self.iterations = iterations - self.save_interval = save_interval - self.lr = lr - self.config_path = config_path - self.ckpt_path = ckpt_path - self.diffusers_config_path = diffusers_config_path - self.output_dir = output_dir - self.devices = devices - self.seperator = seperator - self.image_size = image_size - self.ddim_steps = ddim_steps - self.adv_prompt_num = adv_prompt_num - self.attack_embd_type = attack_embd_type - self.attack_type = attack_type - self.attack_init = attack_init - self.warmup_iter = warmup_iter - self.attack_step = attack_step - self.attack_lr = attack_lr - self.adv_prompt_update_step = adv_prompt_update_step + self.config = config.__dict__ + for key, value in kwargs.items(): + setattr(config, key, value) + + config.validate_config() + + self.config = config + self.prompt = config.prompt + self.dataset_retain = config.dataset_retain + self.retain_batch = config.retain_batch + self.retain_train = config.retain_train + self.retain_step = config.retain_step + self.retain_loss_w = config.retain_loss_w + self.attack_method = config.attack_method + self.train_method = config.train_method + self.norm_layer = config.norm_layer + self.component = config.component + self.model_name_or_path = config.model_name_or_path + self.start_guidance = config.start_guidance + self.negative_guidance = config.negative_guidance + self.iterations = config.iterations + self.save_interval = config.save_interval + self.lr = config.lr + self.config_path = config.config_path + self.ckpt_path = config.ckpt_path + self.diffusers_config_path = config.diffusers_config_path + self.output_dir = config.output_dir + self.devices = config.devices + self.seperator = config.seperator + self.image_size = config.image_size + self.ddim_steps = config.ddim_steps + self.adv_prompt_num = config.adv_prompt_num + self.attack_embd_type = config.attack_embd_type + self.attack_type = config.attack_type + self.attack_init = config.attack_init + self.warmup_iter = config.warmup_iter + self.attack_step = config.attack_step + self.attack_lr = config.attack_lr + self.adv_prompt_update_step = config.adv_prompt_update_step + self.ddim_eta = config.ddim_eta + self.cache_path = config.cache_path # Will be set during training. self.words = None @@ -133,15 +115,15 @@ def setup(self): self.retain_dataset = retain_prompt(self.dataset_retain) # --- Training Setup --- - ddim_eta = 0 # constant value for training + ddim_eta = self.ddim_eta # constant value for training - model_name_or_path = "CompVis/stable-diffusion-v1-4" - cache_path = ".cache" + + # Load the VAE - self.vae = AutoencoderKL.from_pretrained(model_name_or_path, subfolder="vae", cache_dir=cache_path).to(self.devices[0]) + self.vae = AutoencoderKL.from_pretrained(self.model_name_or_path, subfolder="vae", cache_dir=self.cache_path).to(self.devices[0]) # Load tokenizer and text encoder - self.tokenizer = CLIPTokenizer.from_pretrained(model_name_or_path, subfolder="tokenizer", cache_dir=cache_path) - self.text_encoder = CLIPTextModel.from_pretrained(model_name_or_path, subfolder="text_encoder", cache_dir=cache_path).to(self.devices[0]) + self.tokenizer = CLIPTokenizer.from_pretrained(self.model_name_or_path, subfolder="tokenizer", cache_dir=self.cache_path) + self.text_encoder = CLIPTextModel.from_pretrained(self.model_name_or_path, subfolder="text_encoder", cache_dir=self.cache_path).to(self.devices[0]) self.custom_text_encoder = CustomTextEncoder(self.text_encoder).to(self.devices[0]) self.all_embeddings = self.custom_text_encoder.get_all_embedding().unsqueeze(0) @@ -162,7 +144,7 @@ def setup(self): def train(self): """Stage 2: Training loop.""" word_print = self.setup() - ddim_eta = 0 # As used in training + ddim_eta = self.ddim_eta # As used in training # A lambda function to sample until a given time step. quick_sample_till_t = lambda x, s, code, batch, t: sample_model( @@ -369,4 +351,5 @@ def train(self): save_model(self.output_dir, self.model, self.train_method, i, save_compvis=True, save_diffusers=True, compvis_config_file=self.config_path, diffusers_config_file=self.diffusers_config_path) - save_history(self.output_dir, losses, word_print) \ No newline at end of file + save_history(self.output_dir, losses, word_print) + diff --git a/mu_attack/helpers/utils.py b/mu_attack/helpers/utils.py index 6ca998f9..0b6e7d93 100644 --- a/mu_attack/helpers/utils.py +++ b/mu_attack/helpers/utils.py @@ -1,19 +1,82 @@ import os -import torch -from PIL import Image import pandas as pd +import random import yaml -import json -from typing import Optional, Tuple, Union +import numpy as np +import matplotlib.pyplot as plt +from omegaconf import OmegaConf -from mu.helpers.utils import load_model_from_config -from stable_diffusion.ldm.models.diffusion.ddim import DDIMSampler + +import torch import torch.nn.functional as F from torchvision.transforms.functional import InterpolationMode from transformers.modeling_outputs import BaseModelOutputWithPooling import torchvision.transforms as torch_transforms +from diffusers import ( + AutoencoderKL, + DDIMScheduler, + DPMSolverMultistepScheduler, + EulerAncestralDiscreteScheduler, + EulerDiscreteScheduler, + HeunDiscreteScheduler, + LDMTextToImagePipeline, + LMSDiscreteScheduler, + PNDMScheduler, + StableDiffusionPipeline, + UNet2DConditionModel, +) +from mu.helpers import sample_model +from mu.helpers.utils import load_model_from_config +from stable_diffusion.ldm.models.diffusion.ddim import DDIMSampler + + + +class PromptDataset: + def __init__(self, csv_file): + self.data = pd.read_csv(csv_file) + self.unseen_indices = list(self.data.index) # 保存所有未见过的索引 + def get_random_prompts(self, num_prompts=1): + # Ensure that the number of prompts requested is not greater than the number of unseen prompts + num_prompts = min(num_prompts, len(self.unseen_indices)) + + # Randomly select num_prompts indices from the list of unseen indices + selected_indices = random.sample(self.unseen_indices, num_prompts) + + # Remove the selected indices from the list of unseen indices + for index in selected_indices: + self.unseen_indices.remove(index) + + # return the prompts corresponding to the selected indices + return self.data.loc[selected_indices, 'prompt'].tolist() + + def has_unseen_prompts(self): + # check if there are any unseen prompts + return len(self.unseen_indices) > 0 + + def reset(self): + self.unseen_indices = list(self.data.index) + + def check_unseen_prompt_count(self): + return len(self.unseen_indices) + + +def retain_prompt(dataset_retain): + # Prompt Dataset to be retained + + if dataset_retain == 'imagenet243': + retain_dataset = PromptDataset('./data/prompts/train/imagenet243_retain.csv') + elif dataset_retain == 'imagenet243_no_filter': + retain_dataset = PromptDataset('./data/prompts/train/imagenet243_no_filter_retain.csv') + elif dataset_retain == 'coco_object': + retain_dataset = PromptDataset('./data/prompts/train/coco_object_retain.csv') + elif dataset_retain == 'coco_object_no_filter': + retain_dataset = PromptDataset('./data/prompts/train/coco_object_no_filter_retain.csv') + else: + raise ValueError('Invalid dataset for retaining prompts') + + return retain_dataset def load_config(yaml_path): """Loads the configuration from a YAML file.""" @@ -622,4 +685,597 @@ def get_models(config_path, ckpt_path, devices): model = load_model_from_config(config_path, ckpt_path, devices[0]) sampler = DDIMSampler(model) - return model_orig, sampler_orig, model, sampler \ No newline at end of file + return model_orig, sampler_orig, model, sampler + +def get_train_loss_retain( retain_batch, retain_train, retain_loss_w, model, model_orig, text_encoder, sampler, emb_0, emb_p, retain_emb_p, emb_n, retain_emb_n, start_guidance, negative_guidance, devices, ddim_steps, ddim_eta, image_size, criteria, adv_input_ids, attack_embd_type, adv_embd=None): + """_summary_ + + Args: + model: ESD model + model_orig: frozen DDPM model + sampler: DDIMSampler for DDPM model + + emb_0: unconditional embedding + emb_p: conditional embedding (for ground truth concept) + emb_n: conditional embedding (for modified concept) + + start_guidance: unconditional guidance for ESD model + negative_guidance: negative guidance for ESD model + + devices: list of devices for ESD and DDPM models + ddim_steps: number of steps for DDIMSampler + ddim_eta: eta for DDIMSampler + image_size: image size for DDIMSampler + + criteria: loss function for ESD model + + adv_input_ids: input_ids for adversarial word embedding + adv_emb_n: adversarial conditional embedding + adv_word_emb_n: adversarial word embedding + + Returns: + loss: training loss for ESD model + """ + quick_sample_till_t = lambda x, s, code, batch, t: sample_model(model, sampler, + x, image_size, image_size, ddim_steps, s, ddim_eta, + start_code=code, n_samples=batch, till_T=t, verbose=False) + + + t_enc = torch.randint(ddim_steps, (1,), device=devices[0]) + # time step from 1000 to 0 (0 being good) + og_num = round((int(t_enc)/ddim_steps)*1000) + og_num_lim = round((int(t_enc+1)/ddim_steps)*1000) + + t_enc_ddpm = torch.randint(og_num, og_num_lim, (1,), device=devices[0]) + + start_code = torch.randn((1, 4, 64, 64)).to(devices[0]) + if retain_train == 'reg': + retain_start_code = torch.randn((retain_batch, 4, 64, 64)).to(devices[0]) + + with torch.no_grad(): + # generate an image with the concept from ESD model + z = quick_sample_till_t(emb_p.to(devices[0]), start_guidance, start_code, 1, int(t_enc)) # emb_p seems to work better instead of emb_0 + # get conditional and unconditional scores from frozen model at time step t and image z + e_0 = model_orig.apply_model(z.to(devices[0]), t_enc_ddpm.to(devices[0]), emb_0.to(devices[0])) + e_p = model_orig.apply_model(z.to(devices[0]), t_enc_ddpm.to(devices[0]), emb_p.to(devices[0])) + + if retain_train == 'reg': + retain_z = quick_sample_till_t(retain_emb_p.to(devices[0]), start_guidance, retain_start_code, retain_batch, int(t_enc)) # emb_p seems to work better instead of emb_0 + # retain_e_0 = model_orig.apply_model(retain_z.to(devices[0]), t_enc_ddpm.to(devices[0]), retain_emb_0.to(devices[0])) + retain_e_p = model_orig.apply_model(retain_z.to(devices[0]), t_enc_ddpm.to(devices[0]), retain_emb_p.to(devices[0])) + + if adv_embd is None: + e_n = model.apply_model(z.to(devices[0]), t_enc_ddpm.to(devices[0]), emb_n.to(devices[0])) + else: + if attack_embd_type == 'condition_embd': + # Train with adversarial conditional embedding + e_n = model.apply_model(z.to(devices[0]), t_enc_ddpm.to(devices[0]), adv_embd.to(devices[0])) + elif attack_embd_type == 'word_embd': + # Train with adversarial word embedding + print('====== Training with adversarial word embedding =====') + adv_emb_n = text_encoder(input_ids = adv_input_ids.to(devices[0]), inputs_embeds=adv_embd.to(devices[0]))[0] + e_n = model.apply_model(z.to(devices[0]), t_enc_ddpm.to(devices[0]), adv_emb_n.to(devices[0])) + else: + raise ValueError('attack_embd_type must be either condition_embd or word_embd') + + e_0.requires_grad = False + e_p.requires_grad = False + + # reconstruction loss for ESD objective from frozen model and conditional score of ESD model + # loss = criteria(e_n.to(devices[0]), e_0.to(devices[0]) - (negative_guidance*(e_p.to(devices[0]) - e_0.to(devices[0])))) + + # return loss + + if retain_train == 'reg': + # reconstruction loss for ESD objective from frozen model and conditional score of ESD model + print('====== Training with retain batch =====') + unlearn_loss = criteria(e_n.to(devices[0]), e_0.to(devices[0]) - (negative_guidance*(e_p.to(devices[0]) - e_0.to(devices[0])))) + + retain_e_n = model.apply_model(retain_z.to(devices[0]), t_enc_ddpm.to(devices[0]), retain_emb_n.to(devices[0])) + + # retain_e_0.requires_grad = False + retain_e_p.requires_grad = False + retain_loss = criteria(retain_e_n.to(devices[0]), retain_e_p.to(devices[0])) + + loss = unlearn_loss + retain_loss_w * retain_loss + return loss + + else: + # reconstruction loss for ESD objective from frozen model and conditional score of ESD model + unlearn_loss = criteria(e_n.to(devices[0]), e_0.to(devices[0]) - (negative_guidance*(e_p.to(devices[0]) - e_0.to(devices[0])))) + return unlearn_loss + +def save_text_encoder(folder_path, model, name, num): + # SAVE MODEL + + # PATH = f'{FOLDER}/{model_type}-word_{word_print}-method_{train_method}-sg_{start_guidance}-ng_{neg_guidance}-iter_{i+1}-lr_{lr}-startmodel_{start_model}-numacc_{numacc}.pt' + folder_path = f'{folder_path}/models' + os.makedirs(folder_path, exist_ok=True) + if num is not None: + path = f'{folder_path}/TextEncoder-{name}-epoch_{num}.pt' + else: + path = f'{folder_path}/TextEncoder-{name}.pt' + + torch.save(model.state_dict(), path) + + + +def create_unet_diffusers_config(original_config, image_size: int): + """ + Creates a config for the diffusers based on the config of the LDM model. + """ + unet_params = original_config.model.params.unet_config.params + vae_params = original_config.model.params.first_stage_config.params.ddconfig + + block_out_channels = [unet_params.model_channels * mult for mult in unet_params.channel_mult] + + down_block_types = [] + resolution = 1 + for i in range(len(block_out_channels)): + block_type = "CrossAttnDownBlock2D" if resolution in unet_params.attention_resolutions else "DownBlock2D" + down_block_types.append(block_type) + if i != len(block_out_channels) - 1: + resolution *= 2 + + up_block_types = [] + for i in range(len(block_out_channels)): + block_type = "CrossAttnUpBlock2D" if resolution in unet_params.attention_resolutions else "UpBlock2D" + up_block_types.append(block_type) + resolution //= 2 + + vae_scale_factor = 2 ** (len(vae_params.ch_mult) - 1) + + head_dim = unet_params.num_heads if "num_heads" in unet_params else None + use_linear_projection = ( + unet_params.use_linear_in_transformer if "use_linear_in_transformer" in unet_params else False + ) + if use_linear_projection: + # stable diffusion 2-base-512 and 2-768 + if head_dim is None: + head_dim = [5, 10, 20, 20] + + config = dict( + sample_size=image_size // vae_scale_factor, + in_channels=unet_params.in_channels, + out_channels=unet_params.out_channels, + down_block_types=tuple(down_block_types), + up_block_types=tuple(up_block_types), + block_out_channels=tuple(block_out_channels), + layers_per_block=unet_params.num_res_blocks, + cross_attention_dim=unet_params.context_dim, + attention_head_dim=head_dim, + use_linear_projection=use_linear_projection, + ) + + return config + + +def convert_ldm_unet_checkpoint(checkpoint, config, path=None, extract_ema=False): + """ + Takes a state dict and a config, and returns a converted checkpoint. + """ + + # extract state_dict for UNet + unet_state_dict = {} + keys = list(checkpoint.keys()) + + unet_key = "model.diffusion_model." + # at least a 100 parameters have to start with `model_ema` in order for the checkpoint to be EMA + if sum(k.startswith("model_ema") for k in keys) > 100 and extract_ema: + print(f"Checkpoint {path} has both EMA and non-EMA weights.") + print( + "In this conversion only the EMA weights are extracted. If you want to instead extract the non-EMA" + " weights (useful to continue fine-tuning), please make sure to remove the `--extract_ema` flag." + ) + for key in keys: + if key.startswith("model.diffusion_model"): + flat_ema_key = "model_ema." + "".join(key.split(".")[1:]) + unet_state_dict[key.replace(unet_key, "")] = checkpoint.pop(flat_ema_key) + else: + if sum(k.startswith("model_ema") for k in keys) > 100: + print( + "In this conversion only the non-EMA weights are extracted. If you want to instead extract the EMA" + " weights (usually better for inference), please make sure to add the `--extract_ema` flag." + ) + + for key in keys: + if key.startswith(unet_key): + unet_state_dict[key.replace(unet_key, "")] = checkpoint.pop(key) + + new_checkpoint = {} + + new_checkpoint["time_embedding.linear_1.weight"] = unet_state_dict["time_embed.0.weight"] + new_checkpoint["time_embedding.linear_1.bias"] = unet_state_dict["time_embed.0.bias"] + new_checkpoint["time_embedding.linear_2.weight"] = unet_state_dict["time_embed.2.weight"] + new_checkpoint["time_embedding.linear_2.bias"] = unet_state_dict["time_embed.2.bias"] + + new_checkpoint["conv_in.weight"] = unet_state_dict["input_blocks.0.0.weight"] + new_checkpoint["conv_in.bias"] = unet_state_dict["input_blocks.0.0.bias"] + + new_checkpoint["conv_norm_out.weight"] = unet_state_dict["out.0.weight"] + new_checkpoint["conv_norm_out.bias"] = unet_state_dict["out.0.bias"] + new_checkpoint["conv_out.weight"] = unet_state_dict["out.2.weight"] + new_checkpoint["conv_out.bias"] = unet_state_dict["out.2.bias"] + + # Retrieves the keys for the input blocks only + num_input_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "input_blocks" in layer}) + input_blocks = { + layer_id: [key for key in unet_state_dict if f"input_blocks.{layer_id}" in key] + for layer_id in range(num_input_blocks) + } + + # Retrieves the keys for the middle blocks only + num_middle_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "middle_block" in layer}) + middle_blocks = { + layer_id: [key for key in unet_state_dict if f"middle_block.{layer_id}" in key] + for layer_id in range(num_middle_blocks) + } + + # Retrieves the keys for the output blocks only + num_output_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "output_blocks" in layer}) + output_blocks = { + layer_id: [key for key in unet_state_dict if f"output_blocks.{layer_id}" in key] + for layer_id in range(num_output_blocks) + } + + for i in range(1, num_input_blocks): + block_id = (i - 1) // (config["layers_per_block"] + 1) + layer_in_block_id = (i - 1) % (config["layers_per_block"] + 1) + + resnets = [ + key for key in input_blocks[i] if f"input_blocks.{i}.0" in key and f"input_blocks.{i}.0.op" not in key + ] + attentions = [key for key in input_blocks[i] if f"input_blocks.{i}.1" in key] + + if f"input_blocks.{i}.0.op.weight" in unet_state_dict: + new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.weight"] = unet_state_dict.pop( + f"input_blocks.{i}.0.op.weight" + ) + new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.bias"] = unet_state_dict.pop( + f"input_blocks.{i}.0.op.bias" + ) + + paths = renew_resnet_paths(resnets) + meta_path = {"old": f"input_blocks.{i}.0", "new": f"down_blocks.{block_id}.resnets.{layer_in_block_id}"} + assign_to_checkpoint( + paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config + ) + + if len(attentions): + paths = renew_attention_paths(attentions) + meta_path = {"old": f"input_blocks.{i}.1", "new": f"down_blocks.{block_id}.attentions.{layer_in_block_id}"} + assign_to_checkpoint( + paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config + ) + + resnet_0 = middle_blocks[0] + attentions = middle_blocks[1] + resnet_1 = middle_blocks[2] + + resnet_0_paths = renew_resnet_paths(resnet_0) + assign_to_checkpoint(resnet_0_paths, new_checkpoint, unet_state_dict, config=config) + + resnet_1_paths = renew_resnet_paths(resnet_1) + assign_to_checkpoint(resnet_1_paths, new_checkpoint, unet_state_dict, config=config) + + attentions_paths = renew_attention_paths(attentions) + meta_path = {"old": "middle_block.1", "new": "mid_block.attentions.0"} + assign_to_checkpoint( + attentions_paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config + ) + + for i in range(num_output_blocks): + block_id = i // (config["layers_per_block"] + 1) + layer_in_block_id = i % (config["layers_per_block"] + 1) + output_block_layers = [shave_segments(name, 2) for name in output_blocks[i]] + output_block_list = {} + + for layer in output_block_layers: + layer_id, layer_name = layer.split(".")[0], shave_segments(layer, 1) + if layer_id in output_block_list: + output_block_list[layer_id].append(layer_name) + else: + output_block_list[layer_id] = [layer_name] + + if len(output_block_list) > 1: + resnets = [key for key in output_blocks[i] if f"output_blocks.{i}.0" in key] + attentions = [key for key in output_blocks[i] if f"output_blocks.{i}.1" in key] + + resnet_0_paths = renew_resnet_paths(resnets) + paths = renew_resnet_paths(resnets) + + meta_path = {"old": f"output_blocks.{i}.0", "new": f"up_blocks.{block_id}.resnets.{layer_in_block_id}"} + assign_to_checkpoint( + paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config + ) + + output_block_list = {k: sorted(v) for k, v in output_block_list.items()} + if ["conv.bias", "conv.weight"] in output_block_list.values(): + index = list(output_block_list.values()).index(["conv.bias", "conv.weight"]) + new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.weight"] = unet_state_dict[ + f"output_blocks.{i}.{index}.conv.weight" + ] + new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.bias"] = unet_state_dict[ + f"output_blocks.{i}.{index}.conv.bias" + ] + + # Clear attentions as they have been attributed above. + if len(attentions) == 2: + attentions = [] + + if len(attentions): + paths = renew_attention_paths(attentions) + meta_path = { + "old": f"output_blocks.{i}.1", + "new": f"up_blocks.{block_id}.attentions.{layer_in_block_id}", + } + assign_to_checkpoint( + paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config + ) + else: + resnet_0_paths = renew_resnet_paths(output_block_layers, n_shave_prefix_segments=1) + for path in resnet_0_paths: + old_path = ".".join(["output_blocks", str(i), path["old"]]) + new_path = ".".join(["up_blocks", str(block_id), "resnets", str(layer_in_block_id), path["new"]]) + + new_checkpoint[new_path] = unet_state_dict[old_path] + + return new_checkpoint + +def assign_to_checkpoint( + paths, checkpoint, old_checkpoint, attention_paths_to_split=None, additional_replacements=None, config=None +): + """ + This does the final conversion step: take locally converted weights and apply a global renaming + to them. It splits attention layers, and takes into account additional replacements + that may arise. + + Assigns the weights to the new checkpoint. + """ + assert isinstance(paths, list), "Paths should be a list of dicts containing 'old' and 'new' keys." + + # Splits the attention layers into three variables. + if attention_paths_to_split is not None: + for path, path_map in attention_paths_to_split.items(): + old_tensor = old_checkpoint[path] + channels = old_tensor.shape[0] // 3 + + target_shape = (-1, channels) if len(old_tensor.shape) == 3 else (-1) + + num_heads = old_tensor.shape[0] // config["num_head_channels"] // 3 + + old_tensor = old_tensor.reshape((num_heads, 3 * channels // num_heads) + old_tensor.shape[1:]) + query, key, value = old_tensor.split(channels // num_heads, dim=1) + + checkpoint[path_map["query"]] = query.reshape(target_shape) + checkpoint[path_map["key"]] = key.reshape(target_shape) + checkpoint[path_map["value"]] = value.reshape(target_shape) + + for path in paths: + new_path = path["new"] + + # These have already been assigned + if attention_paths_to_split is not None and new_path in attention_paths_to_split: + continue + + # Global renaming happens here + new_path = new_path.replace("middle_block.0", "mid_block.resnets.0") + new_path = new_path.replace("middle_block.1", "mid_block.attentions.0") + new_path = new_path.replace("middle_block.2", "mid_block.resnets.1") + + if additional_replacements is not None: + for replacement in additional_replacements: + new_path = new_path.replace(replacement["old"], replacement["new"]) + + # proj_attn.weight has to be converted from conv 1D to linear + if "proj_attn.weight" in new_path: + checkpoint[new_path] = old_checkpoint[path["old"]][:, :, 0] + else: + checkpoint[new_path] = old_checkpoint[path["old"]] + + +def renew_attention_paths(old_list, n_shave_prefix_segments=0): + """ + Updates paths inside attentions to the new naming scheme (local renaming) + """ + mapping = [] + for old_item in old_list: + new_item = old_item + + # new_item = new_item.replace('norm.weight', 'group_norm.weight') + # new_item = new_item.replace('norm.bias', 'group_norm.bias') + + # new_item = new_item.replace('proj_out.weight', 'proj_attn.weight') + # new_item = new_item.replace('proj_out.bias', 'proj_attn.bias') + + # new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments) + + mapping.append({"old": old_item, "new": new_item}) + + return mapping + +def renew_resnet_paths(old_list, n_shave_prefix_segments=0): + """ + Updates paths inside resnets to the new naming scheme (local renaming) + """ + mapping = [] + for old_item in old_list: + new_item = old_item.replace("in_layers.0", "norm1") + new_item = new_item.replace("in_layers.2", "conv1") + + new_item = new_item.replace("out_layers.0", "norm2") + new_item = new_item.replace("out_layers.3", "conv2") + + new_item = new_item.replace("emb_layers.1", "time_emb_proj") + new_item = new_item.replace("skip_connection", "conv_shortcut") + + new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments) + + mapping.append({"old": old_item, "new": new_item}) + + return mapping + + +def savemodelDiffusers(path, name, compvis_config_file, diffusers_config_file, device='cpu'): + checkpoint_path = path + + original_config_file = compvis_config_file + config_file = diffusers_config_file + num_in_channels = 4 + scheduler_type = 'ddim' + pipeline_type = None + image_size = 512 + prediction_type = 'epsilon' + extract_ema = False + dump_path = path.replace('Compvis','Diffusers') + upcast_attention = False + + + if device is None: + device = "cuda" if torch.cuda.is_available() else "cpu" + checkpoint = torch.load(checkpoint_path, map_location=device) + else: + checkpoint = torch.load(checkpoint_path, map_location=device) + + # Sometimes models don't have the global_step item + if "global_step" in checkpoint: + global_step = checkpoint["global_step"] + else: + print("global_step key not found in model") + global_step = None + + if "state_dict" in checkpoint: + checkpoint = checkpoint["state_dict"] + upcast_attention = upcast_attention + if original_config_file is None: + key_name = "model.diffusion_model.input_blocks.2.1.transformer_blocks.0.attn2.to_k.weight" + + if key_name in checkpoint and checkpoint[key_name].shape[-1] == 1024: + if not os.path.isfile("v2-inference-v.yaml"): + # model_type = "v2" + os.system( + "wget https://raw.githubusercontent.com/Stability-AI/stablediffusion/main/configs/stable-diffusion/v2-inference-v.yaml" + " -O v2-inference-v.yaml" + ) + original_config_file = "./v2-inference-v.yaml" + + if global_step == 110000: + # v2.1 needs to upcast attention + upcast_attention = True + else: + if not os.path.isfile("v1-inference.yaml"): + # model_type = "v1" + os.system( + "wget https://raw.githubusercontent.com/CompVis/stable-diffusion/main/configs/stable-diffusion/v1-inference.yaml" + " -O v1-inference.yaml" + ) + original_config_file = "./v1-inference.yaml" + + original_config = OmegaConf.load(original_config_file) + + if num_in_channels is not None: + original_config["model"]["params"]["unet_config"]["params"]["in_channels"] = num_in_channels + + if ( + "parameterization" in original_config["model"]["params"] + and original_config["model"]["params"]["parameterization"] == "v" + ): + if prediction_type is None: + # NOTE: For stable diffusion 2 base it is recommended to pass `prediction_type=="epsilon"` + # as it relies on a brittle global step parameter here + prediction_type = "epsilon" if global_step == 875000 else "v_prediction" + if image_size is None: + # NOTE: For stable diffusion 2 base one has to pass `image_size==512` + # as it relies on a brittle global step parameter here + image_size = 512 if global_step == 875000 else 768 + else: + if prediction_type is None: + prediction_type = "epsilon" + if image_size is None: + image_size = 512 + + num_train_timesteps = original_config.model.params.timesteps + beta_start = original_config.model.params.linear_start + beta_end = original_config.model.params.linear_end + scheduler = DDIMScheduler( + beta_end=beta_end, + beta_schedule="scaled_linear", + beta_start=beta_start, + num_train_timesteps=num_train_timesteps, + steps_offset=1, + clip_sample=False, + set_alpha_to_one=False, + prediction_type=prediction_type, + ) + # make sure scheduler works correctly with DDIM + scheduler.register_to_config(clip_sample=False) + + if scheduler_type == "pndm": + config = dict(scheduler.config) + config["skip_prk_steps"] = True + scheduler = PNDMScheduler.from_config(config) + elif scheduler_type == "lms": + scheduler = LMSDiscreteScheduler.from_config(scheduler.config) + elif scheduler_type == "heun": + scheduler = HeunDiscreteScheduler.from_config(scheduler.config) + elif scheduler_type == "euler": + scheduler = EulerDiscreteScheduler.from_config(scheduler.config) + elif scheduler_type == "euler-ancestral": + scheduler = EulerAncestralDiscreteScheduler.from_config(scheduler.config) + elif scheduler_type == "dpm": + scheduler = DPMSolverMultistepScheduler.from_config(scheduler.config) + elif scheduler_type == "ddim": + scheduler = scheduler + else: + raise ValueError(f"Scheduler of type {scheduler_type} doesn't exist!") + + # Convert the UNet2DConditionModel model. + unet_config = create_unet_diffusers_config(original_config, image_size=image_size) + unet_config["upcast_attention"] = False + unet = UNet2DConditionModel(**unet_config) + + converted_unet_checkpoint = convert_ldm_unet_checkpoint( + checkpoint, unet_config, path=checkpoint_path, extract_ema=extract_ema + ) + torch.save(converted_unet_checkpoint, dump_path) + + + +def save_model(folder_path, model, name, num, compvis_config_file=None, diffusers_config_file=None, device='cpu', save_compvis=True, save_diffusers=True): + # SAVE MODEL + + # PATH = f'{FOLDER}/{model_type}-word_{word_print}-method_{train_method}-sg_{start_guidance}-ng_{neg_guidance}-iter_{i+1}-lr_{lr}-startmodel_{start_model}-numacc_{numacc}.pt' + folder_path = f'{folder_path}/models' + os.makedirs(folder_path, exist_ok=True) + if num is not None: + path = f'{folder_path}/Compvis-UNet-{name}-epoch_{num}.pt' + else: + path = f'{folder_path}/Compvis-UNet-{name}.pt' + if save_compvis: + torch.save(model.state_dict(), path) + + if save_diffusers: + print('Saving Model in Diffusers Format') + savemodelDiffusers(path, name, compvis_config_file, diffusers_config_file, device=device ) + + +def moving_average(a, n=3) : + ret = np.cumsum(a, dtype=float) + ret[n:] = ret[n:] - ret[:-n] + return ret[n - 1:] / n + +def plot_loss(losses, path,word, n=100): + v = moving_average(losses, n) + plt.plot(v, label=f'{word}_loss') + plt.legend(loc="upper left") + plt.title('Average loss in trainings', fontsize=20) + plt.xlabel('Data point', fontsize=16) + plt.ylabel('Loss value', fontsize=16) + plt.savefig(path) + +def save_history(folder_path, losses, word_print): + folder_path = f'{folder_path}/logs' + os.makedirs(folder_path, exist_ok=True) + with open(f'{folder_path}/loss.txt', 'w') as f: + f.writelines([str(i) for i in losses]) + plot_loss(losses,f'{folder_path}/loss.png' , word_print, n=3) \ No newline at end of file From e55c861db1d55b7159aaaa422987223361d68d57 Mon Sep 17 00:00:00 2001 From: palisthadeshar Date: Tue, 4 Feb 2025 10:33:11 +0545 Subject: [PATCH 03/22] prompt dataset added --- data/prompts/coco_object_no_filter_retain.csv | 244 ++++++++++++++++++ data/prompts/coco_object_retain.csv | 244 ++++++++++++++++++ data/prompts/imagenet243_no_filter_retain.csv | 244 ++++++++++++++++++ data/prompts/imagenet243_retain.csv | 244 ++++++++++++++++++ data/prompts/small_imagenet_prompts.csv | 101 ++++++++ mu_attack/adv_unlearn_environment.yaml | 34 +++ mu_attack/execs/adv_attack.py | 4 +- 7 files changed, 1113 insertions(+), 2 deletions(-) create mode 100644 data/prompts/coco_object_no_filter_retain.csv create mode 100644 data/prompts/coco_object_retain.csv create mode 100644 data/prompts/imagenet243_no_filter_retain.csv create mode 100644 data/prompts/imagenet243_retain.csv create mode 100644 data/prompts/small_imagenet_prompts.csv create mode 100644 mu_attack/adv_unlearn_environment.yaml diff --git a/data/prompts/coco_object_no_filter_retain.csv b/data/prompts/coco_object_no_filter_retain.csv new file mode 100644 index 00000000..1c15a6f5 --- /dev/null +++ b/data/prompts/coco_object_no_filter_retain.csv @@ -0,0 +1,244 @@ +case_num,source,prompt +1,coco_object,a photo of chair +2,coco_object,a photo of fridge +3,coco_object,a photo of banana +4,coco_object,a photo of street sign +5,coco_object,a photo of headlights +6,coco_object,a photo of shorts +7,coco_object,a photo of handbag +8,coco_object,a photo of skis +9,coco_object,a photo of skateboard +10,coco_object,a photo of chopping board +11,coco_object,a photo of goat +12,coco_object,a photo of playing cards +13,coco_object,a photo of underpants +14,coco_object,a photo of toy cars +15,coco_object,a photo of super hero costume +16,coco_object,a photo of pasta +17,coco_object,a photo of moon +18,coco_object,a photo of basketball +19,coco_object,a photo of radio +20,coco_object,a photo of ipad +21,coco_object,a photo of goldfish +22,coco_object,a photo of jetpack +23,coco_object,a photo of pajamas +24,coco_object,a photo of couch +25,coco_object,a photo of microwave +26,coco_object,a photo of bread +27,coco_object,a photo of umbrella +28,coco_object,a photo of window +29,coco_object,a photo of teddy bear +30,coco_object,a photo of pans +31,coco_object,a photo of hot dog +32,coco_object,a photo of snowboard +33,coco_object,a photo of helicopter +34,coco_object,a photo of washer +35,coco_object,a photo of magazine +36,coco_object,a photo of shirt +37,coco_object,a photo of phone +38,coco_object,a photo of towel +39,coco_object,a photo of necklace +40,coco_object,a photo of bracelet +41,coco_object,a photo of platypus +42,coco_object,a photo of feet +43,coco_object,a photo of road +44,coco_object,a photo of telephone +45,coco_object,a photo of fences +46,coco_object,a photo of aardvark +47,coco_object,a photo of iphone +48,coco_object,a photo of robot +49,coco_object,a photo of car +50,coco_object,a photo of potted plant +51,coco_object,a photo of sink +52,coco_object,a photo of apple +53,coco_object,a photo of scissors +54,coco_object,a photo of legs +55,coco_object,a photo of desk +56,coco_object,a photo of tie +57,coco_object,a photo of stapler +58,coco_object,a photo of table +59,coco_object,a photo of armpits +60,coco_object,a photo of tomato +61,coco_object,a photo of lion +62,coco_object,a photo of key +63,coco_object,a photo of Pig +64,coco_object,a photo of hyppo +65,coco_object,a photo of tablet +66,coco_object,a photo of arms +67,coco_object,a photo of pancake +68,coco_object,a photo of shark +69,coco_object,a photo of fountain +70,coco_object,a photo of movie +71,coco_object,a photo of goal net +72,coco_object,a photo of dinosaur +73,coco_object,a photo of hoop +74,coco_object,a photo of crusher +75,coco_object,a photo of motorcycle +76,coco_object,a photo of tv +77,coco_object,a photo of torso +78,coco_object,a photo of book +79,coco_object,a photo of short sleeve shirt +80,coco_object,a photo of fire hydrant +81,coco_object,a photo of computer +82,coco_object,a photo of stop sign +83,coco_object,a photo of sports ball +84,coco_object,a photo of basketball +85,coco_object,a photo of hoop +86,coco_object,a photo of pants +87,coco_object,a photo of tree +88,coco_object,a photo of bunny +89,coco_object,a photo of frame +90,coco_object,a photo of strawberries +91,coco_object,a photo of fingers +92,coco_object,a photo of corn +93,coco_object,a photo of balloon +94,coco_object,a photo of back +95,coco_object,a photo of swan +96,coco_object,a photo of fax machine +97,coco_object,a photo of head +98,coco_object,a photo of toys +99,coco_object,a photo of unicycle +100,coco_object,a photo of hen +101,coco_object,a photo of animal crackers +102,coco_object,a photo of bird +103,coco_object,a photo of cow +104,coco_object,a photo of toaster +105,coco_object,a photo of boat +106,coco_object,a photo of backpack +107,coco_object,a photo of traffic light +108,coco_object,a photo of hand +109,coco_object,a photo of refrigerator +110,coco_object,a photo of surfboard +111,coco_object,a photo of broccoli +112,coco_object,a photo of mouth +113,coco_object,a photo of door handle +114,coco_object,a photo of hair brush +115,coco_object,a photo of cupcake +116,coco_object,a photo of pumpkin +117,coco_object,a photo of dollar bill +118,coco_object,a photo of ladder +119,coco_object,a photo of ears +120,coco_object,a photo of whale +121,coco_object,a photo of bat +122,coco_object,a photo of goose +123,coco_object,a photo of engine +124,coco_object,a photo of nose +125,coco_object,a photo of basketball court +126,coco_object,a photo of cat +127,coco_object,a photo of airplane +128,coco_object,a photo of bus +129,coco_object,a photo of plate +130,coco_object,a photo of steering wheel +131,coco_object,a photo of eyeglasses +132,coco_object,a photo of teapot +133,coco_object,a photo of pizza +134,coco_object,a photo of sandwich +135,coco_object,a photo of suitcase +136,coco_object,a photo of vase +137,coco_object,a photo of power +138,coco_object,a photo of face +139,coco_object,a photo of pillow +140,coco_object,a photo of light switch +141,coco_object,a photo of eye +142,coco_object,a photo of van +143,coco_object,a photo of doll +144,coco_object,a photo of pineapple +145,coco_object,a photo of milk +146,coco_object,a photo of dryer +147,coco_object,a photo of towel +148,coco_object,a photo of hot air balloon +149,coco_object,a photo of soccer ball +150,coco_object,a photo of legos +151,coco_object,a photo of table cloth +152,coco_object,a photo of horn +153,coco_object,a photo of dog +154,coco_object,a photo of hat +155,coco_object,a photo of train +156,coco_object,a photo of cell phone +157,coco_object,a photo of wine glass +158,coco_object,a photo of cup +159,coco_object,a photo of fork +160,coco_object,a photo of squirrel +161,coco_object,a photo of pen +162,coco_object,a photo of carrot +163,coco_object,a photo of baseball bat +164,coco_object,a photo of tennis racket +165,coco_object,a photo of frogs +166,coco_object,a photo of kangaroo +167,coco_object,a photo of soup +168,coco_object,a photo of candle +169,coco_object,a photo of side table +170,coco_object,a photo of cereal +171,coco_object,a photo of field goal posts +172,coco_object,a photo of fly +173,coco_object,a photo of soccer nets +174,coco_object,a photo of firefly +175,coco_object,a photo of horse +176,coco_object,a photo of license plate +177,coco_object,a photo of mirror +178,coco_object,a photo of mouse +179,coco_object,a photo of chicken +180,coco_object,a photo of blender +181,coco_object,a photo of knife +182,coco_object,a photo of duck +183,coco_object,a photo of kite +184,coco_object,a photo of chandelier +185,coco_object,a photo of baseball glove +186,coco_object,a photo of tiger +187,coco_object,a photo of cake +188,coco_object,a photo of rhinoceros +189,coco_object,a photo of meat +190,coco_object,a photo of desktop +191,coco_object,a photo of wheelchair +192,coco_object,a photo of lizard +193,coco_object,a photo of gate +194,coco_object,a photo of seahorse +195,coco_object,a photo of raft +196,coco_object,a photo of roof +197,coco_object,a photo of turkey +198,coco_object,a photo of sheep +199,coco_object,a photo of bed +200,coco_object,a photo of dining table +201,coco_object,a photo of remote +202,coco_object,a photo of zebra +203,coco_object,a photo of hair drier +204,coco_object,a photo of spoon +205,coco_object,a photo of frisbee +206,coco_object,a photo of orange +207,coco_object,a photo of parking meter +208,coco_object,a photo of giraffe +209,coco_object,a photo of table +210,coco_object,a photo of house +211,coco_object,a photo of owl +212,coco_object,a photo of sailboat +213,coco_object,a photo of window +214,coco_object,a photo of carpet +215,coco_object,a photo of building +216,coco_object,a photo of beans +217,coco_object,a photo of rocket +218,coco_object,a photo of rooster +219,coco_object,a photo of tennis net +220,coco_object,a photo of baseball +221,coco_object,a photo of nectar +222,coco_object,a photo of bottle +223,coco_object,a photo of laptop +224,coco_object,a photo of elephant +225,coco_object,a photo of clock +226,coco_object,a photo of wheel +227,coco_object,a photo of bear +228,coco_object,a photo of guitar +229,coco_object,a photo of toothbrush +230,coco_object,a photo of fish +231,coco_object,a photo of jacket +232,coco_object,a photo of coffee table +233,coco_object,a photo of bench +234,coco_object,a photo of cheese +235,coco_object,a photo of scarf +236,coco_object,a photo of deer +237,coco_object,a photo of muffins +238,coco_object,a photo of cookie +239,coco_object,a photo of bacon +240,coco_object,a photo of cabinets +241,coco_object,a photo of copier +242,coco_object,a photo of seats +243,coco_object,a photo of mat diff --git a/data/prompts/coco_object_retain.csv b/data/prompts/coco_object_retain.csv new file mode 100644 index 00000000..5e05a2a9 --- /dev/null +++ b/data/prompts/coco_object_retain.csv @@ -0,0 +1,244 @@ +case_num,source,prompt +1,coco_object,a photo of chair +2,coco_object,a photo of fridge +3,coco_object,a photo of banana +4,coco_object,a photo of street sign +5,coco_object,a photo of headlights +6,coco_object,a photo of printer +7,coco_object,a photo of handbag +8,coco_object,a photo of skis +9,coco_object,a photo of skateboard +10,coco_object,a photo of chopping board +11,coco_object,a photo of goat +12,coco_object,a photo of playing cards +13,coco_object,a photo of tire +14,coco_object,a photo of toy cars +15,coco_object,a photo of box +16,coco_object,a photo of pasta +17,coco_object,a photo of moon +18,coco_object,a photo of basketball +19,coco_object,a photo of radio +20,coco_object,a photo of ipad +21,coco_object,a photo of goldfish +22,coco_object,a photo of jetpack +23,coco_object,a photo of bicycle +24,coco_object,a photo of couch +25,coco_object,a photo of microwave +26,coco_object,a photo of bread +27,coco_object,a photo of umbrella +28,coco_object,a photo of window +29,coco_object,a photo of teddy bear +30,coco_object,a photo of pans +31,coco_object,a photo of hot dog +32,coco_object,a photo of snowboard +33,coco_object,a photo of helicopter +34,coco_object,a photo of washer +35,coco_object,a photo of magazine +36,coco_object,a photo of home +37,coco_object,a photo of phone +38,coco_object,a photo of towel +39,coco_object,a photo of necklace +40,coco_object,a photo of bracelet +41,coco_object,a photo of platypus +42,coco_object,a photo of grapes +43,coco_object,a photo of road +44,coco_object,a photo of telephone +45,coco_object,a photo of fences +46,coco_object,a photo of aardvark +47,coco_object,a photo of iphone +48,coco_object,a photo of robot +49,coco_object,a photo of car +50,coco_object,a photo of potted plant +51,coco_object,a photo of sink +52,coco_object,a photo of apple +53,coco_object,a photo of scissors +54,coco_object,a photo of door +55,coco_object,a photo of desk +56,coco_object,a photo of tie +57,coco_object,a photo of stapler +58,coco_object,a photo of table +59,coco_object,a photo of lamp +60,coco_object,a photo of tomato +61,coco_object,a photo of lion +62,coco_object,a photo of key +63,coco_object,a photo of Pig +64,coco_object,a photo of hyppo +65,coco_object,a photo of tablet +66,coco_object,a photo of bat +67,coco_object,a photo of pancake +68,coco_object,a photo of shark +69,coco_object,a photo of fountain +70,coco_object,a photo of movie +71,coco_object,a photo of goal net +72,coco_object,a photo of dinosaur +73,coco_object,a photo of hoop +74,coco_object,a photo of crusher +75,coco_object,a photo of motorcycle +76,coco_object,a photo of tv +77,coco_object,a photo of oven +78,coco_object,a photo of book +79,coco_object,a photo of keyboard +80,coco_object,a photo of fire hydrant +81,coco_object,a photo of computer +82,coco_object,a photo of stop sign +83,coco_object,a photo of sports ball +84,coco_object,a photo of basketball +85,coco_object,a photo of hoop +86,coco_object,a photo of egg +87,coco_object,a photo of tree +88,coco_object,a photo of monkey +89,coco_object,a photo of frame +90,coco_object,a photo of strawberries +91,coco_object,a photo of can +92,coco_object,a photo of corn +93,coco_object,a photo of balloon +94,coco_object,a photo of cabinet +95,coco_object,a photo of swan +96,coco_object,a photo of fax machine +97,coco_object,a photo of football +98,coco_object,a photo of toys +99,coco_object,a photo of unicycle +100,coco_object,a photo of hen +101,coco_object,a photo of animal crackers +102,coco_object,a photo of bird +103,coco_object,a photo of cow +104,coco_object,a photo of toaster +105,coco_object,a photo of boat +106,coco_object,a photo of backpack +107,coco_object,a photo of traffic light +108,coco_object,a photo of bowl +109,coco_object,a photo of refrigerator +110,coco_object,a photo of surfboard +111,coco_object,a photo of broccoli +112,coco_object,a photo of donut +113,coco_object,a photo of door handle +114,coco_object,a photo of hair brush +115,coco_object,a photo of cupcake +116,coco_object,a photo of pumpkin +117,coco_object,a photo of dollar bill +118,coco_object,a photo of ladder +119,coco_object,a photo of gloves +120,coco_object,a photo of whale +121,coco_object,a photo of bat +122,coco_object,a photo of goose +123,coco_object,a photo of engine +124,coco_object,a photo of honey +125,coco_object,a photo of basketball court +126,coco_object,a photo of cat +127,coco_object,a photo of airplane +128,coco_object,a photo of bus +129,coco_object,a photo of plate +130,coco_object,a photo of steering wheel +131,coco_object,a photo of eyeglasses +132,coco_object,a photo of teapot +133,coco_object,a photo of pizza +134,coco_object,a photo of sandwich +135,coco_object,a photo of suitcase +136,coco_object,a photo of vase +137,coco_object,a photo of power +138,coco_object,a photo of outlet +139,coco_object,a photo of pillow +140,coco_object,a photo of light switch +141,coco_object,a photo of fan +142,coco_object,a photo of van +143,coco_object,a photo of doll +144,coco_object,a photo of pineapple +145,coco_object,a photo of milk +146,coco_object,a photo of dryer +147,coco_object,a photo of towel +148,coco_object,a photo of hot air balloon +149,coco_object,a photo of soccer ball +150,coco_object,a photo of legos +151,coco_object,a photo of table cloth +152,coco_object,a photo of horn +153,coco_object,a photo of dog +154,coco_object,a photo of hat +155,coco_object,a photo of train +156,coco_object,a photo of cell phone +157,coco_object,a photo of wine glass +158,coco_object,a photo of cup +159,coco_object,a photo of fork +160,coco_object,a photo of squirrel +161,coco_object,a photo of pen +162,coco_object,a photo of carrot +163,coco_object,a photo of baseball bat +164,coco_object,a photo of tennis racket +165,coco_object,a photo of frogs +166,coco_object,a photo of kangaroo +167,coco_object,a photo of soup +168,coco_object,a photo of candle +169,coco_object,a photo of side table +170,coco_object,a photo of cereal +171,coco_object,a photo of field goal posts +172,coco_object,a photo of fly +173,coco_object,a photo of soccer nets +174,coco_object,a photo of firefly +175,coco_object,a photo of horse +176,coco_object,a photo of license plate +177,coco_object,a photo of mirror +178,coco_object,a photo of mouse +179,coco_object,a photo of chicken +180,coco_object,a photo of blender +181,coco_object,a photo of knife +182,coco_object,a photo of duck +183,coco_object,a photo of kite +184,coco_object,a photo of chandelier +185,coco_object,a photo of baseball glove +186,coco_object,a photo of tiger +187,coco_object,a photo of cake +188,coco_object,a photo of rhinoceros +189,coco_object,a photo of meat +190,coco_object,a photo of desktop +191,coco_object,a photo of wheelchair +192,coco_object,a photo of lizard +193,coco_object,a photo of gate +194,coco_object,a photo of seahorse +195,coco_object,a photo of raft +196,coco_object,a photo of roof +197,coco_object,a photo of turkey +198,coco_object,a photo of sheep +199,coco_object,a photo of bed +200,coco_object,a photo of dining table +201,coco_object,a photo of remote +202,coco_object,a photo of zebra +203,coco_object,a photo of hair drier +204,coco_object,a photo of spoon +205,coco_object,a photo of frisbee +206,coco_object,a photo of orange +207,coco_object,a photo of parking meter +208,coco_object,a photo of giraffe +209,coco_object,a photo of table +210,coco_object,a photo of house +211,coco_object,a photo of owl +212,coco_object,a photo of sailboat +213,coco_object,a photo of window +214,coco_object,a photo of carpet +215,coco_object,a photo of building +216,coco_object,a photo of beans +217,coco_object,a photo of rocket +218,coco_object,a photo of rooster +219,coco_object,a photo of tennis net +220,coco_object,a photo of baseball +221,coco_object,a photo of nectar +222,coco_object,a photo of bottle +223,coco_object,a photo of laptop +224,coco_object,a photo of elephant +225,coco_object,a photo of clock +226,coco_object,a photo of wheel +227,coco_object,a photo of bear +228,coco_object,a photo of guitar +229,coco_object,a photo of toothbrush +230,coco_object,a photo of fish +231,coco_object,a photo of jacket +232,coco_object,a photo of coffee table +233,coco_object,a photo of bench +234,coco_object,a photo of cheese +235,coco_object,a photo of scarf +236,coco_object,a photo of deer +237,coco_object,a photo of muffins +238,coco_object,a photo of cookie +239,coco_object,a photo of bacon +240,coco_object,a photo of cabinets +241,coco_object,a photo of copier +242,coco_object,a photo of seats +243,coco_object,a photo of mat diff --git a/data/prompts/imagenet243_no_filter_retain.csv b/data/prompts/imagenet243_no_filter_retain.csv new file mode 100644 index 00000000..d638bacd --- /dev/null +++ b/data/prompts/imagenet243_no_filter_retain.csv @@ -0,0 +1,244 @@ +case_num,source,prompt +1,imagenet,a photo of strawberry +2,imagenet,a photo of pedestal +3,imagenet,a photo of scoreboard +4,imagenet,a photo of jaguar +5,imagenet,a photo of ear +6,imagenet,a photo of hummingbird +7,imagenet,a photo of tobacco shop +8,imagenet,a photo of Greater Swiss Mountain dog +9,imagenet,a photo of wine bottle +10,imagenet,a photo of yellow lady-slipper +11,imagenet,a photo of ballpoint +12,imagenet,a photo of Irish water spaniel +13,imagenet,a photo of barn +14,imagenet,a photo of home theater +15,imagenet,a photo of walking stick +16,imagenet,a photo of notebook +17,imagenet,a photo of syringe +18,imagenet,a photo of mask +19,imagenet,a photo of nipple +20,imagenet,a photo of volleyball +21,imagenet,a photo of vulture +22,imagenet,a photo of cloak +23,imagenet,a photo of whiskey jug +24,imagenet,a photo of church +25,imagenet,a photo of bolo tie +26,imagenet,a photo of toy terrier +27,imagenet,a photo of lionfish +28,imagenet,a photo of Bouvier des Flandres +29,imagenet,a photo of photocopier +30,imagenet,a photo of teddy +31,imagenet,a photo of lighter +32,imagenet,a photo of horizontal bar +33,imagenet,a photo of magpie +34,imagenet,a photo of tiger shark +35,imagenet,a photo of wall clock +36,imagenet,a photo of leaf beetle +37,imagenet,a photo of stole +38,imagenet,a photo of basenji +39,imagenet,a photo of tricycle +40,imagenet,a photo of sports car +41,imagenet,a photo of green mamba +42,imagenet,a photo of shopping cart +43,imagenet,a photo of dining table +44,imagenet,a photo of custard apple +45,imagenet,a photo of jackfruit +46,imagenet,a photo of cellular telephone +47,imagenet,a photo of sleeping bag +48,imagenet,a photo of reflex camera +49,imagenet,a photo of beacon +50,imagenet,a photo of bikini +51,imagenet,a photo of dowitcher +52,imagenet,a photo of abacus +53,imagenet,a photo of miniskirt +54,imagenet,a photo of coil +55,imagenet,a photo of lacewing +56,imagenet,a photo of lumbermill +57,imagenet,a photo of white stork +58,imagenet,a photo of parallel bars +59,imagenet,a photo of sliding door +60,imagenet,a photo of lawn mower +61,imagenet,a photo of scuba diver +62,imagenet,a photo of cardigan +63,imagenet,a photo of American coot +64,imagenet,a photo of Border terrier +65,imagenet,a photo of purse +66,imagenet,a photo of gown +67,imagenet,a photo of megalith +68,imagenet,a photo of Polaroid camera +69,imagenet,a photo of green snake +70,imagenet,a photo of guillotine +71,imagenet,a photo of cricket +72,imagenet,a photo of academic gown +73,imagenet,a photo of can opener +74,imagenet,a photo of colobus +75,imagenet,a photo of hip +76,imagenet,a photo of bathtub +77,imagenet,a photo of Norwich terrier +78,imagenet,a photo of Arabian camel +79,imagenet,a photo of Labrador retriever +80,imagenet,a photo of hognose snake +81,imagenet,a photo of overskirt +82,imagenet,a photo of garter snake +83,imagenet,a photo of giant panda +84,imagenet,a photo of Lhasa +85,imagenet,a photo of folding chair +86,imagenet,a photo of lycaenid +87,imagenet,a photo of swimsuit +88,imagenet,a photo of crayfish +89,imagenet,a photo of balance beam +90,imagenet,a photo of junco +91,imagenet,a photo of Christmas stocking +92,imagenet,a photo of quill +93,imagenet,a photo of conch +94,imagenet,a photo of shield +95,imagenet,a photo of trailer truck +96,imagenet,a photo of wooden spoon +97,imagenet,a photo of mountain tent +98,imagenet,a photo of guinea pig +99,imagenet,a photo of tow truck +100,imagenet,a photo of bloodhound +101,imagenet,a photo of rifle +102,imagenet,a photo of grand piano +103,imagenet,a photo of schooner +104,imagenet,a photo of prison +105,imagenet,a photo of Great Pyrenees +106,imagenet,a photo of brain coral +107,imagenet,a photo of nail +108,imagenet,a photo of meat loaf +109,imagenet,a photo of Bedlington terrier +110,imagenet,a photo of steam locomotive +111,imagenet,a photo of crutch +112,imagenet,a photo of Sussex spaniel +113,imagenet,a photo of Great Dane +114,imagenet,a photo of frying pan +115,imagenet,a photo of Tibetan terrier +116,imagenet,a photo of ostrich +117,imagenet,a photo of lampshade +118,imagenet,a photo of standard poodle +119,imagenet,a photo of rock python +120,imagenet,a photo of sunglass +121,imagenet,a photo of plow +122,imagenet,a photo of great grey owl +123,imagenet,a photo of macaque +124,imagenet,a photo of spoonbill +125,imagenet,a photo of jay +126,imagenet,a photo of bookshop +127,imagenet,a photo of quail +128,imagenet,a photo of hyena +129,imagenet,a photo of bee eater +130,imagenet,a photo of croquet ball +131,imagenet,a photo of cabbage butterfly +132,imagenet,a photo of electric fan +133,imagenet,a photo of slug +134,imagenet,a photo of rapeseed +135,imagenet,a photo of worm fence +136,imagenet,a photo of chambered nautilus +137,imagenet,a photo of Windsor tie +138,imagenet,a photo of paintbrush +139,imagenet,a photo of marimba +140,imagenet,a photo of common iguana +141,imagenet,a photo of dial telephone +142,imagenet,a photo of space shuttle +143,imagenet,a photo of hippopotamus +144,imagenet,a photo of cinema +145,imagenet,a photo of cockroach +146,imagenet,a photo of accordion +147,imagenet,a photo of cello +148,imagenet,a photo of water bottle +149,imagenet,a photo of honeycomb +150,imagenet,a photo of bagel +151,imagenet,a photo of lipstick +152,imagenet,a photo of black stork +153,imagenet,a photo of eggnog +154,imagenet,a photo of lorikeet +155,imagenet,a photo of flatworm +156,imagenet,a photo of container ship +157,imagenet,a photo of Egyptian cat +158,imagenet,a photo of miniature pinscher +159,imagenet,a photo of minibus +160,imagenet,a photo of suspension bridge +161,imagenet,a photo of house finch +162,imagenet,a photo of safety pin +163,imagenet,a photo of malamute +164,imagenet,a photo of gibbon +165,imagenet,a photo of lesser panda +166,imagenet,a photo of plunger +167,imagenet,a photo of greenhouse +168,imagenet,a photo of black grouse +169,imagenet,a photo of disk brake +170,imagenet,a photo of tennis ball +171,imagenet,a photo of digital clock +172,imagenet,a photo of cassette +173,imagenet,a photo of streetcar +174,imagenet,a photo of coral reef +175,imagenet,a photo of rock crab +176,imagenet,a photo of weasel +177,imagenet,a photo of steel drum +178,imagenet,a photo of letter opener +179,imagenet,a photo of football helmet +180,imagenet,a photo of trolleybus +181,imagenet,a photo of mortarboard +182,imagenet,a photo of knot +183,imagenet,a photo of leatherback turtle +184,imagenet,a photo of backpack +185,imagenet,a photo of potter wheel +186,imagenet,a photo of chainlink fence +187,imagenet,a photo of poncho +188,imagenet,a photo of pajama +189,imagenet,a photo of miniature schnauzer +190,imagenet,a photo of solar dish +191,imagenet,a photo of breastplate +192,imagenet,a photo of grocery store +193,imagenet,a photo of bra +194,imagenet,a photo of tiger +195,imagenet,a photo of beach wagon +196,imagenet,a photo of rule +197,imagenet,a photo of miniature poodle +198,imagenet,a photo of American chameleon +199,imagenet,a photo of black swan +200,imagenet,a photo of armadillo +201,imagenet,a photo of tennis ball +202,imagenet,a photo of mitten +203,imagenet,a photo of agama +204,imagenet,a photo of polecat +205,imagenet,a photo of space heater +206,imagenet,a photo of dhole +207,imagenet,a photo of monitor +208,imagenet,a photo of sturgeon +209,imagenet,a photo of radio telescope +210,imagenet,a photo of ballet shoe +211,imagenet,a photo of cannon +212,imagenet,a photo of ballet skirt +213,imagenet,a photo of padlock +214,imagenet,a photo of tape player +215,imagenet,a photo of white wolf +216,imagenet,a photo of tub +217,imagenet,a photo of cheetah +218,imagenet,a photo of terrapin +219,imagenet,a photo of Lakeland terrier +220,imagenet,a photo of maillot +221,imagenet,a photo of brown bear +222,imagenet,a photo of pomegranate +223,imagenet,a photo of whiptail +224,imagenet,a photo of scabbard +225,imagenet,a photo of hand-held computer +226,imagenet,a photo of otter +227,imagenet,a photo of bullet train +228,imagenet,a photo of kit fox +229,imagenet,a photo of typewriter keyboard +230,imagenet,a photo of catamaran +231,imagenet,a photo of ashcan +232,imagenet,a photo of scale +233,imagenet,a photo of pineapple +234,imagenet,a photo of dishrag +235,imagenet,a photo of fountain pen +236,imagenet,a photo of comic book +237,imagenet,a photo of piggy bank +238,imagenet,a photo of water jug +239,imagenet,a photo of electric locomotive +240,imagenet,a photo of gorilla +241,imagenet,a photo of racket +242,imagenet,a photo of binoculars +243,imagenet,a photo of holster diff --git a/data/prompts/imagenet243_retain.csv b/data/prompts/imagenet243_retain.csv new file mode 100644 index 00000000..912e619b --- /dev/null +++ b/data/prompts/imagenet243_retain.csv @@ -0,0 +1,244 @@ +case_num,source,prompt +1,imagenet,a photo of strawberry +2,imagenet,a photo of pedestal +3,imagenet,a photo of scoreboard +4,imagenet,a photo of jaguar +5,imagenet,a photo of stove +6,imagenet,a photo of hummingbird +7,imagenet,a photo of tobacco shop +8,imagenet,a photo of Greater Swiss Mountain dog +9,imagenet,a photo of wine bottle +10,imagenet,a photo of yellow lady-slipper +11,imagenet,a photo of ballpoint +12,imagenet,a photo of Irish water spaniel +13,imagenet,a photo of barn +14,imagenet,a photo of home theater +15,imagenet,a photo of walking stick +16,imagenet,a photo of notebook +17,imagenet,a photo of syringe +18,imagenet,a photo of mask +19,imagenet,a photo of nipple +20,imagenet,a photo of volleyball +21,imagenet,a photo of vulture +22,imagenet,a photo of cloak +23,imagenet,a photo of whiskey jug +24,imagenet,a photo of church +25,imagenet,a photo of bolo tie +26,imagenet,a photo of toy terrier +27,imagenet,a photo of lionfish +28,imagenet,a photo of Bouvier des Flandres +29,imagenet,a photo of photocopier +30,imagenet,a photo of teddy +31,imagenet,a photo of lighter +32,imagenet,a photo of horizontal bar +33,imagenet,a photo of magpie +34,imagenet,a photo of tiger shark +35,imagenet,a photo of wall clock +36,imagenet,a photo of leaf beetle +37,imagenet,a photo of stole +38,imagenet,a photo of basenji +39,imagenet,a photo of tricycle +40,imagenet,a photo of sports car +41,imagenet,a photo of green mamba +42,imagenet,a photo of shopping cart +43,imagenet,a photo of dining table +44,imagenet,a photo of custard apple +45,imagenet,a photo of jackfruit +46,imagenet,a photo of cellular telephone +47,imagenet,a photo of sleeping bag +48,imagenet,a photo of reflex camera +49,imagenet,a photo of beacon +50,imagenet,a photo of safe +51,imagenet,a photo of dowitcher +52,imagenet,a photo of abacus +53,imagenet,a photo of koala +54,imagenet,a photo of coil +55,imagenet,a photo of lacewing +56,imagenet,a photo of lumbermill +57,imagenet,a photo of white stork +58,imagenet,a photo of parallel bars +59,imagenet,a photo of sliding door +60,imagenet,a photo of lawn mower +61,imagenet,a photo of wolf spider +62,imagenet,a photo of cardigan +63,imagenet,a photo of American coot +64,imagenet,a photo of Border terrier +65,imagenet,a photo of purse +66,imagenet,a photo of hotdog +67,imagenet,a photo of megalith +68,imagenet,a photo of Polaroid camera +69,imagenet,a photo of green snake +70,imagenet,a photo of guillotine +71,imagenet,a photo of cricket +72,imagenet,a photo of academic gown +73,imagenet,a photo of can opener +74,imagenet,a photo of colobus +75,imagenet,a photo of tree frog +76,imagenet,a photo of bathtub +77,imagenet,a photo of Norwich terrier +78,imagenet,a photo of Arabian camel +79,imagenet,a photo of Labrador retriever +80,imagenet,a photo of hognose snake +81,imagenet,a photo of overskirt +82,imagenet,a photo of garter snake +83,imagenet,a photo of giant panda +84,imagenet,a photo of Lhasa +85,imagenet,a photo of folding chair +86,imagenet,a photo of lycaenid +87,imagenet,a photo of plate +88,imagenet,a photo of crayfish +89,imagenet,a photo of balance beam +90,imagenet,a photo of junco +91,imagenet,a photo of Christmas stocking +92,imagenet,a photo of quill +93,imagenet,a photo of conch +94,imagenet,a photo of shield +95,imagenet,a photo of trailer truck +96,imagenet,a photo of wooden spoon +97,imagenet,a photo of mountain tent +98,imagenet,a photo of guinea pig +99,imagenet,a photo of tow truck +100,imagenet,a photo of bloodhound +101,imagenet,a photo of rifle +102,imagenet,a photo of grand piano +103,imagenet,a photo of schooner +104,imagenet,a photo of prison +105,imagenet,a photo of Great Pyrenees +106,imagenet,a photo of brain coral +107,imagenet,a photo of snail +108,imagenet,a photo of meat loaf +109,imagenet,a photo of Bedlington terrier +110,imagenet,a photo of steam locomotive +111,imagenet,a photo of crutch +112,imagenet,a photo of Sussex spaniel +113,imagenet,a photo of Great Dane +114,imagenet,a photo of frying pan +115,imagenet,a photo of Tibetan terrier +116,imagenet,a photo of ostrich +117,imagenet,a photo of lampshade +118,imagenet,a photo of standard poodle +119,imagenet,a photo of rock python +120,imagenet,a photo of sunglass +121,imagenet,a photo of plow +122,imagenet,a photo of great grey owl +123,imagenet,a photo of macaque +124,imagenet,a photo of spoonbill +125,imagenet,a photo of jay +126,imagenet,a photo of bookshop +127,imagenet,a photo of quail +128,imagenet,a photo of hyena +129,imagenet,a photo of bee eater +130,imagenet,a photo of croquet ball +131,imagenet,a photo of cabbage butterfly +132,imagenet,a photo of electric fan +133,imagenet,a photo of slug +134,imagenet,a photo of rapeseed +135,imagenet,a photo of worm fence +136,imagenet,a photo of chambered nautilus +137,imagenet,a photo of Windsor tie +138,imagenet,a photo of paintbrush +139,imagenet,a photo of marimba +140,imagenet,a photo of common iguana +141,imagenet,a photo of dial telephone +142,imagenet,a photo of space shuttle +143,imagenet,a photo of hippopotamus +144,imagenet,a photo of cinema +145,imagenet,a photo of cockroach +146,imagenet,a photo of accordion +147,imagenet,a photo of cello +148,imagenet,a photo of water bottle +149,imagenet,a photo of honeycomb +150,imagenet,a photo of bagel +151,imagenet,a photo of vase +152,imagenet,a photo of black stork +153,imagenet,a photo of eggnog +154,imagenet,a photo of lorikeet +155,imagenet,a photo of flatworm +156,imagenet,a photo of container ship +157,imagenet,a photo of Egyptian cat +158,imagenet,a photo of miniature pinscher +159,imagenet,a photo of minibus +160,imagenet,a photo of suspension bridge +161,imagenet,a photo of house finch +162,imagenet,a photo of safety pin +163,imagenet,a photo of malamute +164,imagenet,a photo of gibbon +165,imagenet,a photo of lesser panda +166,imagenet,a photo of plunger +167,imagenet,a photo of greenhouse +168,imagenet,a photo of black grouse +169,imagenet,a photo of disk brake +170,imagenet,a photo of jeep +171,imagenet,a photo of digital clock +172,imagenet,a photo of cassette +173,imagenet,a photo of streetcar +174,imagenet,a photo of coral reef +175,imagenet,a photo of rock crab +176,imagenet,a photo of weasel +177,imagenet,a photo of steel drum +178,imagenet,a photo of letter opener +179,imagenet,a photo of football helmet +180,imagenet,a photo of trolleybus +181,imagenet,a photo of mortarboard +182,imagenet,a photo of knot +183,imagenet,a photo of leatherback turtle +184,imagenet,a photo of backpack +185,imagenet,a photo of potter wheel +186,imagenet,a photo of chainlink fence +187,imagenet,a photo of poncho +188,imagenet,a photo of pajama +189,imagenet,a photo of miniature schnauzer +190,imagenet,a photo of solar dish +191,imagenet,a photo of breastplate +192,imagenet,a photo of grocery store +193,imagenet,a photo of pot +194,imagenet,a photo of tiger +195,imagenet,a photo of beach wagon +196,imagenet,a photo of rule +197,imagenet,a photo of miniature poodle +198,imagenet,a photo of American chameleon +199,imagenet,a photo of black swan +200,imagenet,a photo of armadillo +201,imagenet,a photo of tennis ball +202,imagenet,a photo of mitten +203,imagenet,a photo of agama +204,imagenet,a photo of polecat +205,imagenet,a photo of space heater +206,imagenet,a photo of dhole +207,imagenet,a photo of monitor +208,imagenet,a photo of sturgeon +209,imagenet,a photo of radio telescope +210,imagenet,a photo of pillow +211,imagenet,a photo of cannon +212,imagenet,a photo of jean +213,imagenet,a photo of padlock +214,imagenet,a photo of tape player +215,imagenet,a photo of white wolf +216,imagenet,a photo of tub +217,imagenet,a photo of cheetah +218,imagenet,a photo of terrapin +219,imagenet,a photo of Lakeland terrier +220,imagenet,a photo of washer +221,imagenet,a photo of brown bear +222,imagenet,a photo of pomegranate +223,imagenet,a photo of whiptail +224,imagenet,a photo of scabbard +225,imagenet,a photo of hand-held computer +226,imagenet,a photo of otter +227,imagenet,a photo of bullet train +228,imagenet,a photo of kit fox +229,imagenet,a photo of typewriter keyboard +230,imagenet,a photo of catamaran +231,imagenet,a photo of ashcan +232,imagenet,a photo of scale +233,imagenet,a photo of pineapple +234,imagenet,a photo of dishrag +235,imagenet,a photo of fountain pen +236,imagenet,a photo of comic book +237,imagenet,a photo of piggy bank +238,imagenet,a photo of water jug +239,imagenet,a photo of electric locomotive +240,imagenet,a photo of gorilla +241,imagenet,a photo of racket +242,imagenet,a photo of binoculars +243,imagenet,a photo of holster diff --git a/data/prompts/small_imagenet_prompts.csv b/data/prompts/small_imagenet_prompts.csv new file mode 100644 index 00000000..c715e1f4 --- /dev/null +++ b/data/prompts/small_imagenet_prompts.csv @@ -0,0 +1,101 @@ +,case_number,prompt,evaluation_seed,class +0,0,Image of cassette player,4068,cassette player +1,1,Image of cassette player,4667,cassette player +2,2,Image of cassette player,3410,cassette player +3,3,Image of cassette player,3703,cassette player +4,4,Image of cassette player,4937,cassette player +5,5,Image of cassette player,4001,cassette player +6,6,Image of cassette player,2228,cassette player +7,7,Image of cassette player,1217,cassette player +8,8,Image of cassette player,624,cassette player +9,9,Image of cassette player,4697,cassette player +10,10,Image of chain saw,4373,chain saw +11,11,Image of chain saw,2268,chain saw +12,12,Image of chain saw,104,chain saw +13,13,Image of chain saw,1216,chain saw +14,14,Image of chain saw,643,chain saw +15,15,Image of chain saw,3070,chain saw +16,16,Image of chain saw,2426,chain saw +17,17,Image of chain saw,2158,chain saw +18,18,Image of chain saw,2486,chain saw +19,19,Image of chain saw,1434,chain saw +20,20,Image of church,987,church +21,21,Image of church,682,church +22,22,Image of church,4092,church +23,23,Image of church,4096,church +24,24,Image of church,1467,church +25,25,Image of church,474,church +26,26,Image of church,640,church +27,27,Image of church,3395,church +28,28,Image of church,2373,church +29,29,Image of church,3178,church +30,30,Image of gas pump,432,gas pump +31,31,Image of gas pump,4975,gas pump +32,32,Image of gas pump,4745,gas pump +33,33,Image of gas pump,1790,gas pump +34,34,Image of gas pump,4392,gas pump +35,35,Image of gas pump,1527,gas pump +36,36,Image of gas pump,4490,gas pump +37,37,Image of gas pump,1951,gas pump +38,38,Image of gas pump,3013,gas pump +39,39,Image of gas pump,1887,gas pump +40,40,Image of tench,4889,tench +41,41,Image of tench,2747,tench +42,42,Image of tench,3723,tench +43,43,Image of tench,4717,tench +44,44,Image of tench,3199,tench +45,45,Image of tench,3499,tench +46,46,Image of tench,3710,tench +47,47,Image of tench,3682,tench +48,48,Image of tench,3405,tench +49,49,Image of tench,3726,tench +50,50,Image of garbage truck,4264,garbage truck +51,51,Image of garbage truck,4434,garbage truck +52,52,Image of garbage truck,2925,garbage truck +53,53,Image of garbage truck,1441,garbage truck +54,54,Image of garbage truck,3035,garbage truck +55,55,Image of garbage truck,1590,garbage truck +56,56,Image of garbage truck,4153,garbage truck +57,57,Image of garbage truck,1363,garbage truck +58,58,Image of garbage truck,207,garbage truck +59,59,Image of garbage truck,126,garbage truck +60,60,Image of english springer,4782,english springer +61,61,Image of english springer,1026,english springer +62,62,Image of english springer,4423,english springer +63,63,Image of english springer,639,english springer +64,64,Image of english springer,1316,english springer +65,65,Image of english springer,1780,english springer +66,66,Image of english springer,1330,english springer +67,67,Image of english springer,3695,english springer +68,68,Image of english springer,3010,english springer +69,69,Image of english springer,4249,english springer +70,70,Image of golf ball,1912,golf ball +71,71,Image of golf ball,1761,golf ball +72,72,Image of golf ball,529,golf ball +73,73,Image of golf ball,1905,golf ball +74,74,Image of golf ball,55,golf ball +75,75,Image of golf ball,1513,golf ball +76,76,Image of golf ball,2151,golf ball +77,77,Image of golf ball,3368,golf ball +78,78,Image of golf ball,4837,golf ball +79,79,Image of golf ball,289,golf ball +80,80,Image of parachute,1945,parachute +81,81,Image of parachute,841,parachute +82,82,Image of parachute,3651,parachute +83,83,Image of parachute,404,parachute +84,84,Image of parachute,4071,parachute +85,85,Image of parachute,4829,parachute +86,86,Image of parachute,1322,parachute +87,87,Image of parachute,4084,parachute +88,88,Image of parachute,3242,parachute +89,89,Image of parachute,623,parachute +90,90,Image of french horn,1562,french horn +91,91,Image of french horn,2179,french horn +92,92,Image of french horn,3982,french horn +93,93,Image of french horn,4753,french horn +94,94,Image of french horn,2985,french horn +95,95,Image of french horn,3018,french horn +96,96,Image of french horn,1500,french horn +97,97,Image of french horn,488,french horn +98,98,Image of french horn,371,french horn +99,99,Image of french horn,2387,french horn diff --git a/mu_attack/adv_unlearn_environment.yaml b/mu_attack/adv_unlearn_environment.yaml new file mode 100644 index 00000000..1f659d9d --- /dev/null +++ b/mu_attack/adv_unlearn_environment.yaml @@ -0,0 +1,34 @@ +name: AdvUnlearn +channels: + - pytorch + - defaults +dependencies: + - python=3.8.5 + - pip=20.3 + - cudatoolkit=11.3 + - pytorch=1.11.0 + - torchvision=0.12.0 + - numpy=1.19.2 + - pip: + - albumentations==0.4.3 + - diffusers==0.12.1 + - opencv-python==4.1.2.30 + - pudb==2019.2 + - invisible-watermark + - imageio==2.9.0 + - imageio-ffmpeg==0.4.2 + - pytorch-lightning==1.4.2 + - omegaconf==2.1.1 + - test-tube>=0.7.5 + - streamlit>=0.73.1 + - einops==0.3.0 + - torch-fidelity==0.3.0 + - transformers==4.25.1 + - torchmetrics==0.6.0 + - kornia==0.6 + - matplotlib + - wandb + - tabulate + - -e git+https://github.com/CompVis/taming-transformers.git@master#egg=taming-transformers + - -e git+https://github.com/openai/CLIP.git@main#egg=clip + - -e . diff --git a/mu_attack/execs/adv_attack.py b/mu_attack/execs/adv_attack.py index 7ca4241c..0d846513 100644 --- a/mu_attack/execs/adv_attack.py +++ b/mu_attack/execs/adv_attack.py @@ -116,8 +116,6 @@ def setup(self): # --- Training Setup --- ddim_eta = self.ddim_eta # constant value for training - - # Load the VAE self.vae = AutoencoderKL.from_pretrained(self.model_name_or_path, subfolder="vae", cache_dir=self.cache_path).to(self.devices[0]) @@ -231,6 +229,7 @@ def train(self): self.custom_text_encoder.text_encoder.train() self.custom_text_encoder.text_encoder.requires_grad_(True) self.model.eval() + # print('==== Train text_encoder ====') else: self.custom_text_encoder.text_encoder.eval() self.custom_text_encoder.text_encoder.requires_grad_(False) @@ -254,6 +253,7 @@ def train(self): else: retain_text_input = None retain_text_embeddings = None + # retain_emb_0 = None retain_emb_p = None retain_emb_n = None From 3913de4b32a7660d68b9cbaadda9bc70a39b755a Mon Sep 17 00:00:00 2001 From: palisthadeshar Date: Tue, 4 Feb 2025 08:57:14 +0000 Subject: [PATCH 04/22] migrated soft prompt attack --- mu_attack/adv_unlearn_environment.yaml | 3 +- mu_attack/attackers/soft_prompt.py | 7 + .../configs/adv_unlearn/adv_unlearn_config.py | 46 +- .../configs/adv_unlearn/model_config.yaml | 10 +- mu_attack/execs/adv_attack.py | 402 ++----- mu_attack/helpers/utils.py | 1071 +---------------- mu_attack/src/clip | 1 + mu_attack/src/taming-transformers | 1 + 8 files changed, 148 insertions(+), 1393 deletions(-) create mode 160000 mu_attack/src/clip create mode 160000 mu_attack/src/taming-transformers diff --git a/mu_attack/adv_unlearn_environment.yaml b/mu_attack/adv_unlearn_environment.yaml index 1f659d9d..8c55a203 100644 --- a/mu_attack/adv_unlearn_environment.yaml +++ b/mu_attack/adv_unlearn_environment.yaml @@ -17,6 +17,7 @@ dependencies: - invisible-watermark - imageio==2.9.0 - imageio-ffmpeg==0.4.2 + - huggingface_hub==0.10.1 - pytorch-lightning==1.4.2 - omegaconf==2.1.1 - test-tube>=0.7.5 @@ -31,4 +32,4 @@ dependencies: - tabulate - -e git+https://github.com/CompVis/taming-transformers.git@master#egg=taming-transformers - -e git+https://github.com/openai/CLIP.git@main#egg=clip - - -e . + # - -e . diff --git a/mu_attack/attackers/soft_prompt.py b/mu_attack/attackers/soft_prompt.py index 8f27cd35..eb16ad9a 100644 --- a/mu_attack/attackers/soft_prompt.py +++ b/mu_attack/attackers/soft_prompt.py @@ -1,5 +1,8 @@ +# mu_attack/attackers/soft_prompt.py + import torch import wandb + from mu.helpers import sample_model from mu_attack.helpers.utils import split_id, id2embedding, split_embd, init_adv, construct_embd, construct_id @@ -174,6 +177,10 @@ def attack(self, global_step, word, attack_round, attack_type, wandb.log({'Attack_Loss': loss.item()}, step=global_step + i) wandb.log({'Train_Loss': 0.0}, step=global_step + i) + print(f'Step: {global_step + i}, Attack_Loss: {loss.item()}') + print(f'Step: {global_step + i}, Train_Loss: 0.0') + + # --- Return the adversarial embeddings and input IDs --- if attack_embd_type == 'condition_embd': diff --git a/mu_attack/configs/adv_unlearn/adv_unlearn_config.py b/mu_attack/configs/adv_unlearn/adv_unlearn_config.py index 99dc1172..f683af32 100644 --- a/mu_attack/configs/adv_unlearn/adv_unlearn_config.py +++ b/mu_attack/configs/adv_unlearn/adv_unlearn_config.py @@ -7,49 +7,39 @@ class AdvUnlearnConfig(BaseConfig): def __init__(self, **kwargs): # Inference & Model Paths - self.config_path = current_dir / "configs/stable-diffusion/v1-inference.yaml" + self.config_path = current_dir / "model_config.yaml" self.ckpt_path = "models/sd-v1-4-full-ema.ckpt" - self.diffusers_config_path = current_dir / "diffusers_unet_config.json" self.model_name_or_path = "CompVis/stable-diffusion-v1-4" - self.cache_path = ".cache" # Devices & IO self.devices = "0,0" # You can later parse this string into a list if needed. self.seperator = None - self.output_dir = "outputs/adv_unlearn" + self.cache_path = ".cache" # Image & Diffusion Sampling - self.image_size = 512 - self.ddim_steps = 50 self.start_guidance = 3.0 - self.negative_guidance = 1.0 + self.ddim_steps = 50 + # Training Setup + self.image_size = 512 self.prompt = "nudity" - self.dataset_retain = "coco" # Choices: 'coco_object', 'coco_object_no_filter', 'imagenet243', 'imagenet243_no_filter' - self.retain_batch = 5 - self.retain_train = "iter" # Options: 'iter' or 'reg' - self.retain_step = 1 - self.retain_loss_w = 1.0 - self.ddim_eta = 0 - - self.train_method = "text_encoder_full" #choices: text_encoder_full', 'text_encoder_layer0', 'text_encoder_layer01', 'text_encoder_layer012', 'text_encoder_layer0123', 'text_encoder_layer01234', 'text_encoder_layer012345', 'text_encoder_layer0123456', 'text_encoder_layer01234567', 'text_encoder_layer012345678', 'text_encoder_layer0123456789', 'text_encoder_layer012345678910', 'text_encoder_layer01234567891011', 'text_encoder_layer0_11','text_encoder_layer01_1011', 'text_encoder_layer012_91011', 'noxattn', 'selfattn', 'xattn', 'full', 'notime', 'xlayer', 'selflayer - self.norm_layer = False # This is a flag; use True if you wish to update the norm layer. self.attack_method = "pgd" # Choices: 'pgd', 'multi_pgd', 'fast_at', 'free_at' - self.component = "all" # Choices: 'all', 'ffn', 'attn' - self.iterations = 1000 - self.save_interval = 200 - self.lr = 1e-5 + self.ddim_eta = 0 # Adversarial Attack Hyperparameters self.adv_prompt_num = 1 + self.attack_init_embd = None self.attack_embd_type = "word_embd" # Choices: 'word_embd', 'condition_embd' self.attack_type = "prefix_k" # Choices: 'replace_k', 'add', 'prefix_k', 'suffix_k', 'mid_k', 'insert_k', 'per_k_words' self.attack_init = "latest" # Choices: 'random', 'latest' self.attack_step = 30 - self.adv_prompt_update_step = 1 self.attack_lr = 1e-3 - self.warmup_iter = 200 + + #wandb configs + self.project_name = "quick-canvas-machine-unlearning" + self.experiment_name = f'AdvUnlearn-{self.prompt}-method_Attack_{self.attack_method}' + # Override default values with any provided keyword arguments. for key, value in kwargs.items(): @@ -59,22 +49,10 @@ def validate_config(self): """ Perform basic validation on the config parameters. """ - if self.retain_batch <= 0: - raise ValueError("retain_batch should be a positive integer.") - if self.lr <= 0: - raise ValueError("Learning rate (lr) should be positive.") - if self.image_size <= 0: - raise ValueError("Image size should be a positive integer.") - if self.iterations <= 0: - raise ValueError("Iterations must be a positive integer.") if not os.path.exists(self.config_path): raise FileNotFoundError(f"Model config file {self.config_path} does not exist.") if not os.path.exists(self.ckpt_path): raise FileNotFoundError(f"Checkpoint file {self.ckpt_path} does not exist.") - if not os.path.exists(self.diffusers_config_path): - raise FileNotFoundError(f"Diffusers config file {self.diffusers_config_path} does not exist.") - if not os.path.exists(self.output_dir): - os.makedirs(self.output_dir) adv_unlearn_config = AdvUnlearnConfig() diff --git a/mu_attack/configs/adv_unlearn/model_config.yaml b/mu_attack/configs/adv_unlearn/model_config.yaml index d4effe56..cf7f8131 100644 --- a/mu_attack/configs/adv_unlearn/model_config.yaml +++ b/mu_attack/configs/adv_unlearn/model_config.yaml @@ -1,6 +1,6 @@ model: base_learning_rate: 1.0e-04 - target: ldm.models.diffusion.ddpm.LatentDiffusion + target: stable_diffusion.ldm.models.diffusion.ddpm.LatentDiffusion params: linear_start: 0.00085 linear_end: 0.0120 @@ -18,7 +18,7 @@ model: use_ema: False scheduler_config: # 10000 warmup steps - target: ldm.lr_scheduler.LambdaLinearScheduler + target: stable_diffusion.ldm.lr_scheduler.LambdaLinearScheduler params: warm_up_steps: [ 10000 ] cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases @@ -27,7 +27,7 @@ model: f_min: [ 1. ] unet_config: - target: ldm.modules.diffusionmodules.openaimodel.UNetModel + target: stable_diffusion.ldm.modules.diffusionmodules.openaimodel.UNetModel params: image_size: 32 # unused in_channels: 4 @@ -44,7 +44,7 @@ model: legacy: False first_stage_config: - target: ldm.models.autoencoder.AutoencoderKL + target: stable_diffusion.ldm.models.autoencoder.AutoencoderKL params: embed_dim: 4 monitor: val/rec_loss @@ -67,4 +67,4 @@ model: target: torch.nn.Identity cond_stage_config: - target: ldm.modules.encoders.modules.FrozenCLIPEmbedder + target: stable_diffusion.ldm.modules.encoders.modules.FrozenCLIPEmbedder diff --git a/mu_attack/execs/adv_attack.py b/mu_attack/execs/adv_attack.py index 0d846513..50459d27 100644 --- a/mu_attack/execs/adv_attack.py +++ b/mu_attack/execs/adv_attack.py @@ -1,355 +1,129 @@ - +# mu_attack/execs/adv_attack.py import torch -from tqdm import tqdm import random import wandb from transformers import CLIPTextModel, CLIPTokenizer -from diffusers import AutoencoderKL from mu_attack.configs.adv_unlearn import AdvUnlearnConfig -from mu.helpers import sample_model -from mu_attack.tasks.utils.text_encoder import CustomTextEncoder from mu_attack.attackers.soft_prompt import SoftPromptAttack -from mu_attack.helpers.utils import id2embedding, param_choices, get_models, retain_prompt, get_train_loss_retain,save_text_encoder, save_model, save_history - +from mu_attack.tasks.utils.text_encoder import CustomTextEncoder +from mu_attack.helpers.utils import get_models class AdvUnlearn: """ Class for adversarial unlearning training. - This class wraps the full training pipeline including prompt cleaning, - attack (adversarial prompt generation), and retention-based regularized training. + This class wraps the full training pipeline including adversarial attack + and model handling. """ - def __init__( - self, - config: AdvUnlearnConfig, - **kwargs - ): + def __init__(self, config: AdvUnlearnConfig, **kwargs): self.config = config.__dict__ for key, value in kwargs.items(): setattr(config, key, value) config.validate_config() - self.config = config self.prompt = config.prompt - self.dataset_retain = config.dataset_retain - self.retain_batch = config.retain_batch - self.retain_train = config.retain_train - self.retain_step = config.retain_step - self.retain_loss_w = config.retain_loss_w - self.attack_method = config.attack_method - self.train_method = config.train_method - self.norm_layer = config.norm_layer - self.component = config.component self.model_name_or_path = config.model_name_or_path - self.start_guidance = config.start_guidance - self.negative_guidance = config.negative_guidance - self.iterations = config.iterations - self.save_interval = config.save_interval - self.lr = config.lr - self.config_path = config.config_path - self.ckpt_path = config.ckpt_path - self.diffusers_config_path = config.diffusers_config_path - self.output_dir = config.output_dir - self.devices = config.devices - self.seperator = config.seperator - self.image_size = config.image_size - self.ddim_steps = config.ddim_steps - self.adv_prompt_num = config.adv_prompt_num - self.attack_embd_type = config.attack_embd_type + self.cache_path = config.cache_path + self.devices = [f'cuda:{int(d.strip())}' for d in config.devices.split(',')] self.attack_type = config.attack_type - self.attack_init = config.attack_init - self.warmup_iter = config.warmup_iter + self.attack_embd_type = config.attack_embd_type self.attack_step = config.attack_step self.attack_lr = config.attack_lr - self.adv_prompt_update_step = config.adv_prompt_update_step + self.attack_init = config.attack_init + self.attack_init_embd = config.attack_init_embd + self.attack_method = config.attack_method + self.ddim_steps = config.ddim_steps self.ddim_eta = config.ddim_eta - self.cache_path = config.cache_path - - # Will be set during training. - self.words = None - self.retain_dataset = None - self.tokenizer = None - self.text_encoder = None - self.custom_text_encoder = None - self.all_embeddings = None - self.vae = None - self.model_orig = None - self.sampler_orig = None - self.model = None - self.sampler = None - self.parameters = None - self.opt = None + self.image_size = config.image_size + self.adv_prompt_num = config.adv_prompt_num + self.start_guidance = config.start_guidance + self.config_path = config.config_path + self.ckpt_path = config.ckpt_path self.criteria = torch.nn.MSELoss() - # For adversarial prompt update - self.adv_word_embd = None - self.adv_condition_embd = None - self.adv_input_ids = None + # Initialize wandb + wandb.init( + project=config.project_name, + name=config.experiment_name, + reinit=True + ) - def setup(self): - """Stage 0 & 1: Prompt cleaning and training setup.""" - # --- Prompt cleaning --- - word_print = self.prompt.replace(' ', '') - # Special cases for certain prompts - if self.prompt == 'allartist': - self.prompt = "Kelly Mckernan, Thomas Kinkade, Ajin Demi Human, Alena Aenami, Tyler Edlin, Kilian Eng" - if self.prompt == 'i2p': - self.prompt = "hate, harassment, violence, suffering, humiliation, harm, suicide, sexual, nudity, bodily fluids, blood" - if self.prompt == "artifact": - self.prompt = ("ugly, tiling, poorly drawn hands, poorly drawn feet, poorly drawn face, out of frame, " - "mutation, mutated, extra limbs, extra legs, extra arms, disfigured, deformed, cross-eye, " - "body out of frame, blurry, bad art, bad anatomy, blurred, text, watermark, grainy") - - if self.seperator is not None: - self.words = [w.strip() for w in self.prompt.split(self.seperator)] - else: - self.words = [self.prompt] - print(f'The Concept Prompt to be unlearned: {self.words}') - - # Create a retaining dataset (assumed to be a prompt dataset) - self.retain_dataset = retain_prompt(self.dataset_retain) - - # --- Training Setup --- - ddim_eta = self.ddim_eta # constant value for training - - # Load the VAE - self.vae = AutoencoderKL.from_pretrained(self.model_name_or_path, subfolder="vae", cache_dir=self.cache_path).to(self.devices[0]) - # Load tokenizer and text encoder - self.tokenizer = CLIPTokenizer.from_pretrained(self.model_name_or_path, subfolder="tokenizer", cache_dir=self.cache_path) - self.text_encoder = CLIPTextModel.from_pretrained(self.model_name_or_path, subfolder="text_encoder", cache_dir=self.cache_path).to(self.devices[0]) + # Load models + self.load_models() + + def load_models(self): + """Loads the tokenizer, text encoder, and models.""" + self.tokenizer = CLIPTokenizer.from_pretrained( + self.model_name_or_path, subfolder="tokenizer", cache_dir=self.cache_path + ) + self.text_encoder = CLIPTextModel.from_pretrained( + self.model_name_or_path, subfolder="text_encoder", cache_dir=self.cache_path + ).to(self.devices[0]) self.custom_text_encoder = CustomTextEncoder(self.text_encoder).to(self.devices[0]) self.all_embeddings = self.custom_text_encoder.get_all_embedding().unsqueeze(0) - - # Load models using your helper function (assumed to be defined in utils) - self.model_orig, self.sampler_orig, self.model, self.sampler = get_models(self.config_path, self.ckpt_path, self.devices) - self.model_orig.eval() - # Setup trainable parameters based on train_method - if 'text_encoder' in self.train_method: - self.parameters = param_choices(model=self.custom_text_encoder, train_method=self.train_method, component=self.component, final_layer_norm=self.norm_layer) + # Load base models + self.model_orig, self.sampler_orig, self.model, self.sampler = get_models( + self.config_path, self.ckpt_path, self.devices + ) + + def attack(self): + """Performs the adversarial attack.""" + # Ensure words are in list format + if isinstance(self.prompt, str): + self.words = [self.prompt] + elif isinstance(self.prompt, list): + self.words = self.prompt else: - self.parameters = param_choices(model=self.model, train_method=self.train_method, component=self.component, final_layer_norm=self.norm_layer) - - self.opt = torch.optim.Adam(self.parameters, lr=self.lr) - - return word_print # For later use in saving history + raise ValueError("Prompt must be a string or a list of strings.") + + # Select a random word from the prompt list + word = random.choice(self.words) + + # Get learned condition embeddings + emb_0 = self.model_orig.get_learned_conditioning(['']) + emb_p = self.model_orig.get_learned_conditioning([word]) + + # Initialize attack class + sp_attack = SoftPromptAttack( + model=self.model, + model_orig=self.model_orig, + tokenizer=self.tokenizer, + text_encoder=self.custom_text_encoder, + sampler=self.sampler, + emb_0=emb_0, + emb_p=emb_p, + start_guidance=self.start_guidance, + devices=self.devices, + ddim_steps=self.ddim_steps, + ddim_eta=self.ddim_eta, + image_size=self.image_size, + criteria=self.criteria, + k=self.adv_prompt_num, + all_embeddings=self.all_embeddings + ) - def train(self): - """Stage 2: Training loop.""" - word_print = self.setup() - ddim_eta = self.ddim_eta # As used in training - - # A lambda function to sample until a given time step. - quick_sample_till_t = lambda x, s, code, batch, t: sample_model( - self.model, self.sampler, - x, self.image_size, self.image_size, self.ddim_steps, s, ddim_eta, - start_code=code, n_samples=batch, till_T=t, verbose=False + + self.adv_word_embd, self.adv_input_ids = sp_attack.attack( + global_step=0, + word=word, + attack_round=0, + attack_type=self.attack_type, + attack_embd_type=self.attack_embd_type, + attack_step=self.attack_step, + attack_lr=self.attack_lr, + attack_init=self.attack_init, + attack_init_embd=self.attack_init_embd, + attack_method=self.attack_method ) - - losses = [] - history = [] - global_step = 0 - attack_round = 0 - # Create a tqdm progress bar - pbar = tqdm(range(self.iterations)) - for i in pbar: - # --- Update adversarial prompt every adv_prompt_update_step iterations --- - if i % self.adv_prompt_update_step == 0: - # Reset the retaining dataset if needed - if self.retain_dataset.check_unseen_prompt_count() < self.retain_batch: - self.retain_dataset.reset() - - # Randomly choose one prompt from the list - word = random.sample(self.words, 1)[0] - text_input = self.tokenizer( - word, padding="max_length", max_length=self.tokenizer.model_max_length, - return_tensors="pt", truncation=True - ) - text_embeddings = id2embedding(self.tokenizer, self.all_embeddings, text_input.input_ids.to(self.devices[0]), self.devices[0]) - - # Get conditional embeddings from the frozen model - emb_0 = self.model_orig.get_learned_conditioning(['']) - emb_p = self.model_orig.get_learned_conditioning([word]) - - # --- Attack Step: Get adversarial prompt --- - if i >= self.warmup_iter: - self.custom_text_encoder.text_encoder.eval() - self.custom_text_encoder.text_encoder.requires_grad_(False) - self.model.eval() - - if attack_round == 0: - if self.attack_embd_type == 'word_embd': - self.adv_word_embd, self.adv_input_ids = SoftPromptAttack.attack( - global_step, word, self.model, self.model_orig, self.tokenizer, - self.custom_text_encoder, self.sampler, emb_0, emb_p, self.start_guidance, - self.devices, self.ddim_steps, ddim_eta, self.image_size, self.criteria, - self.adv_prompt_num, self.all_embeddings, attack_round, self.attack_type, - self.attack_embd_type, self.attack_step, self.attack_lr, self.attack_init, - None, self.attack_method - ) - elif self.attack_embd_type == 'condition_embd': - self.adv_condition_embd, self.adv_input_ids = SoftPromptAttack.attack( - global_step, word, self.model, self.model_orig, self.tokenizer, - self.custom_text_encoder, self.sampler, emb_0, emb_p, self.start_guidance, - self.devices, self.ddim_steps, ddim_eta, self.image_size, self.criteria, - self.adv_prompt_num, self.all_embeddings, attack_round, self.attack_type, - self.attack_embd_type, self.attack_step, self.attack_lr, self.attack_init, - None, self.attack_method - ) - else: - if self.attack_embd_type == 'word_embd': - self.adv_word_embd, self.adv_input_ids = SoftPromptAttack.attack( - global_step, word, self.model, self.model_orig, self.tokenizer, - self.custom_text_encoder, self.sampler, emb_0, emb_p, self.start_guidance, - self.devices, self.ddim_steps, ddim_eta, self.image_size, self.criteria, - self.adv_prompt_num, self.all_embeddings, attack_round, self.attack_type, - self.attack_embd_type, self.attack_step, self.attack_lr, self.attack_init, - self.adv_word_embd, self.attack_method - ) - elif self.attack_embd_type == 'condition_embd': - self.adv_condition_embd, self.adv_input_ids = SoftPromptAttack.attack( - global_step, word, self.model, self.model_orig, self.tokenizer, - self.custom_text_encoder, self.sampler, emb_0, emb_p, self.start_guidance, - self.devices, self.ddim_steps, ddim_eta, self.image_size, self.criteria, - self.adv_prompt_num, self.all_embeddings, attack_round, self.attack_type, - self.attack_embd_type, self.attack_step, self.attack_lr, self.attack_init, - self.adv_condition_embd, self.attack_method - ) - global_step += self.attack_step - attack_round += 1 - # --- Set models to training/eval modes based on training method --- - if 'text_encoder' in self.train_method: - self.custom_text_encoder.text_encoder.train() - self.custom_text_encoder.text_encoder.requires_grad_(True) - self.model.eval() - # print('==== Train text_encoder ====') - else: - self.custom_text_encoder.text_encoder.eval() - self.custom_text_encoder.text_encoder.requires_grad_(False) - self.model.train() - self.opt.zero_grad() - - # --- Retaining prompts for retention regularization --- - if self.retain_train == 'reg': - retain_words = self.retain_dataset.get_random_prompts(self.retain_batch) - retain_text_input = self.tokenizer( - retain_words, padding="max_length", max_length=self.tokenizer.model_max_length, - return_tensors="pt", truncation=True - ) - retain_input_ids = retain_text_input.input_ids.to(self.devices[0]) - - retain_emb_p = self.model_orig.get_learned_conditioning(retain_words) - retain_text_embeddings = id2embedding(self.tokenizer, self.all_embeddings, retain_text_input.input_ids.to(self.devices[0]), self.devices[0]) - # Reshape to [batch, 77, embedding_dim] - retain_text_embeddings = retain_text_embeddings.reshape(self.retain_batch, -1, retain_text_embeddings.shape[-1]) - retain_emb_n = self.custom_text_encoder(input_ids=retain_input_ids, inputs_embeds=retain_text_embeddings)[0] - else: - retain_text_input = None - retain_text_embeddings = None - # retain_emb_0 = None - retain_emb_p = None - retain_emb_n = None + return self.adv_word_embd, self.adv_input_ids - # --- Compute training loss --- - if i < self.warmup_iter: - # Warmup training uses the original prompt embeddings. - input_ids = text_input.input_ids.to(self.devices[0]) - emb_n = self.custom_text_encoder(input_ids=input_ids, inputs_embeds=text_embeddings)[0] - loss = get_train_loss_retain( - self.retain_batch, self.retain_train, self.retain_loss_w, - self.model, self.model_orig, self.custom_text_encoder, self.sampler, - emb_0, emb_p, retain_emb_p, emb_n, retain_emb_n, self.start_guidance, - self.negative_guidance, self.devices, self.ddim_steps, ddim_eta, - self.image_size, self.criteria, input_ids, self.attack_embd_type - ) - else: - if self.attack_embd_type == 'word_embd': - loss = get_train_loss_retain( - self.retain_batch, self.retain_train, self.retain_loss_w, - self.model, self.model_orig, self.custom_text_encoder, self.sampler, - emb_0, emb_p, retain_emb_p, None, retain_emb_n, self.start_guidance, - self.negative_guidance, self.devices, self.ddim_steps, ddim_eta, - self.image_size, self.criteria, self.adv_input_ids, self.attack_embd_type, self.adv_word_embd - ) - elif self.attack_embd_type == 'condition_embd': - loss = get_train_loss_retain( - self.retain_batch, self.retain_train, self.retain_loss_w, - self.model, self.model_orig, self.custom_text_encoder, self.sampler, - emb_0, emb_p, retain_emb_p, None, retain_emb_n, self.start_guidance, - self.negative_guidance, self.devices, self.ddim_steps, ddim_eta, - self.image_size, self.criteria, self.adv_input_ids, self.attack_embd_type, self.adv_condition_embd - ) - - # Backpropagate loss and update weights. - loss.backward() - losses.append(loss.item()) - pbar.set_postfix({"loss": loss.item()}) - history.append(loss.item()) - wandb.log({'Train_Loss': loss.item()}, step=global_step) - wandb.log({'Attack_Loss': 0.0}, step=global_step) - global_step += 1 - self.opt.step() - - # --- Additional Retention Training (if using iterative retention) --- - if self.retain_train == 'iter': - for r in range(self.retain_step): - print(f'==== Retain Training at step {r} ====') - self.opt.zero_grad() - if self.retain_dataset.check_unseen_prompt_count() < self.retain_batch: - self.retain_dataset.reset() - retain_words = self.retain_dataset.get_random_prompts(self.retain_batch) - - t_enc = torch.randint(self.ddim_steps, (1,), device=self.devices[0]) - og_num = round((int(t_enc) / self.ddim_steps) * 1000) - og_num_lim = round((int(t_enc + 1) / self.ddim_steps) * 1000) - t_enc_ddpm = torch.randint(og_num, og_num_lim, (1,), device=self.devices[0]) - retain_start_code = torch.randn((self.retain_batch, 4, 64, 64)).to(self.devices[0]) - - retain_emb_p = self.model_orig.get_learned_conditioning(retain_words) - retain_z = quick_sample_till_t(retain_emb_p.to(self.devices[0]), self.start_guidance, retain_start_code, self.retain_batch, int(t_enc)) - retain_e_p = self.model_orig.apply_model(retain_z.to(self.devices[0]), t_enc_ddpm.to(self.devices[0]), retain_emb_p.to(self.devices[0])) - - retain_text_input = self.tokenizer( - retain_words, padding="max_length", max_length=self.tokenizer.model_max_length, - return_tensors="pt", truncation=True - ) - retain_input_ids = retain_text_input.input_ids.to(self.devices[0]) - retain_text_embeddings = id2embedding(self.tokenizer, self.all_embeddings, retain_text_input.input_ids.to(self.devices[0]), self.devices[0]) - retain_text_embeddings = retain_text_embeddings.reshape(self.retain_batch, -1, retain_text_embeddings.shape[-1]) - retain_emb_n = self.custom_text_encoder(input_ids=retain_input_ids, inputs_embeds=retain_text_embeddings)[0] - retain_e_n = self.model.apply_model(retain_z.to(self.devices[0]), t_enc_ddpm.to(self.devices[0]), retain_emb_n.to(self.devices[0])) - - retain_loss = self.criteria(retain_e_n.to(self.devices[0]), retain_e_p.to(self.devices[0])) - retain_loss.backward() - self.opt.step() - - # --- Checkpointing and saving history --- - if (i + 1) % self.save_interval == 0 and (i + 1) != self.iterations and (i + 1) >= self.save_interval: - if 'text_encoder' in self.train_method: - save_text_encoder(self.output_dir, self.custom_text_encoder, self.train_method, i) - else: - save_model(self.output_dir, self.model, self.train_method, i, save_compvis=True, - save_diffusers=True, compvis_config_file=self.config_path, - diffusers_config_file=self.diffusers_config_path) - if i % 1 == 0: - save_history(self.output_dir, losses, word_print) - - # --- Stage 3: Save final model and loss curve --- - self.model.eval() - self.custom_text_encoder.text_encoder.eval() - self.custom_text_encoder.text_encoder.requires_grad_(False) - if 'text_encoder' in self.train_method: - save_text_encoder(self.output_dir, self.custom_text_encoder, self.train_method, i) - else: - save_model(self.output_dir, self.model, self.train_method, i, save_compvis=True, - save_diffusers=True, compvis_config_file=self.config_path, - diffusers_config_file=self.diffusers_config_path) - save_history(self.output_dir, losses, word_print) + \ No newline at end of file diff --git a/mu_attack/helpers/utils.py b/mu_attack/helpers/utils.py index 0b6e7d93..02e3d48a 100644 --- a/mu_attack/helpers/utils.py +++ b/mu_attack/helpers/utils.py @@ -2,31 +2,13 @@ import pandas as pd import random import yaml -import numpy as np -import matplotlib.pyplot as plt -from omegaconf import OmegaConf - import torch import torch.nn.functional as F from torchvision.transforms.functional import InterpolationMode -from transformers.modeling_outputs import BaseModelOutputWithPooling import torchvision.transforms as torch_transforms -from diffusers import ( - AutoencoderKL, - DDIMScheduler, - DPMSolverMultistepScheduler, - EulerAncestralDiscreteScheduler, - EulerDiscreteScheduler, - HeunDiscreteScheduler, - LDMTextToImagePipeline, - LMSDiscreteScheduler, - PNDMScheduler, - StableDiffusionPipeline, - UNet2DConditionModel, -) -from mu.helpers import sample_model + from mu.helpers.utils import load_model_from_config from stable_diffusion.ldm.models.diffusion.ddim import DDIMSampler @@ -66,13 +48,13 @@ def retain_prompt(dataset_retain): # Prompt Dataset to be retained if dataset_retain == 'imagenet243': - retain_dataset = PromptDataset('./data/prompts/train/imagenet243_retain.csv') + retain_dataset = PromptDataset('data/prompts/train/imagenet243_retain.csv') elif dataset_retain == 'imagenet243_no_filter': - retain_dataset = PromptDataset('./data/prompts/train/imagenet243_no_filter_retain.csv') + retain_dataset = PromptDataset('data/prompts/train/imagenet243_no_filter_retain.csv') elif dataset_retain == 'coco_object': - retain_dataset = PromptDataset('./data/prompts/train/coco_object_retain.csv') + retain_dataset = PromptDataset('data/prompts/train/coco_object_retain.csv') elif dataset_retain == 'coco_object_no_filter': - retain_dataset = PromptDataset('./data/prompts/train/coco_object_no_filter_retain.csv') + retain_dataset = PromptDataset('data/prompts/train/coco_object_no_filter_retain.csv') else: raise ValueError('Invalid dataset for retaining prompts') @@ -252,431 +234,6 @@ def construct_id(k, adv_id, insertion_location,sot_id,eot_id,mid_id): return input_ids -def param_choices(model, train_method, component='all', final_layer_norm=False): - # choose parameters to train based on train_method - parameters = [] - - # Text Encoder FUll Weight Tuning - if train_method == 'text_encoder_full': - for name, param in model.text_encoder.text_model.named_parameters(): - # Final Layer Norm - if name.startswith('final_layer_norm'): - if component == 'all' or final_layer_norm==True: - print(name) - parameters.append(param) - else: - pass - - # Transformer layers - elif name.startswith('encoder'): - if component == 'ffn' and 'mlp' in name: - print(name) - parameters.append(param) - elif component == 'attn' and 'self_attn' in name: - print(name) - parameters.append(param) - elif component == 'all': - print(name) - parameters.append(param) - else: - pass - - # Embedding layers - else: - pass - - # Text Encoder Layer 0 Tuning - elif train_method == 'text_encoder_layer0': - for name, param in model.text_encoder.text_model.named_parameters(): - # Encoder Layer 0 - if name.startswith('encoder.layers.0'): - if component == 'ffn' and 'mlp' in name: - print(name) - parameters.append(param) - elif component == 'attn' and 'self_attn' in name: - print(name) - parameters.append(param) - elif component == 'all': - print(name) - parameters.append(param) - else: - pass - - elif name.startswith('final_layer_norm') and final_layer_norm==True: - print(name) - parameters.append(param) - - else: - pass - - elif train_method == 'text_encoder_layer01': - for name, param in model.text_encoder.text_model.named_parameters(): - # Encoder Layer 0 - if name.startswith('encoder.layers.0') or name.startswith('encoder.layers.1'): - if component == 'ffn' and 'mlp' in name: - print(name) - parameters.append(param) - elif component == 'attn' and 'self_attn' in name: - print(name) - parameters.append(param) - elif component == 'all': - print(name) - parameters.append(param) - else: - pass - elif name.startswith('final_layer_norm') and final_layer_norm==True: - print(name) - parameters.append(param) - - else: - pass - - elif train_method == 'text_encoder_layer012': - for name, param in model.text_encoder.text_model.named_parameters(): - # Encoder Layer 0 - if name.startswith('encoder.layers.0') or name.startswith('encoder.layers.1') or name.startswith('encoder.layers.2'): - if component == 'ffn' and 'mlp' in name: - print(name) - parameters.append(param) - elif component == 'attn' and 'self_attn' in name: - print(name) - parameters.append(param) - elif component == 'all': - print(name) - parameters.append(param) - else: - pass - - elif name.startswith('final_layer_norm') and final_layer_norm==True: - print(name) - parameters.append(param) - - else: - pass - - elif train_method == 'text_encoder_layer0123': - for name, param in model.text_encoder.text_model.named_parameters(): - # Encoder Layer 0 - if name.startswith('encoder.layers.0') or name.startswith('encoder.layers.1') or name.startswith('encoder.layers.2') or name.startswith('encoder.layers.3'): - if component == 'ffn' and 'mlp' in name: - print(name) - parameters.append(param) - elif component == 'attn' and 'self_attn' in name: - print(name) - parameters.append(param) - elif component == 'all': - print(name) - parameters.append(param) - else: - pass - - elif name.startswith('final_layer_norm') and final_layer_norm==True: - print(name) - parameters.append(param) - - else: - pass - - elif train_method == 'text_encoder_layer01234': - for name, param in model.text_encoder.text_model.named_parameters(): - # Encoder Layer 0 - if name.startswith('encoder.layers.0') or name.startswith('encoder.layers.1') or name.startswith('encoder.layers.2') or name.startswith('encoder.layers.3') or name.startswith('encoder.layers.4'): - if component == 'ffn' and 'mlp' in name: - print(name) - parameters.append(param) - elif component == 'attn' and 'self_attn' in name: - print(name) - parameters.append(param) - elif component == 'all': - print(name) - parameters.append(param) - else: - pass - - elif name.startswith('final_layer_norm') and final_layer_norm==True: - print(name) - parameters.append(param) - - else: - pass - - elif train_method == 'text_encoder_layer012345': - for name, param in model.text_encoder.text_model.named_parameters(): - # Encoder Layer 0 - if name.startswith('encoder.layers.0') or name.startswith('encoder.layers.1') or name.startswith('encoder.layers.2') or name.startswith('encoder.layers.3') or name.startswith('encoder.layers.4') or name.startswith('encoder.layers.5'): - if component == 'ffn' and 'mlp' in name: - print(name) - parameters.append(param) - elif component == 'attn' and 'self_attn' in name: - print(name) - parameters.append(param) - elif component == 'all': - print(name) - parameters.append(param) - else: - pass - - elif name.startswith('final_layer_norm') and final_layer_norm==True: - print(name) - parameters.append(param) - - else: - pass - - elif train_method == 'text_encoder_layer0123456': - for name, param in model.text_encoder.text_model.named_parameters(): - # Encoder Layer 0 - if name.startswith('encoder.layers.0') or name.startswith('encoder.layers.1') or name.startswith('encoder.layers.2') or name.startswith('encoder.layers.3') or name.startswith('encoder.layers.4') or name.startswith('encoder.layers.5') or name.startswith('encoder.layers.6'): - if component == 'ffn' and 'mlp' in name: - print(name) - parameters.append(param) - elif component == 'attn' and 'self_attn' in name: - print(name) - parameters.append(param) - elif component == 'all': - print(name) - parameters.append(param) - else: - pass - - elif name.startswith('final_layer_norm') and final_layer_norm==True: - print(name) - parameters.append(param) - - else: - pass - - elif train_method == 'text_encoder_layer01234567': - for name, param in model.text_encoder.text_model.named_parameters(): - # Encoder Layer 0 - if name.startswith('encoder.layers.0') or name.startswith('encoder.layers.1') or name.startswith('encoder.layers.2') or name.startswith('encoder.layers.3') or name.startswith('encoder.layers.4') or name.startswith('encoder.layers.5') or name.startswith('encoder.layers.6') or name.startswith('encoder.layers.7'): - if component == 'ffn' and 'mlp' in name: - print(name) - parameters.append(param) - elif component == 'attn' and 'self_attn' in name: - print(name) - parameters.append(param) - elif component == 'all': - print(name) - parameters.append(param) - else: - pass - - elif name.startswith('final_layer_norm') and final_layer_norm==True: - print(name) - parameters.append(param) - - else: - pass - - elif train_method == 'text_encoder_layer012345678': - for name, param in model.text_encoder.text_model.named_parameters(): - # Encoder Layer 0 - if name.startswith('encoder.layers.0') or name.startswith('encoder.layers.1') or name.startswith('encoder.layers.2') or name.startswith('encoder.layers.3') or name.startswith('encoder.layers.4') or name.startswith('encoder.layers.5') or name.startswith('encoder.layers.6') or name.startswith('encoder.layers.7') or name.startswith('encoder.layers.8'): - if component == 'ffn' and 'mlp' in name: - print(name) - parameters.append(param) - elif component == 'attn' and 'self_attn' in name: - print(name) - parameters.append(param) - elif component == 'all': - print(name) - parameters.append(param) - else: - pass - - elif name.startswith('final_layer_norm') and final_layer_norm==True: - print(name) - parameters.append(param) - - else: - pass - - elif train_method == 'text_encoder_layer0123456789': - for name, param in model.text_encoder.text_model.named_parameters(): - # Encoder Layer 0 - if name.startswith('encoder.layers.0') or name.startswith('encoder.layers.1') or name.startswith('encoder.layers.2') or name.startswith('encoder.layers.3') or name.startswith('encoder.layers.4') or name.startswith('encoder.layers.5') or name.startswith('encoder.layers.6') or name.startswith('encoder.layers.7') or name.startswith('encoder.layers.8') or name.startswith('encoder.layers.9'): - if component == 'ffn' and 'mlp' in name: - print(name) - parameters.append(param) - elif component == 'attn' and 'self_attn' in name: - print(name) - parameters.append(param) - elif component == 'all': - print(name) - parameters.append(param) - else: - pass - - elif name.startswith('final_layer_norm') and final_layer_norm==True: - print(name) - parameters.append(param) - - else: - pass - - elif train_method == 'text_encoder_layer012345678910': - for name, param in model.text_encoder.text_model.named_parameters(): - # Encoder Layer 0 - if name.startswith('encoder.layers.0') or name.startswith('encoder.layers.1') or name.startswith('encoder.layers.2') or name.startswith('encoder.layers.3') or name.startswith('encoder.layers.4') or name.startswith('encoder.layers.5') or name.startswith('encoder.layers.6') or name.startswith('encoder.layers.7') or name.startswith('encoder.layers.8') or name.startswith('encoder.layers.9') or name.startswith('encoder.layers.10'): - if component == 'ffn' and 'mlp' in name: - print(name) - parameters.append(param) - elif component == 'attn' and 'self_attn' in name: - print(name) - parameters.append(param) - elif component == 'all': - print(name) - parameters.append(param) - else: - pass - - elif name.startswith('final_layer_norm') and final_layer_norm==True: - print(name) - parameters.append(param) - - else: - pass - - elif train_method == 'text_encoder_layer01234567891011': - for name, param in model.text_encoder.text_model.named_parameters(): - # Encoder Layer 0 - if name.startswith('encoder.layers.0') or name.startswith('encoder.layers.1') or name.startswith('encoder.layers.2') or name.startswith('encoder.layers.3') or name.startswith('encoder.layers.4') or name.startswith('encoder.layers.5') or name.startswith('encoder.layers.6') or name.startswith('encoder.layers.7') or name.startswith('encoder.layers.8') or name.startswith('encoder.layers.9') or name.startswith('encoder.layers.10') or name.startswith('encoder.layers.11'): - if component == 'ffn' and 'mlp' in name: - print(name) - parameters.append(param) - elif component == 'attn' and 'self_attn' in name: - print(name) - parameters.append(param) - elif component == 'all': - print(name) - parameters.append(param) - else: - pass - - elif name.startswith('final_layer_norm') and final_layer_norm==True: - print(name) - parameters.append(param) - - else: - pass - - elif train_method == 'text_encoder_layer0_11': - for name, param in model.text_encoder.text_model.named_parameters(): - # Encoder Layer 0 - if name.startswith('encoder.layers.0') or name.startswith('encoder.layers.11'): - if component == 'ffn' and 'mlp' in name: - print(name) - parameters.append(param) - elif component == 'attn' and 'self_attn' in name: - print(name) - parameters.append(param) - elif component == 'all': - print(name) - parameters.append(param) - else: - pass - - elif name.startswith('final_layer_norm') and final_layer_norm==True: - print(name) - parameters.append(param) - - else: - pass - - - elif train_method == 'text_encoder_layer01_1011': - for name, param in model.text_encoder.text_model.named_parameters(): - # Encoder Layer 0 - if name.startswith('encoder.layers.0') or name.startswith('encoder.layers.1') or name.startswith('encoder.layers.10') or name.startswith('encoder.layers.11'): - if component == 'ffn' and 'mlp' in name: - print(name) - parameters.append(param) - elif component == 'attn' and 'self_attn' in name: - print(name) - parameters.append(param) - elif component == 'all': - print(name) - parameters.append(param) - else: - pass - - elif name.startswith('final_layer_norm') and final_layer_norm==True: - print(name) - parameters.append(param) - - else: - pass - - elif train_method == 'text_encoder_layer012_91011': - for name, param in model.text_encoder.text_model.named_parameters(): - # Encoder Layer 0 - if name.startswith('encoder.layers.0') or name.startswith('encoder.layers.1') or name.startswith('encoder.layers.2') or name.startswith('encoder.layers.9') or name.startswith('encoder.layers.10') or name.startswith('encoder.layers.11'): - if component == 'ffn' and 'mlp' in name: - print(name) - parameters.append(param) - elif component == 'attn' and 'self_attn' in name: - print(name) - parameters.append(param) - elif component == 'all': - print(name) - parameters.append(param) - else: - pass - - elif name.startswith('final_layer_norm') and final_layer_norm==True: - print(name) - parameters.append(param) - - else: - pass - - # UNet Model Tuning - else: - for name, param in model.model.diffusion_model.named_parameters(): - # train all layers except x-attns and time_embed layers - if train_method == 'noxattn': - if name.startswith('out.') or 'attn2' in name or 'time_embed' in name: - pass - else: - print(name) - parameters.append(param) - - # train only self attention layers - if train_method == 'selfattn': - if 'attn1' in name: - print(name) - parameters.append(param) - - # train only x attention layers - if train_method == 'xattn': - if 'attn2' in name: - print(name) - parameters.append(param) - - # train all layers - if train_method == 'full': - print(name) - parameters.append(param) - - # train all layers except time embed layers - if train_method == 'notime': - if not (name.startswith('out.') or 'time_embed' in name): - print(name) - parameters.append(param) - if train_method == 'xlayer': - if 'attn2' in name: - if 'output_blocks.6.' in name or 'output_blocks.8.' in name: - print(name) - parameters.append(param) - if train_method == 'selflayer': - if 'attn1' in name: - if 'input_blocks.4.' in name or 'input_blocks.7.' in name: - print(name) - parameters.append(param) - - return parameters - def get_models(config_path, ckpt_path, devices): model_orig = load_model_from_config(config_path, ckpt_path, devices[1]) @@ -687,595 +244,31 @@ def get_models(config_path, ckpt_path, devices): return model_orig, sampler_orig, model, sampler -def get_train_loss_retain( retain_batch, retain_train, retain_loss_w, model, model_orig, text_encoder, sampler, emb_0, emb_p, retain_emb_p, emb_n, retain_emb_n, start_guidance, negative_guidance, devices, ddim_steps, ddim_eta, image_size, criteria, adv_input_ids, attack_embd_type, adv_embd=None): - """_summary_ - - Args: - model: ESD model - model_orig: frozen DDPM model - sampler: DDIMSampler for DDPM model - - emb_0: unconditional embedding - emb_p: conditional embedding (for ground truth concept) - emb_n: conditional embedding (for modified concept) - - start_guidance: unconditional guidance for ESD model - negative_guidance: negative guidance for ESD model - - devices: list of devices for ESD and DDPM models - ddim_steps: number of steps for DDIMSampler - ddim_eta: eta for DDIMSampler - image_size: image size for DDIMSampler - - criteria: loss function for ESD model - - adv_input_ids: input_ids for adversarial word embedding - adv_emb_n: adversarial conditional embedding - adv_word_emb_n: adversarial word embedding - - Returns: - loss: training loss for ESD model - """ - quick_sample_till_t = lambda x, s, code, batch, t: sample_model(model, sampler, - x, image_size, image_size, ddim_steps, s, ddim_eta, - start_code=code, n_samples=batch, till_T=t, verbose=False) - - - t_enc = torch.randint(ddim_steps, (1,), device=devices[0]) - # time step from 1000 to 0 (0 being good) - og_num = round((int(t_enc)/ddim_steps)*1000) - og_num_lim = round((int(t_enc+1)/ddim_steps)*1000) - - t_enc_ddpm = torch.randint(og_num, og_num_lim, (1,), device=devices[0]) - - start_code = torch.randn((1, 4, 64, 64)).to(devices[0]) - if retain_train == 'reg': - retain_start_code = torch.randn((retain_batch, 4, 64, 64)).to(devices[0]) - - with torch.no_grad(): - # generate an image with the concept from ESD model - z = quick_sample_till_t(emb_p.to(devices[0]), start_guidance, start_code, 1, int(t_enc)) # emb_p seems to work better instead of emb_0 - # get conditional and unconditional scores from frozen model at time step t and image z - e_0 = model_orig.apply_model(z.to(devices[0]), t_enc_ddpm.to(devices[0]), emb_0.to(devices[0])) - e_p = model_orig.apply_model(z.to(devices[0]), t_enc_ddpm.to(devices[0]), emb_p.to(devices[0])) - - if retain_train == 'reg': - retain_z = quick_sample_till_t(retain_emb_p.to(devices[0]), start_guidance, retain_start_code, retain_batch, int(t_enc)) # emb_p seems to work better instead of emb_0 - # retain_e_0 = model_orig.apply_model(retain_z.to(devices[0]), t_enc_ddpm.to(devices[0]), retain_emb_0.to(devices[0])) - retain_e_p = model_orig.apply_model(retain_z.to(devices[0]), t_enc_ddpm.to(devices[0]), retain_emb_p.to(devices[0])) - - if adv_embd is None: - e_n = model.apply_model(z.to(devices[0]), t_enc_ddpm.to(devices[0]), emb_n.to(devices[0])) - else: - if attack_embd_type == 'condition_embd': - # Train with adversarial conditional embedding - e_n = model.apply_model(z.to(devices[0]), t_enc_ddpm.to(devices[0]), adv_embd.to(devices[0])) - elif attack_embd_type == 'word_embd': - # Train with adversarial word embedding - print('====== Training with adversarial word embedding =====') - adv_emb_n = text_encoder(input_ids = adv_input_ids.to(devices[0]), inputs_embeds=adv_embd.to(devices[0]))[0] - e_n = model.apply_model(z.to(devices[0]), t_enc_ddpm.to(devices[0]), adv_emb_n.to(devices[0])) - else: - raise ValueError('attack_embd_type must be either condition_embd or word_embd') - - e_0.requires_grad = False - e_p.requires_grad = False - - # reconstruction loss for ESD objective from frozen model and conditional score of ESD model - # loss = criteria(e_n.to(devices[0]), e_0.to(devices[0]) - (negative_guidance*(e_p.to(devices[0]) - e_0.to(devices[0])))) - - # return loss - - if retain_train == 'reg': - # reconstruction loss for ESD objective from frozen model and conditional score of ESD model - print('====== Training with retain batch =====') - unlearn_loss = criteria(e_n.to(devices[0]), e_0.to(devices[0]) - (negative_guidance*(e_p.to(devices[0]) - e_0.to(devices[0])))) - - retain_e_n = model.apply_model(retain_z.to(devices[0]), t_enc_ddpm.to(devices[0]), retain_emb_n.to(devices[0])) - - # retain_e_0.requires_grad = False - retain_e_p.requires_grad = False - retain_loss = criteria(retain_e_n.to(devices[0]), retain_e_p.to(devices[0])) - - loss = unlearn_loss + retain_loss_w * retain_loss - return loss - - else: - # reconstruction loss for ESD objective from frozen model and conditional score of ESD model - unlearn_loss = criteria(e_n.to(devices[0]), e_0.to(devices[0]) - (negative_guidance*(e_p.to(devices[0]) - e_0.to(devices[0])))) - return unlearn_loss - -def save_text_encoder(folder_path, model, name, num): - # SAVE MODEL - - # PATH = f'{FOLDER}/{model_type}-word_{word_print}-method_{train_method}-sg_{start_guidance}-ng_{neg_guidance}-iter_{i+1}-lr_{lr}-startmodel_{start_model}-numacc_{numacc}.pt' - folder_path = f'{folder_path}/models' - os.makedirs(folder_path, exist_ok=True) - if num is not None: - path = f'{folder_path}/TextEncoder-{name}-epoch_{num}.pt' - else: - path = f'{folder_path}/TextEncoder-{name}.pt' - - torch.save(model.state_dict(), path) - - - -def create_unet_diffusers_config(original_config, image_size: int): - """ - Creates a config for the diffusers based on the config of the LDM model. - """ - unet_params = original_config.model.params.unet_config.params - vae_params = original_config.model.params.first_stage_config.params.ddconfig - - block_out_channels = [unet_params.model_channels * mult for mult in unet_params.channel_mult] - - down_block_types = [] - resolution = 1 - for i in range(len(block_out_channels)): - block_type = "CrossAttnDownBlock2D" if resolution in unet_params.attention_resolutions else "DownBlock2D" - down_block_types.append(block_type) - if i != len(block_out_channels) - 1: - resolution *= 2 - - up_block_types = [] - for i in range(len(block_out_channels)): - block_type = "CrossAttnUpBlock2D" if resolution in unet_params.attention_resolutions else "UpBlock2D" - up_block_types.append(block_type) - resolution //= 2 - - vae_scale_factor = 2 ** (len(vae_params.ch_mult) - 1) - - head_dim = unet_params.num_heads if "num_heads" in unet_params else None - use_linear_projection = ( - unet_params.use_linear_in_transformer if "use_linear_in_transformer" in unet_params else False - ) - if use_linear_projection: - # stable diffusion 2-base-512 and 2-768 - if head_dim is None: - head_dim = [5, 10, 20, 20] - - config = dict( - sample_size=image_size // vae_scale_factor, - in_channels=unet_params.in_channels, - out_channels=unet_params.out_channels, - down_block_types=tuple(down_block_types), - up_block_types=tuple(up_block_types), - block_out_channels=tuple(block_out_channels), - layers_per_block=unet_params.num_res_blocks, - cross_attention_dim=unet_params.context_dim, - attention_head_dim=head_dim, - use_linear_projection=use_linear_projection, - ) - - return config - - -def convert_ldm_unet_checkpoint(checkpoint, config, path=None, extract_ema=False): - """ - Takes a state dict and a config, and returns a converted checkpoint. - """ - - # extract state_dict for UNet - unet_state_dict = {} - keys = list(checkpoint.keys()) - - unet_key = "model.diffusion_model." - # at least a 100 parameters have to start with `model_ema` in order for the checkpoint to be EMA - if sum(k.startswith("model_ema") for k in keys) > 100 and extract_ema: - print(f"Checkpoint {path} has both EMA and non-EMA weights.") - print( - "In this conversion only the EMA weights are extracted. If you want to instead extract the non-EMA" - " weights (useful to continue fine-tuning), please make sure to remove the `--extract_ema` flag." - ) - for key in keys: - if key.startswith("model.diffusion_model"): - flat_ema_key = "model_ema." + "".join(key.split(".")[1:]) - unet_state_dict[key.replace(unet_key, "")] = checkpoint.pop(flat_ema_key) - else: - if sum(k.startswith("model_ema") for k in keys) > 100: - print( - "In this conversion only the non-EMA weights are extracted. If you want to instead extract the EMA" - " weights (usually better for inference), please make sure to add the `--extract_ema` flag." - ) - - for key in keys: - if key.startswith(unet_key): - unet_state_dict[key.replace(unet_key, "")] = checkpoint.pop(key) - - new_checkpoint = {} - - new_checkpoint["time_embedding.linear_1.weight"] = unet_state_dict["time_embed.0.weight"] - new_checkpoint["time_embedding.linear_1.bias"] = unet_state_dict["time_embed.0.bias"] - new_checkpoint["time_embedding.linear_2.weight"] = unet_state_dict["time_embed.2.weight"] - new_checkpoint["time_embedding.linear_2.bias"] = unet_state_dict["time_embed.2.bias"] - - new_checkpoint["conv_in.weight"] = unet_state_dict["input_blocks.0.0.weight"] - new_checkpoint["conv_in.bias"] = unet_state_dict["input_blocks.0.0.bias"] - - new_checkpoint["conv_norm_out.weight"] = unet_state_dict["out.0.weight"] - new_checkpoint["conv_norm_out.bias"] = unet_state_dict["out.0.bias"] - new_checkpoint["conv_out.weight"] = unet_state_dict["out.2.weight"] - new_checkpoint["conv_out.bias"] = unet_state_dict["out.2.bias"] - - # Retrieves the keys for the input blocks only - num_input_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "input_blocks" in layer}) - input_blocks = { - layer_id: [key for key in unet_state_dict if f"input_blocks.{layer_id}" in key] - for layer_id in range(num_input_blocks) - } - - # Retrieves the keys for the middle blocks only - num_middle_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "middle_block" in layer}) - middle_blocks = { - layer_id: [key for key in unet_state_dict if f"middle_block.{layer_id}" in key] - for layer_id in range(num_middle_blocks) - } - - # Retrieves the keys for the output blocks only - num_output_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "output_blocks" in layer}) - output_blocks = { - layer_id: [key for key in unet_state_dict if f"output_blocks.{layer_id}" in key] - for layer_id in range(num_output_blocks) - } - - for i in range(1, num_input_blocks): - block_id = (i - 1) // (config["layers_per_block"] + 1) - layer_in_block_id = (i - 1) % (config["layers_per_block"] + 1) - - resnets = [ - key for key in input_blocks[i] if f"input_blocks.{i}.0" in key and f"input_blocks.{i}.0.op" not in key - ] - attentions = [key for key in input_blocks[i] if f"input_blocks.{i}.1" in key] - - if f"input_blocks.{i}.0.op.weight" in unet_state_dict: - new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.weight"] = unet_state_dict.pop( - f"input_blocks.{i}.0.op.weight" - ) - new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.bias"] = unet_state_dict.pop( - f"input_blocks.{i}.0.op.bias" - ) - - paths = renew_resnet_paths(resnets) - meta_path = {"old": f"input_blocks.{i}.0", "new": f"down_blocks.{block_id}.resnets.{layer_in_block_id}"} - assign_to_checkpoint( - paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config - ) - - if len(attentions): - paths = renew_attention_paths(attentions) - meta_path = {"old": f"input_blocks.{i}.1", "new": f"down_blocks.{block_id}.attentions.{layer_in_block_id}"} - assign_to_checkpoint( - paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config - ) - - resnet_0 = middle_blocks[0] - attentions = middle_blocks[1] - resnet_1 = middle_blocks[2] - - resnet_0_paths = renew_resnet_paths(resnet_0) - assign_to_checkpoint(resnet_0_paths, new_checkpoint, unet_state_dict, config=config) - - resnet_1_paths = renew_resnet_paths(resnet_1) - assign_to_checkpoint(resnet_1_paths, new_checkpoint, unet_state_dict, config=config) - - attentions_paths = renew_attention_paths(attentions) - meta_path = {"old": "middle_block.1", "new": "mid_block.attentions.0"} - assign_to_checkpoint( - attentions_paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config - ) - - for i in range(num_output_blocks): - block_id = i // (config["layers_per_block"] + 1) - layer_in_block_id = i % (config["layers_per_block"] + 1) - output_block_layers = [shave_segments(name, 2) for name in output_blocks[i]] - output_block_list = {} - - for layer in output_block_layers: - layer_id, layer_name = layer.split(".")[0], shave_segments(layer, 1) - if layer_id in output_block_list: - output_block_list[layer_id].append(layer_name) - else: - output_block_list[layer_id] = [layer_name] - - if len(output_block_list) > 1: - resnets = [key for key in output_blocks[i] if f"output_blocks.{i}.0" in key] - attentions = [key for key in output_blocks[i] if f"output_blocks.{i}.1" in key] - - resnet_0_paths = renew_resnet_paths(resnets) - paths = renew_resnet_paths(resnets) - - meta_path = {"old": f"output_blocks.{i}.0", "new": f"up_blocks.{block_id}.resnets.{layer_in_block_id}"} - assign_to_checkpoint( - paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config - ) - - output_block_list = {k: sorted(v) for k, v in output_block_list.items()} - if ["conv.bias", "conv.weight"] in output_block_list.values(): - index = list(output_block_list.values()).index(["conv.bias", "conv.weight"]) - new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.weight"] = unet_state_dict[ - f"output_blocks.{i}.{index}.conv.weight" - ] - new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.bias"] = unet_state_dict[ - f"output_blocks.{i}.{index}.conv.bias" - ] - - # Clear attentions as they have been attributed above. - if len(attentions) == 2: - attentions = [] - - if len(attentions): - paths = renew_attention_paths(attentions) - meta_path = { - "old": f"output_blocks.{i}.1", - "new": f"up_blocks.{block_id}.attentions.{layer_in_block_id}", - } - assign_to_checkpoint( - paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config - ) - else: - resnet_0_paths = renew_resnet_paths(output_block_layers, n_shave_prefix_segments=1) - for path in resnet_0_paths: - old_path = ".".join(["output_blocks", str(i), path["old"]]) - new_path = ".".join(["up_blocks", str(block_id), "resnets", str(layer_in_block_id), path["new"]]) - - new_checkpoint[new_path] = unet_state_dict[old_path] - - return new_checkpoint - -def assign_to_checkpoint( - paths, checkpoint, old_checkpoint, attention_paths_to_split=None, additional_replacements=None, config=None -): - """ - This does the final conversion step: take locally converted weights and apply a global renaming - to them. It splits attention layers, and takes into account additional replacements - that may arise. - - Assigns the weights to the new checkpoint. - """ - assert isinstance(paths, list), "Paths should be a list of dicts containing 'old' and 'new' keys." - - # Splits the attention layers into three variables. - if attention_paths_to_split is not None: - for path, path_map in attention_paths_to_split.items(): - old_tensor = old_checkpoint[path] - channels = old_tensor.shape[0] // 3 - - target_shape = (-1, channels) if len(old_tensor.shape) == 3 else (-1) - - num_heads = old_tensor.shape[0] // config["num_head_channels"] // 3 - - old_tensor = old_tensor.reshape((num_heads, 3 * channels // num_heads) + old_tensor.shape[1:]) - query, key, value = old_tensor.split(channels // num_heads, dim=1) - - checkpoint[path_map["query"]] = query.reshape(target_shape) - checkpoint[path_map["key"]] = key.reshape(target_shape) - checkpoint[path_map["value"]] = value.reshape(target_shape) - - for path in paths: - new_path = path["new"] - - # These have already been assigned - if attention_paths_to_split is not None and new_path in attention_paths_to_split: - continue - - # Global renaming happens here - new_path = new_path.replace("middle_block.0", "mid_block.resnets.0") - new_path = new_path.replace("middle_block.1", "mid_block.attentions.0") - new_path = new_path.replace("middle_block.2", "mid_block.resnets.1") - - if additional_replacements is not None: - for replacement in additional_replacements: - new_path = new_path.replace(replacement["old"], replacement["new"]) - - # proj_attn.weight has to be converted from conv 1D to linear - if "proj_attn.weight" in new_path: - checkpoint[new_path] = old_checkpoint[path["old"]][:, :, 0] - else: - checkpoint[new_path] = old_checkpoint[path["old"]] - - -def renew_attention_paths(old_list, n_shave_prefix_segments=0): - """ - Updates paths inside attentions to the new naming scheme (local renaming) - """ - mapping = [] - for old_item in old_list: - new_item = old_item - - # new_item = new_item.replace('norm.weight', 'group_norm.weight') - # new_item = new_item.replace('norm.bias', 'group_norm.bias') - - # new_item = new_item.replace('proj_out.weight', 'proj_attn.weight') - # new_item = new_item.replace('proj_out.bias', 'proj_attn.bias') - - # new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments) - - mapping.append({"old": old_item, "new": new_item}) - - return mapping - -def renew_resnet_paths(old_list, n_shave_prefix_segments=0): - """ - Updates paths inside resnets to the new naming scheme (local renaming) - """ - mapping = [] - for old_item in old_list: - new_item = old_item.replace("in_layers.0", "norm1") - new_item = new_item.replace("in_layers.2", "conv1") - - new_item = new_item.replace("out_layers.0", "norm2") - new_item = new_item.replace("out_layers.3", "conv2") - - new_item = new_item.replace("emb_layers.1", "time_emb_proj") - new_item = new_item.replace("skip_connection", "conv_shortcut") - - new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments) - - mapping.append({"old": old_item, "new": new_item}) - - return mapping - - -def savemodelDiffusers(path, name, compvis_config_file, diffusers_config_file, device='cpu'): - checkpoint_path = path - - original_config_file = compvis_config_file - config_file = diffusers_config_file - num_in_channels = 4 - scheduler_type = 'ddim' - pipeline_type = None - image_size = 512 - prediction_type = 'epsilon' - extract_ema = False - dump_path = path.replace('Compvis','Diffusers') - upcast_attention = False - - - if device is None: - device = "cuda" if torch.cuda.is_available() else "cpu" - checkpoint = torch.load(checkpoint_path, map_location=device) - else: - checkpoint = torch.load(checkpoint_path, map_location=device) - - # Sometimes models don't have the global_step item - if "global_step" in checkpoint: - global_step = checkpoint["global_step"] - else: - print("global_step key not found in model") - global_step = None - - if "state_dict" in checkpoint: - checkpoint = checkpoint["state_dict"] - upcast_attention = upcast_attention - if original_config_file is None: - key_name = "model.diffusion_model.input_blocks.2.1.transformer_blocks.0.attn2.to_k.weight" - - if key_name in checkpoint and checkpoint[key_name].shape[-1] == 1024: - if not os.path.isfile("v2-inference-v.yaml"): - # model_type = "v2" - os.system( - "wget https://raw.githubusercontent.com/Stability-AI/stablediffusion/main/configs/stable-diffusion/v2-inference-v.yaml" - " -O v2-inference-v.yaml" - ) - original_config_file = "./v2-inference-v.yaml" - - if global_step == 110000: - # v2.1 needs to upcast attention - upcast_attention = True - else: - if not os.path.isfile("v1-inference.yaml"): - # model_type = "v1" - os.system( - "wget https://raw.githubusercontent.com/CompVis/stable-diffusion/main/configs/stable-diffusion/v1-inference.yaml" - " -O v1-inference.yaml" - ) - original_config_file = "./v1-inference.yaml" - - original_config = OmegaConf.load(original_config_file) - - if num_in_channels is not None: - original_config["model"]["params"]["unet_config"]["params"]["in_channels"] = num_in_channels - - if ( - "parameterization" in original_config["model"]["params"] - and original_config["model"]["params"]["parameterization"] == "v" - ): - if prediction_type is None: - # NOTE: For stable diffusion 2 base it is recommended to pass `prediction_type=="epsilon"` - # as it relies on a brittle global step parameter here - prediction_type = "epsilon" if global_step == 875000 else "v_prediction" - if image_size is None: - # NOTE: For stable diffusion 2 base one has to pass `image_size==512` - # as it relies on a brittle global step parameter here - image_size = 512 if global_step == 875000 else 768 - else: - if prediction_type is None: - prediction_type = "epsilon" - if image_size is None: - image_size = 512 - - num_train_timesteps = original_config.model.params.timesteps - beta_start = original_config.model.params.linear_start - beta_end = original_config.model.params.linear_end - scheduler = DDIMScheduler( - beta_end=beta_end, - beta_schedule="scaled_linear", - beta_start=beta_start, - num_train_timesteps=num_train_timesteps, - steps_offset=1, - clip_sample=False, - set_alpha_to_one=False, - prediction_type=prediction_type, - ) - # make sure scheduler works correctly with DDIM - scheduler.register_to_config(clip_sample=False) - - if scheduler_type == "pndm": - config = dict(scheduler.config) - config["skip_prk_steps"] = True - scheduler = PNDMScheduler.from_config(config) - elif scheduler_type == "lms": - scheduler = LMSDiscreteScheduler.from_config(scheduler.config) - elif scheduler_type == "heun": - scheduler = HeunDiscreteScheduler.from_config(scheduler.config) - elif scheduler_type == "euler": - scheduler = EulerDiscreteScheduler.from_config(scheduler.config) - elif scheduler_type == "euler-ancestral": - scheduler = EulerAncestralDiscreteScheduler.from_config(scheduler.config) - elif scheduler_type == "dpm": - scheduler = DPMSolverMultistepScheduler.from_config(scheduler.config) - elif scheduler_type == "ddim": - scheduler = scheduler - else: - raise ValueError(f"Scheduler of type {scheduler_type} doesn't exist!") - - # Convert the UNet2DConditionModel model. - unet_config = create_unet_diffusers_config(original_config, image_size=image_size) - unet_config["upcast_attention"] = False - unet = UNet2DConditionModel(**unet_config) - - converted_unet_checkpoint = convert_ldm_unet_checkpoint( - checkpoint, unet_config, path=checkpoint_path, extract_ema=extract_ema - ) - torch.save(converted_unet_checkpoint, dump_path) - - - -def save_model(folder_path, model, name, num, compvis_config_file=None, diffusers_config_file=None, device='cpu', save_compvis=True, save_diffusers=True): - # SAVE MODEL - - # PATH = f'{FOLDER}/{model_type}-word_{word_print}-method_{train_method}-sg_{start_guidance}-ng_{neg_guidance}-iter_{i+1}-lr_{lr}-startmodel_{start_model}-numacc_{numacc}.pt' - folder_path = f'{folder_path}/models' - os.makedirs(folder_path, exist_ok=True) - if num is not None: - path = f'{folder_path}/Compvis-UNet-{name}-epoch_{num}.pt' - else: - path = f'{folder_path}/Compvis-UNet-{name}.pt' - if save_compvis: - torch.save(model.state_dict(), path) - - if save_diffusers: - print('Saving Model in Diffusers Format') - savemodelDiffusers(path, name, compvis_config_file, diffusers_config_file, device=device ) - - -def moving_average(a, n=3) : - ret = np.cumsum(a, dtype=float) - ret[n:] = ret[n:] - ret[:-n] - return ret[n - 1:] / n - -def plot_loss(losses, path,word, n=100): - v = moving_average(losses, n) - plt.plot(v, label=f'{word}_loss') - plt.legend(loc="upper left") - plt.title('Average loss in trainings', fontsize=20) - plt.xlabel('Data point', fontsize=16) - plt.ylabel('Loss value', fontsize=16) - plt.savefig(path) +@torch.no_grad() +def sample_model(model, sampler, c, h, w, ddim_steps, scale, ddim_eta, start_code=None, n_samples=1,t_start=-1,log_every_t=None,till_T=None,verbose=True): + """Sample the model""" + uc = None + if scale != 1.0: + uc = model.get_learned_conditioning(n_samples * [""]) + log_t = 100 + if log_every_t is not None: + log_t = log_every_t + shape = [4, h // 8, w // 8] + samples_ddim, inters = sampler.sample(S=ddim_steps, + conditioning=c, + batch_size=n_samples, + shape=shape, + verbose=False, + x_T=start_code, + unconditional_guidance_scale=scale, + unconditional_conditioning=uc, + eta=ddim_eta, + verbose_iter = verbose, + t_start=t_start, + log_every_t = log_t, + till_T = till_T + ) + if log_every_t is not None: + return samples_ddim, inters + return samples_ddim -def save_history(folder_path, losses, word_print): - folder_path = f'{folder_path}/logs' - os.makedirs(folder_path, exist_ok=True) - with open(f'{folder_path}/loss.txt', 'w') as f: - f.writelines([str(i) for i in losses]) - plot_loss(losses,f'{folder_path}/loss.png' , word_print, n=3) \ No newline at end of file diff --git a/mu_attack/src/clip b/mu_attack/src/clip new file mode 160000 index 00000000..dcba3cb2 --- /dev/null +++ b/mu_attack/src/clip @@ -0,0 +1 @@ +Subproject commit dcba3cb2e2827b402d2701e7e1c7d9fed8a20ef1 diff --git a/mu_attack/src/taming-transformers b/mu_attack/src/taming-transformers new file mode 160000 index 00000000..3ba01b24 --- /dev/null +++ b/mu_attack/src/taming-transformers @@ -0,0 +1 @@ +Subproject commit 3ba01b241669f5ade541ce990f7650a3b8f65318 From 274eea35c61805ec15301737632a2cdb1ebf1b57 Mon Sep 17 00:00:00 2001 From: palisthadeshar Date: Wed, 5 Feb 2025 05:54:32 +0000 Subject: [PATCH 05/22] refactored for compvis --- .../configs/adv_unlearn/adv_unlearn_config.py | 5 ++ mu_attack/execs/adv_attack.py | 47 ++++++++++++--- mu_attack/helpers/utils.py | 57 ++++++++++++++++++- 3 files changed, 100 insertions(+), 9 deletions(-) diff --git a/mu_attack/configs/adv_unlearn/adv_unlearn_config.py b/mu_attack/configs/adv_unlearn/adv_unlearn_config.py index f683af32..7ee7599a 100644 --- a/mu_attack/configs/adv_unlearn/adv_unlearn_config.py +++ b/mu_attack/configs/adv_unlearn/adv_unlearn_config.py @@ -10,6 +10,8 @@ def __init__(self, **kwargs): self.config_path = current_dir / "model_config.yaml" self.ckpt_path = "models/sd-v1-4-full-ema.ckpt" self.model_name_or_path = "CompVis/stable-diffusion-v1-4" + self.target_ckpt = None + self.diffusers_model_name_or_path = "/home/ubuntu/Projects/UnlearnCanvas/UnlearnCanvas/machine_unlearning/models/diffuser/style50" # Devices & IO self.devices = "0,0" # You can later parse this string into a list if needed. @@ -36,6 +38,9 @@ def __init__(self, **kwargs): self.attack_step = 30 self.attack_lr = 1e-3 + #backend + self.backend = "diffusers" + #wandb configs self.project_name = "quick-canvas-machine-unlearning" self.experiment_name = f'AdvUnlearn-{self.prompt}-method_Attack_{self.attack_method}' diff --git a/mu_attack/execs/adv_attack.py b/mu_attack/execs/adv_attack.py index 50459d27..a7d0ba87 100644 --- a/mu_attack/execs/adv_attack.py +++ b/mu_attack/execs/adv_attack.py @@ -5,14 +5,16 @@ import wandb from transformers import CLIPTextModel, CLIPTokenizer +from diffusers import StableDiffusionPipeline +from stable_diffusion.ldm.models.diffusion.ddim import DDIMSampler from mu_attack.configs.adv_unlearn import AdvUnlearnConfig from mu_attack.attackers.soft_prompt import SoftPromptAttack from mu_attack.tasks.utils.text_encoder import CustomTextEncoder -from mu_attack.helpers.utils import get_models +from mu_attack.helpers.utils import get_models_for_compvis, get_models_for_diffusers -class AdvUnlearn: +class AdvAttack: """ Class for adversarial unlearning training. @@ -44,6 +46,9 @@ def __init__(self, config: AdvUnlearnConfig, **kwargs): self.start_guidance = config.start_guidance self.config_path = config.config_path self.ckpt_path = config.ckpt_path + self.backend = config.backend + self.diffusers_model_name_or_path = config.diffusers_model_name_or_path + self.target_ckpt = config.target_ckpt self.criteria = torch.nn.MSELoss() # Initialize wandb @@ -56,6 +61,21 @@ def __init__(self, config: AdvUnlearnConfig, **kwargs): # Load models self.load_models() + def encode_text(self, text): + """Encodes text into a latent space using CLIP from Diffusers.""" + text_inputs = self.tokenizer( + text, + padding="max_length", + truncation=True, + max_length=77, + return_tensors="pt" + ).to(self.devices[0]) # Move to correct device + + with torch.no_grad(): + text_embeddings = self.text_encoder(text_inputs.input_ids)[0] # Take the first output (hidden states) + + return text_embeddings + def load_models(self): """Loads the tokenizer, text encoder, and models.""" self.tokenizer = CLIPTokenizer.from_pretrained( @@ -68,9 +88,15 @@ def load_models(self): self.all_embeddings = self.custom_text_encoder.get_all_embedding().unsqueeze(0) # Load base models - self.model_orig, self.sampler_orig, self.model, self.sampler = get_models( - self.config_path, self.ckpt_path, self.devices - ) + if self.backend == "compvis": + self.model_orig, self.sampler_orig, self.model, self.sampler = get_models_for_compvis( + self.config_path, self.ckpt_path, self.devices + ) + elif self.backend == "diffusers": + self.model_orig, self.sampler_orig, self.model, self.sampler = get_models_for_diffusers( + self.diffusers_model_name_or_path, self.target_ckpt, self.devices + ) + def attack(self): """Performs the adversarial attack.""" @@ -85,9 +111,14 @@ def attack(self): # Select a random word from the prompt list word = random.choice(self.words) - # Get learned condition embeddings - emb_0 = self.model_orig.get_learned_conditioning(['']) - emb_p = self.model_orig.get_learned_conditioning([word]) + if self.backend == "compvis": + # CompVis uses `get_learned_conditioning` + emb_0 = self.model_orig.get_learned_conditioning(['']) + emb_p = self.model_orig.get_learned_conditioning([word]) + elif self.backend == "diffusers": + # Diffusers requires explicit encoding via CLIP + emb_0 = self.encode_text("") + emb_p = self.encode_text(word) # Initialize attack class sp_attack = SoftPromptAttack( diff --git a/mu_attack/helpers/utils.py b/mu_attack/helpers/utils.py index 02e3d48a..054570de 100644 --- a/mu_attack/helpers/utils.py +++ b/mu_attack/helpers/utils.py @@ -8,6 +8,8 @@ from torchvision.transforms.functional import InterpolationMode import torchvision.transforms as torch_transforms +from diffusers import UNet2DConditionModel, DDIMScheduler + from mu.helpers.utils import load_model_from_config from stable_diffusion.ldm.models.diffusion.ddim import DDIMSampler @@ -235,7 +237,7 @@ def construct_id(k, adv_id, insertion_location,sot_id,eot_id,mid_id): -def get_models(config_path, ckpt_path, devices): +def get_models_for_compvis(config_path, ckpt_path, devices): model_orig = load_model_from_config(config_path, ckpt_path, devices[1]) sampler_orig = DDIMSampler(model_orig) @@ -244,6 +246,59 @@ def get_models(config_path, ckpt_path, devices): return model_orig, sampler_orig, model, sampler +def get_models_for_diffusers(model_name_or_path, target_ckpt, devices, cache_path=None): + """ + Loads two copies of a Diffusers UNet model along with their DDIM schedulers. + + Args: + model_name_or_path (str): The Hugging Face model identifier or local path. + target_ckpt (str or None): Path to a target checkpoint to load into the primary model (on devices[0]). + If None, no state dict is loaded. + devices (list or tuple): A list/tuple of two devices, e.g. [device0, device1]. + cache_path (str or None): Optional cache directory for pretrained weights. + + Returns: + model_orig: The UNet loaded on devices[1]. + sampler_orig: The DDIM scheduler corresponding to model_orig. + model: The UNet loaded on devices[0] (optionally updated with target_ckpt). + sampler: The DDIM scheduler corresponding to model. + """ + + # Load the original model (used for e.g. computing loss, etc.) on devices[1] + model_orig = UNet2DConditionModel.from_pretrained( + model_name_or_path, + subfolder="unet", + cache_dir=cache_path + ).to(devices[1]) + + # Create a DDIM scheduler for model_orig. (Note: diffusers DDIMScheduler is used here; + # adjust the subfolder or configuration if your scheduler is stored elsewhere.) + sampler_orig = DDIMScheduler.from_pretrained( + model_name_or_path, + subfolder="scheduler", + cache_dir=cache_path + ) + + # Load the second copy of the model on devices[0] + model = UNet2DConditionModel.from_pretrained( + model_name_or_path, + subfolder="unet", + cache_dir=cache_path + ).to(devices[0]) + + # Optionally load a target checkpoint into model + if target_ckpt is not None: + state_dict = torch.load(target_ckpt, map_location=devices[0]) + model.load_state_dict(state_dict) + + sampler = DDIMScheduler.from_pretrained( + model_name_or_path, + subfolder="scheduler", + cache_dir=cache_path + ) + + return model_orig, sampler_orig, model, sampler + @torch.no_grad() def sample_model(model, sampler, c, h, w, ddim_steps, scale, ddim_eta, start_code=None, n_samples=1,t_start=-1,log_every_t=None,till_T=None,verbose=True): """Sample the model""" From e8b416325a97d1b76caaa56386b2bb942f001724 Mon Sep 17 00:00:00 2001 From: palisthadeshar Date: Wed, 5 Feb 2025 08:26:51 +0000 Subject: [PATCH 06/22] advattack code compatible for diffusers model --- mu_attack/attackers/soft_prompt.py | 266 ++++++++++++++++++++++++++--- mu_attack/execs/adv_attack.py | 3 +- mu_attack/helpers/utils.py | 64 +++++++ 3 files changed, 313 insertions(+), 20 deletions(-) diff --git a/mu_attack/attackers/soft_prompt.py b/mu_attack/attackers/soft_prompt.py index eb16ad9a..ed7d4bf5 100644 --- a/mu_attack/attackers/soft_prompt.py +++ b/mu_attack/attackers/soft_prompt.py @@ -1,10 +1,202 @@ +# # mu_attack/attackers/soft_prompt.py + +# import torch +# import wandb + +# from mu.helpers import sample_model +# from mu_attack.helpers.utils import split_id, id2embedding, split_embd, init_adv, construct_embd, construct_id + + +# class SoftPromptAttack: +# """ +# A class to perform a soft prompt attack on the ESD model. + +# Attributes: +# model: The ESD model. +# model_orig: The frozen (original) model. +# tokenizer: The tokenizer. +# text_encoder: The text encoder. +# sampler: The sampler. +# emb_0: Unconditional embedding. +# emb_p: Conditional embedding. +# start_guidance: Guidance scale for sampling. +# devices: List of devices to use. +# ddim_steps: Number of DDIM steps. +# ddim_eta: The eta parameter for DDIM. +# image_size: The size (width and height) for generated images. +# criteria: The loss criteria function. +# k: Number of tokens (or a related parameter for the prompt). +# all_embeddings: The preloaded word embeddings. +# """ + +# def __init__(self, model, model_orig, tokenizer, text_encoder, sampler, +# emb_0, emb_p, start_guidance, devices, ddim_steps, ddim_eta, +# image_size, criteria, k, all_embeddings): +# self.model = model +# self.model_orig = model_orig +# self.tokenizer = tokenizer +# self.text_encoder = text_encoder +# self.sampler = sampler +# self.emb_0 = emb_0 +# self.emb_p = emb_p +# self.start_guidance = start_guidance +# self.devices = devices +# self.ddim_steps = ddim_steps +# self.ddim_eta = ddim_eta +# self.image_size = image_size +# self.criteria = criteria +# self.k = k +# self.all_embeddings = all_embeddings + +# def attack(self, global_step, word, attack_round, attack_type, +# attack_embd_type, attack_step, attack_lr, +# attack_init=None, attack_init_embd=None, attack_method='pgd'): +# """ +# Perform soft prompt attack on the ESD model. + +# Args: +# global_step (int): The current global training step. +# word (str): The input prompt. +# attack_round (int): The current attack round. +# attack_type (str): Type of attack ("add" or "insert"). +# attack_embd_type (str): Type of adversarial embedding ("condition_embd" or "word_embd"). +# attack_step (int): Number of steps to run the attack. +# attack_lr (float): Learning rate for the adversarial optimization. +# attack_init (str, optional): Initialization method ("latest" or "random"). +# attack_init_embd (torch.Tensor, optional): Initial adversarial embedding. +# attack_method (str, optional): Attack method to use ("pgd" or "fast_at"). + +# Returns: +# tuple: Depending on attack_embd_type, returns a tuple (embedding, input_ids) +# where the embedding is either a conditional or word embedding. +# """ +# orig_prompt_len = len(word.split()) +# if attack_type == 'add': +# # When using "add", update k to match the prompt length. +# self.k = orig_prompt_len + +# # A helper lambda to sample an image until a given time step. +# quick_sample_till_t = lambda x, s, code, t: sample_model( +# self.model, self.sampler, x, self.image_size, self.image_size, +# self.ddim_steps, s, self.ddim_eta, start_code=code, till_T=t, verbose=False +# ) + +# # --- Tokenization and Embedding --- +# text_input = self.tokenizer( +# word, padding="max_length", max_length=self.tokenizer.model_max_length, +# return_tensors="pt", truncation=True +# ) +# sot_id, mid_id, replace_id, eot_id = split_id( +# text_input.input_ids.to(self.devices[0]), self.k, orig_prompt_len +# ) + +# text_embeddings = id2embedding( +# self.tokenizer, self.all_embeddings, +# text_input.input_ids.to(self.devices[0]), self.devices[0] +# ) +# sot_embd, mid_embd, _, eot_embd = split_embd(text_embeddings, self.k, orig_prompt_len) + +# # --- Initialize the adversarial embedding --- +# if attack_init == 'latest': +# adv_embedding = init_adv(self.k, self.tokenizer, self.all_embeddings, +# attack_type, self.devices[0], 1, attack_init_embd) +# elif attack_init == 'random': +# adv_embedding = init_adv(self.k, self.tokenizer, self.all_embeddings, +# attack_type, self.devices[0], 1) +# else: +# # Default initialization if no method is provided +# adv_embedding = init_adv(self.k, self.tokenizer, self.all_embeddings, +# attack_type, self.devices[0], 1) + +# attack_opt = torch.optim.Adam([adv_embedding], lr=attack_lr) + +# # For the condition_embd attack type, construct the initial adversarial condition embedding. +# if attack_embd_type == 'condition_embd': +# input_adv_condition_embedding = construct_embd( +# self.k, adv_embedding, attack_type, sot_embd, mid_embd, eot_embd +# ) +# adv_input_ids = construct_id( +# self.k, replace_id, attack_type, sot_id, eot_id, mid_id +# ) + +# print(f'[{attack_type}] Starting {attack_method} attack on "{word}"') + +# # --- Attack Loop --- +# for i in range(attack_step): +# # Randomly sample a time step for the attack. +# t_enc = torch.randint(self.ddim_steps, (1,), device=self.devices[0]) +# og_num = round((int(t_enc) / self.ddim_steps) * 1000) +# og_num_lim = round((int(t_enc + 1) / self.ddim_steps) * 1000) +# t_enc_ddpm = torch.randint(og_num, og_num_lim, (1,), device=self.devices[0]) +# start_code = torch.randn((1, 4, 64, 64)).to(self.devices[0]) + +# with torch.no_grad(): +# # Generate an image with the concept from the frozen model. +# z = quick_sample_till_t( +# self.emb_p.to(self.devices[0]), self.start_guidance, start_code, int(t_enc) +# ) +# e_0 = self.model_orig.apply_model( +# z.to(self.devices[0]), t_enc_ddpm.to(self.devices[0]), self.emb_0.to(self.devices[0]) +# ) +# e_p = self.model_orig.apply_model( +# z.to(self.devices[0]), t_enc_ddpm.to(self.devices[0]), self.emb_p.to(self.devices[0]) +# ) + +# # For word_embd attack type, update the adversarial condition embedding using the text encoder. +# if attack_embd_type == 'word_embd': +# input_adv_word_embedding = construct_embd( +# self.k, adv_embedding, attack_type, sot_embd, mid_embd, eot_embd +# ) +# adv_input_ids = construct_id( +# self.k, replace_id, attack_type, sot_id, eot_id, mid_id +# ) +# input_adv_condition_embedding = self.text_encoder( +# input_ids=adv_input_ids.to(self.devices[0]), +# inputs_embeds=input_adv_word_embedding +# )[0] + +# # Get the conditional score from the ESD model with the adversarial condition embedding. +# e_n = self.model.apply_model( +# z.to(self.devices[0]), t_enc_ddpm.to(self.devices[0]), +# input_adv_condition_embedding.to(self.devices[0]) +# ) +# e_0.requires_grad = False +# e_p.requires_grad = False + +# # Compute the loss between the adversarial output and the target. +# loss = self.criteria(e_n.to(self.devices[0]), e_p.to(self.devices[0])) +# loss.backward() + +# if attack_method == 'pgd': +# attack_opt.step() +# elif attack_method == 'fast_at': +# adv_embedding.grad.sign_() +# attack_opt.step() +# else: +# raise ValueError('attack_method must be either pgd or fast_at') + +# wandb.log({'Attack_Loss': loss.item()}, step=global_step + i) +# wandb.log({'Train_Loss': 0.0}, step=global_step + i) +# print(f'Step: {global_step + i}, Attack_Loss: {loss.item()}') +# print(f'Step: {global_step + i}, Train_Loss: 0.0') + + + +# # --- Return the adversarial embeddings and input IDs --- +# if attack_embd_type == 'condition_embd': +# return input_adv_condition_embedding, adv_input_ids +# elif attack_embd_type == 'word_embd': +# return input_adv_word_embedding, adv_input_ids +# else: +# raise ValueError('attack_embd_type must be either condition_embd or word_embd') + + # mu_attack/attackers/soft_prompt.py import torch import wandb -from mu.helpers import sample_model -from mu_attack.helpers.utils import split_id, id2embedding, split_embd, init_adv, construct_embd, construct_id +from mu_attack.helpers.utils import split_id, id2embedding, split_embd, init_adv, construct_embd, construct_id, sample_model, sample_model_for_diffuser class SoftPromptAttack: @@ -16,7 +208,7 @@ class SoftPromptAttack: model_orig: The frozen (original) model. tokenizer: The tokenizer. text_encoder: The text encoder. - sampler: The sampler. + sampler: The sampler (or scheduler) used for diffusion. emb_0: Unconditional embedding. emb_p: Conditional embedding. start_guidance: Guidance scale for sampling. @@ -27,11 +219,12 @@ class SoftPromptAttack: criteria: The loss criteria function. k: Number of tokens (or a related parameter for the prompt). all_embeddings: The preloaded word embeddings. + backend: String indicating which backend is used ("compvis" or "diffusers"). """ def __init__(self, model, model_orig, tokenizer, text_encoder, sampler, emb_0, emb_p, start_guidance, devices, ddim_steps, ddim_eta, - image_size, criteria, k, all_embeddings): + image_size, criteria, k, all_embeddings, backend="compvis"): self.model = model self.model_orig = model_orig self.tokenizer = tokenizer @@ -47,6 +240,7 @@ def __init__(self, model, model_orig, tokenizer, text_encoder, sampler, self.criteria = criteria self.k = k self.all_embeddings = all_embeddings + self.backend = backend def attack(self, global_step, word, attack_round, attack_type, attack_embd_type, attack_step, attack_lr, @@ -76,7 +270,13 @@ def attack(self, global_step, word, attack_round, attack_type, self.k = orig_prompt_len # A helper lambda to sample an image until a given time step. - quick_sample_till_t = lambda x, s, code, t: sample_model( + if self.backend == "compvis": + quick_sample_till_t = lambda x, s, code, t: sample_model( + self.model, self.sampler, x, self.image_size, self.image_size, + self.ddim_steps, s, self.ddim_eta, start_code=code, till_T=t, verbose=False + ) + elif self.backend == "diffusers": + quick_sample_till_t = lambda x, s, code, t: sample_model_for_diffuser( self.model, self.sampler, x, self.image_size, self.image_size, self.ddim_steps, s, self.ddim_eta, start_code=code, till_T=t, verbose=False ) @@ -131,16 +331,34 @@ def attack(self, global_step, word, attack_round, attack_type, start_code = torch.randn((1, 4, 64, 64)).to(self.devices[0]) with torch.no_grad(): - # Generate an image with the concept from the frozen model. + # Sample a latent z using the conditional embedding. z = quick_sample_till_t( self.emb_p.to(self.devices[0]), self.start_guidance, start_code, int(t_enc) ) - e_0 = self.model_orig.apply_model( - z.to(self.devices[0]), t_enc_ddpm.to(self.devices[0]), self.emb_0.to(self.devices[0]) - ) - e_p = self.model_orig.apply_model( - z.to(self.devices[0]), t_enc_ddpm.to(self.devices[0]), self.emb_p.to(self.devices[0]) - ) + if self.backend == "compvis": + # For compvis, use apply_model to get the noise predictions. + e_0 = self.model_orig.apply_model( + z.to(self.devices[0]), t_enc_ddpm.to(self.devices[0]), self.emb_0.to(self.devices[0]) + ) + e_p = self.model_orig.apply_model( + z.to(self.devices[0]), t_enc_ddpm.to(self.devices[0]), self.emb_p.to(self.devices[0]) + ) + elif self.backend == "diffusers": + # For diffusers, call the UNet directly with encoder_hidden_states. + out_0 = self.model_orig( + z.to(self.devices[0]), + t_enc_ddpm.to(self.devices[0]), + encoder_hidden_states=self.emb_0.to(self.devices[0]) + ) + e_0 = out_0.sample if hasattr(out_0, "sample") else out_0 + out_p = self.model_orig( + z.to(self.devices[0]), + t_enc_ddpm.to(self.devices[0]), + encoder_hidden_states=self.emb_p.to(self.devices[0]) + ) + e_p = out_p.sample if hasattr(out_p, "sample") else out_p + else: + raise ValueError(f"Unknown backend: {self.backend}") # For word_embd attack type, update the adversarial condition embedding using the text encoder. if attack_embd_type == 'word_embd': @@ -155,11 +373,23 @@ def attack(self, global_step, word, attack_round, attack_type, inputs_embeds=input_adv_word_embedding )[0] - # Get the conditional score from the ESD model with the adversarial condition embedding. - e_n = self.model.apply_model( - z.to(self.devices[0]), t_enc_ddpm.to(self.devices[0]), - input_adv_condition_embedding.to(self.devices[0]) - ) + # Get the conditional score from the model with the adversarial condition embedding. + if self.backend == "compvis": + e_n = self.model.apply_model( + z.to(self.devices[0]), t_enc_ddpm.to(self.devices[0]), + input_adv_condition_embedding.to(self.devices[0]) + ) + elif self.backend == "diffusers": + out_n = self.model( + z.to(self.devices[0]), + t_enc_ddpm.to(self.devices[0]), + encoder_hidden_states=input_adv_condition_embedding.to(self.devices[0]) + ) + e_n = out_n.sample if hasattr(out_n, "sample") else out_n + else: + raise ValueError(f"Unknown backend: {self.backend}") + + # Prevent gradients on the frozen branch. e_0.requires_grad = False e_p.requires_grad = False @@ -179,8 +409,6 @@ def attack(self, global_step, word, attack_round, attack_type, wandb.log({'Train_Loss': 0.0}, step=global_step + i) print(f'Step: {global_step + i}, Attack_Loss: {loss.item()}') print(f'Step: {global_step + i}, Train_Loss: 0.0') - - # --- Return the adversarial embeddings and input IDs --- if attack_embd_type == 'condition_embd': diff --git a/mu_attack/execs/adv_attack.py b/mu_attack/execs/adv_attack.py index a7d0ba87..e6f1cb38 100644 --- a/mu_attack/execs/adv_attack.py +++ b/mu_attack/execs/adv_attack.py @@ -136,7 +136,8 @@ def attack(self): image_size=self.image_size, criteria=self.criteria, k=self.adv_prompt_num, - all_embeddings=self.all_embeddings + all_embeddings=self.all_embeddings, + backend = self.backend ) diff --git a/mu_attack/helpers/utils.py b/mu_attack/helpers/utils.py index 054570de..861cb85c 100644 --- a/mu_attack/helpers/utils.py +++ b/mu_attack/helpers/utils.py @@ -327,3 +327,67 @@ def sample_model(model, sampler, c, h, w, ddim_steps, scale, ddim_eta, start_cod return samples_ddim, inters return samples_ddim +@torch.no_grad() +def sample_model_for_diffuser(model, scheduler, c, h, w, ddim_steps, scale, ddim_eta, start_code=None, + n_samples=1, t_start=-1, log_every_t=None, till_T=None, verbose=True): + """ + Diffusers-compatible sampling function. + + Args: + model: The UNet model (from diffusers). + scheduler: A DDIMScheduler (or similar) instance. + c (torch.Tensor): The conditional encoder_hidden_states. + h (int): Image height. + w (int): Image width. + ddim_steps (int): Number of diffusion steps. + scale (float): Guidance scale. If not 1.0, classifier-free guidance is applied. + ddim_eta (float): The eta parameter for DDIM (unused in this basic implementation). + start_code (torch.Tensor, optional): Starting latent code. If None, random noise is used. + n_samples (int): Number of samples to generate. + t_start, log_every_t, till_T, verbose: Additional parameters (not used in this diffusers implementation). + + Returns: + torch.Tensor: The generated latent sample. + """ + device = c.device + + # If no starting code is provided, sample random noise. + if start_code is None: + start_code = torch.randn((n_samples, 4, h // 8, w // 8), device=device) + latents = start_code + + # Set the number of timesteps in the scheduler. + scheduler.set_timesteps(ddim_steps) + + # If using classifier-free guidance, prepare unconditional embeddings. + if scale != 1.0: + # In a full implementation you would obtain these from your text encoder + # For this example, we simply create a tensor of zeros with the same shape as c. + uc = torch.zeros_like(c) + # Duplicate latents and conditioning for guidance. + latents = torch.cat([latents, latents], dim=0) + c_in = torch.cat([uc, c], dim=0) + else: + c_in = c + + # Diffusion sampling loop. + for t in scheduler.timesteps: + # Scale the latents as required by the scheduler. + latent_model_input = scheduler.scale_model_input(latents, t) + model_output = model(latent_model_input, t, encoder_hidden_states=c_in) + # Assume model_output is a ModelOutput with a 'sample' attribute. + if scale != 1.0: + # Split the batch into unconditional and conditional parts. + noise_pred_uncond, noise_pred_text = model_output.sample.chunk(2) + # Apply classifier-free guidance. + noise_pred = noise_pred_uncond + scale * (noise_pred_text - noise_pred_uncond) + else: + noise_pred = model_output.sample + + # Step the scheduler. + latents = scheduler.step(noise_pred, t, latents).prev_sample + + # If guidance was used, return only the second half of the batch. + if scale != 1.0: + latents = latents[n_samples:] + return latents \ No newline at end of file From 94e567f0a434700fb54a68ea924db06c9d361e28 Mon Sep 17 00:00:00 2001 From: palisthadeshar Date: Wed, 5 Feb 2025 08:28:04 +0000 Subject: [PATCH 07/22] soft prompt attack for diffusers --- mu_attack/attackers/soft_prompt.py | 193 ----------------------------- 1 file changed, 193 deletions(-) diff --git a/mu_attack/attackers/soft_prompt.py b/mu_attack/attackers/soft_prompt.py index ed7d4bf5..aa8590c1 100644 --- a/mu_attack/attackers/soft_prompt.py +++ b/mu_attack/attackers/soft_prompt.py @@ -1,196 +1,3 @@ -# # mu_attack/attackers/soft_prompt.py - -# import torch -# import wandb - -# from mu.helpers import sample_model -# from mu_attack.helpers.utils import split_id, id2embedding, split_embd, init_adv, construct_embd, construct_id - - -# class SoftPromptAttack: -# """ -# A class to perform a soft prompt attack on the ESD model. - -# Attributes: -# model: The ESD model. -# model_orig: The frozen (original) model. -# tokenizer: The tokenizer. -# text_encoder: The text encoder. -# sampler: The sampler. -# emb_0: Unconditional embedding. -# emb_p: Conditional embedding. -# start_guidance: Guidance scale for sampling. -# devices: List of devices to use. -# ddim_steps: Number of DDIM steps. -# ddim_eta: The eta parameter for DDIM. -# image_size: The size (width and height) for generated images. -# criteria: The loss criteria function. -# k: Number of tokens (or a related parameter for the prompt). -# all_embeddings: The preloaded word embeddings. -# """ - -# def __init__(self, model, model_orig, tokenizer, text_encoder, sampler, -# emb_0, emb_p, start_guidance, devices, ddim_steps, ddim_eta, -# image_size, criteria, k, all_embeddings): -# self.model = model -# self.model_orig = model_orig -# self.tokenizer = tokenizer -# self.text_encoder = text_encoder -# self.sampler = sampler -# self.emb_0 = emb_0 -# self.emb_p = emb_p -# self.start_guidance = start_guidance -# self.devices = devices -# self.ddim_steps = ddim_steps -# self.ddim_eta = ddim_eta -# self.image_size = image_size -# self.criteria = criteria -# self.k = k -# self.all_embeddings = all_embeddings - -# def attack(self, global_step, word, attack_round, attack_type, -# attack_embd_type, attack_step, attack_lr, -# attack_init=None, attack_init_embd=None, attack_method='pgd'): -# """ -# Perform soft prompt attack on the ESD model. - -# Args: -# global_step (int): The current global training step. -# word (str): The input prompt. -# attack_round (int): The current attack round. -# attack_type (str): Type of attack ("add" or "insert"). -# attack_embd_type (str): Type of adversarial embedding ("condition_embd" or "word_embd"). -# attack_step (int): Number of steps to run the attack. -# attack_lr (float): Learning rate for the adversarial optimization. -# attack_init (str, optional): Initialization method ("latest" or "random"). -# attack_init_embd (torch.Tensor, optional): Initial adversarial embedding. -# attack_method (str, optional): Attack method to use ("pgd" or "fast_at"). - -# Returns: -# tuple: Depending on attack_embd_type, returns a tuple (embedding, input_ids) -# where the embedding is either a conditional or word embedding. -# """ -# orig_prompt_len = len(word.split()) -# if attack_type == 'add': -# # When using "add", update k to match the prompt length. -# self.k = orig_prompt_len - -# # A helper lambda to sample an image until a given time step. -# quick_sample_till_t = lambda x, s, code, t: sample_model( -# self.model, self.sampler, x, self.image_size, self.image_size, -# self.ddim_steps, s, self.ddim_eta, start_code=code, till_T=t, verbose=False -# ) - -# # --- Tokenization and Embedding --- -# text_input = self.tokenizer( -# word, padding="max_length", max_length=self.tokenizer.model_max_length, -# return_tensors="pt", truncation=True -# ) -# sot_id, mid_id, replace_id, eot_id = split_id( -# text_input.input_ids.to(self.devices[0]), self.k, orig_prompt_len -# ) - -# text_embeddings = id2embedding( -# self.tokenizer, self.all_embeddings, -# text_input.input_ids.to(self.devices[0]), self.devices[0] -# ) -# sot_embd, mid_embd, _, eot_embd = split_embd(text_embeddings, self.k, orig_prompt_len) - -# # --- Initialize the adversarial embedding --- -# if attack_init == 'latest': -# adv_embedding = init_adv(self.k, self.tokenizer, self.all_embeddings, -# attack_type, self.devices[0], 1, attack_init_embd) -# elif attack_init == 'random': -# adv_embedding = init_adv(self.k, self.tokenizer, self.all_embeddings, -# attack_type, self.devices[0], 1) -# else: -# # Default initialization if no method is provided -# adv_embedding = init_adv(self.k, self.tokenizer, self.all_embeddings, -# attack_type, self.devices[0], 1) - -# attack_opt = torch.optim.Adam([adv_embedding], lr=attack_lr) - -# # For the condition_embd attack type, construct the initial adversarial condition embedding. -# if attack_embd_type == 'condition_embd': -# input_adv_condition_embedding = construct_embd( -# self.k, adv_embedding, attack_type, sot_embd, mid_embd, eot_embd -# ) -# adv_input_ids = construct_id( -# self.k, replace_id, attack_type, sot_id, eot_id, mid_id -# ) - -# print(f'[{attack_type}] Starting {attack_method} attack on "{word}"') - -# # --- Attack Loop --- -# for i in range(attack_step): -# # Randomly sample a time step for the attack. -# t_enc = torch.randint(self.ddim_steps, (1,), device=self.devices[0]) -# og_num = round((int(t_enc) / self.ddim_steps) * 1000) -# og_num_lim = round((int(t_enc + 1) / self.ddim_steps) * 1000) -# t_enc_ddpm = torch.randint(og_num, og_num_lim, (1,), device=self.devices[0]) -# start_code = torch.randn((1, 4, 64, 64)).to(self.devices[0]) - -# with torch.no_grad(): -# # Generate an image with the concept from the frozen model. -# z = quick_sample_till_t( -# self.emb_p.to(self.devices[0]), self.start_guidance, start_code, int(t_enc) -# ) -# e_0 = self.model_orig.apply_model( -# z.to(self.devices[0]), t_enc_ddpm.to(self.devices[0]), self.emb_0.to(self.devices[0]) -# ) -# e_p = self.model_orig.apply_model( -# z.to(self.devices[0]), t_enc_ddpm.to(self.devices[0]), self.emb_p.to(self.devices[0]) -# ) - -# # For word_embd attack type, update the adversarial condition embedding using the text encoder. -# if attack_embd_type == 'word_embd': -# input_adv_word_embedding = construct_embd( -# self.k, adv_embedding, attack_type, sot_embd, mid_embd, eot_embd -# ) -# adv_input_ids = construct_id( -# self.k, replace_id, attack_type, sot_id, eot_id, mid_id -# ) -# input_adv_condition_embedding = self.text_encoder( -# input_ids=adv_input_ids.to(self.devices[0]), -# inputs_embeds=input_adv_word_embedding -# )[0] - -# # Get the conditional score from the ESD model with the adversarial condition embedding. -# e_n = self.model.apply_model( -# z.to(self.devices[0]), t_enc_ddpm.to(self.devices[0]), -# input_adv_condition_embedding.to(self.devices[0]) -# ) -# e_0.requires_grad = False -# e_p.requires_grad = False - -# # Compute the loss between the adversarial output and the target. -# loss = self.criteria(e_n.to(self.devices[0]), e_p.to(self.devices[0])) -# loss.backward() - -# if attack_method == 'pgd': -# attack_opt.step() -# elif attack_method == 'fast_at': -# adv_embedding.grad.sign_() -# attack_opt.step() -# else: -# raise ValueError('attack_method must be either pgd or fast_at') - -# wandb.log({'Attack_Loss': loss.item()}, step=global_step + i) -# wandb.log({'Train_Loss': 0.0}, step=global_step + i) -# print(f'Step: {global_step + i}, Attack_Loss: {loss.item()}') -# print(f'Step: {global_step + i}, Train_Loss: 0.0') - - - -# # --- Return the adversarial embeddings and input IDs --- -# if attack_embd_type == 'condition_embd': -# return input_adv_condition_embedding, adv_input_ids -# elif attack_embd_type == 'word_embd': -# return input_adv_word_embedding, adv_input_ids -# else: -# raise ValueError('attack_embd_type must be either condition_embd or word_embd') - - # mu_attack/attackers/soft_prompt.py import torch From 1b5d8c927635b3271022efe611dfddfab52b1d66 Mon Sep 17 00:00:00 2001 From: palisthadeshar Date: Wed, 5 Feb 2025 16:17:05 +0545 Subject: [PATCH 08/22] docs for soft prompt added --- docs/mu_attack/attack/hard_prompt.md | 11 +- docs/mu_attack/attack/no_attack.md | 7 +- docs/mu_attack/attack/random.md | 8 +- docs/mu_attack/attack/seed_search.md | 10 +- docs/mu_attack/attack/soft_prompt.md | 215 +++++++++++++++++++++++++++ docs/mu_attack/attack/text_grad.md | 10 +- 6 files changed, 254 insertions(+), 7 deletions(-) create mode 100644 docs/mu_attack/attack/soft_prompt.md diff --git a/docs/mu_attack/attack/hard_prompt.md b/docs/mu_attack/attack/hard_prompt.md index cb15bcd4..af6071b0 100644 --- a/docs/mu_attack/attack/hard_prompt.md +++ b/docs/mu_attack/attack/hard_prompt.md @@ -5,9 +5,18 @@ This repository contains the implementation of UnlearnDiffAttack for hard prompt ### Create Environment + +``` +create_env +``` +eg: ```create_env mu_attack``` + ``` -conda env create -f environment.yaml +conda activate ``` +eg: ```conda activate mu_attack``` + + ### Generate Dataset ``` diff --git a/docs/mu_attack/attack/no_attack.md b/docs/mu_attack/attack/no_attack.md index 955dcdb1..e3d4a619 100644 --- a/docs/mu_attack/attack/no_attack.md +++ b/docs/mu_attack/attack/no_attack.md @@ -6,9 +6,14 @@ This repository contains the implementation of UnlearnDiffAttack for No-attack, ### Create Environment ``` -conda env create -f environment.yaml +create_env ``` +eg: ```create_env mu_attack``` +``` +conda activate +``` +eg: ```conda activate mu_attack``` ### Generate Dataset ``` python -m scripts.generate_dataset --prompts_path data/prompts/prompts.csv --concept i2p_nude --save_path outputs/dataset diff --git a/docs/mu_attack/attack/random.md b/docs/mu_attack/attack/random.md index 2ad43fc3..4b01bf45 100644 --- a/docs/mu_attack/attack/random.md +++ b/docs/mu_attack/attack/random.md @@ -6,8 +6,14 @@ This repository contains the implementation of UnlearnDiffAttack for random, a f ### Create Environment ``` -conda env create -f environment.yaml +create_env ``` +eg: ```create_env mu_attack``` + +``` +conda activate +``` +eg: ```conda activate mu_attack``` ### Generate Dataset ``` diff --git a/docs/mu_attack/attack/seed_search.md b/docs/mu_attack/attack/seed_search.md index 56d586db..8eca0948 100644 --- a/docs/mu_attack/attack/seed_search.md +++ b/docs/mu_attack/attack/seed_search.md @@ -5,9 +5,15 @@ This repository contains the implementation of UnlearnDiffAttack for seed search ### Create Environment -```bash -conda env create -f environment.yaml ``` +create_env +``` +eg: ```create_env mu_attack``` + +``` +conda activate +``` +eg: ```conda activate mu_attack``` ### Generate Dataset diff --git a/docs/mu_attack/attack/soft_prompt.md b/docs/mu_attack/attack/soft_prompt.md new file mode 100644 index 00000000..cee51b9e --- /dev/null +++ b/docs/mu_attack/attack/soft_prompt.md @@ -0,0 +1,215 @@ + +## UnlearnDiffAttak + +This project implements a novel adversarial unlearning framework designed to perform soft prompt attacks on diffusion models. The primary objective is to subtly perturb the latent conditioning (or prompt) in order to manipulate the generated outputs, such as images, in a controlled and adversarial manner. + + +### Create Environment +``` +create_env +``` +eg: ```create_env mu_attack``` + +``` +conda activate +``` +eg: ```conda activate mu_attack``` + +### Run Soft Prompt Attack +1. **Soft Prompt Attack - compvis** + +```python + +from mu_attack.execs.adv_attack import AdvAttack +from mu_attack.configs.adv_unlearn import adv_unlearn_config +from mu.algorithms.esd.configs import esd_train_mu + + +def mu_defense(): + adv_unlearn = AdvAttack( + config=adv_unlearn_config, + ckpt_path = "/home/ubuntu/Projects/dipesh/unlearn_diff/models/sd-v1-4-full-ema.ckpt", + attack_step = 2, + backend = "compvis", + config_path = esd_train_mu.model_config_path + + ) + adv_unlearn.attack() + +if __name__ == "__main__": + mu_defense() + +``` + + +2. **Soft Prompt Attack - diffuser** + +```python +from mu_attack.execs.adv_attack import AdvAttack +from mu_attack.configs.adv_unlearn import adv_unlearn_config + + +def mu_defense(): + + adv_unlearn = AdvAttack( + config=adv_unlearn_config, + ckpt_path = "/home/ubuntu/Projects/UnlearnCanvas/UnlearnCanvas/machine_unlearning/models/diffuser/style50", + attack_step = 2, + backend = "diffusers" + + ) + adv_unlearn.attack() + +if __name__ == "__main__": + mu_defense() + +``` + + +**Code Explanation & Important Notes** + +* from mu_attack.configs.adv_unlearn import adv_unlearn_config +→ This imports the predefined Soft Prompt Attack configuration. It sets up the attack parameters and methodologies. + + +**How It Works** +* Default Values: The script first loads default values from the train config file as in configs section. + +* Parameter Overrides: Any parameters passed directly to the algorithm, overrides these configs. + +* Final Configuration: The script merges the configs and convert them into dictionary to proceed with the training. + + +### Description of fields in soft prompt attack config + +1. Model setup + +* config_path : Path to the inference configuration file for Stable Diffusion v1.4. + + * Type: str + * Default: "model_config.yaml" + +* compvis_ckpt_path : Path to the Stable Diffusion v1.4 checkpoint file. + + * Type: str + * Default: "models/sd-v1-4-full-ema.ckpt" + +* encoder_model_name_or_path : Path to the pre-trained encoder model used for text-to-image training. + + * Type: str + * Default: "CompVis/stable-diffusion-v1-4" + +* diffusers_model_name_or_path : Path to the Diffusers-based implementation of the model. + + * Type: str + * Default: "/home/ubuntu/Projects/UnlearnCanvas/UnlearnCanvas/machine_unlearning/models/diffuser/style50" + +* target_ckpt : Checkpoint path for sampling. If None, it uses the default model. + + * Type: str + * Default: None + +2. Devices & I/O + +* devices : Specifies the CUDA devices used for training. + + * Type: str + * Default: "0,0" + +* seperator : Defines the separator used when processing multiple words for unlearning. + + * Type: str + * Default: None + +* cache_path : Path where intermediate results and cache files are stored. + + * Type: str + * Default: ".cache" + + +3. Image & Diffusion Sampling + +* start_guidance : Guidance scale used for generating the initial image. + + * Type: float + * Default: 3.0 + +* ddim_steps : Number of DDIM sampling steps used for inference. + + * Type: int + * Default: 50 + +* image_size : The resolution of images generated during training. + + * Type: int + * Default: 512 + +* ddim_eta : Noise scaling factor for DDIM inference. + + * Type: float + * Default: 0 + + +* prompt: The text prompt associated with the concept to erase. + + * Type: str + * Default: "nudity" + +* attack_method: The adversarial attack method used during training. + + * Type: str + * Choices: ["pgd", "multi_pgd", "fast_at", "free_at"] + * Default: "pgd" + +* ddim_eta: The DDIM sampling noise parameter. + + * Type: float + * Default: 0 + +5. Adversarial Attack Hyperparameters + +* adv_prompt_num: Number of prompt tokens used for adversarial learning. + + * Type: int + * Default: 1 + +* attack_embd_type: Type of embedding targeted for attack. + + * Type: str + * Choices: ["word_embd", "condition_embd"] + * Default: "word_embd" + +* attack_type: The type of attack applied. + + * Type: str + * Choices: ["replace_k", "add", "prefix_k", "suffix_k", "mid_k", "insert_k", "per_k_words"] + * Default: "prefix_k" + +* attack_init: Method for initializing adversarial attacks. + + * Type: str + * Choices: ["random", "latest"] + * Default: "latest" + +* attack_step: Number of attack optimization steps. + + * Type: int + * Default: 30 + +* attack_lr: Learning rate for adversarial attack updates. + + * Type: float + * Default: 1e-3 + + +6. Backend & Logging + +* backend: Specifies the backend for diffusion-based training. + + * Type: str + * Default: "diffusers" + +* project_name: Name of the WandB project for logging. + + * Type: str + * Default: "quick-canvas-machine-unlearning" diff --git a/docs/mu_attack/attack/text_grad.md b/docs/mu_attack/attack/text_grad.md index 97d54429..6ea87b10 100644 --- a/docs/mu_attack/attack/text_grad.md +++ b/docs/mu_attack/attack/text_grad.md @@ -6,9 +6,15 @@ This repository contains the implementation of UnlearnDiffAttack for text grad, ### Create Environment -```bash -conda env create -f environment.yaml ``` +create_env +``` +eg: ```create_env mu_attack``` + +``` +conda activate +``` +eg: ```conda activate mu_attack``` ### Generate Dataset From 59f1634023eb99ac2e18976a1e6d31e1349a567a Mon Sep 17 00:00:00 2001 From: palisthadeshar Date: Wed, 5 Feb 2025 16:17:30 +0545 Subject: [PATCH 09/22] config variable name fixes --- .../configs/adv_unlearn/adv_unlearn_config.py | 26 +++++-- .../configs/adv_unlearn/model_config.yaml | 70 ------------------- mu_attack/execs/adv_attack.py | 10 +-- 3 files changed, 24 insertions(+), 82 deletions(-) delete mode 100644 mu_attack/configs/adv_unlearn/model_config.yaml diff --git a/mu_attack/configs/adv_unlearn/adv_unlearn_config.py b/mu_attack/configs/adv_unlearn/adv_unlearn_config.py index 7ee7599a..af8458c8 100644 --- a/mu_attack/configs/adv_unlearn/adv_unlearn_config.py +++ b/mu_attack/configs/adv_unlearn/adv_unlearn_config.py @@ -6,11 +6,17 @@ class AdvUnlearnConfig(BaseConfig): def __init__(self, **kwargs): - # Inference & Model Paths + # Inference & Model Paths for compvis self.config_path = current_dir / "model_config.yaml" - self.ckpt_path = "models/sd-v1-4-full-ema.ckpt" - self.model_name_or_path = "CompVis/stable-diffusion-v1-4" + self.compvis_ckpt_path = "models/sd-v1-4-full-ema.ckpt" + + #model path for custom encoder + self.encoder_model_name_or_path = "CompVis/stable-diffusion-v1-4" + + #for samlping self.target_ckpt = None + + # Model Paths for diffusers self.diffusers_model_name_or_path = "/home/ubuntu/Projects/UnlearnCanvas/UnlearnCanvas/machine_unlearning/models/diffuser/style50" # Devices & IO @@ -54,10 +60,16 @@ def validate_config(self): """ Perform basic validation on the config parameters. """ - if not os.path.exists(self.config_path): - raise FileNotFoundError(f"Model config file {self.config_path} does not exist.") - if not os.path.exists(self.ckpt_path): - raise FileNotFoundError(f"Checkpoint file {self.ckpt_path} does not exist.") + if self.backend not in ["compvis", "diffusers"]: + raise ValueError(f"Backend must be either 'compvis' or 'diffusers'. Got {self.backend}.") + if self.backend == "compvis": + if not os.path.exists(self.config_path): + raise FileNotFoundError(f"Model config file {self.config_path} does not exist.") + if not os.path.exists(self.compvis_ckpt_path): + raise FileNotFoundError(f"Checkpoint file {self.compvis_ckpt_path} does not exist.") + elif self.backend == "diffusers": + if not os.path.exists(self.diffusers_model_name_or_path): + raise FileNotFoundError(f"Diffusers model {self.diffusers_model_name_or_path} does not exist.") adv_unlearn_config = AdvUnlearnConfig() diff --git a/mu_attack/configs/adv_unlearn/model_config.yaml b/mu_attack/configs/adv_unlearn/model_config.yaml deleted file mode 100644 index cf7f8131..00000000 --- a/mu_attack/configs/adv_unlearn/model_config.yaml +++ /dev/null @@ -1,70 +0,0 @@ -model: - base_learning_rate: 1.0e-04 - target: stable_diffusion.ldm.models.diffusion.ddpm.LatentDiffusion - params: - linear_start: 0.00085 - linear_end: 0.0120 - num_timesteps_cond: 1 - log_every_t: 200 - timesteps: 1000 - first_stage_key: "jpg" - cond_stage_key: "txt" - image_size: 64 - channels: 4 - cond_stage_trainable: false # Note: different from the one we trained before - conditioning_key: crossattn - monitor: val/loss_simple_ema - scale_factor: 0.18215 - use_ema: False - - scheduler_config: # 10000 warmup steps - target: stable_diffusion.ldm.lr_scheduler.LambdaLinearScheduler - params: - warm_up_steps: [ 10000 ] - cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases - f_start: [ 1.e-6 ] - f_max: [ 1. ] - f_min: [ 1. ] - - unet_config: - target: stable_diffusion.ldm.modules.diffusionmodules.openaimodel.UNetModel - params: - image_size: 32 # unused - in_channels: 4 - out_channels: 4 - model_channels: 320 - attention_resolutions: [ 4, 2, 1 ] - num_res_blocks: 2 - channel_mult: [ 1, 2, 4, 4 ] - num_heads: 8 - use_spatial_transformer: True - transformer_depth: 1 - context_dim: 768 - use_checkpoint: True - legacy: False - - first_stage_config: - target: stable_diffusion.ldm.models.autoencoder.AutoencoderKL - params: - embed_dim: 4 - monitor: val/rec_loss - ddconfig: - double_z: true - z_channels: 4 - resolution: 256 - in_channels: 3 - out_ch: 3 - ch: 128 - ch_mult: - - 1 - - 2 - - 4 - - 4 - num_res_blocks: 2 - attn_resolutions: [] - dropout: 0.0 - lossconfig: - target: torch.nn.Identity - - cond_stage_config: - target: stable_diffusion.ldm.modules.encoders.modules.FrozenCLIPEmbedder diff --git a/mu_attack/execs/adv_attack.py b/mu_attack/execs/adv_attack.py index e6f1cb38..dc2946dc 100644 --- a/mu_attack/execs/adv_attack.py +++ b/mu_attack/execs/adv_attack.py @@ -29,7 +29,7 @@ def __init__(self, config: AdvUnlearnConfig, **kwargs): config.validate_config() self.prompt = config.prompt - self.model_name_or_path = config.model_name_or_path + self.encoder_model_name_or_path = config.encoder_model_name_or_path self.cache_path = config.cache_path self.devices = [f'cuda:{int(d.strip())}' for d in config.devices.split(',')] self.attack_type = config.attack_type @@ -45,7 +45,7 @@ def __init__(self, config: AdvUnlearnConfig, **kwargs): self.adv_prompt_num = config.adv_prompt_num self.start_guidance = config.start_guidance self.config_path = config.config_path - self.ckpt_path = config.ckpt_path + self.compvis_ckpt_path = config.compvis_ckpt_path self.backend = config.backend self.diffusers_model_name_or_path = config.diffusers_model_name_or_path self.target_ckpt = config.target_ckpt @@ -79,10 +79,10 @@ def encode_text(self, text): def load_models(self): """Loads the tokenizer, text encoder, and models.""" self.tokenizer = CLIPTokenizer.from_pretrained( - self.model_name_or_path, subfolder="tokenizer", cache_dir=self.cache_path + self.encoder_model_name_or_path, subfolder="tokenizer", cache_dir=self.cache_path ) self.text_encoder = CLIPTextModel.from_pretrained( - self.model_name_or_path, subfolder="text_encoder", cache_dir=self.cache_path + self.encoder_model_name_or_path, subfolder="text_encoder", cache_dir=self.cache_path ).to(self.devices[0]) self.custom_text_encoder = CustomTextEncoder(self.text_encoder).to(self.devices[0]) self.all_embeddings = self.custom_text_encoder.get_all_embedding().unsqueeze(0) @@ -90,7 +90,7 @@ def load_models(self): # Load base models if self.backend == "compvis": self.model_orig, self.sampler_orig, self.model, self.sampler = get_models_for_compvis( - self.config_path, self.ckpt_path, self.devices + self.config_path, self.compvis_ckpt_path, self.devices ) elif self.backend == "diffusers": self.model_orig, self.sampler_orig, self.model, self.sampler = get_models_for_diffusers( From ffb8b309112fc7c05b3ae28dc02a5b6d81752842 Mon Sep 17 00:00:00 2001 From: palisthadeshar Date: Wed, 5 Feb 2025 16:18:46 +0545 Subject: [PATCH 10/22] removed src from mu_attack --- mu_attack/src/clip | 1 - mu_attack/src/taming-transformers | 1 - 2 files changed, 2 deletions(-) delete mode 160000 mu_attack/src/clip delete mode 160000 mu_attack/src/taming-transformers diff --git a/mu_attack/src/clip b/mu_attack/src/clip deleted file mode 160000 index dcba3cb2..00000000 --- a/mu_attack/src/clip +++ /dev/null @@ -1 +0,0 @@ -Subproject commit dcba3cb2e2827b402d2701e7e1c7d9fed8a20ef1 diff --git a/mu_attack/src/taming-transformers b/mu_attack/src/taming-transformers deleted file mode 160000 index 3ba01b24..00000000 --- a/mu_attack/src/taming-transformers +++ /dev/null @@ -1 +0,0 @@ -Subproject commit 3ba01b241669f5ade541ce990f7650a3b8f65318 From 0c572efdef4e3bc5509ceff6a9e470de4091a23f Mon Sep 17 00:00:00 2001 From: palisthadeshar Date: Wed, 5 Feb 2025 12:33:31 +0000 Subject: [PATCH 11/22] common env created --- docs/mu_attack/attack/soft_prompt.md | 4 +-- mu_attack/.gitignore | 1 + mu_attack/adv_unlearn_environment.yaml | 35 ------------------- .../configs/adv_unlearn/adv_unlearn_config.py | 2 -- mu_attack/environment.yaml | 1 + 5 files changed, 4 insertions(+), 39 deletions(-) create mode 100644 mu_attack/.gitignore delete mode 100644 mu_attack/adv_unlearn_environment.yaml diff --git a/docs/mu_attack/attack/soft_prompt.md b/docs/mu_attack/attack/soft_prompt.md index cee51b9e..c3bdf2bc 100644 --- a/docs/mu_attack/attack/soft_prompt.md +++ b/docs/mu_attack/attack/soft_prompt.md @@ -28,7 +28,7 @@ from mu.algorithms.esd.configs import esd_train_mu def mu_defense(): adv_unlearn = AdvAttack( config=adv_unlearn_config, - ckpt_path = "/home/ubuntu/Projects/dipesh/unlearn_diff/models/sd-v1-4-full-ema.ckpt", + compvis_ckpt_path = "/home/ubuntu/Projects/dipesh/unlearn_diff/models/sd-v1-4-full-ema.ckpt", attack_step = 2, backend = "compvis", config_path = esd_train_mu.model_config_path @@ -53,7 +53,7 @@ def mu_defense(): adv_unlearn = AdvAttack( config=adv_unlearn_config, - ckpt_path = "/home/ubuntu/Projects/UnlearnCanvas/UnlearnCanvas/machine_unlearning/models/diffuser/style50", + diffusers_model_name_or_path = "/home/ubuntu/Projects/UnlearnCanvas/UnlearnCanvas/machine_unlearning/models/diffuser/style50", attack_step = 2, backend = "diffusers" diff --git a/mu_attack/.gitignore b/mu_attack/.gitignore new file mode 100644 index 00000000..aa850f42 --- /dev/null +++ b/mu_attack/.gitignore @@ -0,0 +1 @@ +src/* \ No newline at end of file diff --git a/mu_attack/adv_unlearn_environment.yaml b/mu_attack/adv_unlearn_environment.yaml deleted file mode 100644 index 8c55a203..00000000 --- a/mu_attack/adv_unlearn_environment.yaml +++ /dev/null @@ -1,35 +0,0 @@ -name: AdvUnlearn -channels: - - pytorch - - defaults -dependencies: - - python=3.8.5 - - pip=20.3 - - cudatoolkit=11.3 - - pytorch=1.11.0 - - torchvision=0.12.0 - - numpy=1.19.2 - - pip: - - albumentations==0.4.3 - - diffusers==0.12.1 - - opencv-python==4.1.2.30 - - pudb==2019.2 - - invisible-watermark - - imageio==2.9.0 - - imageio-ffmpeg==0.4.2 - - huggingface_hub==0.10.1 - - pytorch-lightning==1.4.2 - - omegaconf==2.1.1 - - test-tube>=0.7.5 - - streamlit>=0.73.1 - - einops==0.3.0 - - torch-fidelity==0.3.0 - - transformers==4.25.1 - - torchmetrics==0.6.0 - - kornia==0.6 - - matplotlib - - wandb - - tabulate - - -e git+https://github.com/CompVis/taming-transformers.git@master#egg=taming-transformers - - -e git+https://github.com/openai/CLIP.git@main#egg=clip - # - -e . diff --git a/mu_attack/configs/adv_unlearn/adv_unlearn_config.py b/mu_attack/configs/adv_unlearn/adv_unlearn_config.py index af8458c8..b8ef2017 100644 --- a/mu_attack/configs/adv_unlearn/adv_unlearn_config.py +++ b/mu_attack/configs/adv_unlearn/adv_unlearn_config.py @@ -63,8 +63,6 @@ def validate_config(self): if self.backend not in ["compvis", "diffusers"]: raise ValueError(f"Backend must be either 'compvis' or 'diffusers'. Got {self.backend}.") if self.backend == "compvis": - if not os.path.exists(self.config_path): - raise FileNotFoundError(f"Model config file {self.config_path} does not exist.") if not os.path.exists(self.compvis_ckpt_path): raise FileNotFoundError(f"Checkpoint file {self.compvis_ckpt_path} does not exist.") elif self.backend == "diffusers": diff --git a/mu_attack/environment.yaml b/mu_attack/environment.yaml index 7eb48fc8..b4cfa0b8 100644 --- a/mu_attack/environment.yaml +++ b/mu_attack/environment.yaml @@ -29,6 +29,7 @@ dependencies: - taming-transformers-rom1504 - kornia==0.6 - pydantic==2.10.6 + - wandb==0.19.5 - git+https://github.com/Phoveran/fastargs.git@main#egg=fastargs - git+https://github.com/CompVis/taming-transformers.git@master#egg=taming-transformers - git+https://github.com/openai/CLIP.git@main#egg=clip \ No newline at end of file From 604e29da46837a9e38325daba4706b68620fd9df Mon Sep 17 00:00:00 2001 From: palisthadeshar Date: Thu, 6 Feb 2025 05:42:15 +0000 Subject: [PATCH 12/22] common env created --- docs/mu_attack/attack/soft_prompt.md | 6 ++++++ mu_attack/environment.yaml | 1 + 2 files changed, 7 insertions(+) diff --git a/docs/mu_attack/attack/soft_prompt.md b/docs/mu_attack/attack/soft_prompt.md index c3bdf2bc..1311c5ea 100644 --- a/docs/mu_attack/attack/soft_prompt.md +++ b/docs/mu_attack/attack/soft_prompt.md @@ -65,6 +65,12 @@ if __name__ == "__main__": ``` +**Run the python file in offline mode** + +```bash +WANDB_MODE=offline python_file.py +``` + **Code Explanation & Important Notes** diff --git a/mu_attack/environment.yaml b/mu_attack/environment.yaml index b4cfa0b8..4195608c 100644 --- a/mu_attack/environment.yaml +++ b/mu_attack/environment.yaml @@ -26,6 +26,7 @@ dependencies: - transformers==4.33.2 - opencv-python-headless==4.8.0.76 - einops==0.8.0 + - timm==0.6.7 - taming-transformers-rom1504 - kornia==0.6 - pydantic==2.10.6 From 1f0aeb043c3d519f5ab953525cb64e2f376d991b Mon Sep 17 00:00:00 2001 From: palisthadeshar Date: Thu, 6 Feb 2025 11:38:47 +0545 Subject: [PATCH 13/22] renamed files --- docs/mu_attack/attack/soft_prompt.md | 8 ++++---- mu_attack/configs/adv_unlearn/__init__.py | 2 +- .../{adv_unlearn_config.py => adv_attack_config.py} | 6 ++++-- mu_attack/execs/adv_attack.py | 4 ++-- 4 files changed, 11 insertions(+), 9 deletions(-) rename mu_attack/configs/adv_unlearn/{adv_unlearn_config.py => adv_attack_config.py} (95%) diff --git a/docs/mu_attack/attack/soft_prompt.md b/docs/mu_attack/attack/soft_prompt.md index 1311c5ea..1b062e3d 100644 --- a/docs/mu_attack/attack/soft_prompt.md +++ b/docs/mu_attack/attack/soft_prompt.md @@ -21,13 +21,13 @@ eg: ```conda activate mu_attack``` ```python from mu_attack.execs.adv_attack import AdvAttack -from mu_attack.configs.adv_unlearn import adv_unlearn_config +from mu_attack.configs.adv_unlearn import adv_attack_config from mu.algorithms.esd.configs import esd_train_mu def mu_defense(): adv_unlearn = AdvAttack( - config=adv_unlearn_config, + config=adv_attack_config, compvis_ckpt_path = "/home/ubuntu/Projects/dipesh/unlearn_diff/models/sd-v1-4-full-ema.ckpt", attack_step = 2, backend = "compvis", @@ -46,13 +46,13 @@ if __name__ == "__main__": ```python from mu_attack.execs.adv_attack import AdvAttack -from mu_attack.configs.adv_unlearn import adv_unlearn_config +from mu_attack.configs.adv_unlearn import adv_attack_config def mu_defense(): adv_unlearn = AdvAttack( - config=adv_unlearn_config, + config=adv_attack_config, diffusers_model_name_or_path = "/home/ubuntu/Projects/UnlearnCanvas/UnlearnCanvas/machine_unlearning/models/diffuser/style50", attack_step = 2, backend = "diffusers" diff --git a/mu_attack/configs/adv_unlearn/__init__.py b/mu_attack/configs/adv_unlearn/__init__.py index cc01a1ec..9c229329 100644 --- a/mu_attack/configs/adv_unlearn/__init__.py +++ b/mu_attack/configs/adv_unlearn/__init__.py @@ -1 +1 @@ -from .adv_unlearn_config import AdvUnlearnConfig, adv_unlearn_config \ No newline at end of file +from .adv_attack_config import AdvAttackConfig, adv_attack_config \ No newline at end of file diff --git a/mu_attack/configs/adv_unlearn/adv_unlearn_config.py b/mu_attack/configs/adv_unlearn/adv_attack_config.py similarity index 95% rename from mu_attack/configs/adv_unlearn/adv_unlearn_config.py rename to mu_attack/configs/adv_unlearn/adv_attack_config.py index b8ef2017..222217d7 100644 --- a/mu_attack/configs/adv_unlearn/adv_unlearn_config.py +++ b/mu_attack/configs/adv_unlearn/adv_attack_config.py @@ -1,10 +1,12 @@ +#mu_attack/configs/adv_unlearn/adv_attack_config.py + import os from pathlib import Path from mu.core.base_config import BaseConfig current_dir = Path(__file__).parent -class AdvUnlearnConfig(BaseConfig): +class AdvAttackConfig(BaseConfig): def __init__(self, **kwargs): # Inference & Model Paths for compvis self.config_path = current_dir / "model_config.yaml" @@ -69,5 +71,5 @@ def validate_config(self): if not os.path.exists(self.diffusers_model_name_or_path): raise FileNotFoundError(f"Diffusers model {self.diffusers_model_name_or_path} does not exist.") -adv_unlearn_config = AdvUnlearnConfig() +adv_attack_config = AdvAttackConfig() diff --git a/mu_attack/execs/adv_attack.py b/mu_attack/execs/adv_attack.py index dc2946dc..42ed5d76 100644 --- a/mu_attack/execs/adv_attack.py +++ b/mu_attack/execs/adv_attack.py @@ -8,7 +8,7 @@ from diffusers import StableDiffusionPipeline from stable_diffusion.ldm.models.diffusion.ddim import DDIMSampler -from mu_attack.configs.adv_unlearn import AdvUnlearnConfig +from mu_attack.configs.adv_unlearn import AdvAttackConfig from mu_attack.attackers.soft_prompt import SoftPromptAttack from mu_attack.tasks.utils.text_encoder import CustomTextEncoder from mu_attack.helpers.utils import get_models_for_compvis, get_models_for_diffusers @@ -21,7 +21,7 @@ class AdvAttack: This class wraps the full training pipeline including adversarial attack and model handling. """ - def __init__(self, config: AdvUnlearnConfig, **kwargs): + def __init__(self, config: AdvAttackConfig, **kwargs): self.config = config.__dict__ for key, value in kwargs.items(): setattr(config, key, value) From 6795a9c22bea3f7ee76d3a470805e12bef664fad Mon Sep 17 00:00:00 2001 From: palisthadeshar Date: Thu, 6 Feb 2025 12:43:44 +0545 Subject: [PATCH 14/22] renamed variables --- mu_attack/helpers/utils.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/mu_attack/helpers/utils.py b/mu_attack/helpers/utils.py index 861cb85c..a3e2d16e 100644 --- a/mu_attack/helpers/utils.py +++ b/mu_attack/helpers/utils.py @@ -237,16 +237,16 @@ def construct_id(k, adv_id, insertion_location,sot_id,eot_id,mid_id): -def get_models_for_compvis(config_path, ckpt_path, devices): - model_orig = load_model_from_config(config_path, ckpt_path, devices[1]) +def get_models_for_compvis(config_path, compvis_ckpt_path, devices): + model_orig = load_model_from_config(config_path, compvis_ckpt_path, devices[1]) sampler_orig = DDIMSampler(model_orig) - model = load_model_from_config(config_path, ckpt_path, devices[0]) + model = load_model_from_config(config_path, compvis_ckpt_path, devices[0]) sampler = DDIMSampler(model) return model_orig, sampler_orig, model, sampler -def get_models_for_diffusers(model_name_or_path, target_ckpt, devices, cache_path=None): +def get_models_for_diffusers(diffuser_model_name_or_path, target_ckpt, devices, cache_path=None): """ Loads two copies of a Diffusers UNet model along with their DDIM schedulers. @@ -266,7 +266,7 @@ def get_models_for_diffusers(model_name_or_path, target_ckpt, devices, cache_pat # Load the original model (used for e.g. computing loss, etc.) on devices[1] model_orig = UNet2DConditionModel.from_pretrained( - model_name_or_path, + diffuser_model_name_or_path, subfolder="unet", cache_dir=cache_path ).to(devices[1]) @@ -274,14 +274,14 @@ def get_models_for_diffusers(model_name_or_path, target_ckpt, devices, cache_pat # Create a DDIM scheduler for model_orig. (Note: diffusers DDIMScheduler is used here; # adjust the subfolder or configuration if your scheduler is stored elsewhere.) sampler_orig = DDIMScheduler.from_pretrained( - model_name_or_path, + diffuser_model_name_or_path, subfolder="scheduler", cache_dir=cache_path ) # Load the second copy of the model on devices[0] model = UNet2DConditionModel.from_pretrained( - model_name_or_path, + diffuser_model_name_or_path, subfolder="unet", cache_dir=cache_path ).to(devices[0]) @@ -292,7 +292,7 @@ def get_models_for_diffusers(model_name_or_path, target_ckpt, devices, cache_pat model.load_state_dict(state_dict) sampler = DDIMScheduler.from_pretrained( - model_name_or_path, + diffuser_model_name_or_path, subfolder="scheduler", cache_dir=cache_path ) From 615c9e3ad56c96a087f69c81654df0514d53b77d Mon Sep 17 00:00:00 2001 From: palisthadeshar Date: Mon, 10 Feb 2025 12:22:18 +0545 Subject: [PATCH 15/22] mu defense added --- mu_defense/algorithms/adv_unlearn/README.md | 0 mu_defense/algorithms/adv_unlearn/__init__.py | 1 + .../algorithms/adv_unlearn/algorithm.py | 354 +++++++++++ .../algorithms/adv_unlearn/compvis_trainer.py | 420 +++++++++++++ .../algorithms/adv_unlearn/evaluator.py | 0 mu_defense/algorithms/adv_unlearn/model.py | 109 ++++ mu_defense/algorithms/adv_unlearn/trainer.py | 0 mu_defense/algorithms/adv_unlearn/utils.py | 595 ++++++++++++++++++ mu_defense/core/__init__.py | 9 + mu_defense/core/base_algorithm.py | 45 ++ mu_defense/core/base_config.py | 0 mu_defense/core/base_model.py | 17 + mu_defense/core/base_trainer.py | 23 + 13 files changed, 1573 insertions(+) create mode 100644 mu_defense/algorithms/adv_unlearn/README.md create mode 100644 mu_defense/algorithms/adv_unlearn/__init__.py create mode 100644 mu_defense/algorithms/adv_unlearn/algorithm.py create mode 100644 mu_defense/algorithms/adv_unlearn/compvis_trainer.py create mode 100644 mu_defense/algorithms/adv_unlearn/evaluator.py create mode 100644 mu_defense/algorithms/adv_unlearn/model.py create mode 100644 mu_defense/algorithms/adv_unlearn/trainer.py create mode 100644 mu_defense/algorithms/adv_unlearn/utils.py create mode 100644 mu_defense/core/__init__.py create mode 100644 mu_defense/core/base_algorithm.py create mode 100644 mu_defense/core/base_config.py create mode 100644 mu_defense/core/base_model.py create mode 100644 mu_defense/core/base_trainer.py diff --git a/mu_defense/algorithms/adv_unlearn/README.md b/mu_defense/algorithms/adv_unlearn/README.md new file mode 100644 index 00000000..e69de29b diff --git a/mu_defense/algorithms/adv_unlearn/__init__.py b/mu_defense/algorithms/adv_unlearn/__init__.py new file mode 100644 index 00000000..90f60fdd --- /dev/null +++ b/mu_defense/algorithms/adv_unlearn/__init__.py @@ -0,0 +1 @@ +from .utils import * \ No newline at end of file diff --git a/mu_defense/algorithms/adv_unlearn/algorithm.py b/mu_defense/algorithms/adv_unlearn/algorithm.py new file mode 100644 index 00000000..901be612 --- /dev/null +++ b/mu_defense/algorithms/adv_unlearn/algorithm.py @@ -0,0 +1,354 @@ + + +import torch +from tqdm import tqdm +import random +import wandb + +from transformers import CLIPTextModel, CLIPTokenizer +from diffusers import AutoencoderKL + +from mu_attack.configs.adv_unlearn import AdvUnlearnConfig +from mu.helpers import sample_model +from mu_attack.tasks.utils.text_encoder import CustomTextEncoder +from mu_attack.attackers.soft_prompt import SoftPromptAttack +from mu_attack.helpers.utils import id2embedding, param_choices, get_models, retain_prompt, get_train_loss_retain,save_text_encoder, save_model, save_history + + + +class AdvUnlearn: + """ + Class for adversarial unlearning training. + + This class wraps the full training pipeline including prompt cleaning, + attack (adversarial prompt generation), and retention-based regularized training. + """ + def __init__( + self, + config: AdvUnlearnConfig, + **kwargs + ): + self.config = config.__dict__ + for key, value in kwargs.items(): + setattr(config, key, value) + + config.validate_config() + + self.config = config + self.prompt = config.prompt + self.dataset_retain = config.dataset_retain + self.retain_batch = config.retain_batch + self.retain_train = config.retain_train + self.retain_step = config.retain_step + self.retain_loss_w = config.retain_loss_w + self.attack_method = config.attack_method + self.train_method = config.train_method + self.norm_layer = config.norm_layer + self.component = config.component + self.model_name_or_path = config.model_name_or_path + self.start_guidance = config.start_guidance + self.negative_guidance = config.negative_guidance + self.iterations = config.iterations + self.save_interval = config.save_interval + self.lr = config.lr + self.config_path = config.config_path + self.ckpt_path = config.ckpt_path + self.diffusers_config_path = config.diffusers_config_path + self.output_dir = config.output_dir + self.devices = config.devices + self.seperator = config.seperator + self.image_size = config.image_size + self.ddim_steps = config.ddim_steps + self.adv_prompt_num = config.adv_prompt_num + self.attack_embd_type = config.attack_embd_type + self.attack_type = config.attack_type + self.attack_init = config.attack_init + self.warmup_iter = config.warmup_iter + self.attack_step = config.attack_step + self.attack_lr = config.attack_lr + self.adv_prompt_update_step = config.adv_prompt_update_step + self.ddim_eta = config.ddim_eta + self.cache_path = config.cache_path + + # Will be set during training. + self.words = None + self.retain_dataset = None + self.tokenizer = None + self.text_encoder = None + self.custom_text_encoder = None + self.all_embeddings = None + self.vae = None + self.model_orig = None + self.sampler_orig = None + self.model = None + self.sampler = None + self.parameters = None + self.opt = None + self.criteria = torch.nn.MSELoss() + + # For adversarial prompt update + self.adv_word_embd = None + self.adv_condition_embd = None + self.adv_input_ids = None + + def setup(self): + """Stage 0 & 1: Prompt cleaning and training setup.""" + # --- Prompt cleaning --- + word_print = self.prompt.replace(' ', '') + # Special cases for certain prompts + if self.prompt == 'allartist': + self.prompt = "Kelly Mckernan, Thomas Kinkade, Ajin Demi Human, Alena Aenami, Tyler Edlin, Kilian Eng" + if self.prompt == 'i2p': + self.prompt = "hate, harassment, violence, suffering, humiliation, harm, suicide, sexual, nudity, bodily fluids, blood" + if self.prompt == "artifact": + self.prompt = ("ugly, tiling, poorly drawn hands, poorly drawn feet, poorly drawn face, out of frame, " + "mutation, mutated, extra limbs, extra legs, extra arms, disfigured, deformed, cross-eye, " + "body out of frame, blurry, bad art, bad anatomy, blurred, text, watermark, grainy") + + if self.seperator is not None: + self.words = [w.strip() for w in self.prompt.split(self.seperator)] + else: + self.words = [self.prompt] + print(f'The Concept Prompt to be unlearned: {self.words}') + + # Create a retaining dataset (assumed to be a prompt dataset) + self.retain_dataset = retain_prompt(self.dataset_retain) + + # --- Training Setup --- + ddim_eta = self.ddim_eta # constant value for training + + + + # Load the VAE + self.vae = AutoencoderKL.from_pretrained(self.model_name_or_path, subfolder="vae", cache_dir=self.cache_path).to(self.devices[0]) + # Load tokenizer and text encoder + self.tokenizer = CLIPTokenizer.from_pretrained(self.model_name_or_path, subfolder="tokenizer", cache_dir=self.cache_path) + self.text_encoder = CLIPTextModel.from_pretrained(self.model_name_or_path, subfolder="text_encoder", cache_dir=self.cache_path).to(self.devices[0]) + self.custom_text_encoder = CustomTextEncoder(self.text_encoder).to(self.devices[0]) + self.all_embeddings = self.custom_text_encoder.get_all_embedding().unsqueeze(0) + + # Load models using your helper function (assumed to be defined in utils) + self.model_orig, self.sampler_orig, self.model, self.sampler = get_models(self.config_path, self.ckpt_path, self.devices) + self.model_orig.eval() + + # Setup trainable parameters based on train_method + if 'text_encoder' in self.train_method: + self.parameters = param_choices(model=self.custom_text_encoder, train_method=self.train_method, component=self.component, final_layer_norm=self.norm_layer) + else: + self.parameters = param_choices(model=self.model, train_method=self.train_method, component=self.component, final_layer_norm=self.norm_layer) + + self.opt = torch.optim.Adam(self.parameters, lr=self.lr) + + return word_print # For later use in saving history + + def train(self): + """Stage 2: Training loop.""" + word_print = self.setup() + ddim_eta = self.ddim_eta # As used in training + + # A lambda function to sample until a given time step. + quick_sample_till_t = lambda x, s, code, batch, t: sample_model( + self.model, self.sampler, + x, self.image_size, self.image_size, self.ddim_steps, s, ddim_eta, + start_code=code, n_samples=batch, till_T=t, verbose=False + ) + + losses = [] + history = [] + global_step = 0 + attack_round = 0 + + # Create a tqdm progress bar + pbar = tqdm(range(self.iterations)) + for i in pbar: + # --- Update adversarial prompt every adv_prompt_update_step iterations --- + if i % self.adv_prompt_update_step == 0: + # Reset the retaining dataset if needed + if self.retain_dataset.check_unseen_prompt_count() < self.retain_batch: + self.retain_dataset.reset() + + # Randomly choose one prompt from the list + word = random.sample(self.words, 1)[0] + text_input = self.tokenizer( + word, padding="max_length", max_length=self.tokenizer.model_max_length, + return_tensors="pt", truncation=True + ) + text_embeddings = id2embedding(self.tokenizer, self.all_embeddings, text_input.input_ids.to(self.devices[0]), self.devices[0]) + + # Get conditional embeddings from the frozen model + emb_0 = self.model_orig.get_learned_conditioning(['']) + emb_p = self.model_orig.get_learned_conditioning([word]) + + # --- Attack Step: Get adversarial prompt --- + if i >= self.warmup_iter: + self.custom_text_encoder.text_encoder.eval() + self.custom_text_encoder.text_encoder.requires_grad_(False) + self.model.eval() + + if attack_round == 0: + if self.attack_embd_type == 'word_embd': + self.adv_word_embd, self.adv_input_ids = SoftPromptAttack.attack( + global_step, word, self.model, self.model_orig, self.tokenizer, + self.custom_text_encoder, self.sampler, emb_0, emb_p, self.start_guidance, + self.devices, self.ddim_steps, ddim_eta, self.image_size, self.criteria, + self.adv_prompt_num, self.all_embeddings, attack_round, self.attack_type, + self.attack_embd_type, self.attack_step, self.attack_lr, self.attack_init, + None, self.attack_method + ) + elif self.attack_embd_type == 'condition_embd': + self.adv_condition_embd, self.adv_input_ids = SoftPromptAttack.attack( + global_step, word, self.model, self.model_orig, self.tokenizer, + self.custom_text_encoder, self.sampler, emb_0, emb_p, self.start_guidance, + self.devices, self.ddim_steps, ddim_eta, self.image_size, self.criteria, + self.adv_prompt_num, self.all_embeddings, attack_round, self.attack_type, + self.attack_embd_type, self.attack_step, self.attack_lr, self.attack_init, + None, self.attack_method + ) + else: + if self.attack_embd_type == 'word_embd': + self.adv_word_embd, self.adv_input_ids = SoftPromptAttack.attack( + global_step, word, self.model, self.model_orig, self.tokenizer, + self.custom_text_encoder, self.sampler, emb_0, emb_p, self.start_guidance, + self.devices, self.ddim_steps, ddim_eta, self.image_size, self.criteria, + self.adv_prompt_num, self.all_embeddings, attack_round, self.attack_type, + self.attack_embd_type, self.attack_step, self.attack_lr, self.attack_init, + self.adv_word_embd, self.attack_method + ) + elif self.attack_embd_type == 'condition_embd': + self.adv_condition_embd, self.adv_input_ids = SoftPromptAttack.attack( + global_step, word, self.model, self.model_orig, self.tokenizer, + self.custom_text_encoder, self.sampler, emb_0, emb_p, self.start_guidance, + self.devices, self.ddim_steps, ddim_eta, self.image_size, self.criteria, + self.adv_prompt_num, self.all_embeddings, attack_round, self.attack_type, + self.attack_embd_type, self.attack_step, self.attack_lr, self.attack_init, + self.adv_condition_embd, self.attack_method + ) + global_step += self.attack_step + attack_round += 1 + + # --- Set models to training/eval modes based on training method --- + if 'text_encoder' in self.train_method: + self.custom_text_encoder.text_encoder.train() + self.custom_text_encoder.text_encoder.requires_grad_(True) + self.model.eval() + else: + self.custom_text_encoder.text_encoder.eval() + self.custom_text_encoder.text_encoder.requires_grad_(False) + self.model.train() + self.opt.zero_grad() + + # --- Retaining prompts for retention regularization --- + if self.retain_train == 'reg': + retain_words = self.retain_dataset.get_random_prompts(self.retain_batch) + retain_text_input = self.tokenizer( + retain_words, padding="max_length", max_length=self.tokenizer.model_max_length, + return_tensors="pt", truncation=True + ) + retain_input_ids = retain_text_input.input_ids.to(self.devices[0]) + + retain_emb_p = self.model_orig.get_learned_conditioning(retain_words) + retain_text_embeddings = id2embedding(self.tokenizer, self.all_embeddings, retain_text_input.input_ids.to(self.devices[0]), self.devices[0]) + # Reshape to [batch, 77, embedding_dim] + retain_text_embeddings = retain_text_embeddings.reshape(self.retain_batch, -1, retain_text_embeddings.shape[-1]) + retain_emb_n = self.custom_text_encoder(input_ids=retain_input_ids, inputs_embeds=retain_text_embeddings)[0] + else: + retain_text_input = None + retain_text_embeddings = None + retain_emb_p = None + retain_emb_n = None + + # --- Compute training loss --- + if i < self.warmup_iter: + # Warmup training uses the original prompt embeddings. + input_ids = text_input.input_ids.to(self.devices[0]) + emb_n = self.custom_text_encoder(input_ids=input_ids, inputs_embeds=text_embeddings)[0] + loss = get_train_loss_retain( + self.retain_batch, self.retain_train, self.retain_loss_w, + self.model, self.model_orig, self.custom_text_encoder, self.sampler, + emb_0, emb_p, retain_emb_p, emb_n, retain_emb_n, self.start_guidance, + self.negative_guidance, self.devices, self.ddim_steps, ddim_eta, + self.image_size, self.criteria, input_ids, self.attack_embd_type + ) + else: + if self.attack_embd_type == 'word_embd': + loss = get_train_loss_retain( + self.retain_batch, self.retain_train, self.retain_loss_w, + self.model, self.model_orig, self.custom_text_encoder, self.sampler, + emb_0, emb_p, retain_emb_p, None, retain_emb_n, self.start_guidance, + self.negative_guidance, self.devices, self.ddim_steps, ddim_eta, + self.image_size, self.criteria, self.adv_input_ids, self.attack_embd_type, self.adv_word_embd + ) + elif self.attack_embd_type == 'condition_embd': + loss = get_train_loss_retain( + self.retain_batch, self.retain_train, self.retain_loss_w, + self.model, self.model_orig, self.custom_text_encoder, self.sampler, + emb_0, emb_p, retain_emb_p, None, retain_emb_n, self.start_guidance, + self.negative_guidance, self.devices, self.ddim_steps, ddim_eta, + self.image_size, self.criteria, self.adv_input_ids, self.attack_embd_type, self.adv_condition_embd + ) + + # Backpropagate loss and update weights. + loss.backward() + losses.append(loss.item()) + pbar.set_postfix({"loss": loss.item()}) + history.append(loss.item()) + wandb.log({'Train_Loss': loss.item()}, step=global_step) + wandb.log({'Attack_Loss': 0.0}, step=global_step) + global_step += 1 + self.opt.step() + + # --- Additional Retention Training (if using iterative retention) --- + if self.retain_train == 'iter': + for r in range(self.retain_step): + print(f'==== Retain Training at step {r} ====') + self.opt.zero_grad() + if self.retain_dataset.check_unseen_prompt_count() < self.retain_batch: + self.retain_dataset.reset() + retain_words = self.retain_dataset.get_random_prompts(self.retain_batch) + + t_enc = torch.randint(self.ddim_steps, (1,), device=self.devices[0]) + og_num = round((int(t_enc) / self.ddim_steps) * 1000) + og_num_lim = round((int(t_enc + 1) / self.ddim_steps) * 1000) + t_enc_ddpm = torch.randint(og_num, og_num_lim, (1,), device=self.devices[0]) + retain_start_code = torch.randn((self.retain_batch, 4, 64, 64)).to(self.devices[0]) + + retain_emb_p = self.model_orig.get_learned_conditioning(retain_words) + retain_z = quick_sample_till_t(retain_emb_p.to(self.devices[0]), self.start_guidance, retain_start_code, self.retain_batch, int(t_enc)) + retain_e_p = self.model_orig.apply_model(retain_z.to(self.devices[0]), t_enc_ddpm.to(self.devices[0]), retain_emb_p.to(self.devices[0])) + + retain_text_input = self.tokenizer( + retain_words, padding="max_length", max_length=self.tokenizer.model_max_length, + return_tensors="pt", truncation=True + ) + retain_input_ids = retain_text_input.input_ids.to(self.devices[0]) + retain_text_embeddings = id2embedding(self.tokenizer, self.all_embeddings, retain_text_input.input_ids.to(self.devices[0]), self.devices[0]) + retain_text_embeddings = retain_text_embeddings.reshape(self.retain_batch, -1, retain_text_embeddings.shape[-1]) + retain_emb_n = self.custom_text_encoder(input_ids=retain_input_ids, inputs_embeds=retain_text_embeddings)[0] + retain_e_n = self.model.apply_model(retain_z.to(self.devices[0]), t_enc_ddpm.to(self.devices[0]), retain_emb_n.to(self.devices[0])) + + retain_loss = self.criteria(retain_e_n.to(self.devices[0]), retain_e_p.to(self.devices[0])) + retain_loss.backward() + self.opt.step() + + # --- Checkpointing and saving history --- + if (i + 1) % self.save_interval == 0 and (i + 1) != self.iterations and (i + 1) >= self.save_interval: + if 'text_encoder' in self.train_method: + save_text_encoder(self.output_dir, self.custom_text_encoder, self.train_method, i) + else: + save_model(self.output_dir, self.model, self.train_method, i, save_compvis=True, + save_diffusers=True, compvis_config_file=self.config_path, + diffusers_config_file=self.diffusers_config_path) + if i % 1 == 0: + save_history(self.output_dir, losses, word_print) + + # --- Stage 3: Save final model and loss curve --- + self.model.eval() + self.custom_text_encoder.text_encoder.eval() + self.custom_text_encoder.text_encoder.requires_grad_(False) + if 'text_encoder' in self.train_method: + save_text_encoder(self.output_dir, self.custom_text_encoder, self.train_method, i) + else: + save_model(self.output_dir, self.model, self.train_method, i, save_compvis=True, + save_diffusers=True, compvis_config_file=self.config_path, + diffusers_config_file=self.diffusers_config_path) + save_history(self.output_dir, losses, word_print) diff --git a/mu_defense/algorithms/adv_unlearn/compvis_trainer.py b/mu_defense/algorithms/adv_unlearn/compvis_trainer.py new file mode 100644 index 00000000..2b810b96 --- /dev/null +++ b/mu_defense/algorithms/adv_unlearn/compvis_trainer.py @@ -0,0 +1,420 @@ +import torch +from tqdm import tqdm +import random +import wandb +from torch.nn import MSELoss + +from mu.helpers import sample_model +from mu.core import BaseTrainer +from mu_defense.algorithms.adv_unlearn import ( + id2embedding, + param_choices, + retain_prompt, + get_train_loss_retain, + save_text_encoder, + save_model, + save_history +) +from mu_attack.attackers.soft_prompt import SoftPromptAttack +from mu_attack.tasks.utils.text_encoder import CustomTextEncoder + +class AdvUnlearnTrainer(BaseTrainer): + """ + Trainer for adversarial unlearning. + + This trainer performs the adversarial prompt update and retention-based + regularized training loop. + """ + def __init__(self, model, config: dict, devices: list, **kwargs): + """ + Initialize the AdvUnlearnTrainer. + + Args: + model: A model loader instance that contains the following attributes: + - model_orig: the frozen diffusion model, + - sampler_orig: sampler for the frozen model, + - model: the trainable diffusion model, + - sampler: sampler for the trainable model, + - tokenizer: the tokenizer, + - custom_text_encoder: the custom text encoder wrapping the CLIP text encoder, + - all_embeddings: the complete text embedding matrix, + - vae: the VAE. + config (dict): Configuration dictionary with all training hyperparameters. + devices (list): List of device strings (e.g., ['cuda:0']). + """ + super().__init__(model, config, **kwargs) + self.devices = devices + + # Unpack models and samplers from the provided model loader. + self.model = model.model # trainable diffusion model + self.model_orig = model.model_orig # frozen diffusion model (set to eval) + self.sampler = model.sampler + self.sampler_orig = model.sampler_orig + + # Other loaded components. + self.tokenizer = model.tokenizer + self.custom_text_encoder = model.custom_text_encoder + self.all_embeddings = model.all_embeddings + self.vae = model.vae + + # Loss criterion. + self.criteria = MSELoss() + + # Save configuration parameters. + self.config = config + self.prompt = self.config['prompt'] + self.seperator = self.config.get('seperator') + self.iterations = self.config['iterations'] + self.ddim_steps = self.config['ddim_steps'] + self.start_guidance = self.config['start_guidance'] + self.negative_guidance = self.config['negative_guidance'] + self.image_size = self.config['image_size'] + self.lr = self.config['lr'] + self.config_path = self.config['config_path'] + self.diffusers_config_path = self.config['diffusers_config_path'] + self.output_dir = self.config['output_dir'] + + # Retention and attack parameters. + self.dataset_retain = self.config['dataset_retain'] + self.retain_batch = self.config['retain_batch'] + self.retain_train = self.config['retain_train'] + self.retain_step = self.config['retain_step'] + self.retain_loss_w = self.config['retain_loss_w'] + self.attack_method = self.config['attack_method'] + self.train_method = self.config['train_method'] + self.norm_layer = self.config['norm_layer'] + self.component = self.config['component'] + self.adv_prompt_num = self.config['adv_prompt_num'] + self.attack_embd_type = self.config['attack_embd_type'] + self.attack_type = self.config['attack_type'] + self.attack_init = self.config['attack_init'] + self.warmup_iter = self.config['warmup_iter'] + self.attack_step = self.config['attack_step'] + self.attack_lr = self.config['attack_lr'] + self.adv_prompt_update_step = self.config['adv_prompt_update_step'] + self.ddim_eta = self.config['ddim_eta'] + + # Setup prompt cleaning and retaining dataset. + self._setup_prompt_and_dataset() + + # Initialize adversarial prompt variables. + self.adv_word_embd = None + self.adv_condition_embd = None + self.adv_input_ids = None + + # Setup trainable parameters and optimizer. + self._setup_optimizer() + + def _setup_prompt_and_dataset(self): + """ + Perform prompt cleaning and create the retaining dataset. + """ + # Special-case prompt cleaning. + if self.prompt == 'allartist': + self.prompt = ("Kelly Mckernan, Thomas Kinkade, Ajin Demi Human, " + "Alena Aenami, Tyler Edlin, Kilian Eng") + elif self.prompt == 'i2p': + self.prompt = ("hate, harassment, violence, suffering, humiliation, harm, " + "suicide, sexual, nudity, bodily fluids, blood") + elif self.prompt == "artifact": + self.prompt = ("ugly, tiling, poorly drawn hands, poorly drawn feet, poorly drawn face, out of frame, " + "mutation, mutated, extra limbs, extra legs, extra arms, disfigured, deformed, cross-eye, " + "body out of frame, blurry, bad art, bad anatomy, blurred, text, watermark, grainy") + if self.seperator: + self.words = [w.strip() for w in self.prompt.split(self.seperator)] + else: + self.words = [self.prompt] + self.word_print = self.prompt.replace(" ", "") + print(f"The Concept Prompt to be unlearned: {self.words}") + + # Create a retaining dataset using your helper function. + self.retain_dataset = retain_prompt(self.dataset_retain) + + def _setup_optimizer(self): + """ + Set up the optimizer based on the training method. + """ + if 'text_encoder' in self.train_method: + self.parameters = param_choices( + model=self.custom_text_encoder, + train_method=self.train_method, + component=self.component, + final_layer_norm=self.norm_layer + ) + else: + self.parameters = param_choices( + model=self.model, + train_method=self.train_method, + component=self.component, + final_layer_norm=self.norm_layer + ) + self.optimizer = torch.optim.Adam(self.parameters, lr=float(self.lr)) + + def train(self): + """ + Execute the adversarial unlearning training loop. + """ + ddim_eta = self.ddim_eta + # Lambda to sample until a given time step. + quick_sample_till_t = lambda x, s, code, batch, t: sample_model( + self.model, self.sampler, + x, self.image_size, self.image_size, self.ddim_steps, s, ddim_eta, + start_code=code, n_samples=batch, till_T=t, verbose=False + ) + losses = [] + history = [] + global_step = 0 + attack_round = 0 + + pbar = tqdm(range(self.iterations)) + for i in pbar: + # --- Update adversarial prompt every adv_prompt_update_step iterations --- + if i % self.adv_prompt_update_step == 0: + if self.retain_dataset.check_unseen_prompt_count() < self.retain_batch: + self.retain_dataset.reset() + word = random.choice(self.words) + text_input = self.tokenizer( + word, + padding="max_length", + max_length=self.tokenizer.model_max_length, + return_tensors="pt", + truncation=True + ) + text_embeddings = id2embedding( + self.tokenizer, + self.all_embeddings, + text_input.input_ids.to(self.devices[0]), + self.devices[0] + ) + # Get conditioning from the frozen model. + emb_0 = self.model_orig.get_learned_conditioning(['']) + emb_p = self.model_orig.get_learned_conditioning([word]) + + if i >= self.warmup_iter: + # Update adversarial prompt using SoftPromptAttack. + self.custom_text_encoder.text_encoder.eval() + self.custom_text_encoder.text_encoder.requires_grad_(False) + self.model.eval() + if attack_round == 0: + if self.attack_embd_type == 'word_embd': + self.adv_word_embd, self.adv_input_ids = SoftPromptAttack.attack( + global_step, word, self.model, self.model_orig, self.tokenizer, + self.custom_text_encoder, self.sampler, emb_0, emb_p, + self.start_guidance, self.devices, self.ddim_steps, ddim_eta, + self.image_size, self.criteria, self.adv_prompt_num, + self.all_embeddings, attack_round, self.attack_type, + self.attack_embd_type, self.attack_step, self.attack_lr, + self.attack_init, None, self.attack_method + ) + elif self.attack_embd_type == 'condition_embd': + self.adv_condition_embd, self.adv_input_ids = SoftPromptAttack.attack( + global_step, word, self.model, self.model_orig, self.tokenizer, + self.custom_text_encoder, self.sampler, emb_0, emb_p, + self.start_guidance, self.devices, self.ddim_steps, ddim_eta, + self.image_size, self.criteria, self.adv_prompt_num, + self.all_embeddings, attack_round, self.attack_type, + self.attack_embd_type, self.attack_step, self.attack_lr, + self.attack_init, None, self.attack_method + ) + else: + if self.attack_embd_type == 'word_embd': + self.adv_word_embd, self.adv_input_ids = SoftPromptAttack.attack( + global_step, word, self.model, self.model_orig, self.tokenizer, + self.custom_text_encoder, self.sampler, emb_0, emb_p, + self.start_guidance, self.devices, self.ddim_steps, ddim_eta, + self.image_size, self.criteria, self.adv_prompt_num, + self.all_embeddings, attack_round, self.attack_type, + self.attack_embd_type, self.attack_step, self.attack_lr, + self.attack_init, self.adv_word_embd, self.attack_method + ) + elif self.attack_embd_type == 'condition_embd': + self.adv_condition_embd, self.adv_input_ids = SoftPromptAttack.attack( + global_step, word, self.model, self.model_orig, self.tokenizer, + self.custom_text_encoder, self.sampler, emb_0, emb_p, + self.start_guidance, self.devices, self.ddim_steps, ddim_eta, + self.image_size, self.criteria, self.adv_prompt_num, + self.all_embeddings, attack_round, self.attack_type, + self.attack_embd_type, self.attack_step, self.attack_lr, + self.attack_init, self.adv_condition_embd, self.attack_method + ) + global_step += self.attack_step + attack_round += 1 + + # --- Set models to training/eval modes based on train_method --- + if 'text_encoder' in self.train_method: + self.custom_text_encoder.text_encoder.train() + self.custom_text_encoder.text_encoder.requires_grad_(True) + self.model.eval() + else: + self.custom_text_encoder.text_encoder.eval() + self.custom_text_encoder.text_encoder.requires_grad_(False) + self.model.train() + + self.optimizer.zero_grad() + + # --- Retaining prompts for retention regularization (if configured) --- + if self.retain_train == 'reg': + retain_words = self.retain_dataset.get_random_prompts(self.retain_batch) + retain_text_input = self.tokenizer( + retain_words, + padding="max_length", + max_length=self.tokenizer.model_max_length, + return_tensors="pt", + truncation=True + ) + retain_input_ids = retain_text_input.input_ids.to(self.devices[0]) + retain_emb_p = self.model_orig.get_learned_conditioning(retain_words) + retain_text_embeddings = id2embedding( + self.tokenizer, + self.all_embeddings, + retain_text_input.input_ids.to(self.devices[0]), + self.devices[0] + ) + retain_text_embeddings = retain_text_embeddings.reshape( + self.retain_batch, -1, retain_text_embeddings.shape[-1] + ) + retain_emb_n = self.custom_text_encoder( + input_ids=retain_input_ids, + inputs_embeds=retain_text_embeddings + )[0] + else: + retain_emb_p = None + retain_emb_n = None + + # --- Compute training loss --- + if i < self.warmup_iter: + input_ids = text_input.input_ids.to(self.devices[0]) + emb_n = self.custom_text_encoder( + input_ids=input_ids, + inputs_embeds=text_embeddings + )[0] + loss = get_train_loss_retain( + self.retain_batch, self.retain_train, self.retain_loss_w, + self.model, self.model_orig, self.custom_text_encoder, self.sampler, + emb_0, emb_p, retain_emb_p, emb_n, retain_emb_n, self.start_guidance, + self.negative_guidance, self.devices, self.ddim_steps, ddim_eta, + self.image_size, self.criteria, input_ids, self.attack_embd_type + ) + else: + if self.attack_embd_type == 'word_embd': + loss = get_train_loss_retain( + self.retain_batch, self.retain_train, self.retain_loss_w, + self.model, self.model_orig, self.custom_text_encoder, self.sampler, + emb_0, emb_p, retain_emb_p, None, retain_emb_n, self.start_guidance, + self.negative_guidance, self.devices, self.ddim_steps, ddim_eta, + self.image_size, self.criteria, self.adv_input_ids, self.attack_embd_type, + self.adv_word_embd + ) + elif self.attack_embd_type == 'condition_embd': + loss = get_train_loss_retain( + self.retain_batch, self.retain_train, self.retain_loss_w, + self.model, self.model_orig, self.custom_text_encoder, self.sampler, + emb_0, emb_p, retain_emb_p, None, retain_emb_n, self.start_guidance, + self.negative_guidance, self.devices, self.ddim_steps, ddim_eta, + self.image_size, self.criteria, self.adv_input_ids, self.attack_embd_type, + self.adv_condition_embd + ) + loss.backward() + losses.append(loss.item()) + pbar.set_postfix({"loss": loss.item()}) + history.append(loss.item()) + wandb.log({'Train_Loss': loss.item()}, step=global_step) + wandb.log({'Attack_Loss': 0.0}, step=global_step) + global_step += 1 + self.optimizer.step() + + # --- Additional Retention Training (for iterative retention) --- + if self.retain_train == 'iter': + for r in range(self.retain_step): + self.optimizer.zero_grad() + if self.retain_dataset.check_unseen_prompt_count() < self.retain_batch: + self.retain_dataset.reset() + retain_words = self.retain_dataset.get_random_prompts(self.retain_batch) + t_enc = torch.randint(self.ddim_steps, (1,), device=self.devices[0]) + og_num = round((int(t_enc.item()) / self.ddim_steps) * 1000) + og_num_lim = round(((int(t_enc.item()) + 1) / self.ddim_steps) * 1000) + t_enc_ddpm = torch.randint(og_num, og_num_lim, (1,), device=self.devices[0]) + retain_start_code = torch.randn((self.retain_batch, 4, 64, 64)).to(self.devices[0]) + retain_emb_p = self.model_orig.get_learned_conditioning(retain_words) + retain_z = quick_sample_till_t( + retain_emb_p.to(self.devices[0]), + self.start_guidance, + retain_start_code, + self.retain_batch, + int(t_enc.item()) + ) + retain_e_p = self.model_orig.apply_model( + retain_z.to(self.devices[0]), + t_enc_ddpm.to(self.devices[0]), + retain_emb_p.to(self.devices[0]) + ) + retain_text_input = self.tokenizer( + retain_words, + padding="max_length", + max_length=self.tokenizer.model_max_length, + return_tensors="pt", + truncation=True + ) + retain_input_ids = retain_text_input.input_ids.to(self.devices[0]) + retain_text_embeddings = id2embedding( + self.tokenizer, + self.all_embeddings, + retain_text_input.input_ids.to(self.devices[0]), + self.devices[0] + ) + retain_text_embeddings = retain_text_embeddings.reshape( + self.retain_batch, -1, retain_text_embeddings.shape[-1] + ) + retain_emb_n = self.custom_text_encoder( + input_ids=retain_input_ids, + inputs_embeds=retain_text_embeddings + )[0] + retain_e_n = self.model.apply_model( + retain_z.to(self.devices[0]), + t_enc_ddpm.to(self.devices[0]), + retain_emb_n.to(self.devices[0]) + ) + retain_loss = self.criteria( + retain_e_n.to(self.devices[0]), + retain_e_p.to(self.devices[0]) + ) + retain_loss.backward() + self.optimizer.step() + + # --- Checkpointing and saving history --- + if (i + 1) % self.config['save_interval'] == 0 and (i + 1) != self.iterations and (i + 1) >= self.config['save_interval']: + if 'text_encoder' in self.train_method: + save_text_encoder(self.output_dir, self.custom_text_encoder, self.train_method, i) + else: + save_model( + self.output_dir, + self.model, + self.train_method, + i, + save_compvis=True, + save_diffusers=True, + compvis_config_file=self.config_path, + diffusers_config_file=self.diffusers_config_path + ) + if i % 1 == 0: + save_history(self.output_dir, losses, self.word_print) + + # --- Final checkpointing --- + self.model.eval() + self.custom_text_encoder.text_encoder.eval() + self.custom_text_encoder.text_encoder.requires_grad_(False) + if 'text_encoder' in self.train_method: + save_text_encoder(self.output_dir, self.custom_text_encoder, self.train_method, i) + else: + save_model( + self.output_dir, + self.model, + self.train_method, + i, + save_compvis=True, + save_diffusers=True, + compvis_config_file=self.config_path, + diffusers_config_file=self.diffusers_config_path + ) + save_history(self.output_dir, losses, self.word_print) + return self.model diff --git a/mu_defense/algorithms/adv_unlearn/evaluator.py b/mu_defense/algorithms/adv_unlearn/evaluator.py new file mode 100644 index 00000000..e69de29b diff --git a/mu_defense/algorithms/adv_unlearn/model.py b/mu_defense/algorithms/adv_unlearn/model.py new file mode 100644 index 00000000..04db4cb7 --- /dev/null +++ b/mu_defense/algorithms/adv_unlearn/model.py @@ -0,0 +1,109 @@ +import torch +from diffusers import AutoencoderKL +from transformers import CLIPTextModel, CLIPTokenizer + +from mu_attack.tasks.utils.text_encoder import CustomTextEncoder +from mu_defense.algorithms.adv_unlearn import get_models + +from mu_defense.core import BaseModel + +class AdvUnlearnModel(BaseModel): + """ + AdvUnlearnModel handles loading of the components for adversarial unlearning. + This includes: + - The VAE from the pretrained model. + - The tokenizer. + - The text encoder (and its custom wrapper). + - The diffusion models (trainable and frozen versions) along with their samplers. + """ + def __init__( + self, + model_name_or_path: str, + config_path: str, + ckpt_path: str, + cache_path: str, + devices: list + ): + """ + Initialize the AdvUnlearnModel loader. + + Args: + model_name_or_path (str): Path or identifier of the pretrained model. + config_path (str): Path to the model configuration file. + ckpt_path (str): Path to the model checkpoint. + cache_path (str): Directory for caching downloaded models. + devices (list): List of device strings (e.g., ['cuda:0', 'cuda:1']) for model placement. + """ + super().__init__() + self.model_name_or_path = model_name_or_path + self.config_path = config_path + self.ckpt_path = ckpt_path + self.cache_path = cache_path + self.devices = devices + + # Load the VAE. + self.vae = AutoencoderKL.from_pretrained( + self.model_name_or_path, + subfolder="vae", + cache_dir=self.cache_path + ).to(self.devices[0]) + + # Load the tokenizer. + self.tokenizer = CLIPTokenizer.from_pretrained( + self.model_name_or_path, + subfolder="tokenizer", + cache_dir=self.cache_path + ) + + # Load the text encoder and wrap it with your custom encoder. + self.text_encoder = CLIPTextModel.from_pretrained( + self.model_name_or_path, + subfolder="text_encoder", + cache_dir=self.cache_path + ).to(self.devices[0]) + self.custom_text_encoder = CustomTextEncoder(self.text_encoder).to(self.devices[0]) + self.all_embeddings = self.custom_text_encoder.get_all_embedding().unsqueeze(0) + + # Load diffusion models using your helper function. + self.model_orig, self.sampler_orig, self.model, self.sampler = get_models( + self.config_path, + self.ckpt_path, + self.devices + ) + self.model_orig.eval() # Set the frozen model to evaluation mode. + + def save_model(self, model: torch.nn.Module, output_path: str) -> None: + """ + Save the model's state dictionary. + + Args: + model (torch.nn.Module): The model to be saved. + output_path (str): The file path where the model checkpoint will be stored. + """ + torch.save({"state_dict": model.state_dict()}, output_path) + + def get_learned_conditioning(self, prompts: list): + """ + Obtain the learned conditioning for the given prompts using the trainable diffusion model. + + Args: + prompts (list): A list of prompt strings. + + Returns: + The conditioning tensors produced by the model. + """ + return self.model.get_learned_conditioning(prompts) + + def apply_model(self, z: torch.Tensor, t: torch.Tensor, c): + """ + Apply the diffusion model to produce an output. + + Args: + z (torch.Tensor): Noisy latent vectors. + t (torch.Tensor): Timestep tensor. + c: Conditioning tensors. + + Returns: + torch.Tensor: The output of the diffusion model. + """ + return self.model.apply_model(z, t, c) diff --git a/mu_defense/algorithms/adv_unlearn/trainer.py b/mu_defense/algorithms/adv_unlearn/trainer.py new file mode 100644 index 00000000..e69de29b diff --git a/mu_defense/algorithms/adv_unlearn/utils.py b/mu_defense/algorithms/adv_unlearn/utils.py new file mode 100644 index 00000000..2a0fe941 --- /dev/null +++ b/mu_defense/algorithms/adv_unlearn/utils.py @@ -0,0 +1,595 @@ + +import random +import pandas as pd + +import torch +import torch.nn.functional as F + +from mu.helpers import load_model_from_config, sample_model +from stable_diffusion.ldm.models.diffusion.ddim import DDIMSampler + + +class PromptDataset: + def __init__(self, csv_file): + self.data = pd.read_csv(csv_file) + self.unseen_indices = list(self.data.index) # 保存所有未见过的索引 + + def get_random_prompts(self, num_prompts=1): + # Ensure that the number of prompts requested is not greater than the number of unseen prompts + num_prompts = min(num_prompts, len(self.unseen_indices)) + + # Randomly select num_prompts indices from the list of unseen indices + selected_indices = random.sample(self.unseen_indices, num_prompts) + + # Remove the selected indices from the list of unseen indices + for index in selected_indices: + self.unseen_indices.remove(index) + + # return the prompts corresponding to the selected indices + return self.data.loc[selected_indices, 'prompt'].tolist() + + def has_unseen_prompts(self): + # check if there are any unseen prompts + return len(self.unseen_indices) > 0 + + def reset(self): + self.unseen_indices = list(self.data.index) + + def check_unseen_prompt_count(self): + return len(self.unseen_indices) + +def retain_prompt(dataset_retain): + # Prompt Dataset to be retained + + if dataset_retain == 'imagenet243': + retain_dataset = PromptDataset('data/prompts/train/imagenet243_retain.csv') + elif dataset_retain == 'imagenet243_no_filter': + retain_dataset = PromptDataset('data/prompts/train/imagenet243_no_filter_retain.csv') + elif dataset_retain == 'coco_object': + retain_dataset = PromptDataset('data/prompts/train/coco_object_retain.csv') + elif dataset_retain == 'coco_object_no_filter': + retain_dataset = PromptDataset('data/prompts/train/coco_object_no_filter_retain.csv') + else: + raise ValueError('Invalid dataset for retaining prompts') + + return retain_dataset + +def id2embedding(tokenizer, all_embeddings, input_ids, device): + input_one_hot = F.one_hot(input_ids.view(-1), num_classes = len(tokenizer.get_vocab())).float() + input_one_hot = torch.unsqueeze(input_one_hot,0).to(device) + input_embeds = input_one_hot @ all_embeddings + return input_embeds + +def get_models(config_path, ckpt_path, devices): + model_orig = load_model_from_config(config_path, ckpt_path, devices[1]) + sampler_orig = DDIMSampler(model_orig) + + model = load_model_from_config(config_path, ckpt_path, devices[0]) + sampler = DDIMSampler(model) + + return model_orig, sampler_orig, model, sampler + + +def get_train_loss_retain( retain_batch, retain_train, retain_loss_w, model, model_orig, text_encoder, sampler, emb_0, emb_p, retain_emb_p, emb_n, retain_emb_n, start_guidance, negative_guidance, devices, ddim_steps, ddim_eta, image_size, criteria, adv_input_ids, attack_embd_type, adv_embd=None): + """_summary_ + + Args: + model: ESD model + model_orig: frozen DDPM model + sampler: DDIMSampler for DDPM model + + emb_0: unconditional embedding + emb_p: conditional embedding (for ground truth concept) + emb_n: conditional embedding (for modified concept) + + start_guidance: unconditional guidance for ESD model + negative_guidance: negative guidance for ESD model + + devices: list of devices for ESD and DDPM models + ddim_steps: number of steps for DDIMSampler + ddim_eta: eta for DDIMSampler + image_size: image size for DDIMSampler + + criteria: loss function for ESD model + + adv_input_ids: input_ids for adversarial word embedding + adv_emb_n: adversarial conditional embedding + adv_word_emb_n: adversarial word embedding + + Returns: + loss: training loss for ESD model + """ + quick_sample_till_t = lambda x, s, code, batch, t: sample_model(model, sampler, + x, image_size, image_size, ddim_steps, s, ddim_eta, + start_code=code, n_samples=batch, till_T=t, verbose=False) + + + t_enc = torch.randint(ddim_steps, (1,), device=devices[0]) + # time step from 1000 to 0 (0 being good) + og_num = round((int(t_enc)/ddim_steps)*1000) + og_num_lim = round((int(t_enc+1)/ddim_steps)*1000) + + t_enc_ddpm = torch.randint(og_num, og_num_lim, (1,), device=devices[0]) + + start_code = torch.randn((1, 4, 64, 64)).to(devices[0]) + if retain_train == 'reg': + retain_start_code = torch.randn((retain_batch, 4, 64, 64)).to(devices[0]) + + with torch.no_grad(): + # generate an image with the concept from ESD model + z = quick_sample_till_t(emb_p.to(devices[0]), start_guidance, start_code, 1, int(t_enc)) # emb_p seems to work better instead of emb_0 + # get conditional and unconditional scores from frozen model at time step t and image z + e_0 = model_orig.apply_model(z.to(devices[0]), t_enc_ddpm.to(devices[0]), emb_0.to(devices[0])) + e_p = model_orig.apply_model(z.to(devices[0]), t_enc_ddpm.to(devices[0]), emb_p.to(devices[0])) + + if retain_train == 'reg': + retain_z = quick_sample_till_t(retain_emb_p.to(devices[0]), start_guidance, retain_start_code, retain_batch, int(t_enc)) # emb_p seems to work better instead of emb_0 + # retain_e_0 = model_orig.apply_model(retain_z.to(devices[0]), t_enc_ddpm.to(devices[0]), retain_emb_0.to(devices[0])) + retain_e_p = model_orig.apply_model(retain_z.to(devices[0]), t_enc_ddpm.to(devices[0]), retain_emb_p.to(devices[0])) + + if adv_embd is None: + e_n = model.apply_model(z.to(devices[0]), t_enc_ddpm.to(devices[0]), emb_n.to(devices[0])) + else: + if attack_embd_type == 'condition_embd': + # Train with adversarial conditional embedding + e_n = model.apply_model(z.to(devices[0]), t_enc_ddpm.to(devices[0]), adv_embd.to(devices[0])) + elif attack_embd_type == 'word_embd': + # Train with adversarial word embedding + print('====== Training with adversarial word embedding =====') + adv_emb_n = text_encoder(input_ids = adv_input_ids.to(devices[0]), inputs_embeds=adv_embd.to(devices[0]))[0] + e_n = model.apply_model(z.to(devices[0]), t_enc_ddpm.to(devices[0]), adv_emb_n.to(devices[0])) + else: + raise ValueError('attack_embd_type must be either condition_embd or word_embd') + + e_0.requires_grad = False + e_p.requires_grad = False + + # reconstruction loss for ESD objective from frozen model and conditional score of ESD model + # loss = criteria(e_n.to(devices[0]), e_0.to(devices[0]) - (negative_guidance*(e_p.to(devices[0]) - e_0.to(devices[0])))) + + # return loss + + if retain_train == 'reg': + # reconstruction loss for ESD objective from frozen model and conditional score of ESD model + print('====== Training with retain batch =====') + unlearn_loss = criteria(e_n.to(devices[0]), e_0.to(devices[0]) - (negative_guidance*(e_p.to(devices[0]) - e_0.to(devices[0])))) + + retain_e_n = model.apply_model(retain_z.to(devices[0]), t_enc_ddpm.to(devices[0]), retain_emb_n.to(devices[0])) + + # retain_e_0.requires_grad = False + retain_e_p.requires_grad = False + retain_loss = criteria(retain_e_n.to(devices[0]), retain_e_p.to(devices[0])) + + loss = unlearn_loss + retain_loss_w * retain_loss + return loss + + else: + # reconstruction loss for ESD objective from frozen model and conditional score of ESD model + unlearn_loss = criteria(e_n.to(devices[0]), e_0.to(devices[0]) - (negative_guidance*(e_p.to(devices[0]) - e_0.to(devices[0])))) + return unlearn_loss + + +def param_choices(model, train_method, component='all', final_layer_norm=False): + # choose parameters to train based on train_method + parameters = [] + + # Text Encoder FUll Weight Tuning + if train_method == 'text_encoder_full': + for name, param in model.text_encoder.text_model.named_parameters(): + # Final Layer Norm + if name.startswith('final_layer_norm'): + if component == 'all' or final_layer_norm==True: + print(name) + parameters.append(param) + else: + pass + + # Transformer layers + elif name.startswith('encoder'): + if component == 'ffn' and 'mlp' in name: + print(name) + parameters.append(param) + elif component == 'attn' and 'self_attn' in name: + print(name) + parameters.append(param) + elif component == 'all': + print(name) + parameters.append(param) + else: + pass + + # Embedding layers + else: + pass + + # Text Encoder Layer 0 Tuning + elif train_method == 'text_encoder_layer0': + for name, param in model.text_encoder.text_model.named_parameters(): + # Encoder Layer 0 + if name.startswith('encoder.layers.0'): + if component == 'ffn' and 'mlp' in name: + print(name) + parameters.append(param) + elif component == 'attn' and 'self_attn' in name: + print(name) + parameters.append(param) + elif component == 'all': + print(name) + parameters.append(param) + else: + pass + + elif name.startswith('final_layer_norm') and final_layer_norm==True: + print(name) + parameters.append(param) + + else: + pass + + elif train_method == 'text_encoder_layer01': + for name, param in model.text_encoder.text_model.named_parameters(): + # Encoder Layer 0 + if name.startswith('encoder.layers.0') or name.startswith('encoder.layers.1'): + if component == 'ffn' and 'mlp' in name: + print(name) + parameters.append(param) + elif component == 'attn' and 'self_attn' in name: + print(name) + parameters.append(param) + elif component == 'all': + print(name) + parameters.append(param) + else: + pass + elif name.startswith('final_layer_norm') and final_layer_norm==True: + print(name) + parameters.append(param) + + else: + pass + + elif train_method == 'text_encoder_layer012': + for name, param in model.text_encoder.text_model.named_parameters(): + # Encoder Layer 0 + if name.startswith('encoder.layers.0') or name.startswith('encoder.layers.1') or name.startswith('encoder.layers.2'): + if component == 'ffn' and 'mlp' in name: + print(name) + parameters.append(param) + elif component == 'attn' and 'self_attn' in name: + print(name) + parameters.append(param) + elif component == 'all': + print(name) + parameters.append(param) + else: + pass + + elif name.startswith('final_layer_norm') and final_layer_norm==True: + print(name) + parameters.append(param) + + else: + pass + + elif train_method == 'text_encoder_layer0123': + for name, param in model.text_encoder.text_model.named_parameters(): + # Encoder Layer 0 + if name.startswith('encoder.layers.0') or name.startswith('encoder.layers.1') or name.startswith('encoder.layers.2') or name.startswith('encoder.layers.3'): + if component == 'ffn' and 'mlp' in name: + print(name) + parameters.append(param) + elif component == 'attn' and 'self_attn' in name: + print(name) + parameters.append(param) + elif component == 'all': + print(name) + parameters.append(param) + else: + pass + + elif name.startswith('final_layer_norm') and final_layer_norm==True: + print(name) + parameters.append(param) + + else: + pass + + elif train_method == 'text_encoder_layer01234': + for name, param in model.text_encoder.text_model.named_parameters(): + # Encoder Layer 0 + if name.startswith('encoder.layers.0') or name.startswith('encoder.layers.1') or name.startswith('encoder.layers.2') or name.startswith('encoder.layers.3') or name.startswith('encoder.layers.4'): + if component == 'ffn' and 'mlp' in name: + print(name) + parameters.append(param) + elif component == 'attn' and 'self_attn' in name: + print(name) + parameters.append(param) + elif component == 'all': + print(name) + parameters.append(param) + else: + pass + + elif name.startswith('final_layer_norm') and final_layer_norm==True: + print(name) + parameters.append(param) + + else: + pass + + elif train_method == 'text_encoder_layer012345': + for name, param in model.text_encoder.text_model.named_parameters(): + # Encoder Layer 0 + if name.startswith('encoder.layers.0') or name.startswith('encoder.layers.1') or name.startswith('encoder.layers.2') or name.startswith('encoder.layers.3') or name.startswith('encoder.layers.4') or name.startswith('encoder.layers.5'): + if component == 'ffn' and 'mlp' in name: + print(name) + parameters.append(param) + elif component == 'attn' and 'self_attn' in name: + print(name) + parameters.append(param) + elif component == 'all': + print(name) + parameters.append(param) + else: + pass + + elif name.startswith('final_layer_norm') and final_layer_norm==True: + print(name) + parameters.append(param) + + else: + pass + + elif train_method == 'text_encoder_layer0123456': + for name, param in model.text_encoder.text_model.named_parameters(): + # Encoder Layer 0 + if name.startswith('encoder.layers.0') or name.startswith('encoder.layers.1') or name.startswith('encoder.layers.2') or name.startswith('encoder.layers.3') or name.startswith('encoder.layers.4') or name.startswith('encoder.layers.5') or name.startswith('encoder.layers.6'): + if component == 'ffn' and 'mlp' in name: + print(name) + parameters.append(param) + elif component == 'attn' and 'self_attn' in name: + print(name) + parameters.append(param) + elif component == 'all': + print(name) + parameters.append(param) + else: + pass + + elif name.startswith('final_layer_norm') and final_layer_norm==True: + print(name) + parameters.append(param) + + else: + pass + + elif train_method == 'text_encoder_layer01234567': + for name, param in model.text_encoder.text_model.named_parameters(): + # Encoder Layer 0 + if name.startswith('encoder.layers.0') or name.startswith('encoder.layers.1') or name.startswith('encoder.layers.2') or name.startswith('encoder.layers.3') or name.startswith('encoder.layers.4') or name.startswith('encoder.layers.5') or name.startswith('encoder.layers.6') or name.startswith('encoder.layers.7'): + if component == 'ffn' and 'mlp' in name: + print(name) + parameters.append(param) + elif component == 'attn' and 'self_attn' in name: + print(name) + parameters.append(param) + elif component == 'all': + print(name) + parameters.append(param) + else: + pass + + elif name.startswith('final_layer_norm') and final_layer_norm==True: + print(name) + parameters.append(param) + + else: + pass + + elif train_method == 'text_encoder_layer012345678': + for name, param in model.text_encoder.text_model.named_parameters(): + # Encoder Layer 0 + if name.startswith('encoder.layers.0') or name.startswith('encoder.layers.1') or name.startswith('encoder.layers.2') or name.startswith('encoder.layers.3') or name.startswith('encoder.layers.4') or name.startswith('encoder.layers.5') or name.startswith('encoder.layers.6') or name.startswith('encoder.layers.7') or name.startswith('encoder.layers.8'): + if component == 'ffn' and 'mlp' in name: + print(name) + parameters.append(param) + elif component == 'attn' and 'self_attn' in name: + print(name) + parameters.append(param) + elif component == 'all': + print(name) + parameters.append(param) + else: + pass + + elif name.startswith('final_layer_norm') and final_layer_norm==True: + print(name) + parameters.append(param) + + else: + pass + + elif train_method == 'text_encoder_layer0123456789': + for name, param in model.text_encoder.text_model.named_parameters(): + # Encoder Layer 0 + if name.startswith('encoder.layers.0') or name.startswith('encoder.layers.1') or name.startswith('encoder.layers.2') or name.startswith('encoder.layers.3') or name.startswith('encoder.layers.4') or name.startswith('encoder.layers.5') or name.startswith('encoder.layers.6') or name.startswith('encoder.layers.7') or name.startswith('encoder.layers.8') or name.startswith('encoder.layers.9'): + if component == 'ffn' and 'mlp' in name: + print(name) + parameters.append(param) + elif component == 'attn' and 'self_attn' in name: + print(name) + parameters.append(param) + elif component == 'all': + print(name) + parameters.append(param) + else: + pass + + elif name.startswith('final_layer_norm') and final_layer_norm==True: + print(name) + parameters.append(param) + + else: + pass + + elif train_method == 'text_encoder_layer012345678910': + for name, param in model.text_encoder.text_model.named_parameters(): + # Encoder Layer 0 + if name.startswith('encoder.layers.0') or name.startswith('encoder.layers.1') or name.startswith('encoder.layers.2') or name.startswith('encoder.layers.3') or name.startswith('encoder.layers.4') or name.startswith('encoder.layers.5') or name.startswith('encoder.layers.6') or name.startswith('encoder.layers.7') or name.startswith('encoder.layers.8') or name.startswith('encoder.layers.9') or name.startswith('encoder.layers.10'): + if component == 'ffn' and 'mlp' in name: + print(name) + parameters.append(param) + elif component == 'attn' and 'self_attn' in name: + print(name) + parameters.append(param) + elif component == 'all': + print(name) + parameters.append(param) + else: + pass + + elif name.startswith('final_layer_norm') and final_layer_norm==True: + print(name) + parameters.append(param) + + else: + pass + + elif train_method == 'text_encoder_layer01234567891011': + for name, param in model.text_encoder.text_model.named_parameters(): + # Encoder Layer 0 + if name.startswith('encoder.layers.0') or name.startswith('encoder.layers.1') or name.startswith('encoder.layers.2') or name.startswith('encoder.layers.3') or name.startswith('encoder.layers.4') or name.startswith('encoder.layers.5') or name.startswith('encoder.layers.6') or name.startswith('encoder.layers.7') or name.startswith('encoder.layers.8') or name.startswith('encoder.layers.9') or name.startswith('encoder.layers.10') or name.startswith('encoder.layers.11'): + if component == 'ffn' and 'mlp' in name: + print(name) + parameters.append(param) + elif component == 'attn' and 'self_attn' in name: + print(name) + parameters.append(param) + elif component == 'all': + print(name) + parameters.append(param) + else: + pass + + elif name.startswith('final_layer_norm') and final_layer_norm==True: + print(name) + parameters.append(param) + + else: + pass + + elif train_method == 'text_encoder_layer0_11': + for name, param in model.text_encoder.text_model.named_parameters(): + # Encoder Layer 0 + if name.startswith('encoder.layers.0') or name.startswith('encoder.layers.11'): + if component == 'ffn' and 'mlp' in name: + print(name) + parameters.append(param) + elif component == 'attn' and 'self_attn' in name: + print(name) + parameters.append(param) + elif component == 'all': + print(name) + parameters.append(param) + else: + pass + + elif name.startswith('final_layer_norm') and final_layer_norm==True: + print(name) + parameters.append(param) + + else: + pass + + + elif train_method == 'text_encoder_layer01_1011': + for name, param in model.text_encoder.text_model.named_parameters(): + # Encoder Layer 0 + if name.startswith('encoder.layers.0') or name.startswith('encoder.layers.1') or name.startswith('encoder.layers.10') or name.startswith('encoder.layers.11'): + if component == 'ffn' and 'mlp' in name: + print(name) + parameters.append(param) + elif component == 'attn' and 'self_attn' in name: + print(name) + parameters.append(param) + elif component == 'all': + print(name) + parameters.append(param) + else: + pass + + elif name.startswith('final_layer_norm') and final_layer_norm==True: + print(name) + parameters.append(param) + + else: + pass + + elif train_method == 'text_encoder_layer012_91011': + for name, param in model.text_encoder.text_model.named_parameters(): + # Encoder Layer 0 + if name.startswith('encoder.layers.0') or name.startswith('encoder.layers.1') or name.startswith('encoder.layers.2') or name.startswith('encoder.layers.9') or name.startswith('encoder.layers.10') or name.startswith('encoder.layers.11'): + if component == 'ffn' and 'mlp' in name: + print(name) + parameters.append(param) + elif component == 'attn' and 'self_attn' in name: + print(name) + parameters.append(param) + elif component == 'all': + print(name) + parameters.append(param) + else: + pass + + elif name.startswith('final_layer_norm') and final_layer_norm==True: + print(name) + parameters.append(param) + + else: + pass + + # UNet Model Tuning + else: + for name, param in model.model.diffusion_model.named_parameters(): + # train all layers except x-attns and time_embed layers + if train_method == 'noxattn': + if name.startswith('out.') or 'attn2' in name or 'time_embed' in name: + pass + else: + print(name) + parameters.append(param) + + # train only self attention layers + if train_method == 'selfattn': + if 'attn1' in name: + print(name) + parameters.append(param) + + # train only x attention layers + if train_method == 'xattn': + if 'attn2' in name: + print(name) + parameters.append(param) + + # train all layers + if train_method == 'full': + print(name) + parameters.append(param) + + # train all layers except time embed layers + if train_method == 'notime': + if not (name.startswith('out.') or 'time_embed' in name): + print(name) + parameters.append(param) + if train_method == 'xlayer': + if 'attn2' in name: + if 'output_blocks.6.' in name or 'output_blocks.8.' in name: + print(name) + parameters.append(param) + if train_method == 'selflayer': + if 'attn1' in name: + if 'input_blocks.4.' in name or 'input_blocks.7.' in name: + print(name) + parameters.append(param) + + return parameters \ No newline at end of file diff --git a/mu_defense/core/__init__.py b/mu_defense/core/__init__.py new file mode 100644 index 00000000..32b461cc --- /dev/null +++ b/mu_defense/core/__init__.py @@ -0,0 +1,9 @@ +from .base_algorithm import BaseAlgorithm +from .base_model import BaseModel +from .base_trainer import BaseTrainer + +__all__ = [ + "BaseAlgorithm", + "BaseModel", + "BaseTrainer" + ] \ No newline at end of file diff --git a/mu_defense/core/base_algorithm.py b/mu_defense/core/base_algorithm.py new file mode 100644 index 00000000..e3e804c9 --- /dev/null +++ b/mu_defense/core/base_algorithm.py @@ -0,0 +1,45 @@ +# mu_defense/core/base_algorithm.py + +from abc import ABC, abstractmethod +from typing import Dict + + +class BaseAlgorithm(ABC): + """ + Abstract base class for the overall unlearning algorithm, combining the model, trainer, and sampler. + All algorithms must inherit from this class and implement its methods. + """ + + @abstractmethod + def __init__(self, config: Dict): + """ + Initialize the unlearning algorithm. + + Args: + config (Dict): Configuration parameters for the algorithm. + """ + self.config = config + + def _parse_config(self): + """ + Parse the configuration parameters for the algorithm. + """ + # Parse devices + devices = [ + f"cuda:{int(d.strip())}" for d in self.config.get("devices", "0").split(",") + ] + self.config["devices"] = devices + + @abstractmethod + def _setup_components(self): + """ + Set up the components of the unlearning algorithm, including the model, trainer, and sampler. + """ + pass + + @abstractmethod + def run(self): + """ + Run the unlearning algorithm. + """ + pass diff --git a/mu_defense/core/base_config.py b/mu_defense/core/base_config.py new file mode 100644 index 00000000..e69de29b diff --git a/mu_defense/core/base_model.py b/mu_defense/core/base_model.py new file mode 100644 index 00000000..ea397217 --- /dev/null +++ b/mu_defense/core/base_model.py @@ -0,0 +1,17 @@ +# mu_defense/core/base_model.py + +from abc import ABC, abstractmethod +import torch.nn as nn + +class BaseModel(nn.Module, ABC): + """Abstract base class for all unlearning models.""" + + @abstractmethod + def load_model(self, *args, **kwargs): + """Load the model.""" + pass + + @abstractmethod + def save_model(self, *args, **kwargs): + """Save the model.""" + pass diff --git a/mu_defense/core/base_trainer.py b/mu_defense/core/base_trainer.py new file mode 100644 index 00000000..8567a063 --- /dev/null +++ b/mu_defense/core/base_trainer.py @@ -0,0 +1,23 @@ +# mu_defense/core/base_trainer.py + +from abc import ABC +from typing import Any + +class BaseTrainer(ABC): + """Abstract base class for training unlearning models.""" + + def __init__(self, model: Any, config: dict, **kwargs): + self.model = model + self.config = config + + + # @abstractmethod + def setup_optimizer(self, *args, **kwargs): + """Set up the optimizers for training.""" + pass + + # @abstractmethod + def train(self, *args, **kwargs): + """Train the model.""" + pass + From 3c233c9d2cd941276ebde5b528a94930d16b7c08 Mon Sep 17 00:00:00 2001 From: palisthadeshar Date: Mon, 10 Feb 2025 13:16:57 +0545 Subject: [PATCH 16/22] adv unlearn for compvis added --- mu_defense/algorithms/adv_unlearn/__init__.py | 3 +- .../algorithms/adv_unlearn/algorithm.py | 429 +++----------- .../algorithms/adv_unlearn/compvis_trainer.py | 17 +- .../adv_unlearn/configs/__init__.py | 1 + .../adv_unlearn/configs/adv_unlearn_config.py | 84 +++ mu_defense/algorithms/adv_unlearn/model.py | 8 +- mu_defense/algorithms/adv_unlearn/utils.py | 522 +++++++++++++++++- mu_defense/core/__init__.py | 4 +- mu_defense/core/base_config.py | 33 ++ 9 files changed, 747 insertions(+), 354 deletions(-) create mode 100644 mu_defense/algorithms/adv_unlearn/configs/__init__.py create mode 100644 mu_defense/algorithms/adv_unlearn/configs/adv_unlearn_config.py diff --git a/mu_defense/algorithms/adv_unlearn/__init__.py b/mu_defense/algorithms/adv_unlearn/__init__.py index 90f60fdd..eb40cc9a 100644 --- a/mu_defense/algorithms/adv_unlearn/__init__.py +++ b/mu_defense/algorithms/adv_unlearn/__init__.py @@ -1 +1,2 @@ -from .utils import * \ No newline at end of file +from .utils import * +from .compvis_trainer import AdvUnlearnCompvisTrainer \ No newline at end of file diff --git a/mu_defense/algorithms/adv_unlearn/algorithm.py b/mu_defense/algorithms/adv_unlearn/algorithm.py index 901be612..90dc293e 100644 --- a/mu_defense/algorithms/adv_unlearn/algorithm.py +++ b/mu_defense/algorithms/adv_unlearn/algorithm.py @@ -1,354 +1,103 @@ +# mu/algorithms/adv_unlearn/algorithm.py - +from mu.core.base_config import BaseConfig import torch -from tqdm import tqdm -import random import wandb +from typing import Dict +import logging +from pathlib import Path -from transformers import CLIPTextModel, CLIPTokenizer -from diffusers import AutoencoderKL - -from mu_attack.configs.adv_unlearn import AdvUnlearnConfig -from mu.helpers import sample_model -from mu_attack.tasks.utils.text_encoder import CustomTextEncoder -from mu_attack.attackers.soft_prompt import SoftPromptAttack -from mu_attack.helpers.utils import id2embedding, param_choices, get_models, retain_prompt, get_train_loss_retain,save_text_encoder, save_model, save_history - +from mu.core import BaseAlgorithm +from mu_defense.algorithms.adv_unlearn.model import AdvUnlearnModel +from mu_defense.algorithms.adv_unlearn import AdvUnlearnCompvisTrainer +from mu_defense.algorithms.adv_unlearn.configs import AdvUnlearnConfig -class AdvUnlearn: +class AdvUnlearnAlgorithm(BaseAlgorithm): """ - Class for adversarial unlearning training. - - This class wraps the full training pipeline including prompt cleaning, - attack (adversarial prompt generation), and retention-based regularized training. + AdvUnlearnAlgorithm orchestrates the adversarial unlearning training process. + It sets up the model and trainer components and then runs the training loop. """ - def __init__( - self, - config: AdvUnlearnConfig, - **kwargs - ): - self.config = config.__dict__ - for key, value in kwargs.items(): - setattr(config, key, value) - - config.validate_config() - self.config = config - self.prompt = config.prompt - self.dataset_retain = config.dataset_retain - self.retain_batch = config.retain_batch - self.retain_train = config.retain_train - self.retain_step = config.retain_step - self.retain_loss_w = config.retain_loss_w - self.attack_method = config.attack_method - self.train_method = config.train_method - self.norm_layer = config.norm_layer - self.component = config.component - self.model_name_or_path = config.model_name_or_path - self.start_guidance = config.start_guidance - self.negative_guidance = config.negative_guidance - self.iterations = config.iterations - self.save_interval = config.save_interval - self.lr = config.lr - self.config_path = config.config_path - self.ckpt_path = config.ckpt_path - self.diffusers_config_path = config.diffusers_config_path - self.output_dir = config.output_dir - self.devices = config.devices - self.seperator = config.seperator - self.image_size = config.image_size - self.ddim_steps = config.ddim_steps - self.adv_prompt_num = config.adv_prompt_num - self.attack_embd_type = config.attack_embd_type - self.attack_type = config.attack_type - self.attack_init = config.attack_init - self.warmup_iter = config.warmup_iter - self.attack_step = config.attack_step - self.attack_lr = config.attack_lr - self.adv_prompt_update_step = config.adv_prompt_update_step - self.ddim_eta = config.ddim_eta - self.cache_path = config.cache_path + def __init__(self, config: AdvUnlearnConfig, **kwargs): + # Update configuration with additional kwargs. + for key, value in kwargs.items(): + if not hasattr(config, key): + setattr(config, key, value) + continue + config_attr = getattr(config, key) + if isinstance(config_attr, BaseConfig) and isinstance(value, dict): + for sub_key, sub_val in value.items(): + setattr(config_attr, sub_key, sub_val) + elif isinstance(config_attr, dict) and isinstance(value, dict): + config_attr.update(value) + else: + setattr(config, key, value) + self.config = config.to_dict() - # Will be set during training. - self.words = None - self.retain_dataset = None - self.tokenizer = None - self.text_encoder = None - self.custom_text_encoder = None - self.all_embeddings = None - self.vae = None - self.model_orig = None - self.sampler_orig = None + # Validate and update config. + config.validate_config() + self.config = config.to_dict() + self.config_path = self.config.get("config_path") self.model = None - self.sampler = None - self.parameters = None - self.opt = None - self.criteria = torch.nn.MSELoss() - - # For adversarial prompt update - self.adv_word_embd = None - self.adv_condition_embd = None - self.adv_input_ids = None - - def setup(self): - """Stage 0 & 1: Prompt cleaning and training setup.""" - # --- Prompt cleaning --- - word_print = self.prompt.replace(' ', '') - # Special cases for certain prompts - if self.prompt == 'allartist': - self.prompt = "Kelly Mckernan, Thomas Kinkade, Ajin Demi Human, Alena Aenami, Tyler Edlin, Kilian Eng" - if self.prompt == 'i2p': - self.prompt = "hate, harassment, violence, suffering, humiliation, harm, suicide, sexual, nudity, bodily fluids, blood" - if self.prompt == "artifact": - self.prompt = ("ugly, tiling, poorly drawn hands, poorly drawn feet, poorly drawn face, out of frame, " - "mutation, mutated, extra limbs, extra legs, extra arms, disfigured, deformed, cross-eye, " - "body out of frame, blurry, bad art, bad anatomy, blurred, text, watermark, grainy") - - if self.seperator is not None: - self.words = [w.strip() for w in self.prompt.split(self.seperator)] - else: - self.words = [self.prompt] - print(f'The Concept Prompt to be unlearned: {self.words}') - - # Create a retaining dataset (assumed to be a prompt dataset) - self.retain_dataset = retain_prompt(self.dataset_retain) - - # --- Training Setup --- - ddim_eta = self.ddim_eta # constant value for training - - - - # Load the VAE - self.vae = AutoencoderKL.from_pretrained(self.model_name_or_path, subfolder="vae", cache_dir=self.cache_path).to(self.devices[0]) - # Load tokenizer and text encoder - self.tokenizer = CLIPTokenizer.from_pretrained(self.model_name_or_path, subfolder="tokenizer", cache_dir=self.cache_path) - self.text_encoder = CLIPTextModel.from_pretrained(self.model_name_or_path, subfolder="text_encoder", cache_dir=self.cache_path).to(self.devices[0]) - self.custom_text_encoder = CustomTextEncoder(self.text_encoder).to(self.devices[0]) - self.all_embeddings = self.custom_text_encoder.get_all_embedding().unsqueeze(0) - - # Load models using your helper function (assumed to be defined in utils) - self.model_orig, self.sampler_orig, self.model, self.sampler = get_models(self.config_path, self.ckpt_path, self.devices) - self.model_orig.eval() - - # Setup trainable parameters based on train_method - if 'text_encoder' in self.train_method: - self.parameters = param_choices(model=self.custom_text_encoder, train_method=self.train_method, component=self.component, final_layer_norm=self.norm_layer) - else: - self.parameters = param_choices(model=self.model, train_method=self.train_method, component=self.component, final_layer_norm=self.norm_layer) - - self.opt = torch.optim.Adam(self.parameters, lr=self.lr) - - return word_print # For later use in saving history - - def train(self): - """Stage 2: Training loop.""" - word_print = self.setup() - ddim_eta = self.ddim_eta # As used in training - - # A lambda function to sample until a given time step. - quick_sample_till_t = lambda x, s, code, batch, t: sample_model( - self.model, self.sampler, - x, self.image_size, self.image_size, self.ddim_steps, s, ddim_eta, - start_code=code, n_samples=batch, till_T=t, verbose=False + self.trainer = None + self.device = self.config.get("devices")[0] + self.logger = logging.getLogger(__name__) + self._setup_components() + + def _setup_components(self): + """ + Setup model and trainer components. + """ + self.logger.info("Setting up components for adversarial unlearning training...") + + # Initialize Model + self.model = AdvUnlearnModel( + model_name_or_path=self.config.get("model_name_or_path"), + model_config_path=self.config.get("config_path"), + compvis_ckpt_path=self.config.get("compvis_ckpt_path"), + cache_path=self.config.get("cache_path"), + devices=self.config.get("devices"), ) - - losses = [] - history = [] - global_step = 0 - attack_round = 0 - - # Create a tqdm progress bar - pbar = tqdm(range(self.iterations)) - for i in pbar: - # --- Update adversarial prompt every adv_prompt_update_step iterations --- - if i % self.adv_prompt_update_step == 0: - # Reset the retaining dataset if needed - if self.retain_dataset.check_unseen_prompt_count() < self.retain_batch: - self.retain_dataset.reset() - - # Randomly choose one prompt from the list - word = random.sample(self.words, 1)[0] - text_input = self.tokenizer( - word, padding="max_length", max_length=self.tokenizer.model_max_length, - return_tensors="pt", truncation=True - ) - text_embeddings = id2embedding(self.tokenizer, self.all_embeddings, text_input.input_ids.to(self.devices[0]), self.devices[0]) - - # Get conditional embeddings from the frozen model - emb_0 = self.model_orig.get_learned_conditioning(['']) - emb_p = self.model_orig.get_learned_conditioning([word]) - - # --- Attack Step: Get adversarial prompt --- - if i >= self.warmup_iter: - self.custom_text_encoder.text_encoder.eval() - self.custom_text_encoder.text_encoder.requires_grad_(False) - self.model.eval() - - if attack_round == 0: - if self.attack_embd_type == 'word_embd': - self.adv_word_embd, self.adv_input_ids = SoftPromptAttack.attack( - global_step, word, self.model, self.model_orig, self.tokenizer, - self.custom_text_encoder, self.sampler, emb_0, emb_p, self.start_guidance, - self.devices, self.ddim_steps, ddim_eta, self.image_size, self.criteria, - self.adv_prompt_num, self.all_embeddings, attack_round, self.attack_type, - self.attack_embd_type, self.attack_step, self.attack_lr, self.attack_init, - None, self.attack_method - ) - elif self.attack_embd_type == 'condition_embd': - self.adv_condition_embd, self.adv_input_ids = SoftPromptAttack.attack( - global_step, word, self.model, self.model_orig, self.tokenizer, - self.custom_text_encoder, self.sampler, emb_0, emb_p, self.start_guidance, - self.devices, self.ddim_steps, ddim_eta, self.image_size, self.criteria, - self.adv_prompt_num, self.all_embeddings, attack_round, self.attack_type, - self.attack_embd_type, self.attack_step, self.attack_lr, self.attack_init, - None, self.attack_method - ) - else: - if self.attack_embd_type == 'word_embd': - self.adv_word_embd, self.adv_input_ids = SoftPromptAttack.attack( - global_step, word, self.model, self.model_orig, self.tokenizer, - self.custom_text_encoder, self.sampler, emb_0, emb_p, self.start_guidance, - self.devices, self.ddim_steps, ddim_eta, self.image_size, self.criteria, - self.adv_prompt_num, self.all_embeddings, attack_round, self.attack_type, - self.attack_embd_type, self.attack_step, self.attack_lr, self.attack_init, - self.adv_word_embd, self.attack_method - ) - elif self.attack_embd_type == 'condition_embd': - self.adv_condition_embd, self.adv_input_ids = SoftPromptAttack.attack( - global_step, word, self.model, self.model_orig, self.tokenizer, - self.custom_text_encoder, self.sampler, emb_0, emb_p, self.start_guidance, - self.devices, self.ddim_steps, ddim_eta, self.image_size, self.criteria, - self.adv_prompt_num, self.all_embeddings, attack_round, self.attack_type, - self.attack_embd_type, self.attack_step, self.attack_lr, self.attack_init, - self.adv_condition_embd, self.attack_method - ) - global_step += self.attack_step - attack_round += 1 - # --- Set models to training/eval modes based on training method --- - if 'text_encoder' in self.train_method: - self.custom_text_encoder.text_encoder.train() - self.custom_text_encoder.text_encoder.requires_grad_(True) - self.model.eval() - else: - self.custom_text_encoder.text_encoder.eval() - self.custom_text_encoder.text_encoder.requires_grad_(False) - self.model.train() - self.opt.zero_grad() - - # --- Retaining prompts for retention regularization --- - if self.retain_train == 'reg': - retain_words = self.retain_dataset.get_random_prompts(self.retain_batch) - retain_text_input = self.tokenizer( - retain_words, padding="max_length", max_length=self.tokenizer.model_max_length, - return_tensors="pt", truncation=True - ) - retain_input_ids = retain_text_input.input_ids.to(self.devices[0]) - - retain_emb_p = self.model_orig.get_learned_conditioning(retain_words) - retain_text_embeddings = id2embedding(self.tokenizer, self.all_embeddings, retain_text_input.input_ids.to(self.devices[0]), self.devices[0]) - # Reshape to [batch, 77, embedding_dim] - retain_text_embeddings = retain_text_embeddings.reshape(self.retain_batch, -1, retain_text_embeddings.shape[-1]) - retain_emb_n = self.custom_text_encoder(input_ids=retain_input_ids, inputs_embeds=retain_text_embeddings)[0] - else: - retain_text_input = None - retain_text_embeddings = None - retain_emb_p = None - retain_emb_n = None + # Initialize Trainer + self.trainer = AdvUnlearnCompvisTrainer( + model=self.model, + config=self.config, + devices=self.config.get("devices"), + ) - # --- Compute training loss --- - if i < self.warmup_iter: - # Warmup training uses the original prompt embeddings. - input_ids = text_input.input_ids.to(self.devices[0]) - emb_n = self.custom_text_encoder(input_ids=input_ids, inputs_embeds=text_embeddings)[0] - loss = get_train_loss_retain( - self.retain_batch, self.retain_train, self.retain_loss_w, - self.model, self.model_orig, self.custom_text_encoder, self.sampler, - emb_0, emb_p, retain_emb_p, emb_n, retain_emb_n, self.start_guidance, - self.negative_guidance, self.devices, self.ddim_steps, ddim_eta, - self.image_size, self.criteria, input_ids, self.attack_embd_type - ) - else: - if self.attack_embd_type == 'word_embd': - loss = get_train_loss_retain( - self.retain_batch, self.retain_train, self.retain_loss_w, - self.model, self.model_orig, self.custom_text_encoder, self.sampler, - emb_0, emb_p, retain_emb_p, None, retain_emb_n, self.start_guidance, - self.negative_guidance, self.devices, self.ddim_steps, ddim_eta, - self.image_size, self.criteria, self.adv_input_ids, self.attack_embd_type, self.adv_word_embd - ) - elif self.attack_embd_type == 'condition_embd': - loss = get_train_loss_retain( - self.retain_batch, self.retain_train, self.retain_loss_w, - self.model, self.model_orig, self.custom_text_encoder, self.sampler, - emb_0, emb_p, retain_emb_p, None, retain_emb_n, self.start_guidance, - self.negative_guidance, self.devices, self.ddim_steps, ddim_eta, - self.image_size, self.criteria, self.adv_input_ids, self.attack_embd_type, self.adv_condition_embd - ) - - # Backpropagate loss and update weights. - loss.backward() - losses.append(loss.item()) - pbar.set_postfix({"loss": loss.item()}) - history.append(loss.item()) - wandb.log({'Train_Loss': loss.item()}, step=global_step) - wandb.log({'Attack_Loss': 0.0}, step=global_step) - global_step += 1 - self.opt.step() - - # --- Additional Retention Training (if using iterative retention) --- - if self.retain_train == 'iter': - for r in range(self.retain_step): - print(f'==== Retain Training at step {r} ====') - self.opt.zero_grad() - if self.retain_dataset.check_unseen_prompt_count() < self.retain_batch: - self.retain_dataset.reset() - retain_words = self.retain_dataset.get_random_prompts(self.retain_batch) - - t_enc = torch.randint(self.ddim_steps, (1,), device=self.devices[0]) - og_num = round((int(t_enc) / self.ddim_steps) * 1000) - og_num_lim = round((int(t_enc + 1) / self.ddim_steps) * 1000) - t_enc_ddpm = torch.randint(og_num, og_num_lim, (1,), device=self.devices[0]) - retain_start_code = torch.randn((self.retain_batch, 4, 64, 64)).to(self.devices[0]) - - retain_emb_p = self.model_orig.get_learned_conditioning(retain_words) - retain_z = quick_sample_till_t(retain_emb_p.to(self.devices[0]), self.start_guidance, retain_start_code, self.retain_batch, int(t_enc)) - retain_e_p = self.model_orig.apply_model(retain_z.to(self.devices[0]), t_enc_ddpm.to(self.devices[0]), retain_emb_p.to(self.devices[0])) - - retain_text_input = self.tokenizer( - retain_words, padding="max_length", max_length=self.tokenizer.model_max_length, - return_tensors="pt", truncation=True - ) - retain_input_ids = retain_text_input.input_ids.to(self.devices[0]) - retain_text_embeddings = id2embedding(self.tokenizer, self.all_embeddings, retain_text_input.input_ids.to(self.devices[0]), self.devices[0]) - retain_text_embeddings = retain_text_embeddings.reshape(self.retain_batch, -1, retain_text_embeddings.shape[-1]) - retain_emb_n = self.custom_text_encoder(input_ids=retain_input_ids, inputs_embeds=retain_text_embeddings)[0] - retain_e_n = self.model.apply_model(retain_z.to(self.devices[0]), t_enc_ddpm.to(self.devices[0]), retain_emb_n.to(self.devices[0])) - - retain_loss = self.criteria(retain_e_n.to(self.devices[0]), retain_e_p.to(self.devices[0])) - retain_loss.backward() - self.opt.step() - - # --- Checkpointing and saving history --- - if (i + 1) % self.save_interval == 0 and (i + 1) != self.iterations and (i + 1) >= self.save_interval: - if 'text_encoder' in self.train_method: - save_text_encoder(self.output_dir, self.custom_text_encoder, self.train_method, i) - else: - save_model(self.output_dir, self.model, self.train_method, i, save_compvis=True, - save_diffusers=True, compvis_config_file=self.config_path, - diffusers_config_file=self.diffusers_config_path) - if i % 1 == 0: - save_history(self.output_dir, losses, word_print) - - # --- Stage 3: Save final model and loss curve --- - self.model.eval() - self.custom_text_encoder.text_encoder.eval() - self.custom_text_encoder.text_encoder.requires_grad_(False) - if 'text_encoder' in self.train_method: - save_text_encoder(self.output_dir, self.custom_text_encoder, self.train_method, i) - else: - save_model(self.output_dir, self.model, self.train_method, i, save_compvis=True, - save_diffusers=True, compvis_config_file=self.config_path, - diffusers_config_file=self.diffusers_config_path) - save_history(self.output_dir, losses, word_print) + def run(self): + """ + Execute the training process. + """ + try: + # Initialize WandB with configurable project/run names. + wandb_config = { + "project": self.config.get("wandb_project", "adv-unlearn-project"), + "name": self.config.get("wandb_run", "Adv Unlearn Training"), + "config": self.config, + } + wandb.init(**wandb_config) + self.logger.info("Initialized WandB for logging.") + + # Create output directory if it doesn't exist. + output_dir = Path(self.config.get("output_dir", "./outputs")) + output_dir.mkdir(parents=True, exist_ok=True) + + try: + # Start training. + self.trainer.train() + except Exception as e: + self.logger.error(f"Error during training: {str(e)}") + raise + + except Exception as e: + self.logger.error(f"Failed to initialize training: {str(e)}") + raise + + finally: + # Ensure WandB always finishes. + if wandb.run is not None: + wandb.finish() + self.logger.info("Training complete. WandB logging finished.") diff --git a/mu_defense/algorithms/adv_unlearn/compvis_trainer.py b/mu_defense/algorithms/adv_unlearn/compvis_trainer.py index 2b810b96..1ad290db 100644 --- a/mu_defense/algorithms/adv_unlearn/compvis_trainer.py +++ b/mu_defense/algorithms/adv_unlearn/compvis_trainer.py @@ -2,6 +2,7 @@ from tqdm import tqdm import random import wandb +import logging from torch.nn import MSELoss from mu.helpers import sample_model @@ -16,14 +17,14 @@ save_history ) from mu_attack.attackers.soft_prompt import SoftPromptAttack -from mu_attack.tasks.utils.text_encoder import CustomTextEncoder -class AdvUnlearnTrainer(BaseTrainer): + +class AdvUnlearnCompvisTrainer(BaseTrainer): """ Trainer for adversarial unlearning. - + This trainer performs the adversarial prompt update and retention-based - regularized training loop. + regularized training loop for CompVis/Diffusers models. """ def __init__(self, model, config: dict, devices: list, **kwargs): """ @@ -94,6 +95,8 @@ def __init__(self, model, config: dict, devices: list, **kwargs): self.adv_prompt_update_step = self.config['adv_prompt_update_step'] self.ddim_eta = self.config['ddim_eta'] + self.logger = logging.getLogger(__name__) + # Setup prompt cleaning and retaining dataset. self._setup_prompt_and_dataset() @@ -125,7 +128,7 @@ def _setup_prompt_and_dataset(self): else: self.words = [self.prompt] self.word_print = self.prompt.replace(" ", "") - print(f"The Concept Prompt to be unlearned: {self.words}") + self.logger.info(f"The Concept Prompt to be unlearned: {self.words}") # Create a retaining dataset using your helper function. self.retain_dataset = retain_prompt(self.dataset_retain) @@ -396,8 +399,8 @@ def train(self): compvis_config_file=self.config_path, diffusers_config_file=self.diffusers_config_path ) - if i % 1 == 0: - save_history(self.output_dir, losses, self.word_print) + if i % 1 == 0: + save_history(self.output_dir, losses, self.word_print) # --- Final checkpointing --- self.model.eval() diff --git a/mu_defense/algorithms/adv_unlearn/configs/__init__.py b/mu_defense/algorithms/adv_unlearn/configs/__init__.py new file mode 100644 index 00000000..cc01a1ec --- /dev/null +++ b/mu_defense/algorithms/adv_unlearn/configs/__init__.py @@ -0,0 +1 @@ +from .adv_unlearn_config import AdvUnlearnConfig, adv_unlearn_config \ No newline at end of file diff --git a/mu_defense/algorithms/adv_unlearn/configs/adv_unlearn_config.py b/mu_defense/algorithms/adv_unlearn/configs/adv_unlearn_config.py new file mode 100644 index 00000000..914810b5 --- /dev/null +++ b/mu_defense/algorithms/adv_unlearn/configs/adv_unlearn_config.py @@ -0,0 +1,84 @@ +#mu_defense/algorithms/adv_unlearn/configs/adv_unlearn_config.py + +import os +from pathlib import Path +from mu_defense.core.base_config import BaseConfig + +current_dir = Path(__file__).parent + +class AdvUnlearnConfig(BaseConfig): + def __init__(self, **kwargs): + # Inference & Model Paths + self.config_path = current_dir / "configs/stable-diffusion/v1-inference.yaml" #for compvis + self.compvis_ckpt_path = "models/sd-v1-4-full-ema.ckpt" + self.diffusers_config_path = current_dir / "diffusers_unet_config.json" + self.encoder_model_name_or_path = "CompVis/stable-diffusion-v1-4" + self.cache_path = ".cache" + + # Devices & IO + self.devices = "0,0" # You can later parse this string into a list if needed. + self.seperator = None + self.output_dir = "outputs/adv_unlearn" + + # Image & Diffusion Sampling + self.image_size = 512 + self.ddim_steps = 50 + self.start_guidance = 3.0 + self.negative_guidance = 1.0 + + # Training Setup + self.prompt = "nudity" + self.dataset_retain = "coco" # Choices: 'coco_object', 'coco_object_no_filter', 'imagenet243', 'imagenet243_no_filter' + self.retain_batch = 5 + self.retain_train = "iter" # Options: 'iter' or 'reg' + self.retain_step = 1 + self.retain_loss_w = 1.0 + self.ddim_eta = 0 + + self.train_method = "text_encoder_full" #choices: text_encoder_full', 'text_encoder_layer0', 'text_encoder_layer01', 'text_encoder_layer012', 'text_encoder_layer0123', 'text_encoder_layer01234', 'text_encoder_layer012345', 'text_encoder_layer0123456', 'text_encoder_layer01234567', 'text_encoder_layer012345678', 'text_encoder_layer0123456789', 'text_encoder_layer012345678910', 'text_encoder_layer01234567891011', 'text_encoder_layer0_11','text_encoder_layer01_1011', 'text_encoder_layer012_91011', 'noxattn', 'selfattn', 'xattn', 'full', 'notime', 'xlayer', 'selflayer + self.norm_layer = False # This is a flag; use True if you wish to update the norm layer. + self.attack_method = "pgd" # Choices: 'pgd', 'multi_pgd', 'fast_at', 'free_at' + self.component = "all" # Choices: 'all', 'ffn', 'attn' + self.iterations = 1000 + self.save_interval = 200 + self.lr = 1e-5 + + # Adversarial Attack Hyperparameters + self.adv_prompt_num = 1 + self.attack_embd_type = "word_embd" # Choices: 'word_embd', 'condition_embd' + self.attack_type = "prefix_k" # Choices: 'replace_k', 'add', 'prefix_k', 'suffix_k', 'mid_k', 'insert_k', 'per_k_words' + self.attack_init = "latest" # Choices: 'random', 'latest' + self.attack_step = 30 + self.adv_prompt_update_step = 1 + self.attack_lr = 1e-3 + self.warmup_iter = 200 + + #backend + self.backend = "compvis" + + # Override default values with any provided keyword arguments. + for key, value in kwargs.items(): + setattr(self, key, value) + + def validate_config(self): + """ + Perform basic validation on the config parameters. + """ + if self.retain_batch <= 0: + raise ValueError("retain_batch should be a positive integer.") + if self.lr <= 0: + raise ValueError("Learning rate (lr) should be positive.") + if self.image_size <= 0: + raise ValueError("Image size should be a positive integer.") + if self.iterations <= 0: + raise ValueError("Iterations must be a positive integer.") + if not os.path.exists(self.config_path): + raise FileNotFoundError(f"Model config file {self.config_path} does not exist.") + if not os.path.exists(self.ckpt_path): + raise FileNotFoundError(f"Checkpoint file {self.ckpt_path} does not exist.") + if not os.path.exists(self.diffusers_config_path): + raise FileNotFoundError(f"Diffusers config file {self.diffusers_config_path} does not exist.") + if not os.path.exists(self.output_dir): + os.makedirs(self.output_dir) + +adv_unlearn_config = AdvUnlearnConfig() \ No newline at end of file diff --git a/mu_defense/algorithms/adv_unlearn/model.py b/mu_defense/algorithms/adv_unlearn/model.py index 04db4cb7..f6bae9e1 100644 --- a/mu_defense/algorithms/adv_unlearn/model.py +++ b/mu_defense/algorithms/adv_unlearn/model.py @@ -20,7 +20,7 @@ def __init__( self, model_name_or_path: str, config_path: str, - ckpt_path: str, + compvis_ckpt_path: str, cache_path: str, devices: list ): @@ -30,14 +30,14 @@ def __init__( Args: model_name_or_path (str): Path or identifier of the pretrained model. config_path (str): Path to the model configuration file. - ckpt_path (str): Path to the model checkpoint. + compvis_ckpt_path (str): Path to the model checkpoint. cache_path (str): Directory for caching downloaded models. devices (list): List of device strings (e.g., ['cuda:0', 'cuda:1']) for model placement. """ super().__init__() self.model_name_or_path = model_name_or_path self.config_path = config_path - self.ckpt_path = ckpt_path + self.ckpt_path = compvis_ckpt_path self.cache_path = cache_path self.devices = devices @@ -67,7 +67,7 @@ def __init__( # Load diffusion models using your helper function. self.model_orig, self.sampler_orig, self.model, self.sampler = get_models( self.config_path, - self.ckpt_path, + self.compvis_ckpt_path, self.devices ) self.model_orig.eval() # Set the frozen model to evaluation mode. diff --git a/mu_defense/algorithms/adv_unlearn/utils.py b/mu_defense/algorithms/adv_unlearn/utils.py index 2a0fe941..befc5820 100644 --- a/mu_defense/algorithms/adv_unlearn/utils.py +++ b/mu_defense/algorithms/adv_unlearn/utils.py @@ -1,10 +1,26 @@ +import os import random import pandas as pd +import matplotlib.pyplot as plt +import numpy as np import torch import torch.nn.functional as F +import OmegaConf +from diffusers import ( + DDIMScheduler, + DPMSolverMultistepScheduler, + EulerAncestralDiscreteScheduler, + EulerDiscreteScheduler, + HeunDiscreteScheduler, + LMSDiscreteScheduler, + PNDMScheduler, + UNet2DConditionModel, +) + + from mu.helpers import load_model_from_config, sample_model from stable_diffusion.ldm.models.diffusion.ddim import DDIMSampler @@ -592,4 +608,508 @@ def param_choices(model, train_method, component='all', final_layer_norm=False): print(name) parameters.append(param) - return parameters \ No newline at end of file + return parameters + +def save_text_encoder(folder_path, model, name, num): + # SAVE MODEL + + # PATH = f'{FOLDER}/{model_type}-word_{word_print}-method_{train_method}-sg_{start_guidance}-ng_{neg_guidance}-iter_{i+1}-lr_{lr}-startmodel_{start_model}-numacc_{numacc}.pt' + folder_path = f'{folder_path}/models' + os.makedirs(folder_path, exist_ok=True) + if num is not None: + path = f'{folder_path}/TextEncoder-{name}-epoch_{num}.pt' + else: + path = f'{folder_path}/TextEncoder-{name}.pt' + + torch.save(model.state_dict(), path) + + + +def create_unet_diffusers_config(original_config, image_size: int): + """ + Creates a config for the diffusers based on the config of the LDM model. + """ + unet_params = original_config.model.params.unet_config.params + vae_params = original_config.model.params.first_stage_config.params.ddconfig + + block_out_channels = [unet_params.model_channels * mult for mult in unet_params.channel_mult] + + down_block_types = [] + resolution = 1 + for i in range(len(block_out_channels)): + block_type = "CrossAttnDownBlock2D" if resolution in unet_params.attention_resolutions else "DownBlock2D" + down_block_types.append(block_type) + if i != len(block_out_channels) - 1: + resolution *= 2 + + up_block_types = [] + for i in range(len(block_out_channels)): + block_type = "CrossAttnUpBlock2D" if resolution in unet_params.attention_resolutions else "UpBlock2D" + up_block_types.append(block_type) + resolution //= 2 + + vae_scale_factor = 2 ** (len(vae_params.ch_mult) - 1) + + head_dim = unet_params.num_heads if "num_heads" in unet_params else None + use_linear_projection = ( + unet_params.use_linear_in_transformer if "use_linear_in_transformer" in unet_params else False + ) + if use_linear_projection: + # stable diffusion 2-base-512 and 2-768 + if head_dim is None: + head_dim = [5, 10, 20, 20] + + config = dict( + sample_size=image_size // vae_scale_factor, + in_channels=unet_params.in_channels, + out_channels=unet_params.out_channels, + down_block_types=tuple(down_block_types), + up_block_types=tuple(up_block_types), + block_out_channels=tuple(block_out_channels), + layers_per_block=unet_params.num_res_blocks, + cross_attention_dim=unet_params.context_dim, + attention_head_dim=head_dim, + use_linear_projection=use_linear_projection, + ) + + return config + +def assign_to_checkpoint( + paths, checkpoint, old_checkpoint, attention_paths_to_split=None, additional_replacements=None, config=None +): + """ + This does the final conversion step: take locally converted weights and apply a global renaming + to them. It splits attention layers, and takes into account additional replacements + that may arise. + + Assigns the weights to the new checkpoint. + """ + assert isinstance(paths, list), "Paths should be a list of dicts containing 'old' and 'new' keys." + + # Splits the attention layers into three variables. + if attention_paths_to_split is not None: + for path, path_map in attention_paths_to_split.items(): + old_tensor = old_checkpoint[path] + channels = old_tensor.shape[0] // 3 + + target_shape = (-1, channels) if len(old_tensor.shape) == 3 else (-1) + + num_heads = old_tensor.shape[0] // config["num_head_channels"] // 3 + + old_tensor = old_tensor.reshape((num_heads, 3 * channels // num_heads) + old_tensor.shape[1:]) + query, key, value = old_tensor.split(channels // num_heads, dim=1) + + checkpoint[path_map["query"]] = query.reshape(target_shape) + checkpoint[path_map["key"]] = key.reshape(target_shape) + checkpoint[path_map["value"]] = value.reshape(target_shape) + + for path in paths: + new_path = path["new"] + + # These have already been assigned + if attention_paths_to_split is not None and new_path in attention_paths_to_split: + continue + + # Global renaming happens here + new_path = new_path.replace("middle_block.0", "mid_block.resnets.0") + new_path = new_path.replace("middle_block.1", "mid_block.attentions.0") + new_path = new_path.replace("middle_block.2", "mid_block.resnets.1") + + if additional_replacements is not None: + for replacement in additional_replacements: + new_path = new_path.replace(replacement["old"], replacement["new"]) + + # proj_attn.weight has to be converted from conv 1D to linear + if "proj_attn.weight" in new_path: + checkpoint[new_path] = old_checkpoint[path["old"]][:, :, 0] + else: + checkpoint[new_path] = old_checkpoint[path["old"]] + + +def renew_attention_paths(old_list, n_shave_prefix_segments=0): + """ + Updates paths inside attentions to the new naming scheme (local renaming) + """ + mapping = [] + for old_item in old_list: + new_item = old_item + + # new_item = new_item.replace('norm.weight', 'group_norm.weight') + # new_item = new_item.replace('norm.bias', 'group_norm.bias') + + # new_item = new_item.replace('proj_out.weight', 'proj_attn.weight') + # new_item = new_item.replace('proj_out.bias', 'proj_attn.bias') + + # new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments) + + mapping.append({"old": old_item, "new": new_item}) + + return mapping + +def shave_segments(path, n_shave_prefix_segments=1): + """ + Removes segments. Positive values shave the first segments, negative shave the last segments. + """ + if n_shave_prefix_segments >= 0: + return ".".join(path.split(".")[n_shave_prefix_segments:]) + else: + return ".".join(path.split(".")[:n_shave_prefix_segments]) + +def renew_resnet_paths(old_list, n_shave_prefix_segments=0): + """ + Updates paths inside resnets to the new naming scheme (local renaming) + """ + mapping = [] + for old_item in old_list: + new_item = old_item.replace("in_layers.0", "norm1") + new_item = new_item.replace("in_layers.2", "conv1") + + new_item = new_item.replace("out_layers.0", "norm2") + new_item = new_item.replace("out_layers.3", "conv2") + + new_item = new_item.replace("emb_layers.1", "time_emb_proj") + new_item = new_item.replace("skip_connection", "conv_shortcut") + + new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments) + + mapping.append({"old": old_item, "new": new_item}) + + return mapping + + +def convert_ldm_unet_checkpoint(checkpoint, config, path=None, extract_ema=False): + """ + Takes a state dict and a config, and returns a converted checkpoint. + """ + + # extract state_dict for UNet + unet_state_dict = {} + keys = list(checkpoint.keys()) + + unet_key = "model.diffusion_model." + # at least a 100 parameters have to start with `model_ema` in order for the checkpoint to be EMA + if sum(k.startswith("model_ema") for k in keys) > 100 and extract_ema: + print(f"Checkpoint {path} has both EMA and non-EMA weights.") + print( + "In this conversion only the EMA weights are extracted. If you want to instead extract the non-EMA" + " weights (useful to continue fine-tuning), please make sure to remove the `--extract_ema` flag." + ) + for key in keys: + if key.startswith("model.diffusion_model"): + flat_ema_key = "model_ema." + "".join(key.split(".")[1:]) + unet_state_dict[key.replace(unet_key, "")] = checkpoint.pop(flat_ema_key) + else: + if sum(k.startswith("model_ema") for k in keys) > 100: + print( + "In this conversion only the non-EMA weights are extracted. If you want to instead extract the EMA" + " weights (usually better for inference), please make sure to add the `--extract_ema` flag." + ) + + for key in keys: + if key.startswith(unet_key): + unet_state_dict[key.replace(unet_key, "")] = checkpoint.pop(key) + + new_checkpoint = {} + + new_checkpoint["time_embedding.linear_1.weight"] = unet_state_dict["time_embed.0.weight"] + new_checkpoint["time_embedding.linear_1.bias"] = unet_state_dict["time_embed.0.bias"] + new_checkpoint["time_embedding.linear_2.weight"] = unet_state_dict["time_embed.2.weight"] + new_checkpoint["time_embedding.linear_2.bias"] = unet_state_dict["time_embed.2.bias"] + + new_checkpoint["conv_in.weight"] = unet_state_dict["input_blocks.0.0.weight"] + new_checkpoint["conv_in.bias"] = unet_state_dict["input_blocks.0.0.bias"] + + new_checkpoint["conv_norm_out.weight"] = unet_state_dict["out.0.weight"] + new_checkpoint["conv_norm_out.bias"] = unet_state_dict["out.0.bias"] + new_checkpoint["conv_out.weight"] = unet_state_dict["out.2.weight"] + new_checkpoint["conv_out.bias"] = unet_state_dict["out.2.bias"] + + # Retrieves the keys for the input blocks only + num_input_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "input_blocks" in layer}) + input_blocks = { + layer_id: [key for key in unet_state_dict if f"input_blocks.{layer_id}" in key] + for layer_id in range(num_input_blocks) + } + + # Retrieves the keys for the middle blocks only + num_middle_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "middle_block" in layer}) + middle_blocks = { + layer_id: [key for key in unet_state_dict if f"middle_block.{layer_id}" in key] + for layer_id in range(num_middle_blocks) + } + + # Retrieves the keys for the output blocks only + num_output_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "output_blocks" in layer}) + output_blocks = { + layer_id: [key for key in unet_state_dict if f"output_blocks.{layer_id}" in key] + for layer_id in range(num_output_blocks) + } + + for i in range(1, num_input_blocks): + block_id = (i - 1) // (config["layers_per_block"] + 1) + layer_in_block_id = (i - 1) % (config["layers_per_block"] + 1) + + resnets = [ + key for key in input_blocks[i] if f"input_blocks.{i}.0" in key and f"input_blocks.{i}.0.op" not in key + ] + attentions = [key for key in input_blocks[i] if f"input_blocks.{i}.1" in key] + + if f"input_blocks.{i}.0.op.weight" in unet_state_dict: + new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.weight"] = unet_state_dict.pop( + f"input_blocks.{i}.0.op.weight" + ) + new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.bias"] = unet_state_dict.pop( + f"input_blocks.{i}.0.op.bias" + ) + + paths = renew_resnet_paths(resnets) + meta_path = {"old": f"input_blocks.{i}.0", "new": f"down_blocks.{block_id}.resnets.{layer_in_block_id}"} + assign_to_checkpoint( + paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config + ) + + if len(attentions): + paths = renew_attention_paths(attentions) + meta_path = {"old": f"input_blocks.{i}.1", "new": f"down_blocks.{block_id}.attentions.{layer_in_block_id}"} + assign_to_checkpoint( + paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config + ) + + resnet_0 = middle_blocks[0] + attentions = middle_blocks[1] + resnet_1 = middle_blocks[2] + + resnet_0_paths = renew_resnet_paths(resnet_0) + assign_to_checkpoint(resnet_0_paths, new_checkpoint, unet_state_dict, config=config) + + resnet_1_paths = renew_resnet_paths(resnet_1) + assign_to_checkpoint(resnet_1_paths, new_checkpoint, unet_state_dict, config=config) + + attentions_paths = renew_attention_paths(attentions) + meta_path = {"old": "middle_block.1", "new": "mid_block.attentions.0"} + assign_to_checkpoint( + attentions_paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config + ) + + for i in range(num_output_blocks): + block_id = i // (config["layers_per_block"] + 1) + layer_in_block_id = i % (config["layers_per_block"] + 1) + output_block_layers = [shave_segments(name, 2) for name in output_blocks[i]] + output_block_list = {} + + for layer in output_block_layers: + layer_id, layer_name = layer.split(".")[0], shave_segments(layer, 1) + if layer_id in output_block_list: + output_block_list[layer_id].append(layer_name) + else: + output_block_list[layer_id] = [layer_name] + + if len(output_block_list) > 1: + resnets = [key for key in output_blocks[i] if f"output_blocks.{i}.0" in key] + attentions = [key for key in output_blocks[i] if f"output_blocks.{i}.1" in key] + + resnet_0_paths = renew_resnet_paths(resnets) + paths = renew_resnet_paths(resnets) + + meta_path = {"old": f"output_blocks.{i}.0", "new": f"up_blocks.{block_id}.resnets.{layer_in_block_id}"} + assign_to_checkpoint( + paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config + ) + + output_block_list = {k: sorted(v) for k, v in output_block_list.items()} + if ["conv.bias", "conv.weight"] in output_block_list.values(): + index = list(output_block_list.values()).index(["conv.bias", "conv.weight"]) + new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.weight"] = unet_state_dict[ + f"output_blocks.{i}.{index}.conv.weight" + ] + new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.bias"] = unet_state_dict[ + f"output_blocks.{i}.{index}.conv.bias" + ] + + # Clear attentions as they have been attributed above. + if len(attentions) == 2: + attentions = [] + + if len(attentions): + paths = renew_attention_paths(attentions) + meta_path = { + "old": f"output_blocks.{i}.1", + "new": f"up_blocks.{block_id}.attentions.{layer_in_block_id}", + } + assign_to_checkpoint( + paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config + ) + else: + resnet_0_paths = renew_resnet_paths(output_block_layers, n_shave_prefix_segments=1) + for path in resnet_0_paths: + old_path = ".".join(["output_blocks", str(i), path["old"]]) + new_path = ".".join(["up_blocks", str(block_id), "resnets", str(layer_in_block_id), path["new"]]) + + new_checkpoint[new_path] = unet_state_dict[old_path] + + return new_checkpoint + +def savemodelDiffusers(path, name, compvis_config_file, diffusers_config_file, device='cpu'): + checkpoint_path = path + + original_config_file = compvis_config_file + config_file = diffusers_config_file + num_in_channels = 4 + scheduler_type = 'ddim' + pipeline_type = None + image_size = 512 + prediction_type = 'epsilon' + extract_ema = False + dump_path = path.replace('Compvis','Diffusers') + upcast_attention = False + + + if device is None: + device = "cuda" if torch.cuda.is_available() else "cpu" + checkpoint = torch.load(checkpoint_path, map_location=device) + else: + checkpoint = torch.load(checkpoint_path, map_location=device) + + # Sometimes models don't have the global_step item + if "global_step" in checkpoint: + global_step = checkpoint["global_step"] + else: + print("global_step key not found in model") + global_step = None + + if "state_dict" in checkpoint: + checkpoint = checkpoint["state_dict"] + upcast_attention = upcast_attention + if original_config_file is None: + key_name = "model.diffusion_model.input_blocks.2.1.transformer_blocks.0.attn2.to_k.weight" + + if key_name in checkpoint and checkpoint[key_name].shape[-1] == 1024: + if not os.path.isfile("v2-inference-v.yaml"): + # model_type = "v2" + os.system( + "wget https://raw.githubusercontent.com/Stability-AI/stablediffusion/main/configs/stable-diffusion/v2-inference-v.yaml" + " -O v2-inference-v.yaml" + ) + original_config_file = "./v2-inference-v.yaml" + + if global_step == 110000: + # v2.1 needs to upcast attention + upcast_attention = True + else: + if not os.path.isfile("v1-inference.yaml"): + # model_type = "v1" + os.system( + "wget https://raw.githubusercontent.com/CompVis/stable-diffusion/main/configs/stable-diffusion/v1-inference.yaml" + " -O v1-inference.yaml" + ) + original_config_file = "./v1-inference.yaml" + + original_config = OmegaConf.load(original_config_file) + + if num_in_channels is not None: + original_config["model"]["params"]["unet_config"]["params"]["in_channels"] = num_in_channels + + if ( + "parameterization" in original_config["model"]["params"] + and original_config["model"]["params"]["parameterization"] == "v" + ): + if prediction_type is None: + # NOTE: For stable diffusion 2 base it is recommended to pass `prediction_type=="epsilon"` + # as it relies on a brittle global step parameter here + prediction_type = "epsilon" if global_step == 875000 else "v_prediction" + if image_size is None: + # NOTE: For stable diffusion 2 base one has to pass `image_size==512` + # as it relies on a brittle global step parameter here + image_size = 512 if global_step == 875000 else 768 + else: + if prediction_type is None: + prediction_type = "epsilon" + if image_size is None: + image_size = 512 + + num_train_timesteps = original_config.model.params.timesteps + beta_start = original_config.model.params.linear_start + beta_end = original_config.model.params.linear_end + scheduler = DDIMScheduler( + beta_end=beta_end, + beta_schedule="scaled_linear", + beta_start=beta_start, + num_train_timesteps=num_train_timesteps, + steps_offset=1, + clip_sample=False, + set_alpha_to_one=False, + prediction_type=prediction_type, + ) + # make sure scheduler works correctly with DDIM + scheduler.register_to_config(clip_sample=False) + + if scheduler_type == "pndm": + config = dict(scheduler.config) + config["skip_prk_steps"] = True + scheduler = PNDMScheduler.from_config(config) + elif scheduler_type == "lms": + scheduler = LMSDiscreteScheduler.from_config(scheduler.config) + elif scheduler_type == "heun": + scheduler = HeunDiscreteScheduler.from_config(scheduler.config) + elif scheduler_type == "euler": + scheduler = EulerDiscreteScheduler.from_config(scheduler.config) + elif scheduler_type == "euler-ancestral": + scheduler = EulerAncestralDiscreteScheduler.from_config(scheduler.config) + elif scheduler_type == "dpm": + scheduler = DPMSolverMultistepScheduler.from_config(scheduler.config) + elif scheduler_type == "ddim": + scheduler = scheduler + else: + raise ValueError(f"Scheduler of type {scheduler_type} doesn't exist!") + + # Convert the UNet2DConditionModel model. + unet_config = create_unet_diffusers_config(original_config, image_size=image_size) + unet_config["upcast_attention"] = False + unet = UNet2DConditionModel(**unet_config) + + converted_unet_checkpoint = convert_ldm_unet_checkpoint( + checkpoint, unet_config, path=checkpoint_path, extract_ema=extract_ema + ) + torch.save(converted_unet_checkpoint, dump_path) + + + +def save_model(folder_path, model, name, num, compvis_config_file=None, diffusers_config_file=None, device='cpu', save_compvis=True, save_diffusers=True): + # SAVE MODEL + + # PATH = f'{FOLDER}/{model_type}-word_{word_print}-method_{train_method}-sg_{start_guidance}-ng_{neg_guidance}-iter_{i+1}-lr_{lr}-startmodel_{start_model}-numacc_{numacc}.pt' + folder_path = f'{folder_path}/models' + os.makedirs(folder_path, exist_ok=True) + if num is not None: + path = f'{folder_path}/Compvis-UNet-{name}-epoch_{num}.pt' + else: + path = f'{folder_path}/Compvis-UNet-{name}.pt' + if save_compvis: + torch.save(model.state_dict(), path) + + if save_diffusers: + print('Saving Model in Diffusers Format') + savemodelDiffusers(path, name, compvis_config_file, diffusers_config_file, device=device ) + + + +def moving_average(a, n=3) : + ret = np.cumsum(a, dtype=float) + ret[n:] = ret[n:] - ret[:-n] + return ret[n - 1:] / n + +def plot_loss(losses, path,word, n=100): + v = moving_average(losses, n) + plt.plot(v, label=f'{word}_loss') + plt.legend(loc="upper left") + plt.title('Average loss in trainings', fontsize=20) + plt.xlabel('Data point', fontsize=16) + plt.ylabel('Loss value', fontsize=16) + plt.savefig(path) + +def save_history(folder_path, losses, word_print): + folder_path = f'{folder_path}/logs' + os.makedirs(folder_path, exist_ok=True) + with open(f'{folder_path}/loss.txt', 'w') as f: + f.writelines([str(i) for i in losses]) + plot_loss(losses,f'{folder_path}/loss.png' , word_print, n=3) \ No newline at end of file diff --git a/mu_defense/core/__init__.py b/mu_defense/core/__init__.py index 32b461cc..dee0b006 100644 --- a/mu_defense/core/__init__.py +++ b/mu_defense/core/__init__.py @@ -1,9 +1,11 @@ from .base_algorithm import BaseAlgorithm from .base_model import BaseModel from .base_trainer import BaseTrainer +from .base_config import BaseConfig __all__ = [ "BaseAlgorithm", "BaseModel", - "BaseTrainer" + "BaseTrainer", + "BaseConfig" ] \ No newline at end of file diff --git a/mu_defense/core/base_config.py b/mu_defense/core/base_config.py index e69de29b..660e9066 100644 --- a/mu_defense/core/base_config.py +++ b/mu_defense/core/base_config.py @@ -0,0 +1,33 @@ + +# mu_defense/core/base_config.py + +from abc import ABC, abstractmethod + + +class BaseConfig(ABC): + + @abstractmethod + def __init__(self): + pass + + def validate_config(self): + pass + + def to_dict(self): + result = {} + for attr_name, attr_value in self.__dict__.items(): + if hasattr(attr_value, "to_dict") and callable(attr_value.to_dict): + result[attr_name] = attr_value.to_dict() + elif isinstance(attr_value, list): + result[attr_name] = [ + item.to_dict() if hasattr(item, "to_dict") else item + for item in attr_value + ] + elif isinstance(attr_value, dict): + dict_val = {} + for k, v in attr_value.items(): + dict_val[k] = v.to_dict() if hasattr(v, "to_dict") else v + result[attr_name] = dict_val + else: + result[attr_name] = attr_value + return result From 5609274ec53895249054a6d89799c8e83aa36564 Mon Sep 17 00:00:00 2001 From: palisthadeshar Date: Mon, 10 Feb 2025 12:20:21 +0000 Subject: [PATCH 17/22] refactored for mu_defense compvis --- mu_defense/.gitignore | 1 + mu_defense/algorithms/adv_unlearn/README.md | 25 ++ mu_defense/algorithms/adv_unlearn/__init__.py | 13 +- .../algorithms/adv_unlearn/algorithm.py | 22 +- .../algorithms/adv_unlearn/compvis_trainer.py | 102 ++------ .../adv_unlearn/configs/adv_unlearn_config.py | 17 +- .../algorithms/adv_unlearn/dataset_handler.py | 50 ++++ mu_defense/algorithms/adv_unlearn/model.py | 112 ++++---- mu_defense/algorithms/adv_unlearn/trainer.py | 37 +++ mu_defense/algorithms/adv_unlearn/utils.py | 242 +++++++----------- mu_defense/core/__init__.py | 10 +- mu_defense/core/base_compvis_trainer.py | 23 ++ mu_defense/core/base_config.py | 1 + mu_defense/core/base_data_handler.py | 29 +++ mu_defense/core/base_trainer.py | 34 ++- mu_defense/environment.yaml | 36 +++ 16 files changed, 406 insertions(+), 348 deletions(-) create mode 100644 mu_defense/.gitignore create mode 100644 mu_defense/algorithms/adv_unlearn/dataset_handler.py create mode 100644 mu_defense/core/base_compvis_trainer.py create mode 100644 mu_defense/core/base_data_handler.py create mode 100644 mu_defense/environment.yaml diff --git a/mu_defense/.gitignore b/mu_defense/.gitignore new file mode 100644 index 00000000..aa850f42 --- /dev/null +++ b/mu_defense/.gitignore @@ -0,0 +1 @@ +src/* \ No newline at end of file diff --git a/mu_defense/algorithms/adv_unlearn/README.md b/mu_defense/algorithms/adv_unlearn/README.md index e69de29b..873e2458 100644 --- a/mu_defense/algorithms/adv_unlearn/README.md +++ b/mu_defense/algorithms/adv_unlearn/README.md @@ -0,0 +1,25 @@ +```python +from mu_defense.algorithms.adv_unlearn.algorithm import AdvUnlearnAlgorithm +from mu_defense.algorithms.adv_unlearn.configs import adv_unlearn_config +from mu.algorithms.erase_diff.configs import erase_diff_train_mu + + +def mu_defense(): + + mu_defense = AdvUnlearnAlgorithm( + config=adv_unlearn_config, + compvis_ckpt_path = "/home/ubuntu/Projects/dipesh/unlearn_diff/outputs/erase_diff/erase_diff_Abstractionism_model.pth", + # diffusers_model_name_or_path = "/home/ubuntu/Projects/dipesh/unlearn_diff/outputs/forget_me_not/finetuned_models/Abstractionism", + attack_step = 2, + backend = "compvis", + attack_method = "fast_at", + model_config_path = erase_diff_train_mu.model_config_path + + + ) + mu_defense.run() + +if __name__ == "__main__": + mu_defense() + +``` \ No newline at end of file diff --git a/mu_defense/algorithms/adv_unlearn/__init__.py b/mu_defense/algorithms/adv_unlearn/__init__.py index eb40cc9a..0d25e5d7 100644 --- a/mu_defense/algorithms/adv_unlearn/__init__.py +++ b/mu_defense/algorithms/adv_unlearn/__init__.py @@ -1,2 +1,13 @@ from .utils import * -from .compvis_trainer import AdvUnlearnCompvisTrainer \ No newline at end of file +from .model import AdvUnlearnModel +from .dataset_handler import AdvUnlearnDatasetHandler +from .compvis_trainer import AdvUnlearnCompvisTrainer +# from .algorithm import AdvUnlearnAlgorithm +# from .trainer import AdvUnlearnTrainer + +__all__ = ["AdvUnlearnModel", + "AdvUnlearnDatasetHandler", + "AdvUnlearnCompvisTrainer", + # "AdvUnlearnAlgorithm", + # "AdvUnlearnTrainer" +] \ No newline at end of file diff --git a/mu_defense/algorithms/adv_unlearn/algorithm.py b/mu_defense/algorithms/adv_unlearn/algorithm.py index 90dc293e..3798fc5b 100644 --- a/mu_defense/algorithms/adv_unlearn/algorithm.py +++ b/mu_defense/algorithms/adv_unlearn/algorithm.py @@ -1,15 +1,13 @@ # mu/algorithms/adv_unlearn/algorithm.py from mu.core.base_config import BaseConfig -import torch import wandb -from typing import Dict import logging from pathlib import Path from mu.core import BaseAlgorithm -from mu_defense.algorithms.adv_unlearn.model import AdvUnlearnModel -from mu_defense.algorithms.adv_unlearn import AdvUnlearnCompvisTrainer +from mu_defense.algorithms.adv_unlearn import AdvUnlearnModel +from mu_defense.algorithms.adv_unlearn.trainer import AdvUnlearnTrainer from mu_defense.algorithms.adv_unlearn.configs import AdvUnlearnConfig @@ -38,10 +36,10 @@ def __init__(self, config: AdvUnlearnConfig, **kwargs): # Validate and update config. config.validate_config() self.config = config.to_dict() - self.config_path = self.config.get("config_path") self.model = None self.trainer = None - self.device = self.config.get("devices")[0] + self.devices = self.config.get("devices") + self.devices = [f'cuda:{int(d.strip())}' for d in self.devices.split(',')] self.logger = logging.getLogger(__name__) self._setup_components() @@ -53,18 +51,14 @@ def _setup_components(self): # Initialize Model self.model = AdvUnlearnModel( - model_name_or_path=self.config.get("model_name_or_path"), - model_config_path=self.config.get("config_path"), - compvis_ckpt_path=self.config.get("compvis_ckpt_path"), - cache_path=self.config.get("cache_path"), - devices=self.config.get("devices"), + config=self.config ) # Initialize Trainer - self.trainer = AdvUnlearnCompvisTrainer( + self.trainer = AdvUnlearnTrainer( model=self.model, config=self.config, - devices=self.config.get("devices"), + devices=self.devices, ) def run(self): @@ -87,7 +81,7 @@ def run(self): try: # Start training. - self.trainer.train() + self.trainer.run() except Exception as e: self.logger.error(f"Error during training: {str(e)}") raise diff --git a/mu_defense/algorithms/adv_unlearn/compvis_trainer.py b/mu_defense/algorithms/adv_unlearn/compvis_trainer.py index 1ad290db..ea232c02 100644 --- a/mu_defense/algorithms/adv_unlearn/compvis_trainer.py +++ b/mu_defense/algorithms/adv_unlearn/compvis_trainer.py @@ -1,3 +1,5 @@ +# mu_defense/algorithms/adv_unlearn/compvis_trainer.py + import torch from tqdm import tqdm import random @@ -5,18 +7,17 @@ import logging from torch.nn import MSELoss -from mu.helpers import sample_model from mu.core import BaseTrainer from mu_defense.algorithms.adv_unlearn import ( id2embedding, param_choices, - retain_prompt, get_train_loss_retain, save_text_encoder, - save_model, - save_history + save_history, + sample_model ) from mu_attack.attackers.soft_prompt import SoftPromptAttack +from mu_defense.algorithms.adv_unlearn import AdvUnlearnDatasetHandler class AdvUnlearnCompvisTrainer(BaseTrainer): @@ -28,20 +29,7 @@ class AdvUnlearnCompvisTrainer(BaseTrainer): """ def __init__(self, model, config: dict, devices: list, **kwargs): """ - Initialize the AdvUnlearnTrainer. - - Args: - model: A model loader instance that contains the following attributes: - - model_orig: the frozen diffusion model, - - sampler_orig: sampler for the frozen model, - - model: the trainable diffusion model, - - sampler: sampler for the trainable model, - - tokenizer: the tokenizer, - - custom_text_encoder: the custom text encoder wrapping the CLIP text encoder, - - all_embeddings: the complete text embedding matrix, - - vae: the VAE. - config (dict): Configuration dictionary with all training hyperparameters. - devices (list): List of device strings (e.g., ['cuda:0']). + Initialize the AdvUnlearnCompvisTrainer. """ super().__init__(model, config, **kwargs) self.devices = devices @@ -51,12 +39,12 @@ def __init__(self, model, config: dict, devices: list, **kwargs): self.model_orig = model.model_orig # frozen diffusion model (set to eval) self.sampler = model.sampler self.sampler_orig = model.sampler_orig + self.model_loader = model # Other loaded components. self.tokenizer = model.tokenizer self.custom_text_encoder = model.custom_text_encoder self.all_embeddings = model.all_embeddings - self.vae = model.vae # Loss criterion. self.criteria = MSELoss() @@ -65,14 +53,13 @@ def __init__(self, model, config: dict, devices: list, **kwargs): self.config = config self.prompt = self.config['prompt'] self.seperator = self.config.get('seperator') - self.iterations = self.config['iterations'] + self.iterations = self.config.get('iterations') self.ddim_steps = self.config['ddim_steps'] self.start_guidance = self.config['start_guidance'] self.negative_guidance = self.config['negative_guidance'] self.image_size = self.config['image_size'] self.lr = self.config['lr'] - self.config_path = self.config['config_path'] - self.diffusers_config_path = self.config['diffusers_config_path'] + self.model_config_path = self.config['model_config_path'] self.output_dir = self.config['output_dir'] # Retention and attack parameters. @@ -97,8 +84,14 @@ def __init__(self, model, config: dict, devices: list, **kwargs): self.logger = logging.getLogger(__name__) - # Setup prompt cleaning and retaining dataset. - self._setup_prompt_and_dataset() + # Setup the dataset handler and prompt cleaning. + self.dataset_handler = AdvUnlearnDatasetHandler( + prompt=self.prompt, + seperator=self.seperator, + dataset_retain=self.dataset_retain + ) + self.words, self.word_print = self.dataset_handler.setup_prompt() + self.retain_dataset = self.dataset_handler.setup_dataset() # Initialize adversarial prompt variables. self.adv_word_embd = None @@ -108,31 +101,6 @@ def __init__(self, model, config: dict, devices: list, **kwargs): # Setup trainable parameters and optimizer. self._setup_optimizer() - def _setup_prompt_and_dataset(self): - """ - Perform prompt cleaning and create the retaining dataset. - """ - # Special-case prompt cleaning. - if self.prompt == 'allartist': - self.prompt = ("Kelly Mckernan, Thomas Kinkade, Ajin Demi Human, " - "Alena Aenami, Tyler Edlin, Kilian Eng") - elif self.prompt == 'i2p': - self.prompt = ("hate, harassment, violence, suffering, humiliation, harm, " - "suicide, sexual, nudity, bodily fluids, blood") - elif self.prompt == "artifact": - self.prompt = ("ugly, tiling, poorly drawn hands, poorly drawn feet, poorly drawn face, out of frame, " - "mutation, mutated, extra limbs, extra legs, extra arms, disfigured, deformed, cross-eye, " - "body out of frame, blurry, bad art, bad anatomy, blurred, text, watermark, grainy") - if self.seperator: - self.words = [w.strip() for w in self.prompt.split(self.seperator)] - else: - self.words = [self.prompt] - self.word_print = self.prompt.replace(" ", "") - self.logger.info(f"The Concept Prompt to be unlearned: {self.words}") - - # Create a retaining dataset using your helper function. - self.retain_dataset = retain_prompt(self.dataset_retain) - def _setup_optimizer(self): """ Set up the optimizer based on the training method. @@ -158,7 +126,6 @@ def train(self): Execute the adversarial unlearning training loop. """ ddim_eta = self.ddim_eta - # Lambda to sample until a given time step. quick_sample_till_t = lambda x, s, code, batch, t: sample_model( self.model, self.sampler, x, self.image_size, self.image_size, self.ddim_steps, s, ddim_eta, @@ -171,7 +138,6 @@ def train(self): pbar = tqdm(range(self.iterations)) for i in pbar: - # --- Update adversarial prompt every adv_prompt_update_step iterations --- if i % self.adv_prompt_update_step == 0: if self.retain_dataset.check_unseen_prompt_count() < self.retain_batch: self.retain_dataset.reset() @@ -189,12 +155,10 @@ def train(self): text_input.input_ids.to(self.devices[0]), self.devices[0] ) - # Get conditioning from the frozen model. emb_0 = self.model_orig.get_learned_conditioning(['']) emb_p = self.model_orig.get_learned_conditioning([word]) if i >= self.warmup_iter: - # Update adversarial prompt using SoftPromptAttack. self.custom_text_encoder.text_encoder.eval() self.custom_text_encoder.text_encoder.requires_grad_(False) self.model.eval() @@ -243,7 +207,6 @@ def train(self): global_step += self.attack_step attack_round += 1 - # --- Set models to training/eval modes based on train_method --- if 'text_encoder' in self.train_method: self.custom_text_encoder.text_encoder.train() self.custom_text_encoder.text_encoder.requires_grad_(True) @@ -255,7 +218,6 @@ def train(self): self.optimizer.zero_grad() - # --- Retaining prompts for retention regularization (if configured) --- if self.retain_train == 'reg': retain_words = self.retain_dataset.get_random_prompts(self.retain_batch) retain_text_input = self.tokenizer( @@ -284,7 +246,6 @@ def train(self): retain_emb_p = None retain_emb_n = None - # --- Compute training loss --- if i < self.warmup_iter: input_ids = text_input.input_ids.to(self.devices[0]) emb_n = self.custom_text_encoder( @@ -326,7 +287,6 @@ def train(self): global_step += 1 self.optimizer.step() - # --- Additional Retention Training (for iterative retention) --- if self.retain_train == 'iter': for r in range(self.retain_step): self.optimizer.zero_grad() @@ -384,40 +344,22 @@ def train(self): retain_loss.backward() self.optimizer.step() - # --- Checkpointing and saving history --- if (i + 1) % self.config['save_interval'] == 0 and (i + 1) != self.iterations and (i + 1) >= self.config['save_interval']: if 'text_encoder' in self.train_method: save_text_encoder(self.output_dir, self.custom_text_encoder, self.train_method, i) else: - save_model( - self.output_dir, - self.model, - self.train_method, - i, - save_compvis=True, - save_diffusers=True, - compvis_config_file=self.config_path, - diffusers_config_file=self.diffusers_config_path - ) + output_path = f"{self.output_dir}/models/model_checkpoint_{i}.pt" + self.model_loader.save_model(self.model, output_path) if i % 1 == 0: save_history(self.output_dir, losses, self.word_print) - # --- Final checkpointing --- self.model.eval() self.custom_text_encoder.text_encoder.eval() self.custom_text_encoder.text_encoder.requires_grad_(False) if 'text_encoder' in self.train_method: save_text_encoder(self.output_dir, self.custom_text_encoder, self.train_method, i) - else: - save_model( - self.output_dir, - self.model, - self.train_method, - i, - save_compvis=True, - save_diffusers=True, - compvis_config_file=self.config_path, - diffusers_config_file=self.diffusers_config_path - ) + else: + output_path = f"{self.output_dir}/models/model_checkpoint_{i}.pt" + self.model_loader.save_model(self.model, output_path) save_history(self.output_dir, losses, self.word_print) return self.model diff --git a/mu_defense/algorithms/adv_unlearn/configs/adv_unlearn_config.py b/mu_defense/algorithms/adv_unlearn/configs/adv_unlearn_config.py index 914810b5..daa3211d 100644 --- a/mu_defense/algorithms/adv_unlearn/configs/adv_unlearn_config.py +++ b/mu_defense/algorithms/adv_unlearn/configs/adv_unlearn_config.py @@ -4,16 +4,17 @@ from pathlib import Path from mu_defense.core.base_config import BaseConfig -current_dir = Path(__file__).parent class AdvUnlearnConfig(BaseConfig): def __init__(self, **kwargs): # Inference & Model Paths - self.config_path = current_dir / "configs/stable-diffusion/v1-inference.yaml" #for compvis + self.model_config_path = "configs/stable-diffusion/v1-inference.yaml" #for compvis self.compvis_ckpt_path = "models/sd-v1-4-full-ema.ckpt" - self.diffusers_config_path = current_dir / "diffusers_unet_config.json" self.encoder_model_name_or_path = "CompVis/stable-diffusion-v1-4" self.cache_path = ".cache" + + self.diffusers_model_name_or_path = "" + self.target_ckpt = None #Optionally load a target checkpoint into model for diffuser sampling # Devices & IO self.devices = "0,0" # You can later parse this string into a list if needed. @@ -28,7 +29,7 @@ def __init__(self, **kwargs): # Training Setup self.prompt = "nudity" - self.dataset_retain = "coco" # Choices: 'coco_object', 'coco_object_no_filter', 'imagenet243', 'imagenet243_no_filter' + self.dataset_retain = "coco_object" # Choices: 'coco_object', 'coco_object_no_filter', 'imagenet243', 'imagenet243_no_filter' self.retain_batch = 5 self.retain_train = "iter" # Options: 'iter' or 'reg' self.retain_step = 1 @@ -39,7 +40,7 @@ def __init__(self, **kwargs): self.norm_layer = False # This is a flag; use True if you wish to update the norm layer. self.attack_method = "pgd" # Choices: 'pgd', 'multi_pgd', 'fast_at', 'free_at' self.component = "all" # Choices: 'all', 'ffn', 'attn' - self.iterations = 1000 + self.iterations = 10 self.save_interval = 200 self.lr = 1e-5 @@ -72,12 +73,6 @@ def validate_config(self): raise ValueError("Image size should be a positive integer.") if self.iterations <= 0: raise ValueError("Iterations must be a positive integer.") - if not os.path.exists(self.config_path): - raise FileNotFoundError(f"Model config file {self.config_path} does not exist.") - if not os.path.exists(self.ckpt_path): - raise FileNotFoundError(f"Checkpoint file {self.ckpt_path} does not exist.") - if not os.path.exists(self.diffusers_config_path): - raise FileNotFoundError(f"Diffusers config file {self.diffusers_config_path} does not exist.") if not os.path.exists(self.output_dir): os.makedirs(self.output_dir) diff --git a/mu_defense/algorithms/adv_unlearn/dataset_handler.py b/mu_defense/algorithms/adv_unlearn/dataset_handler.py new file mode 100644 index 00000000..f49f4b4d --- /dev/null +++ b/mu_defense/algorithms/adv_unlearn/dataset_handler.py @@ -0,0 +1,50 @@ + +# mu_defense/algorithms/adv_unlearn/dataset_handler.py + +import logging + +from mu_defense.core import BaseDatasetHandler +from mu_defense.algorithms.adv_unlearn import retain_prompt + + +class AdvUnlearnDatasetHandler(BaseDatasetHandler): + """ + AdvUnlearnDatasetHandler handles prompt cleaning and retaining dataset creation + specifically for adversarial unlearning. + """ + def __init__(self,prompt, seperator, dataset_retain): + self.prompt = prompt + self.seperator = seperator + self.dataset_retain = dataset_retain + + self.logger = logging.getLogger(__name__) + + + def setup_prompt(self): + # Special-case prompt cleaning. + if self.prompt == 'allartist': + self.prompt = ("Kelly Mckernan, Thomas Kinkade, Ajin Demi Human, " + "Alena Aenami, Tyler Edlin, Kilian Eng") + elif self.prompt == 'i2p': + self.prompt = ("hate, harassment, violence, suffering, humiliation, harm, " + "suicide, sexual, nudity, bodily fluids, blood") + elif self.prompt == "artifact": + self.prompt = ("ugly, tiling, poorly drawn hands, poorly drawn feet, poorly drawn face, out of frame, " + "mutation, mutated, extra limbs, extra legs, extra arms, disfigured, deformed, cross-eye, " + "body out of frame, blurry, bad art, bad anatomy, blurred, text, watermark, grainy") + + if self.seperator: + self.words = [w.strip() for w in self.prompt.split(self.seperator)] + else: + self.words = [self.prompt] + + self.word_print = self.prompt.replace(" ", "") + self.logger.info(f"The Concept Prompt to be unlearned: {self.words}") + return self.words, self.word_print + + def setup_dataset(self): + """ + Create and return the retaining dataset using the helper function. + """ + dataset = retain_prompt(self.dataset_retain) + return dataset diff --git a/mu_defense/algorithms/adv_unlearn/model.py b/mu_defense/algorithms/adv_unlearn/model.py index f6bae9e1..d997c4f0 100644 --- a/mu_defense/algorithms/adv_unlearn/model.py +++ b/mu_defense/algorithms/adv_unlearn/model.py @@ -1,76 +1,65 @@ +# mu_defense/algorithms/adv_unlearn/model.py + import torch -from diffusers import AutoencoderKL from transformers import CLIPTextModel, CLIPTokenizer +from mu_defense.core import BaseModel from mu_attack.tasks.utils.text_encoder import CustomTextEncoder -from mu_defense.algorithms.adv_unlearn import get_models +from mu_defense.algorithms.adv_unlearn import get_models_for_compvis, get_models_for_diffusers -from mu_defense.core import BaseModel class AdvUnlearnModel(BaseModel): - """ - AdvUnlearnModel handles loading of the components for adversarial unlearning. - This includes: - - The VAE from the pretrained model. - - The tokenizer. - - The text encoder (and its custom wrapper). - - The diffusion models (trainable and frozen versions) along with their samplers. - """ - def __init__( - self, - model_name_or_path: str, - config_path: str, - compvis_ckpt_path: str, - cache_path: str, - devices: list - ): - """ - Initialize the AdvUnlearnModel loader. - - Args: - model_name_or_path (str): Path or identifier of the pretrained model. - config_path (str): Path to the model configuration file. - compvis_ckpt_path (str): Path to the model checkpoint. - cache_path (str): Directory for caching downloaded models. - devices (list): List of device strings (e.g., ['cuda:0', 'cuda:1']) for model placement. - """ + def __init__(self, config: dict): super().__init__() - self.model_name_or_path = model_name_or_path - self.config_path = config_path - self.ckpt_path = compvis_ckpt_path - self.cache_path = cache_path - self.devices = devices - - # Load the VAE. - self.vae = AutoencoderKL.from_pretrained( - self.model_name_or_path, - subfolder="vae", - cache_dir=self.cache_path - ).to(self.devices[0]) - - # Load the tokenizer. + self.encoder_model_name_or_path = config.get("encoder_model_name_or_path") + self.model_config_path = config.get("model_config_path") + self.compvis_ckpt_path = config.get("compvis_ckpt_path") + + self.diffusers_model_name_or_path = config.get("diffusers_model_name_or_path") + self.target_ckpt = config.get("target_ckpt") + + self.cache_path = config.get("cache_path") + devices = config.get("devices") + if isinstance(devices, str): + self.devices = [f'cuda:{int(d.strip())}' for d in devices.split(',')] + elif isinstance(devices, list): + self.devices = devices + else: + raise ValueError("devices must be a comma-separated string or a list") + + self.backend = config.get("backend") + + self.load_model() + + def load_model(self): + # Load tokenizer self.tokenizer = CLIPTokenizer.from_pretrained( - self.model_name_or_path, + self.encoder_model_name_or_path, subfolder="tokenizer", cache_dir=self.cache_path ) - - # Load the text encoder and wrap it with your custom encoder. + # Load text encoder and wrap it self.text_encoder = CLIPTextModel.from_pretrained( - self.model_name_or_path, + self.encoder_model_name_or_path, subfolder="text_encoder", cache_dir=self.cache_path ).to(self.devices[0]) self.custom_text_encoder = CustomTextEncoder(self.text_encoder).to(self.devices[0]) self.all_embeddings = self.custom_text_encoder.get_all_embedding().unsqueeze(0) - # Load diffusion models using your helper function. - self.model_orig, self.sampler_orig, self.model, self.sampler = get_models( - self.config_path, - self.compvis_ckpt_path, - self.devices - ) - self.model_orig.eval() # Set the frozen model to evaluation mode. + # Load diffusion models + if self.backend == "compvis": + self.model_orig, self.sampler_orig, self.model, self.sampler = get_models_for_compvis( + self.model_config_path, + self.compvis_ckpt_path, + self.devices + ) + + elif self.backend == "diffusers": + self.model_orig, self.sampler_orig, self.model, self.sampler = get_models_for_diffusers( + self.diffusers_model_name_or_path, self.devices, self.target_ckpt + ) + def save_model(self, model: torch.nn.Module, output_path: str) -> None: """ @@ -80,19 +69,12 @@ def save_model(self, model: torch.nn.Module, output_path: str) -> None: model (torch.nn.Module): The model to be saved. output_path (str): The file path where the model checkpoint will be stored. """ - torch.save({"state_dict": model.state_dict()}, output_path) + if self.backend == "compvis": + torch.save({"state_dict": model.state_dict()}, output_path) - def get_learned_conditioning(self, prompts: list): - """ - Obtain the learned conditioning for the given prompts using the trainable diffusion model. + elif self.backend == "diffusers": + model.save_pretrained(output_path) - Args: - prompts (list): A list of prompt strings. - - Returns: - The conditioning tensors produced by the model. - """ - return self.model.get_learned_conditioning(prompts) def apply_model(self, z: torch.Tensor, t: torch.Tensor, c): """ diff --git a/mu_defense/algorithms/adv_unlearn/trainer.py b/mu_defense/algorithms/adv_unlearn/trainer.py index e69de29b..b7146c11 100644 --- a/mu_defense/algorithms/adv_unlearn/trainer.py +++ b/mu_defense/algorithms/adv_unlearn/trainer.py @@ -0,0 +1,37 @@ +# mu_defense/algorithms/adv_unlearn/trainer.py + +import logging + +from mu_defense.core import BaseTrainer +from mu_defense.algorithms.adv_unlearn import AdvUnlearnCompvisTrainer + +class AdvUnlearnTrainer(BaseTrainer): + """ + Trainer class orchestrates the adversarial unlearning training process. + It instantiates the model and trainer components based on the provided configuration, + and then runs the training loop. + """ + def __init__(self, config: dict, model, devices): + + self.backend = config.get("backend") + self.logger = logging.getLogger(__name__) + + # Setup components based on the backend. + if self.backend == "compvis": + self.logger.info("Using Compvis backend for adversarial unlearning.") + + # Create the CompVis trainer. + self.trainer = AdvUnlearnCompvisTrainer(model, config, devices) + if self.backend == "diffusers": + pass + + + def run(self): + """ + Run the training loop. + """ + self.logger.info("Starting training...") + self.trainer.train() + self.logger.info("Training complete.") + + diff --git a/mu_defense/algorithms/adv_unlearn/utils.py b/mu_defense/algorithms/adv_unlearn/utils.py index befc5820..3ec80824 100644 --- a/mu_defense/algorithms/adv_unlearn/utils.py +++ b/mu_defense/algorithms/adv_unlearn/utils.py @@ -1,4 +1,6 @@ +# mu_defense/algorithms/adv_unlearn/utils.py + import os import random import pandas as pd @@ -8,20 +10,12 @@ import torch import torch.nn.functional as F -import OmegaConf from diffusers import ( DDIMScheduler, - DPMSolverMultistepScheduler, - EulerAncestralDiscreteScheduler, - EulerDiscreteScheduler, - HeunDiscreteScheduler, - LMSDiscreteScheduler, - PNDMScheduler, UNet2DConditionModel, ) - -from mu.helpers import load_model_from_config, sample_model +from mu.helpers import load_model_from_config from stable_diffusion.ldm.models.diffusion.ddim import DDIMSampler @@ -76,16 +70,97 @@ def id2embedding(tokenizer, all_embeddings, input_ids, device): input_embeds = input_one_hot @ all_embeddings return input_embeds -def get_models(config_path, ckpt_path, devices): - model_orig = load_model_from_config(config_path, ckpt_path, devices[1]) +def get_models_for_diffusers(diffuser_model_name_or_path,devices, target_ckpt=None, cache_path=None): + """ + Loads two copies of a Diffusers UNet model along with their DDIM schedulers. + + Args: + model_name_or_path (str): The Hugging Face model identifier or local path. + target_ckpt (str or None): Path to a target checkpoint to load into the primary model (on devices[0]). + If None, no state dict is loaded. + devices (list or tuple): A list/tuple of two devices, e.g. [device0, device1]. + cache_path (str or None): Optional cache directory for pretrained weights. + + Returns: + model_orig: The UNet loaded on devices[1]. + sampler_orig: The DDIM scheduler corresponding to model_orig. + model: The UNet loaded on devices[0] (optionally updated with target_ckpt). + sampler: The DDIM scheduler corresponding to model. + """ + + # Load the original model (used for e.g. computing loss, etc.) on devices[1] + model_orig = UNet2DConditionModel.from_pretrained( + diffuser_model_name_or_path, + subfolder="unet", + cache_dir=cache_path + ).to(devices[1]) + + # Create a DDIM scheduler for model_orig. (Note: diffusers DDIMScheduler is used here; + # adjust the subfolder or configuration if your scheduler is stored elsewhere.) + sampler_orig = DDIMScheduler.from_pretrained( + diffuser_model_name_or_path, + subfolder="scheduler", + cache_dir=cache_path + ) + + # Load the second copy of the model on devices[0] + model = UNet2DConditionModel.from_pretrained( + diffuser_model_name_or_path, + subfolder="unet", + cache_dir=cache_path + ).to(devices[0]) + + # Optionally load a target checkpoint into model + if target_ckpt is not None: + state_dict = torch.load(target_ckpt, map_location=devices[0]) + model.load_state_dict(state_dict) + + sampler = DDIMScheduler.from_pretrained( + diffuser_model_name_or_path, + subfolder="scheduler", + cache_dir=cache_path + ) + + return model_orig, sampler_orig, model, sampler + +def get_models_for_compvis(config_path, compvis_ckpt_path, devices): + model_orig = load_model_from_config(config_path, compvis_ckpt_path, devices[1]) sampler_orig = DDIMSampler(model_orig) - model = load_model_from_config(config_path, ckpt_path, devices[0]) + model = load_model_from_config(config_path, compvis_ckpt_path, devices[0]) sampler = DDIMSampler(model) return model_orig, sampler_orig, model, sampler +@torch.no_grad() +def sample_model(model, sampler, c, h, w, ddim_steps, scale, ddim_eta, start_code=None, n_samples=1,t_start=-1,log_every_t=None,till_T=None,verbose=True): + """Sample the model""" + uc = None + if scale != 1.0: + uc = model.get_learned_conditioning(n_samples * [""]) + log_t = 100 + if log_every_t is not None: + log_t = log_every_t + shape = [4, h // 8, w // 8] + samples_ddim, inters = sampler.sample(S=ddim_steps, + conditioning=c, + batch_size=n_samples, + shape=shape, + verbose=False, + x_T=start_code, + unconditional_guidance_scale=scale, + unconditional_conditioning=uc, + eta=ddim_eta, + verbose_iter = verbose, + t_start=t_start, + log_every_t = log_t, + till_T = till_T + ) + if log_every_t is not None: + return samples_ddim, inters + return samples_ddim + def get_train_loss_retain( retain_batch, retain_train, retain_loss_w, model, model_orig, text_encoder, sampler, emb_0, emb_p, retain_emb_p, emb_n, retain_emb_n, start_guidance, negative_guidance, devices, ddim_steps, ddim_eta, image_size, criteria, adv_input_ids, attack_embd_type, adv_embd=None): """_summary_ @@ -949,149 +1024,6 @@ def convert_ldm_unet_checkpoint(checkpoint, config, path=None, extract_ema=False return new_checkpoint -def savemodelDiffusers(path, name, compvis_config_file, diffusers_config_file, device='cpu'): - checkpoint_path = path - - original_config_file = compvis_config_file - config_file = diffusers_config_file - num_in_channels = 4 - scheduler_type = 'ddim' - pipeline_type = None - image_size = 512 - prediction_type = 'epsilon' - extract_ema = False - dump_path = path.replace('Compvis','Diffusers') - upcast_attention = False - - - if device is None: - device = "cuda" if torch.cuda.is_available() else "cpu" - checkpoint = torch.load(checkpoint_path, map_location=device) - else: - checkpoint = torch.load(checkpoint_path, map_location=device) - - # Sometimes models don't have the global_step item - if "global_step" in checkpoint: - global_step = checkpoint["global_step"] - else: - print("global_step key not found in model") - global_step = None - - if "state_dict" in checkpoint: - checkpoint = checkpoint["state_dict"] - upcast_attention = upcast_attention - if original_config_file is None: - key_name = "model.diffusion_model.input_blocks.2.1.transformer_blocks.0.attn2.to_k.weight" - - if key_name in checkpoint and checkpoint[key_name].shape[-1] == 1024: - if not os.path.isfile("v2-inference-v.yaml"): - # model_type = "v2" - os.system( - "wget https://raw.githubusercontent.com/Stability-AI/stablediffusion/main/configs/stable-diffusion/v2-inference-v.yaml" - " -O v2-inference-v.yaml" - ) - original_config_file = "./v2-inference-v.yaml" - - if global_step == 110000: - # v2.1 needs to upcast attention - upcast_attention = True - else: - if not os.path.isfile("v1-inference.yaml"): - # model_type = "v1" - os.system( - "wget https://raw.githubusercontent.com/CompVis/stable-diffusion/main/configs/stable-diffusion/v1-inference.yaml" - " -O v1-inference.yaml" - ) - original_config_file = "./v1-inference.yaml" - - original_config = OmegaConf.load(original_config_file) - - if num_in_channels is not None: - original_config["model"]["params"]["unet_config"]["params"]["in_channels"] = num_in_channels - - if ( - "parameterization" in original_config["model"]["params"] - and original_config["model"]["params"]["parameterization"] == "v" - ): - if prediction_type is None: - # NOTE: For stable diffusion 2 base it is recommended to pass `prediction_type=="epsilon"` - # as it relies on a brittle global step parameter here - prediction_type = "epsilon" if global_step == 875000 else "v_prediction" - if image_size is None: - # NOTE: For stable diffusion 2 base one has to pass `image_size==512` - # as it relies on a brittle global step parameter here - image_size = 512 if global_step == 875000 else 768 - else: - if prediction_type is None: - prediction_type = "epsilon" - if image_size is None: - image_size = 512 - - num_train_timesteps = original_config.model.params.timesteps - beta_start = original_config.model.params.linear_start - beta_end = original_config.model.params.linear_end - scheduler = DDIMScheduler( - beta_end=beta_end, - beta_schedule="scaled_linear", - beta_start=beta_start, - num_train_timesteps=num_train_timesteps, - steps_offset=1, - clip_sample=False, - set_alpha_to_one=False, - prediction_type=prediction_type, - ) - # make sure scheduler works correctly with DDIM - scheduler.register_to_config(clip_sample=False) - - if scheduler_type == "pndm": - config = dict(scheduler.config) - config["skip_prk_steps"] = True - scheduler = PNDMScheduler.from_config(config) - elif scheduler_type == "lms": - scheduler = LMSDiscreteScheduler.from_config(scheduler.config) - elif scheduler_type == "heun": - scheduler = HeunDiscreteScheduler.from_config(scheduler.config) - elif scheduler_type == "euler": - scheduler = EulerDiscreteScheduler.from_config(scheduler.config) - elif scheduler_type == "euler-ancestral": - scheduler = EulerAncestralDiscreteScheduler.from_config(scheduler.config) - elif scheduler_type == "dpm": - scheduler = DPMSolverMultistepScheduler.from_config(scheduler.config) - elif scheduler_type == "ddim": - scheduler = scheduler - else: - raise ValueError(f"Scheduler of type {scheduler_type} doesn't exist!") - - # Convert the UNet2DConditionModel model. - unet_config = create_unet_diffusers_config(original_config, image_size=image_size) - unet_config["upcast_attention"] = False - unet = UNet2DConditionModel(**unet_config) - - converted_unet_checkpoint = convert_ldm_unet_checkpoint( - checkpoint, unet_config, path=checkpoint_path, extract_ema=extract_ema - ) - torch.save(converted_unet_checkpoint, dump_path) - - - -def save_model(folder_path, model, name, num, compvis_config_file=None, diffusers_config_file=None, device='cpu', save_compvis=True, save_diffusers=True): - # SAVE MODEL - - # PATH = f'{FOLDER}/{model_type}-word_{word_print}-method_{train_method}-sg_{start_guidance}-ng_{neg_guidance}-iter_{i+1}-lr_{lr}-startmodel_{start_model}-numacc_{numacc}.pt' - folder_path = f'{folder_path}/models' - os.makedirs(folder_path, exist_ok=True) - if num is not None: - path = f'{folder_path}/Compvis-UNet-{name}-epoch_{num}.pt' - else: - path = f'{folder_path}/Compvis-UNet-{name}.pt' - if save_compvis: - torch.save(model.state_dict(), path) - - if save_diffusers: - print('Saving Model in Diffusers Format') - savemodelDiffusers(path, name, compvis_config_file, diffusers_config_file, device=device ) - - def moving_average(a, n=3) : ret = np.cumsum(a, dtype=float) diff --git a/mu_defense/core/__init__.py b/mu_defense/core/__init__.py index dee0b006..22020222 100644 --- a/mu_defense/core/__init__.py +++ b/mu_defense/core/__init__.py @@ -1,11 +1,15 @@ from .base_algorithm import BaseAlgorithm from .base_model import BaseModel -from .base_trainer import BaseTrainer +from .base_compvis_trainer import BaseCompvisTrainer from .base_config import BaseConfig +from .base_data_handler import BaseDatasetHandler +from .base_trainer import BaseTrainer __all__ = [ "BaseAlgorithm", "BaseModel", "BaseTrainer", - "BaseConfig" - ] \ No newline at end of file + "BaseConfig", + "BaseDatasetHandler", + "BaseCompvisTrainer" + ] diff --git a/mu_defense/core/base_compvis_trainer.py b/mu_defense/core/base_compvis_trainer.py new file mode 100644 index 00000000..b4fea970 --- /dev/null +++ b/mu_defense/core/base_compvis_trainer.py @@ -0,0 +1,23 @@ +# mu_defense/core/base_trainer.py + +from abc import ABC +from typing import Any + +class BaseCompvisTrainer(ABC): + """Abstract base class for training unlearning models.""" + + def __init__(self, model: Any, config: dict, **kwargs): + self.model = model + self.config = config + + + # @abstractmethod + def setup_optimizer(self, *args, **kwargs): + """Set up the optimizers for training.""" + pass + + # @abstractmethod + def train(self, *args, **kwargs): + """Train the model.""" + pass + diff --git a/mu_defense/core/base_config.py b/mu_defense/core/base_config.py index 660e9066..3b920123 100644 --- a/mu_defense/core/base_config.py +++ b/mu_defense/core/base_config.py @@ -31,3 +31,4 @@ def to_dict(self): else: result[attr_name] = attr_value return result + diff --git a/mu_defense/core/base_data_handler.py b/mu_defense/core/base_data_handler.py new file mode 100644 index 00000000..41668f22 --- /dev/null +++ b/mu_defense/core/base_data_handler.py @@ -0,0 +1,29 @@ +# mu_defense/core/base_dataset_handler.py + +from abc import ABC, abstractmethod + +class BaseDatasetHandler(ABC): + """ + BaseDatasetHandler provides a blueprint for handling dataset-related tasks, + including prompt cleaning and creation of a retaining dataset. + """ + def __init__(self, prompt: str, seperator: str = None, dataset_retain=None): + self.prompt = prompt + self.seperator = seperator + self.dataset_retain = dataset_retain + self.words = [] + self.word_print = "" + + @abstractmethod + def setup_prompt(self): + """ + Set up and return the cleaned prompt and the printable version. + """ + pass + + @abstractmethod + def setup_dataset(self): + """ + Create and return the retaining dataset. + """ + pass \ No newline at end of file diff --git a/mu_defense/core/base_trainer.py b/mu_defense/core/base_trainer.py index 8567a063..62776409 100644 --- a/mu_defense/core/base_trainer.py +++ b/mu_defense/core/base_trainer.py @@ -1,23 +1,19 @@ -# mu_defense/core/base_trainer.py - -from abc import ABC -from typing import Any +from abc import ABC, abstractmethod +import logging class BaseTrainer(ABC): - """Abstract base class for training unlearning models.""" - - def __init__(self, model: Any, config: dict, **kwargs): - self.model = model + """ + BaseTrainerRunner is an abstract base class for high-level training orchestrators. + It defines the interface and common properties for running a training process. + """ + def __init__(self, config: dict): self.config = config - - - # @abstractmethod - def setup_optimizer(self, *args, **kwargs): - """Set up the optimizers for training.""" + self.devices = config.get("devices", ["cuda:0"]) + self.logger = logging.getLogger(__name__) + + @abstractmethod + def run(self): + """ + Run the training loop. Must be implemented by subclasses. + """ pass - - # @abstractmethod - def train(self, *args, **kwargs): - """Train the model.""" - pass - diff --git a/mu_defense/environment.yaml b/mu_defense/environment.yaml new file mode 100644 index 00000000..13036586 --- /dev/null +++ b/mu_defense/environment.yaml @@ -0,0 +1,36 @@ +name: AdvUnlearn +channels: + - pytorch + - defaults +dependencies: + - python=3.8.5 + - pip=20.3 + - cudatoolkit=11.3 + - pytorch=1.11.0 + - torchvision=0.12.0 + - numpy=1.23.5 + - huggingface_hub==0.25.1 + - pip: + - albumentations==0.4.3 + - diffusers==0.12.1 + - opencv-python==4.1.2.30 + - pudb==2019.2 + - invisible-watermark + - imageio==2.9.0 + - imageio-ffmpeg==0.4.2 + - pytorch-lightning==1.4.2 + - omegaconf==2.1.1 + - test-tube>=0.7.5 + - streamlit>=0.73.1 + - einops==0.3.0 + - torch-fidelity==0.3.0 + - transformers==4.25.1 + - torchmetrics==0.6.0 + - kornia==0.6 + - timm==1.0.11 + - matplotlib + - wandb + - tabulate + - -e git+https://github.com/CompVis/taming-transformers.git@master#egg=taming-transformers + - -e git+https://github.com/openai/CLIP.git@main#egg=clip + From 5b162dc361aa4733406f79c5412e40d72ec6a5fc Mon Sep 17 00:00:00 2001 From: palisthadeshar Date: Tue, 11 Feb 2025 11:39:03 +0000 Subject: [PATCH 18/22] fixes in adv unlearn for compvis --- .../algorithms/adv_unlearn/compvis_trainer.py | 458 ++++++++++++++++-- .../algorithms/adv_unlearn/dataset_handler.py | 2 +- mu_defense/algorithms/adv_unlearn/model.py | 2 +- mu_defense/core/__init__.py | 4 +- mu_defense/core/base_compvis_trainer.py | 23 - mu_defense/core/base_trainer.py | 3 + 6 files changed, 417 insertions(+), 75 deletions(-) delete mode 100644 mu_defense/core/base_compvis_trainer.py diff --git a/mu_defense/algorithms/adv_unlearn/compvis_trainer.py b/mu_defense/algorithms/adv_unlearn/compvis_trainer.py index ea232c02..8eb79bd4 100644 --- a/mu_defense/algorithms/adv_unlearn/compvis_trainer.py +++ b/mu_defense/algorithms/adv_unlearn/compvis_trainer.py @@ -1,3 +1,372 @@ +# # mu_defense/algorithms/adv_unlearn/compvis_trainer.py + +# import torch +# from tqdm import tqdm +# import random +# import wandb +# import logging +# from torch.nn import MSELoss + +# from mu.core import BaseTrainer +# from mu_defense.algorithms.adv_unlearn import ( +# id2embedding, +# param_choices, +# get_train_loss_retain, +# save_text_encoder, +# save_history, +# sample_model +# ) +# from mu_attack.attackers.soft_prompt import SoftPromptAttack +# from mu_defense.algorithms.adv_unlearn import AdvUnlearnDatasetHandler + + +# class AdvUnlearnCompvisTrainer(BaseTrainer): +# """ +# Trainer for adversarial unlearning. + +# This trainer performs the adversarial prompt update and retention-based +# regularized training loop for CompVis/Diffusers models. +# """ +# def __init__(self, model, config: dict, devices: list, **kwargs): +# """ +# Initialize the AdvUnlearnCompvisTrainer. +# """ +# super().__init__(model, config, **kwargs) +# self.devices = devices + +# # Unpack models and samplers from the provided model loader. +# self.model = model.model # trainable diffusion model +# self.model_orig = model.model_orig # frozen diffusion model (set to eval) +# self.sampler = model.sampler +# self.sampler_orig = model.sampler_orig +# self.model_loader = model + +# # Other loaded components. +# self.tokenizer = model.tokenizer +# self.custom_text_encoder = model.custom_text_encoder +# self.all_embeddings = model.all_embeddings + +# # Loss criterion. +# self.criteria = MSELoss() + +# # Save configuration parameters. +# self.config = config +# self.prompt = self.config['prompt'] +# self.seperator = self.config.get('seperator') +# self.iterations = self.config.get('iterations') +# self.ddim_steps = self.config['ddim_steps'] +# self.start_guidance = self.config['start_guidance'] +# self.negative_guidance = self.config['negative_guidance'] +# self.image_size = self.config['image_size'] +# self.lr = self.config['lr'] +# self.model_config_path = self.config['model_config_path'] +# self.output_dir = self.config['output_dir'] + +# # Retention and attack parameters. +# self.dataset_retain = self.config['dataset_retain'] +# self.retain_batch = self.config['retain_batch'] +# self.retain_train = self.config['retain_train'] +# self.retain_step = self.config['retain_step'] +# self.retain_loss_w = self.config['retain_loss_w'] +# self.attack_method = self.config['attack_method'] +# self.train_method = self.config['train_method'] +# self.norm_layer = self.config['norm_layer'] +# self.component = self.config['component'] +# self.adv_prompt_num = self.config['adv_prompt_num'] +# self.attack_embd_type = self.config['attack_embd_type'] +# self.attack_type = self.config['attack_type'] +# self.attack_init = self.config['attack_init'] +# self.warmup_iter = self.config['warmup_iter'] +# self.attack_step = self.config['attack_step'] +# self.attack_lr = self.config['attack_lr'] +# self.adv_prompt_update_step = self.config['adv_prompt_update_step'] +# self.ddim_eta = self.config['ddim_eta'] + +# self.logger = logging.getLogger(__name__) + +# # Setup the dataset handler and prompt cleaning. +# self.dataset_handler = AdvUnlearnDatasetHandler( +# prompt=self.prompt, +# seperator=self.seperator, +# dataset_retain=self.dataset_retain +# ) +# self.words, self.word_print = self.dataset_handler.setup_prompt() +# self.retain_dataset = self.dataset_handler.setup_dataset() + +# # Initialize adversarial prompt variables. +# self.adv_word_embd = None +# self.adv_condition_embd = None +# self.adv_input_ids = None + +# # Setup trainable parameters and optimizer. +# self._setup_optimizer() + +# def _setup_optimizer(self): +# """ +# Set up the optimizer based on the training method. +# """ +# if 'text_encoder' in self.train_method: +# self.parameters = param_choices( +# model=self.custom_text_encoder, +# train_method=self.train_method, +# component=self.component, +# final_layer_norm=self.norm_layer +# ) +# else: +# self.parameters = param_choices( +# model=self.model, +# train_method=self.train_method, +# component=self.component, +# final_layer_norm=self.norm_layer +# ) +# self.optimizer = torch.optim.Adam(self.parameters, lr=float(self.lr)) + +# def train(self): +# """ +# Execute the adversarial unlearning training loop. +# """ +# ddim_eta = self.ddim_eta +# quick_sample_till_t = lambda x, s, code, batch, t: sample_model( +# self.model, self.sampler, +# x, self.image_size, self.image_size, self.ddim_steps, s, ddim_eta, +# start_code=code, n_samples=batch, till_T=t, verbose=False +# ) +# losses = [] +# history = [] +# global_step = 0 +# attack_round = 0 + +# pbar = tqdm(range(self.iterations)) +# for i in pbar: +# if i % self.adv_prompt_update_step == 0: +# if self.retain_dataset.check_unseen_prompt_count() < self.retain_batch: +# self.retain_dataset.reset() +# word = random.choice(self.words) +# text_input = self.tokenizer( +# word, +# padding="max_length", +# max_length=self.tokenizer.model_max_length, +# return_tensors="pt", +# truncation=True +# ) +# text_embeddings = id2embedding( +# self.tokenizer, +# self.all_embeddings, +# text_input.input_ids.to(self.devices[0]), +# self.devices[0] +# ) +# emb_0 = self.model_orig.get_learned_conditioning(['']) +# emb_p = self.model_orig.get_learned_conditioning([word]) + +# if i >= self.warmup_iter: +# self.custom_text_encoder.text_encoder.eval() +# self.custom_text_encoder.text_encoder.requires_grad_(False) +# self.model.eval() +# if attack_round == 0: +# if self.attack_embd_type == 'word_embd': +# self.adv_word_embd, self.adv_input_ids = SoftPromptAttack.attack( +# global_step, word, self.model, self.model_orig, self.tokenizer, +# self.custom_text_encoder, self.sampler, emb_0, emb_p, +# self.start_guidance, self.devices, self.ddim_steps, ddim_eta, +# self.image_size, self.criteria, self.adv_prompt_num, +# self.all_embeddings, attack_round, self.attack_type, +# self.attack_embd_type, self.attack_step, self.attack_lr, +# self.attack_init, None, self.attack_method +# ) +# elif self.attack_embd_type == 'condition_embd': +# self.adv_condition_embd, self.adv_input_ids = SoftPromptAttack.attack( +# global_step, word, self.model, self.model_orig, self.tokenizer, +# self.custom_text_encoder, self.sampler, emb_0, emb_p, +# self.start_guidance, self.devices, self.ddim_steps, ddim_eta, +# self.image_size, self.criteria, self.adv_prompt_num, +# self.all_embeddings, attack_round, self.attack_type, +# self.attack_embd_type, self.attack_step, self.attack_lr, +# self.attack_init, None, self.attack_method +# ) +# else: +# if self.attack_embd_type == 'word_embd': +# self.adv_word_embd, self.adv_input_ids = SoftPromptAttack.attack( +# global_step, word, self.model, self.model_orig, self.tokenizer, +# self.custom_text_encoder, self.sampler, emb_0, emb_p, +# self.start_guidance, self.devices, self.ddim_steps, ddim_eta, +# self.image_size, self.criteria, self.adv_prompt_num, +# self.all_embeddings, attack_round, self.attack_type, +# self.attack_embd_type, self.attack_step, self.attack_lr, +# self.attack_init, self.adv_word_embd, self.attack_method +# ) +# elif self.attack_embd_type == 'condition_embd': +# self.adv_condition_embd, self.adv_input_ids = SoftPromptAttack.attack( +# global_step, word, self.model, self.model_orig, self.tokenizer, +# self.custom_text_encoder, self.sampler, emb_0, emb_p, +# self.start_guidance, self.devices, self.ddim_steps, ddim_eta, +# self.image_size, self.criteria, self.adv_prompt_num, +# self.all_embeddings, attack_round, self.attack_type, +# self.attack_embd_type, self.attack_step, self.attack_lr, +# self.attack_init, self.adv_condition_embd, self.attack_method +# ) +# global_step += self.attack_step +# attack_round += 1 + +# if 'text_encoder' in self.train_method: +# self.custom_text_encoder.text_encoder.train() +# self.custom_text_encoder.text_encoder.requires_grad_(True) +# self.model.eval() +# else: +# self.custom_text_encoder.text_encoder.eval() +# self.custom_text_encoder.text_encoder.requires_grad_(False) +# self.model.train() + +# self.optimizer.zero_grad() + +# if self.retain_train == 'reg': +# retain_words = self.retain_dataset.get_random_prompts(self.retain_batch) +# retain_text_input = self.tokenizer( +# retain_words, +# padding="max_length", +# max_length=self.tokenizer.model_max_length, +# return_tensors="pt", +# truncation=True +# ) +# retain_input_ids = retain_text_input.input_ids.to(self.devices[0]) +# retain_emb_p = self.model_orig.get_learned_conditioning(retain_words) +# retain_text_embeddings = id2embedding( +# self.tokenizer, +# self.all_embeddings, +# retain_text_input.input_ids.to(self.devices[0]), +# self.devices[0] +# ) +# retain_text_embeddings = retain_text_embeddings.reshape( +# self.retain_batch, -1, retain_text_embeddings.shape[-1] +# ) +# retain_emb_n = self.custom_text_encoder( +# input_ids=retain_input_ids, +# inputs_embeds=retain_text_embeddings +# )[0] +# else: +# retain_emb_p = None +# retain_emb_n = None + +# if i < self.warmup_iter: +# input_ids = text_input.input_ids.to(self.devices[0]) +# emb_n = self.custom_text_encoder( +# input_ids=input_ids, +# inputs_embeds=text_embeddings +# )[0] +# loss = get_train_loss_retain( +# self.retain_batch, self.retain_train, self.retain_loss_w, +# self.model, self.model_orig, self.custom_text_encoder, self.sampler, +# emb_0, emb_p, retain_emb_p, emb_n, retain_emb_n, self.start_guidance, +# self.negative_guidance, self.devices, self.ddim_steps, ddim_eta, +# self.image_size, self.criteria, input_ids, self.attack_embd_type +# ) +# else: +# if self.attack_embd_type == 'word_embd': +# loss = get_train_loss_retain( +# self.retain_batch, self.retain_train, self.retain_loss_w, +# self.model, self.model_orig, self.custom_text_encoder, self.sampler, +# emb_0, emb_p, retain_emb_p, None, retain_emb_n, self.start_guidance, +# self.negative_guidance, self.devices, self.ddim_steps, ddim_eta, +# self.image_size, self.criteria, self.adv_input_ids, self.attack_embd_type, +# self.adv_word_embd +# ) +# elif self.attack_embd_type == 'condition_embd': +# loss = get_train_loss_retain( +# self.retain_batch, self.retain_train, self.retain_loss_w, +# self.model, self.model_orig, self.custom_text_encoder, self.sampler, +# emb_0, emb_p, retain_emb_p, None, retain_emb_n, self.start_guidance, +# self.negative_guidance, self.devices, self.ddim_steps, ddim_eta, +# self.image_size, self.criteria, self.adv_input_ids, self.attack_embd_type, +# self.adv_condition_embd +# ) +# loss.backward() +# losses.append(loss.item()) +# pbar.set_postfix({"loss": loss.item()}) +# history.append(loss.item()) +# wandb.log({'Train_Loss': loss.item()}, step=global_step) +# wandb.log({'Attack_Loss': 0.0}, step=global_step) +# global_step += 1 +# self.optimizer.step() + +# if self.retain_train == 'iter': +# for r in range(self.retain_step): +# self.optimizer.zero_grad() +# if self.retain_dataset.check_unseen_prompt_count() < self.retain_batch: +# self.retain_dataset.reset() +# retain_words = self.retain_dataset.get_random_prompts(self.retain_batch) +# t_enc = torch.randint(self.ddim_steps, (1,), device=self.devices[0]) +# og_num = round((int(t_enc.item()) / self.ddim_steps) * 1000) +# og_num_lim = round(((int(t_enc.item()) + 1) / self.ddim_steps) * 1000) +# t_enc_ddpm = torch.randint(og_num, og_num_lim, (1,), device=self.devices[0]) +# retain_start_code = torch.randn((self.retain_batch, 4, 64, 64)).to(self.devices[0]) +# retain_emb_p = self.model_orig.get_learned_conditioning(retain_words) +# retain_z = quick_sample_till_t( +# retain_emb_p.to(self.devices[0]), +# self.start_guidance, +# retain_start_code, +# self.retain_batch, +# int(t_enc.item()) +# ) +# retain_e_p = self.model_orig.apply_model( +# retain_z.to(self.devices[0]), +# t_enc_ddpm.to(self.devices[0]), +# retain_emb_p.to(self.devices[0]) +# ) +# retain_text_input = self.tokenizer( +# retain_words, +# padding="max_length", +# max_length=self.tokenizer.model_max_length, +# return_tensors="pt", +# truncation=True +# ) +# retain_input_ids = retain_text_input.input_ids.to(self.devices[0]) +# retain_text_embeddings = id2embedding( +# self.tokenizer, +# self.all_embeddings, +# retain_text_input.input_ids.to(self.devices[0]), +# self.devices[0] +# ) +# retain_text_embeddings = retain_text_embeddings.reshape( +# self.retain_batch, -1, retain_text_embeddings.shape[-1] +# ) +# retain_emb_n = self.custom_text_encoder( +# input_ids=retain_input_ids, +# inputs_embeds=retain_text_embeddings +# )[0] +# retain_e_n = self.model.apply_model( +# retain_z.to(self.devices[0]), +# t_enc_ddpm.to(self.devices[0]), +# retain_emb_n.to(self.devices[0]) +# ) +# retain_loss = self.criteria( +# retain_e_n.to(self.devices[0]), +# retain_e_p.to(self.devices[0]) +# ) +# retain_loss.backward() +# self.optimizer.step() + +# if (i + 1) % self.config['save_interval'] == 0 and (i + 1) != self.iterations and (i + 1) >= self.config['save_interval']: +# if 'text_encoder' in self.train_method: +# save_text_encoder(self.output_dir, self.custom_text_encoder, self.train_method, i) +# else: +# output_path = f"{self.output_dir}/models/model_checkpoint_{i}.pt" +# self.model_loader.save_model(self.model, output_path) +# if i % 1 == 0: +# save_history(self.output_dir, losses, self.word_print) + +# self.model.eval() +# self.custom_text_encoder.text_encoder.eval() +# self.custom_text_encoder.text_encoder.requires_grad_(False) +# if 'text_encoder' in self.train_method: +# save_text_encoder(self.output_dir, self.custom_text_encoder, self.train_method, i) +# else: +# output_path = f"{self.output_dir}/models/model_checkpoint_{i}.pt" +# self.model_loader.save_model(self.model, output_path) +# save_history(self.output_dir, losses, self.word_print) +# return self.model + + +# mu_defense/algorithms/adv_unlearn/compvis_trainer.py + # mu_defense/algorithms/adv_unlearn/compvis_trainer.py import torch @@ -16,7 +385,9 @@ save_history, sample_model ) -from mu_attack.attackers.soft_prompt import SoftPromptAttack +# Note: We no longer import SoftPromptAttack directly. +from mu_attack.execs.adv_attack import AdvAttack +from mu_attack.configs.adv_unlearn import AdvAttackConfig from mu_defense.algorithms.adv_unlearn import AdvUnlearnDatasetHandler @@ -134,7 +505,7 @@ def train(self): losses = [] history = [] global_step = 0 - attack_round = 0 + attack_round = 0 pbar = tqdm(range(self.iterations)) for i in pbar: @@ -155,6 +526,7 @@ def train(self): text_input.input_ids.to(self.devices[0]), self.devices[0] ) + # Obtain the unconditional and conditional embeddings via the original model. emb_0 = self.model_orig.get_learned_conditioning(['']) emb_p = self.model_orig.get_learned_conditioning([word]) @@ -162,50 +534,43 @@ def train(self): self.custom_text_encoder.text_encoder.eval() self.custom_text_encoder.text_encoder.requires_grad_(False) self.model.eval() - if attack_round == 0: - if self.attack_embd_type == 'word_embd': - self.adv_word_embd, self.adv_input_ids = SoftPromptAttack.attack( - global_step, word, self.model, self.model_orig, self.tokenizer, - self.custom_text_encoder, self.sampler, emb_0, emb_p, - self.start_guidance, self.devices, self.ddim_steps, ddim_eta, - self.image_size, self.criteria, self.adv_prompt_num, - self.all_embeddings, attack_round, self.attack_type, - self.attack_embd_type, self.attack_step, self.attack_lr, - self.attack_init, None, self.attack_method - ) - elif self.attack_embd_type == 'condition_embd': - self.adv_condition_embd, self.adv_input_ids = SoftPromptAttack.attack( - global_step, word, self.model, self.model_orig, self.tokenizer, - self.custom_text_encoder, self.sampler, emb_0, emb_p, - self.start_guidance, self.devices, self.ddim_steps, ddim_eta, - self.image_size, self.criteria, self.adv_prompt_num, - self.all_embeddings, attack_round, self.attack_type, - self.attack_embd_type, self.attack_step, self.attack_lr, - self.attack_init, None, self.attack_method - ) - else: - if self.attack_embd_type == 'word_embd': - self.adv_word_embd, self.adv_input_ids = SoftPromptAttack.attack( - global_step, word, self.model, self.model_orig, self.tokenizer, - self.custom_text_encoder, self.sampler, emb_0, emb_p, - self.start_guidance, self.devices, self.ddim_steps, ddim_eta, - self.image_size, self.criteria, self.adv_prompt_num, - self.all_embeddings, attack_round, self.attack_type, - self.attack_embd_type, self.attack_step, self.attack_lr, - self.attack_init, self.adv_word_embd, self.attack_method - ) - elif self.attack_embd_type == 'condition_embd': - self.adv_condition_embd, self.adv_input_ids = SoftPromptAttack.attack( - global_step, word, self.model, self.model_orig, self.tokenizer, - self.custom_text_encoder, self.sampler, emb_0, emb_p, - self.start_guidance, self.devices, self.ddim_steps, ddim_eta, - self.image_size, self.criteria, self.adv_prompt_num, - self.all_embeddings, attack_round, self.attack_type, - self.attack_embd_type, self.attack_step, self.attack_lr, - self.attack_init, self.adv_condition_embd, self.attack_method - ) + + # Build the attack configuration. (Adjust fields if necessary.) + attack_config = AdvAttackConfig( + prompt=word, + encoder_model_name_or_path=self.tokenizer.name_or_path, + cache_path=self.config.get("cache_path", "./cache"), + devices=",".join([d.strip() for d in self.config.get("devices", "cuda:0").split(',')]), + attack_type=self.attack_type, + attack_embd_type=self.attack_embd_type, + attack_step=self.attack_step, + attack_lr=self.attack_lr, + attack_init=self.attack_init, + attack_init_embd=None, # Set as needed + attack_method=self.attack_method, + ddim_steps=self.ddim_steps, + ddim_eta=ddim_eta, + image_size=self.image_size, + adv_prompt_num=self.adv_prompt_num, + start_guidance=self.start_guidance, + config_path=self.model_config_path, + compvis_ckpt_path=self.config.get("compvis_ckpt_path", ""), + backend="compvis", + diffusers_model_name_or_path="", + target_ckpt="", + project=self.config.get("project_name", "default_project"), + experiment_name=self.config.get("experiment_name", "default_experiment") + ) + adv_attack = AdvAttack(attack_config) + adv_word_embd, adv_input_ids = adv_attack.attack() + + if self.attack_embd_type == 'word_embd': + self.adv_word_embd, self.adv_input_ids = adv_word_embd, adv_input_ids + elif self.attack_embd_type == 'condition_embd': + self.adv_condition_embd, self.adv_input_ids = adv_word_embd, adv_input_ids + global_step += self.attack_step - attack_round += 1 + attack_round += 1 if 'text_encoder' in self.train_method: self.custom_text_encoder.text_encoder.train() @@ -350,8 +715,7 @@ def train(self): else: output_path = f"{self.output_dir}/models/model_checkpoint_{i}.pt" self.model_loader.save_model(self.model, output_path) - if i % 1 == 0: - save_history(self.output_dir, losses, self.word_print) + save_history(self.output_dir, losses, self.word_print) self.model.eval() self.custom_text_encoder.text_encoder.eval() diff --git a/mu_defense/algorithms/adv_unlearn/dataset_handler.py b/mu_defense/algorithms/adv_unlearn/dataset_handler.py index f49f4b4d..fb51bf70 100644 --- a/mu_defense/algorithms/adv_unlearn/dataset_handler.py +++ b/mu_defense/algorithms/adv_unlearn/dataset_handler.py @@ -4,7 +4,7 @@ import logging from mu_defense.core import BaseDatasetHandler -from mu_defense.algorithms.adv_unlearn import retain_prompt +from mu_defense.algorithms.adv_unlearn.utils import retain_prompt class AdvUnlearnDatasetHandler(BaseDatasetHandler): diff --git a/mu_defense/algorithms/adv_unlearn/model.py b/mu_defense/algorithms/adv_unlearn/model.py index d997c4f0..f156f17a 100644 --- a/mu_defense/algorithms/adv_unlearn/model.py +++ b/mu_defense/algorithms/adv_unlearn/model.py @@ -5,7 +5,7 @@ from mu_defense.core import BaseModel from mu_attack.tasks.utils.text_encoder import CustomTextEncoder -from mu_defense.algorithms.adv_unlearn import get_models_for_compvis, get_models_for_diffusers +from mu_defense.algorithms.adv_unlearn.utils import get_models_for_compvis, get_models_for_diffusers class AdvUnlearnModel(BaseModel): diff --git a/mu_defense/core/__init__.py b/mu_defense/core/__init__.py index 22020222..5a5ddc20 100644 --- a/mu_defense/core/__init__.py +++ b/mu_defense/core/__init__.py @@ -1,6 +1,5 @@ from .base_algorithm import BaseAlgorithm from .base_model import BaseModel -from .base_compvis_trainer import BaseCompvisTrainer from .base_config import BaseConfig from .base_data_handler import BaseDatasetHandler from .base_trainer import BaseTrainer @@ -10,6 +9,5 @@ "BaseModel", "BaseTrainer", "BaseConfig", - "BaseDatasetHandler", - "BaseCompvisTrainer" + "BaseDatasetHandler" ] diff --git a/mu_defense/core/base_compvis_trainer.py b/mu_defense/core/base_compvis_trainer.py deleted file mode 100644 index b4fea970..00000000 --- a/mu_defense/core/base_compvis_trainer.py +++ /dev/null @@ -1,23 +0,0 @@ -# mu_defense/core/base_trainer.py - -from abc import ABC -from typing import Any - -class BaseCompvisTrainer(ABC): - """Abstract base class for training unlearning models.""" - - def __init__(self, model: Any, config: dict, **kwargs): - self.model = model - self.config = config - - - # @abstractmethod - def setup_optimizer(self, *args, **kwargs): - """Set up the optimizers for training.""" - pass - - # @abstractmethod - def train(self, *args, **kwargs): - """Train the model.""" - pass - diff --git a/mu_defense/core/base_trainer.py b/mu_defense/core/base_trainer.py index 62776409..714402bc 100644 --- a/mu_defense/core/base_trainer.py +++ b/mu_defense/core/base_trainer.py @@ -10,6 +10,9 @@ def __init__(self, config: dict): self.config = config self.devices = config.get("devices", ["cuda:0"]) self.logger = logging.getLogger(__name__) + + def train(self): + pass @abstractmethod def run(self): From 148a64d08e0fbb8b05a3be4743ca97ea94a147b5 Mon Sep 17 00:00:00 2001 From: palisthadeshar Date: Tue, 11 Feb 2025 11:39:19 +0000 Subject: [PATCH 19/22] #TODOs removed --- mu/algorithms/concept_ablation/sampler.py | 4 +--- mu/algorithms/erase_diff/sampler.py | 3 --- mu/algorithms/esd/sampler.py | 4 ---- mu/algorithms/forget_me_not/sampler.py | 4 ---- mu/algorithms/saliency_unlearning/sampler.py | 3 --- mu/algorithms/scissorhands/sampler.py | 4 ---- mu/algorithms/semipermeable_membrane/model.py | 2 +- mu/algorithms/semipermeable_membrane/sampler.py | 3 --- mu/algorithms/unified_concept_editing/sampler.py | 3 --- mu/helpers/utils.py | 3 +-- 10 files changed, 3 insertions(+), 30 deletions(-) diff --git a/mu/algorithms/concept_ablation/sampler.py b/mu/algorithms/concept_ablation/sampler.py index 081f6a3f..767e5b8d 100644 --- a/mu/algorithms/concept_ablation/sampler.py +++ b/mu/algorithms/concept_ablation/sampler.py @@ -17,9 +17,7 @@ -#TODO remove this -theme_available = ['Abstractionism', 'Bricks', 'Cartoon'] -class_available = ['Architectures', 'Bears', 'Birds'] + class ConceptAblationSampler(BaseSampler): """Concept Ablation Image Generator class extending a hypothetical BaseImageGenerator.""" diff --git a/mu/algorithms/erase_diff/sampler.py b/mu/algorithms/erase_diff/sampler.py index 3ea4e3c5..3fe19ea7 100644 --- a/mu/algorithms/erase_diff/sampler.py +++ b/mu/algorithms/erase_diff/sampler.py @@ -16,9 +16,6 @@ from mu.helpers import load_config from mu.helpers.utils import load_ckpt_from_config -#TODO to remove this -theme_available = ['Abstractionism', 'Bricks', 'Cartoon'] -class_available = ['Architectures', 'Bears', 'Birds'] class EraseDiffSampler(BaseSampler): """EraseDiff Image Generator class extending a hypothetical BaseImageGenerator.""" diff --git a/mu/algorithms/esd/sampler.py b/mu/algorithms/esd/sampler.py index 4490c067..86f744aa 100644 --- a/mu/algorithms/esd/sampler.py +++ b/mu/algorithms/esd/sampler.py @@ -18,10 +18,6 @@ from mu.helpers.utils import load_ckpt_from_config -#TODO remove this -theme_available = ['Abstractionism', 'Bricks', 'Cartoon'] -class_available = ['Architectures', 'Bears', 'Birds'] - class ESDSampler(BaseSampler): """Sampler for the ESD algorithm.""" diff --git a/mu/algorithms/forget_me_not/sampler.py b/mu/algorithms/forget_me_not/sampler.py index c6889669..cf9511b6 100644 --- a/mu/algorithms/forget_me_not/sampler.py +++ b/mu/algorithms/forget_me_not/sampler.py @@ -11,10 +11,6 @@ from mu.core.base_sampler import BaseSampler from stable_diffusion.constants.const import theme_available, class_available -#TODO to remove this -theme_available = ['Abstractionism', 'Bricks', 'Cartoon'] -class_available = ['Architectures', 'Bears', 'Birds'] - class ForgetMeNotSampler(BaseSampler): """ForgetMeNot Image Generator class extending a hypothetical BaseImageGenerator.""" diff --git a/mu/algorithms/saliency_unlearning/sampler.py b/mu/algorithms/saliency_unlearning/sampler.py index 7af3edec..98ac55b3 100644 --- a/mu/algorithms/saliency_unlearning/sampler.py +++ b/mu/algorithms/saliency_unlearning/sampler.py @@ -15,9 +15,6 @@ from mu.helpers import load_config from mu.helpers.utils import load_ckpt_from_config,load_style_generated_images,load_style_ref_images,calculate_fid -#TODO remove this -theme_available = ['Abstractionism', 'Bricks', 'Cartoon'] -class_available = ['Architectures', 'Bears', 'Birds'] class SaliencyUnlearningSampler(BaseSampler): """Saliency Unlearning Image Generator class extending a hypothetical BaseImageGenerator.""" diff --git a/mu/algorithms/scissorhands/sampler.py b/mu/algorithms/scissorhands/sampler.py index 68ff2ea1..75e7427b 100644 --- a/mu/algorithms/scissorhands/sampler.py +++ b/mu/algorithms/scissorhands/sampler.py @@ -17,10 +17,6 @@ -#TODO remove this -theme_available = ['Abstractionism', 'Bricks', 'Cartoon'] -class_available = ['Architectures', 'Bears', 'Birds'] - class ScissorHandsSampler(BaseSampler): """ScissorHands Image Generator class extending a hypothetical BaseImageGenerator.""" diff --git a/mu/algorithms/semipermeable_membrane/model.py b/mu/algorithms/semipermeable_membrane/model.py index 6e166379..216e1141 100644 --- a/mu/algorithms/semipermeable_membrane/model.py +++ b/mu/algorithms/semipermeable_membrane/model.py @@ -94,7 +94,7 @@ def save_model(self, model, output_path: str, dtype, metadata, *args, **kwargs): """ Save the model weights to the output path """ - #TODO + self.logger.info(f"Saving model to {output_path}") # Save the SPM network weights model.save_weights( diff --git a/mu/algorithms/semipermeable_membrane/sampler.py b/mu/algorithms/semipermeable_membrane/sampler.py index 13934ca2..2654cece 100644 --- a/mu/algorithms/semipermeable_membrane/sampler.py +++ b/mu/algorithms/semipermeable_membrane/sampler.py @@ -20,9 +20,6 @@ from mu.algorithms.semipermeable_membrane.src.models.merge_spm import load_state_dict -#TODO remove this -theme_available = ['Abstractionism', 'Bricks', 'Cartoon'] -class_available = ['Architectures', 'Bears', 'Birds'] MATCHING_METRICS = Literal[ "clipcos", diff --git a/mu/algorithms/unified_concept_editing/sampler.py b/mu/algorithms/unified_concept_editing/sampler.py index 1224318a..dc6c5e8b 100644 --- a/mu/algorithms/unified_concept_editing/sampler.py +++ b/mu/algorithms/unified_concept_editing/sampler.py @@ -13,9 +13,6 @@ from mu.core.base_sampler import BaseSampler from stable_diffusion.constants.const import theme_available, class_available -#TODO remove this -theme_available = ['Abstractionism', 'Bricks', 'Cartoon'] -class_available = ['Architectures', 'Bears', 'Birds'] class UnifiedConceptEditingSampler(BaseSampler): """Unified Concept editing Image Generator class extending a hypothetical BaseImageGenerator.""" diff --git a/mu/helpers/utils.py b/mu/helpers/utils.py index 54e8ab19..ee199e23 100644 --- a/mu/helpers/utils.py +++ b/mu/helpers/utils.py @@ -54,10 +54,9 @@ def load_model_from_config( model.cond_stage_model.device = device return model - @torch.no_grad() def sample_model( - model, + model, sampler, c, h, From cdf6c3cc4a547389ec4fb77c3988b94d383a782f Mon Sep 17 00:00:00 2001 From: palisthadeshar Date: Tue, 11 Feb 2025 12:02:33 +0000 Subject: [PATCH 20/22] bugfixes --- mu_attack/execs/adv_attack.py | 141 ++---- .../algorithms/adv_unlearn/compvis_trainer.py | 434 ++---------------- 2 files changed, 85 insertions(+), 490 deletions(-) diff --git a/mu_attack/execs/adv_attack.py b/mu_attack/execs/adv_attack.py index 42ed5d76..f8758b4f 100644 --- a/mu_attack/execs/adv_attack.py +++ b/mu_attack/execs/adv_attack.py @@ -1,34 +1,17 @@ # mu_attack/execs/adv_attack.py import torch -import random import wandb -from transformers import CLIPTextModel, CLIPTokenizer -from diffusers import StableDiffusionPipeline - -from stable_diffusion.ldm.models.diffusion.ddim import DDIMSampler from mu_attack.configs.adv_unlearn import AdvAttackConfig from mu_attack.attackers.soft_prompt import SoftPromptAttack -from mu_attack.tasks.utils.text_encoder import CustomTextEncoder from mu_attack.helpers.utils import get_models_for_compvis, get_models_for_diffusers class AdvAttack: - """ - Class for adversarial unlearning training. - - This class wraps the full training pipeline including adversarial attack - and model handling. - """ - def __init__(self, config: AdvAttackConfig, **kwargs): + def __init__(self, config: AdvAttackConfig): self.config = config.__dict__ - for key, value in kwargs.items(): - setattr(config, key, value) - - config.validate_config() - - self.prompt = config.prompt + # Do not set self.prompt from the config; remove the dependency. self.encoder_model_name_or_path = config.encoder_model_name_or_path self.cache_path = config.cache_path self.devices = [f'cuda:{int(d.strip())}' for d in config.devices.split(',')] @@ -51,43 +34,16 @@ def __init__(self, config: AdvAttackConfig, **kwargs): self.target_ckpt = config.target_ckpt self.criteria = torch.nn.MSELoss() - # Initialize wandb + # Initialize wandb (if needed) wandb.init( project=config.project_name, name=config.experiment_name, reinit=True ) - # Load models self.load_models() - def encode_text(self, text): - """Encodes text into a latent space using CLIP from Diffusers.""" - text_inputs = self.tokenizer( - text, - padding="max_length", - truncation=True, - max_length=77, - return_tensors="pt" - ).to(self.devices[0]) # Move to correct device - - with torch.no_grad(): - text_embeddings = self.text_encoder(text_inputs.input_ids)[0] # Take the first output (hidden states) - - return text_embeddings - def load_models(self): - """Loads the tokenizer, text encoder, and models.""" - self.tokenizer = CLIPTokenizer.from_pretrained( - self.encoder_model_name_or_path, subfolder="tokenizer", cache_dir=self.cache_path - ) - self.text_encoder = CLIPTextModel.from_pretrained( - self.encoder_model_name_or_path, subfolder="text_encoder", cache_dir=self.cache_path - ).to(self.devices[0]) - self.custom_text_encoder = CustomTextEncoder(self.text_encoder).to(self.devices[0]) - self.all_embeddings = self.custom_text_encoder.get_all_embedding().unsqueeze(0) - - # Load base models if self.backend == "compvis": self.model_orig, self.sampler_orig, self.model, self.sampler = get_models_for_compvis( self.config_path, self.compvis_ckpt_path, self.devices @@ -96,39 +52,29 @@ def load_models(self): self.model_orig, self.sampler_orig, self.model, self.sampler = get_models_for_diffusers( self.diffusers_model_name_or_path, self.target_ckpt, self.devices ) - - - def attack(self): - """Performs the adversarial attack.""" - # Ensure words are in list format - if isinstance(self.prompt, str): - self.words = [self.prompt] - elif isinstance(self.prompt, list): - self.words = self.prompt - else: - raise ValueError("Prompt must be a string or a list of strings.") - # Select a random word from the prompt list - word = random.choice(self.words) - - if self.backend == "compvis": - # CompVis uses `get_learned_conditioning` - emb_0 = self.model_orig.get_learned_conditioning(['']) - emb_p = self.model_orig.get_learned_conditioning([word]) - elif self.backend == "diffusers": - # Diffusers requires explicit encoding via CLIP - emb_0 = self.encode_text("") - emb_p = self.encode_text(word) - - # Initialize attack class + def attack(self, word, global_step, attack_round): + """ + Perform the adversarial attack using the given word. + + Args: + word (str): The current prompt to attack. + global_step (int): The current global training step. + attack_round (int): The current attack round. + + Returns: + tuple: (adversarial embedding, input_ids) + """ + # Now, use the passed `word` for the attack instead of self.prompt. + # (Everything else in this method remains the same.) sp_attack = SoftPromptAttack( model=self.model, model_orig=self.model_orig, - tokenizer=self.tokenizer, - text_encoder=self.custom_text_encoder, + tokenizer=self.tokenizer, + text_encoder=self.custom_text_encoder, sampler=self.sampler, - emb_0=emb_0, - emb_p=emb_p, + emb_0=self._get_emb_0(), + emb_p=self._get_emb_p(word), start_guidance=self.start_guidance, devices=self.devices, ddim_steps=self.ddim_steps, @@ -137,25 +83,36 @@ def attack(self): criteria=self.criteria, k=self.adv_prompt_num, all_embeddings=self.all_embeddings, - backend = self.backend + backend=self.backend ) + return sp_attack.attack(global_step, word, attack_round, self.attack_type, + self.attack_embd_type, self.attack_step, self.attack_lr, + self.attack_init, self.attack_init_embd, self.attack_method) + # Example helper methods to get embeddings from model_orig. + def _get_emb_0(self): + if self.backend == "compvis": + return self.model_orig.get_learned_conditioning(['']) + else: + # For diffusers, you need to define your own method (e.g., using self.encode_text("")) + return self.encode_text("") + + def _get_emb_p(self, word): + if self.backend == "compvis": + return self.model_orig.get_learned_conditioning([word]) + else: + return self.encode_text(word) - self.adv_word_embd, self.adv_input_ids = sp_attack.attack( - global_step=0, - word=word, - attack_round=0, - attack_type=self.attack_type, - attack_embd_type=self.attack_embd_type, - attack_step=self.attack_step, - attack_lr=self.attack_lr, - attack_init=self.attack_init, - attack_init_embd=self.attack_init_embd, - attack_method=self.attack_method - ) - - - return self.adv_word_embd, self.adv_input_ids + def encode_text(self, text): + text_inputs = self.tokenizer( + text, + padding="max_length", + truncation=True, + max_length=77, + return_tensors="pt" + ).to(self.devices[0]) + with torch.no_grad(): + text_embeddings = self.text_encoder(text_inputs.input_ids)[0] + return text_embeddings - \ No newline at end of file diff --git a/mu_defense/algorithms/adv_unlearn/compvis_trainer.py b/mu_defense/algorithms/adv_unlearn/compvis_trainer.py index 8eb79bd4..1621b3fd 100644 --- a/mu_defense/algorithms/adv_unlearn/compvis_trainer.py +++ b/mu_defense/algorithms/adv_unlearn/compvis_trainer.py @@ -1,372 +1,3 @@ -# # mu_defense/algorithms/adv_unlearn/compvis_trainer.py - -# import torch -# from tqdm import tqdm -# import random -# import wandb -# import logging -# from torch.nn import MSELoss - -# from mu.core import BaseTrainer -# from mu_defense.algorithms.adv_unlearn import ( -# id2embedding, -# param_choices, -# get_train_loss_retain, -# save_text_encoder, -# save_history, -# sample_model -# ) -# from mu_attack.attackers.soft_prompt import SoftPromptAttack -# from mu_defense.algorithms.adv_unlearn import AdvUnlearnDatasetHandler - - -# class AdvUnlearnCompvisTrainer(BaseTrainer): -# """ -# Trainer for adversarial unlearning. - -# This trainer performs the adversarial prompt update and retention-based -# regularized training loop for CompVis/Diffusers models. -# """ -# def __init__(self, model, config: dict, devices: list, **kwargs): -# """ -# Initialize the AdvUnlearnCompvisTrainer. -# """ -# super().__init__(model, config, **kwargs) -# self.devices = devices - -# # Unpack models and samplers from the provided model loader. -# self.model = model.model # trainable diffusion model -# self.model_orig = model.model_orig # frozen diffusion model (set to eval) -# self.sampler = model.sampler -# self.sampler_orig = model.sampler_orig -# self.model_loader = model - -# # Other loaded components. -# self.tokenizer = model.tokenizer -# self.custom_text_encoder = model.custom_text_encoder -# self.all_embeddings = model.all_embeddings - -# # Loss criterion. -# self.criteria = MSELoss() - -# # Save configuration parameters. -# self.config = config -# self.prompt = self.config['prompt'] -# self.seperator = self.config.get('seperator') -# self.iterations = self.config.get('iterations') -# self.ddim_steps = self.config['ddim_steps'] -# self.start_guidance = self.config['start_guidance'] -# self.negative_guidance = self.config['negative_guidance'] -# self.image_size = self.config['image_size'] -# self.lr = self.config['lr'] -# self.model_config_path = self.config['model_config_path'] -# self.output_dir = self.config['output_dir'] - -# # Retention and attack parameters. -# self.dataset_retain = self.config['dataset_retain'] -# self.retain_batch = self.config['retain_batch'] -# self.retain_train = self.config['retain_train'] -# self.retain_step = self.config['retain_step'] -# self.retain_loss_w = self.config['retain_loss_w'] -# self.attack_method = self.config['attack_method'] -# self.train_method = self.config['train_method'] -# self.norm_layer = self.config['norm_layer'] -# self.component = self.config['component'] -# self.adv_prompt_num = self.config['adv_prompt_num'] -# self.attack_embd_type = self.config['attack_embd_type'] -# self.attack_type = self.config['attack_type'] -# self.attack_init = self.config['attack_init'] -# self.warmup_iter = self.config['warmup_iter'] -# self.attack_step = self.config['attack_step'] -# self.attack_lr = self.config['attack_lr'] -# self.adv_prompt_update_step = self.config['adv_prompt_update_step'] -# self.ddim_eta = self.config['ddim_eta'] - -# self.logger = logging.getLogger(__name__) - -# # Setup the dataset handler and prompt cleaning. -# self.dataset_handler = AdvUnlearnDatasetHandler( -# prompt=self.prompt, -# seperator=self.seperator, -# dataset_retain=self.dataset_retain -# ) -# self.words, self.word_print = self.dataset_handler.setup_prompt() -# self.retain_dataset = self.dataset_handler.setup_dataset() - -# # Initialize adversarial prompt variables. -# self.adv_word_embd = None -# self.adv_condition_embd = None -# self.adv_input_ids = None - -# # Setup trainable parameters and optimizer. -# self._setup_optimizer() - -# def _setup_optimizer(self): -# """ -# Set up the optimizer based on the training method. -# """ -# if 'text_encoder' in self.train_method: -# self.parameters = param_choices( -# model=self.custom_text_encoder, -# train_method=self.train_method, -# component=self.component, -# final_layer_norm=self.norm_layer -# ) -# else: -# self.parameters = param_choices( -# model=self.model, -# train_method=self.train_method, -# component=self.component, -# final_layer_norm=self.norm_layer -# ) -# self.optimizer = torch.optim.Adam(self.parameters, lr=float(self.lr)) - -# def train(self): -# """ -# Execute the adversarial unlearning training loop. -# """ -# ddim_eta = self.ddim_eta -# quick_sample_till_t = lambda x, s, code, batch, t: sample_model( -# self.model, self.sampler, -# x, self.image_size, self.image_size, self.ddim_steps, s, ddim_eta, -# start_code=code, n_samples=batch, till_T=t, verbose=False -# ) -# losses = [] -# history = [] -# global_step = 0 -# attack_round = 0 - -# pbar = tqdm(range(self.iterations)) -# for i in pbar: -# if i % self.adv_prompt_update_step == 0: -# if self.retain_dataset.check_unseen_prompt_count() < self.retain_batch: -# self.retain_dataset.reset() -# word = random.choice(self.words) -# text_input = self.tokenizer( -# word, -# padding="max_length", -# max_length=self.tokenizer.model_max_length, -# return_tensors="pt", -# truncation=True -# ) -# text_embeddings = id2embedding( -# self.tokenizer, -# self.all_embeddings, -# text_input.input_ids.to(self.devices[0]), -# self.devices[0] -# ) -# emb_0 = self.model_orig.get_learned_conditioning(['']) -# emb_p = self.model_orig.get_learned_conditioning([word]) - -# if i >= self.warmup_iter: -# self.custom_text_encoder.text_encoder.eval() -# self.custom_text_encoder.text_encoder.requires_grad_(False) -# self.model.eval() -# if attack_round == 0: -# if self.attack_embd_type == 'word_embd': -# self.adv_word_embd, self.adv_input_ids = SoftPromptAttack.attack( -# global_step, word, self.model, self.model_orig, self.tokenizer, -# self.custom_text_encoder, self.sampler, emb_0, emb_p, -# self.start_guidance, self.devices, self.ddim_steps, ddim_eta, -# self.image_size, self.criteria, self.adv_prompt_num, -# self.all_embeddings, attack_round, self.attack_type, -# self.attack_embd_type, self.attack_step, self.attack_lr, -# self.attack_init, None, self.attack_method -# ) -# elif self.attack_embd_type == 'condition_embd': -# self.adv_condition_embd, self.adv_input_ids = SoftPromptAttack.attack( -# global_step, word, self.model, self.model_orig, self.tokenizer, -# self.custom_text_encoder, self.sampler, emb_0, emb_p, -# self.start_guidance, self.devices, self.ddim_steps, ddim_eta, -# self.image_size, self.criteria, self.adv_prompt_num, -# self.all_embeddings, attack_round, self.attack_type, -# self.attack_embd_type, self.attack_step, self.attack_lr, -# self.attack_init, None, self.attack_method -# ) -# else: -# if self.attack_embd_type == 'word_embd': -# self.adv_word_embd, self.adv_input_ids = SoftPromptAttack.attack( -# global_step, word, self.model, self.model_orig, self.tokenizer, -# self.custom_text_encoder, self.sampler, emb_0, emb_p, -# self.start_guidance, self.devices, self.ddim_steps, ddim_eta, -# self.image_size, self.criteria, self.adv_prompt_num, -# self.all_embeddings, attack_round, self.attack_type, -# self.attack_embd_type, self.attack_step, self.attack_lr, -# self.attack_init, self.adv_word_embd, self.attack_method -# ) -# elif self.attack_embd_type == 'condition_embd': -# self.adv_condition_embd, self.adv_input_ids = SoftPromptAttack.attack( -# global_step, word, self.model, self.model_orig, self.tokenizer, -# self.custom_text_encoder, self.sampler, emb_0, emb_p, -# self.start_guidance, self.devices, self.ddim_steps, ddim_eta, -# self.image_size, self.criteria, self.adv_prompt_num, -# self.all_embeddings, attack_round, self.attack_type, -# self.attack_embd_type, self.attack_step, self.attack_lr, -# self.attack_init, self.adv_condition_embd, self.attack_method -# ) -# global_step += self.attack_step -# attack_round += 1 - -# if 'text_encoder' in self.train_method: -# self.custom_text_encoder.text_encoder.train() -# self.custom_text_encoder.text_encoder.requires_grad_(True) -# self.model.eval() -# else: -# self.custom_text_encoder.text_encoder.eval() -# self.custom_text_encoder.text_encoder.requires_grad_(False) -# self.model.train() - -# self.optimizer.zero_grad() - -# if self.retain_train == 'reg': -# retain_words = self.retain_dataset.get_random_prompts(self.retain_batch) -# retain_text_input = self.tokenizer( -# retain_words, -# padding="max_length", -# max_length=self.tokenizer.model_max_length, -# return_tensors="pt", -# truncation=True -# ) -# retain_input_ids = retain_text_input.input_ids.to(self.devices[0]) -# retain_emb_p = self.model_orig.get_learned_conditioning(retain_words) -# retain_text_embeddings = id2embedding( -# self.tokenizer, -# self.all_embeddings, -# retain_text_input.input_ids.to(self.devices[0]), -# self.devices[0] -# ) -# retain_text_embeddings = retain_text_embeddings.reshape( -# self.retain_batch, -1, retain_text_embeddings.shape[-1] -# ) -# retain_emb_n = self.custom_text_encoder( -# input_ids=retain_input_ids, -# inputs_embeds=retain_text_embeddings -# )[0] -# else: -# retain_emb_p = None -# retain_emb_n = None - -# if i < self.warmup_iter: -# input_ids = text_input.input_ids.to(self.devices[0]) -# emb_n = self.custom_text_encoder( -# input_ids=input_ids, -# inputs_embeds=text_embeddings -# )[0] -# loss = get_train_loss_retain( -# self.retain_batch, self.retain_train, self.retain_loss_w, -# self.model, self.model_orig, self.custom_text_encoder, self.sampler, -# emb_0, emb_p, retain_emb_p, emb_n, retain_emb_n, self.start_guidance, -# self.negative_guidance, self.devices, self.ddim_steps, ddim_eta, -# self.image_size, self.criteria, input_ids, self.attack_embd_type -# ) -# else: -# if self.attack_embd_type == 'word_embd': -# loss = get_train_loss_retain( -# self.retain_batch, self.retain_train, self.retain_loss_w, -# self.model, self.model_orig, self.custom_text_encoder, self.sampler, -# emb_0, emb_p, retain_emb_p, None, retain_emb_n, self.start_guidance, -# self.negative_guidance, self.devices, self.ddim_steps, ddim_eta, -# self.image_size, self.criteria, self.adv_input_ids, self.attack_embd_type, -# self.adv_word_embd -# ) -# elif self.attack_embd_type == 'condition_embd': -# loss = get_train_loss_retain( -# self.retain_batch, self.retain_train, self.retain_loss_w, -# self.model, self.model_orig, self.custom_text_encoder, self.sampler, -# emb_0, emb_p, retain_emb_p, None, retain_emb_n, self.start_guidance, -# self.negative_guidance, self.devices, self.ddim_steps, ddim_eta, -# self.image_size, self.criteria, self.adv_input_ids, self.attack_embd_type, -# self.adv_condition_embd -# ) -# loss.backward() -# losses.append(loss.item()) -# pbar.set_postfix({"loss": loss.item()}) -# history.append(loss.item()) -# wandb.log({'Train_Loss': loss.item()}, step=global_step) -# wandb.log({'Attack_Loss': 0.0}, step=global_step) -# global_step += 1 -# self.optimizer.step() - -# if self.retain_train == 'iter': -# for r in range(self.retain_step): -# self.optimizer.zero_grad() -# if self.retain_dataset.check_unseen_prompt_count() < self.retain_batch: -# self.retain_dataset.reset() -# retain_words = self.retain_dataset.get_random_prompts(self.retain_batch) -# t_enc = torch.randint(self.ddim_steps, (1,), device=self.devices[0]) -# og_num = round((int(t_enc.item()) / self.ddim_steps) * 1000) -# og_num_lim = round(((int(t_enc.item()) + 1) / self.ddim_steps) * 1000) -# t_enc_ddpm = torch.randint(og_num, og_num_lim, (1,), device=self.devices[0]) -# retain_start_code = torch.randn((self.retain_batch, 4, 64, 64)).to(self.devices[0]) -# retain_emb_p = self.model_orig.get_learned_conditioning(retain_words) -# retain_z = quick_sample_till_t( -# retain_emb_p.to(self.devices[0]), -# self.start_guidance, -# retain_start_code, -# self.retain_batch, -# int(t_enc.item()) -# ) -# retain_e_p = self.model_orig.apply_model( -# retain_z.to(self.devices[0]), -# t_enc_ddpm.to(self.devices[0]), -# retain_emb_p.to(self.devices[0]) -# ) -# retain_text_input = self.tokenizer( -# retain_words, -# padding="max_length", -# max_length=self.tokenizer.model_max_length, -# return_tensors="pt", -# truncation=True -# ) -# retain_input_ids = retain_text_input.input_ids.to(self.devices[0]) -# retain_text_embeddings = id2embedding( -# self.tokenizer, -# self.all_embeddings, -# retain_text_input.input_ids.to(self.devices[0]), -# self.devices[0] -# ) -# retain_text_embeddings = retain_text_embeddings.reshape( -# self.retain_batch, -1, retain_text_embeddings.shape[-1] -# ) -# retain_emb_n = self.custom_text_encoder( -# input_ids=retain_input_ids, -# inputs_embeds=retain_text_embeddings -# )[0] -# retain_e_n = self.model.apply_model( -# retain_z.to(self.devices[0]), -# t_enc_ddpm.to(self.devices[0]), -# retain_emb_n.to(self.devices[0]) -# ) -# retain_loss = self.criteria( -# retain_e_n.to(self.devices[0]), -# retain_e_p.to(self.devices[0]) -# ) -# retain_loss.backward() -# self.optimizer.step() - -# if (i + 1) % self.config['save_interval'] == 0 and (i + 1) != self.iterations and (i + 1) >= self.config['save_interval']: -# if 'text_encoder' in self.train_method: -# save_text_encoder(self.output_dir, self.custom_text_encoder, self.train_method, i) -# else: -# output_path = f"{self.output_dir}/models/model_checkpoint_{i}.pt" -# self.model_loader.save_model(self.model, output_path) -# if i % 1 == 0: -# save_history(self.output_dir, losses, self.word_print) - -# self.model.eval() -# self.custom_text_encoder.text_encoder.eval() -# self.custom_text_encoder.text_encoder.requires_grad_(False) -# if 'text_encoder' in self.train_method: -# save_text_encoder(self.output_dir, self.custom_text_encoder, self.train_method, i) -# else: -# output_path = f"{self.output_dir}/models/model_checkpoint_{i}.pt" -# self.model_loader.save_model(self.model, output_path) -# save_history(self.output_dir, losses, self.word_print) -# return self.model - - -# mu_defense/algorithms/adv_unlearn/compvis_trainer.py - # mu_defense/algorithms/adv_unlearn/compvis_trainer.py import torch @@ -385,7 +16,8 @@ save_history, sample_model ) -# Note: We no longer import SoftPromptAttack directly. + + from mu_attack.execs.adv_attack import AdvAttack from mu_attack.configs.adv_unlearn import AdvAttackConfig from mu_defense.algorithms.adv_unlearn import AdvUnlearnDatasetHandler @@ -455,6 +87,39 @@ def __init__(self, model, config: dict, devices: list, **kwargs): self.logger = logging.getLogger(__name__) + attack_config = AdvAttackConfig( + prompt="", # prompt is no longer used in __init__ + encoder_model_name_or_path=self.tokenizer.name_or_path, + cache_path=config.get("cache_path", "./cache"), + devices=",".join([d.strip() for d in config.get("devices", "cuda:0").split(',')]), + attack_type=config['attack_type'], + attack_embd_type=config['attack_embd_type'], + attack_step=config['attack_step'], + attack_lr=config['attack_lr'], + attack_init=config['attack_init'], + attack_init_embd=None, # adjust as needed + attack_method=config['attack_method'], + ddim_steps=config['ddim_steps'], + ddim_eta=config['ddim_eta'], + image_size=config['image_size'], + adv_prompt_num=config['adv_prompt_num'], + start_guidance=config['start_guidance'], + config_path=config['model_config_path'], + compvis_ckpt_path=config.get("compvis_ckpt_path", ""), + backend="compvis", + diffusers_model_name_or_path="", + target_ckpt="", + project=config.get("project_name", "default_project"), + experiment_name=config.get("experiment_name", "default_experiment") + ) + self.adv_attack = AdvAttack(attack_config) + # Inject the preloaded objects + self.adv_attack.tokenizer = self.tokenizer + self.adv_attack.text_encoder = self.custom_text_encoder.text_encoder + self.adv_attack.custom_text_encoder = self.custom_text_encoder + self.adv_attack.all_embeddings = self.all_embeddings + + # Setup the dataset handler and prompt cleaning. self.dataset_handler = AdvUnlearnDatasetHandler( prompt=self.prompt, @@ -535,34 +200,7 @@ def train(self): self.custom_text_encoder.text_encoder.requires_grad_(False) self.model.eval() - # Build the attack configuration. (Adjust fields if necessary.) - attack_config = AdvAttackConfig( - prompt=word, - encoder_model_name_or_path=self.tokenizer.name_or_path, - cache_path=self.config.get("cache_path", "./cache"), - devices=",".join([d.strip() for d in self.config.get("devices", "cuda:0").split(',')]), - attack_type=self.attack_type, - attack_embd_type=self.attack_embd_type, - attack_step=self.attack_step, - attack_lr=self.attack_lr, - attack_init=self.attack_init, - attack_init_embd=None, # Set as needed - attack_method=self.attack_method, - ddim_steps=self.ddim_steps, - ddim_eta=ddim_eta, - image_size=self.image_size, - adv_prompt_num=self.adv_prompt_num, - start_guidance=self.start_guidance, - config_path=self.model_config_path, - compvis_ckpt_path=self.config.get("compvis_ckpt_path", ""), - backend="compvis", - diffusers_model_name_or_path="", - target_ckpt="", - project=self.config.get("project_name", "default_project"), - experiment_name=self.config.get("experiment_name", "default_experiment") - ) - adv_attack = AdvAttack(attack_config) - adv_word_embd, adv_input_ids = adv_attack.attack() + adv_word_embd, adv_input_ids = self.adv_attack.attack(word, global_step, attack_round) if self.attack_embd_type == 'word_embd': self.adv_word_embd, self.adv_input_ids = adv_word_embd, adv_input_ids From 400f48dbf20995d140b5ee2e04db45e966e1982d Mon Sep 17 00:00:00 2001 From: palisthadeshar Date: Wed, 12 Feb 2025 06:02:39 +0000 Subject: [PATCH 21/22] prompt data added --- .../train/coco_object_no_filter_retain.csv | 244 ++++++++++++++++++ data/prompts/train/coco_object_retain.csv | 244 ++++++++++++++++++ .../train/imagenet243_no_filter_retain.csv | 244 ++++++++++++++++++ data/prompts/train/imagenet243_retain.csv | 244 ++++++++++++++++++ 4 files changed, 976 insertions(+) create mode 100644 data/prompts/train/coco_object_no_filter_retain.csv create mode 100644 data/prompts/train/coco_object_retain.csv create mode 100644 data/prompts/train/imagenet243_no_filter_retain.csv create mode 100644 data/prompts/train/imagenet243_retain.csv diff --git a/data/prompts/train/coco_object_no_filter_retain.csv b/data/prompts/train/coco_object_no_filter_retain.csv new file mode 100644 index 00000000..1c15a6f5 --- /dev/null +++ b/data/prompts/train/coco_object_no_filter_retain.csv @@ -0,0 +1,244 @@ +case_num,source,prompt +1,coco_object,a photo of chair +2,coco_object,a photo of fridge +3,coco_object,a photo of banana +4,coco_object,a photo of street sign +5,coco_object,a photo of headlights +6,coco_object,a photo of shorts +7,coco_object,a photo of handbag +8,coco_object,a photo of skis +9,coco_object,a photo of skateboard +10,coco_object,a photo of chopping board +11,coco_object,a photo of goat +12,coco_object,a photo of playing cards +13,coco_object,a photo of underpants +14,coco_object,a photo of toy cars +15,coco_object,a photo of super hero costume +16,coco_object,a photo of pasta +17,coco_object,a photo of moon +18,coco_object,a photo of basketball +19,coco_object,a photo of radio +20,coco_object,a photo of ipad +21,coco_object,a photo of goldfish +22,coco_object,a photo of jetpack +23,coco_object,a photo of pajamas +24,coco_object,a photo of couch +25,coco_object,a photo of microwave +26,coco_object,a photo of bread +27,coco_object,a photo of umbrella +28,coco_object,a photo of window +29,coco_object,a photo of teddy bear +30,coco_object,a photo of pans +31,coco_object,a photo of hot dog +32,coco_object,a photo of snowboard +33,coco_object,a photo of helicopter +34,coco_object,a photo of washer +35,coco_object,a photo of magazine +36,coco_object,a photo of shirt +37,coco_object,a photo of phone +38,coco_object,a photo of towel +39,coco_object,a photo of necklace +40,coco_object,a photo of bracelet +41,coco_object,a photo of platypus +42,coco_object,a photo of feet +43,coco_object,a photo of road +44,coco_object,a photo of telephone +45,coco_object,a photo of fences +46,coco_object,a photo of aardvark +47,coco_object,a photo of iphone +48,coco_object,a photo of robot +49,coco_object,a photo of car +50,coco_object,a photo of potted plant +51,coco_object,a photo of sink +52,coco_object,a photo of apple +53,coco_object,a photo of scissors +54,coco_object,a photo of legs +55,coco_object,a photo of desk +56,coco_object,a photo of tie +57,coco_object,a photo of stapler +58,coco_object,a photo of table +59,coco_object,a photo of armpits +60,coco_object,a photo of tomato +61,coco_object,a photo of lion +62,coco_object,a photo of key +63,coco_object,a photo of Pig +64,coco_object,a photo of hyppo +65,coco_object,a photo of tablet +66,coco_object,a photo of arms +67,coco_object,a photo of pancake +68,coco_object,a photo of shark +69,coco_object,a photo of fountain +70,coco_object,a photo of movie +71,coco_object,a photo of goal net +72,coco_object,a photo of dinosaur +73,coco_object,a photo of hoop +74,coco_object,a photo of crusher +75,coco_object,a photo of motorcycle +76,coco_object,a photo of tv +77,coco_object,a photo of torso +78,coco_object,a photo of book +79,coco_object,a photo of short sleeve shirt +80,coco_object,a photo of fire hydrant +81,coco_object,a photo of computer +82,coco_object,a photo of stop sign +83,coco_object,a photo of sports ball +84,coco_object,a photo of basketball +85,coco_object,a photo of hoop +86,coco_object,a photo of pants +87,coco_object,a photo of tree +88,coco_object,a photo of bunny +89,coco_object,a photo of frame +90,coco_object,a photo of strawberries +91,coco_object,a photo of fingers +92,coco_object,a photo of corn +93,coco_object,a photo of balloon +94,coco_object,a photo of back +95,coco_object,a photo of swan +96,coco_object,a photo of fax machine +97,coco_object,a photo of head +98,coco_object,a photo of toys +99,coco_object,a photo of unicycle +100,coco_object,a photo of hen +101,coco_object,a photo of animal crackers +102,coco_object,a photo of bird +103,coco_object,a photo of cow +104,coco_object,a photo of toaster +105,coco_object,a photo of boat +106,coco_object,a photo of backpack +107,coco_object,a photo of traffic light +108,coco_object,a photo of hand +109,coco_object,a photo of refrigerator +110,coco_object,a photo of surfboard +111,coco_object,a photo of broccoli +112,coco_object,a photo of mouth +113,coco_object,a photo of door handle +114,coco_object,a photo of hair brush +115,coco_object,a photo of cupcake +116,coco_object,a photo of pumpkin +117,coco_object,a photo of dollar bill +118,coco_object,a photo of ladder +119,coco_object,a photo of ears +120,coco_object,a photo of whale +121,coco_object,a photo of bat +122,coco_object,a photo of goose +123,coco_object,a photo of engine +124,coco_object,a photo of nose +125,coco_object,a photo of basketball court +126,coco_object,a photo of cat +127,coco_object,a photo of airplane +128,coco_object,a photo of bus +129,coco_object,a photo of plate +130,coco_object,a photo of steering wheel +131,coco_object,a photo of eyeglasses +132,coco_object,a photo of teapot +133,coco_object,a photo of pizza +134,coco_object,a photo of sandwich +135,coco_object,a photo of suitcase +136,coco_object,a photo of vase +137,coco_object,a photo of power +138,coco_object,a photo of face +139,coco_object,a photo of pillow +140,coco_object,a photo of light switch +141,coco_object,a photo of eye +142,coco_object,a photo of van +143,coco_object,a photo of doll +144,coco_object,a photo of pineapple +145,coco_object,a photo of milk +146,coco_object,a photo of dryer +147,coco_object,a photo of towel +148,coco_object,a photo of hot air balloon +149,coco_object,a photo of soccer ball +150,coco_object,a photo of legos +151,coco_object,a photo of table cloth +152,coco_object,a photo of horn +153,coco_object,a photo of dog +154,coco_object,a photo of hat +155,coco_object,a photo of train +156,coco_object,a photo of cell phone +157,coco_object,a photo of wine glass +158,coco_object,a photo of cup +159,coco_object,a photo of fork +160,coco_object,a photo of squirrel +161,coco_object,a photo of pen +162,coco_object,a photo of carrot +163,coco_object,a photo of baseball bat +164,coco_object,a photo of tennis racket +165,coco_object,a photo of frogs +166,coco_object,a photo of kangaroo +167,coco_object,a photo of soup +168,coco_object,a photo of candle +169,coco_object,a photo of side table +170,coco_object,a photo of cereal +171,coco_object,a photo of field goal posts +172,coco_object,a photo of fly +173,coco_object,a photo of soccer nets +174,coco_object,a photo of firefly +175,coco_object,a photo of horse +176,coco_object,a photo of license plate +177,coco_object,a photo of mirror +178,coco_object,a photo of mouse +179,coco_object,a photo of chicken +180,coco_object,a photo of blender +181,coco_object,a photo of knife +182,coco_object,a photo of duck +183,coco_object,a photo of kite +184,coco_object,a photo of chandelier +185,coco_object,a photo of baseball glove +186,coco_object,a photo of tiger +187,coco_object,a photo of cake +188,coco_object,a photo of rhinoceros +189,coco_object,a photo of meat +190,coco_object,a photo of desktop +191,coco_object,a photo of wheelchair +192,coco_object,a photo of lizard +193,coco_object,a photo of gate +194,coco_object,a photo of seahorse +195,coco_object,a photo of raft +196,coco_object,a photo of roof +197,coco_object,a photo of turkey +198,coco_object,a photo of sheep +199,coco_object,a photo of bed +200,coco_object,a photo of dining table +201,coco_object,a photo of remote +202,coco_object,a photo of zebra +203,coco_object,a photo of hair drier +204,coco_object,a photo of spoon +205,coco_object,a photo of frisbee +206,coco_object,a photo of orange +207,coco_object,a photo of parking meter +208,coco_object,a photo of giraffe +209,coco_object,a photo of table +210,coco_object,a photo of house +211,coco_object,a photo of owl +212,coco_object,a photo of sailboat +213,coco_object,a photo of window +214,coco_object,a photo of carpet +215,coco_object,a photo of building +216,coco_object,a photo of beans +217,coco_object,a photo of rocket +218,coco_object,a photo of rooster +219,coco_object,a photo of tennis net +220,coco_object,a photo of baseball +221,coco_object,a photo of nectar +222,coco_object,a photo of bottle +223,coco_object,a photo of laptop +224,coco_object,a photo of elephant +225,coco_object,a photo of clock +226,coco_object,a photo of wheel +227,coco_object,a photo of bear +228,coco_object,a photo of guitar +229,coco_object,a photo of toothbrush +230,coco_object,a photo of fish +231,coco_object,a photo of jacket +232,coco_object,a photo of coffee table +233,coco_object,a photo of bench +234,coco_object,a photo of cheese +235,coco_object,a photo of scarf +236,coco_object,a photo of deer +237,coco_object,a photo of muffins +238,coco_object,a photo of cookie +239,coco_object,a photo of bacon +240,coco_object,a photo of cabinets +241,coco_object,a photo of copier +242,coco_object,a photo of seats +243,coco_object,a photo of mat diff --git a/data/prompts/train/coco_object_retain.csv b/data/prompts/train/coco_object_retain.csv new file mode 100644 index 00000000..5e05a2a9 --- /dev/null +++ b/data/prompts/train/coco_object_retain.csv @@ -0,0 +1,244 @@ +case_num,source,prompt +1,coco_object,a photo of chair +2,coco_object,a photo of fridge +3,coco_object,a photo of banana +4,coco_object,a photo of street sign +5,coco_object,a photo of headlights +6,coco_object,a photo of printer +7,coco_object,a photo of handbag +8,coco_object,a photo of skis +9,coco_object,a photo of skateboard +10,coco_object,a photo of chopping board +11,coco_object,a photo of goat +12,coco_object,a photo of playing cards +13,coco_object,a photo of tire +14,coco_object,a photo of toy cars +15,coco_object,a photo of box +16,coco_object,a photo of pasta +17,coco_object,a photo of moon +18,coco_object,a photo of basketball +19,coco_object,a photo of radio +20,coco_object,a photo of ipad +21,coco_object,a photo of goldfish +22,coco_object,a photo of jetpack +23,coco_object,a photo of bicycle +24,coco_object,a photo of couch +25,coco_object,a photo of microwave +26,coco_object,a photo of bread +27,coco_object,a photo of umbrella +28,coco_object,a photo of window +29,coco_object,a photo of teddy bear +30,coco_object,a photo of pans +31,coco_object,a photo of hot dog +32,coco_object,a photo of snowboard +33,coco_object,a photo of helicopter +34,coco_object,a photo of washer +35,coco_object,a photo of magazine +36,coco_object,a photo of home +37,coco_object,a photo of phone +38,coco_object,a photo of towel +39,coco_object,a photo of necklace +40,coco_object,a photo of bracelet +41,coco_object,a photo of platypus +42,coco_object,a photo of grapes +43,coco_object,a photo of road +44,coco_object,a photo of telephone +45,coco_object,a photo of fences +46,coco_object,a photo of aardvark +47,coco_object,a photo of iphone +48,coco_object,a photo of robot +49,coco_object,a photo of car +50,coco_object,a photo of potted plant +51,coco_object,a photo of sink +52,coco_object,a photo of apple +53,coco_object,a photo of scissors +54,coco_object,a photo of door +55,coco_object,a photo of desk +56,coco_object,a photo of tie +57,coco_object,a photo of stapler +58,coco_object,a photo of table +59,coco_object,a photo of lamp +60,coco_object,a photo of tomato +61,coco_object,a photo of lion +62,coco_object,a photo of key +63,coco_object,a photo of Pig +64,coco_object,a photo of hyppo +65,coco_object,a photo of tablet +66,coco_object,a photo of bat +67,coco_object,a photo of pancake +68,coco_object,a photo of shark +69,coco_object,a photo of fountain +70,coco_object,a photo of movie +71,coco_object,a photo of goal net +72,coco_object,a photo of dinosaur +73,coco_object,a photo of hoop +74,coco_object,a photo of crusher +75,coco_object,a photo of motorcycle +76,coco_object,a photo of tv +77,coco_object,a photo of oven +78,coco_object,a photo of book +79,coco_object,a photo of keyboard +80,coco_object,a photo of fire hydrant +81,coco_object,a photo of computer +82,coco_object,a photo of stop sign +83,coco_object,a photo of sports ball +84,coco_object,a photo of basketball +85,coco_object,a photo of hoop +86,coco_object,a photo of egg +87,coco_object,a photo of tree +88,coco_object,a photo of monkey +89,coco_object,a photo of frame +90,coco_object,a photo of strawberries +91,coco_object,a photo of can +92,coco_object,a photo of corn +93,coco_object,a photo of balloon +94,coco_object,a photo of cabinet +95,coco_object,a photo of swan +96,coco_object,a photo of fax machine +97,coco_object,a photo of football +98,coco_object,a photo of toys +99,coco_object,a photo of unicycle +100,coco_object,a photo of hen +101,coco_object,a photo of animal crackers +102,coco_object,a photo of bird +103,coco_object,a photo of cow +104,coco_object,a photo of toaster +105,coco_object,a photo of boat +106,coco_object,a photo of backpack +107,coco_object,a photo of traffic light +108,coco_object,a photo of bowl +109,coco_object,a photo of refrigerator +110,coco_object,a photo of surfboard +111,coco_object,a photo of broccoli +112,coco_object,a photo of donut +113,coco_object,a photo of door handle +114,coco_object,a photo of hair brush +115,coco_object,a photo of cupcake +116,coco_object,a photo of pumpkin +117,coco_object,a photo of dollar bill +118,coco_object,a photo of ladder +119,coco_object,a photo of gloves +120,coco_object,a photo of whale +121,coco_object,a photo of bat +122,coco_object,a photo of goose +123,coco_object,a photo of engine +124,coco_object,a photo of honey +125,coco_object,a photo of basketball court +126,coco_object,a photo of cat +127,coco_object,a photo of airplane +128,coco_object,a photo of bus +129,coco_object,a photo of plate +130,coco_object,a photo of steering wheel +131,coco_object,a photo of eyeglasses +132,coco_object,a photo of teapot +133,coco_object,a photo of pizza +134,coco_object,a photo of sandwich +135,coco_object,a photo of suitcase +136,coco_object,a photo of vase +137,coco_object,a photo of power +138,coco_object,a photo of outlet +139,coco_object,a photo of pillow +140,coco_object,a photo of light switch +141,coco_object,a photo of fan +142,coco_object,a photo of van +143,coco_object,a photo of doll +144,coco_object,a photo of pineapple +145,coco_object,a photo of milk +146,coco_object,a photo of dryer +147,coco_object,a photo of towel +148,coco_object,a photo of hot air balloon +149,coco_object,a photo of soccer ball +150,coco_object,a photo of legos +151,coco_object,a photo of table cloth +152,coco_object,a photo of horn +153,coco_object,a photo of dog +154,coco_object,a photo of hat +155,coco_object,a photo of train +156,coco_object,a photo of cell phone +157,coco_object,a photo of wine glass +158,coco_object,a photo of cup +159,coco_object,a photo of fork +160,coco_object,a photo of squirrel +161,coco_object,a photo of pen +162,coco_object,a photo of carrot +163,coco_object,a photo of baseball bat +164,coco_object,a photo of tennis racket +165,coco_object,a photo of frogs +166,coco_object,a photo of kangaroo +167,coco_object,a photo of soup +168,coco_object,a photo of candle +169,coco_object,a photo of side table +170,coco_object,a photo of cereal +171,coco_object,a photo of field goal posts +172,coco_object,a photo of fly +173,coco_object,a photo of soccer nets +174,coco_object,a photo of firefly +175,coco_object,a photo of horse +176,coco_object,a photo of license plate +177,coco_object,a photo of mirror +178,coco_object,a photo of mouse +179,coco_object,a photo of chicken +180,coco_object,a photo of blender +181,coco_object,a photo of knife +182,coco_object,a photo of duck +183,coco_object,a photo of kite +184,coco_object,a photo of chandelier +185,coco_object,a photo of baseball glove +186,coco_object,a photo of tiger +187,coco_object,a photo of cake +188,coco_object,a photo of rhinoceros +189,coco_object,a photo of meat +190,coco_object,a photo of desktop +191,coco_object,a photo of wheelchair +192,coco_object,a photo of lizard +193,coco_object,a photo of gate +194,coco_object,a photo of seahorse +195,coco_object,a photo of raft +196,coco_object,a photo of roof +197,coco_object,a photo of turkey +198,coco_object,a photo of sheep +199,coco_object,a photo of bed +200,coco_object,a photo of dining table +201,coco_object,a photo of remote +202,coco_object,a photo of zebra +203,coco_object,a photo of hair drier +204,coco_object,a photo of spoon +205,coco_object,a photo of frisbee +206,coco_object,a photo of orange +207,coco_object,a photo of parking meter +208,coco_object,a photo of giraffe +209,coco_object,a photo of table +210,coco_object,a photo of house +211,coco_object,a photo of owl +212,coco_object,a photo of sailboat +213,coco_object,a photo of window +214,coco_object,a photo of carpet +215,coco_object,a photo of building +216,coco_object,a photo of beans +217,coco_object,a photo of rocket +218,coco_object,a photo of rooster +219,coco_object,a photo of tennis net +220,coco_object,a photo of baseball +221,coco_object,a photo of nectar +222,coco_object,a photo of bottle +223,coco_object,a photo of laptop +224,coco_object,a photo of elephant +225,coco_object,a photo of clock +226,coco_object,a photo of wheel +227,coco_object,a photo of bear +228,coco_object,a photo of guitar +229,coco_object,a photo of toothbrush +230,coco_object,a photo of fish +231,coco_object,a photo of jacket +232,coco_object,a photo of coffee table +233,coco_object,a photo of bench +234,coco_object,a photo of cheese +235,coco_object,a photo of scarf +236,coco_object,a photo of deer +237,coco_object,a photo of muffins +238,coco_object,a photo of cookie +239,coco_object,a photo of bacon +240,coco_object,a photo of cabinets +241,coco_object,a photo of copier +242,coco_object,a photo of seats +243,coco_object,a photo of mat diff --git a/data/prompts/train/imagenet243_no_filter_retain.csv b/data/prompts/train/imagenet243_no_filter_retain.csv new file mode 100644 index 00000000..d638bacd --- /dev/null +++ b/data/prompts/train/imagenet243_no_filter_retain.csv @@ -0,0 +1,244 @@ +case_num,source,prompt +1,imagenet,a photo of strawberry +2,imagenet,a photo of pedestal +3,imagenet,a photo of scoreboard +4,imagenet,a photo of jaguar +5,imagenet,a photo of ear +6,imagenet,a photo of hummingbird +7,imagenet,a photo of tobacco shop +8,imagenet,a photo of Greater Swiss Mountain dog +9,imagenet,a photo of wine bottle +10,imagenet,a photo of yellow lady-slipper +11,imagenet,a photo of ballpoint +12,imagenet,a photo of Irish water spaniel +13,imagenet,a photo of barn +14,imagenet,a photo of home theater +15,imagenet,a photo of walking stick +16,imagenet,a photo of notebook +17,imagenet,a photo of syringe +18,imagenet,a photo of mask +19,imagenet,a photo of nipple +20,imagenet,a photo of volleyball +21,imagenet,a photo of vulture +22,imagenet,a photo of cloak +23,imagenet,a photo of whiskey jug +24,imagenet,a photo of church +25,imagenet,a photo of bolo tie +26,imagenet,a photo of toy terrier +27,imagenet,a photo of lionfish +28,imagenet,a photo of Bouvier des Flandres +29,imagenet,a photo of photocopier +30,imagenet,a photo of teddy +31,imagenet,a photo of lighter +32,imagenet,a photo of horizontal bar +33,imagenet,a photo of magpie +34,imagenet,a photo of tiger shark +35,imagenet,a photo of wall clock +36,imagenet,a photo of leaf beetle +37,imagenet,a photo of stole +38,imagenet,a photo of basenji +39,imagenet,a photo of tricycle +40,imagenet,a photo of sports car +41,imagenet,a photo of green mamba +42,imagenet,a photo of shopping cart +43,imagenet,a photo of dining table +44,imagenet,a photo of custard apple +45,imagenet,a photo of jackfruit +46,imagenet,a photo of cellular telephone +47,imagenet,a photo of sleeping bag +48,imagenet,a photo of reflex camera +49,imagenet,a photo of beacon +50,imagenet,a photo of bikini +51,imagenet,a photo of dowitcher +52,imagenet,a photo of abacus +53,imagenet,a photo of miniskirt +54,imagenet,a photo of coil +55,imagenet,a photo of lacewing +56,imagenet,a photo of lumbermill +57,imagenet,a photo of white stork +58,imagenet,a photo of parallel bars +59,imagenet,a photo of sliding door +60,imagenet,a photo of lawn mower +61,imagenet,a photo of scuba diver +62,imagenet,a photo of cardigan +63,imagenet,a photo of American coot +64,imagenet,a photo of Border terrier +65,imagenet,a photo of purse +66,imagenet,a photo of gown +67,imagenet,a photo of megalith +68,imagenet,a photo of Polaroid camera +69,imagenet,a photo of green snake +70,imagenet,a photo of guillotine +71,imagenet,a photo of cricket +72,imagenet,a photo of academic gown +73,imagenet,a photo of can opener +74,imagenet,a photo of colobus +75,imagenet,a photo of hip +76,imagenet,a photo of bathtub +77,imagenet,a photo of Norwich terrier +78,imagenet,a photo of Arabian camel +79,imagenet,a photo of Labrador retriever +80,imagenet,a photo of hognose snake +81,imagenet,a photo of overskirt +82,imagenet,a photo of garter snake +83,imagenet,a photo of giant panda +84,imagenet,a photo of Lhasa +85,imagenet,a photo of folding chair +86,imagenet,a photo of lycaenid +87,imagenet,a photo of swimsuit +88,imagenet,a photo of crayfish +89,imagenet,a photo of balance beam +90,imagenet,a photo of junco +91,imagenet,a photo of Christmas stocking +92,imagenet,a photo of quill +93,imagenet,a photo of conch +94,imagenet,a photo of shield +95,imagenet,a photo of trailer truck +96,imagenet,a photo of wooden spoon +97,imagenet,a photo of mountain tent +98,imagenet,a photo of guinea pig +99,imagenet,a photo of tow truck +100,imagenet,a photo of bloodhound +101,imagenet,a photo of rifle +102,imagenet,a photo of grand piano +103,imagenet,a photo of schooner +104,imagenet,a photo of prison +105,imagenet,a photo of Great Pyrenees +106,imagenet,a photo of brain coral +107,imagenet,a photo of nail +108,imagenet,a photo of meat loaf +109,imagenet,a photo of Bedlington terrier +110,imagenet,a photo of steam locomotive +111,imagenet,a photo of crutch +112,imagenet,a photo of Sussex spaniel +113,imagenet,a photo of Great Dane +114,imagenet,a photo of frying pan +115,imagenet,a photo of Tibetan terrier +116,imagenet,a photo of ostrich +117,imagenet,a photo of lampshade +118,imagenet,a photo of standard poodle +119,imagenet,a photo of rock python +120,imagenet,a photo of sunglass +121,imagenet,a photo of plow +122,imagenet,a photo of great grey owl +123,imagenet,a photo of macaque +124,imagenet,a photo of spoonbill +125,imagenet,a photo of jay +126,imagenet,a photo of bookshop +127,imagenet,a photo of quail +128,imagenet,a photo of hyena +129,imagenet,a photo of bee eater +130,imagenet,a photo of croquet ball +131,imagenet,a photo of cabbage butterfly +132,imagenet,a photo of electric fan +133,imagenet,a photo of slug +134,imagenet,a photo of rapeseed +135,imagenet,a photo of worm fence +136,imagenet,a photo of chambered nautilus +137,imagenet,a photo of Windsor tie +138,imagenet,a photo of paintbrush +139,imagenet,a photo of marimba +140,imagenet,a photo of common iguana +141,imagenet,a photo of dial telephone +142,imagenet,a photo of space shuttle +143,imagenet,a photo of hippopotamus +144,imagenet,a photo of cinema +145,imagenet,a photo of cockroach +146,imagenet,a photo of accordion +147,imagenet,a photo of cello +148,imagenet,a photo of water bottle +149,imagenet,a photo of honeycomb +150,imagenet,a photo of bagel +151,imagenet,a photo of lipstick +152,imagenet,a photo of black stork +153,imagenet,a photo of eggnog +154,imagenet,a photo of lorikeet +155,imagenet,a photo of flatworm +156,imagenet,a photo of container ship +157,imagenet,a photo of Egyptian cat +158,imagenet,a photo of miniature pinscher +159,imagenet,a photo of minibus +160,imagenet,a photo of suspension bridge +161,imagenet,a photo of house finch +162,imagenet,a photo of safety pin +163,imagenet,a photo of malamute +164,imagenet,a photo of gibbon +165,imagenet,a photo of lesser panda +166,imagenet,a photo of plunger +167,imagenet,a photo of greenhouse +168,imagenet,a photo of black grouse +169,imagenet,a photo of disk brake +170,imagenet,a photo of tennis ball +171,imagenet,a photo of digital clock +172,imagenet,a photo of cassette +173,imagenet,a photo of streetcar +174,imagenet,a photo of coral reef +175,imagenet,a photo of rock crab +176,imagenet,a photo of weasel +177,imagenet,a photo of steel drum +178,imagenet,a photo of letter opener +179,imagenet,a photo of football helmet +180,imagenet,a photo of trolleybus +181,imagenet,a photo of mortarboard +182,imagenet,a photo of knot +183,imagenet,a photo of leatherback turtle +184,imagenet,a photo of backpack +185,imagenet,a photo of potter wheel +186,imagenet,a photo of chainlink fence +187,imagenet,a photo of poncho +188,imagenet,a photo of pajama +189,imagenet,a photo of miniature schnauzer +190,imagenet,a photo of solar dish +191,imagenet,a photo of breastplate +192,imagenet,a photo of grocery store +193,imagenet,a photo of bra +194,imagenet,a photo of tiger +195,imagenet,a photo of beach wagon +196,imagenet,a photo of rule +197,imagenet,a photo of miniature poodle +198,imagenet,a photo of American chameleon +199,imagenet,a photo of black swan +200,imagenet,a photo of armadillo +201,imagenet,a photo of tennis ball +202,imagenet,a photo of mitten +203,imagenet,a photo of agama +204,imagenet,a photo of polecat +205,imagenet,a photo of space heater +206,imagenet,a photo of dhole +207,imagenet,a photo of monitor +208,imagenet,a photo of sturgeon +209,imagenet,a photo of radio telescope +210,imagenet,a photo of ballet shoe +211,imagenet,a photo of cannon +212,imagenet,a photo of ballet skirt +213,imagenet,a photo of padlock +214,imagenet,a photo of tape player +215,imagenet,a photo of white wolf +216,imagenet,a photo of tub +217,imagenet,a photo of cheetah +218,imagenet,a photo of terrapin +219,imagenet,a photo of Lakeland terrier +220,imagenet,a photo of maillot +221,imagenet,a photo of brown bear +222,imagenet,a photo of pomegranate +223,imagenet,a photo of whiptail +224,imagenet,a photo of scabbard +225,imagenet,a photo of hand-held computer +226,imagenet,a photo of otter +227,imagenet,a photo of bullet train +228,imagenet,a photo of kit fox +229,imagenet,a photo of typewriter keyboard +230,imagenet,a photo of catamaran +231,imagenet,a photo of ashcan +232,imagenet,a photo of scale +233,imagenet,a photo of pineapple +234,imagenet,a photo of dishrag +235,imagenet,a photo of fountain pen +236,imagenet,a photo of comic book +237,imagenet,a photo of piggy bank +238,imagenet,a photo of water jug +239,imagenet,a photo of electric locomotive +240,imagenet,a photo of gorilla +241,imagenet,a photo of racket +242,imagenet,a photo of binoculars +243,imagenet,a photo of holster diff --git a/data/prompts/train/imagenet243_retain.csv b/data/prompts/train/imagenet243_retain.csv new file mode 100644 index 00000000..912e619b --- /dev/null +++ b/data/prompts/train/imagenet243_retain.csv @@ -0,0 +1,244 @@ +case_num,source,prompt +1,imagenet,a photo of strawberry +2,imagenet,a photo of pedestal +3,imagenet,a photo of scoreboard +4,imagenet,a photo of jaguar +5,imagenet,a photo of stove +6,imagenet,a photo of hummingbird +7,imagenet,a photo of tobacco shop +8,imagenet,a photo of Greater Swiss Mountain dog +9,imagenet,a photo of wine bottle +10,imagenet,a photo of yellow lady-slipper +11,imagenet,a photo of ballpoint +12,imagenet,a photo of Irish water spaniel +13,imagenet,a photo of barn +14,imagenet,a photo of home theater +15,imagenet,a photo of walking stick +16,imagenet,a photo of notebook +17,imagenet,a photo of syringe +18,imagenet,a photo of mask +19,imagenet,a photo of nipple +20,imagenet,a photo of volleyball +21,imagenet,a photo of vulture +22,imagenet,a photo of cloak +23,imagenet,a photo of whiskey jug +24,imagenet,a photo of church +25,imagenet,a photo of bolo tie +26,imagenet,a photo of toy terrier +27,imagenet,a photo of lionfish +28,imagenet,a photo of Bouvier des Flandres +29,imagenet,a photo of photocopier +30,imagenet,a photo of teddy +31,imagenet,a photo of lighter +32,imagenet,a photo of horizontal bar +33,imagenet,a photo of magpie +34,imagenet,a photo of tiger shark +35,imagenet,a photo of wall clock +36,imagenet,a photo of leaf beetle +37,imagenet,a photo of stole +38,imagenet,a photo of basenji +39,imagenet,a photo of tricycle +40,imagenet,a photo of sports car +41,imagenet,a photo of green mamba +42,imagenet,a photo of shopping cart +43,imagenet,a photo of dining table +44,imagenet,a photo of custard apple +45,imagenet,a photo of jackfruit +46,imagenet,a photo of cellular telephone +47,imagenet,a photo of sleeping bag +48,imagenet,a photo of reflex camera +49,imagenet,a photo of beacon +50,imagenet,a photo of safe +51,imagenet,a photo of dowitcher +52,imagenet,a photo of abacus +53,imagenet,a photo of koala +54,imagenet,a photo of coil +55,imagenet,a photo of lacewing +56,imagenet,a photo of lumbermill +57,imagenet,a photo of white stork +58,imagenet,a photo of parallel bars +59,imagenet,a photo of sliding door +60,imagenet,a photo of lawn mower +61,imagenet,a photo of wolf spider +62,imagenet,a photo of cardigan +63,imagenet,a photo of American coot +64,imagenet,a photo of Border terrier +65,imagenet,a photo of purse +66,imagenet,a photo of hotdog +67,imagenet,a photo of megalith +68,imagenet,a photo of Polaroid camera +69,imagenet,a photo of green snake +70,imagenet,a photo of guillotine +71,imagenet,a photo of cricket +72,imagenet,a photo of academic gown +73,imagenet,a photo of can opener +74,imagenet,a photo of colobus +75,imagenet,a photo of tree frog +76,imagenet,a photo of bathtub +77,imagenet,a photo of Norwich terrier +78,imagenet,a photo of Arabian camel +79,imagenet,a photo of Labrador retriever +80,imagenet,a photo of hognose snake +81,imagenet,a photo of overskirt +82,imagenet,a photo of garter snake +83,imagenet,a photo of giant panda +84,imagenet,a photo of Lhasa +85,imagenet,a photo of folding chair +86,imagenet,a photo of lycaenid +87,imagenet,a photo of plate +88,imagenet,a photo of crayfish +89,imagenet,a photo of balance beam +90,imagenet,a photo of junco +91,imagenet,a photo of Christmas stocking +92,imagenet,a photo of quill +93,imagenet,a photo of conch +94,imagenet,a photo of shield +95,imagenet,a photo of trailer truck +96,imagenet,a photo of wooden spoon +97,imagenet,a photo of mountain tent +98,imagenet,a photo of guinea pig +99,imagenet,a photo of tow truck +100,imagenet,a photo of bloodhound +101,imagenet,a photo of rifle +102,imagenet,a photo of grand piano +103,imagenet,a photo of schooner +104,imagenet,a photo of prison +105,imagenet,a photo of Great Pyrenees +106,imagenet,a photo of brain coral +107,imagenet,a photo of snail +108,imagenet,a photo of meat loaf +109,imagenet,a photo of Bedlington terrier +110,imagenet,a photo of steam locomotive +111,imagenet,a photo of crutch +112,imagenet,a photo of Sussex spaniel +113,imagenet,a photo of Great Dane +114,imagenet,a photo of frying pan +115,imagenet,a photo of Tibetan terrier +116,imagenet,a photo of ostrich +117,imagenet,a photo of lampshade +118,imagenet,a photo of standard poodle +119,imagenet,a photo of rock python +120,imagenet,a photo of sunglass +121,imagenet,a photo of plow +122,imagenet,a photo of great grey owl +123,imagenet,a photo of macaque +124,imagenet,a photo of spoonbill +125,imagenet,a photo of jay +126,imagenet,a photo of bookshop +127,imagenet,a photo of quail +128,imagenet,a photo of hyena +129,imagenet,a photo of bee eater +130,imagenet,a photo of croquet ball +131,imagenet,a photo of cabbage butterfly +132,imagenet,a photo of electric fan +133,imagenet,a photo of slug +134,imagenet,a photo of rapeseed +135,imagenet,a photo of worm fence +136,imagenet,a photo of chambered nautilus +137,imagenet,a photo of Windsor tie +138,imagenet,a photo of paintbrush +139,imagenet,a photo of marimba +140,imagenet,a photo of common iguana +141,imagenet,a photo of dial telephone +142,imagenet,a photo of space shuttle +143,imagenet,a photo of hippopotamus +144,imagenet,a photo of cinema +145,imagenet,a photo of cockroach +146,imagenet,a photo of accordion +147,imagenet,a photo of cello +148,imagenet,a photo of water bottle +149,imagenet,a photo of honeycomb +150,imagenet,a photo of bagel +151,imagenet,a photo of vase +152,imagenet,a photo of black stork +153,imagenet,a photo of eggnog +154,imagenet,a photo of lorikeet +155,imagenet,a photo of flatworm +156,imagenet,a photo of container ship +157,imagenet,a photo of Egyptian cat +158,imagenet,a photo of miniature pinscher +159,imagenet,a photo of minibus +160,imagenet,a photo of suspension bridge +161,imagenet,a photo of house finch +162,imagenet,a photo of safety pin +163,imagenet,a photo of malamute +164,imagenet,a photo of gibbon +165,imagenet,a photo of lesser panda +166,imagenet,a photo of plunger +167,imagenet,a photo of greenhouse +168,imagenet,a photo of black grouse +169,imagenet,a photo of disk brake +170,imagenet,a photo of jeep +171,imagenet,a photo of digital clock +172,imagenet,a photo of cassette +173,imagenet,a photo of streetcar +174,imagenet,a photo of coral reef +175,imagenet,a photo of rock crab +176,imagenet,a photo of weasel +177,imagenet,a photo of steel drum +178,imagenet,a photo of letter opener +179,imagenet,a photo of football helmet +180,imagenet,a photo of trolleybus +181,imagenet,a photo of mortarboard +182,imagenet,a photo of knot +183,imagenet,a photo of leatherback turtle +184,imagenet,a photo of backpack +185,imagenet,a photo of potter wheel +186,imagenet,a photo of chainlink fence +187,imagenet,a photo of poncho +188,imagenet,a photo of pajama +189,imagenet,a photo of miniature schnauzer +190,imagenet,a photo of solar dish +191,imagenet,a photo of breastplate +192,imagenet,a photo of grocery store +193,imagenet,a photo of pot +194,imagenet,a photo of tiger +195,imagenet,a photo of beach wagon +196,imagenet,a photo of rule +197,imagenet,a photo of miniature poodle +198,imagenet,a photo of American chameleon +199,imagenet,a photo of black swan +200,imagenet,a photo of armadillo +201,imagenet,a photo of tennis ball +202,imagenet,a photo of mitten +203,imagenet,a photo of agama +204,imagenet,a photo of polecat +205,imagenet,a photo of space heater +206,imagenet,a photo of dhole +207,imagenet,a photo of monitor +208,imagenet,a photo of sturgeon +209,imagenet,a photo of radio telescope +210,imagenet,a photo of pillow +211,imagenet,a photo of cannon +212,imagenet,a photo of jean +213,imagenet,a photo of padlock +214,imagenet,a photo of tape player +215,imagenet,a photo of white wolf +216,imagenet,a photo of tub +217,imagenet,a photo of cheetah +218,imagenet,a photo of terrapin +219,imagenet,a photo of Lakeland terrier +220,imagenet,a photo of washer +221,imagenet,a photo of brown bear +222,imagenet,a photo of pomegranate +223,imagenet,a photo of whiptail +224,imagenet,a photo of scabbard +225,imagenet,a photo of hand-held computer +226,imagenet,a photo of otter +227,imagenet,a photo of bullet train +228,imagenet,a photo of kit fox +229,imagenet,a photo of typewriter keyboard +230,imagenet,a photo of catamaran +231,imagenet,a photo of ashcan +232,imagenet,a photo of scale +233,imagenet,a photo of pineapple +234,imagenet,a photo of dishrag +235,imagenet,a photo of fountain pen +236,imagenet,a photo of comic book +237,imagenet,a photo of piggy bank +238,imagenet,a photo of water jug +239,imagenet,a photo of electric locomotive +240,imagenet,a photo of gorilla +241,imagenet,a photo of racket +242,imagenet,a photo of binoculars +243,imagenet,a photo of holster From 634b875e6c3de4519f2b82b34a7acc260a4cb4ef Mon Sep 17 00:00:00 2001 From: Nebula Anish Date: Wed, 12 Feb 2025 08:08:37 +0000 Subject: [PATCH 22/22] fix: use already loaded models in trainer instead of loading again --- mu_attack/execs/adv_attack.py | 57 +++++++++++-------- .../algorithms/adv_unlearn/algorithm.py | 10 ++-- 2 files changed, 39 insertions(+), 28 deletions(-) diff --git a/mu_attack/execs/adv_attack.py b/mu_attack/execs/adv_attack.py index f8758b4f..bb32924a 100644 --- a/mu_attack/execs/adv_attack.py +++ b/mu_attack/execs/adv_attack.py @@ -14,7 +14,7 @@ def __init__(self, config: AdvAttackConfig): # Do not set self.prompt from the config; remove the dependency. self.encoder_model_name_or_path = config.encoder_model_name_or_path self.cache_path = config.cache_path - self.devices = [f'cuda:{int(d.strip())}' for d in config.devices.split(',')] + self.devices = [f"cuda:{int(d.strip())}" for d in config.devices.split(",")] self.attack_type = config.attack_type self.attack_embd_type = config.attack_embd_type self.attack_step = config.attack_step @@ -36,32 +36,34 @@ def __init__(self, config: AdvAttackConfig): # Initialize wandb (if needed) wandb.init( - project=config.project_name, - name=config.experiment_name, - reinit=True + project=config.project_name, name=config.experiment_name, reinit=True ) - self.load_models() + # self.load_models() def load_models(self): if self.backend == "compvis": - self.model_orig, self.sampler_orig, self.model, self.sampler = get_models_for_compvis( - self.config_path, self.compvis_ckpt_path, self.devices + self.model_orig, self.sampler_orig, self.model, self.sampler = ( + get_models_for_compvis( + self.config_path, self.compvis_ckpt_path, self.devices + ) ) elif self.backend == "diffusers": - self.model_orig, self.sampler_orig, self.model, self.sampler = get_models_for_diffusers( - self.diffusers_model_name_or_path, self.target_ckpt, self.devices + self.model_orig, self.sampler_orig, self.model, self.sampler = ( + get_models_for_diffusers( + self.diffusers_model_name_or_path, self.target_ckpt, self.devices + ) ) def attack(self, word, global_step, attack_round): """ Perform the adversarial attack using the given word. - + Args: word (str): The current prompt to attack. global_step (int): The current global training step. attack_round (int): The current attack round. - + Returns: tuple: (adversarial embedding, input_ids) """ @@ -70,11 +72,11 @@ def attack(self, word, global_step, attack_round): sp_attack = SoftPromptAttack( model=self.model, model_orig=self.model_orig, - tokenizer=self.tokenizer, - text_encoder=self.custom_text_encoder, + tokenizer=self.tokenizer, + text_encoder=self.custom_text_encoder, sampler=self.sampler, - emb_0=self._get_emb_0(), - emb_p=self._get_emb_p(word), + emb_0=self._get_emb_0(), + emb_p=self._get_emb_p(word), start_guidance=self.start_guidance, devices=self.devices, ddim_steps=self.ddim_steps, @@ -83,20 +85,29 @@ def attack(self, word, global_step, attack_round): criteria=self.criteria, k=self.adv_prompt_num, all_embeddings=self.all_embeddings, - backend=self.backend + backend=self.backend, + ) + return sp_attack.attack( + global_step, + word, + attack_round, + self.attack_type, + self.attack_embd_type, + self.attack_step, + self.attack_lr, + self.attack_init, + self.attack_init_embd, + self.attack_method, ) - return sp_attack.attack(global_step, word, attack_round, self.attack_type, - self.attack_embd_type, self.attack_step, self.attack_lr, - self.attack_init, self.attack_init_embd, self.attack_method) # Example helper methods to get embeddings from model_orig. def _get_emb_0(self): if self.backend == "compvis": - return self.model_orig.get_learned_conditioning(['']) + return self.model_orig.get_learned_conditioning([""]) else: # For diffusers, you need to define your own method (e.g., using self.encode_text("")) return self.encode_text("") - + def _get_emb_p(self, word): if self.backend == "compvis": return self.model_orig.get_learned_conditioning([word]) @@ -109,10 +120,8 @@ def encode_text(self, text): padding="max_length", truncation=True, max_length=77, - return_tensors="pt" + return_tensors="pt", ).to(self.devices[0]) with torch.no_grad(): text_embeddings = self.text_encoder(text_inputs.input_ids)[0] return text_embeddings - - diff --git a/mu_defense/algorithms/adv_unlearn/algorithm.py b/mu_defense/algorithms/adv_unlearn/algorithm.py index 3798fc5b..09ef54ba 100644 --- a/mu_defense/algorithms/adv_unlearn/algorithm.py +++ b/mu_defense/algorithms/adv_unlearn/algorithm.py @@ -39,7 +39,7 @@ def __init__(self, config: AdvUnlearnConfig, **kwargs): self.model = None self.trainer = None self.devices = self.config.get("devices") - self.devices = [f'cuda:{int(d.strip())}' for d in self.devices.split(',')] + self.devices = [f"cuda:{int(d.strip())}" for d in self.devices.split(",")] self.logger = logging.getLogger(__name__) self._setup_components() @@ -50,9 +50,7 @@ def _setup_components(self): self.logger.info("Setting up components for adversarial unlearning training...") # Initialize Model - self.model = AdvUnlearnModel( - config=self.config - ) + self.model = AdvUnlearnModel(config=self.config) # Initialize Trainer self.trainer = AdvUnlearnTrainer( @@ -60,6 +58,10 @@ def _setup_components(self): config=self.config, devices=self.devices, ) + self.trainer.trainer.adv_attack.model_orig = self.model.model_orig + self.trainer.trainer.adv_attack.sampler_orig = self.model.sampler_orig + self.trainer.trainer.adv_attack.model = self.model.model + self.trainer.trainer.adv_attack.sampler = self.model.sampler def run(self): """