Skip to content

Commit

Permalink
Initial DiLoCo experiments
Browse files Browse the repository at this point in the history
  • Loading branch information
jonb377 committed Jun 25, 2024
1 parent 7a872f9 commit 6af32a1
Show file tree
Hide file tree
Showing 7 changed files with 118 additions and 18 deletions.
10 changes: 9 additions & 1 deletion MaxText/configs/base.yml
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,7 @@ jax_cache_dir: "~/jax_cache"
hardware: 'tpu' # Supported hardware types are 'tpu', 'gpu', 'gpu_multiprocess' and 'cpu'

# Parallelism
mesh_axes: ['data', 'stage', 'fsdp', 'fsdp_transpose', 'sequence', 'tensor', 'autoregressive']
mesh_axes: ['client', 'data', 'stage', 'fsdp', 'fsdp_transpose', 'sequence', 'tensor', 'autoregressive']
logical_axis_rules: [
['activation_batch', ['data', 'fsdp', 'fsdp_transpose',]],
# For pipeline parallelism the pre and post decoder layer tensors' batch dimension is sharded by stages.
Expand Down Expand Up @@ -183,13 +183,15 @@ dcn_sequence_parallelism: 1 # never recommended
dcn_tensor_parallelism: 1 # never recommended
dcn_pipeline_parallelism: 1
dcn_autoregressive_parallelism: 1 # never recommended
dcn_client_parallelism: 1
ici_data_parallelism: 1
ici_fsdp_parallelism: -1 # recommended ICI axis to be auto-sharded
ici_fsdp_transpose_parallelism: 1
ici_sequence_parallelism: 1
ici_tensor_parallelism: 1
ici_autoregressive_parallelism: 1
ici_pipeline_parallelism: 1
ici_client_parallelism: 1

# The number of TPU slices is automatically determined, you should not set this explicitly. For ahead of time compilation,
# you should set compile_toplogy_num_slices, which will in turn set this value. For non-TPU environments this is set to 1.
Expand Down Expand Up @@ -224,6 +226,12 @@ grain_worker_count: 1
steps: 150_001 # If set to -1 then will inherit value from learning_rate_schedule_steps
log_period: 100 # Flushes Tensorboard

# DiLoCo Parameters
# Each worker replica will execute `client_sync_period` training steps before global synchronization.
diloco_sync_period: 20
diloco_outer_momentum: 0.9
diloco_outer_lr: 0.7

# We take inspiration from Llama2's learning rate (LR) schedule, see https://arxiv.org/pdf/2307.09288.pdf section 2.2
# Learning rate schedule has either two or three parts:
# 1) Linear warmup from 0 to [learning_rate] over steps 0 to [learning_rate_schedule_steps * warmup_steps_fraction]
Expand Down
69 changes: 69 additions & 0 deletions MaxText/diloco.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
import jax
import jax.numpy as jnp
import max_logging
import optimizers
from flax.training import train_state


def get_diloco_train_step(config, train_step):
# Defer import until needed
import drjax

@drjax.program(placements={'clients': config.diloco_num_workers})
def diloco_train_step(model, config, state, data, dropout_rng):
"""
Run a DiLoCo round. DiLoCo executes multiple optimization steps within
each worker, then synchronizes the net change in the model to perform a
global state update.
In this implementation, each worker initializes its own AdamW optimizer
during each round of optimization.
"""

def scan_fn(carry, data):
""" Executes a single inner optimization step. """
state, step = carry
nextrng = jax.jit(jax.random.fold_in)(dropout_rng, step)
state, metrics = train_step(model, config, state, data, nextrng)
return (state, step + 1), metrics

def worker_round(start_step, params, worker_inputs):
"""
Execute one local round of optimization. This executes
`config.diloco_sync_period` steps locally without any cross-client
communication.
"""
# Initialize an AdamW optimizer for the local worker.
# TODO(jonbolin): Need to preserve optimizer state to use an LR schedule
adamw_tx = optimizers.get_optimizer(config, config.learning_rate, inner_diloco=True)
state = train_state.TrainState.create(apply_fn=model.apply, params=params, tx=adamw_tx)

# Scan over local steps, carrying the current step number and updated state.
(final_state, _), metrics = jax.lax.scan(scan_fn, (state, start_step), worker_inputs)
metrics = jax.tree.map(lambda x: jnp.average(x), metrics)

# Calculate the net change in model state.
model_delta = jax.tree.map(lambda x, y: x - y, params, final_state.params)
return model_delta, metrics

max_logging.log('Running training with DiLoCo')

# Broadcast model parameters
params_in_clients = drjax.broadcast(state.params)
start_step_in_clients = drjax.broadcast(state.step)
#init_rng_in_clients = drjax.broadcast(init_rng)

# Run optimization locally on each worker. The final state within each worker
# is discarded, only the aggregate change from each worker is reported.
local_grads, local_metrics = drjax.map_fn(worker_round, (start_step_in_clients, params_in_clients, data))

# DiLoCo Algorithm
# Average the outer gradients across workers
average_grad = drjax.reduce_mean(local_grads)
total_metrics = drjax.reduce_mean(local_metrics)
# Update global state.
state = state.apply_gradients(grads=average_grad)

return state, total_metrics

return diloco_train_step
28 changes: 14 additions & 14 deletions MaxText/input_pipeline/input_pipeline_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,21 +130,21 @@ def __next__(self):
@staticmethod
def raw_generate_synthetic_data(config):
"""Generates a single batch of synthetic data"""
data_shape = (config.global_batch_size_to_load, config.max_target_length)
if config.diloco_num_workers> 1:
# Generate an input of shape NumClients x StepsBetweenSyncs x ClientBatch x Sequence
num_workers = config.diloco_num_workers
steps_per_sync = config.diloco_sync_period
client_batch = config.global_batch_size_to_load // config.diloco_num_workers
seq = config.max_target_length
data_shape = (num_workers, steps_per_sync, client_batch, seq)
output = {}
output["inputs"] = jax.numpy.zeros((config.global_batch_size_to_load, config.max_target_length), dtype=jax.numpy.int32)
output["inputs_position"] = jax.numpy.zeros(
(config.global_batch_size_to_load, config.max_target_length), dtype=jax.numpy.int32
)
output["inputs_segmentation"] = jax.numpy.ones(
(config.global_batch_size_to_load, config.max_target_length), dtype=jax.numpy.int32
)
output["targets"] = jax.numpy.zeros((config.global_batch_size_to_load, config.max_target_length), dtype=jax.numpy.int32)
output["targets_position"] = jax.numpy.zeros(
(config.global_batch_size_to_load, config.max_target_length), dtype=jax.numpy.int32
)
output["targets_segmentation"] = jax.numpy.ones(
(config.global_batch_size_to_load, config.max_target_length), dtype=jax.numpy.int32
)
output["inputs"] = jax.numpy.zeros(data_shape, dtype=jax.numpy.int32)
output["inputs_position"] = jax.numpy.zeros(data_shape, dtype=jax.numpy.int32)
output["inputs_segmentation"] = jax.numpy.ones(data_shape, dtype=jax.numpy.int32)
output["targets"] = jax.numpy.zeros(data_shape, dtype=jax.numpy.int32)
output["targets_position"] = jax.numpy.zeros(data_shape, dtype=jax.numpy.int32)
output["targets_segmentation"] = jax.numpy.ones(data_shape, dtype=jax.numpy.int32)
return output


Expand Down
2 changes: 2 additions & 0 deletions MaxText/max_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -336,6 +336,7 @@ def create_device_mesh(config, devices=None):
multi_slice_env = num_slices > 1

dcn_parallelism = [
config.dcn_client_parallelism,
config.dcn_data_parallelism,
config.dcn_pipeline_parallelism,
config.dcn_fsdp_parallelism,
Expand All @@ -345,6 +346,7 @@ def create_device_mesh(config, devices=None):
config.dcn_autoregressive_parallelism,
]
ici_parallelism = [
config.ici_client_parallelism,
config.ici_data_parallelism,
config.ici_pipeline_parallelism,
config.ici_fsdp_parallelism,
Expand Down
11 changes: 9 additions & 2 deletions MaxText/optimizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,16 @@
import jax.numpy as jnp


def get_optimizer(config, learning_rate_schedule):
def get_optimizer(config, learning_rate_schedule, inner_diloco=False):
"""create optimizer"""
if config.opt_type == "adamw":
if config.diloco_num_workers > 1 and not inner_diloco:
# When training DiLoCo (https://arxiv.org/pdf/2311.08105), use SGD with Nesterov momentum for the outer optimizer
return optax.sgd(
config.diloco_outer_lr,
momentum=config.diloco_outer_momentum,
nesterov=True,
)
elif config.opt_type == "adamw":
# Create AdamW Optimizer following Llama2's training details, see https://arxiv.org/pdf/2307.09288.pdf section 2.2
return optax.adamw(
learning_rate_schedule,
Expand Down
12 changes: 12 additions & 0 deletions MaxText/pyconfig.py
Original file line number Diff line number Diff line change
Expand Up @@ -292,6 +292,13 @@ def user_init(raw_keys):
if raw_keys["steps"] == -1:
raw_keys["steps"] = raw_keys["learning_rate_schedule_steps"]

raw_keys["diloco_num_workers"] = diloco_num_workers(raw_keys)
if raw_keys["diloco_num_workers"] > 1:
assert raw_keys["diloco_sync_period"] > 0, f"diloco_sync_period must be positive when using DiLoCo"
# Adjust data sharding to reflect DiLoCo shaping. Input data will have shape
# NumClients x StepsBetweenSyncs x ClientBatch x Sequence
raw_keys["data_sharding"] = ["client", None, *raw_keys["data_sharding"]]

emb_scale, num_head_scale, mlp_dim_scale, layer_scale = get_individual_scales(raw_keys["global_parameter_scale"])
raw_keys["emb_dim"] = 2**emb_scale * raw_keys["base_emb_dim"]
raw_keys["num_query_heads"] = 2**num_head_scale * raw_keys["base_num_query_heads"]
Expand All @@ -300,6 +307,8 @@ def user_init(raw_keys):
raw_keys["num_decoder_layers"] = 2**layer_scale * raw_keys["base_num_decoder_layers"]

raw_keys["global_batch_size_to_load"], raw_keys["global_batch_size_to_train_on"] = calculate_global_batch_sizes(raw_keys)
assert raw_keys["global_batch_size_to_train_on"] % raw_keys["diloco_num_workers"] == 0, \
f"Global batch ({raw_keys['global_batch_size_to_train_on']}) must be divisible by diloco_num_workers ({raw_keys['diloco_num_workers']}) with DiLoCo training"
raw_keys["num_slices"] = get_num_slices(raw_keys)
raw_keys["quantization_local_shard_count"] = get_quantization_local_shard_count(raw_keys)

Expand Down Expand Up @@ -468,6 +477,9 @@ def get_quantization_local_shard_count(raw_keys):
else:
return raw_keys["quantization_local_shard_count"]

def diloco_num_workers(raw_keys) -> bool:
return int(raw_keys["ici_client_parallelism"]) * int(raw_keys["dcn_client_parallelism"])

def using_pipeline_parallelism(raw_keys) -> bool:
return int(raw_keys['ici_pipeline_parallelism']) > 1 or int(raw_keys['dcn_pipeline_parallelism']) > 1

Expand Down
4 changes: 3 additions & 1 deletion MaxText/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@
import pyconfig
# pylint: disable-next=unused-import
import register_jax_proxy_backend
from diloco import get_diloco_train_step
from vertex_tensorboard import VertexTensorboardManager
# Placeholder: internal

Expand Down Expand Up @@ -432,14 +433,15 @@ def train_loop(config, state=None):
eval_data_iterator,
state,
) = setup_train_loop(config)
train_step_fn = get_diloco_train_step(config, train_step) if config.diloco_num_workers > 1 else train_step
# pylint: disable=line-too-long
(
functional_train,
in_shard_train,
out_shard_train,
static_argnums_train,
donate_argnums_train,
) = maxtext_utils.get_functional_train_with_signature(train_step, mesh, state_mesh_annotations, model, config)
) = maxtext_utils.get_functional_train_with_signature(train_step_fn, mesh, state_mesh_annotations, model, config)

if eval_data_iterator:
# pylint: disable=line-too-long
Expand Down

0 comments on commit 6af32a1

Please sign in to comment.