-
Notifications
You must be signed in to change notification settings - Fork 2
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
9 changed files
with
1,667 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1 +1,120 @@ | ||
# LAMDA | ||
# LAMDA: Label Matching Deep Domain Adaptation | ||
|
||
This is the implementation of paper **[LAMDA: Label Matching Deep Domain Adaptation](http:https://proceedings.mlr.press/v139/le21a/le21a.pdf)** which has been accepted at ICML 2021. | ||
|
||
## A. Setup | ||
|
||
### A.1. Install Package Dependencies | ||
|
||
**Install manually** | ||
|
||
``` | ||
Python Environment: >= 3.5 | ||
Tensorflow: >= 1.9 | ||
``` | ||
|
||
**Install automatically from YAML file** | ||
|
||
``` | ||
pip install --upgrade pip | ||
conda env create --file tf1.9py3.5.yml | ||
``` | ||
|
||
**[UPDATE] Install tensorbayes** | ||
|
||
Please note that tensorbayes 0.4.0 is out of date. Please copy a newer version to the *env* folder (tf1.9py3.5) using **tensorbayes.tar** | ||
|
||
``` | ||
source activate tf1.9py3.5 | ||
pip install tensorbayes | ||
tar -xvf tensorbayes.tar | ||
cp -rf /tensorbayes/* /opt/conda/envs/tf1.9py3.5/lib/python3.5/site-packages/tensorbayes/ | ||
``` | ||
|
||
### A.2. Datasets | ||
|
||
Please download Office-31 [here](https://drive.google.com/file/d/1dsrHn4S6lCmlTa4Eg4RAE5JRfZUIxR8G/view?usp=sharing) and unzip extracted features in the *datasets* folder. | ||
|
||
## B. Training | ||
|
||
We first navigate to *model* folder, and then run *run_lamda.py* file as bellow: | ||
|
||
```python | ||
cd model | ||
``` | ||
|
||
1. **A** --> **W** task | ||
|
||
```python | ||
python run_lamda.py 1 amazon webcam format csv num_iters 20000 summary_freq 400 learning_rate 0.0001 inorm True batch_size 310 src_class_trade_off 1.0 domain_trade_off 0.1 src_vat_trade_off 0.1 trg_trade_off 0.1 save_grads False cast_data False cnn_size small update_target_loss False m_on_D_trade_off 1.0 m_plus_1_on_D_trade_off 1.0 m_plus_1_on_G_trade_off 1.0 m_on_G_trade_off 0.1 data_path "" | ||
``` | ||
|
||
2. **A** --> **D** task | ||
|
||
```python | ||
python run_lamda.py 1 amazon dslr format csv num_iters 20000 summary_freq 400 learning_rate 0.0001 inorm True batch_size 310 src_class_trade_off 1.0 domain_trade_off 0.1 src_vat_trade_off 1.0 trg_trade_off 0.1 save_grads False cast_data False cnn_size small update_target_loss False m_on_D_trade_off 1.0 m_plus_1_on_D_trade_off 1.0 m_plus_1_on_G_trade_off 1.0 m_on_G_trade_off 0.05 data_path "" | ||
``` | ||
|
||
3. **D** --> **W** task | ||
|
||
```python | ||
python run_lamda.py 1 dslr webcam format csv num_iters 20000 summary_freq 400 learning_rate 0.0001 inorm True batch_size 155 src_class_trade_off 1.0 domain_trade_off 0.1 src_vat_trade_off 0.1 trg_trade_off 0.1 save_grads False cast_data False cnn_size small update_target_loss False m_on_D_trade_off 1.0 m_plus_1_on_D_trade_off 1.0 m_plus_1_on_G_trade_off 1.0 m_on_G_trade_off 0.1 data_path "" | ||
``` | ||
|
||
4. **W** --> **D** task | ||
|
||
```python | ||
python run_lamda.py 1 webcam dslr format csv num_iters 20000 summary_freq 400 learning_rate 0.0001 inorm True batch_size 310 src_class_trade_off 1.0 domain_trade_off 0.1 src_vat_trade_off 0.1 trg_trade_off 0.1 save_grads False cast_data False cnn_size small update_target_loss False m_on_D_trade_off 1.0 m_plus_1_on_D_trade_off 1.0 m_plus_1_on_G_trade_off 1.0 m_on_G_trade_off 0.1 data_path "" | ||
``` | ||
|
||
5. **D** --> **A** task | ||
|
||
```python | ||
python run_lamda.py 1 dslr amazon format csv num_iters 20000 sumary_freq 400 learning_rate 0.0001 inorm True batch_size 155 src_class_trade_off 1.0 domain_trade_off 0.1 src_vat_trade_off 1.0 trg_trade_off 0.1 save_grads False cast_data False cnn_size small update_target_loss False m_on_D_trade_off 1.0 m_plus_1_on_D_trade_off 1.0 m_plus_1_on_G_trade_off 1.0 m_on_G_trade_off 1.0 data_path "" | ||
``` | ||
|
||
6. **W** --> **A** task | ||
|
||
```python | ||
python run_lamda.py 1 webcam amazon format csv num_iters 20000 summary_freq 400 learning_rate 0.0001 inorm True batch_size 310 src_class_trade_off 1.0 domain_trade_off 0.1 src_vat_trade_off 1.0 trg_trade_off 0.1 save_grads False cast_data False cnn_size small update_target_loss False m_on_D_trade_off 1.0 m_plus_1_on_D_trade_off 1.0 m_plus_1_on_G_trade_off 1.0 m_on_G_trade_off 1.0 data_path "" | ||
``` | ||
|
||
|
||
|
||
## C. Results | ||
|
||
| Methods | **A** --> **W** | **A** --> **D** | **D** --> **W** | **W** --> **D** | **D** --> **A** | **W** --> **A** | Avg | | ||
| :-----------: | :-------------: | :-------------: | :-------------: | :-------------: | :-------------: | :-------------: | :------: | | ||
| ResNet-50 [1] | 70.0 | 65.5 | 96.1 | 99.3 | 62.8 | 60.5 | 75.7 | | ||
| DeepCORAL [2] | 83.0 | 71.5 | 97.9 | 98.0 | 63.7 | 64.5 | 79.8 | | ||
| DANN [3] | 81.5 | 74.3 | 97.1 | 99.6 | 65.5 | 63.2 | 80.2 | | ||
| ADDA [4] | 86.2 | 78.8 | 96.8 | 99.1 | 69.5 | 68.5 | 83.2 | | ||
| CDAN [5] | 94.1 | 92.9 | 98.6 | **100.0** | 71.0 | 69.3 | 87.7 | | ||
| TPN [6] | 91.2 | 89.9 | 97.7 | 99.5 | 70.5 | 73.5 | 87.1 | | ||
| DeepJDOT [7] | 88.9 | 88.2 | 98.5 | 99.6 | 72.1 | 70.1 | 86.2 | | ||
| RWOT [8] | 95.1 | 94.5 | 99.5 | **100.0** | 77.5 | 77.9 | 90.8 | | ||
| **LAMDA** | **95.2** | **96.0** | 98.5 | **99.8** | **87.3** | **84.4** | **93.0** | | ||
|
||
## D. References | ||
|
||
### D.1. Baselines: | ||
|
||
[1] K. He, X. Zhang, S. Ren, and J. Sun. Deep residual learning for image recognition. In 2016 IEEE Conference on Computer Vision and Pattern Recognition (CVPR), pages 770–778, 2016. | ||
|
||
[2] B. Sun and K. Saenko. Deep coral: Correlation alignment for deep domain adaptation. In Gang Hua and Hervé Jéegou, editors, Computer Vision – ECCV 2016 Workshops, pages 443–450, Cham, 2016. Springer International Publishing. | ||
|
||
[3] Y. Ganin, E. Ustinova, H. Ajakan, P. Germain, H. Larochelle, F. Laviolette, M. Marchand, and V. Lempitsky. Domain-adversarial training of neural networks. J. Mach. Learn. Res., 17(1):2096–2030, jan 2016. | ||
|
||
[4] E. Tzeng, J. Hoffman, K. Saenko, and T. Darrell. Adversarial discriminative domain adaptation. In 2017 IEEE Conference on Computer Vision and Pattern Recognition (CVPR), pages 2962–2971, 2017. | ||
|
||
[5] M. Long, Z. Cao, J. Wang, and M. I. Jordan. Conditional adversarial domain adaptation. In Advances in Neural Information Processing Systems 31, pages 1640–1650. Curran Associates, Inc., 2018. | ||
|
||
[6] Y. Pan, T. Yao, Y. Li, Y. Wang, C. Ngo, and T. Mei. Transferrable prototypical networks for unsupervised domain adaptation. In CVPR, pages 2234–2242, 2019. | ||
|
||
[7] B. B. Damodaran, B. Kellenberger, R. Flamary, D. Tuia, and N. Courty. Deepjdot: Deep joint distribution optimal transport for unsupervised domain adaptation. In Computer Vision - ECCV 2018, pages 467–483. Springer, 2018. | ||
|
||
[8] R. Xu, P. Liu, L. Wang, C. Chen, and J. Wang. Reliable weighted optimal transport for unsupervised domain adaptation. In CVPR 2020, June 2020. | ||
|
||
### D.2. GitHub repositories: | ||
|
||
- Some parts of our code (e.g., VAT, evaluation, …) are rewritten with modifications from [DIRT-T](https://github.com/RuiShu/dirt-t). |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,165 @@ | ||
# Copyright (c) 2021, Tuan Nguyen. | ||
# All rights reserved. | ||
|
||
import os | ||
|
||
import numpy as np | ||
from scipy.io import loadmat | ||
from scipy import misc | ||
import time | ||
import h5py | ||
from keras.utils.np_utils import to_categorical | ||
from generic_utils import random_seed | ||
import csv | ||
|
||
|
||
def load_mat_office_caltech10_decaf(filename): | ||
data = loadmat(filename) | ||
x = np.reshape(data['feas'], (-1, 8, 8, 64)) | ||
y = data['labels'][0] | ||
# y = y.reshape(-1) - 1 | ||
return x, y | ||
|
||
|
||
def load_mat_office31_ResNet50(filename): | ||
data = loadmat(filename) | ||
# x = np.reshape(data['feas'], (-1, 8, 8, 64)) | ||
x = data['feas'] | ||
y = data['labels'][0] | ||
return x, y | ||
|
||
|
||
def load_mat_office31_AlexNet(filename): | ||
data = loadmat(filename) | ||
# x = np.reshape(data['feas'], (-1, 8, 8, 64)) | ||
x = data['feas'] | ||
y = data['labels'][0] | ||
return x, y | ||
|
||
|
||
def load_office31_resnet50_feature(file_path_train): | ||
file = open(file_path_train, "r") | ||
reader = csv.reader(file) | ||
features_full = [] | ||
labels_full = [] | ||
for line in reader: | ||
feature_i = np.asarray(line[:2048]).astype(np.float32) | ||
label_i = int(float(line[2048])) | ||
features_full.append(feature_i) | ||
labels_full.append(label_i) | ||
features_full = np.asarray(features_full) | ||
labels_full = np.asarray(labels_full) | ||
return features_full, labels_full | ||
|
||
|
||
def load_mat_office_caltech10_ResNet101(filename): | ||
data = loadmat(filename) | ||
x = data['feas'] | ||
y = data['labels'][0] | ||
return x, y | ||
|
||
|
||
def load_mat_file_single_label(filename): | ||
filename_list = ['mnist', 'stl32', 'synsign', 'gtsrb', 'cifar32', 'usps32'] | ||
data = loadmat(filename) | ||
x = data['X'] | ||
y = data['y'] | ||
if any(fn in filename for fn in filename_list): | ||
if 'mnist32_60_10' not in filename and 'mnistg' not in filename: | ||
y = y[0] | ||
else: | ||
y = np.argmax(y, axis=1) | ||
# process one-hot label encoder | ||
elif len(y.shape) > 1: | ||
y = np.argmax(y, axis=1) | ||
return x, y | ||
|
||
|
||
def u2t(x): | ||
max_num = 50000 | ||
if len(x) > max_num: | ||
y = np.empty_like(x, dtype='float32') | ||
for i in range(len(x) // max_num): | ||
y[i*max_num: (i+1)*max_num] = (x[i*max_num: (i+1)*max_num].astype('float32') / 255) * 2 - 1 | ||
|
||
y[(i + 1) * max_num:] = (x[(i + 1) * max_num:].astype('float32') / 255) * 2 - 1 | ||
else: | ||
y = (x.astype('float32') / 255) * 2 - 1 | ||
return y | ||
|
||
|
||
class DataLoader: | ||
def __init__(self, src_domain=['mnistm'], trg_domain=['mnist'], data_path='./dataset', data_format='mat', | ||
shuffle_data=False, dataset_name='digits', cast_data=True): | ||
self.num_src_domain = len(src_domain.split(',')) | ||
self.src_domain_name = src_domain | ||
self.trg_domain_name = trg_domain | ||
self.data_path = data_path | ||
self.data_format = data_format | ||
self.shuffle_data = shuffle_data | ||
self.dataset_name = dataset_name | ||
self.cast_data = cast_data | ||
|
||
self.src_train = {} # {idx : ['src_idx']['src_name', x_train, y_train]} | ||
self.trg_train = {} # {idx : ['trg_idx']['trg_name', x_train, y_train]} | ||
self.src_test = {} | ||
self.trg_test = {} | ||
|
||
print("Source domains", self.src_domain_name) | ||
print("Target domain", self.trg_domain_name) | ||
print("----- Training data -----") | ||
self._load_data_train() | ||
print("----- Test data -----") | ||
self._load_data_test() | ||
|
||
self.data_shape = self.src_train[0][1][0].shape | ||
self.num_domain = len(self.src_train.keys()) | ||
self.num_class = self.src_train[0][2].shape[-1] | ||
|
||
def _load_data_train(self, tail_name="_train"): | ||
if not self.src_train: | ||
self.src_train = self._load_file(self.src_domain_name, tail_name, self.shuffle_data) | ||
self.trg_train = self._load_file(self.trg_domain_name, tail_name, self.shuffle_data) | ||
|
||
def _load_data_test(self, tail_name="_test"): | ||
if not self.src_test: | ||
self.src_test = self._load_file(self.src_domain_name, tail_name, self.shuffle_data) | ||
self.trg_test = self._load_file(self.trg_domain_name, tail_name, self.shuffle_data) | ||
|
||
def _load_file(self, name_file=[], tail_name="_train", shuffle_data=False): | ||
data_list = {} | ||
name_file = name_file.split(',') | ||
for idx, s_n in enumerate(name_file): | ||
file_path_train = os.path.join(self.data_path, '{}{}.{}'.format(s_n, tail_name, self.data_format)) | ||
# print(file_path_train) | ||
if os.path.isfile(file_path_train): | ||
if self.dataset_name == 'digits': | ||
x_train, y_train = load_mat_file_single_label(file_path_train) | ||
elif self.dataset_name == 'office_caltech10_DECAF_feat': | ||
x_train, y_train = load_mat_office_caltech10_decaf(file_path_train) | ||
elif self.dataset_name == 'office_caltech10_ResNet101_feat': | ||
x_train, y_train = load_mat_office_caltech10_ResNet101(file_path_train) | ||
elif self.dataset_name == 'office31_AlexNet_feat': | ||
x_train, y_train = load_mat_office31_AlexNet(file_path_train) | ||
elif self.dataset_name == 'office31_resnet50_feature': | ||
x_train, y_train = load_office31_resnet50_feature(file_path_train) | ||
|
||
if shuffle_data: | ||
x_train, y_train = self.shuffle(x_train, y_train) | ||
|
||
if 'mnist32_60_10' not in s_n and self.cast_data: | ||
x_train = u2t(x_train) | ||
data_list.update({idx: [s_n, x_train, to_categorical(y_train)]}) | ||
else: | ||
raise('File not found!') | ||
|
||
print(s_n, x_train.shape[0], x_train.min(), x_train.max(), "Label", y_train.min(), y_train.max(), np.unique(y_train)) | ||
return data_list | ||
|
||
def shuffle(self, x, y=None): | ||
np.random.seed(random_seed()) | ||
idx_train = np.random.permutation(x.shape[0]) | ||
x = x[idx_train] | ||
if y is not None: | ||
y = y[idx_train] | ||
return x, y |
Oops, something went wrong.