From ec9f454ad151ef7b92ed68a4896024c2de6b5ad7 Mon Sep 17 00:00:00 2001 From: Dev Khant Date: Mon, 4 Sep 2023 00:55:43 +0530 Subject: [PATCH] System prompt at App level (#484) Co-authored-by: Taranjeet Singh --- embedchain/apps/App.py | 18 ++++++++++++++---- embedchain/apps/CustomApp.py | 8 ++++++-- embedchain/apps/Llama2App.py | 8 +++++--- embedchain/apps/OpenSourceApp.py | 9 +++++---- embedchain/embedchain.py | 4 +++- tests/embedchain/test_query.py | 19 +++++++++++++++++-- 6 files changed, 50 insertions(+), 16 deletions(-) diff --git a/embedchain/apps/App.py b/embedchain/apps/App.py index 3e6537e362..03d9f73a51 100644 --- a/embedchain/apps/App.py +++ b/embedchain/apps/App.py @@ -1,3 +1,5 @@ +from typing import Optional + import openai from embedchain.config import AppConfig, ChatConfig @@ -14,19 +16,27 @@ class App(EmbedChain): dry_run(query): test your prompt without consuming tokens. """ - def __init__(self, config: AppConfig = None): + def __init__(self, config: AppConfig = None, system_prompt: Optional[str] = None): """ :param config: AppConfig instance to load as configuration. Optional. + :param system_prompt: System prompt string. Optional. """ if config is None: config = AppConfig() - super().__init__(config) + super().__init__(config, system_prompt) def get_llm_model_answer(self, prompt, config: ChatConfig): messages = [] - if config.system_prompt: - messages.append({"role": "system", "content": config.system_prompt}) + system_prompt = ( + self.system_prompt + if self.system_prompt is not None + else config.system_prompt + if config.system_prompt is not None + else None + ) + if system_prompt: + messages.append({"role": "system", "content": system_prompt}) messages.append({"role": "user", "content": prompt}) response = openai.ChatCompletion.create( model=config.model or "gpt-3.5-turbo-0613", diff --git a/embedchain/apps/CustomApp.py b/embedchain/apps/CustomApp.py index 21774bde9c..5478e124aa 100644 --- a/embedchain/apps/CustomApp.py +++ b/embedchain/apps/CustomApp.py @@ -18,10 +18,11 @@ class CustomApp(EmbedChain): dry_run(query): test your prompt without consuming tokens. """ - def __init__(self, config: CustomAppConfig = None): + def __init__(self, config: CustomAppConfig = None, system_prompt: Optional[str] = None): """ :param config: Optional. `CustomAppConfig` instance to load as configuration. :raises ValueError: Config must be provided for custom app + :param system_prompt: Optional. System prompt string. """ if config is None: raise ValueError("Config must be provided for custom app") @@ -34,7 +35,7 @@ def __init__(self, config: CustomAppConfig = None): # Because these models run locally, they should have an instance running when the custom app is created self.open_source_app = OpenSourceApp(config=config.open_source_app_config) - super().__init__(config) + super().__init__(config, system_prompt) def set_llm_model(self, provider: Providers): self.provider = provider @@ -51,6 +52,9 @@ def get_llm_model_answer(self, prompt, config: ChatConfig): "Streaming responses have not been implemented for this model yet. Please disable." ) + if config.system_prompt is None and self.system_prompt is not None: + config.system_prompt = self.system_prompt + try: if self.provider == Providers.OPENAI: return CustomApp._get_openai_answer(prompt, config) diff --git a/embedchain/apps/Llama2App.py b/embedchain/apps/Llama2App.py index b9615cf83b..4ef3922097 100644 --- a/embedchain/apps/Llama2App.py +++ b/embedchain/apps/Llama2App.py @@ -1,4 +1,5 @@ import os +from typing import Optional from langchain.llms import Replicate @@ -15,9 +16,10 @@ class Llama2App(EmbedChain): query(query): finds answer to the given query using vector database and LLM. """ - def __init__(self, config: AppConfig = None): + def __init__(self, config: AppConfig = None, system_prompt: Optional[str] = None): """ :param config: AppConfig instance to load as configuration. Optional. + :param system_prompt: System prompt string. Optional. """ if "REPLICATE_API_TOKEN" not in os.environ: raise ValueError("Please set the REPLICATE_API_TOKEN environment variable.") @@ -25,11 +27,11 @@ def __init__(self, config: AppConfig = None): if config is None: config = AppConfig() - super().__init__(config) + super().__init__(config, system_prompt) def get_llm_model_answer(self, prompt, config: ChatConfig = None): # TODO: Move the model and other inputs into config - if config.system_prompt: + if self.system_prompt or config.system_prompt: raise ValueError("Llama2App does not support `system_prompt`") llm = Replicate( model="a16z-infra/llama13b-v2-chat:df7690f1994d94e96ad9d568eac121aecf50684a0b0963b25a41cc40061269e5", diff --git a/embedchain/apps/OpenSourceApp.py b/embedchain/apps/OpenSourceApp.py index 12f7f8e6e9..a74433fca5 100644 --- a/embedchain/apps/OpenSourceApp.py +++ b/embedchain/apps/OpenSourceApp.py @@ -1,5 +1,5 @@ import logging -from typing import Iterable, Union +from typing import Iterable, Union, Optional from embedchain.config import ChatConfig, OpenSourceAppConfig from embedchain.embedchain import EmbedChain @@ -18,10 +18,11 @@ class OpenSourceApp(EmbedChain): query(query): finds answer to the given query using vector database and LLM. """ - def __init__(self, config: OpenSourceAppConfig = None): + def __init__(self, config: OpenSourceAppConfig = None, system_prompt: Optional[str] = None): """ :param config: OpenSourceAppConfig instance to load as configuration. Optional. `ef` defaults to open source. + :param system_prompt: System prompt string. Optional. """ logging.info("Loading open source embedding model. This may take some time...") # noqa:E501 if not config: @@ -33,7 +34,7 @@ def __init__(self, config: OpenSourceAppConfig = None): self.instance = OpenSourceApp._get_instance(config.model) logging.info("Successfully loaded open source embedding model.") - super().__init__(config) + super().__init__(config, system_prompt) def get_llm_model_answer(self, prompt, config: ChatConfig): return self._get_gpt4all_answer(prompt=prompt, config=config) @@ -55,7 +56,7 @@ def _get_gpt4all_answer(self, prompt: str, config: ChatConfig) -> Union[str, Ite "OpenSourceApp does not support switching models at runtime. Please create a new app instance." ) - if config.system_prompt: + if self.system_prompt or config.system_prompt: raise ValueError("OpenSourceApp does not support `system_prompt`") response = self.instance.generate( diff --git a/embedchain/embedchain.py b/embedchain/embedchain.py index d87f682031..d6ef82c4b6 100644 --- a/embedchain/embedchain.py +++ b/embedchain/embedchain.py @@ -33,15 +33,17 @@ class EmbedChain: - def __init__(self, config: BaseAppConfig): + def __init__(self, config: BaseAppConfig, system_prompt: Optional[str] = None): """ Initializes the EmbedChain instance, sets up a vector DB client and creates a collection. :param config: BaseAppConfig instance to load as configuration. + :param system_prompt: Optional. System prompt string. """ self.config = config + self.system_prompt = system_prompt self.collection = self.config.db._get_or_create_collection(self.config.collection_name) self.db = self.config.db self.user_asks = [] diff --git a/tests/embedchain/test_query.py b/tests/embedchain/test_query.py index da8f1b2603..46be142765 100644 --- a/tests/embedchain/test_query.py +++ b/tests/embedchain/test_query.py @@ -43,7 +43,7 @@ def test_query(self): mock_answer.assert_called_once() @patch("openai.ChatCompletion.create") - def test_query_config_passing(self, mock_create): + def test_query_config_app_passing(self, mock_create): mock_create.return_value = {"choices": [{"message": {"content": "response"}}]} # Mock response config = AppConfig() @@ -52,9 +52,24 @@ def test_query_config_passing(self, mock_create): app.get_llm_model_answer("Test query", chat_config) - # Test systemp_prompt: Check that the 'create' method was called with the correct 'messages' argument + # Test system_prompt: Check that the 'create' method was called with the correct 'messages' argument messages_arg = mock_create.call_args.kwargs["messages"] self.assertEqual(messages_arg[0]["role"], "system") self.assertEqual(messages_arg[0]["content"], "Test system prompt") # TODO: Add tests for other config variables + + @patch("openai.ChatCompletion.create") + def test_app_passing(self, mock_create): + mock_create.return_value = {"choices": [{"message": {"content": "response"}}]} # Mock response + + config = AppConfig() + chat_config = QueryConfig() + app = App(config=config, system_prompt="Test system prompt") + + app.get_llm_model_answer("Test query", chat_config) + + # Test system_prompt: Check that the 'create' method was called with the correct 'messages' argument + messages_arg = mock_create.call_args.kwargs["messages"] + self.assertEqual(messages_arg[0]["role"], "system") + self.assertEqual(messages_arg[0]["content"], "Test system prompt")