Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Chat endpoint improvements, tests, readme updates, etc #67

Merged
merged 2 commits into from
Mar 9, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
Chat endpoint improvements, tests, readme updates, and an alternate C…
…onversation class
  • Loading branch information
OkGoDoIt committed Mar 9, 2023
commit c81ac58f17a1f1b4fe247b79d77e52c9d31438c1
45 changes: 31 additions & 14 deletions OpenAI_API/Chat/ChatEndpoint.cs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
using OpenAI_API.Models;
using System;
using System.Collections.Generic;
using System.Linq;
using System.Net.Http;
using System.Text;
using System.Threading.Tasks;
Expand All @@ -10,8 +11,8 @@ namespace OpenAI_API.Chat
/// <summary>
/// ChatGPT API endpoint. Use this endpoint to send multiple messages and carry on a conversation.
/// </summary>
public class ChatEndpoint : EndpointBase
{
public class ChatEndpoint : EndpointBase
{
/// <summary>
/// This allows you to set default parameters for every request, for example to set a default temperature or max tokens. For every request, if you do not have a parameter set on the request but do have it set here as a default, the request will automatically pick up the default value.
/// </summary>
Expand All @@ -28,14 +29,23 @@ public class ChatEndpoint : EndpointBase
/// <param name="api"></param>
internal ChatEndpoint(OpenAIAPI api) : base(api) { }

/// <summary>
/// Creates an ongoing chat which can easily encapsulate the conversation. This is the simplest way to use the Chat endpoint.
/// </summary>
/// <returns></returns>
public Conversation CreateConversation()
{
return new Conversation(this, defaultChatRequestArgs: DefaultChatRequestArgs);
}

#region Non-streaming

/// <summary>
/// Ask the API to complete the request using the specified parameters. This is non-streaming, so it will wait until the API returns the full result. Any non-specified parameters will fall back to default values specified in <see cref="DefaultChatRequestArgs"/> if present.
/// </summary>
/// <param name="request">The request to send to the API.</param>
/// <returns>Asynchronously returns the completion result. Look in its <see cref="ChatResult.Choices"/> property for the results.</returns>
public async Task<ChatResult> CreateChatAsync(ChatRequest request)
public async Task<ChatResult> CreateChatCompletionAsync(ChatRequest request)
{
return await HttpPost<ChatResult>(postData: request);
}
Expand All @@ -46,10 +56,10 @@ public async Task<ChatResult> CreateChatAsync(ChatRequest request)
/// <param name="request">The request to send to the API.</param>
/// <param name="numOutputs">Overrides <see cref="ChatRequest.NumChoicesPerMessage"/> as a convenience.</param>
/// <returns>Asynchronously returns the completion result. Look in its <see cref="ChatResult.Choices"/> property for the results.</returns>
public Task<ChatResult> CreateChatAsync(ChatRequest request, int numOutputs = 5)
public Task<ChatResult> CreateChatCompletionAsync(ChatRequest request, int numOutputs = 5)
{
request.NumChoicesPerMessage = numOutputs;
return CreateChatAsync(request);
return CreateChatCompletionAsync(request);
}

/// <summary>
Expand All @@ -66,15 +76,15 @@ public Task<ChatResult> CreateChatAsync(ChatRequest request, int numOutputs = 5)
/// <param name="logitBias">Maps tokens (specified by their token ID in the tokenizer) to an associated bias value from -100 to 100. Mathematically, the bias is added to the logits generated by the model prior to sampling. The exact effect will vary per model, but values between -1 and 1 should decrease or increase likelihood of selection; values like -100 or 100 should result in a ban or exclusive selection of the relevant token.</param>
/// <param name="stopSequences">One or more sequences where the API will stop generating further tokens. The returned text will not contain the stop sequence.</param>
/// <returns>Asynchronously returns the completion result. Look in its <see cref="ChatResult.Choices"/> property for the results.</returns>
public Task<ChatResult> CreateChatAsync(IEnumerable<ChatMessage> messages,
public Task<ChatResult> CreateChatCompletionAsync(IList<ChatMessage> messages,
Model model = null,
double? temperature = null,
double? top_p = null,
int? numOutputs = null,
int? max_tokens = null,
double? frequencyPenalty = null,
double? presencePenalty = null,
IReadOnlyDictionary<string, float> logitBias = null,
IReadOnlyDictionary<string, float> logitBias = null,
params string[] stopSequences)
{
ChatRequest request = new ChatRequest(DefaultChatRequestArgs)
Expand All @@ -90,23 +100,30 @@ public Task<ChatResult> CreateChatAsync(ChatRequest request, int numOutputs = 5)
PresencePenalty = presencePenalty ?? DefaultChatRequestArgs.PresencePenalty,
LogitBias = logitBias ?? DefaultChatRequestArgs.LogitBias
};
return CreateChatAsync(request);
return CreateChatCompletionAsync(request);
}

/// <summary>
/// Ask the API to complete the request using the specified parameters. This is non-streaming, so it will wait until the API returns the full result. Any non-specified parameters will fall back to default values specified in <see cref="DefaultChatRequestArgs"/> if present.
/// Ask the API to complete the request using the specified message(s). Any parameters will fall back to default values specified in <see cref="DefaultChatRequestArgs"/> if present.
/// </summary>
/// <param name="messages">The messages to use in the generation.</param>
/// <returns></returns>
public Task<ChatResult> CreateChatAsync(IEnumerable<ChatMessage> messages)
/// <returns>The <see cref="ChatResult"/> with the API response.</returns>
public Task<ChatResult> CreateChatCompletionAsync(params ChatMessage[] messages)
{
ChatRequest request = new ChatRequest(DefaultChatRequestArgs)
{
Messages = messages
};
return CreateChatAsync(request);
return CreateChatCompletionAsync(request);
}

/// <summary>
/// Ask the API to complete the request using the specified message(s). Any parameters will fall back to default values specified in <see cref="DefaultChatRequestArgs"/> if present.
/// </summary>
/// <param name="userMessages">The user message or messages to use in the generation. All strings are assumed to be of Role <see cref="ChatMessageRole.User"/></param>
/// <returns>The <see cref="ChatResult"/> with the API response.</returns>
public Task<ChatResult> CreateChatCompletionAsync(params string[] userMessages) => CreateChatCompletionAsync(userMessages.Select(m => new ChatMessage(ChatMessageRole.User, m)).ToArray());

#endregion

#region Streaming
Expand Down Expand Up @@ -168,15 +185,15 @@ public IAsyncEnumerable<ChatResult> StreamChatEnumerableAsync(ChatRequest reques
/// <param name="logitBias">Maps tokens (specified by their token ID in the tokenizer) to an associated bias value from -100 to 100. Mathematically, the bias is added to the logits generated by the model prior to sampling. The exact effect will vary per model, but values between -1 and 1 should decrease or increase likelihood of selection; values like -100 or 100 should result in a ban or exclusive selection of the relevant token.</param>
/// <param name="stopSequences">One or more sequences where the API will stop generating further tokens. The returned text will not contain the stop sequence.</param>
/// <returns>An async enumerable with each of the results as they come in. See <see href="https://docs.microsoft.com/en-us/dotnet/csharp/whats-new/csharp-8#asynchronous-streams">the C# docs</see> for more details on how to consume an async enumerable.</returns>
public IAsyncEnumerable<ChatResult> StreamChatEnumerableAsync(IEnumerable<ChatMessage> messages,
public IAsyncEnumerable<ChatResult> StreamChatEnumerableAsync(IList<ChatMessage> messages,
Model model = null,
double? temperature = null,
double? top_p = null,
int? numOutputs = null,
int? max_tokens = null,
double? frequencyPenalty = null,
double? presencePenalty = null,
IReadOnlyDictionary<string, float> logitBias = null,
IReadOnlyDictionary<string, float> logitBias = null,
params string[] stopSequences)
{
ChatRequest request = new ChatRequest(DefaultChatRequestArgs)
Expand Down
119 changes: 119 additions & 0 deletions OpenAI_API/Chat/ChatMessageRole.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
using Newtonsoft.Json;
using System;
using System.Collections.Generic;
using System.ComponentModel.Design;
using System.Text;

namespace OpenAI_API.Chat
{
/// <summary>
/// Represents the Role of a <see cref="ChatMessage"/>. Typically, a conversation is formatted with a system message first, followed by alternating user and assistant messages. See <see href="https://platform.openai.com/docs/guides/chat/introduction">the OpenAI docs</see> for more details about usage.
/// </summary>
public class ChatMessageRole : IEquatable<ChatMessageRole>
{
/// <summary>
/// Contructor is prvate to force usage of strongly typed values
/// </summary>
/// <param name="value"></param>
private ChatMessageRole(string value) { Value = value; }

/// <summary>
/// Gets the singleton instance of <see cref="ChatMessageRole"/> based on the string value.
/// </summary>
/// <param name="roleName">Muse be one of "system", "user", or "assistant"</param>
/// <returns></returns>
public static ChatMessageRole FromString(string roleName)
{
switch (roleName)
{
case "system":
return ChatMessageRole.System;
case "user":
return ChatMessageRole.User;
case "assistant":
return ChatMessageRole.Assistant;
default:
return null;
}
}

private string Value { get; set; }

/// <summary>
/// The system message helps set the behavior of the assistant.
/// </summary>
public static ChatMessageRole System { get { return new ChatMessageRole("system"); } }
/// <summary>
/// The user messages help instruct the assistant. They can be generated by the end users of an application, or set by a developer as an instruction.
/// </summary>
public static ChatMessageRole User { get { return new ChatMessageRole("user"); } }
/// <summary>
/// The assistant messages help store prior responses. They can also be written by a developer to help give examples of desired behavior.
/// </summary>
public static ChatMessageRole Assistant { get { return new ChatMessageRole("assistant"); } }

/// <summary>
/// Gets the string value for this role to pass to the API
/// </summary>
/// <returns>The size as a string</returns>
public override string ToString()
{
return Value;
}

/// <summary>
/// Determines whether this instance and a specified object have the same value.
/// </summary>
/// <param name="obj">The ChatMessageRole to compare to this instance</param>
/// <returns>true if obj is a ChatMessageRole and its value is the same as this instance; otherwise, false. If obj is null, the method returns false</returns>
public override bool Equals(object obj)
{
return Value.Equals((obj as ChatMessageRole).Value);
}

/// <summary>
/// Returns the hash code for this object
/// </summary>
/// <returns>A 32-bit signed integer hash code</returns>
public override int GetHashCode()
{
return Value.GetHashCode();
}

/// <summary>
/// Determines whether this instance and a specified object have the same value.
/// </summary>
/// <param name="other">The ChatMessageRole to compare to this instance</param>
/// <returns>true if other's value is the same as this instance; otherwise, false. If other is null, the method returns false</returns>
public bool Equals(ChatMessageRole other)
{
return Value.Equals(other.Value);
}

/// <summary>
/// Gets the string value for this role to pass to the API
/// </summary>
/// <param name="value">The ChatMessageRole to convert</param>
public static implicit operator String(ChatMessageRole value) { return value; }

///// <summary>
///// Used during the Json serialization process
///// </summary>
//internal class ChatMessageRoleJsonConverter : JsonConverter<ChatMessageRole>
//{
// public override void WriteJson(JsonWriter writer, ChatMessageRole value, JsonSerializer serializer)
// {
// writer.WriteValue(value.ToString());
// }

// public override ChatMessageRole ReadJson(JsonReader reader, Type objectType, ChatMessageRole existingValue, bool hasExistingValue, JsonSerializer serializer)
// {
// if (reader.TokenType != JsonToken.String)
// {
// throw new JsonSerializationException();
// }
// return new ChatMessageRole(reader.ReadAsString());
// }
//}
}
}
17 changes: 9 additions & 8 deletions OpenAI_API/Chat/ChatRequest.cs
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ public class ChatRequest
/// The messages to send with this Chat Request
/// </summary>
[JsonProperty("messages")]
public IEnumerable<ChatMessage> Messages { get; set; }
public IList<ChatMessage> Messages { get; set; }

/// <summary>
/// What sampling temperature to use. Higher values means the model will take more risks. Try 0.9 for more creative applications, and 0 (argmax sampling) for ones with a well-defined answer. It is generally recommend to use this or <see cref="TopP"/> but not both.
Expand Down Expand Up @@ -52,7 +52,7 @@ public class ChatRequest
/// This is only used for serializing the request into JSON, do not use it directly.
/// </summary>
[JsonProperty("stop")]
public object CompiledStop
internal object CompiledStop
{
get
{
Expand Down Expand Up @@ -109,9 +109,9 @@ public string StopSequence
/// Accepts a json object that maps tokens(specified by their token ID in the tokenizer) to an associated bias value from -100 to 100.
/// Mathematically, the bias is added to the logits generated by the model prior to sampling.
/// The exact effect will vary per model, but values between -1 and 1 should decrease or increase likelihood of selection; values like -100 or 100 should result in a ban or exclusive selection of the relevant token.
/// </summary>
/// </summary>
[JsonProperty("logit_bias")]
public IReadOnlyDictionary<string, float> LogitBias { get; set; }
public IReadOnlyDictionary<string, float> LogitBias { get; set; }

/// <summary>
/// A unique identifier representing your end-user, which can help OpenAI to monitor and detect abuse.
Expand All @@ -123,16 +123,17 @@ public string StopSequence
/// Creates a new, empty <see cref="ChatRequest"/>
/// </summary>
public ChatRequest()
{
this.Model = OpenAI_API.Models.Model.ChatGPTTurbo;
}
{ }

/// <summary>
/// Create a new chat request using the data from the input chat request.
/// </summary>
/// <param name="basedOn"></param>
public ChatRequest(ChatRequest basedOn)
{
{
if (basedOn == null)
return;

this.Model = basedOn.Model;
this.Messages = basedOn.Messages;
this.Temperature = basedOn.Temperature;
Expand Down
Loading