From c8da098bcf3c93781228a1289c5d33b1bad55a6b Mon Sep 17 00:00:00 2001 From: David Berenstein Date: Thu, 7 Sep 2023 11:01:49 +0200 Subject: [PATCH 1/5] chore: added on_agent_final_answer-support to Agent callback_manager --- haystack/agents/base.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/haystack/agents/base.py b/haystack/agents/base.py index 0c0a146f47f..d3d03c78b69 100644 --- a/haystack/agents/base.py +++ b/haystack/agents/base.py @@ -244,7 +244,7 @@ def __init__( self.max_steps = max_steps self.tm = tools_manager or ToolsManager() self.memory = memory or NoMemory() - self.callback_manager = Events(("on_agent_start", "on_agent_step", "on_agent_finish", "on_new_token")) + self.callback_manager = Events(("on_agent_start", "on_agent_step", "on_agent_finish", "on_agent_final_answer", "on_new_token")) self.prompt_node = prompt_node prompt_template = prompt_template or prompt_node.default_prompt_template or "zero-shot-react" resolved_prompt_template = prompt_node.get_prompt_template(prompt_template) @@ -290,8 +290,12 @@ def on_agent_start(**kwargs: Any) -> None: agent_name = kwargs.pop("name", "react") print_text(f"\nAgent {agent_name} started with {kwargs}\n") + def on_agent_final_answer(final_answer: Dict[str, Any], **kwargs: Any) -> None: + pass + self.tm.callback_manager.on_tool_finish += on_tool_finish self.callback_manager.on_agent_start += on_agent_start + self.callback_manager.on_agent_final_answer += on_agent_final_answer if streaming: self.callback_manager.on_new_token += lambda token, **kwargs: print_text(token, color=agent_color) @@ -359,7 +363,9 @@ def run( agent_step = self._step(query, agent_step, params) finally: self.callback_manager.on_agent_finish(agent_step) - return agent_step.final_answer(query=query) + final_answer = agent_step.final_answer(query=query) + self.callback_manager.on_agent_final_answer(final_answer) + return final_answer def _step(self, query: str, current_step: AgentStep, params: Optional[dict] = None): # plan next step using the LLM From 268e0a4e64907b60b01e0d92f5238fd1847af9f3 Mon Sep 17 00:00:00 2001 From: David Berenstein Date: Thu, 7 Sep 2023 12:09:05 +0200 Subject: [PATCH 2/5] chore: format black --- haystack/agents/base.py | 93 +++++++++++++++++++++++++++++++---------- 1 file changed, 72 insertions(+), 21 deletions(-) diff --git a/haystack/agents/base.py b/haystack/agents/base.py index d3d03c78b69..8539c07d2aa 100644 --- a/haystack/agents/base.py +++ b/haystack/agents/base.py @@ -130,9 +130,13 @@ def __init__( :param tool_pattern: A regular expression pattern that matches the text that the Agent generates to invoke a tool. """ - self._tools: Dict[str, Tool] = {tool.name: tool for tool in tools} if tools else {} + self._tools: Dict[str, Tool] = ( + {tool.name: tool for tool in tools} if tools else {} + ) self.tool_pattern = tool_pattern - self.callback_manager = Events(("on_tool_start", "on_tool_finish", "on_tool_error")) + self.callback_manager = Events( + ("on_tool_start", "on_tool_finish", "on_tool_error") + ) @property def tools(self): @@ -154,9 +158,13 @@ def get_tool_names_with_descriptions(self) -> str: """ Returns a string with the names and descriptions of all registered tools. """ - return "\n".join([f"{tool.name}: {tool.description}" for tool in self.tools.values()]) + return "\n".join( + [f"{tool.name}: {tool.description}" for tool in self.tools.values()] + ) - def run_tool(self, llm_response: str, params: Optional[Dict[str, Any]] = None) -> str: + def run_tool( + self, llm_response: str, params: Optional[Dict[str, Any]] = None + ) -> str: tool_result: str = "" if self.tools: tool_name, tool_input = self.extract_tool_name_and_tool_input(llm_response) @@ -178,7 +186,9 @@ def run_tool(self, llm_response: str, params: Optional[Dict[str, Any]] = None) - raise e return tool_result - def extract_tool_name_and_tool_input(self, llm_response: str) -> Tuple[Optional[str], Optional[str]]: + def extract_tool_name_and_tool_input( + self, llm_response: str + ) -> Tuple[Optional[str], Optional[str]]: """ Parse the tool name and the tool input from the PromptNode response. :param llm_response: The PromptNode response. @@ -244,9 +254,19 @@ def __init__( self.max_steps = max_steps self.tm = tools_manager or ToolsManager() self.memory = memory or NoMemory() - self.callback_manager = Events(("on_agent_start", "on_agent_step", "on_agent_finish", "on_agent_final_answer", "on_new_token")) + self.callback_manager = Events( + ( + "on_agent_start", + "on_agent_step", + "on_agent_finish", + "on_agent_final_answer", + "on_new_token", + ) + ) self.prompt_node = prompt_node - prompt_template = prompt_template or prompt_node.default_prompt_template or "zero-shot-react" + prompt_template = ( + prompt_template or prompt_node.default_prompt_template or "zero-shot-react" + ) resolved_prompt_template = prompt_node.get_prompt_template(prompt_template) if not resolved_prompt_template: raise ValueError( @@ -254,7 +274,9 @@ def __init__( ) self.prompt_template = resolved_prompt_template self.prompt_parameters_resolver = ( - prompt_parameters_resolver if prompt_parameters_resolver else react_parameter_resolver + prompt_parameters_resolver + if prompt_parameters_resolver + else react_parameter_resolver ) self.final_answer_pattern = final_answer_pattern self.add_default_logging_callbacks(streaming=streaming) @@ -268,13 +290,20 @@ def update_hash(self): See haystack/telemetry.py::send_event """ try: - tool_names = " ".join([tool.pipeline_or_node.__class__.__name__ for tool in self.tm.get_tools()]) + tool_names = " ".join( + [ + tool.pipeline_or_node.__class__.__name__ + for tool in self.tm.get_tools() + ] + ) self.hash = md5(tool_names.encode()).hexdigest() except Exception as exc: logger.debug("Telemetry exception: %s", str(exc)) self.hash = "[an exception occurred during hashing]" - def add_default_logging_callbacks(self, agent_color: Color = Color.GREEN, streaming: bool = False) -> None: + def add_default_logging_callbacks( + self, agent_color: Color = Color.GREEN, streaming: bool = False + ) -> None: def on_tool_finish( tool_output: str, color: Optional[Color] = None, @@ -292,13 +321,15 @@ def on_agent_start(**kwargs: Any) -> None: def on_agent_final_answer(final_answer: Dict[str, Any], **kwargs: Any) -> None: pass - + self.tm.callback_manager.on_tool_finish += on_tool_finish self.callback_manager.on_agent_start += on_agent_start self.callback_manager.on_agent_final_answer += on_agent_final_answer if streaming: - self.callback_manager.on_new_token += lambda token, **kwargs: print_text(token, color=agent_color) + self.callback_manager.on_new_token += lambda token, **kwargs: print_text( + token, color=agent_color + ) else: self.callback_manager.on_agent_step += lambda agent_step: print_text( agent_step.prompt_node_response, end="\n", color=agent_color @@ -320,7 +351,8 @@ def add_tool(self, tool: Tool): """ if tool.name in self.tm.tools: logger.warning( - "The agent already has a tool named '%s'. The new tool will overwrite the existing one.", tool.name + "The agent already has a tool named '%s'. The new tool will overwrite the existing one.", + tool.name, ) self.tm.tools[tool.name] = tool @@ -352,11 +384,15 @@ def run( try: if not self.hash == self.last_hash: self.last_hash = self.hash - send_event(event_name="Agent", event_properties={"llm.agent_hash": self.hash}) + send_event( + event_name="Agent", event_properties={"llm.agent_hash": self.hash} + ) except Exception as exc: logger.debug("Telemetry exception: %s", exc) - self.callback_manager.on_agent_start(name=self.prompt_template.name, query=query, params=params) + self.callback_manager.on_agent_start( + name=self.prompt_template.name, query=query, params=params + ) agent_step = self.create_agent_step(max_steps) try: while not agent_step.is_last(): @@ -376,10 +412,16 @@ def _step(self, query: str, current_step: AgentStep, params: Optional[dict] = No self.callback_manager.on_agent_step(next_step) # run the tool selected by the LLM - observation = self.tm.run_tool(next_step.prompt_node_response, params) if not next_step.is_last() else None + observation = ( + self.tm.run_tool(next_step.prompt_node_response, params) + if not next_step.is_last() + else None + ) # save the input, output and observation to memory (if memory is enabled) - memory_data = self.prepare_data_for_memory(input=query, output=prompt_node_response, observation=observation) + memory_data = self.prepare_data_for_memory( + input=query, output=prompt_node_response, observation=observation + ) self.memory.save(data=memory_data) # update the next step with the observation @@ -388,7 +430,9 @@ def _step(self, query: str, current_step: AgentStep, params: Optional[dict] = No def _plan(self, query, current_step): # first resolve prompt template params - template_params = self.prompt_parameters_resolver(query=query, agent=self, agent_step=current_step) + template_params = self.prompt_parameters_resolver( + query=query, agent=self, agent_step=current_step + ) # check for template parameters mismatch self.check_prompt_template(template_params) @@ -405,14 +449,19 @@ def create_agent_step(self, max_steps: Optional[int] = None) -> AgentStep: """ Create an AgentStep object. Override this method to customize the AgentStep class used by the Agent. """ - return AgentStep(max_steps=max_steps or self.max_steps, final_answer_pattern=self.final_answer_pattern) + return AgentStep( + max_steps=max_steps or self.max_steps, + final_answer_pattern=self.final_answer_pattern, + ) def prepare_data_for_memory(self, **kwargs) -> dict: """ Prepare data for saving to the Agent's memory. Override this method to customize the data saved to the memory. """ return { - k: v if isinstance(v, str) else next(iter(v)) for k, v in kwargs.items() if isinstance(v, (str, Iterable)) + k: v if isinstance(v, str) else next(iter(v)) + for k, v in kwargs.items() + if isinstance(v, (str, Iterable)) } def check_prompt_template(self, template_params: Dict[str, Any]) -> None: @@ -428,7 +477,9 @@ def check_prompt_template(self, template_params: Dict[str, Any]) -> None: :param template_params: The parameters provided by the prompt parameter resolver. """ - unused_params = set(template_params.keys()) - set(self.prompt_template.prompt_params) + unused_params = set(template_params.keys()) - set( + self.prompt_template.prompt_params + ) if "transcript" in unused_params: logger.warning( From 44ec6383f1ab9e42d39fb063afbc67c75d9dfac8 Mon Sep 17 00:00:00 2001 From: davidberenstein1957 Date: Mon, 11 Sep 2023 21:33:59 +0200 Subject: [PATCH 3/5] run pre-commit to format file --- haystack/agents/base.py | 109 +++++++++++----------------------------- 1 file changed, 30 insertions(+), 79 deletions(-) diff --git a/haystack/agents/base.py b/haystack/agents/base.py index 8539c07d2aa..bc44f8f3f7a 100644 --- a/haystack/agents/base.py +++ b/haystack/agents/base.py @@ -2,30 +2,30 @@ import logging import re -from collections.abc import Iterable, Callable +from collections.abc import Callable, Iterable from hashlib import md5 -from typing import List, Optional, Union, Dict, Any, Tuple +from typing import Any, Dict, List, Optional, Tuple, Union from events import Events -from haystack import Pipeline, BaseComponent, Answer, Document -from haystack.agents.memory import Memory, NoMemory -from haystack.telemetry import send_event +from haystack import Answer, BaseComponent, Document, Pipeline from haystack.agents.agent_step import AgentStep -from haystack.agents.types import Color, AgentTokenStreamingHandler +from haystack.agents.memory import Memory, NoMemory +from haystack.agents.types import AgentTokenStreamingHandler, Color from haystack.agents.utils import print_text, react_parameter_resolver -from haystack.nodes import PromptNode, BaseRetriever, PromptTemplate +from haystack.nodes import BaseRetriever, PromptNode, PromptTemplate from haystack.pipelines import ( BaseStandardPipeline, - ExtractiveQAPipeline, DocumentSearchPipeline, + ExtractiveQAPipeline, + FAQPipeline, GenerativeQAPipeline, + RetrieverQuestionGenerationPipeline, SearchSummarizationPipeline, - FAQPipeline, TranslationWrapperPipeline, - RetrieverQuestionGenerationPipeline, WebQAPipeline, ) +from haystack.telemetry import send_event logger = logging.getLogger(__name__) @@ -130,13 +130,9 @@ def __init__( :param tool_pattern: A regular expression pattern that matches the text that the Agent generates to invoke a tool. """ - self._tools: Dict[str, Tool] = ( - {tool.name: tool for tool in tools} if tools else {} - ) + self._tools: Dict[str, Tool] = {tool.name: tool for tool in tools} if tools else {} self.tool_pattern = tool_pattern - self.callback_manager = Events( - ("on_tool_start", "on_tool_finish", "on_tool_error") - ) + self.callback_manager = Events(("on_tool_start", "on_tool_finish", "on_tool_error")) @property def tools(self): @@ -158,13 +154,9 @@ def get_tool_names_with_descriptions(self) -> str: """ Returns a string with the names and descriptions of all registered tools. """ - return "\n".join( - [f"{tool.name}: {tool.description}" for tool in self.tools.values()] - ) + return "\n".join([f"{tool.name}: {tool.description}" for tool in self.tools.values()]) - def run_tool( - self, llm_response: str, params: Optional[Dict[str, Any]] = None - ) -> str: + def run_tool(self, llm_response: str, params: Optional[Dict[str, Any]] = None) -> str: tool_result: str = "" if self.tools: tool_name, tool_input = self.extract_tool_name_and_tool_input(llm_response) @@ -186,9 +178,7 @@ def run_tool( raise e return tool_result - def extract_tool_name_and_tool_input( - self, llm_response: str - ) -> Tuple[Optional[str], Optional[str]]: + def extract_tool_name_and_tool_input(self, llm_response: str) -> Tuple[Optional[str], Optional[str]]: """ Parse the tool name and the tool input from the PromptNode response. :param llm_response: The PromptNode response. @@ -255,18 +245,10 @@ def __init__( self.tm = tools_manager or ToolsManager() self.memory = memory or NoMemory() self.callback_manager = Events( - ( - "on_agent_start", - "on_agent_step", - "on_agent_finish", - "on_agent_final_answer", - "on_new_token", - ) + ("on_agent_start", "on_agent_step", "on_agent_finish", "on_agent_final_answer", "on_new_token") ) self.prompt_node = prompt_node - prompt_template = ( - prompt_template or prompt_node.default_prompt_template or "zero-shot-react" - ) + prompt_template = prompt_template or prompt_node.default_prompt_template or "zero-shot-react" resolved_prompt_template = prompt_node.get_prompt_template(prompt_template) if not resolved_prompt_template: raise ValueError( @@ -274,9 +256,7 @@ def __init__( ) self.prompt_template = resolved_prompt_template self.prompt_parameters_resolver = ( - prompt_parameters_resolver - if prompt_parameters_resolver - else react_parameter_resolver + prompt_parameters_resolver if prompt_parameters_resolver else react_parameter_resolver ) self.final_answer_pattern = final_answer_pattern self.add_default_logging_callbacks(streaming=streaming) @@ -290,20 +270,13 @@ def update_hash(self): See haystack/telemetry.py::send_event """ try: - tool_names = " ".join( - [ - tool.pipeline_or_node.__class__.__name__ - for tool in self.tm.get_tools() - ] - ) + tool_names = " ".join([tool.pipeline_or_node.__class__.__name__ for tool in self.tm.get_tools()]) self.hash = md5(tool_names.encode()).hexdigest() except Exception as exc: logger.debug("Telemetry exception: %s", str(exc)) self.hash = "[an exception occurred during hashing]" - def add_default_logging_callbacks( - self, agent_color: Color = Color.GREEN, streaming: bool = False - ) -> None: + def add_default_logging_callbacks(self, agent_color: Color = Color.GREEN, streaming: bool = False) -> None: def on_tool_finish( tool_output: str, color: Optional[Color] = None, @@ -327,9 +300,7 @@ def on_agent_final_answer(final_answer: Dict[str, Any], **kwargs: Any) -> None: self.callback_manager.on_agent_final_answer += on_agent_final_answer if streaming: - self.callback_manager.on_new_token += lambda token, **kwargs: print_text( - token, color=agent_color - ) + self.callback_manager.on_new_token += lambda token, **kwargs: print_text(token, color=agent_color) else: self.callback_manager.on_agent_step += lambda agent_step: print_text( agent_step.prompt_node_response, end="\n", color=agent_color @@ -351,8 +322,7 @@ def add_tool(self, tool: Tool): """ if tool.name in self.tm.tools: logger.warning( - "The agent already has a tool named '%s'. The new tool will overwrite the existing one.", - tool.name, + "The agent already has a tool named '%s'. The new tool will overwrite the existing one.", tool.name ) self.tm.tools[tool.name] = tool @@ -384,15 +354,11 @@ def run( try: if not self.hash == self.last_hash: self.last_hash = self.hash - send_event( - event_name="Agent", event_properties={"llm.agent_hash": self.hash} - ) + send_event(event_name="Agent", event_properties={"llm.agent_hash": self.hash}) except Exception as exc: logger.debug("Telemetry exception: %s", exc) - self.callback_manager.on_agent_start( - name=self.prompt_template.name, query=query, params=params - ) + self.callback_manager.on_agent_start(name=self.prompt_template.name, query=query, params=params) agent_step = self.create_agent_step(max_steps) try: while not agent_step.is_last(): @@ -412,16 +378,10 @@ def _step(self, query: str, current_step: AgentStep, params: Optional[dict] = No self.callback_manager.on_agent_step(next_step) # run the tool selected by the LLM - observation = ( - self.tm.run_tool(next_step.prompt_node_response, params) - if not next_step.is_last() - else None - ) + observation = self.tm.run_tool(next_step.prompt_node_response, params) if not next_step.is_last() else None # save the input, output and observation to memory (if memory is enabled) - memory_data = self.prepare_data_for_memory( - input=query, output=prompt_node_response, observation=observation - ) + memory_data = self.prepare_data_for_memory(input=query, output=prompt_node_response, observation=observation) self.memory.save(data=memory_data) # update the next step with the observation @@ -430,9 +390,7 @@ def _step(self, query: str, current_step: AgentStep, params: Optional[dict] = No def _plan(self, query, current_step): # first resolve prompt template params - template_params = self.prompt_parameters_resolver( - query=query, agent=self, agent_step=current_step - ) + template_params = self.prompt_parameters_resolver(query=query, agent=self, agent_step=current_step) # check for template parameters mismatch self.check_prompt_template(template_params) @@ -449,19 +407,14 @@ def create_agent_step(self, max_steps: Optional[int] = None) -> AgentStep: """ Create an AgentStep object. Override this method to customize the AgentStep class used by the Agent. """ - return AgentStep( - max_steps=max_steps or self.max_steps, - final_answer_pattern=self.final_answer_pattern, - ) + return AgentStep(max_steps=max_steps or self.max_steps, final_answer_pattern=self.final_answer_pattern) def prepare_data_for_memory(self, **kwargs) -> dict: """ Prepare data for saving to the Agent's memory. Override this method to customize the data saved to the memory. """ return { - k: v if isinstance(v, str) else next(iter(v)) - for k, v in kwargs.items() - if isinstance(v, (str, Iterable)) + k: v if isinstance(v, str) else next(iter(v)) for k, v in kwargs.items() if isinstance(v, (str, Iterable)) } def check_prompt_template(self, template_params: Dict[str, Any]) -> None: @@ -477,9 +430,7 @@ def check_prompt_template(self, template_params: Dict[str, Any]) -> None: :param template_params: The parameters provided by the prompt parameter resolver. """ - unused_params = set(template_params.keys()) - set( - self.prompt_template.prompt_params - ) + unused_params = set(template_params.keys()) - set(self.prompt_template.prompt_params) if "transcript" in unused_params: logger.warning( From 99ea14ef4739f67fe9931363bd293963fc38b050 Mon Sep 17 00:00:00 2001 From: davidberenstein1957 Date: Mon, 11 Sep 2023 21:41:39 +0200 Subject: [PATCH 4/5] updated release notes --- ...-on_agent_final_answer_to_agent_base-7798ea8de2f43af0.yaml | 4 ++++ 1 file changed, 4 insertions(+) create mode 100644 releasenotes/notes/add-on_agent_final_answer_to_agent_base-7798ea8de2f43af0.yaml diff --git a/releasenotes/notes/add-on_agent_final_answer_to_agent_base-7798ea8de2f43af0.yaml b/releasenotes/notes/add-on_agent_final_answer_to_agent_base-7798ea8de2f43af0.yaml new file mode 100644 index 00000000000..347fe8a0be0 --- /dev/null +++ b/releasenotes/notes/add-on_agent_final_answer_to_agent_base-7798ea8de2f43af0.yaml @@ -0,0 +1,4 @@ +--- +enhancements: + - | + added support for using `on_final_answer` trough `Agent` `callback_manager` From 84c28821f074f6fa46e985eb1fde6550d631daf0 Mon Sep 17 00:00:00 2001 From: David Berenstein Date: Mon, 18 Sep 2023 14:45:02 -0700 Subject: [PATCH 5/5] reverted sorted imports --- haystack/agents/base.py | 20 ++++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/haystack/agents/base.py b/haystack/agents/base.py index bc44f8f3f7a..4d750532def 100644 --- a/haystack/agents/base.py +++ b/haystack/agents/base.py @@ -2,30 +2,30 @@ import logging import re -from collections.abc import Callable, Iterable +from collections.abc import Iterable, Callable from hashlib import md5 -from typing import Any, Dict, List, Optional, Tuple, Union +from typing import List, Optional, Union, Dict, Any, Tuple from events import Events -from haystack import Answer, BaseComponent, Document, Pipeline -from haystack.agents.agent_step import AgentStep +from haystack import Pipeline, BaseComponent, Answer, Document from haystack.agents.memory import Memory, NoMemory -from haystack.agents.types import AgentTokenStreamingHandler, Color +from haystack.telemetry import send_event +from haystack.agents.agent_step import AgentStep +from haystack.agents.types import Color, AgentTokenStreamingHandler from haystack.agents.utils import print_text, react_parameter_resolver -from haystack.nodes import BaseRetriever, PromptNode, PromptTemplate +from haystack.nodes import PromptNode, BaseRetriever, PromptTemplate from haystack.pipelines import ( BaseStandardPipeline, - DocumentSearchPipeline, ExtractiveQAPipeline, - FAQPipeline, + DocumentSearchPipeline, GenerativeQAPipeline, - RetrieverQuestionGenerationPipeline, SearchSummarizationPipeline, + FAQPipeline, TranslationWrapperPipeline, + RetrieverQuestionGenerationPipeline, WebQAPipeline, ) -from haystack.telemetry import send_event logger = logging.getLogger(__name__)