diff --git a/requirements.txt b/requirements.txt index 207862ab..c007d922 100644 --- a/requirements.txt +++ b/requirements.txt @@ -16,7 +16,7 @@ markdown-it-py==2.2.0 mdurl==0.1.2 miservice-fork==2.1.1 multidict==6.0.4 -openai==0.27.2 +openai==0.27.4 prompt-toolkit==3.0.38 pygments==2.14.0 regex==2022.10.31 diff --git a/xiaogpt/bot/chatgptapi_bot.py b/xiaogpt/bot/chatgptapi_bot.py index 534f7de5..8ccca8aa 100644 --- a/xiaogpt/bot/chatgptapi_bot.py +++ b/xiaogpt/bot/chatgptapi_bot.py @@ -6,11 +6,16 @@ class ChatGPTBot(BaseBot): - def __init__(self, openai_key, api_base=None, proxy=None): + def __init__(self, openai_key, api_base=None, proxy=None, deployment_id=None): self.history = [] openai.api_key = openai_key if api_base: openai.api_base = api_base + # if api_base ends with openai.azure.com, then set api_type to azure + if api_base.endswith(("openai.azure.com/", "openai.azure.com")): + openai.api_type = "azure" + openai.api_version = "2023-03-15-preview" + self.deployment_id = deployment_id if proxy: openai.proxy = proxy @@ -21,6 +26,8 @@ async def ask(self, query, **options): ms.append({"role": "assistant", "content": h[1]}) ms.append({"role": "user", "content": f"{query}"}) kwargs = {"model": "gpt-3.5-turbo", **options} + if openai.api_type == "azure": + kwargs["deployment_id"] = self.deployment_id completion = await openai.ChatCompletion.acreate(messages=ms, **kwargs) message = ( completion["choices"][0] @@ -43,6 +50,8 @@ async def ask_stream(self, query, **options): ms.append({"role": "assistant", "content": h[1]}) ms.append({"role": "user", "content": f"{query}"}) kwargs = {"model": "gpt-3.5-turbo", **options} + if openai.api_type == "azure": + kwargs["deployment_id"] = self.deployment_id completion = await openai.ChatCompletion.acreate( messages=ms, stream=True, **kwargs ) diff --git a/xiaogpt/cli.py b/xiaogpt/cli.py index 4e46f49d..90993b1d 100644 --- a/xiaogpt/cli.py +++ b/xiaogpt/cli.py @@ -114,6 +114,12 @@ def main(): help="specify base url other than the OpenAI's official API address", ) + parser.add_argument( + "--deployment_id", + dest="deployment_id", + help="specify deployment id, only used when api_base points to azure", + ) + options = parser.parse_args() config = Config.from_options(options) diff --git a/xiaogpt/config.py b/xiaogpt/config.py index 244d6f0d..d7b985b4 100644 --- a/xiaogpt/config.py +++ b/xiaogpt/config.py @@ -68,6 +68,7 @@ class Config: bot: str = "chatgpt" cookie: str = "" api_base: str | None = None + deployment_id: str | None = None use_command: bool = False verbose: bool = False start_conversation: str = "开始持续对话" @@ -90,6 +91,14 @@ def __post_init__(self) -> None: ) elif not self.openai_key: raise Exception("Using GPT api needs openai API key, please google how to") + if ( + self.api_base.endswith(("openai.azure.com", "openai.azure.com/")) + and not self.deployment_id + ): + raise Exception( + "Using Azure OpenAI needs deployment_id, read this: " + "https://learn.microsoft.com/en-us/azure/cognitive-services/openai/how-to/chatgpt?pivots=programming-language-chat-completions" + ) @property def tts_command(self) -> str: diff --git a/xiaogpt/xiaogpt.py b/xiaogpt/xiaogpt.py index d10f9d4e..2fe991e6 100644 --- a/xiaogpt/xiaogpt.py +++ b/xiaogpt/xiaogpt.py @@ -185,7 +185,10 @@ def chatbot(self): ) elif self.config.bot == "chatgptapi": self._chatbot = ChatGPTBot( - self.config.openai_key, self.config.api_base, self.config.proxy + self.config.openai_key, + self.config.api_base, + self.config.proxy, + self.config.deployment_id, ) elif self.config.bot == "newbing": self._chatbot = NewBingBot(