📜  Tensorflow.js tf.Sequential 类 .trainOnBatch() 方法(1)

📅  最后修改于: 2023-12-03 15:05:33.295000             🧑  作者: Mango

TensorFlow.js tf.Sequential 类 .trainOnBatch() 方法介绍

简介

TensorFlow.js 是 Google 研发的机器学习框架 TensorFlow 的 JavaScript 版本,可在 Web 环境下运行。tf.Sequential 是 TensorFlow.js 中定义神经网络模型的类。.trainOnBatch() 方法是其中的一个用于训练模型的方法。

方法
trainOnBatch(inputs, labels)
  • inputs: 输入数据。形状为 [batchSize, ...inputShape] 的张量或张量数组。
  • labels: 目标数据。 形状为 [batchSize, ...targetShape] 的张量或张量数组。

该方法会在一个批次 (batch) 上训练模型。 inputslabels 参数是 tensor 或 tensor 数组。 inputs 参数包含输入数据,而 labels 参数包含目标数据(labels)。trainOnBatch() 迭代一次训练批次,并返回一个标量损失 (loss) 值。

代码示例
const model = tf.sequential();
model.add(tf.layers.dense({units: 1, inputShape: [2]}));
model.compile({optimizer: 'sgd', loss: 'meanSquaredError'});

const xs = tf.tensor2d([[0, 0], [0, 1], [1, 0], [1, 1]]);
const ys = tf.tensor2d([[0], [1], [1], [0]]);

for (let i = 0; i < 10; i++) {
  const result = model.trainOnBatch(xs, ys);
  console.log(`[${i + 1}] loss:${result}`);
}

上述代码展示了如何使用 trainOnBatch() 方法训练一个简单的神经网络模型来解决 XOR 问题。每次使用 trainOnBatch() 方法会在四个样本上进行一次训练,返回一个标量损失值。在上面的代码中,我们重复调用这个方法10次来完成训练过程。在训练过程中,损失值逐渐降低,表明网络的预测结果逐渐接近期望输出。