Skip to content

Commit

Permalink
Merge pull request #691 from google:mattdavidow-pipeline-linear
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 642648106
  • Loading branch information
maxtext authors committed Jun 12, 2024
2 parents a2fdf29 + 5ae05cc commit 7cdca96
Show file tree
Hide file tree
Showing 10 changed files with 669 additions and 48 deletions.
21 changes: 19 additions & 2 deletions MaxText/configs/base.yml
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,14 @@ logits_via_embedding: False
normalize_embedding_logits: True # whether to normlize pre-softmax logits if logits_via_embedding is true
logits_dot_in_fp32: True # whether to use fp32 in logits_dense or shared_embedding dot product for stability

# pipeline parallelism
# The number of decoder layers is equal to the product of num_stages and num_layers_per_pipeline_stage (does not yet support circular pipelines).
num_layers_per_pipeline_stage: 1
# num_pipeline_microbatches must be a multiple of the number of pipeline stages. By default it is set to the number of stages.
# Note the microbatch_size is given by global_batch_size / num_pipeline_microbatches, where global_batch_size = per_device_batch_size * num_devices
num_pipeline_microbatches: -1
scan_pipeline_iterations: True # This can be set independently of scan_layers, which is relevant when num_layers_per_pipeline_stage > 1.

# Choose 'remat_policy' between 'minimal', 'save_dot_except_mlpwi', 'save_dot_except_mlp', 'save_qkv_proj', 'qkv_proj_offloaded', 'minimal_offloaded' and 'full'.
# These options offer a trade-off between speed (fastest to slowest) and HBM usage (highest to lowest)
remat_policy: 'full'
Expand All @@ -111,9 +119,13 @@ jax_cache_dir: "~/jax_cache"
hardware: 'tpu' # Supported hardware types are 'tpu', 'gpu', 'gpu_multiprocess' and 'cpu'

# Parallelism
mesh_axes: ['data', 'fsdp', 'fsdp_transpose', 'sequence', 'tensor', 'autoregressive']
mesh_axes: ['stage', 'data', 'fsdp', 'fsdp_transpose', 'sequence', 'tensor', 'autoregressive']
logical_axis_rules: [
['activation_batch', ['data', 'fsdp', 'fsdp_transpose',]],
# For pipeline parallelism the pre and post decoder layer tensors' batch dimension is sharded by stages.
# Microbatches are sharded by stage, so moving out of and into this sharding should be a local reshape.
# The "stage" needs to be listed first since the microbatch dimension is first before the reshape.
['activation_embed_and_logits_batch', ['stage', 'data', 'fsdp', 'fsdp_transpose']],
['activation_heads', ['tensor','sequence']],
['activation_length', 'sequence'],
['activation_embed', 'tensor'],
Expand All @@ -122,19 +134,22 @@ logical_axis_rules: [
['activation_vocab', ['tensor', 'sequence']],
['activation_vocab', 'tensor'],
['activation_vocab', 'sequence'],
['activation_stage','stage'],
['mlp', ['fsdp_transpose', 'tensor', 'autoregressive']],
['vocab', ['tensor', 'autoregressive']],
['embed', ['fsdp', 'fsdp_transpose', 'sequence']],
['embed', ['fsdp', 'sequence']],
['norm', 'tensor'],
['heads', ['tensor', 'autoregressive']],
['layers', 'stage'],
['kv', []],
['cache_batch', []],
['cache_heads', ['autoregressive', 'tensor']],
['cache_kv', []],
['cache_sequence', []],
]
data_sharding: [['data', 'fsdp', 'fsdp_transpose', 'sequence', 'tensor', 'autoregressive']]
# Axes used for DCN must be earlier in this list than ICI, see (b/339009148) for details
data_sharding: [['stage', 'data', 'fsdp', 'fsdp_transpose', 'sequence', 'tensor', 'autoregressive']]

# One axis for each parallelism type may hold a placeholder (-1)
# value to auto-shard based on available slices and devices.
Expand All @@ -145,13 +160,15 @@ dcn_fsdp_parallelism: 1
dcn_fsdp_transpose_parallelism: 1
dcn_sequence_parallelism: 1 # never recommended
dcn_tensor_parallelism: 1 # never recommended
dcn_pipeline_parallelism: 1
dcn_autoregressive_parallelism: 1 # never recommended
ici_data_parallelism: 1
ici_fsdp_parallelism: -1 # recommended ICI axis to be auto-sharded
ici_fsdp_transpose_parallelism: 1
ici_sequence_parallelism: 1
ici_tensor_parallelism: 1
ici_autoregressive_parallelism: 1
ici_pipeline_parallelism: 1

# The number of TPU slices is automatically determined, you should not set this explicitly. For ahead of time compilation,
# you should set compile_toplogy_num_slices, which will in turn set this value. For non-TPU environments this is set to 1.
Expand Down
2 changes: 1 addition & 1 deletion MaxText/layers/embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ def __call__(self, inputs: Array) -> Array:
output = jnp.dot(one_hot, jnp.asarray(self.embedding, self.dtype))
else:
output = jnp.asarray(self.embedding, self.dtype)[inputs]
output = nn.with_logical_constraint(output, ("activation_batch", "activation_length", "activation_embed"))
output = nn.with_logical_constraint(output, ("activation_embed_and_logits_batch", "activation_length", "activation_embed"))
return output

def attend(self, query: Array) -> Array:
Expand Down
129 changes: 87 additions & 42 deletions MaxText/layers/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
# pylint: disable=arguments-differ
# pylint: disable=no-name-in-module

from typing import Callable, Optional
from typing import Any, Callable, Optional


from flax import linen as nn
Expand All @@ -28,6 +28,7 @@
from layers import embeddings
from layers import linears
from layers import normalizations, quantizations
from layers import pipeline

Array = common_types.Array
Config = common_types.Config
Expand Down Expand Up @@ -145,6 +146,25 @@ def __call__(

return layer_output, None if cfg.scan_layers else layer_output

class SequentialBlockDecoderLayers(nn.Module):
"""Sequential unscanned series of decoder layers."""
decoder_layer: Any
num_decoder_layers: int
config: Config
mesh: Mesh
quant: Quant

@nn.compact
def __call__(self, inputs: jnp.ndarray, decoder_segment_ids, decoder_positions, deterministic, model_mode) -> jnp.ndarray:
for lyr in range(self.num_decoder_layers):
inputs = self.decoder_layer(config=self.config, mesh=self.mesh, name=f"layers_{lyr}", quant=self.quant)(
inputs,
decoder_segment_ids,
decoder_positions,
deterministic,
model_mode,
)
return inputs

class Decoder(nn.Module):
"""A stack of decoder layers as a part of an encoder-decoder architecture."""
Expand Down Expand Up @@ -174,6 +194,10 @@ def get_decoder_layer(self):
from layers import gpt3

return gpt3.Gpt3DecoderLayer
elif self.config.decoder_block == "simple":
from layers import simple_layer

return simple_layer.SimpleDecoderLayer
else:
raise ValueError(f"Incorrect decoder_block name {self.config.decoder_block=}")

Expand All @@ -187,6 +211,34 @@ def get_norm_layer(self):
else:
raise ValueError(f"Incorrect decoder_block name {self.config.decoder_block=}")

def scan_decoder_layers(self, cfg, decoder_layer, length, metdata_axis_name, mesh):
initializing = self.is_mutable_collection("params")
params_spec = cfg.param_scan_axis if initializing else ScanIn(cfg.param_scan_axis)
cache_spec = 0
scan_fn = nn.scan(
decoder_layer,
variable_axes={
"params": params_spec,
"cache": cache_spec,
"intermediates": 0,
"aqt": 0,
"_overwrite_with_gradient": 0,
},
split_rngs={
"params": True,
"dropout": cfg.enable_dropout,
},
in_axes=(
nn.broadcast,
nn.broadcast,
nn.broadcast,
nn.broadcast,
),
length=length,
metadata_params={nn.PARTITION_NAME: metdata_axis_name},
)
return scan_fn(config=cfg, mesh=mesh, name="layers", quant=self.quant)

@nn.compact
def __call__(
self,
Expand Down Expand Up @@ -266,53 +318,46 @@ def __call__(
else:
assert cfg.remat_policy == "full", "Remat policy needs to be on list of remat policies"
policy = None
BlockLayer = nn.remat( # pylint: disable=invalid-name
BlockLayer,
prevent_cse=not cfg.scan_layers,
policy=policy,
static_argnums=(-1, -2, -3, -4, -5),
)
if cfg.scan_layers:
initializing = self.is_mutable_collection("params")
params_spec = cfg.param_scan_axis if initializing else ScanIn(cfg.param_scan_axis)
cache_spec = 0
y, _ = nn.scan(
BlockLayer,
variable_axes={
"params": params_spec,
"cache": cache_spec,
"intermediates": 0,
"aqt": 0,
"_overwrite_with_gradient": 0,
},
split_rngs={
"params": True,
"dropout": cfg.enable_dropout,
},
in_axes=(
nn.broadcast,
nn.broadcast,
nn.broadcast,
nn.broadcast,
),
length=cfg.num_decoder_layers,
metadata_params={nn.PARTITION_NAME: "layers"},
)(config=cfg, mesh=mesh, name="layers", quant=self.quant)(
y,
decoder_segment_ids,
decoder_positions,
deterministic,
model_mode,
)

RemattedBlockLayer = nn.remat( # pylint: disable=invalid-name
BlockLayer,
prevent_cse=not cfg.scan_layers,
policy=policy,
static_argnums=(-1, -2, -3, -4, -5),
)
if cfg.using_pipeline_parallelism:
if cfg.num_layers_per_pipeline_stage == 1:
stage_module = BlockLayer(config=cfg, mesh=mesh, quant=self.quant)
elif cfg.scan_layers:
stage_module = self.scan_decoder_layers(cfg, RemattedBlockLayer, cfg.num_layers_per_pipeline_stage, "layers_per_stage", mesh)
elif not cfg.scan_layers:
stage_module=SequentialBlockDecoderLayers(decoder_layer=RemattedBlockLayer, num_decoder_layers=cfg.num_layers_per_pipeline_stage, config=cfg, mesh=mesh,quant=self.quant)

y = pipeline.Pipeline(config=cfg, mesh=mesh, layers=stage_module, remat_policy=policy)(
y,
decoder_segment_ids,
decoder_positions,
deterministic,
model_mode,
)
else:
for lyr in range(cfg.num_decoder_layers):
y = BlockLayer(config=cfg, mesh=mesh, name=f"layers_{lyr}", quant=self.quant)(
if cfg.scan_layers:
y, _ = self.scan_decoder_layers(cfg, RemattedBlockLayer, cfg.num_decoder_layers, "layers", mesh)(
y,
decoder_segment_ids,
decoder_positions,
deterministic,
model_mode,
)
else:
for lyr in range(cfg.num_decoder_layers):
y = RemattedBlockLayer(config=cfg, mesh=mesh, name=f"layers_{lyr}", quant=self.quant)(
y,
decoder_segment_ids,
decoder_positions,
deterministic,
model_mode,
)

y = self.get_norm_layer()(
dtype=cfg.dtype,
Expand Down Expand Up @@ -340,7 +385,7 @@ def __call__(
)(
y
) # We do not quantize the logits matmul.
logits = nn.with_logical_constraint(logits, ("activation_batch", "activation_length", "activation_vocab"))
logits = nn.with_logical_constraint(logits, ("activation_embed_and_logits_batch", "activation_length", "activation_vocab"))
logits = logits.astype(jnp.float32)
return logits

Expand Down
Loading

0 comments on commit 7cdca96

Please sign in to comment.