diff --git a/emmental/db/decades.py b/emmental/db/decades.py index 9d1bdb3..f60b8ed 100644 --- a/emmental/db/decades.py +++ b/emmental/db/decades.py @@ -17,15 +17,25 @@ class Decade(playlist.Playlist): """Initialize a Decade object.""" super().__init__(**kwargs) self.add_children(self.table.sql.years, - Gtk.CustomFilter.new(self.__match_year)) + Gtk.CustomFilter.new(self.has_year), + self.table.get_yearids(self)) - def __match_year(self, year: Year) -> bool: - return self.decade == year.year // 10 * 10 + def add_year(self, year: Year) -> None: + """Add a year to this decade.""" + self.add_child(year) def get_years(self) -> list[Year]: """Get a list of years for this decade.""" return self.table.get_years(self) + def has_year(self, year: Year) -> bool: + """Check if the year is in this decade.""" + return self.has_child(year) + + def remove_year(self, year: Year) -> None: + """Remove a year from this decade.""" + self.remove_child(year) + @property def primary_key(self) -> int: """Get the primary key of this Decade.""" @@ -90,8 +100,12 @@ class Table(playlist.Table): return self.sql("""SELECT trackid FROM decade_tracks_view WHERE decade=?""", decade.decade) - def get_years(self, decade: Decade) -> list[Year]: - """Get the list of years for this decade.""" + def get_yearids(self, decade: Decade) -> set[int]: + """Get the set of years for this decade.""" rows = self.sql("SELECT year FROM years WHERE (year / 10 * 10)=?", decade.decade) - return [self.sql.years.rows.get(row["year"]) for row in rows] + return {row["year"] for row in rows} + + def get_years(self, decade: Decade) -> list[Year]: + """Get the list of years for this decade.""" + return [self.sql.years.rows.get(yr) for yr in self.get_yearids(decade)] diff --git a/emmental/db/years.py b/emmental/db/years.py index b65ebf9..d4f743f 100644 --- a/emmental/db/years.py +++ b/emmental/db/years.py @@ -48,6 +48,8 @@ class Table(playlist.Table): def do_sql_delete(self, year: Year) -> sqlite3.Cursor: """Delete a year.""" + if year.parent is not None: + year.parent.remove_year(year) return self.sql("DELETE FROM years WHERE year=?", year.year) def do_sql_glob(self, glob: str) -> sqlite3.Cursor: @@ -71,3 +73,10 @@ class Table(playlist.Table): """Load a Year's Tracks from the database.""" return self.sql("""SELECT trackid FROM year_tracks_view WHERE year=?""", year.year) + + def create(self, *args, **kwargs) -> Year | None: + """Create a new Year playlist.""" + if (year := super().create(*args, **kwargs)) is not None: + if year.parent is not None: + year.parent.add_year(year) + return year diff --git a/tests/db/test_decades.py b/tests/db/test_decades.py index 56fdffe..7a3e31a 100644 --- a/tests/db/test_decades.py +++ b/tests/db/test_decades.py @@ -28,6 +28,21 @@ class TestDecadeObject(tests.util.TestCase): self.assertEqual(self.decade.name, "The 2020s") self.assertIsNone(self.decade.parent) + def test_add_remove_year(self): + """Test adding and removing a year from the decade.""" + year = self.sql.years.create(1988) + + self.assertFalse(year in self.decade.child_set) + self.assertFalse(self.decade.has_year(year)) + + self.decade.add_year(year) + self.assertTrue(year in self.decade.child_set) + self.assertTrue(self.decade.has_year(year)) + + self.decade.remove_year(year) + self.assertFalse(year in self.decade.child_set) + self.assertFalse(self.decade.has_year(year)) + def test_get_years(self): """Test getting the list of years for this decade.""" with unittest.mock.patch.object(self.table, "get_years", @@ -43,6 +58,7 @@ class TestDecadeObject(tests.util.TestCase): self.assertEqual(self.decade.children.get_model(), self.sql.years) year = self.sql.years.create(2023) + self.decade.add_year(year) self.assertTrue(self.decade.children.get_filter().match(year)) year = self.sql.years.create(1988) @@ -164,8 +180,10 @@ class TestDecadeTable(tests.util.TestCase): def test_load(self): """Load the decade table from the database.""" - self.table.create(1980) + decade = self.table.create(1980) self.table.create(1990) + year = self.sql.years.create(1988) + decade.add_year(year) decades2 = emmental.db.decades.Table(self.sql) self.assertEqual(len(decades2), 0) @@ -175,9 +193,11 @@ class TestDecadeTable(tests.util.TestCase): self.assertEqual(decades2.get_item(0).decade, 1980) self.assertEqual(decades2.get_item(0).name, "The 1980s") + self.assertSetEqual(decades2.get_item(0).child_set.keyset.keys, {1988}) self.assertEqual(decades2.get_item(1).decade, 1990) self.assertEqual(decades2.get_item(1).name, "The 1990s") + self.assertSetEqual(decades2.get_item(1).child_set.keyset.keys, set()) def test_lookup(self): """Test looking up decade playlists.""" @@ -214,4 +234,5 @@ class TestDecadeTable(tests.util.TestCase): y1985 = self.sql.years.create(1985) y1988 = self.sql.years.create(1988) + self.assertSetEqual(self.table.get_yearids(decade), {1985, 1988}) self.assertListEqual(self.table.get_years(decade), [y1985, y1988]) diff --git a/tests/db/test_years.py b/tests/db/test_years.py index f876fa9..26cf349 100644 --- a/tests/db/test_years.py +++ b/tests/db/test_years.py @@ -75,12 +75,14 @@ class TestYearTable(tests.util.TestCase): def test_create(self): """Test creating a year playlist.""" + decade = self.sql.decades.create(1980) year = self.table.create(1988) self.assertIsInstance(year, emmental.db.years.Year) self.assertEqual(year.year, 1988) self.assertEqual(year.name, "1988") self.assertEqual(year.sort_order, "release, albumartist, album, mediumno, number") + self.assertTrue(year in decade.child_set) cur = self.sql("SELECT COUNT(year) FROM years") self.assertEqual(cur.fetchone()["COUNT(year)"], 1) @@ -93,8 +95,10 @@ class TestYearTable(tests.util.TestCase): def test_delete(self): """Test deleting a year playlist.""" + decade = self.sql.decades.create(1980) year = self.table.create(1988) self.assertTrue(year.delete()) + self.assertFalse(year in decade.child_set) cur = self.sql("SELECT COUNT(year) FROM years") self.assertEqual(cur.fetchone()["COUNT(year)"], 0)