Skip to content

Commit

Permalink
Add dirtt logic
Browse files Browse the repository at this point in the history
  • Loading branch information
RuiShu committed Mar 13, 2018
1 parent 5dc9c01 commit c0bfe0b
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 14 deletions.
1 change: 1 addition & 0 deletions codebase/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ def train(M, src=None, trg=None, has_disc=True, saver=None, model_name=None):
src = PseudoData(args.trg, trg, M.teacher)
M.sess.run(M.update_teacher)

# Sanity check model
print_list = []
if src:
save_acc(M, 'fn_ema_acc', 'test/src_test_ema_1k',
Expand Down
52 changes: 38 additions & 14 deletions run_dirtt.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,19 +9,26 @@
parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument('--src', type=str, default='mnist32', help="Src data")
parser.add_argument('--trg', type=str, default='svhn', help="Trg data")
parser.add_argument('--design', type=str, default='v11_y', help="Architecture design")
parser.add_argument('--nn', type=str, default='small', help="Architecture")
parser.add_argument('--trim', type=int, default=5, help="Trim")
parser.add_argument('--inorm', type=int, default=1, help="Instance normalization flag")
parser.add_argument('--pert', type=str, default='vat', help="Type of perturbation")
parser.add_argument('--ball', type=float, default=3.5, help="Perturbation 2-norm ball radius")
parser.add_argument('--dw', type=float, default=1e-2, help="Domain weight")
parser.add_argument('--bw', type=float, default=1e-2, help="Beta (KL) weight")
parser.add_argument('--sw', type=float, default=1, help="Src weight")
parser.add_argument('--tw', type=float, default=1e-2, help="Trg weight")
parser.add_argument('--bw', type=float, default=1e-2, help="Beta (KL) weight")
parser.add_argument('--lr', type=float, default=1e-3, help="Learning rate")
parser.add_argument('--dirt', type=int, default=0, help="0 == VADA, >0 == DIRT-T interval")
parser.add_argument('--run', type=int, default=999, help="Run index")
parser.add_argument('--run', type=int, default=999, help="Run index. >= 999 == debugging")
parser.add_argument('--logdir', type=str, default='log', help="Log directory")
codebase_args.args = args = parser.parse_args()

# Argument overrides and additions
src2Y = {'mnist': 10, 'mnistm': 10, 'digit': 10, 'svhn': 10, 'cifar': 9, 'stl': 9, 'sign': 43}
args.Y = src2Y[args.src]
args.H = 32
args.bw = args.bw if args.dirt > 0 else 0. # mask bw when running VADA
pprint(vars(args))

from codebase.models.dirtt import dirtt
Expand All @@ -31,16 +38,16 @@
# Make model name
setup = [
('model={:s}', 'dirtt'),
('src={:s}', args.src),
('trg={:s}', args.trg),
('des={:s}', args.design),
('trim={:d}', args.trim),
('dw={:.0e}', args.dw),
('cw={:.0e}', args.cw),
('sbw={:.0e}', args.sbw),
('tbw={:.0e}', args.tbw),
('src={:s}', args.src),
('trg={:s}', args.trg),
('nn={:s}', args.nn),
('trim={:d}', args.trim),
('dw={:.0e}', args.dw),
('bw={:.0e}', args.bw),
('sw={:.0e}', args.sw),
('tw={:.0e}', args.tw),
('dirt={:05d}', args.dirt),
('run={:04d}', args.run)
('run={:04d}', args.run)
]
model_name = '_'.join([t.format(v) for (t, v) in setup])
print "Model name:", model_name
Expand All @@ -50,8 +57,25 @@
saver = tf.train.Saver()

if args.dirt > 0:
# Figure out later
pass
run = args.run if args.run < 999 else 0
setup = [
('model={:s}', 'dirtt'),
('src={:s}', args.src),
('trg={:s}', args.trg),
('nn={:s}', args.nn),
('trim={:d}', args.trim),
('dw={:.0e}', args.dw),
('bw={:.0e}', 0),
('sw={:.0e}', args.sw),
('tw={:.0e}', args.tw),
('dirt={:05d}', 0),
('run={:04d}', run)
]
vada_name = '_'.join([t.format(v) for (t, v) in setup])
vada_path = os.path.join('checkpoints', vada_name)
path = tf.train.latest_checkpoint(restoration_path)
saver.restore(M.sess, path)
print "Restored from {}".format(path)

src = get_data(args.src)
trg = get_data(args.trg)
Expand Down

0 comments on commit c0bfe0b

Please sign in to comment.