xfstestsdb: Create a custom sqlite3 Connection manager
The connection manager is used to initialize the database and has a wrapper around the sqlite3.execute() and sqlite3.executemany() functions for easier argument passing. Additionally, it implements __enter__() and __exit__() functions to manually begin a transaction to calls can use the sqlite3 "RETURNING" clause. Signed-off-by: Anna Schumaker <anna@nowheycreamery.com>
This commit is contained in:
parent
c210eff9b9
commit
929c1dd5eb
109
tests/test_sqlite.py
Normal file
109
tests/test_sqlite.py
Normal file
|
@ -0,0 +1,109 @@
|
|||
# Copyright 2023 (c) Anna Schumaker
|
||||
"""Tests our database Connection object."""
|
||||
import pathlib
|
||||
import sqlite3
|
||||
import unittest
|
||||
import xfstestsdb.sqlite
|
||||
from xdg.BaseDirectory import save_data_path
|
||||
|
||||
|
||||
class TestConnection(unittest.TestCase):
|
||||
"""Test the database connection."""
|
||||
|
||||
def setUp(self):
|
||||
"""Set up common variables."""
|
||||
self.sql = xfstestsdb.sqlite.Connection()
|
||||
|
||||
def test_paths(self):
|
||||
"""Check that path constants are pointing in the right places."""
|
||||
data_dir = pathlib.Path(save_data_path("xfstestsdb"))
|
||||
self.assertEqual(xfstestsdb.sqlite.DATA_DIR, data_dir)
|
||||
self.assertEqual(xfstestsdb.sqlite.DATA_FILE,
|
||||
data_dir / "xfstestsdb-debug.sqlite3")
|
||||
self.assertEqual(xfstestsdb.sqlite.DATABASE, ":memory:")
|
||||
|
||||
script = pathlib.Path(xfstestsdb.__file__).parent / "xfstestsdb.sql"
|
||||
self.assertEqual(xfstestsdb.sqlite.SQL_SCRIPT, script)
|
||||
|
||||
def test_foreign_keys(self):
|
||||
"""Test that foreign key constraints are enabled."""
|
||||
cur = self.sql("PRAGMA foreign_keys")
|
||||
self.assertEqual(cur.fetchone()["foreign_keys"], 1)
|
||||
|
||||
def test_version(self):
|
||||
"""Test checking the database schema version."""
|
||||
cur = self.sql("PRAGMA user_version")
|
||||
self.assertEqual(cur.fetchone()["user_version"], 1)
|
||||
|
||||
def test_connection(self):
|
||||
"""Check that the connection manager is initialized properly."""
|
||||
self.assertIsInstance(self.sql.sql, sqlite3.Connection)
|
||||
self.assertEqual(self.sql.sql.row_factory, sqlite3.Row)
|
||||
self.assertTrue(self.sql.connected)
|
||||
|
||||
def test_call(self):
|
||||
"""Test that the connection manager can run sql statements."""
|
||||
self.sql("CREATE TABLE test (a INT UNIQUE, b INT)")
|
||||
self.sql("INSERT INTO test VALUES (?, ?)", 1, 2)
|
||||
cur = self.sql("SELECT * FROM test")
|
||||
self.assertIsInstance(cur, sqlite3.Cursor)
|
||||
row = cur.fetchone()
|
||||
self.assertIsInstance(row, sqlite3.Row)
|
||||
self.assertEqual(row["a"], 1)
|
||||
self.assertEqual(row["b"], 2)
|
||||
|
||||
def test_call_keyword(self):
|
||||
"""Test running a sql statement with keyword arguments."""
|
||||
self.sql("CREATE TABLE test (a INT UNIQUE, b INT)")
|
||||
self.sql("INSERT INTO test VALUES (:a, :b)", a=1, b=2)
|
||||
cur = self.sql("SELECT * FROM test")
|
||||
self.assertIsInstance(cur, sqlite3.Cursor)
|
||||
row = cur.fetchone()
|
||||
self.assertIsInstance(row, sqlite3.Row)
|
||||
self.assertEqual(row["a"], 1)
|
||||
self.assertEqual(row["b"], 2)
|
||||
|
||||
def test_executemany(self):
|
||||
"""Test that the connection manager can run several statements."""
|
||||
self.sql("CREATE TABLE test (a INT, b INT)")
|
||||
self.sql.executemany("INSERT INTO test VALUES (?, ?)",
|
||||
(1, 2), (3, 4), (5, 6), (7, 8), (9, 0))
|
||||
rows = self.sql("SELECT * FROM test").fetchall()
|
||||
self.assertListEqual([(row["a"], row["b"]) for row in rows],
|
||||
[(1, 2), (3, 4), (5, 6), (7, 8), (9, 0)])
|
||||
|
||||
def test_transaction(self):
|
||||
"""Test that we can manually start a transaction."""
|
||||
self.assertFalse(self.sql.sql.in_transaction)
|
||||
|
||||
with self.sql:
|
||||
self.assertTrue(self.sql.sql.in_transaction)
|
||||
self.sql("CREATE TABLE test_table (test TEXT)")
|
||||
self.sql("INSERT INTO test_table VALUES (?)", "Test")
|
||||
self.assertTrue(self.sql.sql.in_transaction)
|
||||
self.assertFalse(self.sql.sql.in_transaction)
|
||||
|
||||
cur = self.sql("SELECT COUNT(*) FROM test_table")
|
||||
self.assertEqual(cur.fetchone()["COUNT(*)"], 1)
|
||||
|
||||
def test_transaction_rollback(self):
|
||||
"""Test that errors roll back the transaction."""
|
||||
with self.assertRaises(Exception):
|
||||
with self.sql:
|
||||
self.sql("CREATE TABLE other_table (test TEXT)")
|
||||
self.sql("INSERT INTO other_table VALUES (?)", "Test")
|
||||
raise Exception("Test Exeption")
|
||||
|
||||
self.assertFalse(self.sql.sql.in_transaction)
|
||||
|
||||
with self.assertRaises(sqlite3.OperationalError):
|
||||
self.sql("SELECT COUNT(*) FROM other_table")
|
||||
|
||||
def test_close(self):
|
||||
"""Check closing the connection."""
|
||||
self.sql.close()
|
||||
self.assertFalse(self.sql.connected)
|
||||
with self.assertRaises(sqlite3.ProgrammingError):
|
||||
self.assertIsNone(self.sql("SELECT COUNT(*) FROM test_table"))
|
||||
|
||||
self.sql.close()
|
65
xfstestsdb/sqlite.py
Normal file
65
xfstestsdb/sqlite.py
Normal file
|
@ -0,0 +1,65 @@
|
|||
# Copyright 2023 (c) Anna Schumaker
|
||||
"""Helper class for working with sqlite3."""
|
||||
import pathlib
|
||||
import sqlite3
|
||||
import sys
|
||||
import xdg.BaseDirectory
|
||||
|
||||
|
||||
DATA_DIR = pathlib.Path(xdg.BaseDirectory.save_data_path("xfstestsdb"))
|
||||
DATA_FILE = DATA_DIR / f"xfstestsdb{'-debug' if __debug__ else ''}.sqlite3"
|
||||
DATABASE = ":memory:" if "unittest" in sys.modules else DATA_FILE
|
||||
SQL_SCRIPT = pathlib.Path(__file__).parent / "xfstestsdb.sql"
|
||||
|
||||
|
||||
class Connection:
|
||||
"""Manages the sqlite3 Connection."""
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize a sqlite3 connection."""
|
||||
self.sql = sqlite3.connect(DATABASE,
|
||||
detect_types=sqlite3.PARSE_DECLTYPES)
|
||||
self.sql.row_factory = sqlite3.Row
|
||||
self.connected = True
|
||||
|
||||
match self("PRAGMA user_version").fetchone()["user_version"]:
|
||||
case 0:
|
||||
with open(SQL_SCRIPT) as f:
|
||||
self.sql.executescript(f.read())
|
||||
self.sql.commit()
|
||||
|
||||
def __call__(self, statement: str,
|
||||
*args, **kwargs) -> sqlite3.Cursor | None:
|
||||
"""Execute a SQL statement."""
|
||||
try:
|
||||
sql_args = args if len(args) > 0 else kwargs
|
||||
return self.sql.execute(statement, sql_args)
|
||||
except sqlite3.IntegrityError:
|
||||
return None
|
||||
|
||||
def __enter__(self) -> None:
|
||||
"""Manually begin a transaction."""
|
||||
self.sql.commit()
|
||||
self.sql.execute("BEGIN")
|
||||
|
||||
def __exit__(self, exp_type, exp_value, traceback) -> bool:
|
||||
"""Either commit or rollback an active transaction."""
|
||||
if exp_type is None:
|
||||
self.sql.commit()
|
||||
else:
|
||||
self.sql.rollback()
|
||||
return exp_type is None
|
||||
|
||||
def close(self) -> None:
|
||||
"""Close the database connection."""
|
||||
if self.connected:
|
||||
self.sql.execute("PRAGMA optimize")
|
||||
self.sql.close()
|
||||
self.connected = False
|
||||
|
||||
def executemany(self, statement: str, *args) -> sqlite3.Cursor | None:
|
||||
"""Execute a SQL statement with several arguments."""
|
||||
try:
|
||||
return self.sql.executemany(statement, args)
|
||||
except sqlite3.IntegrityError:
|
||||
return None
|
4
xfstestsdb/xfstestsdb.sql
Normal file
4
xfstestsdb/xfstestsdb.sql
Normal file
|
@ -0,0 +1,4 @@
|
|||
/* Copyright 2023 (c) Anna Schumaker */
|
||||
|
||||
PRAGMA foreign_keys = ON;
|
||||
PRAGMA user_version = 1;
|
Loading…
Reference in New Issue
Block a user