From 55d7eb3d4593f9c24f6de01b2ca6deb7a83ea588 Mon Sep 17 00:00:00 2001 From: Anna Schumaker Date: Thu, 4 May 2023 10:44:48 -0400 Subject: [PATCH] db: Raise an exception if the user_version is too new Future proof. If we update the database schema, then we'll bump the user_version field. If the user then tries to open the new database with an old Emmental version then there could be a lot of issues. Let's detect this and raise an error with a description of the problem. Signed-off-by: Anna Schumaker --- emmental/db/__init__.py | 16 +++++++++++----- tests/db/test_db.py | 9 +++++++++ 2 files changed, 20 insertions(+), 5 deletions(-) diff --git a/emmental/db/__init__.py b/emmental/db/__init__.py index 1a98676..f7b8b02 100644 --- a/emmental/db/__init__.py +++ b/emmental/db/__init__.py @@ -29,11 +29,7 @@ class Connection(connection.Connection): def __init__(self): """Initialize a sqlite connection.""" super().__init__() - - match self("PRAGMA user_version").fetchone()["user_version"]: - case 0: - with open(SQL_SCRIPT) as f: - self._sql.executescript(f.read()) + self.__check_version() self.settings = settings.Table(self) self.playlists = playlists.Table(self) @@ -47,6 +43,16 @@ class Connection(connection.Connection): self.tracks = tracks.Table(self) + def __check_version(self) -> None: + user_version = self("PRAGMA user_version").fetchone()["user_version"] + match user_version: + case 0: + with open(SQL_SCRIPT) as f: + self._sql.executescript(f.read()) + case 1: pass + case _: + raise Exception(f"Unsupported data version: {user_version}") + def close(self) -> None: """Close the database connection.""" self.settings.stop() diff --git a/tests/db/test_db.py b/tests/db/test_db.py index 716c023..3203cd5 100644 --- a/tests/db/test_db.py +++ b/tests/db/test_db.py @@ -23,6 +23,15 @@ class TestConnection(tests.util.TestCase): cur = self.sql("PRAGMA user_version") self.assertEqual(cur.fetchone()["user_version"], 1) + def test_version_too_new(self): + """Test failing when the database version is too new.""" + self.sql._Connection__check_version() + + self.sql("PRAGMA user_version = 2") + with self.assertRaises(Exception) as e: + self.sql._Connection__check_version() + self.assertEqual(str(e.exception), "Unsupported data version: 2") + def test_close(self): """Check closing the connection.""" self.sql.settings.queue.running = True