db: Convert the tagger to the new Thread class

This lets us do a lot of the basic Thread operations through common
code, allowing us to focus on tagging in this file instead of basic
Thread controls.

Signed-off-by: Anna Schumaker <Anna@NoWheyCreamery.com>
This commit is contained in:
Anna Schumaker 2024-02-04 17:15:02 -05:00
parent 1db187dba5
commit d373c33283
4 changed files with 57 additions and 117 deletions

View File

@ -55,10 +55,11 @@ class Library(playlist.Playlist):
def __tag_track(self, path: pathlib.Path) -> bool: def __tag_track(self, path: pathlib.Path) -> bool:
if self.tagger.ready.is_set(): if self.tagger.ready.is_set():
(file, tags) = self.tagger.get_result(self.table.sql, self) result = self.tagger.get_result(db=self.table.sql, library=self)
if file is None: if result is None:
track = self.table.sql.tracks.lookup(self, path=path) track = self.table.sql.tracks.lookup(self, path=path)
self.tagger.tag_file(path, track.mtime if track else None) mtime = track.mtime if track else None
self.tagger.tag_file(path, mtime=mtime)
else: else:
return True return True
return False return False

View File

@ -3,9 +3,9 @@
import emmental.audio.tagger import emmental.audio.tagger
import musicbrainzngs import musicbrainzngs
import pathlib import pathlib
import threading
from gi.repository import GObject from gi.repository import GObject
from .. import audio from .. import audio
from .. import thread
from . import albums from . import albums
from . import artists from . import artists
from . import connection from . import connection
@ -178,24 +178,12 @@ class Tags:
return year if year else self.db.years.create(raw_year) return year if year else self.db.years.create(raw_year)
class Thread(threading.Thread): class Thread(thread.Thread):
"""A thread for tagging files without blocking the UI.""" """A thread for tagging files without blocking the UI."""
def __init__(self): def __init__(self):
"""Initialize the Tagger Thread.""" """Initialize the Tagger Thread."""
super().__init__() super().__init__()
self.ready = threading.Event()
self._connection = None
self._condition = threading.Condition()
self._file = None
self._mtime = None
self._tags = None
self.start()
def __close_connection(self) -> None:
if self._connection:
self._connection.close()
self._connection = None self._connection = None
def __get_connection(self) -> connection.Connection: def __get_connection(self) -> connection.Connection:
@ -213,55 +201,31 @@ class Thread(threading.Thread):
mb_res = musicbrainzngs.get_artist_by_id(artist.mbid) mb_res = musicbrainzngs.get_artist_by_id(artist.mbid)
artist.name = mb_res["artist"]["name"] artist.name = mb_res["artist"]["name"]
def get_result(self, db: GObject.TYPE_PYOBJECT, def do_get_result(self, result: thread.Data, db: GObject.TYPE_PYOBJECT,
library: playlist.Playlist) \ library: playlist.Playlist) -> tuple:
-> tuple[pathlib.Path | None, Tags | None]:
"""Return the resulting Tags structure.""" """Return the resulting Tags structure."""
with self._condition: tags = None if result.tags is None else Tags(db, result.tags, library)
if not self.ready.is_set(): return (result.path, tags)
return (None, None)
tags = Tags(db, self._tags, library) if self._tags else None def do_run_task(self, task: thread.Data) -> None:
res = (self._file, tags)
self._file = None
self._tags = None
return res
def run(self) -> None:
"""Sleep until we have work to do."""
with self._condition:
self.ready.set()
while self._condition.wait():
if self._file is None:
break
tags = emmental.audio.tagger.tag_file(self._file, self._mtime)
if tags is not None:
for artist in tags.artists:
self.__check_artist(artist)
self._tags = tags
self.ready.set()
self.__close_connection()
def stop(self) -> None:
"""Stop the thread."""
with self._condition:
self._file = None
self._mtime = None
self._condition.notify()
self.join()
def tag_file(self, file: pathlib.Path, mtime: float | None) -> None:
"""Tag a file.""" """Tag a file."""
with self._condition: tags = emmental.audio.tagger.tag_file(task.path, task.mtime)
self.ready.clear() if tags is not None:
self._file = file for artist in tags.artists:
self._mtime = mtime self.__check_artist(artist)
self._tags = None
self._condition.notify() self.set_result(path=task.path, tags=tags)
def do_stop(self) -> None:
"""Close the connection before stopping."""
if self._connection:
self._connection.close()
self._connection = None
def tag_file(self, path: pathlib.Path,
*, mtime: float | None = None) -> None:
"""Tag a file."""
self.set_task(path=path, mtime=mtime)
def untag_track(db: GObject.TYPE_PYOBJECT, track: tracks.Track) -> None: def untag_track(db: GObject.TYPE_PYOBJECT, track: tracks.Track) -> None:

View File

@ -182,23 +182,23 @@ class TestLibraryObject(tests.util.TestCase):
tagger.tag_file.assert_not_called() tagger.tag_file.assert_not_called()
tagger.ready.is_set.return_value = True tagger.ready.is_set.return_value = True
tagger.get_result.return_value = (None, None) tagger.get_result.return_value = None
self.assertFalse(self.library._Library__tag_track(track)) self.assertFalse(self.library._Library__tag_track(track))
tagger.get_result.assert_called_with(self.sql, self.library) tagger.get_result.assert_called_with(db=self.sql, library=self.library)
tagger.tag_file.assert_called_with(track, None) tagger.tag_file.assert_called_with(track, mtime=None)
self.sql.tracks.lookup = unittest.mock.Mock() self.sql.tracks.lookup = unittest.mock.Mock()
self.sql.tracks.lookup.return_value.mtime = 12345 self.sql.tracks.lookup.return_value.mtime = 12345
self.assertFalse(self.library._Library__tag_track(track)) self.assertFalse(self.library._Library__tag_track(track))
tagger.get_result.assert_called_with(self.sql, self.library) tagger.get_result.assert_called_with(db=self.sql, library=self.library)
tagger.tag_file.assert_called_with(track, 12345) tagger.tag_file.assert_called_with(track, mtime=12345)
tagger.reset_mock() tagger.reset_mock()
tagger.ready.is_set.return_value = True tagger.ready.is_set.return_value = True
tagger.get_result.return_value = (track, tags) tagger.get_result.return_value = {"path": track, "tags": tags}
self.assertTrue(self.library._Library__tag_track(track)) self.assertTrue(self.library._Library__tag_track(track))
tagger.tag_file.assert_not_called() tagger.tag_file.assert_not_called()
tagger.get_result.assert_called_with(self.sql, self.library) tagger.get_result.assert_called_with(db=self.sql, library=self.library)
@unittest.mock.patch("emmental.db.tagger.untag_track") @unittest.mock.patch("emmental.db.tagger.untag_track")
def test_scan_check_trackid(self, mock_untag: unittest.mock.Mock()): def test_scan_check_trackid(self, mock_untag: unittest.mock.Mock()):

View File

@ -1,9 +1,9 @@
# Copyright 2022 (c) Anna Schumaker # Copyright 2022 (c) Anna Schumaker
"""Tests our Mutagen wrapper.""" """Tests our Mutagen wrapper."""
import pathlib import pathlib
import threading
import unittest.mock import unittest.mock
import emmental.db.tagger import emmental.db.tagger
import emmental.thread
import tests.util import tests.util
@ -276,8 +276,8 @@ class TestTaggerThread(tests.util.TestCase):
def test_init(self, mock_file: unittest.mock.Mock): def test_init(self, mock_file: unittest.mock.Mock):
"""Test that the tagger thread was initialized properly.""" """Test that the tagger thread was initialized properly."""
self.assertIsInstance(self.tagger, threading.Thread) self.assertIsInstance(self.tagger, emmental.thread.Thread)
self.assertIsInstance(self.tagger._condition, threading.Condition) self.assertIsNone(self.tagger._connection)
self.assertTrue(self.tagger.is_alive()) self.assertTrue(self.tagger.is_alive())
def test_stop(self, mock_file: unittest.mock.Mock): def test_stop(self, mock_file: unittest.mock.Mock):
@ -285,74 +285,49 @@ class TestTaggerThread(tests.util.TestCase):
mock_connection = unittest.mock.Mock() mock_connection = unittest.mock.Mock()
mock_connection.close = unittest.mock.Mock() mock_connection.close = unittest.mock.Mock()
self.tagger._file = "abcde"
self.tagger._mtime = 12345
self.tagger._connection = mock_connection self.tagger._connection = mock_connection
self.tagger.stop()
with unittest.mock.patch.object(self.tagger._condition, "notify",
wraps=self.tagger._condition.notify) \
as mock_notify:
self.tagger.stop()
self.assertIsNone(self.tagger._file)
self.assertIsNone(self.tagger._mtime)
mock_notify.assert_called()
self.assertFalse(self.tagger.is_alive())
self.assertIsNone(self.tagger._connection) self.assertIsNone(self.tagger._connection)
mock_connection.close.assert_called() mock_connection.close.assert_called()
def test_tag_file(self, mock_file: unittest.mock.Mock): def test_tag_file(self, mock_file: unittest.mock.Mock):
"""Test asking the thread to tag a file.""" """Test asking the thread to tag a file."""
path = pathlib.Path("/a/b/c.ogg") path = pathlib.Path("/a/b/c.ogg")
self.assertIsInstance(self.tagger.ready, threading.Event)
self.assertIsNone(self.tagger._file)
self.assertIsNone(self.tagger._tags)
self.assertIsNone(self.tagger._mtime)
self.assertTrue(self.tagger.ready.is_set())
mock_file.return_value = None mock_file.return_value = None
self.tagger.ready.set() self.tagger.ready.set()
self.tagger._tags = 12345 self.tagger.tag_file(path, mtime=None)
self.tagger.tag_file(path, None) self.assertEqual(self.tagger._task, {"path": path, "mtime": None})
self.assertFalse(self.tagger.ready.is_set())
self.assertEqual(self.tagger._file, path)
self.assertIsNone(self.tagger._mtime)
self.assertIsNone(self.tagger._tags)
self.tagger.ready.wait() self.tagger.ready.wait()
self.assertIsNone(self.tagger._tags)
mock_file.assert_called_with(pathlib.Path("/a/b/c.ogg"), None) mock_file.assert_called_with(pathlib.Path("/a/b/c.ogg"), None)
mock_file.return_value = self.make_tags(dict()) mock_file.return_value = self.make_tags(dict())
self.tagger.tag_file(path, 12345) self.tagger.tag_file(path, mtime=12345)
self.assertEqual(self.tagger._mtime, 12345) self.assertEqual(self.tagger._task, {"path": path, "mtime": 12345})
self.tagger.ready.wait() self.tagger.ready.wait()
self.assertIsNotNone(self.tagger._tags) mock_file.assert_called_with(path, 12345)
mock_file.assert_called_with(self.tagger._file, 12345)
def test_get_result(self, mock_file: unittest.mock.Mock): def test_get_result(self, mock_file: unittest.mock.Mock):
"""Test creating a Tags structure after tagging.""" """Test creating a Tags structure after tagging."""
mock_file.return_value = None mock_file.return_value = None
self.tagger.tag_file(pathlib.Path("/a/b/c.ogg"), None) self.assertIsNone(self.tagger.get_result(db=self.sql,
self.assertTupleEqual(self.tagger.get_result(self.sql, self.library), library=self.library))
(None, None))
track_path = pathlib.Path("/a/b/c.ogg")
self.tagger.tag_file(track_path, mtime=None)
self.tagger.ready.wait() self.tagger.ready.wait()
self.assertTupleEqual(self.tagger.get_result(self.sql, self.library), self.assertTupleEqual(self.tagger.get_result(db=self.sql,
(pathlib.Path("/a/b/c.ogg"), None)) library=self.library),
self.assertIsNone(self.tagger._file) (track_path, None))
mock_file.return_value = self.make_tags(dict()) mock_file.return_value = self.make_tags(dict())
self.tagger.tag_file(pathlib.Path("/a/b/c.ogg"), None) self.tagger.tag_file(track_path, mtime=None)
self.tagger.ready.wait() self.tagger.ready.wait()
(file, tags) = self.tagger.get_result(self.sql, self.library) res = self.tagger.get_result(db=self.sql, library=self.library)
self.assertEqual(file, pathlib.Path("/a/b/c.ogg")) self.assertTupleEqual(res, (track_path, res[1]))
self.assertIsInstance(tags, emmental.db.tagger.Tags)
self.assertIsNone(self.tagger._file)
self.assertIsNone(self.tagger._mtime)
self.assertIsNone(self.tagger._tags)
@unittest.mock.patch("emmental.db.connection.Connection.__call__") @unittest.mock.patch("emmental.db.connection.Connection.__call__")
@unittest.mock.patch("musicbrainzngs.get_artist_by_id") @unittest.mock.patch("musicbrainzngs.get_artist_by_id")
@ -370,7 +345,7 @@ class TestTaggerThread(tests.util.TestCase):
mock_cursor.fetchone = unittest.mock.Mock(return_value=None) mock_cursor.fetchone = unittest.mock.Mock(return_value=None)
mock_connection.return_value = mock_cursor mock_connection.return_value = mock_cursor
self.tagger.tag_file(pathlib.Path("/a/b/c.ogg"), None) self.tagger.tag_file(pathlib.Path("/a/b/c.ogg"), mtime=None)
self.tagger.ready.wait() self.tagger.ready.wait()
self.assertEqual(audio_tags.artists[0].name, "Some Artist") self.assertEqual(audio_tags.artists[0].name, "Some Artist")
self.assertEqual(audio_tags.artists[1].name, "Some Artist") self.assertEqual(audio_tags.artists[1].name, "Some Artist")
@ -394,7 +369,7 @@ class TestTaggerThread(tests.util.TestCase):
self.assertIsNone(self.tagger._connection) self.assertIsNone(self.tagger._connection)
self.tagger.tag_file(pathlib.Path("/a/b/c.ogg"), None) self.tagger.tag_file(pathlib.Path("/a/b/c.ogg"), mtime=None)
self.tagger.ready.wait() self.tagger.ready.wait()
self.assertIsInstance(self.tagger._connection, self.assertIsInstance(self.tagger._connection,
emmental.db.connection.Connection) emmental.db.connection.Connection)