📜  如何使用 Node.js 部署机器学习模型?

📅  最后修改于: 2022-05-13 01:56:55.268000             🧑  作者: Mango

如何使用 Node.js 部署机器学习模型?

在本文中,我们将学习如何使用NodeJS部署机器学习模型。在这样做的同时,我们将使用NodeJStensorflow.js制作一个简单的手写数字识别器

Tensorflow.js是一个用于 JavaScript 的机器学习库。它有助于将机器学习模型直接部署到 node.js 或 Web 浏览器中。

训练模型:为了训练模型,我们将使用 Google Colab。它是一个平台,我们可以在其中运行我们所有的Python代码,并且它加载了大多数使用的机器学习库。

下面是我们将创建的最终模型的代码。

Python
# Importing Libraries
from tensorflow.keras.datasets import mnist
import matplotlib.pyplot as plt
from keras.utils import np_utils
from keras import Sequential
from keras.layers import Dense
import tensorflowjs as tfjs
  
# Loading data
(X_train, y_train), (X_test, y_test) = mnist.load_data()
print ("X_train.shape: {}".format(X_train.shape))
print ("y_train.shape: {}".format(y_train.shape))
print ("X_test.shape: {}".format(X_test.shape))
print ("y_test.shape: {}".format(y_test.shape))
  
# Visualizing Data
plt.subplot(161)
plt.imshow(X_train[3], cmap=plt.get_cmap('gray'))
plt.subplot(162)
plt.imshow(X_train[5], cmap=plt.get_cmap('gray'))
plt.subplot(163)
plt.imshow(X_train[7], cmap=plt.get_cmap('gray'))
plt.subplot(164)
plt.imshow(X_train[2], cmap=plt.get_cmap('gray'))
plt.subplot(165)
plt.imshow(X_train[0], cmap=plt.get_cmap('gray'))
plt.subplot(166)
plt.imshow(X_train[13], cmap=plt.get_cmap('gray'))
  
plt.show()
  
# Normalize Inputs from 0–255 to 0–1
X_train = X_train / 255
X_test = X_test / 255
# One-Hot Encode outputs
y_train = np_utils.to_categorical(y_train)
y_test = np_utils.to_categorical(y_test)
num_classes = 10
  
# Training model
x_train_simple = X_train.reshape(60000, 28 * 28).astype('float32')
x_test_simple = X_test.reshape(10000, 28 * 28).astype('float32')
model = Sequential()
model.add(Dense(28 * 28, input_dim=28 * 28, activation='relu'))
model.add(Dense(num_classes, activation='softmax'))
model.compile(loss='categorical_crossentropy', 
        optimizer='adam', metrics=['accuracy'])
model.fit(x_train_simple, y_train, 
        validation_data=(x_test_simple, y_test))


index.js
// Requiring module
const express = require("express");
const app = express();
const path = require("path")
  
// Set public as static directory
app.use(express.static('public'));
  
app.set('views', path.join(__dirname, '/views'))
  
// Use ejs as template engine
app.set('view engine', 'ejs');
app.use(express.json());
app.use(express.urlencoded({ extended: false }));
  
// Render main template
app.get('/',(req,res)=>{
    res.render('main')
})
  
// Server setup
app.listen(3000, () => {
  console.log("The server started running on port 3000") 
});


main.ejs


  

    
    
    

  

    

Digit Recognition WebApp

       
             
       
        Recognized digit         
             
              


style.css
body {
    touch-action: none; 
    font-family: "Roboto";
}
h1 {
    margin: 50px;
    font-size: 70px;
    text-align: center;
}
#paint {
    border:3px solid red;
    margin: auto;
}
#predicted { 
    font-size: 60px;
    margin-top: 60px;
    text-align: center;
}
#number {
    border: 3px solid black;
    margin: auto;
    margin-top: 30px;
    text-align: center;
    vertical-align: middle;
}
#clear {
    margin: auto;
    margin-top: 70px;
    padding: 30px;
    text-align: center;
}


script.js
canvas.addEventListener('mousedown', function (e) {
    context.moveTo(mouse.x, mouse.y);
    context.beginPath();
    canvas.addEventListener('mousemove', onPaint, false);
}, false); var onPaint = function () {
    context.lineTo(mouse.x, mouse.y);
    context.stroke();
};
  
canvas.addEventListener('mouseup', function () {
    $('#number').html('');
    canvas.removeEventListener('mousemove', onPaint, false);
    var img = new Image();
    img.onload = function () {
        context.drawImage(img, 0, 0, 28, 28);
        data = context.getImageData(0, 0, 28, 28).data;
        var input = [];
        for (var i = 0; i < data.length; i += 4) {
            input.push(data[i + 2] / 255);
        }
        predict(input);
    };
    img.src = canvas.toDataURL('image/png');
}, false);
  
// Setting up tfjs with the model we downloaded
tf.loadLayersModel('model / model.json')
    .then(function (model) {
        window.model = model;
    });
  
// Predict function
var predict = function (input) {
    if (window.model) {
        window.model.predict([tf.tensor(input)
            .reshape([1, 28, 28, 1])])
            .array().then(function (scores) {
                scores = scores[0];
                predicted = scores
                    .indexOf(Math.max(...scores));
                $('#number').html(predicted);
            });
    } else {
  
        // The model takes a bit to load, 
        // if we are too fast, wait
        setTimeout(function () { predict(input) }, 50);
    }
}


第 1 步:训练数据

为了训练模型,我们将使用 MNIST 数据库。这是一个免费的大型手写数字数据库。在该数据集中,有 60,000 张图像,所有图像的灰度大小均为 28 x 28 像素,像素值从 0 到 255。

第 2 步:数据预处理

我们执行以下步骤来处理我们的数据:

  1. 标准化输入:输入范围为 0-255。我们需要将它们缩放到 0-1。
  2. 一个热编码输出

第 3 步:机器学习

为了训练模型,我们使用一个简单的神经网络,它有一个隐藏层,足以提供大约 98% 的准确率。

第 4 步:使用 tensorflow.js 转换模型

首先,使用以下命令保存模型:

model.save(“model.h5”)

然后安装tensorflow.js并使用以下命令转换模型:

!pip install tensorflowjs
!tensorflowjs_converter --input_format keras 
    ‘/content/model.h5’ ‘/content/mnist-model’

运行上述命令后,现在刷新文件。您的内容应如下所示。

注意:下载 mnist-model文件夹,我们稍后将使用它。

创建 Express App 并安装模块:

第 1 步:使用以下命令创建package.json

npm init

第2步:现在按照命令安装依赖项。我们将使用 express for server 和 ejs 作为模板引擎。

npm install express ejs

项目结构:现在确保您具有以下文件结构。将我们从 Colab 下载的文件复制到模型文件夹中。

现在将以下代码写到您的index.js文件中。

index.js

// Requiring module
const express = require("express");
const app = express();
const path = require("path")
  
// Set public as static directory
app.use(express.static('public'));
  
app.set('views', path.join(__dirname, '/views'))
  
// Use ejs as template engine
app.set('view engine', 'ejs');
app.use(express.json());
app.use(express.urlencoded({ extended: false }));
  
// Render main template
app.get('/',(req,res)=>{
    res.render('main')
})
  
// Server setup
app.listen(3000, () => {
  console.log("The server started running on port 3000") 
});

主.ejs



  

    
    
    

  

    

Digit Recognition WebApp

       
             
       
        Recognized digit         
             
              

样式.css

body {
    touch-action: none; 
    font-family: "Roboto";
}
h1 {
    margin: 50px;
    font-size: 70px;
    text-align: center;
}
#paint {
    border:3px solid red;
    margin: auto;
}
#predicted { 
    font-size: 60px;
    margin-top: 60px;
    text-align: center;
}
#number {
    border: 3px solid black;
    margin: auto;
    margin-top: 30px;
    text-align: center;
    vertical-align: middle;
}
#clear {
    margin: auto;
    margin-top: 70px;
    padding: 30px;
    text-align: center;
}

编写脚本:我们使用 HTML5 画布定义鼠标事件。然后我们将鼠标上的图像放大并将其缩放为 28×28 像素,使其与我们的模型相匹配,然后将其传递给我们的预测函数。

脚本.js

canvas.addEventListener('mousedown', function (e) {
    context.moveTo(mouse.x, mouse.y);
    context.beginPath();
    canvas.addEventListener('mousemove', onPaint, false);
}, false); var onPaint = function () {
    context.lineTo(mouse.x, mouse.y);
    context.stroke();
};
  
canvas.addEventListener('mouseup', function () {
    $('#number').html('');
    canvas.removeEventListener('mousemove', onPaint, false);
    var img = new Image();
    img.onload = function () {
        context.drawImage(img, 0, 0, 28, 28);
        data = context.getImageData(0, 0, 28, 28).data;
        var input = [];
        for (var i = 0; i < data.length; i += 4) {
            input.push(data[i + 2] / 255);
        }
        predict(input);
    };
    img.src = canvas.toDataURL('image/png');
}, false);
  
// Setting up tfjs with the model we downloaded
tf.loadLayersModel('model / model.json')
    .then(function (model) {
        window.model = model;
    });
  
// Predict function
var predict = function (input) {
    if (window.model) {
        window.model.predict([tf.tensor(input)
            .reshape([1, 28, 28, 1])])
            .array().then(function (scores) {
                scores = scores[0];
                predicted = scores
                    .indexOf(Math.max(...scores));
                $('#number').html(predicted);
            });
    } else {
  
        // The model takes a bit to load, 
        // if we are too fast, wait
        setTimeout(function () { predict(input) }, 50);
    }
}

运行应用程序的步骤:使用以下命令运行index.js文件。

node index.js

输出:打开浏览器并 http://localhost:3000 ,我们将看到以下输出。