Skip to content

Commit

Permalink
fix issue with data
Browse files Browse the repository at this point in the history
  • Loading branch information
Quentin TastyRice committed Jun 21, 2023
1 parent c00ce70 commit fba32d5
Show file tree
Hide file tree
Showing 4 changed files with 7 additions and 6 deletions.
4 changes: 2 additions & 2 deletions configs/1-3B.yml
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
# parallelism settings ( you will want to change these based on your cluster setup, ideally scheduling pipeline stages
# across the node boundaries )
"pipe_parallel_size": 1,
"model_parallel_size": 1,
"model_parallel_size": 2,

# model settings
"num_layers": 24,
Expand Down Expand Up @@ -74,7 +74,7 @@
},

# misc. training settings
"train_iters": 320000,
"train_iters": 320,
"lr_decay_iters": 320000,
"distributed_backend": "nccl",
"lr_decay_style": "cosine",
Expand Down
Binary file added data/enwik8/enwik8.zip
Binary file not shown.
3 changes: 2 additions & 1 deletion megatron/data/data_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ def build_the_dataset(
indexed_dataset = make_indexed_dataset(data_prefix, data_impl, skip_warmup)
if label_prefix is None:
label_dataset = None

else:
label_dataset = make_indexed_dataset(label_prefix, data_impl, skip_warmup)

Expand Down Expand Up @@ -207,7 +208,7 @@ def build_weighted_datasets(
for i, (train_path, label_path, valid_path, test_path) in enumerate(
zip_longest(
neox_args.train_data_paths,
neox_args.label_data_paths if neox_args.label_data_paths else [],
neox_args.label_data_paths if neox_args.label_data_paths else None,
neox_args.valid_data_paths,
neox_args.test_data_paths,
)
Expand Down
6 changes: 3 additions & 3 deletions megatron/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -273,7 +273,7 @@ def pretrain(neox_args):

def _get_batch(neox_args, tokenizer, keys, data, datatype):
"""Support function for get_batch / get_batch pipe (to avoid code repetition)"""
_keys = [k for k in keys if k in data] if data else keys
_keys = [k for k in keys if k in data.keys()] if data else keys
data_b = mpu.broadcast_data(_keys, data, datatype)

# Unpack.
Expand Down Expand Up @@ -304,7 +304,7 @@ def get_batch(neox_args, data_iterator):
"""Generate a batch"""

# Items and their type.
keys = ["text", "label"]
keys = ["text", "label"] if neox_args.label_data_paths else ["text"]
datatype = torch.int64

# Broadcast data.
Expand All @@ -324,7 +324,7 @@ def get_batch(neox_args, data_iterator):
def get_batch_pipe(data, neox_args, curr_scheduler=None):
"""A modification of get_batch() to work with the latest batch instead of an iterator."""
# Items and their type.
keys = ["text", "label"]
keys = ["text", "label"] if neox_args.label_data_paths else ["text"]
datatype = torch.int64

tokens, labels, loss_mask, attention_mask, position_ids = _get_batch(
Expand Down

0 comments on commit fba32d5

Please sign in to comment.