Understanding Diffusion Models Part 1
The dog picture above isn’t real. This is an AI-generated dog based on 6 images of my favorite dog, Frida, which is my roommate’s dog. Here are some actual pictures of Frida so you can compare them.
The current AI isn’t perfect. For example, the pattern on the nose doesn’t exist and some parts of the image are slightly off. But I am pretty happy with the result, not to mention, I was able to get a puppy version of Frida! So I thought I’ll write down everything I learned so far.
If you just want to do something like this too, check out this article! In it, I just cover the implementation/setup so you can generate any picture you want!
I’ll try to write this from a beginner's perspective. So no prior experience is required! But at the same time, I’ll try to go as in-depth as possible so that it’ll be interesting to the ml experts out there too.
What this article will cover
This article will go over a brief history of how AI generates images as well as the math derivation of how the current state-of-the-art diffusion models work. There is some pre-requisite info in this article but if you just want to do something like this too, check out this article! And if you are more curious about recent developments, check out part 2!
How does AI generate images?
Generating images from AI has quite some history. The first boom in AI image generation came from models called Generative Adversarial Networks(GANs). If you are already familiar with GANs, feel free to skip to “What are diffusion models?”
What are GANs?
GANs have 2 parts to them. One part is called the Generator. The idea of a generator is you start with random noise. Then you use these networks called convolutional neural networks to scale up this noise to an image size such as width 256 height 256 and 3 channels.
What are convolutional neural networks?
Convolutional neural networks are how we make neural networks see images. This blog gives my favorite graphic/explanation of convolution networks but for a TLDR, convolutions shifts over the image with a kernel. A kernel is a window where you see how similar the image under it is to the kernel. The idea is if we want to find say an apple in the image, we can get a kernel that resembles a part of an apple, and if we just apply the kernel over the image we found our apple. This is done by dot product with the kernel and the part under it. The basic idea is given by the graphic below
Now to scale up the image, we just do it the other way around. When given the squashed image, we try to predict the big image. Like below
This is a very rough overview but the main idea is there.
Back to GANs
Now we know how the Generator scales up. But the problem is now how do we know the image generated by the Generator is a good image? For this, we use a Discriminator. A Discriminator acts as a critic of the Generator’s images. The discriminator tries its best to tell apart the generated images from the real images. If it can’t guess which images are the Generator’s images vs the real images, the Generator keeps doing what it’s doing while the Discriminator needs to get smarter.
The general idea is that the Generator and Discriminator fight against each other to produce good-looking images. Historically, this approach was pretty good. For example, people were even able to customize it to generate fake people by combining faces like below.
Problem with GANs
The main issue with GANs is, funnily enough, one of their best features: it’s an adversary network. There are 2 competing networks which means that if one does too good a job, the other one just fails. For example, if the Generator outsmarts the Discriminator, we can get stuck with the Generator making subpar images while the Discriminator has no idea how to improve. Same for vice versa. Overall, with GANs, I typically found it hard to get good results like in papers because if the parameters are not set right, training can easily diverge as GANs are very unstable.
What are Diffusion Models?
Diffusion models mainly became popular in June of 2021 in Open AI’s paper “Diffusion Models Beat GANs on Image Synthesis”. This wasn’t the first paper talking about Diffusion Models but it’s the first one that showed that Diffusion Models surpassed GANs. The results look pretty nice
However, this paper was pretty hard in terms of understanding diffusion models for me. The first paper on which a lot of the diffusion model code implementation is based is “Denoising Diffusion Probalistic Models“.
Now even with this paper, I wasn’t sure what was going on but I think I finally got an idea after reading this blog post. Let me know if there are any mistakes!
The diffusion models are models where you keep adding noise to scaled noise for some number of steps until you get an output. The figure below kind of shows this
This is a bit math-y but let’s go step by step. Diffusion models have 2 processes.
One process is called the forward process which is just a fancy way of saying adding noise. This is shown as q(xₜ ∣ xₜ₋₁). This can be read as predicting xₜ, the image after noise is added to xₜ₋₁ given that we have xₜ₋₁. We represent this as
This means we are getting xₜ, the noisy image, given a mean of the square root of 1-Βₜ times xₜ₋₁ and variance of Βₜ. “I” just means we want this variance across all dims. A more intuitive way to write this is
where ϵ just samples from the normal distribution with mean 0 and variance 1. we’ll get back to this in a little bit.
We are working with a normal distribution. If you are not familiar, it’s the distribution that the entire world is based on for some reason. Be sure to check it out if you are unfamiliar!
T is the total number of steps. Usually, this is 1000 or so. If we do T forward processes from an image we get complete noise. On the other hand, if we take T reverse processes from random noise with the reverse process, we are getting a plausible image.
For Βₜ we typically have a scheduler that increases from a small value. In my model 0.00085 for 0 to 0.012 for T or so for 1000 steps. This means that at the beginning like x₀, x₁ we add very small amounts of noise but as we get closer to pure noise at the end, we ramp up the noise added.
The other process is called the reverse process this is where the magic happens. Here we try getting rid of the noise given the noisy image. This is shown as p(xₜ₋₁ ∣ xₜ). We are predicting xₜ₋₁, the image before the noise is added given xₜ(the noisy image). Now, for this process, we do not know the mean and variance at the moment so we’ll just write it as
Our goal is to learn this mean and variance so that our model can correctly denoise. In English, we are trying to predict the denoised xₜ₋₁ from xₜ using a learned mean and a learned variance.
Joint probabilities
This is for the later in-depth math sections. Feel free to skip if you just want the general idea!
Another thing to note is that we can also have a joint probability which is the probability that xₜ, x_{t-1}….x₁ happened given we have x₀ for q.
We can say a similar thing for p for the probability of x₀ happening all the way to xₜ
The middle bar means given. So the idea is for the p example, we start with the noisy image’s probability, and then we can keep getting the probability of getting the less noisy image given the current noisy image. In the end, we have the probability of getting a plausible image x₀ given we started from pure noise.
How to train a Diffusion Model
For new readers, I tried making the math as accessible as possible but the below is technical so feel free to skip to the implementation/setup. Or if you are curious about what I did for this project/recent developments, check out this next article. The below is not essential information for the next article and also it’s the hardest part of all of it.
Now, how do we train a model to do the reverse process? We do have the option of generating the mean and variance directly from xₜ and t but one problem with that is still, how exactly will you know you have the correct mean and correct variance at a certain timestep?
The authors came up with 2 main tricks to make this easier. Firstly, they came up with a strategy to add t steps of noise to the original image x₀. Firstly, as we can write
Let us write αₜ=1-βₜ then we can say
Now if we substitute what xₜ₋₁ is in terms of xₜ₋₂ we get
Now, to combine the two Gaussians, there’s a neat trick I learned from this blog. If we merge two Gaussians, the variance is just the sum of the variances. This means we can write
which is pretty clean. This means we can choose ϵ as a valid hop without getting xₜ₋₂ and xₜ₋₁ individually! Now, we can keep doing this until we get to x₀ and then we’ll have something like
By introducing α with a bar above to represent the product until that point
Now, we have a way to get to the noisy version of the input image t timesteps in.
The next idea the authors thought was we can use this to predict xₜ₋₁ from xₜ and x₀.
The idea for how this is possible is based on Bayes Theorem which is below. I don’t think I can give this theorem justice so check out this vid by 3B1B! The main takeaway is that we can predict the probability of an event A happening given another event B already happened by knowing the other probabilities.
In our case, this means that we can calculate q(xₜ₋₁ | xₜ, x₀) without the need to use the reverse process. I’ll avoid all the derivation math as I’m still digesting it but the main idea is given as
Honestly, I was quite impressed with this. I wouldn’t have thought of applying Bayes Theorems to things like a normal distribution and all the quantities here are pretty easily calculated from what we found so far.
Now, here, let us try getting a better mean by getting rid of x₀
We have a new mean,
Inference
Now, what this means is that we can construct a new image by, at every timestep, having a model predict ϵ and doing the above process. For example, say we have random noise which is x at timestep T. We can now get x at timestep T-1 by finding a mean using the above formula and then adding random noise times the variance
at that step. This means that theoretically, we have the model shift the mean to wherever we want and then we have random noise at a small variance to select slight variations within there. So when the model is trying to generate a dog picture, we would get a lot of variations because of the random noise, but we still get the right picture because of the mean being shifted by the model to generate a dog picture.
Note: there are ideas where you can also learn the correct amount of variance added each timestep but in practice, I tend to just see constant linear β schedulers. But learnable βs do exist
Training
Now we know that we want a model to accurately predict the ϵ to shift the mean by. Now, the next question is how do we do this?
Evidence Lower Bound
If you are familiar with the idea of evidence lower bound, feel free to skip this section!
Let us start with Evidence Lower Bound. The best explanation that I found so far for this is this video here. Evidence Lower Bound is based on Bayes Theorem yet again. In our situation, we want to predict
In English, we want to predict the original image given all the noisy images we observed. To do this, we can use Bayes theorem to predict all the noisy images given x₀, a prior which is the distribution of x₀ images. These are the training data. And all this is divided by a thing called evidence which is the distribution of noisy images.
Now, many of these values are hard to find or think about. Like what form will the prior have or what even is the evidence? It turns out that in this form, all this is incomputable. So we move towards a concept called KL-divergence. This concept sounds tricky but what it is is the distance between distributions.
For our case, we do not know all the noisy images given x₀ but we can approximate it using q. Thus we want to get the distance between q and p here so that we can approximate one part of the Bayes Equation. This can be written as
this can be computed as
Now, this value is 0 if the distributions are equal. Otherwise, it’ll keep going up. For the reason why it can’t be negative, check this post out.
Now we have a new problem. We can’t compute the denominator. But fear not, we’ll just use the Bayes theorem again. This will give us
Now, this can be transformed into
because p(a∣ b)p(b) = p(a∧ b) , we can just start from 0! Let me know if this is a bit confusing! I wrote the joint probabilities section just for this part.
Now, if we substitute this into the KL Divergence we get
Now, if we rearrange the terms, we get
Here, it’s time for a bit of math magic.
KL divergence is always bigger than 0. And also the evidence is a constant quantity because it’s over the distribution of the training data/real-world data. Then, if we maximize
Then, we are decreasing the KL divergence! What this means is that our model p is more accurately modeling the real-world data once the term above is maximum. Now, if we get rid of the KL divergence from the equation above we get
Now we see why it’s called evidence lower bound. It’s because the entire thing is lower-bound by evidence! Since we wanted to maximize the negative of this quantity, we should minimize
So the above is our loss.
How to compute our loss
Now, we have a fancy loss function but we still have no idea how to compute it. But once we expand it using the joint probability, we get something like this
The first term is just a constant as at timestep T we expect pure noise so it can be ignored. For the last term, I wasn’t fully able to figure it out but let me know if somebody does!
The middle terms however can be rearranged, yet again using Bayes, as
Now doesn’t the middle term look familiar? Yup! It’s KL divergence. When we do the algebra this whole thing becomes KL divergences between
Now, the question is yet again. How do we compute this? The answer to this is quite simple. Since we have the same variance for both distributions as we set it to constant, we just need to have the same mean at a given timestep t to get the same distribution!
Training steps
Now, when we compute the means for getting xₜ₋₁ from xₜ and x₀, we are using this mean for the forward process which we found earlier
We derived this from
by substituting x₀ using
Now, for the mean of the reverse process, we have our model trying to predict ϵ given xₜ and t. So our other mean is
Now, you might notice that if we remove the weighting, the loss is just the difference between the ϵs where our model is predicting one ϵ.
Now here’s our idea for training. We have xₜ by adding noise ϵ t times to the base image using the forward process using
or
Then, we can just take the loss as
Appreciating the loss function
I like to take a moment now to appreciate this loss function. After all this trouble and all this math, what we got was probably one of the simplest, cleanest, and probably the most useful loss functions in AI. I love it.
Now, this finishes pretty much all the main math we need to understand diffusion models. To do training or inference can be summarized below
For inference, we need to go through T iterations of the model but for training, we just need to just do it at one random timestep t.
Overall, some parts are missing and I don’t yet understand them. For example, can we say that the ϵ for one step from xₜ₋₁ to xₜ is the same ϵ as from x₀ to xₜ? But I think that’s all just one simple proof away. Anyways, thank you all who stayed with me thus far. I was most likely jumping around a lot learning/understanding all this at the same time as all of you. Also, will appreciate any suggestions!
Next
For people who want to get into the implementation/setup to do this yourself, check out this article! If you want to learn more about theory and what I did, check out this article!