thread: Create a reusable Thread class
I found that I'm rewriting some of the same features every time I need to spin up a Thread for something. This is a reusable Thread that can be inherited for specific work. Signed-off-by: Anna Schumaker <Anna@NoWheyCreamery.com>
This commit is contained in:
parent
0d100ec752
commit
1db187dba5
|
@ -1,5 +1,6 @@
|
||||||
# Copyright 2024 (c) Anna Schumaker.
|
# Copyright 2024 (c) Anna Schumaker.
|
||||||
"""A Thread class designed to easily sync up with the main thread."""
|
"""A Thread class designed to easily sync up with the main thread."""
|
||||||
|
import threading
|
||||||
|
|
||||||
|
|
||||||
class Data:
|
class Data:
|
||||||
|
@ -21,3 +22,73 @@ class Data:
|
||||||
"""Get a string representation of the Data."""
|
"""Get a string representation of the Data."""
|
||||||
items = (f"{k}={v!r}" for k, v in self.__dict__.items())
|
items = (f"{k}={v!r}" for k, v in self.__dict__.items())
|
||||||
return f"{type(self).__name__}({', '.join(items)})"
|
return f"{type(self).__name__}({', '.join(items)})"
|
||||||
|
|
||||||
|
|
||||||
|
class Thread(threading.Thread):
|
||||||
|
"""A worker Thread class that is easy to sync up with the main thread."""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
"""Initialize our worker Thread object."""
|
||||||
|
super().__init__()
|
||||||
|
self.ready = threading.Event()
|
||||||
|
|
||||||
|
self._condition = threading.Condition()
|
||||||
|
self._task = None
|
||||||
|
self._result = None
|
||||||
|
|
||||||
|
self.start()
|
||||||
|
|
||||||
|
def do_get_result(self, result: Data, **kwargs) -> Data:
|
||||||
|
"""Get the result of the task."""
|
||||||
|
return self._result
|
||||||
|
|
||||||
|
def do_run_task(self, task: Data) -> None:
|
||||||
|
"""Run the task."""
|
||||||
|
self.set_result()
|
||||||
|
|
||||||
|
def do_stop(self) -> None:
|
||||||
|
"""Extra work when stopping the thread."""
|
||||||
|
|
||||||
|
def get_result(self, **kwargs) -> Data:
|
||||||
|
"""Get the result of the current task."""
|
||||||
|
with self._condition:
|
||||||
|
if not self.ready.is_set() or self._result is None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
res = self.do_get_result(self._result, **kwargs)
|
||||||
|
self._result = None
|
||||||
|
return res
|
||||||
|
|
||||||
|
def run(self) -> None:
|
||||||
|
"""Wait for a task to run."""
|
||||||
|
with self._condition:
|
||||||
|
self.ready.set()
|
||||||
|
|
||||||
|
while self._condition.wait():
|
||||||
|
if self._task is None:
|
||||||
|
self.do_stop()
|
||||||
|
break
|
||||||
|
|
||||||
|
self.do_run_task(self._task)
|
||||||
|
|
||||||
|
def set_result(self, **kwargs: dict) -> None:
|
||||||
|
"""Set the result of the task."""
|
||||||
|
self._result = Data(kwargs)
|
||||||
|
self.ready.set()
|
||||||
|
|
||||||
|
def __set_task(self, task: Data | None) -> None:
|
||||||
|
"""Set the task to be run by the thread."""
|
||||||
|
with self._condition:
|
||||||
|
self.ready.clear()
|
||||||
|
self._task = task
|
||||||
|
self._result = None
|
||||||
|
self._condition.notify()
|
||||||
|
|
||||||
|
def set_task(self, **kwargs: dict) -> None:
|
||||||
|
"""Set the task to be run by the thread."""
|
||||||
|
self.__set_task(Data(kwargs))
|
||||||
|
|
||||||
|
def stop(self) -> None:
|
||||||
|
"""Stop the thread."""
|
||||||
|
self.__set_task(None)
|
||||||
|
self.join()
|
||||||
|
|
|
@ -1,6 +1,7 @@
|
||||||
# Copyright 2024 (c) Anna Schumaker.
|
# Copyright 2024 (c) Anna Schumaker.
|
||||||
"""Tests our common Thread class."""
|
"""Tests our common Thread class."""
|
||||||
import emmental.thread
|
import emmental.thread
|
||||||
|
import threading
|
||||||
import unittest
|
import unittest
|
||||||
|
|
||||||
|
|
||||||
|
@ -38,3 +39,80 @@ class TestData(unittest.TestCase):
|
||||||
self.assertFalse(data1 == data2)
|
self.assertFalse(data1 == data2)
|
||||||
self.assertFalse(data1 == {"c": 2, "d": 4})
|
self.assertFalse(data1 == {"c": 2, "d": 4})
|
||||||
self.assertFalse(data1 == 3)
|
self.assertFalse(data1 == 3)
|
||||||
|
|
||||||
|
|
||||||
|
class TestThread(unittest.TestCase):
|
||||||
|
"""Tests our Thread class."""
|
||||||
|
|
||||||
|
def setUp(self):
|
||||||
|
"""Set up common variables."""
|
||||||
|
self.thread = emmental.thread.Thread()
|
||||||
|
|
||||||
|
def tearDown(self):
|
||||||
|
"""Clean up."""
|
||||||
|
self.thread.stop()
|
||||||
|
|
||||||
|
def test_init(self):
|
||||||
|
"""Check that the Thread was initialized properly."""
|
||||||
|
self.assertIsInstance(self.thread, threading.Thread)
|
||||||
|
self.assertIsInstance(self.thread.ready, threading.Event)
|
||||||
|
self.assertIsInstance(self.thread._condition, threading.Condition)
|
||||||
|
|
||||||
|
self.assertIsNone(self.thread._task)
|
||||||
|
self.assertIsNone(self.thread._result)
|
||||||
|
|
||||||
|
self.assertTrue(self.thread.is_alive())
|
||||||
|
self.assertTrue(self.thread.ready.is_set())
|
||||||
|
|
||||||
|
def test_set_get_result(self):
|
||||||
|
"""Test the set_result() and get_result() functions."""
|
||||||
|
with unittest.mock.patch.object(self.thread, "do_get_result",
|
||||||
|
wraps=self.thread.do_get_result) \
|
||||||
|
as mock_get_result:
|
||||||
|
self.assertIsNone(self.thread.get_result())
|
||||||
|
mock_get_result.assert_not_called()
|
||||||
|
|
||||||
|
self.thread.ready.clear()
|
||||||
|
self.thread._result = {"res": "abcde"}
|
||||||
|
self.assertIsNone(self.thread.get_result())
|
||||||
|
mock_get_result.assert_not_called()
|
||||||
|
|
||||||
|
self.thread.set_result(res="fghij")
|
||||||
|
self.assertTrue(self.thread.ready.is_set())
|
||||||
|
self.assertIsInstance(self.thread._result, emmental.thread.Data)
|
||||||
|
self.assertEqual(self.thread._result, {"res": "fghij"})
|
||||||
|
self.assertEqual(self.thread.get_result(), {"res": "fghij"})
|
||||||
|
self.assertIsNone(self.thread._result)
|
||||||
|
mock_get_result.assert_called_with({"res": "fghij"})
|
||||||
|
|
||||||
|
result = {"res1": "klmno", "res2": "pqrst"}
|
||||||
|
self.thread.set_result(**result)
|
||||||
|
self.assertEqual(self.thread.get_result(other="other", arg="arg"),
|
||||||
|
result)
|
||||||
|
mock_get_result.assert_called_with(result,
|
||||||
|
other="other", arg="arg")
|
||||||
|
|
||||||
|
def test_set_task(self):
|
||||||
|
"""Test the set_task() function."""
|
||||||
|
self.thread._result = "abcde"
|
||||||
|
|
||||||
|
with unittest.mock.patch.object(self.thread, "do_run_task",
|
||||||
|
wraps=self.thread.do_run_task) \
|
||||||
|
as mock_run_task:
|
||||||
|
self.thread.set_task(arg="test")
|
||||||
|
self.assertIsInstance(self.thread._task, emmental.thread.Data)
|
||||||
|
self.assertEqual(self.thread._task, {"arg": "test"})
|
||||||
|
self.assertIsNone(self.thread._result)
|
||||||
|
self.thread.ready.wait()
|
||||||
|
mock_run_task.assert_called_with(self.thread._task)
|
||||||
|
|
||||||
|
def test_stop(self):
|
||||||
|
"""Test stopping the Thread."""
|
||||||
|
self.thread._task = ("test", "task")
|
||||||
|
|
||||||
|
with unittest.mock.patch.object(self.thread, "do_stop") as mock_stop:
|
||||||
|
self.thread.stop()
|
||||||
|
self.assertFalse(self.thread.is_alive())
|
||||||
|
self.assertFalse(self.thread.ready.is_set())
|
||||||
|
self.assertIsNone(self.thread._task)
|
||||||
|
mock_stop.assert_called()
|
||||||
|
|
Loading…
Reference in New Issue