misc(aurora): various model changes

This commit is contained in:
Seaswimmer 2024-06-05 23:13:23 -04:00
parent 5cb61ecd65
commit f6a42b97d9
Signed by untrusted user: cswimr
GPG key ID: 5D671B5D03D65A7F
3 changed files with 28 additions and 24 deletions

View file

@ -1,5 +1,6 @@
from typing import Any from typing import Any, Optional
from discord import Guild
from pydantic import BaseModel, ConfigDict from pydantic import BaseModel, ConfigDict
from redbot.core.bot import Red from redbot.core.bot import Red
@ -17,12 +18,14 @@ class AuroraBaseModel(BaseModel):
return dump(self.dump(), file, indent=indent, **kwargs) if file else dumps(self.model_dump(exclude={"bot"}), indent=indent, **kwargs) return dump(self.dump(), file, indent=indent, **kwargs) if file else dumps(self.model_dump(exclude={"bot"}), indent=indent, **kwargs)
class AuroraGuildModel(AuroraBaseModel): class AuroraGuildModel(AuroraBaseModel):
"""Subclass of AuroraBaseModel that includes a guild_id attribute, and a modified to_json() method to match.""" """Subclass of AuroraBaseModel that includes a guild_id attribute and a guild attribute, and a modified to_json() method to match."""
model_config = ConfigDict(ignored_types=(Red, Guild), arbitrary_types_allowed=True)
guild_id: int guild_id: int
guild: Optional[Guild] = None
def dump(self) -> dict: def dump(self) -> dict:
return self.model_dump(exclude={"bot", "guild_id"}) return self.model_dump(exclude={"bot", "guild_id", "guild"})
def to_json(self, indent: int | None = None, file: Any | None = None, **kwargs): def to_json(self, indent: int | None = None, file: Any | None = None, **kwargs):
from ..utilities.json import dump, dumps # pylint: disable=cyclic-import from ..utilities.json import dump, dumps # pylint: disable=cyclic-import
return dump(self.dump(), file, indent=indent, **kwargs) if file else dumps(self.model_dump(exclude={"bot", "guild_id"}), indent=indent, **kwargs) return dump(self.dump(), file, indent=indent, **kwargs) if file else dumps(self.dump(), indent=indent, **kwargs)

View file

@ -52,7 +52,7 @@ class Moderation(AuroraGuildModel):
async def get_target(self) -> Union["PartialUser", "PartialChannel"]: async def get_target(self) -> Union["PartialUser", "PartialChannel"]:
if self.target_type == "USER": if self.target_type == "USER":
return await PartialUser.from_id(self.bot, self.target_id) return await PartialUser.from_id(self.bot, self.target_id)
return await PartialChannel.from_id(self.bot, self.target_id) return await PartialChannel.from_id(self.bot, self.target_id, self.guild)
async def get_resolved_by(self) -> Optional["PartialUser"]: async def get_resolved_by(self) -> Optional["PartialUser"]:
if self.resolved_by: if self.resolved_by:
@ -61,7 +61,7 @@ class Moderation(AuroraGuildModel):
async def get_role(self) -> Optional["PartialRole"]: async def get_role(self) -> Optional["PartialRole"]:
if self.role_id: if self.role_id:
return await PartialRole.from_id(self.bot, self.guild_id, self.role_id) return await PartialRole.from_id(self.bot, self.guild, self.role_id)
return None return None
def __str__(self) -> str: def __str__(self) -> str:
@ -118,12 +118,10 @@ class Moderation(AuroraGuildModel):
await self.update() await self.update()
async def update(self) -> None: async def update(self) -> None:
from ..utilities.database import connect
from ..utilities.json import dumps from ..utilities.json import dumps
query = f"UPDATE moderation_{self.guild_id} SET timestamp = ?, moderation_type = ?, target_type = ?, moderator_id = ?, role_id = ?, duration = ?, end_timestamp = ?, reason = ?, resolved = ?, resolved_by = ?, resolve_reason = ?, expired = ?, changes = ?, metadata = ? WHERE moderation_id = ?;" query = f"UPDATE moderation_{self.guild_id} SET timestamp = ?, moderation_type = ?, target_type = ?, moderator_id = ?, role_id = ?, duration = ?, end_timestamp = ?, reason = ?, resolved = ?, resolved_by = ?, resolve_reason = ?, expired = ?, changes = ?, metadata = ? WHERE moderation_id = ?;"
database = await connect() await self.execute(query, (
await database.execute(query, (
self.timestamp.timestamp(), self.timestamp.timestamp(),
self.moderation_type, self.moderation_type,
self.target_type, self.target_type,
@ -140,10 +138,8 @@ class Moderation(AuroraGuildModel):
dumps(self.metadata), dumps(self.metadata),
self.moderation_id, self.moderation_id,
)) ))
await database.commit()
await database.close()
logger.debug("Row updated in moderation_%s!\n%s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s", logger.verbose("Row updated in moderation_%s!\n%s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s",
self.moderation_id, self.moderation_id,
self.guild_id, self.guild_id,
self.timestamp.timestamp(), self.timestamp.timestamp(),
@ -164,6 +160,14 @@ class Moderation(AuroraGuildModel):
@classmethod @classmethod
def from_dict(cls, bot: Red, data: dict) -> "Moderation": def from_dict(cls, bot: Red, data: dict) -> "Moderation":
if data.get("guild_id"):
try:
guild: discord.Guild = bot.get_guild(data["guild_id"])
if not guild:
guild = bot.fetch_guild(data["guild_id"])
except (discord.Forbidden, discord.HTTPException):
guild = None
data.update({"guild": guild})
return cls(bot=bot, **data) return cls(bot=bot, **data)
@classmethod @classmethod

View file

@ -1,4 +1,4 @@
from discord import Forbidden, HTTPException, InvalidData, NotFound from discord import ChannelType, Forbidden, Guild, HTTPException, InvalidData, NotFound
from redbot.core.bot import Red from redbot.core.bot import Red
from .base import AuroraBaseModel, AuroraGuildModel from .base import AuroraBaseModel, AuroraGuildModel
@ -31,6 +31,7 @@ class PartialUser(AuroraBaseModel):
class PartialChannel(AuroraGuildModel): class PartialChannel(AuroraGuildModel):
id: int id: int
name: str name: str
type: ChannelType
@property @property
def mention(self): def mention(self):
@ -42,17 +43,17 @@ class PartialChannel(AuroraGuildModel):
return self.mention return self.mention
@classmethod @classmethod
async def from_id(cls, bot: Red, channel_id: int) -> "PartialChannel": async def from_id(cls, bot: Red, channel_id: int, guild: Guild) -> "PartialChannel":
channel = bot.get_channel(channel_id) channel = bot.get_channel(channel_id)
if not channel: if not channel:
try: try:
channel = await bot.fetch_channel(channel_id) channel = await bot.fetch_channel(channel_id)
return cls(bot=bot, guild_id=channel.guild.id, id=channel.id, name=channel.name) return cls(bot=bot, guild_id=channel.guild.id, guild=guild, id=channel.id, name=channel.name, type=channel.type)
except (NotFound, InvalidData, HTTPException, Forbidden) as e: except (NotFound, InvalidData, HTTPException, Forbidden) as e:
if e == Forbidden: if e == Forbidden:
return cls(bot=bot, guild_id=0, id=channel_id, name="Forbidden Channel") return cls(bot=bot, guild_id=0, id=channel_id, name="Forbidden Channel")
return cls(bot=bot, guild_id=0, id=channel_id, name="Deleted Channel") return cls(bot=bot, guild_id=0, id=channel_id, name="Deleted Channel", type=ChannelType.text)
return cls(bot=bot, guild_id=channel.guild.id, id=channel.id, name=channel.name) return cls(bot=bot, guild_id=channel.guild.id, guild=guild, id=channel.id, name=channel.name, type=channel.type)
class PartialRole(AuroraGuildModel): class PartialRole(AuroraGuildModel):
id: int id: int
@ -68,12 +69,8 @@ class PartialRole(AuroraGuildModel):
return self.mention return self.mention
@classmethod @classmethod
async def from_id(cls, bot: Red, guild_id: int, role_id: int) -> "PartialRole": async def from_id(cls, bot: Red, guild: Guild, role_id: int) -> "PartialRole":
try:
guild = await bot.fetch_guild(guild_id, with_counts=False)
except (Forbidden, HTTPException):
return cls(bot=bot, guild_id=guild_id, id=role_id, name="Forbidden Role")
role = guild.get_role(role_id) role = guild.get_role(role_id)
if not role: if not role:
return cls(bot=bot, guild_id=guild_id, id=role_id, name="Deleted Role") return cls(bot=bot, guild_id=guild.id, id=role_id, name="Deleted Role")
return cls(bot=bot, guild_id=guild_id, id=role.id, name=role.name) return cls(bot=bot, guild_id=guild.id, id=role.id, name=role.name)