LeNet5 demo
Trained using the following colab notebook
weight (json) (1.34MB).
output
1window.onload = function()
2{
3
4const cvIn = document.getElementById("Input");
5const cvInCtx = cvIn.getContext("2d", {"alpha":false});
6const cvPt = document.getElementById("Viewport");
7const cvPtCtx = cvPt.getContext("2d", {"alpha":false});
8const ftoi = (v) => ~~v;
9var cursorHandler = null;
10var data = null;
11
12
13/* cursor handler:
14 tracks monitor and handle mouse/cursor events
15*/
16
17class cCursorHandler {
18 flags = {'idle': 0, 'move': 1, 'down': 2};
19 constructor (el, cb=(event)=>{}) {
20 this.el = el;
21 this.cb = cb;
22 this.w = el.width;
23 this.h = el.height;
24 this.flag = this.flags.idle;
25 this.pos = {};
26 this.pos.curr = {x:0, y:0};
27 this.pos.prev = {x:0, y:0};
28 this.ratio = {};
29 this.ratio.curr = {x:0, y:0};
30 this.ratio.prev = {x:0, y:0};
31 this.handleMove = (event)=>{this.update(event); this.flag|=this.flags.move; this.cb(this)};
32 this.handleDown = (event)=>{this.update(event); this.update(event); this.flag|=this.flags.down; this.cb(this)};
33 this.handleUp = (event)=>{this.update(event); this.flag=this.flags.idle; this.cb(this);};
34 this.el.addEventListener("pointermove", this.handleMove);
35 this.el.addEventListener("pointerdown", this.handleDown);
36 this.el.addEventListener("pointerup", this.handleUp);
37 this.el.addEventListener("pointercancel", this.handleUp);
38 };
39 update (event) {
40 let bcr = this.el.getBoundingClientRect();
41 this.pos.prev.x = this.pos.curr.x;
42 this.pos.prev.y = this.pos.curr.y;
43 this.ratio.prev.x = this.ratio.curr.x;
44 this.ratio.prev.y = this.ratio.curr.y;
45 this.ratio.curr.x = (event.clientX - bcr.left) / bcr.width;
46 this.ratio.curr.y = (event.clientY - bcr.top) / bcr.height;
47 this.pos.curr.x = ftoi(this.ratio.curr.x * this.w);
48 this.pos.curr.y = ftoi(this.ratio.curr.y * this.h);
49 };
50 detach () {
51 this.el.removeEventListener("pointermove", this.handleMove);
52 this.el.removeEventListener("pointerdown", this.handleDown);
53 this.el.removeEventListener("pointerup", this.handleUp);
54 this.el.removeEventListener("pointercancel", this.handleUp);
55 };
56}
57
58
59/* Interp
60 namespace for interpolation methods
61*/
62
63class Interp {
64
65 static _bi_loop(aw, ah, bw, bh, cb) {
66 let cx, cy, rx, ry, ax, ay, bx, by, axl, axh, ayl, ayh;
67 rx = (aw-2) / (bw-1);
68 ry = (ah-2) / (bh-1);
69 for (bx=0; bx<bw; bx++) {
70 for (by=0; by<bh; by++) {
71 axl = Math.floor(rx * bx);
72 axh = Math.ceil (rx * bx);
73 ayl = Math.floor(ry * by);
74 ayh = Math.ceil (ry * by);
75 cb(
76 bx, by,
77 axl, axh, ayl, ayh,
78 rx*bx-axl, ry*by-ayl
79 )
80 }
81 }
82 };
83
84 static _bi_inv(ms, md, bi_method) {
85 Interp._bi_loop(ms.w, ms.h, md.w, md.h,
86 (x, y, xl, xh, yl, yh, xt, yt)=>{
87 if (x<0 || x>=md.w || y<0 || y>=md.h) return;
88 if (xl<0 || xl>=ms.w || yl<0 || yl>=ms.h) return;
89 if (xh<0 || xh>=ms.w || yh<0 || yh>=ms.h) return;
90 md.set(x, y, bi_method(
91 ms.get(xl, yl), ms.get(xh, yl),
92 ms.get(xl, yh), ms.get(xh, yh),
93 xt, yt
94 ));
95 }
96 );
97 };
98
99 static nearest(a, b, t) {
100 return t >= .5 ? b : a;
101 };
102
103 static binearest(a1, b1, a2, b2, t1, t2) {
104 return Interp.nearest(
105 Interp.nearest(a1, b1, t1),
106 Interp.nearest(a2, b2, t1),
107 t2
108 )
109 };
110
111 static binearest_m(ms, md) {Interp._bi_inv(ms, md, Interp.binearest);}
112
113 static linear(a, b, t) {
114 return a + (b - a) * t
115 };
116
117 static bilinear(a1, b1, a2, b2, t1, t2) {
118 return Interp.linear(
119 Interp.linear(a1, b1, t1),
120 Interp.linear(a2, b2, t1),
121 t2
122 )
123 };
124
125 static bilinear_m(ms, md) {Interp._bi_inv(ms, md, Interp.bilinear)}
126}
127
128
129/* ==================== CNN Stuffs ==================== */
130
131/* Float32Matrix
132 butchered 2d matrix implementation
133*/
134
135class cFloat32Matrix {
136 constructor (w, h) {
137 this.w = w; this.h = h; this.l = w*h;
138 this.buf = new Float32Array(this.l);
139 }
140 itox(i) {return i % this.w};
141 itoy(i) {return ftoi(i / this.w)};
142 ctoi(x, y) {return x + (y * this.w)};
143 get(x, y) {return this.buf[this.ctoi(x, y)]};
144 set(x, y, v) {this.buf[this.ctoi(x, y)] = v;};
145 add(v) {for (let i=0; i<this.l; i++){this.buf[i] = this.buf[i] + v}};
146 mul(v) {for (let i=0; i<this.l; i++){this.buf[i] = this.buf[i] * v}};
147 map(cb) {for (let i=0; i<this.l; i++){this.buf[i] = cb(this.buf[i])}};
148 zeros() {for (let i=0; i<this.buf.length; i++) {this.buf[i] = 0}};
149 getMin() {return Math.min(...this.buf)};
150 getMax() {return Math.max(...this.buf)};
151 load2DArray(arr) {
152 for (let x=0; x<this.w; x++) {
153 for (let y=0; y<this.h; y++) {
154 this.set(x, y, arr[y][x]);
155 }
156 }
157 };
158 loadImageData(data) {
159 for (let i=0; i<data.length; i+=4) {
160 this.buf[ftoi(i/4)] = (data[i+0] + data[i+1] + data[i+2]) / 3 / 255;
161 }
162 };
163 unloadImageData(data) {
164 for (let i=0; i<data.length; i+=4) {
165 let p = this.buf[ftoi(i/4)];
166 p = p > 1 ? 1 : p;
167 p = p < 0 ? 0 : p;
168 p = ftoi(p * 255);
169 data[i+0] = p;
170 data[i+1] = p;
171 data[i+2] = p;
172 }
173 };
174 upscaleDilate(m) {
175 let scale = ftoi(m.w / this.w);
176 for (let x=0; x<this.w; x++) {
177 for (let y=0; y<this.h; y++) {
178 for (let i=0; i<scale; i++) {
179 for (let j=0; j<scale; j++) {
180 m.set((x*scale)+i, (y*scale)+j, this.get(x, y));
181 }
182 }
183 }
184 }
185 };
186}
187
188
189/* Additional: activation functions
190*/
191
192function Sigmoid(v) {return 1 / (1+Math.exp(-v))};
193function ReLU(v) {return v<0?0:v};
194
195
196/* Conv2D
197*/
198
199function _conv2dsize(inputSize, padding, kernelSize, stride) {
200 return Math.floor((inputSize + (padding*2) - (kernelSize - 1) - 1) / stride + 1)
201};
202
203
204function conv2d_h(src_w, src_h, dst_w, dst_h, kw, kh, px, py, sx, sy, cb) {
205 if (dst_w !== _conv2dsize(src_w, px, kw, sx)) { console.error(dst_w, _conv2dsize(src_w, px, kw, sx)); throw "Invalid Output Matrix Size" }
206 if (dst_h !== _conv2dsize(src_h, py, kh, sy)) { console.error(dst_h, _conv2dsize(src_h, py, kh, sy)); throw "Invalid Output Matrix Size" }
207 let dx, dy, kx, ky;
208 for (dx=0; dx<dst_w; dx++) {for (dy=0; dy<dst_h; dy++) {
209 for (kx=0; kx<kw; kx++) {for (ky=0; ky<kh; ky++) {
210 cb(dx, dy, kx, ky, (dx*sx)+kx-(px||0), (dy*sy)+ky-(py||0));
211 }}}};
212}
213
214
215class CNNConv2D {
216 constructor(kernel, stride, padding) {
217 this.kernel = kernel;
218 this.stride = stride || 1;
219 this.padding = padding || 0;
220 };
221
222 forward(ms, md) {
223 conv2d_h(ms.w, ms.h, md.w, md.h, this.kernel[0].length, this.kernel.length,
224 this.padding, this.padding, this.stride, this.stride,
225 (dx, dy, kx, ky, sx, sy)=>{
226 md.set(dx, dy, md.get(dx, dy) + (ms.get(sx, sy) || 0) * this.kernel[ky][kx])
227 })
228 };
229};
230
231
232/* MaxPool2D
233*/
234
235class CNNMaxPool2D {
236 constructor(w, h, stride, padding) {
237 this.w = w;
238 this.h = h;
239 this.stride = stride || 1;
240 this.padding = padding || 0;
241 };
242
243 forward(ms, md) {
244 md.zeros();
245 md.add(-1e31);
246 conv2d_h(ms.w, ms.h, md.w, md.h, this.w, this.h,
247 this.padding, this.padding, this.stride, this.stride,
248 (dx, dy, kx, ky, sx, sy)=>{
249 md.set(dx, dy, Math.max(md.get(dx, dy), (ms.get(sx, sy) || 0)))
250 })
251 };
252};
253
254
255/* Linear
256*/
257
258class Linear {
259 constructor (A) {
260 this.A = A;
261 };
262
263 forward(ms, md) {
264 for (let x=0; x<this.A[0].length; x++) {
265 for (let y=0; y<this.A.length; y++) {
266 md.buf[y] += ms.buf[x] * this.A[y][x];
267 }
268 }
269 };
270};
271
272
273/* ==================== Canvas routines ==================== */
274
275function canvasSubRoutineDrawDot(x, y) {
276 cvInCtx.fillStyle = "white";
277 cvInCtx.filter = "blur(5px)";
278 cvInCtx.beginPath();
279 cvInCtx.arc(x, y, 14, 0, 2 * Math.PI);
280 cvInCtx.fill();
281}
282
283
284function canvasSubRoutineDrawLine(x1, y1, x2, y2) {
285 cvInCtx.strokeStyle = "white";
286 cvInCtx.filter = "blur(5px)";
287 cvInCtx.lineWidth = 28;
288 cvInCtx.lineCap = "round";
289 cvInCtx.beginPath();
290 cvInCtx.moveTo(x1, y1);
291 cvInCtx.lineTo(x2, y2);
292 cvInCtx.stroke();
293}
294
295
296function canvasSubRoutineDrawText(x, y, text, align) {
297 cvPtCtx.save();
298 cvPtCtx.fillStyle = "black";
299 cvPtCtx.font = "32px monospace";
300 cvPtCtx.textBaseline = "middle";
301 cvPtCtx.textAlign = align || "left";
302 cvPtCtx.fillText(text, x, y);
303}
304
305
306function canvasSubRoutineDrawMatrix(mat_src, mat_tmp, x, y, normalize) {
307 const imd = cvPtCtx.getImageData(x, y, mat_tmp.w, mat_tmp.h);
308 Interp.binearest_m(mat_src, mat_tmp)
309 if (normalize || true) {
310 let mi = mat_tmp.getMin();
311 let ma = mat_tmp.getMax();
312 mat_tmp.add(-mi);
313 mat_tmp.mul(1/(ma-mi));
314 }
315 mat_tmp.unloadImageData(imd.data);
316 cvPtCtx.putImageData(imd, x, y);
317 cvPtCtx.strokeStyle = "blue";
318 cvPtCtx.rect(x-2,y-2,mat_tmp.w+4, mat_tmp.h+4);
319 cvPtCtx.stroke();
320}
321
322
323function canvasSubRoutineDrawDense(mat_src, x, y, w, h, draw_text) {
324 cvPtCtx.save()
325 let r = w / mat_src.w;
326 cvPtCtx.fillStyle = "white";
327 cvPtCtx.strokeStyle = "black";
328 cvPtCtx.beginPath();
329 cvPtCtx.rect(x, y, w, h)
330 cvPtCtx.fill()
331 for (let i=0; i<mat_src.w; i++) {
332 cvPtCtx.rect(x+r*i,y,r,h);
333 };
334 cvPtCtx.stroke()
335 for (let i=0; i<mat_src.w; i++) {
336 let hh = h*(1-mat_src.buf[i])
337 cvPtCtx.fillStyle = "blue";
338 cvPtCtx.fillRect(x+r*i,y+hh,r,h-hh);
339 if (draw_text || false) {
340 canvasSubRoutineDrawText(x+r*i+r/2, y+h+32, "" + i, 'center');
341 }
342 };
343 cvPtCtx.restore()
344};
345
346
347function netDrawAnnotation() {
348 canvasSubRoutineDrawText(150, 30, "input(512x512)", "center")
349 canvasSubRoutineDrawText(150, 60, "downsamp(28x28)", "center")
350 canvasSubRoutineDrawText(300, 80-20, "C1: Conv2D 5x5pad2 (28x28) x6 (Sigmoid)", "left")
351 canvasSubRoutineDrawText(300, 260-20, "S2: MaxPool 2x2stride2 (14x14) x6", "left")
352 canvasSubRoutineDrawText(30, 470-20, "C3: Conv2D 5x5 (10x10) x16 (Sigmoid)", "left")
353 canvasSubRoutineDrawText(30, 800-20, "S4: MaxPool 2x2stride2 (5x5) x16", "left")
354 canvasSubRoutineDrawText(30, 1076+260*0, "C5: Conv2D 5x5 (1x1) x120 Flatten (Sigmoid)", "left")
355 canvasSubRoutineDrawText(30, 1076+260*1, "F6: Linear (84x120) (Sigmoid)", "left")
356 canvasSubRoutineDrawText(30, 1076+260*2, "F0: Linear (10x84) (Sigmoid)", "left")
357};
358
359
360function appInCanvasReset() {
361 cvInCtx.fillStyle = "black";
362 cvInCtx.clearRect(0,0,cvIn.width,cvIn.height);
363 cvInCtx.fillRect(0,0,cvIn.width,cvIn.height);
364 cvPtCtx.save();
365 cvPtCtx.fillStyle = "white";
366 cvPtCtx.clearRect(0,0,cvPt.width,cvPt.height);
367 cvPtCtx.fillRect(0,0,cvPt.width,cvPt.height);
368 netDrawAnnotation();
369};
370
371
372/* temporary buffers
373*/
374
375const tmp_buf = {};
376// tmp_buf.m2032 = new cFloat32Matrix(32, 32);
377// tmp_buf.m2048 = new cFloat32Matrix(48, 48);
378// tmp_buf.m2064 = new cFloat32Matrix(64, 64);
379tmp_buf.m2080 = new cFloat32Matrix(80, 80);
380tmp_buf.m2100 = new cFloat32Matrix(100, 100);
381tmp_buf.m2112 = new cFloat32Matrix(112, 112);
382// tmp_buf.m2140 = new cFloat32Matrix(140, 140);
383tmp_buf.m2120 = new cFloat32Matrix(120, 120);
384// tmp_buf.m2160 = new cFloat32Matrix(160, 160);
385tmp_buf.m2240 = new cFloat32Matrix(240, 240);
386// tmp_buf.m2028 = new cFloat32Matrix(28, 28);
387// tmp_buf.m2014 = new cFloat32Matrix(14, 14);
388// tmp_buf.m2005 = new cFloat32Matrix(5, 5);
389// tmp_buf.m2002 = new cFloat32Matrix(2, 2);
390tmp_buf.m2001 = new cFloat32Matrix(1, 1);
391// tmp_buf.m1120 = new cFloat32Matrix(120, 1);
392// tmp_buf.m1084 = new cFloat32Matrix(84, 1);
393// tmp_buf.m1010 = new cFloat32Matrix(10, 1);
394
395/* layers and main buffers
396*/
397
398layers = {};
399buffers = {};
400buffers.input = new cFloat32Matrix(cvIn.width, cvIn.height);
401buffers.ds28 = new cFloat32Matrix(28, 28);
402
403function netInit() {
404 const pool2d2x2 = new CNNMaxPool2D(2, 2, 2, 0);
405 buffers.c1 = {};
406 buffers.s2 = {};
407 buffers.c3 = {};
408 buffers.s4 = {};
409 buffers.c5 = {};
410 buffers.f6 = {};
411 buffers.fo = {};
412 for (let i=0; i<6; i++) {buffers.c1[i] = new cFloat32Matrix(28, 28)};
413 for (let i=0; i<6; i++) {buffers.s2[i] = new cFloat32Matrix(14, 14)};
414 for (let i=0; i<16; i++) {buffers.c3[i] = new cFloat32Matrix(10, 10)};
415 for (let i=0; i<16; i++) {buffers.s4[i] = new cFloat32Matrix(5, 5)};
416 for (let i=0; i<1; i++) {buffers.c5[i] = new cFloat32Matrix(120, 1)};
417 for (let i=0; i<1; i++) {buffers.f6[i] = new cFloat32Matrix(84, 1)};
418 for (let i=0; i<1; i++) {buffers.fo[i] = new cFloat32Matrix(10, 1)};
419 layers.c1 = {};
420 layers.s2 = {};
421 for (let i=0; i<6; i++) {
422 layers.c1[i] = {};
423 layers.c1[i].conv2d = new CNNConv2D(data['c1_conv.weight'][i][0], 1, 2);
424 layers.c1[i].bias = data['c1_conv.bias'][i];
425 layers.s2[i] = {};
426 layers.s2[i].pool2d = pool2d2x2;
427 };
428 layers.c3 = {};
429 layers.s4 = {};
430 for (let j=0; j<16; j++) {
431 layers.c3[j] = {};
432 layers.c3[j].bias = data['c3_conv.bias'][j];
433 layers.s4[j] = {};
434 layers.s4[j].pool2d = pool2d2x2;
435 for (let i=0; i<6; i++) {
436 layers.c3[j][i] = {};
437 layers.c3[j][i].conv2d = new CNNConv2D(data['c3_conv.weight'][j][i], 1, 0)
438 }
439 };
440 layers.c5 = {};
441 for (let k=0; k<120; k++) {
442 layers.c5[k] = {};
443 layers.c5[k].bias = data['c5_conv.bias'][k];
444 for (let j=0; j<16; j++) {
445 layers.c5[k][j] = {};
446 layers.c5[k][j].conv2d = new CNNConv2D(data['c5_conv.weight'][k][j], 1, 0)
447 }
448 };
449 layers.f6 = {};
450 layers.f6.linear = new Linear(data['f6_linr.weight']);
451 layers.f6.bias = data['f6_linr.bias'];
452 layers.fo = {};
453 layers.fo.linear = new Linear(data['fo_linr.weight'])
454 layers.fo.bias = data['fo_linr.bias']
455};
456
457
458function netForward() {
459 for (let i=0; i<6; i++) {buffers.c1[i].zeros()};
460 for (let i=0; i<6; i++) {buffers.s2[i].zeros()};
461 for (let i=0; i<16; i++) {buffers.c3[i].zeros()};
462 for (let i=0; i<16; i++) {buffers.s4[i].zeros()};
463 for (let i=0; i<1; i++) {buffers.c5[i].zeros()};
464 for (let i=0; i<1; i++) {buffers.f6[i].zeros()};
465 for (let i=0; i<1; i++) {buffers.fo[i].zeros()};
466
467 for (let i=0; i<6; i++) {
468 layers.c1[i].conv2d.forward(buffers.ds28, buffers.c1[i]);
469 buffers.c1[i].add(layers.c1[i].bias);
470 buffers.c1[i].map(Sigmoid);
471 layers.s2[i].pool2d.forward(buffers.c1[i], buffers.s2[i])
472 buffers.s2[i].upscaleDilate(tmp_buf.m2112)
473 canvasSubRoutineDrawMatrix(buffers.c1[i], tmp_buf.m2120, 240+30+25+(122*i), 80);
474 canvasSubRoutineDrawMatrix(tmp_buf.m2112, tmp_buf.m2120, 240+30+25+(122*i), 260);
475 };
476
477 for (let i=0; i<6; i++) {
478 for (let j=0; j<16; j++) {
479 layers.c3[j][i].conv2d.forward(buffers.s2[i], buffers.c3[j])
480 }
481 };
482
483 for (let j=0; j<16; j++) {
484 buffers.c3[j].add(layers.c3[j].bias)
485 buffers.c3[j].map(Sigmoid)
486 layers.s4[j].pool2d.forward(buffers.c3[j], buffers.s4[j])
487 buffers.s4[j].upscaleDilate(tmp_buf.m2080)
488 canvasSubRoutineDrawMatrix(buffers.c3[j], tmp_buf.m2100, 30+(122*(j % 8)), 470+122*(j>7?1:0));
489 canvasSubRoutineDrawMatrix(tmp_buf.m2080, tmp_buf.m2100, 30+(122*(j % 8)), 800+122*(j>7?1:0));
490 };
491
492 for (let k=0; k<120; k++) {
493 tmp_buf.m2001.zeros();
494 for (let j=0; j<16; j++) {
495 layers.c5[k][j].conv2d.forward(buffers.s4[j], tmp_buf.m2001)
496 }
497 tmp_buf.m2001.add(layers.c5[k].bias);
498 buffers.c5[0].buf[k] = tmp_buf.m2001.buf[0];
499 };
500 buffers.c5[0].map(Sigmoid);
501 layers.f6.linear.forward(buffers.c5[0], buffers.f6[0]);
502 for (let i=0; i<84; i++) {buffers.f6[0].buf[i] += layers.f6.bias[i]};
503 buffers.f6[0].map(Sigmoid);
504 layers.fo.linear.forward(buffers.f6[0], buffers.fo[0]);
505 for (let i=0; i<10; i++) {buffers.fo[0].buf[i] += layers.fo.bias[i]};
506 buffers.fo[0].map(Sigmoid);
507 canvasSubRoutineDrawDense(buffers.c5[0], 30, 1100+260*0, cvPt.width-30-30, 200);
508 canvasSubRoutineDrawDense(buffers.f6[0], 30, 1100+260*1, cvPt.width-30-30, 200);
509 canvasSubRoutineDrawDense(buffers.fo[0], 30, 1100+260*2, cvPt.width-30-30, 200, true);
510};
511
512
513/* ==================== Entry point ==================== */
514
515function connectCanvasCursorInput() {
516 function cursorEventHandler(h) {
517 if (h.flag & h.flags.down || h.flag === (h.flags.down | h.flags.move)){
518 if (h.pos.curr.x == h.pos.prev.x && h.pos.curr.y == h.pos.prev.y) {
519 canvasSubRoutineDrawDot(h.pos.curr.x, h.pos.curr.y);
520 } else {
521 canvasSubRoutineDrawLine(h.pos.prev.x, h.pos.prev.y, h.pos.curr.x, h.pos.curr.y);
522 }
523 }
524 if (h.flag === h.flags.idle){
525 appUpdate();
526 }
527 };
528 cursorHandler = new cCursorHandler(cvIn, cursorEventHandler);
529}
530
531function disconnectCanvasCursorInput() {
532 if (cursorHandler !== null) {
533 cursorHandler.detach();
534 cursorHandler = null;
535 }
536}
537
538function appInit() {
539 disconnectCanvasCursorInput();
540 connectCanvasCursorInput();
541 appInCanvasReset();
542 netInit();
543 appUpdate();
544};
545
546
547function appUpdate() {
548 const imd = cvInCtx.getImageData(0,0,cvIn.width,cvIn.height);
549 buffers.input.loadImageData(imd.data);
550 Interp.bilinear_m(buffers.input, buffers.ds28);
551 buffers.ds28.upscaleDilate(tmp_buf.m2112)
552 canvasSubRoutineDrawMatrix(tmp_buf.m2112, tmp_buf.m2240, 30, 80);
553 netForward();
554};
555
556
557fetch('./assets/LeNet5.json')
558.then((resp)=>{
559 console.log(resp.status, resp.url);
560 resp.json()
561 .then((d)=>{
562 data = d;
563 document.getElementById("appInCanvasReset").addEventListener("click", (e)=>{appInCanvasReset();appUpdate()})
564 appInit();
565 })
566 .catch(console.error);
567})
568.catch(console.error);
569
570
571}