📜  Tensorflow.js tf.whereAsync()函数(1)

📅  最后修改于: 2023-12-03 15:20:35.718000             🧑  作者: Mango

Tensorflow.js tf.whereAsync() 函数

Tensorflow.js 中提供了 tf.whereAsync() 函数,以帮助程序员在 Javascript 中方便的进行条件判断。

函数说明

tf.whereAsync() 函数的作用是,返回两个张量(或一个标量和一个张量)中符合条件的元素的索引。

whereAsync(condition: Tensor | Tensor[], x: Tensor | Tensor[], y: Tensor | Tensor[]): Promise<Tensor>
参数说明
  • condition: TensorTensor 数组。一个布尔类型的张量或张量数组,表示需要判断的条件。当张量数组时,张量应该具有相同的形状。
  • x: TensorTensor 数组。一个张量或张量数组,表示当条件为真时需要返回的结果。当张量数组时,张量应该具有相同的形状。
  • y: TensorTensor 数组。一个张量或张量数组,表示当条件为假时需要返回的结果。当张量数组时,张量应该具有相同的形状。
返回值说明
  • 返回一个张量,表示符合条件的元素的索引。返回的张量的形状与输入的 condition 张量的形状相同,每个位置上的数值为符合条件的张量的相应位置的索引值。
代码示例
const tf = require("@tensorflow/tfjs");

async function main() {
  const condition = tf.tensor2d([true, false, true], [1, 3]);
  const x = tf.tensor2d([1, 2, 3], [1, 3]);
  const y = tf.tensor2d([4, 5, 6], [1, 3]);

  const indices = await tf.whereAsync(condition, x, y);
  console.log(indices.toString()); // 输出 [[0, 0], [0, 2]]

  const selected = await tf.gatherNDAsync(x.concat(y, 0), indices);
  console.log(selected.toString()); // 输出 [[1, 3]]
}

main();
解释

上述示例中,我们定义了一个包含 3 个元素的布尔类型的张量 condition,表示三个数是否满足某个条件,另外我们定义了两个包含 3 个元素的张量 xy,分别表示条件成立和不成立时需返回的结果。最后,我们使用 tf.whereAsync() 函数获取了符合条件的元素的索引,表示 condition 张量中的元素为 true 的位置。我们通过调用 tf.gatherNDAsync() 函数从 xy 中选出符合条件的元素,并将它们连接在一起。最终输出的就是一个包含符合条件的元素的张量。