diff --git a/MaxText/configs/base.yml b/MaxText/configs/base.yml index 92f338513..586064c4e 100644 --- a/MaxText/configs/base.yml +++ b/MaxText/configs/base.yml @@ -89,8 +89,16 @@ normalize_embedding_logits: True # whether to normlize pre-softmax logits if lo logits_dot_in_fp32: True # whether to use fp32 in logits_dense or shared_embedding dot product for stability # pipeline parallelism -# The number of decoder layers is equal to the product of num_stages and num_layers_per_pipeline_stage (does not yet support circular pipelines). +# The number of decoder layers is equal to the product of num_stages, num_layers_per_pipeline_stage and num_pipeline_repeats. +# There is a tradeoff between the num_layers_per_pipeline_stage and num_pipeline_repeats: The more layers per stage the easier +# it is to hide the pipeline communication behind the compute since there is more compute per stage, however there will be a larger bubble +# since there are fewer repeats. Similarly there is tradeoff for num_pipeline_microbatches - more microbatches leads to a smaller bubble, +# but a smaller size per microbatch which may hurt per-stage performance. Additionally note when microbatches > num_stages we have the opportunity to +# perform the circular transfer (last stage to first) asynchronously. +# The bubble fraction is (num_stages - 1) / (num_pipeline_repeats * num_pipeline_microbatches + num_stages - 1) num_layers_per_pipeline_stage: 1 +# The number of repeats will be set to num_decoder_layers / (num_pipeline_stages * num_layers_per_pipeline_stage) +num_pipeline_repeats: -1 # num_pipeline_microbatches must be a multiple of the number of pipeline stages. By default it is set to the number of stages. # Note the microbatch_size is given by global_batch_size / num_pipeline_microbatches, where global_batch_size = per_device_batch_size * num_devices num_pipeline_microbatches: -1 diff --git a/MaxText/layers/pipeline.py b/MaxText/layers/pipeline.py index 35d7fbdd7..ba69812bf 100644 --- a/MaxText/layers/pipeline.py +++ b/MaxText/layers/pipeline.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -''' Pipeline layer wrapping a decoder layer(s). Does not yet supports circular pipelining ''' +''' Pipeline layer wrapping a decoder layer(s). Supports circular pipelining ''' import jax import jax.ad_checkpoint @@ -30,11 +30,11 @@ class Pipeline(nn.Module): This module will loop over microbatches and execute the main body with a vmap for both the inputs and weights. This will produce a pipeline pattern if the stage dimension is sharded. - Does not yet support circular pipelines. Multiple - layers per stage are used when a module that executes multiple layers per stage is passed as the layers input. + Supports circular pipelines, and multiple layers per stage are used when a module that executes multiple layers + is passed as the layers input. Attributes: - config: Importantly contains num_pipeline_microbatches. + config: Importantly contains num_pipeline_microbatches, num_pipeline_repeats. layers: A module instance that each stage can execute. It can either be a single layer such as a LlamaDecoderLayer instance or scanned/looped set of decoder layers to execute multiple layers per stage. mesh: The device mesh of the system. @@ -48,17 +48,20 @@ class Pipeline(nn.Module): def setup(self): self.num_stages = self.config.ici_pipeline_parallelism * self.config.dcn_pipeline_parallelism + self.use_circ_storage = self.config.num_pipeline_repeats > 1 and self.config.num_pipeline_microbatches > self.num_stages self.microbatch_size = self.config.global_batch_size_to_train_on // self.config.num_pipeline_microbatches microbatches_per_stage = self.config.num_pipeline_microbatches // self.num_stages self.microbatches_per_stage = microbatches_per_stage def init_states(self, inputs): - '''Initialize components of state: state_io, shift + '''Initialize components of state: state_io, shift, circular_storage and circular_storage_mover Assumes input has already been reshaped into microbatches: [num_micro_batches, micro_batch_size, sequence, embed] Returns a dictionary with properties shift: zeros shape [num_stages, micro_size, sequence, embed] state_io: reshaped inputs [num_stages, microbatches/stages, micro_size, sequence, embed] + circ_storage: zeros [num_stages, microbatches, micro_size, sequence, embed] + circ_storage_mover: zeros[num_stages, micro_size, sequence, embed] loop_iteration: scalar set initially to 0. ''' @@ -73,22 +76,56 @@ def init_states(self, inputs): # We shard the microbatch_size axis by data/fsdp, not num_microbatches since those are looped over. state_io = nn.with_logical_constraint(state_io, ("activation_stage", None, "activation_batch", "activation_length", "activation_embed"),rules=self.config.logical_axis_rules, mesh=self.mesh) + # circ_storage is used to hold the final pipeline stage outputs before it is used for the next repeat. It is only needed + # when num_microbatches > num_stages, else instead the final stage will immediately pass to the first without additional storage. + # circ_storage has shape [num_stages, microbatches, micro_size, sequence, embed]. + # Note that this shape is a factor of num_stages larger than necessary - each stage holds the global batch, but only stage 0 holds the + # real activations (since it will use them), the rest hold dummy ones. This amount of storage [global_batch, sequence, embed] is + # fine as long as there is some amount of additional sharding axes, e.g. FSDP, TP, DP (e.g. there are many devices that shard stage 0) + # We may look into alternatives using less storage if this becomes an issue (ideas in b/347603101). + if self.use_circ_storage: + circ_storage = jnp.zeros((self.num_stages,) + inputs.shape , dtype=inputs.dtype) + else: + circ_storage = None + + # circ_storage_mover is used to push the microbatches from the pipeline into circ_storage with one buffer iteration of delay + # circ_storage_mover shape is same as shift: [num_stages, micro_size, sequence, embed] + if self.use_circ_storage: + circ_storage_mover = shift + else: + circ_storage_mover = None + init_loop_state = { "state_io": state_io, "shift": shift, + "circ_storage": circ_storage, + "circ_storage_mover": circ_storage_mover, "loop_iteration": 0 } return init_loop_state - def get_iteration_inputs(self, loop_iteration, state_io, shift): + def get_iteration_inputs(self, loop_iteration, state_io, circ_storage, shift): ''' Construct stages_in: the global array that is operated on for this iteration, shape same as shift=[stages, micro_size, sequence, embed] - This is almost a rotated version of the last outputs, except for the first stage which must grab a new batch from state_io + This is almost a rotated version of the last outputs, except for the first stage which must grab a new batch from state_io or an old one from circ_storage ''' # Setup potential input from state_io, which has a rotating microbatch index (size of microbatches_per_stage) state_io_batch_idx = loop_iteration % self.microbatches_per_stage - first_stage_in = state_io[:,state_io_batch_idx] + state_io_slice = state_io[:,state_io_batch_idx] + + if self.use_circ_storage: + # Setup potential input from circ_storage, which also has a rotating index for microbatch, size of num_microbatches + circ_storage_batch_idx = loop_iteration % self.config.num_pipeline_microbatches + circular_stage_in = circ_storage[:,circ_storage_batch_idx] + else: + # The last stage immediately flows into the first stage, use this rotated shift instead of circular storage + circular_stage_in = shift + + # For early loop iterations we grab a new input for stage 0 from the state_io. Once each microbatch has left state_io + # we instead grab from the last stage's output (possibly buffered when num_microbatches > num_stages, e.g. from circ_storage). + first_stage_in = jnp.where(loop_iteration < self.config.num_pipeline_microbatches, state_io_slice, circular_stage_in) + # Note that first_stage_in may correspond to bubble computation during the last few iterations. # However these bubble computation results remain in the shift buffer (do not make it back to state_io) and are thus discarded / not returned. # The final returned output is stored in the state_io, which has the appropriate total size of num_microbatches. The state_io will not contain bubble results @@ -113,12 +150,36 @@ def shard_dim_by_stages(self, x, dim: int): sharding = jax.sharding.NamedSharding(self.mesh, jax.sharding.PartitionSpec(*dims_mapping)) return jax.lax.with_sharding_constraint(x, sharding) - def get_microbatch_ids(self, loop_iteration): - '''Gets the microbatch_ids for all stages on this loop_iteration.''' + def get_microbatch_and_repeat_ids(self, loop_iteration): + '''Gets the microbatch_ids and repeat_ids for all stages on this loop_iteration. Works for both circular and non-circular''' # Stage 0 has processed one microbatch every loop_iter, but Stage 1 is one behind due to bubble, etc for other stages microbatches_processed = jnp.maximum(loop_iteration - jnp.arange(self.num_stages), 0) microbatch_ids = microbatches_processed % self.config.num_pipeline_microbatches - return microbatch_ids + repeat_ids = microbatches_processed // self.config.num_pipeline_microbatches + return microbatch_ids, repeat_ids + + def vmap_parallel_gather(self, weights, repeat_ids, repeat_dim_in_weights, stages_dim_in_weights): + """Use vmap to implement a sharded parallel gather. + Parallel gather means each stage has its own weights, and gets one slice from it. + Args: + weights: Per-stage data to be gathered from. + repeat_ids: Integer tensor of shape [num_stages], the repeats of the stages. + repeat_dim_in_weights: The dimension in weights where repeat_ids are applied. The output will not + have this dimension. + stages_dim_in_weights: The dimension in weights that represents parallel stages. + Returns: + The per-stage gathered values. The shape is weights.shape but with repeat_dim_in_weights + removed. + """ + def _gather_one(x, repeat_id): + return jnp.squeeze(jax.lax.dynamic_slice_in_dim(x, repeat_id, 1, repeat_dim_in_weights), repeat_dim_in_weights) + + gathered_weights_stage_dim = 0 + repeat_ids = self.shard_dim_by_stages(repeat_ids, 0) + weights = self.shard_dim_by_stages(weights, stages_dim_in_weights) + stage_weights = jax.vmap(_gather_one, in_axes=(stages_dim_in_weights, 0), out_axes=gathered_weights_stage_dim)(weights, repeat_ids) + stage_weights = self.shard_dim_by_stages(stage_weights, gathered_weights_stage_dim) + return stage_weights def vmap_gather(self, xs, ids, ids_dim): """Use vmap to implement a stage-wise sharded gather. @@ -150,9 +211,13 @@ def get_new_loop_state(self,output, loop_state): * Pushing inputs up from top of state_io into first stage of shift * Pulling outputs up from last stage of shift into bottom of state_io * shift: rotate output right/down by 1 - we imagine the pipeline moves to right/down + * circ_storage: pushes circ_storage_mover (the output of the previous iteration) into rotating index of circ_storage + * circ_storage_mover: assigned to rotated output and pushed into circ_storage on the next iteration ''' old_state_io = loop_state['state_io'] + old_circ_storage = loop_state["circ_storage"] + old_circ_storage_mover = loop_state["circ_storage_mover"] loop_iteration = loop_state["loop_iteration"] # Shift becomes a rotated-right version of the previous output def _rotate_right(output_in): @@ -162,6 +227,22 @@ def _rotate_right(output_in): return jnp.concatenate([last, except_last], axis=0) new_shift = _rotate_right(output) + if self.use_circ_storage: + # Insert the circ_storage_mover into new_circ_storage at a microbatch-rotating index. + # circ_storage_mover still points to the output of PREVIOUS iteration, which should aid in allowing overlapped compute/async transfers + def _rotate_right_and_update(circ_storage_mover_in, circ_storage_in): + rotated = _rotate_right(circ_storage_mover_in) + rotated = jnp.expand_dims(rotated, 1) + # The offset is the previous iterations microbatch ID of the last stage, so that for example microbatch 0 will + # be placed in index 0 of the num_microbatches axis. + offset = (loop_iteration - (self.num_stages - 1) - 1) % self.config.num_pipeline_microbatches # Note extra -1 b/c grabbing from the previous output - using circ_storage_mover before it is updated + return jax.lax.dynamic_update_slice_in_dim(circ_storage_in, rotated, offset, axis=1) + new_circ_storage = _rotate_right_and_update(old_circ_storage_mover, old_circ_storage) + new_circ_storage_mover = output + else: + new_circ_storage = None + new_circ_storage_mover = None + # Rotate stream_io left/up by 1 on rotating micro/stage index (stream_buf_idx), replacing the last/bottom with the last stage output stream_buf_idx = loop_iteration % self.microbatches_per_stage stream_slice = old_state_io[:, stream_buf_idx] @@ -181,13 +262,15 @@ def _update_state_io(state_in, stream_slice, output): new_loop_state = { "state_io": new_state, "shift": new_shift, + "circ_storage": new_circ_storage, + "circ_storage_mover": new_circ_storage_mover, "loop_iteration": loop_iteration + 1 } return new_loop_state def permute_output_micro_per_stage_dim(self, output): # The first real output (batch 0) takes a certain amount of loop iterations to finish and be pushed to state_io - it will land on a different index of state_io depending on the number of iterations. - first_output_num_iters = self.num_stages - 1 + first_output_num_iters = self.config.num_pipeline_microbatches * (self.config.num_pipeline_repeats - 1) + self.num_stages - 1 # The first term above is a multiple of num_pipeline_microbatches and thus could be ignored since its also a multiple of microbatches_per_stage, but we keep it for clairty land_idx = first_output_num_iters % self.microbatches_per_stage permutation = (np.arange(self.microbatches_per_stage) + land_idx) % self.microbatches_per_stage # permute so the value in land_idx is moved into idx 0, and (land_idx + 1) appear in idx 1, etc @@ -217,11 +300,12 @@ def run_one_iteration(self, loop_state, positions, segment_ids, deterministic, m '''Run one loop iteration - gets weights and inputs for each stage, run the stages in parallel, and update the loop state.''' state_io = loop_state['state_io'] shift = loop_state["shift"] + circ_storage = loop_state["circ_storage"] loop_iteration = loop_state["loop_iteration"] - microbatch_ids = self.get_microbatch_ids(loop_iteration) + microbatch_ids, _ = self.get_microbatch_and_repeat_ids(loop_iteration) - stages_inputs = self.get_iteration_inputs(loop_iteration, state_io, shift) + stages_inputs = self.get_iteration_inputs(loop_iteration, state_io, circ_storage, shift) # We checkpoint stages_inputs since we are grabbing only one slice of the state_io, don't need to save the entire buffer. stages_inputs = jax.ad_checkpoint.checkpoint_name(stages_inputs, 'iteration_input') stages_positions = self.vmap_gather(positions, microbatch_ids, 0) if positions is not None else None @@ -229,6 +313,33 @@ def run_one_iteration(self, loop_state, positions, segment_ids, deterministic, m vmap_func = self.get_main_vmap_func() + if self.config.num_pipeline_repeats > 1: + _, repeat_ids = self.get_microbatch_and_repeat_ids(loop_iteration) + + def prepare_vars_for_main_vmap(weights): + def gather_weights_for_stages_in(weights): + return jax.tree.map( + functools.partial( + self.vmap_parallel_gather, repeat_ids=repeat_ids, repeat_dim_in_weights=0, stages_dim_in_weights=1), + weights) + circular_metadata_params={ + nn.PARTITION_NAME: "circular_repeats", + 'sub_weight_split_dims_mapping': (None,), + "is_initializing": self.is_initializing(), + "x_times": self.config.num_pipeline_repeats, + 'optimizer_dims_mapping': None, + } + weights = meta.remove_axis(weights, 0, circular_metadata_params) # Remove the circular metadata axis, this axis will be removed when passed to the main vmap, only one circular entry per stage. + weights = gather_weights_for_stages_in(weights) + return weights + + vmap_func = nn.map_variables( + vmap_func, + mapped_collections=["params", "non_trainable", "summaries", "intermediates"], + mutable=True, + trans_in_fn=prepare_vars_for_main_vmap, + ) + stages_output = vmap_func(decoder_layer_instance, stages_inputs, stages_segment_ids, stages_positions, deterministic, model_mode) if self.config.scan_layers: stages_output = stages_output[0] @@ -248,26 +359,55 @@ def __call__(self, inputs: jnp.ndarray, segment_ids: jnp.ndarray, positions:jnp. if positions is not None: positions = positions.reshape((self.config.num_pipeline_microbatches, self.microbatch_size, self.config.max_target_length)) example_position = jax.lax.broadcast(positions[0], [self.num_stages]) + position_idx = 0 else: example_position = None + position_idx = None if segment_ids is not None: segment_ids = segment_ids.reshape((self.config.num_pipeline_microbatches, self.microbatch_size, self.config.max_target_length)) example_segmentation = jax.lax.broadcast(segment_ids[0], [self.num_stages]) + segment_idx = 0 else: example_segmentation = None + segment_idx = None loop_state = self.init_states(inputs) - # Each microbatch should go through each stage - so there is num_micro * num_stages compute to perform - # Each iteration is vmapped by num_stages, so the number of iterations should be num_micro * num_stages / num_stages = num_micro + # Each microbatch should go through each stage (with repeats) - so there is num_micro * (num_stages * repeats) compute to perform + # Each iteration is vmapped by num_stages, so the number of iterations should be num_micro * num_stages * repeats / num_stages = num_micro * repeats # However due to the pipeline bubble some iterations process less than num_stages microbatches. It takes - # num_micro - 1 iterations for the last microbatch to enter the pipeline, then num_stages more iterations to complete the pipeline. - # Thus the total iterations is num_micro + num_stages - 1, and we may consider the num_stages - 1 as bubble. - total_iterations = self.config.num_pipeline_microbatches + self.num_stages - 1 + # num_micro * repeat iterations for the last microbatch to start the final repeat, then an additional num_stages - 1 to finish the final repeat. + # Thus the total iterations is num_micro * repeat + num_stages - 1, and we may consider the num_stages - 1 as bubble. + total_iterations = self.config.num_pipeline_microbatches * self.config.num_pipeline_repeats + self.num_stages - 1 if self.is_initializing(): vmap_func = self.get_main_vmap_func() + if self.config.num_pipeline_repeats > 1: + # To shard the weights on initialization for the circular pipeline we create weights of + # shape [num_repeat, num_stages, ...] (e.g. [num_repeat, num_stages, embed, mlp]) and shard the num_stages axis. + # We wrap the main stage vmap with a num_repeat vmap to generate this axis only for parameter initialization. + vmap_func= nn.vmap( + vmap_func, + in_axes=(0, segment_idx, position_idx, None, None), + variable_axes={ + 'params': 0, + "non_trainable": 0, + "hyper_params": 0, + }, + split_rngs={'params': True}, + metadata_params={ + nn.PARTITION_NAME: "circular_repeats", + 'sub_weight_split_dims_mapping': (None,), + "is_initializing": True, + "x_times": self.config.num_pipeline_repeats, + 'optimizer_dims_mapping': None, + } + ) + + example_inputs = jax.lax.broadcast(example_inputs, [self.config.num_pipeline_repeats]) + example_segmentation = jax.lax.broadcast(example_segmentation, [self.config.num_pipeline_repeats]) if example_segmentation is not None else None + example_position = jax.lax.broadcast(example_position, [self.config.num_pipeline_repeats]) if example_position is not None else None # We only need to run one set of stages to initialize the variables, instead of looping over all microbatches for the full total_iterations. stage_outputs = vmap_func(self.layers, example_inputs, example_segmentation, example_position, deterministic, model_mode) if self.config.scan_layers: @@ -275,6 +415,8 @@ def __call__(self, inputs: jnp.ndarray, segment_ids: jnp.ndarray, positions:jnp. # We return something of the correct shape (global_batch, sequence, embed) by reshaping a single stages output which has # shape [microbatch_size, sequence, embed] + if self.config.num_pipeline_repeats > 1: + stage_outputs = stage_outputs[0] # Remove extra dimension created for the circular vmap broadcasted_stage_outpus = jax.lax.broadcast(stage_outputs[0], [self.config.global_batch_size_to_train_on // self.microbatch_size]) return jnp.reshape(broadcasted_stage_outpus, [self.config.global_batch_size_to_train_on, self.config.max_target_length, self.config.emb_dim]) diff --git a/MaxText/pyconfig.py b/MaxText/pyconfig.py index a49631ef5..714d9692f 100644 --- a/MaxText/pyconfig.py +++ b/MaxText/pyconfig.py @@ -14,6 +14,7 @@ limitations under the License. """ +# pytype: skip-file # pylint: disable=missing-module-docstring, bare-except, consider-using-generator from collections import OrderedDict import math @@ -278,8 +279,15 @@ def user_init(raw_keys): if using_pipeline_parallelism(raw_keys): raw_keys["using_pipeline_parallelism"] = True num_stages = int(raw_keys['ici_pipeline_parallelism'] * raw_keys['dcn_pipeline_parallelism']) + if raw_keys['num_pipeline_repeats'] == -1: + num_pipeline_repeats, remainder = divmod(raw_keys['num_decoder_layers'], num_stages * raw_keys['num_layers_per_pipeline_stage']) + assert not remainder, f"The number of layers per stage ({raw_keys['num_layers_per_pipeline_stage']}) times the number of stages ({num_stages}) must divide the number of decoder layers ({raw_keys['num_decoder_layers']}) " + raw_keys['num_pipeline_repeats'] = num_pipeline_repeats + assert num_stages * raw_keys['num_pipeline_repeats'] * raw_keys['num_layers_per_pipeline_stage'] == raw_keys['num_decoder_layers'], f"The product of pipeline stages ({num_stages}), repeats ({raw_keys['num_pipeline_repeats']}), and layers per stage ({raw_keys['num_layers_per_pipeline_stage']}) must be equal to the number of layers ({raw_keys['num_decoder_layers']})" if raw_keys['num_pipeline_microbatches'] == -1: raw_keys['num_pipeline_microbatches'] = num_stages + assert raw_keys['num_pipeline_microbatches'] % num_stages == 0, f"The number of microbatches ({raw_keys['num_pipeline_microbatches']}) must be divisible by the number of stages ({num_stages})" + assert raw_keys['global_batch_size_to_train_on'] % raw_keys['num_pipeline_microbatches'] == 0, f"The global batch size ({raw_keys['global_batch_size_to_train_on']}) must be divisible by the number of microbatches ({raw_keys['num_pipeline_microbatches']})" else: raw_keys["using_pipeline_parallelism"] = False diff --git a/MaxText/tests/pipeline_parallelism_test.py b/MaxText/tests/pipeline_parallelism_test.py index 276cb7e53..b45dec976 100644 --- a/MaxText/tests/pipeline_parallelism_test.py +++ b/MaxText/tests/pipeline_parallelism_test.py @@ -101,7 +101,14 @@ def pipeline_parallelism_dummy_loss(params, inputs, inputs_position, inputs_segm def regular_sequential_layers(params, inputs, inputs_position, inputs_segmentation, deterministic, model_mode): def get_cur_layer_params(params, layer_idx): def get_cur_layer_params_arr(leaf): - if config.num_layers_per_pipeline_stage > 1: + # Reshape layers into a linear list of layers, e.g. [repeat, stage] into [layers] + if config.num_pipeline_repeats > 1 and config.num_layers_per_pipeline_stage == 1: + new_shape = (leaf.shape[0] * leaf.shape[1],) + leaf.shape[2:] + leaf = jnp.reshape(leaf, new_shape) # [repeat, stage] -> [layers] + elif config.num_pipeline_repeats > 1 and config.num_layers_per_pipeline_stage > 1: + new_shape = (leaf.shape[0] * leaf.shape[1] * leaf.shape[2],) + leaf.shape[3:] + leaf = jnp.reshape(leaf, new_shape) # [repeat, stage, layers_per_stage] -> [layers] + elif config.num_pipeline_repeats == 1 and config.num_layers_per_pipeline_stage > 1: new_shape = (leaf.shape[0] * leaf.shape[1],) + leaf.shape[2:] leaf = jnp.reshape(leaf, new_shape) # [stage, layers_per_stage] -> [layers] return leaf[layer_idx] @@ -111,6 +118,9 @@ def get_cur_layer_params_arr(leaf): for layer in range(config.num_decoder_layers): cur_layer_params = get_cur_layer_params(params, layer) cur_layer_params['params'] = cur_layer_params['params']['layers'] + if config.num_pipeline_repeats > 1 and config.num_layers_per_pipeline_stage > 1: + cur_layer_params['params'] = meta.remove_axis(cur_layer_params['params'], 0, {nn.PARTITION_NAME:"circular_repeats"}) + cur_layer_params['params'] = meta.remove_axis(cur_layer_params['params'], 0, {nn.PARTITION_NAME:"layers"}) reg_layer_activations, _ = single_pipeline_stage.apply(cur_layer_params, reg_layer_activations, inputs_position, inputs_segmentation, deterministic, model_mode) return reg_layer_activations @@ -121,6 +131,40 @@ def regular_sequential_layers_dummy_loss(params, inputs, inputs_position, inputs assert_same_output_and_grad(regular_sequential_layers_dummy_loss, pipeline_parallelism_dummy_loss, init_pipeline_params, inputs, inputs_segmentation, inputs_position, deterministic, model_mode, dummy_targets) + @pytest.mark.tpu + def test_circular_minimum_microbatches_same_output_and_grad(self): + # 4 stages, 8 layers (2 repeats, 1 layer per stage), 4 microbatches + pyconfig.initialize( + [sys.argv[0], "configs/base.yml"], + enable_checkpointing=False, + run_name="circular_minimum_microbatches", + max_target_length=128, + base_emb_dim=28, + ici_pipeline_parallelism=4, + base_num_decoder_layers=8, + num_pipeline_microbatches=4, + per_device_batch_size=4 + ) + config = pyconfig.config + self.assert_pipeline_same_output_and_grad(config) + + @pytest.mark.tpu + def test_circular_extra_microbatches_same_output_and_grad(self): + # 4 stages, 8 layers (2 repeats, 1 layer per stage), 8 microbatches + pyconfig.initialize( + [sys.argv[0], "configs/base.yml"], + enable_checkpointing=False, + run_name="circular_extra_microbatches", + max_target_length=128, + base_emb_dim=28, + ici_pipeline_parallelism=4, + base_num_decoder_layers=8, + num_pipeline_microbatches=8, + per_device_batch_size=4 + ) + config = pyconfig.config + self.assert_pipeline_same_output_and_grad(config) + @pytest.mark.tpu def test_non_circular_same_output_and_grad(self): # 4 stages, 4 layers (no circular repeats, 1 layer per stage), 4 microbatches @@ -138,6 +182,33 @@ def test_non_circular_same_output_and_grad(self): config = pyconfig.config self.assert_pipeline_same_output_and_grad(config) + @pytest.mark.tpu + def test_full_train_circular(self): + # Run a full train.py call with 4 stages, 32 layers (2 layers per stage, 4 circular repeats), 8 microbatches + train_main([ + None, + "configs/base.yml", + r"base_output_directory=gs://runner-maxtext-logs", + "run_name=runner_pipeline_parallelism_test", + r"dataset_path=gs://maxtext-dataset", + "base_emb_dim=28", + "base_num_query_heads=4", + "base_num_kv_heads=4", + "base_mlp_dim=32", + "base_num_decoder_layers=32", + "head_dim=128", + "per_device_batch_size=2", + "max_target_length=1024", + "vocab_size=32", + "dataset_type=synthetic", + "steps=3", + "enable_checkpointing=False", + "ici_pipeline_parallelism=4", + "num_layers_per_pipeline_stage=2", + "num_pipeline_microbatches=8", + + ]) + @pytest.mark.tpu def test_full_train_non_circular(self): # Run a full train.py call with 4 stages, 32 layers (8 layers per stage), 8 microbatches @@ -166,4 +237,4 @@ def test_full_train_non_circular(self): ]) if __name__ == "__main__": - unittest.main() \ No newline at end of file + unittest.main()