Skip to content
This repository has been archived by the owner on Aug 11, 2022. It is now read-only.

Commit

Permalink
half precision
Browse files Browse the repository at this point in the history
  • Loading branch information
vpj committed May 18, 2022
1 parent 3b52543 commit c5f0a48
Show file tree
Hide file tree
Showing 7 changed files with 211 additions and 68 deletions.
5 changes: 4 additions & 1 deletion src/neox/checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,10 @@ def load_checkpoint_files(files: Tuple[str, str]):
:return: the loaded parameter tensors
"""
checkpoint_path = CHECKPOINTS_DOWNLOAD_PATH / 'global_step150000'
return [torch.load(checkpoint_path / f) for f in files]
with monit.section('Load checkpoint'):
data = [torch.load(checkpoint_path / f) for f in files]

return data


def merge_params_dim_0(param: Union[nn.Parameter, torch.Tensor], key: str, p1: Dict[str, torch.Tensor],
Expand Down
15 changes: 10 additions & 5 deletions src/neox/evaluation/half_precision.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,19 @@
import torch
from labml import monit
from torch import nn

from labml import monit
from neox.evaluation import run_eval_harness
from neox.utils import load_layers
from neox.utils import LayerGenerator

if __name__ == '__main__':
layers = load_layers(None)
device = torch.device('cuda:0')
layers = list(LayerGenerator(is_clone_layers=True,
filter_layers=None,
dtype=torch.float16,
device=device
).load())

with monit.section('Sequential'):
model = nn.Sequential(*layers).half().to(torch.device('cuda:0'))
model = nn.Sequential(*layers).half().to()

print(run_eval_harness(model, 'half_precision', [], torch.device('cuda:0')))
print(run_eval_harness(model, 'half_precision', [], device))
8 changes: 5 additions & 3 deletions src/neox/evaluation/pipeline_parallel.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,15 @@
import fairscale
import torch
from labml import monit
from torch import nn

from labml import monit
from neox.evaluation import run_eval_harness
from neox.utils import load_layers
from neox.utils import LayerGenerator

if __name__ == '__main__':
layers = load_layers(None)
layers = list(LayerGenerator(is_clone_layers=True,
filter_layers=None,
).load())

with monit.section('Sequential'):
model = nn.Sequential(*layers)
Expand Down
52 changes: 0 additions & 52 deletions src/neox/model_experiments.py

This file was deleted.

111 changes: 107 additions & 4 deletions src/neox/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,13 @@
* [Cache for intermediate activations (for faster inference)](cache.html)
* [Utilities for training and fine-tuning](training.html)
"""

import copy
from pathlib import Path
from typing import List, Optional, Set

import torch
from tokenizers import Tokenizer
from torch import nn

from labml import logger, monit
from labml.logger import inspect, Text
Expand Down Expand Up @@ -145,22 +146,124 @@ def _test_sample_tokens():
inspect(_TOKENIZER.decode(ids))


class LayerGenerator:
def __init__(self, n_vocab: int = 50_432, n_hidden: int = 6_144, n_layers: int = 44, n_heads: int = 64,
filter_layers: Optional[Set] = None, *,
is_clone_layers: bool = False,
dtype: torch.dtype = torch.float,
device: torch.device = torch.device('cpu')):
"""
### Generator to create layers
The layers are generated in the same order as checkpoints.
It gives `None` when a layer is not available; we use the layer indices as NeoX and there are two
transformation layers we don't need in our implementation.
:param n_vocab: is the number of tokens in the vocabulary
:param n_hidden: is the number of features in the embeddings
:param n_layers: is the number of transformer layers
:param n_heads: is the number of attention heads
:param filter_layers: are the set of layers to be used. All layers will be used if None.
This is used to test smaller versions of the model with fewer layers
:param is_clone_layers: specifies whether to clone the transformer layers (a bit faster)
:param is_half_precision: specifies whether to create half precision layers
:return: the layers as a generator
"""
self.n_vocab = n_vocab
self.n_hidden = n_hidden
self.n_layers = n_layers
self.n_heads = n_heads
self.filter_layers = filter_layers
self.is_clone_layers = is_clone_layers
self.dtype = dtype
self.device = device

def _prepare_layer(self, layer: nn.Module):
layer = layer.to(self.device, self.dtype)
return layer

def get_layers(self):
from neox.model import Embedding, TransformerLayer, FinalNorm, ReadoutLayer

# Embedding layer
with monit.section('Embedding layer'):
layer = Embedding(self.n_vocab, self.n_hidden)
yield self._prepare_layer(layer)

#
yield None

tl = None

# Transformer layers
for i in range(self.n_layers):
# Yield `None` if we are skipping layers
if self.filter_layers is not None and i not in self.filter_layers:
yield None
continue
# Transformer layer
with monit.section(f'Transformer Layer {i}'):
if self.is_clone_layers:
if tl is None:
tl = TransformerLayer(self.n_hidden, self.n_heads)
tl = self._prepare_layer(tl)
layer = copy.deepcopy(tl)
else:
layer = TransformerLayer(self.n_hidden, self.n_heads)
layer = self._prepare_layer(layer)
yield layer

#
yield None

# Final normalization layer
with monit.section('Final norm layer'):
layer = FinalNorm(self.n_hidden)
layer = self._prepare_layer(layer)
yield layer

# Readout layer
with monit.section('Readout layer'):
layer = ReadoutLayer(self.n_hidden, self.n_vocab)
layer = self._prepare_layer(layer)
yield layer

def load(self):
total_layers = self.n_layers + 3

with torch.no_grad():
with monit.section("Layers"):
for i, (layer, files) in enumerate(
zip(self.get_layers(), get_checkpoint_files())):
if layer is None or files is None:
continue
layer.load_state(*load_checkpoint_files(files))

yield layer

monit.progress(i / total_layers)


def load_layers(filter_layers: Optional[Set[int]]):
"""
### Load GPT-NeoX layers
This is a helper function to initialize andn load the layers.
:param filter_layers: are the layers to be filters. If `None` all layers will be loaded.
:param is_clone_layers: decides whether to clone transformer layers instead of initializing which is slower because
of weight initialization
:param is_half_precision: specifies whether to create half precision layers
:return: the list of loaded layers
"""
from neox.model import get_layers

with torch.no_grad():
layers = []
with monit.section("Layers"):
for i, (layer, files) in enumerate(zip(get_layers(filter_layers=filter_layers, is_clone_layers=True),
get_checkpoint_files())):
for i, (layer, files) in enumerate(
zip(get_layers(filter_layers=filter_layers, is_clone_layers=True), get_checkpoint_files())):
if layer is None or files is None:
continue
layer.load_state(*load_checkpoint_files(files))
Expand All @@ -186,7 +289,7 @@ def balance_layers(n_layers: int, n_chunks: int):
for i in range(n_chunks):
balance.append((n_layers - sum(balance)) // (n_chunks - i))

return reversed(balance)
return list(reversed(balance))


def _test_balance():
Expand Down
16 changes: 16 additions & 0 deletions src/neox/utils/cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,22 @@ def push(self, name: str, value: Any):
# Push to the queue
self._cache[name].append(value)

def q_size(self, name):
"""
### Return the size of the queue
:param name: is the name of the queue
:return: size of the queue if exists else None
"""

if name not in self._cache:
return None

if type(self._cache[name]) != list:
return None

return len(self._cache[name])

def pop(self, name: str):
"""
### Pop from a queue
Expand Down
72 changes: 69 additions & 3 deletions src/neox/utils/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import torch.nn as nn
import torch.utils.data
import torch.optim
from torch.cuda import amp

from labml import monit, tracker
from labml.logger import inspect
Expand Down Expand Up @@ -108,8 +109,73 @@ def train(model: nn.Module, optimizer: torch.optim.Adam,
tracker.save({'loss.train': loss, 'acc.train': accuracy * 100})

# Log model stats like gradients and weights once in a while
if batch_idx % train_log_interval == 0:
tracker.save(model=model)
# if batch_idx % train_log_interval == 0:
# tracker.save(model=model)

# Log model stats like gradients and weights at the end of the epoch
tracker.save(model=model)
# tracker.save(model=model)


def train_amp(model: nn.Module, optimizer: torch.optim.Adam,
train_loader: torch.utils.data.DataLoader,
device: torch.device, scaler: amp.GradScaler,
train_log_interval: int):
"""
## Simple trainer
This trains the `model` for a single epoch.
:param model: is the model
:param optimizer: is the optimizer
:param train_loader: is the training data loader
:param device: is the device for inputs
:param train_log_interval: is the logging frequency
"""

# Set model for train
model.train()

# Cross-entropy loss
loss_func = nn.CrossEntropyLoss()

# Iterate through the batches
for batch_idx, (data, target) in monit.enum('Train', train_loader):
# Set gradients to zero
optimizer.zero_grad()

# Forward pass
with amp.autocast():
with monit.section('Forward pass'):
output = model(data.to(device))
# Move targets to the same device as output
target = target.to(output.device)

# Calculate loss
loss = loss_func(output.view(target.numel(), -1), target.view(-1))

tracker.add({'loss.unscaled': loss})
# Get predictions
pred = output.argmax(dim=-1)
# Calculate accuracy
accuracy = pred.eq(target).sum().item() / pred.numel()

# Backward pass
loss = scaler.scale(loss)

with monit.section('Backward pass'):
loss.backward()

# Optimize
with monit.section('Optimize'):
optimizer.step()

# Log the stats
tracker.add_global_step()
tracker.save({'loss.train': loss, 'acc.train': accuracy * 100})

# Log model stats like gradients and weights once in a while
# if batch_idx % train_log_interval == 0:
# tracker.save(model=model)

# Log model stats like gradients and weights at the end of the epoch
# tracker.save(model=model)

0 comments on commit c5f0a48

Please sign in to comment.