Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement tempo augmentation function for TokenizerLazy #17

Merged
merged 2 commits into from
Aug 8, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
added tempo_aug_fn to lazy tokenzier
  • Loading branch information
loubbrad committed Aug 8, 2023
commit 70cdad8d84a2a57dcb17bf45f1f0957f3fd9d34e
2 changes: 1 addition & 1 deletion ROADMAP.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ As it stands, the basic functionality of the repository is implemented and teste
* [ ] **Add chord mix-up data-augmentation function**

This (tokenized) data-augmentation function should randomly shuffle the order of notes that occur concurrently. For instance, `("piano", 60, 50), ("dur", 10), ("piano", 64, 50), ("dur", 20)` could be augmented to `("piano", 64, 50), ("dur", 20), ("piano", 60, 50), ("dur", 10)` as there is no wait token between the notes. See `aria.tokenizer.TokenizerLazy.export_pitch_aug()` for an example of how to implement data augmentation functions.
* [ ] **Add speed data-augmentation function**
* [x] **~~Add speed data-augmentation function~~**

This data-augmentation function should change the speed of a tokenized sequence by some (float) factor. The main issue I foresee is accounting for the way that wait tokens are currently implemented. Depending on the `config.json`, the lazy tokenizer has a max wait token `("wait", t_max)`. Any 'wait' event longer than `t_max` is represented as a sequence of tokens. For instance, a wait of 2*t_max + 10ms would be `("wait", t_max), ("wait", t_max), ("wait", 10)`.
* [x] **~~Fix encode/decode disparity bug~~**
Expand Down
114 changes: 105 additions & 9 deletions aria/tokenizer/tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,9 @@
import torch
import functools
import itertools
import random

from collections import defaultdict
from random import randint
from mido.midifiles.units import second2tick, tick2second

from aria.data.midi import MidiDict
Expand Down Expand Up @@ -526,7 +526,7 @@ def pitch_aug_tok(tok, _pitch_aug):
else:
return unk_tok

pitch_aug = randint(-_aug_range, _aug_range)
pitch_aug = random.randint(-_aug_range, _aug_range)
return [pitch_aug_tok(x, pitch_aug) for x in src]

# See functools.partial docs
Expand Down Expand Up @@ -554,7 +554,7 @@ def velocity_aug_seq(
src: list,
velocity_step: int,
max_velocity: int,
_aug_steps_range: float,
_aug_steps_range: int,
):
def velocity_aug_tok(tok, _velocity_aug):
if isinstance(tok, str):
Expand Down Expand Up @@ -582,7 +582,7 @@ def velocity_aug_tok(tok, _velocity_aug):

return (_instrument, _pitch, _velocity + _velocity_aug)

velocity_aug = velocity_step * randint(
velocity_aug = velocity_step * random.randint(
-_aug_steps_range, _aug_steps_range
)
return [velocity_aug_tok(x, velocity_aug) for x in src]
Expand All @@ -595,11 +595,107 @@ def velocity_aug_tok(tok, _velocity_aug):
_aug_steps_range=aug_steps_range,
)

# TODO: Implement - follow export_pitch aug
# Also implement order mix up for chords
def export_time_aug(self):
# Remember special case where we have max_time_step
raise NotImplementedError
def export_tempo_aug(self, tempo_aug_range: float):
def tempo_aug_seq(
src: list,
min_time_step: int,
max_time_step: int,
pad_tok: str,
_tempo_aug_range: float,
):
def tempo_aug_tok_raw(tok, _tempo_aug):
if isinstance(tok, str):
_tok_type = "special"
else:
_tok_type = tok[0]

if _tok_type == "wait" or _tok_type == "dur":
(__tok_type, _dur) = tok

return (
__tok_type,
min_time_step
* int(round(float(_tempo_aug * _dur) / min_time_step)),
)
else:
# Return without changing
return tok

tempo_aug = random.uniform(
1 - tempo_aug_range, 1 + _tempo_aug_range
)
augmented_seq = [tempo_aug_tok_raw(x, tempo_aug) for x in src]

# Recalculate dur and wait tokens so that they are correctly
# formatted after naive augmentation.
initial_seq_len = len(augmented_seq)
idx = 0
buffer = []
while idx < len(augmented_seq):
tok = augmented_seq[idx]
if isinstance(tok, str):
tok_type = "special"
else:
tok_type = tok[0]

# Get tok_type of next token if possible
if idx + 1 < len(augmented_seq):
next_tok = augmented_seq[idx + 1]
if isinstance(next_tok, str):
next_tok_type = "special"
else:
next_tok_type = next_tok[0]
else:
next_tok_type = None

# If necessary add wait token to the buffer
if tok_type == "wait":
# Overflow
if buffer or tok[1] >= max_time_step:
buffer.append(augmented_seq.pop(idx))
# Underflow
elif next_tok_type == "wait":
buffer.append(augmented_seq.pop(idx))
else:
idx += 1

# Current tok not wait token so if the buffer is not empty
# recalculate and reinsert wait tokens in the buffer.
elif buffer:
buffer_remaining_dur = sum(_tok[1] for _tok in buffer)

while buffer_remaining_dur > max_time_step:
augmented_seq.insert(idx, ("wait", max_time_step))
buffer_remaining_dur -= max_time_step
idx += 1

augmented_seq.insert(idx, ("wait", buffer_remaining_dur))
buffer = []
idx += 1

# If dur token has overflowed, truncate at _max_time_step
elif tok_type == "dur":
if tok[1] > max_time_step:
augmented_seq[idx] = ("dur", max_time_step)
idx += 1

else:
idx += 1

# Pad or truncate to original sequence length as necessary
augmented_seq = augmented_seq[:initial_seq_len]
augmented_seq += [pad_tok] * (initial_seq_len - len(augmented_seq))

return augmented_seq

# See functools.partial docs
return functools.partial(
tempo_aug_seq,
min_time_step=self.min_time_step,
max_time_step=self.max_time_step,
pad_tok=self.pad_tok,
_tempo_aug_range=tempo_aug_range,
)


def _get_duration_ms(
Expand Down
6 changes: 5 additions & 1 deletion aria/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,11 @@ def pretrain(

if overfit is False:
train_dataset.set_transform(
[tokenizer.export_velocity_aug(2), tokenizer.export_pitch_aug(4)]
[
tokenizer.export_velocity_aug(2),
tokenizer.export_pitch_aug(4),
tokenizer.export_tempo_aug(0.15),
]
)

train_dataloader = DataLoader(
Expand Down
10 changes: 10 additions & 0 deletions tests/test_tokenizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,12 @@ def get_short_seq(tknzr: tokenizer.TokenizerLazy):
("wait", tknzr._quantize_time(100)),
("drum", tknzr._quantize_time(50)),
("piano", 64, tknzr._quantize_velocity(70)),
("dur", tknzr._quantize_time(1000000)),
("wait", tknzr._quantize_time(1000000)),
("wait", tknzr._quantize_time(1000000)),
("wait", tknzr._quantize_time(1000000)),
("wait", tknzr._quantize_time(100)),
("piano", 65, tknzr._quantize_velocity(70)),
("dur", tknzr._quantize_time(100)),
("wait", tknzr._quantize_time(100)),
"<E>",
Expand Down Expand Up @@ -52,6 +58,7 @@ def test_aug(self):
seq = get_short_seq(tknzr)
pitch_aug_fn = tknzr.export_pitch_aug(aug_range=5)
velocity_aug_fn = tknzr.export_velocity_aug(aug_steps_range=2)
tempo_aug_fn = tknzr.export_tempo_aug(tempo_aug_range=0.5)

seq_pitch_augmented = pitch_aug_fn(get_short_seq(tknzr))
logging.info(f"pitch_aug_fn:\n{seq} ->\n{seq_pitch_augmented}")
Expand All @@ -67,6 +74,9 @@ def test_aug(self):
seq_velocity_augmented[7][2] - seq[7][2],
)

seq_tempo_augmented = tempo_aug_fn(get_short_seq(tknzr))
logging.info(f"tempo_aug_fn:\n{seq} ->\n{seq_tempo_augmented}")

def test_encode_decode(self):
tknzr = tokenizer.TokenizerLazy(
return_tensors=True,
Expand Down