From c7dca6164ecb1dc4cd08ab7b4f2b2e4aed932dc1 Mon Sep 17 00:00:00 2001 From: Anna Schumaker Date: Mon, 19 Feb 2024 21:44:13 -0500 Subject: [PATCH] db: Create a listenbrainz_queue table in the database I bump the user_version to 3 at the same time. This table will be used to hold listenbrainz listens that have not yet been submitted to the listenbrainz server. I also give the Track table functions to get and delete listens from this table as needed by the listenbrainz thread. Signed-off-by: Anna Schumaker --- emmental/db/__init__.py | 7 ++++- emmental/db/tracks.py | 17 ++++++++++ emmental/db/upgrade-v3.sql | 25 +++++++++++++++ tests/db/test_db.py | 7 +++-- tests/db/test_tracks.py | 64 ++++++++++++++++++++++++++++++++++++++ 5 files changed, 116 insertions(+), 4 deletions(-) create mode 100644 emmental/db/upgrade-v3.sql diff --git a/emmental/db/__init__.py b/emmental/db/__init__.py index 1ec1a37..6a27f0d 100644 --- a/emmental/db/__init__.py +++ b/emmental/db/__init__.py @@ -20,6 +20,7 @@ from . import years SQL_V1_SCRIPT = pathlib.Path(__file__).parent / "emmental.sql" SQL_V2_SCRIPT = pathlib.Path(__file__).parent / "upgrade-v2.sql" +SQL_V3_SCRIPT = pathlib.Path(__file__).parent / "upgrade-v3.sql" class Connection(connection.Connection): @@ -57,9 +58,13 @@ class Connection(connection.Connection): case 0: self.executescript(SQL_V1_SCRIPT) self.executescript(SQL_V2_SCRIPT) + self.executescript(SQL_V3_SCRIPT) case 1: self.executescript(SQL_V2_SCRIPT) - case 2: pass + self.executescript(SQL_V3_SCRIPT) + case 2: + self.executescript(SQL_V3_SCRIPT) + case 3: pass case _: raise Exception(f"Unsupported data version: {user_version}") diff --git a/emmental/db/tracks.py b/emmental/db/tracks.py index bd072ee..60b2320 100644 --- a/emmental/db/tracks.py +++ b/emmental/db/tracks.py @@ -200,6 +200,12 @@ class Table(table.Table): return self.sql(f"UPDATE tracks SET {column}=? WHERE trackid=?", newval, track.trackid) + def delete_listens(self, listenids: list[int]) -> None: + """Delete the listens indicated by the provided listenids.""" + self.sql.executemany("""DELETE FROM listenbrainz_queue + WHERE listenid=?""", + *[(id,) for id in listenids]) + def get_artists(self, track: Track) -> list[table.Row]: """Get the set of Artists for a specific Track.""" rows = self.sql("""SELECT artistid FROM artist_tracks_view @@ -212,6 +218,14 @@ class Table(table.Table): WHERE trackid=?""", track.trackid).fetchall() return [self.sql.genres.rows.get(row["genreid"]) for row in rows] + def get_n_listens(self, n: int) -> list[tuple]: + """Get the n most recent listens from the listenbrainz queue.""" + cur = self.sql("""SELECT listenid, trackid, timestamp + FROM listenbrainz_queue ORDER BY timestamp DESC + LIMIT ?""", n) + return [(row["listenid"], self.rows[row["trackid"]], row["timestamp"]) + for row in cur.fetchall()] + def map_sort_order(self, ordering: str) -> dict[int, int]: """Get a lookup table for Track sort keys.""" ordering = ordering if len(ordering) > 0 else "trackid" @@ -277,6 +291,9 @@ class Table(table.Table): @GObject.Signal(arg_types=(Track,)) def track_played(self, track: Track) -> None: """Signal that a Track was played.""" + if track is not None: + self.sql("""INSERT INTO listenbrainz_queue (trackid, timestamp) + VALUES (?, ?)""", track.trackid, track.lastplayed) class TrackidSet(GObject.GObject): diff --git a/emmental/db/upgrade-v3.sql b/emmental/db/upgrade-v3.sql new file mode 100644 index 0000000..825c68f --- /dev/null +++ b/emmental/db/upgrade-v3.sql @@ -0,0 +1,25 @@ +/* Copyright 2024 (c) Anna Schumaker */ + +PRAGMA user_version = 3; + +/* + * The `listenbrainz_queue` table is used to store recently played tracks + * before submitting them to ListenBrainz. This gives us some form of offline + * recovery, since anything in this table needs to be submitted the next time + * we can successfully connect. As a bonus, I prepopulate this table using + * the last played data from tracks that have already been played when this + * table is created. + */ + +CREATE TABLE listenbrainz_queue ( + listenid INTEGER PRIMARY KEY, + trackid INTEGER REFERENCES tracks (trackid) + ON DELETE CASCADE + ON UPDATE CASCADE, + timestamp TIMESTAMP DEFAULT CURRENT_TIMESTAMP +); + + +INSERT INTO listenbrainz_queue (trackid, timestamp) + SELECT trackid, lastplayed FROM tracks + WHERE lastplayed IS NOT NULL; diff --git a/tests/db/test_db.py b/tests/db/test_db.py index 56610bd..7878bb6 100644 --- a/tests/db/test_db.py +++ b/tests/db/test_db.py @@ -14,6 +14,7 @@ class TestConnection(tests.util.TestCase): dir = pathlib.Path(emmental.db.__file__).parent self.assertEqual(emmental.db.SQL_V1_SCRIPT, dir / "emmental.sql") self.assertEqual(emmental.db.SQL_V2_SCRIPT, dir / "upgrade-v2.sql") + self.assertEqual(emmental.db.SQL_V3_SCRIPT, dir / "upgrade-v3.sql") def test_connection(self): """Check that the connection manager is initialized properly.""" @@ -22,16 +23,16 @@ class TestConnection(tests.util.TestCase): def test_version(self): """Test checking the database schema version.""" cur = self.sql("PRAGMA user_version") - self.assertEqual(cur.fetchone()["user_version"], 2) + self.assertEqual(cur.fetchone()["user_version"], 3) def test_version_too_new(self): """Test failing when the database version is too new.""" self.sql._Connection__check_version() - self.sql("PRAGMA user_version = 3") + self.sql("PRAGMA user_version = 4") with self.assertRaises(Exception) as e: self.sql._Connection__check_version() - self.assertEqual(str(e.exception), "Unsupported data version: 3") + self.assertEqual(str(e.exception), "Unsupported data version: 4") def test_close(self): """Check closing the connection.""" diff --git a/tests/db/test_tracks.py b/tests/db/test_tracks.py index b11fe49..8bf1660 100644 --- a/tests/db/test_tracks.py +++ b/tests/db/test_tracks.py @@ -292,6 +292,20 @@ class TestTrackTable(tests.util.TestCase): self.assertFalse(track.delete()) + def test_delete_listens(self): + """Test deleting listens from the listenbrainz_queue.""" + track1 = self.tracks.create(self.library, pathlib.Path("/a/b/1.ogg"), + self.medium, self.year, length=10) + track2 = self.tracks.create(self.library, pathlib.Path("/a/b/2.ogg"), + self.medium, self.year, length=10) + + for track in [track1, track2]: + track.start() + track.stop(9) + + self.tracks.delete_listens([1, 2]) + self.assertListEqual(self.tracks.get_n_listens(5), []) + def test_delete_save(self): """Test saving track data when a track is deleted.""" now = datetime.datetime.now() @@ -485,6 +499,40 @@ class TestTrackTable(tests.util.TestCase): self.assertListEqual(self.tracks.get_genres(track), [genre1, genre2]) + def test_get_n_listens(self): + """Test the get_n_listens() function.""" + track1 = self.tracks.create(self.library, pathlib.Path("/a/b/1.ogg"), + self.medium, self.year, length=10) + track2 = self.tracks.create(self.library, pathlib.Path("/a/b/2.ogg"), + self.medium, self.year, length=12) + + self.assertListEqual(self.tracks.get_n_listens(2), []) + + track1.start() + track1.stop(8) + ts1 = track1.lastplayed + self.assertListEqual(self.tracks.get_n_listens(2), + [(1, track1, ts1)]) + + track2.start() + track2.stop(11) + ts2 = track2.lastplayed + self.assertListEqual(self.tracks.get_n_listens(2), + [(2, track2, ts2), + (1, track1, ts1)]) + + track1.start() + track1.stop(9) + ts3 = track1.lastplayed + self.assertListEqual(self.tracks.get_n_listens(2), + [(3, track1, ts3), + (2, track2, ts2)]) + + self.assertListEqual(self.tracks.get_n_listens(4), + [(3, track1, ts3), + (2, track2, ts2), + (1, track1, ts1)]) + def test_mark_path_active(self): """Test marking a path as active.""" self.tracks.create(self.library, pathlib.Path("/a/b/1.ogg"), @@ -551,6 +599,9 @@ class TestTrackTable(tests.util.TestCase): self.assertIsNone(track.lastplayed) self.assertIsNone(self.tracks.current_track) + cur = self.sql("SELECT trackid, timestamp FROM listenbrainz_queue") + self.assertListEqual(cur.fetchall(), []) + self.playlists.most_played.reload_tracks.assert_not_called() self.playlists.queued.remove_track.assert_not_called() self.playlists.unplayed.remove_track.assert_not_called() @@ -569,6 +620,11 @@ class TestTrackTable(tests.util.TestCase): self.assertEqual(row["lastplayed"], track.laststarted) self.assertEqual(track.lastplayed, track.laststarted) + cur = self.sql("SELECT trackid, timestamp FROM listenbrainz_queue") + row = cur.fetchall()[0] + self.assertEqual(row["trackid"], track.trackid) + self.assertEqual(row["timestamp"], track.lastplayed) + self.playlists.most_played.reload_tracks.assert_called() self.playlists.queued.remove_track.assert_called_with(track) self.playlists.unplayed.remove_track.assert_called_with(track) @@ -594,6 +650,9 @@ class TestTrackTable(tests.util.TestCase): self.assertIsNone(track.restarted) self.assertIsNone(self.tracks.current_track) + cur = self.sql("SELECT trackid, timestamp FROM listenbrainz_queue") + self.assertListEqual(cur.fetchall(), []) + self.playlists.most_played.reload_tracks.assert_not_called() self.playlists.queued.remove_track.assert_not_called() self.playlists.unplayed.remove_track.assert_not_called() @@ -611,6 +670,11 @@ class TestTrackTable(tests.util.TestCase): self.assertEqual(row["laststarted"], restarted) self.assertEqual(track.laststarted, restarted) + cur = self.sql("SELECT trackid, timestamp FROM listenbrainz_queue") + row = cur.fetchall()[0] + self.assertEqual(row["trackid"], track.trackid) + self.assertEqual(row["timestamp"], track.lastplayed) + self.playlists.most_played.reload_tracks.assert_called_with(idle=True) self.playlists.queued.remove_track.assert_called_with(track) self.playlists.unplayed.remove_track.assert_called_with(track)