diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 414285c9a..991c74fb8 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -32,7 +32,7 @@ repos: hooks: - id: codespell args: [ - '--ignore-words-list=reord', # Word used in error messages that need rewording + '--ignore-words-list=reord,dout', # Word used in error messages that need rewording --check-filenames, --check-hidden, ] diff --git a/README.md b/README.md index 43f4665ce..9b6880bf2 100644 --- a/README.md +++ b/README.md @@ -99,6 +99,11 @@ from the repository root. +### Flash Attention + +To use [Flash-Attention](https://github.com/HazyResearch/flash-attention), install the additional dependencies in `./requirements/requirements-flashattention.txt` and set the attention type in your configuration accordingly (see [configs](./configs/)). This can provide significant speed-ups over regular attention on certain GPU architectures, including Ampere GPUs (such as A100s); see the repository for more details. + + ### Containerized Setup We also provide a Dockerfile if you prefer to run NeoX in a container. To use this option, first build an image named `gpt-neox` from the repository root directory with `docker build -t gpt-neox -f Dockerfile .`. We also host pre-built images on Docker Hub at `leogao2/gpt-neox`. diff --git a/configs/neox_arguments.md b/configs/neox_arguments.md index ecd27d435..b637405c4 100644 --- a/configs/neox_arguments.md +++ b/configs/neox_arguments.md @@ -111,7 +111,7 @@ Logging Arguments - **git_hash**: str - Default = 166c5b6 + Default = 12f6f76 current git hash of repository @@ -798,6 +798,14 @@ Misc. Arguments +- **save_iters**: list + + Default = None + + Set during training + + + - **global_num_gpus**: int Default = None @@ -1133,11 +1141,37 @@ Training Arguments -- **save_interval**: int +- **checkpoint_scale**: typing.Literal['linear', 'log'] + + Default = linear + + How step at which checkpoints are saved should scale. "linear" implies 1 checkpoint will be saved at every multiple of `checkpoint-factor`, + while "log" implies that the number of steps between each checkpoint will be multiplied by `checkpoint-factor` at each step, starting from step 1. + + + +- **checkpoint_factor**: int + + Default = None + + Acts as a multiplier on either the "log" or "linear" checkpoint spacing. + + With `checkpoint-scale="linear"`, `checkpoint-factor=20`, and `train-iters=100`, checkpoints will be saved at + steps [20, 40, 60, 80, 100]. + + With `checkpoint-scale="log"`, `checkpoint-factor=2`, and `train-iters=100`, checkpoints will be saved at + steps [1, 2, 4, 8, 16, 32, 64, 100]. + + Note that the last checkpoint step is always saved. + + + +- **extra_save_iters**: list Default = None - Number of iterations between checkpoint saves. + Additional iterations when a checkpoint should be saved. + Must be a list of ints or `None`. diff --git a/megatron/model/flash_attention.py b/megatron/model/flash_attention.py new file mode 100644 index 000000000..5c40a322c --- /dev/null +++ b/megatron/model/flash_attention.py @@ -0,0 +1,185 @@ +# Based on: https://github.com/HazyResearch/flash-attention/blob/4a6eaa9f27df6fff7ffb2c24e894938a687dd870/flash_attn/flash_attn_interface.py + +import torch +import torch.nn as nn +import torch.nn.functional as F + +import flash_attn_cuda + + +def _flash_attn_forward( + q, + k, + v, + out, + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q, + max_seqlen_k, + dropout_p, + softmax_scale, + causal, + return_softmax, + num_splits=0, + generator=None, +): + """ + num_splits: how much to parallelize over the seqlen_q dimension. num_splits=0 means + it will be set by an internal heuristic. We're exposing num_splits mostly for benchmarking. + Don't change it unless you know what you're doing. + """ + softmax_lse, *rest = flash_attn_cuda.fwd( + q, + k, + v, + out, + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q, + max_seqlen_k, + dropout_p, + softmax_scale, + False, + causal, + return_softmax, + num_splits, + generator, + ) + # if out.isnan().any() or softmax_lse.isnan().any(): + # breakpoint() + S_dmask = rest[0] if return_softmax else None + return out, softmax_lse, S_dmask + + +def _flash_attn_backward( + dout, + q, + k, + v, + out, + softmax_lse, + dq, + dk, + dv, + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q, + max_seqlen_k, + dropout_p, + softmax_scale, + causal, + num_splits=0, + generator=None, +): + """ + num_splits: whether to parallelize over the seqlen_k dimension (num_splits > 1) or + not (num_splits = 1). num_splits=0 means it will be set by an internal heuristic. + Any value above 1 will call the same kernel (i.e. num_splits=2 would call the same kernel + as num_splits=3), so effectively the choices are 0, 1, and 2. + This hyperparameter can be tuned for performance, but default value (heuristic) should work fine. + """ + _, _, _, softmax_d = flash_attn_cuda.bwd( + dout, + q, + k, + v, + out, + softmax_lse, + dq, + dk, + dv, + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q, + max_seqlen_k, + dropout_p, + softmax_scale, + False, + causal, + num_splits, + generator, + ) + # if dk.isnan().any() or dk.isnan().any() or dv.isnan().any() or softmax_d.isnan().any(): + # breakpoint() + return dq, dk, dv, softmax_d + + +class FlashAttnQKVPackedFunc(torch.autograd.Function): + @staticmethod + def forward( + ctx, + qkv, + cu_seqlens, + max_seqlen, + dropout_p, + softmax_scale, + causal, + return_softmax, + ): + # Save rng_state because the backward pass will regenerate the dropout mask + rng_state = torch.cuda.get_rng_state() if dropout_p > 0 else None + if softmax_scale is None: + softmax_scale = qkv.shape[-1] ** (-0.5) + out, softmax_lse, S_dmask = _flash_attn_forward( + qkv[:, 0], + qkv[:, 1], + qkv[:, 2], + torch.empty_like(qkv[:, 0]), + cu_seqlens, + cu_seqlens, + max_seqlen, + max_seqlen, + dropout_p, + softmax_scale, + causal=causal, + return_softmax=return_softmax, + ) + ctx.save_for_backward(qkv, out, softmax_lse, cu_seqlens, rng_state) + ctx.dropout_p = dropout_p + ctx.max_seqlen = max_seqlen + ctx.softmax_scale = softmax_scale + ctx.causal = causal + return out if not return_softmax else (out, softmax_lse, S_dmask) + + @staticmethod + def backward(ctx, dout, *args): + qkv, out, softmax_lse, cu_seqlens, rng_state = ctx.saved_tensors + if rng_state is not None: + cur_rng_state = torch.cuda.get_rng_state() + torch.cuda.set_rng_state(rng_state) + dqkv = torch.empty_like(qkv) + _flash_attn_backward( + dout, + qkv[:, 0], + qkv[:, 1], + qkv[:, 2], + out, + softmax_lse, + dqkv[:, 0], + dqkv[:, 1], + dqkv[:, 2], + cu_seqlens, + cu_seqlens, + ctx.max_seqlen, + ctx.max_seqlen, + ctx.dropout_p, + ctx.softmax_scale, + ctx.causal, + ) + if rng_state is not None: + torch.cuda.set_rng_state(cur_rng_state) + return dqkv, None, None, None, None, None, None + + +def flash_attn_unpadded_qkvpacked_func( + qkv, + cu_seqlens, + max_seqlen, + dropout_p, + softmax_scale=None, + causal=False, + return_attn_probs=False, +): + return FlashAttnQKVPackedFunc.apply( + qkv, cu_seqlens, max_seqlen, dropout_p, softmax_scale, causal, return_attn_probs + ) diff --git a/megatron/model/transformer.py b/megatron/model/transformer.py index 4585efd34..e81213983 100644 --- a/megatron/model/transformer.py +++ b/megatron/model/transformer.py @@ -265,7 +265,8 @@ def __init__( self.rotary_emb = None self.attention_type = neox_args.attention_config[layer_number] - self.sparse = self.attention_type != "global" + self.use_flash_attention = self.attention_type == "flash" + self.sparse = self.attention_type != "global" and not self.use_flash_attention if self.sparse: self.sparse_attn = configure_sparse_attention( neox_args, @@ -274,19 +275,31 @@ def __init__( mpu=mpu, ) else: - self.scale_mask_softmax = FusedScaleMaskSoftmax( - input_in_fp16=self.fp16, - input_in_bf16=self.bf16, - fusion_type=get_fusion_type(neox_args), - mask_func=self.attention_mask_func, - softmax_in_fp32=self.attention_softmax_in_fp32, - scale=(coeff / neox_args.mup_attn_temp) if coeff is not None else None, # TODO: deepspeed sparse attention scaling patch? - ) + if self.use_flash_attention: + from megatron.model.flash_attention import ( + flash_attn_unpadded_qkvpacked_func, + ) + + self.flash_attention_function = flash_attn_unpadded_qkvpacked_func + if self.pos_emb == "alibi": + raise ValueError( + "Flash attention is currently not compatible with AliBi positional embeddings. Use sinuisoidal, learned, or rotary embeddings instead." + ) + else: + self.scale_mask_softmax = FusedScaleMaskSoftmax( + input_in_fp16=self.fp16, + input_in_bf16=self.bf16, + fusion_type=get_fusion_type(neox_args), + mask_func=self.attention_mask_func, + softmax_in_fp32=self.attention_softmax_in_fp32, + scale=coeff, + ) # Dropout. Note that for a single iteration, this layer will generate # different outputs on different number of parallel partitions but # on average it should not be partition dependent. - self.attention_dropout = nn.Dropout(neox_args.attention_dropout) + self.dropout_p = neox_args.attention_dropout + self.attention_dropout = nn.Dropout(self.dropout_p) # Output. self.dense = mpu.RowParallelLinear( @@ -402,6 +415,55 @@ def attention( context_layer = context_layer.view(*output_size) return context_layer + def flash_attention(self, query_layer, key_layer, value_layer): + # [b, np, sq, sk] + output_size = ( + query_layer.size(1), + query_layer.size(2), + query_layer.size(0), + key_layer.size(0), + ) + # [s, b, np, hn] -> [b, s, np, hn] -> [b * s, 1, np, hn] + query_layer = query_layer.transpose(0, 1).reshape( + output_size[0] * output_size[2], 1, output_size[1], -1 + ) + key_layer = key_layer.transpose(0, 1).reshape( + output_size[0] * output_size[3], 1, output_size[1], -1 + ) + value_layer = value_layer.transpose(0, 1).reshape( + output_size[0] * output_size[3], 1, output_size[1], -1 + ) + + # Combined q/k/v into [b * s, 3, np, hn]. + qkv = torch.concat([query_layer, key_layer, value_layer], dim=1) + + batch_size = output_size[0] + seqlen = output_size[2] + max_s = seqlen + cu_seqlens = torch.arange( + 0, + (batch_size + 1) * seqlen, + step=seqlen, + dtype=torch.int32, + device=qkv.device, + ) + output = self.flash_attention_function( + qkv, + cu_seqlens, + max_s, + self.dropout_p if self.training else 0.0, + softmax_scale=None, + causal=True, + ) + # [b * sq, np, hn] -> [b, sq, np, hn] + matmul_result = output.view( + output_size[0], output_size[2], output.shape[1], output.shape[2] + ) + # [b, sq, np, hn] -> [b, np, sq, hn] + matmul_result = matmul_result.transpose(1, 2) + + return matmul_result + def sparse_attention(self, query_layer, key_layer, value_layer, attention_mask): # TODO: sparse attn dropout? # TODO: pad to block size @@ -489,7 +551,9 @@ def forward(self, hidden_states, attention_mask, layer_past=None): if self.use_cache: present = torch.stack((key_layer, value_layer)) - if not self.sparse: + if self.use_flash_attention: + context_layer = self.flash_attention(query_layer, key_layer, value_layer) + elif not self.sparse: context_layer = self.attention( query_layer, key_layer, value_layer, layer_past, attention_mask ) diff --git a/megatron/neox_arguments/arguments.py b/megatron/neox_arguments/arguments.py index f9c78b1a4..d3f8ed044 100644 --- a/megatron/neox_arguments/arguments.py +++ b/megatron/neox_arguments/arguments.py @@ -735,6 +735,30 @@ def calculate_derived(self): } ) + # derive steps where checkpoint should be saved + if self.checkpoint_factor or self.extra_save_iters: + if self.extra_save_iters: + save_iters = set(self.extra_save_iters) + else: + save_iters = set() + + step = self.checkpoint_factor # don't save step 0 or 1 + while step < self.train_iters: + save_iters.add(step) + if self.checkpoint_scale == "log": + step *= self.checkpoint_factor + elif self.checkpoint_scale == "linear": + step += self.checkpoint_factor + + save_iters = list(save_iters) + save_iters.sort() + + self.update_values( + { + "save_iters": save_iters, + } + ) + # derive precision if (self.fp16 or {}).get("type", self.precision) == "bfloat16": self.update_value("precision", "bfloat16") @@ -824,7 +848,7 @@ def calculate_derived(self): if self.sparsity_config is None: # Can't have a default value as an empty dict so need to set it here self.update_value("sparsity_config", {}) - + # Adding equal dataset weights if none are provided if self.train_data_paths and (self.train_data_weights is None): self.train_data_weights = [1.0] * len(self.train_data_paths) @@ -923,10 +947,10 @@ def validate_values(self): raise ValueError(error_message) return False - if self.save is not None and self.save_interval is None: + if self.save is not None and self.checkpoint_factor is None and self.extra_save_iters is None: error_message = ( self.__class__.__name__ - + ".validate_values() save_interval must be defined if save is defined" + + ".validate_values() checkpoint_factor or extra_save_iters must be defined if save is defined" ) logging.error(error_message) raise ValueError(error_message) diff --git a/megatron/neox_arguments/neox_args.py b/megatron/neox_arguments/neox_args.py index f9236a5dd..ec8b5439e 100644 --- a/megatron/neox_arguments/neox_args.py +++ b/megatron/neox_arguments/neox_args.py @@ -34,6 +34,7 @@ "bslongformer", "gmlp", "amlp", + "flash", ] @@ -635,6 +636,11 @@ class NeoXArgsOther(NeoXArgsTemplate): Set during training """ + save_iters: list = None + """ + Set during training + """ + global_num_gpus: int = None """ Set during launching @@ -770,9 +776,29 @@ class NeoXArgsTraining(NeoXArgsTemplate): save input and output of a forward pass with the checkpoint and validate after load """ - save_interval: int = None + checkpoint_scale: Literal["linear", "log"] = "linear" + """ + How step at which checkpoints are saved should scale. "linear" implies 1 checkpoint will be saved at every multiple of `checkpoint-factor`, + while "log" implies that the number of steps between each checkpoint will be multiplied by `checkpoint-factor` at each step, starting from step 1. + """ + + checkpoint_factor: int = None + """ + Acts as a multiplier on either the "log" or "linear" checkpoint spacing. + + With `checkpoint-scale="linear"`, `checkpoint-factor=20`, and `train-iters=100`, checkpoints will be saved at + steps [20, 40, 60, 80, 100]. + + With `checkpoint-scale="log"`, `checkpoint-factor=2`, and `train-iters=100`, checkpoints will be saved at + steps [1, 2, 4, 8, 16, 32, 64, 100]. + + Note that the last checkpoint step is always saved. + """ + + extra_save_iters: list = None """ - Number of iterations between checkpoint saves. + Additional iterations when a checkpoint should be saved. + Must be a list of ints or `None`. """ no_save_optim: bool = False diff --git a/megatron/training.py b/megatron/training.py index 6f90654b3..568d06a43 100644 --- a/megatron/training.py +++ b/megatron/training.py @@ -203,6 +203,16 @@ def pretrain(neox_args): iteration = 0 if neox_args.do_train and neox_args.train_iters > 0: + # edge case: save step 0 checkpoint if requested + if neox_args.save and 0 in neox_args.save_iters: + save_checkpoint( + neox_args=neox_args, + iteration=iteration, + model=model, + optimizer=optimizer, + lr_scheduler=lr_scheduler, + ) + iteration = train( neox_args=neox_args, timers=timers, @@ -757,8 +767,7 @@ def train( # Checkpointing if ( neox_args.save - and neox_args.save_interval - and iteration % neox_args.save_interval == 0 + and iteration in neox_args.save_iters ): save_checkpoint( neox_args=neox_args, diff --git a/requirements/requirements-flashattention.txt b/requirements/requirements-flashattention.txt new file mode 100644 index 000000000..0c7d41e59 --- /dev/null +++ b/requirements/requirements-flashattention.txt @@ -0,0 +1 @@ +flash-attn==0.2.2 diff --git a/tools/convert_to_hf.py b/tools/convert_to_hf.py index f10236539..c5c6b306d 100644 --- a/tools/convert_to_hf.py +++ b/tools/convert_to_hf.py @@ -105,11 +105,11 @@ def __init__(self, neox_config): # TODO: change the default value here based on discussion regarding `gpt_j_tied` config parameter's default use_tied_lns = get_key(neox_config, 'gpt-j-tied', False) - if not use_tied_lns: + if use_tied_lns: raise NotImplementedError( """ERROR: Huggingface Transformers does not yet support a single shared layernorm per transformer block for GPT-NeoX models trained w/ GPT-J parallel residuals. - See https://github.com/EleutherAI/gpt-neox/pull/481 for further details.""" + See https://github.com/EleutherAI/gpt-neox/pull/481 for further details.""") # set all config values. hf_config = GPTNeoXConfig(