From e0e7b556beca22235344e6fcf27cf1164bf7fd01 Mon Sep 17 00:00:00 2001 From: Anna Schumaker Date: Sun, 2 Apr 2023 21:44:56 -0400 Subject: [PATCH] db: Add Track support to the Tagger class This includes both creating new Tracks and updating existing Tracks when their tags have changed. New tracks are added to playlists using idle=True so Gtk can spread out UI updates for each playlist so we don't slow things down too much for the user. This patch also adds a library argument to the Tagger thread get_result() function which we pass to the Tagger class to be used by Tracks.creat(). I also add an mtime argument to the Tagger thread tag_file() function to pass down to the audio.tagger layer so we can skip updating tracks that have not changed since the last scan. Implements: #41 (Check for new or modified tags during startup) Signed-off-by: Anna Schumaker --- emmental/db/libraries.py | 5 +- emmental/db/tagger.py | 89 ++++++++++++++++++++-- tests/db/test_libraries.py | 14 +++- tests/db/test_tagger.py | 147 +++++++++++++++++++++++++++++++++---- 4 files changed, 230 insertions(+), 25 deletions(-) diff --git a/emmental/db/libraries.py b/emmental/db/libraries.py index b9769ba..fae61b6 100644 --- a/emmental/db/libraries.py +++ b/emmental/db/libraries.py @@ -43,9 +43,10 @@ class Library(playlist.Playlist): def __tag_track(self, path: pathlib.Path) -> bool: if self.tagger.ready.is_set(): - (file, tags) = self.tagger.get_result(self.table.sql) + (file, tags) = self.tagger.get_result(self.table.sql, self) if file is None: - self.tagger.tag_file(path) + track = self.table.sql.tracks.lookup(self, path=path) + self.tagger.tag_file(path, track.mtime if track else None) else: return True return False diff --git a/emmental/db/tagger.py b/emmental/db/tagger.py index 418cc5a..b6faf00 100644 --- a/emmental/db/tagger.py +++ b/emmental/db/tagger.py @@ -12,6 +12,8 @@ from . import connection from . import decades from . import media from . import genres +from . import playlist +from . import tracks from . import years @@ -19,7 +21,8 @@ class Tags: """Translate the audio.tagger._Tags object into Playlists.""" def __init__(self, db: GObject.TYPE_PYOBJECT, - raw_tags: audio.tagger._Tags): + raw_tags: audio.tagger._Tags, + library: playlist.Playlist): """Initialize the Tags object.""" self.db = db @@ -35,6 +38,8 @@ class Tags: self.medium = self.get_medium(raw_tags.medium) self.year = self.get_year(raw_tags.year) + self.track = self.get_track(library, raw_tags.file, raw_tags.track) + self.__update_album_artists() def __update_album_artists(self) -> None: @@ -47,6 +52,48 @@ class Tags: for artist in new - old: artist.add_album(self.album) + def __update_track(self, track: tracks.Track, + raw_track: audio.tagger._Track) -> None: + orig_year = track.get_year() + orig_decade = orig_year.parent + orig_genres = set(track.get_genres()) + orig_medium = track.get_medium() + orig_album = orig_medium.get_album() + orig_artists = set(track.get_artists()) + + track.update_properties(mediumid=self.medium.mediumid, + year=self.year.year, + title=raw_track.title, + number=raw_track.number, + length=raw_track.length, + artist=raw_track.artist, + mbid=raw_track.mbid, + mtime=raw_track.mtime) + + self.__update_track_playlist_set(track, orig_artists, + set(self.artists)) + self.__update_track_playlist_set(track, orig_genres, set(self.genres)) + + self.__update_track_playlist(track, orig_album, self.album) + self.__update_track_playlist(track, orig_medium, self.medium) + self.__update_track_playlist(track, orig_decade, self.decade) + self.__update_track_playlist(track, orig_year, self.year) + + def __update_track_playlist(self, track: tracks.Track, + orig: playlist.Playlist, + new: playlist.Playlist): + if orig != new: + orig.remove_track(track, idle=True) + new.add_track(track, idle=True) + + def __update_track_playlist_set(self, track: tracks.Track, + orig: set[playlist.Playlist], + new: set[playlist.Playlist]): + for plist in orig - new: + plist.remove_track(track, idle=True) + for plist in new - orig: + plist.add_track(track, idle=True) + def get_album(self, raw_album: audio.tagger._Album) -> albums.Album | None: """Convert the raw album into an Album object.""" if raw_album.name == "": @@ -97,6 +144,33 @@ class Tags: number=raw_medium.number, type=raw_medium.type) + def get_track(self, library: playlist.Playlist, filepath: pathlib.Path, + raw_track: audio.tagger._Track) -> tracks.Track | None: + """Convert the raw track into a Track object.""" + if self.medium is None or self.year is None: + return None + + track = self.db.tracks.lookup(library, path=filepath) + if track is not None: + self.__update_track(track, raw_track) + return track + + track = self.db.tracks.create(library, filepath, self.medium, + self.year, title=raw_track.title, + number=raw_track.number, + length=raw_track.length, + artist=raw_track.artist, + mbid=raw_track.mbid, + mtime=raw_track.mtime) + + for plist in [self.db.playlists.collection, + self.db.playlists.new_tracks, + self.db.playlists.unplayed, + self.album, *self.artists, self.medium, + *self.genres, self.decade, self.year, library]: + plist.add_track(track, idle=True) + return track + def get_year(self, raw_year: int | None) -> years.Year | None: """Convert the raw year into a Year object.""" if raw_year: @@ -115,6 +189,7 @@ class Thread(threading.Thread): self._connection = None self._condition = threading.Condition() self._file = None + self._mtime = None self._tags = None self.start() @@ -138,14 +213,15 @@ class Thread(threading.Thread): mb_res = musicbrainzngs.get_artist_by_id(artist.mbid) artist.name = mb_res["artist"]["name"] - def get_result(self, db: GObject.TYPE_PYOBJECT) \ + def get_result(self, db: GObject.TYPE_PYOBJECT, + library: playlist.Playlist) \ -> tuple[pathlib.Path | None, Tags | None]: """Return the resulting Tags structure.""" with self._condition: if not self.ready.is_set(): return (None, None) - tags = Tags(db, self._tags) if self._tags else None + tags = Tags(db, self._tags, library) if self._tags else None res = (self._file, tags) self._file = None self._tags = None @@ -160,7 +236,8 @@ class Thread(threading.Thread): if self._file is None: break - if tags := emmental.audio.tagger.tag_file(self._file, None): + tags = emmental.audio.tagger.tag_file(self._file, self._mtime) + if tags is not None: for artist in tags.artists: self.__check_artist(artist) @@ -173,13 +250,15 @@ class Thread(threading.Thread): """Stop the thread.""" with self._condition: self._file = None + self._mtime = None self._condition.notify() self.join() - def tag_file(self, file: pathlib.Path) -> None: + def tag_file(self, file: pathlib.Path, mtime: float | None) -> None: """Tag a file.""" with self._condition: self.ready.clear() self._file = file + self._mtime = mtime self._tags = None self._condition.notify() diff --git a/tests/db/test_libraries.py b/tests/db/test_libraries.py index bd94374..851c555 100644 --- a/tests/db/test_libraries.py +++ b/tests/db/test_libraries.py @@ -138,7 +138,7 @@ class TestLibraryObject(tests.util.TestCase): """Test that tracks are tagged during scanning.""" track = pathlib.Path("/a/b/c/1.ogg") raw_tags = emmental.audio.tagger._Tags(track, {}) - tags = emmental.db.tagger.Tags(self.sql, raw_tags) + tags = emmental.db.tagger.Tags(self.sql, raw_tags, self.library) tagger = unittest.mock.Mock() self.library.tagger = tagger @@ -150,15 +150,21 @@ class TestLibraryObject(tests.util.TestCase): tagger.ready.is_set.return_value = True tagger.get_result.return_value = (None, None) self.assertFalse(self.library._Library__tag_track(track)) - tagger.get_result.assert_called_with(self.sql) - tagger.tag_file.assert_called_with(track) + tagger.get_result.assert_called_with(self.sql, self.library) + tagger.tag_file.assert_called_with(track, None) + + self.sql.tracks.lookup = unittest.mock.Mock() + self.sql.tracks.lookup.return_value.mtime = 12345 + self.assertFalse(self.library._Library__tag_track(track)) + tagger.get_result.assert_called_with(self.sql, self.library) + tagger.tag_file.assert_called_with(track, 12345) tagger.reset_mock() tagger.ready.is_set.return_value = True tagger.get_result.return_value = (track, tags) self.assertTrue(self.library._Library__tag_track(track)) tagger.tag_file.assert_not_called() - tagger.get_result.assert_called_with(self.sql) + tagger.get_result.assert_called_with(self.sql, self.library) def test_stop(self): """Test stopping a Library's background work.""" diff --git a/tests/db/test_tagger.py b/tests/db/test_tagger.py index 85c1c34..952489a 100644 --- a/tests/db/test_tagger.py +++ b/tests/db/test_tagger.py @@ -13,12 +13,19 @@ class TestTags(tests.util.TestCase): def setUp(self): """Set up common variables.""" super().setUp() + for tbl in self.sql.playlist_tables(): + tbl.queue.enabled = False + tbl.load() + + self.library = self.sql.libraries.create("/a/b") self.file = pathlib.Path("/a/b/c/track.ogg") - def make_tags(self, raw_tags: dict) -> emmental.db.tagger.Tags: + def make_tags(self, raw_tags: dict, *, length: int = 0, + mtime: float = 0.0) -> emmental.db.tagger.Tags: """Set up and return our Tags object.""" - audio_tags = emmental.audio.tagger._Tags(self.file, raw_tags) - return emmental.db.tagger.Tags(self.sql, audio_tags) + audio_tags = emmental.audio.tagger._Tags(self.file, raw_tags, + length, mtime) + return emmental.db.tagger.Tags(self.sql, audio_tags, self.library) def test_init(self): """Test that the Tags object was set up properly.""" @@ -29,6 +36,7 @@ class TestTags(tests.util.TestCase): self.assertIsNone(tags.decade) self.assertIsNone(tags.medium) self.assertIsNone(tags.year) + self.assertIsNone(tags.track) self.assertListEqual(tags.album_artists, []) self.assertListEqual(tags.artists, []) @@ -106,6 +114,7 @@ class TestTags(tests.util.TestCase): def test_genres(self): """Test that genres were tagged properly.""" + self.sql.genres.autodelete = False raw_tags = {"genre": ["Genre 1", "Genre 2"]} tags = self.make_tags(raw_tags) @@ -137,6 +146,104 @@ class TestTags(tests.util.TestCase): self.assertEqual(self.make_tags(raw_tags).medium, tags.medium) self.assertEqual(tags.medium.name, "New Subtitle") + def test_track(self): + """Test that the Track was tagged properly.""" + raw_tags = {"album": ["Album Name"], + "artist": ["Track Artist"], + "date": ["1988-06"], + "title": ["Test Title"], + "tracknumber": ["3"], + "musicbrainz_releasetrackid": ["ab-cd-ef"]} + + tags = self.make_tags(raw_tags, length=42, mtime=1.234) + self.assertIsInstance(tags.track, emmental.db.tracks.Track) + self.assertEqual(tags.track.get_library(), self.library) + self.assertEqual(tags.track.get_medium(), tags.medium) + self.assertEqual(tags.track.get_year(), tags.year) + self.assertEqual(tags.track.path, self.file) + self.assertEqual(tags.track.title, "Test Title") + self.assertEqual(tags.track.artist, "Track Artist") + self.assertEqual(tags.track.mbid, "ab-cd-ef") + self.assertEqual(tags.track.number, 3) + self.assertEqual(tags.track.length, 42) + self.assertEqual(tags.track.mtime, 1.234) + + self.assertTrue(self.sql.playlists.collection.has_track(tags.track)) + self.assertTrue(self.sql.playlists.new_tracks.has_track(tags.track)) + self.assertTrue(self.sql.playlists.unplayed.has_track(tags.track)) + self.assertTrue(self.library.has_track(tags.track)) + + raw_tags["artist"] = ["New Artist"] + raw_tags["date"] = ["1985-08"] + raw_tags["discnumber"] = ["2"] + raw_tags["title"] = ["New Title"] + raw_tags["tracknumber"] = ["4"] + raw_tags["musicbrainz_releasetrackid"] = ["gh-ij-kl"] + for playlist in self.sql.playlists: + playlist.tracks.trackids = set() + + new_tags = self.make_tags(raw_tags, length=53, mtime=5.6789) + self.assertEqual(new_tags.track, tags.track) + self.assertEqual(new_tags.track.get_medium(), new_tags.medium) + self.assertEqual(new_tags.track.get_year(), new_tags.year) + self.assertEqual(new_tags.track.title, "New Title") + self.assertEqual(new_tags.track.artist, "New Artist") + self.assertEqual(new_tags.track.mbid, "gh-ij-kl") + self.assertEqual(new_tags.track.number, 4) + self.assertEqual(new_tags.track.length, 53.0) + self.assertEqual(new_tags.track.mtime, 5.6789) + + for playlist in self.sql.playlists: + with self.subTest(playlist=playlist.name): + self.assertFalse(playlist.has_track(new_tags.track)) + + def test_track_playlist_update(self): + """Test updating Track Playlists.""" + album = self.sql.albums.create("Album Name", "Artist 2", "1988-06") + medium = self.sql.media.create(album, "", number=1) + decade = self.sql.decades.create(1980) + year = self.sql.years.create(1988) + raw_tags = {"album": ["Album Name"], + "artist": ["Artist 2"], + "artists": ["Artist 1", "Artist 2"], + "date": ["1988-06"], + "discnumber": ["1"], + "genre": ["Genre 1", "Genre 2"]} + + tags = self.make_tags(raw_tags) + self.assertListEqual(tags.track.get_artists(), tags.artists) + self.assertListEqual(tags.track.get_genres(), tags.genres) + self.assertTrue(album.has_track(tags.track)) + self.assertTrue(medium.has_track(tags.track)) + self.assertTrue(decade.has_track(tags.track)) + self.assertTrue(year.has_track(tags.track)) + + new_album = self.sql.albums.create("New Album Name", "Artist 2", + "1992-10") + new_medium = self.sql.media.create(new_album, "", number=2) + new_decade = self.sql.decades.create(1990) + new_year = self.sql.years.create(1992) + raw_tags["album"] = ["New Album Name"] + raw_tags["artists"] = ["Artist 2", "Artist 3"] + raw_tags["date"] = ["1992-10"] + raw_tags["discnumber"] = ["2"] + raw_tags["genre"] = ["Genre 2", "Genre 3"] + + new_tags = self.make_tags(raw_tags) + self.assertEqual(new_tags.track, tags.track) + self.assertListEqual(tags.track.get_artists(), new_tags.artists) + self.assertListEqual(tags.track.get_genres(), new_tags.genres) + + self.assertFalse(album.has_track(tags.track)) + self.assertFalse(medium.has_track(tags.track)) + self.assertFalse(decade.has_track(tags.track)) + self.assertFalse(year.has_track(tags.track)) + + self.assertTrue(new_album.has_track(tags.track)) + self.assertTrue(new_medium.has_track(tags.track)) + self.assertTrue(new_decade.has_track(tags.track)) + self.assertTrue(new_year.has_track(tags.track)) + def test_year(self): """Test that the year was tagged properly.""" raw_tags = {"date": ["1988-06-17"]} @@ -154,6 +261,7 @@ class TestTaggerThread(tests.util.TestCase): def setUp(self): """Set up common variables.""" super().setUp() + self.library = self.sql.libraries.create("/a/b") self.tagger = emmental.db.tagger.Thread() self.tags = dict() @@ -178,6 +286,7 @@ class TestTaggerThread(tests.util.TestCase): mock_connection.close = unittest.mock.Mock() self.tagger._file = "abcde" + self.tagger._mtime = 12345 self.tagger._connection = mock_connection with unittest.mock.patch.object(self.tagger._condition, "notify", @@ -185,6 +294,7 @@ class TestTaggerThread(tests.util.TestCase): as mock_notify: self.tagger.stop() self.assertIsNone(self.tagger._file) + self.assertIsNone(self.tagger._mtime) mock_notify.assert_called() self.assertFalse(self.tagger.is_alive()) @@ -193,46 +303,55 @@ class TestTaggerThread(tests.util.TestCase): def test_tag_file(self, mock_file: unittest.mock.Mock): """Test asking the thread to tag a file.""" + path = pathlib.Path("/a/b/c.ogg") + self.assertIsInstance(self.tagger.ready, threading.Event) self.assertIsNone(self.tagger._file) self.assertIsNone(self.tagger._tags) + self.assertIsNone(self.tagger._mtime) self.assertTrue(self.tagger.ready.is_set()) mock_file.return_value = None self.tagger.ready.set() self.tagger._tags = 12345 - self.tagger.tag_file(pathlib.Path("/a/b/c.ogg")) + self.tagger.tag_file(path, None) self.assertFalse(self.tagger.ready.is_set()) - self.assertEqual(self.tagger._file, pathlib.Path("/a/b/c.ogg")) + self.assertEqual(self.tagger._file, path) + self.assertIsNone(self.tagger._mtime) self.assertIsNone(self.tagger._tags) self.tagger.ready.wait() self.assertIsNone(self.tagger._tags) mock_file.assert_called_with(pathlib.Path("/a/b/c.ogg"), None) - mock_file.return_value = self.tags - self.tagger.tag_file(pathlib.Path("/a/b/c.ogg")) + mock_file.return_value = self.make_tags(dict()) + self.tagger.tag_file(path, 12345) + self.assertEqual(self.tagger._mtime, 12345) + self.tagger.ready.wait() self.assertIsNotNone(self.tagger._tags) + mock_file.assert_called_with(self.tagger._file, 12345) def test_get_result(self, mock_file: unittest.mock.Mock): """Test creating a Tags structure after tagging.""" mock_file.return_value = None - self.tagger.tag_file(pathlib.Path("/a/b/c.ogg")) - self.assertTupleEqual(self.tagger.get_result(self.sql), (None, None)) + self.tagger.tag_file(pathlib.Path("/a/b/c.ogg"), None) + self.assertTupleEqual(self.tagger.get_result(self.sql, self.library), + (None, None)) self.tagger.ready.wait() - self.assertTupleEqual(self.tagger.get_result(self.sql), + self.assertTupleEqual(self.tagger.get_result(self.sql, self.library), (pathlib.Path("/a/b/c.ogg"), None)) self.assertIsNone(self.tagger._file) mock_file.return_value = self.make_tags(dict()) - self.tagger.tag_file(pathlib.Path("/a/b/c.ogg")) + self.tagger.tag_file(pathlib.Path("/a/b/c.ogg"), None) self.tagger.ready.wait() - (file, tags) = self.tagger.get_result(self.sql) + (file, tags) = self.tagger.get_result(self.sql, self.library) self.assertEqual(file, pathlib.Path("/a/b/c.ogg")) self.assertIsInstance(tags, emmental.db.tagger.Tags) self.assertIsNone(self.tagger._file) + self.assertIsNone(self.tagger._mtime) self.assertIsNone(self.tagger._tags) @unittest.mock.patch("emmental.db.connection.Connection.__call__") @@ -251,7 +370,7 @@ class TestTaggerThread(tests.util.TestCase): mock_cursor.fetchone = unittest.mock.Mock(return_value=None) mock_connection.return_value = mock_cursor - self.tagger.tag_file(pathlib.Path("/a/b/c.ogg")) + self.tagger.tag_file(pathlib.Path("/a/b/c.ogg"), None) self.tagger.ready.wait() self.assertEqual(audio_tags.artists[0].name, "Some Artist") self.assertEqual(audio_tags.artists[1].name, "Some Artist") @@ -275,7 +394,7 @@ class TestTaggerThread(tests.util.TestCase): self.assertIsNone(self.tagger._connection) - self.tagger.tag_file(pathlib.Path("/a/b/c.ogg")) + self.tagger.tag_file(pathlib.Path("/a/b/c.ogg"), None) self.tagger.ready.wait() self.assertIsInstance(self.tagger._connection, emmental.db.connection.Connection)