db: Create a base Connection manager
This is a wrapper around the sqlite3.Connection objct that adds some nice functionality to make working with SQL easier. I defined the following magic methods: * __enter__() to manually begin a transaction * __exit__() to commit or rollback a manual transaction * __call__() to execute a SQL statement with either positional or keyword arguments. Additionally: * I define a "CASEFOLD" function that can be used in queries to lowercase unicode text when searching. * I set foreign_keys = ON so foreign keys checking is always enabled * I provide an executemany() function for running running the same statement multiple times with different arguments. Signed-off-by: Anna Schumaker <Anna@NoWheyCreamery.com>
This commit is contained in:
parent
61fc252172
commit
deb4f3d252
|
@ -0,0 +1,68 @@
|
|||
# Copyright 2022 (c) Anna Schumaker
|
||||
"""Easily work with our underlying sqlite3 database."""
|
||||
import sqlite3
|
||||
import sys
|
||||
from gi.repository import GObject
|
||||
from .. import gsetup
|
||||
|
||||
|
||||
DATA_FILE = gsetup.DATA_DIR / f"emmental{gsetup.DEBUG_STR}.sqlite3"
|
||||
DATABASE = ":memory:" if "unittest" in sys.modules else DATA_FILE
|
||||
|
||||
|
||||
class Connection(GObject.GObject):
|
||||
"""Connect to the database."""
|
||||
|
||||
connected = GObject.Property(type=bool, default=True)
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize a sqlite connection."""
|
||||
super().__init__()
|
||||
self._sql = sqlite3.connect(DATABASE,
|
||||
detect_types=sqlite3.PARSE_DECLTYPES)
|
||||
self._sql.create_function("CASEFOLD", 1,
|
||||
lambda s: s.casefold() if s else None,
|
||||
deterministic=True)
|
||||
self._sql.row_factory = sqlite3.Row
|
||||
self._sql("PRAGMA foreign_keys = ON")
|
||||
|
||||
def __call__(self, statement: str,
|
||||
*args, **kwargs) -> sqlite3.Cursor | None:
|
||||
"""Execute a SQL statement."""
|
||||
try:
|
||||
return self._sql.execute(statement, args if len(args) else kwargs)
|
||||
except sqlite3.IntegrityError:
|
||||
return None
|
||||
|
||||
def __del__(self) -> None:
|
||||
"""Clean up before exiting."""
|
||||
self.close()
|
||||
|
||||
def __enter__(self) -> None:
|
||||
"""Begin a transaction."""
|
||||
if not self._sql.in_transaction:
|
||||
self._sql.commit()
|
||||
self._sql.execute("BEGIN")
|
||||
|
||||
def __exit__(self, exp_type, exp_value, traceback) -> bool:
|
||||
"""Either commit or rollback an active transaction."""
|
||||
if exp_type is None:
|
||||
self._sql.commit()
|
||||
else:
|
||||
self._sql.rollback()
|
||||
return exp_type is None
|
||||
|
||||
def close(self) -> None:
|
||||
"""Close the database connection."""
|
||||
if self.connected:
|
||||
self._sql.commit()
|
||||
self._sql.execute("PRAGMA optimize")
|
||||
self._sql.close()
|
||||
self.connected = False
|
||||
|
||||
def executemany(self, statement: str, *args) -> sqlite3.Cursor | None:
|
||||
"""Execute several similar SQL statements at once."""
|
||||
try:
|
||||
return self._sql.executemany(statement, args)
|
||||
except sqlite3.InternalError:
|
||||
return None
|
|
@ -2,7 +2,9 @@
|
|||
"""Set up GObject Introspection, and custom styling, and icons."""
|
||||
import pathlib
|
||||
import sys
|
||||
import sqlite3
|
||||
import gi
|
||||
import xdg.BaseDirectory
|
||||
|
||||
gi.require_version("Gdk", "4.0")
|
||||
gi.require_version("Gtk", "4.0")
|
||||
|
@ -19,6 +21,8 @@ CSS_PRIORITY = gi.repository.Gtk.STYLE_PROVIDER_PRIORITY_APPLICATION
|
|||
CSS_PROVIDER = gi.repository.Gtk.CssProvider()
|
||||
CSS_PROVIDER.load_from_path(str(CSS_FILE))
|
||||
|
||||
DATA_DIR = pathlib.Path(xdg.BaseDirectory.save_data_path("emmental"))
|
||||
|
||||
RESOURCE_PATH = "/com/nowheycreamery/emmental"
|
||||
RESOURCE_ICONS = f"{RESOURCE_PATH}/icons/scalable/apps"
|
||||
RESOURCE_FILE = pathlib.Path(__file__).parent.parent / "emmental.gresource"
|
||||
|
@ -47,3 +51,6 @@ def print_versions():
|
|||
__print_version("Libadwaita", gi.repository.Adw.MAJOR_VERSION,
|
||||
gi.repository.Adw.MINOR_VERSION,
|
||||
gi.repository.Adw.MICRO_VERSION)
|
||||
__print_version("SQLite", sqlite3.sqlite_version_info[0],
|
||||
sqlite3.sqlite_version_info[1],
|
||||
sqlite3.sqlite_version_info[2])
|
||||
|
|
|
@ -0,0 +1,116 @@
|
|||
# Copyright 2022 (c) Anna Schumaker
|
||||
"""Test our custom db Connection object."""
|
||||
import sqlite3
|
||||
import emmental.db.connection
|
||||
import unittest
|
||||
from gi.repository import GObject
|
||||
|
||||
|
||||
class TestConnection(unittest.TestCase):
|
||||
"""Test case for our database connection manager."""
|
||||
|
||||
def setUp(self):
|
||||
"""Set up common variables."""
|
||||
self.sql = emmental.db.connection.Connection()
|
||||
|
||||
def tearDown(self):
|
||||
"""Clean up."""
|
||||
self.sql.close()
|
||||
|
||||
def test_paths(self):
|
||||
"""Check that path constants are pointing to the right places."""
|
||||
self.assertEqual(emmental.db.connection.DATA_FILE,
|
||||
emmental.gsetup.DATA_DIR / "emmental-debug.sqlite3")
|
||||
self.assertEqual(emmental.db.connection.DATABASE, ":memory:")
|
||||
|
||||
def test_connection(self):
|
||||
"""Check that the connection manager is initialized properly."""
|
||||
self.assertIsInstance(self.sql, GObject.GObject)
|
||||
self.assertIsInstance(self.sql._sql, sqlite3.Connection)
|
||||
self.assertEqual(self.sql._sql.row_factory, sqlite3.Row)
|
||||
self.assertTrue(self.sql.connected)
|
||||
|
||||
def test_foreign_keys(self):
|
||||
"""Test that foreign keys are enabled."""
|
||||
cur = self.sql("PRAGMA foreign_keys")
|
||||
self.assertEqual(cur.fetchone()["foreign_keys"], 1)
|
||||
|
||||
def test_call(self):
|
||||
"""Check that we can call the connection to execute a sql statement."""
|
||||
self.sql("CREATE TABLE test (a INT UNIQUE, b INT)")
|
||||
self.sql("INSERT INTO test VALUES (?, ?)", 1, 2)
|
||||
cur = self.sql("SELECT * FROM test")
|
||||
self.assertIsInstance(cur, sqlite3.Cursor)
|
||||
row = cur.fetchone()
|
||||
self.assertIsInstance(row, sqlite3.Row)
|
||||
self.assertEqual(row["a"], 1)
|
||||
self.assertEqual(row["b"], 2)
|
||||
|
||||
def test_call_keyword(self):
|
||||
"""Test running a sql statement with keyword arguments."""
|
||||
self.sql("CREATE TABLE test (a INT UNIQUE, b INT)")
|
||||
self.sql("INSERT INTO test VALUES (:a, :b)", a=1, b=2)
|
||||
cur = self.sql("SELECT * FROM test")
|
||||
self.assertIsInstance(cur, sqlite3.Cursor)
|
||||
row = cur.fetchone()
|
||||
self.assertIsInstance(row, sqlite3.Row)
|
||||
self.assertEqual(row["a"], 1)
|
||||
self.assertEqual(row["b"], 2)
|
||||
|
||||
def test_casefold(self):
|
||||
"""Test the casefold function."""
|
||||
self.sql("CREATE TABLE test (a INT, text TEXT)")
|
||||
self.sql("INSERT INTO test VALUES (?, ?)", 1, "TEST")
|
||||
self.sql("INSERT INTO test VALUES (?, ?)", 2, None)
|
||||
rows = self.sql("SELECT CASEFOLD(text) as text FROM test").fetchall()
|
||||
self.assertEqual(rows[0]["text"], "test")
|
||||
self.assertEqual(rows[1]["text"], None)
|
||||
|
||||
def test_executemany(self):
|
||||
"""Test the executemany function."""
|
||||
self.sql("CREATE TABLE test (a INT, b TEXT)")
|
||||
self.sql.executemany("INSERT INTO test VALUES (?, ?)",
|
||||
(1, "a"), (2, "b"), (3, "c"), (4, "d"), (5, "e"))
|
||||
rows = self.sql("SELECT * FROM test").fetchall()
|
||||
self.assertEqual(tuple(rows[0]), (1, "a"))
|
||||
self.assertEqual(tuple(rows[1]), (2, "b"))
|
||||
self.assertEqual(tuple(rows[2]), (3, "c"))
|
||||
self.assertEqual(tuple(rows[3]), (4, "d"))
|
||||
self.assertEqual(tuple(rows[4]), (5, "e"))
|
||||
|
||||
def test_transaction(self):
|
||||
"""Test that we can manually start a transaction."""
|
||||
self.assertFalse(self.sql._sql.in_transaction)
|
||||
|
||||
with self.sql:
|
||||
self.assertTrue(self.sql._sql.in_transaction)
|
||||
self.sql("CREATE TABLE test_table (test TEXT)")
|
||||
self.sql("INSERT INTO test_table VALUES (?)", "Test")
|
||||
self.assertTrue(self.sql._sql.in_transaction)
|
||||
self.assertFalse(self.sql._sql.in_transaction)
|
||||
|
||||
cur = self.sql("SELECT COUNT(*) FROM test_table")
|
||||
self.assertEqual(cur.fetchone()["COUNT(*)"], 1)
|
||||
|
||||
def test_transaction_rollback(self):
|
||||
"""Test that errors roll back the transaction."""
|
||||
with self.assertRaises(Exception):
|
||||
with self.sql:
|
||||
self.sql("CREATE TABLE other_table (test TEXT)")
|
||||
self.sql("INSERT INTO other_table VALUES (?)", "Test")
|
||||
raise Exception("Test Exeption")
|
||||
|
||||
self.assertFalse(self.sql._sql.in_transaction)
|
||||
|
||||
with self.assertRaises(sqlite3.OperationalError):
|
||||
self.sql("SELECT COUNT(*) FROM other_table")
|
||||
|
||||
def test_close(self):
|
||||
"""Check closing the connection."""
|
||||
self.sql.close()
|
||||
self.assertFalse(self.sql.connected)
|
||||
|
||||
with self.assertRaises(sqlite3.ProgrammingError):
|
||||
self.assertIsNone(self.sql("SELECT COUNT(*) FROM test_table"))
|
||||
|
||||
self.sql.close()
|
|
@ -4,6 +4,7 @@ import unittest
|
|||
import pathlib
|
||||
import emmental
|
||||
import gi
|
||||
import xdg.BaseDirectory
|
||||
|
||||
|
||||
class TestGSetup(unittest.TestCase):
|
||||
|
@ -53,3 +54,8 @@ class TestGSetup(unittest.TestCase):
|
|||
|
||||
self.assertIsInstance(emmental.gsetup.RESOURCE,
|
||||
gi.repository.Gio.Resource)
|
||||
|
||||
def test_data_dir(self):
|
||||
"""Check that the DATA_DIR points to the right place."""
|
||||
data_path = xdg.BaseDirectory.save_data_path("emmental")
|
||||
self.assertEqual(emmental.gsetup.DATA_DIR, pathlib.Path(data_path))
|
||||
|
|
Loading…
Reference in New Issue