Skip to content

Commit

Permalink
more formatting and small fixes for subscriptions
Browse files Browse the repository at this point in the history
  • Loading branch information
lukemurray committed May 26, 2024
1 parent 7d72f43 commit 5c1f88f
Show file tree
Hide file tree
Showing 18 changed files with 2,275 additions and 1,983 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,9 @@ public class PersonArgConstructor

- Fix issue where a subscription execution had access to a disposed `IServiceProvider`
- `Broadcaster` is thread safe when removing observers
- Fixes in the implementation of the [GraphQL over Websockets](https://github.com/enisdenjo/graphql-ws/blob/master/PROTOCOL.md)
- ID is now a `string` as it does not specify that it must be a `Guid`
- Better errors on invalid messages

# 5.3.0

Expand Down
1 change: 1 addition & 0 deletions src/EntityGraphQL.AspNet/EntityGraphQL.AspNet.csproj
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
<TargetFrameworks>net6.0;net7.0;net8.0</TargetFrameworks>
<AssemblyName>EntityGraphQL.AspNet</AssemblyName>
<PackageId>EntityGraphQL.AspNet</PackageId>
<LangVersion>12</LangVersion>
<PackageVersion>5.4.0</PackageVersion>
<Description>Contains ASP.NET extensions and middleware for EntityGraphQL</Description>
<Authors>Luke Murray</Authors>
Expand Down
44 changes: 22 additions & 22 deletions src/EntityGraphQL.AspNet/WebSockets/GraphQLWSMessage.cs
Original file line number Diff line number Diff line change
@@ -1,29 +1,29 @@
using System;
using System.Collections.Generic;

namespace EntityGraphQL.AspNet.WebSockets
namespace EntityGraphQL.AspNet.WebSockets;

public class BaseGraphQLWSResponse
{
public string Type { get; set; } = string.Empty;
}

public class BaseWithIdGraphQLWSResponse : BaseGraphQLWSResponse
{
public class TypeOnlyGraphQLWSResponse
{
public string? Type { get; set; }
}
public string? Id { get; set; }
}

public class WithIdGraphQLWSResponse : TypeOnlyGraphQLWSResponse
{
public Guid? Id { get; set; }
}
public class GraphQLWSRequest : WithIdGraphQLWSResponse
{
public QueryRequest? Payload { get; set; }
}
public class GraphQLWSRequest : BaseWithIdGraphQLWSResponse
{
public QueryRequest? Payload { get; set; }
}

public class GraphQLWSResponse : WithIdGraphQLWSResponse
{
public QueryResult? Payload { get; set; }
}
public class GraphQLWSResponse : BaseWithIdGraphQLWSResponse
{
public QueryResult? Payload { get; set; }
}

public class GraphQLWSError : WithIdGraphQLWSResponse
{
public List<GraphQLError>? Payload { get; set; }
}
}
public class GraphQLWSError : BaseWithIdGraphQLWSResponse
{
public List<GraphQLError>? Payload { get; set; }
}
42 changes: 20 additions & 22 deletions src/EntityGraphQL.AspNet/WebSockets/GraphQLWebSocketServer.cs
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ public class GraphQLWebSocketServer<TQueryType> : IGraphQLWebSocketServer
/// <summary>
/// These are the subscriptions/clients that are currently active with this server.
/// </summary>
private readonly Dictionary<Guid, IDisposable> subscriptions = new();
private readonly Dictionary<string, IDisposable> subscriptions = new();
private readonly WebSocket webSocket;
private readonly ExecutionOptions options;
private bool initialised;
Expand Down Expand Up @@ -94,7 +94,10 @@ private async Task HandleMessageAsync(GraphQLWSRequest graphQLWSMessage)
await HandleSubscribeAsync(graphQLWSMessage);
break;
case GraphQLWSMessageType.Complete:
CompleteSubscription(graphQLWSMessage.Id!.Value);
if (graphQLWSMessage.Id == null)
await CloseConnectionAsync((WebSocketCloseStatus)4400, "Invalid complete message, missing id field.");
else
await CompleteSubscriptionAsync(graphQLWSMessage.Id);
break;
case GraphQLWSMessageType.Pong:
break; // can come to us but we don't care
Expand All @@ -112,7 +115,7 @@ private async Task HandleSubscribeAsync(GraphQLWSRequest graphQLWSMessage)
return;
}

if (!graphQLWSMessage.Id.HasValue)
if (graphQLWSMessage.Id == null)
await CloseConnectionAsync((WebSocketCloseStatus)4400, "Invalid subscribe message, missing id field.");
else if (graphQLWSMessage.Payload == null)
await CloseConnectionAsync((WebSocketCloseStatus)4400, "Invalid subscribe message, missing payload field.");
Expand All @@ -128,57 +131,51 @@ await CloseConnectionAsync(
return;
}

var schemaContext = Context.RequestServices.GetService<TQueryType>();
if (schemaContext == null)
await CloseConnectionAsync(
(WebSocketCloseStatus)4400,
$"No schema context was found in the service collection. Make sure the {typeof(TQueryType).Name} used with MapGraphQL<{typeof(TQueryType).Name}>() is registered in the service collection."
);
else if (subscriptions.ContainsKey(graphQLWSMessage.Id.Value))
await CloseConnectionAsync((WebSocketCloseStatus)4409, $"Subscriber for {graphQLWSMessage.Id.Value} already exists");
if (subscriptions.ContainsKey(graphQLWSMessage.Id))
await CloseConnectionAsync((WebSocketCloseStatus)4409, $"Subscriber for {graphQLWSMessage.Id} already exists");
else
{
var request = graphQLWSMessage.Payload;
// executing this sets up the observers etc. We don't return any data until we have an event
var result = await schema.ExecuteRequestWithContextAsync(request, schemaContext, Context.RequestServices, Context.User, options)!;
var result = await schema.ExecuteRequestAsync(request, Context.RequestServices, Context.User, options)!;
if (result.Errors != null)
{
await SendErrorAsync(graphQLWSMessage.Id!.Value, result.Errors);
await SendErrorAsync(graphQLWSMessage.Id, result.Errors);
}
// No error and a successful subscribe operation
if (result.Data?.Values.First() is GraphQLSubscribeResult subscribeResult)
{
var websocketSubscription = (IDisposable)
Activator.CreateInstance(
typeof(WebSocketSubscription<,>).MakeGenericType(typeof(TQueryType), subscribeResult!.EventType),
graphQLWSMessage.Id!.Value,
graphQLWSMessage.Id,
subscribeResult!.GetObservable(),
this,
subscribeResult!.SubscriptionStatement,
subscribeResult!.Field
)!;
subscriptions.Add(graphQLWSMessage.Id!.Value, websocketSubscription);
subscriptions.Add(graphQLWSMessage.Id, websocketSubscription);
}
else
{
// Assume it is a query or mutation over websockets
if (result.Errors == null)
{
await SendNextAsync(graphQLWSMessage.Id!.Value, result);
await SendNextAsync(graphQLWSMessage.Id, result);
}
// send complete after next or error above
await SendAsync(new WithIdGraphQLWSResponse { Type = GraphQLWSMessageType.Complete, Id = graphQLWSMessage.Id!.Value, });
await SendAsync(new BaseWithIdGraphQLWSResponse { Type = GraphQLWSMessageType.Complete, Id = graphQLWSMessage.Id, });
}
}
}
}

public async Task SendErrorAsync(Guid id, Exception exception)
public async Task SendErrorAsync(string id, Exception exception)
{
await SendErrorAsync(id, new List<GraphQLError> { new GraphQLError(exception.Message, null) });
}

public Task SendErrorAsync(Guid id, IEnumerable<GraphQLError> errors)
public Task SendErrorAsync(string id, IEnumerable<GraphQLError> errors)
{
return SendAsync(
new GraphQLWSError
Expand All @@ -190,17 +187,18 @@ public Task SendErrorAsync(Guid id, IEnumerable<GraphQLError> errors)
);
}

public void CompleteSubscription(Guid id)
public Task CompleteSubscriptionAsync(string id)
{
subscriptions.TryGetValue(id, out var subscription);
if (subscription != null)
{
subscription.Dispose();
subscriptions.Remove(id);
}
return Task.CompletedTask;
}

public async Task SendNextAsync(Guid id, QueryResult result)
public async Task SendNextAsync(string id, QueryResult result)
{
await SendAsync(
new GraphQLWSResponse
Expand All @@ -214,7 +212,7 @@ await SendAsync(

private async Task SendSimpleResponseAsync(string type)
{
await SendAsync(new TypeOnlyGraphQLWSResponse { Type = type });
await SendAsync(new BaseGraphQLWSResponse { Type = type });
}

private Task SendAsync(object graphQLWSMessage)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ namespace EntityGraphQL.AspNet.WebSockets;
public interface IGraphQLWebSocketServer
{
public HttpContext Context { get; }
void CompleteSubscription(Guid id);
Task SendErrorAsync(Guid id, Exception exception);
Task SendNextAsync(Guid id, QueryResult result);
Task CompleteSubscriptionAsync(string id);
Task SendErrorAsync(string id, Exception exception);
Task SendNextAsync(string id, QueryResult result);
}
15 changes: 9 additions & 6 deletions src/EntityGraphQL.AspNet/WebSockets/WebSocketSubscription.cs
Original file line number Diff line number Diff line change
Expand Up @@ -10,16 +10,19 @@ namespace EntityGraphQL.AspNet.WebSockets;
/// <typeparam name="TEventType"></typeparam>
public sealed class WebSocketSubscription<TQueryContext, TEventType> : IDisposable, IObserver<TEventType>
{
private readonly Guid id;
/// <summary>
/// unique-operation-id from the protocol
/// </summary>
public string OperationId { get; }
private readonly IObservable<TEventType> observable;
private readonly IGraphQLWebSocketServer server;
private readonly IDisposable subscription;
private readonly GraphQLSubscriptionStatement subscriptionStatement;
private readonly GraphQLSubscriptionField subscriptionNode;

public WebSocketSubscription(Guid id, object observable, IGraphQLWebSocketServer server, GraphQLSubscriptionStatement subscriptionStatement, GraphQLSubscriptionField node)
public WebSocketSubscription(string id, object observable, IGraphQLWebSocketServer server, GraphQLSubscriptionStatement subscriptionStatement, GraphQLSubscriptionField node)
{
this.id = id;
this.OperationId = id;
if (observable is not IObservable<TEventType>)
throw new ArgumentException($"{nameof(observable)} must be of type {nameof(IObservable<TEventType>)}");
this.observable = (IObservable<TEventType>)observable;
Expand All @@ -37,7 +40,7 @@ public void OnNext(TEventType value)
var data = subscriptionStatement.ExecuteSubscriptionEvent<TQueryContext, TEventType>(subscriptionNode, value, server.Context.RequestServices);
var result = new QueryResult();
result.SetData(new Dictionary<string, object?> { { subscriptionNode.Name, data } });
server.SendNextAsync(id, result).GetAwaiter().GetResult();
server.SendNextAsync(OperationId, result).GetAwaiter().GetResult();
}
catch (Exception ex)
{
Expand All @@ -47,12 +50,12 @@ public void OnNext(TEventType value)

public void OnError(Exception error)
{
server.SendErrorAsync(id, error).GetAwaiter().GetResult();
server.SendErrorAsync(OperationId, error).GetAwaiter().GetResult();
}

public void OnCompleted()
{
server.CompleteSubscription(id);
server.CompleteSubscriptionAsync(OperationId);
}

public void Dispose()
Expand Down
Loading

0 comments on commit 5c1f88f

Please sign in to comment.