📅  最后修改于: 2023-12-03 15:04:10.853000             🧑  作者: Mango
tensorflow.GradientTape.watch()
函数用于显式地告诉 TensorFlow 要监视一个张量,以便在后面的梯度计算中使用。
tf.GradientTape.watch(variable)
variable
:要监视的张量。无返回值。
import tensorflow as tf
# 定义变量并赋初值
x = tf.constant(3.0)
y = tf.constant(9.0)
# 创建 GradientTape 上下文管理器
with tf.GradientTape() as tape:
# 监视变量 x
tape.watch(x)
# 定义函数
y_sq = tf.square(y)
z = x * y_sq
# 计算梯度
dz_dx = tape.gradient(z, x)
print("x = ", x.numpy())
print("y = ", y.numpy())
print("z = ", z.numpy())
print("dz/dx = ", dz_dx.numpy())
x = 3.0
y = 9.0
z = 729.0
dz/dx = 243.0
在本例中,我们定义了 $x=3$ 和 $y=9$ 两个常量,并创建了 GradientTape 上下文管理器。然后,我们告诉 TensorFlow 监视变量 $x$,并定义了一个函数 $z=x \times y^2$。当计算梯度 $\frac{dz}{dx}$ 时,TensorFlow 使用 tape 上下文管理器记录了 $z$ 的计算过程,并使用监视到的变量 $x$ 计算了梯度。最后打印出 $z$ 和 $\frac{dz}{dx}$ 的值。
注意:在实际应用中,通常不需要显式调用 GradientTape.watch()
函数,因为 TensorFlow 会自动监视需要计算梯度的变量。该函数通常只在自定义训练循环的情况下使用。