How do Textual Inversion tokens destroy prompts?
This blog aims to pose a question on how exactly Textual Inversion tokens destroy prompts/over-dominate the cross-attention in diffusion models. To clarify, I think I got close to the answer but I wasn’t able to fully answer this question.
Background
It is pretty well known that while textual inversion can generate the images you train with to a good degree, it’s not good at following prompts. This has been shown by the custom diffusion paper where textual inversion had very low text alignment. This means the generated images followed the text the least
Now for the reason, I discussed here. But Textual Inversion is trained to ignore the prompts and just generate the images you train with it. For example in diffusers, we tell our model to generate the same image given all the following prompts
imagenet_templates_small = [
"a photo of a {}",
"a rendering of a {}",
"a cropped photo of the {}",
"the photo of a {}",
"a photo of a clean {}",
"a photo of a dirty {}",
"a dark photo of the {}",
"a photo of my {}",
"a photo of the cool {}",
"a close-up photo of a {}",
"a bright photo of the {}",
"a cropped photo of a {}",
"a photo of the {}",
"a good photo of the {}",
"a photo of one {}",
"a close-up photo of the {}",
"a rendition of the {}",
"a photo of the clean {}",
"a rendition of a {}",
"a photo of a nice {}",
"a good photo of a {}",
"a photo of the nice {}",
"a photo of the small {}",
"a photo of the weird {}",
"a photo of the large {}",
"a photo of a cool {}",
"a photo of a small {}",
]
This motivates our model to generate the subject and ignore the rest of the prompt/overwrite the prompt with the concept. For example, when we ask our model to generate “A <cat-toy> next to a man with a friend” we get
which clearly does ignore the “man with a friend” portion and replaces both with <cat-toy>. This is pretty interesting as each token only takes a 768 dimension vector which is extremely small compared to the rest of the diffusion model. Also, it only affects one word in the clip text encoder.
But this blog's goal is not to discuss why the training objective causes this to happen but to figure out what the unique characteristics of these textual inversion tokens are that cause this prompt destruction. Also if any of you have ideas for how this happens let me know!
However, first of all, let us confirm that there is this problem.
The Problem
For the purpose of this blog, I have only tested with one example but I think I might test with a few more if I get time. I am using daam which can show the contribution of each token to the output like so
For how this is done, it comes from an interesting characteristic in diffusion models where the cross attention maps for a particular token tend to be stuck in the index of the token like so
so we can just look at the cross-attention map for the bear and we can see the contribution of the token bear to the output! For more information on this, take a look here.
Now, if we were to look at the contribution of each token for the prompt “A <cat-toy> next to a man with a friend”, when we look at the mean attention maps, for all the normal tokens they have norms of around 10~70. However, the <cat-toy> token consistently had a norm of around 200. Also the <cat-toy> token seemed to have “cleaner attention maps” compared to other tokens. For example, below is a comparison of the attention maps of <cat-toy> vs next
so there is definitely something unique about textual inversion tokens.
To identify why, I mainly did 4 tests
- The norm
- The angle
- CLIP attention
- PCA components
to see exactly what the common denominator is causing this cross-attention destruction
The norm
The main theory for how this happens that I found in the LAION discord servers and literature “Encoder-based Domain Tuning for Fast Personalization of Text-to-Image Models”, otherwise known as E4T, is that the norm of textual inversion vector is what causes this collapse. To mitigate the norm from becoming too large, E4T did try to prevent it from growing with a l1 norm. And the claims have some substance. The norms of the tokens follow a distribution like below
where the y-axis is the frequency and the x-axis is the norm. The tiny bump at 0 is for tokens that are never trained. And the mean is around 0.385. Now, for the <cat-toy> token above, the norm is 3.1 which is around 8 times larger than the average tokens. Now then is this what is causing this token to be over-represented?
One counterargument to this is that if you look at the code for the clip text model, which is the text encoder for stable diffusion, we see this line
last_hidden_state = encoder_outputs[0]
last_hidden_state = self.final_layer_norm(last_hidden_state)
Right before we plug into our diffusion model. So essentially, all possible scale information might be lost here. However, it’s very possible that this scale somehow got mingled into the other tokens on the clip side to over-represent the concept. If we assume this, then if we scale down the token’s norm to say around 0.385, then we should see that the token becomes less represented as it has the same norm as all other tokens. However, what we get is the below
If we scale to 0.385*2, 0.385*3, and so on until the original scale, we get the below where the top left is the one with 2 times scale and the bottom right is the original 8 times scale
At least for me, it seems as though the norm does increase quality slightly but lowering it has negligible effect in terms of increasing prompt alignment. I find this to be very fascinating since just a 768-dimension vector can overdominate the prompt this much.
The angle
Now if the norm is not the culprit, is it the angle? For this test, I got the dot product between 1000 tokens with each other(not against the same token). I didn’t do over the entire 49408 tokens for compute reasons. The result of the dot product is below
Just for a refresher, when we do the dot product, we can also get the cosine between two vectors by dividing by their norms! So when we do that we get
So I think we can say pretty confidently that each input embedding for each token is pretty dissimilar to the others. And since their norms are around 0.385 consistently, we can imagine that each vector is taking up a piece of a sphere in the textual inversion token space!
Now, let’s take the cosine of each token with respect to the textual inversion token embeddings. What we get is the below
One slight observation is that the absolute value of the cosine is slightly smaller. So this means that this token is a bit at a different angle than the other tokens!
CLIP Attention
Now, I mentioned in the beginning that “normal tokens had cross attention norms of around 10~70 at most while the <cat-toy> token consistently had a norm of around 200.”
However, there was one token at that point I haven’t mentioned. That is the start of the sequence token. The start of sequence token is the token that is used to start each prompt in stable diffusion. And for some reason, the norm for that turned out to be in the 2000s in the cross-attention maps which does exceed all the token norms I know. So one hypothesis I formed was that perhaps the textual inversion token is taking up characteristics similar to the start token to overdominate the prompt with very high attention maps and it’s the text encoder that is reacting to this change.
So what I decided to do was examine 2 attention maps of the clip text encoder layers. One where we encode the prompt ’A <cat_toy> next to a man with a friend’ and another which is ’A cat toy next to a man with a friend’.
Let’s start with the normal prompt(without textual inversion tokens). For each layer of the clip texttransformer, we get attention maps like this
First layer
Second layer
Third layer
To read what these attention maps mean, we can look at the numbers along the axis. If we go to the 1 on the y-axis and 0 on the x-axis. We can see how much attention the token at position 1 is paying to the token at position 0. So as the layers go down deeper and deeper, all the tokens are only paying attention to the first start of the sequence token which might explain why the attention map for that token is so large! This is a pretty well-known fact, at least in large language models. I first learned about this in the attention streaming paper where the authors used this fact to extend context length of LLMS!
However, there is one layer that seems to be different than the other layers. This is the first layer below where the y-th token seems to be paying attention to itself and sometimes previous tokens. So my understanding is the first layer of the clip model is the one that is responsible for getting a hold of all the words in the prompt. And to encode it into the start token!
Now then let’s look at the first layer attention map for the textual inversion prompt
We see that at the location of the token, <cat_toy>, at index 2, the textual inversion token seems to only be paying attention to itself. In fact, if we zoom in
We see that it’s paying very little attention to the start token! So in the subsequent layers of the clip, my hypothesis is the textual inversion token is overdominating the rest of the prompt by skipping over the start token at its index. This explains why there is comparably less noise for the textual inversion token. It has full control over generation while the other words are ignored. For hard numbers, the textual inversion token at index 2,2 has the highest value in the attention map(0.905) except for the value at index 0, 0 which is 1. My guess for how this happens is that the textual inversion tokens are made to be more similar to the start token than the rest of the tokens.
One guess I had was that this was caused by the similarity of the textual inversion token to the start of the start of sequence token and I think I am somewhat right. We will call the value at index 2, 2 the textual inversion attention for convenience.
Similarity of textual inversion token to the start token
Let us see what happens if we replace our textual inversion embedding with the start token embedding. Interestingly, as the zoomed-in picture of the cross attention below indicates, there doesn’t seem to be much attention paid to our token anymore
The image generated is the following
This does seem to indicate that low attention scores in the CLIP text encoder attention map indicate less prompt destruction which we were guessing. Another interesting finding here is that the attention map does not assign more attention to tokens similar to the start token.
When we took the cosine of the token with respect to the start token, we found the cosine was -0.0821 which is not significant when looking at the rest of the cosines with respect to the start token below
but one hypothesis I formed was perhaps the CLIP text encoder pays more attention to tokens that are dissimilar to the start token. To confirm this, I did Spherical Linear Interpolation(SLERP) between the scaled-down textual inversion token to 0.385 norm and the start token. What I found was that if I am rotating away from the start token, the textual inversion attention stays to at least 0.87 or so which is consistently higher than the other tokens. However, if we rotate towards the textual inversion token, the textual inversion attention quickly drops and when it reaches 0.5, all signs of any cat toys disappear. At around 0.184 interpolation factor with 0.77 textual inversion attention, I did get the below image
The above is the best image I got from this which shows that some image fidelity is gone. One interesting part of this is when we bring the rotated token back to the original scale, while the general trait of attention score decreasing as it nears the start token is still true, the rate is slower. Also, at a certain point, the images turn to very distorted images which indicates some parts are missing from this puzzle.
Conclusion and Future Avenues
The low-hanging fruit, which I might do, is just testing the above with more textual inversion tokens. But one interesting direction might be to analyze the first layer of Clip to see exactly what is causing this first token to have a high score in the first layer attention map. But overall, hope you all enjoyed this mini-research blog.
Appendix-PCA
I tried extracting the principle components and comparing them to the textual inversion token but I wasn’t able to get interesting results. The main reason for this is since the data is already pretty orthogonal, the explained variance of all the principle components is very small(around 0.14 at most) while regularly we want around 0.99.