-
Notifications
You must be signed in to change notification settings - Fork 6
/
calc_cm.py
72 lines (53 loc) · 2.62 KB
/
calc_cm.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
import argparse
import torch
import torchvision
import torchvision.transforms as transforms
def main():
parser = argparse.ArgumentParser()
parser.add_argument("-data_path", help="Path to your dataset", type=str, default='')
parser.add_argument("-batch_size", type=int, default=10)
parser.add_argument("-use_rgb", action='store_true')
parser.add_argument("-model_file", type=str, default='')
parser.add_argument("-output_file", type=str, default='')
params = parser.parse_args()
main_calc(params)
def main_calc(params):
transform_list = [transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor()]
if not params.use_rgb:
rgb2bgr = transforms.Compose([transforms.Lambda(lambda x: x[torch.LongTensor([2,1,0])])])
transform_list += [rgb2bgr]
dataset = torchvision.datasets.ImageFolder(params.data_path, transform=transforms.Compose(transform_list))
loader = torch.utils.data.DataLoader(dataset, batch_size=params.batch_size, num_workers=0, shuffle=False)
print('Computing dataset covariance matrix (this may take a while)')
cov_mtx = 0
for images, _ in loader:
for b in range(images.size(0)):
cov_mtx += rgb_cov(images[b].permute(1,2,0))
cov_mtx = cov_mtx / len(loader.dataset)
U,S,V = torch.svd(cov_mtx)
epsilon = 1e-10
svd_sqrt = U @ torch.diag(torch.sqrt(S + epsilon))
print('Color correlation matrix\n')
print(svd_sqrt)
print_string = "color_decorrelation " + '"' + str(round(svd_sqrt[0][0].item(), 4)) + ',' + str(round(svd_sqrt[0][1].item(), 4)) + ',' + str(round(svd_sqrt[0][2].item(), 4)) \
+ ',' + str(round(svd_sqrt[1][0].item(), 4)) + ',' + str(round(svd_sqrt[1][1].item(), 4)) + ',' + str(round(svd_sqrt[1][2].item(), 4)) \
+ ',' + str(round(svd_sqrt[2][0].item(), 4)) + ',' + str(round(svd_sqrt[2][1].item(), 4)) + ',' + str(round(svd_sqrt[2][2].item(), 4)) + '"'
print("\n-" + print_string.replace('-', 'n'))
if params.model_file != '':
checkpoint = torch.load(params.model_file, map_location='cpu')
checkpoint['color_correlation_svd_sqrt'] = svd_sqrt
if params.output_file == '':
params.output_file = params.model_file
print('\nSaving color correlation matrix to ' + params.output_file)
torch.save(checkpoint, params.output_file)
def rgb_cov(im):
'''
Assuming im a torch.Tensor of shape (H,W,3):
'''
im_re = im.reshape(-1, 3)
im_re -= im_re.mean(0, keepdim=True)
return 1/(im_re.shape[0]-1) * im_re.T @ im_re
if __name__ == "__main__":
main()