sqlite: Give the Connection an executescript() function

This function builds on the built-in Python function, and adds in
opening and reading the file in a way that can be used for running
generic scripts on the database.

I also take this chance to move SQL scripts into a subdirectory to keep
them together.

Signed-off-by: Anna Schumaker <anna@nowheycreamery.com>
This commit is contained in:
Anna Schumaker 2023-07-25 09:54:31 -04:00
parent 14b848bddd
commit 3f61adc941
4 changed files with 34 additions and 6 deletions

7
tests/test-script.sql Normal file
View File

@ -0,0 +1,7 @@
/* Copyright 2023 (c) Anna Schumaker */
CREATE TABLE test (a INT, b INT);
INSERT INTO test VALUES (1, 2);
INSERT INTO test VALUES (3, 4);
INSERT INTO test VALUES (5, 6);
INSERT INTO test VALUES (7, 8);
INSERT INTO test VALUES (9, 0);

View File

@ -22,8 +22,10 @@ class TestConnection(unittest.TestCase):
data_dir / "xfstestsdb-debug.sqlite3")
self.assertEqual(xfstestsdb.sqlite.DATABASE, ":memory:")
script = pathlib.Path(xfstestsdb.__file__).parent / "xfstestsdb.sql"
self.assertEqual(xfstestsdb.sqlite.SQL_SCRIPT, script)
self.assertEqual(xfstestsdb.sqlite.SQL_SCRIPTS,
pathlib.Path(xfstestsdb.__file__).parent / "scripts")
self.assertEqual(xfstestsdb.sqlite.SQL_V1_SCRIPT,
xfstestsdb.sqlite.SQL_SCRIPTS / "xfstestsdb.sql")
def test_foreign_keys(self):
"""Test that foreign key constraints are enabled."""
@ -72,6 +74,17 @@ class TestConnection(unittest.TestCase):
self.assertListEqual([(row["a"], row["b"]) for row in rows],
[(1, 2), (3, 4), (5, 6), (7, 8), (9, 0)])
def test_executescript(self):
"""Test running a sql script."""
script = pathlib.Path(__file__).parent / "test-script.sql"
cur = self.sql.executescript(script)
self.assertIsInstance(cur, sqlite3.Cursor)
rows = self.sql("SELECT * FROM test").fetchall()
self.assertListEqual([(row["a"], row["b"]) for row in rows],
[(1, 2), (3, 4), (5, 6), (7, 8), (9, 0)])
self.assertIsNone(self.sql.executescript(script.parent / "no-script"))
def test_transaction(self):
"""Test that we can manually start a transaction."""
self.assertFalse(self.sql.sql.in_transaction)

View File

@ -9,7 +9,9 @@ import xdg.BaseDirectory
DATA_DIR = pathlib.Path(xdg.BaseDirectory.save_data_path("xfstestsdb"))
DATA_FILE = DATA_DIR / f"xfstestsdb{'-debug' if __debug__ else ''}.sqlite3"
DATABASE = ":memory:" if "unittest" in sys.modules else DATA_FILE
SQL_SCRIPT = pathlib.Path(__file__).parent / "xfstestsdb.sql"
SQL_SCRIPTS = pathlib.Path(__file__).parent / "scripts"
SQL_V1_SCRIPT = SQL_SCRIPTS / "xfstestsdb.sql"
class Connection:
@ -25,9 +27,7 @@ class Connection:
self("PRAGMA foreign_keys = ON")
match self("PRAGMA user_version").fetchone()["user_version"]:
case 0:
with open(SQL_SCRIPT) as f:
self.sql.executescript(f.read())
self.sql.commit()
self.executescript(SQL_V1_SCRIPT)
def __call__(self, statement: str,
*args, **kwargs) -> sqlite3.Cursor | None:
@ -64,3 +64,11 @@ class Connection:
return self.sql.executemany(statement, args)
except sqlite3.IntegrityError:
return None
def executescript(self, script: pathlib.Path) -> sqlite3.Cursor | None:
"""Execute a SQL script."""
if script.is_file():
with open(script) as f:
cur = self.sql.executescript(f.read())
self.sql.commit()
return cur