Skip to content

Commit

Permalink
Back compat llama
Browse files Browse the repository at this point in the history
  • Loading branch information
BlackSamorez committed Jul 21, 2023
1 parent 5d77165 commit 4ae93c0
Showing 1 changed file with 6 additions and 3 deletions.
9 changes: 6 additions & 3 deletions src/tensor_parallel/slicing_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -346,10 +346,13 @@ def get_llama_config(model_config: PretrainedConfig, devices: Sequence[torch.dev
assert model_config.model_type == "llama", f"Trying to pass {model_config.model_type} as llama config"

world_size = len(devices)
num_heads = model_config.num_attention_heads
head_dim = model_config.hidden_size // model_config.num_attention_heads
num_kv = model_config.num_key_value_heads
q_per_kv = model_config.num_attention_heads // model_config.num_key_value_heads
try:
num_kv = model_config.num_key_value_heads
q_per_kv = model_config.num_attention_heads // model_config.num_key_value_heads
except AttributeError:
num_kv = model_config.num_attention_heads
q_per_kv = 1

gather_kv_across_ranks = CollectiveOperation(
world_size=world_size, func=lambda *kvs: gather_kv(*kvs, world_size=world_size)
Expand Down

0 comments on commit 4ae93c0

Please sign in to comment.