db: Commit the database when a Track has been started or stopped

Leaving the database in a dirty state could cause unintentional data
loss if the app crashes.

Fixes: #63 ("The database isn't being committed enough")
Signed-off-by: Anna Schumaker <Anna@NoWheyCreamery.com>
This commit is contained in:
Anna Schumaker 2023-06-07 13:22:56 -04:00
parent 57dd2c280e
commit 14c487c295
3 changed files with 49 additions and 30 deletions

View File

@ -75,6 +75,10 @@ class Connection(GObject.GObject):
self._sql.close()
self.connected = False
def commit(self) -> None:
"""Commit pending changes."""
self._sql.commit()
def executemany(self, statement: str, *args) -> sqlite3.Cursor | None:
"""Execute several similar SQL statements at once."""
try:

View File

@ -241,6 +241,7 @@ class Table(table.Table):
track.active = True
track.laststarted = cur.fetchone()["laststarted"]
self.current_track = track
self.sql.commit()
def stop_track(self, track: Track, played: bool) -> None:
"""Mark that a Track has been stopped."""
@ -270,6 +271,8 @@ class Table(table.Table):
self.sql.playlists.queued.remove_track(track)
self.sql.playlists.unplayed.remove_track(track)
self.sql.commit()
class TrackidSet(GObject.GObject):
"""Manage a set of Track IDs."""

View File

@ -508,16 +508,20 @@ class TestTrackTable(tests.util.TestCase):
track = self.tracks.create(self.library, pathlib.Path("/a/b/1.ogg"),
self.medium, self.year)
track.start()
row = self.sql("SELECT laststarted FROM tracks WHERE trackid=?",
track.trackid).fetchone()
self.assertTrue(track.active)
self.assertIsNotNone(track.laststarted)
self.assertEqual(track.laststarted, row["laststarted"])
self.assertEqual(self.tracks.current_track, track)
with unittest.mock.patch.object(self.sql, "commit",
wraps=self.sql.commit) as mock_commit:
track.start()
mock_commit.assert_called()
self.playlists.previous.remove_track.assert_called_with(track)
self.playlists.previous.add_track.assert_called_with(track)
row = self.sql("SELECT laststarted FROM tracks WHERE trackid=?",
track.trackid).fetchone()
self.assertTrue(track.active)
self.assertIsNotNone(track.laststarted)
self.assertEqual(track.laststarted, row["laststarted"])
self.assertEqual(self.tracks.current_track, track)
self.playlists.previous.remove_track.assert_called_with(track)
self.playlists.previous.add_track.assert_called_with(track)
def test_stop_started_track(self):
"""Test marking that a Track has stopped playback."""
@ -525,31 +529,39 @@ class TestTrackTable(tests.util.TestCase):
self.medium, self.year, length=10)
track.start()
track.stop(3)
row = self.sql("SELECT lastplayed FROM tracks WHERE trackid=?",
track.trackid).fetchone()
self.assertFalse(track.active)
self.assertEqual(track.playcount, 0)
self.assertIsNone(row["lastplayed"])
self.assertIsNone(track.lastplayed)
self.assertIsNone(self.tracks.current_track)
with unittest.mock.patch.object(self.sql, "commit",
wraps=self.sql.commit) as mock_commit:
track.stop(3)
mock_commit.assert_called()
self.playlists.most_played.reload_tracks.assert_not_called()
self.playlists.queued.remove_track.assert_not_called()
self.playlists.unplayed.remove_track.assert_not_called()
row = self.sql("SELECT lastplayed FROM tracks WHERE trackid=?",
track.trackid).fetchone()
self.assertFalse(track.active)
self.assertEqual(track.playcount, 0)
self.assertIsNone(row["lastplayed"])
self.assertIsNone(track.lastplayed)
self.assertIsNone(self.tracks.current_track)
self.playlists.most_played.reload_tracks.assert_not_called()
self.playlists.queued.remove_track.assert_not_called()
self.playlists.unplayed.remove_track.assert_not_called()
track.start()
track.stop(8)
row = self.sql("""SELECT lastplayed, playcount FROM tracks
WHERE trackid=?""", track.trackid).fetchone()
self.assertEqual(row["playcount"], 1)
self.assertEqual(track.playcount, 1)
self.assertEqual(row["lastplayed"], track.laststarted)
self.assertEqual(track.lastplayed, track.laststarted)
with unittest.mock.patch.object(self.sql, "commit",
wraps=self.sql.commit) as mock_commit:
track.stop(8)
mock_commit.assert_called()
self.playlists.most_played.reload_tracks.assert_called()
self.playlists.queued.remove_track.assert_called_with(track)
self.playlists.unplayed.remove_track.assert_called_with(track)
row = self.sql("""SELECT lastplayed, playcount FROM tracks
WHERE trackid=?""", track.trackid).fetchone()
self.assertEqual(row["playcount"], 1)
self.assertEqual(track.playcount, 1)
self.assertEqual(row["lastplayed"], track.laststarted)
self.assertEqual(track.lastplayed, track.laststarted)
self.playlists.most_played.reload_tracks.assert_called()
self.playlists.queued.remove_track.assert_called_with(track)
self.playlists.unplayed.remove_track.assert_called_with(track)
def test_stop_restarted_track(self):
"""Test marking that a restarted Track has stopped playback."""