Skip to content

Commit

Permalink
Fix SequentialWrapper Generation (pipe_parallel_size = 0) (#1031)
Browse files Browse the repository at this point in the history
* Fix SequentialGeneration

* Fix SequentialGeneration
  • Loading branch information
xu-song committed Sep 18, 2023
1 parent fcd5f92 commit 70af6e8
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 0 deletions.
12 changes: 12 additions & 0 deletions megatron/model/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,7 @@ def __init__(
self.activation_checkpoint_interval = activation_checkpoint_interval
self.parent_class_name = parent_class_name
self.activation_checkpoint_func = activation_checkpoint_func
self.batch_fn = None

def _is_checkpointable(self, funcs):
if self.parent_class_name == "GPT2ModelPipe":
Expand All @@ -106,6 +107,14 @@ def _is_checkpointable(self, funcs):
params = [f.parameters() for f in funcs if isinstance(f, torch.nn.Module)]
return any(len(list(p)) > 0 for p in params)

def set_batch_fn(self, fn):
"""Execute a post-processing function on input data.
Args:
fn (function): The function to run.
"""
self.batch_fn = fn

def inference_mode(self, use_cache=True):
"""
Sets up the model for inference by turning on k/v caching (if specified) and setting `parallel output` of the final layer to false,
Expand All @@ -127,6 +136,9 @@ def forward(
self, forward_input, curriculum_seqlen=None, labels=None, neox_args=None
):

if self.batch_fn:
forward_input = self.batch_fn(forward_input)

if (
curriculum_seqlen is not None
and isinstance(forward_input, tuple)
Expand Down
17 changes: 17 additions & 0 deletions megatron/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -351,6 +351,16 @@ def get_batch_pipe(data, neox_args, curr_scheduler=None):
return (tokens, position_ids, attention_mask), (labels, loss_mask)


def get_batch_sequential(forward_input, neox_args):
"""A modification of get_batch() to work with the latest batch instead of an iterator."""
attention_mask, loss_mask, position_ids = get_ltor_masks_and_position_ids(
data=forward_input[0],
eod_token=neox_args.tokenizer.eod,
eod_mask_loss=neox_args.eod_mask_loss,
)
return (forward_input[0], forward_input[1], attention_mask)


def forward_step(
data_iterator, model, neox_args, timers, return_logits=False, is_train=False
):
Expand Down Expand Up @@ -653,6 +663,13 @@ def setup_model_and_optimizer(neox_args, use_cache=False, iteration=None):
get_batch_pipe, neox_args=neox_args, curr_scheduler=curr_scheduler
)
)
else:
model.module.set_batch_fn(
partial(
get_batch_sequential, neox_args=neox_args
)
)

else:
raise ValueError("Must be using deepspeed to run neox")

Expand Down

0 comments on commit 70af6e8

Please sign in to comment.