Skip to content
This repository has been archived by the owner on Mar 21, 2024. It is now read-only.

BUG: Fix missing channels dimension in normalization #701

Merged
merged 11 commits into from
Jun 13, 2022
Prev Previous commit
Next Next commit
Add test for 3D and 4D input images
  • Loading branch information
fepegar committed Jun 8, 2022
commit 478698e47305510dd4c7fb878ea6cf30e6057123
9 changes: 5 additions & 4 deletions InnerEye/ML/utils/image_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -389,18 +389,19 @@ def get_center_crop(image: NumpyOrTorch, crop_shape: TupleInt3) -> NumpyOrTorch:


def check_array_range(data: np.ndarray, expected_range: Optional[Range] = None,
error_prefix: str = None) -> None:
error_prefix: Optional[str] = None) -> None:
"""
Checks if all values in the given array fall into the expected range. If not, raises a
ValueError, and prints out statistics about the values that fell outside the expected range.
``ValueError``, and prints out statistics about the values that fell outside the expected range.
If no range is provided, it checks that all values in the array are finite (that is, they are not
infinity and not np.nan
infinity and not ``np.nan``).

:param data: The array to check. It can have any size.
:param expected_range: The interval that all array elements must fall into. The first entry is the lower
bound, the second entry is the upper bound.
bound, the second entry is the upper bound.
:param error_prefix: A string to use as the prefix for the error message.
"""
data = np.asarray(data)
fepegar marked this conversation as resolved.
Show resolved Hide resolved
if expected_range is None:
valid_pixels = np.isfinite(data)
else:
Expand Down
32 changes: 20 additions & 12 deletions Tests/ML/test_normalize.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,14 +36,14 @@


@pytest.fixture
def image_rand_pos() -> Union[torch.Tensor, np.ndarray]:
def image_rand_pos() -> np.ndarray:
torch.random.manual_seed(1)
np.random.seed(0)
return (np.random.rand(3, 4, 4, 4) * 1000.0).astype(ImageDataType.IMAGE.value)


@pytest.fixture
def image_rand_pos_gpu(image_rand_pos: Union[torch.Tensor, np.ndarray]) -> Union[torch.Tensor, np.ndarray]:
def image_rand_pos_gpu(image_rand_pos: np.ndarray) -> Union[torch.Tensor, np.ndarray]:
return torch.tensor(image_rand_pos) if use_gpu else image_rand_pos


Expand All @@ -56,42 +56,50 @@ def assert_image_out_datatype(image_out: np.ndarray) -> None:
"datatype that we force images to have."


def test_simplenorm_half(image_rand_pos: Union[torch.Tensor, np.ndarray]) -> None:
def test_simplenorm_half(image_rand_pos: np.ndarray) -> None:
image_out = photometric_normalization.simple_norm(image_rand_pos, mask_half, debug_mode=True)
assert np.mean(image_out, dtype=np.float) == approx(-0.05052318)
for c in range(image_out.shape[0]):
assert np.mean(image_out[c, mask_half > 0.5], dtype=np.float) == approx(0, abs=1e-7)
assert_image_out_datatype(image_out)


def test_simplenorm_ones(image_rand_pos: Union[torch.Tensor, np.ndarray]) -> None:
def test_simplenorm_ones(image_rand_pos: np.ndarray) -> None:
image_out = photometric_normalization.simple_norm(image_rand_pos, mask_ones, debug_mode=True)
assert np.mean(image_out) == approx(0, abs=1e-7)
assert_image_out_datatype(image_out)


def test_mriwindowhalf(image_rand_pos: Union[torch.Tensor, np.ndarray]) -> None:
image_out, status = photometric_normalization.mri_window(image_rand_pos, mask_half, (0, 1), sharpen, tail)
def test_3d_4d(image_rand_pos: np.ndarray) -> None:
normalization = photometric_normalization.PhotometricNormalization()
shape = image_rand_pos.shape
spatial_shape = shape[1:]
assert normalization.transform(image_rand_pos).shape == shape
assert normalization.transform(image_rand_pos[0]).shape == spatial_shape


def test_mriwindowhalf(image_rand_pos: np.ndarray) -> None:
image_out, _ = photometric_normalization.mri_window(image_rand_pos, mask_half, (0, 1), sharpen, tail)
assert np.mean(image_out) == approx(0.2748852)
assert_image_out_datatype(image_out)


def test_mriwindowones(image_rand_pos: Union[torch.Tensor, np.ndarray]) -> None:
image_out, status = photometric_normalization.mri_window(image_rand_pos, mask_ones, (0.0, 1.0), sharpen, tail3)
def test_mriwindowones(image_rand_pos: np.ndarray) -> None:
image_out, _ = photometric_normalization.mri_window(image_rand_pos, mask_ones, (0.0, 1.0), sharpen, tail3)
assert np.mean(image_out) == approx(0.2748852)
assert_image_out_datatype(image_out)


def test_trimmed_norm_full(image_rand_pos: Union[torch.Tensor, np.ndarray]) -> None:
image_out, status = photometric_normalization.normalize_trim(image_rand_pos, mask_ones,
def test_trimmed_norm_full(image_rand_pos: np.ndarray) -> None:
image_out, _ = photometric_normalization.normalize_trim(image_rand_pos, mask_ones,
output_range=(-1, 1), sharpen=1,
trim_percentiles=(1, 99))
assert np.mean(image_out, dtype=np.float) == approx(-0.08756259549409151)
assert_image_out_datatype(image_out)


def test_trimmed_norm_half(image_rand_pos: Union[torch.Tensor, np.ndarray]) -> None:
image_out, status = photometric_normalization.normalize_trim(image_rand_pos, mask_half,
def test_trimmed_norm_half(image_rand_pos: np.ndarray) -> None:
image_out, _ = photometric_normalization.normalize_trim(image_rand_pos, mask_half,
output_range=(-1, 1), sharpen=1,
trim_percentiles=(1, 99))
assert np.mean(image_out, dtype=np.float) == approx(-0.4862089517215888)
Expand Down