📅  最后修改于: 2023-12-03 15:35:18.110000             🧑  作者: Mango
tf.where()
函数是TensorFlow.js中的一个用于条件判断的函数。它可以在给定条件下返回满足条件的元素在tensor中的索引。
tf.where(condition: Tensor, a: Tensor, b: Tensor): Tensor;
其中,condition
是一个0和1组成的tensor,其维度应和a
和b
的维度相同。a
和b
也都是tensor,且其维度应一致。
函数返回值是一个由满足条件的元素在tensor中的索引所组成的tensor。
下面是一个使用tf.where()
函数的示例:
const condition = tf.tensor2d([[0.12, 0.85], [0.63, 0.94]]);
const a = tf.tensor2d([[1, 2], [3, 4]]);
const b = tf.tensor2d([[5, 6], [7, 8]]);
const result = tf.where(condition, a, b);
result.print(); // 输出张量
输出结果为:
Tensor
[[5, 2],
[3, 4]]
在上面的示例中,我们先定义了一个由一组条件组成的tensor condition
,还定义了两个和condition
维度相同的tensor a
和b
。
然后,我们使用tf.where()
函数来判断condition
中每个元素的值是否为1。如果是1,则返回相同位置的a
中的元素值;如果是0,则返回相同位置的b
中的元素值。
在示例中,condition
的第二个元素为1,所以结果tensor的第二个元素为a
中对应位置的值,也就是2。而第一个元素为0,所以结果tensor的第一个元素为b
中对应位置的值,也就是5。
通过使用tf.where()
函数,我们可以在TensorFlow.js中进行条件判断,并返回满足条件的元素在tensor中的索引。在实际编程中,tf.where()
函数的应用场景十分广泛,如在图像处理中进行像素筛选、在机器学习中进行条件判定等等。