Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
pkaleta committed Nov 26, 2023
1 parent 74a7bfa commit 6a65c60
Showing 1 changed file with 64 additions and 22 deletions.
86 changes: 64 additions & 22 deletions lectures/makemore/makemore_part4_backprop.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -305,7 +305,7 @@
},
{
"cell_type": "code",
"execution_count": 15,
"execution_count": 35,
"metadata": {},
"outputs": [
{
Expand All @@ -320,14 +320,29 @@
"norm_logits | exact: True | approximate: True | maxdiff: 0.0\n",
"logit_maxes | exact: True | approximate: True | maxdiff: 0.0\n",
"logits | exact: True | approximate: True | maxdiff: 0.0\n",
"h | exact: False | approximate: False | maxdiff: 1.4575186967849731\n"
"h | exact: True | approximate: True | maxdiff: 0.0\n",
"W2 | exact: True | approximate: True | maxdiff: 0.0\n",
"b2 | exact: True | approximate: True | maxdiff: 0.0\n",
"hpreact | exact: True | approximate: True | maxdiff: 0.0\n",
"bngain | exact: True | approximate: True | maxdiff: 0.0\n",
"bnbias | exact: True | approximate: True | maxdiff: 0.0\n",
"bnraw | exact: True | approximate: True | maxdiff: 0.0\n",
"bnvar_inv | exact: True | approximate: True | maxdiff: 0.0\n",
"bnvar | exact: True | approximate: True | maxdiff: 0.0\n",
"bndiff2 | exact: True | approximate: True | maxdiff: 0.0\n",
"bndiff | exact: True | approximate: True | maxdiff: 0.0\n",
"bnmeani | exact: True | approximate: True | maxdiff: 0.0\n",
"hprebn | exact: True | approximate: True | maxdiff: 0.0\n",
"embcat | exact: True | approximate: True | maxdiff: 0.0\n",
"W1 | exact: True | approximate: True | maxdiff: 0.0\n",
"b1 | exact: True | approximate: True | maxdiff: 0.0\n",
"emb | exact: True | approximate: True | maxdiff: 0.0\n",
"C | exact: True | approximate: True | maxdiff: 0.0\n"
]
}
],
"source": [
"# Exercise 1: backprop through the whole thing manually, \n",
"# backpropagating through exactly all of the variables \n",
"# as they are defined in the forward pass above, one by one\n",
"# Exercise 1: orward pass above, one by one\n",
"\n",
"dlogprobs = torch.zeros_like(logprobs)\n",
"dlogprobs[range(n), Yb] = -1/n\n",
Expand All @@ -349,6 +364,33 @@
"# dh = torch.ones_like(h) * W2.sum(axis=1, keepdims=True).T# * dlogits\n",
"dh = dlogits @ W2.T\n",
"\n",
"dW2 = h.T @ dlogits\n",
"\n",
"db2 = dlogits.sum(axis=0)\n",
"\n",
"dhpreact = (1 - torch.tanh(hpreact)**2) * dh\n",
"\n",
"dbngain = (bnraw * dhpreact).sum(0, keepdim=True)\n",
"dbnraw = bngain * dhpreact\n",
"dbnbias = dhpreact.sum(0, keepdim=True)\n",
"dbndiff = bnvar_inv * dbnraw\n",
"dbnvar_inv = (bndiff * dbnraw).sum(0, keepdim=True)\n",
"dbnvar = (-0.5*(bnvar + 1e-5)**-1.5) * dbnvar_inv\n",
"dbndiff2 = (1.0/(n-1))*torch.ones_like(bndiff2) * dbnvar\n",
"dbndiff += (2*bndiff) * dbndiff2\n",
"dhprebn = dbndiff.clone()\n",
"dbnmeani = (-dbndiff).sum(0)\n",
"dhprebn += 1.0/n * (torch.ones_like(hprebn) * dbnmeani)\n",
"dembcat = dhprebn @ W1.T\n",
"dW1 = embcat.T @ dhprebn\n",
"db1 = dhprebn.sum(0)\n",
"demb = dembcat.view(emb.shape)\n",
"dC = torch.zeros_like(C)\n",
"for k in range(Xb.shape[0]):\n",
" for j in range(Xb.shape[1]):\n",
" ix = Xb[k,j]\n",
" dC[ix] += demb[k,j]\n",
"\n",
"cmp('logprobs', dlogprobs, logprobs)\n",
"cmp('probs', dprobs, probs)\n",
"cmp('counts_sum_inv', dcounts_sum_inv, counts_sum_inv)\n",
Expand All @@ -358,23 +400,23 @@
"cmp('logit_maxes', dlogit_maxes, logit_maxes)\n",
"cmp('logits', dlogits, logits)\n",
"cmp('h', dh, h)\n",
"# cmp('W2', dW2, W2)\n",
"# cmp('b2', db2, b2)\n",
"# cmp('hpreact', dhpreact, hpreact)\n",
"# cmp('bngain', dbngain, bngain)\n",
"# cmp('bnbias', dbnbias, bnbias)\n",
"# cmp('bnraw', dbnraw, bnraw)\n",
"# cmp('bnvar_inv', dbnvar_inv, bnvar_inv)\n",
"# cmp('bnvar', dbnvar, bnvar)\n",
"# cmp('bndiff2', dbndiff2, bndiff2)\n",
"# cmp('bndiff', dbndiff, bndiff)\n",
"# cmp('bnmeani', dbnmeani, bnmeani)\n",
"# cmp('hprebn', dhprebn, hprebn)\n",
"# cmp('embcat', dembcat, embcat)\n",
"# cmp('W1', dW1, W1)\n",
"# cmp('b1', db1, b1)\n",
"# cmp('emb', demb, emb)\n",
"# cmp('C', dC, C)"
"cmp('W2', dW2, W2)\n",
"cmp('b2', db2, b2)\n",
"cmp('hpreact', dhpreact, hpreact)\n",
"cmp('bngain', dbngain, bngain)\n",
"cmp('bnbias', dbnbias, bnbias)\n",
"cmp('bnraw', dbnraw, bnraw)\n",
"cmp('bnvar_inv', dbnvar_inv, bnvar_inv)\n",
"cmp('bnvar', dbnvar, bnvar)\n",
"cmp('bndiff2', dbndiff2, bndiff2)\n",
"cmp('bndiff', dbndiff, bndiff)\n",
"cmp('bnmeani', dbnmeani, bnmeani)\n",
"cmp('hprebn', dhprebn, hprebn)\n",
"cmp('embcat', dembcat, embcat)\n",
"cmp('W1', dW1, W1)\n",
"cmp('b1', db1, b1)\n",
"cmp('emb', demb, emb)\n",
"cmp('C', dC, C)"
]
},
{
Expand Down

0 comments on commit 6a65c60

Please sign in to comment.