diff --git a/aria/tokenizer/tokenizer.py b/aria/tokenizer/tokenizer.py index 71a09a0..c455ac5 100644 --- a/aria/tokenizer/tokenizer.py +++ b/aria/tokenizer/tokenizer.py @@ -72,20 +72,20 @@ def tokenize(self, midi_dict: MidiDict, **kwargs): required. For instance, in fine-tuning tokenizer you may want to insert additional tokens. The default behavior is to call tokenize_midi_dict. """ - return self._tokenize_midi_dict(midi_dict) + return self._tokenize_midi_dict(midi_dict, **kwargs) def _detokenize_midi_dict(self, tokenized_seq: list): """Abstract method for de-tokenizing a sequence of tokens into a MidiDict Object.""" raise NotImplementedError - def detokenize(self, tokenized_seq: list): + def detokenize(self, tokenized_seq: list, **kwargs): """Detokenizes a MidiDict object. This function should be overridden if additional are required during detokenization. The default behavior is to call detokenize_midi_dict. """ - return self._detokenize_midi_dict(tokenized_seq) + return self._detokenize_midi_dict(tokenized_seq, **kwargs) def export_data_aug(cls): """Abstract method for exporting a list of all data augmentation @@ -411,7 +411,9 @@ def truncate_by_time(self, tokenized_seq: list, trunc_time_ms: int): return tokenized_seq - def _tokenize_midi_dict(self, midi_dict: MidiDict): + def _tokenize_midi_dict( + self, midi_dict: MidiDict, remove_preceding_silence: bool = True + ): ticks_per_beat = midi_dict.ticks_per_beat midi_dict.remove_instruments(self.config["ignore_instruments"]) @@ -450,9 +452,13 @@ def _tokenize_midi_dict(self, midi_dict: MidiDict): prefix.insert(0, ("prefix", "genre", genre)) random.shuffle(prefix) - # NOTE: Any preceding silence is removed implicitly tokenized_seq = [] - initial_onset_tick = midi_dict.note_msgs[0]["data"]["start"] + + if remove_preceding_silence is False: + initial_onset_tick = 0 + else: + initial_onset_tick = midi_dict.note_msgs[0]["data"]["start"] + curr_time_since_onset = 0 for _, msg in enumerate(midi_dict.note_msgs): # Extract msg data @@ -543,12 +549,14 @@ def _detokenize_midi_dict(self, tokenized_seq: list): # Add non-drum instrument_msgs, breaks at first note token channel_idx = 0 + curr_tick = 0 for idx, tok in enumerate(tokenized_seq): if channel_idx == 9: # Skip channel reserved for drums channel_idx += 1 if tok in self.special_tokens: - # Skip special tokens + if tok == self.time_tok: + curr_tick += self.abs_time_step continue elif ( tok[0] == "prefix" @@ -590,7 +598,6 @@ def _detokenize_midi_dict(self, tokenized_seq: list): # Note messages note_msgs = [] - curr_tick = 0 for tok_1, tok_2, tok_3 in zip( tokenized_seq[start:], tokenized_seq[start + 1 :], diff --git a/requirements.txt b/requirements.txt index c70ea51..3e82abf 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,4 +1,4 @@ -torch >= 2.1 +torch >= 2.0 accelerate mido jsonlines