Skip to content

Commit

Permalink
ggml : fix YARN + add tests + add asserts (llama/7617)
Browse files Browse the repository at this point in the history
* tests : add rope tests

ggml-ci

* ggml : fixes (hopefully)

ggml-ci

* tests : add non-cont tests

ggml-ci

* cuda : add asserts for rope/norm + fix DS2

ggml-ci

* ggml : assert contiguousness

* tests : reduce RoPE tests

ggml-ci
  • Loading branch information
ggerganov committed Jun 15, 2024
1 parent ec4f235 commit 3429c39
Show file tree
Hide file tree
Showing 10 changed files with 129 additions and 88 deletions.
6 changes: 5 additions & 1 deletion include/ggml/ggml.h
Original file line number Diff line number Diff line change
Expand Up @@ -756,7 +756,6 @@ extern "C" {
GGML_API enum ggml_type ggml_ftype_to_ggml_type(enum ggml_ftype ftype);

GGML_API GGML_CALL bool ggml_is_transposed(const struct ggml_tensor * tensor);
GGML_API GGML_CALL bool ggml_is_contiguous(const struct ggml_tensor * tensor);
GGML_API GGML_CALL bool ggml_is_permuted (const struct ggml_tensor * tensor);
GGML_API GGML_CALL bool ggml_is_empty (const struct ggml_tensor * tensor);
GGML_API bool ggml_is_scalar (const struct ggml_tensor * tensor);
Expand All @@ -765,6 +764,11 @@ extern "C" {
GGML_API bool ggml_is_3d (const struct ggml_tensor * tensor);
GGML_API int ggml_n_dims (const struct ggml_tensor * tensor); // returns 1 for scalars

GGML_API GGML_CALL bool ggml_is_contiguous (const struct ggml_tensor * tensor);
GGML_API GGML_CALL bool ggml_is_contiguous_0(const struct ggml_tensor * tensor); // same as ggml_is_contiguous()
GGML_API GGML_CALL bool ggml_is_contiguous_1(const struct ggml_tensor * tensor); // contiguous for dims >= 1
GGML_API GGML_CALL bool ggml_is_contiguous_2(const struct ggml_tensor * tensor); // contiguous for dims >= 2

GGML_API bool ggml_are_same_shape (const struct ggml_tensor * t0, const struct ggml_tensor * t1);
GGML_API bool ggml_are_same_stride(const struct ggml_tensor * t0, const struct ggml_tensor * t1);

Expand Down
4 changes: 3 additions & 1 deletion src/ggml-cuda.cu
Original file line number Diff line number Diff line change
Expand Up @@ -1870,7 +1870,7 @@ static void ggml_cuda_mul_mat_batched_cublas(ggml_backend_cuda_context & ctx, co
}
}
#else
if (r2 == 1 && r3 == 1 && src0->nb[2]*src0->ne[2] == src0->nb[3] && src1->nb[2]*src1->ne[2] == src1->nb[3]) {
if (r2 == 1 && r3 == 1 && ggml_is_contiguous_2(src0) && ggml_is_contiguous_2(src1)) {
// there is no broadcast and src0, src1 are contiguous across dims 2, 3
// use cublasGemmStridedBatchedEx
CUBLAS_CHECK(
Expand Down Expand Up @@ -2886,7 +2886,9 @@ GGML_CALL static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, cons
case GGML_OP_CONT:
case GGML_OP_DIAG_MASK_INF:
case GGML_OP_SOFT_MAX:
return true;
case GGML_OP_ROPE:
return ggml_is_contiguous(op->src[0]);
case GGML_OP_IM2COL:
case GGML_OP_POOL_2D:
case GGML_OP_SUM_ROWS:
Expand Down
6 changes: 6 additions & 0 deletions src/ggml-cuda/norm.cu
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,8 @@ void ggml_cuda_op_norm(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
float * dst_d = (float *)dst->data;
cudaStream_t stream = ctx.stream();

GGML_ASSERT(ggml_is_contiguous(src0));

GGML_ASSERT(src0->type == GGML_TYPE_F32);
GGML_ASSERT( dst->type == GGML_TYPE_F32);

Expand All @@ -188,6 +190,8 @@ void ggml_cuda_op_group_norm(ggml_backend_cuda_context & ctx, ggml_tensor * dst)
float * dst_d = (float *)dst->data;
cudaStream_t stream = ctx.stream();

GGML_ASSERT(ggml_is_contiguous(src0));

GGML_ASSERT(src0->type == GGML_TYPE_F32);
GGML_ASSERT( dst->type == GGML_TYPE_F32);

Expand All @@ -202,6 +206,8 @@ void ggml_cuda_op_rms_norm(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
float * dst_d = (float *)dst->data;
cudaStream_t stream = ctx.stream();

GGML_ASSERT(ggml_is_contiguous(src0));

GGML_ASSERT(src0->type == GGML_TYPE_F32);
GGML_ASSERT( dst->type == GGML_TYPE_F32);

Expand Down
18 changes: 8 additions & 10 deletions src/ggml-cuda/rope.cu
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ static __global__ void rope(
template<typename T, bool has_pos, bool has_freq_facs>
static __global__ void rope_neox(
const T * x, T * dst, int ncols, int n_dims, const int32_t * pos, float freq_scale, int p_delta_rows,
float ext_factor, float attn_factor, rope_corr_dims corr_dims, float theta_scale, float inv_ndims, const float * freq_factors
float ext_factor, float attn_factor, rope_corr_dims corr_dims, float theta_scale, const float * freq_factors
) {
const int col = 2*(blockDim.y*blockIdx.y + threadIdx.y);

Expand All @@ -85,15 +85,13 @@ static __global__ void rope_neox(
const int i = row*ncols + ib*n_dims + ic/2;
const int i2 = row/p_delta_rows;

float cur_rot = inv_ndims * ic - ib;

const int p = has_pos ? pos[i2] : 0;
const float freq_factor = has_freq_facs ? freq_factors[ic/2] : 1.0f;

const float theta_base = p*freq_scale*powf(theta_scale, col/2.0f)/freq_factor;
const float theta_base = p*powf(theta_scale, col/2.0f)/freq_factor;

float cos_theta, sin_theta;
rope_yarn(theta_base, freq_scale, corr_dims, cur_rot, ext_factor, attn_factor, &cos_theta, &sin_theta);
rope_yarn(theta_base, freq_scale, corr_dims, ic, ext_factor, attn_factor, &cos_theta, &sin_theta);

const float x0 = x[i + 0];
const float x1 = x[i + n_dims/2];
Expand Down Expand Up @@ -174,30 +172,29 @@ static void rope_neox_cuda(
const dim3 block_nums(nrows, num_blocks_x, 1);

const float theta_scale = powf(freq_base, -2.0f/n_dims);
const float inv_ndims = -1.0f / n_dims;

if (pos == nullptr) {
if (freq_factors == nullptr) {
rope_neox<T, false, false><<<block_nums, block_dims, 0, stream>>>(
x, dst, ncols, n_dims, pos, freq_scale, p_delta_rows, ext_factor, attn_factor, corr_dims,
theta_scale, inv_ndims, freq_factors
theta_scale, freq_factors
);
} else {
rope_neox<T, false, true><<<block_nums, block_dims, 0, stream>>>(
x, dst, ncols, n_dims, pos, freq_scale, p_delta_rows, ext_factor, attn_factor, corr_dims,
theta_scale, inv_ndims, freq_factors
theta_scale, freq_factors
);
}
} else {
if (freq_factors == nullptr) {
rope_neox<T, true, false><<<block_nums, block_dims, 0, stream>>>(
x, dst, ncols, n_dims, pos, freq_scale, p_delta_rows, ext_factor, attn_factor, corr_dims,
theta_scale, inv_ndims, freq_factors
theta_scale, freq_factors
);
} else {
rope_neox<T, true, true><<<block_nums, block_dims, 0, stream>>>(
x, dst, ncols, n_dims, pos, freq_scale, p_delta_rows, ext_factor, attn_factor, corr_dims,
theta_scale, inv_ndims, freq_factors
theta_scale, freq_factors
);
}
}
Expand Down Expand Up @@ -254,6 +251,7 @@ void ggml_cuda_op_rope(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
float * dst_d = (float *)dst->data;
cudaStream_t stream = ctx.stream();

GGML_ASSERT(ggml_is_contiguous(src0));
GGML_ASSERT(src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16);
GGML_ASSERT( dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16);
GGML_ASSERT(src0->type == dst->type);
Expand Down
4 changes: 3 additions & 1 deletion src/ggml-kompute.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1597,7 +1597,9 @@ static void ggml_vk_graph_compute(struct ggml_kompute_context * ctx, struct ggml
{
GGML_ASSERT(ne00 == ne10);

// TODO: assert that dim2 and dim3 are contiguous
ggml_is_contiguous_2(src0);
ggml_is_contiguous_2(src1);

GGML_ASSERT(ne12 % ne02 == 0);
GGML_ASSERT(ne13 % ne03 == 0);

Expand Down
8 changes: 7 additions & 1 deletion src/ggml-metal.m
Original file line number Diff line number Diff line change
Expand Up @@ -1519,7 +1519,9 @@ static enum ggml_status ggml_metal_graph_compute(
{
GGML_ASSERT(ne00 == ne10);

// TODO: assert that dim2 and dim3 are contiguous
ggml_is_contiguous_2(src0);
ggml_is_contiguous_2(src1);

GGML_ASSERT(ne12 % ne02 == 0);
GGML_ASSERT(ne13 % ne03 == 0);

Expand Down Expand Up @@ -2187,6 +2189,7 @@ static enum ggml_status ggml_metal_graph_compute(
case GGML_OP_RMS_NORM:
{
GGML_ASSERT(ne00 % 4 == 0);
GGML_ASSERT(ggml_is_contiguous_1(src0));

float eps;
memcpy(&eps, dst->op_params, sizeof(float));
Expand Down Expand Up @@ -2214,6 +2217,7 @@ static enum ggml_status ggml_metal_graph_compute(
case GGML_OP_GROUP_NORM:
{
GGML_ASSERT(ne00 % 4 == 0);
GGML_ASSERT(ggml_is_contiguous(src0));

//float eps;
//memcpy(&eps, dst->op_params, sizeof(float));
Expand Down Expand Up @@ -2247,6 +2251,8 @@ static enum ggml_status ggml_metal_graph_compute(
} break;
case GGML_OP_NORM:
{
GGML_ASSERT(ggml_is_contiguous_1(src0));

float eps;
memcpy(&eps, dst->op_params, sizeof(float));

Expand Down
16 changes: 6 additions & 10 deletions src/ggml-metal.metal
Original file line number Diff line number Diff line change
Expand Up @@ -1767,13 +1767,13 @@ kernel void kernel_rope(

const int64_t p = pos[i2];

const float theta_0 = (float)p;
const float theta_base = (float)p;
const float inv_ndims = -1.f/n_dims;

if (!is_neox) {
for (int64_t i0 = 2*tiitg; i0 < ne0; i0 += 2*tptg.x) {
const float theta = theta_base * pow(freq_base, inv_ndims*i0);

const float theta = theta_0 * pow(freq_base, inv_ndims*i0);
float cos_theta, sin_theta;
rope_yarn(theta, freq_scale, corr_dims, i0, ext_factor, attn_factor, &cos_theta, &sin_theta);

Expand All @@ -1789,18 +1789,14 @@ kernel void kernel_rope(
} else {
for (int64_t ic = 2*tiitg; ic < ne0; ic += 2*tptg.x) {
if (ic < n_dims) {
const int64_t ib = 0;
const int64_t i0 = ic/2;

// simplified from `(ib * n_dims + ic) * inv_ndims`
const float cur_rot = inv_ndims*ic - ib;
const float freq_factor = src2 != src0 ? src2[ic/2] : 1.0f;
const float freq_factor = src2 != src0 ? src2[i0] : 1.0f;

const float theta = theta_0 * pow(freq_base, cur_rot) / freq_factor;
const float theta = theta_base * pow(freq_base, inv_ndims*ic);

float cos_theta, sin_theta;
rope_yarn(theta, freq_scale, corr_dims, cur_rot, ext_factor, attn_factor, &cos_theta, &sin_theta);

const int64_t i0 = ib*n_dims + ic/2;
rope_yarn(theta/freq_factor, freq_scale, corr_dims, ic, ext_factor, attn_factor, &cos_theta, &sin_theta);

device const T * const src = (device T *)((device char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
device T * dst_data = (device T *)((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
Expand Down
2 changes: 1 addition & 1 deletion src/ggml-sycl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15183,7 +15183,7 @@ static void ggml_sycl_mul_mat_batched_sycl(const ggml_tensor *src0,
const int64_t r2 = ne12/ne02;
const int64_t r3 = ne13/ne03;

if (r2 == 1 && r3 == 1 && src0->nb[2]*src0->ne[2] == src0->nb[3] && src1->nb[2]*src1->ne[2] == src1->nb[3]) {
if (r2 == 1 && r3 == 1 && ggml_is_contiguous_2(src0) && ggml_is_contiguous_2(src1)) {
// there is no broadcast and src0, src1 are contiguous across dims 2, 3
SYCL_CHECK(CHECK_TRY_ERROR(dpct::gemm_batch(
*g_sycl_handles[g_main_device], oneapi::mkl::transpose::trans,
Expand Down
Loading

0 comments on commit 3429c39

Please sign in to comment.