Skip to content

Commit

Permalink
Add option for stream weights (#178)
Browse files Browse the repository at this point in the history
  • Loading branch information
coryMosaicML authored Nov 4, 2024
1 parent eb23a2f commit b0a094f
Showing 1 changed file with 7 additions and 1 deletion.
8 changes: 7 additions & 1 deletion diffusion/datasets/image_caption_latents.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,9 @@ def build_streaming_image_caption_latents_dataloader(
attention_mask_keys: Tuple[str, ...] = ('T5_ATTENTION_MASK', 'CLIP_ATTENTION_MASK'),
latent_dtype: str = 'torch.bfloat16',
aspect_ratio_bucket_key: Optional[str] = None,
proportion: Optional[list] = None,
repeat: Optional[list] = None,
choose: Optional[list] = None,
streaming_kwargs: Optional[Dict] = None,
dataloader_kwargs: Optional[Dict] = None,
):
Expand Down Expand Up @@ -213,6 +216,9 @@ def build_streaming_image_caption_latents_dataloader(
or 'torch.bfloat16'. Default: ``'torch.bfloat16'``.
aspect_ratio_bucket_key (str, optional): Key associated with the aspect ratio bucket in the streaming dataset.
Needed if using ``crop_type='bucketed_aspect_ratio'``. Default: ``None``.
proportion (list, optional): Specifies how to sample this Stream relative to other Streams. Default: ``None``.
repeat (list, optional): Specifies the degree to which a Stream is upsampled or downsampled. Default: ``None``.
choose (list, optional): Specifies the number of samples to choose from a Stream. Default: ``None``.
streaming_kwargs (dict, optional): Additional arguments to pass to the ``StreamingDataset``. Default: ``None``.
dataloader_kwargs (dict, optional): Additional arguments to pass to the ``DataLoader``. Default: ``None``.
"""
Expand All @@ -239,7 +245,7 @@ def build_streaming_image_caption_latents_dataloader(
dataloader_kwargs = {}

# Make streams
streams = make_streams(remote, local)
streams = make_streams(remote, local=local, proportion=proportion, repeat=repeat, choose=choose)

# Set the crop to apply
if crop_type == 'square':
Expand Down

0 comments on commit b0a094f

Please sign in to comment.