Skip to content

Commit

Permalink
Freeze training to balance networks
Browse files Browse the repository at this point in the history
  • Loading branch information
josepdecid committed Apr 16, 2019
1 parent 4cc7002 commit accec99
Show file tree
Hide file tree
Showing 4 changed files with 28 additions and 19 deletions.
1 change: 1 addition & 0 deletions src/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
BATCH_SIZE = 16
MAX_POLYPHONY = 12
NORMALIZE_FREQ = False
T_LOSS_BALANCER = 0.7

# Generator
LR_G = 0.4
Expand Down
2 changes: 1 addition & 1 deletion src/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ def run_model():


if __name__ == '__main__':
logging.getLogger().setLevel(logging.INFO)
logging.getLogger().setLevel(logging.DEBUG)
logging.info(f'Running Python {sys.version.split()[0]} with PyTorch {torch.__version__} in {device}')

parser = argparse.ArgumentParser()
Expand Down
36 changes: 22 additions & 14 deletions src/model/Trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,8 @@
import torch
from tqdm import tqdm

from constants import EPOCHS, CKPT_STEPS, CHECKPOINTS_PATH, SAMPLE_STEPS, FLAGS, PLOT_COL, PRETRAIN_G, PRETRAIN_D
from constants import EPOCHS, CKPT_STEPS, CHECKPOINTS_PATH, SAMPLE_STEPS, FLAGS, PLOT_COL, PRETRAIN_G, PRETRAIN_D, \
T_LOSS_BALANCER
from dataset.MusicDataset import MusicDataset
from dataset.preprocessing.reconstructor import reconstruct_midi
from model.gan.GANGenerator import GANGenerator
Expand Down Expand Up @@ -49,9 +50,12 @@ def train(self):
metric.plot_loss(self.vis, plot='Pretraining', title='Pretraining Loss')
metric.print_metrics()

current_loss_d = 1e99
current_loss_g = 1e99

# TRAINING
for epoch in range(1, EPOCHS + 1):
metric = self._train_epoch(epoch)
metric = self._train_epoch(epoch, current_loss_d, current_loss_g)

if FLAGS['viz']:
metric.plot_loss(self.vis)
Expand Down Expand Up @@ -117,7 +121,7 @@ def _pretrain_epoch(self, epoch: int) -> EpochMetric:

return EpochMetric(epoch, g_loss, d_loss, None)

def _train_epoch(self, epoch: int) -> EpochMetric:
def _train_epoch(self, epoch: int, current_loss_d=0.0, current_loss_g=0.0) -> EpochMetric:
"""
Train the model for one epoch with the classical GAN training approach.
:param epoch: Current epoch index.
Expand All @@ -133,17 +137,21 @@ def _train_epoch(self, epoch: int) -> EpochMetric:
features = features.to(device)
batch_size = features.size(0)

# if current_loss_d >= 0.7 * current_loss_g:
d_loss, t_pos, t_neg = self._train_discriminator(real_data=features)
# current_loss_d = d_loss
sum_loss_d += d_loss * batch_size
sum_t_pos += t_pos
sum_t_neg += t_neg

# if current_loss_g >= 0.7 * current_loss_d:
g_loss = self._train_generator(real_data=features)
# current_loss_g = g_loss
sum_loss_g += g_loss * batch_size
if current_loss_d >= T_LOSS_BALANCER * current_loss_g:
d_loss, t_pos, t_neg = self._train_discriminator(real_data=features)
current_loss_d = d_loss
sum_loss_d += d_loss * batch_size
sum_t_pos += t_pos
sum_t_neg += t_neg
else:
logging.debug('Freezing Discriminator')

if current_loss_g >= T_LOSS_BALANCER * current_loss_d:
g_loss = self._train_generator(real_data=features)
current_loss_g = g_loss
sum_loss_g += g_loss * batch_size
else:
logging.debug('Freezing Generator')

len_data = len(self.loader.dataset)
sum_loss_g /= len_data
Expand Down
8 changes: 4 additions & 4 deletions src/model/helpers/EpochMetric.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,8 @@ def tp_ratio(self):
return self.cf[0][0] / (self.cf[0][0] + self.cf[0][1])

@property
def tn_ratio(self):
return self.cf[1][1] / (self.cf[1][0] + self.cf[1][1])
def fn_ratio(self):
return self.cf[1][0] / (self.cf[1][0] + self.cf[1][1])

def print_metrics(self):
if self.d_loss is None:
Expand All @@ -46,9 +46,9 @@ def plot_confusion_matrix(self, vis: VisdomPlotter):

# True Negative Ratios
vis.plot_line(plot_name='ConfusionMatrix',
line_label='True Negatives',
line_label='False Negatives',
x=[self.epoch],
y=[self.tn_ratio])
y=[self.fn_ratio])

# Confusion Matrix
cf = np.array(self.cf)
Expand Down

0 comments on commit accec99

Please sign in to comment.