Skip to content

Commit

Permalink
Merge pull request #81 from luigidisotto/callbacks-optimizer
Browse files Browse the repository at this point in the history
Add optimizer to Trainer's self for callbacks.
  • Loading branch information
karpathy committed Jul 26, 2022
2 parents e2065c5 + c4c650e commit 31559f7
Showing 1 changed file with 3 additions and 2 deletions.
5 changes: 3 additions & 2 deletions mingpt/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ def get_default_config():
def __init__(self, config, model, train_dataset):
self.config = config
self.model = model
self.optimizer = None
self.train_dataset = train_dataset
self.callbacks = defaultdict(list)

Expand Down Expand Up @@ -61,7 +62,7 @@ def run(self):
model, config = self.model, self.config

# setup the optimizer
optimizer = model.configure_optimizers(config)
self.optimizer = model.configure_optimizers(config)

# setup the dataloader
train_loader = DataLoader(
Expand Down Expand Up @@ -95,7 +96,7 @@ def run(self):
model.zero_grad(set_to_none=True)
self.loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), config.grad_norm_clip)
optimizer.step()
self.optimizer.step()

self.trigger_callbacks('on_batch_end')
self.iter_num += 1
Expand Down

0 comments on commit 31559f7

Please sign in to comment.