Skip to content

Commit

Permalink
feat: implemented new syntax leveraging IAsyncEnumerable
Browse files Browse the repository at this point in the history
Implemented a new syntax leveraging IAsyncEnumerable to let users choose between callbacks and stream enumerations.
  • Loading branch information
JerrettDavis committed Jun 4, 2024
1 parent 60b5fb8 commit b0016ad
Show file tree
Hide file tree
Showing 8 changed files with 385 additions and 3 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -396,3 +396,4 @@ FodyWeavers.xsd

# JetBrains Rider
*.sln.iml
/.idea
22 changes: 22 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -35,18 +35,40 @@ var models = await ollama.ListLocalModels();

### Pulling a model and reporting progress

#### Callback Syntax
```csharp
await ollama.PullModel("mistral", status => Console.WriteLine($"({status.Percent}%) {status.Status}"));
```

#### IAsyncEnumerable Syntax
```csharp
await foreach (var status in ollama.PullModel("mistral"))
{
Console.WriteLine($"({status.Percent}%) {status.Status}");
}
```

### Streaming a completion directly into the console

#### Callback Syntax
```csharp
// keep reusing the context to keep the chat topic going
ConversationContext context = null;
context = await ollama.StreamCompletion("How are you today?", context, stream => Console.Write(stream.Response));
```

#### IAsyncEnumerable Syntax
```csharp
// keep reusing the context to keep the chat topic going
ConversationContext context = null;
await foreach (var stream in ollama.StreamCompletion("How are you today?", context))
{
Console.Write(stream.Response);
context = stream.Context;
}
```


### Building interactive chats

```csharp
Expand Down
39 changes: 39 additions & 0 deletions src/IOllamaApiClient.cs
Original file line number Diff line number Diff line change
@@ -1,12 +1,16 @@
using OllamaSharp.Models;
using OllamaSharp.Streamer;
using System.Collections.Generic;
using System.Runtime.CompilerServices;
using System.Threading.Tasks;
using OllamaSharp.Models.Chat;
using System.Threading;

namespace OllamaSharp
{
/// <summary>
/// Interface for the Ollama API client.
/// </summary>
public interface IOllamaApiClient
{
/// <summary>
Expand Down Expand Up @@ -44,6 +48,14 @@ public interface IOllamaApiClient
/// <param name="cancellationToken">The token to cancel the operation with</param>
Task CreateModel(CreateModelRequest request, IResponseStreamer<CreateStatus> streamer, CancellationToken cancellationToken = default);

/// <summary>
/// Sends a request to the /api/create endpoint to create a model
/// </summary>
/// <param name="request">The request object containing the model details</param>
/// <param name="cancellationToken">The token to cancel the operation with</param>
/// <returns>An asynchronous enumerable of the model creation status</returns>
IAsyncEnumerable<CreateStatus?> CreateModel(CreateModelRequest request, [EnumeratorCancellation] CancellationToken cancellationToken = default);

/// <summary>
/// Sends a request to the /api/delete endpoint to delete a model
/// </summary>
Expand Down Expand Up @@ -92,6 +104,14 @@ public interface IOllamaApiClient
/// <param name="cancellationToken">The token to cancel the operation with</param>
Task PullModel(PullModelRequest request, IResponseStreamer<PullStatus> streamer, CancellationToken cancellationToken = default);

/// <summary>
/// Sends a request to the /api/pull endpoint to pull a new model
/// </summary>
/// <param name="request">The request specifying the model name and whether to use insecure connection</param>
/// <param name="cancellationToken">The token to cancel the operation with</param>
/// <returns>Async enumerable of PullStatus objects representing the status of the model pull operation</returns>
IAsyncEnumerable<PullStatus?> PullModel(PullModelRequest request, [EnumeratorCancellation] CancellationToken cancellationToken = default);

/// <summary>
/// Sends a request to the /api/push endpoint to push a new model
/// </summary>
Expand All @@ -103,6 +123,15 @@ public interface IOllamaApiClient
/// <param name="cancellationToken">The token to cancel the operation with</param>
Task PushModel(PushRequest request, IResponseStreamer<PushStatus> streamer, CancellationToken cancellationToken = default);

/// <summary>
/// Pushes a model to the Ollama API endpoint.
/// </summary>
/// <param name="request">The request containing the model information to push.</param>
/// <param name="cancellationToken">The token to cancel the operation with.</param>
/// <returns>An asynchronous enumerable of push status updates. Use the enumerator to retrieve the push status updates.</returns>
IAsyncEnumerable<PushStatus?> PushModel(PushRequest request,
[EnumeratorCancellation] CancellationToken cancellationToken = default);

/// <summary>
/// Sends a request to the /api/show endpoint to show the information of a model
/// </summary>
Expand All @@ -127,6 +156,16 @@ public interface IOllamaApiClient
/// </returns>
Task<ConversationContext> StreamCompletion(GenerateCompletionRequest request, IResponseStreamer<GenerateCompletionResponseStream> streamer, CancellationToken cancellationToken = default);

/// <summary>
/// Streams completion responses from the /api/generate endpoint on the
/// Ollama API based on the provided request.
/// </summary>
/// <param name="request">The request containing the parameters for the completion.</param>
/// <param name="streamer">The streamer that receives parts of the completion response as they are streamed by the Ollama endpoint.</param>
/// <param name="cancellationToken">The token to cancel the operation with.</param>
/// <returns>An asynchronous enumerable of completion response streams.</returns>
IAsyncEnumerable<GenerateCompletionResponseStream?> StreamCompletion(GenerateCompletionRequest request, [EnumeratorCancellation] CancellationToken cancellationToken = default);

/// <summary>
/// Sends a query to check whether the Ollama api is running or not
/// </summary>
Expand Down
158 changes: 158 additions & 0 deletions src/OllamaApiClient.cs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
using System.IO;
using System.Linq;
using System.Net.Http;
using System.Runtime.CompilerServices;
using System.Text;
using System.Text.Json;
using System.Threading.Tasks;
Expand Down Expand Up @@ -47,13 +48,28 @@ public OllamaApiClient(Configuration config)
public OllamaApiClient(HttpClient client, string defaultModel = "")
{
_client = client ?? throw new ArgumentNullException(nameof(client));
Config = new Configuration {
Uri = client.BaseAddress ??
throw new ArgumentNullException(nameof(client.BaseAddress)),
Model = defaultModel
};
SelectedModel = defaultModel;
}

public async Task CreateModel(CreateModelRequest request, IResponseStreamer<CreateStatus> streamer, CancellationToken cancellationToken = default)
{
await StreamPostAsync("api/create", request, streamer, cancellationToken);
}

public async IAsyncEnumerable<CreateStatus?> CreateModel(
CreateModelRequest request,
[EnumeratorCancellation] CancellationToken cancellationToken = default)
{
var stream = StreamPostAsync<CreateModelRequest, CreateStatus?>("api/create", request, cancellationToken);

await foreach (var result in stream)
yield return result;
}

public async Task DeleteModel(string model, CancellationToken cancellationToken = default)
{
Expand Down Expand Up @@ -93,11 +109,29 @@ public async Task PullModel(PullModelRequest request, IResponseStreamer<PullStat
{
await StreamPostAsync("api/pull", request, streamer, cancellationToken);
}

public async IAsyncEnumerable<PullStatus?> PullModel(
PullModelRequest request,
[EnumeratorCancellation] CancellationToken cancellationToken = default)
{
var stream = StreamPostAsync<PullModelRequest, PullStatus?>("api/pull", request, cancellationToken);

await foreach (var result in stream)
yield return result;
}

public async Task PushModel(PushRequest request, IResponseStreamer<PushStatus> streamer, CancellationToken cancellationToken = default)
{
await StreamPostAsync("api/push", request, streamer, cancellationToken);
}

public async IAsyncEnumerable<PushStatus?> PushModel(PushRequest request, [EnumeratorCancellation] CancellationToken cancellationToken = default)
{
var stream = StreamPostAsync<PushRequest, PushStatus?>("api/push", request, cancellationToken);

await foreach (var result in stream)
yield return result;
}

public async Task<GenerateEmbeddingResponse> GenerateEmbeddings(GenerateEmbeddingRequest request, CancellationToken cancellationToken = default)
{
Expand All @@ -108,6 +142,14 @@ public async Task<ConversationContext> StreamCompletion(GenerateCompletionReques
{
return await GenerateCompletion(request, streamer, cancellationToken);
}

public async IAsyncEnumerable<GenerateCompletionResponseStream?> StreamCompletion(GenerateCompletionRequest request, [EnumeratorCancellation] CancellationToken cancellationToken = default)
{
var stream = GenerateCompletion(request, cancellationToken);

await foreach (var result in stream)
yield return result;
}

public async Task<ConversationContextWithResponse> GetCompletion(GenerateCompletionRequest request, CancellationToken cancellationToken = default)
{
Expand All @@ -130,6 +172,33 @@ public async Task<IEnumerable<Message>> SendChat(ChatRequest chatRequest, IRespo

return await ProcessStreamedChatResponseAsync(chatRequest, response, streamer, cancellationToken);
}

public async IAsyncEnumerable<ChatResponseStream?> StreamChat(
ChatRequest chatRequest,
[EnumeratorCancellation] CancellationToken cancellationToken = default)
{
var request = new HttpRequestMessage(HttpMethod.Post, "api/chat")
{
Content = new StringContent(
JsonSerializer.Serialize(chatRequest),
Encoding.UTF8,
"application/json")
};

var completion = chatRequest.Stream
? HttpCompletionOption.ResponseHeadersRead
: HttpCompletionOption.ResponseContentRead;

using var response = await _client.SendAsync(
request, completion, cancellationToken);

response.EnsureSuccessStatusCode();

var stream = ProcessStreamedChatResponseAsync(response, cancellationToken);

await foreach (var result in stream)
yield return result;
}

public async Task<bool> IsRunning(CancellationToken cancellationToken = default)
{
Expand All @@ -153,6 +222,30 @@ private async Task<ConversationContext> GenerateCompletion(GenerateCompletionReq

return await ProcessStreamedCompletionResponseAsync(response, streamer, cancellationToken);
}

private async IAsyncEnumerable<GenerateCompletionResponseStream?> GenerateCompletion(GenerateCompletionRequest generateRequest, [EnumeratorCancellation] CancellationToken cancellationToken)
{
var request = new HttpRequestMessage(HttpMethod.Post, "api/generate")
{
Content = new StringContent(
JsonSerializer.Serialize(generateRequest),
Encoding.UTF8,
"application/json")
};

var completion = generateRequest.Stream
? HttpCompletionOption.ResponseHeadersRead
: HttpCompletionOption.ResponseContentRead;

using var response = await _client.SendAsync(
request, completion, cancellationToken);
response.EnsureSuccessStatusCode();

var stream = ProcessStreamedCompletionResponseAsync(response, cancellationToken);

await foreach (var result in stream)
yield return result;
}

private async Task<TResponse> GetAsync<TResponse>(string endpoint, CancellationToken cancellationToken)
{
Expand Down Expand Up @@ -194,6 +287,30 @@ private async Task PostAsync<TRequest>(string endpoint, TRequest request, Cancel

await ProcessStreamedResponseAsync(response, streamer, cancellationToken);
}

private async IAsyncEnumerable<TResponse?> StreamPostAsync<TRequest, TResponse>(string endpoint, TRequest requestModel, [EnumeratorCancellation] CancellationToken cancellationToken)
{
var request = new HttpRequestMessage(HttpMethod.Post, endpoint)
{
Content = new StringContent(
JsonSerializer.Serialize(requestModel),
Encoding.UTF8,
"application/json")
};

using var response = await _client.SendAsync(
request,
HttpCompletionOption.ResponseHeadersRead,
cancellationToken);

response.EnsureSuccessStatusCode();

var stream = ProcessStreamedResponseAsync<TResponse>(response, cancellationToken);

await foreach (var result in stream)
yield return result;
}


private static async Task ProcessStreamedResponseAsync<TLine>(HttpResponseMessage response, IResponseStreamer<TLine> streamer, CancellationToken cancellationToken)
{
Expand All @@ -207,6 +324,18 @@ private static async Task ProcessStreamedResponseAsync<TLine>(HttpResponseMessag
streamer.Stream(streamedResponse);
}
}

private static async IAsyncEnumerable<TLine?> ProcessStreamedResponseAsync<TLine>(HttpResponseMessage response, [EnumeratorCancellation] CancellationToken cancellationToken)
{
var stream = await response.Content.ReadAsStreamAsync();
using var reader = new StreamReader(stream);

while (!reader.EndOfStream && !cancellationToken.IsCancellationRequested)
{
var line = await reader.ReadLineAsync();
yield return JsonSerializer.Deserialize<TLine?>(line);
}
}

private static async Task<ConversationContext> ProcessStreamedCompletionResponseAsync(HttpResponseMessage response, IResponseStreamer<GenerateCompletionResponseStream> streamer, CancellationToken cancellationToken)
{
Expand All @@ -228,6 +357,23 @@ private static async Task<ConversationContext> ProcessStreamedCompletionResponse

return new ConversationContext(Array.Empty<long>());
}

private static async IAsyncEnumerable<GenerateCompletionResponseStream?>
ProcessStreamedCompletionResponseAsync(HttpResponseMessage response, [EnumeratorCancellation] CancellationToken cancellationToken)
{
using var stream = await response.Content.ReadAsStreamAsync();
using var reader = new StreamReader(stream);

while (!reader.EndOfStream && !cancellationToken.IsCancellationRequested)
{
var line = await reader.ReadLineAsync();
var streamedResponse = JsonSerializer.Deserialize<GenerateCompletionResponseStream>(line);

yield return streamedResponse?.Done ?? false
? JsonSerializer.Deserialize<GenerateCompletionDoneResponseStream>(line)!
: streamedResponse;
}
}

private static async Task<IEnumerable<Message>> ProcessStreamedChatResponseAsync(ChatRequest chatRequest, HttpResponseMessage response, IResponseStreamer<ChatResponseStream> streamer, CancellationToken cancellationToken)
{
Expand Down Expand Up @@ -261,6 +407,18 @@ private static async Task<IEnumerable<Message>> ProcessStreamedChatResponseAsync

return Array.Empty<Message>();
}

private static async IAsyncEnumerable<ChatResponseStream?> ProcessStreamedChatResponseAsync(HttpResponseMessage response, [EnumeratorCancellation] CancellationToken cancellationToken)
{
using var stream = await response.Content.ReadAsStreamAsync();
using var reader = new StreamReader(stream);

while (!reader.EndOfStream && !cancellationToken.IsCancellationRequested)
{
var line = await reader.ReadLineAsync();
yield return JsonSerializer.Deserialize<ChatResponseStream>(line);
}
}
}

public record ConversationContext(long[] Context);
Expand Down
Loading

0 comments on commit b0016ad

Please sign in to comment.