Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature/huggingface pipeline streaming #331

Merged
merged 6 commits into from
Feb 15, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
Added NeMo Guardrails streaming support for LLMs deployed using Huggi…
…ngFacePipeline.
  • Loading branch information
trebedea committed Feb 14, 2024
commit a61b28aa1f96d5d009324cc7914addb578b92814
4 changes: 3 additions & 1 deletion docs/user_guides/llm-support.md
Original file line number Diff line number Diff line change
Expand Up @@ -25,19 +25,21 @@ If you want to use an LLM and you cannot see a prompt in the [prompts folder](..
| Dialog Rails | :heavy_check_mark: (0.74) | :heavy_check_mark: (0.83) | :heavy_check_mark: (0.82) | :heavy_check_mark: (0.77) | :heavy_check_mark: (0.76) | :exclamation: (0.45) | :exclamation: | :exclamation: (0.54) | :exclamation: (0.54) | :exclamation: (0.50) | :exclamation: (0.40) | :exclamation: _(DEPENDS ON MODEL)_ |
| • Single LLM call | :heavy_check_mark: (0.83) | :heavy_check_mark: (0.81) | :heavy_check_mark: | :x: | :x: | :x: | :x: | :x: | :x: | :x: | :x: | :x: |
| • Multi-step flow generation | _EXPERIMENTAL_ | _EXPERIMENTAL_ | :x: | :x: | :x: | :x: | :x: | :x: | :x: | :x: | :x: | :x: |
| Streaming | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | :x: | :x: | :heavy_check_mark: | :heavy_check_mark: | :x: | :x: | :x: | :x: | :x: |
| Streaming | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | - | - | :heavy_check_mark: | :heavy_check_mark: | - | - | - | - | :heavy_check_mark: |
| Hallucination detection (SelfCheckGPT with AskLLM) | :heavy_check_mark: | :heavy_check_mark: | :x: | :x: | :x: | :x: | :x: | :x: | :x: | :x: | :x: | :x: |
| AskLLM rails | | | | | | | | | | | | |
| • Jailbreak detection | :heavy_check_mark: (0.88) | :heavy_check_mark: (0.88) | :heavy_check_mark: (0.86) | :x: | :x: | :heavy_check_mark: (0.85) | :x: | :x: | :x: | :x: | :x: | :x: |
| • Output moderation | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | :x: | :x: | :heavy_check_mark: (0.85) | :x: | :x: | :x: | :x: | :x: | :x: |
| • Fact-checking | :heavy_check_mark: (0.81) | :heavy_check_mark: (0.82) | :heavy_check_mark: (0.81) | :heavy_check_mark: (0.80) | :x: | :heavy_check_mark: (0.83) | :x: | :x: | :x: | :x: | :x: | :exclamation: _(DEPENDS ON MODEL)_ |
| AlignScore fact-checking _(LLM independent)_ | :heavy_check_mark: (0.89) | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: |
| ActiveFence moderation _(LLM independent)_ | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: |
| Llama Guard moderation _(LLM independent)_ | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: |

Table legend:
- :heavy_check_mark: - Supported (_The feature is fully supported by the LLM based on our experiments and tests_)
- :exclamation: - Limited Support (_Experiments and tests show that the LLM is under-performing for that feature_)
- :x: - Not Supported (_Experiments show very poor performance or no experiments have been done for the LLM-feature pair_)
- \- - Not Applicable (_e.g. models support streaming, it depends how they are deployed_)

The performance numbers reported in the table above for each LLM-feature pair are as follows:
- the banking dataset evaluation for dialog (topical) rails
Expand Down
1 change: 1 addition & 0 deletions examples/configs/llm/hf_pipeline_dolly/README.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# HuggingFace Pipeline with Dolly models

This configuration uses the HuggingFace Pipeline LLM with the [dolly-v2-3b](https://huggingface.co/databricks/dolly-v2-3b) model.
It also shows how to support streaming in NeMo Guardrails for LLMs deployed using HuggingFacePipeline.

The `dolly-v2-3b` LLM model has been tested on the topical rails evaluation sets, results are available [here](../../../../nemoguardrails/eval/README.md).

Expand Down
41 changes: 34 additions & 7 deletions examples/configs/llm/hf_pipeline_dolly/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,27 +14,54 @@
# limitations under the License.
from functools import lru_cache

from torch.cuda import device_count
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer, pipeline

from nemoguardrails.llm.helpers import get_llm_instance_wrapper
from nemoguardrails.llm.providers import (
HuggingFacePipelineCompatible,
register_llm_provider,
)

# Flag to test streaming in NeMo Guardrails of HuggingFacePipeline models
ENABLE_STREAMING = True


@lru_cache
def get_dolly_v2_3b_llm():
repo_id = "databricks/dolly-v2-3b"
params = {"temperature": 0, "max_length": 1024}
name = "databricks/dolly-v2-3b"

config = AutoConfig.from_pretrained(name, trust_remote_code=True)
device = "cuda:0"
config.init_device = device
config.max_seq_len = 450

model = AutoModelForCausalLM.from_pretrained(
name,
config=config,
trust_remote_code=True,
)
tokenizer = AutoTokenizer.from_pretrained(name)
params = {"temperature": 0.01, "max_new_tokens": 100}

# Set this flag to False if you do not require streaming for the HuggingFacePipeline
if ENABLE_STREAMING:
from nemoguardrails.llm.providers.huggingface import AsyncTextIteratorStreamer

# Use the first CUDA-enabled GPU, if any
device = 0 if device_count() else -1
streamer = AsyncTextIteratorStreamer(tokenizer, skip_prompt=True)
params = {"temperature": 0.01, "max_new_tokens": 100, "streamer": streamer}

llm = HuggingFacePipelineCompatible.from_model_id(
model_id=repo_id, device=device, task="text-generation", model_kwargs=params
pipe = pipeline(
model=model,
task="text-generation",
tokenizer=tokenizer,
device=device,
do_sample=True,
use_cache=True,
**params,
)

llm = HuggingFacePipelineCompatible(pipeline=pipe, model_kwargs=params)

return llm


Expand Down
3 changes: 3 additions & 0 deletions examples/configs/llm/hf_pipeline_dolly/config.yml
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,9 @@ models:
- type: main
engine: hf_pipeline_dolly

# Remove attribute / set to False if streaming is not required
streaming: True

instructions:
- type: general
content: |
Expand Down
5 changes: 5 additions & 0 deletions examples/configs/llm/hf_pipeline_dolly/rails.co
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,11 @@ define user ask capabilities
"tell me what you can do"
"tell me about you"

define user ask general question
"How is the weather tomorrow?"
"Can you tell me which is the best movie this week"
"i would like to know the best scifi books of all time."

define flow
user express greeting
bot express greeting
Expand Down
21 changes: 21 additions & 0 deletions nemoguardrails/llm/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,14 @@ def get_llm_instance_wrapper(
"""

class WrapperLLM(LLM):
"""The wrapper class needs to have defined any parameters we need to be set by NeMo Guardrails.

Currently added only streaming and temperature.
"""

streaming: Optional[bool] = False
temperature: Optional[float] = 1.0

@property
def model_kwargs(self):
"""Return the model's kwargs.
Expand All @@ -50,12 +58,24 @@ def _llm_type(self) -> str:
"""
return llm_type

def _modify_instance_kwargs(self):
"""Modify the parameters of the llm_instance with the attributes set for the wrapper.

This will allow the actual LLM instance to use the parameters at generation.
TODO: Make this function more generic if needed.
"""

if hasattr(llm_instance, "model_kwargs"):
llm_instance.model_kwargs["temperature"] = self.temperature
llm_instance.model_kwargs["streaming"] = self.streaming

def _call(
self,
prompt: str,
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
) -> str:
self._modify_instance_kwargs()
return llm_instance._call(prompt, stop, run_manager)

async def _acall(
Expand All @@ -64,6 +84,7 @@ async def _acall(
stop: Optional[List[str]] = None,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
) -> str:
self._modify_instance_kwargs()
return await llm_instance._acall(prompt, stop, run_manager)

return WrapperLLM
16 changes: 16 additions & 0 deletions nemoguardrails/llm/providers/huggingface/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
# SPDX-FileCopyrightText: Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http:https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from .streamers import AsyncTextIteratorStreamer
50 changes: 50 additions & 0 deletions nemoguardrails/llm/providers/huggingface/streamers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
# SPDX-FileCopyrightText: Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http:https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import asyncio

from transformers.generation.streamers import TextStreamer


class AsyncTextIteratorStreamer(TextStreamer):
"""
Simple async implementation for HuggingFace Transformers streamers.

This follows closely how transformers.generation.streamers.TextIteratorStreamer works,
with minor modifications to make it async.
"""

def __init__(
self, tokenizer: "AutoTokenizer", skip_prompt: bool = False, **decode_kwargs
):
super().__init__(tokenizer, skip_prompt, **decode_kwargs)
self.text_queue = asyncio.Queue()
self.stop_signal = None

def on_finalized_text(self, text: str, stream_end: bool = False):
"""Put the new text in the queue. If the stream is ending, also put a stop signal in the queue."""
self.text_queue.put_nowait(text)
if stream_end:
self.text_queue.put_nowait(self.stop_signal)

def __aiter__(self):
return self

async def __anext__(self):
value = await self.text_queue.get()
if value == self.stop_signal:
raise StopAsyncIteration()
else:
return value
38 changes: 36 additions & 2 deletions nemoguardrails/llm/providers/providers.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,15 @@

Additional providers can be registered using the `register_llm_provider` function.
"""
import asyncio
import logging
from typing import Any, Dict, List, Optional, Type

from langchain.base_language import BaseLanguageModel
from langchain.callbacks.manager import CallbackManagerForLLMRun
from langchain.callbacks.manager import (
AsyncCallbackManagerForLLMRun,
CallbackManagerForLLMRun,
)
from langchain.llms.base import LLM
from langchain.llms.huggingface_pipeline import HuggingFacePipeline
from langchain_community import llms
Expand Down Expand Up @@ -62,6 +66,13 @@ def _call(
f"{type(prompt)}. If you want to run the LLM on multiple prompts, use "
"`generate` instead."
)

# Streaming for NeMo Guardrails is not supported in async calls.
if self.model_kwargs.get("streaming"):
raise Exception(
"Streaming mode not supported for HuggingFacePipeline in NeMo Guardrails!"
)

llm_result = self._generate(
[prompt],
stop=stop,
Expand All @@ -74,7 +85,7 @@ async def _acall(
self,
prompt: str,
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> str:
"""
Expand All @@ -86,6 +97,29 @@ async def _acall(
f"{type(prompt)}. If you want to run the LLM on multiple prompts, use "
"`generate` instead."
)

# Handle streaming, if the flag is set
if self.model_kwargs.get("streaming"):
# Retrieve the streamer object, needs to be set in model_kwargs
streamer = self.model_kwargs["streamer"]
if not streamer:
raise Exception(
"Cannot stream, please add HuggingFace streamer object to model_kwargs!"
)

generation_kwargs = dict(
prompts=[prompt], stop=stop, run_manager=run_manager, **kwargs
)
loop = asyncio.get_running_loop()
loop.create_task(self._agenerate(**generation_kwargs))

completion = ""
async for item in streamer:
completion += item
if run_manager:
await run_manager.on_llm_new_token(item)
return completion

llm_result = await self._agenerate(
[prompt],
stop=stop,
Expand Down