Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support for DeepSpeed Ulysses (SP) #1084

Draft
wants to merge 2 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
Next Next commit
Initial SP support
  • Loading branch information
Quentin-Anthony committed Nov 26, 2023
commit c509f6afd6c7cfb3e4693bd26943cf7d83b870b1
1 change: 1 addition & 0 deletions configs/125M.yml
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
# across the node boundaries )
"pipe_parallel_size": 1,
"model_parallel_size": 1,
"sequence_parallel_size": 1,

# model settings
"num_layers": 12,
Expand Down
3 changes: 2 additions & 1 deletion eval_tasks/eval_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,13 +72,14 @@ def __init__(self, model, forward_step_fn, neox_args, batch_size=None):
self.is_model_parallel = neox_args.model_parallel_size > 1
self.is_pipe_parallel = self.model.is_pipe_parallel
self.is_data_parallel = self.model.is_data_parallel
self.is_sequence_parallel = self.model.is_sequence_parallel
self.is_last_stage = (
True if not self.is_pipe_parallel else model.is_last_stage()
) # only the last stage of the pipeline model will receive the logits
self.dp_world_size = mpu.get_data_parallel_world_size()
self.dp_rank = mpu.get_data_parallel_rank()
self.dp_group = mpu.get_data_parallel_group()
self.is_mp_rank_0 = mpu.get_model_parallel_rank() == 0
self.is_mp_rank_0 = (not self.is_sequence_parallel and mpu.get_tensor_parallel_rank() == 0) or (self.is_sequence_parallel and mpu.get_sequence_parallel_rank() == 0)

self._batch_size = batch_size or (
neox_args.batch_size * self.dp_world_size
Expand Down
2 changes: 1 addition & 1 deletion megatron/checkpointing.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,7 @@ def get_checkpoint_name(checkpoints_path, iteration, release=False, mp_rank=None
checkpoints_path,
directory,
"mp_rank_{:02d}".format(
mpu.get_model_parallel_rank() if mp_rank is None else mp_rank
mpu.get_tensor_parallel_rank() if mp_rank is None else mp_rank
),
"model_optim_rng.pt",
)
Expand Down
2 changes: 1 addition & 1 deletion megatron/data/data_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -312,7 +312,7 @@ def build_train_valid_test_data_iterators(neox_args):
pipe_load = True

# Data loader only on rank 0 of each model parallel group.
if mpu.get_model_parallel_rank() == 0 and pipe_load:
if ((not neox_args.is_sequence_parallel and mpu.get_model_parallel_rank() == 0) or (neox_args.is_sequence_parallel and mpu.get_sequence_parallel_rank() == 0)) and pipe_load:
# Number of train/valid/test samples.
train_iters = neox_args.train_iters
eval_iters = (train_iters // neox_args.eval_interval + 1) * neox_args.eval_iters
Expand Down
54 changes: 33 additions & 21 deletions megatron/initialize.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,37 +154,49 @@ def _initialize_distributed(neox_args):
)

# Setup 3D topology.
pp = neox_args.pipe_parallel_size if neox_args.pipe_parallel_size >= 1 else 1
mp = neox_args.model_parallel_size if neox_args.model_parallel_size >= 1 else 1
assert (
neox_args.world_size % (pp * mp) == 0
), f"world_size={neox_args.world_size}, pp={pp}, mp={mp}"
dp = neox_args.world_size // (pp * mp)

from deepspeed.runtime.pipe.topology import PipeModelDataParallelTopology

# this does pipe on the most outside, then data, then model.
# PipeModelDataParallelTopology is just a wrapper over ProcessTopology that predefines this order.
topo = PipeModelDataParallelTopology(num_pp=pp, num_mp=mp, num_dp=dp)

# Offset base seeds for the interior pipeline stages.
# TODO: adjust last stage too once IO is improved.
stage_id = topo.get_coord(rank=torch.distributed.get_rank()).pipe
if 0 < stage_id < topo.get_dim("pipe") - 1:
offset = neox_args.seed + 1138
neox_args.seed = offset + (stage_id * mp)
#pp = neox_args.pipe_parallel_size if neox_args.pipe_parallel_size >= 1 else 1
#mp = neox_args.model_parallel_size if neox_args.model_parallel_size >= 1 else 1
#assert (
# neox_args.world_size % (pp * mp) == 0
#), f"world_size={neox_args.world_size}, pp={pp}, mp={mp}"
#sp = neox_args.sequence_parallel_size if neox_args.sequence_parallel_size >= 1 else 1
#assert (
# (sp > 1 and (pp == 1 and mp == 1)) or (sp == 1)
#), f"sp={sp} cannot be used along with pp>1 or mp>1"
#if sp > 1:
# dp = neox_args.world_size // (sp)
#else:
# dp = neox_args.world_size // (pp * mp)
#
#from deepspeed.runtime.pipe.topology import PipeModelDataParallelTopology, ProcessTopology
#
#if sp > 1:
# topo = ProcessTopology(axes=['data', 'sequence'], dims=[dp, sp])
#else:
# # this does pipe on the most outside, then data, then model.
# # PipeModelDataParallelTopology is just a wrapper over ProcessTopology that predefines this order.
# topo = PipeModelDataParallelTopology(num_pp=pp, num_mp=mp, num_dp=dp)
#
# # Offset base seeds for the interior pipeline stages.
# # TODO: adjust last stage too once IO is improved.
# stage_id = topo.get_coord(rank=torch.distributed.get_rank()).pipe
# if 0 < stage_id < topo.get_dim("pipe") - 1:
# offset = neox_args.seed + 1138
# neox_args.seed = offset + (stage_id * mp)

# Set the model-parallel / data-parallel communicators.
if device_count > 0:
if mpu.model_parallel_is_initialized():
if mpu.tensor_parallel_is_initialized() or mpu.sequence_parallel_is_initialized:
print(
"_initialize_distributed() model parallel is already initialized",
flush=True,
)
else:
mpu.initialize_model_parallel(
neox_args.model_parallel_size,
topology=topo,
neox_args.sequence_parallel_size,
neox_args.pipeline_parallel_size,
#topology=topo,
fp32_allreduce=neox_args.fp32_allreduce,
)

Expand Down
15 changes: 13 additions & 2 deletions megatron/model/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -275,7 +275,7 @@ def __init__(
self.attention_softmax_in_fp32 = True
self.layer_number = layer_number
# Per attention head and per partition values.
world_size = mpu.get_model_parallel_world_size()
world_size = mpu.get_tensor_parallel_world_size()
self.hidden_size_per_partition = mpu.divide(neox_args.hidden_size, world_size)
self.hidden_size_per_attention_head = mpu.divide(
neox_args.hidden_size, neox_args.num_attention_heads
Expand Down Expand Up @@ -310,7 +310,7 @@ def __init__(
self.alibi_embed = AliBi(
neox_args.num_attention_heads,
neox_args.model_parallel_size,
mpu.get_model_parallel_rank(),
mpu.get_tensor_parallel_rank(),
)

# TODO: this arg shouldn't need to be passed in - get from neox_args
Expand Down Expand Up @@ -338,6 +338,7 @@ def __init__(

self.attention_type = neox_args.attention_config[layer_number]
self.use_flash_attention = self.attention_type == "flash"
self.use_ds_ulysses_attention = self.attention_type == "ulysses"
self.sparse = self.attention_type not in ("global", "flash")
if self.sparse:
self.sparse_attn = configure_sparse_attention(
Expand All @@ -360,6 +361,13 @@ def __init__(
self.flash_triton_fn = flash_attn_unpadded_unpacked_func_triton
self.flash_qkv_fn = flash_attn_varlen_qkvpacked_func
self.flash_kv_fn = flash_attn_varlen_kvpacked_func
elif self.use_ds_ulysses_attention:
print('USING ULYSSES')
try:
from deepspeed.sequence.layer import DistributedAttention
self.ds_ulysses_attention_fn = DistributedAttention(self.attention, mpu.get_sequence_parallel_group())
except ImportError as e:
print(f'Error. You passed a gpt-neox ulysses config but DeepSpeed ulysses could not be imported with the error: {e}')
else:
self.scale_mask_softmax = FusedScaleMaskSoftmax(
input_in_fp16=self.fp16,
Expand Down Expand Up @@ -598,6 +606,7 @@ def flash_attention(self, query_layer, key_layer, value_layer):

return matmul_result


def sparse_attention(self, query_layer, key_layer, value_layer, attention_mask):
# TODO: sparse attn dropout?
# TODO: pad to block size
Expand Down Expand Up @@ -688,6 +697,8 @@ def forward(self, hidden_states, attention_mask, layer_past=None):

if self.use_flash_attention:
context_layer = self.flash_attention(query_layer, key_layer, value_layer)
elif self.use_ds_ulysses_attention:
context_layer = self.ds_ulysses_attention_fn(query_layer, key_layer, value_layer, attention_mask)
elif not self.sparse:
context_layer = self.attention(
query_layer, key_layer, value_layer, layer_past, attention_mask
Expand Down
2 changes: 1 addition & 1 deletion megatron/mpu/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@

from .data import broadcast_data

from .initialize import is_unitialized
from .initialize import is_uninitialized
from .initialize import destroy_model_parallel
from .initialize import get_data_parallel_group
from .initialize import get_data_parallel_rank
Expand Down