diff --git a/emmental/db/libraries.py b/emmental/db/libraries.py index abf0d75..e24c33e 100644 --- a/emmental/db/libraries.py +++ b/emmental/db/libraries.py @@ -55,10 +55,11 @@ class Library(playlist.Playlist): def __tag_track(self, path: pathlib.Path) -> bool: if self.tagger.ready.is_set(): - (file, tags) = self.tagger.get_result(self.table.sql, self) - if file is None: + result = self.tagger.get_result(db=self.table.sql, library=self) + if result is None: track = self.table.sql.tracks.lookup(self, path=path) - self.tagger.tag_file(path, track.mtime if track else None) + mtime = track.mtime if track else None + self.tagger.tag_file(path, mtime=mtime) else: return True return False diff --git a/emmental/db/tagger.py b/emmental/db/tagger.py index 240cfa2..ee01cf0 100644 --- a/emmental/db/tagger.py +++ b/emmental/db/tagger.py @@ -3,9 +3,9 @@ import emmental.audio.tagger import musicbrainzngs import pathlib -import threading from gi.repository import GObject from .. import audio +from .. import thread from . import albums from . import artists from . import connection @@ -178,24 +178,12 @@ class Tags: return year if year else self.db.years.create(raw_year) -class Thread(threading.Thread): +class Thread(thread.Thread): """A thread for tagging files without blocking the UI.""" def __init__(self): """Initialize the Tagger Thread.""" super().__init__() - self.ready = threading.Event() - - self._connection = None - self._condition = threading.Condition() - self._file = None - self._mtime = None - self._tags = None - self.start() - - def __close_connection(self) -> None: - if self._connection: - self._connection.close() self._connection = None def __get_connection(self) -> connection.Connection: @@ -213,55 +201,31 @@ class Thread(threading.Thread): mb_res = musicbrainzngs.get_artist_by_id(artist.mbid) artist.name = mb_res["artist"]["name"] - def get_result(self, db: GObject.TYPE_PYOBJECT, - library: playlist.Playlist) \ - -> tuple[pathlib.Path | None, Tags | None]: + def do_get_result(self, result: thread.Data, db: GObject.TYPE_PYOBJECT, + library: playlist.Playlist) -> tuple: """Return the resulting Tags structure.""" - with self._condition: - if not self.ready.is_set(): - return (None, None) + tags = None if result.tags is None else Tags(db, result.tags, library) + return (result.path, tags) - tags = Tags(db, self._tags, library) if self._tags else None - res = (self._file, tags) - self._file = None - self._tags = None - return res - - def run(self) -> None: - """Sleep until we have work to do.""" - with self._condition: - self.ready.set() - - while self._condition.wait(): - if self._file is None: - break - - tags = emmental.audio.tagger.tag_file(self._file, self._mtime) - if tags is not None: - for artist in tags.artists: - self.__check_artist(artist) - - self._tags = tags - self.ready.set() - - self.__close_connection() - - def stop(self) -> None: - """Stop the thread.""" - with self._condition: - self._file = None - self._mtime = None - self._condition.notify() - self.join() - - def tag_file(self, file: pathlib.Path, mtime: float | None) -> None: + def do_run_task(self, task: thread.Data) -> None: """Tag a file.""" - with self._condition: - self.ready.clear() - self._file = file - self._mtime = mtime - self._tags = None - self._condition.notify() + tags = emmental.audio.tagger.tag_file(task.path, task.mtime) + if tags is not None: + for artist in tags.artists: + self.__check_artist(artist) + + self.set_result(path=task.path, tags=tags) + + def do_stop(self) -> None: + """Close the connection before stopping.""" + if self._connection: + self._connection.close() + self._connection = None + + def tag_file(self, path: pathlib.Path, + *, mtime: float | None = None) -> None: + """Tag a file.""" + self.set_task(path=path, mtime=mtime) def untag_track(db: GObject.TYPE_PYOBJECT, track: tracks.Track) -> None: diff --git a/tests/db/test_libraries.py b/tests/db/test_libraries.py index 81457e0..f60c933 100644 --- a/tests/db/test_libraries.py +++ b/tests/db/test_libraries.py @@ -182,23 +182,23 @@ class TestLibraryObject(tests.util.TestCase): tagger.tag_file.assert_not_called() tagger.ready.is_set.return_value = True - tagger.get_result.return_value = (None, None) + tagger.get_result.return_value = None self.assertFalse(self.library._Library__tag_track(track)) - tagger.get_result.assert_called_with(self.sql, self.library) - tagger.tag_file.assert_called_with(track, None) + tagger.get_result.assert_called_with(db=self.sql, library=self.library) + tagger.tag_file.assert_called_with(track, mtime=None) self.sql.tracks.lookup = unittest.mock.Mock() self.sql.tracks.lookup.return_value.mtime = 12345 self.assertFalse(self.library._Library__tag_track(track)) - tagger.get_result.assert_called_with(self.sql, self.library) - tagger.tag_file.assert_called_with(track, 12345) + tagger.get_result.assert_called_with(db=self.sql, library=self.library) + tagger.tag_file.assert_called_with(track, mtime=12345) tagger.reset_mock() tagger.ready.is_set.return_value = True - tagger.get_result.return_value = (track, tags) + tagger.get_result.return_value = {"path": track, "tags": tags} self.assertTrue(self.library._Library__tag_track(track)) tagger.tag_file.assert_not_called() - tagger.get_result.assert_called_with(self.sql, self.library) + tagger.get_result.assert_called_with(db=self.sql, library=self.library) @unittest.mock.patch("emmental.db.tagger.untag_track") def test_scan_check_trackid(self, mock_untag: unittest.mock.Mock()): diff --git a/tests/db/test_tagger.py b/tests/db/test_tagger.py index 0450952..8fe3a3d 100644 --- a/tests/db/test_tagger.py +++ b/tests/db/test_tagger.py @@ -1,9 +1,9 @@ # Copyright 2022 (c) Anna Schumaker """Tests our Mutagen wrapper.""" import pathlib -import threading import unittest.mock import emmental.db.tagger +import emmental.thread import tests.util @@ -276,8 +276,8 @@ class TestTaggerThread(tests.util.TestCase): def test_init(self, mock_file: unittest.mock.Mock): """Test that the tagger thread was initialized properly.""" - self.assertIsInstance(self.tagger, threading.Thread) - self.assertIsInstance(self.tagger._condition, threading.Condition) + self.assertIsInstance(self.tagger, emmental.thread.Thread) + self.assertIsNone(self.tagger._connection) self.assertTrue(self.tagger.is_alive()) def test_stop(self, mock_file: unittest.mock.Mock): @@ -285,74 +285,49 @@ class TestTaggerThread(tests.util.TestCase): mock_connection = unittest.mock.Mock() mock_connection.close = unittest.mock.Mock() - self.tagger._file = "abcde" - self.tagger._mtime = 12345 self.tagger._connection = mock_connection + self.tagger.stop() - with unittest.mock.patch.object(self.tagger._condition, "notify", - wraps=self.tagger._condition.notify) \ - as mock_notify: - self.tagger.stop() - self.assertIsNone(self.tagger._file) - self.assertIsNone(self.tagger._mtime) - mock_notify.assert_called() - - self.assertFalse(self.tagger.is_alive()) self.assertIsNone(self.tagger._connection) mock_connection.close.assert_called() def test_tag_file(self, mock_file: unittest.mock.Mock): """Test asking the thread to tag a file.""" path = pathlib.Path("/a/b/c.ogg") - - self.assertIsInstance(self.tagger.ready, threading.Event) - self.assertIsNone(self.tagger._file) - self.assertIsNone(self.tagger._tags) - self.assertIsNone(self.tagger._mtime) - self.assertTrue(self.tagger.ready.is_set()) - mock_file.return_value = None + self.tagger.ready.set() - self.tagger._tags = 12345 - self.tagger.tag_file(path, None) - self.assertFalse(self.tagger.ready.is_set()) - self.assertEqual(self.tagger._file, path) - self.assertIsNone(self.tagger._mtime) - self.assertIsNone(self.tagger._tags) + self.tagger.tag_file(path, mtime=None) + self.assertEqual(self.tagger._task, {"path": path, "mtime": None}) self.tagger.ready.wait() - self.assertIsNone(self.tagger._tags) mock_file.assert_called_with(pathlib.Path("/a/b/c.ogg"), None) mock_file.return_value = self.make_tags(dict()) - self.tagger.tag_file(path, 12345) - self.assertEqual(self.tagger._mtime, 12345) + self.tagger.tag_file(path, mtime=12345) + self.assertEqual(self.tagger._task, {"path": path, "mtime": 12345}) self.tagger.ready.wait() - self.assertIsNotNone(self.tagger._tags) - mock_file.assert_called_with(self.tagger._file, 12345) + mock_file.assert_called_with(path, 12345) def test_get_result(self, mock_file: unittest.mock.Mock): """Test creating a Tags structure after tagging.""" mock_file.return_value = None - self.tagger.tag_file(pathlib.Path("/a/b/c.ogg"), None) - self.assertTupleEqual(self.tagger.get_result(self.sql, self.library), - (None, None)) + self.assertIsNone(self.tagger.get_result(db=self.sql, + library=self.library)) + track_path = pathlib.Path("/a/b/c.ogg") + self.tagger.tag_file(track_path, mtime=None) self.tagger.ready.wait() - self.assertTupleEqual(self.tagger.get_result(self.sql, self.library), - (pathlib.Path("/a/b/c.ogg"), None)) - self.assertIsNone(self.tagger._file) + self.assertTupleEqual(self.tagger.get_result(db=self.sql, + library=self.library), + (track_path, None)) mock_file.return_value = self.make_tags(dict()) - self.tagger.tag_file(pathlib.Path("/a/b/c.ogg"), None) + self.tagger.tag_file(track_path, mtime=None) self.tagger.ready.wait() - (file, tags) = self.tagger.get_result(self.sql, self.library) - self.assertEqual(file, pathlib.Path("/a/b/c.ogg")) - self.assertIsInstance(tags, emmental.db.tagger.Tags) - self.assertIsNone(self.tagger._file) - self.assertIsNone(self.tagger._mtime) - self.assertIsNone(self.tagger._tags) + res = self.tagger.get_result(db=self.sql, library=self.library) + self.assertTupleEqual(res, (track_path, res[1])) @unittest.mock.patch("emmental.db.connection.Connection.__call__") @unittest.mock.patch("musicbrainzngs.get_artist_by_id") @@ -370,7 +345,7 @@ class TestTaggerThread(tests.util.TestCase): mock_cursor.fetchone = unittest.mock.Mock(return_value=None) mock_connection.return_value = mock_cursor - self.tagger.tag_file(pathlib.Path("/a/b/c.ogg"), None) + self.tagger.tag_file(pathlib.Path("/a/b/c.ogg"), mtime=None) self.tagger.ready.wait() self.assertEqual(audio_tags.artists[0].name, "Some Artist") self.assertEqual(audio_tags.artists[1].name, "Some Artist") @@ -394,7 +369,7 @@ class TestTaggerThread(tests.util.TestCase): self.assertIsNone(self.tagger._connection) - self.tagger.tag_file(pathlib.Path("/a/b/c.ogg"), None) + self.tagger.tag_file(pathlib.Path("/a/b/c.ogg"), mtime=None) self.tagger.ready.wait() self.assertIsInstance(self.tagger._connection, emmental.db.connection.Connection)