-
Notifications
You must be signed in to change notification settings - Fork 981
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
0c150fe
commit ec7c545
Showing
9 changed files
with
321 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,3 +1,6 @@ | ||
## Python | ||
__pycache__/ | ||
|
||
## Core latex/pdflatex auxiliary files: | ||
*.aux | ||
*.lof | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,2 +1,14 @@ | ||
# GPT-NeoX | ||
An implementation of model parallel GPT-3-like models on GPUs, based on the DeepSpeed library. Designed to be able to train models in the hundreds of billions of parameters or larger. | ||
|
||
## Requirements | ||
|
||
```bash | ||
$ pip install -r requirements.txt | ||
``` | ||
|
||
Test locally | ||
|
||
```bash | ||
$ python train_enwik8.py | ||
``` |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
# Data source | ||
|
||
The enwik8 data was downloaded from the Hutter prize page: http:https://prize.hutter1.net/ |
Binary file not shown.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
from gpt_neox.gpt_neox import GPTNeoX | ||
from gpt_neox.autoregressive_wrapper import AutoregressiveWrapper |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,90 @@ | ||
import torch | ||
from torch import nn | ||
import torch.nn.functional as F | ||
|
||
# nucleus | ||
|
||
def top_p(logits, thres = 0.9): | ||
sorted_logits, sorted_indices = torch.sort(logits, descending=True) | ||
cum_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1) | ||
|
||
sorted_indices_to_remove = cum_probs > (1 - thres) | ||
sorted_indices_to_remove[:, 1:] = sorted_indices_to_remove[:, :-1].clone() | ||
sorted_indices_to_remove[:, 0] = 0 | ||
|
||
sorted_logits[sorted_indices_to_remove] = float('-inf') | ||
return sorted_logits.scatter(1, sorted_indices, sorted_logits) | ||
|
||
# topk | ||
|
||
def top_k(logits, thres = 0.9): | ||
k = int((1 - thres) * logits.shape[-1]) | ||
val, ind = torch.topk(logits, k) | ||
probs = torch.full_like(logits, float('-inf')) | ||
probs.scatter_(1, ind, val) | ||
return probs | ||
|
||
class AutoregressiveWrapper(nn.Module): | ||
def __init__(self, net, ignore_index = 0, pad_value = 0): | ||
super().__init__() | ||
self.pad_value = pad_value | ||
self.ignore_index = ignore_index | ||
|
||
self.net = net | ||
self.seq_len = net.seq_len | ||
|
||
@torch.no_grad() | ||
def generate(self, start_tokens, seq_len, eos_token = None, temperature = 1., filter_logits_fn = top_k, filter_thres = 0.9, **kwargs): | ||
device = start_tokens.device | ||
was_training = self.net.training | ||
num_dims = len(start_tokens.shape) | ||
|
||
if num_dims == 1: | ||
start_tokens = start_tokens[None, :] | ||
|
||
b, t = start_tokens.shape | ||
|
||
self.net.eval() | ||
out = start_tokens | ||
mask = kwargs.pop('mask', None) | ||
|
||
if mask is None: | ||
mask = torch.full_like(out, True, dtype=torch.bool, device=out.device) | ||
|
||
for _ in range(seq_len): | ||
x = out[:, -self.seq_len:] | ||
mask = mask[:, -self.seq_len:] | ||
|
||
logits = self.net(x, mask=mask, **kwargs)[:, -1, :] | ||
filtered_logits = filter_logits_fn(logits, thres = filter_thres) | ||
probs = F.softmax(filtered_logits / temperature, dim=-1) | ||
sample = torch.multinomial(probs, 1) | ||
|
||
out = torch.cat((out, sample), dim=-1) | ||
mask = F.pad(mask, (0, 1), value=True) | ||
|
||
if eos_token is not None and (sample == eos_token).all(): | ||
break | ||
|
||
out = out[:, t:] | ||
|
||
if num_dims == 1: | ||
out = out.squeeze(0) | ||
|
||
self.net.train(was_training) | ||
return out | ||
|
||
def forward(self, x, **kwargs): | ||
xi = x[:, :-1] | ||
xo = x[:, 1:] | ||
|
||
# help auto-solve a frequent area of confusion around input masks in auto-regressive | ||
# if user supplies a mask that is only off by one from the source sequence, resolve it for them | ||
mask = kwargs.pop('mask', None) | ||
if mask is not None and mask.shape[1] == x.shape[1]: | ||
mask = mask[:, :-1] | ||
kwargs.update(mask = mask) | ||
|
||
out = self.net(xi, **kwargs) | ||
loss = F.cross_entropy(out.transpose(1, 2), xo, ignore_index = self.ignore_index) | ||
return loss |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,100 @@ | ||
import torch | ||
from torch import nn, einsum | ||
from einops import rearrange | ||
|
||
# helpers | ||
|
||
def exists(val): | ||
return val is not None | ||
|
||
# classes | ||
|
||
class PreNorm(nn.Module): | ||
def __init__(self, dim, fn): | ||
super().__init__() | ||
self.fn = fn | ||
self.norm = nn.LayerNorm(dim) | ||
|
||
def forward(self, x, **kwargs): | ||
x = self.norm(x) | ||
return self.fn(x, **kwargs) | ||
|
||
class FeedForward(nn.Module): | ||
def __init__(self, dim, mult = 4, dropout = 0.): | ||
super().__init__() | ||
self.net = nn.Sequential( | ||
nn.Linear(dim, dim * mult), | ||
nn.GELU(), | ||
nn.Dropout(dropout), | ||
nn.Linear(dim * mult, dim) | ||
) | ||
|
||
def forward(self, x, **kwargs): | ||
return self.net(x) | ||
|
||
class Attention(nn.Module): | ||
def __init__(self, dim, heads, causal = True, dim_head = 64, dropout = 0.): | ||
super().__init__() | ||
inner_dim = heads * dim_head | ||
self.causal = causal | ||
self.heads = heads | ||
self.scale = dim_head ** -0.5 | ||
|
||
self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False) | ||
self.to_out = nn.Linear(inner_dim, dim) | ||
self.dropout = nn.Dropout(dropout) | ||
|
||
def forward(self, x, **kwargs): | ||
h, device = self.heads, x.device | ||
|
||
q, k, v = self.to_qkv(x).chunk(3, dim = -1) | ||
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), (q, k, v)) | ||
|
||
sim = einsum('b h i d, b h j d -> b h i j', q, k) * self.scale | ||
|
||
if self.causal: | ||
i, j = sim.shape[2:] | ||
mask = torch.ones(i, j, device = device).triu_(j - i + 1).bool() | ||
mask_value = -torch.finfo(sim.dtype).max | ||
sim.masked_fill_(mask, mask_value) | ||
|
||
attn = sim.softmax(dim = -1) | ||
attn = self.dropout(attn) | ||
|
||
out = einsum('b h i j, b h j d -> b h i d', attn, v) | ||
out = rearrange(out, 'b h n d -> b n (h d)') | ||
return self.to_out(out) | ||
|
||
class GPTNeoX(nn.Module): | ||
def __init__(self, *, num_tokens, dim, seq_len, depth, heads = 8, dim_head = 64, attn_dropout = 0., ff_dropout = 0.): | ||
super().__init__() | ||
self.seq_len = seq_len | ||
|
||
self.token_emb = nn.Embedding(num_tokens, dim) | ||
self.pos_emb = nn.Embedding(seq_len, dim) | ||
|
||
self.token_emb.weight.data.normal_(0, 0.02) | ||
self.pos_emb.weight.data.normal_(0, 0.02) | ||
|
||
self.layers = nn.ModuleList([]) | ||
for _ in range(depth): | ||
self.layers.append(nn.ModuleList([ | ||
PreNorm(dim, Attention(dim = dim, heads = heads, dim_head = dim_head, dropout = attn_dropout)), | ||
PreNorm(dim, FeedForward(dim = dim, dropout = ff_dropout)), | ||
])) | ||
|
||
self.norm = nn.LayerNorm(dim) | ||
self.to_logits = lambda t: t @ self.token_emb.weight.t() | ||
|
||
def forward(self, x, mask = None): | ||
n, device = x.shape[1], x.device | ||
|
||
x = self.token_emb(x) | ||
x = self.pos_emb(torch.arange(n, device = device)) + x | ||
|
||
for (attn, ff) in self.layers: | ||
x = attn(x) + x | ||
x = ff(x) + x | ||
|
||
x = self.norm(x) | ||
return self.to_logits(x) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1 +1,3 @@ | ||
einops>=0.3 | ||
torch>=1.6 | ||
tqdm |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,109 @@ | ||
from gpt_neox import GPTNeoX, AutoregressiveWrapper | ||
|
||
import random | ||
import tqdm | ||
import gzip | ||
import numpy as np | ||
import torch | ||
import torch.optim as optim | ||
from torch.nn import functional as F | ||
from torch.utils.data import DataLoader, Dataset | ||
|
||
# constants | ||
|
||
NUM_BATCHES = int(1e5) | ||
BATCH_SIZE = 4 | ||
GRADIENT_ACCUMULATE_EVERY = 4 | ||
LEARNING_RATE = 1e-4 | ||
VALIDATE_EVERY = 100 | ||
GENERATE_EVERY = 500 | ||
GENERATE_LENGTH = 512 | ||
SEQ_LEN = 1024 | ||
|
||
# helpers | ||
|
||
def cycle(loader): | ||
while True: | ||
for data in loader: | ||
yield data | ||
|
||
def decode_token(token): | ||
return str(chr(max(32, token))) | ||
|
||
def decode_tokens(tokens): | ||
return ''.join(list(map(decode_token, tokens))) | ||
|
||
# instantiate GPT-like decoder model | ||
|
||
model = GPTNeoX( | ||
num_tokens = 256, | ||
dim = 512, | ||
seq_len = SEQ_LEN, | ||
depth = 6, | ||
heads = 8, | ||
dim_head = 64 | ||
) | ||
|
||
model = AutoregressiveWrapper(model) | ||
model.cuda() | ||
|
||
# prepare enwik8 data | ||
|
||
with gzip.open('./data/enwik8.gz') as file: | ||
X = np.fromstring(file.read(int(95e6)), dtype=np.uint8) | ||
trX, vaX = np.split(X, [int(90e6)]) | ||
data_train, data_val = torch.from_numpy(trX), torch.from_numpy(vaX) | ||
|
||
class TextSamplerDataset(Dataset): | ||
def __init__(self, data, seq_len): | ||
super().__init__() | ||
self.data = data | ||
self.seq_len = seq_len | ||
|
||
def __getitem__(self, index): | ||
rand_start = torch.randint(0, self.data.size(0) - self.seq_len - 1, (1,)) | ||
full_seq = self.data[rand_start: rand_start + self.seq_len + 1].long() | ||
return full_seq.cuda() | ||
|
||
def __len__(self): | ||
return self.data.size(0) // self.seq_len | ||
|
||
train_dataset = TextSamplerDataset(data_train, SEQ_LEN) | ||
val_dataset = TextSamplerDataset(data_val, SEQ_LEN) | ||
train_loader = cycle(DataLoader(train_dataset, batch_size = BATCH_SIZE)) | ||
val_loader = cycle(DataLoader(val_dataset, batch_size = BATCH_SIZE)) | ||
|
||
# optimizer | ||
|
||
optim = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE) | ||
|
||
# training | ||
|
||
for i in tqdm.tqdm(range(NUM_BATCHES), mininterval=10., desc='training'): | ||
model.train() | ||
|
||
for __ in range(GRADIENT_ACCUMULATE_EVERY): | ||
loss = model(next(train_loader)) | ||
loss.backward() | ||
|
||
print(f'training loss: {loss.item()}') | ||
|
||
torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5) | ||
optim.step() | ||
optim.zero_grad() | ||
|
||
if i % VALIDATE_EVERY == 0: | ||
model.eval() | ||
with torch.no_grad(): | ||
loss = model(next(val_loader)) | ||
print(f'validation loss: {loss.item()}') | ||
|
||
if i % GENERATE_EVERY == 0: | ||
model.eval() | ||
inp = random.choice(val_dataset)[:-1] | ||
prime = decode_tokens(inp) | ||
print(f'%s \n\n %s', (prime, '*' * 100)) | ||
|
||
sample = model.generate(inp, GENERATE_LENGTH) | ||
output_str = decode_tokens(sample) | ||
print(output_str) |