diff --git a/diffusion/datasets/image_caption_latents.py b/diffusion/datasets/image_caption_latents.py index 8ca46c14..4ee9a38b 100644 --- a/diffusion/datasets/image_caption_latents.py +++ b/diffusion/datasets/image_caption_latents.py @@ -178,6 +178,9 @@ def build_streaming_image_caption_latents_dataloader( attention_mask_keys: Tuple[str, ...] = ('T5_ATTENTION_MASK', 'CLIP_ATTENTION_MASK'), latent_dtype: str = 'torch.bfloat16', aspect_ratio_bucket_key: Optional[str] = None, + proportion: Optional[list] = None, + repeat: Optional[list] = None, + choose: Optional[list] = None, streaming_kwargs: Optional[Dict] = None, dataloader_kwargs: Optional[Dict] = None, ): @@ -213,6 +216,9 @@ def build_streaming_image_caption_latents_dataloader( or 'torch.bfloat16'. Default: ``'torch.bfloat16'``. aspect_ratio_bucket_key (str, optional): Key associated with the aspect ratio bucket in the streaming dataset. Needed if using ``crop_type='bucketed_aspect_ratio'``. Default: ``None``. + proportion (list, optional): Specifies how to sample this Stream relative to other Streams. Default: ``None``. + repeat (list, optional): Specifies the degree to which a Stream is upsampled or downsampled. Default: ``None``. + choose (list, optional): Specifies the number of samples to choose from a Stream. Default: ``None``. streaming_kwargs (dict, optional): Additional arguments to pass to the ``StreamingDataset``. Default: ``None``. dataloader_kwargs (dict, optional): Additional arguments to pass to the ``DataLoader``. Default: ``None``. """ @@ -239,7 +245,7 @@ def build_streaming_image_caption_latents_dataloader( dataloader_kwargs = {} # Make streams - streams = make_streams(remote, local) + streams = make_streams(remote, local=local, proportion=proportion, repeat=repeat, choose=choose) # Set the crop to apply if crop_type == 'square':