# Copyright 2024 (c) Anna Schumaker. """Tests our common Thread class.""" import emmental.thread import threading import unittest class TestData(unittest.TestCase): """Tests our thread Data class.""" def test_init_kwargs(self): """Tests initializing the data class with keyword args.""" data = emmental.thread.Data(a=1, b=2) self.assertEqual(data.a, 1) self.assertEqual(data.b, 2) self.assertEqual(repr(data), "Data(a=1, b=2)") def test_init_values_dict(self): """Test initializing the data class with a dictionary of values.""" data = emmental.thread.Data({"a": 1, "b": 2}) self.assertEqual(data.a, 1) self.assertEqual(data.b, 2) self.assertEqual(repr(data), "Data(a=1, b=2)") def test_init_both(self): """Test initializing the data class with both.""" data = emmental.thread.Data({"a": 1, "b": 2}, b=3, c='4') self.assertEqual(data.a, 1) self.assertEqual(data.b, 3) self.assertEqual(data.c, '4') self.assertEqual(repr(data), "Data(a=1, b=3, c='4')") def test_compare(self): """Test comparing two data classes.""" data1 = emmental.thread.Data({"a": 1, "b": 2}) data2 = emmental.thread.Data({"c": 3, "d": 4}) self.assertTrue(data1 == data1) self.assertTrue(data1 == {"a": 1, "b": 2}) self.assertFalse(data1 == data2) self.assertFalse(data1 == {"c": 2, "d": 4}) self.assertFalse(data1 == 3) class TestThread(unittest.TestCase): """Tests our Thread class.""" def setUp(self): """Set up common variables.""" self.thread = emmental.thread.Thread() def tearDown(self): """Clean up.""" self.thread.stop() def test_init(self): """Check that the Thread was initialized properly.""" self.assertIsInstance(self.thread, threading.Thread) self.assertIsInstance(self.thread.ready, threading.Event) self.assertIsInstance(self.thread._condition, threading.Condition) self.assertIsNone(self.thread._task) self.assertIsNone(self.thread._result) self.assertTrue(self.thread.is_alive()) self.assertTrue(self.thread.ready.is_set()) def test_set_get_result(self): """Test the set_result() and get_result() functions.""" with unittest.mock.patch.object(self.thread, "do_get_result", wraps=self.thread.do_get_result) \ as mock_get_result: self.assertIsNone(self.thread.get_result()) mock_get_result.assert_not_called() self.thread.ready.clear() self.thread._result = {"res": "abcde"} self.assertIsNone(self.thread.get_result()) mock_get_result.assert_not_called() self.thread.set_result(res="fghij") self.assertTrue(self.thread.ready.is_set()) self.assertIsInstance(self.thread._result, emmental.thread.Data) self.assertEqual(self.thread._result, {"res": "fghij"}) self.assertEqual(self.thread.get_result(), {"res": "fghij"}) self.assertIsNone(self.thread._result) mock_get_result.assert_called_with({"res": "fghij"}) result = {"res1": "klmno", "res2": "pqrst"} self.thread.set_result(**result) self.assertEqual(self.thread.get_result(other="other", arg="arg"), result) mock_get_result.assert_called_with(result, other="other", arg="arg") def test_set_task(self): """Test the set_task() function.""" self.thread._result = "abcde" with unittest.mock.patch.object(self.thread, "do_run_task", wraps=self.thread.do_run_task) \ as mock_run_task: self.thread.set_task(arg="test") self.assertIsInstance(self.thread._task, emmental.thread.Data) self.assertEqual(self.thread._task, {"arg": "test"}) self.assertIsNone(self.thread._result) self.thread.ready.wait() mock_run_task.assert_called_with(self.thread._task) def test_stop(self): """Test stopping the Thread.""" self.thread._task = ("test", "task") with unittest.mock.patch.object(self.thread, "do_stop") as mock_stop: self.thread.stop() self.assertFalse(self.thread.is_alive()) self.assertFalse(self.thread.ready.is_set()) self.assertIsNone(self.thread._task) mock_stop.assert_called()