Skip to content

Commit

Permalink
Add token count asserts for all endpoint tests
Browse files Browse the repository at this point in the history
Signed-off-by: Mynhardt Burger <[email protected]>
  • Loading branch information
mynhardtburger committed Mar 8, 2024
1 parent 09e66fc commit 9cf0491
Showing 1 changed file with 25 additions and 2 deletions.
27 changes: 25 additions & 2 deletions tests/modules/text_embedding/test_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
BOOTSTRAPPED_MODEL = EmbeddingModule.bootstrap(SEQ_CLASS_MODEL)

INPUT = "The quick brown fox jumps over the lazy dog."
INPUT_TOKEN_COUNT = 36 + 2 # [CLS] Thequickbrownfoxjumpsoverthelazydog. [SEP]

MANY_INPUTS = [
"The quick brown fox jumps over the lazy dog.",
Expand All @@ -45,11 +46,16 @@
]

QUERY = "What is foo bar?"
QUERY_TOKEN_COUNT = 13 + 2 # [CLS] 13 normal [SEP]

QUERIES: List[str] = [
"Who is foo?",
"Where is the bar?",
]
QUERIES_TOKEN_COUNT = (9 + 2) + (
14 + 2
) # [CLS] Whoisfoo? [SEP], [CLS] Whereisthebar? [SEP]


# These are used to test that documents can handle different types in and out
TYPE_KEYS = "str_test", "int_test", "float_test", "nested_dict_test"
Expand Down Expand Up @@ -77,9 +83,16 @@
},
]

# The `text` and `_text` keys are extracted from DOCS as input to the tokenizer
# [CLS] foo [SEP], [CLS] bar [SEP], [CLS] fooandbar [SEP], [CLS] Whereisthebar [SEP]
DOCS_TOKEN_COUNT = (3 + 2) + (3 + 2) + (9 + 2) + (13 + 2)

# Use text or _text from DOCS for our test sentences
SENTENCES = [d.get("text", d.get("_text")) for d in DOCS]

# [CLS] foo [SEP], [CLS] bar [SEP], [CLS] fooandbar [SEP], [CLS] Whereisthebar [SEP]
SENTENCES_TOKEN_COUNT = (3 + 2) + (3 + 2) + (9 + 2) + (13 + 2)

## Tests ########################################################################


Expand Down Expand Up @@ -221,6 +234,7 @@ def test_run_embedding_type_check(loaded_model):
def test_run_embedding(loaded_model):
res = loaded_model.run_embedding(text=INPUT)
_assert_is_expected_embedding_result(res)
assert res.input_token_count == INPUT_TOKEN_COUNT


def test_run_embeddings_str_type(loaded_model):
Expand All @@ -234,6 +248,7 @@ def test_run_embeddings(loaded_model):
res = loaded_model.run_embeddings(texts=[INPUT])
assert isinstance(res.results.vectors, list)
_assert_is_expected_embeddings_results(res.results)
assert res.input_token_count == INPUT_TOKEN_COUNT


@pytest.mark.parametrize(
Expand All @@ -255,7 +270,8 @@ def test_run_rerank_query_type_error(query, docs, top_n, loaded_model):

def test_run_rerank_query_no_type_error(loaded_model):
"""no type error with list of string queries and list of dict documents"""
loaded_model.run_rerank_query(query=QUERY, documents=DOCS, top_n=1)
res = loaded_model.run_rerank_query(query=QUERY, documents=DOCS, top_n=1)
assert res.input_token_count == QUERY_TOKEN_COUNT + DOCS_TOKEN_COUNT


@pytest.mark.parametrize(
Expand All @@ -273,6 +289,7 @@ def test_run_rerank_query_top_n(top_n, expected, loaded_model):
res = loaded_model.run_rerank_query(query=QUERY, documents=DOCS, top_n=top_n)
assert isinstance(res, RerankResult)
assert len(res.result.scores) == expected
assert res.input_token_count == QUERY_TOKEN_COUNT + DOCS_TOKEN_COUNT


def test_run_rerank_query_no_query(loaded_model):
Expand All @@ -296,6 +313,7 @@ def test_run_rerank_query(loaded_model):

types_found = _assert_valid_scores(scores)
_assert_types_found(types_found)
assert res.input_token_count == QUERY_TOKEN_COUNT + DOCS_TOKEN_COUNT


@pytest.mark.parametrize(
Expand All @@ -310,7 +328,8 @@ def test_run_rerank_queries_type_error(queries, docs, loaded_model):

def test_run_rerank_queries_no_type_error(loaded_model):
"""no type error with list of string queries and list of dict documents"""
loaded_model.run_rerank_queries(queries=QUERIES, documents=DOCS, top_n=99)
res = loaded_model.run_rerank_queries(queries=QUERIES, documents=DOCS, top_n=99)
assert res.input_token_count == QUERIES_TOKEN_COUNT + DOCS_TOKEN_COUNT


@pytest.mark.parametrize(
Expand All @@ -331,6 +350,7 @@ def test_run_rerank_queries_top_n(top_n, expected, loaded_model):
assert len(res.results) == len(QUERIES)
for result in res.results:
assert len(result.scores) == expected
assert res.input_token_count == QUERIES_TOKEN_COUNT + DOCS_TOKEN_COUNT


@pytest.mark.parametrize(
Expand Down Expand Up @@ -371,6 +391,7 @@ def test_run_rerank_queries(loaded_model):

# Make sure our document fields of different types made it in/out ok
_assert_types_found(types_found)
assert rerank_result.input_token_count == QUERIES_TOKEN_COUNT + DOCS_TOKEN_COUNT


def test_run_sentence_similarity(loaded_model):
Expand All @@ -381,6 +402,7 @@ def test_run_sentence_similarity(loaded_model):
assert len(scores) == len(SENTENCES)
for score in scores:
assert isinstance(score, float)
assert res.input_token_count == QUERY_TOKEN_COUNT + SENTENCES_TOKEN_COUNT


def test_run_sentence_similarities(loaded_model):
Expand All @@ -394,6 +416,7 @@ def test_run_sentence_similarities(loaded_model):
assert len(scores) == len(SENTENCES)
for score in scores:
assert isinstance(score, float)
assert res.input_token_count == QUERIES_TOKEN_COUNT + SENTENCES_TOKEN_COUNT


@pytest.mark.parametrize(
Expand Down

0 comments on commit 9cf0491

Please sign in to comment.