diff --git a/.gitignore b/.gitignore index 41ddb2ca..9308137d 100644 --- a/.gitignore +++ b/.gitignore @@ -16,4 +16,5 @@ outputs/* models !models/.gitkeep -!mu/algorithms/erase_diff/.gitignore \ No newline at end of file +!mu/algorithms/erase_diff/.gitignore +.venv \ No newline at end of file diff --git a/.isort.cfg b/.isort.cfg new file mode 100644 index 00000000..7946afd5 --- /dev/null +++ b/.isort.cfg @@ -0,0 +1,2 @@ +[settings] +known_third_party = IPython,PIL,accelerate,albumentations,algorithms,cleanfid,clip,core,cv2,datasets,diffusers,einops,fire,huggingface_hub,imwatermark,kornia,ldm,main,matplotlib,nitro,nudenet,numpy,omegaconf,pandas,prettytable,pydantic,pytorch_lightning,quadprog,requests,safetensors,scann,scipy,six,src,taming,tensorflow,tensorflow_gan,tensorflow_hub,timm,torch,torchvision,tqdm,transformers,wandb,yaml diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml new file mode 100644 index 00000000..66dedef5 --- /dev/null +++ b/.pre-commit-config.yaml @@ -0,0 +1,14 @@ +default_language_version: + python: python3.12.7 +repos: + + - repo: https://github.com/asottile/seed-isort-config + rev: v2.2.0 + hooks: + - id: seed-isort-config + + - repo: https://github.com/pre-commit/mirrors-isort + rev: v5.10.1 + hooks: + - id: isort + args: ["--profile", "black"] diff --git a/lora_diffusion/__init__.py b/lora_diffusion/__init__.py index 286e4fec..29a77a36 100644 --- a/lora_diffusion/__init__.py +++ b/lora_diffusion/__init__.py @@ -1,5 +1,5 @@ -from .lora import * from .dataset import * -from .utils import * -from .preprocess_files import * +from .lora import * from .lora_manager import * +from .preprocess_files import * +from .utils import * diff --git a/lora_diffusion/cli_lora_add.py b/lora_diffusion/cli_lora_add.py index fc7f7e4a..f9c2d148 100644 --- a/lora_diffusion/cli_lora_add.py +++ b/lora_diffusion/cli_lora_add.py @@ -1,17 +1,13 @@ -from typing import Literal, Union, Dict import os import shutil +from typing import Dict, Literal, Union + import fire +import torch from diffusers import StableDiffusionPipeline from safetensors.torch import safe_open, save_file -import torch -from .lora import ( - tune_lora_scale, - patch_pipe, - collapse_lora, - monkeypatch_remove_lora, -) +from .lora import collapse_lora, monkeypatch_remove_lora, patch_pipe, tune_lora_scale from .lora_manager import lora_join from .to_ckpt_v2 import convert_to_ckpt diff --git a/lora_diffusion/cli_lora_pti.py b/lora_diffusion/cli_lora_pti.py index 7de4bae1..90fefea0 100644 --- a/lora_diffusion/cli_lora_pti.py +++ b/lora_diffusion/cli_lora_pti.py @@ -10,12 +10,14 @@ import random import re from pathlib import Path -from typing import Optional, List, Literal +from typing import List, Literal, Optional +import fire import torch import torch.nn.functional as F import torch.optim as optim import torch.utils.checkpoint +import wandb from diffusers import ( AutoencoderKL, DDPMScheduler, @@ -29,20 +31,18 @@ from torchvision import transforms from tqdm.auto import tqdm from transformers import CLIPTextModel, CLIPTokenizer -import wandb -import fire from lora_diffusion import ( + UNET_EXTENDED_TARGET_REPLACE, PivotalTuningDatasetCapation, + evaluate_pipe, extract_lora_ups_down, inject_trainable_lora, inject_trainable_lora_extended, inspect_lora, - save_lora_weight, - save_all, prepare_clip_model_sets, - evaluate_pipe, - UNET_EXTENDED_TARGET_REPLACE, + save_all, + save_lora_weight, ) diff --git a/lora_diffusion/cli_pt_to_safetensors.py b/lora_diffusion/cli_pt_to_safetensors.py index 9a4be40d..29aa6cd9 100644 --- a/lora_diffusion/cli_pt_to_safetensors.py +++ b/lora_diffusion/cli_pt_to_safetensors.py @@ -2,6 +2,7 @@ import fire import torch + from lora_diffusion import ( DEFAULT_TARGET_REPLACE, TEXT_ENCODER_DEFAULT_TARGET_REPLACE, diff --git a/lora_diffusion/cli_svd.py b/lora_diffusion/cli_svd.py index cf52aa0b..ef423aa5 100644 --- a/lora_diffusion/cli_svd.py +++ b/lora_diffusion/cli_svd.py @@ -1,15 +1,15 @@ import fire -from diffusers import StableDiffusionPipeline import torch import torch.nn as nn +from diffusers import StableDiffusionPipeline from .lora import ( - save_all, - _find_modules, LoraInjectedConv2d, LoraInjectedLinear, + _find_modules, inject_trainable_lora, inject_trainable_lora_extended, + save_all, ) diff --git a/lora_diffusion/dataset.py b/lora_diffusion/dataset.py index f1c28fd7..98d816b3 100644 --- a/lora_diffusion/dataset.py +++ b/lora_diffusion/dataset.py @@ -1,3 +1,4 @@ +import glob import random from pathlib import Path from typing import Dict, List, Optional, Tuple, Union @@ -6,7 +7,7 @@ from torch import zeros_like from torch.utils.data import Dataset from torchvision import transforms -import glob + from .preprocess_files import face_mask_google_mediapipe OBJECT_TEMPLATE = [ diff --git a/lora_diffusion/lora_manager.py b/lora_diffusion/lora_manager.py index 9d8306e4..ff646c50 100644 --- a/lora_diffusion/lora_manager.py +++ b/lora_diffusion/lora_manager.py @@ -1,12 +1,14 @@ from typing import List + import torch -from safetensors import safe_open from diffusers import StableDiffusionPipeline +from safetensors import safe_open + from .lora import ( - monkeypatch_or_replace_safeloras, apply_learned_embed_in_clip, - set_lora_diag, + monkeypatch_or_replace_safeloras, parse_safeloras_embeds, + set_lora_diag, ) diff --git a/lora_diffusion/patch_lora.py b/lora_diffusion/patch_lora.py index dc73ac83..69604194 100644 --- a/lora_diffusion/patch_lora.py +++ b/lora_diffusion/patch_lora.py @@ -1,5 +1,6 @@ +from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Type, Union + import torch -from typing import Callable, Dict, List, Optional, Set, Tuple, Type, Union, Any try: from safetensors.torch import safe_open diff --git a/lora_diffusion/preprocess_files.py b/lora_diffusion/preprocess_files.py index bedb89f5..e68869a2 100644 --- a/lora_diffusion/preprocess_files.py +++ b/lora_diffusion/preprocess_files.py @@ -2,15 +2,16 @@ # Have BLIP auto caption # Have CLIPSeg auto mask concept -from typing import List, Literal, Union, Optional, Tuple +import glob import os -from PIL import Image, ImageFilter -import torch -import numpy as np +from typing import List, Literal, Optional, Tuple, Union + import fire +import numpy as np +import torch +from PIL import Image, ImageFilter from tqdm import tqdm -import glob -from transformers import CLIPSegProcessor, CLIPSegForImageSegmentation +from transformers import CLIPSegForImageSegmentation, CLIPSegProcessor @torch.no_grad() @@ -133,7 +134,7 @@ def blip_captioning_dataset( Returns a list of captions for the given images """ - from transformers import BlipProcessor, BlipForConditionalGeneration + from transformers import BlipForConditionalGeneration, BlipProcessor processor = BlipProcessor.from_pretrained(model_id) model = BlipForConditionalGeneration.from_pretrained(model_id).to(device) diff --git a/lora_diffusion/to_ckpt_v2.py b/lora_diffusion/to_ckpt_v2.py index 15f39471..a02810d8 100644 --- a/lora_diffusion/to_ckpt_v2.py +++ b/lora_diffusion/to_ckpt_v2.py @@ -8,7 +8,6 @@ import torch - # =================# # UNet Conversion # # =================# diff --git a/lora_diffusion/utils.py b/lora_diffusion/utils.py index d8a3410d..a00cd16a 100644 --- a/lora_diffusion/utils.py +++ b/lora_diffusion/utils.py @@ -1,6 +1,10 @@ +import glob +import math +import os from typing import List, Union import torch +from diffusers import StableDiffusionPipeline from PIL import Image from transformers import ( CLIPProcessor, @@ -9,11 +13,7 @@ CLIPVisionModelWithProjection, ) -from diffusers import StableDiffusionPipeline -from .lora import patch_pipe, tune_lora_scale, _text_lora_path, _ti_lora_path -import os -import glob -import math +from .lora import _text_lora_path, _ti_lora_path, patch_pipe, tune_lora_scale EXAMPLE_PROMPTS = [ " swimming in a pool", diff --git a/mu/algorithms/concept_ablation/algorithm.py b/mu/algorithms/concept_ablation/algorithm.py index b45fb503..6ec382c1 100644 --- a/mu/algorithms/concept_ablation/algorithm.py +++ b/mu/algorithms/concept_ablation/algorithm.py @@ -1,12 +1,12 @@ import logging -import torch -import wandb from typing import Dict -from core.base_algorithm import BaseAlgorithm +import torch +import wandb +from algorithms.concept_ablation.data_handler import ConceptAblationDataHandler from algorithms.concept_ablation.model import ConceptAblationModel from algorithms.concept_ablation.trainer import ConceptAblationTrainer -from algorithms.concept_ablation.data_handler import ConceptAblationDataHandler +from core.base_algorithm import BaseAlgorithm class ConceptAblationAlgorithm(BaseAlgorithm): diff --git a/mu/algorithms/concept_ablation/datasets/dataset.py b/mu/algorithms/concept_ablation/datasets/dataset.py index 9994901c..4dbf0c30 100644 --- a/mu/algorithms/concept_ablation/datasets/dataset.py +++ b/mu/algorithms/concept_ablation/datasets/dataset.py @@ -1,11 +1,13 @@ import os -from typing import List, Tuple, Callable, Any from pathlib import Path -from PIL import Image +from typing import Any, Callable, List, Tuple + import torch -from torch.utils.data import Dataset +from PIL import Image from src import utils from src.utils import safe_dir +from torch.utils.data import Dataset + # Import your filtering logic if available # from src.filter import filter as filter_fn diff --git a/mu/algorithms/concept_ablation/handler.py b/mu/algorithms/concept_ablation/handler.py index 047e35b7..31049b68 100644 --- a/mu/algorithms/concept_ablation/handler.py +++ b/mu/algorithms/concept_ablation/handler.py @@ -1,10 +1,13 @@ -import os import logging +import os from typing import Dict -from torch.utils.data import DataLoader -from core.base_data_handler import BaseDataHandler -from mu.datasets.utils import get_transform, INTERPOLATIONS + from algorithms.concept_ablation.datasets.dataset import ConceptAblationDataset +from core.base_data_handler import BaseDataHandler +from torch.utils.data import DataLoader + +from mu.datasets.utils import INTERPOLATIONS, get_transform + class ConceptAblationDataHandler(BaseDataHandler): """ diff --git a/mu/algorithms/concept_ablation/model.py b/mu/algorithms/concept_ablation/model.py index 60a462ad..2e71af7f 100644 --- a/mu/algorithms/concept_ablation/model.py +++ b/mu/algorithms/concept_ablation/model.py @@ -1,10 +1,12 @@ -from core.base_model import BaseModel -from stable_diffusion.ldm.util import instantiate_from_config -from omegaconf import OmegaConf -import torch from pathlib import Path from typing import Any +import torch +from core.base_model import BaseModel +from omegaconf import OmegaConf + +from stable_diffusion.ldm.util import instantiate_from_config + class ConceptAblationModel(BaseModel): """ diff --git a/mu/algorithms/concept_ablation/scripts/train.py b/mu/algorithms/concept_ablation/scripts/train.py index e4c1c793..972d1ca1 100644 --- a/mu/algorithms/concept_ablation/scripts/train.py +++ b/mu/algorithms/concept_ablation/scripts/train.py @@ -1,11 +1,14 @@ import argparse -import os import logging +import os import sys # Adjust these imports according to your project structure from algorithms.concept_ablation.algorithm import ConceptAblationAlgorithm -from algorithms.erase_diff.logger import setup_logger # or create a similar logger for concept_ablation if needed +from algorithms.erase_diff.logger import ( + setup_logger, # or create a similar logger for concept_ablation if needed +) + def main(): parser = argparse.ArgumentParser( diff --git a/mu/algorithms/concept_ablation/trainer.py b/mu/algorithms/concept_ablation/trainer.py index 2c58c23e..75286887 100644 --- a/mu/algorithms/concept_ablation/trainer.py +++ b/mu/algorithms/concept_ablation/trainer.py @@ -1,12 +1,12 @@ +import logging +from typing import Dict + import torch import wandb +from core.base_trainer import BaseTrainer from torch.nn import MSELoss from torch.optim import Adam from tqdm import tqdm -import logging - -from core.base_trainer import BaseTrainer -from typing import Dict class ConceptAblationTrainer(BaseTrainer): diff --git a/mu/algorithms/erase_diff/algorithm.py b/mu/algorithms/erase_diff/algorithm.py index 4aca1ea8..c7c9f972 100644 --- a/mu/algorithms/erase_diff/algorithm.py +++ b/mu/algorithms/erase_diff/algorithm.py @@ -1,15 +1,17 @@ # mu/algorithms/erase_diff/algorithm.py -import torch -import wandb -from typing import Dict import logging from pathlib import Path +from typing import Dict -from mu.core import BaseAlgorithm +import torch +import wandb + +from mu.algorithms.erase_diff.data_handler import EraseDiffDataHandler from mu.algorithms.erase_diff.model import EraseDiffModel from mu.algorithms.erase_diff.trainer import EraseDiffTrainer -from mu.algorithms.erase_diff.data_handler import EraseDiffDataHandler +from mu.core import BaseAlgorithm + class EraseDiffAlgorithm(BaseAlgorithm): """ diff --git a/mu/algorithms/erase_diff/data_handler.py b/mu/algorithms/erase_diff/data_handler.py index cc2f457e..22937718 100644 --- a/mu/algorithms/erase_diff/data_handler.py +++ b/mu/algorithms/erase_diff/data_handler.py @@ -1,13 +1,13 @@ +import logging import os -import pandas as pd from typing import Any, Dict -from torch.utils.data import DataLoader -import logging +import pandas as pd +from torch.utils.data import DataLoader from mu.algorithms.erase_diff.datasets.erase_diff_dataset import EraseDiffDataset -from mu.datasets.constants import * from mu.core import BaseDataHandler +from mu.datasets.constants import * from mu.helpers import read_text_lines diff --git a/mu/algorithms/erase_diff/datasets/erase_diff_dataset.py b/mu/algorithms/erase_diff/datasets/erase_diff_dataset.py index 9fabbe41..db8c5403 100644 --- a/mu/algorithms/erase_diff/datasets/erase_diff_dataset.py +++ b/mu/algorithms/erase_diff/datasets/erase_diff_dataset.py @@ -1,10 +1,12 @@ import os -from typing import Any, Tuple, Dict +from typing import Any, Dict, Tuple + from torch.utils.data import DataLoader -from mu.datasets import UnlearnCanvasDataset, I2PDataset, BaseDataset +from mu.datasets import BaseDataset, I2PDataset, UnlearnCanvasDataset from mu.datasets.utils import INTERPOLATIONS, get_transform + class EraseDiffDataset(BaseDataset): """ Dataset class for the EraseDiff algorithm. diff --git a/mu/algorithms/erase_diff/model.py b/mu/algorithms/erase_diff/model.py index c409996c..52fc8121 100644 --- a/mu/algorithms/erase_diff/model.py +++ b/mu/algorithms/erase_diff/model.py @@ -1,13 +1,15 @@ # mu/algorithms/erase_diff/model.py -import torch +import logging from pathlib import Path from typing import Any -import logging + +import torch from mu.core import BaseModel from mu.helpers import load_model_from_config + class EraseDiffModel(BaseModel): """ EraseDiffModel handles loading, saving, and interacting with the Stable Diffusion model. diff --git a/mu/algorithms/erase_diff/scripts/train.py b/mu/algorithms/erase_diff/scripts/train.py index 0d8f58de..61b7910c 100644 --- a/mu/algorithms/erase_diff/scripts/train.py +++ b/mu/algorithms/erase_diff/scripts/train.py @@ -1,14 +1,15 @@ # mu/algorithms/erase_diff/scripts/train.py import argparse +import logging import os from pathlib import Path -import logging from mu.algorithms.erase_diff import EraseDiffAlgorithm -from mu.helpers import setup_logger, load_config +from mu.helpers import load_config, setup_logger from mu.helpers.path_setup import * + def main(): parser = argparse.ArgumentParser( prog='TrainEraseDiff', diff --git a/mu/algorithms/erase_diff/trainer.py b/mu/algorithms/erase_diff/trainer.py index e0831351..650186b2 100644 --- a/mu/algorithms/erase_diff/trainer.py +++ b/mu/algorithms/erase_diff/trainer.py @@ -1,18 +1,19 @@ # mu/algorithms/erase_diff/trainer.py -import torch import gc -from tqdm import tqdm -import random -from torch.nn import MSELoss -import wandb import logging +import random from pathlib import Path + +import torch +import wandb from timm.utils import AverageMeter -import logging +from torch.nn import MSELoss +from tqdm import tqdm -from mu.core import BaseTrainer from mu.algorithms.erase_diff.model import EraseDiffModel +from mu.core import BaseTrainer + class EraseDiffTrainer(BaseTrainer): """ diff --git a/mu/algorithms/esd/algorithm.py b/mu/algorithms/esd/algorithm.py index e8fc9b4d..1cac2788 100644 --- a/mu/algorithms/esd/algorithm.py +++ b/mu/algorithms/esd/algorithm.py @@ -1,13 +1,15 @@ +import logging +from pathlib import Path +from typing import Dict + import torch import wandb -from typing import Dict -from pathlib import Path -import logging -from mu.core import BaseAlgorithm from mu.algorithms.esd.model import ESDModel -from mu.algorithms.esd.trainer import ESDTrainer from mu.algorithms.esd.sampler import ESDSampler +from mu.algorithms.esd.trainer import ESDTrainer +from mu.core import BaseAlgorithm + class ESDAlgorithm(BaseAlgorithm): """ diff --git a/mu/algorithms/esd/model.py b/mu/algorithms/esd/model.py index 65e47611..3e2c684d 100644 --- a/mu/algorithms/esd/model.py +++ b/mu/algorithms/esd/model.py @@ -1,10 +1,12 @@ -import torch -from typing import Any from pathlib import Path +from typing import Any + +import torch from mu.core import BaseModel from mu.helpers import load_model_from_config + class ESDModel(BaseModel): """ ESDModel handles loading, saving, and interacting with the Stable Diffusion model. diff --git a/mu/algorithms/esd/sampler.py b/mu/algorithms/esd/sampler.py index a88ee526..e4775c4a 100644 --- a/mu/algorithms/esd/sampler.py +++ b/mu/algorithms/esd/sampler.py @@ -1,7 +1,8 @@ -from mu.core import BaseSampler -from stable_diffusion.ldm.models.diffusion.ddim import DDIMSampler from mu.algorithms.esd.algorithm import ESDModel +from mu.core import BaseSampler from mu.helpers import sample_model +from stable_diffusion.ldm.models.diffusion.ddim import DDIMSampler + class ESDSampler(BaseSampler): """Sampler for the ESD algorithm.""" diff --git a/mu/algorithms/esd/scripts/train.py b/mu/algorithms/esd/scripts/train.py index 2d92ccb5..3bbe8e01 100644 --- a/mu/algorithms/esd/scripts/train.py +++ b/mu/algorithms/esd/scripts/train.py @@ -1,12 +1,13 @@ # mu/algorithms/esd/scripts/train.py import argparse -import os import logging +import os from mu.algorithms.esd.algorithm import ESDAlgorithm -from mu.helpers import setup_logger, load_config, setup_logger -from mu.helpers.path_setup import * +from mu.helpers import load_config, setup_logger +from mu.helpers.path_setup import * + def main(): parser = argparse.ArgumentParser( diff --git a/mu/algorithms/esd/trainer.py b/mu/algorithms/esd/trainer.py index e147f8f2..64502895 100644 --- a/mu/algorithms/esd/trainer.py +++ b/mu/algorithms/esd/trainer.py @@ -1,13 +1,14 @@ -import torch -from tqdm import tqdm import random -from torch.nn import MSELoss +import torch +from torch.nn import MSELoss +from tqdm import tqdm -from mu.helpers import load_model_from_config, sample_model -from mu.core import BaseTrainer from mu.algorithms.esd.model import ESDModel from mu.algorithms.esd.sampler import ESDSampler +from mu.core import BaseTrainer +from mu.helpers import load_model_from_config, sample_model + class ESDTrainer(BaseTrainer): """Trainer for the ESD algorithm.""" diff --git a/mu/algorithms/forget_me_not/algorithm.py b/mu/algorithms/forget_me_not/algorithm.py index 134d6192..4be42ef5 100644 --- a/mu/algorithms/forget_me_not/algorithm.py +++ b/mu/algorithms/forget_me_not/algorithm.py @@ -1,14 +1,15 @@ # forget_me_not/algorithm.py import logging -import torch from typing import Dict -import wandb +import torch +import wandb from algorithms.forget_me_not.data_handler import ForgetMeNotDataHandler from algorithms.forget_me_not.model import ForgetMeNotModel from algorithms.forget_me_not.trainer import ForgetMeNotTrainer + class ForgetMeNotAlgorithm: """ Algorithm class orchestrating the Forget Me Not unlearning process. diff --git a/mu/algorithms/forget_me_not/data_handler.py b/mu/algorithms/forget_me_not/data_handler.py index dc4de535..dc317f48 100644 --- a/mu/algorithms/forget_me_not/data_handler.py +++ b/mu/algorithms/forget_me_not/data_handler.py @@ -1,11 +1,13 @@ # forget_me_not/data_handler.py +import os from typing import Any, Dict -from torch.utils.data import DataLoader -from core.base_data_handler import BaseDataHandler + from algorithms.forget_me_not.datasets.forget_me_not_dataset import ForgetMeNotDataset -from datasets.constants import * -import os +from core.base_data_handler import BaseDataHandler +from datasets.constants import * +from torch.utils.data import DataLoader + class ForgetMeNotDataHandler(BaseDataHandler): """ diff --git a/mu/algorithms/forget_me_not/datasets/forget_me_not_dataset.py b/mu/algorithms/forget_me_not/datasets/forget_me_not_dataset.py index 0a0aeb77..29a79621 100644 --- a/mu/algorithms/forget_me_not/datasets/forget_me_not_dataset.py +++ b/mu/algorithms/forget_me_not/datasets/forget_me_not_dataset.py @@ -6,8 +6,8 @@ import cv2 import numpy as np -from PIL import Image, ImageFilter from datasets.base_dataset import BaseDataset +from PIL import Image, ImageFilter from torchvision import transforms # Templates for object and style prompts diff --git a/mu/algorithms/forget_me_not/model.py b/mu/algorithms/forget_me_not/model.py index 3cb5e20d..e1ee8c4c 100644 --- a/mu/algorithms/forget_me_not/model.py +++ b/mu/algorithms/forget_me_not/model.py @@ -1,10 +1,22 @@ # forget_me_not/model.py import logging + import torch -from diffusers import StableDiffusionPipeline, UNet2DConditionModel, AutoencoderKL, DDPMScheduler +from diffusers import ( + AutoencoderKL, + DDPMScheduler, + StableDiffusionPipeline, + UNet2DConditionModel, +) from transformers import CLIPTextModel, CLIPTokenizer -from lora_diffusion.patch_lora import safe_open, parse_safeloras_embeds, apply_learned_embed_in_clip + +from lora_diffusion.patch_lora import ( + apply_learned_embed_in_clip, + parse_safeloras_embeds, + safe_open, +) + class ForgetMeNotModel: """ diff --git a/mu/algorithms/forget_me_not/scripts/train_attn.py b/mu/algorithms/forget_me_not/scripts/train_attn.py index ea973ee4..2b11b0bb 100644 --- a/mu/algorithms/forget_me_not/scripts/train_attn.py +++ b/mu/algorithms/forget_me_not/scripts/train_attn.py @@ -2,8 +2,10 @@ import argparse import os + from algorithms.algorithms.forget_me_not.algorithm import ForgetMeNotAlgorithm + def main(): parser = argparse.ArgumentParser(description='Forget Me Not - Train Attention') parser.add_argument('--theme', type=str, required=True, help='Theme or concept to unlearn.') diff --git a/mu/algorithms/forget_me_not/scripts/train_ti.py b/mu/algorithms/forget_me_not/scripts/train_ti.py index 45ca1fb8..a0b47cb4 100644 --- a/mu/algorithms/forget_me_not/scripts/train_ti.py +++ b/mu/algorithms/forget_me_not/scripts/train_ti.py @@ -1,8 +1,10 @@ import argparse import os + import yaml from algorithms.forget_me_not.algorithm import ForgetMeNotAlgorithm + def load_config(yaml_path): """Loads the configuration from a YAML file.""" if os.path.exists(yaml_path): diff --git a/mu/algorithms/forget_me_not/trainer.py b/mu/algorithms/forget_me_not/trainer.py index b2029758..aa947c5a 100644 --- a/mu/algorithms/forget_me_not/trainer.py +++ b/mu/algorithms/forget_me_not/trainer.py @@ -1,16 +1,17 @@ # forget_me_not/trainer.py -import os -import math import logging +import math +import os from typing import Dict -from tqdm import tqdm + import torch +import torch.nn.functional as F +from accelerate.utils import set_seed +from diffusers.optimization import get_scheduler from torch.optim import AdamW +from tqdm import tqdm -from diffusers.optimization import get_scheduler -from accelerate.utils import set_seed -import torch.nn.functional as F class ForgetMeNotTrainer: """ diff --git a/mu/algorithms/saliency_unlearning/algorithm.py b/mu/algorithms/saliency_unlearning/algorithm.py index 8fb3376a..2ec5ea6c 100644 --- a/mu/algorithms/saliency_unlearning/algorithm.py +++ b/mu/algorithms/saliency_unlearning/algorithm.py @@ -1,14 +1,18 @@ +import logging import os -import torch -import wandb from typing import Dict -import logging -from core.base_algorithm import BaseAlgorithm +import torch +import wandb +from algorithms.saliency_unlearning.data_handler import SaliencyUnlearnDataHandler +from algorithms.saliency_unlearning.masking import ( + accumulate_gradients_for_mask, + save_mask, +) from algorithms.saliency_unlearning.model import SaliencyUnlearnModel from algorithms.saliency_unlearning.trainer import SaliencyUnlearnTrainer -from algorithms.saliency_unlearning.data_handler import SaliencyUnlearnDataHandler -from algorithms.saliency_unlearning.masking import accumulate_gradients_for_mask, save_mask +from core.base_algorithm import BaseAlgorithm + class SaliencyUnlearnAlgorithm(BaseAlgorithm): """ diff --git a/mu/algorithms/saliency_unlearning/data_handler.py b/mu/algorithms/saliency_unlearning/data_handler.py index 56de9d5e..cf412c91 100644 --- a/mu/algorithms/saliency_unlearning/data_handler.py +++ b/mu/algorithms/saliency_unlearning/data_handler.py @@ -2,14 +2,18 @@ import os from typing import Any, Dict, List -from torch.utils.data import DataLoader -from algorithms.saliency_unlearning.datasets.saliency_unlearn_dataset import SaliencyUnlearnDataset + +from algorithms.saliency_unlearning.datasets.saliency_unlearn_dataset import ( + SaliencyUnlearnDataset, +) from algorithms.saliency_unlearning.logger import setup_logger from core.base_data_handler import BaseDataHandler from datasets.constants import * +from torch.utils.data import DataLoader from mu.helpers import read_text_lines + class SaliencyUnlearnDataHandler(BaseDataHandler): """ Concrete data handler for the SaliencyUnlearn algorithm. diff --git a/mu/algorithms/saliency_unlearning/datasets/saliency_unlearn_dataset.py b/mu/algorithms/saliency_unlearning/datasets/saliency_unlearn_dataset.py index 9c0ded13..56b030b9 100644 --- a/mu/algorithms/saliency_unlearning/datasets/saliency_unlearn_dataset.py +++ b/mu/algorithms/saliency_unlearning/datasets/saliency_unlearn_dataset.py @@ -1,12 +1,15 @@ # algorithms/saliency_unlearning/datasets/saliency_unlearn_dataset.py import os -import torch -from typing import Any, Tuple, Dict -from torch.utils.data import DataLoader +from typing import Any, Dict, Tuple + +import torch from datasets.unlearn_canvas_dataset import UnlearnCanvasDataset +from torch.utils.data import DataLoader + from mu.datasets.utils import INTERPOLATIONS, get_transform + class SaliencyUnlearnDataset(UnlearnCanvasDataset): """ Dataset class for the SaliencyUnlearn algorithm. diff --git a/mu/algorithms/saliency_unlearning/logger.py b/mu/algorithms/saliency_unlearning/logger.py index 09f4a6a6..d1ad2cb8 100644 --- a/mu/algorithms/saliency_unlearning/logger.py +++ b/mu/algorithms/saliency_unlearning/logger.py @@ -1,6 +1,7 @@ import logging from pathlib import Path + def setup_logger(log_file: str = "erase_diff_training.log", level: int = logging.INFO) -> logging.Logger: """ Setup a logger for the training process. diff --git a/mu/algorithms/saliency_unlearning/masking.py b/mu/algorithms/saliency_unlearning/masking.py index b29f0560..d89b0b9b 100644 --- a/mu/algorithms/saliency_unlearning/masking.py +++ b/mu/algorithms/saliency_unlearning/masking.py @@ -1,8 +1,10 @@ +import logging +import os + import torch from torch.nn import MSELoss from tqdm import tqdm -import os -import logging + def accumulate_gradients_for_mask(model, forget_loader, prompt, c_guidance, device, lr=1e-5, num_timesteps=1000, threshold=0.5, batch_size=4): """ diff --git a/mu/algorithms/saliency_unlearning/model.py b/mu/algorithms/saliency_unlearning/model.py index 81d20348..4cdc2762 100644 --- a/mu/algorithms/saliency_unlearning/model.py +++ b/mu/algorithms/saliency_unlearning/model.py @@ -1,12 +1,15 @@ # algorithms/saliency_unlearning/model.py -from core.base_model import BaseModel -from stable_diffusion.ldm.util import instantiate_from_config -from omegaconf import OmegaConf -import torch from pathlib import Path from typing import Any, Dict +import torch +from core.base_model import BaseModel +from omegaconf import OmegaConf + +from stable_diffusion.ldm.util import instantiate_from_config + + class SaliencyUnlearnModel(BaseModel): """ SaliencyUnlearnModel handles loading, saving, and interacting with the Stable Diffusion model. diff --git a/mu/algorithms/saliency_unlearning/sampler.py b/mu/algorithms/saliency_unlearning/sampler.py index f113ecbb..b62dbdd1 100644 --- a/mu/algorithms/saliency_unlearning/sampler.py +++ b/mu/algorithms/saliency_unlearning/sampler.py @@ -1,8 +1,10 @@ # algorithms/saliency_unlearning/sampler.py -from stable_diffusion.ldm.models.diffusion.ddim import DDIMSampler from typing import Any +from stable_diffusion.ldm.models.diffusion.ddim import DDIMSampler + + class SaliencyUnlearnSampler(DDIMSampler): """ Sampler class for the SaliencyUnlearn algorithm. diff --git a/mu/algorithms/saliency_unlearning/scripts/generate_mask.py b/mu/algorithms/saliency_unlearning/scripts/generate_mask.py index 7a30c4e4..fc819a5c 100644 --- a/mu/algorithms/saliency_unlearning/scripts/generate_mask.py +++ b/mu/algorithms/saliency_unlearning/scripts/generate_mask.py @@ -1,10 +1,11 @@ +import argparse import os import sys -import argparse -import torch +import torch from algorithms.saliency_unlearning.algorithm import MaskingAlgorithm + def main(): parser = argparse.ArgumentParser(prog='GenerateMask', description='Generate saliency mask using MaskingAlgorithm.') diff --git a/mu/algorithms/saliency_unlearning/scripts/train.py b/mu/algorithms/saliency_unlearning/scripts/train.py index 65c08d33..8b9b121f 100644 --- a/mu/algorithms/saliency_unlearning/scripts/train.py +++ b/mu/algorithms/saliency_unlearning/scripts/train.py @@ -1,14 +1,14 @@ # saliency_unlearning/scripts/train.py -import os -import sys -import torch -from tqdm import tqdm import argparse import logging +import os +import sys +import torch from algorithms.saliency_unlearning.algorithm import SaliencyUnlearnAlgorithm from algorithms.saliency_unlearning.logger import setup_logger +from tqdm import tqdm if __name__ == '__main__': parser = argparse.ArgumentParser( diff --git a/mu/algorithms/saliency_unlearning/trainer.py b/mu/algorithms/saliency_unlearning/trainer.py index d0f94bc6..0df871af 100644 --- a/mu/algorithms/saliency_unlearning/trainer.py +++ b/mu/algorithms/saliency_unlearning/trainer.py @@ -1,21 +1,24 @@ # algorithms/saliency_unlearning/trainer.py -from core.base_trainer import BaseTrainer -from algorithms.saliency_unlearning.model import SaliencyUnlearnModel -import torch import gc -from tqdm import tqdm +import logging import random +from pathlib import Path +from typing import Dict + +import torch +import wandb +from algorithms.saliency_unlearning.model import SaliencyUnlearnModel from algorithms.saliency_unlearning.utils import load_model_from_config, sample_model +from core.base_trainer import BaseTrainer +from omegaconf import OmegaConf +from timm.utils import AverageMeter from torch.nn import MSELoss -import wandb +from tqdm import tqdm + from stable_diffusion.ldm.models.diffusion.ddim import DDIMSampler -import logging -from pathlib import Path from stable_diffusion.ldm.util import instantiate_from_config -from omegaconf import OmegaConf -from timm.utils import AverageMeter -from typing import Dict + class SaliencyUnlearnTrainer(BaseTrainer): """ diff --git a/mu/algorithms/saliency_unlearning/utils.py b/mu/algorithms/saliency_unlearning/utils.py index 759bdbfa..023a25ce 100644 --- a/mu/algorithms/saliency_unlearning/utils.py +++ b/mu/algorithms/saliency_unlearning/utils.py @@ -1,11 +1,14 @@ # algorithms/saliency_unlearning/utils.py -from omegaconf import OmegaConf -import torch -from typing import Any from pathlib import Path -from stable_diffusion.ldm.util import instantiate_from_config +from typing import Any + +import torch +from omegaconf import OmegaConf + from stable_diffusion.ldm.models.diffusion.ddim import DDIMSampler +from stable_diffusion.ldm.util import instantiate_from_config + def load_model_from_config(config_path: str, ckpt_path: str, device: str = "cpu") -> Any: """ diff --git a/mu/algorithms/scissorhands/algorithm.py b/mu/algorithms/scissorhands/algorithm.py index 6d94a70f..ad1f8bbb 100644 --- a/mu/algorithms/scissorhands/algorithm.py +++ b/mu/algorithms/scissorhands/algorithm.py @@ -1,14 +1,16 @@ -import torch -import wandb -from typing import Dict import logging from pathlib import Path +from typing import Dict -from mu.core import BaseAlgorithm +import torch +import wandb + +from mu.algorithms.scissorhands.data_handler import ScissorHandsDataHandler from mu.algorithms.scissorhands.model import ScissorHandsModel from mu.algorithms.scissorhands.trainer import ScissorHandsTrainer -from mu.algorithms.scissorhands.data_handler import ScissorHandsDataHandler +from mu.core import BaseAlgorithm + class ScissorHandsAlgorithm(BaseAlgorithm): """ diff --git a/mu/algorithms/scissorhands/data_handler.py b/mu/algorithms/scissorhands/data_handler.py index 1db06f57..4df1c160 100644 --- a/mu/algorithms/scissorhands/data_handler.py +++ b/mu/algorithms/scissorhands/data_handler.py @@ -1,11 +1,12 @@ +import logging import os from typing import Any, Dict, List + from torch.utils.data import DataLoader -import logging from mu.algorithms.scissorhands.datasets.scissorhands_dataset import ScissorHandsDataset -from mu.datasets.constants import * from mu.core import BaseDataHandler +from mu.datasets.constants import * from mu.helpers import read_text_lines diff --git a/mu/algorithms/scissorhands/datasets/scissorhands_dataset.py b/mu/algorithms/scissorhands/datasets/scissorhands_dataset.py index 1b701695..3861440a 100644 --- a/mu/algorithms/scissorhands/datasets/scissorhands_dataset.py +++ b/mu/algorithms/scissorhands/datasets/scissorhands_dataset.py @@ -1,12 +1,14 @@ # algorithms/scissorhands/datasets/erase_diff_dataset.py import os -from typing import Any, Tuple, Dict +from typing import Any, Dict, Tuple + from torch.utils.data import DataLoader -from mu.datasets import UnlearnCanvasDataset, I2PDataset, BaseDataset +from mu.datasets import BaseDataset, I2PDataset, UnlearnCanvasDataset from mu.datasets.utils import INTERPOLATIONS, get_transform + class ScissorHandsDataset(BaseDataset): """ Dataset class for the ScissorHands algorithm. diff --git a/mu/algorithms/scissorhands/model.py b/mu/algorithms/scissorhands/model.py index b37fb7e7..939c6025 100644 --- a/mu/algorithms/scissorhands/model.py +++ b/mu/algorithms/scissorhands/model.py @@ -1,10 +1,10 @@ # erase_diff/model.py -import torch +import logging from pathlib import Path from typing import Any -import logging +import torch from mu.core import BaseModel from mu.helpers import load_model_from_config diff --git a/mu/algorithms/scissorhands/scripts/train.py b/mu/algorithms/scissorhands/scripts/train.py index f6c9ee87..150fde02 100644 --- a/mu/algorithms/scissorhands/scripts/train.py +++ b/mu/algorithms/scissorhands/scripts/train.py @@ -1,14 +1,15 @@ # mu/algorithms/scissorhands/scripts/train.py import argparse +import logging import os from pathlib import Path -import logging from mu.algorithms.scissorhands.algorithm import ScissorHandsAlgorithm -from mu.helpers import setup_logger, load_config +from mu.helpers import load_config, setup_logger from mu.helpers.path_setup import * + def main(): parser = argparse.ArgumentParser( prog='TrainScissorHands', diff --git a/mu/algorithms/scissorhands/trainer.py b/mu/algorithms/scissorhands/trainer.py index 1eb7f83d..7ef0f03d 100644 --- a/mu/algorithms/scissorhands/trainer.py +++ b/mu/algorithms/scissorhands/trainer.py @@ -1,16 +1,16 @@ +import copy +import logging + import torch -from tqdm import tqdm from torch.nn import MSELoss -import logging -import copy +from tqdm import tqdm -from mu.core import BaseTrainer -from mu.algorithms.scissorhands.model import ScissorHandsModel from mu.algorithms.scissorhands.data_handler import EraseDiffDataHandler +from mu.algorithms.scissorhands.model import ScissorHandsModel +from mu.algorithms.scissorhands.utils import project2cone2, snip +from mu.core import BaseTrainer -from mu.algorithms.scissorhands.utils import snip, project2cone2 - class ScissorHandsTrainer(BaseTrainer): """ Trainer for the ScissorHands algorithm. diff --git a/mu/algorithms/scissorhands/utils.py b/mu/algorithms/scissorhands/utils.py index e9e98042..63a6a2c8 100644 --- a/mu/algorithms/scissorhands/utils.py +++ b/mu/algorithms/scissorhands/utils.py @@ -1,12 +1,13 @@ # mu/algorithms/scissorhands/utils.py -import torch -from pathlib import Path +import copy import gc +from pathlib import Path + import numpy as np -from timm.models.layers import trunc_normal_ -import copy import quadprog +import torch +from timm.models.layers import trunc_normal_ from torch.nn import MSELoss diff --git a/mu/algorithms/selective_amnesia/algorithm.py b/mu/algorithms/selective_amnesia/algorithm.py index 9cf1faf6..d9c07b9f 100644 --- a/mu/algorithms/selective_amnesia/algorithm.py +++ b/mu/algorithms/selective_amnesia/algorithm.py @@ -1,11 +1,12 @@ import logging from typing import Dict + import torch import wandb -from core.base_algorithm import BaseAlgorithm +from algorithms.selective_amnesia.data_handler import SelectiveAmnesiaDataHandler from algorithms.selective_amnesia.model import SelectiveAmnesiaModel from algorithms.selective_amnesia.trainer import SelectiveAmnesiaTrainer -from algorithms.selective_amnesia.data_handler import SelectiveAmnesiaDataHandler +from core.base_algorithm import BaseAlgorithm logger = logging.getLogger(__name__) diff --git a/mu/algorithms/selective_amnesia/data_handler.py b/mu/algorithms/selective_amnesia/data_handler.py index 4975b804..deedabe3 100644 --- a/mu/algorithms/selective_amnesia/data_handler.py +++ b/mu/algorithms/selective_amnesia/data_handler.py @@ -1,9 +1,11 @@ import logging from typing import Dict -from torch.utils.data import DataLoader -from core.base_data_handler import BaseDataHandler -from mu.datasets.utils import get_transform, INTERPOLATIONS + from algorithms.selective_amnesia.datasets.dataset import SelectiveAmnesiaDataset +from core.base_data_handler import BaseDataHandler +from torch.utils.data import DataLoader + +from mu.datasets.utils import INTERPOLATIONS, get_transform logger = logging.getLogger(__name__) diff --git a/mu/algorithms/selective_amnesia/datasets/selective_amnesia_dataset.py b/mu/algorithms/selective_amnesia/datasets/selective_amnesia_dataset.py index e17ca854..72818486 100644 --- a/mu/algorithms/selective_amnesia/datasets/selective_amnesia_dataset.py +++ b/mu/algorithms/selective_amnesia/datasets/selective_amnesia_dataset.py @@ -1,7 +1,9 @@ import os +from typing import Callable, List, Tuple + from PIL import Image from torch.utils.data import Dataset -from typing import List, Tuple, Callable + class SelectiveAmnesiaDataset(Dataset): """ diff --git a/mu/algorithms/selective_amnesia/model.py b/mu/algorithms/selective_amnesia/model.py index 336a6a86..0930cbed 100644 --- a/mu/algorithms/selective_amnesia/model.py +++ b/mu/algorithms/selective_amnesia/model.py @@ -1,12 +1,13 @@ -import torch -from omegaconf import OmegaConf -from stable_diffusion.ldm.util import instantiate_from_config -from core.base_model import BaseModel +import logging from pathlib import Path from typing import Any -import logging -from algorithms.selective_amnesia.utils import modify_weights, load_fim +import torch +from algorithms.selective_amnesia.utils import load_fim, modify_weights +from core.base_model import BaseModel +from omegaconf import OmegaConf + +from stable_diffusion.ldm.util import instantiate_from_config logger = logging.getLogger(__name__) diff --git a/mu/algorithms/selective_amnesia/scripts/train.py b/mu/algorithms/selective_amnesia/scripts/train.py index 1a692569..1efc3487 100644 --- a/mu/algorithms/selective_amnesia/scripts/train.py +++ b/mu/algorithms/selective_amnesia/scripts/train.py @@ -1,8 +1,12 @@ import argparse -import os import logging +import os + +from algorithms.erase_diff.logger import ( + setup_logger, # or adapt a similar logger if needed. +) from algorithms.selective_amnesia.algorithm import SelectiveAmnesiaAlgorithm -from algorithms.erase_diff.logger import setup_logger # or adapt a similar logger if needed. + def main(): parser = argparse.ArgumentParser(description='Train Selective Amnesia') diff --git a/mu/algorithms/selective_amnesia/trainer.py b/mu/algorithms/selective_amnesia/trainer.py index dbc5c5be..25278d67 100644 --- a/mu/algorithms/selective_amnesia/trainer.py +++ b/mu/algorithms/selective_amnesia/trainer.py @@ -1,11 +1,12 @@ +import logging +from typing import Dict + import torch +import wandb +from core.base_trainer import BaseTrainer from torch.nn import MSELoss from torch.optim import Adam -from core.base_trainer import BaseTrainer -import wandb -import logging from tqdm import tqdm -from typing import Dict logger = logging.getLogger(__name__) diff --git a/mu/algorithms/selective_amnesia/utils.py b/mu/algorithms/selective_amnesia/utils.py index c1eb3766..ec31070e 100644 --- a/mu/algorithms/selective_amnesia/utils.py +++ b/mu/algorithms/selective_amnesia/utils.py @@ -1,6 +1,7 @@ -import torch -import os import logging +import os + +import torch logger = logging.getLogger(__name__) diff --git a/mu/algorithms/semipermeable_membrane/algorithm.py b/mu/algorithms/semipermeable_membrane/algorithm.py index c032aba2..23f7b4ab 100644 --- a/mu/algorithms/semipermeable_membrane/algorithm.py +++ b/mu/algorithms/semipermeable_membrane/algorithm.py @@ -1,14 +1,17 @@ # semipermeable_membrane/algorithm.py import logging -import torch -import wandb from typing import Dict +import torch +import wandb +from algorithms.semipermeable_membrane.data_handler import ( + SemipermeableMembraneDataHandler, +) from algorithms.semipermeable_membrane.model import SemipermeableMembraneModel -from algorithms.semipermeable_membrane.data_handler import SemipermeableMembraneDataHandler from algorithms.semipermeable_membrane.trainer import SemipermeableMembraneTrainer + class SemipermeableMembraneAlgorithm: """ SemipermeableMembraneAlgorithm orchestrates the setup and training of the SPM method. diff --git a/mu/algorithms/semipermeable_membrane/data_handler.py b/mu/algorithms/semipermeable_membrane/data_handler.py index 4b48d33d..21214924 100644 --- a/mu/algorithms/semipermeable_membrane/data_handler.py +++ b/mu/algorithms/semipermeable_membrane/data_handler.py @@ -1,13 +1,14 @@ # semipermeable_membrane/data_handler.py -from typing import List, Optional, Tuple import logging -from core.base_data_handler import BaseDataHandler -from datasets.constants import * -import yaml from pathlib import Path +from typing import List, Optional, Tuple +import yaml from algorithms.semipermeable_membrane.src.configs.prompt import PromptSettings +from core.base_data_handler import BaseDataHandler +from datasets.constants import * + class SemipermeableMembraneDataHandler(BaseDataHandler): """ diff --git a/mu/algorithms/semipermeable_membrane/logger.py b/mu/algorithms/semipermeable_membrane/logger.py index 3406bd03..352b0dbe 100644 --- a/mu/algorithms/semipermeable_membrane/logger.py +++ b/mu/algorithms/semipermeable_membrane/logger.py @@ -2,6 +2,7 @@ import logging + def setup_logger(name: str, log_file: str = None, level=logging.INFO) -> logging.Logger: """ Setup a logger for the module. diff --git a/mu/algorithms/semipermeable_membrane/model.py b/mu/algorithms/semipermeable_membrane/model.py index c36e3661..a88f4939 100644 --- a/mu/algorithms/semipermeable_membrane/model.py +++ b/mu/algorithms/semipermeable_membrane/model.py @@ -1,12 +1,13 @@ # semipermeable_membrane/model.py import logging + import torch +from algorithms.semipermeable_membrane.src.models import model_util +from algorithms.semipermeable_membrane.src.models.spm import SPMLayer, SPMNetwork from diffusers import StableDiffusionPipeline from torch import nn -from algorithms.semipermeable_membrane.src.models.spm import SPMNetwork, SPMLayer -from algorithms.semipermeable_membrane.src.models import model_util class SemipermeableMembraneModel(nn.Module): """ diff --git a/mu/algorithms/semipermeable_membrane/scripts/train.py b/mu/algorithms/semipermeable_membrane/scripts/train.py index b3bb93a9..1c32cde7 100644 --- a/mu/algorithms/semipermeable_membrane/scripts/train.py +++ b/mu/algorithms/semipermeable_membrane/scripts/train.py @@ -1,13 +1,14 @@ # semipermeable_membrane/scripts/train.py import argparse -import os -import yaml import logging +import os +import yaml from algorithms.semipermeable_membrane.algorithm import SemipermeableMembraneAlgorithm from algorithms.semipermeable_membrane.logger import setup_logger + def main(): parser = argparse.ArgumentParser( description='Train Semipermeable Membrane Algorithm' diff --git a/mu/algorithms/semipermeable_membrane/src/configs/config.py b/mu/algorithms/semipermeable_membrane/src/configs/config.py index abc980aa..92632375 100644 --- a/mu/algorithms/semipermeable_membrane/src/configs/config.py +++ b/mu/algorithms/semipermeable_membrane/src/configs/config.py @@ -1,9 +1,8 @@ from typing import Literal, Optional +import torch import yaml - from pydantic import BaseModel -import torch PRECISION_TYPES = Literal["fp32", "fp16", "bf16", "float32", "float16", "bfloat16"] diff --git a/mu/algorithms/semipermeable_membrane/src/configs/generation_config.py b/mu/algorithms/semipermeable_membrane/src/configs/generation_config.py index 50de982d..319cc28a 100644 --- a/mu/algorithms/semipermeable_membrane/src/configs/generation_config.py +++ b/mu/algorithms/semipermeable_membrane/src/configs/generation_config.py @@ -1,6 +1,7 @@ -from pydantic import BaseModel import torch import yaml +from pydantic import BaseModel + class GenerationConfig(BaseModel): prompts: list[str] = [] diff --git a/mu/algorithms/semipermeable_membrane/src/configs/prompt.py b/mu/algorithms/semipermeable_membrane/src/configs/prompt.py index db61aaab..47670e1a 100644 --- a/mu/algorithms/semipermeable_membrane/src/configs/prompt.py +++ b/mu/algorithms/semipermeable_membrane/src/configs/prompt.py @@ -1,16 +1,14 @@ +import random +from pathlib import Path from typing import Literal, Optional, Union -import yaml -from pathlib import Path import pandas as pd -import random - -from pydantic import BaseModel, root_validator -from transformers import CLIPTextModel, CLIPTokenizer import torch - -from src.misc.clip_templates import imagenet_templates +import yaml +from pydantic import BaseModel, root_validator from src.engine.train_util import encode_prompts +from src.misc.clip_templates import imagenet_templates +from transformers import CLIPTextModel, CLIPTokenizer ACTION_TYPES = Literal[ "erase", diff --git a/mu/algorithms/semipermeable_membrane/src/engine/sampling.py b/mu/algorithms/semipermeable_membrane/src/engine/sampling.py index 5f203240..3a143a16 100644 --- a/mu/algorithms/semipermeable_membrane/src/engine/sampling.py +++ b/mu/algorithms/semipermeable_membrane/src/engine/sampling.py @@ -1,6 +1,6 @@ import random -import torch +import torch from src.configs.prompt import PromptEmbedsPair diff --git a/mu/algorithms/semipermeable_membrane/src/engine/train_util.py b/mu/algorithms/semipermeable_membrane/src/engine/train_util.py index f901022f..985dcd5b 100644 --- a/mu/algorithms/semipermeable_membrane/src/engine/train_util.py +++ b/mu/algorithms/semipermeable_membrane/src/engine/train_util.py @@ -1,20 +1,18 @@ # ref: # - https://github.com/p1atdev/LECO/blob/main/train_util.py -from typing import Optional, Union - import ast import importlib +from typing import Optional, Union + import torch -from torch.optim import Optimizer import transformers -from transformers import CLIPTextModel, CLIPTokenizer -from diffusers import UNet2DConditionModel, SchedulerMixin, DiffusionPipeline -from diffusers.optimization import SchedulerType, TYPE_TO_SCHEDULER_FUNCTION - +from diffusers import DiffusionPipeline, SchedulerMixin, UNet2DConditionModel +from diffusers.optimization import TYPE_TO_SCHEDULER_FUNCTION, SchedulerType from src.models.model_util import SDXL_TEXT_ENCODER_TYPE - +from torch.optim import Optimizer from tqdm import tqdm +from transformers import CLIPTextModel, CLIPTokenizer UNET_IN_CHANNELS = 4 # Stable Diffusion の in_channels は 4 で固定。XLも同じ。 VAE_SCALE_FACTOR = 8 # 2 ** (len(vae.config.block_out_channels) - 1) = 8 diff --git a/mu/algorithms/semipermeable_membrane/src/evaluation/__init__.py b/mu/algorithms/semipermeable_membrane/src/evaluation/__init__.py index ced28bf0..c414b223 100644 --- a/mu/algorithms/semipermeable_membrane/src/evaluation/__init__.py +++ b/mu/algorithms/semipermeable_membrane/src/evaluation/__init__.py @@ -1,6 +1,7 @@ -from .eval_util import * -from .evaluator import * -from .clip_evaluator import * from .artwork_evaluator import * +from .clip_evaluator import * + # from .i2p_evaluator import * from .coco_evaluator import * +from .eval_util import * +from .evaluator import * diff --git a/mu/algorithms/semipermeable_membrane/src/evaluation/artwork_evaluator.py b/mu/algorithms/semipermeable_membrane/src/evaluation/artwork_evaluator.py index 31d59a29..3dc82a85 100644 --- a/mu/algorithms/semipermeable_membrane/src/evaluation/artwork_evaluator.py +++ b/mu/algorithms/semipermeable_membrane/src/evaluation/artwork_evaluator.py @@ -4,7 +4,6 @@ import pandas as pd from prettytable import PrettyTable - from src.configs.generation_config import GenerationConfig from .eval_util import clip_score diff --git a/mu/algorithms/semipermeable_membrane/src/evaluation/clip_evaluator.py b/mu/algorithms/semipermeable_membrane/src/evaluation/clip_evaluator.py index 46db0b21..010382a3 100644 --- a/mu/algorithms/semipermeable_membrane/src/evaluation/clip_evaluator.py +++ b/mu/algorithms/semipermeable_membrane/src/evaluation/clip_evaluator.py @@ -2,10 +2,10 @@ import os import random from argparse import ArgumentParser -from prettytable import PrettyTable -from tqdm import tqdm +from prettytable import PrettyTable from src.configs.generation_config import GenerationConfig +from tqdm import tqdm from ..misc.clip_templates import anchor_templates, imagenet_templates from .eval_util import clip_eval_by_image diff --git a/mu/algorithms/semipermeable_membrane/src/evaluation/coco_evaluator.py b/mu/algorithms/semipermeable_membrane/src/evaluation/coco_evaluator.py index f170c251..35eed123 100644 --- a/mu/algorithms/semipermeable_membrane/src/evaluation/coco_evaluator.py +++ b/mu/algorithms/semipermeable_membrane/src/evaluation/coco_evaluator.py @@ -3,9 +3,8 @@ from argparse import ArgumentParser import pandas as pd -from prettytable import PrettyTable from cleanfid import fid - +from prettytable import PrettyTable from src.configs.generation_config import GenerationConfig from .evaluator import Evaluator, GenerationDataset diff --git a/mu/algorithms/semipermeable_membrane/src/evaluation/eval_util.py b/mu/algorithms/semipermeable_membrane/src/evaluation/eval_util.py index 719fb736..b21f85c3 100644 --- a/mu/algorithms/semipermeable_membrane/src/evaluation/eval_util.py +++ b/mu/algorithms/semipermeable_membrane/src/evaluation/eval_util.py @@ -2,19 +2,18 @@ # - https://github.com/jmhessel/clipscore/blob/main/clipscore.py # - https://github.com/openai/CLIP/blob/main/notebooks/Prompt_Engineering_for_ImageNet.ipynb -import torch +import random +from typing import List, Union + import clip import numpy as np -from typing import List, Union +import torch +from diffusers.pipelines import DiffusionPipeline from PIL import Image -import random - -from src.engine.train_util import text2img from src.configs.config import RootConfig +from src.engine.train_util import text2img from src.misc.clip_templates import imagenet_templates - from torchvision.transforms import CenterCrop, Compose, Normalize, Resize, ToTensor -from diffusers.pipelines import DiffusionPipeline def get_clip_preprocess(n_px=224): diff --git a/mu/algorithms/semipermeable_membrane/src/evaluation/evaluator.py b/mu/algorithms/semipermeable_membrane/src/evaluation/evaluator.py index 4116c865..a7cac6a9 100644 --- a/mu/algorithms/semipermeable_membrane/src/evaluation/evaluator.py +++ b/mu/algorithms/semipermeable_membrane/src/evaluation/evaluator.py @@ -4,7 +4,6 @@ from torch.utils.data import IterableDataset - class GenerationDataset(IterableDataset): """ Dataset for generate images. diff --git a/mu/algorithms/semipermeable_membrane/src/evaluation/i2p_evaluator.py b/mu/algorithms/semipermeable_membrane/src/evaluation/i2p_evaluator.py index a880a3c5..86a3266d 100644 --- a/mu/algorithms/semipermeable_membrane/src/evaluation/i2p_evaluator.py +++ b/mu/algorithms/semipermeable_membrane/src/evaluation/i2p_evaluator.py @@ -4,11 +4,10 @@ import pandas as pd import torch -from nudenet import NudeDetector, NudeClassifier +from nudenet import NudeClassifier, NudeDetector from prettytable import PrettyTable -from tqdm import tqdm - from src.configs.generation_config import GenerationConfig +from tqdm import tqdm from .evaluator import Evaluator, GenerationDataset diff --git a/mu/algorithms/semipermeable_membrane/src/misc/sld_pipeline.py b/mu/algorithms/semipermeable_membrane/src/misc/sld_pipeline.py index 33b2949f..2d617b45 100644 --- a/mu/algorithms/semipermeable_membrane/src/misc/sld_pipeline.py +++ b/mu/algorithms/semipermeable_membrane/src/misc/sld_pipeline.py @@ -1,24 +1,18 @@ import inspect from typing import Callable, List, Optional, Union -import torch - import numpy as np import PIL - -from diffusers.utils import is_accelerate_available -from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer - +import torch from diffusers.configuration_utils import FrozenDict from diffusers.models import AutoencoderKL, UNet2DConditionModel from diffusers.pipeline_utils import DiffusionPipeline -from diffusers.schedulers import ( - DDIMScheduler, - LMSDiscreteScheduler, - PNDMScheduler, +from diffusers.pipelines.stable_diffusion.safety_checker import ( + StableDiffusionSafetyChecker, ) -from diffusers.utils import BaseOutput, deprecate, logging -from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker +from diffusers.schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler +from diffusers.utils import BaseOutput, deprecate, is_accelerate_available, logging +from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer logger = logging.get_logger(__name__) # pylint: disable=invalid-name diff --git a/mu/algorithms/semipermeable_membrane/src/models/merge_spm.py b/mu/algorithms/semipermeable_membrane/src/models/merge_spm.py index b58616e8..a19ef024 100644 --- a/mu/algorithms/semipermeable_membrane/src/models/merge_spm.py +++ b/mu/algorithms/semipermeable_membrane/src/models/merge_spm.py @@ -1,13 +1,14 @@ # modify from: # - https://github.com/bmaltais/kohya_ss/blob/master/networks/merge_lora.py -import math import argparse +import math import os -import torch + import safetensors -from safetensors.torch import load_file +import torch from diffusers import DiffusionPipeline +from safetensors.torch import load_file def load_state_dict(file_name, dtype): diff --git a/mu/algorithms/semipermeable_membrane/src/models/model_util.py b/mu/algorithms/semipermeable_membrane/src/models/model_util.py index 8b187a95..afab36de 100644 --- a/mu/algorithms/semipermeable_membrane/src/models/model_util.py +++ b/mu/algorithms/semipermeable_membrane/src/models/model_util.py @@ -1,21 +1,21 @@ -from typing import Literal, Union, Optional +from typing import Literal, Optional, Union import torch -from transformers import CLIPTextModel, CLIPTokenizer, CLIPTextModelWithProjection from diffusers import ( - UNet2DConditionModel, + AltDiffusionPipeline, + DiffusionPipeline, SchedulerMixin, StableDiffusionPipeline, StableDiffusionXLPipeline, - AltDiffusionPipeline, - DiffusionPipeline, + UNet2DConditionModel, ) from diffusers.schedulers import ( DDIMScheduler, DDPMScheduler, - LMSDiscreteScheduler, EulerAncestralDiscreteScheduler, + LMSDiscreteScheduler, ) +from transformers import CLIPTextModel, CLIPTextModelWithProjection, CLIPTokenizer TOKENIZER_V1_MODEL_NAME = "CompVis/stable-diffusion-v1-4" TOKENIZER_V2_MODEL_NAME = "stabilityai/stable-diffusion-2-1" diff --git a/mu/algorithms/semipermeable_membrane/src/models/spm.py b/mu/algorithms/semipermeable_membrane/src/models/spm.py index fe7cf553..425df02f 100644 --- a/mu/algorithms/semipermeable_membrane/src/models/spm.py +++ b/mu/algorithms/semipermeable_membrane/src/models/spm.py @@ -2,9 +2,9 @@ # - https://github.com/cloneofsimo/lora/blob/master/lora_diffusion/lora.py # - https://github.com/kohya-ss/sd-scripts/blob/main/networks/lora.py -import os import math -from typing import Optional, List +import os +from typing import List, Optional import torch import torch.nn as nn diff --git a/mu/algorithms/semipermeable_membrane/trainer.py b/mu/algorithms/semipermeable_membrane/trainer.py index f24b1e68..c5d093ab 100644 --- a/mu/algorithms/semipermeable_membrane/trainer.py +++ b/mu/algorithms/semipermeable_membrane/trainer.py @@ -1,19 +1,24 @@ # semipermeable_membrane/trainer.py import logging -import torch -from torch.optim import Adam -from torch.nn import MSELoss from typing import List, Optional -from algorithms.semipermeable_membrane.src.engine.sampling import sample import algorithms.semipermeable_membrane.src.engine.train_util as train_util -from algorithms.semipermeable_membrane.src.models import model_util -from algorithms.semipermeable_membrane.src.evaluation import eval_util +import torch from algorithms.semipermeable_membrane.src.configs import config as config_pkg from algorithms.semipermeable_membrane.src.configs import prompt as prompt_pkg from algorithms.semipermeable_membrane.src.configs.config import RootConfig -from algorithms.semipermeable_membrane.src.configs.prompt import PromptEmbedsCache, PromptEmbedsPair, PromptSettings +from algorithms.semipermeable_membrane.src.configs.prompt import ( + PromptEmbedsCache, + PromptEmbedsPair, + PromptSettings, +) +from algorithms.semipermeable_membrane.src.engine.sampling import sample +from algorithms.semipermeable_membrane.src.evaluation import eval_util +from algorithms.semipermeable_membrane.src.models import model_util +from torch.nn import MSELoss +from torch.optim import Adam + class SemipermeableMembraneTrainer: """ diff --git a/mu/algorithms/unified_concept_editing/algorithm.py b/mu/algorithms/unified_concept_editing/algorithm.py index 4b309093..bf1332ec 100644 --- a/mu/algorithms/unified_concept_editing/algorithm.py +++ b/mu/algorithms/unified_concept_editing/algorithm.py @@ -1,15 +1,18 @@ # mu/algorithms/unified_concept_editing/algorithm.py -import torch -import wandb import logging -from typing import Dict from pathlib import Path +from typing import Dict -from mu.core import BaseAlgorithm +import torch +import wandb + +from mu.algorithms.unified_concept_editing.data_handler import ( + UnifiedConceptEditingDataHandler, +) from mu.algorithms.unified_concept_editing.model import UnifiedConceptEditingModel from mu.algorithms.unified_concept_editing.trainer import UnifiedConceptEditingTrainer -from mu.algorithms.unified_concept_editing.data_handler import UnifiedConceptEditingDataHandler +from mu.core import BaseAlgorithm class UnifiedConceptEditingAlgorithm(BaseAlgorithm): diff --git a/mu/algorithms/unified_concept_editing/data_handler.py b/mu/algorithms/unified_concept_editing/data_handler.py index 708fae8d..ae18f491 100644 --- a/mu/algorithms/unified_concept_editing/data_handler.py +++ b/mu/algorithms/unified_concept_editing/data_handler.py @@ -1,11 +1,12 @@ # unified_concept_editing/data_handler.py -from typing import List, Optional, Tuple import logging +from typing import List, Optional, Tuple from mu.core import BaseDataHandler from mu.datasets.constants import * + class UnifiedConceptEditingDataHandler(BaseDataHandler): """ DataHandler for Unified Concept Editing. diff --git a/mu/algorithms/unified_concept_editing/model.py b/mu/algorithms/unified_concept_editing/model.py index 2cbe14b3..8faf91e2 100644 --- a/mu/algorithms/unified_concept_editing/model.py +++ b/mu/algorithms/unified_concept_editing/model.py @@ -1,16 +1,17 @@ # mu/algorithms/unified_concept_editing/model.py -import torch -from typing import Any, List, Optional -import logging -import copy -from tqdm import tqdm import ast -from diffusers import StableDiffusionPipeline +import copy +import logging +from typing import Any, List, Optional +import torch +from diffusers import StableDiffusionPipeline +from tqdm import tqdm from mu.core import BaseModel + class UnifiedConceptEditingModel(BaseModel): """ UnifiedConceptEditingModel handles loading, saving, and interacting with the Stable Diffusion model using diffusers. diff --git a/mu/algorithms/unified_concept_editing/scripts/train.py b/mu/algorithms/unified_concept_editing/scripts/train.py index 2d7a2a8e..3f5cfbcb 100644 --- a/mu/algorithms/unified_concept_editing/scripts/train.py +++ b/mu/algorithms/unified_concept_editing/scripts/train.py @@ -1,13 +1,16 @@ # mu/algorithms/unified_concept_editing/scripts/train.py import argparse -import os import logging +import os -from mu.algorithms.unified_concept_editing.algorithm import UnifiedConceptEditingAlgorithm -from mu.helpers import setup_logger, load_config +from mu.algorithms.unified_concept_editing.algorithm import ( + UnifiedConceptEditingAlgorithm, +) +from mu.helpers import load_config, setup_logger from mu.helpers.path_setup import * + def main(): parser = argparse.ArgumentParser( prog='TrainUnifiedConceptEditing', diff --git a/mu/algorithms/unified_concept_editing/trainer.py b/mu/algorithms/unified_concept_editing/trainer.py index 5142938d..a222251a 100644 --- a/mu/algorithms/unified_concept_editing/trainer.py +++ b/mu/algorithms/unified_concept_editing/trainer.py @@ -3,8 +3,10 @@ import logging from typing import Optional +from mu.algorithms.unified_concept_editing.data_handler import ( + UnifiedConceptEditingDataHandler, +) from mu.algorithms.unified_concept_editing.model import UnifiedConceptEditingModel -from mu.algorithms.unified_concept_editing.data_handler import UnifiedConceptEditingDataHandler from mu.core import BaseTrainer diff --git a/mu/core/base_algorithm.py b/mu/core/base_algorithm.py index b39934da..e7612ebe 100644 --- a/mu/core/base_algorithm.py +++ b/mu/core/base_algorithm.py @@ -3,6 +3,7 @@ from abc import ABC, abstractmethod from typing import Dict + class BaseAlgorithm(ABC): """ Abstract base class for the overall unlearning algorithm, combining the model, trainer, and sampler. diff --git a/mu/core/base_data_handler.py b/mu/core/base_data_handler.py index d8dafa1d..5900ea39 100644 --- a/mu/core/base_data_handler.py +++ b/mu/core/base_data_handler.py @@ -1,9 +1,11 @@ # mu/core/base_data_handler.py from abc import ABC, abstractmethod -from typing import Any, Tuple, Dict, Optional, List +from typing import Any, Dict, List, Optional, Tuple + from torch.utils.data import DataLoader, Dataset + class BaseDataHandler(ABC): """ Abstract base class for data handling and processing. diff --git a/mu/core/base_model.py b/mu/core/base_model.py index 282a6feb..f75ebf47 100644 --- a/mu/core/base_model.py +++ b/mu/core/base_model.py @@ -1,8 +1,10 @@ # mu/core/base_model.py from abc import ABC, abstractmethod + import torch.nn as nn + class BaseModel(nn.Module, ABC): """Abstract base class for all unlearning models.""" diff --git a/mu/core/base_sampler.py b/mu/core/base_sampler.py index aef5c267..3e2cb76d 100644 --- a/mu/core/base_sampler.py +++ b/mu/core/base_sampler.py @@ -3,6 +3,7 @@ from abc import ABC, abstractmethod from typing import Any + class BaseSampler(ABC): """Abstract base class for sampling methods used in unlearning.""" diff --git a/mu/core/base_trainer.py b/mu/core/base_trainer.py index 528ae229..80a01007 100644 --- a/mu/core/base_trainer.py +++ b/mu/core/base_trainer.py @@ -3,6 +3,7 @@ from abc import ABC, abstractmethod from typing import Any + class BaseTrainer(ABC): """Abstract base class for training unlearning models.""" diff --git a/mu/datasets/__init__.py b/mu/datasets/__init__.py index fea3bc99..9c43dfce 100644 --- a/mu/datasets/__init__.py +++ b/mu/datasets/__init__.py @@ -4,7 +4,6 @@ from .i2p_dataset import I2PDataset from .unlearn_canvas_dataset import UnlearnCanvasDataset - __all__ = [ "BaseDataset", "I2PDataset", diff --git a/mu/datasets/base_dataset.py b/mu/datasets/base_dataset.py index be1bcad2..5b546db8 100644 --- a/mu/datasets/base_dataset.py +++ b/mu/datasets/base_dataset.py @@ -1,7 +1,9 @@ from abc import ABC, abstractmethod -from torch.utils.data import Dataset from typing import Any, Tuple +from torch.utils.data import Dataset + + class BaseDataset(Dataset, ABC): """ Abstract base class for all datasets. diff --git a/mu/datasets/constants/__init__.py b/mu/datasets/constants/__init__.py index 1c39374c..d68316e8 100644 --- a/mu/datasets/constants/__init__.py +++ b/mu/datasets/constants/__init__.py @@ -1,19 +1,15 @@ # mu/datasets/constants/__init__.py +from .i2p_const import i2p_categories, i2p_sample_categories from .uc_const import ( - uc_style_list, - uc_class_list, - uc_theme_available, uc_class_available, - uc_sample_theme_available, + uc_class_list, uc_sample_class_available, uc_sample_class_list, - uc_sample_style_list -) - -from .i2p_const import ( - i2p_sample_categories, - i2p_categories + uc_sample_style_list, + uc_sample_theme_available, + uc_style_list, + uc_theme_available, ) __all__ = [ diff --git a/mu/datasets/i2p_dataset.py b/mu/datasets/i2p_dataset.py index 761e3696..f917abe4 100644 --- a/mu/datasets/i2p_dataset.py +++ b/mu/datasets/i2p_dataset.py @@ -1,16 +1,18 @@ # mu/datasets/unlearn_canvas_dataset.py -from typing import Any, Tuple -from PIL import Image import os -import torch +from typing import Any, Tuple + import numpy as np +import torch from einops import rearrange +from PIL import Image from mu.datasets import BaseDataset -from mu.datasets.constants import * +from mu.datasets.constants import * from mu.helpers import read_text_lines + class I2PDataset(BaseDataset): """ I2P Dataset. diff --git a/mu/datasets/unlearn_canvas_dataset.py b/mu/datasets/unlearn_canvas_dataset.py index 26eefc47..8e8fa51b 100644 --- a/mu/datasets/unlearn_canvas_dataset.py +++ b/mu/datasets/unlearn_canvas_dataset.py @@ -1,17 +1,19 @@ #datasets/unlearn_canvas_dataset.py -from typing import Any, Tuple -from PIL import Image import os -import torch +from typing import Any, Tuple + import numpy as np +import torch from einops import rearrange -from mu.datasets import BaseDataset +from PIL import Image from torchvision import transforms -from mu.datasets.constants import * +from mu.datasets import BaseDataset +from mu.datasets.constants import * from mu.helpers import read_text_lines + class UnlearnCanvasDataset(BaseDataset): """ Dataset for UnlearnCanvas algorithm. diff --git a/mu/datasets/utils.py b/mu/datasets/utils.py index f502c5ca..e869945e 100644 --- a/mu/datasets/utils.py +++ b/mu/datasets/utils.py @@ -2,7 +2,6 @@ from torchvision.transforms import InterpolationMode from torchvision.transforms import functional as F - INTERPOLATIONS = { 'bilinear': InterpolationMode.BILINEAR, 'bicubic': InterpolationMode.BICUBIC, diff --git a/mu/helpers/__init__.py b/mu/helpers/__init__.py index b7fa004f..4f621372 100644 --- a/mu/helpers/__init__.py +++ b/mu/helpers/__init__.py @@ -1,3 +1,3 @@ from .config_loader import load_config from .logger import setup_logger -from .utils import read_text_lines, load_model_from_config, sample_model \ No newline at end of file +from .utils import load_model_from_config, read_text_lines, sample_model diff --git a/mu/helpers/config_loader.py b/mu/helpers/config_loader.py index 55cce519..bd1dd060 100644 --- a/mu/helpers/config_loader.py +++ b/mu/helpers/config_loader.py @@ -1,5 +1,6 @@ +import os + import yaml -import os def load_config(yaml_path): diff --git a/mu/helpers/logger.py b/mu/helpers/logger.py index 09f4a6a6..d1ad2cb8 100644 --- a/mu/helpers/logger.py +++ b/mu/helpers/logger.py @@ -1,6 +1,7 @@ import logging from pathlib import Path + def setup_logger(log_file: str = "erase_diff_training.log", level: int = logging.INFO) -> logging.Logger: """ Setup a logger for the training process. diff --git a/mu/helpers/path_setup.py b/mu/helpers/path_setup.py index 33088019..872ca756 100644 --- a/mu/helpers/path_setup.py +++ b/mu/helpers/path_setup.py @@ -1,5 +1,4 @@ -import os - +import os file_path = __file__ diff --git a/mu/helpers/utils.py b/mu/helpers/utils.py index 6cb90ed7..70b79763 100644 --- a/mu/helpers/utils.py +++ b/mu/helpers/utils.py @@ -1,11 +1,11 @@ -from typing import List,Any +from pathlib import Path +from typing import Any, List -from omegaconf import OmegaConf import torch -from pathlib import Path +from omegaconf import OmegaConf -from stable_diffusion.ldm.util import instantiate_from_config from stable_diffusion.ldm.models.diffusion.ddim import DDIMSampler +from stable_diffusion.ldm.util import instantiate_from_config def read_text_lines(path: str) -> List[str]: diff --git a/scripts/generate_images_for_prompts.py b/scripts/generate_images_for_prompts.py index 070c1b78..ff438750 100644 --- a/scripts/generate_images_for_prompts.py +++ b/scripts/generate_images_for_prompts.py @@ -1,8 +1,9 @@ import os + import pandas as pd -from torch import autocast from diffusers import StableDiffusionPipeline from PIL import Image +from torch import autocast # Load the CSV file csv_path = "/home/ubuntu/Projects/msu_unlearningalgorithm/data/i2p-dataset/sample/i2p.csv" diff --git a/stable_diffusion/ldm/data/base.py b/stable_diffusion/ldm/data/base.py index fad77670..0f3a9cbe 100644 --- a/stable_diffusion/ldm/data/base.py +++ b/stable_diffusion/ldm/data/base.py @@ -1,5 +1,6 @@ from abc import abstractmethod -from torch.utils.data import Dataset, ConcatDataset, ChainDataset, IterableDataset + +from torch.utils.data import ChainDataset, ConcatDataset, Dataset, IterableDataset class Txt2ImgIterableBaseDataset(IterableDataset): @@ -23,6 +24,7 @@ def __iter__(self): pass import os + import numpy as np diff --git a/stable_diffusion/ldm/data/imagenet.py b/stable_diffusion/ldm/data/imagenet.py index 1c473f9c..3f7c6cc3 100644 --- a/stable_diffusion/ldm/data/imagenet.py +++ b/stable_diffusion/ldm/data/imagenet.py @@ -1,20 +1,29 @@ -import os, yaml, pickle, shutil, tarfile, glob -import cv2 +import glob +import os +import pickle +import shutil +import tarfile +from functools import partial + import albumentations -import PIL +import cv2 import numpy as np +import PIL +import taming.data.utils as tdu import torchvision.transforms.functional as TF +import yaml +from ldm.modules.image_degradation import degradation_fn_bsr, degradation_fn_bsr_light from omegaconf import OmegaConf -from functools import partial from PIL import Image -from tqdm import tqdm +from taming.data.imagenet import ( + ImagePaths, + download, + give_synsets_from_indices, + retrieve, + str_to_indices, +) from torch.utils.data import Dataset, Subset - -import taming.data.utils as tdu -from taming.data.imagenet import str_to_indices, give_synsets_from_indices, download, retrieve -from taming.data.imagenet import ImagePaths - -from ldm.modules.image_degradation import degradation_fn_bsr, degradation_fn_bsr_light +from tqdm import tqdm def synset2idx(path_to_yaml="data/index_synset.yaml"): diff --git a/stable_diffusion/ldm/data/lsun.py b/stable_diffusion/ldm/data/lsun.py index 6256e457..7399c4ad 100644 --- a/stable_diffusion/ldm/data/lsun.py +++ b/stable_diffusion/ldm/data/lsun.py @@ -1,4 +1,5 @@ import os + import numpy as np import PIL from PIL import Image diff --git a/stable_diffusion/ldm/extras.py b/stable_diffusion/ldm/extras.py index c2e7f850..4330bbbf 100644 --- a/stable_diffusion/ldm/extras.py +++ b/stable_diffusion/ldm/extras.py @@ -1,14 +1,15 @@ +import sys from pathlib import Path -from omegaconf import OmegaConf + import torch -import sys +from omegaconf import OmegaConf + sys.path.append() -from ldm.util import instantiate_from_config import logging from contextlib import contextmanager -from contextlib import contextmanager -import logging +from ldm.util import instantiate_from_config + @contextmanager def all_logging_disabled(highest_level=logging.CRITICAL): diff --git a/stable_diffusion/ldm/guaidance.py b/stable_diffusion/ldm/guaidance.py index 512ad4b3..e3794e09 100644 --- a/stable_diffusion/ldm/guaidance.py +++ b/stable_diffusion/ldm/guaidance.py @@ -1,11 +1,12 @@ +import abc from typing import List, Tuple -from scipy import interpolate + +import matplotlib.pyplot as plt import numpy as np import torch -import matplotlib.pyplot as plt from IPython.display import clear_output -import abc +from scipy import interpolate class GuideModel(torch.nn.Module, abc.ABC): diff --git a/stable_diffusion/ldm/modules/attention.py b/stable_diffusion/ldm/modules/attention.py index 8c671b42..4bec8cfe 100644 --- a/stable_diffusion/ldm/modules/attention.py +++ b/stable_diffusion/ldm/modules/attention.py @@ -1,13 +1,15 @@ # File modified by authors of InstructPix2Pix from original (https://github.com/CompVis/stable-diffusion). # See more details in LICENSE. -from inspect import isfunction import math +import sys +from inspect import isfunction + import torch import torch.nn.functional as F -from torch import nn, einsum from einops import rearrange, repeat -import sys +from torch import einsum, nn + sys.path.append('.') from stable_diffusion.ldm.modules.diffusionmodules.util import checkpoint diff --git a/stable_diffusion/ldm/modules/diffusionmodules/model.py b/stable_diffusion/ldm/modules/diffusionmodules/model.py index 224440d6..b4ad62fc 100644 --- a/stable_diffusion/ldm/modules/diffusionmodules/model.py +++ b/stable_diffusion/ldm/modules/diffusionmodules/model.py @@ -1,16 +1,19 @@ # pytorch_diffusion + derived encoder decoder import math +import sys + +import numpy as np import torch import torch.nn as nn -import numpy as np from einops import rearrange -import sys sys.path.append('.') -from stable_diffusion.ldm.util import instantiate_from_config from stable_diffusion.ldm.modules.attention import LinearAttention -from stable_diffusion.ldm.modules.distributions.distributions import DiagonalGaussianDistribution +from stable_diffusion.ldm.modules.distributions.distributions import ( + DiagonalGaussianDistribution, +) +from stable_diffusion.ldm.util import instantiate_from_config def get_timestep_embedding(timesteps, embedding_dim): diff --git a/stable_diffusion/ldm/modules/diffusionmodules/openaimodel.py b/stable_diffusion/ldm/modules/diffusionmodules/openaimodel.py index e6f021cb..8f3f8071 100644 --- a/stable_diffusion/ldm/modules/diffusionmodules/openaimodel.py +++ b/stable_diffusion/ldm/modules/diffusionmodules/openaimodel.py @@ -1,26 +1,27 @@ -from abc import abstractmethod import math +import sys +from abc import abstractmethod import numpy as np import torch as th import torch.nn as nn import torch.nn.functional as F -import sys sys.path.append('.') +from stable_diffusion.ldm.modules.attention import SpatialTransformer +from stable_diffusion.ldm.modules.diffusionmodules.util import conv_nd # nn.Conv2d +from stable_diffusion.ldm.modules.diffusionmodules.util import linear # nn.Linear from stable_diffusion.ldm.modules.diffusionmodules.util import ( - checkpoint, - conv_nd, # nn.Conv2d - linear, # nn.Linear avg_pool_nd, - zero_module, + checkpoint, normalization, timestep_embedding, + zero_module, ) -from stable_diffusion.ldm.modules.attention import SpatialTransformer from stable_diffusion.ldm.util import exists + # dummy replace def convert_module_to_f16(x): pass diff --git a/stable_diffusion/ldm/modules/diffusionmodules/util.py b/stable_diffusion/ldm/modules/diffusionmodules/util.py index ac1d78d3..ec483c3f 100644 --- a/stable_diffusion/ldm/modules/diffusionmodules/util.py +++ b/stable_diffusion/ldm/modules/diffusionmodules/util.py @@ -8,14 +8,15 @@ # thanks! -import os import math +import os +import sys + +import numpy as np import torch import torch.nn as nn -import numpy as np from einops import repeat -import sys sys.path.append('.') from stable_diffusion.ldm.util import instantiate_from_config diff --git a/stable_diffusion/ldm/modules/distributions/distributions.py b/stable_diffusion/ldm/modules/distributions/distributions.py index 919fb852..1d6ef7f4 100644 --- a/stable_diffusion/ldm/modules/distributions/distributions.py +++ b/stable_diffusion/ldm/modules/distributions/distributions.py @@ -1,5 +1,5 @@ -import torch import numpy as np +import torch class AbstractDistribution: diff --git a/stable_diffusion/ldm/modules/encoders/modules.py b/stable_diffusion/ldm/modules/encoders/modules.py index 0b825ee3..b7fa318d 100644 --- a/stable_diffusion/ldm/modules/encoders/modules.py +++ b/stable_diffusion/ldm/modules/encoders/modules.py @@ -1,19 +1,29 @@ -import torch -import torch.nn as nn -import numpy as np -from functools import partial -import kornia import sys +from functools import partial + import clip +import kornia +import numpy as np +import torch +import torch.nn as nn + sys.path.append(".") -from stable_diffusion.ldm.util import instantiate_from_config -from stable_diffusion.ldm.modules.diffusionmodules.util import make_beta_schedule, extract_into_tensor, noise_like -from stable_diffusion.ldm.modules.x_transformer import Encoder, TransformerWrapper # TODO: can we directly rely on lucidrains code and simply add this as a reuirement? --> test -from stable_diffusion.ldm.util import default -from stable_diffusion.ldm.thirdp.psp.id_loss import IDFeatures import kornia.augmentation as K +from stable_diffusion.ldm.modules.diffusionmodules.util import ( + extract_into_tensor, + make_beta_schedule, + noise_like, +) +from stable_diffusion.ldm.modules.x_transformer import ( # TODO: can we directly rely on lucidrains code and simply add this as a reuirement? --> test + Encoder, + TransformerWrapper, +) +from stable_diffusion.ldm.thirdp.psp.id_loss import IDFeatures +from stable_diffusion.ldm.util import default, instantiate_from_config + + class AbstractEncoder(nn.Module): def __init__(self): super().__init__() @@ -112,7 +122,8 @@ def encode(self, text): return self(text) -from transformers import T5Tokenizer, T5EncoderModel, CLIPTokenizer, CLIPTextModel +from transformers import CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5Tokenizer + def disabled_train(self, mode=True): """Overwrite model.train with this function to make sure train/eval mode @@ -217,6 +228,8 @@ def encode(self, text): import torch.nn.functional as F from transformers import CLIPVisionModel + + class ClipImageProjector(AbstractEncoder): """ Uses the CLIP image encoder. diff --git a/stable_diffusion/ldm/modules/evaluate/adm_evaluator.py b/stable_diffusion/ldm/modules/evaluate/adm_evaluator.py index 508cddf2..d69c6280 100644 --- a/stable_diffusion/ldm/modules/evaluate/adm_evaluator.py +++ b/stable_diffusion/ldm/modules/evaluate/adm_evaluator.py @@ -10,11 +10,11 @@ from multiprocessing import cpu_count from multiprocessing.pool import ThreadPool from typing import Iterable, Optional, Tuple -import yaml import numpy as np import requests import tensorflow.compat.v1 as tf +import yaml from scipy import linalg from tqdm.auto import tqdm diff --git a/stable_diffusion/ldm/modules/evaluate/evaluate_perceptualsim.py b/stable_diffusion/ldm/modules/evaluate/evaluate_perceptualsim.py index 8d5db33b..fe9c6202 100644 --- a/stable_diffusion/ldm/modules/evaluate/evaluate_perceptualsim.py +++ b/stable_diffusion/ldm/modules/evaluate/evaluate_perceptualsim.py @@ -1,20 +1,19 @@ import argparse import glob import os -from tqdm import tqdm +import sys from collections import namedtuple import numpy as np import torch import torchvision.transforms as transforms -from torchvision import models from PIL import Image +from torchvision import models +from tqdm import tqdm -import sys sys.path.append(".") from stable_diffusion.ldm.modules.evaluate.ssim import ssim - transform = transforms.Compose([transforms.ToTensor()]) def normalize_tensor(in_feat, eps=1e-10): diff --git a/stable_diffusion/ldm/modules/evaluate/frechet_video_distance.py b/stable_diffusion/ldm/modules/evaluate/frechet_video_distance.py index d9e13c41..3e06bb92 100644 --- a/stable_diffusion/ldm/modules/evaluate/frechet_video_distance.py +++ b/stable_diffusion/ldm/modules/evaluate/frechet_video_distance.py @@ -21,10 +21,7 @@ embedding to be better suitable for videos. """ -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - +from __future__ import absolute_import, division, print_function import six import tensorflow.compat.v1 as tf diff --git a/stable_diffusion/ldm/modules/evaluate/ssim.py b/stable_diffusion/ldm/modules/evaluate/ssim.py index 4e8883cc..e04dbda8 100644 --- a/stable_diffusion/ldm/modules/evaluate/ssim.py +++ b/stable_diffusion/ldm/modules/evaluate/ssim.py @@ -9,6 +9,7 @@ import torch.nn.functional as F from torch.autograd import Variable + def gaussian(window_size, sigma): gauss = torch.Tensor( [ diff --git a/stable_diffusion/ldm/modules/evaluate/torch_frechet_video_distance.py b/stable_diffusion/ldm/modules/evaluate/torch_frechet_video_distance.py index 04856b82..da216c10 100644 --- a/stable_diffusion/ldm/modules/evaluate/torch_frechet_video_distance.py +++ b/stable_diffusion/ldm/modules/evaluate/torch_frechet_video_distance.py @@ -1,26 +1,25 @@ # based on https://github.com/universome/fvd-comparison/blob/master/compare_models.py; huge thanks! -import os -import numpy as np +import glob +import hashlib +import html import io +import multiprocessing as mp +import os import re -import requests -import html -import hashlib import urllib import urllib.request -import scipy.linalg -import multiprocessing as mp -import glob - +from typing import Any, Callable, Dict, List, Tuple, Union +import numpy as np +import requests +import scipy.linalg +from einops import rearrange +from nitro.util import isvideo +from torchvision.io import read_video from tqdm import tqdm -from typing import Any, List, Tuple, Union, Dict, Callable -from torchvision.io import read_video import torch; torch.set_grad_enabled(False) -from einops import rearrange -from nitro.util import isvideo def compute_frechet_distance(mu_sample,sigma_sample,mu_ref,sigma_ref) -> float: print('Calculate frechet distance...') diff --git a/stable_diffusion/ldm/modules/image_degradation/__init__.py b/stable_diffusion/ldm/modules/image_degradation/__init__.py index 7836cada..c6b3b62e 100644 --- a/stable_diffusion/ldm/modules/image_degradation/__init__.py +++ b/stable_diffusion/ldm/modules/image_degradation/__init__.py @@ -1,2 +1,6 @@ -from ldm.modules.image_degradation.bsrgan import degradation_bsrgan_variant as degradation_fn_bsr -from ldm.modules.image_degradation.bsrgan_light import degradation_bsrgan_variant as degradation_fn_bsr_light +from ldm.modules.image_degradation.bsrgan import ( + degradation_bsrgan_variant as degradation_fn_bsr, +) +from ldm.modules.image_degradation.bsrgan_light import ( + degradation_bsrgan_variant as degradation_fn_bsr_light, +) diff --git a/stable_diffusion/ldm/modules/image_degradation/bsrgan.py b/stable_diffusion/ldm/modules/image_degradation/bsrgan.py index bca65a34..f7f1658e 100644 --- a/stable_diffusion/ldm/modules/image_degradation/bsrgan.py +++ b/stable_diffusion/ldm/modules/image_degradation/bsrgan.py @@ -10,20 +10,20 @@ # -------------------------------------------- """ -import numpy as np -import cv2 -import torch - -from functools import partial import random -from scipy import ndimage +import sys +from functools import partial + +import albumentations +import cv2 +import numpy as np import scipy import scipy.stats as ss +import torch +from scipy import ndimage from scipy.interpolate import interp2d from scipy.linalg import orth -import albumentations -import sys sys.path.append('.') import stable_diffusion.ldm.modules.image_degradation.utils_image as util diff --git a/stable_diffusion/ldm/modules/image_degradation/bsrgan_light.py b/stable_diffusion/ldm/modules/image_degradation/bsrgan_light.py index 368e392e..6a5272f9 100644 --- a/stable_diffusion/ldm/modules/image_degradation/bsrgan_light.py +++ b/stable_diffusion/ldm/modules/image_degradation/bsrgan_light.py @@ -1,18 +1,18 @@ # -*- coding: utf-8 -*- -import numpy as np -import cv2 -import torch - -from functools import partial import random -from scipy import ndimage +import sys +from functools import partial + +import albumentations +import cv2 +import numpy as np import scipy import scipy.stats as ss +import torch +from scipy import ndimage from scipy.interpolate import interp2d from scipy.linalg import orth -import albumentations -import sys sys.path.append('.') import stable_diffusion.ldm.modules.image_degradation.utils_image as util diff --git a/stable_diffusion/ldm/modules/image_degradation/utils_image.py b/stable_diffusion/ldm/modules/image_degradation/utils_image.py index 0175f155..5d72b073 100644 --- a/stable_diffusion/ldm/modules/image_degradation/utils_image.py +++ b/stable_diffusion/ldm/modules/image_degradation/utils_image.py @@ -1,11 +1,13 @@ -import os import math +import os import random +from datetime import datetime + +import cv2 import numpy as np import torch -import cv2 from torchvision.utils import make_grid -from datetime import datetime + #import matplotlib.pyplot as plt # TODO: check with Dominik, also bsrgan.py vs bsrgan_light.py diff --git a/stable_diffusion/ldm/modules/losses/contperceptual.py b/stable_diffusion/ldm/modules/losses/contperceptual.py index 672c1e32..0c361a3f 100644 --- a/stable_diffusion/ldm/modules/losses/contperceptual.py +++ b/stable_diffusion/ldm/modules/losses/contperceptual.py @@ -1,6 +1,5 @@ import torch import torch.nn as nn - from taming.modules.losses.vqperceptual import * # TODO: taming dependency yes/no? diff --git a/stable_diffusion/ldm/modules/losses/vqperceptual.py b/stable_diffusion/ldm/modules/losses/vqperceptual.py index f6998176..9f959a13 100644 --- a/stable_diffusion/ldm/modules/losses/vqperceptual.py +++ b/stable_diffusion/ldm/modules/losses/vqperceptual.py @@ -1,11 +1,10 @@ import torch -from torch import nn import torch.nn.functional as F from einops import repeat - from taming.modules.discriminator.model import NLayerDiscriminator, weights_init from taming.modules.losses.lpips import LPIPS from taming.modules.losses.vqperceptual import hinge_d_loss, vanilla_d_loss +from torch import nn def hinge_d_loss_with_exemplar_weights(logits_real, logits_fake, weights): diff --git a/stable_diffusion/ldm/modules/x_transformer.py b/stable_diffusion/ldm/modules/x_transformer.py index 5fc15bf9..a3c9591b 100644 --- a/stable_diffusion/ldm/modules/x_transformer.py +++ b/stable_diffusion/ldm/modules/x_transformer.py @@ -1,11 +1,12 @@ """shout-out to https://github.com/lucidrains/x-transformers/tree/main/x_transformers""" -import torch -from torch import nn, einsum -import torch.nn.functional as F +from collections import namedtuple from functools import partial from inspect import isfunction -from collections import namedtuple -from einops import rearrange, repeat, reduce + +import torch +import torch.nn.functional as F +from einops import rearrange, reduce, repeat +from torch import einsum, nn # constants diff --git a/stable_diffusion/ldm/thirdp/psp/helpers.py b/stable_diffusion/ldm/thirdp/psp/helpers.py index 983baaa5..5ed7e18c 100644 --- a/stable_diffusion/ldm/thirdp/psp/helpers.py +++ b/stable_diffusion/ldm/thirdp/psp/helpers.py @@ -1,8 +1,19 @@ # https://github.com/eladrich/pixel2style2pixel from collections import namedtuple + import torch -from torch.nn import Conv2d, BatchNorm2d, PReLU, ReLU, Sigmoid, MaxPool2d, AdaptiveAvgPool2d, Sequential, Module +from torch.nn import ( + AdaptiveAvgPool2d, + BatchNorm2d, + Conv2d, + MaxPool2d, + Module, + PReLU, + ReLU, + Sequential, + Sigmoid, +) """ ArcFace implementation from [TreB1eN](https://github.com/TreB1eN/InsightFace_Pytorch) diff --git a/stable_diffusion/ldm/thirdp/psp/id_loss.py b/stable_diffusion/ldm/thirdp/psp/id_loss.py index bfbc4832..44b8cf18 100644 --- a/stable_diffusion/ldm/thirdp/psp/id_loss.py +++ b/stable_diffusion/ldm/thirdp/psp/id_loss.py @@ -1,7 +1,9 @@ # https://github.com/eladrich/pixel2style2pixel +import sys + import torch from torch import nn -import sys + sys.path.append('.') from stable_diffusion.ldm.thirdp.psp.model_irse import Backbone diff --git a/stable_diffusion/ldm/thirdp/psp/model_irse.py b/stable_diffusion/ldm/thirdp/psp/model_irse.py index aeab98ae..134e8158 100644 --- a/stable_diffusion/ldm/thirdp/psp/model_irse.py +++ b/stable_diffusion/ldm/thirdp/psp/model_irse.py @@ -1,8 +1,25 @@ # https://github.com/eladrich/pixel2style2pixel import sys + sys.path.append(".") -from torch.nn import Linear, Conv2d, BatchNorm1d, BatchNorm2d, PReLU, Dropout, Sequential, Module -from stable_diffusion.ldm.thirdp.psp.helpers import get_blocks, Flatten, bottleneck_IR, bottleneck_IR_SE, l2_norm +from torch.nn import ( + BatchNorm1d, + BatchNorm2d, + Conv2d, + Dropout, + Linear, + Module, + PReLU, + Sequential, +) + +from stable_diffusion.ldm.thirdp.psp.helpers import ( + Flatten, + bottleneck_IR, + bottleneck_IR_SE, + get_blocks, + l2_norm, +) """ Modified Backbone implementation from [TreB1eN](https://github.com/TreB1eN/InsightFace_Pytorch) diff --git a/stable_diffusion/ldm/util.py b/stable_diffusion/ldm/util.py index 8c09ca1c..fb6947ba 100644 --- a/stable_diffusion/ldm/util.py +++ b/stable_diffusion/ldm/util.py @@ -1,11 +1,10 @@ import importlib +from inspect import isfunction -import torch -from torch import optim import numpy as np - -from inspect import isfunction +import torch from PIL import Image, ImageDraw, ImageFont +from torch import optim def log_txt_as_img(wh, xc, size=10): diff --git a/stable_diffusion/scripts/img2img.py b/stable_diffusion/scripts/img2img.py index 421e2151..0cdbd2a8 100644 --- a/stable_diffusion/scripts/img2img.py +++ b/stable_diffusion/scripts/img2img.py @@ -1,23 +1,26 @@ """make variations of input image""" -import argparse, os, sys, glob +import argparse +import glob +import os +import sys +import time +from contextlib import nullcontext +from itertools import islice + +import numpy as np import PIL import torch -import numpy as np -from omegaconf import OmegaConf -from PIL import Image -from tqdm import tqdm, trange -from itertools import islice from einops import rearrange, repeat -from torchvision.utils import make_grid -from torch import autocast -from contextlib import nullcontext -import time -from pytorch_lightning import seed_everything - -from ldm.util import instantiate_from_config from ldm.models.diffusion.ddim import DDIMSampler from ldm.models.diffusion.plms import PLMSSampler +from ldm.util import instantiate_from_config +from omegaconf import OmegaConf +from PIL import Image +from pytorch_lightning import seed_everything +from torch import autocast +from torchvision.utils import make_grid +from tqdm import tqdm, trange def chunk(it, size): diff --git a/stable_diffusion/scripts/inpaint.py b/stable_diffusion/scripts/inpaint.py index d6e6387a..a9dd9763 100644 --- a/stable_diffusion/scripts/inpaint.py +++ b/stable_diffusion/scripts/inpaint.py @@ -1,11 +1,15 @@ -import argparse, os, sys, glob -from omegaconf import OmegaConf -from PIL import Image -from tqdm import tqdm +import argparse +import glob +import os +import sys + import numpy as np import torch -from main import instantiate_from_config from ldm.models.diffusion.ddim import DDIMSampler +from main import instantiate_from_config +from omegaconf import OmegaConf +from PIL import Image +from tqdm import tqdm def make_batch(image, mask, device): diff --git a/stable_diffusion/scripts/knn2img.py b/stable_diffusion/scripts/knn2img.py index e6eaaeca..0b638192 100644 --- a/stable_diffusion/scripts/knn2img.py +++ b/stable_diffusion/scripts/knn2img.py @@ -1,22 +1,25 @@ -import argparse, os, sys, glob +import argparse +import glob +import os +import sys +import time +from itertools import islice +from multiprocessing import cpu_count + import clip +import numpy as np +import scann import torch import torch.nn as nn -import numpy as np -from omegaconf import OmegaConf -from PIL import Image -from tqdm import tqdm, trange -from itertools import islice from einops import rearrange, repeat -from torchvision.utils import make_grid -import scann -import time -from multiprocessing import cpu_count - -from ldm.util import instantiate_from_config, parallel_data_prefetch from ldm.models.diffusion.ddim import DDIMSampler from ldm.models.diffusion.plms import PLMSSampler from ldm.modules.encoders.modules import FrozenClipImageEmbedder, FrozenCLIPTextEmbedder +from ldm.util import instantiate_from_config, parallel_data_prefetch +from omegaconf import OmegaConf +from PIL import Image +from torchvision.utils import make_grid +from tqdm import tqdm, trange DATABASES = [ "openimages", diff --git a/stable_diffusion/scripts/sample_diffusion.py b/stable_diffusion/scripts/sample_diffusion.py index 876fe3c3..be27d3a7 100644 --- a/stable_diffusion/scripts/sample_diffusion.py +++ b/stable_diffusion/scripts/sample_diffusion.py @@ -1,14 +1,18 @@ -import argparse, os, sys, glob, datetime, yaml -import torch +import argparse +import datetime +import glob +import os +import sys import time -import numpy as np -from tqdm import trange - -from omegaconf import OmegaConf -from PIL import Image +import numpy as np +import torch +import yaml from ldm.models.diffusion.ddim import DDIMSampler from ldm.util import instantiate_from_config +from omegaconf import OmegaConf +from PIL import Image +from tqdm import trange rescale = lambda x: (x + 1.) / 2. diff --git a/stable_diffusion/scripts/train_searcher.py b/stable_diffusion/scripts/train_searcher.py index 1e790488..77df5354 100644 --- a/stable_diffusion/scripts/train_searcher.py +++ b/stable_diffusion/scripts/train_searcher.py @@ -1,12 +1,13 @@ -import os, sys -import numpy as np -import scann import argparse import glob +import os +import sys from multiprocessing import cpu_count -from tqdm import tqdm +import numpy as np +import scann from ldm.util import parallel_data_prefetch +from tqdm import tqdm def search_bruteforce(searcher): diff --git a/stable_diffusion/scripts/txt2img.py b/stable_diffusion/scripts/txt2img.py index bc386404..7616601d 100644 --- a/stable_diffusion/scripts/txt2img.py +++ b/stable_diffusion/scripts/txt2img.py @@ -1,28 +1,31 @@ -import argparse, os, sys, glob +import argparse +import glob +import os +import sys +import time +from contextlib import contextmanager, nullcontext +from itertools import islice + import cv2 -import torch import numpy as np +import torch +from diffusers.pipelines.stable_diffusion.safety_checker import ( + StableDiffusionSafetyChecker, +) +from einops import rearrange +from imwatermark import WatermarkEncoder +from ldm.models.diffusion.ddim import DDIMSampler +from ldm.models.diffusion.dpm_solver import DPMSolverSampler +from ldm.models.diffusion.plms import PLMSSampler +from ldm.util import instantiate_from_config from omegaconf import OmegaConf from PIL import Image -from tqdm import tqdm, trange -from imwatermark import WatermarkEncoder -from itertools import islice -from einops import rearrange -from torchvision.utils import make_grid -import time from pytorch_lightning import seed_everything from torch import autocast -from contextlib import contextmanager, nullcontext - -from ldm.util import instantiate_from_config -from ldm.models.diffusion.ddim import DDIMSampler -from ldm.models.diffusion.plms import PLMSSampler -from ldm.models.diffusion.dpm_solver import DPMSolverSampler - -from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker +from torchvision.utils import make_grid +from tqdm import tqdm, trange from transformers import AutoFeatureExtractor - # load safety model safety_model_id = "CompVis/stable-diffusion-safety-checker" safety_feature_extractor = AutoFeatureExtractor.from_pretrained(safety_model_id) diff --git a/stable_diffusion/scripts/txt2img_make_n_samples.py b/stable_diffusion/scripts/txt2img_make_n_samples.py index f60fa443..4ab5d5c7 100644 --- a/stable_diffusion/scripts/txt2img_make_n_samples.py +++ b/stable_diffusion/scripts/txt2img_make_n_samples.py @@ -1,30 +1,33 @@ -import argparse, os, sys, glob +import argparse +import glob +import os +import sys +import time +from contextlib import contextmanager, nullcontext +from itertools import islice + import cv2 -import torch import numpy as np +import torch +from einops import rearrange from omegaconf import OmegaConf from PIL import Image -from tqdm import tqdm, trange -from itertools import islice -from einops import rearrange -from torchvision.utils import make_grid -import time from pytorch_lightning import seed_everything from torch import autocast -from contextlib import contextmanager, nullcontext +from torchvision.utils import make_grid +from tqdm import tqdm, trange -import sys sys.path.append(".") -from ldm.util import instantiate_from_config +from diffusers.pipelines.stable_diffusion.safety_checker import ( + StableDiffusionSafetyChecker, +) from ldm.models.diffusion.ddim import DDIMSampler -from ldm.models.diffusion.plms import PLMSSampler from ldm.models.diffusion.dpm_solver import DPMSolverSampler - -from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker +from ldm.models.diffusion.plms import PLMSSampler +from ldm.util import instantiate_from_config from transformers import AutoFeatureExtractor - # load safety model safety_model_id = "CompVis/stable-diffusion-safety-checker" safety_feature_extractor = AutoFeatureExtractor.from_pretrained(safety_model_id)