From a81ac3cd2a314cca69b4a89e054d448ac1140957 Mon Sep 17 00:00:00 2001 From: Hailey Schoelkopf <65563625+haileyschoelkopf@users.noreply.github.com> Date: Thu, 29 Feb 2024 10:42:07 -0500 Subject: [PATCH 1/8] Improve argument validation for Flash-attn + SWA --- megatron/neox_arguments/arguments.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/megatron/neox_arguments/arguments.py b/megatron/neox_arguments/arguments.py index 7bca420cd..ca68100a5 100644 --- a/megatron/neox_arguments/arguments.py +++ b/megatron/neox_arguments/arguments.py @@ -390,12 +390,6 @@ def consume_deepy_args(cls, input_args=None): neox_args.wandb_group += "_" + wandb.util.generate_id() - if neox_args.sliding_window_width is not None: - _flash_version = packaging.version.Version(version("flash-attn")) - assert _flash_version >= packaging.version.Version( - "2.0.0" - ), f"Flash-Attention version ({str(_flash_version)}) must be >= 2.0.0 to support sliding window attention." - neox_args.print() return neox_args @@ -1080,10 +1074,16 @@ def calculate_derived(self): assert all( (attn_type == "flash") or (attn_type == "global") for attn_type in self.attention_config - ), "GQA / MQA currently only compatible with Flash or standard global Attention" + ), "GQA / MQA currently only compatible with Flash or standard global/sliding window Attention" assert ( self.num_kv_heads % self.model_parallel_size == 0 ), "Number of KV heads must be at least model_parallel_size for now!" + # Flash attention version >=2.3.0 required to combine Flash + Sliding Window Attention + if self.sliding_window_width is not None and "flash" in self.attention_config: + _flash_version = packaging.version.Version(version("flash-attn")) + assert _flash_version >= packaging.version.Version( + "2.3.0" + ), f"Flash-Attention version ({str(_flash_version)}) must be >= 2.3.0 to support sliding window attention." # Adding equal dataset weights if none are provided if self.train_data_paths and (self.train_data_weights is None): From b01ad10fd477728a44bed298526c609e1a52b9f6 Mon Sep 17 00:00:00 2001 From: github-actions Date: Thu, 29 Feb 2024 15:44:00 +0000 Subject: [PATCH 2/8] Update NeoXArgs docs automatically --- configs/neox_arguments.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/configs/neox_arguments.md b/configs/neox_arguments.md index 9a0511237..6f346e133 100644 --- a/configs/neox_arguments.md +++ b/configs/neox_arguments.md @@ -111,7 +111,7 @@ Logging Arguments - **git_hash**: str - Default = 4ca9a8a + Default = a81ac3c current git hash of repository From c094c8cca025b5cd8d570541a0e3531050b452c4 Mon Sep 17 00:00:00 2001 From: Hailey Schoelkopf <65563625+haileyschoelkopf@users.noreply.github.com> Date: Thu, 29 Feb 2024 10:53:29 -0500 Subject: [PATCH 3/8] don't pass window_size if not necessary --- megatron/model/transformer.py | 19 +++++++++++++------ 1 file changed, 13 insertions(+), 6 deletions(-) diff --git a/megatron/model/transformer.py b/megatron/model/transformer.py index 85cfb6e2d..6cf2eeb29 100644 --- a/megatron/model/transformer.py +++ b/megatron/model/transformer.py @@ -593,6 +593,17 @@ def flash_attention(self, query_layer, key_layer, value_layer): output_size[0], output_size[2], output_size[1], -1 ) + # only pass in window_size kwarg to flash-attn + # if we use Sliding Window Attention. + # Flash attn defaults to (-1,-1), or + # does not have this kwarg prior to v2.3.0 + extra_kwargs = ( + { + "window_size": (self.sliding_window_width, -1) + } + if self.sliding_window_width is not None + else {} + ) if not self.training: q_shape = query_layer.shape k_shape = key_layer.shape @@ -613,9 +624,7 @@ def flash_attention(self, query_layer, key_layer, value_layer): max_seqlen_k, softmax_scale=None, causal=True, - window_size=(self.sliding_window_width, -1) - if self.sliding_window_width is not None - else (-1, -1), + **extra_kwargs, ) output = output.reshape(q_shape) else: @@ -626,9 +635,7 @@ def flash_attention(self, query_layer, key_layer, value_layer): self.dropout_p if self.training else 0.0, softmax_scale=None, causal=True, - window_size=(self.sliding_window_width, -1) - if self.sliding_window_width is not None - else (-1, -1), + **extra_kwargs, ) matmul_result = output From ea57a72f8e9a23e234bc02a0f7257d4457c13c90 Mon Sep 17 00:00:00 2001 From: github-actions Date: Thu, 29 Feb 2024 15:53:42 +0000 Subject: [PATCH 4/8] Update NeoXArgs docs automatically --- configs/neox_arguments.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/configs/neox_arguments.md b/configs/neox_arguments.md index 6f346e133..87db35692 100644 --- a/configs/neox_arguments.md +++ b/configs/neox_arguments.md @@ -111,7 +111,7 @@ Logging Arguments - **git_hash**: str - Default = a81ac3c + Default = c094c8c current git hash of repository From d4a091ab4055926d641a5518f11957aedf3afdf6 Mon Sep 17 00:00:00 2001 From: Hailey Schoelkopf <65563625+haileyschoelkopf@users.noreply.github.com> Date: Thu, 29 Feb 2024 14:21:33 -0500 Subject: [PATCH 5/8] Update 7B.yml --- configs/mistral/7B.yml | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/configs/mistral/7B.yml b/configs/mistral/7B.yml index 67b7c6a52..d67adffeb 100644 --- a/configs/mistral/7B.yml +++ b/configs/mistral/7B.yml @@ -9,7 +9,9 @@ "intermediate_size": 14336, "num_attention_heads": 32, "num_kv_heads": 8, - "seq_length": 4096, + # per Mistral, Mistral-7B-v0.1 was pretrained with 8192 seqlen + # and instruction tuned to 16384 seqlen, all with 4096 sliding window + "seq_length": 8192, "sliding_window_width": 4096, "max_position_embeddings": 131072, "pos_emb": "rotary", From 1c5c427dcf521cf1759c5dbbbe226172d47a4323 Mon Sep 17 00:00:00 2001 From: github-actions Date: Thu, 29 Feb 2024 19:21:48 +0000 Subject: [PATCH 6/8] Update NeoXArgs docs automatically --- configs/neox_arguments.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/configs/neox_arguments.md b/configs/neox_arguments.md index 87db35692..557bc8134 100644 --- a/configs/neox_arguments.md +++ b/configs/neox_arguments.md @@ -111,7 +111,7 @@ Logging Arguments - **git_hash**: str - Default = c094c8c + Default = d4a091a current git hash of repository From 11a2e9b14cafc66d79db72c66bfdef2947d84699 Mon Sep 17 00:00:00 2001 From: haileyschoelkopf Date: Fri, 1 Mar 2024 14:45:41 +0000 Subject: [PATCH 7/8] apply precommit --- configs/mistral/7B.yml | 2 +- configs/neox_arguments.md | 7 +++---- megatron/model/transformer.py | 6 ++---- 3 files changed, 6 insertions(+), 9 deletions(-) diff --git a/configs/mistral/7B.yml b/configs/mistral/7B.yml index d67adffeb..587fe5d36 100644 --- a/configs/mistral/7B.yml +++ b/configs/mistral/7B.yml @@ -11,7 +11,7 @@ "num_kv_heads": 8, # per Mistral, Mistral-7B-v0.1 was pretrained with 8192 seqlen # and instruction tuned to 16384 seqlen, all with 4096 sliding window - "seq_length": 8192, + "seq_length": 8192, "sliding_window_width": 4096, "max_position_embeddings": 131072, "pos_emb": "rotary", diff --git a/configs/neox_arguments.md b/configs/neox_arguments.md index 557bc8134..f2eb9c364 100644 --- a/configs/neox_arguments.md +++ b/configs/neox_arguments.md @@ -1058,7 +1058,7 @@ Text Generation arguments - **prompt_end**: str - Default = + Default = a single prompt's end. Defaults to newline @@ -1100,7 +1100,7 @@ Text Generation arguments - **eval_results_prefix**: str - Default = + Default = prefix to which to save evaluation results - final fp will be {eval_results_prefix}_eval_results_yy-mm-dd-HH-MM.json @@ -1844,7 +1844,7 @@ Args for deepspeed config Default = None - + @@ -2144,4 +2144,3 @@ Args for deepspeed runner (deepspeed.launcher.runner). Default = None Adds a `--account` to the DeepSpeed launch command. In DeeperSpeed this is passed on to the SlurmLauncher as well. Sometimes necessary for cluster rules, or so I've heard. - diff --git a/megatron/model/transformer.py b/megatron/model/transformer.py index 6cf2eeb29..f039126b9 100644 --- a/megatron/model/transformer.py +++ b/megatron/model/transformer.py @@ -595,12 +595,10 @@ def flash_attention(self, query_layer, key_layer, value_layer): # only pass in window_size kwarg to flash-attn # if we use Sliding Window Attention. - # Flash attn defaults to (-1,-1), or + # Flash attn defaults to (-1,-1), or # does not have this kwarg prior to v2.3.0 extra_kwargs = ( - { - "window_size": (self.sliding_window_width, -1) - } + {"window_size": (self.sliding_window_width, -1)} if self.sliding_window_width is not None else {} ) From 0155d766076dd6b41f73665e34372f691ba013ef Mon Sep 17 00:00:00 2001 From: github-actions Date: Fri, 1 Mar 2024 15:29:38 +0000 Subject: [PATCH 8/8] Update NeoXArgs docs automatically --- configs/neox_arguments.md | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/configs/neox_arguments.md b/configs/neox_arguments.md index f2eb9c364..be1fd7905 100644 --- a/configs/neox_arguments.md +++ b/configs/neox_arguments.md @@ -111,7 +111,7 @@ Logging Arguments - **git_hash**: str - Default = d4a091a + Default = 11a2e9b current git hash of repository @@ -1058,7 +1058,7 @@ Text Generation arguments - **prompt_end**: str - Default = + Default = a single prompt's end. Defaults to newline @@ -1100,7 +1100,7 @@ Text Generation arguments - **eval_results_prefix**: str - Default = + Default = prefix to which to save evaluation results - final fp will be {eval_results_prefix}_eval_results_yy-mm-dd-HH-MM.json @@ -1844,7 +1844,7 @@ Args for deepspeed config Default = None - + @@ -2144,3 +2144,4 @@ Args for deepspeed runner (deepspeed.launcher.runner). Default = None Adds a `--account` to the DeepSpeed launch command. In DeeperSpeed this is passed on to the SlurmLauncher as well. Sometimes necessary for cluster rules, or so I've heard. +