db: Add a TrackidSet

The TrackidSet is intened to be used by Playlists to keep track of which
Tracks have been added without much overhead.

Signed-off-by: Anna Schumaker <Anna@NoWheyCreamery.com>
This commit is contained in:
Anna Schumaker 2023-01-20 16:39:20 -05:00
parent ff835832c8
commit 11560d781e
2 changed files with 214 additions and 6 deletions

View File

@ -2,7 +2,9 @@
"""A custom Gio.ListModel for working with tracks.""" """A custom Gio.ListModel for working with tracks."""
import datetime import datetime
import pathlib import pathlib
import random
import sqlite3 import sqlite3
from typing import Iterable
from gi.repository import GObject from gi.repository import GObject
from gi.repository import Gtk from gi.repository import Gtk
from . import table from . import table
@ -247,3 +249,81 @@ class Table(table.Table):
self.sql.playlists.most_played.reload_tracks(idle=True) self.sql.playlists.most_played.reload_tracks(idle=True)
self.sql.playlists.queued.remove_track(track) self.sql.playlists.queued.remove_track(track)
self.sql.playlists.unplayed.remove_track(track) self.sql.playlists.unplayed.remove_track(track)
class TrackidSet(GObject.GObject):
"""Manage a set of Track IDs."""
n_trackids = GObject.Property(type=int)
def __init__(self, trackids: Iterable[int] = []):
"""Initialize a TrackidSet."""
super().__init__()
self.__trackids = set(trackids)
self.n_trackids = len(self.__trackids)
def __contains__(self, track: Track) -> bool:
"""Check if a Track is in the set."""
return track.trackid in self.__trackids
def __len__(self) -> int:
"""Find the number of Tracks in the set."""
return len(self.__trackids)
def __sub__(self, rhs):
"""Subtract two TrackidSets."""
return TrackidSet(self.__trackids - rhs.trackids)
def add_track(self, track: Track) -> None:
"""Add a Track to the set."""
if track.trackid not in self.__trackids:
self.__trackids.add(track.trackid)
self.n_trackids = len(self)
self.emit("trackid-added", track.trackid)
def random_trackid(self) -> int | None:
"""Get a random trackid from the set."""
if len(self.__trackids) > 0:
return random.choice(list(self.__trackids))
def remove_track(self, track: Track) -> None:
"""Remove a Track from the set."""
if track.trackid in self.__trackids:
self.__trackids.discard(track.trackid)
self.n_trackids = len(self)
self.emit("trackid-removed", track.trackid)
@property
def trackids(self) -> set:
"""Get the set of trackids."""
return self.__trackids
@trackids.setter
def trackids(self, trackids: Iterable[int]) -> None:
"""Add several trackids to the set at one time."""
new_trackids = set(trackids)
if self.__trackids.isdisjoint(new_trackids):
self.__trackids = new_trackids
self.n_trackids = len(self)
self.emit("trackids-reset")
else:
removed = self.__trackids - new_trackids
added = new_trackids - self.__trackids
self.__trackids = new_trackids
self.n_trackids = len(self)
for id in removed:
self.emit("trackid-removed", id)
for id in added:
self.emit("trackid-added", id)
@GObject.Signal(arg_types=(int,))
def trackid_added(self, trackid: int) -> None:
"""Signal that a Track has been added to the set."""
@GObject.Signal(arg_types=(int,))
def trackid_removed(self, trackid: int) -> None:
"""Signal that a Track has been removed from the set."""
@GObject.Signal
def trackids_reset(self) -> None:
"""Signal that the Tracks in the set have been reset."""

View File

@ -6,6 +6,7 @@ import unittest
import emmental.db.tracks import emmental.db.tracks
import tests.util import tests.util
import unittest.mock import unittest.mock
from gi.repository import GObject
from gi.repository import Gio from gi.repository import Gio
from gi.repository import Gtk from gi.repository import Gtk
@ -588,10 +589,137 @@ class TestTrackTable(tests.util.TestCase):
track.favorite = True track.favorite = True
self.assertTrue(self.tracks.current_favorite) self.assertTrue(self.tracks.current_favorite)
self.tracks.current_favorite = False
self.assertFalse(track.favorite)
self.tracks.current_favorite = True
self.assertTrue(track.favorite)
self.tracks.current_track = None class TestTrackIdSet(tests.util.TestCase):
self.assertFalse(self.tracks.current_favorite) """Test our custom TrackIdSet object."""
def setUp(self):
"""Set up common variables."""
super().setUp()
self.library = self.sql.libraries.create(pathlib.Path("/a/b"))
self.album = self.sql.albums.create("Album", "Artist", "2023-03")
self.medium = self.sql.media.create(self.album, "", number=1)
self.year = self.sql.years.create(1988)
self.tracks = self.sql.tracks
self.track1 = self.tracks.create(self.library,
pathlib.Path("/a/b/1.ogg"),
self.medium, self.year)
self.track2 = self.tracks.create(self.library,
pathlib.Path("/a/b/2.ogg"),
self.medium, self.year)
self.trackids = emmental.db.tracks.TrackidSet()
def test_init(self):
"""Test that the TrackIdSet was initialized properly."""
self.assertIsInstance(self.trackids, GObject.GObject)
self.assertIsInstance(self.trackids._TrackidSet__trackids, set)
self.assertEqual(self.trackids.n_trackids, 0)
trackids2 = emmental.db.tracks.TrackidSet({1, 2, 3})
self.assertSetEqual(trackids2._TrackidSet__trackids, {1, 2, 3})
self.assertEqual(trackids2.n_trackids, 3)
def test_contains(self):
"""Test the __contains__() function."""
self.trackids.add_track(self.track1)
self.assertTrue(self.track1 in self.trackids)
self.assertFalse(self.track2 in self.trackids)
def test_len(self):
"""Test the __len__() function."""
self.assertEqual(len(self.trackids), 0)
self.trackids.add_track(self.track1)
self.assertEqual(len(self.trackids), 1)
def test_sub(self):
"""Test the __sub__() function."""
self.trackids.trackids = {1, 2, 3, 4, 5}
trackidset2 = emmental.db.tracks.TrackidSet({3, 4, 5, 6, 7})
res = self.trackids - trackidset2
self.assertIsInstance(res, emmental.db.tracks.TrackidSet)
self.assertSetEqual(res.trackids, {1, 2})
def test_add_track(self):
"""Test adding a Track to the set."""
added = unittest.mock.Mock()
self.trackids.connect("trackid-added", added)
self.trackids.add_track(self.track1)
self.assertSetEqual(self.trackids.trackids, {self.track1.trackid})
self.assertEqual(self.trackids.n_trackids, 1)
added.assert_called_with(self.trackids, self.track1.trackid)
self.trackids.add_track(self.track2)
self.assertSetEqual(self.trackids.trackids,
{self.track1.trackid, self.track2.trackid})
self.assertEqual(self.trackids.n_trackids, 2)
added.assert_called_with(self.trackids, self.track2.trackid)
added.reset_mock()
self.trackids.add_track(self.track2)
self.assertSetEqual(self.trackids.trackids,
{self.track1.trackid, self.track2.trackid})
self.assertEqual(self.trackids.n_trackids, 2)
added.assert_not_called()
@unittest.mock.patch("random.choice")
def test_random_trackid(self, mock_choice: unittest.mock.Mock):
"""Test getting a random trackid from the set."""
self.assertIsNone(self.trackids.random_trackid())
mock_choice.assert_not_called()
self.trackids.trackids = {1, 2, 3}
mock_choice.return_value = 2
self.assertEqual(self.trackids.random_trackid(), 2)
mock_choice.assert_called_with([1, 2, 3])
def test_remove_track(self):
"""Test removing a Track from the set."""
removed = unittest.mock.Mock()
self.trackids.trackids = {self.track1.trackid, self.track2.trackid}
self.trackids.connect("trackid-removed", removed)
self.trackids.remove_track(self.track1)
self.assertSetEqual(self.trackids.trackids, {self.track2.trackid})
self.assertEqual(self.trackids.n_trackids, 1)
removed.assert_called_with(self.trackids, self.track1.trackid)
removed.reset_mock()
self.trackids.remove_track(self.track1)
self.assertSetEqual(self.trackids.trackids, {self.track2.trackid})
self.assertEqual(self.trackids.n_trackids, 1)
removed.assert_not_called()
def test_trackids(self):
"""Test setting the Trackids property."""
added = unittest.mock.Mock()
removed = unittest.mock.Mock()
reset = unittest.mock.Mock()
self.trackids.connect("trackid-added", added)
self.trackids.connect("trackid-removed", removed)
self.trackids.connect("trackids-reset", reset)
self.trackids.trackids = {1, 2, 3, 4, 5}
self.assertSetEqual(self.trackids.trackids, {1, 2, 3, 4, 5})
self.assertEqual(self.trackids.n_trackids, 5)
added.assert_not_called()
removed.assert_not_called()
reset.assert_called_with(self.trackids)
reset.reset_mock()
self.trackids.trackids = {1, 2, 3, 4, 5}
self.assertSetEqual(self.trackids.trackids, {1, 2, 3, 4, 5})
added.assert_not_called()
removed.assert_not_called()
reset.assert_not_called()
self.trackids.trackids = {3, 4, 5, 6, 7}
self.assertSetEqual(self.trackids.trackids, {3, 4, 5, 6, 7})
added.assert_has_calls([unittest.mock.call(self.trackids, 6),
unittest.mock.call(self.trackids, 7)])
removed.assert_has_calls([unittest.mock.call(self.trackids, 1),
unittest.mock.call(self.trackids, 2)])
reset.assert_not_called()