📅  最后修改于: 2023-12-03 15:35:17.145000             🧑  作者: Mango
在TensorFlow.js中,tf.data.Dataset是一个非常重要的类,用于处理数据集的输入。.batch()方法是 Dataset类 中的一种方法,它可以将Dataset对象的元素组合成批次。
dataset.batch(batchSize, dropRemainder)
// 创建一个包含一组张量的Dataset对象
const dataset = tf.data.generator(function*() {
yield tf.tensor1d([1, 2, 3]);
yield tf.tensor1d([4, 5, 6]);
yield tf.tensor1d([7, 8, 9]);
});
// 使用 .batch() 方法设置一个批次大小为2的新Dataset对象
const batchedDataset = dataset.batch(2);
// 打印新的Dataset对象中的元素
batchedDataset.forEachAsync(e => {
console.log(e.toString());
});
输出结果:
Tensor
[1, 2, 3]
shape: [3]
dtype: "float32"
Tensor
[4, 5, 6]
shape: [3]
dtype: "float32"
Tensor
[7, 8, 9]
shape: [3]
dtype: "float32"
上述示例中,我们创建了一个包含三个张量的Dataset对象。我们使用 .batch() 方法将每两个张量组合成一个批次,然后使用forEachAsync()方法逐个打印元素。
使用 .batch() 方法可以将一个Dataset对象的元素分成多个批次。这在训练机器学习模型时非常有用,可以将大型数据集分批读入内存,减轻内存压力,提高模型的训练效率。