Skip to content

Commit

Permalink
Adds tokenization task
Browse files Browse the repository at this point in the history
Signed-off-by: Flavia Beo <[email protected]>
  • Loading branch information
flaviabeo committed Jun 27, 2024
1 parent 358dfbc commit 1655446
Showing 1 changed file with 27 additions and 0 deletions.
27 changes: 27 additions & 0 deletions caikit_nlp/modules/text_embedding/embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,8 @@
SentenceSimilarityResult,
SentenceSimilarityResults,
SentenceSimilarityScores,
Token,
TokenizationResults,
)
from caikit.interfaces.nlp.tasks import (
EmbeddingTask,
Expand All @@ -50,6 +52,7 @@
RerankTasks,
SentenceSimilarityTask,
SentenceSimilarityTasks,
TokenizationTask,
)
import alog

Expand Down Expand Up @@ -120,6 +123,7 @@ class TruncatedTokensTuple(NamedTuple):
SentenceSimilarityTasks,
RerankTask,
RerankTasks,
TokenizationTask,
],
)
class EmbeddingModule(ModuleBase):
Expand Down Expand Up @@ -192,6 +196,29 @@ def public_model_info(cls) -> Dict[str, Any]: # pylint: disable=no-self-argumen
"sentence_embedding_dimension": cls.model.get_sentence_embedding_dimension(),
}

@TokenizationTask.taskmethod()
def run_tokenizer(
self,
text: str,
) -> TokenizationResults:
"""Run tokenization task against the model
Args:
text: str
Text to tokenize
Returns:
TokenizationResults
The token count
"""
result = self.model.tokenizer.encode_plus(text, return_offsets_mapping=True)

mapping = [
interv for interv in result.offset_mapping if (interv[1] - interv[0]) > 0
]
tokens = [Token(start=i[0], end=i[1], text=text[i[0] : i[1]]) for i in mapping]

return TokenizationResults(token_count=len(result.input_ids), results=tokens)

@classmethod
def _get_ipex(cls, ipex_flag):
"""Get IPEX optimization library if enabled and available, else return False
Expand Down

0 comments on commit 1655446

Please sign in to comment.