📌  相关文章
📜  Tensorflow.js tf.LayersModel 类 .predict() 方法(1)

📅  最后修改于: 2023-12-03 15:05:33.031000             🧑  作者: Mango

Tensorflow.js tf.LayersModel 类 .predict() 方法介绍

TensorFlow.js 是一个开源的机器学习工具库,能够在浏览器和 Node.js 中运行。其中 tf.LayersModel 是 TensorFlow.js 中一个重要的类,是一种使用层来组装原始张量进行模型训练和推断的高级神经网络。

.predict() 方法简介

tf.LayersModel 类中的 .predict() 方法被用于对数据进行预测。它接受一个输入张量,返回一个张量作为输出。

model.predict(inputs, options)

说明:

  • inputs: 要进行预测的输入张量,必填。该张量的形状和类型必须与模型输入层的形状和类型相同。
  • options: 可选。一个配置对象,用于控制预测的详细过程,如批大小、输出精度等。

但是值得注意的是,使用 .predict() 方法前需要先加载、编译、训练模型,例如:

async function predictModel(){
    const model = await tf.loadLayersModel('path/to/model.json');
    const inputData = tf.tensor2d([0.1, 0.2], [1,2]); // 或处理成符合输入形状要求的张量
    const outputData = model.predict(inputData, options);
    outputData.print();
}
输出结果的张量

.predict() 方法返回的 tensor 是由网络的输出层产生的。输出张量的形状和类型取决于您所训练的模型,通常会是一个形状为 [batch_size, output_size] 的张量。

参考文献
  • TensorFlow.js: A JavaScript Library for Training and Deploying Machine Learning Models in the Browser. https://www.tensorflow.org/js/
  • tf.LayersModel.predict() API Document. https://js.tensorflow.org/api/latest/#LayersModel.predict