Skip to content

Commit

Permalink
scale up
Browse files Browse the repository at this point in the history
  • Loading branch information
RissyRan committed Jun 17, 2024
1 parent 0ceeba2 commit fa91b37
Show file tree
Hide file tree
Showing 4 changed files with 39 additions and 2 deletions.
30 changes: 30 additions & 0 deletions MaxText/configs/models/mixtral-moe-1t.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
# Copyright 2024 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# model config for mixtral-8x7b

base_emb_dim: 4096
base_num_query_heads: 32
base_num_kv_heads: 8
base_mlp_dim: 14336
base_num_decoder_layers: 32
head_dim: 128
mlp_activations: ["silu","linear"]
vocab_size: 32000
enable_dropout: False
logits_via_embedding: False
normalization_layer_epsilon: 1.0e-5
num_experts: 12
num_experts_per_tok: 2
decoder_block: "mistral"
4 changes: 2 additions & 2 deletions MaxText/layers/linears.py
Original file line number Diff line number Diff line change
Expand Up @@ -375,8 +375,8 @@ def unpermute(self, intermediate, sorted_selected_experts, weights):

def call_gmm(self, inputs, gate_logits, config, w0_kernel, w1_kernel, wo_kernel):
# TODO(ranran): update the static default tile_size
# tile_size = (512, 512, 512)
tile_size = None
tile_size = (512, 512, 512)
# tile_size = None
# replicated_sharding = jax.sharding.NamedSharding(self.mesh, PartitionSpec(None))

def gmm(inputs, kernel, group_sizes):
Expand Down
7 changes: 7 additions & 0 deletions MaxText/pyconfig.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,11 +256,18 @@ def user_init(raw_keys):
raw_keys["steps"] = raw_keys["learning_rate_schedule_steps"]

emb_scale, num_head_scale, mlp_dim_scale, layer_scale = get_individual_scales(raw_keys["global_parameter_scale"])
# raw_keys["num_experts"] = raw_keys["global_parameter_scale"] * raw_keys["num_experts"] // 2
raw_keys["emb_dim"] = 2**emb_scale * raw_keys["base_emb_dim"]
raw_keys["num_query_heads"] = 2**num_head_scale * raw_keys["base_num_query_heads"]
raw_keys["num_kv_heads"] = 2**num_head_scale * raw_keys["base_num_kv_heads"]
raw_keys["mlp_dim"] = 2**mlp_dim_scale * raw_keys["base_mlp_dim"]
raw_keys["num_decoder_layers"] = 2**layer_scale * raw_keys["base_num_decoder_layers"]
max_logging.log(f"num_experts is: {raw_keys['num_experts']}")
max_logging.log(f"emb_dim is: {raw_keys['emb_dim']}")
max_logging.log(f"num_query_heads is: {raw_keys['num_query_heads']}")
max_logging.log(f"num_kv_heads is: {raw_keys['num_kv_heads']}")
max_logging.log(f"mlp_dim is: {raw_keys['mlp_dim']}")
max_logging.log(f"num_decoder_layers is: {raw_keys['num_decoder_layers']}")

raw_keys["global_batch_size_to_load"], raw_keys["global_batch_size_to_train_on"] = calculate_global_batch_sizes(raw_keys)
raw_keys["num_slices"] = get_num_slices(raw_keys)
Expand Down
Binary file added assets/tokenizer.mistral
Binary file not shown.

0 comments on commit fa91b37

Please sign in to comment.