Skip to content

Commit

Permalink
Add MNIST inference example with CNN
Browse files Browse the repository at this point in the history
Add one more implementation for MNIST which uses Conv2D layers, ref:
https://keras.io/examples/vision/mnist_convnet/. It achieves ~99%
accuracy on the MNIST test set and also performs better for user inputs.
This implementation expects a model in GGUF format. You can get one with
the 'mnist-cnn.py' script. Example usage:

$ ./mnist-cnn.py train mnist-cnn-model
...
Keras model saved to 'mnist-cnn-model'

$ ./mnist-cnn.py convert mnist-cnn-model
...
Model converted and saved to 'mnist-cnn-model.gguf'

$ ./mnist-cnn mnist-cnn-model.gguf models/mnist/t10k-images.idx3-ubyte
  • Loading branch information
rgerganov committed Aug 28, 2023
1 parent 69bf842 commit 9a287f7
Show file tree
Hide file tree
Showing 3 changed files with 277 additions and 0 deletions.
7 changes: 7 additions & 0 deletions examples/mnist/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,13 @@ set(TEST_TARGET mnist)
add_executable(${TEST_TARGET} main.cpp)
target_link_libraries(${TEST_TARGET} PRIVATE ggml common)

#
# mnist-cnn

set(TEST_TARGET mnist-cnn)
add_executable(${TEST_TARGET} main-cnn.cpp)
target_link_libraries(${TEST_TARGET} PRIVATE ggml common)

#
# mnist-cpu

Expand Down
169 changes: 169 additions & 0 deletions examples/mnist/main-cnn.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,169 @@
#include "ggml/ggml.h"

#include "common.h"

#include <cmath>
#include <cstdio>
#include <cstring>
#include <ctime>
#include <fstream>
#include <string>
#include <vector>
#include <algorithm>

#if defined(_MSC_VER)
#pragma warning(disable: 4244 4267) // possible loss of data
#endif

struct mnist_model {
struct ggml_tensor * conv2d_1_kernel;
struct ggml_tensor * conv2d_1_bias;
struct ggml_tensor * conv2d_2_kernel;
struct ggml_tensor * conv2d_2_bias;
struct ggml_tensor * dense_weight;
struct ggml_tensor * dense_bias;
struct ggml_context * ctx;
};

bool mnist_model_load(const std::string & fname, mnist_model & model) {
struct gguf_init_params params = {
/*.no_alloc =*/ false,
/*.ctx =*/ &model.ctx,
};
gguf_context * ctx = gguf_init_from_file(fname.c_str(), params);
if (!ctx) {
fprintf(stderr, "%s: gguf_init_from_file() failed\n", __func__);
return false;
}
model.conv2d_1_kernel = ggml_get_tensor(model.ctx, "kernel1");
model.conv2d_1_bias = ggml_get_tensor(model.ctx, "bias1");
model.conv2d_2_kernel = ggml_get_tensor(model.ctx, "kernel2");
model.conv2d_2_bias = ggml_get_tensor(model.ctx, "bias2");
model.dense_weight = ggml_get_tensor(model.ctx, "dense_w");
model.dense_bias = ggml_get_tensor(model.ctx, "dense_b");
return true;
}

int mnist_eval(
const mnist_model & model,
const int n_threads,
std::vector<float> digit,
const char * fname_cgraph
)
{
static size_t buf_size = 100000 * sizeof(float) * 4;
static void * buf = malloc(buf_size);

struct ggml_init_params params = {
/*.mem_size =*/ buf_size,
/*.mem_buffer =*/ buf,
/*.no_alloc =*/ false,
};

struct ggml_context * ctx0 = ggml_init(params);
struct ggml_cgraph gf = {};

struct ggml_tensor * input = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, 28, 28, 1, 1);
memcpy(input->data, digit.data(), ggml_nbytes(input));
ggml_set_name(input, "input");
ggml_tensor * cur = ggml_conv_2d(ctx0, model.conv2d_1_kernel, input, 1, 1, 0, 0, 1, 1);
cur = ggml_add(ctx0, cur, model.conv2d_1_bias);
cur = ggml_relu(ctx0, cur);
// Output shape after Conv2D: (26 26 32 1)
cur = ggml_pool_2d(ctx0, cur, GGML_OP_POOL_MAX, 2, 2, 2, 2, 0, 0);
// Output shape after MaxPooling2D: (13 13 32 1)
cur = ggml_conv_2d(ctx0, model.conv2d_2_kernel, cur, 1, 1, 0, 0, 1, 1);
cur = ggml_add(ctx0, cur, model.conv2d_2_bias);
cur = ggml_relu(ctx0, cur);
// Output shape after Conv2D: (11 11 64 1)
cur = ggml_pool_2d(ctx0, cur, GGML_OP_POOL_MAX, 2, 2, 2, 2, 0, 0);
// Output shape after MaxPooling2D: (5 5 64 1)
cur = ggml_cont(ctx0, ggml_permute(ctx0, cur, 1, 2, 0, 3));
// Output shape after permute: (64 5 5 1)
cur = ggml_reshape_2d(ctx0, cur, 1600, 1);
// Final Dense layer
cur = ggml_add(ctx0, ggml_mul_mat(ctx0, model.dense_weight, cur), model.dense_bias);
ggml_tensor * probs = ggml_soft_max(ctx0, cur);
ggml_set_name(probs, "probs");

ggml_build_forward_expand(&gf, probs);
ggml_graph_compute_with_ctx(ctx0, &gf, n_threads);

//ggml_graph_print(&gf);
ggml_graph_dump_dot(&gf, NULL, "mnist-cnn.dot");

if (fname_cgraph) {
// export the compute graph for later use
// see the "mnist-cpu" example
ggml_graph_export(&gf, fname_cgraph);

fprintf(stderr, "%s: exported compute graph to '%s'\n", __func__, fname_cgraph);
}

const float * probs_data = ggml_get_data_f32(probs);
const int prediction = std::max_element(probs_data, probs_data + 10) - probs_data;
ggml_free(ctx0);
return prediction;
}

int main(int argc, char ** argv) {
srand(time(NULL));
ggml_time_init();

if (argc != 3) {
fprintf(stderr, "Usage: %s models/mnist/mnist-cnn.gguf models/mnist/t10k-images.idx3-ubyte\n", argv[0]);
exit(0);
}

uint8_t buf[784];
mnist_model model;
std::vector<float> digit;

// load the model
{
const int64_t t_start_us = ggml_time_us();

if (!mnist_model_load(argv[1], model)) {
fprintf(stderr, "%s: failed to load model from '%s'\n", __func__, argv[1]);
return 1;
}

const int64_t t_load_us = ggml_time_us() - t_start_us;

fprintf(stdout, "%s: loaded model in %8.2f ms\n", __func__, t_load_us / 1000.0f);
}

// read a random digit from the test set
{
std::ifstream fin(argv[2], std::ios::binary);
if (!fin) {
fprintf(stderr, "%s: failed to open '%s'\n", __func__, argv[2]);
return 1;
}

// seek to a random digit: 16-byte header + 28*28 * (random 0 - 10000)
fin.seekg(16 + 784 * (rand() % 10000));
fin.read((char *) &buf, sizeof(buf));
}

// render the digit in ASCII
{
digit.resize(sizeof(buf));

for (int row = 0; row < 28; row++) {
for (int col = 0; col < 28; col++) {
fprintf(stderr, "%c ", (float)buf[row*28 + col] > 230 ? '*' : '_');
digit[row*28 + col] = ((float)buf[row*28 + col] / 255.0f);
}

fprintf(stderr, "\n");
}

fprintf(stderr, "\n");
}

const int prediction = mnist_eval(model, 1, digit, nullptr);
fprintf(stdout, "%s: predicted digit is %d\n", __func__, prediction);
ggml_free(model.ctx);
return 0;
}
101 changes: 101 additions & 0 deletions examples/mnist/mnist-cnn.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
#!/usr/bin/env python3
import sys
import gguf
import numpy as np
from tensorflow import keras
from tensorflow.keras import layers

def train(model_name):
# Model / data parameters
num_classes = 10
input_shape = (28, 28, 1)

# Load the data and split it between train and test sets
(x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data()

# Scale images to the [0, 1] range
x_train = x_train.astype("float32") / 255
x_test = x_test.astype("float32") / 255
# Make sure images have shape (28, 28, 1)
x_train = np.expand_dims(x_train, -1)
x_test = np.expand_dims(x_test, -1)
print("x_train shape:", x_train.shape)
print(x_train.shape[0], "train samples")
print(x_test.shape[0], "test samples")

# convert class vectors to binary class matrices
y_train = keras.utils.to_categorical(y_train, num_classes)
y_test = keras.utils.to_categorical(y_test, num_classes)

model = keras.Sequential(
[
keras.Input(shape=input_shape),
layers.Conv2D(32, kernel_size=(3, 3), activation="relu"),
layers.MaxPooling2D(pool_size=(2, 2)),
layers.Conv2D(64, kernel_size=(3, 3), activation="relu"),
layers.MaxPooling2D(pool_size=(2, 2)),
layers.Flatten(),
layers.Dropout(0.5),
layers.Dense(num_classes, activation="softmax"),
]
)

model.summary()
batch_size = 128
epochs = 15
model.compile(loss="categorical_crossentropy", optimizer="adam", metrics=["accuracy"])
model.fit(x_train, y_train, batch_size=batch_size, epochs=epochs, validation_split=0.1)

score = model.evaluate(x_test, y_test, verbose=0)
print("Test loss:", score[0])
print("Test accuracy:", score[1])
model.save(model_name)
print("Keras model saved to '" + model_name + "'")

def convert(model_name):
model = keras.models.load_model(model_name)
gguf_model_name = model_name + ".gguf"
gguf_writer = gguf.GGUFWriter(gguf_model_name, "mnist-cnn")

kernel1 = model.layers[0].weights[0].numpy()
kernel1 = np.moveaxis(kernel1, [2,3], [0,1])
kernel1 = kernel1.astype(np.float16)
gguf_writer.add_tensor("kernel1", kernel1, raw_shape=(32, 1, 3, 3))

bias1 = model.layers[0].weights[1].numpy()
bias1 = np.repeat(bias1, 26*26)
gguf_writer.add_tensor("bias1", bias1, raw_shape=(1, 32, 26, 26))

kernel2 = model.layers[2].weights[0].numpy()
kernel2 = np.moveaxis(kernel2, [0,1,2,3], [2,3,1,0])
kernel2 = kernel2.astype(np.float16)
gguf_writer.add_tensor("kernel2", kernel2, raw_shape=(64, 32, 3, 3))

bias2 = model.layers[2].weights[1].numpy()
bias2 = np.repeat(bias2, 11*11)
gguf_writer.add_tensor("bias2", bias2, raw_shape=(1, 64, 11, 11))

dense_w = model.layers[-1].weights[0].numpy()
dense_w = dense_w.transpose()
gguf_writer.add_tensor("dense_w", dense_w, raw_shape=(10, 1600))

dense_b = model.layers[-1].weights[1].numpy()
gguf_writer.add_tensor("dense_b", dense_b)

gguf_writer.write_header_to_file()
gguf_writer.write_kv_data_to_file()
gguf_writer.write_tensors_to_file()
gguf_writer.close()
print("Model converted and saved to '{}'".format(gguf_model_name))

if __name__ == '__main__':
if len(sys.argv) < 3:
print("Usage: %s <train|convert> <model_name>".format(sys.argv[0]))
sys.exit(1)
if sys.argv[1] == 'train':
train(sys.argv[2])
elif sys.argv[1] == 'convert':
convert(sys.argv[2])
else:
print("Usage: %s <train|convert> <model_name>".format(sys.argv[0]))
sys.exit(1)

0 comments on commit 9a287f7

Please sign in to comment.