Skip to content

Commit

Permalink
Add LAMDA
Browse files Browse the repository at this point in the history
  • Loading branch information
tuanrpt committed Oct 28, 2021
1 parent 87cbcae commit 703b3ad
Show file tree
Hide file tree
Showing 9 changed files with 1,667 additions and 1 deletion.
121 changes: 120 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
@@ -1 +1,120 @@
# LAMDA
# LAMDA: Label Matching Deep Domain Adaptation

This is the implementation of paper **[LAMDA: Label Matching Deep Domain Adaptation](http:https://proceedings.mlr.press/v139/le21a/le21a.pdf)** which has been accepted at ICML 2021.

## A. Setup

### A.1. Install Package Dependencies

**Install manually**

```
Python Environment: >= 3.5
Tensorflow: >= 1.9
```

**Install automatically from YAML file**

```
pip install --upgrade pip
conda env create --file tf1.9py3.5.yml
```

**[UPDATE] Install tensorbayes**

Please note that tensorbayes 0.4.0 is out of date. Please copy a newer version to the *env* folder (tf1.9py3.5) using **tensorbayes.tar**

```
source activate tf1.9py3.5
pip install tensorbayes
tar -xvf tensorbayes.tar
cp -rf /tensorbayes/* /opt/conda/envs/tf1.9py3.5/lib/python3.5/site-packages/tensorbayes/
```

### A.2. Datasets

Please download Office-31 [here](https://drive.google.com/file/d/1dsrHn4S6lCmlTa4Eg4RAE5JRfZUIxR8G/view?usp=sharing) and unzip extracted features in the *datasets* folder.

## B. Training

We first navigate to *model* folder, and then run *run_lamda.py* file as bellow:

```python
cd model
```

1. **A** --> **W** task

```python
python run_lamda.py 1 amazon webcam format csv num_iters 20000 summary_freq 400 learning_rate 0.0001 inorm True batch_size 310 src_class_trade_off 1.0 domain_trade_off 0.1 src_vat_trade_off 0.1 trg_trade_off 0.1 save_grads False cast_data False cnn_size small update_target_loss False m_on_D_trade_off 1.0 m_plus_1_on_D_trade_off 1.0 m_plus_1_on_G_trade_off 1.0 m_on_G_trade_off 0.1 data_path ""
```

2. **A** --> **D** task

```python
python run_lamda.py 1 amazon dslr format csv num_iters 20000 summary_freq 400 learning_rate 0.0001 inorm True batch_size 310 src_class_trade_off 1.0 domain_trade_off 0.1 src_vat_trade_off 1.0 trg_trade_off 0.1 save_grads False cast_data False cnn_size small update_target_loss False m_on_D_trade_off 1.0 m_plus_1_on_D_trade_off 1.0 m_plus_1_on_G_trade_off 1.0 m_on_G_trade_off 0.05 data_path ""
```

3. **D** --> **W** task

```python
python run_lamda.py 1 dslr webcam format csv num_iters 20000 summary_freq 400 learning_rate 0.0001 inorm True batch_size 155 src_class_trade_off 1.0 domain_trade_off 0.1 src_vat_trade_off 0.1 trg_trade_off 0.1 save_grads False cast_data False cnn_size small update_target_loss False m_on_D_trade_off 1.0 m_plus_1_on_D_trade_off 1.0 m_plus_1_on_G_trade_off 1.0 m_on_G_trade_off 0.1 data_path ""
```

4. **W** --> **D** task

```python
python run_lamda.py 1 webcam dslr format csv num_iters 20000 summary_freq 400 learning_rate 0.0001 inorm True batch_size 310 src_class_trade_off 1.0 domain_trade_off 0.1 src_vat_trade_off 0.1 trg_trade_off 0.1 save_grads False cast_data False cnn_size small update_target_loss False m_on_D_trade_off 1.0 m_plus_1_on_D_trade_off 1.0 m_plus_1_on_G_trade_off 1.0 m_on_G_trade_off 0.1 data_path ""
```

5. **D** --> **A** task

```python
python run_lamda.py 1 dslr amazon format csv num_iters 20000 sumary_freq 400 learning_rate 0.0001 inorm True batch_size 155 src_class_trade_off 1.0 domain_trade_off 0.1 src_vat_trade_off 1.0 trg_trade_off 0.1 save_grads False cast_data False cnn_size small update_target_loss False m_on_D_trade_off 1.0 m_plus_1_on_D_trade_off 1.0 m_plus_1_on_G_trade_off 1.0 m_on_G_trade_off 1.0 data_path ""
```

6. **W** --> **A** task

```python
python run_lamda.py 1 webcam amazon format csv num_iters 20000 summary_freq 400 learning_rate 0.0001 inorm True batch_size 310 src_class_trade_off 1.0 domain_trade_off 0.1 src_vat_trade_off 1.0 trg_trade_off 0.1 save_grads False cast_data False cnn_size small update_target_loss False m_on_D_trade_off 1.0 m_plus_1_on_D_trade_off 1.0 m_plus_1_on_G_trade_off 1.0 m_on_G_trade_off 1.0 data_path ""
```



## C. Results

| Methods | **A** --> **W** | **A** --> **D** | **D** --> **W** | **W** --> **D** | **D** --> **A** | **W** --> **A** | Avg |
| :-----------: | :-------------: | :-------------: | :-------------: | :-------------: | :-------------: | :-------------: | :------: |
| ResNet-50 [1] | 70.0 | 65.5 | 96.1 | 99.3 | 62.8 | 60.5 | 75.7 |
| DeepCORAL [2] | 83.0 | 71.5 | 97.9 | 98.0 | 63.7 | 64.5 | 79.8 |
| DANN [3] | 81.5 | 74.3 | 97.1 | 99.6 | 65.5 | 63.2 | 80.2 |
| ADDA [4] | 86.2 | 78.8 | 96.8 | 99.1 | 69.5 | 68.5 | 83.2 |
| CDAN [5] | 94.1 | 92.9 | 98.6 | **100.0** | 71.0 | 69.3 | 87.7 |
| TPN [6] | 91.2 | 89.9 | 97.7 | 99.5 | 70.5 | 73.5 | 87.1 |
| DeepJDOT [7] | 88.9 | 88.2 | 98.5 | 99.6 | 72.1 | 70.1 | 86.2 |
| RWOT [8] | 95.1 | 94.5 | 99.5 | **100.0** | 77.5 | 77.9 | 90.8 |
| **LAMDA** | **95.2** | **96.0** | 98.5 | **99.8** | **87.3** | **84.4** | **93.0** |

## D. References

### D.1. Baselines:

[1] K. He, X. Zhang, S. Ren, and J. Sun. Deep residual learning for image recognition. In 2016 IEEE Conference on Computer Vision and Pattern Recognition (CVPR), pages 770–778, 2016.

[2] B. Sun and K. Saenko. Deep coral: Correlation alignment for deep domain adaptation. In Gang Hua and Hervé Jéegou, editors, Computer Vision – ECCV 2016 Workshops, pages 443–450, Cham, 2016. Springer International Publishing.

[3] Y. Ganin, E. Ustinova, H. Ajakan, P. Germain, H. Larochelle, F. Laviolette, M. Marchand, and V. Lempitsky. Domain-adversarial training of neural networks. J. Mach. Learn. Res., 17(1):2096–2030, jan 2016.

[4] E. Tzeng, J. Hoffman, K. Saenko, and T. Darrell. Adversarial discriminative domain adaptation. In 2017 IEEE Conference on Computer Vision and Pattern Recognition (CVPR), pages 2962–2971, 2017.

[5] M. Long, Z. Cao, J. Wang, and M. I. Jordan. Conditional adversarial domain adaptation. In Advances in Neural Information Processing Systems 31, pages 1640–1650. Curran Associates, Inc., 2018.

[6] Y. Pan, T. Yao, Y. Li, Y. Wang, C. Ngo, and T. Mei. Transferrable prototypical networks for unsupervised domain adaptation. In CVPR, pages 2234–2242, 2019.

[7] B. B. Damodaran, B. Kellenberger, R. Flamary, D. Tuia, and N. Courty. Deepjdot: Deep joint distribution optimal transport for unsupervised domain adaptation. In Computer Vision - ECCV 2018, pages 467–483. Springer, 2018.

[8] R. Xu, P. Liu, L. Wang, C. Chen, and J. Wang. Reliable weighted optimal transport for unsupervised domain adaptation. In CVPR 2020, June 2020.

### D.2. GitHub repositories:

- Some parts of our code (e.g., VAT, evaluation, …) are rewritten with modifications from [DIRT-T](https://github.com/RuiShu/dirt-t).
165 changes: 165 additions & 0 deletions model/dataLoader.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,165 @@
# Copyright (c) 2021, Tuan Nguyen.
# All rights reserved.

import os

import numpy as np
from scipy.io import loadmat
from scipy import misc
import time
import h5py
from keras.utils.np_utils import to_categorical
from generic_utils import random_seed
import csv


def load_mat_office_caltech10_decaf(filename):
data = loadmat(filename)
x = np.reshape(data['feas'], (-1, 8, 8, 64))
y = data['labels'][0]
# y = y.reshape(-1) - 1
return x, y


def load_mat_office31_ResNet50(filename):
data = loadmat(filename)
# x = np.reshape(data['feas'], (-1, 8, 8, 64))
x = data['feas']
y = data['labels'][0]
return x, y


def load_mat_office31_AlexNet(filename):
data = loadmat(filename)
# x = np.reshape(data['feas'], (-1, 8, 8, 64))
x = data['feas']
y = data['labels'][0]
return x, y


def load_office31_resnet50_feature(file_path_train):
file = open(file_path_train, "r")
reader = csv.reader(file)
features_full = []
labels_full = []
for line in reader:
feature_i = np.asarray(line[:2048]).astype(np.float32)
label_i = int(float(line[2048]))
features_full.append(feature_i)
labels_full.append(label_i)
features_full = np.asarray(features_full)
labels_full = np.asarray(labels_full)
return features_full, labels_full


def load_mat_office_caltech10_ResNet101(filename):
data = loadmat(filename)
x = data['feas']
y = data['labels'][0]
return x, y


def load_mat_file_single_label(filename):
filename_list = ['mnist', 'stl32', 'synsign', 'gtsrb', 'cifar32', 'usps32']
data = loadmat(filename)
x = data['X']
y = data['y']
if any(fn in filename for fn in filename_list):
if 'mnist32_60_10' not in filename and 'mnistg' not in filename:
y = y[0]
else:
y = np.argmax(y, axis=1)
# process one-hot label encoder
elif len(y.shape) > 1:
y = np.argmax(y, axis=1)
return x, y


def u2t(x):
max_num = 50000
if len(x) > max_num:
y = np.empty_like(x, dtype='float32')
for i in range(len(x) // max_num):
y[i*max_num: (i+1)*max_num] = (x[i*max_num: (i+1)*max_num].astype('float32') / 255) * 2 - 1

y[(i + 1) * max_num:] = (x[(i + 1) * max_num:].astype('float32') / 255) * 2 - 1
else:
y = (x.astype('float32') / 255) * 2 - 1
return y


class DataLoader:
def __init__(self, src_domain=['mnistm'], trg_domain=['mnist'], data_path='./dataset', data_format='mat',
shuffle_data=False, dataset_name='digits', cast_data=True):
self.num_src_domain = len(src_domain.split(','))
self.src_domain_name = src_domain
self.trg_domain_name = trg_domain
self.data_path = data_path
self.data_format = data_format
self.shuffle_data = shuffle_data
self.dataset_name = dataset_name
self.cast_data = cast_data

self.src_train = {} # {idx : ['src_idx']['src_name', x_train, y_train]}
self.trg_train = {} # {idx : ['trg_idx']['trg_name', x_train, y_train]}
self.src_test = {}
self.trg_test = {}

print("Source domains", self.src_domain_name)
print("Target domain", self.trg_domain_name)
print("----- Training data -----")
self._load_data_train()
print("----- Test data -----")
self._load_data_test()

self.data_shape = self.src_train[0][1][0].shape
self.num_domain = len(self.src_train.keys())
self.num_class = self.src_train[0][2].shape[-1]

def _load_data_train(self, tail_name="_train"):
if not self.src_train:
self.src_train = self._load_file(self.src_domain_name, tail_name, self.shuffle_data)
self.trg_train = self._load_file(self.trg_domain_name, tail_name, self.shuffle_data)

def _load_data_test(self, tail_name="_test"):
if not self.src_test:
self.src_test = self._load_file(self.src_domain_name, tail_name, self.shuffle_data)
self.trg_test = self._load_file(self.trg_domain_name, tail_name, self.shuffle_data)

def _load_file(self, name_file=[], tail_name="_train", shuffle_data=False):
data_list = {}
name_file = name_file.split(',')
for idx, s_n in enumerate(name_file):
file_path_train = os.path.join(self.data_path, '{}{}.{}'.format(s_n, tail_name, self.data_format))
# print(file_path_train)
if os.path.isfile(file_path_train):
if self.dataset_name == 'digits':
x_train, y_train = load_mat_file_single_label(file_path_train)
elif self.dataset_name == 'office_caltech10_DECAF_feat':
x_train, y_train = load_mat_office_caltech10_decaf(file_path_train)
elif self.dataset_name == 'office_caltech10_ResNet101_feat':
x_train, y_train = load_mat_office_caltech10_ResNet101(file_path_train)
elif self.dataset_name == 'office31_AlexNet_feat':
x_train, y_train = load_mat_office31_AlexNet(file_path_train)
elif self.dataset_name == 'office31_resnet50_feature':
x_train, y_train = load_office31_resnet50_feature(file_path_train)

if shuffle_data:
x_train, y_train = self.shuffle(x_train, y_train)

if 'mnist32_60_10' not in s_n and self.cast_data:
x_train = u2t(x_train)
data_list.update({idx: [s_n, x_train, to_categorical(y_train)]})
else:
raise('File not found!')

print(s_n, x_train.shape[0], x_train.min(), x_train.max(), "Label", y_train.min(), y_train.max(), np.unique(y_train))
return data_list

def shuffle(self, x, y=None):
np.random.seed(random_seed())
idx_train = np.random.permutation(x.shape[0])
x = x[idx_train]
if y is not None:
y = y[idx_train]
return x, y

0 comments on commit 703b3ad

Please sign in to comment.