Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve argument validation for Flash-attn + SWA #1162

Merged
merged 8 commits into from
Mar 2, 2024
Next Next commit
Improve argument validation for Flash-attn + SWA
  • Loading branch information
haileyschoelkopf committed Feb 29, 2024
commit a81ac3cd2a314cca69b4a89e054d448ac1140957
14 changes: 7 additions & 7 deletions megatron/neox_arguments/arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -390,12 +390,6 @@ def consume_deepy_args(cls, input_args=None):

neox_args.wandb_group += "_" + wandb.util.generate_id()

if neox_args.sliding_window_width is not None:
_flash_version = packaging.version.Version(version("flash-attn"))
assert _flash_version >= packaging.version.Version(
"2.0.0"
), f"Flash-Attention version ({str(_flash_version)}) must be >= 2.0.0 to support sliding window attention."

neox_args.print()

return neox_args
Expand Down Expand Up @@ -1080,10 +1074,16 @@ def calculate_derived(self):
assert all(
(attn_type == "flash") or (attn_type == "global")
for attn_type in self.attention_config
), "GQA / MQA currently only compatible with Flash or standard global Attention"
), "GQA / MQA currently only compatible with Flash or standard global/sliding window Attention"
assert (
self.num_kv_heads % self.model_parallel_size == 0
), "Number of KV heads must be at least model_parallel_size for now!"
# Flash attention version >=2.3.0 required to combine Flash + Sliding Window Attention
if self.sliding_window_width is not None and "flash" in self.attention_config:
_flash_version = packaging.version.Version(version("flash-attn"))
assert _flash_version >= packaging.version.Version(
"2.3.0"
), f"Flash-Attention version ({str(_flash_version)}) must be >= 2.3.0 to support sliding window attention."

# Adding equal dataset weights if none are provided
if self.train_data_paths and (self.train_data_weights is None):
Expand Down
Loading