Skip to content

Commit

Permalink
Feat/0.2.1.2 (dataelement#234)
Browse files Browse the repository at this point in the history
支持host 模型流式输出
  • Loading branch information
yaojin3616 authored Jan 3, 2024
2 parents 8e28582 + 78aa7a8 commit 7680c99
Show file tree
Hide file tree
Showing 6 changed files with 129 additions and 175 deletions.
2 changes: 1 addition & 1 deletion src/backend/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ include = ["./bisheng/*", "bisheng/**/*"]
bisheng = "bisheng.__main__:main"

[tool.poetry.dependencies]
bisheng_langchain = "0.2.1"
bisheng_langchain = "0.2.1.2"
bisheng_pyautogen = "0.1.18"
minio = "^7.2.0"
fastapi_jwt_auth = "^0.5.0"
Expand Down
242 changes: 95 additions & 147 deletions src/bisheng-langchain/bisheng_langchain/chat_models/host_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Mapping, Optional, Tuple, Union

import requests
import sseclient
from bisheng_langchain.utils.requests import Requests
from langchain.callbacks.manager import AsyncCallbackManagerForLLMRun, CallbackManagerForLLMRun
from langchain.chat_models.base import BaseChatModel
from langchain.schema import ChatGeneration, ChatResult
Expand Down Expand Up @@ -153,16 +153,28 @@ def validate_environment(cls, values: Dict) -> Dict:
values['host_base_url'] = get_from_dict_or_env(values, 'host_base_url', 'HostBaseUrl')
model = values['model_name']
try:
url = values['host_base_url'].split('/')[2]
config_ep = f'https://{url}/v2/models/{model}/config'
config = requests.get(url=config_ep, json={}, timeout=5).json()
policy = config.get('model_transaction_policy', {})
values['decoupled'] = policy.get('decoupled', False)
if cls != CustomLLMChat:
url = values['host_base_url'].split('/')[2]
config_ep = f'https://{url}/v2/models/{model}/config'
config = requests.get(url=config_ep, json={}, timeout=5).json()
policy = config.get('model_transaction_policy', {})
values['decoupled'] = policy.get('decoupled', False)
# Host class should set below code
if values['decoupled']:
values[
'host_base_url'] = f"{values['host_base_url']}/{values['model_name']}/generate_stream"
else:
values[
'host_base_url'] = f"{values['host_base_url']}/{values['model_name']}/infer"
except Exception:
raise Exception(f'Update Decoupled status faild for model {model}')

try:
values['client'] = requests.post
if values['headers']:
headers = values['headers']
else:
headers = {'Content-Type': 'application/json'}
values['client'] = Requests(headers=headers, request_timeout=values['request_timeout'])
except AttributeError:
raise ValueError('Try upgrading it with `pip install --upgrade requests`.')
return values
Expand All @@ -185,6 +197,7 @@ def completion_with_retry(self, **kwargs: Any) -> Any:

@retry_decorator
def _completion_with_retry(**kwargs: Any) -> Any:
self.client.headers = self.headers
messages = kwargs.get('messages')
temperature = kwargs.get('temperature')
top_p = kwargs.get('top_p')
Expand All @@ -204,17 +217,17 @@ def _completion_with_retry(**kwargs: Any) -> Any:
# print('messages:', messages)
# print('functions:', kwargs.get('functions', []))
if self.verbose:
print('payload', params)

method_name = 'infer' if not self.decoupled else 'generate'
url = f'{self.host_base_url}/{self.model_name}/{method_name}'
logger.info(f'payload={params}')
try:
resp = self.client(
url=url, json=params, timeout=self.request_timeout).json()
except requests.exceptions.Timeout:
raise Exception(f'timeout in host llm infer, url=[{url}]')
resp = self.client.post(url=self.host_base_url, json=params)
if resp.text.startswith('data:'):
resp = json.loads(resp.text.replace('data:', ''))
else:
resp = resp.json()
except requests.exceptions.Timeout as exc:
raise ValueError(f'timeout in host llm infer, url=[{self.host_base_url}]') from exc
except Exception as e:
raise Exception(f'exception in host llm infer: [{e}]')
raise ValueError(f'exception in host llm infer: [{e}]') from e

if not resp.get('choices', []):
logger.info(resp)
Expand Down Expand Up @@ -249,63 +262,46 @@ def _generate(
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> ChatResult:
'''用来处理同步请求'''
message_dicts, params = self._create_message_dicts(messages, stop)
params = {**params, **kwargs}
response = self.completion_with_retry(messages=message_dicts, **params)
return self._create_chat_result(response)

def _stream(self, **kwargs: Any) -> Any:
async def acompletion_with_retry(self, **kwargs: Any) -> Any:
"""Use tenacity to retry the async completion call."""
retry_decorator = _create_retry_decorator(self)

@retry_decorator
def _completion_with_retry(**kwargs: Any) -> Any:
if self.streaming:
if not self.decoupled:
raise Exception('Not supported stream protocol with non decoupled model')

headers = {'Accept': 'text/event-stream'}
url = f'{self.host_base_url}/{self.model_name}/generate_stream'
try:
res = requests.post(
url=url,
data=json.dumps(kwargs),
headers=headers,
stream=False)
except Exception as e:
raise Exception(f'exception in host llm sse infer: [{e}]')

res.raise_for_status()
try:
client = sseclient.SSEClient(res, timeout=self.request_timeout)
for event in client.events():
delta_data = json.loads(event.data)
yield delta_data
except requests.exceptions.Timeout:
raise Exception(f'timeout in host llm sse infer, url=[{url}]')
except Exception as e:
raise Exception(f'exception in host llm sse infer: [{e}]')
else:
method_name = 'infer' if not self.decoupled else 'generate'
url = f'{self.host_base_url}/{self.model_name}/{method_name}'
try:
res = requests.post(
url=url,
data=json.dumps(kwargs),
stream=False,
timeout=self.request_timeout)
return res.json()
except requests.exceptions.Timeout:
raise Exception(f'timeout in host llm infer, url=[{url}]')
except Exception as e:
raise Exception(f'exception in host llm infer: [{e}]')

if self.streaming:
for response in _completion_with_retry(**kwargs):
if response:
yield response
else:
return _completion_with_retry(**kwargs)
async def _acompletion_with_retry(**kwargs: Any) -> Any:
try:
async with self.client.apost(url=self.host_base_url, json=kwargs) as response:
if response.status != 200:
raise ValueError(f'Error: {response.status}')
async for txt in response.content.iter_any():
if b'\n' in txt:
for txt_ in txt.split(b'\n'):
yield txt_.decode('utf-8').strip()
else:
yield txt.decode('utf-8').strip()
except requests.exceptions.Timeout as exc:
raise ValueError(f'timeout in host llm infer, url=[{self.host_base_url}]') from exc
except Exception as e:
raise ValueError(f'exception in host llm infer: [{e}]') from e

async for response in _acompletion_with_retry(**kwargs):
is_error = False
if response:
if response.startswith('event:error'):
is_error = True
elif response.startswith('data:'):
yield (is_error, response[len('data:'):])
if is_error:
break
elif response.startswith('{'):
yield (is_error, response)
else:
continue

async def _agenerate(
self,
Expand All @@ -314,41 +310,49 @@ async def _agenerate(
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> ChatResult:
if not self.decoupled:
return self._generate(messages, stop, run_manager, **kwargs)

"""Generate chat completion with retry."""
message_dicts, params = self._create_message_dicts(messages, stop)
params = {**params, **kwargs}
if self.streaming:
inner_completion = ''
role = 'assistant'
params['stream'] = True
function_call: Optional[dict] = None
for stream_resp in self._stream(
messages=message_dicts, **params
):
role = stream_resp['choices'][0]['delta'].get('role', role)
token = stream_resp['choices'][0]['delta'].get('content', '')
inner_completion += token or ''
_function_call = stream_resp['choices'][0]['delta'].get('function_call')
if _function_call:
if function_call is None:
function_call = _function_call
else:
function_call['arguments'] += _function_call['arguments']
if run_manager:
await run_manager.on_llm_new_token(token)
message = _convert_dict_to_message(
{
'content': inner_completion,
'role': role,
'function_call': function_call,
}
)
async for is_error, stream_resp in self.acompletion_with_retry(messages=message_dicts,
**params):
output = json.loads(stream_resp)
if is_error:
logger.error(stream_resp)
raise ValueError(stream_resp)

choices = output.get('choices')
if choices:
for choice in choices:
role = choice['delta'].get('role', role)
token = choice['delta'].get('content', '')
inner_completion += token or ''
_function_call = choice['delta'].get('function_call')
if run_manager:
await run_manager.on_llm_new_token(token)
if _function_call:
if function_call is None:
function_call = _function_call
else:
function_call['arguments'] += _function_call['arguments']
message = _convert_dict_to_message({
'content': inner_completion,
'role': role,
'function_call': function_call,
})
return ChatResult(generations=[ChatGeneration(message=message)])
else:
params['stream'] = False
response = self._stream(messages=message_dicts, **params)
response = [
response
async for _, response in self.acompletion_with_retry(messages=message_dicts,
**params)
]
response = json.loads(response[0])
return self._create_chat_result(response)

def _create_message_dicts(
Expand All @@ -373,7 +377,7 @@ def _create_chat_result(self, response: Mapping[str, Any]) -> ChatResult:
gen = ChatGeneration(message=message)
generations.append(gen)

llm_output = {'token_usage': response['usage'], 'model_name': self.model_name}
llm_output = {'token_usage': response.get('usage'), 'model_name': self.model_name}
return ChatResult(generations=generations, llm_output=llm_output)

@property
Expand Down Expand Up @@ -525,65 +529,9 @@ class CustomLLMChat(BaseHostChatLLM):
temperature: float = 0.1
top_p: float = 0.1
max_tokens: int = 4096
host_base_url: str

@property
def _llm_type(self) -> str:
"""Return type of chat model."""
return 'custom_llm_chat'

def completion_with_retry(self, **kwargs: Any) -> Any:
retry_decorator = _create_retry_decorator(self)

@retry_decorator
def _completion_with_retry(**kwargs: Any) -> Any:
messages = kwargs.get('messages')
temperature = kwargs.get('temperature')
top_p = kwargs.get('top_p')
max_tokens = kwargs.get('max_tokens')
do_sample = kwargs.get('do_sample')
params = {
'messages': messages,
'model': self.model_name,
'top_p': top_p,
'temperature': temperature,
'max_tokens': max_tokens,
'do_sample': do_sample
}

if self.verbose:
print('payload', params)

resp = None
try:
resp = self.client(
url=self.host_base_url,
json=params,
timeout=self.request_timeout).json()
except requests.exceptions.Timeout:
raise Exception(
f'timeout in custom host llm infer, url=[{self.host_base_url}]')
except Exception as e:
raise Exception(f'exception in custom host llm infer: [{e}]')

return resp

return _completion_with_retry(**kwargs)

def _create_chat_result(self, response: Mapping[str, Any]) -> ChatResult:
generations = []
for res in response['choices']:
message = _convert_dict_to_message(res['message'])
gen = ChatGeneration(message=message)
generations.append(gen)

llm_output = {'token_usage': response.get('usage', {}), 'model_name': self.model_name}
return ChatResult(generations=generations, llm_output=llm_output)

async def _agenerate(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> ChatResult:
return self._generate(messages, stop, run_manager, **kwargs)
Loading

0 comments on commit 7680c99

Please sign in to comment.