-
Notifications
You must be signed in to change notification settings - Fork 940
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
feat: cuda implementation for ggml_conv_transpose_1d
#854
Open
balisujohn
wants to merge
13
commits into
ggerganov:master
Choose a base branch
from
balisujohn:dev-conv-transpose-1d-cuda
base: master
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
+858
−1
Open
Changes from 1 commit
Commits
Show all changes
13 commits
Select commit
Hold shift + click to select a range
70de8b7
conv transpose 1d passing test for 1d input and kernel
balisujohn f6883de
working for different input and output channel counts, added test for…
balisujohn f35d3ec
initial draft appears to work with stride other than 1
balisujohn 53a4fcf
working with all old and new conv1d tests
balisujohn f3bb758
added a test for large tensors
balisujohn 7eff0ab
removed use cuda hardcoding
balisujohn 152e04e
restored test-conv-transpose.c
balisujohn 5d39cd4
removed unused arugments, and fixed bug where test failure would caus…
balisujohn 2e7445e
fixed accumulator bug
balisujohn ed3b788
added test to test-backend-ops
balisujohn da3d0d1
fixed mistake
balisujohn 6fd70a2
addressed review
balisujohn 708f48f
fixed includes
balisujohn File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
added a test for large tensors
- Loading branch information
commit f3bb7580a97328ff770d5c7689f6d10db920797e
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -11,11 +11,10 @@ static __global__ void conv_transpose_1d_kernel( | |
if (global_index >= output_size) { | ||
return; | ||
} | ||
//printf("idx: %d stride %d\n", global_index,s0); | ||
|
||
int out_index = global_index / dst_ne0; | ||
|
||
dst[global_index] = 0; | ||
int accumulator = 0; | ||
|
||
for (int c = 0; c < src0_ne2; c++) | ||
{ | ||
|
@@ -25,33 +24,8 @@ static __global__ void conv_transpose_1d_kernel( | |
int kernel_offset = (src0_ne0 * src0_ne1 * c) + (out_index * src0_ne0); | ||
int input_offset = src1_ne0 * c; | ||
|
||
if(global_index == 0 && output_size == 12) | ||
{ | ||
printf("idx: %d ???: %d\n", global_index,src0_ne2); | ||
|
||
printf("idx: %d kernel offset: %d\n", global_index,kernel_offset); | ||
printf("idx: %d input offset: %d\n", global_index,input_offset); | ||
} | ||
|
||
int upper_bound = idx > src1_ne0-1 ? src1_ne0-1 : (int)(idx/s0)*s0; //inclusive | ||
/* | ||
int upper_bound = 0; | ||
while (upper_bound < idx){ | ||
upper_bound +=1; | ||
}*/ | ||
|
||
|
||
int lower_bound = idx - src0_ne0 + 1 >= 0 ? (int)(idx/s0)*s0 - src0_ne0 + 1 : 0; | ||
|
||
int initial_weight_idx = idx > src0_ne0 -1 ? src0_ne0-1 : idx; | ||
|
||
if(global_index == 0 && output_size == 12) | ||
{ | ||
printf("idx: %d initial_weight_idx: %d\n", global_index,initial_weight_idx); | ||
printf("idx: %d upper bound: %d\n", global_index, upper_bound); | ||
printf("idx: %d lower bound: %d\n", global_index, lower_bound); | ||
} | ||
|
||
for (int i = 0; i < src1_ne0; i++) | ||
{ | ||
if (!(idx >= i*s0 && idx < i*s0 + src0_ne0)) | ||
|
@@ -60,32 +34,14 @@ static __global__ void conv_transpose_1d_kernel( | |
} | ||
int weight_idx = idx - i*s0; | ||
|
||
|
||
if(global_index == 0 && output_size == 12) | ||
{ | ||
//printf("idx: %d partial sum: %d x %d \n", global_index,src0[kernel_offset + (initial_weight_idx-(i-lower_bound))] , src1[input_offset+i]); | ||
//printf("idx: %d kernel_index: %d\n", global_index, kernel_offset + (initial_weight_idx-(i-lower_bound))); | ||
//printf("idx: %d input_index: %d\n", global_index, initial_weight_idx-(i-lower_bound)); | ||
|
||
//printf("idx: %d input_index: %d\n", global_index, input_offset+i); | ||
|
||
} | ||
int test1 = src0[kernel_offset + weight_idx]; | ||
int test2 = src1[input_offset+i]; | ||
if(global_index == 0 && output_size == 12) | ||
{ | ||
//printf("idx: %d partial sum: %d x %d \n", global_index,src0[kernel_offset + (initial_weight_idx-(i-lower_bound))] , src1[input_offset+i]); | ||
//printf("idx: %d kernel_index: %d\n", global_index, kernel_offset + (initial_weight_idx-(i-lower_bound))); | ||
//printf("idx: %d input_index: %d\n", global_index, initial_weight_idx-(i-lower_bound)); | ||
|
||
//printf("idx: %d input_index: %d\n", global_index, input_offset+i); | ||
printf("idx: %d test: %d x %d\n", global_index, test1, test2); | ||
|
||
} | ||
dst[global_index] += test1 * test2; | ||
|
||
int kernel_weight = src0[kernel_offset + weight_idx]; | ||
int input_value = src1[input_offset+i]; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
|
||
accumulator += kernel_weight * input_value; | ||
} | ||
//dst[idx] = 7; | ||
} | ||
dst[global_index] = accumulator; | ||
} | ||
|
||
static void conv_transpose_1d_f32_f32_cuda( | ||
|
Large diffs are not rendered by default.
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
int
->float
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yeah that seems to fix the issue I was experiencing; it's bizarre that the tests still passed even in cuda mode with the types accidentally set to
int
Thanks so much!