Skip to content

Commit

Permalink
Simplified RNG handling (#276)
Browse files Browse the repository at this point in the history
  • Loading branch information
Rafi Witten committed Jan 9, 2024
1 parent 65d1d16 commit 48bb4f0
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 13 deletions.
2 changes: 1 addition & 1 deletion MaxText/maxtext_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ def get_functional_train_with_signature(train_step, mesh, state_mesh_annotations
data_sharding = jax.tree_map(
lambda p: jax.sharding.NamedSharding(mesh, p), data_pspec)
in_shardings = (state_mesh_shardings, data_sharding, None) # State, batch, rng
out_shardings = (state_mesh_shardings, None, None) # State, metrics, rng
out_shardings = (state_mesh_shardings, None) # State, metrics
static_argnums = () # We partial out the static argnums of model and config
donate_argnums = 0 # This is the index of the state - we allow the compiler to make use of this memory.
return functional_train, in_shardings, out_shardings, static_argnums, donate_argnums
Expand Down
25 changes: 13 additions & 12 deletions MaxText/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,8 +175,7 @@ def train_step(model, config, state, data, dropout_rng):
"""
# inputs, targets, segments, positions = apply_args
rng1, gen_aqt_rng = jax.random.split(dropout_rng)
aqt_rng, rng2 = jax.random.split(gen_aqt_rng)
rng1, aqt_rng = jax.random.split(dropout_rng)

# decimate proportion of data when per_device_batch_size<1
for k, v in data.items():
Expand Down Expand Up @@ -209,20 +208,20 @@ def loss_fn(params):
if config.record_internal_nn_metrics:
record_activation_metrics(metrics, intermediate_outputs, config)

return new_state, metrics, rng2
return new_state, metrics

def setup_train_loop(config):
def setup_train_loop(config, init_rng):
""" Set up prerequisites for the training loop -
checkpoint_manager, PRNG keys, Mesh, Model and optimizer.
Set up data iterator and tokenizer, initialize the model.
Args:
config
init_rng
Returns:
writer: Summary writer for tensorboard
checkpoint_manager: Orbax checkpointer
nextrng: key used in train_step for dropout
state_mesh_annotations: the mesh annotations for the train state
model:
mesh:
Expand All @@ -237,9 +236,6 @@ def setup_train_loop(config):
config.async_checkpointing,
config.checkpoint_period,
)
# Initial PRNG Keys
init_rng, nextrng = random.split(random.PRNGKey(config.init_weights_seed), 2)

# Mesh definition
devices_array = max_utils.create_device_mesh(config)
mesh = Mesh(devices_array, config.mesh_axes)
Expand All @@ -253,7 +249,7 @@ def setup_train_loop(config):

state, state_mesh_annotations = max_utils.setup_training_state(model, tx, config, init_rng, mesh, checkpoint_manager)

return ( writer, checkpoint_manager, nextrng, state_mesh_annotations, model,
return ( writer, checkpoint_manager, state_mesh_annotations, model,
mesh, learning_rate_schedule, data_iterator, state)


Expand All @@ -265,8 +261,10 @@ def train_loop(config, state=None):
ckpt_path:
Returns:
"""
( writer, checkpoint_manager, nextrng, state_mesh_annotations, model,
mesh, learning_rate_schedule, data_iterator, state) = setup_train_loop(config)
init_rng = random.PRNGKey(config.init_weights_seed)

( writer, checkpoint_manager, state_mesh_annotations, model,
mesh, learning_rate_schedule, data_iterator, state) = setup_train_loop(config, init_rng)

functional_train, in_shard, out_shard, static_argnums, donate_argnums = maxtext_utils.get_functional_train_with_signature(
train_step,
Expand Down Expand Up @@ -305,13 +303,16 @@ def train_loop(config, state=None):
if config.enable_profiler and first_profiling_step >= config.steps:
raise ValueError("Profiling requested but initial profiling step set past training final step")
last_profiling_step = np.clip(first_profiling_step + config.profiler_steps - 1, first_profiling_step, config.steps - 1)

for step in np.arange(start_step, config.steps):
if step == first_profiling_step:
max_utils.activate_profiler(config)

example_batch = load_next_batch(data_iterator, example_batch, config)

nextrng = jax.random.fold_in(init_rng, start_step)
with mesh, nn_partitioning.axis_rules(config.logical_axis_rules):
state, metrics, nextrng = p_train_step(
state, metrics = p_train_step(
state, example_batch, nextrng
)

Expand Down

0 comments on commit 48bb4f0

Please sign in to comment.