thread: Create a reusable Thread class

I found that I'm rewriting some of the same features every time I need
to spin up a Thread for something. This is a reusable Thread that can be
inherited for specific work.

Signed-off-by: Anna Schumaker <Anna@NoWheyCreamery.com>
This commit is contained in:
Anna Schumaker 2024-01-31 10:46:30 -05:00
parent 0d100ec752
commit 1db187dba5
2 changed files with 149 additions and 0 deletions

View File

@ -1,5 +1,6 @@
# Copyright 2024 (c) Anna Schumaker. # Copyright 2024 (c) Anna Schumaker.
"""A Thread class designed to easily sync up with the main thread.""" """A Thread class designed to easily sync up with the main thread."""
import threading
class Data: class Data:
@ -21,3 +22,73 @@ class Data:
"""Get a string representation of the Data.""" """Get a string representation of the Data."""
items = (f"{k}={v!r}" for k, v in self.__dict__.items()) items = (f"{k}={v!r}" for k, v in self.__dict__.items())
return f"{type(self).__name__}({', '.join(items)})" return f"{type(self).__name__}({', '.join(items)})"
class Thread(threading.Thread):
"""A worker Thread class that is easy to sync up with the main thread."""
def __init__(self):
"""Initialize our worker Thread object."""
super().__init__()
self.ready = threading.Event()
self._condition = threading.Condition()
self._task = None
self._result = None
self.start()
def do_get_result(self, result: Data, **kwargs) -> Data:
"""Get the result of the task."""
return self._result
def do_run_task(self, task: Data) -> None:
"""Run the task."""
self.set_result()
def do_stop(self) -> None:
"""Extra work when stopping the thread."""
def get_result(self, **kwargs) -> Data:
"""Get the result of the current task."""
with self._condition:
if not self.ready.is_set() or self._result is None:
return None
res = self.do_get_result(self._result, **kwargs)
self._result = None
return res
def run(self) -> None:
"""Wait for a task to run."""
with self._condition:
self.ready.set()
while self._condition.wait():
if self._task is None:
self.do_stop()
break
self.do_run_task(self._task)
def set_result(self, **kwargs: dict) -> None:
"""Set the result of the task."""
self._result = Data(kwargs)
self.ready.set()
def __set_task(self, task: Data | None) -> None:
"""Set the task to be run by the thread."""
with self._condition:
self.ready.clear()
self._task = task
self._result = None
self._condition.notify()
def set_task(self, **kwargs: dict) -> None:
"""Set the task to be run by the thread."""
self.__set_task(Data(kwargs))
def stop(self) -> None:
"""Stop the thread."""
self.__set_task(None)
self.join()

View File

@ -1,6 +1,7 @@
# Copyright 2024 (c) Anna Schumaker. # Copyright 2024 (c) Anna Schumaker.
"""Tests our common Thread class.""" """Tests our common Thread class."""
import emmental.thread import emmental.thread
import threading
import unittest import unittest
@ -38,3 +39,80 @@ class TestData(unittest.TestCase):
self.assertFalse(data1 == data2) self.assertFalse(data1 == data2)
self.assertFalse(data1 == {"c": 2, "d": 4}) self.assertFalse(data1 == {"c": 2, "d": 4})
self.assertFalse(data1 == 3) self.assertFalse(data1 == 3)
class TestThread(unittest.TestCase):
"""Tests our Thread class."""
def setUp(self):
"""Set up common variables."""
self.thread = emmental.thread.Thread()
def tearDown(self):
"""Clean up."""
self.thread.stop()
def test_init(self):
"""Check that the Thread was initialized properly."""
self.assertIsInstance(self.thread, threading.Thread)
self.assertIsInstance(self.thread.ready, threading.Event)
self.assertIsInstance(self.thread._condition, threading.Condition)
self.assertIsNone(self.thread._task)
self.assertIsNone(self.thread._result)
self.assertTrue(self.thread.is_alive())
self.assertTrue(self.thread.ready.is_set())
def test_set_get_result(self):
"""Test the set_result() and get_result() functions."""
with unittest.mock.patch.object(self.thread, "do_get_result",
wraps=self.thread.do_get_result) \
as mock_get_result:
self.assertIsNone(self.thread.get_result())
mock_get_result.assert_not_called()
self.thread.ready.clear()
self.thread._result = {"res": "abcde"}
self.assertIsNone(self.thread.get_result())
mock_get_result.assert_not_called()
self.thread.set_result(res="fghij")
self.assertTrue(self.thread.ready.is_set())
self.assertIsInstance(self.thread._result, emmental.thread.Data)
self.assertEqual(self.thread._result, {"res": "fghij"})
self.assertEqual(self.thread.get_result(), {"res": "fghij"})
self.assertIsNone(self.thread._result)
mock_get_result.assert_called_with({"res": "fghij"})
result = {"res1": "klmno", "res2": "pqrst"}
self.thread.set_result(**result)
self.assertEqual(self.thread.get_result(other="other", arg="arg"),
result)
mock_get_result.assert_called_with(result,
other="other", arg="arg")
def test_set_task(self):
"""Test the set_task() function."""
self.thread._result = "abcde"
with unittest.mock.patch.object(self.thread, "do_run_task",
wraps=self.thread.do_run_task) \
as mock_run_task:
self.thread.set_task(arg="test")
self.assertIsInstance(self.thread._task, emmental.thread.Data)
self.assertEqual(self.thread._task, {"arg": "test"})
self.assertIsNone(self.thread._result)
self.thread.ready.wait()
mock_run_task.assert_called_with(self.thread._task)
def test_stop(self):
"""Test stopping the Thread."""
self.thread._task = ("test", "task")
with unittest.mock.patch.object(self.thread, "do_stop") as mock_stop:
self.thread.stop()
self.assertFalse(self.thread.is_alive())
self.assertFalse(self.thread.ready.is_set())
self.assertIsNone(self.thread._task)
mock_stop.assert_called()