如何使用 Node.js 部署机器学习模型?
在本文中,我们将学习如何使用NodeJS部署机器学习模型。在这样做的同时,我们将使用NodeJS和tensorflow.js制作一个简单的手写数字识别器。
Tensorflow.js是一个用于 JavaScript 的机器学习库。它有助于将机器学习模型直接部署到 node.js 或 Web 浏览器中。
训练模型:为了训练模型,我们将使用 Google Colab。它是一个平台,我们可以在其中运行我们所有的Python代码,并且它加载了大多数使用的机器学习库。
下面是我们将创建的最终模型的代码。
Python
# Importing Libraries
from tensorflow.keras.datasets import mnist
import matplotlib.pyplot as plt
from keras.utils import np_utils
from keras import Sequential
from keras.layers import Dense
import tensorflowjs as tfjs
# Loading data
(X_train, y_train), (X_test, y_test) = mnist.load_data()
print ("X_train.shape: {}".format(X_train.shape))
print ("y_train.shape: {}".format(y_train.shape))
print ("X_test.shape: {}".format(X_test.shape))
print ("y_test.shape: {}".format(y_test.shape))
# Visualizing Data
plt.subplot(161)
plt.imshow(X_train[3], cmap=plt.get_cmap('gray'))
plt.subplot(162)
plt.imshow(X_train[5], cmap=plt.get_cmap('gray'))
plt.subplot(163)
plt.imshow(X_train[7], cmap=plt.get_cmap('gray'))
plt.subplot(164)
plt.imshow(X_train[2], cmap=plt.get_cmap('gray'))
plt.subplot(165)
plt.imshow(X_train[0], cmap=plt.get_cmap('gray'))
plt.subplot(166)
plt.imshow(X_train[13], cmap=plt.get_cmap('gray'))
plt.show()
# Normalize Inputs from 0–255 to 0–1
X_train = X_train / 255
X_test = X_test / 255
# One-Hot Encode outputs
y_train = np_utils.to_categorical(y_train)
y_test = np_utils.to_categorical(y_test)
num_classes = 10
# Training model
x_train_simple = X_train.reshape(60000, 28 * 28).astype('float32')
x_test_simple = X_test.reshape(10000, 28 * 28).astype('float32')
model = Sequential()
model.add(Dense(28 * 28, input_dim=28 * 28, activation='relu'))
model.add(Dense(num_classes, activation='softmax'))
model.compile(loss='categorical_crossentropy',
optimizer='adam', metrics=['accuracy'])
model.fit(x_train_simple, y_train,
validation_data=(x_test_simple, y_test))
index.js
// Requiring module
const express = require("express");
const app = express();
const path = require("path")
// Set public as static directory
app.use(express.static('public'));
app.set('views', path.join(__dirname, '/views'))
// Use ejs as template engine
app.set('view engine', 'ejs');
app.use(express.json());
app.use(express.urlencoded({ extended: false }));
// Render main template
app.get('/',(req,res)=>{
res.render('main')
})
// Server setup
app.listen(3000, () => {
console.log("The server started running on port 3000")
});
main.ejs
Digit Recognition WebApp
Recognized digit
style.css
body {
touch-action: none;
font-family: "Roboto";
}
h1 {
margin: 50px;
font-size: 70px;
text-align: center;
}
#paint {
border:3px solid red;
margin: auto;
}
#predicted {
font-size: 60px;
margin-top: 60px;
text-align: center;
}
#number {
border: 3px solid black;
margin: auto;
margin-top: 30px;
text-align: center;
vertical-align: middle;
}
#clear {
margin: auto;
margin-top: 70px;
padding: 30px;
text-align: center;
}
script.js
canvas.addEventListener('mousedown', function (e) {
context.moveTo(mouse.x, mouse.y);
context.beginPath();
canvas.addEventListener('mousemove', onPaint, false);
}, false); var onPaint = function () {
context.lineTo(mouse.x, mouse.y);
context.stroke();
};
canvas.addEventListener('mouseup', function () {
$('#number').html('');
canvas.removeEventListener('mousemove', onPaint, false);
var img = new Image();
img.onload = function () {
context.drawImage(img, 0, 0, 28, 28);
data = context.getImageData(0, 0, 28, 28).data;
var input = [];
for (var i = 0; i < data.length; i += 4) {
input.push(data[i + 2] / 255);
}
predict(input);
};
img.src = canvas.toDataURL('image/png');
}, false);
// Setting up tfjs with the model we downloaded
tf.loadLayersModel('model / model.json')
.then(function (model) {
window.model = model;
});
// Predict function
var predict = function (input) {
if (window.model) {
window.model.predict([tf.tensor(input)
.reshape([1, 28, 28, 1])])
.array().then(function (scores) {
scores = scores[0];
predicted = scores
.indexOf(Math.max(...scores));
$('#number').html(predicted);
});
} else {
// The model takes a bit to load,
// if we are too fast, wait
setTimeout(function () { predict(input) }, 50);
}
}
第 1 步:训练数据
为了训练模型,我们将使用 MNIST 数据库。这是一个免费的大型手写数字数据库。在该数据集中,有 60,000 张图像,所有图像的灰度大小均为 28 x 28 像素,像素值从 0 到 255。
第 2 步:数据预处理
我们执行以下步骤来处理我们的数据:
- 标准化输入:输入范围为 0-255。我们需要将它们缩放到 0-1。
- 一个热编码输出
第 3 步:机器学习
为了训练模型,我们使用一个简单的神经网络,它有一个隐藏层,足以提供大约 98% 的准确率。
第 4 步:使用 tensorflow.js 转换模型
首先,使用以下命令保存模型:
model.save(“model.h5”)
然后安装tensorflow.js并使用以下命令转换模型:
!pip install tensorflowjs
!tensorflowjs_converter --input_format keras
‘/content/model.h5’ ‘/content/mnist-model’
运行上述命令后,现在刷新文件。您的内容应如下所示。
注意:下载 mnist-model文件夹,我们稍后将使用它。
创建 Express App 并安装模块:
第 1 步:使用以下命令创建package.json :
npm init
第2步:现在按照命令安装依赖项。我们将使用 express for server 和 ejs 作为模板引擎。
npm install express ejs
项目结构:现在确保您具有以下文件结构。将我们从 Colab 下载的文件复制到模型文件夹中。
现在将以下代码写到您的index.js文件中。
index.js
// Requiring module
const express = require("express");
const app = express();
const path = require("path")
// Set public as static directory
app.use(express.static('public'));
app.set('views', path.join(__dirname, '/views'))
// Use ejs as template engine
app.set('view engine', 'ejs');
app.use(express.json());
app.use(express.urlencoded({ extended: false }));
// Render main template
app.get('/',(req,res)=>{
res.render('main')
})
// Server setup
app.listen(3000, () => {
console.log("The server started running on port 3000")
});
主.ejs
Digit Recognition WebApp
Recognized digit
样式.css
body {
touch-action: none;
font-family: "Roboto";
}
h1 {
margin: 50px;
font-size: 70px;
text-align: center;
}
#paint {
border:3px solid red;
margin: auto;
}
#predicted {
font-size: 60px;
margin-top: 60px;
text-align: center;
}
#number {
border: 3px solid black;
margin: auto;
margin-top: 30px;
text-align: center;
vertical-align: middle;
}
#clear {
margin: auto;
margin-top: 70px;
padding: 30px;
text-align: center;
}
编写脚本:我们使用 HTML5 画布定义鼠标事件。然后我们将鼠标上的图像放大并将其缩放为 28×28 像素,使其与我们的模型相匹配,然后将其传递给我们的预测函数。
脚本.js
canvas.addEventListener('mousedown', function (e) {
context.moveTo(mouse.x, mouse.y);
context.beginPath();
canvas.addEventListener('mousemove', onPaint, false);
}, false); var onPaint = function () {
context.lineTo(mouse.x, mouse.y);
context.stroke();
};
canvas.addEventListener('mouseup', function () {
$('#number').html('');
canvas.removeEventListener('mousemove', onPaint, false);
var img = new Image();
img.onload = function () {
context.drawImage(img, 0, 0, 28, 28);
data = context.getImageData(0, 0, 28, 28).data;
var input = [];
for (var i = 0; i < data.length; i += 4) {
input.push(data[i + 2] / 255);
}
predict(input);
};
img.src = canvas.toDataURL('image/png');
}, false);
// Setting up tfjs with the model we downloaded
tf.loadLayersModel('model / model.json')
.then(function (model) {
window.model = model;
});
// Predict function
var predict = function (input) {
if (window.model) {
window.model.predict([tf.tensor(input)
.reshape([1, 28, 28, 1])])
.array().then(function (scores) {
scores = scores[0];
predicted = scores
.indexOf(Math.max(...scores));
$('#number').html(predicted);
});
} else {
// The model takes a bit to load,
// if we are too fast, wait
setTimeout(function () { predict(input) }, 50);
}
}
运行应用程序的步骤:使用以下命令运行index.js文件。
node index.js
输出:打开浏览器并 去 http://localhost:3000 ,我们将看到以下输出。