import aiohttp
import asyncio
import inspect
import logging
from datetime import time as dtime
from datetime import timedelta, datetime, timezone
from typing import Callable, Optional, Union, Sequence
from . import utils
_log = logging.getLogger(__name__)
[docs]
class Sleeper:
def __init__(self, dt: datetime, *, loop: asyncio.AbstractEventLoop):
self.loop = loop
self.future: asyncio.Future = loop.create_future()
self.handle: asyncio.TimerHandle = loop.call_later(
max((dt - utils.utcnow()).total_seconds(), 0),
self.future.set_result,
True
)
[docs]
def recalculate(self, dt: datetime) -> None:
self.handle.cancel()
self.handle: asyncio.TimerHandle = self.loop.call_later(
max((dt - utils.utcnow()).total_seconds(), 0),
self.future.set_result,
True
)
[docs]
def wait(self) -> asyncio.Future:
return self.future
[docs]
def done(self) -> bool:
return self.future.done()
[docs]
def cancel(self) -> None:
self.future.cancel()
self.handle.cancel()
[docs]
class Loop:
def __init__(
self,
*,
func: Callable,
seconds: Optional[float],
minutes: Optional[float],
hours: Optional[float],
time: Optional[Union[dtime, list[dtime]]] = None,
count: Optional[int] = None,
reconnect: bool = True
):
self.func: Callable = func
self.reconnect: bool = reconnect
self.count: Optional[int] = count
if self.count is not None and self.count <= 0:
raise ValueError("count must be greater than 0 or None")
self._task: Optional[asyncio.Task] = None
self._injected = None
self._error: Callable = self._default_error
self._before_loop: Callable = self._default_before_loop
self._after_loop: Callable = self._default_after_loop
self._whitelist_exceptions = (
OSError,
asyncio.TimeoutError,
aiohttp.ClientError,
)
self.handle_interval(
seconds=seconds,
minutes=minutes,
hours=hours,
time=time
)
self._will_cancel: bool = False
self._should_stop: bool = False
self._has_faild: bool = False
self._last_loop_failed: bool = False
self._last_loop: Optional[datetime] = None
self._next_loop: Optional[datetime] = None
self._loop_count: int = 0
async def __call__(self, *args, **kwargs) -> Callable:
if self._injected is not None:
args = (self._injected, *args)
return await self.func(*args, **kwargs)
def __get__(self, obj, objtype):
if obj is None:
return self
copy: Loop = Loop(
func=self.func,
seconds=self._seconds,
minutes=self._minutes,
hours=self._hours,
time=self._time,
count=self.count,
reconnect=self.reconnect
)
copy._injected = obj
copy._before_loop = self._before_loop
copy._after_loop = self._after_loop
copy._error = self._error
setattr(obj, self.func.__name__, copy)
return copy
async def _try_sleep_until(self, dt: datetime) -> None:
""" Attempt to sleeps until a specified datetime depending on the loop configuration """
self._handle = Sleeper(dt, loop=asyncio.get_event_loop())
return await self._handle.wait()
async def _default_error(self, e: Exception) -> None:
""" The default error handler for the loop """
_log.error(
f"Unhandled exception in background loop {self.func.__name__}",
exc_info=e
)
async def _default_before_loop(self) -> None:
""" The default before_loop handler for the loop """
pass
async def _default_after_loop(self) -> None:
""" The default after_loop handler for the loop """
pass
async def _looper(self, *args, **kwargs) -> None:
""" Internal looper that handles the behaviour of the loop """
await self._before_loop()
self._last_loop_failed = False
if self._is_explicit_time():
self._next_loop = self._next_sleep_time()
else:
self._next_loop = utils.utcnow()
await asyncio.sleep(0)
try:
if self._should_stop:
return None
while True:
if self._is_explicit_time():
await self._try_sleep_until(self._next_loop)
if not self._last_loop_failed:
self._last_loop = self._next_loop
self._next_loop = self._next_sleep_time()
while (
self._is_explicit_time() and
self._next_loop <= self._last_loop
):
_log.warn(
f"task:{self.func.__name__} woke up a bit too early. "
f"Sleeping until {self._next_loop} to avoid drifting."
)
await self._try_sleep_until(self._next_loop)
self._next_loop = self._next_sleep_time()
try:
await self.func(*args, **kwargs)
self._last_loop_failed = False
except self._whitelist_exceptions:
self._last_loop_failed = True
if not self.reconnect:
raise
await asyncio.sleep(5)
else:
if self._should_stop:
return
if self._is_relative_time():
await self._try_sleep_until(self._next_loop)
self._loop_count += 1
if self.loop_count == self.count:
break
except asyncio.CancelledError:
self._will_cancel = True
raise
except Exception as e:
self._has_faild = True
await self._error(e)
finally:
await self._after_loop()
if self._handle:
self._handle.cancel()
self._will_cancel = False
self._loop_count = 0
self._should_stop = False
[docs]
def start(self, *args, **kwargs) -> asyncio.Task:
""" Starts the loop """
if self._task and not self._task.done():
raise RuntimeError("The loop is already running")
if self._injected is not None:
args = (self._injected, *args)
self._last_loop_failed = False
self._task = asyncio.create_task(self._looper(*args, **kwargs))
return self._task
[docs]
def stop(self) -> None:
""" Stops the loop """
if self._task and not self._task.done():
self._should_stop = True
def _can_be_cancelled(self) -> bool:
return bool(
not self._will_cancel and
self._task and
not self._task.done()
)
[docs]
def cancel(self) -> None:
""" Cancels the loop if possible """
if self._can_be_cancelled():
self._task.cancel()
[docs]
def on_error(self) -> Callable:
""" Decorator that registers a custom error handler for the loop """
def decorator(func: Loop) -> Loop:
if not inspect.iscoroutinefunction(func):
raise TypeError("The error handler must be a coroutine function")
self._error = func
return func
return decorator
[docs]
def before_loop(self) -> Callable:
""" Decorator that registers a custom before_loop handler for the loop """
def decorator(func: Loop) -> Loop:
if not inspect.iscoroutinefunction(func):
raise TypeError("The before_loop must be a coroutine function")
self._before_loop = func
return func
return decorator
[docs]
def after_loop(self) -> Callable:
""" Decorator that registers a custom after_loop handler for the loop """
def decorator(func: Loop) -> Loop:
if not inspect.iscoroutinefunction(func):
raise TypeError("The after_loop must be a coroutine function")
self._after_loop = func
return func
return decorator
[docs]
def is_running(self) -> bool:
""" Returns whether the loop is running or not """
return not bool(self._task.done()) if self._task else False
@property
def loop_count(self) -> int:
""" Returns the number of times the loop has been run """
return self._loop_count
@property
def failed(self) -> bool:
""" Returns whether the loop has failed or not """
return self._has_faild
def _is_relative_time(self) -> bool:
return self._time is None
def _is_explicit_time(self) -> bool:
return self._time is not None
[docs]
def is_being_cancelled(self) -> bool:
""" Returns whether the loop is being cancelled or not """
return self._will_cancel
[docs]
def fetch_task(self) -> Optional[asyncio.Task]:
""" Returns the task that is running the loop """
return self._task
[docs]
def add_exception(self, *exceptions: Exception) -> None:
""" Adds exceptions to the whitelist of exceptions that are ignored """
for e in exceptions:
if not inspect.isclass(e):
_log.error(
"Loop.add_exception expected a class, "
f"received {type(e)} instead, skipping"
)
continue
if not issubclass(e, BaseException):
_log.error(
"Loop.add_exception expected a subclass of BaseException, "
f"received {e} instead, skipping"
)
continue
self._whitelist_exceptions += (e,)
[docs]
def remove_exception(self, *exceptions: Exception) -> None:
""" Removes exceptions from the whitelist of exceptions that are ignored """
self._whitelist_exceptions = tuple(
x for x in self._whitelist_exceptions
if x not in exceptions
)
[docs]
def reset_exceptions(self) -> None:
""" Resets the whitelist of exceptions that are ignored back to the default """
self._whitelist_exceptions = (
OSError,
asyncio.TimeoutError,
aiohttp.ClientError,
)
def _sort_static_times(
self,
times: Optional[Union[dtime, Sequence[dtime]]]
) -> list[dtime]:
if isinstance(times, dtime):
return [
times
if times.tzinfo is not None
else times.replace(tzinfo=timezone.utc)
]
if not isinstance(times, Sequence):
raise TypeError(f"Expected a list, got {type(times)} instead")
if not times:
raise ValueError("Expected at least one item, got an empty list instead")
output: list[dtime] = []
for i, ts in enumerate(times):
if not isinstance(ts, dtime):
raise TypeError(f"Expected datetime.time, got {type(ts)} (Index: {i})")
output.append(
ts
if ts.tzinfo is not None
else ts.replace(tzinfo=timezone.utc)
)
return sorted(set(output))
[docs]
def handle_interval(
self,
*,
seconds: Optional[float] = 0,
minutes: Optional[float] = 0,
hours: Optional[float] = 0,
time: Optional[Union[dtime, list[dtime]]] = None
) -> None:
"""
Sets the interval of the loop.
Parameters
----------
seconds: `float`
Amount of seconds between each iteration of the loop
minutes: `float`
Amount of minutes between each iteration of the loop
hours: `float`
Amount of hours between each iteration of the loop
time: `dtime`
The time of day to run the loop at
Raises
------
`ValueError`
- The sleep timer cannot be 0
- `count` must be greater than 0 or `None`
`TypeError`
`time` must be a `datetime.time` object
"""
if time is None:
seconds = seconds or 0.0
minutes = minutes or 0.0
hours = hours or 0.0
sleep = seconds + (minutes * 60.0) + (hours * 3600.0)
if sleep <= 0:
raise ValueError("The sleep timer cannot be 0")
self._seconds = float(seconds)
self._minutes = float(minutes)
self._hours = float(hours)
self._sleep = sleep
self._time: Optional[list[dtime]] = None
else:
if any((seconds, minutes, hours)):
raise ValueError("Cannot use both time and seconds/minutes/hours")
self._time: Optional[list[dtime]] = self._sort_static_times(time)
self._sleep = self._seconds = self._minutes = self._hours = None
if self.is_running() and self._last_loop is not None:
self._next_loop = self._next_sleep_time()
if self._handle and not self._handle.done():
self._handle.recalculate(self._next_loop)
def _find_time_index(self, now: datetime) -> Optional[int]:
"""
Finds the index of the next time in the list of times
Parameters
----------
now: `datetime`
The current time
Returns
-------
`Optional[int]`
The index of the next time in the list of times
"""
if not self._time:
return None
for i, ts in enumerate(self._time):
start = now.astimezone(ts.tzinfo)
if ts >= start.timetz():
return i
else:
return None
def _next_sleep_time(self, now: Optional[datetime] = None) -> datetime:
""" Calculates the next time the loop should run """
if self._sleep is not None:
return self._last_loop + timedelta(seconds=self._sleep)
if now is None:
now = utils.utcnow()
index = self._find_time_index(now)
if index is None:
time = self._time[0]
tomorrow = now.astimezone(time.tzinfo) + timedelta(days=1)
date = tomorrow.date()
else:
time = self._time[index]
date = now.astimezone(time.tzinfo).date()
dt = datetime.combine(date, time, tzinfo=time.tzinfo)
return dt.astimezone(timezone.utc)
[docs]
def loop(
*,
seconds: Optional[float] = None,
minutes: Optional[float] = None,
hours: Optional[float] = None,
time: Optional[Union[dtime, list[dtime]]] = None,
count: Optional[int] = None,
reconnect: bool = True
) -> Callable[[Callable], Loop]:
"""
Decorator that registers a function as a loop.
Parameters
----------
seconds: `float`
The number of seconds between each iteration of the loop.
minutes: `float`
The number of minutes between each iteration of the loop.
hours: `float`
The number of hours between each iteration of the loop.
time: `datetime.time`
The time of day to run the loop at. (UTC only)
count: `int`
The number of times to run the loop. If ``None``, the loop will run forever.
reconnect: `bool`
Whether the loop should reconnect if it fails or not.
"""
def decorator(func) -> Loop:
return Loop(
func=func,
seconds=seconds,
minutes=minutes,
hours=hours,
time=time,
count=count,
reconnect=reconnect
)
return decorator