This repository has been archived by the owner on Mar 17, 2021. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 404
/
image_window_dataset.py
executable file
·386 lines (321 loc) · 14 KB
/
image_window_dataset.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
# -*- coding: utf-8 -*-
"""
Creating ``tf.data.Dataset`` instance for image window sampler.
"""
from __future__ import absolute_import, division, print_function
import inspect
import numpy as np
import tensorflow as tf
# pylint: disable=no-name-in-module
from tensorflow.python.data.util import nest
from tensorflow.python.keras.utils import GeneratorEnqueuer
from niftynet.engine.image_window import ImageWindow, N_SPATIAL, \
LOCATION_FORMAT, BUFFER_POSITION_NP_TYPE
from niftynet.io.misc_io import squeeze_spatial_temporal_dim
from niftynet.layer.base_layer import Layer
from niftynet.utilities.util_common import look_up_operations
# when total number of window samples are not divisible by batch_size
# the class supports different modes for the final batch
# 'drop': drop the remainder batch
# 'pad': padding the final smaller batch with -1s
# 'dynamic': output the remainder directly (in this case the batch_size
# is undetermined at 'compile time')
SMALLER_FINAL_BATCH_MODE = {'drop', 'pad', 'dynamic'}
# pylint: disable=too-many-instance-attributes
class ImageWindowDataset(Layer):
"""
This class creates a ``tf.data.Dataset`` instance from
a sampler's layer_op function or generator.
If ``from_generator``, ``Dataset.from_generator`` interface will be used,
``Dataset.map`` interface will be used otherwise::
if the windows are from a image reader,
the total number of windows produced
will be: `epoch x n_subjects x windows_per_image`
if the windows are from a generator,
the total number of windows produced
will be: "iterations from the generator" x num_threads
"""
# pylint: disable=too-many-arguments
def __init__(self,
reader=None,
window_sizes=None,
batch_size=1,
windows_per_image=1,
queue_length=10,
shuffle=True,
epoch=-1,
smaller_final_batch_mode='pad',
seed=None,
name='image_dataset'):
Layer.__init__(self, name=name)
self._num_threads = 1
self._enqueuer = None
self._seed = seed
self.dataset = None
self.iterator = None
self.reader = reader
self.batch_size = batch_size
self.queue_length = int(max(queue_length, round(batch_size * 5.0)))
if self.queue_length > queue_length:
tf.logging.warning(
'sampler queue_length should be larger than batch_size, '
'defaulting to batch_size * 5.0 (%s).', self.queue_length)
self.from_generator = inspect.isgeneratorfunction(self.layer_op)
self.shuffle = shuffle
self.epoch = 1 if self.from_generator else epoch
self.smaller_final_batch_mode = look_up_operations(
smaller_final_batch_mode.lower(), SMALLER_FINAL_BATCH_MODE)
self.n_subjects = 1
self.window = None
if reader is not None:
self.window = ImageWindow.from_data_reader_properties(
reader.input_sources,
reader.shapes,
reader.tf_dtypes,
window_sizes or (-1, -1, -1))
self.n_subjects = reader.num_subjects
self.window.n_samples = windows_per_image
@property
def shapes(self):
"""
the sampler output (value of ``layer_op``) is::
[windows_per_image, x, y, z, 1, channels]
returns a dictionary of sampler output shapes
"""
assert self.window, 'Unknown output shapes: self.window not initialised'
return self.window.shapes
@property
def tf_shapes(self):
"""
returns a dictionary of sampler output tensor shapes
"""
assert self.window, 'Unknown output shapes: self.window not initialised'
return self.window.tf_shapes
@property
def tf_dtypes(self):
"""
returns a dictionary of sampler output tensorflow dtypes
"""
assert self.window, 'Unknown output dtypes: self.window not initialised'
return self.window.tf_dtypes
def set_num_threads(self, num_threads):
"""
Set number windows to generate in parallel.
"""
self._num_threads = int(num_threads)
def layer_op(self, idx=None):
"""
Generating each image as a window.
Overriding this function to create new image sampling strategies.
This function should either yield or return a dictionary
(of multiple windows per image)::
return a dictionary:
{
'image_name': a numpy array [n_samples, h, w, d, chn],
'image_name_location': [n_samples, 7]
}
where the 7-element location vector encode the image_id,
starting and ending coordinates of the image window.
Following the same notation, the dictionary can be extended
to multiple modalities; the keys will be::
{'image_name_1', 'image_name_1_location',
'image_name_2', 'image_name_2_location', ...}
:param idx: image_id used to load the image at the i-th row of
the input
:return: a image data dictionary
"""
image_id, image_data, _ = self.reader(idx=idx)
for mod in list(image_data):
spatial_shape = image_data[mod].shape[:N_SPATIAL]
coords = self.dummy_coordinates(image_id, spatial_shape, 1)
image_data[LOCATION_FORMAT.format(mod)] = coords
image_data[mod] = image_data[mod][np.newaxis, ...]
return image_data
# # The following is a demo of generator as the layer_op
# # Often we don't know the total number of elements that
# # will be generated, epoch is always 1.
# for idx in range(100):
# image_id, image_data, _ = self.reader()
# for mod in list(image_data):
# spatial_shape = image_data[mod].shape[:N_SPATIAL]
# coords = self.dummy_coordinates(image_id, spatial_shape, 1)
# image_data[LOCATION_FORMAT.format(mod)] = coords
# image_data[mod] = image_data[mod][np.newaxis, ...]
# yield image_data
def pop_batch_op(self):
"""
This function is used when connecting a sampler output
to a network. e.g.::
data_dict = self.get_sampler()[0].pop_batch_op(device_id)
net_output = net_model(data_dict['image'], is_training)
.. caution::
Note it squeezes the output tensor of 6 dims
``[batch, x, y, z, time, modality]``
by removing all dims along which length is one.
:return: a dictionary of image window tensors.
"""
if self.dataset is None or self.iterator is None:
# in case `run_threads` is not called,
# here we initialise the dataset and iterator
self.init_dataset()
self.iterator = self.dataset.make_one_shot_iterator()
# self.iterator = tf.data.Iterator.from_structure(
# self.dataset.output_types, self.dataset.output_shapes)
window_output = self.iterator.get_next()
for name in window_output:
window_output[name] = squeeze_spatial_temporal_dim(
window_output[name])
return window_output
def init_dataset(self):
"""
Make a window samples dataset from the reader and layer_op.
This function sets ``self.dataset``.
:return:
"""
if not self.from_generator:
dataset = self._dataset_from_range()
else:
dataset = self._dataset_from_generator()
self.dataset = self.dataset_preprocessing(dataset)
def dataset_preprocessing(self, dataset):
"""
dataset: batch and shuffle
:param dataset: a `tf.data.Dataset` instance
:return: a `tf.data.Dataset` instance
"""
dataset = dataset.repeat(self.epoch)
dataset = dataset.prefetch(buffer_size=self.queue_length)
if self.shuffle:
# locally shuffle the buffer of image windows
dataset = dataset.shuffle(
buffer_size=self.queue_length, seed=self._seed)
if self.smaller_final_batch_mode == 'drop':
# drop the remainder if there's not enough windows to
# form a batch, so that we have a fixed batch size.
# dataset = dataset.apply(tf.contrib.data.batch_and_drop_remainder(
# batch_size=self.batch_size))
# new API since TF 1.10
dataset = dataset.batch(batch_size=self.batch_size,
drop_remainder=True)
return dataset
dataset = dataset.batch(batch_size=self.batch_size)
if self.smaller_final_batch_mode == 'dynamic' and self.batch_size > 1:
return dataset
# self.smaller_final_batch_mode is 'pad'
# if self.batch_size == 1 no actual padding
# but this function will set the output shapes properly.
def _pad_batch(batch_size):
def _pad_batch_func(input_tensor_dict):
"""
function to pad the batch dim to `batch_size`.
(assuming the input dataset is a dictionary-based one)
"""
out_dict = {}
for in_name in list(input_tensor_dict):
in_var = input_tensor_dict[in_name]
var_shape = in_var.shape.as_list()
if batch_size > 1:
paddings = [[0, 0] for _ in in_var.shape]
paddings[0][1] = batch_size - tf.shape(in_var)[0]
in_var = tf.pad(
in_var, paddings, "CONSTANT", constant_values=-1)
var_shape[0] = batch_size
in_var.set_shape(var_shape)
out_dict[in_name] = in_var
return out_dict
return _pad_batch_func
dataset = dataset.map(_pad_batch(self.batch_size))
return dataset
# pylint: disable=redefined-variable-type
def _dataset_from_range(self):
"""
This function maps a dataset of integers to a dataset of images.
:return: a `tf.data.Dataset`
"""
# dataset: a list of integers
tf.logging.info(
'Initialising Dataset from %s subjects...', self.n_subjects)
dataset = tf.data.Dataset.range(self.n_subjects)
if self.shuffle:
# global shuffle of the entire set of subjects
dataset = dataset.shuffle(
buffer_size=self.n_subjects, seed=self._seed)
# dataset: map each integer i to n windows sampled from subject i
def _tf_wrapper(idx):
flattened_types = nest.flatten(self.tf_dtypes)
flattened_shapes = nest.flatten(self.tf_shapes)
flat_values = tf.py_func(
func=lambda subject_id: nest.flatten(self(subject_id)),
inp=[idx],
Tout=flattened_types)
for ret_t, shape in zip(flat_values, flattened_shapes):
# the actual returned numpy array shapes are not checked
ret_t.set_shape(shape)
return nest.pack_sequence_as(self.tf_dtypes, flat_values)
dataset = dataset.map(_tf_wrapper, num_parallel_calls=self._num_threads)
# dataset: slice the n-element window into n single windows
dataset = dataset.flat_map(map_func=tf.data.Dataset.from_tensor_slices)
return dataset
def _dataset_from_generator(self):
"""
Create a `tf.data.Dataset` from a layer_op (as a generator).
:return: a `tf.data.Dataset`
"""
tf.logging.info('Initialising dataset from generator...')
if self._num_threads < 2 or not self.shuffle:
window_generator = self
else:
# self._enqueuer = GeneratorEnqueuer(
# self(),
# use_multiprocessing=True,
# wait_time=0.01,
# seed=self._seed)
self._enqueuer = GeneratorEnqueuer(
self(),
use_multiprocessing=True)
self._enqueuer.start(
workers=self._num_threads, max_queue_size=self.queue_length)
window_generator = self._enqueuer.get
# dataset from generator
dataset = tf.data.Dataset.from_generator(
generator=window_generator,
output_types=self.tf_dtypes,
output_shapes=self.tf_shapes)
# dataset: slice the n-element window into n single windows
dataset = dataset.flat_map(map_func=tf.data.Dataset.from_tensor_slices)
return dataset
def run_threads(self, *_args, **_kwargs):
"""
This function is created for compatibility purposes
(Deprecating)
:param _args:
:param _kwargs:
:return:
"""
pass
# if self.dataset is None or self.iterator is None:
# self.init_dataset()
# self.iterator = self.dataset.make_one_shot_iterator()
# self.iterator = tf.data.Iterator.from_structure(
# self.dataset.output_types, self.dataset.output_shapes)
# sess = session or tf.get_default_session()
# if sess is not None:
# sess.run(self.iterator.make_initializer(self.dataset))
def close_all(self):
"""
For compatibility with the queue-based sampler.
"""
if self._enqueuer is not None:
self._enqueuer.stop()
@classmethod
def dummy_coordinates(cls, image_id, image_sizes, n_samples):
"""
This function returns a set of image window coordinates
which are just spatially from 0 to `image_sizes`.
:return: a numpy array of `n_samples` spatial coordinates
"""
starting_coordinates = [0, 0, 0]
image_spatial_shape = list(image_sizes[:N_SPATIAL])
coords = [image_id] + starting_coordinates + image_spatial_shape
coords = np.tile(np.asarray(coords), [n_samples, 1])
return coords.astype(BUFFER_POSITION_NP_TYPE)