db: Have Album playlists use the new child_set

I implement add_medium(), remove_medium(), and has_medium() functions
and make sure we load the set of mediumids during startup. Additionally,
I have Mediums add and remove themselves from Albums as they are created
and deleted.

Signed-off-by: Anna Schumaker <Anna@NoWheyCreamery.com>
This commit is contained in:
Anna Schumaker 2023-06-22 20:09:46 -04:00
parent 4f15bde850
commit 3cddde0986
4 changed files with 60 additions and 9 deletions

View File

@ -23,10 +23,15 @@ class Album(playlist.Playlist):
"""Initialize an Album object."""
super().__init__(**kwargs)
self.add_children(self.table.sql.media,
Gtk.CustomFilter.new(self.__match_medium))
Gtk.CustomFilter.new(self.__match_medium),
self.table.get_mediumids(self))
def __match_medium(self, medium: Medium) -> bool:
return medium.albumid == self.albumid and len(medium.name) > 0
return self.has_medium(medium) and len(medium.name) > 0
def add_medium(self, medium: Medium) -> None:
"""Add a Medium to this Album."""
self.add_child(medium)
def get_artists(self) -> list[playlist.Playlist]:
"""Get a list of artists for this album."""
@ -36,6 +41,14 @@ class Album(playlist.Playlist):
"""Get a list of media for this album."""
return self.table.get_media(self)
def has_medium(self, medium: Medium) -> bool:
"""Check if a Medium is from this Album."""
return self.has_child(medium)
def remove_medium(self, medium: Medium) -> None:
"""Remove a Medium from this Album."""
return self.remove_child(medium)
@property
def primary_key(self) -> int:
"""Get the Album primary key."""
@ -139,6 +152,11 @@ class Table(playlist.Table):
def get_media(self, album: Album) -> list[Medium]:
"""Get the list of media for this album."""
return [self.sql.media.rows.get(id)
for id in self.get_mediumids(album)]
def get_mediumids(self, album: Album) -> set[int]:
"""Get the set of mediumids for this album."""
rows = self.sql("SELECT mediumid FROM media WHERE albumid=?",
album.albumid)
return [self.sql.media.rows.get(row["mediumid"]) for row in rows]
return {row["mediumid"] for row in rows.fetchall()}

View File

@ -77,6 +77,7 @@ class Table(playlist.Table):
def do_sql_delete(self, medium: Medium) -> sqlite3.Cursor:
"""Delete a medium."""
medium.get_album().remove_medium(medium)
return self.sql("DELETE FROM media WHERE mediumid=?",
medium.mediumid)
@ -116,6 +117,13 @@ class Table(playlist.Table):
return self.sql(f"UPDATE media SET {column}=? WHERE mediumid=?",
newval, medium.mediumid)
def create(self, album: playlist.Playlist,
*args, **kwargs) -> Medium | None:
"""Create a new Medium playlist."""
if (medium := super().create(album, *args, **kwargs)) is not None:
album.add_medium(medium)
return medium
def rename(self, medium: Medium, new_name: str) -> bool:
"""Rename a medium."""
if (new_name := new_name.strip()) != medium.name:

View File

@ -41,6 +41,22 @@ class TestAlbumObject(tests.util.TestCase):
self.assertEqual(album2.mbid, "ab-cd-ef")
self.assertEqual(album2.cover, cover)
def test_add_remove_medium(self):
"""Test adding and removing a medium from the Album."""
album = self.table.create("Test Album", "Album Artist", "2023-03")
medium = self.sql.media.create(album, "Test Medium", number=1)
self.assertFalse(medium in self.album.child_set)
self.assertFalse(self.album.has_medium(medium))
self.album.add_medium(medium)
self.assertTrue(medium in self.album.child_set)
self.assertTrue(self.album.has_medium(medium))
self.album.remove_medium(medium)
self.assertFalse(medium in self.album.child_set)
self.assertFalse(self.album.has_medium(medium))
def test_get_artists(self):
"""Test getting the list of artists for this album."""
with unittest.mock.patch.object(self.table, "get_artists",
@ -65,13 +81,14 @@ class TestAlbumObject(tests.util.TestCase):
album = self.table.create("Test Album", "Album Artist", "2023-03")
medium = self.sql.media.create(album, "Test Medium", number=1)
self.assertTrue(album.children.get_filter().match(medium))
medium.albumid = album.albumid + 1
self.assertFalse(album.children.get_filter().match(medium))
self.assertFalse(self.album.children.get_filter().match(medium))
medium = self.sql.media.create(album, "", number=2)
self.assertFalse(album.children.get_filter().match(medium))
self.album.add_medium(medium)
self.assertTrue(self.album.children.get_filter().match(medium))
medium.name = ""
self.assertFalse(self.album.children.get_filter().match(medium))
class TestAlbumTable(tests.util.TestCase):
@ -228,9 +245,11 @@ class TestAlbumTable(tests.util.TestCase):
def test_load(self):
"""Test loading the album table."""
self.table.create("Album 1", "Album Artist", "2023-03")
album = self.table.create("Album 1", "Album Artist", "2023-03")
self.table.create("Album 2", "Album Artist", "2023-03",
mbid="ab-cd-ef", cover=tests.util.COVER_JPG)
medium = self.sql.media.create(album, "Test Medium", number=1)
album.add_medium(medium)
albums2 = emmental.db.albums.Table(self.sql)
self.assertEqual(len(albums2), 0)
@ -243,6 +262,7 @@ class TestAlbumTable(tests.util.TestCase):
self.assertEqual(albums2.get_item(0).release, "2023-03")
self.assertEqual(albums2.get_item(0).mbid, "")
self.assertIsNone(albums2.get_item(0).cover)
self.assertSetEqual(albums2.get_item(0).child_set.keyset.keys, {1})
self.assertEqual(albums2.get_item(1).name, "Album 2")
self.assertEqual(albums2.get_item(1).artist, "Album Artist")
@ -250,6 +270,7 @@ class TestAlbumTable(tests.util.TestCase):
self.assertEqual(albums2.get_item(1).mbid, "ab-cd-ef")
self.assertEqual(albums2.get_item(1).cover,
tests.util.COVER_JPG)
self.assertSetEqual(albums2.get_item(1).child_set.keyset.keys, set())
def test_lookup(self):
"""Test looking up album playlists."""
@ -320,4 +341,6 @@ class TestAlbumTable(tests.util.TestCase):
medium1 = self.sql.media.create(album, "", number=1)
medium2 = self.sql.media.create(album, "", number=2)
self.assertSetEqual(self.table.get_mediumids(album),
{medium1.mediumid, medium2.mediumid})
self.assertListEqual(self.table.get_media(album), [medium1, medium2])

View File

@ -132,6 +132,7 @@ class TestMediumsTable(tests.util.TestCase):
self.assertEqual(medium1.number, 1)
self.assertEqual(medium1.type, "")
self.assertEqual(medium1.sort_order, "mediumno, number")
self.assertTrue(self.album.has_medium(medium1))
cur = self.sql("SELECT COUNT(name) FROM media")
self.assertEqual(cur.fetchone()["COUNT(name)"], 1)
@ -156,6 +157,7 @@ class TestMediumsTable(tests.util.TestCase):
medium = self.table.create(self.album, "Medium 1", number=1)
self.assertTrue(medium.delete())
self.assertIsNone(self.table.index(medium))
self.assertFalse(self.album.has_medium(medium))
cur = self.sql("SELECT COUNT(name) FROM media")
self.assertEqual(cur.fetchone()["COUNT(name)"], 0)