Skip to content

Commit

Permalink
CUDA: fix broken oob check for FA vec f32 kernel (llama/7904)
Browse files Browse the repository at this point in the history
  • Loading branch information
JohannesGaessler authored and ggerganov committed Jun 15, 2024
1 parent 8436913 commit 91f2ca8
Showing 1 changed file with 1 addition and 1 deletion.
2 changes: 1 addition & 1 deletion src/ggml-cuda/fattn-vec-f32.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,7 @@ static __global__ void flash_attn_vec_ext_f32(
for (int i0 = 0; i0 < D/2; i0 += WARP_SIZE) {
const int i = i0 + threadIdx.x;

Q_f2[j][i0/WARP_SIZE] = ncols <= 2 || ic0 + j ? Q_f2_j[i] : make_float2(0.0f, 0.0f);
Q_f2[j][i0/WARP_SIZE] = ncols <= 2 || ic0 + j < ne01 ? Q_f2_j[i] : make_float2(0.0f, 0.0f);
Q_f2[j][i0/WARP_SIZE].x *= scale;
Q_f2[j][i0/WARP_SIZE].y *= scale;
}
Expand Down

0 comments on commit 91f2ca8

Please sign in to comment.