Skip to content

Commit

Permalink
Support async actions (#437)
Browse files Browse the repository at this point in the history
* Support async actions

* Fixes after main rebase

* Test is_coroutine_callable
  • Loading branch information
hasier committed Mar 2, 2024
1 parent b7e4883 commit bfa2c80
Show file tree
Hide file tree
Showing 3 changed files with 97 additions and 3 deletions.
46 changes: 44 additions & 2 deletions tenacity/_asyncio.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from tenacity import DoAttempt
from tenacity import DoSleep
from tenacity import RetryCallState
from tenacity import _utils

WrappedFnReturnT = t.TypeVar("WrappedFnReturnT")
WrappedFn = t.TypeVar("WrappedFn", bound=t.Callable[..., t.Awaitable[t.Any]])
Expand All @@ -46,7 +47,7 @@ async def __call__( # type: ignore[override]

retry_state = RetryCallState(retry_object=self, fn=fn, args=args, kwargs=kwargs)
while True:
do = self.iter(retry_state=retry_state)
do = await self.iter(retry_state=retry_state)
if isinstance(do, DoAttempt):
try:
result = await fn(*args, **kwargs)
Expand All @@ -60,6 +61,47 @@ async def __call__( # type: ignore[override]
else:
return do # type: ignore[no-any-return]

@classmethod
def _wrap_action_func(cls, fn: t.Callable[..., t.Any]) -> t.Callable[..., t.Any]:
if _utils.is_coroutine_callable(fn):
return fn

async def inner(*args: t.Any, **kwargs: t.Any) -> t.Any:
return fn(*args, **kwargs)

return inner

def _add_action_func(self, fn: t.Callable[..., t.Any]) -> None:
self.iter_state.actions.append(self._wrap_action_func(fn))

async def _run_retry(self, retry_state: "RetryCallState") -> None: # type: ignore[override]
self.iter_state.retry_run_result = await self._wrap_action_func(self.retry)(
retry_state
)

async def _run_wait(self, retry_state: "RetryCallState") -> None: # type: ignore[override]
if self.wait:
sleep = await self._wrap_action_func(self.wait)(retry_state)
else:
sleep = 0.0

retry_state.upcoming_sleep = sleep

async def _run_stop(self, retry_state: "RetryCallState") -> None: # type: ignore[override]
self.statistics["delay_since_first_attempt"] = retry_state.seconds_since_start
self.iter_state.stop_run_result = await self._wrap_action_func(self.stop)(
retry_state
)

async def iter(
self, retry_state: "RetryCallState"
) -> t.Union[DoAttempt, DoSleep, t.Any]: # noqa: A003
self._begin_iter(retry_state)
result = None
for action in self.iter_state.actions:
result = await action(retry_state)
return result

def __iter__(self) -> t.Generator[AttemptManager, None, None]:
raise TypeError("AsyncRetrying object is not iterable")

Expand All @@ -70,7 +112,7 @@ def __aiter__(self) -> "AsyncRetrying":

async def __anext__(self) -> AttemptManager:
while True:
do = self.iter(retry_state=self._retry_state)
do = await self.iter(retry_state=self._retry_state)
if do is None:
raise StopAsyncIteration
elif isinstance(do, DoAttempt):
Expand Down
13 changes: 12 additions & 1 deletion tenacity/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,8 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import functools
import inspect
import sys
import typing
from datetime import timedelta
Expand Down Expand Up @@ -76,3 +77,13 @@ def to_seconds(time_unit: time_unit_type) -> float:
return float(
time_unit.total_seconds() if isinstance(time_unit, timedelta) else time_unit
)


def is_coroutine_callable(call: typing.Callable[..., typing.Any]) -> bool:
if inspect.isclass(call):
return False
if inspect.iscoroutinefunction(call):
return True
partial_call = isinstance(call, functools.partial) and call.func
dunder_call = partial_call or getattr(call, "__call__", None)
return inspect.iscoroutinefunction(dunder_call)
41 changes: 41 additions & 0 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
import functools

from tenacity import _utils


def test_is_coroutine_callable() -> None:
async def async_func() -> None:
pass

def sync_func() -> None:
pass

class AsyncClass:
async def __call__(self) -> None:
pass

class SyncClass:
def __call__(self) -> None:
pass

lambda_fn = lambda: None # noqa: E731

partial_async_func = functools.partial(async_func)
partial_sync_func = functools.partial(sync_func)
partial_async_class = functools.partial(AsyncClass().__call__)
partial_sync_class = functools.partial(SyncClass().__call__)
partial_lambda_fn = functools.partial(lambda_fn)

assert _utils.is_coroutine_callable(async_func) is True
assert _utils.is_coroutine_callable(sync_func) is False
assert _utils.is_coroutine_callable(AsyncClass) is False
assert _utils.is_coroutine_callable(AsyncClass()) is True
assert _utils.is_coroutine_callable(SyncClass) is False
assert _utils.is_coroutine_callable(SyncClass()) is False
assert _utils.is_coroutine_callable(lambda_fn) is False

assert _utils.is_coroutine_callable(partial_async_func) is True
assert _utils.is_coroutine_callable(partial_sync_func) is False
assert _utils.is_coroutine_callable(partial_async_class) is True
assert _utils.is_coroutine_callable(partial_sync_class) is False
assert _utils.is_coroutine_callable(partial_lambda_fn) is False

0 comments on commit bfa2c80

Please sign in to comment.