# Copyright 2022 (c) Anna Schumaker """Easily work with our underlying sqlite3 database.""" import pathlib import sqlite3 import sys from gi.repository import GObject from .. import gsetup DATA_FILE = gsetup.DATA_DIR / f"emmental{gsetup.DEBUG_STR}.sqlite3" DATABASE = ":memory:" if "unittest" in sys.modules else DATA_FILE def adapt_path(path: pathlib.Path) -> str: """Adapt a pathlib.Path into a sqlite3 string.""" return str(path) def convert_path(path: bytes) -> pathlib.Path: """Convert a path string into a pathlib.Path object.""" return pathlib.Path(path.decode()) sqlite3.register_adapter(pathlib.PosixPath, adapt_path) sqlite3.register_converter("path", convert_path) class Connection(GObject.GObject): """Connect to the database.""" connected = GObject.Property(type=bool, default=True) def __init__(self): """Initialize a sqlite connection.""" super().__init__() self._sql = sqlite3.connect(DATABASE, detect_types=sqlite3.PARSE_DECLTYPES) self._sql.create_function("CASEFOLD", 1, lambda s: s.casefold() if s else None, deterministic=True) self._sql.row_factory = sqlite3.Row self._sql("PRAGMA foreign_keys = ON") def __call__(self, statement: str, *args, **kwargs) -> sqlite3.Cursor | None: """Execute a SQL statement.""" try: return self._sql.execute(statement, args if len(args) else kwargs) except sqlite3.IntegrityError: return None def __del__(self) -> None: """Clean up before exiting.""" self.close() def __enter__(self) -> None: """Begin a transaction.""" if not self._sql.in_transaction: self._sql.commit() self._sql.execute("BEGIN") def __exit__(self, exp_type, exp_value, traceback) -> bool: """Either commit or rollback an active transaction.""" if exp_type is None: self._sql.commit() else: self._sql.rollback() return exp_type is None def close(self) -> None: """Close the database connection.""" if self.connected: self._sql.commit() self._sql.execute("PRAGMA optimize") self._sql.close() self.connected = False def commit(self) -> None: """Commit pending changes.""" self._sql.commit() def executemany(self, statement: str, *args) -> sqlite3.Cursor | None: """Execute several similar SQL statements at once.""" try: return self._sql.executemany(statement, args) except sqlite3.InternalError: return None def executescript(self, script: pathlib.Path) -> sqlite3.Cursor | None: """Execute a SQL script.""" if script.is_file(): with open(script) as f: cur = self._sql.executescript(f.read()) self.commit() return cur