-
Notifications
You must be signed in to change notification settings - Fork 11
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
MidiDataset optimizations #68
Changes from all commits
0230715
c1d170d
deec979
6e58f09
00cef03
7073a1a
69922b4
8bbc14b
95f492d
93301fa
607c438
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,71 @@ | ||
import builtins | ||
import contextlib | ||
import io | ||
import zstandard | ||
import jsonlines | ||
import json | ||
|
||
|
||
class Reader: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'm guessing that you are using this when building hf datasets. If so I'm happy to add jsonl.zst functionality for streaming MidiDataset and TokenizedDatasets. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The only blocker to replace everything by jsonl.zst reader/writer was that it doesn't allow for mmap indexing. It can't specify a place and directly retrieve the content without reading everything previously. I don't know if it's a theoretic impossibility or just an implementation limit of |
||
"""Reader for the jsonl.zst format.""" | ||
|
||
def __init__(self, path: str): | ||
"""Initializes the reader. | ||
|
||
Args: | ||
path (str): Path to the file. | ||
""" | ||
self.path = path | ||
|
||
def __iter__(self): | ||
with builtins.open(self.path, "rb") as fh: | ||
cctx = zstandard.ZstdDecompressor() | ||
reader = io.BufferedReader(cctx.stream_reader(fh)) | ||
yield from jsonlines.Reader(reader) | ||
|
||
|
||
class Writer: | ||
"""Writer for the jsonl.zst format.""" | ||
|
||
def __init__(self, path: str): | ||
"""Initializes the writer. | ||
|
||
Args: | ||
path (str): Path to the file. | ||
""" | ||
self.path = path | ||
|
||
def __enter__(self): | ||
self.fh = builtins.open(self.path, "wb") | ||
self.cctx = zstandard.ZstdCompressor() | ||
self.compressor = self.cctx.stream_writer(self.fh) | ||
return self | ||
|
||
def write(self, obj): | ||
self.compressor.write(json.dumps(obj).encode("UTF-8") + b"\n") | ||
|
||
def __exit__(self, exc_type, exc_value, traceback): | ||
self.compressor.flush(zstandard.FLUSH_FRAME) | ||
self.fh.flush() | ||
self.compressor.close() | ||
self.fh.close() | ||
|
||
|
||
@contextlib.contextmanager | ||
def open(path: str, mode: str = "r"): | ||
"""Read/Write a jsonl.zst file. | ||
|
||
Args: | ||
path (str): Path to the file. | ||
mode (str): Mode to open the file in. Only 'r' and 'w' are supported. | ||
|
||
Returns: | ||
Reader or Writer: Reader if mode is 'r', Writer if mode is 'w'. | ||
""" | ||
if mode == "r": | ||
yield Reader(path) | ||
elif mode == "w": | ||
with Writer(path) as writer: | ||
yield writer | ||
else: | ||
raise ValueError(f"Unsupported mode '{mode}'") |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,5 +1,4 @@ | ||
"""Utils for data/MIDI processing.""" | ||
|
||
import hashlib | ||
import json | ||
import re | ||
|
@@ -114,8 +113,11 @@ def __init__( | |
} | ||
] | ||
|
||
@classmethod | ||
@property | ||
def program_to_instrument(cls): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'm assuming this is here to speed up the process of building MidiDatasets. Does it make much of a difference? |
||
# This combines the individual dictionaries into one | ||
self.program_to_instrument = ( | ||
return ( | ||
{i: "piano" for i in range(0, 7 + 1)} | ||
| {i: "chromatic" for i in range(8, 15 + 1)} | ||
| {i: "organ" for i in range(16, 23 + 1)} | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -22,6 +22,7 @@ class YaRNConfig: | |
mscale_coeff (int): Temperature scaling factor t follows `a ln s + 1.0`, | ||
and the coefficient `a` is this `mscale_coeff` here. | ||
""" | ||
|
||
beta_fast: int = 16 | ||
beta_slow: int = 1 | ||
scale: int = 1.0 | ||
|
@@ -44,7 +45,6 @@ def __post_init__(self): | |
if self.yarn_config is not None and isinstance(self.yarn_config, dict): | ||
self.yarn_config = YaRNConfig(**self.yarn_config) | ||
|
||
|
||
def set_vocab_size(self, vocab_size: int): | ||
self.vocab_size = vocab_size | ||
|
||
|
@@ -320,7 +320,7 @@ def custom_forward(*args): | |
|
||
return custom_forward | ||
|
||
hidden_states = torch.utils.checkpoint.checkpoint( | ||
hidden_states, _ = torch.utils.checkpoint.checkpoint( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Did this end up fixing something? |
||
create_custom_forward(layer), | ||
hidden_states, | ||
preserve_rng_state=True, | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -124,22 +124,21 @@ def _parse_tokenized_dataset_args(): | |
argp.add_argument("load_path", help="path midi_dict dataset") | ||
argp.add_argument("save_path", help="path to save dataset") | ||
argp.add_argument("-s", help="also produce shuffled", action="store_true") | ||
argp.add_argument("-l", help="max sequence length", type=int) | ||
|
||
return argp.parse_args(sys.argv[2:]) | ||
|
||
|
||
def build_tokenized_dataset(args): | ||
from aria.tokenizer import TokenizerLazy | ||
from aria.data.datasets import TokenizedDataset | ||
from aria.config import load_config | ||
|
||
config = load_config()["data"]["dataset_gen_args"] | ||
tokenizer = TokenizerLazy() | ||
dataset = TokenizedDataset.build( | ||
tokenizer=tokenizer, | ||
save_path=args.save_path, | ||
midi_dataset_path=args.load_path, | ||
max_seq_len=config["max_seq_len"], | ||
max_seq_len=args.l, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Good idea. Might as well also remove the |
||
overwrite=True, | ||
) | ||
if args.s: | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This seems fine, but maybe include an option for when this functionality should be used e.g. stream=True or something.