From 14c487c29541295d273ccf16045af20328d64783 Mon Sep 17 00:00:00 2001 From: Anna Schumaker Date: Wed, 7 Jun 2023 13:22:56 -0400 Subject: [PATCH] db: Commit the database when a Track has been started or stopped Leaving the database in a dirty state could cause unintentional data loss if the app crashes. Fixes: #63 ("The database isn't being committed enough") Signed-off-by: Anna Schumaker --- emmental/db/connection.py | 4 +++ emmental/db/tracks.py | 3 ++ tests/db/test_tracks.py | 72 +++++++++++++++++++++++---------------- 3 files changed, 49 insertions(+), 30 deletions(-) diff --git a/emmental/db/connection.py b/emmental/db/connection.py index dbe9cad..1920e36 100644 --- a/emmental/db/connection.py +++ b/emmental/db/connection.py @@ -75,6 +75,10 @@ class Connection(GObject.GObject): self._sql.close() self.connected = False + def commit(self) -> None: + """Commit pending changes.""" + self._sql.commit() + def executemany(self, statement: str, *args) -> sqlite3.Cursor | None: """Execute several similar SQL statements at once.""" try: diff --git a/emmental/db/tracks.py b/emmental/db/tracks.py index 01061e4..36984a0 100644 --- a/emmental/db/tracks.py +++ b/emmental/db/tracks.py @@ -241,6 +241,7 @@ class Table(table.Table): track.active = True track.laststarted = cur.fetchone()["laststarted"] self.current_track = track + self.sql.commit() def stop_track(self, track: Track, played: bool) -> None: """Mark that a Track has been stopped.""" @@ -270,6 +271,8 @@ class Table(table.Table): self.sql.playlists.queued.remove_track(track) self.sql.playlists.unplayed.remove_track(track) + self.sql.commit() + class TrackidSet(GObject.GObject): """Manage a set of Track IDs.""" diff --git a/tests/db/test_tracks.py b/tests/db/test_tracks.py index 28b7215..187ecb3 100644 --- a/tests/db/test_tracks.py +++ b/tests/db/test_tracks.py @@ -508,16 +508,20 @@ class TestTrackTable(tests.util.TestCase): track = self.tracks.create(self.library, pathlib.Path("/a/b/1.ogg"), self.medium, self.year) - track.start() - row = self.sql("SELECT laststarted FROM tracks WHERE trackid=?", - track.trackid).fetchone() - self.assertTrue(track.active) - self.assertIsNotNone(track.laststarted) - self.assertEqual(track.laststarted, row["laststarted"]) - self.assertEqual(self.tracks.current_track, track) + with unittest.mock.patch.object(self.sql, "commit", + wraps=self.sql.commit) as mock_commit: + track.start() + mock_commit.assert_called() - self.playlists.previous.remove_track.assert_called_with(track) - self.playlists.previous.add_track.assert_called_with(track) + row = self.sql("SELECT laststarted FROM tracks WHERE trackid=?", + track.trackid).fetchone() + self.assertTrue(track.active) + self.assertIsNotNone(track.laststarted) + self.assertEqual(track.laststarted, row["laststarted"]) + self.assertEqual(self.tracks.current_track, track) + + self.playlists.previous.remove_track.assert_called_with(track) + self.playlists.previous.add_track.assert_called_with(track) def test_stop_started_track(self): """Test marking that a Track has stopped playback.""" @@ -525,31 +529,39 @@ class TestTrackTable(tests.util.TestCase): self.medium, self.year, length=10) track.start() - track.stop(3) - row = self.sql("SELECT lastplayed FROM tracks WHERE trackid=?", - track.trackid).fetchone() - self.assertFalse(track.active) - self.assertEqual(track.playcount, 0) - self.assertIsNone(row["lastplayed"]) - self.assertIsNone(track.lastplayed) - self.assertIsNone(self.tracks.current_track) + with unittest.mock.patch.object(self.sql, "commit", + wraps=self.sql.commit) as mock_commit: + track.stop(3) + mock_commit.assert_called() - self.playlists.most_played.reload_tracks.assert_not_called() - self.playlists.queued.remove_track.assert_not_called() - self.playlists.unplayed.remove_track.assert_not_called() + row = self.sql("SELECT lastplayed FROM tracks WHERE trackid=?", + track.trackid).fetchone() + self.assertFalse(track.active) + self.assertEqual(track.playcount, 0) + self.assertIsNone(row["lastplayed"]) + self.assertIsNone(track.lastplayed) + self.assertIsNone(self.tracks.current_track) + + self.playlists.most_played.reload_tracks.assert_not_called() + self.playlists.queued.remove_track.assert_not_called() + self.playlists.unplayed.remove_track.assert_not_called() track.start() - track.stop(8) - row = self.sql("""SELECT lastplayed, playcount FROM tracks - WHERE trackid=?""", track.trackid).fetchone() - self.assertEqual(row["playcount"], 1) - self.assertEqual(track.playcount, 1) - self.assertEqual(row["lastplayed"], track.laststarted) - self.assertEqual(track.lastplayed, track.laststarted) + with unittest.mock.patch.object(self.sql, "commit", + wraps=self.sql.commit) as mock_commit: + track.stop(8) + mock_commit.assert_called() - self.playlists.most_played.reload_tracks.assert_called() - self.playlists.queued.remove_track.assert_called_with(track) - self.playlists.unplayed.remove_track.assert_called_with(track) + row = self.sql("""SELECT lastplayed, playcount FROM tracks + WHERE trackid=?""", track.trackid).fetchone() + self.assertEqual(row["playcount"], 1) + self.assertEqual(track.playcount, 1) + self.assertEqual(row["lastplayed"], track.laststarted) + self.assertEqual(track.lastplayed, track.laststarted) + + self.playlists.most_played.reload_tracks.assert_called() + self.playlists.queued.remove_track.assert_called_with(track) + self.playlists.unplayed.remove_track.assert_called_with(track) def test_stop_restarted_track(self): """Test marking that a restarted Track has stopped playback."""