Skip to content

Commit

Permalink
Improve argument validation for Flash-attn + SWA (#1162)
Browse files Browse the repository at this point in the history
* Improve argument validation for Flash-attn + SWA

* Update NeoXArgs docs automatically

* don't pass window_size if not necessary

* Update NeoXArgs docs automatically

* Update 7B.yml

* Update NeoXArgs docs automatically

* apply precommit

* Update NeoXArgs docs automatically

---------

Co-authored-by: github-actions <[email protected]>
  • Loading branch information
haileyschoelkopf and github-actions committed Mar 2, 2024
1 parent 3c03fc7 commit 19596b0
Show file tree
Hide file tree
Showing 4 changed files with 22 additions and 15 deletions.
4 changes: 3 additions & 1 deletion configs/mistral/7B.yml
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,9 @@
"intermediate_size": 14336,
"num_attention_heads": 32,
"num_kv_heads": 8,
"seq_length": 4096,
# per Mistral, Mistral-7B-v0.1 was pretrained with 8192 seqlen
# and instruction tuned to 16384 seqlen, all with 4096 sliding window
"seq_length": 8192,
"sliding_window_width": 4096,
"max_position_embeddings": 131072,
"pos_emb": "rotary",
Expand Down
2 changes: 1 addition & 1 deletion configs/neox_arguments.md
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ Logging Arguments

- **git_hash**: str

Default = 4ca9a8a
Default = 11a2e9b

current git hash of repository

Expand Down
17 changes: 11 additions & 6 deletions megatron/model/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -593,6 +593,15 @@ def flash_attention(self, query_layer, key_layer, value_layer):
output_size[0], output_size[2], output_size[1], -1
)

# only pass in window_size kwarg to flash-attn
# if we use Sliding Window Attention.
# Flash attn defaults to (-1,-1), or
# does not have this kwarg prior to v2.3.0
extra_kwargs = (
{"window_size": (self.sliding_window_width, -1)}
if self.sliding_window_width is not None
else {}
)
if not self.training:
q_shape = query_layer.shape
k_shape = key_layer.shape
Expand All @@ -613,9 +622,7 @@ def flash_attention(self, query_layer, key_layer, value_layer):
max_seqlen_k,
softmax_scale=None,
causal=True,
window_size=(self.sliding_window_width, -1)
if self.sliding_window_width is not None
else (-1, -1),
**extra_kwargs,
)
output = output.reshape(q_shape)
else:
Expand All @@ -626,9 +633,7 @@ def flash_attention(self, query_layer, key_layer, value_layer):
self.dropout_p if self.training else 0.0,
softmax_scale=None,
causal=True,
window_size=(self.sliding_window_width, -1)
if self.sliding_window_width is not None
else (-1, -1),
**extra_kwargs,
)

matmul_result = output
Expand Down
14 changes: 7 additions & 7 deletions megatron/neox_arguments/arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -390,12 +390,6 @@ def consume_deepy_args(cls, input_args=None):

neox_args.wandb_group += "_" + wandb.util.generate_id()

if neox_args.sliding_window_width is not None:
_flash_version = packaging.version.Version(version("flash-attn"))
assert _flash_version >= packaging.version.Version(
"2.0.0"
), f"Flash-Attention version ({str(_flash_version)}) must be >= 2.0.0 to support sliding window attention."

neox_args.print()

return neox_args
Expand Down Expand Up @@ -1080,10 +1074,16 @@ def calculate_derived(self):
assert all(
(attn_type == "flash") or (attn_type == "global")
for attn_type in self.attention_config
), "GQA / MQA currently only compatible with Flash or standard global Attention"
), "GQA / MQA currently only compatible with Flash or standard global/sliding window Attention"
assert (
self.num_kv_heads % self.model_parallel_size == 0
), "Number of KV heads must be at least model_parallel_size for now!"
# Flash attention version >=2.3.0 required to combine Flash + Sliding Window Attention
if self.sliding_window_width is not None and "flash" in self.attention_config:
_flash_version = packaging.version.Version(version("flash-attn"))
assert _flash_version >= packaging.version.Version(
"2.3.0"
), f"Flash-Attention version ({str(_flash_version)}) must be >= 2.3.0 to support sliding window attention."

# Adding equal dataset weights if none are provided
if self.train_data_paths and (self.train_data_weights is None):
Expand Down

0 comments on commit 19596b0

Please sign in to comment.