📅  最后修改于: 2023-12-03 15:35:17.284000             🧑  作者: Mango
TensorFlow.js 是一个针对浏览器和 Node.js 的端到端开源机器学习平台。tf.GraphModel 类是 TensorFlow.js 中的一个关键类,它允许您使用预先训练的 TensorFlow 模型,同时处理和传输数据。在本文中,我们将介绍 TensorFlow.js tf.GraphModel 类及其用例。
tf.GraphModel 类是 TensorFlow.js 中的一个 JavaScript 类。它是 TensorFlow 中的 SavedModel 的一个子集,可以用于将预训练的 TensorFlow 模型加载到浏览器或 Node.js 中。
一旦加载到浏览器或 Node.js 中,您可以使用 tf.GraphModel 类的方法对模型进行评估、预测,并对其进行微调。
以下是使用 tf.GraphModel 类的基本示例。
const model = await tf.loadGraphModel('url/to/model.json');
const input = tf.tensor2d([[1.0, 2.0], [3.0, 4.0]]);
const output = model.execute(input);
output.print();
首先,我们使用 tf.loadGraphModel()
方法从 URL 加载 SavedModel。接下来,我们创建一个输入张量 input
,使用 model.execute()
方法进行推理,并将输出张量打印到控制台上。
tf.GraphModel 类提供了许多方法,用于读取和处理预训练的 TensorFlow 模型、执行推理操作,并通过 WebGPU 或 WebGL 将图形计算传输到 GPU。
以下是 tf.GraphModel 类的一些常用方法:
tf.loadGraphModel
:从 URL 或文件系统加载 SavedModel。const model = await tf.loadGraphModel('url/to/model.json');
model.execute
:对输入数据进行预测并返回输出数据。输出数据的形状、数据类型和值的范围根据模型的预测设置而定。const input = tf.tensor2d([[1.0, 2.0], [3.0, 4.0]]);
const output = model.execute(input);
output.print();
model.predict
:对输入数据进行预测并返回一个标准化的输出。标准化输出是具有相同数据类型和形状的张量数组(与模型的输出设置有关)。const input = tf.tensor2d([[1.0, 2.0], [3.0, 4.0]]);
const output = model.predict(input);
output.forEach(tensor => tensor.print());
tf.GraphModel 类是 TensorFlow.js 中的重要类之一,它允许您将预训练的 TensorFlow 模型加载到浏览器或 Node.js 中,并使用多种方法执行预测任务。我们希望本文对您有所帮助。