83 lines
2.7 KiB
Python
83 lines
2.7 KiB
Python
# Copyright 2023 (c) Anna Schumaker
|
|
"""Helper class for working with sqlite3."""
|
|
import pathlib
|
|
import sqlite3
|
|
import sys
|
|
import xdg.BaseDirectory
|
|
|
|
|
|
DATA_DIR = pathlib.Path(xdg.BaseDirectory.save_data_path("xfstestsdb"))
|
|
DATA_FILE = DATA_DIR / f"xfstestsdb{'-debug' if __debug__ else ''}.sqlite3"
|
|
DATABASE = ":memory:" if "unittest" in sys.modules else DATA_FILE
|
|
|
|
SQL_SCRIPTS = pathlib.Path(__file__).parent / "scripts"
|
|
SQL_V1_SCRIPT = SQL_SCRIPTS / "xfstestsdb.sql"
|
|
SQL_V2_SCRIPT = SQL_SCRIPTS / "upgrade-v2.sql"
|
|
|
|
|
|
class Connection:
|
|
"""Manages the sqlite3 Connection."""
|
|
|
|
def __init__(self):
|
|
"""Initialize a sqlite3 connection."""
|
|
self.sql = sqlite3.connect(DATABASE,
|
|
detect_types=sqlite3.PARSE_DECLTYPES)
|
|
self.sql.row_factory = sqlite3.Row
|
|
self.connected = True
|
|
|
|
self("PRAGMA foreign_keys = ON")
|
|
match self("PRAGMA user_version").fetchone()["user_version"]:
|
|
case 0:
|
|
self.executescript(SQL_V1_SCRIPT)
|
|
self.executescript(SQL_V2_SCRIPT)
|
|
case 1:
|
|
self.executescript(SQL_V2_SCRIPT)
|
|
|
|
def __call__(self, statement: str,
|
|
*args, **kwargs) -> sqlite3.Cursor | None:
|
|
"""Execute a SQL statement."""
|
|
try:
|
|
sql_args = args if len(args) > 0 else kwargs
|
|
return self.sql.execute(statement, sql_args)
|
|
except sqlite3.IntegrityError:
|
|
return None
|
|
|
|
def __enter__(self) -> None:
|
|
"""Manually begin a transaction."""
|
|
self.sql.commit()
|
|
self.sql.execute("BEGIN")
|
|
|
|
def __exit__(self, exp_type, exp_value, traceback) -> bool:
|
|
"""Either commit or rollback an active transaction."""
|
|
if exp_type is None:
|
|
self.sql.commit()
|
|
else:
|
|
self.sql.rollback()
|
|
return exp_type is None
|
|
|
|
def close(self) -> None:
|
|
"""Close the database connection."""
|
|
if self.connected:
|
|
self.sql.execute("PRAGMA optimize")
|
|
self.sql.close()
|
|
self.connected = False
|
|
|
|
def executemany(self, statement: str, *args) -> sqlite3.Cursor | None:
|
|
"""Execute a SQL statement with several arguments."""
|
|
try:
|
|
return self.sql.executemany(statement, args)
|
|
except sqlite3.IntegrityError:
|
|
return None
|
|
|
|
def executescript(self, script: pathlib.Path) -> sqlite3.Cursor | None:
|
|
"""Execute a SQL script."""
|
|
if script.is_file():
|
|
with open(script) as f:
|
|
cur = self.sql.executescript(f.read())
|
|
self.sql.commit()
|
|
return cur
|
|
|
|
def vacuum(self) -> sqlite3.Cursor | None:
|
|
"""Vacuum the database."""
|
|
return self.sql.execute("VACUUM")
|