Skip to content

Commit

Permalink
Curriculum Learning Support (#695)
Browse files Browse the repository at this point in the history
* Remove deprecated deepspeed.utils.distributed call

* Initial curriculum learning support

* Add is_train flag for curriculum learning

* Update NeoXArgs docs automatically

* add comment arg

Signed-off-by: Dashiell Stander <[email protected]>

* Update NeoXArgs docs automatically

* Add slurm stuff

* Update NeoXArgs docs automatically

* Allow json

* Update NeoXArgs docs automatically

* Apply curriculum learning seq_len to pipeline parallel data loading

Signed-off-by: Dashiell Stander <[email protected]>

* Update NeoXArgs docs automatically

* Actually updating the curriculum seq_len

Signed-off-by: Dashiell Stander <[email protected]>

* Update NeoXArgs docs automatically

* Actually updating the curriculum seq_len

Signed-off-by: Dashiell Stander <[email protected]>

* Update NeoXArgs docs automatically

* Actually updating the curriculum seq_len

Signed-off-by: Dashiell Stander <[email protected]>

* Update NeoXArgs docs automatically

* Actually updating the curriculum seq_len

Signed-off-by: Dashiell Stander <[email protected]>

* Update NeoXArgs docs automatically

* Iteration + 1

Signed-off-by: Dashiell Stander <[email protected]>

* Update NeoXArgs docs automatically

* Clean up comments and debug print statements

Signed-off-by: Dashiell Stander <[email protected]>

* Update NeoXArgs docs automatically

* Debug print again

Signed-off-by: Dashiell Stander <[email protected]>

* Update NeoXArgs docs automatically

* more print statements

Signed-off-by: Dashiell Stander <[email protected]>

* Update NeoXArgs docs automatically

* Remove debug print statements

Signed-off-by: Dashiell Stander <[email protected]>

* Update NeoXArgs docs automatically

* Pre-commit

* Update NeoXArgs docs automatically

* Update NeoXArgs docs automatically

* Update NeoXArgs docs automatically

---------

Signed-off-by: Dashiell Stander <[email protected]>
Co-authored-by: Quentin TastyRice <[email protected]>
Co-authored-by: Dashiell Stander <[email protected]>
Co-authored-by: github-actions <[email protected]>
Co-authored-by: Dashiell Stander <[email protected]>
  • Loading branch information
5 people committed Mar 9, 2023
1 parent 2b84f9a commit 68d223c
Show file tree
Hide file tree
Showing 6 changed files with 111 additions and 7 deletions.
18 changes: 17 additions & 1 deletion configs/neox_arguments.md
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ Logging Arguments

- **git_hash**: str

Default = d49acf3
Default = cbed1b5

current git hash of repository

Expand Down Expand Up @@ -1676,6 +1676,22 @@ Args for deepspeed config



- **curriculum_learning**: dict

Default = None





- **curriculum_seqlen**: int

Default = 0

Internal var for tracking the current seqlen



- **steps_per_print**: int

Default = 10
Expand Down
10 changes: 10 additions & 0 deletions megatron/logging.py
Original file line number Diff line number Diff line change
Expand Up @@ -297,6 +297,16 @@ def add_to_logging(name):
1, neox_args.log_interval - total_loss_dict[skipped_iters_key]
)

# log curriculum learning
if neox_args.curriculum_learning:
tb_wandb_log(
"curriculum_seqlen",
neox_args.curriculum_seqlen,
iteration,
use_wandb=neox_args.use_wandb,
tensorboard_writer=neox_args.tensorboard_writer,
)

# log tflop / gpu
flops_per_s_per_gpu = get_flops(
neox_args=neox_args, model=model, iter_time_s=iteration_time
Expand Down
29 changes: 28 additions & 1 deletion megatron/model/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from megatron.model.norms import LayerNorm, RMSNorm, ScaleNorm
from megatron.model.fused_softmax import SoftmaxFusionTypes
from types import GeneratorType
import torch.distributed as dist


def get_params_for_weight_decay_optimization(module, neox_args):
Expand Down Expand Up @@ -120,7 +121,33 @@ def train_mode(self):
"""
_set_use_cache(self.sequential, False)

def forward(self, forward_input):
def forward(
self, forward_input, curriculum_seqlen=None, labels=None, neox_args=None
):

if (
curriculum_seqlen is not None
and isinstance(forward_input, tuple)
and len(forward_input) == 3
):
neox_args.update_value("curriculum_seqlen", curriculum_seqlen)
tokens = forward_input[0]
input_ids = forward_input[1]
attention_mask = forward_input[2]
if curriculum_seqlen < input_ids.size()[1]:
# seqlen-based curriculum learning
# input_ids, position_ids, labels have size [batch size, seqlen]
input_ids = input_ids[:, :curriculum_seqlen].contiguous()
tokens = tokens[:, :curriculum_seqlen].contiguous()
# position_ids = position_ids[:, :curriculum_seqlen].contiguous()
if labels is not None:
labels = labels[:, :curriculum_seqlen].contiguous()
# attention_mask has size [1, 1, seqlen, seqlen]
attention_mask = attention_mask[
:, :, :curriculum_seqlen, :curriculum_seqlen
].contiguous()
forward_input = (tokens, input_ids, attention_mask)

def exec_range_func(start, end):
"""Helper function to be used with checkpoint()
Adapted from torch.utils.checkpoint:checkpoint_sequential()
Expand Down
5 changes: 4 additions & 1 deletion megatron/neox_arguments/arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -334,7 +334,10 @@ def consume_deepy_args(cls):
conf_files = [os.path.join(args_parsed.conf_dir, f) for f in conf_files]

# enables us to pass in `small` instead of `small.yml`
conf_files = [(cf if cf.endswith(".yml") else cf + ".yml") for cf in conf_files]
conf_files = [
(cf if cf.endswith(".yml") or cf.endswith(".json") else cf + ".yml")
for cf in conf_files
]

# determine overwrite values
overwrite_values = dict()
Expand Down
8 changes: 8 additions & 0 deletions megatron/neox_arguments/deepspeed_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,14 @@ class NeoXArgsDeepspeedConfig(NeoXArgsTemplate):
zero_optimization: dict = None
""""""

curriculum_learning: dict = None
""""""

curriculum_seqlen: int = 0
"""
Internal var for tracking the current seqlen
"""

steps_per_print: int = 10
"""
Print train loss every N steps.
Expand Down
48 changes: 44 additions & 4 deletions megatron/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@

import torch
import deepspeed
from deepspeed.runtime.data_pipeline.curriculum_scheduler import CurriculumScheduler
import numpy as np

from megatron.utils import (
Expand Down Expand Up @@ -301,7 +302,7 @@ def get_batch(neox_args, data_iterator):
)


def get_batch_pipe(data, neox_args):
def get_batch_pipe(data, neox_args, curr_scheduler=None):
"""A modification of get_batch() to work with the latest batch instead of an iterator."""
# Items and their type.
keys = ["text"]
Expand All @@ -310,12 +311,31 @@ def get_batch_pipe(data, neox_args):
tokens, labels, loss_mask, attention_mask, position_ids = _get_batch(
neox_args, neox_args.tokenizer, keys, data, datatype
)
if curr_scheduler is not None:
# iteration + 1 to align with how/when DeepSpeed updates the buffers
curriculum_seqlen = curr_scheduler.update_difficulty(neox_args.iteration + 1)
if curriculum_seqlen < tokens.size()[1]:
# seqlen-based curriculum learning
# input_ids, position_ids, labels have size [batch size, seqlen]
# input_ids = input_ids[:, :curriculum_seqlen].contiguous()
tokens = tokens[:, :curriculum_seqlen].contiguous()
position_ids = position_ids[:, :curriculum_seqlen].contiguous()
if labels is not None:
labels = labels[:, :curriculum_seqlen].contiguous()
if loss_mask is not None:
loss_mask = loss_mask[:, :curriculum_seqlen].contiguous()
# attention_mask has size [1, 1, seqlen, seqlen]
attention_mask = attention_mask[
:, :, :curriculum_seqlen, :curriculum_seqlen
].contiguous()

# unpack data
return (tokens, position_ids, attention_mask), (labels, loss_mask)


def forward_step(data_iterator, model, neox_args, timers, return_logits=False):
def forward_step(
data_iterator, model, neox_args, timers, return_logits=False, is_train=False
):
"""Forward step."""
if neox_args.is_pipe_parallel:
return model.eval_batch(data_iterator, return_logits=return_logits)
Expand All @@ -326,10 +346,18 @@ def forward_step(data_iterator, model, neox_args, timers, return_logits=False):
tokens, labels, loss_mask, attention_mask, position_ids = get_batch(
neox_args=neox_args, data_iterator=data_iterator
)

if timers is not None:
timers("batch generator").stop()

outputs = model((tokens, position_ids, attention_mask))
outputs = model((tokens, position_ids, attention_mask), neox_args=neox_args)
if (
is_train
and neox_args.curriculum_learning
and neox_args.curriculum_seqlen < neox_args.seq_length
):
loss_mask = loss_mask[:, : neox_args.curriculum_seqlen].contiguous()
labels = labels[:, : neox_args.curriculum_seqlen].contiguous()
loss = cross_entropy(
outputs, (labels, loss_mask), _fp16=neox_args.fp16_lm_cross_entropy
)
Expand Down Expand Up @@ -589,7 +617,17 @@ def setup_model_and_optimizer(neox_args, use_cache=False, iteration=None):

if neox_args.is_pipe_parallel:
model.set_has_attention_mask(True)
model.set_batch_fn(partial(get_batch_pipe, neox_args=neox_args))
if neox_args.curriculum_learning:
curr_scheduler = CurriculumScheduler(neox_args.curriculum_learning)
if iteration is not None and iteration > 0:
curr_scheduler.update_difficulty(iteration)
else:
curr_scheduler = None
model.set_batch_fn(
partial(
get_batch_pipe, neox_args=neox_args, curr_scheduler=curr_scheduler
)
)
else:
raise ValueError("Must be using deepspeed to run neox")

Expand Down Expand Up @@ -647,6 +685,7 @@ def train_step(neox_args, timers, data_iterator, model, optimizer, lr_scheduler)
timers=timers,
data_iterator=data_iterator,
model=model,
is_train=True,
)
timers("forward").stop()
losses.append(loss)
Expand Down Expand Up @@ -736,6 +775,7 @@ def train(
lr_scheduler=lr_scheduler,
)
iteration += 1
neox_args.iteration = iteration

overflow_monitor.check(skipped_iter) # check for repeated overflow
if neox_args.log_gradient_noise_scale: # log noise scale if applicable
Expand Down

0 comments on commit 68d223c

Please sign in to comment.