📅  最后修改于: 2023-12-03 14:47:55.001000             🧑  作者: Mango
TensorFlow.js 是一个由 Google Brain 团队开发的 JavaScript 库,可用于在浏览器或 Node.js 环境中构建机器学习应用程序。其中 tf.GraphModel 类是用于加载和执行训练好的 TensorFlow 模型的。
.predict() 是 tf.GraphModel 类中用于推断输入数据的方法,它的主要作用是接受一个输入 Tensor,然后返回一个输出 Tensor。
model.predict(inputs, config)
inputs
:输入数据,可以是一个 Tensor、一组 Tensor 或一个包含 Tensor 的字典。具体取决于模型的输入层。config
:可选参数,用于控制推断过程的一些设置,如批次大小。输出 Tensor、一组 Tensor 或一个包含 Tensor 的字典,具体取决于模型的输出层。
假设我们有一个简单的 TensorFlow 模型,在两个整数输入上执行加法运算并输出结果。以下是该模型的代码:
import tensorflow as tf
# 定义输入层
inputs = tf.keras.layers.Input(shape=(2,), dtype="int32")
# 定义加法运算层
add = tf.keras.layers.Add()([inputs, inputs])
# 定义输出层
outputs = tf.keras.layers.Dense(1)(add)
# 定义模型
model = tf.keras.models.Model(inputs, outputs)
# 编译模型
model.compile(optimizer="adam", loss="mse", metrics=["mae"])
我们可以使用 TensorFlow.js 将该模型导出为一个 JSON 文件,并将其加载到浏览器中。以下是如何加载该模型:
const model = await tf.loadGraphModel("path/to/model.json");
在模型加载完成后,我们可以使用 .predict() 方法来对输入数据进行推断。例如:
// 构造输入数据
const input = tf.tensor2d([[1, 2], [3, 4]]);
// 进行推断并打印结果
const output = model.predict(input);
output.print();
输出结果为:
Tensor
[[ 2. ]
[ 6.0001]]
dtype: float32
Tensorflow.js 的 tf.GraphModel 类 .predict() 方法是一个十分实用的机器学习函数,它允许我们使用 WebGL 加速来进行推断运算,能够在浏览器或 Node.js 中高效地执行机器学习模型的推断。