Remove inefficient computation from AttentionPool2d
Module
#271
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
This is a simple fix that removes unnecessary attention computation from the
AttentionPool2d
Module.In the existing version, self attention is calculated on the full spatial + average embedding sequence with shape
[(HW+1), N, C]
. In the proposed fix, attention is calculated with the average embedding[1, N, C]
as the query and the spatial + average embedding sequence[(HW+1), N, C]
as the key/value.I created this gist: https://gist.github.com/jenkspt/3a09cc150ab531781c6084c166047639 to demonstrate the equivalence of the existing implementation and the proposed one. There is parity in both the computation and the parameter state -- so there shouldn't be any breaking changes introduced.
I realize that
AttentionPool2d
is only used once in the CLIP model, so this fix will not have a huge impact -- however I arrived here from https://github.com/openai/guided-diffusion/blob/main/guided_diffusion/unet.py#L22-L51, which is based on the clip version (and has the same problem) -- so I think there is the added benefit for posterity