Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

sync : whisper.cpp #823

Merged
merged 14 commits into from
May 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
2 changes: 1 addition & 1 deletion examples/whisper/grammar-parser.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -190,7 +190,7 @@ namespace grammar_parser {
pos = parse_space(pos + 1, is_nested);
} else if (*pos == '*' || *pos == '+' || *pos == '?') { // repetition operator
if (last_sym_start == out_elements.size()) {
throw std::runtime_error(std::string("expecting preceeding item to */+/? at ") + pos);
throw std::runtime_error(std::string("expecting preceding item to */+/? at ") + pos);
}

// apply transformation to previous symbol (last_sym_start to end) according to
Expand Down
50 changes: 46 additions & 4 deletions examples/whisper/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,8 @@ struct whisper_params {
float entropy_thold = 2.40f;
float logprob_thold = -1.00f;
float grammar_penalty = 100.0f;
float temperature = 0.0f;
float temperature_inc = 0.2f;

bool speed_up = false;
bool debug_mode = false;
Expand Down Expand Up @@ -133,6 +135,8 @@ bool whisper_params_parse(int argc, char ** argv, whisper_params & params) {
else if (arg == "-wt" || arg == "--word-thold") { params.word_thold = std::stof(argv[++i]); }
else if (arg == "-et" || arg == "--entropy-thold") { params.entropy_thold = std::stof(argv[++i]); }
else if (arg == "-lpt" || arg == "--logprob-thold") { params.logprob_thold = std::stof(argv[++i]); }
else if (arg == "-tp" || arg == "--temperature") { params.temperature = std::stof(argv[++i]); }
else if (arg == "-tpi" || arg == "--temperature-inc") { params.temperature_inc = std::stof(argv[++i]); }
// else if (arg == "-su" || arg == "--speed-up") { params.speed_up = true; }
else if (arg == "-debug"|| arg == "--debug-mode") { params.debug_mode = true; }
else if (arg == "-tr" || arg == "--translate") { params.translate = true; }
Expand Down Expand Up @@ -198,6 +202,8 @@ void whisper_print_usage(int /*argc*/, char ** argv, const whisper_params & para
fprintf(stderr, " -wt N, --word-thold N [%-7.2f] word timestamp probability threshold\n", params.word_thold);
fprintf(stderr, " -et N, --entropy-thold N [%-7.2f] entropy threshold for decoder fail\n", params.entropy_thold);
fprintf(stderr, " -lpt N, --logprob-thold N [%-7.2f] log probability threshold for decoder fail\n", params.logprob_thold);
fprintf(stderr, " -tp, --temperature N [%-7.2f] The sampling temperature, between 0 and 1\n", params.temperature);
fprintf(stderr, " -tpi, --temperature-inc N [%-7.2f] The increment of temperature, between 0 and 1\n",params.temperature_inc);
// fprintf(stderr, " -su, --speed-up [%-7s] speed up audio by x2 (reduced accuracy)\n", params.speed_up ? "true" : "false");
fprintf(stderr, " -debug, --debug-mode [%-7s] enable debug mode (eg. dump log_mel)\n", params.debug_mode ? "true" : "false");
fprintf(stderr, " -tr, --translate [%-7s] translate from source language to english\n", params.translate ? "true" : "false");
Expand Down Expand Up @@ -471,6 +477,38 @@ char *escape_double_quotes_and_backslashes(const char *str) {
return escaped;
}

// double quote should be escaped by another double quote. (rfc4180)
char *escape_double_quotes_in_csv(const char *str) {
if (str == NULL) {
return NULL;
}

size_t escaped_length = strlen(str) + 1;

for (size_t i = 0; str[i] != '\0'; i++) {
if (str[i] == '"') {
escaped_length++;
}
}

char *escaped = (char *)calloc(escaped_length, 1); // pre-zeroed
if (escaped == NULL) {
return NULL;
}

size_t pos = 0;
for (size_t i = 0; str[i] != '\0'; i++) {
if (str[i] == '"') {
escaped[pos++] = '"';
}
escaped[pos++] = str[i];
}

// no need to set zero due to calloc() being used prior

return escaped;
}

bool output_csv(struct whisper_context * ctx, const char * fname, const whisper_params & params, std::vector<std::vector<float>> pcmf32s) {
std::ofstream fout(fname);
if (!fout.is_open()) {
Expand All @@ -492,7 +530,7 @@ bool output_csv(struct whisper_context * ctx, const char * fname, const whisper_
const char * text = whisper_full_get_segment_text(ctx, i);
const int64_t t0 = whisper_full_get_segment_t0(ctx, i);
const int64_t t1 = whisper_full_get_segment_t1(ctx, i);
char * text_escaped = escape_double_quotes_and_backslashes(text);
char * text_escaped = escape_double_quotes_in_csv(text);

//need to multiply times returned from whisper_full_get_segment_t{0,1}() by 10 to get milliseconds.
fout << 10 * t0 << "," << 10 * t1 << ",";
Expand Down Expand Up @@ -1068,14 +1106,16 @@ int main(int argc, char ** argv) {

wparams.tdrz_enable = params.tinydiarize; // [TDRZ]

wparams.suppress_regex = params.suppress_regex.c_str();
wparams.suppress_regex = params.suppress_regex.empty() ? nullptr : params.suppress_regex.c_str();

wparams.initial_prompt = params.prompt.c_str();

wparams.greedy.best_of = params.best_of;
wparams.beam_search.beam_size = params.beam_size;

wparams.temperature_inc = params.no_fallback ? 0.0f : wparams.temperature_inc;
wparams.temperature_inc = params.no_fallback ? 0.0f : params.temperature_inc;
wparams.temperature = params.temperature;

wparams.entropy_thold = params.entropy_thold;
wparams.logprob_thold = params.logprob_thold;

Expand Down Expand Up @@ -1193,7 +1233,9 @@ int main(int argc, char ** argv) {
}
}

whisper_print_timings(ctx);
if (!params.no_prints) {
whisper_print_timings(ctx);
}
whisper_free(ctx);

return 0;
Expand Down
61 changes: 18 additions & 43 deletions examples/whisper/whisper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
#include <regex>
#include <random>
#include <functional>
#include <codecvt>

#if defined(_MSC_VER)
#pragma warning(disable: 4244 4267) // possible loss of data
Expand Down Expand Up @@ -147,7 +148,6 @@ static void whisper_log_callback_default(ggml_log_level level, const char * text
} \
} while (0)

//#define WHISPER_USE_FLASH_ATTN
//#define WHISPER_USE_FLASH_FF
#define WHISPER_MAX_DECODERS 8
#define WHISPER_MAX_NODES 4096
Expand Down Expand Up @@ -1951,32 +1951,6 @@ static struct ggml_cgraph * whisper_build_graph_encoder(

// ------

#ifdef WHISPER_USE_FLASH_ATTN
struct ggml_tensor * Q =
ggml_permute(ctx0,
ggml_cpy(ctx0,
Qcur,
ggml_new_tensor_3d(ctx0, wctx.itype, n_state/n_head, n_head, n_ctx)),
0, 2, 1, 3);

struct ggml_tensor * K =
ggml_permute(ctx0,
ggml_cpy(ctx0,
Kcur,
ggml_new_tensor_3d(ctx0, wctx.itype, n_state/n_head, n_head, n_ctx)),
0, 2, 1, 3);

struct ggml_tensor * V =
ggml_cpy(ctx0,
ggml_permute(ctx0,
ggml_reshape_3d(ctx0,
Vcur,
n_state/n_head, n_head, n_ctx),
1, 2, 0, 3),
ggml_new_tensor_3d(ctx0, wctx.itype, n_ctx, n_state/n_head, n_head));

struct ggml_tensor * KQV = ggml_flash_attn(ctx0, Q, K, V, false);
#else
struct ggml_tensor * Q =
ggml_permute(ctx0,
ggml_cpy(ctx0,
Expand All @@ -1994,9 +1968,7 @@ static struct ggml_cgraph * whisper_build_graph_encoder(
// K * Q
struct ggml_tensor * KQ = ggml_mul_mat(ctx0, K, Q);

struct ggml_tensor * KQ_scaled = ggml_scale(ctx0, KQ, KQscale);

struct ggml_tensor * KQ_soft_max = ggml_soft_max(ctx0, KQ_scaled);
struct ggml_tensor * KQ_soft_max = ggml_soft_max_ext(ctx0, KQ, nullptr, KQscale, 0.0f);

struct ggml_tensor * V =
ggml_cpy(ctx0,
Expand All @@ -2009,7 +1981,7 @@ static struct ggml_cgraph * whisper_build_graph_encoder(
);

struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V, KQ_soft_max);
#endif

struct ggml_tensor * KQV_merged = ggml_permute(ctx0, KQV, 0, 2, 1, 3);

cur = ggml_cpy(ctx0,
Expand Down Expand Up @@ -2406,12 +2378,7 @@ static struct ggml_cgraph * whisper_build_graph_decoder(
// K * Q
struct ggml_tensor * KQ = ggml_mul_mat(ctx0, K, Q);

//struct ggml_tensor * KQ_scaled = ggml_scale(ctx0, KQ, KQ_scale);

//struct ggml_tensor * KQ_masked = ggml_diag_mask_inf(ctx0, KQ, n_past);
struct ggml_tensor * KQ_masked = ggml_add(ctx0, KQ, KQ_mask);

struct ggml_tensor * KQ_soft_max = ggml_soft_max(ctx0, KQ_masked);
struct ggml_tensor * KQ_soft_max = ggml_soft_max_ext(ctx0, KQ, KQ_mask, 1.0f, 0.0f);

struct ggml_tensor * V =
ggml_view_3d(ctx0, kv_self.v,
Expand Down Expand Up @@ -2900,11 +2867,13 @@ static void log_mel_spectrogram_worker_thread(int ith, const std::vector<float>
int n_samples, int frame_size, int frame_step, int n_threads,
const whisper_filters & filters, whisper_mel & mel) {
std::vector<float> fft_in(frame_size, 0.0);
std::vector<float> fft_out(2 * frame_step);
// make sure n_fft == 1 + (WHISPER_N_FFT / 2), bin_0 to bin_nyquist
int n_fft = 1 + (frame_size / 2);
std::vector<float> fft_out(2 * frame_size);
int n_fft = filters.n_fft;
int i = ith;

// make sure n_fft == 1 + (WHISPER_N_FFT / 2), bin_0 to bin_nyquist
assert(n_fft == 1 + (frame_size / 2));

// calculate FFT only when fft_in are not all zero
for (; i < std::min(n_samples / frame_step + 1, mel.n_len); i += n_threads) {
const int offset = i * frame_step;
Expand All @@ -2923,7 +2892,7 @@ static void log_mel_spectrogram_worker_thread(int ith, const std::vector<float>

// Calculate modulus^2 of complex numbers
// Use pow(fft_out[2 * j + 0], 2) + pow(fft_out[2 * j + 1], 2) causes inference quality problem? Interesting.
for (int j = 0; j < frame_size; j++) {
for (int j = 0; j < n_fft; j++) {
fft_out[j] = (fft_out[2 * j + 0] * fft_out[2 * j + 0] + fft_out[2 * j + 1] * fft_out[2 * j + 1]);
}

Expand Down Expand Up @@ -3394,8 +3363,14 @@ struct whisper_context_params whisper_context_default_params() {

struct whisper_context * whisper_init_from_file_with_params_no_state(const char * path_model, struct whisper_context_params params) {
WHISPER_LOG_INFO("%s: loading model from '%s'\n", __func__, path_model);

#ifdef _MSC_VER
// Convert UTF-8 path to wide string (UTF-16) for Windows, resolving character encoding issues.
std::wstring_convert<std::codecvt_utf8<wchar_t>> converter;
std::wstring path_model_wide = converter.from_bytes(path_model);
auto fin = std::ifstream(path_model_wide, std::ios::binary);
#else
auto fin = std::ifstream(path_model, std::ios::binary);
#endif
if (!fin) {
WHISPER_LOG_ERROR("%s: failed to open '%s'\n", __func__, path_model);
return nullptr;
Expand Down Expand Up @@ -5272,7 +5247,7 @@ int whisper_full_with_state(
// basically don't process anything that is less than 1.0s
// see issue #39: https://github.com/ggerganov/whisper.cpp/issues/39
if (seek_end < seek_start + (params.speed_up ? 50 : 100)) {
WHISPER_LOG_DEBUG("%s: input is too short - %d ms < 1000 ms\n", __func__, (seek_end - seek_start)*10);
WHISPER_LOG_WARN("%s: input is too short - %d ms < 1000 ms. consider padding the input audio with silence\n", __func__, (seek_end - seek_start)*10);
return 0;
}

Expand Down
2 changes: 1 addition & 1 deletion scripts/sync-whisper.last
Original file line number Diff line number Diff line change
@@ -1 +1 @@
8f253ef3af1c62c04316ba4afa7145fc4d701a8c
d8356a1cc2b95218a808425b50734007fd13aa00
7 changes: 7 additions & 0 deletions src/ggml-impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -120,9 +120,16 @@ extern "C" {
#ifndef __F16C__
#define __F16C__
#endif
#endif

// __SSE3__ and __SSSE3__ are not defined in MSVC, but SSE3/SSSE3 are present when AVX/AVX2/AVX512 are available
#if defined(_MSC_VER) && (defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__))
#ifndef __SSE3__
#define __SSE3__
#endif
#ifndef __SSSE3__
#define __SSSE3__
#endif
#endif

// 16-bit float
Expand Down
4 changes: 2 additions & 2 deletions src/ggml-metal.m
Original file line number Diff line number Diff line change
Expand Up @@ -1378,7 +1378,7 @@ static enum ggml_status ggml_metal_graph_compute(
const bool use_f16 = (src1 && src1->type == GGML_TYPE_F16);

if (ne00%4 == 0) {
while (nth < ne00/4 && nth < 256) {
while (nth < ne00/4 && nth*ne01*ne02*ne03 < 256) {
nth *= 2;
}
if (use_f16) {
Expand All @@ -1387,7 +1387,7 @@ static enum ggml_status ggml_metal_graph_compute(
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SOFT_MAX_F32_4].pipeline;
}
} else {
while (nth < ne00 && nth < 1024) {
while (nth < ne00 && nth*ne01*ne02*ne03 < 256) {
nth *= 2;
}
if (use_f16) {
Expand Down
2 changes: 1 addition & 1 deletion src/ggml-quants.c
Original file line number Diff line number Diff line change
Expand Up @@ -11425,7 +11425,7 @@ void ggml_vec_dot_iq1_s_q8_K (int n, float * restrict s, size_t bs, const void

vector signed short qxh = (vector signed short)vec_sld(vec_splats(qh[1]), vec_splats(qh[0]), 8);
qh += 2;
vector bool short vsel = vec_cmpge(qxh, (vector signed short)v0);
vector __bool short vsel = vec_cmpge(qxh, (vector signed short)v0);

vector signed short q8ysum = vec_sel((vector signed short)vec_xor((vector unsigned short)q8ysums, vsign), q8ysums, vsel);

Expand Down
2 changes: 2 additions & 0 deletions src/ggml.c
Original file line number Diff line number Diff line change
Expand Up @@ -1306,6 +1306,8 @@ static inline void __avx_f32cx8_store(ggml_fp16_t *x, __m256 y) {
#define GGML_F16_VEC_ZERO GGML_F32x4_ZERO
#define GGML_F16_VEC_SET1 GGML_F32x4_SET1
#define GGML_F16_VEC_FMA GGML_F32x4_FMA
#define GGML_F16_VEC_ADD GGML_F32x4_ADD
#define GGML_F16_VEC_MUL GGML_F32x4_MUL
#define GGML_F16_VEC_REDUCE GGML_F32x4_REDUCE
// Use vec_xl, not vec_ld, in case the load address is not aligned.
#define GGML_F16_VEC_LOAD(p, i) (i & 0x1) ? \
Expand Down