diff --git a/ROADMAP.md b/ROADMAP.md index b6e5ace..a6a888a 100644 --- a/ROADMAP.md +++ b/ROADMAP.md @@ -25,7 +25,7 @@ As it stands, the basic functionality of the repository is implemented and teste * [x] **~~Add further pre-processing tests~~** Add further MidiDict pre-processing tests to improve dataset quality. Some ideas are checking for the frequency of note messages (an average of > 15 p/s or < 2 p/s is a bad sign). I'm open to any suggestions for MidiDict preprocessing tests. Properly cleaning pre-training datasets has a huge effect on model quality and robustness. -* [ ] **Add meta-token prefix support for LazyTokenizer** +* [x] **~~Add meta-token prefix support for LazyTokenizer~~** Investigate the possibility of adding meta-tokens to the prefix in LazyTokenizer. Some examples could be genre, composer, or data source tags. This might require a rewrite of how sequence prefixes are handled. * [x] **~~Add 'ending soon' token to lazy tokenizer~~** diff --git a/aria/data/midi.py b/aria/data/midi.py index e31cada..4418660 100644 --- a/aria/data/midi.py +++ b/aria/data/midi.py @@ -85,7 +85,7 @@ def __init__( instrument_msgs: list, note_msgs: list, ticks_per_beat: int, - metadata: list, + metadata: dict, ): self.meta_msgs = meta_msgs self.tempo_msgs = tempo_msgs diff --git a/aria/tokenizer/tokenizer.py b/aria/tokenizer/tokenizer.py index bdfdf28..51146a4 100644 --- a/aria/tokenizer/tokenizer.py +++ b/aria/tokenizer/tokenizer.py @@ -88,17 +88,6 @@ def _dec_fn(id): return decoded_seq -# TODO: ADD META TOKEN FUNCTIONALITY -# - Refactor meta tokens so they are a triple ("meta", "val") -# this should make using tok_type stuff easier. This refactor should touch a -# lot of code so take care here. -# - Add prefix class function to lazy tokenizer that calculates meta tokens. -# This prefix class should call various functions according to config.json. -# - One function could be doing regex on the meta messages, looking for -# composer names. If one and only one composer name is found then it is added -# to the prefix before the instruments. We could specify the list of -# composers we are interested in in the config.json. -# - By loading according to the config.json we could extend this easily. class TokenizerLazy(Tokenizer): """Lazy MidiDict Tokenizer""" @@ -133,7 +122,15 @@ def __init__( if v is False ] self.instruments_wd = self.instruments_nd + ["drum"] - self.prefix_tokens = [("prefix", x) for x in self.instruments_wd] + + # Prefix tokens + self.prefix_tokens = [ + ("prefix", "instrument", x) for x in self.instruments_wd + ] + self.composer_names = self.config["composer_names"] + self.prefix_tokens += [ + ("prefix", "composer", x) for x in self.composer_names + ] # Build vocab self.special_tokens = [ @@ -272,9 +269,16 @@ def tokenize_midi_dict(self, midi_dict: MidiDict): channel_to_instrument[c] = "piano" # Add non-drums to present_instruments (prefix) - prefix = [("prefix", x) for x in set(channel_to_instrument.values())] + prefix = [ + ("prefix", "instrument", x) + for x in set(channel_to_instrument.values()) + ] if 9 in channels_used: - prefix.append(("prefix", "drum")) + prefix.append(("prefix", "instrument", "drum")) + + composer = midi_dict.metadata.get("composer") + if composer and (composer in self.composer_names): + prefix.insert(0, ("prefix", "composer", composer)) # NOTE: Any preceding silence is removed implicitly tokenized_seq = [] @@ -378,20 +382,24 @@ def detokenize_midi_dict(self, tokenized_seq: list): if tok in self.special_tokens: continue # Non-drum instrument prefix tok - elif tok[0] == "prefix" and tok[1] in self.instruments_nd: + elif ( + tok[0] == "prefix" + and tok[1] == "instrument" + and tok[2] in self.instruments_nd + ): if tok[1] in instrument_to_channel.keys(): - logging.warning(f"Duplicate prefix {tok[1]}") + logging.warning(f"Duplicate prefix {tok[2]}") continue else: instrument_msgs.append( { "type": "instrument", - "data": instrument_programs[tok[1]], + "data": instrument_programs[tok[2]], "tick": 0, "channel": channel_idx, } ) - instrument_to_channel[tok[1]] = channel_idx + instrument_to_channel[tok[2]] = channel_idx channel_idx += 1 # Catches all other prefix tokens elif tok[0] == "prefix": diff --git a/config/config.json b/config/config.json index 4d98fd8..5a03563 100644 --- a/config/config.json +++ b/config/config.json @@ -113,7 +113,8 @@ "time_quantization": { "num_steps": 500, "min_step": 10 - } + }, + "composer_names": ["beethoven", "bach", "mozart", "chopin"] } } } diff --git a/tests/test_data.py b/tests/test_data.py index 4d76be2..6024f44 100644 --- a/tests/test_data.py +++ b/tests/test_data.py @@ -11,8 +11,9 @@ def get_short_seq(): return [ - ("prefix", "piano"), - ("prefix", "drum"), + ("prefix", "instrument", "piano"), + ("prefix", "instrument", "drum"), + ("prefix", "composer", "bach"), "", ("piano", 62, 50), ("dur", 50), @@ -164,12 +165,12 @@ def test_augmentation(self): logging.info(f"aug:\n{seq} ->\n{seq_augmented}") self.assertEqual( - seq_augmented[3][1] - seq[3][1], - seq_augmented[7][1] - seq[7][1], + seq_augmented[4][1] - seq[4][1], + seq_augmented[8][1] - seq[8][1], ) self.assertEqual( - seq_augmented[3][2] - seq[3][2], - seq_augmented[7][2] - seq[7][2], + seq_augmented[4][2] - seq[4][2], + seq_augmented[8][2] - seq[8][2], ) tokenized_dataset.close() diff --git a/tests/test_tokenizers.py b/tests/test_tokenizers.py index 80e49ae..576ad3d 100644 --- a/tests/test_tokenizers.py +++ b/tests/test_tokenizers.py @@ -9,8 +9,9 @@ def get_short_seq(tknzr: tokenizer.TokenizerLazy): return [ - ("prefix", "piano"), - ("prefix", "drum"), + ("prefix", "instrument", "piano"), + ("prefix", "instrument", "drum"), + ("prefix", "composer", "bach"), "", ("piano", 62, tknzr._quantize_velocity(50)), ("dur", tknzr._quantize_time(50)), @@ -63,15 +64,15 @@ def test_aug(self): seq_pitch_augmented = pitch_aug_fn(get_short_seq(tknzr)) logging.info(f"pitch_aug_fn:\n{seq} ->\n{seq_pitch_augmented}") self.assertEqual( - seq_pitch_augmented[3][1] - seq[3][1], - seq_pitch_augmented[7][1] - seq[7][1], + seq_pitch_augmented[4][1] - seq[4][1], + seq_pitch_augmented[8][1] - seq[8][1], ) seq_velocity_augmented = velocity_aug_fn(get_short_seq(tknzr)) logging.info(f"velocity_aug_fn:\n{seq} ->\n{seq_velocity_augmented}") self.assertEqual( - seq_velocity_augmented[3][2] - seq[3][2], - seq_velocity_augmented[7][2] - seq[7][2], + seq_velocity_augmented[4][2] - seq[4][2], + seq_velocity_augmented[8][2] - seq[8][2], ) seq_tempo_augmented = tempo_aug_fn(get_short_seq(tknzr)) @@ -94,6 +95,13 @@ def test_encode_decode(self): for x, y in zip(seq, enc_dec_seq): self.assertEqual(x, y) + def test_no_unk_token(self): + tknzr = tokenizer.TokenizerLazy() + seq = get_short_seq(tknzr) + enc_dec_seq = tknzr.decode(tknzr.encode(seq)) + for tok in enc_dec_seq: + self.assertTrue(tok != tknzr.unk_tok) + if __name__ == "__main__": if os.path.isdir("tests/test_results") is False: