Skip to content

Commit

Permalink
model_run_utils: use enum values as finish_reason
Browse files Browse the repository at this point in the history
fixes caikit#245

Signed-off-by: Daniele Trifirò <[email protected]>
  • Loading branch information
dtrifiro committed Nov 3, 2023
1 parent bb02505 commit 9c0330c
Show file tree
Hide file tree
Showing 2 changed files with 49 additions and 3 deletions.
7 changes: 4 additions & 3 deletions caikit_nlp/toolkit/text_generation/model_run_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from caikit.core.data_model.producer import ProducerId
from caikit.core.exceptions import error_handler
from caikit.interfaces.nlp.data_model import (
FinishReason,
GeneratedTextResult,
GeneratedTextStreamResult,
TokenStreamDetails,
Expand Down Expand Up @@ -237,16 +238,16 @@ def generate_text_func(
if (eos_token and tokenizer.decode(generate_ids[0, -1].item()) == eos_token) or (
generate_ids[0, -1] == tokenizer.eos_token_id
):
finish_reason = "EOS_TOKEN"
finish_reason = FinishReason.EOS_TOKEN
elif ("stopping_criteria" in gen_optional_params) and (
gen_optional_params["stopping_criteria"](
generate_ids,
None, # scores, unused by SequenceStoppingCriteria
)
):
finish_reason = "STOP_SEQUENCE"
finish_reason = FinishReason.STOP_SEQUENCE
else:
finish_reason = "MAX_TOKENS"
finish_reason = FinishReason.MAX_TOKENS

return GeneratedTextResult(
generated_tokens=token_count,
Expand Down
45 changes: 45 additions & 0 deletions tests/toolkit/text_generation/test_model_run_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
# Third Party
import pytest

# First Party
from caikit.core.data_model.producer import ProducerId
from caikit.interfaces.nlp.data_model import GeneratedTextResult

# Local
from caikit_nlp.toolkit.text_generation.model_run_utils import generate_text_func
from tests.fixtures import (
causal_lm_dummy_model,
causal_lm_train_kwargs,
seq2seq_lm_dummy_model,
seq2seq_lm_train_kwargs,
)


@pytest.mark.parametrize(
"model_fixture", ["seq2seq_lm_dummy_model", "causal_lm_dummy_model"]
)
@pytest.mark.parametrize(
"serialization_method,expected_type",
[
("to_dict", dict),
("to_json", str),
("to_proto", GeneratedTextResult._proto_class),
],
)
def test_generate_text_func_serialization_json(
request,
model_fixture,
serialization_method,
expected_type,
):
model = request.getfixturevalue(model_fixture)
generated_text = generate_text_func(
model=model.model,
tokenizer=model.tokenizer,
producer_id=ProducerId("TextGeneration", "0.1.0"),
eos_token="<\n>",
text="What is the boiling point of liquid Nitrogen?",
)

serialized = getattr(generated_text, serialization_method)()
assert isinstance(serialized, expected_type)

0 comments on commit 9c0330c

Please sign in to comment.