-
Notifications
You must be signed in to change notification settings - Fork 978
/
test_neoxargs_commandline.py
89 lines (61 loc) · 3.56 KB
/
test_neoxargs_commandline.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
import sys
from unittest.mock import patch
from ..common import get_root_directory, get_config_directory, get_configs_with_path
def test_neoxargs_consume_deepy_args_with_config_dir():
"""
verify consume_deepy_args processes command line arguments without config dir
"""
from megatron.neox_arguments import NeoXArgs
# load neox args with command line
with patch('sys.argv', [str(get_root_directory() / "deepy.py"), "pretrain_gpt2.py"] + get_configs_with_path(["small.yml", "local_setup.yml"])):
args_loaded_consume = NeoXArgs.consume_deepy_args()
# load neox args directly from yaml files
args_loaded_yamls = NeoXArgs.from_ymls(get_configs_with_path(["small.yml", "local_setup.yml"]))
# update values from yaml files that cannot otherwise be matched
args_loaded_yamls.update_value("user_script", "pretrain_gpt2.py")
args_loaded_yamls.wandb_group = args_loaded_consume.wandb_group
assert args_loaded_yamls == args_loaded_consume
def test_neoxargs_consume_deepy_args_without_yml_suffix():
"""
verify consume_deepy_args processes command line arguments without yaml suffix
"""
from megatron.neox_arguments import NeoXArgs
# load neox args with command line
with patch('sys.argv', [str(get_root_directory() / "deepy.py"), "pretrain_gpt2.py"] + get_configs_with_path(["small", "local_setup"])):
args_loaded_consume = NeoXArgs.consume_deepy_args()
# load neox args directly from yaml files
args_loaded_yamls = NeoXArgs.from_ymls(get_configs_with_path(["small.yml", "local_setup.yml"]))
# update values from yaml files that cannot otherwise be matched
args_loaded_yamls.update_value("user_script", "pretrain_gpt2.py")
args_loaded_yamls.wandb_group = args_loaded_consume.wandb_group
assert args_loaded_yamls == args_loaded_consume
def test_neoxargs_consume_deepy_args_with_config_dir():
"""
verify consume_deepy_args processes command line arguments including config dir
"""
from megatron.neox_arguments import NeoXArgs
# load neox args with command line
with patch('sys.argv', [str(get_root_directory() / "deepy.py"), "pretrain_gpt2.py", '-d', str(get_config_directory())] + ["small.yml", "local_setup.yml"]):
args_loaded_consume = NeoXArgs.consume_deepy_args()
# load neox args directly from yaml files
args_loaded_yamls = NeoXArgs.from_ymls(get_configs_with_path(["small.yml", "local_setup.yml"]))
# update values from yaml files that cannot otherwise be matched
args_loaded_yamls.update_value("user_script", "pretrain_gpt2.py")
args_loaded_yamls.wandb_group = args_loaded_consume.wandb_group
assert args_loaded_yamls == args_loaded_consume
def test_neoxargs_consume_neox_args():
"""
verify megatron args are correctly consumed after sending via deepspeed
"""
from megatron.neox_arguments import NeoXArgs
# intitially load config from files as would be the case in deepy.py
yaml_list = get_configs_with_path(["small.yml", "local_setup.yml"])
args_baseline = NeoXArgs.from_ymls(yaml_list)
args_baseline.update_value("user_script", str(get_root_directory() / "pretrain_gpt2.py"))
deepspeed_main_args = args_baseline.get_deepspeed_main_args()
# patch sys.argv so that args can be access by set_global_variables within initialize_megatron
with patch('sys.argv', deepspeed_main_args):
args_loaded = NeoXArgs.consume_neox_args()
#TODO is the wandb group really to be changed?
args_loaded.wandb_group = args_baseline.wandb_group
assert args_baseline.megatron_config == args_loaded.megatron_config