-
Notifications
You must be signed in to change notification settings - Fork 9
/
slic_superpixels.py
384 lines (301 loc) · 13.9 KB
/
slic_superpixels.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
"""
Modified from scikit-image slic method
Original code (C) scikit-image
Modification (C) Benjamin Irving
See licence.txt for more details
"""
# coding=utf-8
from __future__ import division, absolute_import, unicode_literals, print_function
import collections as coll
import numpy as np
from scipy import ndimage as ndi
from scipy.ndimage.morphology import distance_transform_edt
from scipy.ndimage.filters import gaussian_filter
from maskslic.processing import get_mpd
import warnings
import matplotlib.pyplot as plt
# from skimage.segmentation import mark_boundaries
from skimage.util import img_as_float, regular_grid
from maskslic._slic import (_slic_cython,
_enforce_label_connectivity_cython)
from skimage.color import rgb2lab
def place_seed_points(image, img, mask, n_segments, spacing, q=99.99):
"""
Method for placing seed points in an ROI
Note:
Optimal point placement problem is somewhat related to the k-center problem
metric facility location (MFL)
Maxmin facility location
https://en.wikipedia.org/wiki/Facility_location_problem
Parameters
----------
image
mask
n_segments
spacing
Returns
-------
"""
segments_z = np.zeros(n_segments, dtype=np.int64)
segments_y = np.zeros(n_segments, dtype=np.int64)
segments_x = np.zeros(n_segments, dtype=np.int64)
m_inv = np.copy(mask)
nz = np.nonzero(m_inv)
p = [np.min(nz[0]), np.min(nz[1]), np.min(nz[2])]
pend = [np.max(nz[0]), np.max(nz[1]), np.max(nz[2])]
# cropping to bounding box around ROI
m_inv = m_inv[p[0]:pend[0]+1, p[1]:pend[1]+1, p[2]:pend[2]+1]
# SEED STEP 1: n seeds are placed as far as possible from every other seed and the edge.
for ii in range(n_segments):
# distance transform
dtrans = distance_transform_edt(m_inv, sampling=spacing)
dtrans = gaussian_filter(dtrans, sigma=0.1)
# sizes = ndi.sum(mask_dtrans, pdtrans, range(nb_labels + 1))
# Use the maximum locations for the first two points
coords1 = np.nonzero(dtrans == np.max(dtrans))
segments_z[ii] = coords1[0][0]
segments_x[ii] = coords1[1][0]
segments_y[ii] = coords1[2][0]
# adding a new point
m_inv[segments_z[ii], segments_x[ii], segments_y[ii]] = False
# Plot: Illustrate the seed point selection method
# plt.figure()
# plt.imshow(img)
# my_cmap = plt.cm.get_cmap('jet') # get a copy of the gray color map
# my_cmap.set_bad(alpha=0) # s
# d11 = dtrans[segments_z[ii], :, :]
# d11[d11==0] = np.nan
# plt.imshow(d11, cmap=my_cmap)
# plt.contour(mask[segments_z[ii], :, :] == 1, contours=1, colors='red', linewidths=1)
# plt.plot(segments_y[ii], segments_x[ii], marker='o', color='green')
# plt.axis('off')
# plt.show()
segments_z = segments_z + p[0]
segments_x = segments_x + p[1]
segments_y = segments_y + p[2]
segments_color = np.zeros((segments_z.shape[0], image.shape[3]))
segments = np.concatenate([segments_z[..., np.newaxis],
segments_x[..., np.newaxis],
segments_y[..., np.newaxis],
segments_color], axis=1)
sz = np.ascontiguousarray(segments_z, dtype=np.int32)
sx = np.ascontiguousarray(segments_x, dtype=np.int32)
sy = np.ascontiguousarray(segments_y, dtype=np.int32)
out1 = get_mpd(sz, sx, sy)
step_z, step_x, step_y = out1[0], out1[1], out1[2]
return segments, step_x, step_y, step_z
def slic(image, n_segments=100, compactness=10., max_iter=10, sigma=0,
spacing=None, multichannel=True, convert2lab=None,
enforce_connectivity=False, min_size_factor=0.5, max_size_factor=3,
slic_zero=False, seed_type='grid', mask=None, recompute_seeds=False,
plot_examples=False):
"""Segments image using k-means clustering in Color-(x,y,z) space.
Parameters
----------
image : 2D, 3D or 4D ndarray
Input image, which can be 2D or 3D, and grayscale or multichannel
(see `multichannel` parameter).
n_segments : int, optional
The (approximate) number of labels in the segmented output image.
compactness : float, optional
Balances color proximity and space proximity. Higher values give
more weight to space proximity, making superpixel shapes more
square/cubic. In SLICO mode, this is the initial compactness.
This parameter depends strongly on image contrast and on the
shapes of objects in the image. We recommend exploring possible
values on a log scale, e.g., 0.01, 0.1, 1, 10, 100, before
refining around a chosen value.
max_iter : int, optional
Maximum number of iterations of k-means.
sigma : float or (3,) array-like of floats, optional
Width of Gaussian smoothing kernel for pre-processing for each
dimension of the image. The same sigma is applied to each dimension in
case of a scalar value. Zero means no smoothing.
Note, that `sigma` is automatically scaled if it is scalar and a
manual voxel spacing is provided (see Notes section).
spacing : (3,) array-like of floats, optional
The voxel spacing along each image dimension. By default, `slic`
assumes uniform spacing (same voxel resolution along z, y and x).
This parameter controls the weights of the distances along z, y,
and x during k-means clustering.
multichannel : bool, optional
Whether the last axis of the image is to be interpreted as multiple
channels or another spatial dimension.
convert2lab : bool, optional
Whether the input should be converted to Lab colorspace prior to
maskslic. The input image *must* be RGB. Highly recommended.
This option defaults to ``True`` when ``multichannel=True`` *and*
``image.shape[-1] == 3``.
enforce_connectivity: bool, optional
Whether the generated segments are connected or not
min_size_factor: float, optional
Proportion of the minimum segment size to be removed with respect
to the supposed segment size ```depth*width*height/n_segments```
max_size_factor: float, optional
Proportion of the maximum connected segment size. A value of 3 works
in most of the cases.
slic_zero: bool, optional
Run SLIC-zero, the zero-parameter mode of SLIC. [2]_
mask: ndarray of bools or 0s and 1s, optional
Array of same shape as `image`. Supervoxel analysis will only be performed on points at
which mask == True
Returns
-------
labels : 2D or 3D array
Integer mask indicating segment labels.
Raises
------
ValueError
If ``convert2lab`` is set to ``True`` but the last array
dimension is not of length 3.
Notes
-----
* If `sigma > 0`, the image is smoothed using a Gaussian kernel prior to
maskslic.
* If `sigma` is scalar and `spacing` is provided, the kernel width is
divided along each dimension by the spacing. For example, if ``sigma=1``
and ``spacing=[5, 1, 1]``, the effective `sigma` is ``[0.2, 1, 1]``. This
ensures sensible smoothing for anisotropic images.
* The image is rescaled to be in [0, 1] prior to processing.
* Images of shape (M, N, 3) are interpreted as 2D RGB images by default. To
interpret them as 3D with the last dimension having length 3, use
`multichannel=False`.
References
----------
.. [1] Radhakrishna Achanta, Appu Shaji, Kevin Smith, Aurelien Lucchi,
Pascal Fua, and Sabine Süsstrunk, SLIC Superpixels Compared to
State-of-the-art Superpixel Methods, TPAMI, May 2012.
.. [2] http:https://ivrg.epfl.ch/research/superpixels#SLICO
Examples
--------
>>> from maskslic import slic
>>> from skimage.data import astronaut
>>> img = astronaut()
>>> segments = slic(img, n_segments=100, compactness=10)
Increasing the compactness parameter yields more square regions:
>>> segments = slic(img, n_segments=100, compactness=20)
"""
# if enforce_connectivity:
# raise NotImplementedError("Enforce connectivity has not been implemented yet for maskSLIC.\n"
# "Please set enforce connectivity to 'False' ")
if slic_zero:
raise NotImplementedError("Slic zero has not been implemented yet for maskSLIC.")
img = np.copy(image)
if mask is not None:
msk = np.copy(mask==1)
else:
msk = None
# print("mask shape", msk.shape)
if mask is None and seed_type == 'nplace':
warnings.warn('nrandom assignment of seed points should only be used with an ROI. Changing seed type.')
seed_type = 'size'
if seed_type == 'nplace' and recompute_seeds is False:
warnings.warn('Seeds should be recomputed when seed points are randomly assigned')
image = img_as_float(image)
is_2d = False
if image.ndim == 2:
# 2D grayscale image
image = image[np.newaxis, ..., np.newaxis]
is_2d = True
elif image.ndim == 3 and multichannel:
# Make 2D multichannel image 3D with depth = 1
image = image[np.newaxis, ...]
is_2d = True
elif image.ndim == 3 and not multichannel:
# Add channel as single last dimension
image = image[..., np.newaxis]
if mask is None:
mask = np.ones(image.shape[:3], dtype=np.bool)
else:
mask = np.asarray(mask, dtype=np.bool)
if mask.ndim == 2:
mask = mask[np.newaxis, ...]
if spacing is None:
spacing = np.ones(3)
elif isinstance(spacing, (list, tuple)):
spacing = np.array(spacing, dtype=np.double)
if not isinstance(sigma, coll.Iterable):
sigma = np.array([sigma, sigma, sigma], dtype=np.double)
sigma /= spacing.astype(np.double)
elif isinstance(sigma, (list, tuple)):
sigma = np.array(sigma, dtype=np.double)
if (sigma > 0).any():
# add zero smoothing for multichannel dimension
sigma = list(sigma) + [0]
image = ndi.gaussian_filter(image, sigma)
if multichannel and (convert2lab or convert2lab is None):
if image.shape[-1] != 3 and convert2lab:
raise ValueError("Lab colorspace conversion requires a RGB image.")
elif image.shape[-1] == 3:
image = rgb2lab(image)
depth, height, width = image.shape[:3]
if seed_type == 'nplace':
segments, step_x, step_y, step_z = place_seed_points(image, img, mask, n_segments, spacing)
# print('{0}, {1}, {2}'.format(step_x, step_y, step_z))
elif seed_type == 'grid':
# initialize cluster centroids for desired number of segments
# essentially just outputs the indices of a grid in the x, y and z direction
grid_z, grid_y, grid_x = np.mgrid[:depth, :height, :width]
# returns 3 slices (an object representing an array of slices, see builtin slice)
slices = regular_grid(image.shape[:3], n_segments)
step_z, step_y, step_x = [int(s.step) for s in slices] # extract step size from slices
segments_z = grid_z[slices] # use slices to extract coordinates for centre points
segments_y = grid_y[slices]
segments_x = grid_x[slices]
# mask_ind = mask[slices].reshape(-1)
# list of all locations as well as zeros for the color features
segments_color = np.zeros(segments_z.shape + (image.shape[3],))
segments = np.concatenate([segments_z[..., np.newaxis],
segments_y[..., np.newaxis],
segments_x[..., np.newaxis],
segments_color],
axis=-1).reshape(-1, 3 + image.shape[3])
if mask is not None:
ind1 = mask[segments[:, 0].astype('int'), segments[:, 1].astype('int'), segments[:, 2].astype('int')]
segments = segments[ind1, :]
# seg_list = []
# for ii in range(segments.shape[0]):
# if mask[segments[ii, 0], segments[ii, 1], segments[ii, 2]] != 0:
# seg_list.append(ii)
# segments = segments[seg_list, :]
else:
raise ValueError('seed_type should be nrandom or grid')
segments = np.ascontiguousarray(segments)
# we do the scaling of ratio in the same way as in the SLIC paper
# so the values have the same meaning
step = float(max((step_z, step_y, step_x)))
ratio = 1.0 / compactness
image = np.ascontiguousarray(image * ratio, dtype=np.double)
mask = np.ascontiguousarray(mask, dtype=np.int32)
segments_old = np.copy(segments)
if recompute_seeds:
# Seed step 2: Run SLIC to reinitialise seeds
# Runs the supervoxel method but only uses distance to better initialise the method
labels = _slic_cython(image, mask, segments, step, max_iter, spacing, slic_zero, only_dist=True)
# Testing
if plot_examples:
fig = plt.figure()
plt.imshow(img)
if msk is not None:
plt.contour(msk, contours=1, colors='yellow', linewidths=1)
plt.scatter(segments_old[:, 2], segments_old[:, 1], color='green')
plt.axis('off')
fig = plt.figure()
plt.imshow(img)
if msk is not None:
plt.contour(msk, contours=1, colors='yellow', linewidths=1)
plt.scatter(segments[:, 2], segments[:, 1], color='green')
plt.axis('off')
# image = np.ascontiguousarray(image * ratio)
labels = _slic_cython(image, mask, segments, step, max_iter, spacing, slic_zero, only_dist=False)
if enforce_connectivity:
if msk is None:
segment_size = depth * height * width / n_segments
else:
segment_size = msk.sum() / n_segments
min_size = int(min_size_factor * segment_size)
max_size = int(max_size_factor * segment_size)
labels = _enforce_label_connectivity_cython(labels, mask, n_segments, min_size, max_size)
if is_2d:
labels = labels[0]
return labels