Skip to content

Commit

Permalink
[MLPerf GPT3] add c4_mlperf input pipeline and eval step
Browse files Browse the repository at this point in the history
  • Loading branch information
ZhiyuLi-goog committed Feb 13, 2024
1 parent 3806cd4 commit 3b22241
Show file tree
Hide file tree
Showing 18 changed files with 585 additions and 74 deletions.
3 changes: 3 additions & 0 deletions MaxText/configs/base.yml
Original file line number Diff line number Diff line change
Expand Up @@ -225,3 +225,6 @@ decode_sampling_strategy: "greedy" # decode_sampling_strategy should be one of g
decode_sampling_nucleus_p: -1 # set if you're doing nucleus / top-p
decode_sampling_top_k: 0 # set if you're doing top-k
decode_sampling_temperature: 1.

eval_interval: -1 # the specific number of train step between eval_step
target_eval_loss: 0. # early stop once reaching target eval_loss
12 changes: 12 additions & 0 deletions MaxText/configs/models/gpt3-175b.yml
Original file line number Diff line number Diff line change
Expand Up @@ -32,3 +32,15 @@ use_iota_embed: True
fused_qkv: True
opt_type: "adam_pax"
decoder_block: "gpt3"
dataset_path: "gs:https://mlperf-llm-public2"
dataset_name: "c4/en:3.0.4"
eval_dataset_name: "c4/en:3.0.5"
assets_path: "gs:https://mlperf-llm-public2/vocab"
vocab_relative_path: "c4_en_301_5Mexp2_spm.model"
gradient_clipping_threshold: 1.
adam_b1: 0.9
adam_b2: 0.95
adam_eps: 1.e-8
adam_weight_decay: 0.1
checkpoint_period: 10_000
target_eval_loss: 2.69
5 changes: 5 additions & 0 deletions MaxText/configs/models/gpt3-22b.yml
Original file line number Diff line number Diff line change
Expand Up @@ -33,3 +33,8 @@ use_iota_embed: True
fused_qkv: True
opt_type: "adam_pax"
decoder_block: "gpt3"
gradient_clipping_threshold: 1.
adam_b1: 0.9
adam_b2: 0.95
adam_eps: 1.e-8
adam_weight_decay: 0.1
5 changes: 5 additions & 0 deletions MaxText/configs/models/gpt3-52k.yml
Original file line number Diff line number Diff line change
Expand Up @@ -32,3 +32,8 @@ use_iota_embed: True
fused_qkv: True
opt_type: "adam_pax"
decoder_block: "gpt3"
gradient_clipping_threshold: 1.
adam_b1: 0.9
adam_b2: 0.95
adam_eps: 1.e-8
adam_weight_decay: 0.1
2 changes: 1 addition & 1 deletion MaxText/decode.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,7 @@ def decode_loop(config, state=None):
# Model and Optimizer definition
quant = quantizations.configure_quantization(config)
model = Transformer(config, mesh = mesh, quant=quant)
_, sp_tokenizer = create_data_iterator_with_tokenizer(config, mesh, add_bos = True, add_eos=False)
_, _, sp_tokenizer = create_data_iterator_with_tokenizer(config, mesh, add_bos = True, add_eos=False)
state, state_mesh_annotations = max_utils.setup_decode_state(
model, config, rng, mesh, None
)
Expand Down
7 changes: 5 additions & 2 deletions MaxText/input_pipeline/_grain_data_processing.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@
from input_pipeline import _grain_operations
from input_pipeline import _grain_tokenizer

import multihost_dataloading

def get_datasets(
config: ml_collections.ConfigDict
):
Expand Down Expand Up @@ -169,6 +171,7 @@ def preprocessing_pipeline(
worker_count=grain_worker_count,
)

data_iter = iter(dataloader)
multihost_gen = multihost_dataloading.MultiHostDataLoadIterator(dataloader, global_mesh)

return data_iter
# Return multi-host jax.Array prep iterator
return multihost_gen
12 changes: 8 additions & 4 deletions MaxText/input_pipeline/_tfds_data_processing.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
import tensorflow_datasets as tfds
import jax

import multihost_dataloading
import tokenizer
import sequence_packing

Expand Down Expand Up @@ -161,7 +162,10 @@ def truncate_to_max_allowable_length(x, max_length):
if prefetch_size:
dataset = dataset.prefetch(prefetch_size)

return iter(dataset.as_numpy_iterator())
multihost_gen = multihost_dataloading.MultiHostDataLoadIterator(dataset, global_mesh)

# Return multi-host jax.Array prep iterator
return multihost_gen


def get_datasets(
Expand All @@ -187,7 +191,7 @@ def get_datasets(
# eval_data = get_raw_dataset(eval_ds_builder, config.eval_split)
eval_ds = eval_ds_builder.as_dataset(split=config.eval_split,
read_config = read_config,
shuffle_files=config.enable_data_shuffling)
shuffle_files=False)
eval_ds = eval_ds.shard(num_shards = jax.process_count(), index = jax.process_index())
eval_ds = normalize_features(eval_ds)

Expand Down Expand Up @@ -244,10 +248,11 @@ def filter_keys(record):
eval_ds,
eval_batch_size,
global_mesh,
shuffle=config.enable_data_shuffling,
shuffle=False,
pack_examples=False,
max_length=config.max_target_length,
shift=False,
drop_remainder=False,
data_shuffle_seed = data_shuffle_seed,)

predict_iter = preprocessing_pipeline(
Expand All @@ -262,4 +267,3 @@ def filter_keys(record):
data_shuffle_seed = data_shuffle_seed,)

return train_iter, eval_iter, predict_iter, sp_tokenizer

Loading

0 comments on commit 3b22241

Please sign in to comment.