📜  Tensorflow.js tf.callbacks.earlyStopping()函数

📅  最后修改于: 2022-05-13 01:56:28.688000             🧑  作者: Mango

Tensorflow.js tf.callbacks.earlyStopping()函数

Tensorflow.js 是谷歌开发的一个开源库,用于在浏览器或节点环境中运行机器学习模型和深度学习神经网络。

Tensorflow.js tf.callbacks.earlyStopping()是一个回调函数,用于在训练数据停止改进时停止训练。

句法:

tf.callbacks.earlyStopping(args);

参数:此方法接受以下参数。

  • args:它是一个具有以下字段的对象:
    • 监视器:它应该是一个字符串。这是要监视的值。
    • minDelta:应该是一个数字。它是最小值,低于该值不被认为是训练的改进。
    • 耐心:应该是一个数字。它是在遇到低于 minDelta 的值时不应停止的次数。
    • 详细:它应该是一个数字。这是冗长的价值。
    • 模式:应该是以下三种之一:
      • “auto”:在自动模式下,根据监控量的名称自动推断方向。
      • “min”:在min模式下,当监控的数据值停止减少时,训练将停止。
      • “max”:在max模式下,当监控的数据值停止增加时,训练将停止。
    • 基线:应该是一个数字。当训练跟不上这个值时,这个数字告诉训练将停止。它是监控数量的结束行。
    • restoreBestWeights:它应该是一个布尔值。它告诉是否从每个时期的监控量中恢复最佳值。

返回值:它返回一个对象(EarlyStopping)。

下面是这个函数的一些例子。

示例 1:在此示例中,我们将看到如何在 fitDataset 中使用 tf.callbacks.earlyStopping()函数:

Javascript
import * as tf from "@tensorflow/tfjs";
  
const xArray = [
    [1, 2, 3, 4],
    [5, 6, 7, 8],
    [8, 7, 6, 5],
    [1, 2, 3, 4],
];
  
const x1Array = [
    [0, 1, 0.5, 0],
    [1, 0.5, 0, 1],
    [0.5, 1, 1, 0],
    [1, 0, 0, 1],
];
  
const yArray = [1, 2, 3, 4];
const y1Array = [4, 3, 2, 1];
  
// Create a dataset from the JavaScript array.
const xDataset = tf.data.array(xArray);
const x1Dataset = tf.data.array(x1Array);
const y1Dataset = tf.data.array(x1Array);
const yDataset = tf.data.array(yArray);
  
// Combining the Dataset with zip function
const xyDataset = tf.data
    .zip({ xs: xDataset, ys: yDataset })
    .batch(4)
    .shuffle(4);
const xy1Dataset = tf.data
    .zip({ xs: x1Dataset, ys: y1Dataset })
    .batch(4)
    .shuffle(4);
  
// Creating model
const model = tf.sequential();
model.add(
    tf.layers.dense({
        units: 1,
        inputShape: [4],
    })
);
  
// Compiling model
model.compile({ loss: "meanSquaredError", 
    optimizer: "sgd", metrics: ["acc"] });
  
// Using tf.callbacks.earlyStopping in fitDataset.
const history = await model.fitDataset(xyDataset, {
    epochs: 10,
    validationData: xy1Dataset,
    callbacks: tf.callbacks.earlyStopping({ 
        monitor: "val_acc" }),
});
  
// Printing value
console.log("The value of val_acc is :", 
    history.history.val_acc);


Javascript
import * as tf from "@tensorflow/tfjs";
  
// Creating tensor for training
const x = tf.tensor([5, 6, 7, 8, 9, 2], [3, 2]);
const x1 = tf.tensor([8, 7, 6, 5, 2, 9], [3, 2]);
const y = tf.tensor([1, 3, 3, 4, 4, 6, 6, 8, 9], [3, 3]);
const y1 = tf.tensor([2, 2, 2, 1, 5, 5, 2, 3, 8], [3, 3]);
  
// Creating model
const model = tf.sequential();
  
model.add(
    tf.layers.dense({
        units: 3,
        inputShape: [2],
    })
);
  
// Compiling model
model.compile({ loss: "meanSquaredError", 
    optimizer: "sgd", metrics: ["acc"] });
  
// Using tf.callbacks.earlyStopping in fit.
const history = await model.fit(x, y, {
    epochs: 10,
    validationData: [x1, y1],
    callbacks: tf.callbacks.earlyStopping({ 
        monitor: "val_acc" }),
});
  
// Printing value
console.log("the value of val_acc is :", 
    history.history.val_acc);


输出:你得到的值是不同的,因为它的 val_acc 值随着训练值的变化而变化。

The value of val_acc is :0.4375,0.375

示例 2:在这个示例中,我们将看到如何使用 tf.callbacks.earlyStopping() 配合 fit:

Javascript

import * as tf from "@tensorflow/tfjs";
  
// Creating tensor for training
const x = tf.tensor([5, 6, 7, 8, 9, 2], [3, 2]);
const x1 = tf.tensor([8, 7, 6, 5, 2, 9], [3, 2]);
const y = tf.tensor([1, 3, 3, 4, 4, 6, 6, 8, 9], [3, 3]);
const y1 = tf.tensor([2, 2, 2, 1, 5, 5, 2, 3, 8], [3, 3]);
  
// Creating model
const model = tf.sequential();
  
model.add(
    tf.layers.dense({
        units: 3,
        inputShape: [2],
    })
);
  
// Compiling model
model.compile({ loss: "meanSquaredError", 
    optimizer: "sgd", metrics: ["acc"] });
  
// Using tf.callbacks.earlyStopping in fit.
const history = await model.fit(x, y, {
    epochs: 10,
    validationData: [x1, y1],
    callbacks: tf.callbacks.earlyStopping({ 
        monitor: "val_acc" }),
});
  
// Printing value
console.log("the value of val_acc is :", 
    history.history.val_acc);

输出:执行代码的值会有所不同,因为随着训练数据值的变化:

the value of val_acc is : 0.3333333432674408,0.3333333432674408

参考: https://js.tensorflow.org/api/latest/#callbacks.earlyStopping