Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

MoE Support #677

Draft
wants to merge 7 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Formatting
  • Loading branch information
Quentin-Anthony committed Sep 19, 2022
commit 919899e3f291a965013e8ca976242e76c8496ae7
2 changes: 1 addition & 1 deletion megatron/data/data_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -250,7 +250,7 @@ def weights_by_num_docs(l, alpha=0.3):
total_n_docs = sum(l)
unbiased_sample_probs = [i / total_n_docs for i in l]

probs = [i ** alpha for i in unbiased_sample_probs]
probs = [i**alpha for i in unbiased_sample_probs]

# normalize
total = sum(probs)
Expand Down
2 changes: 1 addition & 1 deletion megatron/model/gmlp.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ class TinyAttention(nn.Module):
def __init__(self, neox_args, d_attn, d_ff, mask_fn):
super().__init__()
self.proj_qkv = nn.Linear(d_ff * 2, 3 * d_attn)
self.scale = d_attn ** -0.5
self.scale = d_attn**-0.5
self.proj_ffn = nn.Linear(d_attn, d_ff)
self.softmax = FusedScaleMaskSoftmax(
input_in_fp16=neox_args.precision == "fp16",
Expand Down
14 changes: 10 additions & 4 deletions megatron/model/gpt2_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,10 +223,16 @@ def init_specs(self):
heads=self.neox_args.num_attention_heads,
)

assert len(self.num_experts) == 1 or len(self.num_experts) == self.neox_args.num_layers // self.neox_args.expert_interval, 'num_layers must be divisible by pipeline_model_parallel_size'
assert (
len(self.num_experts) == 1
or len(self.num_experts)
== self.neox_args.num_layers // self.neox_args.expert_interval
), "num_layers must be divisible by pipeline_model_parallel_size"

if len(self.num_experts) == 1:
self.num_experts = self.num_experts * (self.neox_args.num_layers // self.neox_args.expert_interval)
self.num_experts = self.num_experts * (
self.neox_args.num_layers // self.neox_args.expert_interval
)

# Transformer layers
for i in range(self.neox_args.num_layers):
Expand All @@ -244,7 +250,7 @@ def init_specs(self):
)
else:
if i % self.neox_args.expert_interval == 0:
n_e = self.num_experts[(i-1) // self.neox_args.expert_interval]
n_e = self.num_experts[(i - 1) // self.neox_args.expert_interval]
else:
n_e = 1
self.specs.append(
Expand All @@ -258,7 +264,7 @@ def init_specs(self):
rpe=rpe_emb if self.neox_args.pos_emb == "rpe" else None,
rotary=self.neox_args.pos_emb == "rotary",
use_cache=self.use_cache,
num_experts=n_e
num_experts=n_e,
)
)

Expand Down
2 changes: 1 addition & 1 deletion megatron/model/positional_embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ def _get_slopes(self, n):
def get_slopes_power_of_2(n):
start = 2 ** (-(2 ** -(math.log2(n) - 3)))
ratio = start
return [start * ratio ** i for i in range(n)]
return [start * ratio**i for i in range(n)]

if math.log2(n).is_integer():
return get_slopes_power_of_2(n)
Expand Down
44 changes: 37 additions & 7 deletions megatron/model/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,13 @@ class ParallelMLP(nn.Module):
"""

def __init__(
self, neox_args, init_method, output_layer_init_method, parallel_output=False, MOE=False, MoE_mp_size=1
self,
neox_args,
init_method,
output_layer_init_method,
parallel_output=False,
MOE=False,
MoE_mp_size=1,
):
super().__init__()

Expand All @@ -104,7 +110,7 @@ def __init__(
init_method=init_method,
skip_bias_add=True,
MOE=MOE,
MoE_mp_size=MoE_mp_size
MoE_mp_size=MoE_mp_size,
)
ff_dim_in = ff_dim // 2 if self.activation_type == "geglu" else ff_dim
# Project back to h.
Expand Down Expand Up @@ -586,8 +592,24 @@ def __init__(
moe_mp_size = 1
else:
moe_mp_size = dist.get_world_size() // self.num_experts

self.mlp = MoE(neox_args.hidden_size, ParallelMLP(init_method, output_layer_init_method=output_layer_init_method, MOE=True, MoE_mp_size=moe_mp_size), num_experts=self.num_experts, ep_size=neox_args.moe_expert_parallel_size, k=neox_args.topk, use_residual=(neox_args.mlp_type == 'residual', capacity_factor=neox_args.moe_train_capacity_factor, eval_capacity_factor=neox_args.moe_eval_capacity_factor, min_capacity=neox_args.moe_min_capacity, drop_tokens=neox_args.moe_token_dropping))

self.mlp = MoE(
neox_args.hidden_size,
ParallelMLP(
init_method,
output_layer_init_method=output_layer_init_method,
MOE=True,
MoE_mp_size=moe_mp_size,
),
num_experts=self.num_experts,
ep_size=neox_args.moe_expert_parallel_size,
k=neox_args.topk,
use_residual=(neox_args.mlp_type == "residual"),
capacity_factor=neox_args.moe_train_capacity_factor,
eval_capacity_factor=neox_args.moe_eval_capacity_factor,
min_capacity=neox_args.moe_min_capacity,
drop_tokens=neox_args.moe_token_dropping,
)

self.layer_past = None # used to cache k/v pairs in inference

Expand Down Expand Up @@ -630,9 +652,17 @@ def forward(self, x, attention_mask, layer_past=None):
)

# output = mlp(ln2(x)) + attention_output
#mlp_output, mlp_bias = self.mlp(self.post_attention_layernorm(x))
moe_loss = torch.tensor(0.0, device=self.post_attention_layernorm(x).device, dtype=self.post_attention_layernorm(x).dtype)
mlp_bias = torch.tensor(0.0, device=self.post_attention_layernorm(x).device, dtype=self.post_attention_layernorm(x).dtype)
# mlp_output, mlp_bias = self.mlp(self.post_attention_layernorm(x))
moe_loss = torch.tensor(
0.0,
device=self.post_attention_layernorm(x).device,
dtype=self.post_attention_layernorm(x).dtype,
)
mlp_bias = torch.tensor(
0.0,
device=self.post_attention_layernorm(x).device,
dtype=self.post_attention_layernorm(x).dtype,
)

if self.num_experts == 1:
mlp_output, mlp_bias = self.mlp(self.post_attention_layernorm(x))
Expand Down
4 changes: 2 additions & 2 deletions megatron/neox_arguments/arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -290,8 +290,8 @@ def consume_deepy_args(cls):
type=str,
default=DLTS_HOSTFILE,
help="Hostfile path (in MPI style) that defines the "
"resource pool available to the job (e.g., "
"worker-0 slots=4)"
"resource pool available to the job (e.g., "
"worker-0 slots=4)",
)
group = parser.add_argument_group(title="Generation args")
group.add_argument(
Expand Down
8 changes: 5 additions & 3 deletions megatron/neox_arguments/neox_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,9 @@ class NeoXArgsParallelism(NeoXArgsTemplate):
Size of the model parallelism.
"""

num_experts: list = [1,]
num_experts: list = [
1,
]
"""
Degree of MoE expert parallelism
"""
Expand Down Expand Up @@ -350,9 +352,9 @@ class NeoXArgsModel(NeoXArgsTemplate):
Use experts in every "expert-interval" layers
"""

mlp_type: str = 'standard'
mlp_type: str = "standard"
"""
Only applicable when num-experts > 1, accepts [standard, residual]
Only applicable when num-experts > 1, accepts [standard, residual]
"""


Expand Down
3 changes: 2 additions & 1 deletion megatron/text_generation_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -767,7 +767,8 @@ def generate_samples_interactive(
.tolist()[
batch_token_generation_start_index[0]
.item() : batch_token_generation_end_index[0]
.item() + 1
.item()
+ 1
]
)
generated_text = neox_args.tokenizer.detokenize(generated_tokens)
Expand Down
4 changes: 2 additions & 2 deletions megatron/tokenizer/gpt2_tokenization.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,10 +65,10 @@ def bytes_to_unicode():
)
cs = bs[:]
n = 0
for b in range(2 ** 8):
for b in range(2**8):
if b not in bs:
bs.append(b)
cs.append(2 ** 8 + n)
cs.append(2**8 + n)
n += 1
cs = [_chr(n) for n in cs]
return dict(zip(bs, cs))
Expand Down
10 changes: 8 additions & 2 deletions megatron/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -279,8 +279,14 @@ def get_optimizer(model, neox_args):
# Build parameter groups (weight decay and non-decay).
param_groups = get_params_for_weight_decay_optimization(model, neox_args)
if neox_args.create_moe_param_group:
from deepspeed.moe.utils import is_moe_param, split_params_into_different_moe_groups_for_optimizer
param_groups = split_params_into_different_moe_groups_for_optimizer(param_groups)
from deepspeed.moe.utils import (
is_moe_param,
split_params_into_different_moe_groups_for_optimizer,
)

param_groups = split_params_into_different_moe_groups_for_optimizer(
param_groups
)
print_rank_0(
f'Configuring Optimizer type: {neox_args.optimizer_type} with params: {neox_args.optimizer["params"]}'
)
Expand Down