Skip to content

Commit

Permalink
mnist : add web page for the MNIST example (#190)
Browse files Browse the repository at this point in the history
The web page is using WASM for model inference.
Users can draw digits on an HTML canvas and load random digits from the
MNIST dataset.
  • Loading branch information
rgerganov committed May 24, 2023
1 parent d30ef19 commit 42bfaaf
Show file tree
Hide file tree
Showing 2 changed files with 161 additions and 0 deletions.
35 changes: 35 additions & 0 deletions examples/mnist/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -208,6 +208,41 @@ int mnist_eval(
return prediction;
}

#ifdef __cplusplus
extern "C" {
#endif

int wasm_eval(uint8_t *digitPtr)
{
mnist_model model;
if (!mnist_model_load("models/mnist/ggml-model-f32.bin", model)) {
fprintf(stderr, "error loading model\n");
return -1;
}
std::vector<float> digit(digitPtr, digitPtr + 784);
int result = mnist_eval(model, 1, digit);
ggml_free(model.ctx);
return result;
}

int wasm_random_digit(char *digitPtr)
{
auto fin = std::ifstream("models/mnist/t10k-images.idx3-ubyte", std::ios::binary);
if (!fin) {
fprintf(stderr, "failed to open digits file\n");
return 0;
}
srand(time(NULL));
// Seek to a random digit: 16-byte header + 28*28 * (random 0 - 10000)
fin.seekg(16 + 784 * (rand() % 10000));
fin.read(digitPtr, 784);
return 1;
}

#ifdef __cplusplus
}
#endif

int main(int argc, char ** argv) {
srand(time(NULL));
ggml_time_init();
Expand Down
126 changes: 126 additions & 0 deletions examples/mnist/web/index.html
Original file line number Diff line number Diff line change
@@ -0,0 +1,126 @@
<!DOCTYPE html>
<html>
<head>
<meta charset="utf-8">
<meta name="viewport" content="width=device-width, initial-scale=1, shrink-to-fit=no">
<title>MNIST with GGML</title>
<script src="mnist.js"></script>
</head>
<body>
<h2>MNIST digit recognizer with GGML</h2>
<p>Draw a single digit on the canvas below:</p>
<canvas id="ggCanvas" width="364" height="364" style="border:2px solid #d3d3d3;">
Your browser does not support the HTML canvas tag.
</canvas>
<div>
<button id="clear" onclick="onClear()">Clear</button>
<button id="random" onclick="onRandom()">Random</button>
</div>
<div>
<p id="prediction"></p>
</div>
<script>
"use strict";
const DIGIT_SIZE = 28; // digits are 28x28 pixels
var canvas = document.getElementById("ggCanvas");
var ctx = canvas.getContext("2d");
var digit = new Array(DIGIT_SIZE*DIGIT_SIZE).fill(0);
var dragging = false;

function onClear(event) {
ctx.clearRect(0, 0, canvas.width, canvas.height);
digit.fill(0);
document.getElementById("prediction").innerHTML = "";
}

function onRandom(event) {
onClear();
var buf = Module._malloc(digit.length);
if (buf == 0) {
console.log("failed to allocate memory");
return;
}
let ret = Module.ccall('wasm_random_digit', null, ['number'], [buf]);
let digitBytes = new Uint8Array(Module.HEAPU8.buffer, buf, digit.length);
for (let i = 0; i < digit.length; i++) {
digit[i] = digitBytes[i];
let x = i % DIGIT_SIZE;
let y = Math.floor(i / DIGIT_SIZE);
setPixel(x, y, digit[i]);
}
Module._free(buf);
onMouseUp();
}

// Get the position of the mouse relative to the canvas
function getMousePos(event) {
if (event.touches !== undefined && event.touches.length > 0) {
event = event.touches[0];
}
var rect = canvas.getBoundingClientRect();
return [Math.floor(event.clientX) - rect.left, Math.floor(event.clientY) - rect.top];
}

function setPixel(x, y, val) {
digit[y * DIGIT_SIZE + x] = val;
let canvasX = x * 13;
let canvasY = y * 13;
let color = 255 - val;
ctx.fillStyle = "#" + color.toString(16) + color.toString(16) + color.toString(16);
ctx.fillRect(canvasX, canvasY, 13, 13);
}

function onMouseDown(e) {
dragging = true;
let [mouseX, mouseY] = getMousePos(e);
setPixel(Math.floor(mouseX / 13), Math.floor(mouseY / 13), 255);
}

function onMouseUp(e) {
dragging = false;
var buf = Module._malloc(digit.length);
if (buf == 0) {
console.log("failed to allocate memory");
return;
}
Module.HEAPU8.set(digit, buf);
let prediction = Module.ccall('wasm_eval', null, ['number'], [buf]);
Module._free(buf);
if (prediction >= 0) {
document.getElementById("prediction").innerHTML = "Predicted digit is <b>" + prediction + "</b>";
}
}
function onMouseMove(e) {
if (dragging) {
let [mouseX, mouseY] = getMousePos(e);
setPixel(Math.floor(mouseX / 13), Math.floor(mouseY / 13), 255);
}
}

// Prevent scrolling when touching the canvas
document.body.addEventListener("touchstart", function (e) {
if (e.target == canvas) {
e.preventDefault();
}
}, {passive: false});
document.body.addEventListener("touchend", function (e) {
if (e.target == canvas) {
e.preventDefault();
}
}, {passive: false});
document.body.addEventListener("touchmove", function (e) {
if (e.target == canvas) {
e.preventDefault();
}
}, {passive: false});

// Use the same handlers for mouse and touch events
canvas.onmousedown = onMouseDown;
canvas.onmouseup = onMouseUp;
canvas.onmousemove = onMouseMove;
canvas.ontouchstart = onMouseDown;
canvas.ontouchend = onMouseUp;
canvas.ontouchmove = onMouseMove;
</script>
</body>
</html>

0 comments on commit 42bfaaf

Please sign in to comment.