Skip to content

Commit

Permalink
Merge branch 'dev-midi-generator' into fix-midi-generator-cli
Browse files Browse the repository at this point in the history
  • Loading branch information
SleepyLGod committed Nov 10, 2022
2 parents 4e3bc5b + aea03e4 commit fe59d10
Show file tree
Hide file tree
Showing 4 changed files with 99 additions and 44 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -15,39 +15,45 @@ public class Transcriptor {
private final OrtEnvironment _env;
private final OrtSession _session;
private final int segment_samples = 16000 * 10;
private final int frames_per_second = 100;
private final int frames_per_second = 100;//帧频
private final int classes_num = 88;
private final float onset_threshold = 0.3f;
private final float offset_threshold = 0.3f;
private final float frame_threshold = 0.1f;
private final float pedal_offset_threshold = 0.2f;
private final float onset_threshold = 0.3f; //起始阈值
private final float offset_threshold = 0.3f;//终止阈值
private final float frame_threshold = 0.1f;//帧阈值
private final float pedal_offset_threshold = 0.2f;//踏板偏移阈值

public Transcriptor(String modulePath) throws OrtException {
this._env = OrtEnvironment.getEnvironment();
this._session = this._env.createSession(modulePath);
this._env = OrtEnvironment.getEnvironment();//OrtEnvironment型变量是onnx运行时系统的主机对象。可以创建封装的 OrtSessions 特定型号。
this._session = this._env.createSession(modulePath);//使用默认的OrtSession.SessionOptions、模型和默认内存分配器创建会话。
}

public byte[] transcript(float[] pcm_data) throws InvalidMidiDataException, IOException, OrtException {
var pad_len = (int) (Math.ceil(pcm_data.length / segment_samples)) * segment_samples - pcm_data.length;
int pad_len = (int) (Math.ceil(pcm_data.length / segment_samples)) * segment_samples - pcm_data.length;
if (pad_len != 0)
pcm_data = Arrays.copyOf(pcm_data, pcm_data.length + pad_len);
/*复制指定的数组,用零截断或填充(如有必要),以便副本具有指定的长度。
对于在原始数组和副本中都有效的所有索引,这两个数组将包含相同的值。
对于在副本中有效但不是原始索引的任何索引,副本将包含 0f。
当且仅当指定长度大于原始数组的长度时,此类索引才会存在。 */
var segments = enframe(pcm_data, segment_samples);
//将pcm_data按segment_samples进行划分形成框架
var output_dict = forward(segments);
var new_output_dict = new HashMap<String, float[][]>();
for (var key :
output_dict.keySet()) {
for (String key : output_dict.keySet()) {
new_output_dict.put(key, deframe(output_dict.get(key)));
}
var post_processor = new RegressionPostProcessor(frames_per_second, classes_num, onset_threshold, offset_threshold, frame_threshold, pedal_offset_threshold);
var out = post_processor.outputDictToMidiEvents(new_output_dict);
var est_note_events = (List<NoteEvent>) out.get(0);
var est_pedal_events = (List<PedalEvent>) out.get(1);
List<NoteEvent> est_note_events = (List<NoteEvent>) out.get(0);
List<PedalEvent> est_pedal_events = (List<PedalEvent>) out.get(1);

// write midi file here
return writeEventsToMidi(0, est_note_events, est_pedal_events);
}

private List<float[]> enframe(float[] data, int segment_samples) {
private List<float[]> enframe(float[] data, int segment_samples) {//将data数组按照每segment_samples
//个元素一组,赋值给batch,并且是一半冗余地赋值,即若data为1 2 3 4 5 6 7 8且若segment_samples=4,
// 则传为1 2 3 4;3 4 5 6;5 6 7 8
int pointer = 0;
var batch = new ArrayList<float[]>();
while (pointer + segment_samples <= data.length) {
Expand All @@ -69,11 +75,10 @@ private Map<String, List<float[][]>> forward(List<float[]> data) throws OrtExcep
output_dic.put("pedal_frame_output", new ArrayList<>());

int pointer = 0;
var total_segments = (int) (Math.ceil(data.size() / batch_size));
var total_segments = (int) (Math.ceil(data.size() / batch_size));//获取data中的float数组个数
while (true) {
System.out.println("Segment " + pointer + " / " + total_segments);
if (pointer >= data.size()) break;

if (pointer >= data.size()) break;//已经遍历完,循环结束
//var tensor = Tensor.fromBlob(data.get(pointer), new long[]{1, segment_samples});
var tensor = OnnxTensor.createTensor(_env, FloatBuffer.wrap(data.get(pointer)), new long[]{1, segment_samples});
//var batch_output_dict = this._module.forward(IValue.from(tensor));
Expand All @@ -91,7 +96,7 @@ private Map<String, List<float[][]>> forward(List<float[]> data) throws OrtExcep

private float[][] deframe(List<float[][]> x) {
if (x.size() == 1)
return x.get(0);
return x.get(0);//返回第一个元素(元素是二维数组,不是数字)即可
else {
int segment_samples = x.get(0).length - 1;
int length = x.get(0)[0].length;
Expand All @@ -113,6 +118,7 @@ private float[][] deframe(List<float[][]> x) {
}

private float[][] read2DArray(float[][] x, int start1, int end1, int start2, int end2) {
//读数组,与write2DArray相反的操作
var output = new float[end1 - start1][end2 - start2];
for (int i = 0; i < end1 - start1; i++) {
for (int j = 0; j < end2 - start2; j++) {
Expand All @@ -123,6 +129,9 @@ private float[][] read2DArray(float[][] x, int start1, int end1, int start2, int
}

private void write2DArray(float[][] source, float[][] des, int start1, int end1, int start2, int end2) {
//把source的0到end1 - start1-1行,0到end2 - start2-1列的元素赋值给
//des的start1到end1-1行,start2到end2-1列的元素

for (int i = 0; i < end1 - start1; i++) {
for (int j = 0; j < end2 - start2; j++) {
des[i + start1][j + start2] = source[i][j];
Expand All @@ -131,7 +140,7 @@ private void write2DArray(float[][] source, float[][] des, int start1, int end1,
}
}

class RegressionPostProcessor {
class RegressionPostProcessor {//回归处理器
private final int frames_per_second;
private final int classes_num;
private final float onset_threshold;
Expand All @@ -142,7 +151,7 @@ class RegressionPostProcessor {
private final int velocity_scale;

public RegressionPostProcessor(int frames_per_second, int classes_num, float onset_threshold, float offset_threshold, float frame_threshold, float pedal_offset_threshold) {

//构造函数
this.frames_per_second = frames_per_second;
this.classes_num = classes_num;
this.onset_threshold = onset_threshold;
Expand All @@ -153,7 +162,7 @@ public RegressionPostProcessor(int frames_per_second, int classes_num, float ons
this.velocity_scale = 128;
}

private float[][] read2DArray(float[][] x, int start1, int end1, int start2, int end2) {
private float[][] read2DArray(float[][] x, int start1, int end1, int start2, int end2) {//和上面的类的同名函数功能相同
var output = new float[end1 - start1][end2 - start2];
for (int i = 0; i < end1 - start1; i++) {
for (int j = 0; j < end2 - start2; j++) {
Expand All @@ -163,7 +172,7 @@ private float[][] read2DArray(float[][] x, int start1, int end1, int start2, int
return output;
}

private void write2DArray(float[][] source, float[][] des, int start1, int end1, int start2, int end2) {
private void write2DArray(float[][] source, float[][] des, int start1, int end1, int start2, int end2) {//和上面的类的同名函数功能相同
for (int i = 0; i < end1 - start1; i++) {
for (int j = 0; j < end2 - start2; j++) {
des[i + start1][j + start2] = source[i][j];
Expand Down Expand Up @@ -192,6 +201,7 @@ public List<Object> outputDictToMidiEvents(Map<String, float[][]> dict) {
}

private List<NoteEvent> detectedNotesToEvents(List<float[]> est_on_off_note_vels) {
//将est_on_off_note_vels的数据传输给NoteEvent
var output = new ArrayList<NoteEvent>();
for (var i :
est_on_off_note_vels) {
Expand Down Expand Up @@ -272,6 +282,7 @@ private List<float[][]> getBinarizedOutputFromRegression(float[][] reg_output, f
}

private boolean is_monotonic_neighbour(float[] x, int n, int neighbour) {
//检测是否是递增或递减的相邻项
var monotonic = true;
for (int i = 0; i < neighbour; i++) {
if (x[n - i] < x[n - i - 1])
Expand Down Expand Up @@ -450,6 +461,8 @@ private List<float[]> noteDetectionWithOnsetOffsetRegress(float[] frame_outputs,
}

private List<float[]> pedalDetectionWithOnsetOffsetRegress(float[] frame_outputs, float[] offset_outputs, float[] offset_shift_outputs, double frame_threshold) {
//pedal:持续音
//返回一个output
var bgn = 0;
var frame_disappear = 0;
var offset_occur = 0;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,30 +8,56 @@
import java.util.List;

public class MidiWriter {
/*
MIDIEvent 包含一条 MIDI 消息和一个以刻度表示的相应时间戳,并且可以表示存储在 MIDI 文件或序列对象中的 MIDI 事件信息。
滴答的持续时间由 MIDI 文件或序列对象中包含的计时信息指定。
在 Java Sound 中,MidiEvent 对象通常包含在 Track 中,Tracks 也同样包含在 Sequence 中
此MidiWriter.java部分的主要功能就在于将MP3文件经过转化后得到的数据以字节流的形式进行存储并作为返回值返回,这是 MIDI 文件在计算机中存储的一种数据形式
此外也完成了将 MIDI 类型的文件的字节流写入提供的输出流
*/
public static byte[] writeEventsToMidi(int startTime, List<NoteEvent> noteEvents, List<PedalEvent> pedalEvents) throws InvalidMidiDataException, IOException {
var ticksPerBeat = 384;
var beatsPerSecond = 2;
var ticksPerSecond = ticksPerBeat * beatsPerSecond;
var microsecondsPerBeat = (int) (1e6 / beatsPerSecond);
var ticksPerBeat = 384; //设置每节拍滴答频率
var beatsPerSecond = 2; //设置每秒节拍数
var ticksPerSecond = ticksPerBeat * beatsPerSecond; //每秒滴答频率=每秒节拍数*每节拍滴答频率
var microsecondsPerBeat = (int) (1e6 / beatsPerSecond); //每节拍微秒时长= 每秒节拍数的倒数 * 1e6(化为微秒单位)

// create midi sequence with 384 ticks per beat
// create midi sequence with 384 ticks per beat --- 创建每节拍 384 tick 的 midi 序列
var s = new Sequence(Sequence.PPQ, ticksPerBeat);

// track 0
// track 0 在该midi序列s下创建第一条轨道t0
var t0 = s.createTrack();
// set tempo
// set tempo 设定节奏
var m1 = new MetaMessage();
/*
MIDI 标准用字节表示 MIDI 数据。但是,由于 JavaTM 使用带符号字节,所以 Java Sound API 表示 MIDI 数据时使用整数而不是字节。
例如,MidiMessage 的 getStatus() 方法返回用整数表示的 MIDI 状态字节。
如果处理来源于 Java Sound 之外的 MIDI 数据,而现在又编码为带符号字节,可使用以下转换将字节转换为整数:int i = (int)(byte & 0xFF)
*/
var b1 = new byte[]{(byte) (microsecondsPerBeat >> 16), (byte) (microsecondsPerBeat >> 8 & 0xff), (byte) (microsecondsPerBeat & 0xff)};
/* 对m1 设置 MetaMessage 的消息参数。由于元消息只允许一个状态字节值 0xFF,因此这里不需要指定。对所有元消息的 getStatus 调用返回 0xFF。
MIDI 最核心的功能是用于传输实时的音乐演奏信息,这些信息本质上是一条条包含了音高、力度、效果器参数等信息的指令,我们将这些指令称之为 MIDI 消息(MIDI message)。
一条 MIDI 消息通常由数个字节组成,其中第一个字节被称为 STATUS byte,
其后面有跟有数个 DATA bytes。STATUS byte 第七位为 1,而 DATA byte 第七位为 0。
形参:
type - 元消息类型(必须小于 128) , 应该是 MetaMessage 中状态字节之后的字节的有效值
data - MIDI 消息中的数据字节 , 应该包含 MetaMessage 的所有后续字节。换句话说,指定 MetaMessage 类型的字节不被视为数据字节
length - 数据字节数组中的字节数 */
m1.setMessage(0x51, b1, 3);
/*
构造一个新的 MidiEvent。
message – event中包含的 MIDI
tick – 事件的时间戳,以 MIDI tick 为单位
*/
var me1 = new MidiEvent(m1, 0);
//将me1加入到轨道t0中
t0.add(me1);
//set time signature
//set time signature---设置时间签名
var m2 = new MetaMessage();
var b2 = new byte[]{0x4, 0x2, 0x18, 0x8};
m2.setMessage(0x58, b2, 4);
var me2 = new MidiEvent(m2, 0);
t0.add(me2);
//set end of track
t0.add(me2); //将me2加入到轨道t0中
//set end of track --- 设置t0轨道的末尾me3
var m3 = new MetaMessage();
var b3 = new byte[]{};
m3.setMessage(0x2f, b3, 0);
Expand All @@ -41,25 +67,31 @@ public static byte[] writeEventsToMidi(int startTime, List<NoteEvent> noteEvents
// track 1
var t1 = s.createTrack();

// generate midi message roll
// generate midi message roll --- 生成 midi 信息 roll
/* MidiMessage的四个参数:
* @param a time
* @param b midinote or control change(64)
* @param c velocity or value
* @param type :0 表示 midi note ; 1 表示 control change
*/
var roll = new ArrayList<MidiMessage>();
for (var note : noteEvents) {
roll.add(new MidiMessage(note.getOnsetTime(), note.getMidiNote(), note.getVelocity(), 0));
roll.add(new MidiMessage(note.getOffsetTime(), note.getMidiNote(), 0, 0));
}
if (pedalEvents.size() != 0) {
var controlChange = 64;
for (var pedal : pedalEvents) {
var controlChange = 64; //共64个音色库可进行选择
for (var pedal : pedalEvents) { //对于每一个pedalEvents对象,记录其开始时刻和结束时刻
roll.add(new MidiMessage(pedal.getOnsetTime(), controlChange, 127, 1));
roll.add(new MidiMessage(pedal.getOffsetTime(), controlChange, 0, 1));
roll.add(new MidiMessage(pedal.getOffsetTime(), controlChange, 0, 1)); //力度0就是音符关(Note off)
}
}
roll.sort(Comparator.comparing(x -> x.getA()));
roll.sort(Comparator.comparing(x -> x.getA())); //根据参数a,即OnsetTime对roll中加入的几个MidiMessage对象进行排序

// write midi message to track 1
//var previousTicks = 0;
for (var m : roll) {
var thisTicks = (int) ((m.getA() - startTime) * ticksPerSecond);
var thisTicks = (int) ((m.getA() - startTime) * ticksPerSecond);//对于每一个roll中的对象计算其总ticks数
if (thisTicks >= 0) {
//var diffTicks = thisTicks - previousTicks;
//previousTicks = thisTicks;
Expand All @@ -77,6 +109,16 @@ public static byte[] writeEventsToMidi(int startTime, List<NoteEvent> noteEvents
t1.add(me4);

var stream = new ByteArrayOutputStream();
/*
MidiSystem 类提供对已安装的 MIDI 系统资源的访问,包括合成器、音序器和 MIDI 输入和输出端口等设备。
一个典型的简单 MIDI 应用程序可能首先调用一个或多个 MidiSystem 方法来了解安装了哪些设备并获取该应用程序所需的设备。
该类还具有用于读取包含标准 MIDI 文件数据或音库的文件、流和 URL 的方法。我们可以在 MidiSystem 中查询指定 MIDI 文件的格式。
write方法:将表示 MIDI 文件类型的文件的字节流写入提供的输出流。
形参:
@param in - 包含要写入文件的 MIDI 数据的序列
@param fileType - 要写入输出流的文件的文件类型
@param out - 应该写入文件数据的流
*/
MidiSystem.write(s, 1, stream);
return stream.toByteArray();
}
Expand Down Expand Up @@ -116,4 +158,4 @@ public int getC() {
public int getType() {
return type;
}
}
}
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
package libpianotranscription.midi;

public class NoteEvent {
private final float onsetTime;
private final float offsetTime;
private final int midiNote;
private final int velocity;
private final float onsetTime; //启动时间
private final float offsetTime; //终止时间
private final int midiNote; //midi记录
private final int velocity; //速度

public NoteEvent(float onsetTime, float offsetTime, int midiNote, int velocity) {
this.onsetTime = onsetTime;
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
package libpianotranscription.midi;

public class PedalEvent {//似乎是用来计时的类
public class PedalEvent {//踏板事件,记录每一个pedalEvents对象的开始时刻和结束时刻
private final float onsetTime;
private final float offsetTime;

Expand All @@ -16,4 +16,4 @@ public float getOffsetTime() {
public float getOnsetTime() {
return onsetTime;
}
}
}

0 comments on commit fe59d10

Please sign in to comment.