Understanding Real Custom
This will be a relatively short blog on Real Custom. The motivation for reading this paper mainly came from IP adapter. The idea of IP adapter is very simple. We encode the image and compute a matrix from that image encoding. Then we add it to the text cross-attention in stable diffusion. Now we have our model generating the image we inputted!
The issue with IP adapter
Now the issue with the IP adapter is while training it, it’s not trained to say, in the above example, replace the girl part of the prompt with the target image. It’s more motivated to completely ignore the text and overdominate the entire cross attention with the image cross attention. The only reason that we can somewhat merge the text with the image is during inference we multiply the result of the image cross-attention by around 0.7 to reduce its effect. And this does work pretty well. But I was starting to wonder, can we have an IP adapter that only affects the word girl in the above example and not the rest of the prompt? This is where Real Custom comes in.
Main Contribution of Real Custom
Real Custom solves this issue by localizing the effect of the image we encoded to just the parts of the cross attention we are interested in!
The general overview of the training/inference methods is the above. To point out some interesting ideas
- During training, we can just train like normal with generic text-image pairs
- There is a module called the Adaptive Scoring Module which influences which part of the image the image encoder works on.
- The Adaptive Scoring Module inputs the current text encoder output, the current generation, the timestep, and patches of our input image. The patches, to my knowledge, are encoded with a clip image encoder.
- During inference, given the image we are conditioning on, at the given timestep, we noise and denoise to extract the mean cross attention of our word(ex toy) to see where, without customization, our concept goes in the image.
- We divide the mask by its max norm so, at most, it has a value of 1.
- Then, we multiply that mask by the current generation image and multiply the text encoder output by a token mask that only sets our desired word(ex. toy) to 1 and the rest to 0!
- Finally, the image encoder output for the inference gets multiplied by the attention map mask we computed earlier so there is no leaking of concepts.
The reason I suspect we don’t do the inference trick during training is we are assuming that our text prompt fully explains the image. Then, each token in the text prompt has the output customized by the image encoder!
Now at this point, I still am not sure what is preventing the Adaptive Scoring Module from overdominating the output of the cross attention with image features and ignoring text features. So let’s see how it works!
Adaptive Scoring Module
The adaptive scoring module takes in text features, current generated features, and timestep and outputs a score for each feature in the image features. Here one important point is we do a top k selection after multiplying the image features to the scores. The result of this is we prevent the image features from overdominating the output! The minor con of this is we need a new hyperparameter that controls what proportion of the image features to keep.
Then, like IP adapter those top k selected features are added but we do a slight trick to make the dimensions work in the K, and V layers of the image like so
So the cross-attention result is just the text cross attention+ the image cross attention like the IP adapter but the Kis and Vis only have a sequence length of K from top K!
Now then how do we compute the scores? First, given f_ct is the text features and z_t is the current image generation, we compute softmax-es like so
Here nt is the sequence length for the text features and nz is the height times width of the Vae latent.
Then finally we multiply this attention map to our original features like so
Then we get matrices the size of the embedding dimension. I believe f_y is a typo as I couldn’t find it mentioned in the rest of the paper.
Now, what the authors did was to spatially replicate the above-pooled features to be the same sequence dimension as the image features.
Then, image features are concatenated onto this replication.
Then, there is a 2-layer neural network that given this input, outputs the scores which have dimension image feature sequence length times one. I’m pretty interested in how robust this approach is since the pooling and softmax things in the middle feel pretty complicated at least in my opinion.
However, now we integrate in our timestep using α from the diffusion model parameters!
I am not entirely sure why we use α here but intuitively it makes sense. At lower timesteps we want to focus more on the text features and on future timesteps we want more focus on the visual features! I’ll write here the reason if I get it.
Next, the image features are multiplied by the softmax of S like so
So basically the lowest each feature will go is the original size. And at most, that feature will double.
Finally, the authors mention a bit on the top k and they recommend randomly sampling from 0.3 to 1.0 for the top k proportion.
One final note for inference
While I did mention most of the inference techniques during the above overview, one part I missed was the classifier-free guidance. This paper seems to add to classifier-free guidance so
the z with the superscript T is just the predicted noise given only text while the superscript TI is text and image. So, we first push our predicted noise away from the unconditional noise with the text prediction and then we further push it away from the text with the text and image prediction. I am honestly curious if this helps in IP adapters too! The recommended hyperparameters seem to be ω_t 7.5 and the ω_i 12.5 which is pretty interesting. Finally for inference, interestingly having a lower top k proportion of 0.25 seems optimal
Overall, this was mainly my notes from reading the Real Custom paper but let me know if any parts seem wrong/interesting! I personally find the method for getting the scores to be a bit complicated. In that, I think a better method might be just transpose+some mlps to change the sequence length to the image encoder sequence length but I might be missing some intuition here. I just think that pooling like the above leads to some information loss but I will double check the paper to make sure I get it.