Skip to content

Commit

Permalink
Fix loss functions
Browse files Browse the repository at this point in the history
  • Loading branch information
josepdecid committed Apr 8, 2019
1 parent 9218606 commit f5589a2
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 8 deletions.
4 changes: 2 additions & 2 deletions src/model/GANModel.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,12 +45,12 @@ def eval_mode(self):
def _generator_criterion(d_g_z: FloatTensor) -> FloatTensor:
"""
Loss function for Generator G.
Calculates 1/m · ∑ log(1 - D(G(z))
Calculates 1/m · ∑ log(D(G(z))
where *z* is the uniform random vector (noise) ∈ [0, 1]^T
:param d_g_z: Tensor corresponding to the discriminator prediction D(G(z))
:return: Loss of G
"""
return torch.mean(torch.log(1 - d_g_z))
return torch.mean(torch.log(d_g_z))

@staticmethod
def _discriminator_criterion(d_x: FloatTensor, d_g_z: FloatTensor) -> FloatTensor:
Expand Down
12 changes: 6 additions & 6 deletions src/model/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,12 +61,12 @@ def train_generator(model: GANModel, data):
prediction = model.discriminator(fake_data)

# Calculate gradients w.r.t parameters and backpropagate
loss = model.g_criterion(d_g_z=prediction)
(-loss).backward()
loss = -model.g_criterion(d_g_z=prediction)
loss.backward()

# Update parameters
model.g_optimizer.step()
return loss
return -loss


def train_discriminator(model: GANModel, data):
Expand All @@ -88,13 +88,13 @@ def train_discriminator(model: GANModel, data):
fake_predictions = model.discriminator(fake_data)

# Calculate loss and optimize
loss = model.d_criterion(d_x=real_predictions, d_g_z=fake_predictions)
(-loss).backward()
loss = -model.d_criterion(d_x=real_predictions, d_g_z=fake_predictions)
loss.backward()

# Update parameters
model.d_optimizer.step()

return loss
return -loss


def train_epoch(model: GANModel, loader: DataLoader) -> Tuple[float, float]:
Expand Down

0 comments on commit f5589a2

Please sign in to comment.