Skip to content

Commit

Permalink
sync : llama.cpp (ggml_scale, ggml_row_size, ggml_mul_mat_set_prec) (#…
Browse files Browse the repository at this point in the history
…662)

* sync : llama.cpp (ggml_scale, ggml_row_size, ggml_mul_mat_set_prec)

ggml-ci

* ggml : add comment about backward GGML_OP_DIAG_MASK_INF (#4203)

* llama : fix platforms without mmap (#4578)

* llama : fix platforms without mmap

* win32 : limit prefetch size to the file size

* fix win32 error clobber, unnecessary std::string in std::runtime_error

* ggml-alloc : fix ggml_tallocr_is_own

* whisper : minor

* ggml : cuda jetson + arm quants warnings

ggml-ci

---------

Co-authored-by: Herman Semenov <[email protected]>
Co-authored-by: slaren <[email protected]>
  • Loading branch information
3 people committed Dec 22, 2023
1 parent c80e07e commit 845d01b
Show file tree
Hide file tree
Showing 33 changed files with 1,244 additions and 729 deletions.
41 changes: 20 additions & 21 deletions examples/dolly-v2/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -192,34 +192,34 @@ bool dollyv2_model_load(const std::string & fname, dollyv2_model & model, gpt_vo
const int n_ctx = hparams.n_ctx;
const int n_vocab = hparams.n_vocab;

ctx_size += n_embd*ggml_type_sizef(GGML_TYPE_F32); // ln_f_g
ctx_size += n_embd*ggml_type_sizef(GGML_TYPE_F32); // ln_f_b
ctx_size += ggml_row_size(GGML_TYPE_F32, n_embd); // ln_f_g
ctx_size += ggml_row_size(GGML_TYPE_F32, n_embd); // ln_f_b

ctx_size += n_embd*n_vocab*ggml_type_sizef(wtype); // wte
ctx_size += ggml_row_size(wtype, n_embd*n_vocab); // wte

ctx_size += n_embd*n_vocab*ggml_type_sizef(wtype); // lmh_g
//ctx_size += n_vocab*ggml_type_sizef(GGML_TYPE_F32); // lmh_b
ctx_size += ggml_row_size(wtype, n_embd*n_vocab); // lmh_g
//ctx_size += ggml_row_size(GGML_TYPE_F32, n_vocab); // lmh_b

ctx_size += n_layer*(n_embd*ggml_type_sizef(GGML_TYPE_F32)); // ln_1_g
ctx_size += n_layer*(n_embd*ggml_type_sizef(GGML_TYPE_F32)); // ln_1_b
ctx_size += n_layer*(ggml_row_size(GGML_TYPE_F32, n_embd)); // ln_1_g
ctx_size += n_layer*(ggml_row_size(GGML_TYPE_F32, n_embd)); // ln_1_b

ctx_size += n_layer*(3*n_embd*n_embd*ggml_type_sizef(wtype)); // c_attn_attn_w
ctx_size += n_layer*( 3*n_embd*ggml_type_sizef(GGML_TYPE_F32)); // c_attn_attn_b
ctx_size += n_layer*(ggml_row_size(wtype, 3*n_embd*n_embd)); // c_attn_attn_w
ctx_size += n_layer*(ggml_row_size(GGML_TYPE_F32, 3*n_embd)); // c_attn_attn_b

ctx_size += n_layer*(n_embd*n_embd*ggml_type_sizef(wtype)); // c_attn_proj_w
ctx_size += n_layer*(n_embd*n_embd*ggml_type_sizef(GGML_TYPE_F32)); // c_attn_proj_b
ctx_size += n_layer*(ggml_row_size(wtype, n_embd*n_embd)); // c_attn_proj_w
ctx_size += n_layer*(ggml_row_size(GGML_TYPE_F32, n_embd*n_embd)); // c_attn_proj_b

ctx_size += n_layer*(n_embd*ggml_type_sizef(GGML_TYPE_F32)); // ln_2_g
ctx_size += n_layer*(n_embd*ggml_type_sizef(GGML_TYPE_F32)); // ln_2_b
ctx_size += n_layer*(ggml_row_size(GGML_TYPE_F32, n_embd)); // ln_2_g
ctx_size += n_layer*(ggml_row_size(GGML_TYPE_F32, n_embd)); // ln_2_b

ctx_size += n_layer*(4*n_embd*n_embd*ggml_type_sizef(wtype)); // c_mlp_fc_w
ctx_size += n_layer*( 4*n_embd*ggml_type_sizef(GGML_TYPE_F32)); // c_mlp_fc_b
ctx_size += n_layer*(ggml_row_size(wtype, 4*n_embd*n_embd)); // c_mlp_fc_w
ctx_size += n_layer*(ggml_row_size(GGML_TYPE_F32, 4*n_embd)); // c_mlp_fc_b

ctx_size += n_layer*(4*n_embd*n_embd*ggml_type_sizef(wtype)); // c_mlp_proj_w
ctx_size += n_layer*( n_embd*ggml_type_sizef(GGML_TYPE_F32)); // c_mlp_proj_b
ctx_size += n_layer*(ggml_row_size(wtype, 4*n_embd*n_embd)); // c_mlp_proj_w
ctx_size += n_layer*(ggml_row_size(GGML_TYPE_F32, n_embd)); // c_mlp_proj_b

ctx_size += n_ctx*n_layer*n_embd*ggml_type_sizef(GGML_TYPE_F32); // memory_k
ctx_size += n_ctx*n_layer*n_embd*ggml_type_sizef(GGML_TYPE_F32); // memory_v
ctx_size += n_ctx*n_layer*ggml_row_size(GGML_TYPE_F32, n_embd); // memory_k
ctx_size += n_ctx*n_layer*ggml_row_size(GGML_TYPE_F32, n_embd); // memory_v

ctx_size += (6 + 16*n_layer)*512; // object overhead

Expand Down Expand Up @@ -580,8 +580,7 @@ bool dollyv2_eval(
struct ggml_tensor * KQ_scaled =
ggml_scale_inplace(ctx0,
KQ,
ggml_new_f32(ctx0, 1.0f/sqrt(float(n_embd)/n_head))
);
1.0f/sqrt(float(n_embd)/n_head));

// KQ_masked = mask_past(KQ_scaled)
struct ggml_tensor * KQ_masked = ggml_diag_mask_inf_inplace(ctx0, KQ_scaled, n_past);
Expand Down
46 changes: 20 additions & 26 deletions examples/gpt-2/main-alloc.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -165,33 +165,33 @@ bool gpt2_model_load(const std::string & fname, gpt2_model & model, gpt_vocab &
const int n_ctx = hparams.n_ctx;
const int n_vocab = hparams.n_vocab;

ctx_size += n_embd*ggml_type_sizef(GGML_TYPE_F32); // ln_f_g
ctx_size += n_embd*ggml_type_sizef(GGML_TYPE_F32); // ln_f_b
ctx_size += ggml_row_size(GGML_TYPE_F32, n_embd); // ln_f_g
ctx_size += ggml_row_size(GGML_TYPE_F32, n_embd); // ln_f_b

ctx_size += n_vocab*n_embd*ggml_type_sizef(wtype); // wte
ctx_size += n_ctx*n_embd*ggml_type_sizef(GGML_TYPE_F32); // wpe
ctx_size += n_vocab*n_embd*ggml_type_sizef(wtype); // lm_head
ctx_size += ggml_row_size(wtype, n_vocab*n_embd); // wte
ctx_size += ggml_row_size(GGML_TYPE_F32 , n_ctx*n_embd); // wpe
ctx_size += ggml_row_size(wtype, n_vocab*n_embd); // lm_head

ctx_size += n_layer*(n_embd*ggml_type_sizef(GGML_TYPE_F32)); // ln_1_g
ctx_size += n_layer*(n_embd*ggml_type_sizef(GGML_TYPE_F32)); // ln_1_b
ctx_size += n_layer*(ggml_row_size(GGML_TYPE_F32, n_embd)); // ln_1_g
ctx_size += n_layer*(ggml_row_size(GGML_TYPE_F32, n_embd)); // ln_1_b

ctx_size += n_layer*(n_embd*ggml_type_sizef(GGML_TYPE_F32)); // ln_2_g
ctx_size += n_layer*(n_embd*ggml_type_sizef(GGML_TYPE_F32)); // ln_2_b
ctx_size += n_layer*(ggml_row_size(GGML_TYPE_F32, n_embd)); // ln_2_g
ctx_size += n_layer*(ggml_row_size(GGML_TYPE_F32, n_embd)); // ln_2_b

ctx_size += n_layer*(3*n_embd*n_embd*ggml_type_sizef(wtype)); // c_attn_attn_w
ctx_size += n_layer*( 3*n_embd*ggml_type_sizef(GGML_TYPE_F32)); // c_attn_attn_b
ctx_size += n_layer*(ggml_row_size(wtype, 3*n_embd*n_embd)); // c_attn_attn_w
ctx_size += n_layer*(ggml_row_size(GGML_TYPE_F32, 3*n_embd)); // c_attn_attn_b

ctx_size += n_layer*(n_embd*n_embd*ggml_type_sizef(wtype)); // c_attn_proj_w
ctx_size += n_layer*( n_embd*ggml_type_sizef(GGML_TYPE_F32)); // c_attn_proj_b
ctx_size += n_layer*(ggml_row_size(wtype, n_embd*n_embd)); // c_attn_proj_w
ctx_size += n_layer*(ggml_row_size(GGML_TYPE_F32, n_embd)); // c_attn_proj_b

ctx_size += n_layer*(4*n_embd*n_embd*ggml_type_sizef(wtype)); // c_mlp_fc_w
ctx_size += n_layer*( 4*n_embd*ggml_type_sizef(GGML_TYPE_F32)); // c_mlp_fc_b
ctx_size += n_layer*(ggml_row_size(wtype, 4*n_embd*n_embd)); // c_mlp_fc_w
ctx_size += n_layer*(ggml_row_size(GGML_TYPE_F32, 4*n_embd)); // c_mlp_fc_b

ctx_size += n_layer*(4*n_embd*n_embd*ggml_type_sizef(wtype)); // c_mlp_proj_w
ctx_size += n_layer*( n_embd*ggml_type_sizef(GGML_TYPE_F32)); // c_mlp_proj_b
ctx_size += n_layer*(ggml_row_size(wtype, 4*n_embd*n_embd)); // c_mlp_proj_w
ctx_size += n_layer*(ggml_row_size(GGML_TYPE_F32, 4*n_embd)); // c_mlp_proj_b

ctx_size += n_ctx*n_layer*n_embd*ggml_type_sizef(GGML_TYPE_F32); // memory_k
ctx_size += n_ctx*n_layer*n_embd*ggml_type_sizef(GGML_TYPE_F32); // memory_v
ctx_size += n_ctx*n_layer*ggml_row_size(GGML_TYPE_F32, n_embd); // memory_k
ctx_size += n_ctx*n_layer*ggml_row_size(GGML_TYPE_F32, n_embd); // memory_v

ctx_size += (6 + 12*n_layer)*512; // object overhead

Expand Down Expand Up @@ -427,12 +427,6 @@ struct ggml_cgraph * gpt2_graph(
}
}

struct ggml_tensor * KQ_scale = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, 1);
ggml_allocr_alloc(allocr, KQ_scale);
if (!ggml_allocr_is_measure(allocr)) {
ggml_set_f32(KQ_scale, 1.0f/sqrtf(float(n_embd)/n_head));
}

// wte + wpe
struct ggml_tensor * inpL =
ggml_add(ctx0,
Expand Down Expand Up @@ -528,7 +522,7 @@ struct ggml_cgraph * gpt2_graph(
struct ggml_tensor * KQ_scaled =
ggml_scale(ctx0,
KQ,
KQ_scale);
1.0f/sqrtf(float(n_embd)/n_head));

// KQ_masked = mask_past(KQ_scaled)
// [n_past + N, N, 12]
Expand Down
9 changes: 1 addition & 8 deletions examples/gpt-2/main-backend.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -484,13 +484,6 @@ struct ggml_cgraph * gpt2_graph(
}
}

struct ggml_tensor * KQ_scale = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, 1);
ggml_allocr_alloc(allocr, KQ_scale);
if (!ggml_allocr_is_measure(allocr)) {
float s = 1.0f/sqrtf(float(n_embd)/n_head);
ggml_backend_tensor_set(KQ_scale, &s, 0, sizeof(s));
}

// wte + wpe
struct ggml_tensor * inpL =
ggml_add(ctx0,
Expand Down Expand Up @@ -586,7 +579,7 @@ struct ggml_cgraph * gpt2_graph(
struct ggml_tensor * KQ_scaled =
ggml_scale(ctx0,
KQ,
KQ_scale);
1.0f/sqrtf(float(n_embd)/n_head));

// KQ_masked = mask_past(KQ_scaled)
// [n_past + N, N, 12]
Expand Down
43 changes: 18 additions & 25 deletions examples/gpt-2/main-batched.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -237,30 +237,30 @@ bool gpt2_model_load(const std::string & fname, gpt2_model & model, gpt_vocab &
const int n_ctx = hparams.n_ctx;
const int n_vocab = hparams.n_vocab;

buffer_size += n_embd*ggml_type_sizef(GGML_TYPE_F32); // ln_f_g
buffer_size += n_embd*ggml_type_sizef(GGML_TYPE_F32); // ln_f_b
buffer_size += ggml_row_size(GGML_TYPE_F32, n_embd); // ln_f_g
buffer_size += ggml_row_size(GGML_TYPE_F32, n_embd); // ln_f_b

buffer_size += n_vocab*n_embd*ggml_type_sizef(wtype); // wte
buffer_size += n_ctx*n_embd*ggml_type_sizef(GGML_TYPE_F32); // wpe
buffer_size += n_vocab*n_embd*ggml_type_sizef(wtype); // lm_head
buffer_size += ggml_row_size(wtype, n_vocab*n_embd); // wte
buffer_size += ggml_row_size(GGML_TYPE_F32, n_ctx*n_embd); // wpe
buffer_size += ggml_row_size(wtype, n_vocab*n_embd); // lm_head

buffer_size += n_layer*(n_embd*ggml_type_sizef(GGML_TYPE_F32)); // ln_1_g
buffer_size += n_layer*(n_embd*ggml_type_sizef(GGML_TYPE_F32)); // ln_1_b
buffer_size += n_layer*(ggml_row_size(GGML_TYPE_F32, n_embd)); // ln_1_g
buffer_size += n_layer*(ggml_row_size(GGML_TYPE_F32, n_embd)); // ln_1_b

buffer_size += n_layer*(n_embd*ggml_type_sizef(GGML_TYPE_F32)); // ln_2_g
buffer_size += n_layer*(n_embd*ggml_type_sizef(GGML_TYPE_F32)); // ln_2_b
buffer_size += n_layer*(ggml_row_size(GGML_TYPE_F32, n_embd)); // ln_2_g
buffer_size += n_layer*(ggml_row_size(GGML_TYPE_F32, n_embd)); // ln_2_b

buffer_size += n_layer*(3*n_embd*n_embd*ggml_type_sizef(wtype)); // c_attn_attn_w
buffer_size += n_layer*( 3*n_embd*ggml_type_sizef(GGML_TYPE_F32)); // c_attn_attn_b
buffer_size += n_layer*(ggml_row_size(wtype, 3*n_embd*n_embd)); // c_attn_attn_w
buffer_size += n_layer*(ggml_row_size(GGML_TYPE_F32, 3*n_embd)); // c_attn_attn_b

buffer_size += n_layer*(n_embd*n_embd*ggml_type_sizef(wtype)); // c_attn_proj_w
buffer_size += n_layer*( n_embd*ggml_type_sizef(GGML_TYPE_F32)); // c_attn_proj_b
buffer_size += n_layer*(ggml_row_size(wtype, n_embd*n_embd)); // c_attn_proj_w
buffer_size += n_layer*(ggml_row_size(GGML_TYPE_F32, n_embd)); // c_attn_proj_b

buffer_size += n_layer*(4*n_embd*n_embd*ggml_type_sizef(wtype)); // c_mlp_fc_w
buffer_size += n_layer*( 4*n_embd*ggml_type_sizef(GGML_TYPE_F32)); // c_mlp_fc_b
buffer_size += n_layer*(ggml_row_size(wtype, 4*n_embd*n_embd)); // c_mlp_fc_w
buffer_size += n_layer*(ggml_row_size(GGML_TYPE_F32, 4*n_embd)); // c_mlp_fc_b

buffer_size += n_layer*(4*n_embd*n_embd*ggml_type_sizef(wtype)); // c_mlp_proj_w
buffer_size += n_layer*( n_embd*ggml_type_sizef(GGML_TYPE_F32)); // c_mlp_proj_b
buffer_size += n_layer*(ggml_row_size(wtype, 4*n_embd*n_embd)); // c_mlp_proj_w
buffer_size += n_layer*(ggml_row_size(GGML_TYPE_F32, 4*n_embd)); // c_mlp_proj_b

buffer_size += (6 + 12*n_layer)*128; // alignment overhead

Expand Down Expand Up @@ -599,13 +599,6 @@ struct ggml_cgraph * gpt2_graph(
}
}

struct ggml_tensor * KQ_scale = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, 1);
ggml_allocr_alloc(allocr, KQ_scale);
if (!ggml_allocr_is_measure(allocr)) {
float s = 1.0f/sqrtf(float(n_embd)/n_head);
ggml_backend_tensor_set(KQ_scale, &s, 0, sizeof(s));
}

// KQ_mask (mask for 1 head, it will be broadcasted to all heads)
struct ggml_tensor * KQ_mask = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_kv, n_tokens, 1);
ggml_set_name(KQ_mask, "KQ_mask");
Expand Down Expand Up @@ -720,7 +713,7 @@ struct ggml_cgraph * gpt2_graph(
struct ggml_tensor * KQ_scaled =
ggml_scale(ctx0,
KQ,
KQ_scale);
1.0f/sqrtf(float(n_embd)/n_head));

// KQ_masked = mask_past(KQ_scaled)
// [n_kv, n_tokens, 12]
Expand Down
44 changes: 20 additions & 24 deletions examples/gpt-2/main-ctx.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -164,33 +164,33 @@ bool gpt2_model_load(const std::string & fname, gpt2_model & model, gpt_vocab &
const int n_ctx = hparams.n_ctx;
const int n_vocab = hparams.n_vocab;

ctx_size += n_embd*ggml_type_sizef(GGML_TYPE_F32); // ln_f_g
ctx_size += n_embd*ggml_type_sizef(GGML_TYPE_F32); // ln_f_b
ctx_size += ggml_row_size(GGML_TYPE_F32, n_embd); // ln_f_g
ctx_size += ggml_row_size(GGML_TYPE_F32, n_embd); // ln_f_b

ctx_size += n_vocab*n_embd*ggml_type_sizef(wtype); // wte
ctx_size += n_ctx*n_embd*ggml_type_sizef(GGML_TYPE_F32); // wpe
ctx_size += n_vocab*n_embd*ggml_type_sizef(wtype); // lm_head
ctx_size += ggml_row_size(wtype, n_vocab*n_embd); // wte
ctx_size += ggml_row_size(GGML_TYPE_F32, n_ctx*n_embd); // wpe
ctx_size += ggml_row_size(wtype, n_vocab*n_embd); // lm_head

ctx_size += n_layer*(n_embd*ggml_type_sizef(GGML_TYPE_F32)); // ln_1_g
ctx_size += n_layer*(n_embd*ggml_type_sizef(GGML_TYPE_F32)); // ln_1_b
ctx_size += n_layer*(ggml_row_size(GGML_TYPE_F32, n_embd)); // ln_1_g
ctx_size += n_layer*(ggml_row_size(GGML_TYPE_F32, n_embd)); // ln_1_b

ctx_size += n_layer*(n_embd*ggml_type_sizef(GGML_TYPE_F32)); // ln_2_g
ctx_size += n_layer*(n_embd*ggml_type_sizef(GGML_TYPE_F32)); // ln_2_b
ctx_size += n_layer*(ggml_row_size(GGML_TYPE_F32, n_embd)); // ln_2_g
ctx_size += n_layer*(ggml_row_size(GGML_TYPE_F32, n_embd)); // ln_2_b

ctx_size += n_layer*(3*n_embd*n_embd*ggml_type_sizef(wtype)); // c_attn_attn_w
ctx_size += n_layer*( 3*n_embd*ggml_type_sizef(GGML_TYPE_F32)); // c_attn_attn_b
ctx_size += n_layer*(ggml_row_size(wtype, 3*n_embd*n_embd)); // c_attn_attn_w
ctx_size += n_layer*(ggml_row_size(GGML_TYPE_F32, 3*n_embd)); // c_attn_attn_b

ctx_size += n_layer*(n_embd*n_embd*ggml_type_sizef(wtype)); // c_attn_proj_w
ctx_size += n_layer*( n_embd*ggml_type_sizef(GGML_TYPE_F32)); // c_attn_proj_b
ctx_size += n_layer*(ggml_row_size(wtype, n_embd*n_embd)); // c_attn_proj_w
ctx_size += n_layer*(ggml_row_size(GGML_TYPE_F32, n_embd)); // c_attn_proj_b

ctx_size += n_layer*(4*n_embd*n_embd*ggml_type_sizef(wtype)); // c_mlp_fc_w
ctx_size += n_layer*( 4*n_embd*ggml_type_sizef(GGML_TYPE_F32)); // c_mlp_fc_b
ctx_size += n_layer*(ggml_row_size(wtype, 4*n_embd*n_embd)); // c_mlp_fc_w
ctx_size += n_layer*(ggml_row_size(GGML_TYPE_F32, 4*n_embd)); // c_mlp_fc_b

ctx_size += n_layer*(4*n_embd*n_embd*ggml_type_sizef(wtype)); // c_mlp_proj_w
ctx_size += n_layer*( n_embd*ggml_type_sizef(GGML_TYPE_F32)); // c_mlp_proj_b
ctx_size += n_layer*(ggml_row_size(wtype, 4*n_embd*n_embd)); // c_mlp_proj_w
ctx_size += n_layer*(ggml_row_size(GGML_TYPE_F32, 4*n_embd)); // c_mlp_proj_b

ctx_size += n_ctx*n_layer*n_embd*ggml_type_sizef(GGML_TYPE_F32); // memory_k
ctx_size += n_ctx*n_layer*n_embd*ggml_type_sizef(GGML_TYPE_F32); // memory_v
ctx_size += n_ctx*n_layer*ggml_row_size(GGML_TYPE_F32, n_embd); // memory_k
ctx_size += n_ctx*n_layer*ggml_row_size(GGML_TYPE_F32, n_embd); // memory_v

ctx_size += (6 + 12*n_layer)*512; // object overhead

Expand Down Expand Up @@ -531,11 +531,7 @@ bool gpt2_eval(

// KQ_scaled = KQ / sqrt(n_embd/n_head)
// [n_past + N, N, 12]
struct ggml_tensor * KQ_scaled =
ggml_scale_inplace(ctx0,
KQ,
ggml_new_f32(ctx0, 1.0f/sqrt(float(n_embd)/n_head))
);
struct ggml_tensor * KQ_scaled = ggml_scale_inplace(ctx0, KQ, 1.0f/sqrt(float(n_embd)/n_head));

// KQ_masked = mask_past(KQ_scaled)
// [n_past + N, N, 12]
Expand Down
17 changes: 3 additions & 14 deletions examples/gpt-2/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,6 @@ struct gpt2_model {
// inputs/constants
struct ggml_tensor * embd;
struct ggml_tensor * position;
struct ggml_tensor * KQ_scale;
};

void init_backends(gpt2_model & model, const gpt_params & params) {
Expand Down Expand Up @@ -511,14 +510,12 @@ bool gpt2_model_load(const std::string & fname, gpt2_model & model, gpt_vocab &
{
model.embd = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, model.hparams.n_ctx);
model.position = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, model.hparams.n_ctx);
model.KQ_scale = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, 1); // FIXME: should be in backend_kv, but also shouldn't matter

ggml_set_name(model.embd, "in/embd");
ggml_set_name(model.position, "in/position");
ggml_set_name(model.KQ_scale, "KQ_scale");

// add input tensors to cpu backend
size_t input_size = ggml_nbytes(model.embd) + ggml_nbytes(model.position) + ggml_nbytes(model.KQ_scale);
size_t input_size = ggml_nbytes(model.embd) + ggml_nbytes(model.position);

// FIXME: use cpu backend after sched impl
ggml_backend_t backend_input = params.n_gpu_layers >= model.hparams.n_layer ? backend_gpu : backend_cpu;
Expand All @@ -529,12 +526,7 @@ bool gpt2_model_load(const std::string & fname, gpt2_model & model, gpt_vocab &
ggml_allocr * alloc = ggml_allocr_new_from_buffer(model.buffer_input);
ggml_allocr_alloc(alloc, model.embd);
ggml_allocr_alloc(alloc, model.position);
ggml_allocr_alloc(alloc, model.KQ_scale);
ggml_allocr_free(alloc);

// initialize KQ_scale
float s = 1.0f/sqrtf(float(model.hparams.n_embd)/model.hparams.n_head);
ggml_backend_tensor_set(model.KQ_scale, &s, 0, sizeof(s));
}

return true;
Expand Down Expand Up @@ -588,7 +580,7 @@ struct ggml_cgraph * gpt2_graph(
}
//}

struct ggml_tensor * KQ_scale = model.KQ_scale;
const float KQ_scale = 1.0f/sqrtf(float(model.hparams.n_embd)/model.hparams.n_head);

// wte + wpe
struct ggml_tensor * inpL =
Expand Down Expand Up @@ -697,10 +689,7 @@ struct ggml_cgraph * gpt2_graph(

// KQ_scaled = KQ / sqrt(n_embd/n_head)
// [n_past + N, N, 12]
struct ggml_tensor * KQ_scaled =
ggml_scale(ctx0,
KQ,
KQ_scale);
struct ggml_tensor * KQ_scaled = ggml_scale(ctx0, KQ, KQ_scale);
ggml_format_name(KQ_scaled, "l%d.KQ_scaled", il);

// KQ_masked = mask_past(KQ_scaled)
Expand Down
Loading

0 comments on commit 845d01b

Please sign in to comment.