Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add web page for the MNIST example #190

Merged
merged 1 commit into from
May 24, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Add web page for the MNIST example
The web page is using WASM for model inference.
Users can draw digits on an HTML canvas and load random digits from the
MNIST dataset.
  • Loading branch information
rgerganov committed May 23, 2023
commit a90cf1b99ffdc2a8341544fa3d9ffe521113b2ae
35 changes: 35 additions & 0 deletions examples/mnist/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -198,6 +198,41 @@ int mnist_eval(
return prediction;
}

#ifdef __cplusplus
extern "C" {
#endif

int wasm_eval(uint8_t *digitPtr)
{
mnist_model model;
if (!mnist_model_load("models/mnist/ggml-model-f32.bin", model)) {
fprintf(stderr, "error loading model\n");
return -1;
}
std::vector<float> digit(digitPtr, digitPtr + 784);
int result = mnist_eval(model, 1, digit);
ggml_free(model.ctx);
return result;
}

int wasm_random_digit(char *digitPtr)
{
auto fin = std::ifstream("models/mnist/t10k-images.idx3-ubyte", std::ios::binary);
if (!fin) {
fprintf(stderr, "failed to open digits file\n");
return 0;
}
srand(time(NULL));
// Seek to a random digit: 16-byte header + 28*28 * (random 0 - 10000)
fin.seekg(16 + 784 * (rand() % 10000));
fin.read(digitPtr, 784);
return 1;
}

#ifdef __cplusplus
}
#endif

int main(int argc, char ** argv) {
ggml_time_init();

Expand Down
126 changes: 126 additions & 0 deletions examples/mnist/web/index.html
Original file line number Diff line number Diff line change
@@ -0,0 +1,126 @@
<!DOCTYPE html>
<html>
<head>
<meta charset="utf-8">
<meta name="viewport" content="width=device-width, initial-scale=1, shrink-to-fit=no">
<title>MNIST with GGML</title>
<script src="mnist.js"></script>
</head>
<body>
<h2>MNIST digit recognizer with GGML</h2>
<p>Draw a single digit on the canvas below:</p>
<canvas id="ggCanvas" width="364" height="364" style="border:2px solid #d3d3d3;">
Your browser does not support the HTML canvas tag.
</canvas>
<div>
<button id="clear" onclick="onClear()">Clear</button>
<button id="random" onclick="onRandom()">Random</button>
</div>
<div>
<p id="prediction"></p>
</div>
<script>
"use strict";
const DIGIT_SIZE = 28; // digits are 28x28 pixels
var canvas = document.getElementById("ggCanvas");
var ctx = canvas.getContext("2d");
var digit = new Array(DIGIT_SIZE*DIGIT_SIZE).fill(0);
var dragging = false;

function onClear(event) {
ctx.clearRect(0, 0, canvas.width, canvas.height);
digit.fill(0);
document.getElementById("prediction").innerHTML = "";
}

function onRandom(event) {
onClear();
var buf = Module._malloc(digit.length);
if (buf == 0) {
console.log("failed to allocate memory");
return;
}
let ret = Module.ccall('wasm_random_digit', null, ['number'], [buf]);
let digitBytes = new Uint8Array(Module.HEAPU8.buffer, buf, digit.length);
for (let i = 0; i < digit.length; i++) {
digit[i] = digitBytes[i];
let x = i % DIGIT_SIZE;
let y = Math.floor(i / DIGIT_SIZE);
setPixel(x, y, digit[i]);
}
Module._free(buf);
onMouseUp();
}

// Get the position of the mouse relative to the canvas
function getMousePos(event) {
if (event.touches !== undefined && event.touches.length > 0) {
event = event.touches[0];
}
var rect = canvas.getBoundingClientRect();
return [Math.floor(event.clientX) - rect.left, Math.floor(event.clientY) - rect.top];
}

function setPixel(x, y, val) {
digit[y * DIGIT_SIZE + x] = val;
let canvasX = x * 13;
let canvasY = y * 13;
let color = 255 - val;
ctx.fillStyle = "#" + color.toString(16) + color.toString(16) + color.toString(16);
ctx.fillRect(canvasX, canvasY, 13, 13);
}

function onMouseDown(e) {
dragging = true;
let [mouseX, mouseY] = getMousePos(e);
setPixel(Math.floor(mouseX / 13), Math.floor(mouseY / 13), 255);
}

function onMouseUp(e) {
dragging = false;
var buf = Module._malloc(digit.length);
if (buf == 0) {
console.log("failed to allocate memory");
return;
}
Module.HEAPU8.set(digit, buf);
let prediction = Module.ccall('wasm_eval', null, ['number'], [buf]);
Module._free(buf);
if (prediction >= 0) {
document.getElementById("prediction").innerHTML = "Predicted digit is <b>" + prediction + "</b>";
}
}
function onMouseMove(e) {
if (dragging) {
let [mouseX, mouseY] = getMousePos(e);
setPixel(Math.floor(mouseX / 13), Math.floor(mouseY / 13), 255);
}
}

// Prevent scrolling when touching the canvas
document.body.addEventListener("touchstart", function (e) {
if (e.target == canvas) {
e.preventDefault();
}
}, {passive: false});
document.body.addEventListener("touchend", function (e) {
if (e.target == canvas) {
e.preventDefault();
}
}, {passive: false});
document.body.addEventListener("touchmove", function (e) {
if (e.target == canvas) {
e.preventDefault();
}
}, {passive: false});

// Use the same handlers for mouse and touch events
canvas.onmousedown = onMouseDown;
canvas.onmouseup = onMouseUp;
canvas.onmousemove = onMouseMove;
canvas.ontouchstart = onMouseDown;
canvas.ontouchend = onMouseUp;
canvas.ontouchmove = onMouseMove;
</script>
</body>
</html>