Skip to content

Commit

Permalink
refactor: remove unfinished impl of mqa
Browse files Browse the repository at this point in the history
  • Loading branch information
jon-tow committed May 20, 2023
1 parent 2968bcc commit a896008
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 233 deletions.
259 changes: 32 additions & 227 deletions megatron/model/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@
from megatron.model.positional_embeddings import (
RotaryEmbedding,
apply_rotary_pos_emb,
apply_rotary_pos_emb_torch,
AliBi,
)
from megatron.model.fused_bias_dropout import (
Expand Down Expand Up @@ -203,7 +202,6 @@ def __init__(
self.attention_mask_func = attention_mask_func
self.apply_query_key_layer_scaling = neox_args.apply_query_key_layer_scaling
self.use_cache = use_cache
self.use_multi_query_attention = neox_args.multi_query_attention
self.attention_softmax_in_fp32 = neox_args.attention_softmax_in_fp32
if self.apply_query_key_layer_scaling:
self.attention_softmax_in_fp32 = True
Expand All @@ -222,52 +220,13 @@ def __init__(
)
self.pos_emb = neox_args.pos_emb

# print_rank_0(f"{'=' * 50}")
# print_rank_0(f"HIDDEN_SIZE: {neox_args.hidden_size}")
# print_rank_0(f"DERIVED HIDDEN: {self.hidden_size_per_attention_head * neox_args.num_attention_heads}")
# print_rank_0(f"WORLD SIZE: {world_size}")
# print_rank_0(f"{'=' * 50}")

# TODO: MIXED MULTI-QUERY ATTENTION
# if not self.use_multi_query_attention:
# projection_size = 3 * neox_args.hidden_size
# else:
# # Only query uses `num_attention_heads`, key and value use 1.
# head_dim = self.hidden_size_per_attention_head
# num_heads = neox_args.num_attention_heads
# projection_size = (num_heads * head_dim) + (2 * head_dim)
# # projection_size = 3 * neox_args.hidden_size
# self.query_key_value = mpu.ColumnParallelLinear(
# neox_args=neox_args,
# input_size=neox_args.hidden_size,
# output_size=projection_size,
# gather_output=False,
# init_method=init_method,
# )

if self.use_multi_query_attention:
self.key_value = mpu.ColumnParallelLinear(
neox_args=neox_args,
input_size=neox_args.hidden_size,
output_size=2 * self.hidden_size_per_attention_head, # 2 * head_dim
gather_output=False,
init_method=init_method,
)
self.query = mpu.ColumnParallelLinear(
neox_args=neox_args,
input_size=neox_args.hidden_size,
output_size=neox_args.hidden_size,
gather_output=False,
init_method=init_method,
)
else:
self.query_key_value = mpu.ColumnParallelLinear(
neox_args=neox_args,
input_size=neox_args.hidden_size,
output_size=3 * neox_args.hidden_size,
gather_output=False,
init_method=init_method,
)
self.query_key_value = mpu.ColumnParallelLinear(
neox_args=neox_args,
input_size=neox_args.hidden_size,
output_size=3 * neox_args.hidden_size,
gather_output=False,
init_method=init_method,
)

coeff = None
self.norm_factor = math.sqrt(self.hidden_size_per_attention_head)
Expand Down Expand Up @@ -329,10 +288,6 @@ def __init__(
raise ValueError(
"Flash attention is currently not compatible with AliBi positional embeddings. Use sinuisoidal, learned, or rotary embeddings instead."
)
from megatron.model.flash_attention import (
flash_attn_unpadded_qkvpacked_func_cuda,
)

else:
self.scale_mask_softmax = FusedScaleMaskSoftmax(
input_in_fp16=self.fp16,
Expand Down Expand Up @@ -360,98 +315,6 @@ def __init__(
parallel_output=parallel_output,
)

def multi_query_attention(
self, query_layer, key_layer, value_layer, layer_past, attention_mask
):
# ===================================
# Raw attention scores. [b, np, s, s]
# ===================================

# [b, np, sq, sk]
output_size = (
query_layer.size(1),
query_layer.size(2),
query_layer.size(0),
key_layer.size(0),
)

print_rank_0(f"\n{'='*40}\noqkv shape: {output_size}\n{'='*40}\n")
print_rank_0(f"\n{'='*40}\nquery_layer shape: {query_layer.shape}\n{'='*40}\n")
# h = 3072
# sq = 4096 (seq-length)
# b = 8 (mbs * num_devices)
# np = 8 (num attention heads / num_devices)
# hn = 192 (hidden size / num attention heads)
# [sq, b, np, hn] -> [sq, b * np, hn]
query_layer = query_layer.view(
output_size[2], output_size[0] * output_size[1], -1
)
print_rank_0(f"\n{'='*40}\nMERGED query_layer shape: {query_layer.shape}\n{'='*40}\n")

# [sk, b, 1, hn] -> [b, hn, sk]
print_rank_0(f"\n{'='*40}\nkey_layer shape: {key_layer.shape}\n{'='*40}\n")
key_layer = key_layer.squeeze(2).permute(1, 2, 0)
print_rank_0(f"\n{'='*40}\nMERGED key_layer shape: {key_layer.shape}\n{'='*40}\n")

# preallocating result tensor: [b * np, sq, sk]
matmul_input_buffer = torch.empty(
output_size[0] * output_size[1],
output_size[2],
output_size[3],
dtype=query_layer.dtype,
device=torch.cuda.current_device(),
)

matmul_result = torch.baddbmm(
matmul_input_buffer,
query_layer,
key_layer,
beta=0.0,
alpha=(1.0 / self.norm_factor),
)

# [s, b, np, hn] -> [b, s, np, hn] -> [b * s, 1, np, hn]
# query_layer = query_layer.transpose(0, 1).reshape(
# output_size[0] * output_size[2], 1, output_size[1], -1
# )
# key_layer = key_layer.transpose(0, 1).reshape(
# output_size[0] * output_size[3], 1, output_size[1], -1
# )
# value_layer = value_layer.transpose(0, 1).reshape(
# output_size[0] * output_size[3], 1, output_size[1], -1
# )

# Combined q/k/v into [b * s, 3, np, hn].
# qkv = torch.concat([query_layer, key_layer, value_layer], dim=1)

# batch_size = output_size[0]
# seqlen = output_size[2]
# max_s = seqlen

# cu_seqlens = torch.arange(
# 0,
# (batch_size + 1) * seqlen,
# step=seqlen,
# dtype=torch.int32,
# device=qkv.device,
# )
# output = self.flash_attention_function(
# qkv,
# cu_seqlens,
# max_s,
# self.dropout_p if self.training else 0.0,
# softmax_scale=None,
# causal=True,
# )
# # [b * sq, np, hn] -> [b, sq, np, hn]
# matmul_result = output.view(
# output_size[0], output_size[2], output.shape[1], output.shape[2]
# )
# # [b, sq, np, hn] -> [b, np, sq, hn]
# matmul_result = matmul_result.transpose(1, 2)

return matmul_result

def attention(
self, query_layer, key_layer, value_layer, layer_past, attention_mask
):
Expand Down Expand Up @@ -626,86 +489,28 @@ def sparse_attention(self, query_layer, key_layer, value_layer, attention_mask):
)

def forward(self, hidden_states, attention_mask, layer_past=None):

# hidden_states: [sq, b, h]

# print_rank_0(f"hidden_states: {hidden_states.shape}")

# =====================
# Query, Key, and Value
# =====================

if self.use_multi_query_attention:

# TODO: MIXED MULTI QUERY ATTENTION
# Attention heads [sq, b, dim] --> [sq, b, (2 * hn + h) / p]
# mixed_x_layer, _ = self.query_key_value(hidden_states)
# print_rank_0(f"{'=' * 50}")
# print_rank_0(f"mixed_x_layer: {mixed_x_layer.shape}")
# print_rank_0(f"dim / num_heads: {self.hidden_size_per_attention_head}")
# print_rank_0(f"num_heads / shard: {self.num_attention_heads_per_partition}")
# print_rank_0(f"{'=' * 50}")

# # [sq, b, (np * 3 * hn)] --> [sq, b, np, 3 * hn]
# new_tensor_shape = mixed_x_layer.size()[:-1] + (
# self.num_attention_heads_per_partition,
# 3 * self.hidden_size_per_attention_head,
# )
# mixed_x_layer = mixed_x_layer.view(*new_tensor_shape)

# # Get the size and dimension.
# last_dim = mixed_x_layer.dim() - 1
# last_dim_size = mpu.divide(mixed_x_layer.size()[last_dim], num_partitions)
# # Split.
# query_layer, key_layer, value_layer = torch.split(
# mixed_x_layer, last_dim_size, dim=last_dim)


# UN-MIXED MULTI QUERY ATTENTION

# Attention heads [sk, b, h] --> [sk, b, (1 * 2 * hn)]
mixed_kv_layer, _ = self.key_value(hidden_states)
print_rank_0(f"\n{'='*40}\nmixed_kv_layer: {mixed_kv_layer.shape}\n{'='*40}\n")

# [sk, b, (2 * hn)] --> [sk, b, 1, 2 * hn]
mixed_kv_layer = mixed_kv_layer.unsqueeze(2)

# new_tensor_shape = mixed_kv_layer.size()[:-1] + (
# 1,
# 2 * self.hidden_size_per_attention_head,
# )
# mixed_kv_layer = mixed_kv_layer.view(*new_tensor_shape)
print_rank_0(f"\n{'='*40}\nnew mixed_kv_layer: {mixed_kv_layer.shape}\n{'='*40}\n")

# [sk, b, 1, 2 * hn] --> 2 [sk, b, 1, hn]
key_layer, value_layer = mpu.split_tensor_along_last_dim(mixed_kv_layer, 2)
print_rank_0(f"\n{'='*40}\nkey_layer: {key_layer.shape}\n")
print_rank_0(f"value_layer: {value_layer.shape}\n{'='*40}\n")

# Attention head [sq, b, h] --> [sq, b, hp]
query_layer, _ = self.query(hidden_states)
# [sq, b, hp] --> [sq, b, np, hn]
new_tensor_shape = query_layer.size()[:-1] + (
self.num_attention_heads_per_partition,
self.hidden_size_per_attention_head,
)
query_layer = query_layer.view(*new_tensor_shape)
print_rank_0(f"\n{'='*40}\nquery_layer: {query_layer.shape}\n{'='*40}\n")
else:
# Attention heads [sq, b, h] --> [sq, b, (np * 3 * hn)]
mixed_x_layer, _ = self.query_key_value(hidden_states)
# Attention heads [sq, b, h] --> [sq, b, (np * 3 * hn)]
mixed_x_layer, _ = self.query_key_value(hidden_states)

# [sq, b, (np * 3 * hn)] --> [sq, b, np, 3 * hn]
new_tensor_shape = mixed_x_layer.size()[:-1] + (
self.num_attention_heads_per_partition,
3 * self.hidden_size_per_attention_head,
)
mixed_x_layer = mixed_x_layer.view(*new_tensor_shape)
# [sq, b, (np * 3 * hn)] --> [sq, b, np, 3 * hn]
new_tensor_shape = mixed_x_layer.size()[:-1] + (
self.num_attention_heads_per_partition,
3 * self.hidden_size_per_attention_head,
)
mixed_x_layer = mixed_x_layer.view(*new_tensor_shape)

# [sq, b, np, 3 * hn] --> 3 [sq, b, np, hn]
(query_layer, key_layer, value_layer) = mpu.split_tensor_along_last_dim(
mixed_x_layer, 3
)
# [sq, b, np, 3 * hn] --> 3 [sq, b, np, hn]
(query_layer, key_layer, value_layer) = mpu.split_tensor_along_last_dim(
mixed_x_layer, 3
)

# =====================
# Rotary
# =====================

if exists(self.rotary_emb):
if exists(self.rotary_ndims):
Expand All @@ -721,18 +526,18 @@ def forward(self, hidden_states, attention_mask, layer_past=None):
else:
# full rotary
query_rot, key_rot = query_layer, key_layer
apply_rotary_fn = (
apply_rotary_pos_emb_torch if self.bf16 else apply_rotary_pos_emb
)

apply_rotary_fn = apply_rotary_pos_emb

seq_len = key_layer.shape[0]
offset = 0
if exists(layer_past) and layer_past.numel() > 0:
offset = layer_past[0].shape[0]
seq_len += offset

cos, sin = self.rotary_emb(value_layer, seq_len=seq_len)
query_layer, key_layer = apply_rotary_fn(
query_rot, key_rot, cos, sin, offset=offset
query_rot, key_rot, cos, sin, offset=offset,
)

if exists(self.rotary_ndims):
Expand All @@ -753,11 +558,11 @@ def forward(self, hidden_states, attention_mask, layer_past=None):
if self.use_cache:
present = torch.stack((key_layer, value_layer))

if self.use_multi_query_attention:
context_layer = self.multi_query_attention(
query_layer, key_layer, value_layer, layer_past, attention_mask
)
elif self.use_flash_attention:
# ==================================
# Attention
# ==================================

if self.use_flash_attention:
context_layer = self.flash_attention(query_layer, key_layer, value_layer)
elif not self.sparse:
context_layer = self.attention(
Expand Down
6 changes: 0 additions & 6 deletions megatron/neox_arguments/neox_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,12 +167,6 @@ class NeoXArgsModel(NeoXArgsTemplate):
Disables weight tying between embedding weights and final Linear layer
"""

multi_query_attention: bool = False
"""
Enables multi-query attention.
See: https://arxiv.org/abs/1911.02150
"""

attention_config: list = None
"""
Attention configuration for gpt-neox
Expand Down

0 comments on commit a896008

Please sign in to comment.