feat(aurora): made database.connect() into an async context manager

This commit is contained in:
Seaswimmer 2024-06-05 00:36:12 -04:00
parent eebddd6e89
commit 56a2f96a2d
Signed by: cswimr
GPG key ID: 5D671B5D03D65A7F
3 changed files with 11 additions and 4 deletions

View file

@ -77,7 +77,7 @@ class Aurora(commands.Cog):
"Invalid requester passed to red_delete_data_for_user: %s", requester "Invalid requester passed to red_delete_data_for_user: %s", requester
) )
def __init__(self, bot: Red): def __init__(self, bot: Red) -> None:
super().__init__() super().__init__()
self.bot = bot self.bot = bot
register_config(config) register_config(config)

View file

@ -27,7 +27,7 @@ class ImportAuroraView(ui.View):
"Deleting original table...", ephemeral=True "Deleting original table...", ephemeral=True
) )
async with connect() as database: async with await connect() as database:
query = f"DROP TABLE IF EXISTS moderation_{self.ctx.guild.id};" query = f"DROP TABLE IF EXISTS moderation_{self.ctx.guild.id};"
database.execute(query) database.execute(query)
database.commit() database.commit()

View file

@ -1,5 +1,7 @@
# pylint: disable=cyclic-import # pylint: disable=cyclic-import
import json import json
from contextlib import asynccontextmanager
from typing import Any, AsyncGenerator
import aiosqlite import aiosqlite
from discord import Guild from discord import Guild
@ -8,13 +10,14 @@ from redbot.core import data_manager
from .logger import logger from .logger import logger
async def connect() -> aiosqlite.Connection: @asynccontextmanager
async def connect() -> AsyncGenerator[aiosqlite.Connection, Any, None]:
"""Connects to the SQLite database, and returns a connection object.""" """Connects to the SQLite database, and returns a connection object."""
try: try:
connection = await aiosqlite.connect( connection = await aiosqlite.connect(
database=data_manager.cog_data_path(raw_name="Aurora") / "aurora.db" database=data_manager.cog_data_path(raw_name="Aurora") / "aurora.db"
) )
return connection yield connection
except aiosqlite.OperationalError as e: except aiosqlite.OperationalError as e:
logger.error("Unable to access the SQLite database!\nError:\n%s", e.msg) logger.error("Unable to access the SQLite database!\nError:\n%s", e.msg)
@ -22,6 +25,10 @@ async def connect() -> aiosqlite.Connection:
f"Unable to access the SQLite Database!\n{e.msg}" f"Unable to access the SQLite Database!\n{e.msg}"
) from e ) from e
finally:
if connection:
await connection.close()
async def create_guild_table(guild: Guild): async def create_guild_table(guild: Guild):
database = await connect() database = await connect()