-
Notifications
You must be signed in to change notification settings - Fork 4
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
aiden.d
committed
Dec 2, 2020
0 parents
commit 5e4ba46
Showing
20 changed files
with
2,743 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,93 @@ | ||
# Multi-level Distance Regularization for Deep Metric Learning | ||
## Dependencies | ||
|
||
You need a `CUDA-enabled GPU` and `python` (>3.6) to run the source code. | ||
|
||
- torchvision >= 0.4.2 | ||
- torch >= 1.3.1 | ||
- tqdm | ||
- scipy | ||
- Pillow | ||
|
||
|
||
``` | ||
pip install -r requirements.txt | ||
``` | ||
|
||
## Preparing datasets | ||
### 1. Make `dataset` directory | ||
``` | ||
mkdir ./dataset | ||
``` | ||
### 2. (Optional) Only for In-Shop Clothes Retrieval | ||
The source code will automatically download CUB-200-2011, Cars-196, and Stanford Online Products datasets. | ||
|
||
|
||
But you need to manually download In-Shop Clothes Retrieval dataset. | ||
|
||
|
||
1. Make `Inshop` directory in `./dataset` directory | ||
``` | ||
mkdir -p ./dataset/Inshop | ||
``` | ||
2. Download `img.zip` at the following link, and unzip it in `Inshop` directory | ||
``` | ||
https://drive.google.com/drive/folders/0B7EVK8r0v71pYkd5TzBiclMzR00 | ||
``` | ||
3. Download `list_eval_partition.txt` at the following link, and put it in the `Inshop` directory. | ||
``` | ||
https://drive.google.com/drive/folders/0B7EVK8r0v71pWVBJelFmMW5EWnM | ||
``` | ||
|
||
## Testing on the trained weights | ||
```bash | ||
# The models are trained with Triplet+MDR, please check Table 1. | ||
|
||
# CUB-200-2011 | ||
wget https://github.com/anonymous-ai-research/pretrained/raw/master/cub200/cub200.pth | ||
python run.py --mode eval --dataset cub200 --load cub200.pth | ||
|
||
# Cars-196 | ||
wget https://github.com/anonymous-ai-research/pretrained/raw/master/cars196/cars196.pth | ||
python run.py --mode eval --dataset cars196 --load cars196.pth | ||
|
||
# Stanford Online Products | ||
wget https://github.com/anonymous-ai-research/pretrained/raw/master/sop/sop.pth | ||
python run.py --mode eval --dataset stanford --load sop.pth | ||
|
||
# In-Shop Clothes Retrieval | ||
wget https://github.com/anonymous-ai-research/pretrained/raw/master/inshop/inshop.pth | ||
python run_inshop.py --mode eval --load inshop.pth | ||
``` | ||
|
||
## Training | ||
```bash | ||
# CUB-200-2011 | ||
# Triplet | ||
python run.py --dataset cub200 --lr 5e-5 --recall 1 2 4 8 | ||
# Triplet+L2Norm | ||
python run.py --dataset cub200 --lr 5e-5 --recall 1 2 4 8 --l2norm | ||
# Triplet+MDR | ||
python run.py --dataset cub200 --lr 5e-5 --recall 1 2 4 8 --lambda-mdr 0.6 --nu-mdr 0.01 | ||
# Cars-196 | ||
# Triplet | ||
python run.py --dataset cars196 --lr 5e-5 --recall 1 2 4 8 | ||
# Triplet+L2Norm | ||
python run.py --dataset cars196 --lr 5e-5 --recall 1 2 4 8 --l2norm | ||
# Triplet+MDR | ||
python run.py --dataset cars196 --lr 5e-5 --recall 1 2 4 8 --lambda-mdr 0.2 --nu-mdr 0.01 | ||
# Stanford Online Products | ||
# Triplet | ||
python run.py --dataset stanford --num-image-per-class 3 --batch 256 --lr 1e-4 --recall 1 10 100 1000 | ||
# Triplet+L2Norm | ||
python run.py --dataset stanford --num_image_per_class 3 --batch 256 --lr 1e-4 --recall 1 10 100 1000 --l2norm | ||
# Triplet+MDR | ||
python run.py --dataset stanford --num-image-per-class 3 --batch 256 --lr 1e-4 --recall 1 10 100 1000 --lambda-mdr 0.1 --nu-mdr 0.01 | ||
# In-Shop Clothes Retrieval | ||
# Triplet | ||
python run_inshop.py --num-image-per-class 3 --batch 256 --lr 1e-4 --recall 1 10 20 30 40 | ||
# Triplet+L2Norm | ||
python run_inshop.py --num-image-per-class 3 --batch 256 --lr 1e-4 --recall 1 10 20 30 40 --l2norm | ||
# Triplet+MDR | ||
python run_inshop.py --num-image-per-class 3 --batch 256 --lr 1e-4 --recall 1 10 20 30 40 --lambda-mdr 0.1 --nu-mdr 0.01 | ||
``` |
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,45 @@ | ||
import random | ||
|
||
from torchvision.datasets import ImageFolder | ||
|
||
|
||
def index_dataset(dataset: ImageFolder): | ||
kv = [(cls_ind, idx) for idx, (_, cls_ind) in enumerate(dataset.imgs)] | ||
cls_to_ind = {} | ||
|
||
for k, v in kv: | ||
if k in cls_to_ind: | ||
cls_to_ind[k].append(v) | ||
else: | ||
cls_to_ind[k] = [v] | ||
|
||
return cls_to_ind | ||
|
||
|
||
class MImagesPerClassSampler: | ||
def __init__(self, data_source: ImageFolder, batch_size, m=5, iter_per_epoch=100): | ||
self.m = m | ||
self.batch_size = batch_size | ||
self.n_batch = iter_per_epoch | ||
self.class_idx = list(data_source.class_to_idx.values()) | ||
self.images_by_class = index_dataset(data_source) | ||
|
||
def __len__(self): | ||
return self.n_batch | ||
|
||
def __iter__(self): | ||
for _ in range(self.n_batch): | ||
selected_class = random.sample(self.class_idx, k=len(self.class_idx)) | ||
example_indices = [] | ||
|
||
for c in selected_class: | ||
img_ind_of_cls = self.images_by_class[c] | ||
new_ind = random.sample( | ||
img_ind_of_cls, k=min(self.m, len(img_ind_of_cls)) | ||
) | ||
example_indices += new_ind | ||
|
||
if len(example_indices) >= self.batch_size: | ||
break | ||
|
||
yield example_indices[: self.batch_size] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,86 @@ | ||
import torch | ||
import torch.nn as nn | ||
import torch.nn.functional as F | ||
|
||
from metric.utils import pdist | ||
|
||
|
||
class TripletLoss(nn.Module): | ||
def __init__(self, margin=0.2, sampler=None, reduce=True, size_average=True): | ||
super().__init__() | ||
self.margin = margin | ||
|
||
self.sampler = sampler | ||
self.sampler.dist_func = lambda e: pdist(e, squared=(p == 2)) | ||
|
||
self.reduce = reduce | ||
self.size_average = size_average | ||
|
||
def forward(self, embeddings, labels): | ||
anchor_idx, pos_idx, neg_idx = self.sampler(embeddings, labels) | ||
|
||
anchor_embed = embeddings[anchor_idx] | ||
positive_embed = embeddings[pos_idx] | ||
negative_embed = embeddings[neg_idx] | ||
|
||
loss = F.triplet_margin_loss( | ||
anchor_embed, | ||
positive_embed, | ||
negative_embed, | ||
margin=self.margin, | ||
reduction="none", | ||
) | ||
|
||
if not self.reduce: | ||
return loss | ||
|
||
if self.size_average: | ||
return loss.mean() | ||
else: | ||
return loss.sum() | ||
|
||
|
||
class MDRLoss(nn.Module): | ||
def __init__(self): | ||
super().__init__() | ||
self.levels = nn.Parameter(torch.tensor([-3.0, 0.0, 3.0])) | ||
self.momentum = 0.9 | ||
|
||
momented_mean = torch.zeros(1) | ||
momented_std = torch.zeros(1) | ||
self.register_buffer("momented_mean", momented_mean) | ||
self.register_buffer("momented_std", momented_std) | ||
|
||
# The variable is used to check whether momented_mean and momented_std are initialized | ||
self.init = False | ||
|
||
def initialize_statistics(self, mean, std): | ||
self.momented_mean = mean | ||
self.momented_std = std | ||
self.init = True | ||
|
||
def forward(self, embeddings): | ||
dist_mat = pdist(embeddings) | ||
pdist_mat = dist_mat[ | ||
~torch.eye(dist_mat.shape[0], dtype=torch.bool, device=dist_mat.device,) | ||
] | ||
dist_mat = dist_mat.view(-1) | ||
|
||
mean = dist_mat.mean().detach() | ||
std = dist_mat.std().detach() | ||
|
||
if not self.init: | ||
self.initialize_statistics(mean, std) | ||
else: | ||
self.momented_mean = ( | ||
1 - self.momentum | ||
) * mean + self.momentum * self.momented_mean | ||
self.momented_std = ( | ||
1 - self.momentum | ||
) * std + self.momentum * self.momented_std | ||
|
||
normalized_dist = (pdist_mat - self.momented_mean) / self.momented_std | ||
difference = (normalized_dist[None] - self.levels[:, None]).abs().min(dim=0)[0] | ||
loss = difference.mean() | ||
|
||
return loss |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,160 @@ | ||
import torch | ||
import torch.nn as nn | ||
import torch.nn.functional as F | ||
|
||
from metric.utils import pdist | ||
|
||
BIG_NUMBER = 1e12 | ||
__all__ = [ | ||
"AllPairs", | ||
"HardNegative", | ||
"SemiHardNegative", | ||
"DistanceWeighted", | ||
"RandomNegative", | ||
] | ||
|
||
|
||
def pos_neg_mask(labels): | ||
pos_mask = (labels.unsqueeze(0) == labels.unsqueeze(1)) * ( | ||
1 - torch.eye(labels.size(0), dtype=torch.uint8, device=labels.device) | ||
) | ||
neg_mask = (labels.unsqueeze(0) != labels.unsqueeze(1)) * ( | ||
1 - torch.eye(labels.size(0), dtype=torch.uint8, device=labels.device) | ||
) | ||
|
||
return pos_mask, neg_mask | ||
|
||
|
||
class _Sampler(nn.Module): | ||
def __init__(self, dist_func=pdist): | ||
self.dist_func = dist_func | ||
super().__init__() | ||
|
||
def forward(self, embeddings, labels): | ||
raise NotImplementedError | ||
|
||
|
||
class AllPairs(_Sampler): | ||
def forward(self, embeddings, labels): | ||
with torch.no_grad(): | ||
pos_mask, neg_mask = pos_neg_mask(labels) | ||
pos_pair_idx = pos_mask.nonzero() | ||
|
||
apns = [] | ||
for pair_idx in pos_pair_idx: | ||
anchor_idx = pair_idx[0] | ||
neg_indices = neg_mask[anchor_idx].nonzero() | ||
|
||
apn = torch.cat( | ||
(pair_idx.unsqueeze(0).repeat(len(neg_indices), 1), neg_indices), | ||
dim=1, | ||
) | ||
apns.append(apn) | ||
apns = torch.cat(apns, dim=0) | ||
anchor_idx = apns[:, 0] | ||
pos_idx = apns[:, 1] | ||
neg_idx = apns[:, 2] | ||
|
||
return anchor_idx, pos_idx, neg_idx | ||
|
||
|
||
class RandomNegative(_Sampler): | ||
def forward(self, embeddings, labels): | ||
with torch.no_grad(): | ||
pos_mask, neg_mask = pos_neg_mask(labels) | ||
|
||
pos_pair_index = pos_mask.nonzero() | ||
anchor_idx = pos_pair_index[:, 0] | ||
pos_idx = pos_pair_index[:, 1] | ||
neg_index = torch.multinomial(neg_mask.float()[anchor_idx], 1).squeeze(1) | ||
|
||
return anchor_idx, pos_idx, neg_index | ||
|
||
|
||
class HardNegative(_Sampler): | ||
def forward(self, embeddings, labels): | ||
with torch.no_grad(): | ||
pos_mask, neg_mask = pos_neg_mask(labels) | ||
dist = self.dist_func(embeddings) | ||
|
||
pos_pair_index = pos_mask.nonzero() | ||
anchor_idx = pos_pair_index[:, 0] | ||
pos_idx = pos_pair_index[:, 1] | ||
|
||
neg_dist = neg_mask.float() * dist | ||
neg_dist[neg_dist <= 0] = BIG_NUMBER | ||
neg_idx = neg_dist.argmin(dim=1)[anchor_idx] | ||
|
||
return anchor_idx, pos_idx, neg_idx | ||
|
||
|
||
class SemiHardNegative(_Sampler): | ||
def forward(self, embeddings, labels): | ||
with torch.no_grad(): | ||
dist = self.dist_func(embeddings) | ||
pos_mask, neg_mask = pos_neg_mask(labels) | ||
neg_dist = dist * neg_mask.float() | ||
|
||
pos_pair_idx = pos_mask.nonzero() | ||
anchor_idx = pos_pair_idx[:, 0] | ||
pos_idx = pos_pair_idx[:, 1] | ||
|
||
tiled_negative = neg_dist[anchor_idx] | ||
satisfied_neg = (tiled_negative > dist[pos_mask].unsqueeze(1)) * neg_mask[ | ||
anchor_idx | ||
] | ||
""" | ||
When there is no negative pair that its distance bigger than positive pair, | ||
then select negative pair with largest distance. | ||
""" | ||
unsatisfied_neg = (satisfied_neg.sum(dim=1) == 0).unsqueeze(1) * neg_mask[ | ||
anchor_idx | ||
] | ||
|
||
tiled_negative = (satisfied_neg.float() * tiled_negative) - ( | ||
unsatisfied_neg.float() * tiled_negative | ||
) | ||
tiled_negative[tiled_negative == 0] = BIG_NUMBER | ||
neg_idx = tiled_negative.argmin(dim=1) | ||
|
||
return anchor_idx, pos_idx, neg_idx | ||
|
||
|
||
class DistanceWeighted(_Sampler): | ||
cut_off = 0.5 | ||
nonzero_loss_cutoff = 1.4 | ||
""" | ||
Distance Weighted loss assume that embeddings are normalized py 2-norm. | ||
""" | ||
|
||
def forward(self, embeddings, labels): | ||
with torch.no_grad(): | ||
embeddings = F.normalize(embeddings, dim=1, p=2) | ||
pos_mask, neg_mask = pos_neg_mask(labels) | ||
pos_pair_idx = pos_mask.nonzero() | ||
anchor_idx = pos_pair_idx[:, 0] | ||
pos_idx = pos_pair_idx[:, 1] | ||
|
||
d = embeddings.size(1) | ||
dist = ( | ||
pdist(embeddings, squared=True) | ||
+ torch.eye( | ||
embeddings.size(0), device=embeddings.device, dtype=torch.float32 | ||
) | ||
).sqrt() | ||
dist = dist.clamp(min=self.cut_off) | ||
|
||
log_weight = (2.0 - d) * dist.log() - ((d - 3.0) / 2.0) * ( | ||
1.0 - 0.25 * (dist * dist) | ||
).log() | ||
weight = (log_weight - log_weight.max(dim=1, keepdim=True)[0]).exp() | ||
weight = weight * (neg_mask * (dist < self.nonzero_loss_cutoff)).float() | ||
|
||
weight = ( | ||
weight + ((weight.sum(dim=1, keepdim=True) == 0) * neg_mask).float() | ||
) | ||
weight = weight / (weight.sum(dim=1, keepdim=True)) | ||
weight = weight[anchor_idx] | ||
neg_idx = torch.multinomial(weight, 1).squeeze(1) | ||
|
||
return anchor_idx, pos_idx, neg_idx |
Oops, something went wrong.