Project 5b: Diffusion Models from Scratch

Safaa Mouline

Go to Part A

Overview

In this part, I first implement a UNet for one-step denoising. Then, I add time-conditioning and finally class conditioning with CFG.

1.1: Implementing the UNet

Implementing a UNet from scratch! I implemented this using the following UNet architecture and operations.

figure_3.png figure_3.png

1.2: Using the UNet to Train a Denoiser

We define the loss function to be the L2 loss between the denoied image and the clean image. To create noised images for training, I added gaussian noise of different degrees to MNIST images.

figure_3.png
Varying Noise in MNIST Digits

1.2.1: Training

The following shows the one-step denoising training for noise at the level sigma = 0.5. The batch size was 256, run for 5 epochs, with a learning rate of 1e-4.

figure_4.png
Training Loss
figure_5.png
Results from test set after 1 epoch
figure_6.png
Results from test set after 5 epochs

1.2.1: Out-of-Distribution Testing

Here is how the one-step denoising UNet performs for an image it was not trained on (from the test set)

figure_7.png
Results with varying noise

2.1: Adding Time Conditioning to UNet

Instead of just predicting the clean image all at once (which we saw even in Part A was not that great), we can instead predict the noise at each step. To do this, we'll change the objective function to be the L2 loss between the predicted noise at a time step and the actual noise. Architecure wise, we need to add a FullyConnectedBlock to add the timestep into the model.

figure_7.png

2.2: Training the UNet

Here is the algorithm that was used for training:

figure_7.png

For training, I used a batch size of 128 and trained for 20 epochs.

figure_10.png
Time-Conditioned Training Loss

2.3: Sampling from the UNet

Here is the algorithm that was used for sampling:

figure_7.png
epoch_5.png
Sampling at Epoch 5
epoch_20.png
Sampling at Epoch 20

2.4: Adding Class-Conditioning to UNet

But wait! What if I wanted to guide the model to generate specific digits (similar to prompt embeddings for part A)? To do this, we one-hot encoded a class-conditioning vector, c to represent the different 10 classes (digits). We use a dropout of 10% for the conditioning so that it still doesn't need to have conditioning. Trained for 20 epochs. Here is the algorithm for training:

figure_7.png
figure_11.png
Class-Conditioned UNet Training Loss

2.3: Sampling from the Class-Conditioned UNet

Here, we implement CFG, using unconditioned and conditioned passes through the model. Gamma was set to 5.0 and here is the sampling at the following epochs. Here is the algorithm for sampling:

figure_7.png
class_epoch_5.png
Class-Conditioned Sampling at Epoch 5
class_epoch_20.png
Class-Conditioned Sampling at Epoch 20