Skip to content

Commit

Permalink
ggml : refactor unary ops (#405)
Browse files Browse the repository at this point in the history
* Add gitignore rule for temporary vim files

* ggml: refactor implementation of unary ops

* backends : adapt to ggml_unary_op

* ggml : fix assert number of ops

* ggml : hide ggml_set_unary_op from public API

---------

Co-authored-by: izdane <[email protected]>
Co-authored-by: Georgi Gerganov <[email protected]>
  • Loading branch information
3 people authored Jul 23, 2023
1 parent 7b55e12 commit be4c8ba
Show file tree
Hide file tree
Showing 6 changed files with 393 additions and 453 deletions.
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -29,3 +29,5 @@ zig-out/
zig-cache/

*.dot

*.sw?
50 changes: 31 additions & 19 deletions examples/mnist/main-mtl.m
Original file line number Diff line number Diff line change
Expand Up @@ -340,22 +340,32 @@ int mnist_mtl_eval(

[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
} break;
case GGML_OP_RELU:
{
if (encoder == nil) {
encoder = [command_buffer computeCommandEncoder];
}

id<MTLBuffer> id_src = mnist_mtl_get_buffer(ctx, gf->nodes[i]->src[0], &offs_src0);
id<MTLBuffer> id_dst = mnist_mtl_get_buffer(ctx, gf->nodes[i], &offs_dst);

[encoder setComputePipelineState:ctx->pipeline_relu];
[encoder setBuffer:id_src offset:offs_src0 atIndex:0];
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];

const int64_t n = ggml_nelements(gf->nodes[i]);

[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
case GGML_OP_UNARY:
switch (ggml_get_unary_op(gf->nodes[i])) {
case GGML_UNARY_OP_RELU:
{
if (encoder == nil) {
encoder = [command_buffer computeCommandEncoder];
}

id<MTLBuffer> id_src = mnist_mtl_get_buffer(ctx, gf->nodes[i]->src[0], &offs_src0);
id<MTLBuffer> id_dst = mnist_mtl_get_buffer(ctx, gf->nodes[i], &offs_dst);

[encoder setComputePipelineState:ctx->pipeline_relu];
[encoder setBuffer:id_src offset:offs_src0 atIndex:0];
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];

const int64_t n = ggml_nelements(gf->nodes[i]);

[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
} break;
default:
{
fprintf(stderr, "%s: node %3d, op = %8s, unary op %d not implemented\n", __func__, i, ggml_op_name(gf->nodes[i]->op), (int) ggml_get_unary_op(gf->nodes[i]));
GGML_ASSERT(false);
return -1;
}
break;
} break;
case GGML_OP_SOFT_MAX:
{
Expand Down Expand Up @@ -435,9 +445,11 @@ int mnist_mtl_eval(
[mul encodeToCommandBuffer:command_buffer leftMatrix:mat_src1 rightMatrix:mat_src0 resultMatrix:mat_dst];
} break;
default:
fprintf(stderr, "%s: node %3d, op = %8s not implemented\n", __func__, i, ggml_op_name(gf->nodes[i]->op));
GGML_ASSERT(false);
return -1;
{
fprintf(stderr, "%s: node %3d, op = %8s not implemented\n", __func__, i, ggml_op_name(gf->nodes[i]->op));
GGML_ASSERT(false);
return -1;
}
}
}

Expand Down
43 changes: 30 additions & 13 deletions include/ggml/ggml.h
Original file line number Diff line number Diff line change
Expand Up @@ -330,16 +330,6 @@ extern "C" {
GGML_OP_ARGMAX,
GGML_OP_REPEAT,
GGML_OP_REPEAT_BACK,
GGML_OP_ABS,
GGML_OP_SGN,
GGML_OP_NEG,
GGML_OP_STEP,
GGML_OP_TANH,
GGML_OP_ELU,
GGML_OP_RELU,
GGML_OP_GELU,
GGML_OP_GELU_QUICK,
GGML_OP_SILU,
GGML_OP_SILU_BACK,
GGML_OP_NORM, // normalize
GGML_OP_RMS_NORM,
Expand Down Expand Up @@ -378,6 +368,8 @@ extern "C" {
GGML_OP_WIN_PART,
GGML_OP_WIN_UNPART,

GGML_OP_UNARY,

GGML_OP_MAP_UNARY,
GGML_OP_MAP_BINARY,

Expand All @@ -391,6 +383,18 @@ extern "C" {
GGML_OP_COUNT,
};

enum ggml_unary_op {
GGML_UNARY_OP_ABS,
GGML_UNARY_OP_SGN,
GGML_UNARY_OP_NEG,
GGML_UNARY_OP_STEP,
GGML_UNARY_OP_TANH,
GGML_UNARY_OP_ELU,
GGML_UNARY_OP_RELU,
GGML_UNARY_OP_GELU,
GGML_UNARY_OP_GELU_QUICK,
GGML_UNARY_OP_SILU,
};

// ggml object
struct ggml_object {
Expand Down Expand Up @@ -535,6 +539,7 @@ extern "C" {

GGML_API const char * ggml_type_name(enum ggml_type type);
GGML_API const char * ggml_op_name (enum ggml_op op);
GGML_API const char * ggml_op_symbol(enum ggml_op op);

GGML_API size_t ggml_element_size(const struct ggml_tensor * tensor);

Expand Down Expand Up @@ -618,9 +623,11 @@ extern "C" {
GGML_API void * ggml_get_data (const struct ggml_tensor * tensor);
GGML_API float * ggml_get_data_f32(const struct ggml_tensor * tensor);

GGML_API const char * ggml_get_name(const struct ggml_tensor * tensor);
GGML_API struct ggml_tensor * ggml_set_name(struct ggml_tensor * tensor, const char * name);
GGML_API struct ggml_tensor * ggml_format_name(struct ggml_tensor * tensor, const char * fmt, ...);
GGML_API enum ggml_unary_op ggml_get_unary_op(const struct ggml_tensor * tensor);

GGML_API const char * ggml_get_name (const struct ggml_tensor * tensor);
GGML_API struct ggml_tensor * ggml_set_name ( struct ggml_tensor * tensor, const char * name);
GGML_API struct ggml_tensor * ggml_format_name( struct ggml_tensor * tensor, const char * fmt, ...);

//
// operations on tensors with backpropagation
Expand Down Expand Up @@ -1285,6 +1292,16 @@ extern "C" {
typedef void (*ggml_custom2_op_f32_t)(struct ggml_tensor *, const struct ggml_tensor *, const struct ggml_tensor *);
typedef void (*ggml_custom3_op_f32_t)(struct ggml_tensor *, const struct ggml_tensor *, const struct ggml_tensor *, const struct ggml_tensor *);

GGML_API struct ggml_tensor * ggml_unary(
struct ggml_context * ctx,
struct ggml_tensor * a,
enum ggml_unary_op op);

GGML_API struct ggml_tensor * ggml_unary_inplace(
struct ggml_context * ctx,
struct ggml_tensor * a,
enum ggml_unary_op op);

GGML_API struct ggml_tensor * ggml_map_unary_f32(
struct ggml_context * ctx,
struct ggml_tensor * a,
Expand Down
29 changes: 17 additions & 12 deletions src/ggml-cuda.cu
Original file line number Diff line number Diff line change
Expand Up @@ -3908,18 +3908,23 @@ bool ggml_cuda_compute_forward(struct ggml_compute_params * params, struct ggml_
}
func = ggml_cuda_mul;
break;
case GGML_OP_GELU:
if (!any_on_device) {
return false;
}
func = ggml_cuda_gelu;
break;
case GGML_OP_SILU:
if (!any_on_device) {
return false;
}
func = ggml_cuda_silu;
break;
case GGML_OP_UNARY:
switch (ggml_get_unary_op(tensor)) {
case GGML_UNARY_OP_GELU:
if (!any_on_device) {
return false;
}
func = ggml_cuda_gelu;
break;
case GGML_UNARY_OP_SILU:
if (!any_on_device) {
return false;
}
func = ggml_cuda_silu;
break;
default:
return false;
} break;
case GGML_OP_NORM:
if (!any_on_device) {
return false;
Expand Down
96 changes: 53 additions & 43 deletions src/ggml-metal.m
Original file line number Diff line number Diff line change
Expand Up @@ -519,48 +519,56 @@ void ggml_metal_graph_compute(

[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
} break;
case GGML_OP_SILU:
{
if (encoder == nil) {
encoder = [command_buffer computeCommandEncoder];
}

[encoder setComputePipelineState:ctx->pipeline_silu];
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];

const int64_t n = ggml_nelements(dst);

[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
} break;
case GGML_OP_RELU:
{
if (encoder == nil) {
encoder = [command_buffer computeCommandEncoder];
}

[encoder setComputePipelineState:ctx->pipeline_relu];
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];

const int64_t n = ggml_nelements(dst);

[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
case GGML_OP_UNARY:
switch (ggml_get_unary_op(gf->nodes[i])) {
case GGML_UNARY_OP_SILU:
{
if (encoder == nil) {
encoder = [command_buffer computeCommandEncoder];
}

[encoder setComputePipelineState:ctx->pipeline_silu];
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];

const int64_t n = ggml_nelements(dst);

[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
} break;
case GGML_UNARY_OP_RELU:
{
if (encoder == nil) {
encoder = [command_buffer computeCommandEncoder];
}

[encoder setComputePipelineState:ctx->pipeline_relu];
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];

const int64_t n = ggml_nelements(dst);

[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
} break;
case GGML_UNARY_OP_GELU:
{
if (encoder == nil) {
encoder = [command_buffer computeCommandEncoder];
}

[encoder setComputePipelineState:ctx->pipeline_gelu];
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];

const int64_t n = ggml_nelements(dst);

[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
} break;
default:
{
fprintf(stderr, "%s: node %3d, op = %8s not implemented\n", __func__, i, ggml_op_name(dst->op));
GGML_ASSERT(false);
}
} break;
case GGML_OP_GELU:
{
if (encoder == nil) {
encoder = [command_buffer computeCommandEncoder];
}

[encoder setComputePipelineState:ctx->pipeline_gelu];
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];

const int64_t n = ggml_nelements(dst);

[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
} break;
case GGML_OP_SOFT_MAX:
{
if (encoder == nil) {
Expand Down Expand Up @@ -979,8 +987,10 @@ void ggml_metal_graph_compute(
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
} break;
default:
fprintf(stderr, "%s: node %3d, op = %8s not implemented\n", __func__, i, ggml_op_name(dst->op));
GGML_ASSERT(false);
{
fprintf(stderr, "%s: node %3d, op = %8s not implemented\n", __func__, i, ggml_op_name(dst->op));
GGML_ASSERT(false);
}
}
}

Expand Down
Loading

0 comments on commit be4c8ba

Please sign in to comment.