diff --git a/mu/algorithms/esd/.gitignore b/mu/algorithms/esd/.gitignore new file mode 100644 index 00000000..aa850f42 --- /dev/null +++ b/mu/algorithms/esd/.gitignore @@ -0,0 +1 @@ +src/* \ No newline at end of file diff --git a/mu/algorithms/esd/__init__.py b/mu/algorithms/esd/__init__.py index 46d2b31e..89300efe 100644 --- a/mu/algorithms/esd/__init__.py +++ b/mu/algorithms/esd/__init__.py @@ -2,10 +2,12 @@ from .model import ESDModel from .sampler import ESDSampler from .trainer import ESDTrainer +from .evalutator import ESDEvaluator __all__ = [ 'ESDAlgorithm', 'ESDModel', 'ESDSampler', 'ESDTrainer', + 'ESDEvaluator' ] \ No newline at end of file diff --git a/mu/algorithms/esd/configs/evaluation_config.yaml b/mu/algorithms/esd/configs/evaluation_config.yaml new file mode 100644 index 00000000..7f2d300a --- /dev/null +++ b/mu/algorithms/esd/configs/evaluation_config.yaml @@ -0,0 +1,20 @@ +model_config: "mu/algorithms/esd/configs/generate_sd.yaml" +ckpt_path: "outputs/esd/finetuned_models/esd_Abstractionism_model.pth" +theme: "Abstractionism" +cfg_text: 9.0 +seed: 188 +task: "class" +ddim_steps: 100 +image_height: 512 +image_width: 512 +ddim_eta: 0.0 +sampler_output_dir: "outputs/eval_results/mu_results/esd/" +seed_list: ["188"] +# classification_model: "vit_large_patch16_224.augreg_in21k" +classification_model: "vit_large_patch16_224" +classification_ckpt: "outputs/esd/finetuned_models/esd_Abstractionism_model.pth" +eval_output_dir: "outputs/eval_results/mu_results/esd/" +reference_dir: "/home/ubuntu/Projects/msu_unlearningalgorithm/data/quick-canvas-dataset/sample/" +forget_theme: "Abstractionism" +fid_output_path: "outputs/eval_results/mu_results/esd/" +multiprocessing: "False" \ No newline at end of file diff --git a/mu/algorithms/esd/configs/generate_sd.yaml b/mu/algorithms/esd/configs/generate_sd.yaml new file mode 100644 index 00000000..cf7f8131 --- /dev/null +++ b/mu/algorithms/esd/configs/generate_sd.yaml @@ -0,0 +1,70 @@ +model: + base_learning_rate: 1.0e-04 + target: stable_diffusion.ldm.models.diffusion.ddpm.LatentDiffusion + params: + linear_start: 0.00085 + linear_end: 0.0120 + num_timesteps_cond: 1 + log_every_t: 200 + timesteps: 1000 + first_stage_key: "jpg" + cond_stage_key: "txt" + image_size: 64 + channels: 4 + cond_stage_trainable: false # Note: different from the one we trained before + conditioning_key: crossattn + monitor: val/loss_simple_ema + scale_factor: 0.18215 + use_ema: False + + scheduler_config: # 10000 warmup steps + target: stable_diffusion.ldm.lr_scheduler.LambdaLinearScheduler + params: + warm_up_steps: [ 10000 ] + cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases + f_start: [ 1.e-6 ] + f_max: [ 1. ] + f_min: [ 1. ] + + unet_config: + target: stable_diffusion.ldm.modules.diffusionmodules.openaimodel.UNetModel + params: + image_size: 32 # unused + in_channels: 4 + out_channels: 4 + model_channels: 320 + attention_resolutions: [ 4, 2, 1 ] + num_res_blocks: 2 + channel_mult: [ 1, 2, 4, 4 ] + num_heads: 8 + use_spatial_transformer: True + transformer_depth: 1 + context_dim: 768 + use_checkpoint: True + legacy: False + + first_stage_config: + target: stable_diffusion.ldm.models.autoencoder.AutoencoderKL + params: + embed_dim: 4 + monitor: val/rec_loss + ddconfig: + double_z: true + z_channels: 4 + resolution: 256 + in_channels: 3 + out_ch: 3 + ch: 128 + ch_mult: + - 1 + - 2 + - 4 + - 4 + num_res_blocks: 2 + attn_resolutions: [] + dropout: 0.0 + lossconfig: + target: torch.nn.Identity + + cond_stage_config: + target: stable_diffusion.ldm.modules.encoders.modules.FrozenCLIPEmbedder diff --git a/mu/algorithms/esd/configs/train_config.yaml b/mu/algorithms/esd/configs/train_config.yaml index 86642dca..e800b72b 100644 --- a/mu/algorithms/esd/configs/train_config.yaml +++ b/mu/algorithms/esd/configs/train_config.yaml @@ -9,10 +9,12 @@ ddim_steps: 50 # Optional: DDIM steps of inference # Model configuration model_config_path: "mu/algorithms/esd/configs/model_config.yaml" -ckpt_path: "models/compvis/style50/compvis.ckpt" # Checkpoint path for Stable Diffusion +ckpt_path: "/home/ubuntu/Projects/UnlearnCanvas/UnlearnCanvas/machine_unlearning/models/compvis/style50/compvis.ckpt" # Checkpoint path for Stable Diffusion # Dataset directories -raw_dataset_dir: "data/quick-canvas-dataset/sample" +# raw_dataset_dir: "data/quick-canvas-dataset/sample" + +raw_dataset_dir: "/home/ubuntu/Projects/msu_unlearningalgorithm/data/quick-canvas-dataset/sample" processed_dataset_dir: "mu/algorithms/esd/data" dataset_type: "unlearncanvas" # Choices: ['unlearncanvas', 'i2p'] template: "style" # Choices: ['object', 'style', 'i2p'] diff --git a/mu/algorithms/esd/environment.yaml b/mu/algorithms/esd/environment.yaml index fc426eed..a4d31f84 100644 --- a/mu/algorithms/esd/environment.yaml +++ b/mu/algorithms/esd/environment.yaml @@ -45,7 +45,7 @@ dependencies: - h5py # This is for ICL_Inpainting - xtcocotools # This is for ICL_Inpainting - natsort # This is for ICL_Inpainting - - timm==0.3.2 # This is for ICL_Inpainting + - timm==0.6.7 # This is for ICL_Inpainting - git+https://github.com/cocodataset/panopticapi.git # This is for ICL_Inpainting - fairscale # This is for ICL_Inpainting - git+https://github.com/facebookresearch/detectron2.git # This is for ICL_Inpainting diff --git a/mu/algorithms/esd/evalutator.py b/mu/algorithms/esd/evalutator.py new file mode 100644 index 00000000..407d02c8 --- /dev/null +++ b/mu/algorithms/esd/evalutator.py @@ -0,0 +1,386 @@ +import os +import logging +import torch +import numpy as np +from PIL import Image +from torch import autocast +from pytorch_lightning import seed_everything +from mu.core.base_sampler import BaseSampler +from stable_diffusion.ldm.models.diffusion.ddim import DDIMSampler +from stable_diffusion.constants.const import theme_available, class_available +from mu.helpers import load_config +from mu.helpers.utils import load_ckpt_from_config,load_style_generated_images,load_style_ref_images,calculate_fid +import timm +from tqdm import tqdm +from typing import Any, Dict +from torchvision import transforms +from torch.nn import functional as F +from mu.core.base_evaluator import BaseEvaluator + +# theme_available = ['Abstractionism', 'Bricks', 'Cartoon'] +# class_available = ['Architectures', 'Bears', 'Birds'] + +class ESDSampler(BaseSampler): + """ESD Image Generator class extending a hypothetical BaseImageGenerator.""" + + def __init__(self, config: dict, **kwargs): + """ + Initialize the ESDImageGenerator with a YAML config (or dict). + + Args: + config (Dict[str, Any]): Dictionary of hyperparams / settings. + **kwargs: Additional keyword arguments that can override config entries. + """ + super().__init__() + + self.config = config + self.device = config.get("device", "cuda" if torch.cuda.is_available() else "cpu") + self.model = None + self.sampler = None + self.logger = logging.getLogger(__name__) + + def load_model(self) -> None: + """ + Load the model using `config` and initialize the sampler. + This example shows a direct load of a single model + DDIMSampler. + In practice, if you have an ESDModel with two models, adapt accordingly. + """ + self.logger.info("Loading model...") + model_ckpt_path = self.config["ckpt_path"] + model_config = load_config(self.config["model_config"]) + self.model = load_ckpt_from_config(model_config, model_ckpt_path, verbose=True) + self.model.to(self.device) + self.model.eval() + self.sampler = DDIMSampler(self.model) + self.logger.info("Model loaded and sampler initialized successfully.") + + def sample(self) -> None: + """ + Sample (generate) images using the loaded model and sampler, based on the config. + """ + steps = self.config["ddim_steps"] + theme = self.config["theme"] + cfg_text = self.config["cfg_text"] + seed = self.config["seed"] + H = self.config["image_height"] + W = self.config["image_width"] + ddim_eta = self.config["ddim_eta"] + output_dir = self.config["sampler_output_dir"] + + os.makedirs(output_dir,theme, exist_ok=True) + self.logger.info(f"Generating images and saving to {output_dir}") + + seed_everything(seed) + + for test_theme in theme_available: + for object_class in class_available: + prompt = f"A {object_class} image in {test_theme.replace('_',' ')} style." + self.logger.info(f"Sampling prompt: {prompt}") + + with torch.no_grad(): + with autocast(self.device): + with self.model.ema_scope(): + uc = self.model.get_learned_conditioning([""]) + c = self.model.get_learned_conditioning(prompt) + shape = [4, H // 8, W // 8] + # Generate samples + samples_ddim, _ = self.sampler.sample( + S=steps, + conditioning=c, + batch_size=1, + shape=shape, + verbose=False, + unconditional_guidance_scale=cfg_text, + unconditional_conditioning=uc, + eta=ddim_eta, + x_T=None + ) + + # Convert to numpy image + x_samples_ddim = self.model.decode_first_stage(samples_ddim) + x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0) + x_samples_ddim = x_samples_ddim.cpu().permute(0, 2, 3, 1).numpy() + + assert len(x_samples_ddim) == 1 + + + # Convert to uint8 image + x_sample = x_samples_ddim[0] + + # x_sample = (255. * x_sample.numpy()).round() + if isinstance(x_sample, torch.Tensor): + x_sample = (255. * x_sample.cpu().detach().numpy()).round() + else: + x_sample = (255. * x_sample).round() + x_sample = x_sample.astype(np.uint8) + img = Image.fromarray(x_sample) + + #save image + filename = f"{test_theme}_{object_class}_seed_{seed}.jpg" + outpath = os.path.join(output_dir,theme, filename) + self.save_image(img, outpath) + + self.logger.info("Image generation completed.") + + def save_image(self, image: Image.Image, file_path: str) -> None: + """ + Save an image to the specified path. + """ + image.save(file_path) + self.logger.info(f"Image saved at: {file_path}") + + + +class ESDEvaluator(BaseEvaluator): + """ + Example evaluator that calculates classification accuracy on generated images. + Inherits from the abstract BaseEvaluator. + """ + + def __init__(self,config: Dict[str, Any], **kwargs): + """ + Args: + sampler (Any): An instance of a BaseSampler-derived class (e.g., ESDSampler). + config (Dict[str, Any]): A dict of hyperparameters / evaluation settings. + **kwargs: Additional overrides for config. + """ + super().__init__(config, **kwargs) + self.logger = logging.getLogger(__name__) + self.device = "cuda" if torch.cuda.is_available() else "cpu" + self.sampler = ESDSampler(config) + self.classification_model = None + self.results = {} + + def load_model(self, *args, **kwargs): + """ + Load the classification model for evaluation, using 'timm' + or any approach you prefer. + We assume your config has 'classification_ckpt' and 'task' keys, etc. + """ + self.logger.info("Loading classification model...") + model = self.config.get("classification_model") + self.classification_model = timm.create_model( + model, + pretrained=True + ).to(self.device) + task = self.config['task'] # "style" or "class" + num_classes = len(theme_available) if task == "style" else len(class_available) + self.classification_model.head = torch.nn.Linear(1024, num_classes).to(self.device) + + # Load checkpoint + ckpt_path = self.config["classification_ckpt"] + self.logger.info(f"Loading classification checkpoint from: {ckpt_path}") + #NOTE: changed model_state_dict to state_dict as it was not present and added strict=False + self.classification_model.load_state_dict(torch.load(ckpt_path, map_location=self.device)["state_dict"],strict=False) + self.classification_model.eval() + + self.logger.info("Classification model loaded successfully.") + + def preprocess_image(self, image: Image.Image): + """ + Preprocess the input PIL image before feeding into the classifier. + Replicates the transforms from your accuracy.py script. + """ + image_transform = transforms.Compose([ + transforms.Resize((224, 224)), + transforms.ToTensor(), + transforms.Normalize([0.5], [0.5]), + ]) + return image_transform(image).unsqueeze(0).to(self.device) + + def calculate_accuracy(self, *args, **kwargs): + """ + Calculate accuracy of the classification model on generated images. + Mirrors the logic from your accuracy.py but integrated into a single method. + """ + self.logger.info("Starting accuracy calculation...") + + # Pull relevant config + theme = self.config.get("theme", None) + input_dir = self.config['sampler_output_dir'] + output_dir = self.config["eval_output_dir"] + seed_list = self.config.get("seed_list", [188, 288, 588, 688, 888]) + dry_run = self.config.get("dry_run", False) + task = self.config['task'] + + if theme is not None: + input_dir = os.path.join(input_dir, theme) + else: + input_dir = os.path.join(input_dir) + + os.makedirs(output_dir, exist_ok=True) + output_path = (os.path.join(output_dir, f"{theme}.pth") + if theme is not None + else os.path.join(output_dir, "result.pth")) + + # Initialize results dictionary + self.results = { + "test_theme": theme if theme is not None else "sd", + "input_dir": input_dir, + } + + if task == "style": + self.results["loss"] = {th: 0.0 for th in theme_available} + self.results["acc"] = {th: 0.0 for th in theme_available} + self.results["pred_loss"] = {th: 0.0 for th in theme_available} + self.results["misclassified"] = { + th: {oth: 0 for oth in theme_available} + for th in theme_available + } + else: # task == "class" + self.results["loss"] = {cls_: 0.0 for cls_ in class_available} + self.results["acc"] = {cls_: 0.0 for cls_ in class_available} + self.results["pred_loss"] = {cls_: 0.0 for cls_ in class_available} + self.results["misclassified"] = { + cls_: {other_cls: 0 for other_cls in class_available} + for cls_ in class_available + } + + # Evaluate + if task == "style": + for idx, test_theme in tqdm(enumerate(theme_available), total=len(theme_available)): + theme_label = idx + for seed in seed_list: + for object_class in class_available: + img_file = f"{test_theme}_{object_class}_seed{seed}.jpg" + img_path = os.path.join(input_dir, img_file) + if not os.path.exists(img_path): + self.logger.warning(f"Image not found: {img_path}") + continue + + # Preprocess + image = Image.open(img_path) + tensor_img = self.preprocess_image(image) + label = torch.tensor([theme_label]).to(self.device) + + # Forward pass + with torch.no_grad(): + res = self.classification_model(tensor_img) + + # Compute losses + loss = F.cross_entropy(res, label) + res_softmax = F.softmax(res, dim=1) + pred_loss_val = res_softmax[0][theme_label].item() + pred_label = torch.argmax(res).item() + pred_success = (pred_label == theme_label) + + # Accumulate stats + self.results["loss"][test_theme] += loss.item() + self.results["pred_loss"][test_theme] += pred_loss_val + # Probability of success is 1 if pred_success else 0, + # but for your code, you were dividing by total. So let's keep a sum for now: + self.results["acc"][test_theme] += (1 if pred_success else 0) + + misclassified_as = theme_available[pred_label] + self.results["misclassified"][test_theme][misclassified_as] += 1 + + if not dry_run: + self.save_results(self.results, output_path) + + else: # task == "class" + for test_theme in tqdm(theme_available, total=len(theme_available)): + for seed in seed_list: + for idx, object_class in enumerate(class_available): + label_val = idx + img_file = f"{test_theme}_{object_class}_seed_{seed}.jpg" + img_path = os.path.join(input_dir, img_file) + if not os.path.exists(img_path): + self.logger.warning(f"Image not found: {img_path}") + continue + + # Preprocess + image = Image.open(img_path) + tensor_img = self.preprocess_image(image) + label = torch.tensor([label_val]).to(self.device) + + with torch.no_grad(): + res = self.classification_model(tensor_img) + + loss = F.cross_entropy(res, label) + res_softmax = F.softmax(res, dim=1) + pred_loss_val = res_softmax[0][label_val].item() + pred_label = torch.argmax(res).item() + pred_success = (pred_label == label_val) + + self.results["loss"][object_class] += loss.item() + self.results["pred_loss"][object_class] += pred_loss_val + self.results["acc"][object_class] += (1 if pred_success else 0) + + misclassified_as = class_available[pred_label] + self.results["misclassified"][object_class][misclassified_as] += 1 + + if not dry_run: + self.save_results(self.results, output_path) + + self.logger.info("Accuracy calculation completed.") + + def calculate_fid_score(self, *args, **kwargs): + """ + Calculate the Fréchet Inception Distance (FID) score using the images + generated by ESDSampler vs. some reference images. + """ + self.logger.info("Starting FID calculation...") + + generated_path = self.config["sampler_output_dir"] + reference_path = self.config["reference_dir"] + forget_theme = self.config.get("forget_theme", None) + use_multiprocessing = self.config.get("multiprocessing", False) + batch_size = self.config.get("batch_size", 64) + output_dir = self.config["fid_output_path"] + os.makedirs(output_dir, exist_ok=True) + + images_generated = load_style_generated_images( + path=generated_path, + exclude=forget_theme, + seed=self.config.get("seed_list", [188, 288, 588, 688, 888]) + ) + images_reference = load_style_ref_images( + path=reference_path, + exclude=forget_theme + ) + + fid_value = calculate_fid( + images1=images_reference, + images2=images_generated, + use_multiprocessing=use_multiprocessing, + batch_size=batch_size + ) + self.logger.info(f"Calculated FID: {fid_value}") + self.results["FID"] = fid_value + fid_path = os.path.join(output_dir, "fid_value.pth") + torch.save({"FID": fid_value}, fid_path) + self.logger.info(f"FID results saved to: {fid_path}") + + def save_results(self, results: dict, output_path: str): + """ + Save evaluation results to a file. You can also do JSON or CSV if desired. + """ + torch.save(results, output_path) + self.logger.info(f"Results saved to: {output_path}") + + def run(self, *args, **kwargs): + """ + Run the complete evaluation process: + 1) Load the classification model + 2) Generate images (using sampler) + 3) Calculate accuracy + 4) Calculate FID + 5) Save final results + """ + + # Call the sample method to generate images + self.sampler.load_model() + self.sampler.sample() + + # Load the classification model + self.load_model() + + # Proceed with accuracy and FID calculations + self.calculate_accuracy() + self.calculate_fid_score() + + # Save results + self.save_results(self.results, os.path.join(self.config["eval_output_dir"], "final_results.pth")) + + self.logger.info("Evaluation run completed.") + diff --git a/mu/algorithms/esd/scripts/evaluate.py b/mu/algorithms/esd/scripts/evaluate.py new file mode 100644 index 00000000..b490c1cf --- /dev/null +++ b/mu/algorithms/esd/scripts/evaluate.py @@ -0,0 +1,70 @@ +from argparse import ArgumentParser +from mu.helpers import load_config +from mu.algorithms.esd import ESDEvaluator + +def main(): + """Main entry point for running the entire pipeline.""" + parser = ArgumentParser(description="Unified ESD Evaluation and Sampling") + parser.add_argument('--config_path', required=True, help="Path to the YAML config file.") + + # Below: optional overrides for your config dictionary + parser.add_argument('--model_config', type=str, help="Path for model_config") + parser.add_argument('--ckpt_path', type=str, help="checkpoint path") + parser.add_argument('--theme', type=str, help="theme") + parser.add_argument('--cfg_text', type=float, help="(guidance scale)") + parser.add_argument('--seed', type=int, help="seed") + parser.add_argument('--ddim_steps', type=int, help="number of ddim_steps") + parser.add_argument('--image_height', type=int, help="image height") + parser.add_argument('--image_width', type=int, help="image width") + parser.add_argument('--ddim_eta', type=float, help="DDIM eta") + parser.add_argument('--sampler_output_dir', type=str, help="output directory for sampler") + parser.add_argument('--classification_model', type=str, help="classification model name") + parser.add_argument('--eval_output_dir', type=str, help="evaluation output directory") + parser.add_argument('--reference_dir', type=str, help="reference images directory") + parser.add_argument('--forget_theme', type=str, help="forget_theme setting") + parser.add_argument('--multiprocessing', type=str, help="multiprocessing flag (True/False)") + parser.add_argument('--batch_size', type=int, help="FID batch_size") + + args = parser.parse_args() + + config = load_config(args.config_path) + + # Override config fields if CLI arguments are provided + if args.model_config is not None: + config["model_config"] = args.model_config + if args.ckpt_path is not None: + config["ckpt_path"] = args.ckpt_path + if args.theme is not None: + config["theme"] = args.theme + if args.cfg_text is not None: + config["cfg_text"] = args.cfg_text + if args.seed is not None: + config["seed"] = args.seed + if args.ddim_steps is not None: + config["ddim_steps"] = args.ddim_steps + if args.image_height is not None: + config["image_height"] = args.image_height + if args.image_width is not None: + config["image_width"] = args.image_width + if args.ddim_eta is not None: + config["ddim_eta"] = args.ddim_eta + if args.sampler_output_dir is not None: + config["sampler_output_dir"] = args.sampler_output_dir + if args.classification_model is not None: + config["classification_model"] = args.classification_model + if args.eval_output_dir is not None: + config["eval_output_dir"] = args.eval_output_dir + if args.reference_dir is not None: + config["reference_dir"] = args.reference_dir + if args.forget_theme is not None: + config["forget_theme"] = args.forget_theme + if args.multiprocessing is not None: + config["multiprocessing"] = args.multiprocessing + if args.batch_size is not None: + config["batch_size"] = args.batch_size + + evaluator = ESDEvaluator(config) + evaluator.run() + +if __name__ == "__main__": + main() diff --git a/mu/core/base_evaluator.py b/mu/core/base_evaluator.py new file mode 100644 index 00000000..53aee0ad --- /dev/null +++ b/mu/core/base_evaluator.py @@ -0,0 +1,39 @@ +from abc import ABC, abstractmethod +from typing import Any, Dict + +class BaseEvaluator(ABC): + """Abstract base class for evaluating image generation models.""" + + def __init__(self,config: Dict[str, Any], **kwargs): + # self.sampler =sampler + self.config = config + + @abstractmethod + def load_model(self, *args, **kwargs): + """Load the model for evaluation.""" + pass + + @abstractmethod + def preprocess_image(self, *args, **kwargs): + """Preprocess images before evaluation.""" + pass + + @abstractmethod + def calculate_accuracy(self, *args, **kwargs): + """Calculate accuracy of the model.""" + pass + + @abstractmethod + def calculate_fid_score(self, *args, **kwargs): + """Calculate the Fréchet Inception Distance (FID) score.""" + pass + + @abstractmethod + def save_results(self, *args, **kwargs): + """Save evaluation results to a file.""" + pass + + @abstractmethod + def run(self, *args, **kwargs): + """Run the evaluation process.""" + pass \ No newline at end of file diff --git a/mu/core/base_sampler.py b/mu/core/base_sampler.py index aef5c267..771bc1ee 100644 --- a/mu/core/base_sampler.py +++ b/mu/core/base_sampler.py @@ -18,3 +18,12 @@ def sample(self, **kwargs) -> Any: Any: Generated samples. """ pass + + def load_model(self, *args, **kwargs): + """Load an image.""" + pass + + def save_image(self, *args, **kwargs): + """Save an image.""" + pass + diff --git a/mu/helpers/utils.py b/mu/helpers/utils.py index 3bff1472..ecbbd36b 100644 --- a/mu/helpers/utils.py +++ b/mu/helpers/utils.py @@ -7,10 +7,36 @@ # from pytorch_lightning.utilities import rank_zero_only from pytorch_lightning.utilities.distributed import rank_zero_only from pathlib import Path - +import multiprocessing from stable_diffusion.ldm.util import instantiate_from_config +import torch +import cv2 +import numpy as np +from torchvision.models import inception_v3 +from torch import nn +from scipy import linalg +from stable_diffusion.constants.const import theme_available, class_available +import tqdm +import os +import cv2 +import torch +import warnings +import numpy as np +from tqdm import tqdm +from scipy import linalg +import multiprocessing +from torch import nn +from torchvision.models import inception_v3 + +from stable_diffusion.constants.const import theme_available, class_available + +# theme_available = ['Abstractionism', 'Bricks', 'Cartoon'] +# # class_available = ['Architectures', 'Bears', 'Birds'] +# class_available = ['Architectures'] + + def str2bool(v): if isinstance(v, bool): return v @@ -129,4 +155,266 @@ def load_config_from_yaml(config_path): @rank_zero_only def rank_zero_print(*args): - print(*args) \ No newline at end of file + print(*args) + + +def load_ckpt_from_config(config, ckpt, verbose=False): + print(f"Loading model from {ckpt}") + pl_sd = torch.load(ckpt, map_location="cpu") + if "global_step" in pl_sd: + print(f"Global Step: {pl_sd['global_step']}") + sd = pl_sd["state_dict"] + model = instantiate_from_config(config['model']) + m, u = model.load_state_dict(sd, strict=False) + if len(m) > 0 and verbose: + print("missing keys:") + print(m) + if len(u) > 0 and verbose: + print("unexpected keys:") + print(u) + + model.cuda() + model.eval() + return model + + + +def to_cuda(elements): + """Transfers elements to CUDA if GPU is available.""" + if torch.cuda.is_available(): + return elements.to("cuda") + return elements + +class PartialInceptionNetwork(nn.Module): + """ + A modified InceptionV3 network used for feature extraction. + Captures activations from the Mixed_7c layer and outputs shape (N, 2048). + """ + def __init__(self, transform_input=True): + super().__init__() + self.inception_network = inception_v3(pretrained=True) + # Register a forward hook to capture activations from Mixed_7c + self.inception_network.Mixed_7c.register_forward_hook(self.output_hook) + self.transform_input = transform_input + + def output_hook(self, module, input, output): + self.mixed_7c_output = output # shape (N, 2048, 8, 8) + + def forward(self, x): + """ + x: (N, 3, 299, 299) float32 in [0,1] + Returns: (N, 2048) float32 + """ + assert x.shape[1:] == (3, 299, 299), f"Expected (N,3,299,299), got {x.shape}" + # Shift to [-1, 1] + x = x * 2 - 1 + # Trigger output hook + _ = self.inception_network(x) + # Collect the activations + activations = self.mixed_7c_output # (N, 2048, 8, 8) + activations = torch.nn.functional.adaptive_avg_pool2d(activations, (1, 1)) + activations = activations.view(x.shape[0], 2048) + return activations + +def get_activations(images, batch_size=64): + """ + Calculates activations for the last pool layer for all images using PartialInceptionNetwork. + images: shape (N, 3, 299, 299), float32 in [0,1] + Returns: np.array shape (N, 2048) + """ + assert images.shape[1:] == (3, 299, 299) + num_images = images.shape[0] + inception_net = PartialInceptionNetwork().eval() + inception_net = to_cuda(inception_net) + + n_batches = int(np.ceil(num_images / batch_size)) + inception_activations = np.zeros((num_images, 2048), dtype=np.float32) + + idx = 0 + for _ in range(n_batches): + start = idx + end = min(start + batch_size, num_images) + ims = images[start:end] + ims = to_cuda(ims) + with torch.no_grad(): + batch_activations = inception_net(ims) + inception_activations[start:end, :] = batch_activations.cpu().numpy() + idx = end + return inception_activations + +def calculate_activation_statistics(images, batch_size=64): + """ + Calculates the mean (mu) and covariance matrix (sigma) for Inception activations. + images: shape (N, 3, 299, 299) + Returns: (mu, sigma) + """ + act = get_activations(images, batch_size=batch_size) + mu = np.mean(act, axis=0) + sigma = np.cov(act, rowvar=False) + return mu, sigma + +def calculate_frechet_distance(mu1, sigma1, mu2, sigma2, eps=1e-6): + """ + Computes the Frechet Distance between two multivariate Gaussians described by + (mu1, sigma1) and (mu2, sigma2). + """ + mu1 = np.atleast_1d(mu1) + mu2 = np.atleast_1d(mu2) + sigma1 = np.atleast_2d(sigma1) + sigma2 = np.atleast_2d(sigma2) + + diff = mu1 - mu2 + + # Product might be almost singular + covmean, _ = linalg.sqrtm(sigma1.dot(sigma2), disp=False) + if not np.isfinite(covmean).all(): + warnings.warn("FID calculation produced singular product; adding offset to covariances.") + offset = np.eye(sigma1.shape[0]) * eps + covmean, _ = linalg.sqrtm((sigma1 + offset).dot(sigma2 + offset), disp=False) + + if np.iscomplexobj(covmean): + if not np.allclose(np.diagonal(covmean).imag, 0, atol=1e-3): + m = np.max(np.abs(covmean.imag)) + raise ValueError(f"Imaginary component in sqrtm: {m}") + covmean = covmean.real + + return diff.dot(diff) + np.trace(sigma1) + np.trace(sigma2) - 2 * np.trace(covmean) + + +def preprocess_image(im): + """ + Resizes to 299x299, changes dtype to float32 [0,1], and rearranges shape to (3,299,299). + """ + # If im is uint8, scale to float32 + if im.dtype == np.uint8: + im = im.astype(np.float32) / 255.0 + im = cv2.resize(im, (299, 299)) + im = np.rollaxis(im, 2, 0) # (H, W, 3) -> (3, H, W) + im = torch.from_numpy(im) # shape (3, 299, 299) + return im + +def preprocess_images(images, use_multiprocessing=False): + """ + Applies `preprocess_image` to a batch of images. + images: (N, H, W, 3) + Returns: torch.Tensor shape (N, 3, 299, 299) + """ + if str(use_multiprocessing).lower() == "true": + use_multiprocessing = True + else: + use_multiprocessing = False + if use_multiprocessing: + with multiprocessing.Pool(multiprocessing.cpu_count()) as pool: + jobs = [pool.apply_async(preprocess_image, (im,)) for im in images] + final_images = torch.zeros(len(images), 3, 299, 299, dtype=torch.float32) + for idx, job in enumerate(jobs): + final_images[idx] = job.get() + else: + final_images = torch.stack([preprocess_image(im) for im in images], dim=0) + + return final_images + + +def calculate_fid(images1, images2, use_multiprocessing=False, batch_size=64): + """ + Calculate FID between two sets of images. + images1, images2: np.array shape (N, H, W, 3) + Returns: FID (float) + """ + # Preprocess to shape (N,3,299,299), float32 in [0,1] + images1 = preprocess_images(images1, use_multiprocessing) + images2 = preprocess_images(images2, use_multiprocessing) + + # Compute mu, sigma + mu1, sigma1 = calculate_activation_statistics(images1, batch_size=batch_size) + mu2, sigma2 = calculate_activation_statistics(images2, batch_size=batch_size) + + # Compute Frechet distance + fid = calculate_frechet_distance(mu1, sigma1, mu2, sigma2) + return fid + + +def load_style_generated_images(path, exclude="Abstractionism", seed=[188, 288, 588, 688, 888]): + """ Loads all .png or .jpg images from a given path + Warnings: Expects all images to be of same dtype and shape. + Args: + path: relative path to directory + Returns: + final_images: np.array of image dtype and shape. + """ + image_paths = [] + + if exclude is not None: + if exclude in theme_available: + theme_tested = [x for x in theme_available] + theme_tested.remove(exclude) + class_tested = class_available + else: # exclude is a class + theme_tested = theme_available + class_tested = [x for x in class_available] + class_tested.remove(exclude) + else: + theme_tested = theme_available + class_tested = class_available + for theme in theme_tested: + for object_class in class_tested: + for individual in seed: + image_paths.append(os.path.join(path,theme, f"{theme}_{object_class}_seed_{individual}.jpg")) + if not os.path.isfile(image_paths[0]): + raise FileNotFoundError(f"Could not find {image_paths[0]}") + + first_image = cv2.imread(image_paths[0]) + W, H = 512, 512 + image_paths.sort() + image_paths = image_paths + final_images = np.zeros((len(image_paths), H, W, 3), dtype=first_image.dtype) + for idx, impath in tqdm(enumerate(image_paths)): + im = cv2.imread(impath) + im = cv2.resize(im, (W, H)) # Resize image to 512x512 + im = im[:, :, ::-1] # Convert from BGR to RGB + assert im.dtype == final_images.dtype + final_images[idx] = im + return final_images + + +def load_style_ref_images(path, exclude="Seed_Images"): + """ Loads all .png or .jpg images from a given path + Warnings: Expects all images to be of same dtype and shape. + Args: + path: relative path to directory + Returns: + final_images: np.array of image dtype and shape. + """ + image_paths = [] + + if exclude is not None: + # assert exclude in theme_available, f"{exclude} not in {theme_available}" + if exclude in theme_available: + theme_tested = [x for x in theme_available] + theme_tested.remove(exclude) + class_tested = class_available + else: # exclude is a class + theme_tested = theme_available + class_tested = [x for x in class_available] + class_tested.remove(exclude) + else: + theme_tested = theme_available + class_tested = class_available + + for theme in theme_tested: + for object_class in class_tested: + for idx in range(1, 6): + image_paths.append(os.path.join(path, theme, object_class, str(idx) + ".jpg")) + + first_image = cv2.imread(image_paths[0]) + W, H = 512, 512 + image_paths.sort() + image_paths = image_paths + final_images = np.zeros((len(image_paths), H, W, 3), dtype=first_image.dtype) + for idx, impath in tqdm(enumerate(image_paths)): + im = cv2.imread(impath) + im = cv2.resize(im, (W, H)) # Resize image to 512x512 + im = im[:, :, ::-1] # Convert from BGR to RGB + assert im.dtype == final_images.dtype + final_images[idx] = im + return final_images \ No newline at end of file