📌  相关文章
📜  Tensorflow.js tf.LayersModel 类 .predictOnBatch() 方法

📅  最后修改于: 2022-05-13 01:56:30.948000             🧑  作者: Mango

Tensorflow.js tf.LayersModel 类 .predictOnBatch() 方法

Tensorflow.js 是一个由谷歌开发的开源库,用于在浏览器或节点环境中运行机器学习模型以及深度学习神经网络。

.predictOnBatch()函数用于返回单个实例组的期望值。

句法:

predictOnBatch(x)

参数:

  • x:它是指定的输入实例,例如张量,即具有精确一个输入的模型,或者是一组张量,即具有多个输入的模型。它可以是 tf.Tensor 或 tf.Tensor[] 类型。

返回值:返回 tf.Tensor 对象或 tf.Tensor[]。

示例 1:

Javascript
// Importing the tensorflow.js library
import * as tf from "@tensorflow/tfjs"
  
// Defining model
const Mod = tf.sequential({
   layers: [tf.layers.dense({units: 2, inputShape: [30]})]
});
  
// Calling predictOnBatch() method and
// Printing output
Mod.predictOnBatch(tf.randomNormal([6, 30])).print();


Javascript
// Importing the tensorflow.js library
import * as tf from "@tensorflow/tfjs"
  
// Calling predictOnBatch() method and
// Printing output
tf.sequential({
   layers: [tf.layers.dense({units: 3, inputShape: [40]})]
}).predictOnBatch(tf.truncatedNormal([5, 40])).print();


输出:

Tensor
    [[-1.4716092, -1.8019401],
     [-1.0033149, -0.2789704],
     [-0.4451316, 0.2422157 ],
     [-0.1512984, -0.0726933],
     [2.1483333 , 2.4668102 ],
     [0.4091003 , 0.8335327 ]]

示例 2:

Javascript

// Importing the tensorflow.js library
import * as tf from "@tensorflow/tfjs"
  
// Calling predictOnBatch() method and
// Printing output
tf.sequential({
   layers: [tf.layers.dense({units: 3, inputShape: [40]})]
}).predictOnBatch(tf.truncatedNormal([5, 40])).print();

输出:

Tensor
    [[-1.5034456, -0.3429004, -0.2388536],
     [0.0083699 , -0.3176711, 2.1414554 ],
     [1.1850954 , -0.4481514, 1.1278313 ],
     [-0.1004405, 1.420954  , 0.4890856 ],
     [0.4184967 , 0.1191952 , -0.0936601]]

参考: https://js.tensorflow.org/api/latest/#tf.LayersModel.predictOnBatch