Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

MidiDataset optimizations #68

Merged
merged 11 commits into from
Nov 22, 2023
Prev Previous commit
format and small changes
  • Loading branch information
loubbrad committed Nov 22, 2023
commit 607c4387ceb8c445f0d0d3025c3265072ff35dfe
17 changes: 13 additions & 4 deletions aria/data/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,14 +77,18 @@ def save(self, save_path: str):
writer.write(midi_dict.get_msg_dict())

@classmethod
def load(cls, load_path: str):
def load(cls, load_path: str, stream=True):
"""Loads dataset from JSON file."""

def _load():
with jsonlines.open(load_path) as reader:
for entry in reader:
yield MidiDict.from_msg_dict(entry)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This seems fine, but maybe include an option for when this functionality should be used e.g. stream=True or something.


return cls(_load())
if stream == False:
return cls(list(_load()))
else:
return cls(_load())

@classmethod
def split_from_file(
Expand Down Expand Up @@ -552,8 +556,13 @@ def _get_tokenized_seqs_mp(_midi_dict_iter: Iterable):
oq = Queue()

_num_proc = os.cpu_count()
workers = [Process(target=functools.partial(_worker, tokenizer=tokenizer), args=(iq, oq)) for _ in
range(_num_proc)]
workers = [
Process(
target=functools.partial(_worker, tokenizer=tokenizer),
args=(iq, oq),
)
for _ in range(_num_proc)
]
for w in workers:
w.start()

Expand Down
10 changes: 5 additions & 5 deletions aria/data/jsonl_zst.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ def __init__(self, path: str):
self.path = path

def __iter__(self):
with builtins.open(self.path, 'rb') as fh:
with builtins.open(self.path, "rb") as fh:
cctx = zstandard.ZstdDecompressor()
reader = io.BufferedReader(cctx.stream_reader(fh))
yield from jsonlines.Reader(reader)
Expand All @@ -36,13 +36,13 @@ def __init__(self, path: str):
self.path = path

def __enter__(self):
self.fh = builtins.open(self.path, 'wb')
self.fh = builtins.open(self.path, "wb")
self.cctx = zstandard.ZstdCompressor()
self.compressor = self.cctx.stream_writer(self.fh)
return self

def write(self, obj):
self.compressor.write(json.dumps(obj).encode('UTF-8') + b'\n')
self.compressor.write(json.dumps(obj).encode("UTF-8") + b"\n")

def __exit__(self, exc_type, exc_value, traceback):
self.compressor.flush(zstandard.FLUSH_FRAME)
Expand All @@ -62,9 +62,9 @@ def open(path: str, mode: str = "r"):
Returns:
Reader or Writer: Reader if mode is 'r', Writer if mode is 'w'.
"""
if mode == 'r':
if mode == "r":
yield Reader(path)
elif mode == 'w':
elif mode == "w":
with Writer(path) as writer:
yield writer
else:
Expand Down
2 changes: 1 addition & 1 deletion aria/model/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ class YaRNConfig:
mscale_coeff (int): Temperature scaling factor t follows `a ln s + 1.0`,
and the coefficient `a` is this `mscale_coeff` here.
"""

beta_fast: int = 16
beta_slow: int = 1
scale: int = 1.0
Expand All @@ -44,7 +45,6 @@ def __post_init__(self):
if self.yarn_config is not None and isinstance(self.yarn_config, dict):
self.yarn_config = YaRNConfig(**self.yarn_config)


def set_vocab_size(self, vocab_size: int):
self.vocab_size = vocab_size

Expand Down
2 changes: 1 addition & 1 deletion aria/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ def _parse_tokenized_dataset_args():
argp.add_argument("load_path", help="path midi_dict dataset")
argp.add_argument("save_path", help="path to save dataset")
argp.add_argument("-s", help="also produce shuffled", action="store_true")
argp.add_argument("-l", help="max sequence length", type=int, default=2048)
argp.add_argument("-l", help="max sequence length", type=int)

return argp.parse_args(sys.argv[2:])

Expand Down
3 changes: 0 additions & 3 deletions config/config.json
Original file line number Diff line number Diff line change
Expand Up @@ -70,9 +70,6 @@
"composer_names": ["bach", "beethoven", "mozart", "chopin", "rachmaninoff", "liszt", "debussy", "schubert", "brahms", "ravel", "satie", "scarlatti"]
}
}
},
"dataset_gen_args": {
"max_seq_len": 2048
}
},

Expand Down
2 changes: 1 addition & 1 deletion tests/test_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -248,7 +248,7 @@ def test_augmentation(self):

class TestReaderWriter(unittest.TestCase):
def test_jsonl_zst(self):
data = [{"a": i, "b": i+1} for i in range(0, 100, 4)]
data = [{"a": i, "b": i + 1} for i in range(0, 100, 4)]
filename = "tests/test_results/test.jsonl.zst"
# if test.jsonl.zst exists, delete it
if os.path.isfile(filename):
Expand Down