Skip to content

Commit

Permalink
use reusable wf
Browse files Browse the repository at this point in the history
  • Loading branch information
rcannood committed Jan 16, 2025
1 parent 475f091 commit c57871b
Show file tree
Hide file tree
Showing 2 changed files with 52 additions and 118 deletions.
2 changes: 2 additions & 0 deletions src/workflows/run_benchmark/config.vsh.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,8 @@ resources:
- type: file
path: /_viash.yaml
- path: /common/nextflow_helpers/helper.nf
- path: /common/nextflow_helpers/benchmarkHelper.nf
- path: /common/nextflow_helpers/workflowHelper.nf

dependencies:
- name: utils/extract_uns_metadata
Expand Down
168 changes: 50 additions & 118 deletions src/workflows/run_benchmark/main.nf
Original file line number Diff line number Diff line change
@@ -1,11 +1,5 @@
include { checkItemAllowed } from "${meta.resources_dir}/helper.nf"

workflow auto {
findStates(params, meta.config)
| meta.workflow.run(
auto: [publish: "state"]
)
}
include { run_methods; run_metrics; extract_scores; create_metadata_files } from "${meta.resources_dir}/BenchmarkHelper.nf"

methods = [
majority_vote,
Expand Down Expand Up @@ -47,36 +41,24 @@ workflow run_wf {

main:

/****************************
* EXTRACT DATASET METADATA *
****************************/
dataset_ch = input_ch
// store join id
| map{ id, state ->
[id, state + ["_meta": [join_id: id]]]
}
/* RUN METHODS AND METRICS */
score_ch = input_ch

// extract the dataset metadata
// extract the uns metadata from the dataset
| extract_uns_metadata.run(
fromState: [input: "input_solution"],
toState: { id, output, state ->
state + [
dataset_uns: readYaml(output.output).uns
]
def outputYaml = readYaml(output.output)
if (!outputYaml.uns) {
throw new Exception("id '$id': No uns found in provided dataset")
}
state + [ dataset_uns: outputYaml.uns ]
}
)

/***************************
* RUN METHODS AND METRICS *
***************************/
score_ch = dataset_ch

// run all methods
| runEach(
components: methods,

// use the 'filter' argument to only run a method on the normalisation the component is asking for
filter: { id, state, comp ->
| run_methods(
methods: methods,
filter: {id, state, comp ->
def norm = state.dataset_uns.normalization_id
def pref = comp.config.info.preferred_normalization
// if the preferred normalisation is none at all,
Expand All @@ -91,14 +73,7 @@ workflow run_wf {
)
method_check && norm_check
},

// define a new 'id' by appending the method name to the dataset id
id: { id, state, comp ->
id + "." + comp.config.name
},

// use 'fromState' to fetch the arguments the component requires from the overall state
fromState: { id, state, comp ->
fromState: {id, state, comp ->
def new_args = [
input_train: state.input_train,
input_test: state.input_test
Expand All @@ -108,109 +83,57 @@ workflow run_wf {
}
new_args
},

// use 'toState' to publish that component's outputs to the overall state
toState: { id, output, state, comp ->
toState: {id, output, state, comp ->
state + [
method_id: comp.config.name,
method_output: output.output
]
}
)

// run all metrics
| runEach(
components: metrics,
id: { id, state, comp ->
id + "." + comp.config.name
},
// use 'fromState' to fetch the arguments the component requires from the overall state
| run_metrics(
metrics: metrics,
fromState: [
input_solution: "input_solution",
input_prediction: "method_output"
],
// use 'toState' to publish that component's outputs to the overall state
toState: { id, output, state, comp ->
state + [
metric_id: comp.config.name,
metric_output: output.output
]
}
}
)


/******************************
* GENERATE OUTPUT YAML FILES *
******************************/
// TODO: can we store everything below in a separate helper function?

// extract the dataset metadata
dataset_meta_ch = dataset_ch
// only keep one of the normalization methods
| filter{ id, state ->
state.dataset_uns.normalization_id == "log_cp10k"
}
| joinStates { ids, states ->
// store the dataset metadata in a file
def dataset_uns = states.collect{state ->
def uns = state.dataset_uns.clone()
uns.remove("normalization_id")
uns
}
def dataset_uns_yaml_blob = toYamlBlob(dataset_uns)
def dataset_uns_file = tempFile("dataset_uns.yaml")
dataset_uns_file.write(dataset_uns_yaml_blob)

["output", [output_dataset_info: dataset_uns_file]]
}

output_ch = score_ch

// extract the scores
| extract_uns_metadata.run(
key: "extract_scores",
fromState: [input: "metric_output"],
toState: { id, output, state ->
state + [
score_uns: readYaml(output.output).uns
]
}
| extract_scores(
extract_uns_metadata_component: extract_uns_metadata
)

| joinStates { ids, states ->
// store the method configs in a file
def method_configs = methods.collect{it.config}
def method_configs_yaml_blob = toYamlBlob(method_configs)
def method_configs_file = tempFile("method_configs.yaml")
method_configs_file.write(method_configs_yaml_blob)

// store the metric configs in a file
def metric_configs = metrics.collect{it.config}
def metric_configs_yaml_blob = toYamlBlob(metric_configs)
def metric_configs_file = tempFile("metric_configs.yaml")
metric_configs_file.write(metric_configs_yaml_blob)

def task_info_file = meta.resources_dir.resolve("_viash.yaml")
/* GENERATE METADATA FILES */
metadata_ch = input_ch

// store the scores in a file
def score_uns = states.collect{it.score_uns}
def score_uns_yaml_blob = toYamlBlob(score_uns)
def score_uns_file = tempFile("score_uns.yaml")
score_uns_file.write(score_uns_yaml_blob)

def new_state = [
output_method_configs: method_configs_file,
output_metric_configs: metric_configs_file,
output_task_info: task_info_file,
output_scores: score_uns_file,
_meta: states[0]._meta
]
| create_metadata_files(
datasetFile: "input_solution",
// only keep one of the normalization methods
// for generating the dataset metadata files
filter: {id, state ->
state.dataset_uns.normalization_id == "log_cp10k"
},
datasetUnsModifier: { uns ->
def uns_ = uns.clone()
uns_.remove("normalization_id")
uns_
},
methods: methods,
metrics: metrics,
meta: meta,
extract_uns_metadata_component: extract_uns_metadata
)

["output", new_state]
}

// merge all of the output data
| mix(dataset_meta_ch)
/* JOIN SCORES AND METADATA */
output_ch = score_ch
| mix(metadata_ch)
| joinStates{ ids, states ->
def mergedStates = states.inject([:]) { acc, m -> acc + m }
[ids[0], mergedStates]
Expand All @@ -219,3 +142,12 @@ workflow run_wf {
emit:
output_ch
}

// Helper workflow to look for 'state.yaml' files recursively and
// use it to run the benchmark.
workflow auto {
findStates(params, meta.config)
| meta.workflow.run(
auto: [publish: "state"]
)
}

0 comments on commit c57871b

Please sign in to comment.