Skip to content

Commit

Permalink
Added inline comments for explanation.
Browse files Browse the repository at this point in the history
  • Loading branch information
dr-costas committed May 17, 2019
1 parent 0135612 commit 8a3794e
Showing 1 changed file with 9 additions and 2 deletions.
11 changes: 9 additions & 2 deletions scripts/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,11 +155,14 @@ def _training_iteration(_m, _data, _device, _solver, _sep_l, _reg_twin,
def training_process():
"""The training process.
"""
# Check what device we'll be using
device = 'cuda' if not debug and cuda.is_available() else 'cpu'

# Inform about the device and time and date
printing.print_intro_messages(device)
printing.print_msg('Starting training process. Debug mode: {}'.format(debug))

# Set up MaD TwinNet
with printing.InformAboutProcess('Setting up MaD TwinNet'):
mad_twin_net = MaDTwinNet(
rnn_enc_input_dim=hyper_parameters['reduced_dim'],
Expand All @@ -168,13 +171,14 @@ def training_process():
context_length=hyper_parameters['context_length']
).to(device)

# Get the optimizer
with printing.InformAboutProcess('Setting up optimizer'):
# Optimizer
optimizer = optim.Adam(
mad_twin_net.parameters(),
lr=hyper_parameters['learning_rate']
)

# Create the data feeder
with printing.InformAboutProcess('Initializing data feeder'):
epoch_it = data_feeder.data_feeder_training(
window_size=hyper_parameters['window_size'],
Expand All @@ -187,6 +191,7 @@ def training_process():
debug=debug
)

# Inform about the future
printing.print_msg('Training starts', end='\n\n')

# Training loop
Expand All @@ -196,12 +201,14 @@ def training_process():
hyper_parameters['lambda_2'], hyper_parameters['max_grad_norm'])
for e in range(training_constants['epochs'])]

# Kindly end and save the model
# Inform about the past
printing.print_msg('Training done.', start='\n-- ')

# Save the model
with printing.InformAboutProcess('Saving model.. '):
save(mad_twin_net.mad.state_dict(), output_states_path['mad'])

# Say goodbye!
printing.print_msg('That\'s all folks!')


Expand Down

0 comments on commit 8a3794e

Please sign in to comment.