diff --git a/encodec.cpp b/encodec.cpp index 5548db6..b6010f3 100644 --- a/encodec.cpp +++ b/encodec.cpp @@ -19,11 +19,11 @@ typedef enum { // Run the end-to-end encoder-decoder pipeline - full = 0, + full = 0, // Encode an audio (encoder + quantizer encode) - encode_only = 1, + encode = 1, // Decode an audio from a compressed representation (quantizer decode + decoder) - decode_only = 2, + decode = 2, } encodec_run_mode; void print_tensor(struct ggml_tensor * a) { @@ -1000,6 +1000,7 @@ struct ggml_cgraph * encodec_build_graph( struct encodec_context * ectx, std::vector & inp_audio, const encodec_run_mode mode) { + assert(mode == encodec_run_mode::full || mode == encodec_run_mode::encode); const auto & model = ectx->model; const auto & hparams = model.hparams; @@ -1042,10 +1043,14 @@ struct ggml_cgraph * encodec_build_graph( { ggml_build_forward_expand(gf, decoded); } break; - case encodec_run_mode::encode_only: + case encodec_run_mode::encode: { ggml_build_forward_expand(gf, codes); } break; + case encodec_run_mode::decode: + { + return NULL; + } break; default: { fprintf(stderr, "%s: unknown run mode\n", __func__); @@ -1062,6 +1067,77 @@ struct ggml_cgraph * encodec_build_graph( return gf; } +struct ggml_cgraph * encodec_build_graph( + struct encodec_context * ectx, + std::vector & codes, + const encodec_run_mode mode) { + assert(mode == encodec_run_mode::decode); + + const auto & model = ectx->model; + const auto & hparams = model.hparams; + const auto & allocr = ectx->allocr; + + const int n_bins = hparams.n_bins; + const int sr = hparams.sr; + const int bandwidth = hparams.bandwidth; + const int hop_length = hparams.hop_length; + + const int frame_rate = (int) ceilf(sr / hop_length); + const int n_q = get_num_quantizers_for_bandwidth(n_bins, frame_rate, bandwidth); + + if (codes.size() % n_q != 0) { + fprintf(stderr, "%s: invalid number of codes\n", __func__); + return NULL; + } + + const int N = codes.size() / n_q; + + // since we are using ggml-alloc, this buffer only needs enough space to hold the + // ggml_tensor and ggml_cgraph structs, but not the tensor data + static size_t buf_size = ggml_tensor_overhead()*GGML_MAX_NODES + ggml_graph_overhead(); + static std::vector buf(buf_size); + + struct ggml_init_params ggml_params = { + /*.mem_size =*/ buf_size, + /*.mem_buffer =*/ buf.data(), + /*.no_alloc =*/ true, + }; + + struct ggml_context * ctx0 = ggml_init(ggml_params); + + struct ggml_cgraph * gf = ggml_new_graph(ctx0); + + struct ggml_tensor * inp_codes = ggml_new_tensor_2d(ctx0, GGML_TYPE_I32, N, n_q); + ggml_allocr_alloc(allocr, inp_codes); + + // avoid writing to tensors if we are only measuring the memory usage + if (!ggml_allocr_is_measure(allocr)) { + ggml_backend_tensor_set(inp_codes, codes.data(), 0, N*n_q*ggml_element_size(inp_codes)); + } + + struct ggml_tensor * quantized = encodec_forward_quantizer_decode(ectx, ctx0, inp_codes); + struct ggml_tensor * decoded = encodec_forward_decoder(ectx, ctx0, quantized); + + switch(mode) { + case encodec_run_mode::decode: + { + ggml_build_forward_expand(gf, decoded); + } break; + default: + { + fprintf(stderr, "%s: unknown run mode\n", __func__); + return NULL; + } break; + } + + ggml_free(ctx0); + + ectx->codes = inp_codes; + ectx->decoded = decoded; + + return gf; +} + bool encodec_eval_internal( struct encodec_context * ectx, std::vector & raw_audio, @@ -1087,6 +1163,32 @@ bool encodec_eval_internal( return true; } +bool encodec_eval_internal( + struct encodec_context * ectx, + std::vector & codes, + const int n_threads, + const encodec_run_mode mode) { + auto & model = ectx->model; + auto & allocr = ectx->allocr; + + // reset the allocator to free all the memory allocated during the previous inference + ggml_allocr_reset(allocr); + + struct ggml_cgraph * gf = encodec_build_graph(ectx, codes, mode); + + // allocate tensors + ggml_allocr_alloc_graph(allocr, gf); + + // run the computation + if (ggml_backend_is_cpu(model.backend)) { + ggml_backend_cpu_set_n_threads(model.backend, n_threads); + } + ggml_backend_graph_compute(model.backend, gf); + + return true; +} + + bool encodec_eval( struct encodec_context * ectx, std::vector & raw_audio, @@ -1125,6 +1227,44 @@ bool encodec_eval( return true; } +bool encodec_eval( + struct encodec_context * ectx, + std::vector & codes, + const int n_threads, + const encodec_run_mode mode) { + const int64_t t_start_ms = ggml_time_ms(); + + // allocate the compute buffer + { + // alignment required by the backend + size_t align = ggml_backend_get_alignment(ectx->model.backend); + ectx->allocr = ggml_allocr_new_measure(align); + + // create the graph for memory usage estimation + struct ggml_cgraph * gf = encodec_build_graph(ectx, codes, mode); + + // compute the required memory + size_t mem_size = ggml_allocr_alloc_graph(ectx->allocr, gf); + + // recreate the allocator with the required memory + ggml_allocr_free(ectx->allocr); + ectx->buf_compute = ggml_backend_alloc_buffer(ectx->model.backend, mem_size); + ectx->allocr = ggml_allocr_new_from_buffer(ectx->buf_compute); + + fprintf(stderr, "%s: compute buffer size: %.2f MB\n\n", __func__, mem_size/1024.0/1024.0); + } + + // encodec eval + if (!encodec_eval_internal(ectx, codes, n_threads, mode)) { + fprintf(stderr, "%s: failed to run encodec eval\n", __func__); + return false; + } + + ectx->t_compute_ms = ggml_time_ms() - t_start_ms; + + return true; +} + bool encodec_reconstruct_audio( struct encodec_context * ectx, std::vector & raw_audio, @@ -1155,7 +1295,7 @@ bool encodec_compress_audio( struct encodec_context * ectx, std::vector & raw_audio, int n_threads) { - if(!encodec_eval(ectx, raw_audio, n_threads, encodec_run_mode::encode_only)) { + if(!encodec_eval(ectx, raw_audio, n_threads, encodec_run_mode::encode)) { fprintf(stderr, "%s: failed to run encodec eval\n", __func__); return false; } @@ -1177,6 +1317,32 @@ bool encodec_compress_audio( return true; } +bool encodec_decompress_audio( + struct encodec_context * ectx, + std::vector & codes, + int n_threads) { + if (!encodec_eval(ectx, codes, n_threads, encodec_run_mode::decode)) { + fprintf(stderr, "%s: failed to run encodec eval\n", __func__); + return false; + } + + if (!ectx->decoded) { + fprintf(stderr, "%s: null decoded tensor\n", __func__); + return false; + } + + struct ggml_tensor * decoded = ectx->decoded; + + auto & out_audio = ectx->out_audio; + + int out_length = decoded->ne[0]; + out_audio.resize(out_length); + + ggml_backend_tensor_get(decoded, out_audio.data(), 0, out_length*ggml_element_size(decoded)); + + return true; +} + struct encodec_context * encodec_load_model(const std::string & model_path) { int64_t t_start_load_us = ggml_time_us(); diff --git a/encodec.h b/encodec.h index 2cd2ce6..30820a4 100644 --- a/encodec.h +++ b/encodec.h @@ -1,3 +1,13 @@ +/** + * @file encodec.h + * @brief Header file for the encodec library. + * + * This file contains the declarations of the structs and functions used in the encodec library. + * The library provides functionality for audio compression and decompression using a custom model. + * The model consists of an encoder, a quantizer and a decoder, each with their own set of parameters. + * The library also provides functions for loading and freeing the model, as well as compressing and decompressing audio data. + * + */ #pragma once #include @@ -173,18 +183,68 @@ struct encodec_context { int64_t t_compute_ms = 0; }; -struct encodec_context * encodec_load_model(const std::string & model_path); - -void encodec_set_target_bandwidth(struct encodec_context * ectx, int bandwidth); - +/** + * Loads an encodec model from the specified file path. + * + * @param model_path The file path to the encodec model. + * @return A pointer to the encodec context struct. + */ +struct encodec_context * encodec_load_model( + const std::string & model_path); + +/** + * Sets the target bandwidth for the given encodec context. + * + * @param ectx The encodec context to set the target bandwidth for. + * @param bandwidth The target bandwidth to set, in bits per second. + */ +void encodec_set_target_bandwidth( + struct encodec_context * ectx, + int bandwidth); + +/** + * Reconstructs audio from raw audio data using the specified encodec context. + * + * @param ectx The encodec context to use for reconstruction. + * @param raw_audio The raw audio data to reconstruct. + * @param n_threads The number of threads to use for reconstruction. + * @return True if the reconstruction was successful, false otherwise. + */ bool encodec_reconstruct_audio( struct encodec_context * ectx, std::vector & raw_audio, int n_threads); +/** + * Compresses audio data using the specified encodec context. + * + * @param ectx The encodec context to use for compression. + * @param raw_audio The raw audio data to compress. + * @param n_threads The number of threads to use for compression. + * @return True if the compression was successful, false otherwise. + */ bool encodec_compress_audio( struct encodec_context * ectx, std::vector & raw_audio, int n_threads); -void encodec_free(struct encodec_context * ectx); \ No newline at end of file +/** + * Decompresses audio data using the specified encodec context. + * + * @param ectx The encodec context to use for decompression. + * @param codes The compressed audio data to decompress. + * @param n_threads The number of threads to use for decompression. + * @return True if the audio data was successfully decompressed, false otherwise. + */ +bool encodec_decompress_audio( + struct encodec_context * ectx, + std::vector & codes, + int n_threads); + +/** + * @brief Frees the memory allocated for an encodec context. + * + * @param ectx The encodec context to free. + */ +void encodec_free( + struct encodec_context * ectx); diff --git a/examples/common.cpp b/examples/common.cpp index c1bd458..7a9eba8 100644 --- a/examples/common.cpp +++ b/examples/common.cpp @@ -1,4 +1,5 @@ #include +#include #include #include #include @@ -15,6 +16,7 @@ #define SAMPLE_RATE 24000 #define BITS_PER_CODEBOOK 10 // int(log2(quantizer.bins)); quantizer.bins = 1024 +using json = nlohmann::json; // The ECDC file format expects big endian byte order. // This function swaps the endianness of a 32-bit integer. @@ -75,7 +77,7 @@ int encodec_params_parse(int argc, char ** argv, encodec_params & params) { return 0; } -bool read_wav_from_disk(std::string in_path, std::vector& audio_arr) { +bool read_wav_from_disk(std::string in_path, std::vector & audio_arr) { uint32_t channels; uint32_t sample_rate; drwav_uint64 total_frame_count; @@ -98,13 +100,13 @@ bool read_wav_from_disk(std::string in_path, std::vector& audio_arr) { return true; } -void write_wav_on_disk(std::vector& audio_arr, std::string dest_path) { +void write_wav_on_disk(std::vector & audio_arr, std::string dest_path) { drwav_data_format format; - format.container = drwav_container_riff; - format.format = DR_WAVE_FORMAT_IEEE_FLOAT; - format.channels = 1; - format.sampleRate = SAMPLE_RATE; format.bitsPerSample = 32; + format.sampleRate = SAMPLE_RATE; + format.container = drwav_container_riff; + format.channels = 1; + format.format = DR_WAVE_FORMAT_IEEE_FLOAT; drwav wav; drwav_init_file_write(&wav, dest_path.c_str(), &format, NULL); @@ -181,27 +183,8 @@ class BitUnpacker { int current_bits; }; -std::vector read_exactly(std::ifstream& fo, size_t size) { - std::vector buf; - buf.reserve(size); - - while (buf.size() < size) { - char chunk[size]; - fo.read(chunk, size); - size_t bytesRead = fo.gcount(); - if (bytesRead == 0) { - throw std::runtime_error("Impossible to read enough data from the stream, " + - std::to_string(size) + " bytes remaining."); - } - buf.insert(buf.end(), chunk, chunk + bytesRead); - size -= bytesRead; - } - - return buf; -} - void write_encodec_header(std::ofstream & fo, uint32_t audio_length) { - nlohmann::json metadata = { + json metadata = { {"m" , "encodec_24khz"}, {"al", audio_length}, {"nc", 16}, @@ -209,7 +192,7 @@ void write_encodec_header(std::ofstream & fo, uint32_t audio_length) { }; std::string meta_dumped = metadata.dump(); - char magic[4] = { 'E', 'C', 'D', 'C'}; + std::string magic = "ECDC"; uint8_t version = 0; uint32_t meta_length = static_cast(meta_dumped.size()); @@ -218,70 +201,116 @@ void write_encodec_header(std::ofstream & fo, uint32_t audio_length) { meta_length = swap_endianness(meta_length); } - fo.write(magic, 4); - fo.write(reinterpret_cast(&version), sizeof(version)); - fo.write(reinterpret_cast(&meta_length), sizeof(uint32_t)); + fo.write(magic.c_str(), magic.size()); + fo.write((char *) &version, sizeof(version)); + fo.write((char *) &meta_length, sizeof(uint32_t)); fo.write(meta_dumped.data(), meta_dumped.size()); + fo.flush(); } -nlohmann::json read_ecdc_header(std::ifstream& fo) { - int size_header = 4 * sizeof(char) + sizeof(uint8_t) + sizeof(uint32_t); - std::vector header_bytes = read_exactly(fo, size_header); +json read_ecdc_header(std::ifstream & fin) { + std::string magic; + uint8_t version; + uint32_t meta_length; + + std::string meta_str; + + std::vector buf_magic(4); + fin.read(&buf_magic[0], buf_magic.size()); + magic.assign(&buf_magic[0], buf_magic.size()); - char * magic = reinterpret_cast(header_bytes.data(), header_bytes.data() + 4*sizeof(char)); - uint8_t * version = reinterpret_cast(header_bytes.data() + 4*sizeof(char)); - uint32_t * meta_length = reinterpret_cast(header_bytes.data() + 4*sizeof(char) + sizeof(uint8_t)); + fin.read((char *) &version, sizeof(version)); + fin.read((char *) &meta_length, sizeof(meta_length)); - if (strcmp(magic, "ECDC") != 0) { + // switch to little endian if necessary + if (!is_big_endian()) { + meta_length = swap_endianness(meta_length); + } + + if (magic != "ECDC") { throw std::runtime_error("File is not in ECDC format."); } - if (*version != 0) { + + if (version != 0) { throw std::runtime_error("Version not supported."); } - std::vector meta_bytes = read_exactly(fo, *meta_length); - std::string meta_str(meta_bytes.begin(), meta_bytes.end()); - return nlohmann::json::parse(meta_str); + std::vector buf_meta(meta_length); + fin.read(&buf_meta[0], buf_meta.size()); + meta_str.assign(&buf_meta[0], buf_meta.size()); + + return json::parse(meta_str); } -void write_encodec_codes(std::ofstream & fo, std::vector & codes) { +void write_encodec_codes( + std::ofstream & fo, + std::vector & codes) { BitPacker bp(BITS_PER_CODEBOOK, fo); + for (int32_t code : codes) { bp.push(code); } + bp.flush(); } -bool write_codes_to_file(std::string dest_path, std::vector & codes, uint32_t audio_length) { +bool write_codes_to_file( + std::string dest_path, + std::vector & codes, + uint32_t audio_length) { std::ofstream fo(dest_path, std::ios::binary); + write_encodec_header(fo, audio_length); write_encodec_codes(fo, codes); + fo.close(); return true; } -std::vector read_codes_from_file(std::string code_path) { +bool read_codes_from_file( + std::string code_path, + std::vector & codes, + uint32_t & audio_length, + uint32_t & n_codebooks) { std::ifstream fin(code_path, std::ios::binary); - nlohmann::json metadata = read_ecdc_header(fin); - uint32_t audio_length = metadata["audio_length"]; - uint32_t n_codebooks = metadata["n_codebooks"]; + json metadata = read_ecdc_header(fin); + + try { + if (metadata.contains("al") && metadata["al"].is_number_unsigned()) { + audio_length = metadata["al"]; + } else { + fprintf(stderr, "error: metadata does not contain audio length\n"); + return false; + } + + if (metadata.contains("nc") && metadata["nc"].is_number_unsigned()) { + n_codebooks = metadata["nc"]; + } else { + fprintf(stderr, "error: metadata does not contain number of codebooks\n"); + return false; + } + } catch (const json::exception & ex) { + fprintf(stderr, "JSON Error: %s", ex.what()); + } + + // TODO: remove hardcoded values + const int hop_length = 320; // 8 * 5 * 4 * 2 + const int frame_rate = std::ceil((float) SAMPLE_RATE / hop_length); + const int frame_length = std::ceil((float) audio_length * frame_rate / SAMPLE_RATE); - std::vector codes; - codes.resize(audio_length * n_codebooks); + codes.resize(frame_length * n_codebooks); BitUnpacker bu(BITS_PER_CODEBOOK, fin); - for (int t = 0; t < audio_length; t++) { - for (int c = 0; c < n_codebooks; c++) { - codes[t * n_codebooks + c] = bu.pull(); - } + for (size_t i = 0; i < codes.size(); i++) { + codes[i] = bu.pull(); } fin.close(); - return codes; -} \ No newline at end of file + return true; +} diff --git a/examples/common.h b/examples/common.h index e9da652..40fad88 100644 --- a/examples/common.h +++ b/examples/common.h @@ -62,6 +62,13 @@ bool write_codes_to_file(std::string dest_path, std::vector & codes, ui * @brief Reads a vector of integers from a file on disk. * * @param code_path Path to the input file. + * @param codes Vector to store the codes. + * @param audio_length Original length of the audio. + * @param n_codebooks Number of codebooks used to encode the audio. * @return std::vector Vector containing the integers read from the file. */ -std::vector read_codes_from_file(std::string code_path); +bool read_codes_from_file( + std::string code_path, + std::vector & codes, + uint32_t & audio_length, + uint32_t & n_codebooks); diff --git a/examples/decompress/main.cpp b/examples/decompress/main.cpp index dfbc904..a21d3b4 100644 --- a/examples/decompress/main.cpp +++ b/examples/decompress/main.cpp @@ -25,14 +25,26 @@ int main(int argc, char **argv) { return 1; } + encodec_set_target_bandwidth(ectx, 12); + // read compressed audio from disk - std::vector codes = read_codes_from_file(params.input_path); + std::vector codes; + uint32_t audio_length, n_codebooks; + if (!read_codes_from_file(params.input_path, codes, audio_length, n_codebooks)) { + printf("%s: error during reading codes\n", __func__); + return 1; + } // decompress audio - // TODO: decompress audio + if (!encodec_decompress_audio(ectx, codes, params.n_threads)) { + printf("%s: error during decompression\n", __func__); + return 1; + } // write reconstructed audio on disk - // TODO: write codec output + auto & audio_arr = ectx->out_audio; + audio_arr.resize(audio_length); + write_wav_on_disk(audio_arr, params.output_path); // report timing { diff --git a/examples/main/main.cpp b/examples/main/main.cpp index a8527b8..b07b623 100644 --- a/examples/main/main.cpp +++ b/examples/main/main.cpp @@ -29,7 +29,7 @@ int main(int argc, char **argv) { // read audio from disk std::vector original_audio_arr; - if(!read_wav_from_disk(params.input_path, original_audio_arr)) { + if (!read_wav_from_disk(params.input_path, original_audio_arr)) { printf("%s: error during reading wav file\n", __func__); return 1; }