Skip to content

Commit

Permalink
Merge branch 'cdf' into irg
Browse files Browse the repository at this point in the history
  • Loading branch information
triomino committed Jul 4, 2020
2 parents 17f66dc + effaf3c commit a51d792
Show file tree
Hide file tree
Showing 16 changed files with 798 additions and 566 deletions.
36 changes: 10 additions & 26 deletions dataset/cifar100.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
from __future__ import print_function

import os
import socket
import numpy as np
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
Expand All @@ -20,15 +19,9 @@

def get_data_folder():
"""
return server-dependent path to store the data
return the path to store the data
"""
hostname = socket.gethostname()
if hostname.startswith('visiongpu'):
data_folder = '/data/vision/phillipi/rep-learn/datasets'
elif hostname.startswith('yonglong-home'):
data_folder = '/home/yonglong/Data/data'
else:
data_folder = './data/'
data_folder = './data/'

if not os.path.isdir(data_folder):
os.makedirs(data_folder)
Expand Down Expand Up @@ -60,10 +53,8 @@ class CIFAR100Instance(CIFAR100BackCompat):
"""CIFAR100Instance Dataset.
"""
def __getitem__(self, index):
if self.train:
img, target = self.train_data[index], self.train_labels[index]
else:
img, target = self.test_data[index], self.test_labels[index]

img, target = self.data[index], self.targets[index]

# doing this so that it is consistent with all other datasets
# to return a PIL Image
Expand Down Expand Up @@ -140,13 +131,9 @@ def __init__(self, root, train=True,
self.is_sample = is_sample

num_classes = 100
if self.train:
num_samples = len(self.train_data)
label = self.train_labels
else:
num_samples = len(self.test_data)
label = self.test_labels

num_samples = len(self.data)
label = self.targets

self.cls_positive = [[] for i in range(num_classes)]
for i in range(num_samples):
self.cls_positive[label[i]].append(i)
Expand All @@ -170,11 +157,9 @@ def __init__(self, root, train=True,
self.cls_negative = np.asarray(self.cls_negative)

def __getitem__(self, index):
if self.train:
img, target = self.train_data[index], self.train_labels[index]
else:
img, target = self.test_data[index], self.test_labels[index]


img, target = self.data[index], self.targets[index]

# doing this so that it is consistent with all other datasets
# to return a PIL Image
img = Image.fromarray(img)
Expand Down Expand Up @@ -202,7 +187,6 @@ def __getitem__(self, index):
sample_idx = np.hstack((np.asarray([pos_idx]), neg_idx))
return img, target, index, sample_idx


def get_cifar100_dataloaders_sample(batch_size=128, num_workers=8, k=4096, mode='exact',
is_sample=True, percent=1.0):
"""
Expand Down
11 changes: 2 additions & 9 deletions dataset/imagenet.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
from __future__ import print_function

import os
import socket
import numpy as np
from torch.utils.data import DataLoader
from torchvision import datasets
Expand All @@ -13,15 +12,9 @@

def get_data_folder():
"""
return server-dependent path to store the data
return the path to store the data
"""
hostname = socket.gethostname()
if hostname.startswith('visiongpu'):
data_folder = '/data/vision/phillipi/rep-learn/datasets/imagenet'
elif hostname.startswith('yonglong-home'):
data_folder = '/home/yonglong/Data/data/imagenet'
else:
data_folder = './data/imagenet'
data_folder = './data/imagenet'

if not os.path.isdir(data_folder):
os.makedirs(data_folder)
Expand Down
35 changes: 35 additions & 0 deletions distiller_zoo/AA.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
from __future__ import print_function

import torch
import torch.nn as nn
import torch.nn.functional as F


class AAKDLoss(nn.Module):
# """Similarity-Preserving Knowledge Distillation, ICCV2019, verified by original author"""
def __init__(self):
super(AAKDLoss, self).__init__()
self.crit = nn.MSELoss(reduction='none')

def forward(self, s_value, f_target, weight):
loss = 0
bsz, num_stu, num_tea = weight.shape
ind_loss = torch.zeros(bsz, num_stu, num_tea).cuda()
#print("bsz: " + str(bsz))
#print("num_stu: " + str(num_stu))
#print("num_tea: " +str(num_tea))
#print("s_value: " + str(len(s_value)))
for i in range(num_stu):
# print("s_value: " + str(i) + " : " + str(len(s_value[i])))
for j in range(num_tea):
ind_loss[:, i, j] = self.crit(s_value[i][j], f_target[i][j]).reshape(bsz,-1).mean(-1)

loss = (weight * ind_loss).sum()/(1.0*bsz*num_stu)
#f_s = torch.nn.functional.normalize(s_value[:,:,i].squeeze(), dim=1)
#f_t = torch.nn.functional.normalize(f_target[:,:,i].squeeze(), dim=1)
# G_diff = f_t - f_s
# loss = loss + (G_diff * G_diff).view(-1, 1).mean()

#G_diff = f_target[:,:,i].squeeze() - s_value[:,:,i].squeeze()
#loss = loss + (G_diff * G_diff).view(-1, 1).mean()
return loss
2 changes: 2 additions & 0 deletions distiller_zoo/AT.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ def __init__(self, p=2):
self.p = p

def forward(self, g_s, g_t):
# only calculate min(len(g_s), len(g_t))-pair at_loss with the help of zip function
return [self.at_loss(f_s, f_t) for f_s, f_t in zip(g_s, g_t)]

def at_loss(self, f_s, f_t):
Expand All @@ -26,4 +27,5 @@ def at_loss(self, f_s, f_t):
return (self.at(f_s) - self.at(f_t)).pow(2).mean()

def at(self, f):
# mean(1) function reduce feature map BxCxHxW into BxHxW by averaging the channel response
return F.normalize(f.pow(self.p).mean(1).view(f.size(0), -1))
1 change: 1 addition & 0 deletions distiller_zoo/RKD.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ def forward(self, f_s, f_t):

# RKD Angle loss
with torch.no_grad():
# From Bxdim -> 1xBxdim - Bx1xdim = BxBxdim
td = (teacher.unsqueeze(0) - teacher.unsqueeze(1))
norm_td = F.normalize(td, p=2, dim=2)
t_angle = torch.bmm(norm_td, norm_td.transpose(1, 2)).view(-1)
Expand Down
4 changes: 2 additions & 2 deletions distiller_zoo/SP.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,10 @@ def similarity_loss(self, f_s, f_t):

G_s = torch.mm(f_s, torch.t(f_s))
# G_s = G_s / G_s.norm(2)
G_s = torch.nn.functional.normalize(G_s)
G_s = torch.nn.functional.normalize(G_s, dim=1)
G_t = torch.mm(f_t, torch.t(f_t))
# G_t = G_t / G_t.norm(2)
G_t = torch.nn.functional.normalize(G_t)
G_t = torch.nn.functional.normalize(G_t, dim=1)

G_diff = G_t - G_s
loss = (G_diff * G_diff).view(-1, 1).sum(0) / (bsz * bsz)
Expand Down
6 changes: 1 addition & 5 deletions distiller_zoo/__init__.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,10 @@
from .AB import ABLoss
from .AT import Attention
from .CC import Correlation
from .FitNet import HintLoss
from .FSP import FSP
from .FT import FactorTransfer
from .KD import DistillKL
from .KDSVD import KDSVD
from .NST import NSTLoss
from .PKT import PKT
from .RKD import RKDLoss
from .SP import Similarity
from .VID import VIDLoss
from .IRG import IRGLoss
from .AA import AAKDLoss
Loading

0 comments on commit a51d792

Please sign in to comment.