Tensorflow.js is a way to run tensorflow model in Javascript, or simply your browser. It is huge but not as huge as the Python tensorflow itself. The way we use it is first, to load the 1.2MB js file from the CDN at anywhere in the HTML:

<script src="https://cdn.jsdelivr.net/npm/@tensorflow/tfjs@3.12.0/dist/tf.min.js" integrity="sha256-Yl5oUVtHQ3wqFAPCSZmKxzSb/uZt+xzdT9mDPwwNYbk=" crossorigin="anonymous"></script>

and then a global JavaScript object tf is loaded. Next we need to run the following in JavaScript:

tf.loadLayersModel("modelpath/model.json").then(function(model) {
	window.model = model;

where modelpath/model.json is a path relative to the current HTML. It is generated by a converter that came with the Tensorflow.js. The key here is the Javascript promise function then(), which will assign the model to the current window’s property. This is just a convention to call this property model and obviously we can name it something else especially if there are multiple models to load.

The way it should be invoked is

	function(output) {

The input should be converted into a tensor by tf.tensor() function, and often it should also be reshaped to an appropriate dimension for the model. The model.predict() function will take time to run, hence a promise function should be created as well to process the output.

So how should we create the model at the first place? It should be natural to have the model developed in Python as it should be more convenient for experimentation and refinement. As an example, we can try to train LeNet5 for MNIST handwritten digit recognition:

import matplotlib.pyplot as plt
import numpy as np
import tensorflow as tf
from tensorflow.keras.datasets import mnist
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Conv2D, Dense, AveragePooling2D, Dropout, Flatten
from tensorflow.keras.utils import to_categorical
from tensorflow.keras.callbacks import EarlyStopping

# Load MNIST data
(X_train, y_train), (X_test, y_test) = mnist.load_data()

# Reshape data to shape of (n_sample, height, width, n_channel)
X_train = np.expand_dims(X_train, axis=3).astype('float32')
X_test = np.expand_dims(X_test, axis=3).astype('float32')

# One-hot encode the output
y_train = to_categorical(y_train)
y_test = to_categorical(y_test)

# LeNet5 model
model = Sequential([
    Conv2D(6, (5,5), input_shape=(28,28,1), padding="same", activation="tanh"),
    AveragePooling2D((2,2), strides=2),
    Conv2D(16, (5,5), activation="tanh"),
    AveragePooling2D((2,2), strides=2),
    Conv2D(120, (5,5), activation="tanh"),
    Dense(84, activation="tanh"),
    Dense(10, activation="softmax")

# Training
model.compile(loss="categorical_crossentropy", optimizer="adam", metrics=["accuracy"])
earlystopping = EarlyStopping(monitor="val_loss", patience=4, restore_best_weights=True)
model.fit(X_train, y_train, validation_data=(X_test, y_test), epochs=100, batch_size=32, callbacks=[earlystopping])

This code will train and save the LeNet5 model in HDF5 format. For tensorflow.js, we need to install some tools:

pip install tensorflowjs

This will install the Python tools for tensorflow.js and then we can run this to do the conversion:

tensorflowjs_converter --input_format keras_saved_model lenet5.h5 lenetjsmodel

The format must be keras_saved_model if we have the Keras model saved using the save() function. The last argument is the directory name for the tensorflow.js model. This command will produce the following files


and the json file is what you provide as argument to tf.loadLayerModel()

As an example, this is what you would do to implement this on a web page, which uses HTML5 canvas for the handwritten digit:

<!doctype html>
<html lang="en">
	<title>MNIST Recognition</title>
	#container {
		border: 3px solid #fff;
		padding: 10px;
		width: 655px;
		margin: 0 auto; /* center */
	#canvas, #result {
		width: 300px;
		height: 300px;
		margin: auto;
		border: 3px solid #7f7f7f;
		float: left;
		padding: 10px;
		font-size: 120px;
		text-align: center;
		vertical-align: middle;
	#reset {
		padding: 10px;
		text-align: center;
	#button {
		clear: both;
		text-align: center;
	h1 {
		margin: 10px;
		text-align: center;
	<script src="https://code.jquery.com/jquery-3.6.0.min.js" integrity="sha256-/xUj+3OJU5yExlq6GSYGSHk7tPXikynS7ogEvDej/m4=" crossorigin="anonymous"></script>
	<script src="https://cdn.jsdelivr.net/npm/@tensorflow/tfjs@3.12.0/dist/tf.min.js" integrity="sha256-Yl5oUVtHQ3wqFAPCSZmKxzSb/uZt+xzdT9mDPwwNYbk=" crossorigin="anonymous"></script>
	<h1>MNIST tfjs test</h1>
	<div id="container">
		<canvas id="canvas"></canvas>
		<div id="result"></div>
	<div id="button">
		<button id="reset">Reset</button>
	<div id="debug">
		<span id="lastinput"></span>
		<span id="lastresult"></span>

	// Load tensorflow model
	tf.loadLayersModel("lenetjsmodel/model.json").then(function(model) {
		window.model = model;
	var predict = function(input) {
		if (window.model) {
			]).array().then(function(scores) {
				scores = scores[0]; // convert 2D output into 1D
				$("#lastresult").html(scores.map(function(x){return Number(x.toFixed(3))}).toString());
				var predicted = scores.indexOf(Math.max(...scores));
		} else {
			// didn't have the model loaded yet? try again 30sec later
			}, 30);
	// Trigger drawing on canvas
	var canvas = document.getElementById("canvas");
	var compuetedStyle = getComputedStyle(document.getElementById("canvas"));
	canvas.width = parseInt(compuetedStyle.getPropertyValue("width"));
	canvas.height = parseInt(compuetedStyle.getPropertyValue("height"));
	var context = canvas.getContext("2d");  // to remember drawing
	context.strokeStyle = "#FF0000"; // draw in bright red
	context.lineWidth = 20; // Will downsize to 28x28, so must be thick enough
	var mouse = {x:0, y:0}; // to remember the coordinate w.r.t. canvas
	var onPaint = function() {
		// event handler for mousemove in canvas
		context.lineTo(mouse.x, mouse.y);
	$("#reset").click(function() {
		// on button click, clear the canvas and result
		context.clearRect(0, 0, canvas.width, canvas.height);
	// HTML5 Canvas mouse event
	canvas.addEventListener("mousedown", function(e) {
		// mousedown, begin path at mouse position
		context.moveTo(mouse.x, mouse.y);
		canvas.addEventListener("mousemove", onPaint, false);
	}, false);
	canvas.addEventListener("mousemove", function(e) {
		// mousemove remember position w.r.t. canvas
		mouse.x = e.pageX - this.offsetLeft;
		mouse.y = e.pageY - this.offsetTop;
	}, false);
	canvas.addEventListener("mouseup", function(e) {
		// Stop canvas from further update, then read drawing into image
		canvas.removeEventListener("mousemove", onPaint, false);
		var img = new Image(); // on load, this will be the canvas in same WxH
		img.onload = function() {
			// Draw this to 28x28 at top left corner of canvas so we can extract it back
			context.drawImage(img, 0, 0, 28, 28);
			// Extract data: Each pixel becomes a RGBA value, hence 4 bytes each
			var data = context.getImageData(0, 0, 28, 28).data;
			var input = [];
			for (var i=0; i<data.length; i += 4) {
				// scan each pixel, extract first byte (R component)
			var debug = [];
			for (var i=0; i<input.length; i+=28) {
				debug.push(input.slice(i, i+28).toString());
		img.src = canvas.toDataURL("image/png"); // convert canvas to img and trigger onload()
	}, false);