📜  Tensorflow.js tf.Sequential 类 .fitDataset() 方法(1)

📅  最后修改于: 2023-12-03 14:47:55.785000             🧑  作者: Mango

Tensorflow.js tf.Sequential 类 .fitDataset() 方法

Tensorflow.js是一个JavaScript库,可用于在浏览器中构建和训练机器学习模型。tf.Sequential是一种模型类型,它按顺序连接的层来构建模型。 .fitDataset()方法是tf.Sequential类的一个方法,用于拟合和训练模型。

语法
tf.Sequential.fitDataset(dataset, configuration)
参数
  • dataset: 数据集,格式为tf.data.Dataset类型

  • configuration: 对象,用于指定训练模型时使用的数据大小、批次大小、epochs等参数。可指定以下参数:

    • epochs: 训练时的周期数。
    • stepsPerEpoch: 训练时每个epoch的步骤数。
    • batchSize: 每个批次的数据大小。
    • validationData: 用于验证的数据。
返回值

一个Promise对象,包含了模型的训练结果。

示例

以下是使用.fitDataset()方法训练模型的简单示例。在该示例中,我们将使用一个含有两层的 Sequential 模型,并将其训练10个时期。

const model = tf.sequential();
model.add(tf.layers.dense({units: 10, activation: 'relu', inputShape: [2]}));
model.add(tf.layers.dense({units: 1, activation: 'sigmoid'}));

const trainData = tf.data.array([{x: [0, 0], y: 0}, {x: [0, 1], y: 1}, {x: [1, 0], y: 1}, {x: [1, 1], y: 0}]);
const validationData = tf.data.array([{x: [0.5, 0.5], y: 0.5}]);

const config = {
  epochs: 10,
  stepsPerEpoch: 4,
  batchSize: 2,
  validationData: validationData,
};

model.compile({loss: 'binaryCrossentropy', optimizer: 'adam', metrics: ['accuracy']});

trainData.shuffle(4).batch(2).fitDataset(model, config).then((history) => {
  console.log(history);
});

在这个示例中,我们定义了一个Sequential模型,并向其添加了两个密集层,分别包含10个和1个单元。我们使用一个包含四个训练样本和一个验证样本的小型数据集进行训练和验证。

我们使用'tf.data'模块中的'array'方法来创建一个简单的数据集,然后使用数据集的'shuffle'和'batch'方法来传递给'fitDataset'方法。我们定义了一个包含各种参数的配置对象,然后通过'compile'方法对模型进行编译。

最后,我们使用'fitDataset'方法对模型进行训练,并在训练完成后打印出结果。