Skip to content

Commit

Permalink
Improve preprocessing test logging (#44)
Browse files Browse the repository at this point in the history
  • Loading branch information
loubbrad committed Oct 10, 2023
1 parent b1b58ec commit 29467a1
Show file tree
Hide file tree
Showing 3 changed files with 18 additions and 28 deletions.
7 changes: 4 additions & 3 deletions aria/data/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,8 +174,9 @@ def _run_tests(_mid_dict: MidiDict):
test_fn = get_test_fn(test_name)
test_args = test_config["args"]

if test_fn(_mid_dict, **test_args) is False:
failed_tests.append(test_name)
test_res, val = test_fn(_mid_dict, **test_args)
if test_res is False:
failed_tests.append((test_name, val))

return failed_tests

Expand Down Expand Up @@ -208,7 +209,7 @@ def _preprocess_mididict(_mid_dict: MidiDict):
failed_tests = _run_tests(mid_dict)
if failed_tests:
logger.info(
f"MIDI at {path} failed preprocessing tests: {failed_tests}"
f"MIDI at {path} failed preprocessing tests: {failed_tests} "
)
return False, None
else:
Expand Down
35 changes: 14 additions & 21 deletions aria/data/midi.py
Original file line number Diff line number Diff line change
Expand Up @@ -610,9 +610,9 @@ def test_max_programs(midi_dict: MidiDict, max: int):
)

if len(present_programs) <= max:
return True
return True, len(present_programs)
else:
return False
return False, len(present_programs)


def test_max_instruments(midi_dict: MidiDict, max: int):
Expand All @@ -624,16 +624,16 @@ def test_max_instruments(midi_dict: MidiDict, max: int):
)

if len(present_instruments) <= max:
return True
return True, len(present_instruments)
else:
return False
return False, len(present_instruments)


def test_note_frequency(
midi_dict: MidiDict, max_per_second: float, min_per_second: float
):
if not midi_dict.note_msgs:
return False
return False, 0.0

num_notes = len(midi_dict.note_msgs)
total_duration_ms = get_duration_ms(
Expand All @@ -645,9 +645,9 @@ def test_note_frequency(
notes_per_second = (num_notes * 1e3) / total_duration_ms

if notes_per_second < min_per_second or notes_per_second > max_per_second:
return False
return False, notes_per_second
else:
return True
return True, notes_per_second


def test_note_frequency_per_instrument(
Expand All @@ -674,20 +674,14 @@ def test_note_frequency_per_instrument(
)
notes_per_second = (num_notes * 1e3) / total_duration_ms

note_freq_per_instrument = notes_per_second / num_instruments
if (
notes_per_second / num_instruments < min_per_second
or notes_per_second / num_instruments > max_per_second
note_freq_per_instrument < min_per_second
or note_freq_per_instrument > max_per_second
):
return False
return False, note_freq_per_instrument
else:
return True


def test_no_notes(midi_dict: MidiDict):
if len(midi_dict.note_msgs) > 0:
return True
else:
return False
return True, note_freq_per_instrument


def test_min_length(midi_dict: MidiDict, min_seconds: int):
Expand All @@ -702,17 +696,16 @@ def test_min_length(midi_dict: MidiDict, min_seconds: int):
)

if total_duration_ms / 1e3 < min_seconds:
return False
return False, total_duration_ms / 1e3
else:
return True
return True, total_duration_ms / 1e3


def get_test_fn(test_name: str):
# Add additional test_names to this inventory
name_to_fn = {
"max_programs": test_max_programs,
"max_instruments": test_max_instruments,
"no_notes": test_no_notes,
"total_note_frequency": test_note_frequency,
"note_frequency_per_instrument": test_note_frequency_per_instrument,
"min_length": test_min_length,
Expand Down
4 changes: 0 additions & 4 deletions config/config.json
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,6 @@
"max": 4
}
},
"no_notes":{
"run": true,
"args": {}
},
"total_note_frequency":{
"run": true,
"args": {
Expand Down

0 comments on commit 29467a1

Please sign in to comment.