Skip to content

Commit

Permalink
add test for sort order
Browse files Browse the repository at this point in the history
Signed-off-by: Mynhardt Burger <[email protected]>
  • Loading branch information
mynhardtburger committed Mar 8, 2024
1 parent 6547083 commit 4c7fd20
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 0 deletions.
4 changes: 4 additions & 0 deletions caikit_nlp/modules/text_embedding/embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -847,6 +847,9 @@ def encode(
self.to(device)

all_embeddings = []

# Sort sentences according to length, from longest to shortest
# OOM errors then occurs at start of encoding
length_sorted_idx = np.argsort(
[-self._text_length(sen) for sen in list_of_sentences]
)
Expand Down Expand Up @@ -880,6 +883,7 @@ def encode(
embeddings = embeddings.detach().cpu()
all_embeddings.extend(embeddings)

# Restore original order
all_embeddings = [all_embeddings[idx] for idx in np.argsort(length_sorted_idx)]

if convert_to_tensor:
Expand Down
30 changes: 30 additions & 0 deletions tests/modules/text_embedding/test_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -817,3 +817,33 @@ def test_sum_token_count(
)

assert token_count == expected_count


def test_encoding_order(loaded_model: EmbeddingModule):
"""Confirm that encoding doesn't modify the original sort order"""
separate_embeddings = [loaded_model.run_embedding(text=i) for i in MANY_INPUTS]
combined_embeddings = loaded_model.run_embeddings(texts=MANY_INPUTS)

separate_vectors = [
e.to_dict()["result"]["data"]["values"] for e in separate_embeddings
]
combined_vectors = [
e["data"]["values"] for e in combined_embeddings.to_dict()["results"]["vectors"]
]

assert len(separate_vectors) == len(
combined_vectors
), "expected the same number separate and combined embeddings"

# test order by comparing value of individual embeddings in sequence
for i, e in enumerate(separate_vectors):
assert np.allclose(e, combined_vectors[i])

# test expected failure case by reordering
shifted_separate_vectors = separate_vectors[1:] + [separate_vectors[0]]

for i, e in enumerate(shifted_separate_vectors):
assert e != separate_vectors[i], "expected order to be have been altered"
assert not np.allclose(
e, combined_vectors[i]
), "expected altered order to not match combined vectors"

0 comments on commit 4c7fd20

Please sign in to comment.