xfstestsdb: Create a custom sqlite3 Connection manager

The connection manager is used to initialize the database and has a
wrapper around the sqlite3.execute() and sqlite3.executemany() functions
for easier argument passing.

Additionally, it implements __enter__() and __exit__() functions to
manually begin a transaction to calls can use the sqlite3 "RETURNING"
clause.

Signed-off-by: Anna Schumaker <anna@nowheycreamery.com>
This commit is contained in:
Anna Schumaker 2023-01-31 16:17:58 -05:00
parent c210eff9b9
commit 929c1dd5eb
3 changed files with 178 additions and 0 deletions

109
tests/test_sqlite.py Normal file
View File

@ -0,0 +1,109 @@
# Copyright 2023 (c) Anna Schumaker
"""Tests our database Connection object."""
import pathlib
import sqlite3
import unittest
import xfstestsdb.sqlite
from xdg.BaseDirectory import save_data_path
class TestConnection(unittest.TestCase):
"""Test the database connection."""
def setUp(self):
"""Set up common variables."""
self.sql = xfstestsdb.sqlite.Connection()
def test_paths(self):
"""Check that path constants are pointing in the right places."""
data_dir = pathlib.Path(save_data_path("xfstestsdb"))
self.assertEqual(xfstestsdb.sqlite.DATA_DIR, data_dir)
self.assertEqual(xfstestsdb.sqlite.DATA_FILE,
data_dir / "xfstestsdb-debug.sqlite3")
self.assertEqual(xfstestsdb.sqlite.DATABASE, ":memory:")
script = pathlib.Path(xfstestsdb.__file__).parent / "xfstestsdb.sql"
self.assertEqual(xfstestsdb.sqlite.SQL_SCRIPT, script)
def test_foreign_keys(self):
"""Test that foreign key constraints are enabled."""
cur = self.sql("PRAGMA foreign_keys")
self.assertEqual(cur.fetchone()["foreign_keys"], 1)
def test_version(self):
"""Test checking the database schema version."""
cur = self.sql("PRAGMA user_version")
self.assertEqual(cur.fetchone()["user_version"], 1)
def test_connection(self):
"""Check that the connection manager is initialized properly."""
self.assertIsInstance(self.sql.sql, sqlite3.Connection)
self.assertEqual(self.sql.sql.row_factory, sqlite3.Row)
self.assertTrue(self.sql.connected)
def test_call(self):
"""Test that the connection manager can run sql statements."""
self.sql("CREATE TABLE test (a INT UNIQUE, b INT)")
self.sql("INSERT INTO test VALUES (?, ?)", 1, 2)
cur = self.sql("SELECT * FROM test")
self.assertIsInstance(cur, sqlite3.Cursor)
row = cur.fetchone()
self.assertIsInstance(row, sqlite3.Row)
self.assertEqual(row["a"], 1)
self.assertEqual(row["b"], 2)
def test_call_keyword(self):
"""Test running a sql statement with keyword arguments."""
self.sql("CREATE TABLE test (a INT UNIQUE, b INT)")
self.sql("INSERT INTO test VALUES (:a, :b)", a=1, b=2)
cur = self.sql("SELECT * FROM test")
self.assertIsInstance(cur, sqlite3.Cursor)
row = cur.fetchone()
self.assertIsInstance(row, sqlite3.Row)
self.assertEqual(row["a"], 1)
self.assertEqual(row["b"], 2)
def test_executemany(self):
"""Test that the connection manager can run several statements."""
self.sql("CREATE TABLE test (a INT, b INT)")
self.sql.executemany("INSERT INTO test VALUES (?, ?)",
(1, 2), (3, 4), (5, 6), (7, 8), (9, 0))
rows = self.sql("SELECT * FROM test").fetchall()
self.assertListEqual([(row["a"], row["b"]) for row in rows],
[(1, 2), (3, 4), (5, 6), (7, 8), (9, 0)])
def test_transaction(self):
"""Test that we can manually start a transaction."""
self.assertFalse(self.sql.sql.in_transaction)
with self.sql:
self.assertTrue(self.sql.sql.in_transaction)
self.sql("CREATE TABLE test_table (test TEXT)")
self.sql("INSERT INTO test_table VALUES (?)", "Test")
self.assertTrue(self.sql.sql.in_transaction)
self.assertFalse(self.sql.sql.in_transaction)
cur = self.sql("SELECT COUNT(*) FROM test_table")
self.assertEqual(cur.fetchone()["COUNT(*)"], 1)
def test_transaction_rollback(self):
"""Test that errors roll back the transaction."""
with self.assertRaises(Exception):
with self.sql:
self.sql("CREATE TABLE other_table (test TEXT)")
self.sql("INSERT INTO other_table VALUES (?)", "Test")
raise Exception("Test Exeption")
self.assertFalse(self.sql.sql.in_transaction)
with self.assertRaises(sqlite3.OperationalError):
self.sql("SELECT COUNT(*) FROM other_table")
def test_close(self):
"""Check closing the connection."""
self.sql.close()
self.assertFalse(self.sql.connected)
with self.assertRaises(sqlite3.ProgrammingError):
self.assertIsNone(self.sql("SELECT COUNT(*) FROM test_table"))
self.sql.close()

65
xfstestsdb/sqlite.py Normal file
View File

@ -0,0 +1,65 @@
# Copyright 2023 (c) Anna Schumaker
"""Helper class for working with sqlite3."""
import pathlib
import sqlite3
import sys
import xdg.BaseDirectory
DATA_DIR = pathlib.Path(xdg.BaseDirectory.save_data_path("xfstestsdb"))
DATA_FILE = DATA_DIR / f"xfstestsdb{'-debug' if __debug__ else ''}.sqlite3"
DATABASE = ":memory:" if "unittest" in sys.modules else DATA_FILE
SQL_SCRIPT = pathlib.Path(__file__).parent / "xfstestsdb.sql"
class Connection:
"""Manages the sqlite3 Connection."""
def __init__(self):
"""Initialize a sqlite3 connection."""
self.sql = sqlite3.connect(DATABASE,
detect_types=sqlite3.PARSE_DECLTYPES)
self.sql.row_factory = sqlite3.Row
self.connected = True
match self("PRAGMA user_version").fetchone()["user_version"]:
case 0:
with open(SQL_SCRIPT) as f:
self.sql.executescript(f.read())
self.sql.commit()
def __call__(self, statement: str,
*args, **kwargs) -> sqlite3.Cursor | None:
"""Execute a SQL statement."""
try:
sql_args = args if len(args) > 0 else kwargs
return self.sql.execute(statement, sql_args)
except sqlite3.IntegrityError:
return None
def __enter__(self) -> None:
"""Manually begin a transaction."""
self.sql.commit()
self.sql.execute("BEGIN")
def __exit__(self, exp_type, exp_value, traceback) -> bool:
"""Either commit or rollback an active transaction."""
if exp_type is None:
self.sql.commit()
else:
self.sql.rollback()
return exp_type is None
def close(self) -> None:
"""Close the database connection."""
if self.connected:
self.sql.execute("PRAGMA optimize")
self.sql.close()
self.connected = False
def executemany(self, statement: str, *args) -> sqlite3.Cursor | None:
"""Execute a SQL statement with several arguments."""
try:
return self.sql.executemany(statement, args)
except sqlite3.IntegrityError:
return None

View File

@ -0,0 +1,4 @@
/* Copyright 2023 (c) Anna Schumaker */
PRAGMA foreign_keys = ON;
PRAGMA user_version = 1;