📌  相关文章
📜  Tensorflow.js tf.LayersModel 类 .predict() 方法

📅  最后修改于: 2022-05-13 01:56:42.519000             🧑  作者: Mango

Tensorflow.js tf.LayersModel 类 .predict() 方法

Tensorflow.js 是一个由谷歌开发的开源库,用于在浏览器或节点环境中运行机器学习模型以及深度学习神经网络。

这 。 predict()函数用于生成给定输入实例的输出估计值。此外,这里的计算是成组进行的。其中,目前不支持step操作,只需要 tensorflow.js 的核心后端。

句法:

predict(x, args?)

参数:

  • x:它是规定的输入数据,如张量,否则为 tf.Tensors 数组,以防模型有各种输入。它可以是 tf.Tensor 或 tf.Tensor[] 类型。
  • args:它是声明的 ModelPredictArgs 包含选修字段的对象。
    1. batchSize:它是整数类型的指定批次维度。如果未定义,默认值为 32。
    2. verbose:它是声明的详细模式,默认值为 false。

返回值:返回 tf.Tensor 对象或 tf.Tensor[]。

示例 1:

Javascript
// Importing the tensorflow.js library
import * as tf from "@tensorflow/tfjs"
  
// Defining model
const Mod = tf.sequential({
   layers: [tf.layers.dense({units: 2, inputShape: [30]})]
});
  
// Calling predict() method and
// Printing output
Mod.predict(tf.randomNormal([6, 30])).print();


Javascript
// Importing the tensorflow.js library
import * as tf from "@tensorflow/tfjs"
  
// Calling predict() method and
// Printing output
tf.sequential({
   layers: [tf.layers.dense({units: 3, inputShape: [10]})]
}).predict(tf.truncatedNormal([2, 10]), {batchSize: 2}, true).print();


输出:

Tensor
    [[-0.7650393, -0.8317917],
     [-0.7274997, 1.827635  ],
     [-0.9398478, -0.2998275],
     [-1.0945926, -1.9154934],
     [0.0067322 , -1.9220339],
     [0.2052939 , 0.6488774 ]]

示例 2:

Javascript

// Importing the tensorflow.js library
import * as tf from "@tensorflow/tfjs"
  
// Calling predict() method and
// Printing output
tf.sequential({
   layers: [tf.layers.dense({units: 3, inputShape: [10]})]
}).predict(tf.truncatedNormal([2, 10]), {batchSize: 2}, true).print();

输出:

Tensor
    [[0.2670097, -1.2741219, -0.3159108],
     [0.9108799, -0.1305539, -0.1370454]]

参考: https://js.tensorflow.org/api/latest/#tf.LayersModel.predict