Tensorflow.js tf.LayersModel 类 .fit() 方法
Tensorflow.js 是谷歌开发的一个开源库,用于在浏览器或节点环境中运行机器学习模型和深度学习神经网络。
tf.LayersModel 类的 .fit( ) 方法用于为固定数量的 epoch(数据集上的迭代)训练模型。
句法:
fit(x, y, args?)
参数:此方法接受以下参数。
- x:包含所有输入数据的是 tf.Tensor。
- y:包含所有输出数据的是 tf.Tensor。
- args:对象类型,变量如下:
- batchSize:它定义将通过训练传播的样本数量。
- epochs:它定义了对训练数据数组的迭代。
- 详细:它有助于显示每个时期的进度。如果值为 0 – 这意味着在 fit() 调用期间没有打印消息。如果值为 1 - 这意味着在 Node-js 中,它会打印进度条。在浏览器中,它没有显示任何操作。值 1 是默认值。 2 – 值 2 尚未实现。
- 回调:它定义了在训练期间要调用的回调列表。变量可以有一个或多个回调 onTrainBegin()、onTrainEnd()、onEpochBegin()、onEpochEnd()、onBatchBegin()、onBatchEnd()、onYield()。
- validationSplit:它使用户可以轻松地将训练数据集拆分为训练和验证。例如:如果它的值是validation-Split = 0.5,这意味着在洗牌之前使用最后50%的数据进行验证。
- validationData:用于在最终模型之间进行选择时对最终模型进行估计。
- shuffle:这个值定义了每个 epoch 之前数据的 shuffle。当stepsPerEpoch 不为null 时,它不起作用。
- classWeight:用于对损失函数进行加权。告诉模型更多地关注来自代表性不足的类的样本可能很有用。
- sampleWeight:它是一个权重数组,适用于每个样本的模型损失。
- initialEpoch:它是定义开始训练的时期的值。这对于恢复以前的训练运行很有用。
- stepsPerEpoch:它在声明一个 epoch 完成并开始下一个 epoch 之前定义了一批样本。如果未确定,则等于 1。
- validationSteps:如果指定了stepsPerEpoch ,则相关。停止前要验证的总步骤数。
- yieldEvery:它定义了将主线程让给其他任务的频率的配置。它可以是自动的,这意味着屈服以一定的帧速率发生。批次,如果值是这个,它会产生每个批次。 epoch,如果值是这个,它会产生每个 epoch。任何数字,如果该值是任何数字,则产生每个数字毫秒。 never ,如果值是这个,它永远不会产生。
回报:它回报了历史的承诺。
示例 1:
Javascript
// Importing the tensorflow.js library
import * as tf from "@tensorflow/tfjs"
// Defining model
const mymodel = tf.sequential({
layers: [tf.layers.dense({units: 2, inputShape: [6]})]
});
// Compiling the above model
mymodel.compile({optimizer: 'sgd', loss: 'meanSquaredError'});
// Using for loop
for (let i = 0; i < 4; i++) {
// Calling fit() method
const his = await mymodel.fit(tf.zeros([6, 6]), tf.ones([6, 2]), {
batchSize: 5,
epochs: 4
});
// Printing output
console.log(his.history.loss[1]);
}
Javascript
// Importing the tensorflow.js library
import * as tf from "@tensorflow/tfjs"
// Defining model
const mymodel = tf.sequential({
layers: [tf.layers.dense({units: 2, inputShape: [6],
activation : "sigmoid"})]});
// Compiling the above model
mymodel.compile({optimizer: 'sgd', loss: 'meanSquaredError'});
// Calling fit() method
const his = await mymodel.fit(tf.truncatedNormal([6, 6]),
tf.randomNormal([6, 2]), { batchSize: 5,
epochs: 4, validationSplit: 0.2,
shuffle: true, initialEpoch: 2,
stepsPerEpoch: 1, validationSteps: 2});
// Printing output
console.log(JSON.stringify(his.history));
输出:
0.9574100375175476
0.8151942491531372
0.694103479385376
0.5909997820854187
示例 2:
Javascript
// Importing the tensorflow.js library
import * as tf from "@tensorflow/tfjs"
// Defining model
const mymodel = tf.sequential({
layers: [tf.layers.dense({units: 2, inputShape: [6],
activation : "sigmoid"})]});
// Compiling the above model
mymodel.compile({optimizer: 'sgd', loss: 'meanSquaredError'});
// Calling fit() method
const his = await mymodel.fit(tf.truncatedNormal([6, 6]),
tf.randomNormal([6, 2]), { batchSize: 5,
epochs: 4, validationSplit: 0.2,
shuffle: true, initialEpoch: 2,
stepsPerEpoch: 1, validationSteps: 2});
// Printing output
console.log(JSON.stringify(his.history));
输出:
{"val_loss":[0.35800713300704956,0.35819053649902344],
"loss":[0.633269190788269,0.632409930229187]}
参考: https://js.tensorflow.org/api/latest/#tf.LayersModel.fit