emmental/emmental/db/tracks.py

155 lines
6.1 KiB
Python

# Copyright 2022 (c) Anna Schumaker.
"""A custom Gio.ListModel for working with tracks."""
import pathlib
import sqlite3
from gi.repository import GObject
from gi.repository import Gtk
from . import table
PLAYED_THRESHOLD = 2 / 3
class Track(table.Row):
"""Our custom Track object."""
trackid = GObject.Property(type=int)
libraryid = GObject.Property(type=int)
mediumid = GObject.Property(type=int)
year = GObject.Property(type=int)
active = GObject.Property(type=bool, default=False)
favorite = GObject.Property(type=bool, default=False)
path = GObject.Property(type=GObject.TYPE_PYOBJECT)
mbid = GObject.Property(type=str)
title = GObject.Property(type=str)
artist = GObject.Property(type=str)
number = GObject.Property(type=int)
length = GObject.Property(type=float)
mtime = GObject.Property(type=float)
playcount = GObject.Property(type=int)
added = GObject.Property(type=GObject.TYPE_PYOBJECT)
laststarted = GObject.Property(type=GObject.TYPE_PYOBJECT)
lastplayed = GObject.Property(type=GObject.TYPE_PYOBJECT)
restarted = GObject.Property(type=GObject.TYPE_PYOBJECT)
def do_update(self, column: str) -> bool:
"""Update a Track object."""
match column:
case "trackid" | "libraryid" | "active" | "path" | "playcount" | \
"laststarted" | "lastplayed" | "restarted": pass
case _: return super().do_update(column)
return True
def get_library(self) -> table.Row | None:
"""Get the Library associated with this Track."""
return self.table.sql.libraries.rows.get(self.libraryid)
def get_medium(self) -> table.Row | None:
"""Get the Medium associated with this Track."""
return self.table.sql.media.rows.get(self.mediumid)
def get_year(self) -> table.Row | None:
"""Get the Year associated with this Track."""
return self.table.sql.years.rows.get(self.year)
def update_properties(self, **kwargs) -> None:
"""Update one or more of this Track's properties."""
for (property, newval) in kwargs.items():
if self.get_property(property) != newval:
self.set_property(property, newval)
@property
def primary_key(self) -> int:
"""Get the primary key for this Track."""
return self.trackid
class Filter(table.Filter):
"""A customized Filter that never sets strictness to FilterMatch.All."""
def do_get_strictness(self) -> Gtk.FilterMatch:
"""Get the strictness of the filter."""
if self.n_keys == 0:
return Gtk.FilterMatch.NONE
return Gtk.FilterMatch.SOME
class Table(table.Table):
"""A ListStore tailored for storing Track objects."""
def __init__(self, sql: GObject.TYPE_PYOBJECT):
"""Initialize a Track Table."""
super().__init__(sql, filter=Filter())
self.set_model(None)
def do_construct(self, **kwargs) -> Track:
"""Construct a new Track instance."""
return Track(**kwargs)
def do_sql_delete(self, track: Track) -> sqlite3.Cursor:
"""Delete a Track."""
return self.sql("DELETE FROM tracks WHERE trackid=?", track.trackid)
def do_sql_glob(self, glob: str) -> sqlite3.Cursor:
"""Filter the Track table."""
return self.sql("""SELECT trackid FROM track_info_view WHERE
CASEFOLD(title) GLOB :glob
OR CASEFOLD(artist) GLOB :glob
OR CASEFOLD(album) GLOB :glob
OR CASEFOLD(albumartist) GLOB :glob
OR CASEFOLD(medium) GLOB :glob
OR release GLOB :glob""", glob=glob)
def do_sql_insert(self, library: table.Row, path: pathlib.Path,
medium: table.Row, year: table.Row, *, title: str = "",
number: int = 0, length: float = 0.0, artist: str = "",
mbid: str = "", mtime: float = 0.0) -> sqlite3.Cursor:
"""Insert a new Track into the database."""
return self.sql("""INSERT INTO tracks
(libraryid, mediumid, path, year, title,
number, length, artist, mbid, mtime)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
RETURNING *""",
library.libraryid, medium.mediumid, path, year.year,
title, number, length, artist, mbid, mtime)
def do_sql_select_all(self) -> sqlite3.Cursor:
"""Load Tracks from the database."""
return self.sql("SELECT * FROM tracks")
def do_sql_select_one(self, library: table.Row = None,
*, path: pathlib.Path = None,
mbid: str = None) -> sqlite3.Cursor:
"""Look up a Track in the database."""
if path is None and mbid is None:
raise KeyError("Either 'path' or 'mbid' are required")
args = [("libraryid=?", library.libraryid if library else None),
("path=?", path), ("mbid=?", mbid)]
(where, args) = tuple(zip(*[arg for arg in args if None not in arg]))
sql_where = " AND ".join(where)
return self.sql(f"SELECT trackid FROM tracks WHERE {sql_where}", *args)
def do_sql_update(self, track: Track, column: str,
newval: any) -> sqlite3.Cursor:
"""Update a Track."""
match (column, newval):
case ("favorite", True):
self.sql.playlists.favorites.add_track(track)
case ("favorite", False):
self.sql.playlists.favorites.remove_track(track)
return self.sql(f"UPDATE tracks SET {column}=? WHERE trackid=?",
newval, track.trackid)
def map_sort_order(self, ordering: str) -> dict[int, int]:
"""Get a lookup table for Track sort keys."""
ordering = ordering if len(ordering) > 0 else "trackid"
cur = self.sql(f"""SELECT trackid FROM track_info_view
ORDER BY {ordering}""")
return {row["trackid"]: i for (i, row) in enumerate(cur.fetchall())}