Skip to content

Commit

Permalink
transcriptor-modified
Browse files Browse the repository at this point in the history
  • Loading branch information
zzhisthebest committed Nov 10, 2022
1 parent c931807 commit b6c4073
Showing 1 changed file with 34 additions and 21 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -15,39 +15,45 @@ public class Transcriptor {
private final OrtEnvironment _env;
private final OrtSession _session;
private final int segment_samples = 16000 * 10;
private final int frames_per_second = 100;
private final int frames_per_second = 100;//帧频
private final int classes_num = 88;
private final float onset_threshold = 0.3f;
private final float offset_threshold = 0.3f;
private final float frame_threshold = 0.1f;
private final float pedal_offset_threshold = 0.2f;
private final float onset_threshold = 0.3f; //起始阈值
private final float offset_threshold = 0.3f;//终止阈值
private final float frame_threshold = 0.1f;//帧阈值
private final float pedal_offset_threshold = 0.2f;//踏板偏移阈值

public Transcriptor(String modulePath) throws OrtException {
this._env = OrtEnvironment.getEnvironment();
this._session = this._env.createSession(modulePath);
this._env = OrtEnvironment.getEnvironment();//OrtEnvironment型变量是onnx运行时系统的主机对象。可以创建封装的 OrtSessions 特定型号。
this._session = this._env.createSession(modulePath);//使用默认的OrtSession.SessionOptions、模型和默认内存分配器创建会话。
}

public byte[] transcript(float[] pcm_data) throws InvalidMidiDataException, IOException, OrtException {
var pad_len = (int) (Math.ceil(pcm_data.length / segment_samples)) * segment_samples - pcm_data.length;
int pad_len = (int) (Math.ceil(pcm_data.length / segment_samples)) * segment_samples - pcm_data.length;
if (pad_len != 0)
pcm_data = Arrays.copyOf(pcm_data, pcm_data.length + pad_len);
/*复制指定的数组,用零截断或填充(如有必要),以便副本具有指定的长度。
对于在原始数组和副本中都有效的所有索引,这两个数组将包含相同的值。
对于在副本中有效但不是原始索引的任何索引,副本将包含 0f。
当且仅当指定长度大于原始数组的长度时,此类索引才会存在。 */
var segments = enframe(pcm_data, segment_samples);
//将pcm_data按segment_samples进行划分形成框架
var output_dict = forward(segments);
var new_output_dict = new HashMap<String, float[][]>();
for (var key :
output_dict.keySet()) {
for (String key : output_dict.keySet()) {
new_output_dict.put(key, deframe(output_dict.get(key)));
}
var post_processor = new RegressionPostProcessor(frames_per_second, classes_num, onset_threshold, offset_threshold, frame_threshold, pedal_offset_threshold);
var out = post_processor.outputDictToMidiEvents(new_output_dict);
var est_note_events = (List<NoteEvent>) out.get(0);
var est_pedal_events = (List<PedalEvent>) out.get(1);
List<NoteEvent> est_note_events = (List<NoteEvent>) out.get(0);
List<PedalEvent> est_pedal_events = (List<PedalEvent>) out.get(1);

// write midi file here
return writeEventsToMidi(0, est_note_events, est_pedal_events);
}

private List<float[]> enframe(float[] data, int segment_samples) {
private List<float[]> enframe(float[] data, int segment_samples) {//将data数组按照每segment_samples
//个元素一组,赋值给batch,并且是一半冗余地赋值,即若data为1 2 3 4 5 6 7 8且若segment_samples=4,
// 则传为1 2 3 4;3 4 5 6;5 6 7 8
int pointer = 0;
var batch = new ArrayList<float[]>();
while (pointer + segment_samples <= data.length) {
Expand All @@ -69,11 +75,10 @@ private Map<String, List<float[][]>> forward(List<float[]> data) throws OrtExcep
output_dic.put("pedal_frame_output", new ArrayList<>());

int pointer = 0;
var total_segments = (int) (Math.ceil(data.size() / batch_size));
var total_segments = (int) (Math.ceil(data.size() / batch_size));//获取data中的float数组个数
while (true) {
System.out.println("Segment " + pointer + " / " + total_segments);
if (pointer >= data.size()) break;

if (pointer >= data.size()) break;//已经遍历完,循环结束
//var tensor = Tensor.fromBlob(data.get(pointer), new long[]{1, segment_samples});
var tensor = OnnxTensor.createTensor(_env, FloatBuffer.wrap(data.get(pointer)), new long[]{1, segment_samples});
//var batch_output_dict = this._module.forward(IValue.from(tensor));
Expand All @@ -91,7 +96,7 @@ private Map<String, List<float[][]>> forward(List<float[]> data) throws OrtExcep

private float[][] deframe(List<float[][]> x) {
if (x.size() == 1)
return x.get(0);
return x.get(0);//返回第一个元素(元素是二维数组,不是数字)即可
else {
int segment_samples = x.get(0).length - 1;
int length = x.get(0)[0].length;
Expand All @@ -113,6 +118,7 @@ private float[][] deframe(List<float[][]> x) {
}

private float[][] read2DArray(float[][] x, int start1, int end1, int start2, int end2) {
//读数组,与write2DArray相反的操作
var output = new float[end1 - start1][end2 - start2];
for (int i = 0; i < end1 - start1; i++) {
for (int j = 0; j < end2 - start2; j++) {
Expand All @@ -123,6 +129,9 @@ private float[][] read2DArray(float[][] x, int start1, int end1, int start2, int
}

private void write2DArray(float[][] source, float[][] des, int start1, int end1, int start2, int end2) {
//把source的0到end1 - start1-1行,0到end2 - start2-1列的元素赋值给
//des的start1到end1-1行,start2到end2-1列的元素

for (int i = 0; i < end1 - start1; i++) {
for (int j = 0; j < end2 - start2; j++) {
des[i + start1][j + start2] = source[i][j];
Expand All @@ -131,7 +140,7 @@ private void write2DArray(float[][] source, float[][] des, int start1, int end1,
}
}

class RegressionPostProcessor {
class RegressionPostProcessor {//回归处理器
private final int frames_per_second;
private final int classes_num;
private final float onset_threshold;
Expand All @@ -142,7 +151,7 @@ class RegressionPostProcessor {
private final int velocity_scale;

public RegressionPostProcessor(int frames_per_second, int classes_num, float onset_threshold, float offset_threshold, float frame_threshold, float pedal_offset_threshold) {

//构造函数
this.frames_per_second = frames_per_second;
this.classes_num = classes_num;
this.onset_threshold = onset_threshold;
Expand All @@ -153,7 +162,7 @@ public RegressionPostProcessor(int frames_per_second, int classes_num, float ons
this.velocity_scale = 128;
}

private float[][] read2DArray(float[][] x, int start1, int end1, int start2, int end2) {
private float[][] read2DArray(float[][] x, int start1, int end1, int start2, int end2) {//和上面的类的同名函数功能相同
var output = new float[end1 - start1][end2 - start2];
for (int i = 0; i < end1 - start1; i++) {
for (int j = 0; j < end2 - start2; j++) {
Expand All @@ -163,7 +172,7 @@ private float[][] read2DArray(float[][] x, int start1, int end1, int start2, int
return output;
}

private void write2DArray(float[][] source, float[][] des, int start1, int end1, int start2, int end2) {
private void write2DArray(float[][] source, float[][] des, int start1, int end1, int start2, int end2) {//和上面的类的同名函数功能相同
for (int i = 0; i < end1 - start1; i++) {
for (int j = 0; j < end2 - start2; j++) {
des[i + start1][j + start2] = source[i][j];
Expand Down Expand Up @@ -192,6 +201,7 @@ public List<Object> outputDictToMidiEvents(Map<String, float[][]> dict) {
}

private List<NoteEvent> detectedNotesToEvents(List<float[]> est_on_off_note_vels) {
//将est_on_off_note_vels的数据传输给NoteEvent
var output = new ArrayList<NoteEvent>();
for (var i :
est_on_off_note_vels) {
Expand Down Expand Up @@ -272,6 +282,7 @@ private List<float[][]> getBinarizedOutputFromRegression(float[][] reg_output, f
}

private boolean is_monotonic_neighbour(float[] x, int n, int neighbour) {
//检测是否是递增或递减的相邻项
var monotonic = true;
for (int i = 0; i < neighbour; i++) {
if (x[n - i] < x[n - i - 1])
Expand Down Expand Up @@ -450,6 +461,8 @@ private List<float[]> noteDetectionWithOnsetOffsetRegress(float[] frame_outputs,
}

private List<float[]> pedalDetectionWithOnsetOffsetRegress(float[] frame_outputs, float[] offset_outputs, float[] offset_shift_outputs, double frame_threshold) {
//pedal:持续音
//返回一个output
var bgn = 0;
var frame_disappear = 0;
var offset_occur = 0;
Expand Down

0 comments on commit b6c4073

Please sign in to comment.