emmental/emmental/db/connection.py

96 lines
3.0 KiB
Python

# Copyright 2022 (c) Anna Schumaker
"""Easily work with our underlying sqlite3 database."""
import pathlib
import sqlite3
import sys
from gi.repository import GObject
from .. import gsetup
DATA_FILE = gsetup.DATA_DIR / f"emmental{gsetup.DEBUG_STR}.sqlite3"
DATABASE = ":memory:" if "unittest" in sys.modules else DATA_FILE
def adapt_path(path: pathlib.Path) -> str:
"""Adapt a pathlib.Path into a sqlite3 string."""
return str(path)
def convert_path(path: bytes) -> pathlib.Path:
"""Convert a path string into a pathlib.Path object."""
return pathlib.Path(path.decode())
sqlite3.register_adapter(pathlib.PosixPath, adapt_path)
sqlite3.register_converter("path", convert_path)
class Connection(GObject.GObject):
"""Connect to the database."""
connected = GObject.Property(type=bool, default=True)
def __init__(self):
"""Initialize a sqlite connection."""
super().__init__()
self._sql = sqlite3.connect(DATABASE,
detect_types=sqlite3.PARSE_DECLTYPES)
self._sql.create_function("CASEFOLD", 1,
lambda s: s.casefold() if s else None,
deterministic=True)
self._sql.row_factory = sqlite3.Row
self._sql("PRAGMA foreign_keys = ON")
def __call__(self, statement: str,
*args, **kwargs) -> sqlite3.Cursor | None:
"""Execute a SQL statement."""
try:
return self._sql.execute(statement, args if len(args) else kwargs)
except sqlite3.IntegrityError:
return None
def __del__(self) -> None:
"""Clean up before exiting."""
self.close()
def __enter__(self) -> None:
"""Begin a transaction."""
if not self._sql.in_transaction:
self._sql.commit()
self._sql.execute("BEGIN")
def __exit__(self, exp_type, exp_value, traceback) -> bool:
"""Either commit or rollback an active transaction."""
if exp_type is None:
self._sql.commit()
else:
self._sql.rollback()
return exp_type is None
def close(self) -> None:
"""Close the database connection."""
if self.connected:
self._sql.commit()
self._sql.execute("PRAGMA optimize")
self._sql.close()
self.connected = False
def commit(self) -> None:
"""Commit pending changes."""
self._sql.commit()
def executemany(self, statement: str, *args) -> sqlite3.Cursor | None:
"""Execute several similar SQL statements at once."""
try:
return self._sql.executemany(statement, args)
except sqlite3.InternalError:
return None
def executescript(self, script: pathlib.Path) -> sqlite3.Cursor | None:
"""Execute a SQL script."""
if script.is_file():
with open(script) as f:
cur = self._sql.executescript(f.read())
self.commit()
return cur