Skip to content

Commit

Permalink
parametrize train tests
Browse files Browse the repository at this point in the history
  • Loading branch information
sdtblck committed May 13, 2021
1 parent bb8222f commit a2ecec7
Show file tree
Hide file tree
Showing 4 changed files with 252 additions and 118 deletions.
70 changes: 68 additions & 2 deletions tests/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from pathlib import Path

import pytest
import random

import torch
import torch.distributed as dist
Expand All @@ -23,18 +24,22 @@
def get_root_directory():
return Path(__file__).parents[1]


def get_config_directory():
return get_root_directory() / "configs"


def get_configs_with_path(configs):
return [str(get_config_directory() / cfg) for cfg in configs]


def get_test_configs_with_path(configs):
test_config_dir = Path(__file__).parent / "test_configs"
return [str((test_config_dir / cfg).absolute()) for cfg in configs]


def clear_test_dirs():
log_dir = os.path.join(get_root_directory(),TEST_LOG_DIR)
log_dir = os.path.join(get_root_directory(), TEST_LOG_DIR)
if os.path.isdir(log_dir):
shutil.rmtree(log_dir)

Expand All @@ -45,7 +50,8 @@ def clear_test_dirs():
tensorboard_dir = os.path.join(get_root_directory(), TEST_TENSORBOARD_DIR)
if os.path.isdir(tensorboard_dir):
shutil.rmtree(tensorboard_dir)



def distributed_test(world_size=2, backend='nccl'):
"""A decorator for executing a function (e.g., a unit test) in a distributed manner.
This decorator manages the spawning and joining of processes, initialization of
Expand All @@ -64,8 +70,10 @@ def my_test():
world_size (int or list): number of ranks to spawn. Can be a list to spawn
multiple tests.
"""

def dist_wrap(run_func):
"""Second-level decorator for dist_test. This actually wraps the function. """

def dist_init(local_rank, num_procs, *func_args, **func_kwargs):
"""Initialize torch.distributed and execute the user function. """
os.environ['MASTER_ADDR'] = '127.0.0.1'
Expand Down Expand Up @@ -137,3 +145,61 @@ def run_func_decorator(*func_args, **func_kwargs):
return run_func_decorator

return dist_wrap


def model_setup(yaml_list=None, param_dict=None):
from megatron.neox_arguments import NeoXArgs
from megatron.mpu import destroy_model_parallel
from megatron import initialize_megatron
from megatron.training import setup_model_and_optimizer

destroy_model_parallel() # mpu model parallel contains remaining global vars
if torch.distributed.get_world_size() == 1 or torch.distributed.get_rank() == 0:
clear_test_dirs()

overwrite_values = {
"user_script": str(get_root_directory() / "pretrain_gpt2.py"),
"save": TEST_CHECKPOINT_DIR,
"load": TEST_CHECKPOINT_DIR,
"log_dir": TEST_LOG_DIR,
"tensorboard_dir": TEST_TENSORBOARD_DIR,
}

# should not both be none
assert yaml_list is not None or param_dict is not None

# initially load config from files as would be the case in deepy.py
if yaml_list is not None:
args_loaded = NeoXArgs.from_ymls(yaml_list, overwrite_values=overwrite_values)
else:
p_dict = param_dict.copy()
p_dict.update(overwrite_values)
args_loaded = NeoXArgs.from_dict(p_dict)

args_loaded.build_tokenizer()

initialize_megatron(neox_args=args_loaded)
model, optimizer, lr_scheduler = setup_model_and_optimizer(neox_args=args_loaded, inference=False,
get_key_value=True)
return model, optimizer, lr_scheduler, args_loaded


def bounded_product(sequence, n=None, seed=None):
"""
Returns a shuffled, bounded cartesian product of the input sequence.
Designed to cover as wide a range of permutations as possible with a limited number of iterations.
Will manifest the whole list in memory, so not suitable for super large sequences.
:param sequence: iterable
:param n: length of returned list
:param seed: random seed for reproducibility
:return: list
"""
p = list(itertools.product(*sequence))
if seed is not None:
random.seed(seed)
random.shuffle(p)
return p if n is None else p[:n]


binary = [True, False]
38 changes: 2 additions & 36 deletions tests/model/test_model_instantiation.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,7 @@

import torch

from ..common import TEST_CHECKPOINT_DIR, TEST_LOG_DIR, TEST_TENSORBOARD_DIR
from ..common import distributed_test, get_test_configs_with_path, get_root_directory, clear_test_dirs
from ..common import distributed_test, get_test_configs_with_path, model_setup, clear_test_dirs

@distributed_test(world_size=1)
def test_model_instantiation_small_0():
Expand Down Expand Up @@ -36,40 +35,7 @@ def test_model_instantiation_small_4():

def run_test_model_instantiation(yaml_list=None, param_dict=None):
from deepspeed.runtime.pipe.engine import PipelineEngine, DeepSpeedEngine

from megatron.neox_arguments import NeoXArgs
from megatron.mpu import destroy_model_parallel
from megatron import initialize_megatron
from megatron.training import setup_model_and_optimizer

destroy_model_parallel() # mpu model parallel contains remaining global vars
if torch.distributed.get_world_size() == 1 or torch.distributed.get_rank() == 0:
clear_test_dirs()

overwrite_values = {
"user_script": str(get_root_directory() / "pretrain_gpt2.py"),
"save": TEST_CHECKPOINT_DIR,
"load": TEST_CHECKPOINT_DIR,
"log_dir": TEST_LOG_DIR,
"tensorboard_dir": TEST_TENSORBOARD_DIR,
}

# should not both be none
assert yaml_list is not None or param_dict is not None

# intitially load config from files as would be the case in deepy.py
if yaml_list is not None:
args_loaded = NeoXArgs.from_ymls(yaml_list, overwrite_values=overwrite_values)
else:
p_dict = param_dict.copy()
p_dict.update(overwrite_values)
args_loaded = NeoXArgs.from_dict(p_dict)

args_loaded.build_tokenizer()

initialize_megatron(neox_args=args_loaded)
model, optimizer, lr_scheduler = setup_model_and_optimizer(neox_args=args_loaded, inference=False, get_key_value=True)

model, optimizer, lr_scheduler, args_loaded = model_setup(yaml_list, param_dict)
print(type(model), flush=True)
if args_loaded.pipe_parallel_size < 2:
assert isinstance(model, DeepSpeedEngine), "test model instantiation "+str(yaml_list)
Expand Down
159 changes: 79 additions & 80 deletions tests/model/test_model_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,115 +4,114 @@
This tests contain a relatively large number of functions. They are not split into separate tests because a lot of boilerplate (e.g. instantiate model) needs
to run in order to perform follow up tests. Joining in one test reduces runtime at the expense of decreased transparency of test results in case of failures.
"""
import pytest

import os
from pathlib import Path

from ..common import TEST_CHECKPOINT_DIR, TEST_LOG_DIR, TEST_TENSORBOARD_DIR
from ..common import distributed_test, get_root_directory, get_test_configs_with_path, clear_test_dirs
from copy import deepcopy
from ..common import distributed_test, clear_test_dirs, model_setup, get_test_configs_with_path, bounded_product, binary

import torch

@distributed_test(world_size=1)
def test_model_train_small_0():
yaml_list = get_test_configs_with_path(["test_local_setup.yml", "test_small_0.yml"])
run_train_test(yaml_list=yaml_list)

@distributed_test(world_size=1)
def test_model_train_small_1():
yaml_list = get_test_configs_with_path(["test_local_setup.yml", "test_small_1.yml"])
run_train_test(yaml_list=yaml_list)

# TODO after sorting out scaled-upper-triang-masked-softmax-fusion and fp16
#@distributed_test(world_size=2)
#def test_model_train_small_2():
# yaml_list = get_test_configs_with_path(["test_local_setup.yml", "test_small_2.yml"])
# run_train_test(yaml_list=yaml_list)

# TODO after sorting out RPE + sparse attention
#@distributed_test(world_size=1)
#def test_model_train_small_3():
# yaml_list = get_test_configs_with_path(["test_local_setup.yml", "test_small_3.yml"])
# run_train_test(yaml_list=yaml_list)

from yaml import load

try:
from yaml import CLoader as Loader, CDumper as Dumper
except ImportError:
from yaml import Loader, Dumper

with open(get_test_configs_with_path("test_train_base.yml")[0], 'r') as f:
BASE_CONFIG = load(f, Loader=Loader)

PARAMS_TO_TEST = {
"norm,pos_emb": [["layernorm", "learned"], ["rmsnorm", "rotary"], ["scalenorm", "sinusoidal"],
["layernorm", "rpe"], ["rmsnorm", "none"]],
"pipe_parallel_size,model_parallel_size": [[0, 1], [1, 1], [2, 2], [0, 2]],
"no_weight_tying": binary,
"attention_config": [[[["global"], "all"]], [[["local"], "all"]], [[["sparse_variable"], "all"]],
[[["sparse_fixed"], "all"]]],
"scaled_upper_triang_masked_softmax_fusion": binary,
"bias_gelu_fusion": binary,
"checkpoint_activations": binary,
}


def parametrize(params_to_test: dict, max_tests: int = 50, seed: int = None):
"""
Generates a random sample of max_tests length of all possible combinations of values in
`params_to_test`.
In `params_to_test` you can either specify one value, and all possible settings of that value,
or two values separated by a comma, and all possible combinations of those two values in tandem.
i.e "hidden_size,num_heads": [[768,12], [1024,32], [2048, 64]]
so the first item in each list is a value of `hidden_size` and the second a value of `num_heads`
this is useful for reducing the size of possible tests for values we know are unlikely to interact beforehand,
since the cartesian product can grow very large.
:param params_to_test: dict of neox params
:param max_tests: maximum number of tests to run
:param seed: random seed
:return: a list of neox param dicts to pass to a parametrized unit test
"""
keys, values = zip(*params_to_test.items())
for p in bounded_product(values, n=max_tests, seed=seed):
experiment = dict(zip(keys, p))
for k, v in experiment.items():
if "," in k:
keys_split = [i.strip() for i in k.split(',')]
values_separated = experiment.pop(k)
assert len(values_separated) == len(keys_split)
new_dict = dict(zip(keys_split, values_separated))
experiment.update(new_dict)
base = deepcopy(BASE_CONFIG)
base.update(experiment)
yield base


@pytest.mark.parametrize("param_dict", list(parametrize(PARAMS_TO_TEST, max_tests=50, seed=None)))
@distributed_test(world_size=2)
def test_model_train_small_4():
yaml_list = get_test_configs_with_path(["test_local_setup.yml", "test_small_4.yml"])
run_train_test(yaml_list=yaml_list)
def test_train(param_dict):
run_train_test(param_dict=param_dict)

def run_train_test(yaml_list=None, param_dict=None):

from megatron.neox_arguments import NeoXArgs
from megatron import initialize_megatron
from megatron.training import setup_model_and_optimizer, train_step
from megatron.mpu import destroy_model_parallel
def run_train_test(yaml_list=None, param_dict=None):
from megatron.training import train_step
from megatron.utils import Timers

max_steps = 256
max_steps = 256

model, optimizer, lr_scheduler, args_loaded = model_setup(yaml_list, param_dict)

destroy_model_parallel() # mpu model parallel contains remaining global vars

if torch.distributed.get_world_size() == 1 or torch.distributed.get_rank() == 0:
clear_test_dirs()

overwrite_values = {
"user_script": str(get_root_directory() / "pretrain_gpt2.py"),
"save": TEST_CHECKPOINT_DIR,
"load": TEST_CHECKPOINT_DIR,
"log_dir": TEST_LOG_DIR,
"tensorboard_dir": TEST_TENSORBOARD_DIR,
}

# should not both be none
assert yaml_list is not None or param_dict is not None

# intitially load config from files as would be the case in deepy.py
if yaml_list is not None:
args_loaded = NeoXArgs.from_ymls(yaml_list, overwrite_values=overwrite_values)
else:
p_dict = param_dict.copy()
p_dict.update(overwrite_values)
args_loaded = NeoXArgs.from_dict(p_dict)

args_loaded.build_tokenizer()

initialize_megatron(neox_args=args_loaded)

model, optimizer, lr_scheduler = setup_model_and_optimizer(neox_args=args_loaded, inference=False, get_key_value=True)
model.train()


timers = Timers(use_wandb=False, tensorboard_writer=None)

# generate some random data on which we can overfit
# context size of data is model seq_len + 1 in order to compute loss
data_list = list()
context_tokens_tensor = torch.randint(0, args_loaded.padded_vocab_size, (4, args_loaded.seq_length + 1 )).to(torch.int64)
context_tokens_tensor = torch.randint(0, args_loaded.padded_vocab_size, (4, args_loaded.seq_length + 1)).to(
torch.int64)
for i in range(max_steps):
data_list.append({ "text": context_tokens_tensor.clone() })
data_list.append({"text": context_tokens_tensor.clone()})
data_iterator = iter(data_list)

# run train_step until the loss decreases
losses = list()
for i in range(max_steps):
loss_dict, skipped_iter = train_step(
neox_args=args_loaded,
timers=timers,
data_iterator=data_iterator,
model=model,
optimizer=optimizer,
neox_args=args_loaded,
timers=timers,
data_iterator=data_iterator,
model=model,
optimizer=optimizer,
lr_scheduler=lr_scheduler
)
)
losses.append(loss_dict["lm_loss"])
if len(losses) >= 2:
if torch.isnan(losses[-1]): continue
if torch.isnan(losses[-2]): continue
if losses[-1] < losses[-2]:
return # all good
return # all good

# loss should have decreased by now (otherwise increasing the max_steps parameter could have the testcase pass)
assert losses[-1] < losses[-2], "run_train_test() loss going down within "+str(max_steps)+" steps"
assert losses[-1] < losses[-2], "run_train_test() loss going down within " + str(max_steps) + " steps"

if torch.distributed.get_world_size() == 1 or torch.distributed.get_rank() == 0:
clear_test_dirs()
clear_test_dirs()
Loading

0 comments on commit a2ecec7

Please sign in to comment.