Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Bf16 #330

Merged
merged 24 commits into from
Jun 22, 2021
Merged

Bf16 #330

Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Merge branch 'main' into _bf16
  • Loading branch information
StellaAthena authored Jun 20, 2021
commit a19b8aac6d2cab97019d7156bc4ce0213b458ebb
14 changes: 3 additions & 11 deletions megatron/model/fused_softmax.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,25 +108,17 @@ def forward(self, input, mask):
# [b, np, sq, sk]
assert input.dim() == 4
data_size = input.size()
assert input.dim() == 4
data_size = input.size()
query_seq_len = data_size[-2]
key_seq_len = data_size[-1]
attn_batch_size = data_size[0] * data_size[1]

# constraints on various tensor dimensions to enable warp based
# optimization and upper triangular optimization (for causal mask)
custom_kernel_constraint = 16 < key_seq_len <= 2048 and \
query_seq_len % 4 == 0 and attn_batch_size % 4 == 0

# constraints on various tensor dimensions to enable warp based
# optimization and upper triangular optimization (for causal mask)
custom_kernel_constraint = key_seq_len > 16 and key_seq_len <= 2048 and query_seq_len % 4 == 0 and attn_batch_size % 4 == 0

custom_kernel_constraint = 16 < key_seq_len <= 2048 and query_seq_len % 4 == 0 and attn_batch_size % 4 == 0

# invoke custom kernel
if self.input_in_float16 and data_size[-1] <= 2048 and \
(self.upper_triang_mask_fusion or self.general_mask_fusion) and \
query_seq_len == key_seq_len:
if self.input_in_float16 and data_size[-1] <= 2048 and mask is not None and (self.upper_triang_mask_fusion or self.general_mask_fusion) and query_seq_len == key_seq_len:
assert custom_kernel_constraint
scale = self.scale if self.scale is not None else 1.0
if self.upper_triang_mask_fusion:
Expand Down
4 changes: 1 addition & 3 deletions megatron/model/gpt2_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from collections import defaultdict

from functools import partial
from megatron.model.utils import Lambda, SequentialWrapper
from megatron.model.utils import Lambda, SequentialWrapper, _set_get_key_value
from megatron.model.norms import get_norm
from megatron.model.init_functions import get_init_methods

Expand Down Expand Up @@ -201,8 +201,6 @@ def init_specs(self):

self.specs.append(_post_transformer_block)

# Final layernorm after transformer layers
norm, eps = get_norm(self.neox_args)
# NormPipe is a helper class to pass presents through to the output when doing inference
norm, eps = get_norm(self.neox_args)
self.specs.append(
Expand Down
38 changes: 18 additions & 20 deletions megatron/model/norms.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import torch
from .fused_layer_norm import MixedFusedLayerNorm as LayerNorm

# Attempt to import FusedLayerNorm from Apex
try:
Expand All @@ -12,8 +13,22 @@
'instead of apex.normalization.FusedLayerNorm!')
from torch.nn import LayerNorm as ApexLayerNorm


from .fused_layer_norm import MixedFusedLayerNorm as LayerNorm
def get_norm(neox_args):
if neox_args.norm == "rmsnorm":
norm = RMSNorm
eps = neox_args.rms_norm_epsilon
elif neox_args.norm == "layernorm":
eps = neox_args.layernorm_epsilon
norm = LayerNorm
elif neox_args.norm == "scalenorm":
eps = neox_args.scalenorm_epsilon
norm = ScaleNorm
elif neox_args.norm == "apexlayernorm":
eps = neox_args.layernorm_epsilon
norm = ApexLayerNorm
else:
raise ValueError(f"norm {neox_args.norm} not recognized")
return norm, eps

class RMSNorm(torch.nn.Module):
def __init__(self, dim, p=-1., eps=1e-8, bias=False):
Expand Down Expand Up @@ -67,21 +82,4 @@ def __init__(self, dim, eps=1e-5):

def forward(self, x):
n = torch.norm(x, dim=-1, keepdim=True).clamp(min=self.eps)
return x / n * self.g

def get_norm(neox_args):
if neox_args.norm == "rmsnorm":
norm = RMSNorm
eps = neox_args.rms_norm_epsilon
elif neox_args.norm == "layernorm":
eps = neox_args.layernorm_epsilon
norm = LayerNorm
elif neox_args.norm == "scalenorm":
eps = neox_args.scalenorm_epsilon
norm = ScaleNorm
elif neox_args.norm == "apexlayernorm":
eps = neox_args.layernorm_epsilon
norm = ApexLayerNorm
else:
raise ValueError(f"norm {neox_args.norm} not recognized")
return norm, eps
return x / n * self.g
9 changes: 8 additions & 1 deletion megatron/model/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -346,7 +346,14 @@ def forward(self, hidden_states, attention_mask, layer_past=None):
# full rotary
query_rot, key_rot = query_layer, key_layer
apply_rotary_fn = apply_rotary_pos_emb_torch if self.bf16 else apply_rotary_pos_emb
query_layer, key_layer = apply_rotary_fn(query_rot, key_rot, cos, sin)

seq_len = key_layer.shape[0]
offset = 0
if exists(layer_past) and layer_past.numel() > 0:
offset = layer_past[0].shape[0]
seq_len += offset
cos, sin = self.rotary_emb(value_layer, seq_len=seq_len)
query_layer, key_layer = apply_rotary_pos_emb(query_rot, key_rot, cos, sin, offset=offset)

if exists(self.rotary_ndims):
query_layer = torch.cat((query_layer, query_pass), dim=-1)
Expand Down
3 changes: 1 addition & 2 deletions megatron/model/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
from deepspeed.ops.sparse_attention import SparseSelfAttention, VariableSparsityConfig, FixedSparsityConfig, \
BigBirdSparsityConfig, BSLongformerSparsityConfig
from deepspeed.ops.sparse_attention.sparsity_config import LocalSlidingWindowSparsityConfig
from megatron.model.norms import LayerNorm, RMSNorm, ScaleNorm
from megatron.model.norms import LayerNorm, RMSNorm, ScaleNorm, ApexLayerNorm


def get_params_for_weight_decay_optimization(module, neox_args):
Expand All @@ -33,7 +33,6 @@ def get_params_for_weight_decay_optimization(module, neox_args):
"""
weight_decay_params = {'params': []}
no_weight_decay_params = {'params': [], 'weight_decay': 0.0}
from megatron.model.norms import LayerNorm, RMSNorm, ScaleNorm, ApexLayerNorm
for module_ in module.modules():
if any([isinstance(module_, ApexLayerNorm), isinstance(module_, LayerNorm), isinstance(module_, RMSNorm), isinstance(module_, ScaleNorm)]) or \
(neox_args.weight_decay == 0.0): # also include all parameters here if no weight decay is being done
Expand Down
You are viewing a condensed version of this merge commit. You can view the full changes here.