From 67b508384c17e1a198a82b7b00e5d9aecc3e6d2f Mon Sep 17 00:00:00 2001 From: Anna Schumaker Date: Mon, 19 Jun 2023 22:41:47 -0400 Subject: [PATCH] db: Create a TableSubset model This is similar to a Gtk.FilterListModel, except we already know exactly which rows are part of the model or not. So we can skip the entire filtering step and show rows directly instead. Signed-off-by: Anna Schumaker --- emmental/db/table.py | 73 +++++++++++++++++++++ tests/db/test_table.py | 144 +++++++++++++++++++++++++++++++++++++++++ 2 files changed, 217 insertions(+) diff --git a/emmental/db/table.py b/emmental/db/table.py index d453b7a..61fa198 100644 --- a/emmental/db/table.py +++ b/emmental/db/table.py @@ -1,5 +1,6 @@ # Copyright 2022 (c) Anna Schumaker """Base classes for database objects.""" +import bisect import sqlite3 from gi.repository import GObject from gi.repository import Gio @@ -271,3 +272,75 @@ class Table(Gtk.FilterListModel): def update(self, row: Row, column: str, newval) -> bool: """Update a Row.""" return self.do_sql_update(row, column, newval) is not None + + +class TableSubset(GObject.GObject, Gio.ListModel): + """A list model containing a subset of the rows in the source Table.""" + + keyset = GObject.Property(type=KeySet) + table = GObject.Property(type=Table) + n_rows = GObject.Property(type=int) + + def __init__(self, table: Table, *, keys: set[any] | None = None): + """Initialize a KeySetModel.""" + super().__init__(keyset=KeySet(set() if keys is None else keys), + table=table) + self._items = [] + + self.keyset.connect("key-added", self.__on_key_added) + self.keyset.connect("key-removed", self.__on_key_removed) + self.table.connect("notify::loaded", self.__notify_table_loaded) + + def __contains__(self, row: Row) -> bool: + """Check if the Row is in the internal KeySet.""" + return row in self.keyset + + def __bisect(self, key: any) -> int | None: + if self.table.loaded: + sort_key = self.table.get_sort_key(self.table.rows[key]) + return bisect.bisect_left(self._items, sort_key, + key=self.table.get_sort_key) + return None + + def __items_changed(self, position: int, removed: int, added: int) -> None: + self.n_rows = len(self._items) + self.items_changed(position, removed, added) + + def __notify_table_loaded(self, table: Table, param) -> None: + if table.loaded and self.keyset.n_keys > 0: + self._items = sorted([table.rows[k] for k in self.keyset.keys], + key=self.table.get_sort_key) + self.__items_changed(0, 0, self.keyset.n_keys) + elif not table.loaded and self.n_rows > 0: + self._items = [] + self.__items_changed(0, self.n_rows, 0) + + def __on_key_added(self, keyset: KeySet, key: any) -> None: + if (pos := self.__bisect(key)) is not None: + self._items.insert(pos, self.table.rows[key]) + self.__items_changed(pos, 0, 1) + + def __on_key_removed(self, keyset: KeySet, key: any) -> None: + if (pos := self.__bisect(key)) is not None: + del self._items[pos] + self.__items_changed(pos, 1, 0) + + def do_get_item_type(self) -> GObject.GType: + """Get the Gio.ListModel item type.""" + return Row.__gtype__ + + def do_get_n_items(self) -> int: + """Get the number of Rows in the TableSubset.""" + return self.n_rows + + def do_get_item(self, n: int) -> int: + """Get the nth item in the TableSubset.""" + return self._items[n] if n < len(self._items) else None + + def add_row(self, row: Row) -> None: + """Add a row to the TableSubset.""" + self.keyset.add_row(row) + + def remove_row(self, row: Row) -> None: + """Remove a row from the TableSubset.""" + self.keyset.remove_row(row) diff --git a/tests/db/test_table.py b/tests/db/test_table.py index 2f990f8..584397c 100644 --- a/tests/db/test_table.py +++ b/tests/db/test_table.py @@ -440,3 +440,147 @@ class TestTableFunctions(tests.util.TestCase): self.table.create(number=3) self.assertFalse(self.table.update(row, "number", 3)) + + +class TestTableSubset(tests.util.TestCase): + """Tests the TableSubset.""" + + def setUp(self): + """Set up common variables.""" + super().setUp() + self.table = tests.util.table.MockTable(self.sql) + self.subset = emmental.db.table.TableSubset(self.table) + self.rows = [self.table.create(number=i) for i in range(5)] + + def test_init(self): + """Test that the TableSubset was set up properly.""" + self.assertIsInstance(self.subset, Gio.ListModel) + self.assertIsInstance(self.subset, GObject.GObject) + self.assertIsInstance(self.subset.keyset, emmental.db.table.KeySet) + self.assertSetEqual(self.subset.keyset.keys, set()) + self.assertEqual(self.subset.table, self.table) + + subset2 = emmental.db.table.TableSubset(self.table, keys={1, 2, 3}) + self.assertSetEqual(subset2.keyset.keys, {1, 2, 3}) + + def test_get_item_type(self): + """Test the Gio.ListModel.get_item_type() function.""" + self.assertEqual(self.subset.get_item_type(), + emmental.db.table.Row.__gtype__) + + def test_get_n_items(self): + """Test the Gio.ListModel.get_n_items() function.""" + self.assertEqual(self.subset.get_n_items(), 0) + self.assertEqual(self.subset.n_rows, 0) + + self.subset.add_row(self.rows[0]) + self.assertEqual(self.subset.get_n_items(), 0) + self.assertEqual(self.subset.n_rows, 0) + + self.table.loaded = True + self.assertEqual(self.subset.get_n_items(), 1) + self.assertEqual(self.subset.n_rows, 1) + + self.table.loaded = False + self.assertEqual(self.subset.get_n_items(), 0) + self.assertEqual(self.subset.n_rows, 0) + + def test_get_item(self): + """Test the Gio.ListModel.get_item() function.""" + for row in self.rows: + self.subset.add_row(row) + + self.assertListEqual(self.subset._items, []) + + for i, row in enumerate(self.rows): + with self.subTest(i=i, row=row.number): + self.assertIsNone(self.subset.get_item(i)) + + self.table.loaded = True + self.assertEqual(self.subset.get_item(i), row) + self.assertEqual(self.subset._items[i], row) + + self.table.loaded = False + self.assertIsNone(self.subset.get_item(i)) + + def test_add_row(self): + """Test adding a row to the TableSubset.""" + expected = set() + self.table.loaded = True + self.assertListEqual(self.subset._items, []) + + changed = unittest.mock.Mock() + self.subset.connect("items-changed", changed) + + for n, i in enumerate([2, 0, 4, 1, 3], start=1): + row = self.rows[i] + with self.subTest(i=i, row=row.number): + expected.add(i) + self.subset.add_row(row) + self.assertSetEqual(self.subset.keyset.keys, expected) + self.assertEqual(self.subset.n_rows, n) + changed.assert_called_with(self.subset, + sorted(expected).index(i), 0, 1) + + self.assertListEqual(self.subset._items, self.rows) + self.assertListEqual(list(self.subset), self.rows) + + def test_remove_row(self): + """Test removing a row from the TableSubset.""" + self.table.loaded = True + [self.subset.add_row(row) for row in self.rows] + expected = {row.number for row in self.rows} + + changed = unittest.mock.Mock() + self.subset.connect("items-changed", changed) + + for n, i in enumerate([2, 0, 4, 1, 3], start=1): + row = self.rows[i] + rm = sorted(expected).index(i) + with self.subTest(i=i, row=row.number): + expected.discard(i) + self.subset.remove_row(row) + self.assertSetEqual(self.subset.keyset.keys, expected) + self.assertEqual(self.subset.n_rows, 5 - n) + changed.assert_called_with(self.subset, rm, 1, 0) + + self.assertEqual(self.subset.n_rows, 0) + + def test_contains(self): + """Test the __contains__() magic method.""" + self.table.loaded = True + self.assertFalse(self.rows[0] in self.subset) + self.subset.add_row(self.rows[0]) + self.assertTrue(self.rows[0] in self.subset) + + def test_table_not_loaded(self): + """Test operations when the table hasn't been loaded.""" + self.subset.add_row(self.rows[0]) + self.assertListEqual(self.subset._items, []) + self.assertEqual(self.subset.n_rows, 0) + self.assertIsNone(self.subset.get_item(0)) + + self.subset.remove_row(self.rows[0]) + self.assertListEqual(self.subset._items, []) + self.assertEqual(self.subset.n_rows, 0) + + def test_table_loaded(self): + """Test changing the value of Table.loaded.""" + changed = unittest.mock.Mock() + self.subset.connect("items-changed", changed) + + self.table.loaded = True + changed.assert_not_called() + self.table.loaded = False + changed.assert_not_called() + + self.subset.add_row(self.rows[0]) + self.subset.add_row(self.rows[1]) + + self.table.loaded = True + self.assertEqual(self.subset.n_rows, 2) + changed.assert_called_with(self.subset, 0, 0, 2) + + self.table.loaded = False + self.assertEqual(self.subset.n_rows, 0) + changed.assert_called_with(self.subset, 0, 2, 0)