Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

MidiDataset optimizations #68

Merged
merged 11 commits into from
Nov 22, 2023
Merged

MidiDataset optimizations #68

merged 11 commits into from
Nov 22, 2023

Conversation

honglu2875
Copy link
Contributor

@honglu2875 honglu2875 commented Nov 9, 2023

  • can initialize with an iterator and only expand when necessary (len or random access).
  • reduce a little memory/initialization overhead by changing program_to_instrument dict into a getter function (we are starting to have >100k MidiDataset and could have more in the future)

note: I thought we might hide some latency by doing json.decode async during tokenizer encoding. But doesn't feel like helping much....... However, I have a separate script to build a static huggingface tokenized dataset and it seems to help.
update: Now it's slightly faster by using persistent workers reading queues in while True loops, instead of process pool. The processes are spun up only once and therefore tokenizers are pickled only once

def _load():
with jsonlines.open(load_path) as reader:
for entry in reader:
yield MidiDict.from_msg_dict(entry)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This seems fine, but maybe include an option for when this functionality should be used e.g. stream=True or something.

@@ -320,7 +320,7 @@ def custom_forward(*args):

return custom_forward

hidden_states = torch.utils.checkpoint.checkpoint(
hidden_states, _ = torch.utils.checkpoint.checkpoint(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Did this end up fixing something?

import json


class Reader:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm guessing that you are using this when building hf datasets. If so I'm happy to add jsonl.zst functionality for streaming MidiDataset and TokenizedDatasets.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The only blocker to replace everything by jsonl.zst reader/writer was that it doesn't allow for mmap indexing. It can't specify a place and directly retrieve the content without reading everything previously. I don't know if it's a theoretic impossibility or just an implementation limit of zstandard

@@ -114,8 +113,11 @@ def __init__(
}
]

@classmethod
@property
def program_to_instrument(cls):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm assuming this is here to speed up the process of building MidiDatasets. Does it make much of a difference?

tokenizer = TokenizerLazy()
dataset = TokenizedDataset.build(
tokenizer=tokenizer,
save_path=args.save_path,
midi_dataset_path=args.load_path,
max_seq_len=config["max_seq_len"],
max_seq_len=args.l,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good idea. Might as well also remove the dataset_gen_args from the config json.

@loubbrad loubbrad merged commit 4cd90fc into EleutherAI:main Nov 22, 2023
1 check passed
@honglu2875 honglu2875 deleted the dev branch November 22, 2023 19:26
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

2 participants