diff --git a/diffusion/datasets/image_caption.py b/diffusion/datasets/image_caption.py index 81939c61..2c02905c 100644 --- a/diffusion/datasets/image_caption.py +++ b/diffusion/datasets/image_caption.py @@ -15,7 +15,8 @@ from torch.utils.data import DataLoader from torchvision import transforms -from diffusion.datasets.laion.transforms import LargestCenterSquare, RandomCropAspectRatioTransorm, RandomCropSquare +from diffusion.datasets.laion.transforms import (LargestCenterSquare, RandomCropAspectRatioTransform, + RandomCropBucketedAspectRatioTransform, RandomCropSquare) from diffusion.datasets.utils import make_streams from diffusion.models.text_encoder import MultiTokenizer @@ -45,6 +46,7 @@ class StreamingImageCaptionDataset(StreamingDataset): transform (Callable, optional): The transforms to apply to the image. Default: ``None``. image_key (str): Key associated with the image in the streaming dataset. Default: ``'image'``. caption_key (str): Key associated with the caption in the streaming dataset. Default: ``'caption'``. + aspect_ratio_bucket_key (str, optional): Key associated with the aspect ratio bucket in the streaming dataset. Default: ``None``. sdxl_conditioning (bool): Whether or not to include SDXL microconditioning in a sample. Default: `False`. zero_dropped_captions (bool): If True, zero out text embeddings for dropped captions. Default: ``False``. **streaming_kwargs: Additional arguments to pass in the construction of the StreamingDataloader @@ -63,6 +65,7 @@ def __init__( transform: Optional[Callable] = None, image_key: str = 'image', caption_key: str = 'caption', + aspect_ratio_bucket_key: Optional[str] = None, sdxl_conditioning: bool = False, zero_dropped_captions: bool = False, **streaming_kwargs, @@ -90,6 +93,9 @@ def __init__( self.caption_selection = caption_selection self.image_key = image_key self.caption_key = caption_key + self.aspect_ratio_bucket_key = aspect_ratio_bucket_key + if isinstance(self.crop, RandomCropBucketedAspectRatioTransform): + assert self.aspect_ratio_bucket_key is not None, 'aspect_ratio_bucket_key must be provided when using RandomCropBucketedAspectRatioTransform' self.zero_dropped_captions = zero_dropped_captions self.tokenizer = tokenizer @@ -107,7 +113,9 @@ def __getitem__(self, index): orig_w, orig_h = img.size # Image transforms - if self.crop is not None: + if isinstance(self.crop, RandomCropBucketedAspectRatioTransform): + img, crop_top, crop_left = self.crop(img, sample[self.aspect_ratio_bucket_key]) + elif self.crop is not None: img, crop_top, crop_left = self.crop(img) else: crop_top, crop_left = 0, 0 @@ -179,6 +187,7 @@ def build_streaming_image_caption_dataloader( transform: Optional[List[Callable]] = None, image_key: str = 'image', caption_key: str = 'caption', + aspect_ratio_bucket_key: Optional[str] = None, crop_type: Optional[str] = 'square', zero_dropped_captions: bool = True, sdxl_conditioning: bool = False, @@ -212,7 +221,8 @@ def build_streaming_image_caption_dataloader( transform (Optional[Callable]): The transforms to apply to the image. Default: ``None``. image_key (str): Key associated with the image in the streaming dataset. Default: ``'image'``. caption_key (str): Key associated with the caption in the streaming dataset. Default: ``'caption'``. - crop_type (str, optional): Type of crop to perform, either ['square', 'random', 'aspect_ratio']. + aspect_ratio_bucket_key (str, optional): Key associated with the aspect ratio bucket in the streaming dataset. Default: ``None``. + crop_type (str, optional): Type of crop to perform, either ['square', 'random', 'aspect_ratio', 'bucketed_aspect_ratio']. Default: ``'square'``. zero_dropped_captions (bool): If True, zero out text embeddings for dropped captions. Default: ``True``. sdxl_conditioning (bool): Whether or not to include SDXL microconditioning in a sample. Default: `False`. @@ -225,12 +235,14 @@ def build_streaming_image_caption_dataloader( # Check crop type if crop_type is not None: crop_type = crop_type.lower() - if crop_type not in ['square', 'random', 'aspect_ratio']: - raise ValueError(f'Invalid crop_type: {crop_type}. Must be ["square", "random", "aspect_ratio", None]') - if crop_type == 'aspect_ratio' and (isinstance(resize_size, int) or isinstance(resize_size[0], int)): + if crop_type not in ['square', 'random', 'aspect_ratio', 'bucketed_aspect_ratio']: raise ValueError( - 'If using crop_type="aspect_ratio", specify aspect ratio buckets in resize_size as a tuple of tuples.') - + f'Invalid crop_type: {crop_type}. Must be ["square", "random", "aspect_ratio", "bucketed_aspect_ratio", None]' + ) + if crop_type in ['aspect_ratio', 'bucketed_aspect_ratio'] and (isinstance(resize_size, int) or + isinstance(resize_size[0], int)): + raise ValueError( + 'If using aspect ratio bucketing, specify aspect ratio buckets in resize_size as a tuple of tuples.') # Handle ``None`` kwargs if streaming_kwargs is None: streaming_kwargs = {} @@ -246,7 +258,10 @@ def build_streaming_image_caption_dataloader( elif crop_type == 'random': crop = RandomCropSquare(resize_size) elif crop_type == 'aspect_ratio': - crop = RandomCropAspectRatioTransorm(resize_size, ar_bucket_boundaries) # type: ignore + crop = RandomCropAspectRatioTransform(resize_size, ar_bucket_boundaries) # type: ignore + elif crop_type == 'bucketed_aspect_ratio': + assert aspect_ratio_bucket_key is not None, 'aspect_ratio_bucket_key must be provided when using bucketed_aspect_ratio crop type' + crop = RandomCropBucketedAspectRatioTransform(resize_size) # type: ignore else: crop = None @@ -265,6 +280,7 @@ def build_streaming_image_caption_dataloader( transform=transform, image_key=image_key, caption_key=caption_key, + aspect_ratio_bucket_key=aspect_ratio_bucket_key, batch_size=batch_size, sdxl_conditioning=sdxl_conditioning, zero_dropped_captions=zero_dropped_captions, diff --git a/diffusion/datasets/laion/transforms.py b/diffusion/datasets/laion/transforms.py index e9700b9b..2f9e955a 100644 --- a/diffusion/datasets/laion/transforms.py +++ b/diffusion/datasets/laion/transforms.py @@ -3,6 +3,7 @@ """Transforms for the training and eval dataset.""" +import math from typing import Optional, Tuple import torch @@ -45,7 +46,7 @@ def __call__(self, img): return img, c_top, c_left -class RandomCropAspectRatioTransorm: +class RandomCropAspectRatioTransform: """Assigns an image to a arbitrary set of aspect ratio buckets, then resizes and crops to fit into the bucket. Args: @@ -111,3 +112,49 @@ def __call__(self, img): c_top, c_left, height, width = transforms.RandomCrop.get_params(img, output_size=(target_height, target_width)) img = crop(img, c_top, c_left, height, width) return img, c_top, c_left + + +class RandomCropBucketedAspectRatioTransform: + """Assigns an image to a arbitrary set of aspect ratio buckets, then resizes and crops to fit into the bucket. + + This transform requires the desired aspect ratio bucket to be specified manually in the call to the transform. + + Args: + resize_size (Tuple[Tuple[int, int], ...): A tuple of 2-tuple integers representing the aspect ratio buckets. + The format is ((height_bucket1, width_bucket1), (height_bucket2, width_bucket2), ...). + """ + + def __init__( + self, + resize_size: Tuple[Tuple[int, int], ...], + ): + self.height_buckets = torch.tensor([size[0] for size in resize_size]) + self.width_buckets = torch.tensor([size[1] for size in resize_size]) + self.aspect_ratio_buckets = self.height_buckets / self.width_buckets + self.log_aspect_ratio_buckets = torch.log(self.aspect_ratio_buckets) + + def __call__(self, img, aspect_ratio): + orig_w, orig_h = img.size + orig_aspect_ratio = orig_h / orig_w + # Figure out target H/W given the input aspect ratio + bucket_ind = torch.abs(self.log_aspect_ratio_buckets - math.log(aspect_ratio)).argmin() + target_width, target_height = self.width_buckets[bucket_ind].item(), self.height_buckets[bucket_ind].item() + target_aspect_ratio = target_height / target_width + + # Determine resize size + if orig_aspect_ratio > target_aspect_ratio: + # Resize width and crop height + w_scale = target_width / orig_w + resize_size = (round(w_scale * orig_h), target_width) + elif orig_aspect_ratio < target_aspect_ratio: + # Resize height and crop width + h_scale = target_height / orig_h + resize_size = (target_height, round(h_scale * orig_w)) + else: + resize_size = (target_height, target_width) + img = transforms.functional.resize(img, resize_size, antialias=True) + + # Crop based on aspect ratio + c_top, c_left, height, width = transforms.RandomCrop.get_params(img, output_size=(target_height, target_width)) + img = crop(img, c_top, c_left, height, width) + return img, c_top, c_left