Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: cuda implementation for ggml_conv_transpose_1d #854

Open
wants to merge 11 commits into
base: master
Choose a base branch
from
Prev Previous commit
Next Next commit
removed unused arugments, and fixed bug where test failure would caus…
…e subsequent tests to fail
  • Loading branch information
balisujohn committed Jun 14, 2024
commit 5d39cd4da5d70410a61c1acf4d961d6f68544a6e
10 changes: 4 additions & 6 deletions src/ggml-cuda/conv-transpose-1d.cu
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
#include "conv-transpose-1d.cuh"

static __global__ void conv_transpose_1d_kernel(
const int s0, const int p0, const int d0,
const int kernel_size, const int input_size, const int output_size,
const int s0, const int p0, const int d0, const int output_size,
const int src0_ne0, const int src0_ne1, const int src0_ne2, const int src0_ne3,
const int src1_ne0, const int src1_ne1, const int src1_ne2, const int src1_ne3,
const int dst_ne0, const int dst_ne1, const int dst_ne2, const int dst_ne3,
Expand Down Expand Up @@ -45,16 +44,15 @@ static __global__ void conv_transpose_1d_kernel(
}

static void conv_transpose_1d_f32_f32_cuda(
const int s0, const int p0, const int d0,
const int kernel_size, const int input_size, const int output_size,
const int s0, const int p0, const int d0, const int output_size,
const int src0_ne0, const int src0_ne1, const int src0_ne2, const int src0_ne3,
const int src1_ne0, const int src1_ne1, const int src1_ne2, const int src1_ne3,
const int dst_ne0, const int dst_ne1, const int dst_ne2, const int dst_ne3,
const float * src0, const float * src1, float * dst,
cudaStream_t stream) {

const int num_blocks = (output_size + CUDA_CONV_TRANPOSE_1D_BLOCK_SIZE - 1) / CUDA_CONV_TRANPOSE_1D_BLOCK_SIZE;
conv_transpose_1d_kernel<<<num_blocks,CUDA_CONV_TRANPOSE_1D_BLOCK_SIZE, 0, stream>>>(s0,p0,d0,kernel_size, input_size, output_size,
conv_transpose_1d_kernel<<<num_blocks,CUDA_CONV_TRANPOSE_1D_BLOCK_SIZE, 0, stream>>>(s0,p0,d0,output_size,
src0_ne0, src0_ne1, src0_ne2, src0_ne3,
src1_ne0, src1_ne1, src1_ne2, src1_ne3,
dst_ne0, dst_ne1, dst_ne2, dst_ne3,
Expand Down Expand Up @@ -85,7 +83,7 @@ void ggml_cuda_op_conv_transpose_1d(ggml_backend_cuda_context & ctx, ggml_tensor
const int64_t output_size = ggml_nelements(dst);


conv_transpose_1d_f32_f32_cuda( s0,p0,d0,kernel_size, input_size, output_size,
conv_transpose_1d_f32_f32_cuda( s0,p0,d0,output_size,
src0->ne[0],src0->ne[1],src0->ne[2],src0->ne[3],
src1->ne[0],src1->ne[1],src1->ne[2],src1->ne[3],
dst->ne[0],dst->ne[1],dst->ne[2],dst->ne[3],
Expand Down
12 changes: 7 additions & 5 deletions tests/test-conv-transpose-1d.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -591,6 +591,7 @@ int main(void)

printf("ggml_conv_1d_transpose (%d): %s\n", (int) ggml_nelements(conv1d_transpose_res_0), passed && (ggml_nelements(conv1d_transpose_res_0) == n_conv_transpose_1d_test_0) ? "\033[32mPASSED\033[0m" : "\033[31mFAILED\033[0m");

passed = true;
for(int i = 0; i < n_conv_transpose_1d_test_1; i++) {
if(
conv1d_transpose_data_1[i] != expected_conv1d_1[i]) {
Expand All @@ -603,7 +604,7 @@ int main(void)

printf("ggml_conv_1d_transpose (%d): %s\n", (int) ggml_nelements(conv1d_transpose_res_1), passed && (ggml_nelements(conv1d_transpose_res_1) == n_conv_transpose_1d_test_1) ? "\033[32mPASSED\033[0m" : "\033[31mFAILED\033[0m");


passed = true;
for(int i = 0; i < n_conv_transpose_1d_test_2; i++) {
if(
conv1d_transpose_data_2[i] != expected_conv1d_2[i]) {
Expand All @@ -617,7 +618,7 @@ int main(void)
printf("ggml_conv_1d_transpose (%d): %s\n", (int) ggml_nelements(conv1d_transpose_res_2), passed && (ggml_nelements(conv1d_transpose_res_2) == n_conv_transpose_1d_test_2) ? "\033[32mPASSED\033[0m" : "\033[31mFAILED\033[0m");



passed = true;
for(int i = 0; i < n_conv_transpose_1d_test_3; i++) {
if(
conv1d_transpose_data_3[i] != expected_conv1d_3[i]) {
Expand All @@ -630,7 +631,7 @@ int main(void)

printf("ggml_conv_1d_transpose (%d): %s\n", (int) ggml_nelements(conv1d_transpose_res_3), passed && (ggml_nelements(conv1d_transpose_res_3) == n_conv_transpose_1d_test_3) ? "\033[32mPASSED\033[0m" : "\033[31mFAILED\033[0m");


passed = true;
for(int i = 0; i < n_conv_transpose_1d_test_4; i++) {
if(
conv1d_transpose_data_4[i] != expected_conv1d_4[i]) {
Expand All @@ -643,6 +644,7 @@ int main(void)

printf("ggml_conv_1d_transpose (%d): %s\n", (int) ggml_nelements(conv1d_transpose_res_4), passed && (ggml_nelements(conv1d_transpose_res_4) == n_conv_transpose_1d_test_4) ? "\033[32mPASSED\033[0m" : "\033[31mFAILED\033[0m");

passed = true;
for(int i = 0; i < n_conv_transpose_1d_test_5; i++) {
if(
conv1d_transpose_data_5[i] != expected_conv1d_5[i]) {
Expand All @@ -655,7 +657,7 @@ int main(void)

printf("ggml_conv_1d_transpose (%d): %s\n", (int) ggml_nelements(conv1d_transpose_res_5), passed && (ggml_nelements(conv1d_transpose_res_5) == n_conv_transpose_1d_test_5) ? "\033[32mPASSED\033[0m" : "\033[31mFAILED\033[0m");


passed = true;
for(int i = 0; i < n_conv_transpose_1d_test_6; i++) {
if(
conv1d_transpose_data_6[i] != expected_conv1d_6[i]) {
Expand All @@ -670,7 +672,7 @@ int main(void)
printf("ggml_conv_1d_transpose (%d): %s\n", (int) ggml_nelements(conv1d_transpose_res_6), passed && (ggml_nelements(conv1d_transpose_res_6) == n_conv_transpose_1d_test_6) ? "\033[32mPASSED\033[0m" : "\033[31mFAILED\033[0m");



passed = true;
for(int i = 0; i < n_conv_transpose_1d_test_7; i++) {
if(
fabs(conv1d_transpose_data_7[i] - expected_conv1d_7[i])/fabs(expected_conv1d_7[i]) > .000001) {
Expand Down