Skip to content

Commit

Permalink
first push
Browse files Browse the repository at this point in the history
  • Loading branch information
aiden.d committed Dec 2, 2020
0 parents commit 5e4ba46
Show file tree
Hide file tree
Showing 20 changed files with 2,743 additions and 0 deletions.
93 changes: 93 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
# Multi-level Distance Regularization for Deep Metric Learning
## Dependencies

You need a `CUDA-enabled GPU` and `python` (>3.6) to run the source code.

- torchvision >= 0.4.2
- torch >= 1.3.1
- tqdm
- scipy
- Pillow


```
pip install -r requirements.txt
```

## Preparing datasets
### 1. Make `dataset` directory
```
mkdir ./dataset
```
### 2. (Optional) Only for In-Shop Clothes Retrieval
The source code will automatically download CUB-200-2011, Cars-196, and Stanford Online Products datasets.


But you need to manually download In-Shop Clothes Retrieval dataset.


1. Make `Inshop` directory in `./dataset` directory
```
mkdir -p ./dataset/Inshop
```
2. Download `img.zip` at the following link, and unzip it in `Inshop` directory
```
https://drive.google.com/drive/folders/0B7EVK8r0v71pYkd5TzBiclMzR00
```
3. Download `list_eval_partition.txt` at the following link, and put it in the `Inshop` directory.
```
https://drive.google.com/drive/folders/0B7EVK8r0v71pWVBJelFmMW5EWnM
```

## Testing on the trained weights
```bash
# The models are trained with Triplet+MDR, please check Table 1.

# CUB-200-2011
wget https://github.com/anonymous-ai-research/pretrained/raw/master/cub200/cub200.pth
python run.py --mode eval --dataset cub200 --load cub200.pth

# Cars-196
wget https://github.com/anonymous-ai-research/pretrained/raw/master/cars196/cars196.pth
python run.py --mode eval --dataset cars196 --load cars196.pth

# Stanford Online Products
wget https://github.com/anonymous-ai-research/pretrained/raw/master/sop/sop.pth
python run.py --mode eval --dataset stanford --load sop.pth

# In-Shop Clothes Retrieval
wget https://github.com/anonymous-ai-research/pretrained/raw/master/inshop/inshop.pth
python run_inshop.py --mode eval --load inshop.pth
```

## Training
```bash
# CUB-200-2011
# Triplet
python run.py --dataset cub200 --lr 5e-5 --recall 1 2 4 8
# Triplet+L2Norm
python run.py --dataset cub200 --lr 5e-5 --recall 1 2 4 8 --l2norm
# Triplet+MDR
python run.py --dataset cub200 --lr 5e-5 --recall 1 2 4 8 --lambda-mdr 0.6 --nu-mdr 0.01
# Cars-196
# Triplet
python run.py --dataset cars196 --lr 5e-5 --recall 1 2 4 8
# Triplet+L2Norm
python run.py --dataset cars196 --lr 5e-5 --recall 1 2 4 8 --l2norm
# Triplet+MDR
python run.py --dataset cars196 --lr 5e-5 --recall 1 2 4 8 --lambda-mdr 0.2 --nu-mdr 0.01
# Stanford Online Products
# Triplet
python run.py --dataset stanford --num-image-per-class 3 --batch 256 --lr 1e-4 --recall 1 10 100 1000
# Triplet+L2Norm
python run.py --dataset stanford --num_image_per_class 3 --batch 256 --lr 1e-4 --recall 1 10 100 1000 --l2norm
# Triplet+MDR
python run.py --dataset stanford --num-image-per-class 3 --batch 256 --lr 1e-4 --recall 1 10 100 1000 --lambda-mdr 0.1 --nu-mdr 0.01
# In-Shop Clothes Retrieval
# Triplet
python run_inshop.py --num-image-per-class 3 --batch 256 --lr 1e-4 --recall 1 10 20 30 40
# Triplet+L2Norm
python run_inshop.py --num-image-per-class 3 --batch 256 --lr 1e-4 --recall 1 10 20 30 40 --l2norm
# Triplet+MDR
python run_inshop.py --num-image-per-class 3 --batch 256 --lr 1e-4 --recall 1 10 20 30 40 --lambda-mdr 0.1 --nu-mdr 0.01
```
Empty file added metric/__init__.py
Empty file.
45 changes: 45 additions & 0 deletions metric/batchsampler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
import random

from torchvision.datasets import ImageFolder


def index_dataset(dataset: ImageFolder):
kv = [(cls_ind, idx) for idx, (_, cls_ind) in enumerate(dataset.imgs)]
cls_to_ind = {}

for k, v in kv:
if k in cls_to_ind:
cls_to_ind[k].append(v)
else:
cls_to_ind[k] = [v]

return cls_to_ind


class MImagesPerClassSampler:
def __init__(self, data_source: ImageFolder, batch_size, m=5, iter_per_epoch=100):
self.m = m
self.batch_size = batch_size
self.n_batch = iter_per_epoch
self.class_idx = list(data_source.class_to_idx.values())
self.images_by_class = index_dataset(data_source)

def __len__(self):
return self.n_batch

def __iter__(self):
for _ in range(self.n_batch):
selected_class = random.sample(self.class_idx, k=len(self.class_idx))
example_indices = []

for c in selected_class:
img_ind_of_cls = self.images_by_class[c]
new_ind = random.sample(
img_ind_of_cls, k=min(self.m, len(img_ind_of_cls))
)
example_indices += new_ind

if len(example_indices) >= self.batch_size:
break

yield example_indices[: self.batch_size]
86 changes: 86 additions & 0 deletions metric/loss.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
import torch
import torch.nn as nn
import torch.nn.functional as F

from metric.utils import pdist


class TripletLoss(nn.Module):
def __init__(self, margin=0.2, sampler=None, reduce=True, size_average=True):
super().__init__()
self.margin = margin

self.sampler = sampler
self.sampler.dist_func = lambda e: pdist(e, squared=(p == 2))

self.reduce = reduce
self.size_average = size_average

def forward(self, embeddings, labels):
anchor_idx, pos_idx, neg_idx = self.sampler(embeddings, labels)

anchor_embed = embeddings[anchor_idx]
positive_embed = embeddings[pos_idx]
negative_embed = embeddings[neg_idx]

loss = F.triplet_margin_loss(
anchor_embed,
positive_embed,
negative_embed,
margin=self.margin,
reduction="none",
)

if not self.reduce:
return loss

if self.size_average:
return loss.mean()
else:
return loss.sum()


class MDRLoss(nn.Module):
def __init__(self):
super().__init__()
self.levels = nn.Parameter(torch.tensor([-3.0, 0.0, 3.0]))
self.momentum = 0.9

momented_mean = torch.zeros(1)
momented_std = torch.zeros(1)
self.register_buffer("momented_mean", momented_mean)
self.register_buffer("momented_std", momented_std)

# The variable is used to check whether momented_mean and momented_std are initialized
self.init = False

def initialize_statistics(self, mean, std):
self.momented_mean = mean
self.momented_std = std
self.init = True

def forward(self, embeddings):
dist_mat = pdist(embeddings)
pdist_mat = dist_mat[
~torch.eye(dist_mat.shape[0], dtype=torch.bool, device=dist_mat.device,)
]
dist_mat = dist_mat.view(-1)

mean = dist_mat.mean().detach()
std = dist_mat.std().detach()

if not self.init:
self.initialize_statistics(mean, std)
else:
self.momented_mean = (
1 - self.momentum
) * mean + self.momentum * self.momented_mean
self.momented_std = (
1 - self.momentum
) * std + self.momentum * self.momented_std

normalized_dist = (pdist_mat - self.momented_mean) / self.momented_std
difference = (normalized_dist[None] - self.levels[:, None]).abs().min(dim=0)[0]
loss = difference.mean()

return loss
160 changes: 160 additions & 0 deletions metric/pairsampler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,160 @@
import torch
import torch.nn as nn
import torch.nn.functional as F

from metric.utils import pdist

BIG_NUMBER = 1e12
__all__ = [
"AllPairs",
"HardNegative",
"SemiHardNegative",
"DistanceWeighted",
"RandomNegative",
]


def pos_neg_mask(labels):
pos_mask = (labels.unsqueeze(0) == labels.unsqueeze(1)) * (
1 - torch.eye(labels.size(0), dtype=torch.uint8, device=labels.device)
)
neg_mask = (labels.unsqueeze(0) != labels.unsqueeze(1)) * (
1 - torch.eye(labels.size(0), dtype=torch.uint8, device=labels.device)
)

return pos_mask, neg_mask


class _Sampler(nn.Module):
def __init__(self, dist_func=pdist):
self.dist_func = dist_func
super().__init__()

def forward(self, embeddings, labels):
raise NotImplementedError


class AllPairs(_Sampler):
def forward(self, embeddings, labels):
with torch.no_grad():
pos_mask, neg_mask = pos_neg_mask(labels)
pos_pair_idx = pos_mask.nonzero()

apns = []
for pair_idx in pos_pair_idx:
anchor_idx = pair_idx[0]
neg_indices = neg_mask[anchor_idx].nonzero()

apn = torch.cat(
(pair_idx.unsqueeze(0).repeat(len(neg_indices), 1), neg_indices),
dim=1,
)
apns.append(apn)
apns = torch.cat(apns, dim=0)
anchor_idx = apns[:, 0]
pos_idx = apns[:, 1]
neg_idx = apns[:, 2]

return anchor_idx, pos_idx, neg_idx


class RandomNegative(_Sampler):
def forward(self, embeddings, labels):
with torch.no_grad():
pos_mask, neg_mask = pos_neg_mask(labels)

pos_pair_index = pos_mask.nonzero()
anchor_idx = pos_pair_index[:, 0]
pos_idx = pos_pair_index[:, 1]
neg_index = torch.multinomial(neg_mask.float()[anchor_idx], 1).squeeze(1)

return anchor_idx, pos_idx, neg_index


class HardNegative(_Sampler):
def forward(self, embeddings, labels):
with torch.no_grad():
pos_mask, neg_mask = pos_neg_mask(labels)
dist = self.dist_func(embeddings)

pos_pair_index = pos_mask.nonzero()
anchor_idx = pos_pair_index[:, 0]
pos_idx = pos_pair_index[:, 1]

neg_dist = neg_mask.float() * dist
neg_dist[neg_dist <= 0] = BIG_NUMBER
neg_idx = neg_dist.argmin(dim=1)[anchor_idx]

return anchor_idx, pos_idx, neg_idx


class SemiHardNegative(_Sampler):
def forward(self, embeddings, labels):
with torch.no_grad():
dist = self.dist_func(embeddings)
pos_mask, neg_mask = pos_neg_mask(labels)
neg_dist = dist * neg_mask.float()

pos_pair_idx = pos_mask.nonzero()
anchor_idx = pos_pair_idx[:, 0]
pos_idx = pos_pair_idx[:, 1]

tiled_negative = neg_dist[anchor_idx]
satisfied_neg = (tiled_negative > dist[pos_mask].unsqueeze(1)) * neg_mask[
anchor_idx
]
"""
When there is no negative pair that its distance bigger than positive pair,
then select negative pair with largest distance.
"""
unsatisfied_neg = (satisfied_neg.sum(dim=1) == 0).unsqueeze(1) * neg_mask[
anchor_idx
]

tiled_negative = (satisfied_neg.float() * tiled_negative) - (
unsatisfied_neg.float() * tiled_negative
)
tiled_negative[tiled_negative == 0] = BIG_NUMBER
neg_idx = tiled_negative.argmin(dim=1)

return anchor_idx, pos_idx, neg_idx


class DistanceWeighted(_Sampler):
cut_off = 0.5
nonzero_loss_cutoff = 1.4
"""
Distance Weighted loss assume that embeddings are normalized py 2-norm.
"""

def forward(self, embeddings, labels):
with torch.no_grad():
embeddings = F.normalize(embeddings, dim=1, p=2)
pos_mask, neg_mask = pos_neg_mask(labels)
pos_pair_idx = pos_mask.nonzero()
anchor_idx = pos_pair_idx[:, 0]
pos_idx = pos_pair_idx[:, 1]

d = embeddings.size(1)
dist = (
pdist(embeddings, squared=True)
+ torch.eye(
embeddings.size(0), device=embeddings.device, dtype=torch.float32
)
).sqrt()
dist = dist.clamp(min=self.cut_off)

log_weight = (2.0 - d) * dist.log() - ((d - 3.0) / 2.0) * (
1.0 - 0.25 * (dist * dist)
).log()
weight = (log_weight - log_weight.max(dim=1, keepdim=True)[0]).exp()
weight = weight * (neg_mask * (dist < self.nonzero_loss_cutoff)).float()

weight = (
weight + ((weight.sum(dim=1, keepdim=True) == 0) * neg_mask).float()
)
weight = weight / (weight.sum(dim=1, keepdim=True))
weight = weight[anchor_idx]
neg_idx = torch.multinomial(weight, 1).squeeze(1)

return anchor_idx, pos_idx, neg_idx
Loading

0 comments on commit 5e4ba46

Please sign in to comment.