Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: added on_agent_final_answer-support to Agent callback_manager #5736

Merged
merged 14 commits into from
Oct 9, 2023
Merged
Prev Previous commit
Next Next commit
run pre-commit to format file
  • Loading branch information
davidberenstein1957 committed Sep 11, 2023
commit 44ec6383f1ab9e42d39fb063afbc67c75d9dfac8
109 changes: 30 additions & 79 deletions haystack/agents/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,30 +2,30 @@

import logging
import re
from collections.abc import Iterable, Callable
from collections.abc import Callable, Iterable
from hashlib import md5
from typing import List, Optional, Union, Dict, Any, Tuple
from typing import Any, Dict, List, Optional, Tuple, Union

from events import Events

from haystack import Pipeline, BaseComponent, Answer, Document
from haystack.agents.memory import Memory, NoMemory
from haystack.telemetry import send_event
from haystack import Answer, BaseComponent, Document, Pipeline
from haystack.agents.agent_step import AgentStep
from haystack.agents.types import Color, AgentTokenStreamingHandler
from haystack.agents.memory import Memory, NoMemory
from haystack.agents.types import AgentTokenStreamingHandler, Color
from haystack.agents.utils import print_text, react_parameter_resolver
from haystack.nodes import PromptNode, BaseRetriever, PromptTemplate
from haystack.nodes import BaseRetriever, PromptNode, PromptTemplate
from haystack.pipelines import (
BaseStandardPipeline,
ExtractiveQAPipeline,
DocumentSearchPipeline,
ExtractiveQAPipeline,
FAQPipeline,
GenerativeQAPipeline,
RetrieverQuestionGenerationPipeline,
SearchSummarizationPipeline,
FAQPipeline,
TranslationWrapperPipeline,
RetrieverQuestionGenerationPipeline,
WebQAPipeline,
)
from haystack.telemetry import send_event
davidberenstein1957 marked this conversation as resolved.
Show resolved Hide resolved

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -130,13 +130,9 @@ def __init__(
:param tool_pattern: A regular expression pattern that matches the text that the Agent generates to invoke
a tool.
"""
self._tools: Dict[str, Tool] = (
{tool.name: tool for tool in tools} if tools else {}
)
self._tools: Dict[str, Tool] = {tool.name: tool for tool in tools} if tools else {}
self.tool_pattern = tool_pattern
self.callback_manager = Events(
("on_tool_start", "on_tool_finish", "on_tool_error")
)
self.callback_manager = Events(("on_tool_start", "on_tool_finish", "on_tool_error"))

@property
def tools(self):
Expand All @@ -158,13 +154,9 @@ def get_tool_names_with_descriptions(self) -> str:
"""
Returns a string with the names and descriptions of all registered tools.
"""
return "\n".join(
[f"{tool.name}: {tool.description}" for tool in self.tools.values()]
)
return "\n".join([f"{tool.name}: {tool.description}" for tool in self.tools.values()])

def run_tool(
self, llm_response: str, params: Optional[Dict[str, Any]] = None
) -> str:
def run_tool(self, llm_response: str, params: Optional[Dict[str, Any]] = None) -> str:
tool_result: str = ""
if self.tools:
tool_name, tool_input = self.extract_tool_name_and_tool_input(llm_response)
Expand All @@ -186,9 +178,7 @@ def run_tool(
raise e
return tool_result

def extract_tool_name_and_tool_input(
self, llm_response: str
) -> Tuple[Optional[str], Optional[str]]:
def extract_tool_name_and_tool_input(self, llm_response: str) -> Tuple[Optional[str], Optional[str]]:
"""
Parse the tool name and the tool input from the PromptNode response.
:param llm_response: The PromptNode response.
Expand Down Expand Up @@ -255,28 +245,18 @@ def __init__(
self.tm = tools_manager or ToolsManager()
self.memory = memory or NoMemory()
self.callback_manager = Events(
(
"on_agent_start",
"on_agent_step",
"on_agent_finish",
"on_agent_final_answer",
"on_new_token",
)
("on_agent_start", "on_agent_step", "on_agent_finish", "on_agent_final_answer", "on_new_token")
)
self.prompt_node = prompt_node
prompt_template = (
prompt_template or prompt_node.default_prompt_template or "zero-shot-react"
)
prompt_template = prompt_template or prompt_node.default_prompt_template or "zero-shot-react"
resolved_prompt_template = prompt_node.get_prompt_template(prompt_template)
if not resolved_prompt_template:
raise ValueError(
f"Prompt template '{prompt_template}' not found. Please check the spelling of the template name."
)
self.prompt_template = resolved_prompt_template
self.prompt_parameters_resolver = (
prompt_parameters_resolver
if prompt_parameters_resolver
else react_parameter_resolver
prompt_parameters_resolver if prompt_parameters_resolver else react_parameter_resolver
)
self.final_answer_pattern = final_answer_pattern
self.add_default_logging_callbacks(streaming=streaming)
Expand All @@ -290,20 +270,13 @@ def update_hash(self):
See haystack/telemetry.py::send_event
"""
try:
tool_names = " ".join(
[
tool.pipeline_or_node.__class__.__name__
for tool in self.tm.get_tools()
]
)
tool_names = " ".join([tool.pipeline_or_node.__class__.__name__ for tool in self.tm.get_tools()])
self.hash = md5(tool_names.encode()).hexdigest()
except Exception as exc:
logger.debug("Telemetry exception: %s", str(exc))
self.hash = "[an exception occurred during hashing]"

def add_default_logging_callbacks(
self, agent_color: Color = Color.GREEN, streaming: bool = False
) -> None:
def add_default_logging_callbacks(self, agent_color: Color = Color.GREEN, streaming: bool = False) -> None:
def on_tool_finish(
tool_output: str,
color: Optional[Color] = None,
Expand All @@ -327,9 +300,7 @@ def on_agent_final_answer(final_answer: Dict[str, Any], **kwargs: Any) -> None:
self.callback_manager.on_agent_final_answer += on_agent_final_answer

if streaming:
self.callback_manager.on_new_token += lambda token, **kwargs: print_text(
token, color=agent_color
)
self.callback_manager.on_new_token += lambda token, **kwargs: print_text(token, color=agent_color)
else:
self.callback_manager.on_agent_step += lambda agent_step: print_text(
agent_step.prompt_node_response, end="\n", color=agent_color
Expand All @@ -351,8 +322,7 @@ def add_tool(self, tool: Tool):
"""
if tool.name in self.tm.tools:
logger.warning(
"The agent already has a tool named '%s'. The new tool will overwrite the existing one.",
tool.name,
"The agent already has a tool named '%s'. The new tool will overwrite the existing one.", tool.name
)
self.tm.tools[tool.name] = tool

Expand Down Expand Up @@ -384,15 +354,11 @@ def run(
try:
if not self.hash == self.last_hash:
self.last_hash = self.hash
send_event(
event_name="Agent", event_properties={"llm.agent_hash": self.hash}
)
send_event(event_name="Agent", event_properties={"llm.agent_hash": self.hash})
except Exception as exc:
logger.debug("Telemetry exception: %s", exc)

self.callback_manager.on_agent_start(
name=self.prompt_template.name, query=query, params=params
)
self.callback_manager.on_agent_start(name=self.prompt_template.name, query=query, params=params)
agent_step = self.create_agent_step(max_steps)
try:
while not agent_step.is_last():
Expand All @@ -412,16 +378,10 @@ def _step(self, query: str, current_step: AgentStep, params: Optional[dict] = No
self.callback_manager.on_agent_step(next_step)

# run the tool selected by the LLM
observation = (
self.tm.run_tool(next_step.prompt_node_response, params)
if not next_step.is_last()
else None
)
observation = self.tm.run_tool(next_step.prompt_node_response, params) if not next_step.is_last() else None

# save the input, output and observation to memory (if memory is enabled)
memory_data = self.prepare_data_for_memory(
input=query, output=prompt_node_response, observation=observation
)
memory_data = self.prepare_data_for_memory(input=query, output=prompt_node_response, observation=observation)
self.memory.save(data=memory_data)

# update the next step with the observation
Expand All @@ -430,9 +390,7 @@ def _step(self, query: str, current_step: AgentStep, params: Optional[dict] = No

def _plan(self, query, current_step):
# first resolve prompt template params
template_params = self.prompt_parameters_resolver(
query=query, agent=self, agent_step=current_step
)
template_params = self.prompt_parameters_resolver(query=query, agent=self, agent_step=current_step)

# check for template parameters mismatch
self.check_prompt_template(template_params)
Expand All @@ -449,19 +407,14 @@ def create_agent_step(self, max_steps: Optional[int] = None) -> AgentStep:
"""
Create an AgentStep object. Override this method to customize the AgentStep class used by the Agent.
"""
return AgentStep(
max_steps=max_steps or self.max_steps,
final_answer_pattern=self.final_answer_pattern,
)
return AgentStep(max_steps=max_steps or self.max_steps, final_answer_pattern=self.final_answer_pattern)

def prepare_data_for_memory(self, **kwargs) -> dict:
"""
Prepare data for saving to the Agent's memory. Override this method to customize the data saved to the memory.
"""
return {
k: v if isinstance(v, str) else next(iter(v))
for k, v in kwargs.items()
if isinstance(v, (str, Iterable))
k: v if isinstance(v, str) else next(iter(v)) for k, v in kwargs.items() if isinstance(v, (str, Iterable))
}

def check_prompt_template(self, template_params: Dict[str, Any]) -> None:
Expand All @@ -477,9 +430,7 @@ def check_prompt_template(self, template_params: Dict[str, Any]) -> None:
:param template_params: The parameters provided by the prompt parameter resolver.

"""
unused_params = set(template_params.keys()) - set(
self.prompt_template.prompt_params
)
unused_params = set(template_params.keys()) - set(self.prompt_template.prompt_params)

if "transcript" in unused_params:
logger.warning(
Expand Down