Skip to content

Commit

Permalink
Merge pull request EleutherAI#680 from dashstander/srun
Browse files Browse the repository at this point in the history
SLURM Multi-Node Support
  • Loading branch information
Quentin-Anthony committed Sep 22, 2022
2 parents f5ce8f1 + 8d53ee0 commit e70f390
Show file tree
Hide file tree
Showing 18 changed files with 288 additions and 84 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -328,7 +328,7 @@ We also support using TensorBoard via the <code><var>tensorboard-dir</var></code

# Running on multi-node

If you need to supply a hostfile for use with the MPI-based DeepSpeed launcher, you can set the environment variable `DLTS_HOSTFILE` to point to the hostfile.
If you need to supply a hostfile for use with the MPI-based DeepSpeed launcher, you can set the environment variable `DLTS_HOSTFILE` to point to the hostfile.

# Administrative Notes

Expand Down
40 changes: 40 additions & 0 deletions configs/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -257,4 +257,44 @@ An example config for fp16 training:

To train in fp32, simply set `fp16["enabled"]` to `false`.


### SLURM Settings

If you are running GPT-NeoX on a SLURM cluster and wish to use SLURM to coordinate nodes, then you must set the following variables in your config:

```yaml
"launcher": "slurm",
"deepspeed_slurm": true
```

Additionally, you need to modify _all_ of your configs to conform to the JSON. When launching a GPT-NeoX job you can specify multiple YAML config files. Internally, all of these files are merged into one config and then passed as a single long command line argument to Deep(er)Speed. When using SLURM and its internal command `srun`, python fails to parse this long command line argument unless it is in the more restrictive JSON format. In practice, the example NeoX configs are already very close to JSON. As an example, this is a snippet of a YAML-compatible config, N.B. the comment the capital-F `False`:

```yaml
# optimizer settings
"optimizer": {
"type": "OneBitAdam",
"params": {
"lr": 0.0001,
"freeze_step": 23000,
"betas": [0.9, 0.95],
"cuda_aware": False,
"comm_backend_name": "nccl"
}
```

To make this JSON just remove the comment and use all lowercase for the boolean:

```yaml
"optimizer": {
"type": "OneBitAdam",
"params": {
"lr": 0.0001,
"freeze_step": 23000,
"betas": [0.9, 0.95],
"cuda_aware": false,
"comm_backend_name": "nccl"
}
```


** TODO: bf16 docs **
5 changes: 2 additions & 3 deletions configs/neox_arguments.md
Original file line number Diff line number Diff line change
Expand Up @@ -930,7 +930,7 @@ Text Generation arguments

- **eval_results_prefix**: str

Default =
Default =

prefix to which to save evaluation results - final fp will be {eval_results_prefix}_eval_results_yy-mm-dd-HH-MM.json

Expand Down Expand Up @@ -1518,7 +1518,7 @@ Args for deepspeed config

Default = None





Expand Down Expand Up @@ -1644,4 +1644,3 @@ Args for deepspeed runner (deepspeed.launcher.runner).
Default = False

If true, autodetects nvlink pairs and remaps cuda visible devices to place them next to each other. This is an Eleuther addition to deepspeed, and should speed up model parallel training on setups with nvlink pairs when mp=2.

13 changes: 13 additions & 0 deletions configs/slurm_local.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
{
"data-path": "data/enron/enron_text_document",
"vocab-file": "data/gpt2-vocab.json",
"merge-file": "data/gpt2-merges.txt",
"save": "checkpoints",
"load": "checkpoints",
"checkpoint_validation_with_forward_pass": false,
"tensorboard-dir": "tensorboard",
"log-dir": "logs",
"use_wandb": true,
"wandb_host": "https://api.wandb.ai",
"wandb_project": "neox"
}
65 changes: 65 additions & 0 deletions configs/slurm_small.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
{
"pipe-parallel-size": 1,
"model-parallel-size": 1,
"num-layers": 12,
"hidden-size": 768,
"num-attention-heads": 12,
"seq-length": 2048,
"max-position-embeddings": 2048,
"norm": "layernorm",
"pos-emb": "rotary",
"no-weight-tying": true,
"scaled-upper-triang-masked-softmax-fusion": true,
"bias-gelu-fusion": true,
"optimizer": {
"type": "Adam",
"params": {
"lr": 0.0006,
"betas": [0.9, 0.999],
"eps": 1.0e-8
}
},
"zero_optimization": {
"stage": 0,
"allgather_partitions": true,
"allgather_bucket_size": 500000000,
"overlap_comm": true,
"reduce_scatter": true,
"reduce_bucket_size": 500000000,
"contiguous_gradients": true,
"cpu_offload": false
},
"train_micro_batch_size_per_gpu": 4,
"data-impl": "mmap",
"split": "949,50,1",
"checkpoint-activations": true,
"checkpoint-num-layers": 1,
"partition-activations": true,
"synchronize-each-layer": true,
"gradient_clipping": 1.0,
"weight-decay": 0.0,
"hidden-dropout": 0.0,
"attention-dropout": 0.0,
"fp16": {
"enabled": true,
"loss_scale": 0,
"loss_scale_window": 1000,
"hysteresis": 2,
"min_loss_scale": 1
},
"train-iters": 320000,
"lr-decay-iters": 320000,
"distributed-backend": "nccl",
"lr-decay-style": "cosine",
"warmup": 0.01,
"save-interval": 10000,
"eval-interval": 1000,
"eval-iters": 10,
"log-interval": 100,
"steps_per_print": 10,
"keep-last-n-checkpoints": 4,
"wall_clock_breakdown": true,
"launcher": "slurm",
"deepspeed_slurm": true,
"slurm_comment": "neox"
}
6 changes: 3 additions & 3 deletions eval_tasks/eval_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -356,7 +356,7 @@ def run_eval(
description_dict=None,
use_cache=True,
name="neox",
limit=None
limit=None,
):
was_training = self.model.training
self.model.eval()
Expand Down Expand Up @@ -389,7 +389,7 @@ def run_eval(
if use_cache:
# TODO(jon-tow): Append a subset of `neox_args` to the cache database
# name arg to distinguish model runs that use different configurations.
lm = base.CachingLM(lm, 'lm_cache/' + name + '.db')
lm = base.CachingLM(lm, "lm_cache/" + name + ".db")

results = evaluator.evaluate(
lm=lm,
Expand All @@ -409,7 +409,7 @@ def run_eval(
"no_cache": not use_cache,
"limit": limit,
"bootstrap_iters": bootstrap_iters,
"description_dict": description_dict
"description_dict": description_dict,
}

if was_training:
Expand Down
2 changes: 1 addition & 1 deletion megatron/data/data_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -250,7 +250,7 @@ def weights_by_num_docs(l, alpha=0.3):
total_n_docs = sum(l)
unbiased_sample_probs = [i / total_n_docs for i in l]

probs = [i ** alpha for i in unbiased_sample_probs]
probs = [i**alpha for i in unbiased_sample_probs]

# normalize
total = sum(probs)
Expand Down
2 changes: 1 addition & 1 deletion megatron/model/gmlp.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ class TinyAttention(nn.Module):
def __init__(self, neox_args, d_attn, d_ff, mask_fn):
super().__init__()
self.proj_qkv = nn.Linear(d_ff * 2, 3 * d_attn)
self.scale = d_attn ** -0.5
self.scale = d_attn**-0.5
self.proj_ffn = nn.Linear(d_attn, d_ff)
self.softmax = FusedScaleMaskSoftmax(
input_in_fp16=neox_args.precision == "fp16",
Expand Down
10 changes: 7 additions & 3 deletions megatron/model/positional_embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ def _get_slopes(self, n):
def get_slopes_power_of_2(n):
start = 2 ** (-(2 ** -(math.log2(n) - 3)))
ratio = start
return [start * ratio ** i for i in range(n)]
return [start * ratio**i for i in range(n)]

if math.log2(n).is_integer():
return get_slopes_power_of_2(n)
Expand All @@ -129,9 +129,13 @@ def forward(self, x):
if self.cached_seq_len is not None and self.cached_seq_len >= seq_len_k:
a = self.cached_matrix
else:
target_seq_len = seq_len_k if self.cached_seq_len is None else self.cached_seq_len * 4
target_seq_len = (
seq_len_k if self.cached_seq_len is None else self.cached_seq_len * 4
)
a = -torch.tril(
torch.arange(target_seq_len).view(target_seq_len, 1).repeat(1, target_seq_len)
torch.arange(target_seq_len)
.view(target_seq_len, 1)
.repeat(1, target_seq_len)
+ torch.arange(0, -target_seq_len, -1)
)
a = a.to(x.device).to(x.dtype)
Expand Down
15 changes: 11 additions & 4 deletions megatron/neox_arguments/arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -289,8 +289,8 @@ def consume_deepy_args(cls):
"--hostfile",
type=str,
help="Hostfile path (in MPI style) that defines the "
"resource pool available to the job (e.g., "
"worker-0 slots=4)"
"resource pool available to the job (e.g., "
"worker-0 slots=4)",
)
group = parser.add_argument_group(title="Generation args")
group.add_argument(
Expand Down Expand Up @@ -391,9 +391,11 @@ def get_deepspeed_main_args(self):
args_list.extend(
self.convert_key_value_to_command_line_arg(key, configured_value)
)
if 'DLTS_HOSTFILE' in os.environ:
if "DLTS_HOSTFILE" in os.environ:
args_list.extend(
self.convert_key_value_to_command_line_arg("hostfile", os.environ['DLTS_HOSTFILE'])
self.convert_key_value_to_command_line_arg(
"hostfile", os.environ["DLTS_HOSTFILE"]
)
)

if (
Expand Down Expand Up @@ -545,6 +547,11 @@ def configure_distributed_args(self):

mpi_discovery()

if self.deepspeed_slurm:
os.environ["LOCAL_RANK"] = os.environ["SLURM_LOCALID"]
os.environ["RANK"] = os.environ["SLURM_PROCID"]
os.environ["WORLD_SIZE"] = os.environ["SLURM_NTASKS"]

self.update_value("local_rank", int(os.getenv("LOCAL_RANK", "0")))
self.update_value("rank", int(os.getenv("RANK", "0")))
self.update_value("world_size", int(os.getenv("WORLD_SIZE", "1")))
Expand Down
5 changes: 5 additions & 0 deletions megatron/neox_arguments/deepspeed_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,3 +171,8 @@ class NeoXArgsDeepspeedRunner(NeoXArgsTemplate):
"""
If true, autodetects nvlink pairs and remaps cuda visible devices to place them next to each other. This is an Eleuther addition to deepspeed, and should speed up model parallel training on setups with nvlink pairs when mp=2.
"""

slurm_comment: str = None
"""
If using SLURM launcher adds a `--comment` to the srun command that launches the job. Sometimes necessary for cluster rules, or so I've heard.
"""
5 changes: 5 additions & 0 deletions megatron/neox_arguments/neox_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -581,6 +581,11 @@ class NeoXArgsOther(NeoXArgsTemplate):
Run via MPI, this will attempt to discover the necessary variables to initialize torch distributed from the MPI environment
"""

deepspeed_slurm: bool = False
"""
Run via SLURM, this will attempt to discover the necessary variables to initialize torch distributed from the SLURM environment
"""

user_script: str = None
"""
user script to be run
Expand Down
11 changes: 8 additions & 3 deletions megatron/text_generation_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -278,7 +278,9 @@ def stream_tokens(
# initialize generation variables
state_is_done = torch.zeros([batch_size]).byte().cuda()
token_generation_end_index = torch.ones([batch_size]).long().cuda() * (-1)
generation_logits = torch.empty(maximum_tokens, neox_args.padded_vocab_size).float().cuda()
generation_logits = (
torch.empty(maximum_tokens, neox_args.padded_vocab_size).float().cuda()
)

while token_index_to_generate <= last_token_index_to_generate:
if recompute: # recompute all tokens
Expand Down Expand Up @@ -335,7 +337,9 @@ def stream_tokens(
).view(-1)

if neox_args.return_logits:
generation_logits[token_index_to_generate - 1] = generated_token_logits[0]
generation_logits[
token_index_to_generate - 1
] = generated_token_logits[0]

if neox_args.is_pipe_parallel:
# broadcast generated tokens to pipe parallel group
Expand Down Expand Up @@ -776,7 +780,8 @@ def generate_samples_interactive(
.tolist()[
batch_token_generation_start_index[0]
.item() : batch_token_generation_end_index[0]
.item() + 1
.item()
+ 1
]
)
generated_text = neox_args.tokenizer.detokenize(generated_tokens)
Expand Down
4 changes: 2 additions & 2 deletions megatron/tokenizer/gpt2_tokenization.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,10 +65,10 @@ def bytes_to_unicode():
)
cs = bs[:]
n = 0
for b in range(2 ** 8):
for b in range(2**8):
if b not in bs:
bs.append(b)
cs.append(2 ** 8 + n)
cs.append(2**8 + n)
n += 1
cs = [_chr(n) for n in cs]
return dict(zip(bs, cs))
Expand Down
4 changes: 2 additions & 2 deletions megatron/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,7 @@ def pretrain(neox_args):
iteration=iteration,
verbose=True,
timers=timers,
chart_name="test"
chart_name="test",
)


Expand Down Expand Up @@ -737,7 +737,7 @@ def evaluate_and_print_results(
iteration,
verbose=False,
timers=None,
chart_name="validation"
chart_name="validation",
):
"""Helper function to evaluate and dump results on screen."""
total_loss_dict = evaluate(
Expand Down
4 changes: 4 additions & 0 deletions megatron/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,10 @@ def get_ltor_masks_and_position_ids(
def local_rank():
"""Local rank of process"""
local_rank = os.environ.get("LOCAL_RANK")

if local_rank is None:
local_rank = os.environ.get("SLURM_LOCALID")

if local_rank is None:
print(
"utils.local_rank() environment variable LOCAL_RANK not set, defaulting to 0",
Expand Down
4 changes: 3 additions & 1 deletion tests/neox_args/test_neoxargs_commandline.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ def test_neoxargs_consume_deepy_args_without_yml_suffix():

assert args_loaded_yamls == args_loaded_consume


@pytest.mark.cpu
def test_neoxargs_consume_deepy_args_with_hostfile_param():
"""
Expand All @@ -77,7 +78,7 @@ def test_neoxargs_consume_deepy_args_with_hostfile_param():
"sys.argv",
[str(get_root_directory() / "deepy.py"), "train.py"]
+ get_configs_with_path(["small", "local_setup"])
+ ["--hostfile=/mock_path"]
+ ["--hostfile=/mock_path"],
):
args_loaded_consume = NeoXArgs.consume_deepy_args()

Expand All @@ -92,6 +93,7 @@ def test_neoxargs_consume_deepy_args_with_hostfile_param():

assert args_loaded_yamls == args_loaded_consume


@pytest.mark.cpu
def test_neoxargs_consume_deepy_args_with_config_dir():
"""
Expand Down
Loading

0 comments on commit e70f390

Please sign in to comment.