-
Notifications
You must be signed in to change notification settings - Fork 247
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
nextrng
not checkpointed, consider using fold_in(config.seed, step)
#276
Comments
We will follow the suggestion here. |
rwitten
pushed a commit
that referenced
this issue
Jan 9, 2024
rwitten
pushed a commit
that referenced
this issue
Jan 9, 2024
rwitten
pushed a commit
that referenced
this issue
Jan 9, 2024
done! |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Looks like the
nextrng
value is not saved to the checkpoint:maxtext/MaxText/train.py
Line 285 in bab4ee7
It doesn't matter in most cases since most people don't use dropout or stochastic rounding. But in cases where it does matter, it's cleaner to generate the RNG for a training step using
jax.random.fold_in(config.seed, state.step)
. This way, no checkpointing is required, and there's also some other side advantages listed in https://twitter.com/cgarciae88/status/1615022554992738315.The text was updated successfully, but these errors were encountered: