-
Notifications
You must be signed in to change notification settings - Fork 46
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add rerank and sentence-similarity tasks to text embedding module #235
Conversation
}, | ||
output_type=RerankPrediction, | ||
) | ||
class RerankTask(TaskBase): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: Can you please add a bit of description of what this task is supposed to do at either as docstring or as module docs
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
added
caikit_nlp/data_model/reranker.py
Outdated
class RerankScore(DataObjectBase): | ||
"""The score for one document (one query)""" | ||
|
||
document: JsonDict |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why does document
needs to be JsonDict
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
the desired API is a document that is JSON with a text (or alternative _text) field that is used for ranking while the rest of the document is typically returned reranked. JsonDict works for me for gRPC and REST while allowing different types input/output even nested. I'm not sure if you are recommending a preferred alternative.
If you are just wondering why not only pass text and return index, then I understand (agree) but that isn't the requested API for the rerank use case.
caikit_nlp/data_model/reranker.py
Outdated
class RerankQueryResult(DataObjectBase): | ||
"""Result for one query in a rerank task""" | ||
|
||
scores: List[RerankScore] | ||
|
||
|
||
@dataobject(package="caikit_data_model.caikit_nlp") | ||
@dataclass | ||
class RerankPrediction(DataObjectBase): | ||
"""Result for a rerank task""" | ||
|
||
results: List[RerankQueryResult] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is the results
here for 1 query or 1 document, and what is the relation between 1 query and 1 document and 1 result ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
1 query with the top_n document results in order of relevance for that query.
Edit: I was looking at the wrong part of the code snippet. 1 query for n docs is my explanation for RerankQueryResult.
I think the question was about RerankPrediction which is a list of RerankQueryResult corresponding to the input list of queries.
I will expand the docstring and rename RerankPrediction --> RerankPredictions
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
"""Initialize | ||
This function gets called by `.load` and `.train` function |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: indenting of Initialize vs second line
nit: this function also gets called from bootstrap
function
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
removed. It's not a good practice to try to document all usages and init() doesn't need "Initialize" called out.
queries: List[str], | ||
documents: List[JsonDict], | ||
top_n: Optional[int] = None, | ||
) -> RerankPrediction: | ||
"""Run inference on model. | ||
Args: | ||
queries: List[str] | ||
documents: List[JsonDict] | ||
top_n: Optional[int] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
A bit confused of the input and output requirements here. Can you please add in the docstring what the query is supposed to be and what document is supposed to be? Also in the Returns
section, can you please add information about what the ReRankPrediction
actually giving?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done and also renamed RerankPrediction -> RerankPredictions because this is used to provide a result for each of the queries (plural).
top_n = len(documents) | ||
|
||
# Using input document dicts so get "text" else "_text" else default to "" | ||
doc_texts = [srd.get("text") or srd.get("_text", "") for srd in documents] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is the requirement to handle text
and _text
because of JsonDict?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The comparison is typically with the "text" field of the JSON document. We also have cases where "_text" is used instead. So this impl uses "text" if found, but if not uses "_text" as the alternate.
Currently if neither is found text="" is used. This almost makes sense, but if that is a real use case we might need to handle it better.
doc_embeddings = self.model.encode(doc_texts, convert_to_tensor=True) | ||
doc_embeddings = doc_embeddings.to(self.model.device) | ||
doc_embeddings = normalize_embeddings(doc_embeddings) | ||
|
||
query_embeddings = self.model.encode(queries, convert_to_tensor=True) | ||
query_embeddings = query_embeddings.to(self.model.device) | ||
query_embeddings = normalize_embeddings(query_embeddings) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
device is not an input to any of the entry functions, load
, bootstrap
so using it to put to any device seems not necessary since it will by default be on cpu
unless moved
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
sentence-transformers automatically does the if cuda use gpu
logic (with env var controls). So here we use the device that was set on self.model.device and get all the embeddings to() there as well before scoring. This is the from the sentence-transformers optimization examples.
Performance eval for whether or not to use GPU here is not done yet, but we at least have environment control (e.g., per pod).
self.model = model | ||
|
||
@classmethod | ||
def load(cls, model_path: str) -> "Rerank": |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Since loading of this is exactly same as embedding
can we directly use that module's function and internally initialize embedding
's module, that way we don't have to have duplicate this code between 2 modules.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Now that we have multi-task, I've combined them.
The embedding PR is open. Now I'm stacking this PR on that one. Wanted to do smaller PRs, but of course there is a dependency and I think for rerank and sentence-similarity it makes sense to see both at once for review/discussion.
if len(queries) < 1 or len(documents) < 1: | ||
return RerankPrediction([]) | ||
|
||
if top_n is None or top_n < 1: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this default behavior is undocumented. Can we please add this in docstrings?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
Checked feedback again and added more docstring for some things I had missed. Thanks @gkumbhat |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Left a few minor comments. Looks great overall. My only real concern is using sentence-transformers
, since it doesn't look like it's being actively maintained
return cls(data=data) | ||
|
||
@classmethod | ||
def from_json(cls, json_str): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
...
from typing import Union, Any
...
def from_json(cls, json_str): | |
def from_json(cls, json_data: Union[dict[str, Any], str]) -> "Vector1D": |
Adding type hints might also be useful for the other classmethods
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done, but didn't rename the arg. Want to keep it in sync with the base.
caikit_nlp/data_model/reranker.py
Outdated
from caikit.core.data_model.json_dict import JsonDict | ||
|
||
|
||
@dataobject() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do these @dataobject()
declarations also need arguments as in other parts of the code?
@dataobject(package="caikit_data_model.caikit_nlp")
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
good catch. Updated them all to specify this package (fwiw).
pyproject.toml
Outdated
@@ -24,6 +24,7 @@ dependencies = [ | |||
"pandas>=1.5.0", | |||
"scikit-learn>=1.1", | |||
"scipy>=1.8.1", | |||
"sentence-transformers~=2.2.2", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm not familiar with sentence-transformers
, but it looks like the last release was in June 2022 and work on this project has been quite slow for since 2021. Are we sure this is the only/best choice?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think the popularity still goes to sentence-transformers (unless my searches are biased), but the maintenance activity does seem slow. The workaround is to use transformers feature-extraction directly and add mean pooling and normalization (and cosine, dot_score...). Hugging Face seems to prefer to defer to sentence-transformers for now as far as I can tell. Old requests to create HF pipelines to replace sentence-transformers were rejected.
@EmbeddingTasks.taskmethod() | ||
def run_embeddings( | ||
self, texts: List[str] # pylint: disable=redefined-builtin | ||
) -> ListOfVector1D: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@EmbeddingTasks.taskmethod() | |
def run_embeddings( | |
self, texts: List[str] # pylint: disable=redefined-builtin | |
) -> ListOfVector1D: | |
@EmbeddingTasks.taskmethod() | |
def run_embeddings(self, texts: List[str]) -> ListOfVector1D: |
Is disabling the redefined-builtin warning actually required?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
another good catch! Previously was "input" which is a builtin. So I removed all the obsolete pylint comments.
@EmbeddingTask.taskmethod() | ||
def run_embedding( | ||
self, text: str | ||
) -> EmbeddingResult: # pylint: disable=redefined-builtin |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is disabling the redefined-builtin warning actually required?
@EmbeddingTask.taskmethod() | |
def run_embedding( | |
self, text: str | |
) -> EmbeddingResult: # pylint: disable=redefined-builtin | |
@EmbeddingTask.taskmethod() | |
def run_embedding(self, text: str) -> EmbeddingResult: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
removed
# Local | ||
from caikit_nlp.data_model.reranker import RerankPredictions, RerankQueryResult | ||
|
||
logger = alog.use_channel("<SMPL_BLK>") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What does BMPL_BLK
stand for?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
was copy from sample "SMPL". I've removed these copy/paste error ones because they were not needed.
Thanks again for the good eyes
|
||
@task( | ||
required_parameters={ | ||
"documents": List[JsonDict], |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
As above, a Document
type, might make this field easier to handle
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
See above, but short version is I need the flexibility of any-key, any-value and this does that.
random_numpy_vector1d_float32 = random_number_generator.random( | ||
DUMMY_VECTOR_SHAPE, dtype=np.float32 | ||
) | ||
random_numpy_vector1d_float64 = random_number_generator.random( | ||
DUMMY_VECTOR_SHAPE, dtype=np.float64 | ||
) | ||
random_python_vector1d_float = random_numpy_vector1d_float32.tolist() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
These could be pytest fixtures
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
tests/data_model/test_reranker.py
Outdated
|
||
input_document = { | ||
"text": "this is the input text", | ||
"_text": "alternate _text here", | ||
"title": "some title attribute here", | ||
"anything": "another string attribute", | ||
"str_test": "test string", | ||
"int_test": 1234, | ||
"float_test": 9876.4321, | ||
} | ||
|
||
key = "".join(random.choices(string.ascii_letters, k=20)) | ||
value = "".join(random.choices(string.printable, k=100)) | ||
input_random_document = { | ||
"text": "".join(random.choices(string.printable, k=100)), | ||
"random_str": "".join(random.choices(string.printable, k=100)), | ||
"random_int": random.randint(-99999, 99999), | ||
"random_float": random.uniform(-99999, 99999), | ||
} | ||
|
||
input_documents = [input_document, input_random_document] | ||
|
||
input_score = { | ||
"document": input_document, | ||
"index": 1234, | ||
"score": 9876.54321, | ||
"text": "this is the input text", | ||
} | ||
|
||
input_random_score = { | ||
"document": input_random_document, | ||
"index": random.randint(-99999, 99999), | ||
"score": random.uniform(-99999, 99999), | ||
"text": "".join(random.choices(string.printable, k=100)), | ||
} | ||
|
||
input_random_score_3 = { | ||
"document": {"text": "random foo3"}, | ||
"index": random.randint(-99999, 99999), | ||
"score": random.uniform(-99999, 99999), | ||
"text": "".join(random.choices(string.printable, k=100)), | ||
} | ||
|
||
input_scores = [dm.RerankScore(**input_score), dm.RerankScore(**input_random_score)] | ||
input_scores2 = [ | ||
dm.RerankScore(**input_random_score), | ||
dm.RerankScore(**input_random_score_3), | ||
] | ||
|
||
input_result_1 = {"query": "foo", "scores": input_scores} | ||
input_result_2 = {"query": "bar", "scores": input_scores2} | ||
input_results = [ | ||
dm.RerankQueryResult(**input_result_1), | ||
dm.RerankQueryResult(**input_result_2), | ||
] | ||
|
||
input_sentence_similarity_scores_1 = { | ||
"scores": [random.uniform(-99999, 99999) for _ in range(10)] | ||
} | ||
input_sentence_similarity_scores_2 = { | ||
"scores": [random.uniform(-99999, 99999) for _ in range(10)] | ||
} | ||
|
||
input_sentence_similarities_scores = [ | ||
dm.SentenceScores(**input_sentence_similarity_scores_1), | ||
dm.SentenceScores(**input_sentence_similarity_scores_2), | ||
] | ||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
All of these could be pytest fixtures
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
tests/data_model/test_reranker.py
Outdated
def assert_fields_match(data_object, inputs): | ||
for k, v in inputs.items(): | ||
assert getattr(data_object, k) == inputs[k] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
def assert_fields_match(data_object, inputs): | |
for k, v in inputs.items(): | |
assert getattr(data_object, k) == inputs[k] | |
def assert_fields_match(data_object, inputs): | |
assert all(getattr(data_object, key) == value for key, value in inputs.items()) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
I think I addressed all the comments in code and/or in reply. Sorry for the delay. I thought I was nearly done and then found a bunch that had been collapsed. Thanks again for the advice! |
README.md
Outdated
| EmbeddingTask | 1. `TextEmbedding` | 1. text/embedding from a local sentence-transformers model | ||
| EmbeddingTasks | 1. `TextEmbedding` | 1. Same as EmbeddingTask but multiple sentences (texts) as input and corresponding list of outputs. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we combine these? Like we do for prompt tuning?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
updated README.md please take a look
README.md
Outdated
| SentenceSimilarityTask | 1. `TextEmbedding` | 1. text/sentence-similarity from a local sentence-transformers model (Hugging Face style API returns scores only in order of input sentences) | | ||
| SentenceSimilarityTasks | 1. `TextEmbedding` | 1. Same as SentenceSimilarityTask but multiple source_sentences (each to be compared to same list of sentences) as input and corresponding lists of outputs. | | ||
| RerankTask | 1. `TextEmbedding` | 1. text/rerank from a local sentence-transformers model (Cohere style API returns top_n scores in order of relevance with index to source and optionally returning inputs) | | ||
| RerankTasks | 1. `TextEmbedding` | 1. Same as RerankTask but multiple queries as input and corresponding lists of outputs. Same list of documents for all queries. | |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same here
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
updated README.md please take a look
@dataobject(package="caikit_data_model.caikit_nlp") | ||
class SentenceScores(DataObjectBase): | ||
scores: List[float] | ||
|
||
|
||
@dataobject(package="caikit_data_model.caikit_nlp") | ||
class SentenceListScores(DataObjectBase): | ||
|
||
results: List[SentenceScores] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just realized producer_id
isn't present in these. May be you can add these in your caikit
interface PR
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks like that is only used in text_gen and only on the task output objects. I'll try following that pattern if that is the preferred thing going forward.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
left 2 small comments. Other than that it LGTM
@@ -0,0 +1,7 @@ | |||
# These can be installed with --no-deps. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
hm, is this to separate out dependencies before we add other mechanisms?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I did this because there was some concern about dependencies, but I'll admit this probably is not the best way to deal with it. I'd actually recommend I remove this and add extras handling like caikit in a separate PR. Is that preferred? Or is it better to just accept the dependencies.
def run_rerank_queries( | ||
self, | ||
queries: List[str], | ||
documents: List[JsonDict], |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is jsondict and dict same? Wondering if the pure python experience for using these functions will be problematic? i.e if I want to use this function in a notebook or something and then I'll need to first convert my documents to JsonDict ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Caikit won't allow Dict[str, Any]. I would have to provide data objects to wrap Any and that is exactly what JsonDict does for me. So this only works with JsonDict unless there is some smaller piece where I can use dict and hide part of the JsonDict usage, but if that is even allowed we'd still have JsonDict output so I think that is not helpful.
So in general, this must be the caikit way, but if you think I can back-off part of that and use Python please clarify and I'll do some more testing to see where that breaks down. I'd love to use plain Python.
Signed-off-by: markstur <[email protected]>
* Less data objects and more primitives * Fixes str,str limitation in the input JSON * Add tests * More ready for review changes Signed-off-by: markstur <[email protected]>
* Tests * Work on save() Signed-off-by: markstur <[email protected]>
* Error message had wrong var in f-string message * Added test to catch that mistake * Added save tests and empty queries/docs test to complete coverage Signed-off-by: markstur <[email protected]>
* rerank run() will only do one query * adding reranks run_queries() for multiple queries with multi-task (coming soon) Signed-off-by: markstur <[email protected]>
…-task * The EmbeddingModule now does all 3 tasks (same loaded model) * An additional 3 tasks allow multiple texts, source_sentences, or queries. - the documents or sentences compared to are the same for each * Added more docs Signed-off-by: markstur <[email protected]>
Signed-off-by: markstur <[email protected]>
Signed-off-by: markstur <[email protected]>
* More docstrings to help code readers (doc viewers?) * Renamed RerankPrediction -> RerankPredictions since plural is better as it is being used for multiple queries each with a RerankQueryResult with scores. Signed-off-by: markstur <[email protected]>
* Some misc clean-up based on review feedback * Use pytest fixtures in the tests Signed-off-by: markstur <[email protected]>
Signed-off-by: markstur <[email protected]>
Signed-off-by: markstur <[email protected]>
* Handling ModuleNotFound so that we can move extras to extras in the future * Testing with pip install --nodeps of only the minimum (probably to be replaced with full import of sentence-transformers in extras in the future) Signed-off-by: markstur <[email protected]>
* Moved interfaces (tasks and datamodels) to caikit * Updated code here to the new interfaces with added producer_id and related changes to the data models Signed-off-by: markstur <[email protected]>
…rfaces Signed-off-by: markstur <[email protected]>
Signed-off-by: markstur <[email protected]>
Signed-off-by: markstur <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
Add rerank and sentence-similarity tasks to text embedding module.
This PR is stacked on the first embedding PR #224.
Since embedding, sentence-similarity, and rerank all work with a sentence-transformer model this module will load a model and is able to run multiple tasks.
In addition to the 3 (embed, rerank, sentence-similarity), there are another 3 so that each is not limited to a single input string. In the case of sentence-similarity and rerank this means that a list of source_sentences or queries (respectively) are each applied against the same list of sentences or documents. So there is a real benefit many queries can be sent against a large collection of documents. In the case of embeddings, this is simple batching.
Text Embedding Module
Implements the following tasks:
a list of outputs
More details for sentence-similarity and rerank...
sentence-similarity is a common and simple concept (see Hugging Face or Sentence Transformers)
class SentenceSimilarityTask(TaskBase):
"""Compare the source_sentence to each of the sentences.
Result contains a list of scores in the order of the input sentences.
"""
@task(
required_parameters={"source_sentences": List[str], "sentences": List[str]},
output_type=SentenceListScores
)
class SentenceSimilarityTasks(TaskBase):
"""Compare each of the source_sentences to each of the sentences.
Returns a list of results in the order of the source_sentences.
Each result contains a list of scores in the order of the input sentences.
"""
rerank is less intuitive, but is popular for RAG and chaining. One of the more popular rerank APIs is from Cohere. This implementation is similar to their API.
class RerankTask(TaskBase):
"""Returns an ordered list ranking the most relevant documents for the query
class RerankTasks(TaskBase):
"""Returns an ordered list for each query ranking the most relevant documents for the query