348 lines
12 KiB
Python
348 lines
12 KiB
Python
# Copyright 2022 (c) Anna Schumaker
|
|
"""Base classes for database objects."""
|
|
import bisect
|
|
import sqlite3
|
|
from gi.repository import GObject
|
|
from gi.repository import Gio
|
|
from gi.repository import Gtk
|
|
from .idle import Queue
|
|
from .. import store
|
|
|
|
|
|
class Row(GObject.GObject):
|
|
"""A single row in a database table."""
|
|
|
|
table = GObject.Property(type=Gio.ListModel)
|
|
|
|
def __init__(self, table: Gio.ListModel, **kwargs):
|
|
"""Initialize a database Row."""
|
|
super().__init__(table=table, **kwargs)
|
|
self.connect("notify", self.__notify)
|
|
|
|
def __notify(self, row: GObject.GObject, param: GObject.ParamSpec) -> None:
|
|
match param.name:
|
|
case "table": pass
|
|
case _: self.do_update(param.name)
|
|
|
|
def do_update(self, column: str) -> bool:
|
|
"""Update a Row in the database."""
|
|
return self.table.update(self, column, self.get_property(column))
|
|
|
|
def delete(self) -> bool:
|
|
"""Delete this Row."""
|
|
return self.table.delete(self)
|
|
|
|
@property
|
|
def primary_key(self) -> None:
|
|
"""Get the primary key for this row."""
|
|
raise NotImplementedError
|
|
|
|
|
|
class KeySet(Gtk.Filter):
|
|
"""A Gtk.Filter that also acts as a Python Set."""
|
|
|
|
n_keys = GObject.Property(type=int)
|
|
|
|
def __init__(self, keys: set | None = None, **kwargs):
|
|
"""Set up our KeySet."""
|
|
super().__init__(**kwargs)
|
|
self._keys = keys
|
|
self.n_keys = len(keys) if keys is not None else -1
|
|
|
|
def __contains__(self, row: Row) -> bool:
|
|
"""Check if a Row is in the KeySet."""
|
|
return self._keys is None or row.primary_key in self._keys
|
|
|
|
def __sub__(self, rhs: Gtk.Filter) -> set[int]:
|
|
"""Subtract two KeySets and return the result."""
|
|
match (self._keys, rhs._keys):
|
|
case (None, _): return None
|
|
case (_, None): return self._keys
|
|
case (_, _): return self._keys - rhs._keys
|
|
|
|
def __find_difference(self, new: set[any] | None) \
|
|
-> tuple[set, set, Gtk.FilterChange | None]:
|
|
if self._keys is None:
|
|
if new is None:
|
|
return (set(), set(), None)
|
|
return (set(), new, Gtk.FilterChange.MORE_STRICT)
|
|
elif new is None:
|
|
return (self._keys, set(), Gtk.FilterChange.LESS_STRICT)
|
|
|
|
removed = self._keys - new
|
|
added = new - self._keys
|
|
match len(removed), len(added):
|
|
case 0, 0: return (removed, added, None)
|
|
case _, 0: return (removed, added, Gtk.FilterChange.MORE_STRICT)
|
|
case 0, _: return (removed, added, Gtk.FilterChange.LESS_STRICT)
|
|
case _, _: return (removed, added, Gtk.FilterChange.DIFFERENT)
|
|
|
|
def changed(self, how: Gtk.FilterChange) -> None:
|
|
"""Notify that the KeySet has changed."""
|
|
self.n_keys = len(self._keys) if self._keys is not None else -1
|
|
super().changed(how)
|
|
|
|
def do_get_strictness(self) -> Gtk.FilterMatch:
|
|
"""Get the strictness of the Gtk.Filter."""
|
|
if self._keys is None:
|
|
return Gtk.FilterMatch.ALL
|
|
if len(self._keys) == 0:
|
|
return Gtk.FilterMatch.NONE
|
|
return Gtk.FilterMatch.SOME
|
|
|
|
def do_match(self, row: Row) -> bool:
|
|
"""Check if the Row is in the KeySet."""
|
|
return self._keys is None or row.primary_key in self._keys
|
|
|
|
def add_row(self, row: Row) -> None:
|
|
"""Add a Row to the KeySet."""
|
|
if row not in self:
|
|
self._keys.add(row.primary_key)
|
|
self.emit("key-added", row.primary_key)
|
|
self.changed(Gtk.FilterChange.LESS_STRICT)
|
|
|
|
def remove_row(self, row: Row) -> None:
|
|
"""Remove a Row from the KeySet."""
|
|
if self._keys is not None and row in self:
|
|
self._keys.discard(row.primary_key)
|
|
self.emit("key-removed", row.primary_key)
|
|
self.changed(Gtk.FilterChange.MORE_STRICT)
|
|
|
|
@property
|
|
def keys(self) -> set[any]:
|
|
"""Return the set of matching primary keys."""
|
|
return self._keys
|
|
|
|
@keys.setter
|
|
def keys(self, keys: set[any] | None) -> None:
|
|
"""Set the matching primary keys."""
|
|
(removed, added, change) = self.__find_difference(keys)
|
|
if change is not None:
|
|
self._keys = keys
|
|
self.emit("keys-changed", removed, added)
|
|
self.changed(change)
|
|
|
|
@GObject.Signal(arg_types=(int,))
|
|
def key_added(self, key: int) -> None:
|
|
"""Signal that a Row has been added to the KeySet."""
|
|
|
|
@GObject.Signal(arg_types=(int,))
|
|
def key_removed(self, key: int) -> None:
|
|
"""Signal that a Row has been removed from the KeySet."""
|
|
|
|
@GObject.Signal(arg_types=(GObject.TYPE_PYOBJECT, GObject.TYPE_PYOBJECT))
|
|
def keys_changed(self, removed: set | None, added: set | None) -> None:
|
|
"""Signal that the KeySet has been directly modified."""
|
|
|
|
|
|
class Table(Gtk.FilterListModel):
|
|
"""An object that represents a database Table."""
|
|
|
|
sql = GObject.Property(type=GObject.TYPE_PYOBJECT)
|
|
queue = GObject.Property(type=Queue)
|
|
store = GObject.Property(type=Gio.ListModel)
|
|
rows = GObject.Property(type=GObject.TYPE_PYOBJECT)
|
|
|
|
loaded = GObject.Property(type=bool, default=False)
|
|
|
|
def __init__(self, sql: GObject.TYPE_PYOBJECT,
|
|
filter: KeySet | None = None,
|
|
queue: Queue | None = None, **kwargs):
|
|
"""Set up our Table object."""
|
|
super().__init__(sql=sql, rows=dict(),
|
|
store=store.SortedList(self.get_sort_key),
|
|
filter=(filter if filter else KeySet()),
|
|
queue=(queue if queue else Queue()), **kwargs)
|
|
self.set_model(self.store)
|
|
|
|
def __clear_rows(self) -> None:
|
|
self.rows.clear()
|
|
self.store.clear()
|
|
self.loaded = False
|
|
|
|
def __contains__(self, row: Row) -> bool:
|
|
"""Check if the row is in the _rowid_map for this Table."""
|
|
return self.index(row) is not None
|
|
|
|
def do_construct(self, *args, **kwargs) -> Row:
|
|
"""Construct a new Row instance."""
|
|
raise NotImplementedError
|
|
|
|
def do_get_sort_key(self, row: Row) -> any:
|
|
"""Get a sort key for the requested row."""
|
|
return None
|
|
|
|
def do_sql_delete(self, row: Row) -> bool:
|
|
"""Delete a Row."""
|
|
raise NotImplementedError
|
|
|
|
def do_sql_glob(self, glob: str) -> sqlite3.Cursor:
|
|
"""Select matching rowids using GLOB."""
|
|
raise NotImplementedError
|
|
|
|
def do_sql_insert(self, *args, **kwargs) -> sqlite3.Cursor:
|
|
"""Create a new Row."""
|
|
raise NotImplementedError
|
|
|
|
def do_sql_select_all(self) -> sqlite3.Cursor:
|
|
"""Return all rows from the table."""
|
|
raise NotImplementedError
|
|
|
|
def do_sql_select_one(self, *args, **kwargs) -> sqlite3.Cursor:
|
|
"""Look up a single row."""
|
|
raise NotImplementedError
|
|
|
|
def do_sql_update(self, row: Row, column: str, newval) -> sqlite3.Cursor:
|
|
"""Update a row."""
|
|
raise NotImplementedError
|
|
|
|
def clear(self) -> None:
|
|
"""Clear the table."""
|
|
self.stop()
|
|
self.__clear_rows()
|
|
|
|
def construct(self, *args, **kwargs) -> Row:
|
|
"""Construct a new Row instance."""
|
|
return self.do_construct(table=self, *args, **kwargs)
|
|
|
|
def create(self, *args, **kwargs) -> Row | None:
|
|
"""Create a new Row in the Table."""
|
|
if cur := self.do_sql_insert(*args, **kwargs):
|
|
return self.insert(self.construct(**cur.fetchone()))
|
|
|
|
def delete(self, row: Row) -> bool:
|
|
"""Delete a Row from the Table."""
|
|
if row in self and self.do_sql_delete(row).rowcount == 1:
|
|
self.sql.commit()
|
|
self.store.remove(row)
|
|
del self.rows[row.primary_key]
|
|
return True
|
|
return False
|
|
|
|
def _filter_idle(self, glob: str) -> bool:
|
|
rows = self.do_sql_glob(glob).fetchall()
|
|
self.get_filter().keys = {row[0] for row in rows}
|
|
return True
|
|
|
|
def filter(self, glob: str | None, *, now: bool = False) -> None:
|
|
"""Filter the displayed Rows."""
|
|
if glob is not None:
|
|
self.queue.cancel_task(self._filter_idle)
|
|
self.queue.push(self._filter_idle, glob, now=now, first=True)
|
|
else:
|
|
self.get_filter().keys = None
|
|
|
|
def get_sort_key(self, row: Row) -> tuple:
|
|
"""Get a sort key for the requested row."""
|
|
res = self.do_get_sort_key(row)
|
|
return res if res is not None else row.primary_key
|
|
|
|
def index(self, row: Row) -> int | None:
|
|
"""Find the index of a specific Row."""
|
|
if row.table is self:
|
|
return self.store.index(row)
|
|
|
|
def insert(self, row: Row) -> Row | None:
|
|
"""Insert a Row in sorted position."""
|
|
if row and row not in self:
|
|
self.store.append(row)
|
|
return self.rows.setdefault(row.primary_key, row)
|
|
|
|
def _load_idle(self) -> bool:
|
|
self.__clear_rows()
|
|
cur = self.do_sql_select_all()
|
|
rows = [self.construct(**row) for row in cur.fetchall()]
|
|
self.store.extend(rows)
|
|
self.rows = {row.primary_key: row for row in rows}
|
|
self.sql.emit("table-loaded", self)
|
|
return True
|
|
|
|
def load(self, *, now: bool = False) -> None:
|
|
"""Load the Table from the database."""
|
|
self.queue.push(self._load_idle, now=now)
|
|
|
|
def lookup(self, *args, **kwargs) -> Row | None:
|
|
"""Look up a Row in the database."""
|
|
row = self.do_sql_select_one(*args, **kwargs).fetchone()
|
|
return self.rows.get(row[0]) if row else None
|
|
|
|
def stop(self) -> None:
|
|
"""Stop any background work."""
|
|
self.queue.cancel()
|
|
|
|
def update(self, row: Row, column: str, newval) -> bool:
|
|
"""Update a Row."""
|
|
return self.do_sql_update(row, column, newval) is not None
|
|
|
|
|
|
class TableSubset(GObject.GObject, Gio.ListModel):
|
|
"""A list model containing a subset of the rows in the source Table."""
|
|
|
|
keyset = GObject.Property(type=KeySet)
|
|
table = GObject.Property(type=Table)
|
|
n_rows = GObject.Property(type=int)
|
|
|
|
def __init__(self, table: Table, *, keys: set[any] | None = None):
|
|
"""Initialize a KeySetModel."""
|
|
super().__init__(keyset=KeySet(set() if keys is None else keys),
|
|
table=table)
|
|
self._items = []
|
|
|
|
self.keyset.connect("key-added", self.__on_key_added)
|
|
self.keyset.connect("key-removed", self.__on_key_removed)
|
|
self.table.connect("notify::loaded", self.__notify_table_loaded)
|
|
|
|
def __contains__(self, row: Row) -> bool:
|
|
"""Check if the Row is in the internal KeySet."""
|
|
return row in self.keyset
|
|
|
|
def __bisect(self, key: any) -> int | None:
|
|
if self.table.loaded:
|
|
sort_key = self.table.get_sort_key(self.table.rows[key])
|
|
return bisect.bisect_left(self._items, sort_key,
|
|
key=self.table.get_sort_key)
|
|
return None
|
|
|
|
def __items_changed(self, position: int, removed: int, added: int) -> None:
|
|
self.n_rows = len(self._items)
|
|
self.items_changed(position, removed, added)
|
|
|
|
def __notify_table_loaded(self, table: Table, param) -> None:
|
|
if table.loaded and self.keyset.n_keys > 0:
|
|
self._items = sorted([table.rows[k] for k in self.keyset.keys],
|
|
key=self.table.get_sort_key)
|
|
self.__items_changed(0, 0, self.keyset.n_keys)
|
|
elif not table.loaded and self.n_rows > 0:
|
|
self._items = []
|
|
self.__items_changed(0, self.n_rows, 0)
|
|
|
|
def __on_key_added(self, keyset: KeySet, key: any) -> None:
|
|
if (pos := self.__bisect(key)) is not None:
|
|
self._items.insert(pos, self.table.rows[key])
|
|
self.__items_changed(pos, 0, 1)
|
|
|
|
def __on_key_removed(self, keyset: KeySet, key: any) -> None:
|
|
if (pos := self.__bisect(key)) is not None:
|
|
del self._items[pos]
|
|
self.__items_changed(pos, 1, 0)
|
|
|
|
def do_get_item_type(self) -> GObject.GType:
|
|
"""Get the Gio.ListModel item type."""
|
|
return Row.__gtype__
|
|
|
|
def do_get_n_items(self) -> int:
|
|
"""Get the number of Rows in the TableSubset."""
|
|
return self.n_rows
|
|
|
|
def do_get_item(self, n: int) -> int:
|
|
"""Get the nth item in the TableSubset."""
|
|
return self._items[n] if n < len(self._items) else None
|
|
|
|
def add_row(self, row: Row) -> None:
|
|
"""Add a row to the TableSubset."""
|
|
self.keyset.add_row(row)
|
|
|
|
def remove_row(self, row: Row) -> None:
|
|
"""Remove a row from the TableSubset."""
|
|
self.keyset.remove_row(row)
|