From a61b28aa1f96d5d009324cc7914addb578b92814 Mon Sep 17 00:00:00 2001 From: Traian Rebedea Date: Wed, 14 Feb 2024 14:05:38 +0200 Subject: [PATCH 1/5] Added NeMo Guardrails streaming support for LLMs deployed using HuggingFacePipeline. --- docs/user_guides/llm-support.md | 4 +- .../configs/llm/hf_pipeline_dolly/README.md | 1 + .../configs/llm/hf_pipeline_dolly/config.py | 41 ++++++++++++--- .../configs/llm/hf_pipeline_dolly/config.yml | 3 ++ .../configs/llm/hf_pipeline_dolly/rails.co | 5 ++ nemoguardrails/llm/helpers.py | 21 ++++++++ .../llm/providers/huggingface/__init__.py | 16 ++++++ .../llm/providers/huggingface/streamers.py | 50 +++++++++++++++++++ nemoguardrails/llm/providers/providers.py | 38 +++++++++++++- 9 files changed, 169 insertions(+), 10 deletions(-) create mode 100644 nemoguardrails/llm/providers/huggingface/__init__.py create mode 100644 nemoguardrails/llm/providers/huggingface/streamers.py diff --git a/docs/user_guides/llm-support.md b/docs/user_guides/llm-support.md index 1f6a8e32a..1d773fa73 100644 --- a/docs/user_guides/llm-support.md +++ b/docs/user_guides/llm-support.md @@ -25,7 +25,7 @@ If you want to use an LLM and you cannot see a prompt in the [prompts folder](.. | Dialog Rails | :heavy_check_mark: (0.74) | :heavy_check_mark: (0.83) | :heavy_check_mark: (0.82) | :heavy_check_mark: (0.77) | :heavy_check_mark: (0.76) | :exclamation: (0.45) | :exclamation: | :exclamation: (0.54) | :exclamation: (0.54) | :exclamation: (0.50) | :exclamation: (0.40) | :exclamation: _(DEPENDS ON MODEL)_ | | • Single LLM call | :heavy_check_mark: (0.83) | :heavy_check_mark: (0.81) | :heavy_check_mark: | :x: | :x: | :x: | :x: | :x: | :x: | :x: | :x: | :x: | | • Multi-step flow generation | _EXPERIMENTAL_ | _EXPERIMENTAL_ | :x: | :x: | :x: | :x: | :x: | :x: | :x: | :x: | :x: | :x: | -| Streaming | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | :x: | :x: | :heavy_check_mark: | :heavy_check_mark: | :x: | :x: | :x: | :x: | :x: | +| Streaming | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | - | - | :heavy_check_mark: | :heavy_check_mark: | - | - | - | - | :heavy_check_mark: | | Hallucination detection (SelfCheckGPT with AskLLM) | :heavy_check_mark: | :heavy_check_mark: | :x: | :x: | :x: | :x: | :x: | :x: | :x: | :x: | :x: | :x: | | AskLLM rails | | | | | | | | | | | | | | • Jailbreak detection | :heavy_check_mark: (0.88) | :heavy_check_mark: (0.88) | :heavy_check_mark: (0.86) | :x: | :x: | :heavy_check_mark: (0.85) | :x: | :x: | :x: | :x: | :x: | :x: | @@ -33,11 +33,13 @@ If you want to use an LLM and you cannot see a prompt in the [prompts folder](.. | • Fact-checking | :heavy_check_mark: (0.81) | :heavy_check_mark: (0.82) | :heavy_check_mark: (0.81) | :heavy_check_mark: (0.80) | :x: | :heavy_check_mark: (0.83) | :x: | :x: | :x: | :x: | :x: | :exclamation: _(DEPENDS ON MODEL)_ | | AlignScore fact-checking _(LLM independent)_ | :heavy_check_mark: (0.89) | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | | ActiveFence moderation _(LLM independent)_ | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | +| Llama Guard moderation _(LLM independent)_ | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | Table legend: - :heavy_check_mark: - Supported (_The feature is fully supported by the LLM based on our experiments and tests_) - :exclamation: - Limited Support (_Experiments and tests show that the LLM is under-performing for that feature_) - :x: - Not Supported (_Experiments show very poor performance or no experiments have been done for the LLM-feature pair_) +- \- - Not Applicable (_e.g. models support streaming, it depends how they are deployed_) The performance numbers reported in the table above for each LLM-feature pair are as follows: - the banking dataset evaluation for dialog (topical) rails diff --git a/examples/configs/llm/hf_pipeline_dolly/README.md b/examples/configs/llm/hf_pipeline_dolly/README.md index ba9a8ad4f..99c1831ec 100644 --- a/examples/configs/llm/hf_pipeline_dolly/README.md +++ b/examples/configs/llm/hf_pipeline_dolly/README.md @@ -1,6 +1,7 @@ # HuggingFace Pipeline with Dolly models This configuration uses the HuggingFace Pipeline LLM with the [dolly-v2-3b](https://huggingface.co/databricks/dolly-v2-3b) model. +It also shows how to support streaming in NeMo Guardrails for LLMs deployed using HuggingFacePipeline. The `dolly-v2-3b` LLM model has been tested on the topical rails evaluation sets, results are available [here](../../../../nemoguardrails/eval/README.md). diff --git a/examples/configs/llm/hf_pipeline_dolly/config.py b/examples/configs/llm/hf_pipeline_dolly/config.py index 0907ebc31..84cbccfb6 100644 --- a/examples/configs/llm/hf_pipeline_dolly/config.py +++ b/examples/configs/llm/hf_pipeline_dolly/config.py @@ -14,7 +14,7 @@ # limitations under the License. from functools import lru_cache -from torch.cuda import device_count +from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer, pipeline from nemoguardrails.llm.helpers import get_llm_instance_wrapper from nemoguardrails.llm.providers import ( @@ -22,19 +22,46 @@ register_llm_provider, ) +# Flag to test streaming in NeMo Guardrails of HuggingFacePipeline models +ENABLE_STREAMING = True + @lru_cache def get_dolly_v2_3b_llm(): - repo_id = "databricks/dolly-v2-3b" - params = {"temperature": 0, "max_length": 1024} + name = "databricks/dolly-v2-3b" + + config = AutoConfig.from_pretrained(name, trust_remote_code=True) + device = "cuda:0" + config.init_device = device + config.max_seq_len = 450 + + model = AutoModelForCausalLM.from_pretrained( + name, + config=config, + trust_remote_code=True, + ) + tokenizer = AutoTokenizer.from_pretrained(name) + params = {"temperature": 0.01, "max_new_tokens": 100} + + # Set this flag to False if you do not require streaming for the HuggingFacePipeline + if ENABLE_STREAMING: + from nemoguardrails.llm.providers.huggingface import AsyncTextIteratorStreamer - # Use the first CUDA-enabled GPU, if any - device = 0 if device_count() else -1 + streamer = AsyncTextIteratorStreamer(tokenizer, skip_prompt=True) + params = {"temperature": 0.01, "max_new_tokens": 100, "streamer": streamer} - llm = HuggingFacePipelineCompatible.from_model_id( - model_id=repo_id, device=device, task="text-generation", model_kwargs=params + pipe = pipeline( + model=model, + task="text-generation", + tokenizer=tokenizer, + device=device, + do_sample=True, + use_cache=True, + **params, ) + llm = HuggingFacePipelineCompatible(pipeline=pipe, model_kwargs=params) + return llm diff --git a/examples/configs/llm/hf_pipeline_dolly/config.yml b/examples/configs/llm/hf_pipeline_dolly/config.yml index 499d15737..702d10531 100644 --- a/examples/configs/llm/hf_pipeline_dolly/config.yml +++ b/examples/configs/llm/hf_pipeline_dolly/config.yml @@ -2,6 +2,9 @@ models: - type: main engine: hf_pipeline_dolly +# Remove attribute / set to False if streaming is not required +streaming: True + instructions: - type: general content: | diff --git a/examples/configs/llm/hf_pipeline_dolly/rails.co b/examples/configs/llm/hf_pipeline_dolly/rails.co index b4be33f2e..61ef579e4 100644 --- a/examples/configs/llm/hf_pipeline_dolly/rails.co +++ b/examples/configs/llm/hf_pipeline_dolly/rails.co @@ -8,6 +8,11 @@ define user ask capabilities "tell me what you can do" "tell me about you" +define user ask general question + "How is the weather tomorrow?" + "Can you tell me which is the best movie this week" + "i would like to know the best scifi books of all time." + define flow user express greeting bot express greeting diff --git a/nemoguardrails/llm/helpers.py b/nemoguardrails/llm/helpers.py index 434fdea1e..30e9ab9ea 100644 --- a/nemoguardrails/llm/helpers.py +++ b/nemoguardrails/llm/helpers.py @@ -32,6 +32,14 @@ def get_llm_instance_wrapper( """ class WrapperLLM(LLM): + """The wrapper class needs to have defined any parameters we need to be set by NeMo Guardrails. + + Currently added only streaming and temperature. + """ + + streaming: Optional[bool] = False + temperature: Optional[float] = 1.0 + @property def model_kwargs(self): """Return the model's kwargs. @@ -50,12 +58,24 @@ def _llm_type(self) -> str: """ return llm_type + def _modify_instance_kwargs(self): + """Modify the parameters of the llm_instance with the attributes set for the wrapper. + + This will allow the actual LLM instance to use the parameters at generation. + TODO: Make this function more generic if needed. + """ + + if hasattr(llm_instance, "model_kwargs"): + llm_instance.model_kwargs["temperature"] = self.temperature + llm_instance.model_kwargs["streaming"] = self.streaming + def _call( self, prompt: str, stop: Optional[List[str]] = None, run_manager: Optional[CallbackManagerForLLMRun] = None, ) -> str: + self._modify_instance_kwargs() return llm_instance._call(prompt, stop, run_manager) async def _acall( @@ -64,6 +84,7 @@ async def _acall( stop: Optional[List[str]] = None, run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, ) -> str: + self._modify_instance_kwargs() return await llm_instance._acall(prompt, stop, run_manager) return WrapperLLM diff --git a/nemoguardrails/llm/providers/huggingface/__init__.py b/nemoguardrails/llm/providers/huggingface/__init__.py new file mode 100644 index 000000000..61d9267d9 --- /dev/null +++ b/nemoguardrails/llm/providers/huggingface/__init__.py @@ -0,0 +1,16 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .streamers import AsyncTextIteratorStreamer diff --git a/nemoguardrails/llm/providers/huggingface/streamers.py b/nemoguardrails/llm/providers/huggingface/streamers.py new file mode 100644 index 000000000..a04c56feb --- /dev/null +++ b/nemoguardrails/llm/providers/huggingface/streamers.py @@ -0,0 +1,50 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import asyncio + +from transformers.generation.streamers import TextStreamer + + +class AsyncTextIteratorStreamer(TextStreamer): + """ + Simple async implementation for HuggingFace Transformers streamers. + + This follows closely how transformers.generation.streamers.TextIteratorStreamer works, + with minor modifications to make it async. + """ + + def __init__( + self, tokenizer: "AutoTokenizer", skip_prompt: bool = False, **decode_kwargs + ): + super().__init__(tokenizer, skip_prompt, **decode_kwargs) + self.text_queue = asyncio.Queue() + self.stop_signal = None + + def on_finalized_text(self, text: str, stream_end: bool = False): + """Put the new text in the queue. If the stream is ending, also put a stop signal in the queue.""" + self.text_queue.put_nowait(text) + if stream_end: + self.text_queue.put_nowait(self.stop_signal) + + def __aiter__(self): + return self + + async def __anext__(self): + value = await self.text_queue.get() + if value == self.stop_signal: + raise StopAsyncIteration() + else: + return value diff --git a/nemoguardrails/llm/providers/providers.py b/nemoguardrails/llm/providers/providers.py index 0307aba95..cdeba84df 100644 --- a/nemoguardrails/llm/providers/providers.py +++ b/nemoguardrails/llm/providers/providers.py @@ -20,11 +20,15 @@ Additional providers can be registered using the `register_llm_provider` function. """ +import asyncio import logging from typing import Any, Dict, List, Optional, Type from langchain.base_language import BaseLanguageModel -from langchain.callbacks.manager import CallbackManagerForLLMRun +from langchain.callbacks.manager import ( + AsyncCallbackManagerForLLMRun, + CallbackManagerForLLMRun, +) from langchain.llms.base import LLM from langchain.llms.huggingface_pipeline import HuggingFacePipeline from langchain_community import llms @@ -62,6 +66,13 @@ def _call( f"{type(prompt)}. If you want to run the LLM on multiple prompts, use " "`generate` instead." ) + + # Streaming for NeMo Guardrails is not supported in async calls. + if self.model_kwargs.get("streaming"): + raise Exception( + "Streaming mode not supported for HuggingFacePipeline in NeMo Guardrails!" + ) + llm_result = self._generate( [prompt], stop=stop, @@ -74,7 +85,7 @@ async def _acall( self, prompt: str, stop: Optional[List[str]] = None, - run_manager: Optional[CallbackManagerForLLMRun] = None, + run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, **kwargs: Any, ) -> str: """ @@ -86,6 +97,29 @@ async def _acall( f"{type(prompt)}. If you want to run the LLM on multiple prompts, use " "`generate` instead." ) + + # Handle streaming, if the flag is set + if self.model_kwargs.get("streaming"): + # Retrieve the streamer object, needs to be set in model_kwargs + streamer = self.model_kwargs["streamer"] + if not streamer: + raise Exception( + "Cannot stream, please add HuggingFace streamer object to model_kwargs!" + ) + + generation_kwargs = dict( + prompts=[prompt], stop=stop, run_manager=run_manager, **kwargs + ) + loop = asyncio.get_running_loop() + loop.create_task(self._agenerate(**generation_kwargs)) + + completion = "" + async for item in streamer: + completion += item + if run_manager: + await run_manager.on_llm_new_token(item) + return completion + llm_result = await self._agenerate( [prompt], stop=stop, From b50316e40c3909853705c64304095a5111a2e098 Mon Sep 17 00:00:00 2001 From: Traian Rebedea Date: Wed, 14 Feb 2024 16:18:29 +0200 Subject: [PATCH 2/5] GenerationChunk is needed for on_llm_new_token --- nemoguardrails/llm/providers/providers.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/nemoguardrails/llm/providers/providers.py b/nemoguardrails/llm/providers/providers.py index cdeba84df..71c2e6c2f 100644 --- a/nemoguardrails/llm/providers/providers.py +++ b/nemoguardrails/llm/providers/providers.py @@ -31,6 +31,7 @@ ) from langchain.llms.base import LLM from langchain.llms.huggingface_pipeline import HuggingFacePipeline +from langchain.schema.output import GenerationChunk from langchain_community import llms from nemoguardrails.rails.llm.config import Model @@ -116,8 +117,9 @@ async def _acall( completion = "" async for item in streamer: completion += item + chunk = GenerationChunk(text=item) if run_manager: - await run_manager.on_llm_new_token(item) + await run_manager.on_llm_new_token(item, chunk=chunk) return completion llm_result = await self._agenerate( From 65d714bae172ab16348e09acc3261adeefbdac27 Mon Sep 17 00:00:00 2001 From: Traian Rebedea Date: Wed, 14 Feb 2024 16:39:11 +0200 Subject: [PATCH 3/5] Added some additional documentation and example for streaming with HuggingFacePipeline LLMs. --- docs/user_guides/advanced/streaming.md | 29 ++++++++++++++++++++++++++ examples/scripts/demo_streaming.py | 13 ++++++++++++ 2 files changed, 42 insertions(+) diff --git a/docs/user_guides/advanced/streaming.md b/docs/user_guides/advanced/streaming.md index baf03b12d..b01f56deb 100644 --- a/docs/user_guides/advanced/streaming.md +++ b/docs/user_guides/advanced/streaming.md @@ -86,3 +86,32 @@ POST /v1/chat/completions "stream": true } ``` + +### Streaming for LLMs deployed using HuggingFacePipeline + +We also support streaming for LLMs deployed using `HuggingFacePipeline`. +One example is provided in the [HF Pipeline Dolly](./../../../examples/configs/llm/hf_pipeline_dolly/README.md) configuration. + +To use streaming for HF Pipeline LLMs, you first need to set the streaming flag in your `config.yml`. + +```yaml +streaming: True +``` + +Then you need to create an `nemoguardrails.llm.providers.huggingface.AsyncTextIteratorStreamer` streamer object, +add it to the `kwargs` of the pipeline and to the `model_kwargs` of the `HuggingFacePipelineCompatible` object. + +```python +from nemoguardrails.llm.providers.huggingface import AsyncTextIteratorStreamer + +# instantiate tokenizer object required by LLM +streamer = AsyncTextIteratorStreamer(tokenizer, skip_prompt=True) +params = {"temperature": 0.01, "max_new_tokens": 100, "streamer": streamer} + +pipe = pipeline( + # all other parameters + **params, +) + +llm = HuggingFacePipelineCompatible(pipeline=pipe, model_kwargs=params) +``` diff --git a/examples/scripts/demo_streaming.py b/examples/scripts/demo_streaming.py index 237f869a9..6387c0775 100644 --- a/examples/scripts/demo_streaming.py +++ b/examples/scripts/demo_streaming.py @@ -66,6 +66,19 @@ async def process_tokens(): print(result) +async def demo_hf_pipeline(): + """Demo for streaming of response chunks directly with HuggingFacePipline deployed LLMs.""" + config = RailsConfig.from_path(config_path="./../configs/llm/hf_pipeline_dolly") + app = LLMRails(config) + + history = [{"role": "user", "content": "What is the capital of France?"}] + + async for chunk in app.stream_async(messages=history): + print(f"CHUNK: {chunk}") + # Or do something else with the token + + if __name__ == "__main__": asyncio.run(demo_1()) asyncio.run(demo_2()) + # asyncio.run(demo_hf_pipeline()) From 47e3f1513c75df1077723394253bb0c80fd60884 Mon Sep 17 00:00:00 2001 From: Razvan Dinu Date: Thu, 15 Feb 2024 13:08:47 +0200 Subject: [PATCH 4/5] Disable dialog rails for HF dolly example. --- .../configs/llm/hf_pipeline_dolly/rails.co | 50 ++++++++++--------- 1 file changed, 26 insertions(+), 24 deletions(-) diff --git a/examples/configs/llm/hf_pipeline_dolly/rails.co b/examples/configs/llm/hf_pipeline_dolly/rails.co index 61ef579e4..dba933b0e 100644 --- a/examples/configs/llm/hf_pipeline_dolly/rails.co +++ b/examples/configs/llm/hf_pipeline_dolly/rails.co @@ -1,25 +1,27 @@ -define user express greeting - "Hello" - "Hi" +# Uncomment these to test dialog rails -define user ask capabilities - "What can you do?" - "What can you help me with?" - "tell me what you can do" - "tell me about you" - -define user ask general question - "How is the weather tomorrow?" - "Can you tell me which is the best movie this week" - "i would like to know the best scifi books of all time." - -define flow - user express greeting - bot express greeting - -define flow - user ask capabilities - bot inform capabilities - -define bot inform capabilities - "I am an AI assistant and I'm here to help." +#define user express greeting +# "Hello" +# "Hi" +# +#define user ask capabilities +# "What can you do?" +# "What can you help me with?" +# "tell me what you can do" +# "tell me about you" +# +#define user ask general question +# "How is the weather tomorrow?" +# "Can you tell me which is the best movie this week" +# "i would like to know the best scifi books of all time." +# +#define flow +# user express greeting +# bot express greeting +# +#define flow +# user ask capabilities +# bot inform capabilities +# +#define bot inform capabilities +# "I am an AI assistant and I'm here to help." From 634d15ef6c3562d842ec1023b33cbcb56cae38b7 Mon Sep 17 00:00:00 2001 From: Razvan Dinu Date: Thu, 15 Feb 2024 13:09:43 +0200 Subject: [PATCH 5/5] Tweak and fix the streaming for HuggingFace pipeline. --- .../configs/llm/hf_pipeline_dolly/config.py | 15 ++++++--------- .../llm/providers/huggingface/streamers.py | 7 +++++-- nemoguardrails/llm/providers/providers.py | 19 +++++++++++++++---- 3 files changed, 26 insertions(+), 15 deletions(-) diff --git a/examples/configs/llm/hf_pipeline_dolly/config.py b/examples/configs/llm/hf_pipeline_dolly/config.py index 84cbccfb6..b6851279a 100644 --- a/examples/configs/llm/hf_pipeline_dolly/config.py +++ b/examples/configs/llm/hf_pipeline_dolly/config.py @@ -22,18 +22,15 @@ register_llm_provider, ) -# Flag to test streaming in NeMo Guardrails of HuggingFacePipeline models -ENABLE_STREAMING = True - @lru_cache -def get_dolly_v2_3b_llm(): +def get_dolly_v2_3b_llm(streaming: bool = True): name = "databricks/dolly-v2-3b" config = AutoConfig.from_pretrained(name, trust_remote_code=True) - device = "cuda:0" + device = "cpu" config.init_device = device - config.max_seq_len = 450 + config.max_seq_len = 45 model = AutoModelForCausalLM.from_pretrained( name, @@ -43,12 +40,12 @@ def get_dolly_v2_3b_llm(): tokenizer = AutoTokenizer.from_pretrained(name) params = {"temperature": 0.01, "max_new_tokens": 100} - # Set this flag to False if you do not require streaming for the HuggingFacePipeline - if ENABLE_STREAMING: + # If we want streaming, we create a streamer. + if streaming: from nemoguardrails.llm.providers.huggingface import AsyncTextIteratorStreamer streamer = AsyncTextIteratorStreamer(tokenizer, skip_prompt=True) - params = {"temperature": 0.01, "max_new_tokens": 100, "streamer": streamer} + params["streamer"] = streamer pipe = pipeline( model=model, diff --git a/nemoguardrails/llm/providers/huggingface/streamers.py b/nemoguardrails/llm/providers/huggingface/streamers.py index a04c56feb..d81288fae 100644 --- a/nemoguardrails/llm/providers/huggingface/streamers.py +++ b/nemoguardrails/llm/providers/huggingface/streamers.py @@ -32,12 +32,15 @@ def __init__( super().__init__(tokenizer, skip_prompt, **decode_kwargs) self.text_queue = asyncio.Queue() self.stop_signal = None + self.loop = None def on_finalized_text(self, text: str, stream_end: bool = False): """Put the new text in the queue. If the stream is ending, also put a stop signal in the queue.""" - self.text_queue.put_nowait(text) + if len(text) > 0: + asyncio.run_coroutine_threadsafe(self.text_queue.put(text), self.loop) + if stream_end: - self.text_queue.put_nowait(self.stop_signal) + asyncio.run_coroutine_threadsafe(self.text_queue.put(text), self.loop) def __aiter__(self): return self diff --git a/nemoguardrails/llm/providers/providers.py b/nemoguardrails/llm/providers/providers.py index 71c2e6c2f..e04faa922 100644 --- a/nemoguardrails/llm/providers/providers.py +++ b/nemoguardrails/llm/providers/providers.py @@ -68,7 +68,7 @@ def _call( "`generate` instead." ) - # Streaming for NeMo Guardrails is not supported in async calls. + # Streaming for NeMo Guardrails is not supported in sync calls. if self.model_kwargs.get("streaming"): raise Exception( "Streaming mode not supported for HuggingFacePipeline in NeMo Guardrails!" @@ -102,24 +102,35 @@ async def _acall( # Handle streaming, if the flag is set if self.model_kwargs.get("streaming"): # Retrieve the streamer object, needs to be set in model_kwargs - streamer = self.model_kwargs["streamer"] + streamer = self.model_kwargs.get("streamer") if not streamer: raise Exception( "Cannot stream, please add HuggingFace streamer object to model_kwargs!" ) + loop = asyncio.get_running_loop() + + # Pass the asyncio loop to the stream so that it can send back + # the chunks in the queue. + streamer.loop = loop + + # Launch the generation in a separate task. generation_kwargs = dict( - prompts=[prompt], stop=stop, run_manager=run_manager, **kwargs + prompts=[prompt], + stop=stop, + run_manager=run_manager, + **kwargs, ) - loop = asyncio.get_running_loop() loop.create_task(self._agenerate(**generation_kwargs)) + # And start waiting for the chunks to come in. completion = "" async for item in streamer: completion += item chunk = GenerationChunk(text=item) if run_manager: await run_manager.on_llm_new_token(item, chunk=chunk) + return completion llm_result = await self._agenerate(