Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ggml-backend update: buffer types, backend registry, graph compare, tests #620

Merged
merged 35 commits into from
Nov 30, 2023

Conversation

slaren
Copy link
Collaborator

@slaren slaren commented Nov 23, 2023

  • Buffer types
  • Backend registry
  • Graph copy & compare between backends
  • ggml_backend_alloc_ctx_tensors
  • ggml_unary_op_name and ggml_op_desc
  • Backend buffer tests
  • Backend op tests
  • Update Metal backend
  • CUDA multiple device support

Buffer types

  • Buffers are no longer tied to a backend instance
  • A buffer may be used in multiple backends. In systems with unified memory or integrated GPUs, this can allow fallback to CPU without copies
  • Backends may add different types of buffers, even for different backends. For example, ggml-cuda adds a host pinned buffer that can be used with the CPU backend for faster transfer of inputs and outputs between system and device memory.

Backend registry

Allows enumerating and initializing the available backends.

// enumerate backends
for (size_t i = 0; i < ggml_backend_reg_get_count(); i++) {
    printf("Backend %zu/%zu (%s)\n", i + 1, ggml_backend_reg_get_count(), ggml_backend_reg_get_name(i));
    ggml_backend_buffer_t buf = ggml_backend_buft_alloc_buffer(ggml_backend_reg_get_default_buffer_type(i), size);
    ggml_backend_t backend = ggml_backend_reg_init_backend(i, NULL);
    // ...
}

// initialize a backend by name
ggml_backend_t backend = ggml_backend_reg_init_backend_from_str("CUDA0"); // cuda device 0

Graph copy & compare between backends

Copy a graph to a different backend and evaluate it on both one op at a time. A callback can be used to compare the results of each operation.

// graph is allocated on the CUDA backend: copy and compare with CPU backend
auto eval_callback = [](int index, struct ggml_tensor * t1, struct ggml_tensor * t2, void * user_data) {
    if (t1->type == GGML_TYPE_F32) {
        std::vector<float> v1(GGML_PAD(ggml_nbytes(t1), sizeof(float)) / sizeof(float));
        std::vector<float> v2(GGML_PAD(ggml_nbytes(t2), sizeof(float)) / sizeof(float));

        ggml_backend_tensor_get(t1, v1.data(), 0, ggml_nbytes(t1));
        ggml_backend_tensor_get(t2, v2.data(), 0, ggml_nbytes(t2));

        for (size_t i = 0; i < v1.size(); ++i) {
            if (fabsf(v1[i] - v2[i]) > 1e0) {
                printf("[%d] %s [%s]: mismatch at index %zu: %f != %f\n", index, t1->name, ggml_op_desc(t1), i, v1[i], v2[i]);
                return false; // stop graph eval on first error
            }
        }
    }
    return true;
};

ggml_backend_compare_graph_backend(model.backend, backend_cpu, gf, eval_callback, nullptr);

ggml_backend_alloc_ctx_tensors

Allocate all the tensors in a context and a backend buffer in one step. Ref: #578

model.ln_f_g = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_embd);
model.ln_f_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_embd);

model.wte     = ggml_new_tensor_2d(ctx, wtype,         n_embd, n_vocab);
model.wpe     = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_embd, n_ctx);
model.lm_head = ggml_new_tensor_2d(ctx, wtype,         n_embd, n_vocab);
// etc

// allocate the model tensors in a backend buffer
model.buffer_w = ggml_backend_alloc_ctx_tensors(ctx, model.backend);

ggml_unary_op_name and ggml_op_desc

ggml_op_desc can be used as a replacement of ggml_op_name that also returns the name of the unary op when the op is GGML_OP_UNARY.

CUDA multiple device support

ggml_backend_cuda_init takes a int device parameter that specifies the CUDA device to use. To use the default device, pass 0. Each device is registered as a different backend in the backend registry, with names CUDA0 for device 0, CUDA1 for device 1 and so on.

Backend op tests

Each op implemented in the backends is tested against the CPU backend. Tensors are initialized with random data, and the result of the op is compared using a normalized MSE. New ops implemented in the backends should add a test in test-backend-ops.cpp. Only F16/F32 tests for now, quantized types are not yet supported in the test.

@FSSRepo
Copy link
Collaborator

FSSRepo commented Nov 23, 2023

Nice upate, Normally, one has to create the graph at every moment before computation; it cannot be reused, and simply assigning IDs to input tensors wouldn't suffice.

         ggml_allocr_reset(compute_alloc);

        struct ggml_cgraph * gf = build_graph(z, decode_graph);
        ggml_allocr_alloc_graph(compute_alloc, gf);

        if (ggml_backend_is_cpu(backend)) {
            ggml_backend_cpu_set_n_threads(backend, n_threads);
        }

        ggml_backend_graph_compute(backend, gf);

        ggml_backend_tensor_get(gf->nodes[gf->n_nodes - 1], work_result->data, 0, ggml_nbytes(work_result));
ggml_allocr_reset(compute_alloc);
ggml_set_input(gf, compute_alloc, Z_INPUT, z);

ggml_allocr_alloc_graph(compute_alloc, gf);

 if (ggml_backend_is_cpu(backend)) {
          ggml_backend_cpu_set_n_threads(backend, n_threads);
}

 // gf is created once, it can be the graph used for measure buffer size
ggml_backend_graph_compute(backend, gf);

ggml_backend_tensor_get(gf->nodes[gf->n_nodes - 1], work_result->data, 0, ggml_nbytes(work_result));

ggml_backend_alloc_ctx_tensors(ctx, model.backend); Skip already allocated tensors, for example when assigning a constant tensor such as the attention query scale.

// constant tensor alloc it once time when initialize the params of model
attn_scale = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, 1);
        ggml_allocr_alloc(alloc, attn_scale);
        float scale = 1.0f / sqrt((float) d_head);
        ggml_backend_tensor_set(attn_scale, &scale, 0, sizeof(scale));

// to avoid assert(data == NULL)
// alloc all tensors linked to this context
        for (struct ggml_tensor * t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) {
           if(t->data == NULL) {
             ggml_allocr_alloc(alloc, t);
           }
        }

ggml_backend_compare_graph_backend wonderful function, thank you, It will be very useful for debugging dequantization kernel issues since, for some reason, in stable diffusion, quantizations work fine for me in the CPU backend but not in CUDA. At some point, I start to get NaN, and I have evidence that it used to work before, but with some change in the CUDA backend, it stopped working.

If possible, it could be beneficial for parameter tensors to be dynamically passed in a backend (dynamic offloading) as needed during computation.

@slaren
Copy link
Collaborator Author

slaren commented Nov 24, 2023

Nice upate, Normally, one has to create the graph at every moment before computation; it cannot be reused, and simply assigning IDs to input tensors wouldn't suffice.

I think that what you mean here is that graphs used for measure cannot be re-used for computation, which is not very good for graphs that only need to be evaluated once. I think we could solve this by assigning offsets rather than absolute addresses to tensors (ie. in the data member), then it would be possible to replace the buffer used to allocate the graph without having to update every tensor. However, this would break compatibility with the CPU backend, and it would require significant changes to ggml.c, which is something that we have tried to avoid for now. Maybe in the future we can re-evaluate this, but it is not going to happen soon.

ggml_backend_alloc_ctx_tensors(ctx, model.backend); Skip already allocated tensors, for example when assigning a constant tensor such as the attention query scale.

I will do this.

ggml_backend_compare_graph_backend wonderful function, thank you, It will be very useful for debugging dequantization kernel issues since, for some reason, in stable diffusion, quantizations work fine for me in the CPU backend but not in CUDA.

The goal is indeed to help debugging issues with the backends. I already found some differences in the GELU op between the CPU and CUDA backends while testing this with the gpt-2 example, and I am sure that there are more cases. I will try to add automated tests for all the ops.

If possible, it could be beneficial for parameter tensors to be dynamically passed in a backend (dynamic offloading) as needed during computation.

I don't really understand what you mean by this. If you prefer to use Spanish, you can send me a message to the email address in my github profile.

@FSSRepo
Copy link
Collaborator

FSSRepo commented Nov 24, 2023

some differences in the GELU op

comment #define GGML_GELU_FP16 in ggml.c and try again

@slaren
Copy link
Collaborator Author

slaren commented Nov 28, 2023

@ggerganov is the github CI capable of using the Metal backend? I suspect that it is failing somewhere during ggml_metal_init.

19: Test command: /Users/runner/work/ggml/ggml/build/bin/test-backend-buffer
19: Working Directory: /Users/runner/work/ggml/ggml/build/tests
19: Environment variables: 
19:  LLVM_PROFILE_FILE=test-backend-buffer.profraw
19: Test timeout computed to be: 900
19: ggml_backend_register: registered backend CPU
19: ggml_backend_register: registered backend Metal
19: ggml_metal_init: allocating
19: ggml_metal_init: found device: Apple Paravirtual device
19: ggml_metal_init: picking default device: Apple Paravirtual device
19: ggml_metal_init: default.metallib not found, loading from source
19: ggml_metal_init: error: could not use bundle path to find ggml-metal.metal, falling back to trying cwd
19: ggml_metal_init: loading 'ggml-metal.metal'
19/19 Test #19: test-backend-buffer ..............***Exception: SegFault  0.04 sec
Errors while running CTest

@ggerganov
Copy link
Owner

Hm, seems it does not work, but it does not print any meaningful error.
Probably we should just disable it for now and use ggml-ci for test-backend-buffer?

@slaren
Copy link
Collaborator Author

slaren commented Nov 28, 2023

I am not sure what is happening either, it seems that the log is missing the output from stdout too. Let's disable it for now then, thanks for looking into it.

@slaren
Copy link
Collaborator Author

slaren commented Nov 29, 2023

@FSSRepo can you help me define a few test cases for im2col? I need different values of parameters to test, currently it is tested with these (the default values in the constructor):

// GGML_OP_IM2COL
struct test_im2col : public test_case {
    const ggml_type type_a;
    const ggml_type type_b;
    const std::array<int64_t, 4> ne_a;
    const std::array<int64_t, 4> ne_b;
    const int s0;
    const int s1;
    const int p0;
    const int p1;
    const int d0;
    const int d1;
    const bool is_2D;

    std::string vars() override {
        return VARS_TO_STR11(type_a, type_b, ne_a, ne_b, s0, s1, p0, p1, d0, d1, is_2D);
    }

    test_im2col(ggml_type type_a = GGML_TYPE_F16, ggml_type type_b = GGML_TYPE_F32,
            std::array<int64_t, 4> ne_a = {10, 10, 10, 10},
            std::array<int64_t, 4> ne_b = {10, 10, 10, 10},
            int s0 = 1, int s1 = 1,
            int p0 = 0, int p1 = 0,
            int d0 = 1, int d1 = 1,
            bool is_2D = false)
        : type_a(type_a), type_b(type_b), ne_a(ne_a), ne_b(ne_b), s0(s0), s1(s1), p0(p0), p1(p1), d0(d0), d1(d1), is_2D(is_2D) {}

    ggml_tensor * build_graph(ggml_context * ctx) override {
        ggml_tensor * a = ggml_new_tensor(ctx, type_a, 4, ne_a.data());
        ggml_tensor * b = ggml_new_tensor(ctx, type_b, 4, ne_b.data());
        ggml_tensor * out = ggml_im2col(ctx, a, b, s0, s1, p0, p1, d0, d1, is_2D);
        return out;
    }
};

@slaren slaren marked this pull request as ready for review November 29, 2023 16:47
@FSSRepo
Copy link
Collaborator

FSSRepo commented Nov 29, 2023

My mistake, I gave you the wrong order of the tensors. Wait a little bit.

@FSSRepo
Copy link
Collaborator

FSSRepo commented Nov 29, 2023

Fixed:

// GGML_OP_IM2COL
struct test_im2col : public test_case {
    const ggml_type type_input;
    const ggml_type type_kernel;
    const std::array<int64_t, 4> ne_input;
    const std::array<int64_t, 4> ne_kernel;
    // stride
    const int s0;
    const int s1;
    // padding
    const int p0;
    const int p1;
    // dilatation
    const int d0;
    const int d1;
    // mode
    const bool is_2D;

    std::string vars() override {
        return VARS_TO_STR11(type_input, type_kernel, ne_input, ne_kernel, s0, s1, p0, p1, d0, d1, is_2D);
    }

    test_im2col(ggml_type type_input = GGML_TYPE_F16, ggml_type type_kernel = GGML_TYPE_F32,
            std::array<int64_t, 4> ne_input = {10, 10, 3, 1}, // [input_width, input_height, input_channels, 1]
            std::array<int64_t, 4> ne_kernel = {3, 3, 3, 1}, // [kernel_width, kernel_height, input_channels, 1]
            int s0 = 1, int s1 = 1,
            int p0 = 1, int p1 = 1,
            int d0 = 1, int d1 = 1,
            bool is_2D = true)
        : type_input(type_input), type_kernel(type_kernel), ne_input(ne_input), ne_kernel(ne_kernel), s0(s0), s1(s1), p0(p0), p1(p1), d0(d0), d1(d1), is_2D(is_2D) {}

    ggml_tensor * build_graph(ggml_context * ctx) override {
        ggml_tensor * input = ggml_new_tensor(ctx, type_input, 4, ne_input.data());
        ggml_tensor * kernel = ggml_new_tensor(ctx, type_kernel, 4, ne_kernel.data());
        ggml_tensor * out = ggml_im2col(ctx, kernel, input, s0, s1, p0, p1, d0, d1, is_2D);
        return out;
    }
};

@slaren
Copy link
Collaborator Author

slaren commented Nov 29, 2023

I had to swap the kernel and input types, but otherwise it works.

@slaren
Copy link
Collaborator Author

slaren commented Nov 29, 2023

The ggml-ci is also failing with Metal because it cannot find the ggml-metal.metal, what would be the best way to fix that?

It is still going to fail after that because some ops are broken in Metal when the number of columns is not a multiple of 4 and there aren't checks for that, and additionally both CUDA and Metal have broken support for broadcasting with add and mul, but I think we should keep these tests enabled to remind us to fix it.

@ggerganov
Copy link
Owner

 Will take a detailed look tomorrow. Great work on the tests - I was just thinking about something like this. Much needed.

@ggerganov
Copy link
Owner

Regarding multiple CUDA devices: is it fully functional at this point? Is there going to be a conflict around the shared memory pool if it is used by more than one device or thanks to the lock everything should be good?

With this infrastructure, if let's say I wanted to implement an alternative backend for an existing GPU (for example, ggml-cuda-2.cu or ggml-metal-2.m), I can just register it as a new backend and both implementations should be usable together in a single binary, correct?

@slaren
Copy link
Collaborator Author

slaren commented Nov 30, 2023

I wouldn't say that it is fully functional already because I haven't run enough tests to be confident that everything works correctly, but it should work. The memory pool is already per-device, so that shouldn't cause issues. There are still most definitely synchronization issues in the CUDA backend, so it is not possible to use two devices simultaneously (eg. in different threads). It should work for splitting a model to multiple devices with ggml_backend_sched, but I have not tested it, and there are still some optimizations that could be done (mainly, copies between devices will have happen as device->cpu->device in ggml-backend, rather than just doing a cudaMemcpy).

With this infrastructure, if let's say I wanted to implement an alternative backend for an existing GPU (for example, ggml-cuda-2.cu or ggml-metal-2.m), I can just register it as a new backend and both implementations should be usable together in a single binary, correct?

Yes, I don't see why that wouldn't work, except maybe for the few hooks that some backends (cuda/opencl) have in ggml.c. Ultimately the goal is to be able to use a single binary for all the backends, and detect the availability at runtime.

@ggerganov
Copy link
Owner

ggerganov commented Nov 30, 2023

How come none of the MUL_MAT tests pass for the Metal backend:

Error: MUL_MAT: NMSE = 5.200507
  MUL_MAT(type_a=f32,type_b=f32,m=32,n=32,k=32,bs=[10,10],nr=[1,1]): FAIL
Error: MUL_MAT: NMSE = 5.312463
  MUL_MAT(type_a=f32,type_b=f32,m=32,n=32,k=32,bs=[10,10],nr=[2,1]): FAIL
Error: MUL_MAT: NMSE = 7.459601
  MUL_MAT(type_a=f32,type_b=f32,m=32,n=32,k=32,bs=[10,10],nr=[1,2]): FAIL
Error: MUL_MAT: NMSE = 7.485240
  MUL_MAT(type_a=f32,type_b=f32,m=32,n=32,k=32,bs=[10,10],nr=[2,2]): FAIL
Error: MUL_MAT: NMSE = 5.304665
  MUL_MAT(type_a=f16,type_b=f32,m=32,n=32,k=32,bs=[10,10],nr=[1,1]): FAIL
Error: MUL_MAT: NMSE = 5.417071
  MUL_MAT(type_a=f16,type_b=f32,m=32,n=32,k=32,bs=[10,10],nr=[2,1]): FAIL
Error: MUL_MAT: NMSE = 7.481416
  MUL_MAT(type_a=f16,type_b=f32,m=32,n=32,k=32,bs=[10,10],nr=[1,2]): FAIL
Error: MUL_MAT: NMSE = 7.437984
  MUL_MAT(type_a=f16,type_b=f32,m=32,n=32,k=32,bs=[10,10],nr=[2,2]): FAIL

Edit: got it - dim[3] is not supported

Copy link
Owner

@ggerganov ggerganov left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great stuff! Merge at will

@slaren slaren merged commit 38f46af into master Nov 30, 2023
7 of 9 checks passed
@slaren slaren deleted the backend-v-next branch November 30, 2023 18:03
CCLDArjun pushed a commit to CCLDArjun/ggml that referenced this pull request Dec 18, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
Status: No status
Development

Successfully merging this pull request may close these issues.

None yet

3 participants