Skip to content

Commit

Permalink
Add LR annealing with visualizations
Browse files Browse the repository at this point in the history
  • Loading branch information
josepdecid committed Apr 13, 2019
1 parent f45e2b3 commit 877e661
Show file tree
Hide file tree
Showing 11 changed files with 54 additions and 44 deletions.
20 changes: 15 additions & 5 deletions src/utils/constants.py → src/constants.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import os
import numpy as np
from pathlib import Path

SOURCE_MIDI_URLS = ['http:https://www.jsbach.es/bbdd/index01_20.htm']
Expand All @@ -12,6 +13,11 @@

PPath = lambda p: Path(PROJECT_PATH + p)

PLOT_COL = {
'G': np.array([[75, 159, 56]]), # Green
'D': np.array([[220, 65, 51]]) # Red
}

SAMPLE_STEPS = 1
CKPT_STEPS = 100

Expand All @@ -22,23 +28,27 @@
'viz': True
}

# HYPERPARAMETERS
# HYPERPARAMETERS #

EPOCHS = 100
BATCH_SIZE = 8
MAX_POLYPHONY = 1
SAMPLE_TIMES = 100

LR_G = 0.1
# Generator
LR_G = 0.3
LR_PAT_G = 5
L2_G = 0.25
HIDDEN_DIM_G = 50
BIDIRECTIONAL_G = False
TYPE_G = 'GRU'
TYPE_G = 'LSTM'
LAYERS_G = 1

LR_D = 0.1
# Discriminator
LR_D = 0.3
LR_PAT_D = 5
L2_D = 0.25
HIDDEN_DIM_D = 50
BIDIRECTIONAL_D = True
TYPE_D = 'GRU'
TYPE_D = 'LSTM'
LAYERS_D = 1
5 changes: 2 additions & 3 deletions src/dataset/MusicDataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import torch
from torch.utils.data import Dataset, DataLoader

from utils.constants import DATASET_PATH, NUM_NOTES, BATCH_SIZE
from constants import DATASET_PATH, NUM_NOTES, BATCH_SIZE
from utils.tensors import use_cuda
from utils.typings import File, FloatTensor

Expand All @@ -25,8 +25,7 @@ def __init__(self):
self.padded_songs = self._apply_padding()

def __getitem__(self, index: int) -> FloatTensor:
x = self.padded_songs[index]
return FloatTensor(x)
return self.padded_songs[index]

def __len__(self) -> int:
return len(self.songs)
Expand Down
2 changes: 1 addition & 1 deletion src/dataset/preprocessing/crawler.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

from bs4 import BeautifulSoup

from utils.constants import RAW_DATASET_PATH, SOURCE_MIDI_URLS
from constants import RAW_DATASET_PATH, SOURCE_MIDI_URLS

logging.getLogger().setLevel(logging.INFO)

Expand Down
2 changes: 1 addition & 1 deletion src/dataset/preprocessing/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from tqdm import tqdm

from dataset.Music import Song, Track, NoteData
from utils.constants import RAW_DATASET_PATH, DATASET_PATH, NUM_NOTES, MIN_NOTE, SAMPLE_TIMES
from constants import RAW_DATASET_PATH, DATASET_PATH, NUM_NOTES, MIN_NOTE, SAMPLE_TIMES


def csv_cleaner(data: List[str]) -> Song:
Expand Down
2 changes: 1 addition & 1 deletion src/dataset/preprocessing/reconstructor.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

from py_midicsv import csv_to_midi, FileWriter

from utils.constants import RESULTS_PATH, SAMPLE_TIMES, MIN_NOTE
from constants import RESULTS_PATH, SAMPLE_TIMES, MIN_NOTE


def store_csv_to_midi(title: str, data: str) -> str:
Expand Down
2 changes: 1 addition & 1 deletion src/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from dataset.MusicDataset import MusicDataset
from model.GANModel import GANModel
from model.train import Trainer
from utils.constants import FLAGS
from constants import FLAGS
from utils.tensors import device


Expand Down
2 changes: 1 addition & 1 deletion src/model/GANDiscriminator.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from torch import nn

from model.RNN import RNN
from utils.constants import LAYERS_D, HIDDEN_DIM_D, BIDIRECTIONAL_D, NUM_NOTES, TYPE_D
from constants import LAYERS_D, HIDDEN_DIM_D, BIDIRECTIONAL_D, NUM_NOTES, TYPE_D
from utils.tensors import device


Expand Down
5 changes: 2 additions & 3 deletions src/model/GANGenerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,9 @@
import torch
from torch import nn
from torch.nn import functional as F
from torch.autograd import Variable

from model.RNN import RNN
from utils.constants import HIDDEN_DIM_G, LAYERS_G, BIDIRECTIONAL_G, NUM_NOTES, TYPE_G
from constants import HIDDEN_DIM_G, LAYERS_G, BIDIRECTIONAL_G, NUM_NOTES, TYPE_G
from utils.tensors import device


Expand Down Expand Up @@ -39,4 +38,4 @@ def noise(dims: Tuple):
Generates a 2-d vector of uniform sampled random values.
:param dims: Tuple with the dimensions of the data.
"""
return Variable(torch.randint(0, 2, dims, dtype=torch.float).to(device))
return torch.randint(0, 2, dims, dtype=torch.float).to(device)
15 changes: 5 additions & 10 deletions src/model/GANModel.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
import logging

import torch
from torch.optim.lr_scheduler import ReduceLROnPlateau

from model.GANDiscriminator import GANDiscriminator
from model.GANGenerator import GANGenerator
from utils.constants import LR_G, L2_G, L2_D, LR_D
from constants import LR_G, L2_G, L2_D, LR_D, LR_PAT_D, LR_PAT_G
from utils.tensors import device
from utils.typings import NNet, Optimizer, Criterion, FloatTensor
from utils.typings import NNet, Optimizer, Criterion, FloatTensor, Scheduler


class GANModel:
Expand All @@ -15,20 +16,14 @@ def __init__(self):

self.generator: NNet = GANGenerator().to(device)
self.g_optimizer: Optimizer = torch.optim.Adam(self.generator.parameters(), lr=LR_G, weight_decay=L2_G)
self.g_scheduler: Scheduler = ReduceLROnPlateau(self.g_optimizer, mode='min', patience=LR_PAT_G)
self.g_criterion: Criterion = GANModel._generator_criterion

self.discriminator: NNet = GANDiscriminator().to(device)
self.d_optimizer: Optimizer = torch.optim.Adam(self.discriminator.parameters(), lr=LR_D, weight_decay=L2_D)
self.d_scheduler: Scheduler = ReduceLROnPlateau(self.d_optimizer, mode='min', patience=LR_PAT_D)
self.d_criterion: Criterion = GANModel._discriminator_criterion

# with SummaryWriter(log_dir=f'{PROJECT_PATH}/res/log', comment='Generator') as w:
# h = self.generator(Variable(torch.zeros(BATCH_SIZE, 42, MAX_POLYPHONY), requires_grad=True))
# w.add_graph(self.generator, h)

# with SummaryWriter(log_dir=f'{PROJECT_PATH}/res/log', comment='Discriminator') as w:
# h = self.discriminator(Variable(torch.zeros(BATCH_SIZE, 42, MAX_POLYPHONY), requires_grad=True))
# w.add_graph(self.discriminator, h)

# PyTorch working modes

def train_mode(self):
Expand Down
36 changes: 21 additions & 15 deletions src/model/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,31 +3,30 @@
from typing import List

import torch
from torch.autograd import Variable
from tqdm import tqdm
from visdom import Visdom

from dataset.MusicDataset import MusicDataset
from dataset.preprocessing.reconstructor import reconstruct_midi
from model.GANGenerator import GANGenerator
from model.GANModel import GANModel
from utils.constants import EPOCHS, NUM_NOTES, CKPT_STEPS, CHECKPOINTS_PATH, SAMPLE_STEPS, FLAGS
from constants import EPOCHS, NUM_NOTES, CKPT_STEPS, CHECKPOINTS_PATH, SAMPLE_STEPS, FLAGS, PLOT_COL
from utils.tensors import device

CURRENT_TIME = datetime.now().strftime("%Y-%m-%d-%H-%M-%S")


class VisdomLinePlotter:
def __init__(self):
self.viz = Visdom(env='BachPropagation')
self.plots = {}

def plot_line(self, plot, line, title, y_label, x, y, ):
def plot_line(self, plot, line, title, y_label, x, y, color=None):
if plot not in self.plots:
self.plots[plot] = self.viz.line(X=x, Y=y, opts={'legend': [line], 'title': title,
'xlabel': 'Epochs', 'ylabel': y_label})
opts = {'title': title, 'xlabel': 'Epochs', 'ylabel': y_label, 'linecolor': color}
if line is not None:
opts['legend'] = [line]
self.plots[plot] = self.viz.line(X=x, Y=y, opts=opts)
else:
self.viz.line(X=x, Y=y, win=self.plots[plot], name=line, update='append')
self.viz.line(X=x, Y=y, win=self.plots[plot], name=line, update='append', opts={'linecolor': color})

def add_song(self, path):
self.viz.audio(audiofile=path, tensor=None)
Expand All @@ -43,8 +42,8 @@ def print_metrics(self):
print(f'Generator loss: {self.g_loss:.6f} | Discriminator loss: {self.d_loss:.6f}')

def plot_loss(self, vis):
vis.plot_line('Loss', 'Generator', f'Model Loss', 'Loss', [self.epoch], [self.g_loss])
vis.plot_line('Loss', 'Discriminator', None, None, [self.epoch], [self.d_loss])
vis.plot_line('Loss', 'Generator', f'Model Loss', 'Loss', [self.epoch], [self.g_loss], PLOT_COL['G'])
vis.plot_line('Loss', 'Discriminator', None, None, [self.epoch], [self.d_loss], PLOT_COL['D'])


class Trainer:
Expand All @@ -69,6 +68,10 @@ def train(self):

if FLAGS['viz']:
metric.plot_loss(self.vis)
self.vis.plot_line('LR_G', None, 'LR Generator', 'LR',
[epoch], [self.model.g_optimizer.param_groups[0]['lr']], PLOT_COL['G'])
self.vis.plot_line('LR_D', None, 'LR Discriminator', 'LR',
[epoch], [self.model.d_optimizer.param_groups[0]['lr']], PLOT_COL['D'])
metric.print_metrics()

if epoch % SAMPLE_STEPS == 0:
Expand Down Expand Up @@ -103,7 +106,7 @@ def _train_epoch(self, epoch: int) -> EpochMetric:

batch_data = enumerate(tqdm(self.loader, desc=f'Epoch {epoch}: ', ncols=100))
for batch_idx, features in batch_data:
features = features.to(device)
features = features.requires_grad_().to(device)

# if current_loss_d >= 0.7 * current_loss_g:
d_loss = self._train_discriminator(features)
Expand All @@ -117,6 +120,10 @@ def _train_epoch(self, epoch: int) -> EpochMetric:

g_loss = sum(sum_loss_g) / len(sum_loss_g)
d_loss = sum(sum_loss_d) / len(sum_loss_d)

self.model.g_scheduler.step(metrics=g_loss)
self.model.d_scheduler.step(metrics=d_loss)

return EpochMetric(epoch, g_loss, d_loss)

def _train_generator(self, data):
Expand All @@ -143,17 +150,16 @@ def _train_generator(self, data):

return -(loss.item())

def _train_discriminator(self, data) -> float:
def _train_discriminator(self, real_data) -> float:
logging.debug('Training Discriminator')

batch_size = data.size(0)
time_steps = data.size(1)
batch_size = real_data.size(0)
time_steps = real_data.size(1)

# Reset gradients
self.model.d_optimizer.zero_grad()

# Train on real data
real_data = Variable(data)
real_predictions = self.model.discriminator(real_data)

# Train on fake data
Expand Down
7 changes: 4 additions & 3 deletions src/utils/typings.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
import torch
from typing import TextIO
from typing import IO

File = TextIO
File = IO[str]
IntTensor = torch.IntTensor
FloatTensor = torch.FloatTensor

Optimizer = torch.optim.Optimizer
Scheduler = torch.optim.lr_scheduler
Criterion = torch.nn.modules.loss.CrossEntropyLoss
NNet = torch.nn.Module
NNet = torch.nn.Module

0 comments on commit 877e661

Please sign in to comment.