Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: cuda implementation for ggml_conv_transpose_1d #854

Open
wants to merge 11 commits into
base: master
Choose a base branch
from
4 changes: 4 additions & 0 deletions src/ggml-cuda.cu
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
#include "ggml-cuda/tsembd.cuh"
#include "ggml-cuda/unary.cuh"
#include "ggml-cuda/upscale.cuh"
#include "ggml-cuda/conv-transpose-1d.cuh"

#include <algorithm>
#include <array>
Expand Down Expand Up @@ -2338,6 +2339,9 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
case GGML_OP_IM2COL:
ggml_cuda_op_im2col(ctx, dst);
break;
case GGML_OP_CONV_TRANSPOSE_1D:
ggml_cuda_op_conv_transpose_1d(ctx,dst);
break;
case GGML_OP_POOL_2D:
ggml_cuda_op_pool2d(ctx, dst);
break;
Expand Down
91 changes: 91 additions & 0 deletions src/ggml-cuda/conv-transpose-1d.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
#include "conv-transpose-1d.cuh"

static __global__ void conv_transpose_1d_kernel(
const int s0, const int p0, const int d0, const int output_size,
const int src0_ne0, const int src0_ne1, const int src0_ne2, const int src0_ne3,
const int src1_ne0, const int src1_ne1, const int src1_ne2, const int src1_ne3,
const int dst_ne0, const int dst_ne1, const int dst_ne2, const int dst_ne3,
const float * src0, const float * src1, float * dst) {
int global_index = threadIdx.x + blockIdx.x * blockDim.x;
if (global_index >= output_size) {
return;
}

int out_index = global_index / dst_ne0;

float accumulator = 0;

for (int c = 0; c < src0_ne2; c++)
{

int idx = global_index % dst_ne0;

int kernel_offset = (src0_ne0 * src0_ne1 * c) + (out_index * src0_ne0);
int input_offset = src1_ne0 * c;

int initial_weight_idx = idx > src0_ne0 -1 ? src0_ne0-1 : idx;

for (int i = 0; i < src1_ne0; i++)
{
if (!(idx >= i*s0 && idx < i*s0 + src0_ne0))
{
continue;
}
int weight_idx = idx - i*s0;


float kernel_weight = src0[kernel_offset + weight_idx];
float input_value = src1[input_offset+i];

accumulator += kernel_weight * input_value;
}
}
dst[global_index] = accumulator;
}

static void conv_transpose_1d_f32_f32_cuda(
const int s0, const int p0, const int d0, const int output_size,
const int src0_ne0, const int src0_ne1, const int src0_ne2, const int src0_ne3,
const int src1_ne0, const int src1_ne1, const int src1_ne2, const int src1_ne3,
const int dst_ne0, const int dst_ne1, const int dst_ne2, const int dst_ne3,
const float * src0, const float * src1, float * dst,
cudaStream_t stream) {

const int num_blocks = (output_size + CUDA_CONV_TRANPOSE_1D_BLOCK_SIZE - 1) / CUDA_CONV_TRANPOSE_1D_BLOCK_SIZE;
conv_transpose_1d_kernel<<<num_blocks,CUDA_CONV_TRANPOSE_1D_BLOCK_SIZE, 0, stream>>>(s0,p0,d0,output_size,
src0_ne0, src0_ne1, src0_ne2, src0_ne3,
src1_ne0, src1_ne1, src1_ne2, src1_ne3,
dst_ne0, dst_ne1, dst_ne2, dst_ne3,
src0,src1, dst);
}

void ggml_cuda_op_conv_transpose_1d(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
const ggml_tensor * src0 = dst->src[0];
const float * src0_d = (const float *)src0->data;

const ggml_tensor * src1 = dst->src[1];
const float * src1_d = (const float *)src1->data;

float * dst_d = (float *)dst->data;
cudaStream_t stream = ctx.stream();

GGML_ASSERT(src0->type == GGML_TYPE_F32);
GGML_ASSERT( dst->type == GGML_TYPE_F32);
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also assert contiguous src0 and src1


const int32_t * opts = (const int32_t *)dst->op_params;

const int s0 = dst->op_params[0];
const int p0 = 0;//opts[3];
const int d0 = 1;//opts[4];

const int64_t kernel_size = ggml_nelements(src0);
const int64_t input_size = ggml_nelements(src1);
const int64_t output_size = ggml_nelements(dst);


conv_transpose_1d_f32_f32_cuda( s0,p0,d0,output_size,
src0->ne[0],src0->ne[1],src0->ne[2],src0->ne[3],
src1->ne[0],src1->ne[1],src1->ne[2],src1->ne[3],
dst->ne[0],dst->ne[1],dst->ne[2],dst->ne[3],
src0_d, src1_d, dst_d, stream);
}
5 changes: 5 additions & 0 deletions src/ggml-cuda/conv-transpose-1d.cuh
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
#include "common.cuh"

#define CUDA_CONV_TRANPOSE_1D_BLOCK_SIZE 256

void ggml_cuda_op_conv_transpose_1d(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
8 changes: 8 additions & 0 deletions tests/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -337,6 +337,14 @@ add_executable(${TEST_TARGET} ${TEST_TARGET}.c)
target_link_libraries(${TEST_TARGET} PRIVATE ggml)
add_test(NAME ${TEST_TARGET} COMMAND $<TARGET_FILE:${TEST_TARGET}>)

# test-conv-transpose-1d

set(TEST_TARGET test-conv-transpose-1d)
add_executable(${TEST_TARGET} ${TEST_TARGET}.cpp)
target_link_libraries(${TEST_TARGET} PRIVATE ggml)
add_test(NAME ${TEST_TARGET} COMMAND $<TARGET_FILE:${TEST_TARGET}>)


#
# test-dup

Expand Down
37 changes: 36 additions & 1 deletion tests/test-backend-ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1216,6 +1216,38 @@ struct test_pool2d : public test_case {
}
};

// GGML_OP_CONV_TRANSPOSE_1D
struct test_conv_transpose_1d : public test_case {

const std::array<int64_t, 4> ne_input;
const std::array<int64_t, 4> ne_kernel;

// stride
const int s0;
// padding
const int p0;
// dilation
const int d0;

std::string vars() override {
return VARS_TO_STR5(ne_input, ne_kernel, s0, p0, d0);
}

test_conv_transpose_1d(std::array<int64_t, 4> ne_input = {197, 32, 1, 1}, // [input_width, input_height, input_channels, 1]
std::array<int64_t, 4> ne_kernel = {16, 32, 32, 1}, // [kernel_width, kernel_height, input_channels, 1]
int s0 = 1, int p0 = 0, int d0 = 1)
: ne_input(ne_input), ne_kernel(ne_kernel), s0(s0), p0(p0), d0(d0) {}

ggml_tensor * build_graph(ggml_context * ctx) override {
ggml_tensor * input = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne_input.data());
ggml_tensor * kernel = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne_kernel.data());
ggml_tensor * out = ggml_conv_transpose_1d(ctx, kernel, input, s0, p0, d0);
return out;
}
};



// GGML_OP_IM2COL
struct test_im2col : public test_case {
const ggml_type type_input;
Expand All @@ -1229,7 +1261,7 @@ struct test_im2col : public test_case {
// padding
const int p0;
const int p1;
// dilatation
// dilation
const int d0;
const int d1;
// mode
Expand Down Expand Up @@ -2026,6 +2058,9 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op
test_cases.emplace_back(new test_im2col(GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_F32));
test_cases.emplace_back(new test_im2col(GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_F16));

test_cases.emplace_back(new test_conv_transpose_1d());


test_cases.emplace_back(new test_repeat(GGML_TYPE_F32, {10, 10, 10, 10}, {1, 1, 1, 1}));
test_cases.emplace_back(new test_repeat(GGML_TYPE_F32, {10, 10, 10, 10}, {2, 1, 1, 1}));
test_cases.emplace_back(new test_repeat(GGML_TYPE_F32, {10, 10, 10, 10}, {1, 2, 1, 1}));
Expand Down