Skip to content

Commit

Permalink
Fix imports
Browse files Browse the repository at this point in the history
  • Loading branch information
RuiShu committed Mar 14, 2018
1 parent 9ffc4b8 commit ad9ce12
Show file tree
Hide file tree
Showing 4 changed files with 64 additions and 9 deletions.
14 changes: 6 additions & 8 deletions codebase/models/dirtt.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@
from tensorbayes.tfutils import softmax_cross_entropy_with_two_logits as softmax_xent_two
from tensorbayes.layers import placeholder, constant
from tensorflow.python.ops.nn_ops import softmax_cross_entropy_with_logits_v2 as softmax_xent
from tensorflow.python.ops.nn_ops import sigmoid_cross_entropy_with_logits as sigmoid_xent
nn = importlib.import_module('nns.{}'.format(args.nn))
from tensorflow.python.ops.nn_impl import sigmoid_cross_entropy_with_logits as sigmoid_xent
nn = importlib.import_module('codebase.models.nns.{}'.format(args.nn))

def dirtt():
T = tb.utils.TensorDict(dict(
Expand Down Expand Up @@ -54,7 +54,7 @@ def dirtt():
ema = tf.train.ExponentialMovingAverage(decay=0.998)
var_class = tf.get_collection('trainable_variables', 'class/')
ema_op = ema.apply(var_class)
ema_p = nn.classifier(T.test_x, phase=False, reuse=True, getter=tb.get_getter(ema))
ema_p = nn.classifier(T.test_x, phase=False, reuse=True, getter=tb.tfutils.get_getter(ema))

# Teacher model (a back-up of EMA model)
teacher_p = nn.classifier(T.test_x, phase=False, scope='teacher')
Expand All @@ -71,8 +71,7 @@ def dirtt():
# Accuracies
src_acc = basic_accuracy(T.src_y, src_p)
trg_acc = basic_accuracy(T.trg_y, trg_p)
ema_acc = basic_accuracy(T.test_y, T.ema_p)
fn_test_acc = tb.function(T.sess, [T.test_x, T.test_y], test_acc)
ema_acc = basic_accuracy(T.test_y, ema_p)
fn_ema_acc = tb.function(T.sess, [T.test_x, T.test_y], ema_acc)

# Optimizer
Expand Down Expand Up @@ -103,7 +102,7 @@ def dirtt():
tf.summary.scalar('lipschitz/loss_trg_vat', loss_trg_vat),
tf.summary.scalar('lipschitz/loss_src_vat', loss_src_vat),
tf.summary.scalar('hyper/dw', dw),
tf.summary.scalar('hyper/cw', bw),
tf.summary.scalar('hyper/cw', cw),
tf.summary.scalar('hyper/sw', sw),
tf.summary.scalar('hyper/tw', tw),
tf.summary.scalar('acc/src_acc', src_acc),
Expand All @@ -125,9 +124,8 @@ def dirtt():
c('trg'), trg_acc]
T.ops_disc = [summary_disc, train_disc]
T.ops_main = [summary_main, train_main]
T.fn_test_acc = fn_test_acc
T.fn_ema_acc = fn_ema_acc
T.teacher = teacher
T.teacher_update = teacher_update
T.update_teacher = update_teacher

return T
Empty file added codebase/models/nns/__init__.py
Empty file.
57 changes: 57 additions & 0 deletions codebase/models/nns/large.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
import tensorflow as tf
from codebase.args import args
from codebase.models.extra_layers import leaky_relu, noise
from tensorbayes.layers import dense, conv2d, avg_pool, max_pool, batch_norm, instance_norm
from tensorflow.contrib.framework import arg_scope
from tensorflow.python.ops.nn_ops import dropout

def classifier(x, phase, enc_phase=1, trim=0, scope='class', reuse=None, internal_update=False, getter=None):
with tf.variable_scope(scope, reuse=reuse, custom_getter=getter):
with arg_scope([leaky_relu], a=0.1), \
arg_scope([conv2d, dense], activation=leaky_relu, bn=True, phase=phase), \
arg_scope([batch_norm], internal_update=internal_update):

preprocess = instance_norm if args.inorm else tf.identity
layout = [
(preprocess, (), {}),
(conv2d, (96, 3, 1), {}),
(conv2d, (96, 3, 1), {}),
(conv2d, (96, 3, 1), {}),
(max_pool, (2, 2), {}),
(dropout, (), dict(training=phase)),
(noise, (1,), dict(phase=phase)),
(conv2d, (192, 3, 1), {}),
(conv2d, (192, 3, 1), {}),
(conv2d, (192, 3, 1), {}),
(max_pool, (2, 2), {}),
(dropout, (), dict(training=phase)),
(noise, (1,), dict(phase=phase)),
(conv2d, (192, 3, 1), {}),
(conv2d, (192, 3, 1), {}),
(conv2d, (192, 3, 1), {}),
(avg_pool, (), dict(global_pool=True)),
(dense, (args.Y,), dict(activation=None))
]

if enc_phase:
start = 0
end = len(layout) - trim
else:
start = len(layout) - trim
end = len(layout)

for i in xrange(start, end):
with tf.variable_scope('l{:d}'.format(i)):
f, f_args, f_kwargs = layout[i]
x = f(x, *f_args, **f_kwargs)

return x

def feature_discriminator(x, phase, C=1, reuse=None):
with tf.variable_scope('disc/feat', reuse=reuse):
with arg_scope([dense], activation=tf.nn.relu): # Switch to leaky?

x = dense(x, 100)
x = dense(x, C, activation=None)

return x
2 changes: 1 addition & 1 deletion codebase/models/nns/small.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from codebase.models.extra_layers import leaky_relu, noise
from tensorbayes.layers import dense, conv2d, avg_pool, max_pool, batch_norm, instance_norm
from tensorflow.contrib.framework import arg_scope
from tensorflow.python.ops.nn_ops import dropout
from tensorflow.python.layers.core import dropout

def classifier(x, phase, enc_phase=1, trim=0, scope='class', reuse=None, internal_update=False, getter=None):
with tf.variable_scope(scope, reuse=reuse, custom_getter=getter):
Expand Down

0 comments on commit ad9ce12

Please sign in to comment.