Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ENH] Make Eddy's slice2vol much easier to use #710

Merged
merged 5 commits into from
Mar 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 14 additions & 9 deletions qsiprep/interfaces/dsi_studio.py
Original file line number Diff line number Diff line change
Expand Up @@ -361,6 +361,7 @@ class DSIStudioConnectivityMatrix(CommandLine):
input_spec = DSIStudioConnectivityMatrixInputSpec
output_spec = DSIStudioConnectivityMatrixOutputSpec
_cmd = "dsi_studio --action=ana "
_terminal_output = "file"

def _post_run_hook(self, runtime):
atlas_config = self.inputs.atlas_config
Expand Down Expand Up @@ -448,18 +449,22 @@ def _run_interface(self, runtime):
workflow.config["execution"]["stop_on_first_crash"] = "true"
workflow.config["execution"]["remove_unnecessary_outputs"] = "false"
workflow.base_dir = runtime.cwd
plugin_settings = {}
if num_threads > 1:
plugin_settings = {
"plugin": "MultiProc",
"plugin_args": {
"raise_insufficient": False,
"maxtasksperchild": 1,
"n_procs": num_threads,
},
plugin_settings["plugin"] = "MultiProc"
plugin_settings["plugin_args"] = {
"raise_insufficient": False,
"maxtasksperchild": 1,
"n_procs": num_threads,
}
wf_result = workflow.run(**plugin_settings)
else:
wf_result = workflow.run()
plugin_settings["plugin"] = "Linear"

workflow.config["execution"] = {
"stop_on_first_crash": "True",
"remove_unnecessary_outputs": "False",
}
wf_result = workflow.run(**plugin_settings)
(merge_node,) = [
node for node in list(wf_result.nodes) if node.name.endswith("merge_mats")
]
Expand Down
41 changes: 40 additions & 1 deletion qsiprep/interfaces/dwi_merge.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,14 +48,15 @@ class MergeDWIsInputSpec(BaseInterfaceInputSpec):
carpetplot_data = InputMultiObject(
File(exists=True), mandatory=False, desc="list of carpetplot_data files"
)
scan_metadata = traits.Dict(desc="Dict of metadata for the to-be-combined scans")


class MergeDWIsOutputSpec(TraitedSpec):
out_dwi = File(desc="the merged dwi image")
out_bval = File(desc="the merged bval file")
out_bvec = File(desc="the merged bvec file")
original_images = traits.List()
merged_metadata = traits.Dict()
merged_metadata = File(exists=True)
merged_denoising_confounds = File(exists=True)
merged_b0_ref = File(exists=True)
merged_raw_dwi = File(exists=True, mandatory=False)
Expand All @@ -79,6 +80,17 @@ def _run_interface(self, runtime):
self.inputs.harmonize_b0_intensities,
)

# Create a merged metadata json file for
if isdefined(self.inputs.scan_metadata):
combined_metadata = combine_metadata(
self.inputs.bids_dwi_files,
self.inputs.scan_metadata,
)
merged_metadata_file = op.join(runtime.cwd, "merged_metadata.json")
with open(merged_metadata_file, "w") as merged_meta_f:
json.dump(combined_metadata, merged_meta_f, sort_keys=True, indent=4)
self._results["merged_metadata"] = merged_metadata_file

# Get basic qc / provenance per volume
provenance_df = create_provenance_dataframe(
self.inputs.bids_dwi_files, to_concat, b0_means, corrections
Expand Down Expand Up @@ -151,6 +163,33 @@ def _run_interface(self, runtime):
return runtime


def combine_metadata(scan_list, metadata_dict, merge_method="first"):
"""Create a merged metadata dictionary.
Most importantly, combine the slice timings in some way.
Parameters
----------
scan_list: list
List of BIDS inputs in the order in which they'll be concatenated
medadata_dict: dict
Mapping keys (values in ``scan_list``) to BIDS metadata dictionaries
merge_method: str
How to combine the metadata when multiple scans are being concatenated.
If "first" the metadata from the first scan is selected. Someday other
methods like "average" may be added.
Returns
-------
metadata: dict
A BIDS metadata dictionary
"""
if merge_method == "first":
return metadata_dict[scan_list[0]]
raise NotImplementedError(f"merge_method '{merge_method}' is not implemented")


class AveragePEPairsInputSpec(MergeDWIsInputSpec):
original_bvec_files = InputMultiObject(
File(exists=True), mandatory=True, desc="list of original bvec files"
Expand Down
17 changes: 17 additions & 0 deletions qsiprep/interfaces/eddy.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
"""
import json
import os
import os.path as op

Expand Down Expand Up @@ -45,6 +46,8 @@ class GatherEddyInputsInputSpec(BaseInterfaceInputSpec):
topup_max_b0s_per_spec = traits.CInt(1, usedefault=True)
topup_requested = traits.Bool(False, usedefault=True)
raw_image_sdc = traits.Bool(True, usedefault=True)
eddy_config = File(exists=True, mandatory=True)
json_file = File(exists=True)


class GatherEddyInputsOutputSpec(TraitedSpec):
Expand All @@ -60,6 +63,8 @@ class GatherEddyInputsOutputSpec(TraitedSpec):
forward_transforms = traits.List()
forward_warps = traits.List()
topup_report = traits.Str(desc="description of where data came from")
json_file = File(exists=True)
multiband_factor = traits.Int()


class GatherEddyInputs(SimpleInterface):
Expand Down Expand Up @@ -125,6 +130,14 @@ def _run_interface(self, runtime):
# these have already had HMC, SDC applied
self._results["forward_transforms"] = []
self._results["forward_warps"] = []

# Based on the eddy config, determine whether to send a json argument
with open(self.inputs.eddy_config, "r") as eddy_cfg_f:
eddy_config = json.load(eddy_cfg_f)
# json file is only allowed if mporder is defined
if "mporder" in eddy_config:
self._results["json_file"] = self.inputs.json_file

return runtime


Expand Down Expand Up @@ -241,6 +254,10 @@ def _format_arg(self, name, spec, value):
if name == "field":
pth, fname, _ = split_filename(value)
return spec.argstr % op.join(pth, fname)
if name == "json":
if isdefined(self.inputs.mporder):
return spec.argstr % value
return ""
return super(ExtendedEddy, self)._format_arg(name, spec, value)


Expand Down
1 change: 1 addition & 0 deletions qsiprep/workflows/dwi/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -426,6 +426,7 @@ def init_dwi_preproc_wf(
('outputnode.dwi_file', 'inputnode.dwi_file'),
('outputnode.bval_file', 'inputnode.bval_file'),
('outputnode.bvec_file', 'inputnode.bvec_file'),
('outputnode.json_file', 'inputnode.json_file'),
('outputnode.original_files', 'inputnode.original_files')]),
(inputnode, hmc_wf, [
('t1_brain', 'inputnode.t1_brain'),
Expand Down
19 changes: 14 additions & 5 deletions qsiprep/workflows/dwi/fsl.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,8 @@ def init_fsl_hmc_wf(
bvec file
bval_file: str
bval file
json_file: str
path to sidecar json file for dwi_file
b0_indices: list
Indexes into ``dwi_files`` that correspond to b=0 volumes
b0_images: list
Expand All @@ -116,6 +118,7 @@ def init_fsl_hmc_wf(
"dwi_file",
"bvec_file",
"bval_file",
"json_file",
"b0_indices",
"b0_images",
"original_files",
Expand Down Expand Up @@ -165,10 +168,6 @@ def init_fsl_hmc_wf(
)

workflow = Workflow(name=name)
gather_inputs = pe.Node(
GatherEddyInputs(b0_threshold=b0_threshold, raw_image_sdc=raw_image_sdc),
name="gather_inputs",
)
if eddy_config is None:
# load from the defaults
eddy_cfg_file = pkgr_fn("qsiprep.data", "eddy_params.json")
Expand All @@ -177,6 +176,13 @@ def init_fsl_hmc_wf(

with open(eddy_cfg_file, "r") as f:
eddy_args = json.load(f)

gather_inputs = pe.Node(
GatherEddyInputs(
b0_threshold=b0_threshold, raw_image_sdc=raw_image_sdc, eddy_config=eddy_cfg_file
),
name="gather_inputs",
)
enhance_pre_sdc = pe.Node(EnhanceB0(), name="enhance_pre_sdc")

# Run in parallel if possible
Expand Down Expand Up @@ -210,6 +216,7 @@ def init_fsl_hmc_wf(
('dwi_file', 'dwi_file'),
('bval_file', 'bval_file'),
('bvec_file', 'bvec_file'),
('json_file', 'json_file'),
('original_files', 'original_files')]),
(inputnode, pre_eddy_b0_ref_wf, [
('t1_brain', 'inputnode.t1_brain'),
Expand All @@ -224,7 +231,9 @@ def init_fsl_hmc_wf(
('outputnode.ref_image_brain', 'dwi_file')]),
(gather_inputs, eddy, [
('eddy_indices', 'in_index'),
('eddy_acqp', 'in_acqp')]),
('eddy_acqp', 'in_acqp'),
('json_file', 'json'),
('multiband_factor', 'multiband_factor')]),
(inputnode, eddy, [
('dwi_file', 'in_file'),
('bval_file', 'in_bval'),
Expand Down
1 change: 1 addition & 0 deletions qsiprep/workflows/dwi/hmc_sdc.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@ def init_qsiprep_hmcsdc_wf(
"dwi_file",
"bvec_file",
"bval_file",
"json_file",
"rpe_b0",
"t2w_unfatsat",
"original_files",
Expand Down
6 changes: 6 additions & 0 deletions qsiprep/workflows/dwi/merge.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

from ...engine import Workflow
from ...interfaces import ConformDwi, DerivativesDataSink
from ...interfaces.bids import get_metadata_for_nifti
from ...interfaces.dipy import Patch2Self
from ...interfaces.dwi_merge import MergeDWIs, PhaseToRad, StackConfounds
from ...interfaces.gradients import ExtractB0s
Expand Down Expand Up @@ -117,6 +118,8 @@ def init_merge_and_denoise_wf(
bvals from merged images
merged_bvec
bvecs from merged images
merged_json
JSON file containing slice timings for slice2vol
noise_image
image(s) created by ``dwidenoise``
original_files
Expand All @@ -133,6 +136,7 @@ def init_merge_and_denoise_wf(
"merged_raw_image",
"merged_bval",
"merged_bvec",
"merged_json",
"noise_images",
"bias_images",
"denoising_confounds",
Expand All @@ -151,6 +155,7 @@ def init_merge_and_denoise_wf(
bids_dwi_files=raw_dwi_files,
b0_threshold=b0_threshold,
harmonize_b0_intensities=not no_b0_harmonization,
scan_metadata={scan: get_metadata_for_nifti(scan) for scan in raw_dwi_files},
),
name="merge_dwis",
n_procs=omp_nthreads,
Expand Down Expand Up @@ -285,6 +290,7 @@ def init_merge_and_denoise_wf(
('original_images', 'original_files'),
('out_bval', 'merged_bval'),
('out_bvec', 'merged_bvec'),
('merged_metadata', 'merged_json')
]),
]) # fmt:skip

Expand Down
7 changes: 7 additions & 0 deletions qsiprep/workflows/dwi/pre_hmc.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,8 @@ def init_dwi_pre_hmc_wf(
a bvec file
bval_file
a bval files
sidecar_file
a json sidecar file for the scan data
b0_indices
list of the positions of the b0 images in the dwi series
b0_images
Expand All @@ -121,6 +123,7 @@ def init_dwi_pre_hmc_wf(
"dwi_file",
"bval_file",
"bvec_file",
"json_file",
"original_files",
"denoising_confounds",
"noise_images",
Expand Down Expand Up @@ -296,6 +299,9 @@ def init_dwi_pre_hmc_wf(
(pm_raw_images, raw_rpe_concat, [('out', 'in_files')]),
(raw_rpe_concat, outputnode, [('out_file', 'raw_concatenated')]),

# Send the slice timings from "plus" to the next steps
(merge_plus, outputnode, [('outputnode.merged_json', 'json_file')]),

# Connect to the QC calculator
(raw_rpe_concat, qc_wf, [('out_file', 'inputnode.dwi_file')]),
(rpe_concat, qc_wf, [
Expand Down Expand Up @@ -333,6 +339,7 @@ def init_dwi_pre_hmc_wf(
('outputnode.merged_image', 'dwi_file'),
('outputnode.merged_bval', 'bval_file'),
('outputnode.merged_bvec', 'bvec_file'),
('outputnode.merged_json', 'json_file'),
('outputnode.bias_images', 'bias_images'),
('outputnode.noise_images', 'noise_images'),
('outputnode.validation_reports', 'validation_reports'),
Expand Down