Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Path type option for Resample target argument (Issue #132) #134

Merged
merged 3 commits into from
Apr 27, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 11 additions & 1 deletion tests/transforms/preprocessing/test_resample.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ def test_spacing(self):
image = nib_to_sitk(image_dict[DATA], image_dict[AFFINE])
self.assertEqual(image.GetSpacing(), 3 * (spacing,))

def test_reference(self):
def test_reference_name(self):
sample = self.get_inconsistent_sample()
reference_name = 't1'
transform = Resample(reference_name)
Expand Down Expand Up @@ -46,6 +46,16 @@ def test_missing_affine(self):
with self.assertRaises(ValueError):
transform(self.sample)

def test_reference_path(self):
reference_image, reference_path = self.get_reference_image_and_path()
transform = Resample(reference_path)
transformed = transform(self.sample)
ref_data, ref_affine = reference_image.load()
for image_dict in transformed.values():
self.assertEqual(
ref_data.shape, image_dict[DATA].shape)
assert_array_equal(ref_affine, image_dict[AFFINE])

def test_wrong_spacing_length(self):
with self.assertRaises(ValueError):
Resample((1, 2))
Expand Down
6 changes: 6 additions & 0 deletions tests/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,12 @@ def get_inconsistent_sample(self):
dataset = ImagesDataset(subjects_list)
return dataset[0]

def get_reference_image_and_path(self):
"""Return a reference image and its path"""
path = self.get_image_path('ref', shape=(10, 20, 31))
image = Image(path, INTENSITY)
return image, path

def tearDown(self):
"""Tear down test fixtures, if any."""
print('Deleting', self.dir)
Expand Down
58 changes: 38 additions & 20 deletions torchio/transforms/preprocessing/spatial/resample.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,24 @@
from numbers import Number
from typing import Union, Tuple, Optional
from pathlib import Path

import torch
import numpy as np
import nibabel as nib
from nibabel.processing import resample_to_output, resample_from_to

from ....data.subject import Subject
from ....torchio import LABEL, DATA, AFFINE, TYPE
from ....data.image import Image
from ....torchio import LABEL, DATA, AFFINE, TYPE, INTENSITY
from ... import Interpolation
from ... import Transform


TypeSpacing = Union[float, Tuple[float, float, float]]
TypeTarget = Tuple[
Optional[Union[Image, str]],
Optional[Tuple[float, float, float]],
]


class Resample(Transform):
Expand All @@ -21,8 +27,9 @@ class Resample(Transform):
Args:
target: Tuple :math:`(s_d, s_h, s_w)`. If only one value
:math:`n` is specified, then :math:`s_d = s_h = s_w = n`.
If a string is given, all images will be resampled using the image
with that name as reference.
If a string or :py:class:`~pathlib.Path` is given,
all images will be resampled using the image
with that name as reference or found at the path.
pre_affine_name: Name of the *image key* (not subject key) storing an
affine matrix that will be applied to the image header before
resampling. If ``None``, the image is resampled with an identity
Expand All @@ -43,9 +50,11 @@ class Resample(Transform):
Example:
>>> import torchio
>>> from torchio.transforms import Resample
>>> transform = Resample(1) # resample all images to 1mm iso
>>> transform = Resample((1, 1, 1)) # resample all images to 1mm iso
>>> transform = Resample('t1') # resample all images to 't1' image space
>>> from pathlib import Path
>>> transform = Resample(1) # resample all images to 1mm iso
>>> transform = Resample((1, 1, 1)) # resample all images to 1mm iso
>>> transform = Resample('t1') # resample all images to 't1' image space
>>> transform = Resample('path/to/ref.nii.gz') # resample all images to space of image at this path
>>>
>>> # Affine matrices are added to each image
>>> matrix_to_mni = some_4_by_4_array # e.g. result of registration to MNI space
Expand All @@ -54,15 +63,15 @@ class Resample(Transform):
... mni=Image('mni_152_lin.nii.gz', torchio.INTENSITY),
... )
>>> resample = Resample(
... 'mni', # this is subject key
... 'mni', # this is a subject key
... affine_name='to_mni', # this is an image key
... )
>>> dataset = torchio.ImagesDataset([subject], transform=resample)
>>> sample = dataset[0] # sample['t1'] is now in MNI space
"""
def __init__(
self,
target: Union[TypeSpacing, str],
target: Union[TypeSpacing, str, Path],
image_interpolation: Interpolation = Interpolation.LINEAR,
pre_affine_name: Optional[str] = None,
p: float = 1,
Expand All @@ -73,9 +82,15 @@ def __init__(
image_interpolation)
self.affine_name = pre_affine_name

def parse_target(self, target: Union[TypeSpacing, str]):
if isinstance(target, str):
reference_image = target
def parse_target(
self,
target: Union[TypeSpacing, str],
) -> TypeTarget:
if isinstance(target, (str, Path)):
if Path(target).is_file():
reference_image = Image(target, INTENSITY).load()
else:
reference_image = target
target_spacing = None
else:
reference_image = None
Expand Down Expand Up @@ -171,15 +186,18 @@ def apply_transform(self, sample: Subject) -> dict:
# Resample
args = image_dict[DATA], image_dict[AFFINE], interpolation_order
if use_reference:
try:
ref_image_dict = sample[self.reference_image]
except KeyError as error:
message = (
f'Reference name "{self.reference_image}"'
' not found in sample'
)
raise ValueError(message) from error
reference = ref_image_dict[DATA], ref_image_dict[AFFINE]
if isinstance(self.reference_image, str):
try:
ref_image_dict = sample[self.reference_image]
except KeyError as error:
message = (
f'Reference name "{self.reference_image}"'
' not found in sample'
)
raise ValueError(message) from error
reference = ref_image_dict[DATA], ref_image_dict[AFFINE]
else:
reference = self.reference_image
kwargs = dict(reference=reference)
else:
kwargs = dict(target_spacing=self.target_spacing)
Expand Down