From 11560d781ed4450b1bc717e2a2fa47803dfa18de Mon Sep 17 00:00:00 2001 From: Anna Schumaker Date: Fri, 20 Jan 2023 16:39:20 -0500 Subject: [PATCH] db: Add a TrackidSet The TrackidSet is intened to be used by Playlists to keep track of which Tracks have been added without much overhead. Signed-off-by: Anna Schumaker --- emmental/db/tracks.py | 80 +++++++++++++++++++++++ tests/db/test_tracks.py | 140 ++++++++++++++++++++++++++++++++++++++-- 2 files changed, 214 insertions(+), 6 deletions(-) diff --git a/emmental/db/tracks.py b/emmental/db/tracks.py index 58c527d..6e52542 100644 --- a/emmental/db/tracks.py +++ b/emmental/db/tracks.py @@ -2,7 +2,9 @@ """A custom Gio.ListModel for working with tracks.""" import datetime import pathlib +import random import sqlite3 +from typing import Iterable from gi.repository import GObject from gi.repository import Gtk from . import table @@ -247,3 +249,81 @@ class Table(table.Table): self.sql.playlists.most_played.reload_tracks(idle=True) self.sql.playlists.queued.remove_track(track) self.sql.playlists.unplayed.remove_track(track) + + +class TrackidSet(GObject.GObject): + """Manage a set of Track IDs.""" + + n_trackids = GObject.Property(type=int) + + def __init__(self, trackids: Iterable[int] = []): + """Initialize a TrackidSet.""" + super().__init__() + self.__trackids = set(trackids) + self.n_trackids = len(self.__trackids) + + def __contains__(self, track: Track) -> bool: + """Check if a Track is in the set.""" + return track.trackid in self.__trackids + + def __len__(self) -> int: + """Find the number of Tracks in the set.""" + return len(self.__trackids) + + def __sub__(self, rhs): + """Subtract two TrackidSets.""" + return TrackidSet(self.__trackids - rhs.trackids) + + def add_track(self, track: Track) -> None: + """Add a Track to the set.""" + if track.trackid not in self.__trackids: + self.__trackids.add(track.trackid) + self.n_trackids = len(self) + self.emit("trackid-added", track.trackid) + + def random_trackid(self) -> int | None: + """Get a random trackid from the set.""" + if len(self.__trackids) > 0: + return random.choice(list(self.__trackids)) + + def remove_track(self, track: Track) -> None: + """Remove a Track from the set.""" + if track.trackid in self.__trackids: + self.__trackids.discard(track.trackid) + self.n_trackids = len(self) + self.emit("trackid-removed", track.trackid) + + @property + def trackids(self) -> set: + """Get the set of trackids.""" + return self.__trackids + + @trackids.setter + def trackids(self, trackids: Iterable[int]) -> None: + """Add several trackids to the set at one time.""" + new_trackids = set(trackids) + if self.__trackids.isdisjoint(new_trackids): + self.__trackids = new_trackids + self.n_trackids = len(self) + self.emit("trackids-reset") + else: + removed = self.__trackids - new_trackids + added = new_trackids - self.__trackids + self.__trackids = new_trackids + self.n_trackids = len(self) + for id in removed: + self.emit("trackid-removed", id) + for id in added: + self.emit("trackid-added", id) + + @GObject.Signal(arg_types=(int,)) + def trackid_added(self, trackid: int) -> None: + """Signal that a Track has been added to the set.""" + + @GObject.Signal(arg_types=(int,)) + def trackid_removed(self, trackid: int) -> None: + """Signal that a Track has been removed from the set.""" + + @GObject.Signal + def trackids_reset(self) -> None: + """Signal that the Tracks in the set have been reset.""" diff --git a/tests/db/test_tracks.py b/tests/db/test_tracks.py index 9cabc87..9ba0ea3 100644 --- a/tests/db/test_tracks.py +++ b/tests/db/test_tracks.py @@ -6,6 +6,7 @@ import unittest import emmental.db.tracks import tests.util import unittest.mock +from gi.repository import GObject from gi.repository import Gio from gi.repository import Gtk @@ -588,10 +589,137 @@ class TestTrackTable(tests.util.TestCase): track.favorite = True self.assertTrue(self.tracks.current_favorite) - self.tracks.current_favorite = False - self.assertFalse(track.favorite) - self.tracks.current_favorite = True - self.assertTrue(track.favorite) - self.tracks.current_track = None - self.assertFalse(self.tracks.current_favorite) +class TestTrackIdSet(tests.util.TestCase): + """Test our custom TrackIdSet object.""" + + def setUp(self): + """Set up common variables.""" + super().setUp() + self.library = self.sql.libraries.create(pathlib.Path("/a/b")) + self.album = self.sql.albums.create("Album", "Artist", "2023-03") + self.medium = self.sql.media.create(self.album, "", number=1) + self.year = self.sql.years.create(1988) + + self.tracks = self.sql.tracks + self.track1 = self.tracks.create(self.library, + pathlib.Path("/a/b/1.ogg"), + self.medium, self.year) + self.track2 = self.tracks.create(self.library, + pathlib.Path("/a/b/2.ogg"), + self.medium, self.year) + + self.trackids = emmental.db.tracks.TrackidSet() + + def test_init(self): + """Test that the TrackIdSet was initialized properly.""" + self.assertIsInstance(self.trackids, GObject.GObject) + self.assertIsInstance(self.trackids._TrackidSet__trackids, set) + self.assertEqual(self.trackids.n_trackids, 0) + + trackids2 = emmental.db.tracks.TrackidSet({1, 2, 3}) + self.assertSetEqual(trackids2._TrackidSet__trackids, {1, 2, 3}) + self.assertEqual(trackids2.n_trackids, 3) + + def test_contains(self): + """Test the __contains__() function.""" + self.trackids.add_track(self.track1) + self.assertTrue(self.track1 in self.trackids) + self.assertFalse(self.track2 in self.trackids) + + def test_len(self): + """Test the __len__() function.""" + self.assertEqual(len(self.trackids), 0) + self.trackids.add_track(self.track1) + self.assertEqual(len(self.trackids), 1) + + def test_sub(self): + """Test the __sub__() function.""" + self.trackids.trackids = {1, 2, 3, 4, 5} + trackidset2 = emmental.db.tracks.TrackidSet({3, 4, 5, 6, 7}) + + res = self.trackids - trackidset2 + self.assertIsInstance(res, emmental.db.tracks.TrackidSet) + self.assertSetEqual(res.trackids, {1, 2}) + + def test_add_track(self): + """Test adding a Track to the set.""" + added = unittest.mock.Mock() + self.trackids.connect("trackid-added", added) + + self.trackids.add_track(self.track1) + self.assertSetEqual(self.trackids.trackids, {self.track1.trackid}) + self.assertEqual(self.trackids.n_trackids, 1) + added.assert_called_with(self.trackids, self.track1.trackid) + + self.trackids.add_track(self.track2) + self.assertSetEqual(self.trackids.trackids, + {self.track1.trackid, self.track2.trackid}) + self.assertEqual(self.trackids.n_trackids, 2) + added.assert_called_with(self.trackids, self.track2.trackid) + + added.reset_mock() + self.trackids.add_track(self.track2) + self.assertSetEqual(self.trackids.trackids, + {self.track1.trackid, self.track2.trackid}) + self.assertEqual(self.trackids.n_trackids, 2) + added.assert_not_called() + + @unittest.mock.patch("random.choice") + def test_random_trackid(self, mock_choice: unittest.mock.Mock): + """Test getting a random trackid from the set.""" + self.assertIsNone(self.trackids.random_trackid()) + mock_choice.assert_not_called() + + self.trackids.trackids = {1, 2, 3} + mock_choice.return_value = 2 + self.assertEqual(self.trackids.random_trackid(), 2) + mock_choice.assert_called_with([1, 2, 3]) + + def test_remove_track(self): + """Test removing a Track from the set.""" + removed = unittest.mock.Mock() + self.trackids.trackids = {self.track1.trackid, self.track2.trackid} + self.trackids.connect("trackid-removed", removed) + + self.trackids.remove_track(self.track1) + self.assertSetEqual(self.trackids.trackids, {self.track2.trackid}) + self.assertEqual(self.trackids.n_trackids, 1) + removed.assert_called_with(self.trackids, self.track1.trackid) + + removed.reset_mock() + self.trackids.remove_track(self.track1) + self.assertSetEqual(self.trackids.trackids, {self.track2.trackid}) + self.assertEqual(self.trackids.n_trackids, 1) + removed.assert_not_called() + + def test_trackids(self): + """Test setting the Trackids property.""" + added = unittest.mock.Mock() + removed = unittest.mock.Mock() + reset = unittest.mock.Mock() + self.trackids.connect("trackid-added", added) + self.trackids.connect("trackid-removed", removed) + self.trackids.connect("trackids-reset", reset) + + self.trackids.trackids = {1, 2, 3, 4, 5} + self.assertSetEqual(self.trackids.trackids, {1, 2, 3, 4, 5}) + self.assertEqual(self.trackids.n_trackids, 5) + added.assert_not_called() + removed.assert_not_called() + reset.assert_called_with(self.trackids) + + reset.reset_mock() + self.trackids.trackids = {1, 2, 3, 4, 5} + self.assertSetEqual(self.trackids.trackids, {1, 2, 3, 4, 5}) + added.assert_not_called() + removed.assert_not_called() + reset.assert_not_called() + + self.trackids.trackids = {3, 4, 5, 6, 7} + self.assertSetEqual(self.trackids.trackids, {3, 4, 5, 6, 7}) + added.assert_has_calls([unittest.mock.call(self.trackids, 6), + unittest.mock.call(self.trackids, 7)]) + removed.assert_has_calls([unittest.mock.call(self.trackids, 1), + unittest.mock.call(self.trackids, 2)]) + reset.assert_not_called()