Skip to content

Commit 2a6fff4

Browse files
authored
Add dataset for running with precomputed latents from multiple captions (#161)
1 parent ee934a4 commit 2a6fff4

File tree

2 files changed

+263
-0
lines changed

2 files changed

+263
-0
lines changed

diffusion/datasets/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@
55

66
from diffusion.datasets.coco import StreamingCOCOCaption, build_streaming_cocoval_dataloader
77
from diffusion.datasets.image_caption import StreamingImageCaptionDataset, build_streaming_image_caption_dataloader
8+
from diffusion.datasets.image_caption_latents import (StreamingImageCaptionLatentsDataset,
9+
build_streaming_image_caption_latents_dataloader)
810
from diffusion.datasets.laion import StreamingLAIONDataset, build_streaming_laion_dataloader
911
from diffusion.datasets.synthetic_image_caption import (SyntheticImageCaptionDataset,
1012
build_synthetic_image_caption_dataloader)
@@ -16,6 +18,8 @@
1618
'StreamingCOCOCaption',
1719
'build_streaming_image_caption_dataloader',
1820
'StreamingImageCaptionDataset',
21+
'build_streaming_image_caption_latents_dataloader',
22+
'StreamingImageCaptionLatentsDataset',
1923
'build_synthetic_image_caption_dataloader',
2024
'SyntheticImageCaptionDataset',
2125
]
Lines changed: 259 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,259 @@
1+
# Copyright 2022 MosaicML Diffusion authors
2+
# SPDX-License-Identifier: Apache-2.0
3+
4+
"""Streaming Image-Caption Dataset for use with Pre-computed Text Latents."""
5+
6+
import logging
7+
from io import BytesIO
8+
from typing import Callable, Dict, List, Optional, Sequence, Tuple, Union
9+
10+
import numpy as np
11+
import torch
12+
from PIL import Image
13+
from streaming import Stream, StreamingDataset
14+
from torch.utils.data import DataLoader
15+
from torchvision import transforms
16+
17+
from diffusion.datasets.laion.transforms import LargestCenterSquare, RandomCropAspectRatioTransform, RandomCropSquare
18+
from diffusion.datasets.utils import make_streams
19+
20+
log = logging.getLogger(__name__)
21+
22+
23+
class StreamingImageCaptionLatentsDataset(StreamingDataset):
24+
"""Streaming dataset for image-caption datasets with pre-computed text latents.
25+
26+
Args:
27+
streams (Sequence[Stream]): One or more Streams to stream/cache samples from.
28+
caption_drop_prob (float): The probability of dropping a caption. Default: ``0.0``.
29+
microcond_drop_prob (float): The probability of dropping microconditioning. Only relevant for SDXL. Default: ``0.0``.
30+
crop (Callable, optional): The crop transform to apply to the image before ``transform``. Default: ``None``
31+
transform (Callable, optional): The transforms to apply to the image. Default: ``None``.
32+
image_key (str): Key associated with the image in the streaming dataset. Default: ``'image'``.
33+
caption_keys (Tuple[str, ...]): Key(s) associated with captions in the streaming dataset. Default: ``('caption',)``.
34+
caption_selection_probs (Tuple[float, ...]): The probability of selecting each caption key. Default: ``(1.0,)``.
35+
text_latent_keys (Tuple[str, ...]): Key(s) associated with text latents in the streaming dataset.
36+
Default: ``('T5_LATENTS', 'CLIP_LATENTS')``.
37+
text_latent_shapes (Tuple[Tuple[int, int], ...]): The shape(s) of the text latents in the streaming dataset.
38+
Each shape is a 2-tuple where the first dim is the sequence length and the second dim is the feature size.
39+
Default: ``((512, 4096), (77, 768))``.
40+
attention_mask_keys (Tuple[str, ...]): Key(s) associated with attention masks in the streaming dataset.
41+
Default: ``('T5_ATTENTION_MASK', 'CLIP_ATTENTION_MASK')``.
42+
latent_dtype (torch.dtype): The dtype to cast the text latents to. Default: ``torch.bfloat16``.
43+
**streaming_kwargs: Additional arguments to pass in the construction of the StreamingDataloader
44+
"""
45+
46+
def __init__(
47+
self,
48+
streams: Sequence[Stream],
49+
caption_drop_prob: float = 0.0,
50+
microcond_drop_prob: float = 0.0,
51+
crop: Optional[Callable] = None,
52+
transform: Optional[Callable] = None,
53+
image_key: str = 'image',
54+
caption_keys: Tuple[str, ...] = ('caption',),
55+
caption_selection_probs: Tuple[float, ...] = (1.0,),
56+
text_latent_keys: Tuple[str, ...] = ('T5_LATENTS', 'CLIP_LATENTS'),
57+
text_latent_shapes: Tuple[Tuple[int, int], ...] = ((512, 4096), (77, 768)),
58+
attention_mask_keys: Tuple[str, ...] = ('T5_ATTENTION_MASK', 'CLIP_ATTENTION_MASK'),
59+
latent_dtype: torch.dtype = torch.bfloat16,
60+
**streaming_kwargs,
61+
):
62+
63+
# Set defaults for vision-friendly streaming args.
64+
streaming_kwargs.setdefault('shuffle_block_size', 1 << 18)
65+
streaming_kwargs.setdefault('shuffle_algo', 'py1s')
66+
super().__init__(streams=streams, **streaming_kwargs)
67+
68+
self.crop = crop
69+
self.transform = transform
70+
self.caption_drop_prob = caption_drop_prob
71+
self.microcond_drop_prob = microcond_drop_prob
72+
self.image_key = image_key
73+
self.caption_keys = caption_keys
74+
self.caption_selection_probs = caption_selection_probs
75+
self.text_latent_keys = text_latent_keys
76+
self.text_latent_shapes = text_latent_shapes
77+
self.attention_mask_keys = attention_mask_keys
78+
self.latent_dtype = latent_dtype
79+
80+
def __getitem__(self, index):
81+
sample = super().__getitem__(index)
82+
out = {}
83+
84+
# Image
85+
img = sample[self.image_key]
86+
if not isinstance(img, Image.Image):
87+
img = Image.open(BytesIO(sample[self.image_key]))
88+
if img.mode != 'RGB':
89+
img = img.convert('RGB')
90+
out['cond_original_size'] = torch.tensor(img.size)
91+
92+
# Image transforms
93+
if self.crop is not None:
94+
img, crop_top, crop_left = self.crop(img)
95+
else:
96+
crop_top, crop_left = 0, 0
97+
out['cond_crops_coords_top_left'] = torch.tensor([crop_top, crop_left])
98+
99+
if self.transform is not None:
100+
img = self.transform(img)
101+
out['image'] = img
102+
103+
# Get the new height and width
104+
if isinstance(img, torch.Tensor):
105+
img_h, img_w = img.shape[-2], img.shape[-1]
106+
elif isinstance(img, Image.Image):
107+
img_w, img_h = img.size
108+
else:
109+
raise ValueError('Image after transformations must either be a PIL Image or Torch Tensor')
110+
out['cond_target_size'] = torch.tensor([img_w, img_h])
111+
112+
# Microconditioning dropout as in Stability repo
113+
# https://github.com/Stability-AI/generative-models/blob/477d8b9a7730d9b2e92b326a770c0420d00308c9/sgm/modules/encoders/modules.py#L151-L160
114+
if torch.rand(1) < self.microcond_drop_prob:
115+
out['cond_crops_coords_top_left'] = out['cond_crops_coords_top_left'] * 0
116+
if torch.rand(1) < self.microcond_drop_prob:
117+
out['cond_original_size'] = out['cond_original_size'] * 0
118+
if torch.rand(1) < self.microcond_drop_prob:
119+
out['cond_target_size'] = out['cond_target_size'] * 0
120+
121+
# Randomly select a caption according to the selection probabilities
122+
caption_key = np.random.choice(self.caption_keys, p=self.caption_selection_probs)
123+
# Load text latents, attention masks, and clip pooled embeddings
124+
for i in range(len(self.text_latent_keys)):
125+
latent_key = f'{caption_key}_{self.text_latent_keys[i]}'
126+
latent_shape = self.text_latent_shapes[i]
127+
attention_key = f'{caption_key}_{self.attention_mask_keys[i]}'
128+
129+
if torch.rand(1) < self.caption_drop_prob:
130+
out[self.text_latent_keys[i]] = torch.zeros(latent_shape, dtype=self.latent_dtype)
131+
out[self.attention_mask_keys[i]] = torch.zeros(latent_shape[0])
132+
if 'CLIP_LATENTS' in latent_key:
133+
out['CLIP_POOLED'] = torch.zeros(latent_shape[1])
134+
else:
135+
text_latent = np.frombuffer(sample[latent_key], dtype=np.float32).copy()
136+
out[self.text_latent_keys[i]] = torch.from_numpy(text_latent).to(
137+
self.latent_dtype).reshape(latent_shape)
138+
attention_mask = np.frombuffer(sample[attention_key], dtype=np.bool_).copy()
139+
out[self.attention_mask_keys[i]] = torch.from_numpy(attention_mask).to(dtype=torch.float).reshape(-1) #.reshape(latent_shape[0])
140+
if 'CLIP_LATENTS' in latent_key:
141+
clip_pooled = np.frombuffer(sample[f'{caption_key}_CLIP_POOLED_TEXT'], dtype=np.float32).copy()
142+
out['CLIP_POOLED'] = torch.from_numpy(clip_pooled).to(self.latent_dtype).reshape(latent_shape[1])
143+
return out
144+
145+
146+
def build_streaming_image_caption_latents_dataloader(
147+
remote: Union[str, List],
148+
batch_size: int,
149+
local: Optional[Union[str, List]] = None,
150+
caption_drop_prob: float = 0.0,
151+
microcond_drop_prob: float = 0.0,
152+
resize_size: Union[int, Tuple[int, int], Tuple[Tuple[int, int], ...]] = 256,
153+
ar_bucket_boundaries: Optional[Tuple[float, ...]] = None,
154+
transform: Optional[List[Callable]] = None,
155+
crop_type: Optional[str] = 'square',
156+
image_key: str = 'image',
157+
caption_keys: Tuple[str, ...] = ('caption',),
158+
caption_selection_probs: Tuple[float, ...] = (1.0,),
159+
text_latent_keys: Tuple[str, ...] = ('T5_LATENTS', 'CLIP_LATENTS'),
160+
text_latent_shapes: Tuple[Tuple, ...] = ((512, 4096), (77, 768)),
161+
attention_mask_keys: Tuple[str, ...] = ('T5_ATTENTION_MASK', 'CLIP_ATTENTION_MASK'),
162+
latent_dtype: str = 'torch.bfloat16',
163+
streaming_kwargs: Optional[Dict] = None,
164+
dataloader_kwargs: Optional[Dict] = None,
165+
):
166+
"""Builds a streaming dataloader for image-caption pairs with pre-computed text latents.
167+
168+
Args:
169+
remote (str, Sequence[str]): One or more remote directories (S3 or local filesystem) where dataset is stored.
170+
batch_size (int): The batch size to use for both the ``StreamingDataset`` and ``DataLoader``.
171+
local (str, Sequence[str], optional): One or more local filesystem directories where dataset is cached during operation.
172+
caption_drop_prob (float): The probability of dropping a caption. Default: ``0.0``.
173+
microcond_drop_prob (float): The probability of dropping microconditioning. Default:``0.0``.
174+
resize_size (int, Tuple[int, int], Tuple[Tuple[int, int], ...]): The size to resize the image to. Specify a
175+
tuple of tuples if using 'aspect_ratio' crop_type. Default: ``256``.
176+
ar_bucket_boundaries (Tuple[float, ...], optional): When using ``crop_type='aspect_ratio'``, specifies the
177+
boundary points for bucket assignment. This tuple should be of length len(resize_size) - 1. If set to
178+
``None``, the bucket with the smallest distance to the current sample's aspect ratio is selected.
179+
Default: ``None``.
180+
transform (Callable, optional): The transforms to apply to the image. Default: ``None``.
181+
crop_type (str, optional): Type of crop to perform, either ['square', 'random', 'aspect_ratio'].
182+
Default: ``'square'``.
183+
image_key (str): Key associated with the image in the streaming dataset. Default: ``'image'``.
184+
caption_keys (Tuple[str, ...]): Key(s) associated with captions in the streaming dataset. Default: ``('caption',)``.
185+
caption_selection_probs (Tuple[float, ...]): The probability of selecting each caption key. Default: ``(1.0,)``.
186+
text_latent_keys (Tuple[str, ...]): Key(s) associated with text latents in the streaming dataset.
187+
Default: ``('T5_LATENTS', 'CLIP_LATENTS')``.
188+
text_latent_shapes (Tuple[Tuple[int, int], ...]): The shape(s) of the text latents in the streaming dataset.
189+
Each shape is a 2-tuple where the first dim is the sequence length and the second dim is the feature size.
190+
Default: ``((512, 4096), (77, 768))``.
191+
attention_mask_keys (Tuple[str, ...]): Key(s) associated with attention masks in the streaming dataset.
192+
Default: ``('T5_ATTENTION_MASK', 'CLIP_ATTENTION_MASK')``.
193+
latent_dtype (str): The torch dtype to cast the text latents to. One of 'torch.float16', 'torch.float32',
194+
or 'torch.bfloat16'. Default: ``'torch.bfloat16'``.
195+
streaming_kwargs (dict, optional): Additional arguments to pass to the ``StreamingDataset``. Default: ``None``.
196+
dataloader_kwargs (dict, optional): Additional arguments to pass to the ``DataLoader``. Default: ``None``.
197+
"""
198+
# Check crop type
199+
if crop_type is not None:
200+
crop_type = crop_type.lower()
201+
if crop_type not in ['square', 'random', 'aspect_ratio']:
202+
raise ValueError(f'Invalid crop_type: {crop_type}. Must be ["square", "random", "aspect_ratio", None]')
203+
if crop_type == 'aspect_ratio' and (isinstance(resize_size, int) or isinstance(resize_size[0], int)):
204+
raise ValueError(
205+
'If using crop_type="aspect_ratio", specify aspect ratio buckets in resize_size as a tuple of tuples.')
206+
207+
# Check latent dtype
208+
dtypes = {'torch.float16': torch.float16, 'torch.float32': torch.float32, 'torch.bfloat16': torch.bfloat16}
209+
assert latent_dtype in dtypes, f'Invalid latent_dtype: {latent_dtype}. Must be one of {list(dtypes.keys())}'
210+
dtype = dtypes[latent_dtype]
211+
212+
# Handle ``None`` kwargs
213+
if streaming_kwargs is None:
214+
streaming_kwargs = {}
215+
if dataloader_kwargs is None:
216+
dataloader_kwargs = {}
217+
218+
# Make streams
219+
streams = make_streams(remote, local)
220+
221+
# Set the crop to apply
222+
if crop_type == 'square':
223+
crop = LargestCenterSquare(resize_size)
224+
elif crop_type == 'random':
225+
crop = RandomCropSquare(resize_size)
226+
elif crop_type == 'aspect_ratio':
227+
crop = RandomCropAspectRatioTransform(resize_size, ar_bucket_boundaries) # type: ignore
228+
else:
229+
crop = None
230+
231+
if transform is None:
232+
transform = [transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]
233+
transform = transforms.Compose(transform)
234+
assert isinstance(transform, Callable)
235+
236+
dataset = StreamingImageCaptionLatentsDataset(
237+
streams=streams,
238+
caption_drop_prob=caption_drop_prob,
239+
microcond_drop_prob=microcond_drop_prob,
240+
crop=crop,
241+
transform=transform,
242+
image_key=image_key,
243+
caption_keys=caption_keys,
244+
caption_selection_probs=caption_selection_probs,
245+
text_latent_keys=text_latent_keys,
246+
text_latent_shapes=text_latent_shapes,
247+
attention_mask_keys=attention_mask_keys,
248+
latent_dtype=dtype,
249+
**streaming_kwargs,
250+
)
251+
252+
dataloader = DataLoader(
253+
dataset=dataset,
254+
batch_size=batch_size,
255+
sampler=None,
256+
**dataloader_kwargs,
257+
)
258+
259+
return dataloader

0 commit comments

Comments
 (0)