Skip to content

Commit

Permalink
Fix PretrainingDataset.build bottleneck (#92)
Browse files Browse the repository at this point in the history
* improve logging

* fix

* fix dataset building bottlekneck
  • Loading branch information
loubbrad committed Jan 20, 2024
1 parent b5bb6cf commit 748e259
Show file tree
Hide file tree
Showing 4 changed files with 54 additions and 21 deletions.
66 changes: 49 additions & 17 deletions aria/data/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,6 +190,10 @@ def build_to_file(
def combine_datasets_from_file(cls, *args: str, output_path: str):
"""Function for concatenating jsonl files, checking for duplicates"""
logger = setup_logger()

for input_path in args:
assert os.path.isfile(input_path), f"{input_path} doesn't exist"

dupe_cnt = 0
hashes = {}
with jsonlines.open(output_path, mode="w") as f_out:
Expand All @@ -208,6 +212,10 @@ def combine_datasets_from_file(cls, *args: str, output_path: str):
f_out.write(msg_dict)
hashes[midi_dict_hash] = True
logger.info(f"Finished processing: {input_path}")
logger.info(
f"{len(hashes)} unique midi_dicts and {dupe_cnt} duplicates so far"
)


logger.info(
f"Found {len(hashes)} unique midi_dicts and {dupe_cnt} duplicates"
Expand Down Expand Up @@ -299,25 +307,32 @@ def _get_mididicts_mp(_paths):
with Pool() as pool:
results = pool.imap(_get_mididict, _paths)
seen_hashes = defaultdict(list)
dupe_cnt = 0
failed_cnt = 0
for idx, (success, result) in enumerate(results):
if idx % 50 == 0 and idx != 0:
logger.info(f"Processed MIDI files: {idx}/{num_paths}")

if not success:
failed_cnt += 1
continue
else:
mid_dict, mid_hash, mid_path = result

if seen_hashes.get(mid_hash):
logger.info(
f"MIDI located at '{mid_path}' is a duplicate - already"
f" seen at: {seen_hashes[mid_hash]}"
f" seen at: {seen_hashes[mid_hash][0]}"
)
seen_hashes[mid_hash].append(str(mid_path))
dupe_cnt += 1
else:
seen_hashes[mid_hash].append(str(mid_path))
yield mid_dict

print(f"Total duplicates: {dupe_cnt}")
print(f"Total processing fails (tests or otherwise): {failed_cnt}")

logger = setup_logger()
if get_start_method() == "spawn":
logger.warning(
Expand Down Expand Up @@ -523,8 +538,12 @@ def _get_seqs(_entry: MidiDict | dict, _tokenizer: Tokenizer):

if isinstance(_entry, str):
_midi_dict = MidiDict.from_msg_dict(json.loads(_entry.rstrip()))
else:
elif isinstance(_entry, dict):
_midi_dict = MidiDict.from_msg_dict(_entry)
elif isinstance(_entry, MidiDict):
_midi_dict = _entry
else:
raise Exception

try:
_tokenized_seq = _tokenizer.tokenize(_midi_dict)
Expand Down Expand Up @@ -584,6 +603,19 @@ def get_seqs(
if not any(proc.is_alive() for proc in workers):
break

def reservoir(_iterable: Iterable, k: int):
_reservoir = []
for entry in _iterable:
if entry is not None:
_reservoir.append(entry)

if len(_reservoir) >= k:
random.shuffle(_reservoir)
yield from _reservoir
_reservoir = []

if _reservoir != []:
yield from _reservoir

class PretrainingDataset(TrainingDataset):
def __init__(self, dir_path: str, tokenizer: Tokenizer):
Expand Down Expand Up @@ -695,13 +727,18 @@ def _build_epoch(_save_path, _midi_dataset):
)

buffer = []
# TODO: Profile why mp takes a while to spit up
for entry in get_seqs(tokenizer, _midi_dataset):
_idx = 0
for entry in reservoir(get_seqs(tokenizer, _midi_dataset), 5):
if entry is not None:
buffer += entry
while len(buffer) >= max_seq_len:
writer.write(buffer[:max_seq_len])
buffer = buffer[max_seq_len:]

_idx += 1
if _idx % 250 == 0:
logger.info(f"Finished processing {_idx}")

buffer += [tokenizer.pad_tok] * (max_seq_len - len(buffer))

logger = setup_logger()
Expand All @@ -727,32 +764,27 @@ def _build_epoch(_save_path, _midi_dataset):
if not os.path.exists(save_dir):
os.mkdir(save_dir)

# TODO: This is very slow right now
if not midi_dataset:
midi_dataset = MidiDataset.load(midi_dataset_path)
else:
if not midi_dataset and not midi_dataset_path:
Exception("Must provide either midi_dataset or midi_dataset_path")
if midi_dataset and midi_dataset_path:
Exception("Can't provide both midi_dataset and midi_dataset_path")

logger.info(
f"Building PretrainingDataset with config: "
f"max_seq_len={max_seq_len}, "
f"tokenizer_name={tokenizer.name}"
)
_num_proc = os.cpu_count()
if 2 * _num_proc > len(midi_dataset):
logger.warning(
"Number of processes is close to the number of MidiDicts "
"in the dataset. This can result in shuffling not working "
"as intended when building different epochs"
)
for idx in range(num_epochs):
logger.info(f"Building epoch {idx}/{num_epochs - 1}...")

# Reload the dataset on each iter
if midi_dataset_path:
midi_dataset = jsonlines.open(midi_dataset_path, "r")

_build_epoch(
_save_path=os.path.join(save_dir, f"epoch{idx}.jsonl"),
_midi_dataset=midi_dataset,
)
# TODO: This is very slow for large datasets
midi_dataset.shuffle()

logger.info(
f"Finished building, saved PretrainingDataset to {save_dir}"
Expand Down
1 change: 1 addition & 0 deletions aria/data/midi.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,7 @@ def calculate_hash(self):
msg_dict_to_hash = self.get_msg_dict()
# Remove meta when calculating hash
msg_dict_to_hash.pop("meta_msgs")
msg_dict_to_hash.pop("ticks_per_beat")
msg_dict_to_hash.pop("metadata")

return hashlib.md5(
Expand Down
4 changes: 2 additions & 2 deletions aria/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -263,7 +263,7 @@ def _parse_midi_dataset_args():
help="manually add metadata key-value pair when building dataset",
)
argp.add_argument(
"--split", type=float, help="create train/val split", required=False
"-split", type=float, help="create train/val split", required=False
)

return argp.parse_args(sys.argv[2:])
Expand All @@ -284,7 +284,7 @@ def build_midi_dataset(args):
)

if args.split:
assert 0.0 < args.split < 1.0, "Invalid range given for --split"
assert 0.0 < args.split < 1.0, "Invalid range given for -split"
MidiDataset.split_from_file(
load_path=args.save_path,
train_val_ratio=args.split,
Expand Down
4 changes: 2 additions & 2 deletions config/config.json
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
"max_programs":{
"run": true,
"args": {
"max": 7
"max": 12
}
},
"max_instruments":{
Expand All @@ -24,7 +24,7 @@
"run": true,
"args": {
"min_per_second": 0.5,
"max_per_second": 15
"max_per_second": 16
}
},
"min_length":{
Expand Down

0 comments on commit 748e259

Please sign in to comment.