From 097c12d1b56167ae8b75ad1ef1b5af5d9b9ed7cb Mon Sep 17 00:00:00 2001 From: palisthadeshar Date: Mon, 6 Jan 2025 21:11:35 +0545 Subject: [PATCH 01/10] base models added for evaluation framework, implementation for esd --- .../esd/configs/evaluation_config.yaml | 12 + mu/algorithms/esd/scripts/evaluate.py | 239 +++++++++++++++++ mu/core/base_evaluator.py | 39 +++ mu/core/base_image_generator.py | 30 +++ mu/helpers/utils.py | 249 +++++++++++++++++- 5 files changed, 567 insertions(+), 2 deletions(-) create mode 100644 mu/algorithms/esd/configs/evaluation_config.yaml create mode 100644 mu/algorithms/esd/scripts/evaluate.py create mode 100644 mu/core/base_evaluator.py create mode 100644 mu/core/base_image_generator.py diff --git a/mu/algorithms/esd/configs/evaluation_config.yaml b/mu/algorithms/esd/configs/evaluation_config.yaml new file mode 100644 index 00000000..2473ed67 --- /dev/null +++ b/mu/algorithms/esd/configs/evaluation_config.yaml @@ -0,0 +1,12 @@ +model_config: "configs/generate_sd.yaml" +ckpt_path: "../mu_erasing_concept_esd/results/style50/" +theme: "Abstractionism" +cfg_text: 9.0 +seed: 188 +ddim_steps: 100 +image_height: 512 +image_width: 512 +ddim_eta: 0.0 +output_dir: "output/eval_results/mu_results/esd/style50/" +eval_output_dir: "output/eval_results/mu_results/esd/" +original_image_dir: "data/evaluation_images/" diff --git a/mu/algorithms/esd/scripts/evaluate.py b/mu/algorithms/esd/scripts/evaluate.py new file mode 100644 index 00000000..8bd4dd51 --- /dev/null +++ b/mu/algorithms/esd/scripts/evaluate.py @@ -0,0 +1,239 @@ +import os +import torch +import numpy as np +import logging +from PIL import Image +from torch import autocast +from torch import nn +from pytorch_lightning import seed_everything +from argparse import ArgumentParser +from stable_diffusion.ldm.models.diffusion.ddim import DDIMSampler +from stable_diffusion.constants.const import theme_available, class_available +from mu.core.base_image_generator import BaseImageGenerator +from mu.helpers.utils import load_model_from_config, calculate_fid,load_style_ref_images,load_style_generated_images +from mu.helpers import load_config +import logging +from mu.core.base_evaluator import BaseEvaluator +from torchvision import transforms +import timm +from tqdm import tqdm + + +class ESDImageGenerator(BaseImageGenerator): + """ESD Image Generator class extending BaseImageGenerator.""" + + def __init__(self, config: str, **kwargs): + """Initialize the ESDImageGenerator with a YAML config.""" + # Load config and allow overrides from kwargs + self.config = config + self.device = self.config.get('device', 'cuda' if torch.cuda.is_available() else 'cpu') + self.model = None + self.sampler = None + self.logger = logging.getLogger(__name__) + + def load_model(self): + """Load the model using config and initialize the sampler.""" + self.logger.info("Loading model...") + self.model = load_model_from_config(self.config["model_config"], self.config["ckpt_path"]) + self.model = self.model.to(self.device) + self.sampler = DDIMSampler(self.model) + self.logger.info("Model loaded successfully.") + + def sample_image(self): + """Sample and generate images using the ESD model based on the config.""" + steps = self.config['ddim_steps'] + theme = self.config['theme'] + cfg_text = self.config['cfg_text'] + seed = self.config['seed'] + H = self.config['image_height'] + W = self.config['image_width'] + ddim_eta = self.config['ddim_eta'] + output_dir = self.config['output_dir'] + + os.makedirs(output_dir, exist_ok=True) + self.logger.info(f"Generating images and saving to {output_dir}") + seed_everything(seed) + + for test_theme in theme_available: + for object_class in class_available: + prompt = f"A {object_class} image in {test_theme.replace('_', ' ')} style." + with torch.no_grad(): + with autocast(self.device): + with self.model.ema_scope(): + uc = self.model.get_learned_conditioning([""]) + c = self.model.get_learned_conditioning(prompt) + shape = [4, H // 8, W // 8] + samples_ddim, _ = self.sampler.sample( + S=steps, conditioning=c, batch_size=1, shape=shape, + verbose=False, unconditional_guidance_scale=cfg_text, + unconditional_conditioning=uc, eta=ddim_eta, x_T=None + ) + + x_samples_ddim = self.model.decode_first_stage(samples_ddim) + x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0) + x_samples_ddim = x_samples_ddim.cpu().permute(0, 2, 3, 1) + x_sample = (255. * x_samples_ddim[0].numpy()).round().astype(np.uint8) + img = Image.fromarray(x_sample) + self.save_image(img, os.path.join(output_dir, f"{test_theme}_{object_class}_seed{seed}.jpg")) + + self.logger.info("Image generation completed.") + + def save_image(self, image, file_path): + """Save an image to the specified path.""" + image.save(file_path) + self.logger.info(f"Image saved at: {file_path}") + + +class ESDEvaluator(BaseEvaluator): + """Evaluator combining Accuracy and FID metrics using ImageGenerator output.""" + + def __init__(self, config): + self.config = config + self.generator = ESDImageGenerator(config) + self.logger = logging.getLogger(__name__) + self.results = {} + + def load_model(self): + """Load the model.""" + device = self.config['device'] + model = timm.create_model("vit_large_patch16_224.augreg_in21k", pretrained=True).to(device) + num_classes = len(theme_available) + model.head = torch.nn.Linear(1024, num_classes).to(device) + model.load_state_dict(torch.load(self.config['ckpt_path'], map_location=device)["model_state_dict"]) + model.eval() + + def calculate_accuracy(self): + """Calculate unlearning and retaining accuracy using the original accuracy.py logic.""" + device = self.config['device'] + theme = self.config['theme'] + input_dir = self.config['output_dir'] #output from image generation + output_dir = self.config['accuracy_output_dir'] + output_path = os.path.join(output_dir, f"{theme}.pth") if theme is not None else os.path.join(output_dir, "result.pth") + task = self.config['task'] + seed = self.config['seed'] + + os.makedirs(output_dir, exist_ok=True) + model = self.load_model() + + results = {} + results["test_theme"] = theme if theme is not None else "sd" + results["input_dir"] = self.config['output_dir'] + if task == "style": + results["loss"] = {theme: 0.0 for theme in theme_available} + results["acc"] = {theme: 0.0 for theme in theme_available} + results["pred_loss"] = {theme: 0.0 for theme in theme_available} + results["misclassified"] = {theme: {other_theme: 0 for other_theme in theme_available} for theme in theme_available} + else: + results["loss"] = {class_: 0.0 for class_ in class_available} + results["acc"] = {class_: 0.0 for class_ in class_available} + results["pred_loss"] = {class_: 0.0 for class_ in class_available} + results["misclassified"] = {class_: {other_class: 0 for other_class in class_available} for class_ in class_available} + + + image_transform = transforms.Compose([ + transforms.Resize((224, 224)), + transforms.ToTensor(), + transforms.Normalize([0.5], [0.5]), + ]) + + if self.config['task'] == "style": + for idx, test_theme in tqdm(enumerate(theme_available)): + theme_label = idx + for seed in seed: + for object_class in class_available: + img_path = os.path.join(input_dir, f"{test_theme}_{object_class}_seed{seed}.jpg") + image = Image.open(img_path) + target_image = image_transform(image).unsqueeze(0).to(device) + with torch.no_grad(): + res = model(target_image) + label = torch.tensor([theme_label]).to(device) + loss = torch.nn.functional.cross_entropy(res, label) + # softmax the prediction + res_softmax = torch.nn.functional.softmax(res, dim=1) + pred_loss = res_softmax[0][theme_label] + pred_label = torch.argmax(res) + pred_success = (torch.argmax(res) == theme_label).sum() + + results["loss"][test_theme] += loss + results["pred_loss"][test_theme] += pred_loss + results["acc"][test_theme] += (pred_success * 1.0 / (len(class_available) * len(args.seed))) + + misclassified_as = theme_available[pred_label.item()] + results["misclassified"][test_theme][misclassified_as] += 1 + + if not self.config['dry_run']: + torch.save(results, output_path) + + else: + for test_theme in tqdm(theme_available): + for seed in seed: + for idx, object_class in enumerate(class_available): + theme_label = idx + img_path = os.path.join(input_dir, f"{test_theme}_{object_class}_seed{seed}.jpg") + image = Image.open(img_path) + target_image = image_transform(image).unsqueeze(0).to(device) + with torch.no_grad(): + res = model(target_image) + label = torch.tensor([theme_label]).to(device) + loss = torch.nn.functional.cross_entropy(res, label) + # softmax the prediction + res_softmax = torch.nn.functional.softmax(res, dim=1) + pred_loss = res_softmax[0][theme_label] + pred_success = (torch.argmax(res) == theme_label).sum() + pred_label = torch.argmax(res) + + results["loss"][object_class] += loss + results["pred_loss"][object_class] += pred_loss + results["acc"][object_class] += (pred_success * 1.0 / (len(theme_available) * len(seed))) + misclassified_as = class_available[pred_label.item()] + results["misclassified"][object_class][misclassified_as] += 1 + + if not self.config['dry_run']: + torch.save(results, output_path) + + self.results.update(results) + + def calculate_fid_score(self): + """Calculate FID score using the utilities from utils.py.""" + generated_images = load_style_generated_images(self.config['output_dir'], self.config['theme']) + reference_images = load_style_ref_images(self.config['original_image_dir'], self.config['theme']) + + fid_score = calculate_fid(generated_images, reference_images, batch_size=self.config['batch_size']) + self.results["FID"] = fid_score + self.logger.info(f"FID Score calculated: {fid_score}") + + self.save_results(self.results["FID"]) + + def save_results(self,result): + """Save the results.""" + output_path = os.path.join(self.config['eval_output_dir'], "evaluation_results.pth") + torch.save(result, output_path) + self.logger.info(f"Results saved to: {output_path}") + + def run(self): + """Run the full pipeline: image generation, accuracy, and FID.""" + self.logger.info("Starting the evaluation pipeline...") + self.load_model() + self.generator.sample_image() + self.calculate_accuracy() + self.calculate_fid_score() + self.save_results() + self.logger.info("Evaluation completed successfully.") + + +def main(): + """Main entry point for running the entire pipeline.""" + parser = ArgumentParser(description="Unified ESD Evaluation and Sampling") + parser.add_argument('--config_path', required=True, help="Path to the YAML config file.") + args = parser.parse_args() + + # Load configuration + config = load_config(args.config_path) + + # Initialize and run the evaluation + evaluator = ESDEvaluator(config) + evaluator.run() + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/mu/core/base_evaluator.py b/mu/core/base_evaluator.py new file mode 100644 index 00000000..19cb972f --- /dev/null +++ b/mu/core/base_evaluator.py @@ -0,0 +1,39 @@ +from abc import ABC, abstractmethod +from typing import Any, Dict + +class BaseEvaluator(ABC): + """Abstract base class for evaluating image generation models.""" + + def __init__(self, model: Any, config: Dict[str, Any], **kwargs): + self.model = model + self.config = config + + @abstractmethod + def load_model(self, *args, **kwargs): + """Load the model for evaluation.""" + pass + + @abstractmethod + def preprocess_image(self, *args, **kwargs): + """Preprocess images before evaluation.""" + pass + + @abstractmethod + def calculate_accuracy(self, *args, **kwargs): + """Calculate accuracy of the model.""" + pass + + @abstractmethod + def calculate_fid_score(self, *args, **kwargs): + """Calculate the Fréchet Inception Distance (FID) score.""" + pass + + @abstractmethod + def save_results(self, *args, **kwargs): + """Save evaluation results to a file.""" + pass + + @abstractmethod + def run(self, *args, **kwargs): + """Run the evaluation process.""" + pass \ No newline at end of file diff --git a/mu/core/base_image_generator.py b/mu/core/base_image_generator.py new file mode 100644 index 00000000..29bd5e31 --- /dev/null +++ b/mu/core/base_image_generator.py @@ -0,0 +1,30 @@ +from abc import ABC, abstractmethod +from typing import Dict + +class BaseImageGenerator(ABC): + """Abstract base class for all image generators.""" + + @abstractmethod + def __init__(self, config: Dict): + """ + Args: + config (Dict): Configuration parameters for sampling unlearned models. + """ + pass + + @abstractmethod + def load_model(self, *args, **kwargs): + """Load an image.""" + pass + + @abstractmethod + def sample_image(self, *args, **kwargs): + """Generate an image.""" + pass + + @abstractmethod + def save_image(self, *args, **kwargs): + """Save an image.""" + pass + + \ No newline at end of file diff --git a/mu/helpers/utils.py b/mu/helpers/utils.py index 3bff1472..e54a5dae 100644 --- a/mu/helpers/utils.py +++ b/mu/helpers/utils.py @@ -7,10 +7,20 @@ # from pytorch_lightning.utilities import rank_zero_only from pytorch_lightning.utilities.distributed import rank_zero_only from pathlib import Path - +import multiprocessing from stable_diffusion.ldm.util import instantiate_from_config +import torch +import cv2 +import numpy as np +from torchvision.models import inception_v3 +from torch import nn +from scipy import linalg +from stable_diffusion.constants.const import theme_available, class_available +import tqdm + + def str2bool(v): if isinstance(v, bool): return v @@ -129,4 +139,239 @@ def load_config_from_yaml(config_path): @rank_zero_only def rank_zero_print(*args): - print(*args) \ No newline at end of file + print(*args) + + +def load_model_from_config(config, ckpt, verbose=False): + print(f"Loading model from {ckpt}") + pl_sd = torch.load(ckpt, map_location="cpu") + if "global_step" in pl_sd: + print(f"Global Step: {pl_sd['global_step']}") + sd = pl_sd["state_dict"] + model = instantiate_from_config(config.model) + m, u = model.load_state_dict(sd, strict=False) + if len(m) > 0 and verbose: + print("missing keys:") + print(m) + if len(u) > 0 and verbose: + print("unexpected keys:") + print(u) + + model.cuda() + model.eval() + return model + + + +def to_cuda(elements): + """Transfers elements to cuda if GPU is available.""" + if torch.cuda.is_available(): + return elements.to("cuda") + return elements + + +class PartialInceptionNetwork(nn.Module): + """A modified InceptionV3 network used for feature extraction.""" + def __init__(self): + super().__init__() + self.inception_network = inception_v3(pretrained=True) + self.inception_network.Mixed_7c.register_forward_hook(self.output_hook) + + def output_hook(self, module, input, output): + self.mixed_7c_output = output + + def forward(self, x): + x = x * 2 - 1 # Normalize to [-1, 1] + self.inception_network(x) + activations = self.mixed_7c_output + activations = torch.nn.functional.adaptive_avg_pool2d(activations, (1, 1)) + activations = activations.view(x.shape[0], 2048) + return activations + + +def preprocess_image(im): + """Preprocesses a single image.""" + assert im.shape[2] == 3 + if im.dtype == np.uint8: + im = im.astype(np.float32) / 255 + im = cv2.resize(im, (299, 299)) + im = np.rollaxis(im, axis=2) + im = torch.from_numpy(im).float() + assert im.max() <= 1.0 + assert im.min() >= 0.0 + return im + + +def preprocess_images(images, use_multiprocessing=False): + """Resizes and shifts the dynamic range of image to 0-1 + Args: + images: np.array, shape: (N, H, W, 3), dtype: float32 between 0-1 or np.uint8 + use_multiprocessing: If multiprocessing should be used to pre-process the images + Return: + final_images: torch.tensor, shape: (N, 3, 299, 299), dtype: torch.float32 between 0-1 + """ + if use_multiprocessing: + with multiprocessing.Pool(multiprocessing.cpu_count()) as pool: + jobs = [] + for im in images: + job = pool.apply_async(preprocess_image, (im,)) + jobs.append(job) + final_images = torch.zeros(images.shape[0], 3, 299, 299) + for idx, job in enumerate(jobs): + im = job.get() + final_images[idx] = im # job.get() + else: + final_images = torch.stack([preprocess_image(im) for im in images], dim=0) + assert final_images.shape == (images.shape[0], 3, 299, 299) + assert final_images.max() <= 1.0 + assert final_images.min() >= 0.0 + assert final_images.dtype == torch.float32 + return final_images + + + +def get_activations(images, batch_size): + """Calculates activations for last pool layer for all images.""" + num_images = images.shape[0] + inception_network = PartialInceptionNetwork() + inception_network = to_cuda(inception_network) + inception_network.eval() + n_batches = int(np.ceil(num_images / batch_size)) + inception_activations = np.zeros((num_images, 2048), dtype=np.float32) + + for batch_idx in range(n_batches): + start_idx = batch_size * batch_idx + end_idx = batch_size * (batch_idx + 1) + ims = images[start_idx:end_idx].to("cuda") + with torch.no_grad(): + activations = inception_network(ims) + inception_activations[start_idx:end_idx, :] = activations.cpu().numpy() + + return inception_activations + + +def calculate_activation_statistics(images, batch_size): + """Calculates mean and covariance for FID.""" + act = get_activations(images, batch_size) + mu = np.mean(act, axis=0) + sigma = np.cov(act, rowvar=False) + return mu, sigma + + +def calculate_frechet_distance(mu1, sigma1, mu2, sigma2, eps=1e-6): + """Calculates Frechet Distance between two distributions.""" + diff = mu1 - mu2 + covmean, _ = linalg.sqrtm(sigma1 @ sigma2, disp=False) + if not np.isfinite(covmean).all(): + covmean = linalg.sqrtm((sigma1 + eps * np.eye(sigma1.shape[0])) @ + (sigma2 + eps * np.eye(sigma2.shape[0]))) + if np.iscomplexobj(covmean): + covmean = covmean.real + tr_covmean = np.trace(covmean) + fid_value = diff.dot(diff) + np.trace(sigma1) + np.trace(sigma2) - 2 * tr_covmean + return fid_value + + +def calculate_fid(images1, images2, use_multiprocessing=False, batch_size=64): + """ Calculate FID between images1 and images2 + Args: + images1: np.array, shape: (N, H, W, 3), dtype: np.float32 between 0-1 or np.uint8 + images2: np.array, shape: (N, H, W, 3), dtype: np.float32 between 0-1 or np.uint8 + use_multiprocessing: If multiprocessing should be used to pre-process the images + batch size: batch size used for inception network + Returns: + FID (scalar) + """ + images1 = preprocess_images(images1, use_multiprocessing) + images2 = preprocess_images(images2, use_multiprocessing) + mu1, sigma1 = calculate_activation_statistics(images1, batch_size) + print("mu1", mu1.shape, "sigma1", sigma1.shape) + mu2, sigma2 = calculate_activation_statistics(images2, batch_size) + print("mu2", mu2.shape, "sigma2", sigma2.shape) + fid = calculate_frechet_distance(mu1, sigma1, mu2, sigma2) + return fid + +def load_style_generated_images(path, exclude="Abstractionism", seed=[188, 288, 588, 688, 888]): + """ Loads all .png or .jpg images from a given path + Warnings: Expects all images to be of same dtype and shape. + Args: + path: relative path to directory + Returns: + final_images: np.array of image dtype and shape. + """ + image_paths = [] + + if exclude is not None: + if exclude in theme_available: + theme_tested = [x for x in theme_available] + theme_tested.remove(exclude) + class_tested = class_available + else: # exclude is a class + theme_tested = theme_available + class_tested = [x for x in class_available] + class_tested.remove(exclude) + else: + theme_tested = theme_available + class_tested = class_available + for theme in theme_tested: + for object_class in class_tested: + for individual in seed: + image_paths.append(os.path.join(path, f"{theme}_{object_class}_seed{individual}.jpg")) + if not os.path.isfile(image_paths[0]): + raise FileNotFoundError(f"Could not find {image_paths[0]}") + + first_image = cv2.imread(image_paths[0]) + W, H = 512, 512 + image_paths.sort() + image_paths = image_paths + final_images = np.zeros((len(image_paths), H, W, 3), dtype=first_image.dtype) + for idx, impath in tqdm(enumerate(image_paths)): + im = cv2.imread(impath) + im = cv2.resize(im, (W, H)) # Resize image to 512x512 + im = im[:, :, ::-1] # Convert from BGR to RGB + assert im.dtype == final_images.dtype + final_images[idx] = im + return final_images + + +def load_style_ref_images(path, exclude="Seed_Images"): + """ Loads all .png or .jpg images from a given path + Warnings: Expects all images to be of same dtype and shape. + Args: + path: relative path to directory + Returns: + final_images: np.array of image dtype and shape. + """ + image_paths = [] + + if exclude is not None: + # assert exclude in theme_available, f"{exclude} not in {theme_available}" + if exclude in theme_available: + theme_tested = [x for x in theme_available] + theme_tested.remove(exclude) + class_tested = class_available + else: # exclude is a class + theme_tested = theme_available + class_tested = [x for x in class_available] + class_tested.remove(exclude) + else: + theme_tested = theme_available + class_tested = class_available + + for theme in theme_tested: + for object_class in class_tested: + for idx in range(1, 6): + image_paths.append(os.path.join(path, theme, object_class, str(idx) + ".jpg")) + + first_image = cv2.imread(image_paths[0]) + W, H = 512, 512 + image_paths.sort() + image_paths = image_paths + final_images = np.zeros((len(image_paths), H, W, 3), dtype=first_image.dtype) + for idx, impath in tqdm(enumerate(image_paths)): + im = cv2.imread(impath) + im = cv2.resize(im, (W, H)) # Resize image to 512x512 + im = im[:, :, ::-1] # Convert from BGR to RGB + assert im.dtype == final_images.dtype + final_images[idx] = im + return final_images \ No newline at end of file From f34907997e4990b7d4b283c435be40db21e6899d Mon Sep 17 00:00:00 2001 From: palisthadeshar Date: Mon, 6 Jan 2025 21:17:37 +0545 Subject: [PATCH 02/10] bugfix --- mu/algorithms/esd/scripts/evaluate.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mu/algorithms/esd/scripts/evaluate.py b/mu/algorithms/esd/scripts/evaluate.py index 8bd4dd51..353fb43b 100644 --- a/mu/algorithms/esd/scripts/evaluate.py +++ b/mu/algorithms/esd/scripts/evaluate.py @@ -217,7 +217,7 @@ def run(self): self.generator.sample_image() self.calculate_accuracy() self.calculate_fid_score() - self.save_results() + # self.save_results() self.logger.info("Evaluation completed successfully.") From 1f63971d8c7243f9c248fd1d13fe47a1838aeb7a Mon Sep 17 00:00:00 2001 From: palisthadeshar Date: Tue, 7 Jan 2025 12:39:19 +0545 Subject: [PATCH 03/10] config added --- .../esd/configs/evaluation_config.yaml | 10 ++- mu/algorithms/esd/configs/generate_sd.yaml | 70 +++++++++++++++++++ 2 files changed, 77 insertions(+), 3 deletions(-) create mode 100644 mu/algorithms/esd/configs/generate_sd.yaml diff --git a/mu/algorithms/esd/configs/evaluation_config.yaml b/mu/algorithms/esd/configs/evaluation_config.yaml index 2473ed67..c49e96f3 100644 --- a/mu/algorithms/esd/configs/evaluation_config.yaml +++ b/mu/algorithms/esd/configs/evaluation_config.yaml @@ -1,4 +1,4 @@ -model_config: "configs/generate_sd.yaml" +model_config: "mu/algorithms/esd/configs/generate_sd.yaml" ckpt_path: "../mu_erasing_concept_esd/results/style50/" theme: "Abstractionism" cfg_text: 9.0 @@ -7,6 +7,10 @@ ddim_steps: 100 image_height: 512 image_width: 512 ddim_eta: 0.0 -output_dir: "output/eval_results/mu_results/esd/style50/" +sampler_output_dir: "output/eval_results/mu_results/esd/" +classification_model: "vit_large_patch16_224.augreg_in21k" eval_output_dir: "output/eval_results/mu_results/esd/" -original_image_dir: "data/evaluation_images/" +reference_dir: "data/evaluation_images/" +forget_theme: "self-harm" +fid_output_path: "output/eval_results/mu_results/esd/" +multiprocessing: "False" \ No newline at end of file diff --git a/mu/algorithms/esd/configs/generate_sd.yaml b/mu/algorithms/esd/configs/generate_sd.yaml new file mode 100644 index 00000000..d4effe56 --- /dev/null +++ b/mu/algorithms/esd/configs/generate_sd.yaml @@ -0,0 +1,70 @@ +model: + base_learning_rate: 1.0e-04 + target: ldm.models.diffusion.ddpm.LatentDiffusion + params: + linear_start: 0.00085 + linear_end: 0.0120 + num_timesteps_cond: 1 + log_every_t: 200 + timesteps: 1000 + first_stage_key: "jpg" + cond_stage_key: "txt" + image_size: 64 + channels: 4 + cond_stage_trainable: false # Note: different from the one we trained before + conditioning_key: crossattn + monitor: val/loss_simple_ema + scale_factor: 0.18215 + use_ema: False + + scheduler_config: # 10000 warmup steps + target: ldm.lr_scheduler.LambdaLinearScheduler + params: + warm_up_steps: [ 10000 ] + cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases + f_start: [ 1.e-6 ] + f_max: [ 1. ] + f_min: [ 1. ] + + unet_config: + target: ldm.modules.diffusionmodules.openaimodel.UNetModel + params: + image_size: 32 # unused + in_channels: 4 + out_channels: 4 + model_channels: 320 + attention_resolutions: [ 4, 2, 1 ] + num_res_blocks: 2 + channel_mult: [ 1, 2, 4, 4 ] + num_heads: 8 + use_spatial_transformer: True + transformer_depth: 1 + context_dim: 768 + use_checkpoint: True + legacy: False + + first_stage_config: + target: ldm.models.autoencoder.AutoencoderKL + params: + embed_dim: 4 + monitor: val/rec_loss + ddconfig: + double_z: true + z_channels: 4 + resolution: 256 + in_channels: 3 + out_ch: 3 + ch: 128 + ch_mult: + - 1 + - 2 + - 4 + - 4 + num_res_blocks: 2 + attn_resolutions: [] + dropout: 0.0 + lossconfig: + target: torch.nn.Identity + + cond_stage_config: + target: ldm.modules.encoders.modules.FrozenCLIPEmbedder From 54a373d70d5839e0e4b9ea739e83bbf4e4f0704f Mon Sep 17 00:00:00 2001 From: palisthadeshar Date: Tue, 7 Jan 2025 12:41:29 +0545 Subject: [PATCH 04/10] evalutation framework for esd added --- mu/algorithms/esd/evalutator.py | 368 ++++++++++++++++++++++++++ mu/algorithms/esd/scripts/evaluate.py | 279 ++++--------------- 2 files changed, 423 insertions(+), 224 deletions(-) create mode 100644 mu/algorithms/esd/evalutator.py diff --git a/mu/algorithms/esd/evalutator.py b/mu/algorithms/esd/evalutator.py new file mode 100644 index 00000000..746d0a98 --- /dev/null +++ b/mu/algorithms/esd/evalutator.py @@ -0,0 +1,368 @@ +import os +import logging +import torch +import numpy as np +from PIL import Image +from torch import autocast +from pytorch_lightning import seed_everything +from mu.core.base_sampler import BaseSampler +from stable_diffusion.ldm.models.diffusion.ddim import DDIMSampler +from stable_diffusion.constants.const import theme_available, class_available +from mu.helpers import load_config +from mu.helpers.utils import load_ckpt_from_config,load_style_generated_images,load_style_ref_images,calculate_fid +import timm +from tqdm import tqdm +from typing import Any, Dict +from torchvision import transforms +from torch.nn import functional as F +from mu.core.base_evaluator import BaseEvaluator + +class ESDSampler(BaseSampler): + """ESD Image Generator class extending a hypothetical BaseImageGenerator.""" + + def __init__(self, config: dict, **kwargs): + """ + Initialize the ESDImageGenerator with a YAML config (or dict). + + Args: + config (Dict[str, Any]): Dictionary of hyperparams / settings. + **kwargs: Additional keyword arguments that can override config entries. + """ + super().__init__() + + self.config = config + self.device = config.get("device", "cuda" if torch.cuda.is_available() else "cpu") + self.model = None + self.sampler = None + self.logger = logging.getLogger(__name__) + + def load_model(self) -> None: + """ + Load the model using `config` and initialize the sampler. + This example shows a direct load of a single model + DDIMSampler. + In practice, if you have an ESDModel with two models, adapt accordingly. + """ + self.logger.info("Loading model...") + model_ckpt_path = self.config["ckpt_path"] + model_config = load_config(self.config["model_config_path"]) + self.model = load_ckpt_from_config(model_config, model_ckpt_path, verbose=True) + self.model.to(self.device) + self.model.eval() + self.sampler = DDIMSampler(self.model) + self.logger.info("Model loaded and sampler initialized successfully.") + + def sample(self) -> None: + """ + Sample (generate) images using the loaded model and sampler, based on the config. + """ + steps = self.config["ddim_steps"] + theme = self.config["theme"] + cfg_text = self.config["cfg_text"] + seed = self.config["seed"] + H = self.config["image_height"] + W = self.config["image_width"] + ddim_eta = self.config["ddim_eta"] + output_dir = self.config["sampler_output_dir"] + + os.makedirs(output_dir, exist_ok=True) + self.logger.info(f"Generating images and saving to {output_dir}") + + seed_everything(seed) + + for test_theme in theme_available: + for object_class in class_available: + prompt = f"A {object_class} image in {test_theme.replace('_',' ')} style." + self.logger.info(f"Sampling prompt: {prompt}") + + with torch.no_grad(): + with autocast(self.device): + with self.model.ema_scope(): + uc = self.model.get_learned_conditioning([""]) + c = self.model.get_learned_conditioning(prompt) + shape = [4, H // 8, W // 8] + # Generate samples + samples_ddim, _ = self.sampler.sample( + S=steps, + conditioning=c, + batch_size=1, + shape=shape, + verbose=False, + unconditional_guidance_scale=cfg_text, + unconditional_conditioning=uc, + eta=ddim_eta, + x_T=None + ) + + # Convert to numpy image + x_samples_ddim = self.model.decode_first_stage(samples_ddim) + x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0) + x_samples_ddim = x_samples_ddim.cpu().permute(0, 2, 3, 1).numpy() + + assert len(x_samples_ddim) == 1 + + + # Convert to uint8 image + x_sample = x_samples_ddim[0] + + x_sample = (255. * x_sample.numpy()).round() + x_sample = x_sample.astype(np.uint8) + img = Image.fromarray(x_sample) + + #save image + filename = f"{test_theme}_{object_class}_seed{seed}.jpg" + outpath = os.path.join(output_dir, filename) + self.save_image(img, outpath) + + self.logger.info("Image generation completed.") + + def save_image(self, image: Image.Image, file_path: str) -> None: + """ + Save an image to the specified path. + """ + image.save(file_path) + self.logger.info(f"Image saved at: {file_path}") + + + +class ESDEvaluator(BaseEvaluator): + """ + Example evaluator that calculates classification accuracy on generated images. + Inherits from the abstract BaseEvaluator. + """ + + def __init__(self, sampler: Any, config: Dict[str, Any], **kwargs): + """ + Args: + sampler (Any): An instance of a BaseSampler-derived class (e.g., ESDSampler). + config (Dict[str, Any]): A dict of hyperparameters / evaluation settings. + **kwargs: Additional overrides for config. + """ + super().__init__(sampler, config, **kwargs) + self.logger = logging.getLogger(__name__) + self.device = "cuda" if torch.cuda.is_available() else "cpu" + self.sampler = ESDSampler(config) + self.classification_model = None + self.results = {} + + def load_model(self, *args, **kwargs): + """ + Load the classification model for evaluation, using 'timm' + or any approach you prefer. + We assume your config has 'classification_ckpt' and 'task' keys, etc. + """ + self.logger.info("Loading classification model...") + model = self.config.get("classification_model") + self.classification_model = timm.create_model( + model, + pretrained=True + ).to(self.device) + task = self.config['task'] # "style" or "class" + num_classes = len(theme_available) if task == "style" else len(class_available) + self.classification_model.head = torch.nn.Linear(1024, num_classes).to(self.device) + + # Load checkpoint + ckpt_path = self.config["classification_ckpt"] + self.logger.info(f"Loading classification checkpoint from: {ckpt_path}") + checkpoint = torch.load(ckpt_path, map_location=self.device) + self.classification_model.load_state_dict(checkpoint["model_state_dict"]) + self.classification_model.eval() + + self.logger.info("Classification model loaded successfully.") + + def preprocess_image(self, image: Image.Image): + """ + Preprocess the input PIL image before feeding into the classifier. + Replicates the transforms from your accuracy.py script. + """ + image_transform = transforms.Compose([ + transforms.Resize((224, 224)), + transforms.ToTensor(), + transforms.Normalize([0.5], [0.5]), + ]) + return image_transform(image).unsqueeze(0).to(self.device) + + def calculate_accuracy(self, *args, **kwargs): + """ + Calculate accuracy of the classification model on generated images. + Mirrors the logic from your accuracy.py but integrated into a single method. + """ + self.logger.info("Starting accuracy calculation...") + + # Pull relevant config + theme = self.config.get("theme", None) + input_dir = self.config['sampler_output_dir'] + output_dir = self.config["eval_output_dir"] + seed_list = self.config.get("seed_list", [188, 288, 588, 688, 888]) + dry_run = self.config.get("dry_run", False) + task = self.config['task'] + + if theme is not None: + input_dir = os.path.join(input_dir, theme) + + os.makedirs(output_dir, exist_ok=True) + output_path = (os.path.join(output_dir, f"{theme}.pth") + if theme is not None + else os.path.join(output_dir, "result.pth")) + + # Initialize results dictionary + self.results = { + "test_theme": theme if theme is not None else "sd", + "input_dir": input_dir, + } + + if task == "style": + self.results["loss"] = {th: 0.0 for th in theme_available} + self.results["acc"] = {th: 0.0 for th in theme_available} + self.results["pred_loss"] = {th: 0.0 for th in theme_available} + self.results["misclassified"] = { + th: {oth: 0 for oth in theme_available} + for th in theme_available + } + else: # task == "class" + self.results["loss"] = {cls_: 0.0 for cls_ in class_available} + self.results["acc"] = {cls_: 0.0 for cls_ in class_available} + self.results["pred_loss"] = {cls_: 0.0 for cls_ in class_available} + self.results["misclassified"] = { + cls_: {other_cls: 0 for other_cls in class_available} + for cls_ in class_available + } + + # Evaluate + if task == "style": + for idx, test_theme in tqdm(enumerate(theme_available), total=len(theme_available)): + theme_label = idx + for seed in seed_list: + for object_class in class_available: + img_file = f"{test_theme}_{object_class}_seed{seed}.jpg" + img_path = os.path.join(input_dir, img_file) + if not os.path.exists(img_path): + self.logger.warning(f"Image not found: {img_path}") + continue + + # Preprocess + image = Image.open(img_path) + tensor_img = self.preprocess_image(image) + label = torch.tensor([theme_label]).to(self.device) + + # Forward pass + with torch.no_grad(): + res = self.classification_model(tensor_img) + + # Compute losses + loss = F.cross_entropy(res, label) + res_softmax = F.softmax(res, dim=1) + pred_loss_val = res_softmax[0][theme_label].item() + pred_label = torch.argmax(res).item() + pred_success = (pred_label == theme_label) + + # Accumulate stats + self.results["loss"][test_theme] += loss.item() + self.results["pred_loss"][test_theme] += pred_loss_val + # Probability of success is 1 if pred_success else 0, + # but for your code, you were dividing by total. So let's keep a sum for now: + self.results["acc"][test_theme] += (1 if pred_success else 0) + + misclassified_as = theme_available[pred_label] + self.results["misclassified"][test_theme][misclassified_as] += 1 + + if not dry_run: + self.save_results(self.results, output_path) + + else: # task == "class" + for test_theme in tqdm(theme_available, total=len(theme_available)): + for seed in seed_list: + for idx, object_class in enumerate(class_available): + label_val = idx + img_file = f"{test_theme}_{object_class}_seed{seed}.jpg" + img_path = os.path.join(input_dir, img_file) + if not os.path.exists(img_path): + self.logger.warning(f"Image not found: {img_path}") + continue + + # Preprocess + image = Image.open(img_path) + tensor_img = self.preprocess_image(image) + label = torch.tensor([label_val]).to(self.device) + + with torch.no_grad(): + res = self.classification_model(tensor_img) + + loss = F.cross_entropy(res, label) + res_softmax = F.softmax(res, dim=1) + pred_loss_val = res_softmax[0][label_val].item() + pred_label = torch.argmax(res).item() + pred_success = (pred_label == label_val) + + self.results["loss"][object_class] += loss.item() + self.results["pred_loss"][object_class] += pred_loss_val + self.results["acc"][object_class] += (1 if pred_success else 0) + + misclassified_as = class_available[pred_label] + self.results["misclassified"][object_class][misclassified_as] += 1 + + if not dry_run: + self.save_results(self.results, output_path) + + self.logger.info("Accuracy calculation completed.") + + def calculate_fid_score(self, *args, **kwargs): + """ + Calculate the Fréchet Inception Distance (FID) score using the images + generated by ESDSampler vs. some reference images. + """ + self.logger.info("Starting FID calculation...") + + generated_path = self.config["sampler_output_dir"] + reference_path = self.config["reference_dir"] + forget_theme = self.config.get("forget_theme", None) + use_multiprocessing = self.config.get("multiprocessing", False) + batch_size = self.config.get("batch_size", 64) + output_dir = self.config["fid_output_path"] + os.makedirs(output_dir, exist_ok=True) + + images_generated = load_style_generated_images( + path=generated_path, + exclude=forget_theme, + seed=self.config.get("seed_list", [188, 288, 588, 688, 888]) + ) + images_reference = load_style_ref_images( + path=reference_path, + exclude=forget_theme + ) + + fid_value = calculate_fid( + images1=images_reference, + images2=images_generated, + use_multiprocessing=use_multiprocessing, + batch_size=batch_size + ) + self.logger.info(f"Calculated FID: {fid_value}") + self.results["FID"] = fid_value + fid_path = os.path.join(output_dir, "fid_value.pth") + torch.save({"FID": fid_value}, fid_path) + self.logger.info(f"FID results saved to: {fid_path}") + + def save_results(self, results: dict, output_path: str): + """ + Save evaluation results to a file. You can also do JSON or CSV if desired. + """ + torch.save(results, output_path) + self.logger.info(f"Results saved to: {output_path}") + + def run(self, *args, **kwargs): + """ + Run the complete evaluation process: + 1) Load the classification model + 2) Generate images (if you want to use sampler here) or skip if already generated + 3) Calculate accuracy + 4) calculate FID + 5) [Save final results if needed] + """ + self.load_model() + + self.calculate_accuracy() + self.calculate_fid_score() + + self.save_results(self.results, os.path.join(self.config["eval_output_dir"], "final_results.pth")) + + self.logger.info("Evaluation run completed.") diff --git a/mu/algorithms/esd/scripts/evaluate.py b/mu/algorithms/esd/scripts/evaluate.py index 353fb43b..4c707a1d 100644 --- a/mu/algorithms/esd/scripts/evaluate.py +++ b/mu/algorithms/esd/scripts/evaluate.py @@ -1,239 +1,70 @@ -import os -import torch -import numpy as np -import logging -from PIL import Image -from torch import autocast -from torch import nn -from pytorch_lightning import seed_everything from argparse import ArgumentParser -from stable_diffusion.ldm.models.diffusion.ddim import DDIMSampler -from stable_diffusion.constants.const import theme_available, class_available -from mu.core.base_image_generator import BaseImageGenerator -from mu.helpers.utils import load_model_from_config, calculate_fid,load_style_ref_images,load_style_generated_images from mu.helpers import load_config -import logging -from mu.core.base_evaluator import BaseEvaluator -from torchvision import transforms -import timm -from tqdm import tqdm - - -class ESDImageGenerator(BaseImageGenerator): - """ESD Image Generator class extending BaseImageGenerator.""" - - def __init__(self, config: str, **kwargs): - """Initialize the ESDImageGenerator with a YAML config.""" - # Load config and allow overrides from kwargs - self.config = config - self.device = self.config.get('device', 'cuda' if torch.cuda.is_available() else 'cpu') - self.model = None - self.sampler = None - self.logger = logging.getLogger(__name__) - - def load_model(self): - """Load the model using config and initialize the sampler.""" - self.logger.info("Loading model...") - self.model = load_model_from_config(self.config["model_config"], self.config["ckpt_path"]) - self.model = self.model.to(self.device) - self.sampler = DDIMSampler(self.model) - self.logger.info("Model loaded successfully.") - - def sample_image(self): - """Sample and generate images using the ESD model based on the config.""" - steps = self.config['ddim_steps'] - theme = self.config['theme'] - cfg_text = self.config['cfg_text'] - seed = self.config['seed'] - H = self.config['image_height'] - W = self.config['image_width'] - ddim_eta = self.config['ddim_eta'] - output_dir = self.config['output_dir'] - - os.makedirs(output_dir, exist_ok=True) - self.logger.info(f"Generating images and saving to {output_dir}") - seed_everything(seed) - - for test_theme in theme_available: - for object_class in class_available: - prompt = f"A {object_class} image in {test_theme.replace('_', ' ')} style." - with torch.no_grad(): - with autocast(self.device): - with self.model.ema_scope(): - uc = self.model.get_learned_conditioning([""]) - c = self.model.get_learned_conditioning(prompt) - shape = [4, H // 8, W // 8] - samples_ddim, _ = self.sampler.sample( - S=steps, conditioning=c, batch_size=1, shape=shape, - verbose=False, unconditional_guidance_scale=cfg_text, - unconditional_conditioning=uc, eta=ddim_eta, x_T=None - ) - - x_samples_ddim = self.model.decode_first_stage(samples_ddim) - x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0) - x_samples_ddim = x_samples_ddim.cpu().permute(0, 2, 3, 1) - x_sample = (255. * x_samples_ddim[0].numpy()).round().astype(np.uint8) - img = Image.fromarray(x_sample) - self.save_image(img, os.path.join(output_dir, f"{test_theme}_{object_class}_seed{seed}.jpg")) - - self.logger.info("Image generation completed.") - - def save_image(self, image, file_path): - """Save an image to the specified path.""" - image.save(file_path) - self.logger.info(f"Image saved at: {file_path}") - - -class ESDEvaluator(BaseEvaluator): - """Evaluator combining Accuracy and FID metrics using ImageGenerator output.""" - - def __init__(self, config): - self.config = config - self.generator = ESDImageGenerator(config) - self.logger = logging.getLogger(__name__) - self.results = {} - - def load_model(self): - """Load the model.""" - device = self.config['device'] - model = timm.create_model("vit_large_patch16_224.augreg_in21k", pretrained=True).to(device) - num_classes = len(theme_available) - model.head = torch.nn.Linear(1024, num_classes).to(device) - model.load_state_dict(torch.load(self.config['ckpt_path'], map_location=device)["model_state_dict"]) - model.eval() - - def calculate_accuracy(self): - """Calculate unlearning and retaining accuracy using the original accuracy.py logic.""" - device = self.config['device'] - theme = self.config['theme'] - input_dir = self.config['output_dir'] #output from image generation - output_dir = self.config['accuracy_output_dir'] - output_path = os.path.join(output_dir, f"{theme}.pth") if theme is not None else os.path.join(output_dir, "result.pth") - task = self.config['task'] - seed = self.config['seed'] - - os.makedirs(output_dir, exist_ok=True) - model = self.load_model() - - results = {} - results["test_theme"] = theme if theme is not None else "sd" - results["input_dir"] = self.config['output_dir'] - if task == "style": - results["loss"] = {theme: 0.0 for theme in theme_available} - results["acc"] = {theme: 0.0 for theme in theme_available} - results["pred_loss"] = {theme: 0.0 for theme in theme_available} - results["misclassified"] = {theme: {other_theme: 0 for other_theme in theme_available} for theme in theme_available} - else: - results["loss"] = {class_: 0.0 for class_ in class_available} - results["acc"] = {class_: 0.0 for class_ in class_available} - results["pred_loss"] = {class_: 0.0 for class_ in class_available} - results["misclassified"] = {class_: {other_class: 0 for other_class in class_available} for class_ in class_available} - - - image_transform = transforms.Compose([ - transforms.Resize((224, 224)), - transforms.ToTensor(), - transforms.Normalize([0.5], [0.5]), - ]) - - if self.config['task'] == "style": - for idx, test_theme in tqdm(enumerate(theme_available)): - theme_label = idx - for seed in seed: - for object_class in class_available: - img_path = os.path.join(input_dir, f"{test_theme}_{object_class}_seed{seed}.jpg") - image = Image.open(img_path) - target_image = image_transform(image).unsqueeze(0).to(device) - with torch.no_grad(): - res = model(target_image) - label = torch.tensor([theme_label]).to(device) - loss = torch.nn.functional.cross_entropy(res, label) - # softmax the prediction - res_softmax = torch.nn.functional.softmax(res, dim=1) - pred_loss = res_softmax[0][theme_label] - pred_label = torch.argmax(res) - pred_success = (torch.argmax(res) == theme_label).sum() - - results["loss"][test_theme] += loss - results["pred_loss"][test_theme] += pred_loss - results["acc"][test_theme] += (pred_success * 1.0 / (len(class_available) * len(args.seed))) - - misclassified_as = theme_available[pred_label.item()] - results["misclassified"][test_theme][misclassified_as] += 1 - - if not self.config['dry_run']: - torch.save(results, output_path) - - else: - for test_theme in tqdm(theme_available): - for seed in seed: - for idx, object_class in enumerate(class_available): - theme_label = idx - img_path = os.path.join(input_dir, f"{test_theme}_{object_class}_seed{seed}.jpg") - image = Image.open(img_path) - target_image = image_transform(image).unsqueeze(0).to(device) - with torch.no_grad(): - res = model(target_image) - label = torch.tensor([theme_label]).to(device) - loss = torch.nn.functional.cross_entropy(res, label) - # softmax the prediction - res_softmax = torch.nn.functional.softmax(res, dim=1) - pred_loss = res_softmax[0][theme_label] - pred_success = (torch.argmax(res) == theme_label).sum() - pred_label = torch.argmax(res) - - results["loss"][object_class] += loss - results["pred_loss"][object_class] += pred_loss - results["acc"][object_class] += (pred_success * 1.0 / (len(theme_available) * len(seed))) - misclassified_as = class_available[pred_label.item()] - results["misclassified"][object_class][misclassified_as] += 1 - - if not self.config['dry_run']: - torch.save(results, output_path) - - self.results.update(results) - - def calculate_fid_score(self): - """Calculate FID score using the utilities from utils.py.""" - generated_images = load_style_generated_images(self.config['output_dir'], self.config['theme']) - reference_images = load_style_ref_images(self.config['original_image_dir'], self.config['theme']) - - fid_score = calculate_fid(generated_images, reference_images, batch_size=self.config['batch_size']) - self.results["FID"] = fid_score - self.logger.info(f"FID Score calculated: {fid_score}") - - self.save_results(self.results["FID"]) - - def save_results(self,result): - """Save the results.""" - output_path = os.path.join(self.config['eval_output_dir'], "evaluation_results.pth") - torch.save(result, output_path) - self.logger.info(f"Results saved to: {output_path}") - - def run(self): - """Run the full pipeline: image generation, accuracy, and FID.""" - self.logger.info("Starting the evaluation pipeline...") - self.load_model() - self.generator.sample_image() - self.calculate_accuracy() - self.calculate_fid_score() - # self.save_results() - self.logger.info("Evaluation completed successfully.") - +from mu.algorithms.esd.evalutator import ESDEvaluator def main(): """Main entry point for running the entire pipeline.""" parser = ArgumentParser(description="Unified ESD Evaluation and Sampling") parser.add_argument('--config_path', required=True, help="Path to the YAML config file.") + + # Below: optional overrides for your config dictionary + parser.add_argument('--model_config', type=str, help="Override path for model_config") + parser.add_argument('--ckpt_path', type=str, help="Override checkpoint path") + parser.add_argument('--theme', type=str, help="Override the theme in config") + parser.add_argument('--cfg_text', type=float, help="Override the cfg_text (guidance scale)") + parser.add_argument('--seed', type=int, help="Override the seed") + parser.add_argument('--ddim_steps', type=int, help="Override the number of ddim_steps") + parser.add_argument('--image_height', type=int, help="Override image height") + parser.add_argument('--image_width', type=int, help="Override image width") + parser.add_argument('--ddim_eta', type=float, help="Override DDIM eta") + parser.add_argument('--sampler_output_dir', type=str, help="Override output directory for sampler") + parser.add_argument('--classification_model', type=str, help="Override the classification model name") + parser.add_argument('--eval_output_dir', type=str, help="Override evaluation output directory") + parser.add_argument('--reference_dir', type=str, help="Override reference images directory") + parser.add_argument('--forget_theme', type=str, help="Override the forget_theme setting") + parser.add_argument('--multiprocessing', type=str, help="Override the multiprocessing flag (True/False)") + parser.add_argument('--batch_size', type=int, help="Override the FID batch_size") + args = parser.parse_args() - # Load configuration config = load_config(args.config_path) - # Initialize and run the evaluation + # Override config fields if CLI arguments are provided + if args.model_config is not None: + config["model_config"] = args.model_config + if args.ckpt_path is not None: + config["ckpt_path"] = args.ckpt_path + if args.theme is not None: + config["theme"] = args.theme + if args.cfg_text is not None: + config["cfg_text"] = args.cfg_text + if args.seed is not None: + config["seed"] = args.seed + if args.ddim_steps is not None: + config["ddim_steps"] = args.ddim_steps + if args.image_height is not None: + config["image_height"] = args.image_height + if args.image_width is not None: + config["image_width"] = args.image_width + if args.ddim_eta is not None: + config["ddim_eta"] = args.ddim_eta + if args.sampler_output_dir is not None: + config["sampler_output_dir"] = args.sampler_output_dir + if args.classification_model is not None: + config["classification_model"] = args.classification_model + if args.eval_output_dir is not None: + config["eval_output_dir"] = args.eval_output_dir + if args.reference_dir is not None: + config["reference_dir"] = args.reference_dir + if args.forget_theme is not None: + config["forget_theme"] = args.forget_theme + if args.multiprocessing is not None: + config["multiprocessing"] = args.multiprocessing + if args.batch_size is not None: + config["batch_size"] = args.batch_size + evaluator = ESDEvaluator(config) evaluator.run() - if __name__ == "__main__": - main() \ No newline at end of file + main() From c0370e19408b12968537b8656dab7d634291e983 Mon Sep 17 00:00:00 2001 From: palisthadeshar Date: Tue, 7 Jan 2025 12:42:16 +0545 Subject: [PATCH 05/10] base class and utils added --- mu/core/base_evaluator.py | 4 +- mu/core/base_image_generator.py | 30 ----- mu/core/base_sampler.py | 9 ++ mu/helpers/utils.py | 207 +++++++++++++++++++------------- 4 files changed, 132 insertions(+), 118 deletions(-) delete mode 100644 mu/core/base_image_generator.py diff --git a/mu/core/base_evaluator.py b/mu/core/base_evaluator.py index 19cb972f..56d0690c 100644 --- a/mu/core/base_evaluator.py +++ b/mu/core/base_evaluator.py @@ -4,8 +4,8 @@ class BaseEvaluator(ABC): """Abstract base class for evaluating image generation models.""" - def __init__(self, model: Any, config: Dict[str, Any], **kwargs): - self.model = model + def __init__(self, sampler: Any, config: Dict[str, Any], **kwargs): + self.sampler =sampler self.config = config @abstractmethod diff --git a/mu/core/base_image_generator.py b/mu/core/base_image_generator.py deleted file mode 100644 index 29bd5e31..00000000 --- a/mu/core/base_image_generator.py +++ /dev/null @@ -1,30 +0,0 @@ -from abc import ABC, abstractmethod -from typing import Dict - -class BaseImageGenerator(ABC): - """Abstract base class for all image generators.""" - - @abstractmethod - def __init__(self, config: Dict): - """ - Args: - config (Dict): Configuration parameters for sampling unlearned models. - """ - pass - - @abstractmethod - def load_model(self, *args, **kwargs): - """Load an image.""" - pass - - @abstractmethod - def sample_image(self, *args, **kwargs): - """Generate an image.""" - pass - - @abstractmethod - def save_image(self, *args, **kwargs): - """Save an image.""" - pass - - \ No newline at end of file diff --git a/mu/core/base_sampler.py b/mu/core/base_sampler.py index aef5c267..771bc1ee 100644 --- a/mu/core/base_sampler.py +++ b/mu/core/base_sampler.py @@ -18,3 +18,12 @@ def sample(self, **kwargs) -> Any: Any: Generated samples. """ pass + + def load_model(self, *args, **kwargs): + """Load an image.""" + pass + + def save_image(self, *args, **kwargs): + """Save an image.""" + pass + diff --git a/mu/helpers/utils.py b/mu/helpers/utils.py index e54a5dae..4fcb846b 100644 --- a/mu/helpers/utils.py +++ b/mu/helpers/utils.py @@ -19,6 +19,18 @@ from scipy import linalg from stable_diffusion.constants.const import theme_available, class_available import tqdm +import os +import cv2 +import torch +import warnings +import numpy as np +from tqdm import tqdm +from scipy import linalg +import multiprocessing +from torch import nn +from torchvision.models import inception_v3 + +from constants.const import theme_available, class_available def str2bool(v): @@ -142,7 +154,7 @@ def rank_zero_print(*args): print(*args) -def load_model_from_config(config, ckpt, verbose=False): +def load_ckpt_from_config(config, ckpt, verbose=False): print(f"Loading model from {ckpt}") pl_sd = torch.load(ckpt, map_location="cpu") if "global_step" in pl_sd: @@ -164,133 +176,156 @@ def load_model_from_config(config, ckpt, verbose=False): def to_cuda(elements): - """Transfers elements to cuda if GPU is available.""" + """Transfers elements to CUDA if GPU is available.""" if torch.cuda.is_available(): return elements.to("cuda") return elements - class PartialInceptionNetwork(nn.Module): - """A modified InceptionV3 network used for feature extraction.""" - def __init__(self): + """ + A modified InceptionV3 network used for feature extraction. + Captures activations from the Mixed_7c layer and outputs shape (N, 2048). + """ + def __init__(self, transform_input=True): super().__init__() self.inception_network = inception_v3(pretrained=True) + # Register a forward hook to capture activations from Mixed_7c self.inception_network.Mixed_7c.register_forward_hook(self.output_hook) + self.transform_input = transform_input def output_hook(self, module, input, output): - self.mixed_7c_output = output + self.mixed_7c_output = output # shape (N, 2048, 8, 8) def forward(self, x): - x = x * 2 - 1 # Normalize to [-1, 1] - self.inception_network(x) - activations = self.mixed_7c_output + """ + x: (N, 3, 299, 299) float32 in [0,1] + Returns: (N, 2048) float32 + """ + assert x.shape[1:] == (3, 299, 299), f"Expected (N,3,299,299), got {x.shape}" + # Shift to [-1, 1] + x = x * 2 - 1 + # Trigger output hook + _ = self.inception_network(x) + # Collect the activations + activations = self.mixed_7c_output # (N, 2048, 8, 8) activations = torch.nn.functional.adaptive_avg_pool2d(activations, (1, 1)) activations = activations.view(x.shape[0], 2048) return activations - -def preprocess_image(im): - """Preprocesses a single image.""" - assert im.shape[2] == 3 - if im.dtype == np.uint8: - im = im.astype(np.float32) / 255 - im = cv2.resize(im, (299, 299)) - im = np.rollaxis(im, axis=2) - im = torch.from_numpy(im).float() - assert im.max() <= 1.0 - assert im.min() >= 0.0 - return im - - -def preprocess_images(images, use_multiprocessing=False): - """Resizes and shifts the dynamic range of image to 0-1 - Args: - images: np.array, shape: (N, H, W, 3), dtype: float32 between 0-1 or np.uint8 - use_multiprocessing: If multiprocessing should be used to pre-process the images - Return: - final_images: torch.tensor, shape: (N, 3, 299, 299), dtype: torch.float32 between 0-1 +def get_activations(images, batch_size=64): """ - if use_multiprocessing: - with multiprocessing.Pool(multiprocessing.cpu_count()) as pool: - jobs = [] - for im in images: - job = pool.apply_async(preprocess_image, (im,)) - jobs.append(job) - final_images = torch.zeros(images.shape[0], 3, 299, 299) - for idx, job in enumerate(jobs): - im = job.get() - final_images[idx] = im # job.get() - else: - final_images = torch.stack([preprocess_image(im) for im in images], dim=0) - assert final_images.shape == (images.shape[0], 3, 299, 299) - assert final_images.max() <= 1.0 - assert final_images.min() >= 0.0 - assert final_images.dtype == torch.float32 - return final_images - - - -def get_activations(images, batch_size): - """Calculates activations for last pool layer for all images.""" + Calculates activations for the last pool layer for all images using PartialInceptionNetwork. + images: shape (N, 3, 299, 299), float32 in [0,1] + Returns: np.array shape (N, 2048) + """ + assert images.shape[1:] == (3, 299, 299) num_images = images.shape[0] - inception_network = PartialInceptionNetwork() - inception_network = to_cuda(inception_network) - inception_network.eval() + inception_net = PartialInceptionNetwork().eval() + inception_net = to_cuda(inception_net) + n_batches = int(np.ceil(num_images / batch_size)) inception_activations = np.zeros((num_images, 2048), dtype=np.float32) - for batch_idx in range(n_batches): - start_idx = batch_size * batch_idx - end_idx = batch_size * (batch_idx + 1) - ims = images[start_idx:end_idx].to("cuda") + idx = 0 + for _ in range(n_batches): + start = idx + end = min(start + batch_size, num_images) + ims = images[start:end] + ims = to_cuda(ims) with torch.no_grad(): - activations = inception_network(ims) - inception_activations[start_idx:end_idx, :] = activations.cpu().numpy() - + batch_activations = inception_net(ims) + inception_activations[start:end, :] = batch_activations.cpu().numpy() + idx = end return inception_activations - -def calculate_activation_statistics(images, batch_size): - """Calculates mean and covariance for FID.""" - act = get_activations(images, batch_size) +def calculate_activation_statistics(images, batch_size=64): + """ + Calculates the mean (mu) and covariance matrix (sigma) for Inception activations. + images: shape (N, 3, 299, 299) + Returns: (mu, sigma) + """ + act = get_activations(images, batch_size=batch_size) mu = np.mean(act, axis=0) sigma = np.cov(act, rowvar=False) return mu, sigma - def calculate_frechet_distance(mu1, sigma1, mu2, sigma2, eps=1e-6): - """Calculates Frechet Distance between two distributions.""" + """ + Computes the Frechet Distance between two multivariate Gaussians described by + (mu1, sigma1) and (mu2, sigma2). + """ + mu1 = np.atleast_1d(mu1) + mu2 = np.atleast_1d(mu2) + sigma1 = np.atleast_2d(sigma1) + sigma2 = np.atleast_2d(sigma2) + diff = mu1 - mu2 - covmean, _ = linalg.sqrtm(sigma1 @ sigma2, disp=False) + + # Product might be almost singular + covmean, _ = linalg.sqrtm(sigma1.dot(sigma2), disp=False) if not np.isfinite(covmean).all(): - covmean = linalg.sqrtm((sigma1 + eps * np.eye(sigma1.shape[0])) @ - (sigma2 + eps * np.eye(sigma2.shape[0]))) + warnings.warn("FID calculation produced singular product; adding offset to covariances.") + offset = np.eye(sigma1.shape[0]) * eps + covmean, _ = linalg.sqrtm((sigma1 + offset).dot(sigma2 + offset), disp=False) + if np.iscomplexobj(covmean): + if not np.allclose(np.diagonal(covmean).imag, 0, atol=1e-3): + m = np.max(np.abs(covmean.imag)) + raise ValueError(f"Imaginary component in sqrtm: {m}") covmean = covmean.real - tr_covmean = np.trace(covmean) - fid_value = diff.dot(diff) + np.trace(sigma1) + np.trace(sigma2) - 2 * tr_covmean - return fid_value + + return diff.dot(diff) + np.trace(sigma1) + np.trace(sigma2) - 2 * np.trace(covmean) + + +def preprocess_image(im): + """ + Resizes to 299x299, changes dtype to float32 [0,1], and rearranges shape to (3,299,299). + """ + # If im is uint8, scale to float32 + if im.dtype == np.uint8: + im = im.astype(np.float32) / 255.0 + im = cv2.resize(im, (299, 299)) + im = np.rollaxis(im, 2, 0) # (H, W, 3) -> (3, H, W) + im = torch.from_numpy(im) # shape (3, 299, 299) + return im + +def preprocess_images(images, use_multiprocessing=False): + """ + Applies `preprocess_image` to a batch of images. + images: (N, H, W, 3) + Returns: torch.Tensor shape (N, 3, 299, 299) + """ + if use_multiprocessing: + with multiprocessing.Pool(multiprocessing.cpu_count()) as pool: + jobs = [pool.apply_async(preprocess_image, (im,)) for im in images] + final_images = torch.zeros(len(images), 3, 299, 299, dtype=torch.float32) + for idx, job in enumerate(jobs): + final_images[idx] = job.get() + else: + final_images = torch.stack([preprocess_image(im) for im in images], dim=0) + + return final_images def calculate_fid(images1, images2, use_multiprocessing=False, batch_size=64): - """ Calculate FID between images1 and images2 - Args: - images1: np.array, shape: (N, H, W, 3), dtype: np.float32 between 0-1 or np.uint8 - images2: np.array, shape: (N, H, W, 3), dtype: np.float32 between 0-1 or np.uint8 - use_multiprocessing: If multiprocessing should be used to pre-process the images - batch size: batch size used for inception network - Returns: - FID (scalar) """ + Calculate FID between two sets of images. + images1, images2: np.array shape (N, H, W, 3) + Returns: FID (float) + """ + # Preprocess to shape (N,3,299,299), float32 in [0,1] images1 = preprocess_images(images1, use_multiprocessing) images2 = preprocess_images(images2, use_multiprocessing) - mu1, sigma1 = calculate_activation_statistics(images1, batch_size) - print("mu1", mu1.shape, "sigma1", sigma1.shape) - mu2, sigma2 = calculate_activation_statistics(images2, batch_size) - print("mu2", mu2.shape, "sigma2", sigma2.shape) + + # Compute mu, sigma + mu1, sigma1 = calculate_activation_statistics(images1, batch_size=batch_size) + mu2, sigma2 = calculate_activation_statistics(images2, batch_size=batch_size) + + # Compute Frechet distance fid = calculate_frechet_distance(mu1, sigma1, mu2, sigma2) return fid + def load_style_generated_images(path, exclude="Abstractionism", seed=[188, 288, 588, 688, 888]): """ Loads all .png or .jpg images from a given path Warnings: Expects all images to be of same dtype and shape. From 6f816eace3e4df28cd6b9678f61526d6a36d6264 Mon Sep 17 00:00:00 2001 From: palisthadeshar Date: Tue, 7 Jan 2025 14:07:36 +0000 Subject: [PATCH 06/10] bugfix in evaluation framework --- mu/algorithms/esd/.gitignore | 1 + mu/algorithms/esd/__init__.py | 2 + .../esd/configs/evaluation_config.yaml | 18 ++++--- mu/algorithms/esd/configs/generate_sd.yaml | 10 ++-- mu/algorithms/esd/configs/train_config.yaml | 6 ++- mu/algorithms/esd/evalutator.py | 52 +++++++++++++------ mu/algorithms/esd/scripts/evaluate.py | 34 ++++++------ 7 files changed, 75 insertions(+), 48 deletions(-) create mode 100644 mu/algorithms/esd/.gitignore diff --git a/mu/algorithms/esd/.gitignore b/mu/algorithms/esd/.gitignore new file mode 100644 index 00000000..aa850f42 --- /dev/null +++ b/mu/algorithms/esd/.gitignore @@ -0,0 +1 @@ +src/* \ No newline at end of file diff --git a/mu/algorithms/esd/__init__.py b/mu/algorithms/esd/__init__.py index 46d2b31e..89300efe 100644 --- a/mu/algorithms/esd/__init__.py +++ b/mu/algorithms/esd/__init__.py @@ -2,10 +2,12 @@ from .model import ESDModel from .sampler import ESDSampler from .trainer import ESDTrainer +from .evalutator import ESDEvaluator __all__ = [ 'ESDAlgorithm', 'ESDModel', 'ESDSampler', 'ESDTrainer', + 'ESDEvaluator' ] \ No newline at end of file diff --git a/mu/algorithms/esd/configs/evaluation_config.yaml b/mu/algorithms/esd/configs/evaluation_config.yaml index c49e96f3..7f2d300a 100644 --- a/mu/algorithms/esd/configs/evaluation_config.yaml +++ b/mu/algorithms/esd/configs/evaluation_config.yaml @@ -1,16 +1,20 @@ model_config: "mu/algorithms/esd/configs/generate_sd.yaml" -ckpt_path: "../mu_erasing_concept_esd/results/style50/" +ckpt_path: "outputs/esd/finetuned_models/esd_Abstractionism_model.pth" theme: "Abstractionism" cfg_text: 9.0 seed: 188 +task: "class" ddim_steps: 100 image_height: 512 image_width: 512 ddim_eta: 0.0 -sampler_output_dir: "output/eval_results/mu_results/esd/" -classification_model: "vit_large_patch16_224.augreg_in21k" -eval_output_dir: "output/eval_results/mu_results/esd/" -reference_dir: "data/evaluation_images/" -forget_theme: "self-harm" -fid_output_path: "output/eval_results/mu_results/esd/" +sampler_output_dir: "outputs/eval_results/mu_results/esd/" +seed_list: ["188"] +# classification_model: "vit_large_patch16_224.augreg_in21k" +classification_model: "vit_large_patch16_224" +classification_ckpt: "outputs/esd/finetuned_models/esd_Abstractionism_model.pth" +eval_output_dir: "outputs/eval_results/mu_results/esd/" +reference_dir: "/home/ubuntu/Projects/msu_unlearningalgorithm/data/quick-canvas-dataset/sample/" +forget_theme: "Abstractionism" +fid_output_path: "outputs/eval_results/mu_results/esd/" multiprocessing: "False" \ No newline at end of file diff --git a/mu/algorithms/esd/configs/generate_sd.yaml b/mu/algorithms/esd/configs/generate_sd.yaml index d4effe56..cf7f8131 100644 --- a/mu/algorithms/esd/configs/generate_sd.yaml +++ b/mu/algorithms/esd/configs/generate_sd.yaml @@ -1,6 +1,6 @@ model: base_learning_rate: 1.0e-04 - target: ldm.models.diffusion.ddpm.LatentDiffusion + target: stable_diffusion.ldm.models.diffusion.ddpm.LatentDiffusion params: linear_start: 0.00085 linear_end: 0.0120 @@ -18,7 +18,7 @@ model: use_ema: False scheduler_config: # 10000 warmup steps - target: ldm.lr_scheduler.LambdaLinearScheduler + target: stable_diffusion.ldm.lr_scheduler.LambdaLinearScheduler params: warm_up_steps: [ 10000 ] cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases @@ -27,7 +27,7 @@ model: f_min: [ 1. ] unet_config: - target: ldm.modules.diffusionmodules.openaimodel.UNetModel + target: stable_diffusion.ldm.modules.diffusionmodules.openaimodel.UNetModel params: image_size: 32 # unused in_channels: 4 @@ -44,7 +44,7 @@ model: legacy: False first_stage_config: - target: ldm.models.autoencoder.AutoencoderKL + target: stable_diffusion.ldm.models.autoencoder.AutoencoderKL params: embed_dim: 4 monitor: val/rec_loss @@ -67,4 +67,4 @@ model: target: torch.nn.Identity cond_stage_config: - target: ldm.modules.encoders.modules.FrozenCLIPEmbedder + target: stable_diffusion.ldm.modules.encoders.modules.FrozenCLIPEmbedder diff --git a/mu/algorithms/esd/configs/train_config.yaml b/mu/algorithms/esd/configs/train_config.yaml index 86642dca..e800b72b 100644 --- a/mu/algorithms/esd/configs/train_config.yaml +++ b/mu/algorithms/esd/configs/train_config.yaml @@ -9,10 +9,12 @@ ddim_steps: 50 # Optional: DDIM steps of inference # Model configuration model_config_path: "mu/algorithms/esd/configs/model_config.yaml" -ckpt_path: "models/compvis/style50/compvis.ckpt" # Checkpoint path for Stable Diffusion +ckpt_path: "/home/ubuntu/Projects/UnlearnCanvas/UnlearnCanvas/machine_unlearning/models/compvis/style50/compvis.ckpt" # Checkpoint path for Stable Diffusion # Dataset directories -raw_dataset_dir: "data/quick-canvas-dataset/sample" +# raw_dataset_dir: "data/quick-canvas-dataset/sample" + +raw_dataset_dir: "/home/ubuntu/Projects/msu_unlearningalgorithm/data/quick-canvas-dataset/sample" processed_dataset_dir: "mu/algorithms/esd/data" dataset_type: "unlearncanvas" # Choices: ['unlearncanvas', 'i2p'] template: "style" # Choices: ['object', 'style', 'i2p'] diff --git a/mu/algorithms/esd/evalutator.py b/mu/algorithms/esd/evalutator.py index 746d0a98..b1a48718 100644 --- a/mu/algorithms/esd/evalutator.py +++ b/mu/algorithms/esd/evalutator.py @@ -17,6 +17,9 @@ from torch.nn import functional as F from mu.core.base_evaluator import BaseEvaluator +theme_available = ['Abstractionism', 'Bricks', 'Cartoon'] +class_available = ['Architectures', 'Bears', 'Birds'] + class ESDSampler(BaseSampler): """ESD Image Generator class extending a hypothetical BaseImageGenerator.""" @@ -44,7 +47,7 @@ def load_model(self) -> None: """ self.logger.info("Loading model...") model_ckpt_path = self.config["ckpt_path"] - model_config = load_config(self.config["model_config_path"]) + model_config = load_config(self.config["model_config"]) self.model = load_ckpt_from_config(model_config, model_ckpt_path, verbose=True) self.model.to(self.device) self.model.eval() @@ -68,7 +71,7 @@ def sample(self) -> None: self.logger.info(f"Generating images and saving to {output_dir}") seed_everything(seed) - + for test_theme in theme_available: for object_class in class_available: prompt = f"A {object_class} image in {test_theme.replace('_',' ')} style." @@ -104,13 +107,17 @@ def sample(self) -> None: # Convert to uint8 image x_sample = x_samples_ddim[0] - x_sample = (255. * x_sample.numpy()).round() + # x_sample = (255. * x_sample.numpy()).round() + if isinstance(x_sample, torch.Tensor): + x_sample = (255. * x_sample.cpu().detach().numpy()).round() + else: + x_sample = (255. * x_sample).round() x_sample = x_sample.astype(np.uint8) img = Image.fromarray(x_sample) #save image - filename = f"{test_theme}_{object_class}_seed{seed}.jpg" - outpath = os.path.join(output_dir, filename) + filename = f"{test_theme}_{object_class}_seed_{seed}.jpg" + outpath = os.path.join(output_dir,theme, filename) self.save_image(img, outpath) self.logger.info("Image generation completed.") @@ -130,14 +137,14 @@ class ESDEvaluator(BaseEvaluator): Inherits from the abstract BaseEvaluator. """ - def __init__(self, sampler: Any, config: Dict[str, Any], **kwargs): + def __init__(self,config: Dict[str, Any], **kwargs): """ Args: sampler (Any): An instance of a BaseSampler-derived class (e.g., ESDSampler). config (Dict[str, Any]): A dict of hyperparameters / evaluation settings. **kwargs: Additional overrides for config. """ - super().__init__(sampler, config, **kwargs) + super().__init__(config, **kwargs) self.logger = logging.getLogger(__name__) self.device = "cuda" if torch.cuda.is_available() else "cpu" self.sampler = ESDSampler(config) @@ -163,10 +170,10 @@ def load_model(self, *args, **kwargs): # Load checkpoint ckpt_path = self.config["classification_ckpt"] self.logger.info(f"Loading classification checkpoint from: {ckpt_path}") - checkpoint = torch.load(ckpt_path, map_location=self.device) - self.classification_model.load_state_dict(checkpoint["model_state_dict"]) + #NOTE: changed model_state_dict to state_dict as it was not present and added strict=False + self.classification_model.load_state_dict(torch.load(ckpt_path, map_location=self.device)["state_dict"],strict=False) self.classification_model.eval() - + self.logger.info("Classification model loaded successfully.") def preprocess_image(self, image: Image.Image): @@ -198,6 +205,8 @@ def calculate_accuracy(self, *args, **kwargs): if theme is not None: input_dir = os.path.join(input_dir, theme) + else: + input_dir = os.path.join(input_dir) os.makedirs(output_dir, exist_ok=True) output_path = (os.path.join(output_dir, f"{theme}.pth") @@ -273,7 +282,7 @@ def calculate_accuracy(self, *args, **kwargs): for seed in seed_list: for idx, object_class in enumerate(class_available): label_val = idx - img_file = f"{test_theme}_{object_class}_seed{seed}.jpg" + img_file = f"{test_theme}_{object_class}_seed_{seed}.jpg" img_path = os.path.join(input_dir, img_file) if not os.path.exists(img_path): self.logger.warning(f"Image not found: {img_path}") @@ -352,17 +361,26 @@ def save_results(self, results: dict, output_path: str): def run(self, *args, **kwargs): """ Run the complete evaluation process: - 1) Load the classification model - 2) Generate images (if you want to use sampler here) or skip if already generated - 3) Calculate accuracy - 4) calculate FID - 5) [Save final results if needed] + 1) Load the classification model + 2) Generate images (using sampler) + 3) Calculate accuracy + 4) Calculate FID + 5) Save final results """ + + # Call the sample method to generate images + # self.sampler.load_model() + # self.sampler.sample() + + # Load the classification model self.load_model() - + + # Proceed with accuracy and FID calculations self.calculate_accuracy() self.calculate_fid_score() + # Save results self.save_results(self.results, os.path.join(self.config["eval_output_dir"], "final_results.pth")) self.logger.info("Evaluation run completed.") + diff --git a/mu/algorithms/esd/scripts/evaluate.py b/mu/algorithms/esd/scripts/evaluate.py index 4c707a1d..b490c1cf 100644 --- a/mu/algorithms/esd/scripts/evaluate.py +++ b/mu/algorithms/esd/scripts/evaluate.py @@ -1,6 +1,6 @@ from argparse import ArgumentParser from mu.helpers import load_config -from mu.algorithms.esd.evalutator import ESDEvaluator +from mu.algorithms.esd import ESDEvaluator def main(): """Main entry point for running the entire pipeline.""" @@ -8,22 +8,22 @@ def main(): parser.add_argument('--config_path', required=True, help="Path to the YAML config file.") # Below: optional overrides for your config dictionary - parser.add_argument('--model_config', type=str, help="Override path for model_config") - parser.add_argument('--ckpt_path', type=str, help="Override checkpoint path") - parser.add_argument('--theme', type=str, help="Override the theme in config") - parser.add_argument('--cfg_text', type=float, help="Override the cfg_text (guidance scale)") - parser.add_argument('--seed', type=int, help="Override the seed") - parser.add_argument('--ddim_steps', type=int, help="Override the number of ddim_steps") - parser.add_argument('--image_height', type=int, help="Override image height") - parser.add_argument('--image_width', type=int, help="Override image width") - parser.add_argument('--ddim_eta', type=float, help="Override DDIM eta") - parser.add_argument('--sampler_output_dir', type=str, help="Override output directory for sampler") - parser.add_argument('--classification_model', type=str, help="Override the classification model name") - parser.add_argument('--eval_output_dir', type=str, help="Override evaluation output directory") - parser.add_argument('--reference_dir', type=str, help="Override reference images directory") - parser.add_argument('--forget_theme', type=str, help="Override the forget_theme setting") - parser.add_argument('--multiprocessing', type=str, help="Override the multiprocessing flag (True/False)") - parser.add_argument('--batch_size', type=int, help="Override the FID batch_size") + parser.add_argument('--model_config', type=str, help="Path for model_config") + parser.add_argument('--ckpt_path', type=str, help="checkpoint path") + parser.add_argument('--theme', type=str, help="theme") + parser.add_argument('--cfg_text', type=float, help="(guidance scale)") + parser.add_argument('--seed', type=int, help="seed") + parser.add_argument('--ddim_steps', type=int, help="number of ddim_steps") + parser.add_argument('--image_height', type=int, help="image height") + parser.add_argument('--image_width', type=int, help="image width") + parser.add_argument('--ddim_eta', type=float, help="DDIM eta") + parser.add_argument('--sampler_output_dir', type=str, help="output directory for sampler") + parser.add_argument('--classification_model', type=str, help="classification model name") + parser.add_argument('--eval_output_dir', type=str, help="evaluation output directory") + parser.add_argument('--reference_dir', type=str, help="reference images directory") + parser.add_argument('--forget_theme', type=str, help="forget_theme setting") + parser.add_argument('--multiprocessing', type=str, help="multiprocessing flag (True/False)") + parser.add_argument('--batch_size', type=int, help="FID batch_size") args = parser.parse_args() From 9698e256d4e6fe88e56fb950eba95eeb6a95fdae Mon Sep 17 00:00:00 2001 From: palisthadeshar Date: Tue, 7 Jan 2025 14:18:24 +0000 Subject: [PATCH 07/10] bugfix --- mu/algorithms/esd/environment.yaml | 2 +- mu/core/base_evaluator.py | 4 ++-- mu/helpers/utils.py | 16 ++++++++++++---- 3 files changed, 15 insertions(+), 7 deletions(-) diff --git a/mu/algorithms/esd/environment.yaml b/mu/algorithms/esd/environment.yaml index fc426eed..a4d31f84 100644 --- a/mu/algorithms/esd/environment.yaml +++ b/mu/algorithms/esd/environment.yaml @@ -45,7 +45,7 @@ dependencies: - h5py # This is for ICL_Inpainting - xtcocotools # This is for ICL_Inpainting - natsort # This is for ICL_Inpainting - - timm==0.3.2 # This is for ICL_Inpainting + - timm==0.6.7 # This is for ICL_Inpainting - git+https://github.com/cocodataset/panopticapi.git # This is for ICL_Inpainting - fairscale # This is for ICL_Inpainting - git+https://github.com/facebookresearch/detectron2.git # This is for ICL_Inpainting diff --git a/mu/core/base_evaluator.py b/mu/core/base_evaluator.py index 56d0690c..53aee0ad 100644 --- a/mu/core/base_evaluator.py +++ b/mu/core/base_evaluator.py @@ -4,8 +4,8 @@ class BaseEvaluator(ABC): """Abstract base class for evaluating image generation models.""" - def __init__(self, sampler: Any, config: Dict[str, Any], **kwargs): - self.sampler =sampler + def __init__(self,config: Dict[str, Any], **kwargs): + # self.sampler =sampler self.config = config @abstractmethod diff --git a/mu/helpers/utils.py b/mu/helpers/utils.py index 4fcb846b..8d090472 100644 --- a/mu/helpers/utils.py +++ b/mu/helpers/utils.py @@ -30,7 +30,11 @@ from torch import nn from torchvision.models import inception_v3 -from constants.const import theme_available, class_available +from stable_diffusion.constants.const import theme_available, class_available + +theme_available = ['Abstractionism', 'Bricks', 'Cartoon'] +# class_available = ['Architectures', 'Bears', 'Birds'] +class_available = ['Architectures'] def str2bool(v): @@ -160,7 +164,7 @@ def load_ckpt_from_config(config, ckpt, verbose=False): if "global_step" in pl_sd: print(f"Global Step: {pl_sd['global_step']}") sd = pl_sd["state_dict"] - model = instantiate_from_config(config.model) + model = instantiate_from_config(config['model']) m, u = model.load_state_dict(sd, strict=False) if len(m) > 0 and verbose: print("missing keys:") @@ -295,6 +299,10 @@ def preprocess_images(images, use_multiprocessing=False): images: (N, H, W, 3) Returns: torch.Tensor shape (N, 3, 299, 299) """ + if str(use_multiprocessing).lower() == "true": + use_multiprocessing = True + else: + use_multiprocessing = False if use_multiprocessing: with multiprocessing.Pool(multiprocessing.cpu_count()) as pool: jobs = [pool.apply_async(preprocess_image, (im,)) for im in images] @@ -351,7 +359,7 @@ def load_style_generated_images(path, exclude="Abstractionism", seed=[188, 288, for theme in theme_tested: for object_class in class_tested: for individual in seed: - image_paths.append(os.path.join(path, f"{theme}_{object_class}_seed{individual}.jpg")) + image_paths.append(os.path.join(path,"Abstractionism", f"{theme}_{object_class}_seed_{individual}.jpg")) if not os.path.isfile(image_paths[0]): raise FileNotFoundError(f"Could not find {image_paths[0]}") @@ -396,7 +404,7 @@ def load_style_ref_images(path, exclude="Seed_Images"): for theme in theme_tested: for object_class in class_tested: for idx in range(1, 6): - image_paths.append(os.path.join(path, theme, object_class, str(idx) + ".jpg")) + image_paths.append(os.path.join(path, "Abstractionism", object_class, str(idx) + ".jpg")) first_image = cv2.imread(image_paths[0]) W, H = 512, 512 From db7e034c9f65f7908558ad1d924d3978bebc1447 Mon Sep 17 00:00:00 2001 From: palisthadeshar Date: Wed, 8 Jan 2025 11:52:46 +0545 Subject: [PATCH 08/10] bugfix --- mu/algorithms/esd/evalutator.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/mu/algorithms/esd/evalutator.py b/mu/algorithms/esd/evalutator.py index b1a48718..d5f40d57 100644 --- a/mu/algorithms/esd/evalutator.py +++ b/mu/algorithms/esd/evalutator.py @@ -369,8 +369,8 @@ def run(self, *args, **kwargs): """ # Call the sample method to generate images - # self.sampler.load_model() - # self.sampler.sample() + self.sampler.load_model() + self.sampler.sample() # Load the classification model self.load_model() From 0622957ab6ed79dc857c49945d556ac793ffeb61 Mon Sep 17 00:00:00 2001 From: palisthadeshar Date: Wed, 8 Jan 2025 11:59:00 +0545 Subject: [PATCH 09/10] path fix --- mu/algorithms/esd/evalutator.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mu/algorithms/esd/evalutator.py b/mu/algorithms/esd/evalutator.py index d5f40d57..14ca1dd6 100644 --- a/mu/algorithms/esd/evalutator.py +++ b/mu/algorithms/esd/evalutator.py @@ -67,7 +67,7 @@ def sample(self) -> None: ddim_eta = self.config["ddim_eta"] output_dir = self.config["sampler_output_dir"] - os.makedirs(output_dir, exist_ok=True) + os.makedirs(output_dir,theme, exist_ok=True) self.logger.info(f"Generating images and saving to {output_dir}") seed_everything(seed) From abf40cd99b31e355cf6f5623952e58bff3c33856 Mon Sep 17 00:00:00 2001 From: palisthadeshar Date: Wed, 8 Jan 2025 12:16:55 +0545 Subject: [PATCH 10/10] output path fix --- mu/algorithms/esd/evalutator.py | 4 ++-- mu/helpers/utils.py | 10 +++++----- 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/mu/algorithms/esd/evalutator.py b/mu/algorithms/esd/evalutator.py index 14ca1dd6..407d02c8 100644 --- a/mu/algorithms/esd/evalutator.py +++ b/mu/algorithms/esd/evalutator.py @@ -17,8 +17,8 @@ from torch.nn import functional as F from mu.core.base_evaluator import BaseEvaluator -theme_available = ['Abstractionism', 'Bricks', 'Cartoon'] -class_available = ['Architectures', 'Bears', 'Birds'] +# theme_available = ['Abstractionism', 'Bricks', 'Cartoon'] +# class_available = ['Architectures', 'Bears', 'Birds'] class ESDSampler(BaseSampler): """ESD Image Generator class extending a hypothetical BaseImageGenerator.""" diff --git a/mu/helpers/utils.py b/mu/helpers/utils.py index 8d090472..ecbbd36b 100644 --- a/mu/helpers/utils.py +++ b/mu/helpers/utils.py @@ -32,9 +32,9 @@ from stable_diffusion.constants.const import theme_available, class_available -theme_available = ['Abstractionism', 'Bricks', 'Cartoon'] -# class_available = ['Architectures', 'Bears', 'Birds'] -class_available = ['Architectures'] +# theme_available = ['Abstractionism', 'Bricks', 'Cartoon'] +# # class_available = ['Architectures', 'Bears', 'Birds'] +# class_available = ['Architectures'] def str2bool(v): @@ -359,7 +359,7 @@ def load_style_generated_images(path, exclude="Abstractionism", seed=[188, 288, for theme in theme_tested: for object_class in class_tested: for individual in seed: - image_paths.append(os.path.join(path,"Abstractionism", f"{theme}_{object_class}_seed_{individual}.jpg")) + image_paths.append(os.path.join(path,theme, f"{theme}_{object_class}_seed_{individual}.jpg")) if not os.path.isfile(image_paths[0]): raise FileNotFoundError(f"Could not find {image_paths[0]}") @@ -404,7 +404,7 @@ def load_style_ref_images(path, exclude="Seed_Images"): for theme in theme_tested: for object_class in class_tested: for idx in range(1, 6): - image_paths.append(os.path.join(path, "Abstractionism", object_class, str(idx) + ".jpg")) + image_paths.append(os.path.join(path, theme, object_class, str(idx) + ".jpg")) first_image = cv2.imread(image_paths[0]) W, H = 512, 512