Skip to content

Commit

Permalink
Add MoE matmul implementation
Browse files Browse the repository at this point in the history
  • Loading branch information
RissyRan committed May 30, 2024
1 parent 48f2524 commit 851d048
Show file tree
Hide file tree
Showing 9 changed files with 318 additions and 97 deletions.
1 change: 1 addition & 0 deletions MaxText/configs/base.yml
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@ head_dim: 128
# mixture of experts (moe)
num_experts: 1
num_experts_per_tok: 1
moe_matmul: False
mlp_activations: ["silu", "linear"]
dropout_rate: 0
logits_via_embedding: False
Expand Down
108 changes: 75 additions & 33 deletions MaxText/layers/linears.py
Original file line number Diff line number Diff line change
Expand Up @@ -265,6 +265,7 @@ class MoeBlock(nn.Module):
num_experts_per_tok: Number of experts for each token.
kernel_init: Kernel function, passed to the dense layers.
kernel_axes: Tuple with axes to apply kernel function.
weight_dtype: Type for the weights.
dtype: Type for the dense layer.
"""

Expand All @@ -273,40 +274,81 @@ class MoeBlock(nn.Module):
num_experts_per_tok: int
kernel_init: NdInitializer
kernel_axes: Tuple[str, ...]
weight_dtype: DType = jnp.float32
dtype: DType = jnp.float32

def generate_kernels(self, num_experts, base_emb_dim, mlp_dim):

kernel_in_axis = np.arange(1)
kernel_out_axis = np.arange(1, 2)
kernel_init = nd_dense_init(1.0, 'fan_in', 'truncated_normal')

kernel_axes = ('exp', 'embed', 'mlp')
wo_kernel_axes = ('exp', 'mlp', 'embed')

w0_kernel = self.param(
'wi_0',
nn.with_logical_partitioning(kernel_init, kernel_axes),
(num_experts, base_emb_dim, mlp_dim),
self.weight_dtype,
kernel_in_axis,
kernel_out_axis,
)
w0_kernel = jnp.asarray(w0_kernel, self.dtype)
w1_kernel = self.param(
'wi_1',
nn.with_logical_partitioning(kernel_init, kernel_axes),
(num_experts, base_emb_dim, mlp_dim),
self.weight_dtype,
kernel_in_axis,
kernel_out_axis,
)
w1_kernel = jnp.asarray(w1_kernel, self.dtype)
wo_kernel = self.param(
'wo',
nn.with_logical_partitioning(kernel_init, wo_kernel_axes),
(num_experts, mlp_dim, base_emb_dim),
self.weight_dtype,
kernel_in_axis,
kernel_out_axis,
)
wo_kernel = jnp.asarray(wo_kernel, self.dtype)
return w0_kernel, w1_kernel, wo_kernel

@nn.compact
def __call__(self, inputs, deterministic: bool = False):
def __call__(self, inputs):
cfg = self.config
inputs = inputs.astype(cfg.dtype)
gate_logits = DenseGeneral(
self.num_experts,
dtype=self.dtype,
kernel_init=self.kernel_init,
kernel_axes=self.kernel_axes,
name="gate",
quant=self.quant,
)(inputs)

weights, selected_experts = lax.top_k(gate_logits, self.num_experts_per_tok)
weights = jax.nn.softmax(weights.astype(jnp.float32), axis=-1)
mlp_lnx = jnp.zeros_like(inputs)
weights = weights.astype(self.dtype)
mlp_lnx = nn.with_logical_constraint(mlp_lnx, ("activation_batch", "activation_length", "activation_embed"))

# TODO(ranran): have a better solution to remove the loop here
for k in range(self.num_experts):
weights_exp = jnp.sum(jnp.multiply(selected_experts == k, weights), axis=-1)
mlp_lnx_exp = MlpBlock(
intermediate_dim=self.config.mlp_dim,
activations=self.config.mlp_activations,
intermediate_dropout_rate=self.config.dropout_rate,
dtype=self.dtype,
weight_dtype=self.weight_dtype,
name=f"mlp_{k}",
config=self.config,
)(inputs, deterministic=deterministic)

mlp_lnx_exp = nn.with_logical_constraint(mlp_lnx_exp, ("activation_batch", "activation_length", "activation_embed"))
mlp_lnx_exp = weights_exp[:, :, None] * mlp_lnx_exp
mlp_lnx += mlp_lnx_exp

return mlp_lnx
self.num_experts,
dtype=self.dtype,
kernel_init=self.kernel_init,
kernel_axes=self.kernel_axes,
name="gate")(inputs)

top_k_weights, top_k_indices = jax.lax.top_k(gate_logits, self.num_experts_per_tok)
flattened_top_k_weights = top_k_weights.reshape(-1, self.num_experts_per_tok)

softmax_probs = jax.nn.softmax(flattened_top_k_weights.astype(jnp.float32), axis=-1).astype(self.weight_dtype)
softmax_probs = softmax_probs.reshape(gate_logits.shape[:-1] + (self.num_experts_per_tok,))

weights = jnp.zeros_like(gate_logits)
index_update = (jnp.arange(gate_logits.shape[0])[:, None, None], jnp.arange(gate_logits.shape[1])[:, None], top_k_indices)
weights = weights.at[index_update].set(softmax_probs)

w0_kernel, w1_kernel, wo_kernel = self.generate_kernels(cfg.num_experts,
cfg.base_emb_dim,
cfg.mlp_dim)

with jax.named_scope("wi_0"):
layer_w0 = jnp.einsum("BLE,NEH -> BLNH", inputs, w0_kernel)
with jax.named_scope("wi_1"):
layer_w1 = jnp.einsum("BLE,NEH -> BLNH", inputs, w1_kernel)
layer_w0_act = _convert_to_activation_function(cfg.mlp_activations[0])(layer_w0)
layer_multiply = jnp.multiply(layer_w0_act, layer_w1)
with jax.named_scope("wo"):
intermediate_layer = jnp.einsum("BLNH,NHE -> BLNE", layer_multiply, wo_kernel)
with jax.named_scope("w_sum"):
output = jnp.einsum("BLNE,BLN -> BLE", intermediate_layer, weights)
return output
83 changes: 44 additions & 39 deletions MaxText/layers/mistral.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
from layers import normalizations
from layers import models
import common_types
import max_logging

Array = common_types.Array
Config = common_types.Config
Expand Down Expand Up @@ -122,47 +123,51 @@ def __call__(
hidden_states = nn.with_logical_constraint(hidden_states, ("activation_batch", "activation_length", "activation_embed"))

if cfg.num_experts > 1:
# TODO(ranran): currently, this MoeBlock does not work as expected, and plan to fix it in coming PR.

# mlp_lnx = linears.MoeBlock(
# config=cfg,
# num_experts=cfg.num_experts,
# num_experts_per_tok=cfg.num_experts_per_tok,
# kernel_init=initializers.nd_dense_init(1.0, 'fan_in', 'truncated_normal'),
# kernel_axes=('embed', 'mlp'),
# dtype=cfg.dtype,
# )(hidden_states, deterministic=deterministic)

gate_logits = linears.DenseGeneral(
cfg.num_experts,
weight_dtype=cfg.weight_dtype,
dtype=cfg.dtype,
kernel_init=initializers.nd_dense_init(1.0, "fan_in", "truncated_normal"),
kernel_axes=("embed", "mlp"),
name="gate",
quant=self.quant,
)(hidden_states)
weights, selected_experts = jax.lax.top_k(gate_logits, cfg.num_experts_per_tok)
weights = jax.nn.softmax(weights.astype(jnp.float32), axis=-1)
mlp_lnx = jnp.zeros_like(hidden_states)
weights = weights.astype(cfg.dtype)
mlp_lnx = nn.with_logical_constraint(mlp_lnx, ("activation_batch", "activation_length", "activation_embed"))

# TODO(ranran): have a better solution to remove the loop here
for k in range(cfg.num_experts):
weights_exp = jnp.sum(jnp.multiply(selected_experts == k, weights), axis=-1)
mlp_lnx_exp = linears.MlpBlock(
intermediate_dim=cfg.mlp_dim,
activations=cfg.mlp_activations,
intermediate_dropout_rate=cfg.dropout_rate,
# TODO(ranran): remove for loop implementation after adding expert parallelism
if cfg.moe_matmul:
max_logging.log("Running MoE matmul implementation.")
mlp_lnx = linears.MoeBlock(
config=cfg,
num_experts=cfg.num_experts,
num_experts_per_tok=cfg.num_experts_per_tok,
kernel_init=initializers.nd_dense_init(1.0, 'fan_in', 'truncated_normal'),
kernel_axes=('embed', 'mlp'),
dtype=cfg.dtype,
)(hidden_states)
mlp_lnx = nn.with_logical_constraint(
mlp_lnx, ('activation_batch', 'activation_length', 'activation_embed')
)
else:
max_logging.log("Running MoE for loop implementation.")
gate_logits = linears.DenseGeneral(
cfg.num_experts,
weight_dtype=cfg.weight_dtype,
name=f"mlp_{k}",
config=cfg,
)(hidden_states, deterministic=deterministic)
mlp_lnx_exp = nn.with_logical_constraint(mlp_lnx_exp, ("activation_batch", "activation_length", "activation_embed"))
mlp_lnx_exp = weights_exp[:, :, None] * mlp_lnx_exp
mlp_lnx += mlp_lnx_exp
dtype=cfg.dtype,
kernel_init=initializers.nd_dense_init(1.0, "fan_in", "truncated_normal"),
kernel_axes=("embed", "mlp"),
name="gate",
quant=self.quant,
)(hidden_states)
weights, selected_experts = jax.lax.top_k(gate_logits, cfg.num_experts_per_tok)
weights = jax.nn.softmax(weights.astype(jnp.float32), axis=-1)
mlp_lnx = jnp.zeros_like(hidden_states)
weights = weights.astype(cfg.dtype)
mlp_lnx = nn.with_logical_constraint(mlp_lnx, ("activation_batch", "activation_length", "activation_embed"))

for k in range(cfg.num_experts):
weights_exp = jnp.sum(jnp.multiply(selected_experts == k, weights), axis=-1)
mlp_lnx_exp = linears.MlpBlock(
intermediate_dim=cfg.mlp_dim,
activations=cfg.mlp_activations,
intermediate_dropout_rate=cfg.dropout_rate,
dtype=cfg.dtype,
weight_dtype=cfg.weight_dtype,
name=f"mlp_{k}",
config=cfg,
)(hidden_states, deterministic=deterministic)
mlp_lnx_exp = nn.with_logical_constraint(mlp_lnx_exp, ("activation_batch", "activation_length", "activation_embed"))
mlp_lnx_exp = weights_exp[:, :, None] * mlp_lnx_exp
mlp_lnx += mlp_lnx_exp
else:
mlp_lnx = linears.MlpBlock(
intermediate_dim=cfg.mlp_dim,
Expand Down
46 changes: 36 additions & 10 deletions MaxText/llama_or_mistral_ckpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ def permute_to_match_maxtext_rope(arr):
SIMULATED_CPU_DEVICES_COUNT = 16


def convert(base_model_path, maxtext_model_path, model_size):
def convert(base_model_path, maxtext_model_path, model_size, moe_matmul):
"""
Function to convert the checkpoint at base_model_path into Orbax checkpoint
for MaxText and save at maxtext_model_path
Expand All @@ -106,6 +106,7 @@ def convert(base_model_path, maxtext_model_path, model_size):
base_model_path: checkpoint path
maxtext_model_path: Path to save the MaxText checkpoint to
model_size: llama2-7b to 70b, mistral-7b, or mixtral-8x7b
moe_matmul: Indicate if run MoE block through matmul, otherwise through for loop
"""
"""Convert model to maxtext."""
model_params = MODEL_PARAMS_DICT[model_size]
Expand All @@ -126,7 +127,10 @@ def convert(base_model_path, maxtext_model_path, model_size):
pytorch_vars[int(ckpt_path.name.split(".", maxsplit=2)[1])] = checkpoint
pytorch_vars = [pytorch_vars[i] for i in sorted(list(pytorch_vars.keys()))]

layer_key = "gate" if num_experts else "mlp"
if num_experts:
layer_key = "MoeBlock_0" if moe_matmul else "gate"
else:
layer_key = "mlp"
jax_weights = {
"decoder": {
"layers": {
Expand Down Expand Up @@ -161,7 +165,10 @@ def convert(base_model_path, maxtext_model_path, model_size):
layer_weight["gate"] = {"kernel": []}

for k in range(num_experts):
jax_weights["decoder"]["layers"][f"mlp_{k}"] = {}
if moe_matmul:
jax_weights["decoder"]["layers"]["MoeBlock_0"]["gate"] = {}
else:
jax_weights["decoder"]["layers"][f"mlp_{k}"] = {}
layer_weight[f"mlp_{k}"] = {
"wi_0": {"kernel": []},
"wi_1": {"kernel": []},
Expand Down Expand Up @@ -294,17 +301,35 @@ def convert(base_model_path, maxtext_model_path, model_size):
else:
layer_weight["gate"]["kernel"] = np.array(layer_weight["gate"]["kernel"])
layer_weight["gate"]["kernel"] = np.transpose(layer_weight["gate"]["kernel"], axes=(1, 0, 2))
jax_weights["decoder"]["layers"]["gate"] = layer_weight["gate"]
if moe_matmul:
jax_weights["decoder"]["layers"]["MoeBlock_0"]["gate"]["kernel"] = layer_weight["gate"]["kernel"]
all_wi_0 = []
all_wi_1 = []
all_wo = []
else:
jax_weights["decoder"]["layers"]["gate"] = layer_weight["gate"]

for k in range(num_experts):
layer_weight[f"mlp_{k}"]["wi_0"]["kernel"] = np.array(layer_weight[f"mlp_{k}"]["wi_0"]["kernel"])
layer_weight[f"mlp_{k}"]["wi_1"]["kernel"] = np.array(layer_weight[f"mlp_{k}"]["wi_1"]["kernel"])
layer_weight[f"mlp_{k}"]["wo"]["kernel"] = np.array(layer_weight[f"mlp_{k}"]["wo"]["kernel"])
# swap the layer index
layer_weight[f"mlp_{k}"]["wi_0"]["kernel"] = np.transpose(layer_weight[f"mlp_{k}"]["wi_0"]["kernel"], axes=(1, 0, 2))
layer_weight[f"mlp_{k}"]["wi_1"]["kernel"] = np.transpose(layer_weight[f"mlp_{k}"]["wi_1"]["kernel"], axes=(1, 0, 2))
layer_weight[f"mlp_{k}"]["wo"]["kernel"] = np.transpose(layer_weight[f"mlp_{k}"]["wo"]["kernel"], axes=(1, 0, 2))

jax_weights["decoder"]["layers"][f"mlp_{k}"] = layer_weight[f"mlp_{k}"]
if moe_matmul:
all_wi_0.append(layer_weight[f"mlp_{k}"]["wi_0"]["kernel"])
all_wi_1.append(layer_weight[f"mlp_{k}"]["wi_1"]["kernel"])
all_wo.append(layer_weight[f"mlp_{k}"]["wo"]["kernel"])
else:
# swap the layer index
layer_weight[f"mlp_{k}"]["wi_0"]["kernel"] = np.transpose(layer_weight[f"mlp_{k}"]["wi_0"]["kernel"], axes=(1, 0, 2))
layer_weight[f"mlp_{k}"]["wi_1"]["kernel"] = np.transpose(layer_weight[f"mlp_{k}"]["wi_1"]["kernel"], axes=(1, 0, 2))
layer_weight[f"mlp_{k}"]["wo"]["kernel"] = np.transpose(layer_weight[f"mlp_{k}"]["wo"]["kernel"], axes=(1, 0, 2))

jax_weights["decoder"]["layers"][f"mlp_{k}"] = layer_weight[f"mlp_{k}"]

if moe_matmul:
jax_weights["decoder"]["layers"]["MoeBlock_0"]["wi_0"] = np.array(all_wi_0)
jax_weights["decoder"]["layers"]["MoeBlock_0"]["wi_1"] = np.array(all_wi_1)
jax_weights["decoder"]["layers"]["MoeBlock_0"]["wo"] = np.array(all_wo)

mesh = jax.sharding.Mesh(jax.devices(), "checkpoint_sharding_axis")
s1 = jax.sharding.NamedSharding(mesh, jax.sharding.PartitionSpec("checkpoint_sharding_axis")) # shards first axis
Expand Down Expand Up @@ -353,6 +378,7 @@ def checkpoint_device_put(arr):
parser.add_argument("--base-model-path", type=str, required=True)
parser.add_argument("--maxtext-model-path", type=str, required=True)
parser.add_argument("--model-size", type=str, required=True)
parser.add_argument("--moe-matmul", type=bool, required=False, default=False)

args = parser.parse_args()

Expand All @@ -361,4 +387,4 @@ def checkpoint_device_put(arr):

os.environ["XLA_FLAGS"] = f"--xla_force_host_platform_device_count={SIMULATED_CPU_DEVICES_COUNT}"

convert(args.base_model_path, args.maxtext_model_path, args.model_size)
convert(args.base_model_path, args.maxtext_model_path, args.model_size, args.moe_matmul)
Loading

0 comments on commit 851d048

Please sign in to comment.