db: Convert the tagger to the new Thread class
This lets us do a lot of the basic Thread operations through common code, allowing us to focus on tagging in this file instead of basic Thread controls. Signed-off-by: Anna Schumaker <Anna@NoWheyCreamery.com>
This commit is contained in:
parent
1db187dba5
commit
d373c33283
|
@ -55,10 +55,11 @@ class Library(playlist.Playlist):
|
||||||
|
|
||||||
def __tag_track(self, path: pathlib.Path) -> bool:
|
def __tag_track(self, path: pathlib.Path) -> bool:
|
||||||
if self.tagger.ready.is_set():
|
if self.tagger.ready.is_set():
|
||||||
(file, tags) = self.tagger.get_result(self.table.sql, self)
|
result = self.tagger.get_result(db=self.table.sql, library=self)
|
||||||
if file is None:
|
if result is None:
|
||||||
track = self.table.sql.tracks.lookup(self, path=path)
|
track = self.table.sql.tracks.lookup(self, path=path)
|
||||||
self.tagger.tag_file(path, track.mtime if track else None)
|
mtime = track.mtime if track else None
|
||||||
|
self.tagger.tag_file(path, mtime=mtime)
|
||||||
else:
|
else:
|
||||||
return True
|
return True
|
||||||
return False
|
return False
|
||||||
|
|
|
@ -3,9 +3,9 @@
|
||||||
import emmental.audio.tagger
|
import emmental.audio.tagger
|
||||||
import musicbrainzngs
|
import musicbrainzngs
|
||||||
import pathlib
|
import pathlib
|
||||||
import threading
|
|
||||||
from gi.repository import GObject
|
from gi.repository import GObject
|
||||||
from .. import audio
|
from .. import audio
|
||||||
|
from .. import thread
|
||||||
from . import albums
|
from . import albums
|
||||||
from . import artists
|
from . import artists
|
||||||
from . import connection
|
from . import connection
|
||||||
|
@ -178,24 +178,12 @@ class Tags:
|
||||||
return year if year else self.db.years.create(raw_year)
|
return year if year else self.db.years.create(raw_year)
|
||||||
|
|
||||||
|
|
||||||
class Thread(threading.Thread):
|
class Thread(thread.Thread):
|
||||||
"""A thread for tagging files without blocking the UI."""
|
"""A thread for tagging files without blocking the UI."""
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
"""Initialize the Tagger Thread."""
|
"""Initialize the Tagger Thread."""
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.ready = threading.Event()
|
|
||||||
|
|
||||||
self._connection = None
|
|
||||||
self._condition = threading.Condition()
|
|
||||||
self._file = None
|
|
||||||
self._mtime = None
|
|
||||||
self._tags = None
|
|
||||||
self.start()
|
|
||||||
|
|
||||||
def __close_connection(self) -> None:
|
|
||||||
if self._connection:
|
|
||||||
self._connection.close()
|
|
||||||
self._connection = None
|
self._connection = None
|
||||||
|
|
||||||
def __get_connection(self) -> connection.Connection:
|
def __get_connection(self) -> connection.Connection:
|
||||||
|
@ -213,55 +201,31 @@ class Thread(threading.Thread):
|
||||||
mb_res = musicbrainzngs.get_artist_by_id(artist.mbid)
|
mb_res = musicbrainzngs.get_artist_by_id(artist.mbid)
|
||||||
artist.name = mb_res["artist"]["name"]
|
artist.name = mb_res["artist"]["name"]
|
||||||
|
|
||||||
def get_result(self, db: GObject.TYPE_PYOBJECT,
|
def do_get_result(self, result: thread.Data, db: GObject.TYPE_PYOBJECT,
|
||||||
library: playlist.Playlist) \
|
library: playlist.Playlist) -> tuple:
|
||||||
-> tuple[pathlib.Path | None, Tags | None]:
|
|
||||||
"""Return the resulting Tags structure."""
|
"""Return the resulting Tags structure."""
|
||||||
with self._condition:
|
tags = None if result.tags is None else Tags(db, result.tags, library)
|
||||||
if not self.ready.is_set():
|
return (result.path, tags)
|
||||||
return (None, None)
|
|
||||||
|
|
||||||
tags = Tags(db, self._tags, library) if self._tags else None
|
def do_run_task(self, task: thread.Data) -> None:
|
||||||
res = (self._file, tags)
|
"""Tag a file."""
|
||||||
self._file = None
|
tags = emmental.audio.tagger.tag_file(task.path, task.mtime)
|
||||||
self._tags = None
|
|
||||||
return res
|
|
||||||
|
|
||||||
def run(self) -> None:
|
|
||||||
"""Sleep until we have work to do."""
|
|
||||||
with self._condition:
|
|
||||||
self.ready.set()
|
|
||||||
|
|
||||||
while self._condition.wait():
|
|
||||||
if self._file is None:
|
|
||||||
break
|
|
||||||
|
|
||||||
tags = emmental.audio.tagger.tag_file(self._file, self._mtime)
|
|
||||||
if tags is not None:
|
if tags is not None:
|
||||||
for artist in tags.artists:
|
for artist in tags.artists:
|
||||||
self.__check_artist(artist)
|
self.__check_artist(artist)
|
||||||
|
|
||||||
self._tags = tags
|
self.set_result(path=task.path, tags=tags)
|
||||||
self.ready.set()
|
|
||||||
|
|
||||||
self.__close_connection()
|
def do_stop(self) -> None:
|
||||||
|
"""Close the connection before stopping."""
|
||||||
|
if self._connection:
|
||||||
|
self._connection.close()
|
||||||
|
self._connection = None
|
||||||
|
|
||||||
def stop(self) -> None:
|
def tag_file(self, path: pathlib.Path,
|
||||||
"""Stop the thread."""
|
*, mtime: float | None = None) -> None:
|
||||||
with self._condition:
|
|
||||||
self._file = None
|
|
||||||
self._mtime = None
|
|
||||||
self._condition.notify()
|
|
||||||
self.join()
|
|
||||||
|
|
||||||
def tag_file(self, file: pathlib.Path, mtime: float | None) -> None:
|
|
||||||
"""Tag a file."""
|
"""Tag a file."""
|
||||||
with self._condition:
|
self.set_task(path=path, mtime=mtime)
|
||||||
self.ready.clear()
|
|
||||||
self._file = file
|
|
||||||
self._mtime = mtime
|
|
||||||
self._tags = None
|
|
||||||
self._condition.notify()
|
|
||||||
|
|
||||||
|
|
||||||
def untag_track(db: GObject.TYPE_PYOBJECT, track: tracks.Track) -> None:
|
def untag_track(db: GObject.TYPE_PYOBJECT, track: tracks.Track) -> None:
|
||||||
|
|
|
@ -182,23 +182,23 @@ class TestLibraryObject(tests.util.TestCase):
|
||||||
tagger.tag_file.assert_not_called()
|
tagger.tag_file.assert_not_called()
|
||||||
|
|
||||||
tagger.ready.is_set.return_value = True
|
tagger.ready.is_set.return_value = True
|
||||||
tagger.get_result.return_value = (None, None)
|
tagger.get_result.return_value = None
|
||||||
self.assertFalse(self.library._Library__tag_track(track))
|
self.assertFalse(self.library._Library__tag_track(track))
|
||||||
tagger.get_result.assert_called_with(self.sql, self.library)
|
tagger.get_result.assert_called_with(db=self.sql, library=self.library)
|
||||||
tagger.tag_file.assert_called_with(track, None)
|
tagger.tag_file.assert_called_with(track, mtime=None)
|
||||||
|
|
||||||
self.sql.tracks.lookup = unittest.mock.Mock()
|
self.sql.tracks.lookup = unittest.mock.Mock()
|
||||||
self.sql.tracks.lookup.return_value.mtime = 12345
|
self.sql.tracks.lookup.return_value.mtime = 12345
|
||||||
self.assertFalse(self.library._Library__tag_track(track))
|
self.assertFalse(self.library._Library__tag_track(track))
|
||||||
tagger.get_result.assert_called_with(self.sql, self.library)
|
tagger.get_result.assert_called_with(db=self.sql, library=self.library)
|
||||||
tagger.tag_file.assert_called_with(track, 12345)
|
tagger.tag_file.assert_called_with(track, mtime=12345)
|
||||||
|
|
||||||
tagger.reset_mock()
|
tagger.reset_mock()
|
||||||
tagger.ready.is_set.return_value = True
|
tagger.ready.is_set.return_value = True
|
||||||
tagger.get_result.return_value = (track, tags)
|
tagger.get_result.return_value = {"path": track, "tags": tags}
|
||||||
self.assertTrue(self.library._Library__tag_track(track))
|
self.assertTrue(self.library._Library__tag_track(track))
|
||||||
tagger.tag_file.assert_not_called()
|
tagger.tag_file.assert_not_called()
|
||||||
tagger.get_result.assert_called_with(self.sql, self.library)
|
tagger.get_result.assert_called_with(db=self.sql, library=self.library)
|
||||||
|
|
||||||
@unittest.mock.patch("emmental.db.tagger.untag_track")
|
@unittest.mock.patch("emmental.db.tagger.untag_track")
|
||||||
def test_scan_check_trackid(self, mock_untag: unittest.mock.Mock()):
|
def test_scan_check_trackid(self, mock_untag: unittest.mock.Mock()):
|
||||||
|
|
|
@ -1,9 +1,9 @@
|
||||||
# Copyright 2022 (c) Anna Schumaker
|
# Copyright 2022 (c) Anna Schumaker
|
||||||
"""Tests our Mutagen wrapper."""
|
"""Tests our Mutagen wrapper."""
|
||||||
import pathlib
|
import pathlib
|
||||||
import threading
|
|
||||||
import unittest.mock
|
import unittest.mock
|
||||||
import emmental.db.tagger
|
import emmental.db.tagger
|
||||||
|
import emmental.thread
|
||||||
import tests.util
|
import tests.util
|
||||||
|
|
||||||
|
|
||||||
|
@ -276,8 +276,8 @@ class TestTaggerThread(tests.util.TestCase):
|
||||||
|
|
||||||
def test_init(self, mock_file: unittest.mock.Mock):
|
def test_init(self, mock_file: unittest.mock.Mock):
|
||||||
"""Test that the tagger thread was initialized properly."""
|
"""Test that the tagger thread was initialized properly."""
|
||||||
self.assertIsInstance(self.tagger, threading.Thread)
|
self.assertIsInstance(self.tagger, emmental.thread.Thread)
|
||||||
self.assertIsInstance(self.tagger._condition, threading.Condition)
|
self.assertIsNone(self.tagger._connection)
|
||||||
self.assertTrue(self.tagger.is_alive())
|
self.assertTrue(self.tagger.is_alive())
|
||||||
|
|
||||||
def test_stop(self, mock_file: unittest.mock.Mock):
|
def test_stop(self, mock_file: unittest.mock.Mock):
|
||||||
|
@ -285,74 +285,49 @@ class TestTaggerThread(tests.util.TestCase):
|
||||||
mock_connection = unittest.mock.Mock()
|
mock_connection = unittest.mock.Mock()
|
||||||
mock_connection.close = unittest.mock.Mock()
|
mock_connection.close = unittest.mock.Mock()
|
||||||
|
|
||||||
self.tagger._file = "abcde"
|
|
||||||
self.tagger._mtime = 12345
|
|
||||||
self.tagger._connection = mock_connection
|
self.tagger._connection = mock_connection
|
||||||
|
|
||||||
with unittest.mock.patch.object(self.tagger._condition, "notify",
|
|
||||||
wraps=self.tagger._condition.notify) \
|
|
||||||
as mock_notify:
|
|
||||||
self.tagger.stop()
|
self.tagger.stop()
|
||||||
self.assertIsNone(self.tagger._file)
|
|
||||||
self.assertIsNone(self.tagger._mtime)
|
|
||||||
mock_notify.assert_called()
|
|
||||||
|
|
||||||
self.assertFalse(self.tagger.is_alive())
|
|
||||||
self.assertIsNone(self.tagger._connection)
|
self.assertIsNone(self.tagger._connection)
|
||||||
mock_connection.close.assert_called()
|
mock_connection.close.assert_called()
|
||||||
|
|
||||||
def test_tag_file(self, mock_file: unittest.mock.Mock):
|
def test_tag_file(self, mock_file: unittest.mock.Mock):
|
||||||
"""Test asking the thread to tag a file."""
|
"""Test asking the thread to tag a file."""
|
||||||
path = pathlib.Path("/a/b/c.ogg")
|
path = pathlib.Path("/a/b/c.ogg")
|
||||||
|
|
||||||
self.assertIsInstance(self.tagger.ready, threading.Event)
|
|
||||||
self.assertIsNone(self.tagger._file)
|
|
||||||
self.assertIsNone(self.tagger._tags)
|
|
||||||
self.assertIsNone(self.tagger._mtime)
|
|
||||||
self.assertTrue(self.tagger.ready.is_set())
|
|
||||||
|
|
||||||
mock_file.return_value = None
|
mock_file.return_value = None
|
||||||
|
|
||||||
self.tagger.ready.set()
|
self.tagger.ready.set()
|
||||||
self.tagger._tags = 12345
|
self.tagger.tag_file(path, mtime=None)
|
||||||
self.tagger.tag_file(path, None)
|
self.assertEqual(self.tagger._task, {"path": path, "mtime": None})
|
||||||
self.assertFalse(self.tagger.ready.is_set())
|
|
||||||
self.assertEqual(self.tagger._file, path)
|
|
||||||
self.assertIsNone(self.tagger._mtime)
|
|
||||||
self.assertIsNone(self.tagger._tags)
|
|
||||||
|
|
||||||
self.tagger.ready.wait()
|
self.tagger.ready.wait()
|
||||||
self.assertIsNone(self.tagger._tags)
|
|
||||||
mock_file.assert_called_with(pathlib.Path("/a/b/c.ogg"), None)
|
mock_file.assert_called_with(pathlib.Path("/a/b/c.ogg"), None)
|
||||||
|
|
||||||
mock_file.return_value = self.make_tags(dict())
|
mock_file.return_value = self.make_tags(dict())
|
||||||
self.tagger.tag_file(path, 12345)
|
self.tagger.tag_file(path, mtime=12345)
|
||||||
self.assertEqual(self.tagger._mtime, 12345)
|
self.assertEqual(self.tagger._task, {"path": path, "mtime": 12345})
|
||||||
|
|
||||||
self.tagger.ready.wait()
|
self.tagger.ready.wait()
|
||||||
self.assertIsNotNone(self.tagger._tags)
|
mock_file.assert_called_with(path, 12345)
|
||||||
mock_file.assert_called_with(self.tagger._file, 12345)
|
|
||||||
|
|
||||||
def test_get_result(self, mock_file: unittest.mock.Mock):
|
def test_get_result(self, mock_file: unittest.mock.Mock):
|
||||||
"""Test creating a Tags structure after tagging."""
|
"""Test creating a Tags structure after tagging."""
|
||||||
mock_file.return_value = None
|
mock_file.return_value = None
|
||||||
self.tagger.tag_file(pathlib.Path("/a/b/c.ogg"), None)
|
self.assertIsNone(self.tagger.get_result(db=self.sql,
|
||||||
self.assertTupleEqual(self.tagger.get_result(self.sql, self.library),
|
library=self.library))
|
||||||
(None, None))
|
|
||||||
|
|
||||||
|
track_path = pathlib.Path("/a/b/c.ogg")
|
||||||
|
self.tagger.tag_file(track_path, mtime=None)
|
||||||
self.tagger.ready.wait()
|
self.tagger.ready.wait()
|
||||||
self.assertTupleEqual(self.tagger.get_result(self.sql, self.library),
|
self.assertTupleEqual(self.tagger.get_result(db=self.sql,
|
||||||
(pathlib.Path("/a/b/c.ogg"), None))
|
library=self.library),
|
||||||
self.assertIsNone(self.tagger._file)
|
(track_path, None))
|
||||||
|
|
||||||
mock_file.return_value = self.make_tags(dict())
|
mock_file.return_value = self.make_tags(dict())
|
||||||
self.tagger.tag_file(pathlib.Path("/a/b/c.ogg"), None)
|
self.tagger.tag_file(track_path, mtime=None)
|
||||||
self.tagger.ready.wait()
|
self.tagger.ready.wait()
|
||||||
(file, tags) = self.tagger.get_result(self.sql, self.library)
|
res = self.tagger.get_result(db=self.sql, library=self.library)
|
||||||
self.assertEqual(file, pathlib.Path("/a/b/c.ogg"))
|
self.assertTupleEqual(res, (track_path, res[1]))
|
||||||
self.assertIsInstance(tags, emmental.db.tagger.Tags)
|
|
||||||
self.assertIsNone(self.tagger._file)
|
|
||||||
self.assertIsNone(self.tagger._mtime)
|
|
||||||
self.assertIsNone(self.tagger._tags)
|
|
||||||
|
|
||||||
@unittest.mock.patch("emmental.db.connection.Connection.__call__")
|
@unittest.mock.patch("emmental.db.connection.Connection.__call__")
|
||||||
@unittest.mock.patch("musicbrainzngs.get_artist_by_id")
|
@unittest.mock.patch("musicbrainzngs.get_artist_by_id")
|
||||||
|
@ -370,7 +345,7 @@ class TestTaggerThread(tests.util.TestCase):
|
||||||
mock_cursor.fetchone = unittest.mock.Mock(return_value=None)
|
mock_cursor.fetchone = unittest.mock.Mock(return_value=None)
|
||||||
mock_connection.return_value = mock_cursor
|
mock_connection.return_value = mock_cursor
|
||||||
|
|
||||||
self.tagger.tag_file(pathlib.Path("/a/b/c.ogg"), None)
|
self.tagger.tag_file(pathlib.Path("/a/b/c.ogg"), mtime=None)
|
||||||
self.tagger.ready.wait()
|
self.tagger.ready.wait()
|
||||||
self.assertEqual(audio_tags.artists[0].name, "Some Artist")
|
self.assertEqual(audio_tags.artists[0].name, "Some Artist")
|
||||||
self.assertEqual(audio_tags.artists[1].name, "Some Artist")
|
self.assertEqual(audio_tags.artists[1].name, "Some Artist")
|
||||||
|
@ -394,7 +369,7 @@ class TestTaggerThread(tests.util.TestCase):
|
||||||
|
|
||||||
self.assertIsNone(self.tagger._connection)
|
self.assertIsNone(self.tagger._connection)
|
||||||
|
|
||||||
self.tagger.tag_file(pathlib.Path("/a/b/c.ogg"), None)
|
self.tagger.tag_file(pathlib.Path("/a/b/c.ogg"), mtime=None)
|
||||||
self.tagger.ready.wait()
|
self.tagger.ready.wait()
|
||||||
self.assertIsInstance(self.tagger._connection,
|
self.assertIsInstance(self.tagger._connection,
|
||||||
emmental.db.connection.Connection)
|
emmental.db.connection.Connection)
|
||||||
|
|
Loading…
Reference in New Issue