-
Notifications
You must be signed in to change notification settings - Fork 46
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add embedding task #224
Add embedding task #224
Conversation
Tag @gabe-l-hart and will tag Anjali if I can find her. |
Signed-off-by: gkumbhat <[email protected]>
Signed-off-by: gkumbhat <[email protected]>
Signed-off-by: gkumbhat <[email protected]>
Signed-off-by: gkumbhat <[email protected]>
Return a list of vectors. One for each input sentence. Uses sentence-transformers. The data model allows for different float types (py, np.float32, np.float64). Signed-off-by: markstur <[email protected]> Co-authored-by: gkumbhat <[email protected]>
Signed-off-by: markstur <[email protected]>
Need more approx() wrappers to pass CI. Signed-off-by: markstur <[email protected]>
* More type checks * Ensure to JSON uses consistent keys not varying one-of names * More tests Signed-off-by: markstur <[email protected]>
* Was not safe for use with existing dir or empty path because errors lead to rmtree. * Added checks and tests * Some additional cleanup Signed-off-by: markstur <[email protected]>
* One sentence in, one vector out * Use bootstrap/save to create a model config with model artifacts * Simplified: * Removed the support for sentences/vectors * Removed the hf_model download Signed-off-by: markstur <[email protected]>
Signed-off-by: markstur <[email protected]>
0f760f5
to
af35dd9
Compare
rebased to catch up to main (and so I can stack PRs with multi-task support) |
|
||
@dataobject(package="caikit_data_model.caikit_nlp") | ||
@dataclass | ||
class Vector1D(DataObjectBase): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: we usually try to name the output object name conveying the "task" related output. Since this is directly output of EmbeddingTask, can we rename this to EmbeddingVector
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We may also want to use the Vector1D later on.. so may be it would make sense to go with EmbeddingResponse
with Vector1D
used in it.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Changed to EmbeddingResponse with Vector1D in it.
My opinion? I tend to dislike these extra levels like resp.result.data.values[2]
when resp[2] would've been much nicer considering we really just want a List[float] (more or less). But I guess it looks pretty good this new way.
) | ||
|
||
@classmethod | ||
def from_embeddings(cls, embeddings): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: Can we rename this function to from_vector
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
f"model_path '{model_config_path}' is invalid", | ||
) | ||
|
||
model_config_path = os.path.abspath( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why do we need absolute path here?
An example of how we use ModuleSaver: https://github.com/caikit/caikit-nlp/blob/main/caikit_nlp/modules/text_generation/text_generation_local.py#L482
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The ModuleSaver docstring says it takes an "absolute path":
model_path (str): The absolute path to the directory where the model
will be saved. If this directory does not exist, it will be
created.
It doesn't look like that is necessary today, but better to follow the doc.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
re: example using ModuleSaver... I'd recommend you do not use the ModuleSaver context manager. It isn't safe. caikit/caikit#525
saver.update_config({self._ARTIFACTS_PATH_KEY: artifacts_path}) | ||
|
||
# Save the model | ||
artifacts_path = os.path.abspath( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is absolute path required here
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nope. Removed.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
few small things
@@ -12,5 +12,8 @@ | |||
# See the License for the specific language governing permissions and | |||
# limitations under the License. | |||
|
|||
# First Party | |||
from caikit.core import TaskBase, task |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
there's no task added to this file - was this added for another particular reason?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
fixed
|
||
def test_vector1d_trick(): | ||
"""FYI -- The param check currently allows for objects with values using this trick""" | ||
dm.Vector1D(data=TRICK_SEQUENCE(values=[1.1, 2.2])) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is there anything we can assert on the resulting object to make sure that this is not just not-erroring?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I removed this test. It was redundant with below where the results are validated. Less confusing to just get rid of this one.
from .embedding_retrieval_task import EmbeddingRetrievalTask | ||
from caikit_nlp.data_model.embedding_vectors import Vector1D | ||
|
||
logger = alog.use_channel("<EMBD_BLK>") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: probably don't need the <>
or there might be multiple bracket layers https://github.com/IBM/alchemy-logging/blob/main/src/python/README.md#log-contexts
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks!
@@ -0,0 +1,96 @@ | |||
"""Tests for sequence classification module |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
embedding or embedding retrieval module?
also nit: should we rename the file to test_embedding
to match the module file?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fixed comment. Renamed test file to match. Did some other renaming for more consistency.
* Renames for more consistent naming * Adding a level EmbeddingResult -> Vector1D for consistent naming and future use * Remove an abspath() call that isn't needed for model save() * Remove a redundant test (was confusing) Signed-off-by: markstur <[email protected]>
548fd97
to
2d61326
Compare
@gkumbhat anything I need to do to unblock this one? Not sure if it helps if click the "resolve conversation" buttons. Up to you. |
|
||
|
||
@module( | ||
"EEB12558-B4FA-4F34-A9FD-3F5890E9CD3F", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
interesting change I noticed that the id here contains all caps (plus numbers).. Other modules have small cases. There isn't a rule enforcing around it, but might be good to be consistent. 🤔
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done (in-coming)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Left some small things, but other than that it looks good. Thanks @markstur
""" | ||
error.type_check("<NLP27491611E>", str, input=input) | ||
|
||
return EmbeddingResult(Vector1D.from_vector(self.model.encode(input))) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
do we need to do some validation on model.encode
output before converting it to embedding via vector or is it guaranteed to be a compatible vector always?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Always returns a tensor that works with Vector1D.from_vector(). There is nothing to validate here. If the dtype is something we didn't expect we use PyFloatSequence(). I could add a post_init so PyFloatSequence() evaluates each value, but I don't have a real case for that right now.
) | ||
|
||
# Get and update config (artifacts_path) | ||
artifacts_path = saver.config.get(self._ARTIFACTS_PATH_KEY) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this will always be none right? Unless we are overriding the model ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
right, I can remove this
|
||
# Save the model | ||
self.model.save( | ||
os.path.join(model_config_path, artifacts_path), create_model_card=True |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
how does the model card gets saved and consumed?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'll remove the option, but it'a a model card that is typically stored with the model files documenting what it there.
model_config_path.strip() | ||
) # No leading/trailing spaces sneaky weirdness | ||
|
||
os.makedirs(model_config_path, exist_ok=False) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We can also use saver.add_dirs
functions from here: https://github.com/caikit/caikit/blob/a16f063f6155f0088eb9959a32b7f0871e89731d/caikit/core/modules/saver.py#L117
This will avoid need to also make the absolute path etc above.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We can also use the context base use of ModuleSaver
and that can take care of adding / making the directory as well. Example: https://github.com/caikit/caikit-nlp/blob/main/caikit_nlp/modules/text_generation/peft_prompt_tuning.py#L472
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We can also use
saver.add_dirs
functions from here: https://github.com/caikit/caikit/blob/a16f063f6155f0088eb9959a32b7f0871e89731d/caikit/core/modules/saver.py#L117This will avoid need to also make the absolute path etc above.
os.makedirs() is used because in-place updates are not safe to run on existing dirs. I use exist_ok=False to enforce that where ModuleSaver would not and can remove existing file trees on exceptions. Not cool.
The abspath() is per ModuleSaver param docstrings. strip() should not be necessary but seems better than testing what blanks would do for user experience.
Note: the net of these things (and below) is that using save.add_dir() would only add a line of code and not save anything. Otherwise add_dir() would be fine.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We can also use the context base use of
ModuleSaver
and that can take care of adding / making the directory as well. Example: https://github.com/caikit/caikit-nlp/blob/main/caikit_nlp/modules/text_generation/peft_prompt_tuning.py#L472
context for ModuleSaver is pretty dangerous here. Use save() wrong and ModuleSaver will wipe out a directory that might be important.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
but that is related to os.mkdirs
only right? So that can be fixed by adding this an option in the context manager or in the caikit repo itself? May be club it with your DM PR in caikit ?
* Removed some unnecessary things in save() * Lowercase module GUID for consistency Signed-off-by: markstur <[email protected]>
@gkumbhat Updated based on feedback. I don't want to use some of the saver() code until it is fixed to be safe and even revert partial changes, but we did get some unnecessary stuff out of save(). Thanks! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
Adds an embedding-retrieval task to get embedding vector for a sentence.
Data model is based on Gabe's feedback in #39
This embedding service will be extended in separate PRs. With multi-task support. This service can also support sentence-similarity and reranking.