diff --git a/emmental/playlist/playlist.py b/emmental/playlist/playlist.py index c866011..8e20b24 100644 --- a/emmental/playlist/playlist.py +++ b/emmental/playlist/playlist.py @@ -54,6 +54,9 @@ class Playlist(model.TrackidModel): def on_trackid_removed(self, set: db.tracks.TrackidSet, trackid: int) -> None: """Handle the TrackidSet::trackid-removed signal.""" + if self.__playlist.current_trackid == trackid: + index = super().index(trackid) - 1 + self.current_track = None if index < 0 else self[index] super().on_trackid_removed(set, trackid) self.__sort_keys.pop(trackid, None) @@ -61,6 +64,8 @@ class Playlist(model.TrackidModel): """Handle the TrackidSet::trackids-reset signal.""" self.__sort_keys.clear() super().on_trackids_reset(set) + if super().index(self.__playlist.current_trackid) is None: + self.current_track = None def add_track(self, track: db.tracks.Track) -> None: """Add a track to the playlist.""" @@ -91,6 +96,19 @@ class Playlist(model.TrackidModel): if self.__playlist is not None: self.__playlist.remove_track(track) + @GObject.Property(type=db.tracks.Track) + def current_track(self) -> db.tracks.Track | None: + """Get the current Track of the Playlist.""" + if self.__playlist is not None: + return self.sql.tracks.rows.get(self.__playlist.current_trackid) + + @current_track.setter + def current_track(self, track: db.tracks.Track | None) -> None: + """Set the current Track.""" + if self.__playlist is not None: + trackid = 0 if track is None else track.trackid + self.__playlist.current_trackid = trackid + @GObject.Property(type=db.playlist.Playlist) def playlist(self) -> db.playlist.Playlist | None: """Get the current db playlist.""" @@ -108,11 +126,16 @@ class Playlist(model.TrackidModel): self.__playlist.connect("notify", self.__playlist_notify) self.__sort_order = new.sort_order self.trackid_set = new.tracks + + if len(self.trackids) > 0: + if new.current_trackid == self.trackids[-1]: + new.current_trackid = 0 else: self.__sort_order = None self.trackid_set = None - self.notify("sort-order") + for prop in ("current-track", "sort-order"): + self.notify(prop) @GObject.Property(type=str, flags=FLAGS) def sort_order(self) -> str: diff --git a/tests/playlist/test_playlist.py b/tests/playlist/test_playlist.py index 66e6db6..98023e4 100644 --- a/tests/playlist/test_playlist.py +++ b/tests/playlist/test_playlist.py @@ -150,6 +150,79 @@ class TestPlaylist(tests.util.TestCase): self.playlist.remove_track(self.track1) self.assertListEqual(self.playlist.trackids, []) + def test_remove_current_track(self): + """Test removing the current-track from the playlist.""" + self.db_plist.add_track(self.track1) + self.db_plist.add_track(self.track2) + self.db_plist.add_track(self.track3) + self.playlist.playlist = self.db_plist + self.playlist.current_track = self.track2 + + notify = unittest.mock.Mock() + self.playlist.connect("notify::current-track", notify) + self.playlist.remove_track(self.track3) + notify.assert_not_called() + + self.playlist.remove_track(self.track2) + self.assertEqual(self.playlist.current_track, self.track1) + self.assertEqual(self.db_plist.current_trackid, self.track1.trackid) + notify.assert_called() + + notify.reset_mock() + self.playlist.remove_track(self.track1) + self.assertIsNone(self.playlist.current_track) + self.assertEqual(self.db_plist.current_trackid, 0) + notify.assert_called() + + self.playlist.add_track(self.track1) + self.playlist.add_track(self.track3) + self.playlist.current_track = self.track3 + + notify.reset_mock() + self.db_plist.remove_track(self.track3) + self.assertEqual(self.playlist.current_track, self.track1) + self.assertEqual(self.db_plist.current_trackid, self.track1.trackid) + notify.assert_called() + + def test_current_track(self): + """Test the Playlist current-track property.""" + self.assertIsNone(self.playlist.current_track) + self.playlist.current_track = self.track2 + self.assertIsNone(self.playlist.current_track) + + self.db_plist.add_track(self.track1) + self.db_plist.add_track(self.track2) + self.db_plist.add_track(self.track3) + self.db_plist.current_trackid = self.track2.trackid + + notify = unittest.mock.Mock() + self.playlist.connect("notify::current-track", notify) + self.playlist.playlist = self.db_plist + self.assertEqual(self.playlist.current_track, self.track2) + notify.assert_called() + + self.db_plist.current_trackid = self.track3.trackid + self.playlist.playlist = self.db_plist + self.assertEqual(self.db_plist.current_trackid, 0) + self.assertIsNone(self.playlist.current_track) + + self.playlist.current_track = self.track1 + self.assertEqual(self.db_plist.current_trackid, self.track1.trackid) + self.assertEqual(self.playlist.current_track, self.track1) + + self.playlist.current_track = None + self.assertEqual(self.db_plist.current_trackid, 0) + self.assertIsNone(self.playlist.current_track) + + self.db_plist.remove_track(self.track2) + self.db_plist.current_trackid = self.track2.trackid + + notify.reset_mock() + self.db_plist.tracks.emit("trackids-reset") + self.assertEqual(self.db_plist.current_trackid, 0) + self.assertIsNone(self.playlist.current_track) + notify.assert_called() + def test_playlist(self): """Test the playlist property.""" self.assertIsNone(self.playlist.playlist)