This repository has been archived by the owner on Jun 15, 2023. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 130
/
CosNormClassifier.py
44 lines (35 loc) · 1.41 KB
/
CosNormClassifier.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
"""Copyright (c) Facebook, Inc. and its affiliates.
All rights reserved.
This source code is licensed under the license found in the
LICENSE file in the root directory of this source tree.
Portions of the source code are from the OLTR project which
notice below and in LICENSE in the root directory of
this source tree.
Copyright (c) 2019, Zhongqi Miao
All rights reserved.
"""
import torch
import math
import torch.nn as nn
from torch.nn.parameter import Parameter
import pdb
class CosNorm_Classifier(nn.Module):
def __init__(self, in_dims, out_dims, scale=16, margin=0.5, init_std=0.001):
super(CosNorm_Classifier, self).__init__()
self.in_dims = in_dims
self.out_dims = out_dims
self.scale = scale
self.margin = margin
self.weight = Parameter(torch.Tensor(out_dims, in_dims).cuda())
self.reset_parameters()
def reset_parameters(self):
stdv = 1. / math.sqrt(self.weight.size(1))
self.weight.data.uniform_(-stdv, stdv)
def forward(self, input, *args):
norm_x = torch.norm(input.clone(), 2, 1, keepdim=True)
ex = (norm_x / (1 + norm_x)) * (input / norm_x)
ew = self.weight / torch.norm(self.weight, 2, 1, keepdim=True)
return torch.mm(self.scale * ex, ew.t())
def create_model(in_dims=512, out_dims=1000):
print('Loading Cosine Norm Classifier.')
return CosNorm_Classifier(in_dims=in_dims, out_dims=out_dims)