Skip to content
This repository has been archived by the owner on Mar 17, 2021. It is now read-only.

Commit

Permalink
Additional PEP 8 enforcement for the csv reader and additional unit t…
Browse files Browse the repository at this point in the history
…ests for the CSV reader
  • Loading branch information
csudre committed Jul 3, 2019
1 parent dffecd0 commit 813f287
Show file tree
Hide file tree
Showing 2 changed files with 126 additions and 55 deletions.
123 changes: 72 additions & 51 deletions niftynet/contrib/csv_reader/sampler_csvpatch.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@
from niftynet.engine.image_window import N_SPATIAL, LOCATION_FORMAT
from niftynet.io.misc_io import do_reorientation_idx, do_resampling_idx

SUPPORTED_MODES_CORRECTION=['pad', 'remove', 'random']


class CSVPatchSampler(ImageWindowDatasetCSV):
"""
Expand Down Expand Up @@ -51,6 +53,7 @@ def __init__(self,
tf.logging.info("initialised csv patch sampler %s ", self.window.shapes)
self.mode_correction = mode_correction
self.window_centers_sampler = rand_spatial_coordinates
self.available_subjects = reader._file_list.subject_id

# pylint: disable=too-many-locals
def layer_op(self, idx=None):
Expand All @@ -70,7 +73,6 @@ def layer_op(self, idx=None):
if self.window.n_samples > 1:
raise ValueError("\nThe number of windows per image has to be "
"1 with a csv_reader")

# flag_multi_row = False
print("Trying to run csv patch sampler ")
if 'sampler' not in self.csv_reader.names:
Expand All @@ -83,21 +85,59 @@ def layer_op(self, idx=None):
_, _, subject_id = self.csv_reader(idx)

print("subject id is ", subject_id)

idx_subject_id = np.where(
self.reader._file_list.subject_id == subject_id)[0][0]

self.available_subjects == subject_id)[0][0]
image_id, data, _ = self.reader(idx=idx_subject_id, shuffle=True)

subj_indices, csv_data, _ = self.csv_reader(subject_id=subject_id)

if 'sampler' not in self.csv_reader.names:
print('Uniform sampling because no csv sampler provided')
tf.logging.warning('Uniform sampling because no csv sampler '
'provided')

image_shapes = dict(
(name, data[name].shape) for name in self.window.names)
static_window_shapes = self.window.match_image_shapes(image_shapes)

num_idx, num_discard = self.check_csv_sampler_valid(subject_id,
image_shapes,
static_window_shapes
)
if self.mode_correction == 'remove':
if num_idx == num_discard:
self.available_subjects.drop(subject_id)
subject_id = None
while subject_id is None and len(self.available_subjects) > 0:
_, _, subject_id = self.csv_reader(idx)

# print("subject id is ", subject_id)

idx_subject_id = np.where(
self.available_subjects == subject_id)[0][0]
image_id, data, _ = self.reader(idx=idx_subject_id,
shuffle=True)
subj_indices, csv_data, _ = self.csv_reader(
subject_id=subject_id)
if 'sampler' not in self.csv_reader.names:
tf.logging.warning(
'Uniform sampling because no csv sampler provided')
image_shapes = dict(
(name, data[name].shape) for name in self.window.names)
static_window_shapes = self.window.match_image_shapes(
image_shapes)
num_idx, num_discard = self.check_csv_sampler_valid(
subject_id,
image_shapes,
static_window_shapes)
if num_idx == num_discard:
self.available_subjects.drop(subject_id)
subject_id = None
if subject_id is None:
tf.logging.fatal("None of the subjects has any suitable "
"samples. Consider using a different "
"alternative to unsuitable samples or "
"reducing your patch size")
raise ValueError

# find csv coordinates and return coordinates (not corrected) and
# corresponding csv indices
coordinates, idx = self.csvcenter_spatial_coordinates(
Expand All @@ -111,27 +151,16 @@ def layer_op(self, idx=None):
reject = False
if self.mode_correction == 'remove':
reject = True
print(idx, "index selected")
# print(idx, "index selected")
# initialise output dict, placeholders as dictionary keys
# this dictionary will be used in
# enqueue operation in the form of: `feed_dict=output_dict`
output_dict = {}
potential_pad = self.csv_reader.pad_by_task['sampler'][idx][0]
potential_pad_corr_end = -1.0 * \
np.asarray(potential_pad[N_SPATIAL:])
potential_pad_corr_end = -1.0 * np.asarray(potential_pad[N_SPATIAL:])
potential_pad_corr = np.concatenate((potential_pad[:N_SPATIAL],
potential_pad_corr_end), 0)

print("Pot pad is", potential_pad, potential_pad_corr,
self.csv_reader.df_by_task['sampler'].shape)
# print( np.asarray(self.csv_reader.df_by_task['sampler'])[idx[0]])
# print(self.reader.output_list[idx_subject_id]['image'].
# _output_axcodes,
# self.reader.output_list[idx_subject_id]['label'].
# original_axcodes)

# samples = idx[0]+np.arange(-2,4)
# print(self.csv_reader.pad_by_task['sampler'][samples])
potential_pad_corr_end), 0)

# fill output dict with data
for name in list(data):
coordinates_key = LOCATION_FORMAT.format(name)
Expand All @@ -147,7 +176,7 @@ def layer_op(self, idx=None):
x_start, y_start, z_start, x_end, y_end, z_end = \
location_array[window_id, 1:].astype(np.int32) + \
potential_pad_corr.astype(np.int32)
print(location_array[window_id, 1:]+potential_pad_corr)
# print(location_array[window_id, 1:]+potential_pad_corr)
try:
image_window = data[name][
x_start:x_end, y_start:y_end,
Expand All @@ -158,7 +187,7 @@ def layer_op(self, idx=None):
image_window))-N_SPATIAL, 1])
new_pad = np.concatenate((new_pad, add_pad),
0).astype(np.int32)
print(new_pad, "is padding")
# print(new_pad, "is padding")
new_img = np.pad(image_window, pad_width=new_pad,
mode='constant',
constant_values=0)
Expand All @@ -179,7 +208,7 @@ def layer_op(self, idx=None):
else:
output_dict[image_data_key] = image_array[0]
# fill output dict with csv_data
print("filling output dict")
# print("filling output dict")
if self.csv_reader is not None:
idx_dict = {}
list_keys = self.csv_reader.df_by_task.keys()
Expand Down Expand Up @@ -262,19 +291,18 @@ def csvcenter_spatial_coordinates(self,
else:
window_centres_list = []
list_idx = []
list_mod = list(img_sizes.keys())
print(list_mod)
self.check_csv_sampler_valid(subject_id, img_sizes, win_sizes)
_, _ = self.check_csv_sampler_valid(subject_id, img_sizes,
win_sizes)
idx_check, _, _ = self.csv_reader(
subject_id=subject_id, mode='multi', reject=False)
idx_multi = idx_check['sampler']
for mod in self.csv_reader.task_param:
all_coordinates[mod] = []
for n in range(0, n_samples):
print("reject value is ", reject)
# print("reject value is ", reject)
idx, data_csv, _ = self.csv_reader(
subject_id=subject_id, mode='single', reject=reject)
print(data_csv['sampler'].shape[0], 'data_sampler')
# print(data_csv['sampler'].shape[0], 'data_sampler')
if data_csv['sampler'].shape[0] > 0:
centre_transform = self.transform_centres(
subject_id, img_sizes,
Expand All @@ -288,7 +316,6 @@ def csvcenter_spatial_coordinates(self,
window_centres_list.append(centre_transform)
window_centres = np.concatenate(window_centres_list, 0)


if np.sum(self.csv_reader.valid_by_task['sampler'][idx_multi]) ==\
0 and np.asarray(window_centres).shape[0] == 0:
tf.logging.warning("Nothing is valid, taking random centres")
Expand All @@ -299,7 +326,6 @@ def csvcenter_spatial_coordinates(self,
list_idx = np.arange(0, n_samples)
print("all prepared and added ")


assert window_centres.shape == (n_samples, N_SPATIAL), \
"the coordinates generator should return " \
"{} samples of rank {} locations".format(n_samples, N_SPATIAL)
Expand All @@ -325,10 +351,10 @@ def csvcenter_spatial_coordinates(self,
# 'spatial coords: out of bounds.'

# include subject id as the 1st column of all_coordinates values
idx_subject_id = np.where(
self.reader._file_list.subject_id == subject_id)[0][0]
idx_subject_id = np.ones((n_samples,), dtype=np.int32)\
* idx_subject_id
idx_subject_id = np.where(self.reader._file_list.subject_id ==
subject_id)[0][0]
idx_subject_id = np.ones((n_samples,),
dtype=np.int32) * idx_subject_id
spatial_coords = np.append(
idx_subject_id[:, None], spatial_coords, axis=1)
all_coordinates[mod] = spatial_coords
Expand All @@ -341,8 +367,8 @@ def transform_centres(self, subject_id, img_sizes, windows_centres):
list_mod = list(img_sizes.keys())

print(list_mod)
idx_subject_id = np.where(
self.reader._file_list.subject_id == subject_id)[0][0]
idx_subject_id = np.where(self.reader._file_list.subject_id ==
subject_id)[0][0]
input_shape = self.reader.output_list[idx_subject_id][list_mod[
0]].original_shape[:N_SPATIAL]
output_shape = self.reader.output_list[idx_subject_id][list_mod[
Expand Down Expand Up @@ -385,34 +411,33 @@ def transform_centres(self, subject_id, img_sizes, windows_centres):
def check_csv_sampler_valid(self, subject_id, img_sizes, win_sizes):
print("Checking if csv_sampler valid is updated")
reject = False
if self.mode_correction == 'remove':
if self.mode_correction != 'pad':
reject = True
idx_multi, csv_data, _ = self.csv_reader(subject_id=subject_id,
idx=None, mode='multi',
reject=reject)

windows_centres = csv_data['sampler']
print("Windows extracted", windows_centres)
numb = windows_centres.shape[0]
if windows_centres.shape[0] > 0:
checked = self.csv_reader.valid_by_task['sampler'][
idx_multi['sampler']]
print(np.sum(checked), 'is sum of checked')
min_checked = np.min(checked)
numb_valid = np.sum(checked)
else:
min_checked = 0
numb_valid = 0
if min_checked >= 0:
print("Already checked, no need for further analysis")
return
return numb, numb-numb_valid
else:
transformed_centres = self.transform_centres(subject_id, img_sizes,
windows_centres)


pad = self.csv_reader.pad_by_task['sampler'][idx_multi['sampler']]

numb = windows_centres.shape[0]
img_spatial_size, win_spatial_size = \
_infer_spatial_size(img_sizes, win_sizes)
img_spatial_size, win_spatial_size = _infer_spatial_size(
img_sizes, win_sizes)
tf.logging.warning("Need to checked validity of samples for "
"subject %s" %subject_id)
checked = np.ones([numb])
Expand Down Expand Up @@ -456,12 +481,8 @@ def check_csv_sampler_valid(self, subject_id, img_sizes, win_sizes):
pad[:, N_SPATIAL:])

tf.logging.warning("to discard or pad is %d out of %d for mod "
"%s" % (
numb-np.sum(checked), numb, mod))
"%s" % (numb-np.sum(checked), numb, mod))

print("check on idx_multi", np.asarray(idx_multi['sampler']).shape,
np.asarray(idx_multi['sampler']).dtype,
checked.dtype)
idx_discarded = []
for i in range(0, len(checked)):
self.csv_reader.valid_by_task['sampler'][idx_multi[
Expand All @@ -479,7 +500,7 @@ def check_csv_sampler_valid(self, subject_id, img_sizes, win_sizes):
idx_discarded))))
print(
"updated valid part of csv_reader for subject %s" % subject_id)
return
return numb, numb-np.sum(checked)


# def correction_coordinates(coordinates, idx, pb_coord, img_sizes, win_sizes,
Expand Down
58 changes: 54 additions & 4 deletions tests/sampler_csvpatch_v2_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,6 +204,7 @@ def get_csvpatch_reader():
csv_reader.initialise(CSV_DATA, DYNAMIC_MOD_TASK, dynamic_list)
return csv_reader


def get_csvpatchbad_reader():
csv_reader = CSVReader(['sampler'])
csv_reader.initialise(CSVBAD_DATA, DYNAMIC_MOD_TASK, dynamic_list)
Expand All @@ -229,7 +230,7 @@ def test_3d_csvsampler_init(self):
sampler.close_all()


def test_dynamic_init(self):
def test_pad_init(self):
sampler = CSVPatchSampler(reader=get_large_window_reader(),
csv_reader=get_csvpatch_reader(),
window_sizes=LARGE_MOD_DATA,
Expand All @@ -244,22 +245,71 @@ def test_dynamic_init(self):
self.assertAllClose(out['image'].shape[1:], (75,75,75, 1))
sampler.close_all()

def test_remove_element(self):


def test_padd_volume(self):
sampler = CSVPatchSampler(reader=get_large_window_reader(),
csv_reader=get_csvpatch_reader(),
window_sizes=LARGE_MOD_DATA,
batch_size=2,
windows_per_image=1,
queue_length=3)
with self.test_session() as sess:
sampler.set_num_threads(2)
out = sess.run(sampler.pop_batch_op())
img_loc = out['image_location']
print(img_loc)
self.assertAllClose(out['image'].shape[1:], (75, 75, 75, 1))
sampler.close_all()

def test_change_orientation(self):
sampler = CSVPatchSampler(reader=get_large_window_reader(),
csv_reader=get_csvpatch_reader(),
window_sizes=LARGE_MOD_DATA,
batch_size=2,
windows_per_image=1,
queue_length=3)
with self.test_session() as sess:
sampler.set_num_threads(2)
out = sess.run(sampler.pop_batch_op())
img_loc = out['image_location']
print(img_loc)
self.assertAllClose(out['image'].shape[1:], (75, 75, 75, 1))
sampler.close_all()

def test_random_init(self):
sampler = CSVPatchSampler(reader=get_large_window_reader(),
csv_reader=get_csvpatch_reader(),
window_sizes=LARGE_MOD_DATA,
batch_size=2,
windows_per_image=1,
queue_length=3,
mode_correction='remove')
mode_correction='random')
with self.test_session() as sess:
sampler.set_num_threads(1)
sampler.set_num_threads(2)
out = sess.run(sampler.pop_batch_op())
img_loc = out['image_location']
print(img_loc)
self.assertAllClose(out['image'].shape[1:], (75, 75, 75, 1))
sampler.close_all()


def test_remove_element(self):

sampler = CSVPatchSampler(reader=get_large_window_reader(),
csv_reader=get_csvpatch_reader(),
window_sizes=LARGE_MOD_DATA,
batch_size=2,
windows_per_image=1,
queue_length=3,
mode_correction='remove')
with self.test_session() as sess:
sampler.set_num_threads(2)
with self.assertRaisesRegexp(ValueError, ""):
out = sess.run(sampler.pop_batch_op())



def test_ill_init(self):
with self.assertRaisesRegexp(Exception, ""):
sampler = \
Expand Down

0 comments on commit 813f287

Please sign in to comment.