xfstestsdb/xfstestsdb/sqlite.py
Anna Schumaker 929c1dd5eb xfstestsdb: Create a custom sqlite3 Connection manager
The connection manager is used to initialize the database and has a
wrapper around the sqlite3.execute() and sqlite3.executemany() functions
for easier argument passing.

Additionally, it implements __enter__() and __exit__() functions to
manually begin a transaction to calls can use the sqlite3 "RETURNING"
clause.

Signed-off-by: Anna Schumaker <anna@nowheycreamery.com>
2023-02-15 11:52:41 -05:00

66 lines
2.1 KiB
Python

# Copyright 2023 (c) Anna Schumaker
"""Helper class for working with sqlite3."""
import pathlib
import sqlite3
import sys
import xdg.BaseDirectory
DATA_DIR = pathlib.Path(xdg.BaseDirectory.save_data_path("xfstestsdb"))
DATA_FILE = DATA_DIR / f"xfstestsdb{'-debug' if __debug__ else ''}.sqlite3"
DATABASE = ":memory:" if "unittest" in sys.modules else DATA_FILE
SQL_SCRIPT = pathlib.Path(__file__).parent / "xfstestsdb.sql"
class Connection:
"""Manages the sqlite3 Connection."""
def __init__(self):
"""Initialize a sqlite3 connection."""
self.sql = sqlite3.connect(DATABASE,
detect_types=sqlite3.PARSE_DECLTYPES)
self.sql.row_factory = sqlite3.Row
self.connected = True
match self("PRAGMA user_version").fetchone()["user_version"]:
case 0:
with open(SQL_SCRIPT) as f:
self.sql.executescript(f.read())
self.sql.commit()
def __call__(self, statement: str,
*args, **kwargs) -> sqlite3.Cursor | None:
"""Execute a SQL statement."""
try:
sql_args = args if len(args) > 0 else kwargs
return self.sql.execute(statement, sql_args)
except sqlite3.IntegrityError:
return None
def __enter__(self) -> None:
"""Manually begin a transaction."""
self.sql.commit()
self.sql.execute("BEGIN")
def __exit__(self, exp_type, exp_value, traceback) -> bool:
"""Either commit or rollback an active transaction."""
if exp_type is None:
self.sql.commit()
else:
self.sql.rollback()
return exp_type is None
def close(self) -> None:
"""Close the database connection."""
if self.connected:
self.sql.execute("PRAGMA optimize")
self.sql.close()
self.connected = False
def executemany(self, statement: str, *args) -> sqlite3.Cursor | None:
"""Execute a SQL statement with several arguments."""
try:
return self.sql.executemany(statement, args)
except sqlite3.IntegrityError:
return None