From 08687882a362e7187fd2f8aada7858c82d3bd68b Mon Sep 17 00:00:00 2001 From: Anna Schumaker Date: Fri, 23 Sep 2022 09:16:54 -0400 Subject: [PATCH] db: Add a Track Table The Track Table does all the work for saving, loading, and managing Track objects. I also create a SQLite View to link tracks to their associated artists, albums, and mediums. Signed-off-by: Anna Schumaker --- emmental/db/__init__.py | 5 + emmental/db/emmental.sql | 56 +++++++++ emmental/db/tracks.py | 90 ++++++++++++++ tests/db/test_db.py | 11 +- tests/db/test_tracks.py | 257 +++++++++++++++++++++++++++++++++++++++ 5 files changed, 415 insertions(+), 4 deletions(-) diff --git a/emmental/db/__init__.py b/emmental/db/__init__.py index 3aa2cac..1a98676 100644 --- a/emmental/db/__init__.py +++ b/emmental/db/__init__.py @@ -14,6 +14,7 @@ from . import media from . import playlists from . import settings from . import table +from . import tracks from . import years @@ -44,11 +45,14 @@ class Connection(connection.Connection): self.years = years.Table(self, queue=self.decades.queue) self.libraries = libraries.Table(self) + self.tracks = tracks.Table(self) + def close(self) -> None: """Close the database connection.""" self.settings.stop() for tbl in self.playlist_tables(): tbl.stop() + self.tracks.stop() super().close() @@ -62,6 +66,7 @@ class Connection(connection.Connection): self.settings.load() for tbl in self.playlist_tables(): tbl.load() + self.tracks.load() def playlist_tables(self) -> Generator[playlist.Table, None, None]: """Iterate over each playlist table.""" diff --git a/emmental/db/emmental.sql b/emmental/db/emmental.sql index f48c95b..1928e75 100644 --- a/emmental/db/emmental.sql +++ b/emmental/db/emmental.sql @@ -339,6 +339,62 @@ CREATE TRIGGER libraries_delete_trigger AFTER DELETE ON libraries END; +/************************ + * * + * Tracks * + * * + ************************/ + +CREATE TABLE tracks ( + trackid INTEGER PRIMARY KEY, + libraryid INTEGER REFERENCES libraries (libraryid) + ON DELETE CASCADE + ON UPDATE CASCADE, + mediumid INTEGER REFERENCES media (mediumid) + ON DELETE CASCADE + ON UPDATE CASCADE, + year INTEGER REFERENCES years (year) + ON DELETE CASCADE + ON UPDATE CASCADE, + path PATH NOT NULL, + mbid TEXT NOT NULL DEFAULT "" COLLATE NOCASE, + title TEXT NOT NULL, + number INTEGER NOT NULL, + length REAL NOT NULL, + artist TEXT NOT NULL, + mtime REAL NOT NULL, + active BOOLEAN NOT NULL DEFAULT FALSE, + favorite BOOLEAN NOT NULL DEFAULT FALSE, + playcount INTEGER NOT NULL DEFAULT 0, + added DATE DEFAULT CURRENT_DATE, + laststarted TIMESTAMP, + lastplayed TIMESTAMP, + UNIQUE (libraryid, path) +); + +CREATE VIEW track_info_view AS + SELECT trackid, tracks.mediumid, tracks.number, length, playcount, + laststarted, lastplayed, title, tracks.artist, + tracks.path as filepath, + media.number as mediumno, COALESCE(media.name, "") as medium, + albums.albumid, COALESCE(albums.name, "") as album, + COALESCE(albums.release, "") as release, + COALESCE(albums.artist, "") as albumartist, + libraries.deleting + FROM tracks + LEFT JOIN media USING (mediumid) + LEFT JOIN albums USING (albumid) + LEFT JOIN libraries USING (libraryid); + +CREATE TRIGGER tracks_active_trigger + AFTER UPDATE OF active ON tracks + FOR EACH ROW BEGIN + UPDATE tracks + SET active = FALSE + WHERE trackid != NEW.trackid and active == TRUE; + END; + + /****************************************** * * * Create Default Playlists * diff --git a/emmental/db/tracks.py b/emmental/db/tracks.py index 63549d3..1dc4a5e 100644 --- a/emmental/db/tracks.py +++ b/emmental/db/tracks.py @@ -1,6 +1,9 @@ # Copyright 2022 (c) Anna Schumaker. """A custom Gio.ListModel for working with tracks.""" +import pathlib +import sqlite3 from gi.repository import GObject +from gi.repository import Gtk from . import table @@ -62,3 +65,90 @@ class Track(table.Row): def primary_key(self) -> int: """Get the primary key for this Track.""" return self.trackid + + +class Filter(table.Filter): + """A customized Filter that never sets strictness to FilterMatch.All.""" + + def do_get_strictness(self) -> Gtk.FilterMatch: + """Get the strictness of the filter.""" + if self.n_keys == 0: + return Gtk.FilterMatch.NONE + return Gtk.FilterMatch.SOME + + +class Table(table.Table): + """A ListStore tailored for storing Track objects.""" + + def __init__(self, sql: GObject.TYPE_PYOBJECT): + """Initialize a Track Table.""" + super().__init__(sql, filter=Filter()) + self.set_model(None) + + def do_construct(self, **kwargs) -> Track: + """Construct a new Track instance.""" + return Track(**kwargs) + + def do_sql_delete(self, track: Track) -> sqlite3.Cursor: + """Delete a Track.""" + return self.sql("DELETE FROM tracks WHERE trackid=?", track.trackid) + + def do_sql_glob(self, glob: str) -> sqlite3.Cursor: + """Filter the Track table.""" + return self.sql("""SELECT trackid FROM track_info_view WHERE + CASEFOLD(title) GLOB :glob + OR CASEFOLD(artist) GLOB :glob + OR CASEFOLD(album) GLOB :glob + OR CASEFOLD(albumartist) GLOB :glob + OR CASEFOLD(medium) GLOB :glob + OR release GLOB :glob""", glob=glob) + + def do_sql_insert(self, library: table.Row, path: pathlib.Path, + medium: table.Row, year: table.Row, *, title: str = "", + number: int = 0, length: float = 0.0, artist: str = "", + mbid: str = "", mtime: float = 0.0) -> sqlite3.Cursor: + """Insert a new Track into the database.""" + return self.sql("""INSERT INTO tracks + (libraryid, mediumid, path, year, title, + number, length, artist, mbid, mtime) + VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?) + RETURNING *""", + library.libraryid, medium.mediumid, path, year.year, + title, number, length, artist, mbid, mtime) + + def do_sql_select_all(self) -> sqlite3.Cursor: + """Load Tracks from the database.""" + return self.sql("SELECT * FROM tracks") + + def do_sql_select_one(self, library: table.Row = None, + *, path: pathlib.Path = None, + mbid: str = None) -> sqlite3.Cursor: + """Look up a Track in the database.""" + if path is None and mbid is None: + raise KeyError("Either 'path' or 'mbid' are required") + + args = [("libraryid=?", library.libraryid if library else None), + ("path=?", path), ("mbid=?", mbid)] + + (where, args) = tuple(zip(*[arg for arg in args if None not in arg])) + sql_where = " AND ".join(where) + return self.sql(f"SELECT trackid FROM tracks WHERE {sql_where}", *args) + + def do_sql_update(self, track: Track, column: str, + newval: any) -> sqlite3.Cursor: + """Update a Track.""" + match (column, newval): + case ("favorite", True): + self.sql.playlists.favorites.add_track(track) + case ("favorite", False): + self.sql.playlists.favorites.remove_track(track) + + return self.sql(f"UPDATE tracks SET {column}=? WHERE trackid=?", + newval, track.trackid) + + def map_sort_order(self, ordering: str) -> dict[int, int]: + """Get a lookup table for Track sort keys.""" + ordering = ordering if len(ordering) > 0 else "trackid" + cur = self.sql(f"""SELECT trackid FROM track_info_view + ORDER BY {ordering}""") + return {row["trackid"]: i for (i, row) in enumerate(cur.fetchall())} diff --git a/tests/db/test_db.py b/tests/db/test_db.py index 13410f5..716c023 100644 --- a/tests/db/test_db.py +++ b/tests/db/test_db.py @@ -48,6 +48,7 @@ class TestConnection(tests.util.TestCase): self.assertIsInstance(self.sql.decades, emmental.db.decades.Table) self.assertIsInstance(self.sql.years, emmental.db.years.Table) self.assertIsInstance(self.sql.libraries, emmental.db.libraries.Table) + self.assertIsInstance(self.sql.tracks, emmental.db.tracks.Table) self.assertEqual(self.sql.albums.queue, self.sql.artists.queue) self.assertEqual(self.sql.media.queue, self.sql.artists.queue) @@ -61,20 +62,22 @@ class TestConnection(tests.util.TestCase): def test_load(self): """Check that calling load() loads the tables.""" + idle_tables = [tbl for tbl in self.sql.playlist_tables()] + \ + [self.sql.tracks] + table_loaded = unittest.mock.Mock() self.sql.connect("table-loaded", table_loaded) self.sql.load() self.assertTrue(self.sql.settings.loaded) - for tbl in self.sql.playlist_tables(): + for tbl in idle_tables: self.assertFalse(tbl.loaded) - for tbl in self.sql.playlist_tables(): + for tbl in idle_tables: tbl.queue.complete() self.assertTrue(tbl.loaded) - tables = [tbl for tbl in self.sql.playlist_tables()] calls = [unittest.mock.call(self.sql, tbl) - for tbl in [self.sql.settings] + tables] + for tbl in [self.sql.settings] + idle_tables] table_loaded.assert_has_calls(calls) def test_filter(self): diff --git a/tests/db/test_tracks.py b/tests/db/test_tracks.py index 9dc5f5c..0243dad 100644 --- a/tests/db/test_tracks.py +++ b/tests/db/test_tracks.py @@ -6,6 +6,7 @@ import emmental.db.tracks import tests.util import unittest.mock from gi.repository import Gio +from gi.repository import Gtk class TestTrackObject(tests.util.TestCase): @@ -103,3 +104,259 @@ class TestTrackObject(tests.util.TestCase): artist="New Artist", number=2, length=12.345, mtime=67.890) self.table.update.assert_not_called() + + +class TestTrackTable(tests.util.TestCase): + """Tests our track table.""" + + def setUp(self): + """Set up common variables.""" + super().setUp() + self.sql.playlists.load(now=True) + self.sql.playlists.favorites.add_track = unittest.mock.Mock() + self.sql.playlists.favorites.remove_track = unittest.mock.Mock() + self.sql.playlists.previous.add_track = unittest.mock.Mock() + + self.library = self.sql.libraries.create(pathlib.Path("/a/b/")) + self.album = self.sql.albums.create("Test Album", "Album Artist", + release="2022-10") + self.medium = self.sql.media.create(self.album, "Test Medium", + number=1) + self.year = self.sql.years.create(1988) + + self.tracks = self.sql.tracks + + def test_track_filter(self): + """Test the tracks.Filter object.""" + filter = emmental.db.tracks.Filter() + self.assertEqual(filter.get_strictness(), Gtk.FilterMatch.SOME) + filter.keys = {1, 2, 3} + self.assertEqual(filter.get_strictness(), Gtk.FilterMatch.SOME) + filter.keys = set() + self.assertEqual(filter.get_strictness(), Gtk.FilterMatch.NONE) + + def test_init(self): + """Test that the Track table is initialized properly.""" + self.assertIsInstance(self.tracks, emmental.db.table.Table) + self.assertIsInstance(self.tracks.get_filter(), + emmental.db.tracks.Filter) + self.assertIsNone(self.tracks.get_model()) + + def test_construct(self): + """Test constructing a new Track.""" + now = datetime.datetime.now() + track = self.tracks.construct(trackid=1, year=1988, + libraryid=self.library.libraryid, + mediumid=self.medium.mediumid, + path=pathlib.Path("/a/b/c.ogg"), + mbid="ab-cd-ef", title="Title", number=1, + length=1.0, artist="Artist", mtime=1.0, + playcount=1, lastplayed=now) + self.assertIsInstance(track, emmental.db.tracks.Track) + self.assertEqual(track.table, self.tracks) + self.assertEqual(track.trackid, 1) + self.assertEqual(track.libraryid, self.library.libraryid) + self.assertEqual(track.mediumid, self.medium.mediumid) + self.assertEqual(track.year, 1988) + self.assertEqual(track.path, pathlib.Path("/a/b/c.ogg")) + self.assertEqual(track.mbid, "ab-cd-ef") + self.assertEqual(track.title, "Title") + self.assertEqual(track.number, 1) + self.assertEqual(track.length, 1.0) + self.assertEqual(track.artist, "Artist") + self.assertEqual(track.mtime, 1.0) + self.assertEqual(track.playcount, 1) + self.assertEqual(track.lastplayed, now) + self.assertFalse(track.active) + self.assertFalse(track.favorite) + + def test_create(self): + """Test creating a new Track.""" + track = self.tracks.create(self.library, pathlib.Path("/a/b/c.ogg"), + self.medium, self.year) + self.assertIsInstance(track, emmental.db.tracks.Track) + self.assertEqual(track.libraryid, self.library.libraryid) + self.assertEqual(track.mediumid, self.medium.mediumid) + self.assertEqual(track.year, 1988) + self.assertEqual(track.path, pathlib.Path("/a/b/c.ogg")) + self.assertEqual(track.added, datetime.datetime.utcnow().date()) + + track2 = self.tracks.create(self.library, pathlib.Path("/a/b/d.ogg"), + self.medium, self.year, title="Test Track", + number=1, length=1.23, artist="Artist", + mbid="ab-cd-ef", mtime=4.56) + self.assertEqual(track2.trackid, 2) + self.assertEqual(track2.libraryid, self.library.libraryid) + self.assertEqual(track2.mediumid, self.medium.mediumid) + self.assertEqual(track2.path, pathlib.Path("/a/b/d.ogg")) + self.assertEqual(track2.title, "Test Track") + self.assertEqual(track2.number, 1) + self.assertEqual(track2.length, 1.23) + self.assertEqual(track2.artist, "Artist") + self.assertEqual(track2.mbid, "ab-cd-ef") + self.assertEqual(track2.mtime, 4.56) + + track3 = self.tracks.create(self.library, pathlib.Path("/a/b/c.ogg"), + self.medium, self.year) + self.assertIsNone(track3) + + cur = self.sql("SELECT COUNT(*) FROM tracks") + self.assertEqual(cur.fetchone()["COUNT(*)"], 2) + + def test_delete(self): + """Test deleting a Track.""" + track = self.tracks.create(self.library, pathlib.Path("/a/b/c.ogg"), + self.medium, self.year) + + self.assertTrue(track.delete()) + self.assertIsNone(self.tracks.index(track)) + + cur = self.sql("SELECT COUNT(path) FROM tracks") + self.assertEqual(cur.fetchone()["COUNT(path)"], 0) + self.assertEqual(len(self.tracks), 0) + + self.assertFalse(track.delete()) + + def test_filter(self): + """Test filtering the Track table.""" + self.tracks.create(self.library, pathlib.Path("/a/b/1.ogg"), + self.medium, self.year, + title="Title 1", artist="Test Artist") + self.tracks.create(self.library, pathlib.Path("/a/b/2.ogg"), + self.medium, self.year, + title="Title 2", artist="Test Artist") + + self.tracks.filter("*1", now=True) + self.assertSetEqual(self.tracks.get_filter().keys, {1}) + self.tracks.filter("*artist", now=True) + self.assertSetEqual(self.tracks.get_filter().keys, {1, 2}) + self.tracks.filter("*medium", now=True) + self.assertSetEqual(self.tracks.get_filter().keys, {1, 2}) + self.tracks.filter("*album", now=True) + self.assertSetEqual(self.tracks.get_filter().keys, {1, 2}) + self.tracks.filter("*album artist", now=True) + self.assertSetEqual(self.tracks.get_filter().keys, {1, 2}) + self.tracks.filter("2022-*", now=True) + self.assertSetEqual(self.tracks.get_filter().keys, {1, 2}) + + def test_load(self): + """Test loading tracks from the database.""" + now = datetime.datetime.now() + self.tracks.create(self.library, pathlib.Path("/a/b/1.ogg"), + self.medium, self.year) + self.tracks.create(self.library, pathlib.Path("/a/b/2.ogg"), + self.medium, self.year, title="Track 2", + number=2, length=2, artist="Test Artist", + mbid="ab-cd-ef", mtime=123.45) + self.sql("""UPDATE tracks SET active=TRUE, favorite=TRUE, playcount=3, + lastplayed=?, laststarted=? WHERE trackid=2""", now, now) + + table2 = emmental.db.tracks.Table(sql=self.sql) + self.assertEqual(len(table2), 0) + + table2.load(now=True) + self.assertEqual(len(table2.store), 2) + + for i in [0, 1]: + with self.subTest(i=i): + self.assertEqual(table2.store[i].trackid, i + 1) + self.assertEqual(table2.store[i].libraryid, + self.library.libraryid) + self.assertEqual(table2.store[i].mediumid, + self.medium.mediumid) + self.assertEqual(table2.store[i].year, self.year.year) + + self.assertEqual(table2.store[i].active, bool(i)) + self.assertEqual(table2.store[i].favorite, bool(i)) + + self.assertEqual(table2.store[i].path, + pathlib.Path(f"/a/b/{i+1}.ogg")) + self.assertEqual(table2.store[i].mbid, "ab-cd-ef" if i else "") + self.assertEqual(table2.store[i].title, "Track 2" if i else "") + self.assertEqual(table2.store[i].artist, + "Test Artist" if i else "") + self.assertEqual(table2.store[i].number, 2 if i else 0) + self.assertEqual(table2.store[i].length, 2 if i else 0) + self.assertEqual(table2.store[i].mtime, 123.45 if i else 0) + self.assertEqual(table2.store[i].playcount, 3 if i else 0) + + self.assertEqual(table2.store[i].laststarted, + now if i else None) + self.assertEqual(table2.store[i].lastplayed, + now if i else None) + self.assertIsNone(table2.store[i].restarted) + + def test_lookup(self): + """Test looking up tracks in the database.""" + path1 = pathlib.Path("/a/b/1.ogg") + path2 = pathlib.Path("/a/b/2.ogg") + track1 = self.tracks.create(self.library, path1, + self.medium, self.year) + track2 = self.tracks.create(self.library, path2, + self.medium, self.year, mbid="ab-cd-ef") + library2 = self.sql.libraries.create(pathlib.Path("/a/b/d")) + + self.assertEqual(self.tracks.lookup(self.library, path=path1), track1) + self.assertEqual(self.tracks.lookup(path=path1), track1) + self.assertEqual(self.tracks.lookup(path=path2), track2) + self.assertIsNone(self.tracks.lookup(path="/no/such/track")) + self.assertIsNone(self.tracks.lookup(library2, path=path1)) + + self.assertEqual(self.tracks.lookup(self.library, mbid="ab-cd-ef"), + track2) + self.assertEqual(self.tracks.lookup(mbid="ab-cd-ef"), track2) + self.assertIsNone(self.tracks.lookup(mbid="gh-ij-kl")) + + with self.assertRaises(KeyError) as error: + self.tracks.lookup(self.library) + self.assertEqual(error.value, + "Either 'path' or 'mbid' are required") + + def test_map_sort_order(self): + """Test getting a lookup table for Track sort keys.""" + tracks = [self.tracks.create(self.library, + pathlib.Path(f"/a/b/{n}.ogg"), + self.medium, self.year, number=n) + for n in range(10)] + + self.assertDictEqual(self.tracks.map_sort_order(""), + {t.trackid: t.trackid - 1 for t in tracks}) + self.assertDictEqual(self.tracks.map_sort_order("number DESC"), + {t.trackid: 10 - t.trackid for t in tracks}) + + def test_update(self): + """Test updating tracks in the database.""" + track = self.tracks.create(self.library, pathlib.Path("/a/b/1.ogg"), + self.medium, self.year, length=10) + medium2 = self.sql.media.create(self.album, "", number=2) + y2022 = self.sql.years.create(2022) + + track.update_properties(mediumid=medium2.mediumid, year=y2022.year, + favorite=True, mbid="ab-cd-ef", + title="New Title", artist="New Artist", + number=1, length=42, mtime=123.45) + self.sql.playlists.favorites.add_track.assert_called_with(track) + + cur = self.sql("""SELECT mediumid, year, favorite, mbid, title, + artist, number, length, mtime + FROM tracks WHERE trackid = ?""", track.trackid) + row = cur.fetchone() + self.assertEqual(row["mediumid"], medium2.mediumid) + self.assertEqual(row["year"], 2022) + self.assertTrue(row["favorite"]) + self.assertEqual(row["mbid"], "ab-cd-ef") + self.assertEqual(row["title"], "New Title") + self.assertEqual(row["artist"], "New Artist") + self.assertEqual(row["number"], 1) + self.assertEqual(row["length"], 42) + self.assertEqual(row["mtime"], 123.45) + + track.update_properties(favorite=False) + self.sql.playlists.favorites.remove_track.assert_called_with(track) + + track2 = self.tracks.create(self.library, pathlib.Path("/a/b/2.ogg"), + self.medium, self.year, length=10) + track2.active = True + row = self.sql("SELECT active FROM tracks WHERE trackid=?", + track.trackid).fetchone() + self.assertFalse(row["active"])