Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve preprocessing test logging #44

Merged
merged 1 commit into from
Oct 10, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions aria/data/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,8 +174,9 @@ def _run_tests(_mid_dict: MidiDict):
test_fn = get_test_fn(test_name)
test_args = test_config["args"]

if test_fn(_mid_dict, **test_args) is False:
failed_tests.append(test_name)
test_res, val = test_fn(_mid_dict, **test_args)
if test_res is False:
failed_tests.append((test_name, val))

return failed_tests

Expand Down Expand Up @@ -208,7 +209,7 @@ def _preprocess_mididict(_mid_dict: MidiDict):
failed_tests = _run_tests(mid_dict)
if failed_tests:
logger.info(
f"MIDI at {path} failed preprocessing tests: {failed_tests}"
f"MIDI at {path} failed preprocessing tests: {failed_tests} "
)
return False, None
else:
Expand Down
35 changes: 14 additions & 21 deletions aria/data/midi.py
Original file line number Diff line number Diff line change
Expand Up @@ -610,9 +610,9 @@ def test_max_programs(midi_dict: MidiDict, max: int):
)

if len(present_programs) <= max:
return True
return True, len(present_programs)
else:
return False
return False, len(present_programs)


def test_max_instruments(midi_dict: MidiDict, max: int):
Expand All @@ -624,16 +624,16 @@ def test_max_instruments(midi_dict: MidiDict, max: int):
)

if len(present_instruments) <= max:
return True
return True, len(present_instruments)
else:
return False
return False, len(present_instruments)


def test_note_frequency(
midi_dict: MidiDict, max_per_second: float, min_per_second: float
):
if not midi_dict.note_msgs:
return False
return False, 0.0

num_notes = len(midi_dict.note_msgs)
total_duration_ms = get_duration_ms(
Expand All @@ -645,9 +645,9 @@ def test_note_frequency(
notes_per_second = (num_notes * 1e3) / total_duration_ms

if notes_per_second < min_per_second or notes_per_second > max_per_second:
return False
return False, notes_per_second
else:
return True
return True, notes_per_second


def test_note_frequency_per_instrument(
Expand All @@ -674,20 +674,14 @@ def test_note_frequency_per_instrument(
)
notes_per_second = (num_notes * 1e3) / total_duration_ms

note_freq_per_instrument = notes_per_second / num_instruments
if (
notes_per_second / num_instruments < min_per_second
or notes_per_second / num_instruments > max_per_second
note_freq_per_instrument < min_per_second
or note_freq_per_instrument > max_per_second
):
return False
return False, note_freq_per_instrument
else:
return True


def test_no_notes(midi_dict: MidiDict):
if len(midi_dict.note_msgs) > 0:
return True
else:
return False
return True, note_freq_per_instrument


def test_min_length(midi_dict: MidiDict, min_seconds: int):
Expand All @@ -702,17 +696,16 @@ def test_min_length(midi_dict: MidiDict, min_seconds: int):
)

if total_duration_ms / 1e3 < min_seconds:
return False
return False, total_duration_ms / 1e3
else:
return True
return True, total_duration_ms / 1e3


def get_test_fn(test_name: str):
# Add additional test_names to this inventory
name_to_fn = {
"max_programs": test_max_programs,
"max_instruments": test_max_instruments,
"no_notes": test_no_notes,
"total_note_frequency": test_note_frequency,
"note_frequency_per_instrument": test_note_frequency_per_instrument,
"min_length": test_min_length,
Expand Down
4 changes: 0 additions & 4 deletions config/config.json
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,6 @@
"max": 4
}
},
"no_notes":{
"run": true,
"args": {}
},
"total_note_frequency":{
"run": true,
"args": {
Expand Down