📜  Tensorflow.js tf.callbacks.earlyStopping()函数(1)

📅  最后修改于: 2023-12-03 15:20:34.170000             🧑  作者: Mango

TensorFlow.js tf.callbacks.earlyStopping() 函数

简介

tf.callbacks.earlyStopping() 是 TensorFlow.js 中提供的一个回调函数,用于在训练过程中实现早停(early stopping)。

早停是一种常用的训练模型技术,通常在训练过程中通过观察验证集的性能来决定是否停止训练。当验证集的性能不再提升或出现恶化时,早停允许我们停止训练,以避免过拟合并节省计算资源。

用法

tf.callbacks.earlyStopping() 函数接受一个对象参数,用于配置早停的条件和行为。

tf.callbacks.earlyStopping(config)
参数
  • config:一个对象,用于配置早停的条件和行为。

    • monitor(必选):指定要监视的指标名称。可以是损失函数名称或指标函数名称。
    • minDelta(可选):要监视指标的最小变化量(增加或减少)。当指标的变化小于该值时,训练将在 patience 次迭代后停止。默认值为 0。
    • patience(可选):当指标的变化小于 minDelta 时,将等待多少个迭代次数以确定是否达到停止条件。默认值为 0,即仅在第一次验证集的性能恶化时停止训练。
    • verbose(可选):控制是否输出详细信息。默认值为 0,不输出详细信息。
    • mode(可选):指定如何确定监视指标的变化。可以是 min(默认值)或 max,分别表示需要监视指标的最小值或最大值。
返回值

tf.callbacks.earlyStopping() 函数返回一个回调函数对象,可用于训练模型时将其传递给 tf.Model.fit() 函数的 callbacks 参数。

示例

下面是一个示例,展示了如何在 TensorFlow.js 中使用 tf.callbacks.earlyStopping() 函数:

const model = tf.sequential();
model.add(tf.layers.dense({units: 10, inputShape: [4]}));
model.add(tf.layers.dense({units: 3}));

const xs = tf.randomNormal([100, 4]);
const ys = tf.randomNormal([100, 3]);

const earlyStopCallback = tf.callbacks.earlyStopping({monitor: 'val_loss', patience: 5, verbose: 1});

model.compile({optimizer: 'sgd', loss: 'meanSquaredError'});
model.fit(xs, ys, {
  epochs: 100,
  callbacks: [earlyStopCallback]
}).then(() => {
  console.log('Training completed.');
});

在上述示例中,我们创建了一个简单的神经网络模型,并使用 tf.callbacks.earlyStopping() 函数配置了早停条件。我们在每个迭代中观察验证集的损失值,并在连续 5 次迭代验证集损失没有变化或变化小于 0.001 时停止训练。

总结

通过使用 TensorFlow.js tf.callbacks.earlyStopping() 函数,我们可以轻松实现早停技术,避免过拟合,并在验证集性能不再提升时节省计算资源。这是一个非常有用的工具,在训练深度学习模型时常常会用到。