From 4ce571ebf8800628656dd59e71cf7d26c4e07a3a Mon Sep 17 00:00:00 2001 From: Anna Schumaker Date: Sat, 15 Apr 2023 13:31:15 -0400 Subject: [PATCH] playlist: Create a TrackidModel Gio.ListModel The TrackidModel takes a TrackidSet and presents it as a Gio.ListModel that maps trackids into Track objects. Tracks can be found by value using the bisect() function, which sorts the trackids by number by default (this can be changed by overriding the do_get_sort_key() function). Signed-off-by: Anna Schumaker --- emmental/playlist/model.py | 103 +++++++++++++++++++ tests/playlist/test_model.py | 192 +++++++++++++++++++++++++++++++++++ 2 files changed, 295 insertions(+) create mode 100644 emmental/playlist/model.py create mode 100644 tests/playlist/test_model.py diff --git a/emmental/playlist/model.py b/emmental/playlist/model.py new file mode 100644 index 0000000..fef44eb --- /dev/null +++ b/emmental/playlist/model.py @@ -0,0 +1,103 @@ +# Copyright 2023 (c) Anna Schumaker. +"""Converts a TrackidSet into a Gio.ListModel.""" +import bisect +from gi.repository import GObject +from gi.repository import Gio +from .. import db + + +class TrackidModel(GObject.GObject, Gio.ListModel): + """A Gio.ListModel representing a TrackidSet.""" + + sql = GObject.Property(type=db.Connection) + n_tracks = GObject.Property(type=int) + + def __init__(self, sql: db.Connection): + """Initialize the TrackidModel.""" + super().__init__(sql=sql) + self.__trackid_set = None + self.trackids = [] + + def bisect(self, trackid: int) -> tuple[bool, int | None]: + """Bisect the TrackidModel for the given trackid.""" + pos = bisect.bisect_left(self.trackids, + self.do_get_sort_key(trackid), + key=self.do_get_sort_key) + + if pos < len(self.trackids): + return (self.trackids[pos] == trackid, pos) + return (False, pos) + + def do_get_item_type(self) -> GObject.GType: + """Get the item type of this Model.""" + return db.tracks.Track.__gtype__ + + def do_get_n_items(self) -> int: + """Get the number of items in the list.""" + return len(self.trackids) + + def do_get_item(self, n: int) -> db.tracks.Track | None: + """Get the n-th item in the list.""" + if n < len(self.trackids): + return self.sql.tracks.rows.get(self.trackids[n]) + + def do_get_sort_key(self, trackid: int) -> int: + """Get a stort key for the given trackid.""" + return trackid + + def do_items_changed(self, *, position: int, + removed: int, added: int) -> None: + """Emit the ::items-changed signal.""" + self.n_tracks = len(self.trackids) + self.items_changed(position, removed, added) + + def index(self, trackid: int) -> int | None: + """Find the index of a specific trackid.""" + (has, pos) = self.bisect(trackid) + return pos if has else None + + def on_trackid_added(self, set: db.tracks.TrackidSet, + trackid: int) -> None: + """Respond to the trackid-added signal.""" + (has, pos) = self.bisect(trackid) + if not has: + self.trackids.insert(pos, trackid) + self.do_items_changed(position=pos, removed=0, added=1) + + def on_trackid_removed(self, set: db.tracks.TrackidSet, + trackid: int) -> None: + """Respond to the trackid-removed signal.""" + (has, pos) = self.bisect(trackid) + if has: + del self.trackids[pos] + self.do_items_changed(position=pos, removed=1, added=0) + + def on_trackids_reset(self, set: db.tracks.TrackidSet) -> None: + """Respond to the trackids-reset signal.""" + self.trackids = sorted(set.trackids, key=self.do_get_sort_key) + self.do_items_changed(position=0, removed=self.n_tracks, + added=len(self.trackids)) + + @GObject.Property(type=db.tracks.TrackidSet) + def trackid_set(self) -> db.tracks.TrackidSet | None: + """Get the current trackid-set.""" + return self.__trackid_set + + @trackid_set.setter + def trackid_set(self, new: db.tracks.TrackidSet | None) -> None: + """Set a new value to the trackid-set property.""" + if self.__trackid_set is not None: + self.__trackid_set.disconnect_by_func(self.on_trackid_added) + self.__trackid_set.disconnect_by_func(self.on_trackid_removed) + self.__trackid_set.disconnect_by_func(self.on_trackids_reset) + self.trackids = [] + + self.__trackid_set = new + if new is not None: + new.connect("trackid-added", self.on_trackid_added) + new.connect("trackid-removed", self.on_trackid_removed) + new.connect("trackids-reset", self.on_trackids_reset) + self.trackids = sorted(new.trackids, key=self.do_get_sort_key) + + self.do_items_changed(position=0, removed=self.n_tracks, + added=len(self.trackids)) diff --git a/tests/playlist/test_model.py b/tests/playlist/test_model.py new file mode 100644 index 0000000..6031275 --- /dev/null +++ b/tests/playlist/test_model.py @@ -0,0 +1,192 @@ +# Copyright 2023 (c) Anna Schumaker. +"""Tests our TrackidModel.""" +import pathlib +import unittest.mock +import tests.util +import emmental.playlist.model +from gi.repository import Gio + + +class TestTrackidModel(tests.util.TestCase): + """Tests the Trackid Model.""" + + def setUp(self): + """Set up common variables.""" + super().setUp() + self.db_plist = self.sql.playlists.create("Test Playlist") + self.model = emmental.playlist.model.TrackidModel(self.sql) + + self.library = self.sql.libraries.create(pathlib.Path("/a/b")) + self.album = self.sql.albums.create("Test Album", "Artist", "2023") + self.medium = self.sql.media.create(self.album, "", number=1) + self.year = self.sql.years.create(2023) + + self.track1 = self.sql.tracks.create(self.library, + pathlib.Path("/a/b/1.ogg"), + self.medium, self.year, number=1) + self.track2 = self.sql.tracks.create(self.library, + pathlib.Path("/a/b/2.ogg"), + self.medium, self.year, number=2) + + def test_init(self): + """Test that the TrackidModel was set up correctly.""" + self.assertIsInstance(self.model, Gio.ListModel) + self.assertEqual(self.model.sql, self.sql) + self.assertIsNone(self.model._TrackidModel__trackid_set) + + def test_bisect(self): + """Test the TrackidModel bisect() function.""" + self.assertTupleEqual(self.model.bisect(self.track1.trackid), + (False, 0)) + + self.model.trackid_set = self.db_plist.tracks + self.db_plist.add_track(self.track1) + self.assertTupleEqual(self.model.bisect(self.track1.trackid), + (True, 0)) + self.assertTupleEqual(self.model.bisect(self.track2.trackid), + (False, 1)) + + self.db_plist.add_track(self.track2) + self.assertTupleEqual(self.model.bisect(self.track2.trackid), + (True, 1)) + + def test_get_item_type(self): + """Test the Gio.ListModel:get_item_type() function.""" + self.assertEqual(self.model.get_item_type(), + emmental.db.tracks.Track.__gtype__) + + def test_get_n_items(self): + """Test the Gio.ListModel:get_n_items() function.""" + self.model.trackid_set = self.db_plist.tracks + self.assertEqual(self.model.get_n_items(), 0) + self.db_plist.add_track(self.track1) + self.assertEqual(self.model.get_n_items(), 1) + self.db_plist.add_track(self.track2) + self.assertEqual(self.model.get_n_items(), 2) + + def test_get_item(self): + """Test the Gio.ListModel:get_item() function.""" + self.assertIsNone(self.model.get_item(0)) + + self.model.trackid_set = self.db_plist.tracks + self.db_plist.add_track(self.track1) + self.db_plist.add_track(self.track2) + + self.assertEqual(self.model.get_item(0), self.track1) + self.assertEqual(self.model.get_item(1), self.track2) + self.assertIsNone(self.model.get_item(2)) + + def test_index(self): + """Test finding the index of a specific trackid.""" + self.assertIsNone(self.model.index(self.track1.trackid)) + + self.model.trackid_set = self.db_plist.tracks + self.db_plist.add_track(self.track1) + self.db_plist.add_track(self.track2) + + self.assertEqual(self.model.index(self.track1.trackid), 0) + self.assertEqual(self.model.index(self.track2.trackid), 1) + self.assertIsNone(self.model.index(self.track2.trackid + 1)) + + def test_trackid_added(self): + """Test that the TrackidModel responds to the trackid-added signal.""" + items_changed = unittest.mock.Mock() + self.model.connect("items-changed", items_changed) + + self.model.trackid_set = self.db_plist.tracks + self.assertListEqual(self.model.trackids, []) + self.assertEqual(self.model.n_tracks, 0) + + self.db_plist.add_track(self.track2) + self.assertListEqual(self.model.trackids, [self.track2.trackid]) + self.assertEqual(self.model.n_tracks, 1) + items_changed.assert_called_with(self.model, 0, 0, 1) + + self.db_plist.add_track(self.track1) + self.assertListEqual(self.model.trackids, + [self.track1.trackid, self.track2.trackid]) + self.assertEqual(self.model.n_tracks, 2) + items_changed.assert_called_with(self.model, 0, 0, 1) + + self.model.trackid_set = None + self.db_plist.tracks.trackids.clear() + items_changed.reset_mock() + + self.db_plist.tracks.add_track(self.track1) + self.assertListEqual(self.model.trackids, []) + self.assertEqual(self.model.n_tracks, 0) + items_changed.assert_not_called() + + def test_trackid_removed(self): + """Test that the TrackModel responds to the trackid-removed signal.""" + items_changed = unittest.mock.Mock() + self.model.connect("items-changed", items_changed) + + self.db_plist.tracks.trackids = {self.track1.trackid, + self.track2.trackid} + self.model.trackid_set = self.db_plist.tracks + + self.db_plist.remove_track(self.track2) + self.assertListEqual(self.model.trackids, [self.track1.trackid]) + self.assertEqual(self.model.n_tracks, 1) + items_changed.assert_called_with(self.model, 1, 1, 0) + + self.db_plist.remove_track(self.track1) + self.assertListEqual(self.model.trackids, []) + self.assertEqual(self.model.n_tracks, 0) + items_changed.assert_called_with(self.model, 0, 1, 0) + + self.model.trackid_set = None + self.model.trackids = [self.track1.trackid] + self.db_plist.tracks.trackids = {self.track1.trackid} + items_changed.reset_mock() + + self.db_plist.tracks.remove_track(self.track1) + self.assertListEqual(self.model.trackids, [self.track1.trackid]) + self.assertEqual(self.model.n_tracks, 0) + items_changed.assert_not_called() + + def test_trackids_reset(self): + """Test that the TrackModel responds to the trackids-reset signal.""" + items_changed = unittest.mock.Mock() + self.model.connect("items-changed", items_changed) + + self.model.trackid_set = self.db_plist.tracks + self.db_plist.tracks.trackids = {self.track1.trackid, + self.track2.trackid} + self.assertListEqual(self.model.trackids, [self.track1.trackid, + self.track2.trackid]) + self.assertEqual(self.model.n_tracks, 2) + items_changed.assert_called_with(self.model, 0, 0, 2) + + self.model.trackid_set = None + self.db_plist.tracks.trackids.clear() + items_changed.reset_mock() + + self.db_plist.tracks.trackids = {self.track1.trackid, + self.track2.trackid} + self.assertListEqual(self.model.trackids, []) + items_changed.assert_not_called() + + def test_trackid_set(self): + """Test the trackid-set property.""" + items_changed = unittest.mock.Mock() + self.model.connect("items-changed", items_changed) + + self.assertIsNone(self.model.trackid_set) + self.db_plist.tracks.trackids = {self.track1.trackid, + self.track2.trackid} + + self.model.trackid_set = self.db_plist.tracks + self.assertEqual(self.model._TrackidModel__trackid_set, + self.db_plist.tracks) + self.assertEqual(self.model.trackid_set, self.db_plist.tracks) + self.assertEqual(self.model.trackids, [self.track1.trackid, + self.track2.trackid]) + self.assertEqual(self.model.n_tracks, 2) + items_changed.assert_called_with(self.model, 0, 0, 2) + + self.model.trackid_set = None + self.assertEqual(self.model.trackids, []) + self.assertEqual(self.model.n_tracks, 0) + items_changed.assert_called_with(self.model, 0, 2, 0)