Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

LLaMA-2 #101

Merged
merged 7 commits into from
Jul 21, 2023
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
back compat again
  • Loading branch information
BlackSamorez committed Jul 21, 2023
commit e0409366d23420e93e73ec32dfd219ec34c73050
10 changes: 8 additions & 2 deletions src/tensor_parallel/slicing_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -350,15 +350,17 @@ def get_llama_config(model_config: PretrainedConfig, devices: Sequence[torch.dev
try:
num_kv = model_config.num_key_value_heads
q_per_kv = model_config.num_attention_heads // model_config.num_key_value_heads
new_modeling = True
except AttributeError:
num_kv = model_config.num_attention_heads
q_per_kv = 1
new_modeling = False

gather_kv_across_ranks = CollectiveOperation(
world_size=world_size, func=lambda *kvs: gather_kv(*kvs, world_size=world_size)
) # this operation ensures that we get attention cache for all heads on each device

return Config(
config = Config(
state_rules={
# LlamaAttention
r".*self_attn\.q_proj\.weight$": SplitInChunks(
Expand Down Expand Up @@ -389,13 +391,17 @@ def get_llama_config(model_config: PretrainedConfig, devices: Sequence[torch.dev
attr_rules={
r".*self_attn$": {
"hidden_size": partial(split_inner_dim, num_heads=num_kv, world_size=world_size),
"num_key_value_heads": partial(split_num_heads, world_size=world_size),
"num_heads": lambda n, rank: q_per_kv
* split_num_heads(n // q_per_kv, rank=rank, world_size=world_size),
}
},
)

if new_modeling:
config.attr_rules[r".*self_attn$"]["num_key_value_heads"] = partial(split_num_heads, world_size=world_size)

return config


def get_refined_web_config(model_config: PretrainedConfig, devices: Sequence[torch.device]) -> Config:
# We can't use `RWConfig`` since it's custom code
Expand Down
Loading