Skip to content

Commit

Permalink
test on v5p
Browse files Browse the repository at this point in the history
  • Loading branch information
RissyRan committed Jun 10, 2024
1 parent cc43d18 commit d4dc064
Show file tree
Hide file tree
Showing 3 changed files with 10 additions and 1 deletion.
6 changes: 5 additions & 1 deletion MaxText/configs/base.yml
Original file line number Diff line number Diff line change
Expand Up @@ -302,4 +302,8 @@ enable_checkpoint_cloud_logger: False
enable_checkpoint_standard_logger: False

# Single-controller
enable_single_controller: False
enable_single_controller: False

tile_size_0: 4096
tile_size_1: 128
tile_size_2: 128
1 change: 1 addition & 0 deletions MaxText/layers/linears.py
Original file line number Diff line number Diff line change
Expand Up @@ -361,6 +361,7 @@ def call_gmm(self, inputs, group_sizes, mlp_activation, w0_kernel, w1_kernel, wo
# kernel_axes = ('exp', 'embed', 'mlp')
# wo_kernel_axes = ('exp', 'mlp', 'embed')

# tile_size = (self.config.tile_size_0, self.config.tile_size_1, self.config.tile_size_2)
tile_size = (4096, 128, 128)
@functools.partial(
shard_map.shard_map,
Expand Down
4 changes: 4 additions & 0 deletions MaxText/tests/moe_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,10 @@ def get_moe_output(variables, hidden_states, cfg, mesh):

# print("get_moe_output expected_variables", variables)
# breakpoint()
from jax.sharding import PartitionSpec
sharding = jax.sharding.NamedSharding(mesh, PartitionSpec(None))
jax.device_put(moe_variables, device=sharding)
jax.device_put(hidden_states, device=sharding)
time.simple_timeit(jax.jit(model.apply), moe_variables, hidden_states, tries=10, task="matmul")
output = jax.jit(model.apply)(moe_variables, hidden_states)
# output = model.apply(moe_variables, hidden_states)
Expand Down

0 comments on commit d4dc064

Please sign in to comment.