diff --git a/emmental/db/__init__.py b/emmental/db/__init__.py index 16ffbb7..3207115 100644 --- a/emmental/db/__init__.py +++ b/emmental/db/__init__.py @@ -3,6 +3,7 @@ import pathlib from gi.repository import GObject from . import connection +from . import playlist from . import settings from . import table @@ -13,6 +14,8 @@ SQL_SCRIPT = pathlib.Path(__file__).parent / "emmental.sql" class Connection(connection.Connection): """Connect to the database.""" + active_playlist = GObject.Property(type=playlist.Playlist) + def __init__(self): """Initialize a sqlite connection.""" super().__init__() @@ -34,6 +37,16 @@ class Connection(connection.Connection): """Load the database tables.""" self.settings.load() + def set_active_playlist(self, plist: playlist.Playlist) -> None: + """Set the currently active playlist.""" + if self.active_playlist is not None: + self.active_playlist.active = False + + self.active_playlist = plist + + if plist is not None: + plist.active = True + @GObject.Signal(arg_types=(table.Table,)) def table_loaded(self, tbl: table.Table) -> None: """Signal that a table has been loaded.""" diff --git a/emmental/db/emmental.sql b/emmental/db/emmental.sql index 989d7fd..da1afc0 100644 --- a/emmental/db/emmental.sql +++ b/emmental/db/emmental.sql @@ -15,3 +15,23 @@ CREATE TABLE settings ( value TEXT NOT NULL, CHECK (type IN ("gint", "gdouble", "gboolean", "gchararray")) ); + + +/************************************* + * * + * Playlist Properties * + * * + *************************************/ + +CREATE TABLE playlist_properties ( + propertyid INTEGER PRIMARY KEY, + active BOOLEAN NOT NULL DEFAULT FALSE +); + +CREATE TRIGGER playlists_active_trigger + AFTER UPDATE OF active ON playlist_properties + FOR EACH ROW BEGIN + UPDATE playlist_properties + SET active = FALSE + WHERE propertyid != NEW.propertyid AND active == TRUE; + END; diff --git a/emmental/db/playlist.py b/emmental/db/playlist.py index de4ffef..edf74ec 100644 --- a/emmental/db/playlist.py +++ b/emmental/db/playlist.py @@ -3,6 +3,7 @@ from gi.repository import GObject from gi.repository import Gio from gi.repository import Gtk +from .. import format from . import table @@ -19,10 +20,10 @@ class Playlist(table.Row): children = GObject.Property(type=Gtk.FilterListModel) def __init__(self, table: Gio.ListModel, propertyid: int, - name: str = "", active: bool = False, **kwargs): + name: str, **kwargs): """Initialize a Playlist object.""" super().__init__(table=table, propertyid=propertyid, - name=name, active=active, **kwargs) + name=name, **kwargs) def add_children(self, child_table: table.Table, child_filter: Gtk.Filter) -> None: @@ -41,3 +42,62 @@ class Playlist(table.Row): def parent(self) -> table.Row | None: """Get this playlist's parent playlist.""" return None + + +class Table(table.Table): + """A table.Table with extra functionality for Playlists.""" + + active_playlist = GObject.Property(type=Playlist) + treemodel = GObject.Property(type=Gtk.TreeListModel) + + def __init__(self, sql: GObject.TYPE_PYOBJECT, **kwargs): + """Initialize a Playlist Table.""" + super().__init__(sql=sql, **kwargs) + self.treemodel = Gtk.TreeListModel.new(root=self, + passthrough=False, + autoexpand=False, + create_func=self.__create_tree) + + def __create_tree(self, plist: Playlist) -> Gtk.FilterListModel | None: + return plist.children + + def do_get_sort_key(self, playlist: Playlist) -> tuple[str]: + """Get a sort key for the requested Playlist.""" + return format.sort_key(playlist.name) + + def clear(self) -> None: + """Clear the Table.""" + self.active_playlist = None + super().clear() + + def construct(self, propertyid: int, name: str, **kwargs) -> Playlist: + """Construct a new Playlist object.""" + res = super().construct(propertyid=propertyid, name=name, **kwargs) + if res.active: + self.sql.set_active_playlist(res) + return res + + def delete(self, playlist: Playlist) -> bool: + """Delete a playlist from the database.""" + if playlist.active: + self.sql.set_active_playlist(None) + return super().delete(playlist) + + def update(self, playlist: Playlist, column: str, newval) -> bool: + """Update a Playlist in the Database.""" + match column: + case "active": + return self.update_playlist_property(playlist, column, newval) + case _: + return super().update(playlist, column, newval) + + def update_playlist_property(self, playlist: Playlist, + column: str, newval) -> bool: + """Update the playlists_common table.""" + match column: + case "active": + self.active_playlist = playlist if playlist.active else None + + return self.sql(f"""UPDATE playlist_properties + SET {column}=? WHERE propertyid=?""", + newval, playlist.propertyid) is not None diff --git a/tests/db/test_db.py b/tests/db/test_db.py index 978ee0e..084482e 100644 --- a/tests/db/test_db.py +++ b/tests/db/test_db.py @@ -48,3 +48,23 @@ class TestConnection(tests.util.TestCase): calls = [unittest.mock.call(self.sql, tbl) for tbl in [self.sql.settings]] table_loaded.assert_has_calls(calls) + + def test_set_active_playlist(self): + """Check setting the active playlist.""" + table = tests.util.playlist.MockTable(self.sql) + plist1 = table.create(name="Playlist 1") + plist2 = table.create(name="Playlist 2") + self.assertIsNone(self.sql.active_playlist) + + self.sql.set_active_playlist(plist1) + self.assertEqual(self.sql.active_playlist, plist1) + self.assertTrue(plist1.active) + + self.sql.set_active_playlist(plist2) + self.assertEqual(self.sql.active_playlist, plist2) + self.assertFalse(plist1.active) + self.assertTrue(plist2.active) + + self.sql.set_active_playlist(None) + self.assertIsNone(self.sql.active_playlist) + self.assertFalse(plist2.active) diff --git a/tests/db/test_playlist.py b/tests/db/test_playlist.py index aef83a7..9c220fc 100644 --- a/tests/db/test_playlist.py +++ b/tests/db/test_playlist.py @@ -3,6 +3,7 @@ import unittest import unittest.mock import emmental.db.playlist +import tests.util.playlist from gi.repository import Gio from gi.repository import Gtk @@ -15,14 +16,15 @@ class TestPlaylistRow(unittest.TestCase): self.table = Gio.ListStore() self.table.update = unittest.mock.Mock(return_value=True) self.playlist = emmental.db.playlist.Playlist(table=self.table, - propertyid=0) + propertyid=0, + name="Test Playlist") def test_init(self): """Test that the Playlist object is configured correctly.""" self.assertIsInstance(self.playlist, emmental.db.table.Row) self.assertEqual(self.playlist.table, self.table) self.assertEqual(self.playlist.propertyid, 0) - self.assertEqual(self.playlist.name, "") + self.assertEqual(self.playlist.name, "Test Playlist") self.assertEqual(self.playlist.n_tracks, 0) self.assertFalse(self.playlist.active) @@ -59,3 +61,82 @@ class TestPlaylistRow(unittest.TestCase): self.playlist.active = True self.table.update.assert_called_with(self.playlist, "active", True) + + +class TestPlaylistTable(tests.util.TestCase): + """Tests our Playlist Table.""" + + def setUp(self): + """Set up common variables.""" + super().setUp() + self.table = tests.util.playlist.MockTable(self.sql) + self.sql("DELETE FROM playlist_properties") + + def test_treemodel(self): + """Check that the table's treemodel was set up properly.""" + self.assertIsInstance(self.table.treemodel, Gtk.TreeListModel) + self.assertEqual(self.table.treemodel.get_model(), self.table) + self.assertFalse(self.table.treemodel.get_passthrough()) + self.assertFalse(self.table.treemodel.get_autoexpand()) + + root = self.table.create("Root") + self.assertIsNone(self.table._Table__create_tree(root)) + root.children = Gtk.FilterListModel() + self.assertEqual(self.table._Table__create_tree(root), root.children) + + def test_construct(self): + """Test constructing a new playlist.""" + self.assertIsNone(self.table.active_playlist) + + plist1 = self.table.construct(propertyid=1, name="Test") + self.assertIsInstance(plist1, emmental.db.playlist.Playlist) + self.assertEqual(plist1.table, self.table) + self.assertEqual(plist1.propertyid, 1) + self.assertEqual(plist1.name, "Test") + self.assertFalse(plist1.active) + + plist2 = self.table.construct(propertyid=2, name="Test 2", active=True) + self.assertEqual(self.table.active_playlist, plist2) + self.assertEqual(self.sql.active_playlist, plist2) + self.assertTrue(plist2.active) + + def test_get_sort_key(self): + """Test getting a sort key for a playlist.""" + plist = self.table.create("Playlist 1") + self.assertTupleEqual(self.table.get_sort_key(plist), + ("playlist", "1")) + + def test_clear(self): + """Test clearing the active_playlist property.""" + plist = self.table.create("Playlist 1") + self.table.active_playlist = plist + self.table.clear() + self.assertIsNone(self.table.active_playlist) + self.assertEqual(len(self.table), 0) + + def test_delete(self): + """Test deleting the active playlist.""" + plist = self.table.create("Test Playlist") + self.sql.set_active_playlist(plist) + + self.assertTrue(self.table.delete(plist)) + self.assertIsNone(self.table.active_playlist) + self.assertIsNone(self.sql.active_playlist) + self.assertNotIn(plist, self.table) + + def test_update(self): + """Test updating playlist properties.""" + plist1 = self.table.create("Test Playlist 1") + plist2 = self.table.create("Test Playlist 2") + plist1.active = True + + self.assertEqual(self.table.active_playlist, plist1) + row = self.sql("""SELECT active FROM playlist_properties + WHERE propertyid=?""", plist1.propertyid).fetchone() + self.assertEqual(row["active"], True) + + plist2.active = True + self.assertEqual(self.table.active_playlist, plist2) + row = self.sql("SELECT active FROM playlist_properties WHERE rowid=?", + plist1.propertyid).fetchone() + self.assertEqual(row["active"], False) diff --git a/tests/util/playlist.py b/tests/util/playlist.py new file mode 100644 index 0000000..135d267 --- /dev/null +++ b/tests/util/playlist.py @@ -0,0 +1,31 @@ +# Copyright 2023 (c) Anna Schumaker. +"""Mock Playlist and Table objects for testing.""" +import emmental.db.playlist +import sqlite3 + + +class MockPlaylist(emmental.db.playlist.Playlist): + """A fake Playlist for testing.""" + + @property + def primary_key(self) -> int: + """Get the primary_key of this playlist.""" + return self.propertyid + + +class MockTable(emmental.db.playlist.Table): + """A fake Playlist Table for testing.""" + + def do_construct(self, **kwargs) -> MockPlaylist: + """Construct a new Playlist object.""" + return MockPlaylist(**kwargs) + + def do_sql_delete(self, playlist: MockPlaylist) -> sqlite3.Cursor: + """Extra work for deleting a Playlist.""" + return self.sql("DELETE FROM playlist_properties WHERE propertyid=?", + playlist.propertyid) + + def do_sql_insert(self, name: str) -> sqlite3.Cursor: + """Extra work for adding a new Playlist.""" + return self.sql("""INSERT INTO playlist_properties DEFAULT VALUES + RETURNING ? as name, *""", name)