diff --git a/db/playlist.py b/db/playlist.py index 1678ba6..ca74313 100644 --- a/db/playlist.py +++ b/db/playlist.py @@ -63,8 +63,8 @@ class Playlist(GObject.GObject): def add_track(self, track): self.emit("track-added", track) - def remove_track(self, track): - self.emit("track-removed", track) + def remove_track(self, track, adjust_current): + self.emit("track-removed", track, adjust_current) def next_track(self): n = self.get_n_tracks() @@ -139,8 +139,10 @@ class Playlist(GObject.GObject): if self.track_adjusts_current(track): self.current += 1 - @GObject.Signal(arg_types=(GObject.TYPE_PYOBJECT,)) - def track_removed(self, track): pass + @GObject.Signal(arg_types=(GObject.TYPE_PYOBJECT,bool)) + def track_removed(self, track, adjust_current): + if adjust_current: + self.current -= 1 class MappedPlaylist(Playlist): @@ -206,11 +208,12 @@ class MappedPlaylist(Playlist): return row[2] - 1 if row else None def remove_track(self, track): + adjust_current = self.track_adjusts_current(track) res = sql.execute(f"DELETE FROM {self.map_table} " f"WHERE {self.rowkey}=? AND trackid=?", [ self.rowid, track.rowid ]).rowcount == 1 if res: - super().remove_track(track) + super().remove_track(track, adjust_current) return res diff --git a/db/test_album.py b/db/test_album.py index b576e8e..15e2534 100644 --- a/db/test_album.py +++ b/db/test_album.py @@ -9,8 +9,8 @@ class TestAlbum(unittest.TestCase): def track_added(self, album, added): self.added = added - def track_removed(self, album, removed): - self.removed = removed + def track_removed(self, album, removed, adjusted_current): + self.removed = (removed, False) def setUp(self): db.reset() @@ -54,7 +54,7 @@ class TestAlbum(unittest.TestCase): album.connect("track-removed", self.track_removed) db.track.Table.delete(track) self.assertEqual(album.get_tracks(), [ ]) - self.assertEqual(self.removed, track) + self.assertEqual(self.removed, (track, False)) class TestAlbumTable(unittest.TestCase): diff --git a/db/test_artist.py b/db/test_artist.py index dd2a0ad..496eb76 100644 --- a/db/test_artist.py +++ b/db/test_artist.py @@ -9,8 +9,8 @@ class TestArtist(unittest.TestCase): def track_added(self, artist, added): self.added = added - def track_removed(self, artist, removed): - self.removed = removed + def track_removed(self, artist, removed, adjusted_current): + self.removed = (removed, adjusted_current) def setUp(self): db.reset() @@ -48,7 +48,7 @@ class TestArtist(unittest.TestCase): artist.connect("track-removed", self.track_removed) db.track.Table.delete(track) self.assertEqual(artist.get_tracks(), [ ]) - self.assertEqual(self.removed, track) + self.assertEqual(self.removed, (track, False)) class TestArtistTable(unittest.TestCase): diff --git a/db/test_decade.py b/db/test_decade.py index b99470a..700920d 100644 --- a/db/test_decade.py +++ b/db/test_decade.py @@ -11,8 +11,8 @@ class TestDecade(unittest.TestCase): def track_added(self, decade, added): self.added = added - def track_removed(self, decade, removed): - self.removed = removed + def track_removed(self, decade, removed, adjusted_current): + self.removed = (removed, adjusted_current) def setUp(self): db.reset() @@ -58,7 +58,7 @@ class TestDecade(unittest.TestCase): decade.connect("track-removed", self.track_removed) db.track.Table.delete(track) self.assertEqual(decade.get_tracks(), [ ]) - self.assertEqual(self.removed, track) + self.assertEqual(self.removed, (track, False)) class TestDecadeTable(unittest.TestCase): diff --git a/db/test_disc.py b/db/test_disc.py index de9aa8f..8776c9b 100644 --- a/db/test_disc.py +++ b/db/test_disc.py @@ -9,8 +9,8 @@ class TestDisc(unittest.TestCase): def track_added(self, disc, added): self.added = added - def track_removed(self, disc, removed): - self. removed = removed + def track_removed(self, disc, removed, adjusted_current): + self. removed = (removed, adjusted_current) def setUp(self): db.reset() @@ -67,7 +67,7 @@ class TestDisc(unittest.TestCase): disc.connect("track-removed", self.track_removed) db.track.Table.delete(track) self.assertEqual(disc.get_tracks(), [ ]) - self.assertEqual(self.removed, track) + self.assertEqual(self.removed, (track, False)) class TestDiscTable(unittest.TestCase): diff --git a/db/test_genre.py b/db/test_genre.py index 829243c..63f00b0 100644 --- a/db/test_genre.py +++ b/db/test_genre.py @@ -10,8 +10,8 @@ class TestGenre(unittest.TestCase): def track_added(self, genre, added): self.added = added - def track_removed(self, genre, removed): - self.removed = removed + def track_removed(self, genre, removed, adjusted_current): + self.removed = (removed, adjusted_current) def setUp(self): db.reset() @@ -51,7 +51,7 @@ class TestGenre(unittest.TestCase): self.assertFalse(genre.remove_track(track)) self.assertEqual(genre.get_n_tracks(), 0) self.assertEqual(genre.get_tracks(), [ ]) - self.assertEqual(self.removed, track) + self.assertEqual(self.removed, (track, False)) class TestGenreTable(unittest.TestCase): diff --git a/db/test_library.py b/db/test_library.py index ce5a000..f13cbeb 100644 --- a/db/test_library.py +++ b/db/test_library.py @@ -10,8 +10,8 @@ class TestLibrary(unittest.TestCase): def track_added(self, library, added): self.added = added - def track_removed(self, library, removed): - self. removed = removed + def track_removed(self, library, removed, adjusted_current): + self. removed = (removed, adjusted_current) def setUp(self): db.reset() @@ -56,7 +56,7 @@ class TestLibrary(unittest.TestCase): library.connect("track-removed", self.track_removed) db.track.Table.delete(track) self.assertEqual(library.get_tracks(), [ ]) - self.assertEqual(self.removed, track) + self.assertEqual(self.removed, (track, False)) class TestLibraryTable(unittest.TestCase): diff --git a/db/test_playlist.py b/db/test_playlist.py index cf140bb..23e0961 100644 --- a/db/test_playlist.py +++ b/db/test_playlist.py @@ -66,6 +66,29 @@ class TestPlaylist(unittest.TestCase): self.assertEqual(plist.get_track(2), track3) self.assertEqual(plist.current, 1) + def test_remove_track(self): + plist = db.user.Table.find("Test Playlist") + track1 = db.make_fake_track(1, 1, "Track 1", "/a/b/c/1.ogg") + track2 = db.make_fake_track(2, 2, "Track 2", "/a/b/c/2.ogg") + track3 = db.make_fake_track(3, 3, "Track 3", "/a/b/c/3.ogg") + plist.sort = [ "tracks.number ASC" ] + + plist.add_track(track1) + plist.add_track(track2) + plist.add_track(track3) + plist.current = 1 + + self.assertTrue(plist.track_adjusts_current(track1)) + self.assertTrue(plist.track_adjusts_current(track2)) + self.assertFalse(plist.track_adjusts_current(track3)) + + plist.remove_track(track3) + self.assertEqual(plist.current, 1) + plist.remove_track(track1) + self.assertEqual(plist.current, 0) + plist.remove_track(track2) + self.assertEqual(plist.current, -1) + def test_current(self): plist = db.user.Table.find("Test Playlist") self.assertEqual(plist.get_property("current"), -1) diff --git a/db/test_user.py b/db/test_user.py index 5299a84..41a18a3 100644 --- a/db/test_user.py +++ b/db/test_user.py @@ -9,8 +9,8 @@ class TestCollection(unittest.TestCase): def track_added(self, plist, added): self.added = added - def track_removed(self, plist, removed): - self.removed = removed + def track_removed(self, plist, removed, adjusted_current): + self.removed = (removed, adjusted_current) def refreshed(self, plist): self.refreshed = True @@ -41,7 +41,7 @@ class TestCollection(unittest.TestCase): collection.connect("track-removed", self.track_removed) db.track.Table.delete(track) self.assertEqual(collection.get_tracks(), [ ]) - self.assertEqual(self.removed, track) + self.assertEqual(self.removed, (track, False)) def test_library_enabled(self): collection = db.user.Table.find("Collection") @@ -67,8 +67,8 @@ class TestFavorites(unittest.TestCase): def track_added(self, plist, added): self.added = added - def track_removed(self, plist, removed): - self.removed = removed + def track_removed(self, plist, removed, adjusted_current): + self.removed = (removed, False) def setUp(self): db.reset() @@ -101,15 +101,15 @@ class TestFavorites(unittest.TestCase): self.assertFalse(favorites.remove_track(track)) self.assertEqual(favorites.get_n_tracks(), 0) self.assertEqual(favorites.get_tracks(), [ ]) - self.assertEqual(self.removed, track) + self.assertEqual(self.removed, (track, False)) class TestNewTracks(unittest.TestCase): def track_added(self, plist, added): self.added = added - def track_removed(self, plist, removed): - self.removed = removed + def track_removed(self, plist, removed, adjusted_current): + self.removed = (removed, adjusted_current) def setUp(self): db.reset() @@ -140,15 +140,15 @@ class TestNewTracks(unittest.TestCase): self.assertFalse(new.remove_track(track)) self.assertEqual(new.get_n_tracks(), 0) self.assertEqual(new.get_tracks(), [ ]) - self.assertEqual(self.removed, track) + self.assertEqual(self.removed, (track, False)) class TestPrevious(unittest.TestCase): def track_added(self, plist, added): self.added = added - def track_removed(self, plist, removed): - self.removed = removed + def track_removed(self, plist, removed, adjusted_current): + self.removed = (removed, adjusted_current) def setUp(self): db.reset() @@ -194,7 +194,7 @@ class TestPrevious(unittest.TestCase): self.assertEqual(previous.get_tracks(), [ track1, track2 ]) self.assertEqual(previous.get_track_index(track1), 0) self.assertEqual(previous.get_track_index(track2), 1) - self.assertEqual(self.removed, track1) + self.assertEqual(self.removed, (track1, False)) self.assertEqual(self.added, track1) self.assertTrue(previous.remove_track(track1)) @@ -202,7 +202,7 @@ class TestPrevious(unittest.TestCase): self.assertEqual(previous.get_n_tracks(), 1) self.assertEqual(previous.get_tracks(), [ track2 ]) self.assertEqual(previous.get_track_index(track2), 0) - self.assertEqual(self.removed, track1) + self.assertEqual(self.removed, (track1, True)) def test_previous_track(self): previous = db.user.Table.find("Previous") @@ -241,8 +241,8 @@ class TestQueuedTracks(unittest.TestCase): def track_added(self, plist, added): self.added = added - def track_removed(self, plist, removed): - self.removed = removed + def track_removed(self, plist, removed, adjusted_current): + self.removed = (removed, adjusted_current) def setUp(self): db.reset() @@ -276,15 +276,15 @@ class TestQueuedTracks(unittest.TestCase): self.assertFalse(queued.remove_track(track)) self.assertEqual(queued.get_n_tracks(), 0) self.assertEqual(queued.get_tracks(), [ ]) - self.assertEqual(self.removed, track) + self.assertEqual(self.removed, (track, False)) class TestUserPlaylist(unittest.TestCase): def track_added(self, plist, added): self.added = added - def track_removed(self, plist, removed): - self.removed = removed + def track_removed(self, plist, removed, adjusted_current): + self.removed = (removed, adjusted_current) def setUp(self): db.reset() @@ -325,7 +325,7 @@ class TestUserPlaylist(unittest.TestCase): self.assertFalse(plist.remove_track(track)) self.assertEqual(plist.get_n_tracks(), 0) self.assertEqual(plist.get_tracks(), [ ]) - self.assertEqual(self.removed, track) + self.assertEqual(self.removed, (track, False)) class TestUserTable(unittest.TestCase): diff --git a/db/test_year.py b/db/test_year.py index 59ffdc7..f669d08 100644 --- a/db/test_year.py +++ b/db/test_year.py @@ -8,8 +8,8 @@ class TestYear(unittest.TestCase): def track_added(self, plist, added): self.added = added - def track_removed(self, plist, removed): - self.removed = removed + def track_removed(self, plist, removed, adjusted_current): + self.removed = (removed, adjusted_current) def setUp(self): db.reset() @@ -45,7 +45,7 @@ class TestYear(unittest.TestCase): year.connect("track-removed", self.track_removed) db.track.Table.delete(track) self.assertEqual(year.get_tracks(), [ ]) - self.assertEqual(self.removed, track) + self.assertEqual(self.removed, (track, False)) class TestYearTable(unittest.TestCase): diff --git a/db/track.py b/db/track.py index 27be306..567e462 100644 --- a/db/track.py +++ b/db/track.py @@ -159,10 +159,11 @@ class TrackTable(table.Table): plists = [ track.artist, track.album, track.disc, track.decade, track.year, track.library, user.Table.find("Collection") ] + adjust = [ p.track_adjusts_current(track) for p in plists ] super().delete(track) - for plist in plists: - plist.remove_track(track) + for (plist, adjust) in zip(plists, adjust): + plist.remove_track(track, adjust) def find(self, *args): raise NotImplementedError