📌  相关文章
📜  Tensorflow.js tf.data.Dataset 类

📅  最后修改于: 2022-05-13 01:56:32.298000             🧑  作者: Mango

Tensorflow.js tf.data.Dataset 类

Tensorflow.js 是谷歌开发的一个开源库,用于在浏览器或节点环境中运行机器学习模型和深度学习神经网络。

Tensorflow.js tf.data.Dataset 类代表大量数据集合。数据可以是数组、映射或基元,也可以是此数据类型的任何嵌套结构。数据类型在一些有序的数据集合中。我们可以对这些数据执行各种方法。这种方法会产生一个新的数据集。

句法:

tf.data.method(args);

参数:此方法接受以下参数:

  • args:并非所有方法都相同,不同方法可能不同。

返回值:它返回与输入相同类型的数据集。

示例 1:在本示例中,我们将看到 batch() 方法。它按批次对元素进行分组。它接受两个参数,batchSize 是组的长度,smallLastBatch 是一个布尔值,告诉打印最后一批是否长度小。

Javascript
// Importing the tensorflow.js library
const tf = require("@tensorflow/tfjs")
 
// Creating tenst dataset of array
const array = [[1, 4, 6], [2, 5, 7],
    [3, 6, 8], [4, 7, 9], [5, 8, 11]];
 
// Making dataset with array
const gfg = tf.data.array(array);
 
// Creating new dataset
const GFG = gfg.batch(2);
 
// Printing Dataset
GFG.forEachAsync(Q => Q.print());


Javascript
// Importing the tensorflow.js library
const tf = require("@tensorflow/tfjs")
 
// Creating Dataset  
const gfg_array = [4, 8, 12, 16, 20];
const gfg = tf.data.array(gfg_array);
 
// Prefetching data from dataset
const GFG_array = gfg.shuffle(5);
 
// Printing data from array
GFG_array.forEachAsync( tm => console.log(tm))


输出:

Tensor
    [[1, 4, 6],
     [2, 5, 7]]
Tensor
    [[3, 6, 8],
     [4, 7, 9]]
Tensor
     [[5, 8, 11],]

示例 2:在此示例中,我们将看到 shuffle() 方法。它用于打乱数据集的元素。它采用缓冲区大小,这是随机元素将从其开始的数字,种子用于创建随机种子以分配元素,或者最后一个参数是 reshuffle each_iteration 这是一个布尔值,它告诉它在迭代时随机重新洗牌数据集。

Javascript

// Importing the tensorflow.js library
const tf = require("@tensorflow/tfjs")
 
// Creating Dataset  
const gfg_array = [4, 8, 12, 16, 20];
const gfg = tf.data.array(gfg_array);
 
// Prefetching data from dataset
const GFG_array = gfg.shuffle(5);
 
// Printing data from array
GFG_array.forEachAsync( tm => console.log(tm))

输出:

8
20
4
16
12

参考: https://js.tensorflow.org/api/latest/#class:data.Dataset