Skip to content

Commit

Permalink
add default parameters
Browse files Browse the repository at this point in the history
  • Loading branch information
kxz18 committed Oct 27, 2023
1 parent 0002e6a commit ff1517a
Showing 1 changed file with 10 additions and 10 deletions.
20 changes: 10 additions & 10 deletions src/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,25 +63,25 @@ def parse():
parser.add_argument('--vocab', type=str, required=True, help='path of vocabulary (.pkl) or bpe vocab(.txt)')
parser.add_argument('--shuffle', action='store_true', help='shuffle data')
parser.add_argument('--save_dir', type=str, required=True, help='path to store the model')
parser.add_argument('--batch_size', type=int, default=64, help='size of mini-batch')
parser.add_argument('--lr', type=float, default=5e-4, help='Learning rate')
parser.add_argument('--batch_size', type=int, default=32, help='size of mini-batch')
parser.add_argument('--lr', type=float, default=1e-3, help='Learning rate')
parser.add_argument('--alpha', type=float, required=True,
help='balancing reconstruct loss and predictor loss')
# vae training
parser.add_argument('--beta', type=float, default=0.001,
parser.add_argument('--beta', type=float, default=0,
help='balancing kl loss and other loss')
parser.add_argument('--step_beta', type=float, default=0.0005,
parser.add_argument('--step_beta', type=float, default=0.002,
help='value of beta increasing step')
parser.add_argument('--max_beta', type=float, default=0.005)
parser.add_argument('--kl_warmup', type=int, default=2000,
parser.add_argument('--max_beta', type=float, default=0.01)
parser.add_argument('--kl_warmup', type=int, default=0,
help='Within these steps beta is set to 0')
parser.add_argument('--kl_anneal_iter', type=int, default=1000)

parser.add_argument('--num_workers', type=int, default=4, help='number of cpus to load data')
parser.add_argument('--gpus', default=None, help='gpus to use')
parser.add_argument('--epochs', type=int, default=20, help='max epochs')
parser.add_argument('--epochs', type=int, default=6, help='max epochs')
parser.add_argument('--patience', type=int, default=3, help='early stopping number of epochs')
parser.add_argument('--grad_clip', type=float, default=0,
parser.add_argument('--grad_clip', type=float, default=10.0,
help='clip large gradient to prevent gradient boom')
parser.add_argument('--monitor', type=str, default='val_loss',
help='Value to monitor in early stopping')
Expand All @@ -91,9 +91,9 @@ def parse():
default=['qed', 'logp'], help='properties to predict')
parser.add_argument('--predictor_hidden_dim', type=int, default=200,
help='hidden dim of predictor (MLP)')
parser.add_argument('--node_hidden_dim', type=int, default=100,
parser.add_argument('--node_hidden_dim', type=int, default=300,
help='dim of node hidden embedding in encoder and decoder')
parser.add_argument('--graph_embedding_dim', type=int, default=200,
parser.add_argument('--graph_embedding_dim', type=int, default=400,
help='dim of graph embedding by encoder and also condition for ae decoders')
parser.add_argument('--latent_dim', type=int, default=56,
help='dim of latent z for vae decoders')
Expand Down

0 comments on commit ff1517a

Please sign in to comment.