Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

MidiDataset optimizations #68

Merged
merged 11 commits into from
Nov 22, 2023
65 changes: 46 additions & 19 deletions aria/data/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,12 @@
from typing import Callable, Iterable
from collections import defaultdict
from copy import deepcopy
from multiprocessing import Pool
from multiprocessing import Pool, Process, Queue

from aria.config import load_config
from aria.tokenizer import Tokenizer, TokenizerLazy
from aria.data.midi import MidiDict, get_test_fn
import tqdm


def setup_logger():
Expand Down Expand Up @@ -52,13 +53,17 @@ class MidiDataset:
entries (list[MidiDict]): MidiDict objects to be stored.
"""

def __init__(self, entries: list[MidiDict]):
def __init__(self, entries: list[MidiDict] | Iterable):
self.entries = entries

def __len__(self):
if not isinstance(self.entries, list):
self.entries = list(self.entries)
return len(self.entries)

def __getitem__(self, ind: int):
if not isinstance(self.entries, list):
self.entries = list(self.entries)
return self.entries[ind]

def __iter__(self):
Expand All @@ -74,12 +79,12 @@ def save(self, save_path: str):
@classmethod
def load(cls, load_path: str):
"""Loads dataset from JSON file."""
midi_dicts = []
with jsonlines.open(load_path) as reader:
for entry in reader:
midi_dicts.append(MidiDict.from_msg_dict(entry))
def _load():
with jsonlines.open(load_path) as reader:
for entry in reader:
yield MidiDict.from_msg_dict(entry)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This seems fine, but maybe include an option for when this functionality should be used e.g. stream=True or something.


return cls(midi_dicts)
return cls(_load())

@classmethod
def split_from_file(
Expand Down Expand Up @@ -529,25 +534,47 @@ def build(
TokenizedDataset: Dataset saved midi_dataset and saved at save_path.
"""

def _worker(input_queue, output_queue, tokenizer):
while True:
item = input_queue.get()
if item is None:
break
output_queue.put(_get_tokenized_seqs(item, tokenizer))

def _get_tokenized_seqs_mp(_midi_dict_iter: Iterable):
# Gets tokenized sequences using multiprocessing

# TokenizerLazy is the only supported tokenizer due to the truncate
# and stride logic in _get_tokenized_seqs
assert isinstance(tokenizer, TokenizerLazy), "Unsupported tokenizer"

with Pool() as pool:
results = pool.imap(
functools.partial(_get_tokenized_seqs, tokenizer=tokenizer),
_midi_dict_iter,
)

for idx, tokenized_seq in enumerate(results):
if tokenized_seq:
yield tokenized_seq

if idx % 50 == 0 and idx != 0:
logger.info(f"Processed MidiDicts: {idx}")
iq = Queue()
oq = Queue()

_num_proc = os.cpu_count()
workers = [Process(target=functools.partial(_worker, tokenizer=tokenizer), args=(iq, oq)) for _ in
range(_num_proc)]
for w in workers:
w.start()

def _enqueue(iq):
for midi_dict in _midi_dict_iter:
iq.put(midi_dict)
for i in range(_num_proc):
iq.put(None)

enqueue = Process(target=_enqueue, args=(iq,))
enqueue.start()

with tqdm.tqdm() as t:
while True:
if not oq.empty():
result = oq.get()
t.update(1)
yield result
else:
if not any(proc.is_alive() for proc in workers):
break

logger = setup_logger()

Expand Down
71 changes: 71 additions & 0 deletions aria/data/jsonl_zst.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
import builtins
import contextlib
import io
import zstandard
import jsonlines
import json


class Reader:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm guessing that you are using this when building hf datasets. If so I'm happy to add jsonl.zst functionality for streaming MidiDataset and TokenizedDatasets.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The only blocker to replace everything by jsonl.zst reader/writer was that it doesn't allow for mmap indexing. It can't specify a place and directly retrieve the content without reading everything previously. I don't know if it's a theoretic impossibility or just an implementation limit of zstandard

"""Reader for the jsonl.zst format."""

def __init__(self, path: str):
"""Initializes the reader.

Args:
path (str): Path to the file.
"""
self.path = path

def __iter__(self):
with builtins.open(self.path, 'rb') as fh:
cctx = zstandard.ZstdDecompressor()
reader = io.BufferedReader(cctx.stream_reader(fh))
yield from jsonlines.Reader(reader)


class Writer:
"""Writer for the jsonl.zst format."""

def __init__(self, path: str):
"""Initializes the writer.

Args:
path (str): Path to the file.
"""
self.path = path

def __enter__(self):
self.fh = builtins.open(self.path, 'wb')
self.cctx = zstandard.ZstdCompressor()
self.compressor = self.cctx.stream_writer(self.fh)
return self

def write(self, obj):
self.compressor.write(json.dumps(obj).encode('UTF-8') + b'\n')

def __exit__(self, exc_type, exc_value, traceback):
self.compressor.flush(zstandard.FLUSH_FRAME)
self.fh.flush()
self.compressor.close()
self.fh.close()


@contextlib.contextmanager
def open(path: str, mode: str = "r"):
"""Read/Write a jsonl.zst file.

Args:
path (str): Path to the file.
mode (str): Mode to open the file in. Only 'r' and 'w' are supported.

Returns:
Reader or Writer: Reader if mode is 'r', Writer if mode is 'w'.
"""
if mode == 'r':
yield Reader(path)
elif mode == 'w':
with Writer(path) as writer:
yield writer
else:
raise ValueError(f"Unsupported mode '{mode}'")
6 changes: 4 additions & 2 deletions aria/data/midi.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
"""Utils for data/MIDI processing."""

import hashlib
import json
import re
Expand Down Expand Up @@ -114,8 +113,11 @@ def __init__(
}
]

@classmethod
@property
def program_to_instrument(cls):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm assuming this is here to speed up the process of building MidiDatasets. Does it make much of a difference?

# This combines the individual dictionaries into one
self.program_to_instrument = (
return (
{i: "piano" for i in range(0, 7 + 1)}
| {i: "chromatic" for i in range(8, 15 + 1)}
| {i: "organ" for i in range(16, 23 + 1)}
Expand Down
2 changes: 1 addition & 1 deletion aria/model/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -320,7 +320,7 @@ def custom_forward(*args):

return custom_forward

hidden_states = torch.utils.checkpoint.checkpoint(
hidden_states, _ = torch.utils.checkpoint.checkpoint(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Did this end up fixing something?

create_custom_forward(layer),
hidden_states,
preserve_rng_state=True,
Expand Down
5 changes: 2 additions & 3 deletions aria/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,22 +124,21 @@ def _parse_tokenized_dataset_args():
argp.add_argument("load_path", help="path midi_dict dataset")
argp.add_argument("save_path", help="path to save dataset")
argp.add_argument("-s", help="also produce shuffled", action="store_true")
argp.add_argument("-l", help="max sequence length", type=int, default=2048)

return argp.parse_args(sys.argv[2:])


def build_tokenized_dataset(args):
from aria.tokenizer import TokenizerLazy
from aria.data.datasets import TokenizedDataset
from aria.config import load_config

config = load_config()["data"]["dataset_gen_args"]
tokenizer = TokenizerLazy()
dataset = TokenizedDataset.build(
tokenizer=tokenizer,
save_path=args.save_path,
midi_dataset_path=args.load_path,
max_seq_len=config["max_seq_len"],
max_seq_len=args.l,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good idea. Might as well also remove the dataset_gen_args from the config json.

overwrite=True,
)
if args.s:
Expand Down
18 changes: 18 additions & 0 deletions tests/test_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from aria import tokenizer
from aria.data import datasets
from aria.data.midi import MidiDict
from aria.data import jsonl_zst

if not os.path.isdir("tests/test_results"):
os.makedirs("tests/test_results")
Expand Down Expand Up @@ -245,6 +246,23 @@ def test_augmentation(self):
tokenized_dataset.close()


class TestReaderWriter(unittest.TestCase):
def test_jsonl_zst(self):
data = [{"a": i, "b": i+1} for i in range(0, 100, 4)]
filename = "tests/test_results/test.jsonl.zst"
# if test.jsonl.zst exists, delete it
if os.path.isfile(filename):
os.remove(filename)
with jsonl_zst.open(filename, "w") as f:
for d in data:
f.write(d)
with jsonl_zst.open(filename, "r") as f:
for d, d2 in zip(data, f):
self.assertEqual(d, d2)
# Remove the file
os.remove(filename)


if __name__ == "__main__":
if os.path.isdir("tests/test_results") is False:
os.mkdir("tests/test_results")
Expand Down