db: Create a base Connection manager

This is a wrapper around the sqlite3.Connection objct that adds some
nice functionality to make working with SQL easier.

I defined the following magic methods:
  * __enter__() to manually begin a transaction
  * __exit__() to commit or rollback a manual transaction
  * __call__() to execute a SQL statement with either positional or
    keyword arguments.

Additionally:
  * I define a "CASEFOLD" function that can be used in queries
    to lowercase unicode text when searching.
  * I set foreign_keys = ON so foreign keys checking is always enabled
  * I provide an executemany() function for running running the same
    statement multiple times with different arguments.

Signed-off-by: Anna Schumaker <Anna@NoWheyCreamery.com>
This commit is contained in:
Anna Schumaker 2022-05-26 17:17:03 -04:00
parent 61fc252172
commit deb4f3d252
4 changed files with 197 additions and 0 deletions

68
emmental/db/connection.py Normal file
View File

@ -0,0 +1,68 @@
# Copyright 2022 (c) Anna Schumaker
"""Easily work with our underlying sqlite3 database."""
import sqlite3
import sys
from gi.repository import GObject
from .. import gsetup
DATA_FILE = gsetup.DATA_DIR / f"emmental{gsetup.DEBUG_STR}.sqlite3"
DATABASE = ":memory:" if "unittest" in sys.modules else DATA_FILE
class Connection(GObject.GObject):
"""Connect to the database."""
connected = GObject.Property(type=bool, default=True)
def __init__(self):
"""Initialize a sqlite connection."""
super().__init__()
self._sql = sqlite3.connect(DATABASE,
detect_types=sqlite3.PARSE_DECLTYPES)
self._sql.create_function("CASEFOLD", 1,
lambda s: s.casefold() if s else None,
deterministic=True)
self._sql.row_factory = sqlite3.Row
self._sql("PRAGMA foreign_keys = ON")
def __call__(self, statement: str,
*args, **kwargs) -> sqlite3.Cursor | None:
"""Execute a SQL statement."""
try:
return self._sql.execute(statement, args if len(args) else kwargs)
except sqlite3.IntegrityError:
return None
def __del__(self) -> None:
"""Clean up before exiting."""
self.close()
def __enter__(self) -> None:
"""Begin a transaction."""
if not self._sql.in_transaction:
self._sql.commit()
self._sql.execute("BEGIN")
def __exit__(self, exp_type, exp_value, traceback) -> bool:
"""Either commit or rollback an active transaction."""
if exp_type is None:
self._sql.commit()
else:
self._sql.rollback()
return exp_type is None
def close(self) -> None:
"""Close the database connection."""
if self.connected:
self._sql.commit()
self._sql.execute("PRAGMA optimize")
self._sql.close()
self.connected = False
def executemany(self, statement: str, *args) -> sqlite3.Cursor | None:
"""Execute several similar SQL statements at once."""
try:
return self._sql.executemany(statement, args)
except sqlite3.InternalError:
return None

View File

@ -2,7 +2,9 @@
"""Set up GObject Introspection, and custom styling, and icons."""
import pathlib
import sys
import sqlite3
import gi
import xdg.BaseDirectory
gi.require_version("Gdk", "4.0")
gi.require_version("Gtk", "4.0")
@ -19,6 +21,8 @@ CSS_PRIORITY = gi.repository.Gtk.STYLE_PROVIDER_PRIORITY_APPLICATION
CSS_PROVIDER = gi.repository.Gtk.CssProvider()
CSS_PROVIDER.load_from_path(str(CSS_FILE))
DATA_DIR = pathlib.Path(xdg.BaseDirectory.save_data_path("emmental"))
RESOURCE_PATH = "/com/nowheycreamery/emmental"
RESOURCE_ICONS = f"{RESOURCE_PATH}/icons/scalable/apps"
RESOURCE_FILE = pathlib.Path(__file__).parent.parent / "emmental.gresource"
@ -47,3 +51,6 @@ def print_versions():
__print_version("Libadwaita", gi.repository.Adw.MAJOR_VERSION,
gi.repository.Adw.MINOR_VERSION,
gi.repository.Adw.MICRO_VERSION)
__print_version("SQLite", sqlite3.sqlite_version_info[0],
sqlite3.sqlite_version_info[1],
sqlite3.sqlite_version_info[2])

116
tests/db/test_connection.py Normal file
View File

@ -0,0 +1,116 @@
# Copyright 2022 (c) Anna Schumaker
"""Test our custom db Connection object."""
import sqlite3
import emmental.db.connection
import unittest
from gi.repository import GObject
class TestConnection(unittest.TestCase):
"""Test case for our database connection manager."""
def setUp(self):
"""Set up common variables."""
self.sql = emmental.db.connection.Connection()
def tearDown(self):
"""Clean up."""
self.sql.close()
def test_paths(self):
"""Check that path constants are pointing to the right places."""
self.assertEqual(emmental.db.connection.DATA_FILE,
emmental.gsetup.DATA_DIR / "emmental-debug.sqlite3")
self.assertEqual(emmental.db.connection.DATABASE, ":memory:")
def test_connection(self):
"""Check that the connection manager is initialized properly."""
self.assertIsInstance(self.sql, GObject.GObject)
self.assertIsInstance(self.sql._sql, sqlite3.Connection)
self.assertEqual(self.sql._sql.row_factory, sqlite3.Row)
self.assertTrue(self.sql.connected)
def test_foreign_keys(self):
"""Test that foreign keys are enabled."""
cur = self.sql("PRAGMA foreign_keys")
self.assertEqual(cur.fetchone()["foreign_keys"], 1)
def test_call(self):
"""Check that we can call the connection to execute a sql statement."""
self.sql("CREATE TABLE test (a INT UNIQUE, b INT)")
self.sql("INSERT INTO test VALUES (?, ?)", 1, 2)
cur = self.sql("SELECT * FROM test")
self.assertIsInstance(cur, sqlite3.Cursor)
row = cur.fetchone()
self.assertIsInstance(row, sqlite3.Row)
self.assertEqual(row["a"], 1)
self.assertEqual(row["b"], 2)
def test_call_keyword(self):
"""Test running a sql statement with keyword arguments."""
self.sql("CREATE TABLE test (a INT UNIQUE, b INT)")
self.sql("INSERT INTO test VALUES (:a, :b)", a=1, b=2)
cur = self.sql("SELECT * FROM test")
self.assertIsInstance(cur, sqlite3.Cursor)
row = cur.fetchone()
self.assertIsInstance(row, sqlite3.Row)
self.assertEqual(row["a"], 1)
self.assertEqual(row["b"], 2)
def test_casefold(self):
"""Test the casefold function."""
self.sql("CREATE TABLE test (a INT, text TEXT)")
self.sql("INSERT INTO test VALUES (?, ?)", 1, "TEST")
self.sql("INSERT INTO test VALUES (?, ?)", 2, None)
rows = self.sql("SELECT CASEFOLD(text) as text FROM test").fetchall()
self.assertEqual(rows[0]["text"], "test")
self.assertEqual(rows[1]["text"], None)
def test_executemany(self):
"""Test the executemany function."""
self.sql("CREATE TABLE test (a INT, b TEXT)")
self.sql.executemany("INSERT INTO test VALUES (?, ?)",
(1, "a"), (2, "b"), (3, "c"), (4, "d"), (5, "e"))
rows = self.sql("SELECT * FROM test").fetchall()
self.assertEqual(tuple(rows[0]), (1, "a"))
self.assertEqual(tuple(rows[1]), (2, "b"))
self.assertEqual(tuple(rows[2]), (3, "c"))
self.assertEqual(tuple(rows[3]), (4, "d"))
self.assertEqual(tuple(rows[4]), (5, "e"))
def test_transaction(self):
"""Test that we can manually start a transaction."""
self.assertFalse(self.sql._sql.in_transaction)
with self.sql:
self.assertTrue(self.sql._sql.in_transaction)
self.sql("CREATE TABLE test_table (test TEXT)")
self.sql("INSERT INTO test_table VALUES (?)", "Test")
self.assertTrue(self.sql._sql.in_transaction)
self.assertFalse(self.sql._sql.in_transaction)
cur = self.sql("SELECT COUNT(*) FROM test_table")
self.assertEqual(cur.fetchone()["COUNT(*)"], 1)
def test_transaction_rollback(self):
"""Test that errors roll back the transaction."""
with self.assertRaises(Exception):
with self.sql:
self.sql("CREATE TABLE other_table (test TEXT)")
self.sql("INSERT INTO other_table VALUES (?)", "Test")
raise Exception("Test Exeption")
self.assertFalse(self.sql._sql.in_transaction)
with self.assertRaises(sqlite3.OperationalError):
self.sql("SELECT COUNT(*) FROM other_table")
def test_close(self):
"""Check closing the connection."""
self.sql.close()
self.assertFalse(self.sql.connected)
with self.assertRaises(sqlite3.ProgrammingError):
self.assertIsNone(self.sql("SELECT COUNT(*) FROM test_table"))
self.sql.close()

View File

@ -4,6 +4,7 @@ import unittest
import pathlib
import emmental
import gi
import xdg.BaseDirectory
class TestGSetup(unittest.TestCase):
@ -53,3 +54,8 @@ class TestGSetup(unittest.TestCase):
self.assertIsInstance(emmental.gsetup.RESOURCE,
gi.repository.Gio.Resource)
def test_data_dir(self):
"""Check that the DATA_DIR points to the right place."""
data_path = xdg.BaseDirectory.save_data_path("emmental")
self.assertEqual(emmental.gsetup.DATA_DIR, pathlib.Path(data_path))