📅  最后修改于: 2023-12-03 15:34:06.779000             🧑  作者: Mango
In Tensorflow, tf.gather()
function is used to gather slices from the specified axis of a tensor. It is similar to the NumPy array indexing concept. The tf.gather()
function takes two input parameters. The first parameter is the tensor over which the gathering operation needs to be performed, and the second parameter is a tensor of indices to gather from the first tensor.
tf.gather(params, indices, validate_indices=None, axis=None, batch_dims=0, name=None)
True
(default), tf.errors.InvalidArgumentError
is thrown if any of the indices are out of bounds, else ignored.axis=None
(default), it gathers slices from all dimensions. Negative axis is allowed.import tensorflow as tf
a = tf.constant([[1, 2], [3, 4], [5, 6]])
indices = tf.constant([0, 1])
result = tf.gather(a, indices)
print(result)
Output:
tf.Tensor(
[[1 2]
[3 4]], shape=(2, 2), dtype=int32)
In the above example, we have defined a constant tensor a
of shape (3, 2). We have also defined another constant tensor indices
of shape (2,) containing the indices of the rows to be collected. Finally, we have passed a
and indices
as input to the tf.gather()
function.
The output tensor result
obtained after gathering slices of a
based on the indices is of shape (2, 2) as the second dimension of a
is 2.
In this article, we have discussed the usage and syntax of tf.gather()
function in Tensorflow. The function is useful when gathering slices from a tensor based on the indices from another tensor.