96 lines
3.0 KiB
Python
96 lines
3.0 KiB
Python
# Copyright 2022 (c) Anna Schumaker
|
|
"""Easily work with our underlying sqlite3 database."""
|
|
import pathlib
|
|
import sqlite3
|
|
import sys
|
|
from gi.repository import GObject
|
|
from .. import gsetup
|
|
|
|
|
|
DATA_FILE = gsetup.DATA_DIR / f"emmental{gsetup.DEBUG_STR}.sqlite3"
|
|
DATABASE = ":memory:" if "unittest" in sys.modules else DATA_FILE
|
|
|
|
|
|
def adapt_path(path: pathlib.Path) -> str:
|
|
"""Adapt a pathlib.Path into a sqlite3 string."""
|
|
return str(path)
|
|
|
|
|
|
def convert_path(path: bytes) -> pathlib.Path:
|
|
"""Convert a path string into a pathlib.Path object."""
|
|
return pathlib.Path(path.decode())
|
|
|
|
|
|
sqlite3.register_adapter(pathlib.PosixPath, adapt_path)
|
|
sqlite3.register_converter("path", convert_path)
|
|
|
|
|
|
class Connection(GObject.GObject):
|
|
"""Connect to the database."""
|
|
|
|
connected = GObject.Property(type=bool, default=True)
|
|
|
|
def __init__(self):
|
|
"""Initialize a sqlite connection."""
|
|
super().__init__()
|
|
self._sql = sqlite3.connect(DATABASE,
|
|
detect_types=sqlite3.PARSE_DECLTYPES)
|
|
self._sql.create_function("CASEFOLD", 1,
|
|
lambda s: s.casefold() if s else None,
|
|
deterministic=True)
|
|
self._sql.row_factory = sqlite3.Row
|
|
self._sql("PRAGMA foreign_keys = ON")
|
|
|
|
def __call__(self, statement: str,
|
|
*args, **kwargs) -> sqlite3.Cursor | None:
|
|
"""Execute a SQL statement."""
|
|
try:
|
|
return self._sql.execute(statement, args if len(args) else kwargs)
|
|
except sqlite3.IntegrityError:
|
|
return None
|
|
|
|
def __del__(self) -> None:
|
|
"""Clean up before exiting."""
|
|
self.close()
|
|
|
|
def __enter__(self) -> None:
|
|
"""Begin a transaction."""
|
|
if not self._sql.in_transaction:
|
|
self._sql.commit()
|
|
self._sql.execute("BEGIN")
|
|
|
|
def __exit__(self, exp_type, exp_value, traceback) -> bool:
|
|
"""Either commit or rollback an active transaction."""
|
|
if exp_type is None:
|
|
self._sql.commit()
|
|
else:
|
|
self._sql.rollback()
|
|
return exp_type is None
|
|
|
|
def close(self) -> None:
|
|
"""Close the database connection."""
|
|
if self.connected:
|
|
self._sql.commit()
|
|
self._sql.execute("PRAGMA optimize")
|
|
self._sql.close()
|
|
self.connected = False
|
|
|
|
def commit(self) -> None:
|
|
"""Commit pending changes."""
|
|
self._sql.commit()
|
|
|
|
def executemany(self, statement: str, *args) -> sqlite3.Cursor | None:
|
|
"""Execute several similar SQL statements at once."""
|
|
try:
|
|
return self._sql.executemany(statement, args)
|
|
except sqlite3.InternalError:
|
|
return None
|
|
|
|
def executescript(self, script: pathlib.Path) -> sqlite3.Cursor | None:
|
|
"""Execute a SQL script."""
|
|
if script.is_file():
|
|
with open(script) as f:
|
|
cur = self._sql.executescript(f.read())
|
|
self.commit()
|
|
return cur
|