📌  相关文章
📜  Tensorflow.js tf.data.Dataset 类 .batch() 方法(1)

📅  最后修改于: 2023-12-03 15:35:17.145000             🧑  作者: Mango

TensorFlow.js tf.data.Dataset类 .batch() 方法介绍

在TensorFlow.js中,tf.data.Dataset是一个非常重要的类,用于处理数据集的输入。.batch()方法是 Dataset类 中的一种方法,它可以将Dataset对象的元素组合成批次。

方法语法
dataset.batch(batchSize, dropRemainder)
参数说明
  • batchSize(必选):一个正整数,表示每个批次中包含的元素数量。
  • dropRemainder(可选):一个布尔值,表示是否丢弃最后一个批次中数量小于batchSize的元素。
代码示例
// 创建一个包含一组张量的Dataset对象
const dataset = tf.data.generator(function*() {
  yield tf.tensor1d([1, 2, 3]);
  yield tf.tensor1d([4, 5, 6]);
  yield tf.tensor1d([7, 8, 9]);
});

// 使用 .batch() 方法设置一个批次大小为2的新Dataset对象
const batchedDataset = dataset.batch(2);

// 打印新的Dataset对象中的元素
batchedDataset.forEachAsync(e => {
  console.log(e.toString());
});

输出结果:

Tensor
  [1, 2, 3]
  shape: [3]
  dtype: "float32"

Tensor
  [4, 5, 6]
  shape: [3]
  dtype: "float32"

Tensor
  [7, 8, 9]
  shape: [3]
  dtype: "float32"

上述示例中,我们创建了一个包含三个张量的Dataset对象。我们使用 .batch() 方法将每两个张量组合成一个批次,然后使用forEachAsync()方法逐个打印元素。

总结

使用 .batch() 方法可以将一个Dataset对象的元素分成多个批次。这在训练机器学习模型时非常有用,可以将大型数据集分批读入内存,减轻内存压力,提高模型的训练效率。