Skip to content

Commit

Permalink
ggml : refactor rope norm/neox (llama/7634)
Browse files Browse the repository at this point in the history
* ggml : unify rope norm/neox (CPU)

* ggml : fix compile warning

* ggml : remove GLM rope mode

ggml-ci

* metal : better rope implementation

ggml-ci

* cuda : better rope implementation

ggml-ci

* naming : n_orig_ctx -> n_ctx_orig

ggml-ci

* dev : add reminders to update backends

ggml-ci

* vulkan : fix ggml_rope_ext() usage

* cuda : fix array size + indents

ggml-ci
  • Loading branch information
ggerganov committed Jun 15, 2024
1 parent a524d3f commit f27d7fc
Show file tree
Hide file tree
Showing 10 changed files with 365 additions and 612 deletions.
36 changes: 9 additions & 27 deletions include/ggml/ggml.h
Original file line number Diff line number Diff line change
Expand Up @@ -1465,7 +1465,6 @@ extern "C" {
// rotary position embedding
// if mode & 1 == 1, skip n_past elements (NOT SUPPORTED)
// if mode & 2 == 1, GPT-NeoX style
// if mode & 4 == 1, ChatGLM style
//
// b is an int32 vector with size a->ne[2], it contains the positions
// c is freq factors (e.g. phi3-128k), (optional)
Expand All @@ -1474,17 +1473,15 @@ extern "C" {
struct ggml_tensor * a,
struct ggml_tensor * b,
int n_dims,
int mode,
int n_ctx);
int mode);

// in-place, returns view(a)
GGML_API struct ggml_tensor * ggml_rope_inplace(
struct ggml_context * ctx,
struct ggml_tensor * a,
struct ggml_tensor * b,
int n_dims,
int mode,
int n_ctx);
int mode);

// custom RoPE
GGML_API struct ggml_tensor * ggml_rope_ext(
Expand All @@ -1494,8 +1491,7 @@ extern "C" {
struct ggml_tensor * c,
int n_dims,
int mode,
int n_ctx,
int n_orig_ctx,
int n_ctx_orig,
float freq_base,
float freq_scale,
float ext_factor,
Expand All @@ -1511,8 +1507,7 @@ extern "C" {
struct ggml_tensor * c,
int n_dims,
int mode,
int n_ctx,
int n_orig_ctx,
int n_ctx_orig,
float freq_base,
float freq_scale,
float ext_factor,
Expand All @@ -1526,8 +1521,7 @@ extern "C" {
struct ggml_tensor * b,
int n_dims,
int mode,
int n_ctx,
int n_orig_ctx,
int n_ctx_orig,
float freq_base,
float freq_scale,
float ext_factor,
Expand All @@ -1542,8 +1536,7 @@ extern "C" {
struct ggml_tensor * b,
int n_dims,
int mode,
int n_ctx,
int n_orig_ctx,
int n_ctx_orig,
float freq_base,
float freq_scale,
float ext_factor,
Expand All @@ -1552,17 +1545,9 @@ extern "C" {
float beta_slow),
"use ggml_rope_ext_inplace instead");

struct ggml_tensor * ggml_rope_xpos_inplace(
struct ggml_context * ctx,
struct ggml_tensor * a,
struct ggml_tensor * b,
int n_dims,
float base,
bool down);

// compute correction dims for YaRN RoPE scaling
GGML_CALL void ggml_rope_yarn_corr_dims(
int n_dims, int n_orig_ctx, float freq_base, float beta_fast, float beta_slow, float dims[2]);
int n_dims, int n_ctx_orig, float freq_base, float beta_fast, float beta_slow, float dims[2]);

// rotary position embedding backward, i.e compute dx from dy
// a - dy
Expand All @@ -1573,16 +1558,13 @@ extern "C" {
struct ggml_tensor * c,
int n_dims,
int mode,
int n_ctx,
int n_orig_ctx,
int n_ctx_orig,
float freq_base,
float freq_scale,
float ext_factor,
float attn_factor,
float beta_fast,
float beta_slow,
float xpos_base,
bool xpos_down);
float beta_slow);

// clamp
// in-place, returns view(a)
Expand Down
Loading

0 comments on commit f27d7fc

Please sign in to comment.