From 072264a77c1efe43996b6dfbcfa08557c324e7e9 Mon Sep 17 00:00:00 2001 From: Anna Schumaker Date: Wed, 18 Oct 2023 11:50:29 -0400 Subject: [PATCH] db: Upgrade the database version to 2 Prepare for database modifications. The first step is to bump the database version, and it's cleaner to do that in a separate patch. Signed-off-by: Anna Schumaker --- emmental/db/__init__.py | 10 +++++++--- emmental/db/upgrade-v2.sql | 3 +++ tests/db/test_db.py | 11 ++++++----- 3 files changed, 16 insertions(+), 8 deletions(-) create mode 100644 emmental/db/upgrade-v2.sql diff --git a/emmental/db/__init__.py b/emmental/db/__init__.py index b8a1a66..e119ab1 100644 --- a/emmental/db/__init__.py +++ b/emmental/db/__init__.py @@ -18,7 +18,8 @@ from . import tracks from . import years -SQL_SCRIPT = pathlib.Path(__file__).parent / "emmental.sql" +SQL_V1_SCRIPT = pathlib.Path(__file__).parent / "emmental.sql" +SQL_V2_SCRIPT = pathlib.Path(__file__).parent / "upgrade-v2.sql" class Connection(connection.Connection): @@ -47,8 +48,11 @@ class Connection(connection.Connection): user_version = self("PRAGMA user_version").fetchone()["user_version"] match user_version: case 0: - self.executescript(SQL_SCRIPT) - case 1: pass + self.executescript(SQL_V1_SCRIPT) + self.executescript(SQL_V2_SCRIPT) + case 1: + self.executescript(SQL_V2_SCRIPT) + case 2: pass case _: raise Exception(f"Unsupported data version: {user_version}") diff --git a/emmental/db/upgrade-v2.sql b/emmental/db/upgrade-v2.sql new file mode 100644 index 0000000..c45856c --- /dev/null +++ b/emmental/db/upgrade-v2.sql @@ -0,0 +1,3 @@ +/* Copyright 2023 (c) Anna Schumaker */ + +PRAGMA user_version = 2; diff --git a/tests/db/test_db.py b/tests/db/test_db.py index 821073c..b855d09 100644 --- a/tests/db/test_db.py +++ b/tests/db/test_db.py @@ -11,8 +11,9 @@ class TestConnection(tests.util.TestCase): def test_paths(self): """Check that path constants are pointing to the right places.""" - script = pathlib.Path(emmental.db.__file__).parent / "emmental.sql" - self.assertEqual(emmental.db.SQL_SCRIPT, script) + dir = pathlib.Path(emmental.db.__file__).parent + self.assertEqual(emmental.db.SQL_V1_SCRIPT, dir / "emmental.sql") + self.assertEqual(emmental.db.SQL_V2_SCRIPT, dir / "upgrade-v2.sql") def test_connection(self): """Check that the connection manager is initialized properly.""" @@ -21,16 +22,16 @@ class TestConnection(tests.util.TestCase): def test_version(self): """Test checking the database schema version.""" cur = self.sql("PRAGMA user_version") - self.assertEqual(cur.fetchone()["user_version"], 1) + self.assertEqual(cur.fetchone()["user_version"], 2) def test_version_too_new(self): """Test failing when the database version is too new.""" self.sql._Connection__check_version() - self.sql("PRAGMA user_version = 2") + self.sql("PRAGMA user_version = 3") with self.assertRaises(Exception) as e: self.sql._Connection__check_version() - self.assertEqual(str(e.exception), "Unsupported data version: 2") + self.assertEqual(str(e.exception), "Unsupported data version: 3") def test_close(self): """Check closing the connection."""