diff --git a/emmental/db/__init__.py b/emmental/db/__init__.py index 3207115..24289d8 100644 --- a/emmental/db/__init__.py +++ b/emmental/db/__init__.py @@ -2,8 +2,10 @@ """Easily work with our underlying sqlite3 database.""" import pathlib from gi.repository import GObject +from typing import Generator from . import connection from . import playlist +from . import playlists from . import settings from . import table @@ -26,16 +28,31 @@ class Connection(connection.Connection): self._sql.executescript(f.read()) self.settings = settings.Table(self) + self.playlists = playlists.Table(self) def close(self) -> None: """Close the database connection.""" self.settings.stop() + for tbl in self.playlist_tables(): + tbl.stop() super().close() + def filter(self, glob: str) -> None: + """Filter the playlist tables.""" + for tbl in self.playlist_tables(): + tbl.filter(glob) + def load(self) -> None: """Load the database tables.""" self.settings.load() + for tbl in self.playlist_tables(): + tbl.load() + + def playlist_tables(self) -> Generator[playlist.Table, None, None]: + """Iterate over each playlist table.""" + for tbl in [self.playlists]: + yield tbl def set_active_playlist(self, plist: playlist.Playlist) -> None: """Set the currently active playlist.""" diff --git a/emmental/db/connection.py b/emmental/db/connection.py index 5fb1ac4..dbe9cad 100644 --- a/emmental/db/connection.py +++ b/emmental/db/connection.py @@ -1,5 +1,6 @@ # Copyright 2022 (c) Anna Schumaker """Easily work with our underlying sqlite3 database.""" +import pathlib import sqlite3 import sys from gi.repository import GObject @@ -10,6 +11,20 @@ DATA_FILE = gsetup.DATA_DIR / f"emmental{gsetup.DEBUG_STR}.sqlite3" DATABASE = ":memory:" if "unittest" in sys.modules else DATA_FILE +def adapt_path(path: pathlib.Path) -> str: + """Adapt a pathlib.Path into a sqlite3 string.""" + return str(path) + + +def convert_path(path: bytes) -> pathlib.Path: + """Convert a path string into a pathlib.Path object.""" + return pathlib.Path(path.decode()) + + +sqlite3.register_adapter(pathlib.PosixPath, adapt_path) +sqlite3.register_converter("path", convert_path) + + class Connection(GObject.GObject): """Connect to the database.""" diff --git a/emmental/db/emmental.sql b/emmental/db/emmental.sql index da1afc0..a418ca2 100644 --- a/emmental/db/emmental.sql +++ b/emmental/db/emmental.sql @@ -35,3 +35,52 @@ CREATE TRIGGER playlists_active_trigger SET active = FALSE WHERE propertyid != NEW.propertyid AND active == TRUE; END; + + +/******************************************* + * * + * User and System Playlists * + * * + *******************************************/ + +CREATE TABLE playlists ( + playlistid INTEGER PRIMARY KEY, + propertyid INTEGER REFERENCES playlist_properties(propertyid) + ON DELETE CASCADE + ON UPDATE CASCADE, + name TEXT NOT NULL UNIQUE COLLATE NOCASE, + image PATH +); + +CREATE VIEW playlists_view AS + SELECT playlistid, propertyid, name, image, active + FROM playlists + JOIN playlist_properties USING (propertyid); + +CREATE TRIGGER playlists_insert_trigger AFTER INSERT ON playlists + BEGIN + INSERT INTO playlist_properties (active) + VALUES (NEW.name == "Collection"); + UPDATE playlists SET propertyid = last_insert_rowid() + WHERE playlistid = NEW.playlistid; + END; + +CREATE TRIGGER playlists_delete_trigger AFTER DELETE ON playlists + BEGIN + DELETE FROM playlist_properties WHERE propertyid = OLD.propertyid; + END; + +/****************************************** + * * + * Create Default Playlists * + * * + ******************************************/ + +INSERT INTO playlists (name) VALUES + ("Collection"), + ("Favorite Tracks"), + ("Most Played Tracks"), + ("New Tracks"), + ("Previous Tracks"), + ("Queued Tracks"), + ("Unplayed Tracks"); diff --git a/emmental/db/playlists.py b/emmental/db/playlists.py new file mode 100644 index 0000000..166ef72 --- /dev/null +++ b/emmental/db/playlists.py @@ -0,0 +1,91 @@ +# Copyright 2022 (c) Anna Schumaker +"""A custom Gio.ListModel for working with playlists.""" +import sqlite3 +from gi.repository import GObject +from . import playlist + + +class Playlist(playlist.Playlist): + """Our custom Playlist with an image filepath.""" + + playlistid = GObject.Property(type=int) + image = GObject.Property(type=GObject.TYPE_PYOBJECT) + + def rename(self, new_name: str) -> bool: + """Rename this playlist.""" + return self.table.rename(self, new_name) + + @property + def primary_key(self) -> int: + """Get the playlist primary key.""" + return self.playlistid + + +class Table(playlist.Table): + """Our Playlist Table.""" + + collection = GObject.Property(type=Playlist) + favorites = GObject.Property(type=Playlist) + most_played = GObject.Property(type=Playlist) + new_tracks = GObject.Property(type=Playlist) + previous = GObject.Property(type=Playlist) + queued = GObject.Property(type=Playlist) + unplayed = GObject.Property(type=Playlist) + + def do_construct(self, **kwargs) -> Playlist: + """Construct a new playlist.""" + match (plist := Playlist(**kwargs)).name: + case "Collection": self.collection = plist + case "Favorite Tracks": self.favorites = plist + case "Most Played Tracks": self.most_played = plist + case "New Tracks": self.new_tracks = plist + case "Previous Tracks": self.previous = plist + case "Queued Tracks": self.queued = plist + case "Unplayed Tracks": self.unplayed = plist + return plist + + def do_sql_delete(self, playlist: Playlist) -> sqlite3.Cursor: + """Delete a playlist.""" + return self.sql("DELETE FROM playlists WHERE playlistid=?", + playlist.playlistid) + + def do_sql_glob(self, glob: str) -> sqlite3.Cursor: + """Search for playlists matching the search text.""" + return self.sql("""SELECT playlistid FROM playlists + WHERE CASEFOLD(name) GLOB ?""", glob) + + def do_sql_insert(self, name: str, **kwargs) -> sqlite3.Cursor | None: + """Insert a new playlist into the database.""" + if (cur := self.sql("INSERT INTO playlists (name) VALUES (?)", name)): + return self.sql("SELECT * FROM playlists_view WHERE playlistid=?", + cur.lastrowid) + + def do_sql_select_all(self) -> sqlite3.Cursor: + """Load playlists from the database.""" + return self.sql("SELECT * FROM playlists_view") + + def do_sql_select_one(self, name: str) -> sqlite3.Cursor: + """Look up a playlist by name.""" + return self.sql("SELECT playlistid FROM playlists WHERE name=?", name) + + def do_sql_update(self, playlist: Playlist, + column: str, newval) -> sqlite3.Cursor: + """Update a playlist.""" + return self.sql(f"UPDATE playlists SET {column}=? WHERE playlistid=?", + newval, playlist.playlistid) + + def create(self, name: str) -> Playlist: + """Create a new Playlist.""" + if len(name := name.strip()) > 0: + return super().create(name) + + def rename(self, playlist: Playlist, new_name: str) -> bool: + """Rename a Playlist.""" + if len(new_name := new_name.strip()) > 0: + if playlist.name != new_name: + if self.update(playlist, "name", new_name): + self.store.remove(playlist) + playlist.name = new_name + self.store.append(playlist) + return True + return False diff --git a/tests/db/test_connection.py b/tests/db/test_connection.py index 7e3ae32..e0cc583 100644 --- a/tests/db/test_connection.py +++ b/tests/db/test_connection.py @@ -1,5 +1,6 @@ # Copyright 2022 (c) Anna Schumaker """Test our custom db Connection object.""" +import pathlib import sqlite3 import emmental.db.connection import unittest @@ -78,6 +79,14 @@ class TestConnection(unittest.TestCase): self.assertEqual(tuple(rows[3]), (4, "d")) self.assertEqual(tuple(rows[4]), (5, "e")) + def test_path_column(self): + """Test that the PATH column type has been set up.""" + self.sql("CREATE TABLE test (path PATH)") + self.sql("INSERT INTO test VALUES (?)", pathlib.Path("/my/test/path")) + row = self.sql("SELECT path FROM test").fetchone() + self.assertIsInstance(row["path"], pathlib.Path) + self.assertEqual(row["path"], pathlib.Path("/my/test/path")) + def test_transaction(self): """Test that we can manually start a transaction.""" self.assertFalse(self.sql._sql.in_transaction) diff --git a/tests/db/test_db.py b/tests/db/test_db.py index 084482e..2503b20 100644 --- a/tests/db/test_db.py +++ b/tests/db/test_db.py @@ -26,16 +26,24 @@ class TestConnection(tests.util.TestCase): def test_close(self): """Check closing the connection.""" self.sql.settings.queue.running = True + for tbl in self.sql.playlist_tables(): + tbl.queue.running = True self.sql.close() self.assertFalse(self.sql.connected) self.assertFalse(self.sql.settings.queue.running) + for tbl in self.sql.playlist_tables(): + self.assertFalse(tbl.queue.running) self.sql.close() def test_tables(self): """Check that the connection has pointers to our tables.""" self.assertIsInstance(self.sql.settings, emmental.db.settings.Table) + self.assertIsInstance(self.sql.playlists, emmental.db.playlists.Table) + + self.assertListEqual([tbl for tbl in self.sql.playlist_tables()], + [self.sql.playlists]) def test_load(self): """Check that calling load() loads the tables.""" @@ -44,11 +52,26 @@ class TestConnection(tests.util.TestCase): self.sql.load() self.assertTrue(self.sql.settings.loaded) + for tbl in self.sql.playlist_tables(): + self.assertFalse(tbl.loaded) + for tbl in self.sql.playlist_tables(): + tbl.queue.complete() + self.assertTrue(tbl.loaded) + tables = [tbl for tbl in self.sql.playlist_tables()] calls = [unittest.mock.call(self.sql, tbl) - for tbl in [self.sql.settings]] + for tbl in [self.sql.settings] + tables] table_loaded.assert_has_calls(calls) + def test_filter(self): + """Check filtering the playlist tables.""" + for tbl in self.sql.playlist_tables(): + tbl.filter = unittest.mock.Mock() + + self.sql.filter("*glob*") + for tbl in self.sql.playlist_tables(): + tbl.filter.assert_called_with("*glob*") + def test_set_active_playlist(self): """Check setting the active playlist.""" table = tests.util.playlist.MockTable(self.sql) diff --git a/tests/db/test_playlists.py b/tests/db/test_playlists.py new file mode 100644 index 0000000..4917207 --- /dev/null +++ b/tests/db/test_playlists.py @@ -0,0 +1,259 @@ +# Copyright 2022 (c) Anna Schumaker +"""Tests our playlist Gio.ListModel.""" +import pathlib +import unittest.mock +import emmental.db +import tests.util + + +class TestPlaylistObject(tests.util.TestCase): + """Tests our playlist object.""" + + def setUp(self): + """Set up common variables.""" + super().setUp() + self.table = self.sql.playlists + self.playlist = emmental.db.playlists.Playlist(table=self.table, + playlistid=12345, + propertyid=67890, + name="Test Playlist") + + def test_init(self): + """Test that the Playlist is set up properly.""" + self.assertIsInstance(self.playlist, emmental.db.playlist.Playlist) + self.assertEqual(self.playlist.table, self.table) + self.assertEqual(self.playlist.propertyid, 67890) + self.assertEqual(self.playlist.playlistid, 12345) + self.assertEqual(self.playlist.primary_key, 12345) + self.assertEqual(self.playlist.name, "Test Playlist") + self.assertIsNone(self.playlist.image) + self.assertIsNone(self.playlist.parent) + + def test_image_path(self): + """Test the image-path property.""" + path = pathlib.Path("/a/b/c.jpg") + playlist = emmental.db.playlists.Playlist(table=self.table, + playlistid=1, propertyid=1, + image=path, name="Test") + self.assertEqual(playlist.image, path) + + def test_rename(self): + """Test the rename() function.""" + with unittest.mock.patch.object(self.table, "rename", + return_value=True) as mock_rename: + self.assertTrue(self.playlist.rename("New Name")) + mock_rename.assert_called_with(self.playlist, "New Name") + + +class TestPlaylistTable(tests.util.TestCase): + """Tests our playlist table.""" + + def setUp(self): + """Set up common variables.""" + super().setUp() + self.sql("DELETE FROM playlists") + self.table = self.sql.playlists + + def test_init(self): + """Test that the playlist model is configured correctly.""" + self.assertIsInstance(self.table, emmental.db.playlist.Table) + self.assertEqual(len(self.table), 0) + + self.assertIsNone(self.table.collection) + self.assertIsNone(self.table.favorites) + self.assertIsNone(self.table.most_played) + self.assertIsNone(self.table.new_tracks) + self.assertIsNone(self.table.previous) + self.assertIsNone(self.table.queued) + self.assertIsNone(self.table.unplayed) + + def test_construct(self): + """Test constructing a playlist.""" + playlist = self.table.construct(playlistid=1, propertyid=1, + name="Test Playlist") + self.assertIsInstance(playlist, emmental.db.playlists.Playlist) + self.assertEqual(playlist.table, self.table) + self.assertEqual(playlist.propertyid, 1) + self.assertEqual(playlist.playlistid, 1) + self.assertEqual(playlist.name, "Test Playlist") + self.assertIsNone(playlist.image) + + def test_create(self): + """Test creating a playlist.""" + playlist = self.table.create(" Test Playlist ") + self.assertIsInstance(playlist, emmental.db.playlists.Playlist) + self.assertEqual(playlist.name, "Test Playlist") + self.assertIsNone(playlist.image) + + cur = self.sql("SELECT COUNT(name) FROM playlists") + self.assertEqual(cur.fetchone()["COUNT(name)"], 1) + self.assertEqual(len(self.table), 1) + self.assertEqual(self.table.get_item(0), playlist) + + cur = self.sql("SELECT COUNT(*) FROM playlist_properties") + self.assertEqual(cur.fetchone()["COUNT(*)"], 1) + + for name in ["", " ", "Test Playlist", "test playlist"]: + self.assertIsNone(self.table.create(name)) + self.assertEqual(len(self.table), 1) + cur = self.sql("SELECT COUNT(rowid) FROM playlists") + self.assertEqual(cur.fetchone()["COUNT(rowid)"], 1) + + def test_delete(self): + """Test deleting a playlist.""" + playlist = self.table.create("Test Playlist") + self.assertTrue(playlist.delete()) + self.assertIsNone(self.table.index(playlist)) + + cur = self.sql("SELECT COUNT(name) FROM playlists") + self.assertEqual(cur.fetchone()["COUNT(name)"], 0) + self.assertEqual(len(self.table), 0) + self.assertIsNone(self.table.get_item(0)) + + cur = self.sql("SELECT COUNT(*) FROM playlist_properties") + self.assertEqual(cur.fetchone()["COUNT(*)"], 0) + + self.assertFalse(playlist.delete()) + + def test_filter(self): + """Test filtering the playlist model.""" + self.table.create("Playlist 1") + self.table.create("Playlist 2") + + self.table.filter("*1", now=True) + self.assertSetEqual(self.table.get_filter().keys, {1}) + + self.table.filter("playlist*", now=True) + self.assertSetEqual(self.table.get_filter().keys, {1, 2}) + + def test_load(self): + """Test loading playlists from the database.""" + self.table.create("Playlist 1").image = tests.util.COVER_JPG + self.table.create("Playlist 2") + + playlists2 = emmental.db.playlists.Table(self.sql) + playlists2.load(now=True) + + self.assertEqual(len(playlists2), 2) + self.assertEqual(playlists2[0].name, "Playlist 1") + self.assertEqual(playlists2[0].image, tests.util.COVER_JPG) + self.assertEqual(playlists2[1].name, "Playlist 2") + self.assertIsNone(playlists2[1].image) + + def test_lookup(self): + """Test looking up a playlist.""" + playlist = self.table.create("Test Playlist") + self.assertEqual(self.table.lookup("Test Playlist"), playlist) + self.assertEqual(self.table.lookup("test playlist"), playlist) + self.assertIsNone(self.table.lookup("No Playlist")) + + def test_rename(self): + """Test renaming a playlist.""" + playlist = self.table.create("Test Playlist") + + self.table.store.append = unittest.mock.Mock() + self.table.store.remove = unittest.mock.Mock() + self.assertTrue(playlist.rename(" New Name ")) + self.assertEqual(playlist.name, "New Name") + self.table.store.remove.assert_called_with(playlist) + self.table.store.append.assert_called_with(playlist) + + rows = self.sql("SELECT name FROM playlists").fetchall() + self.assertEqual(len(rows), 1) + self.assertEqual(rows[0]["name"], "New Name") + + self.table.create("Other Name") + self.assertFalse(playlist.rename("New Name")) + self.assertFalse(playlist.rename("Other Name")) + + def test_update(self): + """Test updating playlist properties.""" + playlist = self.table.create("Test Playlist") + playlist.image = tests.util.COVER_JPG + playlist.active = True + + cur = self.sql("""SELECT image, active FROM playlists_view + WHERE playlistid=?""", playlist.playlistid) + row = cur.fetchone() + self.assertEqual(row["image"], tests.util.COVER_JPG) + self.assertTrue(row["active"]) + + +class TestSystemPlaylists(tests.util.TestCase): + """Tests our system playlists.""" + + def setUp(self): + """Set up common variables.""" + super().setUp() + self.table = self.sql.playlists + self.table.load(now=True) + + def test_collection(self): + """Test the Collection playlist.""" + self.assertIsInstance(self.table.collection, + emmental.db.playlists.Playlist) + self.assertEqual(self.table.collection.name, "Collection") + self.assertTrue(self.table.collection.active) + + self.assertEqual(self.table.lookup("Collection"), + self.table.collection) + + def test_favorites(self): + """Test the favorite tracks playlist.""" + self.assertIsInstance(self.table.favorites, + emmental.db.playlists.Playlist) + self.assertEqual(self.table.favorites.name, "Favorite Tracks") + self.assertFalse(self.table.favorites.active) + + self.assertEqual(self.table.lookup("Favorite Tracks"), + self.table.favorites) + + def test_most_played(self): + """Test the most-played tracks playlist.""" + self.assertIsInstance(self.table.most_played, + emmental.db.playlists.Playlist) + self.assertEqual(self.table.most_played.name, "Most Played Tracks") + self.assertFalse(self.table.most_played.active) + + self.assertEqual(self.table.lookup("Most Played Tracks"), + self.table.most_played) + + def test_new_tracks(self): + """Test the new tracks playlist.""" + self.assertIsInstance(self.table.new_tracks, + emmental.db.playlists.Playlist) + self.assertEqual(self.table.new_tracks.name, "New Tracks") + self.assertFalse(self.table.new_tracks.active) + + self.assertEqual(self.table.lookup("New Tracks"), + self.table.new_tracks) + + def test_previous(self): + """Test the previous tracks playlist.""" + self.assertIsInstance(self.table.previous, + emmental.db.playlists.Playlist) + self.assertEqual(self.table.previous.name, "Previous Tracks") + self.assertFalse(self.table.previous.active) + + self.assertEqual(self.table.lookup("Previous Tracks"), + self.table.previous) + + def test_queued(self): + """Test the queued tracks playlist.""" + self.assertIsInstance(self.table.queued, + emmental.db.playlists.Playlist) + self.assertEqual(self.table.queued.name, "Queued Tracks") + self.assertFalse(self.table.queued.active) + + self.assertEqual(self.table.lookup("Queued Tracks"), + self.table.queued) + + def test_unplayed(self): + """Test the unplayed tracks playlist.""" + self.assertIsInstance(self.table.unplayed, + emmental.db.playlists.Playlist) + self.assertEqual(self.table.unplayed.name, "Unplayed Tracks") + self.assertFalse(self.table.unplayed.active) + + self.assertEqual(self.table.lookup("Unplayed Tracks"), + self.table.unplayed)