Skip to content

Commit

Permalink
Merge Pretrain discriminator
Browse files Browse the repository at this point in the history
  • Loading branch information
josepdecid committed Apr 14, 2019
1 parent 8a7480b commit 4f8300b
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 19 deletions.
2 changes: 1 addition & 1 deletion src/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,6 @@
L2_D = 0.25
HIDDEN_DIM_D = 30
BIDIRECTIONAL_D = True
PRETRAIN_D = 0
PRETRAIN_D = 10
TYPE_D = 'LSTM'
LAYERS_D = 1
36 changes: 18 additions & 18 deletions src/model/Trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,7 @@
import torch
from tqdm import tqdm

from constants import EPOCHS, NUM_NOTES, CKPT_STEPS, CHECKPOINTS_PATH, \
SAMPLE_STEPS, FLAGS, PLOT_COL, PRETRAIN_G, PRETRAIN_D
from constants import EPOCHS, CKPT_STEPS, CHECKPOINTS_PATH, SAMPLE_STEPS, FLAGS, PLOT_COL, PRETRAIN_G, PRETRAIN_D
from dataset.MusicDataset import MusicDataset
from dataset.preprocessing.reconstructor import reconstruct_midi
from model.gan.GANGenerator import GANGenerator
Expand Down Expand Up @@ -104,11 +103,13 @@ def _pretrain_epoch(self, epoch: int) -> EpochMetric:
features = features.to(device)
batch_size = features.size(0)

g_loss = self._train_generator(real_data=features, pretraining=True)
sum_loss_g += g_loss * batch_size
if epoch < PRETRAIN_G:
g_loss = self._train_generator(real_data=features, pretraining=True)
sum_loss_g += g_loss * batch_size

d_loss = self._train_discriminator(real_data=features, pretraining=True)
sum_loss_d += d_loss * batch_size
if epoch < PRETRAIN_D:
d_loss = self._train_discriminator(real_data=features, pretraining=True)
sum_loss_d += d_loss * batch_size

g_loss = sum_loss_g / len(self.loader.dataset)
d_loss = sum_loss_d / len(self.loader.dataset)
Expand Down Expand Up @@ -196,27 +197,26 @@ def _train_discriminator(self, real_data: FloatTensor, pretraining=False) -> flo
# Reset gradients
self.model.d_optimizer.zero_grad()

#Predictions on real data
# Predictions on real data
real_predictions = self.model.discriminator(real_data)

if pretraining:
real_loss = self.model.pretraining_discriminator_criterion(real_predictions, torch.ones(real_predictions.shape))
real_loss = self.model.pretraining_criterion(real_predictions, torch.ones(real_predictions.shape))
real_loss.backward()
fake_loss = 0.

fake_loss = 0
else:
# Train on real data
real_loss = self.model.training_criterion(real_predictions, ones_target((batch_size,)))
real_loss.backward()

# Train on fake data
noise_data = GANGenerator.noise((batch_size, time_steps))
fake_data = self.model.generator(noise_data).detach()
fake_predictions = self.model.discriminator(fake_data)
fake_loss = self.model.training_criterion(fake_predictions, zeros_target((batch_size,)))
fake_loss.backward()
# Train on fake data
noise_data = GANGenerator.noise((batch_size, time_steps))
fake_data = self.model.generator(noise_data).detach()
fake_predictions = self.model.discriminator(fake_data)
fake_loss = self.model.training_criterion(fake_predictions, zeros_target((batch_size,)))
fake_loss.backward()

# Update parameters
self.model.d_optimizer.step()
# Update parameters
self.model.d_optimizer.step()

return (real_loss + fake_loss).item()

0 comments on commit 4f8300b

Please sign in to comment.