From e7526f595f7745456856e5b6c8fc81f283be3c39 Mon Sep 17 00:00:00 2001 From: Anna Schumaker Date: Wed, 18 Oct 2023 10:51:43 -0400 Subject: [PATCH] db: Give the db an executescript() function This is a wrapper function that takes a pathlib.Path object, reads it, and calls the sqlite3 executescript() function. I update the main db.Connection object to call this function to set up our database tables while I'm at it. Signed-off-by: Anna Schumaker --- emmental/db/__init__.py | 3 +-- emmental/db/connection.py | 8 ++++++++ tests/db/test-script.sql | 7 +++++++ tests/db/test_connection.py | 14 ++++++++++++++ 4 files changed, 30 insertions(+), 2 deletions(-) create mode 100644 tests/db/test-script.sql diff --git a/emmental/db/__init__.py b/emmental/db/__init__.py index 1f7fc92..b8a1a66 100644 --- a/emmental/db/__init__.py +++ b/emmental/db/__init__.py @@ -47,8 +47,7 @@ class Connection(connection.Connection): user_version = self("PRAGMA user_version").fetchone()["user_version"] match user_version: case 0: - with open(SQL_SCRIPT) as f: - self._sql.executescript(f.read()) + self.executescript(SQL_SCRIPT) case 1: pass case _: raise Exception(f"Unsupported data version: {user_version}") diff --git a/emmental/db/connection.py b/emmental/db/connection.py index 1920e36..e7464d6 100644 --- a/emmental/db/connection.py +++ b/emmental/db/connection.py @@ -85,3 +85,11 @@ class Connection(GObject.GObject): return self._sql.executemany(statement, args) except sqlite3.InternalError: return None + + def executescript(self, script: pathlib.Path) -> sqlite3.Cursor | None: + """Execute a SQL script.""" + if script.is_file(): + with open(script) as f: + cur = self._sql.executescript(f.read()) + self.commit() + return cur diff --git a/tests/db/test-script.sql b/tests/db/test-script.sql new file mode 100644 index 0000000..fa2f27d --- /dev/null +++ b/tests/db/test-script.sql @@ -0,0 +1,7 @@ +/* Copyright 2023 (c) Anna Schumaker */ +CREATE TABLE test (a INT, b INT); +INSERT INTO test VALUES (1, 2); +INSERT INTO test VALUES (3, 4); +INSERT INTO test VALUES (5, 6); +INSERT INTO test VALUES (7, 8); +INSERT INTO test VALUES (9, 0); diff --git a/tests/db/test_connection.py b/tests/db/test_connection.py index e0cc583..0136180 100644 --- a/tests/db/test_connection.py +++ b/tests/db/test_connection.py @@ -79,6 +79,20 @@ class TestConnection(unittest.TestCase): self.assertEqual(tuple(rows[3]), (4, "d")) self.assertEqual(tuple(rows[4]), (5, "e")) + @unittest.mock.patch("emmental.db.connection.Connection.commit") + def test_executescript(self, mock_commit: unittest.mock.Mock): + """Test the executescript function.""" + script = pathlib.Path(__file__).parent / "test-script.sql" + cur = self.sql.executescript(script) + self.assertIsInstance(cur, sqlite3.Cursor) + mock_commit.assert_called() + + rows = self.sql("SELECT * FROM test").fetchall() + self.assertListEqual([(row["a"], row["b"]) for row in rows], + [(1, 2), (3, 4), (5, 6), (7, 8), (9, 0)]) + + self.assertIsNone(self.sql.executescript(script.parent / "no-script")) + def test_path_column(self): """Test that the PATH column type has been set up.""" self.sql("CREATE TABLE test (path PATH)")