diff --git a/examples/mnist/web/index.html b/examples/mnist/web/index.html index 1bd01ae50..ab1ef1778 100644 --- a/examples/mnist/web/index.html +++ b/examples/mnist/web/index.html @@ -15,6 +15,7 @@

MNIST digit recognizer with GGML
+

@@ -23,33 +24,61 @@

MNIST digit recognizer with GGML "use strict"; const DIGIT_SIZE = 28; // digits are 28x28 pixels var canvas = document.getElementById("ggCanvas"); -var ctx = canvas.getContext("2d"); -var digit = new Array(DIGIT_SIZE*DIGIT_SIZE).fill(0); +var ctx = canvas.getContext("2d", { alpha: false, willReadFrequently: true }); +ctx.fillStyle = "white"; +ctx.fillRect(0, 0, canvas.width, canvas.height); var dragging = false; +var lastX, lastY; function onClear(event) { - ctx.clearRect(0, 0, canvas.width, canvas.height); - digit.fill(0); + ctx.fillStyle = "white"; + ctx.fillRect(0, 0, canvas.width, canvas.height); document.getElementById("prediction").innerHTML = ""; } +function predict(digit) { + let buf = Module._malloc(digit.length); + if (buf == 0) { + console.log("failed to allocate memory"); + return; + } + Module.HEAPU8.set(digit, buf); + let prediction = Module.ccall('wasm_eval', null, ['number'], [buf]); + Module._free(buf); + if (prediction >= 0) { + document.getElementById("prediction").innerHTML = "Predicted digit is " + prediction + ""; + } +} + function onRandom(event) { onClear(); - var buf = Module._malloc(digit.length); + const bufLength = DIGIT_SIZE*DIGIT_SIZE; + var buf = Module._malloc(bufLength); if (buf == 0) { console.log("failed to allocate memory"); return; } let ret = Module.ccall('wasm_random_digit', null, ['number'], [buf]); - let digitBytes = new Uint8Array(Module.HEAPU8.buffer, buf, digit.length); + let digit = new Uint8Array(Module.HEAPU8.buffer, buf, bufLength); for (let i = 0; i < digit.length; i++) { - digit[i] = digitBytes[i]; let x = i % DIGIT_SIZE; let y = Math.floor(i / DIGIT_SIZE); setPixel(x, y, digit[i]); } Module._free(buf); - onMouseUp(); + predict(digit); +} + +function onDownload(event) { + let digit = scaleCanvas(); + let digitBlob = new Blob([new Uint8Array(digit)], {type: "application/octet-stream"}); + let url = URL.createObjectURL(digitBlob); + let link = document.createElement('a'); + link.href = url; + link.download = "image.raw"; + document.body.appendChild(link); + link.click(); + document.body.removeChild(link); } // Get the position of the mouse relative to the canvas @@ -62,7 +91,6 @@

MNIST digit recognizer with GGML } function setPixel(x, y, val) { - digit[y * DIGIT_SIZE + x] = val; let canvasX = x * 13; let canvasY = y * 13; let color = 255 - val; @@ -72,28 +100,46 @@

MNIST digit recognizer with GGML function onMouseDown(e) { dragging = true; - let [mouseX, mouseY] = getMousePos(e); - setPixel(Math.floor(mouseX / 13), Math.floor(mouseY / 13), 255); + [lastX, lastY] = getMousePos(e); +} + +// scale the canvas to 28x28 pixels and return the pixel values as an array +function scaleCanvas() { + let imgData = ctx.getImageData(0, 0, canvas.width, canvas.height); + let tempCanvas = document.createElement('canvas'); + tempCanvas.width = DIGIT_SIZE; + tempCanvas.height = DIGIT_SIZE; + let tempCtx = tempCanvas.getContext("2d"); + tempCtx.drawImage(canvas, 0, 0, DIGIT_SIZE, DIGIT_SIZE); + let tempImgData = tempCtx.getImageData(0, 0, DIGIT_SIZE, DIGIT_SIZE); + let tempData = tempImgData.data; + let digit = new Array(DIGIT_SIZE*DIGIT_SIZE).fill(0); + for (let i = 0; i < tempData.length; i += 4) { + let val = 255 - tempData[i]; + digit[i / 4] = val; + } + return digit; } function onMouseUp(e) { dragging = false; - var buf = Module._malloc(digit.length); - if (buf == 0) { - console.log("failed to allocate memory"); - return; - } - Module.HEAPU8.set(digit, buf); - let prediction = Module.ccall('wasm_eval', null, ['number'], [buf]); - Module._free(buf); - if (prediction >= 0) { - document.getElementById("prediction").innerHTML = "Predicted digit is " + prediction + ""; - } + let digit = scaleCanvas(); + predict(digit); } + function onMouseMove(e) { if (dragging) { let [mouseX, mouseY] = getMousePos(e); - setPixel(Math.floor(mouseX / 13), Math.floor(mouseY / 13), 255); + ctx.beginPath(); + ctx.moveTo(lastX, lastY); + ctx.lineTo(mouseX, mouseY); + ctx.lineWidth = 20; + ctx.lineJoin = ctx.lineCap = 'round'; + ctx.strokeStyle = "#000000"; + ctx.stroke(); + ctx.closePath(); + lastX = mouseX; + lastY = mouseY; } }