Skip to content

Commit

Permalink
Move loss targets to corresponding device
Browse files Browse the repository at this point in the history
  • Loading branch information
josepdecid committed Apr 13, 2019
1 parent 8e1a066 commit ea50b81
Showing 1 changed file with 3 additions and 3 deletions.
6 changes: 3 additions & 3 deletions src/model/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,7 @@ def _train_generator(self, data):
prediction = self.model.discriminator(fake_data)

# Calculate gradients w.r.t parameters and backpropagate
loss = self.model.g_criterion(prediction, torch.ones(batch_size))
loss = self.model.g_criterion(prediction, torch.ones(batch_size).to(device))
loss.backward()

# Update parameters
Expand All @@ -163,14 +163,14 @@ def _train_discriminator(self, real_data) -> float:

# Train on real data
real_predictions = self.model.discriminator(real_data)
real_loss = self.model.d_criterion(real_predictions, torch.ones(batch_size))
real_loss = self.model.d_criterion(real_predictions, torch.ones(batch_size).to(device))
real_loss.backward()

# Train on fake data
noise_data = GANGenerator.noise((batch_size, time_steps, NUM_NOTES))
fake_data = self.model.generator(noise_data).detach()
fake_predictions = self.model.discriminator(fake_data)
fake_loss = self.model.d_criterion(fake_predictions, torch.zeros(batch_size))
fake_loss = self.model.d_criterion(fake_predictions, torch.zeros(batch_size).to(device))
fake_loss.backward()

# Update parameters
Expand Down

0 comments on commit ea50b81

Please sign in to comment.