From 4c7fd20ea3ee690ea2f5d6ec13cea9e2bf26669e Mon Sep 17 00:00:00 2001 From: Mynhardt Burger Date: Thu, 7 Mar 2024 22:17:55 -0500 Subject: [PATCH] add test for sort order Signed-off-by: Mynhardt Burger --- .../modules/text_embedding/embedding.py | 4 +++ .../modules/text_embedding/test_embedding.py | 30 +++++++++++++++++++ 2 files changed, 34 insertions(+) diff --git a/caikit_nlp/modules/text_embedding/embedding.py b/caikit_nlp/modules/text_embedding/embedding.py index 721939a0..8283389e 100644 --- a/caikit_nlp/modules/text_embedding/embedding.py +++ b/caikit_nlp/modules/text_embedding/embedding.py @@ -847,6 +847,9 @@ def encode( self.to(device) all_embeddings = [] + + # Sort sentences according to length, from longest to shortest + # OOM errors then occurs at start of encoding length_sorted_idx = np.argsort( [-self._text_length(sen) for sen in list_of_sentences] ) @@ -880,6 +883,7 @@ def encode( embeddings = embeddings.detach().cpu() all_embeddings.extend(embeddings) + # Restore original order all_embeddings = [all_embeddings[idx] for idx in np.argsort(length_sorted_idx)] if convert_to_tensor: diff --git a/tests/modules/text_embedding/test_embedding.py b/tests/modules/text_embedding/test_embedding.py index 902c3079..4b89d878 100644 --- a/tests/modules/text_embedding/test_embedding.py +++ b/tests/modules/text_embedding/test_embedding.py @@ -817,3 +817,33 @@ def test_sum_token_count( ) assert token_count == expected_count + + +def test_encoding_order(loaded_model: EmbeddingModule): + """Confirm that encoding doesn't modify the original sort order""" + separate_embeddings = [loaded_model.run_embedding(text=i) for i in MANY_INPUTS] + combined_embeddings = loaded_model.run_embeddings(texts=MANY_INPUTS) + + separate_vectors = [ + e.to_dict()["result"]["data"]["values"] for e in separate_embeddings + ] + combined_vectors = [ + e["data"]["values"] for e in combined_embeddings.to_dict()["results"]["vectors"] + ] + + assert len(separate_vectors) == len( + combined_vectors + ), "expected the same number separate and combined embeddings" + + # test order by comparing value of individual embeddings in sequence + for i, e in enumerate(separate_vectors): + assert np.allclose(e, combined_vectors[i]) + + # test expected failure case by reordering + shifted_separate_vectors = separate_vectors[1:] + [separate_vectors[0]] + + for i, e in enumerate(shifted_separate_vectors): + assert e != separate_vectors[i], "expected order to be have been altered" + assert not np.allclose( + e, combined_vectors[i] + ), "expected altered order to not match combined vectors"