How does Stable Diffusion generate images from text?
Mark Ezema / September 26, 2022
15 min read •
As someone who loves art, I've always been intrigued by how AI will power the next generation of artists and creators to generate content. Stability AI released an open source model for image generation called Stable diffusion, and this article will be looking into how Stable Diffusion works.
Table of Contents
- Overview of model architecture
- Setup
- The AutoEncoder
- The Scheduler
- The Text Encoder
- The UNET and CFG
- Guidance
- Conclusion
Overview
Stable Diffusion is powered by Latent Diffusion, a cutting-edge text-to-image synthesis technique. The process was described in a paper published by AI researchers at the Ludwig Maximilian University of Munich called High-Resolution Image Synthesis with Latent Diffusion Models.
The model achieves state-of-the-art results by breaking down the process of making an image into a series of applications of denoising autoencoders. Also, the way they are made lets them be used immediately for image editing tasks like inpainting without having to be retrained.
The downside is that these models usually work directly in pixel space and optimizing powerful diffusion models can take hundreds of GPU days which is expensive because evaluations are done one at a time.
Stable Diffusion deals with this by using latent diffusion, which instead does the diffusion process in this 'latent space', using the compressed representations from the AutoEncoder rather than raw images.
These representations are information-rich and can be small enough to handle manageably on consumer hardware. Once the new 'image' is generated as a latent representation, the autoencoder can take those final latent outputs and turn them into actual pixels.
The core architecture of Stability AI latent diffusion models(LDMs) revolves around separating the compressive and generative learning phases. It relies on an autoencoder to learn a lower-dimension representation of the pixel space, then, this latent representation is passed through the diffusion process, which adds noise at each step.
That phase’s output is fed into a denoising network based on the U-Net architecture with cross-attention layers. This denoising network employs some additional inputs, such as semantic maps or additional image or text representations, in addition to the latent representation.
This leads to Stable Diffusion having a relatively lightweight architecture, you can run the model on your hardware or Google Colab, for example, if you have enough computing power.
Setup
If you want to follow along, Google Colab is an excellent option.
Here we install the library, log into HF (you'll need to accept the license terms to download the model), and import some things we'll need.
pip install transformers diffusers
Login to huggingface
from huggingface_hub import notebook_login
notebook_login()
Import relevant modules
import torch
from transformers import CLIPTextModel, CLIPTokenizer
from diffusers import AutoencoderKL, UNet2DConditionModel, LMSDiscreteScheduler
from tqdm.auto import tqdm
from torch import autocast
from PIL import Image
from matplotlib import pyplot as plt
import numpy
from torchvision import transforms as tfms
# For video display:
from IPython.display import HTML
from base64 import b64encode
# Set device
torch_device = "cuda" if torch.cuda.is_available() else "CPU"
Loading the models
This code (and that in the next section) comes from the Huggingface example notebook.
This will download and set up the relevant models and components we'll be using. Let's just run this for now and move on to the next section to check that it all works before diving deeper.
# Load the autoencoder model which will be used to decode the latents into image space.
vae = AutoencoderKL.from_pretrained("CompVis/stable-diffusion-v1-4", subfolder="vae", use_auth_token=True)
# Load the tokenizer and text encoder to tokenize and encode the text.
tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14")
text_encoder = CLIPTextModel.from_pretrained("openai/clip-vit-large-patch14")
# The UNet model for generating the latents.
unet = UNet2DConditionModel.from_pretrained("CompVis/stable-diffusion-v1-4", subfolder="unet", use_auth_token=True)
# The noise scheduler
scheduler = LMSDiscreteScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000)
# Move computations to the GPU
vae = vae.to(torch_device)
text_encoder = text_encoder.to(torch_device)
unet = unet.to(torch_device)
Let's look at a generated image of the "coolest cat on the planet with sunglasses".
# Some settings
prompt = ["Coolest cat on the planet with sunglasses"]
height = 512 # default height of Stable Diffusion
width = 512 # default width of Stable Diffusion
num_inference_steps = 50 # Number of denoising steps
guidance_scale = 7.5 # Scale for classifier-free guidance
generator = torch.manual_seed(42) # Seed generator to create the initial latent noise
batch_size = 1
# Prep text
text_input = tokenizer(prompt, padding="max_length", max_length=tokenizer.model_max_length, truncation=True, return_tensors="pt")
with torch.no_grad():
text_embeddings = text_encoder(text_input.input_ids.to(torch_device))[0]
max_length = text_input.input_ids.shape[-1]
uncond_input = tokenizer(
[""] * batch_size, padding="max_length", max_length=max_length, return_tensors="pt"
)
with torch.no_grad():
uncond_embeddings = text_encoder(uncond_input.input_ids.to(torch_device))[0]
text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
# Prep Scheduler
scheduler.set_timesteps(num_inference_steps)
# Prep latents
latents = torch.randn(
(batch_size, unet.in_channels, height // 8, width // 8),
generator=generator,
)
latents = latents.to(torch_device)
latents = latents * scheduler.sigmas[0] # Need to scale to match k
# Loop
with autocast("cuda"):
for i, t in tqdm(enumerate(scheduler.timesteps)):
# expand the latents if we are doing classifier-free guidance to avoid doing two forward passes.
latent_model_input = torch.cat([latents] * 2)
sigma = scheduler.sigmas[i]
latent_model_input = latent_model_input / ((sigma**2 + 1) ** 0.5)
# predict the noise residual
with torch.no_grad():
noise_pred = unet(latent_model_input, t, encoder_hidden_states=text_embeddings)["sample"]
# perform guidance
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
# compute the previous noisy sample x_t -> x_t-1
latents = scheduler.step(noise_pred, i, latents)["prev_sample"]
# scale and decode the image latents with vae
latents = 1 / 0.18215 * latents
with torch.no_grad():
image = vae.decode(latents)
# Display
image = (image / 2 + 0.5).clamp(0, 1)
image = image.detach().cpu().permute(0, 2, 3, 1).numpy()
images = (image * 255).round().astype("uint8")
pil_images = [Image.fromarray(image) for image in images]
pil_images[0]
Now let's take a closer look at the different components.
The AutoEncoder
The AE can 'encode' an image into some sort of latent representation, and decode this back into an image.
We can get a better intuition by looking at it in action.
# Using torchvision.transforms.ToTensor
to_tensor_tfm = tfms.ToTensor()
def pil_to_latent(input_im):
# Single image -> single latent in a batch (so size 1, 4, 64, 64)
with torch.no_grad():
latent = vae.encode(to_tensor_tfm(input_im).unsqueeze(0).to(torch_device)*2-1) # Note scaling
return 0.18215 * latent.mode() # or .mean or .sample
def latents_to_pil(latents):
# bath of latents -> list of images
latents = (1 / 0.18215) * latents
with torch.no_grad():
image = vae.decode(latents)
image = (image / 2 + 0.5).clamp(0, 1)
image = image.detach().cpu().permute(0, 2, 3, 1).numpy()
images = (image * 255).round().astype("uint8")
pil_images = [Image.fromarray(image) for image in images]
return pil_images
We'll use a pic of a macaw from the web.
!curl --output macaw.jpg 'https://lafeber.com/pet-birds/wp-content/uploads/2018/06/Scarlet-Macaw-2.jpg'
# Load the image with PIL
input_image = Image.open('macaw.jpg').resize((512, 512))
input_image
Encoding this into the latent space of the AE with the function defined above looks like this:
# Encode to the latent space
encoded = pil_to_latent(input_image)
encoded.shape
# output
torch.Size([1, 4, 64, 64])
We can do the reverse as well, going from the latent representation back to an image:
# Decode this latent representation back into an image
decoded = latents_to_pil(encoded)[0]
decoded
You'll see some small differences if you look closely. Focus on the eye if you can't see anything obvious. This is pretty impressive - that 4x64x64 latent seems to hold a lot more information than a 64px image.
This autoencoder has been trained to squish down an image to a smaller representation and then re-create the image back from this compressed version again.
In this particular case, it does so by a factor of ~8 (512 / 8 = 64), so each 8x8px patch gets compressed down to four numbers (the four channels in the AE output). You can find AEs with a higher compression ratio (eg f16 like some popular VQGAN models) but at some point, they begin to introduce artifacts that we don't want.
The Scheduler
Now we need to talk about adding noise...
During training, we add some noise to an image and then have the model try to predict the noise. If we always added a ton of noise, the model might not have much to work with. If we only add a tiny amount, the model won't be able to do much with the random starting points we use for sampling. So during training, the amount is varied, according to some distribution.
During sampling, we want to 'denoise' over a number of steps. How many steps and how much noise we should aim for at each step are going to affect the final result.
The scheduler is in charge of handling all of these details.
Let's look at the macaw image with some noise.
# View a noised version
noise = torch.randn_like(encoded) # Random noise
timestep = 150 # i.e. equivalent to that at 150/1000 training steps
encoded_and_noised = scheduler.add_noise(encoded, noise, timestep)
latents_to_pil(encoded_and_noised)[0] # Display
The Text Encoder
How do our captions get turned into something the model can use to produce images?
We rely on a text encoder - usually a transformer model. In this case, SD uses one of openai's CLIP text encoders. Since CLIP was trained to find relationships between text and images, the hope is that this model has learned to map text to very rich representations which we can use.
Let's walk through the process of moving from text to a set of embeddings we can feed to our diffusion model:
Let's start with a prompt
prompt = 'Cave painting of a bird, flooble'
The first step is to convert this into a sequence of tokens
text_input = tokenizer(prompt, padding="max_length", max_length=tokenizer.model_max_length, truncation=True, return_tensors="pt")
text_input['input_ids'][0] # View the tokens
# output
tensor([49406, 9654, 3086, 539, 320, 3329, 267, 4062, 1059, 49407,
49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407,
49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407,
49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407,
49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407,
49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407,
49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407,
49407, 49407, 49407, 49407, 49407, 49407, 49407])
There are some special tokens in here, as well as tokens representing our input. Most are just 49407, which is a special padding token. Many words get their own token, but for unknown words like 'flooble' the tokenizer splits them up. We can decode one token at a time to see how the tokenizer has interpreted our input:
# Decode the first non-special token
tokenizer.decoder.get(9654)
'cave</w>'
# And the last two non-pad tokens:
print(4062, tokenizer.decoder.get(4062))
print(1059, tokenizer.decoder.get(1059))
# output
4062 floo
1059 ble</w>
Now that we have a fixed-length input of tokens, the next step is to convert these into embeddings.
When the tokens are first passed to the text model, a set of learned embeddings is used to create a vector from each token. These are the model's input embeddings. These input embeddings are then fed through the text model (a transformer) which outputs a new set of embeddings. It is these output embeddings that we're using and talking about, and thanks to the magic of attention and transformers they hopefully capture a richer set of semantic meanings that incorporate the context of the prompt as a whole better than a discreet set of word embeddings would.
This is what the final embeddings look like:
# Grab the embeddings
embeddings = text_encoder(text_input.input_ids.to(torch_device))[0]
print('Shape:', embeddings.shape)
embeddings
# output
Shape: torch.Size([1, 77, 768])
tensor([[[-0.3884, 0.0229, -0.0522, ..., -0.4899, -0.3066, 0.0675],
[ 0.1967, 0.3905, -1.1450, ..., -0.9653, 2.0285, 0.2323],
[-0.1163, 0.4147, -1.2173, ..., -0.6865, -0.3320, 1.5457],
...,
[-0.3868, -0.0672, -0.0636, ..., -1.5762, 0.1958, -0.1945],
[-0.3894, -0.0515, -0.0919, ..., -1.5772, 0.2131, -0.1897],
[-0.3998, -0.0703, 0.0162, ..., -1.5541, 0.2430, -0.1985]]],
device='cuda:0', grad_fn=<NativeLayerNormBackward0>)
For every token in the 77-token input, we have a 768-dimensional representation. These capture a lot of meaning about the input string, and the hope is that this gives the diffusion model as much chance as possible to use this info in the denoising task.
To have a bit of fun with this, let's embed two inputs and take the average as our final embedding, and see what kind of image that results in. We can control how much of each we include with a mix_factor.
#@title re-generate starting from a noised version of this image
prompt1 = 'A dog' #@param
prompt2 = 'A cat' #@param
mix_factor = 0.4 #@param
height = 512 # default height of Stable Diffusion
width = 512 # default width of Stable Diffusion
num_inference_steps = 50 #@param # Number of denoising steps
guidance_scale = 8 # Scale for classifier-free guidance
generator = torch.manual_seed(32) # Seed generator to create the initial latent noise
batch_size = 1
# Prep text
# Embed both prompts
text_input1 = tokenizer([prompt1], padding="max_length", max_length=tokenizer.model_max_length, truncation=True, return_tensors="pt")
with torch.no_grad():
text_embeddings1 = text_encoder(text_input1.input_ids.to(torch_device))[0]
text_input2 = tokenizer([prompt2], padding="max_length", max_length=tokenizer.model_max_length, truncation=True, return_tensors="pt")
with torch.no_grad():
text_embeddings2 = text_encoder(text_input2.input_ids.to(torch_device))[0]
# Take the average
text_embeddings = (text_embeddings1*mix_factor + \
text_embeddings2*(1-mix_factor))
# And the second input as before:
max_length = text_input.input_ids.shape[-1]
uncond_input = tokenizer(
[""] * batch_size, padding="max_length", max_length=max_length, return_tensors="pt"
)
with torch.no_grad():
uncond_embeddings = text_encoder(uncond_input.input_ids.to(torch_device))[0]
text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
# Prep Scheduler
scheduler.set_timesteps(num_inference_steps)
# Prep latents
latents = torch.randn(
(batch_size, unet.in_channels, height // 8, width // 8),
generator=generator,
)
latents = latents.to(torch_device)
latents = latents * scheduler.sigmas[0] # Need to scale to match k
# Loop
with autocast("cuda"):
for i, t in tqdm(enumerate(scheduler.timesteps)):
# expand the latents if we are doing classifier-free guidance to avoid doing two forward passes.
latent_model_input = torch.cat([latents] * 2)
sigma = scheduler.sigmas[i]
latent_model_input = latent_model_input / ((sigma**2 + 1) ** 0.5)
# predict the noise residual
with torch.no_grad():
noise_pred = unet(latent_model_input, t, encoder_hidden_states=text_embeddings)["sample"]
# perform guidance
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
# compute the previous noisy sample x_t -> x_t-1
latents = scheduler.step(noise_pred, i, latents)["prev_sample"]
latents_to_pil(latents)[0]
The UNET and CFG
Now it's time we looked at the actual diffusion model. This is typically a UNET that takes in the noisy latents (x) and predicts the noise. We use a conditional model that also takes in the timestep (t) and our text embedding as conditioning. Feeding all of these into the model looks like this:
noise_pred = unet(latents, t, encoder_hidden_states=text_embeddings)["sample"]
We can try it out and see what the output looks like:
# Prep Scheduler
scheduler.set_timesteps(num_inference_steps)
# What is our timestep
t = scheduler.timesteps[0]
sigma = scheduler.sigmas[0]
# A noisy latent
latents = torch.randn(
(batch_size, unet.in_channels, height // 8, width // 8),
generator=generator,
)
latents = latents.to(torch_device)
latents = latents * scheduler.sigmas[0]
# Text embedding
text_input = tokenizer(['A macaw'], padding="max_length", max_length=tokenizer.model_max_length, truncation=True, return_tensors="pt")
with torch.no_grad():
text_embeddings = text_encoder(text_input.input_ids.to(torch_device))[0]
# Run this through the unet to predict the noise residual
with torch.no_grad():
noise_pred = unet(latents, t, encoder_hidden_states=text_embeddings)["sample"]
latents.shape, noise_pred.shape # We get preds in the same shape as the input
# Prep Scheduler
scheduler.set_timesteps(num_inference_steps)
# What is our timestep
t = scheduler.timesteps[0]
sigma = scheduler.sigmas[0]
# A noisy latent
latents = torch.randn(
(batch_size, unet.in_channels, height // 8, width // 8),
generator=generator,
)
latents = latents.to(torch_device)
latents = latents * scheduler.sigmas[0]
# Text embedding
text_input = tokenizer(['A macaw'], padding="max_length", max_length=tokenizer.model_max_length, truncation=True, return_tensors="pt")
with torch.no_grad():
text_embeddings = text_encoder(text_input.input_ids.to(torch_device))[0]
# Run this through the unet to predict the noise residual
with torch.no_grad():
noise_pred = unet(latents, t, encoder_hidden_states=text_embeddings)["sample"]
latents.shape, noise_pred.shape # We get preds in the same shape as the input
Classifier Free Guidance
By default, the model doesn't often do what we ask. If we want it to follow the prompt better, we use a hack called CFG. There's a good explanation in this video.
In the code, this comes down to us doing:
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
Guidance
How can we add some extra control to this generation process?
At each step, we're going to use our model as before to predict the noise component of x. Then we'll use this to produce a predicted output image, and apply some loss function to this image.
The key here:
- We find our predicted output image
- We find our loss based on this
- We get the gradient of this loss with respect to our X (the latents)
- We modify the current X based on this gradient
- This is 'nudging' the output at each step in a way that reduces our loss. We can control this by scaling the loss (too high and it takes over, too low and it won't affect the process much).
Conclusion
I'm still learning about diffusion models, but my first impression is that this is a cool model, and there is a lot you can do with it! I hope this notebook helped you peek under the hood a little more, and that you're inspired to start exploring further and doing some experiments of your own. I'm on the same journey so If you have questions or you wanna chat, feel free to reach out.
Subscribe to the newsletter
Get emails from me about updates in the world of Artificial Intelligence, and Personal Development.