Skip to content

Commit

Permalink
Add blended datasets
Browse files Browse the repository at this point in the history
  • Loading branch information
sdtblck committed May 11, 2021
1 parent cf37ce3 commit 33b887e
Show file tree
Hide file tree
Showing 12 changed files with 525 additions and 289 deletions.
8 changes: 7 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,13 @@ dmypy.json
wandb/

# data files
data/
data/**/*.idx
data/**/*.bin
data/**/*.json*
data/**/*.txt
data/**/*.gz
data/**/*.np*
data/**/*.npy
checkpoints/
.vscode/
*.pt
Expand Down
11 changes: 8 additions & 3 deletions configs/eleutherai_cluster.yml
Original file line number Diff line number Diff line change
@@ -1,9 +1,14 @@
# Data paths and options when using EleutherAI cluster
{
"data-path": "/mnt/ssd-cluster/data/enron/enron_text_document",
#"train-data-path": "/mnt/ssd-cluster/data/train/train_text_document",
#"test-data-path": "/mnt/ssd-cluster/data/test/test_text_document",
#"valid-data-path": "/mnt/ssd-cluster/data/valid/valid_text_document",
# or for weighted datasets:
# "train-data-paths": ["/mnt/ssd-cluster/data/enron/enron_text_document", "/mnt/ssd-cluster/data/enron/enron_text_document"],
# "test-data-paths": ["/mnt/ssd-cluster/data/enron/enron_text_document", "/mnt/ssd-cluster/data/enron/enron_text_document"],
# "valid-data-paths": ["/mnt/ssd-cluster/data/enron/enron_text_document", "/mnt/ssd-cluster/data/enron/enron_text_document"],
# "train-data-weights": [1., 2.],
# "test-data-weights": [2., 1.],
# "valid-data-weights": [0.5, 0.4],

"vocab-file": "/mnt/ssd-cluster/data/gpt2-vocab.json",
"merge-file": "/mnt/ssd-cluster/data/gpt2-merges.txt",
"save": "/mnt/ssd-cluster/checkpoints",
Expand Down
12 changes: 9 additions & 3 deletions configs/local_setup.yml
Original file line number Diff line number Diff line change
@@ -1,9 +1,15 @@
# Suggested data paths when using GPT-NeoX locally
{
"data-path": "data/enron/enron_text_document",
# "train-data-path": "data/train/train_text_document",
# "test-data-path": "data/test/test_text_document",
# "valid-data-path": "data/valid/valid_text_document",

# or for weighted datasets:
# "train-data-paths": ["data/enron/enron_text_document", "data/enron/enron_text_document"],
# "test-data-paths": ["data/enron/enron_text_document", "data/enron/enron_text_document"],
# "valid-data-paths": ["data/enron/enron_text_document", "data/enron/enron_text_document"],
# "train-data-weights": [1., 2.],
# "test-data-weights": [2., 1.],
# "valid-data-weights": [0.5, 0.4],

"vocab-file": "data/gpt2-vocab.json",
"merge-file": "data/gpt2-merges.txt",
"save": "checkpoints",
Expand Down
84 changes: 84 additions & 0 deletions megatron/data/blendable_dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
# coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Blendable dataset."""

import time

import numpy as np
import torch

from megatron import print_rank_0
from megatron import mpu


class BlendableDataset(torch.utils.data.Dataset):


def __init__(self, datasets, weights):

self.datasets = datasets
num_datasets = len(datasets)
assert num_datasets == len(weights)

self.size = 0
for dataset in self.datasets:
self.size += len(dataset)

# Normalize weights.
weights = np.array(weights, dtype=np.float64)
sum_weights = np.sum(weights)
assert sum_weights > 0.0
weights /= sum_weights

# Build indecies.
start_time = time.time()
assert num_datasets < 255
self.dataset_index = np.zeros(self.size, dtype=np.uint8)
self.dataset_sample_index = np.zeros(self.size, dtype=np.int64)

if torch.distributed.get_rank() == 0:
from megatron.data.gpt2_dataset import compile_helper
compile_helper()

# Simple barrier
tmp = torch.cuda.LongTensor([1])
torch.distributed.all_reduce(tmp)


# apparently the barrier isn't working correctly, so this is
# necessary for some reason upon first build
try:
from megatron.data import helpers
except:
time.sleep(5)
from megatron.data import helpers

helpers.build_blending_indices(self.dataset_index,
self.dataset_sample_index,
weights, num_datasets, self.size,
torch.distributed.get_rank() == 0)
print_rank_0('> elapsed time for building blendable dataset indices: '
'{:.2f} (sec)'.format(time.time() - start_time))


def __len__(self):
return self.size


def __getitem__(self, idx):
dataset_idx = self.dataset_index[idx]
sample_idx = self.dataset_sample_index[idx]
return self.datasets[dataset_idx][sample_idx]
Loading

0 comments on commit 33b887e

Please sign in to comment.