This repository has been archived by the owner on Aug 11, 2022. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 20
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
4 changed files
with
120 additions
and
8 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,101 @@ | ||
""" | ||
--- | ||
title: Generate Text with GPT-NeoX by Evaluating Layer by Layer | ||
summary: > | ||
Generate Text with GPT-NeoX by evaluating layer by layer | ||
--- | ||
# Generate Text with GPT-NeoX by Evaluating Layer by Layer | ||
This shows how to generate text from GPT-NeoX with a single GPU. | ||
This needs a GPU with more than 45GB memory. | ||
""" | ||
|
||
# Imports | ||
from typing import List | ||
|
||
import torch | ||
from torch import nn | ||
|
||
from labml import monit | ||
from neox.utils import get_tokens, print_tokens, LayerGenerator | ||
from neox.utils.cache import get_cache | ||
|
||
# List of layers to load. This is used for testing. | ||
# You can assign a subset of layers like `{0, 1}` so that it only loads | ||
# the first to transformer layers. | ||
LAYERS = None | ||
|
||
# Prompt to complete | ||
PROMPT = 'Einstein was born in the German Empire, but moved to Switzerland in 1895, forsaking his German' | ||
|
||
|
||
def infer(model: nn.Module, ids: List[int], device: torch.device): | ||
""" | ||
### Predict the next token | ||
:param layers: is the list of layers | ||
:param ids: are the input token ids | ||
:param device: is the device of the model | ||
""" | ||
|
||
with torch.no_grad(): | ||
# Get the tokens | ||
x = torch.tensor(ids)[None, :].to(device) | ||
# Eval model | ||
x = model(x) | ||
|
||
# Return predicted token | ||
return x[0].max(dim=-1)[1].tolist() | ||
|
||
|
||
def generate(): | ||
""" | ||
## Generate text | ||
""" | ||
|
||
# Setup [cache](../utils/cache.html) to cache intermediate key/value pairs for faster generation | ||
cache = get_cache() | ||
cache.set('use_cache', True) | ||
|
||
# Device | ||
device = torch.device('cuda:0') | ||
|
||
# Load layers | ||
layers = list(LayerGenerator(is_clone_layers=True, | ||
filter_layers=LAYERS, | ||
dtype=torch.float16, | ||
device=device, | ||
).load()) | ||
|
||
model = nn.Sequential(*layers) | ||
|
||
# Get token ids | ||
ids = get_tokens(PROMPT) | ||
|
||
# Run the model | ||
cache.set('state_ids', (None, 1)) | ||
with monit.section('Infer'): | ||
next_token = infer(model, ids, device)[-1] | ||
|
||
# Append the predicted token | ||
ids += [next_token] | ||
|
||
# Predict 100 tokens | ||
for i in range(1, 100): | ||
# Set the state to use cached activations | ||
cache.set('state_ids', (i, i + 1)) | ||
# Get next token. Note that we only feed the last token to the model because | ||
# we cache the key/value pairs of previous tokens. | ||
with monit.section('Infer'): | ||
next_token = infer(model, [next_token], device)[-1] | ||
# Append the predicted token | ||
ids += [next_token] | ||
print_tokens(ids, [ids]) | ||
|
||
|
||
# | ||
if __name__ == '__main__': | ||
generate() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters