Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

MidiDataset optimizations #68

Merged
merged 11 commits into from
Nov 22, 2023
Next Next commit
MidiDataset can initialize with an iterator and only expand when nece…
…ssary.
  • Loading branch information
honglu2875 committed Nov 9, 2023
commit 023071572e60c385dbc820974109ca9e946e733f
16 changes: 8 additions & 8 deletions aria/data/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,14 +52,14 @@ class MidiDataset:
entries (list[MidiDict]): MidiDict objects to be stored.
"""

def __init__(self, entries: list[MidiDict]):
def __init__(self, entries: list[MidiDict] | Iterable):
self.entries = entries

def __len__(self):
return len(self.entries)
return len(list(self.entries))

def __getitem__(self, ind: int):
return self.entries[ind]
return list(self.entries)[ind]

def __iter__(self):
yield from self.entries
Expand All @@ -74,12 +74,12 @@ def save(self, save_path: str):
@classmethod
def load(cls, load_path: str):
"""Loads dataset from JSON file."""
midi_dicts = []
with jsonlines.open(load_path) as reader:
for entry in reader:
midi_dicts.append(MidiDict.from_msg_dict(entry))
def _load():
with jsonlines.open(load_path) as reader:
for entry in reader:
yield MidiDict.from_msg_dict(entry)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This seems fine, but maybe include an option for when this functionality should be used e.g. stream=True or something.


return cls(midi_dicts)
return cls(_load())

@classmethod
def split_from_file(
Expand Down