diff --git a/emmental/playlist/playlist.py b/emmental/playlist/playlist.py index 8d35a04..138a9eb 100644 --- a/emmental/playlist/playlist.py +++ b/emmental/playlist/playlist.py @@ -14,6 +14,7 @@ class Playlist(model.TrackidModel): playlist: db.playlist.Playlist = None): """Initialize the Playlist instance.""" super().__init__(sql=sql) + self.__picked = db.tracks.TrackidSet() self.__sort_keys = {} self.__playlist = None @@ -22,13 +23,20 @@ class Playlist(model.TrackidModel): if playlist is not None: self.playlist = playlist - def __get_nth_track(self, n: int) -> db.tracks.Track | None: - return self[n] if n < len(self.trackids) else None + def __get_nth_track(self, n: int | None) -> db.tracks.Track | None: + return self[n] if n is not None and n < len(self.trackids) else None + + def __get_random_index(self, loop: bool) -> int | None: + choices = self.__playlist.tracks - self.__picked + if len(choices) == 0 and loop: + self.__picked.trackids = {} + choices = self.__playlist.tracks + return super().index(choices.random_trackid()) def __playlist_notify(self, plist: db.playlist.Playlist, param) -> None: match param.name: - case "loop": - self.notify("loop") + case "loop" | "shuffle": + self.notify(param.name) case "sort-order": self.__sort_order = plist.sort_order self.on_trackids_reset(plist.tracks) @@ -102,10 +110,13 @@ class Playlist(model.TrackidModel): return None index = self.index(self.current_track) - match (index, self.__playlist.loop): - case (None, _): index = 0 - case (_, "Playlist"): index = (index + 1) % self.n_tracks - case (_, "None"): index += 1 + match (index, self.__playlist.loop, self.__playlist.shuffle): + case (None, _, False): index = 0 + case (None, _, True): index = self.__get_random_index(False) + case (_, "Playlist", False): index = (index + 1) % self.n_tracks + case (_, "Playlist", True): index = self.__get_random_index(True) + case (_, "None", False): index += 1 + case (_, "None", True): index = self.__get_random_index(False) if (next := self.__get_nth_track(index)) is not None: self.current_track = next @@ -126,8 +137,11 @@ class Playlist(model.TrackidModel): def current_track(self, track: db.tracks.Track | None) -> None: """Set the current Track.""" if self.__playlist is not None: - trackid = 0 if track is None else track.trackid - self.__playlist.current_trackid = trackid + if track is not None: + self.__playlist.current_trackid = track.trackid + self.__picked.add_track(track) + else: + self.__playlist.current_trackid = 0 @GObject.Property(type=str, flags=FLAGS) def loop(self) -> str: @@ -152,6 +166,7 @@ class Playlist(model.TrackidModel): if self.__playlist: self.__playlist.disconnect_by_func(self.__playlist_notify) + self.__picked.trackids = set() self.__playlist = new if new is not None: @@ -162,13 +177,25 @@ class Playlist(model.TrackidModel): if len(self.trackids) > 0: if new.current_trackid == self.trackids[-1]: new.current_trackid = 0 + if track := self.current_track: + self.__picked.add_track(track) else: self.__sort_order = None self.trackid_set = None - for prop in ("current-track", "loop", "sort-order"): + for prop in ("current-track", "loop", "shuffle", "sort-order"): self.notify(prop) + @GObject.Property(type=bool, default=False, flags=FLAGS) + def shuffle(self) -> bool: + """Get the current shuffle setting of the Playlist.""" + return False if self.__playlist is None else self.__playlist.shuffle + + @shuffle.setter + def shuffle(self, newval: bool) -> None: + if self.__playlist is not None: + self.__playlist.shuffle = newval + @GObject.Property(type=str, flags=FLAGS) def sort_order(self) -> str: """Get the current sort order.""" diff --git a/tests/playlist/test_playlist.py b/tests/playlist/test_playlist.py index 741f3e3..6ba47a4 100644 --- a/tests/playlist/test_playlist.py +++ b/tests/playlist/test_playlist.py @@ -44,6 +44,9 @@ class TestPlaylist(tests.util.TestCase): """Test that the Playlist was set up correctly.""" self.assertIsInstance(self.playlist, emmental.playlist.model.TrackidModel) + self.assertIsInstance(self.playlist._Playlist__picked, + emmental.db.tracks.TrackidSet) + self.assertEqual(self.playlist.sql, self.sql) self.assertDictEqual(self.playlist._Playlist__sort_keys, {}) @@ -193,6 +196,7 @@ class TestPlaylist(tests.util.TestCase): self.assertIsNone(self.playlist.current_track) self.playlist.current_track = self.track2 self.assertIsNone(self.playlist.current_track) + self.assertSetEqual(self.playlist._Playlist__picked.trackids, set()) self.db_plist.add_track(self.track1) self.db_plist.add_track(self.track2) @@ -203,20 +207,27 @@ class TestPlaylist(tests.util.TestCase): self.playlist.connect("notify::current-track", notify) self.playlist.playlist = self.db_plist self.assertEqual(self.playlist.current_track, self.track2) + self.assertSetEqual(self.playlist._Playlist__picked.trackids, + {self.track2.trackid}) notify.assert_called() self.db_plist.current_trackid = self.track3.trackid self.playlist.playlist = self.db_plist self.assertEqual(self.db_plist.current_trackid, 0) self.assertIsNone(self.playlist.current_track) + self.assertSetEqual(self.playlist._Playlist__picked.trackids, set()) self.playlist.current_track = self.track1 self.assertEqual(self.db_plist.current_trackid, self.track1.trackid) self.assertEqual(self.playlist.current_track, self.track1) + self.assertSetEqual(self.playlist._Playlist__picked.trackids, + {self.track1.trackid}) self.playlist.current_track = None self.assertEqual(self.db_plist.current_trackid, 0) self.assertIsNone(self.playlist.current_track) + self.assertSetEqual(self.playlist._Playlist__picked.trackids, + {self.track1.trackid}) self.db_plist.remove_track(self.track2) self.db_plist.current_trackid = self.track2.trackid @@ -270,6 +281,35 @@ class TestPlaylist(tests.util.TestCase): self.assertIsNone(self.playlist.playlist) self.assertIsNone(self.playlist.trackid_set) + def test_shuffle(self): + """Test the Playlist shuffle property.""" + self.assertFalse(self.playlist.shuffle) + + notify = unittest.mock.Mock() + self.playlist.connect("notify::shuffle", notify) + self.playlist.shuffle = True + self.assertFalse(self.playlist.shuffle) + notify.assert_not_called() + + self.playlist.playlist = self.db_plist + notify.assert_called() + + notify.reset_mock() + self.playlist.shuffle = True + self.assertTrue(self.playlist.shuffle) + self.assertTrue(self.db_plist.shuffle) + notify.assert_called() + + notify.reset_mock() + self.db_plist.shuffle = False + self.assertFalse(self.playlist.shuffle) + notify.assert_called() + + self.playlist.playlist = None + notify.reset_mock() + self.db_plist.shuffle = True + notify.assert_not_called() + def test_sort_order(self): """Test the sort-order property.""" notify = unittest.mock.Mock() @@ -376,3 +416,49 @@ class TestPlaylistNextTrack(tests.util.TestCase): with self.subTest(i=i, track=track.path): self.assertEqual(self.playlist.next_track(), track) self.assertEqual(self.playlist.current_track, track) + + @unittest.mock.patch("random.choice") + def test_shuffle(self, mock_random: unittest.mock.Mock()): + """Test the next_track() function with ::shuffle=True.""" + self.playlist.shuffle = True + + mock_random.return_value = 3 + self.assertEqual(self.playlist.next_track(), self.tracks[2]) + mock_random.assert_called_with([1, 2, 3, 4, 5]) + + mock_random.return_value = 5 + self.assertEqual(self.playlist.next_track(), self.tracks[4]) + mock_random.assert_called_with([1, 2, 4, 5]) + + self.db_plist.tracks.trackids = set() + self.assertIsNone(self.playlist.next_track()) + + @unittest.mock.patch("random.choice") + def test_shuffle_loop_track(self, mock_random: unittest.mock.Mock()): + """Test next_track() with ::shuffle=True and ::loop='Track'.""" + self.playlist.loop = "Track" + self.playlist.shuffle = True + + mock_random.return_value = 3 + self.assertEqual(self.playlist.next_track(), self.tracks[2]) + mock_random.assert_called_with([1, 2, 3, 4, 5]) + + mock_random.reset_mock() + self.assertEqual(self.playlist.next_track(), self.tracks[2]) + mock_random.assert_not_called() + + @unittest.mock.patch("random.choice") + def test_shuffle_loop_playlist(self, mock_random: unittest.mock.Mock()): + """Test next_track() with ::shuffle=True and ::loop='Playlist'.""" + self.playlist.loop = "Playlist" + self.playlist.shuffle = True + + for i in range(5): + with self.subTest(i=i): + mock_random.return_value = i + 1 + self.assertEqual(self.playlist.next_track(), self.tracks[i]) + + for i in range(5): + with self.subTest(i=i): + mock_random.return_value = i + 1 + self.assertEqual(self.playlist.next_track(), self.tracks[i])