Skip to content

Commit

Permalink
fix dtype bug in adam_pax
Browse files Browse the repository at this point in the history
  • Loading branch information
ZhiyuLi-goog committed Mar 12, 2024
1 parent 04ca2db commit fac4513
Showing 1 changed file with 7 additions and 2 deletions.
9 changes: 7 additions & 2 deletions MaxText/optimizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,8 +117,13 @@ def __init__(self, mu, nu):
self.nu = nu

def _update_momentum(update, mu, nu):
beta1_decay = bias_corrected_decay(count, beta1)
beta2_decay = bias_corrected_decay(count, beta2)
# The conversion to the data type of the update ensures that bfloat16 remains
# bfloat16 in the optimizer state. This conversion has to be done after
# `bias_corrected_dacay` is calculated as calculating `jnp.power(decay, t)` in low
# precision can result in it being rounded to 1 and subsequently a
# "division by zero" error.
beta1_decay = bias_corrected_decay(count, beta1).astype(update)
beta2_decay = bias_corrected_decay(count, beta2).astype(update)
mu = (1.0 - beta1_decay) * update + beta1_decay * mu
nu = (1.0 - beta2_decay) * (update**2) + beta2_decay * nu
return _slot_opt_state(mu=mu, nu=nu)
Expand Down

0 comments on commit fac4513

Please sign in to comment.