Skip to content

Commit

Permalink
Adding the possibility of passing a label dataset (#958)
Browse files Browse the repository at this point in the history
* add an optional `label` field passed in parallel with training data.

* minor fix; Add doc

* fix

* fix data can be None

* prevent loading optimizer

* add script

* Remove some print() stmts, make mask documentation clearer

* Add documentation for preprocess_data_with_mask.py

---------

Co-authored-by: Hailey Schoelkopf <[email protected]>
  • Loading branch information
honglu2875 and haileyschoelkopf committed Jun 7, 2023
1 parent eedf1a8 commit c00ce70
Show file tree
Hide file tree
Showing 7 changed files with 448 additions and 27 deletions.
7 changes: 7 additions & 0 deletions configs/neox_arguments.md
Original file line number Diff line number Diff line change
Expand Up @@ -1061,6 +1061,13 @@ Training Arguments
List of paths to train datasets.


- **label_data_paths**: list

Default = None

List of paths to label datasets (should be fully in sync with train data, not shifted by 1!).



- **test_data_paths**: list

Expand Down
10 changes: 9 additions & 1 deletion megatron/data/data_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,10 +60,15 @@ def build_the_dataset(
seed,
skip_warmup,
build_index_mappings=True,
label_prefix=None,
):
"""Build train/valid/test datasets."""

indexed_dataset = make_indexed_dataset(data_prefix, data_impl, skip_warmup)
if label_prefix is None:
label_dataset = None
else:
label_dataset = make_indexed_dataset(label_prefix, data_impl, skip_warmup)

total_num_of_documents = indexed_dataset.sizes.shape[0]
print_rank_0(" {}:".format(name))
Expand All @@ -79,6 +84,7 @@ def build_the_dataset(
seq_length,
seed,
build_index_mappings=build_index_mappings,
label_dataset=label_dataset,
)
return dataset

Expand Down Expand Up @@ -198,9 +204,10 @@ def build_weighted_datasets(
):
# build individual datasets
train_datasets, valid_datasets, test_datasets = [], [], []
for i, (train_path, valid_path, test_path) in enumerate(
for i, (train_path, label_path, valid_path, test_path) in enumerate(
zip_longest(
neox_args.train_data_paths,
neox_args.label_data_paths if neox_args.label_data_paths else [],
neox_args.valid_data_paths,
neox_args.test_data_paths,
)
Expand All @@ -216,6 +223,7 @@ def build_weighted_datasets(
seed=neox_args.seed,
skip_warmup=(not neox_args.mmap_warmup),
build_index_mappings=build_index_mappings,
label_prefix=label_path,
)
)

Expand Down
51 changes: 30 additions & 21 deletions megatron/data/gpt2_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,10 +38,12 @@ def __init__(
seed,
build_index_mappings=True,
use_shared_fs=True,
label_dataset=None,
):

self.name = name
self.indexed_dataset = indexed_dataset
self.label_dataset = label_dataset

# Checks
assert np.min(documents) >= 0
Expand Down Expand Up @@ -79,30 +81,37 @@ def __getitem__(self, idx):
doc_index_l = self.sample_idx[idx + 1][0]
offset_f = self.sample_idx[idx][1]
offset_l = self.sample_idx[idx + 1][1]
# Labels and texts are supposed to be fully in sync.
datasets = [self.indexed_dataset] if self.label_dataset is None else [self.indexed_dataset, self.label_dataset]
samples = []
# If we are within the same document, just extract the chunk.
if doc_index_f == doc_index_l:
sample = self.indexed_dataset.get(
self.doc_idx[doc_index_f],
offset=offset_f,
length=offset_l - offset_f + 1,
)
else:
# Otherwise, get the rest of the initial document.
sample_list = [
self.indexed_dataset.get(self.doc_idx[doc_index_f], offset=offset_f)
]
# Loop over all in between documents and add the entire document.
for i in range(doc_index_f + 1, doc_index_l):
sample_list.append(self.indexed_dataset.get(self.doc_idx[i]))
# And finally add the relevant portion of last document.
sample_list.append(
self.indexed_dataset.get(
self.doc_idx[doc_index_l], length=offset_l + 1
for n, dataset in enumerate(datasets):
if doc_index_f == doc_index_l:
samples.append(dataset.get(
self.doc_idx[doc_index_f],
offset=offset_f,
length=offset_l - offset_f + 1,
))
else:
# Otherwise, get the rest of the initial document.
sample_list = [
dataset.get(self.doc_idx[doc_index_f], offset=offset_f)
]
# Loop over all in between documents and add the entire document.
for i in range(doc_index_f + 1, doc_index_l):
sample_list.append(dataset.get(self.doc_idx[i]))
# And finally add the relevant portion of last document.
sample_list.append(
dataset.get(
self.doc_idx[doc_index_l], length=offset_l + 1
)
)
)
sample = np.concatenate(sample_list)
samples.append(np.concatenate(sample_list))

return {"text": np.array(sample, dtype=np.int64)}
if len(datasets) == 1:
return {"text": np.array(samples[0], dtype=np.int64)}
else:
return {"text": np.array(samples[0], dtype=np.int64), "label": np.array(samples[1], dtype=np.int64)}
except IndexError:
new_idx = idx % len(self)
print(
Expand Down
5 changes: 5 additions & 0 deletions megatron/neox_arguments/neox_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -719,6 +719,11 @@ class NeoXArgsTraining(NeoXArgsTemplate):
List of paths to train datasets.
"""

label_data_paths: list = None
"""
List of paths to label datasets (not shifted by 1 yet!).
"""

test_data_paths: list = None
"""
List of paths to test datasets.
Expand Down
20 changes: 15 additions & 5 deletions megatron/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -273,11 +273,19 @@ def pretrain(neox_args):

def _get_batch(neox_args, tokenizer, keys, data, datatype):
"""Support function for get_batch / get_batch pipe (to avoid code repetition)"""
data_b = mpu.broadcast_data(keys, data, datatype)
_keys = [k for k in keys if k in data] if data else keys
data_b = mpu.broadcast_data(_keys, data, datatype)

# Unpack.
tokens_ = data_b["text"].long()
labels = tokens_[:, 1:].contiguous()
if "label" in data_b:
labels = torch.where(
data_b["label"].long() >= 0,
data_b["label"].long(),
torch.zeros_like(data_b["label"].long()),
)[:, 1:].contiguous()
else:
labels = tokens_[:, 1:].contiguous()
tokens = tokens_[:, :-1].contiguous()

# Get the masks and position ids.
Expand All @@ -286,15 +294,17 @@ def _get_batch(neox_args, tokenizer, keys, data, datatype):
eod_token=neox_args.tokenizer.eod,
eod_mask_loss=neox_args.eod_mask_loss,
)

# If `label` is present, any token < 0 (e.g., -100, the default for torch) skips the loss computation
if "label" in data_b:
loss_mask = (data_b["label"][:, 1:] >= 0).to(loss_mask.dtype)
return tokens, labels, loss_mask, attention_mask, position_ids


def get_batch(neox_args, data_iterator):
"""Generate a batch"""

# Items and their type.
keys = ["text"]
keys = ["text", "label"]
datatype = torch.int64

# Broadcast data.
Expand All @@ -314,7 +324,7 @@ def get_batch(neox_args, data_iterator):
def get_batch_pipe(data, neox_args, curr_scheduler=None):
"""A modification of get_batch() to work with the latest batch instead of an iterator."""
# Items and their type.
keys = ["text"]
keys = ["text", "label"]
datatype = torch.int64

tokens, labels, loss_mask, attention_mask, position_ids = _get_batch(
Expand Down
1 change: 1 addition & 0 deletions megatron/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -422,6 +422,7 @@ def setup_for_inference_or_eval(
"checkpoint_activations": False,
"partition_activations": False,
"no_load_optim": True,
"optimizer": None, # prevent loading optimizer (no_load_optim alone won't work)
"zero_optimization": None, # disable zero optimization (won't be used in inference, and loading zero optimizer can cause errors)
}
if overwrite_values:
Expand Down
Loading

0 comments on commit c00ce70

Please sign in to comment.