diff --git a/src/workflows/run_benchmark/config.vsh.yaml b/src/workflows/run_benchmark/config.vsh.yaml index 9d1d3aa..f49776e 100644 --- a/src/workflows/run_benchmark/config.vsh.yaml +++ b/src/workflows/run_benchmark/config.vsh.yaml @@ -73,6 +73,8 @@ resources: - type: file path: /_viash.yaml - path: /common/nextflow_helpers/helper.nf + - path: /common/nextflow_helpers/benchmarkHelper.nf + - path: /common/nextflow_helpers/workflowHelper.nf dependencies: - name: utils/extract_uns_metadata diff --git a/src/workflows/run_benchmark/main.nf b/src/workflows/run_benchmark/main.nf index e7fe17f..e58581c 100644 --- a/src/workflows/run_benchmark/main.nf +++ b/src/workflows/run_benchmark/main.nf @@ -1,11 +1,5 @@ include { checkItemAllowed } from "${meta.resources_dir}/helper.nf" - -workflow auto { - findStates(params, meta.config) - | meta.workflow.run( - auto: [publish: "state"] - ) -} +include { run_methods; run_metrics; extract_scores; create_metadata_files } from "${meta.resources_dir}/BenchmarkHelper.nf" methods = [ majority_vote, @@ -47,36 +41,24 @@ workflow run_wf { main: - /**************************** - * EXTRACT DATASET METADATA * - ****************************/ - dataset_ch = input_ch - // store join id - | map{ id, state -> - [id, state + ["_meta": [join_id: id]]] - } + /* RUN METHODS AND METRICS */ + score_ch = input_ch - // extract the dataset metadata + // extract the uns metadata from the dataset | extract_uns_metadata.run( fromState: [input: "input_solution"], toState: { id, output, state -> - state + [ - dataset_uns: readYaml(output.output).uns - ] + def outputYaml = readYaml(output.output) + if (!outputYaml.uns) { + throw new Exception("id '$id': No uns found in provided dataset") + } + state + [ dataset_uns: outputYaml.uns ] } ) - /*************************** - * RUN METHODS AND METRICS * - ***************************/ - score_ch = dataset_ch - - // run all methods - | runEach( - components: methods, - - // use the 'filter' argument to only run a method on the normalisation the component is asking for - filter: { id, state, comp -> + | run_methods( + methods: methods, + filter: {id, state, comp -> def norm = state.dataset_uns.normalization_id def pref = comp.config.info.preferred_normalization // if the preferred normalisation is none at all, @@ -91,14 +73,7 @@ workflow run_wf { ) method_check && norm_check }, - - // define a new 'id' by appending the method name to the dataset id - id: { id, state, comp -> - id + "." + comp.config.name - }, - - // use 'fromState' to fetch the arguments the component requires from the overall state - fromState: { id, state, comp -> + fromState: {id, state, comp -> def new_args = [ input_train: state.input_train, input_test: state.input_test @@ -108,9 +83,7 @@ workflow run_wf { } new_args }, - - // use 'toState' to publish that component's outputs to the overall state - toState: { id, output, state, comp -> + toState: {id, output, state, comp -> state + [ method_id: comp.config.name, method_output: output.output @@ -118,99 +91,49 @@ workflow run_wf { } ) - // run all metrics - | runEach( - components: metrics, - id: { id, state, comp -> - id + "." + comp.config.name - }, - // use 'fromState' to fetch the arguments the component requires from the overall state + | run_metrics( + metrics: metrics, fromState: [ input_solution: "input_solution", input_prediction: "method_output" ], - // use 'toState' to publish that component's outputs to the overall state toState: { id, output, state, comp -> state + [ metric_id: comp.config.name, metric_output: output.output ] - } + } ) - - /****************************** - * GENERATE OUTPUT YAML FILES * - ******************************/ - // TODO: can we store everything below in a separate helper function? - - // extract the dataset metadata - dataset_meta_ch = dataset_ch - // only keep one of the normalization methods - | filter{ id, state -> - state.dataset_uns.normalization_id == "log_cp10k" - } - | joinStates { ids, states -> - // store the dataset metadata in a file - def dataset_uns = states.collect{state -> - def uns = state.dataset_uns.clone() - uns.remove("normalization_id") - uns - } - def dataset_uns_yaml_blob = toYamlBlob(dataset_uns) - def dataset_uns_file = tempFile("dataset_uns.yaml") - dataset_uns_file.write(dataset_uns_yaml_blob) - - ["output", [output_dataset_info: dataset_uns_file]] - } - - output_ch = score_ch - - // extract the scores - | extract_uns_metadata.run( - key: "extract_scores", - fromState: [input: "metric_output"], - toState: { id, output, state -> - state + [ - score_uns: readYaml(output.output).uns - ] - } + | extract_scores( + extract_uns_metadata_component: extract_uns_metadata ) - | joinStates { ids, states -> - // store the method configs in a file - def method_configs = methods.collect{it.config} - def method_configs_yaml_blob = toYamlBlob(method_configs) - def method_configs_file = tempFile("method_configs.yaml") - method_configs_file.write(method_configs_yaml_blob) - - // store the metric configs in a file - def metric_configs = metrics.collect{it.config} - def metric_configs_yaml_blob = toYamlBlob(metric_configs) - def metric_configs_file = tempFile("metric_configs.yaml") - metric_configs_file.write(metric_configs_yaml_blob) - - def task_info_file = meta.resources_dir.resolve("_viash.yaml") + /* GENERATE METADATA FILES */ + metadata_ch = input_ch - // store the scores in a file - def score_uns = states.collect{it.score_uns} - def score_uns_yaml_blob = toYamlBlob(score_uns) - def score_uns_file = tempFile("score_uns.yaml") - score_uns_file.write(score_uns_yaml_blob) - - def new_state = [ - output_method_configs: method_configs_file, - output_metric_configs: metric_configs_file, - output_task_info: task_info_file, - output_scores: score_uns_file, - _meta: states[0]._meta - ] + | create_metadata_files( + datasetFile: "input_solution", + // only keep one of the normalization methods + // for generating the dataset metadata files + filter: {id, state -> + state.dataset_uns.normalization_id == "log_cp10k" + }, + datasetUnsModifier: { uns -> + def uns_ = uns.clone() + uns_.remove("normalization_id") + uns_ + }, + methods: methods, + metrics: metrics, + meta: meta, + extract_uns_metadata_component: extract_uns_metadata + ) - ["output", new_state] - } - // merge all of the output data - | mix(dataset_meta_ch) + /* JOIN SCORES AND METADATA */ + output_ch = score_ch + | mix(metadata_ch) | joinStates{ ids, states -> def mergedStates = states.inject([:]) { acc, m -> acc + m } [ids[0], mergedStates] @@ -219,3 +142,12 @@ workflow run_wf { emit: output_ch } + +// Helper workflow to look for 'state.yaml' files recursively and +// use it to run the benchmark. +workflow auto { + findStates(params, meta.config) + | meta.workflow.run( + auto: [publish: "state"] + ) +} \ No newline at end of file