From a75d9a9aa898f6722e653a1405788198f984f01e Mon Sep 17 00:00:00 2001 From: gobbleturk Date: Wed, 12 Jun 2024 22:24:18 +0000 Subject: [PATCH] base.yml changes circular changes to pipeline.py pyconfig circ changes pipeline parallel tests circular style tree map, half passed tests Total iterations circularized improved iteration comment run all tests test both circular and non-circular circ storage comment circ storage pushing index comment --- MaxText/configs/base.yml | 10 +- MaxText/layers/pipeline.py | 180 ++++++++++++++++++--- MaxText/pyconfig.py | 8 + MaxText/tests/pipeline_parallelism_test.py | 75 ++++++++- 4 files changed, 251 insertions(+), 22 deletions(-) diff --git a/MaxText/configs/base.yml b/MaxText/configs/base.yml index 92f338513..586064c4e 100644 --- a/MaxText/configs/base.yml +++ b/MaxText/configs/base.yml @@ -89,8 +89,16 @@ normalize_embedding_logits: True # whether to normlize pre-softmax logits if lo logits_dot_in_fp32: True # whether to use fp32 in logits_dense or shared_embedding dot product for stability # pipeline parallelism -# The number of decoder layers is equal to the product of num_stages and num_layers_per_pipeline_stage (does not yet support circular pipelines). +# The number of decoder layers is equal to the product of num_stages, num_layers_per_pipeline_stage and num_pipeline_repeats. +# There is a tradeoff between the num_layers_per_pipeline_stage and num_pipeline_repeats: The more layers per stage the easier +# it is to hide the pipeline communication behind the compute since there is more compute per stage, however there will be a larger bubble +# since there are fewer repeats. Similarly there is tradeoff for num_pipeline_microbatches - more microbatches leads to a smaller bubble, +# but a smaller size per microbatch which may hurt per-stage performance. Additionally note when microbatches > num_stages we have the opportunity to +# perform the circular transfer (last stage to first) asynchronously. +# The bubble fraction is (num_stages - 1) / (num_pipeline_repeats * num_pipeline_microbatches + num_stages - 1) num_layers_per_pipeline_stage: 1 +# The number of repeats will be set to num_decoder_layers / (num_pipeline_stages * num_layers_per_pipeline_stage) +num_pipeline_repeats: -1 # num_pipeline_microbatches must be a multiple of the number of pipeline stages. By default it is set to the number of stages. # Note the microbatch_size is given by global_batch_size / num_pipeline_microbatches, where global_batch_size = per_device_batch_size * num_devices num_pipeline_microbatches: -1 diff --git a/MaxText/layers/pipeline.py b/MaxText/layers/pipeline.py index 35d7fbdd7..ba69812bf 100644 --- a/MaxText/layers/pipeline.py +++ b/MaxText/layers/pipeline.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -''' Pipeline layer wrapping a decoder layer(s). Does not yet supports circular pipelining ''' +''' Pipeline layer wrapping a decoder layer(s). Supports circular pipelining ''' import jax import jax.ad_checkpoint @@ -30,11 +30,11 @@ class Pipeline(nn.Module): This module will loop over microbatches and execute the main body with a vmap for both the inputs and weights. This will produce a pipeline pattern if the stage dimension is sharded. - Does not yet support circular pipelines. Multiple - layers per stage are used when a module that executes multiple layers per stage is passed as the layers input. + Supports circular pipelines, and multiple layers per stage are used when a module that executes multiple layers + is passed as the layers input. Attributes: - config: Importantly contains num_pipeline_microbatches. + config: Importantly contains num_pipeline_microbatches, num_pipeline_repeats. layers: A module instance that each stage can execute. It can either be a single layer such as a LlamaDecoderLayer instance or scanned/looped set of decoder layers to execute multiple layers per stage. mesh: The device mesh of the system. @@ -48,17 +48,20 @@ class Pipeline(nn.Module): def setup(self): self.num_stages = self.config.ici_pipeline_parallelism * self.config.dcn_pipeline_parallelism + self.use_circ_storage = self.config.num_pipeline_repeats > 1 and self.config.num_pipeline_microbatches > self.num_stages self.microbatch_size = self.config.global_batch_size_to_train_on // self.config.num_pipeline_microbatches microbatches_per_stage = self.config.num_pipeline_microbatches // self.num_stages self.microbatches_per_stage = microbatches_per_stage def init_states(self, inputs): - '''Initialize components of state: state_io, shift + '''Initialize components of state: state_io, shift, circular_storage and circular_storage_mover Assumes input has already been reshaped into microbatches: [num_micro_batches, micro_batch_size, sequence, embed] Returns a dictionary with properties shift: zeros shape [num_stages, micro_size, sequence, embed] state_io: reshaped inputs [num_stages, microbatches/stages, micro_size, sequence, embed] + circ_storage: zeros [num_stages, microbatches, micro_size, sequence, embed] + circ_storage_mover: zeros[num_stages, micro_size, sequence, embed] loop_iteration: scalar set initially to 0. ''' @@ -73,22 +76,56 @@ def init_states(self, inputs): # We shard the microbatch_size axis by data/fsdp, not num_microbatches since those are looped over. state_io = nn.with_logical_constraint(state_io, ("activation_stage", None, "activation_batch", "activation_length", "activation_embed"),rules=self.config.logical_axis_rules, mesh=self.mesh) + # circ_storage is used to hold the final pipeline stage outputs before it is used for the next repeat. It is only needed + # when num_microbatches > num_stages, else instead the final stage will immediately pass to the first without additional storage. + # circ_storage has shape [num_stages, microbatches, micro_size, sequence, embed]. + # Note that this shape is a factor of num_stages larger than necessary - each stage holds the global batch, but only stage 0 holds the + # real activations (since it will use them), the rest hold dummy ones. This amount of storage [global_batch, sequence, embed] is + # fine as long as there is some amount of additional sharding axes, e.g. FSDP, TP, DP (e.g. there are many devices that shard stage 0) + # We may look into alternatives using less storage if this becomes an issue (ideas in b/347603101). + if self.use_circ_storage: + circ_storage = jnp.zeros((self.num_stages,) + inputs.shape , dtype=inputs.dtype) + else: + circ_storage = None + + # circ_storage_mover is used to push the microbatches from the pipeline into circ_storage with one buffer iteration of delay + # circ_storage_mover shape is same as shift: [num_stages, micro_size, sequence, embed] + if self.use_circ_storage: + circ_storage_mover = shift + else: + circ_storage_mover = None + init_loop_state = { "state_io": state_io, "shift": shift, + "circ_storage": circ_storage, + "circ_storage_mover": circ_storage_mover, "loop_iteration": 0 } return init_loop_state - def get_iteration_inputs(self, loop_iteration, state_io, shift): + def get_iteration_inputs(self, loop_iteration, state_io, circ_storage, shift): ''' Construct stages_in: the global array that is operated on for this iteration, shape same as shift=[stages, micro_size, sequence, embed] - This is almost a rotated version of the last outputs, except for the first stage which must grab a new batch from state_io + This is almost a rotated version of the last outputs, except for the first stage which must grab a new batch from state_io or an old one from circ_storage ''' # Setup potential input from state_io, which has a rotating microbatch index (size of microbatches_per_stage) state_io_batch_idx = loop_iteration % self.microbatches_per_stage - first_stage_in = state_io[:,state_io_batch_idx] + state_io_slice = state_io[:,state_io_batch_idx] + + if self.use_circ_storage: + # Setup potential input from circ_storage, which also has a rotating index for microbatch, size of num_microbatches + circ_storage_batch_idx = loop_iteration % self.config.num_pipeline_microbatches + circular_stage_in = circ_storage[:,circ_storage_batch_idx] + else: + # The last stage immediately flows into the first stage, use this rotated shift instead of circular storage + circular_stage_in = shift + + # For early loop iterations we grab a new input for stage 0 from the state_io. Once each microbatch has left state_io + # we instead grab from the last stage's output (possibly buffered when num_microbatches > num_stages, e.g. from circ_storage). + first_stage_in = jnp.where(loop_iteration < self.config.num_pipeline_microbatches, state_io_slice, circular_stage_in) + # Note that first_stage_in may correspond to bubble computation during the last few iterations. # However these bubble computation results remain in the shift buffer (do not make it back to state_io) and are thus discarded / not returned. # The final returned output is stored in the state_io, which has the appropriate total size of num_microbatches. The state_io will not contain bubble results @@ -113,12 +150,36 @@ def shard_dim_by_stages(self, x, dim: int): sharding = jax.sharding.NamedSharding(self.mesh, jax.sharding.PartitionSpec(*dims_mapping)) return jax.lax.with_sharding_constraint(x, sharding) - def get_microbatch_ids(self, loop_iteration): - '''Gets the microbatch_ids for all stages on this loop_iteration.''' + def get_microbatch_and_repeat_ids(self, loop_iteration): + '''Gets the microbatch_ids and repeat_ids for all stages on this loop_iteration. Works for both circular and non-circular''' # Stage 0 has processed one microbatch every loop_iter, but Stage 1 is one behind due to bubble, etc for other stages microbatches_processed = jnp.maximum(loop_iteration - jnp.arange(self.num_stages), 0) microbatch_ids = microbatches_processed % self.config.num_pipeline_microbatches - return microbatch_ids + repeat_ids = microbatches_processed // self.config.num_pipeline_microbatches + return microbatch_ids, repeat_ids + + def vmap_parallel_gather(self, weights, repeat_ids, repeat_dim_in_weights, stages_dim_in_weights): + """Use vmap to implement a sharded parallel gather. + Parallel gather means each stage has its own weights, and gets one slice from it. + Args: + weights: Per-stage data to be gathered from. + repeat_ids: Integer tensor of shape [num_stages], the repeats of the stages. + repeat_dim_in_weights: The dimension in weights where repeat_ids are applied. The output will not + have this dimension. + stages_dim_in_weights: The dimension in weights that represents parallel stages. + Returns: + The per-stage gathered values. The shape is weights.shape but with repeat_dim_in_weights + removed. + """ + def _gather_one(x, repeat_id): + return jnp.squeeze(jax.lax.dynamic_slice_in_dim(x, repeat_id, 1, repeat_dim_in_weights), repeat_dim_in_weights) + + gathered_weights_stage_dim = 0 + repeat_ids = self.shard_dim_by_stages(repeat_ids, 0) + weights = self.shard_dim_by_stages(weights, stages_dim_in_weights) + stage_weights = jax.vmap(_gather_one, in_axes=(stages_dim_in_weights, 0), out_axes=gathered_weights_stage_dim)(weights, repeat_ids) + stage_weights = self.shard_dim_by_stages(stage_weights, gathered_weights_stage_dim) + return stage_weights def vmap_gather(self, xs, ids, ids_dim): """Use vmap to implement a stage-wise sharded gather. @@ -150,9 +211,13 @@ def get_new_loop_state(self,output, loop_state): * Pushing inputs up from top of state_io into first stage of shift * Pulling outputs up from last stage of shift into bottom of state_io * shift: rotate output right/down by 1 - we imagine the pipeline moves to right/down + * circ_storage: pushes circ_storage_mover (the output of the previous iteration) into rotating index of circ_storage + * circ_storage_mover: assigned to rotated output and pushed into circ_storage on the next iteration ''' old_state_io = loop_state['state_io'] + old_circ_storage = loop_state["circ_storage"] + old_circ_storage_mover = loop_state["circ_storage_mover"] loop_iteration = loop_state["loop_iteration"] # Shift becomes a rotated-right version of the previous output def _rotate_right(output_in): @@ -162,6 +227,22 @@ def _rotate_right(output_in): return jnp.concatenate([last, except_last], axis=0) new_shift = _rotate_right(output) + if self.use_circ_storage: + # Insert the circ_storage_mover into new_circ_storage at a microbatch-rotating index. + # circ_storage_mover still points to the output of PREVIOUS iteration, which should aid in allowing overlapped compute/async transfers + def _rotate_right_and_update(circ_storage_mover_in, circ_storage_in): + rotated = _rotate_right(circ_storage_mover_in) + rotated = jnp.expand_dims(rotated, 1) + # The offset is the previous iterations microbatch ID of the last stage, so that for example microbatch 0 will + # be placed in index 0 of the num_microbatches axis. + offset = (loop_iteration - (self.num_stages - 1) - 1) % self.config.num_pipeline_microbatches # Note extra -1 b/c grabbing from the previous output - using circ_storage_mover before it is updated + return jax.lax.dynamic_update_slice_in_dim(circ_storage_in, rotated, offset, axis=1) + new_circ_storage = _rotate_right_and_update(old_circ_storage_mover, old_circ_storage) + new_circ_storage_mover = output + else: + new_circ_storage = None + new_circ_storage_mover = None + # Rotate stream_io left/up by 1 on rotating micro/stage index (stream_buf_idx), replacing the last/bottom with the last stage output stream_buf_idx = loop_iteration % self.microbatches_per_stage stream_slice = old_state_io[:, stream_buf_idx] @@ -181,13 +262,15 @@ def _update_state_io(state_in, stream_slice, output): new_loop_state = { "state_io": new_state, "shift": new_shift, + "circ_storage": new_circ_storage, + "circ_storage_mover": new_circ_storage_mover, "loop_iteration": loop_iteration + 1 } return new_loop_state def permute_output_micro_per_stage_dim(self, output): # The first real output (batch 0) takes a certain amount of loop iterations to finish and be pushed to state_io - it will land on a different index of state_io depending on the number of iterations. - first_output_num_iters = self.num_stages - 1 + first_output_num_iters = self.config.num_pipeline_microbatches * (self.config.num_pipeline_repeats - 1) + self.num_stages - 1 # The first term above is a multiple of num_pipeline_microbatches and thus could be ignored since its also a multiple of microbatches_per_stage, but we keep it for clairty land_idx = first_output_num_iters % self.microbatches_per_stage permutation = (np.arange(self.microbatches_per_stage) + land_idx) % self.microbatches_per_stage # permute so the value in land_idx is moved into idx 0, and (land_idx + 1) appear in idx 1, etc @@ -217,11 +300,12 @@ def run_one_iteration(self, loop_state, positions, segment_ids, deterministic, m '''Run one loop iteration - gets weights and inputs for each stage, run the stages in parallel, and update the loop state.''' state_io = loop_state['state_io'] shift = loop_state["shift"] + circ_storage = loop_state["circ_storage"] loop_iteration = loop_state["loop_iteration"] - microbatch_ids = self.get_microbatch_ids(loop_iteration) + microbatch_ids, _ = self.get_microbatch_and_repeat_ids(loop_iteration) - stages_inputs = self.get_iteration_inputs(loop_iteration, state_io, shift) + stages_inputs = self.get_iteration_inputs(loop_iteration, state_io, circ_storage, shift) # We checkpoint stages_inputs since we are grabbing only one slice of the state_io, don't need to save the entire buffer. stages_inputs = jax.ad_checkpoint.checkpoint_name(stages_inputs, 'iteration_input') stages_positions = self.vmap_gather(positions, microbatch_ids, 0) if positions is not None else None @@ -229,6 +313,33 @@ def run_one_iteration(self, loop_state, positions, segment_ids, deterministic, m vmap_func = self.get_main_vmap_func() + if self.config.num_pipeline_repeats > 1: + _, repeat_ids = self.get_microbatch_and_repeat_ids(loop_iteration) + + def prepare_vars_for_main_vmap(weights): + def gather_weights_for_stages_in(weights): + return jax.tree.map( + functools.partial( + self.vmap_parallel_gather, repeat_ids=repeat_ids, repeat_dim_in_weights=0, stages_dim_in_weights=1), + weights) + circular_metadata_params={ + nn.PARTITION_NAME: "circular_repeats", + 'sub_weight_split_dims_mapping': (None,), + "is_initializing": self.is_initializing(), + "x_times": self.config.num_pipeline_repeats, + 'optimizer_dims_mapping': None, + } + weights = meta.remove_axis(weights, 0, circular_metadata_params) # Remove the circular metadata axis, this axis will be removed when passed to the main vmap, only one circular entry per stage. + weights = gather_weights_for_stages_in(weights) + return weights + + vmap_func = nn.map_variables( + vmap_func, + mapped_collections=["params", "non_trainable", "summaries", "intermediates"], + mutable=True, + trans_in_fn=prepare_vars_for_main_vmap, + ) + stages_output = vmap_func(decoder_layer_instance, stages_inputs, stages_segment_ids, stages_positions, deterministic, model_mode) if self.config.scan_layers: stages_output = stages_output[0] @@ -248,26 +359,55 @@ def __call__(self, inputs: jnp.ndarray, segment_ids: jnp.ndarray, positions:jnp. if positions is not None: positions = positions.reshape((self.config.num_pipeline_microbatches, self.microbatch_size, self.config.max_target_length)) example_position = jax.lax.broadcast(positions[0], [self.num_stages]) + position_idx = 0 else: example_position = None + position_idx = None if segment_ids is not None: segment_ids = segment_ids.reshape((self.config.num_pipeline_microbatches, self.microbatch_size, self.config.max_target_length)) example_segmentation = jax.lax.broadcast(segment_ids[0], [self.num_stages]) + segment_idx = 0 else: example_segmentation = None + segment_idx = None loop_state = self.init_states(inputs) - # Each microbatch should go through each stage - so there is num_micro * num_stages compute to perform - # Each iteration is vmapped by num_stages, so the number of iterations should be num_micro * num_stages / num_stages = num_micro + # Each microbatch should go through each stage (with repeats) - so there is num_micro * (num_stages * repeats) compute to perform + # Each iteration is vmapped by num_stages, so the number of iterations should be num_micro * num_stages * repeats / num_stages = num_micro * repeats # However due to the pipeline bubble some iterations process less than num_stages microbatches. It takes - # num_micro - 1 iterations for the last microbatch to enter the pipeline, then num_stages more iterations to complete the pipeline. - # Thus the total iterations is num_micro + num_stages - 1, and we may consider the num_stages - 1 as bubble. - total_iterations = self.config.num_pipeline_microbatches + self.num_stages - 1 + # num_micro * repeat iterations for the last microbatch to start the final repeat, then an additional num_stages - 1 to finish the final repeat. + # Thus the total iterations is num_micro * repeat + num_stages - 1, and we may consider the num_stages - 1 as bubble. + total_iterations = self.config.num_pipeline_microbatches * self.config.num_pipeline_repeats + self.num_stages - 1 if self.is_initializing(): vmap_func = self.get_main_vmap_func() + if self.config.num_pipeline_repeats > 1: + # To shard the weights on initialization for the circular pipeline we create weights of + # shape [num_repeat, num_stages, ...] (e.g. [num_repeat, num_stages, embed, mlp]) and shard the num_stages axis. + # We wrap the main stage vmap with a num_repeat vmap to generate this axis only for parameter initialization. + vmap_func= nn.vmap( + vmap_func, + in_axes=(0, segment_idx, position_idx, None, None), + variable_axes={ + 'params': 0, + "non_trainable": 0, + "hyper_params": 0, + }, + split_rngs={'params': True}, + metadata_params={ + nn.PARTITION_NAME: "circular_repeats", + 'sub_weight_split_dims_mapping': (None,), + "is_initializing": True, + "x_times": self.config.num_pipeline_repeats, + 'optimizer_dims_mapping': None, + } + ) + + example_inputs = jax.lax.broadcast(example_inputs, [self.config.num_pipeline_repeats]) + example_segmentation = jax.lax.broadcast(example_segmentation, [self.config.num_pipeline_repeats]) if example_segmentation is not None else None + example_position = jax.lax.broadcast(example_position, [self.config.num_pipeline_repeats]) if example_position is not None else None # We only need to run one set of stages to initialize the variables, instead of looping over all microbatches for the full total_iterations. stage_outputs = vmap_func(self.layers, example_inputs, example_segmentation, example_position, deterministic, model_mode) if self.config.scan_layers: @@ -275,6 +415,8 @@ def __call__(self, inputs: jnp.ndarray, segment_ids: jnp.ndarray, positions:jnp. # We return something of the correct shape (global_batch, sequence, embed) by reshaping a single stages output which has # shape [microbatch_size, sequence, embed] + if self.config.num_pipeline_repeats > 1: + stage_outputs = stage_outputs[0] # Remove extra dimension created for the circular vmap broadcasted_stage_outpus = jax.lax.broadcast(stage_outputs[0], [self.config.global_batch_size_to_train_on // self.microbatch_size]) return jnp.reshape(broadcasted_stage_outpus, [self.config.global_batch_size_to_train_on, self.config.max_target_length, self.config.emb_dim]) diff --git a/MaxText/pyconfig.py b/MaxText/pyconfig.py index a49631ef5..714d9692f 100644 --- a/MaxText/pyconfig.py +++ b/MaxText/pyconfig.py @@ -14,6 +14,7 @@ limitations under the License. """ +# pytype: skip-file # pylint: disable=missing-module-docstring, bare-except, consider-using-generator from collections import OrderedDict import math @@ -278,8 +279,15 @@ def user_init(raw_keys): if using_pipeline_parallelism(raw_keys): raw_keys["using_pipeline_parallelism"] = True num_stages = int(raw_keys['ici_pipeline_parallelism'] * raw_keys['dcn_pipeline_parallelism']) + if raw_keys['num_pipeline_repeats'] == -1: + num_pipeline_repeats, remainder = divmod(raw_keys['num_decoder_layers'], num_stages * raw_keys['num_layers_per_pipeline_stage']) + assert not remainder, f"The number of layers per stage ({raw_keys['num_layers_per_pipeline_stage']}) times the number of stages ({num_stages}) must divide the number of decoder layers ({raw_keys['num_decoder_layers']}) " + raw_keys['num_pipeline_repeats'] = num_pipeline_repeats + assert num_stages * raw_keys['num_pipeline_repeats'] * raw_keys['num_layers_per_pipeline_stage'] == raw_keys['num_decoder_layers'], f"The product of pipeline stages ({num_stages}), repeats ({raw_keys['num_pipeline_repeats']}), and layers per stage ({raw_keys['num_layers_per_pipeline_stage']}) must be equal to the number of layers ({raw_keys['num_decoder_layers']})" if raw_keys['num_pipeline_microbatches'] == -1: raw_keys['num_pipeline_microbatches'] = num_stages + assert raw_keys['num_pipeline_microbatches'] % num_stages == 0, f"The number of microbatches ({raw_keys['num_pipeline_microbatches']}) must be divisible by the number of stages ({num_stages})" + assert raw_keys['global_batch_size_to_train_on'] % raw_keys['num_pipeline_microbatches'] == 0, f"The global batch size ({raw_keys['global_batch_size_to_train_on']}) must be divisible by the number of microbatches ({raw_keys['num_pipeline_microbatches']})" else: raw_keys["using_pipeline_parallelism"] = False diff --git a/MaxText/tests/pipeline_parallelism_test.py b/MaxText/tests/pipeline_parallelism_test.py index 276cb7e53..b45dec976 100644 --- a/MaxText/tests/pipeline_parallelism_test.py +++ b/MaxText/tests/pipeline_parallelism_test.py @@ -101,7 +101,14 @@ def pipeline_parallelism_dummy_loss(params, inputs, inputs_position, inputs_segm def regular_sequential_layers(params, inputs, inputs_position, inputs_segmentation, deterministic, model_mode): def get_cur_layer_params(params, layer_idx): def get_cur_layer_params_arr(leaf): - if config.num_layers_per_pipeline_stage > 1: + # Reshape layers into a linear list of layers, e.g. [repeat, stage] into [layers] + if config.num_pipeline_repeats > 1 and config.num_layers_per_pipeline_stage == 1: + new_shape = (leaf.shape[0] * leaf.shape[1],) + leaf.shape[2:] + leaf = jnp.reshape(leaf, new_shape) # [repeat, stage] -> [layers] + elif config.num_pipeline_repeats > 1 and config.num_layers_per_pipeline_stage > 1: + new_shape = (leaf.shape[0] * leaf.shape[1] * leaf.shape[2],) + leaf.shape[3:] + leaf = jnp.reshape(leaf, new_shape) # [repeat, stage, layers_per_stage] -> [layers] + elif config.num_pipeline_repeats == 1 and config.num_layers_per_pipeline_stage > 1: new_shape = (leaf.shape[0] * leaf.shape[1],) + leaf.shape[2:] leaf = jnp.reshape(leaf, new_shape) # [stage, layers_per_stage] -> [layers] return leaf[layer_idx] @@ -111,6 +118,9 @@ def get_cur_layer_params_arr(leaf): for layer in range(config.num_decoder_layers): cur_layer_params = get_cur_layer_params(params, layer) cur_layer_params['params'] = cur_layer_params['params']['layers'] + if config.num_pipeline_repeats > 1 and config.num_layers_per_pipeline_stage > 1: + cur_layer_params['params'] = meta.remove_axis(cur_layer_params['params'], 0, {nn.PARTITION_NAME:"circular_repeats"}) + cur_layer_params['params'] = meta.remove_axis(cur_layer_params['params'], 0, {nn.PARTITION_NAME:"layers"}) reg_layer_activations, _ = single_pipeline_stage.apply(cur_layer_params, reg_layer_activations, inputs_position, inputs_segmentation, deterministic, model_mode) return reg_layer_activations @@ -121,6 +131,40 @@ def regular_sequential_layers_dummy_loss(params, inputs, inputs_position, inputs assert_same_output_and_grad(regular_sequential_layers_dummy_loss, pipeline_parallelism_dummy_loss, init_pipeline_params, inputs, inputs_segmentation, inputs_position, deterministic, model_mode, dummy_targets) + @pytest.mark.tpu + def test_circular_minimum_microbatches_same_output_and_grad(self): + # 4 stages, 8 layers (2 repeats, 1 layer per stage), 4 microbatches + pyconfig.initialize( + [sys.argv[0], "configs/base.yml"], + enable_checkpointing=False, + run_name="circular_minimum_microbatches", + max_target_length=128, + base_emb_dim=28, + ici_pipeline_parallelism=4, + base_num_decoder_layers=8, + num_pipeline_microbatches=4, + per_device_batch_size=4 + ) + config = pyconfig.config + self.assert_pipeline_same_output_and_grad(config) + + @pytest.mark.tpu + def test_circular_extra_microbatches_same_output_and_grad(self): + # 4 stages, 8 layers (2 repeats, 1 layer per stage), 8 microbatches + pyconfig.initialize( + [sys.argv[0], "configs/base.yml"], + enable_checkpointing=False, + run_name="circular_extra_microbatches", + max_target_length=128, + base_emb_dim=28, + ici_pipeline_parallelism=4, + base_num_decoder_layers=8, + num_pipeline_microbatches=8, + per_device_batch_size=4 + ) + config = pyconfig.config + self.assert_pipeline_same_output_and_grad(config) + @pytest.mark.tpu def test_non_circular_same_output_and_grad(self): # 4 stages, 4 layers (no circular repeats, 1 layer per stage), 4 microbatches @@ -138,6 +182,33 @@ def test_non_circular_same_output_and_grad(self): config = pyconfig.config self.assert_pipeline_same_output_and_grad(config) + @pytest.mark.tpu + def test_full_train_circular(self): + # Run a full train.py call with 4 stages, 32 layers (2 layers per stage, 4 circular repeats), 8 microbatches + train_main([ + None, + "configs/base.yml", + r"base_output_directory=gs://runner-maxtext-logs", + "run_name=runner_pipeline_parallelism_test", + r"dataset_path=gs://maxtext-dataset", + "base_emb_dim=28", + "base_num_query_heads=4", + "base_num_kv_heads=4", + "base_mlp_dim=32", + "base_num_decoder_layers=32", + "head_dim=128", + "per_device_batch_size=2", + "max_target_length=1024", + "vocab_size=32", + "dataset_type=synthetic", + "steps=3", + "enable_checkpointing=False", + "ici_pipeline_parallelism=4", + "num_layers_per_pipeline_stage=2", + "num_pipeline_microbatches=8", + + ]) + @pytest.mark.tpu def test_full_train_non_circular(self): # Run a full train.py call with 4 stages, 32 layers (8 layers per stage), 8 microbatches @@ -166,4 +237,4 @@ def test_full_train_non_circular(self): ]) if __name__ == "__main__": - unittest.main() \ No newline at end of file + unittest.main()