Skip to content

Commit

Permalink
llama support flash-attention (CoinCheung#7)
Browse files Browse the repository at this point in the history
  • Loading branch information
CoinCheung authored Jul 31, 2023
1 parent a72f2d9 commit 5c287b6
Show file tree
Hide file tree
Showing 6 changed files with 121 additions and 55 deletions.
54 changes: 11 additions & 43 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,42 +3,6 @@

这个项目没有什么理论上的创新,没有提出茴香豆的新写法,也没发明什么新工具,仅仅是基于现有的方法和库提供一套简洁易扩展的代码,可以在8张v100服务器上训练7b的模型(对全部模型参数做full-finetune的那种训练),可以在更多gpu上训练更大的模型,也可以联机训练,速度比zero3方法更快,并且支持更长的输入序列长度。

以下是在我的8张V100上测出来的训练速度,设置是`micro_batch_size=1``global_batch_size=128`,训练20个step看log显示的速度(sample/s)。


<table class="center" style="margin-left: auto; margin-right: auto"><tbody>
<!-- START TABLE -->
<!-- TABLE HEADER -->
<tr>
<td align="center"><sup><sub>max_seq_len</sub></sup></td>
<td align="center"><sup><sub>256</sub></sup></td>
<td align="center"><sup><sub>384</sub></sup></td>
<td align="center"><sup><sub>512</sub></sup></td>
<td align="center"><sup><sub>768</sub></sup></td>
<td align="center"><sup><sub>1024</sub></sup></td>
<td align="center"><sup><sub>1280</sub></sup></td>
</tr>
<tr>
<td align="center"><sup><sub>bloom-7b</sub></sup></td>
<td align="center"><sup><sub>20.29</sub></sup></td>
<td align="center"><sup><sub>15.83</sub></sup></td>
<td align="center"><sup><sub>12.99</sub></sup></td>
<td align="center"><sup><sub>9.21</sub></sup></td>
<td align="center"><sup><sub>7.03</sub></sup></td>
<td align="center"><sup><sub>oom</sub></sup></td>
</tr>
<tr>
<td align="center"><sup><sub>llama-7b</sub></sup></td>
<td align="center"><sup><sub>22.32</sub></sup></td>
<td align="center"><sup><sub>18.35</sub></sup></td>
<td align="center"><sup><sub>14.88</sub></sup></td>
<td align="center"><sup><sub>10.40</sub></sup></td>
<td align="center"><sup><sub>8.13</sub></sup></td>
<td align="center"><sup><sub>oom</sub></sup></td>
</tr>
</tr>
<!-- END RPN TABLE -->
</tbody></table>


### 我的环境
Expand All @@ -53,6 +17,7 @@
* torch==1.13.1
* sentencepiece
* protobuf==3.20.0 (python pip install)
* flash_attn==2.0.2


### 训练
Expand Down Expand Up @@ -230,29 +195,32 @@ use_grad_ckpt: true
<!-- END RPN TABLE -->
</tbody></table>

(2) 使用zero的offload
(2) 使用flash-attention
flash-attention可以加快qkv的计算速度,而且还能省内存,用过的人都说好,如果你的平台可以运行flash-attention的话,可以在配置文件`configs/ds_config_pp.yml`里面这样设置:
```yaml
use_flash_attn: true
```
到2023.8为止,flash-attention还不支持V100,在本项目里面也只支持llama不支持bloom模型。
(3) 使用zero的offload
意思是说,在训练过程中,把一部分gpu内存上的模型参数以及优化器状态等移动到cpu内存上,只有用到的时候再移回gpu内存。这种方法会引入通信延时,就是cpu和gpu之间的通信会导致训练时间变长,属于牺牲了一部分速度换取更多的空间的方法,如果想这样做的话,可以在`configs/ds_config_pp.yml`里面加上下面这个:
```yaml
zero_allow_untested_optimizer: true
zero_force_ds_cpu_optimizer: false
zero_optimization:
stage: 1
offload_param:
device: cpu
pin_memory: true
offload_optimizer:
device: cpu
pin_memory: true
```

(3) 使用其他优化器
(4) 使用其他优化器
adamw的一个缺点就是对每个参数都要有param/m/v,也就是要占用三倍参数的存储空间,lion优化器没有这个问题,亲测在我的服务器上使用lion可以在8张v100上训练llama-13b(max_seq_len=128),如果想试试这个优化器的话,可以在`configs/ds_config_pp.yml`里面把优化器的配置改成这样:
```yml
optimizer:
type: Lion
params:
lr: 2.0e-4
betas: [0.95, 0.98]
betas: [0.9, 0.999]
use_triton: true
weight_decay: 2.0e-4
```
Expand Down
1 change: 1 addition & 0 deletions configs/ds_config_pp.yml
Original file line number Diff line number Diff line change
Expand Up @@ -47,3 +47,4 @@ data_path: ./dataset.json

use_grad_ckpt: false

use_flash_attn: false
21 changes: 13 additions & 8 deletions datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,22 +12,27 @@



def truncate_and_pad_left(input_ids, att_mask, labels, max_seq_len, tokenizer):
def truncate_and_pad_func(input_ids, att_mask, labels, max_seq_len, tokenizer, side='right'):
# truncate to max_seq_len
input_ids = input_ids[:max_seq_len]
att_mask = att_mask[:max_seq_len]
labels = labels[:max_seq_len]

# pad to left
# pad
len_pad = max_seq_len - input_ids.size(0)
if len_pad > 0:
pad_token_id = tokenizer.pad_token_id
pad_inp = torch.zeros(len_pad, dtype=torch.long).fill_(pad_token_id)
pad_att = torch.zeros(len_pad, dtype=torch.long)
pad_lb = torch.zeros(len_pad, dtype=torch.long).fill_(-100) # ignore pad label
input_ids = torch.cat([pad_inp, input_ids], dim=0)
att_mask = torch.cat([pad_att, att_mask], dim=0)
labels = torch.cat([pad_lb, labels], dim=0)
if side == 'left':
input_ids = torch.cat([pad_inp, input_ids], dim=0)
att_mask = torch.cat([pad_att, att_mask], dim=0)
labels = torch.cat([pad_lb, labels], dim=0)
elif side == 'right':
input_ids = torch.cat([input_ids, pad_inp], dim=0)
att_mask = torch.cat([att_mask, pad_att], dim=0)
labels = torch.cat([labels, pad_lb], dim=0)

inputs = torch.cat([input_ids.unsqueeze(-1), att_mask.unsqueeze(-1)], dim=-1)

Expand Down Expand Up @@ -122,7 +127,7 @@ def parse_instruct_sample(tokenizer, ob, max_seq_len, ignore_known=False):
att_mask = torch.cat([p_attm, o_attm], dim=0)
labels = torch.cat([p_lb, o_lb], dim=0)

res = truncate_and_pad_left(input_ids, att_mask, labels, max_seq_len, tokenizer)
res = truncate_and_pad_func(input_ids, att_mask, labels, max_seq_len, tokenizer)
inputs, labels = res
labels[-1] = tokenizer.eos_token_id

Expand Down Expand Up @@ -163,7 +168,7 @@ def parse_conversation_sample(tokenizer, ob, max_seq_len, ignore_known=False):
att_mask = torch.cat([h_attm, ] + res_rounds[1], dim=0)
labels = torch.cat([h_lb, ] + res_rounds[2], dim=0)

res = truncate_and_pad_left(input_ids, att_mask, labels, max_seq_len, tokenizer)
res = truncate_and_pad_func(input_ids, att_mask, labels, max_seq_len, tokenizer)
inputs, labels = res
labels[-1] = tokenizer.eos_token_id

Expand Down Expand Up @@ -200,7 +205,7 @@ def parse_ref_qa_sample(tokenizer, ob, max_seq_len, ignore_known=False):
att_mask = torch.cat([hr_attm, ] + res_rounds[1], dim=0)
labels = torch.cat([hr_lb, ] + res_rounds[2], dim=0)

res = truncate_and_pad_left(input_ids, att_mask, labels, max_seq_len, tokenizer)
res = truncate_and_pad_func(input_ids, att_mask, labels, max_seq_len, tokenizer)
inputs, labels = res
labels[-1] = tokenizer.eos_token_id

Expand Down
3 changes: 2 additions & 1 deletion models/bloom.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,7 +177,8 @@ def set_input_embeddings(self, new_embeddings: torch.Tensor):
self.word_embeddings = new_embeddings


def get_bloom_causal_lm_specs(config, load_path=None, grad_ckpt=False, tie_emb=True):
def get_bloom_causal_lm_specs(config, load_path=None, grad_ckpt=False,
tie_emb=True, use_flash_attn=False):
specs = []
ldpth = osp.join(load_path, 'layer_00-model_states.pt') if load_path else None
if tie_emb:
Expand Down
96 changes: 93 additions & 3 deletions models/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,15 +9,103 @@
from transformers.models.llama.modeling_llama import _make_causal_mask, _expand_mask
from deepspeed.pipe import PipelineModule, LayerSpec, TiedLayerSpec

try:
from flash_attn import flash_attn_func
except ImportError:
flash_attn_func = None


class LlamaAttentionFlashAttn(LlamaAttention):

def __init__(self, config: LlamaConfig):
super().__init__(config)
self.inv_norm_factor = 1. / math.sqrt(self.head_dim)

def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
output_attentions: bool = False,
use_cache: bool = False,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
bsz, q_len, _ = hidden_states.size()

if self.pretraining_tp > 1:
key_value_slicing = (self.num_key_value_heads * self.head_dim) // self.pretraining_tp
query_slices = self.q_proj.weight.split((self.num_heads * self.head_dim) // self.pretraining_tp, dim=0)
key_slices = self.k_proj.weight.split(key_value_slicing, dim=0)
value_slices = self.v_proj.weight.split(key_value_slicing, dim=0)

query_states = [F.linear(hidden_states, query_slices[i]) for i in range(self.pretraining_tp)]
query_states = torch.cat(query_states, dim=-1)

key_states = [F.linear(hidden_states, key_slices[i]) for i in range(self.pretraining_tp)]
key_states = torch.cat(key_states, dim=-1)

value_states = [F.linear(hidden_states, value_slices[i]) for i in range(self.pretraining_tp)]
value_states = torch.cat(value_states, dim=-1)

else:
query_states = self.q_proj(hidden_states)
key_states = self.k_proj(hidden_states)
value_states = self.v_proj(hidden_states)

query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)

kv_seq_len = key_states.shape[-2]
if past_key_value is not None:
kv_seq_len += past_key_value[0].shape[-2]
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)

if past_key_value is not None:
# reuse k, v, self_attention
key_states = torch.cat([past_key_value[0], key_states], dim=2)
value_states = torch.cat([past_key_value[1], value_states], dim=2)

past_key_value = (key_states, value_states) if use_cache else None

# repeat k/v heads if n_kv_heads < n_heads
# key_states = repeat_kv(key_states, self.num_key_value_groups)
# value_states = repeat_kv(value_states, self.num_key_value_groups)

## Now: qkv are [bs, num_heads, q_len, head_dim]
## flash-atten requires them to be: [bs, q_len, num_heads, head_dim]
query_states = torch.einsum('bhld->blhd', query_states)
key_states = torch.einsum('bhld->blhd', key_states)
value_states = torch.einsum('bhld->blhd', value_states)
attn_output = flash_attn_func(query_states, key_states,
value_states, dropout_p=0., softmax_scale=self.inv_norm_factor,
causal=True)

attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)

if self.pretraining_tp > 1:
attn_output = attn_output.split(self.hidden_size // self.pretraining_tp, dim=2)
o_proj_slices = self.o_proj.weight.split(self.hidden_size // self.pretraining_tp, dim=1)
attn_output = sum([F.linear(attn_output[i], o_proj_slices[i]) for i in range(self.pretraining_tp)])
else:
attn_output = self.o_proj(attn_output)

if not output_attentions:
attn_weights = None

return attn_output, attn_weights, past_key_value



class LlamaDecoderLayerTupleIO(LlamaDecoderLayer):

def __init__(self, config: LlamaConfig, load_path=None, gradient_checkpointing=False):
def __init__(self, config: LlamaConfig, load_path=None,
gradient_checkpointing=False, use_flash_attn=False):
super().__init__(config)
if load_path: self.load_state_dict(torch.load(load_path))
self.gradient_checkpointing = gradient_checkpointing
if use_flash_attn: self.self_attn = LlamaAttentionFlashAttn(config=config)

def forward(self, inputs):
"""
Expand Down Expand Up @@ -151,7 +239,8 @@ def weight(self):


## llama does not tie weights
def get_llama_causal_lm_specs(config, load_path=None, grad_ckpt=False, tie_emb=False):
def get_llama_causal_lm_specs(config, load_path=None, grad_ckpt=False,
tie_emb=False, use_flash_attn=False):
specs = []
ldpth = osp.join(load_path, 'layer_00-model_states.pt') if load_path else None
if tie_emb:
Expand All @@ -165,7 +254,8 @@ def get_llama_causal_lm_specs(config, load_path=None, grad_ckpt=False, tie_emb=F
ldpth = None
if load_path: ldpth = osp.join(load_path, f'layer_{i:02d}-model_states.pt')
specs.append(LayerSpec(LlamaDecoderLayerTupleIO, config,
load_path=ldpth, gradient_checkpointing=grad_ckpt))
load_path=ldpth, gradient_checkpointing=grad_ckpt,
use_flash_attn=use_flash_attn))

ldpth = None
ind = config.num_hidden_layers + 1
Expand Down
1 change: 1 addition & 0 deletions train_ds.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ def get_model(model_path, grad_ckpt=False):
if ds_cfg['from_scratch']: kwargs['load_path'] = None
if hasattr(config, 'tie_word_embeddings'):
kwargs['tie_emb'] = config.tie_word_embeddings
kwargs['use_flash_attn'] = ds_cfg.get('use_flash_attn', False)

if re.search('llama', model_type):
specs = get_llama_causal_lm_specs(**kwargs)
Expand Down

0 comments on commit 5c287b6

Please sign in to comment.