db: Add a TrackidSet

The TrackidSet is intened to be used by Playlists to keep track of which
Tracks have been added without much overhead.

Signed-off-by: Anna Schumaker <Anna@NoWheyCreamery.com>
This commit is contained in:
Anna Schumaker 2023-01-20 16:39:20 -05:00
parent ff835832c8
commit 11560d781e
2 changed files with 214 additions and 6 deletions

View File

@ -2,7 +2,9 @@
"""A custom Gio.ListModel for working with tracks."""
import datetime
import pathlib
import random
import sqlite3
from typing import Iterable
from gi.repository import GObject
from gi.repository import Gtk
from . import table
@ -247,3 +249,81 @@ class Table(table.Table):
self.sql.playlists.most_played.reload_tracks(idle=True)
self.sql.playlists.queued.remove_track(track)
self.sql.playlists.unplayed.remove_track(track)
class TrackidSet(GObject.GObject):
"""Manage a set of Track IDs."""
n_trackids = GObject.Property(type=int)
def __init__(self, trackids: Iterable[int] = []):
"""Initialize a TrackidSet."""
super().__init__()
self.__trackids = set(trackids)
self.n_trackids = len(self.__trackids)
def __contains__(self, track: Track) -> bool:
"""Check if a Track is in the set."""
return track.trackid in self.__trackids
def __len__(self) -> int:
"""Find the number of Tracks in the set."""
return len(self.__trackids)
def __sub__(self, rhs):
"""Subtract two TrackidSets."""
return TrackidSet(self.__trackids - rhs.trackids)
def add_track(self, track: Track) -> None:
"""Add a Track to the set."""
if track.trackid not in self.__trackids:
self.__trackids.add(track.trackid)
self.n_trackids = len(self)
self.emit("trackid-added", track.trackid)
def random_trackid(self) -> int | None:
"""Get a random trackid from the set."""
if len(self.__trackids) > 0:
return random.choice(list(self.__trackids))
def remove_track(self, track: Track) -> None:
"""Remove a Track from the set."""
if track.trackid in self.__trackids:
self.__trackids.discard(track.trackid)
self.n_trackids = len(self)
self.emit("trackid-removed", track.trackid)
@property
def trackids(self) -> set:
"""Get the set of trackids."""
return self.__trackids
@trackids.setter
def trackids(self, trackids: Iterable[int]) -> None:
"""Add several trackids to the set at one time."""
new_trackids = set(trackids)
if self.__trackids.isdisjoint(new_trackids):
self.__trackids = new_trackids
self.n_trackids = len(self)
self.emit("trackids-reset")
else:
removed = self.__trackids - new_trackids
added = new_trackids - self.__trackids
self.__trackids = new_trackids
self.n_trackids = len(self)
for id in removed:
self.emit("trackid-removed", id)
for id in added:
self.emit("trackid-added", id)
@GObject.Signal(arg_types=(int,))
def trackid_added(self, trackid: int) -> None:
"""Signal that a Track has been added to the set."""
@GObject.Signal(arg_types=(int,))
def trackid_removed(self, trackid: int) -> None:
"""Signal that a Track has been removed from the set."""
@GObject.Signal
def trackids_reset(self) -> None:
"""Signal that the Tracks in the set have been reset."""

View File

@ -6,6 +6,7 @@ import unittest
import emmental.db.tracks
import tests.util
import unittest.mock
from gi.repository import GObject
from gi.repository import Gio
from gi.repository import Gtk
@ -588,10 +589,137 @@ class TestTrackTable(tests.util.TestCase):
track.favorite = True
self.assertTrue(self.tracks.current_favorite)
self.tracks.current_favorite = False
self.assertFalse(track.favorite)
self.tracks.current_favorite = True
self.assertTrue(track.favorite)
self.tracks.current_track = None
self.assertFalse(self.tracks.current_favorite)
class TestTrackIdSet(tests.util.TestCase):
"""Test our custom TrackIdSet object."""
def setUp(self):
"""Set up common variables."""
super().setUp()
self.library = self.sql.libraries.create(pathlib.Path("/a/b"))
self.album = self.sql.albums.create("Album", "Artist", "2023-03")
self.medium = self.sql.media.create(self.album, "", number=1)
self.year = self.sql.years.create(1988)
self.tracks = self.sql.tracks
self.track1 = self.tracks.create(self.library,
pathlib.Path("/a/b/1.ogg"),
self.medium, self.year)
self.track2 = self.tracks.create(self.library,
pathlib.Path("/a/b/2.ogg"),
self.medium, self.year)
self.trackids = emmental.db.tracks.TrackidSet()
def test_init(self):
"""Test that the TrackIdSet was initialized properly."""
self.assertIsInstance(self.trackids, GObject.GObject)
self.assertIsInstance(self.trackids._TrackidSet__trackids, set)
self.assertEqual(self.trackids.n_trackids, 0)
trackids2 = emmental.db.tracks.TrackidSet({1, 2, 3})
self.assertSetEqual(trackids2._TrackidSet__trackids, {1, 2, 3})
self.assertEqual(trackids2.n_trackids, 3)
def test_contains(self):
"""Test the __contains__() function."""
self.trackids.add_track(self.track1)
self.assertTrue(self.track1 in self.trackids)
self.assertFalse(self.track2 in self.trackids)
def test_len(self):
"""Test the __len__() function."""
self.assertEqual(len(self.trackids), 0)
self.trackids.add_track(self.track1)
self.assertEqual(len(self.trackids), 1)
def test_sub(self):
"""Test the __sub__() function."""
self.trackids.trackids = {1, 2, 3, 4, 5}
trackidset2 = emmental.db.tracks.TrackidSet({3, 4, 5, 6, 7})
res = self.trackids - trackidset2
self.assertIsInstance(res, emmental.db.tracks.TrackidSet)
self.assertSetEqual(res.trackids, {1, 2})
def test_add_track(self):
"""Test adding a Track to the set."""
added = unittest.mock.Mock()
self.trackids.connect("trackid-added", added)
self.trackids.add_track(self.track1)
self.assertSetEqual(self.trackids.trackids, {self.track1.trackid})
self.assertEqual(self.trackids.n_trackids, 1)
added.assert_called_with(self.trackids, self.track1.trackid)
self.trackids.add_track(self.track2)
self.assertSetEqual(self.trackids.trackids,
{self.track1.trackid, self.track2.trackid})
self.assertEqual(self.trackids.n_trackids, 2)
added.assert_called_with(self.trackids, self.track2.trackid)
added.reset_mock()
self.trackids.add_track(self.track2)
self.assertSetEqual(self.trackids.trackids,
{self.track1.trackid, self.track2.trackid})
self.assertEqual(self.trackids.n_trackids, 2)
added.assert_not_called()
@unittest.mock.patch("random.choice")
def test_random_trackid(self, mock_choice: unittest.mock.Mock):
"""Test getting a random trackid from the set."""
self.assertIsNone(self.trackids.random_trackid())
mock_choice.assert_not_called()
self.trackids.trackids = {1, 2, 3}
mock_choice.return_value = 2
self.assertEqual(self.trackids.random_trackid(), 2)
mock_choice.assert_called_with([1, 2, 3])
def test_remove_track(self):
"""Test removing a Track from the set."""
removed = unittest.mock.Mock()
self.trackids.trackids = {self.track1.trackid, self.track2.trackid}
self.trackids.connect("trackid-removed", removed)
self.trackids.remove_track(self.track1)
self.assertSetEqual(self.trackids.trackids, {self.track2.trackid})
self.assertEqual(self.trackids.n_trackids, 1)
removed.assert_called_with(self.trackids, self.track1.trackid)
removed.reset_mock()
self.trackids.remove_track(self.track1)
self.assertSetEqual(self.trackids.trackids, {self.track2.trackid})
self.assertEqual(self.trackids.n_trackids, 1)
removed.assert_not_called()
def test_trackids(self):
"""Test setting the Trackids property."""
added = unittest.mock.Mock()
removed = unittest.mock.Mock()
reset = unittest.mock.Mock()
self.trackids.connect("trackid-added", added)
self.trackids.connect("trackid-removed", removed)
self.trackids.connect("trackids-reset", reset)
self.trackids.trackids = {1, 2, 3, 4, 5}
self.assertSetEqual(self.trackids.trackids, {1, 2, 3, 4, 5})
self.assertEqual(self.trackids.n_trackids, 5)
added.assert_not_called()
removed.assert_not_called()
reset.assert_called_with(self.trackids)
reset.reset_mock()
self.trackids.trackids = {1, 2, 3, 4, 5}
self.assertSetEqual(self.trackids.trackids, {1, 2, 3, 4, 5})
added.assert_not_called()
removed.assert_not_called()
reset.assert_not_called()
self.trackids.trackids = {3, 4, 5, 6, 7}
self.assertSetEqual(self.trackids.trackids, {3, 4, 5, 6, 7})
added.assert_has_calls([unittest.mock.call(self.trackids, 6),
unittest.mock.call(self.trackids, 7)])
removed.assert_has_calls([unittest.mock.call(self.trackids, 1),
unittest.mock.call(self.trackids, 2)])
reset.assert_not_called()