Skip to content

Commit

Permalink
working minimal GPT
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Dec 23, 2020
1 parent 0c150fe commit ec7c545
Show file tree
Hide file tree
Showing 9 changed files with 321 additions and 0 deletions.
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
## Python
__pycache__/

## Core latex/pdflatex auxiliary files:
*.aux
*.lof
Expand Down
12 changes: 12 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,2 +1,14 @@
# GPT-NeoX
An implementation of model parallel GPT-3-like models on GPUs, based on the DeepSpeed library. Designed to be able to train models in the hundreds of billions of parameters or larger.

## Requirements

```bash
$ pip install -r requirements.txt
```

Test locally

```bash
$ python train_enwik8.py
```
3 changes: 3 additions & 0 deletions data/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
# Data source

The enwik8 data was downloaded from the Hutter prize page: http:https://prize.hutter1.net/
Binary file added data/enwik8.gz
Binary file not shown.
2 changes: 2 additions & 0 deletions gpt_neox/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
from gpt_neox.gpt_neox import GPTNeoX
from gpt_neox.autoregressive_wrapper import AutoregressiveWrapper
90 changes: 90 additions & 0 deletions gpt_neox/autoregressive_wrapper.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
import torch
from torch import nn
import torch.nn.functional as F

# nucleus

def top_p(logits, thres = 0.9):
sorted_logits, sorted_indices = torch.sort(logits, descending=True)
cum_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)

sorted_indices_to_remove = cum_probs > (1 - thres)
sorted_indices_to_remove[:, 1:] = sorted_indices_to_remove[:, :-1].clone()
sorted_indices_to_remove[:, 0] = 0

sorted_logits[sorted_indices_to_remove] = float('-inf')
return sorted_logits.scatter(1, sorted_indices, sorted_logits)

# topk

def top_k(logits, thres = 0.9):
k = int((1 - thres) * logits.shape[-1])
val, ind = torch.topk(logits, k)
probs = torch.full_like(logits, float('-inf'))
probs.scatter_(1, ind, val)
return probs

class AutoregressiveWrapper(nn.Module):
def __init__(self, net, ignore_index = 0, pad_value = 0):
super().__init__()
self.pad_value = pad_value
self.ignore_index = ignore_index

self.net = net
self.seq_len = net.seq_len

@torch.no_grad()
def generate(self, start_tokens, seq_len, eos_token = None, temperature = 1., filter_logits_fn = top_k, filter_thres = 0.9, **kwargs):
device = start_tokens.device
was_training = self.net.training
num_dims = len(start_tokens.shape)

if num_dims == 1:
start_tokens = start_tokens[None, :]

b, t = start_tokens.shape

self.net.eval()
out = start_tokens
mask = kwargs.pop('mask', None)

if mask is None:
mask = torch.full_like(out, True, dtype=torch.bool, device=out.device)

for _ in range(seq_len):
x = out[:, -self.seq_len:]
mask = mask[:, -self.seq_len:]

logits = self.net(x, mask=mask, **kwargs)[:, -1, :]
filtered_logits = filter_logits_fn(logits, thres = filter_thres)
probs = F.softmax(filtered_logits / temperature, dim=-1)
sample = torch.multinomial(probs, 1)

out = torch.cat((out, sample), dim=-1)
mask = F.pad(mask, (0, 1), value=True)

if eos_token is not None and (sample == eos_token).all():
break

out = out[:, t:]

if num_dims == 1:
out = out.squeeze(0)

self.net.train(was_training)
return out

def forward(self, x, **kwargs):
xi = x[:, :-1]
xo = x[:, 1:]

# help auto-solve a frequent area of confusion around input masks in auto-regressive
# if user supplies a mask that is only off by one from the source sequence, resolve it for them
mask = kwargs.pop('mask', None)
if mask is not None and mask.shape[1] == x.shape[1]:
mask = mask[:, :-1]
kwargs.update(mask = mask)

out = self.net(xi, **kwargs)
loss = F.cross_entropy(out.transpose(1, 2), xo, ignore_index = self.ignore_index)
return loss
100 changes: 100 additions & 0 deletions gpt_neox/gpt_neox.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
import torch
from torch import nn, einsum
from einops import rearrange

# helpers

def exists(val):
return val is not None

# classes

class PreNorm(nn.Module):
def __init__(self, dim, fn):
super().__init__()
self.fn = fn
self.norm = nn.LayerNorm(dim)

def forward(self, x, **kwargs):
x = self.norm(x)
return self.fn(x, **kwargs)

class FeedForward(nn.Module):
def __init__(self, dim, mult = 4, dropout = 0.):
super().__init__()
self.net = nn.Sequential(
nn.Linear(dim, dim * mult),
nn.GELU(),
nn.Dropout(dropout),
nn.Linear(dim * mult, dim)
)

def forward(self, x, **kwargs):
return self.net(x)

class Attention(nn.Module):
def __init__(self, dim, heads, causal = True, dim_head = 64, dropout = 0.):
super().__init__()
inner_dim = heads * dim_head
self.causal = causal
self.heads = heads
self.scale = dim_head ** -0.5

self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)
self.to_out = nn.Linear(inner_dim, dim)
self.dropout = nn.Dropout(dropout)

def forward(self, x, **kwargs):
h, device = self.heads, x.device

q, k, v = self.to_qkv(x).chunk(3, dim = -1)
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), (q, k, v))

sim = einsum('b h i d, b h j d -> b h i j', q, k) * self.scale

if self.causal:
i, j = sim.shape[2:]
mask = torch.ones(i, j, device = device).triu_(j - i + 1).bool()
mask_value = -torch.finfo(sim.dtype).max
sim.masked_fill_(mask, mask_value)

attn = sim.softmax(dim = -1)
attn = self.dropout(attn)

out = einsum('b h i j, b h j d -> b h i d', attn, v)
out = rearrange(out, 'b h n d -> b n (h d)')
return self.to_out(out)

class GPTNeoX(nn.Module):
def __init__(self, *, num_tokens, dim, seq_len, depth, heads = 8, dim_head = 64, attn_dropout = 0., ff_dropout = 0.):
super().__init__()
self.seq_len = seq_len

self.token_emb = nn.Embedding(num_tokens, dim)
self.pos_emb = nn.Embedding(seq_len, dim)

self.token_emb.weight.data.normal_(0, 0.02)
self.pos_emb.weight.data.normal_(0, 0.02)

self.layers = nn.ModuleList([])
for _ in range(depth):
self.layers.append(nn.ModuleList([
PreNorm(dim, Attention(dim = dim, heads = heads, dim_head = dim_head, dropout = attn_dropout)),
PreNorm(dim, FeedForward(dim = dim, dropout = ff_dropout)),
]))

self.norm = nn.LayerNorm(dim)
self.to_logits = lambda t: t @ self.token_emb.weight.t()

def forward(self, x, mask = None):
n, device = x.shape[1], x.device

x = self.token_emb(x)
x = self.pos_emb(torch.arange(n, device = device)) + x

for (attn, ff) in self.layers:
x = attn(x) + x
x = ff(x) + x

x = self.norm(x)
return self.to_logits(x)
2 changes: 2 additions & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1 +1,3 @@
einops>=0.3
torch>=1.6
tqdm
109 changes: 109 additions & 0 deletions train_enwik8.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
from gpt_neox import GPTNeoX, AutoregressiveWrapper

import random
import tqdm
import gzip
import numpy as np
import torch
import torch.optim as optim
from torch.nn import functional as F
from torch.utils.data import DataLoader, Dataset

# constants

NUM_BATCHES = int(1e5)
BATCH_SIZE = 4
GRADIENT_ACCUMULATE_EVERY = 4
LEARNING_RATE = 1e-4
VALIDATE_EVERY = 100
GENERATE_EVERY = 500
GENERATE_LENGTH = 512
SEQ_LEN = 1024

# helpers

def cycle(loader):
while True:
for data in loader:
yield data

def decode_token(token):
return str(chr(max(32, token)))

def decode_tokens(tokens):
return ''.join(list(map(decode_token, tokens)))

# instantiate GPT-like decoder model

model = GPTNeoX(
num_tokens = 256,
dim = 512,
seq_len = SEQ_LEN,
depth = 6,
heads = 8,
dim_head = 64
)

model = AutoregressiveWrapper(model)
model.cuda()

# prepare enwik8 data

with gzip.open('./data/enwik8.gz') as file:
X = np.fromstring(file.read(int(95e6)), dtype=np.uint8)
trX, vaX = np.split(X, [int(90e6)])
data_train, data_val = torch.from_numpy(trX), torch.from_numpy(vaX)

class TextSamplerDataset(Dataset):
def __init__(self, data, seq_len):
super().__init__()
self.data = data
self.seq_len = seq_len

def __getitem__(self, index):
rand_start = torch.randint(0, self.data.size(0) - self.seq_len - 1, (1,))
full_seq = self.data[rand_start: rand_start + self.seq_len + 1].long()
return full_seq.cuda()

def __len__(self):
return self.data.size(0) // self.seq_len

train_dataset = TextSamplerDataset(data_train, SEQ_LEN)
val_dataset = TextSamplerDataset(data_val, SEQ_LEN)
train_loader = cycle(DataLoader(train_dataset, batch_size = BATCH_SIZE))
val_loader = cycle(DataLoader(val_dataset, batch_size = BATCH_SIZE))

# optimizer

optim = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)

# training

for i in tqdm.tqdm(range(NUM_BATCHES), mininterval=10., desc='training'):
model.train()

for __ in range(GRADIENT_ACCUMULATE_EVERY):
loss = model(next(train_loader))
loss.backward()

print(f'training loss: {loss.item()}')

torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5)
optim.step()
optim.zero_grad()

if i % VALIDATE_EVERY == 0:
model.eval()
with torch.no_grad():
loss = model(next(val_loader))
print(f'validation loss: {loss.item()}')

if i % GENERATE_EVERY == 0:
model.eval()
inp = random.choice(val_dataset)[:-1]
prime = decode_tokens(inp)
print(f'%s \n\n %s', (prime, '*' * 100))

sample = model.generate(inp, GENERATE_LENGTH)
output_str = decode_tokens(sample)
print(output_str)

0 comments on commit ec7c545

Please sign in to comment.