From 1730b7e92c43d7e8da609309efb7fdebcf083528 Mon Sep 17 00:00:00 2001 From: Anna Schumaker Date: Tue, 23 Aug 2022 16:11:09 -0400 Subject: [PATCH] db: Create a link between Artists and Albums I use a sql link table to accomplish this so a single album can be added to multiple album-artists. Additionally, I set up a view on Artists and Albums to make filtering easier without needing to use a complicated join every time. Additionally, I use the Playlist.add_children() function to set up a filter on the Album list model for each Artist's albums. Signed-off-by: Anna Schumaker --- emmental/db/albums.py | 19 ++++++ emmental/db/artists.py | 83 +++++++++++++++++++++++- emmental/db/emmental.sql | 24 +++++++ tests/db/test_albums.py | 29 +++++++++ tests/db/test_artists.py | 133 +++++++++++++++++++++++++++++++++++++-- 5 files changed, 282 insertions(+), 6 deletions(-) diff --git a/emmental/db/albums.py b/emmental/db/albums.py index bd19596..ed20f84 100644 --- a/emmental/db/albums.py +++ b/emmental/db/albums.py @@ -16,11 +16,21 @@ class Album(playlist.Playlist): mbid = GObject.Property(type=str) cover = GObject.Property(type=GObject.TYPE_PYOBJECT) + def get_artists(self) -> list[playlist.Playlist]: + """Get a list of artists for this album.""" + return self.table.get_artists(self) + @property def primary_key(self) -> int: """Get the Album primary key.""" return self.albumid + @GObject.Property(type=playlist.Playlist) + def parent(self) -> playlist.Playlist | None: + """Get the parent playlist of this Album.""" + artists = self.get_artists() + return artists[0] if len(artists) else None + class Table(playlist.Table): """Our Album Table.""" @@ -39,6 +49,8 @@ class Table(playlist.Table): def do_sql_delete(self, album: Album) -> sqlite3.Cursor: """Delete an album.""" + for artist in album.get_artists(): + artist.remove_album(album) return self.sql("DELETE FROM albums WHERE albumid=?", album.albumid) def do_sql_glob(self, glob: str) -> sqlite3.Cursor: @@ -80,3 +92,10 @@ class Table(playlist.Table): """Rename an album.""" return self.sql(f"UPDATE albums SET {column}=? WHERE albumid=?", newval, album.albumid) + + def get_artists(self, album: Album) -> list[playlist.Playlist]: + """Get the list of artists for this album.""" + rows = self.sql("""SELECT artistid FROM album_artist_link + WHERE albumid=?""", album.albumid).fetchall() + artists = [self.sql.artists.rows.get(row["artistid"]) for row in rows] + return list(filter(None, artists)) diff --git a/emmental/db/artists.py b/emmental/db/artists.py index 6ec3dcc..ce8eb87 100644 --- a/emmental/db/artists.py +++ b/emmental/db/artists.py @@ -2,8 +2,11 @@ """A custom Gio.ListModel for working with artists.""" import sqlite3 from gi.repository import GObject +from gi.repository import Gtk +from .albums import Album from .. import format from . import playlist +from . import table class Artist(playlist.Playlist): @@ -12,15 +15,73 @@ class Artist(playlist.Playlist): artistid = GObject.Property(type=int) mbid = GObject.Property(type=str) + def __init__(self, **kwargs): + """Initialize an Artist object.""" + super().__init__(**kwargs) + self.add_children(self.table.sql.albums, + table.Filter(self.table.get_albumids(self))) + + def add_album(self, album: Album) -> None: + """Add an Album to this Artist.""" + if self.table.add_album(self, album): + self.children.get_filter().add_row(album) + + def has_album(self, album: Album) -> bool: + """Check if the Artist has this Album.""" + return self.children.get_filter().match(album) + + def remove_album(self, album: Album) -> None: + """Remove an album from this Artist.""" + self.children.get_filter().remove_row(album) + self.table.remove_album(self, album) + @property def primary_key(self) -> int: """Get the Artist primary key.""" return self.artistid +class Filter(table.Filter): + """Custom filter to hide artists without albums.""" + + show_all = GObject.Property(type=bool, default=False) + + def __init__(self, show_all: bool = False): + """Initialize the Artist filter.""" + super().__init__(show_all=show_all) + self.connect("notify::show-all", self.__notify_show_all) + + def __notify_show_all(self, filter: table.Filter, param) -> None: + self.changed(Gtk.FilterChange.LESS_STRICT if self.show_all else + Gtk.FilterChange.MORE_STRICT) + + def do_get_strictness(self) -> Gtk.FilterMatch: + """Get the strictness of the filter.""" + res = super().do_get_strictness() + if not self.show_all and res == Gtk.FilterMatch.ALL: + return Gtk.FilterMatch.SOME + return res + + def do_match(self, artist: Artist) -> bool: + """Check if the artist matches the filter.""" + res = super().do_match(artist) + if not self.show_all and res: + return artist.children.get_filter().n_keys > 0 + return res + + class Table(playlist.Table): """Our Artist Table.""" + show_all = GObject.Property(type=bool, default=False) + + def __init__(self, sql: GObject.TYPE_PYOBJECT, + show_all: bool = False, **kwargs): + """Initialize an Artist model.""" + super().__init__(sql=sql, show_all=show_all, + filter=Filter(show_all=show_all), **kwargs) + self.bind_property("show-all", self.get_filter(), "show-all") + def do_construct(self, **kwargs) -> Artist: """Construct a new artist.""" return Artist(**kwargs) @@ -38,8 +99,9 @@ class Table(playlist.Table): def do_sql_glob(self, glob: str) -> sqlite3.Cursor: """Search for artists matching the search text.""" - return self.sql("""SELECT artistid FROM artists WHERE - CASEFOLD(name) GLOB ?""", glob) + return self.sql("""SELECT artistid FROM album_artist_view + WHERE CASEFOLD(artist) GLOB :glob + OR CASEFOLD(album) GLOB :glob""", glob=glob) def do_sql_insert(self, name: str, mbid: str = "") -> sqlite3.Cursor | None: @@ -65,3 +127,20 @@ class Table(playlist.Table): """Update an artist.""" return self.sql(f"UPDATE artists SET {column}=? WHERE artistid=?", newval, artist.artistid) + + def add_album(self, artist: Artist, album: Album) -> bool: + """Add an album to this artist.""" + return self.sql("INSERT INTO album_artist_link VALUES (?, ?)", + artist.artistid, album.albumid) is not None + + def get_albumids(self, artist: Artist) -> set[int]: + """Get an Artist's associated albumids from the database.""" + cur = self.sql("""SELECT albumid FROM album_artist_link + WHERE artistid=?""", artist.artistid) + return {row["albumid"] for row in cur.fetchall()} + + def remove_album(self, artist: Artist, album: Album) -> bool: + """Remove an album from this artist.""" + return self.sql("""DELETE FROM album_artist_link + WHERE artistid=? AND albumid=?""", + artist.artistid, album.albumid).rowcount == 1 diff --git a/emmental/db/emmental.sql b/emmental/db/emmental.sql index 39353f7..cb721b7 100644 --- a/emmental/db/emmental.sql +++ b/emmental/db/emmental.sql @@ -146,6 +146,30 @@ CREATE TRIGGER albums_delete_trigger AFTER DELETE ON albums END; +/******************************************* + * * + * Artist <--> Album Linking * + * * + *******************************************/ + +CREATE TABLE album_artist_link ( + artistid INTEGER NOT NULL REFERENCES artists (artistid) + ON DELETE CASCADE + ON UPDATE CASCADE, + albumid INTEGER NOT NULL REFERENCES albums (albumid) + ON DELETE CASCADE + ON UPDATE CASCADE, + UNIQUE (artistid, albumid) +); + +CREATE VIEW album_artist_view AS + SELECT artistid, artists.name as artist, + albumid, COALESCE(albums.name, "") as album + FROM artists + LEFT JOIN album_artist_link USING (artistid) + LEFT JOIN albums USING (albumid); + + /****************************************** * * * Create Default Playlists * diff --git a/tests/db/test_albums.py b/tests/db/test_albums.py index 524a82d..bbb541b 100644 --- a/tests/db/test_albums.py +++ b/tests/db/test_albums.py @@ -1,6 +1,7 @@ # Copyright 2022 (c) Anna Schumaker. """Tests our album Gio.ListModel.""" import pathlib +import unittest.mock import emmental.db import tests.util @@ -39,6 +40,14 @@ class TestAlbumObject(tests.util.TestCase): self.assertEqual(album2.mbid, "ab-cd-ef") self.assertEqual(album2.cover, cover) + def test_get_artists(self): + """Test getting the list of artists for this album.""" + with unittest.mock.patch.object(self.table, "get_artists", + return_value=[1, 2, 3]) as mock: + self.assertListEqual(self.album.get_artists(), [1, 2, 3]) + mock.assert_called_with(self.album) + self.assertEqual(self.album.parent, 1) + class TestAlbumTable(tests.util.TestCase): """Tests our album table.""" @@ -106,9 +115,13 @@ class TestAlbumTable(tests.util.TestCase): def test_delete(self): """Test deleting an album playlist.""" + artist = self.sql.artists.create("Test Artist") album = self.table.create("Test Album", "Album Artist", "2023-03") + artist.add_album(album) + self.assertTrue(album.delete()) self.assertIsNone(self.table.index(album)) + self.assertFalse(artist.has_album(album)) cur = self.sql("SELECT COUNT(name) FROM albums") self.assertEqual(cur.fetchone()["COUNT(name)"], 0) @@ -119,6 +132,9 @@ class TestAlbumTable(tests.util.TestCase): WHERE propertyid=?""", album.propertyid).fetchone() self.assertEqual(row["COUNT(*)"], 0) + cur = self.sql("SELECT COUNT(artistid) FROM album_artist_link") + self.assertEqual(cur.fetchone()["COUNT(artistid)"], 0) + self.assertFalse(album.delete()) def test_filter(self): @@ -204,3 +220,16 @@ class TestAlbumTable(tests.util.TestCase): row = self.sql("SELECT cover FROM albums WHERE albumid=?", album.albumid).fetchone() self.assertIsNone(row["cover"], tests.util.COVER_JPG) + + def test_get_artists(self): + """Test getting the list of artists an album is attached to.""" + artist1 = self.sql.artists.create("Artist 1") + artist2 = self.sql.artists.create("Artist 2") + album = self.table.create("Test Album", "Album Artist", "2023-03") + + artist1.add_album(album) + artist2.add_album(album) + self.assertListEqual(self.table.get_artists(album), [artist1, artist2]) + + del self.sql.artists.rows[artist1.artistid] + self.assertListEqual(self.table.get_artists(album), [artist2]) diff --git a/tests/db/test_artists.py b/tests/db/test_artists.py index d1f2111..3ccd989 100644 --- a/tests/db/test_artists.py +++ b/tests/db/test_artists.py @@ -1,7 +1,9 @@ # Copyright 2022 (c) Anna Schumaker. """Tests our artist Gio.ListModel.""" +import unittest.mock import emmental.db import tests.util +from gi.repository import Gtk class TestArtistObject(tests.util.TestCase): @@ -18,6 +20,7 @@ class TestArtistObject(tests.util.TestCase): def test_init(self): """Test that the Artist is set up properly.""" self.assertIsInstance(self.artist, emmental.db.playlist.Playlist) + self.assertSetEqual(self.artist.children.get_filter().keys, set()) self.assertEqual(self.artist.table, self.table) self.assertEqual(self.artist.propertyid, 456) self.assertEqual(self.artist.artistid, 123) @@ -25,6 +28,65 @@ class TestArtistObject(tests.util.TestCase): self.assertEqual(self.artist.mbid, "") self.assertIsNone(self.artist.parent) + def test_add_remove_album(self): + """Test that the Album Artist filter works as expected.""" + album = self.sql.albums.create("Test Album", "Album Artist", "2023-03") + + with unittest.mock.patch.object(self.table, "add_album", + return_value=True) as mock_add: + self.artist.add_album(album) + + mock_add.assert_called_with(self.artist, album) + self.assertSetEqual(self.artist.children.get_filter().keys, + {album.albumid}) + self.assertTrue(self.artist.has_album(album)) + + with unittest.mock.patch.object(self.table, "remove_album", + return_value=True) as mock_remove: + self.artist.remove_album(album) + + mock_remove.assert_called_with(self.artist, album) + self.assertSetEqual(self.artist.children.get_filter().keys, set()) + self.assertFalse(self.artist.has_album(album)) + + def test_children(self): + """Test that Albums have been added as Artist playlist children.""" + self.assertIsInstance(self.artist.children, Gtk.FilterListModel) + self.assertIsInstance(self.artist.children.get_filter(), + emmental.db.table.Filter) + self.assertEqual(self.artist.children.get_model(), self.sql.albums) + + +class TestFilter(tests.util.TestCase): + """Test the artist filter.""" + + def setUp(self): + """Set up common variables.""" + super().setUp() + self.artists = self.sql.artists + self.filter = emmental.db.artists.Filter() + + def test_init(self): + """Test that the filter is initialized properly.""" + self.assertIsInstance(self.filter, emmental.db.table.Filter) + self.assertFalse(self.filter.show_all) + + filter2 = emmental.db.artists.Filter(show_all=True) + self.assertTrue(filter2.show_all) + + def test_strictness(self): + """Test checking strictness.""" + self.assertEqual(self.filter.get_strictness(), Gtk.FilterMatch.SOME) + self.filter.show_all = True + self.assertEqual(self.filter.get_strictness(), Gtk.FilterMatch.ALL) + + def test_match(self): + """Test matching an artist.""" + artist = self.artists.create("Test Artist") + self.assertFalse(self.filter.match(artist)) + self.filter.show_all = True + self.assertTrue(self.filter.match(artist)) + class TestArtistTable(tests.util.TestCase): """Tests our artist table.""" @@ -33,10 +95,14 @@ class TestArtistTable(tests.util.TestCase): """Set up common variables.""" tests.util.TestCase.setUp(self) self.table = self.sql.artists + self.album = self.sql.albums.create("Test Album", "Album Artist", + "2023-03") def test_init(self): """Test that the artist model is configured correctly.""" self.assertIsInstance(self.table, emmental.db.playlist.Table) + self.assertIsInstance(self.table.get_filter(), + emmental.db.artists.Filter) self.assertEqual(len(self.table), 0) def test_construct(self): @@ -53,6 +119,8 @@ class TestArtistTable(tests.util.TestCase): def test_create(self): """Test creating an artist playlist.""" + self.table.show_all = True + artist1 = self.table.create("Test Artist") self.assertIsInstance(artist1, emmental.db.artists.Artist) self.assertEqual(artist1.name, "Test Artist") @@ -79,6 +147,8 @@ class TestArtistTable(tests.util.TestCase): def test_delete(self): """Test deleting an artist playlist.""" artist = self.table.create("Test Artist") + artist.add_album(self.album) + self.assertTrue(artist.delete()) self.assertIsNone(self.table.index(artist)) @@ -91,18 +161,29 @@ class TestArtistTable(tests.util.TestCase): WHERE propertyid=?""", artist.propertyid).fetchone() self.assertEqual(row["COUNT(*)"], 0) + cur = self.sql("SELECT COUNT(albumid) FROM album_artist_link") + self.assertEqual(cur.fetchone()["COUNT(albumid)"], 0) + self.assertFalse(artist.delete()) def test_filter(self): """Test filtering an artist playlist.""" - self.table.create("Artist 1") - self.table.create("Artist 2") + artist1 = self.table.create("Artist 1") + artist2 = self.table.create("Artist 2") + + artist1.add_album(self.sql.albums.create("Album 1", "Artist 1", "1")) + artist1.add_album(self.sql.albums.create("Album 2", "Artist 1", "2")) + artist2.add_album(self.sql.albums.create("Album 3", "Artist 2", "3")) + artist2.add_album(self.sql.albums.create("Album 4", "Artist 2", "4")) self.table.filter("*1", now=True) self.assertSetEqual(self.table.get_filter().keys, {1}) self.table.filter("artist*", now=True) self.assertSetEqual(self.table.get_filter().keys, {1, 2}) + self.table.filter("*4", now=True) + self.assertSetEqual(self.table.get_filter().keys, {2}) + def test_get_sort_key(self): """Test the get_sort_key() function.""" artist1 = self.table.create("Artist 1") @@ -115,10 +196,11 @@ class TestArtistTable(tests.util.TestCase): def test_load(self): """Test loading the artist table.""" - self.table.create("Artist 1") + artist = self.table.create("Artist 1") self.table.create("Artist 2", mbid="ab-cd-ef") + artist.add_album(self.album) - artists2 = emmental.db.artists.Table(self.sql) + artists2 = emmental.db.artists.Table(self.sql, show_all=True) self.assertEqual(len(artists2), 0) artists2.load(now=True) @@ -126,9 +208,13 @@ class TestArtistTable(tests.util.TestCase): self.assertEqual(artists2.get_item(0).name, "Artist 1") self.assertEqual(artists2.get_item(0).mbid, "") + self.assertSetEqual(artists2.get_item(0).children.get_filter().keys, + {1}) self.assertEqual(artists2.get_item(1).name, "Artist 2") self.assertEqual(artists2.get_item(1).mbid, "ab-cd-ef") + self.assertSetEqual(artists2.get_item(1).children.get_filter().keys, + set()) def test_lookup(self): """Test looking up artist playlists.""" @@ -155,3 +241,42 @@ class TestArtistTable(tests.util.TestCase): row = self.sql("""SELECT active FROM playlist_properties WHERE propertyid=?""", artist.propertyid).fetchone() self.assertTrue(row["active"]) + + def test_add_remove_album(self): + """Test adding an album to an artist.""" + artist = self.table.create("Test Artist") + artist.add_album(self.album) + self.assertTrue(artist.has_album(self.album)) + + row = self.sql("""SELECT albumid FROM album_artist_link + WHERE artistid=?""", artist.artistid).fetchone() + self.assertEqual(row["albumid"], self.album.albumid) + + artist.remove_album(self.album) + self.assertFalse(artist.has_album(self.album)) + + cur = self.sql("""SELECT albumid FROM album_artist_link + WHERE artistid=?""", artist.artistid) + self.assertIsNone(cur.fetchone()) + + artist.remove_album(self.album) + + def test_get_albumids(self): + """Test getting an artist's associated albumids from the database.""" + artist = self.table.create("Artist") + artist.add_album(self.album) + artist.add_album(self.sql.albums.create("Album 1", "Artist", "1")) + artist.add_album(self.sql.albums.create("Album 2", "Artist", "2")) + self.assertSetEqual(self.table.get_albumids(artist), {1, 2, 3}) + + def test_show_all(self): + """Test the show-all property.""" + self.assertFalse(self.table.show_all) + self.table.show_all = True + self.assertTrue(self.table.get_filter().show_all) + self.table.show_all = False + self.assertFalse(self.table.get_filter().show_all) + + table2 = emmental.db.artists.Table(self.sql, show_all=True) + self.assertTrue(table2.show_all) + self.assertTrue(table2.get_filter().show_all)