diff --git a/emmental/db/idle.py b/emmental/db/idle.py index 28f0932..78e350c 100644 --- a/emmental/db/idle.py +++ b/emmental/db/idle.py @@ -52,6 +52,11 @@ class Queue(GObject.GObject): self.running = False self._idle_id = None + def cancel_task(self, func: typing.Callable) -> None: + """Remove all instances of a specific task from the Idle Queue.""" + self._tasks = [t for t in self._tasks if t[0] != func] + self.__update_counters() + def complete(self) -> None: """Complete all pending tasks.""" if self.running: diff --git a/tests/db/test_idle.py b/tests/db/test_idle.py index ab1bb0a..6a73486 100644 --- a/tests/db/test_idle.py +++ b/tests/db/test_idle.py @@ -51,6 +51,26 @@ class TestIdleQueue(unittest.TestCase): self.assertEqual(self.queue.total, 0) self.assertEqual(self.queue.progress, 0.0) + def test_cancel_task(self, mock_idle_add: unittest.mock.Mock, + mock_source_removed: unittest.mock.Mock): + """Test canceling a specific task.""" + self.queue.push(1) + self.queue.push(2) + self.queue.push(1) + + self.queue.cancel_task(1) + self.assertListEqual(self.queue._tasks, [(2,)]) + self.assertEqual(self.queue.total, 3) + self.assertAlmostEqual(self.queue.progress, 2 / 3) + mock_source_removed.assert_not_called() + + self.queue.cancel_task(2) + self.assertListEqual(self.queue._tasks, []) + self.assertIsNone(self.queue._idle_id) + self.assertEqual(self.queue.total, 0) + self.assertEqual(self.queue.progress, 0.0) + mock_source_removed.assert_called_with(42) + def test_complete(self, mock_idle_add: unittest.mock.Mock, mock_source_removed: unittest.mock.Mock): """Test completing queued tasks."""