db: Add a TrackidSet
The TrackidSet is intened to be used by Playlists to keep track of which Tracks have been added without much overhead. Signed-off-by: Anna Schumaker <Anna@NoWheyCreamery.com>
This commit is contained in:
parent
ff835832c8
commit
11560d781e
|
@ -2,7 +2,9 @@
|
|||
"""A custom Gio.ListModel for working with tracks."""
|
||||
import datetime
|
||||
import pathlib
|
||||
import random
|
||||
import sqlite3
|
||||
from typing import Iterable
|
||||
from gi.repository import GObject
|
||||
from gi.repository import Gtk
|
||||
from . import table
|
||||
|
@ -247,3 +249,81 @@ class Table(table.Table):
|
|||
self.sql.playlists.most_played.reload_tracks(idle=True)
|
||||
self.sql.playlists.queued.remove_track(track)
|
||||
self.sql.playlists.unplayed.remove_track(track)
|
||||
|
||||
|
||||
class TrackidSet(GObject.GObject):
|
||||
"""Manage a set of Track IDs."""
|
||||
|
||||
n_trackids = GObject.Property(type=int)
|
||||
|
||||
def __init__(self, trackids: Iterable[int] = []):
|
||||
"""Initialize a TrackidSet."""
|
||||
super().__init__()
|
||||
self.__trackids = set(trackids)
|
||||
self.n_trackids = len(self.__trackids)
|
||||
|
||||
def __contains__(self, track: Track) -> bool:
|
||||
"""Check if a Track is in the set."""
|
||||
return track.trackid in self.__trackids
|
||||
|
||||
def __len__(self) -> int:
|
||||
"""Find the number of Tracks in the set."""
|
||||
return len(self.__trackids)
|
||||
|
||||
def __sub__(self, rhs):
|
||||
"""Subtract two TrackidSets."""
|
||||
return TrackidSet(self.__trackids - rhs.trackids)
|
||||
|
||||
def add_track(self, track: Track) -> None:
|
||||
"""Add a Track to the set."""
|
||||
if track.trackid not in self.__trackids:
|
||||
self.__trackids.add(track.trackid)
|
||||
self.n_trackids = len(self)
|
||||
self.emit("trackid-added", track.trackid)
|
||||
|
||||
def random_trackid(self) -> int | None:
|
||||
"""Get a random trackid from the set."""
|
||||
if len(self.__trackids) > 0:
|
||||
return random.choice(list(self.__trackids))
|
||||
|
||||
def remove_track(self, track: Track) -> None:
|
||||
"""Remove a Track from the set."""
|
||||
if track.trackid in self.__trackids:
|
||||
self.__trackids.discard(track.trackid)
|
||||
self.n_trackids = len(self)
|
||||
self.emit("trackid-removed", track.trackid)
|
||||
|
||||
@property
|
||||
def trackids(self) -> set:
|
||||
"""Get the set of trackids."""
|
||||
return self.__trackids
|
||||
|
||||
@trackids.setter
|
||||
def trackids(self, trackids: Iterable[int]) -> None:
|
||||
"""Add several trackids to the set at one time."""
|
||||
new_trackids = set(trackids)
|
||||
if self.__trackids.isdisjoint(new_trackids):
|
||||
self.__trackids = new_trackids
|
||||
self.n_trackids = len(self)
|
||||
self.emit("trackids-reset")
|
||||
else:
|
||||
removed = self.__trackids - new_trackids
|
||||
added = new_trackids - self.__trackids
|
||||
self.__trackids = new_trackids
|
||||
self.n_trackids = len(self)
|
||||
for id in removed:
|
||||
self.emit("trackid-removed", id)
|
||||
for id in added:
|
||||
self.emit("trackid-added", id)
|
||||
|
||||
@GObject.Signal(arg_types=(int,))
|
||||
def trackid_added(self, trackid: int) -> None:
|
||||
"""Signal that a Track has been added to the set."""
|
||||
|
||||
@GObject.Signal(arg_types=(int,))
|
||||
def trackid_removed(self, trackid: int) -> None:
|
||||
"""Signal that a Track has been removed from the set."""
|
||||
|
||||
@GObject.Signal
|
||||
def trackids_reset(self) -> None:
|
||||
"""Signal that the Tracks in the set have been reset."""
|
||||
|
|
|
@ -6,6 +6,7 @@ import unittest
|
|||
import emmental.db.tracks
|
||||
import tests.util
|
||||
import unittest.mock
|
||||
from gi.repository import GObject
|
||||
from gi.repository import Gio
|
||||
from gi.repository import Gtk
|
||||
|
||||
|
@ -588,10 +589,137 @@ class TestTrackTable(tests.util.TestCase):
|
|||
track.favorite = True
|
||||
self.assertTrue(self.tracks.current_favorite)
|
||||
|
||||
self.tracks.current_favorite = False
|
||||
self.assertFalse(track.favorite)
|
||||
self.tracks.current_favorite = True
|
||||
self.assertTrue(track.favorite)
|
||||
|
||||
self.tracks.current_track = None
|
||||
self.assertFalse(self.tracks.current_favorite)
|
||||
class TestTrackIdSet(tests.util.TestCase):
|
||||
"""Test our custom TrackIdSet object."""
|
||||
|
||||
def setUp(self):
|
||||
"""Set up common variables."""
|
||||
super().setUp()
|
||||
self.library = self.sql.libraries.create(pathlib.Path("/a/b"))
|
||||
self.album = self.sql.albums.create("Album", "Artist", "2023-03")
|
||||
self.medium = self.sql.media.create(self.album, "", number=1)
|
||||
self.year = self.sql.years.create(1988)
|
||||
|
||||
self.tracks = self.sql.tracks
|
||||
self.track1 = self.tracks.create(self.library,
|
||||
pathlib.Path("/a/b/1.ogg"),
|
||||
self.medium, self.year)
|
||||
self.track2 = self.tracks.create(self.library,
|
||||
pathlib.Path("/a/b/2.ogg"),
|
||||
self.medium, self.year)
|
||||
|
||||
self.trackids = emmental.db.tracks.TrackidSet()
|
||||
|
||||
def test_init(self):
|
||||
"""Test that the TrackIdSet was initialized properly."""
|
||||
self.assertIsInstance(self.trackids, GObject.GObject)
|
||||
self.assertIsInstance(self.trackids._TrackidSet__trackids, set)
|
||||
self.assertEqual(self.trackids.n_trackids, 0)
|
||||
|
||||
trackids2 = emmental.db.tracks.TrackidSet({1, 2, 3})
|
||||
self.assertSetEqual(trackids2._TrackidSet__trackids, {1, 2, 3})
|
||||
self.assertEqual(trackids2.n_trackids, 3)
|
||||
|
||||
def test_contains(self):
|
||||
"""Test the __contains__() function."""
|
||||
self.trackids.add_track(self.track1)
|
||||
self.assertTrue(self.track1 in self.trackids)
|
||||
self.assertFalse(self.track2 in self.trackids)
|
||||
|
||||
def test_len(self):
|
||||
"""Test the __len__() function."""
|
||||
self.assertEqual(len(self.trackids), 0)
|
||||
self.trackids.add_track(self.track1)
|
||||
self.assertEqual(len(self.trackids), 1)
|
||||
|
||||
def test_sub(self):
|
||||
"""Test the __sub__() function."""
|
||||
self.trackids.trackids = {1, 2, 3, 4, 5}
|
||||
trackidset2 = emmental.db.tracks.TrackidSet({3, 4, 5, 6, 7})
|
||||
|
||||
res = self.trackids - trackidset2
|
||||
self.assertIsInstance(res, emmental.db.tracks.TrackidSet)
|
||||
self.assertSetEqual(res.trackids, {1, 2})
|
||||
|
||||
def test_add_track(self):
|
||||
"""Test adding a Track to the set."""
|
||||
added = unittest.mock.Mock()
|
||||
self.trackids.connect("trackid-added", added)
|
||||
|
||||
self.trackids.add_track(self.track1)
|
||||
self.assertSetEqual(self.trackids.trackids, {self.track1.trackid})
|
||||
self.assertEqual(self.trackids.n_trackids, 1)
|
||||
added.assert_called_with(self.trackids, self.track1.trackid)
|
||||
|
||||
self.trackids.add_track(self.track2)
|
||||
self.assertSetEqual(self.trackids.trackids,
|
||||
{self.track1.trackid, self.track2.trackid})
|
||||
self.assertEqual(self.trackids.n_trackids, 2)
|
||||
added.assert_called_with(self.trackids, self.track2.trackid)
|
||||
|
||||
added.reset_mock()
|
||||
self.trackids.add_track(self.track2)
|
||||
self.assertSetEqual(self.trackids.trackids,
|
||||
{self.track1.trackid, self.track2.trackid})
|
||||
self.assertEqual(self.trackids.n_trackids, 2)
|
||||
added.assert_not_called()
|
||||
|
||||
@unittest.mock.patch("random.choice")
|
||||
def test_random_trackid(self, mock_choice: unittest.mock.Mock):
|
||||
"""Test getting a random trackid from the set."""
|
||||
self.assertIsNone(self.trackids.random_trackid())
|
||||
mock_choice.assert_not_called()
|
||||
|
||||
self.trackids.trackids = {1, 2, 3}
|
||||
mock_choice.return_value = 2
|
||||
self.assertEqual(self.trackids.random_trackid(), 2)
|
||||
mock_choice.assert_called_with([1, 2, 3])
|
||||
|
||||
def test_remove_track(self):
|
||||
"""Test removing a Track from the set."""
|
||||
removed = unittest.mock.Mock()
|
||||
self.trackids.trackids = {self.track1.trackid, self.track2.trackid}
|
||||
self.trackids.connect("trackid-removed", removed)
|
||||
|
||||
self.trackids.remove_track(self.track1)
|
||||
self.assertSetEqual(self.trackids.trackids, {self.track2.trackid})
|
||||
self.assertEqual(self.trackids.n_trackids, 1)
|
||||
removed.assert_called_with(self.trackids, self.track1.trackid)
|
||||
|
||||
removed.reset_mock()
|
||||
self.trackids.remove_track(self.track1)
|
||||
self.assertSetEqual(self.trackids.trackids, {self.track2.trackid})
|
||||
self.assertEqual(self.trackids.n_trackids, 1)
|
||||
removed.assert_not_called()
|
||||
|
||||
def test_trackids(self):
|
||||
"""Test setting the Trackids property."""
|
||||
added = unittest.mock.Mock()
|
||||
removed = unittest.mock.Mock()
|
||||
reset = unittest.mock.Mock()
|
||||
self.trackids.connect("trackid-added", added)
|
||||
self.trackids.connect("trackid-removed", removed)
|
||||
self.trackids.connect("trackids-reset", reset)
|
||||
|
||||
self.trackids.trackids = {1, 2, 3, 4, 5}
|
||||
self.assertSetEqual(self.trackids.trackids, {1, 2, 3, 4, 5})
|
||||
self.assertEqual(self.trackids.n_trackids, 5)
|
||||
added.assert_not_called()
|
||||
removed.assert_not_called()
|
||||
reset.assert_called_with(self.trackids)
|
||||
|
||||
reset.reset_mock()
|
||||
self.trackids.trackids = {1, 2, 3, 4, 5}
|
||||
self.assertSetEqual(self.trackids.trackids, {1, 2, 3, 4, 5})
|
||||
added.assert_not_called()
|
||||
removed.assert_not_called()
|
||||
reset.assert_not_called()
|
||||
|
||||
self.trackids.trackids = {3, 4, 5, 6, 7}
|
||||
self.assertSetEqual(self.trackids.trackids, {3, 4, 5, 6, 7})
|
||||
added.assert_has_calls([unittest.mock.call(self.trackids, 6),
|
||||
unittest.mock.call(self.trackids, 7)])
|
||||
removed.assert_has_calls([unittest.mock.call(self.trackids, 1),
|
||||
unittest.mock.call(self.trackids, 2)])
|
||||
reset.assert_not_called()
|
||||
|
|
Loading…
Reference in New Issue