From a4e0968ef44f641685d377960181dca3c2aa4dcb Mon Sep 17 00:00:00 2001 From: Anna Schumaker Date: Wed, 13 Mar 2024 11:20:15 -0400 Subject: [PATCH] db: Give the database a 'loaded' property This can be checked or connected to so other parts of the application can easily know if all database tables have been loaded or not. Signed-off-by: Anna Schumaker --- emmental/db/__init__.py | 8 ++++++++ tests/db/test_db.py | 24 ++++++++++++++++++------ 2 files changed, 26 insertions(+), 6 deletions(-) diff --git a/emmental/db/__init__.py b/emmental/db/__init__.py index e119ab1..1ec1a37 100644 --- a/emmental/db/__init__.py +++ b/emmental/db/__init__.py @@ -26,6 +26,7 @@ class Connection(connection.Connection): """Connect to the database.""" active_playlist = GObject.Property(type=playlist.Playlist) + loaded = GObject.Property(type=bool, default=False) def __init__(self): """Initialize a sqlite connection.""" @@ -44,6 +45,12 @@ class Connection(connection.Connection): self.tracks = tracks.Table(self) + def __check_loaded(self) -> None: + for tbl in list(self.playlist_tables()) + [self.tracks]: + if tbl.loaded is False: + return + self.loaded = True + def __check_version(self) -> None: user_version = self("PRAGMA user_version").fetchone()["user_version"] match user_version: @@ -99,3 +106,4 @@ class Connection(connection.Connection): def table_loaded(self, tbl: table.Table) -> None: """Signal that a table has been loaded.""" tbl.loaded = True + self.__check_loaded() diff --git a/tests/db/test_db.py b/tests/db/test_db.py index b855d09..56610bd 100644 --- a/tests/db/test_db.py +++ b/tests/db/test_db.py @@ -72,22 +72,34 @@ class TestConnection(tests.util.TestCase): def test_load(self): """Check that calling load() loads the tables.""" - idle_tables = [tbl for tbl in self.sql.playlist_tables()] + \ - [self.sql.tracks] + plist_tables = list(self.sql.playlist_tables()) + all_tables = [self.sql.settings] + plist_tables + [self.sql.tracks] table_loaded = unittest.mock.Mock() self.sql.connect("table-loaded", table_loaded) + self.assertFalse(self.sql.loaded) + notify_loaded = unittest.mock.Mock() + self.sql.connect("notify::loaded", notify_loaded) + self.sql.load() self.assertTrue(self.sql.settings.loaded) - for tbl in idle_tables: + notify_loaded.assert_not_called() + + for tbl in all_tables[1:]: self.assertFalse(tbl.loaded) - for tbl in idle_tables: + for tbl in plist_tables: tbl.queue.complete() self.assertTrue(tbl.loaded) + self.assertFalse(self.sql.loaded) + notify_loaded.assert_not_called() - calls = [unittest.mock.call(self.sql, tbl) - for tbl in [self.sql.settings] + idle_tables] + self.sql.tracks.queue.complete() + self.assertTrue(self.sql.tracks.loaded) + self.assertTrue(self.sql.loaded) + notify_loaded.assert_called() + + calls = [unittest.mock.call(self.sql, tbl) for tbl in all_tables] table_loaded.assert_has_calls(calls) def test_filter(self):