Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update to transforms v2 API #84

Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion diffusion/datasets/coco/coco_captions.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,15 @@

from streaming.base import StreamingDataset
from torch.utils.data import DataLoader
from torchvision import transforms
from transformers import CLIPTokenizer

from diffusion.datasets.laion.transforms import LargestCenterSquare

try:
from torchvision.transforms.v2 import transforms
except ImportError:
from torchvision import transforms


class StreamingCOCOCaption(StreamingDataset):
"""Streaming COCO dataset.
Expand Down
7 changes: 6 additions & 1 deletion diffusion/datasets/image_caption.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,12 @@
from PIL import Image
from streaming import Stream, StreamingDataset
from torch.utils.data import DataLoader
from torchvision import transforms

try:
from torchvision.transforms.v2 import transforms
except ImportError:
from torchvision import transforms

from transformers import AutoTokenizer

from diffusion.datasets.laion.transforms import LargestCenterSquare, RandomCropSquare, RandomCropSquareReturnTransform
Expand Down
6 changes: 5 additions & 1 deletion diffusion/datasets/laion/laion.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,15 @@
from PIL import Image
from streaming import Stream, StreamingDataset
from torch.utils.data import DataLoader
from torchvision import transforms
from transformers import CLIPTokenizer

from diffusion.datasets.laion.transforms import LargestCenterSquare

try:
from torchvision.transforms.v2 import transforms
except ImportError:
from torchvision import transforms

# Disable PIL max image size limit
Image.MAX_IMAGE_PIXELS = None

Expand Down
10 changes: 7 additions & 3 deletions diffusion/datasets/laion/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,13 @@

"""Transforms for the training and eval dataset."""

import torchvision.transforms as transforms
from torchvision.transforms import RandomCrop
from torchvision.transforms.functional import crop
try:
import torchvision.transforms.v2 as transforms
except ImportError:
import torchvision.transforms as transforms

RandomCrop = transforms.RandomCrop
crop = transforms.functional


class LargestCenterSquare:
Expand Down
6 changes: 5 additions & 1 deletion diffusion/evaluation/clean_fid_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,14 @@
from composer.utils import dist
from torch.utils.data import DataLoader
from torchmetrics.multimodal import CLIPScore
from torchvision.transforms.functional import to_pil_image
from tqdm.auto import tqdm
from transformers import PreTrainedTokenizerBase

try:
from torchvision.transforms.v2.functional import to_pil_image
except ImportError:
from torchvision.transforms.functional import to_pil_image

os.environ['TOKENIZERS_PARALLELISM'] = 'false'


Expand Down
4 changes: 2 additions & 2 deletions diffusion/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ def train(config: DictConfig) -> None:
# Assumes that evaluators is a nested dictionary with evalutor / dataloader pairs
if 'evaluators' in config.dataset:
evaluators = []
for _, eval_conf in config.dataset.evaluators.items():
for eval_conf in config.dataset.evaluators.values():
print(OmegaConf.to_yaml(eval_conf))
eval_dataloader = hydra.utils.instantiate(
eval_conf.eval_dataset,
Expand Down Expand Up @@ -109,7 +109,7 @@ def train(config: DictConfig) -> None:
)

if 'callbacks' in config:
for _, call_conf in config.callbacks.items():
for call_conf in config.callbacks.values():
if '_target_' in call_conf:
print(f'Instantiating callbacks <{call_conf._target_}>')
callbacks.append(hydra.utils.instantiate(call_conf))
Expand Down
6 changes: 5 additions & 1 deletion scripts/precompute_latents.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,16 @@
from PIL import Image
from streaming import MDSWriter, Stream, StreamingDataset
from torch.utils.data import DataLoader
from torchvision import transforms
from tqdm import tqdm
from transformers import CLIPTextModel, CLIPTokenizer

from diffusion.datasets.laion.transforms import LargestCenterSquare

try:
from torchvision.transforms.v2 import transforms
except ImportError:
from torchvision import transforms


class StreamingLAIONDataset(StreamingDataset):
"""Implementation of the LAION dataset as a streaming dataset except with metadata.
Expand Down