diff --git a/emmental/db/__init__.py b/emmental/db/__init__.py index caef014..ea117d2 100644 --- a/emmental/db/__init__.py +++ b/emmental/db/__init__.py @@ -7,6 +7,7 @@ from . import albums from . import artists from . import connection from . import playlist +from . import media from . import playlists from . import settings from . import table @@ -33,6 +34,7 @@ class Connection(connection.Connection): self.playlists = playlists.Table(self) self.artists = artists.Table(self) self.albums = albums.Table(self, queue=self.artists.queue) + self.media = media.Table(self, queue=self.artists.queue) def close(self) -> None: """Close the database connection.""" @@ -55,7 +57,7 @@ class Connection(connection.Connection): def playlist_tables(self) -> Generator[playlist.Table, None, None]: """Iterate over each playlist table.""" - for tbl in [self.playlists, self.artists, self.albums]: + for tbl in [self.playlists, self.artists, self.albums, self.media]: yield tbl def set_active_playlist(self, plist: playlist.Playlist) -> None: diff --git a/emmental/db/emmental.sql b/emmental/db/emmental.sql index cb721b7..aeb57bd 100644 --- a/emmental/db/emmental.sql +++ b/emmental/db/emmental.sql @@ -146,6 +146,44 @@ CREATE TRIGGER albums_delete_trigger AFTER DELETE ON albums END; +/************************* + * * + * Mediums * + * * + *************************/ + +CREATE TABLE media ( + mediumid INTEGER PRIMARY KEY, + propertyid INTEGER REFERENCES playlist_properties (propertyid) + ON DELETE CASCADE + ON UPDATE CASCADE, + albumid INTEGER NOT NULL REFERENCES albums (albumid) + ON DELETE CASCADE + ON UPDATE CASCADE, + number INTEGER NOT NULL, + name TEXT NOT NULL DEFAULT "" COLLATE NOCASE, + type TEXT NOT NULL DEFAULT "" COLLATE NOCASE, + UNIQUE (albumid, number, type) +); + +CREATE VIEW media_view AS + SELECT mediumid, propertyid, albumid, number, name, type, active + FROM media + JOIN playlist_properties USING (propertyid); + +CREATE TRIGGER media_insert_trigger AFTER INSERT ON media + BEGIN + INSERT INTO playlist_properties (active) VALUES (False); + UPDATE media SET propertyid = last_insert_rowid() + WHERE mediumid = NEW.mediumid; + END; + +CREATE TRIGGER media_delete_trigger AFTER DELETE ON media + BEGIN + DELETE FROM playlist_properties WHERE propertyid = OLD.propertyid; + END; + + /******************************************* * * * Artist <--> Album Linking * diff --git a/emmental/db/media.py b/emmental/db/media.py new file mode 100644 index 0000000..341f23e --- /dev/null +++ b/emmental/db/media.py @@ -0,0 +1,92 @@ +# Copyright 2022 (c) Anna Schumaker. +"""A custom Gio.ListModel for managing individual media in an album.""" +import sqlite3 +from gi.repository import GObject +from .. import format +from . import playlist + + +class Medium(playlist.Playlist): + """Our custom Medium object representing a single disc in an album.""" + + mediumid = GObject.Property(type=int) + albumid = GObject.Property(type=int) + number = GObject.Property(type=int, default=1) + type = GObject.Property(type=str) + + def get_album(self) -> playlist.Playlist: + """Get this Medium's Album.""" + return self.table.sql.albums.rows.get(self.albumid) + + def rename(self, new_name: str) -> bool: + """Rename this medium.""" + return self.table.rename(self, new_name) + + @property + def primary_key(self) -> int: + """Get this Medium's primary key.""" + return self.mediumid + + @GObject.Property(type=playlist.Playlist) + def parent(self) -> playlist.Playlist | None: + """Get this Medium's parent playlist.""" + return self.get_album() + + +class Table(playlist.Table): + """Our Media Table.""" + + def do_construct(self, **kwargs) -> Medium: + """Construct a new medium.""" + return Medium(**kwargs) + + def do_get_sort_key(self, medium: Medium) -> tuple[int, int, tuple, str]: + """Get the sort key for a medium.""" + return (medium.albumid, medium.number, + format.sort_key(medium.name), medium.type) + + def do_sql_delete(self, medium: Medium) -> sqlite3.Cursor: + """Delete a medium.""" + return self.sql("DELETE FROM media WHERE mediumid=?", + medium.mediumid) + + def do_sql_glob(self, glob: str) -> sqlite3.Cursor: + """Search for media names matching the search text.""" + return self.sql("""SELECT mediumid FROM media + WHERE CASEFOLD(name) GLOB ?""", glob) + + def do_sql_insert(self, album: playlist.Playlist, name: str, + *, number: int, type: str = "") -> sqlite3.Cursor | None: + """Create a new medium.""" + if cur := self.sql("""INSERT INTO media (albumid, number, name, type) + VALUES (?, ?, ?, ?)""", + album.albumid, number, name, type): + return self.sql("SELECT * FROM media_view WHERE mediumid=?", + cur.lastrowid) + + def do_sql_select_all(self) -> sqlite3.Cursor: + """Load media from the database.""" + return self.sql("SELECT * FROM media_view") + + def do_sql_select_one(self, album: playlist.Playlist, + *, number: int, type: str = "") -> sqlite3.Cursor: + """Look up a medium by album, number, and type.""" + return self.sql("""SELECT mediumid FROM media + WHERE albumid=? AND number=? AND type=?""", + album.albumid, number, type) + + def do_sql_update(self, medium: Medium, + column: str, newval) -> sqlite3.Cursor: + """Update a medium.""" + return self.sql(f"UPDATE media SET {column}=? WHERE mediumid=?", + newval, medium.mediumid) + + def rename(self, medium: Medium, new_name: str) -> bool: + """Rename a medium.""" + if (new_name := new_name.strip()) != medium.name: + if self.update(medium, "name", new_name): + self.store.remove(medium) + medium.name = new_name + self.store.append(medium) + return True + return False diff --git a/tests/db/test_db.py b/tests/db/test_db.py index ad26f33..50844a4 100644 --- a/tests/db/test_db.py +++ b/tests/db/test_db.py @@ -43,12 +43,14 @@ class TestConnection(tests.util.TestCase): self.assertIsInstance(self.sql.playlists, emmental.db.playlists.Table) self.assertIsInstance(self.sql.artists, emmental.db.artists.Table) self.assertIsInstance(self.sql.albums, emmental.db.albums.Table) + self.assertIsInstance(self.sql.media, emmental.db.media.Table) self.assertEqual(self.sql.albums.queue, self.sql.artists.queue) + self.assertEqual(self.sql.media.queue, self.sql.artists.queue) self.assertListEqual([tbl for tbl in self.sql.playlist_tables()], [self.sql.playlists, self.sql.artists, - self.sql.albums]) + self.sql.albums, self.sql.media]) def test_load(self): """Check that calling load() loads the tables.""" diff --git a/tests/db/test_media.py b/tests/db/test_media.py new file mode 100644 index 0000000..8d85483 --- /dev/null +++ b/tests/db/test_media.py @@ -0,0 +1,197 @@ +# Copyright 2022 (c) Anna Schumaker. +"""Tests our medium Gio.ListModel.""" +import unittest.mock +import emmental.db +import tests.util + + +class TestMediumObject(tests.util.TestCase): + """Tests our medium object.""" + + def setUp(self): + """Set up common variables.""" + super().setUp() + self.table = self.sql.media + self.medium = emmental.db.media.Medium(table=self.table, name="", + mediumid=123, propertyid=456) + + def test_init(self): + """Test that the Media is set up properly.""" + self.assertIsInstance(self.medium, emmental.db.playlist.Playlist) + self.assertEqual(self.medium.table, self.table) + self.assertEqual(self.medium.propertyid, 456) + self.assertEqual(self.medium.mediumid, 123) + self.assertEqual(self.medium.primary_key, 123) + self.assertEqual(self.medium.albumid, 0) + self.assertEqual(self.medium.number, 1) + self.assertEqual(self.medium.name, "") + self.assertEqual(self.medium.type, "") + + def test_get_album(self): + """Test getting this Medium's Album.""" + self.assertIsNone(self.medium.get_album()) + self.assertIsNone(self.medium.parent) + + album = self.sql.albums.create("Test Album", "Album Artist", "2023") + self.medium.albumid = album.albumid + self.assertEqual(self.medium.get_album(), album) + self.assertEqual(self.medium.parent, album) + + def test_rename(self): + """Test the rename() function.""" + with unittest.mock.patch.object(self.table, "rename", + return_value=True) as mock_rename: + self.assertTrue(self.medium.rename("New Name")) + mock_rename.assert_called_with(self.medium, "New Name") + + +class TestMediumsTable(tests.util.TestCase): + """Tests our mediums table.""" + + def setUp(self): + """Set up common variables.""" + super().setUp() + self.album = self.sql.albums.create("Test Album", "Test Artist", "123") + self.table = self.sql.media + + def test_init(self): + """Test that the medium model is configured corretly.""" + self.assertIsInstance(self.table, emmental.db.playlist.Table) + self.assertEqual(len(self.table), 0) + + def test_construct(self): + """Test constructing a medium playlist.""" + medium = self.table.construct(mediumid=1, propertyid=1, + albumid=self.album.albumid, + name="Medium 2", number=2, type="CD") + self.assertIsInstance(medium, emmental.db.media.Medium) + self.assertEqual(medium.table, self.table) + self.assertEqual(medium.propertyid, 1) + self.assertEqual(medium.mediumid, 1) + self.assertEqual(medium.albumid, self.album.albumid) + self.assertEqual(medium.name, "Medium 2") + self.assertEqual(medium.number, 2) + self.assertEqual(medium.type, "CD") + + def test_create(self): + """Test creating a medium playlist.""" + medium1 = self.table.create(self.album, "", number=1) + self.assertIsInstance(medium1, emmental.db.media.Medium) + self.assertEqual(medium1.albumid, self.album.albumid) + self.assertEqual(medium1.name, "") + self.assertEqual(medium1.number, 1) + self.assertEqual(medium1.type, "") + + cur = self.sql("SELECT COUNT(name) FROM media") + self.assertEqual(cur.fetchone()["COUNT(name)"], 1) + + row = self.sql("""SELECT COUNT(*) FROM playlist_properties + WHERE propertyid=?""", medium1.propertyid).fetchone() + self.assertEqual(row["COUNT(*)"], 1) + + medium2 = self.table.create(self.album, "Test Medium", + number=2, type="CD") + self.assertEqual(medium2.name, "Test Medium") + self.assertEqual(medium2.number, 2) + self.assertEqual(medium2.type, "CD") + + cur = self.sql("SELECT COUNT(name) FROM media") + self.assertEqual(cur.fetchone()["COUNT(name)"], 2) + + self.assertIsNone(self.table.create(self.album, "", number=1)) + + def test_delete(self): + """Test deleting a medium playlist.""" + medium = self.table.create(self.album, "Medium 1", number=1) + self.assertTrue(medium.delete()) + self.assertIsNone(self.table.index(medium)) + + cur = self.sql("SELECT COUNT(name) FROM media") + self.assertEqual(cur.fetchone()["COUNT(name)"], 0) + self.assertEqual(len(self.table), 0) + self.assertIsNone(self.table.get_item(0)) + + row = self.sql("""SELECT COUNT(*) FROM playlist_properties + WHERE propertyid=?""", medium.propertyid).fetchone() + self.assertEqual(row["COUNT(*)"], 0) + + self.assertFalse(medium.delete()) + + def test_filter(self): + """Test filtering medium playlists.""" + self.table.create(self.album, "Medium 1", number=1) + self.table.create(self.album, "Medium 2", number=2) + self.table.create(self.album, "", number=3) + + self.table.filter("*1", now=True) + self.assertSetEqual(self.table.get_filter().keys, {1}) + self.table.filter("medium*", now=True) + self.assertSetEqual(self.table.get_filter().keys, {1, 2}) + + def test_get_sort_key(self): + """Test getting a medium's sort key.""" + medium = self.table.create(self.album, "Medium 2", number=2) + self.assertTupleEqual(self.table.get_sort_key(medium), + (1, 2, ("medium", "2"), "")) + + def test_load(self): + """Test loading mediums from the database.""" + self.table.create(self.album, "", number=1) + self.table.create(self.album, "Medium 2", number=2, + type="Digital Media") + + mediums2 = emmental.db.media.Table(self.sql) + self.assertEqual(len(mediums2), 0) + + mediums2.load(now=True) + self.assertEqual(len(mediums2), 2) + + self.assertEqual(mediums2.get_item(0).albumid, self.album.albumid) + self.assertEqual(mediums2.get_item(0).name, "") + self.assertEqual(mediums2.get_item(0).number, 1) + self.assertEqual(mediums2.get_item(0).type, "") + + self.assertEqual(mediums2.get_item(1).albumid, self.album.albumid) + self.assertEqual(mediums2.get_item(1).name, "Medium 2") + self.assertEqual(mediums2.get_item(1).number, 2) + self.assertEqual(mediums2.get_item(1).type, "Digital Media") + + def test_lookup(self): + """Test looking up medium playlists.""" + medium1 = self.table.create(self.album, "Test Medium", + number=1, type="CD") + medium2 = self.table.create(self.album, "", number=2, + type="Digital Media") + + self.assertEqual(self.table.lookup(self.album, number=1, + type="CD"), medium1) + self.assertEqual(self.table.lookup(self.album, number=2, + type="Digital Media"), medium2) + self.assertIsNone(self.table.lookup(self.album, number=3, + type="Enhanced CD")) + + def test_rename(self): + """Test renaming a medium playlist.""" + medium1 = self.table.create(self.album, "Medium 1", + number=1, type="CD") + medium2 = self.table.create(self.album, "Medium 2", + number=1, type="Digital Medium") + + self.assertTrue(medium1.rename("Medium 3")) + self.assertEqual(medium1.name, "Medium 3") + self.assertIsNone(self.sql("SELECT name FROM media WHERE name=?", + "Test Medium").fetchone()) + + self.assertListEqual(self.table.store.items, [medium2, medium1]) + + self.assertFalse(medium1.rename("Medium 3")) + + def test_update(self): + """Test updating medium attributes.""" + medium = self.table.create(self.album, "Test Medium", + number=1, type="CD") + medium.active = True + + row = self.sql("""SELECT active FROM playlist_properties + WHERE propertyid=?""", medium.propertyid).fetchone() + self.assertEqual(row["active"], True)