# Copyright 2022 (c) Anna Schumaker """Test our custom db Connection object.""" import pathlib import sqlite3 import emmental.db.connection import unittest from gi.repository import GObject class TestConnection(unittest.TestCase): """Test case for our database connection manager.""" def setUp(self): """Set up common variables.""" self.sql = emmental.db.connection.Connection() def tearDown(self): """Clean up.""" self.sql.close() def test_paths(self): """Check that path constants are pointing to the right places.""" self.assertEqual(emmental.db.connection.DATA_FILE, emmental.gsetup.DATA_DIR / "emmental-debug.sqlite3") self.assertEqual(emmental.db.connection.DATABASE, ":memory:") def test_connection(self): """Check that the connection manager is initialized properly.""" self.assertIsInstance(self.sql, GObject.GObject) self.assertIsInstance(self.sql._sql, sqlite3.Connection) self.assertEqual(self.sql._sql.row_factory, sqlite3.Row) self.assertTrue(self.sql.connected) def test_foreign_keys(self): """Test that foreign keys are enabled.""" cur = self.sql("PRAGMA foreign_keys") self.assertEqual(cur.fetchone()["foreign_keys"], 1) def test_call(self): """Check that we can call the connection to execute a sql statement.""" self.sql("CREATE TABLE test (a INT UNIQUE, b INT)") self.sql("INSERT INTO test VALUES (?, ?)", 1, 2) cur = self.sql("SELECT * FROM test") self.assertIsInstance(cur, sqlite3.Cursor) row = cur.fetchone() self.assertIsInstance(row, sqlite3.Row) self.assertEqual(row["a"], 1) self.assertEqual(row["b"], 2) def test_call_keyword(self): """Test running a sql statement with keyword arguments.""" self.sql("CREATE TABLE test (a INT UNIQUE, b INT)") self.sql("INSERT INTO test VALUES (:a, :b)", a=1, b=2) cur = self.sql("SELECT * FROM test") self.assertIsInstance(cur, sqlite3.Cursor) row = cur.fetchone() self.assertIsInstance(row, sqlite3.Row) self.assertEqual(row["a"], 1) self.assertEqual(row["b"], 2) def test_casefold(self): """Test the casefold function.""" self.sql("CREATE TABLE test (a INT, text TEXT)") self.sql("INSERT INTO test VALUES (?, ?)", 1, "TEST") self.sql("INSERT INTO test VALUES (?, ?)", 2, None) rows = self.sql("SELECT CASEFOLD(text) as text FROM test").fetchall() self.assertEqual(rows[0]["text"], "test") self.assertEqual(rows[1]["text"], None) def test_executemany(self): """Test the executemany function.""" self.sql("CREATE TABLE test (a INT, b TEXT)") self.sql.executemany("INSERT INTO test VALUES (?, ?)", (1, "a"), (2, "b"), (3, "c"), (4, "d"), (5, "e")) rows = self.sql("SELECT * FROM test").fetchall() self.assertEqual(tuple(rows[0]), (1, "a")) self.assertEqual(tuple(rows[1]), (2, "b")) self.assertEqual(tuple(rows[2]), (3, "c")) self.assertEqual(tuple(rows[3]), (4, "d")) self.assertEqual(tuple(rows[4]), (5, "e")) @unittest.mock.patch("emmental.db.connection.Connection.commit") def test_executescript(self, mock_commit: unittest.mock.Mock): """Test the executescript function.""" script = pathlib.Path(__file__).parent / "test-script.sql" cur = self.sql.executescript(script) self.assertIsInstance(cur, sqlite3.Cursor) mock_commit.assert_called() rows = self.sql("SELECT * FROM test").fetchall() self.assertListEqual([(row["a"], row["b"]) for row in rows], [(1, 2), (3, 4), (5, 6), (7, 8), (9, 0)]) self.assertIsNone(self.sql.executescript(script.parent / "no-script")) def test_path_column(self): """Test that the PATH column type has been set up.""" self.sql("CREATE TABLE test (path PATH)") self.sql("INSERT INTO test VALUES (?)", pathlib.Path("/my/test/path")) row = self.sql("SELECT path FROM test").fetchone() self.assertIsInstance(row["path"], pathlib.Path) self.assertEqual(row["path"], pathlib.Path("/my/test/path")) def test_transaction(self): """Test that we can manually start a transaction.""" self.assertFalse(self.sql._sql.in_transaction) with self.sql: self.assertTrue(self.sql._sql.in_transaction) self.sql("CREATE TABLE test_table (test TEXT)") self.sql("INSERT INTO test_table VALUES (?)", "Test") self.assertTrue(self.sql._sql.in_transaction) self.assertFalse(self.sql._sql.in_transaction) cur = self.sql("SELECT COUNT(*) FROM test_table") self.assertEqual(cur.fetchone()["COUNT(*)"], 1) def test_transaction_rollback(self): """Test that errors roll back the transaction.""" with self.assertRaises(Exception): with self.sql: self.sql("CREATE TABLE other_table (test TEXT)") self.sql("INSERT INTO other_table VALUES (?)", "Test") raise Exception("Test Exeption") self.assertFalse(self.sql._sql.in_transaction) with self.assertRaises(sqlite3.OperationalError): self.sql("SELECT COUNT(*) FROM other_table") def test_close(self): """Check closing the connection.""" self.sql.close() self.assertFalse(self.sql.connected) with self.assertRaises(sqlite3.ProgrammingError): self.assertIsNone(self.sql("SELECT COUNT(*) FROM test_table")) self.sql.close()