Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Azure OpenAI Support #222

Merged
merged 1 commit into from
Apr 17, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Add Azure OpenAI support
  • Loading branch information
changeforan committed Apr 14, 2023
commit 559a476c5f23f44c5e0acedefc5e9165aec991ee
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ markdown-it-py==2.2.0
mdurl==0.1.2
miservice-fork==2.1.1
multidict==6.0.4
openai==0.27.2
openai==0.27.4
prompt-toolkit==3.0.38
pygments==2.14.0
regex==2022.10.31
Expand Down
11 changes: 10 additions & 1 deletion xiaogpt/bot/chatgptapi_bot.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,16 @@


class ChatGPTBot(BaseBot):
def __init__(self, openai_key, api_base=None, proxy=None):
def __init__(self, openai_key, api_base=None, proxy=None, deployment_id=None):
self.history = []
openai.api_key = openai_key
if api_base:
openai.api_base = api_base
# if api_base ends with openai.azure.com, then set api_type to azure
if api_base.endswith(("openai.azure.com/", "openai.azure.com")):
openai.api_type = "azure"
openai.api_version = "2023-03-15-preview"
self.deployment_id = deployment_id
if proxy:
openai.proxy = proxy

Expand All @@ -21,6 +26,8 @@ async def ask(self, query, **options):
ms.append({"role": "assistant", "content": h[1]})
ms.append({"role": "user", "content": f"{query}"})
kwargs = {"model": "gpt-3.5-turbo", **options}
if openai.api_type == "azure":
kwargs["deployment_id"] = self.deployment_id
completion = await openai.ChatCompletion.acreate(messages=ms, **kwargs)
message = (
completion["choices"][0]
Expand All @@ -43,6 +50,8 @@ async def ask_stream(self, query, **options):
ms.append({"role": "assistant", "content": h[1]})
ms.append({"role": "user", "content": f"{query}"})
kwargs = {"model": "gpt-3.5-turbo", **options}
if openai.api_type == "azure":
kwargs["deployment_id"] = self.deployment_id
changeforan marked this conversation as resolved.
Show resolved Hide resolved
completion = await openai.ChatCompletion.acreate(
messages=ms, stream=True, **kwargs
)
Expand Down
6 changes: 6 additions & 0 deletions xiaogpt/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,12 @@ def main():
help="specify base url other than the OpenAI's official API address",
)

parser.add_argument(
"--deployment_id",
dest="deployment_id",
help="specify deployment id, only used when api_base points to azure",
)

options = parser.parse_args()
config = Config.from_options(options)

Expand Down
9 changes: 9 additions & 0 deletions xiaogpt/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ class Config:
bot: str = "chatgpt"
cookie: str = ""
api_base: str | None = None
deployment_id: str | None = None
use_command: bool = False
verbose: bool = False
start_conversation: str = "开始持续对话"
Expand All @@ -90,6 +91,14 @@ def __post_init__(self) -> None:
)
elif not self.openai_key:
raise Exception("Using GPT api needs openai API key, please google how to")
if (
self.api_base.endswith(("openai.azure.com", "openai.azure.com/"))
and not self.deployment_id
):
raise Exception(
"Using Azure OpenAI needs deployment_id, read this: "
"https://learn.microsoft.com/en-us/azure/cognitive-services/openai/how-to/chatgpt?pivots=programming-language-chat-completions"
)

@property
def tts_command(self) -> str:
Expand Down
5 changes: 4 additions & 1 deletion xiaogpt/xiaogpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,7 +185,10 @@ def chatbot(self):
)
elif self.config.bot == "chatgptapi":
self._chatbot = ChatGPTBot(
self.config.openai_key, self.config.api_base, self.config.proxy
self.config.openai_key,
self.config.api_base,
self.config.proxy,
self.config.deployment_id,
)
elif self.config.bot == "newbing":
self._chatbot = NewBingBot(
Expand Down