Skip to content

Commit

Permalink
parametrize all model tests + simplify
Browse files Browse the repository at this point in the history
  • Loading branch information
Samuel Weinbach committed May 13, 2021
1 parent 00adcf3 commit 1e8fb3b
Show file tree
Hide file tree
Showing 17 changed files with 118 additions and 819 deletions.
20 changes: 15 additions & 5 deletions tests/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,7 @@ def run_func_decorator(*func_args, **func_kwargs):
return dist_wrap


def model_setup(yaml_list=None, param_dict=None, clear_data=True):
def model_setup(yaml_list=None, param_dict=None, clear_data=True, inference=False):
from megatron.neox_arguments import NeoXArgs
from megatron.mpu import destroy_model_parallel
from megatron import initialize_megatron
Expand Down Expand Up @@ -184,7 +184,7 @@ def model_setup(yaml_list=None, param_dict=None, clear_data=True):
args_loaded.build_tokenizer()

initialize_megatron(neox_args=args_loaded)
model, optimizer, lr_scheduler = setup_model_and_optimizer(neox_args=args_loaded, inference=False,
model, optimizer, lr_scheduler = setup_model_and_optimizer(neox_args=args_loaded, inference=inference,
get_key_value=True)
return model, optimizer, lr_scheduler, args_loaded

Expand All @@ -207,7 +207,7 @@ def bounded_product(sequence, n=None, seed=None):
return p if n is None else p[:n]


def parametrize(params_to_test: dict, max_tests: int = 50, seed: int = None):
def parametrize(params_to_test: dict, max_tests: int = 50, seed: int = None, with_names=True):
"""
Generates a random sample of max_tests length of all possible combinations of values in
`params_to_test`.
Expand All @@ -225,6 +225,9 @@ def parametrize(params_to_test: dict, max_tests: int = 50, seed: int = None):
:return: a list of neox param dicts to pass to a parametrized unit test
"""
keys, values = zip(*params_to_test.items())
ret = []
if with_names:
experiments = []
for p in bounded_product(values, n=max_tests, seed=seed):
experiment = dict(zip(keys, p))
to_pop = []
Expand All @@ -242,8 +245,15 @@ def parametrize(params_to_test: dict, max_tests: int = 50, seed: int = None):
experiment.pop(k)
base = deepcopy(BASE_CONFIG)
base.update(experiment)
yield base

ret.append(base)
if with_names:
experiments.append(experiment)
if with_names:
return ret, [dict_repr(d) for d in experiments]
return ret

def dict_repr(d):
return " ".join([f"{str(k)} : {str(v)}" for k, v in d.items()])

binary = [True, False]

Expand Down
72 changes: 0 additions & 72 deletions tests/config_comparison.py

This file was deleted.

5 changes: 3 additions & 2 deletions tests/model/test_model_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,15 @@
"checkpoint_validation_with_forward_pass": [True]
}


@pytest.mark.parametrize("param_dict", list(parametrize(PARAMS_TO_TEST, max_tests=50, seed=None)))
parameters, names = parametrize(PARAMS_TO_TEST, max_tests=50, seed=None)
@pytest.mark.parametrize("param_dict", parameters, ids=names)
def test_train(param_dict):
@distributed_test(world_size=2)
def wrapper():
run_checkpoint_test(param_dict=param_dict)
wrapper()


def run_checkpoint_test(yaml_list=None, param_dict=None):

from megatron.checkpointing import load_checkpoint
Expand Down
Loading

0 comments on commit 1e8fb3b

Please sign in to comment.