Skip to content

Commit

Permalink
_stream_chat for GPTAPI
Browse files Browse the repository at this point in the history
  • Loading branch information
liujiangning30 committed Jul 11, 2024
1 parent d76969e commit ac321d8
Showing 1 changed file with 101 additions and 13 deletions.
114 changes: 101 additions & 13 deletions lagent/llms/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import warnings
from concurrent.futures import ThreadPoolExecutor
from logging import getLogger
# from threading import Lock
from threading import Lock
from typing import Dict, List, Optional, Union

import requests
Expand Down Expand Up @@ -140,21 +140,22 @@ def stream_chat(
assert isinstance(inputs, list)
if 'max_tokens' in gen_params:
raise NotImplementedError('unsupported parameter: max_tokens')
# gen_params = {**self.gen_params, **gen_params}
gen_params = self.update_gen_params(**gen_params)
gen_params['stream'] = True

resp = ''
finished = False
stop_words = gen_params.get('stop_words')
if stop_words is None:
stop_words = []
# mapping to role that openai supports
messages = inputs.copy()
for item in messages:
for role_cfg in self.meta_template:
if item['role'] == role_cfg['role']:
item['role'] = role_cfg['api_role']
break
for text in self._chat(messages, **gen_params):
for text in self._stream_chat(messages, **gen_params):
resp += text
if not resp:
continue
Expand All @@ -172,6 +173,102 @@ def stream_chat(
def _chat(self, messages: List[dict], **gen_params) -> str:
"""Generate completion from a list of templates.
Args:
messages (List[dict]): a list of prompt dictionaries
gen_params: additional generation configuration
Returns:
str: The generated string.
"""
assert isinstance(messages, list)
gen_params = gen_params.copy()

# Hold out 100 tokens due to potential errors in tiktoken calculation
max_tokens = min(gen_params.pop('max_new_tokens'), 4096)
if max_tokens <= 0:
return ''

max_num_retries = 0
while max_num_retries < self.retry:
self._wait()

with Lock():
if len(self.invalid_keys) == len(self.keys):
raise RuntimeError('All keys have insufficient quota.')

# find the next valid key
while True:
self.key_ctr += 1
if self.key_ctr == len(self.keys):
self.key_ctr = 0

if self.keys[self.key_ctr] not in self.invalid_keys:
break

key = self.keys[self.key_ctr]

header = {
'Authorization': f'Bearer {key}',
'content-type': 'application/json',
}

if self.orgs:
with Lock():
self.org_ctr += 1
if self.org_ctr == len(self.orgs):
self.org_ctr = 0
header['OpenAI-Organization'] = self.orgs[self.org_ctr]

response = dict()
try:
gen_params_new = gen_params.copy()
data = dict(
model=self.model_type,
messages=messages,
max_tokens=max_tokens,
n=1,
stop=gen_params_new.pop('stop_words'),
frequency_penalty=gen_params_new.pop('repetition_penalty'),
**gen_params_new,
)
if self.json_mode:
data['response_format'] = {'type': 'json_object'}
raw_response = requests.post(
self.url,
headers=header,
data=json.dumps(data),
proxies=self.proxies)
response = raw_response.json()
return response['choices'][0]['message']['content'].strip()
except requests.ConnectionError:
print('Got connection error, retrying...')
continue
except requests.JSONDecodeError:
print('JsonDecode error, got', str(raw_response.content))
continue
except KeyError:
if 'error' in response:
if response['error']['code'] == 'rate_limit_exceeded':
time.sleep(1)
continue
elif response['error']['code'] == 'insufficient_quota':
self.invalid_keys.add(key)
self.logger.warn(f'insufficient_quota key: {key}')
continue

print('Find error message in response: ',
str(response['error']))
except Exception as error:
print(str(error))
max_num_retries += 1

raise RuntimeError('Calling OpenAI failed after retrying for '
f'{max_num_retries} times. Check the logs for '
'details.')

def _stream_chat(self, messages: List[dict], **gen_params) -> str:
"""Generate completion from a list of templates.
Args:
messages (List[dict]): a list of prompt dictionaries
gen_params: additional generation configuration
Expand Down Expand Up @@ -205,9 +302,6 @@ def _stream_chat(raw_response):

max_num_retries = 0
while max_num_retries < self.retry:
# self._wait()

# with Lock():
if len(self.invalid_keys) == len(self.keys):
raise RuntimeError('All keys have insufficient quota.')

Expand All @@ -228,7 +322,6 @@ def _stream_chat(raw_response):
}

if self.orgs:
# with Lock():
self.org_ctr += 1
if self.org_ctr == len(self.orgs):
self.org_ctr = 0
Expand All @@ -253,12 +346,7 @@ def _stream_chat(raw_response):
headers=header,
data=json.dumps(data),
proxies=self.proxies)

if data.get('stream', False):
return _stream_chat(raw_response)
else:
response = raw_response.json()
return response['choices'][0]['message']['content'].strip()
return _stream_chat(raw_response)
except requests.ConnectionError:
print('Got connection error, retrying...')
continue
Expand Down

0 comments on commit ac321d8

Please sign in to comment.