diff --git a/emmental/db/media.py b/emmental/db/media.py index 6d163f3..ed50720 100644 --- a/emmental/db/media.py +++ b/emmental/db/media.py @@ -2,8 +2,10 @@ """A custom Gio.ListModel for managing individual media in an album.""" import sqlite3 from gi.repository import GObject +from gi.repository import Gtk from .. import format from . import playlist +from . import table from . import tracks @@ -34,12 +36,26 @@ class Medium(playlist.Playlist): return self.get_album() +class Filter(table.KeySet): + """Custom filter to hide media with empty names.""" + + def do_get_strictness(self) -> Gtk.FilterMatch: + """Get the strictness of the filter.""" + if (res := super().do_get_strictness()) == Gtk.FilterMatch.ALL: + res = Gtk.FilterMatch.SOME + return res + + def do_match(self, medium: Medium) -> bool: + """Check if the Medium matches the filter.""" + return len(medium.name) > 0 if super().do_match(medium) else False + + class Table(playlist.Table): """Our Media Table.""" def __init__(self, sql: GObject.TYPE_PYOBJECT, **kwargs): """Initialize the Media Table.""" - super().__init__(sql=sql, autodelete=True, + super().__init__(sql=sql, filter=Filter(), autodelete=True, system_tracks=False, **kwargs) def do_construct(self, **kwargs) -> Medium: diff --git a/tests/db/test_media.py b/tests/db/test_media.py index 784779b..a25d180 100644 --- a/tests/db/test_media.py +++ b/tests/db/test_media.py @@ -4,6 +4,7 @@ import pathlib import unittest.mock import emmental.db import tests.util +from gi.repository import Gtk class TestMediumObject(tests.util.TestCase): @@ -46,6 +47,36 @@ class TestMediumObject(tests.util.TestCase): mock_rename.assert_called_with(self.medium, "New Name") +class TestFilter(tests.util.TestCase): + """Test the medium filter.""" + + def setUp(self): + """Set up common variables.""" + super().setUp() + self.filter = emmental.db.media.Filter() + + def test_init(self): + """Test that the filter is initialized properly.""" + self.assertIsInstance(self.filter, emmental.db.table.KeySet) + + def test_strictness(self): + """Test checking strictness.""" + self.filter.keys = None + self.assertEqual(self.filter.get_strictness(), Gtk.FilterMatch.SOME) + self.filter.keys = set() + self.assertEqual(self.filter.get_strictness(), Gtk.FilterMatch.NONE) + self.filter.keys = {1, 2, 3} + self.assertEqual(self.filter.get_strictness(), Gtk.FilterMatch.SOME) + + def test_match(self): + """Test matching a Medium.""" + album = self.sql.albums.create("Test Album", "Test Artist", "123") + medium = self.sql.media.create(album, "", number=1) + self.assertFalse(self.filter.match(medium)) + medium.name = "abcde" + self.assertTrue(self.filter.match(medium)) + + class TestMediumsTable(tests.util.TestCase): """Tests our mediums table.""" @@ -61,6 +92,8 @@ class TestMediumsTable(tests.util.TestCase): def test_init(self): """Test that the medium model is configured corretly.""" self.assertIsInstance(self.table, emmental.db.playlist.Table) + self.assertIsInstance(self.table.get_filter(), + emmental.db.media.Filter) self.assertEqual(len(self.table), 0) self.assertTrue(self.table.autodelete) self.assertFalse(self.table.system_tracks) @@ -172,17 +205,19 @@ class TestMediumsTable(tests.util.TestCase): self.assertEqual(len(mediums2), 0) mediums2.load(now=True) - self.assertEqual(len(mediums2), 2) + self.assertEqual(len(mediums2.store), 2) - self.assertEqual(mediums2.get_item(0).albumid, self.album.albumid) - self.assertEqual(mediums2.get_item(0).name, "") - self.assertEqual(mediums2.get_item(0).number, 1) - self.assertEqual(mediums2.get_item(0).type, "") + self.assertEqual(mediums2.store.get_item(0).albumid, + self.album.albumid) + self.assertEqual(mediums2.store.get_item(0).name, "") + self.assertEqual(mediums2.store.get_item(0).number, 1) + self.assertEqual(mediums2.store.get_item(0).type, "") - self.assertEqual(mediums2.get_item(1).albumid, self.album.albumid) - self.assertEqual(mediums2.get_item(1).name, "Medium 2") - self.assertEqual(mediums2.get_item(1).number, 2) - self.assertEqual(mediums2.get_item(1).type, "Digital Media") + self.assertEqual(mediums2.store.get_item(1).albumid, + self.album.albumid) + self.assertEqual(mediums2.store.get_item(1).name, "Medium 2") + self.assertEqual(mediums2.store.get_item(1).number, 2) + self.assertEqual(mediums2.store.get_item(1).type, "Digital Media") def test_lookup(self): """Test looking up medium playlists."""