Skip to content

Commit

Permalink
Merge pull request #23 from SleepyLGod/dev-midi-generator
Browse files Browse the repository at this point in the history
Dev midi generator
  • Loading branch information
SleepyLGod committed Dec 15, 2022
2 parents b2e0a57 + 4da3eed commit 835e9da
Show file tree
Hide file tree
Showing 10 changed files with 284 additions and 67 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -15,39 +15,45 @@ public class Transcriptor {
private final OrtEnvironment _env;
private final OrtSession _session;
private final int segment_samples = 16000 * 10;
private final int frames_per_second = 100;
private final int frames_per_second = 100;//帧频
private final int classes_num = 88;
private final float onset_threshold = 0.3f;
private final float offset_threshold = 0.3f;
private final float frame_threshold = 0.1f;
private final float pedal_offset_threshold = 0.2f;
private final float onset_threshold = 0.3f; //起始阈值
private final float offset_threshold = 0.3f;//终止阈值
private final float frame_threshold = 0.1f;//帧阈值
private final float pedal_offset_threshold = 0.2f;//踏板偏移阈值

public Transcriptor(String modulePath) throws OrtException {
this._env = OrtEnvironment.getEnvironment();
this._session = this._env.createSession(modulePath);
this._env = OrtEnvironment.getEnvironment();//OrtEnvironment型变量是onnx运行时系统的主机对象。可以创建封装的 OrtSessions 特定型号。
this._session = this._env.createSession(modulePath);//使用默认的OrtSession.SessionOptions、模型和默认内存分配器创建会话。
}

public byte[] transcript(float[] pcm_data) throws InvalidMidiDataException, IOException, OrtException {
var pad_len = (int) (Math.ceil(pcm_data.length / segment_samples)) * segment_samples - pcm_data.length;
int pad_len = (int) (Math.ceil(pcm_data.length / segment_samples)) * segment_samples - pcm_data.length;
if (pad_len != 0)
pcm_data = Arrays.copyOf(pcm_data, pcm_data.length + pad_len);
/*复制指定的数组,用零截断或填充(如有必要),以便副本具有指定的长度。
对于在原始数组和副本中都有效的所有索引,这两个数组将包含相同的值。
对于在副本中有效但不是原始索引的任何索引,副本将包含 0f。
当且仅当指定长度大于原始数组的长度时,此类索引才会存在。 */
var segments = enframe(pcm_data, segment_samples);
//将pcm_data按segment_samples进行划分形成框架
var output_dict = forward(segments);
var new_output_dict = new HashMap<String, float[][]>();
for (var key :
output_dict.keySet()) {
for (String key : output_dict.keySet()) {
new_output_dict.put(key, deframe(output_dict.get(key)));
}
var post_processor = new RegressionPostProcessor(frames_per_second, classes_num, onset_threshold, offset_threshold, frame_threshold, pedal_offset_threshold);
var out = post_processor.outputDictToMidiEvents(new_output_dict);
var est_note_events = (List<NoteEvent>) out.get(0);
var est_pedal_events = (List<PedalEvent>) out.get(1);
List<NoteEvent> est_note_events = (List<NoteEvent>) out.get(0);
List<PedalEvent> est_pedal_events = (List<PedalEvent>) out.get(1);

// write midi file here
return writeEventsToMidi(0, est_note_events, est_pedal_events);
}

private List<float[]> enframe(float[] data, int segment_samples) {
private List<float[]> enframe(float[] data, int segment_samples) {//将data数组按照每segment_samples
//个元素一组,赋值给batch,并且是一半冗余地赋值,即若data为1 2 3 4 5 6 7 8且若segment_samples=4,
// 则传为1 2 3 4;3 4 5 6;5 6 7 8
int pointer = 0;
var batch = new ArrayList<float[]>();
while (pointer + segment_samples <= data.length) {
Expand All @@ -69,11 +75,10 @@ private Map<String, List<float[][]>> forward(List<float[]> data) throws OrtExcep
output_dic.put("pedal_frame_output", new ArrayList<>());

int pointer = 0;
var total_segments = (int) (Math.ceil(data.size() / batch_size));
var total_segments = (int) (Math.ceil(data.size() / batch_size));//获取data中的float数组个数
while (true) {
System.out.println("Segment " + pointer + " / " + total_segments);
if (pointer >= data.size()) break;

if (pointer >= data.size()) break;//已经遍历完,循环结束
//var tensor = Tensor.fromBlob(data.get(pointer), new long[]{1, segment_samples});
var tensor = OnnxTensor.createTensor(_env, FloatBuffer.wrap(data.get(pointer)), new long[]{1, segment_samples});
//var batch_output_dict = this._module.forward(IValue.from(tensor));
Expand All @@ -91,7 +96,7 @@ private Map<String, List<float[][]>> forward(List<float[]> data) throws OrtExcep

private float[][] deframe(List<float[][]> x) {
if (x.size() == 1)
return x.get(0);
return x.get(0);//返回第一个元素(元素是二维数组,不是数字)即可
else {
int segment_samples = x.get(0).length - 1;
int length = x.get(0)[0].length;
Expand All @@ -113,6 +118,7 @@ private float[][] deframe(List<float[][]> x) {
}

private float[][] read2DArray(float[][] x, int start1, int end1, int start2, int end2) {
//读数组,与write2DArray相反的操作
var output = new float[end1 - start1][end2 - start2];
for (int i = 0; i < end1 - start1; i++) {
for (int j = 0; j < end2 - start2; j++) {
Expand All @@ -123,6 +129,9 @@ private float[][] read2DArray(float[][] x, int start1, int end1, int start2, int
}

private void write2DArray(float[][] source, float[][] des, int start1, int end1, int start2, int end2) {
//把source的0到end1 - start1-1行,0到end2 - start2-1列的元素赋值给
//des的start1到end1-1行,start2到end2-1列的元素

for (int i = 0; i < end1 - start1; i++) {
for (int j = 0; j < end2 - start2; j++) {
des[i + start1][j + start2] = source[i][j];
Expand All @@ -131,7 +140,7 @@ private void write2DArray(float[][] source, float[][] des, int start1, int end1,
}
}

class RegressionPostProcessor {
class RegressionPostProcessor {//回归处理器
private final int frames_per_second;
private final int classes_num;
private final float onset_threshold;
Expand All @@ -142,7 +151,7 @@ class RegressionPostProcessor {
private final int velocity_scale;

public RegressionPostProcessor(int frames_per_second, int classes_num, float onset_threshold, float offset_threshold, float frame_threshold, float pedal_offset_threshold) {

//构造函数
this.frames_per_second = frames_per_second;
this.classes_num = classes_num;
this.onset_threshold = onset_threshold;
Expand All @@ -153,7 +162,7 @@ public RegressionPostProcessor(int frames_per_second, int classes_num, float ons
this.velocity_scale = 128;
}

private float[][] read2DArray(float[][] x, int start1, int end1, int start2, int end2) {
private float[][] read2DArray(float[][] x, int start1, int end1, int start2, int end2) {//和上面的类的同名函数功能相同
var output = new float[end1 - start1][end2 - start2];
for (int i = 0; i < end1 - start1; i++) {
for (int j = 0; j < end2 - start2; j++) {
Expand All @@ -163,7 +172,7 @@ private float[][] read2DArray(float[][] x, int start1, int end1, int start2, int
return output;
}

private void write2DArray(float[][] source, float[][] des, int start1, int end1, int start2, int end2) {
private void write2DArray(float[][] source, float[][] des, int start1, int end1, int start2, int end2) {//和上面的类的同名函数功能相同
for (int i = 0; i < end1 - start1; i++) {
for (int j = 0; j < end2 - start2; j++) {
des[i + start1][j + start2] = source[i][j];
Expand Down Expand Up @@ -192,6 +201,7 @@ public List<Object> outputDictToMidiEvents(Map<String, float[][]> dict) {
}

private List<NoteEvent> detectedNotesToEvents(List<float[]> est_on_off_note_vels) {
//将est_on_off_note_vels的数据传输给NoteEvent
var output = new ArrayList<NoteEvent>();
for (var i :
est_on_off_note_vels) {
Expand Down Expand Up @@ -272,6 +282,7 @@ private List<float[][]> getBinarizedOutputFromRegression(float[][] reg_output, f
}

private boolean is_monotonic_neighbour(float[] x, int n, int neighbour) {
//检测是否是递增或递减的相邻项
var monotonic = true;
for (int i = 0; i < neighbour; i++) {
if (x[n - i] < x[n - i - 1])
Expand Down Expand Up @@ -450,6 +461,8 @@ private List<float[]> noteDetectionWithOnsetOffsetRegress(float[] frame_outputs,
}

private List<float[]> pedalDetectionWithOnsetOffsetRegress(float[] frame_outputs, float[] offset_outputs, float[] offset_shift_outputs, double frame_threshold) {
//pedal:持续音
//返回一个output
var bgn = 0;
var frame_disappear = 0;
var offset_occur = 0;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import org.springframework.web.multipart.MultipartFile;

import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import java.io.File;
import java.io.IOException;

Expand Down Expand Up @@ -45,19 +46,21 @@ public Mp3ImportVO Mp3ToMidi(@RequestBody Mp3ImportDTO mp3ImportDTO) throws Exce

@ResponseBody
@PostMapping(value = "/mp3ToMidiWithFile", consumes = {"multipart/form-data"})
public Mp3ImportVO Mp3ToMidiWithFile(@RequestParam("file")MultipartFile file,
@RequestParam("outPath")String outPath,
@RequestParam("songName")String songName) throws Exception {
Mp3ImportWithFileDTO mp3ImportWithFileDTO = new Mp3ImportWithFileDTO(file, outPath, songName);
public void Mp3ToMidiWithFile(@RequestParam("file")MultipartFile file,
// @RequestParam("outPath")String outPath,
@RequestParam("songName")String songName,
HttpServletResponse response) throws Exception {
Mp3ImportWithFileDTO mp3ImportWithFileDTO = new Mp3ImportWithFileDTO(file, songName);
try {
CommonResult commonResult = transcriptionService.Mp3TOMidiUploadWithFile(mp3ImportWithFileDTO);
CommonResult commonResult = transcriptionService.Mp3TOMidiUploadWithFile(mp3ImportWithFileDTO, response);
if (commonResult.getCode() == 1) {
return new Mp3ImportVO(true, commonResult.getData().toString(), null);
System.out.println("success");
} else {
return new Mp3ImportVO(false, null, commonResult.getMessage());
System.out.println("fail");
}
} catch (NullPointerException e) {
return new Mp3ImportVO(false, null, "请检查是否传入了正确的参数");
System.out.println("fail");
e.printStackTrace();
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,5 +10,7 @@ public class Mp3ImportDTO {
@NonNull
private String resourcePath;
@NonNull
private String outPath;
@NonNull
private String songName;
}
Original file line number Diff line number Diff line change
Expand Up @@ -11,15 +11,13 @@ public class Mp3ImportWithFileDTO {
@NonNull
private MultipartFile file;
@NonNull
private String outPath;
@NonNull
private String songName;
@NonNull
private String inputPath = "D:\\gitrepositories\\omg-score\\OmgPianoTranscription\\pianotranscriptioncli\\src\\main\\resources\\";

public Mp3ImportWithFileDTO(MultipartFile file, String outPath, String songName) {
@NonNull
private String outPath = "D:\\gitrepositories\\omg-score\\OmgPianoTranscription\\pianotranscriptioncli\\src\\main\\resources\\output\\";
public Mp3ImportWithFileDTO(MultipartFile file, String songName) {
this.file = file;
this.outPath = outPath;
this.songName = songName;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,12 @@
import com.pianotranscriptioncli.dto.Mp3ImportDTO;
import com.pianotranscriptioncli.dto.Mp3ImportWithFileDTO;

import javax.servlet.http.HttpServletResponse;

public interface TranscriptionService {
CommonResult Mp3TOMidiUpload(Mp3ImportDTO mp3ImportDTO) throws Exception;

CommonResult Mp3TOMidiUploadWithFile(Mp3ImportWithFileDTO mp3ImportWithFileDTO) throws Exception;
CommonResult Mp3TOMidiUploadWithFile(Mp3ImportWithFileDTO mp3ImportWithFileDTO, HttpServletResponse response) throws Exception;

String WavToMidiUpload();
}
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,16 @@
import com.pianotranscriptioncli.dto.Mp3ImportWithFileDTO;
import com.pianotranscriptioncli.service.TranscriptionService;
import com.pianotranscriptioncli.utils.Utils;
import org.apache.tomcat.util.http.fileupload.FileUtils;
import org.apache.tomcat.util.http.fileupload.IOUtils;
import org.springframework.stereotype.Service;
import org.springframework.web.multipart.MultipartFile;

import java.io.File;
import java.io.IOException;
import javax.servlet.http.HttpServletResponse;
import java.io.*;
import java.net.URLEncoder;
import java.nio.charset.StandardCharsets;
import java.util.Arrays;

@Service
public class TranscriptionServiceImpl implements TranscriptionService {
Expand All @@ -23,25 +28,32 @@ public CommonResult Mp3TOMidiUpload(Mp3ImportDTO mp3ImportDTO) throws Exception
System.out.println(System.getProperty("user.dir"));
resourcePath = System.getProperty("user.dir") + mp3ImportDTO.getResourcePath(); // "\\src\\main\\resources\\"
}
//String ans = resourcePath + "output\\" + mp3ImportDTO.getSongName() + ".mid";
String ans = Utils.Convertor(resourcePath, mp3ImportDTO.getSongName());
String outPath = mp3ImportDTO.getOutPath() + mp3ImportDTO.getSongName() + ".mid";
String ans = Utils.ConvertorRedirect(resourcePath, mp3ImportDTO.getSongName(), outPath);
if (ans != null) {
return CommonResult.success(ans, "mp3转换成功");
} else {
return CommonResult.failed("mp3转换失败");
}
}

/**
* 上传文件
* @param mp3ImportWithFileDTO
* @param response
* @return CommonResult
* @throws Exception
*/
@Override
public CommonResult Mp3TOMidiUploadWithFile(Mp3ImportWithFileDTO mp3ImportWithFileDTO) throws Exception {
public CommonResult Mp3TOMidiUploadWithFile(Mp3ImportWithFileDTO mp3ImportWithFileDTO, HttpServletResponse response) throws Exception {
MultipartFile file = mp3ImportWithFileDTO.getFile();
String inputFilePath;
if (!file.isEmpty()) {
/*
/*
String fileName = file.getOriginalFilename(); // 获取文件名
assert fileName != null;
String suffixName = fileName.substring(fileName.lastIndexOf(".")); // 获取文件的后缀名
*/
*/
inputFilePath = mp3ImportWithFileDTO.getInputPath() + "input\\" + mp3ImportWithFileDTO.getSongName() + ".mp3";
File dest = new File(inputFilePath);
if (!dest.getParentFile().exists()) { // 检测是否存在目录
Expand All @@ -65,8 +77,43 @@ public CommonResult Mp3TOMidiUploadWithFile(Mp3ImportWithFileDTO mp3ImportWithFi

String outPath = mp3ImportWithFileDTO.getOutPath() + mp3ImportWithFileDTO.getSongName() + ".mid";
String output = Utils.ConvertorRedirect(mp3ImportWithFileDTO.getInputPath(), mp3ImportWithFileDTO.getSongName(), outPath);
OutputStream outputStream = null;
if (output != null) {
return CommonResult.success(output, "mp3转换成功");
File outputFile = new File(output);
if (!outputFile.exists()) {
throw new Exception("midi文件不存在");
}
// System.out.println(outputFile);
try {
// 通过response返回
// 设置文件头 (URLEncoder.encode(mp3ImportWithFileDTO.getSongName() + ".mid", StandardCharsets.US_ASCII)))
response.setHeader("Content-Disposition", "attchement;filename=" + URLEncoder.encode(mp3ImportWithFileDTO.getSongName() + ".mid", StandardCharsets.UTF_8));
response.setCharacterEncoding("UTF-8");
response.setContentType("audio/mid");
// response.setContentType("application/octet-stream");
InputStream fis = new BufferedInputStream(new FileInputStream(outputFile));
byte[] buffer = new byte[fis.available()];
// fis.read(buffer);
fis.close();
response.reset();
outputStream = new BufferedOutputStream(response.getOutputStream());
System.out.println(Arrays.toString(buffer));
outputStream.write(buffer);
System.out.println(outputStream);
outputStream.flush();
// IOUtils.copy(fis, outputStream);
response.flushBuffer();
} catch (Exception e) {
if (null != outputStream) {
try {
outputStream.close();
} catch (IOException e1) {
e1.printStackTrace();
}
}
}
System.out.println(response.getOutputStream());
return CommonResult.success(response.getClass(), "mp3转换成功");
} else {
return CommonResult.failed("mp3转换失败");
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
package com.pianotranscriptioncli.utils;

import javax.sound.midi.*;
import java.io.File;
import java.io.IOException;

public class MidiUtils {

/**
* 读取并播放midi文件
* &#064;参考文档 <a href="https://docs.oracle.com/javase/7/docs/api/javax/sound/midi/package-summary.html">...</a>
* @param output 完整路径(带后缀)
*/
public static void Reduce(String output) {
try {
Sequence sequence = MidiSystem.getSequence(new File(output));
long length = sequence.getMicrosecondLength(); // 获取序列的总时间(微秒)
int trackCount = sequence.getTracks().length; // 获取序列的音轨数
float divType = sequence.getDivisionType(); // 获取序列的(计时方式?)
int resolution = sequence.getResolution(); // 获取序列的时间解析度

MidiDevice.Info[] infos = MidiSystem.getMidiDeviceInfo(); // 获取所有 midi 设备的信息
Sequencer sequencer = MidiSystem.getSequencer(); // 获取默认的音序器
Synthesizer synthsizer = MidiSystem.getSynthesizer(); // 获取默认的合成器
Receiver receiver = MidiSystem.getReceiver(); // 获取默认的接收器
Transmitter transmitter = MidiSystem.getTransmitter(); // 获取默认的传输器
if(sequencer == null) {
throw new IOException("未找到可用音序器!");
}
sequencer.setSequence(sequence); // 设置midi序列

sequencer.start(); // 开始播放当前序列

sequencer.stop(); // 停止播放当前序列

// sequencer.setTempoFactor(float factor); // 设置速度比率 (1.0f 为原速)
//
// sequencer.setMicrosecondPosition(long microseconds); // 设置播放位置到指定微秒
//
// sequencer.setTrackMute(int track, boolean mute); // 开启或关闭一条音轨的静音模式
//
// sequencer.setTrackSolo(int track, boolean solo); // 开启或关闭一条音轨的独奏模式

} catch (InvalidMidiDataException | IOException | MidiUnavailableException e) {
e.printStackTrace();
}
}

}
Loading

0 comments on commit 835e9da

Please sign in to comment.