diff --git a/dozer/Components/CustomJoinLeaveMessages.py b/dozer/Components/CustomJoinLeaveMessages.py deleted file mode 100644 index 2298c375..00000000 --- a/dozer/Components/CustomJoinLeaveMessages.py +++ /dev/null @@ -1,100 +0,0 @@ -"""Holder for the custom join/leave messages database class and the associated methods""" -import discord -from loguru import logger - -from dozer import db - - -async def send_log(member): - """Sends the message for when a user joins or leave a guild""" - config = await CustomJoinLeaveMessages.get_by(guild_id=member.guild.id) - if len(config): - channel = member.guild.get_channel(config[0].channel_id) - if channel: - embed = discord.Embed(color=0x00FF00) - embed.set_author(name='Member Joined', icon_url=member.display_avatar.replace(format='png', size=32)) - embed.description = format_join_leave(config[0].join_message, member) - embed.set_footer(text="{} | {} members".format(member.guild.name, member.guild.member_count)) - try: - await channel.send(content=member.mention if config[0].ping else None, embed=embed) - except discord.Forbidden: - logger.warning( - f"Guild {member.guild}({member.guild.id}) has invalid permissions for join/leave logs") - - -def format_join_leave(template: str, member: discord.Member): - """Formats join leave message templates - {guild} = guild name - {user} = user's name plus discriminator ex. SnowPlow#5196 - {user_name} = user's name without discriminator - {user_mention} = user's mention - {user_id} = user's ID - """ - if template: - return template.format(guild=member.guild, user=str(member), user_name=member.name, - user_mention=member.mention, user_id=member.id) - else: - return "{user_mention}\n{user} ({user_id})".format(user=str(member), user_mention=member.mention, - user_id=member.id) - - -class CustomJoinLeaveMessages(db.DatabaseTable): - """Holds custom join leave messages""" - __tablename__ = 'memberlogconfig' - __uniques__ = 'guild_id' - - @classmethod - async def initial_create(cls): - """Create the table in the database""" - async with db.Pool.acquire() as conn: - await conn.execute(f""" - CREATE TABLE {cls.__tablename__} ( - guild_id bigint PRIMARY KEY NOT NULL, - memberlog_channel bigint NOT NULL, - name varchar NOT NULL, - send_on_verify boolean - )""") - - def __init__(self, guild_id, channel_id=None, ping=None, join_message=None, leave_message=None, send_on_verify=False): - super().__init__() - self.guild_id = guild_id - self.channel_id = channel_id - self.ping = ping - self.join_message = join_message - self.leave_message = leave_message - self.send_on_verify = send_on_verify - - @classmethod - async def get_by(cls, **kwargs): - results = await super().get_by(**kwargs) - result_list = [] - for result in results: - obj = CustomJoinLeaveMessages(guild_id=result.get("guild_id"), channel_id=result.get("channel_id"), - ping=result.get("ping"), - join_message=result.get("join_message"), - leave_message=result.get("leave_message"), - send_on_verify=result.get("send_on_verify")) - result_list.append(obj) - return result_list - - async def version_1(self): - """DB migration v1""" - async with db.Pool.acquire() as conn: - await conn.execute(f""" - alter table memberlogconfig rename column memberlog_channel to channel_id; - alter table memberlogconfig alter column channel_id drop not null; - alter table {self.__tablename__} drop column IF EXISTS name; - alter table {self.__tablename__} - add IF NOT EXISTS ping boolean default False; - alter table {self.__tablename__} - add IF NOT EXISTS join_message text default null; - alter table {self.__tablename__} - add IF NOT EXISTS leave_message text default null; - """) - - async def version_2(self): - async with db.Pool.acquire() as conn: - await conn.execute(f"alter table {self.__tablename__} " - f"add if not exists send_on_verify boolean default null;") - - __versions__ = [version_1, version_2] diff --git a/dozer/Components/TeamNumbers.py b/dozer/Components/TeamNumbers.py deleted file mode 100644 index 8b694382..00000000 --- a/dozer/Components/TeamNumbers.py +++ /dev/null @@ -1,67 +0,0 @@ -from dozer import db - - -class TeamNumbers(db.DatabaseTable): - """Database operations for tracking team associations.""" - __tablename__ = 'team_numbers' - __uniques__ = 'user_id, team_number, team_type' - - @classmethod - async def initial_create(cls): - """Create the table in the database""" - async with db.Pool.acquire() as conn: - await conn.execute(f""" - CREATE TABLE {cls.__tablename__} ( - user_id bigint NOT NULL, - team_number bigint NOT NULL, - team_type VARCHAR NOT NULL, - PRIMARY KEY (user_id, team_number, team_type) - )""") - - def __init__(self, user_id, team_number, team_type): - super().__init__() - self.user_id = user_id - self.team_number = team_number - self.team_type = team_type - - async def update_or_add(self): - """Assign the attribute to this object, then call this method to either insert the object if it doesn't exist in - the DB or update it if it does exist. It will update every column not specified in __uniques__.""" - # This is its own functions because all columns must be unique, which breaks the syntax of the other one - keys = [] - values = [] - for var, value in self.__dict__.items(): - # Done so that the two are guaranteed to be in the same order, which isn't true of keys() and values() - if value is not None: - keys.append(var) - values.append(value) - async with db.Pool.acquire() as conn: - statement = f""" - INSERT INTO {self.__tablename__} ({", ".join(keys)}) - VALUES({','.join(f'${i+1}' for i in range(len(values)))}) - """ - await conn.execute(statement, *values) - - @classmethod - async def get_by(cls, **kwargs): - results = await super().get_by(**kwargs) - result_list = [] - for result in results: - obj = TeamNumbers(user_id=result.get("user_id"), - team_number=result.get("team_number"), - team_type=result.get("team_type")) - result_list.append(obj) - return result_list - - # noinspection SqlResolve - @classmethod - async def top10(cls, user_ids): - """Returns the top 10 team entries""" - query = f"""SELECT team_type, team_number, count(*) - FROM {cls.__tablename__} - WHERE user_id = ANY($1) --first param: list of user IDs - GROUP BY team_type, team_number - ORDER BY count DESC, team_type, team_number - LIMIT 10""" - async with db.Pool.acquire() as conn: - return await conn.fetch(query, user_ids) \ No newline at end of file diff --git a/dozer/bot.py b/dozer/bot.py index cc72a00b..6eff3faf 100755 --- a/dozer/bot.py +++ b/dozer/bot.py @@ -1,14 +1,14 @@ """Bot object for Dozer""" - - import os import re import sys import traceback -from typing import Pattern +from typing import Pattern, Optional, Union, Generator, Dict, Any import discord +from discord import Status, Message from discord.ext import commands +from discord.ext.commands import Cooldown, CommandError, BucketType from loguru import logger from sentry_sdk import capture_exception @@ -18,6 +18,7 @@ from .context import DozerContext from .db import db_init, db_migrate + if discord.version_info.major < 2: logger.error("Your installed discord.py version is too low " "%d.%d.%d, please upgrade to at least 2.0.0", @@ -36,14 +37,14 @@ class InvalidContext(commands.CheckFailure): class Dozer(commands.Bot): """Botty things that are critical to Dozer working""" - _global_cooldown = commands.Cooldown(1, 1) # One command per second per user + _global_cooldown: Cooldown = Cooldown(1, 1) # One command per second per user - def __init__(self, config: dict, *args, **kwargs): + def __init__(self, config: Dict[str, Union[Dict[str, str], str]], *args, **kwargs): self.wavelink = None - self.dynamic_prefix = _utils.PrefixHandler(config['prefix']) + self.dynamic_prefix: _utils.PrefixHandler = _utils.PrefixHandler(str(config['prefix'])) super().__init__(command_prefix=self.dynamic_prefix.handler, *args, **kwargs) - self.config = config - self._restarting = False + self.config: Dict[str, Any] = config + self._restarting: bool = False self.check(self.global_checks) async def setup_hook(self) -> None: @@ -67,23 +68,24 @@ async def on_ready(self): perms |= cmd.required_permissions.value else: logger.warning(f"Command {cmd} not subclass of Dozer type.") - logger.debug('Bot Invite: {}'.format(utils.oauth_url(self.user.id, discord.Permissions(perms)))) + logger.debug('Bot Invite: {}'.format(utils.oauth_url(str(self.user.id), discord.Permissions(perms)))) if self.config['is_backup']: - status = discord.Status.dnd + status: Status = Status.dnd else: - status = discord.Status.online - activity = discord.Game(name=f"@{self.user.name} or '{self.config['prefix']}' in {len(self.guilds)} guilds") + status: Status = Status.online + activity: discord.Game = discord.Game(name=f"@{self.user.name} or '{self.config['prefix']}' in {len(self.guilds)} guilds") try: await self.change_presence(activity=activity, status=status) except TypeError: logger.warning("You are running an older version of the discord.py rewrite (with breaking changes)! " "To upgrade, run `pip install -r requirements.txt --upgrade`") - async def get_context(self, message: discord.Message, *, cls=DozerContext): # pylint: disable=arguments-differ + async def get_context(self, message: Message, *, cls=DozerContext) -> DozerContext: # pylint: disable=arguments-differ ctx = await super().get_context(message, cls=cls) + ctx.prefix = self.dynamic_prefix.handler(self, message) return ctx - async def on_command_error(self, context: DozerContext, exception): # pylint: disable=arguments-differ + async def on_command_error(self, context: DozerContext, exception: CommandError): # pylint: disable=arguments-differ if isinstance(exception, commands.NoPrivateMessage): await context.send('{}, This command cannot be used in DMs.'.format(context.author.mention)) elif isinstance(exception, commands.UserInputError): @@ -105,11 +107,14 @@ async def on_command_error(self, context: DozerContext, exception): # pylint: d '{}, That command is on cooldown! Try again in {:.2f}s!'.format(context.author.mention, exception.retry_after)) elif isinstance(exception, commands.MaxConcurrencyReached): - types = {discord.ext.commands.BucketType.default: "`Global`", - discord.ext.commands.BucketType.guild: "`Guild`", - discord.ext.commands.BucketType.channel: "`Channel`", - discord.ext.commands.BucketType.category: "`Category`", - discord.ext.commands.BucketType.member: "`Member`", discord.ext.commands.BucketType.user: "`User`"} + types: Dict[BucketType, str] = { + BucketType.default: "`Global`", + BucketType.guild: "`Guild`", + BucketType.channel: "`Channel`", + BucketType.category: "`Category`", + BucketType.member: "`Member`", + BucketType.user: "`User`" + } await context.send( '{}, That command has exceeded the max {} concurrency limit of `{}` instance! Please try again later.'.format( context.author.mention, types[exception.per], exception.number)) @@ -128,14 +133,14 @@ async def on_command_error(self, context: DozerContext, exception): # pylint: d context.channel.recipient, context.message.content) logger.error(''.join(traceback.format_exception(type(exception), exception, exception.__traceback__))) - async def on_error(self, event_method, *args, **kwargs): + async def on_error(self, event_method: str, *args, **kwargs): """Don't ignore the error, causing Sentry to capture it.""" print('Ignoring exception in {}'.format(event_method), file=sys.stderr) traceback.print_exc() capture_exception() @staticmethod - def format_error(ctx: DozerContext, err: Exception, *, word_re: Pattern = re.compile('[A-Z][a-z]+')): + def format_error(ctx: DozerContext, err: Exception, *, word_re: Pattern = re.compile('[A-Z][a-z]+')) -> str: """Turns an exception into a user-friendly (or -friendlier, at least) error message.""" type_words = word_re.findall(type(err).__name__) type_msg = ' '.join(map(str.lower, type_words)) @@ -145,7 +150,7 @@ def format_error(ctx: DozerContext, err: Exception, *, word_re: Pattern = re.com else: return type_msg - def global_checks(self, ctx: DozerContext): + def global_checks(self, ctx: DozerContext) -> bool: """Checks that should be executed before passed to the command""" if ctx.author.bot: raise InvalidContext('Bots cannot run commands!') @@ -154,6 +159,15 @@ def global_checks(self, ctx: DozerContext): raise InvalidContext('Global rate-limit exceeded!') return True + def get_command(self, name: str) -> Optional[Union[_utils.Command, _utils.Group]]: # pylint: disable=arguments-differ + return super().get_command(name) + + def walk_commands(self) -> Generator[Union[_utils.Command, _utils.Group], None, None]: + return super().walk_commands() + + def get_cog(self, name: str, /) -> Optional[_utils.Cog]: + return super().get_cog(name) + def run(self, *args, **kwargs): token = self.config['discord_token'] del self.config['discord_token'] # Prevent token dumping diff --git a/dozer/cogs/_utils.py b/dozer/cogs/_utils.py index a3ce3e16..1f34b901 100755 --- a/dozer/cogs/_utils.py +++ b/dozer/cogs/_utils.py @@ -1,12 +1,12 @@ """Utilities for Dozer.""" import asyncio import inspect -import typing from collections.abc import Mapping -from typing import Dict, Union +from typing import Dict, Union, Optional, Any, Coroutine, List, Generator, Iterable, AsyncGenerator +from typing import TYPE_CHECKING import discord -from discord import app_commands +from discord import app_commands, Embed, Permissions from discord.ext import commands from discord.ext.commands import HybridCommand from discord.ext.commands.core import MISSING @@ -15,8 +15,11 @@ from dozer import db from dozer.context import DozerContext +if TYPE_CHECKING: + from dozer import Dozer + __all__ = ['bot_has_permissions', 'command', 'group', 'Cog', 'Reactor', 'Paginator', 'paginate', 'chunk', 'dev_check', - 'DynamicPrefixEntry'] + 'DynamicPrefixEntry', 'CommandMixin'] @@ -26,12 +29,12 @@ class CommandMixin: # Keyword-arg dictionary passed to __init__ when copying/updating commands when Cog instances are created # inherited from discord.ext.command.Command - __original_kwargs__: typing.Dict[str, typing.Any] + __original_kwargs__: Dict[str, Any] _required_permissions = None - def __init__(self, func, **kwargs): + def __init__(self, func: Union["Command", "Group"], **kwargs): super().__init__(func, **kwargs) - self.example_usage = kwargs.pop('example_usage', '') + self.example_usage: Optional[str] = kwargs.pop('example_usage', '') if hasattr(func, '__required_permissions__'): # This doesn't need to go into __original_kwargs__ because it'll be read from func each time self._required_permissions = func.__required_permissions__ @@ -40,7 +43,7 @@ def __init__(self, func, **kwargs): def required_permissions(self): """Required permissions handler""" if self._required_permissions is None: - self._required_permissions = discord.Permissions() + self._required_permissions = Permissions() return self._required_permissions @property @@ -76,15 +79,15 @@ class Group(CommandMixin, commands.HybridGroup): def command( self, name: Union[str, app_commands.locale_str] = MISSING, - *args: typing.Any, + *args: Any, with_app_command: bool = True, - **kwargs: typing.Any, + **kwargs: Any, ): """Initiates a command""" def decorator(func): kwargs.setdefault('parent', self) - result = command(name=name, *args, with_app_command=with_app_command, **kwargs)(func) + result = command(name=name, with_app_command=with_app_command, **kwargs)(func) self.add_command(result) return result @@ -93,15 +96,15 @@ def decorator(func): def group( self, name: Union[str, app_commands.locale_str] = MISSING, - *args: typing.Any, + *args: Any, with_app_command: bool = True, - **kwargs: typing.Any, + **kwargs: Any, ): """Initiates a command group""" def decorator(func): kwargs.setdefault('parent', self) - result = group(name=name, *args, with_app_command=with_app_command, **kwargs)(func) + result = group(name=name, with_app_command=with_app_command, **kwargs)(func) self.add_command(result) return result @@ -111,15 +114,18 @@ def decorator(func): class Cog(commands.Cog): """Initiates cogs.""" - def __init__(self, bot: commands.Bot): + def __init__(self, bot: "Dozer"): super().__init__() - self.bot = bot + self.bot: "Dozer" = bot + + def walk_commands(self) -> Generator[Union[Group, Command], None, None]: + return super().walk_commands() def dev_check(): """Function decorator to check that the calling user is a developer""" - async def predicate(ctx: DozerContext): + async def predicate(ctx: DozerContext) -> bool: if ctx.author.id not in ctx.bot.config['developers']: raise commands.NotOwner('you are not a developer!') return True @@ -163,8 +169,10 @@ def __init__(self, ctx: DozerContext, initial_reactions, *, auto_remove: bool = self._remove_reactions = auto_remove and ctx.channel.permissions_for( self.me).manage_messages # Check for required permissions self.timeout = timeout - self._action = None + self._action: Optional[Coroutine] = None self.message = None + self.pages: Dict[Union[int, str], Embed] + self.page: Embed async def __aiter__(self): self.message = await self.dest.send(embed=self.pages[self.page]) @@ -231,21 +239,22 @@ class Paginator(Reactor): '\N{BLACK SQUARE FOR STOP}' # :stop_button: ) - def __init__(self, ctx: DozerContext, initial_reactions, pages, *, start: int = 0, auto_remove: bool = True, - timeout: int = 60): + def __init__(self, ctx: DozerContext, initial_reactions: Iterable[discord.Reaction], pages: List[Union[Embed, Dict[str, Embed]]], *, + start: Union[int, str] = 0, auto_remove: bool = True, timeout: int = 60): all_reactions = list(initial_reactions) - ind = all_reactions.index(Ellipsis) + ind: int = all_reactions.index(Ellipsis) all_reactions[ind:ind + 1] = self.pagination_reactions super().__init__(ctx, all_reactions, auto_remove=auto_remove, timeout=timeout) if pages and isinstance(pages[-1], Mapping): - named_pages = pages.pop() - self.pages = dict(enumerate(pages), **named_pages) + named_pages: Dict[str, Embed] = pages.pop() + # The following code assembles the list of Embeds into a dict with the indexes as keys, and with the named pages. + self.pages = {**{k: v for v, k in enumerate(pages)}, **named_pages} else: self.pages = pages - self.len_pages = len(pages) - self.page = start - self.message = None - self.reactor = None + self.len_pages: int = len(pages) + self.page: Union[int, str] = start + self.message: Optional[discord.Message] = None + self.reactor: Optional[AsyncGenerator] = None async def __aiter__(self): self.reactor = super().__aiter__() @@ -300,13 +309,13 @@ async def paginate(ctx: DozerContext, pages, *, start: int = 0, auto_remove: boo pass # The normal pagination reactions are handled - just drop anything else -def chunk(iterable, size: int): +def chunk(iterable, size: int) -> Iterable[Iterable]: """ Break an iterable into chunks of a fixed size. Returns an iterable of iterables. Almost-inverse of itertools.chain.from_iterable - passing the output of this into that function will reconstruct the original iterable. If the last chunk is not the full length, it will be returned but not padded. """ - contents = list(iterable) + contents: List = list(iterable) for i in range(0, len(contents), size): yield contents[i:i + size] @@ -316,8 +325,8 @@ def bot_has_permissions(**required): def predicate(ctx: DozerContext): """Function to tell the bot if it has the right permissions""" - given = ctx.channel.permissions_for((ctx.guild or ctx.channel).me) - missing = [name for name, value in required.items() if getattr(given, name) != value] + given: Permissions = ctx.channel.permissions_for((ctx.guild or ctx.channel).me) + missing: List[str] = [name for name, value in required.items() if getattr(given, name) != value] if missing: raise commands.BotMissingPermissions(missing) @@ -334,7 +343,7 @@ def decorator(func): func.__commands_checks__.append(predicate) else: func.__commands_checks__ = [predicate] - func.__required_permissions__ = discord.Permissions() + func.__required_permissions__ = Permissions() func.__required_permissions__.update(**required) return func @@ -350,15 +359,16 @@ def __init__(self, default_prefix: str): def handler(self, bot, message: discord.Message): """Process the dynamic prefix for each message""" - dynamic = self.prefix_cache.get(message.guild.id) if message.guild else self.default_prefix + dynamic = self.prefix_cache.get(message.guild.id) if message.guild else None # <@!> is a nickname mention which discord.py doesn't make by default - return [f"<@!{bot.user.id}> ", bot.user.mention, dynamic if dynamic else self.default_prefix] + return [f"<@!{bot.user.id}> ", f"<@!{bot.user.id}>", bot.user.mention, bot.user.mention + " ", + dynamic.prefix if dynamic else self.default_prefix] async def refresh(self): """Refreshes the prefix cache""" prefixes = await DynamicPrefixEntry.get_by() # no filters, get all for prefix in prefixes: - self.prefix_cache[prefix.guild_id] = prefix.prefix + self.prefix_cache[prefix.guild_id] = prefix logger.info(f"{len(prefixes)} prefixes loaded from database") @@ -384,7 +394,7 @@ def __init__(self, guild_id: int, prefix: str): self.prefix = prefix @classmethod - async def get_by(cls, **kwargs): + async def get_by(cls, **kwargs) -> List["DynamicPrefixEntry"]: results = await super().get_by(**kwargs) result_list = [] for result in results: diff --git a/dozer/cogs/actionlogs.py b/dozer/cogs/actionlogs.py index 8a2e041a..ba640cb5 100644 --- a/dozer/cogs/actionlogs.py +++ b/dozer/cogs/actionlogs.py @@ -1,11 +1,12 @@ """Provides guild logging functions for Dozer.""" import asyncio -import datetime import math import time +import datetime +from typing import TYPE_CHECKING, List, Set, Optional import discord -from discord.ext import commands +from discord import Guild, AuditLogAction, Member, Message, Embed, AuditLogEntry from discord.ext.commands import has_permissions, BadArgument from discord.utils import escape_markdown from loguru import logger @@ -13,12 +14,13 @@ from dozer.context import DozerContext from ._utils import * from .general import blurple -from .moderation import GuildNewMember from .. import db -from ..Components.CustomJoinLeaveMessages import CustomJoinLeaveMessages, format_join_leave, send_log +if TYPE_CHECKING: + from dozer import Dozer -async def embed_paginatorinator(content_name, embed, text): + +async def embed_paginatorinator(content_name: str, embed: Embed, text: str): """Chunks up embed sections to fit within 1024 characters""" required_chunks = math.ceil(len(text) / 1024) c_embed = embed.copy() @@ -32,13 +34,13 @@ async def embed_paginatorinator(content_name, embed, text): class Actionlog(Cog): """A cog to handle guild events tasks""" - def __init__(self, bot: commands.Bot): + def __init__(self, bot: "Dozer"): super().__init__(bot) - self.edit_delete_config = db.ConfigCache(GuildMessageLog) + self.edit_delete_config: db.ConfigCache = db.ConfigCache(GuildMessageLog) self.bulk_delete_buffer = {} @staticmethod - async def check_audit(guild, event_type, event_time=None): + async def check_audit(guild: Guild, event_type: AuditLogAction, event_time: Optional[datetime] = None): """Method for checking the audit log for events""" try: async for entry in guild.audit_logs(limit=1, after=event_time, @@ -48,10 +50,10 @@ async def check_audit(guild, event_type, event_time=None): return None @Cog.listener('on_member_join') - async def on_member_join(self, member): + async def on_member_join(self, member: Member): """Logs that a member joined, with optional custom message""" - join_leave_config = await CustomJoinLeaveMessages.get_by(guild_id=member.guild.id) - new_members_config = await GuildNewMember.get_by(guild_id=member.guild.id) + join_leave_config: List[CustomJoinLeaveMessages] = await CustomJoinLeaveMessages.get_by(guild_id=member.guild.id) + new_members_config: List[GuildNewMember] = await GuildNewMember.get_by(guild_id=member.guild.id) if len(new_members_config) == 0 and len(join_leave_config) == 0: await send_log(member) else: @@ -63,14 +65,14 @@ async def on_member_join(self, member): await send_log(member) @Cog.listener('on_member_remove') - async def on_member_remove(self, member): + async def on_member_remove(self, member: Member): """Logs that a member left.""" config = await CustomJoinLeaveMessages.get_by(guild_id=member.guild.id) if len(config): channel = member.guild.get_channel(config[0].channel_id) if channel: - embed = discord.Embed(color=0xFF0000) - embed.set_author(name='Member Left', icon_url=member.display_avatar.replace(format='png', size=32)) + embed: Embed = Embed(color=0xFF0000) + embed.set_author(name='Member Left', icon_url=member.avatar.replace(format='png', size=32)) embed.description = format_join_leave(config[0].leave_message, member) embed.set_footer(text="{} | {} members".format(member.guild.name, member.guild.member_count)) try: @@ -80,18 +82,18 @@ async def on_member_remove(self, member): f"Guild {member.guild}({member.guild.id}) has invalid permissions for join/leave logs") @Cog.listener("on_member_update") - async def on_member_update(self, before, after): + async def on_member_update(self, before: Member, after: Member): """Called whenever a member gets updated""" if before.nick != after.nick: await self.on_nickname_change(before, after) - async def on_nickname_change(self, before, after): + async def on_nickname_change(self, before: Member, after: Member): """The log handler for when a user changes their nicknames""" - audit = await self.check_audit(after.guild, discord.AuditLogAction.member_update) + audit: Optional[AuditLogEntry] = await self.check_audit(after.guild, discord.AuditLogAction.member_update) - embed = discord.Embed(title="Nickname Changed", - color=0x00FFFF) - embed.set_author(name=after, icon_url=after.display_avatar) + embed: Embed = Embed(title="Nickname Changed", + color=0x00FFFF) + embed.set_author(name=after, icon_url=after.avatar) embed.add_field(name="Before", value=before.nick, inline=False) embed.add_field(name="After", value=after.nick, inline=False) @@ -108,9 +110,10 @@ async def on_nickname_change(self, before, after): await channel.send(embed=embed) await self.check_nickname_lock(before, after) - async def check_nickname_lock(self, before, after): + @staticmethod + async def check_nickname_lock(before: Member, after: Member): """The handler for checking if a member is allowed to change their nickname""" - results = await NicknameLock.get_by(guild_id=after.guild.id, member_id=after.id) + results: List[NicknameLock] = await NicknameLock.get_by(guild_id=after.guild.id, member_id=after.id) if results: while time.time() <= results[0].timeout: await asyncio.sleep(10) # prevents nickname update spam @@ -126,12 +129,12 @@ async def check_nickname_lock(self, before, after): f"your nickname has been reverted to **{results[0].locked_name}**") @Cog.listener() - async def on_raw_bulk_message_delete(self, payload): + async def on_raw_bulk_message_delete(self, payload: discord.RawBulkMessageDeleteEvent): """Log bulk message deletes""" - guild = self.bot.get_guild(int(payload.guild_id)) - message_channel = self.bot.get_channel(int(payload.channel_id)) - message_ids = payload.message_ids - cached_messages = payload.cached_messages + guild: Guild = self.bot.get_guild(int(payload.guild_id)) + message_channel: discord.TextChannel = self.bot.get_channel(int(payload.channel_id)) + message_ids: Set[int] = payload.message_ids + cached_messages: List[Message] = payload.cached_messages message_log_channel = await self.edit_delete_config.query_one(guild_id=guild.id) if message_log_channel is not None: @@ -146,7 +149,7 @@ async def on_raw_bulk_message_delete(self, payload): self.bulk_delete_buffer[message_channel.id]["msg_ids"] += message_ids self.bulk_delete_buffer[message_channel.id]["msgs"] += cached_messages header_message = self.bulk_delete_buffer[message_channel.id]["header_message"] - header_embed = discord.Embed(title="Bulk Message Delete", color=0xFF0000) + header_embed = Embed(title="Bulk Message Delete", color=0xFF0000) deleted = self.bulk_delete_buffer[message_channel.id]["msg_ids"] cached = self.bulk_delete_buffer[message_channel.id]["msgs"] header_embed.description = f"{len(deleted)} Messages Deleted In: {message_channel.mention}\n" \ @@ -154,7 +157,7 @@ async def on_raw_bulk_message_delete(self, payload): f"Messages logged: *Currently Purging*" await header_message.edit(embed=header_embed) else: - header_embed = discord.Embed(title="Bulk Message Delete", color=0xFF0000) + header_embed = Embed(title="Bulk Message Delete", color=0xFF0000) header_embed.description = f"{len(message_ids)} Messages Deleted In: {message_channel.mention}\n" \ f"Messages cached: {len(cached_messages)}/{len(message_ids)} \n" \ f"Messages logged: *Currently Purging*" @@ -179,7 +182,7 @@ async def bulk_delete_log(self, message_channel): header_message = buffer_entry["header_message"] message_count = 0 - header_embed = discord.Embed(title="Bulk Message Delete", color=0xFF0000) + header_embed = Embed(title="Bulk Message Delete", color=0xFF0000) header_embed.description = f"{len(message_ids)} Messages Deleted In: {message_channel.mention}\n" \ f"Messages cached: {len(cached_messages)}/{len(message_ids)} \n" \ f"Messages logged: *Currently Logging*" @@ -188,8 +191,8 @@ async def bulk_delete_log(self, message_channel): current_page = 1 page_character_count = 0 page_message_count = 0 - embed = discord.Embed(title="Bulk Message Delete", color=0xFF0000, - timestamp=datetime.datetime.now(tz=datetime.timezone.utc)) + embed = Embed(title="Bulk Message Delete", color=0xFF0000, + timestamp=datetime.datetime.now(tz=datetime.timezone.utc)) for message in sorted(cached_messages, key=lambda msg: msg.created_at): page_character_count += len(message.content[0:512]) + 3 @@ -201,8 +204,8 @@ async def bulk_delete_log(self, message_channel): await channel.send(embed=embed) except discord.HTTPException as e: logger.debug(f"Bulk delete embed failed to send: {e}") - embed = discord.Embed(title="Bulk Message Delete", color=0xFF0000, - timestamp=datetime.datetime.now(tz=datetime.timezone.utc)) + embed: Embed = discord.Embed(title="Bulk Message Delete", color=0xFF0000, + timestamp=datetime.datetime.now(tz=datetime.timezone.utc)) page_character_count = len(message.content) message_count += page_message_count page_message_count = 0 @@ -236,9 +239,9 @@ async def on_raw_message_delete(self, payload: discord.RawMessageDeleteEvent): message_channel = self.bot.get_channel(int(payload.channel_id)) message_id = int(payload.message_id) message_created = discord.Object(message_id).created_at - embed = discord.Embed(title="Message Deleted", - description=f"Message Deleted In: {message_channel.mention}", - color=0xFF00F0, timestamp=message_created) + embed = Embed(title="Message Deleted", + description=f"Message Deleted In: {message_channel.mention}", + color=0xFF00F0, timestamp=message_created) embed.add_field(name="Message", value="N/A", inline=False) embed.set_footer(text=f"Message ID: {message_channel.id} - {message_id}\nSent at ") message_log_channel = await self.edit_delete_config.query_one(guild_id=guild.id) @@ -248,15 +251,15 @@ async def on_raw_message_delete(self, payload: discord.RawMessageDeleteEvent): await channel.send(embed=embed) @Cog.listener('on_message_delete') - async def on_message_delete(self, message: discord.Message): + async def on_message_delete(self, message: Message): """When a message is deleted, log it.""" if message.author == self.bot.user: return - audit = await self.check_audit(message.guild, discord.AuditLogAction.message_delete, message.created_at) - embed = discord.Embed(title="Message Deleted", - description=f"Message Deleted In: {message.channel.mention}\nSent by: {message.author.mention}", - color=0xFF0000, timestamp=message.created_at) - embed.set_author(name=message.author, icon_url=message.author.display_avatar) + audit: Optional[AuditLogEntry] = await self.check_audit(message.guild, discord.AuditLogAction.message_delete, message.created_at) + embed: Embed = Embed(title="Message Deleted", + description=f"Message Deleted In: {message.channel.mention}\nSent by: {message.author.mention}", + color=0xFF0000, timestamp=message.created_at) + embed.set_author(name=message.author, icon_url=message.author.avatar) if audit: if audit.target == message.author: audit_member = await message.guild.fetch_member(audit.user.id) @@ -280,26 +283,26 @@ async def on_raw_message_edit(self, payload: discord.RawMessageUpdateEvent): if payload.cached_message: return mchannel = self.bot.get_channel(int(payload.channel_id)) - guild = mchannel.guild + guild: Guild = mchannel.guild try: - content = payload.data['content'] + content: Optional[str] = payload.data['content'] except KeyError: content = None author = payload.data.get("author") if not author: return - guild_id = guild.id - channel_id = payload.channel_id - user_id = author['id'] - if (self.bot.get_user(int(user_id))).bot: + guild_id: int = guild.id + channel_id: int = payload.channel_id + user_id: int = int(author['id']) + if (self.bot.get_user(user_id)).bot: return # Breakout if the user is a bot - message_id = payload.message_id - link = f"https://discordapp.com/channels/{guild_id}/{channel_id}/{message_id}" - mention = f"<@!{user_id}>" - avatar_link = f"https://cdn.discordapp.com/avatars/{user_id}/{author['avatar']}.webp?size=1024" - embed = discord.Embed(title="Message Edited", - description=f"[MESSAGE]({link}) From {mention}\nEdited In: {mchannel.mention}", - color=0xFFC400) + message_id: int = payload.message_id + link: str = f"https://discordapp.com/channels/{guild_id}/{channel_id}/{message_id}" + mention: str = f"<@!{user_id}>" + avatar_link: str = f"https://cdn.discordapp.com/avatars/{user_id}/{author['avatar']}.webp?size=1024" + embed: Embed = Embed(title="Message Edited", + description=f"[MESSAGE]({link}) From {mention}\nEdited In: {mchannel.mention}", + color=0xFFC400) embed.set_author(name=f"{author['username']}#{author['discriminator']}", icon_url=avatar_link) embed.add_field(name="Original", value="N/A", inline=False) if content: @@ -316,7 +319,7 @@ async def on_raw_message_edit(self, payload: discord.RawMessageUpdateEvent): await channel.send(embed=embed) @Cog.listener('on_message_edit') - async def on_message_edit(self, before: discord.Message, after: discord.Message): + async def on_message_edit(self, before: Message, after: Message): """Logs message edits.""" if before.author.bot: return @@ -329,10 +332,10 @@ async def on_message_edit(self, before: discord.Message, after: discord.Message) user_id = before.author.id message_id = before.id link = f"https://discordapp.com/channels/{guild_id}/{channel_id}/{message_id}" - embed = discord.Embed(title="Message Edited", - description=f"[MESSAGE]({link}) From {before.author.mention}" - f"\nEdited In: {before.channel.mention}", color=0xFFC400, - timestamp=after.edited_at) + embed: Embed = discord.Embed(title="Message Edited", + description=f"[MESSAGE]({link}) From {before.author.mention}" + f"\nEdited In: {before.channel.mention}", color=0xFFC400, + timestamp=after.edited_at) embed.set_author(name=before.author, icon_url=before.author.display_avatar) embed.set_footer(text=f"Message ID: {channel_id} - {message_id}\nUserID: {user_id}") if len(before.content) + len(after.content) < 5000: @@ -365,12 +368,12 @@ async def on_message_edit(self, before: discord.Message, after: discord.Message) @Cog.listener('on_member_ban') async def on_member_ban(self, guild: discord.Guild, user: discord.User): """Logs raw member ban events, even if not banned via &ban""" - audit = await self.check_audit(guild, discord.AuditLogAction.ban) - embed = discord.Embed(title="User Banned", color=0xff6700) - embed.set_thumbnail(url=user.display_avatar) + audit: Optional[AuditLogEntry] = await self.check_audit(guild, discord.AuditLogAction.ban) + embed: Embed = Embed(title="User Banned", color=0xff6700) + embed.set_thumbnail(url=user.avatar) embed.add_field(name="Banned user", value=f"{user}|({user.id})") if audit and audit.target == user: - acton_member = await guild.fetch_member(audit.user.id) + acton_member: Member = await guild.fetch_member(audit.user.id) embed.description = f"User banned by: {acton_member.mention}\n{acton_member}|({acton_member.id})" embed.add_field(name="Reason", value=audit.reason, inline=False) embed.set_footer(text=f"Actor ID: {acton_member.id}\nTarget ID: {user.id}") @@ -388,9 +391,10 @@ async def on_member_ban(self, guild: discord.Guild, user: discord.User): @has_permissions(administrator=True) async def messagelogconfig(self, ctx: DozerContext, channel_mentions: discord.TextChannel): """Set the modlog channel for a server by passing the channel id""" - config = await GuildMessageLog.get_by(guild_id=ctx.guild.id) - if len(config) != 0: - config = config[0] + results: List[GuildMessageLog] = await GuildMessageLog.get_by(guild_id=ctx.guild.id) + config: GuildMessageLog + if len(results) != 0: + config = results[0] config.name = ctx.guild.name config.messagelog_channel = channel_mentions.id else: @@ -407,9 +411,9 @@ async def messagelogconfig(self, ctx: DozerContext, channel_mentions: discord.Te @has_permissions(administrator=True) async def memberlogconfig(self, ctx: DozerContext): """Command group to configure Join/Leave logs""" - config = await CustomJoinLeaveMessages.get_by(guild_id=ctx.guild.id) - embed = discord.Embed(title=f"Join/Leave configuration for {ctx.guild}", color=blurple) + config: List[CustomJoinLeaveMessages] = await CustomJoinLeaveMessages.get_by(guild_id=ctx.guild.id) if len(config): + embed: Embed = Embed(title=f"Join/Leave configuration for {ctx.guild}", color=blurple) channel = ctx.guild.get_channel(config[0].channel_id) embed.add_field(name="Message Channel", value=channel.mention if channel else "None") embed.add_field(name="Ping on join", value=config[0].ping) @@ -440,12 +444,12 @@ async def viewconfig(self, ctx: DozerContext): @has_permissions(manage_guild=True) async def setchannel(self, ctx: DozerContext, channel: discord.TextChannel): """Configure join/leave channel""" - config = CustomJoinLeaveMessages( + config: CustomJoinLeaveMessages = CustomJoinLeaveMessages( guild_id=ctx.guild.id, channel_id=channel.id ) await config.update_or_add() - e = discord.Embed(color=blurple) + e: Embed = Embed(color=blurple) e.add_field(name='Success!', value=f"Join/Leave log channel has been set to {channel.mention}") e.set_footer(text='Triggered by ' + escape_markdown(ctx.author.display_name)) await ctx.send(embed=e) @@ -454,40 +458,41 @@ async def setchannel(self, ctx: DozerContext, channel: discord.TextChannel): @has_permissions(manage_guild=True) async def toggleping(self, ctx: DozerContext): """Toggles whenever a new member gets pinged on join""" - config = await CustomJoinLeaveMessages.get_by(guild_id=ctx.guild.id) + config: List[CustomJoinLeaveMessages] = await CustomJoinLeaveMessages.get_by(guild_id=ctx.guild.id) if len(config): config[0].ping = not config[0].ping else: config = [CustomJoinLeaveMessages(guild_id=ctx.guild.id, ping=True)] await config[0].update_or_add() - e = discord.Embed(color=blurple) + e: Embed = Embed(color=blurple) e.add_field(name='Success!', value=f"Ping on join is set to: {config[0].ping}") e.set_footer(text='Triggered by ' + escape_markdown(ctx.author.display_name)) await ctx.send(embed=e) @memberlogconfig.command() @has_permissions(manage_guild=True) - async def togglesendonverify(self, ctx): + async def togglesendonverify(self, ctx: DozerContext): """Toggles if a join log is sent on user joining or on completing verification""" - config = await CustomJoinLeaveMessages.get_by(guild_id=ctx.guild.id) + config: List[CustomJoinLeaveMessages] = await CustomJoinLeaveMessages.get_by(guild_id=ctx.guild.id) if len(config): config[0].send_on_verify = not config[0].send_on_verify else: config = [CustomJoinLeaveMessages(guild_id=ctx.guild.id, send_on_verify=True)] await config[0].update_or_add() - e = discord.Embed(color=blurple) + e: Embed = Embed(color=blurple) e.add_field(name='Success!', value=f"Send on verify is set to: {config[0].send_on_verify}") e.set_footer(text='Triggered by ' + ctx.author.display_name) await ctx.send(embed=e) @memberlogconfig.command() @has_permissions(manage_guild=True) - async def setjoinmessage(self, ctx: DozerContext, *, template: str = None): + async def setjoinmessage(self, ctx: DozerContext, *, template: Optional[str] = None): """Configure custom join message template""" - e = discord.Embed(color=blurple) + e: Embed = Embed(color=blurple) e.set_footer(text='Triggered by ' + escape_markdown(ctx.author.display_name)) + config: CustomJoinLeaveMessages if template: config = CustomJoinLeaveMessages( guild_id=ctx.guild.id, @@ -505,10 +510,11 @@ async def setjoinmessage(self, ctx: DozerContext, *, template: str = None): @memberlogconfig.command() @has_permissions(manage_guild=True) - async def setleavemessage(self, ctx: DozerContext, *, template=None): + async def setleavemessage(self, ctx: DozerContext, *, template: Optional[str] = None): """Configure custom leave message template""" - e = discord.Embed(color=blurple) + e: Embed = Embed(color=blurple) e.set_footer(text='Triggered by ' + escape_markdown(ctx.author.display_name)) + config: CustomJoinLeaveMessages if template: config = CustomJoinLeaveMessages( guild_id=ctx.guild.id, @@ -528,9 +534,9 @@ async def setleavemessage(self, ctx: DozerContext, *, template=None): @has_permissions(manage_guild=True) async def disable(self, ctx: DozerContext): """Disables Join/Leave logging""" - e = discord.Embed(color=blurple) + e: Embed = Embed(color=blurple) e.set_footer(text='Triggered by ' + escape_markdown(ctx.author.display_name)) - config = CustomJoinLeaveMessages( + config: CustomJoinLeaveMessages = CustomJoinLeaveMessages( guild_id=ctx.guild.id, channel_id=CustomJoinLeaveMessages.nullify ) @@ -540,10 +546,10 @@ async def disable(self, ctx: DozerContext): @memberlogconfig.command() @has_permissions(manage_guild=True) - async def help(self, - ctx: DozerContext): # I cannot put formatting example in example_usage because then it trys to format the example + async def help(self, ctx: DozerContext): + # I cannot put formatting example in example_usage because then it tries to format the example """Displays message formatting key""" - e = discord.Embed(color=blurple) + e: Embed = Embed(color=blurple) e.set_footer(text='Triggered by ' + escape_markdown(ctx.author.display_name)) e.description = """ `{guild}` = guild name @@ -557,20 +563,20 @@ async def help(self, @command() @has_permissions(manage_nicknames=True) @bot_has_permissions(manage_nicknames=True) - async def locknickname(self, ctx: DozerContext, member: discord.Member, *, name: str): + async def locknickname(self, ctx: DozerContext, member: Member, *, name: str): """Locks a members nickname to a particular string, in essence revoking nickname change perms""" try: await member.edit(nick=name) except discord.Forbidden: raise BadArgument(f"Dozer is not elevated high enough to change {member}'s nickname") - lock = NicknameLock( + lock: NicknameLock = NicknameLock( guild_id=ctx.guild.id, member_id=member.id, locked_name=name, timeout=time.time() ) await lock.update_or_add() - e = discord.Embed(color=blurple) + e: Embed = Embed(color=blurple) e.add_field(name='Success!', value=f"**{member}**'s nickname has been locked to **{name}**") e.set_footer(text='Triggered by ' + escape_markdown(ctx.author.display_name)) await ctx.send(embed=e) @@ -582,11 +588,11 @@ async def locknickname(self, ctx: DozerContext, member: discord.Member, *, name: @command() @has_permissions(manage_nicknames=True) @bot_has_permissions(manage_nicknames=True) - async def unlocknickname(self, ctx: DozerContext, member: discord.Member): + async def unlocknickname(self, ctx: DozerContext, member: Member): """Removes nickname lock from member""" deleted = await NicknameLock.delete(guild_id=ctx.guild.id, member_id=member.id) if int(deleted.split(" ", 1)[1]): - e = discord.Embed(color=blurple) + e: Embed = Embed(color=blurple) e.add_field(name='Success!', value=f"Nickname lock for {member} has been removed") e.set_footer(text='Triggered by ' + escape_markdown(ctx.author.display_name)) await ctx.send(embed=e) @@ -616,15 +622,15 @@ async def initial_create(cls): UNIQUE (guild_id, member_id) )""") - def __init__(self, guild_id: int, member_id: int, locked_name: str, timeout: float = None): + def __init__(self, guild_id: int, member_id: int, locked_name: Optional[str], timeout: Optional[float] = None): super().__init__() - self.guild_id = guild_id - self.member_id = member_id - self.locked_name = locked_name - self.timeout = timeout + self.guild_id: int = guild_id + self.member_id: int = member_id + self.locked_name: Optional[str] = locked_name + self.timeout: Optional[float] = timeout @classmethod - async def get_by(cls, **kwargs): + async def get_by(cls, **kwargs) -> List["NicknameLock"]: results = await super().get_by(**kwargs) result_list = [] for result in results: @@ -652,12 +658,12 @@ async def initial_create(cls): def __init__(self, guild_id: int, name: str, messagelog_channel: int): super().__init__() - self.guild_id = guild_id - self.name = name - self.messagelog_channel = messagelog_channel + self.guild_id: int = guild_id + self.name: str = name + self.messagelog_channel: int = messagelog_channel @classmethod - async def get_by(cls, **kwargs): + async def get_by(cls, **kwargs) -> List["GuildMessageLog"]: results = await super().get_by(**kwargs) result_list = [] for result in results: @@ -667,6 +673,149 @@ async def get_by(cls, **kwargs): return result_list -async def setup(bot): +async def send_log(member: Member): + """Sends the message for when a user joins or leave a guild""" + config = await CustomJoinLeaveMessages.get_by(guild_id=member.guild.id) + if len(config): + channel = member.guild.get_channel(config[0].channel_id) + if channel: + embed: Embed = Embed(color=0x00FF00) + embed.set_author(name='Member Joined', icon_url=member.avatar.replace(format='png', size=32)) + embed.description = format_join_leave(config[0].join_message, member) + embed.set_footer(text="{} | {} members".format(member.guild.name, member.guild.member_count)) + try: + await channel.send(content=member.mention if config[0].ping else None, embed=embed) + except discord.Forbidden: + logger.warning( + f"Guild {member.guild}({member.guild.id}) has invalid permissions for join/leave logs") + + +def format_join_leave(template: str, member: Member): + """Formats join leave message templates + {guild} = guild name + {user} = user's name plus discriminator ex. SnowPlow#5196 + {user_name} = user's name without discriminator + {user_mention} = user's mention + {user_id} = user's ID + """ + if template: + return template.format(guild=member.guild, user=str(member), user_name=member.name, + user_mention=member.mention, user_id=member.id) + else: + return "{user_mention}\n{user} ({user_id})".format(user=str(member), user_mention=member.mention, + user_id=member.id) + + +class CustomJoinLeaveMessages(db.DatabaseTable): + """Holds custom join leave messages""" + __tablename__ = 'memberlogconfig' + __uniques__ = 'guild_id' + + @classmethod + async def initial_create(cls): + """Create the table in the database""" + async with db.Pool.acquire() as conn: + await conn.execute(f""" + CREATE TABLE {cls.__tablename__} ( + guild_id bigint PRIMARY KEY NOT NULL, + memberlog_channel bigint NOT NULL, + name varchar NOT NULL, + send_on_verify boolean + )""") + + def __init__(self, guild_id: int, channel_id: int = None, ping: Optional[bool] = None, join_message: Optional[str] = None, + leave_message: Optional[str] = None, send_on_verify: Optional[bool] = False): + super().__init__() + self.guild_id: int = guild_id + self.channel_id: Optional[int] = channel_id + self.ping: Optional[bool] = ping + self.join_message: Optional[str] = join_message + self.leave_message: Optional[str] = leave_message + self.send_on_verify: Optional[bool] = send_on_verify + + @classmethod + async def get_by(cls, **kwargs) -> List["CustomJoinLeaveMessages"]: + results = await super().get_by(**kwargs) + result_list = [] + for result in results: + obj = CustomJoinLeaveMessages(guild_id=result.get("guild_id"), channel_id=result.get("channel_id"), + ping=result.get("ping"), + join_message=result.get("join_message"), + leave_message=result.get("leave_message"), + send_on_verify=result.get("send_on_verify")) + result_list.append(obj) + return result_list + + async def version_1(self): + """DB migration v1""" + async with db.Pool.acquire() as conn: + await conn.execute(f""" + alter table memberlogconfig rename column memberlog_channel to channel_id; + alter table memberlogconfig alter column channel_id drop not null; + alter table {self.__tablename__} drop column IF EXISTS name; + alter table {self.__tablename__} + add IF NOT EXISTS ping boolean default False; + alter table {self.__tablename__} + add IF NOT EXISTS join_message text default null; + alter table {self.__tablename__} + add IF NOT EXISTS leave_message text default null; + """) + + async def version_2(self): + """Updates database stuff to version 2""" + async with db.Pool.acquire() as conn: + await conn.execute(f"alter table {self.__tablename__} " + f"add if not exists send_on_verify boolean default null;") + + __versions__ = [version_1, version_2] + + +class GuildNewMember(db.DatabaseTable): + """Holds new member info""" + __tablename__ = 'new_members' + __uniques__ = 'guild_id' + + @classmethod + async def initial_create(cls): + """Create the table in the database""" + async with db.Pool.acquire() as conn: + await conn.execute(f""" + CREATE TABLE {cls.__tablename__} ( + guild_id bigint PRIMARY KEY, + channel_id bigint NOT NULL, + role_id bigint NOT NULL, + message varchar NOT NULL + )""") + + def __init__(self, guild_id: int, channel_id: int, role_id: int, message: str, require_team: bool): + super().__init__() + self.guild_id: int = guild_id + self.channel_id: int = channel_id + self.role_id: int = role_id + self.message: str = message + self.require_team: bool = require_team + + @classmethod + async def get_by(cls, **kwargs) -> List["GuildNewMember"]: + results = await super().get_by(**kwargs) + result_list = [] + for result in results: + obj = GuildNewMember(guild_id=result.get("guild_id"), channel_id=result.get("channel_id"), + role_id=result.get("role_id"), message=result.get("message"), + require_team=result.get("require_team")) + result_list.append(obj) + return result_list + + async def version_1(self): + """DB migration v1""" + async with db.Pool.acquire() as conn: + await conn.execute(f""" + ALTER TABLE {self.__tablename__} ADD require_team bool NOT NULL DEFAULT false; + """) + + __versions__ = [version_1] + + +async def setup(bot: "Dozer"): """Adds the actionlog cog to the bot.""" await bot.add_cog(Actionlog(bot)) diff --git a/dozer/cogs/development.py b/dozer/cogs/development.py index 6d8463fb..f64e947e 100755 --- a/dozer/cogs/development.py +++ b/dozer/cogs/development.py @@ -9,6 +9,7 @@ from loguru import logger from dozer.context import DozerContext +from . import _utils from ._utils import * @@ -43,7 +44,7 @@ async def reload(self, ctx: DozerContext, cog: str): async def document(self, ctx: DozerContext): """Dump documentation for Sphinx processing""" for x in self.bot.cogs: - cog = ctx.bot.get_cog(x) + cog: _utils.Cog = ctx.bot.get_cog(x) comrst = rstcloth.RstCloth() comrst.title(x) for command in cog.walk_commands(): @@ -85,11 +86,11 @@ async def evaluate(self, ctx: DozerContext, *, code: str): ret = await locals_['evaluated_function'](ctx) e.title = 'Python Evaluation - Success' - e.color = 0x00FF00 + e.colour = 0x00FF00 e.add_field(name='Output', value='```\n%s (%s)\n```' % (repr(ret), type(ret).__name__), inline=False) except Exception as err: e.title = 'Python Evaluation - Error' - e.color = 0xFF0000 + e.colour = 0xFF0000 e.add_field(name='Error', value='```\n%s\n```' % repr(err)) await ctx.send('', embed=e) diff --git a/dozer/cogs/filter.py b/dozer/cogs/filter.py index 8b8a1767..25cd66a8 100755 --- a/dozer/cogs/filter.py +++ b/dozer/cogs/filter.py @@ -2,15 +2,21 @@ with whitelisted role exceptions.""" import re +from re import Pattern +from typing import TYPE_CHECKING, List, Optional, Dict, Generator import discord -from discord.ext import commands +from asyncpg import Record +from discord import Embed, Role from discord.ext.commands import guild_only, has_permissions from dozer.context import DozerContext from ._utils import * from .. import db +if TYPE_CHECKING: + from dozer import Dozer + class Filter(Cog): """The filters need to be compiled each time they're run, but we don't want to compile every filter @@ -18,16 +24,17 @@ class Filter(Cog): the compiled object is placed in here. This dict is actually a dict full of dicts, with each parent dict's key being the guild ID for easy accessing. """ - filter_dict = {} + filter_dict: Dict[int, Dict[int, Pattern]] = {} - def __init__(self, bot: commands.Bot): + def __init__(self, bot: "Dozer"): super().__init__(bot) - self.word_filter_setting = db.ConfigCache(WordFilterSetting) - self.word_filter_role_whitelist = db.ConfigCache(WordFilterRoleWhitelist) + self.word_filter_setting: db.ConfigCache = db.ConfigCache(WordFilterSetting) + self.word_filter_role_whitelist: db.ConfigCache = db.ConfigCache(WordFilterRoleWhitelist) """Helper Functions""" - async def check_dm_filter(self, ctx: DozerContext, embed: discord.Embed): + @staticmethod + async def check_dm_filter(ctx: DozerContext, embed: Embed): """Send an embed, if the setting in the DB allows for it""" results = await WordFilterSetting.get_by(guild_id=ctx.guild.id, setting_type="dm") if results: @@ -126,26 +133,26 @@ async def on_member_update(self, before: discord.Member, after: discord.Member): @guild_only() async def filter(self, ctx: DozerContext, advanced: bool = False): """List and manage filtered words""" - results = await WordFilter.get_by(guild_id=ctx.guild.id, enabled=True) + results: List[WordFilter] = await WordFilter.get_by(guild_id=ctx.guild.id, enabled=True) if not results: - embed = discord.Embed(title="Filters for {}".format(ctx.guild.name)) + embed: Embed = Embed(title="Filters for {}".format(ctx.guild.name)) embed.description = "No filters found for this guild! Add one using `{}filter add [name]`".format( ctx.prefix) - embed.color = discord.Color.red() + embed.colour = discord.Color.red() await ctx.send(embed=embed) return - fmt = 'ID {0.filter_id}: `{0.friendly_name}`' + fmt: str = 'ID {0.filter_id}: `{0.friendly_name}`' if advanced: fmt += ': Pattern: `{0.pattern}`' - filter_text = '\n'.join(map(fmt.format, results)) + filter_text: str = '\n'.join(map(fmt.format, results)) - embed = discord.Embed() + embed: Embed = Embed() embed.title = "Filters for {}".format(ctx.guild.name) embed.add_field(name="Filters", value=filter_text) - embed.color = discord.Color.dark_orange() + embed.colour = discord.Color.dark_orange() await self.check_dm_filter(ctx, embed) filter.example_usage = """`{prefix}filter add test` - Adds test as a filter. @@ -171,9 +178,9 @@ async def add(self, ctx: DozerContext, pattern: str, friendly_name=None): except re.error as err: await ctx.send("Invalid RegEx! ```{}```".format(err.msg)) return - new_filter = WordFilter(guild_id=ctx.guild.id, pattern=pattern, friendly_name=friendly_name or pattern) + new_filter: WordFilter = WordFilter(guild_id=ctx.guild.id, pattern=pattern, friendly_name=friendly_name or pattern) await new_filter.update_or_add() - embed = discord.Embed() + embed: Embed = Embed() embed.title = "Filter added!" embed.description = "A new filter with the name `{}` was added.".format(friendly_name or pattern) embed.add_field(name="Pattern", value="`{}`".format(pattern)) @@ -185,19 +192,20 @@ async def add(self, ctx: DozerContext, pattern: str, friendly_name=None): @guild_only() @has_permissions(manage_guild=True) @filter.command() - async def edit(self, ctx: DozerContext, filter_id: int, pattern): + async def edit(self, ctx: DozerContext, filter_id: int, pattern: str): """Edit an already existing filter using a new pattern. A filter's friendly name cannot be edited.""" try: re.compile(pattern) except re.error as err: await ctx.send("Invalid RegEx! ```{}```".format(err.msg)) return - results = await WordFilter.get_by(guild_id=ctx.guild.id) - found = False - result = None - for result in results: - if result.filter_id == filter_id: + results: List[WordFilter] = await WordFilter.get_by(guild_id=ctx.guild.id) + found: bool = False + result: Optional[WordFilter] = None + for search_filter in results: + if search_filter.filter_id == filter_id: found = True + result = search_filter break if not found: await ctx.send("That filter ID does not exist or does not belong to this guild.") @@ -206,11 +214,11 @@ async def edit(self, ctx: DozerContext, filter_id: int, pattern): enabled_change = False if not result.enabled: result.enabled = True - enabled_change = True + enabled_change: bool = True result.pattern = pattern await result.update_or_add() await self.load_filters(ctx.guild.id) - embed = discord.Embed(title="Updated filter {}".format(result.friendly_name or result.pattern)) + embed: Embed = Embed(title="Updated filter {}".format(result.friendly_name or result.pattern)) embed.description = "Filter ID {} has been updated.".format(result.filter_id) embed.add_field(name="Old Pattern", value=old_pattern) embed.add_field(name="New Pattern", value=pattern) @@ -227,12 +235,13 @@ async def edit(self, ctx: DozerContext, filter_id: int, pattern): @filter.command() async def remove(self, ctx: DozerContext, filter_id: int): """Remove a pattern from the filter list.""" - result = await WordFilter.get_by(filter_id=filter_id) - if len(result) == 0: + results: List[WordFilter] = await WordFilter.get_by(filter_id=filter_id) + result: WordFilter + if len(results) == 0: await ctx.send("Filter ID {} not found!".format(filter_id)) return else: - result = result[0] + result: WordFilter = results[0] if result.guild_id != ctx.guild.id: await ctx.send("That Filter does not belong to this guild.") return @@ -249,15 +258,15 @@ async def remove(self, ctx: DozerContext, filter_id: int): async def dm_config(self, ctx: DozerContext, config: str): """Set whether filter words should be DMed when used in bot messages""" config: str = str(int(config)) # turns into "1" or "0" idk man - results = await WordFilterSetting.get_by(guild_id=ctx.guild.id, setting_type="dm") + results: List[WordFilterSetting] = await WordFilterSetting.get_by(guild_id=ctx.guild.id, setting_type="dm") if results: - before_setting = results[0].value + before_setting: Optional[str] = results[0].value # Due to the settings table having a serial ID, inserts always succeed, so update_or_add can't be used to # update in place. Instead, we have to delete and reinsert the record. await WordFilterSetting.delete(guild_id=results[0].guild_id, setting_type=results[0].setting_type) else: before_setting = None - result = WordFilterSetting(guild_id=ctx.guild.id, setting_type="dm", value=config) + result: WordFilterSetting = WordFilterSetting(guild_id=ctx.guild.id, setting_type="dm", value=config) await result.update_or_add() self.word_filter_setting.invalidate_entry(guild_id=ctx.guild.id, setting_type="dm") await ctx.send( @@ -272,10 +281,10 @@ async def dm_config(self, ctx: DozerContext, config: str): async def whitelist(self, ctx: DozerContext): """List all whitelisted roles for this server""" results = await WordFilterRoleWhitelist.get_by(guild_id=ctx.guild.id) - role_objects = [ctx.guild.get_role(db_role.role_id) for db_role in results] - role_names = (role.name for role in role_objects if role is not None) - roles_text = "\n".join(role_names) - embed = discord.Embed() + role_objects: List[Role] = [ctx.guild.get_role(db_role.role_id) for db_role in results] + role_names: Generator[str] = (role.name for role in role_objects if role is not None) + roles_text: str = "\n".join(role_names) + embed: Embed = Embed() embed.title = "Whitelisted roles for {}".format(ctx.guild.name) embed.description = "Anybody with any of the roles below will not have their messages filtered." embed.add_field(name="Roles", value=roles_text or "No roles") @@ -294,7 +303,7 @@ async def viewlist(self, ctx: DozerContext): @whitelist.command(name="add") async def whitelist_add(self, ctx: DozerContext, *, role: discord.Role): """Add a role to the whitelist""" - result = await WordFilterRoleWhitelist.get_by(role_id=role.id) + result: List[WordFilterRoleWhitelist] = await WordFilterRoleWhitelist.get_by(role_id=role.id) if len(result) != 0: await ctx.send("That role is already whitelisted.") return @@ -310,7 +319,7 @@ async def whitelist_add(self, ctx: DozerContext, *, role: discord.Role): @whitelist.command(name="remove") async def whitelist_remove(self, ctx: DozerContext, *, role: discord.Role): """Remove a role from the whitelist""" - result = await WordFilterRoleWhitelist.get_by(role_id=role.id) + result: List[WordFilterRoleWhitelist] = await WordFilterRoleWhitelist.get_by(role_id=role.id) if len(result) == 0: await ctx.send("That role is not whitelisted.") return @@ -349,14 +358,14 @@ async def initial_create(cls): def __init__(self, guild_id: int, friendly_name: str, pattern: str, enabled: bool = True, filter_id: int = None): super().__init__() - self.filter_id = filter_id - self.guild_id = guild_id - self.enabled = enabled - self.friendly_name = friendly_name - self.pattern = pattern + self.filter_id: int = filter_id + self.guild_id: int = guild_id + self.enabled: bool = enabled + self.friendly_name: str = friendly_name + self.pattern: str = pattern @classmethod - async def get_by(cls, **kwargs): + async def get_by(cls, **kwargs) -> List["WordFilter"]: results = await super().get_by(**kwargs) result_list = [] for result in results: @@ -386,14 +395,14 @@ async def initial_create(cls): def __init__(self, guild_id: int, setting_type: str, value: str): super().__init__() - self.guild_id = guild_id - self.setting_type = setting_type - self.value = value + self.guild_id: int = guild_id + self.setting_type: str = setting_type + self.value: str = value @classmethod - async def get_by(cls, **kwargs): - results = await super().get_by(**kwargs) - result_list = [] + async def get_by(cls, **kwargs) -> List["WordFilterSetting"]: + results: List[Record] = await super().get_by(**kwargs) + result_list: List["WordFilterSetting"] = [] for result in results: obj = WordFilterSetting(guild_id=result.get("guild_id"), setting_type=result.get("setting_type"), value=result.get('value')) @@ -418,11 +427,11 @@ async def initial_create(cls): def __init__(self, guild_id: int, role_id: int): super().__init__() - self.role_id = role_id - self.guild_id = guild_id + self.role_id: int = role_id + self.guild_id: int = guild_id @classmethod - async def get_by(cls, **kwargs): + async def get_by(cls, **kwargs) -> List["WordFilterRoleWhitelist"]: results = await super().get_by(**kwargs) result_list = [] for result in results: diff --git a/dozer/cogs/fun.py b/dozer/cogs/fun.py index 6411810c..e7fc2a4f 100755 --- a/dozer/cogs/fun.py +++ b/dozer/cogs/fun.py @@ -2,6 +2,7 @@ import asyncio import random from asyncio import sleep +from typing import List, TYPE_CHECKING import discord from discord.ext.commands import cooldown, BucketType, guild_only, BadArgument, MissingPermissions @@ -11,11 +12,15 @@ from ._utils import * from .general import blurple +if TYPE_CHECKING: + from dozer.cogs.levels import Levels + class Fun(Cog): """Fun commands""" - async def battle(self, ctx: DozerContext, opponent: discord.Member, delete_result: bool = True): + @staticmethod + async def battle(ctx: DozerContext, opponent: discord.Member, delete_result: bool = True): """Start a fight with another user.""" attacks = [ "**{opponent}** was hit on the head by **{attacker}** ", @@ -76,7 +81,7 @@ async def battle(self, ctx: DozerContext, opponent: discord.Member, delete_resul damages = [100, 150, 200, 300, 50, 250, 420] players = [ctx.author, opponent] - hps = [1400, 1400] + hps: List[int] = [1400, 1400] turn = random.randint(0, 1) messages = [] @@ -115,7 +120,7 @@ async def battle(self, ctx: DozerContext, opponent: discord.Member, delete_resul async def fight(self, ctx: DozerContext, opponent: discord.Member, wager: int = 0): """Start a fight with another user.""" - levels = self.bot.get_cog("Levels") + levels: "Levels" = self.bot.get_cog("Levels") if wager == 0: await self.battle(ctx, opponent, delete_result=False) @@ -153,7 +158,7 @@ async def fight(self, ctx: DozerContext, opponent: discord.Member, wager: int = await msg.add_reaction("✅") await msg.add_reaction("❌") except discord.Forbidden: - raise MissingPermissions(f"**{ctx.bot.user}** does not have the permission to add reacts") + raise MissingPermissions(['ADD_REACTIONS']) try: emoji = None @@ -188,8 +193,8 @@ def reaction_check(reaction, reactor): f"\n{opponent.mention} now is at " f"level {levels.level_for_total_xp(opponent_levels.total_xp)} ({opponent_levels.total_xp} XP)") - levels.sync_member(ctx.guild.id, ctx.author.id) - levels.sync_member(ctx.guild.id, opponent.id) + await levels.sync_member(ctx.guild.id, ctx.author.id) + await levels.sync_member(ctx.guild.id, opponent.id) elif emoji == "❌": try: diff --git a/dozer/cogs/general.py b/dozer/cogs/general.py index b0217c85..624566e2 100755 --- a/dozer/cogs/general.py +++ b/dozer/cogs/general.py @@ -1,15 +1,19 @@ """General, basic commands that are common for Discord bots""" +import datetime import inspect +from typing import Optional, Union, List, Dict import discord +from discord import AppInfo, Embed from discord.ext.commands import BadArgument, cooldown, BucketType, Group, has_permissions, NotOwner, guild_only from discord.utils import escape_markdown from dozer.context import DozerContext +from . import _utils from ._utils import * from ..utils import oauth_url -blurple = discord.Color.blurple() +blurple: discord.Colour = discord.Color.blurple() class General(Cog): @@ -22,8 +26,8 @@ async def ping(self, ctx: DozerContext): location = 'DMs' else: location = 'the **%s** server' % ctx.guild.name - response = await ctx.send('Pong! We\'re in %s.' % location) - delay = response.created_at - ctx.message.created_at + response: discord.Message = await ctx.send('Pong! We\'re in %s.' % location) + delay: datetime.timedelta = response.created_at - ctx.message.created_at await response.edit( content=response.content + '\nTook %d ms to respond.' % (delay.seconds * 1000 + delay.microseconds // 1000)) @@ -35,7 +39,7 @@ async def ping(self, ctx: DozerContext): @command(name='help', aliases=['about']) @bot_has_permissions(add_reactions=True, embed_links=True, read_message_history=True) # Message history is for internals of paginate() - async def base_help(self, ctx: DozerContext, *, target=None): + async def base_help(self, ctx: DozerContext, *, target: str = None): """Show this message.""" await ctx.defer() try: @@ -49,17 +53,17 @@ async def base_help(self, ctx: DozerContext, *, target=None): if target_name in ctx.bot.cogs: await self._help_cog(ctx, ctx.bot.cogs[target_name]) else: - command = ctx.bot.get_command(target_name) - if command is None: + target_command = ctx.bot.get_command(target_name) + if target_command is None: raise BadArgument('that command/cog does not exist!') else: - await self._help_command(ctx, command) + await self._help_command(ctx, target_command) else: # Command with subcommand - command = ctx.bot.get_command(' '.join(target)) - if command is None: + target_command = ctx.bot.get_command(' '.join(target)) + if target_command is None: raise BadArgument('that command does not exist!') else: - await self._help_command(ctx, command) + await self._help_command(ctx, target_command) base_help.example_usage = """ `{prefix}help` - General help message @@ -69,8 +73,8 @@ async def base_help(self, ctx: DozerContext, *, target=None): async def _help_all(self, ctx: DozerContext): """Gets the help message for all commands.""" - info = discord.Embed(title='Dozer: Info', description='A guild management bot for FIRST Discord servers', - color=discord.Color.blue()) + info: Embed = Embed(title='Dozer: Info', description='A guild management bot for FIRST Discord servers', + color=discord.Color.blue()) info.set_thumbnail(url=self.bot.user.avatar) info.add_field(name='About', value="Dozer: A collaborative bot for FIRST Discord servers, developed by the FRC Discord Server Development Team") @@ -90,28 +94,29 @@ async def _help_all(self, ctx: DozerContext): info.set_footer(text='Dozer Help | all commands | Info page') await self._show_help(ctx, info, 'Dozer: Commands', '', 'all commands', ctx.bot.commands) - async def _help_command(self, ctx: DozerContext, command): + async def _help_command(self, ctx: DozerContext, target_command: _utils.Command): """Gets the help message for one command.""" - info = discord.Embed(title='Command: {}{} {}'.format(ctx.prefix, command.qualified_name, command.signature), - description=command.help or ( - None if command.example_usage else 'No information provided.'), - color=discord.Color.blue()) - usage = command.example_usage + info: Embed = Embed(title='Command: {}{} {}'.format(ctx.prefix, target_command.qualified_name, target_command.signature), + description=target_command.help or ( + None if target_command.example_usage else 'No information provided.'), + color=discord.Color.blue()) + usage: Union[str, None] = target_command.example_usage if usage: info.add_field(name='Usage', value=usage.format(prefix=ctx.prefix, name=ctx.invoked_with), inline=False) - info.set_footer(text='Dozer Help | {!r} command | Info'.format(command.qualified_name)) + info.set_footer(text='Dozer Help | {!r} command | Info'.format(target_command.qualified_name)) await self._show_help(ctx, info, 'Subcommands: {prefix}{name} {signature}', '', '{name!r} command', - command.commands if isinstance(command, Group) else set(), - name=command.qualified_name, signature=command.signature) + target_command.commands if isinstance(target_command, Group) else set(), + name=target_command.qualified_name, signature=target_command.signature) - async def _help_cog(self, ctx: DozerContext, cog): + async def _help_cog(self, ctx: DozerContext, cog: Cog): """Gets the help message for one cog.""" await self._show_help(ctx, None, 'Category: {cog_name}', inspect.cleandoc(cog.__doc__ or ''), '{cog_name!r} category', - (command for command in ctx.bot.commands if command.cog is cog), + (target_command for target_command in ctx.bot.commands if target_command.cog is cog), cog_name=type(cog).__name__) - async def _show_help(self, ctx: DozerContext, start_page: discord.Embed, title: str, description: str, + @staticmethod + async def _show_help(ctx: DozerContext, start_page: Optional[Embed], title: str, description: str, footer: str, commands, **format_args): """Creates and sends a template help message, with arguments filled in.""" format_args['prefix'] = ctx.prefix @@ -128,20 +133,20 @@ async def _show_help(self, ctx: DozerContext, start_page: discord.Embed, title: continue command_chunks = list(chunk(sorted(filtered_commands, key=lambda cmd: cmd.name), 4)) format_args['len_pages'] = len(command_chunks) - pages = [] + pages: List[Union[Embed, Dict[str, Embed]]] = [] for page_num, page_commands in enumerate(command_chunks): format_args['page_num'] = page_num + 1 - page = discord.Embed(title=title.format(**format_args), description=description.format(**format_args), - color=discord.Color.blue()) - for command in page_commands: - if command.short_doc: - embed_value = command.short_doc - elif command.example_usage: # Usage provided - show the user the command to see it + page: Embed = Embed(title=title.format(**format_args), description=description.format(**format_args), + color=discord.Color.blue()) + for target_command in page_commands: + if target_command.short_doc: + embed_value = target_command.short_doc + elif target_command.example_usage: # Usage provided - show the user the command to see it embed_value = 'Use `{0.prefix}{0.invoked_with} {1.qualified_name}` for more information.'.format( - ctx, command) + ctx, target_command) else: embed_value = 'No information provided.' - page.add_field(name='{}{} {}'.format(ctx.prefix, command.qualified_name, command.signature), + page.add_field(name='{}{} {}'.format(ctx.prefix, target_command.qualified_name, target_command.signature), value=embed_value, inline=False) page.set_footer(text=footer.format(**format_args)) pages.append(page) @@ -154,19 +159,21 @@ async def _show_help(self, ctx: DozerContext, start_page: discord.Embed, title: elif start_page is not None: info_emoji = '\N{INFORMATION SOURCE}' p = Paginator(ctx, (info_emoji, ...), pages, start='info', - auto_remove=ctx.channel.permissions_for(ctx.me)) + auto_remove=ctx.channel.permissions_for(ctx.me).manage_messages) async for reaction in p: if reaction == info_emoji: p.go_to_page('info') else: - await paginate(ctx, pages, auto_remove=ctx.channel.permissions_for(ctx.me)) + await paginate(ctx, pages, auto_remove=ctx.channel.permissions_for(ctx.me).manage_messages) elif start_page: # No commands - command without subcommands or empty cog - but a usable info page await ctx.send(embed=start_page) - else: # No commands, and no info page + else: # No commands and no info page format_args['len_pages'] = 1 format_args['page_num'] = 1 - embed = discord.Embed(title=title.format(**format_args), description=description.format(**format_args), - color=discord.Color.blue()) + embed: Embed = Embed( + title=title.format(**format_args), + description=description.format(**format_args), + color=discord.Color.blue()) embed.set_footer(text=footer.format(**format_args)) await ctx.send(embed=embed) @@ -185,7 +192,7 @@ async def invite(self, ctx: DozerContext): Display the bot's invite link. The generated link gives all permissions the bot requires. If permissions are removed, some commands will be unusable. """ - bot_info = await self.bot.application_info() + bot_info: AppInfo = await self.bot.application_info() if not bot_info.bot_public or self.bot.config['invite_override'] != "": await ctx.send(self.bot.config['invite_override'] or "The bot is not able to be publicly invited. Please " "contact the bot developer. If you are the bot " @@ -194,20 +201,20 @@ async def invite(self, ctx: DozerContext): perms = 0 for cmd in ctx.bot.walk_commands(): perms |= cmd.required_permissions.value - await ctx.send('<{}>'.format(oauth_url(ctx.me.id, discord.Permissions(perms)))) + await ctx.send('<{}>'.format(oauth_url(str(ctx.me.id), discord.Permissions(perms)))) @command(aliases=["setprefix"]) @guild_only() @has_permissions(manage_guild=True) async def configprefix(self, ctx: DozerContext, prefix: str): """Update a servers dynamic prefix""" - new_prefix = DynamicPrefixEntry( + new_prefix: DynamicPrefixEntry = DynamicPrefixEntry( guild_id=int(ctx.guild.id), prefix=prefix ) await new_prefix.update_or_add() await self.bot.dynamic_prefix.refresh() - e = discord.Embed(color=blurple) + e: Embed = Embed(color=blurple) e.add_field(name='Success!', value=f"`{ctx.guild}`'s prefix has set to `{prefix}`!") e.set_footer(text='Triggered by ' + escape_markdown(ctx.author.display_name)) await ctx.send(embed=e) diff --git a/dozer/cogs/info.py b/dozer/cogs/info.py index 93f19a5c..cd859cfd 100755 --- a/dozer/cogs/info.py +++ b/dozer/cogs/info.py @@ -1,7 +1,7 @@ """Provides commands for pulling certain information.""" import math import typing -from datetime import timezone, datetime, date +from datetime import timezone, datetime from difflib import SequenceMatcher import discord diff --git a/dozer/cogs/levels.py b/dozer/cogs/levels.py index 8e5dff41..3a4e3965 100644 --- a/dozer/cogs/levels.py +++ b/dozer/cogs/levels.py @@ -5,8 +5,8 @@ import itertools import math import random -import typing from datetime import timedelta, timezone, datetime +from typing import List, Optional import aiohttp import discord @@ -18,9 +18,10 @@ from dozer.bot import Dozer from dozer.context import DozerContext from ._utils import * +from .. import db blurple = discord.Color.blurple() -from .. import db + ADD_LIMIT = 2147483647 LEVEL_SET_LIMIT = 100000 @@ -455,7 +456,7 @@ async def configureranks(self, ctx: DozerContext): notify_channel = ctx.guild.get_channel(settings.lvl_up_msgs) enabled = "Enabled" if settings.enabled else "Disabled" - embed.set_author(name=ctx.guild, icon_url=ctx.guild.icon.url if ctx.guild.icon else None) + embed.set_author(name=ctx.guild, icon_url=ctx.guild.icon.url) embed.add_field(name=f"Levels are {enabled} for {ctx.guild}", value=f"XP min: {settings.xp_min}\n" f"XP max: {settings.xp_max}\n" f"Cooldown: {settings.xp_cooldown} Seconds\n" @@ -691,7 +692,7 @@ async def rank(self, ctx: DozerContext, *, member: discord.Member = None): @command(aliases=["ranks", "leaderboard"]) @guild_only() - async def levels(self, ctx: DozerContext, start: typing.Optional[discord.Member]): + async def levels(self, ctx: DozerContext, start: Optional[discord.Member]): """Show the XP leaderboard for this server. Leaderboard refreshes every 5 minutes or so""" # Order by total_xp needs a tiebreaker, otherwise all records with equal XP will have the same rank @@ -759,7 +760,7 @@ def __init__(self, guild_id: int, role_id: int, level: int): self.level = level @classmethod - async def get_by(cls, **kwargs): + async def get_by(cls, **kwargs) -> List["XPRole"]: results = await super().get_by(**kwargs) result_list = [] for result in results: @@ -797,7 +798,7 @@ def __init__(self, guild_id: int, user_id: int, total_xp: int, total_messages: i self.last_given_at = last_given_at @classmethod - async def get_by(cls, **kwargs): + async def get_by(cls, **kwargs) -> List["MemberXP"]: results = await super().get_by(**kwargs) result_list = [] for result in results: @@ -863,7 +864,7 @@ def __init__(self, guild_id: int, xp_min: int, xp_max: int, xp_cooldown: int, en self.keep_old_roles = keep_old_roles @classmethod - async def get_by(cls, **kwargs): + async def get_by(cls, **kwargs) -> List["GuildXPSettings"]: results = await super().get_by(**kwargs) result_list = [] for result in results: diff --git a/dozer/cogs/management.py b/dozer/cogs/management.py index d55f771c..289d135e 100644 --- a/dozer/cogs/management.py +++ b/dozer/cogs/management.py @@ -5,6 +5,7 @@ import math import os from datetime import timezone, datetime +from typing import List import discord from dateutil import parser @@ -209,7 +210,7 @@ async def initial_create(cls): PRIMARY KEY (entry_id, request_id) )""") - def __init__(self, guild_id: int, channel_id: int, time: datetime.time, content: str, request_id: str, + def __init__(self, guild_id: int, channel_id: int, time: datetime.time, content: str, request_id: int, header: str = None, requester_id: int = None, entry_id: int = None): super().__init__() self.guild_id = guild_id @@ -222,7 +223,7 @@ def __init__(self, guild_id: int, channel_id: int, time: datetime.time, content: self.entry_id = entry_id @classmethod - async def get_by(cls, **kwargs): + async def get_by(cls, **kwargs) -> List["ScheduledMessages"]: results = await super().get_by(**kwargs) result_list = [] for result in results: diff --git a/dozer/cogs/moderation.py b/dozer/cogs/moderation.py index bcf38a98..ddd453fe 100755 --- a/dozer/cogs/moderation.py +++ b/dozer/cogs/moderation.py @@ -4,9 +4,10 @@ import re import time import typing -from typing import Union +from typing import TYPE_CHECKING, Union, Optional, Type, Set, Tuple, Dict, List import discord +from discord import Guild, Embed, User, Member, Message, Role, PermissionOverwrite, ClientUser from discord.ext import tasks, commands from discord.ext.commands import BadArgument, has_permissions, RoleConverter, guild_only from discord.utils import escape_markdown @@ -14,13 +15,15 @@ from dozer.context import DozerContext from ._utils import * +from .actionlogs import CustomJoinLeaveMessages, send_log, GuildNewMember from .general import blurple +from .teams import TeamNumbers from .. import db -from ..Components.CustomJoinLeaveMessages import send_log, CustomJoinLeaveMessages -__all__ = ["SafeRoleConverter", "Moderation", "NewMemPurgeConfig", "GuildNewMember"] +if TYPE_CHECKING: + from dozer import Dozer -from ..Components.TeamNumbers import TeamNumbers +__all__ = ["SafeRoleConverter", "Moderation", "NewMemPurgeConfig", "GuildNewMember"] MAX_PURGE = 1000 @@ -44,10 +47,10 @@ async def convert(self, ctx: DozerContext, argument: str): class Moderation(Cog): """A cog to handle moderation tasks.""" - def __init__(self, bot: commands.Bot): + def __init__(self, bot: "Dozer"): super().__init__(bot) - self.links_config = db.ConfigCache(GuildMessageLinks) - self.punishment_timer_tasks = [] + self.links_config: db.ConfigCache = db.ConfigCache(GuildMessageLinks) + self.punishment_timer_tasks: List[asyncio.Task] = [] """=== Helper functions ===""" @@ -58,14 +61,14 @@ async def nm_kick_internal(self, guild: discord.Guild = None): entries = await NewMemPurgeConfig.get_by() else: entries = await NewMemPurgeConfig.get_by(guild_id=guild.id) - count = 0 + count: int = 0 for entry in entries: - guild = self.bot.get_guild(entry.guild_id) + guild: Guild = self.bot.get_guild(entry.guild_id) if guild is None: continue for mem in guild.members: if guild.get_role(entry.member_role) not in mem.roles: - delta = datetime.datetime.now() - mem.joined_at + delta: datetime.timedelta = datetime.datetime.now() - mem.joined_at if delta.days >= entry.days: await mem.kick(reason="New member purge cycle") count += 1 @@ -76,10 +79,10 @@ async def nm_kick(self): """Kicks new members""" await self.nm_kick_internal() - async def mod_log(self, actor: discord.Member, action: str, target: Union[discord.User, discord.Member, None], + async def mod_log(self, actor: Member, action: str, target: Union[User, Member, None], reason, orig_channel=None, - embed_color=discord.Color.red(), global_modlog: bool = True, duration: bool = None, - dm: bool = True, guild_override: int = None, extra_fields=None, updated_by: discord.Member = None): + embed_color: discord.Color = discord.Color.red(), global_modlog: bool = True, duration: datetime.timedelta = None, + dm: bool = True, guild_override: int = None, extra_fields=None, updated_by: Member = None): """Generates a modlog embed""" if target is None: @@ -87,10 +90,9 @@ async def mod_log(self, actor: discord.Member, action: str, target: Union[discor else: title = f"User {action}!" - modlog_embed = discord.Embed( + modlog_embed: Embed = Embed( color=embed_color, title=title - ) if target is not None: modlog_embed.add_field(name=f"{action.capitalize()} user", @@ -116,7 +118,7 @@ async def mod_log(self, actor: discord.Member, action: str, target: Union[discor await orig_channel.send("Failed to DM modlog to user") finally: modlog_embed.remove_field(2) - modlog_channel = await GuildModLog.get_by(guild_id=actor.guild.id) if guild_override is None else \ + modlog_channel: List[GuildModLog] = await GuildModLog.get_by(guild_id=actor.guild.id) if guild_override is None else \ await GuildModLog.get_by(guild_id=guild_override) if orig_channel is not None: await orig_channel.send(embed=modlog_embed) @@ -134,10 +136,12 @@ async def mod_log(self, actor: discord.Member, action: str, target: Union[discor if orig_channel is not None: await orig_channel.send("Please configure modlog channel to enable modlog functionality") - async def perm_override(self, member: discord.Member, **overwrites): + @staticmethod + async def perm_override(member: Member, **overwrites): """Applies the given overrides to the given member in their guild.""" for channel in member.guild.channels: - overwrite = channel.overwrites_for(member) + + overwrite: PermissionOverwrite = channel.overwrites_for(member) if channel.permissions_for(member.guild.me).manage_roles: overwrite.update(**overwrites) try: @@ -146,35 +150,37 @@ async def perm_override(self, member: discord.Member, **overwrites): logger.error( f"Failed to catch missing perms in {channel} ({channel.id}) Guild: {channel.guild.id}; Error: {e}") + hm_regex: re.Pattern = re.compile( + r"((?P\d+)y)?((?P\d+)M)?((?P\d+)w)?((?P\d+)d)?((?P\d+)h)?((?P\d+)m)?((" + r"?P\d+)s)?") - hm_regex = re.compile(r"((?P\d+)y)?((?P\d+)M)?((?P\d+)w)?((?P\d+)d)?((?P\d+)h)?((?P\d+)m)?((" - r"?P\d+)s)?") - - def hm_to_seconds(self, hm_str: str): + def hm_to_seconds(self, hm_str: str) -> int: """Converts an hour-minute string to seconds. For example, '1h15m' returns 4500""" - matches = re.match(self.hm_regex, hm_str).groupdict() - years = int(matches.get('years') or 0) - months = int(matches.get('months') or 0) - weeks = int(matches.get('weeks') or 0) - days = int(matches.get('days') or 0) - hours = int(matches.get('hours') or 0) - minutes = int(matches.get('minutes') or 0) - seconds = int(matches.get('seconds') or 0) - val = int((years * 3.154e+7) + (months * 2.628e+6) + (weeks * 604800) + (days * 86400) + (hours * 3600) + (minutes * 60) + seconds) + matches: Dict[str, str] = re.match(self.hm_regex, hm_str).groupdict() + years: int = int(matches.get('years') or 0) + months: int = int(matches.get('months') or 0) + weeks: int = int(matches.get('weeks') or 0) + days: int = int(matches.get('days') or 0) + hours: int = int(matches.get('hours') or 0) + minutes: int = int(matches.get('minutes') or 0) + seconds: int = int(matches.get('seconds') or 0) + val: int = int((years * 3.154e+7) + (months * 2.628e+6) + (weeks * 604800) + (days * 86400) + (hours * 3600) + (minutes * 60) + seconds) # Make sure it is a positive number, and it doesn't exceed the max 32-bit int + # Wait so dozer is going to die at 03:14:07 on Tuesday, 19 January 2038, well I guess that's someone else's problem. + # (yes right now its probably because of a discord non-compatibility, but once they support it we should probably fix it) return max(0, min(2147483647, val)) async def start_punishment_timers(self): """Starts all punishment timers""" - q = await PunishmentTimerRecords.get_by() # no filters: all + q: List[PunishmentTimerRecords] = await PunishmentTimerRecords.get_by() # no filters: all for r in q: - guild = self.bot.get_guild(r.guild_id) - actor = guild.get_member(r.actor_id) - target = guild.get_member(r.target_id) - orig_channel = self.bot.get_channel(r.orig_channel_id) - punishment_type = r.type_of_punishment - reason = r.reason or "" - seconds = max(int(r.target_ts - time.time()), 0.01) + guild: Guild = self.bot.get_guild(r.guild_id) + actor: Member = guild.get_member(r.actor_id) + target: Member = guild.get_member(r.target_id) + orig_channel: discord.TextChannel = self.bot.get_channel(r.orig_channel_id) + punishment_type: int = r.type_of_punishment + reason: str = r.reason or "" + seconds = int(max(r.target_ts - time.time(), 0.01)) await PunishmentTimerRecords.delete(id=r.id) self.bot.loop.create_task( self.punishment_timer(seconds, target, PunishmentTimerRecords.type_map[punishment_type], reason, actor, @@ -186,15 +192,15 @@ async def restart_all_timers(self): """Restarts all timers""" logger.info("Restarting all timers") for timer in self.punishment_timer_tasks: - # timer: asyncio.Task + timer: asyncio.Task logger.info(f"Stopping \"{timer.get_name()}\"") for timer in self.punishment_timer_tasks: timer.cancel() self.punishment_timer_tasks = [] await self.start_punishment_timers() - async def punishment_timer(self, seconds: int, target: discord.Member, punishment, reason: str, - actor: discord.Member, orig_channel=None, + async def punishment_timer(self, seconds: int, target: Member, punishment: Type[Union["Deafen", "Mute"]], reason: str, + actor: Member, orig_channel=None, global_modlog: bool = True): """Asynchronous task that sleeps for a set time to unmute/undeafen a member for a set period of time.""" @@ -209,7 +215,7 @@ async def punishment_timer(self, seconds: int, target: discord.Member, punishmen return # register the timer - ent = PunishmentTimerRecords( + ent: PunishmentTimerRecords = PunishmentTimerRecords( guild_id=target.guild.id, actor_id=actor.id, target_id=target.id, @@ -223,7 +229,7 @@ async def punishment_timer(self, seconds: int, target: discord.Member, punishmen await asyncio.sleep(seconds) - user = await punishment.get_by(member_id=target.id) + user: List[Union[Deafen, Mute]] = await punishment.get_by(member_id=target.id) if len(user) != 0: await self.mod_log(actor=actor, action="un" + punishment.past_participle, @@ -239,16 +245,17 @@ async def punishment_timer(self, seconds: int, target: discord.Member, punishmen await PunishmentTimerRecords.delete(guild_id=target.guild.id, target_id=target.id, type_of_punishment=punishment.type) - async def _check_links_warn(self, msg: discord.Message, role: discord.Role): + @staticmethod + async def _check_links_warn(msg: Message, role: Role): """Warns a user that they can't send links.""" - warn_msg = await msg.channel.send(f"{msg.author.mention}, you need the `{role.name}` role to post links!") + warn_msg: Message = await msg.channel.send(f"{msg.author.mention}, you need the `{role.name}` role to post links!") await asyncio.sleep(3) await warn_msg.delete() - async def check_links(self, msg: discord.Message): + async def check_links(self, msg: Message): """Checks messages for the links role if necessary, then checks if the author is allowed to send links in the server""" if msg.guild is None or not isinstance(msg.author, - discord.Member) or not msg.guild.me.guild_permissions.manage_messages: + Member) or not msg.guild.me.guild_permissions.manage_messages: return config = await self.links_config.query_one(guild_id=msg.guild.id) if config is None: @@ -262,14 +269,14 @@ async def check_links(self, msg: discord.Message): return True return False - async def run_cross_ban(self, ctx: DozerContext, user: discord.User, reason: str): + async def run_cross_ban(self, ctx: DozerContext, user: User, reason: str) -> List[Guild]: """Checks for guilds that are subscribed to the banned members guild""" - subscriptions = await CrossBanSubscriptions.get_by(subscription_id=ctx.guild.id) - bans = [] + subscriptions: List[CrossBanSubscriptions] = await CrossBanSubscriptions.get_by(subscription_id=ctx.guild.id) + bans: List[Guild] = [] for subscription in subscriptions: - sub_guild = self.bot.get_guild(subscription.subscriber_id) + sub_guild: Guild = self.bot.get_guild(subscription.subscriber_id) if sub_guild: - modlog_channel = await GuildModLog.get_by(guild_id=sub_guild.id) + modlog_channel: List[GuildModLog] = await GuildModLog.get_by(guild_id=sub_guild.id) try: await sub_guild.ban(user, reason=f"User Cross Banned from \"{ctx.guild}\" for: {reason}") if modlog_channel: @@ -288,8 +295,8 @@ async def run_cross_ban(self, ctx: DozerContext, user: discord.User, reason: str """=== context-free backend functions ===""" - async def _mute(self, member: discord.Member, reason: str = "No reason provided", seconds: int = 0, - actor: discord.Member = None, orig_channel=None): + async def _mute(self, member: Member, reason: str = "No reason provided", seconds: int = 0, + actor: Member = None, orig_channel=None): """Mutes a user. member: the member to be muted reason: a reason string without a time specifier @@ -314,7 +321,7 @@ async def _mute(self, member: discord.Member, reason: str = "No reason provided" orig_channel=orig_channel)) return True - async def _unmute(self, member: discord.Member): + async def _unmute(self, member: Member): """Unmutes a user.""" results = await Mute.get_by(guild_id=member.guild.id, member_id=member.id) if results: @@ -327,7 +334,7 @@ async def _unmute(self, member: discord.Member): else: return False # member not muted - async def _deafen(self, member: discord.Member, reason: str = "No reason provided", seconds: int = 0, + async def _deafen(self, member: Member, reason: str = "No reason provided", seconds: int = 0, self_inflicted: bool = False, actor=None, orig_channel=None): """Deafens a user. @@ -367,7 +374,7 @@ async def _deafen(self, member: discord.Member, reason: str = "No reason provide global_modlog=not self_inflicted)) return True - async def _undeafen(self, member: discord.Member): + async def _undeafen(self, member: Member): """Undeafens a user.""" results = await Deafen.get_by(guild_id=member.guild.id, member_id=member.id) if results: @@ -390,7 +397,7 @@ async def on_ready(self): await self.nm_kick.start() @Cog.listener('on_member_join') - async def on_member_join(self, member: discord.Member): + async def on_member_join(self, member: Member): """Logs that a member joined.""" users = await Mute.get_by(guild_id=member.guild.id, member_id=member.id) if users: @@ -400,7 +407,7 @@ async def on_member_join(self, member: discord.Member): await self.perm_override(member, read_messages=False) @Cog.listener('on_message') - async def on_message(self, message: discord.Message): + async def on_message(self, message: Message): """Check things when messages come in.""" if message.author.bot or message.guild is None or not message.guild.me.guild_permissions.manage_roles: return @@ -433,7 +440,7 @@ async def on_message(self, message: discord.Message): await send_log(member=message.author) @Cog.listener('on_message_edit') - async def on_message_edit(self, before: discord.Message, after: discord.Message): + async def on_message_edit(self, before: Message, after: Message): """Checks for links""" await self.check_links(after) @@ -441,7 +448,7 @@ async def on_message_edit(self, before: discord.Message, after: discord.Message) @command() @has_permissions(kick_members=True) - async def warn(self, ctx: DozerContext, member: discord.Member, *, reason: str): + async def warn(self, ctx: DozerContext, member: Member, *, reason: str): """Sends a message to the mod log specifying the member has been warned without punishment.""" orig_channel = ctx.interaction.followup if ctx.interaction else ctx.channel await self.mod_log(actor=ctx.author, action="warned", target=member, orig_channel=orig_channel, reason=reason) @@ -456,7 +463,7 @@ async def customlog(self, ctx: DozerContext, *, reason: str): """Sends a message to the mod log with custom text.""" orig_channel = ctx.interaction.followup if ctx.interaction else ctx.channel await self.mod_log(actor=ctx.author, action="", target=None, orig_channel=orig_channel, reason=reason, - embed_color=0xFFC400) + embed_color=discord.Color(0xFFC400)) customlog.example_usage = """ `{prefix}`customlog reason - warns a user for "reason" @@ -467,12 +474,13 @@ async def customlog(self, ctx: DozerContext, *, reason: str): @bot_has_permissions(manage_permissions=True) async def timeout(self, ctx: DozerContext, duration: float): """Set a timeout (no sending messages or adding reactions) on the current channel.""" - settings = await MemberRole.get_by(guild_id=ctx.guild.id) - if len(settings) == 0: + results: List[MemberRole] = await MemberRole.get_by(guild_id=ctx.guild.id) + settings: MemberRole + if len(results) == 0: settings = MemberRole(guild_id=ctx.guild.id, member_role=MemberRole.nullify) await settings.update_or_add() else: - settings = settings[0] + settings = results[0] # None-safe - nonexistent or non-configured role return None member_role = ctx.guild.get_role(settings.member_role) if member_role is not None: @@ -482,22 +490,23 @@ async def timeout(self, ctx: DozerContext, duration: float): '{0.author.mention}, the members role has not been configured. This may not work as expected. Use ' '`{0.prefix}help memberconfig` to see how to set this up.'.format( ctx)) - targets = set(sorted(ctx.guild.roles)[:ctx.author.top_role.position]) + targets: Set[Role] = set(sorted(ctx.guild.roles)[:ctx.author.top_role.position]) - to_restore = [(target, ctx.channel.overwrites_for(target)) for target in targets] + to_restore: List[Tuple[Union[Role, ClientUser, Member], PermissionOverwrite]] = \ + [(target, ctx.channel.overwrites_for(target)) for target in targets] for target, overwrite in to_restore: - new_overwrite = discord.PermissionOverwrite.from_pair(*overwrite.pair()) + new_overwrite: discord.PermissionOverwrite = discord.PermissionOverwrite.from_pair(*overwrite.pair()) new_overwrite.update(send_messages=False, add_reactions=False) await ctx.channel.set_permissions(target, overwrite=new_overwrite) for allow_target in (ctx.me, ctx.author): overwrite = ctx.channel.overwrites_for(allow_target) - new_overwrite = discord.PermissionOverwrite.from_pair(*overwrite.pair()) + new_overwrite: discord.PermissionOverwrite = discord.PermissionOverwrite.from_pair(*overwrite.pair()) new_overwrite.update(send_messages=True) await ctx.channel.set_permissions(allow_target, overwrite=new_overwrite) to_restore.append((allow_target, overwrite)) - e = discord.Embed(title='Timeout - {}s'.format(duration), description='This channel has been timed out.', + e: Embed = Embed(title='Timeout - {}s'.format(duration), description='This channel has been timed out.', color=discord.Color.blue()) e.set_author(name=escape_markdown(ctx.author.display_name), icon_url=ctx.author.display_avatar.replace(format='png', size=32)) msg = await ctx.send(embed=e) @@ -520,19 +529,19 @@ async def timeout(self, ctx: DozerContext, duration: float): @command(aliases=["purge"]) @has_permissions(manage_messages=True) @bot_has_permissions(manage_messages=True, read_message_history=True) - async def prune(self, ctx: DozerContext, target: typing.Optional[discord.Member], num: int): + async def prune(self, ctx: DozerContext, target: typing.Optional[Member], num: int): """Bulk delete a set number of messages from the current channel.""" await ctx.defer() - def check_target(message): + def check_target(message: Message) -> bool: if target is None: return True else: return message.author == target try: - msg = await ctx.message.channel.fetch_message(num) - deleted = await ctx.message.channel.purge(after=msg, limit=MAX_PURGE, check=check_target) + msg: Message = await ctx.message.channel.fetch_message(num) + deleted: List[Message] = await ctx.message.channel.purge(after=msg, limit=MAX_PURGE, check=check_target) await ctx.send( f"Deleted {len(deleted)} messages under request of {ctx.message.author.mention}", delete_after=5) @@ -540,7 +549,7 @@ def check_target(message): if num > MAX_PURGE: await ctx.send("Message cannot be found or you're trying to purge too many messages.") return - deleted = await ctx.message.channel.purge(limit=num + 1, check=check_target) + deleted: List[Message] = await ctx.message.channel.purge(limit=num + 1, check=check_target) await ctx.send( f"Deleted {len(deleted) - 1} messages under request of {ctx.message.author.mention}", delete_after=5) @@ -555,28 +564,28 @@ def check_target(message): @has_permissions(manage_roles=True) async def punishments(self, ctx: DozerContext): """List currently active mutes and deafens in a guild""" - punishments = await PunishmentTimerRecords.get_by(guild_id=ctx.guild.id) - deafen_records = await Deafen.get_by(guild_id=ctx.guild.id) - self_inflicted = [record.member_id for record in deafen_records if record.self_inflicted] - deafens = [punishment for punishment in punishments if - punishment.type_of_punishment == 2 and punishment.target_id not in self_inflicted] - self_deafens = [punishment for punishment in punishments if - punishment.type_of_punishment == 2 and punishment.target_id in self_inflicted] - mutes = [punishment for punishment in punishments if punishment.type_of_punishment == 1] - embed = discord.Embed(title=f"Active punishments in {ctx.guild}", color=blurple) + punishments: List[PunishmentTimerRecords] = await PunishmentTimerRecords.get_by(guild_id=ctx.guild.id) + deafen_records: List[Deafen] = await Deafen.get_by(guild_id=ctx.guild.id) + self_inflicted: List[int] = [record.member_id for record in deafen_records if record.self_inflicted] + deafens: List[PunishmentTimerRecords] = [punishment for punishment in punishments if + punishment.type_of_punishment == 2 and punishment.target_id not in self_inflicted] + self_deafens: List[PunishmentTimerRecords] = [punishment for punishment in punishments if + punishment.type_of_punishment == 2 and punishment.target_id in self_inflicted] + mutes: List[PunishmentTimerRecords] = [punishment for punishment in punishments if punishment.type_of_punishment == 1] + embed: Embed = Embed(title=f"Active punishments in {ctx.guild}", color=blurple) embed.set_footer(text='Triggered by ' + ctx.author.display_name) - def get_mention(target_id): - member = ctx.guild.get_member(target_id) + def get_mention(target_id: int) -> str: + member: Member = ctx.guild.get_member(target_id) if member: return member.mention else: return "**Member left**" - def get_name(target_id): - user = ctx.bot.get_user(target_id) + def get_name(target_id) -> str: + user: User = ctx.bot.get_user(target_id) if user: - return user + return str(user) else: return "**Unknown#NONE**" @@ -607,7 +616,7 @@ def get_name(target_id): @command() @has_permissions(ban_members=True) @bot_has_permissions(ban_members=True) - async def ban(self, ctx: DozerContext, user_mention: discord.User, *, reason: str = "No reason provided"): + async def ban(self, ctx: DozerContext, user_mention: User, *, reason: str = "No reason provided"): """Bans the user mentioned.""" orig_channel = ctx.interaction.followup if ctx.interaction else ctx.channel await self.mod_log(actor=ctx.author, action="banned", target=user_mention, reason=reason, @@ -629,7 +638,7 @@ async def ban(self, ctx: DozerContext, user_mention: discord.User, *, reason: st @command() @has_permissions(ban_members=True) @bot_has_permissions(ban_members=True) - async def unban(self, ctx: DozerContext, user_mention: discord.User, *, reason: str = "No reason provided"): + async def unban(self, ctx: DozerContext, user_mention: User, *, reason: str = "No reason provided"): """Unbans the user mentioned.""" orig_channel = ctx.interaction.followup if ctx.interaction else ctx.channel await ctx.guild.unban(user_mention, reason=reason) @@ -643,7 +652,7 @@ async def unban(self, ctx: DozerContext, user_mention: discord.User, *, reason: @command() @has_permissions(kick_members=True) @bot_has_permissions(kick_members=True) - async def kick(self, ctx: DozerContext, user_mention: discord.User, *, reason: str = "No reason provided"): + async def kick(self, ctx: DozerContext, user_mention: User, *, reason: str = "No reason provided"): """Kicks the user mentioned.""" orig_channel = ctx.interaction.followup if ctx.interaction else ctx.channel await self.mod_log(actor=ctx.author, action="kicked", target=user_mention, reason=reason, @@ -657,7 +666,7 @@ async def kick(self, ctx: DozerContext, user_mention: discord.User, *, reason: s @command() @has_permissions(manage_roles=True) @bot_has_permissions(manage_permissions=True) - async def mute(self, ctx: DozerContext, member_mentions: discord.Member, *, reason: str = "No reason provided"): + async def mute(self, ctx: DozerContext, member_mentions: Member, *, reason: str = "No reason provided"): """Mute a user to prevent them from sending messages""" orig_channel = ctx.interaction.followup if ctx.interaction else ctx.channel async with ctx.typing(): @@ -679,7 +688,7 @@ async def mute(self, ctx: DozerContext, member_mentions: discord.Member, *, reas @command() @has_permissions(manage_roles=True) @bot_has_permissions(manage_permissions=True) - async def unmute(self, ctx: DozerContext, member_mentions: discord.Member, *, reason="No reason provided"): + async def unmute(self, ctx: DozerContext, member_mentions: Member, *, reason="No reason provided"): """Unmute a user to allow them to send messages again.""" orig_channel = ctx.interaction.followup if ctx.interaction else ctx.channel async with ctx.typing(): @@ -696,7 +705,7 @@ async def unmute(self, ctx: DozerContext, member_mentions: discord.Member, *, re @command() @has_permissions(manage_roles=True) @bot_has_permissions(manage_permissions=True) - async def deafen(self, ctx: DozerContext, member_mentions: discord.Member, *, reason: str = "No reason provided"): + async def deafen(self, ctx: DozerContext, member_mentions: Member, *, reason: str = "No reason provided"): """Deafen a user to prevent them from both sending messages but also reading messages.""" orig_channel = ctx.interaction.followup if ctx.interaction else ctx.channel async with ctx.typing(): @@ -744,7 +753,7 @@ async def selfdeafen(self, ctx: DozerContext, *, reason: str = "No reason provid @command() @has_permissions(manage_roles=True) @bot_has_permissions(manage_permissions=True) - async def undeafen(self, ctx: DozerContext, member_mentions: discord.Member, *, reason: str = "No reason provided"): + async def undeafen(self, ctx: DozerContext, member_mentions: Member, *, reason: str = "No reason provided"): """Undeafen a user to allow them to see message and send message again.""" orig_channel = ctx.interaction.followup if ctx.interaction else ctx.channel async with ctx.typing(): @@ -761,7 +770,7 @@ async def undeafen(self, ctx: DozerContext, member_mentions: discord.Member, *, """ @command() - async def voicekick(self, ctx: DozerContext, member: discord.Member, reason: str = "No reason provided"): + async def voicekick(self, ctx: DozerContext, member: Member, reason: str = "No reason provided"): """Kick a user from voice chat. This is most useful if their perms to rejoin have already been removed.""" async with ctx.typing(): if member.voice is None: @@ -810,7 +819,7 @@ async def modlogconfig(self, ctx: DozerContext, channel_mentions: discord.TextCh @command() @has_permissions(manage_guild=True) - async def verifymember(self, ctx, member: discord.Member): + async def verifymember(self, ctx, member: Member): """Command to verify a member who may not have a team number set, or who hasn't sent the required verification message. """ config = await GuildNewMember.get_by(guild_id=ctx.guild.id) @@ -830,7 +839,7 @@ async def verifymember(self, ctx, member: discord.Member): @command() @has_permissions(administrator=True) - async def nmconfig(self, ctx: DozerContext, channel_mention: discord.TextChannel, role: discord.Role, *, message, + async def nmconfig(self, ctx: DozerContext, channel_mention: discord.TextChannel, role: Role, *, message, requireteam=None): """Sets the config for the new members channel""" config = await GuildNewMember.get_by(guild_id=ctx.guild.id) @@ -848,8 +857,8 @@ async def nmconfig(self, ctx: DozerContext, channel_mention: discord.TextChannel role_name = role.name await ctx.send( - "New Member Channel configured as: {channel}. Role configured as: {role}. Team numbers required: {" - "required}. Message: {message}".format( + "New Member Channel configured as: {channel}. Role configured as: {role}. Team numbers required: " + "{required}. Message: {message}".format( channel=channel_mention.name, role=role_name, required=requireteam, message=message)) nmconfig.example_usage = """ @@ -859,7 +868,7 @@ async def nmconfig(self, ctx: DozerContext, channel_mention: discord.TextChannel @command() @has_permissions(administrator=True) - async def nmpurgeconfig(self, ctx: DozerContext, role: discord.Role, days: int): + async def nmpurgeconfig(self, ctx: DozerContext, role: Role, days: int): """Sets the config for the new members purge""" config = NewMemPurgeConfig(guild_id=ctx.guild.id, member_role=role.id, days=days) await config.update_or_add() @@ -872,7 +881,7 @@ async def nmpurgeconfig(self, ctx: DozerContext, role: discord.Role, days: int): @command() @has_permissions(administrator=True) - async def memberconfig(self, ctx: DozerContext, *, member_role: SafeRoleConverter): + async def memberconfig(self, ctx: DozerContext, *, member_role: Role): """ Set the member role for the guild. The member role is the role used for the timeout command. It should be a role that all members of the server have. @@ -901,19 +910,19 @@ async def memberconfig(self, ctx: DozerContext, *, member_role: SafeRoleConverte @command() @has_permissions(administrator=True) @bot_has_permissions(manage_messages=True) - async def linkscrubconfig(self, ctx: DozerContext, *, link_role: SafeRoleConverter): + async def linkscrubconfig(self, ctx: DozerContext, *, link_role: Role): """ Set a role that users must have in order to post links. This accepts the safe default role conventions that the memberconfig command does. """ if link_role >= ctx.author.top_role: raise BadArgument('Link role cannot be higher than your top role!') - - settings = await GuildMessageLinks.get_by(guild_id=ctx.guild.id) - if len(settings) == 0: + settings: GuildMessageLinks + results: List[GuildMessageLinks] = await GuildMessageLinks.get_by(guild_id=ctx.guild.id) + if len(results) == 0: settings = GuildMessageLinks(guild_id=ctx.guild.id, role_id=link_role.id) else: - settings = settings[0] + settings = results[0] settings.role_id = link_role.id await settings.update_or_add() self.links_config.invalidate_entry(guild_id=ctx.guild.id) @@ -934,7 +943,7 @@ async def crossbans(self, ctx: DozerContext): """Cross ban""" subscriptions = await CrossBanSubscriptions.get_by(subscriber_id=ctx.guild.id) subscribers = await CrossBanSubscriptions.get_by(subscription_id=ctx.guild.id) - embed = discord.Embed(title="Cross ban subscriptions", color=blurple) + embed = Embed(title="Cross ban subscriptions", color=blurple) for field_number, target_ids in enumerate(chunk(subscriptions, 10)): embed.add_field(name='Subscriptions', value='\n'.join(f"{self.bot.get_guild(sub_id.subscription_id)} | {sub_id.subscription_id}" @@ -956,19 +965,19 @@ async def view_subs(self, ctx: DozerContext): @crossbans.command() @has_permissions(administrator=True) @bot_has_permissions(ban_members=True) - async def subscribe(self, ctx: DozerContext, guild_id): + async def subscribe(self, ctx: DozerContext, guild_id: str): """Subscribe to a guild to cross ban from""" - guild_id = int(guild_id) + guild_id: int = int(guild_id) guild = self.bot.get_guild(guild_id) if guild: - subscription = CrossBanSubscriptions( + subscription: CrossBanSubscriptions = CrossBanSubscriptions( subscriber_id=ctx.guild.id, subscription_id=guild.id ) await subscription.update_or_add() - embed = discord.Embed(title='Success!', - description=f"**{ctx.guild}** is now subscribed to receive crossbans from **{guild}**", - color=blurple) + embed: Embed = Embed(title='Success!', + description=f"**{ctx.guild}** is now subscribed to receive crossbans from **{guild}**", + color=blurple) embed.set_footer(text='Triggered by ' + ctx.author.display_name) await ctx.send(embed=embed) else: @@ -985,10 +994,10 @@ async def unsubscribe(self, ctx: DozerContext, guild_id): ) if int(result.split(" ", 1)[1]) > 0: - guild = self.bot.get_guild(guild_id) - embed = discord.Embed(title='Success!', - description=f"**{ctx.guild}** is no longer subscribed to receive crossbans from **{guild}**", - color=blurple) + guild: Guild = self.bot.get_guild(guild_id) + embed = Embed(title='Success!', + description=f"**{ctx.guild}** is no longer subscribed to receive crossbans from **{guild}**", + color=blurple) embed.set_footer(text='Triggered by ' + ctx.author.display_name) await ctx.send(embed=embed) else: @@ -1016,11 +1025,11 @@ async def initial_create(cls): def __init__(self, member_id: int, guild_id: int): super().__init__() - self.member_id = member_id - self.guild_id = guild_id + self.member_id: int = member_id + self.guild_id: int = guild_id @classmethod - async def get_by(cls, **kwargs): + async def get_by(cls, **kwargs) -> List["Mute"]: results = await super().get_by(**kwargs) result_list = [] for result in results: @@ -1069,12 +1078,12 @@ async def initial_create(cls): def __init__(self, member_id: int, guild_id: int, self_inflicted: bool): super().__init__() - self.member_id = member_id - self.guild_id = guild_id - self.self_inflicted = self_inflicted + self.member_id: int = member_id + self.guild_id: int = guild_id + self.self_inflicted: bool = self_inflicted @classmethod - async def get_by(cls, **kwargs): + async def get_by(cls, **kwargs) -> List["Deafen"]: results = await super().get_by(**kwargs) result_list = [] for result in results: @@ -1102,12 +1111,12 @@ async def initial_create(cls): def __init__(self, guild_id: int, modlog_channel: int, name: str): super().__init__() - self.guild_id = guild_id - self.modlog_channel = modlog_channel - self.name = name + self.guild_id: int = guild_id + self.modlog_channel: int = modlog_channel + self.name: str = name @classmethod - async def get_by(cls, **kwargs): + async def get_by(cls, **kwargs) -> List["GuildModLog"]: results = await super().get_by(**kwargs) result_list = [] for result in results: @@ -1134,11 +1143,11 @@ async def initial_create(cls): )""") def __init__(self, subscriber_id: int, subscription_id: int): - self.subscriber_id = subscriber_id - self.subscription_id = subscription_id + self.subscriber_id: int = subscriber_id + self.subscription_id: int = subscription_id @classmethod - async def get_by(cls, **kwargs): + async def get_by(cls, **kwargs) -> List["CrossBanSubscriptions"]: results = await super().get_by(**kwargs) result_list = [] for result in results: @@ -1165,11 +1174,11 @@ async def initial_create(cls): def __init__(self, guild_id: int, member_role: int = None): super().__init__() - self.guild_id = guild_id - self.member_role = member_role + self.guild_id: int = guild_id + self.member_role: int = member_role @classmethod - async def get_by(cls, **kwargs): + async def get_by(cls, **kwargs) -> List["MemberRole"]: results = await super().get_by(**kwargs) result_list = [] for result in results: @@ -1196,12 +1205,12 @@ async def initial_create(cls): def __init__(self, guild_id: int, member_role: int, days: int): super().__init__() - self.guild_id = guild_id - self.member_role = member_role - self.days = days + self.guild_id: int = guild_id + self.member_role: int = member_role + self.days: int = days @classmethod - async def get_by(cls, **kwargs): + async def get_by(cls, **kwargs) -> List["NewMemPurgeConfig"]: results = await super().get_by(**kwargs) result_list = [] for result in results: @@ -1212,52 +1221,6 @@ async def get_by(cls, **kwargs): return result_list -class GuildNewMember(db.DatabaseTable): - """Holds new member info""" - __tablename__ = 'new_members' - __uniques__ = 'guild_id' - - @classmethod - async def initial_create(cls): - """Create the table in the database""" - async with db.Pool.acquire() as conn: - await conn.execute(f""" - CREATE TABLE {cls.__tablename__} ( - guild_id bigint PRIMARY KEY, - channel_id bigint NOT NULL, - role_id bigint NOT NULL, - message varchar NOT NULL - )""") - - def __init__(self, guild_id: int, channel_id: int, role_id: int, message: str, require_team: bool): - super().__init__() - self.guild_id = guild_id - self.channel_id = channel_id - self.role_id = role_id - self.message = message - self.require_team = require_team - - @classmethod - async def get_by(cls, **kwargs): - results = await super().get_by(**kwargs) - result_list = [] - for result in results: - obj = GuildNewMember(guild_id=result.get("guild_id"), channel_id=result.get("channel_id"), - role_id=result.get("role_id"), message=result.get("message"), - require_team=result.get("require_team")) - result_list.append(obj) - return result_list - - async def version_1(self): - """DB migration v1""" - async with db.Pool.acquire() as conn: - await conn.execute(f""" - ALTER TABLE {self.__tablename__} ADD require_team bool NOT NULL DEFAULT false; - """) - - __versions__ = [version_1] - - class GuildMessageLinks(db.DatabaseTable): """Contains information for link scrubbing""" __tablename__ = 'guild_msg_links' @@ -1275,11 +1238,11 @@ async def initial_create(cls): def __init__(self, guild_id: int, role_id: int = None): super().__init__() - self.guild_id = guild_id - self.role_id = role_id + self.guild_id: int = guild_id + self.role_id: int = role_id @classmethod - async def get_by(cls, **kwargs): + async def get_by(cls, **kwargs) -> List["GuildMessageLinks"]: results = await super().get_by(**kwargs) result_list = [] for result in results: @@ -1311,20 +1274,20 @@ async def initial_create(cls): )""") def __init__(self, guild_id: int, actor_id: int, target_id: int, type_of_punishment: int, target_ts: int, - orig_channel_id: int = None, reason: str = None, input_id: int = None, self_inflicted: bool =False): + orig_channel_id: int = None, reason: Optional[str] = None, input_id: int = None, self_inflicted: bool = False): super().__init__() - self.id = input_id - self.guild_id = guild_id - self.actor_id = actor_id - self.target_id = target_id - self.type_of_punishment = type_of_punishment - self.target_ts = target_ts - self.orig_channel_id = orig_channel_id - self.reason = reason - self.self_inflicted = self_inflicted + self.id: int = input_id + self.guild_id: int = guild_id + self.actor_id: int = actor_id + self.target_id: int = target_id + self.type_of_punishment: int = type_of_punishment + self.target_ts: int = target_ts + self.orig_channel_id: Optional[int] = orig_channel_id + self.reason: Optional[str] = reason + self.self_inflicted: bool = self_inflicted @classmethod - async def get_by(cls, **kwargs): + async def get_by(cls, **kwargs) -> List["PunishmentTimerRecords"]: results = await super().get_by(**kwargs) result_list = [] for result in results: diff --git a/dozer/cogs/music.py b/dozer/cogs/music.py index a0e7aa58..65e417bf 100644 --- a/dozer/cogs/music.py +++ b/dozer/cogs/music.py @@ -1,10 +1,15 @@ """Music commands, currently disabled""" +from typing import TYPE_CHECKING + import lavaplayer from discord.ext import commands from loguru import logger from dozer.cogs._utils import command +if TYPE_CHECKING: + from dozer import Dozer + class Music(commands.Cog): """Music commands cog""" @@ -128,7 +133,7 @@ async def repeat(self, ctx: commands.Context, status: bool): await ctx.send("Repeated the queue.") -async def setup(bot: commands.Bot): +async def setup(bot: "Dozer"): """Adds the cog to the bot""" # await bot.add_cog(Music(bot)) logger.info("Music cog is temporarily disabled due to code bugs.") diff --git a/dozer/cogs/namegame.py b/dozer/cogs/namegame.py index e8c45c13..1cf741eb 100755 --- a/dozer/cogs/namegame.py +++ b/dozer/cogs/namegame.py @@ -5,10 +5,10 @@ import traceback from collections import OrderedDict from functools import wraps +from typing import TYPE_CHECKING, List import discord import tbapi -from discord.ext import commands from discord.ext.commands import has_permissions from discord.utils import escape_markdown from fuzzywuzzy import fuzz @@ -18,6 +18,9 @@ from ._utils import * from .. import db +if TYPE_CHECKING: + from dozer import Dozer + SUPPORTED_MODES = ["frc", "ftc"] @@ -56,7 +59,7 @@ async def wrapper(self, ctx: DozerContext, *args, **kwargs): return wrapper -class NameGameSession(): +class NameGameSession: """NameGame session object""" def __init__(self, mode: str): @@ -87,14 +90,14 @@ def __init__(self, mode: str): self.vote_embed = None self.vote_task = None - def create_embed(self, title: str = "", description: str = "", color=discord.Color.blurple(), extra_fields=[], + def create_embed(self, title: str = "", description: str = "", color: discord.Colour = discord.Color.blurple(), extra_fields=[], start: bool = False): """Creates an embed.""" v = "Starting " if start else "Current " embed = discord.Embed() embed.title = title embed.description = description - embed.color = color + embed.colour = color embed.add_field(name="Players", value=", ".join([escape_markdown(p.display_name) for p in self.players.keys()]) or "n/a") embed.add_field(name=v + "Player", value=self.current_player) embed.add_field(name=v + "Number", value=self.number or "Wildcard") @@ -165,7 +168,7 @@ def get_picked(self): class NameGame(Cog): """Namegame commands""" - def __init__(self, bot: commands.Bot): + def __init__(self, bot: "Dozer"): super().__init__(bot) with gzip.open("ftc_teams.pickle.gz") as f: raw_teams = pickle.load(f) @@ -199,7 +202,7 @@ async def ng(self, ctx: DozerContext): async def info(self, ctx: DozerContext): """Show a description of the robotics team name game and how to play.""" game_embed = discord.Embed() - game_embed.color = discord.Color.magenta() + game_embed.colour = discord.Color.magenta() game_embed.title = "How to play" game_embed.description = "This is a very simple little game where players will name a team number and name that " \ "starts with the last digit of the last named team. Some more specific rules are below:" @@ -537,7 +540,7 @@ async def pick(self, ctx: DozerContext, team: int, *, name: str): game.vote_player = ctx.author game.vote_correct = False vote_embed = discord.Embed() - vote_embed.color = discord.Color.gold() + vote_embed.colour = discord.Color.gold() vote_embed.title = "A vote is needed!" vote_embed.description = "A player has made a choice with less than 50% similarity. The details of the " \ "pick are below. Click on the two emoji to vote if this is correct or not. A" \ @@ -646,7 +649,7 @@ async def strike(self, ctx: DozerContext, game: NameGameSession, player: discord record = NameGameLeaderboard(user_id=winner.id, wins=1, game_mode=game.mode) await record.update_or_add() win_embed = discord.Embed() - win_embed.color = discord.Color.gold() + win_embed.colour = discord.Color.gold() win_embed.title = "We have a winner!" win_embed.add_field(name="Winning Player", value=winner) win_embed.add_field(name="Wins Total", value=record.wins) @@ -658,7 +661,8 @@ async def strike(self, ctx: DozerContext, game: NameGameSession, player: discord if not game.running: self.games.pop(ctx.channel.id) - async def display_info(self, ctx: DozerContext, game: NameGameSession): + @staticmethod + async def display_info(ctx: DozerContext, game: NameGameSession): """Displays info about the current game""" info_embed = discord.Embed(title="Current Game Info", color=discord.Color.blue()) info_embed.add_field(name="Game Type", value=game.mode.upper()) @@ -686,12 +690,14 @@ async def skip_player(self, ctx: DozerContext, game: NameGameSession, player: di await self.strike(ctx, game, player) # send an embed that starts a new turn - async def send_turn_embed(self, ctx: DozerContext, game: NameGameSession, **kwargs): + @staticmethod + async def send_turn_embed(ctx: DozerContext, game: NameGameSession, **kwargs): """Sends an embed that starts a new turn""" game.turn_embed = game.create_embed(**kwargs) game.turn_msg = await ctx.send(embed=game.turn_embed) - async def notify(self, ctx: DozerContext, game: NameGameSession, msg: str): + @staticmethod + async def notify(ctx: DozerContext, game: NameGameSession, msg: str): """Notifies people in the channel when it's their turn.""" if game.pings_enabled: await ctx.send(msg) @@ -708,7 +714,7 @@ async def on_reaction_add(self, reaction: discord.Reaction, user: discord.Member await self._on_reaction(game, reaction, user, 1) # also handle voting logic - ctx = await self.bot.get_context(reaction.message) + ctx: DozerContext = await self.bot.get_context(reaction.message) if game.vote_correct: if game.fail_tally > .5 * len(game.players): await ctx.send(f"The decision was overruled! Player {game.vote_player.mention} is given a strike!") @@ -745,7 +751,8 @@ async def on_reaction_remove(self, reaction: discord.Reaction, user: discord.Mem return await self._on_reaction(game, reaction, user, -1) - async def _on_reaction(self, game: NameGameSession, reaction: discord.Reaction, user: discord.Member, inc: int): + @staticmethod + async def _on_reaction(game: NameGameSession, reaction: discord.Reaction, user: discord.Member, inc: int): """Handles pass/fail reactions""" if reaction.message.id == game.vote_msg.id and user in game.players: if reaction.emoji == '❌': @@ -756,7 +763,7 @@ async def _on_reaction(self, game: NameGameSession, reaction: discord.Reaction, return game @keep_alive - async def game_turn_countdown(self, ctx: DozerContext, game): + async def game_turn_countdown(self, ctx: DozerContext, game: NameGameSession): """Counts down the time remaining left in the turn""" await asyncio.sleep(1) async with game.state_lock: @@ -821,7 +828,7 @@ def __init__(self, guild_id: int, mode: str, pings_enabled: int, channel_id: int self.pings_enabled = pings_enabled @classmethod - async def get_by(cls, **kwargs): + async def get_by(cls, **kwargs) -> List["NameGameConfig"]: results = await super().get_by(**kwargs) result_list = [] for result in results: @@ -855,7 +862,7 @@ def __init__(self, user_id: int, game_mode: str, wins: int): self.wins = wins @classmethod - async def get_by(cls, **kwargs): + async def get_by(cls, **kwargs) -> List["NameGameLeaderboard"]: results = await super().get_by(**kwargs) result_list = [] for result in results: diff --git a/dozer/cogs/news.py b/dozer/cogs/news.py index a06ff1e0..7bd8b5a0 100644 --- a/dozer/cogs/news.py +++ b/dozer/cogs/news.py @@ -3,10 +3,12 @@ import datetime import traceback from asyncio import CancelledError, InvalidStateError +from typing import List, TYPE_CHECKING from xml.etree import ElementTree import aiohttp import discord +from discord import Embed from discord.ext import tasks from discord.ext.commands import guild_only, has_permissions, BadArgument from loguru import logger @@ -16,6 +18,9 @@ from .. import db from ..sources import DataBasedSource, Source, sources +if TYPE_CHECKING: + from dozer import Dozer + def str_or_none(obj): """A helper function to make sure str(None) returns None instead of 'None' """ @@ -32,7 +37,7 @@ class News(Cog): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) - self.updated = True + self.updated: bool = True self.http_source = None self.sources = {} self.get_new_posts.change_interval(minutes=self.bot.config['news']['check_interval']) @@ -86,7 +91,7 @@ async def get_new_posts(self): channel_dict[sub.data][channel] = sub.kind - # We've gotten all of the channels we need to post to, lets get the posts and post them now + # We've gotten all the channels we need to post to, lets get the posts and post them now try: posts = await source.get_new_posts() except ElementTree.ParseError: @@ -148,10 +153,10 @@ async def on_guild_channel_delete(self, channel: discord.abc.GuildChannel): @guild_only() async def news(self, ctx: DozerContext): """Show help for news subscriptions""" - embed = discord.Embed(title="How to subscribe to News Sources", - description="Dozer has built in news scrapers to allow you to review up to date news" - "in specific channels. See below on how to manage your server's " - "subscriptions") + embed = Embed(title="How to subscribe to News Sources", + description="Dozer has built in news scrapers to allow you to review up to date news" + "in specific channels. See below on how to manage your server's " + "subscriptions") embed.add_field(name="How to add a subscription", value=f"To add a source, for example, Chief Delphi to a channel, you can use the command" f"`{ctx.bot.command_prefix}news add #channel cd`") @@ -182,7 +187,7 @@ async def view_help(self, ctx: DozerContext): @news.command() @has_permissions(manage_guild=True) @guild_only() - async def add(self, ctx: DozerContext, channel: discord.TextChannel, source: Source, kind='embed', data=None): + async def add(self, ctx: DozerContext, channel: discord.TextChannel, source: Source, kind: str = 'embed', data=None): """Add a new subscription of a given source to a channel.""" if data is None and kind not in self.kinds and isinstance(source, DataBasedSource): @@ -245,8 +250,8 @@ async def add(self, ctx: DozerContext, channel: discord.TextChannel, source: Sou kind=kind, data=str_or_none(data_obj)) await new_sub.update_or_add() - embed = discord.Embed(title=f"Channel #{channel.name} subscribed to {source.full_name}", - description="New posts should be in this channel soon.") + embed = Embed(title=f"Channel #{channel.name} subscribed to {source.full_name}", + description="New posts should be in this channel soon.") embed.add_field(name="Kind", value=kind) if isinstance(source, DataBasedSource): embed.add_field(name="Data", value=data_obj.full_name) @@ -287,7 +292,7 @@ async def remove(self, ctx: DozerContext, channel: discord.TextChannel, source: f"with data {data} found. Please contact the Dozer administrator for help.") return - data_exists = await NewsSubscription.get_by(source=source.short_name, data=str(data_obj)) + data_exists: List[NewsSubscription] = await NewsSubscription.get_by(source=source.short_name, data=str(data_obj)) if len(data_exists) > 1: removed = await source.remove_data(data_obj) if not removed: @@ -312,8 +317,8 @@ async def remove(self, ctx: DozerContext, channel: discord.TextChannel, source: await NewsSubscription.delete(id=sub[0].id) - embed = discord.Embed(title=f"Subscription of channel #{channel.name} to {source.full_name} removed", - description="Posts from this source will no longer appear.") + embed = Embed(title=f"Subscription of channel #{channel.name} to {source.full_name} removed", + description="Posts from this source will no longer appear.") if isinstance(source, DataBasedSource): embed.add_field(name="Data", value=sub[0].data) @@ -327,7 +332,7 @@ async def remove(self, ctx: DozerContext, channel: discord.TextChannel, source: @news.command(name='sources') async def list_sources(self, ctx: DozerContext): """List all available sources to subscribe to.""" - embed = discord.Embed(title="All available sources to subscribe to.") + embed = Embed(title="All available sources to subscribe to.") embed.description = f"To subscribe to any of these sources, use the `{ctx.prefix}news add " \ f" ` command." @@ -352,7 +357,7 @@ async def list_subscriptions(self, ctx: DozerContext, channel: discord.TextChann results = await NewsSubscription.get_by(guild_id=ctx.guild.id) if not results: - embed = discord.Embed(title="News Subscriptions for {}".format(ctx.guild.name)) + embed = Embed(title="News Subscriptions for {}".format(ctx.guild.name)) embed.description = f"No news subscriptions found for this guild! Add one using `{self.bot.command_prefix}" \ f"news add `" embed.colour = discord.Color.red() @@ -371,7 +376,7 @@ async def list_subscriptions(self, ctx: DozerContext, channel: discord.TextChann except KeyError: channels[channel] = [result] - embed = discord.Embed() + embed = Embed() embed.title = "News Subscriptions for {}".format(ctx.guild.name) embed.colour = discord.Color.dark_orange() for found_channel, lst in channels.items(): @@ -422,7 +427,7 @@ async def get_exception(self, ctx: DozerContext): if exception is None: await ctx.send("No exception occurred.") else: - tb_str = traceback.format_exception(type(exception), exception, exception.__traceback__) + tb_str: List[str] = traceback.format_exception(type(exception), value=exception, tb=exception.__traceback__) await ctx.send(f"```{''.join(tb_str)}```") except CancelledError: await ctx.send("Task has been cancelled.") @@ -433,7 +438,7 @@ async def get_exception(self, ctx: DozerContext): get_exception.example_usage = "`{prefix}news get_exception` - Get the exception that the loop failed with" -async def setup(bot): +async def setup(bot: "Dozer"): """Setup cog""" await bot.add_cog(News(bot)) @@ -459,15 +464,15 @@ async def initial_create(cls): def __init__(self, channel_id: int, guild_id: int, source: str, kind: str, data: str = None, sub_id: int = None): super().__init__() - self.id = sub_id - self.channel_id = channel_id - self.guild_id = guild_id - self.source = source - self.kind = kind - self.data = data + self.id: int = sub_id + self.channel_id: int = channel_id + self.guild_id: int = guild_id + self.source: str = source + self.kind: str = kind + self.data: str = data @classmethod - async def get_by(cls, **kwargs): + async def get_by(cls, **kwargs) -> List["NewsSubscription"]: results = await super().get_by(**kwargs) result_list = [] for result in results: diff --git a/dozer/cogs/roles.py b/dozer/cogs/roles.py index c57105c2..13fdb8ef 100755 --- a/dozer/cogs/roles.py +++ b/dozer/cogs/roles.py @@ -1,11 +1,11 @@ """Role management commands.""" import asyncio import time -import typing +from typing import TYPE_CHECKING, Set, Optional import discord import discord.utils -from discord.ext import commands +from discord import Message, Role, Member, Embed from discord.ext.commands import cooldown, BucketType, has_permissions, BadArgument, guild_only from discord.utils import escape_markdown @@ -15,13 +15,16 @@ from .. import db from ..db import * -blurple = discord.Color.blurple() +if TYPE_CHECKING: + from dozer import Dozer + +blurple: discord.Color = discord.Color.blurple() class Roles(Cog): """Commands for role management.""" - def __init__(self, bot: commands.Bot): + def __init__(self, bot: "Dozer"): super().__init__(bot) for loop_command in self.giveme.walk_commands(): @loop_command.before_invoke # pylint: disable=cell-var-from-loop @@ -39,12 +42,12 @@ def calculate_epoch_time(time_string: str): return time_release @staticmethod - async def safe_message_fetch(ctx: DozerContext, menu=None, channel: discord.TextChannel = None, - message_id: int = None): + async def safe_message_fetch(ctx: DozerContext, menu: "RoleMenu" = None, channel: Optional[discord.TextChannel] = None, + message_id: int = None) -> Message: """Used to safely get a message and raise an error message cannot be found""" try: if menu: - channel = ctx.guild.get_channel(menu.channel_id) + channel: discord.TextChannel = ctx.guild.get_channel(menu.channel_id) return await channel.fetch_message(menu.message_id) else: if channel: @@ -55,27 +58,27 @@ async def safe_message_fetch(ctx: DozerContext, menu=None, channel: discord.Text raise BadArgument("That message does not exist or is not in this channel!") @staticmethod - async def add_to_message(message: discord.Message, entry): + async def add_to_message(message: Message, entry: "ReactionRole"): """Adds a reaction role to a message""" await message.add_reaction(entry.reaction) await entry.update_or_add() @staticmethod - async def del_from_message(message: discord.Message, entry): + async def del_from_message(message: Message, entry: "ReactionRole"): """Removes a reaction from a message""" await message.clear_reaction(entry.reaction) @Cog.listener('on_ready') async def on_ready(self): """Restore tempRole timers on bot startup""" - q = await TempRoleTimerRecords.get_by() # no filters: all + q: List[TempRoleTimerRecords] = await TempRoleTimerRecords.get_by() # no filters: all for record in q: self.bot.loop.create_task(self.removal_timer(record)) @Cog.listener() async def on_raw_message_delete(self, payload: discord.RawMessageDeleteEvent): """Used to remove dead reaction role entries""" - message_id = payload.message_id + message_id: int = payload.message_id await ReactionRole.delete(message_id=message_id) await RoleMenu.delete(message_id=message_id) @@ -91,13 +94,13 @@ async def on_raw_reaction_remove(self, payload: discord.RawReactionActionEvent): async def on_raw_reaction_action(self, payload: discord.RawReactionActionEvent): """Called whenever a reaction is added or removed""" - message_id = payload.message_id - reaction = str(payload.emoji) - reaction_roles = await ReactionRole.get_by(message_id=message_id, reaction=reaction) + message_id: int = payload.message_id + reaction: str = str(payload.emoji) + reaction_roles: List[ReactionRole] = await ReactionRole.get_by(message_id=message_id, reaction=reaction) if len(reaction_roles): - guild = self.bot.get_guild(payload.guild_id) - member = guild.get_member(payload.user_id) - role = guild.get_role(reaction_roles[0].role_id) + guild: discord.Guild = self.bot.get_guild(payload.guild_id) + member: Member = guild.get_member(payload.user_id) + role: Role = guild.get_role(reaction_roles[0].role_id) if member.bot: return if role: @@ -109,13 +112,13 @@ async def on_raw_reaction_action(self, payload: discord.RawReactionActionEvent): except discord.Forbidden: logger.debug(f"Unable to add reaction role in guild {guild} due to missing permissions") - async def removal_timer(self, record): + async def removal_timer(self, record: "TempRoleTimerRecords"): """Asynchronous task that sleeps for a set time to remove a role from a member after a set period of time.""" - guild = self.bot.get_guild(int(record.guild_id)) - target = guild.get_member(int(record.target_id)) - target_role = guild.get_role(int(record.target_role_id)) - removal_time = record.removal_ts + guild: discord.Guild = self.bot.get_guild(int(record.guild_id)) + target: Member = guild.get_member(int(record.target_id)) + target_role: Role = guild.get_role(int(record.target_role_id)) + removal_time: int = record.removal_ts # Max function is used to make sure the delay is not negative time_delta = max(int(removal_time - time.time()), 1) @@ -127,7 +130,7 @@ async def removal_timer(self, record): await TempRoleTimerRecords.delete(id=record.id) @Cog.listener('on_guild_role_update') - async def on_role_edit(self, old, new): + async def on_role_edit(self, old: Role, new: Role): """Changes role names in database when they are changed in the guild""" if self.normalize(old.name) != self.normalize(new.name): results = await GiveableRole.get_by(norm_name=self.normalize(old.name), guild_id=old.guild.id) @@ -136,25 +139,25 @@ async def on_role_edit(self, old, new): await GiveableRole.from_role(new).update_or_add() @Cog.listener('on_guild_role_delete') - async def on_role_delete(self, old): + async def on_role_delete(self, old: Role): """Deletes roles from database when the roles are deleted from the guild. """ - results = await GiveableRole.get_by(norm_name=self.normalize(old.name), guild_id=old.guild.id) + results: List[GiveableRole] = await GiveableRole.get_by(norm_name=self.normalize(old.name), guild_id=old.guild.id) if results: logger.debug(f"Role {old.id} deleted. Deleting from database.") await GiveableRole.delete(role_id=old.id) @Cog.listener('on_member_join') - async def on_member_join(self, member: discord.Member): + async def on_member_join(self, member: Member): """Restores a member's roles when they join if they have joined before.""" - me = member.guild.me - top_restorable = me.top_role.position if me.guild_permissions.manage_roles else 0 - restore = await MissingRole.get_by(guild_id=member.guild.id, member_id=member.id) + me: Member = member.guild.me + top_restorable: int = me.top_role.position if me.guild_permissions.manage_roles else 0 + restore: List[MissingRole] = await MissingRole.get_by(guild_id=member.guild.id, member_id=member.id) if len(restore) == 0: return # New member - nothing to restore valid, cant_give, missing = set(), set(), set() for missing_role in restore: - role = member.guild.get_role(missing_role.role_id) + role: Role = member.guild.get_role(missing_role.role_id) if role is None: # Role with that ID does not exist missing.add(missing_role.role_name) elif role.position > top_restorable: @@ -169,16 +172,16 @@ async def on_member_join(self, member: discord.Member): if not missing and not cant_give: return - e = discord.Embed(title='Welcome back to the {} server, {}!'.format(member.guild.name, member), - color=discord.Color.blue()) + e: Embed = Embed(title='Welcome back to the {} server, {}!'.format(member.guild.name, member), + color=discord.Color.blue()) if missing: e.add_field(name='I couldn\'t restore these roles, as they don\'t exist.', value='\n'.join(sorted(missing))) if cant_give: e.add_field(name='I couldn\'t restore these roles, as I don\'t have permission.', value='\n'.join(sorted(cant_give))) try: - dest_id = await CustomJoinLeaveMessages.get_by(guild_id=member.guild.id) - dest = member.guild.get_channel(dest_id[0].memberlog_channel) + dest_id: List[CustomJoinLeaveMessages] = await CustomJoinLeaveMessages.get_by(guild_id=member.guild.id) + dest: discord.TextChannel = member.guild.get_channel(dest_id[0].channel_id) await dest.send(embed=e) except discord.Forbidden: pass @@ -186,27 +189,28 @@ async def on_member_join(self, member: discord.Member): pass @Cog.listener('on_member_remove') - async def on_member_remove(self, member: discord.Member): + async def on_member_remove(self, member: Member): """Saves a member's roles when they leave in case they rejoin.""" - guild_id = member.guild.id - member_id = member.id + guild_id: int = member.guild.id + member_id: int = member.id for role in member.roles[1:]: # Exclude the @everyone role - db_member = MissingRole(role_id=role.id, role_name=role.name, guild_id=guild_id, member_id=member_id) + db_member: MissingRole = MissingRole(role_id=role.id, role_name=role.name, guild_id=guild_id, member_id=member_id) await db_member.update_or_add() - async def giveme_purge(self, rolelist): + @staticmethod + async def giveme_purge(rolelist: List["GiveableRole"]): """Purges roles in the giveme database that no longer exist. The argument is a list of GiveableRole objects.""" for role in rolelist: - dbrole = await GiveableRole.get_by(role_id=role.role_id) + dbrole: List[GiveableRole] = await GiveableRole.get_by(role_id=role.role_id) if dbrole: await GiveableRole.delete(role_id=role.role_id) async def ctx_purge(self, ctx: DozerContext): """Purges all giveme roles that no longer exist in a guild""" - counter = 0 - roles = await GiveableRole.get_by(guild_id=ctx.guild.id) - guildroles = [] - rolelist = [] + counter: int = 0 + roles: List["GiveableRole"] = await GiveableRole.get_by(guild_id=ctx.guild.id) + guildroles: List[int] = [] + rolelist: List["GiveableRole"] = [] for i in ctx.guild.roles: guildroles.append(i.id) for role in roles: @@ -216,50 +220,51 @@ async def ctx_purge(self, ctx: DozerContext): await self.giveme_purge(rolelist) return counter - async def on_guild_role_delete(self, role: discord.Role): + async def on_guild_role_delete(self, role: "GiveableRole"): """Automatically delete giveme roles if they are deleted from the guild""" - rolelist = [role] + rolelist: List[GiveableRole] = [role] await self.giveme_purge(rolelist) @group(invoke_without_command=True) @bot_has_permissions(manage_roles=True) - async def giveme(self, ctx: DozerContext, *, roles): + async def giveme(self, ctx: DozerContext, *, roles: str): """Give you one or more giveable roles, separated by commas.""" - norm_names = [self.normalize(name) for name in roles.split(',')] - giveable_ids = [tup.role_id for tup in await GiveableRole.get_by(guild_id=ctx.guild.id) if - tup.norm_name in norm_names] - valid = set(role for role in ctx.guild.roles if role.id in giveable_ids) + norm_names: List[str] = [self.normalize(name) for name in roles.split(',')] + giveable_ids: List[int] = [tup.role_id for tup in await GiveableRole.get_by(guild_id=ctx.guild.id) if + tup.norm_name in norm_names] + valid: Set[Role] = set(role for role in ctx.guild.roles if role.id in giveable_ids) - already_have = valid & set(ctx.author.roles) - given = valid - already_have + already_have: Set[Role] = valid & set(ctx.author.roles) + given: Set[Role] = valid - already_have await ctx.author.add_roles(*given) - e = discord.Embed(color=discord.Color.blue()) + e: Embed = Embed(color=discord.Color.blue()) if given: - given_names = sorted((role.name for role in given), key=str.casefold) + given_names: List[str] = sorted((role.name for role in given), key=str.casefold) e.add_field(name='Gave you {} role(s)!'.format(len(given)), value='\n'.join(given_names), inline=False) if already_have: already_have_names = sorted((role.name for role in already_have), key=str.casefold) e.add_field(name='You already have {} role(s)!'.format(len(already_have)), value='\n'.join(already_have_names), inline=False) - extra = len(norm_names) - len(valid) + extra: int = len(norm_names) - len(valid) if extra > 0: e.add_field(name='{} role(s) could not be found!'.format(extra), value='Use `{0.prefix}{0.invoked_with} list` to find valid giveable roles!'.format(ctx), inline=False) - msg = await ctx.send(embed=e) + msg: Message = await ctx.send(embed=e) try: await msg.add_reaction("❌") except discord.Forbidden: return try: - await self.bot.wait_for('reaction_add', timeout=30, check=lambda reaction, reactor: - reaction.emoji == "❌" and reactor == ctx.author and reaction.message == msg) + await self.bot.wait_for( + 'reaction_add', timeout=30, check=lambda reaction, reactor: + reaction.emoji == "❌" and reactor == ctx.author and reaction.message == msg) try: await msg.delete() except discord.HTTPException: logger.debug( - f"Unable to delete message to {ctx.member} in guild {ctx.guild} Reason: HTTPException") + f"Unable to delete message to {ctx.author} in guild {ctx.guild} Reason: HTTPException") try: await ctx.message.delete() except discord.Forbidden: @@ -288,7 +293,7 @@ async def role(self, ctx: DozerContext, roles): @has_permissions(manage_roles=True) async def purge(self, ctx: DozerContext): """Force a purge of giveme roles that no longer exist in the guild""" - counter = await self.ctx_purge(ctx) + counter: int = await self.ctx_purge(ctx) await ctx.send("Purged {} role(s)".format(counter)) @giveme.command() @@ -299,14 +304,14 @@ async def add(self, ctx: DozerContext, *, name: str): Similar to create, but will use an existing role if one exists.""" if ',' in name: raise BadArgument('giveable role names must not contain commas!') - norm_name = self.normalize(name) - settings = await GiveableRole.get_by(guild_id=ctx.guild.id, norm_name=norm_name) + norm_name: str = self.normalize(name) + settings: List[GiveableRole] = await GiveableRole.get_by(guild_id=ctx.guild.id, norm_name=norm_name) if settings: raise BadArgument('that role already exists and is giveable!') - candidates = [role for role in ctx.guild.roles if self.normalize(role.name) == norm_name] + candidates: List[Role] = [role for role in ctx.guild.roles if self.normalize(role.name) == norm_name] - if not candidates: - role = await ctx.guild.create_role(name=name, reason='Giveable role created by {}'.format(ctx.author)) + if not len(candidates): + role: Role = await ctx.guild.create_role(name=name, reason='Giveable role created by {}'.format(ctx.author)) elif len(candidates) == 1: role = candidates[0] else: @@ -328,11 +333,11 @@ async def create(self, ctx: DozerContext, *, name: str): Similar to add, but will always create a new role.""" if ',' in name: raise BadArgument('giveable role names must not contain commas!') - norm_name = self.normalize(name) - settings = await GiveableRole.get_by(guild_id=ctx.guild.id, norm_name=norm_name) - if not settings: - role = await ctx.guild.create_role(name=name, reason='Giveable role created by {}'.format(ctx.author)) - settings = GiveableRole.from_role(role) + norm_name: str = self.normalize(name) + results: List[GiveableRole] = await GiveableRole.get_by(guild_id=ctx.guild.id, norm_name=norm_name) + if not results: + role: Role = await ctx.guild.create_role(name=name, reason='Giveable role created by {}'.format(ctx.author)) + settings: GiveableRole = GiveableRole.from_role(role) await settings.update_or_add() await ctx.send( 'Role "{0}" created! Use `{1}{2} {0}` to get it!'.format(role.name, ctx.prefix, ctx.command.parent)) @@ -349,45 +354,46 @@ async def create(self, ctx: DozerContext, *, name: str): @bot_has_permissions(manage_roles=True) async def remove(self, ctx: DozerContext, *, roles): """Removes multiple giveable roles from you. Names must be separated by commas.""" - norm_names = [self.normalize(name) for name in roles.split(',')] - query = await GiveableRole.get_by(guild_id=ctx.guild.id) + norm_names: List[str] = [self.normalize(name) for name in roles.split(',')] + query: List[GiveableRole] = await GiveableRole.get_by(guild_id=ctx.guild.id) roles_to_remove = [] for role in query: if role.norm_name in norm_names: roles_to_remove.append(role) - removable_ids = [tup.role_id for tup in roles_to_remove] - valid = set(role for role in ctx.guild.roles if role.id in removable_ids) + removable_ids: List[int] = [tup.role_id for tup in roles_to_remove] + valid: Set[Role] = set(role for role in ctx.guild.roles if role.id in removable_ids) - removed = valid & set(ctx.author.roles) - dont_have = valid - removed + removed: Set[Role] = valid & set(ctx.author.roles) + dont_have: Set[Role] = valid - removed await ctx.author.remove_roles(*removed) - e = discord.Embed(color=discord.Color.blue()) + e: Embed = Embed(color=discord.Color.blue()) if removed: - removed_names = sorted((role.name for role in removed), key=str.casefold) + removed_names: List[str] = sorted((role.name for role in removed), key=str.casefold) e.add_field(name='Removed {} role(s)!'.format(len(removed)), value='\n'.join(removed_names), inline=False) if dont_have: - dont_have_names = sorted((role.name for role in dont_have), key=str.casefold) + dont_have_names: List[str] = sorted((role.name for role in dont_have), key=str.casefold) e.add_field(name='You didn\'t have {} role(s)!'.format(len(dont_have)), value='\n'.join(dont_have_names), inline=False) - extra = len(norm_names) - len(valid) + extra: int = len(norm_names) - len(valid) if extra > 0: e.add_field(name='{} role(s) could not be found!'.format(extra), value='Use `{0.prefix}{0.invoked_with} list` to find valid giveable roles!'.format(ctx), inline=False) - msg = await ctx.send(embed=e) + msg: Message = await ctx.send(embed=e) try: await msg.add_reaction("❌") except discord.Forbidden: return try: - await self.bot.wait_for('reaction_add', timeout=30, check=lambda reaction, reactor: - reaction.emoji == "❌" and reactor == ctx.author and reaction.message == msg) + await self.bot.wait_for( + 'reaction_add', timeout=30, check=lambda reaction, reactor: + reaction.emoji == "❌" and reactor == ctx.author and reaction.message == msg) try: await msg.delete() except discord.HTTPException: logger.debug( - f"Unable to delete message to {ctx.member} in guild {ctx.guild} Reason: HTTPException") + f"Unable to delete message to {ctx.author} in guild {ctx.guild} Reason: HTTPException") try: await ctx.message.delete() except discord.Forbidden: @@ -397,7 +403,7 @@ async def remove(self, ctx: DozerContext, *, roles): await msg.clear_reactions() except discord.HTTPException: logger.debug( - f"Unable to clear reactions from message to {ctx.member} in guild {ctx.guild} Reason: HTTPException") + f"Unable to clear reactions from message to {ctx.author} in guild {ctx.guild} Reason: HTTPException") return remove.example_usage = """ @@ -412,10 +418,10 @@ async def delete(self, ctx: DozerContext, *, name: str): """Deletes and removes a giveable role.""" if ',' in name: raise BadArgument('this command only works with single roles!') - norm_name = self.normalize(name) - valid_ids = set(role.id for role in ctx.guild.roles) - roles = await GiveableRole.get_by(guild_id=ctx.guild.id, norm_name=norm_name) - valid_roles = [] + norm_name: str = self.normalize(name) + valid_ids: Set[int] = set(role.id for role in ctx.guild.roles) + roles: List[GiveableRole] = await GiveableRole.get_by(guild_id=ctx.guild.id, norm_name=norm_name) + valid_roles: List[GiveableRole] = [] for role_option in roles: if role_option.role_id in valid_ids: valid_roles.append(role_option) @@ -424,7 +430,7 @@ async def delete(self, ctx: DozerContext, *, name: str): elif len(valid_roles) > 1: raise BadArgument('multiple giveable roles with that name exist!') else: - role = ctx.guild.get_role(valid_roles[0].role_id) + role: Role = ctx.guild.get_role(valid_roles[0].role_id) await GiveableRole.delete(guild_id=ctx.guild.id, norm_name=valid_roles[0].norm_name) await role.delete(reason='Giveable role deleted by {}'.format(ctx.author)) await ctx.send('Role "{0}" deleted!'.format(role)) @@ -438,8 +444,8 @@ async def delete(self, ctx: DozerContext, *, name: str): @bot_has_permissions(manage_roles=True) async def list_roles(self, ctx: DozerContext): """Lists all giveable roles for this server.""" - names = [tup.name for tup in await GiveableRole.get_by(guild_id=ctx.guild.id)] - e = discord.Embed(title='Roles available to self-assign', color=discord.Color.blue()) + names: List[str] = [tup.name for tup in await GiveableRole.get_by(guild_id=ctx.guild.id)] + e: Embed = Embed(title='Roles available to self-assign', color=discord.Color.blue()) e.description = '\n'.join(sorted(names, key=str.casefold)) await ctx.send(embed=e) @@ -448,7 +454,7 @@ async def list_roles(self, ctx: DozerContext): """ @staticmethod - def normalize(name): + def normalize(name: str) -> str: """Normalizes a role for consistency in the DB.""" return name.strip().casefold() @@ -460,10 +466,10 @@ async def removefromlist(self, ctx: DozerContext, *, name: str): # Honestly this is the giveme delete command but modified to only delete from the DB if ',' in name: raise BadArgument('this command only works with single roles!') - norm_name = self.normalize(name) - valid_ids = set(role.id for role in ctx.guild.roles) - roles = await GiveableRole.get_by(guild_id=ctx.guild.id) - valid_roles = [] + norm_name: str = self.normalize(name) + valid_ids: Set[int] = set(role.id for role in ctx.guild.roles) + roles: List[GiveableRole] = await GiveableRole.get_by(guild_id=ctx.guild.id) + valid_roles: List[GiveableRole] = [] for role_option in roles: if role_option.norm_name == norm_name and role_option.role_id in valid_ids: valid_roles.append(role_option) @@ -476,13 +482,13 @@ async def removefromlist(self, ctx: DozerContext, *, name: str): await ctx.send('Role "{0}" deleted from list!'.format(name)) delete.example_usage = """ - `{prefix}giveme removefromlist Java` - removes the role "Java" from the list of giveable roles but does not remove it from the server or members who have it + `{prefix}giveme removefromlist Java` - removes the role "Java" from the list of giveable roles. """ @command() @bot_has_permissions(manage_roles=True, embed_links=True) @has_permissions(manage_roles=True) - async def tempgive(self, ctx: DozerContext, member: discord.Member, length, *, role: discord.Role): + async def tempgive(self, ctx: DozerContext, member: Member, length: int, *, role: Role): """Temporarily gives a member a role for a set time. Not restricted to giveable roles.""" if role > ctx.author.top_role: raise BadArgument('Cannot give roles higher than your top role!') @@ -490,11 +496,11 @@ async def tempgive(self, ctx: DozerContext, member: discord.Member, length, *, r if role > ctx.me.top_role: raise BadArgument('Cannot give roles higher than my top role!') - remove_time = self.calculate_epoch_time(length) + remove_time: int = self.calculate_epoch_time(str(length)) if remove_time < time.time(): raise BadArgument('Cannot use negative role time') - ent = TempRoleTimerRecords( + ent: TempRoleTimerRecords = TempRoleTimerRecords( guild_id=member.guild.id, target_id=member.id, target_role_id=role.id, @@ -504,7 +510,7 @@ async def tempgive(self, ctx: DozerContext, member: discord.Member, length, *, r await member.add_roles(role) await ent.update_or_add() self.bot.loop.create_task(self.removal_timer(ent)) - e = discord.Embed(color=blurple) + e: Embed = Embed(color=blurple) e.add_field(name='Success!', value='I gave {} to {}, for {}!'.format(role.mention, member.mention, length)) e.set_footer(text='Triggered by ' + escape_markdown(ctx.author.display_name)) await ctx.send(embed=e) @@ -516,12 +522,12 @@ async def tempgive(self, ctx: DozerContext, member: discord.Member, length, *, r @command() @bot_has_permissions(manage_roles=True, embed_links=True) @has_permissions(manage_roles=True) - async def give(self, ctx: DozerContext, member: discord.Member, *, role: discord.Role): + async def give(self, ctx: DozerContext, member: Member, *, role: Role): """Gives a member a role. Not restricted to giveable roles.""" if role > ctx.author.top_role: raise BadArgument('Cannot give roles higher than your top role!') await member.add_roles(role) - e = discord.Embed(color=blurple) + e: Embed = Embed(color=blurple) e.add_field(name='Success!', value='I gave {} to {}!'.format(role, member)) e.set_footer(text='Triggered by ' + escape_markdown(ctx.author.display_name)) await ctx.send(embed=e) @@ -533,12 +539,12 @@ async def give(self, ctx: DozerContext, member: discord.Member, *, role: discord @command() @bot_has_permissions(manage_roles=True, embed_links=True) @has_permissions(manage_roles=True) - async def take(self, ctx: DozerContext, member: discord.Member, *, role: discord.Role): + async def take(self, ctx: DozerContext, member: Member, *, role: Role): """Takes a role from a member. Not restricted to giveable roles.""" if role > ctx.author.top_role: raise BadArgument('Cannot take roles higher than your top role!') await member.remove_roles(role) - e = discord.Embed(color=blurple) + e: Embed = Embed(color=blurple) e.add_field(name='Success!', value='I took {} from {}!'.format(role, member)) e.set_footer(text='Triggered by ' + escape_markdown(ctx.author.display_name)) await ctx.send(embed=e) @@ -547,15 +553,14 @@ async def take(self, ctx: DozerContext, member: discord.Member, *, role: discord `{prefix}take cooldude#1234 Java` - takes any role named Java, giveable or not, from cooldude """ - async def update_role_menu(self, ctx: DozerContext, menu): + async def update_role_menu(self, ctx: DozerContext, menu: "RoleMenu"): """Updates a reaction role menu""" - menu = int(menu) - menu_message = await self.safe_message_fetch(ctx, menu=menu) + menu_message: Message = await self.safe_message_fetch(ctx, menu=menu) - menu_embed = discord.Embed(title=f"Role Menu: {menu.name}") - menu_entries = await ReactionRole.get_by(message_id=menu.message_id) + menu_embed: Embed = Embed(title=f"Role Menu: {menu.name}") + menu_entries: List[ReactionRole] = await ReactionRole.get_by(message_id=menu.message_id) for entry in menu_entries: - role = ctx.guild.get_role(entry.role_id) + role: Role = ctx.guild.get_role(entry.role_id) menu_embed.add_field(name=f"Role: {role}", value=f"{entry.reaction}: {role.mention}", inline=False) menu_embed.set_footer(text=f"React to get a role\nMenu ID: {menu_message.id}, Total roles: {len(menu_entries)}") await menu_message.edit(embed=menu_embed) @@ -566,20 +571,21 @@ async def update_role_menu(self, ctx: DozerContext, menu): @guild_only() async def rolemenu(self, ctx: DozerContext): """Base command for setting up and tracking reaction roles""" - rolemenus = await RoleMenu.get_by(guild_id=ctx.guild.id) - embed = discord.Embed(title="Reaction Role Messages", color=blurple) - boundroles = [] + rolemenus: List[RoleMenu] = await RoleMenu.get_by(guild_id=ctx.guild.id) + embed: Embed = Embed(title="Reaction Role Messages", color=blurple) + boundroles: List[int] = [] for rolemenu in rolemenus: menu_entries = await ReactionRole.get_by(message_id=rolemenu.message_id) for role in menu_entries: boundroles.append(role.message_id) - link = f"https://discordapp.com/channels/{rolemenu.guild_id}/{rolemenu.channel_id}/{rolemenu.message_id}" + link: str = f"https://discordapp.com/channels/{rolemenu.guild_id}/{rolemenu.channel_id}/{rolemenu.message_id}" embed.add_field(name=f"Menu: {rolemenu.name}", value=f"[Contains {len(menu_entries)} role watchers]({link})", inline=False) unbound_reactions = await db.Pool.fetch( f"""SELECT * FROM {ReactionRole.__tablename__} WHERE message_id != all($1)""" f""" and guild_id = $2;""", boundroles, ctx.guild.id) - combined_unbound = {} # The following code is too group individual reaction role entries into the messages they are associated with + # The following code is meant to group individual reaction role entries into the messages they are associated with + combined_unbound: Dict[str, Dict[str, int]] = {} if unbound_reactions: for unbound in unbound_reactions: guild_id = unbound.get("guild_id") @@ -614,7 +620,7 @@ async def rolemenu(self, ctx: DozerContext): @guild_only() async def createmenu(self, ctx: DozerContext, channel: discord.TextChannel, *, name: str): """Creates a blank reaction role menu""" - menu_embed = discord.Embed(title=f"Role Menu: {name}", description="React to get a role") + menu_embed = Embed(title=f"Role Menu: {name}", description="React to get a role") message = await channel.send(embed=menu_embed) e = RoleMenu( @@ -628,7 +634,7 @@ async def createmenu(self, ctx: DozerContext, channel: discord.TextChannel, *, n menu_embed.set_footer(text=f"Menu ID: {message.id}, Total roles: {0}") await message.edit(embed=menu_embed) - e = discord.Embed(color=blurple) + e = Embed(color=blurple) link = f"https://discordapp.com/channels/{ctx.guild.id}/{message.channel.id}/{message.id}" e.add_field(name='Success!', value=f"I added created role menu [\"{name}\"]({link}) in channel {channel.mention}") @@ -643,9 +649,8 @@ async def createmenu(self, ctx: DozerContext, channel: discord.TextChannel, *, n @bot_has_permissions(manage_roles=True, embed_links=True) @has_permissions(manage_roles=True) @guild_only() - async def addrole(self, ctx: DozerContext, channel: typing.Optional[discord.TextChannel], message_id, - role: discord.Role, - emoji: discord.Emoji): + async def addrole(self, ctx: DozerContext, channel: Optional[discord.TextChannel], message_id: str, + role: Role, emoji: discord.Emoji): """Adds a reaction role to a message or a role menu""" message_id = int(message_id) if isinstance(emoji, discord.Emoji) and emoji.guild_id != ctx.guild.id: @@ -683,7 +688,7 @@ async def addrole(self, ctx: DozerContext, channel: typing.Optional[discord.Text if menu: await self.update_role_menu(ctx, menu) - e = discord.Embed(color=blurple) + e = Embed(color=blurple) link = f"https://discordapp.com/channels/{ctx.guild.id}/{message.channel.id}/{message_id}" shortcut = f"[{menu.name}]({link})" if menu else f"[{message_id}]({link})" e.add_field(name='Success!', value=f"I added {role.mention} to message \"{shortcut}\" with reaction {emoji}") @@ -701,8 +706,8 @@ async def addrole(self, ctx: DozerContext, channel: typing.Optional[discord.Text @bot_has_permissions(manage_roles=True, embed_links=True) @has_permissions(manage_roles=True) @guild_only() - async def delrole(self, ctx: DozerContext, channel: typing.Optional[discord.TextChannel], message_id, - role: discord.Role): + async def delrole(self, ctx: DozerContext, channel: Optional[discord.TextChannel], message_id: str, + role: Role): """Removes a reaction role from a message or a role menu""" message_id = int(message_id) menu_return = await RoleMenu.get_by(guild_id=ctx.guild.id, message_id=message_id) @@ -716,9 +721,9 @@ async def delrole(self, ctx: DozerContext, channel: typing.Optional[discord.Text if menu: await self.update_role_menu(ctx, menu) - e = discord.Embed(color=blurple) - link = f"https://discordapp.com/channels/{ctx.guild.id}/{message.channel.id}/{message_id}" - shortcut = f"[{menu.name}]({link})" if menu else f"[{message_id}]({link})" + e = Embed(color=blurple) + link: str = f"https://discordapp.com/channels/{ctx.guild.id}/{message.channel.id}/{message_id}" + shortcut: str = f"[{menu.name}]({link})" if menu else f"[{message_id}]({link})" e.add_field(name='Success!', value=f"I removed {role.mention} from message {shortcut}") e.set_footer(text='Triggered by ' + escape_markdown(ctx.author.display_name)) await ctx.send(embed=e) @@ -751,13 +756,13 @@ async def initial_create(cls): def __init__(self, guild_id: int, channel_id: int, message_id: int, name: str): super().__init__() - self.guild_id = guild_id - self.channel_id = channel_id - self.message_id = message_id - self.name = name + self.guild_id: int = guild_id + self.channel_id: int = channel_id + self.message_id: int = message_id + self.name: str = name @classmethod - async def get_by(cls, **kwargs): + async def get_by(cls, **kwargs) -> List["RoleMenu"]: results = await super().get_by(**kwargs) result_list = [] for result in results: @@ -788,14 +793,14 @@ async def initial_create(cls): def __init__(self, guild_id: int, channel_id: int, message_id: int, role_id: int, reaction: str): super().__init__() - self.guild_id = guild_id - self.channel_id = channel_id - self.message_id = message_id - self.role_id = role_id - self.reaction = reaction + self.guild_id: int = guild_id + self.channel_id: int = channel_id + self.message_id: int = message_id + self.role_id: int = role_id + self.reaction: str = reaction @classmethod - async def get_by(cls, **kwargs): + async def get_by(cls, **kwargs) -> List["ReactionRole"]: results = await super().get_by(**kwargs) result_list = [] for result in results: @@ -825,13 +830,13 @@ async def initial_create(cls): def __init__(self, guild_id: int, role_id: int, norm_name: str, name: str): super().__init__() - self.guild_id = guild_id - self.role_id = role_id - self.name = name - self.norm_name = norm_name + self.guild_id: int = guild_id + self.role_id: int = role_id + self.name: str = name + self.norm_name: str = norm_name @classmethod - async def get_by(cls, **kwargs): + async def get_by(cls, **kwargs) -> List["GiveableRole"]: results = await super().get_by(**kwargs) result_list = [] for result in results: @@ -841,8 +846,8 @@ async def get_by(cls, **kwargs): return result_list @classmethod - def from_role(cls, role: discord.Role): - """Creates a GiveableRole record from a discord.Role.""" + def from_role(cls, role: Role) -> "GiveableRole": + """Creates a GiveableRole record from a Role.""" return cls(role_id=role.id, name=role.name, norm_name=Roles.normalize(role.name), guild_id=role.guild.id) @@ -866,13 +871,13 @@ async def initial_create(cls): def __init__(self, guild_id: int, member_id: int, role_id: int, role_name: str): super().__init__() - self.guild_id = guild_id - self.member_id = member_id - self.role_id = role_id - self.role_name = role_name + self.guild_id: int = guild_id + self.member_id: int = member_id + self.role_id: int = role_id + self.role_name: str = role_name @classmethod - async def get_by(cls, **kwargs): + async def get_by(cls, **kwargs) -> List["MissingRole"]: results = await super().get_by(**kwargs) result_list = [] for result in results: @@ -903,14 +908,14 @@ async def initial_create(cls): def __init__(self, guild_id: int, target_id: int, target_role_id: int, removal_ts: int, input_id: int = None): super().__init__() - self.id = input_id - self.guild_id = guild_id - self.target_id = target_id - self.target_role_id = target_role_id - self.removal_ts = removal_ts + self.id: int = input_id + self.guild_id: int = guild_id + self.target_id: int = target_id + self.target_role_id: int = target_role_id + self.removal_ts: int = removal_ts @classmethod - async def get_by(cls, **kwargs): + async def get_by(cls, **kwargs) -> List["TempRoleTimerRecords"]: results = await super().get_by(**kwargs) result_list = [] for result in results: diff --git a/dozer/cogs/starboard.py b/dozer/cogs/starboard.py index df554b62..030882ab 100644 --- a/dozer/cogs/starboard.py +++ b/dozer/cogs/starboard.py @@ -1,8 +1,9 @@ """Cog to post specific 'Hall of Fame' messages in a specific channel""" import asyncio +from typing import List, Set, Optional +from typing import TYPE_CHECKING import discord -from discord.ext import commands from discord.ext.commands import guild_only, has_permissions from discord.utils import escape_markdown from loguru import logger @@ -11,14 +12,16 @@ from ._utils import * from .. import db -MAX_EMBED = 1024 -LOCK_TIME = .1 -FORCE_TRY_TIME = 1 +if TYPE_CHECKING: + from dozer import Dozer -VIDEO_FORMATS = ['.mp4', '.mov', 'webm'] +MAX_EMBED: int = 1024 +LOCK_TIME: float = .1 +FORCE_TRY_TIME: int = 1 +VIDEO_FORMATS: List[str] = ['.mp4', '.mov', 'webm'] -async def is_cancelled(emoji, message: discord.Message, me, author: discord.Member = None): +async def is_cancelled(emoji: str, message: discord.Message, me: discord.Member, author: discord.Member = None) -> bool: """Determine if the message has cancellation reacts""" if author is None: author = message.author @@ -35,13 +38,13 @@ async def is_cancelled(emoji, message: discord.Message, me, author: discord.Memb return False -def make_starboard_embed(msg: discord.Message, reaction_count: int): +def make_starboard_embed(msg: discord.Message, reaction_count: int) -> discord.Embed: """Makes a starboard embed.""" - e = discord.Embed(color=msg.author.color, title=f"New Starred Message in #{msg.channel.name}", - description=msg.content, url=msg.jump_url) - e.set_author(name=escape_markdown(msg.author.display_name), icon_url=msg.author.display_avatar) + e: discord.Embed = discord.Embed(color=msg.author.color, title=f"New Starred Message in #{msg.channel.name}", + description=msg.content, url=msg.jump_url) + e.set_author(name=escape_markdown(msg.author.display_name), icon_url=msg.author.avatar) - view_link = f" [[view]]({msg.jump_url})" + view_link: str = f" [[view]]({msg.jump_url})" e.add_field(name="Link:", value=view_link) if len(msg.attachments) > 1: @@ -61,20 +64,20 @@ def make_starboard_embed(msg: discord.Message, reaction_count: int): class Starboard(Cog): """Cog to post specific 'Hall of Fame' messages in a specific channel""" - def __init__(self, bot: commands.Bot): + def __init__(self, bot: "Dozer"): super().__init__(bot) - self.config_cache = db.ConfigCache(StarboardConfig) - self.locked_messages = set() + self.config_cache: db.ConfigCache = db.ConfigCache(StarboardConfig) + self.locked_messages: Set = set() - def make_config_embed(self, ctx: DozerContext, title, config): + def make_config_embed(self, ctx: DozerContext, title, config) -> discord.Embed: """Makes a config embed.""" - channel = self.bot.get_channel(config.channel_id) + channel: discord.TextChannel = self.bot.get_channel(config.channel_id) if channel is None: return discord.Embed(title="Starboard channel no longer exists!", description="Please reconfigure the starboard to fix this.", color=discord.colour.Color.red()) - e = discord.Embed(title=title, color=discord.Color.gold()) + e: discord.Embed = discord.Embed(title=title, color=discord.Color.gold()) e.add_field(name="Starboard Channel", value=channel.mention) e.add_field(name="Starboard Emoji", value=config.star_emoji) e.add_field(name="Cancel Emoji", value=config.cancel_emoji) @@ -82,7 +85,7 @@ def make_config_embed(self, ctx: DozerContext, title, config): e.set_footer(text=f"For more information, try {ctx.prefix}help starboard") return e - async def send_to_starboard(self, config, message: discord.Message, reaction_count: int, add_react: bool = True): + async def send_to_starboard(self, config: "StarboardConfig", message: discord.Message, reaction_count: int, add_react: bool = True): """Given a message which may or may not exist, send it to the starboard""" starboard_channel = message.guild.get_channel(config.channel_id) if starboard_channel is None: @@ -103,7 +106,7 @@ async def send_to_starboard(self, config, message: discord.Message, reaction_cou await message.add_reaction(config.star_emoji) else: try: - sent_msg = await self.bot.get_channel(config.channel_id).fetch_message(db_msgs[0].starboard_message_id) + sent_msg: discord.Message = await self.bot.get_channel(config.channel_id).fetch_message(db_msgs[0].starboard_message_id) except discord.errors.NotFound: # Uh oh! Starboard message was deleted. Let's try and delete it logger.warning(f"Cannot find Starboard Message {db_msgs[0].starboard_message_id} to update") @@ -114,13 +117,13 @@ async def send_to_starboard(self, config, message: discord.Message, reaction_cou async def remove_from_starboard(self, config, starboard_message: discord.Message, cancel: bool = False): """Given a starboard message or snowflake, remove that message and remove it from the DB""" - db_msgs = await StarboardMessage.get_by(starboard_message_id=starboard_message.id) + db_msgs: List[StarboardMessage] = await StarboardMessage.get_by(starboard_message_id=starboard_message.id) if len(db_msgs): if hasattr(starboard_message, 'delete'): await starboard_message.delete() if cancel: try: - orig_msg = await self.bot.get_channel(db_msgs[0].channel_id).fetch_message(db_msgs[0].message_id) + orig_msg: discord.Message = await self.bot.get_channel(db_msgs[0].channel_id).fetch_message(db_msgs[0].message_id) await orig_msg.add_reaction(config.cancel_emoji) except discord.NotFound: pass @@ -128,11 +131,11 @@ async def remove_from_starboard(self, config, starboard_message: discord.Message async def starboard_check(self, reaction: discord.Reaction, member: discord.Member): """Provides all logic for checking and updating the Starboard""" - msg = reaction.message + msg: discord.Message = reaction.message if not msg.guild: return - config = await self.config_cache.query_one(guild_id=msg.guild.id) + config: StarboardConfig = await self.config_cache.query_one(guild_id=msg.guild.id) if config is None: return @@ -273,8 +276,8 @@ async def config(self, ctx: DozerContext, channel: discord.TextChannel, f"{ctx.me.name} is in.") return - config = StarboardConfig(guild_id=ctx.guild.id, channel_id=channel.id, star_emoji=str(star_emoji), - threshold=threshold, cancel_emoji=str(cancel_emoji)) + config: StarboardConfig = StarboardConfig(guild_id=ctx.guild.id, channel_id=channel.id, star_emoji=str(star_emoji), + threshold=threshold, cancel_emoji=str(cancel_emoji)) await config.update_or_add() self.config_cache.invalidate_entry(guild_id=ctx.guild.id) @@ -290,7 +293,7 @@ async def config(self, ctx: DozerContext, channel: discord.TextChannel, @starboard.command() async def disable(self, ctx: DozerContext): """Turn off the starboard if it is enabled""" - config = await StarboardConfig.get_by(guild_id=ctx.guild.id) + config: List[StarboardConfig] = await StarboardConfig.get_by(guild_id=ctx.guild.id) if not config: await ctx.send("There is not a Starboard set up for this server.") return @@ -306,10 +309,11 @@ async def disable(self, ctx: DozerContext): @guild_only() @has_permissions(manage_messages=True) @starboard.command() - async def add(self, ctx: DozerContext, message_id, channel: discord.TextChannel = None): + async def add(self, ctx: DozerContext, message_id: str, channel: discord.TextChannel = None): """Add a message to the starboard manually""" - message_id = int(message_id) - config = await self.config_cache.query_one(guild_id=ctx.guild.id) + message_id: int = int(message_id) + config: StarboardConfig = await self.config_cache.query_one(guild_id=ctx.guild.id) + if config is None: await ctx.send(f"There is not a Starboard configured for this server. Set one up with " f"`{ctx.prefix}starboard config`") @@ -318,7 +322,7 @@ async def add(self, ctx: DozerContext, message_id, channel: discord.TextChannel channel = ctx.channel try: - msg = await channel.fetch_message(message_id) + msg: discord.Message = await channel.fetch_message(message_id) for reaction in msg.reactions: if str(reaction) != config.star_emoji: await self.send_to_starboard(config, msg, reaction.count, False) @@ -360,16 +364,16 @@ async def initial_create(cls): threshold bigint NOT NULL )""") - def __init__(self, guild_id: int, channel_id: int, star_emoji: str, threshold: int, cancel_emoji: str = None): + def __init__(self, guild_id: int, channel_id: int, star_emoji: str, threshold: int, cancel_emoji: Optional[str] = None): super().__init__() - self.guild_id = guild_id - self.channel_id = channel_id - self.star_emoji = star_emoji - self.cancel_emoji = cancel_emoji - self.threshold = threshold + self.guild_id: int = guild_id + self.channel_id: int = channel_id + self.star_emoji: str = star_emoji + self.cancel_emoji: Optional[str] = cancel_emoji + self.threshold: int = threshold @classmethod - async def get_by(cls, **kwargs): + async def get_by(cls, **kwargs) -> List["StarboardConfig"]: results = await super().get_by(**kwargs) result_list = [] for result in results: @@ -399,13 +403,13 @@ async def initial_create(cls): def __init__(self, message_id: int, channel_id: int, starboard_message_id: int, author_id: int): super().__init__() - self.message_id = message_id - self.channel_id = channel_id - self.starboard_message_id = starboard_message_id - self.author_id = author_id + self.message_id: int = message_id + self.channel_id: int = channel_id + self.starboard_message_id: int = starboard_message_id + self.author_id: int = author_id @classmethod - async def get_by(cls, **kwargs): + async def get_by(cls, **kwargs) -> List["StarboardMessage"]: results = await super().get_by(**kwargs) result_list = [] for result in results: diff --git a/dozer/cogs/tba.py b/dozer/cogs/tba.py index 9ea0a311..01016fe1 100755 --- a/dozer/cogs/tba.py +++ b/dozer/cogs/tba.py @@ -1,38 +1,44 @@ """A series of commands that talk to The Blue Alliance.""" -import datetime +from datetime import datetime, timedelta import io import itertools import json from pprint import pformat +from typing import TYPE_CHECKING, List, Union, Dict from urllib.parse import quote as urlquote, urljoin import aiohttp import aiotba import async_timeout -import googlemaps import discord -from discord.ext import commands +import googlemaps +from aiotba.models import Media +from discord import Embed from discord.ext.commands import BadArgument from discord.utils import escape_markdown from geopy.geocoders import Nominatim +from tbapi import Event, Team from dozer.context import DozerContext from ._utils import * +if TYPE_CHECKING: + from dozer import Dozer + class TBA(Cog): """Commands that talk to The Blue Alliance""" - def __init__(self, bot: commands.Bot): + def __init__(self, bot: "Dozer"): super().__init__(bot) - tba_config = bot.config['tba'] - self.gmaps_key = bot.config['gmaps_key'] - self.http_session = aiohttp.ClientSession() - self.session = aiotba.TBASession(tba_config['key'], self.http_session) + tba_config: Dict[str, str] = bot.config['tba'] + self.gmaps_key: str = bot.config['gmaps_key'] + self.http_session: aiohttp.ClientSession = aiohttp.ClientSession() + self.session: aiotba.TBASession = aiotba.TBASession(tba_config['key'], self.http_session) # self.parser = tbapi.TBAParser(tba_config['key'], cache=False) - col = discord.Color.from_rgb(63, 81, 181) + col: discord.Colour = discord.Color.from_rgb(63, 81, 181) @group(invoke_without_command=True) async def tba(self, ctx: DozerContext, team_num: int): @@ -64,9 +70,9 @@ async def team(self, ctx: DozerContext, team_num: int): team_district = max(team_district_data, key=lambda d: d.year) except aiotba.http.AioTBAError: team_district_data = None - e = discord.Embed(color=self.col, - title='FIRST® Robotics Competition Team {}'.format(team_num), - url='https://www.thebluealliance.com/team/{}'.format(team_num)) + e: Embed = Embed(color=self.col, + title='FIRST® Robotics Competition Team {}'.format(team_num), + url='https://www.thebluealliance.com/team/{}'.format(team_num)) e.set_thumbnail(url='https://frcavatars.herokuapp.com/get_image?team={}'.format(team_num)) e.add_field(name='Name', value=team_data.nickname) e.add_field(name='Rookie Year', value=team_data.rookie_year) @@ -93,15 +99,15 @@ async def eventsfor(self, ctx: DozerContext, team_num: int, year: int = None): if year is None: year = (await self.session.status()).current_season try: - events = await self.session.team_events(team_num, year=year) + events: List[Event] = await self.session.team_events(team_num, year=year) except aiotba.http.AioTBAError: raise BadArgument("Couldn't find matching data!") if not events: raise BadArgument("Couldn't find matching data!") - e = discord.Embed(color=self.col) - events = "\n".join(i.name for i in events) + e: Embed = Embed(color=self.col) + events: str = "\n".join(i.name for i in events) e.title = f"Registered events for FRC Team {team_num} in {year}:" e.description = events await ctx.send(embed=e) @@ -115,11 +121,11 @@ async def eventsfor(self, ctx: DozerContext, team_num: int, year: int = None): async def media(self, ctx: DozerContext, team_num: int, year: int = None): """Get media of a team for a given year. Defaults to current year.""" if year is None: - year = datetime.datetime.today().year + year = datetime.today().year try: - team_media = await self.session.team_media(team_num, year) + team_media: List[Media] = await self.session.team_media(team_num, year) - pages = [] + pages: List[Embed] = [] base = f"FRC Team {team_num} {year} Media: " for media in team_media: name, url, img_url = { @@ -151,7 +157,7 @@ async def media(self, ctx: DozerContext, team_num: int, year: int = None): }.get(media.type, (None, None, None)) media.details['foreign_key'] = media.foreign_key if name is not None: - page = discord.Embed(title="{}{}".format(base, name), url=url.format(**media.details)) + page: Embed = Embed(title="{}{}".format(base, name), url=url.format(**media.details)) page.set_image(url=img_url.format(**media.details)) pages.append(page) @@ -179,9 +185,9 @@ async def awards(self, ctx: DozerContext, team_num: int, year: int = None): except aiotba.http.AioTBAError: raise BadArgument("Couldn't find data for team {}".format(team_num)) - pages = [] + pages: List[Embed] = [] for award_year, awards in itertools.groupby(awards_data, lambda a: a.year): - e = discord.Embed(title=f"Awards for FRC Team {team_num} in {award_year}:", color=self.col) + e: Embed = Embed(title=f"Awards for FRC Team {team_num} in {award_year}:", color=self.col) for event_key, event_awards in itertools.groupby(list(awards), lambda a: a.event_key): event = event_key_map[event_key] e.add_field(name=f"{event.name} [{event_key}]", @@ -207,8 +213,8 @@ async def raw(self, ctx: DozerContext, team_num: int): This command is really only useful for development. """ try: - team_data = await self.session.team(team_num) - e = discord.Embed(color=self.col) + team_data: Team = await self.session.team(team_num) + e: Embed = Embed(color=self.col) e.set_author(name='FIRST® Robotics Competition Team {}'.format(team_num), url='https://www.thebluealliance.com/team/{}'.format(team_num), icon_url='https://frcavatars.herokuapp.com/get_image?team={}'.format(team_num)) @@ -235,7 +241,7 @@ async def weather(self, ctx: DozerContext, team_program: str, team_num: int): if team_program.lower() == "frc": try: - td = await self.session.team(team_num) + td: Union["TBA.TeamData", Team] = await self.session.team(team_num) except aiotba.http.AioTBAError: raise BadArgument('Team {} does not exist.'.format(team_num)) elif team_program.lower() == "ftc": @@ -248,18 +254,18 @@ async def weather(self, ctx: DozerContext, team_program: str, team_num: int): else: raise BadArgument('`team_program` should be one of [`frc`, `ftc`]') - units = 'm' + units: str = 'm' if td.country == "USA": td.country = "United States of America" units = 'u' - url = "https://wttr.in/{}".format( + url: str = "https://wttr.in/{}".format( urlquote("{}+{}+{}_0_{}.png".format(td.city, td.state_prov, td.country, units))) async with ctx.typing(), self.http_session.get(url) as resp: image_data = io.BytesIO(await resp.read()) - file_name = f"weather_{team_program.lower()}{team_num}.png" - e = discord.Embed(title=f"Current weather for {team_program.upper()} Team {team_num}:", url=url) + file_name: str = f"weather_{team_program.lower()}{team_num}.png" + e: Embed = Embed(title=f"Current weather for {team_program.upper()} Team {team_num}:", url=url) e.set_image(url=f"attachment://{file_name}") e.set_footer(text="Powered by wttr.in and sometimes TBA") await ctx.send(embed=e, file=discord.File(image_data, file_name)) @@ -277,7 +283,7 @@ async def timezone(self, ctx: DozerContext, team_program: str, team_num: int): if team_program.lower() == "frc": try: - team_data = await self.session.team(team_num) + team_data: Union["TBA.TeamData", Team] = await self.session.team(team_num) except aiotba.http.AioTBAError: raise BadArgument('Team {} does not exist.'.format(team_num)) if team_data.city is None: @@ -293,18 +299,18 @@ async def timezone(self, ctx: DozerContext, team_program: str, team_num: int): else: raise BadArgument('`team_program` should be one of [`frc`, `ftc`]') - location = '{0.city}, {0.state_prov} {0.country}'.format(team_data) - gmaps = googlemaps.Client(key=self.gmaps_key) - geolocator = Nominatim(user_agent="Dozer Discord Bot") + location: str = '{0.city}, {0.state_prov} {0.country}'.format(team_data) + gmaps: googlemaps.Client = googlemaps.Client(key=self.gmaps_key) + geolocator: Nominatim = Nominatim(user_agent="Dozer Discord Bot") geolocation = geolocator.geocode(location) if self.gmaps_key and not self.bot.config['tz_url']: timezone = gmaps.timezone(location="{}, {}".format(geolocation.latitude, geolocation.longitude), language="json") - utc_offset = float(timezone["rawOffset"]) / 3600 + utc_offset: float = float(timezone["rawOffset"]) / 3600 if timezone["dstOffset"] == 3600: utc_offset += 1 - tzname = timezone["timeZoneName"] + tzname: str = timezone["timeZoneName"] else: async with async_timeout.timeout(5), self.bot.http_session.get(urljoin(base=self.bot.config['tz_url'], url="{}/{}".format( @@ -313,9 +319,9 @@ async def timezone(self, ctx: DozerContext, team_program: str, team_num: int): r.raise_for_status() data = await r.json() utc_offset = data["utc_offset"] - tzname = '`{}`'.format(data["tz"]) + tzname: str = '`{}`'.format(data["tz"]) - current_time = datetime.datetime.utcnow() + datetime.timedelta(hours=utc_offset) + current_time: datetime = datetime.utcnow() + timedelta(hours=utc_offset) await ctx.send("Timezone: {} UTC{}\n{}".format(tzname, utc_offset, current_time.strftime("Current Time: %I:%M:%S %p (%H:%M:%S)"))) diff --git a/dozer/cogs/teams.py b/dozer/cogs/teams.py index 1db48f1e..b78db2fc 100755 --- a/dozer/cogs/teams.py +++ b/dozer/cogs/teams.py @@ -1,17 +1,21 @@ """Commands for making and seeing robotics team associations.""" import json +from typing import List, Dict, Union, Tuple import discord from aiotba.http import AioTBAError +from discord import Color, Embed from discord.ext.commands import BadArgument, guild_only, has_permissions from discord.utils import escape_markdown +from tbapi import Team from dozer.context import DozerContext +from dozer import db + from ._utils import * -from .info import blurple -from .. import db -from ..Components.TeamNumbers import TeamNumbers + +blurple: Color = discord.Color.blurple() class Teams(Cog): @@ -20,9 +24,9 @@ class Teams(Cog): @command() async def setteam(self, ctx: DozerContext, team_type: str, team_number: int): """Sets an association with your team in the database.""" - team_type = team_type.casefold() - dbcheck = await TeamNumbers.get_by(user_id=ctx.author.id, team_type=team_type, team_number=team_number) - if not dbcheck: + team_type: str = team_type.casefold() + dbcheck: List[TeamNumbers] = await TeamNumbers.get_by(user_id=ctx.author.id, team_type=team_type, team_number=team_number) + if not len(dbcheck): await TeamNumbers(user_id=ctx.author.id, team_number=team_number, team_type=team_type).update_or_add() await ctx.send("Team number set!") else: @@ -35,8 +39,8 @@ async def setteam(self, ctx: DozerContext, team_type: str, team_number: int): @command() async def removeteam(self, ctx: DozerContext, team_type: str, team_number: int): """Removes an association with a team in the database.""" - team_type = team_type.casefold() - results = await TeamNumbers.get_by(user_id=ctx.author.id, team_type=team_type, team_number=team_number) + team_type: str = team_type.casefold() + results: List[TeamNumbers] = await TeamNumbers.get_by(user_id=ctx.author.id, team_type=team_type, team_number=team_number) if len(results) != 0: await TeamNumbers.delete(user_id=ctx.author.id, team_number=team_number, team_type=team_type) await ctx.send("Removed association with {} team {}".format(team_type, team_number)) @@ -53,11 +57,11 @@ async def teamsfor(self, ctx: DozerContext, user: discord.Member = None): """Allows you to see the teams for the mentioned user, or yourself if nobody is mentioned.""" if user is None: user = ctx.author - teams = await TeamNumbers.get_by(user_id=user.id) + teams: List[TeamNumbers] = await TeamNumbers.get_by(user_id=user.id) if len(teams) == 0: raise BadArgument("Couldn't find any team associations for that user!") else: - e = discord.Embed(type='rich') + e: Embed = Embed(type='rich') e.title = 'Teams for {}'.format(escape_markdown(user.display_name)) e.description = "Teams: \n" for i in teams: @@ -72,33 +76,34 @@ async def teamsfor(self, ctx: DozerContext, user: discord.Member = None): @guild_only() @bot_has_permissions(add_reactions=True, embed_links=True, read_message_history=True) - async def compcheck(self, ctx: DozerContext, event_type: str, event_key): + async def compcheck(self, ctx: DozerContext, event_type: str, event_key: str): """Allows you to see people in the Discord server that are going to a certain competition.""" if event_type.lower() == "frc": try: - teams_raw = await ctx.bot.get_cog("TBA").session.event_teams(event_key) - teams = [team.team_number for team in teams_raw] + teams_raw: List[Team] = await ctx.bot.get_cog("TBA").session.event_teams(event_key) + teams: List[int] = [team.team_number for team in teams_raw] except AioTBAError: raise BadArgument("Invalid event!") elif event_type.lower() == "ftc": - teams_raw = json.loads(await ctx.bot.get_cog("TOA").parser.req(f"/api/event/{event_key}/teams")) + teams_raw: List[Dict[str, Dict[str, Union[str, int]]]] = json.loads( + await ctx.bot.get_cog("TOA").parser.req(f"/api/event/{event_key}/teams")) try: - teams = [team['team']['team_number'] for team in teams_raw] + teams: List[int] = [team['team']['team_number'] for team in teams_raw] except TypeError: raise BadArgument("Invalid event!") else: raise BadArgument("Unknown event type!") - found_mems = False - embeds = [] + found_mems: bool = False + embeds: List[Embed] = [] for team in teams: - e = discord.Embed(type='rich') + e: Embed = Embed(type='rich') e.title = 'Members going to {}'.format(event_key) - members = await TeamNumbers.get_by(team_type=event_type.lower(), team_number=team) - memstr = "" + members: List[TeamNumbers] = await TeamNumbers.get_by(team_type=event_type.lower(), team_number=team) + memstr: str = "" for member in members: - mem = ctx.guild.get_member(member.user_id) + mem: discord.Member = ctx.guild.get_member(member.user_id) if mem is not None: - newmemstr = "{} {} \n".format(escape_markdown(mem.display_name), mem.mention) + newmemstr: str = "{} {} \n".format(escape_markdown(mem.display_name), mem.mention) if len(newmemstr + memstr) > 1023: e.add_field(name=f"Team {team}", value=memstr) memstr = "" @@ -107,7 +112,7 @@ async def compcheck(self, ctx: DozerContext, event_type: str, event_key): if len(memstr) > 0: if len(e.fields) == 25: embeds.append(e) - e = discord.Embed(type='rich') + e: Embed = Embed(type='rich') e.title = 'Members going to {}'.format(event_key) e.add_field(name=f"Team {team}", value=memstr) embeds.append(e) @@ -115,7 +120,7 @@ async def compcheck(self, ctx: DozerContext, event_type: str, event_key): await ctx.send("Couldn't find any team members for that event!") return else: - pagenum = 1 + pagenum: int = 1 for embed in embeds: embed.set_footer(text=f"Page {pagenum} of {len(embeds)}") pagenum += 1 @@ -130,17 +135,17 @@ async def compcheck(self, ctx: DozerContext, event_type: str, event_key): @guild_only() async def onteam(self, ctx: DozerContext, team_type: str, team_number: int): """Allows you to see who has associated themselves with a particular team.""" - team_type = team_type.casefold() - users = await TeamNumbers.get_by(team_type=team_type, team_number=team_number) + team_type: str = team_type.casefold() + users: List[TeamNumbers] = await TeamNumbers.get_by(team_type=team_type, team_number=team_number) if len(users) == 0: await ctx.send("Nobody on that team found!") else: - e = discord.Embed(type='rich') + e: Embed = Embed(type='rich') e.title = 'Users on team {}'.format(team_number) e.description = "Users: \n" - extra_mems = "" + extra_mems: str = "" for i in users: - user = ctx.guild.get_member(i.user_id) + user: discord.Member = ctx.guild.get_member(i.user_id) if user is not None: memstr = "{} {} \n".format(escape_markdown(user.display_name), user.mention) if len(e.description + memstr) > 2047: @@ -159,9 +164,9 @@ async def onteam(self, ctx: DozerContext, team_type: str, team_number: int): @guild_only() async def onteam_top(self, ctx: DozerContext): """Show the top 10 teams by number of members in this guild.""" - users = [mem.id for mem in ctx.guild.members] - counts = await TeamNumbers.top10(users) - embed = discord.Embed(title=f'Top teams in {ctx.guild.name}', color=discord.Color.blue()) + users: List[int] = [mem.id for mem in ctx.guild.members] + counts: List[Tuple[str, int, int]] = await TeamNumbers.top10(users) + embed: Embed = Embed(title=f'Top teams in {ctx.guild.name}', color=discord.Color.blue()) embed.description = '\n'.join( f'{type_.upper()} team {num} ({count} member{"s" if count > 1 else ""})' for (type_, num, count) in counts) await ctx.send(embed=embed) @@ -175,15 +180,15 @@ async def onteam_top(self, ctx: DozerContext): @has_permissions(manage_guild=True) async def toggleautoteam(self, ctx: DozerContext): """Toggles automatic adding of team association to member nicknames""" - settings = await AutoAssociation.get_by(guild_id=ctx.guild.id) - enabled = settings[0].team_on_join if settings else True - new_settings = AutoAssociation( + settings: List[AutoAssociation] = await AutoAssociation.get_by(guild_id=ctx.guild.id) + enabled: bool = settings[0].team_on_join if settings else True + new_settings: AutoAssociation = AutoAssociation( guild_id=ctx.guild.id, team_on_join=not enabled ) await new_settings.update_or_add() - e = discord.Embed(color=blurple) - modetext = "Enabled" if not enabled else "Disabled" + e: Embed = Embed(color=blurple) + modetext: str = "Enabled" if not enabled else "Disabled" e.add_field(name='Success!', value=f"Automatic adding of team association is currently: **{modetext}**") e.set_footer(text='Triggered by ' + escape_markdown(ctx.author.display_name)) await ctx.send(embed=e) @@ -191,12 +196,12 @@ async def toggleautoteam(self, ctx: DozerContext): @Cog.listener('on_member_join') async def on_member_join(self, member: discord.Member): """Adds a user's team association to their name when they join (if exactly 1 association)""" - settings = await AutoAssociation.get_by(guild_id=member.guild.id) - enabled = settings[0].team_on_join if settings else True + settings: List[AutoAssociation] = await AutoAssociation.get_by(guild_id=member.guild.id) + enabled: bool = settings[0].team_on_join if settings else True if member.guild.me.guild_permissions.manage_nicknames and enabled: - query = await TeamNumbers.get_by(user_id=member.id) + query: List[TeamNumbers] = await TeamNumbers.get_by(user_id=member.id) if len(query) == 1: - nick = "{} {}{}".format(member.display_name, query[0].team_type, query[0].team_number) + nick: str = "{} {}{}".format(member.display_name, query[0].team_type, query[0].team_number) if len(nick) <= 32: await member.edit(nick=nick) @@ -219,11 +224,11 @@ async def initial_create(cls): def __init__(self, guild_id: int, team_on_join: bool = True): super().__init__() - self.guild_id = guild_id - self.team_on_join = team_on_join + self.guild_id: int = guild_id + self.team_on_join: bool = team_on_join @classmethod - async def get_by(cls, **kwargs): + async def get_by(cls, **kwargs) -> List["AutoAssociation"]: results = await super().get_by(**kwargs) result_list = [] for result in results: @@ -232,6 +237,72 @@ async def get_by(cls, **kwargs): return result_list +class TeamNumbers(db.DatabaseTable): + """Database operations for tracking team associations.""" + __tablename__ = 'team_numbers' + __uniques__ = 'user_id, team_number, team_type' + + @classmethod + async def initial_create(cls): + """Create the table in the database""" + async with db.Pool.acquire() as conn: + await conn.execute(f""" + CREATE TABLE {cls.__tablename__} ( + user_id bigint NOT NULL, + team_number bigint NOT NULL, + team_type VARCHAR NOT NULL, + PRIMARY KEY (user_id, team_number, team_type) + )""") + + def __init__(self, user_id: int, team_number: int, team_type: str): + super().__init__() + self.user_id: int = user_id + self.team_number: int = team_number + self.team_type: str = team_type + + async def update_or_add(self): + """Assign the attribute to this object, then call this method to either insert the object if it doesn't exist in + the DB or update it if it does exist. It will update every column not specified in __uniques__.""" + # This is its own functions because all columns must be unique, which breaks the syntax of the other one + keys = [] + values = [] + for var, value in self.__dict__.items(): + # Done so that the two are guaranteed to be in the same order, which isn't true of keys() and values() + if value is not None: + keys.append(var) + values.append(value) + async with db.Pool.acquire() as conn: + statement = f""" + INSERT INTO {self.__tablename__} ({", ".join(keys)}) + VALUES({','.join(f'${i + 1}' for i in range(len(values)))}) + """ + await conn.execute(statement, *values) + + @classmethod + async def get_by(cls, **kwargs) -> List["TeamNumbers"]: + results = await super().get_by(**kwargs) + result_list = [] + for result in results: + obj = TeamNumbers(user_id=result.get("user_id"), + team_number=result.get("team_number"), + team_type=result.get("team_type")) + result_list.append(obj) + return result_list + + # noinspection SqlResolve + @classmethod + async def top10(cls, user_ids) -> List[Tuple[str, int, int]]: + """Returns the top 10 team entries""" + query = f"""SELECT team_type, team_number, count(*) + FROM {cls.__tablename__} + WHERE user_id = ANY($1) --first param: list of user IDs + GROUP BY team_type, team_number + ORDER BY count DESC, team_type, team_number + LIMIT 10""" + async with db.Pool.acquire() as conn: + return await conn.fetch(query, user_ids) + + async def setup(bot): """Adds this cog to the main bot""" await bot.add_cog(Teams(bot)) diff --git a/dozer/cogs/toa.py b/dozer/cogs/toa.py index a32d7040..a5c40fa6 100755 --- a/dozer/cogs/toa.py +++ b/dozer/cogs/toa.py @@ -3,18 +3,21 @@ import json from asyncio import sleep from datetime import datetime +from typing import TYPE_CHECKING, Dict, List from urllib.parse import urljoin import aiohttp import async_timeout import discord -from discord.ext import commands from discord.utils import escape_markdown from dozer.context import DozerContext from ._utils import * -embed_color = discord.Color(0xf89808) +if TYPE_CHECKING: + from dozer import Dozer + +embed_color: discord.Color = discord.Color(0xf89808) class TOAParser: @@ -22,20 +25,20 @@ class TOAParser: A class to make async requests to The Orange Alliance. """ - def __init__(self, api_key: str, aiohttp_session, base_url: str = "https://theorangealliance.org/api/", + def __init__(self, api_key: str, aiohttp_session: aiohttp.ClientSession, base_url: str = "https://theorangealliance.org/api/", app_name: str = "Dozer", ratelimit: bool = True): - self.last_req = datetime.now() - self.ratelimit = ratelimit - self.base = base_url - self.http = aiohttp_session - self.headers = { + self.last_req: datetime = datetime.now() + self.ratelimit: bool = ratelimit + self.base: str = base_url + self.http: aiohttp.ClientSession = aiohttp_session + self.headers: Dict[str, str] = { 'X-Application-Origin': app_name, 'X-TOA-Key': api_key, 'Content-Type': 'application/json' } - async def req(self, endpoint): + async def req(self, endpoint: str): """Make an async request at the specified endpoint, waiting to let the ratelimit cool off.""" if self.ratelimit: # this will delay a request to avoid the ratelimit @@ -44,7 +47,7 @@ async def req(self, endpoint): self.last_req = now if diff < 2.2: # have a 200 ms fudge factor await sleep(2.2 - diff) - tries = 0 + tries: int = 0 while True: try: async with async_timeout.timeout(5) as _, self.http.get(urljoin(self.base, endpoint), @@ -59,10 +62,10 @@ async def req(self, endpoint): class TOA(Cog): """TOA commands""" - def __init__(self, bot: commands.Bot): + def __init__(self, bot: "Dozer"): super().__init__(bot) - self.http_session = aiohttp.ClientSession() - self.parser = TOAParser(bot.config['toa']['key'], self.http_session, app_name=bot.config['toa']['app_name']) + self.http_session: aiohttp.ClientSession = aiohttp.ClientSession() + self.parser: TOAParser = TOAParser(bot.config['toa']['key'], self.http_session, app_name=bot.config['toa']['app_name']) @group(invoke_without_command=True) async def toa(self, ctx: DozerContext, team_num: int): @@ -80,13 +83,13 @@ async def toa(self, ctx: DozerContext, team_num: int): @bot_has_permissions(embed_links=True) async def team(self, ctx: DozerContext, team_num: int): """Get information on an FTC team by number.""" - res = json.loads(await self.parser.req("team/" + str(team_num))) + res: List[Dict[str, str]] = json.loads(await self.parser.req("team/" + str(team_num))) if len(res) == 0: await ctx.send("This team does not have any data on it yet, or it does not exist!") return - team_data = res[0] + team_data: Dict[str, str] = res[0] - e = discord.Embed(color=embed_color) + e: discord.Embed = discord.Embed(color=embed_color) e.set_author(name='FIRST® Tech Challenge Team {}'.format(team_num), url='https://theorangealliance.org/teams/{}'.format(team_num), icon_url='https://theorangealliance.org/assets/imgs/favicon.png?v=1') diff --git a/dozer/cogs/voice.py b/dozer/cogs/voice.py index 8f9e45e6..76e6975d 100755 --- a/dozer/cogs/voice.py +++ b/dozer/cogs/voice.py @@ -1,5 +1,8 @@ """Provides commands for voice, currently only voice and text channel access bindings.""" +from typing import List, Optional + import discord +from discord import Embed from discord.ext.commands import has_permissions, BadArgument from discord.utils import escape_markdown @@ -13,10 +16,10 @@ class Voice(Cog): """Commands interacting with voice.""" @staticmethod - async def auto_ptt_check(voice_channel: discord.VoiceChannel): + async def auto_ptt_check(voice_channel: discord.VoiceState): """Handles voice activity when members join/leave voice channels""" total_users = len(voice_channel.channel.members) - config = await AutoPTT.get_by(channel_id=voice_channel.channel.id) + config: List[AutoPTT] = await AutoPTT.get_by(channel_id=voice_channel.channel.id) if config: everyone = voice_channel.channel.guild.default_role # grab the @everyone role perms = voice_channel.channel.overwrites_for(everyone) # grab the @everyone overwrites @@ -36,12 +39,12 @@ async def on_voice_state_update(self, member: discord.Member, before: discord.Vo # before and after are voice states if before.channel is not None: # leave event, take role - config = await Voicebinds.get_by(channel_id=before.channel.id) + config: List[Voicebinds] = await Voicebinds.get_by(channel_id=before.channel.id) if len(config) != 0: await member.remove_roles(member.guild.get_role(config[0].role_id)) if after.channel is not None: # join event, give role - config = await Voicebinds.get_by(channel_id=after.channel.id) + config: List[Voicebinds] = await Voicebinds.get_by(channel_id=after.channel.id) if len(config) != 0: await member.add_roles(member.guild.get_role(config[0].role_id)) @@ -66,14 +69,14 @@ async def on_PTT_check(self, member: discord.Member, before: discord.VoiceState, async def autoptt(self, ctx: DozerContext, voice_channel: discord.VoiceChannel, ptt_threshold: int): """Configures AutoPtt limit for when members join/leave voice channels ptt is enabled""" - e = discord.Embed(color=blurple) + e: Embed = Embed(color=blurple) e.set_footer(text='Triggered by ' + escape_markdown(ctx.author.display_name)) if ptt_threshold < 0: raise BadArgument('PTT threshold must be positive integer') if ptt_threshold == 0: - config = await AutoPTT.get_by(channel_id=voice_channel.id) + config: List[AutoPTT] = await AutoPTT.get_by(channel_id=voice_channel.id) if len(config) != 0: await AutoPTT.delete(channel_id=config[0].channel_id) e.add_field(name='Success!', value='AutoPTT has been disabled for voice channel "**{}**"' @@ -82,7 +85,7 @@ async def autoptt(self, ctx: DozerContext, voice_channel: discord.VoiceChannel, e.add_field(name='Error', value='AutoPTT has not been configured for voice channel "**{}**"' .format(voice_channel)) else: - ent = AutoPTT( + ent: AutoPTT = AutoPTT( channel_id=voice_channel.id, ptt_limit=ptt_threshold ) @@ -105,7 +108,7 @@ async def autoptt(self, ctx: DozerContext, voice_channel: discord.VoiceChannel, async def voicebind(self, ctx: DozerContext, voice_channel: discord.VoiceChannel, *, role: discord.Role): """Binds a voice channel with a role, so users joining voice channels will be given desired role(s).""" - config = await Voicebinds.get_by(channel_id=voice_channel.id) + config: List[Voicebinds] = await Voicebinds.get_by(channel_id=voice_channel.id) if len(config) != 0: config[0].guild_id = ctx.guild.id config[0].channel_id = voice_channel.id @@ -117,18 +120,17 @@ async def voicebind(self, ctx: DozerContext, voice_channel: discord.VoiceChannel await ctx.send("Role `{role}` will now be given to users in voice channel `{voice_channel}`!".format(role=role, voice_channel=voice_channel)) - voicebind.example_usage = """ - `{prefix}voicebind "General #1" voice-general-1` - sets up Dozer to give users `voice-general-1` when they join voice channel "General #1", which will be removed when they leave. - """ + voicebind.example_usage = '`{prefix}voicebind "General #1" voice-general-1` - sets up Dozer to give users `voice-general-1` ' \ + 'when they join voice channel "General #1", which will be removed when they leave. ' @command() @bot_has_permissions(manage_roles=True) @has_permissions(manage_roles=True) async def voiceunbind(self, ctx: DozerContext, voice_channel: discord.VoiceChannel): """Dissasociates a voice channel with a role previously binded with the voicebind command.""" - config = await Voicebinds.get_by(channel_id=voice_channel.id) + config: List[Voicebinds] = await Voicebinds.get_by(channel_id=voice_channel.id) if len(config) != 0: - role = ctx.guild.get_role(config[0].role_id) + role: discord.Role = ctx.guild.get_role(config[0].role_id) await Voicebinds.delete(id=config[0].id) await ctx.send( "Role `{role}` will no longer be given to users in voice channel `{voice_channel}`!".format( @@ -145,11 +147,11 @@ async def voiceunbind(self, ctx: DozerContext, voice_channel: discord.VoiceChann @bot_has_permissions(manage_roles=True) async def voicebindlist(self, ctx: DozerContext): """Lists all the voice channel to role bindings for the current server""" - embed = discord.Embed(title="List of voice bindings for \"{}\"".format(ctx.guild), color=discord.Color.blue()) + embed: Embed = Embed(title="List of voice bindings for \"{}\"".format(ctx.guild), color=discord.Color.blue()) for config in await Voicebinds.get_by(guild_id=ctx.guild.id): - channel = discord.utils.get(ctx.guild.voice_channels, id=config.channel_id) - role = ctx.guild.get_role(config.role_id) - embed.add_field(name=channel, value="`{}`".format(role)) + channel: discord.VoiceChannel = discord.utils.get(ctx.guild.voice_channels, id=config.channel_id) + role: discord.Role = ctx.guild.get_role(config.role_id) + embed.add_field(name=str(channel), value="`{}`".format(str(role))) await ctx.send(embed=embed) voicebindlist.example_usage = """ @@ -175,16 +177,16 @@ async def initial_create(cls): role_id bigint null )""") - def __init__(self, guild_id: int, channel_id: int, role_id: int, row_id: int = None): + def __init__(self, guild_id: int, channel_id: Optional[int], role_id: Optional[int], row_id: Optional[int] = None): super().__init__() if row_id is not None: - self.id = row_id - self.guild_id = guild_id - self.channel_id = channel_id - self.role_id = role_id + self.id: int = row_id + self.guild_id: int = guild_id + self.channel_id: Optional[int] = channel_id + self.role_id: Optional[int] = role_id @classmethod - async def get_by(cls, **kwargs): + async def get_by(cls, **kwargs) -> List["Voicebinds"]: results = await super().get_by(**kwargs) result_list = [] for result in results: @@ -211,13 +213,13 @@ async def initial_create(cls): ptt_limit bigint null )""") - def __init__(self, channel_id: int, ptt_limit: int): + def __init__(self, channel_id: int, ptt_limit: Optional[int]): super().__init__() - self.channel_id = channel_id - self.ptt_limit = ptt_limit + self.channel_id: int = channel_id + self.ptt_limit: Optional[int] = ptt_limit @classmethod - async def get_by(cls, **kwargs): + async def get_by(cls, **kwargs) -> List["AutoPTT"]: results = await super().get_by(**kwargs) result_list = [] for result in results: diff --git a/dozer/context.py b/dozer/context.py index cf76791f..78773960 100644 --- a/dozer/context.py +++ b/dozer/context.py @@ -1,13 +1,25 @@ """Class that holds dozercontext. """ +from typing import TYPE_CHECKING + +import discord from discord.ext import commands + from dozer import utils +if TYPE_CHECKING: + from dozer import Dozer + class DozerContext(commands.Context): """Cleans all messages before sending""" + bot: "Dozer" + # @property + # def bot(self) -> "Dozer": + # """Returns the bot, with the correct type. """ + # return super().bot - async def send(self, content=None, **kwargs): # pylint: disable=arguments-differ + async def send(self, content: str = None, **kwargs) -> discord.Message: # pylint: disable=arguments-differ """Make it so you cannot ping @.everyone when sending a message""" if content is not None: content = utils.clean(self, content, mass=True, member=False, role=False, channel=False) diff --git a/dozer/db.py b/dozer/db.py index e5c07f8c..ee7100bf 100755 --- a/dozer/db.py +++ b/dozer/db.py @@ -1,10 +1,11 @@ """Provides database storage for the Dozer Discord bot""" -from typing import List, Dict +from typing import List, Dict, Callable, Type, Tuple, Any import asyncpg +from asyncpg import Record from loguru import logger -Pool = None +Pool: asyncpg.Pool async def db_init(db_url): @@ -52,8 +53,8 @@ async def db_migrate(): class DatabaseTable: """Defines a database table""" __tablename__: str = '' - __versions__: List[int] = [] - __uniques__: List[str] = [] + __versions__: List[Callable] = [] + __uniques__: str = "" # Declare the migrate/create functions @classmethod @@ -88,7 +89,6 @@ async def update_or_add(self): updates = "" for key in keys: if key in self.__uniques__: - # Skip updating anything that has a unique constraint on it continue updates += f"{key} = EXCLUDED.{key}" if keys.index(key) == len(keys) - 1: @@ -111,7 +111,7 @@ async def update_or_add(self): """ await conn.execute(statement, *values) - def __repr__(self): + def __repr__(self) -> str: values = "" first = True for key, value in self.__dict__.items(): @@ -124,15 +124,15 @@ def __repr__(self): # Class Methods @classmethod - async def get_by(cls, **filters): + async def get_by(cls, **filters) -> List[Record]: """Get a list of all records matching the given column=value criteria. This will grab all attributes, it's more - efficent to write your own SQL queries than use this one, but for a simple query this is fine.""" + efficient to write your own SQL queries than use this one, but for a simple query this is fine.""" async with Pool.acquire() as conn: - statement = f"SELECT * FROM {cls.__tablename__}" + statement: str = f"SELECT * FROM {cls.__tablename__}" if filters: # note: this code relies on subsequent iterations of the same dict having the same iteration order. # This is an implementation detail of CPython 3.6 and a language guarantee in Python 3.7+. - conditions = " AND ".join(f"{column_name} = ${i + 1}" for (i, column_name) in enumerate(filters)) + conditions: str = " AND ".join(f"{column_name} = ${i + 1}" for (i, column_name) in enumerate(filters)) statement = f"{statement} WHERE {conditions};" else: statement += ";" @@ -160,12 +160,12 @@ async def set_initial_version(cls): class ConfigCache: """Class that will reduce calls to sqlalchemy as much as possible. Has no growth limit (yet)""" - def __init__(self, table): - self.cache = {} - self.table = table + def __init__(self, table: Type[DatabaseTable]): + self.cache: Dict = {} + self.table: Type[DatabaseTable] = table @staticmethod - def _hash_dict(dic): + def _hash_dict(dic) -> Tuple[Tuple[Any, Any], ...]: """Makes a dict hashable by turning it into a tuple of tuples""" # sort the keys to make this repeatable; this allows consistency even when insertion order is different return tuple((k, dic[k]) for k in sorted(dic)) @@ -194,6 +194,6 @@ def invalidate_entry(self, **kwargs): if query_hash in self.cache: del self.cache[query_hash] - __versions__: Dict[str, int] = {} + __versions__: Dict[str, Callable] = {} - __uniques__: List[str] = [] + __uniques__: str = "" diff --git a/dozer/sources/AbstractSources.py b/dozer/sources/AbstractSources.py index d9f80afb..1cfb970f 100644 --- a/dozer/sources/AbstractSources.py +++ b/dozer/sources/AbstractSources.py @@ -1,7 +1,12 @@ """Provide helper classes and end classes for source data""" +from typing import TYPE_CHECKING + import aiohttp from discord.ext.commands import BadArgument +if TYPE_CHECKING: + from dozer import Dozer + class Source: """Abstract base class for a data source.""" @@ -10,10 +15,10 @@ class Source: short_name: str = "src" base_url: str = "" aliases: tuple = tuple() - description = "Description" - disabled = False + description: str = "Description" + disabled: bool = False - def __init__(self, aiohttp_session: aiohttp.ClientSession, bot): + def __init__(self, aiohttp_session: aiohttp.ClientSession, bot: "Dozer"): self.aliases = (self.full_name, self.short_name) self.http_session = aiohttp_session self.bot = bot diff --git a/dozer/sources/RSSSources.py b/dozer/sources/RSSSources.py index 280e2bbb..f3efb3e3 100644 --- a/dozer/sources/RSSSources.py +++ b/dozer/sources/RSSSources.py @@ -2,14 +2,18 @@ import datetime import re import xml.etree.ElementTree +from typing import List, TYPE_CHECKING import aiohttp import discord from .AbstractSources import Source +if TYPE_CHECKING: + from dozer import Dozer -def clean_html(raw_html): + +def clean_html(raw_html: str) -> str: """Clean all HTML tags. From https://stackoverflow.com/questions/9662346/python-code-to-remove-html-tags-from-a-string""" cleanr = re.compile('<.*?>') @@ -20,13 +24,13 @@ def clean_html(raw_html): class RSSSource(Source): """Given an arbitrary RSS feed, get new posts from it""" url: str = "" - color = discord.colour.Color.blurple() - date_formats = ["%a, %d %b %Y %H:%M:%S %z", - "%a, %d %b %Y %H:%M:%S %Z"] # format for datetime.strptime() + color: discord.Colour = discord.colour.Color.blurple() + date_formats: List[str] = ["%a, %d %b %Y %H:%M:%S %z", + "%a, %d %b %Y %H:%M:%S %Z"] # format for datetime.strptime() base_url: str = "" read_more_str: str = "...\n Read More" - def __init__(self, aiohttp_session: aiohttp.ClientSession, bot): + def __init__(self, aiohttp_session: aiohttp.ClientSession, bot: "Dozer"): super().__init__(aiohttp_session, bot) self.guids_seen: set = set() diff --git a/dozer/sources/RedditSource.py b/dozer/sources/RedditSource.py index 742590ec..4b223c22 100644 --- a/dozer/sources/RedditSource.py +++ b/dozer/sources/RedditSource.py @@ -1,12 +1,27 @@ """Get new posts from any arbitrary subreddit""" import datetime +from datetime import datetime, timedelta +from typing import Dict, Optional, TYPE_CHECKING import aiohttp -import discord +from discord import Embed, Colour from loguru import logger from .AbstractSources import DataBasedSource +if TYPE_CHECKING: + from dozer import Dozer + + +class SubReddit(DataBasedSource.DataPoint): + """Represents a single subreddit with associated detail""" + + def __init__(self, name: str, url: str, color: Colour): + super().__init__(name, url) + self.name: str = name + self.url: str = url + self.color: Colour = color + class RedditSource(DataBasedSource): """Get new posts from any arbitrary subreddit""" @@ -19,33 +34,24 @@ class RedditSource(DataBasedSource): token_url = "https://www.reddit.com/api/v1/access_token" api_url = "https://oauth.reddit.com/" backup_api_url = "https://reddit.com/" - color = discord.Color.from_rgb(255, 69, 0) - - class SubReddit(DataBasedSource.DataPoint): - """Represents a single subreddit with associated detail""" - - def __init__(self, name, url, color): - super().__init__(name, url) - self.name = name - self.url = url - self.color = color + color: Colour = Colour.from_rgb(255, 69, 0) - def __init__(self, aiohttp_session, bot): + def __init__(self, aiohttp_session, bot: "Dozer"): super().__init__(aiohttp_session, bot) - self.access_token = None - self.expiry_time = None - self.oauth_disabled = False - self.subreddits = {} - self.seen_posts = set() + self.access_token: Optional[str] = None + self.expiry_time: Optional[datetime] = None + self.oauth_disabled: bool = False + self.subreddits: Dict[str, SubReddit] = {} + self.seen_posts: set = set() async def get_token(self): """Using OAuth2, get a reddit bearer token. If this fails, fallback to non-oauth API""" client_id = self.bot.config['news']['reddit']['client_id'] client_secret = self.bot.config['news']['reddit']['client_secret'] - params = { + params: Dict[str, str] = { 'grant_type': 'client_credentials' } - auth = aiohttp.BasicAuth(client_id, client_secret) + auth: aiohttp.BasicAuth = aiohttp.BasicAuth(client_id, client_secret) response = await self.http_session.post(self.token_url, params=params, auth=auth) response = await response.json() try: @@ -56,9 +62,9 @@ async def get_token(self): self.oauth_disabled = True return - expiry_seconds = response['expires_in'] - time_delta = datetime.timedelta(seconds=expiry_seconds) - self.expiry_time = datetime.datetime.now() + time_delta + expiry_seconds: int = response['expires_in'] + time_delta: timedelta = timedelta(seconds=expiry_seconds) + self.expiry_time = datetime.now() + time_delta async def request(self, url, *args, headers=None, **kwargs): """Make a request using OAuth2 (or not, if it's been disabled)""" @@ -87,20 +93,20 @@ async def request(self, url, *args, headers=None, **kwargs): json = await response.json() return json - def create_subreddit_obj(self, data): + def create_subreddit_obj(self, data) -> SubReddit: """Given a dict, create a subreddit object""" - color = data['key_color'] - if "#" in color: - color = color.replace("#", "") + color_str: str = data['key_color'] + if "#" in color_str: + color_str = color_str.replace("#", "") try: - color = discord.Color(int(color, 16)) + color: Colour = Colour(int(color_str, 16)) except ValueError: - color = self.color + color: Colour = self.color - return RedditSource.SubReddit(data['display_name'], data['url'], color) + return SubReddit(data['display_name'], data['url'], color) - async def clean_data(self, text): + async def clean_data(self, text) -> SubReddit: """Make a request to the reddit API to verify the subreddit exists and clean it into a object""" try: return self.subreddits[text] @@ -172,7 +178,7 @@ async def first_run(self, data=None): self.subreddits[subreddit_obj.name] = subreddit_obj await self.get_new_posts(first_time=True) - async def get_new_posts(self, first_time=False): # pylint: disable=arguments-differ + async def get_new_posts(self, first_time: bool = False) -> Dict: # pylint: disable=arguments-differ """Make a API request for new posts and generate embed and strings for them""" if len(self.subreddits) == 0: return {} @@ -202,9 +208,9 @@ async def get_new_posts(self, first_time=False): # pylint: disable=arguments-di return posts - def generate_embed(self, data): + def generate_embed(self, data: Dict) -> Embed: """Given a dict of data, create a embed""" - embed = discord.Embed() + embed: Embed = Embed() embed.title = f"New post on {data['subreddit_name_prefixed']}!" embed.colour = self.subreddits[data['subreddit']].color @@ -226,12 +232,13 @@ def generate_embed(self, data): except KeyError: pass - time = datetime.datetime.utcfromtimestamp(data['created_utc']) + time = datetime.utcfromtimestamp(data['created_utc']) embed.timestamp = time return embed - def generate_plain_text(self, data): + @staticmethod + def generate_plain_text(data: Dict) -> str: """Given a dict of data, create a string""" return f"New post on {data['subreddit_name_prefixed']}: {data['title']}\n" \ f"Read more at https://reddit.com{data['permalink']}" diff --git a/dozer/sources/TwitchSource.py b/dozer/sources/TwitchSource.py index 70436aa7..b1d6be82 100644 --- a/dozer/sources/TwitchSource.py +++ b/dozer/sources/TwitchSource.py @@ -1,6 +1,7 @@ """News source to send a notification whenever a twitch streamer goes live.""" import datetime +from typing import TYPE_CHECKING import discord from dateutil import parser @@ -8,17 +9,20 @@ from .AbstractSources import DataBasedSource +if TYPE_CHECKING: + from dozer import Dozer + class TwitchSource(DataBasedSource): """News source to send a notification whenever a twitch streamer goes live.""" - full_name = "Twitch" - short_name = "twitch" - base_url = "https://twitch.tv" - description = "Makes a post whenever a specified user goes live on Twitch" + full_name: str = "Twitch" + short_name: str = "twitch" + base_url: str = "https://twitch.tv" + description: str = "Makes a post whenever a specified user goes live on Twitch" - token_url = "https://id.twitch.tv/oauth2/token" - api_url = "https://api.twitch.tv/helix" - color = discord.Color.from_rgb(145, 70, 255) + token_url: str = "https://id.twitch.tv/oauth2/token" + api_url: str = "https://api.twitch.tv/helix" + color: discord.Colour = discord.Color.from_rgb(145, 70, 255) class TwitchUser(DataBasedSource.DataPoint): """A helper class to represent a single Twitch streamer""" @@ -30,7 +34,7 @@ def __init__(self, user_id, display_name, profile_image_url, login): self.profile_image_url = profile_image_url self.login = login - def __init__(self, aiohttp_session, bot): + def __init__(self, aiohttp_session, bot: "Dozer"): super().__init__(aiohttp_session, bot) self.access_token = None self.client_id = None @@ -196,7 +200,8 @@ def generate_embed(self, data, games): return embed - def generate_plain_text(self, data, games): + @staticmethod + def generate_plain_text(data, games): """Given data on a stream and a dict of games, assemble a string""" try: display_name = data['display_name'] diff --git a/dozer/utils.py b/dozer/utils.py index a2220523..b62c8bb8 100755 --- a/dozer/utils.py +++ b/dozer/utils.py @@ -1,23 +1,30 @@ """Provides some useful utilities for the Discord bot, mostly to do with cleaning.""" import re +from re import Pattern +from typing import Optional, List, TYPE_CHECKING from urllib.parse import urlencode import discord +if TYPE_CHECKING: + from dozer.context import DozerContext -__all__ = ['clean', 'is_clean'] +__all__ = ['clean', 'is_clean', 'oauth_url', 'pretty_concat'] -mass_mention = re.compile('@(everyone|here)') -member_mention = re.compile(r'<@\!?(\d+)>') -role_mention = re.compile(r'<@&(\d+)>') -channel_mention = re.compile(r'<#(\d+)>') +mass_mention: Pattern = re.compile('@(everyone|here)') +member_mention: Pattern = re.compile(r'<@?(\d+)>') +role_mention: Pattern = re.compile(r'<@&(\d+)>') +channel_mention: Pattern = re.compile(r'<#(\d+)>') -def clean(ctx, text=None, *, mass=True, member=True, role=True, channel=True): +def clean(ctx: "DozerContext", text: Optional[str] = None, *, mass: bool = True, member: bool = True, role: bool = True, channel: bool = True) -> str: """Cleans the message of anything specified in the parameters passed.""" + if text is None: - text = ctx.message.content - cleaned_text = text + filter_text: str = ctx.message.content + else: + filter_text: str = text + cleaned_text: str = filter_text if mass: cleaned_text = mass_mention.sub(lambda match: '@\N{ZERO WIDTH SPACE}' + match.group(1), cleaned_text) if member: @@ -29,14 +36,14 @@ def clean(ctx, text=None, *, mass=True, member=True, role=True, channel=True): return cleaned_text -def is_clean(ctx, text=None): +def is_clean(ctx: "DozerContext", text: Optional[str] = None) -> bool: """Checks if the message is clean already and doesn't need to be cleaned.""" if text is None: text = ctx.message.content return all(regex.search(text) is None for regex in (mass_mention, member_mention, role_mention, channel_mention)) -def clean_member_name(ctx, member_id): +def clean_member_name(ctx: "DozerContext", member_id: int) -> str: """Cleans a member's name from the message.""" member = ctx.guild.get_member(member_id) if member is None: @@ -49,9 +56,9 @@ def clean_member_name(ctx, member_id): return '<@\N{ZERO WIDTH SPACE}%d>' % member.id -def clean_role_name(ctx, role_id): +def clean_role_name(ctx: "DozerContext", role_id: int) -> str: """Cleans role pings from messages.""" - role = discord.utils.get(ctx.guild.roles, id=role_id) # Guild.get_role doesn't exist + role: discord.Role = discord.utils.get(ctx.guild.roles, id=role_id) # Guild.get_role doesn't exist if role is None: return '<@&\N{ZERO WIDTH SPACE}%d>' % role_id elif is_clean(ctx, role.name): @@ -60,7 +67,7 @@ def clean_role_name(ctx, role_id): return '<@&\N{ZERO WIDTH SPACE}%d>' % role.id -def clean_channel_name(ctx, channel_id): +def clean_channel_name(ctx: "DozerContext", channel_id: int) -> str: """Cleans channel mentions from messages.""" channel = ctx.guild.get_channel(channel_id) if channel is None: @@ -71,7 +78,7 @@ def clean_channel_name(ctx, channel_id): return '<#\N{ZERO WIDTH SPACE}%d>' % channel.id -def pretty_concat(strings, single_suffix='', multi_suffix=''): +def pretty_concat(strings: List[str], single_suffix: str = '', multi_suffix: str = '') -> str: """Concatenates things in a pretty way""" if len(strings) == 1: return strings[0] + single_suffix @@ -81,7 +88,8 @@ def pretty_concat(strings, single_suffix='', multi_suffix=''): return '{}, and {}{}'.format(', '.join(strings[:-1]), strings[-1], multi_suffix) -def oauth_url(client_id, permissions=None, guild=None, redirect_uri=None): +def oauth_url(client_id: str, permissions: Optional[discord.Permissions] = None, guild: Optional[discord.Guild] = None, + redirect_uri: Optional[str] = None) -> str: """A helper function that returns the OAuth2 URL for inviting the bot into guilds.