db: Create a listenbrainz_queue table in the database

I bump the user_version to 3 at the same time. This table will be used
to hold listenbrainz listens that have not yet been submitted to the
listenbrainz server. I also give the Track table functions to get and
delete listens from this table as needed by the listenbrainz thread.

Signed-off-by: Anna Schumaker <Anna@NoWheyCreamery.com>
This commit is contained in:
Anna Schumaker 2024-02-19 21:44:13 -05:00
parent eada937b7a
commit c7dca6164e
5 changed files with 116 additions and 4 deletions

View File

@ -20,6 +20,7 @@ from . import years
SQL_V1_SCRIPT = pathlib.Path(__file__).parent / "emmental.sql"
SQL_V2_SCRIPT = pathlib.Path(__file__).parent / "upgrade-v2.sql"
SQL_V3_SCRIPT = pathlib.Path(__file__).parent / "upgrade-v3.sql"
class Connection(connection.Connection):
@ -57,9 +58,13 @@ class Connection(connection.Connection):
case 0:
self.executescript(SQL_V1_SCRIPT)
self.executescript(SQL_V2_SCRIPT)
self.executescript(SQL_V3_SCRIPT)
case 1:
self.executescript(SQL_V2_SCRIPT)
case 2: pass
self.executescript(SQL_V3_SCRIPT)
case 2:
self.executescript(SQL_V3_SCRIPT)
case 3: pass
case _:
raise Exception(f"Unsupported data version: {user_version}")

View File

@ -200,6 +200,12 @@ class Table(table.Table):
return self.sql(f"UPDATE tracks SET {column}=? WHERE trackid=?",
newval, track.trackid)
def delete_listens(self, listenids: list[int]) -> None:
"""Delete the listens indicated by the provided listenids."""
self.sql.executemany("""DELETE FROM listenbrainz_queue
WHERE listenid=?""",
*[(id,) for id in listenids])
def get_artists(self, track: Track) -> list[table.Row]:
"""Get the set of Artists for a specific Track."""
rows = self.sql("""SELECT artistid FROM artist_tracks_view
@ -212,6 +218,14 @@ class Table(table.Table):
WHERE trackid=?""", track.trackid).fetchall()
return [self.sql.genres.rows.get(row["genreid"]) for row in rows]
def get_n_listens(self, n: int) -> list[tuple]:
"""Get the n most recent listens from the listenbrainz queue."""
cur = self.sql("""SELECT listenid, trackid, timestamp
FROM listenbrainz_queue ORDER BY timestamp DESC
LIMIT ?""", n)
return [(row["listenid"], self.rows[row["trackid"]], row["timestamp"])
for row in cur.fetchall()]
def map_sort_order(self, ordering: str) -> dict[int, int]:
"""Get a lookup table for Track sort keys."""
ordering = ordering if len(ordering) > 0 else "trackid"
@ -277,6 +291,9 @@ class Table(table.Table):
@GObject.Signal(arg_types=(Track,))
def track_played(self, track: Track) -> None:
"""Signal that a Track was played."""
if track is not None:
self.sql("""INSERT INTO listenbrainz_queue (trackid, timestamp)
VALUES (?, ?)""", track.trackid, track.lastplayed)
class TrackidSet(GObject.GObject):

View File

@ -0,0 +1,25 @@
/* Copyright 2024 (c) Anna Schumaker */
PRAGMA user_version = 3;
/*
* The `listenbrainz_queue` table is used to store recently played tracks
* before submitting them to ListenBrainz. This gives us some form of offline
* recovery, since anything in this table needs to be submitted the next time
* we can successfully connect. As a bonus, I prepopulate this table using
* the last played data from tracks that have already been played when this
* table is created.
*/
CREATE TABLE listenbrainz_queue (
listenid INTEGER PRIMARY KEY,
trackid INTEGER REFERENCES tracks (trackid)
ON DELETE CASCADE
ON UPDATE CASCADE,
timestamp TIMESTAMP DEFAULT CURRENT_TIMESTAMP
);
INSERT INTO listenbrainz_queue (trackid, timestamp)
SELECT trackid, lastplayed FROM tracks
WHERE lastplayed IS NOT NULL;

View File

@ -14,6 +14,7 @@ class TestConnection(tests.util.TestCase):
dir = pathlib.Path(emmental.db.__file__).parent
self.assertEqual(emmental.db.SQL_V1_SCRIPT, dir / "emmental.sql")
self.assertEqual(emmental.db.SQL_V2_SCRIPT, dir / "upgrade-v2.sql")
self.assertEqual(emmental.db.SQL_V3_SCRIPT, dir / "upgrade-v3.sql")
def test_connection(self):
"""Check that the connection manager is initialized properly."""
@ -22,16 +23,16 @@ class TestConnection(tests.util.TestCase):
def test_version(self):
"""Test checking the database schema version."""
cur = self.sql("PRAGMA user_version")
self.assertEqual(cur.fetchone()["user_version"], 2)
self.assertEqual(cur.fetchone()["user_version"], 3)
def test_version_too_new(self):
"""Test failing when the database version is too new."""
self.sql._Connection__check_version()
self.sql("PRAGMA user_version = 3")
self.sql("PRAGMA user_version = 4")
with self.assertRaises(Exception) as e:
self.sql._Connection__check_version()
self.assertEqual(str(e.exception), "Unsupported data version: 3")
self.assertEqual(str(e.exception), "Unsupported data version: 4")
def test_close(self):
"""Check closing the connection."""

View File

@ -292,6 +292,20 @@ class TestTrackTable(tests.util.TestCase):
self.assertFalse(track.delete())
def test_delete_listens(self):
"""Test deleting listens from the listenbrainz_queue."""
track1 = self.tracks.create(self.library, pathlib.Path("/a/b/1.ogg"),
self.medium, self.year, length=10)
track2 = self.tracks.create(self.library, pathlib.Path("/a/b/2.ogg"),
self.medium, self.year, length=10)
for track in [track1, track2]:
track.start()
track.stop(9)
self.tracks.delete_listens([1, 2])
self.assertListEqual(self.tracks.get_n_listens(5), [])
def test_delete_save(self):
"""Test saving track data when a track is deleted."""
now = datetime.datetime.now()
@ -485,6 +499,40 @@ class TestTrackTable(tests.util.TestCase):
self.assertListEqual(self.tracks.get_genres(track),
[genre1, genre2])
def test_get_n_listens(self):
"""Test the get_n_listens() function."""
track1 = self.tracks.create(self.library, pathlib.Path("/a/b/1.ogg"),
self.medium, self.year, length=10)
track2 = self.tracks.create(self.library, pathlib.Path("/a/b/2.ogg"),
self.medium, self.year, length=12)
self.assertListEqual(self.tracks.get_n_listens(2), [])
track1.start()
track1.stop(8)
ts1 = track1.lastplayed
self.assertListEqual(self.tracks.get_n_listens(2),
[(1, track1, ts1)])
track2.start()
track2.stop(11)
ts2 = track2.lastplayed
self.assertListEqual(self.tracks.get_n_listens(2),
[(2, track2, ts2),
(1, track1, ts1)])
track1.start()
track1.stop(9)
ts3 = track1.lastplayed
self.assertListEqual(self.tracks.get_n_listens(2),
[(3, track1, ts3),
(2, track2, ts2)])
self.assertListEqual(self.tracks.get_n_listens(4),
[(3, track1, ts3),
(2, track2, ts2),
(1, track1, ts1)])
def test_mark_path_active(self):
"""Test marking a path as active."""
self.tracks.create(self.library, pathlib.Path("/a/b/1.ogg"),
@ -551,6 +599,9 @@ class TestTrackTable(tests.util.TestCase):
self.assertIsNone(track.lastplayed)
self.assertIsNone(self.tracks.current_track)
cur = self.sql("SELECT trackid, timestamp FROM listenbrainz_queue")
self.assertListEqual(cur.fetchall(), [])
self.playlists.most_played.reload_tracks.assert_not_called()
self.playlists.queued.remove_track.assert_not_called()
self.playlists.unplayed.remove_track.assert_not_called()
@ -569,6 +620,11 @@ class TestTrackTable(tests.util.TestCase):
self.assertEqual(row["lastplayed"], track.laststarted)
self.assertEqual(track.lastplayed, track.laststarted)
cur = self.sql("SELECT trackid, timestamp FROM listenbrainz_queue")
row = cur.fetchall()[0]
self.assertEqual(row["trackid"], track.trackid)
self.assertEqual(row["timestamp"], track.lastplayed)
self.playlists.most_played.reload_tracks.assert_called()
self.playlists.queued.remove_track.assert_called_with(track)
self.playlists.unplayed.remove_track.assert_called_with(track)
@ -594,6 +650,9 @@ class TestTrackTable(tests.util.TestCase):
self.assertIsNone(track.restarted)
self.assertIsNone(self.tracks.current_track)
cur = self.sql("SELECT trackid, timestamp FROM listenbrainz_queue")
self.assertListEqual(cur.fetchall(), [])
self.playlists.most_played.reload_tracks.assert_not_called()
self.playlists.queued.remove_track.assert_not_called()
self.playlists.unplayed.remove_track.assert_not_called()
@ -611,6 +670,11 @@ class TestTrackTable(tests.util.TestCase):
self.assertEqual(row["laststarted"], restarted)
self.assertEqual(track.laststarted, restarted)
cur = self.sql("SELECT trackid, timestamp FROM listenbrainz_queue")
row = cur.fetchall()[0]
self.assertEqual(row["trackid"], track.trackid)
self.assertEqual(row["timestamp"], track.lastplayed)
self.playlists.most_played.reload_tracks.assert_called_with(idle=True)
self.playlists.queued.remove_track.assert_called_with(track)
self.playlists.unplayed.remove_track.assert_called_with(track)