Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

nextrng not checkpointed, consider using fold_in(config.seed, step) #276

Closed
reinerp opened this issue Dec 4, 2023 · 2 comments
Closed
Assignees

Comments

@reinerp
Copy link

reinerp commented Dec 4, 2023

Looks like the nextrng value is not saved to the checkpoint:

state, metrics, nextrng = p_train_step(

It doesn't matter in most cases since most people don't use dropout or stochastic rounding. But in cases where it does matter, it's cleaner to generate the RNG for a training step using jax.random.fold_in(config.seed, state.step). This way, no checkpointing is required, and there's also some other side advantages listed in https://twitter.com/cgarciae88/status/1615022554992738315.

@rwitten
Copy link
Collaborator

rwitten commented Dec 4, 2023

We will follow the suggestion here.

rwitten pushed a commit that referenced this issue Jan 9, 2024
rwitten pushed a commit that referenced this issue Jan 9, 2024
rwitten pushed a commit that referenced this issue Jan 9, 2024
@rwitten
Copy link
Collaborator

rwitten commented Jan 9, 2024

done!

@rwitten rwitten closed this as completed Jan 9, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants