📌  相关文章
📜  Tensorflow.js tf.LayersModel 类 .fit() 方法

📅  最后修改于: 2022-05-13 01:56:25             🧑  作者: Mango

Tensorflow.js tf.LayersModel 类 .fit() 方法

Tensorflow.js 是谷歌开发的一个开源库,用于在浏览器或节点环境中运行机器学习模型和深度学习神经网络。

tf.LayersModel 类的 .fit( ) 方法用于为固定数量的 epoch(数据集上的迭代)训练模型。

句法:

fit(x, y, args?)

参数:此方法接受以下参数。

  • x:包含所有输入数据的是 tf.Tensor。
  • y:包含所有输出数据的是 tf.Tensor。
  • args:对象类型,变量如下:
    1. batchSize:它定义将通过训练传播的样本数量。
    2. epochs:它定义了对训练数据数组的迭代。
    3. 详细:它有助于显示每个时期的进度。如果值为 0 – 这意味着在 fit() 调用期间没有打印消息。如果值为 1 - 这意味着在 Node-js 中,它会打印进度条。在浏览器中,它没有显示任何操作。值 1 是默认值。 2 – 值 2 尚未实现。
    4. 回调:它定义了在训练期间要调用的回调列表。变量可以有一个或多个回调 onTrainBegin()、onTrainEnd()、onEpochBegin()、onEpochEnd()、onBatchBegin()、onBatchEnd()、onYield()。
    5. validationSplit:它使用户可以轻松地将训练数据集拆分为训练和验证。例如:如果它的值是validation-Split = 0.5,这意味着在洗牌之前使用最后50%的数据进行验证。
    6. validationData:用于在最终模型之间进行选择时对最终模型进行估计。
    7. shuffle:这个值定义了每个 epoch 之前数据的 shuffle。当stepsPerEpoch 不为null 时,它不起作用。
    8. classWeight:用于对损失函数进行加权。告诉模型更多地关注来自代表性不足的类的样本可能很有用。
    9. sampleWeight:它是一个权重数组,适用于每个样本的模型损失。
    10. initialEpoch:它是定义开始训练的时期的值。这对于恢复以前的训练运行很有用。
    11. stepsPerEpoch:它在声明一个 epoch 完成并开始下一个 epoch 之前定义了一批样本。如果未确定,则等于 1。
    12. validationSteps:如果指定了stepsPerEpoch ,则相关。停止前要验证的总步骤数。
    13. yieldEvery:它定义了将主线程让给其他任务的频率的配置。它可以是自动的,这意味着屈服以一定的帧速率发生。批次,如果值是这个,它会产生每个批次。 epoch,如果值是这个,它会产生每个 epoch。任何数字,如果该值是任何数字,则产生每个数字毫秒never ,如果值是这个,它永远不会产生。

回报:它回报了历史的承诺。

示例 1:

Javascript
// Importing the tensorflow.js library
import * as tf from "@tensorflow/tfjs"
  
// Defining model
const mymodel = tf.sequential({
     layers: [tf.layers.dense({units: 2, inputShape: [6]})]
});
  
// Compiling the above model
mymodel.compile({optimizer: 'sgd', loss: 'meanSquaredError'});
  
// Using for loop
for (let i = 0; i < 4; i++) {
     
  // Calling fit() method
  const his = await mymodel.fit(tf.zeros([6, 6]), tf.ones([6, 2]), {
       batchSize: 5,
       epochs: 4
   });
    
    // Printing output
    console.log(his.history.loss[1]);
}


Javascript
// Importing the tensorflow.js library
import * as tf from "@tensorflow/tfjs"
  
// Defining model
const mymodel = tf.sequential({
     layers: [tf.layers.dense({units: 2, inputShape: [6], 
                               activation : "sigmoid"})]});
  
// Compiling the above model
mymodel.compile({optimizer: 'sgd', loss: 'meanSquaredError'});   
  
// Calling fit() method
 const his = await mymodel.fit(tf.truncatedNormal([6, 6]), 
                             tf.randomNormal([6, 2]), { batchSize: 5,
                             epochs: 4, validationSplit: 0.2, 
                             shuffle: true, initialEpoch: 2, 
                             stepsPerEpoch: 1, validationSteps: 2});
    
// Printing output
console.log(JSON.stringify(his.history));


输出:

0.9574100375175476
0.8151942491531372
0.694103479385376
0.5909997820854187

示例 2:

Javascript

// Importing the tensorflow.js library
import * as tf from "@tensorflow/tfjs"
  
// Defining model
const mymodel = tf.sequential({
     layers: [tf.layers.dense({units: 2, inputShape: [6], 
                               activation : "sigmoid"})]});
  
// Compiling the above model
mymodel.compile({optimizer: 'sgd', loss: 'meanSquaredError'});   
  
// Calling fit() method
 const his = await mymodel.fit(tf.truncatedNormal([6, 6]), 
                             tf.randomNormal([6, 2]), { batchSize: 5,
                             epochs: 4, validationSplit: 0.2, 
                             shuffle: true, initialEpoch: 2, 
                             stepsPerEpoch: 1, validationSteps: 2});
    
// Printing output
console.log(JSON.stringify(his.history));

输出:

{"val_loss":[0.35800713300704956,0.35819053649902344],
"loss":[0.633269190788269,0.632409930229187]}

参考: https://js.tensorflow.org/api/latest/#tf.LayersModel.fit