Skip to content

Commit

Permalink
refactor sequence generation into the model and match the huggingface…
Browse files Browse the repository at this point in the history
…/transformers api. touches everything but this makes a lot more sense to me aesthetically
  • Loading branch information
karpathy committed Jul 11, 2022
1 parent 5af9e5c commit acaadac
Show file tree
Hide file tree
Showing 7 changed files with 83 additions and 101 deletions.
70 changes: 35 additions & 35 deletions demo.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,9 @@
"source": [
"import torch\n",
"from torch.utils.data import Dataset\n",
"from torch.utils.data.dataloader import DataLoader"
"from torch.utils.data.dataloader import DataLoader\n",
"from mingpt.utils import set_seed\n",
"set_seed(3407)"
]
},
{
Expand Down Expand Up @@ -96,17 +98,17 @@
"name": "stdout",
"output_type": "stream",
"text": [
"2 -1\n",
"2 -1\n",
"1 -1\n",
"0 -1\n",
"1 -1\n",
"0 -1\n",
"2 0\n",
"0 -1\n",
"0 0\n",
"0 0\n",
"0 0\n",
"0 0\n",
"0 1\n",
"1 2\n",
"2 2\n",
"2 2\n"
"1 1\n"
]
}
],
Expand Down Expand Up @@ -152,7 +154,7 @@
"name": "stdout",
"output_type": "stream",
"text": [
"running on device cpu\n"
"running on device cuda\n"
]
}
],
Expand All @@ -176,26 +178,26 @@
"name": "stdout",
"output_type": "stream",
"text": [
"iter_dt 0.00ms; iter 0: train loss 1.09793\n",
"iter_dt 29.36ms; iter 100: train loss 0.14420\n",
"iter_dt 29.03ms; iter 200: train loss 0.04971\n",
"iter_dt 28.62ms; iter 300: train loss 0.03680\n",
"iter_dt 28.92ms; iter 400: train loss 0.01332\n",
"iter_dt 28.34ms; iter 500: train loss 0.01905\n",
"iter_dt 28.35ms; iter 600: train loss 0.02515\n",
"iter_dt 28.69ms; iter 700: train loss 0.02522\n",
"iter_dt 28.70ms; iter 800: train loss 0.02379\n",
"iter_dt 28.39ms; iter 900: train loss 0.00192\n",
"iter_dt 28.40ms; iter 1000: train loss 0.01416\n",
"iter_dt 28.47ms; iter 1100: train loss 0.00136\n",
"iter_dt 28.21ms; iter 1200: train loss 0.02124\n",
"iter_dt 28.21ms; iter 1300: train loss 0.05553\n",
"iter_dt 28.39ms; iter 1400: train loss 0.00930\n",
"iter_dt 28.00ms; iter 1500: train loss 0.00863\n",
"iter_dt 28.57ms; iter 1600: train loss 0.00624\n",
"iter_dt 28.39ms; iter 1700: train loss 0.00355\n",
"iter_dt 28.35ms; iter 1800: train loss 0.00235\n",
"iter_dt 28.98ms; iter 1900: train loss 0.00243\n"
"iter_dt 0.00ms; iter 0: train loss 1.06407\n",
"iter_dt 18.17ms; iter 100: train loss 0.14712\n",
"iter_dt 18.70ms; iter 200: train loss 0.05315\n",
"iter_dt 19.65ms; iter 300: train loss 0.04404\n",
"iter_dt 31.64ms; iter 400: train loss 0.04724\n",
"iter_dt 18.43ms; iter 500: train loss 0.02521\n",
"iter_dt 19.83ms; iter 600: train loss 0.03352\n",
"iter_dt 19.58ms; iter 700: train loss 0.00539\n",
"iter_dt 18.72ms; iter 800: train loss 0.02057\n",
"iter_dt 18.26ms; iter 900: train loss 0.00360\n",
"iter_dt 18.50ms; iter 1000: train loss 0.00788\n",
"iter_dt 20.64ms; iter 1100: train loss 0.01162\n",
"iter_dt 18.63ms; iter 1200: train loss 0.00963\n",
"iter_dt 18.32ms; iter 1300: train loss 0.02066\n",
"iter_dt 18.40ms; iter 1400: train loss 0.01739\n",
"iter_dt 18.37ms; iter 1500: train loss 0.00376\n",
"iter_dt 18.67ms; iter 1600: train loss 0.00133\n",
"iter_dt 18.38ms; iter 1700: train loss 0.00179\n",
"iter_dt 18.66ms; iter 1800: train loss 0.00079\n",
"iter_dt 18.48ms; iter 1900: train loss 0.00042\n"
]
}
],
Expand Down Expand Up @@ -233,8 +235,6 @@
}
],
"source": [
"from mingpt.utils import sample\n",
"\n",
"def eval_split(trainer, split, max_batches):\n",
" dataset = {'train':train_dataset, 'test':test_dataset}[split]\n",
" n = train_dataset.length # naugy direct access shrug\n",
Expand All @@ -248,7 +248,7 @@
" inp = x[:, :n]\n",
" sol = y[:, -n:]\n",
" # let the model sample the rest of the sequence\n",
" cat = sample(model, inp, n, sample=False) # using greedy argmax, not sampling\n",
" cat = model.generate(inp, n, do_sample=False) # using greedy argmax, not sampling\n",
" sol_candidate = cat[:, n:] # isolate the filled in sequence\n",
" # compare the predicted sequence to the true sequence\n",
" correct = (sol == sol_candidate).all(1).cpu() # Software 1.0 vs. Software 2.0 fight RIGHT on this line haha\n",
Expand Down Expand Up @@ -291,7 +291,7 @@
"inp = torch.tensor([[0, 0, 2, 1, 0, 1]], dtype=torch.long).to(trainer.device)\n",
"assert inp[0].nelement() == n\n",
"with torch.no_grad():\n",
" cat = sample(model, inp, n, sample=False)\n",
" cat = model.generate(inp, n, do_sample=False)\n",
"sol = torch.sort(inp[0])[0]\n",
"sol_candidate = cat[:, n:]\n",
"print('input sequence :', inp.tolist())\n",
Expand All @@ -303,7 +303,7 @@
],
"metadata": {
"kernelspec": {
"display_name": "Python 3.8.5 ('base')",
"display_name": "Python 3.10.4 64-bit",
"language": "python",
"name": "python3"
},
Expand All @@ -317,12 +317,12 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.5"
"version": "3.10.4"
},
"orig_nbformat": 4,
"vscode": {
"interpreter": {
"hash": "afdab15bd6582f87e2d1e596bfa7241af51aedf8abc909e2cab3828057cb30c9"
"hash": "3ad933181bd8a04b432d3370b9dc3b0662ad032c4dfaa4e4f1596c548f763858"
}
}
},
Expand Down
10 changes: 3 additions & 7 deletions generate.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
"source": [
"from transformers import GPT2Tokenizer, GPT2LMHeadModel\n",
"from mingpt.model import GPT\n",
"from mingpt.utils import sample\n",
"from mingpt.utils import set_seed\n",
"set_seed(3407)"
]
Expand Down Expand Up @@ -74,10 +73,7 @@
" x = x.expand(num_samples, -1)\n",
"\n",
" # forward the model `steps` times to get samples, in a batch\n",
" if use_mingpt:\n",
" y = sample(model=model, x=x, steps=steps, sample=do_sample, top_k=40)\n",
" else:\n",
" y = model.generate(x, max_new_tokens=steps, do_sample=do_sample, top_k=40)\n",
" y = model.generate(x, max_new_tokens=steps, do_sample=do_sample, top_k=40)\n",
" \n",
" for i in range(num_samples):\n",
" out = tokenizer.decode(y[i].cpu().squeeze())\n",
Expand All @@ -95,8 +91,8 @@
"name": "stderr",
"output_type": "stream",
"text": [
"2022-07-08 23:51:10.949993: W tensorflow/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libcudart.so.11.0'; dlerror: libcudart.so.11.0: cannot open shared object file: No such file or directory\n",
"2022-07-08 23:51:10.950042: I tensorflow/stream_executor/cuda/cudart_stub.cc:29] Ignore above cudart dlerror if you do not have a GPU set up on your machine.\n"
"2022-07-11 18:42:21.744061: W tensorflow/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libcudart.so.11.0'; dlerror: libcudart.so.11.0: cannot open shared object file: No such file or directory\n",
"2022-07-11 18:42:21.744099: I tensorflow/stream_executor/cuda/cudart_stub.cc:29] Ignore above cudart dlerror if you do not have a GPU set up on your machine.\n"
]
},
{
Expand Down
42 changes: 32 additions & 10 deletions mingpt/model.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,5 @@
"""
GPT model:
- the initial stem consists of a combination of token encoding and a positional encoding
- the meat of it is a uniform sequence of Transformer blocks
- each Transformer is a sequential combination of a 1-hidden-layer MLP block and a self-attention block
- all blocks feed into a central residual pathway similar to resnets
- the final decoder is a linear projection into a vanilla Softmax classifier
Full definition of a GPT Language Model, all of it in this single file.
References:
1) the official GPT-2 TensorFlow implementation released by OpenAI:
Expand Down Expand Up @@ -161,13 +156,10 @@ def __init__(self, config):
if pn.endswith('c_proj.weight'):
torch.nn.init.normal_(p, mean=0.0, std=0.02/math.sqrt(2 * config.n_layer))

# report number of parameters
# report number of parameters (note we don't count the decoder parameters in lm_head)
n_params = sum(p.numel() for p in self.transformer.parameters())
print("number of parameters: %.2fM" % (n_params/1e6,))

def get_block_size(self):
return self.block_size

def _init_weights(self, module):
if isinstance(module, nn.Linear):
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
Expand Down Expand Up @@ -286,3 +278,33 @@ def forward(self, idx, targets=None):
loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-1)

return logits, loss

@torch.no_grad()
def generate(self, idx, max_new_tokens, temperature=1.0, do_sample=False, top_k=None):
"""
Take a conditioning sequence of indices idx (LongTensor of shape (b,t)) and complete
the sequence max_new_tokens times, feeding the predictions back into the model each time.
Most likely you'll want to make sure to be in model.eval() mode of operation for this.
"""
for _ in range(max_new_tokens):
# if the sequence context is growing too long we must crop it at block_size
idx_cond = idx if idx.size(1) <= self.block_size else idx[:, -self.block_size:]
# forward the model to get the logits for the index in the sequence
logits, _ = self(idx_cond)
# pluck the logits at the final step and scale by desired temperature
logits = logits[:, -1, :] / temperature
# optionally crop the logits to only the top k options
if top_k is not None:
v, _ = torch.topk(logits, top_k)
logits[logits < v[:, [-1]]] = -float('Inf')
# apply softmax to convert logits to (normalized) probabilities
probs = F.softmax(logits, dim=-1)
# either sample from the distribution or take the most likely element
if do_sample:
idx_next = torch.multinomial(probs, num_samples=1)
else:
_, idx_next = torch.topk(probs, k=1, dim=-1)
# append sampled index to the running sequence and continue
idx = torch.cat((idx, idx_next), dim=1)

return idx
40 changes: 2 additions & 38 deletions mingpt/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@

import numpy as np
import torch
import torch.nn as nn
from torch.nn import functional as F

# -----------------------------------------------------------------------------

def set_seed(seed):
random.seed(seed)
Expand All @@ -28,42 +28,6 @@ def setup_logging(config):
with open(os.path.join(work_dir, 'config.json'), 'w') as f:
f.write(json.dumps(config.to_dict(), indent=4))

def top_k_logits(logits, k):
v, ix = torch.topk(logits, k)
out = logits.clone()
out[out < v[:, [-1]]] = -float('Inf')
return out

@torch.no_grad()
def sample(model, x, steps, temperature=1.0, sample=False, top_k=None):
"""
take a conditioning sequence of indices in x (of shape (b,t)) and predict the next token in
the sequence, feeding the predictions back into the model each time. Clearly the sampling
has quadratic complexity unlike an RNN that is only linear, and has a finite context window
of block_size, unlike an RNN that has an infinite context window.
"""
block_size = model.get_block_size()
model.eval()
for k in range(steps):
x_cond = x if x.size(1) <= block_size else x[:, -block_size:] # crop context if needed
logits, _ = model(x_cond)
# pluck the logits at the final step and scale by temperature
logits = logits[:, -1, :] / temperature
# optionally crop probabilities to only the top k options
if top_k is not None:
logits = top_k_logits(logits, top_k)
# apply softmax to convert to probabilities
probs = F.softmax(logits, dim=-1)
# sample from the distribution or take the most likely
if sample:
ix = torch.multinomial(probs, num_samples=1)
else:
_, ix = torch.topk(probs, k=1, dim=-1)
# append to the sequence and continue
x = torch.cat((x, ix), dim=1)

return x

class CfgNode:
""" a lightweight configuration class inspired by yacs """
# TODO: convert to subclass from a dict like in yacs?
Expand Down
4 changes: 2 additions & 2 deletions projects/adder/adder.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@

from mingpt.model import GPT
from mingpt.trainer import Trainer
from mingpt.utils import set_seed, sample, setup_logging, CfgNode as CN
from mingpt.utils import set_seed, setup_logging, CfgNode as CN

# -----------------------------------------------------------------------------

Expand Down Expand Up @@ -154,7 +154,7 @@ def eval_split(trainer, split, max_batches=None):
# isolate the first two digits of the input sequence alone
d1d2 = x[:, :ndigit*2]
# let the model sample the rest of the sequence
d1d2d3 = sample(model, d1d2, ndigit+1, sample=False) # using greedy argmax, not sampling
d1d2d3 = model.generate(d1d2, ndigit+1, do_sample=False) # using greedy argmax, not sampling
# isolate the last digit of the sampled sequence
d3 = d1d2d3[:, -(ndigit+1):]
d3 = d3.flip(1) # reverse the digits to their "normal" order
Expand Down
4 changes: 2 additions & 2 deletions projects/chargpt/chargpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

from mingpt.model import GPT
from mingpt.trainer import Trainer
from mingpt.utils import set_seed, sample, setup_logging, CfgNode as CN
from mingpt.utils import set_seed, setup_logging, CfgNode as CN

# -----------------------------------------------------------------------------

Expand Down Expand Up @@ -117,7 +117,7 @@ def batch_end_callback(trainer):
# sample from the model...
context = "O God, O God!"
x = torch.tensor([train_dataset.stoi[s] for s in context], dtype=torch.long)[None,...].to(trainer.device)
y = sample(model, x, 500, temperature=1.0, sample=True, top_k=10)[0]
y = model.generate(x, 500, temperature=1.0, do_sample=True, top_k=10)[0]
completion = ''.join([train_dataset.itos[int(i)] for i in y])
print(completion)
# save the latest model
Expand Down
14 changes: 7 additions & 7 deletions tests/test_huggingface_import.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
import torch
from transformers import GPT2Tokenizer, GPT2LMHeadModel
from mingpt.model import GPT
from mingpt.utils import sample

# -----------------------------------------------------------------------------

Expand Down Expand Up @@ -41,14 +40,15 @@ def test_gpt2(self):
logits2 = model_hf(x).logits
self.assertTrue(torch.allclose(logits1, logits2))

# now draw the argmax samples from each and compare them
y1 = sample(model=model, x=x, steps=20, sample=False)[0]
out1 = tokenizer.decode(y1.cpu().squeeze())
# now draw the argmax samples from each
y1 = model.generate(x, max_new_tokens=20, do_sample=False)[0]
y2 = model_hf.generate(x, max_new_tokens=20, do_sample=False)[0]
out2 = tokenizer.decode(y2.cpu().squeeze())
self.assertTrue(torch.equal(y1, y2))
self.assertTrue(out1 == out2) # compare the output strings too, exactly
self.assertTrue(torch.equal(y1, y2)) # compare the raw sampled indices

# convert indices to strings
out1 = tokenizer.decode(y1.cpu().squeeze())
out2 = tokenizer.decode(y2.cpu().squeeze())
self.assertTrue(out1 == out2) # compare the exact output strings too

if __name__ == '__main__':
unittest.main()

0 comments on commit acaadac

Please sign in to comment.