-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
16 changed files
with
1,913 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,12 @@ | ||
__pycache__/ | ||
arrays/ | ||
.idea/ | ||
result/ | ||
*.nii | ||
*.gz | ||
blend_quality.py | ||
flip.py | ||
geometry_distribution.py | ||
show_feature.py | ||
show_filter.py | ||
test_deep.py |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,7 @@ | ||
# NISR | ||
implemation of 3D MRI NIfTI super resolution with machine learning | ||
|
||
Paper will be updated.. | ||
|
||
|
||
Package Required:cupy, numpy, nibabel |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,64 @@ | ||
import numpy as np | ||
import random | ||
|
||
import filter_constant as C | ||
from util import * | ||
|
||
def mod_crop(im, modulo): | ||
H, W, D = im.shape | ||
size0 = H - H % modulo | ||
size1 = W - W % modulo | ||
size2 = D - D % modulo | ||
|
||
out = im[0:size0, 0:size1, 0:size2] | ||
|
||
return out | ||
|
||
def get_point_list_pixel_type(array): | ||
sampled_list = [[] for j in range(C.PIXEL_TYPE)] | ||
for xP, yP, zP in array: | ||
t = xP % C.R * (C.R ** 2) + yP % C.R * C.R + zP % C.R | ||
sampled_list[t].append([xP, yP, zP]) | ||
return sampled_list | ||
|
||
def get_sampled_point_list(array): | ||
[x_range, y_range, z_range] = crop_slice(array) | ||
|
||
xyz_range = [[x, y, z] for x in x_range for y in y_range for z in z_range] | ||
sample_range = random.sample(xyz_range, len(xyz_range) // C.TRAIN_DIV) | ||
sampled_list = get_point_list_pixel_type(sample_range) | ||
#split_range = list(chunks(sample_range, len(sample_range) // TRAIN_STP - 1)) | ||
|
||
return sampled_list | ||
|
||
def crop_slice(array, padding, factor): | ||
for i in range(padding, array.shape[0] - padding): | ||
if not np.all(array[i, :, :] == 0): | ||
x_use1 = i - padding | ||
x_use1 = x_use1 - (x_use1 % factor) | ||
break | ||
for i in reversed(range(padding, array.shape[0] - padding)): | ||
if not np.all(array[i, :, :] == 0): | ||
x_use2 = i + padding | ||
break | ||
for i in range(padding, array.shape[1] - padding): | ||
if not np.all(array[:, i, :] == 0): | ||
y_use1 = i - padding | ||
y_use1 = y_use1 - (y_use1 % factor) | ||
break | ||
for i in reversed(range(padding, array.shape[1] - padding)): | ||
if not np.all(array[:, i, :] == 0): | ||
y_use2 = i + padding | ||
break | ||
for i in range(padding, array.shape[2] - padding): | ||
if not np.all(array[:, :, i] == 0): | ||
z_use1 = i - padding | ||
z_use1 = z_use1 - (z_use1 % factor) | ||
break | ||
for i in reversed(range(padding, array.shape[2] - padding)): | ||
if not np.all(array[:, :, i] == 0): | ||
z_use2 = i + padding | ||
break | ||
|
||
area = (slice(x_use1, x_use2), slice(y_use1, y_use2), slice(z_use1, z_use2)) | ||
return area |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,85 @@ | ||
import numpy as np | ||
|
||
TRAIN_GLOB = './train/*.nii.gz' | ||
TEST_GLOB = "./test/*.nii.gz" | ||
RESULT_DIR = "./result/" | ||
|
||
Q_ANGLE_T = 8 | ||
Q_ANGLE_P = 8 | ||
|
||
PATCH_SIZE = 11 | ||
PATCH_HALF = PATCH_SIZE // 2 | ||
|
||
GRADIENT_SIZE = 9 | ||
GRADIENT_HALF = GRADIENT_SIZE // 2 | ||
|
||
Q_TRACE = 3 | ||
Q_FA = 3 | ||
Q_MODE = 3 | ||
|
||
R = 4 | ||
|
||
Q_TOTAL = Q_ANGLE_P * Q_ANGLE_T * Q_TRACE * Q_FA * Q_MODE | ||
FILTER_VOL = PATCH_SIZE ** 3 | ||
|
||
TRAIN_DIV = 3 | ||
SAMPLE_RATE = 3 | ||
SHARPEN = 'False' | ||
BLEND_THRESHOLD = 10 | ||
|
||
LR_TYPE = 'interpolation' | ||
FEATURE_TYPE = 'lambda1_coh2' | ||
TRAIN_FILE_MAX = 99999999 | ||
|
||
|
||
def argument_parse(): | ||
import argparse | ||
import sys | ||
|
||
global Q_ANGLE_T, Q_ANGLE_P, GRADIENT_SIZE, PATCH_SIZE, PATCH_HALF | ||
global Q_TRACE, Q_FA, Q_MODE, R, Q_TOTAL, FILTER_VOL, TRAIN_DIV | ||
global SHARPEN, BLEND_THRESHOLD, LR_TYPE, FEATURE_TYPE, TRAIN_FILE_MAX | ||
|
||
parser = argparse.ArgumentParser() | ||
|
||
parser.add_argument('--q_angle_t', required=False, default=Q_ANGLE_T) | ||
parser.add_argument('--q_angle_p', required=False, default=Q_ANGLE_P) | ||
parser.add_argument('--filter_len', required=False, default=PATCH_SIZE) | ||
parser.add_argument('--grad_len', required=False, default=GRADIENT_SIZE) | ||
parser.add_argument('--factor', required=False, default=R) | ||
parser.add_argument('--train_div', required=False, default=TRAIN_DIV) | ||
parser.add_argument('--sharpen', required=False, default=SHARPEN) | ||
parser.add_argument('--blend_threshold', required=False, default=BLEND_THRESHOLD) | ||
parser.add_argument('--lr_type', required=False, default=LR_TYPE) | ||
parser.add_argument('--feature_type', required=False, default=FEATURE_TYPE) | ||
parser.add_argument('--train_file_max', required=False, default=TRAIN_FILE_MAX) | ||
|
||
args = parser.parse_args() | ||
|
||
assert int(args.q_angle_t) > 3 | ||
assert int(args.q_angle_p) > 3 | ||
assert int(args.filter_len) > 2 and int(args.filter_len) % 2 == 1 | ||
assert int(args.grad_len) > 2 and int(args.filter_len) % 2 == 1 | ||
assert int(args.factor) >= 2 | ||
assert int(args.train_div) >= 1 | ||
assert args.sharpen in ['True', 'False'] | ||
assert 1 <= int(args.blend_threshold) <= 26 | ||
assert args.lr_type in ['kspace', 'interpolation'] | ||
assert args.feature_type in ['lambda1_coh2', 'lambda1_fa', 'trace_coh2', 'trace_fa'] | ||
assert int(args.train_file_max) >= 1 | ||
|
||
Q_ANGLE_T = int(args.q_angle_t) | ||
Q_ANGLE_P = int(args.q_angle_t) | ||
PATCH_SIZE = int(args.filter_len) | ||
PATCH_HALF = PATCH_SIZE // 2 | ||
GRADIENT_SIZE = int(args.grad_len) | ||
GRADIENT_HALF = GRADIENT_SIZE // 2 | ||
R = int(args.factor) | ||
Q_TOTAL = Q_ANGLE_P * Q_ANGLE_T * Q_TRACE * Q_FA * Q_MODE | ||
FILTER_VOL = PATCH_SIZE ** 3 | ||
TRAIN_DIV = int(args.train_div) | ||
SHARPEN = (args.sharpen == 'True') | ||
BLEND_THRESHOLD = int(args.blend_threshold) | ||
LR_TYPE = args.lr_type | ||
FEATURE_TYPE = args.feature_type | ||
TRAIN_FILE_MAX = int(args.train_file_max) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,161 @@ | ||
import numpy as np | ||
import math | ||
|
||
from scipy.ndimage.filters import convolve | ||
from scipy.ndimage import zoom | ||
from numba import jit, njit, cuda, prange, vectorize, float32 | ||
from scipy.sparse.linalg import cg | ||
|
||
import filter_constant as C | ||
|
||
|
||
def dog_sharpener(input, sigma=0.85, alpha=1.414, r=15, ksize=(3,3,3)): | ||
G1 = gaussian_3d_blur(input, ksize, sigma) | ||
Ga1 = gaussian_3d_blur(input, ksize, sigma*alpha) | ||
D1 = add_weight(G1, 1+r, Ga1, -r, 0) | ||
|
||
G2 = gaussian_3d_blur(Ga1, ksize, sigma) | ||
Ga2 = gaussian_3d_blur(Ga1, ksize, sigma*alpha) | ||
D2 = add_weight(G2, 1+r, Ga2, -r, 0) | ||
|
||
G3 = gaussian_3d_blur(Ga2, ksize, sigma) | ||
Ga3 = gaussian_3d_blur(Ga2, ksize, sigma * alpha) | ||
D3 = add_weight(G3, 1+r, Ga3, -r, 0) | ||
|
||
B1 = blend_image(input, D3) | ||
B1 = blend_image(input, B1) | ||
B2 = blend_image(B1, D2) | ||
B2 = blend_image(input, B2) | ||
B3 = blend_image(B2, D1) | ||
B3 = blend_image(input, B3) | ||
|
||
output = np.clip(B3, 0, 1) | ||
|
||
return output | ||
|
||
|
||
def add_weight(im1, w1, im2, w2, b): | ||
return im1 * w1 + im2 * w2 + b | ||
|
||
|
||
def clip_image(im): | ||
clip_value = np.sort(im.ravel())[int(np.prod(im.shape) * 0.999)] | ||
im = np.clip(im, 0, clip_value) | ||
return im | ||
|
||
|
||
@njit(parallel=True) | ||
def ct_descriptor(im): | ||
H, W, D = im.shape | ||
windowSize = 3 | ||
Census = np.zeros((H, W, D)) | ||
CT = np.zeros((H, W, D, windowSize, windowSize, windowSize)) | ||
C = np.int((windowSize - 1) / 2) | ||
for i in prange(C, H - C): | ||
for j in prange(C, W - C): | ||
for k in prange(C, D - C): | ||
cen = 0 | ||
for a in prange(-C, C + 1): | ||
for b in prange(-C, C + 1): | ||
for c in prange(-C, C + 1): | ||
if not (a == 0 and b == 0 and c == 0): | ||
if im[i + a, j + b, k + c] < im[i, j, k]: | ||
cen += 1 | ||
CT[i, j, k, a + C, b + C, c + C] = 1 | ||
Census[i, j, k] = cen | ||
Census = Census / 26 | ||
return Census, CT | ||
|
||
|
||
@njit | ||
def blend_weight(LR, HR, ctLR, ctHR, threshold = 10): | ||
windowSize = 3 | ||
H, W, D = ctLR.shape[:3] | ||
blended = np.zeros((H, W, D), dtype=np.float64) | ||
|
||
C = np.int((windowSize - 1) / 2) | ||
for i in range(C, H - C): | ||
for j in range(C, W - C): | ||
for k in range(C, D - C): | ||
dist = 0 | ||
for a in range(-C, C + 1): | ||
for b in range(-C, C + 1): | ||
for c in range(-C, C + 1): | ||
if not (a == 0 and b == 0 and c == 0): | ||
if ctLR[i, j, k, a + C, b + C, c + C] != ctHR[i, j, k, a + C, b + C, c + C]: | ||
dist += 1 | ||
if dist > threshold: | ||
blended[i, j, k] = LR[i, j, k] | ||
else: | ||
blended[i, j, k] = HR[i, j, k] | ||
return blended | ||
|
||
|
||
@njit | ||
def blend_image(LR, HR, threshold = 10): | ||
censusLR, ctLR = ct_descriptor(LR) | ||
censusHR, ctHR = ct_descriptor(HR) | ||
blended = blend_weight(LR, HR, ctLR, ctHR, threshold) | ||
return blended | ||
|
||
|
||
@njit | ||
def blend_image2(LR, SR, threshold = 10): | ||
H, W, D = LR.shape | ||
blended = SR.copy() | ||
print(blended.shape) | ||
windowSize = 3 | ||
C = np.int((windowSize - 1) / 2) | ||
|
||
for i in range(C, H - C): | ||
for j in range(C, W - C): | ||
for k in range(C, D - C): | ||
cur = np.sort(SR[i-C: i+C+1, j-C: j+C+1, k-C: k+C+1].ravel()) | ||
# cur = cur[2:27-2] | ||
|
||
if cur[0] > SR[i, j, k] or cur[-1] < SR[i, j, k]: | ||
blended[i, j, k] = LR[i, j, k] | ||
# blended = blend_weight(LR, HR, ctLR, ctHR, threshold) | ||
return blended | ||
|
||
|
||
@njit | ||
def blend_image3(LR, SR, threshold = 3): | ||
H, W, D = LR.shape | ||
blended = SR.copy() | ||
windowSize = 3 | ||
C = np.int((windowSize - 1) / 2) | ||
|
||
for i in range(C, H - C): | ||
for j in range(C, W - C): | ||
for k in range(C, D - C): | ||
std_sr = np.std(LR[i-C: i+C+1, j-C: j+C+1, k-C: k+C+1].ravel()) | ||
|
||
if abs(LR[i, j, k] - SR[i, j, k]) > std_sr * threshold: | ||
blended[i, j, k] = LR[i, j, k] | ||
# blended = blend_weight(LR, HR, ctLR, ctHR, threshold) | ||
return blended | ||
|
||
|
||
def gaussian_3d_blur(input, ksize=(3,3,3), sigma=0.85): | ||
filter = gaussian_3d(ksize, sigma) | ||
output = convolve(input, filter) | ||
return output | ||
|
||
|
||
def gaussian_3d(shape=(3,3,3), sigma=0.85): | ||
m,n,o = [(ss-1.)/2. for ss in shape] | ||
z, y, x = np.ogrid[-m:m+1,-n:n+1, -o:o+1] | ||
h = np.exp( -(x*x + y*y + z*z) / (2.*sigma*sigma) ) | ||
h[ h < np.finfo(h.dtype).eps*h.max() ] = 0 | ||
sumh = h.sum() | ||
if sumh != 0: | ||
h /= sumh | ||
return h | ||
|
||
|
||
def get_normalized_gaussian(): | ||
weight = gaussian_3d((C.GRADIENT_SIZE, C.GRADIENT_SIZE, C.GRADIENT_SIZE)) | ||
weight = np.diag(weight.ravel()) | ||
weight = np.array(weight, dtype=np.float32) | ||
return weight |
Oops, something went wrong.