📅  最后修改于: 2023-12-03 15:20:34.170000             🧑  作者: Mango
tf.callbacks.earlyStopping()
是 TensorFlow.js 中提供的一个回调函数,用于在训练过程中实现早停(early stopping)。
早停是一种常用的训练模型技术,通常在训练过程中通过观察验证集的性能来决定是否停止训练。当验证集的性能不再提升或出现恶化时,早停允许我们停止训练,以避免过拟合并节省计算资源。
tf.callbacks.earlyStopping()
函数接受一个对象参数,用于配置早停的条件和行为。
tf.callbacks.earlyStopping(config)
config
:一个对象,用于配置早停的条件和行为。
monitor
(必选):指定要监视的指标名称。可以是损失函数名称或指标函数名称。minDelta
(可选):要监视指标的最小变化量(增加或减少)。当指标的变化小于该值时,训练将在 patience
次迭代后停止。默认值为 0。patience
(可选):当指标的变化小于 minDelta
时,将等待多少个迭代次数以确定是否达到停止条件。默认值为 0,即仅在第一次验证集的性能恶化时停止训练。verbose
(可选):控制是否输出详细信息。默认值为 0
,不输出详细信息。mode
(可选):指定如何确定监视指标的变化。可以是 min
(默认值)或 max
,分别表示需要监视指标的最小值或最大值。tf.callbacks.earlyStopping()
函数返回一个回调函数对象,可用于训练模型时将其传递给 tf.Model.fit()
函数的 callbacks
参数。
下面是一个示例,展示了如何在 TensorFlow.js 中使用 tf.callbacks.earlyStopping()
函数:
const model = tf.sequential();
model.add(tf.layers.dense({units: 10, inputShape: [4]}));
model.add(tf.layers.dense({units: 3}));
const xs = tf.randomNormal([100, 4]);
const ys = tf.randomNormal([100, 3]);
const earlyStopCallback = tf.callbacks.earlyStopping({monitor: 'val_loss', patience: 5, verbose: 1});
model.compile({optimizer: 'sgd', loss: 'meanSquaredError'});
model.fit(xs, ys, {
epochs: 100,
callbacks: [earlyStopCallback]
}).then(() => {
console.log('Training completed.');
});
在上述示例中,我们创建了一个简单的神经网络模型,并使用 tf.callbacks.earlyStopping()
函数配置了早停条件。我们在每个迭代中观察验证集的损失值,并在连续 5 次迭代验证集损失没有变化或变化小于 0.001 时停止训练。
通过使用 TensorFlow.js tf.callbacks.earlyStopping() 函数,我们可以轻松实现早停技术,避免过拟合,并在验证集性能不再提升时节省计算资源。这是一个非常有用的工具,在训练深度学习模型时常常会用到。