📜  Tensorflow.js tf.data.array() 方法(1)

📅  最后修改于: 2023-12-03 14:47:54.740000             🧑  作者: Mango

TensorFlow.js tf.data.array() 方法

TensorFlow.js 是 Google TensorFlow 的 JavaScript 版本,它提供了一些不同于传统机器学习的功能,其 API 设计体现了实现高性能机器学习在 Web 浏览器中的可能性。

tf.data.array() 是 TensorFlow.js 中一个数据集的构造方法,它直接从数组中创建一个数据集。

语法
tf.data.array(
  data: Array | TypedArray | Iterable,
  options: {
    batchSize?: number;
    numEpochs?: number;
    shuffle?: boolean;
  } = {}
);
  • data:一个数组、一种类型的数组或一个可迭代对象。
  • options:一个可选选项对象,包括batchSizenumEpochsshuffle等参数。
示例

本例展示了如何使用 tf.data.array() 从数字数组中创建一个数据集,并对数据集进行遍历。

const data = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10];
const dataset = tf.data.array(data);

dataset.forEachAsync(e => {
  console.log(e);
})

你可以使用可选的batchSizenumEpochsshuffle参数来控制生成的数据集。例如,下面的代码展示了如何生成随机的批次和 epochs。

const data = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10];
const dataset = tf.data.array(data, {
  batchSize: 2,
  numEpochs: 2,
  shuffle: true
});

dataset.forEachAsync(e => {
  console.log(e);
})

可选的参数 batchSize 定义了每个批次所包含的示例数, numEpochs 定义了数据集将被遍历的次数, shuffle 则指示数据集是否随机排序。

结论

tf.data.array()使得在 TensorFlow.js 中使用数组数据变得简单方便。它可以用于许多不同的用例,除了本文介绍的示例之外,还可以从多个数组中创建数据集,也可以使用 Iterator 等可迭代对象创建数据集。