db: Convert the tagger to the new Thread class
This lets us do a lot of the basic Thread operations through common code, allowing us to focus on tagging in this file instead of basic Thread controls. Signed-off-by: Anna Schumaker <Anna@NoWheyCreamery.com>
This commit is contained in:
parent
1db187dba5
commit
d373c33283
|
@ -55,10 +55,11 @@ class Library(playlist.Playlist):
|
|||
|
||||
def __tag_track(self, path: pathlib.Path) -> bool:
|
||||
if self.tagger.ready.is_set():
|
||||
(file, tags) = self.tagger.get_result(self.table.sql, self)
|
||||
if file is None:
|
||||
result = self.tagger.get_result(db=self.table.sql, library=self)
|
||||
if result is None:
|
||||
track = self.table.sql.tracks.lookup(self, path=path)
|
||||
self.tagger.tag_file(path, track.mtime if track else None)
|
||||
mtime = track.mtime if track else None
|
||||
self.tagger.tag_file(path, mtime=mtime)
|
||||
else:
|
||||
return True
|
||||
return False
|
||||
|
|
|
@ -3,9 +3,9 @@
|
|||
import emmental.audio.tagger
|
||||
import musicbrainzngs
|
||||
import pathlib
|
||||
import threading
|
||||
from gi.repository import GObject
|
||||
from .. import audio
|
||||
from .. import thread
|
||||
from . import albums
|
||||
from . import artists
|
||||
from . import connection
|
||||
|
@ -178,24 +178,12 @@ class Tags:
|
|||
return year if year else self.db.years.create(raw_year)
|
||||
|
||||
|
||||
class Thread(threading.Thread):
|
||||
class Thread(thread.Thread):
|
||||
"""A thread for tagging files without blocking the UI."""
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize the Tagger Thread."""
|
||||
super().__init__()
|
||||
self.ready = threading.Event()
|
||||
|
||||
self._connection = None
|
||||
self._condition = threading.Condition()
|
||||
self._file = None
|
||||
self._mtime = None
|
||||
self._tags = None
|
||||
self.start()
|
||||
|
||||
def __close_connection(self) -> None:
|
||||
if self._connection:
|
||||
self._connection.close()
|
||||
self._connection = None
|
||||
|
||||
def __get_connection(self) -> connection.Connection:
|
||||
|
@ -213,55 +201,31 @@ class Thread(threading.Thread):
|
|||
mb_res = musicbrainzngs.get_artist_by_id(artist.mbid)
|
||||
artist.name = mb_res["artist"]["name"]
|
||||
|
||||
def get_result(self, db: GObject.TYPE_PYOBJECT,
|
||||
library: playlist.Playlist) \
|
||||
-> tuple[pathlib.Path | None, Tags | None]:
|
||||
def do_get_result(self, result: thread.Data, db: GObject.TYPE_PYOBJECT,
|
||||
library: playlist.Playlist) -> tuple:
|
||||
"""Return the resulting Tags structure."""
|
||||
with self._condition:
|
||||
if not self.ready.is_set():
|
||||
return (None, None)
|
||||
tags = None if result.tags is None else Tags(db, result.tags, library)
|
||||
return (result.path, tags)
|
||||
|
||||
tags = Tags(db, self._tags, library) if self._tags else None
|
||||
res = (self._file, tags)
|
||||
self._file = None
|
||||
self._tags = None
|
||||
return res
|
||||
|
||||
def run(self) -> None:
|
||||
"""Sleep until we have work to do."""
|
||||
with self._condition:
|
||||
self.ready.set()
|
||||
|
||||
while self._condition.wait():
|
||||
if self._file is None:
|
||||
break
|
||||
|
||||
tags = emmental.audio.tagger.tag_file(self._file, self._mtime)
|
||||
if tags is not None:
|
||||
for artist in tags.artists:
|
||||
self.__check_artist(artist)
|
||||
|
||||
self._tags = tags
|
||||
self.ready.set()
|
||||
|
||||
self.__close_connection()
|
||||
|
||||
def stop(self) -> None:
|
||||
"""Stop the thread."""
|
||||
with self._condition:
|
||||
self._file = None
|
||||
self._mtime = None
|
||||
self._condition.notify()
|
||||
self.join()
|
||||
|
||||
def tag_file(self, file: pathlib.Path, mtime: float | None) -> None:
|
||||
def do_run_task(self, task: thread.Data) -> None:
|
||||
"""Tag a file."""
|
||||
with self._condition:
|
||||
self.ready.clear()
|
||||
self._file = file
|
||||
self._mtime = mtime
|
||||
self._tags = None
|
||||
self._condition.notify()
|
||||
tags = emmental.audio.tagger.tag_file(task.path, task.mtime)
|
||||
if tags is not None:
|
||||
for artist in tags.artists:
|
||||
self.__check_artist(artist)
|
||||
|
||||
self.set_result(path=task.path, tags=tags)
|
||||
|
||||
def do_stop(self) -> None:
|
||||
"""Close the connection before stopping."""
|
||||
if self._connection:
|
||||
self._connection.close()
|
||||
self._connection = None
|
||||
|
||||
def tag_file(self, path: pathlib.Path,
|
||||
*, mtime: float | None = None) -> None:
|
||||
"""Tag a file."""
|
||||
self.set_task(path=path, mtime=mtime)
|
||||
|
||||
|
||||
def untag_track(db: GObject.TYPE_PYOBJECT, track: tracks.Track) -> None:
|
||||
|
|
|
@ -182,23 +182,23 @@ class TestLibraryObject(tests.util.TestCase):
|
|||
tagger.tag_file.assert_not_called()
|
||||
|
||||
tagger.ready.is_set.return_value = True
|
||||
tagger.get_result.return_value = (None, None)
|
||||
tagger.get_result.return_value = None
|
||||
self.assertFalse(self.library._Library__tag_track(track))
|
||||
tagger.get_result.assert_called_with(self.sql, self.library)
|
||||
tagger.tag_file.assert_called_with(track, None)
|
||||
tagger.get_result.assert_called_with(db=self.sql, library=self.library)
|
||||
tagger.tag_file.assert_called_with(track, mtime=None)
|
||||
|
||||
self.sql.tracks.lookup = unittest.mock.Mock()
|
||||
self.sql.tracks.lookup.return_value.mtime = 12345
|
||||
self.assertFalse(self.library._Library__tag_track(track))
|
||||
tagger.get_result.assert_called_with(self.sql, self.library)
|
||||
tagger.tag_file.assert_called_with(track, 12345)
|
||||
tagger.get_result.assert_called_with(db=self.sql, library=self.library)
|
||||
tagger.tag_file.assert_called_with(track, mtime=12345)
|
||||
|
||||
tagger.reset_mock()
|
||||
tagger.ready.is_set.return_value = True
|
||||
tagger.get_result.return_value = (track, tags)
|
||||
tagger.get_result.return_value = {"path": track, "tags": tags}
|
||||
self.assertTrue(self.library._Library__tag_track(track))
|
||||
tagger.tag_file.assert_not_called()
|
||||
tagger.get_result.assert_called_with(self.sql, self.library)
|
||||
tagger.get_result.assert_called_with(db=self.sql, library=self.library)
|
||||
|
||||
@unittest.mock.patch("emmental.db.tagger.untag_track")
|
||||
def test_scan_check_trackid(self, mock_untag: unittest.mock.Mock()):
|
||||
|
|
|
@ -1,9 +1,9 @@
|
|||
# Copyright 2022 (c) Anna Schumaker
|
||||
"""Tests our Mutagen wrapper."""
|
||||
import pathlib
|
||||
import threading
|
||||
import unittest.mock
|
||||
import emmental.db.tagger
|
||||
import emmental.thread
|
||||
import tests.util
|
||||
|
||||
|
||||
|
@ -276,8 +276,8 @@ class TestTaggerThread(tests.util.TestCase):
|
|||
|
||||
def test_init(self, mock_file: unittest.mock.Mock):
|
||||
"""Test that the tagger thread was initialized properly."""
|
||||
self.assertIsInstance(self.tagger, threading.Thread)
|
||||
self.assertIsInstance(self.tagger._condition, threading.Condition)
|
||||
self.assertIsInstance(self.tagger, emmental.thread.Thread)
|
||||
self.assertIsNone(self.tagger._connection)
|
||||
self.assertTrue(self.tagger.is_alive())
|
||||
|
||||
def test_stop(self, mock_file: unittest.mock.Mock):
|
||||
|
@ -285,74 +285,49 @@ class TestTaggerThread(tests.util.TestCase):
|
|||
mock_connection = unittest.mock.Mock()
|
||||
mock_connection.close = unittest.mock.Mock()
|
||||
|
||||
self.tagger._file = "abcde"
|
||||
self.tagger._mtime = 12345
|
||||
self.tagger._connection = mock_connection
|
||||
self.tagger.stop()
|
||||
|
||||
with unittest.mock.patch.object(self.tagger._condition, "notify",
|
||||
wraps=self.tagger._condition.notify) \
|
||||
as mock_notify:
|
||||
self.tagger.stop()
|
||||
self.assertIsNone(self.tagger._file)
|
||||
self.assertIsNone(self.tagger._mtime)
|
||||
mock_notify.assert_called()
|
||||
|
||||
self.assertFalse(self.tagger.is_alive())
|
||||
self.assertIsNone(self.tagger._connection)
|
||||
mock_connection.close.assert_called()
|
||||
|
||||
def test_tag_file(self, mock_file: unittest.mock.Mock):
|
||||
"""Test asking the thread to tag a file."""
|
||||
path = pathlib.Path("/a/b/c.ogg")
|
||||
|
||||
self.assertIsInstance(self.tagger.ready, threading.Event)
|
||||
self.assertIsNone(self.tagger._file)
|
||||
self.assertIsNone(self.tagger._tags)
|
||||
self.assertIsNone(self.tagger._mtime)
|
||||
self.assertTrue(self.tagger.ready.is_set())
|
||||
|
||||
mock_file.return_value = None
|
||||
|
||||
self.tagger.ready.set()
|
||||
self.tagger._tags = 12345
|
||||
self.tagger.tag_file(path, None)
|
||||
self.assertFalse(self.tagger.ready.is_set())
|
||||
self.assertEqual(self.tagger._file, path)
|
||||
self.assertIsNone(self.tagger._mtime)
|
||||
self.assertIsNone(self.tagger._tags)
|
||||
self.tagger.tag_file(path, mtime=None)
|
||||
self.assertEqual(self.tagger._task, {"path": path, "mtime": None})
|
||||
|
||||
self.tagger.ready.wait()
|
||||
self.assertIsNone(self.tagger._tags)
|
||||
mock_file.assert_called_with(pathlib.Path("/a/b/c.ogg"), None)
|
||||
|
||||
mock_file.return_value = self.make_tags(dict())
|
||||
self.tagger.tag_file(path, 12345)
|
||||
self.assertEqual(self.tagger._mtime, 12345)
|
||||
self.tagger.tag_file(path, mtime=12345)
|
||||
self.assertEqual(self.tagger._task, {"path": path, "mtime": 12345})
|
||||
|
||||
self.tagger.ready.wait()
|
||||
self.assertIsNotNone(self.tagger._tags)
|
||||
mock_file.assert_called_with(self.tagger._file, 12345)
|
||||
mock_file.assert_called_with(path, 12345)
|
||||
|
||||
def test_get_result(self, mock_file: unittest.mock.Mock):
|
||||
"""Test creating a Tags structure after tagging."""
|
||||
mock_file.return_value = None
|
||||
self.tagger.tag_file(pathlib.Path("/a/b/c.ogg"), None)
|
||||
self.assertTupleEqual(self.tagger.get_result(self.sql, self.library),
|
||||
(None, None))
|
||||
self.assertIsNone(self.tagger.get_result(db=self.sql,
|
||||
library=self.library))
|
||||
|
||||
track_path = pathlib.Path("/a/b/c.ogg")
|
||||
self.tagger.tag_file(track_path, mtime=None)
|
||||
self.tagger.ready.wait()
|
||||
self.assertTupleEqual(self.tagger.get_result(self.sql, self.library),
|
||||
(pathlib.Path("/a/b/c.ogg"), None))
|
||||
self.assertIsNone(self.tagger._file)
|
||||
self.assertTupleEqual(self.tagger.get_result(db=self.sql,
|
||||
library=self.library),
|
||||
(track_path, None))
|
||||
|
||||
mock_file.return_value = self.make_tags(dict())
|
||||
self.tagger.tag_file(pathlib.Path("/a/b/c.ogg"), None)
|
||||
self.tagger.tag_file(track_path, mtime=None)
|
||||
self.tagger.ready.wait()
|
||||
(file, tags) = self.tagger.get_result(self.sql, self.library)
|
||||
self.assertEqual(file, pathlib.Path("/a/b/c.ogg"))
|
||||
self.assertIsInstance(tags, emmental.db.tagger.Tags)
|
||||
self.assertIsNone(self.tagger._file)
|
||||
self.assertIsNone(self.tagger._mtime)
|
||||
self.assertIsNone(self.tagger._tags)
|
||||
res = self.tagger.get_result(db=self.sql, library=self.library)
|
||||
self.assertTupleEqual(res, (track_path, res[1]))
|
||||
|
||||
@unittest.mock.patch("emmental.db.connection.Connection.__call__")
|
||||
@unittest.mock.patch("musicbrainzngs.get_artist_by_id")
|
||||
|
@ -370,7 +345,7 @@ class TestTaggerThread(tests.util.TestCase):
|
|||
mock_cursor.fetchone = unittest.mock.Mock(return_value=None)
|
||||
mock_connection.return_value = mock_cursor
|
||||
|
||||
self.tagger.tag_file(pathlib.Path("/a/b/c.ogg"), None)
|
||||
self.tagger.tag_file(pathlib.Path("/a/b/c.ogg"), mtime=None)
|
||||
self.tagger.ready.wait()
|
||||
self.assertEqual(audio_tags.artists[0].name, "Some Artist")
|
||||
self.assertEqual(audio_tags.artists[1].name, "Some Artist")
|
||||
|
@ -394,7 +369,7 @@ class TestTaggerThread(tests.util.TestCase):
|
|||
|
||||
self.assertIsNone(self.tagger._connection)
|
||||
|
||||
self.tagger.tag_file(pathlib.Path("/a/b/c.ogg"), None)
|
||||
self.tagger.tag_file(pathlib.Path("/a/b/c.ogg"), mtime=None)
|
||||
self.tagger.ready.wait()
|
||||
self.assertIsInstance(self.tagger._connection,
|
||||
emmental.db.connection.Connection)
|
||||
|
|
Loading…
Reference in New Issue