Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add s3 checkpoint syncing #1010

Merged
merged 10 commits into from
Sep 23, 2023
Next Next commit
add s3 checkpoint syncing
  • Loading branch information
haileyschoelkopf committed Aug 9, 2023
commit 757320b154767bd038de2d97658950034344081d
122 changes: 120 additions & 2 deletions megatron/checkpointing.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,17 @@
"""Input/output checkpointing."""

import json
import math
import os
import re
import shutil
import time
import random
import sys
import numpy as np

import boto3
Quentin-Anthony marked this conversation as resolved.
Show resolved Hide resolved
import hf_transfer
Quentin-Anthony marked this conversation as resolved.
Show resolved Hide resolved
import torch
from glob import glob

Expand Down Expand Up @@ -137,6 +141,10 @@ def get_checkpoint_name(checkpoints_path, iteration, release=False, mp_rank=None
)


def get_checkpoint_tag(iteration: int) -> str:
return f"global_step{iteration}"


def delete_old_checkpoints(save_dir, n_to_keep):
if torch.distributed.get_rank() == 0:
ckpt_dir_regex = r"global_step[\d]*"
Expand Down Expand Up @@ -188,7 +196,7 @@ def save_ds_checkpoint(iteration, model, neox_args):
sd["checkpoint_validation_logits"] = logits

# checkpoint folder name
tag = f"global_step{iteration}"
tag = get_checkpoint_tag(iteration)

# save checkpoint
model.save_checkpoint(neox_args.save, tag=tag, client_state=sd)
Expand All @@ -203,6 +211,111 @@ def save_ds_checkpoint(iteration, model, neox_args):
f.write(config_data)
else:
json.dump(config_data, f)
def multiprocessing_starmap(func, args, num_processes=None):
"""Wrapper to allow for re-usable multiprocessing pools with `spawn` context handling
Args:
func (Callable): Function to call
args (Iterable): Iterable of arguments to pass to `func`
num_processes (int, optional): Number of processes to spawn. Defaults to `multiprocessing.cpu_count() - 1`
"""
import multiprocessing
num_processes = num_processes or (multiprocessing.cpu_count() - 1)
with multiprocessing.get_context("spawn").Pool(processes=num_processes) as process_pool:
process_pool.starmap(func, args)
process_pool.terminate()
process_pool.join()
del process_pool


def _upload(
file_path: str,
s3_key: str,
chunk_size: int = 104_857_600,
max_files: int = 64,
parallel_failures: int = 63,
max_retries: int = 5,
):
"""Upload local file to S3 using `hf_transfer` library
Args:
file_path (str): Local filename to upload
s3_key (str): S3 key to upload to. E.g. `s3:https://bucket-name/path/to/file`
chunk_size (int, optional): Chunk size to use for multipart upload.
Defaults to 100MiB = 104_857_600
max_files (int, optional): Number of open file handles, which determines
the maximum number of parallel downloads. Defaults to 64
parallel_failures (int, optional): Number of maximum failures of different
chunks in parallel (cannot exceed max_files). Defaults to 63
max_retries (int, optional): Number of retries for each chunk. Defaults to 5
"""
s3 = boto3.client('s3')
bucket = s3_key.split("s3:https://")[1].split("/")[0]
key = s3_key.split(bucket)[1].lstrip("/")

# 1. Init multipart upload and obtain unique upload identifier
upload = s3.create_multipart_upload(
ACL="bucket-owner-full-control",
Bucket=bucket,
Key=key,
)
upload_id = upload["UploadId"]

# 2. Generate presigned URLs for each part
file_size = os.stat(file_path).st_size
urls = []
nb_parts = math.ceil(file_size / chunk_size)
for part_number in range(1, nb_parts + 1):
params = {
"Bucket": bucket,
"Key": key,
"PartNumber": part_number,
"UploadId": upload_id,
}
urls.append(
s3.generate_presigned_url(
ClientMethod="upload_part", Params=params, ExpiresIn=86400
)
)

# 3. Upload parts in parallel
responses = hf_transfer.multipart_upload(
file_path=file_path,
parts_urls=urls,
chunk_size=chunk_size,
max_files=max_files,
parallel_failures=parallel_failures,
max_retries=max_retries,
)

# 4. Complete multipart upload request with ETag values
etag_with_parts = []
for part_number, header in enumerate(responses):
etag = header.get("etag")
etag_with_parts.append({"ETag": etag, "PartNumber": part_number + 1})
parts = {"Parts": etag_with_parts}
s3.complete_multipart_upload(
Bucket=bucket, Key=key, MultipartUpload=parts, UploadId=upload_id
)


def upload_checkpoint(iteration, neox_args):
local_checkpoint_path = os.path.join(os.path.abspath(neox_args.save), get_checkpoint_tag(iteration))
local_checkpoint_list = sorted(filter(
lambda x: os.path.isfile(x),
[str(p) for p in Path(local_checkpoint_path).rglob("*")],
))
remote_checkpoint_path = os.path.join(
neox_args.s3_path, os.path.basename(neox_args.save), get_checkpoint_tag(iteration))
remote_checkpoint_list = [
os.path.join(remote_checkpoint_path, os.path.relpath(local_checkpoint, local_checkpoint_path))
for local_checkpoint in local_checkpoint_list
]
inputs = zip(local_checkpoint_list, remote_checkpoint_list, [neox_args.s3_chunk_size] * len(local_checkpoint_list))

print_rank_0(f"[RANK {torch.distributed.get_rank()}] Uploading checkpoint `{local_checkpoint_path}` to `{remote_checkpoint_path}`...")
start = time.time()
multiprocessing_starmap(_upload, inputs)
total_time = time.time() - start
print_rank_0(f"[RANK {torch.distributed.get_rank()}] Uploaded checkpoint `{local_checkpoint_path}` to `{remote_checkpoint_path}` in {total_time:.2f}s")


def save_checkpoint(neox_args, iteration, model, optimizer, lr_scheduler):
Expand All @@ -213,6 +326,11 @@ def save_checkpoint(neox_args, iteration, model, optimizer, lr_scheduler):
else:
raise ValueError("Must be using deepspeed to use neox")

torch.distributed.barrier()
upload_to_s3 = torch.distributed.get_rank() == 0 and neox_args.s3_path is not None
if upload_to_s3:
upload_checkpoint(iteration, neox_args)

# Wait so everyone is done (necessary)
torch.distributed.barrier()
if neox_args.keep_last_n_checkpoints is not None:
Expand All @@ -233,7 +351,7 @@ def load_checkpoint(
if neox_args.finetune:
load_optim_and_scheduler = False
if iteration is not None:
tag = f"global_step{iteration}"
tag = get_checkpoint_tag(iteration)
else:
tag = None
checkpoint_name, state_dict = model.load_checkpoint(
Expand Down
10 changes: 10 additions & 0 deletions megatron/neox_arguments/neox_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -793,6 +793,16 @@ class NeoXArgsTraining(NeoXArgsTemplate):
Output directory to save checkpoints to.
"""

s3_path: str = None
"""
Path to s3 bucket for saving checkpoints.
"""

s3_chunk_size: int = 104_857_600
"""
The number of bytes in each file chunk when uploading to s3. Defaults to 100MiB.
"""

config_files: dict = None
"""
Store of original config files mapping config filename to file contents
Expand Down
1 change: 1 addition & 0 deletions megatron/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
from deepspeed.runtime.data_pipeline.curriculum_scheduler import CurriculumScheduler
import numpy as np

from CPCargo import Heartbeat
from megatron.utils import (
Timers,
init_wandb,
Expand Down
2 changes: 2 additions & 0 deletions requirements/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ git+https://github.com/EleutherAI/DeeperSpeed.git#egg=deepspeed
ftfy>=6.0.1
git+https://github.com/EleutherAI/lm_dataformat.git@4eec05349977071bf67fc072290b95e31c8dd836
huggingface_hub>=0.11.0
hf-transfer>=0.1.3
lm_eval>=0.3.0
mpi4py>=3.0.3
numpy>=1.22.0
Expand All @@ -13,3 +14,4 @@ six
tiktoken>=0.1.2
tokenizers>=0.12.1
transformers>=4.24.0
git+https://github.com/samikama/CPCargo@main