Skip to content

Commit

Permalink
Fix some stuff
Browse files Browse the repository at this point in the history
  • Loading branch information
josepdecid committed Apr 13, 2019
1 parent 87319fc commit f6c6fe2
Show file tree
Hide file tree
Showing 4 changed files with 20 additions and 9 deletions.
2 changes: 1 addition & 1 deletion run.sh
Original file line number Diff line number Diff line change
Expand Up @@ -47,13 +47,13 @@ fi

ngrok http "$VIS_PORT" > /dev/null &
echo "> Tunneling URL with Ngrok"
sleep 3

WEB_HOOK_URL=$(curl --silent --connect-timeout 10 https://localhost:4040/api/tunnels | \
pipenv run python -c "import json,sys;obj=json.load(sys.stdin);print(obj['tunnels'][0]['public_url'])")

echo "Visdom tunneled in $WEB_HOOK_URL"


# RUN MODEL

pipenv run python src/main.py --viz
Expand Down
8 changes: 4 additions & 4 deletions src/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,19 +36,19 @@
SAMPLE_TIMES = 100

# Generator
LR_G = 0.4
LR_G = 0.1
LR_PAT_G = 10
L2_G = 0.25
HIDDEN_DIM_G = 150
BIDIRECTIONAL_G = False
TYPE_G = 'LSTM'
LAYERS_G = 3
LAYERS_G = 1

# Discriminator
LR_D = 0.4
LR_D = 0.1
LR_PAT_D = 10
L2_D = 0.25
HIDDEN_DIM_D = 150
BIDIRECTIONAL_D = True
TYPE_D = 'LSTM'
LAYERS_D = 3
LAYERS_D = 1
8 changes: 4 additions & 4 deletions src/model/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from model.GANGenerator import GANGenerator
from model.GANModel import GANModel
from constants import EPOCHS, NUM_NOTES, CKPT_STEPS, CHECKPOINTS_PATH, SAMPLE_STEPS, FLAGS, PLOT_COL
from utils.tensors import device
from utils.tensors import device, zeros_target, ones_target


class VisdomLinePlotter:
Expand Down Expand Up @@ -144,7 +144,7 @@ def _train_generator(self, data):
prediction = self.model.discriminator(fake_data)

# Calculate gradients w.r.t parameters and backpropagate
loss = self.model.g_criterion(prediction, torch.ones(batch_size).to(device))
loss = self.model.g_criterion(prediction, ones_target((batch_size,)))
loss.backward()

# Update parameters
Expand All @@ -163,14 +163,14 @@ def _train_discriminator(self, real_data) -> float:

# Train on real data
real_predictions = self.model.discriminator(real_data)
real_loss = self.model.d_criterion(real_predictions, torch.ones(batch_size).to(device))
real_loss = self.model.d_criterion(real_predictions, ones_target((batch_size,)))
real_loss.backward()

# Train on fake data
noise_data = GANGenerator.noise((batch_size, time_steps, NUM_NOTES))
fake_data = self.model.generator(noise_data).detach()
fake_predictions = self.model.discriminator(fake_data)
fake_loss = self.model.d_criterion(fake_predictions, torch.zeros(batch_size).to(device))
fake_loss = self.model.d_criterion(fake_predictions, zeros_target((batch_size,)))
fake_loss.backward()

# Update parameters
Expand Down
11 changes: 11 additions & 0 deletions src/utils/tensors.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,15 @@
from typing import Tuple

import torch
from torch import cuda, device as dev

use_cuda = cuda.is_available()
device = dev('cuda' if use_cuda else 'cpu')


def zeros_target(dims: Tuple):
return torch.zeros(size=dims).to(device)


def ones_target(dims: Tuple):
return torch.ones(size=dims).to(device)

0 comments on commit f6c6fe2

Please sign in to comment.