Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

MidiDataset optimizations #68

Merged
merged 11 commits into from
Nov 22, 2023
Prev Previous commit
Next Next commit
use separate workers to build dataset instead of process pool
  • Loading branch information
honglu2875 committed Nov 10, 2023
commit 7073a1aceeebb73a7fc96163052b10e5c6bb2081
48 changes: 36 additions & 12 deletions aria/data/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,12 @@
from typing import Callable, Iterable
from collections import defaultdict
from copy import deepcopy
from multiprocessing import Pool
from multiprocessing import Pool, Process, Queue

from aria.config import load_config
from aria.tokenizer import Tokenizer, TokenizerLazy
from aria.data.midi import MidiDict, get_test_fn
import tqdm


def setup_logger():
Expand Down Expand Up @@ -529,24 +530,47 @@ def build(
TokenizedDataset: Dataset saved midi_dataset and saved at save_path.
"""

def _worker(input_queue, output_queue, tokenizer):
while True:
item = input_queue.get()
if item is None:
break
output_queue.put(_get_tokenized_seqs(item, tokenizer))

def _get_tokenized_seqs_mp(_midi_dict_iter: Iterable):
# Gets tokenized sequences using multiprocessing

# TokenizerLazy is the only supported tokenizer due to the truncate
# and stride logic in _get_tokenized_seqs
assert isinstance(tokenizer, TokenizerLazy), "Unsupported tokenizer"

with Pool() as pool:
results = pool.imap(
functools.partial(_get_tokenized_seqs, tokenizer=tokenizer),
_midi_dict_iter,
)

for idx, tokenized_seq in enumerate(results):
yield tokenized_seq

if idx % 50 == 0 and idx != 0:
logger.info(f"Processed MidiDicts: {idx}")
iq = Queue()
oq = Queue()

_num_proc = os.cpu_count()
workers = [Process(target=functools.partial(_worker, tokenizer=tokenizer), args=(iq, oq)) for _ in
range(_num_proc)]
for w in workers:
w.start()

def _enqueue(iq):
for midi_dict in _midi_dict_iter:
iq.put(midi_dict)
for i in range(_num_proc):
iq.put(None)

enqueue = Process(target=_enqueue, args=(iq,))
enqueue.start()

with tqdm.tqdm() as t:
while True:
try:
result = oq.get(timeout=1000)
t.update(1)
yield result
except oq.Empty:
if not any(proc.is_alive() for proc in workers):
break

logger = setup_logger()

Expand Down