diff --git a/MaxText/configs/base.yml b/MaxText/configs/base.yml index cbe377b64..27efc0ce2 100644 --- a/MaxText/configs/base.yml +++ b/MaxText/configs/base.yml @@ -87,6 +87,14 @@ logits_via_embedding: False normalize_embedding_logits: True # whether to normlize pre-softmax logits if logits_via_embedding is true logits_dot_in_fp32: True # whether to use fp32 in logits_dense or shared_embedding dot product for stability +# pipeline parallelism +# The number of decoder layers is equal to the product of num_stages and num_layers_per_pipeline_stage (does not yet support circular pipelines). +num_layers_per_pipeline_stage: 1 +# num_pipeline_microbatches must be a multiple of the number of pipeline stages. By default it is set to the number of stages. +# Note the microbatch_size is given by global_batch_size / num_pipeline_microbatches, where global_batch_size = per_device_batch_size * num_devices +num_pipeline_microbatches: -1 +scan_pipeline_iterations: True # This can be set independently of scan_layers, which is relevant when num_layers_per_pipeline_stage > 1. + # Choose 'remat_policy' between 'minimal', 'save_dot_except_mlpwi', 'save_dot_except_mlp', 'save_qkv_proj', 'qkv_proj_offloaded', 'minimal_offloaded' and 'full'. # These options offer a trade-off between speed (fastest to slowest) and HBM usage (highest to lowest) remat_policy: 'full' @@ -111,9 +119,13 @@ jax_cache_dir: "~/jax_cache" hardware: 'tpu' # Supported hardware types are 'tpu', 'gpu', 'gpu_multiprocess' and 'cpu' # Parallelism -mesh_axes: ['data', 'fsdp', 'fsdp_transpose', 'sequence', 'tensor', 'autoregressive'] +mesh_axes: ['stage', 'data', 'fsdp', 'fsdp_transpose', 'sequence', 'tensor', 'autoregressive'] logical_axis_rules: [ ['activation_batch', ['data', 'fsdp', 'fsdp_transpose',]], + # For pipeline parallelism the pre and post decoder layer tensors' batch dimension is sharded by stages. + # Microbatches are sharded by stage, so moving out of and into this sharding should be a local reshape. + # The "stage" needs to be listed first since the microbatch dimension is first before the reshape. + ['activation_embed_and_logits_batch', ['stage', 'data', 'fsdp', 'fsdp_transpose']], ['activation_heads', ['tensor','sequence']], ['activation_length', 'sequence'], ['activation_embed', 'tensor'], @@ -122,19 +134,22 @@ logical_axis_rules: [ ['activation_vocab', ['tensor', 'sequence']], ['activation_vocab', 'tensor'], ['activation_vocab', 'sequence'], + ['activation_stage','stage'], ['mlp', ['fsdp_transpose', 'tensor', 'autoregressive']], ['vocab', ['tensor', 'autoregressive']], ['embed', ['fsdp', 'fsdp_transpose', 'sequence']], ['embed', ['fsdp', 'sequence']], ['norm', 'tensor'], ['heads', ['tensor', 'autoregressive']], + ['layers', 'stage'], ['kv', []], ['cache_batch', []], ['cache_heads', ['autoregressive', 'tensor']], ['cache_kv', []], ['cache_sequence', []], ] -data_sharding: [['data', 'fsdp', 'fsdp_transpose', 'sequence', 'tensor', 'autoregressive']] +# Axes used for DCN must be earlier in this list than ICI, see (b/339009148) for details +data_sharding: [['stage', 'data', 'fsdp', 'fsdp_transpose', 'sequence', 'tensor', 'autoregressive']] # One axis for each parallelism type may hold a placeholder (-1) # value to auto-shard based on available slices and devices. @@ -145,6 +160,7 @@ dcn_fsdp_parallelism: 1 dcn_fsdp_transpose_parallelism: 1 dcn_sequence_parallelism: 1 # never recommended dcn_tensor_parallelism: 1 # never recommended +dcn_pipeline_parallelism: 1 dcn_autoregressive_parallelism: 1 # never recommended ici_data_parallelism: 1 ici_fsdp_parallelism: -1 # recommended ICI axis to be auto-sharded @@ -152,6 +168,7 @@ ici_fsdp_transpose_parallelism: 1 ici_sequence_parallelism: 1 ici_tensor_parallelism: 1 ici_autoregressive_parallelism: 1 +ici_pipeline_parallelism: 1 # The number of TPU slices is automatically determined, you should not set this explicitly. For ahead of time compilation, # you should set compile_toplogy_num_slices, which will in turn set this value. For non-TPU environments this is set to 1. diff --git a/MaxText/layers/embeddings.py b/MaxText/layers/embeddings.py index 9337986a0..5bf83755a 100644 --- a/MaxText/layers/embeddings.py +++ b/MaxText/layers/embeddings.py @@ -82,7 +82,7 @@ def __call__(self, inputs: Array) -> Array: output = jnp.dot(one_hot, jnp.asarray(self.embedding, self.dtype)) else: output = jnp.asarray(self.embedding, self.dtype)[inputs] - output = nn.with_logical_constraint(output, ("activation_batch", "activation_length", "activation_embed")) + output = nn.with_logical_constraint(output, ("activation_embed_and_logits_batch", "activation_length", "activation_embed")) return output def attend(self, query: Array) -> Array: diff --git a/MaxText/layers/models.py b/MaxText/layers/models.py index a9e1d0e63..7b824601a 100644 --- a/MaxText/layers/models.py +++ b/MaxText/layers/models.py @@ -16,7 +16,7 @@ # pylint: disable=arguments-differ # pylint: disable=no-name-in-module -from typing import Callable, Optional +from typing import Any, Callable, Optional from flax import linen as nn @@ -28,6 +28,7 @@ from layers import embeddings from layers import linears from layers import normalizations, quantizations +from layers import pipeline Array = common_types.Array Config = common_types.Config @@ -145,6 +146,25 @@ def __call__( return layer_output, None if cfg.scan_layers else layer_output +class SequentialBlockDecoderLayers(nn.Module): + """Sequential unscanned series of decoder layers.""" + decoder_layer: Any + num_decoder_layers: int + config: Config + mesh: Mesh + quant: Quant + + @nn.compact + def __call__(self, inputs: jnp.ndarray, decoder_segment_ids, decoder_positions, deterministic, model_mode) -> jnp.ndarray: + for lyr in range(self.num_decoder_layers): + inputs = self.decoder_layer(config=self.config, mesh=self.mesh, name=f"layers_{lyr}", quant=self.quant)( + inputs, + decoder_segment_ids, + decoder_positions, + deterministic, + model_mode, + ) + return inputs class Decoder(nn.Module): """A stack of decoder layers as a part of an encoder-decoder architecture.""" @@ -174,6 +194,10 @@ def get_decoder_layer(self): from layers import gpt3 return gpt3.Gpt3DecoderLayer + elif self.config.decoder_block == "simple": + from layers import simple_layer + + return simple_layer.SimpleDecoderLayer else: raise ValueError(f"Incorrect decoder_block name {self.config.decoder_block=}") @@ -187,6 +211,34 @@ def get_norm_layer(self): else: raise ValueError(f"Incorrect decoder_block name {self.config.decoder_block=}") + def scan_decoder_layers(self, cfg, decoder_layer, length, metdata_axis_name, mesh): + initializing = self.is_mutable_collection("params") + params_spec = cfg.param_scan_axis if initializing else ScanIn(cfg.param_scan_axis) + cache_spec = 0 + scan_fn = nn.scan( + decoder_layer, + variable_axes={ + "params": params_spec, + "cache": cache_spec, + "intermediates": 0, + "aqt": 0, + "_overwrite_with_gradient": 0, + }, + split_rngs={ + "params": True, + "dropout": cfg.enable_dropout, + }, + in_axes=( + nn.broadcast, + nn.broadcast, + nn.broadcast, + nn.broadcast, + ), + length=length, + metadata_params={nn.PARTITION_NAME: metdata_axis_name}, + ) + return scan_fn(config=cfg, mesh=mesh, name="layers", quant=self.quant) + @nn.compact def __call__( self, @@ -266,53 +318,46 @@ def __call__( else: assert cfg.remat_policy == "full", "Remat policy needs to be on list of remat policies" policy = None - BlockLayer = nn.remat( # pylint: disable=invalid-name - BlockLayer, - prevent_cse=not cfg.scan_layers, - policy=policy, - static_argnums=(-1, -2, -3, -4, -5), - ) - if cfg.scan_layers: - initializing = self.is_mutable_collection("params") - params_spec = cfg.param_scan_axis if initializing else ScanIn(cfg.param_scan_axis) - cache_spec = 0 - y, _ = nn.scan( - BlockLayer, - variable_axes={ - "params": params_spec, - "cache": cache_spec, - "intermediates": 0, - "aqt": 0, - "_overwrite_with_gradient": 0, - }, - split_rngs={ - "params": True, - "dropout": cfg.enable_dropout, - }, - in_axes=( - nn.broadcast, - nn.broadcast, - nn.broadcast, - nn.broadcast, - ), - length=cfg.num_decoder_layers, - metadata_params={nn.PARTITION_NAME: "layers"}, - )(config=cfg, mesh=mesh, name="layers", quant=self.quant)( - y, - decoder_segment_ids, - decoder_positions, - deterministic, - model_mode, - ) + + RemattedBlockLayer = nn.remat( # pylint: disable=invalid-name + BlockLayer, + prevent_cse=not cfg.scan_layers, + policy=policy, + static_argnums=(-1, -2, -3, -4, -5), + ) + if cfg.using_pipeline_parallelism: + if cfg.num_layers_per_pipeline_stage == 1: + stage_module = BlockLayer(config=cfg, mesh=mesh, quant=self.quant) + elif cfg.scan_layers: + stage_module = self.scan_decoder_layers(cfg, RemattedBlockLayer, cfg.num_layers_per_pipeline_stage, "layers_per_stage", mesh) + elif not cfg.scan_layers: + stage_module=SequentialBlockDecoderLayers(decoder_layer=RemattedBlockLayer, num_decoder_layers=cfg.num_layers_per_pipeline_stage, config=cfg, mesh=mesh,quant=self.quant) + + y = pipeline.Pipeline(config=cfg, mesh=mesh, layers=stage_module, remat_policy=policy)( + y, + decoder_segment_ids, + decoder_positions, + deterministic, + model_mode, + ) else: - for lyr in range(cfg.num_decoder_layers): - y = BlockLayer(config=cfg, mesh=mesh, name=f"layers_{lyr}", quant=self.quant)( + if cfg.scan_layers: + y, _ = self.scan_decoder_layers(cfg, RemattedBlockLayer, cfg.num_decoder_layers, "layers", mesh)( y, decoder_segment_ids, decoder_positions, deterministic, model_mode, ) + else: + for lyr in range(cfg.num_decoder_layers): + y = RemattedBlockLayer(config=cfg, mesh=mesh, name=f"layers_{lyr}", quant=self.quant)( + y, + decoder_segment_ids, + decoder_positions, + deterministic, + model_mode, + ) y = self.get_norm_layer()( dtype=cfg.dtype, @@ -340,7 +385,7 @@ def __call__( )( y ) # We do not quantize the logits matmul. - logits = nn.with_logical_constraint(logits, ("activation_batch", "activation_length", "activation_vocab")) + logits = nn.with_logical_constraint(logits, ("activation_embed_and_logits_batch", "activation_length", "activation_vocab")) logits = logits.astype(jnp.float32) return logits diff --git a/MaxText/layers/pipeline.py b/MaxText/layers/pipeline.py new file mode 100644 index 000000000..35d7fbdd7 --- /dev/null +++ b/MaxText/layers/pipeline.py @@ -0,0 +1,331 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +''' Pipeline layer wrapping a decoder layer(s). Does not yet supports circular pipelining ''' + +import jax +import jax.ad_checkpoint +import numpy as np +from jax import numpy as jnp +from flax.core import meta +from flax import linen as nn +import common_types +import functools +from typing import Any + +class Pipeline(nn.Module): + """Module that implements pipelining across stages. + + This module will loop over microbatches and execute the main body with a vmap for both the inputs and weights. + This will produce a pipeline pattern if the stage dimension is sharded. + + Does not yet support circular pipelines. Multiple + layers per stage are used when a module that executes multiple layers per stage is passed as the layers input. + + Attributes: + config: Importantly contains num_pipeline_microbatches. + layers: A module instance that each stage can execute. It can either be a single layer such as a LlamaDecoderLayer instance + or scanned/looped set of decoder layers to execute multiple layers per stage. + mesh: The device mesh of the system. + remat_policy: Remat policy to use for the loop iterations + """ + + config: common_types.Config + layers: nn.Module # The name of this property (layers) is reflected in the state pytree and thus also checkpoints. + mesh: common_types.Mesh + remat_policy: Any = None + + def setup(self): + self.num_stages = self.config.ici_pipeline_parallelism * self.config.dcn_pipeline_parallelism + self.microbatch_size = self.config.global_batch_size_to_train_on // self.config.num_pipeline_microbatches + microbatches_per_stage = self.config.num_pipeline_microbatches // self.num_stages + self.microbatches_per_stage = microbatches_per_stage + + def init_states(self, inputs): + '''Initialize components of state: state_io, shift + Assumes input has already been reshaped into microbatches: [num_micro_batches, micro_batch_size, sequence, embed] + + Returns a dictionary with properties + shift: zeros shape [num_stages, micro_size, sequence, embed] + state_io: reshaped inputs [num_stages, microbatches/stages, micro_size, sequence, embed] + loop_iteration: scalar set initially to 0. + ''' + + # Shift is used to rotate the output of each pipeline into the input of the next + # shift has shape [num_stages, micro_size, sequence, embed] + shift = jnp.zeros((self.num_stages,) + inputs.shape[1:], dtype=inputs.dtype) + shift = nn.with_logical_constraint(shift, ("activation_stage", "activation_batch", "activation_length", "activation_embed"),rules=self.config.logical_axis_rules,mesh=self.mesh) + + # state_io (state input output) at first holds all of the input batches, but also will hold the outputs as the pipeline runs/finishes + # state_io has shape [num_stages, microbatches/stages, micro_size, sequence, embed] + state_io = jnp.reshape(inputs, (self.num_stages, self.microbatches_per_stage) + inputs.shape[1:]) + # We shard the microbatch_size axis by data/fsdp, not num_microbatches since those are looped over. + state_io = nn.with_logical_constraint(state_io, ("activation_stage", None, "activation_batch", "activation_length", "activation_embed"),rules=self.config.logical_axis_rules, mesh=self.mesh) + + init_loop_state = { + "state_io": state_io, + "shift": shift, + "loop_iteration": 0 + } + return init_loop_state + + def get_iteration_inputs(self, loop_iteration, state_io, shift): + ''' + Construct stages_in: the global array that is operated on for this iteration, shape same as shift=[stages, micro_size, sequence, embed] + This is almost a rotated version of the last outputs, except for the first stage which must grab a new batch from state_io + ''' + + # Setup potential input from state_io, which has a rotating microbatch index (size of microbatches_per_stage) + state_io_batch_idx = loop_iteration % self.microbatches_per_stage + first_stage_in = state_io[:,state_io_batch_idx] + # Note that first_stage_in may correspond to bubble computation during the last few iterations. + # However these bubble computation results remain in the shift buffer (do not make it back to state_io) and are thus discarded / not returned. + # The final returned output is stored in the state_io, which has the appropriate total size of num_microbatches. The state_io will not contain bubble results + # at the end of the last iteration. + + + def select_state_or_input(first_stage_in, shift): + # Selects input for stage 0, shift for other stages + return jnp.where(jax.lax.broadcasted_iota('int32', shift.shape, 0) == 0, first_stage_in, shift) + + # Selects input (from stream_io) for stage 0, other stages get from shift (the rotated previous output) + stages_in = select_state_or_input(first_stage_in, shift) + stages_in = nn.with_logical_constraint(stages_in, ("activation_stage", "activation_batch", "activation_length", "activation_embed"), rules=self.config.logical_axis_rules, mesh=self.mesh) + return stages_in + + def shard_dim_by_stages(self, x, dim: int): + # Shards a dimension by stages. Currently the sharding of other dimensions are left up the compiler, alternatively + # we may want to copy over the sharding from the other input axes. + dims_mapping = [jax.sharding.PartitionSpec.UNCONSTRAINED] * x.ndim + dims_mapping[dim] = "stage" + dims_mapping = tuple(dims_mapping) + sharding = jax.sharding.NamedSharding(self.mesh, jax.sharding.PartitionSpec(*dims_mapping)) + return jax.lax.with_sharding_constraint(x, sharding) + + def get_microbatch_ids(self, loop_iteration): + '''Gets the microbatch_ids for all stages on this loop_iteration.''' + # Stage 0 has processed one microbatch every loop_iter, but Stage 1 is one behind due to bubble, etc for other stages + microbatches_processed = jnp.maximum(loop_iteration - jnp.arange(self.num_stages), 0) + microbatch_ids = microbatches_processed % self.config.num_pipeline_microbatches + return microbatch_ids + + def vmap_gather(self, xs, ids, ids_dim): + """Use vmap to implement a stage-wise sharded gather. + + The stages share the same input, but they have different offsets. + + Args: + xs: Data shared by all stages, to be gathered from. + ids: Integer tensor of shape [num_stages], the offsets of the stages. + ids_dim: The dimension in xs where ids are applied. In the output, this + dimension will be [num_stages], since each stage gets one slice. + + Returns: + The per-stage gathered values. The shape is xs.shape but with ids_dim size + replaced with [num_stages]. + """ + def _gather_one(x, i): + return jnp.squeeze( + jax.lax.dynamic_slice_in_dim(x, i, 1, ids_dim), ids_dim) + + ids = self.shard_dim_by_stages(ids, 0) + outs = jax.vmap(_gather_one, in_axes=(None, 0), out_axes=ids_dim)(xs, ids) + return self.shard_dim_by_stages(outs, 0) + + def get_new_loop_state(self,output, loop_state): + ''' + Update the various buffers given the output of the most recent iteration + * state_io: rotates left/up by 1 (the whole created in the last slot is filled with the most recent pipeline output) + * Pushing inputs up from top of state_io into first stage of shift + * Pulling outputs up from last stage of shift into bottom of state_io + * shift: rotate output right/down by 1 - we imagine the pipeline moves to right/down + ''' + + old_state_io = loop_state['state_io'] + loop_iteration = loop_state["loop_iteration"] + # Shift becomes a rotated-right version of the previous output + def _rotate_right(output_in): + # Use lax.slice to avoid generating a gather. + last = jax.lax.slice_in_dim(output_in, self.num_stages - 1, self.num_stages, axis=0) + except_last = jax.lax.slice_in_dim(output_in, 0, self.num_stages - 1, axis=0) + return jnp.concatenate([last, except_last], axis=0) + new_shift = _rotate_right(output) + + # Rotate stream_io left/up by 1 on rotating micro/stage index (stream_buf_idx), replacing the last/bottom with the last stage output + stream_buf_idx = loop_iteration % self.microbatches_per_stage + stream_slice = old_state_io[:, stream_buf_idx] + def _update_state_io(state_in, stream_slice, output): + # Shift the current slice to the left, then fill the last stage with the final output. + padding = [[0, 1]] + [[0, 0]] * (stream_slice.ndim - 1) + stream_slice = jax.lax.slice_in_dim( + jnp.pad(stream_slice, padding), 1, stream_slice.shape[0] + 1, axis=0) + stream_slice = jnp.where( + jax.lax.broadcasted_iota('int32', stream_slice.shape, 0) == self.num_stages - 1, output, + stream_slice) + stream_slice = jnp.expand_dims(stream_slice, 1) + return jax.lax.dynamic_update_slice_in_dim( + state_in, stream_slice, stream_buf_idx, axis=1) + new_state = _update_state_io(old_state_io, stream_slice, output) + + new_loop_state = { + "state_io": new_state, + "shift": new_shift, + "loop_iteration": loop_iteration + 1 + } + return new_loop_state + + def permute_output_micro_per_stage_dim(self, output): + # The first real output (batch 0) takes a certain amount of loop iterations to finish and be pushed to state_io - it will land on a different index of state_io depending on the number of iterations. + first_output_num_iters = self.num_stages - 1 + # The first term above is a multiple of num_pipeline_microbatches and thus could be ignored since its also a multiple of microbatches_per_stage, but we keep it for clairty + land_idx = first_output_num_iters % self.microbatches_per_stage + permutation = (np.arange(self.microbatches_per_stage) + land_idx) % self.microbatches_per_stage # permute so the value in land_idx is moved into idx 0, and (land_idx + 1) appear in idx 1, etc + output = output[:,permutation] + return output + + def get_main_vmap_func(self): + def func_to_vmap(body_instance, stages_inputs, stages_segment_ids, stages_positions, deterministic, model_mode): + # nn.vmap requires either a nn.module class or a function whose first argument is a nn.module instance. + return body_instance(stages_inputs, stages_segment_ids, stages_positions, deterministic, model_mode) + + vmap_func = nn.vmap( + func_to_vmap, + in_axes=(0, 0, 0, None, None), + spmd_axis_name='stage', + variable_axes={'params': 0}, + split_rngs={'params': self.is_initializing()}, + metadata_params={ + nn.PARTITION_NAME: "layers", + 'sub_weight_split_dims_mapping': (None), + "is_initializing": self.is_initializing(), + "x_times": self.num_stages} + ) + return vmap_func + + def run_one_iteration(self, loop_state, positions, segment_ids, deterministic, model_mode, decoder_layer_instance): + '''Run one loop iteration - gets weights and inputs for each stage, run the stages in parallel, and update the loop state.''' + state_io = loop_state['state_io'] + shift = loop_state["shift"] + loop_iteration = loop_state["loop_iteration"] + + microbatch_ids = self.get_microbatch_ids(loop_iteration) + + stages_inputs = self.get_iteration_inputs(loop_iteration, state_io, shift) + # We checkpoint stages_inputs since we are grabbing only one slice of the state_io, don't need to save the entire buffer. + stages_inputs = jax.ad_checkpoint.checkpoint_name(stages_inputs, 'iteration_input') + stages_positions = self.vmap_gather(positions, microbatch_ids, 0) if positions is not None else None + stages_segment_ids = self.vmap_gather(segment_ids, microbatch_ids, 0) if segment_ids is not None else None + + vmap_func = self.get_main_vmap_func() + + stages_output = vmap_func(decoder_layer_instance, stages_inputs, stages_segment_ids, stages_positions, deterministic, model_mode) + if self.config.scan_layers: + stages_output = stages_output[0] + + new_state = self.get_new_loop_state(stages_output, loop_state) + return new_state + + @nn.compact + def __call__(self, inputs: jnp.ndarray, segment_ids: jnp.ndarray, positions:jnp.ndarray, deterministic: bool, model_mode=common_types.MODEL_MODE_TRAIN) -> jnp.ndarray: + ''' The main method that maps the series of decoder layer inputs to final layer outputs. + Has the same signature of a single decoder layer, and expects the same shapes, e.g. the inputs should have shape [global_batch], and internally + this will be reshapped into microbatches. + ''' + # Reshape inputs of [global_batch, ...] to [microbatches, microbatch_sizes, ...] + inputs = inputs.reshape((self.config.num_pipeline_microbatches, self.microbatch_size, self.config.max_target_length, self.config.emb_dim)) + example_inputs = jax.lax.broadcast(inputs[0], [self.num_stages]) # dummy inputs fed to initialize the module weights. + if positions is not None: + positions = positions.reshape((self.config.num_pipeline_microbatches, self.microbatch_size, self.config.max_target_length)) + example_position = jax.lax.broadcast(positions[0], [self.num_stages]) + else: + example_position = None + if segment_ids is not None: + segment_ids = segment_ids.reshape((self.config.num_pipeline_microbatches, self.microbatch_size, self.config.max_target_length)) + example_segmentation = jax.lax.broadcast(segment_ids[0], [self.num_stages]) + else: + example_segmentation = None + + loop_state = self.init_states(inputs) + + # Each microbatch should go through each stage - so there is num_micro * num_stages compute to perform + # Each iteration is vmapped by num_stages, so the number of iterations should be num_micro * num_stages / num_stages = num_micro + # However due to the pipeline bubble some iterations process less than num_stages microbatches. It takes + # num_micro - 1 iterations for the last microbatch to enter the pipeline, then num_stages more iterations to complete the pipeline. + # Thus the total iterations is num_micro + num_stages - 1, and we may consider the num_stages - 1 as bubble. + total_iterations = self.config.num_pipeline_microbatches + self.num_stages - 1 + + if self.is_initializing(): + vmap_func = self.get_main_vmap_func() + + # We only need to run one set of stages to initialize the variables, instead of looping over all microbatches for the full total_iterations. + stage_outputs = vmap_func(self.layers, example_inputs, example_segmentation, example_position, deterministic, model_mode) + if self.config.scan_layers: + stage_outputs = stage_outputs[0] + + # We return something of the correct shape (global_batch, sequence, embed) by reshaping a single stages output which has + # shape [microbatch_size, sequence, embed] + broadcasted_stage_outpus = jax.lax.broadcast(stage_outputs[0], [self.config.global_batch_size_to_train_on // self.microbatch_size]) + return jnp.reshape(broadcasted_stage_outpus, [self.config.global_batch_size_to_train_on, self.config.max_target_length, self.config.emb_dim]) + + def run_iteration_scannable(model,loop_state, xs): + # flax transforms like nn.scan and nn.remat can only be applied to nn.module classes or nn.module instances, so we explicitly wrap + # the run_one_iteration in this method - the first argument model (i.e. self) is a nn.module instance. + return model.run_one_iteration(loop_state, positions, segment_ids, deterministic, model_mode, model.layers), None + if self.remat_policy is not None: + remat_policy = jax.checkpoint_policies.save_from_both_policies( + self.remat_policy, + jax.checkpoint_policies.save_only_these_names('iteration_input') + ) + else: + remat_policy = jax.checkpoint_policies.save_only_these_names('iteration_input') + run_one_iteration_rematted = nn.remat( + run_iteration_scannable, + prevent_cse=not self.config.scan_pipeline_iterations, # prevent_cse not used with scan + policy=remat_policy + ) + + # The scan cannot be used on init since it broadcasts the weights, which aren't yet initialized. + if self.config.scan_pipeline_iterations: + variable_carry = [] + variable_broadcast = ["params"] # All loop iterations need the weights for the full pipeline. + if self.is_mutable_collection("non_trainable"): + variable_carry.append("non_trainable") + else: + variable_broadcast.append("non_trainable") + run_all_iterations_scanned = nn.scan( + run_one_iteration_rematted, + variable_axes={ + "summaries": 0, + "aux_loss": 0, + "intermediates": 0, + "hyper_params": 0, + }, + variable_broadcast=variable_broadcast, + variable_carry=variable_carry, + # Dropout/aqt keys will be split for each iteration. + split_rngs={"random": True}, + length=total_iterations, + ) + loop_state, _ = run_all_iterations_scanned(self, loop_state, None) + else: + for loop_iteration in range(total_iterations): + loop_state, _ = run_one_iteration_rematted(self, loop_state, None) + + # The final output is located in the input/output array, however the output microbatches may be permuted relative to the input + final_output = self.permute_output_micro_per_stage_dim(loop_state["state_io"]) + + # reshape outputs to match input shape of total batch instead of microbatches [batch, sequence, embed] + final_output = jnp.reshape(final_output, (self.config.global_batch_size_to_train_on, self.config.max_target_length, self.config.emb_dim)) + + return final_output \ No newline at end of file diff --git a/MaxText/layers/simple_layer.py b/MaxText/layers/simple_layer.py new file mode 100644 index 000000000..84a2ff5b8 --- /dev/null +++ b/MaxText/layers/simple_layer.py @@ -0,0 +1,41 @@ +""" +Copyright 2024 Google LLC +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + https://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +""" A simple decoder layer consisting of a single [embed, embed] weight matrix for testing and debugging purposes.""" + +from jax import numpy as jnp +from flax import linen as nn +from jax.sharding import Mesh +from typing import Optional +from layers import quantizations +import common_types + +# pytype: disable=attribute-error + +class SimpleDecoderLayer(nn.Module): + config: common_types.Config + mesh: Mesh + quant: Optional[quantizations.AqtQuantization] = None + + def setup(self): + self.weight_mat = self.param( + 'weights', + nn.with_logical_partitioning(nn.initializers.lecun_normal(), ("embed", "mlp")), + (self.config.emb_dim, self.config.emb_dim) + ) + + def __call__(self, inputs: jnp.ndarray, positions, segmentation, deterministic, model_mode): + if self.config.scan_layers: + return inputs @ self.weight_mat.astype(inputs.dtype), None + else: + return inputs @ self.weight_mat.astype(inputs.dtype) diff --git a/MaxText/max_utils.py b/MaxText/max_utils.py index 0e372e294..8b2085d3e 100644 --- a/MaxText/max_utils.py +++ b/MaxText/max_utils.py @@ -309,6 +309,7 @@ def create_device_mesh(config, devices=None): multi_slice_env = num_slices > 1 dcn_parallelism = [ + config.dcn_pipeline_parallelism, config.dcn_data_parallelism, config.dcn_fsdp_parallelism, config.dcn_fsdp_transpose_parallelism, @@ -317,6 +318,7 @@ def create_device_mesh(config, devices=None): config.dcn_autoregressive_parallelism, ] ici_parallelism = [ + config.ici_pipeline_parallelism, config.ici_data_parallelism, config.ici_fsdp_parallelism, config.ici_fsdp_transpose_parallelism, diff --git a/MaxText/maxtext_utils.py b/MaxText/maxtext_utils.py index b07a80df5..31e3b2c02 100644 --- a/MaxText/maxtext_utils.py +++ b/MaxText/maxtext_utils.py @@ -194,7 +194,7 @@ def assert_params_sufficiently_sharded(params, mesh, tolerance=0.02): """ total_num_params = max_utils.calculate_num_params_from_pytree(params) product_num_devices_for_weight_sharding = 1 - for axis in ["fsdp", "fsdp_transpose", "sequence", "tensor"]: + for axis in ["fsdp", "fsdp_transpose", "sequence", "tensor", "stage"]: product_num_devices_for_weight_sharding *= mesh.shape[axis] total_num_params_per_chip = max_utils.calculate_total_params_per_chip(params) perfectly_sharded_params_per_chip = total_num_params / product_num_devices_for_weight_sharding diff --git a/MaxText/pyconfig.py b/MaxText/pyconfig.py index 0abdfe95b..a49631ef5 100644 --- a/MaxText/pyconfig.py +++ b/MaxText/pyconfig.py @@ -275,6 +275,15 @@ def user_init(raw_keys): raw_keys["num_slices"] = get_num_slices(raw_keys) raw_keys["quantization_local_shard_count"] = get_quantization_local_shard_count(raw_keys) + if using_pipeline_parallelism(raw_keys): + raw_keys["using_pipeline_parallelism"] = True + num_stages = int(raw_keys['ici_pipeline_parallelism'] * raw_keys['dcn_pipeline_parallelism']) + if raw_keys['num_pipeline_microbatches'] == -1: + raw_keys['num_pipeline_microbatches'] = num_stages + else: + raw_keys["using_pipeline_parallelism"] = False + + print_system_information() # Write raw_keys to GCS before type conversions @@ -411,6 +420,8 @@ def get_quantization_local_shard_count(raw_keys): else: return raw_keys["quantization_local_shard_count"] +def using_pipeline_parallelism(raw_keys) -> bool: + return int(raw_keys['ici_pipeline_parallelism']) > 1 or int(raw_keys['dcn_pipeline_parallelism']) > 1 class HyperParameters: # pylint: disable=missing-class-docstring diff --git a/MaxText/tests/pipeline_parallelism_test.py b/MaxText/tests/pipeline_parallelism_test.py new file mode 100644 index 000000000..276cb7e53 --- /dev/null +++ b/MaxText/tests/pipeline_parallelism_test.py @@ -0,0 +1,169 @@ +""" +Copyright 2024 Google LLC +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + https://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +# pylint: disable=missing-module-docstring, missing-function-docstring +import sys + +import jax +from jax.sharding import Mesh + + +import unittest +import pytest + +import pyconfig + + +from layers import pipeline +import jax +from jax import numpy as jnp +from jax.sharding import Mesh + +import common_types +import pyconfig +import max_utils +from flax.core import meta + +import jax.numpy as jnp +from flax import linen as nn +from layers import simple_layer +from train import main as train_main + + + +def assert_same_output_and_grad(f1, f2, *inputs): + f1_value, f1_grad = jax.value_and_grad(f1)(*inputs) + f2_value, f2_grad = jax.value_and_grad(f2)(*inputs) + + def pytree_ravel(pytree): + ravelled_tree = jax.tree.map(jnp.ravel, pytree) + ravelled_leaves, _ = jax.tree_util.tree_flatten(ravelled_tree) + return jnp.concatenate(ravelled_leaves) + f1_grad = pytree_ravel(f1_grad) + f2_grad = pytree_ravel(f2_grad) + + assert jax.numpy.allclose(f1_value, f2_value, rtol=1e-2, equal_nan=False) + assert jax.numpy.allclose(f1_grad, f2_grad, rtol=1e-2, equal_nan=False) + + +class PipelineParallelismTest(unittest.TestCase): + + def assert_pipeline_same_output_and_grad(self, config): + devices_array = max_utils.create_device_mesh(config) + mesh = Mesh(devices_array, config.mesh_axes) + + def get_inputs(batch_size, sequence, features): + '''Get random inputs, and random dummy targets + Returns + inputs: [batch_size, sequence, features] + targets: [batch_size, sequence, features] + positions: [batch_size, sequence] + segmentations: [batch_size, segmentation] + ''' + input_shape = [batch_size, sequence, features] + inputs = jax.random.normal(jax.random.PRNGKey(2), input_shape, dtype=jnp.float32) + + # dummy targets same shape as inputs to use for a dummy loss function to check gradient correctness + dummy_targets = jax.random.normal(jax.random.PRNGKey(3),input_shape, dtype=jnp.float32) + + inputs_position = jnp.array([jnp.arange(sequence, dtype=jnp.int32) for _ in range(batch_size)], dtype=jnp.int32) + inputs_segmentation = jnp.ones((batch_size, sequence), dtype=jnp.int32) + return inputs, dummy_targets, inputs_position, inputs_segmentation + + inputs, dummy_targets, inputs_position, inputs_segmentation = get_inputs(config.global_batch_size_to_train_on, config.max_target_length, config.emb_dim) + deterministic = True + model_mode = common_types.MODEL_MODE_TRAIN + # We use a simpler single matmul decoder layer for fast compilation in these tests. + single_pipeline_stage = simple_layer.SimpleDecoderLayer(config=config, mesh=mesh) + my_pipeline = pipeline.Pipeline( + config=config, + layers=single_pipeline_stage, + mesh=mesh + ) + init_pipeline_params = my_pipeline.init(jax.random.PRNGKey(0), inputs, inputs_position, inputs_segmentation, deterministic, model_mode) + + # Create a dummy scalar loss function so we may take the gradient wrt weights + def pipeline_parallelism_dummy_loss(params, inputs, inputs_position, inputs_segmentation, deterministic, model_mode, dummy_targets): + outputs = my_pipeline.apply(params, inputs, inputs_position, inputs_segmentation, deterministic, model_mode) + loss = jnp.linalg.norm(outputs - dummy_targets) + return loss + + def regular_sequential_layers(params, inputs, inputs_position, inputs_segmentation, deterministic, model_mode): + def get_cur_layer_params(params, layer_idx): + def get_cur_layer_params_arr(leaf): + if config.num_layers_per_pipeline_stage > 1: + new_shape = (leaf.shape[0] * leaf.shape[1],) + leaf.shape[2:] + leaf = jnp.reshape(leaf, new_shape) # [stage, layers_per_stage] -> [layers] + return leaf[layer_idx] + return jax.tree.map(get_cur_layer_params_arr, params) + + reg_layer_activations = inputs + for layer in range(config.num_decoder_layers): + cur_layer_params = get_cur_layer_params(params, layer) + cur_layer_params['params'] = cur_layer_params['params']['layers'] + reg_layer_activations, _ = single_pipeline_stage.apply(cur_layer_params, reg_layer_activations, inputs_position, inputs_segmentation, deterministic, model_mode) + return reg_layer_activations + + def regular_sequential_layers_dummy_loss(params, inputs, inputs_position, inputs_segmentation, deterministic, model_mode, dummy_targets): + outputs = regular_sequential_layers(params, inputs, inputs_position, inputs_segmentation, deterministic, model_mode) + loss = jnp.linalg.norm(outputs - dummy_targets) + return loss + + assert_same_output_and_grad(regular_sequential_layers_dummy_loss, pipeline_parallelism_dummy_loss, init_pipeline_params, inputs, inputs_segmentation, inputs_position, deterministic, model_mode, dummy_targets) + + @pytest.mark.tpu + def test_non_circular_same_output_and_grad(self): + # 4 stages, 4 layers (no circular repeats, 1 layer per stage), 4 microbatches + pyconfig.initialize( + [sys.argv[0], "configs/base.yml"], + enable_checkpointing=False, + run_name="non_circular", + max_target_length=128, + base_emb_dim=28, + ici_pipeline_parallelism=4, + base_num_decoder_layers=4, + num_pipeline_microbatches=4, + per_device_batch_size=4 + ) + config = pyconfig.config + self.assert_pipeline_same_output_and_grad(config) + + @pytest.mark.tpu + def test_full_train_non_circular(self): + # Run a full train.py call with 4 stages, 32 layers (8 layers per stage), 8 microbatches + train_main([ + None, + "configs/base.yml", + r"base_output_directory=gs://runner-maxtext-logs", + "run_name=runner_pipeline_parallelism_test", + r"dataset_path=gs://maxtext-dataset", + "base_emb_dim=28", + "base_num_query_heads=4", + "base_num_kv_heads=4", + "base_mlp_dim=32", + "base_num_decoder_layers=32", + "head_dim=128", + "per_device_batch_size=2", + "max_target_length=1024", + "vocab_size=32", + "dataset_type=synthetic", + "steps=3", + "enable_checkpointing=False", + "ici_pipeline_parallelism=4", + "num_layers_per_pipeline_stage=8", + "num_pipeline_microbatches=8", + + ]) + +if __name__ == "__main__": + unittest.main() \ No newline at end of file diff --git a/MaxText/train.py b/MaxText/train.py index 480feb809..a2baa653f 100644 --- a/MaxText/train.py +++ b/MaxText/train.py @@ -227,7 +227,7 @@ def loss_fn(model, config, data, dropout_rng, params, is_train=True): ) one_hot_targets = jax.nn.one_hot(data["targets"], config.vocab_size) xent, _ = max_utils.cross_entropy_with_logits(logits, one_hot_targets, 0.0) - xent = nn.with_logical_constraint(xent, ("activation_batch", "activation_length")) + xent = nn.with_logical_constraint(xent, ("activation_embed_and_logits_batch", "activation_length")) # Mask out paddings at the end of each example. xent = xent * (data["targets_segmentation"] != 0) total_loss = jnp.sum(xent) @@ -387,7 +387,12 @@ def setup_train_loop(config): model, data_iterator, tx, config, init_rng, mesh, checkpoint_manager ) - maxtext_utils.assert_params_sufficiently_sharded(state.params, mesh) + if config.using_pipeline_parallelism: + # The vocab tensor(s) of shape [vocab, embed] (and transpose) are not sharded by stage + params_sharded_tolerance=0.1 + else: + params_sharded_tolerance=0.02 + maxtext_utils.assert_params_sufficiently_sharded(state.params, mesh, tolerance=params_sharded_tolerance) return ( init_rng,