diff --git a/src/ggml-metal.m b/src/ggml-metal.m index f9fd8dc8c..6b538fd39 100644 --- a/src/ggml-metal.m +++ b/src/ggml-metal.m @@ -855,8 +855,8 @@ static bool ggml_metal_supports_op(const struct ggml_tensor * op) { case GGML_OP_DIAG_MASK_INF: case GGML_OP_GET_ROWS: { - return op->ne[0] % 4 == 0; - } + return op->ne[3] == 1; + } break; default: return false; } @@ -931,7 +931,10 @@ void ggml_metal_graph_compute( } break; } - GGML_ASSERT(ggml_metal_supports_op(dst)); + if (!ggml_metal_supports_op(dst)) { + GGML_METAL_LOG_ERROR("%s: error: unsupported op '%s'\n", __func__, ggml_op_desc(dst)); + GGML_ASSERT(!"unsupported op"); + } const int64_t ne00 = src0 ? src0->ne[0] : 0; const int64_t ne01 = src0 ? src0->ne[1] : 0; @@ -1326,6 +1329,8 @@ void ggml_metal_graph_compute( [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; if (id_src1) { [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1]; + } else { + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:1]; } [encoder setBuffer:id_dst offset:offs_dst atIndex:2]; [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3]; diff --git a/src/ggml-metal.metal b/src/ggml-metal.metal index 092a1f599..54b3b8a16 100644 --- a/src/ggml-metal.metal +++ b/src/ggml-metal.metal @@ -252,10 +252,10 @@ kernel void kernel_relu( } kernel void kernel_tanh( - device const float4 * src0, - device float4 * dst, + device const float * src0, + device float * dst, uint tpig[[thread_position_in_grid]]) { - device const float4 & x = src0[tpig]; + device const float & x = src0[tpig]; dst[tpig] = precise::tanh(x); } @@ -367,7 +367,7 @@ kernel void kernel_soft_max( const int64_t i01 = (tgpig - i03*ne02*ne01 - i02*ne01); device const float * psrc0 = src0 + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00; - device const float * pmask = src1 ? src1 + i01*ne00 : nullptr; + device const float * pmask = src1 != src0 ? src1 + i01*ne00 : nullptr; device float * pdst = dst + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00; // parallel max @@ -404,6 +404,7 @@ kernel void kernel_soft_max( pdst[i00] = exp_psrc0; } + threadgroup_barrier(mem_flags::mem_threadgroup); float sum = simd_sum(lsum); if (ntg > N_SIMDWIDTH) { if (sgitg == 0) { @@ -447,9 +448,9 @@ kernel void kernel_soft_max_4( const int64_t i02 = (tgpig - i03*ne02*ne01) / ne01; const int64_t i01 = (tgpig - i03*ne02*ne01 - i02*ne01); - device const float4 * psrc4 = (device const float4 *)(src0 + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00); - device const float4 * pmask = src1 ? (device const float4 *)(src1 + i01*ne00) : nullptr; - device float4 * pdst4 = (device float4 *)(dst + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00); + device const float4 * psrc4 = (device const float4 *)(src0 + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00); + device const float4 * pmask = src1 != src0 ? (device const float4 *)(src1 + i01*ne00) : nullptr; + device float4 * pdst4 = (device float4 *)(dst + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00); // parallel max float4 lmax4 = -INFINITY; @@ -487,6 +488,7 @@ kernel void kernel_soft_max_4( } const float lsum = lsum4[0] + lsum4[1] + lsum4[2] + lsum4[3]; + threadgroup_barrier(mem_flags::mem_threadgroup); float sum = simd_sum(lsum); if (ntg > N_SIMDWIDTH) { if (sgitg == 0) { @@ -693,6 +695,7 @@ kernel void kernel_group_norm( tmp += src0[j]; } + threadgroup_barrier(mem_flags::mem_threadgroup); tmp = simd_sum(tmp); if (ntg > N_SIMDWIDTH) { if (sgitg == 0) {