Skip to content

Commit

Permalink
core[patch]: fix no current event loop for sql history in async mode (#…
Browse files Browse the repository at this point in the history
…22933)

- **Description:** When use
RunnableWithMessageHistory/SQLChatMessageHistory in async mode, we'll
get the following error:
```
Error in RootListenersTracer.on_chain_end callback: RuntimeError("There is no current event loop in thread 'asyncio_3'.")
```
which throwed by
https://github.com/langchain-ai/langchain/blob/ddfbca38dfa22954eaeda38614c6e1ec0cdecaa9/libs/community/langchain_community/chat_message_histories/sql.py#L259.
and no message history will be add to database.

In this patch, a new _aexit_history function which will'be called in
async mode is added, and in turn aadd_messages will be called.

In this patch, we use `afunc` attribute of a Runnable to check if the
end listener should be run in async mode or not.

  - **Issue:** #22021, #22022 
  - **Dependencies:** N/A
  • Loading branch information
mackong committed Jun 21, 2024
1 parent 1c2b9cc commit 360a70c
Show file tree
Hide file tree
Showing 3 changed files with 345 additions and 35 deletions.
17 changes: 5 additions & 12 deletions libs/community/langchain_community/chat_message_histories/sql.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import asyncio
import contextlib
import json
import logging
Expand Down Expand Up @@ -252,17 +251,11 @@ async def aadd_message(self, message: BaseMessage) -> None:
await session.commit()

def add_messages(self, messages: Sequence[BaseMessage]) -> None:
# The method RunnableWithMessageHistory._exit_history() call
# add_message method by mistake and not aadd_message.
# See https://github.com/langchain-ai/langchain/issues/22021
if self.async_mode:
loop = asyncio.get_event_loop()
loop.run_until_complete(self.aadd_messages(messages))
else:
with self._make_sync_session() as session:
for message in messages:
session.add(self.converter.to_sql_model(message, self.session_id))
session.commit()
# Add all messages in one transaction
with self._make_sync_session() as session:
for message in messages:
session.add(self.converter.to_sql_model(message, self.session_id))
session.commit()

async def aadd_messages(self, messages: Sequence[BaseMessage]) -> None:
# Add all messages in one transaction
Expand Down
37 changes: 35 additions & 2 deletions libs/core/langchain_core/runnables/history.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from langchain_core.chat_history import BaseChatMessageHistory
from langchain_core.load.load import load
from langchain_core.pydantic_v1 import BaseModel
from langchain_core.runnables import RunnableBranch
from langchain_core.runnables.base import Runnable, RunnableBindingBase, RunnableLambda
from langchain_core.runnables.passthrough import RunnablePassthrough
from langchain_core.runnables.utils import (
Expand Down Expand Up @@ -306,8 +307,17 @@ def get_session_history(
history_chain = RunnablePassthrough.assign(
**{messages_key: history_chain}
).with_config(run_name="insert_history")
bound = (
history_chain | runnable.with_listeners(on_end=self._exit_history)
bound: Runnable = (
history_chain
| RunnableBranch(
(
RunnableLambda(
self._is_not_async, afunc=self._is_async
).with_config(run_name="RunnableWithMessageHistoryInAsyncMode"),
runnable.with_alisteners(on_end=self._aexit_history),
),
runnable.with_listeners(on_end=self._exit_history),
)
).with_config(run_name="RunnableWithMessageHistory")

if history_factory_config:
Expand Down Expand Up @@ -367,6 +377,12 @@ def get_input_schema(
else:
return super_schema

def _is_not_async(self, *args: Sequence[Any], **kwargs: Dict[str, Any]) -> bool:
return False

async def _is_async(self, *args: Sequence[Any], **kwargs: Dict[str, Any]) -> bool:
return True

def _get_input_messages(
self, input_val: Union[str, BaseMessage, Sequence[BaseMessage], dict]
) -> List[BaseMessage]:
Expand Down Expand Up @@ -483,6 +499,23 @@ def _exit_history(self, run: Run, config: RunnableConfig) -> None:
output_messages = self._get_output_messages(output_val)
hist.add_messages(input_messages + output_messages)

async def _aexit_history(self, run: Run, config: RunnableConfig) -> None:
hist: BaseChatMessageHistory = config["configurable"]["message_history"]

# Get the input messages
inputs = load(run.inputs)
input_messages = self._get_input_messages(inputs)
# If historic messages were prepended to the input messages, remove them to
# avoid adding duplicate messages to history.
if not self.history_messages_key:
historic_messages = config["configurable"]["message_history"].messages
input_messages = input_messages[len(historic_messages) :]

# Get the output messages
output_val = load(run.outputs)
output_messages = self._get_output_messages(output_val)
await hist.aadd_messages(input_messages + output_messages)

def _merge_configs(self, *configs: Optional[RunnableConfig]) -> RunnableConfig:
config = super()._merge_configs(*configs)
expected_keys = [field_spec.id for field_spec in self.history_factory_config]
Expand Down
Loading

0 comments on commit 360a70c

Please sign in to comment.