Skip to content

Commit

Permalink
metal : soft max, tanh, supports_op fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
slaren committed Dec 12, 2023
1 parent b9a77fa commit 1914017
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 10 deletions.
11 changes: 8 additions & 3 deletions src/ggml-metal.m
Original file line number Diff line number Diff line change
Expand Up @@ -855,8 +855,8 @@ static bool ggml_metal_supports_op(const struct ggml_tensor * op) {
case GGML_OP_DIAG_MASK_INF:
case GGML_OP_GET_ROWS:
{
return op->ne[0] % 4 == 0;
}
return op->ne[3] == 1;
} break;
default:
return false;
}
Expand Down Expand Up @@ -931,7 +931,10 @@ void ggml_metal_graph_compute(
} break;
}

GGML_ASSERT(ggml_metal_supports_op(dst));
if (!ggml_metal_supports_op(dst)) {
GGML_METAL_LOG_ERROR("%s: error: unsupported op '%s'\n", __func__, ggml_op_desc(dst));
GGML_ASSERT(!"unsupported op");
}

const int64_t ne00 = src0 ? src0->ne[0] : 0;
const int64_t ne01 = src0 ? src0->ne[1] : 0;
Expand Down Expand Up @@ -1326,6 +1329,8 @@ void ggml_metal_graph_compute(
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
if (id_src1) {
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
} else {
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:1];
}
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
[encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3];
Expand Down
17 changes: 10 additions & 7 deletions src/ggml-metal.metal
Original file line number Diff line number Diff line change
Expand Up @@ -252,10 +252,10 @@ kernel void kernel_relu(
}

kernel void kernel_tanh(
device const float4 * src0,
device float4 * dst,
device const float * src0,
device float * dst,
uint tpig[[thread_position_in_grid]]) {
device const float4 & x = src0[tpig];
device const float & x = src0[tpig];
dst[tpig] = precise::tanh(x);
}

Expand Down Expand Up @@ -367,7 +367,7 @@ kernel void kernel_soft_max(
const int64_t i01 = (tgpig - i03*ne02*ne01 - i02*ne01);

device const float * psrc0 = src0 + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
device const float * pmask = src1 ? src1 + i01*ne00 : nullptr;
device const float * pmask = src1 != src0 ? src1 + i01*ne00 : nullptr;
device float * pdst = dst + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;

// parallel max
Expand Down Expand Up @@ -404,6 +404,7 @@ kernel void kernel_soft_max(
pdst[i00] = exp_psrc0;
}

threadgroup_barrier(mem_flags::mem_threadgroup);
float sum = simd_sum(lsum);
if (ntg > N_SIMDWIDTH) {
if (sgitg == 0) {
Expand Down Expand Up @@ -447,9 +448,9 @@ kernel void kernel_soft_max_4(
const int64_t i02 = (tgpig - i03*ne02*ne01) / ne01;
const int64_t i01 = (tgpig - i03*ne02*ne01 - i02*ne01);

device const float4 * psrc4 = (device const float4 *)(src0 + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00);
device const float4 * pmask = src1 ? (device const float4 *)(src1 + i01*ne00) : nullptr;
device float4 * pdst4 = (device float4 *)(dst + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00);
device const float4 * psrc4 = (device const float4 *)(src0 + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00);
device const float4 * pmask = src1 != src0 ? (device const float4 *)(src1 + i01*ne00) : nullptr;
device float4 * pdst4 = (device float4 *)(dst + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00);

// parallel max
float4 lmax4 = -INFINITY;
Expand Down Expand Up @@ -487,6 +488,7 @@ kernel void kernel_soft_max_4(
}

const float lsum = lsum4[0] + lsum4[1] + lsum4[2] + lsum4[3];
threadgroup_barrier(mem_flags::mem_threadgroup);
float sum = simd_sum(lsum);
if (ntg > N_SIMDWIDTH) {
if (sgitg == 0) {
Expand Down Expand Up @@ -693,6 +695,7 @@ kernel void kernel_group_norm(
tmp += src0[j];
}

threadgroup_barrier(mem_flags::mem_threadgroup);
tmp = simd_sum(tmp);
if (ntg > N_SIMDWIDTH) {
if (sgitg == 0) {
Expand Down

0 comments on commit 1914017

Please sign in to comment.