emmental/emmental/db/tracks.py

375 lines
15 KiB
Python

# Copyright 2022 (c) Anna Schumaker.
"""A custom Gio.ListModel for working with tracks."""
import datetime
import pathlib
import random
import sqlite3
from typing import Iterable
from gi.repository import GObject
from gi.repository import Gtk
from . import table
PLAYED_THRESHOLD = 2 / 3
class Track(table.Row):
"""Our custom Track object."""
trackid = GObject.Property(type=int)
libraryid = GObject.Property(type=int)
mediumid = GObject.Property(type=int)
year = GObject.Property(type=int)
active = GObject.Property(type=bool, default=False)
favorite = GObject.Property(type=bool, default=False)
path = GObject.Property(type=GObject.TYPE_PYOBJECT)
mbid = GObject.Property(type=str)
title = GObject.Property(type=str)
artist = GObject.Property(type=str)
number = GObject.Property(type=int)
length = GObject.Property(type=float)
mtime = GObject.Property(type=float)
playcount = GObject.Property(type=int)
added = GObject.Property(type=GObject.TYPE_PYOBJECT)
laststarted = GObject.Property(type=GObject.TYPE_PYOBJECT)
lastplayed = GObject.Property(type=GObject.TYPE_PYOBJECT)
restarted = GObject.Property(type=GObject.TYPE_PYOBJECT)
def do_update(self, column: str) -> bool:
"""Update a Track object."""
match column:
case "trackid" | "libraryid" | "active" | "path" | "playcount" | \
"laststarted" | "lastplayed" | "restarted": pass
case _: return super().do_update(column)
return True
def get_artists(self) -> list[table.Row]:
"""Get a list of Artists for this Track."""
return self.table.get_artists(self)
def get_genres(self) -> list[table.Row]:
"""Get a list of Genres for this Track."""
return self.table.get_genres(self)
def get_library(self) -> table.Row | None:
"""Get the Library associated with this Track."""
return self.table.sql.libraries.rows.get(self.libraryid)
def get_medium(self) -> table.Row | None:
"""Get the Medium associated with this Track."""
return self.table.sql.media.rows.get(self.mediumid)
def get_year(self) -> table.Row | None:
"""Get the Year associated with this Track."""
return self.table.sql.years.rows.get(self.year)
def restart(self) -> None:
"""Mark that a previously started track has been started again."""
self.table.restart_track(self)
def start(self) -> None:
"""Mark that this track has started playback."""
self.table.start_track(self)
def stop(self, play_time: float) -> None:
"""Mark that this track has stopped playback."""
self.table.stop_track(self, play_time / self.length > PLAYED_THRESHOLD)
def update_properties(self, **kwargs) -> None:
"""Update one or more of this Track's properties."""
for (property, newval) in kwargs.items():
if self.get_property(property) != newval:
self.set_property(property, newval)
@property
def primary_key(self) -> int:
"""Get the primary key for this Track."""
return self.trackid
class Filter(table.KeySet):
"""A customized Filter that never sets strictness to FilterMatch.All."""
def do_get_strictness(self) -> Gtk.FilterMatch:
"""Get the strictness of the filter."""
if self.n_keys == 0:
return Gtk.FilterMatch.NONE
return Gtk.FilterMatch.SOME
class Table(table.Table):
"""A ListStore tailored for storing Track objects."""
have_current_track = GObject.Property(type=bool, default=False)
current_track = GObject.Property(type=Track)
current_favorite = GObject.Property(type=bool, default=False)
def __init__(self, sql: GObject.TYPE_PYOBJECT):
"""Initialize a Track Table."""
super().__init__(sql, filter=Filter())
self.set_model(None)
self.connect("notify::current-track", self.__notify_current_track)
self.connect("notify::current-favorite",
self.__notify_current_favorite)
def __notify_current_track(self, table: table.Table, param) -> None:
if self.current_track is not None:
self.have_current_track = True
self.current_favorite = self.current_track.favorite
self.sql.playlists.previous.add_track(self.current_track)
else:
self.have_current_track = False
self.current_favorite = False
def __notify_current_favorite(self, table: table.Table, param) -> None:
if self.current_track is not None:
self.current_track.update_properties(
favorite=self.current_favorite)
elif self.current_favorite is True:
self.current_favorite = False
def do_construct(self, **kwargs) -> Track:
"""Construct a new Track instance."""
if (track := Track(**kwargs)).active:
self.current_track = track
return track
def do_sql_delete(self, track: Track) -> sqlite3.Cursor:
"""Delete a Track."""
return self.sql("DELETE FROM tracks WHERE trackid=?", track.trackid)
def do_sql_glob(self, glob: str) -> sqlite3.Cursor:
"""Filter the Track table."""
return self.sql("""SELECT trackid FROM track_info_view WHERE
CASEFOLD(title) GLOB :glob
OR CASEFOLD(artist) GLOB :glob
OR CASEFOLD(album) GLOB :glob
OR CASEFOLD(albumartist) GLOB :glob
OR CASEFOLD(medium) GLOB :glob
OR release GLOB :glob""", glob=glob)
def do_sql_insert(self, library: table.Row, path: pathlib.Path,
medium: table.Row, year: table.Row, *, title: str = "",
number: int = 0, length: float = 0.0, artist: str = "",
mbid: str = "", mtime: float = 0.0) -> sqlite3.Cursor:
"""Insert a new Track into the database."""
if cur := self.sql("""INSERT INTO tracks
(libraryid, mediumid, path, year, title,
number, length, artist, mbid, mtime)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
RETURNING *""",
library.libraryid, medium.mediumid, path, year.year,
title, number, length, artist, mbid, mtime):
return self.sql("SELECT * FROM tracks WHERE trackid=?",
cur.lastrowid)
def do_sql_select_all(self) -> sqlite3.Cursor:
"""Load Tracks from the database."""
return self.sql("SELECT * FROM tracks")
def do_sql_select_one(self, library: table.Row = None,
*, path: pathlib.Path = None,
mbid: str = None) -> sqlite3.Cursor:
"""Look up a Track in the database."""
if path is None and mbid is None:
raise KeyError("Either 'path' or 'mbid' are required")
args = [("libraryid=?", library.libraryid if library else None),
("path=?", path), ("mbid=?", mbid)]
(where, args) = tuple(zip(*[arg for arg in args if None not in arg]))
sql_where = " AND ".join(where)
return self.sql(f"SELECT trackid FROM tracks WHERE {sql_where}", *args)
def do_sql_update(self, track: Track, column: str,
newval: any) -> sqlite3.Cursor:
"""Update a Track."""
match (column, newval):
case ("favorite", True):
self.sql.playlists.favorites.add_track(track)
if track == self.current_track:
self.current_favorite = True
case ("favorite", False):
self.sql.playlists.favorites.remove_track(track)
if track == self.current_track:
self.current_favorite = False
return self.sql(f"UPDATE tracks SET {column}=? WHERE trackid=?",
newval, track.trackid)
def delete_listens(self, listenids: list[int]) -> None:
"""Delete the listens indicated by the provided listenids."""
self.sql.executemany("""DELETE FROM listenbrainz_queue
WHERE listenid=?""",
*[(id,) for id in listenids])
def get_artists(self, track: Track) -> list[table.Row]:
"""Get the set of Artists for a specific Track."""
rows = self.sql("""SELECT artistid FROM artist_tracks_view
WHERE trackid=?""", track.trackid).fetchall()
return [self.sql.artists.rows.get(row["artistid"]) for row in rows]
def get_genres(self, track: Track) -> list[int]:
"""Get the list of Genres for a specific Track."""
rows = self.sql("""SELECT genreid FROM genre_tracks_view
WHERE trackid=?""", track.trackid).fetchall()
return [self.sql.genres.rows.get(row["genreid"]) for row in rows]
def get_n_listens(self, n: int) -> list[tuple]:
"""Get the n most recent listens from the listenbrainz queue."""
cur = self.sql("""SELECT listenid, trackid, timestamp
FROM listenbrainz_queue ORDER BY timestamp DESC
LIMIT ?""", n)
return [(row["listenid"], self.rows[row["trackid"]], row["timestamp"])
for row in cur.fetchall()]
def map_sort_order(self, ordering: str) -> dict[int, int]:
"""Get a lookup table for Track sort keys."""
ordering = ordering if len(ordering) > 0 else "trackid"
cur = self.sql(f"""SELECT trackid FROM track_info_view
ORDER BY {ordering}""")
return {row["trackid"]: i for (i, row) in enumerate(cur.fetchall())}
def mark_path_active(self, path: pathlib.Path) -> None:
"""Mark a specific track as active in the database.."""
if self.sql("UPDATE tracks SET active=TRUE WHERE path=?",
path).rowcount == 0:
self.sql("UPDATE tracks SET active=FALSE WHERE active=TRUE")
def restart_track(self, track: Track) -> None:
"""Mark that a Track has been restarted."""
track.active = True
track.restarted = datetime.datetime.utcnow()
self.current_track = track
def start_track(self, track: Track) -> None:
"""Mark that a Track has been started."""
self.sql.playlists.previous.remove_track(track)
cur = self.sql("""UPDATE tracks SET active=TRUE, laststarted=?
WHERE trackid=? RETURNING laststarted""",
datetime.datetime.utcnow(), track.trackid)
track.active = True
track.laststarted = cur.fetchone()["laststarted"]
self.current_track = track
self.sql.commit()
def stop_track(self, track: Track, played: bool) -> None:
"""Mark that a Track has been stopped."""
args = [("active=?", False)]
if played:
if track.restarted is not None:
track.laststarted = track.restarted
args.append(("laststarted=?", track.restarted))
args.append(("lastplayed=?", track.laststarted))
args.append(("playcount=?", track.playcount + 1))
(fields, vals) = tuple(zip(*args))
update = ", ".join(fields)
row = self.sql(f"""UPDATE tracks SET {update} WHERE trackid=?
RETURNING lastplayed, playcount""",
*vals, track.trackid).fetchone()
track.active = False
track.playcount = row["playcount"]
track.lastplayed = row["lastplayed"]
track.restarted = None
self.current_track = None
if played:
self.sql.playlists.most_played.reload_tracks(idle=True)
self.sql.playlists.queued.remove_track(track)
self.sql.playlists.unplayed.remove_track(track)
self.emit("track-played", track)
self.sql.commit()
@GObject.Signal(arg_types=(Track,))
def track_played(self, track: Track) -> None:
"""Signal that a Track was played."""
if track is not None:
self.sql("""INSERT INTO listenbrainz_queue (trackid, timestamp)
VALUES (?, ?)""", track.trackid, track.lastplayed)
class TrackidSet(GObject.GObject):
"""Manage a set of Track IDs."""
n_trackids = GObject.Property(type=int)
def __init__(self, trackids: Iterable[int] = []):
"""Initialize a TrackidSet."""
super().__init__()
self.__trackids = set(trackids)
self.n_trackids = len(self.__trackids)
def __contains__(self, track: Track) -> bool:
"""Check if a Track is in the set."""
return track.trackid in self.__trackids
def __len__(self) -> int:
"""Find the number of Tracks in the set."""
return len(self.__trackids)
def __sub__(self, rhs):
"""Subtract two TrackidSets."""
return TrackidSet(self.__trackids - rhs.trackids)
def add_track(self, track: Track) -> None:
"""Add a Track to the set."""
if track.trackid not in self.__trackids:
self.__trackids.add(track.trackid)
self.n_trackids = len(self)
self.emit("trackid-added", track.trackid)
def random_trackid(self) -> int | None:
"""Get a random trackid from the set."""
if len(self.__trackids) > 0:
return random.choice(list(self.__trackids))
def remove_track(self, track: Track) -> None:
"""Remove a Track from the set."""
if track.trackid in self.__trackids:
self.__trackids.discard(track.trackid)
self.n_trackids = len(self)
self.emit("trackid-removed", track.trackid)
@property
def trackids(self) -> set:
"""Get the set of trackids."""
return self.__trackids
@trackids.setter
def trackids(self, trackids: Iterable[int]) -> None:
"""Add several trackids to the set at one time."""
new_trackids = set(trackids)
if self.__trackids.isdisjoint(new_trackids):
self.__trackids = new_trackids
self.n_trackids = len(self)
self.emit("trackids-reset")
else:
removed = self.__trackids - new_trackids
added = new_trackids - self.__trackids
self.__trackids = new_trackids
self.n_trackids = len(self)
for id in removed:
self.emit("trackid-removed", id)
for id in added:
self.emit("trackid-added", id)
@GObject.Signal(arg_types=(int,))
def trackid_added(self, trackid: int) -> None:
"""Signal that a Track has been added to the set."""
@GObject.Signal(arg_types=(int,))
def trackid_removed(self, trackid: int) -> None:
"""Signal that a Track has been removed from the set."""
@GObject.Signal
def trackids_reset(self) -> None:
"""Signal that the Tracks in the set have been reset."""