|
| 1 | +# Copyright 2022 MosaicML Diffusion authors |
| 2 | +# SPDX-License-Identifier: Apache-2.0 |
| 3 | + |
| 4 | +"""Streaming Image-Caption Dataset for use with Pre-computed Text Latents.""" |
| 5 | + |
| 6 | +import logging |
| 7 | +from io import BytesIO |
| 8 | +from typing import Callable, Dict, List, Optional, Sequence, Tuple, Union |
| 9 | + |
| 10 | +import numpy as np |
| 11 | +import torch |
| 12 | +from PIL import Image |
| 13 | +from streaming import Stream, StreamingDataset |
| 14 | +from torch.utils.data import DataLoader |
| 15 | +from torchvision import transforms |
| 16 | + |
| 17 | +from diffusion.datasets.laion.transforms import LargestCenterSquare, RandomCropAspectRatioTransform, RandomCropSquare |
| 18 | +from diffusion.datasets.utils import make_streams |
| 19 | + |
| 20 | +log = logging.getLogger(__name__) |
| 21 | + |
| 22 | + |
| 23 | +class StreamingImageCaptionLatentsDataset(StreamingDataset): |
| 24 | + """Streaming dataset for image-caption datasets with pre-computed text latents. |
| 25 | +
|
| 26 | + Args: |
| 27 | + streams (Sequence[Stream]): One or more Streams to stream/cache samples from. |
| 28 | + caption_drop_prob (float): The probability of dropping a caption. Default: ``0.0``. |
| 29 | + microcond_drop_prob (float): The probability of dropping microconditioning. Only relevant for SDXL. Default: ``0.0``. |
| 30 | + crop (Callable, optional): The crop transform to apply to the image before ``transform``. Default: ``None`` |
| 31 | + transform (Callable, optional): The transforms to apply to the image. Default: ``None``. |
| 32 | + image_key (str): Key associated with the image in the streaming dataset. Default: ``'image'``. |
| 33 | + caption_keys (Tuple[str, ...]): Key(s) associated with captions in the streaming dataset. Default: ``('caption',)``. |
| 34 | + caption_selection_probs (Tuple[float, ...]): The probability of selecting each caption key. Default: ``(1.0,)``. |
| 35 | + text_latent_keys (Tuple[str, ...]): Key(s) associated with text latents in the streaming dataset. |
| 36 | + Default: ``('T5_LATENTS', 'CLIP_LATENTS')``. |
| 37 | + text_latent_shapes (Tuple[Tuple[int, int], ...]): The shape(s) of the text latents in the streaming dataset. |
| 38 | + Each shape is a 2-tuple where the first dim is the sequence length and the second dim is the feature size. |
| 39 | + Default: ``((512, 4096), (77, 768))``. |
| 40 | + attention_mask_keys (Tuple[str, ...]): Key(s) associated with attention masks in the streaming dataset. |
| 41 | + Default: ``('T5_ATTENTION_MASK', 'CLIP_ATTENTION_MASK')``. |
| 42 | + latent_dtype (torch.dtype): The dtype to cast the text latents to. Default: ``torch.bfloat16``. |
| 43 | + **streaming_kwargs: Additional arguments to pass in the construction of the StreamingDataloader |
| 44 | + """ |
| 45 | + |
| 46 | + def __init__( |
| 47 | + self, |
| 48 | + streams: Sequence[Stream], |
| 49 | + caption_drop_prob: float = 0.0, |
| 50 | + microcond_drop_prob: float = 0.0, |
| 51 | + crop: Optional[Callable] = None, |
| 52 | + transform: Optional[Callable] = None, |
| 53 | + image_key: str = 'image', |
| 54 | + caption_keys: Tuple[str, ...] = ('caption',), |
| 55 | + caption_selection_probs: Tuple[float, ...] = (1.0,), |
| 56 | + text_latent_keys: Tuple[str, ...] = ('T5_LATENTS', 'CLIP_LATENTS'), |
| 57 | + text_latent_shapes: Tuple[Tuple[int, int], ...] = ((512, 4096), (77, 768)), |
| 58 | + attention_mask_keys: Tuple[str, ...] = ('T5_ATTENTION_MASK', 'CLIP_ATTENTION_MASK'), |
| 59 | + latent_dtype: torch.dtype = torch.bfloat16, |
| 60 | + **streaming_kwargs, |
| 61 | + ): |
| 62 | + |
| 63 | + # Set defaults for vision-friendly streaming args. |
| 64 | + streaming_kwargs.setdefault('shuffle_block_size', 1 << 18) |
| 65 | + streaming_kwargs.setdefault('shuffle_algo', 'py1s') |
| 66 | + super().__init__(streams=streams, **streaming_kwargs) |
| 67 | + |
| 68 | + self.crop = crop |
| 69 | + self.transform = transform |
| 70 | + self.caption_drop_prob = caption_drop_prob |
| 71 | + self.microcond_drop_prob = microcond_drop_prob |
| 72 | + self.image_key = image_key |
| 73 | + self.caption_keys = caption_keys |
| 74 | + self.caption_selection_probs = caption_selection_probs |
| 75 | + self.text_latent_keys = text_latent_keys |
| 76 | + self.text_latent_shapes = text_latent_shapes |
| 77 | + self.attention_mask_keys = attention_mask_keys |
| 78 | + self.latent_dtype = latent_dtype |
| 79 | + |
| 80 | + def __getitem__(self, index): |
| 81 | + sample = super().__getitem__(index) |
| 82 | + out = {} |
| 83 | + |
| 84 | + # Image |
| 85 | + img = sample[self.image_key] |
| 86 | + if not isinstance(img, Image.Image): |
| 87 | + img = Image.open(BytesIO(sample[self.image_key])) |
| 88 | + if img.mode != 'RGB': |
| 89 | + img = img.convert('RGB') |
| 90 | + out['cond_original_size'] = torch.tensor(img.size) |
| 91 | + |
| 92 | + # Image transforms |
| 93 | + if self.crop is not None: |
| 94 | + img, crop_top, crop_left = self.crop(img) |
| 95 | + else: |
| 96 | + crop_top, crop_left = 0, 0 |
| 97 | + out['cond_crops_coords_top_left'] = torch.tensor([crop_top, crop_left]) |
| 98 | + |
| 99 | + if self.transform is not None: |
| 100 | + img = self.transform(img) |
| 101 | + out['image'] = img |
| 102 | + |
| 103 | + # Get the new height and width |
| 104 | + if isinstance(img, torch.Tensor): |
| 105 | + img_h, img_w = img.shape[-2], img.shape[-1] |
| 106 | + elif isinstance(img, Image.Image): |
| 107 | + img_w, img_h = img.size |
| 108 | + else: |
| 109 | + raise ValueError('Image after transformations must either be a PIL Image or Torch Tensor') |
| 110 | + out['cond_target_size'] = torch.tensor([img_w, img_h]) |
| 111 | + |
| 112 | + # Microconditioning dropout as in Stability repo |
| 113 | + # https://github.com/Stability-AI/generative-models/blob/477d8b9a7730d9b2e92b326a770c0420d00308c9/sgm/modules/encoders/modules.py#L151-L160 |
| 114 | + if torch.rand(1) < self.microcond_drop_prob: |
| 115 | + out['cond_crops_coords_top_left'] = out['cond_crops_coords_top_left'] * 0 |
| 116 | + if torch.rand(1) < self.microcond_drop_prob: |
| 117 | + out['cond_original_size'] = out['cond_original_size'] * 0 |
| 118 | + if torch.rand(1) < self.microcond_drop_prob: |
| 119 | + out['cond_target_size'] = out['cond_target_size'] * 0 |
| 120 | + |
| 121 | + # Randomly select a caption according to the selection probabilities |
| 122 | + caption_key = np.random.choice(self.caption_keys, p=self.caption_selection_probs) |
| 123 | + # Load text latents, attention masks, and clip pooled embeddings |
| 124 | + for i in range(len(self.text_latent_keys)): |
| 125 | + latent_key = f'{caption_key}_{self.text_latent_keys[i]}' |
| 126 | + latent_shape = self.text_latent_shapes[i] |
| 127 | + attention_key = f'{caption_key}_{self.attention_mask_keys[i]}' |
| 128 | + |
| 129 | + if torch.rand(1) < self.caption_drop_prob: |
| 130 | + out[self.text_latent_keys[i]] = torch.zeros(latent_shape, dtype=self.latent_dtype) |
| 131 | + out[self.attention_mask_keys[i]] = torch.zeros(latent_shape[0]) |
| 132 | + if 'CLIP_LATENTS' in latent_key: |
| 133 | + out['CLIP_POOLED'] = torch.zeros(latent_shape[1]) |
| 134 | + else: |
| 135 | + text_latent = np.frombuffer(sample[latent_key], dtype=np.float32).copy() |
| 136 | + out[self.text_latent_keys[i]] = torch.from_numpy(text_latent).to( |
| 137 | + self.latent_dtype).reshape(latent_shape) |
| 138 | + attention_mask = np.frombuffer(sample[attention_key], dtype=np.bool_).copy() |
| 139 | + out[self.attention_mask_keys[i]] = torch.from_numpy(attention_mask).to(dtype=torch.float).reshape(-1) #.reshape(latent_shape[0]) |
| 140 | + if 'CLIP_LATENTS' in latent_key: |
| 141 | + clip_pooled = np.frombuffer(sample[f'{caption_key}_CLIP_POOLED_TEXT'], dtype=np.float32).copy() |
| 142 | + out['CLIP_POOLED'] = torch.from_numpy(clip_pooled).to(self.latent_dtype).reshape(latent_shape[1]) |
| 143 | + return out |
| 144 | + |
| 145 | + |
| 146 | +def build_streaming_image_caption_latents_dataloader( |
| 147 | + remote: Union[str, List], |
| 148 | + batch_size: int, |
| 149 | + local: Optional[Union[str, List]] = None, |
| 150 | + caption_drop_prob: float = 0.0, |
| 151 | + microcond_drop_prob: float = 0.0, |
| 152 | + resize_size: Union[int, Tuple[int, int], Tuple[Tuple[int, int], ...]] = 256, |
| 153 | + ar_bucket_boundaries: Optional[Tuple[float, ...]] = None, |
| 154 | + transform: Optional[List[Callable]] = None, |
| 155 | + crop_type: Optional[str] = 'square', |
| 156 | + image_key: str = 'image', |
| 157 | + caption_keys: Tuple[str, ...] = ('caption',), |
| 158 | + caption_selection_probs: Tuple[float, ...] = (1.0,), |
| 159 | + text_latent_keys: Tuple[str, ...] = ('T5_LATENTS', 'CLIP_LATENTS'), |
| 160 | + text_latent_shapes: Tuple[Tuple, ...] = ((512, 4096), (77, 768)), |
| 161 | + attention_mask_keys: Tuple[str, ...] = ('T5_ATTENTION_MASK', 'CLIP_ATTENTION_MASK'), |
| 162 | + latent_dtype: str = 'torch.bfloat16', |
| 163 | + streaming_kwargs: Optional[Dict] = None, |
| 164 | + dataloader_kwargs: Optional[Dict] = None, |
| 165 | +): |
| 166 | + """Builds a streaming dataloader for image-caption pairs with pre-computed text latents. |
| 167 | +
|
| 168 | + Args: |
| 169 | + remote (str, Sequence[str]): One or more remote directories (S3 or local filesystem) where dataset is stored. |
| 170 | + batch_size (int): The batch size to use for both the ``StreamingDataset`` and ``DataLoader``. |
| 171 | + local (str, Sequence[str], optional): One or more local filesystem directories where dataset is cached during operation. |
| 172 | + caption_drop_prob (float): The probability of dropping a caption. Default: ``0.0``. |
| 173 | + microcond_drop_prob (float): The probability of dropping microconditioning. Default:``0.0``. |
| 174 | + resize_size (int, Tuple[int, int], Tuple[Tuple[int, int], ...]): The size to resize the image to. Specify a |
| 175 | + tuple of tuples if using 'aspect_ratio' crop_type. Default: ``256``. |
| 176 | + ar_bucket_boundaries (Tuple[float, ...], optional): When using ``crop_type='aspect_ratio'``, specifies the |
| 177 | + boundary points for bucket assignment. This tuple should be of length len(resize_size) - 1. If set to |
| 178 | + ``None``, the bucket with the smallest distance to the current sample's aspect ratio is selected. |
| 179 | + Default: ``None``. |
| 180 | + transform (Callable, optional): The transforms to apply to the image. Default: ``None``. |
| 181 | + crop_type (str, optional): Type of crop to perform, either ['square', 'random', 'aspect_ratio']. |
| 182 | + Default: ``'square'``. |
| 183 | + image_key (str): Key associated with the image in the streaming dataset. Default: ``'image'``. |
| 184 | + caption_keys (Tuple[str, ...]): Key(s) associated with captions in the streaming dataset. Default: ``('caption',)``. |
| 185 | + caption_selection_probs (Tuple[float, ...]): The probability of selecting each caption key. Default: ``(1.0,)``. |
| 186 | + text_latent_keys (Tuple[str, ...]): Key(s) associated with text latents in the streaming dataset. |
| 187 | + Default: ``('T5_LATENTS', 'CLIP_LATENTS')``. |
| 188 | + text_latent_shapes (Tuple[Tuple[int, int], ...]): The shape(s) of the text latents in the streaming dataset. |
| 189 | + Each shape is a 2-tuple where the first dim is the sequence length and the second dim is the feature size. |
| 190 | + Default: ``((512, 4096), (77, 768))``. |
| 191 | + attention_mask_keys (Tuple[str, ...]): Key(s) associated with attention masks in the streaming dataset. |
| 192 | + Default: ``('T5_ATTENTION_MASK', 'CLIP_ATTENTION_MASK')``. |
| 193 | + latent_dtype (str): The torch dtype to cast the text latents to. One of 'torch.float16', 'torch.float32', |
| 194 | + or 'torch.bfloat16'. Default: ``'torch.bfloat16'``. |
| 195 | + streaming_kwargs (dict, optional): Additional arguments to pass to the ``StreamingDataset``. Default: ``None``. |
| 196 | + dataloader_kwargs (dict, optional): Additional arguments to pass to the ``DataLoader``. Default: ``None``. |
| 197 | + """ |
| 198 | + # Check crop type |
| 199 | + if crop_type is not None: |
| 200 | + crop_type = crop_type.lower() |
| 201 | + if crop_type not in ['square', 'random', 'aspect_ratio']: |
| 202 | + raise ValueError(f'Invalid crop_type: {crop_type}. Must be ["square", "random", "aspect_ratio", None]') |
| 203 | + if crop_type == 'aspect_ratio' and (isinstance(resize_size, int) or isinstance(resize_size[0], int)): |
| 204 | + raise ValueError( |
| 205 | + 'If using crop_type="aspect_ratio", specify aspect ratio buckets in resize_size as a tuple of tuples.') |
| 206 | + |
| 207 | + # Check latent dtype |
| 208 | + dtypes = {'torch.float16': torch.float16, 'torch.float32': torch.float32, 'torch.bfloat16': torch.bfloat16} |
| 209 | + assert latent_dtype in dtypes, f'Invalid latent_dtype: {latent_dtype}. Must be one of {list(dtypes.keys())}' |
| 210 | + dtype = dtypes[latent_dtype] |
| 211 | + |
| 212 | + # Handle ``None`` kwargs |
| 213 | + if streaming_kwargs is None: |
| 214 | + streaming_kwargs = {} |
| 215 | + if dataloader_kwargs is None: |
| 216 | + dataloader_kwargs = {} |
| 217 | + |
| 218 | + # Make streams |
| 219 | + streams = make_streams(remote, local) |
| 220 | + |
| 221 | + # Set the crop to apply |
| 222 | + if crop_type == 'square': |
| 223 | + crop = LargestCenterSquare(resize_size) |
| 224 | + elif crop_type == 'random': |
| 225 | + crop = RandomCropSquare(resize_size) |
| 226 | + elif crop_type == 'aspect_ratio': |
| 227 | + crop = RandomCropAspectRatioTransform(resize_size, ar_bucket_boundaries) # type: ignore |
| 228 | + else: |
| 229 | + crop = None |
| 230 | + |
| 231 | + if transform is None: |
| 232 | + transform = [transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))] |
| 233 | + transform = transforms.Compose(transform) |
| 234 | + assert isinstance(transform, Callable) |
| 235 | + |
| 236 | + dataset = StreamingImageCaptionLatentsDataset( |
| 237 | + streams=streams, |
| 238 | + caption_drop_prob=caption_drop_prob, |
| 239 | + microcond_drop_prob=microcond_drop_prob, |
| 240 | + crop=crop, |
| 241 | + transform=transform, |
| 242 | + image_key=image_key, |
| 243 | + caption_keys=caption_keys, |
| 244 | + caption_selection_probs=caption_selection_probs, |
| 245 | + text_latent_keys=text_latent_keys, |
| 246 | + text_latent_shapes=text_latent_shapes, |
| 247 | + attention_mask_keys=attention_mask_keys, |
| 248 | + latent_dtype=dtype, |
| 249 | + **streaming_kwargs, |
| 250 | + ) |
| 251 | + |
| 252 | + dataloader = DataLoader( |
| 253 | + dataset=dataset, |
| 254 | + batch_size=batch_size, |
| 255 | + sampler=None, |
| 256 | + **dataloader_kwargs, |
| 257 | + ) |
| 258 | + |
| 259 | + return dataloader |
0 commit comments