Skip to content
This repository has been archived by the owner on Mar 21, 2024. It is now read-only.

Commit

Permalink
✅ Fix fastMRI test
Browse files Browse the repository at this point in the history
  • Loading branch information
peterhessey committed Jun 22, 2022
1 parent e05e043 commit c901822
Show file tree
Hide file tree
Showing 2 changed files with 108 additions and 10 deletions.
108 changes: 103 additions & 5 deletions Tests/ML/configs/fastmri_random.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,16 +10,114 @@
from pathlib import Path
from typing import Any, Optional

import h5py
import numpy as np
from _pytest.monkeypatch import MonkeyPatch
from pytorch_lightning import LightningDataModule, LightningModule

from InnerEye.ML.configs.other.fastmri_varnet import VarNetWithImageLogging
from InnerEye.ML.lightning_container import LightningContainer
from fastMRI.tests.create_temp_data import create_temp_data
from fastmri.data import SliceDataset
from fastmri.data.subsample import create_mask_for_mask_type
from fastmri.data.transforms import VarNetDataTransform
from fastmri.pl_modules import FastMriDataModule
from pytorch_lightning import LightningDataModule, LightningModule

from InnerEye.ML.configs.other.fastmri_varnet import VarNetWithImageLogging
from InnerEye.ML.lightning_container import LightningContainer


def create_temp_data(path):
rg = np.random.default_rng(seed=1234)
max_num_slices = 15
max_num_coils = 15
data_splits = {
"knee_data": [
"multicoil_train",
"multicoil_val",
"multicoil_test",
"multicoil_challenge",
"singlecoil_train",
"singlecoil_val",
"singlecoil_test",
"singlecoil_challenge",
],
"brain_data": [
"multicoil_train",
"multicoil_val",
"multicoil_test",
"multicoil_challenge",
],
}

enc_sizes = {
"train": [(1, 128, 64), (1, 128, 49), (1, 150, 67)],
"val": [(1, 128, 64), (1, 170, 57)],
"test": [(1, 128, 64), (1, 96, 96)],
"challenge": [(1, 128, 64), (1, 96, 48)],
}
recon_sizes = {
"train": [(1, 64, 64), (1, 49, 49), (1, 67, 67)],
"val": [(1, 64, 64), (1, 57, 47)],
"test": [(1, 64, 64), (1, 96, 96)],
"challenge": [(1, 64, 64), (1, 48, 48)],
}

metadata = {}
for dataset in data_splits:
for split in data_splits[dataset]:
fcount = 0
(path / dataset / split).mkdir(parents=True)
encs = enc_sizes[split.split("_")[-1]]
recs = recon_sizes[split.split("_")[-1]]
for i in range(len(encs)):
fname = path / dataset / split / f"file{fcount}.h5"
num_slices = rg.integers(2, max_num_slices)
if "multicoil" in split:
num_coils = rg.integers(2, max_num_coils)
enc_size = (num_slices, num_coils, encs[i][-2], encs[i][-1])
recon_size = (num_slices, recs[i][-2], recs[i][-1])
else:
enc_size = (num_slices, encs[i][-2], encs[i][-1])
recon_size = (num_slices, recs[i][-2], recs[i][-1])

data = rg.normal(size=enc_size) + 1j * rg.normal(size=enc_size)

if split.split("_")[-1] in ("train", "val"):
recon = np.absolute(rg.normal(size=recon_size)).astype(
np.dtype("<f4")
)
else:
mask = rg.integers(0, 2, size=recon_size[-1]).astype(bool)

with h5py.File(fname, "w") as hf:
hf.create_dataset("kspace", data=data.astype(np.complex64))
if split.split("_")[-1] in ("train", "val"):
hf.attrs["max"] = recon.max()
if "singlecoil" in split:
hf.create_dataset("reconstruction_esc", data=recon)
else:
hf.create_dataset("reconstruction_rss", data=recon)
else:
hf.create_dataset("mask", data=mask)

enc_size = encs[i]

enc_limits_center = enc_size[1] // 2 + 1
enc_limits_max = enc_size[1] - 2

padding_left = enc_size[1] // 2 - enc_limits_center
padding_right = padding_left + enc_limits_max

metadata[str(fname)] = (
{
"padding_left": padding_left,
"padding_right": padding_right,
"encoding_size": enc_size,
"recon_size": recon_size,
},
num_slices,
)

fcount += 1

return path / "knee_data", path / "brain_data", metadata


class FastMriRandomData(FastMriDataModule):
Expand Down
10 changes: 5 additions & 5 deletions Tests/ML/test_lightning_containers.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import pytest
from azureml.core import ScriptRunConfig
from azureml.train.hyperdrive.runconfig import HyperDriveConfig
from health_azure import AzureRunInfo
from pytorch_lightning import LightningModule

from InnerEye.Azure.azure_config import AzureConfig
Expand All @@ -24,10 +25,11 @@
from InnerEye.ML.run_ml import MLRunner
from InnerEye.ML.runner import Runner
from Tests.ML.configs.DummyModel import DummyModel
from Tests.ML.configs.lightning_test_containers import (DummyContainerWithAzureDataset, DummyContainerWithHooks,
DummyContainerWithModel, DummyContainerWithPlainLightning)
from Tests.ML.configs.fastmri_random import FastMriOnRandomData
from Tests.ML.configs.lightning_test_containers import (
DummyContainerWithAzureDataset, DummyContainerWithHooks, DummyContainerWithModel, DummyContainerWithPlainLightning
)
from Tests.ML.util import default_runner
from health_azure import AzureRunInfo


def test_run_container_in_situ(test_output_dirs: OutputFolderForTests) -> None:
Expand Down Expand Up @@ -127,7 +129,6 @@ def test_create_fastmri_container() -> None:
and if the submodule is created correctly.
"""
from InnerEye.ML.configs.other.fastmri_varnet import VarNetWithImageLogging
from Tests.ML.configs.fastmri_random import FastMriOnRandomData
FastMriOnRandomData()
VarNetWithImageLogging()

Expand All @@ -147,7 +148,6 @@ def test_run_fastmri_container(test_output_dirs: OutputFolderForTests) -> None:
with mock.patch("sys.argv", args):
loaded_config, run_info = runner.run()
assert isinstance(run_info, AzureRunInfo)
from Tests.ML.configs.fastmri_random import FastMriOnRandomData
assert isinstance(runner.lightning_container, FastMriOnRandomData)


Expand Down

0 comments on commit c901822

Please sign in to comment.