From 61fc2521721dead2d4709884cb1c02a7a53e7fe8 Mon Sep 17 00:00:00 2001 From: Anna Schumaker Date: Tue, 3 Jan 2023 15:14:09 -0500 Subject: [PATCH] store: Add a SortedList store This ListStore implementation uses a key function to keep the items sorted at all times. Signed-off-by: Anna Schumaker --- emmental/store.py | 39 +++++++++++++++++ tests/test_store.py | 100 ++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 139 insertions(+) diff --git a/emmental/store.py b/emmental/store.py index 205eb09..864188b 100644 --- a/emmental/store.py +++ b/emmental/store.py @@ -1,5 +1,6 @@ # Copyright 2023 (c) Anna Schumaker. """A Python-based ListStore implementation.""" +import bisect import typing from gi.repository import GObject from gi.repository import Gio @@ -84,3 +85,41 @@ class ListStore(GObject.GObject, Gio.ListModel): if (index := self.index(item)) is not None: return self.pop(index) is not None return False + + +class SortedList(ListStore): + """A ListStore that keeps objects in a sorted order.""" + + def __init__(self, key_func: typing.Callable): + """Initialize a SortedList.""" + super().__init__() + self.key_func = key_func + + def __bisect(self, item: GObject.GObject) -> tuple[bool, int]: + item_key = self.key_func(item) + pos = bisect.bisect_left(self.items, item_key, key=self.key_func) + if pos < self.n_items: + cur_key = self.key_func(self.items[pos]) + return (item_key == cur_key, pos) + return (False, pos) + + def append(self, item: GObject.GObject) -> bool: + """Add an item to the list.""" + (found, pos) = self.__bisect(item) + return super().insert(pos, item) if not found else False + + def extend(self, items: typing.Iterable) -> None: + """Add multiple items to the list.""" + self.items.extend(items) + if len(self.items) != self.n_items: + self.items.sort(key=self.key_func) + self.items_changed(0, self.n_items, len(self.items)) + + def index(self, item: GObject.GObject) -> int | None: + """Find the index of an item in the list.""" + (found, pos) = self.__bisect(item) + return pos if found else None + + def insert(self, index: int, item: GObject.GObject) -> bool: + """Insert an item into the list.""" + return self.append(item) diff --git a/tests/test_store.py b/tests/test_store.py index 6597330..017435f 100644 --- a/tests/test_store.py +++ b/tests/test_store.py @@ -156,3 +156,103 @@ class TestListStore(unittest.TestCase): self.changed = (None, None, None) self.assertFalse(self.store.remove(objs[0])) self.assertTupleEqual(self.changed, (None, None, None)) + + +class TestSortedList(unittest.TestCase): + """Test case for our SortedList implementation.""" + + def items_changed(self, store: emmental.store.SortedList, + pos: int, removed: int, added: int) -> None: + """Handle the items-changed signal.""" + self.changed = (pos, removed, added) + + def key_func(self, obj: Object) -> int: + """Get a sort key for Objects.""" + return obj.value + + def setUp(self): + """Set up common variables.""" + self.store = emmental.store.SortedList(self.key_func) + self.store.connect("items-changed", self.items_changed) + self.changed = (None, None, None) + + def test_init(self): + """Check that the SortedList is initialized properly.""" + self.assertIsInstance(self.store, emmental.store.ListStore) + self.assertEqual(self.store.key_func, self.key_func) + + def test_append(self): + """Test the SortedList's append() function.""" + self.assertTrue(self.store.append(Object(value=2))) + self.assertTupleEqual(self.changed, (0, 0, 1)) + self.assertListEqual([obj.value for obj in self.store.items], + [2]) + + self.assertTrue(self.store.append(Object(value=0))) + self.assertTupleEqual(self.changed, (0, 0, 1)) + self.assertListEqual([obj.value for obj in self.store.items], + [0, 2]) + + self.assertTrue(self.store.append(Object(value=1))) + self.assertTupleEqual(self.changed, (1, 0, 1)) + self.assertListEqual([obj.value for obj in self.store.items], + [0, 1, 2]) + + self.changed = (None, None, None) + for value in [0, 1, 2]: + with self.subTest(value=value): + self.assertFalse(self.store.append(Object(value=value))) + self.assertTupleEqual(self.changed, (None, None, None)) + self.assertListEqual([obj.value for obj in self.store.items], + [0, 1, 2]) + + def test_extend(self): + """Test adding multiple values to the SortedList.""" + self.store.extend([]) + self.assertTupleEqual(self.changed, (None, None, None)) + self.assertListEqual(self.store.items, []) + + self.store.extend([Object(value=1), Object(value=3)]) + self.assertTupleEqual(self.changed, (0, 0, 2)) + self.assertListEqual([obj.value for obj in self.store.items], + [1, 3]) + + self.store.extend([Object(value=0), Object(value=2), Object(value=4)]) + self.assertTupleEqual(self.changed, (0, 2, 5)) + self.assertListEqual([obj.value for obj in self.store.items], + [0, 1, 2, 3, 4]) + + def test_index(self): + """Test finding the index of items in the SortedList.""" + objs = [Object(value=i) for i in range(5)] + self.store.extend(objs) + + for i, obj in enumerate(objs): + with self.subTest(i=i): + self.assertEqual(self.store.index(obj), i) + self.assertIsNone(self.store.index(Object(value=10))) + + def test_insert(self): + """Test the SortedList's insert() function.""" + self.assertTrue(self.store.insert(2, Object(value=2))) + self.assertTupleEqual(self.changed, (0, 0, 1)) + self.assertListEqual([obj.value for obj in self.store.items], + [2]) + + self.assertTrue(self.store.insert(1, Object(value=0))) + self.assertTupleEqual(self.changed, (0, 0, 1)) + self.assertListEqual([obj.value for obj in self.store.items], + [0, 2]) + + self.assertTrue(self.store.insert(0, Object(value=1))) + self.assertTupleEqual(self.changed, (1, 0, 1)) + self.assertListEqual([obj.value for obj in self.store.items], + [0, 1, 2]) + + self.changed = (None, None, None) + for i, obj in enumerate(self.store): + with self.subTest(i=i): + self.assertFalse(self.store.insert(0, obj)) + self.assertTupleEqual(self.changed, (None, None, None)) + self.assertListEqual([obj.value for obj in self.store.items], + [0, 1, 2])