# This file is part of curious.
#
# curious is free software: you can redistribute it and/or modify
# it under the terms of the GNU Lesser General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# curious is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Lesser General Public License for more details.
#
# You should have received a copy of the GNU Lesser General Public License
# along with curious. If not, see <http://www.gnu.org/licenses/>.
"""
Special helpers for events.
.. currentmodule: curious.core.events
"""
import functools
import inspect
import logging
import typing
import multio
from async_generator import asynccontextmanager
from multidict import MultiDict
from curious.core import client as md_client
from curious.core.gateway import GatewayHandler
from curious.util import remove_from_multidict, safe_generator
logger = logging.getLogger("curious.events")
[docs]class ListenerExit(Exception):
"""
Raised when a temporary listener is to be exited.
.. code-block:: python3
def listener(ctx, message):
if message.author.id == message.guild.owner_id:
raise ListenerExit
"""
[docs]@asynccontextmanager
@safe_generator
async def _wait_for_manager(manager, name: str, predicate):
"""
Helper class for managing a wait_for.
"""
async with multio.asynclib.task_manager() as tg:
try:
partial = functools.partial(manager.wait_for, name, predicate)
await multio.asynclib.spawn(tg, partial)
yield
finally:
await multio.asynclib.cancel_task_group(tg)
[docs]class EventManager(object):
"""
A manager for events.
This deals with firing of events and temporary listeners.
"""
def __init__(self):
#: The task manager used to spawn events.
self.task_manager = None
#: A list of event hooks.
self.event_hooks = set()
#: A MultiDict of event listeners.
self.event_listeners = MultiDict()
#: A MultiDict of temporary listeners.
self.temporary_listeners = MultiDict()
# add or removal functions
# Events
[docs] def add_event(self, func, name: str = None):
"""
Add an event to the internal registry of events.
:param name: The event name to register under.
:param func: The function to add.
"""
if not inspect.iscoroutinefunction(func):
raise TypeError("Event must be an async function")
if name is None:
evs = func.events
else:
evs = [name]
for ev_name in evs:
logger.debug("Registered event `{}` handling `{}`".format(func, ev_name))
self.event_listeners.add(ev_name, func)
[docs] def remove_event(self, name: str, func):
"""
Removes a function event.
:param name: The name the event is registered under.
:param func: The function to remove.
"""
self.event_listeners = remove_from_multidict(self.event_listeners, key=name, item=func)
# listeners
[docs] def add_temporary_listener(self, name: str, listener):
"""
Adds a new temporary listener.
To remove the listener, you can raise ListenerExit which will exit it and remove the
listener from the list.
:param name: The name of the event to listen to.
:param listener: The listener function.
"""
self.temporary_listeners.add(name, listener)
[docs] def remove_listener_early(self, name: str, listener):
"""
Removes a temporary listener early.
:param name: The name of the event the listener is registered under.
:param listener: The listener function.
"""
self.event_listeners = remove_from_multidict(self.event_listeners, key=name, item=listener)
[docs] def add_event_hook(self, listener):
"""
Adds an event hook.
:param listener: The event hook callable to use.
"""
logger.warning("Adding event hook '%s'", listener)
self.event_hooks.add(listener)
[docs] def remove_event_hook(self, listener):
"""
Removes an event hook.
"""
self.event_hooks.remove(listener)
# wrapper functions
[docs] async def _safety_wrapper(self, func, *args, **kwargs):
"""
Ensures a coro's error is caught and doesn't balloon out.
"""
try:
await func(*args, **kwargs)
except Exception as e:
logger.exception("Unhandled exception in {}!".format(func.__name__), exc_info=True)
[docs] async def _listener_wrapper(self, key: str, func, *args, **kwargs):
"""
Wraps a listener, ensuring ListenerExit is handled properly.
"""
try:
await func(*args, **kwargs)
except ListenerExit:
# remove the function
self.temporary_listeners = remove_from_multidict(self.temporary_listeners, key, func)
except Exception:
logger.exception("Unhandled exception in listener {}!".format(func.__name__),
exc_info=True)
self.temporary_listeners = remove_from_multidict(self.temporary_listeners, key, func)
[docs] async def wait_for(self, event_name: str, predicate=None):
"""
Waits for an event.
Returning a truthy value from the predicate will cause it to exit and return.
:param event_name: The name of the event.
:param predicate: The predicate to use to check for the event.
"""
p = multio.Promise()
errored = False
async def listener(ctx, *args):
# exit immediately if the predicate is none
if predicate is None:
await p.set(args)
raise ListenerExit
try:
res = predicate(*args)
if inspect.isawaitable(res):
res = await res
except ListenerExit:
# ???
await p.set(args)
raise
except Exception as e:
# something bad happened, set exception and exit
logger.exception("Exception in wait_for predicate!")
# signal that an error happened
nonlocal errored
errored = True
await p.set(e)
raise ListenerExit
else:
# exit now if result is true
if res is True:
await p.set(args)
raise ListenerExit
self.add_temporary_listener(name=event_name, listener=listener)
output = await p.wait()
if errored:
raise output
# unwrap tuples, if applicable
if len(output) == 1:
return output[0]
return output
[docs] def wait_for_manager(self, event_name: str, predicate) -> 'typing.AsyncContextManager[None]':
"""
Returns a context manager that can be used to run some steps whilst waiting for a
temporary listener.
.. code-block:: python
async with client.events.wait_for_manager("member_update", predicate=...):
await member.nickname.set("Test")
This probably won't be needed outside of internal library functions.
"""
return _wait_for_manager(self, event_name, predicate)
[docs] async def spawn(self, cofunc, *args) -> typing.Any:
"""
Spawns a new async function using our task manager.
Usage::
async def myfn(a, b):
await do_some_operation(a + b)
await events.spawn(myfn, 1, 2)
:param cofunc: The async function to spawn.
:param args: Args to provide to the async function.
"""
return await multio.asynclib.spawn(self.task_manager, cofunc, *args)
[docs] async def fire_event(self, event_name: str, *args, **kwargs):
"""
Fires an event.
:param event_name: The name of the event to fire.
"""
if "ctx" not in kwargs:
gateway = kwargs.pop("gateway")
client = kwargs.pop("client")
ctx = EventContext(client, gateway.gw_state.shard_id, event_name)
else:
ctx = kwargs.pop("ctx")
# clobber event name
ctx.event_name = event_name
# always ensure hooks are ran first
for hook in self.event_hooks:
cofunc = functools.partial(hook, ctx, *args, **kwargs)
await self.spawn(cofunc)
for handler in self.event_listeners.getall(event_name, []):
coro = functools.partial(handler, ctx, *args, **kwargs)
coro.__name__ = handler.__name__
await self.spawn(self._safety_wrapper, coro)
for listener in self.temporary_listeners.getall(event_name, []):
coro = functools.partial(self._listener_wrapper, event_name, listener, ctx,
*args, **kwargs)
await self.spawn(coro)
[docs]def event(name, scan: bool = True):
"""
Marks a function as an event.
:param name: The name of the event.
:param scan: Should this event be handled in scans too?
"""
def __innr(f):
if not hasattr(f, "events"):
f.events = {name}
f.is_event = True
f.events.add(name)
f.scan = scan
return f
return __innr
[docs]def scan_events(obb) -> typing.Generator[None, typing.Tuple[str, typing.Any], None]:
"""
Scans an object for any items marked as an event and yields them.
"""
def _pred(f):
is_event = getattr(f, "is_event", False)
if not is_event:
return False
if not f.scan:
return False
return True
for _, item in inspect.getmembers(obb, predicate=_pred):
yield (_, item)
[docs]class EventContext(object):
"""
Represents a special context that are passed to events.
"""
def __init__(self, cl: 'md_client.Client', shard_id: int,
event_name: str):
"""
:param cl: The :class:`.Client` instance for this event context.
:param shard_id: The shard ID this event is for.
:param event_name: The event name for this event.
"""
#: The :class:`.Client` instance that this event was fired under.
self.bot = cl
#: The shard this event was received on.
self.shard_id = shard_id # type: int
#: The shard for this bot.
self.shard_count = cl.shard_count # type: int
#: The event name for this event.
self.event_name = event_name # type: str
@property
def handlers(self) -> typing.List[typing.Callable[['EventContext'], None]]:
"""
:return: A list of handlers registered for this event.
"""
return self.bot.events.getall(self.event_name, [])
[docs] async def change_status(self, *args, **kwargs) -> None:
"""
Changes the current status for this shard.
This takes the same arguments as :class:`.Client.change_status`, but ignoring the shard ID.
"""
kwargs["shard_id"] = self.shard_id
return await self.bot.change_status(*args, **kwargs)
@property
def gateway(self) -> GatewayHandler:
"""
:return: The :class:`.Gateway` that produced this event.
"""
return self.bot.gateways[self.shard_id]