Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Perf megablox #694

Draft
wants to merge 17 commits into
base: main
Choose a base branch
from
Prev Previous commit
Next Next commit
improve perf
  • Loading branch information
RissyRan committed Jun 9, 2024
commit 8098e13895b89c863baa0cfb469398e6f4224c15
3 changes: 2 additions & 1 deletion MaxText/configs/base.yml
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ hardware: 'tpu' # Supported hardware types are 'tpu', 'gpu', 'gpu_multiprocess'
# Parallelism
mesh_axes: ['data', 'fsdp', 'fsdp_transpose', 'sequence', 'tensor', 'autoregressive']
logical_axis_rules: [
['activation_batch', ['data', 'fsdp', 'fsdp_transpose',]],
['activation_batch', ['data', 'fsdp', 'fsdp_transpose']],
['activation_heads', ['tensor','sequence']],
['activation_length', 'sequence'],
['activation_embed', 'tensor'],
Expand All @@ -134,6 +134,7 @@ logical_axis_rules: [
['cache_heads', ['autoregressive', 'tensor']],
['cache_kv', []],
['cache_sequence', []],
['test', ['fsdp']],
]
data_sharding: [['data', 'fsdp', 'fsdp_transpose', 'sequence', 'tensor', 'autoregressive']]

Expand Down
30 changes: 30 additions & 0 deletions MaxText/configs/models/mixtral-test.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
# Copyright 2024 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# model config for mixtral-8x7b

base_emb_dim: 4096
base_num_query_heads: 32
base_num_kv_heads: 8
base_mlp_dim: 14336
base_num_decoder_layers: 5
head_dim: 128
mlp_activations: ["silu","linear"]
vocab_size: 32000
enable_dropout: False
logits_via_embedding: False
normalization_layer_epsilon: 1.0e-5
num_experts: 8
num_experts_per_tok: 2
decoder_block: "mistral"
67 changes: 56 additions & 11 deletions MaxText/layers/linears.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@

try:
from jax.experimental.pallas.ops.tpu import megablox as mblx
# import megablox as mblx
except ImportError:
max_logging.log("JAX megablox is available for TPU only.")
pass
Expand Down Expand Up @@ -285,8 +286,8 @@ class MoeBlock(nn.Module):
mesh: Mesh
kernel_init: NdInitializer
kernel_axes: Tuple[str, ...]
weight_dtype: DType = jnp.float32
dtype: DType = jnp.float32
weight_dtype: DType = jnp.bfloat16 # todo: check type
dtype: DType = jnp.bfloat16

def generate_kernels(self, num_experts, base_emb_dim, mlp_dim):

Expand Down Expand Up @@ -340,7 +341,7 @@ def permute(self, inputs, gate_logits, emb_dim):
# sort inputs for number of selected experts
sorted_inputs = jnp.take(repeat_inputs, indices=sorted_selected_experts, axis=0).astype(self.dtype)
group_size = jnp.bincount(flatten_selected_experts, length=self.num_experts)

# breakpoint()
return sorted_inputs, sorted_selected_experts, weights, group_size

def unpermute(self, intermediate, inputs, sorted_selected_experts, weights):
Expand All @@ -355,47 +356,90 @@ def unpermute(self, intermediate, inputs, sorted_selected_experts, weights):
def call_gmm(self, inputs, group_sizes, mlp_activation, w0_kernel, w1_kernel, wo_kernel):
# TODO(ranran): currently megablox works well on single host, and
# will add sharding properly to improve performance.
# kernel_axes = ('exp', 'embed', 'mlp')
# wo_kernel_axes = ('exp', 'mlp', 'embed')
tile_size = (8192, 128, 128)
# tile_size = None
@functools.partial(
shard_map.shard_map,
mesh=self.mesh,
in_specs=(
(nn.logical_to_mesh_axes((None, None))),
(nn.logical_to_mesh_axes(('test', None))),
(nn.logical_to_mesh_axes((None, None, None))),
(nn.logical_to_mesh_axes((None,))),
),
out_specs=(nn.logical_to_mesh_axes((None, None))),
out_specs=(nn.logical_to_mesh_axes(('test', None))),
check_rep=False,
)
def gmm(inputs, kernel, group_sizes):
hs_shape = inputs.shape
# pad lengh is the 1st dimension of tiling size in gmm call
# pad length is the 1st dimension of tiling size in gmm call
pad_length = 512
if hs_shape[0] % pad_length:
pad_length = pad_length - hs_shape[0] % pad_length
inputs = jax.lax.pad(inputs.astype(jnp.float32), 0.0, [(0, pad_length, 0), (0,0,0)])
inputs = jax.lax.pad(inputs, 0.0, [(0, pad_length, 0), (0,0,0)])
# inputs = jax.lax.pad(inputs.astype(jnp.float32), 0.0, [(0, pad_length, 0), (0,0,0)])

inputs = inputs.astype(self.dtype)
kernel = kernel.astype(self.weight_dtype)

output = mblx.gmm(lhs=inputs,
rhs=kernel,
group_sizes=group_sizes,
tiling=(512, 512, 512))
preferred_element_type=jnp.bfloat16,
tiling=tile_size)

if hs_shape[0] % pad_length:
output = output[:hs_shape[0]]
return output


@functools.partial(
shard_map.shard_map,
mesh=self.mesh,
in_specs=(
(nn.logical_to_mesh_axes(('test', None))),
(nn.logical_to_mesh_axes((None, None, None))),
(nn.logical_to_mesh_axes((None,))),
),
out_specs=(nn.logical_to_mesh_axes(('test', None))),
check_rep=False,
)
def gmm2(inputs, kernel, group_sizes):
hs_shape = inputs.shape
# pad length is the 1st dimension of tiling size in gmm call
pad_length = 512
if hs_shape[0] % pad_length:
pad_length = pad_length - hs_shape[0] % pad_length
# inputs = jax.lax.pad(inputs.astype(jnp.float32), 0.0, [(0, pad_length, 0), (0,0,0)])
inputs = jax.lax.pad(inputs, 0.0, [(0, pad_length, 0), (0,0,0)])

inputs = inputs.astype(self.dtype)
kernel = kernel.astype(self.weight_dtype)

output = mblx.gmm(lhs=inputs,
rhs=kernel,
group_sizes=group_sizes,
preferred_element_type=jnp.bfloat16,
tiling=(tile_size[0], tile_size[2], tile_size[1]) if tile_size else None)

if hs_shape[0] % pad_length:
output = output[:hs_shape[0]]
return output

# inputs: (batch * selected_exp * sequence, emb_dim) - (262144, 4096)
# w0_kernel: (num_exp, emb_dim, mlp) -> (8, 4096, 14336)
# w1_kernel: (num_exp, emb_dim, mlp)
# o_kernel: (num_exp, mlp, emb_dim) - > (8, 14336, 4096)
layer_w0 = gmm(inputs, w0_kernel, group_sizes)
layer_w1 = gmm(inputs, w1_kernel, group_sizes)
layer_act = _convert_to_activation_function(mlp_activation)(layer_w0)
intermediate_layer = jnp.multiply(layer_act, layer_w1)
output = gmm(intermediate_layer, wo_kernel, group_sizes)
output = gmm2(intermediate_layer, wo_kernel, group_sizes)
return output

@nn.compact
def __call__(self, inputs):
cfg = self.config
# inputs = nn.with_logical_constraint(inputs, ('test', None, None))
inputs = inputs.astype(cfg.dtype)
gate_logits = DenseGeneral(
self.num_experts,
Expand All @@ -408,6 +452,7 @@ def __call__(self, inputs):
flattened_top_k_weights = top_k_weights.reshape(-1, self.num_experts_per_tok)

softmax_probs = jax.nn.softmax(flattened_top_k_weights.astype(jnp.float32), axis=-1).astype(self.weight_dtype)
# softmax_probs = jax.nn.softmax(flattened_top_k_weights.astype(jnp.float32), axis=-1).astype(self.weight_dtype)
softmax_probs = softmax_probs.reshape(gate_logits.shape[:-1] + (self.num_experts_per_tok,))

weights = jnp.zeros_like(gate_logits)
Expand Down
6 changes: 2 additions & 4 deletions MaxText/layers/mistral.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,6 @@ def __call__(
mesh = self.mesh

inputs = nn.with_logical_constraint(inputs, ("activation_batch", "activation_length", "activation_embed"))

lnx_rms = models.RMSNorm(
dtype=cfg.dtype,
weight_dtype=cfg.weight_dtype,
Expand All @@ -79,7 +78,6 @@ def __call__(
epsilon=cfg.normalization_layer_epsilon,
)
lnx = lnx_rms(inputs)

lnx = nn.with_logical_constraint(lnx, ("activation_batch", "activation_length", "activation_embed"))

# Self-attention block
Expand Down Expand Up @@ -121,7 +119,7 @@ def __call__(
epsilon=cfg.normalization_layer_epsilon,
)(intermediate_inputs)
hidden_states = nn.with_logical_constraint(hidden_states, ("activation_batch", "activation_length", "activation_embed"))

# breakpoint()
if cfg.num_experts > 1:
# TODO(ranran): remove for loop implementation after adding expert parallelism
if cfg.moe_matmul:
Expand All @@ -131,7 +129,7 @@ def __call__(
num_experts_per_tok=cfg.num_experts_per_tok,
mesh=mesh,
kernel_init=initializers.nd_dense_init(1.0, 'fan_in', 'truncated_normal'),
kernel_axes=('embed', 'mlp'),
kernel_axes=(None, 'test'),
dtype=cfg.dtype,
)(hidden_states)
mlp_lnx = nn.with_logical_constraint(
Expand Down
15 changes: 15 additions & 0 deletions MaxText/megablox/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
# Copyright 2024 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from megablox.ops import gmm
80 changes: 80 additions & 0 deletions MaxText/megablox/common.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
# Copyright 2024 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


"""Common utilities for GMM kernels."""


import re


import jax
import jax.numpy as jnp




def is_tpu() -> bool:
return "TPU" in jax.devices()[0].device_kind




def tpu_kind() -> str:
"""Query identification string for the currently attached TPU."""
return jax.devices()[0].device_kind




_TPU_KIND_PATTERN = re.compile(r"TPU v(\d+)")




def tpu_generation() -> int:
"""Generation number of the currently attached TPU."""
if version := _TPU_KIND_PATTERN.match(tpu_kind()):
return int(version[1])
raise NotImplementedError("only TPU devices are supported")




def supports_bfloat16_matmul() -> bool:
"""Does the currently attached CPU support bfloat16 inputs?"""
return not is_tpu() or tpu_generation() >= 4




def assert_is_supported_dtype(dtype: jnp.dtype) -> None:
if dtype != jnp.bfloat16 and dtype != jnp.float32:
raise ValueError(f"Expected bfloat16 or float32 array but got {dtype}.")




def select_input_dtype(lhs: jnp.ndarray, rhs: jnp.ndarray) -> jnp.dtype:
"""A type to which both input should be adapted to before dot product."""
# bf16xbf16 matmul is only supported since TPUv4 generation. In case of mixed
# input precision, we need to convert bf16 argument to fp32 beforehand.
if (
supports_bfloat16_matmul()
and lhs.dtype == jnp.bfloat16
and rhs.dtype == jnp.bfloat16
):
return jnp.bfloat16
else:
return jnp.float32
Loading
Loading