Skip to content

Commit

Permalink
Moved to using nest_asyncio for the blocking API. Close NVIDIA#3. C…
Browse files Browse the repository at this point in the history
…lose NVIDIA#32.
  • Loading branch information
drazvan committed Sep 1, 2023
1 parent b103ab6 commit 8f789c4
Show file tree
Hide file tree
Showing 8 changed files with 109 additions and 34 deletions.
5 changes: 4 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,15 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Support for [`PROMPTS_DIR`](./docs/user_guide/advanced/prompt-customization.md#prompt-configuration).
- [#101](https://github.com/NVIDIA/NeMo-Guardrails/pull/101) Support for [using OpenAI embeddings](./docs/user_guide/configuration-guide.md#the-embeddings-model) models in addition to SentenceTransformers.

### Changed

- Moved to using `nest_asyncio` for [implementing the blocking API](./docs/user_guide/advanced/nested-async-loop.md). Fixes [#3](https://github.com/NVIDIA/NeMo-Guardrails/issues/3) and [#32](https://github.com/NVIDIA/NeMo-Guardrails/issues/32).

### Fixed

- Fixed when the `init` function from `config.py` is called to allow custom LLM providers to be registered inside.
- [#93](https://github.com/NVIDIA/NeMo-Guardrails/pull/93): Removed redundant `hasattr` check in `nemoguardrails/llm/params.py`.


## [0.4.0] - 2023-08-03

### Added
Expand Down
7 changes: 7 additions & 0 deletions docs/user_guide/advanced/nested-async-loop.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
# Nested AsyncIO Loop

NeMo Guardrails is an async-first toolkit, i.e., the core functionality is implemented using async functions. To provide a blocking API, the toolkit must invoke async functions inside synchronous code using `asyncio.run`. However, the current Python implementation for `asyncio` does not allow "nested event loops". This issue is being discussed by the Python core team and, most likely, support will be added (see [GitHub Issue 66435](https://github.com/python/cpython/issues/66435) and [Pull Request 93338](https://github.com/python/cpython/pull/93338)).

Meanwhile, NeMo Guardrails makes use of [nest_asyncio](https://github.com/erdewit/nest_asyncio). The patching is applied when the `nemoguardrails` package is loaded the first time.

If the blocking API is not needed, or the `nest_asyncio` patching causes unexpected problems, you can disable it by setting the `DISABLE_NEST_ASYNCIO=True` environment variable.
3 changes: 3 additions & 0 deletions nemoguardrails/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,4 +14,7 @@
# limitations under the License.

"""NeMo Guardrails Toolkit."""
from . import patch_asyncio
from .rails import LLMRails, RailsConfig

patch_asyncio.apply()
54 changes: 54 additions & 0 deletions nemoguardrails/patch_asyncio.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
# SPDX-FileCopyrightText: Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http:https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import asyncio
import os

import nest_asyncio

# Keep track of whether the patch was applied or not
nest_asyncio_patch_applied = False


def apply():
global nest_asyncio_patch_applied

if os.environ.get("DISABLE_NEST_ASYNCIO", "false").lower() not in [
"true",
"1",
"yes",
]:
nest_asyncio.apply()
nest_asyncio_patch_applied = True


def check_sync_call_from_async_loop():
"""Helper to check if a sync call is made from an async loop.
Returns
True if a sync call is made from an async loop.
"""
if nest_asyncio_patch_applied:
return False

try:
loop = asyncio.get_running_loop()
except RuntimeError:
loop = None

if loop and loop.is_running():
return True

return False
15 changes: 3 additions & 12 deletions nemoguardrails/rails/llm/llmrails.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
get_llm_provider_names,
)
from nemoguardrails.logging.stats import llm_stats
from nemoguardrails.patch_asyncio import check_sync_call_from_async_loop
from nemoguardrails.rails.llm.config import RailsConfig
from nemoguardrails.rails.llm.utils import get_history_cache_key

Expand Down Expand Up @@ -293,12 +294,7 @@ def generate(
):
"""Synchronous version of generate_async."""

try:
loop = asyncio.get_running_loop()
except RuntimeError:
loop = None

if loop and loop.is_running():
if check_sync_call_from_async_loop():
raise RuntimeError(
"You are using the sync `generate` inside async code. "
"You should replace with `await generate_async(...)."
Expand Down Expand Up @@ -345,12 +341,7 @@ async def generate_events_async(self, events: List[dict]) -> List[dict]:
def generate_events(self, events: List[dict]) -> List[dict]:
"""Synchronous version of `LLMRails.generate_events_async`."""

try:
loop = asyncio.get_running_loop()
except RuntimeError:
loop = None

if loop and loop.is_running():
if check_sync_call_from_async_loop():
raise RuntimeError(
"You are using the sync `generate_events` inside async code. "
"You should replace with `await generate_events_async(...)."
Expand Down
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -14,3 +14,4 @@ httpx==0.23.3
simpleeval==0.9.13
typing-extensions==4.5.0
Jinja2==3.1.2
nest-asyncio==1.5.6
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@
"simpleeval==0.9.13",
"typing-extensions==4.5.0",
"Jinja2==3.1.2",
"nest-asyncio==1.5.6",
],
extras_require={
"eval": ["tqdm~=4.65", "numpy~=1.24"],
Expand Down
57 changes: 36 additions & 21 deletions tests/test_async_run.py → tests/test_nest_asyncio.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,34 +12,49 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import asyncio
import importlib

import pytest

import nemoguardrails
from nemoguardrails import RailsConfig
from tests.utils import TestChat

config = RailsConfig.from_content(yaml_content="""models: []""")

chat = TestChat(
config,
llm_completions=[
"Hello there!",
"Hello there!",
"Hello there!",
],
)


def test_sync_api():
chat >> "Hi!"
chat << "Hello there!"


@pytest.mark.asyncio
async def test_1():
"""Test that setting variables in context works correctly."""
config = RailsConfig.from_content(
"""
define user express greeting
"hello"
define flow
user express greeting
bot express greeting
"""
)
chat = TestChat(
config,
llm_completions=[
" express greeting",
' "Hello John!"',
],
)
async def test_async_api():
chat >> "Hi!"
chat << "Hello there!"


@pytest.mark.asyncio
async def test_async_api_error(monkeypatch):
monkeypatch.setenv("DISABLE_NEST_ASYNCIO", "True")

# Reload the module to re-run its top-level code with the new env var
importlib.reload(nemoguardrails)
importlib.reload(asyncio)

with pytest.raises(
RuntimeError, match="You are using the sync `generate` inside async code."
RuntimeError,
match=r"asyncio.run\(\) cannot be called from a running event loop",
):
chat.app.generate(messages=[{"role": "user", "content": "Hello!"}])
chat >> "Hi!"
chat << "Hello there!"

0 comments on commit 8f789c4

Please sign in to comment.