Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove inefficient computation from AttentionPool2d Module #271

Merged
merged 3 commits into from
Jul 21, 2022

Conversation

jenkspt
Copy link
Contributor

@jenkspt jenkspt commented Jul 21, 2022

This is a simple fix that removes unnecessary attention computation from the AttentionPool2d Module.
In the existing version, self attention is calculated on the full spatial + average embedding sequence with shape [(HW+1), N, C]. In the proposed fix, attention is calculated with the average embedding [1, N, C] as the query and the spatial + average embedding sequence [(HW+1), N, C] as the key/value.

I created this gist: https://gist.github.com/jenkspt/3a09cc150ab531781c6084c166047639 to demonstrate the equivalence of the existing implementation and the proposed one. There is parity in both the computation and the parameter state -- so there shouldn't be any breaking changes introduced.

I realize that AttentionPool2d is only used once in the CLIP model, so this fix will not have a huge impact -- however I arrived here from https://github.com/openai/guided-diffusion/blob/main/guided_diffusion/unet.py#L22-L51, which is based on the clip version (and has the same problem) -- so I think there is the added benefit for posterity

@jongwook jongwook merged commit f69a9bc into openai:main Jul 21, 2022
@jongwook
Copy link
Collaborator

Thanks for the PR, a nice fix! There's a similar inefficiency in the last layer of the vision transformer but it won't be as simple as this PR to fix it..

@jenkspt jenkspt deleted the fix-attention-pool2d branch July 21, 2022 20:28
@jenkspt
Copy link
Contributor Author

jenkspt commented Jul 21, 2022

Thanks for the PR, a nice fix! There's a similar inefficiency in the last layer of the vision transformer but it won't be as simple as this PR to fix it..

NP! -- are you referring to this?

CLIP/clip/model.py

Lines 185 to 187 in f69a9bc

def attention(self, x: torch.Tensor):
self.attn_mask = self.attn_mask.to(dtype=x.dtype, device=x.device) if self.attn_mask is not None else None
return self.attn(x, x, x, need_weights=False, attn_mask=self.attn_mask)[0]

@jongwook
Copy link
Collaborator

That [0] was just to throw out the attn_output_weights returned by nn.MultiHeadAttention. The inefficiency that I mentioned is at:

CLIP/clip/model.py

Lines 231 to 235 in f69a9bc

x = x.permute(1, 0, 2) # NLD -> LND
x = self.transformer(x)
x = x.permute(1, 0, 2) # LND -> NLD
x = self.ln_post(x[:, 0, :])

where the vision encoder takes activations from just the CLS position as its output. But I don't think it needs a fix here anytime soon; was just noting!

rom1504 pushed a commit to rom1504/CLIP that referenced this pull request Jan 13, 2024
)

* fix inefficient attention computation

* remove erroneous formatting

* simplified flatten

Co-authored-by: Jong Wook Kim <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants