Skip to content

Commit

Permalink
MidiDataset optimizations (#68)
Browse files Browse the repository at this point in the history
* MidiDataset can initialize with an iterator and only expand when necessary.

* reduce some memory overhead (we are starting to have >100k MidiDict and may get more in the future)

* classmethod+property is better...

* remove functools import

* use separate workers to build dataset instead of process pool

* add jsonl.zst support; unit test; fix bug

* receive context length via commandline. It's more convenient than digging into the config file every time.

* fix a minor output format mismatch when grad_checkpoint is true

* format and small changes

---------

Co-authored-by: Louis Bradshaw <[email protected]>
  • Loading branch information
honglu2875 and loubbrad committed Nov 22, 2023
1 parent e9d82c3 commit 4cd90fc
Show file tree
Hide file tree
Showing 7 changed files with 152 additions and 29 deletions.
74 changes: 55 additions & 19 deletions aria/data/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,12 @@
from typing import Callable, Iterable
from collections import defaultdict
from copy import deepcopy
from multiprocessing import Pool
from multiprocessing import Pool, Process, Queue

from aria.config import load_config
from aria.tokenizer import Tokenizer, TokenizerLazy
from aria.data.midi import MidiDict, get_test_fn
import tqdm


def setup_logger():
Expand Down Expand Up @@ -52,13 +53,17 @@ class MidiDataset:
entries (list[MidiDict]): MidiDict objects to be stored.
"""

def __init__(self, entries: list[MidiDict]):
def __init__(self, entries: list[MidiDict] | Iterable):
self.entries = entries

def __len__(self):
if not isinstance(self.entries, list):
self.entries = list(self.entries)
return len(self.entries)

def __getitem__(self, ind: int):
if not isinstance(self.entries, list):
self.entries = list(self.entries)
return self.entries[ind]

def __iter__(self):
Expand All @@ -72,14 +77,18 @@ def save(self, save_path: str):
writer.write(midi_dict.get_msg_dict())

@classmethod
def load(cls, load_path: str):
def load(cls, load_path: str, stream=True):
"""Loads dataset from JSON file."""
midi_dicts = []
with jsonlines.open(load_path) as reader:
for entry in reader:
midi_dicts.append(MidiDict.from_msg_dict(entry))

return cls(midi_dicts)
def _load():
with jsonlines.open(load_path) as reader:
for entry in reader:
yield MidiDict.from_msg_dict(entry)

if stream == False:
return cls(list(_load()))
else:
return cls(_load())

@classmethod
def split_from_file(
Expand Down Expand Up @@ -529,25 +538,52 @@ def build(
TokenizedDataset: Dataset saved midi_dataset and saved at save_path.
"""

def _worker(input_queue, output_queue, tokenizer):
while True:
item = input_queue.get()
if item is None:
break
output_queue.put(_get_tokenized_seqs(item, tokenizer))

def _get_tokenized_seqs_mp(_midi_dict_iter: Iterable):
# Gets tokenized sequences using multiprocessing

# TokenizerLazy is the only supported tokenizer due to the truncate
# and stride logic in _get_tokenized_seqs
assert isinstance(tokenizer, TokenizerLazy), "Unsupported tokenizer"

with Pool() as pool:
results = pool.imap(
functools.partial(_get_tokenized_seqs, tokenizer=tokenizer),
_midi_dict_iter,
)

for idx, tokenized_seq in enumerate(results):
if tokenized_seq:
yield tokenized_seq
iq = Queue()
oq = Queue()

if idx % 50 == 0 and idx != 0:
logger.info(f"Processed MidiDicts: {idx}")
_num_proc = os.cpu_count()
workers = [
Process(
target=functools.partial(_worker, tokenizer=tokenizer),
args=(iq, oq),
)
for _ in range(_num_proc)
]
for w in workers:
w.start()

def _enqueue(iq):
for midi_dict in _midi_dict_iter:
iq.put(midi_dict)
for i in range(_num_proc):
iq.put(None)

enqueue = Process(target=_enqueue, args=(iq,))
enqueue.start()

with tqdm.tqdm() as t:
while True:
if not oq.empty():
result = oq.get()
t.update(1)
yield result
else:
if not any(proc.is_alive() for proc in workers):
break

logger = setup_logger()

Expand Down
71 changes: 71 additions & 0 deletions aria/data/jsonl_zst.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
import builtins
import contextlib
import io
import zstandard
import jsonlines
import json


class Reader:
"""Reader for the jsonl.zst format."""

def __init__(self, path: str):
"""Initializes the reader.
Args:
path (str): Path to the file.
"""
self.path = path

def __iter__(self):
with builtins.open(self.path, "rb") as fh:
cctx = zstandard.ZstdDecompressor()
reader = io.BufferedReader(cctx.stream_reader(fh))
yield from jsonlines.Reader(reader)


class Writer:
"""Writer for the jsonl.zst format."""

def __init__(self, path: str):
"""Initializes the writer.
Args:
path (str): Path to the file.
"""
self.path = path

def __enter__(self):
self.fh = builtins.open(self.path, "wb")
self.cctx = zstandard.ZstdCompressor()
self.compressor = self.cctx.stream_writer(self.fh)
return self

def write(self, obj):
self.compressor.write(json.dumps(obj).encode("UTF-8") + b"\n")

def __exit__(self, exc_type, exc_value, traceback):
self.compressor.flush(zstandard.FLUSH_FRAME)
self.fh.flush()
self.compressor.close()
self.fh.close()


@contextlib.contextmanager
def open(path: str, mode: str = "r"):
"""Read/Write a jsonl.zst file.
Args:
path (str): Path to the file.
mode (str): Mode to open the file in. Only 'r' and 'w' are supported.
Returns:
Reader or Writer: Reader if mode is 'r', Writer if mode is 'w'.
"""
if mode == "r":
yield Reader(path)
elif mode == "w":
with Writer(path) as writer:
yield writer
else:
raise ValueError(f"Unsupported mode '{mode}'")
6 changes: 4 additions & 2 deletions aria/data/midi.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
"""Utils for data/MIDI processing."""

import hashlib
import json
import re
Expand Down Expand Up @@ -114,8 +113,11 @@ def __init__(
}
]

@classmethod
@property
def program_to_instrument(cls):
# This combines the individual dictionaries into one
self.program_to_instrument = (
return (
{i: "piano" for i in range(0, 7 + 1)}
| {i: "chromatic" for i in range(8, 15 + 1)}
| {i: "organ" for i in range(16, 23 + 1)}
Expand Down
4 changes: 2 additions & 2 deletions aria/model/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ class YaRNConfig:
mscale_coeff (int): Temperature scaling factor t follows `a ln s + 1.0`,
and the coefficient `a` is this `mscale_coeff` here.
"""

beta_fast: int = 16
beta_slow: int = 1
scale: int = 1.0
Expand All @@ -44,7 +45,6 @@ def __post_init__(self):
if self.yarn_config is not None and isinstance(self.yarn_config, dict):
self.yarn_config = YaRNConfig(**self.yarn_config)


def set_vocab_size(self, vocab_size: int):
self.vocab_size = vocab_size

Expand Down Expand Up @@ -320,7 +320,7 @@ def custom_forward(*args):

return custom_forward

hidden_states = torch.utils.checkpoint.checkpoint(
hidden_states, _ = torch.utils.checkpoint.checkpoint(
create_custom_forward(layer),
hidden_states,
preserve_rng_state=True,
Expand Down
5 changes: 2 additions & 3 deletions aria/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,22 +124,21 @@ def _parse_tokenized_dataset_args():
argp.add_argument("load_path", help="path midi_dict dataset")
argp.add_argument("save_path", help="path to save dataset")
argp.add_argument("-s", help="also produce shuffled", action="store_true")
argp.add_argument("-l", help="max sequence length", type=int)

return argp.parse_args(sys.argv[2:])


def build_tokenized_dataset(args):
from aria.tokenizer import TokenizerLazy
from aria.data.datasets import TokenizedDataset
from aria.config import load_config

config = load_config()["data"]["dataset_gen_args"]
tokenizer = TokenizerLazy()
dataset = TokenizedDataset.build(
tokenizer=tokenizer,
save_path=args.save_path,
midi_dataset_path=args.load_path,
max_seq_len=config["max_seq_len"],
max_seq_len=args.l,
overwrite=True,
)
if args.s:
Expand Down
3 changes: 0 additions & 3 deletions config/config.json
Original file line number Diff line number Diff line change
Expand Up @@ -70,9 +70,6 @@
"composer_names": ["bach", "beethoven", "mozart", "chopin", "rachmaninoff", "liszt", "debussy", "schubert", "brahms", "ravel", "satie", "scarlatti"]
}
}
},
"dataset_gen_args": {
"max_seq_len": 2048
}
},

Expand Down
18 changes: 18 additions & 0 deletions tests/test_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from aria import tokenizer
from aria.data import datasets
from aria.data.midi import MidiDict
from aria.data import jsonl_zst

if not os.path.isdir("tests/test_results"):
os.makedirs("tests/test_results")
Expand Down Expand Up @@ -245,6 +246,23 @@ def test_augmentation(self):
tokenized_dataset.close()


class TestReaderWriter(unittest.TestCase):
def test_jsonl_zst(self):
data = [{"a": i, "b": i + 1} for i in range(0, 100, 4)]
filename = "tests/test_results/test.jsonl.zst"
# if test.jsonl.zst exists, delete it
if os.path.isfile(filename):
os.remove(filename)
with jsonl_zst.open(filename, "w") as f:
for d in data:
f.write(d)
with jsonl_zst.open(filename, "r") as f:
for d, d2 in zip(data, f):
self.assertEqual(d, d2)
# Remove the file
os.remove(filename)


if __name__ == "__main__":
if os.path.isdir("tests/test_results") is False:
os.mkdir("tests/test_results")
Expand Down

0 comments on commit 4cd90fc

Please sign in to comment.