db: Upgrade the database version to 2

Prepare for database modifications. The first step is to bump the
database version, and it's cleaner to do that in a separate patch.

Signed-off-by: Anna Schumaker <Anna@NoWheyCreamery.com>
This commit is contained in:
Anna Schumaker 2023-10-18 11:50:29 -04:00
parent e7526f595f
commit 072264a77c
3 changed files with 16 additions and 8 deletions

View File

@ -18,7 +18,8 @@ from . import tracks
from . import years
SQL_SCRIPT = pathlib.Path(__file__).parent / "emmental.sql"
SQL_V1_SCRIPT = pathlib.Path(__file__).parent / "emmental.sql"
SQL_V2_SCRIPT = pathlib.Path(__file__).parent / "upgrade-v2.sql"
class Connection(connection.Connection):
@ -47,8 +48,11 @@ class Connection(connection.Connection):
user_version = self("PRAGMA user_version").fetchone()["user_version"]
match user_version:
case 0:
self.executescript(SQL_SCRIPT)
case 1: pass
self.executescript(SQL_V1_SCRIPT)
self.executescript(SQL_V2_SCRIPT)
case 1:
self.executescript(SQL_V2_SCRIPT)
case 2: pass
case _:
raise Exception(f"Unsupported data version: {user_version}")

View File

@ -0,0 +1,3 @@
/* Copyright 2023 (c) Anna Schumaker */
PRAGMA user_version = 2;

View File

@ -11,8 +11,9 @@ class TestConnection(tests.util.TestCase):
def test_paths(self):
"""Check that path constants are pointing to the right places."""
script = pathlib.Path(emmental.db.__file__).parent / "emmental.sql"
self.assertEqual(emmental.db.SQL_SCRIPT, script)
dir = pathlib.Path(emmental.db.__file__).parent
self.assertEqual(emmental.db.SQL_V1_SCRIPT, dir / "emmental.sql")
self.assertEqual(emmental.db.SQL_V2_SCRIPT, dir / "upgrade-v2.sql")
def test_connection(self):
"""Check that the connection manager is initialized properly."""
@ -21,16 +22,16 @@ class TestConnection(tests.util.TestCase):
def test_version(self):
"""Test checking the database schema version."""
cur = self.sql("PRAGMA user_version")
self.assertEqual(cur.fetchone()["user_version"], 1)
self.assertEqual(cur.fetchone()["user_version"], 2)
def test_version_too_new(self):
"""Test failing when the database version is too new."""
self.sql._Connection__check_version()
self.sql("PRAGMA user_version = 2")
self.sql("PRAGMA user_version = 3")
with self.assertRaises(Exception) as e:
self.sql._Connection__check_version()
self.assertEqual(str(e.exception), "Unsupported data version: 2")
self.assertEqual(str(e.exception), "Unsupported data version: 3")
def test_close(self):
"""Check closing the connection."""