Skip to content

Commit

Permalink
OpenAI-DotNet 6.0.0 (RageAgainstThePixel#43)
Browse files Browse the repository at this point in the history
- closes RageAgainstThePixel#24 Added support for Azure OpenAI
  • Loading branch information
StephenHodgson committed Mar 11, 2023
1 parent 6b876d5 commit ab4603e
Show file tree
Hide file tree
Showing 18 changed files with 177 additions and 129 deletions.
9 changes: 9 additions & 0 deletions OpenAI-DotNet-Tests/TestFixture_00_Authentication.cs
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,15 @@ public void Test_10_GetOrgFailed()
}
}

[Test]
public void Test_11_AzureConfigurationSettings()
{
var auth = new OpenAIAuthentication("sk-testAA", "org-testAA");
var settings = new OpenAIClientSettings("test-resource", "deployment-id-test");
var api = new OpenAIClient(auth, settings);
Console.WriteLine(api.BaseRequestUrl);
}

[TearDown]
public void TearDown()
{
Expand Down
7 changes: 3 additions & 4 deletions OpenAI-DotNet/Audio/AudioEndpoint.cs
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,7 @@ public AudioResponse(string text)
public AudioEndpoint(OpenAIClient api) : base(api) { }

/// <inheritdoc />
protected override string GetEndpoint()
=> $"{Api.BaseUrl}audio";
protected override string Root => "audio";

/// <summary>
/// Transcribes audio into the input language.
Expand Down Expand Up @@ -64,7 +63,7 @@ public async Task<string> CreateTranscriptionAsync(AudioTranscriptionRequest req

request.Dispose();

var response = await Api.Client.PostAsync($"{GetEndpoint()}/transcriptions", content, cancellationToken);
var response = await Api.Client.PostAsync(GetUrl("/transcriptions"), content, cancellationToken);
var responseAsString = await response.ReadAsStringAsync(cancellationToken);

return responseFormat == AudioResponseFormat.Json
Expand Down Expand Up @@ -101,7 +100,7 @@ public async Task<string> CreateTranslationAsync(AudioTranslationRequest request

request.Dispose();

var response = await Api.Client.PostAsync($"{GetEndpoint()}/translations", content, cancellationToken);
var response = await Api.Client.PostAsync(GetUrl("/translations"), content, cancellationToken);
var responseAsString = await response.ReadAsStringAsync(cancellationToken);

return responseFormat == AudioResponseFormat.Json
Expand Down
14 changes: 8 additions & 6 deletions OpenAI-DotNet/BaseEndPoint.cs
Original file line number Diff line number Diff line change
Expand Up @@ -2,18 +2,20 @@
{
public abstract class BaseEndPoint
{
protected BaseEndPoint(OpenAIClient api) => Api = api;

protected readonly OpenAIClient Api;

/// <summary>
/// Constructor of the api endpoint.
/// Rather than instantiating this yourself, access it through an instance of <see cref="OpenAIClient"/>.
/// The root endpoint address.
/// </summary>
internal BaseEndPoint(OpenAIClient api) => Api = api;
protected abstract string Root { get; }

/// <summary>
/// Gets the basic endpoint url for the API
/// Gets the full formatted url for the API endpoint.
/// </summary>
/// <returns>The completed basic url for the endpoint.</returns>
protected abstract string GetEndpoint();
/// <param name="endpoint">The endpoint url.</param>
protected string GetUrl(string endpoint = "")
=> string.Format(Api.BaseRequestUrl, $"{Root}{endpoint}");
}
}
22 changes: 11 additions & 11 deletions OpenAI-DotNet/Chat/ChatEndpoint.cs
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,11 @@ namespace OpenAI.Chat
{
public sealed class ChatEndpoint : BaseEndPoint
{
/// <inheritdoc />
public ChatEndpoint(OpenAIClient api) : base(api) { }

protected override string GetEndpoint()
=> $"{Api.BaseUrl}chat";
/// <inheritdoc />
protected override string Root => "chat";

/// <summary>
/// Creates a completion for the chat message
Expand All @@ -24,9 +25,8 @@ protected override string GetEndpoint()
/// <returns><see cref="ChatResponse"/>.</returns>
public async Task<ChatResponse> GetCompletionAsync(ChatRequest chatRequest, CancellationToken cancellationToken = default)
{
var json = JsonSerializer.Serialize(chatRequest, Api.JsonSerializationOptions);
var payload = json.ToJsonStringContent();
var response = await Api.Client.PostAsync($"{GetEndpoint()}/completions", payload, cancellationToken).ConfigureAwait(false);
var jsonContent = JsonSerializer.Serialize(chatRequest, Api.JsonSerializationOptions).ToJsonStringContent();
var response = await Api.Client.PostAsync(GetUrl("/completions"), jsonContent, cancellationToken).ConfigureAwait(false);
var responseAsString = await response.ReadAsStringAsync(cancellationToken).ConfigureAwait(false);
return response.DeserializeResponse<ChatResponse>(responseAsString, Api.JsonSerializationOptions);
}
Expand All @@ -43,10 +43,10 @@ public async Task<ChatResponse> GetCompletionAsync(ChatRequest chatRequest, Canc
public async Task StreamCompletionAsync(ChatRequest chatRequest, Action<ChatResponse> resultHandler, CancellationToken cancellationToken = default)
{
chatRequest.Stream = true;
var jsonContent = JsonSerializer.Serialize(chatRequest, Api.JsonSerializationOptions);
using var request = new HttpRequestMessage(HttpMethod.Post, $"{GetEndpoint()}/completions")
var jsonContent = JsonSerializer.Serialize(chatRequest, Api.JsonSerializationOptions).ToJsonStringContent();
using var request = new HttpRequestMessage(HttpMethod.Post, GetUrl("/completions"))
{
Content = jsonContent.ToJsonStringContent()
Content = jsonContent
};
var response = await Api.Client.SendAsync(request, HttpCompletionOption.ResponseHeadersRead, cancellationToken);
await response.CheckResponseAsync(cancellationToken);
Expand Down Expand Up @@ -84,10 +84,10 @@ public async Task StreamCompletionAsync(ChatRequest chatRequest, Action<ChatResp
public async IAsyncEnumerable<ChatResponse> StreamCompletionEnumerableAsync(ChatRequest chatRequest, [EnumeratorCancellation] CancellationToken cancellationToken = default)
{
chatRequest.Stream = true;
var jsonContent = JsonSerializer.Serialize(chatRequest, Api.JsonSerializationOptions);
using var request = new HttpRequestMessage(HttpMethod.Post, $"{GetEndpoint()}/completions")
var jsonContent = JsonSerializer.Serialize(chatRequest, Api.JsonSerializationOptions).ToJsonStringContent();
using var request = new HttpRequestMessage(HttpMethod.Post, GetUrl("/completions"))
{
Content = jsonContent.ToJsonStringContent()
Content = jsonContent
};
var response = await Api.Client.SendAsync(request, HttpCompletionOption.ResponseHeadersRead, cancellationToken);
await response.CheckResponseAsync(cancellationToken);
Expand Down
19 changes: 9 additions & 10 deletions OpenAI-DotNet/Completions/CompletionsEndpoint.cs
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,7 @@ public sealed class CompletionsEndpoint : BaseEndPoint
internal CompletionsEndpoint(OpenAIClient api) : base(api) { }

/// <inheritdoc />
protected override string GetEndpoint()
=> $"{Api.BaseUrl}completions";
protected override string Root => "completions";

#region Non-streaming

Expand Down Expand Up @@ -107,8 +106,8 @@ protected override string GetEndpoint()
public async Task<CompletionResult> CreateCompletionAsync(CompletionRequest completionRequest, CancellationToken cancellationToken = default)
{
completionRequest.Stream = false;
var jsonContent = JsonSerializer.Serialize(completionRequest, Api.JsonSerializationOptions);
var response = await Api.Client.PostAsync(GetEndpoint(), jsonContent.ToJsonStringContent(), cancellationToken).ConfigureAwait(false);
var jsonContent = JsonSerializer.Serialize(completionRequest, Api.JsonSerializationOptions).ToJsonStringContent();
var response = await Api.Client.PostAsync(GetUrl(), jsonContent, cancellationToken).ConfigureAwait(false);
var responseAsString = await response.ReadAsStringAsync(cancellationToken).ConfigureAwait(false);
return response.DeserializeResponse<CompletionResult>(responseAsString, Api.JsonSerializationOptions);
}
Expand Down Expand Up @@ -198,10 +197,10 @@ public async Task<CompletionResult> CreateCompletionAsync(CompletionRequest comp
public async Task StreamCompletionAsync(CompletionRequest completionRequest, Action<CompletionResult> resultHandler, CancellationToken cancellationToken = default)
{
completionRequest.Stream = true;
var jsonContent = JsonSerializer.Serialize(completionRequest, Api.JsonSerializationOptions);
using var request = new HttpRequestMessage(HttpMethod.Post, GetEndpoint())
var jsonContent = JsonSerializer.Serialize(completionRequest, Api.JsonSerializationOptions).ToJsonStringContent();
using var request = new HttpRequestMessage(HttpMethod.Post, GetUrl())
{
Content = jsonContent.ToJsonStringContent()
Content = jsonContent
};
var response = await Api.Client.SendAsync(request, HttpCompletionOption.ResponseHeadersRead, cancellationToken).ConfigureAwait(false);
await response.CheckResponseAsync(cancellationToken).ConfigureAwait(false);
Expand Down Expand Up @@ -310,10 +309,10 @@ public async Task StreamCompletionAsync(CompletionRequest completionRequest, Act
public async IAsyncEnumerable<CompletionResult> StreamCompletionEnumerableAsync(CompletionRequest completionRequest, [EnumeratorCancellation] CancellationToken cancellationToken = default)
{
completionRequest.Stream = true;
var jsonContent = JsonSerializer.Serialize(completionRequest, Api.JsonSerializationOptions);
using var request = new HttpRequestMessage(HttpMethod.Post, GetEndpoint())
var jsonContent = JsonSerializer.Serialize(completionRequest, Api.JsonSerializationOptions).ToJsonStringContent();
using var request = new HttpRequestMessage(HttpMethod.Post, GetUrl())
{
Content = jsonContent.ToJsonStringContent()
Content = jsonContent
};
var response = await Api.Client.SendAsync(request, HttpCompletionOption.ResponseHeadersRead, cancellationToken).ConfigureAwait(false);
await response.CheckResponseAsync(cancellationToken).ConfigureAwait(false);
Expand Down
7 changes: 3 additions & 4 deletions OpenAI-DotNet/Edits/EditsEndpoint.cs
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,7 @@ public sealed class EditsEndpoint : BaseEndPoint
public EditsEndpoint(OpenAIClient api) : base(api) { }

/// <inheritdoc />
protected override string GetEndpoint()
=> $"{Api.BaseUrl}edits";
protected override string Root => "edits";

/// <summary>
/// Creates a new edit for the provided input, instruction, and parameters
Expand Down Expand Up @@ -56,8 +55,8 @@ protected override string GetEndpoint()
/// <returns><see cref="EditResponse"/></returns>
public async Task<EditResponse> CreateEditAsync(EditRequest request)
{
var jsonContent = JsonSerializer.Serialize(request, Api.JsonSerializationOptions);
var response = await Api.Client.PostAsync(GetEndpoint(), jsonContent.ToJsonStringContent()).ConfigureAwait(false);
var jsonContent = JsonSerializer.Serialize(request, Api.JsonSerializationOptions).ToJsonStringContent();
var response = await Api.Client.PostAsync(GetUrl(), jsonContent).ConfigureAwait(false);
var responseAsString = await response.ReadAsStringAsync().ConfigureAwait(false);
return response.DeserializeResponse<EditResponse>(responseAsString, Api.JsonSerializationOptions);
}
Expand Down
7 changes: 3 additions & 4 deletions OpenAI-DotNet/Embeddings/EmbeddingsEndpoint.cs
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,7 @@ public sealed class EmbeddingsEndpoint : BaseEndPoint
public EmbeddingsEndpoint(OpenAIClient api) : base(api) { }

/// <inheritdoc />
protected override string GetEndpoint()
=> $"{Api.BaseUrl}embeddings";
protected override string Root => "embeddings";

/// <summary>
/// Creates an embedding vector representing the input text.
Expand Down Expand Up @@ -62,8 +61,8 @@ public async Task<EmbeddingsResponse> CreateEmbeddingAsync(IEnumerable<string> i
/// <returns><see cref="EmbeddingsResponse"/></returns>
public async Task<EmbeddingsResponse> CreateEmbeddingAsync(EmbeddingsRequest request)
{
var jsonContent = JsonSerializer.Serialize(request, Api.JsonSerializationOptions);
var response = await Api.Client.PostAsync(GetEndpoint(), jsonContent.ToJsonStringContent()).ConfigureAwait(false);
var jsonContent = JsonSerializer.Serialize(request, Api.JsonSerializationOptions).ToJsonStringContent();
var response = await Api.Client.PostAsync(GetUrl(), jsonContent).ConfigureAwait(false);
var responseAsString = await response.ReadAsStringAsync().ConfigureAwait(false);
return response.DeserializeResponse<EmbeddingsResponse>(responseAsString, Api.JsonSerializationOptions);
}
Expand Down
13 changes: 6 additions & 7 deletions OpenAI-DotNet/Files/FilesEndpoint.cs
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,7 @@ private class FileDeleteResponse
public FilesEndpoint(OpenAIClient api) : base(api) { }

/// <inheritdoc />
protected override string GetEndpoint()
=> $"{Api.BaseUrl}files";
protected override string Root => "files";

/// <summary>
/// Returns a list of files that belong to the user's organization.
Expand All @@ -41,7 +40,7 @@ protected override string GetEndpoint()
/// <exception cref="HttpRequestException"></exception>
public async Task<IReadOnlyList<FileData>> ListFilesAsync()
{
var response = await Api.Client.GetAsync(GetEndpoint()).ConfigureAwait(false);
var response = await Api.Client.GetAsync(GetUrl()).ConfigureAwait(false);
var resultAsString = await response.ReadAsStringAsync().ConfigureAwait(false);
return JsonSerializer.Deserialize<FilesList>(resultAsString, Api.JsonSerializationOptions)?.Data;
}
Expand Down Expand Up @@ -83,7 +82,7 @@ public async Task<FileData> UploadFileAsync(FileUploadRequest request, Cancellat
content.Add(new ByteArrayContent(fileData.ToArray()), "file", request.FileName);
request.Dispose();

var response = await Api.Client.PostAsync(GetEndpoint(), content, cancellationToken).ConfigureAwait(false);
var response = await Api.Client.PostAsync(GetUrl(), content, cancellationToken).ConfigureAwait(false);
var responseAsString = await response.ReadAsStringAsync(cancellationToken).ConfigureAwait(false);
return JsonSerializer.Deserialize<FileData>(responseAsString, Api.JsonSerializationOptions);
}
Expand All @@ -101,7 +100,7 @@ public async Task<bool> DeleteFileAsync(string fileId, CancellationToken cancell

async Task<bool> InternalDeleteFileAsync(int attempt)
{
var response = await Api.Client.DeleteAsync($"{GetEndpoint()}/{fileId}", cancellationToken).ConfigureAwait(false);
var response = await Api.Client.DeleteAsync(GetUrl($"/{fileId}"), cancellationToken).ConfigureAwait(false);
// We specifically don't use the extension method here bc we need to check if it's still processing the file.
var responseAsString = await response.Content.ReadAsStringAsync(cancellationToken).ConfigureAwait(false);

Expand Down Expand Up @@ -129,7 +128,7 @@ async Task<bool> InternalDeleteFileAsync(int attempt)
/// <exception cref="HttpRequestException"></exception>
public async Task<FileData> GetFileInfoAsync(string fileId)
{
var response = await Api.Client.GetAsync($"{GetEndpoint()}/{fileId}").ConfigureAwait(false);
var response = await Api.Client.GetAsync(GetUrl($"/{fileId}")).ConfigureAwait(false);
var responseAsString = await response.ReadAsStringAsync().ConfigureAwait(false);
return JsonSerializer.Deserialize<FileData>(responseAsString, Api.JsonSerializationOptions);
}
Expand Down Expand Up @@ -158,7 +157,7 @@ public async Task<string> DownloadFileAsync(string fileId, string directory, Can
/// <exception cref="ArgumentNullException"></exception>
public async Task<string> DownloadFileAsync(FileData fileData, string directory, CancellationToken cancellationToken = default)
{
await using var response = await Api.Client.GetStreamAsync($"{GetEndpoint()}/{fileData.Id}/content", cancellationToken).ConfigureAwait(false);
await using var response = await Api.Client.GetStreamAsync(GetUrl($"/{fileData.Id}/content"), cancellationToken).ConfigureAwait(false);

if (!Directory.Exists(directory))
{
Expand Down
Loading

0 comments on commit ab4603e

Please sign in to comment.