Skip to content

Commit

Permalink
feat(train): support tokens count for seqlen warmup (EleutherAI#15)
Browse files Browse the repository at this point in the history
* make token count seqlen warmup friendly

* add sanity check for schedule_type
  • Loading branch information
leemengtw committed Jun 7, 2023
1 parent f5b2b90 commit f92a6d0
Show file tree
Hide file tree
Showing 2 changed files with 130 additions and 0 deletions.
20 changes: 20 additions & 0 deletions megatron/logging.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

import sys
import torch
import numpy as np

try:
import wandb
Expand All @@ -22,6 +23,7 @@

from megatron import mpu, print_rank_0
from megatron.utils import report_memory
from megatron.seqlen_warmup_tokens import fixed_linear_seqlen_warmup_schedule


class Tee:
Expand Down Expand Up @@ -303,6 +305,24 @@ def add_to_logging(name):
)

# log tokens seen so far
if neox_args.curriculum_learning:
cl = neox_args.curriculum_learning
assert cl['schedule_type'] == "fixed_linear", \
"Only `fixed_linear` curriculum is supported at this time"

cl_steps = cl['schedule_config']['total_curriculum_step']
seq_len_schedule = fixed_linear_seqlen_warmup_schedule(
start_seqlen=cl['min_difficulty'],
end_seqlen=cl['max_difficulty'],
total_steps=cl_steps,
step_size=cl['schedule_config']['difficulty_step']
)

tokens = np.sum(seq_len_schedule[:iteration] * neox_args.train_batch_size) \
+ neox_args.train_batch_size * neox_args.seq_length * np.max([0, iteration - cl_steps])
else:
tokens = neox_args.train_batch_size * neox_args.seq_length * iteration

tb_wandb_log(
"runtime/tokens",
neox_args.train_batch_size * neox_args.seq_length * iteration,
Expand Down
110 changes: 110 additions & 0 deletions megatron/seqlen_warmup_tokens.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
"""
This script calculates the number of tokens to train on for a given DeepSpeed curriculum
setup as specified in the NeoX config file.
Usage:
python scripts/seqlen_warmup_tokens.py --config <path-to-config> --nodes <num-nodes>
"""
import argparse
import yaml
import numpy as np


def fixed_linear_seqlen_warmup_schedule(
start_seqlen: int = 64,
end_seqlen: int = 2048,
total_steps: int = 20_000,
step_size: int = 8 # For GPU efficiency
):
"""
Linear warmup schedule from Li et al. The Stability-Efficiency Dilemma: Investigating
Sequence Length Warmup for Training GPT Models https://openreview.net/pdf?id=JpZ5du_Kdh
as used in DeepSpeed
"""
seqlen_schedule = np.array([0] * total_steps)
for t in range(0, total_steps):
seqlen_schedule[t] = \
start_seqlen + (end_seqlen - start_seqlen) * min(t / total_steps, 1)
seqlen_schedule[t] = seqlen_schedule[t] - (seqlen_schedule[t] % step_size)
seqlen_schedule[t] = int(seqlen_schedule[t])
return seqlen_schedule


def token_count_with_seqlen_warmup(
seqlen: int,
seqlen_schedule: np.array,
rest_steps: int,
effective_batch_size: int = 2048,
):
"""
This function calculates the total number of tokens to for training with the
given warmup schedule and rest of steps..
Args:
rest_steps: The number of steps to train on after warmup
"""
warmup_steps = len(seqlen_schedule)
warmup_tokens = np.sum(seqlen_schedule * effective_batch_size)
rest_tokens = rest_steps * (effective_batch_size * seqlen)
total_tokens = warmup_tokens + rest_tokens
return dict(
warmup_steps=warmup_steps,
warmup_tokens=warmup_tokens,
rest_steps=rest_steps,
rest_tokens=rest_tokens,
total_tokens=total_tokens,
)


def effective_batch_size(mbs, gas, num_gpus, tp, pp):
return mbs * gas * (num_gpus // (tp * pp))


if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--config", type=str)
parser.add_argument("--nodes", type=int, default=1)
args = parser.parse_args()

with open(args.config, "r") as f:
config = yaml.load(f, Loader=yaml.FullLoader)

assert "curriculum_learning" in config and config['curriculum_learning']['enabled'], \
"Set the `curriculum_learning` field in your config"
assert config['curriculum_learning']['schedule_type'] == "fixed_linear", \
"Only `fixed_linear` curriculum is supported at this time"

curriculum = config['curriculum_learning']
schedule = fixed_linear_seqlen_warmup_schedule(
start_seqlen=curriculum['min_difficulty'],
end_seqlen=curriculum['max_difficulty'],
total_steps=curriculum['schedule_config']['total_curriculum_step'],
step_size=curriculum['schedule_config']['difficulty_step']
)
ebs = effective_batch_size(
mbs=config["train_micro_batch_size_per_gpu"],
gas=config["gradient_accumulation_steps"],
num_gpus=args.nodes * 8,
tp=config["model-parallel-size"],
pp=config["pipe-parallel-size"],
)
count_info = token_count_with_seqlen_warmup(
seqlen=curriculum['max_difficulty'],
seqlen_schedule=schedule,
rest_steps=config['train-iters'] - curriculum['schedule_config']['total_curriculum_step'],
effective_batch_size=ebs,
)

print(f"{'='*32}")
print(f"num_gpus: {args.nodes * 8}")
print(f"effective batch size: {ebs}")
print(f"{'='*32}")
print(f"warmup steps: {count_info['warmup_steps']:,}")
print(f"warmup tokens: {count_info['warmup_tokens']:,}")
print(f"{'='*32}")
print(f"rest steps: {count_info['rest_steps']:,}")
print(f"rest tokens: {count_info['rest_tokens']:,}")
print(f"{'='*32}")
print(f"total steps: {count_info['warmup_steps'] + count_info['rest_steps']:,}")
print(f"total tokens: {count_info['total_tokens']:,}")
print(f"{'='*32}")

0 comments on commit f92a6d0

Please sign in to comment.