-
Notifications
You must be signed in to change notification settings - Fork 3
/
input_data.py
155 lines (123 loc) · 5.7 KB
/
input_data.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
import cv2
import random
import numpy as np
### this InputData file is adapted from the CVM-Net project ###
class InputData:
img_root = 'dataset/'
def __init__(self):
self.train_list = self.img_root + 'splits/train-19zl.csv'
self.test_list = self.img_root + 'splits/val-19zl.csv'
print('InputData::__init__: load %s' % self.train_list)
self.__cur_id = 0 # for training
self.id_list = []
self.id_idx_list = []
with open(self.train_list, 'r') as file:
idx = 0
for line in file:
data = line.split(',')
pano_id = (data[0].split('/')[-1]).split('.')[0]
# satellite filename, streetview filename, pano_id
self.id_list.append([data[0], data[1], pano_id])
self.id_idx_list.append(idx)
idx += 1
self.data_size = len(self.id_list)
print('InputData::__init__: load', self.train_list, ' data_size =', self.data_size)
print('InputData::__init__: load %s' % self.test_list)
self.__cur_test_id = 0 # for training
self.id_test_list = []
self.id_test_idx_list = []
with open(self.test_list, 'r') as file:
idx = 0
for line in file:
data = line.split(',')
pano_id = (data[0].split('/')[-1]).split('.')[0]
# satellite filename, streetview filename, pano_id
self.id_test_list.append([data[0], data[1], pano_id])
self.id_test_idx_list.append(idx)
idx += 1
self.test_data_size = len(self.id_test_list)
print('InputData::__init__: load', self.test_list, ' data_size =', self.test_data_size)
def next_batch_scan(self, batch_size):
if self.__cur_test_id >= self.test_data_size:
self.__cur_test_id = 0
return None, None
elif self.__cur_test_id + batch_size >= self.test_data_size:
batch_size = self.test_data_size - self.__cur_test_id
batch_grd = np.zeros([batch_size, 224, 1232, 3], dtype = np.float32)
batch_sat = np.zeros([batch_size, 512, 512, 3], dtype=np.float32)
for i in range(batch_size):
img_idx = self.__cur_test_id + i
# satellite
img = cv2.imread(self.img_root + self.id_test_list[img_idx][0])
img = cv2.resize(img, (512, 512), interpolation=cv2.INTER_AREA)
img = img.astype(np.float32)
# img -= 100.0
img[:, :, 0] -= 103.939 # Blue
img[:, :, 1] -= 116.779 # Green
img[:, :, 2] -= 123.6 # Red
batch_sat[i, :, :, :] = img
# ground
img = cv2.imread(self.img_root + self.id_test_list[img_idx][1])
img = img.astype(np.float32)
# img -= 100.0
img[:, :, 0] -= 103.939 # Blue
img[:, :, 1] -= 116.779 # Green
img[:, :, 2] -= 123.6 # Red
batch_grd[i, :, :, :] = img
self.__cur_test_id += batch_size
return batch_sat, batch_grd
def next_pair_batch(self, batch_size):
if self.__cur_id == 0:
for i in range(20):
random.shuffle(self.id_idx_list)
if self.__cur_id + batch_size + 2 >= self.data_size:
self.__cur_id = 0
return None, None
batch_sat = np.zeros([batch_size, 512, 512, 3], dtype=np.float32)
batch_grd = np.zeros([batch_size, 224, 1232, 3], dtype=np.float32)
i = 0
batch_idx = 0
while True:
if batch_idx >= batch_size or self.__cur_id + i >= self.data_size:
break
img_idx = self.id_idx_list[self.__cur_id + i]
i += 1
# satellite
img = cv2.imread(self.img_root + self.id_list[img_idx][0])
if img is None or img.shape[0] != img.shape[1]:
print('InputData::next_pair_batch: read fail: %s, %d, ' % (self.img_root + self.id_list[img_idx][0], i), img.shape)
continue
rand_crop = random.randint(1, 748)
if rand_crop > 512:
start = int((750 - rand_crop) / 2)
img = img[start : start + rand_crop, start : start + rand_crop, :]
img = cv2.resize(img, (512, 512), interpolation=cv2.INTER_AREA)
rand_rotate = random.randint(0, 4) * 90
rot_matrix = cv2.getRotationMatrix2D((256, 256), rand_rotate, 1)
img = cv2.warpAffine(img, rot_matrix, (512, 512))
img = img.astype(np.float32)
# img -= 100.0
img[:, :, 0] -= 103.939 # Blue
img[:, :, 1] -= 116.779 # Green
img[:, :, 2] -= 123.6 # Red
batch_sat[batch_idx, :, :, :] = img
# ground
img = cv2.imread(self.img_root + self.id_list[img_idx][1])
if img is None or img.shape[0] != 224 or img.shape[1] != 1232:
print('InputData::next_pair_batch: read fail: %s, %d, ' % (self.img_root + self.id_list[img_idx][1], i), img.shape)
continue
img = img.astype(np.float32)
# img -= 100.0
img[:, :, 0] -= 103.939 # Blue
img[:, :, 1] -= 116.779 # Green
img[:, :, 2] -= 123.6 # Red
batch_grd[batch_idx, :, :, :] = img
batch_idx += 1
self.__cur_id += i
return batch_sat, batch_grd
def get_dataset_size(self):
return self.data_size
def get_test_dataset_size(self):
return self.test_data_size
def reset_scan(self):
self.__cur_test_idd = 0