Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Clean up Click testing #352

Merged
merged 1 commit into from
May 4, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
167 changes: 74 additions & 93 deletions tests/experiment/tools/test_nidm_lingreg.py
Original file line number Diff line number Diff line change
@@ -1,63 +1,26 @@
from __future__ import annotations
from pathlib import Path
import click
from click.testing import CliRunner
import pytest
from nidm.experiment.tools.nidm_linreg import linear_regression


def call_click_command(cmd, *args, **kwargs):
"""Wrapper to call a click command

:param cmd: click cli command function to call
:param args: arguments to pass to the function
:param kwargs: keyword arguments to pass to the function
:return: None
"""

# Get positional arguments from args
arg_values = {c.name: a for a, c in zip(args, cmd.params)}
args_needed = {c.name: c for c in cmd.params if c.name not in arg_values}

# build and check opts list from kwargs
opts = {a.name: a for a in cmd.params if isinstance(a, click.Option)}
for name in kwargs:
if name in opts:
arg_values[name] = kwargs[name]
else:
if name in args_needed:
arg_values[name] = kwargs[name]
del args_needed[name]
else:
raise click.BadParameter("Unknown keyword argument '{}'".format(name))

# check positional arguments list
for arg in (a for a in cmd.params if isinstance(a, click.Argument)):
if arg.name not in arg_values:
raise click.BadParameter(
"Missing required positional parameter '{}'".format(arg.name)
)

# build parameter lists
opts_list = sum([[o.opts[0], arg_values[n]] for n, o in opts.items()], [])
args_list = [str(v) for n, v in arg_values.items() if n not in opts]

# call the command
CliRunner().invoke(cmd, opts_list + args_list)


def test_simple_model(brain_vol_files: list[str], tmp_path: Path) -> None:
arguments = dict(
nidm_file_list=",".join(brain_vol_files),
ml="fs_000008 = DX_GROUP + http://uri.interlex.org/ilx_0100400",
ctr=None,
regularization=None,
output_file=str(tmp_path / "output.txt"),
)

call_click_command(linear_regression, *arguments, **arguments)

out = (tmp_path / "output.txt").read_text()
def test_simple_model(brain_vol_files: list[str]) -> None:
runner = CliRunner()
with runner.isolated_filesystem():
r = runner.invoke(
linear_regression,
[
"--nidm_file_list",
",".join(brain_vol_files),
"--ml",
"fs_000008 = DX_GROUP + http://uri.interlex.org/ilx_0100400",
"-o",
"output.txt",
],
)
assert r.exit_code == 0
out = Path("output.txt").read_text()

# check if model was read correctly
assert "fs_000008 ~ ilx_0100400 + DX_GROUP" in out
Expand All @@ -80,20 +43,26 @@ def test_simple_model(brain_vol_files: list[str], tmp_path: Path) -> None:
)


def test_model_with_contrasts(brain_vol_files: list[str], tmp_path: Path) -> None:
def test_model_with_contrasts(brain_vol_files: list[str]) -> None:
# run linear regression tool with simple model and evaluate output

arguments = dict(
nidm_file_list=",".join(brain_vol_files),
ml="fs_000008 = DX_GROUP + http://uri.interlex.org/ilx_0100400",
ctr="DX_GROUP",
regularization=None,
output_file=str(tmp_path / "output.txt"),
)

call_click_command(linear_regression, *arguments, **arguments)

out = (tmp_path / "output.txt").read_text()
runner = CliRunner()
with runner.isolated_filesystem():
r = runner.invoke(
linear_regression,
[
"--nidm_file_list",
",".join(brain_vol_files),
"--ml",
"fs_000008 = DX_GROUP + http://uri.interlex.org/ilx_0100400",
"--ctr",
"DX_GROUP",
"-o",
"output.txt",
],
)
assert r.exit_code == 0
out = Path("output.txt").read_text()

# check if model was read correctly
assert "fs_000008 ~ ilx_0100400 + DX_GROUP" in out
Expand Down Expand Up @@ -147,20 +116,26 @@ def test_model_with_contrasts(brain_vol_files: list[str], tmp_path: Path) -> Non
@pytest.mark.skip(
reason="regularization weights seem to be different depending on the platform"
)
def test_model_with_contrasts_reg_L1(
brain_vol_files: list[str], tmp_path: Path
) -> None:
arguments = dict(
nidm_file_list=",".join(brain_vol_files),
ml="fs_000008 = DX_GROUP + http://uri.interlex.org/ilx_0100400",
ctr="DX_GROUP",
regularization="L1",
output_file=str(tmp_path / "output.txt"),
)

call_click_command(linear_regression, *arguments, **arguments)

out = (tmp_path / "output.txt").read_text()
def test_model_with_contrasts_reg_L1(brain_vol_files: list[str]) -> None:
runner = CliRunner()
with runner.isolated_filesystem():
r = runner.invoke(
linear_regression,
[
"--nidm_file_list",
",".join(brain_vol_files),
"--ml",
"fs_000008 = DX_GROUP + http://uri.interlex.org/ilx_0100400",
"--ctr",
"DX_GROUP",
"--regularization",
"L1",
"-o",
"output.txt",
],
)
assert r.exit_code == 0
out = Path("output.txt").read_text()

# check if model was read correctly
assert "fs_000008 ~ ilx_0100400 + DX_GROUP" in out
Expand All @@ -179,20 +154,26 @@ def test_model_with_contrasts_reg_L1(
@pytest.mark.skip(
reason="regularization weights seem to be different depending on the platform"
)
def test_model_with_contrasts_reg_L2(
brain_vol_files: list[str], tmp_path: Path
) -> None:
arguments = dict(
nidm_file_list=",".join(brain_vol_files),
ml="fs_000008 = DX_GROUP + http://uri.interlex.org/ilx_0100400",
ctr="DX_GROUP",
regularization="L2",
output_file=str(tmp_path / "output.txt"),
)

call_click_command(linear_regression, *arguments, **arguments)

out = (tmp_path / "output.txt").read_text()
def test_model_with_contrasts_reg_L2(brain_vol_files: list[str]) -> None:
runner = CliRunner()
with runner.isolated_filesystem():
r = runner.invoke(
linear_regression,
[
"--nidm_file_list",
",".join(brain_vol_files),
"--ml",
"fs_000008 = DX_GROUP + http://uri.interlex.org/ilx_0100400",
"--ctr",
"DX_GROUP",
"--regularization",
"L2",
"-o",
"output.txt",
],
)
assert r.exit_code == 0
out = Path("output.txt").read_text()

# check if model was read correctly
assert "fs_000008 ~ ilx_0100400 + DX_GROUP" in out
Expand Down