Tensorflow.js tf.gather()函数
Tensorflow.js 是一个由谷歌开发的开源库,用于在浏览器或节点环境中运行机器学习模型以及深度学习神经网络。
.gather()函数用于根据规定的索引从规定的张量 x 轴收集片段。
句法:
tf.gather(x, indices, axis?, batchDims?)
参数:
- x:指定的输入张量,要收集的分片,可以是tf.Tensor、TypedArray或Array类型。
- indices:表示要取出的值的索引,可以是 tf.Tensor、TypedArray 或 Array 类型。
- 轴:要在其上方选择值的指定轴。默认值为零,并且它的类型为 number。但是,此参数是可选的。
- batchDims:它是规定的批大小数量,应小于或等于规定的等级,即索引。它的默认值为零。此外,返回的输出必须具有 x.shape[:axis] + indices.shape[batchDims:] + x.shape[axis + 1:] 的形状。它的类型为 number 并且是可选的。
返回值:返回 tf.Tensor 对象。
示例 1:
Javascript
// Importing the tensorflow.js library
import * as tf from "@tensorflow/tfjs"
// Defining tensor input and indices
const y = tf.tensor1d([1, 6, 7, 8]);
const ind = tf.tensor1d([1, 6, 2], 'int32');
// Calling tf.gather() method and
// Printing output
y.gather(ind).print();
Javascript
// Importing the tensorflow.js library
import * as tf from "@tensorflow/tfjs"
// Defining tensor input, indices, axis,
// and batchdims
const y = tf.tensor2d([7, 8, 12, 13], [4, 1]);
const ind = tf.tensor1d([2, 3, 0], 'int32');
const axis = 1;
const batchdims = -1;
// Calling tf.gather() method
var res = tf.gather(y, ind, axis, batchdims);
// Printing output
res.print();
输出:
Tensor
[6, NaN, 7]
示例 2:
Javascript
// Importing the tensorflow.js library
import * as tf from "@tensorflow/tfjs"
// Defining tensor input, indices, axis,
// and batchdims
const y = tf.tensor2d([7, 8, 12, 13], [4, 1]);
const ind = tf.tensor1d([2, 3, 0], 'int32');
const axis = 1;
const batchdims = -1;
// Calling tf.gather() method
var res = tf.gather(y, ind, axis, batchdims);
// Printing output
res.print();
输出:
Tensor
[[12 , 13 , 7 ],
[13 , NaN, 8 ],
[NaN, NaN, 12],
[NaN, NaN, 13]]
参考: https://js.tensorflow.org/api/latest/#gather