Skip to content

Commit

Permalink
Add input_token_count to results
Browse files Browse the repository at this point in the history
Signed-off-by: Mynhardt Burger <[email protected]>
  • Loading branch information
mynhardtburger committed Mar 5, 2024
1 parent 8ca93c3 commit 3c1f4b9
Showing 1 changed file with 99 additions and 43 deletions.
142 changes: 99 additions & 43 deletions caikit_nlp/modules/text_embedding/embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,15 @@
# limitations under the License.

# Standard
from typing import List, Optional, Union
from collections.abc import Iterable, Sized
from typing import Callable, List, NamedTuple, Optional, Tuple, TypeVar, Union
import importlib
import os
import time

# Third Party
from torch.backends import mps
from transformers import BatchEncoding
import numpy as np
import torch

Expand Down Expand Up @@ -56,6 +58,7 @@
logger = alog.use_channel("TXT_EMB")
error = error_handler.get(logger)


# To avoid dependency problems, make sentence-transformers an optional import and
# defer any ModuleNotFoundError until someone actually tries to init a model with this module.
try:
Expand Down Expand Up @@ -87,6 +90,15 @@ def __init__(self, *args, **kwargs): # pylint: disable=unused-argument
BATCH_SIZE = env_val_to_int(val=embedding_cfg.get("batch_size"), default=0)
DEVICE = embedding_cfg.get("device", "")

RT = TypeVar("RT") # return type


class EmbeddingResultTuple(NamedTuple):
"""Output of SentenceTransformerWithTruncate.encode()"""

embedding: np.ndarray
input_token_count: int


@module(
"eeb12558-b4fa-4f34-a9fd-3f5890e9cd3f",
Expand All @@ -102,7 +114,6 @@ def __init__(self, *args, **kwargs): # pylint: disable=unused-argument
],
)
class EmbeddingModule(ModuleBase):

# Retry count if enabled to try again (was for thread contention errors)
RETRY_COUNT = max(RETRIES, 0) # Ensure non-negative, before using in loop!

Expand Down Expand Up @@ -177,7 +188,7 @@ def _get_ipex(cls, ipex_flag):
f"IPEX enabled in env, but skipping ipex.optimize() because "
f"import intel_extension_for_pytorch failed with exception: {ie}"
)
logger.warning(msg, exc_info=1)
logger.warning(msg, exc_info=True)

return ret

Expand Down Expand Up @@ -210,7 +221,6 @@ def _get_backend(use_ipex, use_device):

@staticmethod
def _optimize(model, ipex, device, autocast, pt2_compile):

if ipex:
if autocast: # IPEX performs best with autocast using bfloat16
model = ipex.optimize(
Expand All @@ -233,7 +243,7 @@ def _optimize(model, ipex, device, autocast, pt2_compile):
logger.warning(warn_msg, exc_info=True)
return model

def _with_retry(self, fn, *args, **kwargs):
def _with_retry(self, fn: Callable[..., RT], *args, **kwargs) -> RT:
first_exception = None
for count in range(1 + self.RETRY_COUNT): # try once plus retries (if needed)
try:
Expand All @@ -253,7 +263,7 @@ def _with_retry(self, fn, *args, **kwargs):
exception=first_exception,
)

def _encode_with_retry(self, *args, **kwargs):
def _encode_with_retry(self, *args, **kwargs) -> EmbeddingResultTuple:
"""All encode calls should use this for consistent param adding and retry loop"""

# Add the batch_size kwarg if not passed in and given a usable BATCH_SIZE
Expand Down Expand Up @@ -287,12 +297,13 @@ def run_embedding(
"""
error.type_check("<NLP27491611E>", str, text=text)

embeddings = self._encode_with_retry(
embeddings, input_token_count = self._encode_with_retry(
text, truncate_input_tokens=truncate_input_tokens
)
return EmbeddingResult(
result=Vector1D.from_vector(embeddings),
producer_id=self.PRODUCER_ID,
input_token_count=input_token_count,
)

@EmbeddingTasks.taskmethod()
Expand Down Expand Up @@ -321,12 +332,15 @@ def run_embeddings(
): # encode allows str, but the result would lack a dimension
texts = [texts]

embeddings = self._encode_with_retry(
embeddings, input_token_count = self._encode_with_retry(
texts, truncate_input_tokens=truncate_input_tokens
)
vectors = [Vector1D.from_vector(e) for e in embeddings]

return EmbeddingResults(
results=ListOfVector1D(vectors=vectors), producer_id=self.PRODUCER_ID
results=ListOfVector1D(vectors=vectors),
producer_id=self.PRODUCER_ID,
input_token_count=input_token_count,
)

@SentenceSimilarityTask.taskmethod()
Expand All @@ -352,17 +366,20 @@ def run_sentence_similarity(
SentenceSimilarityResult: Similarity scores for each sentence.
"""

source_embedding = self._encode_with_retry(
source_embedding, source_token_count = self._encode_with_retry(
source_sentence, truncate_input_tokens=truncate_input_tokens
)
embeddings = self._encode_with_retry(
embeddings, embeddings_token_count = self._encode_with_retry(
sentences, truncate_input_tokens=truncate_input_tokens
)

input_token_count = source_token_count + embeddings_token_count
res = cos_sim(source_embedding, embeddings)

return SentenceSimilarityResult(
result=SentenceSimilarityScores(scores=res.tolist()[0]),
producer_id=self.PRODUCER_ID,
input_token_count=input_token_count,
)

@SentenceSimilarityTasks.taskmethod()
Expand All @@ -389,18 +406,21 @@ def run_sentence_similarities(
Each one contains the source-sentence's score for each sentence in order.
"""

source_embedding = self._encode_with_retry(
source_embedding, source_token_count = self._encode_with_retry(
source_sentences, truncate_input_tokens=truncate_input_tokens
)
embeddings = self._encode_with_retry(
embeddings, embeddings_token_count = self._encode_with_retry(
sentences, truncate_input_tokens=truncate_input_tokens
)

input_token_count = source_token_count + embeddings_token_count
res = cos_sim(source_embedding, embeddings)
float_list_list = res.tolist()

return SentenceSimilarityResults(
results=[SentenceSimilarityScores(fl) for fl in float_list_list],
producer_id=self.PRODUCER_ID,
input_token_count=input_token_count,
)

@RerankTask.taskmethod()
Expand Down Expand Up @@ -464,17 +484,22 @@ def run_rerank_query(
return_documents=return_documents,
return_queries=return_query,
return_text=return_text,
).results
)

if results:
return RerankResult(result=results[0], producer_id=self.PRODUCER_ID)
if results.results:
return RerankResult(
result=results.results[0],
producer_id=self.PRODUCER_ID,
input_token_count=results.input_token_count,
)

RerankResult(
producer_id=self.PRODUCER_ID,
result=RerankScore(
scores=[],
query=query if return_query else None,
),
producer_id=self.PRODUCER_ID,
input_token_count=results.input_token_count,
)

@RerankTasks.taskmethod()
Expand Down Expand Up @@ -546,21 +571,19 @@ def get_text(doc):

doc_texts = [get_text(doc) for doc in documents]

doc_embeddings = normalize_embeddings(
self._encode_with_retry(
doc_texts,
truncate_input_tokens=truncate_input_tokens,
convert_to_tensor=True,
).to(self.model.device)
doc_embeddings, doc_token_count = self._encode_with_retry(
doc_texts,
truncate_input_tokens=truncate_input_tokens,
convert_to_tensor=True,
)
doc_embeddings = normalize_embeddings(doc_embeddings.to(self.model.device))

query_embeddings = normalize_embeddings(
self._encode_with_retry(
queries,
truncate_input_tokens=truncate_input_tokens,
convert_to_tensor=True,
).to(self.model.device)
query_embeddings, query_token_count = self._encode_with_retry(
queries,
truncate_input_tokens=truncate_input_tokens,
convert_to_tensor=True,
)
query_embeddings = normalize_embeddings(query_embeddings.to(self.model.device))

res = semantic_search(
query_embeddings, doc_embeddings, top_k=top_n, score_function=dot_score
Expand Down Expand Up @@ -588,8 +611,13 @@ def add_query(q):
)
for q, r in enumerate(res)
]
input_token_count = doc_token_count + query_token_count

return RerankResults(results=results, producer_id=self.PRODUCER_ID)
return RerankResults(
results=results,
producer_id=self.PRODUCER_ID,
input_token_count=input_token_count,
)

@classmethod
def bootstrap(cls, model_name_or_path: str) -> "EmbeddingModule":
Expand Down Expand Up @@ -639,9 +667,26 @@ def save(self, model_path: str, *args, **kwargs):


class SentenceTransformerWithTruncate(SentenceTransformer):
@staticmethod
def _sum_token_count(
tokenized: Union[BatchEncoding, Iterable[BatchEncoding]],
) -> int:
error.type_check(
"<NLP82314993E>",
BatchEncoding,
Iterable[BatchEncoding],
tokenized=tokenized,
)

if isinstance(tokenized, BatchEncoding):
return len(tokenized.tokens())

if isinstance(tokenized, Iterable):
return sum([len(t.tokens()) for t in tokenized])

def _truncate_input_tokens(
self, truncate_input_tokens, texts: List[str]
) -> List[str]:
self, truncate_input_tokens: int, texts: List[str]
) -> Tuple[BatchEncoding, int]:
"""Truncate input tokens
Args:
truncate_input_tokens: int
Expand Down Expand Up @@ -728,18 +773,19 @@ def _truncate_input_tokens(
f"maximum sequence length for this model ({tokens} > {max_tokens})."
),
)
input_token_count = self._sum_token_count(tokenized)

return tokenized
return tokenized, input_token_count

def encode(
self,
sentences: Union[str, List[str]],
batch_size: int = 32,
device: str = None,
device: Optional[str] = None,
convert_to_numpy: bool = True,
convert_to_tensor: bool = False,
truncate_input_tokens: Optional[int] = 0,
) -> np.ndarray:
truncate_input_tokens: int = 0,
) -> EmbeddingResultTuple:
"""
Computes sentence embeddings
Expand Down Expand Up @@ -767,26 +813,36 @@ def encode(
convert_to_numpy = False

input_was_string = False
if isinstance(sentences, str) or not hasattr(
sentences, "__len__"
if isinstance(sentences, str) or not isinstance(
sentences, Sized
): # Cast an individual sentence to a list with length 1
sentences = [sentences]
list_of_sentences = [sentences]
input_was_string = True
elif isinstance(sentences, list):
list_of_sentences = sentences

if device is None:
device = self._target_device

self.to(device)

all_embeddings = []
length_sorted_idx = np.argsort([-self._text_length(sen) for sen in sentences])
sentences_sorted = [sentences[idx] for idx in length_sorted_idx]
length_sorted_idx = np.argsort(
[-self._text_length(sen) for sen in list_of_sentences]
)
sentences_sorted: list[str] = [
list_of_sentences[idx] for idx in length_sorted_idx
]

input_token_count = 0

for start_index in range(0, len(sentences), batch_size):
sentences_batch = sentences_sorted[start_index : start_index + batch_size]
features = self._truncate_input_tokens(
features, token_count = self._truncate_input_tokens(
truncate_input_tokens, sentences_batch
)
input_token_count += token_count

features = batch_to_device(features, device)

if AUTOCAST:
Expand Down Expand Up @@ -814,4 +870,4 @@ def encode(
if input_was_string:
all_embeddings = all_embeddings[0]

return all_embeddings
return all_embeddings, input_token_count

0 comments on commit 3c1f4b9

Please sign in to comment.