Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 21 additions & 0 deletions diffusion/models/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from diffusion.models.stable_diffusion import StableDiffusion
from diffusion.models.text_encoder import MultiTextEncoder, MultiTokenizer
from diffusion.schedulers.schedulers import ContinuousTimeScheduler
from diffusion.schedulers.utils import shift_noise_schedule

try:
import xformers # type: ignore
Expand Down Expand Up @@ -51,6 +52,7 @@ def stable_diffusion_2(
beta_schedule: str = 'scaled_linear',
zero_terminal_snr: bool = False,
offset_noise: Optional[float] = None,
scheduler_shift_resolution: int = 256,
train_metrics: Optional[List] = None,
val_metrics: Optional[List] = None,
quasirandomness: bool = False,
Expand Down Expand Up @@ -97,6 +99,7 @@ def stable_diffusion_2(
precomputed_latents (bool): Whether to use precomputed latents. Defaults to False.
offset_noise (float, optional): The scale of the offset noise. If not specified, offset noise will not
be used. Default `None`.
scheduler_shift_resolution (int): The resolution to shift the noise scheduler to. Default: `256`.
encode_latents_in_fp16 (bool): Whether to encode latents in fp16. Defaults to True.
mask_pad_tokens (bool): Whether to mask pad tokens in cross attention. Defaults to False.
fsdp (bool): Whether to use FSDP. Defaults to True.
Expand Down Expand Up @@ -178,6 +181,14 @@ def stable_diffusion_2(
set_alpha_to_one=False,
prediction_type=prediction_type)

# Shift noise scheduler to correct for resolution changes
noise_scheduler = shift_noise_schedule(noise_scheduler,
base_dim=32,
shift_dim=scheduler_shift_resolution // downsample_factor)
inference_noise_scheduler = shift_noise_schedule(inference_noise_scheduler,
base_dim=32,
shift_dim=scheduler_shift_resolution // downsample_factor)

# Make the composer model
model = StableDiffusion(
unet=unet,
Expand Down Expand Up @@ -236,6 +247,7 @@ def stable_diffusion_xl(
zero_terminal_snr: bool = False,
use_karras_sigmas: bool = False,
offset_noise: Optional[float] = None,
scheduler_shift_resolution: int = 256,
train_metrics: Optional[List] = None,
val_metrics: Optional[List] = None,
quasirandomness: bool = False,
Expand Down Expand Up @@ -283,6 +295,7 @@ def stable_diffusion_xl(
use_karras_sigmas (bool): Whether to use the Karras sigmas for the diffusion process noise. Default: `False`.
offset_noise (float, optional): The scale of the offset noise. If not specified, offset noise will not
be used. Default `None`.
scheduler_shift_resolution (int): The resolution to shift the noise scheduler to. Default: `256`.
train_metrics (list, optional): List of metrics to compute during training. If None, defaults to
[MeanSquaredError()].
val_metrics (list, optional): List of metrics to compute during validation. If None, defaults to
Expand Down Expand Up @@ -433,6 +446,14 @@ def stable_diffusion_xl(
steps_offset=1,
rescale_betas_zero_snr=zero_terminal_snr)

# Shift noise scheduler to correct for resolution changes
noise_scheduler = shift_noise_schedule(noise_scheduler,
base_dim=32,
shift_dim=scheduler_shift_resolution // downsample_factor)
inference_noise_scheduler = shift_noise_schedule(inference_noise_scheduler,
base_dim=32,
shift_dim=scheduler_shift_resolution // downsample_factor)

# Make the composer model
model = StableDiffusion(
unet=unet,
Expand Down
3 changes: 2 additions & 1 deletion diffusion/schedulers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,5 +4,6 @@
"""Diffusion schedulers."""

from diffusion.schedulers.schedulers import ContinuousTimeScheduler
from diffusion.schedulers.utils import shift_noise_schedule

__all__ = ['ContinuousTimeScheduler']
__all__ = ['ContinuousTimeScheduler', 'shift_noise_schedule']
39 changes: 39 additions & 0 deletions diffusion/schedulers/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
# Copyright 2022 MosaicML Diffusion authors
# SPDX-License-Identifier: Apache-2.0

"""Utils for working with diffusion schedulers."""

import torch


def shift_noise_schedule(noise_scheduler, base_dim: int = 64, shift_dim: int = 64):
"""Shifts the function SNR(t) for a noise scheduler to correct for resolution changes.

Implements the technique from https://arxiv.org/abs/2301.11093

Args:
noise_scheduler (diffusers.SchedulerMixin): The noise scheduler to shift.
base_dim (int): The base side length of the schedule resolution.
shift_dim (int): The new side length of the schedule resolution.

Returns:
diffusers.SchedulerMixin: The shifted noise scheduler.
"""
# First, we need to get the original SNR(t) function
alpha_bar = noise_scheduler.alphas_cumprod
SNR = alpha_bar / (1 - alpha_bar)
# Shift the SNR acorrording to the resolution change
SNR_shifted = (base_dim / shift_dim)**2 * SNR
# Get the new alpha_bars
alpha_bar_shifted = torch.where(SNR_shifted == float('inf'), torch.tensor(1.0), SNR_shifted / (1 + SNR_shifted))
# Get the new alpha values
alpha_shifted = torch.empty_like(alpha_bar_shifted)
alpha_shifted[0] = alpha_bar_shifted[0]
alpha_shifted[1:] = alpha_bar_shifted[1:] / alpha_bar_shifted[:-1]
# Get the new beta values
beta_shifted = 1 - alpha_shifted
# Update the noise scheduler
noise_scheduler.alphas = alpha_shifted
noise_scheduler.betas = beta_shifted
noise_scheduler.alphas_cumprod = alpha_bar_shifted
return noise_scheduler