Skip to content

Commit

Permalink
Add composer prefix tokens to TokenizerLazy (#23)
Browse files Browse the repository at this point in the history
* added composer prefix tokens

* fix composer name config usage

* fix roadmap
  • Loading branch information
loubbrad committed Aug 12, 2023
1 parent 895fa2f commit 1786de2
Show file tree
Hide file tree
Showing 6 changed files with 51 additions and 33 deletions.
2 changes: 1 addition & 1 deletion ROADMAP.md
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ As it stands, the basic functionality of the repository is implemented and teste
* [x] **~~Add further pre-processing tests~~**

Add further MidiDict pre-processing tests to improve dataset quality. Some ideas are checking for the frequency of note messages (an average of > 15 p/s or < 2 p/s is a bad sign). I'm open to any suggestions for MidiDict preprocessing tests. Properly cleaning pre-training datasets has a huge effect on model quality and robustness.
* [ ] **Add meta-token prefix support for LazyTokenizer**
* [x] **~~Add meta-token prefix support for LazyTokenizer~~**

Investigate the possibility of adding meta-tokens to the prefix in LazyTokenizer. Some examples could be genre, composer, or data source tags. This might require a rewrite of how sequence prefixes are handled.
* [x] **~~Add 'ending soon' token to lazy tokenizer~~**
Expand Down
2 changes: 1 addition & 1 deletion aria/data/midi.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ def __init__(
instrument_msgs: list,
note_msgs: list,
ticks_per_beat: int,
metadata: list,
metadata: dict,
):
self.meta_msgs = meta_msgs
self.tempo_msgs = tempo_msgs
Expand Down
44 changes: 26 additions & 18 deletions aria/tokenizer/tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,17 +88,6 @@ def _dec_fn(id):
return decoded_seq


# TODO: ADD META TOKEN FUNCTIONALITY
# - Refactor meta tokens so they are a triple ("meta", "val")
# this should make using tok_type stuff easier. This refactor should touch a
# lot of code so take care here.
# - Add prefix class function to lazy tokenizer that calculates meta tokens.
# This prefix class should call various functions according to config.json.
# - One function could be doing regex on the meta messages, looking for
# composer names. If one and only one composer name is found then it is added
# to the prefix before the instruments. We could specify the list of
# composers we are interested in in the config.json.
# - By loading according to the config.json we could extend this easily.
class TokenizerLazy(Tokenizer):
"""Lazy MidiDict Tokenizer"""

Expand Down Expand Up @@ -133,7 +122,15 @@ def __init__(
if v is False
]
self.instruments_wd = self.instruments_nd + ["drum"]
self.prefix_tokens = [("prefix", x) for x in self.instruments_wd]

# Prefix tokens
self.prefix_tokens = [
("prefix", "instrument", x) for x in self.instruments_wd
]
self.composer_names = self.config["composer_names"]
self.prefix_tokens += [
("prefix", "composer", x) for x in self.composer_names
]

# Build vocab
self.special_tokens = [
Expand Down Expand Up @@ -272,9 +269,16 @@ def tokenize_midi_dict(self, midi_dict: MidiDict):
channel_to_instrument[c] = "piano"

# Add non-drums to present_instruments (prefix)
prefix = [("prefix", x) for x in set(channel_to_instrument.values())]
prefix = [
("prefix", "instrument", x)
for x in set(channel_to_instrument.values())
]
if 9 in channels_used:
prefix.append(("prefix", "drum"))
prefix.append(("prefix", "instrument", "drum"))

composer = midi_dict.metadata.get("composer")
if composer and (composer in self.composer_names):
prefix.insert(0, ("prefix", "composer", composer))

# NOTE: Any preceding silence is removed implicitly
tokenized_seq = []
Expand Down Expand Up @@ -378,20 +382,24 @@ def detokenize_midi_dict(self, tokenized_seq: list):
if tok in self.special_tokens:
continue
# Non-drum instrument prefix tok
elif tok[0] == "prefix" and tok[1] in self.instruments_nd:
elif (
tok[0] == "prefix"
and tok[1] == "instrument"
and tok[2] in self.instruments_nd
):
if tok[1] in instrument_to_channel.keys():
logging.warning(f"Duplicate prefix {tok[1]}")
logging.warning(f"Duplicate prefix {tok[2]}")
continue
else:
instrument_msgs.append(
{
"type": "instrument",
"data": instrument_programs[tok[1]],
"data": instrument_programs[tok[2]],
"tick": 0,
"channel": channel_idx,
}
)
instrument_to_channel[tok[1]] = channel_idx
instrument_to_channel[tok[2]] = channel_idx
channel_idx += 1
# Catches all other prefix tokens
elif tok[0] == "prefix":
Expand Down
3 changes: 2 additions & 1 deletion config/config.json
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,8 @@
"time_quantization": {
"num_steps": 500,
"min_step": 10
}
},
"composer_names": ["beethoven", "bach", "mozart", "chopin"]
}
}
}
13 changes: 7 additions & 6 deletions tests/test_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,9 @@

def get_short_seq():
return [
("prefix", "piano"),
("prefix", "drum"),
("prefix", "instrument", "piano"),
("prefix", "instrument", "drum"),
("prefix", "composer", "bach"),
"<S>",
("piano", 62, 50),
("dur", 50),
Expand Down Expand Up @@ -164,12 +165,12 @@ def test_augmentation(self):

logging.info(f"aug:\n{seq} ->\n{seq_augmented}")
self.assertEqual(
seq_augmented[3][1] - seq[3][1],
seq_augmented[7][1] - seq[7][1],
seq_augmented[4][1] - seq[4][1],
seq_augmented[8][1] - seq[8][1],
)
self.assertEqual(
seq_augmented[3][2] - seq[3][2],
seq_augmented[7][2] - seq[7][2],
seq_augmented[4][2] - seq[4][2],
seq_augmented[8][2] - seq[8][2],
)

tokenized_dataset.close()
Expand Down
20 changes: 14 additions & 6 deletions tests/test_tokenizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,9 @@

def get_short_seq(tknzr: tokenizer.TokenizerLazy):
return [
("prefix", "piano"),
("prefix", "drum"),
("prefix", "instrument", "piano"),
("prefix", "instrument", "drum"),
("prefix", "composer", "bach"),
"<S>",
("piano", 62, tknzr._quantize_velocity(50)),
("dur", tknzr._quantize_time(50)),
Expand Down Expand Up @@ -63,15 +64,15 @@ def test_aug(self):
seq_pitch_augmented = pitch_aug_fn(get_short_seq(tknzr))
logging.info(f"pitch_aug_fn:\n{seq} ->\n{seq_pitch_augmented}")
self.assertEqual(
seq_pitch_augmented[3][1] - seq[3][1],
seq_pitch_augmented[7][1] - seq[7][1],
seq_pitch_augmented[4][1] - seq[4][1],
seq_pitch_augmented[8][1] - seq[8][1],
)

seq_velocity_augmented = velocity_aug_fn(get_short_seq(tknzr))
logging.info(f"velocity_aug_fn:\n{seq} ->\n{seq_velocity_augmented}")
self.assertEqual(
seq_velocity_augmented[3][2] - seq[3][2],
seq_velocity_augmented[7][2] - seq[7][2],
seq_velocity_augmented[4][2] - seq[4][2],
seq_velocity_augmented[8][2] - seq[8][2],
)

seq_tempo_augmented = tempo_aug_fn(get_short_seq(tknzr))
Expand All @@ -94,6 +95,13 @@ def test_encode_decode(self):
for x, y in zip(seq, enc_dec_seq):
self.assertEqual(x, y)

def test_no_unk_token(self):
tknzr = tokenizer.TokenizerLazy()
seq = get_short_seq(tknzr)
enc_dec_seq = tknzr.decode(tknzr.encode(seq))
for tok in enc_dec_seq:
self.assertTrue(tok != tknzr.unk_tok)


if __name__ == "__main__":
if os.path.isdir("tests/test_results") is False:
Expand Down

0 comments on commit 1786de2

Please sign in to comment.