From f6a42b97d9a79940f1e0c992a192fcb8c3968d3a Mon Sep 17 00:00:00 2001 From: Seaswimmer Date: Wed, 5 Jun 2024 23:13:23 -0400 Subject: [PATCH] misc(aurora): various model changes --- aurora/models/base.py | 11 +++++++---- aurora/models/moderation.py | 20 ++++++++++++-------- aurora/models/partials.py | 21 +++++++++------------ 3 files changed, 28 insertions(+), 24 deletions(-) diff --git a/aurora/models/base.py b/aurora/models/base.py index e6005ee..7cf8ab1 100644 --- a/aurora/models/base.py +++ b/aurora/models/base.py @@ -1,5 +1,6 @@ -from typing import Any +from typing import Any, Optional +from discord import Guild from pydantic import BaseModel, ConfigDict from redbot.core.bot import Red @@ -17,12 +18,14 @@ class AuroraBaseModel(BaseModel): return dump(self.dump(), file, indent=indent, **kwargs) if file else dumps(self.model_dump(exclude={"bot"}), indent=indent, **kwargs) class AuroraGuildModel(AuroraBaseModel): - """Subclass of AuroraBaseModel that includes a guild_id attribute, and a modified to_json() method to match.""" + """Subclass of AuroraBaseModel that includes a guild_id attribute and a guild attribute, and a modified to_json() method to match.""" + model_config = ConfigDict(ignored_types=(Red, Guild), arbitrary_types_allowed=True) guild_id: int + guild: Optional[Guild] = None def dump(self) -> dict: - return self.model_dump(exclude={"bot", "guild_id"}) + return self.model_dump(exclude={"bot", "guild_id", "guild"}) def to_json(self, indent: int | None = None, file: Any | None = None, **kwargs): from ..utilities.json import dump, dumps # pylint: disable=cyclic-import - return dump(self.dump(), file, indent=indent, **kwargs) if file else dumps(self.model_dump(exclude={"bot", "guild_id"}), indent=indent, **kwargs) + return dump(self.dump(), file, indent=indent, **kwargs) if file else dumps(self.dump(), indent=indent, **kwargs) diff --git a/aurora/models/moderation.py b/aurora/models/moderation.py index 8887dd5..9cc534c 100644 --- a/aurora/models/moderation.py +++ b/aurora/models/moderation.py @@ -52,7 +52,7 @@ class Moderation(AuroraGuildModel): async def get_target(self) -> Union["PartialUser", "PartialChannel"]: if self.target_type == "USER": return await PartialUser.from_id(self.bot, self.target_id) - return await PartialChannel.from_id(self.bot, self.target_id) + return await PartialChannel.from_id(self.bot, self.target_id, self.guild) async def get_resolved_by(self) -> Optional["PartialUser"]: if self.resolved_by: @@ -61,7 +61,7 @@ class Moderation(AuroraGuildModel): async def get_role(self) -> Optional["PartialRole"]: if self.role_id: - return await PartialRole.from_id(self.bot, self.guild_id, self.role_id) + return await PartialRole.from_id(self.bot, self.guild, self.role_id) return None def __str__(self) -> str: @@ -118,12 +118,10 @@ class Moderation(AuroraGuildModel): await self.update() async def update(self) -> None: - from ..utilities.database import connect from ..utilities.json import dumps query = f"UPDATE moderation_{self.guild_id} SET timestamp = ?, moderation_type = ?, target_type = ?, moderator_id = ?, role_id = ?, duration = ?, end_timestamp = ?, reason = ?, resolved = ?, resolved_by = ?, resolve_reason = ?, expired = ?, changes = ?, metadata = ? WHERE moderation_id = ?;" - database = await connect() - await database.execute(query, ( + await self.execute(query, ( self.timestamp.timestamp(), self.moderation_type, self.target_type, @@ -140,10 +138,8 @@ class Moderation(AuroraGuildModel): dumps(self.metadata), self.moderation_id, )) - await database.commit() - await database.close() - logger.debug("Row updated in moderation_%s!\n%s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s", + logger.verbose("Row updated in moderation_%s!\n%s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s", self.moderation_id, self.guild_id, self.timestamp.timestamp(), @@ -164,6 +160,14 @@ class Moderation(AuroraGuildModel): @classmethod def from_dict(cls, bot: Red, data: dict) -> "Moderation": + if data.get("guild_id"): + try: + guild: discord.Guild = bot.get_guild(data["guild_id"]) + if not guild: + guild = bot.fetch_guild(data["guild_id"]) + except (discord.Forbidden, discord.HTTPException): + guild = None + data.update({"guild": guild}) return cls(bot=bot, **data) @classmethod diff --git a/aurora/models/partials.py b/aurora/models/partials.py index 48f4030..68b0261 100644 --- a/aurora/models/partials.py +++ b/aurora/models/partials.py @@ -1,4 +1,4 @@ -from discord import Forbidden, HTTPException, InvalidData, NotFound +from discord import ChannelType, Forbidden, Guild, HTTPException, InvalidData, NotFound from redbot.core.bot import Red from .base import AuroraBaseModel, AuroraGuildModel @@ -31,6 +31,7 @@ class PartialUser(AuroraBaseModel): class PartialChannel(AuroraGuildModel): id: int name: str + type: ChannelType @property def mention(self): @@ -42,17 +43,17 @@ class PartialChannel(AuroraGuildModel): return self.mention @classmethod - async def from_id(cls, bot: Red, channel_id: int) -> "PartialChannel": + async def from_id(cls, bot: Red, channel_id: int, guild: Guild) -> "PartialChannel": channel = bot.get_channel(channel_id) if not channel: try: channel = await bot.fetch_channel(channel_id) - return cls(bot=bot, guild_id=channel.guild.id, id=channel.id, name=channel.name) + return cls(bot=bot, guild_id=channel.guild.id, guild=guild, id=channel.id, name=channel.name, type=channel.type) except (NotFound, InvalidData, HTTPException, Forbidden) as e: if e == Forbidden: return cls(bot=bot, guild_id=0, id=channel_id, name="Forbidden Channel") - return cls(bot=bot, guild_id=0, id=channel_id, name="Deleted Channel") - return cls(bot=bot, guild_id=channel.guild.id, id=channel.id, name=channel.name) + return cls(bot=bot, guild_id=0, id=channel_id, name="Deleted Channel", type=ChannelType.text) + return cls(bot=bot, guild_id=channel.guild.id, guild=guild, id=channel.id, name=channel.name, type=channel.type) class PartialRole(AuroraGuildModel): id: int @@ -68,12 +69,8 @@ class PartialRole(AuroraGuildModel): return self.mention @classmethod - async def from_id(cls, bot: Red, guild_id: int, role_id: int) -> "PartialRole": - try: - guild = await bot.fetch_guild(guild_id, with_counts=False) - except (Forbidden, HTTPException): - return cls(bot=bot, guild_id=guild_id, id=role_id, name="Forbidden Role") + async def from_id(cls, bot: Red, guild: Guild, role_id: int) -> "PartialRole": role = guild.get_role(role_id) if not role: - return cls(bot=bot, guild_id=guild_id, id=role_id, name="Deleted Role") - return cls(bot=bot, guild_id=guild_id, id=role.id, name=role.name) + return cls(bot=bot, guild_id=guild.id, id=role_id, name="Deleted Role") + return cls(bot=bot, guild_id=guild.id, id=role.id, name=role.name)