db: Create a Table base class
This is a Gtk.FilterListModel containing a store.SortedList to store individual rows in sorted order. I also implemented some convenience functions to make it easier to add, remove, look up, and filter rows. Signed-off-by: Anna Schumaker <Anna@NoWheyCreamery.com>
This commit is contained in:
parent
651f24672b
commit
788ca374a8
|
@ -1,7 +1,9 @@
|
|||
# Copyright 2022 (c) Anna Schumaker
|
||||
"""Easily work with our underlying sqlite3 database."""
|
||||
import pathlib
|
||||
from gi.repository import GObject
|
||||
from . import connection
|
||||
from . import table
|
||||
|
||||
|
||||
SQL_SCRIPT = pathlib.Path(__file__).parent / "emmental.sql"
|
||||
|
@ -18,3 +20,8 @@ class Connection(connection.Connection):
|
|||
case 0:
|
||||
with open(SQL_SCRIPT) as f:
|
||||
self._sql.executescript(f.read())
|
||||
|
||||
@GObject.Signal(arg_types=(table.Table,))
|
||||
def table_loaded(self, tbl: table.Table) -> None:
|
||||
"""Signal that a table has been loaded."""
|
||||
tbl.loaded = True
|
||||
|
|
|
@ -1,8 +1,10 @@
|
|||
# Copyright 2022 (c) Anna Schumaker
|
||||
"""Base classes for database objects."""
|
||||
import sqlite3
|
||||
from gi.repository import GObject
|
||||
from gi.repository import Gio
|
||||
from gi.repository import Gtk
|
||||
from .. import store
|
||||
|
||||
|
||||
class Row(GObject.GObject):
|
||||
|
@ -105,3 +107,122 @@ class Filter(Gtk.Filter):
|
|||
if (how := self.__find_change(keys)) is not None:
|
||||
self._keys = keys
|
||||
self.changed(how)
|
||||
|
||||
|
||||
class Table(Gtk.FilterListModel):
|
||||
"""An object that represents a database Table."""
|
||||
|
||||
sql = GObject.Property(type=GObject.TYPE_PYOBJECT)
|
||||
store = GObject.Property(type=Gio.ListModel)
|
||||
rows = GObject.Property(type=GObject.TYPE_PYOBJECT)
|
||||
|
||||
loaded = GObject.Property(type=bool, default=False)
|
||||
|
||||
def __init__(self, sql: GObject.TYPE_PYOBJECT,
|
||||
filter: Filter | None = None, **kwargs):
|
||||
"""Set up our Table object."""
|
||||
super().__init__(sql=sql, incremental=True, rows=dict(),
|
||||
store=store.SortedList(self.get_sort_key),
|
||||
filter=(filter if filter else Filter()), **kwargs)
|
||||
self.set_model(self.store)
|
||||
|
||||
def __contains__(self, row: Row) -> bool:
|
||||
"""Check if the row is in the _rowid_map for this Table."""
|
||||
return self.index(row) is not None
|
||||
|
||||
def do_construct(self, *args, **kwargs) -> Row:
|
||||
"""Construct a new Row instance."""
|
||||
raise NotImplementedError
|
||||
|
||||
def do_get_sort_key(self, row: Row) -> any:
|
||||
"""Get a sort key for the requested row."""
|
||||
return None
|
||||
|
||||
def do_sql_delete(self, row: Row) -> bool:
|
||||
"""Delete a Row."""
|
||||
raise NotImplementedError
|
||||
|
||||
def do_sql_glob(self, glob: str) -> sqlite3.Cursor:
|
||||
"""Select matching rowids using GLOB."""
|
||||
raise NotImplementedError
|
||||
|
||||
def do_sql_insert(self, *args, **kwargs) -> sqlite3.Cursor:
|
||||
"""Create a new Row."""
|
||||
raise NotImplementedError
|
||||
|
||||
def do_sql_select_all(self) -> sqlite3.Cursor:
|
||||
"""Return all rows from the table."""
|
||||
raise NotImplementedError
|
||||
|
||||
def do_sql_select_one(self, *args, **kwargs) -> sqlite3.Cursor:
|
||||
"""Look up a single row."""
|
||||
raise NotImplementedError
|
||||
|
||||
def do_sql_update(self, row: Row, column: str, newval) -> sqlite3.Cursor:
|
||||
"""Update a row."""
|
||||
raise NotImplementedError
|
||||
|
||||
def clear(self) -> None:
|
||||
"""Clear the table."""
|
||||
self.rows.clear()
|
||||
self.store.clear()
|
||||
self.loaded = False
|
||||
|
||||
def construct(self, *args, **kwargs) -> Row:
|
||||
"""Construct a new Row instance."""
|
||||
return self.do_construct(table=self, *args, **kwargs)
|
||||
|
||||
def create(self, *args, **kwargs) -> Row | None:
|
||||
"""Create a new Row in the Table."""
|
||||
if cur := self.do_sql_insert(*args, **kwargs):
|
||||
return self.insert(self.construct(**cur.fetchone()))
|
||||
|
||||
def delete(self, row: Row) -> bool:
|
||||
"""Delete a Row from the Table."""
|
||||
if row in self and self.do_sql_delete(row).rowcount == 1:
|
||||
self.store.remove(row)
|
||||
del self.rows[row.primary_key]
|
||||
return True
|
||||
return False
|
||||
|
||||
def filter(self, glob: str | None) -> None:
|
||||
"""Filter the displayed Rows."""
|
||||
if glob is not None:
|
||||
rows = self.do_sql_glob(glob).fetchall()
|
||||
self.get_filter().keys = {row[0] for row in rows}
|
||||
else:
|
||||
self.get_filter().keys = None
|
||||
|
||||
def get_sort_key(self, row: Row) -> tuple:
|
||||
"""Get a sort key for the requested row."""
|
||||
res = self.do_get_sort_key(row)
|
||||
return res if res is not None else row.primary_key
|
||||
|
||||
def index(self, row: Row) -> int | None:
|
||||
"""Find the index of a specific Row."""
|
||||
if row.table is self:
|
||||
return self.store.index(row)
|
||||
|
||||
def insert(self, row: Row) -> Row | None:
|
||||
"""Insert a Row in sorted position."""
|
||||
if row and row not in self:
|
||||
self.store.append(row)
|
||||
return self.rows.setdefault(row.primary_key, row)
|
||||
|
||||
def load(self) -> None:
|
||||
"""Load the Table from the database."""
|
||||
self.clear()
|
||||
cur = self.do_sql_select_all()
|
||||
rows = [self.construct(**row) for row in cur.fetchall()]
|
||||
self.store.extend(rows)
|
||||
self.rows = {row.primary_key: row for row in rows}
|
||||
self.sql.emit("table-loaded", self)
|
||||
|
||||
def lookup(self, *args, **kwargs) -> Row | None:
|
||||
"""Look up a Row in the database."""
|
||||
row = self.do_sql_select_one(*args, **kwargs).fetchone()
|
||||
return self.rows.get(row[0]) if row else None
|
||||
|
||||
def update(self, row: Row, column: str, newval) -> bool:
|
||||
"""Update a Row."""
|
||||
return self.do_sql_update(row, column, newval) is not None
|
||||
|
|
|
@ -3,6 +3,7 @@
|
|||
import unittest
|
||||
import unittest.mock
|
||||
import emmental.db.table
|
||||
import emmental.store
|
||||
import tests.util.table
|
||||
from gi.repository import GObject
|
||||
from gi.repository import Gio
|
||||
|
@ -162,3 +163,202 @@ class TestFilter(unittest.TestCase):
|
|||
self.assertFalse(self.filter.match(self.row1))
|
||||
self.filter.keys = set()
|
||||
self.assertFalse(self.filter.match(self.row1))
|
||||
|
||||
|
||||
class TestTable(tests.util.TestCase):
|
||||
"""Tests the base Table object."""
|
||||
|
||||
def setUp(self):
|
||||
"""Set up common variables."""
|
||||
super().setUp()
|
||||
self.table = emmental.db.table.Table(self.sql)
|
||||
|
||||
def test_init(self):
|
||||
"""Test that the table is set up properly."""
|
||||
self.assertIsInstance(self.table, Gtk.FilterListModel)
|
||||
self.assertIsInstance(self.table.get_filter(),
|
||||
emmental.db.table.Filter)
|
||||
self.assertIsInstance(self.table.store, emmental.store.SortedList)
|
||||
self.assertIsInstance(self.table.rows, dict)
|
||||
|
||||
self.assertEqual(self.table.sql, self.sql)
|
||||
self.assertEqual(self.table.get_model(), self.table.store)
|
||||
self.assertEqual(self.table.store.key_func, self.table.get_sort_key)
|
||||
self.assertDictEqual(self.table.rows, {})
|
||||
self.assertTrue(self.table.get_incremental())
|
||||
|
||||
filter2 = emmental.db.table.Filter()
|
||||
table2 = emmental.db.table.Table(self.sql, filter=filter2)
|
||||
self.assertEqual(table2.get_filter(), filter2)
|
||||
|
||||
def test_clear(self):
|
||||
"""Test clearing a table."""
|
||||
row = tests.util.table.MockRow(number=1, table=self.table)
|
||||
self.table.store.append(row)
|
||||
self.table.loaded = True
|
||||
|
||||
self.table.clear()
|
||||
self.assertEqual(self.table.store.n_items, 0)
|
||||
self.assertDictEqual(self.table.rows, dict())
|
||||
self.assertFalse(self.table.loaded)
|
||||
|
||||
def test_contains(self):
|
||||
"""Test checking if a Row is already in this Table."""
|
||||
row1 = tests.util.table.MockRow(number=1, table=self.table)
|
||||
row2 = tests.util.table.MockRow(number=2, table=self.table)
|
||||
self.table.insert(row1)
|
||||
self.assertTrue(row1 in self.table)
|
||||
self.assertFalse(row2 in self.table)
|
||||
|
||||
def test_get_sort_key(self):
|
||||
"""Test getting a sort key for a row."""
|
||||
row = tests.util.table.MockRow(number=1, table=self.table)
|
||||
self.table.insert(row)
|
||||
self.assertEqual(self.table.get_sort_key(row), 1)
|
||||
|
||||
def test_index(self):
|
||||
"""Test finding the index of rows in the table."""
|
||||
row1 = tests.util.table.MockRow(number=1, table=self.table)
|
||||
row2 = tests.util.table.MockRow(number=2, table=self.table)
|
||||
row3 = tests.util.table.MockRow(number=3, table=self.table)
|
||||
self.table.insert(row1)
|
||||
self.table.rows[row3.primary_key] = row3
|
||||
self.assertEqual(self.table.index(row1), 0)
|
||||
self.assertIsNone(self.table.index(row2))
|
||||
self.assertIsNone(self.table.index(row3))
|
||||
|
||||
def test_insert(self):
|
||||
"""Test inserting rows into the table in sorted position."""
|
||||
row1 = tests.util.table.MockRow(number=1, table=self.table)
|
||||
row2 = tests.util.table.MockRow(number=2, table=self.table)
|
||||
row3 = tests.util.table.MockRow(number=3, table=self.table)
|
||||
|
||||
self.assertEqual(self.table.insert(row1), row1)
|
||||
self.assertEqual(self.table.store.get_item(0), row1)
|
||||
self.assertDictEqual(self.table.rows, {1: row1})
|
||||
|
||||
self.assertEqual(self.table.insert(row3), row3)
|
||||
self.assertEqual(self.table.store.get_item(0), row1)
|
||||
self.assertEqual(self.table.store.get_item(1), row3)
|
||||
self.assertDictEqual(self.table.rows, {1: row1, 3: row3})
|
||||
|
||||
self.assertEqual(self.table.insert(row2), row2)
|
||||
self.assertEqual(self.table.store.get_item(0), row1)
|
||||
self.assertEqual(self.table.store.get_item(1), row2)
|
||||
self.assertEqual(self.table.store.get_item(2), row3)
|
||||
self.assertDictEqual(self.table.rows, {1: row1, 2: row2, 3: row3})
|
||||
|
||||
row1_again = tests.util.table.MockRow(number=1, table=self.table)
|
||||
self.assertIsNone(self.table.insert(row1_again))
|
||||
self.assertIsNone(self.table.insert(row1))
|
||||
self.assertIsNone(self.table.insert(None))
|
||||
|
||||
def test_interface(self):
|
||||
"""Test that calling interface functions raises an exception."""
|
||||
with self.assertRaises(NotImplementedError):
|
||||
self.table.construct(rowid=1)
|
||||
with self.assertRaises(NotImplementedError):
|
||||
self.table.create(rowid=1)
|
||||
with self.assertRaises(NotImplementedError):
|
||||
self.table.do_sql_delete(None)
|
||||
with self.assertRaises(NotImplementedError):
|
||||
self.table.filter("*text*")
|
||||
with self.assertRaises(NotImplementedError):
|
||||
self.table.load()
|
||||
with self.assertRaises(NotImplementedError):
|
||||
self.table.lookup(1)
|
||||
with self.assertRaises(NotImplementedError):
|
||||
self.table.update(None, "column", 12345)
|
||||
|
||||
|
||||
class TestTableFunctions(tests.util.TestCase):
|
||||
"""Tests Table functions with a Mock implementation."""
|
||||
|
||||
def setUp(self):
|
||||
"""Set up common variables."""
|
||||
super().setUp()
|
||||
self.table = tests.util.table.MockTable(self.sql)
|
||||
|
||||
def test_construct(self):
|
||||
"""Test constructing a new Row object."""
|
||||
row = self.table.construct(number=1)
|
||||
self.assertIsInstance(row, tests.util.table.MockRow)
|
||||
self.assertIsInstance(row, emmental.db.table.Row)
|
||||
self.assertEqual(row.table, self.table)
|
||||
self.assertEqual(row.number, 1)
|
||||
|
||||
def test_create(self):
|
||||
"""Test creating new rows."""
|
||||
row = self.table.create(number=1)
|
||||
self.assertIsInstance(row, tests.util.table.MockRow)
|
||||
self.assertEqual(self.table.index(row), 0)
|
||||
self.assertEqual(row.number, 1)
|
||||
self.assertDictEqual(self.table.rows, {1: row})
|
||||
|
||||
self.assertIsNone(self.table.create(number=1))
|
||||
|
||||
def test_delete(self):
|
||||
"""Test deleting rows."""
|
||||
row = self.table.create(number=1)
|
||||
self.assertTrue(row.delete())
|
||||
self.assertEqual(len(self.table), 0)
|
||||
self.assertDictEqual(self.table.rows, dict())
|
||||
|
||||
self.assertFalse(row.delete())
|
||||
|
||||
def test_filter(self):
|
||||
"""Test filtering Rows in the table."""
|
||||
for n in [1, 121, 212, 333]:
|
||||
self.table.create(number=n)
|
||||
|
||||
self.table.filter("*2*")
|
||||
self.assertSetEqual(self.table.get_filter().keys, {121, 212})
|
||||
self.table.filter(None)
|
||||
self.assertIsNone(self.table.get_filter().keys)
|
||||
|
||||
def test_get_sort_key(self):
|
||||
"""Test getting a sort key for a row."""
|
||||
row = self.table.create(number=42)
|
||||
self.assertTupleEqual(self.table.get_sort_key(row), (42, 42))
|
||||
|
||||
def test_load(self):
|
||||
"""Test loading rows from the database."""
|
||||
self.assertFalse(self.table.loaded)
|
||||
|
||||
table_loaded = unittest.mock.Mock()
|
||||
self.sql.connect("table-loaded", table_loaded)
|
||||
self.sql("INSERT INTO mock_table (number) VALUES (?)", 1)
|
||||
self.sql("INSERT INTO mock_table (number) VALUES (?)", 2)
|
||||
|
||||
self.table.load()
|
||||
self.assertTrue(self.table.loaded)
|
||||
self.assertEqual(len(self.table), 2)
|
||||
table_loaded.assert_called_with(self.sql, self.table)
|
||||
|
||||
row1 = self.table[0]
|
||||
row2 = self.table[1]
|
||||
|
||||
for row, n in [(row1, 1), (row2, 2)]:
|
||||
with self.subTest(n=n):
|
||||
self.assertEqual(row.number, n)
|
||||
|
||||
self.assertEqual(self.table.rows, {1: row1, 2: row2})
|
||||
|
||||
self.table.load()
|
||||
self.assertNotEqual(self.table[0], row1)
|
||||
self.assertNotEqual(self.table[1], row2)
|
||||
|
||||
def test_lookup(self):
|
||||
"""Test looking up rows in the table."""
|
||||
row = self.table.create(number=1)
|
||||
self.assertEqual(self.table.lookup(1), row)
|
||||
self.assertIsNone(self.table.lookup(2))
|
||||
|
||||
def test_update(self):
|
||||
"""Test updating a Row."""
|
||||
row = self.table.create(number=1)
|
||||
self.assertTrue(self.table.update(row, "number", 2))
|
||||
row.number = 2
|
||||
|
||||
self.table.create(number=3)
|
||||
self.assertFalse(self.table.update(row, "number", 3))
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
# Copyright 2023 (c) Anna Schumaker.
|
||||
"""Mock Row and Table objects for testing."""
|
||||
import emmental.db.table
|
||||
import sqlite3
|
||||
from gi.repository import GObject
|
||||
|
||||
|
||||
|
@ -13,3 +14,48 @@ class MockRow(emmental.db.table.Row):
|
|||
def primary_key(self) -> int:
|
||||
"""Get the primary key for this MockRow."""
|
||||
return self.number
|
||||
|
||||
|
||||
class MockTable(emmental.db.table.Table):
|
||||
"""A fake Table customized for testing."""
|
||||
|
||||
def __init__(self, sql: GObject.TYPE_PYOBJECT):
|
||||
"""Initialize the Mock Table."""
|
||||
super().__init__(sql)
|
||||
self.sql("CREATE TABLE mock_table (number INTEGER PRIMARY KEY)")
|
||||
|
||||
def do_construct(self, *args, **kwargs) -> MockRow:
|
||||
"""Construct a MockRow."""
|
||||
return MockRow(*args, **kwargs)
|
||||
|
||||
def do_get_sort_key(self, row: MockRow) -> any:
|
||||
"""Get the sort key for a MockRow."""
|
||||
return (row.number, row.number)
|
||||
|
||||
def do_sql_delete(self, row: MockRow) -> sqlite3.Cursor:
|
||||
"""Delete a MockRow from the Table."""
|
||||
return self.sql("DELETE FROM mock_table WHERE number=?", row.number)
|
||||
|
||||
def do_sql_glob(self, glob: str) -> sqlite3.Cursor:
|
||||
"""Select matching rows from the Table."""
|
||||
return self.sql("SELECT number FROM mock_table WHERE number GLOB ?",
|
||||
glob)
|
||||
|
||||
def do_sql_insert(self, number: int) -> sqlite3.Cursor:
|
||||
"""Insert a MockRow into the Table."""
|
||||
return self.sql("""INSERT INTO mock_table (number)
|
||||
VALUES (?) RETURNING *""", number)
|
||||
|
||||
def do_sql_select_all(self) -> sqlite3.Cursor:
|
||||
"""Return all rows in the Table."""
|
||||
return self.sql("SELECT * FROM mock_table ORDER BY number")
|
||||
|
||||
def do_sql_select_one(self, number: int) -> sqlite3.Cursor:
|
||||
"""Look up a single MockRow in the Table."""
|
||||
return self.sql("SELECT number FROM mock_table WHERE number=?", number)
|
||||
|
||||
def do_sql_update(self, row: MockRow, column: str,
|
||||
newval: int) -> sqlite3.Cursor:
|
||||
"""Update a MockRow in the Table."""
|
||||
return self.sql(f"UPDATE mock_table SET {column}=? WHERE number=?",
|
||||
newval, row.number)
|
||||
|
|
Loading…
Reference in New Issue