diff --git a/tests/test_commands.py b/tests/test_commands.py index e10476e..a53eeee 100644 --- a/tests/test_commands.py +++ b/tests/test_commands.py @@ -24,9 +24,15 @@ class TestCommand(unittest.TestCase): self.assertEqual(self.command.parser.description, "description") self.assertEqual(self.command.parser._defaults["function"], self.command.do_command) + self.assertEqual(self.command.parser._defaults["done"], + self.command.do_done) self.assertEqual(self.command.sql, self.sql) def test_do_command(self): """Test the do_command() function.""" with self.assertRaises(NotImplementedError): self.command.do_command(argparse.Namespace()) + + def test_do_done(self): + """Test the do_done() function.""" + self.command.do_done(argparse.Namespace()) diff --git a/tests/test_xfstestsdb.py b/tests/test_xfstestsdb.py index cc5a59b..4f18924 100644 --- a/tests/test_xfstestsdb.py +++ b/tests/test_xfstestsdb.py @@ -37,16 +37,18 @@ class TestXfstestsdb(unittest.TestCase): def test_run(self): """Test running the xfstestsdb.""" parser = self.xfstestsdb.subparser.add_parser("test-run", help="help") + test_done = unittest.mock.Mock() test_passed = False def test_func(args: argparse.Namespace) -> None: nonlocal test_passed self.assertTrue(self.xfstestsdb.sql.sql.in_transaction) test_passed = True - parser.set_defaults(function=test_func) + parser.set_defaults(function=test_func, done=test_done) self.xfstestsdb.run(["test-run"]) self.assertTrue(test_passed) + test_done.assert_called() @unittest.mock.patch("sys.stdout", new_callable=io.StringIO) def test_version(self, mock_stdout: io.StringIO): diff --git a/xfstestsdb/__init__.py b/xfstestsdb/__init__.py index 1ff45c2..700041f 100644 --- a/xfstestsdb/__init__.py +++ b/xfstestsdb/__init__.py @@ -24,7 +24,8 @@ class Command: def __init__(self) -> None: """Initialize the xfstestsdb command.""" self.parser = argparse.ArgumentParser() - self.parser.set_defaults(function=lambda x: self.parser.print_usage()) + self.parser.set_defaults(function=lambda x: self.parser.print_usage(), + done=lambda x: None) self.parser.add_argument("--version", action="store_true", help="show version number and exit") self.subparser = self.parser.add_subparsers(title="commands") @@ -55,3 +56,4 @@ class Command: else: with self.sql: parsed.function(parsed) + parsed.done(parsed) diff --git a/xfstestsdb/commands.py b/xfstestsdb/commands.py index 8053992..7a5042a 100644 --- a/xfstestsdb/commands.py +++ b/xfstestsdb/commands.py @@ -12,8 +12,12 @@ class Command: """Set up the Command.""" self.parser = subparser.add_parser(name, help=help, **kwargs) self.parser.set_defaults(function=self.do_command) + self.parser.set_defaults(done=self.do_done) self.sql = sql def do_command(self, args: argparse.Namespace) -> None: """Do something.""" raise NotImplementedError + + def do_done(self, args: argparse.Namespace) -> None: + """Run after the main command, outside of a transaction."""