# This file is part of curious.
#
# curious is free software: you can redistribute it and/or modify
# it under the terms of the GNU Lesser General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# curious is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Lesser General Public License for more details.
#
# You should have received a copy of the GNU Lesser General Public License
# along with curious. If not, see <http://www.gnu.org/licenses/>.
"""
Contains the class for the commands manager for a client.
.. currentmodule:: curious.commands.manager
"""
import importlib
import inspect
import logging
import sys
import traceback
import typing
from collections import defaultdict
from functools import partial
import multio
from curious.commands.context import Context
from curious.commands.exc import CommandsError
from curious.commands.help import help_command
from curious.commands.plugin import Plugin
from curious.commands.ratelimit import RateLimiter
from curious.commands.utils import prefix_check_factory
from curious.core import client as md_client
from curious.core.event import EventContext, event
from curious.dataclasses.message import Message
logger = logging.getLogger("curious.commands.manager")
[docs]class CommandsManager(object):
"""
A manager that handles commands for a client.
First, you need to create the manager and attach it to a client:
.. code-block:: python3
# form 1, automatically register with the client
manager = CommandsManager.with_client(bot)
# form 2, manually register
manager = CommandsManager(bot)
manager.register_events()
This is required to add the handler events to the client.
Next, you need to register a message check handler. This is a callable that is called for
every message to try and extract the command from a message, if it matches.
By default, the manager provides an easy way to use a simple command prefix:
.. code-block:: python3
# at creation time
manager = CommandsManager(bot, command_prefix="!")
# or set it on the manager
manager.command_prefix = "!"
At this point, the command prefix will be available on the manager with either
:attr:`.Manager.command_prefix` or :attr:`.Manager.message_check.prefix`.
If you need more complex message checking, you can use ``message_check``:
.. code-block:: python3
manager = CommandsManager(bot, message_check=my_message_checker)
# or
manager.message_check = my_message_checker
Finally, you can register plugins or modules containing plugins with the manager:
.. code-block:: python3
@bot.event("ready")
async def load_plugins(ctx: EventContext):
# load plugin explicitly
await manager.load_plugin(PluginClass, arg1)
# load plugins from a module
await manager.load_plugins_from("my.plugin.module")
You can also add free-standing commands that aren't bound to a plugin with
:meth:`.CommandsManager.add_command`:
.. code-block:: python3
@command()
async def ping(ctx: CommandsContext):
await ctx.channel.messages.send(content="Ping!")
manager.add_command(ping)
These will then be available to the client.
"""
def __init__(self, client: 'md_client.Client', *,
message_check=None, command_prefix: str = None):
"""
:param client: The :class:`.Client` to use with this manager.
:param message_check: The message check function for this manager.
This should take two arguments, the client and message, and should return either None
or a 2-item tuple:
- The command word matched
- The tokens after the command word
"""
if message_check is None and command_prefix is None:
raise ValueError("Must provide one of message_check or command_prefix")
#: The client for this manager.
self.client = client
if message_check is None:
message_check = prefix_check_factory(command_prefix)
#: The message check function for this manager.
self.message_check = message_check
#: A dictionary mapping of <plugin name> -> <plugin> object.
self.plugins = {}
#: A dictionary of stand-alone commands, i.e. commands not associated with a plugin.
self.commands = {}
#: The current ratelimiter.
self.ratelimiter = RateLimiter()
self._module_plugins = defaultdict(lambda: [])
[docs] @classmethod
def with_client(cls, client: 'md_client.Client', **kwargs):
"""
Creates a manager and automatically registers events.
"""
obb = cls(client=client, **kwargs)
obb.register_events()
return obb
[docs] def register_events(self) -> None:
"""
Copies the events to the client specified on this manager.
"""
self.client.events.add_event(self.handle_message)
self.client.events.add_event(self.default_command_error)
self.client.events.add_event_hook(self.event_hook)
from curious.commands.decorators import command
self.commands["help"] = command(name="help")(help_command)
[docs] async def load_plugin(self, klass: typing.Type[Plugin], *args,
module: str = None):
"""
Loads a plugin.
.. note::
The client instance will automatically be provided to the Plugin's ``__init__``.
:param klass: The plugin class to load.
:param args: Any args to provide to the plugin.
:param module: The module name provided with this plugin. Only used interally.
"""
# get the name and create the plugin object
plugin_name = getattr(klass, "plugin_name", klass.__name__)
instance = klass(self.client, *args)
# call load, of course
await instance.load()
self.plugins[plugin_name] = instance
if module is not None:
self._module_plugins[module].append(instance)
return instance
[docs] async def unload_plugin(self, klass: typing.Union[Plugin, str]):
"""
Unloads a plugin.
:param klass: The plugin class or name of plugin to unload.
"""
p: Plugin = None
if isinstance(klass, str):
p = self.plugins.pop(klass)
for k, p in self.plugins.copy().items():
if type(p) == klass:
p = self.plugins.pop(k)
break
if p is not None:
# cancel the task group used for this plugin, if it's running
if p.task_group is not None:
await multio.asynclib.cancel_task_group(p.task_group)
await p.unload()
return p
[docs] def _lookup_command(self, name: str):
"""
Does a lookup in plugin and standalone commands.
"""
if name in self.commands:
return self.commands[name]
for plugin in self.plugins.values():
cmds = plugin._get_commands()
try:
return next(filter(lambda cmd: not cmd.cmd_subcommand and
(cmd.cmd_name == name or name in cmd.cmd_aliases), cmds))
except StopIteration:
continue
[docs] def get_command(self, command_name: str):
"""
Gets a command from the internal command storage.
If provided a string separated by spaces, a subcommand lookup will be attempted.
:param command_name: The name of the command to lookup.
"""
# do an immediate lookup for the first token
sp = command_name.split(" ")
command = self._lookup_command(sp[0])
if command is None:
return None
for token in sp[1:]:
try:
filtered = filter(lambda cmd: cmd.cmd_name == token or token in cmd.cmd_aliases,
command.cmd_subcommands)
command = next(filtered)
except StopIteration:
return None
return command
[docs] def add_command(self, command):
"""
Adds a command.
:param command: A command function.
"""
if not hasattr(command, "is_cmd"):
raise ValueError("Commands must be decorated with the command decorator")
self.commands[command.cmd_name] = command
return command
[docs] def remove_command(self, command):
"""
Removes a command.
:param command: The name of the command, or the command function.
"""
if isinstance(command, str):
return self.commands.pop(command)
else:
for k, p in self.commands.copy().items():
if p == command:
return self.commands.pop(k)
[docs] async def load_plugins_from(self, import_path: str):
"""
Loads plugins from the specified module.
:param import_path: The import path to import.
"""
mod = importlib.import_module(import_path)
# define the predicate for the body scanner
def predicate(item):
if not isinstance(item, type):
return False
# only accept plugin subclasses
if not issubclass(item, Plugin):
return False
# ensure item is not actually Plugin
if item == Plugin:
return False
# it is a plugin
return True
for plugin_name, plugin_class in inspect.getmembers(mod, predicate=predicate):
await self.load_plugin(plugin_class, module=mod)
[docs] async def unload_plugins_from(self, import_path: str):
"""
Unloads plugins from the specified module.
This will delete the module from sys.path.
:param import_path: The import path.
"""
for plugin in self._module_plugins[import_path]:
await plugin.unload()
self.plugins.pop(getattr(plugin, "plugin_name", "__name__"))
del sys.modules[import_path]
del self._module_plugins[import_path]
[docs] async def event_hook(self, ctx: EventContext, *args, **kwargs):
"""
The event hook for the commands manager.
"""
async with multio.asynclib.task_manager() as tg:
for plugin in self.plugins.values():
body = inspect.getmembers(plugin, predicate=lambda v: hasattr(v, "is_event"))
for _, handler in body:
if ctx.event_name not in handler.events:
continue
cofunc = partial(self.client.events._safety_wrapper,
handler, ctx, *args, **kwargs)
await multio.asynclib.spawn(tg, cofunc)
[docs] async def handle_commands(self, ctx: EventContext, message: Message):
"""
Handles commands for a message.
"""
# don't process messages pre-cache
if not message.author:
return
# check bot type
if message.author.user.bot and self.client.bot_type & 8:
return
if message.author.user != self.client.user and self.client.bot_type & 64:
return
if message.guild_id is not None and self.client.bot_type & 32:
return
if message.guild_id is None and self.client.bot_type & 16:
return
# step 1, match the messages
matched = self.message_check(self.client, message)
if inspect.isawaitable(matched):
matched = await matched
if matched is None:
return None
# deconstruct the tuple returned into more useful variables than a single tuple
command_word, tokens = matched
# step 2, create the new commands context
ctx = Context(event_context=ctx, message=message)
ctx.command_name = command_word
ctx.tokens = tokens
ctx.manager = self
# step 3, invoke the context to try and match the command and run it
await ctx.try_invoke()
[docs] @event("command_error")
async def default_command_error(self, ev_ctx: EventContext, ctx: Context, err: CommandsError):
"""
Handles command errors by default.
"""
# autoremove ourself if applicable
if len(self.client.events.event_listeners.getall("command_error")) > 1:
self.client.events.remove_event("command_error", self.default_command_error)
return
fmtted = ''.join(traceback.format_exception(type(err), err, err.__traceback__))
logger.error(f"Error in command!\n{fmtted}")
[docs] @event("message_create")
async def handle_message(self, ctx: EventContext, message: Message):
"""
Registered as the event handler in a client for handling commands.
"""
return await self.handle_commands(ctx, message)