📅  最后修改于: 2023-12-03 15:34:06.795000             🧑  作者: Mango
在 TensorFlow 中,GradientTape API 是前向传递和反向传递中最强大的 API 之一。在这个 API 中,tf.GradientTape 类被用来自动计算梯度。tf.GradientTape API 是 TensorFlow 2.0 中的一个新功能。
一般来说 GradientTape 不会生成有效的计算图。通常情况下,在前向传递期间,GradientTape 监视所有操作,以进行自动微分计算。但是默认情况下,GradientTape 在执行前向传递时不跟踪某些操作,如在执行计算时跳出到 Python 等。如果我们希望 GradientTape 监控这些操作,则需要使用 tf.GradientTape.start_recording()
方法启用跟踪。
相对应的 tf.GradientTape.stop_recording()
方法用于停止跟踪这些操作。
以下是 tf.GradientTape.stop_recording()
的语法:
stop_recording()
tf.GradientTape.stop_recording()
方法不需要参数。
tf.GradientTape.stop_recording()
方法没有返回值;
下面的例子将演示如何使用 tf.GradientTape.start_recording()
和 tf.GradientTape.stop_recording()
开启关闭 tape 的记录跟踪功能。
import tensorflow as tf
x = tf.constant(2.0)
y = tf.constant(3.0)
with tf.GradientTape() as tape:
tape.watch([x,y])
z = x * y
tape.stop_recording() # stop tracking x and y here to make them constant
k = x + y
dz_dx, dz_dy = tape.gradient(z, [x, y])
dk_dx, dk_dy = tape.gradient(k, [x, y])
print(dz_dx.numpy()) # Output: 3.0
print(dz_dy.numpy()) # Output: 2.0
print(dk_dx) # Output: None
print(dk_dy) # Output: None
在上面的例子中,GradientTape 在跟踪 x 和 y 下的 'z = x * y' 操作,然后使用 tf.GradientTape.stop_recording()
停止跟踪 x 和 y 在 'x + y' 操作中。接着使用 tape.gradient() 计算 z 和 k 对 x 和 y 的偏导数,我们可以看到 tape.gradient() Yields None for Constant Input x + y。
tf.GradientTape.stop_recording()
方法需要相对应的 tf.GradientTape.start_recording()
或者处于 tf.GradientTape()
上下文范围之内才能使用。
用于记录的 tf.GradientTape 上的操作是动态生成的,它们只有在每次前向传递时才被执行。
tf.GradientTape.stop_recording()
可以不使用而快速抛开计算图构建,可加快速度。如上例所示,我们只对 * 操作进行了跟踪计算梯度,加快了计算速度。
tf.GradientTape.stop_recording()
可以在启用Monitoring 或不启用Monitoring 的情况下使用。它不影响 tf.GradientTape.gradient()
的结果。