📜  Tensorflow.js tf.gatherND()函数(1)

📅  最后修改于: 2023-12-03 15:35:17.240000             🧑  作者: Mango

Tensorflow.js tf.gatherND()函数介绍

简介

Tensorflow.js中的tf.gatherND()函数通过提供索引来从张量中收集元素。它支持索引多个元素,包括多个维度和不同形状的张量。tf.gatherND()函数也可以用于子集采样,其中张量中的一部分元素被采样。

语法

以下是tf.gatherND()函数的语法:

tf.gatherND(x, indices)
  • x: 张量,要从中抽取元素的张量。
  • indices: 张量,要在x中选择元素的索引。indices张量的最后一个维度必须与x的最后一个维度匹配,并且它们必须具有不同的维度大小。最后一个维度可以是任何大小。
示例
const x = tf.tensor3d([[[0, 1], [2, 3]], [[4, 5], [6, 7]]]);
const indices = tf.tensor2d([[1, 0], [0, 1]]);
const result = tf.gatherND(x, indices);

result.print();
// output: 
// [[2, 1],
//  [4, 7]]

在上面的示例中,我们使用tf.tensor3d()创建了一个形状为[2, 2, 2]的3D张量x。然后,我们使用tf.tensor2d()创建一个形状为[2, 2]的2D张量indices,该张量表示要从x中选择的元素的索引。我们使用tf.gatherND()函数来选择这些元素并存储在结果张量中,并使用print()函数在控制台中打印结果。

注意事项
  • indices的形状必须是[s1,...,sn, r, q],其中s1,...,sn是在x上的形状,r是最后一个维度大小,q可以是任何大小。
  • tf.gatherND()函数返回的张量的形状与indices的形状相同,其中最后一个维度被替换为x张量中的相应维度大小。
总结

tf.gatherND()函数是一个有用的函数,可以根据索引从张量中选择元素。这个函数支持多个维度和不同形状的张量,可以用于子集采样。本文介绍了tf.gatherND()函数的用法,示例和注意事项,希望可以帮助您更好地使用Tensorflow.js。