Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: adds ggml_pad_reflect_1d #850

Open
wants to merge 6 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions include/ggml.h
Original file line number Diff line number Diff line change
Expand Up @@ -480,6 +480,7 @@ extern "C" {
GGML_OP_CONV_TRANSPOSE_2D,
GGML_OP_POOL_1D,
GGML_OP_POOL_2D,
GGML_OP_PAD_REFLECT_1D,
GGML_OP_UPSCALE, // nearest interpolate
GGML_OP_PAD,
GGML_OP_ARANGE,
Expand Down Expand Up @@ -1642,6 +1643,14 @@ extern "C" {
struct ggml_tensor * b,
int stride);


GGML_API struct ggml_tensor * ggml_pad_reflect_1d(
struct ggml_context * ctx,
struct ggml_tensor * a,
int p0,
int p1);


enum ggml_op_pool {
GGML_OP_POOL_MAX,
GGML_OP_POOL_AVG,
Expand Down
5 changes: 5 additions & 0 deletions src/ggml-cuda.cu
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
#include "ggml-cuda/tsembd.cuh"
#include "ggml-cuda/unary.cuh"
#include "ggml-cuda/upscale.cuh"
#include "ggml-cuda/padreflect.cuh"
#include "ggml-cuda/conv-transpose-1d.cuh"

#include <algorithm>
Expand Down Expand Up @@ -2206,6 +2207,9 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
case GGML_OP_UPSCALE:
ggml_cuda_op_upscale(ctx, dst);
break;
case GGML_OP_PAD_REFLECT_1D:
ggml_cuda_op_pad_reflect_1d(ctx, dst);
break;
case GGML_OP_PAD:
ggml_cuda_op_pad(ctx, dst);
break;
Expand Down Expand Up @@ -2844,6 +2848,7 @@ GGML_CALL static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, cons
case GGML_OP_ACC:
case GGML_OP_GROUP_NORM:
case GGML_OP_UPSCALE:
case GGML_OP_PAD_REFLECT_1D:
case GGML_OP_PAD:
case GGML_OP_ARANGE:
case GGML_OP_TIMESTEP_EMBEDDING:
Expand Down
65 changes: 65 additions & 0 deletions src/ggml-cuda/padreflect.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
#include "padreflect.cuh"

static __global__ void pad_reflect_1d_f32(const float * x, float * dst,
const int nb00, const int nb01,
const int ne10, const int ne11, const int p0,
const int p1, const int inp_size, const int dst_size
) {
int index = threadIdx.x + blockIdx.x * blockDim.x;
if (index >= ne10 * ne11) {
return;
}


const int row_size = ne10;
int column_index = index % row_size;
const int row_index = index / row_size;

if (column_index < p0)
{
column_index = p0 - column_index;
}
else if(column_index < row_size -p1)
{
column_index = column_index - p0;
}
else
{
column_index = (row_size - p1 - p0) - (p1+1-(row_size-column_index)) - 1;
}

int i00 = column_index;
int i01 = row_index;



dst[index] = *(float *)((char *)x + i01 * nb01 + i00 * nb00);
}

static void pad_reflect_1d_f32_cuda(const float * x, float * dst,
const int nb00, const int nb01,
const int ne10, const int ne11,
const int p0, const int p1,
const int inp_size, const int dst_size,
cudaStream_t stream) {
int num_blocks = (dst_size + CUDA_PAD_REFLECT_BLOCK_SIZE - 1) / CUDA_PAD_REFLECT_BLOCK_SIZE;

pad_reflect_1d_f32<<<num_blocks, CUDA_PAD_REFLECT_BLOCK_SIZE,0,stream>>>(x, dst, nb00, nb01, ne10, ne11,p0,p1, inp_size,dst_size);
}

void ggml_cuda_op_pad_reflect_1d(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
const ggml_tensor * src0 = dst->src[0];
const float * src0_d = (const float *)src0->data;
float * dst_d = (float *)dst->data;
cudaStream_t stream = ctx.stream();

GGML_ASSERT(src0->type == GGML_TYPE_F32);
GGML_ASSERT( dst->type == GGML_TYPE_F32);

const int inp_size = src0->ne[0] * src0->ne[1];
const int dst_size = dst->ne[0] * dst->ne[1];



pad_reflect_1d_f32_cuda(src0_d, dst_d, src0->nb[0], src0->nb[1], dst->ne[0], dst->ne[1], dst->op_params[0],dst->op_params[1], inp_size,dst_size, stream);
}
5 changes: 5 additions & 0 deletions src/ggml-cuda/padreflect.cuh
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
#include "common.cuh"

#define CUDA_PAD_REFLECT_BLOCK_SIZE 256

void ggml_cuda_op_pad_reflect_1d(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
91 changes: 89 additions & 2 deletions src/ggml.c
Original file line number Diff line number Diff line change
Expand Up @@ -2695,7 +2695,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
"CROSS_ENTROPY_LOSS_BACK",
};

static_assert(GGML_OP_COUNT == 74, "GGML_OP_COUNT != 74");
static_assert(GGML_OP_COUNT == 75, "GGML_OP_COUNT != 75");

static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
"none",
Expand Down Expand Up @@ -2757,6 +2757,8 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
"argsort(x)",
"leaky_relu(x)",

"pad_reflect_1d(x)",

"flash_attn_ext(x)",
"flash_attn_back(x)",
"ssm_conv(x)",
Expand All @@ -2783,7 +2785,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
"cross_entropy_loss_back(x,y)",
};

static_assert(GGML_OP_COUNT == 74, "GGML_OP_COUNT != 74");
static_assert(GGML_OP_COUNT == 75, "GGML_OP_COUNT != 75");

static_assert(GGML_OP_POOL_COUNT == 2, "GGML_OP_POOL_COUNT != 2");

Expand Down Expand Up @@ -7274,6 +7276,43 @@ struct ggml_tensor * ggml_get_rel_pos(
return result;
}

// ggml_pad_reflect_1d

struct ggml_tensor * ggml_pad_reflect_1d(
struct ggml_context * ctx,
struct ggml_tensor * a,
int p0,
int p1) {

bool is_node = false;

if (a->grad) {
GGML_ASSERT(false); // TODO: implement backward
is_node = true;
}

GGML_ASSERT(p0 < a->ne[0]); // padding length on each size must be less than the
GGML_ASSERT(p1 < a->ne[0]); // existing length of the dimension being padded

GGML_ASSERT(a->ne[2] == 1);
GGML_ASSERT(a->ne[3] == 1);

GGML_ASSERT(ggml_is_contiguous(a));


const int64_t ne[2] = { p0 + a->ne[0] + p1, a->ne[1] };
struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 2, ne);

int32_t params[] = { p0, p1 };
ggml_set_op_params(result, params, sizeof(params));

result->op = GGML_OP_PAD_REFLECT_1D;
result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
result->src[0] = a;

return result;
}

// ggml_add_rel_pos

static struct ggml_tensor * ggml_add_rel_pos_impl(
Expand Down Expand Up @@ -13409,6 +13448,42 @@ static void ggml_compute_forward_diag_mask_f32(
}
}

// ggml_compute_forward_pad_reflect_1d

static void ggml_compute_forward_pad_reflect_1d(
const struct ggml_compute_params * params,
const struct ggml_tensor * src0,
struct ggml_tensor * dst) {
GGML_ASSERT(src0->type == GGML_TYPE_F32);
GGML_ASSERT( dst->type == GGML_TYPE_F32);

const int32_t * opts = (const int32_t *) dst->op_params;
const int p0 = opts[0];
const int p1 = opts[1];
GGML_ASSERT(p0 >= 0);
GGML_ASSERT(p1 >= 0);

const int ne00 = src0->ne[0];

const int nb01 = src0->nb[1];

const int ne0 = dst->ne[0];
const int ne1 = dst->ne[1];

const int nb0 = dst->nb[0];
const int nb1 = dst->nb[1];

for (int i1 = 0; i1 < ne1; i1++) {
float * left = (float *) ((char *) dst->data + i1*nb1 + p0*nb0);
float * right = (float *) ((char *) dst->data + i1*nb1 + (ne0-p1-1)*nb0);

ggml_vec_cpy_f32(ne00, left, (float *) ((char *) src0->data + i1*nb01));

for (int i0 = 1; i0 <= p0; i0++) { left[-i0] = left[i0]; }
for (int i0 = 1; i0 <= p1; i0++) { right[i0] = right[-i0]; }
}
}

static void ggml_compute_forward_diag_mask_inf(
const struct ggml_compute_params * params,
struct ggml_tensor * dst) {
Expand Down Expand Up @@ -16679,6 +16754,10 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
{
ggml_compute_forward_conv_transpose_2d(params, tensor);
} break;
case GGML_OP_PAD_REFLECT_1D:
{
ggml_compute_forward_pad_reflect_1d(params, tensor->src[0], tensor);
} break;
case GGML_OP_POOL_1D:
{
ggml_compute_forward_pool_1d(params, tensor);
Expand Down Expand Up @@ -17679,6 +17758,10 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
{
GGML_ASSERT(false); // TODO: not implemented
} break;
case GGML_OP_PAD_REFLECT_1D:
{
GGML_ASSERT(false); // TODO: not implemented
} break;
case GGML_OP_POOL_1D:
{
GGML_ASSERT(false); // TODO: not implemented
Expand Down Expand Up @@ -18388,6 +18471,10 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) {
{
n_tasks = n_threads;
} break;
case GGML_OP_PAD_REFLECT_1D:
{
n_tasks = 1;
} break;
case GGML_OP_POOL_1D:
case GGML_OP_POOL_2D:
{
Expand Down
7 changes: 7 additions & 0 deletions tests/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -359,6 +359,13 @@ add_executable(${TEST_TARGET} ${TEST_TARGET}.c)
target_link_libraries(${TEST_TARGET} PRIVATE ggml)
add_test(NAME ${TEST_TARGET} COMMAND $<TARGET_FILE:${TEST_TARGET}>)

# test-pad-reflect-1d

set(TEST_TARGET test-pad-reflect-1d)
add_executable(${TEST_TARGET} ${TEST_TARGET}.cpp)
target_link_libraries(${TEST_TARGET} PRIVATE ggml)
add_test(NAME ${TEST_TARGET} COMMAND $<TARGET_FILE:${TEST_TARGET}>)

#
# test-rel-pos

Expand Down
28 changes: 28 additions & 0 deletions tests/test-backend-ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1535,6 +1535,29 @@ struct test_acc : public test_case {
}
};

// GGML_OP_PAD_REFLECT_1D
struct test_pad_reflect_1d : public test_case {
const ggml_type type;
const std::array<int64_t, 2> ne_a;
const int pad_0;
const int pad_1;

std::string vars() override {
return VARS_TO_STR4(type, ne_a, pad_0, pad_1);
}

test_pad_reflect_1d(ggml_type type = GGML_TYPE_F32,
std::array<int64_t, 2> ne_a = {512, 10},
int pad_0 = 10, int pad_1 = 9)
: type(type), ne_a(ne_a), pad_0(pad_0), pad_1(pad_1) {}

ggml_tensor * build_graph(ggml_context * ctx) override {
ggml_tensor * a = ggml_new_tensor(ctx, type, 2, ne_a.data());
ggml_tensor * out = ggml_pad_reflect_1d(ctx, a, pad_0, pad_1);
return out;
}
};

// GGML_OP_PAD
struct test_pad : public test_case {
const ggml_type type;
Expand Down Expand Up @@ -2363,6 +2386,11 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op
test_cases.emplace_back(new test_group_norm());
test_cases.emplace_back(new test_acc());
test_cases.emplace_back(new test_pad());
test_cases.emplace_back(new test_pad_reflect_1d());
test_cases.emplace_back(new test_pad_reflect_1d(GGML_TYPE_F32, {1024}, 5 , 6));
test_cases.emplace_back(new test_pad_reflect_1d(GGML_TYPE_F32, {2}, 1 , 0));
test_cases.emplace_back(new test_pad_reflect_1d(GGML_TYPE_F32, {10}, 7 , 3));

test_cases.emplace_back(new test_arange());
test_cases.emplace_back(new test_timestep_embedding());
test_cases.emplace_back(new test_leaky_relu());
Expand Down
Loading
Loading