📅  最后修改于: 2023-12-03 15:05:33.295000             🧑  作者: Mango
TensorFlow.js 是 Google 研发的机器学习框架 TensorFlow 的 JavaScript 版本,可在 Web 环境下运行。tf.Sequential
是 TensorFlow.js 中定义神经网络模型的类。.trainOnBatch()
方法是其中的一个用于训练模型的方法。
trainOnBatch(inputs, labels)
inputs
: 输入数据。形状为 [batchSize, ...inputShape]
的张量或张量数组。labels
: 目标数据。 形状为 [batchSize, ...targetShape]
的张量或张量数组。该方法会在一个批次 (batch) 上训练模型。 inputs
和 labels
参数是 tensor 或 tensor 数组。 inputs
参数包含输入数据,而 labels
参数包含目标数据(labels)。trainOnBatch()
迭代一次训练批次,并返回一个标量损失 (loss) 值。
const model = tf.sequential();
model.add(tf.layers.dense({units: 1, inputShape: [2]}));
model.compile({optimizer: 'sgd', loss: 'meanSquaredError'});
const xs = tf.tensor2d([[0, 0], [0, 1], [1, 0], [1, 1]]);
const ys = tf.tensor2d([[0], [1], [1], [0]]);
for (let i = 0; i < 10; i++) {
const result = model.trainOnBatch(xs, ys);
console.log(`[${i + 1}] loss:${result}`);
}
上述代码展示了如何使用 trainOnBatch()
方法训练一个简单的神经网络模型来解决 XOR 问题。每次使用 trainOnBatch()
方法会在四个样本上进行一次训练,返回一个标量损失值。在上面的代码中,我们重复调用这个方法10次来完成训练过程。在训练过程中,损失值逐渐降低,表明网络的预测结果逐渐接近期望输出。