The dog picture above isn’t real. This is an AI-generated dog based on 6 images of my favorite dog, Frida, which is my roommate’s dog. In this article, I’ll talk about my journey toward getting a picture like this and how diffusion models and the industry got better over the course of this project.
If you haven’t already check out at least the beginning of part 1! There I cover what diffusion models are, why we have them, and how to train them/do inference with them!
The below is a bit math-y/technical, so if you just want to do something like this too, check out this article! In it, I just cover the implementation.
What this article will cover
In this article, I will cover conditioning/text-based diffusion models, latent diffusion models, approaches to doing few-shot learning, and textual inversion! I’ll also talk about the results and experimentation!
I’ll also try to add a colaboratory notebook here so you all can try it out too!
How to make a diffusion model work with text?
Now I think we have all seen Dalle 2’s and Midjourney’s amazing art generated from a simple text prompt. For example, the below was generated in Midjourney with the prompt: “sky full of forest water birds animals”
Now, how do we do this with our diffusion model? Because if you remember from the previous article, our diffusion model just started with random noise and just tried to keep denoising from there. Our model had no way for users to specify what kind of image they wanted.
The first paper I found on this topic was Open AI’s publication “GLIDE: Towards Photorealistic Image Generation and Editing with Text-Guided Diffusion Models”. This was also the first ever diffusion model I worked with.
The main idea for how to add text came from “Diffusion Models Beat GANs on Image Synthesis”. The idea was we can condition the generated images to a class. So let’s say we want to generate a dog. If we tell our model that we want to generate a dog, wouldn’t it do better at generating a dog?
Class Conditional Diffusion Models
Now I think part 1 had enough math in it already so if you are interested in the math check out this article(coming soon)! I’ll just try doing a summary.
In these models, we first need a classifier
Given a noisy image xₜ, the classifier tries to predict a class c by giving us the probability of class c given xₜ.
Now, the reverse process, the process where the model generates an image, is shown below
one thing you might notice is that we have a symbol c. This is our new reverse process that gets the probability of getting the denoised image xₜ₋₁ given the noisy image xₜ and given we know the class c.
As it turns out, this turns out to still a normal distribution. However, we shift the mean like so
Now, this looks a bit scary but there is an intuitive way of understanding this. The new mean is just the previous mean we had but we shift it by our variance times the gradient. The gradient is the gradient of a classifier that guesses what class a noisy image xₜ₋₁ is.
s here is the amount we want to push towards getting this class. The higher the value of s, the more the images will be like the class but the less diverse they would be.
Now, what does this all mean? First of all, a gradient, the triangle symbol, is the rate of change. In our case, it’s the rate of change of the class’s probability with respect to our noisy image xₜ₋₁. It shows how fast the probability it’s predicting for c changes as we change the values of xₜ₋₁.
The gradient of the classifier will be high if small changes in the noisy image xₜ₋₁ will drastically change the probability that our classifier is guessing c. So by adding that gradient to the mean, we are doing a gradient ascent where we are pushing the value of xₜ₋₁ to where the probability of predicting c is higher.
On the other hand, if we get a very low gradient value, this means that the classifier is already pretty happy with its choice of class since even if xₜ₋₁ is slightly nudged away from the current value, it’ll still predict class c at a similar probability!
For Σ, you can think of it as pushing the right amount. For example, if our images are distributed at very high variance, and if we push say, 0.1 in the gradient’s direction, we won’t especially go anywhere since we ignored how variable the data is.
Removing the classifier
Now one problem with the previous approach is that we need a classifier. So, why don’t we remove that?
The idea is we can use Bayes theorem to say
So basically we can get rid of our classifier part if we can find the right-hand side! Now, if we look at the RHS, we notice that we have a gradient of the probability of xₜ₋₁ given c and the gradient of the probability of xₜ₋₁, now since we are doing the small nudges to xₜ₋₁ to get to xₜ₋₂ and so on, we can approximate these gradients with ϵs! In particular,
Now, if we introduce say a ∅ label that can mean any class, we can do classifier-free guidance pretty simply like so!
Now, this does move away from our mean changing with classifier guidance but it’s still pretty intuitive!
s here is a guidance score like before. If s is 0, then we don’t want any guidance. Then since we are conditioned on the ∅, we get unconditional images. On the other hand, if we have a high value of s, we push away from the unconditional images toward the conditional images.
Making the model recognize text
We finally got here! For the text portion, it is pretty simple too! First, we work with models called Transformers
If you already know what Transformers are, skip ahead!
Transformers in AI are not large robots from space whose intention it is to save the earth. They are AI models that can encode text. This is a nice blog post about it here. If you want to see the code too, I wrote a series on the first-ever transformer model called GPT 2 here!
But let’s just have a TLDR. Transformers encode text by looking over the text you gave and deciding which words to pay attention to. Then it just focuses on those words and outputs a new sequence of words. The output and the input of the transformers are the same sizes as the input.
Now, what is interesting about these models is that they can be used for any data type which gave rise to having one transformer model being able to do pretty much anything. If interested, check out this blog post!
Transformers in Diffusion Models
Now since transformers are so good at their job, they have been used to even process images too! In our case, the diffusion model architectures where they predict the next noise given xₜ and t all have Transformer architecture in them.
Historically, Convolutional Neural Networks(CNNs) were used but the main problem was that since we are dealing with a sliding window, we can’t easily make connections between one side of an image and the other side of the image. Transformers, on the other hand, can just pay attention to the parts of the image they want which makes them the state of the art in many applications!
Using Transformers to recognize text
So now, let us go back to using text to change the model output in diffusion models. Firstly, we tokenized the text. Tokenization means we put the text in the form the transformer understands. We would have a set probability to have no text go through for training. This would be the unconditional model!
Then we create a new transformer model that inputs those tokens and outputs a series of tokens. Now, given this transformer’s output, we do two things
- We take the last predicted token, pass it through a linear neural net, and that is our class label which the model recognizes!
- We add these outputted tokens inside the current diffusion model architecture. What happens then is that the Transformers in the architecture, in addition to seeing the current xₜ, see the processed text. So it can pay attention to both the image and the text at the same time. This idea is called cross-attention because we are paying attention to two different types of data. Let me know if this is confusing and I’ll try expanding on it!
Now, let us talk about the wonderful research called Textual Inversion that came out right when I was figuring this problem out. So far, I have tried
- Adding prompts, such as “A picture of Frida on her pillow under a bicycle” to each image and seeing if the model figures it out.
- Attempting to replicate this research paper’s result where we completely remove the text part and try adding a set encoder there instead. If you are curious, check out this repo. I still kind of like this idea
But then, I found out about Textual Inversion. Textual Inversion was published with a paper titled “An Image is Worth One Word: Personalizing Text-to-Image Generation using Textual Inversion”.
The idea is as follows: why do we need to fine-tune the entire network for getting a picture of Frida?
Can’t we just learn a word to describe her? The execution is even more simple. We just add a new token to describe Frida and allow the network to only train that token’s embedding. This would make it so the network figures out what Frida is in a prompt.
This allows for some interesting ideas. Take a look at the image they gave
Basically, by learning the concept S* we can add it to a custom prompt to put Frida in any situation we want! I haven’t done this yet but I’ll add Frida in a superman costume here once I do that.
Now, around this time, there were some fancy models hanging around but the main problem was that they didn’t fit in my GPU’s memory which was 6Gb. So I first tested out with the Glide model we’ve been talking about. The code with which I did this is here.
You can see all the runs here but the results weren’t terrific. I gave the token Frida an initial meaning of dog and the model seemed to have trouble going anywhere from that. I trained on 64x64
The results are not too bad, but they were too volatile. These were the best images and they were quickly followed by
which are pretty unrelated to Frida.
Using Multiple Tokens
One idea I found in the original implementation of textual inversion was using multiple tokens to represent the concept. So instead of trying to generate
A picture of <frida>
I am trying to generate
A picture of <frida>_0 <frida>_1 <frida>_2 …. <frida>_9
so using 10 words to describe Frida.
We can assign the initial value of all the Frida tokens to a dog. Another idea I had was why don’t we try to describe Frida so that the model would have a better starting point.
I tried having the initial tokens be a “large white and light brownish retriever dog”. The idea here is that <frida>_0 would be assigned to large <frida>_1 to white and so on and the relationship between the words get updated as the training goes on!
This was one of the more promising runs for the Glide model where the model improved to consistently get images like
However, it still is not good enough to capture all the features of Frida so I needed to change models.
In comes Huggingface and Latent Diffusion! Have you heard of the new model on the block called Stable Diffusion? Well, this is it!
The main idea here is we’ll add a Variational Autoencoder(VAE)! This is the first code I saw on this topic but basically, VAEs are a way to encode images, audio, and basically anything, into a latent representation. What is special about this latent representation is that it’s made, in a VAE, to represent concepts well. So, if we hypothetically encode a bunch of Frida pictures in a VAE, we expect them to lie in a similar place.
Now, it’s a good idea to combine VAEs with diffusion models because we can train a model with way fewer parameters. We don’t have to make a model generate say a 512x512x3 image directly, we can just encode that 512x512x3 image to say 64x64x8 or something similar and then do diffusion on that! Once we are done, we can decode the 64x64x8 image we made back to 512x512x3 and we have our image!
Attempt 1 of using latent diffusion
You can see my first attempt at using latent diffusion in this code here on the 6gb branch. As you may have noticed from the branch name, the main problem I had was fitting it into my GPU’s memory as these models were made for 16Gb GPUs and not my 6 Gb one. My main strategy for my first attempt was
- Moving almost all the models to the CPU
- Lower the size of the image to 256x256
The results were pretty bad in the beginning like so
They seemed to have a better idea but there were these edge distortions. What these edge marks turned out to be were rather interesting. They were the VAE basically freaking out because it can’t recognize the concept. You may notice that all the edge marks form a block-like architecture. This is pretty much how the VAE sees the image, it segments it into blocks and combines them together.
Now, I figured the reason for the VAE not understanding the concept was that we set our learning rate to be too high. So I set it lower to learn slowly and I got pictures like this.
Now, this is pretty good, especially compared to our Glide runs, but I had a problem with the images not being detailed. But in this model, the outputs were not consistent as can be seen below
Increasing the number of tokens didn’t do well too as then it went back to being distorted.
The main reason for the problem was pretty simple, the whole architecture on Huggingface was designed to be trained on 512x512 and for 256x256, I need to test out a lot to see what the correct parameters are.
Doing latent diffusion on 512x512
Now, something a bit magical happened. I saw a pr on Huggingface’s stable diffusion library called diffusers that pretty much solved my problem. This is the pr.
The author Ttl came up with two main strategies.
- Using fp16. This lowers the precision of the data in the model by half. This theoretically lowers the usage of memory by half too. I have tried this on my end before but just doing this is not enough to fit the model in my 6Gb GPU.
- Using gradient checkpointing! This was quite an interesting improvement. The idea of gradient checkpointing is best described here but it is to remove the activations in the model to save memory and when we do the backward pass, we recompute the activations so we can get the gradient properly. Doing this saves 60% of the memory for 20% more computational time usually.
Now, the funny thing about textual inversion is that we freeze this model we are doing gradient checkpoint on. Gradient checkpointing is typically done for the backward pass of the model where we get the gradient. So we typically train the model we do gradient checkpointing on.
But if we are just using the model in the forward pass, there doesn’t seem to be much point to it. But there is one big point. We get rid of the activations in the forward pass. So we are reducing the memory cost. According to Ttl by roughly 1Gb.
Even with these changes, it was barely not enough to fit in my 6Gb GPU.
But once I added attention slicing which is basically just doing the matrix multiplication for attention bit by bit(check out here), it fits in my GPU! The final code is here. This was rather interesting too since according to Suraj Patil who’s one of the top maintainers, this should only have a marginal optimization for training as it was designed for inference. But it still worked so I’m happy with it.
Now, let’s look at the results! I was finally able to get good images. For one token training, I got results like so
As you may notice it’s slightly cursed. I think what happened was that Frida is a pretty unique-looking dog so the model had trouble describing her properly with just one token.
When I increase it to 10 tokens, I got results like
Which are way better. Although not perfect.
If we increase the prompts to say 15 we are able to create more monstrous creations like
And yup, I started training with some more creative parameters and I got results like the puppy Frida below
I am pretty satisfied with this Journey and thank you the open source community and the AI research community for probably unknowingly supporting me through this journey.
Check out this article so that you can do something like this too!