# Copyright 2022 (c) Anna Schumaker. """A custom Gio.ListModel for working with tracks.""" import datetime import pathlib import random import sqlite3 from typing import Iterable from gi.repository import GObject from gi.repository import Gtk from . import table PLAYED_THRESHOLD = 2 / 3 class Track(table.Row): """Our custom Track object.""" trackid = GObject.Property(type=int) libraryid = GObject.Property(type=int) mediumid = GObject.Property(type=int) year = GObject.Property(type=int) active = GObject.Property(type=bool, default=False) favorite = GObject.Property(type=bool, default=False) path = GObject.Property(type=GObject.TYPE_PYOBJECT) mbid = GObject.Property(type=str) title = GObject.Property(type=str) artist = GObject.Property(type=str) number = GObject.Property(type=int) length = GObject.Property(type=float) mtime = GObject.Property(type=float) playcount = GObject.Property(type=int) added = GObject.Property(type=GObject.TYPE_PYOBJECT) laststarted = GObject.Property(type=GObject.TYPE_PYOBJECT) lastplayed = GObject.Property(type=GObject.TYPE_PYOBJECT) restarted = GObject.Property(type=GObject.TYPE_PYOBJECT) def do_update(self, column: str) -> bool: """Update a Track object.""" match column: case "trackid" | "libraryid" | "active" | "path" | "playcount" | \ "laststarted" | "lastplayed" | "restarted": pass case _: return super().do_update(column) return True def get_artists(self) -> list[table.Row]: """Get a list of Artists for this Track.""" return self.table.get_artists(self) def get_genres(self) -> list[table.Row]: """Get a list of Genres for this Track.""" return self.table.get_genres(self) def get_library(self) -> table.Row | None: """Get the Library associated with this Track.""" return self.table.sql.libraries.rows.get(self.libraryid) def get_medium(self) -> table.Row | None: """Get the Medium associated with this Track.""" return self.table.sql.media.rows.get(self.mediumid) def get_year(self) -> table.Row | None: """Get the Year associated with this Track.""" return self.table.sql.years.rows.get(self.year) def restart(self) -> None: """Mark that a previously started track has been started again.""" self.table.restart_track(self) def start(self) -> None: """Mark that this track has started playback.""" self.table.start_track(self) def stop(self, play_time: float) -> None: """Mark that this track has stopped playback.""" self.table.stop_track(self, play_time / self.length > PLAYED_THRESHOLD) def update_properties(self, **kwargs) -> None: """Update one or more of this Track's properties.""" for (property, newval) in kwargs.items(): if self.get_property(property) != newval: self.set_property(property, newval) @property def primary_key(self) -> int: """Get the primary key for this Track.""" return self.trackid class Filter(table.KeySet): """A customized Filter that never sets strictness to FilterMatch.All.""" def do_get_strictness(self) -> Gtk.FilterMatch: """Get the strictness of the filter.""" if self.n_keys == 0: return Gtk.FilterMatch.NONE return Gtk.FilterMatch.SOME class Table(table.Table): """A ListStore tailored for storing Track objects.""" have_current_track = GObject.Property(type=bool, default=False) current_track = GObject.Property(type=Track) current_favorite = GObject.Property(type=bool, default=False) def __init__(self, sql: GObject.TYPE_PYOBJECT): """Initialize a Track Table.""" super().__init__(sql, filter=Filter()) self.set_model(None) self.connect("notify::current-track", self.__notify_current_track) self.connect("notify::current-favorite", self.__notify_current_favorite) def __notify_current_track(self, table: table.Table, param) -> None: if self.current_track is not None: self.have_current_track = True self.current_favorite = self.current_track.favorite self.sql.playlists.previous.add_track(self.current_track) else: self.have_current_track = False self.current_favorite = False def __notify_current_favorite(self, table: table.Table, param) -> None: if self.current_track is not None: self.current_track.update_properties( favorite=self.current_favorite) elif self.current_favorite is True: self.current_favorite = False def do_construct(self, **kwargs) -> Track: """Construct a new Track instance.""" if (track := Track(**kwargs)).active: self.current_track = track return track def do_sql_delete(self, track: Track) -> sqlite3.Cursor: """Delete a Track.""" return self.sql("DELETE FROM tracks WHERE trackid=?", track.trackid) def do_sql_glob(self, glob: str) -> sqlite3.Cursor: """Filter the Track table.""" return self.sql("""SELECT trackid FROM track_info_view WHERE CASEFOLD(title) GLOB :glob OR CASEFOLD(artist) GLOB :glob OR CASEFOLD(album) GLOB :glob OR CASEFOLD(albumartist) GLOB :glob OR CASEFOLD(medium) GLOB :glob OR release GLOB :glob""", glob=glob) def do_sql_insert(self, library: table.Row, path: pathlib.Path, medium: table.Row, year: table.Row, *, title: str = "", number: int = 0, length: float = 0.0, artist: str = "", mbid: str = "", mtime: float = 0.0) -> sqlite3.Cursor: """Insert a new Track into the database.""" if cur := self.sql("""INSERT INTO tracks (libraryid, mediumid, path, year, title, number, length, artist, mbid, mtime) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?) RETURNING *""", library.libraryid, medium.mediumid, path, year.year, title, number, length, artist, mbid, mtime): return self.sql("SELECT * FROM tracks WHERE trackid=?", cur.lastrowid) def do_sql_select_all(self) -> sqlite3.Cursor: """Load Tracks from the database.""" return self.sql("SELECT * FROM tracks") def do_sql_select_one(self, library: table.Row = None, *, path: pathlib.Path = None, mbid: str = None) -> sqlite3.Cursor: """Look up a Track in the database.""" if path is None and mbid is None: raise KeyError("Either 'path' or 'mbid' are required") args = [("libraryid=?", library.libraryid if library else None), ("path=?", path), ("mbid=?", mbid)] (where, args) = tuple(zip(*[arg for arg in args if None not in arg])) sql_where = " AND ".join(where) return self.sql(f"SELECT trackid FROM tracks WHERE {sql_where}", *args) def do_sql_update(self, track: Track, column: str, newval: any) -> sqlite3.Cursor: """Update a Track.""" match (column, newval): case ("favorite", True): self.sql.playlists.favorites.add_track(track) if track == self.current_track: self.current_favorite = True case ("favorite", False): self.sql.playlists.favorites.remove_track(track) if track == self.current_track: self.current_favorite = False return self.sql(f"UPDATE tracks SET {column}=? WHERE trackid=?", newval, track.trackid) def get_artists(self, track: Track) -> list[table.Row]: """Get the set of Artists for a specific Track.""" rows = self.sql("""SELECT artistid FROM artist_tracks_view WHERE trackid=?""", track.trackid).fetchall() return [self.sql.artists.rows.get(row["artistid"]) for row in rows] def get_genres(self, track: Track) -> list[int]: """Get the list of Genres for a specific Track.""" rows = self.sql("""SELECT genreid FROM genre_tracks_view WHERE trackid=?""", track.trackid).fetchall() return [self.sql.genres.rows.get(row["genreid"]) for row in rows] def map_sort_order(self, ordering: str) -> dict[int, int]: """Get a lookup table for Track sort keys.""" ordering = ordering if len(ordering) > 0 else "trackid" cur = self.sql(f"""SELECT trackid FROM track_info_view ORDER BY {ordering}""") return {row["trackid"]: i for (i, row) in enumerate(cur.fetchall())} def mark_path_active(self, path: pathlib.Path) -> None: """Mark a specific track as active in the database..""" if self.sql("UPDATE tracks SET active=TRUE WHERE path=?", path).rowcount == 0: self.sql("UPDATE tracks SET active=FALSE WHERE active=TRUE") def restart_track(self, track: Track) -> None: """Mark that a Track has been restarted.""" track.active = True track.restarted = datetime.datetime.utcnow() self.current_track = track def start_track(self, track: Track) -> None: """Mark that a Track has been started.""" self.sql.playlists.previous.remove_track(track) cur = self.sql("""UPDATE tracks SET active=TRUE, laststarted=? WHERE trackid=? RETURNING laststarted""", datetime.datetime.utcnow(), track.trackid) track.active = True track.laststarted = cur.fetchone()["laststarted"] self.current_track = track self.sql.commit() def stop_track(self, track: Track, played: bool) -> None: """Mark that a Track has been stopped.""" args = [("active=?", False)] if played: if track.restarted is not None: track.laststarted = track.restarted args.append(("laststarted=?", track.restarted)) args.append(("lastplayed=?", track.laststarted)) args.append(("playcount=?", track.playcount + 1)) (fields, vals) = tuple(zip(*args)) update = ", ".join(fields) row = self.sql(f"""UPDATE tracks SET {update} WHERE trackid=? RETURNING lastplayed, playcount""", *vals, track.trackid).fetchone() track.active = False track.playcount = row["playcount"] track.lastplayed = row["lastplayed"] track.restarted = None self.current_track = None if played: self.sql.playlists.most_played.reload_tracks(idle=True) self.sql.playlists.queued.remove_track(track) self.sql.playlists.unplayed.remove_track(track) self.emit("track-played", track) self.sql.commit() @GObject.Signal(arg_types=(Track,)) def track_played(self, track: Track) -> None: """Signal that a Track was played.""" class TrackidSet(GObject.GObject): """Manage a set of Track IDs.""" n_trackids = GObject.Property(type=int) def __init__(self, trackids: Iterable[int] = []): """Initialize a TrackidSet.""" super().__init__() self.__trackids = set(trackids) self.n_trackids = len(self.__trackids) def __contains__(self, track: Track) -> bool: """Check if a Track is in the set.""" return track.trackid in self.__trackids def __len__(self) -> int: """Find the number of Tracks in the set.""" return len(self.__trackids) def __sub__(self, rhs): """Subtract two TrackidSets.""" return TrackidSet(self.__trackids - rhs.trackids) def add_track(self, track: Track) -> None: """Add a Track to the set.""" if track.trackid not in self.__trackids: self.__trackids.add(track.trackid) self.n_trackids = len(self) self.emit("trackid-added", track.trackid) def random_trackid(self) -> int | None: """Get a random trackid from the set.""" if len(self.__trackids) > 0: return random.choice(list(self.__trackids)) def remove_track(self, track: Track) -> None: """Remove a Track from the set.""" if track.trackid in self.__trackids: self.__trackids.discard(track.trackid) self.n_trackids = len(self) self.emit("trackid-removed", track.trackid) @property def trackids(self) -> set: """Get the set of trackids.""" return self.__trackids @trackids.setter def trackids(self, trackids: Iterable[int]) -> None: """Add several trackids to the set at one time.""" new_trackids = set(trackids) if self.__trackids.isdisjoint(new_trackids): self.__trackids = new_trackids self.n_trackids = len(self) self.emit("trackids-reset") else: removed = self.__trackids - new_trackids added = new_trackids - self.__trackids self.__trackids = new_trackids self.n_trackids = len(self) for id in removed: self.emit("trackid-removed", id) for id in added: self.emit("trackid-added", id) @GObject.Signal(arg_types=(int,)) def trackid_added(self, trackid: int) -> None: """Signal that a Track has been added to the set.""" @GObject.Signal(arg_types=(int,)) def trackid_removed(self, trackid: int) -> None: """Signal that a Track has been removed from the set.""" @GObject.Signal def trackids_reset(self) -> None: """Signal that the Tracks in the set have been reset."""