Skip to content

Commit

Permalink
Fixed the trio backend yielding ints instead of Signals from its sign…
Browse files Browse the repository at this point in the history
…al receiver
  • Loading branch information
agronholm committed Oct 13, 2022
1 parent 3ce7bb8 commit f080174
Show file tree
Hide file tree
Showing 3 changed files with 30 additions and 11 deletions.
2 changes: 2 additions & 0 deletions docs/versionhistory.rst
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@ This library adheres to `Semantic Versioning 2.0 <http:https://semver.org/>`_.
order (instead of LIFO) (PR by Conor Stevenson)
- Fixed ``CancelScope.cancel()`` not working on asyncio if called before entering the
scope
- Fixed ``open_signal_receiver()`` inconsistently yielding integers instead of
``signal.Signals`` instances on the ``trio`` backend

**3.6.1**

Expand Down
30 changes: 21 additions & 9 deletions src/anyio/_backends/_trio.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import array
import math
import socket
from collections.abc import Iterable
from collections.abc import AsyncIterator, Iterable
from concurrent.futures import Future
from dataclasses import dataclass
from functools import partial
Expand Down Expand Up @@ -674,12 +674,16 @@ def statistics(self) -> CapacityLimiterStatistics:
#


class _SignalReceiver(Generic[T]):
def __init__(self, cm: ContextManager[T]):
self._cm = cm
class _SignalReceiver:
_iterator: AsyncIterator[int]

def __enter__(self) -> T:
return self._cm.__enter__()
def __init__(self, signals: tuple[Signals, ...]):
self._signals = signals

def __enter__(self) -> _SignalReceiver:
self._cm = trio.open_signal_receiver(*self._signals)
self._iterator = self._cm.__enter__()
return self

def __exit__(
self,
Expand All @@ -689,6 +693,13 @@ def __exit__(
) -> bool | None:
return self._cm.__exit__(exc_type, exc_val, exc_tb)

def __aiter__(self) -> _SignalReceiver:
return self

async def __anext__(self) -> Signals:
signum = await self._iterator.__anext__()
return Signals(signum)


#
# Testing and debugging
Expand Down Expand Up @@ -1060,9 +1071,10 @@ def current_default_thread_limiter(cls) -> CapacityLimiter:
return limiter

@classmethod
def open_signal_receiver(cls, *signals: Signals) -> ContextManager:
cm = trio.open_signal_receiver(*signals)
return _SignalReceiver(cm)
def open_signal_receiver(
cls, *signals: Signals
) -> ContextManager[AsyncIterator[Signals]]:
return _SignalReceiver(signals)

@classmethod
def get_current_task(cls) -> TaskInfo:
Expand Down
9 changes: 7 additions & 2 deletions tests/test_signals.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,13 @@ async def test_receive_signals() -> None:
await to_thread.run_sync(os.kill, os.getpid(), signal.SIGUSR1)
await to_thread.run_sync(os.kill, os.getpid(), signal.SIGUSR2)
with fail_after(1):
assert await sigiter.__anext__() == signal.SIGUSR1
assert await sigiter.__anext__() == signal.SIGUSR2
sigusr1 = await sigiter.__anext__()
assert isinstance(sigusr1, signal.Signals)
assert sigusr1 == signal.Signals.SIGUSR1

sigusr2 = await sigiter.__anext__()
assert isinstance(sigusr2, signal.Signals)
assert sigusr2 == signal.Signals.SIGUSR2


async def test_task_group_cancellation_open() -> None:
Expand Down

0 comments on commit f080174

Please sign in to comment.