From acaadacd595dedf354a48852841824d51ce131af Mon Sep 17 00:00:00 2001 From: Andrej Date: Mon, 11 Jul 2022 18:50:53 +0000 Subject: [PATCH] refactor sequence generation into the model and match the huggingface/transformers api. touches everything but this makes a lot more sense to me aesthetically --- demo.ipynb | 70 ++++++++++++++++---------------- generate.ipynb | 10 ++--- mingpt/model.py | 42 ++++++++++++++----- mingpt/utils.py | 40 +----------------- projects/adder/adder.py | 4 +- projects/chargpt/chargpt.py | 4 +- tests/test_huggingface_import.py | 14 +++---- 7 files changed, 83 insertions(+), 101 deletions(-) diff --git a/demo.ipynb b/demo.ipynb index 1265af56..4e74622e 100644 --- a/demo.ipynb +++ b/demo.ipynb @@ -15,7 +15,9 @@ "source": [ "import torch\n", "from torch.utils.data import Dataset\n", - "from torch.utils.data.dataloader import DataLoader" + "from torch.utils.data.dataloader import DataLoader\n", + "from mingpt.utils import set_seed\n", + "set_seed(3407)" ] }, { @@ -96,17 +98,17 @@ "name": "stdout", "output_type": "stream", "text": [ - "2 -1\n", - "2 -1\n", + "1 -1\n", "0 -1\n", "1 -1\n", "0 -1\n", - "2 0\n", + "0 -1\n", + "0 0\n", + "0 0\n", + "0 0\n", "0 0\n", "0 1\n", - "1 2\n", - "2 2\n", - "2 2\n" + "1 1\n" ] } ], @@ -152,7 +154,7 @@ "name": "stdout", "output_type": "stream", "text": [ - "running on device cpu\n" + "running on device cuda\n" ] } ], @@ -176,26 +178,26 @@ "name": "stdout", "output_type": "stream", "text": [ - "iter_dt 0.00ms; iter 0: train loss 1.09793\n", - "iter_dt 29.36ms; iter 100: train loss 0.14420\n", - "iter_dt 29.03ms; iter 200: train loss 0.04971\n", - "iter_dt 28.62ms; iter 300: train loss 0.03680\n", - "iter_dt 28.92ms; iter 400: train loss 0.01332\n", - "iter_dt 28.34ms; iter 500: train loss 0.01905\n", - "iter_dt 28.35ms; iter 600: train loss 0.02515\n", - "iter_dt 28.69ms; iter 700: train loss 0.02522\n", - "iter_dt 28.70ms; iter 800: train loss 0.02379\n", - "iter_dt 28.39ms; iter 900: train loss 0.00192\n", - "iter_dt 28.40ms; iter 1000: train loss 0.01416\n", - "iter_dt 28.47ms; iter 1100: train loss 0.00136\n", - "iter_dt 28.21ms; iter 1200: train loss 0.02124\n", - "iter_dt 28.21ms; iter 1300: train loss 0.05553\n", - "iter_dt 28.39ms; iter 1400: train loss 0.00930\n", - "iter_dt 28.00ms; iter 1500: train loss 0.00863\n", - "iter_dt 28.57ms; iter 1600: train loss 0.00624\n", - "iter_dt 28.39ms; iter 1700: train loss 0.00355\n", - "iter_dt 28.35ms; iter 1800: train loss 0.00235\n", - "iter_dt 28.98ms; iter 1900: train loss 0.00243\n" + "iter_dt 0.00ms; iter 0: train loss 1.06407\n", + "iter_dt 18.17ms; iter 100: train loss 0.14712\n", + "iter_dt 18.70ms; iter 200: train loss 0.05315\n", + "iter_dt 19.65ms; iter 300: train loss 0.04404\n", + "iter_dt 31.64ms; iter 400: train loss 0.04724\n", + "iter_dt 18.43ms; iter 500: train loss 0.02521\n", + "iter_dt 19.83ms; iter 600: train loss 0.03352\n", + "iter_dt 19.58ms; iter 700: train loss 0.00539\n", + "iter_dt 18.72ms; iter 800: train loss 0.02057\n", + "iter_dt 18.26ms; iter 900: train loss 0.00360\n", + "iter_dt 18.50ms; iter 1000: train loss 0.00788\n", + "iter_dt 20.64ms; iter 1100: train loss 0.01162\n", + "iter_dt 18.63ms; iter 1200: train loss 0.00963\n", + "iter_dt 18.32ms; iter 1300: train loss 0.02066\n", + "iter_dt 18.40ms; iter 1400: train loss 0.01739\n", + "iter_dt 18.37ms; iter 1500: train loss 0.00376\n", + "iter_dt 18.67ms; iter 1600: train loss 0.00133\n", + "iter_dt 18.38ms; iter 1700: train loss 0.00179\n", + "iter_dt 18.66ms; iter 1800: train loss 0.00079\n", + "iter_dt 18.48ms; iter 1900: train loss 0.00042\n" ] } ], @@ -233,8 +235,6 @@ } ], "source": [ - "from mingpt.utils import sample\n", - "\n", "def eval_split(trainer, split, max_batches):\n", " dataset = {'train':train_dataset, 'test':test_dataset}[split]\n", " n = train_dataset.length # naugy direct access shrug\n", @@ -248,7 +248,7 @@ " inp = x[:, :n]\n", " sol = y[:, -n:]\n", " # let the model sample the rest of the sequence\n", - " cat = sample(model, inp, n, sample=False) # using greedy argmax, not sampling\n", + " cat = model.generate(inp, n, do_sample=False) # using greedy argmax, not sampling\n", " sol_candidate = cat[:, n:] # isolate the filled in sequence\n", " # compare the predicted sequence to the true sequence\n", " correct = (sol == sol_candidate).all(1).cpu() # Software 1.0 vs. Software 2.0 fight RIGHT on this line haha\n", @@ -291,7 +291,7 @@ "inp = torch.tensor([[0, 0, 2, 1, 0, 1]], dtype=torch.long).to(trainer.device)\n", "assert inp[0].nelement() == n\n", "with torch.no_grad():\n", - " cat = sample(model, inp, n, sample=False)\n", + " cat = model.generate(inp, n, do_sample=False)\n", "sol = torch.sort(inp[0])[0]\n", "sol_candidate = cat[:, n:]\n", "print('input sequence :', inp.tolist())\n", @@ -303,7 +303,7 @@ ], "metadata": { "kernelspec": { - "display_name": "Python 3.8.5 ('base')", + "display_name": "Python 3.10.4 64-bit", "language": "python", "name": "python3" }, @@ -317,12 +317,12 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.8.5" + "version": "3.10.4" }, "orig_nbformat": 4, "vscode": { "interpreter": { - "hash": "afdab15bd6582f87e2d1e596bfa7241af51aedf8abc909e2cab3828057cb30c9" + "hash": "3ad933181bd8a04b432d3370b9dc3b0662ad032c4dfaa4e4f1596c548f763858" } } }, diff --git a/generate.ipynb b/generate.ipynb index 8965a271..3f29559a 100644 --- a/generate.ipynb +++ b/generate.ipynb @@ -15,7 +15,6 @@ "source": [ "from transformers import GPT2Tokenizer, GPT2LMHeadModel\n", "from mingpt.model import GPT\n", - "from mingpt.utils import sample\n", "from mingpt.utils import set_seed\n", "set_seed(3407)" ] @@ -74,10 +73,7 @@ " x = x.expand(num_samples, -1)\n", "\n", " # forward the model `steps` times to get samples, in a batch\n", - " if use_mingpt:\n", - " y = sample(model=model, x=x, steps=steps, sample=do_sample, top_k=40)\n", - " else:\n", - " y = model.generate(x, max_new_tokens=steps, do_sample=do_sample, top_k=40)\n", + " y = model.generate(x, max_new_tokens=steps, do_sample=do_sample, top_k=40)\n", " \n", " for i in range(num_samples):\n", " out = tokenizer.decode(y[i].cpu().squeeze())\n", @@ -95,8 +91,8 @@ "name": "stderr", "output_type": "stream", "text": [ - "2022-07-08 23:51:10.949993: W tensorflow/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libcudart.so.11.0'; dlerror: libcudart.so.11.0: cannot open shared object file: No such file or directory\n", - "2022-07-08 23:51:10.950042: I tensorflow/stream_executor/cuda/cudart_stub.cc:29] Ignore above cudart dlerror if you do not have a GPU set up on your machine.\n" + "2022-07-11 18:42:21.744061: W tensorflow/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libcudart.so.11.0'; dlerror: libcudart.so.11.0: cannot open shared object file: No such file or directory\n", + "2022-07-11 18:42:21.744099: I tensorflow/stream_executor/cuda/cudart_stub.cc:29] Ignore above cudart dlerror if you do not have a GPU set up on your machine.\n" ] }, { diff --git a/mingpt/model.py b/mingpt/model.py index 06c4c7f1..dfa95bf6 100644 --- a/mingpt/model.py +++ b/mingpt/model.py @@ -1,10 +1,5 @@ """ -GPT model: -- the initial stem consists of a combination of token encoding and a positional encoding -- the meat of it is a uniform sequence of Transformer blocks - - each Transformer is a sequential combination of a 1-hidden-layer MLP block and a self-attention block - - all blocks feed into a central residual pathway similar to resnets -- the final decoder is a linear projection into a vanilla Softmax classifier +Full definition of a GPT Language Model, all of it in this single file. References: 1) the official GPT-2 TensorFlow implementation released by OpenAI: @@ -161,13 +156,10 @@ def __init__(self, config): if pn.endswith('c_proj.weight'): torch.nn.init.normal_(p, mean=0.0, std=0.02/math.sqrt(2 * config.n_layer)) - # report number of parameters + # report number of parameters (note we don't count the decoder parameters in lm_head) n_params = sum(p.numel() for p in self.transformer.parameters()) print("number of parameters: %.2fM" % (n_params/1e6,)) - def get_block_size(self): - return self.block_size - def _init_weights(self, module): if isinstance(module, nn.Linear): torch.nn.init.normal_(module.weight, mean=0.0, std=0.02) @@ -286,3 +278,33 @@ def forward(self, idx, targets=None): loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-1) return logits, loss + + @torch.no_grad() + def generate(self, idx, max_new_tokens, temperature=1.0, do_sample=False, top_k=None): + """ + Take a conditioning sequence of indices idx (LongTensor of shape (b,t)) and complete + the sequence max_new_tokens times, feeding the predictions back into the model each time. + Most likely you'll want to make sure to be in model.eval() mode of operation for this. + """ + for _ in range(max_new_tokens): + # if the sequence context is growing too long we must crop it at block_size + idx_cond = idx if idx.size(1) <= self.block_size else idx[:, -self.block_size:] + # forward the model to get the logits for the index in the sequence + logits, _ = self(idx_cond) + # pluck the logits at the final step and scale by desired temperature + logits = logits[:, -1, :] / temperature + # optionally crop the logits to only the top k options + if top_k is not None: + v, _ = torch.topk(logits, top_k) + logits[logits < v[:, [-1]]] = -float('Inf') + # apply softmax to convert logits to (normalized) probabilities + probs = F.softmax(logits, dim=-1) + # either sample from the distribution or take the most likely element + if do_sample: + idx_next = torch.multinomial(probs, num_samples=1) + else: + _, idx_next = torch.topk(probs, k=1, dim=-1) + # append sampled index to the running sequence and continue + idx = torch.cat((idx, idx_next), dim=1) + + return idx diff --git a/mingpt/utils.py b/mingpt/utils.py index cf932618..af864ecb 100644 --- a/mingpt/utils.py +++ b/mingpt/utils.py @@ -7,8 +7,8 @@ import numpy as np import torch -import torch.nn as nn -from torch.nn import functional as F + +# ----------------------------------------------------------------------------- def set_seed(seed): random.seed(seed) @@ -28,42 +28,6 @@ def setup_logging(config): with open(os.path.join(work_dir, 'config.json'), 'w') as f: f.write(json.dumps(config.to_dict(), indent=4)) -def top_k_logits(logits, k): - v, ix = torch.topk(logits, k) - out = logits.clone() - out[out < v[:, [-1]]] = -float('Inf') - return out - -@torch.no_grad() -def sample(model, x, steps, temperature=1.0, sample=False, top_k=None): - """ - take a conditioning sequence of indices in x (of shape (b,t)) and predict the next token in - the sequence, feeding the predictions back into the model each time. Clearly the sampling - has quadratic complexity unlike an RNN that is only linear, and has a finite context window - of block_size, unlike an RNN that has an infinite context window. - """ - block_size = model.get_block_size() - model.eval() - for k in range(steps): - x_cond = x if x.size(1) <= block_size else x[:, -block_size:] # crop context if needed - logits, _ = model(x_cond) - # pluck the logits at the final step and scale by temperature - logits = logits[:, -1, :] / temperature - # optionally crop probabilities to only the top k options - if top_k is not None: - logits = top_k_logits(logits, top_k) - # apply softmax to convert to probabilities - probs = F.softmax(logits, dim=-1) - # sample from the distribution or take the most likely - if sample: - ix = torch.multinomial(probs, num_samples=1) - else: - _, ix = torch.topk(probs, k=1, dim=-1) - # append to the sequence and continue - x = torch.cat((x, ix), dim=1) - - return x - class CfgNode: """ a lightweight configuration class inspired by yacs """ # TODO: convert to subclass from a dict like in yacs? diff --git a/projects/adder/adder.py b/projects/adder/adder.py index eda5e074..55f03ee1 100644 --- a/projects/adder/adder.py +++ b/projects/adder/adder.py @@ -12,7 +12,7 @@ from mingpt.model import GPT from mingpt.trainer import Trainer -from mingpt.utils import set_seed, sample, setup_logging, CfgNode as CN +from mingpt.utils import set_seed, setup_logging, CfgNode as CN # ----------------------------------------------------------------------------- @@ -154,7 +154,7 @@ def eval_split(trainer, split, max_batches=None): # isolate the first two digits of the input sequence alone d1d2 = x[:, :ndigit*2] # let the model sample the rest of the sequence - d1d2d3 = sample(model, d1d2, ndigit+1, sample=False) # using greedy argmax, not sampling + d1d2d3 = model.generate(d1d2, ndigit+1, do_sample=False) # using greedy argmax, not sampling # isolate the last digit of the sampled sequence d3 = d1d2d3[:, -(ndigit+1):] d3 = d3.flip(1) # reverse the digits to their "normal" order diff --git a/projects/chargpt/chargpt.py b/projects/chargpt/chargpt.py index bcaca109..5de925b0 100644 --- a/projects/chargpt/chargpt.py +++ b/projects/chargpt/chargpt.py @@ -11,7 +11,7 @@ from mingpt.model import GPT from mingpt.trainer import Trainer -from mingpt.utils import set_seed, sample, setup_logging, CfgNode as CN +from mingpt.utils import set_seed, setup_logging, CfgNode as CN # ----------------------------------------------------------------------------- @@ -117,7 +117,7 @@ def batch_end_callback(trainer): # sample from the model... context = "O God, O God!" x = torch.tensor([train_dataset.stoi[s] for s in context], dtype=torch.long)[None,...].to(trainer.device) - y = sample(model, x, 500, temperature=1.0, sample=True, top_k=10)[0] + y = model.generate(x, 500, temperature=1.0, do_sample=True, top_k=10)[0] completion = ''.join([train_dataset.itos[int(i)] for i in y]) print(completion) # save the latest model diff --git a/tests/test_huggingface_import.py b/tests/test_huggingface_import.py index b9142119..96f6e6f0 100644 --- a/tests/test_huggingface_import.py +++ b/tests/test_huggingface_import.py @@ -6,7 +6,6 @@ import torch from transformers import GPT2Tokenizer, GPT2LMHeadModel from mingpt.model import GPT -from mingpt.utils import sample # ----------------------------------------------------------------------------- @@ -41,14 +40,15 @@ def test_gpt2(self): logits2 = model_hf(x).logits self.assertTrue(torch.allclose(logits1, logits2)) - # now draw the argmax samples from each and compare them - y1 = sample(model=model, x=x, steps=20, sample=False)[0] - out1 = tokenizer.decode(y1.cpu().squeeze()) + # now draw the argmax samples from each + y1 = model.generate(x, max_new_tokens=20, do_sample=False)[0] y2 = model_hf.generate(x, max_new_tokens=20, do_sample=False)[0] - out2 = tokenizer.decode(y2.cpu().squeeze()) - self.assertTrue(torch.equal(y1, y2)) - self.assertTrue(out1 == out2) # compare the output strings too, exactly + self.assertTrue(torch.equal(y1, y2)) # compare the raw sampled indices + # convert indices to strings + out1 = tokenizer.decode(y1.cpu().squeeze()) + out2 = tokenizer.decode(y2.cpu().squeeze()) + self.assertTrue(out1 == out2) # compare the exact output strings too if __name__ == '__main__': unittest.main() \ No newline at end of file