diff --git a/emmental/db/albums.py b/emmental/db/albums.py index 2f1da70..ca41098 100644 --- a/emmental/db/albums.py +++ b/emmental/db/albums.py @@ -23,10 +23,15 @@ class Album(playlist.Playlist): """Initialize an Album object.""" super().__init__(**kwargs) self.add_children(self.table.sql.media, - Gtk.CustomFilter.new(self.__match_medium)) + Gtk.CustomFilter.new(self.__match_medium), + self.table.get_mediumids(self)) def __match_medium(self, medium: Medium) -> bool: - return medium.albumid == self.albumid and len(medium.name) > 0 + return self.has_medium(medium) and len(medium.name) > 0 + + def add_medium(self, medium: Medium) -> None: + """Add a Medium to this Album.""" + self.add_child(medium) def get_artists(self) -> list[playlist.Playlist]: """Get a list of artists for this album.""" @@ -36,6 +41,14 @@ class Album(playlist.Playlist): """Get a list of media for this album.""" return self.table.get_media(self) + def has_medium(self, medium: Medium) -> bool: + """Check if a Medium is from this Album.""" + return self.has_child(medium) + + def remove_medium(self, medium: Medium) -> None: + """Remove a Medium from this Album.""" + return self.remove_child(medium) + @property def primary_key(self) -> int: """Get the Album primary key.""" @@ -139,6 +152,11 @@ class Table(playlist.Table): def get_media(self, album: Album) -> list[Medium]: """Get the list of media for this album.""" + return [self.sql.media.rows.get(id) + for id in self.get_mediumids(album)] + + def get_mediumids(self, album: Album) -> set[int]: + """Get the set of mediumids for this album.""" rows = self.sql("SELECT mediumid FROM media WHERE albumid=?", album.albumid) - return [self.sql.media.rows.get(row["mediumid"]) for row in rows] + return {row["mediumid"] for row in rows.fetchall()} diff --git a/emmental/db/media.py b/emmental/db/media.py index ed50720..e8af57d 100644 --- a/emmental/db/media.py +++ b/emmental/db/media.py @@ -77,6 +77,7 @@ class Table(playlist.Table): def do_sql_delete(self, medium: Medium) -> sqlite3.Cursor: """Delete a medium.""" + medium.get_album().remove_medium(medium) return self.sql("DELETE FROM media WHERE mediumid=?", medium.mediumid) @@ -116,6 +117,13 @@ class Table(playlist.Table): return self.sql(f"UPDATE media SET {column}=? WHERE mediumid=?", newval, medium.mediumid) + def create(self, album: playlist.Playlist, + *args, **kwargs) -> Medium | None: + """Create a new Medium playlist.""" + if (medium := super().create(album, *args, **kwargs)) is not None: + album.add_medium(medium) + return medium + def rename(self, medium: Medium, new_name: str) -> bool: """Rename a medium.""" if (new_name := new_name.strip()) != medium.name: diff --git a/tests/db/test_albums.py b/tests/db/test_albums.py index d89388e..fbdd970 100644 --- a/tests/db/test_albums.py +++ b/tests/db/test_albums.py @@ -41,6 +41,22 @@ class TestAlbumObject(tests.util.TestCase): self.assertEqual(album2.mbid, "ab-cd-ef") self.assertEqual(album2.cover, cover) + def test_add_remove_medium(self): + """Test adding and removing a medium from the Album.""" + album = self.table.create("Test Album", "Album Artist", "2023-03") + medium = self.sql.media.create(album, "Test Medium", number=1) + + self.assertFalse(medium in self.album.child_set) + self.assertFalse(self.album.has_medium(medium)) + + self.album.add_medium(medium) + self.assertTrue(medium in self.album.child_set) + self.assertTrue(self.album.has_medium(medium)) + + self.album.remove_medium(medium) + self.assertFalse(medium in self.album.child_set) + self.assertFalse(self.album.has_medium(medium)) + def test_get_artists(self): """Test getting the list of artists for this album.""" with unittest.mock.patch.object(self.table, "get_artists", @@ -65,13 +81,14 @@ class TestAlbumObject(tests.util.TestCase): album = self.table.create("Test Album", "Album Artist", "2023-03") medium = self.sql.media.create(album, "Test Medium", number=1) - self.assertTrue(album.children.get_filter().match(medium)) - medium.albumid = album.albumid + 1 - self.assertFalse(album.children.get_filter().match(medium)) + self.assertFalse(self.album.children.get_filter().match(medium)) - medium = self.sql.media.create(album, "", number=2) - self.assertFalse(album.children.get_filter().match(medium)) + self.album.add_medium(medium) + self.assertTrue(self.album.children.get_filter().match(medium)) + + medium.name = "" + self.assertFalse(self.album.children.get_filter().match(medium)) class TestAlbumTable(tests.util.TestCase): @@ -228,9 +245,11 @@ class TestAlbumTable(tests.util.TestCase): def test_load(self): """Test loading the album table.""" - self.table.create("Album 1", "Album Artist", "2023-03") + album = self.table.create("Album 1", "Album Artist", "2023-03") self.table.create("Album 2", "Album Artist", "2023-03", mbid="ab-cd-ef", cover=tests.util.COVER_JPG) + medium = self.sql.media.create(album, "Test Medium", number=1) + album.add_medium(medium) albums2 = emmental.db.albums.Table(self.sql) self.assertEqual(len(albums2), 0) @@ -243,6 +262,7 @@ class TestAlbumTable(tests.util.TestCase): self.assertEqual(albums2.get_item(0).release, "2023-03") self.assertEqual(albums2.get_item(0).mbid, "") self.assertIsNone(albums2.get_item(0).cover) + self.assertSetEqual(albums2.get_item(0).child_set.keyset.keys, {1}) self.assertEqual(albums2.get_item(1).name, "Album 2") self.assertEqual(albums2.get_item(1).artist, "Album Artist") @@ -250,6 +270,7 @@ class TestAlbumTable(tests.util.TestCase): self.assertEqual(albums2.get_item(1).mbid, "ab-cd-ef") self.assertEqual(albums2.get_item(1).cover, tests.util.COVER_JPG) + self.assertSetEqual(albums2.get_item(1).child_set.keyset.keys, set()) def test_lookup(self): """Test looking up album playlists.""" @@ -320,4 +341,6 @@ class TestAlbumTable(tests.util.TestCase): medium1 = self.sql.media.create(album, "", number=1) medium2 = self.sql.media.create(album, "", number=2) + self.assertSetEqual(self.table.get_mediumids(album), + {medium1.mediumid, medium2.mediumid}) self.assertListEqual(self.table.get_media(album), [medium1, medium2]) diff --git a/tests/db/test_media.py b/tests/db/test_media.py index a25d180..c8f103e 100644 --- a/tests/db/test_media.py +++ b/tests/db/test_media.py @@ -132,6 +132,7 @@ class TestMediumsTable(tests.util.TestCase): self.assertEqual(medium1.number, 1) self.assertEqual(medium1.type, "") self.assertEqual(medium1.sort_order, "mediumno, number") + self.assertTrue(self.album.has_medium(medium1)) cur = self.sql("SELECT COUNT(name) FROM media") self.assertEqual(cur.fetchone()["COUNT(name)"], 1) @@ -156,6 +157,7 @@ class TestMediumsTable(tests.util.TestCase): medium = self.table.create(self.album, "Medium 1", number=1) self.assertTrue(medium.delete()) self.assertIsNone(self.table.index(medium)) + self.assertFalse(self.album.has_medium(medium)) cur = self.sql("SELECT COUNT(name) FROM media") self.assertEqual(cur.fetchone()["COUNT(name)"], 0)