Source code for curious.core.client

# This file is part of curious.
#
# curious is free software: you can redistribute it and/or modify
# it under the terms of the GNU Lesser General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# curious is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU Lesser General Public License for more details.
#
# You should have received a copy of the GNU Lesser General Public License
# along with curious.  If not, see <http://www.gnu.org/licenses/>.

"""
The main client class.

This contains a definition for :class:`.Client` which is used to interface primarily with Discord.

.. currentmodule:: curious.core.client
"""

import enum
import functools
import inspect
import logging
import typing
from types import MappingProxyType

import collections
import multio
from lomond.errors import WebSocketClosed, WebSocketClosing

from curious.core import chunker as md_chunker
from curious.core.event import EventContext, EventManager, event as ev_dec, scan_events
from curious.core.gateway import GatewayHandler, open_websocket
from curious.core.httpclient import HTTPClient
from curious.dataclasses import channel as dt_channel, guild as dt_guild, member as dt_member
from curious.dataclasses.appinfo import AppInfo
from curious.dataclasses.invite import Invite
from curious.dataclasses.message import CHANNEL_REGEX, EMOJI_REGEX, MENTION_REGEX
from curious.dataclasses.presence import Game, Status
from curious.dataclasses.user import BotUser, User
from curious.dataclasses.webhook import Webhook
from curious.dataclasses.widget import Widget
from curious.util import base64ify

logger = logging.getLogger("curious.client")


[docs]class BotType(enum.IntEnum): """ An enum that signifies what type of bot this bot is. This will tell the commands handling how to respond, as well as how to log in. """ #: Regular bot. This signifies that the client should log in as a bot account. BOT = 1 # 2, 4 is reserved #: No bot responses. This signifies that the client should respond to ONLY USER messages. ONLY_USER = 8 #: No DMs. This signifies the bot only works in guilds. NO_DMS = 16 #: No guilds. This signifies the bot only works in DMs. NO_GUILDS = 32 #: Self bot. This signifies the bot only responds to itself. SELF_BOT = 64
[docs]class Client(object): """ The main client class. This is used to interact with Discord. To start, you can create an instance of the client by passing it the token you want to use: .. code-block:: python3 cl = Client("my.token.string") Registering events can be done with the :meth:`.Client.event` decorator, or alternatively manual usage of the :class:`.EventHandler` on :attr:`.Client.events`. .. code-block:: python3 @cl.event("ready") async def loaded(ctx: EventContext): print("Bot logged in.") """ #: A list of events to ignore the READY status. IGNORE_READY = [ "connect", "guild_streamed", "guild_chunk", "guild_available", "guild_sync" ] def __init__(self, token: str, *, state_klass: type = None, bot_type: int = (BotType.BOT | BotType.ONLY_USER)): """ :param token: The current token for this bot. :param state_klass: The class to construct the connection state from. :param bot_type: A union of :class:`.BotType` that defines the type of this bot. """ #: The mapping of `shard_id -> gateway` objects. self._gateways = {} # type: typing.MutableMapping[int, GatewayHandler] #: The number of shards this client has. self.shard_count = 0 #: The token for the bot. self._token = token if state_klass is None: from curious.core.state import State state_klass = State #: The current connection state for the bot. self.state = state_klass(self) #: The bot type for this bot. self.bot_type = bot_type #: The current :class:`.EventManager` for this bot. self.events = EventManager() #: The current :class:`.Chunker` for this bot. self.chunker = md_chunker.Chunker(self) self.chunker.register_events(self.events) self._ready_state = {} #: The :class:`.HTTPClient` used for this bot. self.http = HTTPClient(self._token, bot=bool(self.bot_type & BotType.BOT)) #: The cached gateway URL. self._gw_url = None # type: str #: The application info for this bot. Instance of :class:`.AppInfo`. #: This will be None for user bots. self.application_info = None # type: AppInfo #: The task manager used for this bot. self.task_manager = None for (name, event) in scan_events(self): self.events.add_event(event) @property def user(self) -> BotUser: """ :return: The :class:`.User` that this client is logged in as. """ return self.state._user @property def guilds(self) -> 'typing.Mapping[int, dt_guild.Guild]': """ :return: A mapping of int -> :class:`.Guild` that this client can see. """ return self.state.guilds @property def invite_url(self) -> str: """ :return: The invite URL for this bot. """ return "https://discordapp.com/oauth2/authorize?client_id={}&scope=bot".format( self.application_info.client_id) @property def events_handled(self) -> collections.Counter: """ A :class:`collections.Counter` of all events that have been handled since the bot's bootup. This can be used to track statistics for events. .. code-block:: python3 @command() async def events(self, ctx: Context): ''' Shows the most common events. ''' ev = ctx.bot.events_handled.most_common(3) await ctx.channel.messages.send(", ".join("{}: {}".format(*x) for x in ev) """ c = collections.Counter() for gw in self._gateways.values(): c.update(gw._dispatches_handled) return c @property def gateways(self) -> 'typing.Mapping[int, GatewayHandler]': """ :return: A read-only view of the current gateways for this client. """ return MappingProxyType(self._gateways)
[docs] def find_channel(self, channel_id: int): """ Finds a channel by channel ID. """ return self.state.find_channel(channel_id)
[docs] async def _background_get_app_info(self): """ Gets the application info in the background. """ appinfo = AppInfo(self, **(await self.http.get_app_info(None))) self.application_info = appinfo
[docs] async def get_gateway_url(self) -> str: """ :return: The gateway URL for this bot. """ if self._gw_url: return self._gw_url self._gw_url = await self.http.get_gateway_url() return self._gw_url
[docs] async def get_shard_count(self) -> int: """ :return: The shard count recommended for this bot. """ gw, shards = await self.http.get_shard_count() self._gw_url = gw return shards
[docs] def guilds_for(self, shard_id: int) -> 'typing.Iterable[dt_guild.Guild]': """ Gets the guilds for this shard. :param shard_id: The shard ID to get guilds from. :return: A list of :class:`Guild` that client can see on the specified shard. """ return self.state.guilds_for_shard(shard_id)
[docs] def event(self, name: str): """ A convenience decorator to mark a function as an event. This will copy it to the events dictionary, where it will be used as an event later on. .. code-block:: python3 @bot.event("message_create") async def something(ctx, message: Message): pass :param name: The name of the event. """ def _inner(func): f = ev_dec(name)(func) self.events.add_event(func=f) return func return _inner
# rip in peace old fire_event # 2016-2017 # broke my pycharm
[docs] async def fire_event(self, event_name: str, *args, **kwargs): """ Fires an event. This actually passes the arguments to :meth:`.EventManager.fire_event`. """ gateway = kwargs.get("gateway") if not self.state.is_ready(gateway.gw_state.shard_id): if event_name not in self.IGNORE_READY and not event_name.startswith("gateway_"): return return await self.events.fire_event(event_name, *args, **kwargs, client=self)
[docs] async def wait_for(self, *args, **kwargs) -> typing.Any: """ Shortcut for :meth:`.EventManager.wait_for`. """ return await self.events.wait_for(*args, **kwargs)
# Gateway functions
[docs] async def change_status(self, game: Game = None, status: Status = Status.ONLINE, afk: bool = False, shard_id: int = 0): """ Changes the bot's current status. :param game: The game object to use. None for no game. :param status: The new status. Must be a :class:`.Status` object. :param afk: Is the bot AFK? Only useful for userbots. :param shard_id: The shard to change your status on. """ gateway = self._gateways[shard_id] return await gateway.send_status( name=game.name if game else None, type_=game.type if game else None, url=game.url if game else None, status=status.value, afk=afk )
# HTTP Functions
[docs] async def edit_profile(self, *, username: str = None, avatar: bytes = None, password: str = None): """ Edits the profile of this bot. The user is **not** edited in-place - instead, you must wait for the `USER_UPDATE` event to be fired on the websocket. :param username: The new username of the bot. :param avatar: The bytes-like object that represents the new avatar you wish to use. :param password: The password to use. Only for user accounts. """ if not self.user.bot and password is None: raise ValueError("Password must be passed for user bots") if username: if any(x in username for x in ('@', ':', '```')): raise ValueError("Username must not contain banned characters") if username in ("discordtag", "everyone", "here"): raise ValueError("Username cannot be a banned username") if not 2 <= len(username) <= 32: raise ValueError("Username must be 2-32 characters") if avatar: avatar = base64ify(avatar) await self.http.edit_user(username, avatar, password)
[docs] async def edit_avatar(self, path: str): """ A higher-level way to change your avatar. This allows you to provide a path to the avatar file instead of having to read it in manually. :param path: The path-like object to the avatar file. """ with open(path, 'rb') as f: return await self.edit_profile(avatar=f.read())
[docs] async def get_user(self, user_id: int) -> User: """ Gets a user by ID. :param user_id: The ID of the user to get. :return: A new :class:`.User` object. """ try: return self.state._users[user_id] except KeyError: u = self.state.make_user(await self.http.get_user(user_id)) # decache it if we need to self.state._check_decache_user(u.id) return u
[docs] async def get_application(self, application_id: int) -> AppInfo: """ Gets an application by ID. :param application_id: The client ID of the application to fetch. :return: A new :class:`.AppInfo` object corresponding to the application. """ data = await self.http.get_app_info(application_id=application_id) appinfo = AppInfo(self, **data) return appinfo
[docs] async def get_webhook(self, webhook_id: int) -> Webhook: """ Gets a webhook by ID. :param webhook_id: The ID of the webhook to get. :return: A new :class:`.Webhook` object. """ return self.state.make_webhook(await self.http.get_webhook(webhook_id))
[docs] async def get_invite(self, invite_code: str, *, with_counts: bool = True) -> Invite: """ Gets an invite by code. :param invite_code: The invite code to get. :param with_counts: Return the approximate counts for this invite? :return: A new :class:`.Invite` object. """ return Invite(self, **(await self.http.get_invite(invite_code, with_counts=with_counts)))
[docs] async def get_widget(self, guild_id: int) -> Widget: """ Gets a widget from a guild. :param guild_id: The ID of the guild to get the widget of. :return: A :class:`.Widget` object. """ data = await self.http.get_widget_data(guild_id) return Widget(self, **data)
[docs] async def clean_content(self, content: str) -> str: """ Cleans the content of a message, using the bot's cache. :param content: The content to clean. :return: The cleaned up message. """ final = [] tokens = content.split(" ") # o(2n) loop for token in tokens: # try and find a channel from public channels channel_match = CHANNEL_REGEX.match(token) if channel_match is not None: channel_id = int(channel_match.groups()[0]) channel = self.state.find_channel(channel_id) if channel is None or channel.type not in \ [dt_channel.ChannelType.TEXT, dt_channel.ChannelType.VOICE]: final.append("#deleted-channel") else: final.append(f"#{channel.name}") continue user_match = MENTION_REGEX.match(token) if user_match is not None: found_name = None user_id = int(user_match.groups()[0]) member_or_user = self.state.find_member_or_user(user_id) if member_or_user: found_name = member_or_user.name if found_name is None: final.append(token) else: final.append(f"@{found_name}") continue emoji_match = EMOJI_REGEX.match(token) if emoji_match is not None: final.append(f":{emoji_match.groups()[0]}:") continue # if we got here, matching failed # so just add the token final.append(token) return " ".join(final)
# download_ methods
[docs] async def download_guild_member(self, guild_id: int, member_id: int) -> 'dt_member.Member': """ Downloads a :class:`.Member` over HTTP. .. warning:: The :attr:`.Member.roles` and similar fields will be empty when downloading a Member, unless the guild was in cache. :param guild_id: The ID of the guild which the member is in. :param member_id: The ID of the member to get. :return: The :class:`.Member` object downloaded. """ member_data = await self.http.get_guild_member(guild_id=guild_id, member_id=member_id) member = dt_member.Member(self, **member_data) # this is enough to pick up the cache member.guild_id = guild_id # manual refcounts :ancap: self.state._check_decache_user(member.id) return member
[docs] async def download_guild_members(self, guild_id: int, *, after: int = None, limit: int = 1000, get_all: bool = True) -> 'typing.Iterable[dt_member.Member]': """ Downloads the members for a :class:`.Guild` over HTTP. .. warning:: This can take a long time on big guilds. :param guild_id: The ID of the guild to download members for. :param after: The member ID after which to get members for. :param limit: The maximum number of members to return. By default, this is 1000 members. :param get_all: Should *all* members be fetched? :return: An iterable of :class:`.Member`. """ member_data = [] if get_all is True: last_id = 0 while True: next_data = await self.http.get_guild_members(guild_id=guild_id, limit=limit, after=last_id) # no more members to get if not next_data: break member_data.extend(next_data) # if there's less data than limit, we are finished downloading members if len(next_data) < limit: break last_id = member_data[-1]["user"]["id"] else: next_data = await self.http.get_guild_members(guild_id=guild_id, limit=limit, after=after) member_data.extend(next_data) # create the member objects members = [] for datum in member_data: m = dt_member.Member(self, **datum) m.guild_id = guild_id members.append(m) return members
[docs] async def download_channels(self, guild_id: int) -> 'typing.List[dt_channel.Channel]': """ Downloads all the :class:`.Channel` for a Guild. :param guild_id: The ID of the guild to download channels for. :return: An iterable of :class:`.Channel` objects. """ channel_data = await self.http.get_guild_channels(guild_id=guild_id) channels = [] for datum in channel_data: channel = dt_channel.Channel(self, **datum) channel.guild_id = guild_id channels.append(channel) return channels
[docs] async def download_guild(self, guild_id: int, *, full: bool = False) -> 'dt_guild.Guild': """ Downloads a :class:`.Guild` over HTTP. .. warning:: If ``full`` is True, this will fetch and fill ALL objects of the guild, including channels and members. This can take a *long* time if the guild is large. :param guild_id: The ID of the Guild object to download. :param full: If all extra data should be downloaded alongside it. :return: The :class:`.Guild` object downloaded. """ guild_data = await self.http.get_guild(guild_id) # create the new guild using the data specified guild = dt_guild.Guild(self, **guild_data) guild.unavailable = False # update the guild store self.state._guilds[guild_id] = guild if full: # download all of the members members = await self.download_guild_members(guild_id=guild_id, get_all=True) # update the `_members` dict guild._members = {m.id: m for m in members} # download all of the channels channels = await self.download_channels(guild_id=guild_id) guild._channels = {c.id: c for c in channels} return guild
[docs] @ev_dec(name="gateway_dispatch_received") async def handle_dispatches(self, ctx: EventContext, name: str, dispatch: dict): """ Handles dispatches for the client. """ try: handle_name = name.lower() handler = getattr(self.state, f"handle_{handle_name}") except AttributeError: logger.warning(f"Got unknown dispatch {name}") return else: logger.debug(f"Processing event {name}") try: result = handler(ctx.gateway, dispatch) if inspect.isawaitable(result): result = await result elif inspect.isasyncgen(result): async with multio.asynclib.finalize_agen(result) as gen: async for i in gen: await self.events.fire_event(i[0], *i[1:], gateway=ctx.gateway, client=self) # no more processing after the async gen return if not isinstance(result, tuple): await self.events.fire_event(result, gateway=ctx.gateway, client=self) else: await self.events.fire_event(result[0], *result[1:], gateway=ctx.gateway, client=self) except Exception: logger.exception(f"Error decoding event {name} with data {dispatch}!") await self.kill() raise
[docs] @ev_dec(name="ready") async def handle_ready(self, ctx: 'EventContext'): """ Handles a READY event, dispatching a ``shards_ready`` event when all shards are ready. """ self._ready_state[ctx.shard_id] = True if not all(self._ready_state.values()): return await self.events.fire_event("shards_ready", gateway=self._gateways[ctx.shard_id], client=self)
[docs] async def handle_shard(self, shard_id: int, shard_count: int): """ Handles a shard. :param shard_id: The shard ID to boot and handle. :param shard_count: The shard count to send in the identify packet. """ # consume events async with open_websocket(self._token, self._gw_url, shard_id=shard_id, shard_count=shard_count) as gw: self._gateways[shard_id] = gw while True: try: async with multio.asynclib.finalize_agen(gw.events()) as agen: async for event in agen: await self.fire_event(event[0], *event[1:], gateway=gw) except (WebSocketClosing, WebSocketClosed): logger.warning("Websocket closing, reconnecting...") except Exception as e: # kill the bot if we failed to parse something await self.kill() raise finally: self._gateways.pop(shard_id, None)
[docs] async def start(self, shard_count: int): """ Starts the bot. :param shard_count: The number of shards to boot. """ # update ready state for shard_id in range(shard_count): self._ready_state[shard_id] = False async with multio.asynclib.task_manager() as tg: self.task_manager = tg self.events.task_manager = tg await multio.asynclib.spawn(tg, self._background_get_app_info) for shard_id in range(0, shard_count): await multio.asynclib.spawn(tg, self.handle_shard, shard_id, shard_count)
[docs] async def run_async(self, *, shard_count: int = 1, autoshard: bool = True): """ Runs the client asynchronously. :param shard_count: The number of shards to boot. :param autoshard: If the bot should be autosharded. """ if autoshard: shard_count = await self.get_shard_count() self.shard_count = shard_count return await self.start(shard_count)
[docs] async def kill(self) -> None: """ Kills the bot by closing all shards. """ for gateway in self._gateways.values(): await gateway.close(code=1006, reason="Bot killed", reconnect=False)
[docs] def run(self, *, shard_count: int = 1, autoshard: bool = True, **kwargs): """ Convenience method to run the bot with multio. :param shard_count: The number of shards to use. Ignored if autoshard is True. :param autoshard: If the bot should be autosharded. """ p = functools.partial(self.run_async, shard_count=shard_count, autoshard=autoshard) multio.run(p, **kwargs)
[docs] @classmethod def from_token(cls, token: str = None): """ Starts a bot from a token object. :param token: The token to use for the bot. """ bot = cls(token) return bot.run()