Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

MidiDataset optimizations #68

Merged
merged 11 commits into from
Nov 22, 2023
Prev Previous commit
Next Next commit
add jsonl.zst support; unit test; fix bug
  • Loading branch information
honglu2875 committed Nov 10, 2023
commit 8bbc14b2f7498db3fd8f079a7531a80036b3c33b
14 changes: 9 additions & 5 deletions aria/data/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,10 +57,14 @@ def __init__(self, entries: list[MidiDict] | Iterable):
self.entries = entries

def __len__(self):
return len(list(self.entries))
if not isinstance(self.entries, list):
self.entries = list(self.entries)
return len(self.entries)

def __getitem__(self, ind: int):
return list(self.entries)[ind]
if not isinstance(self.entries, list):
self.entries = list(self.entries)
return self.entries[ind]

def __iter__(self):
yield from self.entries
Expand Down Expand Up @@ -564,11 +568,11 @@ def _enqueue(iq):

with tqdm.tqdm() as t:
while True:
try:
result = oq.get(timeout=1000)
if not oq.empty():
result = oq.get()
t.update(1)
yield result
except oq.Empty:
else:
if not any(proc.is_alive() for proc in workers):
break

Expand Down
71 changes: 71 additions & 0 deletions aria/data/jsonl_zst.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
import builtins
import contextlib
import io
import zstandard
import jsonlines
import json


class Reader:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm guessing that you are using this when building hf datasets. If so I'm happy to add jsonl.zst functionality for streaming MidiDataset and TokenizedDatasets.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The only blocker to replace everything by jsonl.zst reader/writer was that it doesn't allow for mmap indexing. It can't specify a place and directly retrieve the content without reading everything previously. I don't know if it's a theoretic impossibility or just an implementation limit of zstandard

"""Reader for the jsonl.zst format."""

def __init__(self, path: str):
"""Initializes the reader.

Args:
path (str): Path to the file.
"""
self.path = path

def __iter__(self):
with builtins.open(self.path, 'rb') as fh:
cctx = zstandard.ZstdDecompressor()
reader = io.BufferedReader(cctx.stream_reader(fh))
yield from jsonlines.Reader(reader)


class Writer:
"""Writer for the jsonl.zst format."""

def __init__(self, path: str):
"""Initializes the writer.

Args:
path (str): Path to the file.
"""
self.path = path

def __enter__(self):
self.fh = builtins.open(self.path, 'wb')
self.cctx = zstandard.ZstdCompressor()
self.compressor = self.cctx.stream_writer(self.fh)
return self

def write(self, obj):
self.compressor.write(json.dumps(obj).encode('UTF-8') + b'\n')

def __exit__(self, exc_type, exc_value, traceback):
self.compressor.flush(zstandard.FLUSH_FRAME)
self.fh.flush()
self.compressor.close()
self.fh.close()


@contextlib.contextmanager
def open(path: str, mode: str = "r"):
"""Read/Write a jsonl.zst file.

Args:
path (str): Path to the file.
mode (str): Mode to open the file in. Only 'r' and 'w' are supported.

Returns:
Reader or Writer: Reader if mode is 'r', Writer if mode is 'w'.
"""
if mode == 'r':
yield Reader(path)
elif mode == 'w':
with Writer(path) as writer:
yield writer
else:
raise ValueError(f"Unsupported mode '{mode}'")
18 changes: 18 additions & 0 deletions tests/test_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from aria import tokenizer
from aria.data import datasets
from aria.data.midi import MidiDict
from aria.data import jsonl_zst

if not os.path.isdir("tests/test_results"):
os.makedirs("tests/test_results")
Expand Down Expand Up @@ -245,6 +246,23 @@ def test_augmentation(self):
tokenized_dataset.close()


class TestReaderWriter(unittest.TestCase):
def test_jsonl_zst(self):
data = [{"a": i, "b": i+1} for i in range(0, 100, 4)]
filename = "tests/test_results/test.jsonl.zst"
# if test.jsonl.zst exists, delete it
if os.path.isfile(filename):
os.remove(filename)
with jsonl_zst.open(filename, "w") as f:
for d in data:
f.write(d)
with jsonl_zst.open(filename, "r") as f:
for d, d2 in zip(data, f):
self.assertEqual(d, d2)
# Remove the file
os.remove(filename)


if __name__ == "__main__":
if os.path.isdir("tests/test_results") is False:
os.mkdir("tests/test_results")
Expand Down