📌  相关文章
📜  Tensorflow.js tf.data.Dataset 类 .batch() 方法

📅  最后修改于: 2022-05-13 01:56:53.158000             🧑  作者: Mango

Tensorflow.js tf.data.Dataset 类 .batch() 方法

Tensorflow.js是谷歌开发的一个开源库,用于在浏览器或节点环境中运行机器学习模型和深度学习神经网络。它还可以帮助开发人员用 JavaScript 语言开发 ML 模型,并且可以直接在浏览器或 Node.js 中使用 ML。

tf.data.Dataset.batch()函数用于将元素分组为批次。

句法:

tf.data.Dataset.batch(batchSize, smallLastBatch?)

参数:

  • batchSize:应该在一个批次中存在的元素。
  • smallLastBatch:如果为 true,如果最后一批的元素少于 batchSize,则最后一批将发出元素,反之亦然。默认值为真。提供此值是可选的。

返回值:它返回一个 tf.data.Dataset。

示例 1:在此示例中,我们将获取一个大小为 6 的数组,并将其拆分为多个批次,每个批次包含 3 个元素。

Javascript
// Importing the tensorflow.js library
import * as tf from "@tensorflow/tfjs"
  
// Creating an array
const gfg = tf.data.array(
  [10, 20, 30, 40, 50, 60]
).batch(3);
  
// Printing the elements
await gfg.forEachAsync(
  element => element.print()
);


Javascript
// Importing the tensorflow.js library
import * as tf from "@tensorflow/tfjs"
  
// Creating an array
const gfg = tf.data.array(
  [10, 20, 30, 40, 50, 60, 70, 80]
).batch(3);
  
// Printing the elements
await gfg.forEachAsync(
  element => element.print()
);


Javascript
// Importing the tensorflow.js library
import * as tf from "@tensorflow/tfjs"
  
// Creating an array
const gfg = tf.data.array(
  [10, 20, 30, 40, 50, 60, 70, 80]
).batch(3, false);
  
// Printing the elements
await gfg.forEachAsync(
  element => element.print()
);


输出:

"Tensor
    [10, 20, 30]"
"Tensor
    [40, 50, 60]"

示例 2:这次我们将取 8 个元素,并尝试将它们分批拆分,每批 3 个元素。

Javascript

// Importing the tensorflow.js library
import * as tf from "@tensorflow/tfjs"
  
// Creating an array
const gfg = tf.data.array(
  [10, 20, 30, 40, 50, 60, 70, 80]
).batch(3);
  
// Printing the elements
await gfg.forEachAsync(
  element => element.print()
);

输出:

"Tensor
    [10, 20, 30]"
"Tensor
    [40, 50, 60]"
"Tensor
    [70, 80]"

由于 smallLastBatch 的默认值默认为 true,因此我们看到了具有 2 个元素的第三批。

示例 3:这次我们将 smallLastBatch 参数作为 false 传递。

Javascript

// Importing the tensorflow.js library
import * as tf from "@tensorflow/tfjs"
  
// Creating an array
const gfg = tf.data.array(
  [10, 20, 30, 40, 50, 60, 70, 80]
).batch(3, false);
  
// Printing the elements
await gfg.forEachAsync(
  element => element.print()
);

输出:

"Tensor
    [10, 20, 30]"
"Tensor
    [40, 50, 60]"

由于 smallLastBatch 的默认值为 false,我们看不到第三批,因为最后一批中只有 2 个元素,小于 3,即指定的批大小。

参考: https://js.tensorflow.org/api/latest/#tf.data.Dataset.batch