Skip to content

fix sorting of imports #2

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -16,4 +16,5 @@ outputs/*
models
!models/.gitkeep

!mu/algorithms/erase_diff/.gitignore
!mu/algorithms/erase_diff/.gitignore
.venv
2 changes: 2 additions & 0 deletions .isort.cfg
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
[settings]
known_third_party = IPython,PIL,accelerate,albumentations,algorithms,cleanfid,clip,core,cv2,datasets,diffusers,einops,fire,huggingface_hub,imwatermark,kornia,ldm,main,matplotlib,nitro,nudenet,numpy,omegaconf,pandas,prettytable,pydantic,pytorch_lightning,quadprog,requests,safetensors,scann,scipy,six,src,taming,tensorflow,tensorflow_gan,tensorflow_hub,timm,torch,torchvision,tqdm,transformers,wandb,yaml
14 changes: 14 additions & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
default_language_version:
python: python3.12.7
repos:

- repo: https://github.com/asottile/seed-isort-config
rev: v2.2.0
hooks:
- id: seed-isort-config

- repo: https://github.com/pre-commit/mirrors-isort
rev: v5.10.1
hooks:
- id: isort
args: ["--profile", "black"]
6 changes: 3 additions & 3 deletions lora_diffusion/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from .lora import *
from .dataset import *
from .utils import *
from .preprocess_files import *
from .lora import *
from .lora_manager import *
from .preprocess_files import *
from .utils import *
12 changes: 4 additions & 8 deletions lora_diffusion/cli_lora_add.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,13 @@
from typing import Literal, Union, Dict
import os
import shutil
from typing import Dict, Literal, Union

import fire
import torch
from diffusers import StableDiffusionPipeline
from safetensors.torch import safe_open, save_file

import torch
from .lora import (
tune_lora_scale,
patch_pipe,
collapse_lora,
monkeypatch_remove_lora,
)
from .lora import collapse_lora, monkeypatch_remove_lora, patch_pipe, tune_lora_scale
from .lora_manager import lora_join
from .to_ckpt_v2 import convert_to_ckpt

Expand Down
14 changes: 7 additions & 7 deletions lora_diffusion/cli_lora_pti.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,14 @@
import random
import re
from pathlib import Path
from typing import Optional, List, Literal
from typing import List, Literal, Optional

import fire
import torch
import torch.nn.functional as F
import torch.optim as optim
import torch.utils.checkpoint
import wandb
from diffusers import (
AutoencoderKL,
DDPMScheduler,
Expand All @@ -29,20 +31,18 @@
from torchvision import transforms
from tqdm.auto import tqdm
from transformers import CLIPTextModel, CLIPTokenizer
import wandb
import fire

from lora_diffusion import (
UNET_EXTENDED_TARGET_REPLACE,
PivotalTuningDatasetCapation,
evaluate_pipe,
extract_lora_ups_down,
inject_trainable_lora,
inject_trainable_lora_extended,
inspect_lora,
save_lora_weight,
save_all,
prepare_clip_model_sets,
evaluate_pipe,
UNET_EXTENDED_TARGET_REPLACE,
save_all,
save_lora_weight,
)


Expand Down
1 change: 1 addition & 0 deletions lora_diffusion/cli_pt_to_safetensors.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import fire
import torch

from lora_diffusion import (
DEFAULT_TARGET_REPLACE,
TEXT_ENCODER_DEFAULT_TARGET_REPLACE,
Expand Down
6 changes: 3 additions & 3 deletions lora_diffusion/cli_svd.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,15 @@
import fire
from diffusers import StableDiffusionPipeline
import torch
import torch.nn as nn
from diffusers import StableDiffusionPipeline

from .lora import (
save_all,
_find_modules,
LoraInjectedConv2d,
LoraInjectedLinear,
_find_modules,
inject_trainable_lora,
inject_trainable_lora_extended,
save_all,
)


Expand Down
3 changes: 2 additions & 1 deletion lora_diffusion/dataset.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import glob
import random
from pathlib import Path
from typing import Dict, List, Optional, Tuple, Union
Expand All @@ -6,7 +7,7 @@
from torch import zeros_like
from torch.utils.data import Dataset
from torchvision import transforms
import glob

from .preprocess_files import face_mask_google_mediapipe

OBJECT_TEMPLATE = [
Expand Down
8 changes: 5 additions & 3 deletions lora_diffusion/lora_manager.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
from typing import List

import torch
from safetensors import safe_open
from diffusers import StableDiffusionPipeline
from safetensors import safe_open

from .lora import (
monkeypatch_or_replace_safeloras,
apply_learned_embed_in_clip,
set_lora_diag,
monkeypatch_or_replace_safeloras,
parse_safeloras_embeds,
set_lora_diag,
)


Expand Down
3 changes: 2 additions & 1 deletion lora_diffusion/patch_lora.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Type, Union

import torch
from typing import Callable, Dict, List, Optional, Set, Tuple, Type, Union, Any

try:
from safetensors.torch import safe_open
Expand Down
15 changes: 8 additions & 7 deletions lora_diffusion/preprocess_files.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,16 @@
# Have BLIP auto caption
# Have CLIPSeg auto mask concept

from typing import List, Literal, Union, Optional, Tuple
import glob
import os
from PIL import Image, ImageFilter
import torch
import numpy as np
from typing import List, Literal, Optional, Tuple, Union

import fire
import numpy as np
import torch
from PIL import Image, ImageFilter
from tqdm import tqdm
import glob
from transformers import CLIPSegProcessor, CLIPSegForImageSegmentation
from transformers import CLIPSegForImageSegmentation, CLIPSegProcessor


@torch.no_grad()
Expand Down Expand Up @@ -133,7 +134,7 @@ def blip_captioning_dataset(
Returns a list of captions for the given images
"""

from transformers import BlipProcessor, BlipForConditionalGeneration
from transformers import BlipForConditionalGeneration, BlipProcessor

processor = BlipProcessor.from_pretrained(model_id)
model = BlipForConditionalGeneration.from_pretrained(model_id).to(device)
Expand Down
1 change: 0 additions & 1 deletion lora_diffusion/to_ckpt_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@

import torch


# =================#
# UNet Conversion #
# =================#
Expand Down
10 changes: 5 additions & 5 deletions lora_diffusion/utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
import glob
import math
import os
from typing import List, Union

import torch
from diffusers import StableDiffusionPipeline
from PIL import Image
from transformers import (
CLIPProcessor,
Expand All @@ -9,11 +13,7 @@
CLIPVisionModelWithProjection,
)

from diffusers import StableDiffusionPipeline
from .lora import patch_pipe, tune_lora_scale, _text_lora_path, _ti_lora_path
import os
import glob
import math
from .lora import _text_lora_path, _ti_lora_path, patch_pipe, tune_lora_scale

EXAMPLE_PROMPTS = [
"<obj> swimming in a pool",
Expand Down
8 changes: 4 additions & 4 deletions mu/algorithms/concept_ablation/algorithm.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
import logging
import torch
import wandb
from typing import Dict

from core.base_algorithm import BaseAlgorithm
import torch
import wandb
from algorithms.concept_ablation.data_handler import ConceptAblationDataHandler
from algorithms.concept_ablation.model import ConceptAblationModel
from algorithms.concept_ablation.trainer import ConceptAblationTrainer
from algorithms.concept_ablation.data_handler import ConceptAblationDataHandler
from core.base_algorithm import BaseAlgorithm


class ConceptAblationAlgorithm(BaseAlgorithm):
Expand Down
8 changes: 5 additions & 3 deletions mu/algorithms/concept_ablation/datasets/dataset.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
import os
from typing import List, Tuple, Callable, Any
from pathlib import Path
from PIL import Image
from typing import Any, Callable, List, Tuple

import torch
from torch.utils.data import Dataset
from PIL import Image
from src import utils
from src.utils import safe_dir
from torch.utils.data import Dataset

# Import your filtering logic if available
# from src.filter import filter as filter_fn

Expand Down
11 changes: 7 additions & 4 deletions mu/algorithms/concept_ablation/handler.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
import os
import logging
import os
from typing import Dict
from torch.utils.data import DataLoader
from core.base_data_handler import BaseDataHandler
from mu.datasets.utils import get_transform, INTERPOLATIONS

from algorithms.concept_ablation.datasets.dataset import ConceptAblationDataset
from core.base_data_handler import BaseDataHandler
from torch.utils.data import DataLoader

from mu.datasets.utils import INTERPOLATIONS, get_transform


class ConceptAblationDataHandler(BaseDataHandler):
"""
Expand Down
10 changes: 6 additions & 4 deletions mu/algorithms/concept_ablation/model.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
from core.base_model import BaseModel
from stable_diffusion.ldm.util import instantiate_from_config
from omegaconf import OmegaConf
import torch
from pathlib import Path
from typing import Any

import torch
from core.base_model import BaseModel
from omegaconf import OmegaConf

from stable_diffusion.ldm.util import instantiate_from_config


class ConceptAblationModel(BaseModel):
"""
Expand Down
7 changes: 5 additions & 2 deletions mu/algorithms/concept_ablation/scripts/train.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,14 @@
import argparse
import os
import logging
import os
import sys

# Adjust these imports according to your project structure
from algorithms.concept_ablation.algorithm import ConceptAblationAlgorithm
from algorithms.erase_diff.logger import setup_logger # or create a similar logger for concept_ablation if needed
from algorithms.erase_diff.logger import (
setup_logger, # or create a similar logger for concept_ablation if needed
)


def main():
parser = argparse.ArgumentParser(
Expand Down
8 changes: 4 additions & 4 deletions mu/algorithms/concept_ablation/trainer.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
import logging
from typing import Dict

import torch
import wandb
from core.base_trainer import BaseTrainer
from torch.nn import MSELoss
from torch.optim import Adam
from tqdm import tqdm
import logging

from core.base_trainer import BaseTrainer
from typing import Dict


class ConceptAblationTrainer(BaseTrainer):
Expand Down
12 changes: 7 additions & 5 deletions mu/algorithms/erase_diff/algorithm.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,17 @@
# mu/algorithms/erase_diff/algorithm.py

import torch
import wandb
from typing import Dict
import logging
from pathlib import Path
from typing import Dict

from mu.core import BaseAlgorithm
import torch
import wandb

from mu.algorithms.erase_diff.data_handler import EraseDiffDataHandler
from mu.algorithms.erase_diff.model import EraseDiffModel
from mu.algorithms.erase_diff.trainer import EraseDiffTrainer
from mu.algorithms.erase_diff.data_handler import EraseDiffDataHandler
from mu.core import BaseAlgorithm


class EraseDiffAlgorithm(BaseAlgorithm):
"""
Expand Down
8 changes: 4 additions & 4 deletions mu/algorithms/erase_diff/data_handler.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
import logging
import os
import pandas as pd
from typing import Any, Dict
from torch.utils.data import DataLoader
import logging

import pandas as pd
from torch.utils.data import DataLoader

from mu.algorithms.erase_diff.datasets.erase_diff_dataset import EraseDiffDataset
from mu.datasets.constants import *
from mu.core import BaseDataHandler
from mu.datasets.constants import *
from mu.helpers import read_text_lines


Expand Down
6 changes: 4 additions & 2 deletions mu/algorithms/erase_diff/datasets/erase_diff_dataset.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
import os
from typing import Any, Tuple, Dict
from typing import Any, Dict, Tuple

from torch.utils.data import DataLoader

from mu.datasets import UnlearnCanvasDataset, I2PDataset, BaseDataset
from mu.datasets import BaseDataset, I2PDataset, UnlearnCanvasDataset
from mu.datasets.utils import INTERPOLATIONS, get_transform


class EraseDiffDataset(BaseDataset):
"""
Dataset class for the EraseDiff algorithm.
Expand Down
Loading