Skip to content

Commit

Permalink
be more defensive around model_type, don't let the user shoot themsel…
Browse files Browse the repository at this point in the history
…ves in the foot
  • Loading branch information
karpathy committed Jun 27, 2022
1 parent 1c8842d commit ea20661
Show file tree
Hide file tree
Showing 4 changed files with 13 additions and 11 deletions.
18 changes: 10 additions & 8 deletions mingpt/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,10 +104,10 @@ class GPT(nn.Module):
def get_default_config(cls):
C = CN()
# either model_type or (n_layer, n_head, n_embd) must be given in the config
C.name = None # string OR
C.n_layer = None # int
C.n_head = None # int
C.n_embd = None # int
C.model_type = 'gpt'
C.n_layer = None
C.n_head = None
C.n_embd = None
# these options must be filled in externally
C.vocab_size = None
C.block_size = None
Expand All @@ -123,8 +123,11 @@ def __init__(self, config):
assert config.block_size is not None
self.block_size = config.block_size

# map "named" GPT configurations to number of layers etc
if config.name is not None:
type_given = config.model_type is not None
params_given = all([config.n_layer is not None, config.n_head is not None, config.n_embd is not None])
assert (type_given and not params_given) or (not type_given and params_given) # exactly one of these
if type_given:
# translate from model_type to detailed configuration
config.merge_from_dict({
# names follow the huggingface naming conventions
# GPT-1
Expand All @@ -141,8 +144,7 @@ def __init__(self, config):
'gpt-mini': dict(n_layer=6, n_head=6, n_embd=192),
'gpt-micro': dict(n_layer=4, n_head=4, n_embd=128),
'gpt-nano': dict(n_layer=3, n_head=3, n_embd=48),
}[config.name])
assert all([config.n_layer is not None, config.n_head is not None, config.n_embd is not None])
}[config.model_type])

self.transformer = nn.ModuleDict(dict(
wte = nn.Embedding(config.vocab_size, config.n_embd),
Expand Down
2 changes: 1 addition & 1 deletion projects/adder/adder.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ def get_config():

# model
C.model = GPT.get_default_config()
C.model.name = 'gpt-nano'
C.model.model_type = 'gpt-nano'

# trainer
C.trainer = Trainer.get_default_config()
Expand Down
2 changes: 1 addition & 1 deletion projects/chargpt/chargpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ def get_config():

# model
C.model = GPT.get_default_config()
C.model.name = 'gpt-mini'
C.model.model_type = 'gpt-mini'

# trainer
C.trainer = Trainer.get_default_config()
Expand Down
2 changes: 1 addition & 1 deletion scripts/weights_import.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ def get_pretrained(model_type='gpt2'):

# init a mingpt model with the right hyperparams
conf = GPT.get_default_config()
conf.name = model_type
conf.model_type = model_type
conf.vocab_size = 50257 # openai's model vocabulary
conf.block_size = 1024 # openai's model block_size
model = GPT(conf)
Expand Down

0 comments on commit ea20661

Please sign in to comment.