From 56a2f96a2d6be218db8d04d99906cad2f322399e Mon Sep 17 00:00:00 2001 From: Seaswimmer Date: Wed, 5 Jun 2024 00:36:12 -0400 Subject: [PATCH] feat(aurora): made database.connect() into an async context manager --- aurora/aurora.py | 2 +- aurora/importers/aurora.py | 2 +- aurora/utilities/database.py | 11 +++++++++-- 3 files changed, 11 insertions(+), 4 deletions(-) diff --git a/aurora/aurora.py b/aurora/aurora.py index fed5034..2332d4a 100644 --- a/aurora/aurora.py +++ b/aurora/aurora.py @@ -77,7 +77,7 @@ class Aurora(commands.Cog): "Invalid requester passed to red_delete_data_for_user: %s", requester ) - def __init__(self, bot: Red): + def __init__(self, bot: Red) -> None: super().__init__() self.bot = bot register_config(config) diff --git a/aurora/importers/aurora.py b/aurora/importers/aurora.py index f154dab..b2367a6 100644 --- a/aurora/importers/aurora.py +++ b/aurora/importers/aurora.py @@ -27,7 +27,7 @@ class ImportAuroraView(ui.View): "Deleting original table...", ephemeral=True ) - async with connect() as database: + async with await connect() as database: query = f"DROP TABLE IF EXISTS moderation_{self.ctx.guild.id};" database.execute(query) database.commit() diff --git a/aurora/utilities/database.py b/aurora/utilities/database.py index 3894bce..7d50d77 100644 --- a/aurora/utilities/database.py +++ b/aurora/utilities/database.py @@ -1,5 +1,7 @@ # pylint: disable=cyclic-import import json +from contextlib import asynccontextmanager +from typing import Any, AsyncGenerator import aiosqlite from discord import Guild @@ -8,13 +10,14 @@ from redbot.core import data_manager from .logger import logger -async def connect() -> aiosqlite.Connection: +@asynccontextmanager +async def connect() -> AsyncGenerator[aiosqlite.Connection, Any, None]: """Connects to the SQLite database, and returns a connection object.""" try: connection = await aiosqlite.connect( database=data_manager.cog_data_path(raw_name="Aurora") / "aurora.db" ) - return connection + yield connection except aiosqlite.OperationalError as e: logger.error("Unable to access the SQLite database!\nError:\n%s", e.msg) @@ -22,6 +25,10 @@ async def connect() -> aiosqlite.Connection: f"Unable to access the SQLite Database!\n{e.msg}" ) from e + finally: + if connection: + await connection.close() + async def create_guild_table(guild: Guild): database = await connect()