Skip to content

Commit

Permalink
Merge pull request #67 from OkGoDoIt/feature/chat
Browse files Browse the repository at this point in the history
Chat endpoint improvements, tests, readme updates, etc
  • Loading branch information
OkGoDoIt committed Mar 9, 2023
2 parents 9c4c730 + 6e0ca11 commit 7b89ec6
Show file tree
Hide file tree
Showing 8 changed files with 624 additions and 116 deletions.
45 changes: 31 additions & 14 deletions OpenAI_API/Chat/ChatEndpoint.cs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
using OpenAI_API.Models;
using System;
using System.Collections.Generic;
using System.Linq;
using System.Net.Http;
using System.Text;
using System.Threading.Tasks;
Expand All @@ -10,8 +11,8 @@ namespace OpenAI_API.Chat
/// <summary>
/// ChatGPT API endpoint. Use this endpoint to send multiple messages and carry on a conversation.
/// </summary>
public class ChatEndpoint : EndpointBase
{
public class ChatEndpoint : EndpointBase
{
/// <summary>
/// This allows you to set default parameters for every request, for example to set a default temperature or max tokens. For every request, if you do not have a parameter set on the request but do have it set here as a default, the request will automatically pick up the default value.
/// </summary>
Expand All @@ -28,14 +29,23 @@ public class ChatEndpoint : EndpointBase
/// <param name="api"></param>
internal ChatEndpoint(OpenAIAPI api) : base(api) { }

/// <summary>
/// Creates an ongoing chat which can easily encapsulate the conversation. This is the simplest way to use the Chat endpoint.
/// </summary>
/// <returns></returns>
public Conversation CreateConversation()
{
return new Conversation(this, defaultChatRequestArgs: DefaultChatRequestArgs);
}

#region Non-streaming

/// <summary>
/// Ask the API to complete the request using the specified parameters. This is non-streaming, so it will wait until the API returns the full result. Any non-specified parameters will fall back to default values specified in <see cref="DefaultChatRequestArgs"/> if present.
/// </summary>
/// <param name="request">The request to send to the API.</param>
/// <returns>Asynchronously returns the completion result. Look in its <see cref="ChatResult.Choices"/> property for the results.</returns>
public async Task<ChatResult> CreateChatAsync(ChatRequest request)
public async Task<ChatResult> CreateChatCompletionAsync(ChatRequest request)
{
return await HttpPost<ChatResult>(postData: request);
}
Expand All @@ -46,10 +56,10 @@ public async Task<ChatResult> CreateChatAsync(ChatRequest request)
/// <param name="request">The request to send to the API.</param>
/// <param name="numOutputs">Overrides <see cref="ChatRequest.NumChoicesPerMessage"/> as a convenience.</param>
/// <returns>Asynchronously returns the completion result. Look in its <see cref="ChatResult.Choices"/> property for the results.</returns>
public Task<ChatResult> CreateChatAsync(ChatRequest request, int numOutputs = 5)
public Task<ChatResult> CreateChatCompletionAsync(ChatRequest request, int numOutputs = 5)
{
request.NumChoicesPerMessage = numOutputs;
return CreateChatAsync(request);
return CreateChatCompletionAsync(request);
}

/// <summary>
Expand All @@ -66,15 +76,15 @@ public Task<ChatResult> CreateChatAsync(ChatRequest request, int numOutputs = 5)
/// <param name="logitBias">Maps tokens (specified by their token ID in the tokenizer) to an associated bias value from -100 to 100. Mathematically, the bias is added to the logits generated by the model prior to sampling. The exact effect will vary per model, but values between -1 and 1 should decrease or increase likelihood of selection; values like -100 or 100 should result in a ban or exclusive selection of the relevant token.</param>
/// <param name="stopSequences">One or more sequences where the API will stop generating further tokens. The returned text will not contain the stop sequence.</param>
/// <returns>Asynchronously returns the completion result. Look in its <see cref="ChatResult.Choices"/> property for the results.</returns>
public Task<ChatResult> CreateChatAsync(IEnumerable<ChatMessage> messages,
public Task<ChatResult> CreateChatCompletionAsync(IList<ChatMessage> messages,
Model model = null,
double? temperature = null,
double? top_p = null,
int? numOutputs = null,
int? max_tokens = null,
double? frequencyPenalty = null,
double? presencePenalty = null,
IReadOnlyDictionary<string, float> logitBias = null,
IReadOnlyDictionary<string, float> logitBias = null,
params string[] stopSequences)
{
ChatRequest request = new ChatRequest(DefaultChatRequestArgs)
Expand All @@ -90,23 +100,30 @@ public Task<ChatResult> CreateChatAsync(ChatRequest request, int numOutputs = 5)
PresencePenalty = presencePenalty ?? DefaultChatRequestArgs.PresencePenalty,
LogitBias = logitBias ?? DefaultChatRequestArgs.LogitBias
};
return CreateChatAsync(request);
return CreateChatCompletionAsync(request);
}

/// <summary>
/// Ask the API to complete the request using the specified parameters. This is non-streaming, so it will wait until the API returns the full result. Any non-specified parameters will fall back to default values specified in <see cref="DefaultChatRequestArgs"/> if present.
/// Ask the API to complete the request using the specified message(s). Any parameters will fall back to default values specified in <see cref="DefaultChatRequestArgs"/> if present.
/// </summary>
/// <param name="messages">The messages to use in the generation.</param>
/// <returns></returns>
public Task<ChatResult> CreateChatAsync(IEnumerable<ChatMessage> messages)
/// <returns>The <see cref="ChatResult"/> with the API response.</returns>
public Task<ChatResult> CreateChatCompletionAsync(params ChatMessage[] messages)
{
ChatRequest request = new ChatRequest(DefaultChatRequestArgs)
{
Messages = messages
};
return CreateChatAsync(request);
return CreateChatCompletionAsync(request);
}

/// <summary>
/// Ask the API to complete the request using the specified message(s). Any parameters will fall back to default values specified in <see cref="DefaultChatRequestArgs"/> if present.
/// </summary>
/// <param name="userMessages">The user message or messages to use in the generation. All strings are assumed to be of Role <see cref="ChatMessageRole.User"/></param>
/// <returns>The <see cref="ChatResult"/> with the API response.</returns>
public Task<ChatResult> CreateChatCompletionAsync(params string[] userMessages) => CreateChatCompletionAsync(userMessages.Select(m => new ChatMessage(ChatMessageRole.User, m)).ToArray());

#endregion

#region Streaming
Expand Down Expand Up @@ -168,15 +185,15 @@ public IAsyncEnumerable<ChatResult> StreamChatEnumerableAsync(ChatRequest reques
/// <param name="logitBias">Maps tokens (specified by their token ID in the tokenizer) to an associated bias value from -100 to 100. Mathematically, the bias is added to the logits generated by the model prior to sampling. The exact effect will vary per model, but values between -1 and 1 should decrease or increase likelihood of selection; values like -100 or 100 should result in a ban or exclusive selection of the relevant token.</param>
/// <param name="stopSequences">One or more sequences where the API will stop generating further tokens. The returned text will not contain the stop sequence.</param>
/// <returns>An async enumerable with each of the results as they come in. See <see href="https://docs.microsoft.com/en-us/dotnet/csharp/whats-new/csharp-8#asynchronous-streams">the C# docs</see> for more details on how to consume an async enumerable.</returns>
public IAsyncEnumerable<ChatResult> StreamChatEnumerableAsync(IEnumerable<ChatMessage> messages,
public IAsyncEnumerable<ChatResult> StreamChatEnumerableAsync(IList<ChatMessage> messages,
Model model = null,
double? temperature = null,
double? top_p = null,
int? numOutputs = null,
int? max_tokens = null,
double? frequencyPenalty = null,
double? presencePenalty = null,
IReadOnlyDictionary<string, float> logitBias = null,
IReadOnlyDictionary<string, float> logitBias = null,
params string[] stopSequences)
{
ChatRequest request = new ChatRequest(DefaultChatRequestArgs)
Expand Down
57 changes: 57 additions & 0 deletions OpenAI_API/Chat/ChatMessage.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
using Newtonsoft.Json;
using System;
using System.Collections.Generic;
using System.Text;

namespace OpenAI_API.Chat
{
/// <summary>
/// Chat message sent or received from the API. Includes who is speaking in the "role" and the message text in the "content"
/// </summary>
public class ChatMessage
{
/// <summary>
/// Creates an empty <see cref="ChatMessage"/>, with <see cref="Role"/> defaulting to <see cref="ChatMessageRole.User"/>
/// </summary>
public ChatMessage()
{
this.Role = ChatMessageRole.User;
}

/// <summary>
/// Constructor for a new Chat Message
/// </summary>
/// <param name="role">The role of the message, which can be "system", "assistant" or "user"</param>
/// <param name="content">The text to send in the message</param>
public ChatMessage(ChatMessageRole role, string content)
{
this.Role = role;
this.Content = content;
}

[JsonProperty("role")]
internal string rawRole { get; set; }

/// <summary>
/// The role of the message, which can be "system", "assistant" or "user"
/// </summary>
[JsonIgnore]
public ChatMessageRole Role
{
get
{
return ChatMessageRole.FromString(rawRole);
}
set
{
rawRole = value.ToString();
}
}

/// <summary>
/// The content of the message
/// </summary>
[JsonProperty("content")]
public string Content { get; set; }
}
}
119 changes: 119 additions & 0 deletions OpenAI_API/Chat/ChatMessageRole.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
using Newtonsoft.Json;
using System;
using System.Collections.Generic;
using System.ComponentModel.Design;
using System.Text;

namespace OpenAI_API.Chat
{
/// <summary>
/// Represents the Role of a <see cref="ChatMessage"/>. Typically, a conversation is formatted with a system message first, followed by alternating user and assistant messages. See <see href="https://platform.openai.com/docs/guides/chat/introduction">the OpenAI docs</see> for more details about usage.
/// </summary>
public class ChatMessageRole : IEquatable<ChatMessageRole>
{
/// <summary>
/// Contructor is private to force usage of strongly typed values
/// </summary>
/// <param name="value"></param>
private ChatMessageRole(string value) { Value = value; }

/// <summary>
/// Gets the singleton instance of <see cref="ChatMessageRole"/> based on the string value.
/// </summary>
/// <param name="roleName">Muse be one of "system", "user", or "assistant"</param>
/// <returns></returns>
public static ChatMessageRole FromString(string roleName)
{
switch (roleName)
{
case "system":
return ChatMessageRole.System;
case "user":
return ChatMessageRole.User;
case "assistant":
return ChatMessageRole.Assistant;
default:
return null;
}
}

private string Value { get; }

/// <summary>
/// The system message helps set the behavior of the assistant.
/// </summary>
public static ChatMessageRole System { get; } = new ChatMessageRole("system");
/// <summary>
/// The user messages help instruct the assistant. They can be generated by the end users of an application, or set by a developer as an instruction.
/// </summary>
public static ChatMessageRole User { get; } = new ChatMessageRole("user");
/// <summary>
/// The assistant messages help store prior responses. They can also be written by a developer to help give examples of desired behavior.
/// </summary>
public static ChatMessageRole Assistant { get; } = new ChatMessageRole("assistant");

/// <summary>
/// Gets the string value for this role to pass to the API
/// </summary>
/// <returns>The size as a string</returns>
public override string ToString()
{
return Value;
}

/// <summary>
/// Determines whether this instance and a specified object have the same value.
/// </summary>
/// <param name="obj">The ChatMessageRole to compare to this instance</param>
/// <returns>true if obj is a ChatMessageRole and its value is the same as this instance; otherwise, false. If obj is null, the method returns false</returns>
public override bool Equals(object obj)
{
return Value.Equals((obj as ChatMessageRole).Value);
}

/// <summary>
/// Returns the hash code for this object
/// </summary>
/// <returns>A 32-bit signed integer hash code</returns>
public override int GetHashCode()
{
return Value.GetHashCode();
}

/// <summary>
/// Determines whether this instance and a specified object have the same value.
/// </summary>
/// <param name="other">The ChatMessageRole to compare to this instance</param>
/// <returns>true if other's value is the same as this instance; otherwise, false. If other is null, the method returns false</returns>
public bool Equals(ChatMessageRole other)
{
return Value.Equals(other.Value);
}

/// <summary>
/// Gets the string value for this role to pass to the API
/// </summary>
/// <param name="value">The ChatMessageRole to convert</param>
public static implicit operator String(ChatMessageRole value) { return value; }

///// <summary>
///// Used during the Json serialization process
///// </summary>
//internal class ChatMessageRoleJsonConverter : JsonConverter<ChatMessageRole>
//{
// public override void WriteJson(JsonWriter writer, ChatMessageRole value, JsonSerializer serializer)
// {
// writer.WriteValue(value.ToString());
// }

// public override ChatMessageRole ReadJson(JsonReader reader, Type objectType, ChatMessageRole existingValue, bool hasExistingValue, JsonSerializer serializer)
// {
// if (reader.TokenType != JsonToken.String)
// {
// throw new JsonSerializationException();
// }
// return new ChatMessageRole(reader.ReadAsString());
// }
//}
}
}
17 changes: 9 additions & 8 deletions OpenAI_API/Chat/ChatRequest.cs
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ public class ChatRequest
/// The messages to send with this Chat Request
/// </summary>
[JsonProperty("messages")]
public IEnumerable<ChatMessage> Messages { get; set; }
public IList<ChatMessage> Messages { get; set; }

/// <summary>
/// What sampling temperature to use. Higher values means the model will take more risks. Try 0.9 for more creative applications, and 0 (argmax sampling) for ones with a well-defined answer. It is generally recommend to use this or <see cref="TopP"/> but not both.
Expand Down Expand Up @@ -52,7 +52,7 @@ public class ChatRequest
/// This is only used for serializing the request into JSON, do not use it directly.
/// </summary>
[JsonProperty("stop")]
public object CompiledStop
internal object CompiledStop
{
get
{
Expand Down Expand Up @@ -109,9 +109,9 @@ public string StopSequence
/// Accepts a json object that maps tokens(specified by their token ID in the tokenizer) to an associated bias value from -100 to 100.
/// Mathematically, the bias is added to the logits generated by the model prior to sampling.
/// The exact effect will vary per model, but values between -1 and 1 should decrease or increase likelihood of selection; values like -100 or 100 should result in a ban or exclusive selection of the relevant token.
/// </summary>
/// </summary>
[JsonProperty("logit_bias")]
public IReadOnlyDictionary<string, float> LogitBias { get; set; }
public IReadOnlyDictionary<string, float> LogitBias { get; set; }

/// <summary>
/// A unique identifier representing your end-user, which can help OpenAI to monitor and detect abuse.
Expand All @@ -123,16 +123,17 @@ public string StopSequence
/// Creates a new, empty <see cref="ChatRequest"/>
/// </summary>
public ChatRequest()
{
this.Model = OpenAI_API.Models.Model.ChatGPTTurbo;
}
{ }

/// <summary>
/// Create a new chat request using the data from the input chat request.
/// </summary>
/// <param name="basedOn"></param>
public ChatRequest(ChatRequest basedOn)
{
{
if (basedOn == null)
return;

this.Model = basedOn.Model;
this.Messages = basedOn.Messages;
this.Temperature = basedOn.Temperature;
Expand Down
Loading

0 comments on commit 7b89ec6

Please sign in to comment.