Skip to content

Commit

Permalink
Add resume_agile to allow coroutine to resume in any apartment (#1356)
Browse files Browse the repository at this point in the history
  • Loading branch information
oldnewthing committed Sep 12, 2023
1 parent 23c4ced commit fac72c8
Show file tree
Hide file tree
Showing 4 changed files with 112 additions and 57 deletions.
48 changes: 31 additions & 17 deletions strings/base_coroutine_foundation.h
Original file line number Diff line number Diff line change
Expand Up @@ -99,44 +99,49 @@ namespace winrt::impl
return async.GetResults();
}

template<typename Awaiter>
struct disconnect_aware_handler
struct ignore_apartment_context {};

template<bool preserve_context, typename Awaiter>
struct disconnect_aware_handler : private std::conditional_t<preserve_context, resume_apartment_context, ignore_apartment_context>
{
disconnect_aware_handler(Awaiter* awaiter, coroutine_handle<> handle) noexcept
: m_awaiter(awaiter), m_handle(handle) { }

disconnect_aware_handler(disconnect_aware_handler&& other) noexcept
: m_context(std::move(other.m_context))
, m_awaiter(std::exchange(other.m_awaiter, {}))
, m_handle(std::exchange(other.m_handle, {})) { }
disconnect_aware_handler(disconnect_aware_handler&& other) = default;

~disconnect_aware_handler()
{
if (m_handle) Complete();
if (m_handle.value) Complete();
}

template<typename Async>
void operator()(Async&&, Windows::Foundation::AsyncStatus status)
{
m_awaiter->status = status;
m_awaiter.value->status = status;
Complete();
}

private:
resume_apartment_context m_context;
Awaiter* m_awaiter;
coroutine_handle<> m_handle;
movable_primitive<Awaiter*> m_awaiter;
movable_primitive<coroutine_handle<>, nullptr> m_handle;

void Complete()
{
if (m_awaiter->suspending.exchange(false, std::memory_order_release))
if (m_awaiter.value->suspending.exchange(false, std::memory_order_release))
{
m_handle = nullptr; // resumption deferred to await_suspend
m_handle.value = nullptr; // resumption deferred to await_suspend
}
else
{
auto handle = std::exchange(m_handle, {});
if (!resume_apartment(m_context, handle, &m_awaiter->failure))
auto handle = m_handle.detach();
if constexpr (preserve_context)
{
if (!resume_apartment(*this, handle, &m_awaiter.value->failure))
{
handle.resume();
}
}
else
{
handle.resume();
}
Expand All @@ -145,7 +150,7 @@ namespace winrt::impl
};

#ifdef WINRT_IMPL_COROUTINES
template <typename Async>
template <typename Async, bool preserve_context = true>
struct await_adapter : cancellable_awaiter<await_adapter<Async>>
{
await_adapter(Async const& async) : async(async) { }
Expand Down Expand Up @@ -185,7 +190,7 @@ namespace winrt::impl
private:
bool register_completed_callback(coroutine_handle<> handle)
{
async.Completed(disconnect_aware_handler(this, handle));
async.Completed(disconnect_aware_handler<preserve_context, await_adapter>(this, handle));
return suspending.exchange(false, std::memory_order_acquire);
}

Expand Down Expand Up @@ -249,6 +254,15 @@ namespace winrt::impl
}

#ifdef WINRT_IMPL_COROUTINES
WINRT_EXPORT namespace winrt
{
template<typename Async, typename = std::enable_if_t<std::is_convertible_v<Async, winrt::Windows::Foundation::IAsyncInfo>>>
inline impl::await_adapter<Async, false> resume_agile(Async const& async)
{
return { async };
};
}

WINRT_EXPORT namespace winrt::Windows::Foundation
{
inline impl::await_adapter<IAsyncAction> operator co_await(IAsyncAction const& async)
Expand Down
17 changes: 4 additions & 13 deletions strings/base_coroutine_threadpool.h
Original file line number Diff line number Diff line change
Expand Up @@ -52,23 +52,14 @@ namespace winrt::impl
{
resume_apartment_context() = default;
resume_apartment_context(std::nullptr_t) : m_context(nullptr), m_context_type(-1) {}
resume_apartment_context(resume_apartment_context const&) = default;
resume_apartment_context(resume_apartment_context&& other) noexcept :
m_context(std::move(other.m_context)), m_context_type(std::exchange(other.m_context_type, -1)) {}
resume_apartment_context& operator=(resume_apartment_context const&) = default;
resume_apartment_context& operator=(resume_apartment_context&& other) noexcept
{
m_context = std::move(other.m_context);
m_context_type = std::exchange(other.m_context_type, -1);
return *this;
}

bool valid() const noexcept
{
return m_context_type >= 0;
return m_context_type.value >= 0;
}

com_ptr<IContextCallback> m_context = try_capture<IContextCallback>(WINRT_IMPL_CoGetObjectContext);
int32_t m_context_type = get_apartment_type().first;
movable_primitive<int32_t, -1> m_context_type = get_apartment_type().first;
};

inline int32_t __stdcall resume_apartment_callback(com_callback_args* args) noexcept
Expand Down Expand Up @@ -124,7 +115,7 @@ namespace winrt::impl
{
return false;
}
else if (context.m_context_type == 1 /* APTTYPE_MTA */)
else if (context.m_context_type.value == 1 /* APTTYPE_MTA */)
{
resume_background(handle);
return true;
Expand Down
19 changes: 19 additions & 0 deletions strings/base_meta.h
Original file line number Diff line number Diff line change
Expand Up @@ -193,6 +193,25 @@ namespace winrt::impl
}
}

template<typename T, auto empty_value = T{}>
struct movable_primitive
{
T value = empty_value;
movable_primitive() = default;
movable_primitive(T const& init) : value(init) {}
movable_primitive(movable_primitive const&) = default;
movable_primitive(movable_primitive&& other) :
value(other.detach()) {}
movable_primitive& operator=(movable_primitive const&) = default;
movable_primitive& operator=(movable_primitive&& other)
{
value = other.detach();
return *this;
}

T detach() { return std::exchange(value, empty_value); }
};

template <typename T, typename Enable = void>
struct arg
{
Expand Down
85 changes: 58 additions & 27 deletions test/test/await_adapter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,9 @@ namespace

static handle signal{ CreateEventW(nullptr, false, false, nullptr) };

IAsyncAction OtherForegroundAsync()
IAsyncAction OtherForegroundAsync(DispatcherQueue dispatcher)
{
// Simple coroutine that completes on a unique STA thread.

auto controller = DispatcherQueueController::CreateOnDedicatedThread();
auto dispatcher = controller.DispatcherQueue();

// Simple coroutine that completes on the specified STA thread.
co_await resume_foreground(dispatcher);
}

Expand All @@ -35,37 +31,37 @@ namespace
co_await resume_background();
}

IAsyncAction ForegroundAsync(DispatcherQueue dispatcher)
// Coroutine that completes on dispatcher1, while potentially blocking dispatcher2.
IAsyncAction ForegroundAsync(DispatcherQueue dispatcher1, DispatcherQueue dispatcher2)
{
REQUIRE(!is_sta());
co_await resume_foreground(dispatcher);
co_await resume_foreground(dispatcher1);
REQUIRE(is_sta());

// This exercises one STA thread waiting on another thus one context callback
// completing on another.
uint32_t id = GetCurrentThreadId();
co_await OtherForegroundAsync();
co_await OtherForegroundAsync(dispatcher2);
REQUIRE(id == GetCurrentThreadId());

// This just avoids the ForegroundAsync coroutine completing before
// BackgroundAsync waits on the result, forcing the Completed handler
// to be called on the foreground thread. This just makes the test
// success/failure more predictable.
// This Sleep() makes it more likely that the caller will actually suspend in await_suspend,
// so that the Completed handler triggers a resumption from the dispatcher1 thread.
Sleep(100);
}

fire_and_forget SignalFromForeground(DispatcherQueue dispatcher)
fire_and_forget SignalFromForeground(DispatcherQueue dispatcher1)
{
REQUIRE(!is_sta());
co_await resume_foreground(dispatcher);
co_await resume_foreground(dispatcher1);
REQUIRE(is_sta());

// Previously, this signal was never raised because the foreground thread
// was always blocked waiting for ContextCallback to return.
// Previously, we never got here because of a deadlock:
// The dispatcher1 thread was blocked waiting for ContextCallback to return,
// but the ContextCallback is waiting for this event to get signaled.
REQUIRE(SetEvent(signal.get()));
}

IAsyncAction BackgroundAsync(DispatcherQueue dispatcher)
IAsyncAction BackgroundAsync(DispatcherQueue dispatcher1, DispatcherQueue dispatcher2)
{
// Switch to a background (MTA) thread.
co_await resume_background();
Expand All @@ -76,19 +72,19 @@ namespace
co_await OtherBackgroundAsync();
REQUIRE(!is_sta());

// Wait for a coroutine that completes on a foreground (STA) thread.
co_await ForegroundAsync(dispatcher);
// Wait for a coroutine that completes on a the dispatcher1 thread (STA).
co_await ForegroundAsync(dispatcher1, dispatcher2);

// Resumption should automatically switch to a background (MTA) thread
// without blocking the Completed handler (which would in turn block the foreground thread).
// without blocking the Completed handler (which would in turn block the dispatcher1 thread).
REQUIRE(!is_sta());

// Attempt to signal from the foreground thread under the assumption
// that the foreground thread is not blocked.
SignalFromForeground(dispatcher);
// Attempt to signal from the dispatcher1 thread under the assumption
// that the dispatcher1 thread is not blocked.
SignalFromForeground(dispatcher1);

// Block the background (MTA) thread indefinitely until the signal is raied.
// Previously this would deadlock.
// Block the background (MTA) thread indefinitely until the signal is raised.
// Previously this would hang because the signal never got raised.
REQUIRE(WAIT_OBJECT_0 == WaitForSingleObject(signal.get(), INFINITE));
}
}
Expand All @@ -99,9 +95,44 @@ TEST_CASE("await_adapter", "[.clang-crash]")
#else
TEST_CASE("await_adapter")
#endif
{
auto controller1 = DispatcherQueueController::CreateOnDedicatedThread();
auto controller2 = DispatcherQueueController::CreateOnDedicatedThread();

BackgroundAsync(controller1.DispatcherQueue(), controller2.DispatcherQueue()).get();
controller1.ShutdownQueueAsync().get();
controller2.ShutdownQueueAsync().get();
}

namespace
{
IAsyncAction OtherBackgroundDelayAsync()
{
// Simple coroutine that completes on some MTA thread after a brief delay
// to ensure that the caller has suspended.

co_await resume_after(100ms);
}

IAsyncAction AgileAsync(DispatcherQueue dispatcher)
{
// Switch to the STA.
co_await resume_foreground(dispatcher);
REQUIRE(is_sta());

// Ask for agile resumption of a coroutine that finishes on a background thread.
// Add a 100ms delay to ensure we suspend.
co_await resume_agile(OtherBackgroundDelayAsync());
// We should be on the background thread now.
REQUIRE(!is_sta());
}
}

TEST_CASE("await_adapter_agile")
{
auto controller = DispatcherQueueController::CreateOnDedicatedThread();
auto dispatcher = controller.DispatcherQueue();

BackgroundAsync(dispatcher).get();
AgileAsync(dispatcher).get();
controller.ShutdownQueueAsync().get();
}

0 comments on commit fac72c8

Please sign in to comment.