Skip to content
This repository has been archived by the owner on Aug 11, 2022. It is now read-only.

Commit

Permalink
training utils
Browse files Browse the repository at this point in the history
  • Loading branch information
vpj committed Apr 7, 2022
1 parent 88c5df6 commit 753033e
Show file tree
Hide file tree
Showing 2 changed files with 174 additions and 4 deletions.
63 changes: 59 additions & 4 deletions src/neox/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,15 +8,19 @@
# Utilities and Helpers
* [Cache for intermediate activations (for faster inference)](cache.html)
* [Utilities for training and fine-tuning](training.html)
"""

from pathlib import Path
from typing import List, Optional
from typing import List, Optional, Set

import torch
from tokenizers import Tokenizer

from labml import logger
from labml import logger, monit
from labml.logger import inspect, Text
from neox.checkpoint import get_checkpoint_files, load_checkpoint_files
from neox.model import get_layers
from neox.tokenizer import get_tokenizer

# Tokenizer singleton
Expand Down Expand Up @@ -129,7 +133,10 @@ def print_tokens(target: List[int], others: List[List[int]]):
logger.log(parts)


def _test():
def _test_sample_tokens():
"""
Test sample tokens
"""
ids = get_sample_tokens(True)
inspect(ids)

Expand All @@ -139,6 +146,54 @@ def _test():
inspect(_TOKENIZER.decode(ids))


def load_layers(filter_layers: Optional[Set[int]]):
"""
### Load GPT-NeoX layers
This is a helper function to initialize andn load the layers.
:param filter_layers: are the layers to be filters. If `None` all layers will be loaded.
:return: the list of loaded layers
"""
with torch.no_grad():
layers = []
with monit.section("Layers"):
for i, (layer, files) in enumerate(zip(get_layers(filter_layers=filter_layers), get_checkpoint_files())):
if layer is None or files is None:
continue
layer.load_state(*load_checkpoint_files(files))

layers.append(layer)

monit.progress(i / 49)

return layers


def balance_layers(n_layers: int, n_chunks: int):
"""
### Balance layers
Split the `n_layers` into `n_chunks`. This is used for pipeline parallel training.
:param n_layers: is the number of layers
:param n_chunks: is the number of chunks
:return: returns a list with the number of layers for each chunk
"""
balance = []
for i in range(n_chunks):
balance.append((n_layers - sum(balance)) // (n_chunks - i))

return reversed(balance)


def _test_balance():
"""
Test balancing
"""
inspect(balance_layers(45, 4))


#
if __name__ == '__main__':
_test()
_test_sample_tokens()
115 changes: 115 additions & 0 deletions src/neox/utils/training.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
"""
---
title: Training Utilities and Helpers
summary: >
Utilities and helper functions for model training
---
# Training Utilities and Helpers
"""
from typing import List

import torch.nn as nn
import torch.utils.data
import torch.optim

from labml import monit, tracker
from labml.logger import inspect


def get_trainable_params(model: nn.Module):
"""
### Get trainable parameters
:param model: is the model to train
:return: a list of parameters for training
"""

# Get all parameters
params = list(model.parameters())
# Filter parameters that require gradients
trainable_params = [p for p in params if p.requires_grad]
# Log
inspect(params=len(params), params_training=len(trainable_params))

#
return trainable_params


def train_biases_only(layers: List[nn.Module]):
"""
### Train only biases
This sets `requires_grad` to `False` in all parameters except biases.
We use this for fine-tuning, when it's too slow/expensive to train all parameters.
:param layers: is the list of layers
"""

for layer in layers:
# Set `requires_grad` to `False` for the entire layer.
layer.requires_grad_(False)
#
for n, p in layer.named_parameters():
# Set `requires_grad` to `True` only for biases
if 'bias' in n:
p.requires_grad_(True)


def train(model: nn.Module, optimizer: torch.optim.Adam,
train_loader: torch.utils.data.DataLoader,
device: torch.device, train_log_interval: int):
"""
## Simple trainer
This trains the `model` for a single epoch.
:param model: is the model
:param optimizer: is the optimizer
:param train_loader: is the training data loader
:param device: is the device for inputs
:param train_log_interval: is the logging frequency
"""

# Set model for train
model.train()

# Cross-entropy loss
loss_func = nn.CrossEntropyLoss()

# Iterate through the batches
for batch_idx, (data, target) in monit.enum('Train', train_loader):
# Set gradients to zero
optimizer.zero_grad()

# Forward pass
with monit.section('Forward pass'):
output = model(data.to(device))
# Move targets to the same device as output
target = target.to(output.device)
# Calculate loss
loss = loss_func(output.view(target.numel(), -1), target.view(-1))

# Get predictions
pred = output.argmax(dim=-1)
# Calculate accuracy
accuracy = pred.eq(target).sum().item() / pred.numel()

# Backward pass
with monit.section('Backward pass'):
loss.backward()

# Optimize
with monit.section('Optimize'):
optimizer.step()

# Log the stats
tracker.add_global_step()
tracker.save({'loss.train': loss, 'acc.train': accuracy * 100})

# Log model stats like gradients and weights once in a while
if batch_idx % train_log_interval == 0:
tracker.save(model=model)

# Log model stats like gradients and weights at the end of the epoch
tracker.save(model=model)

0 comments on commit 753033e

Please sign in to comment.