Skip to content

Commit

Permalink
Modify input of GAN
Browse files Browse the repository at this point in the history
  • Loading branch information
josepdecid committed Apr 14, 2019
1 parent 22a1683 commit 68a36ea
Show file tree
Hide file tree
Showing 7 changed files with 20 additions and 14 deletions.
4 changes: 3 additions & 1 deletion src/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@

SAMPLE_STEPS = 1
CKPT_STEPS = 100
INPUT_FEATURES = 3

MIN_FREQ_NOTE = 27.5
MAX_FREQ_NOTE = 4186.0
Expand All @@ -35,6 +36,7 @@
EPOCHS = 100
BATCH_SIZE = 5
MAX_POLYPHONY = 12
NORMALIZE_FREQ = False
SAMPLE_TIMES = 50

# Generator
Expand All @@ -43,7 +45,7 @@
L2_G = 0.25
HIDDEN_DIM_G = 30
BIDIRECTIONAL_G = False
PRETRAIN_G = 0
PRETRAIN_G = 10
TYPE_G = 'LSTM'
LAYERS_G = 1

Expand Down
4 changes: 2 additions & 2 deletions src/dataset/MusicDataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import torch
from torch.utils.data import Dataset, DataLoader

from constants import DATASET_PATH, BATCH_SIZE, MAX_POLYPHONY
from constants import DATASET_PATH, BATCH_SIZE, INPUT_FEATURES
from utils.tensors import use_cuda
from utils.typings import File, FloatTensor

Expand Down Expand Up @@ -46,7 +46,7 @@ def _apply_padding(self) -> List[FloatTensor]:
"""
padded_songs = []
for song in self.songs:
padded_song = torch.zeros((self.longest_song, MAX_POLYPHONY), dtype=torch.float)
padded_song = torch.zeros((self.longest_song, INPUT_FEATURES), dtype=torch.float)
padded_song[:len(song), :] = torch.tensor(song)
padded_songs.append(padded_song)
return padded_songs
Expand Down
7 changes: 4 additions & 3 deletions src/dataset/preprocessing/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from py_midicsv import midi_to_csv
from tqdm import tqdm

from constants import RAW_DATASET_PATH, DATASET_PATH, MAX_POLYPHONY
from constants import RAW_DATASET_PATH, DATASET_PATH, MAX_POLYPHONY, NORMALIZE_FREQ
from dataset.Music import Song, Track, NoteData


Expand Down Expand Up @@ -66,7 +66,6 @@ def csv_to_series(song: Song) -> List[Tuple[float, int, int]]:
"""
ts_data = []
last_start = None
note_max_start = song.max_time

# Index of current treated element for each track as those are already sorted.
track_time_indices: Union[int, None] = [0] * song.number_tracks
Expand All @@ -75,6 +74,7 @@ def csv_to_series(song: Song) -> List[Tuple[float, int, int]]:

while True:
# Get first note to be played
note_max_start = song.max_time
first_note_idx = None
for track_idx, note in enumerate(current_notes):
if note is not None and note.note_start < note_max_start:
Expand All @@ -87,7 +87,8 @@ def csv_to_series(song: Song) -> List[Tuple[float, int, int]]:

# Add id to the time series data
time_since_last = first_note.note_start - last_start if last_start is not None else 0
ts_data.append((first_note.norm_freq, first_note.duration, time_since_last))
ts_data.append((first_note.norm_freq if NORMALIZE_FREQ else first_note.freq,
first_note.duration, time_since_last))
last_start = first_note.note_start

# Update that note and discard finished tracks
Expand Down
5 changes: 4 additions & 1 deletion src/dataset/preprocessing/reconstructor.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,10 @@ def parse_data(notes_data: NDArray) -> str:
csv_data_tracks = [[f'{idx + 2}, 0, Start_track'] for idx in range(MAX_POLYPHONY)]

for note_data in notes_data:
note = freq_to_note(float(note_data[0]) * MAX_FREQ_NOTE)
note = freq_to_note(float(note_data[0]) * (MAX_FREQ_NOTE if MAX_FREQ_NOTE else 1))
# if note == 0:
# continue

duration = int(note_data[1])
time_since_previous = int(note_data[2])

Expand Down
4 changes: 2 additions & 2 deletions src/model/Trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ def train(self):
metric.print_metrics()

if epoch % SAMPLE_STEPS == 0:
sample = self.generate_sample(length=500)
sample = self.generate_sample(length=5000)
reconstruct_midi(title=f'Sample_{epoch}', raw_data=sample)
# TODO: Visdom doesn't accept MIDI files.
# Should convert to WAV or find an alternative for visualization.
Expand All @@ -86,7 +86,7 @@ def generate_sample(self, length: int) -> NDArray:
noise_data = GANGenerator.noise((1, length))
sample_data = self.model.generator(noise_data)
# sample_notes = sample_data.argmax(2)
return sample_data.view(-1, MAX_POLYPHONY).cpu().numpy()
return sample_data.view(-1, 3).cpu().numpy()

def _pretrain_epoch(self, epoch: int) -> EpochMetric:
"""
Expand Down
4 changes: 2 additions & 2 deletions src/model/gan/GANDiscriminator.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from torch import nn
from torch.nn import functional as F

from constants import LAYERS_D, HIDDEN_DIM_D, BIDIRECTIONAL_D, TYPE_D, MAX_POLYPHONY
from constants import LAYERS_D, HIDDEN_DIM_D, BIDIRECTIONAL_D, TYPE_D, INPUT_FEATURES
from model.gan.RNN import RNN
from utils.tensors import device

Expand All @@ -17,7 +17,7 @@ def __init__(self):
super(GANDiscriminator, self).__init__()

self.rnn = RNN(architecture=TYPE_D,
inp_dim=MAX_POLYPHONY,
inp_dim=INPUT_FEATURES,
hid_dim=HIDDEN_DIM_D,
layers=LAYERS_D,
bidirectional=BIDIRECTIONAL_D).to(device)
Expand Down
6 changes: 3 additions & 3 deletions src/model/gan/GANGenerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from torch import nn
from torch.nn import functional as F

from constants import HIDDEN_DIM_G, LAYERS_G, BIDIRECTIONAL_G, TYPE_G, MAX_POLYPHONY
from constants import HIDDEN_DIM_G, LAYERS_G, BIDIRECTIONAL_G, TYPE_G, INPUT_FEATURES
from model.gan.RNN import RNN
from utils.tensors import device

Expand All @@ -25,8 +25,8 @@ def __init__(self):
bidirectional=BIDIRECTIONAL_G).to(device)

dense_input_features = (2 if BIDIRECTIONAL_G else 1) * HIDDEN_DIM_G
self.dense_1 = nn.Linear(in_features=dense_input_features, out_features=2 * MAX_POLYPHONY)
self.dense_2 = nn.Linear(in_features=2 * MAX_POLYPHONY, out_features=MAX_POLYPHONY)
self.dense_1 = nn.Linear(in_features=dense_input_features, out_features=2 * INPUT_FEATURES)
self.dense_2 = nn.Linear(in_features=2 * INPUT_FEATURES, out_features=INPUT_FEATURES)

def forward(self, x):
x, _ = self.rnn(x, )
Expand Down

0 comments on commit 68a36ea

Please sign in to comment.