diff --git a/emmental/playlist/playlist.py b/emmental/playlist/playlist.py index 8e20b24..8d35a04 100644 --- a/emmental/playlist/playlist.py +++ b/emmental/playlist/playlist.py @@ -22,8 +22,13 @@ class Playlist(model.TrackidModel): if playlist is not None: self.playlist = playlist + def __get_nth_track(self, n: int) -> db.tracks.Track | None: + return self[n] if n < len(self.trackids) else None + def __playlist_notify(self, plist: db.playlist.Playlist, param) -> None: match param.name: + case "loop": + self.notify("loop") case "sort-order": self.__sort_order = plist.sort_order self.on_trackids_reset(plist.tracks) @@ -91,6 +96,21 @@ class Playlist(model.TrackidModel): if self.__playlist.move_track_up(track) and need_handling: self.__track_moved(track, offset=-1) + def next_track(self) -> db.tracks.Track | None: + """Select the next track for playback.""" + if self.__playlist is None: + return None + + index = self.index(self.current_track) + match (index, self.__playlist.loop): + case (None, _): index = 0 + case (_, "Playlist"): index = (index + 1) % self.n_tracks + case (_, "None"): index += 1 + + if (next := self.__get_nth_track(index)) is not None: + self.current_track = next + return next + def remove_track(self, track: db.tracks.Track) -> None: """Remove a track from the playlist.""" if self.__playlist is not None: @@ -109,6 +129,18 @@ class Playlist(model.TrackidModel): trackid = 0 if track is None else track.trackid self.__playlist.current_trackid = trackid + @GObject.Property(type=str, flags=FLAGS) + def loop(self) -> str: + """Get the current loop setting of the Playlist.""" + return "None" if self.__playlist is None else self.__playlist.loop + + @loop.setter + def loop(self, newval: str) -> None: + if self.__playlist is not None: + if newval not in {"None", "Track", "Playlist"}: + raise ValueError + self.__playlist.loop = newval + @GObject.Property(type=db.playlist.Playlist) def playlist(self) -> db.playlist.Playlist | None: """Get the current db playlist.""" @@ -134,7 +166,7 @@ class Playlist(model.TrackidModel): self.__sort_order = None self.trackid_set = None - for prop in ("current-track", "sort-order"): + for prop in ("current-track", "loop", "sort-order"): self.notify(prop) @GObject.Property(type=str, flags=FLAGS) diff --git a/tests/playlist/test_playlist.py b/tests/playlist/test_playlist.py index 98023e4..741f3e3 100644 --- a/tests/playlist/test_playlist.py +++ b/tests/playlist/test_playlist.py @@ -141,6 +141,10 @@ class TestPlaylist(tests.util.TestCase): self.track1.trackid]) items_changed.assert_called_once_with(self.playlist, 1, 2, 2) + def test_next_track(self): + """Test the playlist next_track() function.""" + self.assertIsNone(self.playlist.next_track()) + def test_remove_track(self): """Test the playlist remove_track() function.""" self.playlist.remove_track(self.track1) @@ -223,6 +227,35 @@ class TestPlaylist(tests.util.TestCase): self.assertIsNone(self.playlist.current_track) notify.assert_called() + def test_loop(self): + """Test the Playlist loop property.""" + self.assertEqual(self.playlist.loop, "None") + + notify = unittest.mock.Mock() + self.playlist.connect("notify::loop", notify) + self.playlist.loop = "Track" + self.assertEqual(self.playlist.loop, "None") + notify.assert_not_called() + + self.playlist.playlist = self.db_plist + notify.assert_called() + + notify.reset_mock() + self.playlist.loop = "Track" + self.assertEqual(self.db_plist.loop, "Track") + self.assertEqual(self.playlist.loop, "Track") + notify.assert_called() + + notify.reset_mock() + self.db_plist.loop = "Playlist" + self.assertEqual(self.playlist.loop, "Playlist") + notify.assert_called() + + self.playlist.playlist = None + notify.reset_mock() + self.db_plist.loop = "Track" + notify.assert_not_called() + def test_playlist(self): """Test the playlist property.""" self.assertIsNone(self.playlist.playlist) @@ -284,3 +317,62 @@ class TestPlaylist(tests.util.TestCase): self.playlist.sort_order = "length" self.assertIsNone(self.playlist.sort_order) notify.assert_not_called() + + +class TestPlaylistNextTrack(tests.util.TestCase): + """Test the Playlist next_track() function.""" + + def setUpTrack(self, i: int) -> emmental.db.tracks.Track: + """Create a Track, add it to the Playlist, and return it.""" + track = self.sql.tracks.create(self.library, + pathlib.Path(f"/a/b/{i}.ogg"), + self.medium, self.year, number=i) + self.db_plist.add_track(track) + return track + + def setUp(self): + """Set up common variables.""" + super().setUp() + self.library = self.sql.libraries.create(pathlib.Path("/a/b")) + self.album = self.sql.albums.create("Test Album", "Artist", "2023") + self.medium = self.sql.media.create(self.album, "", number=1) + self.year = self.sql.years.create(2023) + + self.playlist = emmental.playlist.playlist.Playlist(self.sql) + self.db_plist = self.sql.playlists.create("Test Playlist") + self.playlist.playlist = self.db_plist + + self.tracks = [self.setUpTrack(i) for i in range(1, 6)] + + def test_next_track(self): + """Test the Playlist next_track() function with no extra flags.""" + for i, track in enumerate(self.tracks): + with self.subTest(i=i, track=track.path): + self.assertEqual(self.playlist.next_track(), track) + self.assertEqual(self.playlist.current_track, track) + + self.assertIsNone(self.playlist.next_track()) + self.assertEqual(self.playlist.current_track, self.tracks[-1]) + + def test_loop_track(self): + """Test the next_track() function with ::loop='Track'.""" + self.playlist.loop = "Track" + + for i in range(3): + with self.subTest(i=i): + self.assertEqual(self.playlist.next_track(), self.tracks[0]) + self.assertEqual(self.playlist.current_track, self.tracks[0]) + + def test_loop_playlist(self): + """Test the next_track() function with ::loop='Playlist'.""" + self.playlist.loop = "Playlist" + + for i, track in enumerate(self.tracks): + with self.subTest(i=i, track=track.path): + self.assertEqual(self.playlist.next_track(), track) + self.assertEqual(self.playlist.current_track, track) + + for i, track in enumerate(self.tracks): + with self.subTest(i=i, track=track.path): + self.assertEqual(self.playlist.next_track(), track) + self.assertEqual(self.playlist.current_track, track)