Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add composer prefix tokens to TokenizerLazy #23

Merged
merged 3 commits into from
Aug 12, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion ROADMAP.md
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ As it stands, the basic functionality of the repository is implemented and teste
* [x] **~~Add further pre-processing tests~~**

Add further MidiDict pre-processing tests to improve dataset quality. Some ideas are checking for the frequency of note messages (an average of > 15 p/s or < 2 p/s is a bad sign). I'm open to any suggestions for MidiDict preprocessing tests. Properly cleaning pre-training datasets has a huge effect on model quality and robustness.
* [ ] **Add meta-token prefix support for LazyTokenizer**
* [x] **~~Add meta-token prefix support for LazyTokenizer~~**

Investigate the possibility of adding meta-tokens to the prefix in LazyTokenizer. Some examples could be genre, composer, or data source tags. This might require a rewrite of how sequence prefixes are handled.
* [x] **~~Add 'ending soon' token to lazy tokenizer~~**
Expand Down
2 changes: 1 addition & 1 deletion aria/data/midi.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ def __init__(
instrument_msgs: list,
note_msgs: list,
ticks_per_beat: int,
metadata: list,
metadata: dict,
):
self.meta_msgs = meta_msgs
self.tempo_msgs = tempo_msgs
Expand Down
44 changes: 26 additions & 18 deletions aria/tokenizer/tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,17 +88,6 @@ def _dec_fn(id):
return decoded_seq


# TODO: ADD META TOKEN FUNCTIONALITY
# - Refactor meta tokens so they are a triple ("meta", "val")
# this should make using tok_type stuff easier. This refactor should touch a
# lot of code so take care here.
# - Add prefix class function to lazy tokenizer that calculates meta tokens.
# This prefix class should call various functions according to config.json.
# - One function could be doing regex on the meta messages, looking for
# composer names. If one and only one composer name is found then it is added
# to the prefix before the instruments. We could specify the list of
# composers we are interested in in the config.json.
# - By loading according to the config.json we could extend this easily.
class TokenizerLazy(Tokenizer):
"""Lazy MidiDict Tokenizer"""

Expand Down Expand Up @@ -133,7 +122,15 @@ def __init__(
if v is False
]
self.instruments_wd = self.instruments_nd + ["drum"]
self.prefix_tokens = [("prefix", x) for x in self.instruments_wd]

# Prefix tokens
self.prefix_tokens = [
("prefix", "instrument", x) for x in self.instruments_wd
]
self.composer_names = self.config["composer_names"]
self.prefix_tokens += [
("prefix", "composer", x) for x in self.composer_names
]

# Build vocab
self.special_tokens = [
Expand Down Expand Up @@ -272,9 +269,16 @@ def tokenize_midi_dict(self, midi_dict: MidiDict):
channel_to_instrument[c] = "piano"

# Add non-drums to present_instruments (prefix)
prefix = [("prefix", x) for x in set(channel_to_instrument.values())]
prefix = [
("prefix", "instrument", x)
for x in set(channel_to_instrument.values())
]
if 9 in channels_used:
prefix.append(("prefix", "drum"))
prefix.append(("prefix", "instrument", "drum"))

composer = midi_dict.metadata.get("composer")
if composer and (composer in self.composer_names):
prefix.insert(0, ("prefix", "composer", composer))

# NOTE: Any preceding silence is removed implicitly
tokenized_seq = []
Expand Down Expand Up @@ -378,20 +382,24 @@ def detokenize_midi_dict(self, tokenized_seq: list):
if tok in self.special_tokens:
continue
# Non-drum instrument prefix tok
elif tok[0] == "prefix" and tok[1] in self.instruments_nd:
elif (
tok[0] == "prefix"
and tok[1] == "instrument"
and tok[2] in self.instruments_nd
):
if tok[1] in instrument_to_channel.keys():
logging.warning(f"Duplicate prefix {tok[1]}")
logging.warning(f"Duplicate prefix {tok[2]}")
continue
else:
instrument_msgs.append(
{
"type": "instrument",
"data": instrument_programs[tok[1]],
"data": instrument_programs[tok[2]],
"tick": 0,
"channel": channel_idx,
}
)
instrument_to_channel[tok[1]] = channel_idx
instrument_to_channel[tok[2]] = channel_idx
channel_idx += 1
# Catches all other prefix tokens
elif tok[0] == "prefix":
Expand Down
3 changes: 2 additions & 1 deletion config/config.json
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,8 @@
"time_quantization": {
"num_steps": 500,
"min_step": 10
}
},
"composer_names": ["beethoven", "bach", "mozart", "chopin"]
}
}
}
13 changes: 7 additions & 6 deletions tests/test_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,9 @@

def get_short_seq():
return [
("prefix", "piano"),
("prefix", "drum"),
("prefix", "instrument", "piano"),
("prefix", "instrument", "drum"),
("prefix", "composer", "bach"),
"<S>",
("piano", 62, 50),
("dur", 50),
Expand Down Expand Up @@ -164,12 +165,12 @@ def test_augmentation(self):

logging.info(f"aug:\n{seq} ->\n{seq_augmented}")
self.assertEqual(
seq_augmented[3][1] - seq[3][1],
seq_augmented[7][1] - seq[7][1],
seq_augmented[4][1] - seq[4][1],
seq_augmented[8][1] - seq[8][1],
)
self.assertEqual(
seq_augmented[3][2] - seq[3][2],
seq_augmented[7][2] - seq[7][2],
seq_augmented[4][2] - seq[4][2],
seq_augmented[8][2] - seq[8][2],
)

tokenized_dataset.close()
Expand Down
20 changes: 14 additions & 6 deletions tests/test_tokenizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,9 @@

def get_short_seq(tknzr: tokenizer.TokenizerLazy):
return [
("prefix", "piano"),
("prefix", "drum"),
("prefix", "instrument", "piano"),
("prefix", "instrument", "drum"),
("prefix", "composer", "bach"),
"<S>",
("piano", 62, tknzr._quantize_velocity(50)),
("dur", tknzr._quantize_time(50)),
Expand Down Expand Up @@ -63,15 +64,15 @@ def test_aug(self):
seq_pitch_augmented = pitch_aug_fn(get_short_seq(tknzr))
logging.info(f"pitch_aug_fn:\n{seq} ->\n{seq_pitch_augmented}")
self.assertEqual(
seq_pitch_augmented[3][1] - seq[3][1],
seq_pitch_augmented[7][1] - seq[7][1],
seq_pitch_augmented[4][1] - seq[4][1],
seq_pitch_augmented[8][1] - seq[8][1],
)

seq_velocity_augmented = velocity_aug_fn(get_short_seq(tknzr))
logging.info(f"velocity_aug_fn:\n{seq} ->\n{seq_velocity_augmented}")
self.assertEqual(
seq_velocity_augmented[3][2] - seq[3][2],
seq_velocity_augmented[7][2] - seq[7][2],
seq_velocity_augmented[4][2] - seq[4][2],
seq_velocity_augmented[8][2] - seq[8][2],
)

seq_tempo_augmented = tempo_aug_fn(get_short_seq(tknzr))
Expand All @@ -94,6 +95,13 @@ def test_encode_decode(self):
for x, y in zip(seq, enc_dec_seq):
self.assertEqual(x, y)

def test_no_unk_token(self):
tknzr = tokenizer.TokenizerLazy()
seq = get_short_seq(tknzr)
enc_dec_seq = tknzr.decode(tknzr.encode(seq))
for tok in enc_dec_seq:
self.assertTrue(tok != tknzr.unk_tok)


if __name__ == "__main__":
if os.path.isdir("tests/test_results") is False:
Expand Down