Skip to content

Commit

Permalink
OpenAI-DotNet 5.1.1 (RageAgainstThePixel#40)
Browse files Browse the repository at this point in the history
- Refactored Model validation
- Added additional default models
- Deprecate `OpenAIClient.DefaultModel`
- Implemented chat completion streaming
- Refactored immutable types
  • Loading branch information
StephenHodgson committed Mar 9, 2023
1 parent ee6cac0 commit 2cef872
Show file tree
Hide file tree
Showing 44 changed files with 588 additions and 465 deletions.
34 changes: 17 additions & 17 deletions OpenAI-DotNet-Tests/TestFixture_02_Completions.cs
Original file line number Diff line number Diff line change
Expand Up @@ -10,32 +10,31 @@ namespace OpenAI.Tests
{
internal class TestFixture_02_Completions
{
private readonly string prompts = "One Two Three Four Five Six Seven Eight Nine One Two Three Four Five Six Seven Eight";
private const string CompletionPrompts = "One Two Three Four Five Six Seven Eight Nine One Two Three Four Five Six Seven Eight";

[Test]
public async Task Test_01_GetBasicCompletion()
{
var api = new OpenAIClient(OpenAIAuthentication.LoadFromEnv());
Assert.IsNotNull(api.CompletionsEndpoint);

var result = await api.CompletionsEndpoint.CreateCompletionAsync(prompts, temperature: 0.1, maxTokens: 5, numOutputs: 5, model: Model.Davinci);
var result = await api.CompletionsEndpoint.CreateCompletionAsync(
CompletionPrompts,
temperature: 0.1,
maxTokens: 5,
numOutputs: 5,
model: Model.Davinci);
Assert.IsNotNull(result);
Assert.NotNull(result.Completions);
Assert.NotZero(result.Completions.Count);
Assert.That(result.Completions.Any(c => c.Text.Trim().ToLower().StartsWith("nine")));

foreach (var choice in result.Completions)
{
Console.WriteLine(choice);
}
Console.WriteLine(result);
}

[Test]
public async Task Test_02_GetStreamingCompletion()
{
var api = new OpenAIClient(OpenAIAuthentication.LoadFromEnv());
Assert.IsNotNull(api.CompletionsEndpoint);

var allCompletions = new List<Choice>();

await api.CompletionsEndpoint.StreamCompletionAsync(result =>
Expand All @@ -44,33 +43,34 @@ public async Task Test_02_GetStreamingCompletion()
Assert.NotNull(result.Completions);
Assert.NotZero(result.Completions.Count);
allCompletions.AddRange(result.Completions);
}, CompletionPrompts, temperature: 0.1, maxTokens: 5, numOutputs: 5);

foreach (var choice in result.Completions)
{
Console.WriteLine(choice);
}
}, prompts, temperature: 0.1, maxTokens: 5, numOutputs: 5, model: Model.Davinci);
Assert.That(allCompletions.Any(c => c.Text.Trim().ToLower().StartsWith("nine")));
Console.WriteLine(allCompletions.FirstOrDefault());
}

[Test]
public async Task Test_03_GetStreamingEnumerableCompletion()
{
var api = new OpenAIClient(OpenAIAuthentication.LoadFromEnv());
Assert.IsNotNull(api.CompletionsEndpoint);

var allCompletions = new List<Choice>();

await foreach (var result in api.CompletionsEndpoint.StreamCompletionEnumerableAsync(prompts, temperature: 0.1, maxTokens: 5, numOutputs: 5, model: Model.Davinci))
await foreach (var result in api.CompletionsEndpoint.StreamCompletionEnumerableAsync(
CompletionPrompts,
temperature: 0.1,
maxTokens: 5,
numOutputs: 5,
model: Model.Davinci))
{
Assert.IsNotNull(result);
Assert.NotNull(result.Completions);
Assert.NotZero(result.Completions.Count);
Console.WriteLine(result);
allCompletions.AddRange(result.Completions);
}

Assert.That(allCompletions.Any(c => c.Text.Trim().ToLower().StartsWith("nine")));
Console.WriteLine(allCompletions.FirstOrDefault());
}
}
}
52 changes: 52 additions & 0 deletions OpenAI-DotNet-Tests/TestFixture_03_Chat.cs
Original file line number Diff line number Diff line change
Expand Up @@ -28,5 +28,57 @@ public async Task Test_1_GetChatCompletion()
Assert.NotZero(result.Choices.Count);
Console.WriteLine(result.FirstChoice);
}

[Test]
public async Task Test_2_GetChatStreamingCompletion()
{
var api = new OpenAIClient(OpenAIAuthentication.LoadFromEnv());
Assert.IsNotNull(api.ChatEndpoint);
var chatPrompts = new List<ChatPrompt>
{
new ChatPrompt("system", "You are a helpful assistant."),
new ChatPrompt("user", "Who won the world series in 2020?"),
new ChatPrompt("assistant", "The Los Angeles Dodgers won the World Series in 2020."),
new ChatPrompt("user", "Where was it played?"),
};
var chatRequest = new ChatRequest(chatPrompts, Model.GPT3_5_Turbo);
var allContent = new List<string>();

await api.ChatEndpoint.StreamCompletionAsync(chatRequest, result =>
{
Assert.IsNotNull(result);
Assert.NotNull(result.Choices);
Assert.NotZero(result.Choices.Count);
allContent.Add(result.FirstChoice);
});

Console.WriteLine(string.Join("", allContent));
}

[Test]
public async Task Test_3_GetChatStreamingCompletionEnumerableAsync()
{
var api = new OpenAIClient(OpenAIAuthentication.LoadFromEnv());
Assert.IsNotNull(api.ChatEndpoint);
var chatPrompts = new List<ChatPrompt>
{
new ChatPrompt("system", "You are a helpful assistant."),
new ChatPrompt("user", "Who won the world series in 2020?"),
new ChatPrompt("assistant", "The Los Angeles Dodgers won the World Series in 2020."),
new ChatPrompt("user", "Where was it played?"),
};
var chatRequest = new ChatRequest(chatPrompts, Model.GPT3_5_Turbo);
var allContent = new List<string>();

await foreach (var result in api.ChatEndpoint.StreamCompletionEnumerableAsync(chatRequest))
{
Assert.IsNotNull(result);
Assert.NotNull(result.Choices);
Assert.NotZero(result.Choices.Count);
allContent.Add(result.FirstChoice);
}

Console.WriteLine(string.Join("", allContent));
}
}
}
4 changes: 2 additions & 2 deletions OpenAI-DotNet/Audio/AudioTranscriptionRequest.cs
Original file line number Diff line number Diff line change
Expand Up @@ -102,11 +102,11 @@ public sealed class AudioTranscriptionRequest

AudioName = audioName;

Model = model ?? new Model("whisper-1");
Model = model ?? Models.Model.Whisper1;

if (!Model.Contains("whisper"))
{
throw new ArgumentException(nameof(model), $"{Model} is not supported.");
throw new ArgumentException($"{Model} is not supported", nameof(model));
}

Prompt = prompt;
Expand Down
4 changes: 2 additions & 2 deletions OpenAI-DotNet/Audio/AudioTranslationRequest.cs
Original file line number Diff line number Diff line change
Expand Up @@ -82,11 +82,11 @@ public sealed class AudioTranslationRequest

AudioName = audioName;

Model = model ?? new Model("whisper-1");
Model = model ?? Models.Model.Whisper1;

if (!Model.Contains("whisper"))
{
throw new ArgumentException(nameof(model), $"{Model} is not supported.");
throw new ArgumentException($"{Model} is not supported", nameof(model));
}

Prompt = prompt;
Expand Down
93 changes: 89 additions & 4 deletions OpenAI-DotNet/Chat/ChatEndpoint.cs
Original file line number Diff line number Diff line change
@@ -1,4 +1,8 @@
using System;
using System.Collections.Generic;
using System.IO;
using System.Net.Http;
using System.Runtime.CompilerServices;
using System.Text.Json;
using System.Threading;
using System.Threading.Tasks;
Expand All @@ -22,11 +26,92 @@ public async Task<ChatResponse> GetCompletionAsync(ChatRequest chatRequest, Canc
{
var json = JsonSerializer.Serialize(chatRequest, Api.JsonSerializationOptions);
var payload = json.ToJsonStringContent();
var result = await Api.Client.PostAsync($"{GetEndpoint()}/completions", payload, cancellationToken);
var resultAsString = await result.ReadAsStringAsync(cancellationToken);
return JsonSerializer.Deserialize<ChatResponse>(resultAsString, Api.JsonSerializationOptions);
var response = await Api.Client.PostAsync($"{GetEndpoint()}/completions", payload, cancellationToken).ConfigureAwait(false);
var responseAsString = await response.ReadAsStringAsync(cancellationToken).ConfigureAwait(false);
return response.DeserializeResponse<ChatResponse>(responseAsString, Api.JsonSerializationOptions);
}

// TODO Streaming endpoints

/// <summary>
/// Created a completion for the chat message and stream the results to the <paramref name="resultHandler"/> as they come in.
/// </summary>
/// <param name="chatRequest">The chat request which contains the message content.</param>
/// <param name="resultHandler">An action to be called as each new result arrives.</param>
/// <param name="cancellationToken">Optional, <see cref="CancellationToken"/>.</param>
/// <returns><see cref="ChatResponse"/>.</returns>
/// <exception cref="HttpRequestException">Raised when the HTTP request fails</exception>
public async Task StreamCompletionAsync(ChatRequest chatRequest, Action<ChatResponse> resultHandler, CancellationToken cancellationToken = default)
{
chatRequest.Stream = true;
var jsonContent = JsonSerializer.Serialize(chatRequest, Api.JsonSerializationOptions);
using var request = new HttpRequestMessage(HttpMethod.Post, $"{GetEndpoint()}/completions")
{
Content = jsonContent.ToJsonStringContent()
};
var response = await Api.Client.SendAsync(request, HttpCompletionOption.ResponseHeadersRead, cancellationToken);
await response.CheckResponseAsync(cancellationToken);
await using var stream = await response.Content.ReadAsStreamAsync(cancellationToken).ConfigureAwait(false);
using var reader = new StreamReader(stream);

while (await reader.ReadLineAsync() is { } line)
{
if (line.StartsWith("data: "))
{
line = line["data: ".Length..];
}

if (line == "[DONE]")
{
return;
}

if (!string.IsNullOrWhiteSpace(line))
{
resultHandler(response.DeserializeResponse<ChatResponse>(line.Trim(), Api.JsonSerializationOptions));
}
}
}

/// <summary>
/// Created a completion for the chat message and stream the results as they come in.<br/>
/// If you are not using C# 8 supporting IAsyncEnumerable{T} or if you are using the .NET Framework,
/// you may need to use <see cref="StreamCompletionAsync(ChatRequest, Action{ChatResponse}, CancellationToken)"/> instead.
/// </summary>
/// <param name="chatRequest">The chat request which contains the message content.</param>
/// <param name="cancellationToken">Optional, <see cref="CancellationToken"/>.</param>
/// <returns><see cref="ChatResponse"/>.</returns>
/// <exception cref="HttpRequestException">Raised when the HTTP request fails</exception>
public async IAsyncEnumerable<ChatResponse> StreamCompletionEnumerableAsync(ChatRequest chatRequest, [EnumeratorCancellation] CancellationToken cancellationToken = default)
{
chatRequest.Stream = true;
var jsonContent = JsonSerializer.Serialize(chatRequest, Api.JsonSerializationOptions);
using var request = new HttpRequestMessage(HttpMethod.Post, $"{GetEndpoint()}/completions")
{
Content = jsonContent.ToJsonStringContent()
};
var response = await Api.Client.SendAsync(request, HttpCompletionOption.ResponseHeadersRead, cancellationToken);
await response.CheckResponseAsync(cancellationToken);
await using var stream = await response.Content.ReadAsStreamAsync(cancellationToken).ConfigureAwait(false);
using var reader = new StreamReader(stream);

while (await reader.ReadLineAsync() is { } line &&
!cancellationToken.IsCancellationRequested)
{
if (line.StartsWith("data: "))
{
line = line["data: ".Length..];
}

if (line == "[DONE]")
{
yield break;
}

if (!string.IsNullOrWhiteSpace(line))
{
yield return response.DeserializeResponse<ChatResponse>(line.Trim(), Api.JsonSerializationOptions);
}
}
}
}
}
66 changes: 61 additions & 5 deletions OpenAI-DotNet/Chat/ChatRequest.cs
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,62 @@ namespace OpenAI.Chat
{
public sealed class ChatRequest
{
/// <summary>
/// Constructor.
/// </summary>
/// <param name="messages"></param>
/// <param name="model">
/// ID of the model to use. Currently, only gpt-3.5-turbo and gpt-3.5-turbo-0301 are supported.
/// </param>
/// <param name="temperature">
/// What sampling temperature to use, between 0 and 2.
/// Higher values like 0.8 will make the output more random, while lower values like 0.2 will
/// make it more focused and deterministic.
/// We generally recommend altering this or top_p but not both.<br/>
/// Defaults to 1
/// </param>
/// <param name="topP">
/// An alternative to sampling with temperature, called nucleus sampling,
/// where the model considers the results of the tokens with top_p probability mass.
/// So 0.1 means only the tokens comprising the top 10% probability mass are considered.
/// We generally recommend altering this or temperature but not both.<br/>
/// Defaults to 1
/// </param>
/// <param name="number">
/// How many chat completion choices to generate for each input message.<br/>
/// Defaults to 1
/// </param>
/// <param name="stops">
/// Up to 4 sequences where the API will stop generating further tokens.
/// </param>
/// <param name="maxTokens">
/// The maximum number of tokens allowed for the generated answer.
/// By default, the number of tokens the model can return will be (4096 - prompt tokens).
/// </param>
/// <param name="presencePenalty">
/// Number between -2.0 and 2.0.
/// Positive values penalize new tokens based on whether they appear in the text so far,
/// increasing the model's likelihood to talk about new topics.<br/>
/// Defaults to 0
/// </param>
/// <param name="frequencyPenalty">
/// Number between -2.0 and 2.0.
/// Positive values penalize new tokens based on their existing frequency in the text so far,
/// decreasing the model's likelihood to repeat the same line verbatim.<br/>
/// Defaults to 0
/// </param>
/// <param name="logitBias">
/// Modify the likelihood of specified tokens appearing in the completion.
/// Accepts a json object that maps tokens(specified by their token ID in the tokenizer)
/// to an associated bias value from -100 to 100. Mathematically, the bias is added to the logits
/// generated by the model prior to sampling.The exact effect will vary per model, but values between
/// -1 and 1 should decrease or increase likelihood of selection; values like -100 or 100 should result
/// in a ban or exclusive selection of the relevant token.<br/>
/// Defaults to null
/// </param>
/// <param name="user">
/// A unique identifier representing your end-user, which can help OpenAI to monitor and detect abuse.
/// </param>
public ChatRequest(
IEnumerable<ChatPrompt> messages,
Model model = null,
Expand All @@ -22,12 +78,11 @@ public sealed class ChatRequest
Dictionary<string, double> logitBias = null,
string user = null)
{
const string defaultModel = "gpt-3.5-turbo";
Model = model ?? Models.Model.GPT3_5_Turbo;

if (!Model.Contains(defaultModel))
if (!Model.Contains("turbo"))
{
throw new ArgumentException(nameof(model), $"{Model} not supported");
throw new ArgumentException($"{Model} is not supported", nameof(model));
}

Messages = messages?.ToList();
Expand Down Expand Up @@ -126,7 +181,8 @@ public sealed class ChatRequest
[JsonPropertyName("frequency_penalty")]
public double? FrequencyPenalty { get; }

/// <summary>Modify the likelihood of specified tokens appearing in the completion.
/// <summary>
/// Modify the likelihood of specified tokens appearing in the completion.
/// Accepts a json object that maps tokens(specified by their token ID in the tokenizer)
/// to an associated bias value from -100 to 100. Mathematically, the bias is added to the logits
/// generated by the model prior to sampling.The exact effect will vary per model, but values between
Expand All @@ -135,7 +191,7 @@ public sealed class ChatRequest
/// Defaults to null
/// </summary>
[JsonPropertyName("logit_bias")]
public IReadOnlyDictionary<string, double> LogitBias { get; set; }
public IReadOnlyDictionary<string, double> LogitBias { get; }

/// <summary>
/// A unique identifier representing your end-user, which can help OpenAI to monitor and detect abuse.
Expand Down
Loading

0 comments on commit 2cef872

Please sign in to comment.