Skip to content

Commit

Permalink
Fixed case where loaded lora adapter has no segments (#510)
Browse files Browse the repository at this point in the history
  • Loading branch information
tgaddair committed Jun 12, 2024
1 parent bd7db80 commit 25a29c6
Show file tree
Hide file tree
Showing 2 changed files with 71 additions and 3 deletions.
5 changes: 4 additions & 1 deletion server/lorax_server/adapters/lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,8 +251,11 @@ def load(
lora_a = {idx: adapter_weights[idx].weights_a for idx in segment_indices if idx in adapter_weights}
lora_b = {idx: adapter_weights[idx].weights_b for idx in segment_indices if idx in adapter_weights}

max_rank = max(adapter_weights[idx].lora_a_r for idx in segment_indices if idx in adapter_weights)
segment_ranks = [adapter_weights[idx].lora_a_r for idx in segment_indices if idx in adapter_weights]
if not segment_ranks:
return None

max_rank = max(segment_ranks)
if prefill or max_rank > BGMV_MAX_RANK:
use_sgmv = True
lora_a_ptr = torch.tensor(
Expand Down
69 changes: 67 additions & 2 deletions server/tests/utils/test_lora.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import List
from typing import Dict, List, Optional, Type
from unittest import mock
import pytest

Expand All @@ -7,10 +7,40 @@

from lorax_server.adapters.lora import LoraWeights
from lorax_server.adapters.types import LORA
from lorax_server.adapters.weights import AdapterBatchMetadata, LayerAdapterWeights
from lorax_server.adapters.weights import AdapterBatchMetadata, AdapterWeights, BatchAdapterWeights, LayerAdapterWeights
from lorax_server.utils.sgmv import MIN_RANK_CUSTOM


class FakeAdapterWeights(AdapterWeights):
@classmethod
def get_batch_types(cls) -> List[Type["FakeBatchAdapterWeights"]]:
return [FakeBatchAdapterWeights]

@property
def speculative_tokens(self) -> int:
return 0


class FakeBatchAdapterWeights(BatchAdapterWeights):
@classmethod
def has_adapter(self, adapter_index: int) -> bool:
False

@classmethod
def key(cls) -> str:
"fake"

@classmethod
def load(
cls,
adapter_weights: Dict[int, AdapterWeights],
meta: "AdapterBatchMetadata",
prefill: bool,
prefill_head_indices: torch.Tensor,
) -> Optional["BatchAdapterWeights"]:
return None


@pytest.mark.parametrize(
"lora_ranks",
[
Expand Down Expand Up @@ -71,4 +101,39 @@ def test_batched_lora_weights(lora_ranks: List[int]):
assert rd.segment_starts.shape == (2,)
assert rd.segment_ends.shape == (2,)


def test_batched_lora_weights_no_segments():
batched_weights = LayerAdapterWeights()
assert batched_weights.is_empty()

h = 1024

# fake weights
idx = 0
weights = FakeAdapterWeights()
batched_weights.add_adapter(idx, weights)

# lora weights
idx = 1
lora_rank = 16
weights = LoraWeights(
weights_a=[torch.randn((h, lora_rank), dtype=torch.float16)],
weights_b=[torch.randn((lora_rank, h), dtype=torch.float16)],
adapter_config=LoraConfig(r=lora_rank),
)
batched_weights.add_adapter(idx, weights)

assert not batched_weights.is_empty()
assert len(batched_weights.adapter_weights) == 2

meta = AdapterBatchMetadata(
adapter_indices=torch.tensor([0, 0, 0, 0], dtype=torch.int64),
adapter_set={0, 1},
adapter_segments=torch.tensor([0, 4], dtype=torch.int64),
segment_indices=[0],
)

with mock.patch("lorax_server.adapters.lora.get_tmp_tensors", return_value=(torch.empty(0), torch.empty(0))):
data = batched_weights.get_data(meta, prefill=True, prefill_head_indices=None).get(LORA)

print(data)

0 comments on commit 25a29c6

Please sign in to comment.