emmental/emmental/db/table.py

348 lines
12 KiB
Python

# Copyright 2022 (c) Anna Schumaker
"""Base classes for database objects."""
import bisect
import sqlite3
from gi.repository import GObject
from gi.repository import Gio
from gi.repository import Gtk
from .idle import Queue
from .. import store
class Row(GObject.GObject):
"""A single row in a database table."""
table = GObject.Property(type=Gio.ListModel)
def __init__(self, table: Gio.ListModel, **kwargs):
"""Initialize a database Row."""
super().__init__(table=table, **kwargs)
self.connect("notify", self.__notify)
def __notify(self, row: GObject.GObject, param: GObject.ParamSpec) -> None:
match param.name:
case "table": pass
case _: self.do_update(param.name)
def do_update(self, column: str) -> bool:
"""Update a Row in the database."""
return self.table.update(self, column, self.get_property(column))
def delete(self) -> bool:
"""Delete this Row."""
return self.table.delete(self)
@property
def primary_key(self) -> None:
"""Get the primary key for this row."""
raise NotImplementedError
class KeySet(Gtk.Filter):
"""A Gtk.Filter that also acts as a Python Set."""
n_keys = GObject.Property(type=int)
def __init__(self, keys: set | None = None, **kwargs):
"""Set up our KeySet."""
super().__init__(**kwargs)
self._keys = keys
self.n_keys = len(keys) if keys is not None else -1
def __contains__(self, row: Row) -> bool:
"""Check if a Row is in the KeySet."""
return self._keys is None or row.primary_key in self._keys
def __sub__(self, rhs: Gtk.Filter) -> set[int]:
"""Subtract two KeySets and return the result."""
match (self._keys, rhs._keys):
case (None, _): return None
case (_, None): return self._keys
case (_, _): return self._keys - rhs._keys
def __find_difference(self, new: set[any] | None) \
-> tuple[set, set, Gtk.FilterChange | None]:
if self._keys is None:
if new is None:
return (set(), set(), None)
return (set(), new, Gtk.FilterChange.MORE_STRICT)
elif new is None:
return (self._keys, set(), Gtk.FilterChange.LESS_STRICT)
removed = self._keys - new
added = new - self._keys
match len(removed), len(added):
case 0, 0: return (removed, added, None)
case _, 0: return (removed, added, Gtk.FilterChange.MORE_STRICT)
case 0, _: return (removed, added, Gtk.FilterChange.LESS_STRICT)
case _, _: return (removed, added, Gtk.FilterChange.DIFFERENT)
def changed(self, how: Gtk.FilterChange) -> None:
"""Notify that the KeySet has changed."""
self.n_keys = len(self._keys) if self._keys is not None else -1
super().changed(how)
def do_get_strictness(self) -> Gtk.FilterMatch:
"""Get the strictness of the Gtk.Filter."""
if self._keys is None:
return Gtk.FilterMatch.ALL
if len(self._keys) == 0:
return Gtk.FilterMatch.NONE
return Gtk.FilterMatch.SOME
def do_match(self, row: Row) -> bool:
"""Check if the Row is in the KeySet."""
return self._keys is None or row.primary_key in self._keys
def add_row(self, row: Row) -> None:
"""Add a Row to the KeySet."""
if row not in self:
self._keys.add(row.primary_key)
self.emit("key-added", row.primary_key)
self.changed(Gtk.FilterChange.LESS_STRICT)
def remove_row(self, row: Row) -> None:
"""Remove a Row from the KeySet."""
if self._keys is not None and row in self:
self._keys.discard(row.primary_key)
self.emit("key-removed", row.primary_key)
self.changed(Gtk.FilterChange.MORE_STRICT)
@property
def keys(self) -> set[any]:
"""Return the set of matching primary keys."""
return self._keys
@keys.setter
def keys(self, keys: set[any] | None) -> None:
"""Set the matching primary keys."""
(removed, added, change) = self.__find_difference(keys)
if change is not None:
self._keys = keys
self.emit("keys-changed", removed, added)
self.changed(change)
@GObject.Signal(arg_types=(int,))
def key_added(self, key: int) -> None:
"""Signal that a Row has been added to the KeySet."""
@GObject.Signal(arg_types=(int,))
def key_removed(self, key: int) -> None:
"""Signal that a Row has been removed from the KeySet."""
@GObject.Signal(arg_types=(GObject.TYPE_PYOBJECT, GObject.TYPE_PYOBJECT))
def keys_changed(self, removed: set | None, added: set | None) -> None:
"""Signal that the KeySet has been directly modified."""
class Table(Gtk.FilterListModel):
"""An object that represents a database Table."""
sql = GObject.Property(type=GObject.TYPE_PYOBJECT)
queue = GObject.Property(type=Queue)
store = GObject.Property(type=Gio.ListModel)
rows = GObject.Property(type=GObject.TYPE_PYOBJECT)
loaded = GObject.Property(type=bool, default=False)
def __init__(self, sql: GObject.TYPE_PYOBJECT,
filter: KeySet | None = None,
queue: Queue | None = None, **kwargs):
"""Set up our Table object."""
super().__init__(sql=sql, rows=dict(),
store=store.SortedList(self.get_sort_key),
filter=(filter if filter else KeySet()),
queue=(queue if queue else Queue()), **kwargs)
self.set_model(self.store)
def __clear_rows(self) -> None:
self.rows.clear()
self.store.clear()
self.loaded = False
def __contains__(self, row: Row) -> bool:
"""Check if the row is in the _rowid_map for this Table."""
return self.index(row) is not None
def do_construct(self, *args, **kwargs) -> Row:
"""Construct a new Row instance."""
raise NotImplementedError
def do_get_sort_key(self, row: Row) -> any:
"""Get a sort key for the requested row."""
return None
def do_sql_delete(self, row: Row) -> bool:
"""Delete a Row."""
raise NotImplementedError
def do_sql_glob(self, glob: str) -> sqlite3.Cursor:
"""Select matching rowids using GLOB."""
raise NotImplementedError
def do_sql_insert(self, *args, **kwargs) -> sqlite3.Cursor:
"""Create a new Row."""
raise NotImplementedError
def do_sql_select_all(self) -> sqlite3.Cursor:
"""Return all rows from the table."""
raise NotImplementedError
def do_sql_select_one(self, *args, **kwargs) -> sqlite3.Cursor:
"""Look up a single row."""
raise NotImplementedError
def do_sql_update(self, row: Row, column: str, newval) -> sqlite3.Cursor:
"""Update a row."""
raise NotImplementedError
def clear(self) -> None:
"""Clear the table."""
self.stop()
self.__clear_rows()
def construct(self, *args, **kwargs) -> Row:
"""Construct a new Row instance."""
return self.do_construct(table=self, *args, **kwargs)
def create(self, *args, **kwargs) -> Row | None:
"""Create a new Row in the Table."""
if cur := self.do_sql_insert(*args, **kwargs):
return self.insert(self.construct(**cur.fetchone()))
def delete(self, row: Row) -> bool:
"""Delete a Row from the Table."""
if row in self and self.do_sql_delete(row).rowcount == 1:
self.sql.commit()
self.store.remove(row)
del self.rows[row.primary_key]
return True
return False
def _filter_idle(self, glob: str) -> bool:
rows = self.do_sql_glob(glob).fetchall()
self.get_filter().keys = {row[0] for row in rows}
return True
def filter(self, glob: str | None, *, now: bool = False) -> None:
"""Filter the displayed Rows."""
if glob is not None:
self.queue.cancel_task(self._filter_idle)
self.queue.push(self._filter_idle, glob, now=now, first=True)
else:
self.get_filter().keys = None
def get_sort_key(self, row: Row) -> tuple:
"""Get a sort key for the requested row."""
res = self.do_get_sort_key(row)
return res if res is not None else row.primary_key
def index(self, row: Row) -> int | None:
"""Find the index of a specific Row."""
if row.table is self:
return self.store.index(row)
def insert(self, row: Row) -> Row | None:
"""Insert a Row in sorted position."""
if row and row not in self:
self.store.append(row)
return self.rows.setdefault(row.primary_key, row)
def _load_idle(self) -> bool:
self.__clear_rows()
cur = self.do_sql_select_all()
rows = [self.construct(**row) for row in cur.fetchall()]
self.store.extend(rows)
self.rows = {row.primary_key: row for row in rows}
self.sql.emit("table-loaded", self)
return True
def load(self, *, now: bool = False) -> None:
"""Load the Table from the database."""
self.queue.push(self._load_idle, now=now)
def lookup(self, *args, **kwargs) -> Row | None:
"""Look up a Row in the database."""
row = self.do_sql_select_one(*args, **kwargs).fetchone()
return self.rows.get(row[0]) if row else None
def stop(self) -> None:
"""Stop any background work."""
self.queue.cancel()
def update(self, row: Row, column: str, newval) -> bool:
"""Update a Row."""
return self.do_sql_update(row, column, newval) is not None
class TableSubset(GObject.GObject, Gio.ListModel):
"""A list model containing a subset of the rows in the source Table."""
keyset = GObject.Property(type=KeySet)
table = GObject.Property(type=Table)
n_rows = GObject.Property(type=int)
def __init__(self, table: Table, *, keys: set[any] | None = None):
"""Initialize a KeySetModel."""
super().__init__(keyset=KeySet(set() if keys is None else keys),
table=table)
self._items = []
self.keyset.connect("key-added", self.__on_key_added)
self.keyset.connect("key-removed", self.__on_key_removed)
self.table.connect("notify::loaded", self.__notify_table_loaded)
def __contains__(self, row: Row) -> bool:
"""Check if the Row is in the internal KeySet."""
return row in self.keyset
def __bisect(self, key: any) -> int | None:
if self.table.loaded:
sort_key = self.table.get_sort_key(self.table.rows[key])
return bisect.bisect_left(self._items, sort_key,
key=self.table.get_sort_key)
return None
def __items_changed(self, position: int, removed: int, added: int) -> None:
self.n_rows = len(self._items)
self.items_changed(position, removed, added)
def __notify_table_loaded(self, table: Table, param) -> None:
if table.loaded and self.keyset.n_keys > 0:
self._items = sorted([table.rows[k] for k in self.keyset.keys],
key=self.table.get_sort_key)
self.__items_changed(0, 0, self.keyset.n_keys)
elif not table.loaded and self.n_rows > 0:
self._items = []
self.__items_changed(0, self.n_rows, 0)
def __on_key_added(self, keyset: KeySet, key: any) -> None:
if (pos := self.__bisect(key)) is not None:
self._items.insert(pos, self.table.rows[key])
self.__items_changed(pos, 0, 1)
def __on_key_removed(self, keyset: KeySet, key: any) -> None:
if (pos := self.__bisect(key)) is not None:
del self._items[pos]
self.__items_changed(pos, 1, 0)
def do_get_item_type(self) -> GObject.GType:
"""Get the Gio.ListModel item type."""
return Row.__gtype__
def do_get_n_items(self) -> int:
"""Get the number of Rows in the TableSubset."""
return self.n_rows
def do_get_item(self, n: int) -> int:
"""Get the nth item in the TableSubset."""
return self._items[n] if n < len(self._items) else None
def add_row(self, row: Row) -> None:
"""Add a row to the TableSubset."""
self.keyset.add_row(row)
def remove_row(self, row: Row) -> None:
"""Remove a row from the TableSubset."""
self.keyset.remove_row(row)