Skip to content

Commit

Permalink
Change argparser inputs
Browse files Browse the repository at this point in the history
  • Loading branch information
RuiShu committed Mar 14, 2018
1 parent 2b767af commit d41f607
Showing 1 changed file with 5 additions and 6 deletions.
11 changes: 5 additions & 6 deletions run_dirtt.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,13 @@

# Settings
parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument('--src', type=str, default='mnist32', help="Src data")
parser.add_argument('--src', type=str, default='mnist', help="Src data")
parser.add_argument('--trg', type=str, default='svhn', help="Trg data")
parser.add_argument('--nn', type=str, default='small', help="Architecture")
parser.add_argument('--nn', type=str, default='small', help="Architecture")
parser.add_argument('--trim', type=int, default=5, help="Trim")
parser.add_argument('--inorm', type=int, default=1, help="Instance normalization flag")
parser.add_argument('--pert', type=str, default='vat', help="Type of perturbation")
parser.add_argument('--ball', type=float, default=3.5, help="Perturbation 2-norm ball radius")
parser.add_argument('--radius', type=float, default=3.5, help="Perturbation 2-norm ball radius")
parser.add_argument('--dw', type=float, default=1e-2, help="Domain weight")
parser.add_argument('--bw', type=float, default=1e-2, help="Beta (KL) weight")
parser.add_argument('--sw', type=float, default=1, help="Src weight")
Expand All @@ -33,7 +33,7 @@

from codebase.models.dirtt import dirtt
from codebase.train import train
from codebase.utils import get_data
from codebase.datasets import get_data

# Make model name
setup = [
Expand Down Expand Up @@ -82,6 +82,5 @@

train(M, src, trg,
saver=saver,
has_disc=True,
add_z=True,
has_disc=args.dirt == 0,
model_name=model_name)

0 comments on commit d41f607

Please sign in to comment.