From 929beb2a97ba1e1dba232e9fdaf7697286c8f312 Mon Sep 17 00:00:00 2001 From: Anna Schumaker Date: Mon, 19 Jun 2023 13:42:28 -0400 Subject: [PATCH] db: Add set features to the KeySet This includes implementing the __contains__() magic method, and adding signals that are emitted when rows are added, removed, or directly set. This will allow us to build a model around the rows represented by the set. Signed-off-by: Anna Schumaker --- emmental/db/table.py | 56 +++++++++++++++++++++++++---------- tests/db/test_table.py | 67 +++++++++++++++++++++++++++++++++++------- 2 files changed, 97 insertions(+), 26 deletions(-) diff --git a/emmental/db/table.py b/emmental/db/table.py index 07e2d1e..d453b7a 100644 --- a/emmental/db/table.py +++ b/emmental/db/table.py @@ -48,6 +48,10 @@ class KeySet(Gtk.Filter): self._keys = keys self.n_keys = len(keys) if keys is not None else -1 + def __contains__(self, row: Row) -> bool: + """Check if a Row is in the KeySet.""" + return self._keys is None or row.primary_key in self._keys + def __sub__(self, rhs: Gtk.Filter) -> set[int]: """Subtract two KeySets and return the result.""" match (self._keys, rhs._keys): @@ -55,18 +59,22 @@ class KeySet(Gtk.Filter): case (_, None): return self._keys case (_, _): return self._keys - rhs._keys - def __find_change(self, keys: set[any] | None) -> Gtk.FilterChange | None: - if keys == self._keys: - return None - elif keys is None: - return Gtk.FilterChange.LESS_STRICT - elif self._keys is None: - return Gtk.FilterChange.MORE_STRICT - elif keys.issuperset(self._keys): - return Gtk.FilterChange.LESS_STRICT - elif keys.issubset(self._keys): - return Gtk.FilterChange.MORE_STRICT - return Gtk.FilterChange.DIFFERENT + def __find_difference(self, new: set[any] | None) \ + -> tuple[set, set, Gtk.FilterChange | None]: + if self._keys is None: + if new is None: + return (set(), set(), None) + return (set(), new, Gtk.FilterChange.MORE_STRICT) + elif new is None: + return (self._keys, set(), Gtk.FilterChange.LESS_STRICT) + + removed = self._keys - new + added = new - self._keys + match len(removed), len(added): + case 0, 0: return (removed, added, None) + case _, 0: return (removed, added, Gtk.FilterChange.MORE_STRICT) + case 0, _: return (removed, added, Gtk.FilterChange.LESS_STRICT) + case _, _: return (removed, added, Gtk.FilterChange.DIFFERENT) def changed(self, how: Gtk.FilterChange) -> None: """Notify that the KeySet has changed.""" @@ -87,14 +95,16 @@ class KeySet(Gtk.Filter): def add_row(self, row: Row) -> None: """Add a Row to the KeySet.""" - if self._keys is not None: + if row not in self: self._keys.add(row.primary_key) + self.emit("key-added", row.primary_key) self.changed(Gtk.FilterChange.LESS_STRICT) def remove_row(self, row: Row) -> None: """Remove a Row from the KeySet.""" - if self._keys is not None: + if self._keys is not None and row in self: self._keys.discard(row.primary_key) + self.emit("key-removed", row.primary_key) self.changed(Gtk.FilterChange.MORE_STRICT) @property @@ -105,9 +115,23 @@ class KeySet(Gtk.Filter): @keys.setter def keys(self, keys: set[any] | None) -> None: """Set the matching primary keys.""" - if (how := self.__find_change(keys)) is not None: + (removed, added, change) = self.__find_difference(keys) + if change is not None: self._keys = keys - self.changed(how) + self.emit("keys-changed", removed, added) + self.changed(change) + + @GObject.Signal(arg_types=(int,)) + def key_added(self, key: int) -> None: + """Signal that a Row has been added to the KeySet.""" + + @GObject.Signal(arg_types=(int,)) + def key_removed(self, key: int) -> None: + """Signal that a Row has been removed from the KeySet.""" + + @GObject.Signal(arg_types=(GObject.TYPE_PYOBJECT, GObject.TYPE_PYOBJECT)) + def keys_changed(self, removed: set | None, added: set | None) -> None: + """Signal that the KeySet has been directly modified.""" class Table(Gtk.FilterListModel): diff --git a/tests/db/test_table.py b/tests/db/test_table.py index 5ed073b..2f990f8 100644 --- a/tests/db/test_table.py +++ b/tests/db/test_table.py @@ -91,78 +91,125 @@ class TestKeySet(unittest.TestCase): def test_add_row(self, mock_changed: unittest.mock.Mock): """Test adding Rows to the KeySet.""" + mock_added = unittest.mock.Mock() + self.keyset.connect("key-added", mock_added) + self.keyset.add_row(self.row1) self.assertIsNone(self.keyset.keys) + mock_added.assert_not_called() self.keyset.keys = set() self.keyset.add_row(self.row1) self.assertSetEqual(self.keyset.keys, {1}) - mock_changed.assert_called_with(Gtk.FilterChange.LESS_STRICT) self.assertEqual(self.keyset.n_keys, 1) + mock_changed.assert_called_with(Gtk.FilterChange.LESS_STRICT) + mock_added.assert_called_with(self.keyset, 1) self.keyset.add_row(self.row2) self.assertSetEqual(self.keyset.keys, {1, 2}) - mock_changed.assert_called_with(Gtk.FilterChange.LESS_STRICT) self.assertEqual(self.keyset.n_keys, 2) + mock_changed.assert_called_with(Gtk.FilterChange.LESS_STRICT) + mock_added.assert_called_with(self.keyset, 2) + + mock_changed.reset_mock() + mock_added.reset_mock() + self.keyset.add_row(self.row2) + self.assertSetEqual(self.keyset.keys, {1, 2}) + mock_changed.assert_not_called() + mock_added.assert_not_called() def test_remove_row(self, mock_changed: unittest.mock.Mock): """Test removing Rows from the KeySet.""" + mock_removed = unittest.mock.Mock() + self.keyset.connect("key-removed", mock_removed) + self.keyset.remove_row(self.row1) mock_changed.assert_not_called() + mock_removed.assert_not_called() self.keyset.keys = {1, 2} self.keyset.remove_row(self.row1) self.assertSetEqual(self.keyset._keys, {2}) - mock_changed.assert_called_with(Gtk.FilterChange.MORE_STRICT) self.assertEqual(self.keyset.n_keys, 1) + mock_changed.assert_called_with(Gtk.FilterChange.MORE_STRICT) + mock_removed.assert_called_with(self.keyset, 1) mock_changed.reset_mock() + mock_removed.reset_mock() + self.keyset.remove_row(self.row1) + self.assertSetEqual(self.keyset.keys, {2}) + self.assertEqual(self.keyset.n_keys, 1) + mock_changed.assert_not_called() + mock_removed.assert_not_called() + self.keyset.remove_row(self.row2) self.assertSetEqual(self.keyset._keys, set()) - mock_changed.assert_called_with(Gtk.FilterChange.MORE_STRICT) self.assertEqual(self.keyset.n_keys, 0) + mock_changed.assert_called_with(Gtk.FilterChange.MORE_STRICT) + mock_removed.assert_called_with(self.keyset, 2) def test_keys(self, mock_changed: unittest.mock.Mock): - """Test setting and getting the KeySet keys property.""" + """Test getting and setting the KeySet.keys property.""" + mock_keys_changed = unittest.mock.Mock() + self.keyset.connect("keys-changed", mock_keys_changed) + self.assertIsNone(self.keyset.keys) + self.keyset.keys = None + self.assertIsNone(self.keyset.keys) + mock_changed.assert_not_called() + mock_keys_changed.assert_not_called() self.keyset.keys = {1, 2, 3} self.assertSetEqual(self.keyset._keys, {1, 2, 3}) - mock_changed.assert_called_with(Gtk.FilterChange.MORE_STRICT) self.assertEqual(self.keyset.n_keys, 3) + mock_changed.assert_called_with(Gtk.FilterChange.MORE_STRICT) + mock_keys_changed.assert_called_with(self.keyset, set(), {1, 2, 3}) mock_changed.reset_mock() self.keyset.keys = {1, 2} self.assertSetEqual(self.keyset.keys, {1, 2}) - mock_changed.assert_called_with(Gtk.FilterChange.MORE_STRICT) self.assertEqual(self.keyset.n_keys, 2) + mock_changed.assert_called_with(Gtk.FilterChange.MORE_STRICT) + mock_keys_changed.assert_called_with(self.keyset, {3}, set()) mock_changed.reset_mock() + mock_keys_changed.reset_mock() self.keyset.keys = {1, 2} mock_changed.assert_not_called() + mock_keys_changed.assert_not_called() self.keyset.keys = {1, 2, 3} self.assertSetEqual(self.keyset.keys, {1, 2, 3}) mock_changed.assert_called_with(Gtk.FilterChange.LESS_STRICT) + mock_keys_changed.assert_called_with(self.keyset, set(), {3}) self.keyset.keys = {4, 5, 6} self.assertSetEqual(self.keyset._keys, {4, 5, 6}) mock_changed.assert_called_with(Gtk.FilterChange.DIFFERENT) + mock_keys_changed.assert_called_with(self.keyset, {1, 2, 3}, {4, 5, 6}) self.keyset.keys = None self.assertIsNone(self.keyset._keys) - mock_changed.assert_called_with(Gtk.FilterChange.LESS_STRICT) self.assertEqual(self.keyset.n_keys, -1) + mock_changed.assert_called_with(Gtk.FilterChange.LESS_STRICT) + mock_keys_changed.assert_called_with(self.keyset, {4, 5, 6}, set()) - def test_match(self, mock_changed: unittest.mock.Mock): - """Test matching Rows.""" + def test_match_contains(self, mock_changed: unittest.mock.Mock): + """Test matching Rows and the __contains__() magic method.""" self.assertTrue(self.keyset.match(self.row1)) + self.assertTrue(self.row1 in self.keyset) + self.keyset.keys = {1, 2, 3} self.assertTrue(self.keyset.match(self.row1)) + self.assertTrue(self.row1 in self.keyset) + self.keyset.keys = {4, 5, 6} self.assertFalse(self.keyset.match(self.row1)) + self.assertFalse(self.row1 in self.keyset) + self.keyset.keys = set() self.assertFalse(self.keyset.match(self.row1)) + self.assertFalse(self.row1 in self.keyset) class TestTable(tests.util.TestCase):