119 lines
4.5 KiB
Python
119 lines
4.5 KiB
Python
# Copyright 2024 (c) Anna Schumaker.
|
|
"""Tests our common Thread class."""
|
|
import emmental.thread
|
|
import threading
|
|
import unittest
|
|
|
|
|
|
class TestData(unittest.TestCase):
|
|
"""Tests our thread Data class."""
|
|
|
|
def test_init_kwargs(self):
|
|
"""Tests initializing the data class with keyword args."""
|
|
data = emmental.thread.Data(a=1, b=2)
|
|
self.assertEqual(data.a, 1)
|
|
self.assertEqual(data.b, 2)
|
|
self.assertEqual(repr(data), "Data(a=1, b=2)")
|
|
|
|
def test_init_values_dict(self):
|
|
"""Test initializing the data class with a dictionary of values."""
|
|
data = emmental.thread.Data({"a": 1, "b": 2})
|
|
self.assertEqual(data.a, 1)
|
|
self.assertEqual(data.b, 2)
|
|
self.assertEqual(repr(data), "Data(a=1, b=2)")
|
|
|
|
def test_init_both(self):
|
|
"""Test initializing the data class with both."""
|
|
data = emmental.thread.Data({"a": 1, "b": 2}, b=3, c='4')
|
|
self.assertEqual(data.a, 1)
|
|
self.assertEqual(data.b, 3)
|
|
self.assertEqual(data.c, '4')
|
|
self.assertEqual(repr(data), "Data(a=1, b=3, c='4')")
|
|
|
|
def test_compare(self):
|
|
"""Test comparing two data classes."""
|
|
data1 = emmental.thread.Data({"a": 1, "b": 2})
|
|
data2 = emmental.thread.Data({"c": 3, "d": 4})
|
|
self.assertTrue(data1 == data1)
|
|
self.assertTrue(data1 == {"a": 1, "b": 2})
|
|
self.assertFalse(data1 == data2)
|
|
self.assertFalse(data1 == {"c": 2, "d": 4})
|
|
self.assertFalse(data1 == 3)
|
|
|
|
|
|
class TestThread(unittest.TestCase):
|
|
"""Tests our Thread class."""
|
|
|
|
def setUp(self):
|
|
"""Set up common variables."""
|
|
self.thread = emmental.thread.Thread()
|
|
|
|
def tearDown(self):
|
|
"""Clean up."""
|
|
self.thread.stop()
|
|
|
|
def test_init(self):
|
|
"""Check that the Thread was initialized properly."""
|
|
self.assertIsInstance(self.thread, threading.Thread)
|
|
self.assertIsInstance(self.thread.ready, threading.Event)
|
|
self.assertIsInstance(self.thread._condition, threading.Condition)
|
|
|
|
self.assertIsNone(self.thread._task)
|
|
self.assertIsNone(self.thread._result)
|
|
|
|
self.assertTrue(self.thread.is_alive())
|
|
self.assertTrue(self.thread.ready.is_set())
|
|
|
|
def test_set_get_result(self):
|
|
"""Test the set_result() and get_result() functions."""
|
|
with unittest.mock.patch.object(self.thread, "do_get_result",
|
|
wraps=self.thread.do_get_result) \
|
|
as mock_get_result:
|
|
self.assertIsNone(self.thread.get_result())
|
|
mock_get_result.assert_not_called()
|
|
|
|
self.thread.ready.clear()
|
|
self.thread._result = {"res": "abcde"}
|
|
self.assertIsNone(self.thread.get_result())
|
|
mock_get_result.assert_not_called()
|
|
|
|
self.thread.set_result(res="fghij")
|
|
self.assertTrue(self.thread.ready.is_set())
|
|
self.assertIsInstance(self.thread._result, emmental.thread.Data)
|
|
self.assertEqual(self.thread._result, {"res": "fghij"})
|
|
self.assertEqual(self.thread.get_result(), {"res": "fghij"})
|
|
self.assertIsNone(self.thread._result)
|
|
mock_get_result.assert_called_with({"res": "fghij"})
|
|
|
|
result = {"res1": "klmno", "res2": "pqrst"}
|
|
self.thread.set_result(**result)
|
|
self.assertEqual(self.thread.get_result(other="other", arg="arg"),
|
|
result)
|
|
mock_get_result.assert_called_with(result,
|
|
other="other", arg="arg")
|
|
|
|
def test_set_task(self):
|
|
"""Test the set_task() function."""
|
|
self.thread._result = "abcde"
|
|
|
|
with unittest.mock.patch.object(self.thread, "do_run_task",
|
|
wraps=self.thread.do_run_task) \
|
|
as mock_run_task:
|
|
self.thread.set_task(arg="test")
|
|
self.assertIsInstance(self.thread._task, emmental.thread.Data)
|
|
self.assertEqual(self.thread._task, {"arg": "test"})
|
|
self.assertIsNone(self.thread._result)
|
|
self.thread.ready.wait()
|
|
mock_run_task.assert_called_with(self.thread._task)
|
|
|
|
def test_stop(self):
|
|
"""Test stopping the Thread."""
|
|
self.thread._task = ("test", "task")
|
|
|
|
with unittest.mock.patch.object(self.thread, "do_stop") as mock_stop:
|
|
self.thread.stop()
|
|
self.assertFalse(self.thread.is_alive())
|
|
self.assertFalse(self.thread.ready.is_set())
|
|
self.assertIsNone(self.thread._task)
|
|
mock_stop.assert_called()
|