diff --git a/README.md b/README.md index bd4af5c0b..c211f4fc3 100644 --- a/README.md +++ b/README.md @@ -18,6 +18,8 @@ GPT-NeoX leverages many of the same features and technologies as the popular Meg * Easy connections with the open source ecosystem, including Hugging Face's [tokenizers](https://github.com/huggingface/tokenizers) and [transformers](https://github.com/huggingface/transformers/) libraries, logging via [WandB](https://wandb.ai/site), and evaluation via our [Language Model Evaluation Harness](https://github.com/EleutherAI/lm-evaluation-harness). ## News +**[8/10/2023]** We now support checkpointing with AWS S3! Activate with the `s3_path` config option (for more detail, see [the PR](https://github.com/EleutherAI/gpt-neox/pull/1010)) + **[9/20/2023]** As of https://github.com/EleutherAI/gpt-neox/pull/1035, we have deprecated Flash Attention 0.x and 1.x, and migrated support to Flash Attention 2.x. We don't believe this will cause problems, but if you have a specific use-case that requires old flash support using the latest GPT-NeoX, please raise an issue. **[8/10/2023]** We have experimental support for LLaMA 2 and Flash Attention v2 supported in our [math-lm](https://github.com/EleutherAI/math-lm) project that will be upstreamed later this month. diff --git a/configs/neox_arguments.md b/configs/neox_arguments.md index cc7717f66..f0c146ca6 100644 --- a/configs/neox_arguments.md +++ b/configs/neox_arguments.md @@ -111,7 +111,7 @@ Logging Arguments - **git_hash**: str - Default = 1d20559 + Default = aa6c176 current git hash of repository @@ -1169,6 +1169,22 @@ Training Arguments +- **s3_path**: str + + Default = None + + Path to s3 bucket for saving checkpoints. + + + +- **s3_chunk_size**: int + + Default = 104857600 + + The number of bytes in each file chunk when uploading to s3. Defaults to 100MiB. + + + - **config_files**: dict Default = None diff --git a/megatron/checkpointing.py b/megatron/checkpointing.py index 8ad2a88c7..8bcc01f3b 100644 --- a/megatron/checkpointing.py +++ b/megatron/checkpointing.py @@ -18,13 +18,23 @@ """Input/output checkpointing.""" import json +import math import os import re import shutil +import time import random import sys import numpy as np +try: + import boto3 +except ModuleNotFoundError: + print("For s3 checkpointing, please install boto3 either using requirements/requirements-s3.txt or https://github.com/boto/boto3") +try: + import hf_transfer +except ModuleNotFoundError: + print("For s3 checkpointing, please install hf_transfer either using requirements/requirements-s3.txt or https://github.com/huggingface/hf_transfer") import torch from glob import glob @@ -137,6 +147,10 @@ def get_checkpoint_name(checkpoints_path, iteration, release=False, mp_rank=None ) +def get_checkpoint_tag(iteration: int) -> str: + return f"global_step{iteration}" + + def delete_old_checkpoints(save_dir, n_to_keep): if torch.distributed.get_rank() == 0: ckpt_dir_regex = r"global_step[\d]*" @@ -188,7 +202,7 @@ def save_ds_checkpoint(iteration, model, neox_args): sd["checkpoint_validation_logits"] = logits # checkpoint folder name - tag = f"global_step{iteration}" + tag = get_checkpoint_tag(iteration) # save checkpoint model.save_checkpoint(neox_args.save, tag=tag, client_state=sd) @@ -203,6 +217,111 @@ def save_ds_checkpoint(iteration, model, neox_args): f.write(config_data) else: json.dump(config_data, f) +def multiprocessing_starmap(func, args, num_processes=None): + """Wrapper to allow for re-usable multiprocessing pools with `spawn` context handling + Args: + func (Callable): Function to call + args (Iterable): Iterable of arguments to pass to `func` + num_processes (int, optional): Number of processes to spawn. Defaults to `multiprocessing.cpu_count() - 1` + """ + import multiprocessing + num_processes = num_processes or (multiprocessing.cpu_count() - 1) + with multiprocessing.get_context("spawn").Pool(processes=num_processes) as process_pool: + process_pool.starmap(func, args) + process_pool.terminate() + process_pool.join() + del process_pool + + +def _upload( + file_path: str, + s3_key: str, + chunk_size: int = 104_857_600, + max_files: int = 64, + parallel_failures: int = 63, + max_retries: int = 5, +): + """Upload local file to S3 using `hf_transfer` library + Args: + file_path (str): Local filename to upload + s3_key (str): S3 key to upload to. E.g. `s3://bucket-name/path/to/file` + chunk_size (int, optional): Chunk size to use for multipart upload. + Defaults to 100MiB = 104_857_600 + max_files (int, optional): Number of open file handles, which determines + the maximum number of parallel downloads. Defaults to 64 + parallel_failures (int, optional): Number of maximum failures of different + chunks in parallel (cannot exceed max_files). Defaults to 63 + max_retries (int, optional): Number of retries for each chunk. Defaults to 5 + """ + s3 = boto3.client('s3') + bucket = s3_key.split("s3://")[1].split("/")[0] + key = s3_key.split(bucket)[1].lstrip("/") + + # 1. Init multipart upload and obtain unique upload identifier + upload = s3.create_multipart_upload( + ACL="bucket-owner-full-control", + Bucket=bucket, + Key=key, + ) + upload_id = upload["UploadId"] + + # 2. Generate presigned URLs for each part + file_size = os.stat(file_path).st_size + urls = [] + nb_parts = math.ceil(file_size / chunk_size) + for part_number in range(1, nb_parts + 1): + params = { + "Bucket": bucket, + "Key": key, + "PartNumber": part_number, + "UploadId": upload_id, + } + urls.append( + s3.generate_presigned_url( + ClientMethod="upload_part", Params=params, ExpiresIn=86400 + ) + ) + + # 3. Upload parts in parallel + responses = hf_transfer.multipart_upload( + file_path=file_path, + parts_urls=urls, + chunk_size=chunk_size, + max_files=max_files, + parallel_failures=parallel_failures, + max_retries=max_retries, + ) + + # 4. Complete multipart upload request with ETag values + etag_with_parts = [] + for part_number, header in enumerate(responses): + etag = header.get("etag") + etag_with_parts.append({"ETag": etag, "PartNumber": part_number + 1}) + parts = {"Parts": etag_with_parts} + s3.complete_multipart_upload( + Bucket=bucket, Key=key, MultipartUpload=parts, UploadId=upload_id + ) + + +def upload_checkpoint(iteration, neox_args): + local_checkpoint_path = os.path.join(os.path.abspath(neox_args.save), get_checkpoint_tag(iteration)) + local_checkpoint_list = sorted(filter( + lambda x: os.path.isfile(x), + [str(p) for p in Path(local_checkpoint_path).rglob("*")], + )) + remote_checkpoint_path = os.path.join( + neox_args.s3_path, os.path.basename(neox_args.save), get_checkpoint_tag(iteration)) + remote_checkpoint_list = [ + os.path.join(remote_checkpoint_path, os.path.relpath(local_checkpoint, local_checkpoint_path)) + for local_checkpoint in local_checkpoint_list + ] + inputs = zip(local_checkpoint_list, remote_checkpoint_list, [neox_args.s3_chunk_size] * len(local_checkpoint_list)) + + print_rank_0(f"[RANK {torch.distributed.get_rank()}] Uploading checkpoint `{local_checkpoint_path}` to `{remote_checkpoint_path}`...") + start = time.time() + multiprocessing_starmap(_upload, inputs) + total_time = time.time() - start + print_rank_0(f"[RANK {torch.distributed.get_rank()}] Uploaded checkpoint `{local_checkpoint_path}` to `{remote_checkpoint_path}` in {total_time:.2f}s") def save_checkpoint(neox_args, iteration, model, optimizer, lr_scheduler): @@ -213,6 +332,11 @@ def save_checkpoint(neox_args, iteration, model, optimizer, lr_scheduler): else: raise ValueError("Must be using deepspeed to use neox") + torch.distributed.barrier() + upload_to_s3 = torch.distributed.get_rank() == 0 and neox_args.s3_path is not None + if upload_to_s3: + upload_checkpoint(iteration, neox_args) + # Wait so everyone is done (necessary) torch.distributed.barrier() if neox_args.keep_last_n_checkpoints is not None: @@ -233,7 +357,7 @@ def load_checkpoint( if neox_args.finetune: load_optim_and_scheduler = False if iteration is not None: - tag = f"global_step{iteration}" + tag = get_checkpoint_tag(iteration) else: tag = None checkpoint_name, state_dict = model.load_checkpoint( diff --git a/megatron/neox_arguments/neox_args.py b/megatron/neox_arguments/neox_args.py index 6f507ffe3..e427b2551 100644 --- a/megatron/neox_arguments/neox_args.py +++ b/megatron/neox_arguments/neox_args.py @@ -793,6 +793,16 @@ class NeoXArgsTraining(NeoXArgsTemplate): Output directory to save checkpoints to. """ + s3_path: str = None + """ + Path to s3 bucket for saving checkpoints. + """ + + s3_chunk_size: int = 104_857_600 + """ + The number of bytes in each file chunk when uploading to s3. Defaults to 100MiB. + """ + config_files: dict = None """ Store of original config files mapping config filename to file contents diff --git a/requirements/requirements-s3.txt b/requirements/requirements-s3.txt new file mode 100644 index 000000000..7a2924ccd --- /dev/null +++ b/requirements/requirements-s3.txt @@ -0,0 +1,2 @@ +hf-transfer>=0.1.3 +boto3 \ No newline at end of file