Skip to content

Commit

Permalink
Add MMap dataset batch viewer
Browse files Browse the repository at this point in the history
  • Loading branch information
uSaiPrashanth committed Jun 27, 2023
1 parent fba86e2 commit 51a8428
Show file tree
Hide file tree
Showing 2 changed files with 275 additions and 76 deletions.
89 changes: 13 additions & 76 deletions utils/batch_viewer.py
Original file line number Diff line number Diff line change
@@ -1,58 +1,6 @@
import os
import json

from mmap_dataset import MMapIndexedDataset
from tqdm import trange
import argparse
from typing import Literal
from tqdm import tqdm

import numpy as np
import pandas as pd

from megatron.data import data_utils
from megatron.neox_arguments import NeoXArgs

from megatron import mpu


def view_data(
args,
neox_args,
batch_fn: callable = None,
save_path: str = None,
):
# fake MPU setup (needed to init dataloader without actual GPUs or parallelism)
mpu.mock_model_parallel()

# overrides to config
neox_args.update_value("train_micro_batch_size_per_gpu", 1024)
neox_args.update_value("train_batch_size", 1024)
neox_args.update_value("num_workers", 8)

# init dataloader
train_dataloader, _, _ = data_utils.build_train_valid_test_data_iterators(neox_args=neox_args)

print(f"Starting batch iteration from step {args.start_iteration} until step {args.end_iteration}")
# iterate over dataloader
for i in tqdm(range(args.start_iteration, args.end_iteration)):
batch = next(train_dataloader)["text"].cpu().numpy()

if args.mode == "save":
# save full batches for each step in the range (WARNING: this may consume lots of storage!)
filename = f"batch{i}_bs{neox_args.train_micro_batch_size_per_gpu}"
np.save(os.path.join(save_path, filename), batch)
elif args.mode == "custom":
# dump user_defined statistic to a jsonl file (save_fn must return a dict)
log = batch_fn(batch, i)

filename = "stats.jsonl"
with open(os.path.join(save_path, filename), "w+") as f:
f.write(json.dumps(log) + "\n")
else:
raise ValueError(f'mode={mode} not acceptable--please pass either "save" or "custom" !')

del batch
print(f"Finished iteration from step {args.start_iteration} to {args.end_iteration}")


if __name__ == '__main__':

Expand All @@ -72,34 +20,23 @@ def view_data(
help="Train step to end logging (inclusive)"
)
parser.add_argument(
"--mode",
type=str,
choices=["save", "custom"],
help="Choose mode: 'save' to log all batches, and 'custom' to use user-defined statistic"
"load_path",
type = str,
default = '/mnt/ssd-1/pile_preshuffled/standard/document',
help = ("MMap dataset path with .bin and .idx files. Omit the .bin (or) .idx "
"Extension while specifying the path")
)
parser.add_argument(
"--save_path",
type=str,
default=0,
default="token_indicies",
help="Save path for files"
)
args = parser.parse_known_args()[0]

# init neox args
neox_args = NeoXArgs.consume_deepy_args()
# set start iter for dataloader
neox_args.update_value("iteration", args.start_iteration)


def save_fn(batch: np.array, iteration: int):
# define your own logic here
return {"iteration": iteration, "text": None}

os.makedirs(args.save_path, exist_ok=True)
filename = os.path.join(args.save_path, "indicies.npy")

dataset = MMapIndexedDataset(args.load_path, skip_warmup = True)
indicies = dataset[args.start_iteration: args.end_iteration + 1]
np.save(filename, indicies)

view_data(
args,
neox_args,
batch_fn=save_fn,
save_path=args.save_path,
)
262 changes: 262 additions & 0 deletions utils/mmap_dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,262 @@
# Copyright (c) 2021, EleutherAI
# This file is based on code by the authors denoted below and has been modified from its original version.
#
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.


# copied from gpt-neox/megatron/data/indexed_dataset.py
# Adapted to only include MMapDataset reader
# other slight modifications too
# *************
# **IMPORTANT**
# *************
# This Implementation assumes that the sequences in
# the dataset are always of sequence length 2049

import os
import shutil
import struct
from functools import lru_cache
from itertools import accumulate

import numpy as np
import torch

dtypes = {
1: np.uint8,
2: np.int8,
3: np.int16,
4: np.int32,
5: np.int64,
6: np.float32,
7: np.float64,
8: np.uint16,
}

def index_file_path(prefix_path):
return prefix_path + ".idx"


def data_file_path(prefix_path):
return prefix_path + ".bin"

class MMapIndexedDataset(torch.utils.data.Dataset):
class Index(object):
_HDR_MAGIC = b"MMIDIDX\x00\x00"

@classmethod
def writer(cls, path, dtype):
class _Writer(object):
def __enter__(self):
self._file = open(path, "wb")

# Write Magic string so we can check the file format then opening it again.
self._file.write(cls._HDR_MAGIC)
# Write version number
# Little endian unsigned 64 Bit integer
self._file.write(struct.pack("<Q", 1))
# Little endian unsigned 8 Bit integer
self._file.write(struct.pack("<B", code(dtype)))

return self

@staticmethod
def _get_pointers(sizes):
pointers = np.zeros(len(sizes), dtype=np.int64)
sizes = np.array(sizes, dtype=np.int64)

np.cumsum(sizes[:-1], out=pointers[1:])
pointers = pointers * dtype().itemsize
return pointers

def write(self, sizes, doc_idx):
pointers = self._get_pointers(sizes)

# Little endian unsigned 64 Bit integer
self._file.write(struct.pack("<Q", len(sizes)))
# Little endian unsigned 64 Bit integer
self._file.write(struct.pack("<Q", len(doc_idx)))

sizes = np.array(sizes, dtype=np.int32)
self._file.write(sizes.tobytes(order="C"))
del sizes

pointers = np.array(pointers, dtype=np.int64)
self._file.write(pointers.tobytes(order="C"))
del pointers

doc_idx = np.array(doc_idx, dtype=np.int64)
self._file.write(doc_idx.tobytes(order="C"))

def __exit__(self, exc_type, exc_val, exc_tb):
self._file.close()

return _Writer()

def __init__(self, path, skip_warmup=False):
with open(path, "rb") as stream:
magic_test = stream.read(9)
assert self._HDR_MAGIC == magic_test, (
"Index file doesn't match expected format. "
"Make sure that --dataset-impl is configured properly."
)
# Little endian unsigned 64 Bit integer
version = struct.unpack("<Q", stream.read(8))
assert (1,) == version

# Little endian unsigned 8 Bit integer
(dtype_code,) = struct.unpack("<B", stream.read(1))
self._dtype = dtypes[dtype_code]
self._dtype_size = self._dtype().itemsize

self._len = struct.unpack("<Q", stream.read(8))[0]
self._doc_count = struct.unpack("<Q", stream.read(8))[0]
offset = stream.tell()

if not skip_warmup:
print(" warming up index mmap file...")
_warmup_mmap_file(path)

self._bin_buffer_mmap = np.memmap(path, mode="r", order="C")
self._bin_buffer = memoryview(self._bin_buffer_mmap)
print(" reading sizes...")
self._sizes = np.frombuffer(
self._bin_buffer, dtype=np.int32, count=self._len, offset=offset
)
print(" reading pointers...")
self._pointers = np.frombuffer(
self._bin_buffer,
dtype=np.int64,
count=self._len,
offset=offset + self._sizes.nbytes,
)
print(" reading document index...")
self._doc_idx = np.frombuffer(
self._bin_buffer,
dtype=np.int64,
count=self._doc_count,
offset=offset + self._sizes.nbytes + self._pointers.nbytes,
)

def __del__(self):
self._bin_buffer_mmap._mmap.close()
del self._bin_buffer_mmap

@property
def dtype(self):
return self._dtype

@property
def sizes(self):
return self._sizes

@property
def doc_idx(self):
return self._doc_idx

@lru_cache(maxsize=8)
def __getitem__(self, i):
return self._pointers[i], self._sizes[i]

def __len__(self):
return self._len

def __init__(self, path, skip_warmup=False):
super().__init__()

self._path = None
self._index = None
self._bin_buffer = None

self._do_init(path, skip_warmup)

def __getstate__(self):
return self._path

def __setstate__(self, state):
self._do_init(state)

def _do_init(self, path, skip_warmup):
self._path = path
self._index = self.Index(index_file_path(self._path), skip_warmup)

if not skip_warmup:
print(" warming up data mmap file...")
_warmup_mmap_file(data_file_path(self._path))
print(" creating numpy buffer of mmap...")
self._bin_buffer_mmap = np.memmap(
data_file_path(self._path), mode="r", order="C"
)
print(" creating memory view of numpy buffer...")
self._bin_buffer = memoryview(self._bin_buffer_mmap)

def __del__(self):
self._bin_buffer_mmap._mmap.close()
del self._bin_buffer_mmap
del self._index

def __len__(self):
return len(self._index)

# @lru_cache(maxsize=8)
def __getitem__(self, idx):
if isinstance(idx, int):
ptr, size = self._index[idx]
np_array = np.frombuffer(
self._bin_buffer, dtype=self._index.dtype, count=size, offset=ptr
)
return np_array
elif isinstance(idx, slice):
start, stop, step = idx.indices(len(self))
if step != 1:
raise ValueError("Slices into indexed_dataset must be contiguous")
ptr = self._index._pointers[start]
sizes = self._index._sizes[idx]
offsets = list(accumulate(sizes))
total_size = sum(sizes)
np_array = np.frombuffer(
self._bin_buffer, dtype=self._index.dtype, count=total_size, offset=ptr
)
return np_array.reshape(-1, 2049)

def get(self, idx, offset=0, length=None):
"""Retrieves a single item from the dataset with the option to only
return a portion of the item.
get(idx) is the same as [idx] but get() does not support slicing.
"""
ptr, size = self._index[idx]
if length is None:
length = size - offset
ptr += offset * np.dtype(self._index.dtype).itemsize
np_array = np.frombuffer(
self._bin_buffer, dtype=self._index.dtype, count=length, offset=ptr
)
return np_array

@property
def sizes(self):
return self._index.sizes

@property
def doc_idx(self):
return self._index.doc_idx

def get_doc_idx(self):
return self._index._doc_idx

def set_doc_idx(self, doc_idx_):
self._index._doc_idx = doc_idx_

@property
def supports_prefetch(self):
return False

@staticmethod
def exists(path):
return os.path.exists(index_file_path(path)) and os.path.exists(
data_file_path(path)
)

0 comments on commit 51a8428

Please sign in to comment.