Skip to content

Commit

Permalink
feat: add streaming support for OpenAI models (mem0ai#202)
Browse files Browse the repository at this point in the history
  • Loading branch information
aaishikdutta committed Jul 10, 2023
1 parent 13bac72 commit 66c4d30
Show file tree
Hide file tree
Showing 3 changed files with 58 additions and 6 deletions.
13 changes: 13 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -204,6 +204,19 @@ from embedchain import PersonApp as ECPApp
print(naval_chat_bot.query("What unique capacity does Naval argue humans possess when it comes to understanding explanations or concepts?"))
# answer: Naval argues that humans possess the unique capacity to understand explanations or concepts to the maximum extent possible in this physical reality.
```
### Stream Response

- You can add config to your query method to stream responses like ChatGPT does. You would require a downstream handler to render the chunk in your desirable format

- To use this, instantiate App with a `InitConfig` instance passing `stream_response=True`. The following example iterates through the chunks and prints them as they appear
```python
app = App(InitConfig(stream_response=True))
resp = naval_chat_bot.query("What unique capacity does Naval argue humans possess when it comes to understanding explanations or concepts?")

for chunk in resp:
print(chunk, end="", flush=True)
# answer: Naval argues that humans possess the unique capacity to understand explanations or concepts to the maximum extent possible in this physical reality.
```

### Chat Interface

Expand Down
6 changes: 5 additions & 1 deletion embedchain/config/InitConfig.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ class InitConfig(BaseConfig):
"""
Config to initialize an embedchain `App` instance.
"""
def __init__(self, ef=None, db=None):
def __init__(self, ef=None, db=None, stream_response=False):
"""
:param ef: Optional. Embedding function to use.
:param db: Optional. (Vector) database to use for embeddings.
Expand All @@ -27,6 +27,10 @@ def __init__(self, ef=None, db=None):
self.db = ChromaDB(ef=self.ef)
else:
self.db = db

if not isinstance(stream_response, bool):
raise ValueError("`stream_respone` should be bool")
self.stream_response = stream_response

return

Expand Down
45 changes: 40 additions & 5 deletions embedchain/embedchain.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,8 +164,8 @@ def get_answer_from_llm(self, prompt):
:param context: Similar documents to the query used as context.
:return: The answer.
"""
answer = self.get_llm_model_answer(prompt)
return answer

return self.get_llm_model_answer(prompt)

def query(self, input_query, config: QueryConfig = None):
"""
Expand Down Expand Up @@ -226,8 +226,20 @@ def chat(self, input_query, config: ChatConfig = None):
)
answer = self.get_answer_from_llm(prompt)
memory.chat_memory.add_user_message(input_query)
memory.chat_memory.add_ai_message(answer)
return answer
if isinstance(answer, str):
memory.chat_memory.add_ai_message(answer)
return answer
else:
#this is a streamed response and needs to be handled differently
return self._stream_chat_response(answer)

def _stream_chat_response(self, answer):
streamed_answer = ""
for chunk in answer:
streamed_answer.join(chunk)
yield chunk
memory.chat_memory.add_ai_message(streamed_answer)


def dry_run(self, input_query, config: QueryConfig = None):
"""
Expand Down Expand Up @@ -284,6 +296,13 @@ def __init__(self, config: InitConfig = None):
super().__init__(config)

def get_llm_model_answer(self, prompt):
stream_response = self.config.stream_response
if stream_response:
return self._stream_llm_model_response(prompt)
else:
return self._get_llm_model_response(prompt)

def _get_llm_model_response(self, prompt, stream_response = False):
messages = []
messages.append({
"role": "user", "content": prompt
Expand All @@ -294,8 +313,24 @@ def get_llm_model_answer(self, prompt):
temperature=0,
max_tokens=1000,
top_p=1,
stream=stream_response
)
return response["choices"][0]["message"]["content"]

if stream_response:
# This contains the entire completions object. Needs to be sanitised
return response
else:
return response["choices"][0]["message"]["content"]

def _stream_llm_model_response(self, prompt):
"""
This is a generator for streaming response from the OpenAI completions API
"""
response = self._get_llm_model_response(prompt, True)
for line in response:
chunk = line['choices'][0].get('delta', {}).get('content', '')
yield chunk



class OpenSourceApp(EmbedChain):
Expand Down

0 comments on commit 66c4d30

Please sign in to comment.