📅  最后修改于: 2023-12-03 14:49:09.909000             🧑  作者: Mango
在 TensorFlow 中,tf.repeat()
是一个用于在张量中重复元素的函数。
tf.repeat()
要求传入一个张量和一个整数 repeats
。它会将张量中的每个元素复制 repeats
次,并返回一个新的张量。下面是一个例子:
import tensorflow as tf
x = tf.constant([1, 2, 3])
y = tf.repeat(x, repeats=3)
print(y)
输出:
tf.Tensor([1 1 1 2 2 2 3 3 3], shape=(9,), dtype=int32)
在这个例子中,我们传入了一个张量 [1, 2, 3]
,并将每个元素都复制了 3 次,得到了一个新的张量 [1, 1, 1, 2, 2, 2, 3, 3, 3]
。
我们还可以传入一个整数数组来指定每个元素要重复的次数。例如:
import tensorflow as tf
x = tf.constant([1, 2, 3])
y = tf.repeat(x, repeats=[1, 2, 3])
print(y)
输出:
tf.Tensor([1 2 2 3 3 3], shape=(6,), dtype=int32)
在这个例子中,我们传入了一个整数数组 [1, 2, 3]
,表示要将第一个元素复制 1 次,第二个元素复制 2 次,第三个元素复制 3 次。最终得到的新的张量为 [1, 2, 2, 3, 3, 3]
。
tf.repeat()
是 TensorFlow 中一个用于在张量中重复元素的函数。它非常实用,可以用于各种需要重复元素的场景。