Skip to content

Commit

Permalink
Automatically infer shard size
Browse files Browse the repository at this point in the history
  • Loading branch information
uSaiPrashanth committed Oct 31, 2023
1 parent b56b413 commit a3b62b5
Showing 1 changed file with 6 additions and 2 deletions.
8 changes: 6 additions & 2 deletions utils/unshard_memmap.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,15 @@ def unshard(
output_dir: str,
):
"""Reconstruct a Megatron .bin file from shards"""
SHARD_SIZE = 5_000_000_000

input_dir = os.path.dirname(input_file)
base_filename = os.path.basename(input_file)[:-19] # remove 00000-of-xxxxx.bin suffix from shard 0's filename


# check size of non-final shard
shard_filename = shard_filename = os.path.join(input_dir, base_filename) + f"-00000-of-{(num_shards - 1):05}.bin"
shard_memmap = np.memmap(shard_filename, mode="r", order="C")
SHARD_SIZE = shard_memmap.shape[0]

# check size of final shard
shard_filename = os.path.join(input_dir, base_filename) + f"-{(num_shards - 1):05}-of-{(num_shards - 1):05}.bin"
shard_memmap = np.memmap(shard_filename, mode="r", order="C")
Expand Down

0 comments on commit a3b62b5

Please sign in to comment.