diff --git a/SignalR.sln b/SignalR.sln index b3da7f86f0..ea2381c137 100644 --- a/SignalR.sln +++ b/SignalR.sln @@ -1,12 +1,17 @@ Microsoft Visual Studio Solution File, Format Version 12.00 # Visual Studio 15 -VisualStudioVersion = 15.0.26510.0 +VisualStudioVersion = 15.0.26526.1 MinimumVisualStudioVersion = 10.0.40219.1 Project("{2150E333-8FDC-42A3-9474-1A3956D46DE8}") = "src", "src", "{DA69F624-5398-4884-87E4-B816698CDE65}" EndProject Project("{2150E333-8FDC-42A3-9474-1A3956D46DE8}") = "Solution Items", "Solution Items", "{83B2C3EB-A3D8-4E6F-9A3C-A380B005EF31}" ProjectSection(SolutionItems) = preProject + build\common.props = build\common.props + build\dependencies.props = build\dependencies.props + build\Key.snk = build\Key.snk NuGet.config = NuGet.config + build\repo.props = build\repo.props + build\repo.targets = build\repo.targets EndProjectSection EndProject Project("{2150E333-8FDC-42A3-9474-1A3956D46DE8}") = "samples", "samples", "{C4BC9889-B49F-41B6-806B-F84941B2549B}" @@ -64,6 +69,8 @@ EndProject Project("{2150E333-8FDC-42A3-9474-1A3956D46DE8}") = "Common", "Common", "{6CEC3DC2-5B01-45A8-8F0D-8531315DA90B}" ProjectSection(SolutionItems) = preProject test\Common\ArrayOutput.cs = test\Common\ArrayOutput.cs + test\Common\ByteArrayExtensions.cs = test\Common\ByteArrayExtensions.cs + test\Common\ChannelExtensions.cs = test\Common\ChannelExtensions.cs test\Common\TaskExtensions.cs = test\Common\TaskExtensions.cs EndProjectSection EndProject diff --git a/build/dependencies.props b/build/dependencies.props index d6001912c3..12125af5be 100644 --- a/build/dependencies.props +++ b/build/dependencies.props @@ -13,6 +13,7 @@ 2.0.0-* 15.3.0-* 2.3.0-beta2-* + 3.1.1 diff --git a/client-ts/Microsoft.AspNetCore.SignalR.Test.Server/Program.cs b/client-ts/Microsoft.AspNetCore.SignalR.Test.Server/Program.cs index 58f2bbd7d3..00aa867bd1 100644 --- a/client-ts/Microsoft.AspNetCore.SignalR.Test.Server/Program.cs +++ b/client-ts/Microsoft.AspNetCore.SignalR.Test.Server/Program.cs @@ -1,4 +1,4 @@ -// Copyright (c) .NET Foundation. All rights reserved. +// Copyright (c) .NET Foundation. All rights reserved. // Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. using System; @@ -16,6 +16,7 @@ public class Program public static void Main(string[] args) { var host = new WebHostBuilder() + .UseSetting(WebHostDefaults.PreventHostingStartupKey, "true") // Work around https://github.com/aspnet/Hosting/issues/1075 .ConfigureLogging(factory => { factory.AddConsole(); diff --git a/client-ts/Microsoft.AspNetCore.SignalR.Test.Server/Startup.cs b/client-ts/Microsoft.AspNetCore.SignalR.Test.Server/Startup.cs index 6c30114542..cb5ca7f04d 100644 --- a/client-ts/Microsoft.AspNetCore.SignalR.Test.Server/Startup.cs +++ b/client-ts/Microsoft.AspNetCore.SignalR.Test.Server/Startup.cs @@ -4,7 +4,6 @@ using Microsoft.AspNetCore.Builder; using Microsoft.AspNetCore.Hosting; using Microsoft.Extensions.DependencyInjection; -using Microsoft.Extensions.Logging; namespace Microsoft.AspNetCore.SignalR.Test.Server { diff --git a/client-ts/Microsoft.AspNetCore.SignalR.Test.Server/TestHub.cs b/client-ts/Microsoft.AspNetCore.SignalR.Test.Server/TestHub.cs index c72de8ca9b..d4f6fb94b4 100644 --- a/client-ts/Microsoft.AspNetCore.SignalR.Test.Server/TestHub.cs +++ b/client-ts/Microsoft.AspNetCore.SignalR.Test.Server/TestHub.cs @@ -1,7 +1,9 @@ -// Copyright (c) .NET Foundation. All rights reserved. +// Copyright (c) .NET Foundation. All rights reserved. // Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. using System; +using System.Reactive.Disposables; +using System.Reactive.Linq; using System.Threading.Tasks; namespace Microsoft.AspNetCore.SignalR.Test.Server @@ -22,5 +24,10 @@ public Task InvokeWithString(string message) { return Clients.Client(Context.Connection.ConnectionId).InvokeAsync("Message", message); } + + public IObservable Stream() + { + return new string[] { "a", "b", "c" }.ToObservable(); + } } } diff --git a/client-ts/Microsoft.AspNetCore.SignalR.Test.Server/wwwroot/connectionTests.html b/client-ts/Microsoft.AspNetCore.SignalR.Test.Server/wwwroot/connectionTests.html index fad27119c5..94f5545107 100644 --- a/client-ts/Microsoft.AspNetCore.SignalR.Test.Server/wwwroot/connectionTests.html +++ b/client-ts/Microsoft.AspNetCore.SignalR.Test.Server/wwwroot/connectionTests.html @@ -1,4 +1,4 @@ - + @@ -16,4 +16,4 @@ - \ No newline at end of file + diff --git a/client-ts/Microsoft.AspNetCore.SignalR.Test.Server/wwwroot/js/hubConnectionTests.js b/client-ts/Microsoft.AspNetCore.SignalR.Test.Server/wwwroot/js/hubConnectionTests.js index abc326f7dd..0d90a2d066 100644 --- a/client-ts/Microsoft.AspNetCore.SignalR.Test.Server/wwwroot/js/hubConnectionTests.js +++ b/client-ts/Microsoft.AspNetCore.SignalR.Test.Server/wwwroot/js/hubConnectionTests.js @@ -2,85 +2,145 @@ const TESTHUBENDPOINT_URL = `http://${document.location.host}/testhub`; describe('hubConnection', () => { eachTransport(transportType => { - it(`over ${signalR.TransportType[transportType]} can invoke server method and receive result`, done => { - const message = "Hi"; - let hubConnection = new signalR.HubConnection(TESTHUBENDPOINT_URL, 'formatType=json&format=text'); - hubConnection.onClosed = error => { - expect(error).toBe(undefined); - done(); - } + describe(`${signalR.TransportType[transportType]} transport`, () => { + it(`can invoke server method and receive result`, done => { + const message = "Hi"; + let hubConnection = new signalR.HubConnection(TESTHUBENDPOINT_URL, 'formatType=json&format=text'); + hubConnection.onClosed = error => { + expect(error).toBe(undefined); + done(); + } - hubConnection.start(transportType) - .then(() => { - hubConnection.invoke('Echo', message) - .then(result => { - expect(result).toBe(message); - }) - .catch(() => { - fail(); - }) + hubConnection.start(transportType) .then(() => { - hubConnection.stop(); + hubConnection.invoke('Echo', message) + .then(result => { + expect(result).toBe(message); + }) + .catch(e => { + fail(e); + }) + .then(() => { + hubConnection.stop(); + }) }) - }) - .catch(() => { - fail(); + .catch(e => { + fail(e); + done(); + }); + }); + + it(`can stream server method and receive result`, done => { + let hubConnection = new signalR.HubConnection(TESTHUBENDPOINT_URL, 'formatType=json&format=text'); + hubConnection.onClosed = error => { + expect(error).toBe(undefined); done(); - }); - }); + } + + let received = []; + hubConnection.start(transportType) + .then(() => { + hubConnection.stream('Stream') + .subscribe({ + next: (item) => { + received.push(item); + }, + error: (err) => { + fail(err); + done(); + }, + complete: () => { + expect(received).toEqual(["a", "b", "c"]); + done(); + } + }); + }) + .catch(e => { + fail(e); + done(); + }); + }); - it(`over ${signalR.TransportType[transportType]} rethrows an exception from the server`, done => { - const errorMessage = "An error occurred."; - let hubConnection = new signalR.HubConnection(TESTHUBENDPOINT_URL, 'formatType=json&format=text'); + it(`rethrows an exception from the server when invoking`, done => { + const errorMessage = "An error occurred."; + let hubConnection = new signalR.HubConnection(TESTHUBENDPOINT_URL, 'formatType=json&format=text'); - hubConnection.start(transportType) - .then(() => { - hubConnection.invoke('ThrowException', errorMessage) - .then(() => { - // exception expected but none thrown - fail(); - }) - .catch(e => { - expect(e.message).toBe(errorMessage); - }) - .then(() => { - return hubConnection.stop(); - }) - .then(() => { - done(); - }); - }) - .catch(() => { - fail(); - done(); - }); - }); + hubConnection.start(transportType) + .then(() => { + hubConnection.invoke('ThrowException', errorMessage) + .then(() => { + // exception expected but none thrown + fail(); + }) + .catch(e => { + expect(e.message).toBe(errorMessage); + }) + .then(() => { + return hubConnection.stop(); + }) + .then(() => { + done(); + }); + }) + .catch(e => { + fail(e); + done(); + }); + }); - it(`over ${signalR.TransportType[transportType]} can receive server calls`, done => { - let client = new signalR.HubConnection(TESTHUBENDPOINT_URL, 'formatType=json&format=text'); - const message = "Hello SignalR"; + it(`rethrows an exception from the server when streaming`, done => { + const errorMessage = "An error occurred."; + let hubConnection = new signalR.HubConnection(TESTHUBENDPOINT_URL, 'formatType=json&format=text'); - let callbackPromise = new Promise((resolve, reject) => { - client.on("Message", msg => { - expect(msg).toBe(message); - resolve(); - }); + hubConnection.start(transportType) + .then(() => { + hubConnection.stream('ThrowException', errorMessage) + .subscribe({ + next: (item) => { + fail(); + }, + error: (err) => { + expect(err.message).toEqual("An error occurred."); + done(); + }, + complete: () => { + fail(); + } + }); + + }) + .catch(e => { + fail(e); + done(); + }); }); - client.start(transportType) - .then(() => { - return Promise.all([client.invoke('InvokeWithString', message), callbackPromise]); - }) - .then(() => { - return stop(); - }) - .then(() => { - done(); - }) - .catch(e => { - fail(); - done(); + it(`can receive server calls`, done => { + let client = new signalR.HubConnection(TESTHUBENDPOINT_URL, 'formatType=json&format=text'); + const message = "Hello SignalR"; + + let callbackPromise = new Promise((resolve, reject) => { + client.on("Message", msg => { + expect(msg).toBe(message); + resolve(); + }); }); + + client.start(transportType) + .then(() => { + return Promise.all([client.invoke('InvokeWithString', message), callbackPromise]); + }) + .then(() => { + return stop(); + }) + .then(() => { + done(); + }) + .catch(e => { + fail(e); + done(); + }); + }); }); }); -}); \ No newline at end of file +}); diff --git a/client-ts/Microsoft.AspNetCore.SignalR.Test.Server/wwwroot/js/webSocketTests.js b/client-ts/Microsoft.AspNetCore.SignalR.Test.Server/wwwroot/js/webSocketTests.js index 9a0fb49132..6c437ae49f 100644 --- a/client-ts/Microsoft.AspNetCore.SignalR.Test.Server/wwwroot/js/webSocketTests.js +++ b/client-ts/Microsoft.AspNetCore.SignalR.Test.Server/wwwroot/js/webSocketTests.js @@ -1,18 +1,13 @@ -describe('WebSockets', function () { +describe('WebSockets', function () { it('can be used to connect to SignalR', done => { const message = "message"; - let webSocket = new WebSocket(ECHOENDPOINT_URL.replace(/^http/, "ws") + '/ws'); + let webSocket = new WebSocket(ECHOENDPOINT_URL.replace(/^http/, "ws")); webSocket.onopen = () => { webSocket.send(message); }; - webSocket.onerror = event => { - fail(); - done(); - }; - var received = ""; webSocket.onmessage = event => { received += event.data; @@ -22,8 +17,14 @@ }; webSocket.onclose = event => { + if (!event.wasClean) { + fail("connection closed with unexpected status code: " + event.code + " " + event.reason); + } + + // Jasmine doesn't like tests without expectations expect(event.wasClean).toBe(true); + done(); }; }); -}); \ No newline at end of file +}); diff --git a/samples/SocketsSample/Hubs/Streaming.cs b/samples/SocketsSample/Hubs/Streaming.cs new file mode 100644 index 0000000000..e8f26ef0df --- /dev/null +++ b/samples/SocketsSample/Hubs/Streaming.cs @@ -0,0 +1,79 @@ +using System; +using System.IO; +using System.Threading; +using System.Threading.Tasks; +using System.Threading.Tasks.Channels; +using Microsoft.AspNetCore.SignalR; + +namespace SocketsSample.Hubs +{ + public class Streaming : Hub + { + public IObservable ObservableCounter(int count, int delay) + { + return new CounterObservable(count, delay); + } + + public ReadableChannel ChannelCounter(int count, int delay) + { + var channel = Channel.CreateUnbounded(); + + Task.Run(async () => + { + for (var i = 0; i < count; i++) + { + await channel.Out.WriteAsync(i); + await Task.Delay(delay); + } + + channel.Out.TryComplete(); + }); + + return channel.In; + } + + private class CounterObservable : IObservable + { + private int _count; + private int _delay; + + public CounterObservable(int count, int delay) + { + _count = count; + _delay = delay; + } + + public IDisposable Subscribe(IObserver observer) + { + // Run in a thread-pool thread + var cts = new CancellationTokenSource(); + Task.Run(async () => + { + for (var i = 0; !cts.Token.IsCancellationRequested && i < _count; i++) + { + observer.OnNext(i); + await Task.Delay(_delay); + } + observer.OnCompleted(); + }); + + return new Disposable(() => cts.Cancel()); + } + } + + private class Disposable : IDisposable + { + private Action _action; + + public Disposable(Action action) + { + _action = action; + } + + public void Dispose() + { + _action(); + } + } + } +} diff --git a/samples/SocketsSample/Startup.cs b/samples/SocketsSample/Startup.cs index 1acae62ef4..3578b22557 100644 --- a/samples/SocketsSample/Startup.cs +++ b/samples/SocketsSample/Startup.cs @@ -1,4 +1,4 @@ -// Copyright (c) .NET Foundation. All rights reserved. +// Copyright (c) .NET Foundation. All rights reserved. // Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. using Microsoft.AspNetCore.Builder; @@ -49,6 +49,7 @@ public void Configure(IApplicationBuilder app, IHostingEnvironment env) app.UseSignalR(routes => { routes.MapHub("hubs"); + routes.MapHub("streaming"); }); app.UseSockets(routes => diff --git a/samples/SocketsSample/wwwroot/hubs.html b/samples/SocketsSample/wwwroot/hubs.html index cb2331db55..2533f75dcc 100644 --- a/samples/SocketsSample/wwwroot/hubs.html +++ b/samples/SocketsSample/wwwroot/hubs.html @@ -1,4 +1,4 @@ - + @@ -55,13 +55,6 @@

Private Message

\ No newline at end of file + diff --git a/samples/SocketsSample/wwwroot/index.html b/samples/SocketsSample/wwwroot/index.html index f8404615c5..5f0e8c06b2 100644 --- a/samples/SocketsSample/wwwroot/index.html +++ b/samples/SocketsSample/wwwroot/index.html @@ -1,22 +1,28 @@ - + -

ASP.NET Sockets

+

ASP.NET Core Sockets

-

ASP.NET SignalR (Hubs)

+

ASP.NET Core SignalR (Hubs)

+

ASP.NET Core SignalR (Streaming)

+ diff --git a/samples/SocketsSample/wwwroot/streaming.html b/samples/SocketsSample/wwwroot/streaming.html new file mode 100644 index 0000000000..348b11e796 --- /dev/null +++ b/samples/SocketsSample/wwwroot/streaming.html @@ -0,0 +1,105 @@ + + + + + + + +

Unknown Transport

+ +

Controls

+
+ + + +
+ +
+ + +
+ +

Results

+
    + +
      + + + + + diff --git a/samples/SocketsSample/wwwroot/utils.js b/samples/SocketsSample/wwwroot/utils.js index ef155b5a2d..8bf6b7c2b3 100644 --- a/samples/SocketsSample/wwwroot/utils.js +++ b/samples/SocketsSample/wwwroot/utils.js @@ -10,3 +10,19 @@ function getParameterByName(name, url) { return decodeURIComponent(results[2].replace(/\+/g, " ")); } +function click(id, callback) { + document.getElementById(id).addEventListener('click', event => { + callback(event); + event.preventDefault(); + }); +} + +function addLine(listId, line, color) { + var child = document.createElement('li'); + if (color) { + child.style.color = color; + } + child.innerText = line; + document.getElementById(listId).appendChild(child); +} + diff --git a/src/Microsoft.AspNetCore.SignalR.Client/CastObservable.cs b/src/Microsoft.AspNetCore.SignalR.Client/CastObservable.cs new file mode 100644 index 0000000000..629cdf24bd --- /dev/null +++ b/src/Microsoft.AspNetCore.SignalR.Client/CastObservable.cs @@ -0,0 +1,54 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; + +namespace Microsoft.AspNetCore.SignalR.Client +{ + internal class CastObservable : IObservable + { + private IObservable _innerObservable; + + public CastObservable(IObservable innerObservable) + { + _innerObservable = innerObservable; + } + + public IDisposable Subscribe(IObserver observer) + { + return _innerObservable.Subscribe(new CastObserver(observer)); + } + + private class CastObserver : IObserver + { + private IObserver _innerObserver; + + public CastObserver(IObserver innerObserver) + { + _innerObserver = innerObserver; + } + + public void OnCompleted() + { + _innerObserver.OnCompleted(); + } + + public void OnError(Exception error) + { + _innerObserver.OnError(error); + } + + public void OnNext(object value) + { + try + { + _innerObserver.OnNext((TResult)value); + } + catch(Exception ex) + { + _innerObserver.OnError(ex); + } + } + } + } +} diff --git a/src/Microsoft.AspNetCore.SignalR.Client/HubConnection.cs b/src/Microsoft.AspNetCore.SignalR.Client/HubConnection.cs index 490b3b28a1..c6f038e887 100644 --- a/src/Microsoft.AspNetCore.SignalR.Client/HubConnection.cs +++ b/src/Microsoft.AspNetCore.SignalR.Client/HubConnection.cs @@ -8,6 +8,7 @@ using System.Net.Http; using System.Threading; using System.Threading.Tasks; +using System.Threading.Tasks.Channels; using Microsoft.AspNetCore.SignalR.Internal; using Microsoft.AspNetCore.SignalR.Internal.Protocol; using Microsoft.AspNetCore.Sockets; @@ -115,17 +116,30 @@ public void On(string methodName, Type[] parameterTypes, Action handle _handlers.AddOrUpdate(methodName, invocationHandler, (_, __) => invocationHandler); } - public async Task Invoke(string methodName, Type returnType, CancellationToken cancellationToken, params object[] args) + public ReadableChannel Stream(string methodName, Type returnType, CancellationToken cancellationToken, params object[] args) + { + var irq = InvocationRequest.Stream(cancellationToken, returnType, GetNextId(), _loggerFactory, out var channel); + InvokeCore(methodName, irq, args); + return channel; + } + + public Task Invoke(string methodName, Type returnType, CancellationToken cancellationToken, params object[] args) + { + var irq = InvocationRequest.Invoke(cancellationToken, returnType, GetNextId(), _loggerFactory, out var task); + InvokeCore(methodName, irq, args); + return task; + } + + private void InvokeCore(string methodName, InvocationRequest irq, object[] args) { ThrowIfConnectionTerminated(); - _logger.LogTrace("Preparing invocation of '{target}', with return type '{returnType}' and {argumentCount} args", methodName, returnType.AssemblyQualifiedName, args.Length); + _logger.LogTrace("Preparing invocation of '{target}', with return type '{returnType}' and {argumentCount} args", methodName, irq.ResultType.AssemblyQualifiedName, args.Length); // Create an invocation descriptor. Client invocations are always blocking - var invocationMessage = new InvocationMessage(GetNextId(), nonBlocking: false, target: methodName, arguments: args); + var invocationMessage = new InvocationMessage(irq.InvocationId, nonBlocking: false, target: methodName, arguments: args); // I just want an excuse to use 'irq' as a variable name... _logger.LogDebug("Registering Invocation ID '{invocationId}' for tracking", invocationMessage.InvocationId); - var irq = new InvocationRequest(cancellationToken, returnType, invocationMessage.InvocationId, _loggerFactory); AddInvocation(irq); @@ -133,16 +147,22 @@ public async Task Invoke(string methodName, Type returnType, Cancellatio if (_logger.IsEnabled(LogLevel.Trace)) { var argsList = string.Join(", ", args.Select(a => a.GetType().FullName)); - _logger.LogTrace("Issuing Invocation '{invocationId}': {returnType} {methodName}({args})", invocationMessage.InvocationId, returnType.FullName, methodName, argsList); + _logger.LogTrace("Issuing Invocation '{invocationId}': {returnType} {methodName}({args})", invocationMessage.InvocationId, irq.ResultType.FullName, methodName, argsList); } + // We don't need to wait for this to complete. It will signal back to the invocation request. + _ = SendInvocation(invocationMessage, irq); + } + + private async Task SendInvocation(InvocationMessage invocationMessage, InvocationRequest irq) + { try { var payload = await _protocol.WriteToArrayAsync(invocationMessage); _logger.LogInformation("Sending Invocation '{invocationId}'", invocationMessage.InvocationId); - await _connection.SendAsync(payload, _protocol.MessageType, cancellationToken); + await _connection.SendAsync(payload, _protocol.MessageType, irq.CancellationToken); _logger.LogInformation("Sending Invocation '{invocationId}' complete", invocationMessage.InvocationId); } catch (Exception ex) @@ -151,9 +171,6 @@ public async Task Invoke(string methodName, Type returnType, Cancellatio irq.Fail(ex); TryRemoveInvocation(invocationMessage.InvocationId, out _); } - - // Return the completion task. It will be completed by ReceiveMessages when the response is received. - return await irq.Completion; } private void OnDataReceived(byte[] data, MessageType messageType) @@ -182,12 +199,12 @@ private void OnDataReceived(byte[] data, MessageType messageType) break; case StreamItemMessage streamItem: // Complete the invocation with an error, we don't support streaming (yet) - if (!TryRemoveInvocation(streamItem.InvocationId, out irq)) + if (!TryGetInvocation(streamItem.InvocationId, out irq)) { _logger.LogWarning("Dropped unsolicited Stream Item message for invocation '{invocationId}'", streamItem.InvocationId); return; } - irq.Fail(new NotSupportedException("Streaming method results are not supported")); + DispatchInvocationStreamItemAsync(streamItem, irq); break; default: throw new InvalidOperationException($"Unknown message type: {message.GetType().FullName}"); @@ -236,6 +253,22 @@ private void DispatchInvocation(InvocationMessage invocation, CancellationToken handler.Handler(invocation.Arguments); } + // This async void is GROSS but we need to dispatch asynchronously because we're writing to a Channel + // and there's nobody to actually wait for us to finish. + private async void DispatchInvocationStreamItemAsync(StreamItemMessage streamItem, InvocationRequest irq) + { + _logger.LogTrace("Received StreamItem for Invocation #{invocationId}", streamItem.InvocationId); + + if (irq.CancellationToken.IsCancellationRequested) + { + _logger.LogTrace("Canceling dispatch of StreamItem message for Invocation {invocationId}. The invocation was cancelled.", irq.InvocationId); + } + else if (!await irq.StreamItem(streamItem.Item)) + { + _logger.LogWarning("Invocation {invocationId} received stream item after channel was closed.", irq.InvocationId); + } + } + private void DispatchInvocationCompletion(CompletionMessage completion, InvocationRequest irq) { _logger.LogTrace("Received Completion for Invocation #{invocationId}", completion.InvocationId); @@ -352,53 +385,5 @@ public InvocationHandler(Type[] parameterTypes, Action handler) ParameterTypes = parameterTypes; } } - - private class InvocationRequest : IDisposable - { - private readonly TaskCompletionSource _completionSource = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); - private readonly CancellationTokenRegistration _cancellationTokenRegistration; - private readonly ILogger _logger; - - public Type ResultType { get; } - public CancellationToken CancellationToken { get; } - public string InvocationId { get; } - - public Task Completion => _completionSource.Task; - - - public InvocationRequest(CancellationToken cancellationToken, Type resultType, string invocationId, ILoggerFactory loggerFactory) - { - _logger = loggerFactory.CreateLogger(); - _cancellationTokenRegistration = cancellationToken.Register(() => _completionSource.TrySetCanceled()); - - InvocationId = invocationId; - CancellationToken = cancellationToken; - ResultType = resultType; - - _logger.LogTrace("Invocation {invocationId} created", InvocationId); - } - - public void Fail(Exception exception) - { - _logger.LogTrace("Invocation {invocationId} marked as failed", InvocationId); - _completionSource.TrySetException(exception); - } - - public void Complete(object result) - { - _logger.LogTrace("Invocation {invocationId} marked as completed", InvocationId); - _completionSource.TrySetResult(result); - } - - public void Dispose() - { - _logger.LogTrace("Invocation {invocationId} disposed", InvocationId); - - // Just in case it hasn't already been completed - _completionSource.TrySetCanceled(); - - _cancellationTokenRegistration.Dispose(); - } - } } } diff --git a/src/Microsoft.AspNetCore.SignalR.Client/HubConnectionExtensions.cs b/src/Microsoft.AspNetCore.SignalR.Client/HubConnectionExtensions.cs index b892fc7488..c26b4db924 100644 --- a/src/Microsoft.AspNetCore.SignalR.Client/HubConnectionExtensions.cs +++ b/src/Microsoft.AspNetCore.SignalR.Client/HubConnectionExtensions.cs @@ -1,9 +1,10 @@ -// Copyright (c) .NET Foundation. All rights reserved. +// Copyright (c) .NET Foundation. All rights reserved. // Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. using System; using System.Threading; using System.Threading.Tasks; +using System.Threading.Tasks.Channels; namespace Microsoft.AspNetCore.SignalR.Client { @@ -35,6 +36,59 @@ public async static Task Invoke(this HubConnection hubConnecti return (TResult)await hubConnection.Invoke(methodName, typeof(TResult), cancellationToken, args); } + public static ReadableChannel Stream(this HubConnection hubConnection, string methodName, params object[] args) => + Stream(hubConnection, methodName, CancellationToken.None, args); + + public static ReadableChannel Stream(this HubConnection hubConnection, string methodName, CancellationToken cancellationToken, params object[] args) + { + if (hubConnection == null) + { + throw new ArgumentNullException(nameof(hubConnection)); + } + + var inputChannel = hubConnection.Stream(methodName, typeof(TResult), cancellationToken, args); + var outputChannel = Channel.CreateUnbounded(); + + // Local function to provide a way to run async code as fire-and-forget + // The output channel is how we signal completion to the caller. + async Task RunChannel() + { + try + { + while (await inputChannel.WaitToReadAsync()) + { + while (inputChannel.TryRead(out var item)) + { + while (!outputChannel.Out.TryWrite((TResult)item)) + { + if (!await outputChannel.Out.WaitToWriteAsync()) + { + // Failed to write to the output channel because it was closed. Nothing really we can do but abort here. + return; + } + } + } + } + + // Manifest any errors in the completion task + await inputChannel.Completion; + } + catch (Exception ex) + { + outputChannel.Out.TryComplete(ex); + } + finally + { + // This will safely no-op if the catch block above ran. + outputChannel.Out.TryComplete(); + } + } + + _ = RunChannel(); + + return outputChannel.In; + } + public static void On(this HubConnection hubConnection, string methodName, Action handler) { if (hubConnection == null) diff --git a/src/Microsoft.AspNetCore.SignalR.Client/InvocationRequest.cs b/src/Microsoft.AspNetCore.SignalR.Client/InvocationRequest.cs new file mode 100644 index 0000000000..e40dd8feda --- /dev/null +++ b/src/Microsoft.AspNetCore.SignalR.Client/InvocationRequest.cs @@ -0,0 +1,160 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; +using System.Threading; +using System.Threading.Tasks; +using System.Threading.Tasks.Channels; +using Microsoft.Extensions.Logging; + +namespace Microsoft.AspNetCore.SignalR.Client +{ + internal abstract class InvocationRequest : IDisposable + { + private readonly CancellationTokenRegistration _cancellationTokenRegistration; + + protected ILogger Logger { get; } + + public Type ResultType { get; } + public CancellationToken CancellationToken { get; } + public string InvocationId { get; } + + protected InvocationRequest(CancellationToken cancellationToken, Type resultType, string invocationId, ILogger logger) + { + _cancellationTokenRegistration = cancellationToken.Register(self => ((InvocationRequest)self).Cancel(), this); + + InvocationId = invocationId; + CancellationToken = cancellationToken; + ResultType = resultType; + Logger = logger; + + Logger.LogTrace("Invocation {invocationId} created", InvocationId); + } + + public static InvocationRequest Invoke(CancellationToken cancellationToken, Type resultType, string invocationId, ILoggerFactory loggerFactory, out Task result) + { + var req = new NonStreaming(cancellationToken, resultType, invocationId, loggerFactory); + result = req.Result; + return req; + } + + + public static InvocationRequest Stream(CancellationToken cancellationToken, Type resultType, string invocationId, ILoggerFactory loggerFactory, out ReadableChannel result) + { + var req = new Streaming(cancellationToken, resultType, invocationId, loggerFactory); + result = req.Result; + return req; + } + + public abstract void Fail(Exception exception); + public abstract void Complete(object result); + public abstract ValueTask StreamItem(object item); + + protected abstract void Cancel(); + + public virtual void Dispose() + { + Logger.LogTrace("Invocation {invocationId} disposed", InvocationId); + + // Just in case it hasn't already been completed + Cancel(); + + _cancellationTokenRegistration.Dispose(); + } + + private class Streaming : InvocationRequest + { + private readonly Channel _channel = Channel.CreateUnbounded(); + + public Streaming(CancellationToken cancellationToken, Type resultType, string invocationId, ILoggerFactory loggerFactory) + : base(cancellationToken, resultType, invocationId, loggerFactory.CreateLogger()) + { + } + + public ReadableChannel Result => _channel.In; + + public override void Complete(object result) + { + Logger.LogTrace("Invocation {invocationId} marked as completed.", InvocationId); + if (result != null) + { + Logger.LogError("Invocation {invocationId} received a completion result, but was invoked as a streaming invocation.", InvocationId); + _channel.Out.TryComplete(new InvalidOperationException("Server provided a result in a completion response to a streamed invocation.")); + } + else + { + _channel.Out.TryComplete(); + } + } + + public override void Fail(Exception exception) + { + Logger.LogTrace("Invocation {invocationId} marked as failed.", InvocationId); + _channel.Out.TryComplete(exception); + } + + public override async ValueTask StreamItem(object item) + { + try + { + Logger.LogTrace("Invocation {invocationId} received stream item.", InvocationId); + while (!_channel.Out.TryWrite(item)) + { + if (!await _channel.Out.WaitToWriteAsync()) + { + return false; + } + } + } + catch (Exception ex) + { + Logger.LogError(ex, "Invocation {invocationId} caused an error trying to write a stream item.", InvocationId); + } + return true; + } + + protected override void Cancel() + { + _channel.Out.TryComplete(new OperationCanceledException("Connection terminated")); + } + } + + private class NonStreaming : InvocationRequest + { + private readonly TaskCompletionSource _completionSource = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); + + public NonStreaming(CancellationToken cancellationToken, Type resultType, string invocationId, ILoggerFactory loggerFactory) + : base(cancellationToken, resultType, invocationId, loggerFactory.CreateLogger()) + { + } + + public Task Result => _completionSource.Task; + + public override void Complete(object result) + { + Logger.LogTrace("Invocation {invocationId} marked as completed.", InvocationId); + _completionSource.TrySetResult(result); + } + + public override void Fail(Exception exception) + { + Logger.LogTrace("Invocation {invocationId} marked as failed.", InvocationId); + _completionSource.TrySetException(exception); + } + + public override ValueTask StreamItem(object item) + { + Logger.LogError("Invocation {invocationId} received stream item but was invoked as a non-streamed invocation.", InvocationId); + _completionSource.TrySetException(new InvalidOperationException("Streaming methods must be invoked using HubConnection.Stream")); + + // We "delivered" the stream item successfully as far as the caller cares + return new ValueTask(true); + } + + protected override void Cancel() + { + _completionSource.TrySetCanceled(); + } + } + } +} diff --git a/src/Microsoft.AspNetCore.SignalR.Common/Internal/Protocol/CompletionMessage.cs b/src/Microsoft.AspNetCore.SignalR.Common/Internal/Protocol/CompletionMessage.cs index 44b7206121..78c39c9a9f 100644 --- a/src/Microsoft.AspNetCore.SignalR.Common/Internal/Protocol/CompletionMessage.cs +++ b/src/Microsoft.AspNetCore.SignalR.Common/Internal/Protocol/CompletionMessage.cs @@ -34,5 +34,7 @@ public override string ToString() public static CompletionMessage WithError(string invocationId, string error) => new CompletionMessage(invocationId, error, result: null, hasResult: false); public static CompletionMessage WithResult(string invocationId, object payload) => new CompletionMessage(invocationId, error: null, result: payload, hasResult: true); + + public static CompletionMessage Empty(string invocationId) => new CompletionMessage(invocationId, error: null, result: null, hasResult: false); } } diff --git a/src/Microsoft.AspNetCore.SignalR.Common/Internal/Protocol/JsonHubProtocol.cs b/src/Microsoft.AspNetCore.SignalR.Common/Internal/Protocol/JsonHubProtocol.cs index 41e4db3d7a..41ba3505ef 100644 --- a/src/Microsoft.AspNetCore.SignalR.Common/Internal/Protocol/JsonHubProtocol.cs +++ b/src/Microsoft.AspNetCore.SignalR.Common/Internal/Protocol/JsonHubProtocol.cs @@ -13,6 +13,7 @@ namespace Microsoft.AspNetCore.SignalR.Internal.Protocol public class JsonHubProtocol : IHubProtocol { private const string ResultPropertyName = "result"; + private const string ItemPropertyName = "item"; private const string InvocationIdPropertyName = "invocationId"; private const string TypePropertyName = "type"; private const string ErrorPropertyName = "error"; @@ -117,7 +118,7 @@ private void WriteMessage(HubMessage message, Stream stream) WriteInvocationMessage(m, writer); break; case StreamItemMessage m: - WriteResultMessage(m, writer); + WriteStreamItemMessage(m, writer); break; case CompletionMessage m: WriteCompletionMessage(m, writer); @@ -145,11 +146,11 @@ private void WriteCompletionMessage(CompletionMessage message, JsonTextWriter wr writer.WriteEndObject(); } - private void WriteResultMessage(StreamItemMessage message, JsonTextWriter writer) + private void WriteStreamItemMessage(StreamItemMessage message, JsonTextWriter writer) { writer.WriteStartObject(); WriteHubMessageCommon(message, writer, ResultMessageType); - writer.WritePropertyName(ResultPropertyName); + writer.WritePropertyName(ItemPropertyName); _payloadSerializer.Serialize(writer, message.Item); writer.WriteEndObject(); } @@ -216,7 +217,7 @@ private InvocationMessage BindInvocationMessage(JObject json, IInvocationBinder private StreamItemMessage BindResultMessage(JObject json, IInvocationBinder binder) { var invocationId = GetRequiredProperty(json, InvocationIdPropertyName, JTokenType.String); - var result = GetRequiredProperty(json, ResultPropertyName); + var result = GetRequiredProperty(json, ItemPropertyName); var returnType = binder.GetReturnType(invocationId); return new StreamItemMessage(invocationId, result?.ToObject(returnType, _payloadSerializer)); diff --git a/src/Microsoft.AspNetCore.SignalR/HubEndPoint.cs b/src/Microsoft.AspNetCore.SignalR/HubEndPoint.cs index f16170ff78..3e18413702 100644 --- a/src/Microsoft.AspNetCore.SignalR/HubEndPoint.cs +++ b/src/Microsoft.AspNetCore.SignalR/HubEndPoint.cs @@ -3,10 +3,12 @@ using System; using System.Collections.Generic; +using System.IO; using System.Linq; using System.Reflection; using System.Threading; using System.Threading.Tasks; +using System.Threading.Tasks.Channels; using Microsoft.AspNetCore.SignalR.Internal; using Microsoft.AspNetCore.SignalR.Internal.Protocol; using Microsoft.AspNetCore.Sockets; @@ -220,8 +222,7 @@ private async Task Execute(ConnectionContext connection, IHubProtocol protocol, } else { - var result = await Invoke(descriptor, connection, invocationMessage); - await SendMessageAsync(connection, protocol, result); + await Invoke(descriptor, connection, protocol, invocationMessage); } } @@ -243,7 +244,7 @@ private async Task SendMessageAsync(ConnectionContext connection, IHubProtocol p throw new OperationCanceledException("Outbound channel was closed while trying to write hub message"); } - private async Task Invoke(HubMethodDescriptor descriptor, ConnectionContext connection, InvocationMessage invocationMessage) + private async Task Invoke(HubMethodDescriptor descriptor, ConnectionContext connection, IHubProtocol protocol, InvocationMessage invocationMessage) { var methodExecutor = descriptor.MethodExecutor; @@ -257,7 +258,9 @@ private async Task Invoke(HubMethodDescriptor descriptor, Con InitializeHub(hub, connection); object result = null; - if (methodExecutor.IsMethodAsync) + + // ReadableChannel is awaitable but we don't want to await it. + if (methodExecutor.IsMethodAsync && !IsChannel(methodExecutor.MethodReturnType, out _)) { if (methodExecutor.MethodReturnType == typeof(Task)) { @@ -273,17 +276,26 @@ private async Task Invoke(HubMethodDescriptor descriptor, Con result = methodExecutor.Execute(hub, invocationMessage.Arguments); } - return CompletionMessage.WithResult(invocationMessage.InvocationId, result); + if (IsStreamed(methodExecutor, result, methodExecutor.MethodReturnType, out var enumerator)) + { + _logger.LogTrace("[{connectionId}/{invocationId}] Streaming result of type {resultType}", connection.ConnectionId, invocationMessage.InvocationId, methodExecutor.MethodReturnType.FullName); + await StreamResultsAsync(invocationMessage.InvocationId, connection, protocol, enumerator); + } + else + { + _logger.LogTrace("[{connectionId}/{invocationId}] Sending result of type {resultType}", connection.ConnectionId, invocationMessage.InvocationId, methodExecutor.MethodReturnType.FullName); + await SendMessageAsync(connection, protocol, CompletionMessage.WithResult(invocationMessage.InvocationId, result)); + } } catch (TargetInvocationException ex) { _logger.LogError(0, ex, "Failed to invoke hub method"); - return CompletionMessage.WithError(invocationMessage.InvocationId, ex.InnerException.Message); + await SendMessageAsync(connection, protocol, CompletionMessage.WithError(invocationMessage.InvocationId, ex.InnerException.Message)); } catch (Exception ex) { _logger.LogError(0, ex, "Failed to invoke hub method"); - return CompletionMessage.WithError(invocationMessage.InvocationId, ex.Message); + await SendMessageAsync(connection, protocol, CompletionMessage.WithError(invocationMessage.InvocationId, ex.Message)); } finally { @@ -299,6 +311,74 @@ private void InitializeHub(THub hub, ConnectionContext connection) hub.Groups = new GroupManager(connection, _lifetimeManager); } + private bool IsChannel(Type type, out Type payloadType) + { + var channelType = type.AllBaseTypes().FirstOrDefault(t => t.IsGenericType && t.GetGenericTypeDefinition() == typeof(ReadableChannel<>)); + if (channelType == null) + { + payloadType = null; + return false; + } + else + { + payloadType = channelType.GetGenericArguments()[0]; + return true; + } + } + + private async Task StreamResultsAsync(string invocationId, ConnectionContext connection, IHubProtocol protocol, IAsyncEnumerator enumerator) + { + // TODO: Cancellation? See https://github.com/aspnet/SignalR/issues/481 + try + { + while (await enumerator.MoveNextAsync()) + { + // Send the stream item + await SendMessageAsync(connection, protocol, new StreamItemMessage(invocationId, enumerator.Current)); + } + + await SendMessageAsync(connection, protocol, CompletionMessage.Empty(invocationId)); + } + catch (Exception ex) + { + await SendMessageAsync(connection, protocol, CompletionMessage.WithError(invocationId, ex.Message)); + } + } + + private bool IsStreamed(ObjectMethodExecutor methodExecutor, object result, Type resultType, out IAsyncEnumerator enumerator) + { + if (result == null) + { + enumerator = null; + return false; + } + + var observableInterface = IsIObservable(resultType) ? + resultType : + resultType.GetInterfaces().FirstOrDefault(IsIObservable); + if (observableInterface != null) + { + enumerator = AsyncEnumeratorAdapters.FromObservable(result, observableInterface); + return true; + } + else if (IsChannel(resultType, out var payloadType)) + { + enumerator = AsyncEnumeratorAdapters.FromChannel(result, payloadType); + return true; + } + else + { + // Not streamed + enumerator = null; + return false; + } + } + + private static bool IsIObservable(Type iface) + { + return iface.IsGenericType && iface.GetGenericTypeDefinition() == typeof(IObservable<>); + } + private void DiscoverHubMethods() { var hubType = typeof(THub); diff --git a/src/Microsoft.AspNetCore.SignalR/Internal/AsyncEnumeratorAdapters.cs b/src/Microsoft.AspNetCore.SignalR/Internal/AsyncEnumeratorAdapters.cs new file mode 100644 index 0000000000..58411afedc --- /dev/null +++ b/src/Microsoft.AspNetCore.SignalR/Internal/AsyncEnumeratorAdapters.cs @@ -0,0 +1,124 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; +using System.Linq; +using System.Reflection; +using System.Threading; +using System.Threading.Tasks; +using System.Threading.Tasks.Channels; + +namespace Microsoft.AspNetCore.SignalR.Internal +{ + // True-internal because this is a weird and tricky class to use :) + internal static class AsyncEnumeratorAdapters + { + private static readonly MethodInfo _boxEnumeratorMethod = typeof(AsyncEnumeratorAdapters) + .GetRuntimeMethods() + .Single(m => m.Name.Equals(nameof(BoxEnumerator)) && m.IsGenericMethod); + + private static readonly MethodInfo _fromObservableMethod = typeof(AsyncEnumeratorAdapters) + .GetRuntimeMethods() + .Single(m => m.Name.Equals(nameof(FromObservable)) && m.IsGenericMethod); + + private static readonly object[] _getAsyncEnumeratorArgs = new object[] { CancellationToken.None }; + + public static IAsyncEnumerator FromObservable(object observable, Type observableInterface) + { + // TODO: Cache expressions by observable.GetType()? + return (IAsyncEnumerator)_fromObservableMethod + .MakeGenericMethod(observableInterface.GetGenericArguments()) + .Invoke(null, new[] { observable }); + } + + public static IAsyncEnumerator FromObservable(IObservable observable) + { + // TODO: Allow bounding and optimizations? + var channel = Channel.CreateUnbounded(); + + var subscription = observable.Subscribe(new ChannelObserver(channel.Out, CancellationToken.None)); + + return channel.In.GetAsyncEnumerator(); + } + + public static IAsyncEnumerator FromChannel(object readableChannelOfT, Type payloadType) + { + var enumerator = readableChannelOfT + .GetType() + .GetRuntimeMethod("GetAsyncEnumerator", new[] { typeof(CancellationToken) }) + .Invoke(readableChannelOfT, _getAsyncEnumeratorArgs); + + if (payloadType.IsValueType) + { + return (IAsyncEnumerator)_boxEnumeratorMethod + .MakeGenericMethod(payloadType) + .Invoke(null, new[] { enumerator }); + } + else + { + return (IAsyncEnumerator)enumerator; + } + } + + private static IAsyncEnumerator BoxEnumerator(IAsyncEnumerator input) where T : struct + { + return new BoxingEnumerator(input); + } + + private class ChannelObserver : IObserver + { + private WritableChannel _output; + private CancellationToken _cancellationToken; + + public ChannelObserver(WritableChannel output, CancellationToken cancellationToken) + { + _output = output; + _cancellationToken = cancellationToken; + } + + public void OnCompleted() + { + _output.TryComplete(); + } + + public void OnError(Exception error) + { + _output.TryComplete(error); + } + + public void OnNext(T value) + { + _cancellationToken.ThrowIfCancellationRequested(); + + // This will block the thread emitting the object if the channel is bounded and full + // I think this is OK, since we want to push the backpressure up. However, we may need + // to find a way to force the entire subscription off to a dedicated thread in order to + // ensure we don't block other tasks + + // Right now however, we use unbounded channels, so all of the above is moot because TryWrite will always succeed + while (!_output.TryWrite(value)) + { + // Wait for a spot + if (!_output.WaitToWriteAsync(_cancellationToken).Result) + { + // Channel was closed. + throw new InvalidOperationException("Output channel was closed"); + } + } + } + } + + private class BoxingEnumerator : IAsyncEnumerator where T : struct + { + private IAsyncEnumerator _input; + + public BoxingEnumerator(IAsyncEnumerator input) + { + _input = input; + } + + public object Current => _input.Current; + public Task MoveNextAsync() => _input.MoveNextAsync(); + } + } +} diff --git a/src/Microsoft.AspNetCore.SignalR/Internal/TypeBaseEnumerationExtensions.cs b/src/Microsoft.AspNetCore.SignalR/Internal/TypeBaseEnumerationExtensions.cs new file mode 100644 index 0000000000..70102b8624 --- /dev/null +++ b/src/Microsoft.AspNetCore.SignalR/Internal/TypeBaseEnumerationExtensions.cs @@ -0,0 +1,21 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; +using System.Collections.Generic; + +namespace Microsoft.AspNetCore.SignalR.Internal +{ + public static class TypeBaseEnumerationExtensions + { + public static IEnumerable AllBaseTypes(this Type type) + { + var current = type; + while (current != null) + { + yield return current; + current = current.BaseType; + } + } + } +} diff --git a/test/Common/ChannelExtensions.cs b/test/Common/ChannelExtensions.cs new file mode 100644 index 0000000000..2502886317 --- /dev/null +++ b/test/Common/ChannelExtensions.cs @@ -0,0 +1,24 @@ +using System.Collections.Generic; + +namespace System.Threading.Tasks.Channels +{ + internal static class ChannelExtensions + { + public static async Task> ReadAllAsync(this ReadableChannel channel) + { + var list = new List(); + while (await channel.WaitToReadAsync()) + { + while (channel.TryRead(out var item)) + { + list.Add(item); + } + } + + // Manifest any error from channel.Completion (which should be completed now) + await channel.Completion; + + return list; + } + } +} diff --git a/test/Microsoft.AspNetCore.SignalR.Client.FunctionalTests/HubConnectionTests.cs b/test/Microsoft.AspNetCore.SignalR.Client.FunctionalTests/HubConnectionTests.cs index 35316d66ad..3fdd84429d 100644 --- a/test/Microsoft.AspNetCore.SignalR.Client.FunctionalTests/HubConnectionTests.cs +++ b/test/Microsoft.AspNetCore.SignalR.Client.FunctionalTests/HubConnectionTests.cs @@ -1,8 +1,10 @@ -// Copyright (c) .NET Foundation. All rights reserved. +// Copyright (c) .NET Foundation. All rights reserved. // Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. using System; +using System.Reactive.Linq; using System.Threading.Tasks; +using System.Threading.Tasks.Channels; using Microsoft.AspNetCore.Builder; using Microsoft.AspNetCore.Hosting; using Microsoft.AspNetCore.SignalR.Tests.Common; @@ -50,12 +52,12 @@ public async Task CheckFixedMessage() using (var httpClient = _testServer.CreateClient()) { - var connection = new HubConnection(new Uri("http://test/hubs")); + var connection = new HubConnection(new Uri("http://test/hubs"), loggerFactory); try { await connection.StartAsync(TransportType.LongPolling, httpClient); - var result = await connection.Invoke("HelloWorld"); + var result = await connection.Invoke(nameof(TestHub.HelloWorld)); Assert.Equal("Hello World!", result); } @@ -74,12 +76,12 @@ public async Task CanSendAndReceiveMessage() using (var httpClient = _testServer.CreateClient()) { - var connection = new HubConnection(new Uri("http://test/hubs")); + var connection = new HubConnection(new Uri("http://test/hubs"), loggerFactory); try { await connection.StartAsync(TransportType.LongPolling, httpClient); - var result = await connection.Invoke("Echo", originalMessage); + var result = await connection.Invoke(nameof(TestHub.Echo), originalMessage); Assert.Equal(originalMessage, result); } @@ -103,7 +105,7 @@ public async Task MethodsAreCaseInsensitive() { await connection.StartAsync(TransportType.LongPolling, httpClient); - var result = await connection.Invoke("echo", originalMessage); + var result = await connection.Invoke(nameof(TestHub.Echo).ToLowerInvariant(), originalMessage); Assert.Equal(originalMessage, result); } @@ -122,7 +124,7 @@ public async Task CanInvokeClientMethodFromServer() using (var httpClient = _testServer.CreateClient()) { - var connection = new HubConnection(new Uri("http://test/hubs")); + var connection = new HubConnection(new Uri("http://test/hubs"), loggerFactory); try { await connection.StartAsync(TransportType.LongPolling, httpClient); @@ -130,7 +132,7 @@ public async Task CanInvokeClientMethodFromServer() var tcs = new TaskCompletionSource(); connection.On("Echo", tcs.SetResult); - await connection.Invoke("CallEcho", originalMessage).OrTimeout(); + await connection.Invoke(nameof(TestHub.CallEcho), originalMessage).OrTimeout(); Assert.Equal(originalMessage, await tcs.Task.OrTimeout()); } @@ -141,6 +143,31 @@ public async Task CanInvokeClientMethodFromServer() } } + [Fact] + public async Task CanStreamClientMethodFromServer() + { + var loggerFactory = CreateLogger(); + + using (var httpClient = _testServer.CreateClient()) + { + var connection = new HubConnection(new Uri("http://test/hubs"), loggerFactory); + try + { + await connection.StartAsync(TransportType.LongPolling, httpClient); + + var tcs = new TaskCompletionSource(); + + var results = await connection.Stream(nameof(TestHub.Stream)).ReadAllAsync().OrTimeout(); + + Assert.Equal(new[] { "a", "b", "c" }, results.ToArray()); + } + finally + { + await connection.DisposeAsync().OrTimeout(); + } + } + } + [Fact] public async Task ServerClosesConnectionIfHubMethodCannotBeResolved() { @@ -148,7 +175,7 @@ public async Task ServerClosesConnectionIfHubMethodCannotBeResolved() using (var httpClient = _testServer.CreateClient()) { - var connection = new HubConnection(new Uri("http://test/hubs")); + var connection = new HubConnection(new Uri("http://test/hubs"), loggerFactory); try { await connection.StartAsync(TransportType.LongPolling, httpClient); @@ -196,6 +223,11 @@ public async Task CallEcho(string message) { await Clients.Client(Context.ConnectionId).InvokeAsync("Echo", message); } + + public IObservable Stream() + { + return new[] { "a", "b", "c" }.ToObservable(); + } } } } diff --git a/test/Microsoft.AspNetCore.SignalR.Client.FunctionalTests/Microsoft.AspNetCore.SignalR.Client.FunctionalTests.csproj b/test/Microsoft.AspNetCore.SignalR.Client.FunctionalTests/Microsoft.AspNetCore.SignalR.Client.FunctionalTests.csproj index a4a774b519..340afe9254 100644 --- a/test/Microsoft.AspNetCore.SignalR.Client.FunctionalTests/Microsoft.AspNetCore.SignalR.Client.FunctionalTests.csproj +++ b/test/Microsoft.AspNetCore.SignalR.Client.FunctionalTests/Microsoft.AspNetCore.SignalR.Client.FunctionalTests.csproj @@ -9,6 +9,7 @@ + @@ -25,6 +26,7 @@ + diff --git a/test/Microsoft.AspNetCore.SignalR.Client.Tests/HubConnectionExtensionsTests.cs b/test/Microsoft.AspNetCore.SignalR.Client.Tests/HubConnectionExtensionsTests.cs index 703b3b6c80..8a55e55fb9 100644 --- a/test/Microsoft.AspNetCore.SignalR.Client.Tests/HubConnectionExtensionsTests.cs +++ b/test/Microsoft.AspNetCore.SignalR.Client.Tests/HubConnectionExtensionsTests.cs @@ -1,4 +1,4 @@ -// Copyright (c) .NET Foundation. All rights reserved. +// Copyright (c) .NET Foundation. All rights reserved. // Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. using System; @@ -128,7 +128,18 @@ public async Task ConnectionClosedOnCallbackArgumentCountMismatch() var connection = new TestConnection(); var hubConnection = new HubConnection(connection, new JsonHubProtocol(new JsonSerializer()), new LoggerFactory()); var closeTcs = new TaskCompletionSource(); - hubConnection.Closed += e => closeTcs.TrySetException(e); + hubConnection.Closed += e => + { + if (e == null) + { + closeTcs.TrySetResult(null); + } + else + { + closeTcs.TrySetException(e); + } + }; + try { hubConnection.On("Foo", r => { }); @@ -137,7 +148,7 @@ public async Task ConnectionClosedOnCallbackArgumentCountMismatch() await connection.ReceiveJsonMessage( new { - invocationId = "1", + invocationId = "1", type = 1, target = "Foo", arguments = new object[] { 42, "42" } @@ -159,7 +170,18 @@ public async Task ConnectionClosedOnCallbackArgumentTypeMismatch() var connection = new TestConnection(); var hubConnection = new HubConnection(connection, new JsonHubProtocol(new JsonSerializer()), new LoggerFactory()); var closeTcs = new TaskCompletionSource(); - hubConnection.Closed += e => closeTcs.TrySetException(e); + hubConnection.Closed += e => + { + if (e == null) + { + closeTcs.TrySetResult(null); + } + else + { + closeTcs.TrySetException(e); + } + }; + try { hubConnection.On("Foo", r => { }); diff --git a/test/Microsoft.AspNetCore.SignalR.Client.Tests/HubConnectionProtocolTests.cs b/test/Microsoft.AspNetCore.SignalR.Client.Tests/HubConnectionProtocolTests.cs index 563d6a553e..7709603b2a 100644 --- a/test/Microsoft.AspNetCore.SignalR.Client.Tests/HubConnectionProtocolTests.cs +++ b/test/Microsoft.AspNetCore.SignalR.Client.Tests/HubConnectionProtocolTests.cs @@ -3,6 +3,7 @@ using System; using System.Threading.Tasks; +using System.Threading.Tasks.Channels; using Microsoft.AspNetCore.SignalR.Internal.Protocol; using Microsoft.AspNetCore.SignalR.Tests.Common; using Microsoft.Extensions.Logging; @@ -38,6 +39,32 @@ public async Task InvokeSendsAnInvocationMessage() } } + [Fact] + public async Task StreamSendsAnInvocationMessage() + { + var connection = new TestConnection(); + var hubConnection = new HubConnection(connection, new JsonHubProtocol(new JsonSerializer()), new LoggerFactory()); + try + { + await hubConnection.StartAsync(); + + var channel = hubConnection.Stream("Foo"); + + var invokeMessage = await connection.ReadSentTextMessageAsync().OrTimeout(); + + Assert.Equal("{\"invocationId\":\"1\",\"type\":1,\"target\":\"Foo\",\"arguments\":[]}", invokeMessage); + + // Complete the channel + await connection.ReceiveJsonMessage(new { invocationId = "1", type = 3 }).OrTimeout(); + await channel.Completion; + } + finally + { + await hubConnection.DisposeAsync().OrTimeout(); + await connection.DisposeAsync().OrTimeout(); + } + } + [Fact] public async Task InvokeCompletedWhenCompletionMessageReceived() { @@ -60,6 +87,28 @@ public async Task InvokeCompletedWhenCompletionMessageReceived() } } + [Fact] + public async Task StreamCompletesWhenCompletionMessageIsReceived() + { + var connection = new TestConnection(); + var hubConnection = new HubConnection(connection, new JsonHubProtocol(new JsonSerializer()), new LoggerFactory()); + try + { + await hubConnection.StartAsync(); + + var channel = hubConnection.Stream("Foo"); + + await connection.ReceiveJsonMessage(new { invocationId = "1", type = 3 }).OrTimeout(); + + Assert.Empty(await channel.ReadAllAsync()); + } + finally + { + await hubConnection.DisposeAsync().OrTimeout(); + await connection.DisposeAsync().OrTimeout(); + } + } + [Fact] public async Task InvokeYieldsResultWhenCompletionMessageReceived() { @@ -82,6 +131,29 @@ public async Task InvokeYieldsResultWhenCompletionMessageReceived() } } + [Fact] + public async Task StreamFailsIfCompletionMessageHasPayload() + { + var connection = new TestConnection(); + var hubConnection = new HubConnection(connection, new JsonHubProtocol(new JsonSerializer()), new LoggerFactory()); + try + { + await hubConnection.StartAsync(); + + var channel = hubConnection.Stream("Foo"); + + await connection.ReceiveJsonMessage(new { invocationId = "1", type = 3, result = "Oops" }).OrTimeout(); + + var ex = await Assert.ThrowsAsync(async () => await channel.ReadAllAsync().OrTimeout()); + Assert.Equal("Server provided a result in a completion response to a streamed invocation.", ex.Message); + } + finally + { + await hubConnection.DisposeAsync().OrTimeout(); + await connection.DisposeAsync().OrTimeout(); + } + } + [Fact] public async Task InvokeFailsWithExceptionWhenCompletionWithErrorReceived() { @@ -106,7 +178,29 @@ public async Task InvokeFailsWithExceptionWhenCompletionWithErrorReceived() } [Fact] - // This will fail (intentionally) when we support streaming! + public async Task StreamFailsWithExceptionWhenCompletionWithErrorReceived() + { + var connection = new TestConnection(); + var hubConnection = new HubConnection(connection, new JsonHubProtocol(new JsonSerializer()), new LoggerFactory()); + try + { + await hubConnection.StartAsync(); + + var channel = hubConnection.Stream("Foo"); + + await connection.ReceiveJsonMessage(new { invocationId = "1", type = 3, error = "An error occurred" }).OrTimeout(); + + var ex = await Assert.ThrowsAsync(async () => await channel.ReadAllAsync().OrTimeout()); + Assert.Equal("An error occurred", ex.Message); + } + finally + { + await hubConnection.DisposeAsync().OrTimeout(); + await connection.DisposeAsync().OrTimeout(); + } + } + + [Fact] public async Task InvokeFailsWithErrorWhenStreamingItemReceived() { var connection = new TestConnection(); @@ -117,10 +211,37 @@ public async Task InvokeFailsWithErrorWhenStreamingItemReceived() var invokeTask = hubConnection.Invoke("Foo"); - await connection.ReceiveJsonMessage(new { invocationId = "1", type = 2, result = 42 }).OrTimeout(); + await connection.ReceiveJsonMessage(new { invocationId = "1", type = 2, item = 42 }).OrTimeout(); + + var ex = await Assert.ThrowsAsync(() => invokeTask).OrTimeout(); + Assert.Equal("Streaming methods must be invoked using HubConnection.Stream", ex.Message); + } + finally + { + await hubConnection.DisposeAsync().OrTimeout(); + await connection.DisposeAsync().OrTimeout(); + } + } + + [Fact] + public async Task StreamYieldsItemsAsTheyArrive() + { + var connection = new TestConnection(); + var hubConnection = new HubConnection(connection, new JsonHubProtocol(new JsonSerializer()), new LoggerFactory()); + try + { + await hubConnection.StartAsync(); + + var channel = hubConnection.Stream("Foo"); + + await connection.ReceiveJsonMessage(new { invocationId = "1", type = 2, item = "1" }).OrTimeout(); + await connection.ReceiveJsonMessage(new { invocationId = "1", type = 2, item = "2" }).OrTimeout(); + await connection.ReceiveJsonMessage(new { invocationId = "1", type = 2, item = "3" }).OrTimeout(); + await connection.ReceiveJsonMessage(new { invocationId = "1", type = 3 }).OrTimeout(); + + var notifications = await channel.ReadAllAsync().OrTimeout(); - var ex = await Assert.ThrowsAsync(() => invokeTask).OrTimeout(); - Assert.Equal("Streaming method results are not supported", ex.Message); + Assert.Equal(new[] { "1", "2", "3", }, notifications.ToArray()); } finally { diff --git a/test/Microsoft.AspNetCore.SignalR.Client.Tests/Microsoft.AspNetCore.SignalR.Client.Tests.csproj b/test/Microsoft.AspNetCore.SignalR.Client.Tests/Microsoft.AspNetCore.SignalR.Client.Tests.csproj index 08845b7062..4edd643532 100644 --- a/test/Microsoft.AspNetCore.SignalR.Client.Tests/Microsoft.AspNetCore.SignalR.Client.Tests.csproj +++ b/test/Microsoft.AspNetCore.SignalR.Client.Tests/Microsoft.AspNetCore.SignalR.Client.Tests.csproj @@ -10,6 +10,7 @@ + diff --git a/test/Microsoft.AspNetCore.SignalR.Client.Tests/TestConnection.cs b/test/Microsoft.AspNetCore.SignalR.Client.Tests/TestConnection.cs index 5eaacddd65..b7aecf4eed 100644 --- a/test/Microsoft.AspNetCore.SignalR.Client.Tests/TestConnection.cs +++ b/test/Microsoft.AspNetCore.SignalR.Client.Tests/TestConnection.cs @@ -1,4 +1,4 @@ -// Copyright (c) .NET Foundation. All rights reserved. +// Copyright (c) .NET Foundation. All rights reserved. // Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. using System; @@ -116,4 +116,4 @@ private async Task ReceiveLoopAsync(CancellationToken token) } } } -} \ No newline at end of file +} diff --git a/test/Microsoft.AspNetCore.SignalR.Common.Tests/Internal/Protocol/JsonHubProtocolTests.cs b/test/Microsoft.AspNetCore.SignalR.Common.Tests/Internal/Protocol/JsonHubProtocolTests.cs index 6dc6fb4a0d..326c8a6851 100644 --- a/test/Microsoft.AspNetCore.SignalR.Common.Tests/Internal/Protocol/JsonHubProtocolTests.cs +++ b/test/Microsoft.AspNetCore.SignalR.Common.Tests/Internal/Protocol/JsonHubProtocolTests.cs @@ -1,4 +1,4 @@ -using System; +using System; using System.Collections.Generic; using System.Linq; using System.Text; @@ -24,15 +24,15 @@ public class JsonHubProtocolTests new object[] { new InvocationMessage("123", false, "Target", new CustomObject()), false, NullValueHandling.Include, "{\"invocationId\":\"123\",\"type\":1,\"target\":\"Target\",\"arguments\":[{\"StringProp\":\"SignalR!\",\"DoubleProp\":6.2831853071,\"IntProp\":42,\"DateTimeProp\":\"2017-04-11T00:00:00\",\"NullProp\":null}]}" }, new object[] { new InvocationMessage("123", false, "Target", new CustomObject()), true, NullValueHandling.Include, "{\"invocationId\":\"123\",\"type\":1,\"target\":\"Target\",\"arguments\":[{\"stringProp\":\"SignalR!\",\"doubleProp\":6.2831853071,\"intProp\":42,\"dateTimeProp\":\"2017-04-11T00:00:00\",\"nullProp\":null}]}" }, - new object[] { new StreamItemMessage("123", 1), true, NullValueHandling.Ignore, "{\"invocationId\":\"123\",\"type\":2,\"result\":1}" }, - new object[] { new StreamItemMessage("123", "Foo"), true, NullValueHandling.Ignore, "{\"invocationId\":\"123\",\"type\":2,\"result\":\"Foo\"}" }, - new object[] { new StreamItemMessage("123", 2.0f), true, NullValueHandling.Ignore, "{\"invocationId\":\"123\",\"type\":2,\"result\":2.0}" }, - new object[] { new StreamItemMessage("123", true), true, NullValueHandling.Ignore, "{\"invocationId\":\"123\",\"type\":2,\"result\":true}" }, - new object[] { new StreamItemMessage("123", null), true, NullValueHandling.Ignore, "{\"invocationId\":\"123\",\"type\":2,\"result\":null}" }, - new object[] { new StreamItemMessage("123", new CustomObject()), false, NullValueHandling.Ignore, "{\"invocationId\":\"123\",\"type\":2,\"result\":{\"StringProp\":\"SignalR!\",\"DoubleProp\":6.2831853071,\"IntProp\":42,\"DateTimeProp\":\"2017-04-11T00:00:00\"}}" }, - new object[] { new StreamItemMessage("123", new CustomObject()), true, NullValueHandling.Ignore, "{\"invocationId\":\"123\",\"type\":2,\"result\":{\"stringProp\":\"SignalR!\",\"doubleProp\":6.2831853071,\"intProp\":42,\"dateTimeProp\":\"2017-04-11T00:00:00\"}}" }, - new object[] { new StreamItemMessage("123", new CustomObject()), false, NullValueHandling.Include, "{\"invocationId\":\"123\",\"type\":2,\"result\":{\"StringProp\":\"SignalR!\",\"DoubleProp\":6.2831853071,\"IntProp\":42,\"DateTimeProp\":\"2017-04-11T00:00:00\",\"NullProp\":null}}" }, - new object[] { new StreamItemMessage("123", new CustomObject()), true, NullValueHandling.Include, "{\"invocationId\":\"123\",\"type\":2,\"result\":{\"stringProp\":\"SignalR!\",\"doubleProp\":6.2831853071,\"intProp\":42,\"dateTimeProp\":\"2017-04-11T00:00:00\",\"nullProp\":null}}" }, + new object[] { new StreamItemMessage("123", 1), true, NullValueHandling.Ignore, "{\"invocationId\":\"123\",\"type\":2,\"item\":1}" }, + new object[] { new StreamItemMessage("123", "Foo"), true, NullValueHandling.Ignore, "{\"invocationId\":\"123\",\"type\":2,\"item\":\"Foo\"}" }, + new object[] { new StreamItemMessage("123", 2.0f), true, NullValueHandling.Ignore, "{\"invocationId\":\"123\",\"type\":2,\"item\":2.0}" }, + new object[] { new StreamItemMessage("123", true), true, NullValueHandling.Ignore, "{\"invocationId\":\"123\",\"type\":2,\"item\":true}" }, + new object[] { new StreamItemMessage("123", null), true, NullValueHandling.Ignore, "{\"invocationId\":\"123\",\"type\":2,\"item\":null}" }, + new object[] { new StreamItemMessage("123", new CustomObject()), false, NullValueHandling.Ignore, "{\"invocationId\":\"123\",\"type\":2,\"item\":{\"StringProp\":\"SignalR!\",\"DoubleProp\":6.2831853071,\"IntProp\":42,\"DateTimeProp\":\"2017-04-11T00:00:00\"}}" }, + new object[] { new StreamItemMessage("123", new CustomObject()), true, NullValueHandling.Ignore, "{\"invocationId\":\"123\",\"type\":2,\"item\":{\"stringProp\":\"SignalR!\",\"doubleProp\":6.2831853071,\"intProp\":42,\"dateTimeProp\":\"2017-04-11T00:00:00\"}}" }, + new object[] { new StreamItemMessage("123", new CustomObject()), false, NullValueHandling.Include, "{\"invocationId\":\"123\",\"type\":2,\"item\":{\"StringProp\":\"SignalR!\",\"DoubleProp\":6.2831853071,\"IntProp\":42,\"DateTimeProp\":\"2017-04-11T00:00:00\",\"NullProp\":null}}" }, + new object[] { new StreamItemMessage("123", new CustomObject()), true, NullValueHandling.Include, "{\"invocationId\":\"123\",\"type\":2,\"item\":{\"stringProp\":\"SignalR!\",\"doubleProp\":6.2831853071,\"intProp\":42,\"dateTimeProp\":\"2017-04-11T00:00:00\",\"nullProp\":null}}" }, new object[] { CompletionMessage.WithResult("123", 1), true, NullValueHandling.Ignore, "{\"invocationId\":\"123\",\"type\":3,\"result\":1}" }, new object[] { CompletionMessage.WithResult("123", "Foo"), true, NullValueHandling.Ignore, "{\"invocationId\":\"123\",\"type\":3,\"result\":\"Foo\"}" }, @@ -96,7 +96,7 @@ public void ParseMessage(HubMessage expectedMessage, bool camelCase, NullValueHa [InlineData("{'type':2}", "Missing required property 'invocationId'.")] [InlineData("{'type':2,'invocationId':42}", "Expected 'invocationId' to be of type String.")] - [InlineData("{'type':2,'invocationId':'42'}", "Missing required property 'result'.")] + [InlineData("{'type':2,'invocationId':'42'}", "Missing required property 'item'.")] [InlineData("{'type':3}", "Missing required property 'invocationId'.")] [InlineData("{'type':3,'invocationId':42}", "Expected 'invocationId' to be of type String.")] diff --git a/test/Microsoft.AspNetCore.SignalR.Tests/CancellationDisposable.cs b/test/Microsoft.AspNetCore.SignalR.Tests/CancellationDisposable.cs new file mode 100644 index 0000000000..50a02883ca --- /dev/null +++ b/test/Microsoft.AspNetCore.SignalR.Tests/CancellationDisposable.cs @@ -0,0 +1,24 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; +using System.Threading; + +namespace Microsoft.AspNetCore.SignalR.Tests +{ + internal class CancellationDisposable : IDisposable + { + private CancellationTokenSource _cts; + + public CancellationDisposable(CancellationTokenSource cts) + { + _cts = cts; + } + + public void Dispose() + { + _cts.Cancel(); + _cts.Dispose(); + } + } +} diff --git a/test/Microsoft.AspNetCore.SignalR.Tests/HubEndpointTests.cs b/test/Microsoft.AspNetCore.SignalR.Tests/HubEndpointTests.cs index bf8536a2e8..fc4d94c9b5 100644 --- a/test/Microsoft.AspNetCore.SignalR.Tests/HubEndpointTests.cs +++ b/test/Microsoft.AspNetCore.SignalR.Tests/HubEndpointTests.cs @@ -2,7 +2,9 @@ // Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. using System; +using System.Threading; using System.Threading.Tasks; +using System.Threading.Tasks.Channels; using Microsoft.AspNetCore.SignalR.Internal.Protocol; using Microsoft.AspNetCore.SignalR.Tests.Common; using Microsoft.AspNetCore.Sockets; @@ -21,7 +23,7 @@ public async Task HubsAreDisposed() var serviceProvider = CreateServiceProvider(s => s.AddSingleton(trackDispose)); var endPoint = serviceProvider.GetService>(); - using (var client = new TestClient(serviceProvider)) + using (var client = new TestClient()) { var endPointTask = endPoint.OnConnectedAsync(client.Connection); @@ -51,7 +53,7 @@ public async Task LifetimeManagerOnDisconnectedAsyncCalledIfLifetimeManagerOnCon var endPoint = serviceProvider.GetService>(); - using (var client = new TestClient(serviceProvider)) + using (var client = new TestClient()) { var exception = await Assert.ThrowsAsync( @@ -79,7 +81,7 @@ public async Task HubOnDisconnectedAsyncCalledIfHubOnConnectedAsyncThrows() var endPoint = serviceProvider.GetService>(); - using (var client = new TestClient(serviceProvider)) + using (var client = new TestClient()) { var endPointTask = endPoint.OnConnectedAsync(client.Connection); client.Dispose(); @@ -103,7 +105,7 @@ public async Task LifetimeManagerOnDisconnectedAsyncCalledIfHubOnDisconnectedAsy var endPoint = serviceProvider.GetService>(); - using (var client = new TestClient(serviceProvider)) + using (var client = new TestClient()) { var endPointTask = endPoint.OnConnectedAsync(client.Connection); client.Dispose(); @@ -123,7 +125,7 @@ public async Task HubMethodCanReturnValueFromTask() var endPoint = serviceProvider.GetService>(); - using (var client = new TestClient(serviceProvider)) + using (var client = new TestClient()) { var endPointTask = endPoint.OnConnectedAsync(client.Connection); @@ -146,7 +148,7 @@ public async Task HubMethodsAreCaseInsensitive() var endPoint = serviceProvider.GetService>(); - using (var client = new TestClient(serviceProvider)) + using (var client = new TestClient()) { var endPointTask = endPoint.OnConnectedAsync(client.Connection); @@ -170,7 +172,7 @@ public async Task HubMethodCanThrowOrYieldFailedTask(string methodName) var endPoint = serviceProvider.GetService>(); - using (var client = new TestClient(serviceProvider)) + using (var client = new TestClient()) { var endPointTask = endPoint.OnConnectedAsync(client.Connection); @@ -192,7 +194,7 @@ public async Task HubMethodCanReturnValue() var endPoint = serviceProvider.GetService>(); - using (var client = new TestClient(serviceProvider)) + using (var client = new TestClient()) { var endPointTask = endPoint.OnConnectedAsync(client.Connection); @@ -215,7 +217,7 @@ public async Task HubMethodCanBeVoid() var endPoint = serviceProvider.GetService>(); - using (var client = new TestClient(serviceProvider)) + using (var client = new TestClient()) { var endPointTask = endPoint.OnConnectedAsync(client.Connection); @@ -237,7 +239,7 @@ public async Task HubMethodWithMultiParam() var endPoint = serviceProvider.GetService>(); - using (var client = new TestClient(serviceProvider)) + using (var client = new TestClient()) { var endPointTask = endPoint.OnConnectedAsync(client.Connection); @@ -259,7 +261,7 @@ public async Task CanCallInheritedHubMethodFromInheritingHub() var endPoint = serviceProvider.GetService>(); - using (var client = new TestClient(serviceProvider)) + using (var client = new TestClient()) { var endPointTask = endPoint.OnConnectedAsync(client.Connection); @@ -281,7 +283,7 @@ public async Task CanCallOverridenVirtualHubMethod() var endPoint = serviceProvider.GetService>(); - using (var client = new TestClient(serviceProvider)) + using (var client = new TestClient()) { var endPointTask = endPoint.OnConnectedAsync(client.Connection); @@ -303,7 +305,7 @@ public async Task CannotCallOverriddenBaseHubMethod() var endPoint = serviceProvider.GetService>(); - using (var client = new TestClient(serviceProvider)) + using (var client = new TestClient()) { var endPointTask = endPoint.OnConnectedAsync(client.Connection); @@ -341,8 +343,8 @@ public async Task BroadcastHubMethod_SendsToAllClients() var endPoint = serviceProvider.GetService>(); - using (var firstClient = new TestClient(serviceProvider)) - using (var secondClient = new TestClient(serviceProvider)) + using (var firstClient = new TestClient()) + using (var secondClient = new TestClient()) { var firstEndPointTask = endPoint.OnConnectedAsync(firstClient.Connection); var secondEndPointTask = endPoint.OnConnectedAsync(secondClient.Connection); @@ -376,8 +378,8 @@ public async Task HubsCanAddAndSendToGroup() var endPoint = serviceProvider.GetService>(); - using (var firstClient = new TestClient(serviceProvider)) - using (var secondClient = new TestClient(serviceProvider)) + using (var firstClient = new TestClient()) + using (var secondClient = new TestClient()) { var firstEndPointTask = endPoint.OnConnectedAsync(firstClient.Connection); var secondEndPointTask = endPoint.OnConnectedAsync(secondClient.Connection); @@ -418,7 +420,7 @@ public async Task RemoveFromGroupWhenNotInGroupDoesNotFail() var endPoint = serviceProvider.GetService>(); - using (var client = new TestClient(serviceProvider)) + using (var client = new TestClient()) { var endPointTask = endPoint.OnConnectedAsync(client.Connection); @@ -438,8 +440,8 @@ public async Task HubsCanSendToUser() var endPoint = serviceProvider.GetService>(); - using (var firstClient = new TestClient(serviceProvider)) - using (var secondClient = new TestClient(serviceProvider)) + using (var firstClient = new TestClient()) + using (var secondClient = new TestClient()) { var firstEndPointTask = endPoint.OnConnectedAsync(firstClient.Connection); var secondEndPointTask = endPoint.OnConnectedAsync(secondClient.Connection); @@ -470,8 +472,8 @@ public async Task HubsCanSendToConnection() var endPoint = serviceProvider.GetService>(); - using (var firstClient = new TestClient(serviceProvider)) - using (var secondClient = new TestClient(serviceProvider)) + using (var firstClient = new TestClient()) + using (var secondClient = new TestClient()) { var firstEndPointTask = endPoint.OnConnectedAsync(firstClient.Connection); var secondEndPointTask = endPoint.OnConnectedAsync(secondClient.Connection); @@ -495,6 +497,62 @@ public async Task HubsCanSendToConnection() } } + [Theory] + [InlineData(nameof(StreamingHub.CounterChannel))] + [InlineData(nameof(StreamingHub.CounterObservable))] + public async Task HubsCanStreamResponses(string method) + { + var serviceProvider = CreateServiceProvider(); + + var endPoint = serviceProvider.GetService>(); + + using (var client = new TestClient()) + { + var endPointLifetime = endPoint.OnConnectedAsync(client.Connection); + + await client.Connected.OrTimeout(); + + var messages = await client.StreamAsync(method, 4).OrTimeout(); + + Assert.Equal(5, messages.Count); + AssertHubMessage(new StreamItemMessage(string.Empty, "0"), messages[0]); + AssertHubMessage(new StreamItemMessage(string.Empty, "1"), messages[1]); + AssertHubMessage(new StreamItemMessage(string.Empty, "2"), messages[2]); + AssertHubMessage(new StreamItemMessage(string.Empty, "3"), messages[3]); + AssertHubMessage(new CompletionMessage(string.Empty, error: null, result: null, hasResult: false), messages[4]); + + client.Dispose(); + + await endPointLifetime; + } + } + + private static void AssertHubMessage(HubMessage expected, HubMessage actual) + { + // We aren't testing InvocationIds here + switch (expected) + { + case CompletionMessage expectedCompletion: + var actualCompletion = Assert.IsType(actual); + Assert.Equal(expectedCompletion.Error, actualCompletion.Error); + Assert.Equal(expectedCompletion.HasResult, actualCompletion.HasResult); + Assert.Equal(expectedCompletion.Result, actualCompletion.Result); + break; + case StreamItemMessage expectedStreamItem: + var actualStreamItem = Assert.IsType(actual); + Assert.Equal(expectedStreamItem.Item, actualStreamItem.Item); + break; + case InvocationMessage expectedInvocation: + var actualInvocation = Assert.IsType(actual); + Assert.Equal(expectedInvocation.NonBlocking, actualInvocation.NonBlocking); + Assert.Equal(expectedInvocation.Target, actualInvocation.Target); + Assert.Equal(expectedInvocation.Arguments, actualInvocation.Arguments); + break; + default: + throw new InvalidOperationException($"Unsupported Hub Message type {expected.GetType()}"); + } + } + private static Type GetEndPointType(Type hubType) { var endPointType = typeof(HubEndPoint<>); @@ -518,6 +576,55 @@ private IServiceProvider CreateServiceProvider(Action addServ return services.BuildServiceProvider(); } + public class StreamingHub : TestHub + { + public IObservable CounterObservable(int count) + { + return new CountingObservable(count); + } + + public ReadableChannel CounterChannel(int count) + { + var channel = Channel.CreateUnbounded(); + + var task = Task.Run(async () => + { + for (int i = 0; i < count; i++) + { + await channel.Out.WriteAsync(i.ToString()); + } + channel.Out.Complete(); + }); + + return channel.In; + } + + private class CountingObservable : IObservable + { + private int _count; + + public CountingObservable(int count) + { + _count = count; + } + + public IDisposable Subscribe(IObserver observer) + { + var cts = new CancellationTokenSource(); + Task.Run(() => + { + for (int i = 0; !cts.Token.IsCancellationRequested && i < _count; i++) + { + observer.OnNext(i.ToString()); + } + observer.OnCompleted(); + }); + + return new CancellationDisposable(cts); + } + } + } + public class OnConnectedThrowsHub : Hub { public override Task OnConnectedAsync() @@ -528,7 +635,7 @@ public override Task OnConnectedAsync() } } - public class OnDisconnectedThrowsHub : Hub + public class OnDisconnectedThrowsHub : TestHub { public override Task OnDisconnectedAsync(Exception exception) { @@ -538,14 +645,8 @@ public override Task OnDisconnectedAsync(Exception exception) } } - private class MethodHub : Hub + private class MethodHub : TestHub { - public override Task OnConnectedAsync() - { - Context.Connection.Metadata.Get>("ConnectedTask").SetResult(true); - return base.OnConnectedAsync(); - } - public Task GroupRemoveMethod(string groupName) { return Groups.RemoveAsync(groupName); @@ -624,7 +725,7 @@ public override int VirtualMethod(int num) } } - private class BaseHub : Hub + private class BaseHub : TestHub { public string BaseMethod(string message) { @@ -637,7 +738,7 @@ public virtual int VirtualMethod(int num) } } - private class InvalidHub : Hub + private class InvalidHub : TestHub { public void OverloadedMethod(int num) { @@ -648,7 +749,7 @@ public void OverloadedMethod(string message) } } - private class DisposeTrackingHub : Hub + private class DisposeTrackingHub : TestHub { private TrackDispose _trackDispose; @@ -670,5 +771,14 @@ private class TrackDispose { public int DisposeCount = 0; } + + public abstract class TestHub : Hub + { + public override Task OnConnectedAsync() + { + Context.Connection.Metadata.Get>("ConnectedTask")?.TrySetResult(true); + return base.OnConnectedAsync(); + } + } } } diff --git a/test/Microsoft.AspNetCore.SignalR.Tests/TestClient.cs b/test/Microsoft.AspNetCore.SignalR.Tests/TestClient.cs index e9edcdc268..d7c70866cb 100644 --- a/test/Microsoft.AspNetCore.SignalR.Tests/TestClient.cs +++ b/test/Microsoft.AspNetCore.SignalR.Tests/TestClient.cs @@ -1,4 +1,4 @@ -// Copyright (c) .NET Foundation. All rights reserved. +// Copyright (c) .NET Foundation. All rights reserved. // Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. using System; @@ -21,11 +21,11 @@ public class TestClient : IDisposable, IInvocationBinder private IHubProtocol _protocol; private CancellationTokenSource _cts; - public ConnectionContext Connection; + public ConnectionContext Connection { get; } public IChannelConnection Application { get; } public Task Connected => Connection.Metadata.Get>("ConnectedTask").Task; - public TestClient(IServiceProvider serviceProvider) + public TestClient() { var transportToApplication = Channel.CreateUnbounded(); var applicationToTransport = Channel.CreateUnbounded(); @@ -42,6 +42,39 @@ public TestClient(IServiceProvider serviceProvider) _cts = new CancellationTokenSource(); } + public async Task> StreamAsync(string methodName, params object[] args) + { + var invocationId = await SendInvocationAsync(methodName, args); + + var messages = new List(); + while (true) + { + var message = await Read(); + + if (!string.Equals(message.InvocationId, invocationId)) + { + throw new NotSupportedException("TestClient does not support multiple outgoing invocations!"); + } + + if (message == null) + { + throw new InvalidOperationException("Connection aborted!"); + } + + switch (message) + { + case StreamItemMessage _: + messages.Add(message); + break; + case CompletionMessage _: + messages.Add(message); + return messages; + default: + throw new NotSupportedException("TestClient does not support receiving invocations!"); + } + } + } + public async Task InvokeAsync(string methodName, params object[] args) { var invocationId = await SendInvocationAsync(methodName, args); @@ -63,7 +96,7 @@ public async Task InvokeAsync(string methodName, params objec switch (message) { case StreamItemMessage result: - throw new NotSupportedException("TestClient does not support streaming!"); + throw new NotSupportedException("Use 'StreamAsync' to call a streaming method"); case CompletionMessage completion: return completion; default: