Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve argument validation for Flash-attn + SWA #1162

Merged
merged 8 commits into from
Mar 2, 2024
Prev Previous commit
Next Next commit
don't pass window_size if not necessary
  • Loading branch information
haileyschoelkopf committed Feb 29, 2024
commit c094c8cca025b5cd8d570541a0e3531050b452c4
19 changes: 13 additions & 6 deletions megatron/model/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -593,6 +593,17 @@ def flash_attention(self, query_layer, key_layer, value_layer):
output_size[0], output_size[2], output_size[1], -1
)

# only pass in window_size kwarg to flash-attn
# if we use Sliding Window Attention.
# Flash attn defaults to (-1,-1), or
# does not have this kwarg prior to v2.3.0
extra_kwargs = (
{
"window_size": (self.sliding_window_width, -1)
}
if self.sliding_window_width is not None
else {}
)
if not self.training:
q_shape = query_layer.shape
k_shape = key_layer.shape
Expand All @@ -613,9 +624,7 @@ def flash_attention(self, query_layer, key_layer, value_layer):
max_seqlen_k,
softmax_scale=None,
causal=True,
window_size=(self.sliding_window_width, -1)
if self.sliding_window_width is not None
else (-1, -1),
**extra_kwargs,
)
output = output.reshape(q_shape)
else:
Expand All @@ -626,9 +635,7 @@ def flash_attention(self, query_layer, key_layer, value_layer):
self.dropout_p if self.training else 0.0,
softmax_scale=None,
causal=True,
window_size=(self.sliding_window_width, -1)
if self.sliding_window_width is not None
else (-1, -1),
**extra_kwargs,
)

matmul_result = output
Expand Down
Loading