db: Preserve the current track when removing tracks

Again, we have to be careful not to check positions against playlists
where current == -1 for performance reasons.

Signed-off-by: Anna Schumaker <Anna@NoWheyCreamery.com>
This commit is contained in:
Anna Schumaker 2021-11-17 11:49:08 -05:00
parent 2f747ccaa6
commit 2e57e1fe0a
11 changed files with 74 additions and 47 deletions

View File

@ -63,8 +63,8 @@ class Playlist(GObject.GObject):
def add_track(self, track):
self.emit("track-added", track)
def remove_track(self, track):
self.emit("track-removed", track)
def remove_track(self, track, adjust_current):
self.emit("track-removed", track, adjust_current)
def next_track(self):
n = self.get_n_tracks()
@ -139,8 +139,10 @@ class Playlist(GObject.GObject):
if self.track_adjusts_current(track):
self.current += 1
@GObject.Signal(arg_types=(GObject.TYPE_PYOBJECT,))
def track_removed(self, track): pass
@GObject.Signal(arg_types=(GObject.TYPE_PYOBJECT,bool))
def track_removed(self, track, adjust_current):
if adjust_current:
self.current -= 1
class MappedPlaylist(Playlist):
@ -206,11 +208,12 @@ class MappedPlaylist(Playlist):
return row[2] - 1 if row else None
def remove_track(self, track):
adjust_current = self.track_adjusts_current(track)
res = sql.execute(f"DELETE FROM {self.map_table} "
f"WHERE {self.rowkey}=? AND trackid=?",
[ self.rowid, track.rowid ]).rowcount == 1
if res:
super().remove_track(track)
super().remove_track(track, adjust_current)
return res

View File

@ -9,8 +9,8 @@ class TestAlbum(unittest.TestCase):
def track_added(self, album, added):
self.added = added
def track_removed(self, album, removed):
self.removed = removed
def track_removed(self, album, removed, adjusted_current):
self.removed = (removed, False)
def setUp(self):
db.reset()
@ -54,7 +54,7 @@ class TestAlbum(unittest.TestCase):
album.connect("track-removed", self.track_removed)
db.track.Table.delete(track)
self.assertEqual(album.get_tracks(), [ ])
self.assertEqual(self.removed, track)
self.assertEqual(self.removed, (track, False))
class TestAlbumTable(unittest.TestCase):

View File

@ -9,8 +9,8 @@ class TestArtist(unittest.TestCase):
def track_added(self, artist, added):
self.added = added
def track_removed(self, artist, removed):
self.removed = removed
def track_removed(self, artist, removed, adjusted_current):
self.removed = (removed, adjusted_current)
def setUp(self):
db.reset()
@ -48,7 +48,7 @@ class TestArtist(unittest.TestCase):
artist.connect("track-removed", self.track_removed)
db.track.Table.delete(track)
self.assertEqual(artist.get_tracks(), [ ])
self.assertEqual(self.removed, track)
self.assertEqual(self.removed, (track, False))
class TestArtistTable(unittest.TestCase):

View File

@ -11,8 +11,8 @@ class TestDecade(unittest.TestCase):
def track_added(self, decade, added):
self.added = added
def track_removed(self, decade, removed):
self.removed = removed
def track_removed(self, decade, removed, adjusted_current):
self.removed = (removed, adjusted_current)
def setUp(self):
db.reset()
@ -58,7 +58,7 @@ class TestDecade(unittest.TestCase):
decade.connect("track-removed", self.track_removed)
db.track.Table.delete(track)
self.assertEqual(decade.get_tracks(), [ ])
self.assertEqual(self.removed, track)
self.assertEqual(self.removed, (track, False))
class TestDecadeTable(unittest.TestCase):

View File

@ -9,8 +9,8 @@ class TestDisc(unittest.TestCase):
def track_added(self, disc, added):
self.added = added
def track_removed(self, disc, removed):
self. removed = removed
def track_removed(self, disc, removed, adjusted_current):
self. removed = (removed, adjusted_current)
def setUp(self):
db.reset()
@ -67,7 +67,7 @@ class TestDisc(unittest.TestCase):
disc.connect("track-removed", self.track_removed)
db.track.Table.delete(track)
self.assertEqual(disc.get_tracks(), [ ])
self.assertEqual(self.removed, track)
self.assertEqual(self.removed, (track, False))
class TestDiscTable(unittest.TestCase):

View File

@ -10,8 +10,8 @@ class TestGenre(unittest.TestCase):
def track_added(self, genre, added):
self.added = added
def track_removed(self, genre, removed):
self.removed = removed
def track_removed(self, genre, removed, adjusted_current):
self.removed = (removed, adjusted_current)
def setUp(self):
db.reset()
@ -51,7 +51,7 @@ class TestGenre(unittest.TestCase):
self.assertFalse(genre.remove_track(track))
self.assertEqual(genre.get_n_tracks(), 0)
self.assertEqual(genre.get_tracks(), [ ])
self.assertEqual(self.removed, track)
self.assertEqual(self.removed, (track, False))
class TestGenreTable(unittest.TestCase):

View File

@ -10,8 +10,8 @@ class TestLibrary(unittest.TestCase):
def track_added(self, library, added):
self.added = added
def track_removed(self, library, removed):
self. removed = removed
def track_removed(self, library, removed, adjusted_current):
self. removed = (removed, adjusted_current)
def setUp(self):
db.reset()
@ -56,7 +56,7 @@ class TestLibrary(unittest.TestCase):
library.connect("track-removed", self.track_removed)
db.track.Table.delete(track)
self.assertEqual(library.get_tracks(), [ ])
self.assertEqual(self.removed, track)
self.assertEqual(self.removed, (track, False))
class TestLibraryTable(unittest.TestCase):

View File

@ -66,6 +66,29 @@ class TestPlaylist(unittest.TestCase):
self.assertEqual(plist.get_track(2), track3)
self.assertEqual(plist.current, 1)
def test_remove_track(self):
plist = db.user.Table.find("Test Playlist")
track1 = db.make_fake_track(1, 1, "Track 1", "/a/b/c/1.ogg")
track2 = db.make_fake_track(2, 2, "Track 2", "/a/b/c/2.ogg")
track3 = db.make_fake_track(3, 3, "Track 3", "/a/b/c/3.ogg")
plist.sort = [ "tracks.number ASC" ]
plist.add_track(track1)
plist.add_track(track2)
plist.add_track(track3)
plist.current = 1
self.assertTrue(plist.track_adjusts_current(track1))
self.assertTrue(plist.track_adjusts_current(track2))
self.assertFalse(plist.track_adjusts_current(track3))
plist.remove_track(track3)
self.assertEqual(plist.current, 1)
plist.remove_track(track1)
self.assertEqual(plist.current, 0)
plist.remove_track(track2)
self.assertEqual(plist.current, -1)
def test_current(self):
plist = db.user.Table.find("Test Playlist")
self.assertEqual(plist.get_property("current"), -1)

View File

@ -9,8 +9,8 @@ class TestCollection(unittest.TestCase):
def track_added(self, plist, added):
self.added = added
def track_removed(self, plist, removed):
self.removed = removed
def track_removed(self, plist, removed, adjusted_current):
self.removed = (removed, adjusted_current)
def refreshed(self, plist):
self.refreshed = True
@ -41,7 +41,7 @@ class TestCollection(unittest.TestCase):
collection.connect("track-removed", self.track_removed)
db.track.Table.delete(track)
self.assertEqual(collection.get_tracks(), [ ])
self.assertEqual(self.removed, track)
self.assertEqual(self.removed, (track, False))
def test_library_enabled(self):
collection = db.user.Table.find("Collection")
@ -67,8 +67,8 @@ class TestFavorites(unittest.TestCase):
def track_added(self, plist, added):
self.added = added
def track_removed(self, plist, removed):
self.removed = removed
def track_removed(self, plist, removed, adjusted_current):
self.removed = (removed, False)
def setUp(self): db.reset()
@ -101,15 +101,15 @@ class TestFavorites(unittest.TestCase):
self.assertFalse(favorites.remove_track(track))
self.assertEqual(favorites.get_n_tracks(), 0)
self.assertEqual(favorites.get_tracks(), [ ])
self.assertEqual(self.removed, track)
self.assertEqual(self.removed, (track, False))
class TestNewTracks(unittest.TestCase):
def track_added(self, plist, added):
self.added = added
def track_removed(self, plist, removed):
self.removed = removed
def track_removed(self, plist, removed, adjusted_current):
self.removed = (removed, adjusted_current)
def setUp(self): db.reset()
@ -140,15 +140,15 @@ class TestNewTracks(unittest.TestCase):
self.assertFalse(new.remove_track(track))
self.assertEqual(new.get_n_tracks(), 0)
self.assertEqual(new.get_tracks(), [ ])
self.assertEqual(self.removed, track)
self.assertEqual(self.removed, (track, False))
class TestPrevious(unittest.TestCase):
def track_added(self, plist, added):
self.added = added
def track_removed(self, plist, removed):
self.removed = removed
def track_removed(self, plist, removed, adjusted_current):
self.removed = (removed, adjusted_current)
def setUp(self): db.reset()
@ -194,7 +194,7 @@ class TestPrevious(unittest.TestCase):
self.assertEqual(previous.get_tracks(), [ track1, track2 ])
self.assertEqual(previous.get_track_index(track1), 0)
self.assertEqual(previous.get_track_index(track2), 1)
self.assertEqual(self.removed, track1)
self.assertEqual(self.removed, (track1, False))
self.assertEqual(self.added, track1)
self.assertTrue(previous.remove_track(track1))
@ -202,7 +202,7 @@ class TestPrevious(unittest.TestCase):
self.assertEqual(previous.get_n_tracks(), 1)
self.assertEqual(previous.get_tracks(), [ track2 ])
self.assertEqual(previous.get_track_index(track2), 0)
self.assertEqual(self.removed, track1)
self.assertEqual(self.removed, (track1, True))
def test_previous_track(self):
previous = db.user.Table.find("Previous")
@ -241,8 +241,8 @@ class TestQueuedTracks(unittest.TestCase):
def track_added(self, plist, added):
self.added = added
def track_removed(self, plist, removed):
self.removed = removed
def track_removed(self, plist, removed, adjusted_current):
self.removed = (removed, adjusted_current)
def setUp(self): db.reset()
@ -276,15 +276,15 @@ class TestQueuedTracks(unittest.TestCase):
self.assertFalse(queued.remove_track(track))
self.assertEqual(queued.get_n_tracks(), 0)
self.assertEqual(queued.get_tracks(), [ ])
self.assertEqual(self.removed, track)
self.assertEqual(self.removed, (track, False))
class TestUserPlaylist(unittest.TestCase):
def track_added(self, plist, added):
self.added = added
def track_removed(self, plist, removed):
self.removed = removed
def track_removed(self, plist, removed, adjusted_current):
self.removed = (removed, adjusted_current)
def setUp(self): db.reset()
@ -325,7 +325,7 @@ class TestUserPlaylist(unittest.TestCase):
self.assertFalse(plist.remove_track(track))
self.assertEqual(plist.get_n_tracks(), 0)
self.assertEqual(plist.get_tracks(), [ ])
self.assertEqual(self.removed, track)
self.assertEqual(self.removed, (track, False))
class TestUserTable(unittest.TestCase):

View File

@ -8,8 +8,8 @@ class TestYear(unittest.TestCase):
def track_added(self, plist, added):
self.added = added
def track_removed(self, plist, removed):
self.removed = removed
def track_removed(self, plist, removed, adjusted_current):
self.removed = (removed, adjusted_current)
def setUp(self):
db.reset()
@ -45,7 +45,7 @@ class TestYear(unittest.TestCase):
year.connect("track-removed", self.track_removed)
db.track.Table.delete(track)
self.assertEqual(year.get_tracks(), [ ])
self.assertEqual(self.removed, track)
self.assertEqual(self.removed, (track, False))
class TestYearTable(unittest.TestCase):

View File

@ -159,10 +159,11 @@ class TrackTable(table.Table):
plists = [ track.artist, track.album, track.disc,
track.decade, track.year, track.library,
user.Table.find("Collection") ]
adjust = [ p.track_adjusts_current(track) for p in plists ]
super().delete(track)
for plist in plists:
plist.remove_track(track)
for (plist, adjust) in zip(plists, adjust):
plist.remove_track(track, adjust)
def find(self, *args):
raise NotImplementedError