Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix SequentialWrapper Generation (pipe_parallel_size = 0) #1031

Merged
merged 2 commits into from
Sep 18, 2023
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Fix SequentialGeneration
  • Loading branch information
xu-song committed Sep 15, 2023
commit 5098970b6ee6f39602590b159aea16e4f54c98d5
17 changes: 17 additions & 0 deletions megatron/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -351,6 +351,16 @@ def get_batch_pipe(data, neox_args, curr_scheduler=None):
return (tokens, position_ids, attention_mask), (labels, loss_mask)


def get_batch_sequential(forward_input, neox_args):
"""A modification of get_batch() to work with the latest batch instead of an iterator."""
attention_mask, loss_mask, position_ids = get_ltor_masks_and_position_ids(
data=forward_input[0],
eod_token=neox_args.tokenizer.eod,
eod_mask_loss=neox_args.eod_mask_loss,
)
return (forward_input[0], forward_input[1], attention_mask)


def forward_step(
data_iterator, model, neox_args, timers, return_logits=False, is_train=False
):
Expand Down Expand Up @@ -653,6 +663,13 @@ def setup_model_and_optimizer(neox_args, use_cache=False, iteration=None):
get_batch_pipe, neox_args=neox_args, curr_scheduler=curr_scheduler
)
)
else:
model.module.set_batch_fn(
partial(
get_batch_sequential, neox_args=neox_args
)
)

else:
raise ValueError("Must be using deepspeed to run neox")

Expand Down
Loading