Skip to content

Commit

Permalink
Add python example w/ cffi-generated bindings
Browse files Browse the repository at this point in the history
Add python example w/ cffi-generated bindings

Features:

- Seamless copies between tensors (ggml & numpy alike) with automatic (de/re)quantization
- Access to full C API (incl. CUDA, MPI, OpenCL, Metal, alloc... and any local API changes)
- Trivial regeneration with `python regenerate.py` (uses llama.cpp headers by default, README.md for options)
  • Loading branch information
ochafik committed Aug 13, 2023
1 parent 244776a commit 53a0003
Show file tree
Hide file tree
Showing 9 changed files with 504 additions and 0 deletions.
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -31,3 +31,5 @@ zig-cache/
*.dot

*.sw?

__pycache__/
107 changes: 107 additions & 0 deletions examples/python/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
# Simple autogenerated Python bindings for ggml

This folder contains:

- Scripts to generate full Python bindings from ggml headers.
- Some barebones utils (see [ggml/utils.py](./ggml/utils.py)):
- `ggml.utils.init` builds a context that's freed automatically when the pointer gets GC'd
- `ggml.utils.copy` **copies between same-shaped tensors (numpy or ggml), w/ automatic (de/re)quantization**
- `ggml.utils.numpy` returns a numpy view over a ggml tensor; if it's quantized, it returns a copy (requires `allow_copy=True`)
- Very basic examples (anyone wants to port [llama2.c](https://github.com/karpathy/llama2.c)?)

Provided you set `GGML_LIBRARY=.../path/to/libggml_shared.so` (see instructions below), it's trivial to do some operations on quantized tensors:

```python
# Make sure libllama.so is in your [DY]LD_LIBRARY_PATH, or set GGML_LIBRARY=.../libggml_shared.so

from ggml import lib, ffi
from ggml.utils import init, copy, numpy
import numpy as np

ctx = init(mem_size=12*1024*1024)
n = 256
n_threads = 4

a = lib.ggml_new_tensor_1d(ctx, lib.GGML_TYPE_Q5_K, n)
b = lib.ggml_new_tensor_1d(ctx, lib.GGML_TYPE_F32, n) # Can't both be quantized
sum = lib.ggml_add(ctx, a, b) # all zeroes for now. Will be quantized too!

gf = ffi.new('struct ggml_cgraph*')
lib.ggml_build_forward_expand(gf, sum)

copy(np.array([i for i in range(n)], np.float32), a)
copy(np.array([i*100 for i in range(n)], np.float32), b)

lib.ggml_graph_compute_with_ctx(ctx, gf, n_threads)

print(numpy(a, allow_copy=True))
# 0. 1.0439453 2.0878906 3.131836 4.1757812 5.2197266. ...
print(numpy(b))
# 0. 100. 200. 300. 400. 500. ...
print(numpy(sum, allow_copy=True))
# 0. 105.4375 210.875 316.3125 421.75 527.1875 ...
```

### Prerequisites

You'll need a shared library of ggml to use the bindings.

#### Build libggml_shared.so or libllama.so

As of this writing the best is to use [ggerganov/llama.cpp](https://github.com/ggerganov/llama.cpp)'s generated `libggml_shared.so` or `libllama.so`, which you can build as follows:

```bash
git clone https://github.com/ggerganov/llama.cpp
# On a CUDA-enabled system add -DLLAMA_CUBLAS=1
# On a Mac add -DLLAMA_METAL=1
cmake llama.cpp \
-B llama_build \
-DCMAKE_C_FLAGS=-Ofast \
-DLLAMA_NATIVE=1 \
-DLLAMA_LTO=1 \
-DBUILD_SHARED_LIBS=1 \
-DLLAMA_MPI=1 \
-DLLAMA_BUILD_TESTS=0 \
-DLLAMA_BUILD_EXAMPLES=0
( cd llama_build && make -j )

# On Mac, this will be libggml_shared.dylib instead
export GGML_LIBRARY=$PWD/llama_build/libggml_shared.so
# Alternatively, you can just copy it to your system's lib dir, e.g /usr/local/lib
```

#### (Optional) Regenerate the bindings (`ggml/cffi.py`)

If you added or changed any signatures of the C API, you'll want to regenerate the bindings.

Luckily it's a one-liner using [regenerate.py](./regenerate.py):

```bash
pip install -q cffi

python regenerate.py
```

By default it assumes `llama.cpp` was cloned in ../../../llama.cpp (alongside the ggml folder). You can override this with:

```bash
C_INCLUDE_DIR=$LLAMA_CPP_DIR python regenerate.py
```

You can also edit [api.h](./api.h) to control which files should be included in the generated bindings (defaults to `llama.cpp/ggml*.h`)

In fact, if you wanted to only generate bindings for the current version of the `ggml` repo itself (instead of `llama.cpp`; you'd loose support for k-quants), you could run:

```bash
API=../../include/ggml/ggml.h python regenerate.py
```

### Alternatives

This example's goal is to showcase [cffi](https://cffi.readthedocs.io/)-generated bindings that are trivial to use and update, but there are already alternatives in the wild:

- https://github.com/abetlen/ggml-python: these bindings seem to be hand-written and use [ctypes](https://docs.python.org/3/library/ctypes.html). It has [high-quality API reference docs](https://ggml-python.readthedocs.io/en/latest/api-reference/#ggml.ggml) that can be used with these bindings too, but it doesn't expose Metal, CUDA, MPI or OpenCL calls, doesn't support transparent (de/re)quantization like this example does (see [ggml.utils](./ggml/utils.py) module), and won't pick up your local changes.

- https://github.com/abetlen/llama-cpp-python: these expose the C++ `llama.cpp` interface, which this example cannot easily be extended to support (`cffi` only generates bindings of C libraries)

- [pybind11](https://github.com/pybind/pybind11) and [nanobind](https://github.com/wjakob/nanobind) are two alternatives to cffi that support binding C++ libraries, but it doesn't seem either of them have an automatic generator (writing bindings is rather time-consuming).
14 changes: 14 additions & 0 deletions examples/python/api.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
/*
List here all the headers you want to expose in the Python bindings,
then run `python regenerate.py` (see details in README.md)
*/

#include "ggml.h"
#include "ggml-metal.h"
#include "ggml-opencl.h"

// Headers below are currently only present in the llama.cpp repository, comment them out if you don't have them.
#include "k_quants.h"
#include "ggml-alloc.h"
#include "ggml-cuda.h"
#include "ggml-mpi.h"
25 changes: 25 additions & 0 deletions examples/python/example_add_quant.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
from ggml import lib, ffi
from ggml.utils import init, copy, numpy
import numpy as np

ctx = init(mem_size=12*1024*1024) # automatically freed when pointer is GC'd
n = 256
n_threads = 4

a = lib.ggml_new_tensor_1d(ctx, lib.GGML_TYPE_Q5_K, n)
b = lib.ggml_new_tensor_1d(ctx, lib.GGML_TYPE_F32, n) # can't both be quantized
sum = lib.ggml_add(ctx, a, b) # all zeroes for now. Will be quantized too!

# See cffi's doc on how to allocate native memory: it's very simple!
# https://cffi.readthedocs.io/en/latest/ref.html#ffi-interface
gf = ffi.new('struct ggml_cgraph*')
lib.ggml_build_forward_expand(gf, sum)

copy(np.array([i for i in range(n)], np.float32), a)
copy(np.array([i*100 for i in range(n)], np.float32), b)

lib.ggml_graph_compute_with_ctx(ctx, gf, n_threads)

print(numpy(a, allow_copy=True))
print(numpy(b))
print(numpy(sum, allow_copy=True))
68 changes: 68 additions & 0 deletions examples/python/example_test_all_quants.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
from ggml import ffi, lib
from ggml.utils import init, numpy, copy
import numpy as np
from math import pi, cos, sin, ceil

import matplotlib.pyplot as plt

ctx = init(mem_size=100*1024*1024) # Will be auto-GC'd
n = 256

orig = np.array([
[
cos(j * 2 * pi / n) * (sin(i * 2 * pi / n))
for j in range(n)
]
for i in range(n)
], np.float32)
orig_tensor = lib.ggml_new_tensor_2d(ctx, lib.GGML_TYPE_F32, n, n)
copy(orig, orig_tensor)

quants = [
type for type in range(lib.GGML_TYPE_COUNT)
if lib.ggml_is_quantized(type) and
type not in [lib.GGML_TYPE_Q8_1, lib.GGML_TYPE_Q8_K] # Apparently not supported
]
# quants = [lib.GGML_TYPE_Q2_K] # Test a single one

def get_name(type):
name = lib.ggml_type_name(type)
return ffi.string(name).decode('utf-8') if name else '?'

quants.sort(key=get_name)
quants.insert(0, None)
print(quants)

ncols=4
nrows = ceil(len(quants) / ncols)

plt.figure(figsize=(ncols * 5, nrows * 5), layout='tight')

for i, type in enumerate(quants):
plt.subplot(nrows, ncols, i + 1)
try:
if type == None:
plt.title('Original')
plt.imshow(orig)
else:
quantized_tensor = lib.ggml_new_tensor_2d(ctx, type, n, n)
copy(orig_tensor, quantized_tensor)
quantized = numpy(quantized_tensor, allow_copy=True)
d = quantized - orig
results = {
"l2": np.linalg.norm(d, 2),
"linf": np.linalg.norm(d, np.inf),
"compression":
round(lib.ggml_nbytes(orig_tensor) /
lib.ggml_nbytes(quantized_tensor), 1)
}
name = get_name(type)
print(f'{name}: {results}')

plt.title(f'{name} ({results["compression"]}x smaller)')
plt.imshow(quantized, interpolation='nearest')

except Exception as e:
print(f'Error: {e}')

plt.show()
58 changes: 58 additions & 0 deletions examples/python/ggml/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
"""
Python bindings for the ggml library.
Usage example:
from ggml import lib, ffi
from ggml.utils import init, copy, numpy
import numpy as np
ctx = init(mem_size=10*1024*1024)
n = 1024
n_threads = 4
a = lib.ggml_new_tensor_1d(ctx, lib.GGML_TYPE_Q5_K, n)
b = lib.ggml_new_tensor_1d(ctx, lib.GGML_TYPE_F32, n)
sum = lib.ggml_add(ctx, a, b)
gf = ffi.new('struct ggml_cgraph*')
lib.ggml_build_forward_expand(gf, sum)
copy(np.array([i for i in range(n)], np.float32), a)
copy(np.array([i*100 for i in range(n)], np.float32), b)
lib.ggml_graph_compute_with_ctx(ctx, gf, n_threads)
print(numpy(sum, allow_copy=True))
See https://cffi.readthedocs.io/en/latest/cdef.html for more on cffi.
"""

try:
from ggml.cffi import ffi as ffi
except ImportError as e:
raise ImportError(f"Couldn't find ggml bindings ({e}). Run `python regenerate.py` or check your PYTHONPATH.")

import os, platform

__exact_library = os.environ.get("GGML_LIBRARY")
if __exact_library:
__candidates = [__exact_library]
elif platform.system() == "Windows":
__candidates = ["ggml_shared.dll", "llama.dll"]
else:
__candidates = ["libggml_shared.so", "libllama.so"]
if platform.system() == "Darwin":
__candidates += ["libggml_shared.dylib", "libllama.dylib"]

for i, name in enumerate(__candidates):
try:
# This is where all the functions, enums and constants are defined
lib = ffi.dlopen(name)
except OSError:
if i < len(__candidates) - 1:
continue
raise OSError(f"Couldn't find ggml's shared library (tried names: {__candidates}). Add its directory to DYLD_LIBRARY_PATH (on Mac) or LD_LIBRARY_PATH, or define GGML_LIBRARY.")

# This contains the cffi helpers such as new, cast, string, etc.
# https://cffi.readthedocs.io/en/latest/ref.html#ffi-interface
ffi = ffi
Loading

0 comments on commit 53a0003

Please sign in to comment.