Skip to content

Commit

Permalink
Add get_info and change PseudoData input
Browse files Browse the repository at this point in the history
  • Loading branch information
RuiShu committed Mar 13, 2018
1 parent cbc6def commit 5dc9c01
Showing 1 changed file with 3 additions and 3 deletions.
6 changes: 3 additions & 3 deletions codebase/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ def train(M, src=None, trg=None, has_disc=True, saver=None, model_name=None):
# Replace src domain with psuedolabeled trg
if args.dirt > 0:
print "Setting backup and updating backup model"
src = PseudoData(args.trg, trg, M)
src = PseudoData(args.trg, trg, M.teacher)
M.sess.run(M.update_teacher)

print_list = []
Expand All @@ -73,8 +73,8 @@ def train(M, src=None, trg=None, has_disc=True, saver=None, model_name=None):

print print_list

if src: print "Src size:", src.train.images.shape
if trg: print "Trg size:", trg.train.images.shape
if src: get_info(args.src, src)
if trg: get_info(args.trg, trg)
print "Batch size:", bs
print "Iterep:", iterep
print "Total iterations:", n_epoch * iterep
Expand Down

0 comments on commit 5dc9c01

Please sign in to comment.