Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

LLaMA-2 #101

Merged
merged 7 commits into from
Jul 21, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -39,11 +39,11 @@ dev =
pytest==6.2.5
pytest-forked
pytest-asyncio==0.16.0
accelerate==0.15.0
accelerate==0.20.3
black==22.3.0
isort==5.10.1
psutil
peft>=0.3.0
peft==0.3.0
einops==0.6.1
[options.packages.find]
where = src
29 changes: 23 additions & 6 deletions src/tensor_parallel/slicing_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -346,20 +346,31 @@ def get_llama_config(model_config: PretrainedConfig, devices: Sequence[torch.dev
assert model_config.model_type == "llama", f"Trying to pass {model_config.model_type} as llama config"

world_size = len(devices)
num_heads = model_config.num_attention_heads
head_dim = model_config.hidden_size // model_config.num_attention_heads
try:
num_kv = model_config.num_key_value_heads
q_per_kv = model_config.num_attention_heads // model_config.num_key_value_heads
new_modeling = True
except AttributeError:
num_kv = model_config.num_attention_heads
q_per_kv = 1
new_modeling = False

gather_kv_across_ranks = CollectiveOperation(
world_size=world_size, func=lambda *kvs: gather_kv(*kvs, world_size=world_size)
) # this operation ensures that we get attention cache for all heads on each device

return Config(
config = Config(
state_rules={
# LlamaAttention
r".*self_attn\.q_proj\.weight$": SplitInChunks(world_size=world_size, dim=0, chunk_size=head_dim),
r".*self_attn\.q_proj\.weight$": SplitInChunks(
world_size=world_size, dim=0, chunk_size=q_per_kv * head_dim
),
r".*self_attn\.k_proj\.weight$": SplitInChunks(world_size=world_size, dim=0, chunk_size=head_dim),
r".*self_attn\.v_proj\.weight$": SplitInChunks(world_size=world_size, dim=0, chunk_size=head_dim),
r".*self_attn\.o_proj\.weight$": SplitInChunks(world_size=world_size, dim=1, chunk_size=head_dim),
r".*self_attn\.o_proj\.weight$": SplitInChunks(
world_size=world_size, dim=1, chunk_size=q_per_kv * head_dim
),
# LlamaFeedForward
r".*mlp\.gate_proj\.weight$": Split(world_size=world_size, dim=0),
r".*mlp\.down_proj\.weight$": Split(world_size=world_size, dim=1),
Expand All @@ -379,12 +390,18 @@ def get_llama_config(model_config: PretrainedConfig, devices: Sequence[torch.dev
},
attr_rules={
r".*self_attn$": {
"hidden_size": partial(split_inner_dim, num_heads=num_heads, world_size=world_size),
"num_heads": partial(split_num_heads, world_size=world_size),
"hidden_size": partial(split_inner_dim, num_heads=num_kv, world_size=world_size),
"num_heads": lambda n, rank: q_per_kv
* split_num_heads(n // q_per_kv, rank=rank, world_size=world_size),
}
},
)

if new_modeling:
config.attr_rules[r".*self_attn$"]["num_key_value_heads"] = partial(split_num_heads, world_size=world_size)

return config


def get_refined_web_config(model_config: PretrainedConfig, devices: Sequence[torch.device]) -> Config:
# We can't use `RWConfig`` since it's custom code
Expand Down
34 changes: 19 additions & 15 deletions tests/test_transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,23 @@ def all_equal(iterator):
) # basically asserting that all of those have the same config


def prepare_model(model_name, use_lora):
if model_name == "BlackSamorez/falcon-40b-tiny-testing" and torch.__version__ < "2.0":
pytest.skip(f"Not testing {model_name} with torch=={torch.__version__}")
if model_name == "BlackSamorez/llama-2-tiny-testing" and transformers.__version__ < "4.31":
pytest.skip(f"Not testing {model_name} with transformers=={transformers.__version__}")

try:
model = AutoModelForCausalLM.from_pretrained(model_name, low_cpu_mem_usage=True, trust_remote_code=True).float()
except KeyError as err:
pytest.skip(f"Could not create model {model_name} with error {err}")
if use_lora:
if model_name == "gpt2":
pytest.skip("Not testing LoRA for gpt2")
model = add_lora(model, model_name)
return model


@pytest.mark.parametrize("use_lora", [False, True])
@pytest.mark.parametrize("use_config", [False, True])
@pytest.mark.parametrize("devices", [("cpu",) * 2, ("cpu",) * 3])
Expand All @@ -83,27 +100,14 @@ def all_equal(iterator):
"trl-internal-testing/tiny-random-GPTNeoXForCausalLM",
"Salesforce/codegen-350M-mono",
"Bingsu/llama-190m-arch",
"BlackSamorez/llama-2-tiny-testing",
"BlackSamorez/falcon-40b-tiny-testing",
],
)
def test_forward_gpt2_like(use_lora, use_config, devices, model_name):
torch.manual_seed(0)

if model_name == "BlackSamorez/falcon-40b-tiny-testing" and torch.__version__ < "2.0":
pytest.skip(f"Not testing {model_name} with torch=={torch.__version__}")

try:
model = (
AutoModelForCausalLM.from_pretrained(model_name, low_cpu_mem_usage=True, trust_remote_code=True)
.float()
.to(devices[0])
)
except KeyError as err:
pytest.skip(f"Could not create model {model_name} with error {err}")
if use_lora:
if model_name == "gpt2":
pytest.skip("Not testing LoRA for gpt2")
model = add_lora(model, model_name)
model = prepare_model(model_name, use_lora)

inp1 = torch.randint(1, 1000, size=(2, 3), device=devices[0])
inp2 = torch.randint(1, 1000, size=(2, 1), device=devices[0])
Expand Down
Loading