📅  最后修改于: 2023-12-03 15:35:19.440000             🧑  作者: Mango
In TensorFlow, the tf.squeeze()
function is used to remove dimensions of size 1 from a tensor. It is often used to remove extra dimensions that are added during the process of training neural networks.
tf.squeeze(input, axis=None, name=None)
Tensor
. The input tensor from which dimension of size 1 will be removedlist
of ints
. If specified, only squeezes the dimensions listed. The dimension index starts at 0.import tensorflow as tf
# create a tensor with extra dimensions
x = tf.constant([[[[1], [2]]]])
# use squeeze to remove the extra dimensions
y = tf.squeeze(x)
print("Shape before squeezing:", x.shape)
print("Shape after squeezing:", y.shape)
Output:
Shape before squeezing: (1, 1, 2, 1)
Shape after squeezing: (2,)
In this example, the original tensor x
had dimensions of size 1 at indices 0, 1, and 3. Using tf.squeeze()
, we removed those dimensions and obtained a new tensor y
with dimensions (2,)
.
The tf.squeeze()
function is a powerful tool in TensorFlow for removing dimensions of size 1 from a tensor. By reducing the number of dimensions in a tensor, we can simplify computations and make our code easier to understand.