diff --git a/emmental/db/emmental.sql b/emmental/db/emmental.sql index aa044c2..e345c1c 100644 --- a/emmental/db/emmental.sql +++ b/emmental/db/emmental.sql @@ -283,13 +283,15 @@ CREATE TABLE genres ( ); CREATE VIEW genres_view AS - SELECT genreid, propertyid, name, active + SELECT genreid, propertyid, name, + active, loop, shuffle, sort_order, current_trackid FROM genres JOIN playlist_properties USING (propertyid); CREATE TRIGGER genres_insert_trigger AFTER INSERT ON genres BEGIN - INSERT INTO playlist_properties (active) VALUES (False); + INSERT INTO playlist_properties (active, sort_order) + VALUES (False, "albumartist, album, mediumno, number"); UPDATE genres SET propertyid = last_insert_rowid() WHERE genreid = NEW.genreid; END; @@ -545,6 +547,14 @@ CREATE VIEW medium_tracks_view AS JOIN libraries USING (libraryid) WHERE libraries.deleting = False; +CREATE VIEW genre_tracks_view AS + SELECT tracks.trackid, genres.genreid + FROM tracks + JOIN system_tracks USING (trackid) + JOIN genres USING (propertyid) + JOIN libraries USING (libraryid) + WHERE libraries.deleting = False; + /**************************************************** * * diff --git a/emmental/db/genres.py b/emmental/db/genres.py index 4662986..2a55e1d 100644 --- a/emmental/db/genres.py +++ b/emmental/db/genres.py @@ -20,6 +20,10 @@ class Genre(playlist.Playlist): class Table(playlist.Table): """Our Genre Table.""" + def __init__(self, sql: GObject.TYPE_PYOBJECT, **kwargs): + """Initialize the Genres Table.""" + super().__init__(sql=sql, autodelete=True, **kwargs) + def do_construct(self, **kwargs) -> Genre: """Construct a new Genre.""" return Genre(**kwargs) diff --git a/emmental/db/tracks.py b/emmental/db/tracks.py index 2c28bfd..01061e4 100644 --- a/emmental/db/tracks.py +++ b/emmental/db/tracks.py @@ -50,6 +50,10 @@ class Track(table.Row): """Get a list of Artists for this Track.""" return self.table.get_artists(self) + def get_genres(self) -> list[table.Row]: + """Get a list of Genres for this Track.""" + return self.table.get_genres(self) + def get_library(self) -> table.Row | None: """Get the Library associated with this Track.""" return self.table.sql.libraries.rows.get(self.libraryid) @@ -202,6 +206,12 @@ class Table(table.Table): WHERE trackid=?""", track.trackid).fetchall() return [self.sql.artists.rows.get(row["artistid"]) for row in rows] + def get_genres(self, track: Track) -> list[int]: + """Get the list of Genres for a specific Track.""" + rows = self.sql("""SELECT genreid FROM genre_tracks_view + WHERE trackid=?""", track.trackid).fetchall() + return [self.sql.genres.rows.get(row["genreid"]) for row in rows] + def map_sort_order(self, ordering: str) -> dict[int, int]: """Get a lookup table for Track sort keys.""" ordering = ordering if len(ordering) > 0 else "trackid" diff --git a/tests/db/test_genres.py b/tests/db/test_genres.py index 169155d..badb056 100644 --- a/tests/db/test_genres.py +++ b/tests/db/test_genres.py @@ -36,6 +36,8 @@ class TestGenreTable(tests.util.TestCase): """Test that the genre model is configured correctly.""" self.assertIsInstance(self.table, emmental.db.playlist.Table) self.assertEqual(len(self.table), 0) + self.assertTrue(self.table.autodelete) + self.assertTrue(self.table.system_tracks) def test_construct(self): """Test constructing a new genre playlist.""" @@ -52,6 +54,8 @@ class TestGenreTable(tests.util.TestCase): genre = self.table.create("Test Genre") self.assertIsInstance(genre, emmental.db.genres.Genre) self.assertEqual(genre.name, "Test Genre") + self.assertEqual(genre.sort_order, + "albumartist, album, mediumno, number") cur = self.sql("SELECT COUNT(name) FROM genres") self.assertEqual(cur.fetchone()["COUNT(name)"], 1) @@ -119,7 +123,16 @@ class TestGenreTable(tests.util.TestCase): """Test updating genre attributes.""" genre = self.table.create("Test Genre") genre.active = True + genre.loop = "Track" + genre.shuffle = True + genre.sort_order = "trackid" - row = self.sql("""SELECT active FROM playlist_properties - WHERE propertyid=?""", genre.propertyid).fetchone() - self.assertEqual(row["active"], True) + row = self.sql("""SELECT active, loop, shuffle, + sort_order, current_trackid + FROM genres_view WHERE genreid=?""", + genre.genreid).fetchone() + self.assertTrue(row["active"]) + self.assertEqual(row["loop"], "Track") + self.assertTrue(row["shuffle"]) + self.assertEqual(row["sort_order"], "trackid") + self.assertIsNone(row["current_trackid"]) diff --git a/tests/db/test_tracks.py b/tests/db/test_tracks.py index af218c2..28b7215 100644 --- a/tests/db/test_tracks.py +++ b/tests/db/test_tracks.py @@ -70,6 +70,12 @@ class TestTrackObject(tests.util.TestCase): self.assertListEqual(self.track.get_artists(), [1, 2, 3]) self.table.get_artists.assert_called_with(self.track) + def test_get_genres(self): + """Test getting the Genre list for a Track.""" + self.table.get_genres = unittest.mock.Mock(return_value=[1, 2, 3]) + self.assertListEqual(self.track.get_genres(), [1, 2, 3]) + self.table.get_genres.assert_called_with(self.track) + def test_get_library(self): """Test getting the Library associated with a Track.""" self.assertEqual(self.track.get_library(), self.library) @@ -460,6 +466,18 @@ class TestTrackTable(tests.util.TestCase): self.assertListEqual(self.tracks.get_artists(track), [artist1, artist2]) + def test_get_genres(self): + """Test finding the genres for a track.""" + track = self.tracks.create(self.library, pathlib.Path("a/b/1.ogg"), + self.medium, self.year) + genre1 = self.sql.genres.create("Genre 1") + genre2 = self.sql.genres.create("Genre 2") + + genre1.add_track(track) + genre2.add_track(track) + self.assertListEqual(self.tracks.get_genres(track), + [genre1, genre2]) + def test_mark_path_active(self): """Test marking a path as active.""" self.tracks.create(self.library, pathlib.Path("/a/b/1.ogg"),