diff --git a/emmental/playlist/__init__.py b/emmental/playlist/__init__.py index a68db00..5f207c5 100644 --- a/emmental/playlist/__init__.py +++ b/emmental/playlist/__init__.py @@ -19,6 +19,8 @@ class Factory(GObject.GObject): previous_playlist = GObject.Property(type=previous.Previous) visible_playlist = GObject.Property(type=playlist.Playlist) + can_go_previous = GObject.Property(type=bool, default=False) + def __init__(self, sql: db.Connection): """Initialize the Playlist Factory.""" super().__init__(sql=sql) @@ -36,12 +38,26 @@ class Factory(GObject.GObject): if plist.playlist == db_plist: return plist + def __update_can_go(self, which: str, newval: bool) -> None: + if self.get_property(f"can-go-{which}") != newval: + self.set_property(f"can-go-{which}", newval) + + def __update_can_go_prev(self, *args) -> None: + self.__update_can_go("previous", self.previous_playlist.can_go_next) + def __make_playlist(self, db_plist: db.playlist.Playlist) -> playlist.Playlist: if db_plist == self.sql.playlists.previous: - return previous.Previous(self.sql, db_plist) + res = previous.Previous(self.sql, db_plist) + res.connect("notify::can-go-next", self.__update_can_go_prev) + return res return playlist.Playlist(self.sql, db_plist) + def __free_playlist(self, plist: playlist.Playlist) -> None: + plist.playlist = None + if isinstance(plist, previous.Previous): + plist.disconnect_by_func(self.__update_can_go_prev) + def __run_factory(self, label: str) -> None: db_plist = self.get_property(f"db-{label}") plist = self.get_property(f"{label}-playlist") @@ -51,7 +67,7 @@ class Factory(GObject.GObject): if db_plist is None: if self.__get_playlists().count(plist) == 1: - plist.playlist = None + self.__free_playlist(plist) new = None elif plist is None or self.__get_playlists().count(plist) > 1: if (new := self.__search_playlists(db_plist)) is None: @@ -75,6 +91,12 @@ class Factory(GObject.GObject): case "db-visible": self.__run_factory("visible") + def previous_track(self) -> tuple[db.tracks.Track | None, bool]: + """Get the previous Track.""" + if self.previous_playlist is None: + return None + return self.previous_playlist.next_track() + @GObject.Property(type=str, flags=playlist.FLAGS) def active_loop(self) -> str: """Get the loop state of the active playlist.""" diff --git a/tests/playlist/test_factory.py b/tests/playlist/test_factory.py index 9754c00..354bca0 100644 --- a/tests/playlist/test_factory.py +++ b/tests/playlist/test_factory.py @@ -1,6 +1,7 @@ # Copyright 2023 (c) Anna Schumaker. """Test our Playlist Manager object.""" import io +import pathlib import unittest.mock import emmental.playlist import tests.util @@ -232,3 +233,60 @@ class TestFactory(tests.util.TestCase): self.factory.db_visible = self.user_plist self.assertNotEqual(id(self.factory.visible_playlist), id(self.factory.previous_playlist)) + + +@unittest.mock.patch("sys.stdout", new_callable=io.StringIO) +class TestFactoryNextPreviousTrack(tests.util.TestCase): + """Test the Factory next_track() and previous_track() functions.""" + + def setUp(self): + """Set up common variables.""" + super().setUp() + self.sql.playlists.load(now=True) + self.user_plist = self.sql.playlists.create("User Playlist") + + self.factory = emmental.playlist.Factory(self.sql) + + self.library = self.sql.libraries.create(pathlib.Path("/a/b")) + self.album = self.sql.albums.create("Test Album", "Artist", "2023") + self.medium = self.sql.media.create(self.album, "", number=1) + self.year = self.sql.years.create(2023) + + self.tracks = [self.sql.tracks.create(self.library, + pathlib.Path(f"/a/b/{i}.ogg"), + self.medium, self.year, number=i) + for i in range(1, 4)] + + def test_can_go_previous(self, mock_stdout: io.StringIO): + """Test the can-go-previous property.""" + self.assertFalse(self.factory.can_go_previous) + + self.factory.db_previous = self.sql.playlists.previous + self.assertFalse(self.factory.can_go_previous) + self.tracks[0].start() + self.assertFalse(self.factory.can_go_previous) + self.tracks[1].start() + self.assertTrue(self.factory.can_go_previous) + + self.factory.db_previous = None + self.assertFalse(self.factory.can_go_previous) + self.tracks[2].start() + self.assertFalse(self.factory.can_go_previous) + + def test_previous_track(self, mock_stdout: io.StringIO): + """Test the previous_track() function.""" + self.assertIsNone(self.factory.previous_track()) + + self.factory.db_previous = self.sql.playlists.previous + self.assertIsNone(self.factory.previous_track()) + self.tracks[0].start() + self.assertIsNone(self.factory.previous_track()) + + self.tracks[1].start() + self.assertEqual(self.factory.previous_track(), self.tracks[0]) + self.assertIsNone(self.factory.previous_track()) + + self.tracks[2].start() + self.assertEqual(self.factory.previous_track(), self.tracks[1]) + self.assertEqual(self.factory.previous_track(), self.tracks[0]) + self.assertIsNone(self.factory.previous_track())