Skip to content

Commit

Permalink
refactor pretrained weight loading into from_pretrained and add unit …
Browse files Browse the repository at this point in the history
…tests
  • Loading branch information
karpathy committed Jul 8, 2022
1 parent 4a56b20 commit 803f388
Show file tree
Hide file tree
Showing 4 changed files with 121 additions and 71 deletions.
8 changes: 8 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,14 @@ trainer.run()

See `demo.ipynb` for a more concrete example.

### Unit tests

Coverage is not super amazing just yet but:

```
python -m unittest discover tests
```

### References

Code:
Expand Down
41 changes: 41 additions & 0 deletions mingpt/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,47 @@ def _init_weights(self, module):
torch.nn.init.zeros_(module.bias)
torch.nn.init.ones_(module.weight)

@classmethod
def from_pretrained(cls, model_type):
"""
Initialize a pretrained GPT model by copying over the weights
from a huggingface/transformers checkpoint.
"""
assert model_type in {'gpt2', 'gpt2-medium', 'gpt2-large', 'gpt2-xl'}
from transformers import GPT2LMHeadModel

# create a from-scratch initialized minGPT model
config = cls.get_default_config()
config.model_type = model_type
config.vocab_size = 50257 # openai's model vocabulary
config.block_size = 1024 # openai's model block_size
model = GPT(config)
sd = model.state_dict()

# init a huggingface/transformers model
model_hf = GPT2LMHeadModel.from_pretrained(model_type)
sd_hf = model_hf.state_dict()

# copy while ensuring all of the parameters are aligned and match in names and shapes
keys = [k for k in sd_hf if not k.endswith('attn.masked_bias')] # ignore these
transposed = ['attn.c_attn.weight', 'attn.c_proj.weight', 'mlp.c_fc.weight', 'mlp.c_proj.weight']
# basically the openai checkpoints use a "Conv1D" module, but we only want to use a vanilla nn.Linear.
# this means that we have to transpose these weights when we import them
assert len(keys) == len(sd)
for k in keys:
if any(k.endswith(w) for w in transposed):
# special treatment for the Conv1D weights we need to transpose
assert sd_hf[k].shape[::-1] == sd[k].shape
with torch.no_grad():
sd[k].copy_(sd_hf[k].t())
else:
# vanilla copy over the other parameters
assert sd_hf[k].shape == sd[k].shape
with torch.no_grad():
sd[k].copy_(sd_hf[k])

return model

def configure_optimizers(self, train_config):
"""
This long function is unfortunately doing something very simple and is being very defensive:
Expand Down
89 changes: 18 additions & 71 deletions scripts/weights_import.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
"""
This script will import the actual weights released by OpenAI,
with the help of code from huggingface/transformers. It also
that the forward pass matches exactly.
This script runs inference of GPT-2, both minGPT and huggingface/transformers
"""

import torch
Expand All @@ -13,94 +11,43 @@

# -----------------------------------------------------------------------------

def get_pretrained(model_type='gpt2'):
"""
model_type is one of gpt2|gpt2-medium|gpt2-large|gpt2-xl
returns an initialized GPT class
"""

# init a mingpt model with the right hyperparams
conf = GPT.get_default_config()
conf.model_type = model_type
conf.vocab_size = 50257 # openai's model vocabulary
conf.block_size = 1024 # openai's model block_size
model = GPT(conf)
sd = model.state_dict()

# init a huggingface/transformers model
model_hf = GPT2LMHeadModel.from_pretrained(model_type)
sd_hf = model_hf.state_dict()

# copy while ensuring all of the parameters are aligned and match in names and shapes
keys = [k for k in sd_hf if not k.endswith('attn.masked_bias')] # ignore these
transposed = ['attn.c_attn.weight', 'attn.c_proj.weight', 'mlp.c_fc.weight', 'mlp.c_proj.weight']
# basically the openai checkpoints use a "Conv1D" module, but we only want to use a vanilla nn.Linear.
# this means that we have to transpose these weights when we import them
assert len(keys) == len(sd)
for k in keys:
if any(k.endswith(w) for w in transposed):
# special treatment for the Conv1D weights we need to transpose
assert sd_hf[k].shape[::-1] == sd[k].shape
with torch.no_grad():
sd[k].copy_(sd_hf[k].t())
else:
# vanilla copy over the other parameters
assert sd_hf[k].shape == sd[k].shape
with torch.no_grad():
sd[k].copy_(sd_hf[k])

return model

# -----------------------------------------------------------------------------

def run(
model_type = 'gpt2',
prompt = "Hello, my dog is a little",
num_samples = 5,
steps = 20,
do_sample = True,
device = 'cuda',
use_mingpt = True, # use mingpt or huggingface/transformers model
):

# create both a minGPT model and a huggingface model
model = get_pretrained(model_type) # init a minGPT model
model_hf = GPT2LMHeadModel.from_pretrained(model_type) # init a HF model too
# instantiate the model
if use_mingpt:
model = GPT.from_pretrained(model_type)
else:
model = GPT2LMHeadModel.from_pretrained(model_type)
model.config.pad_token_id = model.config.eos_token_id # suppress a warning

# ship both to gpu
# ship model to device and set to eval mode
model.to(device)
model_hf.to(device)
# set both to eval mode
model.eval()
model_hf.eval()

# tokenize an input prompt
# tokenize the input prompt into integer input sequence
tokenizer = GPT2Tokenizer.from_pretrained(model_type)
model_hf.config.pad_token_id = model_hf.config.eos_token_id # suppress a warning
if prompt == '': # to create unconditional samples we feed in the special start token
prompt = '<|endoftext|>'
encoded_input = tokenizer(prompt, return_tensors='pt').to(device)
x = encoded_input['input_ids']

# ensure the logits match exactly
logits1, loss = model(x)
logits2 = model_hf(x).logits
assert torch.allclose(logits1, logits2)
# forward the model
logits = model(x)

# draw some samples from the HuggingFace model
print('-'*80)
print('huggingface samples')
print('-'*80)
# draw some samples
for _ in range(num_samples):
y = model_hf.generate(x, max_new_tokens=20, do_sample=do_sample, top_k=40)
out = tokenizer.decode(y.cpu().squeeze())
print('-'*80)
print(out)

# draw some samples from mingpt model
print('-'*80)
print('mingpt samples')
print('-'*80)
for _ in range(num_samples):
y = sample(model=model, x=x, steps=20, sample=do_sample, top_k=40)[0]
if use_mingpt:
y = sample(model=model, x=x, steps=steps, sample=do_sample, top_k=40)[0]
else:
y = model.generate(x, max_new_tokens=steps, do_sample=do_sample, top_k=40)
out = tokenizer.decode(y.cpu().squeeze())
print('-'*80)
print(out)
Expand Down
54 changes: 54 additions & 0 deletions tests/test_huggingface_import.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
"""
Ensure that we can load huggingface/transformer GPTs into minGPT
"""

import unittest
import torch
from transformers import GPT2Tokenizer, GPT2LMHeadModel
from mingpt.model import GPT
from mingpt.utils import sample

# -----------------------------------------------------------------------------

class TestHuggingFaceImport(unittest.TestCase):

def test_gpt2(self):
model_type = 'gpt2'
device = 'cuda' if torch.cuda.is_available() else 'cpu'
prompt = "Hello, my dog is a little"

# create a minGPT and a huggingface/transformers model
model = GPT.from_pretrained(model_type)
model_hf = GPT2LMHeadModel.from_pretrained(model_type) # init a HF model too

# ship both to device
model.to(device)
model_hf.to(device)
# set both to eval mode
model.eval()
model_hf.eval()

# tokenize an input prompt
tokenizer = GPT2Tokenizer.from_pretrained(model_type)
model_hf.config.pad_token_id = model_hf.config.eos_token_id # suppress a warning
if prompt == '': # to create unconditional samples we feed in the special start token
prompt = '<|endoftext|>'
encoded_input = tokenizer(prompt, return_tensors='pt').to(device)
x = encoded_input['input_ids']

# ensure the logits match exactly
logits1, loss = model(x)
logits2 = model_hf(x).logits
self.assertTrue(torch.allclose(logits1, logits2))

# now draw the argmax samples from each and compare them
y1 = sample(model=model, x=x, steps=20, sample=False)[0]
out1 = tokenizer.decode(y1.cpu().squeeze())
y2 = model_hf.generate(x, max_new_tokens=20, do_sample=False)[0]
out2 = tokenizer.decode(y2.cpu().squeeze())
self.assertTrue(torch.equal(y1, y2))
self.assertTrue(out1 == out2) # compare the output strings too, exactly


if __name__ == '__main__':
unittest.main()

0 comments on commit 803f388

Please sign in to comment.