Skip to content

Commit

Permalink
Add form metadata tags (#90)
Browse files Browse the repository at this point in the history
* add form metadata tags

* format
  • Loading branch information
loubbrad committed Jan 17, 2024
1 parent 00d9c85 commit e7a6b7c
Show file tree
Hide file tree
Showing 7 changed files with 53 additions and 15 deletions.
4 changes: 1 addition & 3 deletions aria/data/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -832,8 +832,6 @@ def _build(_midi_dataset):

_build(_midi_dataset=midi_dataset)

logger.info(
f"Finished building, saved Finetuning to {save_path}"
)
logger.info(f"Finished building, saved Finetuning to {save_path}")

return cls(file_path=save_path, tokenizer=tokenizer)
23 changes: 20 additions & 3 deletions aria/data/midi.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

# TODO:
# - Possibly refactor names 'mid' to 'midi'
# - When pedal goes on after note on - this leads to it being played incorrectly.


class MidiDict:
Expand Down Expand Up @@ -556,7 +557,7 @@ def get_duration_ms(
return duration


def _match_composer(text: str, composer_name: str):
def _match_word(text: str, composer_name: str):
# If name="bach" this pattern will match "bach", "Bach" or "BACH" if
# it is either proceeded or preceded by a "_" or " ".
pattern = (
Expand All @@ -581,7 +582,7 @@ def meta_composer_filename(
file_name = pathlib.Path(mid.filename).stem
matched_names = set()
for name in composer_names:
if _match_composer(file_name, name):
if _match_word(file_name, name):
matched_names.add(name)

# Only return data if only one composer is found
Expand All @@ -592,13 +593,28 @@ def meta_composer_filename(
return {}


def meta_form_filename(mid: mido.MidiFile, msg_data: dict, form_names: list):
file_name = pathlib.Path(mid.filename).stem
matched_names = set()
for name in form_names:
if _match_word(file_name, name):
matched_names.add(name)

# Only return data if only one composer is found
matched_names = list(matched_names)
if len(matched_names) == 1:
return {"form": matched_names[0]}
else:
return {}


def meta_composer_metamsg(
mid: mido.MidiFile, msg_data: dict, composer_names: list
):
matched_names = set()
for msg in msg_data["meta_msgs"]:
for name in composer_names:
if _match_composer(msg["data"], name):
if _match_word(msg["data"], name):
matched_names.add(name)

# Only return data if only one composer is found
Expand All @@ -614,6 +630,7 @@ def get_metadata_fn(metadata_proc_name: str):
name_to_fn = {
"composer_filename": meta_composer_filename,
"composer_metamsg": meta_composer_metamsg,
"form_filename": meta_form_filename,
}

fn = name_to_fn.get(metadata_proc_name, None)
Expand Down
1 change: 1 addition & 0 deletions aria/sample.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

# TODO: Add which instruments were detected in the prompt


def _get_cfg_coeff(cfg_gamma, cfg_mode, cur_pos, start_pos, total_len):
if cfg_mode is None:
return cfg_gamma
Expand Down
12 changes: 12 additions & 0 deletions aria/tokenizer/tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -283,6 +283,8 @@ def __init__(self, return_tensors: bool = False):
self.prefix_tokens += [
("prefix", "composer", x) for x in self.composer_names
]
self.form_names = self.config["form_names"]
self.prefix_tokens += [("prefix", "form", x) for x in self.form_names]

# Build vocab
self.time_tok = "<T>"
Expand Down Expand Up @@ -389,6 +391,10 @@ def tokenize_midi_dict(self, midi_dict: MidiDict):
if composer and (composer in self.composer_names):
prefix.insert(0, ("prefix", "composer", composer))

form = midi_dict.metadata.get("form")
if form and (form in self.form_names):
prefix.insert(0, ("prefix", "form", form))

# NOTE: Any preceding silence is removed implicitly
tokenized_seq = []
initial_onset_tick = midi_dict.note_msgs[0]["data"]["start"]
Expand Down Expand Up @@ -919,6 +925,8 @@ def __init__(self, return_tensors: bool = False):
self.prefix_tokens += [
("prefix", "composer", x) for x in self.composer_names
]
self.form_names = self.config["form_names"]
self.prefix_tokens += [("prefix", "form", x) for x in self.form_names]

# Build vocab
self.wait_tokens = [("wait", i) for i in self.time_step_quantizations]
Expand Down Expand Up @@ -1012,6 +1020,10 @@ def tokenize_midi_dict(self, midi_dict: MidiDict):
if composer and (composer in self.composer_names):
prefix.insert(0, ("prefix", "composer", composer))

form = midi_dict.metadata.get("form")
if form and (form in self.form_names):
prefix.insert(0, ("prefix", "form", form))

# NOTE: Any preceding silence is removed implicitly
tokenized_seq = []
num_notes = len(midi_dict.note_msgs)
Expand Down
12 changes: 10 additions & 2 deletions config/config.json
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,12 @@
"args": {
"composer_names": ["bach", "beethoven", "mozart", "chopin", "rachmaninoff", "liszt", "debussy", "schubert", "brahms", "ravel", "satie", "scarlatti"]
}
},
"form_filename": {
"run": true,
"args": {
"form_names": ["sonata", "prelude", "nocturne", "étude", "waltz", "mazurka", "impromptu", "fugue"]
}
}
}
},
Expand Down Expand Up @@ -119,7 +125,8 @@
"num_steps": 500,
"step": 10
},
"composer_names": ["bach", "beethoven", "mozart", "chopin", "rachmaninoff", "liszt", "debussy", "schubert", "brahms", "ravel", "satie", "scarlatti"]
"composer_names": ["bach", "beethoven", "mozart", "chopin", "rachmaninoff", "liszt", "debussy", "schubert", "brahms", "ravel", "satie", "scarlatti"],
"form_names": ["sonata", "prelude", "nocturne", "étude", "waltz", "mazurka", "impromptu", "fugue"]
},
"abs": {
"ignore_instruments": {
Expand Down Expand Up @@ -165,7 +172,8 @@
"abs_time_step_ms": 5000,
"max_dur_ms": 5000,
"time_step_ms": 10,
"composer_names": ["bach", "beethoven", "mozart", "chopin", "rachmaninoff", "liszt", "debussy", "schubert", "brahms", "ravel", "satie", "scarlatti"]
"composer_names": ["bach", "beethoven", "mozart", "chopin", "rachmaninoff", "liszt", "debussy", "schubert", "brahms", "ravel", "satie", "scarlatti"],
"form_names": ["sonata", "prelude", "nocturne", "étude", "waltz", "mazurka", "impromptu", "fugue"]
}
}
}
File renamed without changes.
16 changes: 9 additions & 7 deletions tests/test_tokenizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,7 +213,7 @@ def tokenize_detokenize(file_name: str):
tknzr = tokenizer.AbsTokenizer(return_tensors=False)
tokenize_detokenize("basic.mid")
tokenize_detokenize("arabesque.mid")
tokenize_detokenize("beethoven.mid")
tokenize_detokenize("beethoven_sonata.mid")
tokenize_detokenize("bach.mid")
tokenize_detokenize("expressive.mid")
tokenize_detokenize("pop.mid")
Expand Down Expand Up @@ -249,7 +249,7 @@ def tokenize_aug_detokenize(
logging.info(f"pitch_aug_fn:\n{seq} ->\n\n{seq_pitch_augmented}\n")
tokenize_aug_detokenize("basic.mid", pitch_aug_fn, "pitch")
tokenize_aug_detokenize("arabesque.mid", pitch_aug_fn, "pitch")
tokenize_aug_detokenize("beethoven.mid", pitch_aug_fn, "pitch")
tokenize_aug_detokenize("beethoven_sonata.mid", pitch_aug_fn, "pitch")
tokenize_aug_detokenize("bach.mid", pitch_aug_fn, "pitch")
tokenize_aug_detokenize("expressive.mid", pitch_aug_fn, "pitch")
tokenize_aug_detokenize("pop.mid", pitch_aug_fn, "pitch")
Expand All @@ -264,7 +264,9 @@ def tokenize_aug_detokenize(
)
tokenize_aug_detokenize("basic.mid", velocity_aug_fn, "velocity")
tokenize_aug_detokenize("arabesque.mid", velocity_aug_fn, "velocity")
tokenize_aug_detokenize("beethoven.mid", velocity_aug_fn, "velocity")
tokenize_aug_detokenize(
"beethoven_sonata.mid", velocity_aug_fn, "velocity"
)
tokenize_aug_detokenize("bach.mid", velocity_aug_fn, "velocity")
tokenize_aug_detokenize("expressive.mid", velocity_aug_fn, "velocity")
tokenize_aug_detokenize("pop.mid", velocity_aug_fn, "velocity")
Expand All @@ -283,7 +285,7 @@ def tokenize_aug_detokenize(

tokenize_aug_detokenize("basic.mid", tempo_aug_fn, "tempo")
tokenize_aug_detokenize("arabesque.mid", tempo_aug_fn, "tempo")
tokenize_aug_detokenize("beethoven.mid", tempo_aug_fn, "tempo")
tokenize_aug_detokenize("beethoven_sonata.mid", tempo_aug_fn, "tempo")
tokenize_aug_detokenize("bach.mid", tempo_aug_fn, "tempo")
tokenize_aug_detokenize("expressive.mid", tempo_aug_fn, "tempo")
tokenize_aug_detokenize("pop.mid", tempo_aug_fn, "tempo")
Expand All @@ -293,7 +295,7 @@ def tokenize_aug_detokenize(

def test_aug_time(self):
tknzr = tokenizer.AbsTokenizer()
mid_dict = MidiDict.from_midi("tests/test_data/beethoven.mid")
mid_dict = MidiDict.from_midi("tests/test_data/beethoven_sonata.mid")
tokenized_seq = tknzr.tokenize(mid_dict)[:4096]
pitch_aug_fn = tknzr.export_pitch_aug(aug_range=5)
velocity_aug_fn = tknzr.export_velocity_aug(aug_steps_range=2)
Expand Down Expand Up @@ -352,7 +354,7 @@ def tokenize_detokenize(file_name: str):

tokenize_detokenize("basic.mid")
tokenize_detokenize("arabesque.mid")
tokenize_detokenize("beethoven.mid")
tokenize_detokenize("beethoven_sonata.mid")
tokenize_detokenize("bach.mid")
tokenize_detokenize("expressive.mid")
tokenize_detokenize("pop.mid")
Expand Down Expand Up @@ -405,7 +407,7 @@ def test_aug(self):

def test_aug_time(self):
tknzr = tokenizer.RelTokenizer()
mid_dict = MidiDict.from_midi("tests/test_data/beethoven.mid")
mid_dict = MidiDict.from_midi("tests/test_data/beethoven_sonata.mid")
tokenized_seq = tknzr.tokenize(mid_dict)[:4096]

pitch_aug_fn = tknzr.export_pitch_aug(aug_range=5)
Expand Down

0 comments on commit e7a6b7c

Please sign in to comment.