📜  Python – tensorflow.GradientTape.watch()(1)

📅  最后修改于: 2023-12-03 15:04:10.853000             🧑  作者: Mango

Python – tensorflow.GradientTape.watch()

tensorflow.GradientTape.watch() 函数用于显式地告诉 TensorFlow 要监视一个张量,以便在后面的梯度计算中使用。

语法
tf.GradientTape.watch(variable)
参数
  • variable:要监视的张量。
返回值

无返回值。

示例代码
import tensorflow as tf

# 定义变量并赋初值
x = tf.constant(3.0)
y = tf.constant(9.0)

# 创建 GradientTape 上下文管理器
with tf.GradientTape() as tape:
  # 监视变量 x
  tape.watch(x)
  # 定义函数
  y_sq = tf.square(y)
  z = x * y_sq

# 计算梯度
dz_dx = tape.gradient(z, x)

print("x = ", x.numpy())
print("y = ", y.numpy())
print("z = ", z.numpy())
print("dz/dx = ", dz_dx.numpy())
输出结果
x = 3.0
y = 9.0
z = 729.0
dz/dx = 243.0
解释

在本例中,我们定义了 $x=3$ 和 $y=9$ 两个常量,并创建了 GradientTape 上下文管理器。然后,我们告诉 TensorFlow 监视变量 $x$,并定义了一个函数 $z=x \times y^2$。当计算梯度 $\frac{dz}{dx}$ 时,TensorFlow 使用 tape 上下文管理器记录了 $z$ 的计算过程,并使用监视到的变量 $x$ 计算了梯度。最后打印出 $z$ 和 $\frac{dz}{dx}$ 的值。

注意:在实际应用中,通常不需要显式调用 GradientTape.watch() 函数,因为 TensorFlow 会自动监视需要计算梯度的变量。该函数通常只在自定义训练循环的情况下使用。