📜  Tensorflow.js tf.gather()函数(1)

📅  最后修改于: 2023-12-03 15:05:32.780000             🧑  作者: Mango

TensorFlow.js tf.gather()函数

tf.gather()函数是TensorFlow.js中的一个函数,用于从输入张量的轴上抽取元素。抽取的顺序由索引张量决定,索引张量的形状必须与要抽取的轴在除抽取维度之外的其他维度上匹配。这一函数可以用于数据的特定维度上的选择或两个张量的元素特定组合。

语法
tf.gather(input, indices, axis?)
参数
  • input: tf.Tensor - 输入张量
  • indices: tf.Tensor - 索引张量,形状必须与要抽取的轴在除抽取维度之外的其他维度上匹配
  • axis: number (可选) - 要在其上选取的轴,默认情况下是0
返回

一个新的张量,包含从输入张量的指定轴上抽取的元素。

示例
const tensor = tf.tensor2d([[1, 2], [3, 4], [5, 6]]); // 创建一个2维张量

const indices = tf.tensor1d([1]); // 抓出第二个元素
const result = tf.gather(tensor, indices, 0); // 在0轴上抓出第二个元素

result.print(); // 打印 [[3, 4]]
总结

TensorFlow.js中的tf.gather()函数是一个实用的函数,可以从输入张量的特定轴上抽取元素。在数据科学中,这种函数可以用于选择特定维度的数据或者组合不同的数据。记得索引张量的形状必须与要抽取的轴在除抽取维度之外的其他维度上匹配。