📜  Tensorflow.js tf.GraphModel 类 .predict() 方法(1)

📅  最后修改于: 2023-12-03 14:47:55.001000             🧑  作者: Mango

Tensorflow.js tf.GraphModel 类 .predict() 方法

TensorFlow.js 是一个由 Google Brain 团队开发的 JavaScript 库,可用于在浏览器或 Node.js 环境中构建机器学习应用程序。其中 tf.GraphModel 类是用于加载和执行训练好的 TensorFlow 模型的。

.predict() 是 tf.GraphModel 类中用于推断输入数据的方法,它的主要作用是接受一个输入 Tensor,然后返回一个输出 Tensor。

语法
model.predict(inputs, config)
参数
  • inputs:输入数据,可以是一个 Tensor、一组 Tensor 或一个包含 Tensor 的字典。具体取决于模型的输入层。
  • config:可选参数,用于控制推断过程的一些设置,如批次大小。
返回值

输出 Tensor、一组 Tensor 或一个包含 Tensor 的字典,具体取决于模型的输出层。

示例

假设我们有一个简单的 TensorFlow 模型,在两个整数输入上执行加法运算并输出结果。以下是该模型的代码:

import tensorflow as tf

# 定义输入层
inputs = tf.keras.layers.Input(shape=(2,), dtype="int32")

# 定义加法运算层
add = tf.keras.layers.Add()([inputs, inputs])

# 定义输出层
outputs = tf.keras.layers.Dense(1)(add)

# 定义模型
model = tf.keras.models.Model(inputs, outputs)

# 编译模型
model.compile(optimizer="adam", loss="mse", metrics=["mae"])

我们可以使用 TensorFlow.js 将该模型导出为一个 JSON 文件,并将其加载到浏览器中。以下是如何加载该模型:

const model = await tf.loadGraphModel("path/to/model.json");

在模型加载完成后,我们可以使用 .predict() 方法来对输入数据进行推断。例如:

// 构造输入数据
const input = tf.tensor2d([[1, 2], [3, 4]]);

// 进行推断并打印结果
const output = model.predict(input);
output.print();

输出结果为:

Tensor
    [[ 2.   ]
     [ 6.0001]]
dtype: float32
总结

Tensorflow.js 的 tf.GraphModel 类 .predict() 方法是一个十分实用的机器学习函数,它允许我们使用 WebGL 加速来进行推断运算,能够在浏览器或 Node.js 中高效地执行机器学习模型的推断。