diff --git a/emmental/playlist/model.py b/emmental/playlist/model.py new file mode 100644 index 0000000..fef44eb --- /dev/null +++ b/emmental/playlist/model.py @@ -0,0 +1,103 @@ +# Copyright 2023 (c) Anna Schumaker. +"""Converts a TrackidSet into a Gio.ListModel.""" +import bisect +from gi.repository import GObject +from gi.repository import Gio +from .. import db + + +class TrackidModel(GObject.GObject, Gio.ListModel): + """A Gio.ListModel representing a TrackidSet.""" + + sql = GObject.Property(type=db.Connection) + n_tracks = GObject.Property(type=int) + + def __init__(self, sql: db.Connection): + """Initialize the TrackidModel.""" + super().__init__(sql=sql) + self.__trackid_set = None + self.trackids = [] + + def bisect(self, trackid: int) -> tuple[bool, int | None]: + """Bisect the TrackidModel for the given trackid.""" + pos = bisect.bisect_left(self.trackids, + self.do_get_sort_key(trackid), + key=self.do_get_sort_key) + + if pos < len(self.trackids): + return (self.trackids[pos] == trackid, pos) + return (False, pos) + + def do_get_item_type(self) -> GObject.GType: + """Get the item type of this Model.""" + return db.tracks.Track.__gtype__ + + def do_get_n_items(self) -> int: + """Get the number of items in the list.""" + return len(self.trackids) + + def do_get_item(self, n: int) -> db.tracks.Track | None: + """Get the n-th item in the list.""" + if n < len(self.trackids): + return self.sql.tracks.rows.get(self.trackids[n]) + + def do_get_sort_key(self, trackid: int) -> int: + """Get a stort key for the given trackid.""" + return trackid + + def do_items_changed(self, *, position: int, + removed: int, added: int) -> None: + """Emit the ::items-changed signal.""" + self.n_tracks = len(self.trackids) + self.items_changed(position, removed, added) + + def index(self, trackid: int) -> int | None: + """Find the index of a specific trackid.""" + (has, pos) = self.bisect(trackid) + return pos if has else None + + def on_trackid_added(self, set: db.tracks.TrackidSet, + trackid: int) -> None: + """Respond to the trackid-added signal.""" + (has, pos) = self.bisect(trackid) + if not has: + self.trackids.insert(pos, trackid) + self.do_items_changed(position=pos, removed=0, added=1) + + def on_trackid_removed(self, set: db.tracks.TrackidSet, + trackid: int) -> None: + """Respond to the trackid-removed signal.""" + (has, pos) = self.bisect(trackid) + if has: + del self.trackids[pos] + self.do_items_changed(position=pos, removed=1, added=0) + + def on_trackids_reset(self, set: db.tracks.TrackidSet) -> None: + """Respond to the trackids-reset signal.""" + self.trackids = sorted(set.trackids, key=self.do_get_sort_key) + self.do_items_changed(position=0, removed=self.n_tracks, + added=len(self.trackids)) + + @GObject.Property(type=db.tracks.TrackidSet) + def trackid_set(self) -> db.tracks.TrackidSet | None: + """Get the current trackid-set.""" + return self.__trackid_set + + @trackid_set.setter + def trackid_set(self, new: db.tracks.TrackidSet | None) -> None: + """Set a new value to the trackid-set property.""" + if self.__trackid_set is not None: + self.__trackid_set.disconnect_by_func(self.on_trackid_added) + self.__trackid_set.disconnect_by_func(self.on_trackid_removed) + self.__trackid_set.disconnect_by_func(self.on_trackids_reset) + self.trackids = [] + + self.__trackid_set = new + if new is not None: + new.connect("trackid-added", self.on_trackid_added) + new.connect("trackid-removed", self.on_trackid_removed) + new.connect("trackids-reset", self.on_trackids_reset) + self.trackids = sorted(new.trackids, key=self.do_get_sort_key) + + self.do_items_changed(position=0, removed=self.n_tracks, + added=len(self.trackids)) diff --git a/tests/playlist/test_model.py b/tests/playlist/test_model.py new file mode 100644 index 0000000..6031275 --- /dev/null +++ b/tests/playlist/test_model.py @@ -0,0 +1,192 @@ +# Copyright 2023 (c) Anna Schumaker. +"""Tests our TrackidModel.""" +import pathlib +import unittest.mock +import tests.util +import emmental.playlist.model +from gi.repository import Gio + + +class TestTrackidModel(tests.util.TestCase): + """Tests the Trackid Model.""" + + def setUp(self): + """Set up common variables.""" + super().setUp() + self.db_plist = self.sql.playlists.create("Test Playlist") + self.model = emmental.playlist.model.TrackidModel(self.sql) + + self.library = self.sql.libraries.create(pathlib.Path("/a/b")) + self.album = self.sql.albums.create("Test Album", "Artist", "2023") + self.medium = self.sql.media.create(self.album, "", number=1) + self.year = self.sql.years.create(2023) + + self.track1 = self.sql.tracks.create(self.library, + pathlib.Path("/a/b/1.ogg"), + self.medium, self.year, number=1) + self.track2 = self.sql.tracks.create(self.library, + pathlib.Path("/a/b/2.ogg"), + self.medium, self.year, number=2) + + def test_init(self): + """Test that the TrackidModel was set up correctly.""" + self.assertIsInstance(self.model, Gio.ListModel) + self.assertEqual(self.model.sql, self.sql) + self.assertIsNone(self.model._TrackidModel__trackid_set) + + def test_bisect(self): + """Test the TrackidModel bisect() function.""" + self.assertTupleEqual(self.model.bisect(self.track1.trackid), + (False, 0)) + + self.model.trackid_set = self.db_plist.tracks + self.db_plist.add_track(self.track1) + self.assertTupleEqual(self.model.bisect(self.track1.trackid), + (True, 0)) + self.assertTupleEqual(self.model.bisect(self.track2.trackid), + (False, 1)) + + self.db_plist.add_track(self.track2) + self.assertTupleEqual(self.model.bisect(self.track2.trackid), + (True, 1)) + + def test_get_item_type(self): + """Test the Gio.ListModel:get_item_type() function.""" + self.assertEqual(self.model.get_item_type(), + emmental.db.tracks.Track.__gtype__) + + def test_get_n_items(self): + """Test the Gio.ListModel:get_n_items() function.""" + self.model.trackid_set = self.db_plist.tracks + self.assertEqual(self.model.get_n_items(), 0) + self.db_plist.add_track(self.track1) + self.assertEqual(self.model.get_n_items(), 1) + self.db_plist.add_track(self.track2) + self.assertEqual(self.model.get_n_items(), 2) + + def test_get_item(self): + """Test the Gio.ListModel:get_item() function.""" + self.assertIsNone(self.model.get_item(0)) + + self.model.trackid_set = self.db_plist.tracks + self.db_plist.add_track(self.track1) + self.db_plist.add_track(self.track2) + + self.assertEqual(self.model.get_item(0), self.track1) + self.assertEqual(self.model.get_item(1), self.track2) + self.assertIsNone(self.model.get_item(2)) + + def test_index(self): + """Test finding the index of a specific trackid.""" + self.assertIsNone(self.model.index(self.track1.trackid)) + + self.model.trackid_set = self.db_plist.tracks + self.db_plist.add_track(self.track1) + self.db_plist.add_track(self.track2) + + self.assertEqual(self.model.index(self.track1.trackid), 0) + self.assertEqual(self.model.index(self.track2.trackid), 1) + self.assertIsNone(self.model.index(self.track2.trackid + 1)) + + def test_trackid_added(self): + """Test that the TrackidModel responds to the trackid-added signal.""" + items_changed = unittest.mock.Mock() + self.model.connect("items-changed", items_changed) + + self.model.trackid_set = self.db_plist.tracks + self.assertListEqual(self.model.trackids, []) + self.assertEqual(self.model.n_tracks, 0) + + self.db_plist.add_track(self.track2) + self.assertListEqual(self.model.trackids, [self.track2.trackid]) + self.assertEqual(self.model.n_tracks, 1) + items_changed.assert_called_with(self.model, 0, 0, 1) + + self.db_plist.add_track(self.track1) + self.assertListEqual(self.model.trackids, + [self.track1.trackid, self.track2.trackid]) + self.assertEqual(self.model.n_tracks, 2) + items_changed.assert_called_with(self.model, 0, 0, 1) + + self.model.trackid_set = None + self.db_plist.tracks.trackids.clear() + items_changed.reset_mock() + + self.db_plist.tracks.add_track(self.track1) + self.assertListEqual(self.model.trackids, []) + self.assertEqual(self.model.n_tracks, 0) + items_changed.assert_not_called() + + def test_trackid_removed(self): + """Test that the TrackModel responds to the trackid-removed signal.""" + items_changed = unittest.mock.Mock() + self.model.connect("items-changed", items_changed) + + self.db_plist.tracks.trackids = {self.track1.trackid, + self.track2.trackid} + self.model.trackid_set = self.db_plist.tracks + + self.db_plist.remove_track(self.track2) + self.assertListEqual(self.model.trackids, [self.track1.trackid]) + self.assertEqual(self.model.n_tracks, 1) + items_changed.assert_called_with(self.model, 1, 1, 0) + + self.db_plist.remove_track(self.track1) + self.assertListEqual(self.model.trackids, []) + self.assertEqual(self.model.n_tracks, 0) + items_changed.assert_called_with(self.model, 0, 1, 0) + + self.model.trackid_set = None + self.model.trackids = [self.track1.trackid] + self.db_plist.tracks.trackids = {self.track1.trackid} + items_changed.reset_mock() + + self.db_plist.tracks.remove_track(self.track1) + self.assertListEqual(self.model.trackids, [self.track1.trackid]) + self.assertEqual(self.model.n_tracks, 0) + items_changed.assert_not_called() + + def test_trackids_reset(self): + """Test that the TrackModel responds to the trackids-reset signal.""" + items_changed = unittest.mock.Mock() + self.model.connect("items-changed", items_changed) + + self.model.trackid_set = self.db_plist.tracks + self.db_plist.tracks.trackids = {self.track1.trackid, + self.track2.trackid} + self.assertListEqual(self.model.trackids, [self.track1.trackid, + self.track2.trackid]) + self.assertEqual(self.model.n_tracks, 2) + items_changed.assert_called_with(self.model, 0, 0, 2) + + self.model.trackid_set = None + self.db_plist.tracks.trackids.clear() + items_changed.reset_mock() + + self.db_plist.tracks.trackids = {self.track1.trackid, + self.track2.trackid} + self.assertListEqual(self.model.trackids, []) + items_changed.assert_not_called() + + def test_trackid_set(self): + """Test the trackid-set property.""" + items_changed = unittest.mock.Mock() + self.model.connect("items-changed", items_changed) + + self.assertIsNone(self.model.trackid_set) + self.db_plist.tracks.trackids = {self.track1.trackid, + self.track2.trackid} + + self.model.trackid_set = self.db_plist.tracks + self.assertEqual(self.model._TrackidModel__trackid_set, + self.db_plist.tracks) + self.assertEqual(self.model.trackid_set, self.db_plist.tracks) + self.assertEqual(self.model.trackids, [self.track1.trackid, + self.track2.trackid]) + self.assertEqual(self.model.n_tracks, 2) + items_changed.assert_called_with(self.model, 0, 0, 2) + + self.model.trackid_set = None + self.assertEqual(self.model.trackids, []) + self.assertEqual(self.model.n_tracks, 0) + items_changed.assert_called_with(self.model, 0, 2, 0)