Skip to content

Commit 93a5469

Browse files
authored
Only make a text encoder mask if mask_pad_tokens is true (#149)
1 parent 7d3a7cc commit 93a5469

File tree

1 file changed

+2
-4
lines changed

1 file changed

+2
-4
lines changed

diffusion/models/stable_diffusion.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -182,11 +182,9 @@ def set_rng_generator(self, rng_generator: torch.Generator):
182182

183183
def forward(self, batch):
184184
latents, text_embeds, text_pooled_embeds, attention_mask, encoder_attention_mask = None, None, None, None, None
185-
if 'attention_mask' in batch:
185+
if 'attention_mask' in batch and self.mask_pad_tokens:
186186
attention_mask = batch['attention_mask'] # mask for text encoders
187-
# text mask for U-Net
188-
if self.mask_pad_tokens:
189-
encoder_attention_mask = _create_unet_attention_mask(attention_mask)
187+
encoder_attention_mask = _create_unet_attention_mask(attention_mask) # text mask for U-Net
190188

191189
# Use latents if specified and available. When specified, they might not exist during eval
192190
if self.precomputed_latents and self.image_latents_key in batch and self.text_latents_key in batch:

0 commit comments

Comments
 (0)