From 26779250d6f3c08c1d125dd7a3939c84367cb5ab Mon Sep 17 00:00:00 2001 From: "John T. Wodder II" Date: Thu, 4 May 2023 10:00:25 -0400 Subject: [PATCH] Clean up Click testing --- tests/experiment/tools/test_nidm_lingreg.py | 167 +++++++++----------- 1 file changed, 74 insertions(+), 93 deletions(-) diff --git a/tests/experiment/tools/test_nidm_lingreg.py b/tests/experiment/tools/test_nidm_lingreg.py index e6a13b8..01b10e0 100644 --- a/tests/experiment/tools/test_nidm_lingreg.py +++ b/tests/experiment/tools/test_nidm_lingreg.py @@ -1,63 +1,26 @@ from __future__ import annotations from pathlib import Path -import click from click.testing import CliRunner import pytest from nidm.experiment.tools.nidm_linreg import linear_regression -def call_click_command(cmd, *args, **kwargs): - """Wrapper to call a click command - - :param cmd: click cli command function to call - :param args: arguments to pass to the function - :param kwargs: keyword arguments to pass to the function - :return: None - """ - - # Get positional arguments from args - arg_values = {c.name: a for a, c in zip(args, cmd.params)} - args_needed = {c.name: c for c in cmd.params if c.name not in arg_values} - - # build and check opts list from kwargs - opts = {a.name: a for a in cmd.params if isinstance(a, click.Option)} - for name in kwargs: - if name in opts: - arg_values[name] = kwargs[name] - else: - if name in args_needed: - arg_values[name] = kwargs[name] - del args_needed[name] - else: - raise click.BadParameter("Unknown keyword argument '{}'".format(name)) - - # check positional arguments list - for arg in (a for a in cmd.params if isinstance(a, click.Argument)): - if arg.name not in arg_values: - raise click.BadParameter( - "Missing required positional parameter '{}'".format(arg.name) - ) - - # build parameter lists - opts_list = sum([[o.opts[0], arg_values[n]] for n, o in opts.items()], []) - args_list = [str(v) for n, v in arg_values.items() if n not in opts] - - # call the command - CliRunner().invoke(cmd, opts_list + args_list) - - -def test_simple_model(brain_vol_files: list[str], tmp_path: Path) -> None: - arguments = dict( - nidm_file_list=",".join(brain_vol_files), - ml="fs_000008 = DX_GROUP + http://uri.interlex.org/ilx_0100400", - ctr=None, - regularization=None, - output_file=str(tmp_path / "output.txt"), - ) - - call_click_command(linear_regression, *arguments, **arguments) - - out = (tmp_path / "output.txt").read_text() +def test_simple_model(brain_vol_files: list[str]) -> None: + runner = CliRunner() + with runner.isolated_filesystem(): + r = runner.invoke( + linear_regression, + [ + "--nidm_file_list", + ",".join(brain_vol_files), + "--ml", + "fs_000008 = DX_GROUP + http://uri.interlex.org/ilx_0100400", + "-o", + "output.txt", + ], + ) + assert r.exit_code == 0 + out = Path("output.txt").read_text() # check if model was read correctly assert "fs_000008 ~ ilx_0100400 + DX_GROUP" in out @@ -80,20 +43,26 @@ def test_simple_model(brain_vol_files: list[str], tmp_path: Path) -> None: ) -def test_model_with_contrasts(brain_vol_files: list[str], tmp_path: Path) -> None: +def test_model_with_contrasts(brain_vol_files: list[str]) -> None: # run linear regression tool with simple model and evaluate output - arguments = dict( - nidm_file_list=",".join(brain_vol_files), - ml="fs_000008 = DX_GROUP + http://uri.interlex.org/ilx_0100400", - ctr="DX_GROUP", - regularization=None, - output_file=str(tmp_path / "output.txt"), - ) - - call_click_command(linear_regression, *arguments, **arguments) - - out = (tmp_path / "output.txt").read_text() + runner = CliRunner() + with runner.isolated_filesystem(): + r = runner.invoke( + linear_regression, + [ + "--nidm_file_list", + ",".join(brain_vol_files), + "--ml", + "fs_000008 = DX_GROUP + http://uri.interlex.org/ilx_0100400", + "--ctr", + "DX_GROUP", + "-o", + "output.txt", + ], + ) + assert r.exit_code == 0 + out = Path("output.txt").read_text() # check if model was read correctly assert "fs_000008 ~ ilx_0100400 + DX_GROUP" in out @@ -147,20 +116,26 @@ def test_model_with_contrasts(brain_vol_files: list[str], tmp_path: Path) -> Non @pytest.mark.skip( reason="regularization weights seem to be different depending on the platform" ) -def test_model_with_contrasts_reg_L1( - brain_vol_files: list[str], tmp_path: Path -) -> None: - arguments = dict( - nidm_file_list=",".join(brain_vol_files), - ml="fs_000008 = DX_GROUP + http://uri.interlex.org/ilx_0100400", - ctr="DX_GROUP", - regularization="L1", - output_file=str(tmp_path / "output.txt"), - ) - - call_click_command(linear_regression, *arguments, **arguments) - - out = (tmp_path / "output.txt").read_text() +def test_model_with_contrasts_reg_L1(brain_vol_files: list[str]) -> None: + runner = CliRunner() + with runner.isolated_filesystem(): + r = runner.invoke( + linear_regression, + [ + "--nidm_file_list", + ",".join(brain_vol_files), + "--ml", + "fs_000008 = DX_GROUP + http://uri.interlex.org/ilx_0100400", + "--ctr", + "DX_GROUP", + "--regularization", + "L1", + "-o", + "output.txt", + ], + ) + assert r.exit_code == 0 + out = Path("output.txt").read_text() # check if model was read correctly assert "fs_000008 ~ ilx_0100400 + DX_GROUP" in out @@ -179,20 +154,26 @@ def test_model_with_contrasts_reg_L1( @pytest.mark.skip( reason="regularization weights seem to be different depending on the platform" ) -def test_model_with_contrasts_reg_L2( - brain_vol_files: list[str], tmp_path: Path -) -> None: - arguments = dict( - nidm_file_list=",".join(brain_vol_files), - ml="fs_000008 = DX_GROUP + http://uri.interlex.org/ilx_0100400", - ctr="DX_GROUP", - regularization="L2", - output_file=str(tmp_path / "output.txt"), - ) - - call_click_command(linear_regression, *arguments, **arguments) - - out = (tmp_path / "output.txt").read_text() +def test_model_with_contrasts_reg_L2(brain_vol_files: list[str]) -> None: + runner = CliRunner() + with runner.isolated_filesystem(): + r = runner.invoke( + linear_regression, + [ + "--nidm_file_list", + ",".join(brain_vol_files), + "--ml", + "fs_000008 = DX_GROUP + http://uri.interlex.org/ilx_0100400", + "--ctr", + "DX_GROUP", + "--regularization", + "L2", + "-o", + "output.txt", + ], + ) + assert r.exit_code == 0 + out = Path("output.txt").read_text() # check if model was read correctly assert "fs_000008 ~ ilx_0100400 + DX_GROUP" in out