Tensorflow.js tf.GraphModel 类 .execute() 方法
Tensorflow.js 是一个由谷歌开发的开源库,用于在浏览器或节点环境中运行机器学习模型以及深度学习神经网络。
.execute() 方法用于为指定的输入张量实现有利于给定模型的暗示。
句法:
execute(inputs, outputs?)
参数:
- 输入:它是指定的张量或张量数组或有利于模型的输入的张量图,通过输入节点指定处理。它的类型为 (tf.Tensor|tf.Tensor[]|{[name: 字符串]: tf.Tensor})。
- 输出:它是来自所述张量流模型的所述输出节点名称。如果未说明输出,则必须应用所述模型的默认输出。此外,我们可以通过将它们附加到输出数组来分析指定模型的节点之间的中间值。它是字符串或字符串[] 类型。
返回值:返回 tf.Tensor 或 tf.Tensor[]。
示例 1:在此示例中,我们从 URL 加载 MobileNetV2。
Javascript
// Importing the tensorflow.js library
import * as tf from "@tensorflow/tfjs"
// Defining tensor input elements
const model_Url =
'https://storage.googleapis.com/tfjs-models/savedmodel/mobilenet_v2_1.0_224/model.json';
// Calling the loadGraphModel() method
const mymodel = await tf.loadGraphModel(model_Url);
// Defining inputs
const inputs = tf.zeros([1, 224, 224, 3]);
// Calling execute() method and
// Printing output
mymodel.execute(inputs).print();
Javascript
// Importing the tensorflow.js library
import * as tf from "@tensorflow/tfjs"
// Defining tensor input elements
const model_Url =
'https://tfhub.dev/google/imagenet/mobilenet_v2_140_224/classification/2';
// Calling the loadGraphModel() method
const mymodel = await tf.loadGraphModel(
model_Url, {fromTFHub: true});
// Defining inputs
const inputs = tf.zeros([1, 224, 224, 3]);
// Defining outputs
const outputs = "module_apply_default/MobilenetV2/Logits/output";
// Calling execute() method and
// Printing output
mymodel.execute(inputs, outputs).print();
输出:
Tensor
[[-0.1800361, -0.4059965, 0.8190175,
...,
-0.8953396, -1.0841646, 1.2912753],]
示例 2:在此示例中,我们从 TF Hub URL 加载 MobileNetV2。
Javascript
// Importing the tensorflow.js library
import * as tf from "@tensorflow/tfjs"
// Defining tensor input elements
const model_Url =
'https://tfhub.dev/google/imagenet/mobilenet_v2_140_224/classification/2';
// Calling the loadGraphModel() method
const mymodel = await tf.loadGraphModel(
model_Url, {fromTFHub: true});
// Defining inputs
const inputs = tf.zeros([1, 224, 224, 3]);
// Defining outputs
const outputs = "module_apply_default/MobilenetV2/Logits/output";
// Calling execute() method and
// Printing output
mymodel.execute(inputs, outputs).print();
输出:
Tensor
[[-1.1690605, 0.0195426, 1.1962479,
...,
-0.4825858, -0.0055641, 1.1937635],]
参考: https://js.tensorflow.org/api/latest/#tf.GraphModel.execute