This repository has been archived by the owner on Jun 15, 2023. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 129
/
MetaEmbeddingClassifier.py
executable file
·87 lines (67 loc) · 3.17 KB
/
MetaEmbeddingClassifier.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
"""Copyright (c) Facebook, Inc. and its affiliates.
All rights reserved.
This source code is licensed under the license found in the
LICENSE file in the root directory of this source tree.
Portions of the source code are from the OLTR project which
notice below and in LICENSE in the root directory of
this source tree.
Copyright (c) 2019, Zhongqi Miao
All rights reserved.
"""
import torch
import torch.nn as nn
from models.CosNormClassifier import CosNorm_Classifier
from utils import *
from os import path
import pdb
class MetaEmbedding_Classifier(nn.Module):
def __init__(self, feat_dim=2048, num_classes=1000):
super(MetaEmbedding_Classifier, self).__init__()
self.num_classes = num_classes
self.fc_hallucinator = nn.Linear(feat_dim, num_classes)
self.fc_selector = nn.Linear(feat_dim, feat_dim)
self.cosnorm_classifier = CosNorm_Classifier(feat_dim, num_classes)
def forward(self, x, centroids, *args):
# storing direct feature
direct_feature = x
batch_size = x.size(0)
feat_size = x.size(1)
# set up visual memory
x_expand = x.unsqueeze(1).expand(-1, self.num_classes, -1)
centroids_expand = centroids.unsqueeze(0).expand(batch_size, -1, -1)
keys_memory = centroids
# computing reachability
dist_cur = torch.norm(x_expand - centroids_expand, 2, 2)
values_nn, labels_nn = torch.sort(dist_cur, 1)
scale = 10.0
reachability = (scale / values_nn[:, 0]).unsqueeze(1).expand(-1, feat_size)
# computing memory feature by querying and associating visual memory
values_memory = self.fc_hallucinator(x)
values_memory = values_memory.softmax(dim=1)
memory_feature = torch.matmul(values_memory, keys_memory)
# computing concept selector
concept_selector = self.fc_selector(x)
concept_selector = concept_selector.tanh()
x = reachability * (direct_feature + concept_selector * memory_feature)
# storing infused feature
infused_feature = concept_selector * memory_feature
logits = self.cosnorm_classifier(x)
return logits, [direct_feature, infused_feature]
def create_model(feat_dim=2048, num_classes=1000, stage1_weights=False, dataset=None, log_dir=None, test=False, *args):
print('Loading Meta Embedding Classifier.')
clf = MetaEmbedding_Classifier(feat_dim, num_classes)
if not test:
if stage1_weights:
assert(dataset)
print('Loading %s Stage 1 Classifier Weights.' % dataset)
if log_dir is not None:
weight_dir = path.join('/'.join(log_dir.split('/')[:-1]), 'stage1')
else:
weight_dir = './logs/%s/stage1' % dataset
print('==> Loading weights from %s' % weight_dir)
clf.fc_hallucinator = init_weights(model=clf.fc_hallucinator,
weights_path=path.join(weight_dir, 'final_model_checkpoint.pth'),
classifier=True)
else:
print('Random initialized classifier weights.')
return clf