Skip to content

Commit

Permalink
Merge pull request #45 from EleutherAI/shard_memmap
Browse files Browse the repository at this point in the history
Add Megatron `.bin` file sharding and unsharding scripts
  • Loading branch information
haileyschoelkopf committed Jan 4, 2023
2 parents d0c0e23 + 187d565 commit df1ff85
Show file tree
Hide file tree
Showing 2 changed files with 130 additions and 0 deletions.
61 changes: 61 additions & 0 deletions utils/shard_memmap.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
import os
import argparse

import numpy as np
from tqdm import tqdm


def shard(
input_file: str,
output_dir: str,
):
"""Shard a Megatron .bin file into ~ 4.5 GB chunks"""
SHARD_SIZE = 5_000_000_000 # bytes ~= 4.5 GB

# load in memmapped .bin file
full_idx_map = np.memmap(input_file, mode="r", order="C")

# get number of chunks (rounds down bc start counting from shard number 0)
num_chunks = full_idx_map.shape[0] // SHARD_SIZE

# chunk by iterating over file
for i in tqdm(range(num_chunks + 1)): # while still have file contents remaining to chunk:

start_idx = i * SHARD_SIZE
end_idx = (i + 1) * SHARD_SIZE

if end_idx > full_idx_map.shape[0]:
chunk = full_idx_map[start_idx:]
else:
chunk = full_idx_map[start_idx:end_idx]

shard_filename = os.path.join(output_dir, os.path.basename(input_file)[:-4]) + f"-{i:05}-of-{num_chunks:05}.bin"
with open(shard_filename, "wb+") as out_shard_file:
print(f"Dumping shard {i:05} to {shard_filename} ...")
chunk.tofile(out_shard_file)

del chunk


if __name__ == "__main__":

parser = argparse.ArgumentParser(
description="Shard a single Megatron data .bin file"
)

## CLI args
parser.add_argument(
"--input_file",
type=str,
help="Path to .bin file e.g. /path/to/pile_0.87_deduped_text_document.bin",
)
parser.add_argument(
"--output_dir",
type=str,
help="Folder to save shards into",
)
args = parser.parse_args()

os.makedirs(args.output_dir, exist_ok=True)

shard(args.input_file, args.output_dir)
69 changes: 69 additions & 0 deletions utils/unshard_memmap.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
import os
import argparse

import numpy as np
from tqdm import tqdm


def unshard(
input_file: str,
num_shards: int,
output_dir: str,
):
"""Reconstruct a Megatron .bin file from shards"""
SHARD_SIZE = 5_000_000_000

input_dir = os.path.dirname(input_file)
base_filename = os.path.basename(input_file)[:-19] # remove 00000-of-xxxxx.bin suffix from shard 0's filename

# check size of final shard
shard_filename = os.path.join(input_dir, base_filename) + f"-{(num_shards - 1):05}-of-{(num_shards - 1):05}.bin"
shard_memmap = np.memmap(shard_filename, mode="r", order="C")
final_shard_size = shard_memmap.shape[0]
del shard_memmap

# create full .bin file of proper size
open(os.path.join(output_dir, base_filename) + ".bin", "w+").close()
full_idx_map = np.memmap(os.path.join(output_dir, base_filename) + ".bin", shape=(SHARD_SIZE * (num_shards - 1) + final_shard_size), mode="w+", order="C")
print(full_idx_map.shape)

# chunk by iterating over file
print(f"Loading {num_shards} shards from {input_dir}")
for i in tqdm(range(num_shards)):

shard_filename = os.path.join(input_dir, base_filename) + f"-{i:05}-of-{(num_shards - 1):05}.bin"
print(shard_filename)
shard_memmap = np.memmap(shard_filename, mode="r", order="C")

size = SHARD_SIZE if not (i == num_shards - 1) else final_shard_size
full_idx_map[i * SHARD_SIZE: (i * SHARD_SIZE) + size] = shard_memmap

del shard_memmap

if __name__ == "__main__":

parser = argparse.ArgumentParser(
description="Shard a single Megatron data .bin file"
)

## CLI args
parser.add_argument(
"--input_file",
type=str,
help="Path to shard 0",
)
parser.add_argument(
"--num_shards",
type=int,
help="Provide number of shards (The total seen in shard filenames + 1)"
)
parser.add_argument(
"--output_dir",
type=str,
help="Folder to save .bin file into",
)
args = parser.parse_args()

os.makedirs(args.output_dir, exist_ok=True)

unshard(args.input_file, args.num_shards, args.output_dir)

0 comments on commit df1ff85

Please sign in to comment.