📜  Python – tensorflow.gather()(1)

📅  最后修改于: 2023-12-03 15:34:06.779000             🧑  作者: Mango

Python - tensorflow.gather()

Introduction

In Tensorflow, tf.gather() function is used to gather slices from the specified axis of a tensor. It is similar to the NumPy array indexing concept. The tf.gather() function takes two input parameters. The first parameter is the tensor over which the gathering operation needs to be performed, and the second parameter is a tensor of indices to gather from the first tensor.

Syntax:
tf.gather(params, indices, validate_indices=None, axis=None, batch_dims=0, name=None)
  • params: A tensor from which we need to gather slices.
  • indices: A tensor of type int32 or int64 containing the indices to gather from the specified axis of the params tensor.
  • validate_indices: An option to consider the indices in bounds. If set to True (default), tf.errors.InvalidArgumentError is thrown if any of the indices are out of bounds, else ignored.
  • axis: An integer or a scalar Tensor. The axis along which the slices need to be gathered. If axis=None (default), it gathers slices from all dimensions. Negative axis is allowed.
  • batch_dims: Number of batch dimensions.
  • name: A name for the operation.
Example:
import tensorflow as tf

a = tf.constant([[1, 2], [3, 4], [5, 6]])

indices = tf.constant([0, 1])

result = tf.gather(a, indices)

print(result)

Output:

tf.Tensor(
[[1 2]
 [3 4]], shape=(2, 2), dtype=int32)
Explanation:

In the above example, we have defined a constant tensor a of shape (3, 2). We have also defined another constant tensor indices of shape (2,) containing the indices of the rows to be collected. Finally, we have passed a and indices as input to the tf.gather() function.

The output tensor result obtained after gathering slices of a based on the indices is of shape (2, 2) as the second dimension of a is 2.

Conclusion:

In this article, we have discussed the usage and syntax of tf.gather() function in Tensorflow. The function is useful when gathering slices from a tensor based on the indices from another tensor.