📅  最后修改于: 2023-12-03 15:20:35.718000             🧑  作者: Mango
Tensorflow.js 中提供了 tf.whereAsync()
函数,以帮助程序员在 Javascript 中方便的进行条件判断。
tf.whereAsync()
函数的作用是,返回两个张量(或一个标量和一个张量)中符合条件的元素的索引。
whereAsync(condition: Tensor | Tensor[], x: Tensor | Tensor[], y: Tensor | Tensor[]): Promise<Tensor>
condition
: Tensor
或 Tensor
数组。一个布尔类型的张量或张量数组,表示需要判断的条件。当张量数组时,张量应该具有相同的形状。x
: Tensor
或 Tensor
数组。一个张量或张量数组,表示当条件为真时需要返回的结果。当张量数组时,张量应该具有相同的形状。y
: Tensor
或 Tensor
数组。一个张量或张量数组,表示当条件为假时需要返回的结果。当张量数组时,张量应该具有相同的形状。condition
张量的形状相同,每个位置上的数值为符合条件的张量的相应位置的索引值。const tf = require("@tensorflow/tfjs");
async function main() {
const condition = tf.tensor2d([true, false, true], [1, 3]);
const x = tf.tensor2d([1, 2, 3], [1, 3]);
const y = tf.tensor2d([4, 5, 6], [1, 3]);
const indices = await tf.whereAsync(condition, x, y);
console.log(indices.toString()); // 输出 [[0, 0], [0, 2]]
const selected = await tf.gatherNDAsync(x.concat(y, 0), indices);
console.log(selected.toString()); // 输出 [[1, 3]]
}
main();
上述示例中,我们定义了一个包含 3 个元素的布尔类型的张量 condition
,表示三个数是否满足某个条件,另外我们定义了两个包含 3 个元素的张量 x
和 y
,分别表示条件成立和不成立时需返回的结果。最后,我们使用 tf.whereAsync()
函数获取了符合条件的元素的索引,表示 condition
张量中的元素为 true
的位置。我们通过调用 tf.gatherNDAsync()
函数从 x
和 y
中选出符合条件的元素,并将它们连接在一起。最终输出的就是一个包含符合条件的元素的张量。