Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Perf megablox #694

Draft
wants to merge 17 commits into
base: main
Choose a base branch
from
Prev Previous commit
Next Next commit
06_12
  • Loading branch information
RissyRan committed Jun 13, 2024
commit 7bf2f4aab33388511ba2c5be48d4e41c703fe04c
2 changes: 1 addition & 1 deletion MaxText/configs/base.yml
Original file line number Diff line number Diff line change
Expand Up @@ -240,7 +240,7 @@ adam_eps_root: 0. # A small constant applied to denominator inside the square ro
adam_weight_decay: 0.1 # AdamW Weight decay

# Stack trace parameters
collect_stack_trace: True
collect_stack_trace: False
stack_trace_to_cloud: False # Uploads to cloud logging if True, else to the console if False.
stack_trace_interval_seconds: 600 # Stack trace collection frequency in seconds.

Expand Down
148 changes: 79 additions & 69 deletions MaxText/layers/linears.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
from jax.ad_checkpoint import checkpoint_name
from jax.experimental import shard_map
import max_logging
from jax.sharding import PartitionSpec

try:
# from jax.experimental.pallas.ops.tpu import megablox as mblx
Expand Down Expand Up @@ -340,7 +341,9 @@ def permute(self, inputs, gate_logits, emb_dim):

# reshape inputs (batch, sequence, emb) to 2D
inputs_2d = jnp.reshape(inputs, (-1, emb_dim))
# print('inputs_2d', inputs_2d.shape)
weights, selected_experts = jax.lax.top_k(gate_logits, self.num_experts_per_tok)
# print('gate_logits', gate_logits.shape)
weights = jax.nn.softmax(weights.astype(self.weight_dtype), axis=-1).astype(self.dtype)
flatten_selected_experts = jnp.ravel(selected_experts)
sorted_selected_experts = jnp.argsort(flatten_selected_experts)
Expand All @@ -355,34 +358,41 @@ def permute(self, inputs, gate_logits, emb_dim):
def unpermute(self, intermediate, inputs, sorted_selected_experts, weights):
"""Unpermute tokens to original order and combine weights."""

# print("unpermute:...")
# print(f"intermediate: {intermediate.shape}")
# print(f"inputs: {inputs.shape}")
unsort_output = jnp.take(intermediate, indices=jnp.argsort(sorted_selected_experts), axis=0)
flatten_weights = jnp.ravel(weights)
combined_output = jnp.multiply(unsort_output, flatten_weights[:, None])
# print(f"combined_output: {combined_output.shape}")
groups = jnp.reshape(combined_output, (-1, self.num_experts_per_tok, combined_output.shape[1]))
return jnp.sum(groups, axis=1).reshape(inputs.shape).astype(self.dtype)
# print(f"groups: {groups.shape}")
return jnp.sum(groups, axis=1).reshape(-1, self.config.max_target_length, self.config.emb_dim).astype(self.dtype)

def call_gmm(self, inputs, group_sizes, mlp_activation, w0_kernel, w1_kernel, wo_kernel):
def call_gmm(self, inputs, gate_logits, config, w0_kernel, w1_kernel, wo_kernel):
# TODO(ranran): currently megablox works well on single host, and
# will add sharding properly to improve performance.
# kernel_axes = ('exp', 'embed', 'mlp')
# wo_kernel_axes = ('exp', 'mlp', 'embed')

tile_size = (self.config.tile_size_0, self.config.tile_size_1, self.config.tile_size_2)
# tile_size = None
# tile_size = (4096, 128, 128)
# tile_size = (512, 512, 512)
@functools.partial(
shard_map.shard_map,
mesh=self.mesh,
in_specs=(
(nn.logical_to_mesh_axes(('test', None))),
(nn.logical_to_mesh_axes((None, None, None))),
(nn.logical_to_mesh_axes((None,))),
),
out_specs=(nn.logical_to_mesh_axes(('test', None))),
check_rep=False,
)
# tile_size = (self.config.tile_size_0, self.config.tile_size_1, self.config.tile_size_2)
tile_size = None
# @functools.partial(
# shard_map.shard_map,
# mesh=self.mesh,
# in_specs=(
# (nn.logical_to_mesh_axes(('test', None))),
# (nn.logical_to_mesh_axes((None, None, None))),
# (nn.logical_to_mesh_axes((None,))),
# ),
# out_specs=(nn.logical_to_mesh_axes(('test', None))),
# check_rep=False,
# )
def gmm(inputs, kernel, group_sizes):
# print(f"inside")
# print(f"inputs: {inputs.shape}")
# print(f"kernel: {kernel.shape}")
# print(f"group_size: {group_sizes}")
hs_shape = inputs.shape
# pad length is the 1st dimension of tiling size in gmm call
pad_length = tile_size[0] if tile_size else 512
Expand All @@ -403,48 +413,43 @@ def gmm(inputs, kernel, group_sizes):
output = output[:hs_shape[0]]
return output

# from jax.sharding import PartitionSpec
# replicated_sharding = jax.sharding.NamedSharding(self.mesh, PartitionSpec(None))
# w0_kernel, w1_kernel, wo_kernel = jax.device_put((w0_kernel, w1_kernel, wo_kernel), device=replicated_sharding)

layer_w0 = gmm(inputs, w0_kernel, group_sizes)
layer_w1 = gmm(inputs, w1_kernel, group_sizes)
layer_act = _convert_to_activation_function(mlp_activation)(layer_w0)
intermediate_layer = jnp.multiply(layer_act, layer_w1)
output = gmm(intermediate_layer, wo_kernel, group_sizes)
return output
@functools.partial(
shard_map.shard_map,
mesh=self.mesh,
in_specs=(
(PartitionSpec('fsdp', None, None),
PartitionSpec('fsdp', None, None),
PartitionSpec(None, None, None),
PartitionSpec(None, None, None),
PartitionSpec(None, None, None),
)),
out_specs=PartitionSpec('fsdp', None, None),
check_rep=False,
)
def inner_fn(x, logits, w0, w1, wo):
x, sorted_selected_experts, weights, group_sizes = self.permute(x,logits,config.emb_dim)
# breakpoint()
layer_w0 = gmm(x, w0, group_sizes)
layer_w1 = gmm(x, w1, group_sizes)
layer_act = _convert_to_activation_function(config.mlp_activations[0])(layer_w0)
intermediate_layer = jnp.multiply(layer_act, layer_w1)
intermediate_output = gmm(intermediate_layer, wo, group_sizes)
# print(f"intermediate_output.shape: {intermediate_output.shape}")
# print(f"x.shape: {x.shape}")
output = self.unpermute(intermediate_output,
x,
sorted_selected_experts,
weights)
# print(f"unpermute: {output.shape}")
return output
# print(f"inner_fn inputs: {inputs.shape}")
return inner_fn(inputs, gate_logits, w0_kernel, w1_kernel, wo_kernel)

# inputs: (batch * selected_exp * sequence, emb_dim) - (262144, 4096)
# w0_kernel: (num_exp, emb_dim, mlp) -> (8, 4096, 14336)
# w1_kernel: (num_exp, emb_dim, mlp)
# o_kernel: (num_exp, mlp, emb_dim) - > (8, 14336, 4096)

# @functools.partial(
# shard_map.shard_map,
# mesh=self.mesh,
# in_specs=(
# (nn.logical_to_mesh_axes(('test', None))),
# (nn.logical_to_mesh_axes((None, None, None))),
# (nn.logical_to_mesh_axes((None, None, None))),
# (nn.logical_to_mesh_axes((None, None, None))),
# (nn.logical_to_mesh_axes((None,))),
# ),
# out_specs=(nn.logical_to_mesh_axes(('test', None))),
# check_rep=False,
# )
# def inner_fn(x, w0, w1, wo, gs):
# tile_size = (4096, 128, 128)
# layer_w0 = gmm(x, w0, gs, tile_size)
# layer_w1 = gmm(x, w1, gs, tile_size)
# layer_act = _convert_to_activation_function(mlp_activation)(layer_w0)
# intermediate_layer = jnp.multiply(layer_act, layer_w1)
# output = gmm(intermediate_layer, wo, gs, (tile_size[0], tile_size[2], tile_size[1]))
# # breakpoint()
# return output

# output = inner_fn(inputs, w0_kernel, w1_kernel, wo_kernel, group_sizes)
# return output

@nn.compact
def __call__(self, inputs):
cfg = self.config
Expand Down Expand Up @@ -472,23 +477,28 @@ def __call__(self, inputs):

if cfg.megablox:
max_logging.log("Running MoE megablox implementation.")
sorted_hidden_states, sorted_selected_experts, weights, group_sizes = self.permute(inputs,
gate_logits,
cfg.emb_dim)
from jax.sharding import PartitionSpec
replicated_sharding = jax.sharding.NamedSharding(self.mesh, PartitionSpec(None))
w0_kernel, w1_kernel, wo_kernel = jax.device_put((w0_kernel, w1_kernel, wo_kernel), device=replicated_sharding)
return self.call_gmm(inputs, gate_logits, cfg, w0_kernel, w1_kernel, wo_kernel)
# sorted_hidden_states, sorted_selected_experts, weights, group_sizes = self.permute(inputs,
# gate_logits,
# cfg.emb_dim)
# from jax.sharding import PartitionSpec
# replicated_sharding = jax.sharding.NamedSharding(self.mesh, PartitionSpec(None))
# w0_kernel, w1_kernel, wo_kernel = jax.device_put((w0_kernel, w1_kernel, wo_kernel), device=replicated_sharding)

intermediate_output = self.call_gmm(sorted_hidden_states,
group_sizes,
cfg.mlp_activations[0],
w0_kernel,
w1_kernel,
wo_kernel)
output = self.unpermute(intermediate_output,
inputs,
sorted_selected_experts,
weights)
# print("before")
# print(f"sorted_hidden_states: {sorted_hidden_states.shape}")
# print(f"group_sizes: {group_sizes}")
# print(f"w0_kernel: {w0_kernel.shape}")
# intermediate_output = self.call_gmm(sorted_hidden_states,
# group_sizes,
# cfg.mlp_activations[0],
# w0_kernel,
# w1_kernel,
# wo_kernel)
# output = self.unpermute(intermediate_output,
# inputs,
# sorted_selected_experts,
# weights)
else:
max_logging.log("Running MoE matmul implementation.")
with jax.named_scope("wi_0"):
Expand Down
29 changes: 15 additions & 14 deletions MaxText/tests/moe_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,23 +179,23 @@ def get_moe_output(variables, hidden_states, cfg, mesh):
fsdp_sharding = jax.sharding.NamedSharding(mesh, PartitionSpec('fsdp'))
replicated_sharding = jax.sharding.NamedSharding(mesh, PartitionSpec(None))
# moe_variables = jax.device_put(moe_variables, device=fsdp_sharding)
# hidden_states = jax.device_put(hidden_states, device=fsdp_sharding)

hidden_states = nn.with_logical_constraint(
hidden_states, ('activation_batch', 'activation_length', 'activation_embed')
)
hidden_states = jax.device_put(hidden_states, device=fsdp_sharding)

#hidden_states = nn.with_logical_constraint(
# hidden_states, ('activation_batch', 'activation_length', 'activation_embed')
# )
print('hidden states shape', hidden_states.shape)
rng = jax.random.PRNGKey(40)
moe_variables = model.init(rng, jax.random.normal(rng, (int(cfg.per_device_batch_size),
cfg.max_target_length,
cfg.base_emb_dim)))
#moe_variables = model.init(rng, jax.random.normal(rng, (int(cfg.per_device_batch_size) * 4 ,
# cfg.max_target_length,
# cfg.base_emb_dim)))
moe_variables = jax.device_put(moe_variables, device=fsdp_sharding)
# breakpoint()
# jax.debug.visualize_array_sharding(moe_variables['params']['gate']['kernel'].value)

time.simple_timeit(jax.jit(model.apply), moe_variables, hidden_states, tries=10, task="matmul")
output = jax.jit(model.apply)(moe_variables, hidden_states)
# output = model.apply(moe_variables, hidden_states)
# output = jax.jit(model.apply)(moe_variables, hidden_states)
output = model.apply(moe_variables, hidden_states)
return output


Expand All @@ -204,6 +204,7 @@ class MoeTest(unittest.TestCase):
def setUp(self):
import os
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "0"

pyconfig.initialize(
[None, 'configs/base.yml'],
run_name='test',
Expand All @@ -214,7 +215,7 @@ def setUp(self):
moe_matmul=True,
megablox=True,
ici_fsdp_parallelism=4,
per_device_batch_size=16,
per_device_batch_size=8,
dataset_type='synthetic',
attention='flash',
max_target_length=4096,
Expand All @@ -223,7 +224,7 @@ def setUp(self):
self.cfg = pyconfig.config
self.rng = jax.random.PRNGKey(42)

self.hidden_states = jax.random.uniform(self.rng, (int(self.cfg.per_device_batch_size),
self.hidden_states = jax.random.uniform(self.rng, (int(self.cfg.per_device_batch_size) * 4,
self.cfg.max_target_length,
self.cfg.base_emb_dim), dtype=self.cfg.dtype)
# print(f"{self.hidden_states.shape}=")
Expand All @@ -235,8 +236,8 @@ def setUp(self):
def test_moe_block(self):
variables, expected_output = get_expected_output(self.rng, self.hidden_states, self.cfg)
actual_output = get_moe_output(variables, self.hidden_states, self.cfg, self.mesh)
# print("expected_output", expected_output)
# print("actual_output", actual_output)
print("expected_output", expected_output.shape)
print("actual_output", actual_output.shape)
# breakpoint()
self.assertTrue(jax.numpy.allclose(expected_output, actual_output, rtol=1e-02, atol=1e-02, equal_nan=False))

Expand Down
Loading