Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ggml: new gpu kernels + extends ggml_leaky_relu + ggml_pad #621

Merged
merged 32 commits into from
Dec 13, 2023

Conversation

FSSRepo
Copy link
Collaborator

@FSSRepo FSSRepo commented Nov 26, 2023

The purpose of this PR is to synchronize the changes I made in ggml while working on a PR for the stable-diffusion.cpp project. This adds new CUDA kernels that could help other projects fully support different backends.

New CUDA kernels

  • ggml_concat
  • ggml_upscale
  • ggml_gelu_quick
  • ggml_group_norm
  • ggml_acc
  • ggml_tanh
  • ggml_leaky_relu

New Operation

  • ggml_pad: add a zero padding. equivalent of PyTorch pad. Needed in stable-diffusion.cpp.

Tasks:

  • Make metal kernels, using as a reference those that exist for CUDA. @ggerganov
    • ggml_concat
    • ggml_upscale
    • ggml_gelu_quick
    • ggml_group_norm
    • ggml_acc
    • ggml_tanh
    • ggml_leaky_relu
    • ggml_pad
  • Researching a better way to implement the ggml_group_norm kernel, the current one is definitely very inefficient. @slaren can you help me?

@FSSRepo FSSRepo marked this pull request as draft November 26, 2023 20:55
@slaren
Copy link
Collaborator

slaren commented Nov 27, 2023

As far as I can tell, the CPU implementations of add and mul already supported broadcasting in all dimensions except the first. Is that not enough for clip/sd? I am not sure why the CPU implementation needed to be changed, this seems more restrictive since it only supports broadcasting in one dimension.

@FSSRepo
Copy link
Collaborator Author

FSSRepo commented Nov 27, 2023

As far as I can tell, the CPU implementations of add and mul already supported broadcasting in all dimensions except the first. Is that not enough for clip/sd? I am not sure why the CPU implementation needed to be changed, this seems more restrictive since it only supports broadcasting in one dimension.

Conducted some broadcasting tests on the ggml_add and ggml_mul functions. Previously, it only worked with the first dimension of b (it had to have the same number of elements as the first dimension of a, thus repeating all rows). This didn't work in stable diffusion since it needed to repeat the rows in dimension 3, but the original implementation crashed with an assert if I didn't use repeat.

@slaren
Copy link
Collaborator

slaren commented Nov 27, 2023

The CUDA implementation didn't support broadcasting other than by repeating the entire tensor, so I think that would only work when broadcasting the highest dimension, but not any other. The CPU implementation supported broadcasting along all dimensions except for the first one:

ggml/src/ggml.c

Lines 6860 to 6862 in a5e4560

const int64_t i13 = i03 % ne13;
const int64_t i12 = i02 % ne12;
const int64_t i11 = i01 % ne11;

Unless sd required broadcasting in dimension 0, it should already work with the CPU backend, right? I think we should look into implementing this in a similar way for CUDA, it would be simpler and would work with any dimensions.

@FSSRepo
Copy link
Collaborator Author

FSSRepo commented Nov 27, 2023

it should already work with the CPU backend, right

I will undo my broadcasting changes to the CPU backend and test it with stable diffusion, assuming that it supports broadcasting in all dimensions as you mentioned. I will also need to change "ggml_can_repeat_rows" to "ggml_can_repeat" since that assert is what is preventing me from using ggml_add without using ggml_repeat. I hope it works and we can reach a consensus.

@FSSRepo
Copy link
Collaborator Author

FSSRepo commented Nov 27, 2023

Unless sd required broadcasting in dimension 0

You must repeat the elements in dimensions 0 and 1, as the size of the bias is the same as the output channel of the ggml_conv_2d operation.

Result of ggml_conv_2d: [out width, out height, out channels, N]

Bias tensor should be reshaped from [out channels] to [1, 1, out_channels, 1]

But the previusly implementation of broadcasting expect a tensor [out_width, 1, 1, 1] to works correctly

@slaren
Copy link
Collaborator

slaren commented Nov 27, 2023

ggml_can_repeat_rows is the same as ggml_can_repeat, but without broadcasting in the dimension 0:

ggml/src/ggml.c

Lines 2152 to 2156 in a5e4560

static inline bool ggml_can_repeat_rows(const struct ggml_tensor * t0, const struct ggml_tensor * t1) {
static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function");
return (t0->ne[0] == t1->ne[0]) && ggml_can_repeat(t0, t1);
}

So, if ggml_can_repeat_rows fails for sd, then it likely means that it requires broadcasting in the dimension 0, which is not currently supported.

@FSSRepo
Copy link
Collaborator Author

FSSRepo commented Nov 27, 2023

@slaren

.9><lora:lcm-lora:1>beautiful anime girl, short hair, red hair, red eyes, realistic, masterpiece, azur lane, 4k, high quality" --sampling-method lcm --cfg-scale 1 --steps 5 -t 1 -s 424354
ggml_init_cublas: GGML_CUDA_FORCE_MMQ:   no
ggml_init_cublas: CUDA_USE_TENSOR_CORES: yes
ggml_init_cublas: found 1 CUDA devices:
  Device 0: NVIDIA GeForce RTX 3050 Laptop GPU, compute capability 8.6
[INFO]  stable-diffusion.cpp:4432 - loading model from 'AnythingV5_v5PrtRE-f16.gguf'
[INFO]  stable-diffusion.cpp:4460 - Stable Diffusion 1.x | AnythingV5_v5PrtRE.safetensors
[INFO]  stable-diffusion.cpp:4468 - model data type: f16
[INFO]  stable-diffusion.cpp:4638 - total memory buffer size = 1877.33MB (clip 236.18MB, unet 1641.16MB, vae 0.00MB)
[INFO]  stable-diffusion.cpp:4640 - loading model from 'AnythingV5_v5PrtRE-f16.gguf' completed, taking 1.56s
[INFO]  stable-diffusion.cpp:4664 - running in eps-prediction mode
[INFO]  stable-diffusion.cpp:3911 - loading taesd from 'taesd-model.gguf'
[INFO]  stable-diffusion.cpp:3990 - taesd model loaded
[INFO]  stable-diffusion.cpp:5505 - img2img 512x512
[INFO]  stable-diffusion.cpp:5509 - target t_enc is 3 steps
[INFO]  stable-diffusion.cpp:4005 - loading LoRA from 'Kana_Arima-10.gguf'
[INFO]  stable-diffusion.cpp:4031 - LoRA Type: regular | Kana_Arima-10.safetensors
[INFO]  stable-diffusion.cpp:4051 - LoRA data type: f16
[INFO]  stable-diffusion.cpp:4748 - lora 'Kana_Arima-10' applied, taking 0.22s
[INFO]  stable-diffusion.cpp:4005 - loading LoRA from 'lcm-lora.gguf'
[INFO]  stable-diffusion.cpp:4031 - LoRA Type: regular | lcm_lora.safetensors
[INFO]  stable-diffusion.cpp:4051 - LoRA data type: f16
[INFO]  stable-diffusion.cpp:4748 - lora 'lcm-lora' applied, taking 0.50s
[INFO]  stable-diffusion.cpp:5545 - apply_loras completed, taking 0.72s
GGML_ASSERT: C:\proyectos\stable-diffusion.cpp\ggml\src\ggml.c:3185: ggml_are_same_shape(a, b) || ggml_can_repeat_rows(b, a)

@slaren
Copy link
Collaborator

slaren commented Nov 27, 2023

Right, we need to implement support for broadcasting dimension 0, but that can be done by extending the current code. The hard part is doing it without either duplicating large amounts of code or severely impacting performance even for non-broadcast cases. This would be much easier to do with C++ templates.

@FSSRepo
Copy link
Collaborator Author

FSSRepo commented Nov 27, 2023

My current implementation works in most common cases, clip, llama, stable diffusion, and others, although I can undo my changes and continue waiting for an implementation that adapts to any change, which I see as very difficult and will only impact performance.

Forcing ggml_add to broadcast in stable diffusion, changing ggml_can_repeat_rows to ggml_can_repeat:

output

@slaren
Copy link
Collaborator

slaren commented Nov 27, 2023

This should be enough to add broadcasting in dimension 0 to add_f32:

diff --git a/src/ggml.c b/src/ggml.c
index 7069542..a0f76c9 100644
--- a/src/ggml.c
+++ b/src/ggml.c
@@ -6897,7 +6897,8 @@ static void ggml_compute_forward_add_f32(
             float * src0_ptr = (float *) ((char *) src0->data + i03*nb03 + i02*nb02 + i01*nb01);

             for (int i0 = 0; i0 < ne0; i0++) {
-                float * src1_ptr = (float *) ((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11 + i0*nb10);
+                const int64_t i10 = i0 % ne10;
+                float * src1_ptr = (float *) ((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11 + i10*nb10);

                 dst_ptr[i0] = src0_ptr[i0] + *src1_ptr;
             }

The only issue is that the additional modulus per column may be too costly.

@FSSRepo
Copy link
Collaborator Author

FSSRepo commented Nov 27, 2023

It doesn't work, besides, you made that change in the part where if tensor b (a. k.a src1) is not contiguous, and in stable diffusion, all tensors that arrive at ggml_add are contiguous.

@slaren
Copy link
Collaborator

slaren commented Nov 27, 2023

The principle is the same, just replace ggml_vec_add_f32 with a loop and calculate the index of src1 as i0 % ne10.

@FSSRepo
Copy link
Collaborator Author

FSSRepo commented Nov 27, 2023

The principle is the same, just replace ggml_vec_add_f32 with a loop and calculate the index of src1 as i0 % ne10.

It works, I applied that for ggml_mult and I get corrects result, but with 2 seconds more of time.

@slaren
Copy link
Collaborator

slaren commented Nov 27, 2023

Something like this might be a bit more efficient in some cases at least (for the contiguous case):

            for (int r0 = 0; r0 < ne0 / ne10; ++r0) {
                ggml_vec_add_f32(ne10, dst_ptr + r0*ne10, src0_ptr + r0*ne10, src1_ptr);
            }

@ggerganov
Copy link
Owner

Do you need to broadcast just bias tensors? If so, you can do what we do in whisper.cpp:

https://github.com/ggerganov/whisper.cpp/blob/641f2f42823affb6e5c471b63317deefb0b6e3e9/whisper.cpp#L1620-L1640

Let me know if it is not clear

@FSSRepo
Copy link
Collaborator Author

FSSRepo commented Nov 27, 2023

@ggerganov the memory cost 💀, for small models is good, but stable diffusion it will be 500MB of bias

@slaren
Copy link
Collaborator

slaren commented Nov 27, 2023

It would be good to support full broadcasting for ease of use regardless.

@FSSRepo
Copy link
Collaborator Author

FSSRepo commented Nov 27, 2023

Well, then what should I do? I honestly didn't intend to delve into a super complete implementation of broadcasting in this pull request. It was just meant to be a solution for the most common cases, easy to implement in the backends, and to expedite the optimal adaptation of other models. But it seems we've reached a deadlock.

The second intention, which drove me to create this pull request, was to discuss how I could optimize the kernel group_norm. For some reason, when I perform simultaneous summations across multiple threads, it gives me an incorrect value.

@slaren
Copy link
Collaborator

slaren commented Nov 27, 2023

Have you tested the performance with the change I suggested in #621 (comment) ? I don't think the implementation of broadcasting in this PR can be merged as is, it removes functionality and it is more complex than it needs to be.

@slaren
Copy link
Collaborator

slaren commented Nov 27, 2023

The second intention, which drove me to create this pull request, was to discuss how I could optimize the kernel group_norm. For some reason, when I perform simultaneous summations across multiple threads, it gives me an incorrect value.

The group_norm kernel needs to be rewritten to work in the same way as eg. the rms_norm kernel. Each step needs to be distributed to all the threads in a warp and then applying a reduction with __shfl_xor_sync or similar.

@FSSRepo
Copy link
Collaborator Author

FSSRepo commented Nov 27, 2023

@slaren I can consider reverting the broadcasting changes to ggml.c, applying the modification you have suggested, which allows broadcasting for dimension 0. However, for the CUDA backend, it will remain unchanged for now.

@slaren
Copy link
Collaborator

slaren commented Nov 27, 2023

However, for the CUDA backend, it will remain unchanged.

The CUDA backend also needs to be updated as well, I can help you with that, but it should be easier because the additional modulus is unlikely to affect performance significantly since it's all parallelized anyway. And if needed, we can create multiple versions of the kernels with and without broadcasting using templates.

@ggerganov
Copy link
Owner

@slaren I can consider reverting the broadcasting changes to ggml.c, applying the modification you have suggested, which allows broadcasting for dimension 0. However, for the CUDA backend, it will remain unchanged for now.

However, for the CUDA backend, it will remain unchanged.

The CUDA backend also needs to be updated as well, I can help you with that, but it should be easier because the additional modulus is unlikely to affect performance significantly since it's all parallelized anyway. And if needed, we can create multiple versions of the kernels with and without broadcasting using templates.

Ok, let's proceed like this. I will try to implement the Metal kernels.

Btw, it would be useful to have some sort of unit tests with these kind of changes. You seem to be testing with SD, but I don't have it setup. We should make a simple test that runs 2D and 3D broadcast ggml_add and ggml_mul with and without ggml_repeat() and compare the results.

@FSSRepo
Copy link
Collaborator Author

FSSRepo commented Nov 27, 2023

Each step needs to be distributed to all the threads in a warp and then applying a reduction with __shfl_xor_sync or similar.

I will try to rewrite the kernel using that. I will need to create an little example to understand how __shfl_xor_sync works and thus determine how the implementation should be.

@FSSRepo
Copy link
Collaborator Author

FSSRepo commented Nov 27, 2023

Btw, it would be useful to have some sort of unit tests with these kind of changes. You seem to be testing with SD, but I don't have it setup. We should make a simple test that runs 2D and 3D broadcast ggml_add and ggml_mul with and without ggml_repeat() and compare the results.

I'm going to add some tests that I have done.

@slaren
Copy link
Collaborator

slaren commented Nov 27, 2023

I will need to create an little example to understand how __shfl_xor_sync works

This article is pretty good: https://developer.nvidia.com/blog/using-cuda-warp-level-primitives/. It's a bit outdated because these primitives now have the _sync suffix, they synchronize the threads automatically and the calls to __syncwarp are no longer needed, but otherwise it is all the same.

@FSSRepo
Copy link
Collaborator Author

FSSRepo commented Nov 30, 2023

@slaren @ggerganov

I have created this CUDA kernel for flash attention that merges softmax(QK)V. I expected that launching it with many threads would increase performance, but it is definitely saturating the bandwidth and performing poorly. I will ask in the original flash attention repository to see if anyone can tell me what I am doing wrong.

CUDA Kernel

static __device__ __forceinline__ float warp_reduce_max(float x) {
#pragma unroll
    for (int mask = 16; mask > 0; mask >>= 1) {
        x = fmaxf(__shfl_xor_sync(0xffffffff, x, mask, 32), x);
    }
    return x;
}

#define CUDA_FLASH_ATTENTION_BLOCK_SIZE 1024

template<int block_size>
static __global__ void flash_attn_f32(const float* q, const float* k,const float* v, float* dst, float kq_scale,
        int d_head, int seq_len, int num_heads) {
        const int head = blockIdx.x / seq_len;
        const int head_size = d_head * seq_len;
        const int s = blockIdx.x % seq_len;
        const int tid = threadIdx.x;

        extern __shared__  char work_data[];
        float* S = (float*)work_data; // theorical sequent length: 12848, due memory per block limit
        float* warp_data = (float*)(work_data + seq_len * sizeof(float));

        // QK^T
        for(int is = tid; is < seq_len; is += block_size) {
                S[is] = 0.0f;
                int key_offset = is * d_head + head * head_size;
                int query_offset = s * d_head + head * head_size;
                for(int d = 0; d < d_head; d++) {
                        S[is] += k[key_offset + d] * q[query_offset + d];
                }
                S[is] *= kq_scale;
        }

        __syncthreads();

        float max_val = -INFINITY;
        // get the max
        for(int is = tid; is < seq_len; is += block_size) {
                max_val = fmaxf(max_val , S[is]);
        }

        max_val = warp_reduce_max(max_val);
        { // get max from all threads
            int warp_id = threadIdx.x / WARP_SIZE;
            int lane_id = threadIdx.x % WARP_SIZE;
            if (lane_id == 0) {
                warp_data[warp_id] = max_val;
            }
            __syncthreads();
            max_val = warp_data[lane_id];
            max_val = warp_reduce_max(max_val);
        }

        // softmax(QK^T)
        float sum = 0.0f;
        for(int is = tid; is < seq_len;is += block_size) {
                const float val = expf(S[is] - max_val);
                S[is] = val;
                sum += val;
        }

        sum = warp_reduce_sum(sum);
        { // sum partials
            int warp_id = threadIdx.x / WARP_SIZE;
            int lane_id = threadIdx.x % WARP_SIZE;
            if (lane_id == 0) {
                warp_data[warp_id] = sum;
            }
            __syncthreads();
            sum = warp_data[lane_id];
            sum = warp_reduce_sum(sum);
        }

        float inv_sum = 1.0f / sum;
        for(int is = tid; is < seq_len; is += block_size) {
                S[is] *= inv_sum;
        }

        __syncthreads();
        // softmax(QK^T)V
        for (int d = tid; d < d_head; d += block_size) {
                int dst_index = d + s * d_head + head * head_size;
                int value_offset = d * seq_len +   head * head_size;
                dst[dst_index] = 0.0f;
                for(int ic = 0; ic < seq_len; ic++) {
                        dst[dst_index] += v[value_offset + ic] * S[ic];
                }
        }
}

Launcher

static void flash_attn_f32_cuda(const float* q, const float* k,const float* v, float* dst, float kq_scale, const int d_head, const int seq_len, const int num_heads, cudaStream_t stream) {
    int sram_memory_size = seq_len*sizeof(float) + WARP_SIZE * sizeof(float);
    int num_blocks = num_heads * seq_len;
    flash_attn_f32<CUDA_FLASH_ATTENTION_BLOCK_SIZE><<<num_blocks, CUDA_FLASH_ATTENTION_BLOCK_SIZE, sram_memory_size, stream>>>(
            q, k, v, dst, kq_scale, d_head, seq_len, num_heads);
}

@FSSRepo
Copy link
Collaborator Author

FSSRepo commented Dec 12, 2023

The results between backends don't match, and it seems that 2 seconds per iteration is too much for it to be an M3 Max on the Metal backend, although I don't know much about how Metal works.

@slaren
Copy link
Collaborator

slaren commented Dec 12, 2023

I am not sure that we should expect identical images between backends. I tried reducing the max mse, but didn't see anything clearly wrong. The largest error usually comes from the matrix multiplications, but that's not different with CUDA.

@FSSRepo
Copy link
Collaborator Author

FSSRepo commented Dec 12, 2023

It's just that identical results are not expected between backends like CPU and CUDA, but the difference is only small artifacts in the images, not a complete change like this.

@slaren
Copy link
Collaborator

slaren commented Dec 12, 2023

Is this closer? I found a bug in the tanh kernel.

output

@slaren
Copy link
Collaborator

slaren commented Dec 12, 2023

@ggerganov I have been trying the Metal debugging flags, and I found some issues running MTL_DEBUG_LAYER=1 MTL_SHADER_VALIDATION=1 MTL_SHADER_VALIDATION_REPORT_TO_STDERR=1 MTL_SHADER_VALIDATION_FAIL_MODE=allow build/bin/test-backend-ops -b Metal

  • Metal does not like not passing src1 to soft_max. I couldn't figure how to explicitly pass NULL, so instead I am passing src0 and checking src1 == src0 instead of NULL in the kernel
  • tanh uses float4 when it should be using float, leading to memory errors
  • Some matrix multiplications fail with these flags, but without errors
  MUL_MAT(type_a=f32,type_b=f32,m=16,n=1,k=256,bs=[1,1],nr=[1,1]): OK
  MUL_MAT(type_a=f32,type_b=f32,m=16,n=1,k=256,bs=[10,1],nr=[1,1]): OK
  MUL_MAT(type_a=f32,type_b=f32,m=16,n=1,k=256,bs=[10,1],nr=[2,1]): OK
  MUL_MAT(type_a=f32,type_b=f32,m=16,n=1,k=256,bs=[10,10],nr=[1,1]): OK
  MUL_MAT(type_a=f32,type_b=f32,m=16,n=1,k=256,bs=[10,10],nr=[2,1]): OK
  MUL_MAT(type_a=f32,type_b=f32,m=16,n=1,k=256,bs=[10,10],nr=[1,2]): OK
  MUL_MAT(type_a=f32,type_b=f32,m=16,n=1,k=256,bs=[10,10],nr=[2,2]): OK
  MUL_MAT(type_a=f32,type_b=f32,m=16,n=16,k=256,bs=[1,1],nr=[1,1]): NMSE = 2.678957 FAIL
  MUL_MAT(type_a=f32,type_b=f32,m=16,n=16,k=256,bs=[10,1],nr=[1,1]): NMSE = 2.857147 FAIL
  MUL_MAT(type_a=f32,type_b=f32,m=16,n=16,k=256,bs=[10,1],nr=[2,1]): NMSE = 3.158050 FAIL
  MUL_MAT(type_a=f32,type_b=f32,m=16,n=16,k=256,bs=[10,10],nr=[1,1]): NMSE = 2.982662 FAIL
  MUL_MAT(type_a=f32,type_b=f32,m=16,n=16,k=256,bs=[10,10],nr=[2,1]): NMSE = 3.052191 FAIL
  MUL_MAT(type_a=f32,type_b=f32,m=16,n=16,k=256,bs=[10,10],nr=[1,2]): NMSE = 3.040786 FAIL
  MUL_MAT(type_a=f32,type_b=f32,m=16,n=16,k=256,bs=[10,10],nr=[2,2]): NMSE = 3.000345 FAIL
  MUL_MAT(type_a=f16,type_b=f32,m=16,n=1,k=256,bs=[1,1],nr=[1,1]): OK
  MUL_MAT(type_a=f16,type_b=f32,m=16,n=1,k=256,bs=[10,1],nr=[1,1]): OK
  MUL_MAT(type_a=f16,type_b=f32,m=16,n=1,k=256,bs=[10,1],nr=[2,1]): OK
  MUL_MAT(type_a=f16,type_b=f32,m=16,n=1,k=256,bs=[10,10],nr=[1,1]): OK
  MUL_MAT(type_a=f16,type_b=f32,m=16,n=1,k=256,bs=[10,10],nr=[2,1]): OK
  MUL_MAT(type_a=f16,type_b=f32,m=16,n=1,k=256,bs=[10,10],nr=[1,2]): OK
  MUL_MAT(type_a=f16,type_b=f32,m=16,n=1,k=256,bs=[10,10],nr=[2,2]): OK
  MUL_MAT(type_a=f16,type_b=f32,m=16,n=16,k=256,bs=[1,1],nr=[1,1]): NMSE = 3.793017 FAIL
  MUL_MAT(type_a=f16,type_b=f32,m=16,n=16,k=256,bs=[10,1],nr=[1,1]): NMSE = 3.030888 FAIL
  MUL_MAT(type_a=f16,type_b=f32,m=16,n=16,k=256,bs=[10,1],nr=[2,1]): NMSE = 2.961799 FAIL
  MUL_MAT(type_a=f16,type_b=f32,m=16,n=16,k=256,bs=[10,10],nr=[1,1]): NMSE = 3.067143 FAIL
  MUL_MAT(type_a=f16,type_b=f32,m=16,n=16,k=256,bs=[10,10],nr=[2,1]): NMSE = 3.019470 FAIL
  MUL_MAT(type_a=f16,type_b=f32,m=16,n=16,k=256,bs=[10,10],nr=[1,2]): NMSE = 3.042080 FAIL
  MUL_MAT(type_a=f16,type_b=f32,m=16,n=16,k=256,bs=[10,10],nr=[2,2]): NMSE = 2.989152 FAIL
  MUL_MAT(type_a=q4_0,type_b=f32,m=16,n=1,k=256,bs=[1,1],nr=[1,1]): OK
  MUL_MAT(type_a=q4_0,type_b=f32,m=16,n=1,k=256,bs=[10,1],nr=[1,1]): NMSE = 1.124740 FAIL
  MUL_MAT(type_a=q4_0,type_b=f32,m=16,n=1,k=256,bs=[10,1],nr=[2,1]): NMSE = 0.961504 FAIL
  MUL_MAT(type_a=q4_0,type_b=f32,m=16,n=1,k=256,bs=[10,10],nr=[1,1]): NMSE = 1.031065 FAIL
  MUL_MAT(type_a=q4_0,type_b=f32,m=16,n=1,k=256,bs=[10,10],nr=[2,1]): NMSE = 1.016941 FAIL
  MUL_MAT(type_a=q4_0,type_b=f32,m=16,n=1,k=256,bs=[10,10],nr=[1,2]): NMSE = 1.005911 FAIL
  MUL_MAT(type_a=q4_0,type_b=f32,m=16,n=1,k=256,bs=[10,10],nr=[2,2]): NMSE = 0.996707 FAIL
  MUL_MAT(type_a=q4_0,type_b=f32,m=16,n=16,k=256,bs=[1,1],nr=[1,1]): NMSE = 0.833259 FAIL
  MUL_MAT(type_a=q4_0,type_b=f32,m=16,n=16,k=256,bs=[10,1],nr=[1,1]): NMSE = 0.984862 FAIL
  MUL_MAT(type_a=q4_0,type_b=f32,m=16,n=16,k=256,bs=[10,1],nr=[2,1]): NMSE = 0.985815 FAIL
  MUL_MAT(type_a=q4_0,type_b=f32,m=16,n=16,k=256,bs=[10,10],nr=[1,1]): NMSE = 1.011241 FAIL
  MUL_MAT(type_a=q4_0,type_b=f32,m=16,n=16,k=256,bs=[10,10],nr=[2,1]): NMSE = 0.986323 FAIL
  MUL_MAT(type_a=q4_0,type_b=f32,m=16,n=16,k=256,bs=[10,10],nr=[1,2]): NMSE = 1.002740 FAIL
  MUL_MAT(type_a=q4_0,type_b=f32,m=16,n=16,k=256,bs=[10,10],nr=[2,2]): NMSE = 1.004022 FAIL
  MUL_MAT(type_a=q4_1,type_b=f32,m=16,n=1,k=256,bs=[1,1],nr=[1,1]): OK
  MUL_MAT(type_a=q4_1,type_b=f32,m=16,n=1,k=256,bs=[10,1],nr=[1,1]): NMSE = 1.121132 FAIL
  MUL_MAT(type_a=q4_1,type_b=f32,m=16,n=1,k=256,bs=[10,1],nr=[2,1]): NMSE = 1.085911 FAIL
  MUL_MAT(type_a=q4_1,type_b=f32,m=16,n=1,k=256,bs=[10,10],nr=[1,1]): NMSE = 0.950441 FAIL
  MUL_MAT(type_a=q4_1,type_b=f32,m=16,n=1,k=256,bs=[10,10],nr=[2,1]): NMSE = 0.969166 FAIL
  MUL_MAT(type_a=q4_1,type_b=f32,m=16,n=1,k=256,bs=[10,10],nr=[1,2]): NMSE = 1.035165 FAIL
  MUL_MAT(type_a=q4_1,type_b=f32,m=16,n=1,k=256,bs=[10,10],nr=[2,2]): NMSE = 0.973459 FAIL
  MUL_MAT(type_a=q4_1,type_b=f32,m=16,n=16,k=256,bs=[1,1],nr=[1,1]): NMSE = 1.110874 FAIL
  MUL_MAT(type_a=q4_1,type_b=f32,m=16,n=16,k=256,bs=[10,1],nr=[1,1]): NMSE = 1.017867 FAIL
  MUL_MAT(type_a=q4_1,type_b=f32,m=16,n=16,k=256,bs=[10,1],nr=[2,1]): NMSE = 0.958266 FAIL
  MUL_MAT(type_a=q4_1,type_b=f32,m=16,n=16,k=256,bs=[10,10],nr=[1,1]): NMSE = 1.003077 FAIL
  MUL_MAT(type_a=q4_1,type_b=f32,m=16,n=16,k=256,bs=[10,10],nr=[2,1]): NMSE = 0.993689 FAIL
  MUL_MAT(type_a=q4_1,type_b=f32,m=16,n=16,k=256,bs=[10,10],nr=[1,2]): NMSE = 1.000277 FAIL
  MUL_MAT(type_a=q4_1,type_b=f32,m=16,n=16,k=256,bs=[10,10],nr=[2,2]): NMSE = 1.010218 FAIL

@FSSRepo
Copy link
Collaborator Author

FSSRepo commented Dec 12, 2023

@slaren In stable diffusion computing, Tanh is not used, unless specified with --taesd to use TinyAutoEncoder.

@slaren
Copy link
Collaborator

slaren commented Dec 12, 2023

It might have been the soft max src1 issue, but I am surprised that doesn't cause bigger problems if it is really a bug.

@FSSRepo
Copy link
Collaborator Author

FSSRepo commented Dec 12, 2023

I hope that this is the inconsistency issue between results across backends. The output generated by the Metal backend should be similar, if not identical, to that of the CPU-only backend.

This is CUDA result:

previus softmax cuda kernel current softmax cuda kernel
output output

@slaren
Copy link
Collaborator

slaren commented Dec 12, 2023

Looks like it was the soft max issue. Getting this now.

output

@FSSRepo
Copy link
Collaborator Author

FSSRepo commented Dec 12, 2023

Now the performance issue remains, which, for being an M3 Max, the processing of the UNet and VAE compute seems very slow to me, my RTX 3050 laptop I get 1.5 iterations per second.

@slaren
Copy link
Collaborator

slaren commented Dec 12, 2023

@ggerganov please review the changes, I am not sure that's the best solution for the problems with soft max.

@slaren
Copy link
Collaborator

slaren commented Dec 12, 2023

Here is a test-backend-ops perf run with Metal:

  ADD(type=f32,ne=[320,1,1,1],nr=[1,4096,1,1]):                  2185 runs -    33.26 us/run -    15360 kB/run -  440.42 GB/s
  SILU(type=f32,ne=[64,64,320,1]):                               3277 runs -    19.46 us/run -    10240 kB/run -  501.74 GB/s
  SOFT_MAX(type=f32,ne=[4096,4096,8,1]):                           33 runs -  3468.12 us/run -  1048576 kB/run -  288.34 GB/s
  IM2COL(type_input=f32,type_kernel=f16,ne_input=[64,64,320,1],ne_kernel=[3,3,320,320],s0=1,s1=1,p0=1,p1=1,d0=1,d1=1,is_2D=1):                   1120 runs -  2850.00 us/run -    29960 kB/run -   10.03 GB/s
  GROUP_NORM(type=f32,ne=[64,64,320,1],num_groups=32):           3277 runs -   522.49 us/run -    10240 kB/run -   18.69 GB/s
  ABS(type=f32,ne=[128,10,10,10]): not supported
  SGN(type=f32,ne=[128,10,10,10]): not supported
  NEG(type=f32,ne=[128,10,10,10]): not supported
  STEP(type=f32,ne=[128,10,10,10]): not supported
  TANH(type=f32,ne=[128,10,10,10]):                              8192 runs -    12.02 us/run -     1000 kB/run -   79.36 GB/s
  ELU(type=f32,ne=[128,10,10,10]): not supported
  RELU(type=f32,ne=[128,10,10,10]):                              8192 runs -     7.60 us/run -     1000 kB/run -  125.48 GB/s
  GELU(type=f32,ne=[128,10,10,10]):                              8192 runs -     5.21 us/run -     1000 kB/run -  183.07 GB/s
  GELU_QUICK(type=f32,ne=[128,10,10,10]):                        8192 runs -     3.69 us/run -     1000 kB/run -  258.17 GB/s
  SILU(type=f32,ne=[128,10,10,10]):                              8192 runs -     3.66 us/run -     1000 kB/run -  260.38 GB/s
  GET_ROWS(type=f32,n=10,m=5,r=3):                               8192 runs -     1.74 us/run -        0 kB/run -    0.18 GB/s
  GET_ROWS(type=f32,n=16,m=5,r=3):                               8192 runs -     2.98 us/run -        0 kB/run -    0.16 GB/s
  GET_ROWS(type=f16,n=10,m=5,r=3):                               8192 runs -     1.76 us/run -        0 kB/run -    0.12 GB/s
  GET_ROWS(type=f16,n=16,m=5,r=3):                               8192 runs -     2.37 us/run -        0 kB/run -    0.14 GB/s
  REPEAT(type=f32,ne=[10,10,10,10],nr=[1,1,1,1]): not supported
  REPEAT(type=f32,ne=[10,10,10,10],nr=[2,1,1,1]): not supported
  REPEAT(type=f32,ne=[10,10,10,10],nr=[1,2,1,1]): not supported
  REPEAT(type=f32,ne=[10,10,10,10],nr=[1,1,2,1]): not supported
  REPEAT(type=f32,ne=[10,10,10,10],nr=[1,1,1,2]): not supported
  DUP(type=f32,ne=[10,10,10,1]):                                 8192 runs -     3.90 us/run -        7 kB/run -    1.91 GB/s
  CPY(type_src=f32,type_dst=f32,ne=[10,10,10,1]):                8192 runs -     4.16 us/run -        7 kB/run -    1.79 GB/s
  CONT(type=f32,ne=[10,10,10,1]):                                8191 runs -     4.24 us/run -        7 kB/run -    1.76 GB/s
  ADD(type=f32,ne=[1,1,8,1],nr=[1,1,1,1]):                       8192 runs -     3.73 us/run -        0 kB/run -    0.02 GB/s
  MUL(type=f32,ne=[1,1,8,1],nr=[1,1,1,1]):                       8192 runs -     3.78 us/run -        0 kB/run -    0.02 GB/s
  DIV(type=f32,ne=[1,1,8,1],nr=[1,1,1,1]):                       8192 runs -     3.79 us/run -        0 kB/run -    0.02 GB/s
  ADD(type=f32,ne=[1,1,320,320],nr=[1,1,1,1]):                   8192 runs -   253.17 us/run -     1200 kB/run -    4.52 GB/s
  MUL(type=f32,ne=[1,1,320,320],nr=[1,1,1,1]):                   8192 runs -   252.58 us/run -     1200 kB/run -    4.53 GB/s
  DIV(type=f32,ne=[1,1,320,320],nr=[1,1,1,1]):                   8192 runs -   253.15 us/run -     1200 kB/run -    4.52 GB/s
  ADD(type=f32,ne=[16,10,1,1],nr=[1,1,1,1]):                     8192 runs -     3.84 us/run -        1 kB/run -    0.47 GB/s
  MUL(type=f32,ne=[16,10,1,1],nr=[1,1,1,1]):                     8192 runs -     3.90 us/run -        1 kB/run -    0.46 GB/s
  DIV(type=f32,ne=[16,10,1,1],nr=[1,1,1,1]):                     8192 runs -     3.90 us/run -        1 kB/run -    0.46 GB/s
  ADD(type=f32,ne=[16,10,10,1],nr=[1,1,1,1]):                    8192 runs -     4.11 us/run -       18 kB/run -    4.35 GB/s
  MUL(type=f32,ne=[16,10,10,1],nr=[1,1,1,1]):                    8192 runs -     4.10 us/run -       18 kB/run -    4.36 GB/s
  DIV(type=f32,ne=[16,10,10,1],nr=[1,1,1,1]):                    8192 runs -     4.14 us/run -       18 kB/run -    4.32 GB/s
  ADD(type=f32,ne=[16,10,10,10],nr=[1,1,1,1]):                   8192 runs -     6.91 us/run -      187 kB/run -   25.88 GB/s
  MUL(type=f32,ne=[16,10,10,10],nr=[1,1,1,1]):                   8192 runs -     6.82 us/run -      187 kB/run -   26.23 GB/s
  DIV(type=f32,ne=[16,10,10,10],nr=[1,1,1,1]):                   8192 runs -     6.84 us/run -      187 kB/run -   26.13 GB/s
  ADD(type=f32,ne=[16,10,10,10],nr=[2,1,1,1]):                   8192 runs -     7.71 us/run -      375 kB/run -   46.41 GB/s
  MUL(type=f32,ne=[16,10,10,10],nr=[2,1,1,1]):                   8192 runs -     7.60 us/run -      375 kB/run -   47.03 GB/s
  DIV(type=f32,ne=[16,10,10,10],nr=[2,1,1,1]):                   8192 runs -     7.68 us/run -      375 kB/run -   46.58 GB/s
  ADD(type=f32,ne=[16,10,10,10],nr=[1,2,1,1]):                   8192 runs -    10.08 us/run -      375 kB/run -   35.49 GB/s
  MUL(type=f32,ne=[16,10,10,10],nr=[1,2,1,1]):                   8192 runs -    10.19 us/run -      375 kB/run -   35.10 GB/s
  DIV(type=f32,ne=[16,10,10,10],nr=[1,2,1,1]):                   8192 runs -    10.25 us/run -      375 kB/run -   34.89 GB/s
  ADD(type=f32,ne=[16,10,10,10],nr=[1,1,2,1]):                   8192 runs -     9.91 us/run -      375 kB/run -   36.07 GB/s
  MUL(type=f32,ne=[16,10,10,10],nr=[1,1,2,1]):                   8192 runs -     9.81 us/run -      375 kB/run -   36.45 GB/s
  DIV(type=f32,ne=[16,10,10,10],nr=[1,1,2,1]):                   8192 runs -     9.78 us/run -      375 kB/run -   36.59 GB/s
  ADD(type=f32,ne=[16,10,10,10],nr=[1,1,1,2]):                   8192 runs -     9.74 us/run -      375 kB/run -   36.71 GB/s
  MUL(type=f32,ne=[16,10,10,10],nr=[1,1,1,2]):                   8192 runs -     9.84 us/run -      375 kB/run -   36.33 GB/s
  DIV(type=f32,ne=[16,10,10,10],nr=[1,1,1,2]):                   8192 runs -     9.89 us/run -      375 kB/run -   36.15 GB/s
  ADD(type=f32,ne=[16,10,10,10],nr=[1,1,2,2]):                   8192 runs -    17.14 us/run -      750 kB/run -   41.74 GB/s
  MUL(type=f32,ne=[16,10,10,10],nr=[1,1,2,2]):                   8192 runs -    16.94 us/run -      750 kB/run -   42.21 GB/s
  DIV(type=f32,ne=[16,10,10,10],nr=[1,1,2,2]):                   8192 runs -    17.10 us/run -      750 kB/run -   41.83 GB/s
  ADD(type=f32,ne=[16,10,10,10],nr=[1,2,2,2]):                   8192 runs -    27.29 us/run -     1500 kB/run -   52.42 GB/s
  MUL(type=f32,ne=[16,10,10,10],nr=[1,2,2,2]):                   8192 runs -    27.35 us/run -     1500 kB/run -   52.31 GB/s
  DIV(type=f32,ne=[16,10,10,10],nr=[1,2,2,2]):                   8192 runs -    27.34 us/run -     1500 kB/run -   52.33 GB/s
  ADD(type=f32,ne=[16,10,10,10],nr=[2,2,2,2]):                   8192 runs -    30.25 us/run -     3000 kB/run -   94.57 GB/s
  MUL(type=f32,ne=[16,10,10,10],nr=[2,2,2,2]):                   8192 runs -    29.98 us/run -     3000 kB/run -   95.42 GB/s
  DIV(type=f32,ne=[16,10,10,10],nr=[2,2,2,2]):                   8192 runs -    30.04 us/run -     3000 kB/run -   95.25 GB/s
  ADD(type=f32,ne=[1280,1,1,1],nr=[1,1,1,1]):                    8192 runs -     2.21 us/run -       15 kB/run -    6.47 GB/s
  MUL(type=f32,ne=[1280,1,1,1],nr=[1,1,1,1]):                    8192 runs -     2.20 us/run -       15 kB/run -    6.51 GB/s
  DIV(type=f32,ne=[1280,1,1,1],nr=[1,1,1,1]):                    8192 runs -     2.20 us/run -       15 kB/run -    6.49 GB/s
  ADD(type=f32,ne=[1280,1,1,1],nr=[1,16,16,1]):                  8192 runs -    10.32 us/run -     3840 kB/run -  354.84 GB/s
  MUL(type=f32,ne=[1280,1,1,1],nr=[1,16,16,1]):                  8192 runs -    10.47 us/run -     3840 kB/run -  349.62 GB/s
  DIV(type=f32,ne=[1280,1,1,1],nr=[1,16,16,1]):                  8192 runs -    10.46 us/run -     3840 kB/run -  350.25 GB/s
  ADD(type=f32,ne=[1280,16,16,1],nr=[1,1,1,1]):                  8192 runs -    25.79 us/run -     3840 kB/run -  142.00 GB/s
  MUL(type=f32,ne=[1280,16,16,1],nr=[1,1,1,1]):                  8192 runs -    25.65 us/run -     3840 kB/run -  142.75 GB/s
  DIV(type=f32,ne=[1280,16,16,1],nr=[1,1,1,1]):                  8192 runs -    25.73 us/run -     3840 kB/run -  142.32 GB/s
  ADD(type=f32,ne=[1280,1,1,1],nr=[1,256,1,1]):                  8192 runs -    10.61 us/run -     3840 kB/run -  345.15 GB/s
  MUL(type=f32,ne=[1280,1,1,1],nr=[1,256,1,1]):                  8192 runs -    10.40 us/run -     3840 kB/run -  352.00 GB/s
  DIV(type=f32,ne=[1280,1,1,1],nr=[1,256,1,1]):                  8192 runs -    10.48 us/run -     3840 kB/run -  349.41 GB/s
  ADD(type=f32,ne=[1,1,1280,1],nr=[16,16,1,1]):                  8192 runs -    54.75 us/run -     3840 kB/run -   66.88 GB/s
  MUL(type=f32,ne=[1,1,1280,1],nr=[16,16,1,1]):                  8192 runs -    55.07 us/run -     3840 kB/run -   66.50 GB/s
  DIV(type=f32,ne=[1,1,1280,1],nr=[16,16,1,1]):                  8192 runs -    54.70 us/run -     3840 kB/run -   66.95 GB/s
  ADD(type=f32,ne=[16,16,1280,1],nr=[1,1,1,1]):                  8192 runs -    55.10 us/run -     3840 kB/run -   66.47 GB/s
  MUL(type=f32,ne=[16,16,1280,1],nr=[1,1,1,1]):                  8192 runs -    54.40 us/run -     3840 kB/run -   67.32 GB/s
  DIV(type=f32,ne=[16,16,1280,1],nr=[1,1,1,1]):                  8192 runs -    54.84 us/run -     3840 kB/run -   66.78 GB/s
  ADD(type=f32,ne=[1,1,1920,1],nr=[16,16,1,1]):                  5826 runs -    79.55 us/run -     5760 kB/run -   69.05 GB/s
  MUL(type=f32,ne=[1,1,1920,1],nr=[16,16,1,1]):                  5826 runs -    79.23 us/run -     5760 kB/run -   69.33 GB/s
  DIV(type=f32,ne=[1,1,1920,1],nr=[16,16,1,1]):                  5826 runs -    79.52 us/run -     5760 kB/run -   69.08 GB/s
  ADD(type=f32,ne=[1,1,2560,1],nr=[16,16,1,1]):                  4370 runs -   103.71 us/run -     7680 kB/run -   70.62 GB/s
  MUL(type=f32,ne=[1,1,2560,1],nr=[16,16,1,1]):                  4370 runs -   103.64 us/run -     7680 kB/run -   70.67 GB/s
  DIV(type=f32,ne=[1,1,2560,1],nr=[16,16,1,1]):                  4370 runs -   103.68 us/run -     7680 kB/run -   70.64 GB/s
  ADD(type=f32,ne=[1,1,1280,1],nr=[32,32,1,1]):                  2185 runs -   103.62 us/run -    15360 kB/run -  141.37 GB/s
  MUL(type=f32,ne=[1,1,1280,1],nr=[32,32,1,1]):                  2185 runs -   102.87 us/run -    15360 kB/run -  142.40 GB/s
  DIV(type=f32,ne=[1,1,1280,1],nr=[32,32,1,1]):                  2185 runs -   104.72 us/run -    15360 kB/run -  139.88 GB/s
  ADD(type=f32,ne=[1,1,1920,1],nr=[32,32,1,1]):                  1457 runs -   149.96 us/run -    23040 kB/run -  146.52 GB/s
  MUL(type=f32,ne=[1,1,1920,1],nr=[32,32,1,1]):                  1457 runs -   150.42 us/run -    23040 kB/run -  146.08 GB/s
  DIV(type=f32,ne=[1,1,1920,1],nr=[32,32,1,1]):                  1457 runs -   150.37 us/run -    23040 kB/run -  146.12 GB/s
  ADD(type=f32,ne=[1,1,640,1],nr=[32,32,1,1]):                   4370 runs -    56.04 us/run -     7680 kB/run -  130.70 GB/s
  MUL(type=f32,ne=[1,1,640,1],nr=[32,32,1,1]):                   4370 runs -    55.77 us/run -     7680 kB/run -  131.32 GB/s
  DIV(type=f32,ne=[1,1,640,1],nr=[32,32,1,1]):                   4370 runs -    56.08 us/run -     7680 kB/run -  130.61 GB/s
  ADD(type=f32,ne=[5120,1,1,1],nr=[1,256,1,1]):                  2185 runs -    31.29 us/run -    15360 kB/run -  468.20 GB/s
  MUL(type=f32,ne=[5120,1,1,1],nr=[1,256,1,1]):                  2185 runs -    31.66 us/run -    15360 kB/run -  462.72 GB/s
  DIV(type=f32,ne=[5120,1,1,1],nr=[1,256,1,1]):                  2185 runs -    32.00 us/run -    15360 kB/run -  457.80 GB/s
  ADD(type=f32,ne=[640,1,1,1],nr=[1,1,1,1]):                     8192 runs -     2.16 us/run -        7 kB/run -    3.31 GB/s
  MUL(type=f32,ne=[640,1,1,1],nr=[1,1,1,1]):                     8192 runs -     2.14 us/run -        7 kB/run -    3.34 GB/s
  DIV(type=f32,ne=[640,1,1,1],nr=[1,1,1,1]):                     8192 runs -     2.16 us/run -        7 kB/run -    3.31 GB/s
  ADD(type=f32,ne=[3,3,2560,1280],nr=[1,1,1,1]):                   98 runs - 24149.36 us/run -   345600 kB/run -   13.65 GB/s
  MUL(type=f32,ne=[3,3,2560,1280],nr=[1,1,1,1]):                   98 runs - 24221.85 us/run -   345600 kB/run -   13.61 GB/s
  DIV(type=f32,ne=[3,3,2560,1280],nr=[1,1,1,1]):                   98 runs - 24058.21 us/run -   345600 kB/run -   13.70 GB/s
  ADD(type=f32,ne=[3,3,2560,1280],nr=[2,1,1,1]):                   49 runs - 23957.12 us/run -   691200 kB/run -   27.51 GB/s
  MUL(type=f32,ne=[3,3,2560,1280],nr=[2,1,1,1]):                   49 runs - 24263.18 us/run -   691200 kB/run -   27.17 GB/s
  DIV(type=f32,ne=[3,3,2560,1280],nr=[2,1,1,1]):                   49 runs - 24123.29 us/run -   691200 kB/run -   27.33 GB/s
  SCALE(type=f32,ne=[10,10,10,10]):                              8192 runs -     1.75 us/run -       78 kB/run -   42.68 GB/s
  NORM(type=f32,ne=[64,10,10,10],eps=0.000001):                  8192 runs -     9.38 us/run -      500 kB/run -   50.86 GB/s
  RMS_NORM(type=f32,ne=[64,10,10,10],eps=0.000001):              8192 runs -     5.85 us/run -      500 kB/run -   81.55 GB/s
  NORM(type=f32,ne=[64,10,10,10],eps=0.000010):                  8192 runs -     9.49 us/run -      500 kB/run -   50.27 GB/s
  RMS_NORM(type=f32,ne=[64,10,10,10],eps=0.000010):              8192 runs -     5.89 us/run -      500 kB/run -   80.96 GB/s
  NORM(type=f32,ne=[64,10,10,10],eps=0.001000):                  8192 runs -     9.50 us/run -      500 kB/run -   50.18 GB/s
  RMS_NORM(type=f32,ne=[64,10,10,10],eps=0.001000):              8192 runs -     5.83 us/run -      500 kB/run -   81.79 GB/s
  NORM(type=f32,ne=[64,10,10,10],eps=0.100000):                  8192 runs -     9.52 us/run -      500 kB/run -   50.11 GB/s
  RMS_NORM(type=f32,ne=[64,10,10,10],eps=0.100000):              8192 runs -     5.83 us/run -      500 kB/run -   81.83 GB/s
  MUL_MAT(type_a=f32,type_b=f32,m=16,n=1,k=256,bs=[1,1],nr=[1,1]):                   8192 runs -     3.45 us/run -       32 kB/run -    8.86 GB/s
  MUL_MAT(type_a=f32,type_b=f32,m=16,n=1,k=256,bs=[10,1],nr=[1,1]):                  8192 runs -     3.92 us/run -      320 kB/run -   78.03 GB/s
  MUL_MAT(type_a=f32,type_b=f32,m=16,n=1,k=256,bs=[10,1],nr=[2,1]):                  8192 runs -     4.36 us/run -      641 kB/run -  140.25 GB/s
  MUL_MAT(type_a=f32,type_b=f32,m=16,n=1,k=256,bs=[10,10],nr=[1,1]):                 8192 runs -     8.56 us/run -     3206 kB/run -  357.02 GB/s
  MUL_MAT(type_a=f32,type_b=f32,m=16,n=1,k=256,bs=[10,10],nr=[2,1]):                 5233 runs -    12.82 us/run -     6412 kB/run -  477.19 GB/s
  MUL_MAT(type_a=f32,type_b=f32,m=16,n=1,k=256,bs=[10,10],nr=[1,2]):                 5233 runs -    12.82 us/run -     6412 kB/run -  477.04 GB/s
  MUL_MAT(type_a=f32,type_b=f32,m=16,n=1,k=256,bs=[10,10],nr=[2,2]):                 2617 runs -    21.61 us/run -    12825 kB/run -  566.00 GB/s
  MUL_MAT(type_a=f32,type_b=f32,m=16,n=16,k=256,bs=[1,1],nr=[1,1]):                  8192 runs -    12.42 us/run -      513 kB/run -   39.40 GB/s
  MUL_MAT(type_a=f32,type_b=f32,m=16,n=16,k=256,bs=[10,1],nr=[1,1]):                 6541 runs -    13.65 us/run -     5130 kB/run -  358.32 GB/s
  MUL_MAT(type_a=f32,type_b=f32,m=16,n=16,k=256,bs=[10,1],nr=[2,1]):                 3271 runs -    16.52 us/run -    10260 kB/run -  592.16 GB/s
  MUL_MAT(type_a=f32,type_b=f32,m=16,n=16,k=256,bs=[10,10],nr=[1,1]):                 655 runs -    23.42 us/run -    51300 kB/run - 2089.25 GB/s
  MUL_MAT(type_a=f32,type_b=f32,m=16,n=16,k=256,bs=[10,10],nr=[2,1]):                 328 runs -    34.92 us/run -   102600 kB/run - 2801.97 GB/s
  MUL_MAT(type_a=f32,type_b=f32,m=16,n=16,k=256,bs=[10,10],nr=[1,2]):                 328 runs -    35.12 us/run -   102600 kB/run - 2786.16 GB/s
  MUL_MAT(type_a=f32,type_b=f32,m=16,n=16,k=256,bs=[10,10],nr=[2,2]):                 164 runs -    56.88 us/run -   205200 kB/run - 3440.22 GB/s
  MUL_MAT(type_a=f16,type_b=f32,m=16,n=1,k=256,bs=[1,1],nr=[1,1]):                   8192 runs -     3.52 us/run -       24 kB/run -    6.52 GB/s
  MUL_MAT(type_a=f16,type_b=f32,m=16,n=1,k=256,bs=[10,1],nr=[1,1]):                  8192 runs -     4.05 us/run -      240 kB/run -   56.60 GB/s
  MUL_MAT(type_a=f16,type_b=f32,m=16,n=1,k=256,bs=[10,1],nr=[2,1]):                  8192 runs -     4.42 us/run -      481 kB/run -  103.77 GB/s
  MUL_MAT(type_a=f16,type_b=f32,m=16,n=1,k=256,bs=[10,10],nr=[1,1]):                 8192 runs -     8.68 us/run -     2406 kB/run -  264.46 GB/s
  MUL_MAT(type_a=f16,type_b=f32,m=16,n=1,k=256,bs=[10,10],nr=[2,1]):                 6973 runs -    12.58 us/run -     4812 kB/run -  364.95 GB/s
  MUL_MAT(type_a=f16,type_b=f32,m=16,n=1,k=256,bs=[10,10],nr=[1,2]):                 6973 runs -    12.80 us/run -     4812 kB/run -  358.43 GB/s
  MUL_MAT(type_a=f16,type_b=f32,m=16,n=1,k=256,bs=[10,10],nr=[2,2]):                 3487 runs -    21.89 us/run -     9625 kB/run -  419.40 GB/s
  MUL_MAT(type_a=f16,type_b=f32,m=16,n=16,k=256,bs=[1,1],nr=[1,1]):                  8192 runs -    12.37 us/run -      385 kB/run -   29.68 GB/s
  MUL_MAT(type_a=f16,type_b=f32,m=16,n=16,k=256,bs=[10,1],nr=[1,1]):                 8192 runs -    12.69 us/run -     3850 kB/run -  289.35 GB/s
  MUL_MAT(type_a=f16,type_b=f32,m=16,n=16,k=256,bs=[10,1],nr=[2,1]):                 4358 runs -    15.74 us/run -     7700 kB/run -  466.62 GB/s
  MUL_MAT(type_a=f16,type_b=f32,m=16,n=16,k=256,bs=[10,10],nr=[1,1]):                 872 runs -    22.50 us/run -    38500 kB/run - 1631.84 GB/s
  MUL_MAT(type_a=f16,type_b=f32,m=16,n=16,k=256,bs=[10,10],nr=[2,1]):                 436 runs -    33.52 us/run -    77000 kB/run - 2190.68 GB/s
  MUL_MAT(type_a=f16,type_b=f32,m=16,n=16,k=256,bs=[10,10],nr=[1,2]):                 436 runs -    33.57 us/run -    77000 kB/run - 2187.24 GB/s
  MUL_MAT(type_a=f16,type_b=f32,m=16,n=16,k=256,bs=[10,10],nr=[2,2]):                 218 runs -    54.42 us/run -   154000 kB/run - 2698.88 GB/s
  MUL_MAT(type_a=q4_0,type_b=f32,m=16,n=1,k=256,bs=[1,1],nr=[1,1]):                  8192 runs -     4.03 us/run -       18 kB/run -    4.33 GB/s
  MUL_MAT(type_a=q4_0,type_b=f32,m=16,n=1,k=256,bs=[10,1],nr=[1,1]):                 8192 runs -    12.52 us/run -      183 kB/run -   13.95 GB/s
  MUL_MAT(type_a=q4_0,type_b=f32,m=16,n=1,k=256,bs=[10,1],nr=[2,1]):                 8192 runs -    15.63 us/run -      366 kB/run -   22.34 GB/s
  MUL_MAT(type_a=q4_0,type_b=f32,m=16,n=1,k=256,bs=[10,10],nr=[1,1]):                8192 runs -    22.52 us/run -     1831 kB/run -   77.56 GB/s
  MUL_MAT(type_a=q4_0,type_b=f32,m=16,n=1,k=256,bs=[10,10],nr=[2,1]):                8192 runs -    33.34 us/run -     3662 kB/run -  104.76 GB/s
  MUL_MAT(type_a=q4_0,type_b=f32,m=16,n=1,k=256,bs=[10,10],nr=[1,2]):                8192 runs -    33.43 us/run -     3662 kB/run -  104.48 GB/s
  MUL_MAT(type_a=q4_0,type_b=f32,m=16,n=1,k=256,bs=[10,10],nr=[2,2]):                4581 runs -    56.56 us/run -     7325 kB/run -  123.51 GB/s
  MUL_MAT(type_a=q4_0,type_b=f32,m=16,n=16,k=256,bs=[1,1],nr=[1,1]):                 8192 runs -    13.02 us/run -      293 kB/run -   21.46 GB/s
  MUL_MAT(type_a=q4_0,type_b=f32,m=16,n=16,k=256,bs=[10,1],nr=[1,1]):                8192 runs -    13.32 us/run -     2930 kB/run -  209.80 GB/s
  MUL_MAT(type_a=q4_0,type_b=f32,m=16,n=16,k=256,bs=[10,1],nr=[2,1]):                5727 runs -    16.33 us/run -     5860 kB/run -  342.28 GB/s
  MUL_MAT(type_a=q4_0,type_b=f32,m=16,n=16,k=256,bs=[10,10],nr=[1,1]):               1146 runs -    23.68 us/run -    29300 kB/run - 1179.81 GB/s
  MUL_MAT(type_a=q4_0,type_b=f32,m=16,n=16,k=256,bs=[10,10],nr=[2,1]):                573 runs -    35.04 us/run -    58600 kB/run - 1594.74 GB/s
  MUL_MAT(type_a=q4_0,type_b=f32,m=16,n=16,k=256,bs=[10,10],nr=[1,2]):                573 runs -    35.36 us/run -    58600 kB/run - 1580.25 GB/s
  MUL_MAT(type_a=q4_0,type_b=f32,m=16,n=16,k=256,bs=[10,10],nr=[2,2]):                287 runs -    58.49 us/run -   117200 kB/run - 1910.78 GB/s
  MUL_MAT(type_a=q4_1,type_b=f32,m=16,n=1,k=256,bs=[1,1],nr=[1,1]):                  8192 runs -     4.04 us/run -       18 kB/run -    4.38 GB/s
  MUL_MAT(type_a=q4_1,type_b=f32,m=16,n=1,k=256,bs=[10,1],nr=[1,1]):                 8192 runs -    12.68 us/run -      185 kB/run -   13.96 GB/s
  MUL_MAT(type_a=q4_1,type_b=f32,m=16,n=1,k=256,bs=[10,1],nr=[2,1]):                 8192 runs -    15.79 us/run -      371 kB/run -   22.42 GB/s
  MUL_MAT(type_a=q4_1,type_b=f32,m=16,n=1,k=256,bs=[10,10],nr=[1,1]):                8192 runs -    22.50 us/run -     1856 kB/run -   78.67 GB/s
  MUL_MAT(type_a=q4_1,type_b=f32,m=16,n=1,k=256,bs=[10,10],nr=[2,1]):                8192 runs -    33.35 us/run -     3712 kB/run -  106.15 GB/s
  MUL_MAT(type_a=q4_1,type_b=f32,m=16,n=1,k=256,bs=[10,10],nr=[1,2]):                8192 runs -    33.51 us/run -     3712 kB/run -  105.64 GB/s
  MUL_MAT(type_a=q4_1,type_b=f32,m=16,n=1,k=256,bs=[10,10],nr=[2,2]):                4520 runs -    56.36 us/run -     7425 kB/run -  125.65 GB/s
  MUL_MAT(type_a=q4_1,type_b=f32,m=16,n=16,k=256,bs=[1,1],nr=[1,1]):                 8192 runs -    13.06 us/run -      297 kB/run -   21.69 GB/s
  MUL_MAT(type_a=q4_1,type_b=f32,m=16,n=16,k=256,bs=[10,1],nr=[1,1]):                8192 runs -    13.27 us/run -     2970 kB/run -  213.50 GB/s
  MUL_MAT(type_a=q4_1,type_b=f32,m=16,n=16,k=256,bs=[10,1],nr=[2,1]):                5649 runs -    16.48 us/run -     5940 kB/run -  343.73 GB/s
  MUL_MAT(type_a=q4_1,type_b=f32,m=16,n=16,k=256,bs=[10,10],nr=[1,1]):               1130 runs -    23.78 us/run -    29700 kB/run - 1190.84 GB/s
  MUL_MAT(type_a=q4_1,type_b=f32,m=16,n=16,k=256,bs=[10,10],nr=[2,1]):                565 runs -    35.11 us/run -    59400 kB/run - 1613.38 GB/s
  MUL_MAT(type_a=q4_1,type_b=f32,m=16,n=16,k=256,bs=[10,10],nr=[1,2]):                565 runs -    34.95 us/run -    59400 kB/run - 1621.06 GB/s
  MUL_MAT(type_a=q4_1,type_b=f32,m=16,n=16,k=256,bs=[10,10],nr=[2,2]):                283 runs -    58.14 us/run -   118800 kB/run - 1948.76 GB/s
  MUL_MAT(type_a=q5_0,type_b=f32,m=16,n=1,k=256,bs=[1,1],nr=[1,1]):                  8192 runs -     4.65 us/run -       18 kB/run -    3.86 GB/s
  MUL_MAT(type_a=q5_0,type_b=f32,m=16,n=1,k=256,bs=[10,1],nr=[1,1]):                 8192 runs -    15.56 us/run -      188 kB/run -   11.53 GB/s
  MUL_MAT(type_a=q5_0,type_b=f32,m=16,n=1,k=256,bs=[10,1],nr=[2,1]):                 8192 runs -    17.30 us/run -      376 kB/run -   20.74 GB/s
  MUL_MAT(type_a=q5_0,type_b=f32,m=16,n=1,k=256,bs=[10,10],nr=[1,1]):                8192 runs -    24.82 us/run -     1881 kB/run -   72.30 GB/s
  MUL_MAT(type_a=q5_0,type_b=f32,m=16,n=1,k=256,bs=[10,10],nr=[2,1]):                8192 runs -    36.94 us/run -     3762 kB/run -   97.13 GB/s
  MUL_MAT(type_a=q5_0,type_b=f32,m=16,n=1,k=256,bs=[10,10],nr=[1,2]):                8192 runs -    37.19 us/run -     3762 kB/run -   96.47 GB/s
  MUL_MAT(type_a=q5_0,type_b=f32,m=16,n=1,k=256,bs=[10,10],nr=[2,2]):                4460 runs -    63.87 us/run -     7525 kB/run -  112.36 GB/s
  MUL_MAT(type_a=q5_0,type_b=f32,m=16,n=16,k=256,bs=[1,1],nr=[1,1]):                 8192 runs -    15.78 us/run -      301 kB/run -   18.19 GB/s
  MUL_MAT(type_a=q5_0,type_b=f32,m=16,n=16,k=256,bs=[10,1],nr=[1,1]):                8192 runs -    15.98 us/run -     3010 kB/run -  179.58 GB/s
  MUL_MAT(type_a=q5_0,type_b=f32,m=16,n=16,k=256,bs=[10,1],nr=[2,1]):                5574 runs -    17.82 us/run -     6020 kB/run -  322.25 GB/s
  MUL_MAT(type_a=q5_0,type_b=f32,m=16,n=16,k=256,bs=[10,10],nr=[1,1]):               1115 runs -    25.88 us/run -    30100 kB/run - 1109.03 GB/s
  MUL_MAT(type_a=q5_0,type_b=f32,m=16,n=16,k=256,bs=[10,10],nr=[2,1]):                558 runs -    38.27 us/run -    60200 kB/run - 1500.21 GB/s
  MUL_MAT(type_a=q5_0,type_b=f32,m=16,n=16,k=256,bs=[10,10],nr=[1,2]):                558 runs -    38.50 us/run -    60200 kB/run - 1491.34 GB/s
  MUL_MAT(type_a=q5_0,type_b=f32,m=16,n=16,k=256,bs=[10,10],nr=[2,2]):                279 runs -    65.85 us/run -   120400 kB/run - 1743.81 GB/s
  MUL_MAT(type_a=q5_1,type_b=f32,m=16,n=1,k=256,bs=[1,1],nr=[1,1]):                  8192 runs -     4.71 us/run -       19 kB/run -    3.86 GB/s
  MUL_MAT(type_a=q5_1,type_b=f32,m=16,n=1,k=256,bs=[10,1],nr=[1,1]):                 8192 runs -    15.33 us/run -      190 kB/run -   11.86 GB/s
  MUL_MAT(type_a=q5_1,type_b=f32,m=16,n=1,k=256,bs=[10,1],nr=[2,1]):                 8192 runs -    17.13 us/run -      381 kB/run -   21.22 GB/s
  MUL_MAT(type_a=q5_1,type_b=f32,m=16,n=1,k=256,bs=[10,10],nr=[1,1]):                8192 runs -    24.72 us/run -     1906 kB/run -   73.53 GB/s
  MUL_MAT(type_a=q5_1,type_b=f32,m=16,n=1,k=256,bs=[10,10],nr=[2,1]):                8192 runs -    37.27 us/run -     3812 kB/run -   97.56 GB/s
  MUL_MAT(type_a=q5_1,type_b=f32,m=16,n=1,k=256,bs=[10,10],nr=[1,2]):                8192 runs -    37.25 us/run -     3812 kB/run -   97.62 GB/s
  MUL_MAT(type_a=q5_1,type_b=f32,m=16,n=1,k=256,bs=[10,10],nr=[2,2]):                4401 runs -    63.77 us/run -     7625 kB/run -  114.03 GB/s
  MUL_MAT(type_a=q5_1,type_b=f32,m=16,n=16,k=256,bs=[1,1],nr=[1,1]):                 8192 runs -    15.73 us/run -      305 kB/run -   18.49 GB/s
  MUL_MAT(type_a=q5_1,type_b=f32,m=16,n=16,k=256,bs=[10,1],nr=[1,1]):                8192 runs -    16.04 us/run -     3050 kB/run -  181.33 GB/s
  MUL_MAT(type_a=q5_1,type_b=f32,m=16,n=16,k=256,bs=[10,1],nr=[2,1]):                5501 runs -    17.78 us/run -     6100 kB/run -  327.22 GB/s
  MUL_MAT(type_a=q5_1,type_b=f32,m=16,n=16,k=256,bs=[10,10],nr=[1,1]):               1101 runs -    25.52 us/run -    30500 kB/run - 1139.88 GB/s
  MUL_MAT(type_a=q5_1,type_b=f32,m=16,n=16,k=256,bs=[10,10],nr=[2,1]):                551 runs -    38.46 us/run -    61000 kB/run - 1512.69 GB/s
  MUL_MAT(type_a=q5_1,type_b=f32,m=16,n=16,k=256,bs=[10,10],nr=[1,2]):                551 runs -    38.62 us/run -    61000 kB/run - 1506.37 GB/s
  MUL_MAT(type_a=q5_1,type_b=f32,m=16,n=16,k=256,bs=[10,10],nr=[2,2]):                276 runs -    65.18 us/run -   122000 kB/run - 1785.10 GB/s
  MUL_MAT(type_a=q8_0,type_b=f32,m=16,n=1,k=256,bs=[1,1],nr=[1,1]):                  8192 runs -     3.76 us/run -       20 kB/run -    5.15 GB/s
  MUL_MAT(type_a=q8_0,type_b=f32,m=16,n=1,k=256,bs=[10,1],nr=[1,1]):                 8192 runs -    12.84 us/run -      203 kB/run -   15.08 GB/s
  MUL_MAT(type_a=q8_0,type_b=f32,m=16,n=1,k=256,bs=[10,1],nr=[2,1]):                 8192 runs -    16.07 us/run -      406 kB/run -   24.10 GB/s
  MUL_MAT(type_a=q8_0,type_b=f32,m=16,n=1,k=256,bs=[10,10],nr=[1,1]):                8192 runs -    22.14 us/run -     2031 kB/run -   87.51 GB/s
  MUL_MAT(type_a=q8_0,type_b=f32,m=16,n=1,k=256,bs=[10,10],nr=[2,1]):                8192 runs -    32.93 us/run -     4062 kB/run -  117.65 GB/s
  MUL_MAT(type_a=q8_0,type_b=f32,m=16,n=1,k=256,bs=[10,10],nr=[1,2]):                8192 runs -    33.03 us/run -     4062 kB/run -  117.31 GB/s
  MUL_MAT(type_a=q8_0,type_b=f32,m=16,n=1,k=256,bs=[10,10],nr=[2,2]):                4130 runs -    55.46 us/run -     8125 kB/run -  139.72 GB/s
  MUL_MAT(type_a=q8_0,type_b=f32,m=16,n=16,k=256,bs=[1,1],nr=[1,1]):                 8192 runs -    13.20 us/run -      325 kB/run -   23.49 GB/s
  MUL_MAT(type_a=q8_0,type_b=f32,m=16,n=16,k=256,bs=[10,1],nr=[1,1]):                8192 runs -    13.53 us/run -     3250 kB/run -  229.12 GB/s
  MUL_MAT(type_a=q8_0,type_b=f32,m=16,n=16,k=256,bs=[10,1],nr=[2,1]):                5163 runs -    16.70 us/run -     6500 kB/run -  371.29 GB/s
  MUL_MAT(type_a=q8_0,type_b=f32,m=16,n=16,k=256,bs=[10,10],nr=[1,1]):               1033 runs -    23.24 us/run -    32500 kB/run - 1333.66 GB/s
  MUL_MAT(type_a=q8_0,type_b=f32,m=16,n=16,k=256,bs=[10,10],nr=[2,1]):                517 runs -    34.59 us/run -    65000 kB/run - 1792.31 GB/s
  MUL_MAT(type_a=q8_0,type_b=f32,m=16,n=16,k=256,bs=[10,10],nr=[1,2]):                517 runs -    34.71 us/run -    65000 kB/run - 1785.71 GB/s
  MUL_MAT(type_a=q8_0,type_b=f32,m=16,n=16,k=256,bs=[10,10],nr=[2,2]):                259 runs -    56.80 us/run -   130000 kB/run - 2182.88 GB/s
  MUL_MAT(type_a=q2_K,type_b=f32,m=16,n=1,k=256,bs=[1,1],nr=[1,1]):                  8192 runs -     4.73 us/run -       17 kB/run -    3.50 GB/s
  MUL_MAT(type_a=q2_K,type_b=f32,m=16,n=1,k=256,bs=[10,1],nr=[1,1]):                 8192 runs -    13.51 us/run -      173 kB/run -   12.27 GB/s
  MUL_MAT(type_a=q2_K,type_b=f32,m=16,n=1,k=256,bs=[10,1],nr=[2,1]):                 8192 runs -    16.15 us/run -      347 kB/run -   20.52 GB/s
  MUL_MAT(type_a=q2_K,type_b=f32,m=16,n=1,k=256,bs=[10,10],nr=[1,1]):                8192 runs -    22.62 us/run -     1737 kB/run -   73.26 GB/s
  MUL_MAT(type_a=q2_K,type_b=f32,m=16,n=1,k=256,bs=[10,10],nr=[2,1]):                8192 runs -    33.86 us/run -     3475 kB/run -   97.86 GB/s
  MUL_MAT(type_a=q2_K,type_b=f32,m=16,n=1,k=256,bs=[10,10],nr=[1,2]):                8192 runs -    33.90 us/run -     3475 kB/run -   97.77 GB/s
  MUL_MAT(type_a=q2_K,type_b=f32,m=16,n=1,k=256,bs=[10,10],nr=[2,2]):                4828 runs -    57.35 us/run -     6950 kB/run -  115.56 GB/s
  MUL_MAT(type_a=q2_K,type_b=f32,m=16,n=16,k=256,bs=[1,1],nr=[1,1]):                 8192 runs -    13.65 us/run -      278 kB/run -   19.42 GB/s
  MUL_MAT(type_a=q2_K,type_b=f32,m=16,n=16,k=256,bs=[10,1],nr=[1,1]):                8192 runs -    14.06 us/run -     2780 kB/run -  188.56 GB/s
  MUL_MAT(type_a=q2_K,type_b=f32,m=16,n=16,k=256,bs=[10,1],nr=[2,1]):                6035 runs -    16.66 us/run -     5560 kB/run -  318.24 GB/s
  MUL_MAT(type_a=q2_K,type_b=f32,m=16,n=16,k=256,bs=[10,10],nr=[1,1]):               1207 runs -    23.71 us/run -    27800 kB/run - 1118.30 GB/s
  MUL_MAT(type_a=q2_K,type_b=f32,m=16,n=16,k=256,bs=[10,10],nr=[2,1]):                604 runs -    35.42 us/run -    55600 kB/run - 1497.13 GB/s
  MUL_MAT(type_a=q2_K,type_b=f32,m=16,n=16,k=256,bs=[10,10],nr=[1,2]):                604 runs -    35.48 us/run -    55600 kB/run - 1494.69 GB/s
  MUL_MAT(type_a=q2_K,type_b=f32,m=16,n=16,k=256,bs=[10,10],nr=[2,2]):                302 runs -    59.49 us/run -   111200 kB/run - 1782.53 GB/s
  MUL_MAT(type_a=q3_K,type_b=f32,m=16,n=1,k=256,bs=[1,1],nr=[1,1]):                  8192 runs -     4.77 us/run -       17 kB/run -    3.56 GB/s
  MUL_MAT(type_a=q3_K,type_b=f32,m=16,n=1,k=256,bs=[10,1],nr=[1,1]):                 8192 runs -    15.08 us/run -      177 kB/run -   11.24 GB/s
  MUL_MAT(type_a=q3_K,type_b=f32,m=16,n=1,k=256,bs=[10,1],nr=[2,1]):                 8192 runs -    16.66 us/run -      355 kB/run -   20.36 GB/s
  MUL_MAT(type_a=q3_K,type_b=f32,m=16,n=1,k=256,bs=[10,10],nr=[1,1]):                8192 runs -    23.50 us/run -     1778 kB/run -   72.16 GB/s
  MUL_MAT(type_a=q3_K,type_b=f32,m=16,n=1,k=256,bs=[10,10],nr=[2,1]):                8192 runs -    35.08 us/run -     3556 kB/run -   96.67 GB/s
  MUL_MAT(type_a=q3_K,type_b=f32,m=16,n=1,k=256,bs=[10,10],nr=[1,2]):                8192 runs -    35.44 us/run -     3556 kB/run -   95.69 GB/s
  MUL_MAT(type_a=q3_K,type_b=f32,m=16,n=1,k=256,bs=[10,10],nr=[2,2]):                4718 runs -    60.41 us/run -     7112 kB/run -  112.29 GB/s
  MUL_MAT(type_a=q3_K,type_b=f32,m=16,n=16,k=256,bs=[1,1],nr=[1,1]):                 8192 runs -    15.26 us/run -      284 kB/run -   17.78 GB/s
  MUL_MAT(type_a=q3_K,type_b=f32,m=16,n=16,k=256,bs=[10,1],nr=[1,1]):                8192 runs -    15.58 us/run -     2845 kB/run -  174.20 GB/s
  MUL_MAT(type_a=q3_K,type_b=f32,m=16,n=16,k=256,bs=[10,1],nr=[2,1]):                5898 runs -    17.35 us/run -     5690 kB/run -  312.71 GB/s
  MUL_MAT(type_a=q3_K,type_b=f32,m=16,n=16,k=256,bs=[10,10],nr=[1,1]):               1180 runs -    24.51 us/run -    28450 kB/run - 1106.82 GB/s
  MUL_MAT(type_a=q3_K,type_b=f32,m=16,n=16,k=256,bs=[10,10],nr=[2,1]):                590 runs -    36.69 us/run -    56900 kB/run - 1478.79 GB/s
  MUL_MAT(type_a=q3_K,type_b=f32,m=16,n=16,k=256,bs=[10,10],nr=[1,2]):                590 runs -    36.98 us/run -    56900 kB/run - 1467.34 GB/s
  MUL_MAT(type_a=q3_K,type_b=f32,m=16,n=16,k=256,bs=[10,10],nr=[2,2]):                295 runs -    62.16 us/run -   113800 kB/run - 1745.97 GB/s
  MUL_MAT(type_a=q4_K,type_b=f32,m=16,n=1,k=256,bs=[1,1],nr=[1,1]):                  8192 runs -     4.69 us/run -       18 kB/run -    3.72 GB/s
  MUL_MAT(type_a=q4_K,type_b=f32,m=16,n=1,k=256,bs=[10,1],nr=[1,1]):                 8192 runs -    15.05 us/run -      183 kB/run -   11.61 GB/s
  MUL_MAT(type_a=q4_K,type_b=f32,m=16,n=1,k=256,bs=[10,1],nr=[2,1]):                 8192 runs -    16.60 us/run -      366 kB/run -   21.04 GB/s
  MUL_MAT(type_a=q4_K,type_b=f32,m=16,n=1,k=256,bs=[10,10],nr=[1,1]):                8192 runs -    22.81 us/run -     1831 kB/run -   76.58 GB/s
  MUL_MAT(type_a=q4_K,type_b=f32,m=16,n=1,k=256,bs=[10,10],nr=[2,1]):                8192 runs -    34.16 us/run -     3662 kB/run -  102.25 GB/s
  MUL_MAT(type_a=q4_K,type_b=f32,m=16,n=1,k=256,bs=[10,10],nr=[1,2]):                8192 runs -    34.21 us/run -     3662 kB/run -  102.10 GB/s
  MUL_MAT(type_a=q4_K,type_b=f32,m=16,n=1,k=256,bs=[10,10],nr=[2,2]):                4581 runs -    57.88 us/run -     7325 kB/run -  120.69 GB/s
  MUL_MAT(type_a=q4_K,type_b=f32,m=16,n=16,k=256,bs=[1,1],nr=[1,1]):                 8192 runs -    15.30 us/run -      293 kB/run -   18.26 GB/s
  MUL_MAT(type_a=q4_K,type_b=f32,m=16,n=16,k=256,bs=[10,1],nr=[1,1]):                8192 runs -    15.78 us/run -     2930 kB/run -  177.03 GB/s
  MUL_MAT(type_a=q4_K,type_b=f32,m=16,n=16,k=256,bs=[10,1],nr=[2,1]):                5727 runs -    17.08 us/run -     5860 kB/run -  327.25 GB/s
  MUL_MAT(type_a=q4_K,type_b=f32,m=16,n=16,k=256,bs=[10,10],nr=[1,1]):               1146 runs -    23.98 us/run -    29300 kB/run - 1165.25 GB/s
  MUL_MAT(type_a=q4_K,type_b=f32,m=16,n=16,k=256,bs=[10,10],nr=[2,1]):                573 runs -    35.90 us/run -    58600 kB/run - 1556.75 GB/s
  MUL_MAT(type_a=q4_K,type_b=f32,m=16,n=16,k=256,bs=[10,10],nr=[1,2]):                573 runs -    36.08 us/run -    58600 kB/run - 1549.14 GB/s
  MUL_MAT(type_a=q4_K,type_b=f32,m=16,n=16,k=256,bs=[10,10],nr=[2,2]):                287 runs -    59.65 us/run -   117200 kB/run - 1873.83 GB/s
  MUL_MAT(type_a=q5_K,type_b=f32,m=16,n=1,k=256,bs=[1,1],nr=[1,1]):                  8192 runs -     4.81 us/run -       18 kB/run -    3.73 GB/s
  MUL_MAT(type_a=q5_K,type_b=f32,m=16,n=1,k=256,bs=[10,1],nr=[1,1]):                 8192 runs -    16.02 us/run -      188 kB/run -   11.20 GB/s
  MUL_MAT(type_a=q5_K,type_b=f32,m=16,n=1,k=256,bs=[10,1],nr=[2,1]):                 8192 runs -    17.22 us/run -      376 kB/run -   20.84 GB/s
  MUL_MAT(type_a=q5_K,type_b=f32,m=16,n=1,k=256,bs=[10,10],nr=[1,1]):                8192 runs -    24.05 us/run -     1881 kB/run -   74.61 GB/s
  MUL_MAT(type_a=q5_K,type_b=f32,m=16,n=1,k=256,bs=[10,10],nr=[2,1]):                8192 runs -    36.14 us/run -     3762 kB/run -   99.28 GB/s
  MUL_MAT(type_a=q5_K,type_b=f32,m=16,n=1,k=256,bs=[10,10],nr=[1,2]):                8192 runs -    36.14 us/run -     3762 kB/run -   99.30 GB/s
  MUL_MAT(type_a=q5_K,type_b=f32,m=16,n=1,k=256,bs=[10,10],nr=[2,2]):                4460 runs -    61.64 us/run -     7525 kB/run -  116.42 GB/s
  MUL_MAT(type_a=q5_K,type_b=f32,m=16,n=16,k=256,bs=[1,1],nr=[1,1]):                 8192 runs -    15.95 us/run -      301 kB/run -   17.99 GB/s
  MUL_MAT(type_a=q5_K,type_b=f32,m=16,n=16,k=256,bs=[10,1],nr=[1,1]):                8192 runs -    16.23 us/run -     3010 kB/run -  176.88 GB/s
  MUL_MAT(type_a=q5_K,type_b=f32,m=16,n=16,k=256,bs=[10,1],nr=[2,1]):                5574 runs -    17.80 us/run -     6020 kB/run -  322.56 GB/s
  MUL_MAT(type_a=q5_K,type_b=f32,m=16,n=16,k=256,bs=[10,10],nr=[1,1]):               1115 runs -    25.13 us/run -    30100 kB/run - 1142.36 GB/s
  MUL_MAT(type_a=q5_K,type_b=f32,m=16,n=16,k=256,bs=[10,10],nr=[2,1]):                558 runs -    37.67 us/run -    60200 kB/run - 1524.19 GB/s
  MUL_MAT(type_a=q5_K,type_b=f32,m=16,n=16,k=256,bs=[10,10],nr=[1,2]):                558 runs -    37.98 us/run -    60200 kB/run - 1511.46 GB/s
  MUL_MAT(type_a=q5_K,type_b=f32,m=16,n=16,k=256,bs=[10,10],nr=[2,2]):                279 runs -    63.75 us/run -   120400 kB/run - 1801.16 GB/s
  MUL_MAT(type_a=q6_K,type_b=f32,m=16,n=1,k=256,bs=[1,1],nr=[1,1]):                  8192 runs -     3.58 us/run -       19 kB/run -    5.16 GB/s
  MUL_MAT(type_a=q6_K,type_b=f32,m=16,n=1,k=256,bs=[10,1],nr=[1,1]):                 8192 runs -    12.08 us/run -      193 kB/run -   15.27 GB/s
  MUL_MAT(type_a=q6_K,type_b=f32,m=16,n=1,k=256,bs=[10,1],nr=[2,1]):                 8192 runs -    14.22 us/run -      386 kB/run -   25.94 GB/s
  MUL_MAT(type_a=q6_K,type_b=f32,m=16,n=1,k=256,bs=[10,10],nr=[1,1]):                8192 runs -    24.19 us/run -     1934 kB/run -   76.27 GB/s
  MUL_MAT(type_a=q6_K,type_b=f32,m=16,n=1,k=256,bs=[10,10],nr=[2,1]):                8192 runs -    36.32 us/run -     3868 kB/run -  101.59 GB/s
  MUL_MAT(type_a=q6_K,type_b=f32,m=16,n=1,k=256,bs=[10,10],nr=[1,2]):                8192 runs -    36.27 us/run -     3868 kB/run -  101.71 GB/s
  MUL_MAT(type_a=q6_K,type_b=f32,m=16,n=1,k=256,bs=[10,10],nr=[2,2]):                4337 runs -    62.19 us/run -     7737 kB/run -  118.65 GB/s
  MUL_MAT(type_a=q6_K,type_b=f32,m=16,n=16,k=256,bs=[1,1],nr=[1,1]):                 8192 runs -    15.85 us/run -      309 kB/run -   18.62 GB/s
  MUL_MAT(type_a=q6_K,type_b=f32,m=16,n=16,k=256,bs=[10,1],nr=[1,1]):                8192 runs -    16.13 us/run -     3095 kB/run -  183.02 GB/s
  MUL_MAT(type_a=q6_K,type_b=f32,m=16,n=16,k=256,bs=[10,1],nr=[2,1]):                5421 runs -    17.84 us/run -     6190 kB/run -  330.82 GB/s
  MUL_MAT(type_a=q6_K,type_b=f32,m=16,n=16,k=256,bs=[10,10],nr=[1,1]):               1085 runs -    25.15 us/run -    30950 kB/run - 1173.51 GB/s
  MUL_MAT(type_a=q6_K,type_b=f32,m=16,n=16,k=256,bs=[10,10],nr=[2,1]):                543 runs -    37.70 us/run -    61900 kB/run - 1566.01 GB/s
  MUL_MAT(type_a=q6_K,type_b=f32,m=16,n=16,k=256,bs=[10,10],nr=[1,2]):                543 runs -    37.92 us/run -    61900 kB/run - 1556.65 GB/s
  MUL_MAT(type_a=q6_K,type_b=f32,m=16,n=16,k=256,bs=[10,10],nr=[2,2]):                272 runs -    64.31 us/run -   123800 kB/run - 1836.01 GB/s
  SQR(type=f32,ne=[10,10,10,10]):                                8192 runs -     1.90 us/run -       78 kB/run -   39.18 GB/s
  CLAMP(type=f32,ne=[10,10,10,10],min=-0.500000,max=0.500000): not supported
  DIAG_MASK_INF(type=f32,ne=[10,10,1,1],n_past=5):               8192 runs -     1.87 us/run -        0 kB/run -    0.40 GB/s
  DIAG_MASK_INF(type=f32,ne=[10,10,10,1],n_past=5):              8192 runs -     2.01 us/run -        7 kB/run -    3.70 GB/s
  DIAG_MASK_INF(type=f32,ne=[10,10,10,10],n_past=5): not supported
  DIAG_MASK_INF(type=f32,ne=[77,77,12,1],n_past=0):              8192 runs -     9.97 us/run -      555 kB/run -   53.16 GB/s
  SOFT_MAX(type=f32,ne=[10,10,10,10]):                           8192 runs -     7.25 us/run -       78 kB/run -   10.28 GB/s
  ROPE(type=f32,ne=[128,32,10,1],n_dims=128,mode=0,n_ctx=512):                       8192 runs -     5.73 us/run -      320 kB/run -   53.22 GB/s
  ROPE(type=f32,ne=[128,40,10,1],n_dims=128,mode=0,n_ctx=512):                       8192 runs -     6.52 us/run -      400 kB/run -   58.53 GB/s
  ROPE(type=f32,ne=[128,52,10,1],n_dims=128,mode=0,n_ctx=512):                       8192 runs -     7.13 us/run -      520 kB/run -   69.57 GB/s
  ROPE(type=f32,ne=[128,64,10,1],n_dims=128,mode=0,n_ctx=512):                       8192 runs -     7.90 us/run -      640 kB/run -   77.26 GB/s
  ROPE(type=f32,ne=[64,1,10,1],n_dims=64,mode=2,n_ctx=512):                          8192 runs -     5.14 us/run -        5 kB/run -    0.94 GB/s
  ROPE(type=f32,ne=[64,71,10,1],n_dims=64,mode=2,n_ctx=512):                         8192 runs -     7.31 us/run -      355 kB/run -   46.31 GB/s
  ROPE(type=f32,ne=[64,8,10,1],n_dims=64,mode=2,n_ctx=512):                          8192 runs -     5.28 us/run -       40 kB/run -    7.23 GB/s
  ROPE(type=f32,ne=[64,128,10,1],n_dims=64,mode=2,n_ctx=512):                        8192 runs -     9.09 us/run -      640 kB/run -   67.18 GB/s
  ROPE(type=f32,ne=[80,32,10,1],n_dims=20,mode=2,n_ctx=512):                         8192 runs -     7.83 us/run -      200 kB/run -   24.38 GB/s
  ROPE(type=f16,ne=[128,32,10,1],n_dims=128,mode=0,n_ctx=512):                       8192 runs -     6.09 us/run -      160 kB/run -   25.04 GB/s
  ROPE(type=f16,ne=[128,40,10,1],n_dims=128,mode=0,n_ctx=512):                       8192 runs -     6.47 us/run -      200 kB/run -   29.47 GB/s
  ROPE(type=f16,ne=[128,52,10,1],n_dims=128,mode=0,n_ctx=512):                       8192 runs -     7.01 us/run -      260 kB/run -   35.38 GB/s
  ROPE(type=f16,ne=[128,64,10,1],n_dims=128,mode=0,n_ctx=512):                       8192 runs -     7.82 us/run -      320 kB/run -   39.02 GB/s
  ROPE(type=f16,ne=[64,1,10,1],n_dims=64,mode=2,n_ctx=512):                          8192 runs -     5.14 us/run -        2 kB/run -    0.47 GB/s
  ROPE(type=f16,ne=[64,71,10,1],n_dims=64,mode=2,n_ctx=512):                         8192 runs -     7.13 us/run -      177 kB/run -   23.73 GB/s
  ROPE(type=f16,ne=[64,8,10,1],n_dims=64,mode=2,n_ctx=512):                          8192 runs -     5.18 us/run -       20 kB/run -    3.69 GB/s
  ROPE(type=f16,ne=[64,128,10,1],n_dims=64,mode=2,n_ctx=512):                        8192 runs -     9.00 us/run -      320 kB/run -   33.91 GB/s
  ROPE(type=f16,ne=[80,32,10,1],n_dims=20,mode=2,n_ctx=512):                         8192 runs -     7.76 us/run -      100 kB/run -   12.30 GB/s
  ALIBI(type=f32,ne=[10,10,10,10],n_past=512,n_head=10,bias_max=0.500000):           8192 runs -     9.16 us/run -       78 kB/run -    8.13 GB/s
  IM2COL(type_input=f32,type_kernel=f16,ne_input=[10,10,3,1],ne_kernel=[3,3,3,1],s0=1,s1=1,p0=1,p1=1,d0=1,d1=1,is_2D=1):                         8192 runs -     4.26 us/run -        6 kB/run -    1.45 GB/s
  CONCAT(type=f32,ne=[10,10,10,10],b_ne2=10):                    8192 runs -     9.77 us/run -      156 kB/run -   15.26 GB/s
  ARGSORT(type=f32,ne=[16,10,10,10],order=0):                    8192 runs -     6.44 us/run -      125 kB/run -   18.50 GB/s
  ARGSORT(type=f32,ne=[16,10,10,10],order=1):                    8192 runs -     6.45 us/run -      125 kB/run -   18.48 GB/s
  SUM_ROWS(type=f32,ne=[10,10,10,10]):                           8192 runs -     3.89 us/run -       42 kB/run -   10.52 GB/s
  UPSCALE(type=f32,ne=[512,512,3,1],scale_factor=2):             2185 runs -   218.18 us/run -    15360 kB/run -   67.14 GB/s
  GROUP_NORM(type=f32,ne=[64,64,320,1],num_groups=32):           3277 runs -   523.61 us/run -    10240 kB/run -   18.65 GB/s
  ACC(type=f32,ne_a=[1024,577,1,1],ne_b=[1024,576,1,1]):                             4849 runs -   112.17 us/run -     6920 kB/run -   58.83 GB/s
  PAD(type=f32,ne_a=[512,512,1,1],pad_0=1,pad_1=1):              8192 runs -    15.14 us/run -     2052 kB/run -  129.26 GB/s
  LEAKY_RELU(type=f32,ne_a=[10,10,10,10],negative_slope=0.200000):                   8192 runs -     2.10 us/run -       78 kB/run -   35.49 GB/s
  Backend Metal: OK

@ggerganov
Copy link
Owner

Ah interesting - sorry for missing the discussion. I was so focused on the Mixtral issue that I didn't pay attention and see this just now. Let me review tomorrow first thing

@slaren
Copy link
Collaborator

slaren commented Dec 12, 2023

I wonder if the soft max issue is related to the early ending with Metal (in mixtral).

@ggerganov
Copy link
Owner

ggerganov commented Dec 13, 2023

@slaren I think this indeed fixes it! ❤️

Still testing

Edit: the MTL_DEBUG_LAYER=1 MTL_SHADER_VALIDATION=1 MTL_SHADER_VALIDATION_REPORT_TO_STDERR=1 MTL_SHADER_VALIDATION_FAIL_MODE=allow flags look extremely useful. Will be using those from now on

Edit2: All my failure cases with Mixtral are now resolved - the fix does seem to work

@@ -487,6 +488,7 @@ kernel void kernel_soft_max_4(
}

const float lsum = lsum4[0] + lsum4[1] + lsum4[2] + lsum4[3];
threadgroup_barrier(mem_flags::mem_threadgroup);
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do these barriers make a difference in your tests?
To me they seems superfluous, but I could be missing something

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am not sure why, but without the barrier this test case fails sometimes:

    test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {4096, 4096, 8, 1}));

Copy link
Collaborator

@slaren slaren Dec 13, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Using a execution barrier only threadgroup_barrier(mem_flags::mem_none); seems to be enough to fix the issue. My intuition when adding the barrier is to ensure that all threads finished computing their value before the simd_sum, but I don't know enough about Metal to tell if that's actually required. In CUDA, the warp shuffle functions have an implicit synchronization, so it is not necessary.

@ggerganov
Copy link
Owner

The plan is to merge the Mixtral branch in llama.cpp and then sync with the changes here.
Let me know if there is anything pending before that. We will improve the Metal performance after the sync

@slaren
Copy link
Collaborator

slaren commented Dec 13, 2023

Edit: the MTL_DEBUG_LAYER=1 MTL_SHADER_VALIDATION=1 MTL_SHADER_VALIDATION_REPORT_TO_STDERR=1 MTL_SHADER_VALIDATION_FAIL_MODE=allow flags look extremely useful. Will be using those from now on

I am not sure if there is a better combination of flags though. I still can't figure why some mat muls fail with these flags. The full documentation is available in man MetalValidation.

@ggerganov
Copy link
Owner

Yup, I saw the failing mat muls. I also get some invalid loads in the F32 kernel. Will be looking to fix those

@FSSRepo
Copy link
Collaborator Author

FSSRepo commented Dec 13, 2023

Can I merge this pull request?

@ggerganov
Copy link
Owner

Just squash it in a single commit when merging

@FSSRepo FSSRepo merged commit 5bf85a5 into ggerganov:master Dec 13, 2023
4 checks passed
@FSSRepo
Copy link
Collaborator Author

FSSRepo commented Dec 13, 2023

Thank you very much, everyone, for the feedback and assistance. I hope to continue contributing. I'm considering implementing Winograd (reduce memory usage and computation) in the ggml_conv_2d operation. I already have a working implementation, but I'll need to extend the ggml_pad operation to enable symmetric padding, as it currently performs asymmetric padding (only adds zeros at the end of one dimension). I would greatly appreciate your help.

@ggerganov
Copy link
Owner

@FSSRepo Thank you for your contributions - your help is very much appreciated!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

3 participants