📜  Tensorflow.js tf.Sequential 类 .evaluate() 方法(1)

📅  最后修改于: 2023-12-03 14:47:55.769000             🧑  作者: Mango

TensorFlow.js tf.Sequential 类 .evaluate() 方法

在使用 TensorFlow.js 进行深度学习模型训练时,我们通常会使用 tf.Sequential 类来构建模型。而在模型训练完成后,我们通常需要使用 .evaluate() 方法来对测试集进行模型的评估。本文将对 TensorFlow.js 中 tf.Sequential 类的 .evaluate() 方法进行介绍。

使用方法

首先,我们需要使用 tf.model() 方法来创建一个 tf.Sequential 类的实例,然后通过 .compile() 方法来对模型进行编译:

const model = tf.sequential();
model.add(tf.layers.dense({inputShape: [4], units: 4, activation: 'relu'}));
model.add(tf.layers.dense({units: 3, activation: 'softmax'}));
model.compile({optimizer: 'sgd', loss: 'categoricalCrossentropy', metrics: ['accuracy']});

上述代码使用 tf.sequential() 创建了一个 tf.Sequential 类的实例,并添加了两个全连接层,接着使用 .compile() 方法对模型进行编译。在编译时需要指定优化器、损失函数和评估指标。

接下来,我们可以使用 .evaluate() 方法来对测试集进行评估:

const testXs = tf.randomNormal([100, 4]);
const testYs = tf.oneHot(tf.tensor1d(Array.from({length: 100}, (_, i) => Math.floor(Math.random() * 3)), 'int32'), 3);
const evalOutput = model.evaluate(testXs, testYs, {batchSize: 32});
evalOutput.print();

上述代码中,我们使用 tf.randomNormal() 方法来生成 100 个输入样本数据,使用 tf.oneHot() 方法生成与之对应的标签数据。最后,我们使用 .evaluate() 方法对测试集进行评估,并指定了 batchSize 参数。我们还使用 .print() 方法来打印评估结果。

需要注意的是,.evaluate() 方法返回一个 Promise 对象,我们可以使用 async/await.then() 方法来处理返回结果。

参数列表

.evaluate() 方法有以下参数:

  • x: 必须。输入数据。
  • y: 可选。标签数据。
  • batchSize: 可选。批大小。
  • verbose: 可选。是否输出日志。
返回值

.evaluate() 方法返回一个含有评估结果的数组。

总结

tf.Sequential 类的 .evaluate() 方法是一个重要的工具,在深度学习模型训练完成后对模型进行评估。使用该方法需要指定输入数据和标签数据,以及优化器、损失函数和评估指标等相关参数。