Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Perf megablox #694

Draft
wants to merge 17 commits into
base: main
Choose a base branch
from
Prev Previous commit
Next Next commit
perf_debug
  • Loading branch information
RissyRan committed Jun 12, 2024
commit 115b399863c735474b9d6fa33a0162138a4a64f4
6 changes: 3 additions & 3 deletions MaxText/configs/base.yml
Original file line number Diff line number Diff line change
Expand Up @@ -304,6 +304,6 @@ enable_checkpoint_standard_logger: False
# Single-controller
enable_single_controller: False

tile_size_0: 4096
tile_size_1: 128
tile_size_2: 128
tile_size_0: 512
tile_size_1: 512
tile_size_2: 512
4 changes: 2 additions & 2 deletions MaxText/layers/linears.py
Original file line number Diff line number Diff line change
Expand Up @@ -361,9 +361,9 @@ def call_gmm(self, inputs, group_sizes, mlp_activation, w0_kernel, w1_kernel, wo
# kernel_axes = ('exp', 'embed', 'mlp')
# wo_kernel_axes = ('exp', 'mlp', 'embed')

# tile_size = (self.config.tile_size_0, self.config.tile_size_1, self.config.tile_size_2)
tile_size = (self.config.tile_size_0, self.config.tile_size_1, self.config.tile_size_2)
# tile_size = (4096, 128, 128)
tile_size = (512, 512, 512)
# tile_size = (512, 512, 512)
@functools.partial(
shard_map.shard_map,
mesh=self.mesh,
Expand Down
30 changes: 27 additions & 3 deletions MaxText/tests/moe_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,19 @@ def get_moe_output(variables, hidden_states, cfg, mesh):
wi_1 = jnp.concatenate(exp_wi_1, axis=0, dtype=cfg.weight_dtype)
wo = jnp.concatenate(exp_wo, axis=0, dtype=cfg.weight_dtype)

kernel = nn.with_logical_constraint(
kernel, ('embed', 'mlp')
)
wi_0 = nn.with_logical_constraint(
wi_0, (None, 'test', None)
)
wi_1 = nn.with_logical_constraint(
wi_1, (None, 'test', None)
)
wo = nn.with_logical_constraint(
wo, (None, 'test', None)
)

moe_variables = {'params': {'gate': {'kernel': kernel},
'wi_0': wi_0,
'wi_1': wi_1,
Expand All @@ -163,9 +176,15 @@ def get_moe_output(variables, hidden_states, cfg, mesh):
# print("get_moe_output expected_variables", variables)
# breakpoint()
# from jax.sharding import PartitionSpec
# sharding = jax.sharding.NamedSharding(mesh, PartitionSpec(None))
# jax.device_put(moe_variables, device=sharding)
# jax.device_put(hidden_states, device=sharding)
# fsdp_sharding = jax.sharding.NamedSharding(mesh, PartitionSpec('fsdp'))
# moe_variables = jax.device_put(moe_variables, device=fsdp_sharding)
# hidden_states = jax.device_put(hidden_states, device=fsdp_sharding)

hidden_states = nn.with_logical_constraint(
hidden_states, ('activation_batch', 'activation_length', 'activation_embed')
)


time.simple_timeit(jax.jit(model.apply), moe_variables, hidden_states, tries=10, task="matmul")
output = jax.jit(model.apply)(moe_variables, hidden_states)
# output = model.apply(moe_variables, hidden_states)
Expand All @@ -186,6 +205,11 @@ def setUp(self):
weight_dtype='bfloat16',
moe_matmul=True,
megablox=True,
ici_fsdp_parallelism=4,
per_device_batch_size=4,
dataset_type='synthetic',
attention='flash',
max_target_length=4096,
)

self.cfg = pyconfig.config
Expand Down
Loading