db: Convert the tagger to the new Thread class

This lets us do a lot of the basic Thread operations through common
code, allowing us to focus on tagging in this file instead of basic
Thread controls.

Signed-off-by: Anna Schumaker <Anna@NoWheyCreamery.com>
This commit is contained in:
Anna Schumaker 2024-02-04 17:15:02 -05:00
parent 1db187dba5
commit d373c33283
4 changed files with 57 additions and 117 deletions

View File

@ -55,10 +55,11 @@ class Library(playlist.Playlist):
def __tag_track(self, path: pathlib.Path) -> bool:
if self.tagger.ready.is_set():
(file, tags) = self.tagger.get_result(self.table.sql, self)
if file is None:
result = self.tagger.get_result(db=self.table.sql, library=self)
if result is None:
track = self.table.sql.tracks.lookup(self, path=path)
self.tagger.tag_file(path, track.mtime if track else None)
mtime = track.mtime if track else None
self.tagger.tag_file(path, mtime=mtime)
else:
return True
return False

View File

@ -3,9 +3,9 @@
import emmental.audio.tagger
import musicbrainzngs
import pathlib
import threading
from gi.repository import GObject
from .. import audio
from .. import thread
from . import albums
from . import artists
from . import connection
@ -178,24 +178,12 @@ class Tags:
return year if year else self.db.years.create(raw_year)
class Thread(threading.Thread):
class Thread(thread.Thread):
"""A thread for tagging files without blocking the UI."""
def __init__(self):
"""Initialize the Tagger Thread."""
super().__init__()
self.ready = threading.Event()
self._connection = None
self._condition = threading.Condition()
self._file = None
self._mtime = None
self._tags = None
self.start()
def __close_connection(self) -> None:
if self._connection:
self._connection.close()
self._connection = None
def __get_connection(self) -> connection.Connection:
@ -213,55 +201,31 @@ class Thread(threading.Thread):
mb_res = musicbrainzngs.get_artist_by_id(artist.mbid)
artist.name = mb_res["artist"]["name"]
def get_result(self, db: GObject.TYPE_PYOBJECT,
library: playlist.Playlist) \
-> tuple[pathlib.Path | None, Tags | None]:
def do_get_result(self, result: thread.Data, db: GObject.TYPE_PYOBJECT,
library: playlist.Playlist) -> tuple:
"""Return the resulting Tags structure."""
with self._condition:
if not self.ready.is_set():
return (None, None)
tags = None if result.tags is None else Tags(db, result.tags, library)
return (result.path, tags)
tags = Tags(db, self._tags, library) if self._tags else None
res = (self._file, tags)
self._file = None
self._tags = None
return res
def run(self) -> None:
"""Sleep until we have work to do."""
with self._condition:
self.ready.set()
while self._condition.wait():
if self._file is None:
break
tags = emmental.audio.tagger.tag_file(self._file, self._mtime)
if tags is not None:
for artist in tags.artists:
self.__check_artist(artist)
self._tags = tags
self.ready.set()
self.__close_connection()
def stop(self) -> None:
"""Stop the thread."""
with self._condition:
self._file = None
self._mtime = None
self._condition.notify()
self.join()
def tag_file(self, file: pathlib.Path, mtime: float | None) -> None:
def do_run_task(self, task: thread.Data) -> None:
"""Tag a file."""
with self._condition:
self.ready.clear()
self._file = file
self._mtime = mtime
self._tags = None
self._condition.notify()
tags = emmental.audio.tagger.tag_file(task.path, task.mtime)
if tags is not None:
for artist in tags.artists:
self.__check_artist(artist)
self.set_result(path=task.path, tags=tags)
def do_stop(self) -> None:
"""Close the connection before stopping."""
if self._connection:
self._connection.close()
self._connection = None
def tag_file(self, path: pathlib.Path,
*, mtime: float | None = None) -> None:
"""Tag a file."""
self.set_task(path=path, mtime=mtime)
def untag_track(db: GObject.TYPE_PYOBJECT, track: tracks.Track) -> None:

View File

@ -182,23 +182,23 @@ class TestLibraryObject(tests.util.TestCase):
tagger.tag_file.assert_not_called()
tagger.ready.is_set.return_value = True
tagger.get_result.return_value = (None, None)
tagger.get_result.return_value = None
self.assertFalse(self.library._Library__tag_track(track))
tagger.get_result.assert_called_with(self.sql, self.library)
tagger.tag_file.assert_called_with(track, None)
tagger.get_result.assert_called_with(db=self.sql, library=self.library)
tagger.tag_file.assert_called_with(track, mtime=None)
self.sql.tracks.lookup = unittest.mock.Mock()
self.sql.tracks.lookup.return_value.mtime = 12345
self.assertFalse(self.library._Library__tag_track(track))
tagger.get_result.assert_called_with(self.sql, self.library)
tagger.tag_file.assert_called_with(track, 12345)
tagger.get_result.assert_called_with(db=self.sql, library=self.library)
tagger.tag_file.assert_called_with(track, mtime=12345)
tagger.reset_mock()
tagger.ready.is_set.return_value = True
tagger.get_result.return_value = (track, tags)
tagger.get_result.return_value = {"path": track, "tags": tags}
self.assertTrue(self.library._Library__tag_track(track))
tagger.tag_file.assert_not_called()
tagger.get_result.assert_called_with(self.sql, self.library)
tagger.get_result.assert_called_with(db=self.sql, library=self.library)
@unittest.mock.patch("emmental.db.tagger.untag_track")
def test_scan_check_trackid(self, mock_untag: unittest.mock.Mock()):

View File

@ -1,9 +1,9 @@
# Copyright 2022 (c) Anna Schumaker
"""Tests our Mutagen wrapper."""
import pathlib
import threading
import unittest.mock
import emmental.db.tagger
import emmental.thread
import tests.util
@ -276,8 +276,8 @@ class TestTaggerThread(tests.util.TestCase):
def test_init(self, mock_file: unittest.mock.Mock):
"""Test that the tagger thread was initialized properly."""
self.assertIsInstance(self.tagger, threading.Thread)
self.assertIsInstance(self.tagger._condition, threading.Condition)
self.assertIsInstance(self.tagger, emmental.thread.Thread)
self.assertIsNone(self.tagger._connection)
self.assertTrue(self.tagger.is_alive())
def test_stop(self, mock_file: unittest.mock.Mock):
@ -285,74 +285,49 @@ class TestTaggerThread(tests.util.TestCase):
mock_connection = unittest.mock.Mock()
mock_connection.close = unittest.mock.Mock()
self.tagger._file = "abcde"
self.tagger._mtime = 12345
self.tagger._connection = mock_connection
self.tagger.stop()
with unittest.mock.patch.object(self.tagger._condition, "notify",
wraps=self.tagger._condition.notify) \
as mock_notify:
self.tagger.stop()
self.assertIsNone(self.tagger._file)
self.assertIsNone(self.tagger._mtime)
mock_notify.assert_called()
self.assertFalse(self.tagger.is_alive())
self.assertIsNone(self.tagger._connection)
mock_connection.close.assert_called()
def test_tag_file(self, mock_file: unittest.mock.Mock):
"""Test asking the thread to tag a file."""
path = pathlib.Path("/a/b/c.ogg")
self.assertIsInstance(self.tagger.ready, threading.Event)
self.assertIsNone(self.tagger._file)
self.assertIsNone(self.tagger._tags)
self.assertIsNone(self.tagger._mtime)
self.assertTrue(self.tagger.ready.is_set())
mock_file.return_value = None
self.tagger.ready.set()
self.tagger._tags = 12345
self.tagger.tag_file(path, None)
self.assertFalse(self.tagger.ready.is_set())
self.assertEqual(self.tagger._file, path)
self.assertIsNone(self.tagger._mtime)
self.assertIsNone(self.tagger._tags)
self.tagger.tag_file(path, mtime=None)
self.assertEqual(self.tagger._task, {"path": path, "mtime": None})
self.tagger.ready.wait()
self.assertIsNone(self.tagger._tags)
mock_file.assert_called_with(pathlib.Path("/a/b/c.ogg"), None)
mock_file.return_value = self.make_tags(dict())
self.tagger.tag_file(path, 12345)
self.assertEqual(self.tagger._mtime, 12345)
self.tagger.tag_file(path, mtime=12345)
self.assertEqual(self.tagger._task, {"path": path, "mtime": 12345})
self.tagger.ready.wait()
self.assertIsNotNone(self.tagger._tags)
mock_file.assert_called_with(self.tagger._file, 12345)
mock_file.assert_called_with(path, 12345)
def test_get_result(self, mock_file: unittest.mock.Mock):
"""Test creating a Tags structure after tagging."""
mock_file.return_value = None
self.tagger.tag_file(pathlib.Path("/a/b/c.ogg"), None)
self.assertTupleEqual(self.tagger.get_result(self.sql, self.library),
(None, None))
self.assertIsNone(self.tagger.get_result(db=self.sql,
library=self.library))
track_path = pathlib.Path("/a/b/c.ogg")
self.tagger.tag_file(track_path, mtime=None)
self.tagger.ready.wait()
self.assertTupleEqual(self.tagger.get_result(self.sql, self.library),
(pathlib.Path("/a/b/c.ogg"), None))
self.assertIsNone(self.tagger._file)
self.assertTupleEqual(self.tagger.get_result(db=self.sql,
library=self.library),
(track_path, None))
mock_file.return_value = self.make_tags(dict())
self.tagger.tag_file(pathlib.Path("/a/b/c.ogg"), None)
self.tagger.tag_file(track_path, mtime=None)
self.tagger.ready.wait()
(file, tags) = self.tagger.get_result(self.sql, self.library)
self.assertEqual(file, pathlib.Path("/a/b/c.ogg"))
self.assertIsInstance(tags, emmental.db.tagger.Tags)
self.assertIsNone(self.tagger._file)
self.assertIsNone(self.tagger._mtime)
self.assertIsNone(self.tagger._tags)
res = self.tagger.get_result(db=self.sql, library=self.library)
self.assertTupleEqual(res, (track_path, res[1]))
@unittest.mock.patch("emmental.db.connection.Connection.__call__")
@unittest.mock.patch("musicbrainzngs.get_artist_by_id")
@ -370,7 +345,7 @@ class TestTaggerThread(tests.util.TestCase):
mock_cursor.fetchone = unittest.mock.Mock(return_value=None)
mock_connection.return_value = mock_cursor
self.tagger.tag_file(pathlib.Path("/a/b/c.ogg"), None)
self.tagger.tag_file(pathlib.Path("/a/b/c.ogg"), mtime=None)
self.tagger.ready.wait()
self.assertEqual(audio_tags.artists[0].name, "Some Artist")
self.assertEqual(audio_tags.artists[1].name, "Some Artist")
@ -394,7 +369,7 @@ class TestTaggerThread(tests.util.TestCase):
self.assertIsNone(self.tagger._connection)
self.tagger.tag_file(pathlib.Path("/a/b/c.ogg"), None)
self.tagger.tag_file(pathlib.Path("/a/b/c.ogg"), mtime=None)
self.tagger.ready.wait()
self.assertIsInstance(self.tagger._connection,
emmental.db.connection.Connection)