Skip to content

Commit

Permalink
Merge branch 'main' into deepspeed_main
Browse files Browse the repository at this point in the history
  • Loading branch information
Quentin-Anthony committed Dec 17, 2022
2 parents 0535bfb + 8d1c5c8 commit f6a8f5d
Show file tree
Hide file tree
Showing 10 changed files with 372 additions and 24 deletions.
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ repos:
hooks:
- id: codespell
args: [
'--ignore-words-list=reord', # Word used in error messages that need rewording
'--ignore-words-list=reord,dout', # Word used in error messages that need rewording
--check-filenames,
--check-hidden,
]
5 changes: 5 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,11 @@ from the repository root.
</aside>


### Flash Attention

To use [Flash-Attention](https://github.com/HazyResearch/flash-attention), install the additional dependencies in `./requirements/requirements-flashattention.txt` and set the attention type in your configuration accordingly (see [configs](./configs/)). This can provide significant speed-ups over regular attention on certain GPU architectures, including Ampere GPUs (such as A100s); see the repository for more details.


### Containerized Setup

We also provide a Dockerfile if you prefer to run NeoX in a container. To use this option, first build an image named `gpt-neox` from the repository root directory with `docker build -t gpt-neox -f Dockerfile .`. We also host pre-built images on Docker Hub at `leogao2/gpt-neox`.
Expand Down
40 changes: 37 additions & 3 deletions configs/neox_arguments.md
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ Logging Arguments

- **git_hash**: str

Default = 166c5b6
Default = 12f6f76

current git hash of repository

Expand Down Expand Up @@ -798,6 +798,14 @@ Misc. Arguments



- **save_iters**: list

Default = None

Set during training



- **global_num_gpus**: int

Default = None
Expand Down Expand Up @@ -1133,11 +1141,37 @@ Training Arguments



- **save_interval**: int
- **checkpoint_scale**: typing.Literal['linear', 'log']

Default = linear

How step at which checkpoints are saved should scale. "linear" implies 1 checkpoint will be saved at every multiple of `checkpoint-factor`,
while "log" implies that the number of steps between each checkpoint will be multiplied by `checkpoint-factor` at each step, starting from step 1.



- **checkpoint_factor**: int

Default = None

Acts as a multiplier on either the "log" or "linear" checkpoint spacing.

With `checkpoint-scale="linear"`, `checkpoint-factor=20`, and `train-iters=100`, checkpoints will be saved at
steps [20, 40, 60, 80, 100].

With `checkpoint-scale="log"`, `checkpoint-factor=2`, and `train-iters=100`, checkpoints will be saved at
steps [1, 2, 4, 8, 16, 32, 64, 100].

Note that the last checkpoint step is always saved.



- **extra_save_iters**: list

Default = None

Number of iterations between checkpoint saves.
Additional iterations when a checkpoint should be saved.
Must be a list of ints or `None`.



Expand Down
185 changes: 185 additions & 0 deletions megatron/model/flash_attention.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,185 @@
# Based on: https://github.com/HazyResearch/flash-attention/blob/4a6eaa9f27df6fff7ffb2c24e894938a687dd870/flash_attn/flash_attn_interface.py

import torch
import torch.nn as nn
import torch.nn.functional as F

import flash_attn_cuda


def _flash_attn_forward(
q,
k,
v,
out,
cu_seqlens_q,
cu_seqlens_k,
max_seqlen_q,
max_seqlen_k,
dropout_p,
softmax_scale,
causal,
return_softmax,
num_splits=0,
generator=None,
):
"""
num_splits: how much to parallelize over the seqlen_q dimension. num_splits=0 means
it will be set by an internal heuristic. We're exposing num_splits mostly for benchmarking.
Don't change it unless you know what you're doing.
"""
softmax_lse, *rest = flash_attn_cuda.fwd(
q,
k,
v,
out,
cu_seqlens_q,
cu_seqlens_k,
max_seqlen_q,
max_seqlen_k,
dropout_p,
softmax_scale,
False,
causal,
return_softmax,
num_splits,
generator,
)
# if out.isnan().any() or softmax_lse.isnan().any():
# breakpoint()
S_dmask = rest[0] if return_softmax else None
return out, softmax_lse, S_dmask


def _flash_attn_backward(
dout,
q,
k,
v,
out,
softmax_lse,
dq,
dk,
dv,
cu_seqlens_q,
cu_seqlens_k,
max_seqlen_q,
max_seqlen_k,
dropout_p,
softmax_scale,
causal,
num_splits=0,
generator=None,
):
"""
num_splits: whether to parallelize over the seqlen_k dimension (num_splits > 1) or
not (num_splits = 1). num_splits=0 means it will be set by an internal heuristic.
Any value above 1 will call the same kernel (i.e. num_splits=2 would call the same kernel
as num_splits=3), so effectively the choices are 0, 1, and 2.
This hyperparameter can be tuned for performance, but default value (heuristic) should work fine.
"""
_, _, _, softmax_d = flash_attn_cuda.bwd(
dout,
q,
k,
v,
out,
softmax_lse,
dq,
dk,
dv,
cu_seqlens_q,
cu_seqlens_k,
max_seqlen_q,
max_seqlen_k,
dropout_p,
softmax_scale,
False,
causal,
num_splits,
generator,
)
# if dk.isnan().any() or dk.isnan().any() or dv.isnan().any() or softmax_d.isnan().any():
# breakpoint()
return dq, dk, dv, softmax_d


class FlashAttnQKVPackedFunc(torch.autograd.Function):
@staticmethod
def forward(
ctx,
qkv,
cu_seqlens,
max_seqlen,
dropout_p,
softmax_scale,
causal,
return_softmax,
):
# Save rng_state because the backward pass will regenerate the dropout mask
rng_state = torch.cuda.get_rng_state() if dropout_p > 0 else None
if softmax_scale is None:
softmax_scale = qkv.shape[-1] ** (-0.5)
out, softmax_lse, S_dmask = _flash_attn_forward(
qkv[:, 0],
qkv[:, 1],
qkv[:, 2],
torch.empty_like(qkv[:, 0]),
cu_seqlens,
cu_seqlens,
max_seqlen,
max_seqlen,
dropout_p,
softmax_scale,
causal=causal,
return_softmax=return_softmax,
)
ctx.save_for_backward(qkv, out, softmax_lse, cu_seqlens, rng_state)
ctx.dropout_p = dropout_p
ctx.max_seqlen = max_seqlen
ctx.softmax_scale = softmax_scale
ctx.causal = causal
return out if not return_softmax else (out, softmax_lse, S_dmask)

@staticmethod
def backward(ctx, dout, *args):
qkv, out, softmax_lse, cu_seqlens, rng_state = ctx.saved_tensors
if rng_state is not None:
cur_rng_state = torch.cuda.get_rng_state()
torch.cuda.set_rng_state(rng_state)
dqkv = torch.empty_like(qkv)
_flash_attn_backward(
dout,
qkv[:, 0],
qkv[:, 1],
qkv[:, 2],
out,
softmax_lse,
dqkv[:, 0],
dqkv[:, 1],
dqkv[:, 2],
cu_seqlens,
cu_seqlens,
ctx.max_seqlen,
ctx.max_seqlen,
ctx.dropout_p,
ctx.softmax_scale,
ctx.causal,
)
if rng_state is not None:
torch.cuda.set_rng_state(cur_rng_state)
return dqkv, None, None, None, None, None, None


def flash_attn_unpadded_qkvpacked_func(
qkv,
cu_seqlens,
max_seqlen,
dropout_p,
softmax_scale=None,
causal=False,
return_attn_probs=False,
):
return FlashAttnQKVPackedFunc.apply(
qkv, cu_seqlens, max_seqlen, dropout_p, softmax_scale, causal, return_attn_probs
)
86 changes: 75 additions & 11 deletions megatron/model/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -265,7 +265,8 @@ def __init__(
self.rotary_emb = None

self.attention_type = neox_args.attention_config[layer_number]
self.sparse = self.attention_type != "global"
self.use_flash_attention = self.attention_type == "flash"
self.sparse = self.attention_type != "global" and not self.use_flash_attention
if self.sparse:
self.sparse_attn = configure_sparse_attention(
neox_args,
Expand All @@ -274,19 +275,31 @@ def __init__(
mpu=mpu,
)
else:
self.scale_mask_softmax = FusedScaleMaskSoftmax(
input_in_fp16=self.fp16,
input_in_bf16=self.bf16,
fusion_type=get_fusion_type(neox_args),
mask_func=self.attention_mask_func,
softmax_in_fp32=self.attention_softmax_in_fp32,
scale=(coeff / neox_args.mup_attn_temp) if coeff is not None else None, # TODO: deepspeed sparse attention scaling patch?
)
if self.use_flash_attention:
from megatron.model.flash_attention import (
flash_attn_unpadded_qkvpacked_func,
)

self.flash_attention_function = flash_attn_unpadded_qkvpacked_func
if self.pos_emb == "alibi":
raise ValueError(
"Flash attention is currently not compatible with AliBi positional embeddings. Use sinuisoidal, learned, or rotary embeddings instead."
)
else:
self.scale_mask_softmax = FusedScaleMaskSoftmax(
input_in_fp16=self.fp16,
input_in_bf16=self.bf16,
fusion_type=get_fusion_type(neox_args),
mask_func=self.attention_mask_func,
softmax_in_fp32=self.attention_softmax_in_fp32,
scale=coeff,
)

# Dropout. Note that for a single iteration, this layer will generate
# different outputs on different number of parallel partitions but
# on average it should not be partition dependent.
self.attention_dropout = nn.Dropout(neox_args.attention_dropout)
self.dropout_p = neox_args.attention_dropout
self.attention_dropout = nn.Dropout(self.dropout_p)

# Output.
self.dense = mpu.RowParallelLinear(
Expand Down Expand Up @@ -402,6 +415,55 @@ def attention(
context_layer = context_layer.view(*output_size)
return context_layer

def flash_attention(self, query_layer, key_layer, value_layer):
# [b, np, sq, sk]
output_size = (
query_layer.size(1),
query_layer.size(2),
query_layer.size(0),
key_layer.size(0),
)
# [s, b, np, hn] -> [b, s, np, hn] -> [b * s, 1, np, hn]
query_layer = query_layer.transpose(0, 1).reshape(
output_size[0] * output_size[2], 1, output_size[1], -1
)
key_layer = key_layer.transpose(0, 1).reshape(
output_size[0] * output_size[3], 1, output_size[1], -1
)
value_layer = value_layer.transpose(0, 1).reshape(
output_size[0] * output_size[3], 1, output_size[1], -1
)

# Combined q/k/v into [b * s, 3, np, hn].
qkv = torch.concat([query_layer, key_layer, value_layer], dim=1)

batch_size = output_size[0]
seqlen = output_size[2]
max_s = seqlen
cu_seqlens = torch.arange(
0,
(batch_size + 1) * seqlen,
step=seqlen,
dtype=torch.int32,
device=qkv.device,
)
output = self.flash_attention_function(
qkv,
cu_seqlens,
max_s,
self.dropout_p if self.training else 0.0,
softmax_scale=None,
causal=True,
)
# [b * sq, np, hn] -> [b, sq, np, hn]
matmul_result = output.view(
output_size[0], output_size[2], output.shape[1], output.shape[2]
)
# [b, sq, np, hn] -> [b, np, sq, hn]
matmul_result = matmul_result.transpose(1, 2)

return matmul_result

def sparse_attention(self, query_layer, key_layer, value_layer, attention_mask):
# TODO: sparse attn dropout?
# TODO: pad to block size
Expand Down Expand Up @@ -489,7 +551,9 @@ def forward(self, hidden_states, attention_mask, layer_past=None):
if self.use_cache:
present = torch.stack((key_layer, value_layer))

if not self.sparse:
if self.use_flash_attention:
context_layer = self.flash_attention(query_layer, key_layer, value_layer)
elif not self.sparse:
context_layer = self.attention(
query_layer, key_layer, value_layer, layer_past, attention_mask
)
Expand Down

0 comments on commit f6a8f5d

Please sign in to comment.