Skip to content
This repository has been archived by the owner on Aug 11, 2022. It is now read-only.

Commit

Permalink
notes
Browse files Browse the repository at this point in the history
  • Loading branch information
vpj committed Apr 10, 2022
1 parent 731deab commit 14cb608
Show file tree
Hide file tree
Showing 4 changed files with 166 additions and 23 deletions.
13 changes: 13 additions & 0 deletions src/neox/samples/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
"""
---
title: Samples
summary: >
Samples for inference and fine-tuning
---
# Samples
* [Generating text with pipeline-parallel](generating_pipe.html)
* [Generating text by running the model layer-by-layer](generating_single_gpu.html)
* [Fine tuning the biases with pipeline-parallel](fine_tune_biases.html)
"""
33 changes: 33 additions & 0 deletions src/neox/samples/fine_tune_biases.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,16 @@
"""
---
title: Fine Tune GPT-NeoX
summary: >
Fine tune GPT-NeoX biases with Fairscale pipeline parallel module
---
# Fine Tune GPT-NeoX
This shows how to fine tune GPT-NeoX with pipeline parallelism.
"""

# Imports
import fairscale
import torch
import torch.nn as nn
Expand All @@ -11,38 +24,58 @@
from neox.utils import load_layers, balance_layers
from neox.utils.training import train, get_trainable_params, train_biases_only

# List of layers to load. This is used for testing.
# You can assign a subset of layers like `{0, 1}` so that it only loads
# the first to transformer layers.
LAYERS = None


def main():
"""
## Train GPT-NeoX
"""

# Create the experiment for tracking
experiment.create(name='finetune_neox_biases', comment='Pipeline parallel', writers={'screen', 'web_api'})

# Load layers
layers = load_layers(LAYERS)

# Mark `requires_grad=True` for biases using a [helper function](../utils/training.html).
train_biases_only(layers)

# Create the pipeline parallel model
with monit.section('Pipe'):
# Number of GPUs
n_gpus = min(16, torch.cuda.device_count())
# [Get the distribution of layers across the GPUs](../utils/__init__.py)
balance = balance_layers(len(layers), n_gpus)
# Get the GPU references
devices = [torch.device(f'cuda:{i}') for i in range(n_gpus)]
# Create the pipeline parallel model
pipe_model = fairscale.nn.Pipe(nn.Sequential(*layers),
balance=balance,
devices=devices,
chunks=8)

# Load [dataset](../dataset.html)
dataset = get_training_data(1024)

# Create data loader
train_dl = DataLoader(dataset,
batch_size=8,
sampler=RandomSampler(dataset, replacement=True))

# Initialize optimizer
optimizer = optim.Adam(get_trainable_params(pipe_model), lr=1e-6)

# Train the model using the [helper function](../utils/training.html)
with experiment.start():
for epoch in monit.loop(16):
train(pipe_model, optimizer, train_dl, torch.device('cuda:0'), 10)
tracker.new_line()


#
if __name__ == '__main__':
main()
75 changes: 59 additions & 16 deletions src/neox/samples/generating_pipe.py
Original file line number Diff line number Diff line change
@@ -1,58 +1,101 @@
"""
---
title: Generate Text with GPT-NeoX using Pipeline Parallelism
summary: >
Generate Text with GPT-NeoX using Fairscale Pipeline Parallelism
---
# Generate Text with GPT-NeoX using Pipeline Parallelism
This shows how to generate text from GPT-NeoX with pipeline parallelism.
"""

# Imports
from typing import List

import fairscale
import torch
from torch import nn

from labml import monit
from neox.utils import print_token_outputs, load_layers, get_tokens, print_tokens
from neox.utils import load_layers, get_tokens, print_tokens, balance_layers
from neox.utils.cache import get_cache

# List of layers to load. This is used for testing.
# You can assign a subset of layers like `{0, 1}` so that it only loads
# the first to transformer layers.
LAYERS = None

# Prompt to complete
PROMPT = 'Einstein was born in the German Empire, but moved to Switzerland in 1895, forsaking his German'


def infer(model, ids, device):
def infer(model: nn.Module, ids: List[int], device: torch.device):
"""
### Predict the next token
:param model: is the model
:param ids: are the input token ids
:param device: is the device of the model
"""

# Call the model
with torch.no_grad():
x = torch.tensor(ids)[None, :].to(device)
x = model(x)

print_token_outputs(ids, x)

# Return the outputs
return x[0].max(dim=-1)[1].tolist()


def generate():
"""
## Generate text
"""

# Setup [cache](../utils/cache.html) to cache intermediate key/value pairs for faster generation
cache = get_cache()
cache.set('use_cache', True)

# Load layers
layers = load_layers(LAYERS)

with monit.section('Sequential'):
model = nn.Sequential(*layers)

# Create pipeline parallel model
with monit.section('Pipe'):
n_layers = len(layers)
n_gpus = 4
balance = []
# Number of GPUs
n_gpus = min(4, torch.cuda.device_count())
# [Get the distribution of layers across the GPUs](../utils/__init__.py)
balance = balance_layers(len(layers), n_gpus)
# Get the GPU references
devices = [torch.device(f'cuda:{i}') for i in range(n_gpus)]
for i in range(n_gpus):
balance.append((n_layers - sum(balance)) // (n_gpus - i))
pipe_model = fairscale.nn.Pipe(model,
# Create the pipeline parallel model
pipe_model = fairscale.nn.Pipe(nn.Sequential(*layers),
balance=balance,
devices=devices)

# Get token ids
ids = get_tokens(PROMPT)

# Run the model
cache.set('state_ids', (None, 1))
next_token = infer(pipe_model, ids, pipe_model.devices[0])[-1]

full_tokens = ids + [next_token]
# Append the predicted token
ids += [next_token]

# Predict 100 tokens
for i in range(1, 100):
# Set the state to use cached activations
cache.set('state_ids', (i, i + 1))
# Get next token. Note that we only feed the last token to the model because
# we cache the key/value pairs of previous tokens.
next_token = infer(pipe_model, [next_token], pipe_model.devices[0])[-1]
full_tokens += [next_token]
print_tokens(full_tokens, [full_tokens])
# Append the predicted token
ids += [next_token]
# Print
print_tokens(ids, [ids])


#
if __name__ == '__main__':
generate()
68 changes: 61 additions & 7 deletions src/neox/samples/generating_single_gpu.py
Original file line number Diff line number Diff line change
@@ -1,51 +1,105 @@
"""
---
title: Generate Text with GPT-NeoX by Evaluating Layer by Layer
summary: >
Generate Text with GPT-NeoX by evaluating layer by layer
---
# Generate Text with GPT-NeoX by Evaluating Layer by Layer
This shows how to generate text from GPT-NeoX with a single GPU.
"""

# Imports
from typing import List

import torch
from torch import nn

from labml import monit
from neox.utils import print_token_outputs, load_layers, get_tokens, print_tokens
from neox.utils import load_layers, get_tokens, print_tokens
from neox.utils.cache import get_cache

# List of layers to load. This is used for testing.
# You can assign a subset of layers like `{0, 1}` so that it only loads
# the first to transformer layers.
LAYERS = None

# Prompt to complete
PROMPT = 'Einstein was born in the German Empire, but moved to Switzerland in 1895, forsaking his German'


def infer(layers, ids, device):
def infer(layers: List[nn.Module], ids: List[int], device: torch.device):
"""
### Predict the next token
:param layers: is the list of layers
:param ids: are the input token ids
:param device: is the device of the model
"""

# Offload to CPU
offload = torch.device('cpu')
# CUDA stream for async loading and de-loading. This is still WIP.
s = torch.cuda.Stream()
#
with torch.no_grad():
# Get the tokens
x = torch.tensor(ids)[None, :].to(device)
# Iterate through the layers
for layer in layers:
# Move the layer to device. Should pre-load.
layer.to(device)
# Evaluate the layer
x = layer(x)
# Offload (async)
with torch.cuda.stream(s):
layer.to(offload)

print_token_outputs(ids, x)

# Return predicted token
return x[0].max(dim=-1)[1].tolist()


def generate():
"""
## Generate text
"""

# Setup [cache](../utils/cache.html) to cache intermediate key/value pairs for faster generation
cache = get_cache()
cache.set('use_cache', True)

# Load layers
layers = load_layers(LAYERS)

# Device
device = torch.device('cuda:0')

# Get token ids
ids = get_tokens(PROMPT)

# Run the model
cache.set('state_ids', (None, 1))
with monit.section('Infer'):
next_token = infer(layers, ids, device)[-1]

full_tokens = ids + [next_token]
# Append the predicted token
ids += [next_token]

# Predict 100 tokens
for i in range(1, 100):
# Set the state to use cached activations
cache.set('state_ids', (i, i + 1))
# Get next token. Note that we only feed the last token to the model because
# we cache the key/value pairs of previous tokens.
with monit.section('Infer'):
next_token = infer(layers, [next_token], device)[-1]
full_tokens += [next_token]
print_tokens(full_tokens, [full_tokens])
# Append the predicted token
ids += [next_token]
# Print
print_tokens(ids, [ids])


#
if __name__ == '__main__':
generate()

0 comments on commit 14cb608

Please sign in to comment.