db: Give the db an executescript() function

This is a wrapper function that takes a pathlib.Path object, reads it,
and calls the sqlite3 executescript() function. I update the main
db.Connection object to call this function to set up our database tables
while I'm at it.

Signed-off-by: Anna Schumaker <Anna@NoWheyCreamery.com>
This commit is contained in:
Anna Schumaker 2023-10-18 10:51:43 -04:00
parent 7d2ec00da7
commit e7526f595f
4 changed files with 30 additions and 2 deletions

View File

@ -47,8 +47,7 @@ class Connection(connection.Connection):
user_version = self("PRAGMA user_version").fetchone()["user_version"]
match user_version:
case 0:
with open(SQL_SCRIPT) as f:
self._sql.executescript(f.read())
self.executescript(SQL_SCRIPT)
case 1: pass
case _:
raise Exception(f"Unsupported data version: {user_version}")

View File

@ -85,3 +85,11 @@ class Connection(GObject.GObject):
return self._sql.executemany(statement, args)
except sqlite3.InternalError:
return None
def executescript(self, script: pathlib.Path) -> sqlite3.Cursor | None:
"""Execute a SQL script."""
if script.is_file():
with open(script) as f:
cur = self._sql.executescript(f.read())
self.commit()
return cur

7
tests/db/test-script.sql Normal file
View File

@ -0,0 +1,7 @@
/* Copyright 2023 (c) Anna Schumaker */
CREATE TABLE test (a INT, b INT);
INSERT INTO test VALUES (1, 2);
INSERT INTO test VALUES (3, 4);
INSERT INTO test VALUES (5, 6);
INSERT INTO test VALUES (7, 8);
INSERT INTO test VALUES (9, 0);

View File

@ -79,6 +79,20 @@ class TestConnection(unittest.TestCase):
self.assertEqual(tuple(rows[3]), (4, "d"))
self.assertEqual(tuple(rows[4]), (5, "e"))
@unittest.mock.patch("emmental.db.connection.Connection.commit")
def test_executescript(self, mock_commit: unittest.mock.Mock):
"""Test the executescript function."""
script = pathlib.Path(__file__).parent / "test-script.sql"
cur = self.sql.executescript(script)
self.assertIsInstance(cur, sqlite3.Cursor)
mock_commit.assert_called()
rows = self.sql("SELECT * FROM test").fetchall()
self.assertListEqual([(row["a"], row["b"]) for row in rows],
[(1, 2), (3, 4), (5, 6), (7, 8), (9, 0)])
self.assertIsNone(self.sql.executescript(script.parent / "no-script"))
def test_path_column(self):
"""Test that the PATH column type has been set up."""
self.sql("CREATE TABLE test (path PATH)")