diff --git a/README.md b/README.md index 27ca0ed..a957d4c 100644 --- a/README.md +++ b/README.md @@ -1 +1,120 @@ -# LAMDA \ No newline at end of file +# LAMDA: Label Matching Deep Domain Adaptation + +This is the implementation of paper **[LAMDA: Label Matching Deep Domain Adaptation](http://proceedings.mlr.press/v139/le21a/le21a.pdf)** which has been accepted at ICML 2021. + +## A. Setup + +### A.1. Install Package Dependencies + +**Install manually** + +``` +Python Environment: >= 3.5 +Tensorflow: >= 1.9 +``` + +**Install automatically from YAML file** + +``` +pip install --upgrade pip +conda env create --file tf1.9py3.5.yml +``` + +**[UPDATE] Install tensorbayes** + +Please note that tensorbayes 0.4.0 is out of date. Please copy a newer version to the *env* folder (tf1.9py3.5) using **tensorbayes.tar** + +``` +source activate tf1.9py3.5 +pip install tensorbayes +tar -xvf tensorbayes.tar +cp -rf /tensorbayes/* /opt/conda/envs/tf1.9py3.5/lib/python3.5/site-packages/tensorbayes/ +``` + +### A.2. Datasets + +Please download Office-31 [here](https://drive.google.com/file/d/1dsrHn4S6lCmlTa4Eg4RAE5JRfZUIxR8G/view?usp=sharing) and unzip extracted features in the *datasets* folder. + +## B. Training + +We first navigate to *model* folder, and then run *run_lamda.py* file as bellow: + +```python +cd model +``` + +1. **A** --> **W** task + +```python +python run_lamda.py 1 amazon webcam format csv num_iters 20000 summary_freq 400 learning_rate 0.0001 inorm True batch_size 310 src_class_trade_off 1.0 domain_trade_off 0.1 src_vat_trade_off 0.1 trg_trade_off 0.1 save_grads False cast_data False cnn_size small update_target_loss False m_on_D_trade_off 1.0 m_plus_1_on_D_trade_off 1.0 m_plus_1_on_G_trade_off 1.0 m_on_G_trade_off 0.1 data_path "" +``` + +2. **A** --> **D** task + +```python +python run_lamda.py 1 amazon dslr format csv num_iters 20000 summary_freq 400 learning_rate 0.0001 inorm True batch_size 310 src_class_trade_off 1.0 domain_trade_off 0.1 src_vat_trade_off 1.0 trg_trade_off 0.1 save_grads False cast_data False cnn_size small update_target_loss False m_on_D_trade_off 1.0 m_plus_1_on_D_trade_off 1.0 m_plus_1_on_G_trade_off 1.0 m_on_G_trade_off 0.05 data_path "" +``` + +3. **D** --> **W** task + +```python +python run_lamda.py 1 dslr webcam format csv num_iters 20000 summary_freq 400 learning_rate 0.0001 inorm True batch_size 155 src_class_trade_off 1.0 domain_trade_off 0.1 src_vat_trade_off 0.1 trg_trade_off 0.1 save_grads False cast_data False cnn_size small update_target_loss False m_on_D_trade_off 1.0 m_plus_1_on_D_trade_off 1.0 m_plus_1_on_G_trade_off 1.0 m_on_G_trade_off 0.1 data_path "" +``` + +4. **W** --> **D** task + +```python +python run_lamda.py 1 webcam dslr format csv num_iters 20000 summary_freq 400 learning_rate 0.0001 inorm True batch_size 310 src_class_trade_off 1.0 domain_trade_off 0.1 src_vat_trade_off 0.1 trg_trade_off 0.1 save_grads False cast_data False cnn_size small update_target_loss False m_on_D_trade_off 1.0 m_plus_1_on_D_trade_off 1.0 m_plus_1_on_G_trade_off 1.0 m_on_G_trade_off 0.1 data_path "" +``` + +5. **D** --> **A** task + +```python +python run_lamda.py 1 dslr amazon format csv num_iters 20000 sumary_freq 400 learning_rate 0.0001 inorm True batch_size 155 src_class_trade_off 1.0 domain_trade_off 0.1 src_vat_trade_off 1.0 trg_trade_off 0.1 save_grads False cast_data False cnn_size small update_target_loss False m_on_D_trade_off 1.0 m_plus_1_on_D_trade_off 1.0 m_plus_1_on_G_trade_off 1.0 m_on_G_trade_off 1.0 data_path "" +``` + +6. **W** --> **A** task + +```python +python run_lamda.py 1 webcam amazon format csv num_iters 20000 summary_freq 400 learning_rate 0.0001 inorm True batch_size 310 src_class_trade_off 1.0 domain_trade_off 0.1 src_vat_trade_off 1.0 trg_trade_off 0.1 save_grads False cast_data False cnn_size small update_target_loss False m_on_D_trade_off 1.0 m_plus_1_on_D_trade_off 1.0 m_plus_1_on_G_trade_off 1.0 m_on_G_trade_off 1.0 data_path "" +``` + + + +## C. Results + +| Methods | **A** --> **W** | **A** --> **D** | **D** --> **W** | **W** --> **D** | **D** --> **A** | **W** --> **A** | Avg | +| :-----------: | :-------------: | :-------------: | :-------------: | :-------------: | :-------------: | :-------------: | :------: | +| ResNet-50 [1] | 70.0 | 65.5 | 96.1 | 99.3 | 62.8 | 60.5 | 75.7 | +| DeepCORAL [2] | 83.0 | 71.5 | 97.9 | 98.0 | 63.7 | 64.5 | 79.8 | +| DANN [3] | 81.5 | 74.3 | 97.1 | 99.6 | 65.5 | 63.2 | 80.2 | +| ADDA [4] | 86.2 | 78.8 | 96.8 | 99.1 | 69.5 | 68.5 | 83.2 | +| CDAN [5] | 94.1 | 92.9 | 98.6 | **100.0** | 71.0 | 69.3 | 87.7 | +| TPN [6] | 91.2 | 89.9 | 97.7 | 99.5 | 70.5 | 73.5 | 87.1 | +| DeepJDOT [7] | 88.9 | 88.2 | 98.5 | 99.6 | 72.1 | 70.1 | 86.2 | +| RWOT [8] | 95.1 | 94.5 | 99.5 | **100.0** | 77.5 | 77.9 | 90.8 | +| **LAMDA** | **95.2** | **96.0** | 98.5 | **99.8** | **87.3** | **84.4** | **93.0** | + +## D. References + +### D.1. Baselines: + +[1] K. He, X. Zhang, S. Ren, and J. Sun. Deep residual learning for image recognition. In 2016 IEEE Conference on Computer Vision and Pattern Recognition (CVPR), pages 770–778, 2016. + +[2] B. Sun and K. Saenko. Deep coral: Correlation alignment for deep domain adaptation. In Gang Hua and Hervé Jéegou, editors, Computer Vision – ECCV 2016 Workshops, pages 443–450, Cham, 2016. Springer International Publishing. + +[3] Y. Ganin, E. Ustinova, H. Ajakan, P. Germain, H. Larochelle, F. Laviolette, M. Marchand, and V. Lempitsky. Domain-adversarial training of neural networks. J. Mach. Learn. Res., 17(1):2096–2030, jan 2016. + +[4] E. Tzeng, J. Hoffman, K. Saenko, and T. Darrell. Adversarial discriminative domain adaptation. In 2017 IEEE Conference on Computer Vision and Pattern Recognition (CVPR), pages 2962–2971, 2017. + +[5] M. Long, Z. Cao, J. Wang, and M. I. Jordan. Conditional adversarial domain adaptation. In Advances in Neural Information Processing Systems 31, pages 1640–1650. Curran Associates, Inc., 2018. + +[6] Y. Pan, T. Yao, Y. Li, Y. Wang, C. Ngo, and T. Mei. Transferrable prototypical networks for unsupervised domain adaptation. In CVPR, pages 2234–2242, 2019. + +[7] B. B. Damodaran, B. Kellenberger, R. Flamary, D. Tuia, and N. Courty. Deepjdot: Deep joint distribution optimal transport for unsupervised domain adaptation. In Computer Vision - ECCV 2018, pages 467–483. Springer, 2018. + +[8] R. Xu, P. Liu, L. Wang, C. Chen, and J. Wang. Reliable weighted optimal transport for unsupervised domain adaptation. In CVPR 2020, June 2020. + +### D.2. GitHub repositories: + +- Some parts of our code (e.g., VAT, evaluation, …) are rewritten with modifications from [DIRT-T](https://github.com/RuiShu/dirt-t). \ No newline at end of file diff --git a/model/dataLoader.py b/model/dataLoader.py new file mode 100644 index 0000000..8d1b81f --- /dev/null +++ b/model/dataLoader.py @@ -0,0 +1,165 @@ +# Copyright (c) 2021, Tuan Nguyen. +# All rights reserved. + +import os + +import numpy as np +from scipy.io import loadmat +from scipy import misc +import time +import h5py +from keras.utils.np_utils import to_categorical +from generic_utils import random_seed +import csv + + +def load_mat_office_caltech10_decaf(filename): + data = loadmat(filename) + x = np.reshape(data['feas'], (-1, 8, 8, 64)) + y = data['labels'][0] + # y = y.reshape(-1) - 1 + return x, y + + +def load_mat_office31_ResNet50(filename): + data = loadmat(filename) + # x = np.reshape(data['feas'], (-1, 8, 8, 64)) + x = data['feas'] + y = data['labels'][0] + return x, y + + +def load_mat_office31_AlexNet(filename): + data = loadmat(filename) + # x = np.reshape(data['feas'], (-1, 8, 8, 64)) + x = data['feas'] + y = data['labels'][0] + return x, y + + +def load_office31_resnet50_feature(file_path_train): + file = open(file_path_train, "r") + reader = csv.reader(file) + features_full = [] + labels_full = [] + for line in reader: + feature_i = np.asarray(line[:2048]).astype(np.float32) + label_i = int(float(line[2048])) + features_full.append(feature_i) + labels_full.append(label_i) + features_full = np.asarray(features_full) + labels_full = np.asarray(labels_full) + return features_full, labels_full + + +def load_mat_office_caltech10_ResNet101(filename): + data = loadmat(filename) + x = data['feas'] + y = data['labels'][0] + return x, y + + +def load_mat_file_single_label(filename): + filename_list = ['mnist', 'stl32', 'synsign', 'gtsrb', 'cifar32', 'usps32'] + data = loadmat(filename) + x = data['X'] + y = data['y'] + if any(fn in filename for fn in filename_list): + if 'mnist32_60_10' not in filename and 'mnistg' not in filename: + y = y[0] + else: + y = np.argmax(y, axis=1) + # process one-hot label encoder + elif len(y.shape) > 1: + y = np.argmax(y, axis=1) + return x, y + + +def u2t(x): + max_num = 50000 + if len(x) > max_num: + y = np.empty_like(x, dtype='float32') + for i in range(len(x) // max_num): + y[i*max_num: (i+1)*max_num] = (x[i*max_num: (i+1)*max_num].astype('float32') / 255) * 2 - 1 + + y[(i + 1) * max_num:] = (x[(i + 1) * max_num:].astype('float32') / 255) * 2 - 1 + else: + y = (x.astype('float32') / 255) * 2 - 1 + return y + + +class DataLoader: + def __init__(self, src_domain=['mnistm'], trg_domain=['mnist'], data_path='./dataset', data_format='mat', + shuffle_data=False, dataset_name='digits', cast_data=True): + self.num_src_domain = len(src_domain.split(',')) + self.src_domain_name = src_domain + self.trg_domain_name = trg_domain + self.data_path = data_path + self.data_format = data_format + self.shuffle_data = shuffle_data + self.dataset_name = dataset_name + self.cast_data = cast_data + + self.src_train = {} # {idx : ['src_idx']['src_name', x_train, y_train]} + self.trg_train = {} # {idx : ['trg_idx']['trg_name', x_train, y_train]} + self.src_test = {} + self.trg_test = {} + + print("Source domains", self.src_domain_name) + print("Target domain", self.trg_domain_name) + print("----- Training data -----") + self._load_data_train() + print("----- Test data -----") + self._load_data_test() + + self.data_shape = self.src_train[0][1][0].shape + self.num_domain = len(self.src_train.keys()) + self.num_class = self.src_train[0][2].shape[-1] + + def _load_data_train(self, tail_name="_train"): + if not self.src_train: + self.src_train = self._load_file(self.src_domain_name, tail_name, self.shuffle_data) + self.trg_train = self._load_file(self.trg_domain_name, tail_name, self.shuffle_data) + + def _load_data_test(self, tail_name="_test"): + if not self.src_test: + self.src_test = self._load_file(self.src_domain_name, tail_name, self.shuffle_data) + self.trg_test = self._load_file(self.trg_domain_name, tail_name, self.shuffle_data) + + def _load_file(self, name_file=[], tail_name="_train", shuffle_data=False): + data_list = {} + name_file = name_file.split(',') + for idx, s_n in enumerate(name_file): + file_path_train = os.path.join(self.data_path, '{}{}.{}'.format(s_n, tail_name, self.data_format)) + # print(file_path_train) + if os.path.isfile(file_path_train): + if self.dataset_name == 'digits': + x_train, y_train = load_mat_file_single_label(file_path_train) + elif self.dataset_name == 'office_caltech10_DECAF_feat': + x_train, y_train = load_mat_office_caltech10_decaf(file_path_train) + elif self.dataset_name == 'office_caltech10_ResNet101_feat': + x_train, y_train = load_mat_office_caltech10_ResNet101(file_path_train) + elif self.dataset_name == 'office31_AlexNet_feat': + x_train, y_train = load_mat_office31_AlexNet(file_path_train) + elif self.dataset_name == 'office31_resnet50_feature': + x_train, y_train = load_office31_resnet50_feature(file_path_train) + + if shuffle_data: + x_train, y_train = self.shuffle(x_train, y_train) + + if 'mnist32_60_10' not in s_n and self.cast_data: + x_train = u2t(x_train) + data_list.update({idx: [s_n, x_train, to_categorical(y_train)]}) + else: + raise('File not found!') + + print(s_n, x_train.shape[0], x_train.min(), x_train.max(), "Label", y_train.min(), y_train.max(), np.unique(y_train)) + return data_list + + def shuffle(self, x, y=None): + np.random.seed(random_seed()) + idx_train = np.random.permutation(x.shape[0]) + x = x[idx_train] + if y is not None: + y = y[idx_train] + return x, y diff --git a/model/generic_utils.py b/model/generic_utils.py new file mode 100644 index 0000000..5f26644 --- /dev/null +++ b/model/generic_utils.py @@ -0,0 +1,211 @@ +# Copyright (c) 2021, Tuan Nguyen. +# All rights reserved. + +from __future__ import division +from __future__ import print_function +from __future__ import absolute_import + +import sys +import six +import time +import copy +import math +import warnings +import numpy as np +from pathlib import Path +import os +_RANDOM_SEED = 6789 + + +def model_dir(): + cur_dir = Path(os.path.abspath(__file__)) + return str(cur_dir.parent.parent) + +def data_dir(): + cur_dir = Path(os.path.abspath(__file__)) + par_dir = cur_dir.parent.parent + return str(par_dir / "datasets") + + +def random_seed(): + return _RANDOM_SEED + + +def tuid(): + ''' + Create a string ID based on current time + :return: a string formatted using current time + ''' + random_num = np.random.randint(0, 100) + return time.strftime('%Y-%m-%d_%H.%M.%S') + str(random_num) + + +def deepcopy(obj): + try: + return copy.deepcopy(obj) + except: + warnings.warn("Fail to deepcopy {}".format(obj)) + return None + + +def make_batches(size, batch_size): + '''Returns a list of batch indices (tuples of indices). + ''' + return [(i, min(size, i + batch_size)) for i in range(0, size, batch_size)] + + +def conv_out_size_same(size, stride): + return int(math.ceil(float(size) / float(stride))) + + +class Progbar(object): + def __init__(self, target, width=30, verbose=1, interval=0.01, show_steps=0): + '''Dislays a progress bar. + + # Arguments: + target: Total number of steps expected. + interval: Minimum visual progress update interval (in seconds). + ''' + self.width = width + self.target = target + self.sum_values = {} + self.unique_values = [] + self.start = time.time() + self.last_update = 0 + self.interval = interval + self.total_width = 0 + self.seen_so_far = 0 + self.verbose = verbose + self.show_steps = show_steps + self.unknown = False + self.header = '' + if self.target <= 0: + self.unknown = True + self.target = 100 + + def update(self, current, values=[], force=False): + """ + Updates the progress bar. + # Arguments + current: Index of current step. + values: List of tuples (name, value_for_last_step). + The progress bar will display averages for these values. + force: Whether to force visual progress update. + """ + if self.unknown: + current = 99 + for k, v in values: + if k not in self.sum_values: + self.sum_values[k] = [v * (current - self.seen_so_far), + current - self.seen_so_far] + self.unique_values.append(k) + else: + self.sum_values[k][0] += v * (current - self.seen_so_far) + self.sum_values[k][1] += (current - self.seen_so_far) + self.seen_so_far = current + + now = time.time() + if self.verbose == 1: + if not force and (now - self.last_update) < self.interval: + return + + prev_total_width = self.total_width + sys.stdout.write('\b' * prev_total_width) + # sys.stdout.write('\r') + + numdigits = int(np.floor(np.log10(self.target))) + 1 + barstr = '%%%dd/%%%dd [' % (numdigits, numdigits) + if self.show_steps > 0: + bar = self.header + '[' + else: + bar = self.header + barstr % (current, self.target) + prog = float(current) / self.target + prog_width = int(self.width * prog) + if prog_width > 0: + bar += ('=' * (prog_width - 1)) + if current < self.target: + bar += '>' + else: + bar += '=' + bar += ('.' * (self.width - prog_width)) + bar += ']' + sys.stdout.write(bar) + self.total_width = len(bar) + + if current: + time_per_unit = (now - self.start) / current + else: + time_per_unit = 0 + eta = time_per_unit * (self.target - current) + info = '' + if current < self.target: + info += ' - ETA: ' + eta_hours = eta // 3600 + eta_mins = (eta % 3600) // 60 + eta_seconds = eta % 60 + info += ('%dhours ' % eta_hours) if eta_hours > 0 else '' + info += ('%dmins ' % eta_mins) if eta_mins > 0 else '' + info += ('%ds ' % eta_seconds) if eta_seconds > 0 else '' + else: + info += ' - %ds' % (now - self.start) + for k in self.unique_values: + info += ' - %s:' % k + if isinstance(self.sum_values[k], list): + avg = self.sum_values[k][0] / max(1, self.sum_values[k][1]) + if abs(avg) > 1e-3: + info += ' %.4f' % avg + else: + info += ' %.4e' % avg + else: + info += ' %s' % self.sum_values[k] + + if prev_total_width > self.total_width + len(info): + info += ((prev_total_width - self.total_width - len(info)) * ' ') + self.total_width += len(info) + + sys.stdout.write(info) + sys.stdout.flush() + + if current >= self.target: + sys.stdout.write('\n') + + if self.verbose == 2: + if current >= self.target: + info = '%ds' % (now - self.start) + for k in self.unique_values: + info += ' - %s:' % k + avg = self.sum_values[k][0] / max(1, self.sum_values[k][1]) + if avg > 1e-3: + info += ' %.4f' % avg + else: + info += ' %.4e' % avg + sys.stdout.write(info + "\n") + + self.last_update = now + + def add(self, n, values=[]): + self.update(self.seen_so_far + n, values) + + +def get_from_module(identifier, module_params, module_name, + instantiate=False, kwargs=None): + if isinstance(identifier, six.string_types): + res = module_params.get(identifier) + if not res: + raise ValueError('Invalid ' + str(module_name) + ': ' + + str(identifier)) + if instantiate and not kwargs: + return res() + elif instantiate and kwargs: + return res(**kwargs) + else: + return res + elif isinstance(identifier, dict): + name = identifier.pop('name') + res = module_params.get(name) + if res: + return res(**identifier) + else: + raise ValueError('Invalid ' + str(module_name) + ': ' + + str(identifier)) + return identifier diff --git a/model/layers.py b/model/layers.py new file mode 100644 index 0000000..4bbe159 --- /dev/null +++ b/model/layers.py @@ -0,0 +1,67 @@ +# Copyright (c) 2021, Tuan Nguyen. +# All rights reserved. + +import tensorflow as tf +from tensorflow.contrib.framework import add_arg_scope +# from tensorbayes.tfutils import softmax_cross_entropy_with_two_logits as softmax_x_entropy_two + +@add_arg_scope +def noise(x, std, phase, scope=None, reuse=None): + with tf.name_scope(scope, 'noise'): + eps = tf.random_normal(tf.shape(x), 0.0, std) + output = tf.where(phase, x + eps, x) + return output + + +@add_arg_scope +def leaky_relu(x, a=0.2, name=None): + with tf.name_scope(name, 'leaky_relu'): + return tf.maximum(x, a * x) + +@add_arg_scope +def basic_accuracy(a, b, scope=None): + with tf.name_scope(scope, 'basic_acc'): + a = tf.argmax(a, 1) + b = tf.argmax(b, 1) + eq = tf.cast(tf.equal(a, b), 'float32') + output = tf.reduce_mean(eq) + return output + +@add_arg_scope +def batch_ema_acc(a, b, scope=None): + with tf.name_scope(scope, 'basic_acc'): + a = tf.argmax(a, 1) + b = tf.argmax(b, 1) + output = tf.cast(tf.equal(a, b), 'float32') + return output + +@add_arg_scope +def batch_teac_stud_avg_acc(y_trg_true, y_trg_logit, y_trg_teacher, scope=None): + with tf.name_scope(scope, 'average_acc'): + y_trg_prob = tf.nn.softmax(y_trg_logit) + y_pred_avg = (y_trg_prob + y_trg_teacher) / 2.0 + + y_trg_true = tf.argmax(y_trg_true, 1) + y_pred_avg = tf.argmax(y_pred_avg, 1) + output = tf.cast(tf.equal(y_trg_true, y_pred_avg), 'float32') + return output + +@add_arg_scope +def batch_teac_stud_ent_acc(y_trg_true, y_trg_logit, y_trg_teacher, scope=None): + with tf.name_scope(scope, 'entropy_acc'): + y_trg_prob = tf.nn.softmax(y_trg_logit) + # compute entropy + y_trg_student_ent = -tf.reduce_sum(y_trg_prob * tf.log(y_trg_prob), axis=-1) + y_trg_teacher_ent = -tf.reduce_sum(y_trg_teacher * tf.log(y_trg_teacher), axis=-1) + min_entropy = tf.argmin(tf.stack([y_trg_student_ent, y_trg_teacher_ent]), axis=0) + + y_trg_pred_sparse = tf.argmax(y_trg_logit, 1, output_type=tf.int32) + y_trg_teacher_sparse = tf.argmax(y_trg_teacher, 1, output_type=tf.int32) + student_teacher_concat = tf.stack([y_trg_pred_sparse, y_trg_teacher_sparse], axis=1) + + y_pred_entropy_voting = tf.reduce_max(student_teacher_concat * tf.one_hot(min_entropy, 2, dtype=tf.int32), + axis=1) + + y_trg_true = tf.argmax(y_trg_true, 1, output_type=tf.int32) + output = tf.cast(tf.equal(y_trg_true, y_pred_entropy_voting), 'float32') + return output \ No newline at end of file diff --git a/model/model.py b/model/model.py new file mode 100644 index 0000000..05aaddb --- /dev/null +++ b/model/model.py @@ -0,0 +1,608 @@ +# Copyright (c) 2021, Tuan Nguyen. +# All rights reserved. + +from __future__ import division +from __future__ import print_function +from __future__ import absolute_import + +import tensorflow as tf +from tensorflow.contrib.framework import arg_scope +from tensorflow.contrib.framework import add_arg_scope +from tensorbayes.layers import dense, conv2d, batch_norm, instance_norm +from tensorflow.python.ops.nn_impl import sigmoid_cross_entropy_with_logits as sigmoid_x_entropy +from tensorbayes.tfutils import softmax_cross_entropy_with_two_logits as softmax_x_entropy_two + +from generic_utils import random_seed + +from layers import leaky_relu +import os +from generic_utils import model_dir +import numpy as np +import tensorbayes as tb +from layers import batch_ema_acc + + +def build_block(input_layer, layout, info=1): + x = input_layer + for i in range(0, len(layout)): + with tf.variable_scope('l{:d}'.format(i)): + f, f_args, f_kwargs = layout[i] + x = f(x, *f_args, **f_kwargs) + if info > 1: + print(x) + return x + + +@add_arg_scope +def normalize_perturbation(d, scope=None): + with tf.name_scope(scope, 'norm_pert'): + output = tf.nn.l2_normalize(d, axis=np.arange(1, len(d.shape))) + return output + + +def build_encode_template( + input_layer, training_phase, scope, encode_layout, + reuse=None, internal_update=False, getter=None, inorm=True, cnn_size='large'): + with tf.variable_scope(scope, reuse=reuse, custom_getter=getter): + with arg_scope([leaky_relu], a=0.1), \ + arg_scope([conv2d, dense], activation=leaky_relu, bn=True, phase=training_phase), \ + arg_scope([batch_norm], internal_update=internal_update): + + preprocess = instance_norm if inorm else tf.identity + + layout = encode_layout(preprocess=preprocess, training_phase=training_phase, cnn_size=cnn_size) + output_layer = build_block(input_layer, layout) + + return output_layer + + +def build_decode_template( + input_layer, training_phase, scope, decode_layout, + reuse=None, internal_update=False, getter=None, inorm=False, cnn_size='large'): + with tf.variable_scope(scope, reuse=reuse, custom_getter=getter): + with arg_scope([leaky_relu], a=0.1), \ + arg_scope([conv2d, dense], activation=leaky_relu, bn=True, phase=training_phase), \ + arg_scope([batch_norm], internal_update=internal_update): + layout = decode_layout(training_phase=training_phase) + output_layer = build_block(input_layer, layout) + + return output_layer + + +def build_class_discriminator_template( + input_layer, training_phase, scope, num_classes, class_discriminator_layout, + reuse=None, internal_update=False, getter=None, cnn_size='large'): + with tf.variable_scope(scope, reuse=reuse, custom_getter=getter): + with arg_scope([leaky_relu], a=0.1), \ + arg_scope([conv2d, dense], activation=leaky_relu, bn=True, phase=training_phase), \ + arg_scope([batch_norm], internal_update=internal_update): + layout = class_discriminator_layout(num_classes=num_classes, global_pool=True, activation=None, + cnn_size=cnn_size) + output_layer = build_block(input_layer, layout) + + return output_layer + + +def build_domain_discriminator_template(x, domain_layout, c=1, reuse=None): + with tf.variable_scope('domain_disc', reuse=reuse): + with arg_scope([dense], activation=tf.nn.relu): + layout = domain_layout(c=c) + output_layer = build_block(x, layout) + + return output_layer + + +def get_default_config(): + tf_config = tf.ConfigProto() + tf_config.gpu_options.allow_growth = True + tf_config.log_device_placement = False + tf_config.allow_soft_placement = True + return tf_config + + +class LAMDA(): + def __init__(self, + model_name="LAMDA-results", + learning_rate=0.001, + batch_size=128, + num_iters=80000, + summary_freq=400, + src_class_trade_off=1.0, + src_vat_trade_off=1.0, + trg_trade_off=1.0, + domain_trade_off=1.0, + adapt_domain_trade_off=False, + encode_layout=None, + decode_layout=None, + classify_layout=None, + domain_layout=None, + freq_calc_metrics=10, + init_calc_metrics=2, + current_time='', + inorm=True, + m_on_D_trade_off=1.0, + m_plus_1_on_D_trade_off=1.0, + m_plus_1_on_G_trade_off=1.0, + m_on_G_trade_off=0.1, + lamda_model_id='', + save_grads=False, + only_save_final_model=True, + cnn_size='large', + update_target_loss=True, + sample_size=50, + src_recons_trade_off=0.1, + **kwargs): + self.model_name = model_name + self.batch_size = batch_size + self.learning_rate = learning_rate + self.num_iters = num_iters + self.summary_freq = summary_freq + self.src_class_trade_off = src_class_trade_off + self.src_vat_trade_off = src_vat_trade_off + self.trg_trade_off = trg_trade_off + self.domain_trade_off = domain_trade_off + self.adapt_domain_trade_off = adapt_domain_trade_off + + self.encode_layout = encode_layout + self.decode_layout = decode_layout + self.classify_layout = classify_layout + self.domain_layout = domain_layout + + self.freq_calc_metrics = freq_calc_metrics + self.init_calc_metrics = init_calc_metrics + + self.current_time = current_time + self.inorm = inorm + + self.m_on_D_trade_off = m_on_D_trade_off + self.m_plus_1_on_D_trade_off = m_plus_1_on_D_trade_off + self.m_plus_1_on_G_trade_off = m_plus_1_on_G_trade_off + self.m_on_G_trade_off = m_on_G_trade_off + + self.lamda_model_id = lamda_model_id + + self.save_grads = save_grads + self.only_save_final_model = only_save_final_model + + self.cnn_size = cnn_size + self.update_target_loss = update_target_loss + + self.sample_size = sample_size + self.src_recons_trade_off = src_recons_trade_off + + + def _init(self, data_loader): + np.random.seed(random_seed()) + tf.set_random_seed(random_seed()) + tf.reset_default_graph() + + self.tf_graph = tf.get_default_graph() + self.tf_config = get_default_config() + self.tf_session = tf.Session(config=self.tf_config, graph=self.tf_graph) + + self.data_loader = data_loader + self.num_classes = self.data_loader.num_class + self.batch_size_src = self.sample_size*self.num_classes + + def _get_variables(self, list_scopes): + variables = [] + for scope_name in list_scopes: + variables.append(tf.get_collection('trainable_variables', scope_name)) + return variables + + def convert_one_hot(self, y): + y_idx = y.reshape(-1).astype(int) if y is not None else None + y = np.eye(self.num_classes)[y_idx] if y is not None else None + return y + + def _get_scope(self, part_name, side_name, same_network=True): + suffix = '' + if not same_network: + suffix = '/' + side_name + return part_name + suffix + + def _get_primary_scopes(self): + return ['generator', 'classifier', 'decode'] + + def _get_secondary_scopes(self): + return ['domain_disc'] + + def _build_source_middle(self, x_src): + scope_name = self._get_scope('generator', 'src') + return build_encode_template(x_src, encode_layout=self.encode_layout, + scope=scope_name, training_phase=self.is_training, inorm=self.inorm, cnn_size=self.cnn_size) + + def _build_middle_source(self, x_src_mid): + scope_name = self._get_scope('decode', 'src') + return build_decode_template( + x_src_mid, decode_layout=self.decode_layout, scope=scope_name, training_phase=self.is_training, inorm=self.inorm, cnn_size=self.cnn_size + ) + + def _build_target_middle(self, x_trg): + scope_name = self._get_scope('generator', 'trg') + return build_encode_template( + x_trg, encode_layout=self.encode_layout, + scope=scope_name, training_phase=self.is_training, inorm=self.inorm, + reuse=True, internal_update=True, cnn_size=self.cnn_size + ) # reuse the 'encode_layout' + + def _build_classifier(self, x, num_classes, ema=None, is_teacher=False): + g_teacher_scope = self._get_scope('generator', 'teacher', same_network=False) + g_x = build_encode_template( + x, encode_layout=self.encode_layout, + scope=g_teacher_scope if is_teacher else 'generator', training_phase=False, inorm=self.inorm, + reuse=False if is_teacher else True, getter=None if is_teacher else tb.tfutils.get_getter(ema), + cnn_size=self.cnn_size + ) + + h_teacher_scope = self._get_scope('classifier', 'teacher', same_network=False) + h_g_x = build_class_discriminator_template( + g_x, training_phase=False, scope=h_teacher_scope if is_teacher else 'classifier', num_classes=num_classes, + reuse=False if is_teacher else True, class_discriminator_layout=self.classify_layout, + getter=None if is_teacher else tb.tfutils.get_getter(ema), cnn_size=self.cnn_size + ) + return h_g_x + + def _build_domain_discriminator(self, x_mid, reuse=False): + return build_domain_discriminator_template(x_mid, domain_layout=self.domain_layout, c=self.num_classes+1, reuse=reuse) + + def _build_class_src_discriminator(self, x_src, num_src_classes): + return build_class_discriminator_template( + self.x_src_mid, training_phase=self.is_training, scope='classifier', num_classes=num_src_classes, + class_discriminator_layout=self.classify_layout, cnn_size=self.cnn_size + ) + + def _build_class_trg_discriminator(self, x_trg, num_trg_classes): + return build_class_discriminator_template( + self.x_trg_mid, training_phase=self.is_training, scope='classifier', num_classes=num_trg_classes, + reuse=True, internal_update=True, class_discriminator_layout=self.classify_layout, cnn_size=self.cnn_size + ) + + def perturb_image(self, x, p, num_classes, class_discriminator_layout, encode_layout, + pert='vat', scope=None, radius=3.5, scope_classify=None, scope_encode=None, training_phase=None): + with tf.name_scope(scope, 'perturb_image'): + eps = 1e-6 * normalize_perturbation(tf.random_normal(shape=tf.shape(x))) + + # Predict on randomly perturbed image + x_eps_mid = build_encode_template( + x + eps, encode_layout=encode_layout, scope=scope_encode, training_phase=training_phase, reuse=True, + inorm=self.inorm, cnn_size=self.cnn_size) + x_eps_pred = build_class_discriminator_template( + x_eps_mid, class_discriminator_layout=class_discriminator_layout, + training_phase=training_phase, scope=scope_classify, reuse=True, num_classes=num_classes, + cnn_size=self.cnn_size + ) + # eps_p = classifier(x + eps, phase=True, reuse=True) + loss = softmax_x_entropy_two(labels=p, logits=x_eps_pred) + + # Based on perturbed image, get direction of greatest error + eps_adv = tf.gradients(loss, [eps], aggregation_method=2)[0] + + # Use that direction as adversarial perturbation + eps_adv = normalize_perturbation(eps_adv) + x_adv = tf.stop_gradient(x + radius * eps_adv) + + return x_adv + + def vat_loss(self, x, p, num_classes, class_discriminator_layout, encode_layout, + scope=None, scope_classify=None, scope_encode=None, training_phase=None): + + with tf.name_scope(scope, 'smoothing_loss'): + x_adv = self.perturb_image( + x, p, num_classes, class_discriminator_layout=class_discriminator_layout, encode_layout=encode_layout, + scope_classify=scope_classify, scope_encode=scope_encode, training_phase=training_phase) + + x_adv_mid = build_encode_template( + x_adv, encode_layout=encode_layout, scope=scope_encode, training_phase=training_phase, inorm=self.inorm, + reuse=True, cnn_size=self.cnn_size) + x_adv_pred = build_class_discriminator_template( + x_adv_mid, training_phase=training_phase, scope=scope_classify, reuse=True, num_classes=num_classes, + class_discriminator_layout=class_discriminator_layout, cnn_size=self.cnn_size + ) + # p_adv = classifier(x_adv, phase=True, reuse=True) + loss = tf.reduce_mean(softmax_x_entropy_two(labels=tf.stop_gradient(p), logits=x_adv_pred)) + + return loss + + def _build_vat_loss(self, x, p, num_classes, scope=None, scope_classify=None, scope_encode=None): + return self.vat_loss( # compute the divergence between C(x) and C(G(x+r)) + x, p, num_classes, + class_discriminator_layout=self.classify_layout, + encode_layout=self.encode_layout, + scope=scope, scope_classify=scope_classify, scope_encode=scope_encode, + training_phase=self.is_training + ) + + def _build_model(self): + self.x_src = tf.placeholder(dtype=tf.float32, shape=(None, 2048)) + self.x_trg = tf.placeholder(dtype=tf.float32, shape=(None, 2048)) + + self.y_src = tf.placeholder(dtype=tf.float32, shape=(None, self.num_classes)) + self.y_trg = tf.placeholder(dtype=tf.float32, shape=(None, self.num_classes)) + + T = tb.utils.TensorDict(dict( + x_tmp=tf.placeholder(dtype=tf.float32, shape=(None, 2048)), + y_tmp=tf.placeholder(dtype=tf.float32, shape=(None, self.num_classes)) + )) + + self.is_training = tf.placeholder(tf.bool, shape=(), name='is_training') + + self.x_src_mid = self._build_source_middle(self.x_src) + self.x_src_prime = self._build_middle_source(self.x_src_mid) + self.x_trg_mid = self._build_target_middle(self.x_trg) + + self.x_fr_src = self._build_domain_discriminator(self.x_src_mid) + self.x_fr_trg = self._build_domain_discriminator(self.x_trg_mid, reuse=True) + + # use m units of D(G(x_s)) for classification on joint space + self.m_src_on_D_logit = tf.gather(self.x_fr_src, tf.range(0, self.num_classes, dtype=tf.int32), axis=1) + self.loss_m_src_on_D = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits_v2(labels=self.y_src, + logits=self.m_src_on_D_logit)) + + # maximize log likelihood of target data and minimize that of source data on 11th class + self.m_plus_1_src_logit_on_D = tf.gather(self.x_fr_src, tf.range(self.num_classes, self.num_classes + 1, + dtype=tf.int32), axis=1) + self.m_plus_1_trg_logit_on_D = tf.gather(self.x_fr_trg, tf.range(self.num_classes, self.num_classes + 1, + dtype=tf.int32), axis=1) + + self.loss_m_plus_1_on_D = 0.5 * tf.reduce_mean(sigmoid_x_entropy( + labels=tf.ones_like(self.m_plus_1_trg_logit_on_D), logits=self.m_plus_1_trg_logit_on_D) + \ + sigmoid_x_entropy( + labels=tf.zeros_like(self.m_plus_1_src_logit_on_D), + logits=self.m_plus_1_src_logit_on_D)) + + self.loss_disc = self.m_on_D_trade_off*self.loss_m_src_on_D + self.m_plus_1_on_D_trade_off*self.loss_m_plus_1_on_D + + self.y_src_logit = self._build_class_src_discriminator(self.x_src_mid, self.num_classes) + self.y_trg_logit = self._build_class_trg_discriminator(self.x_trg_mid, self.num_classes) + + self.y_src_pred = tf.argmax(self.y_src_logit, 1, output_type=tf.int32) + self.y_trg_pred = tf.argmax(self.y_trg_logit, 1, output_type=tf.int32) + self.y_src_sparse = tf.argmax(self.y_src, 1, output_type=tf.int32) + self.y_trg_sparse = tf.argmax(self.y_trg, 1, output_type=tf.int32) + + ############################### + # classification loss + self.src_loss_class_detail = tf.nn.softmax_cross_entropy_with_logits_v2( + logits=self.y_src_logit, labels=self.y_src) # (batch_size,) + self.src_loss_class = tf.reduce_mean(self.src_loss_class_detail) # real number + + self.trg_loss_class_detail = tf.nn.softmax_cross_entropy_with_logits_v2( + logits=self.y_trg_logit, labels=self.y_trg) + self.trg_loss_class = tf.reduce_mean(self.trg_loss_class_detail) # just use for testing + + self.src_accuracy = tf.reduce_mean(tf.cast(tf.equal(self.y_src_sparse, self.y_src_pred), 'float32')) + self.trg_accuracy_batch = tf.cast(tf.equal(self.y_trg_sparse, self.y_trg_pred), 'float32') + self.trg_accuracy = tf.reduce_mean(self.trg_accuracy_batch) + + ############################# + # generator loss + self.loss_m_plus_1_on_G = 0.5 * tf.reduce_mean(sigmoid_x_entropy( + labels=tf.zeros_like(self.m_plus_1_trg_logit_on_D), logits=self.m_plus_1_trg_logit_on_D) + \ + sigmoid_x_entropy( + labels=tf.ones_like(self.m_plus_1_src_logit_on_D), + logits=self.m_plus_1_src_logit_on_D)) + + self.A_m = self.y_trg_logit + self.m_trg_on_D_logit = tf.gather(self.x_fr_trg, tf.range(0, self.num_classes, dtype=tf.int32), axis=1) + + self.loss_m_trg_on_G = tf.reduce_mean( + softmax_x_entropy_two(logits=self.m_trg_on_D_logit, labels=self.A_m)) + + self.loss_generator = self.m_plus_1_on_G_trade_off * self.loss_m_plus_1_on_G + \ + self.m_on_G_trade_off * self.loss_m_trg_on_G + + ############################# + # vat loss + self.src_loss_vat = self._build_vat_loss( + self.x_src, self.y_src_logit, self.num_classes, + scope_encode=self._get_scope('generator', 'src'), scope_classify='classifier' + ) + self.trg_loss_vat = self._build_vat_loss( + self.x_trg, self.y_trg_logit, self.num_classes, + scope_encode=self._get_scope('generator', 'trg'), scope_classify='classifier' + ) + + ############################# + # conditional entropy loss w.r.t. target distribution + self.trg_loss_cond_entropy = tf.reduce_mean(softmax_x_entropy_two(labels=self.y_trg_logit, + logits=self.y_trg_logit)) + + ############################# + # reconstruct loss + # self.src_reconstruct_loss = tf.reduce_mean(tf.pow(tf.norm(self.x_src - self.x_src_prime, axis=1, ord=2), 2)) / 1000.0 + ############################# + # construct primary loss + if self.adapt_domain_trade_off: + self.domain_trade_off_ph = tf.placeholder(dtype=tf.float32) + lst_primary_losses = [ + (self.src_class_trade_off, self.src_loss_class), + (self.domain_trade_off, self.loss_generator), + (self.src_vat_trade_off, self.src_loss_vat), + (self.trg_trade_off, self.trg_loss_vat), + (self.trg_trade_off, self.trg_loss_cond_entropy) + # (self.src_recons_trade_off, self.src_reconstruct_loss) + ] + self.primary_loss = tf.constant(0.0) + for trade_off, loss in lst_primary_losses: + if trade_off != 0: + self.primary_loss += trade_off * loss + + primary_variables = self._get_variables(self._get_primary_scopes()) + + # Evaluation (EMA) + ema = tf.train.ExponentialMovingAverage(decay=0.998) + var_list_for_ema = primary_variables[0] + primary_variables[1] + ema_op = ema.apply(var_list=var_list_for_ema) + self.ema_p = self._build_classifier(T.x_tmp, self.num_classes, ema) + + # Accuracies + self.batch_ema_acc = batch_ema_acc(T.y_tmp, self.ema_p) + self.fn_batch_ema_acc = tb.function(self.tf_session, [T.x_tmp, T.y_tmp], self.batch_ema_acc) + + self.train_main = \ + tf.train.AdamOptimizer(self.learning_rate, 0.5).minimize(self.primary_loss, var_list=primary_variables) + + self.primary_train_op = tf.group(self.train_main, ema_op) + # self.primary_train_op = tf.group(self.train_main) + + if self.save_grads: + self.grads_wrt_primary_loss = tf.train.AdamOptimizer(self.learning_rate, 0.5).compute_gradients( + self.primary_loss, var_list=primary_variables) + ############################# + # construct secondary loss + secondary_variables = self._get_variables(self._get_secondary_scopes()) + self.secondary_train_op = \ + tf.train.AdamOptimizer(self.learning_rate, 0.5).minimize(self.loss_disc, + var_list=secondary_variables) + ############################# + # construct one more target loss + if self.update_target_loss: + self.target_loss = self.trg_trade_off * (self.trg_loss_vat + self.trg_loss_cond_entropy) + + self.target_train_op = \ + tf.train.AdamOptimizer(self.learning_rate, 0.5).minimize(self.target_loss, + var_list=primary_variables) + + if self.save_grads: + self.grads_wrt_secondary_loss = tf.train.AdamOptimizer(self.learning_rate, 0.5).compute_gradients( + self.loss_disc, var_list=secondary_variables) + ############################ + # summaries + tf.summary.scalar('domain/loss_disc', self.loss_disc) + tf.summary.scalar('domain/loss_disc/loss_m_src_on_D', self.loss_m_src_on_D) + tf.summary.scalar('domain/loss_disc/loss_m_plus_1_on_D', self.loss_m_plus_1_on_D) + + tf.summary.scalar('primary_loss/src_loss_class', self.src_loss_class) + tf.summary.scalar('primary_loss/loss_generator', self.loss_generator) + tf.summary.scalar('primary_loss/loss_generator/loss_m_plus_1_on_G', self.loss_m_plus_1_on_G) + tf.summary.scalar('primary_loss/loss_generator/loss_m_trg_on_G', self.loss_m_trg_on_G) + + tf.summary.scalar('acc/src_acc', self.src_accuracy) + tf.summary.scalar('acc/trg_acc', self.trg_accuracy) + + tf.summary.scalar('hyperparameters/learning_rate', self.learning_rate) + tf.summary.scalar('hyperparameters/src_class_trade_off', self.src_class_trade_off) + tf.summary.scalar('hyperparameters/domain_trade_off', + self.domain_trade_off_ph if self.adapt_domain_trade_off + else self.domain_trade_off) + + self.tf_merged_summaries = tf.summary.merge_all() + + if self.save_grads: + with tf.name_scope("visualize"): + for var in tf.trainable_variables(): + tf.summary.histogram(var.op.name + '/values', var) + for grad, var in self.grads_wrt_primary_loss: + if grad is not None: + tf.summary.histogram(var.op.name + '/grads_wrt_primary_loss', grad) + for grad, var in self.grads_wrt_secondary_loss: + if grad is not None: + tf.summary.histogram(var.op.name + '/grads_wrt_secondary_loss', grad) + + def _fit_loop(self): + print('Start training', 'LAMDA at', os.path.basename(__file__)) + print('============ LOG-ID: %s ============' % self.current_time) + + self.tf_session.run(tf.global_variables_initializer()) + + num_src_samples = self.data_loader.src_train[0][2].shape[0] + num_trg_samples = self.data_loader.trg_train[0][2].shape[0] + + with self.tf_graph.as_default(): + saver = tf.train.Saver(tf.global_variables(), max_to_keep=3) + + self.checkpoint_path = os.path.join(model_dir(), self.model_name, "saved-model", "{}".format(self.lamda_model_id)) + check_point = tf.train.get_checkpoint_state(self.checkpoint_path) + + if check_point and tf.train.checkpoint_exists(check_point.model_checkpoint_path): + print("Load model parameters from %s\n" % check_point.model_checkpoint_path) + saver.restore(self.tf_session, check_point.model_checkpoint_path) + + for it in range(self.num_iters): + idx_src_samples = np.random.permutation(num_src_samples)[:self.batch_size] + idx_trg_samples = np.random.permutation(num_trg_samples)[:self.batch_size] + + feed_data = dict() + feed_data[self.x_src] = self.data_loader.src_train[0][1][idx_src_samples, :] + feed_data[self.y_src] = self.data_loader.src_train[0][2][idx_src_samples] + feed_data[self.y_src] = feed_data[self.y_src] + + feed_data[self.x_trg] = self.data_loader.trg_train[0][1][idx_trg_samples, :] + feed_data[self.y_trg] = self.data_loader.trg_train[0][2][idx_trg_samples] + feed_data[self.y_trg] = feed_data[self.y_trg] + feed_data[self.is_training] = True + + _, loss_disc = \ + self.tf_session.run( + [self.secondary_train_op, self.loss_disc], + feed_dict=feed_data + ) + + _, src_loss_class, loss_generator, trg_loss_class, src_acc, trg_acc = \ + self.tf_session.run( + [self.primary_train_op, self.src_loss_class, self.loss_generator, + self.trg_loss_class, self.src_accuracy, self.trg_accuracy], + feed_dict=feed_data + ) + + if it == 0 or (it + 1) % self.summary_freq == 0: + print("iter %d/%d loss_disc %.3f; src_loss_class %.5f; loss_generator %.3f\n" + "src_acc %.2f" % (it + 1, self.num_iters, loss_disc, src_loss_class, loss_generator, src_acc * 100)) + + if (it + 1) % self.summary_freq == 0: + if not self.only_save_final_model: + self.save_trained_model(saver, it + 1) + elif it + 1 == self.num_iters: + self.save_trained_model(saver, it + 1) + + # Save acc values + self.save_value(step=it + 1) + + def save_trained_model(self, saver, step): + # Save model + checkpoint_path = os.path.join(model_dir(), self.model_name, "saved-model", + "{}".format(self.current_time)) + checkpoint_path = os.path.join(checkpoint_path, "lamda_" + self.current_time + ".ckpt") + + directory = os.path.dirname(checkpoint_path) + if not os.path.exists(directory): + os.makedirs(directory) + saver.save(self.tf_session, checkpoint_path, global_step=step) + + def save_value(self, step): + # Save ema accuracy + acc_trg_test_ema, summary_trg_test_ema = self.compute_value(self.fn_batch_ema_acc, 'test/trg_test_ema', + x_full=self.data_loader.trg_test[0][1], + y=self.data_loader.trg_test[0][2], labeler=None) + print_list = ['trg_test_ema', round(acc_trg_test_ema * 100, 2)] + print(print_list) + + def compute_value(self, fn_batch_ema_acc, tag, x_full, y, labeler, full=True): + + with tb.nputils.FixedSeed(0): + shuffle = np.random.permutation(len(x_full)) + + xs = x_full[shuffle] + ys = y[shuffle] if y is not None else None + + if not full: + xs = xs[:1000] + ys = ys[:1000] if ys is not None else None + + n = len(xs) + bs = 200 + + acc_full = np.ones(n, dtype=float) + + for i in range(0, n, bs): + x = xs[i:i + bs] + y = ys[i:i + bs] if ys is not None else labeler(x) + acc_batch = fn_batch_ema_acc(x, y) + acc_full[i:i + bs] = acc_batch + + acc = np.mean(acc_full) + + summary = tf.Summary.Value(tag=tag, simple_value=acc) + summary = tf.Summary(value=[summary]) + return acc, summary diff --git a/model/run_lamda.py b/model/run_lamda.py new file mode 100644 index 0000000..c9fdae2 --- /dev/null +++ b/model/run_lamda.py @@ -0,0 +1,210 @@ +# Copyright (c) 2021, Tuan Nguyen. +# All rights reserved. + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from model import LAMDA + +from layers import noise +from test_da_template_lamda import main_func, resolve_conflict_params + +from tensorflow.python.layers.core import dropout +from tensorbayes.layers import dense, conv2d, avg_pool, max_pool + +import warnings +import os +from generic_utils import tuid, model_dir +import signal +import sys +import time +import datetime +from pprint import pprint + +choice_default = 1 +warnings.simplefilter("ignore", category=DeprecationWarning) + +model_name = "LAMDA-results" +current_time = tuid() + + +# generator +def encode_layout(preprocess, training_phase=True, cnn_size='large'): + layout = [] + if cnn_size == 'small': + layout = [ + (dense, (256,), {}), + (dropout, (), dict(training=training_phase)), + (noise, (1,), dict(phase=training_phase)), + ] + elif cnn_size == 'large': + layout = [ + (preprocess, (), {}), + (conv2d, (96, 3, 1), {}), + (conv2d, (96, 3, 1), {}), + (conv2d, (96, 3, 1), {}), + (max_pool, (2, 2), {}), + (dropout, (), dict(training=training_phase)), + (noise, (1,), dict(phase=training_phase)), + (conv2d, (192, 3, 1), {}), + (conv2d, (192, 3, 1), {}), + (conv2d, (192, 3, 1), {}), + (max_pool, (2, 2), {}), + (dropout, (), dict(training=training_phase)), + (noise, (1,), dict(phase=training_phase)), + ] + return layout + + +def decode_layout(training_phase=True): + layout = [ + (dense, (2048,), {}), + (dropout, (), dict(training=training_phase)), + (noise, (1,), dict(phase=training_phase)), + ] + return layout + + +# classifier +def class_discriminator_layout(num_classes=None, global_pool=True, activation=None, cnn_size='large'): + layout = [] + if cnn_size == 'small': + layout = [ + (dense, (num_classes,), dict(activation=activation)) + ] + + elif cnn_size == 'large': + layout = [ + (conv2d, (192, 3, 1), {}), + (conv2d, (192, 3, 1), {}), + (conv2d, (192, 3, 1), {}), + (avg_pool, (), dict(global_pool=global_pool)), + (dense, (num_classes,), dict(activation=activation)) + ] + return layout + + +# discriminator +def domain_layout(c): + layout = [ + # (dense, (100,), {}), + (dense, (c,), dict(activation=None)) + ] + return layout + + +def create_obj_func(params): + if len(sys.argv) > 1: + my_choice = int(sys.argv[1]) + else: + my_choice = choice_default + if my_choice == 0: + default_params = { + } + else: + default_params = { + 'batch_size': 128, + 'learning_rate': 1e-4, + 'num_iters': 80000, + 'src_class_trade_off': 1.0, + 'domain_trade_off': 10.0, + 'src_vat_trade_off': 1.0, + 'trg_trade_off': 1.0, + 'm_on_D_trade_off': 1.0, + 'm_plus_1_on_D_trade_off': 1.0, + 'm_plus_1_on_G_trade_off': 1.0, + 'm_on_G_trade_off': 0.1, + 'src_recons_trade_off': 0.0, + 'lamda_model_id': '', + 'classify_layout': class_discriminator_layout, + 'encode_layout': encode_layout, + 'decode_layout': decode_layout, + 'domain_layout': domain_layout, + 'freq_calc_metrics': 10, + 'init_calc_metrics': -1, + 'log_path': os.path.join(model_dir(), model_name, "logs", "{}".format(current_time)), + 'summary_freq': 400, + 'current_time': current_time, + 'inorm': True, + 'save_grads': False, + 'cast_data': False, + 'only_save_final_model': True, + 'cnn_size': 'large', + 'update_target_loss': True, + 'data_augmentation': False, + } + + default_params = resolve_conflict_params(params, default_params) + + print('Default parameters:') + pprint(default_params) + + learner = LAMDA( + **params, + **default_params, + ) + return learner + + +def main_test(run_exp=False): + params_gridsearch = { + 'learning_rate': [1e-3, 1e-2], + } + attribute_names = ( + 'learning_rate', 'same_network', 'src_class_trade_off', 'trg_trade_off', + 'src_vat_trade_off', 'domain_trade_off', 'adapt_domain_trade_off', 'num_iters', 'model_name') + + main_func( + create_obj_func, + choice_default=choice_default, + src_name_default='mnist32_60_10', + trg_name_default='mnistm32_60_10', + params_gridsearch=params_gridsearch, + attribute_names=attribute_names, + num_workers=2, + file_config=None, + run_exp=run_exp, + freq_predict_display=10, + summary_freq=100, + current_time=current_time, + log_path=os.path.join(model_dir(), model_name, "logs", "{}".format(current_time)) + ) + +class Logger(object): + def __init__(self): + self.terminal = sys.stdout + self.console_log_path = os.path.join(model_dir(), model_name, "console_output", "{}.txt".format(current_time)) + if not os.path.exists(os.path.dirname(self.console_log_path)): + os.makedirs(os.path.dirname(self.console_log_path)) + self.log = open(self.console_log_path, 'a') + signal.signal(signal.SIGINT, self.signal_handler) + + def signal_handler(self, sig, frame): + print('You pressed Ctrl+C.') + self.log.close() + + # Remove logfile + os.remove(self.console_log_path) + print('Removed console_output file') + sys.exit(0) + + def write(self, message): + self.terminal.write(message) + self.log.write(message) + + def flush(self): + # this flush method is needed for python 3 compatibility. + # this handles the flush command by doing nothing. + # you might want to specify some extra behavior here. + pass + +if __name__ == '__main__': + # pytest.main([__file__]) + sys.stdout = Logger() + start_time = time.time() + print('Running {} ...'.format(os.path.basename(__file__))) + main_test(run_exp=True) + training_time = time.time() - start_time + print('Total time: %s' % str(datetime.timedelta(seconds=training_time))) + print("============ LOG-ID: %s ============" % current_time) diff --git a/model/test_da_template_lamda.py b/model/test_da_template_lamda.py new file mode 100644 index 0000000..cca9498 --- /dev/null +++ b/model/test_da_template_lamda.py @@ -0,0 +1,177 @@ +# Copyright (c) 2021, Tuan Nguyen. +# All rights reserved. + +from __future__ import division +from __future__ import print_function +from __future__ import absolute_import + +import os +import sys + +import numpy as np +import tensorflow as tf + +from generic_utils import random_seed +from generic_utils import data_dir +from dataLoader import DataLoader + + +def test_real_dataset(create_obj_func, src_name=None, trg_name=None, show=False, block_figure_on_end=False): + print('Running {} ...'.format(os.path.basename(__file__))) + + if src_name is None: + if len(sys.argv) > 2: + src_name = sys.argv[2] + else: + raise Exception('Not specify source dataset') + if trg_name is None: + if len(sys.argv) > 3: + trg_name = sys.argv[3] + else: + raise Exception('Not specify trgget dataset') + + np.random.seed(random_seed()) + tf.set_random_seed(random_seed()) + tf.reset_default_graph() + + print("========== Test on real data ==========") + + users_params = dict() + users_params = parse_arguments(users_params) + + # data_format = 'libsvm' + data_format = 'mat' + + if 'format' in users_params: + data_format, users_params = extract_param('format', data_format, users_params) + + if len(users_params['data_path']) == 0: + data_path = data_dir() + else: + data_path = users_params['data_path'] + + data_loader = DataLoader(src_domain=src_name, + trg_domain=trg_name, + data_path=data_path, + data_format=data_format, + dataset_name='office31_resnet50_feature', + cast_data=users_params['cast_data']) + + assert users_params['batch_size'] % data_loader.num_src_domain == 0 + + print('users_params:', users_params) + + learner = create_obj_func(users_params) + learner.dim_src = data_loader.data_shape + learner.dim_trg = data_loader.data_shape + + learner.x_trg_test = data_loader.trg_test[0][0] + learner.y_trg_test = data_loader.trg_test[0][1] + # learner.x_src_test = x_src_test + # learner.y_src_test = y_src_test + + print("dim_src: (%d)" % (learner.dim_src[0])) + print("dim_trg: (%d)" % (learner.dim_trg[0])) + + learner._init(data_loader) + learner._build_model() + learner._fit_loop() + + +def main_func( + create_obj_func, + choice_default=0, + src_name_default='svmguide1', + trg_name_default='svmguide1', + params_gridsearch=None, + attribute_names=None, + num_workers=4, + file_config=None, + run_exp=False, + keep_vars=[], + **kwargs): + + if not run_exp: + choice_lst = [0, 1, 2] + src_name = src_name_default + trg_name = trg_name_default + elif len(sys.argv) > 1: + choice_lst = [int(sys.argv[1])] + src_name = None + trg_name = None + else: + choice_lst = [choice_default] + src_name = src_name_default + trg_name = trg_name_default + + for choice in choice_lst: + if choice == 0: + pass + # add another function here + elif choice == 1: + test_real_dataset(create_obj_func, src_name, trg_name, show=False, block_figure_on_end=run_exp) + + +def parse_arguments(params, as_array=False): + for it in range(4, len(sys.argv), 2): + params[sys.argv[it]] = parse_argument(sys.argv[it + 1], as_array) + return params + + +def parse_argument(string, as_array=False): + try: + result = int(string) + except ValueError: + try: + result = float(string) + except ValueError: + if str.lower(string) == 'true': + result = True + elif str.lower(string) == 'false': + result = False + elif string == "[]": + return [] + elif ('|' in string) and ('[' in string) and (']' in string): + result = [float(item) for item in string[1:-1].split('|')] + return result + elif (',' in string) and ('(' in string) and (')' in string): + split = string[1:-1].split(',') + result = float(split[0]) ** np.arange(float(split[1]), float(split[2]), float(split[3])) + return result + else: + result = string + + return [result] if as_array else result + + +def resolve_conflict_params(primary_params, secondary_params): + for key in primary_params.keys(): + if key in secondary_params.keys(): + del secondary_params[key] + return secondary_params + + +def extract_param(key, value, params_gridsearch, scalar=False): + if key in params_gridsearch.keys(): + value = params_gridsearch[key] + del params_gridsearch[key] + if scalar and (value is not None): + value = value[0] + return value, params_gridsearch + + +def dict2string(params): + result = '' + for key, value in params.items(): + if type(value) is np.ndarray: + if value.size < 16: + result += key + ': ' + '|'.join('{0:.4f}'.format(x) for x in value.ravel()) + ', ' + else: + result += key + ': ' + str(value) + ', ' + return '{' + result[:-2] + '}' + + +def u2t(x): + """Convert uint8 to [-1, 1] float + """ + return x.astype('float32') / 255 * 2 - 1 diff --git a/tensorbayes.tar b/tensorbayes.tar new file mode 100644 index 0000000..e1a24cf Binary files /dev/null and b/tensorbayes.tar differ diff --git a/tf1.9py3.5.yml b/tf1.9py3.5.yml new file mode 100644 index 0000000..7913fee --- /dev/null +++ b/tf1.9py3.5.yml @@ -0,0 +1,109 @@ +name: tf1.9py3.5 +channels: + - defaults +dependencies: + - _libgcc_mutex=0.1=main + - asn1crypto=1.4.0=py_0 + - blas=1.0=mkl + - bzip2=1.0.8=h7b6447c_0 + - ca-certificates=2020.7.22=0 + - cairo=1.14.12=h8948797_3 + - certifi=2018.8.24=py35_1 + - cffi=1.11.5=py35he75722e_1 + - chardet=3.0.4=py35_1 + - cryptography=2.3.1=py35hc365091_0 + - cycler=0.10.0=py35hc4d5149_0 + - dbus=1.13.16=hb2f20db_0 + - dill=0.3.2=py_0 + - expat=2.2.9=he6710b0_2 + - ffmpeg=4.0=hcdf2ecd_0 + - fontconfig=2.13.0=h9420a91_0 + - freeglut=3.0.0=hf484d3e_5 + - freetype=2.10.2=h5ab3b9f_0 + - glib=2.63.1=h5a9c865_0 + - graphite2=1.3.14=h23475e2_0 + - gst-plugins-base=1.14.0=hbbd80ab_1 + - gstreamer=1.14.0=hb453b48_1 + - harfbuzz=1.8.8=hffaf4a1_0 + - hdf5=1.10.2=hba1933b_1 + - icu=58.2=he6710b0_3 + - idna=2.10=py_0 + - intel-openmp=2019.4=243 + - jasper=2.0.14=h07fcdf6_1 + - jpeg=9b=h024ee3a_2 + - kiwisolver=1.0.1=py35hf484d3e_0 + - libedit=3.1.20191231=h14c3975_1 + - libffi=3.2.1=hd88cf55_4 + - libgcc-ng=9.1.0=hdf63c60_0 + - libgfortran-ng=7.3.0=hdf63c60_0 + - libglu=9.0.0=hf484d3e_1 + - libopencv=3.4.2=hb342d67_1 + - libopus=1.3.1=h7b6447c_0 + - libpng=1.6.37=hbc83047_0 + - libprotobuf=3.6.0=hdbcaa40_0 + - libstdcxx-ng=9.1.0=hdf63c60_0 + - libtiff=4.1.0=h2733197_1 + - libuuid=1.0.3=h1bed415_2 + - libvpx=1.7.0=h439df22_0 + - libxcb=1.14=h7b6447c_0 + - libxml2=2.9.10=he19cac6_1 + - lz4-c=1.9.2=he6710b0_1 + - matplotlib=3.0.0=py35h5429711_0 + - mkl=2018.0.3=1 + - mkl_fft=1.0.6=py35h7dd41cf_0 + - mkl_random=1.0.1=py35h4414c95_1 + - ncurses=6.2=he6710b0_1 + - numpy=1.15.2=py35h1d66e8a_0 + - opencv=3.4.2=py35h6fd60c2_1 + - openssl=1.0.2u=h7b6447c_0 + - pandas=0.22.0=py35hf484d3e_0 + - pcre=8.44=he6710b0_0 + - pip=10.0.1=py35_0 + - pixman=0.40.0=h7b6447c_0 + - protobuf=3.6.0=py35hf484d3e_0 + - py-opencv=3.4.2=py35hb342d67_1 + - pycparser=2.20=py_2 + - pyopenssl=18.0.0=py35_0 + - pyparsing=2.4.7=py_0 + - pyqt=5.9.2=py35h05f1152_2 + - pysocks=1.6.8=py35_0 + - python=3.5.6=hc3d631a_0 + - python-dateutil=2.8.1=py_0 + - pytz=2020.1=py_0 + - qt=5.9.6=h8703b6f_2 + - readline=7.0=h7b6447c_5 + - requests=2.24.0=py_0 + - scikit-learn=0.20.0=py35h4989274_1 + - sip=4.19.8=py35hf484d3e_0 + - six=1.15.0=py_0 + - sqlite=3.33.0=h62c20be_0 + - tbb=2020.2=hfd86e86_0 + - tbb4py=2018.0.5=py35h6bb024c_0 + - tk=8.6.10=hbc83047_0 + - tornado=5.1.1=py35h7b6447c_0 + - urllib3=1.23=py35_0 + - wheel=0.35.1=py_0 + - xz=5.2.5=h7b6447c_0 + - zlib=1.2.11=h7b6447c_3 + - zstd=1.4.5=h9ceee32_0 + - pip: + - absl-py==0.10.0 + - astor==0.8.1 + - gast==0.4.0 + - grpcio==1.31.0 + - h5py==2.10.0 + - importlib-metadata==1.7.0 + - keras==2.2.5 + - keras-applications==1.0.8 + - keras-preprocessing==1.1.2 + - markdown==3.2.2 + - pyyaml==5.3.1 + - scipy==1.4.1 + - setuptools==39.1.0 + - tensorbayes==0.4.0 + - tensorboard==1.9.0 + - tensorflow-gpu==1.9.0 + - termcolor==1.1.0 + - werkzeug==1.0.1 + - zipp==1.2.0 +prefix: /opt/conda/envs/tf1.9py3.5