📜  Tensorflow.js tf.GraphModel 类 .predict() 方法

📅  最后修改于: 2022-05-13 01:56:34.937000             🧑  作者: Mango

Tensorflow.js tf.GraphModel 类 .predict() 方法

Tensorflow.js 是由谷歌开发的开源库,用于在浏览器或节点环境中运行机器学习模型以及深度学习神经网络。

.predict()函数用于实现有利于输入张量的暗示。

句法:

predict(inputs, config?)

参数:

  • 输入:这是规定的输入。它的类型为 (tf.Tensor|tf.Tensor[]|{[name: 字符串]: tf.Tensor})。
  • config:它是规定的预测配置,用于定义批量大小以及输出节点名称。此外,目前图模型忽略了批量大小的选择。它是可选的并且是对象类型。
    • batchSize:它是指定的批处理维度,是可选的并且是整数类型。如果未定义,则默认值为 32。
    • 详细:它是指定的详细模式,默认值为 false 并且是可选的。

返回值:返回 tf.Tensor|tf.Tensor[]|{[name: 字符串]: tf.Tensor}。

示例 1:在此示例中,我们从 URL 加载 MobileNetV2 并保存带有零输入的预测。

Javascript
// Importing the tensorflow.js library
import * as tf from "@tensorflow/tfjs"
  
// Defining tensor input elements
const model_Url =
'https://storage.googleapis.com/tfjs-models/savedmodel/mobilenet_v2_1.0_224/model.json';
  
// Calling the loadGraphModel() method
const mymodel = await tf.loadGraphModel(model_Url);
  
// Defining inputs
const inputs = tf.zeros([1, 224, 224, 3]);
  
// Calling predict() method and 
// Printing output
mymodel.predict(inputs).print();


Javascript
// Importing the tensorflow.js library
import * as tf from "@tensorflow/tfjs"
  
// Defining tensor input elements
const model_Url =
'https://tfhub.dev/google/imagenet/mobilenet_v2_140_224/classification/2';
  
// Calling the loadGraphModel() method
const model = await tf.loadGraphModel(
        model_Url, {fromTFHub: true});
  
// Defining inputs
const inputs = tf.zeros([1, 224, 224, 3]);
  
// Defining batchsize
const batchsize = 1;
  
// Defining verbose
const verbose = true;
  
// Calling predict() method and
// Printing output
model.predict(inputs, batchsize, verbose).print();


输出:

Tensor
     [[-0.1800361, -0.4059965, 0.8190175, 
     ..., 
     -0.8953396, -1.0841646, 1.2912753],]

示例 2:在此示例中,我们从 TF Hub URL 加载 MobileNetV2,并使用零输入进行预测。

Javascript

// Importing the tensorflow.js library
import * as tf from "@tensorflow/tfjs"
  
// Defining tensor input elements
const model_Url =
'https://tfhub.dev/google/imagenet/mobilenet_v2_140_224/classification/2';
  
// Calling the loadGraphModel() method
const model = await tf.loadGraphModel(
        model_Url, {fromTFHub: true});
  
// Defining inputs
const inputs = tf.zeros([1, 224, 224, 3]);
  
// Defining batchsize
const batchsize = 1;
  
// Defining verbose
const verbose = true;
  
// Calling predict() method and
// Printing output
model.predict(inputs, batchsize, verbose).print();

输出:

Tensor
     [[-1.1690605, 0.0195426, 1.1962479, 
     ..., 
     -0.4825858, -0.0055641, 1.1937635],]

参考: https://js.tensorflow.org/api/latest/#tf.GraphModel.predict