diff --git a/configs/mistral/7B.yml b/configs/mistral/7B.yml index 67b7c6a52..587fe5d36 100644 --- a/configs/mistral/7B.yml +++ b/configs/mistral/7B.yml @@ -9,7 +9,9 @@ "intermediate_size": 14336, "num_attention_heads": 32, "num_kv_heads": 8, - "seq_length": 4096, + # per Mistral, Mistral-7B-v0.1 was pretrained with 8192 seqlen + # and instruction tuned to 16384 seqlen, all with 4096 sliding window + "seq_length": 8192, "sliding_window_width": 4096, "max_position_embeddings": 131072, "pos_emb": "rotary", diff --git a/configs/neox_arguments.md b/configs/neox_arguments.md index 9a0511237..be1fd7905 100644 --- a/configs/neox_arguments.md +++ b/configs/neox_arguments.md @@ -111,7 +111,7 @@ Logging Arguments - **git_hash**: str - Default = 4ca9a8a + Default = 11a2e9b current git hash of repository diff --git a/megatron/model/transformer.py b/megatron/model/transformer.py index 85cfb6e2d..f039126b9 100644 --- a/megatron/model/transformer.py +++ b/megatron/model/transformer.py @@ -593,6 +593,15 @@ def flash_attention(self, query_layer, key_layer, value_layer): output_size[0], output_size[2], output_size[1], -1 ) + # only pass in window_size kwarg to flash-attn + # if we use Sliding Window Attention. + # Flash attn defaults to (-1,-1), or + # does not have this kwarg prior to v2.3.0 + extra_kwargs = ( + {"window_size": (self.sliding_window_width, -1)} + if self.sliding_window_width is not None + else {} + ) if not self.training: q_shape = query_layer.shape k_shape = key_layer.shape @@ -613,9 +622,7 @@ def flash_attention(self, query_layer, key_layer, value_layer): max_seqlen_k, softmax_scale=None, causal=True, - window_size=(self.sliding_window_width, -1) - if self.sliding_window_width is not None - else (-1, -1), + **extra_kwargs, ) output = output.reshape(q_shape) else: @@ -626,9 +633,7 @@ def flash_attention(self, query_layer, key_layer, value_layer): self.dropout_p if self.training else 0.0, softmax_scale=None, causal=True, - window_size=(self.sliding_window_width, -1) - if self.sliding_window_width is not None - else (-1, -1), + **extra_kwargs, ) matmul_result = output diff --git a/megatron/neox_arguments/arguments.py b/megatron/neox_arguments/arguments.py index 7bca420cd..ca68100a5 100644 --- a/megatron/neox_arguments/arguments.py +++ b/megatron/neox_arguments/arguments.py @@ -390,12 +390,6 @@ def consume_deepy_args(cls, input_args=None): neox_args.wandb_group += "_" + wandb.util.generate_id() - if neox_args.sliding_window_width is not None: - _flash_version = packaging.version.Version(version("flash-attn")) - assert _flash_version >= packaging.version.Version( - "2.0.0" - ), f"Flash-Attention version ({str(_flash_version)}) must be >= 2.0.0 to support sliding window attention." - neox_args.print() return neox_args @@ -1080,10 +1074,16 @@ def calculate_derived(self): assert all( (attn_type == "flash") or (attn_type == "global") for attn_type in self.attention_config - ), "GQA / MQA currently only compatible with Flash or standard global Attention" + ), "GQA / MQA currently only compatible with Flash or standard global/sliding window Attention" assert ( self.num_kv_heads % self.model_parallel_size == 0 ), "Number of KV heads must be at least model_parallel_size for now!" + # Flash attention version >=2.3.0 required to combine Flash + Sliding Window Attention + if self.sliding_window_width is not None and "flash" in self.attention_config: + _flash_version = packaging.version.Version(version("flash-attn")) + assert _flash_version >= packaging.version.Version( + "2.3.0" + ), f"Flash-Attention version ({str(_flash_version)}) must be >= 2.3.0 to support sliding window attention." # Adding equal dataset weights if none are provided if self.train_data_paths and (self.train_data_weights is None):