Tensorflow.js tf.Sequential 类 .fitDataset() 方法
Tensorflow.js 是谷歌开发的一个开源库,用于在浏览器或节点环境中运行机器学习模型和深度学习神经网络。 Tensorflow.js tf.Sequential 类 .fitDataset() 方法用于使用数据集对象训练模型。
句法:
model.fitDataset(dataset, args);
参数:此方法包含以下参数:
- 数据集:它是输入值的数据集。它可以是原始数据集、数组或对象。
- args:它包含以下值:
- epochs :它是训练模型期间训练数据集中的总次数。它是一个整数值。
- batchesPerEpoch :它定义了每个 epoch 中的批次数。它的值取决于批量大小,因为批量大小增加它的大小减小。
- 详细:它有助于显示每个时期的进度。如果值为 0 – 这意味着在 fit() 调用期间没有打印消息。如果值为 1 - 这意味着在 Node.js 中,它会打印进度条。在浏览器中它没有显示任何操作。值 1 是默认值。 2 – 值 2 尚未实现。
- 回调:它定义了在训练期间要调用的回调列表。变量可以有一个或多个回调 onTrainBegin()、onTrainEnd()、onEpochBegin()、onEpochEnd()、onBatchBegin()、onBatchEnd()、onYield()。
- validationData:用于在最终模型之间进行选择时对最终模型进行估计。这可以是以下任何一种: [ xVal, yVal ] 的数组,具有 { xs : xVal, ys : yVal } 形式的元素的 Dataset 对象。
- validationBatchSize:它是定义批次大小的数字。它用于验证批量大小。这意味着我们不能一次放置超过这个值的所有数据集。它的默认值为 32。
- validationBatches:用于验证样本批次。它用于在每个 epoch 结束时为验证目的绘制验证数据。
- classWeight:用于对损失函数进行加权。告诉模型更多地关注来自代表性不足的类的样本可能很有用。
- initialEpoch:用于定义开始训练的 epoch 值。这对于恢复以前的训练运行很有用。
- yieldEvery:它定义了将主线程让给其他任务的频率的配置。它可以是自动的,这意味着屈服以一定的帧速率发生。批次,如果值是这个,它会产生每个批次。 epoch,如果值是这个,它会产生每个 epoch。任何数字,如果该值是任何数字,它会产生每个数字毫秒。从不,如果值是这个,它永远不会产生。
回报:承诺<历史>
示例 1:在此示例中,我们将使用数组数据集训练我们的模型。
Javascript
import * as tf from "@tensorflow/tfjs"
// Creating model
const gfg_Model = tf.sequential() ;
// Adding layer to model
const config = {units: 1, inputShape: [2]}
const gfg_layer = tf.layers.dense(config);
gfg_Model.add(gfg_layer);
// Compiling the model
const config2 = {optimizer: 'sgd', loss: 'meanSquaredError'}
gfg_Model.compile(config2);
// Creating Datasets for training
const array1 = [[1,2], [1,4], [1,3], [3,4]];
const array2 = [1, 1];
const arrData1 = tf.data.array(array1);
const arrData2 = tf.data.array(array2);
const config3 = {xs:arrData1, ys:arrData2}
const arrayDataset = tf.data.zip(config3)
const ArrayDataset = arrayDataset.batch(3).shuffle(6);
// Training the model
const Tm = await gfg_Model.fitDataset(ArrayDataset, { epochs: 3 });
// Printing the loss after training
console.log("Loss " + " : " + Tm.history.loss[0]);
Javascript
import * as tf from "@tensorflow/tfjs";
// Path for the CSV file
const gfg_CsvFile =
"https://storage.googleapis.com/tfjs-examples/multivariate-linear-regression/data/boston-housing-train.csv";
// Creating model
const gfg_Model = tf.sequential();
// Adding layer to model
const config = { units: 1, inputShape: [12] };
const gfg_layer = tf.layers.dense(config);
gfg_Model.add(gfg_layer);
// Compiling the model
const opt = tf.train.sgd(0.0001);
gfg_Model.compile({ optimizer: opt, loss: "meanSquaredError" });
// Here we want to predict column tax
const config2 = { columnConfigs: { tax: { isLabel: true } } };
const csvDataset = tf.data.csv(gfg_CsvFile, config2);
// Creating dataset for training
const flattenedDataset = csvDataset
.map(({ xs, ys }) => {
return { xs: Object.values(xs), ys: Object.values(ys) };
})
.batch(5);
// Training the model
const Tm = await gfg_Model.fitDataset(flattenedDataset, { epochs: 5 });
for (let i = 0; i < 5; i++) {
console.log(Tm.history.loss[i]);
}
输出:
Loss : 0.428712397813797
示例 2:在此示例中,我们将使用由 csv 文件制作的数据集来训练我们的模型。
Javascript
import * as tf from "@tensorflow/tfjs";
// Path for the CSV file
const gfg_CsvFile =
"https://storage.googleapis.com/tfjs-examples/multivariate-linear-regression/data/boston-housing-train.csv";
// Creating model
const gfg_Model = tf.sequential();
// Adding layer to model
const config = { units: 1, inputShape: [12] };
const gfg_layer = tf.layers.dense(config);
gfg_Model.add(gfg_layer);
// Compiling the model
const opt = tf.train.sgd(0.0001);
gfg_Model.compile({ optimizer: opt, loss: "meanSquaredError" });
// Here we want to predict column tax
const config2 = { columnConfigs: { tax: { isLabel: true } } };
const csvDataset = tf.data.csv(gfg_CsvFile, config2);
// Creating dataset for training
const flattenedDataset = csvDataset
.map(({ xs, ys }) => {
return { xs: Object.values(xs), ys: Object.values(ys) };
})
.batch(5);
// Training the model
const Tm = await gfg_Model.fitDataset(flattenedDataset, { epochs: 5 });
for (let i = 0; i < 5; i++) {
console.log(Tm.history.loss[i]);
}
输出:
21489.68359375
8750.29296875
6632.365234375
5908.6171875
5546.45654296875
参考: https://js.tensorflow.org/api/latest/#tf.Sequential.fitDataset