Skip to content

Commit

Permalink
fused layernorm (#1105)
Browse files Browse the repository at this point in the history
* Add simple util for CUDA timings

* Add fused layernorm kernel from Megatron

Closes #952

* change default fused layernorm to false

* Update test_setup.yml

* Update test_train_base.yml

---------

Co-authored-by: Yang Zhang <[email protected]>
Co-authored-by: jahatef <[email protected]>
Co-authored-by: Jacob Hatef <[email protected]>
  • Loading branch information
4 people committed Jan 26, 2024
1 parent 7a8fa2f commit 3d8fec0
Show file tree
Hide file tree
Showing 23 changed files with 285 additions and 1 deletion.
1 change: 1 addition & 0 deletions configs/1-3B.yml
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
"scaled_upper_triang_masked_softmax_fusion": false,
"bias_gelu_fusion": false,
"rope_fusion": false,
"layernorm_fusion": false,

# init methods
"init_method": "small_init",
Expand Down
1 change: 1 addition & 0 deletions configs/125M-json.yml
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
"scaled_upper_triang_masked_softmax_fusion": false,
"bias_gelu_fusion": false,
"rope_fusion": false,
"layernorm_fusion": false,

"init_method": "small_init",
"output_layer_init_method": "wang_init",
Expand Down
1 change: 1 addition & 0 deletions configs/125M.yml
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
"scaled_upper_triang_masked_softmax_fusion": false,
"bias_gelu_fusion": false,
"rope_fusion": false,
"layernorm_fusion": false,

# init methods
"init_method": "small_init",
Expand Down
1 change: 1 addition & 0 deletions configs/13B.yml
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
"scaled_upper_triang_masked_softmax_fusion": false,
"bias_gelu_fusion": false,
"rope_fusion": false,
"layernorm_fusion": false,

# init methods
"init_method": "small_init",
Expand Down
1 change: 1 addition & 0 deletions configs/175B.yml
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
"scaled_upper_triang_masked_softmax_fusion": false,
"bias_gelu_fusion": false,
"rope_fusion": false,
"layernorm_fusion": false,

# init methods
"init_method": "small_init",
Expand Down
1 change: 1 addition & 0 deletions configs/19M.yml
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
"scaled_upper_triang_masked_softmax_fusion": false,
"bias_gelu_fusion": false,
"rope_fusion": false,
"layernorm_fusion": false,

# init methods
"init_method": "small_init",
Expand Down
1 change: 1 addition & 0 deletions configs/2-7B.yml
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
"scaled_upper_triang_masked_softmax_fusion": false,
"bias_gelu_fusion": false,
"rope_fusion": false,
"layernorm_fusion": false,

# init methods
"init_method": "small_init",
Expand Down
1 change: 1 addition & 0 deletions configs/20B.yml
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
"scaled_upper_triang_masked_softmax_fusion": true,
"bias_gelu_fusion": true,
"rope_fusion": false,
"layernorm_fusion": false,

# init methods
"init_method": "small_init",
Expand Down
1 change: 1 addition & 0 deletions configs/350M.yml
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
"scaled_upper_triang_masked_softmax_fusion": false,
"bias_gelu_fusion": false,
"rope_fusion": false,
"layernorm_fusion": false,

# init methods
"init_method": "small_init",
Expand Down
1 change: 1 addition & 0 deletions configs/49M.yml
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
"scaled_upper_triang_masked_softmax_fusion": false,
"bias_gelu_fusion": false,
"rope_fusion": false,
"layernorm_fusion": false,

# init methods
"init_method": "small_init",
Expand Down
1 change: 1 addition & 0 deletions configs/6-7B.yml
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
"scaled_upper_triang_masked_softmax_fusion": false,
"bias_gelu_fusion": false,
"rope_fusion": false,
"layernorm_fusion": false,

# init methods
"init_method": "small_init",
Expand Down
1 change: 1 addition & 0 deletions configs/760M.yml
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
"scaled_upper_triang_masked_softmax_fusion": false,
"bias_gelu_fusion": false,
"rope_fusion": false,
"layernorm_fusion": false,

# init methods
"init_method": "small_init",
Expand Down
1 change: 1 addition & 0 deletions configs/800M.yml
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
"scaled_upper_triang_masked_softmax_fusion": false,
"bias_gelu_fusion": false,
"rope_fusion": false,
"layernorm_fusion": false,

# init methods
"init_method": "small_init",
Expand Down
1 change: 1 addition & 0 deletions configs/bf16_125M.yml
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
"scaled_upper_triang_masked_softmax_fusion": false,
"bias_gelu_fusion": false,
"rope_fusion": false,
"layernorm_fusion": false,


# optimizer settings
Expand Down
1 change: 1 addition & 0 deletions configs/bnb_125M.yml
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
"scaled_upper_triang_masked_softmax_fusion": false,
"bias_gelu_fusion": false,
"rope_fusion": false,
"layernorm_fusion": false,


# optimizer settings
Expand Down
1 change: 1 addition & 0 deletions configs/slurm_125M.yml
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
"scaled_upper_triang_masked_softmax_fusion": true,
"bias_gelu_fusion": true,
"rope_fusion": false,
"layernorm_fusion": false,
"optimizer": {
"type": "Adam",
"params": {
Expand Down
51 changes: 51 additions & 0 deletions megatron/devutil.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
import torch.cuda


class Metric:
"""
Dumb utility to collect and report average wall-time metrics.
"""

def __init__(self, label):
self.label = label
self.measurements = []

def collect(self, measurement):
self.measurements.append(measurement)

def get_measurements(self):
return self.measurements[:]

def report(self):
print(
self.label,
torch.quantile(torch.tensor(self.measurements), torch.arange(10) / 10.0),
)


def monitor_method_cuda_wall_times(metric, obj, methodname):
"""
Measure timings for a method on an object or class.
For instance:
>>> metric = Metric('!LNORM')
>>> monitor_method_wall_times(metric, LayerNorm, 'forward')
"""
oldmeth = getattr(obj, methodname)

start_event = torch.cuda.Event(enable_timing=True)
end_event = torch.cuda.Event(enable_timing=True)

def newmeth(*args, **kw):
start_event.record()
try:
return oldmeth(*args, **kw)
finally:
end_event.record()
torch.cuda.synchronize()
elapsed = start_event.elapsed_time(end_event)
metric.collect(elapsed)
metric.report()

setattr(obj, methodname, newmeth)
150 changes: 150 additions & 0 deletions megatron/model/fused_layer_norm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,150 @@
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.

"""This code is copied from NVIDIA apex:
https://github.com/NVIDIA/apex
with some changes. """

import numbers
import torch
from torch.nn.parameter import Parameter
from torch.nn import init
import importlib
from torch.nn import functional as F
import inspect

from megatron.utils import make_viewless_tensor

try:
from apex.contrib.layer_norm.layer_norm import FastLayerNormFN

HAVE_PERSIST_LAYER_NORM = True
except:
HAVE_PERSIST_LAYER_NORM = False

from apex.normalization.fused_layer_norm import FusedLayerNormAffineFunction


global fused_layer_norm_cuda
fused_layer_norm_cuda = None


class MixedFusedLayerNorm(torch.nn.Module):
def __init__(
self,
normalized_shape,
eps=1e-5,
no_persist_layer_norm=True,
sequence_parallel=False,
apply_layernorm_1p=False,
mem_efficient_ln=True,
):
super(MixedFusedLayerNorm, self).__init__()

self.apply_layernorm_1p = apply_layernorm_1p
self.mem_efficient_ln = mem_efficient_ln

global fused_layer_norm_cuda
fused_layer_norm_cuda = importlib.import_module("fused_layer_norm_cuda")

# List of hiddens sizes supported in the persistent layer norm kernel
# If the hidden size is not supported, fall back to the non-persistent
# kernel.
persist_ln_hidden_sizes = [
1024,
1536,
2048,
2304,
3072,
3840,
4096,
5120,
6144,
8192,
10240,
12288,
12800,
15360,
16384,
18432,
20480,
24576,
25600,
30720,
32768,
40960,
49152,
65536,
]
if (
normalized_shape not in persist_ln_hidden_sizes
or not HAVE_PERSIST_LAYER_NORM
):
no_persist_layer_norm = True

if isinstance(normalized_shape, numbers.Integral):
normalized_shape = (normalized_shape,)
self.normalized_shape = torch.Size(normalized_shape)
self.eps = eps
self.weight = Parameter(torch.Tensor(*normalized_shape))
self.bias = Parameter(torch.Tensor(*normalized_shape))
self.reset_parameters()
self.no_persist_layer_norm = no_persist_layer_norm
self.sequence_parallel = sequence_parallel

# set sequence parallelism flag on weight and bias parameters
setattr(self.weight, "sequence_parallel", self.sequence_parallel)
setattr(self.bias, "sequence_parallel", self.sequence_parallel)

def reset_parameters(self):

if self.apply_layernorm_1p:
init.zeros_(self.weight)
init.zeros_(self.bias)
else:
init.ones_(self.weight)
init.zeros_(self.bias)

def forward(self, input):

weight = self.weight + 1 if self.apply_layernorm_1p else self.weight
# CPU path is here for unittest sake.
if not input.is_cuda:
print(
"WARNING! The input of FusedLayerNorm should be on the GPU."
"This warning should only be triggered in the FusedLayerNorm unit tests."
)
return F.layer_norm(
input, self.normalized_shape, weight, self.bias, self.eps
)

if self.no_persist_layer_norm:
# Apex does not have versions yet (https://github.com/NVIDIA/apex/pull/1648), so we need to inspect
# the function manually on whether the extra arg introduced in https://github.com/NVIDIA/apex/pull/1715 exists yet
if (
"memory_efficient"
in inspect.getfullargspec(FusedLayerNormAffineFunction.forward).args
):
return FusedLayerNormAffineFunction.apply(
input,
weight,
self.bias,
self.normalized_shape,
self.eps,
self.mem_efficient_ln,
)
else:
return FusedLayerNormAffineFunction.apply(
input, weight, self.bias, self.normalized_shape, self.eps
)
else:
output = FastLayerNormFN.apply(input, weight, self.bias, self.eps)

# Apex's fast layer norm function outputs a 'view' tensor (i.e., has
# a populated '_base' field). This will result in schedule.py's
# deallocate_output_tensor() throwing an error, so a viewless tensor is
# created to prevent this.
output = make_viewless_tensor(
inp=output, requires_grad=input.requires_grad, keep_graph=True
)

return output
3 changes: 2 additions & 1 deletion megatron/model/norms.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

import torch
from torch.nn import LayerNorm as LayerNorm
from .fused_layer_norm import MixedFusedLayerNorm


def get_norm(neox_args):
Expand All @@ -22,7 +23,7 @@ def get_norm(neox_args):
eps = neox_args.rms_norm_epsilon
elif neox_args.norm == "layernorm":
eps = neox_args.layernorm_epsilon
norm = LayerNorm
norm = MixedFusedLayerNorm if neox_args.layernorm_fusion else LayerNorm
elif neox_args.norm == "scalenorm":
eps = neox_args.scalenorm_epsilon
norm = ScaleNorm
Expand Down
5 changes: 5 additions & 0 deletions megatron/neox_arguments/neox_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,11 @@ class NeoXArgsModel(NeoXArgsTemplate):
Normalization layer to use. Choose from "layernorm", "rmsnorm", "scalenorm".
"""

layernorm_fusion: bool = False
"""
Use fused layer norm kernel (if `norm` is `layernorm`).
"""

use_qk_layernorm: bool = False
"""
Use QK Normalization
Expand Down
Loading

0 comments on commit 3d8fec0

Please sign in to comment.