Source code for curious.commands.ratelimit
"""
Utilities for ratelimiting a command.
"""
import time
from typing import Any, Callable, List, Tuple
from curious.commands import Context
from curious.commands.exc import CommandRateLimited
[docs]class BucketNamer:
"""
A simple namespace for storing bucket functions.
"""
def __new__(cls):
raise NotImplementedError("Don't make an instance of this class")
[docs] @staticmethod
def GUILD(ctx: Context) -> str:
"""
A bucket namer that uses the guild ID as the bucket.
"""
return str(ctx.guild.id)
[docs] @staticmethod
def CHANNEL(ctx: Context) -> str:
"""
A bucket namer that uses the channel ID as the bucket.
"""
return str(ctx.channel.id)
[docs] @staticmethod
def AUTHOR(ctx: Context) -> str:
"""
A bucket namer that uses the author ID as the bucket.
"""
return str(ctx.author.id)
[docs] @staticmethod
def GLOBAL(ctx: Context) -> str:
"""
A bucket namer that is global.
"""
return "GLOBAL"
[docs]class CommandRateLimit(object):
"""
Represents a ratelimit for a command.
"""
def __init__(self, *, limit: int, time: float,
bucket_namer: Callable[[Context], str] = BucketNamer.AUTHOR):
"""
:param limit: The number of times a command can be called in the specified limit.
:param time: The time (in seconds) this ratelimit lasts.
:param bucket_namer: A callable that gets the ratelimit bucket name.
"""
self.limit = limit
self.time = time
self.bucket_namer = bucket_namer
#: The command function being used.
self.command = None
[docs] def get_full_bucket_key(self, ctx: Context) -> Tuple[str, str]:
"""
Gets the full bucket key for this ratelimit.
"""
return self.command.cmd_name, self.bucket_namer(ctx)
[docs]class RateLimiter(object):
"""
Represents a ratelimiter. This ensures that commands meet the ratelimit before being ran.
"""
def __init__(self):
self._ratelimit_buckets = {}
[docs] async def update_bucket(self, key: Any, current_uses: int, expiration: float):
"""
Updates a ratelimit bucket.
:param key: The ratelimit key to use.
:param current_uses: The current uses for the key.
:param expiration: When the ratelimit expires.
"""
self._ratelimit_buckets[key] = (current_uses, expiration)
[docs] async def get_bucket(self, key: Any) -> Tuple[int, float]:
"""
Gets the ratelimit bucket for the specified key.
:param key: The key to use.
:return: A two-item tuple of (uses, expiration), or None if no bucket was found.
"""
return self._ratelimit_buckets.get(key)
[docs] async def ensure_ratelimits(self, ctx: Context, cmd):
"""
Ensures the ratelimits for a command.
"""
ratelimits: List[CommandRateLimit] = cmd.cmd_ratelimits
for limit in ratelimits:
bucket_key = limit.get_full_bucket_key(ctx)
bucket = await self.get_bucket(key=bucket_key)
if not bucket:
await self.update_bucket(bucket_key, 1, time.monotonic() + limit.time)
continue
# check if we've hit the limit
if bucket[0] == limit.limit:
# we have, but that might be okay if the timer has expired
if time.monotonic() > bucket[1]:
# we're good, so we can just reset the bucket and continue on our way
await self.update_bucket(bucket_key, 1, time.monotonic() + limit.time)
else:
raise CommandRateLimited(ctx, cmd, limit, bucket)
else:
# we haven't, but we need to up the number anyway
await self.update_bucket(bucket_key, bucket[0] + 1, bucket[1])