diff --git a/emmental/db/__init__.py b/emmental/db/__init__.py index 1a98676..f7b8b02 100644 --- a/emmental/db/__init__.py +++ b/emmental/db/__init__.py @@ -29,11 +29,7 @@ class Connection(connection.Connection): def __init__(self): """Initialize a sqlite connection.""" super().__init__() - - match self("PRAGMA user_version").fetchone()["user_version"]: - case 0: - with open(SQL_SCRIPT) as f: - self._sql.executescript(f.read()) + self.__check_version() self.settings = settings.Table(self) self.playlists = playlists.Table(self) @@ -47,6 +43,16 @@ class Connection(connection.Connection): self.tracks = tracks.Table(self) + def __check_version(self) -> None: + user_version = self("PRAGMA user_version").fetchone()["user_version"] + match user_version: + case 0: + with open(SQL_SCRIPT) as f: + self._sql.executescript(f.read()) + case 1: pass + case _: + raise Exception(f"Unsupported data version: {user_version}") + def close(self) -> None: """Close the database connection.""" self.settings.stop() diff --git a/tests/db/test_db.py b/tests/db/test_db.py index 716c023..3203cd5 100644 --- a/tests/db/test_db.py +++ b/tests/db/test_db.py @@ -23,6 +23,15 @@ class TestConnection(tests.util.TestCase): cur = self.sql("PRAGMA user_version") self.assertEqual(cur.fetchone()["user_version"], 1) + def test_version_too_new(self): + """Test failing when the database version is too new.""" + self.sql._Connection__check_version() + + self.sql("PRAGMA user_version = 2") + with self.assertRaises(Exception) as e: + self.sql._Connection__check_version() + self.assertEqual(str(e.exception), "Unsupported data version: 2") + def test_close(self): """Check closing the connection.""" self.sql.settings.queue.running = True