Skip to content

Commit

Permalink
Cut out redundant filter
Browse files Browse the repository at this point in the history
  • Loading branch information
Mathis Rasmussen committed Jun 15, 2023
1 parent 4efd793 commit bb833e8
Showing 1 changed file with 72 additions and 36 deletions.
108 changes: 72 additions & 36 deletions rt_utils/smoothing.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,31 +4,53 @@
from typing import List, Union, Tuple, Dict
import logging

# A set of parameters that is know to work well
default_smoothing_parameters = {
"iterations": 2,
"crop_margins": [20, 20, 1],
"initial_shift": [0, 0, 0],
"np_kron": {"scaling_factor": 3},
"ndimage_gaussian_filter": {"sigma": 2,
"radius": 3},
"threshold": {"threshold": 0.4},
"ndimage_median_filter": {"size": 3}
}
# A set of parameters that is know to work well
default_smoothing_parameters_2 = {
"iterations": 3,
"crop_margins": [20, 20, 1],
"np_kron": {"scaling_factor": 2},
"ndimage_gaussian_filter": {"sigma": 2,
"radius": 5},
"threshold": {"threshold": 0.4},
}


def kron_upscale(mask: np.ndarray, **kwargs):
scaling_factor = (kwargs["scaling_factor"], kwargs["scaling_factor"], 1)
return np.kron(mask, np.ones(scaling_factor))
def kron_upscale(mask: np.ndarray, params):
"""
This function upscales masks like so
def gaussian_blur(mask: np.ndarray, **kwargs):
return ndimage.gaussian_filter(mask, **kwargs)
1|2 1|1|2|2
3|4 --> 1|1|2|2
3|3|4|4
3|3|4|4
Scaling only in x and y direction
"""

scaling_array = (params["scaling_factor"], params["scaling_factor"], 1)

return np.kron(mask, np.ones(scaling_array))

def binary_threshold(mask: np.ndarray, **kwargs):
return mask > kwargs["threshold"]
def gaussian_blur(mask: np.ndarray, params):
return ndimage.gaussian_filter(mask, **params)

def median_filter(mask: np.ndarray, **kwargs):
return ndimage.median_filter(mask.astype(float), **kwargs)
def binary_threshold(mask: np.ndarray, params):
return mask > params["threshold"]

def get_new_margin(column, margin, column_length):
"""
This functions takes a column (of x, y, or z) coordinates and adds a margin.
If margin exceeds mask size, the margin is returned to most extreme possible values
"""
new_min = column.min() - margin
if new_min < 0:
new_min = 0
Expand All @@ -39,7 +61,11 @@ def get_new_margin(column, margin, column_length):

return new_min, new_max

def crop_mask(mask: np.ndarray, crop_margins: np.ndarray):
def crop_mask(mask: np.ndarray, crop_margins: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
"""
This function crops masks to non-zero pixels padded by crop_margins.
Returns (cropped mask, bounding box)
"""
x, y, z = np.nonzero(mask)

x_min, x_max = get_new_margin(x, crop_margins[0], mask.shape[0])
Expand All @@ -53,36 +79,45 @@ def crop_mask(mask: np.ndarray, crop_margins: np.ndarray):
bbox[4]: bbox[5]], bbox

def restore_mask_dimensions(cropped_mask: np.ndarray, new_shape, bbox):
"""
This funtion restores mask dimentions to the given shape.
"""
new_mask = np.zeros(new_shape)

new_mask[bbox[0]: bbox[1], bbox[2]: bbox[3], bbox[4]: bbox[5]] = cropped_mask
return new_mask.astype(bool)

def iteration_2d(mask: np.ndarray, np_kron, ndimage_gaussian_filter, threshold, ndimage_median_filter):
cropped_mask = kron_upscale(mask=mask, **np_kron)
"""
This is the actual set of filters. Applied iterative over z direction
"""
cropped_mask = kron_upscale(mask=mask, params=np_kron)

for z_idx in range(cropped_mask.shape[2]):
slice = cropped_mask[:, :, z_idx]
slice = gaussian_blur(mask=slice, **ndimage_gaussian_filter)
slice = binary_threshold(mask=slice, **threshold)
slice = median_filter(mask=slice, **ndimage_median_filter)
slice = gaussian_blur(mask=slice, params=ndimage_gaussian_filter)
slice = binary_threshold(mask=slice, params=threshold)

cropped_mask[:, :, z_idx] = slice

return cropped_mask

def iteration_3d(mask: np.ndarray, np_kron, ndimage_gaussian_filter, threshold, ndimage_median_filter):
cropped_mask = kron_upscale(mask=mask, **np_kron)
cropped_mask = gaussian_blur(mask=cropped_mask, **ndimage_gaussian_filter)
cropped_mask = binary_threshold(mask=cropped_mask, **threshold)
cropped_mask = median_filter(mask=cropped_mask, **ndimage_median_filter)
"""
This is the actual filters applied iteratively in 3d.
"""
cropped_mask = kron_upscale(mask=mask, params=np_kron)
cropped_mask = gaussian_blur(mask=cropped_mask, params=ndimage_gaussian_filter)
cropped_mask = binary_threshold(mask=cropped_mask, params=threshold)

return cropped_mask

def pipeline(mask: np.ndarray,
apply_smoothing: str,
smoothing_parameters: Union[Dict, None]):

"""
This is the entrypoint for smoothing a mask.
"""
if not smoothing_parameters:
smoothing_parameters = default_smoothing_parameters

Expand All @@ -91,49 +126,50 @@ def pipeline(mask: np.ndarray,
np_kron = smoothing_parameters["np_kron"]
ndimage_gaussian_filter = smoothing_parameters["ndimage_gaussian_filter"]
threshold = smoothing_parameters["threshold"]
ndimage_median_filter = smoothing_parameters["ndimage_median_filter"]

print(f"Original mask shape {mask.shape}")
print(f"Cropping mask to non-zero")
logging.info(f"Original mask shape {mask.shape}")
logging.info(f"Cropping mask to non-zero")
cropped_mask, bbox = crop_mask(mask, crop_margins=crop_margins)
final_shape, final_bbox = get_final_mask_shape_and_bbox(mask=mask,
scaling_factor=np_kron["scaling_factor"],
iterations=iterations,
bbox=bbox)
print(f"Final scaling with factor of {np_kron['scaling_factor']} for {iterations} iterations")
logging.info(f"Final scaling with factor of {np_kron['scaling_factor']} for {iterations} iterations")
for i in range(iterations):
print(f"Iteration {i+1} out of {iterations}")
print(f"Applying filters")
logging.info(f"Iteration {i+1} out of {iterations}")
logging.info(f"Applying filters")
if apply_smoothing == "2d":
cropped_mask = iteration_2d(cropped_mask,
np_kron=np_kron,
ndimage_gaussian_filter=ndimage_gaussian_filter,
threshold=threshold,
ndimage_median_filter=ndimage_median_filter)
threshold=threshold)
elif apply_smoothing == "3d":
cropped_mask = iteration_3d(cropped_mask,
np_kron=np_kron,
ndimage_gaussian_filter=ndimage_gaussian_filter,
threshold=threshold,
ndimage_median_filter=ndimage_median_filter)
threshold=threshold)
else:
raise Exception("Wrong dimension parameter. Use '2d' or '3d'.")

# Restore dimensions
print("Restoring original mask shape")
logging.info("Restoring original mask shape")
mask = restore_mask_dimensions(cropped_mask, final_shape, final_bbox)
return mask

def get_final_mask_shape_and_bbox(mask, bbox, scaling_factor, iterations):
"""
This function scales image shape and the bounding box which should be used for the final mask
"""

final_scaling_factor = pow(scaling_factor, iterations)

final_shape = np.array(mask.shape)
final_shape[:2] *= final_scaling_factor

bbox[:4] *= final_scaling_factor
bbox[:4] -= round(final_scaling_factor * 0.5) # Shift organ
print("Final shape: ", final_shape)
print("Final bbox: ", bbox)
bbox[:4] *= final_scaling_factor # Scale bounding box to final shape
bbox[:4] -= round(final_scaling_factor * 0.5) # Shift volumes to account for the shift that occurs as a result of the scaling
logging.info("Final shape: ", final_shape)
logging.info("Final bbox: ", bbox)
return final_shape, bbox


0 comments on commit bb833e8

Please sign in to comment.