Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions src/workflows/run_benchmark/config.vsh.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,8 @@ resources:
- type: file
path: /_viash.yaml
- path: /common/nextflow_helpers/helper.nf
- path: /common/nextflow_helpers/benchmarkHelper.nf
- path: /common/nextflow_helpers/workflowHelper.nf

dependencies:
- name: utils/extract_uns_metadata
Expand Down
168 changes: 50 additions & 118 deletions src/workflows/run_benchmark/main.nf
Original file line number Diff line number Diff line change
@@ -1,11 +1,5 @@
include { checkItemAllowed } from "${meta.resources_dir}/helper.nf"

workflow auto {
findStates(params, meta.config)
| meta.workflow.run(
auto: [publish: "state"]
)
}
include { run_methods; run_metrics; extract_scores; create_metadata_files } from "${meta.resources_dir}/BenchmarkHelper.nf"

methods = [
majority_vote,
Expand Down Expand Up @@ -47,36 +41,24 @@ workflow run_wf {

main:

/****************************
* EXTRACT DATASET METADATA *
****************************/
dataset_ch = input_ch
// store join id
| map{ id, state ->
[id, state + ["_meta": [join_id: id]]]
}
/* RUN METHODS AND METRICS */
score_ch = input_ch

// extract the dataset metadata
// extract the uns metadata from the dataset
| extract_uns_metadata.run(
fromState: [input: "input_solution"],
toState: { id, output, state ->
state + [
dataset_uns: readYaml(output.output).uns
]
def outputYaml = readYaml(output.output)
if (!outputYaml.uns) {
throw new Exception("id '$id': No uns found in provided dataset")
}
state + [ dataset_uns: outputYaml.uns ]
}
)

/***************************
* RUN METHODS AND METRICS *
***************************/
score_ch = dataset_ch

// run all methods
| runEach(
components: methods,

// use the 'filter' argument to only run a method on the normalisation the component is asking for
filter: { id, state, comp ->
| run_methods(
methods: methods,
filter: {id, state, comp ->
def norm = state.dataset_uns.normalization_id
def pref = comp.config.info.preferred_normalization
// if the preferred normalisation is none at all,
Expand All @@ -91,14 +73,7 @@ workflow run_wf {
)
method_check && norm_check
},

// define a new 'id' by appending the method name to the dataset id
id: { id, state, comp ->
id + "." + comp.config.name
},

// use 'fromState' to fetch the arguments the component requires from the overall state
fromState: { id, state, comp ->
fromState: {id, state, comp ->
def new_args = [
input_train: state.input_train,
input_test: state.input_test
Expand All @@ -108,109 +83,57 @@ workflow run_wf {
}
new_args
},

// use 'toState' to publish that component's outputs to the overall state
toState: { id, output, state, comp ->
toState: {id, output, state, comp ->
state + [
method_id: comp.config.name,
method_output: output.output
]
}
)

// run all metrics
| runEach(
components: metrics,
id: { id, state, comp ->
id + "." + comp.config.name
},
// use 'fromState' to fetch the arguments the component requires from the overall state
| run_metrics(
metrics: metrics,
fromState: [
input_solution: "input_solution",
input_prediction: "method_output"
],
// use 'toState' to publish that component's outputs to the overall state
toState: { id, output, state, comp ->
state + [
metric_id: comp.config.name,
metric_output: output.output
]
}
}
)


/******************************
* GENERATE OUTPUT YAML FILES *
******************************/
// TODO: can we store everything below in a separate helper function?

// extract the dataset metadata
dataset_meta_ch = dataset_ch
// only keep one of the normalization methods
| filter{ id, state ->
state.dataset_uns.normalization_id == "log_cp10k"
}
| joinStates { ids, states ->
// store the dataset metadata in a file
def dataset_uns = states.collect{state ->
def uns = state.dataset_uns.clone()
uns.remove("normalization_id")
uns
}
def dataset_uns_yaml_blob = toYamlBlob(dataset_uns)
def dataset_uns_file = tempFile("dataset_uns.yaml")
dataset_uns_file.write(dataset_uns_yaml_blob)

["output", [output_dataset_info: dataset_uns_file]]
}

output_ch = score_ch

// extract the scores
| extract_uns_metadata.run(
key: "extract_scores",
fromState: [input: "metric_output"],
toState: { id, output, state ->
state + [
score_uns: readYaml(output.output).uns
]
}
| extract_scores(
extract_uns_metadata_component: extract_uns_metadata
)

| joinStates { ids, states ->
// store the method configs in a file
def method_configs = methods.collect{it.config}
def method_configs_yaml_blob = toYamlBlob(method_configs)
def method_configs_file = tempFile("method_configs.yaml")
method_configs_file.write(method_configs_yaml_blob)

// store the metric configs in a file
def metric_configs = metrics.collect{it.config}
def metric_configs_yaml_blob = toYamlBlob(metric_configs)
def metric_configs_file = tempFile("metric_configs.yaml")
metric_configs_file.write(metric_configs_yaml_blob)

def task_info_file = meta.resources_dir.resolve("_viash.yaml")
/* GENERATE METADATA FILES */
metadata_ch = input_ch

// store the scores in a file
def score_uns = states.collect{it.score_uns}
def score_uns_yaml_blob = toYamlBlob(score_uns)
def score_uns_file = tempFile("score_uns.yaml")
score_uns_file.write(score_uns_yaml_blob)

def new_state = [
output_method_configs: method_configs_file,
output_metric_configs: metric_configs_file,
output_task_info: task_info_file,
output_scores: score_uns_file,
_meta: states[0]._meta
]
| create_metadata_files(
datasetFile: "input_solution",
// only keep one of the normalization methods
// for generating the dataset metadata files
filter: {id, state ->
state.dataset_uns.normalization_id == "log_cp10k"
},
datasetUnsModifier: { uns ->
def uns_ = uns.clone()
uns_.remove("normalization_id")
uns_
},
methods: methods,
metrics: metrics,
meta: meta,
extract_uns_metadata_component: extract_uns_metadata
)

["output", new_state]
}

// merge all of the output data
| mix(dataset_meta_ch)
/* JOIN SCORES AND METADATA */
output_ch = score_ch
| mix(metadata_ch)
| joinStates{ ids, states ->
def mergedStates = states.inject([:]) { acc, m -> acc + m }
[ids[0], mergedStates]
Expand All @@ -219,3 +142,12 @@ workflow run_wf {
emit:
output_ch
}

// Helper workflow to look for 'state.yaml' files recursively and
// use it to run the benchmark.
workflow auto {
findStates(params, meta.config)
| meta.workflow.run(
auto: [publish: "state"]
)
}
Loading