Skip to content

Commit

Permalink
clean_up
Browse files Browse the repository at this point in the history
  • Loading branch information
RissyRan committed Jun 14, 2024
1 parent 7bf2f4a commit 122829f
Show file tree
Hide file tree
Showing 5 changed files with 35 additions and 92 deletions.
9 changes: 2 additions & 7 deletions MaxText/configs/base.yml
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ hardware: 'tpu' # Supported hardware types are 'tpu', 'gpu', 'gpu_multiprocess'
# Parallelism
mesh_axes: ['data', 'fsdp', 'fsdp_transpose', 'sequence', 'tensor', 'autoregressive']
logical_axis_rules: [
['activation_batch', ['fsdp']],
['activation_batch', ['data', 'fsdp', 'fsdp_transpose',]],
['activation_heads', ['tensor','sequence']],
['activation_length', 'sequence'],
['activation_embed', 'tensor'],
Expand All @@ -125,7 +125,7 @@ logical_axis_rules: [
['activation_vocab', 'sequence'],
['mlp', ['fsdp_transpose', 'tensor', 'autoregressive']],
['vocab', ['tensor', 'autoregressive']],
['embed', ['fsdp']],
['embed', ['fsdp', 'fsdp_transpose', 'sequence']],
['embed', ['fsdp', 'sequence']],
['norm', 'tensor'],
['heads', ['tensor', 'autoregressive']],
Expand All @@ -134,7 +134,6 @@ logical_axis_rules: [
['cache_heads', ['autoregressive', 'tensor']],
['cache_kv', []],
['cache_sequence', []],
['test', 'fsdp'],
]
data_sharding: [['data', 'fsdp', 'fsdp_transpose', 'sequence', 'tensor', 'autoregressive']]

Expand Down Expand Up @@ -303,7 +302,3 @@ enable_checkpoint_standard_logger: False

# Single-controller
enable_single_controller: False

tile_size_0: 512
tile_size_1: 512
tile_size_2: 512
105 changes: 22 additions & 83 deletions MaxText/layers/linears.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@
RMSNorm = normalizations.RMSNorm
Quant = quantizations.AqtQuantization

MESH_FSDP_AXIS = "fsdp"

def _convert_to_activation_function(fn_or_string: Union[str, Callable[..., Any]]) -> Callable[..., Any]:
"""Convert a string to an activation function."""
Expand Down Expand Up @@ -279,6 +280,7 @@ class MoeBlock(nn.Module):
kernel_axes: Tuple with axes to apply kernel function.
weight_dtype: Type for the weights.
dtype: Type for the dense layer.
quant: Optional quantization config, no quantization if None.
"""

config: Config
Expand All @@ -289,17 +291,16 @@ class MoeBlock(nn.Module):
kernel_axes: Tuple[str, ...]
weight_dtype: DType = jnp.bfloat16 # todo: check type
dtype: DType = jnp.bfloat16
quant: Optional[Quant] = None

def generate_kernels(self, num_experts, emb_dim, mlp_dim):

kernel_in_axis = np.arange(1)
kernel_out_axis = np.arange(1, 2)
kernel_init = nd_dense_init(1.0, 'fan_in', 'truncated_normal')

# kernel_axes = ('exp', 'embed', 'mlp')
# wo_kernel_axes = ('exp', 'mlp', 'embed')
kernel_axes = (None, 'test', None)
wo_kernel_axes = (None, None, 'test')
kernel_axes = (None, 'embed', None)
wo_kernel_axes = (None, None, 'embed')

w0_kernel = self.param(
'wi_0',
Expand Down Expand Up @@ -328,78 +329,46 @@ def generate_kernels(self, num_experts, emb_dim, mlp_dim):
kernel_out_axis,
)
wo_kernel = jnp.asarray(wo_kernel, self.dtype)

# from jax.sharding import PartitionSpec
# fsdp_sharding = jax.sharding.NamedSharding(self.mesh, PartitionSpec('fsdp'))
# w0_kernel = jax.device_put(w0_kernel, device=fsdp_sharding)
# w1_kernel = jax.device_put(w1_kernel, device=fsdp_sharding)
# wo_kernel = jax.device_put(wo_kernel, device=fsdp_sharding)
return w0_kernel, w1_kernel, wo_kernel

def permute(self, inputs, gate_logits, emb_dim):
"""Permute tokens to group by expert to fit gmm call."""

# reshape inputs (batch, sequence, emb) to 2D
inputs_2d = jnp.reshape(inputs, (-1, emb_dim))
# print('inputs_2d', inputs_2d.shape)
weights, selected_experts = jax.lax.top_k(gate_logits, self.num_experts_per_tok)
# print('gate_logits', gate_logits.shape)
weights = jax.nn.softmax(weights.astype(self.weight_dtype), axis=-1).astype(self.dtype)
flatten_selected_experts = jnp.ravel(selected_experts)
sorted_selected_experts = jnp.argsort(flatten_selected_experts)
# repeat inputs for number of active experts
repeat_inputs = jnp.repeat(inputs_2d, self.num_experts_per_tok, axis=0)
# sort inputs for number of selected experts
sorted_inputs = jnp.take(repeat_inputs, indices=sorted_selected_experts, axis=0).astype(self.dtype)
group_size = jnp.bincount(flatten_selected_experts, length=self.num_experts)
# breakpoint()
return sorted_inputs, sorted_selected_experts, weights, group_size

def unpermute(self, intermediate, inputs, sorted_selected_experts, weights):
def unpermute(self, intermediate, sorted_selected_experts, weights):
"""Unpermute tokens to original order and combine weights."""

# print("unpermute:...")
# print(f"intermediate: {intermediate.shape}")
# print(f"inputs: {inputs.shape}")
unsort_output = jnp.take(intermediate, indices=jnp.argsort(sorted_selected_experts), axis=0)
flatten_weights = jnp.ravel(weights)
combined_output = jnp.multiply(unsort_output, flatten_weights[:, None])
# print(f"combined_output: {combined_output.shape}")
groups = jnp.reshape(combined_output, (-1, self.num_experts_per_tok, combined_output.shape[1]))
# print(f"groups: {groups.shape}")
return jnp.sum(groups, axis=1).reshape(-1, self.config.max_target_length, self.config.emb_dim).astype(self.dtype)

def call_gmm(self, inputs, gate_logits, config, w0_kernel, w1_kernel, wo_kernel):
# TODO(ranran): currently megablox works well on single host, and
# will add sharding properly to improve performance.
# kernel_axes = ('exp', 'embed', 'mlp')
# wo_kernel_axes = ('exp', 'mlp', 'embed')

# tile_size = (self.config.tile_size_0, self.config.tile_size_1, self.config.tile_size_2)
# TODO(ranran): update the static default tile_size
# tile_size = (512, 512, 512)
tile_size = None
# @functools.partial(
# shard_map.shard_map,
# mesh=self.mesh,
# in_specs=(
# (nn.logical_to_mesh_axes(('test', None))),
# (nn.logical_to_mesh_axes((None, None, None))),
# (nn.logical_to_mesh_axes((None,))),
# ),
# out_specs=(nn.logical_to_mesh_axes(('test', None))),
# check_rep=False,
# )
# replicated_sharding = jax.sharding.NamedSharding(self.mesh, PartitionSpec(None))

def gmm(inputs, kernel, group_sizes):
# print(f"inside")
# print(f"inputs: {inputs.shape}")
# print(f"kernel: {kernel.shape}")
# print(f"group_size: {group_sizes}")
# breakpoint()
# kernel = jax.lax.all_gather(kernel, 'fsdp', axis=axis_index, tiled=True)
# breakpoint()
hs_shape = inputs.shape
# pad length is the 1st dimension of tiling size in gmm call
pad_length = tile_size[0] if tile_size else 512
if hs_shape[0] % pad_length:
pad_length = pad_length - hs_shape[0] % pad_length
inputs = jax.lax.pad(inputs, 0.0, [(0, pad_length, 0), (0,0,0)])
# inputs = jax.lax.pad(inputs.astype(jnp.float32), 0.0, [(0, pad_length, 0), (0,0,0)])
inputs = jax.lax.pad(inputs.astype(jnp.float32), 0.0, [(0, pad_length, 0), (0,0,0)])

inputs = inputs.astype(self.dtype)
kernel = kernel.astype(self.weight_dtype)
Expand All @@ -417,38 +386,28 @@ def gmm(inputs, kernel, group_sizes):
shard_map.shard_map,
mesh=self.mesh,
in_specs=(
(PartitionSpec('fsdp', None, None),
PartitionSpec('fsdp', None, None),
PartitionSpec(MESH_FSDP_AXIS, None, None),
PartitionSpec(MESH_FSDP_AXIS, None, None),
PartitionSpec(None, None, None),
PartitionSpec(None, None, None),
PartitionSpec(None, None, None),
)),
out_specs=PartitionSpec('fsdp', None, None),
),
out_specs=PartitionSpec(MESH_FSDP_AXIS, None, None),
check_rep=False,
)
def inner_fn(x, logits, w0, w1, wo):
x, sorted_selected_experts, weights, group_sizes = self.permute(x,logits,config.emb_dim)
# breakpoint()
x, sorted_selected_experts, weights, group_sizes = self.permute(x, logits, config.emb_dim)

layer_w0 = gmm(x, w0, group_sizes)
layer_w1 = gmm(x, w1, group_sizes)
layer_act = _convert_to_activation_function(config.mlp_activations[0])(layer_w0)
intermediate_layer = jnp.multiply(layer_act, layer_w1)
intermediate_output = gmm(intermediate_layer, wo, group_sizes)
# print(f"intermediate_output.shape: {intermediate_output.shape}")
# print(f"x.shape: {x.shape}")
output = self.unpermute(intermediate_output,
x,
sorted_selected_experts,
weights)
# print(f"unpermute: {output.shape}")
return output
# print(f"inner_fn inputs: {inputs.shape}")
return inner_fn(inputs, gate_logits, w0_kernel, w1_kernel, wo_kernel)

# inputs: (batch * selected_exp * sequence, emb_dim) - (262144, 4096)
# w0_kernel: (num_exp, emb_dim, mlp) -> (8, 4096, 14336)
# w1_kernel: (num_exp, emb_dim, mlp)
# o_kernel: (num_exp, mlp, emb_dim) - > (8, 14336, 4096)

@nn.compact
def __call__(self, inputs):
Expand All @@ -459,7 +418,8 @@ def __call__(self, inputs):
dtype=self.dtype,
kernel_init=self.kernel_init,
kernel_axes=self.kernel_axes,
name="gate")(inputs)
name="gate",
quant=self.quant)(inputs)

top_k_weights, top_k_indices = jax.lax.top_k(gate_logits, self.num_experts_per_tok)
flattened_top_k_weights = top_k_weights.reshape(-1, self.num_experts_per_tok)
Expand All @@ -478,27 +438,6 @@ def __call__(self, inputs):
if cfg.megablox:
max_logging.log("Running MoE megablox implementation.")
return self.call_gmm(inputs, gate_logits, cfg, w0_kernel, w1_kernel, wo_kernel)
# sorted_hidden_states, sorted_selected_experts, weights, group_sizes = self.permute(inputs,
# gate_logits,
# cfg.emb_dim)
# from jax.sharding import PartitionSpec
# replicated_sharding = jax.sharding.NamedSharding(self.mesh, PartitionSpec(None))
# w0_kernel, w1_kernel, wo_kernel = jax.device_put((w0_kernel, w1_kernel, wo_kernel), device=replicated_sharding)

# print("before")
# print(f"sorted_hidden_states: {sorted_hidden_states.shape}")
# print(f"group_sizes: {group_sizes}")
# print(f"w0_kernel: {w0_kernel.shape}")
# intermediate_output = self.call_gmm(sorted_hidden_states,
# group_sizes,
# cfg.mlp_activations[0],
# w0_kernel,
# w1_kernel,
# wo_kernel)
# output = self.unpermute(intermediate_output,
# inputs,
# sorted_selected_experts,
# weights)
else:
max_logging.log("Running MoE matmul implementation.")
with jax.named_scope("wi_0"):
Expand Down
3 changes: 2 additions & 1 deletion MaxText/layers/mistral.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,8 +129,9 @@ def __call__(
num_experts_per_tok=cfg.num_experts_per_tok,
mesh=mesh,
kernel_init=initializers.nd_dense_init(1.0, 'fan_in', 'truncated_normal'),
kernel_axes=('embed', 'mlp'),
kernel_axes=(None, None),
dtype=cfg.dtype,
quant=self.quant,
)(hidden_states)
mlp_lnx = nn.with_logical_constraint(
mlp_lnx, ('activation_batch', 'activation_length', 'activation_embed')
Expand Down
8 changes: 8 additions & 0 deletions MaxText/optimizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,14 @@ def get_optimizer(config, learning_rate_schedule):
epsilon_root=config.adam_eps_root,
weight_decay=config.adam_weight_decay,
)
elif config.opt_type == "sgd":
return optax.sgd(
learning_rate_schedule
)
elif config.opt_type == "adafactor":
return optax.adafactor(
learning_rate_schedule
)
else:
raise ValueError(f"{config.opt_type=} is not a supported.")

Expand Down
2 changes: 1 addition & 1 deletion MaxText/tests/moe_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ def get_moe_output(variables, hidden_states, cfg, mesh):
num_experts_per_tok=cfg.num_experts_per_tok,
mesh=mesh,
kernel_init=initializers.nd_dense_init(1.0, 'fan_in', 'truncated_normal'),
kernel_axes=(None, 'test'),
kernel_axes=('embed', 'mlp'),
dtype=cfg.dtype,
)
# print("jax.tree_util.tree_structure(variables)")
Expand Down

0 comments on commit 122829f

Please sign in to comment.