# This file is part of curious.
#
# curious is free software: you can redistribute it and/or modify
# it under the terms of the GNU Lesser General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# curious is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Lesser General Public License for more details.
#
# You should have received a copy of the GNU Lesser General Public License
# along with curious. If not, see <http://www.gnu.org/licenses/>.
"""
Websocket gateway code.
.. currentmodule:: curious.core.gateway
"""
import enum
import json
import logging
import sys
import time
import zlib
from collections import Counter
from typing import Any, AsyncContextManager, AsyncGenerator, List, Union
import multio
from async_generator import asynccontextmanager
from dataclasses import dataclass # use a 3.6 backport if available
from lomond.events import Binary, Closed, Connecting, Text
from curious.core._ws_wrapper import BasicWebsocketWrapper
from curious.util import safe_generator
[docs]class GatewayOp(enum.IntEnum):
"""
Represents the opcode mapping for the gateway.
"""
DISPATCH = 0
HEARTBEAT = 1
IDENTIFY = 2
PRESENCE = 3
VOICE_STATE = 4
VOICE_PING = 5
RESUME = 6
RECONNECT = 7
REQUEST_MEMBERS = 8
INVALIDATE_SESSION = 9
HELLO = 10
HEARTBEAT_ACK = 11
GUILD_SYNC = 12
[docs]@dataclass
class _GatewayState:
"""
Represents the gateway state for the current gateway.
"""
#: The current token.
token: str
#: The current gateway URL.
gateway_url: str
#: The shard ID for this gateway.
shard_id: int
#: The shard count for this gateway.
shard_count: int
#: The current session ID.
session_id: str = None
#: The current sequence.
sequence: int = 0
[docs]@dataclass
class HeartbeatStats:
"""
Represents the statistics for the gateway's heartbeat counters.
"""
#: The number of heartbeats sent.
heartbeats: int = 0
#: The number of heartbeat acks received.
heartbeat_acks: int = 0
#: Internal time when the last heartbeat was sent.
last_heartbeat_time: int = 0
#: Internal time when the last heartbeat_ack was received.
last_ack_time: int = 0
@property
def gw_time(self) -> float:
"""
:return: The time the most recent heartbeat and heartbeat_ack.
"""
return self.last_ack_time - self.last_heartbeat_time
[docs]class GatewayHandler(object):
"""
Represents a gateway handler - something that is connected to Discord's websocket and handles
incoming events and parses them as appropriate.
You don't want to create this class directly; use :meth:`.open_gateway` instead.
.. code-block:: python3
async with open_gateway("wss://gateway.discord.gg", token="token",
shard_id=0, shard_count=1) as gateway:
async for event in gateway.events():
...
"""
GATEWAY_VERSION = 6
def __init__(self, gw_state: _GatewayState):
#: The current state being used for this gateway.
self.gw_state = gw_state
#: The current heartbeat stats being used for this gateway.
self.heartbeat_stats = HeartbeatStats()
#: The current :class:`.BasicWebsocketWrapper` connected to Discord.
self.websocket: BasicWebsocketWrapper = None
#: The current task group for this gateway.
self.task_group = None
self._logger = None
self._stop_heartbeating = multio.Event()
self._dispatches_handled = Counter()
@property
def logger(self) -> logging.Logger:
"""
:return: The gateway-specific logger.
"""
if self._logger:
return self._logger
self._logger = logging.getLogger("curious.gateway:shard-{}".format(self.gw_state.shard_id))
return self._logger
[docs] async def close(self, code: int = 1000, reason: str = "Client closed connection", *,
reconnect: bool = False, clear_session_id: bool = True):
"""
Close the current websocket connection.
:param code: The close code.
:param reason: The close reason.
:param reconnect: If we should reconnect.
:param clear_session_id: If we should clear the session ID.
"""
await self.websocket.close(code=code, reason=reason, reconnect=reconnect)
# this kills the websocket
await self._stop_heartbeating.set()
if clear_session_id:
self.gw_state.session_id = None
# also clear heartbeats so we don't immediately HEARTBEAT with the wrong hb
self.gw_state.sequence = None
self.heartbeat_stats.heartbeats = 0
self.heartbeat_stats.heartbeat_acks = 0
# send commands
[docs] async def send(self, data: dict) -> None:
"""
Sends data down the websocket.
"""
dumped = json.dumps(data)
return await self.websocket.send_text(dumped)
[docs] async def send_identify(self) -> None:
"""
Sends an IDENTIFY to Discord.
"""
payload = {
"op": GatewayOp.IDENTIFY,
"d": {
"token": self.gw_state.token,
"properties": {
"$os": sys.platform,
"$browser": "curious",
"$device": "curious",
"$referrer": "",
"$referring_domain": ""
},
"compress": True,
"large_threshold": 250,
"v": self.GATEWAY_VERSION,
"shard": [self.gw_state.shard_id, self.gw_state.shard_count]
}
}
return await self.send(payload)
[docs] async def send_heartbeat(self) -> None:
"""
Sends a heartbeat to Discord.
"""
# increment the stats
self.heartbeat_stats.heartbeats += 1
self.heartbeat_stats.last_heartbeat_time = time.monotonic()
if self.heartbeat_stats.heartbeats > self.heartbeat_stats.heartbeat_acks + 1:
self.logger.warning("Connection has zombied, reconnecting.")
# Note: The 1006 close code signifies an error.
# In my testing, closing with a 1006 will allow a resume once reconnected,
# whereas other close codes won't.
# The timeout mihgt be too high to RESUME, however.
return await self.close(code=1006, reason="Zombied connection", reconnect=True,
clear_session_id=False)
self.logger.debug("Heartbeating with sequence {}".format(self.gw_state.sequence))
payload = {
"op": GatewayOp.HEARTBEAT,
"d": self.gw_state.sequence
}
return await self.send(payload)
[docs] async def send_resume(self) -> None:
"""
Sends a RESUME to Discord, attempting to resume the connection from where we left off.
"""
payload = {
"op": GatewayOp.RESUME,
"d": {
"token": self.gw_state.token,
"session_id": self.gw_state.session_id,
"seq": self.gw_state.sequence
}
}
return await self.send(payload)
[docs] async def send_guild_chunks(self, guild_ids: List[int]) -> None:
"""
Sends GUILD_MEMBER_CHUNK packets to Discord.
"""
payload = {
"op": GatewayOp.REQUEST_MEMBERS,
"d": {
"guild_id": list(map(str, guild_ids)),
"query": "",
"limit": 0
}
}
return await self.send(payload)
[docs] async def send_status(self, *, status: int = None, name: str = None, url: str = None,
type_: int = None,
afk: bool = None):
"""
Sends a PRESENCE_UPDATE.
:param status: The int status to send.
:param name: The name of the status to send.
:param url: The URL to include if applicable.
:param type_: The type of the status to send.
:param afk: If the account is to be marked as AFK.
"""
payload = {
"op": GatewayOp.PRESENCE,
"d": {}
}
if status is not None:
payload["d"]["status"] = status
if name is not None:
game = {
"name": name,
"type": type_
}
if url is not None:
game["url"] = url
payload["d"]["game"] = game
if afk is not None:
payload["d"].update(afk=afk, since=int(time.time() * 1000))
return await self.send(payload)
[docs] async def open(self) -> None:
"""
Opens a new connection to Discord.
.. warning::
This only opens the websocket.
"""
if multio.asynclib.lib_name == "curio":
from curious.core._ws_wrapper.curio_wrapper import CurioWebsocketWrapper as Wrapper
ws_open = Wrapper.open
elif multio.asynclib.lib_name == "trio":
from curious.core._ws_wrapper.trio_wrapper import TrioWebsocketWrapper as Wrapper
ws_open = lambda url: Wrapper.open(url, self.task_group)
else:
raise RuntimeError("Unsupported lib: " + multio.asynclib.lib_name)
self.logger.info("Using %s for the gateway", Wrapper.__name__)
self.websocket = await ws_open(self.gw_state.gateway_url)
[docs] async def events(self) -> AsyncGenerator[None, Any]:
"""
Returns an async generator used to iterate over the events received by this websocket.
"""
async for event in self.websocket:
if isinstance(event, Closed):
await self._stop_heartbeat_events()
yield "websocket_closed",
elif isinstance(event, Connecting):
yield "websocket_opened",
elif isinstance(event, (Text, Binary)):
gen = self.handle_data_event(event)
async with multio.asynclib.finalize_agen(gen) as finalized:
async for i in finalized:
yield i
[docs] async def _start_heatbeat_events(self, heartbeat_interval: float):
"""
Starts heartbeating.
:param heartbeat_interval: The number of seconds between each heartbeat.
"""
if self._stop_heartbeating.is_set():
self._stop_heartbeating.clear()
async def heartbeater() -> None:
while True:
try:
async with multio.asynclib.timeout_after(heartbeat_interval):
await self._stop_heartbeating.wait()
except multio.asynclib.TaskTimeout:
pass
else:
break
await self.send_heartbeat()
await multio.asynclib.spawn(self.task_group, heartbeater)
[docs] async def _stop_heartbeat_events(self) -> None:
"""
Cancels any current heartbeat events.
"""
await self._stop_heartbeating.set()
# reset our heartbeat count
self.heartbeat_stats.heartbeats = 0
self.heartbeat_stats.heartbeat_acks = 0
[docs] async def handle_data_event(self, evt: Union[Text, Binary]):
"""
Handles a data event.
"""
if evt.name == "binary":
# magic numbers
data = zlib.decompress(evt.data, 15, 10490000)
data = data.decode("utf-8")
else:
data = evt.text
# empty payloads
if not data:
return
decoded = json.loads(data)
opcode = decoded.get('op')
sequence = decoded.get('s')
event_data = decoded.get('d', {})
# update sequence number for dispatches
if sequence is not None:
self.gw_state.sequence = sequence
# switch based on opcode
if opcode == GatewayOp.HELLO:
heartbeat_interval = event_data.get("heartbeat_interval", 45000) / 1000.0
self.logger.debug("Heartbeating every {} seconds.".format(heartbeat_interval))
await self.send_heartbeat()
await self._start_heatbeat_events(heartbeat_interval)
trace = ", ".join(event_data["_trace"])
self.logger.info(f"Connected to Discord servers {trace}")
if self.gw_state.session_id is None:
self.logger.info("Sending IDENTIFY...")
await self.send_identify()
else:
self.logger.info("We already have a session ID, Sending RESUME...")
await self.send_resume()
# give an event down here instead of above
# this means that we're all done when we go to give off our event
yield ("gateway_hello", event_data['_trace'])
elif opcode == GatewayOp.HEARTBEAT:
await self.send_heartbeat()
yield "gateway_heartbeat_received",
elif opcode == GatewayOp.HEARTBEAT_ACK:
self.heartbeat_stats.heartbeat_acks += 1
self.heartbeat_stats.last_ack_time = time.monotonic()
yield "gateway_heartbeat_ack",
elif opcode == GatewayOp.INVALIDATE_SESSION:
# the data sent is if we should resume
# if it's non-existent, we assume it's False.
should_resume = data or False
if should_resume is True:
self.logger.debug("Sending RESUME again")
await self.send_resume()
else:
self.logger.warning("Received INVALIDATE_SESSION with d False, re-identifying.")
self.gw_state.sequence = 0
self.gw_state.session_id = None
await self.send_identify()
yield ("gateway_invalidate_session", should_resume,)
elif opcode == GatewayOp.DISPATCH:
event = decoded.get("t")
if not event:
return
if event == "READY":
# hijack the session id
self.gw_state.session_id = event_data["session_id"]
self._dispatches_handled[event] += 1
yield ("gateway_dispatch_received", event, event_data,)
elif opcode == GatewayOp.RECONNECT:
self.logger.info("Being asked to reconnect...")
await self.close(code=1000, reason="Server asked to reconnect",
reconnect=True, clear_session_id=False)
else:
try:
self.logger.warning("Unhandled opcode: {} ({})".format(opcode, GatewayOp(opcode)))
except ValueError:
self.logger.warning("Unknown opcode: {}".format(opcode))
[docs]@asynccontextmanager
@safe_generator
async def open_websocket(token: str, url: str, *,
shard_id: int = 0, shard_count: int = 1) \
-> AsyncContextManager[GatewayHandler]:
"""
Opens a new connection to Discord.
This is an async context manager; for example, using Trio for nursery management:
.. code-block:: python3
async with trio.open_nursery() as nursery:
async with open_websocket(token, url, task_group=nursery) as gateway:
# example for changing presence
nursery.start_soon(some_gw_handler, gateway)
async for event in gateway.events():
# handle events, etc.
...
:param token: The token to connect to Discord with.
:param url: The gateway URL to connect with.
:param shard_id: The shard ID to connect with. Defaults to 0.
:param shard_count: The number of shards to boot with.
:return: An async context manager that yields a :class:`.GatewayHandler`.
"""
params = f"/?v={GatewayHandler.GATEWAY_VERSION}&encoding=json"
url = url + params
state = _GatewayState(token=token, gateway_url=url, shard_id=shard_id, shard_count=shard_count)
gw = GatewayHandler(gw_state=state)
logger = logging.getLogger(f"curious.gateway:shard-{shard_id}")
async with multio.asynclib.task_manager() as tg:
gw.task_group = tg
try:
logger.info("Opening gateway connection to %s", url)
await gw.open()
yield gw
finally:
# make sure we don't die on closing the task group
await gw._stop_heartbeating.set()
await gw.close(code=1000, reason="Closing bot")