diff --git a/nipype/algorithms/metrics.py b/nipype/algorithms/metrics.py index beeced288a..7f78088b86 100644 --- a/nipype/algorithms/metrics.py +++ b/nipype/algorithms/metrics.py @@ -20,11 +20,11 @@ from .. import config, logging -from ..interfaces.base import (BaseInterface, traits, TraitedSpec, File, - InputMultiPath, BaseInterfaceInputSpec, - isdefined) +from ..interfaces.base import ( + SimpleInterface, BaseInterface, traits, TraitedSpec, File, + InputMultiPath, BaseInterfaceInputSpec, + isdefined) from ..interfaces.nipy.base import NipyBaseInterface -from ..utils import NUMPY_MMAP iflogger = logging.getLogger('interface') @@ -383,6 +383,7 @@ class FuzzyOverlapInputSpec(BaseInterfaceInputSpec): File(exists=True), mandatory=True, desc='Test image. Requires the same dimensions as in_ref.') + in_mask = File(exists=True, desc='calculate overlap only within mask') weighting = traits.Enum( 'none', 'volume', @@ -403,10 +404,6 @@ class FuzzyOverlapInputSpec(BaseInterfaceInputSpec): class FuzzyOverlapOutputSpec(TraitedSpec): jaccard = traits.Float(desc='Fuzzy Jaccard Index (fJI), all the classes') dice = traits.Float(desc='Fuzzy Dice Index (fDI), all the classes') - diff_file = File( - exists=True, - desc= - 'resulting difference-map of all classes, using the chosen weighting') class_fji = traits.List( traits.Float(), desc='Array containing the fJIs of each computed class') @@ -415,7 +412,7 @@ class FuzzyOverlapOutputSpec(TraitedSpec): desc='Array containing the fDIs of each computed class') -class FuzzyOverlap(BaseInterface): +class FuzzyOverlap(SimpleInterface): """Calculates various overlap measures between two maps, using the fuzzy definition proposed in: Crum et al., Generalized Overlap Measures for Evaluation and Validation in Medical Image Analysis, IEEE Trans. Med. @@ -439,77 +436,75 @@ class FuzzyOverlap(BaseInterface): output_spec = FuzzyOverlapOutputSpec def _run_interface(self, runtime): - ncomp = len(self.inputs.in_ref) - assert (ncomp == len(self.inputs.in_tst)) - weights = np.ones(shape=ncomp) - - img_ref = np.array([ - nb.load(fname, mmap=NUMPY_MMAP).get_data() - for fname in self.inputs.in_ref - ]) - img_tst = np.array([ - nb.load(fname, mmap=NUMPY_MMAP).get_data() - for fname in self.inputs.in_tst - ]) - - msk = np.sum(img_ref, axis=0) - msk[msk > 0] = 1.0 - tst_msk = np.sum(img_tst, axis=0) - tst_msk[tst_msk > 0] = 1.0 - - # check that volumes are normalized - # img_ref[:][msk>0] = img_ref[:][msk>0] / (np.sum( img_ref, axis=0 ))[msk>0] - # img_tst[tst_msk>0] = img_tst[tst_msk>0] / np.sum( img_tst, axis=0 )[tst_msk>0] - - self._jaccards = [] - volumes = [] - - diff_im = np.zeros(img_ref.shape) - - for ref_comp, tst_comp, diff_comp in zip(img_ref, img_tst, diff_im): - num = np.minimum(ref_comp, tst_comp) - ddr = np.maximum(ref_comp, tst_comp) - diff_comp[ddr > 0] += 1.0 - (num[ddr > 0] / ddr[ddr > 0]) - self._jaccards.append(np.sum(num) / np.sum(ddr)) - volumes.append(np.sum(ref_comp)) - - self._dices = 2.0 * (np.array(self._jaccards) / - (np.array(self._jaccards) + 1.0)) + # Load data + refdata = nb.concat_images(self.inputs.in_ref).get_data() + tstdata = nb.concat_images(self.inputs.in_tst).get_data() + + # Data must have same shape + if not refdata.shape == tstdata.shape: + raise RuntimeError( + 'Size of "in_tst" %s must match that of "in_ref" %s.' % + (tstdata.shape, refdata.shape)) + + ncomp = refdata.shape[-1] + # Load mask + mask = np.ones_like(refdata, dtype=bool) + if isdefined(self.inputs.in_mask): + mask = nb.load(self.inputs.in_mask).get_data() + mask = mask > 0 + mask = np.repeat(mask[..., np.newaxis], ncomp, -1) + assert mask.shape == refdata.shape + + # Drop data outside mask + refdata = refdata[mask] + tstdata = tstdata[mask] + + if np.any(refdata < 0.0): + iflogger.warning('Negative values encountered in "in_ref" input, ' + 'taking absolute values.') + refdata = np.abs(refdata) + + if np.any(tstdata < 0.0): + iflogger.warning('Negative values encountered in "in_tst" input, ' + 'taking absolute values.') + tstdata = np.abs(tstdata) + + if np.any(refdata > 1.0): + iflogger.warning('Values greater than 1.0 found in "in_ref" input, ' + 'scaling values.') + refdata /= refdata.max() + + if np.any(tstdata > 1.0): + iflogger.warning('Values greater than 1.0 found in "in_tst" input, ' + 'scaling values.') + tstdata /= tstdata.max() + + numerators = np.atleast_2d( + np.minimum(refdata, tstdata).reshape((-1, ncomp))) + denominators = np.atleast_2d( + np.maximum(refdata, tstdata).reshape((-1, ncomp))) + + jaccards = numerators.sum(axis=0) / denominators.sum(axis=0) + + # Calculate weights + weights = np.ones_like(jaccards, dtype=float) if self.inputs.weighting != "none": - weights = 1.0 / np.array(volumes) + volumes = np.sum((refdata + tstdata) > 0, axis=1).reshape((-1, ncomp)) + weights = 1.0 / volumes if self.inputs.weighting == "squared_vol": weights = weights**2 weights = weights / np.sum(weights) + dices = 2.0 * jaccards / (jaccards + 1.0) - setattr(self, '_jaccard', np.sum(weights * self._jaccards)) - setattr(self, '_dice', np.sum(weights * self._dices)) - - diff = np.zeros(diff_im[0].shape) - - for w, ch in zip(weights, diff_im): - ch[msk == 0] = 0 - diff += w * ch - - nb.save( - nb.Nifti1Image(diff, - nb.load(self.inputs.in_ref[0]).affine, - nb.load(self.inputs.in_ref[0]).header), - self.inputs.out_file) - + # Fill-in the results object + self._results['jaccard'] = float(weights.dot(jaccards)) + self._results['dice'] = float(weights.dot(dices)) + self._results['class_fji'] = [float(v) for v in jaccards] + self._results['class_fdi'] = [float(v) for v in dices] return runtime - def _list_outputs(self): - outputs = self._outputs().get() - for method in ("dice", "jaccard"): - outputs[method] = getattr(self, '_' + method) - # outputs['volume_difference'] = self._volume - outputs['diff_file'] = os.path.abspath(self.inputs.out_file) - outputs['class_fji'] = np.array(self._jaccards).astype(float).tolist() - outputs['class_fdi'] = self._dices.astype(float).tolist() - return outputs - class ErrorMapInputSpec(BaseInterfaceInputSpec): in_ref = File( diff --git a/nipype/algorithms/tests/test_auto_FuzzyOverlap.py b/nipype/algorithms/tests/test_auto_FuzzyOverlap.py index b59ba2a5a5..fc8bf79fa9 100644 --- a/nipype/algorithms/tests/test_auto_FuzzyOverlap.py +++ b/nipype/algorithms/tests/test_auto_FuzzyOverlap.py @@ -10,6 +10,7 @@ def test_FuzzyOverlap_inputs(): nohash=True, usedefault=True, ), + in_mask=dict(), in_ref=dict(mandatory=True, ), in_tst=dict(mandatory=True, ), out_file=dict(usedefault=True, ), @@ -25,7 +26,6 @@ def test_FuzzyOverlap_outputs(): class_fdi=dict(), class_fji=dict(), dice=dict(), - diff_file=dict(), jaccard=dict(), ) outputs = FuzzyOverlap.output_spec() diff --git a/nipype/algorithms/tests/test_metrics.py b/nipype/algorithms/tests/test_metrics.py new file mode 100644 index 0000000000..fb876b3c72 --- /dev/null +++ b/nipype/algorithms/tests/test_metrics.py @@ -0,0 +1,58 @@ +# emacs: -*- mode: python; py-indent-offset: 4; indent-tabs-mode: nil -*- +# vi: set ft=python sts=4 ts=4 sw=4 et: + +import numpy as np +import nibabel as nb +from nipype.testing import example_data +from ..metrics import FuzzyOverlap + + +def test_fuzzy_overlap(tmpdir): + tmpdir.chdir() + + # Tests with tissue probability maps + in_mask = example_data('tpms_msk.nii.gz') + tpms = [example_data('tpm_%02d.nii.gz' % i) for i in range(3)] + out = FuzzyOverlap(in_ref=tpms[0], in_tst=tpms[0]).run().outputs + assert out.dice == 1 + + out = FuzzyOverlap( + in_mask=in_mask, in_ref=tpms[0], in_tst=tpms[0]).run().outputs + assert out.dice == 1 + + out = FuzzyOverlap( + in_mask=in_mask, in_ref=tpms[0], in_tst=tpms[1]).run().outputs + assert 0 < out.dice < 1 + + out = FuzzyOverlap(in_ref=tpms, in_tst=tpms).run().outputs + assert out.dice == 1.0 + + out = FuzzyOverlap( + in_mask=in_mask, in_ref=tpms, in_tst=tpms).run().outputs + assert out.dice == 1.0 + + # Tests with synthetic 3x3x3 images + data = np.zeros((3, 3, 3), dtype=float) + data[0, 0, 0] = 0.5 + data[2, 2, 2] = 0.25 + data[1, 1, 1] = 0.3 + nb.Nifti1Image(data, np.eye(4)).to_filename('test1.nii.gz') + + data = np.zeros((3, 3, 3), dtype=float) + data[0, 0, 0] = 0.6 + data[1, 1, 1] = 0.3 + nb.Nifti1Image(data, np.eye(4)).to_filename('test2.nii.gz') + + out = FuzzyOverlap(in_ref='test1.nii.gz', in_tst='test2.nii.gz').run().outputs + assert np.allclose(out.dice, 0.82051) + + # Just considering the mask, the central pixel + # that raised the index now is left aside. + data = np.zeros((3, 3, 3), dtype=int) + data[0, 0, 0] = 1 + data[2, 2, 2] = 1 + nb.Nifti1Image(data, np.eye(4)).to_filename('mask.nii.gz') + + out = FuzzyOverlap(in_ref='test1.nii.gz', in_tst='test2.nii.gz', + in_mask='mask.nii.gz').run().outputs + assert np.allclose(out.dice, 0.74074)