db: Give the db an executescript() function
This is a wrapper function that takes a pathlib.Path object, reads it, and calls the sqlite3 executescript() function. I update the main db.Connection object to call this function to set up our database tables while I'm at it. Signed-off-by: Anna Schumaker <Anna@NoWheyCreamery.com>
This commit is contained in:
parent
7d2ec00da7
commit
e7526f595f
|
@ -47,8 +47,7 @@ class Connection(connection.Connection):
|
|||
user_version = self("PRAGMA user_version").fetchone()["user_version"]
|
||||
match user_version:
|
||||
case 0:
|
||||
with open(SQL_SCRIPT) as f:
|
||||
self._sql.executescript(f.read())
|
||||
self.executescript(SQL_SCRIPT)
|
||||
case 1: pass
|
||||
case _:
|
||||
raise Exception(f"Unsupported data version: {user_version}")
|
||||
|
|
|
@ -85,3 +85,11 @@ class Connection(GObject.GObject):
|
|||
return self._sql.executemany(statement, args)
|
||||
except sqlite3.InternalError:
|
||||
return None
|
||||
|
||||
def executescript(self, script: pathlib.Path) -> sqlite3.Cursor | None:
|
||||
"""Execute a SQL script."""
|
||||
if script.is_file():
|
||||
with open(script) as f:
|
||||
cur = self._sql.executescript(f.read())
|
||||
self.commit()
|
||||
return cur
|
||||
|
|
|
@ -0,0 +1,7 @@
|
|||
/* Copyright 2023 (c) Anna Schumaker */
|
||||
CREATE TABLE test (a INT, b INT);
|
||||
INSERT INTO test VALUES (1, 2);
|
||||
INSERT INTO test VALUES (3, 4);
|
||||
INSERT INTO test VALUES (5, 6);
|
||||
INSERT INTO test VALUES (7, 8);
|
||||
INSERT INTO test VALUES (9, 0);
|
|
@ -79,6 +79,20 @@ class TestConnection(unittest.TestCase):
|
|||
self.assertEqual(tuple(rows[3]), (4, "d"))
|
||||
self.assertEqual(tuple(rows[4]), (5, "e"))
|
||||
|
||||
@unittest.mock.patch("emmental.db.connection.Connection.commit")
|
||||
def test_executescript(self, mock_commit: unittest.mock.Mock):
|
||||
"""Test the executescript function."""
|
||||
script = pathlib.Path(__file__).parent / "test-script.sql"
|
||||
cur = self.sql.executescript(script)
|
||||
self.assertIsInstance(cur, sqlite3.Cursor)
|
||||
mock_commit.assert_called()
|
||||
|
||||
rows = self.sql("SELECT * FROM test").fetchall()
|
||||
self.assertListEqual([(row["a"], row["b"]) for row in rows],
|
||||
[(1, 2), (3, 4), (5, 6), (7, 8), (9, 0)])
|
||||
|
||||
self.assertIsNone(self.sql.executescript(script.parent / "no-script"))
|
||||
|
||||
def test_path_column(self):
|
||||
"""Test that the PATH column type has been set up."""
|
||||
self.sql("CREATE TABLE test (path PATH)")
|
||||
|
|
Loading…
Reference in New Issue