Skip to content

Commit

Permalink
mnist : smooth user input (#199)
Browse files Browse the repository at this point in the history
Drawing on the canvas is now smooth. The final image which is used for
prediction is obtained by down-scaling the canvas to 28x28 pixels.
Download button is aslo added for downloading raw image values.
  • Loading branch information
rgerganov committed May 26, 2023
1 parent 3d3e22f commit e8d347b
Showing 1 changed file with 69 additions and 23 deletions.
92 changes: 69 additions & 23 deletions examples/mnist/web/index.html
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ <h2>MNIST digit recognizer with <a href="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/ggerganov/ggml">GGML
<div>
<button id="clear" onclick="onClear()">Clear</button>
<button id="random" onclick="onRandom()" disabled>Random</button>
<button id="download" onclick="onDownload()">Download</button>
</div>
<div>
<p id="prediction"></p>
Expand All @@ -23,33 +24,61 @@ <h2>MNIST digit recognizer with <a href="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/ggerganov/ggml">GGML
"use strict";
const DIGIT_SIZE = 28; // digits are 28x28 pixels
var canvas = document.getElementById("ggCanvas");
var ctx = canvas.getContext("2d");
var digit = new Array(DIGIT_SIZE*DIGIT_SIZE).fill(0);
var ctx = canvas.getContext("2d", { alpha: false, willReadFrequently: true });
ctx.fillStyle = "white";
ctx.fillRect(0, 0, canvas.width, canvas.height);
var dragging = false;
var lastX, lastY;

function onClear(event) {
ctx.clearRect(0, 0, canvas.width, canvas.height);
digit.fill(0);
ctx.fillStyle = "white";
ctx.fillRect(0, 0, canvas.width, canvas.height);
document.getElementById("prediction").innerHTML = "";
}

function predict(digit) {
let buf = Module._malloc(digit.length);
if (buf == 0) {
console.log("failed to allocate memory");
return;
}
Module.HEAPU8.set(digit, buf);
let prediction = Module.ccall('wasm_eval', null, ['number'], [buf]);
Module._free(buf);
if (prediction >= 0) {
document.getElementById("prediction").innerHTML = "Predicted digit is <b>" + prediction + "</b>";
}
}

function onRandom(event) {
onClear();
var buf = Module._malloc(digit.length);
const bufLength = DIGIT_SIZE*DIGIT_SIZE;
var buf = Module._malloc(bufLength);
if (buf == 0) {
console.log("failed to allocate memory");
return;
}
let ret = Module.ccall('wasm_random_digit', null, ['number'], [buf]);
let digitBytes = new Uint8Array(Module.HEAPU8.buffer, buf, digit.length);
let digit = new Uint8Array(Module.HEAPU8.buffer, buf, bufLength);
for (let i = 0; i < digit.length; i++) {
digit[i] = digitBytes[i];
let x = i % DIGIT_SIZE;
let y = Math.floor(i / DIGIT_SIZE);
setPixel(x, y, digit[i]);
}
Module._free(buf);
onMouseUp();
predict(digit);
}

function onDownload(event) {
let digit = scaleCanvas();
let digitBlob = new Blob([new Uint8Array(digit)], {type: "application/octet-stream"});
let url = URL.createObjectURL(digitBlob);
let link = document.createElement('a');
link.href = url;
link.download = "image.raw";
document.body.appendChild(link);
link.click();
document.body.removeChild(link);
}

// Get the position of the mouse relative to the canvas
Expand All @@ -62,7 +91,6 @@ <h2>MNIST digit recognizer with <a href="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/ggerganov/ggml">GGML
}

function setPixel(x, y, val) {
digit[y * DIGIT_SIZE + x] = val;
let canvasX = x * 13;
let canvasY = y * 13;
let color = 255 - val;
Expand All @@ -72,28 +100,46 @@ <h2>MNIST digit recognizer with <a href="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/ggerganov/ggml">GGML

function onMouseDown(e) {
dragging = true;
let [mouseX, mouseY] = getMousePos(e);
setPixel(Math.floor(mouseX / 13), Math.floor(mouseY / 13), 255);
[lastX, lastY] = getMousePos(e);
}

// scale the canvas to 28x28 pixels and return the pixel values as an array
function scaleCanvas() {
let imgData = ctx.getImageData(0, 0, canvas.width, canvas.height);
let tempCanvas = document.createElement('canvas');
tempCanvas.width = DIGIT_SIZE;
tempCanvas.height = DIGIT_SIZE;
let tempCtx = tempCanvas.getContext("2d");
tempCtx.drawImage(canvas, 0, 0, DIGIT_SIZE, DIGIT_SIZE);
let tempImgData = tempCtx.getImageData(0, 0, DIGIT_SIZE, DIGIT_SIZE);
let tempData = tempImgData.data;
let digit = new Array(DIGIT_SIZE*DIGIT_SIZE).fill(0);
for (let i = 0; i < tempData.length; i += 4) {
let val = 255 - tempData[i];
digit[i / 4] = val;
}
return digit;
}

function onMouseUp(e) {
dragging = false;
var buf = Module._malloc(digit.length);
if (buf == 0) {
console.log("failed to allocate memory");
return;
}
Module.HEAPU8.set(digit, buf);
let prediction = Module.ccall('wasm_eval', null, ['number'], [buf]);
Module._free(buf);
if (prediction >= 0) {
document.getElementById("prediction").innerHTML = "Predicted digit is <b>" + prediction + "</b>";
}
let digit = scaleCanvas();
predict(digit);
}

function onMouseMove(e) {
if (dragging) {
let [mouseX, mouseY] = getMousePos(e);
setPixel(Math.floor(mouseX / 13), Math.floor(mouseY / 13), 255);
ctx.beginPath();
ctx.moveTo(lastX, lastY);
ctx.lineTo(mouseX, mouseY);
ctx.lineWidth = 20;
ctx.lineJoin = ctx.lineCap = 'round';
ctx.strokeStyle = "#000000";
ctx.stroke();
ctx.closePath();
lastX = mouseX;
lastY = mouseY;
}
}

Expand Down

0 comments on commit e8d347b

Please sign in to comment.