Skip to content

Commit

Permalink
Fix repo for CI (#1106)
Browse files Browse the repository at this point in the history
* Fix syntax errors

* Make pre-commit fixes across repo

* Ensure correct version of clang-format in CI

---------

Co-authored-by: Yang Zhang <[email protected]>
  • Loading branch information
yang and yang committed Jan 4, 2024
1 parent e5a7ea7 commit eca6b1a
Show file tree
Hide file tree
Showing 41 changed files with 294 additions and 121 deletions.
2 changes: 2 additions & 0 deletions .github/workflows/pull_request.yml
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@ jobs:
python-version: 3.8
cache: "pip"
cache-dependency-path: "**/requirements*.txt"
# Need the right version of clang-format
- run: pip install -r requirements/requirements-dev.txt
- uses: pre-commit/[email protected]

update-documentation:
Expand Down
1 change: 1 addition & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -36,3 +36,4 @@ repos:
--check-filenames,
--check-hidden,
]
exclude: tests/data/hf_cache/tokenizer/gpt2.json
8 changes: 4 additions & 4 deletions configs/finetuning_configs/6-9B.yml
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,9 @@
"no-weight-tying": true,
"gpt_j_residual": true,
"output_layer_parallelism": "column",

"attention-config": [[["flash"], 32]],

"scaled-upper-triang-masked-softmax-fusion": true,
"bias-gelu-fusion": true,

Expand All @@ -32,7 +32,7 @@
"eps": 1.0e-8
}
},

"min_lr": 0.000012,

"zero_optimization": {
Expand All @@ -43,7 +43,7 @@
"reduce_scatter": true,
"reduce_bucket_size": 1260000000,
"contiguous_gradients": true,
"cpu_offload": false
"cpu_offload": false,
"load_from_fp32_weights": False, # if checkpoint has fp16/bf16 params
},

Expand Down
7 changes: 3 additions & 4 deletions configs/neox_arguments.md
Original file line number Diff line number Diff line change
Expand Up @@ -960,7 +960,7 @@ Text Generation arguments

- **prompt_end**: str

Default =
Default =


a single prompt's end. Defaults to newline
Expand Down Expand Up @@ -1002,7 +1002,7 @@ Text Generation arguments

- **eval_results_prefix**: str

Default =
Default =

prefix to which to save evaluation results - final fp will be {eval_results_prefix}_eval_results_yy-mm-dd-HH-MM.json

Expand Down Expand Up @@ -1752,7 +1752,7 @@ Args for deepspeed config

Default = None





Expand Down Expand Up @@ -2044,4 +2044,3 @@ Args for deepspeed runner (deepspeed.launcher.runner).
Default = None

Adds a `--comment` to the DeepSpeed launch command. In DeeperSpeed this is passed on to the SlurmLauncher as well. Sometime necessary for cluster rules, or so I've heard.

4 changes: 3 additions & 1 deletion eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,9 @@


def main(input_args=None, overwrite_values=None):
model, neox_args = setup_for_inference_or_eval(use_cache=False, input_args=input_args, overwrite_values=overwrite_values)
model, neox_args = setup_for_inference_or_eval(
use_cache=False, input_args=input_args, overwrite_values=overwrite_values
)
results = run_eval_harness(
model,
forward_step,
Expand Down
17 changes: 13 additions & 4 deletions eval_tasks/eval_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,9 @@ def _collate(x):
return (len(toks), x[0])

reord = utils.Reorderer(reqs, _collate)
for context, gen_kwargs in tqdm(reord.get_reordered(), "Running greedy generation"):
for context, gen_kwargs in tqdm(
reord.get_reordered(), "Running greedy generation"
):
if isinstance(gen_kwargs, dict):
kwargs = copy.deepcopy(gen_kwargs) # edge case for repeats > 1
if "until" in kwargs.keys():
Expand Down Expand Up @@ -406,7 +408,7 @@ def run_eval(
"winogrande",
"mathqa",
"pubmedqa",
"triviaqa"
"triviaqa",
]

# register all the default tasks bundled with lm-evaluation-harness repository
Expand Down Expand Up @@ -442,7 +444,14 @@ def pattern_match(patterns, source_list):
lm = self

if use_cache:
use_cache = 'lm_cache/neox' + '_dp_rank' + str(self._dp_rank) + '_dp_group' + str(self._dp_group) + '.db'
use_cache = (
"lm_cache/neox"
+ "_dp_rank"
+ str(self._dp_rank)
+ "_dp_group"
+ str(self._dp_group)
+ ".db"
)
print(f"Using cache at {use_cache}...")
lm = lm_eval.api.model.CachingLM(
lm,
Expand Down Expand Up @@ -481,7 +490,7 @@ def pattern_match(patterns, source_list):
results = evaluator.evaluate(
lm=lm,
task_dict=task_dict,
limit=10, #limit,
limit=10, # limit,
bootstrap_iters=bootstrap_iters,
log_samples=False,
)
Expand Down
4 changes: 3 additions & 1 deletion generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,9 @@ def main(input_args=None, overwrite_values=None):
"""
Generate text/sample model
"""
model, neox_args = setup_for_inference_or_eval(use_cache=True,input_args=input_args, overwrite_values=overwrite_values)
model, neox_args = setup_for_inference_or_eval(
use_cache=True, input_args=input_args, overwrite_values=overwrite_values
)
if neox_args.recompute:
model.module.inference_mode(
use_cache=False
Expand Down
59 changes: 44 additions & 15 deletions megatron/checkpointing.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,11 +30,15 @@
try:
import boto3
except ModuleNotFoundError:
print("For s3 checkpointing, please install boto3 either using requirements/requirements-s3.txt or https://github.com/boto/boto3")
print(
"For s3 checkpointing, please install boto3 either using requirements/requirements-s3.txt or https://github.com/boto/boto3"
)
try:
import hf_transfer
except ModuleNotFoundError:
print("For s3 checkpointing, please install hf_transfer either using requirements/requirements-s3.txt or https://github.com/huggingface/hf_transfer")
print(
"For s3 checkpointing, please install hf_transfer either using requirements/requirements-s3.txt or https://github.com/huggingface/hf_transfer"
)
import torch
from glob import glob

Expand Down Expand Up @@ -217,6 +221,8 @@ def save_ds_checkpoint(iteration, model, neox_args):
f.write(config_data)
else:
json.dump(config_data, f)


def multiprocessing_starmap(func, args, num_processes=None):
"""Wrapper to allow for re-usable multiprocessing pools with `spawn` context handling
Args:
Expand All @@ -225,8 +231,11 @@ def multiprocessing_starmap(func, args, num_processes=None):
num_processes (int, optional): Number of processes to spawn. Defaults to `multiprocessing.cpu_count() - 1`
"""
import multiprocessing

num_processes = num_processes or (multiprocessing.cpu_count() - 1)
with multiprocessing.get_context("spawn").Pool(processes=num_processes) as process_pool:
with multiprocessing.get_context("spawn").Pool(
processes=num_processes
) as process_pool:
process_pool.starmap(func, args)
process_pool.terminate()
process_pool.join()
Expand All @@ -253,7 +262,7 @@ def _upload(
chunks in parallel (cannot exceed max_files). Defaults to 63
max_retries (int, optional): Number of retries for each chunk. Defaults to 5
"""
s3 = boto3.client('s3')
s3 = boto3.client("s3")
bucket = s3_key.split("s3:https://")[1].split("/")[0]
key = s3_key.split(bucket)[1].lstrip("/")

Expand Down Expand Up @@ -304,24 +313,42 @@ def _upload(


def upload_checkpoint(iteration, neox_args):
local_checkpoint_path = os.path.join(os.path.abspath(neox_args.save), get_checkpoint_tag(iteration))
local_checkpoint_list = sorted(filter(
lambda x: os.path.isfile(x),
[str(p) for p in Path(local_checkpoint_path).rglob("*")],
))
local_checkpoint_path = os.path.join(
os.path.abspath(neox_args.save), get_checkpoint_tag(iteration)
)
local_checkpoint_list = sorted(
filter(
lambda x: os.path.isfile(x),
[str(p) for p in Path(local_checkpoint_path).rglob("*")],
)
)
remote_checkpoint_path = os.path.join(
neox_args.s3_path, os.path.basename(neox_args.save), get_checkpoint_tag(iteration))
neox_args.s3_path,
os.path.basename(neox_args.save),
get_checkpoint_tag(iteration),
)
remote_checkpoint_list = [
os.path.join(remote_checkpoint_path, os.path.relpath(local_checkpoint, local_checkpoint_path))
os.path.join(
remote_checkpoint_path,
os.path.relpath(local_checkpoint, local_checkpoint_path),
)
for local_checkpoint in local_checkpoint_list
]
inputs = zip(local_checkpoint_list, remote_checkpoint_list, [neox_args.s3_chunk_size] * len(local_checkpoint_list))
inputs = zip(
local_checkpoint_list,
remote_checkpoint_list,
[neox_args.s3_chunk_size] * len(local_checkpoint_list),
)

print_rank_0(f"[RANK {torch.distributed.get_rank()}] Uploading checkpoint `{local_checkpoint_path}` to `{remote_checkpoint_path}`...")
print_rank_0(
f"[RANK {torch.distributed.get_rank()}] Uploading checkpoint `{local_checkpoint_path}` to `{remote_checkpoint_path}`..."
)
start = time.time()
multiprocessing_starmap(_upload, inputs)
total_time = time.time() - start
print_rank_0(f"[RANK {torch.distributed.get_rank()}] Uploaded checkpoint `{local_checkpoint_path}` to `{remote_checkpoint_path}` in {total_time:.2f}s")
print_rank_0(
f"[RANK {torch.distributed.get_rank()}] Uploaded checkpoint `{local_checkpoint_path}` to `{remote_checkpoint_path}` in {total_time:.2f}s"
)


def save_checkpoint(neox_args, iteration, model, optimizer, lr_scheduler):
Expand Down Expand Up @@ -395,7 +422,9 @@ def load_checkpoint(
if "iteration" in state_dict:
iteration = state_dict["iteration"]
else:
iteration = state_dict.get("total_iters") # total_iters backward compatible with older checkpoints
iteration = state_dict.get(
"total_iters"
) # total_iters backward compatible with older checkpoints
if iteration is None:
raise ValueError(
f"Unable to load iteration from checkpoint {checkpoint_name} with keys {state_dict.keys()}, exiting"
Expand Down
27 changes: 17 additions & 10 deletions megatron/data/gpt2_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,16 +82,22 @@ def __getitem__(self, idx):
offset_f = self.sample_idx[idx][1]
offset_l = self.sample_idx[idx + 1][1]
# Labels and texts are supposed to be fully in sync.
datasets = [self.indexed_dataset] if self.label_dataset is None else [self.indexed_dataset, self.label_dataset]
datasets = (
[self.indexed_dataset]
if self.label_dataset is None
else [self.indexed_dataset, self.label_dataset]
)
samples = []
# If we are within the same document, just extract the chunk.
for n, dataset in enumerate(datasets):
if doc_index_f == doc_index_l:
samples.append(dataset.get(
self.doc_idx[doc_index_f],
offset=offset_f,
length=offset_l - offset_f + 1,
))
samples.append(
dataset.get(
self.doc_idx[doc_index_f],
offset=offset_f,
length=offset_l - offset_f + 1,
)
)
else:
# Otherwise, get the rest of the initial document.
sample_list = [
Expand All @@ -102,16 +108,17 @@ def __getitem__(self, idx):
sample_list.append(dataset.get(self.doc_idx[i]))
# And finally add the relevant portion of last document.
sample_list.append(
dataset.get(
self.doc_idx[doc_index_l], length=offset_l + 1
)
dataset.get(self.doc_idx[doc_index_l], length=offset_l + 1)
)
samples.append(np.concatenate(sample_list))

if len(datasets) == 1:
return {"text": np.array(samples[0], dtype=np.int64)}
else:
return {"text": np.array(samples[0], dtype=np.int64), "label": np.array(samples[1], dtype=np.int64)}
return {
"text": np.array(samples[0], dtype=np.int64),
"label": np.array(samples[1], dtype=np.int64),
}
except IndexError:
new_idx = idx % len(self)
print(
Expand Down
2 changes: 1 addition & 1 deletion megatron/learning_rates.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ def get_lr(self):
elif self.decay_style == "cosine":
end_iter_ = self.end_iter - self.warmup_iter
lr = self.min_lr + (
(self.start_lr-self.min_lr)
(self.start_lr - self.min_lr)
/ 2.0
* (math.cos(math.pi * num_iters_ / end_iter_) + 1)
)
Expand Down
6 changes: 5 additions & 1 deletion megatron/logging.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,11 @@ def get_flops(neox_args, iter_time_s) -> float:
* seq_len
* num_layers
* (hidden_size**2)
* (1.0 + (seq_len / (6.0 * hidden_size)) + (vocab_size / (16.0 * num_layers * hidden_size)))
* (
1.0
+ (seq_len / (6.0 * hidden_size))
+ (vocab_size / (16.0 * num_layers * hidden_size))
)
)
return flops_per_iteration / (iter_time_s * world_size)

Expand Down

0 comments on commit eca6b1a

Please sign in to comment.