-
Notifications
You must be signed in to change notification settings - Fork 2
/
tort_grid.py
60 lines (50 loc) · 2.02 KB
/
tort_grid.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
import matplotlib.pyplot as plt
# import SimpleITK as sitk
import numpy as np
import PIL.Image as Image
def grid2contour(grid, title):
'''
grid--image_grid used to show deform field
type: numpy ndarray, shape: (h, w, 2), value range:(-1, 1)
'''
assert grid.ndim == 3
x = np.arange(-1, 1, 2.0 / grid.shape[1])
y = np.arange(-1, 1, 2.0 / grid.shape[0])
X, Y = np.meshgrid(x, y)
Z1 = grid[:, :, 0] + 2 # remove the dashed line
Z1 = Z1[::-1] # vertical flip
Z2 = grid[:, :, 1] + 2
plt.figure()
plt.contour(X, Y, Z1, 15, levels=50, colors='k') # 改变levels的值,可以改变形变场的密集程度
plt.contour(X, Y, Z2, 15, levels=50, colors='k')
plt.xticks(()), plt.yticks(()) # remove x, y ticks
plt.title(title)
plt.show()
def show_grid(img_arr):
# img_arr = np.transpose(img_arr, [1, 2, 0])
img_shape = img_arr.shape
print(img_shape)
# 起点、终点、步长(可为小数)
x = np.arange(-1, 1, 2 / img_shape[1])
y = np.arange(-1, 1, 2 / img_shape[0])
X, Y = np.meshgrid(x, y)
regular_grid = np.stack((X, Y), axis=2)
grid2contour(regular_grid, "regular_grid")
rand_field = np.random.rand(*img_shape[:2], 2) # 参数前加*是以元组形式导入
rand_field_norm = rand_field.copy()
rand_field_norm[:, :, 0] = rand_field_norm[:, :, 0] * 2 / img_shape[1]
rand_field_norm[:, :, 1] = rand_field_norm[:, :, 1] * 2 / img_shape[0]
sampling_grid = regular_grid + rand_field_norm
grid2contour(sampling_grid, "sampling_grid")
img_arr[..., 0] = img_arr[..., 0] * 2 / img_shape[1]
img_arr[..., 1] = img_arr[..., 1] * 2 / img_shape[0]
img_grid = regular_grid + img_arr
grid2contour(img_grid, "img_grid")
if __name__ == "__main__":
# img = sitk.ReadImage("./2D1.nii")
# img_arr = sitk.GetArrayFromImage(img)
imgs = np.load('/Users/huangwenbin/Desktop/SAR-voxelmorph/data/test3/flow.npy')
for i in range(imgs.shape[0]):
img_arr = imgs[i]
show_grid(img_arr)
print("end")