Skip to content

Commit

Permalink
Small changes for MosaicML cloud usage (EleutherAI#778)
Browse files Browse the repository at this point in the history
* debug print

* add force multi arg

* Add a JSON compatible config file

* debug print

* fix config

* update global num gpus

* remove

* allow data to be processed on all nodes

* try local rank

* switch to local rank

* clean up neox changes

* testing change to config

* revert config change

* attempt to remove the need for the hostfile

* revert
  • Loading branch information
dakinggg committed Feb 3, 2023
1 parent 841679f commit 0c831c3
Show file tree
Hide file tree
Showing 4 changed files with 96 additions and 2 deletions.
78 changes: 78 additions & 0 deletions configs/125M-json.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
{
"pipe-parallel-size": 1,
"model-parallel-size": 1,

"num-layers": 12,
"hidden-size": 768,
"num-attention-heads": 12,
"seq-length": 2048,
"max-position-embeddings": 2048,
"norm": "layernorm",
"pos-emb": "rotary",
"no-weight-tying": true,
"gpt_j_residual": false,
"output_layer_parallelism": "column",

"scaled-upper-triang-masked-softmax-fusion": false,
"bias-gelu-fusion": false,

"init_method": "small_init",
"output_layer_init_method": "wang_init",

"optimizer": {
"type": "Adam",
"params": {
"lr": 0.0006,
"betas": [0.9, 0.95],
"eps": 1.0e-8
}
},
"min_lr": 0.00006,

"zero_optimization": {
"stage": 1,
"allgather_partitions": true,
"allgather_bucket_size": 500000000,
"overlap_comm": true,
"reduce_scatter": true,
"reduce_bucket_size": 500000000,
"contiguous_gradients": true
},

"train_micro_batch_size_per_gpu": 4,
"data-impl": "mmap",

"checkpoint-activations": true,
"checkpoint-num-layers": 1,
"partition-activations": true,
"synchronize-each-layer": true,

"gradient_clipping": 1.0,
"weight-decay": 0.1,
"hidden-dropout": 0.0,
"attention-dropout": 0.0,

"fp16": {
"enabled": true,
"loss_scale": 0,
"loss_scale_window": 1000,
"hysteresis": 2,
"min_loss_scale": 1
},

"train-iters": 320000,
"lr-decay-iters": 320000,
"distributed-backend": "nccl",
"lr-decay-style": "cosine",
"warmup": 0.01,
"checkpoint-factor": 10000,
"eval-interval": 1000,
"eval-iters": 10,

"log-interval": 100,
"steps_per_print": 10,
"keep-last-n-checkpoints": 4,
"wall_clock_breakdown": true,

"hostfile": "/mock_path"
}
3 changes: 3 additions & 0 deletions megatron/data/data_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@ def build_the_dataset(

def build_train_valid_test_datasets(
data_prefix,
use_shared_fs,
data_impl,
splits_string,
train_valid_test_num_samples,
Expand Down Expand Up @@ -131,6 +132,7 @@ def build_dataset(index, name):
train_valid_test_num_samples[index],
seq_length,
seed,
use_shared_fs=use_shared_fs
)
return dataset

Expand Down Expand Up @@ -394,6 +396,7 @@ def build_train_valid_test_data_iterators(neox_args):
# split dataset into train, valid and test from data_path
train_ds, valid_ds, test_ds = build_train_valid_test_datasets(
data_prefix=neox_args.data_path,
use_shared_fs=neox_args.use_shared_fs,
data_impl=neox_args.data_impl,
splits_string=neox_args.split,
train_valid_test_num_samples=train_val_test_num_samples,
Expand Down
11 changes: 9 additions & 2 deletions megatron/data/gpt2_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ def __init__(
seq_length,
seed,
build_index_mappings=True,
use_shared_fs=True
):

self.name = name
Expand All @@ -56,6 +57,7 @@ def __init__(
num_samples,
seq_length,
seed,
use_shared_fs=use_shared_fs
)
self.shuffle_idx_len = self.shuffle_idx.shape[0] - 1
self.sample_idx_len = self.sample_idx.shape[0] - 1
Expand Down Expand Up @@ -110,7 +112,7 @@ def __getitem__(self, idx):


def _build_index_mappings(
name, data_prefix, documents, sizes, num_samples, seq_length, seed
name, data_prefix, documents, sizes, num_samples, seq_length, seed, use_shared_fs=True
):
"""Build doc-idx, sample-idx, and shuffle-idx.
doc-idx: is an array (ordered) of documents to be used in training.
Expand All @@ -134,8 +136,13 @@ def _build_index_mappings(
sample_idx_filename = _filename + "_sample_idx.npy"
shuffle_idx_filename = _filename + "_shuffle_idx.npy"

if not use_shared_fs:
should_process_dataset = int(os.environ['LOCAL_RANK']) == 0
else:
should_process_dataset = torch.distributed.get_rank() == 0

# Build the indexed mapping if not exist.
if torch.distributed.get_rank() == 0:
if should_process_dataset:
if (
(not os.path.isfile(doc_idx_filename))
or (not os.path.isfile(sample_idx_filename))
Expand Down
6 changes: 6 additions & 0 deletions megatron/neox_arguments/neox_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -687,6 +687,12 @@ class NeoXArgsTraining(NeoXArgsTemplate):
Path to combined dataset to split.
"""

use_shared_fs: bool = True
"""
Whether to use a shared filesystem for data loading. If False, local rank 0 on all nodes will preprocess the data,
otherwise only global rank 0 will preprocess the data. This is implemented in megatron/data/gpt2_dataset.py::_build_index_mappings.
"""

train_data_paths: list = None
"""
List of paths to train datasets.
Expand Down

0 comments on commit 0c831c3

Please sign in to comment.