📌  相关文章
📜  Tensorflow.js tf.LayersModel 类 .predictOnBatch() 方法(1)

📅  最后修改于: 2023-12-03 15:35:17.478000             🧑  作者: Mango

Tensorflow.js tf.LayersModel类 .predictOnBatch()方法

简介

Tensorflow.js是一个由Google开发的机器学习框架,它支持在浏览器中进行机器学习和深度学习。tf.LayersModel类是Tensorflow.js中的一个模型类,它支持使用层次结构构建机器学习模型。.predictOnBatch()方法是tf.LayersModel类中的一个方法,用于在输入数据集的批次上进行预测。

使用方法

在使用.predictOnBatch()方法前,需要使用tf.loadLayersModel()方法或其他方式加载模型。加载模型后,可以使用.predictOnBatch()方法对数据进行预测。

// 加载模型
const model = await tf.loadLayersModel('model.json');

// 准备输入数据
const x = tf.tensor([1, 2, 3, 4], [2, 2]);

// 预测结果
const result = model.predictOnBatch(x);
console.log(result);

其中,输入数据的形状需要和模型的输入形状匹配。predictOnBatch()方法返回一个张量,表示预测结果。

示例

以下示例展示了如何使用.predictOnBatch()方法对数据进行预测。

// 加载模型
const model = await tf.loadLayersModel('model.json');

// 准备输入数据
const x = tf.tensor([1, 2, 3, 4], [2, 2]);

// 预测结果
const result = model.predictOnBatch(x);
result.print();

输出结果如下:

Tensor
    [[-0.04575942 -0.08205768]
     [-0.00956025 -0.1043027 ]]
注意事项

使用.predictOnBatch()方法前,需要确保已加载模型,并且输入数据的形状需要和模型的输入形状匹配。如果输入数据的形状不匹配,将会报错。另外,预测结果是一个张量,需要使用.print()方法或其他方式打印出来。