diff --git a/emmental/thread.py b/emmental/thread.py index 589733f..22e6ce8 100644 --- a/emmental/thread.py +++ b/emmental/thread.py @@ -1,5 +1,6 @@ # Copyright 2024 (c) Anna Schumaker. """A Thread class designed to easily sync up with the main thread.""" +import threading class Data: @@ -21,3 +22,73 @@ class Data: """Get a string representation of the Data.""" items = (f"{k}={v!r}" for k, v in self.__dict__.items()) return f"{type(self).__name__}({', '.join(items)})" + + +class Thread(threading.Thread): + """A worker Thread class that is easy to sync up with the main thread.""" + + def __init__(self): + """Initialize our worker Thread object.""" + super().__init__() + self.ready = threading.Event() + + self._condition = threading.Condition() + self._task = None + self._result = None + + self.start() + + def do_get_result(self, result: Data, **kwargs) -> Data: + """Get the result of the task.""" + return self._result + + def do_run_task(self, task: Data) -> None: + """Run the task.""" + self.set_result() + + def do_stop(self) -> None: + """Extra work when stopping the thread.""" + + def get_result(self, **kwargs) -> Data: + """Get the result of the current task.""" + with self._condition: + if not self.ready.is_set() or self._result is None: + return None + + res = self.do_get_result(self._result, **kwargs) + self._result = None + return res + + def run(self) -> None: + """Wait for a task to run.""" + with self._condition: + self.ready.set() + + while self._condition.wait(): + if self._task is None: + self.do_stop() + break + + self.do_run_task(self._task) + + def set_result(self, **kwargs: dict) -> None: + """Set the result of the task.""" + self._result = Data(kwargs) + self.ready.set() + + def __set_task(self, task: Data | None) -> None: + """Set the task to be run by the thread.""" + with self._condition: + self.ready.clear() + self._task = task + self._result = None + self._condition.notify() + + def set_task(self, **kwargs: dict) -> None: + """Set the task to be run by the thread.""" + self.__set_task(Data(kwargs)) + + def stop(self) -> None: + """Stop the thread.""" + self.__set_task(None) + self.join() diff --git a/tests/test_thread.py b/tests/test_thread.py index 5a8dc66..79730db 100644 --- a/tests/test_thread.py +++ b/tests/test_thread.py @@ -1,6 +1,7 @@ # Copyright 2024 (c) Anna Schumaker. """Tests our common Thread class.""" import emmental.thread +import threading import unittest @@ -38,3 +39,80 @@ class TestData(unittest.TestCase): self.assertFalse(data1 == data2) self.assertFalse(data1 == {"c": 2, "d": 4}) self.assertFalse(data1 == 3) + + +class TestThread(unittest.TestCase): + """Tests our Thread class.""" + + def setUp(self): + """Set up common variables.""" + self.thread = emmental.thread.Thread() + + def tearDown(self): + """Clean up.""" + self.thread.stop() + + def test_init(self): + """Check that the Thread was initialized properly.""" + self.assertIsInstance(self.thread, threading.Thread) + self.assertIsInstance(self.thread.ready, threading.Event) + self.assertIsInstance(self.thread._condition, threading.Condition) + + self.assertIsNone(self.thread._task) + self.assertIsNone(self.thread._result) + + self.assertTrue(self.thread.is_alive()) + self.assertTrue(self.thread.ready.is_set()) + + def test_set_get_result(self): + """Test the set_result() and get_result() functions.""" + with unittest.mock.patch.object(self.thread, "do_get_result", + wraps=self.thread.do_get_result) \ + as mock_get_result: + self.assertIsNone(self.thread.get_result()) + mock_get_result.assert_not_called() + + self.thread.ready.clear() + self.thread._result = {"res": "abcde"} + self.assertIsNone(self.thread.get_result()) + mock_get_result.assert_not_called() + + self.thread.set_result(res="fghij") + self.assertTrue(self.thread.ready.is_set()) + self.assertIsInstance(self.thread._result, emmental.thread.Data) + self.assertEqual(self.thread._result, {"res": "fghij"}) + self.assertEqual(self.thread.get_result(), {"res": "fghij"}) + self.assertIsNone(self.thread._result) + mock_get_result.assert_called_with({"res": "fghij"}) + + result = {"res1": "klmno", "res2": "pqrst"} + self.thread.set_result(**result) + self.assertEqual(self.thread.get_result(other="other", arg="arg"), + result) + mock_get_result.assert_called_with(result, + other="other", arg="arg") + + def test_set_task(self): + """Test the set_task() function.""" + self.thread._result = "abcde" + + with unittest.mock.patch.object(self.thread, "do_run_task", + wraps=self.thread.do_run_task) \ + as mock_run_task: + self.thread.set_task(arg="test") + self.assertIsInstance(self.thread._task, emmental.thread.Data) + self.assertEqual(self.thread._task, {"arg": "test"}) + self.assertIsNone(self.thread._result) + self.thread.ready.wait() + mock_run_task.assert_called_with(self.thread._task) + + def test_stop(self): + """Test stopping the Thread.""" + self.thread._task = ("test", "task") + + with unittest.mock.patch.object(self.thread, "do_stop") as mock_stop: + self.thread.stop() + self.assertFalse(self.thread.is_alive()) + self.assertFalse(self.thread.ready.is_set()) + self.assertIsNone(self.thread._task) + mock_stop.assert_called()