Skip to content

Commit

Permalink
add: preserve input text to local inference
Browse files Browse the repository at this point in the history
Signed-off-by: Sukriti-Sharma4 <[email protected]>
  • Loading branch information
Ssukriti committed Nov 9, 2023
1 parent 4a5b2f8 commit 642028e
Showing 1 changed file with 17 additions and 1 deletion.
18 changes: 17 additions & 1 deletion caikit_nlp/toolkit/text_generation/model_run_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,7 @@ def generate_text_func(
Union[Tuple[int, float], ExponentialDecayLengthPenalty]
] = None,
stop_sequences: Optional[List[str]] = None,
preserve_input_text: bool = True,
**kwargs,
):
"""
Expand All @@ -164,6 +165,9 @@ def generate_text_func(
Caikit producer id associated with the module
eos_token: str
End of sequence token to be used with generation
preserve_input_text: bool
Whether or not the source string should be contained in the generated output,
e.g., as a prefix. Default True. (Source string will apprear as prefix)
{}
Returns:
GeneratedTextResult
Expand Down Expand Up @@ -235,6 +239,18 @@ def generate_text_func(
for g in generate_ids
]

if preserve_input_text!=True:
prompt_length = len(
tokenizer.decode(
inputs["input_ids"][0],
skip_special_tokens=True,
clean_up_tokenization_spaces=True,
)
)
generated_text = preds[0][prompt_length:]
else:
generated_text = preds[0]

if (eos_token and tokenizer.decode(generate_ids[0, -1].item()) == eos_token) or (
generate_ids[0, -1] == tokenizer.eos_token_id
):
Expand All @@ -251,7 +267,7 @@ def generate_text_func(

return GeneratedTextResult(
generated_tokens=token_count,
generated_text=preds[0],
generated_text=generated_text,
finish_reason=finish_reason,
producer_id=producer_id,
input_token_count=input_token_count,
Expand Down

0 comments on commit 642028e

Please sign in to comment.