📜  ctx.save_for_backward - Python (1)

📅  最后修改于: 2023-12-03 15:30:13.182000             🧑  作者: Mango

Python中的ctx.save_for_backward

ctx.save_for_backward是在pytorch中使用的一种机制,用于实现自动微分。在深度学习中,我们需要对模型的参数进行反向传播来更新他们(也就是模型的训练),而存储这些参数在前向传播和反向传播中间的计算结果显然是需要的。如果我们需要进行更高阶的自动微分,那么这些计算结果也是需要的。

在pytorch中,ctx.save_for_backward一般用在一个自定义函数的实现内部。这个自定义函数可以被使用在一个pytorch的计算图中。

下面是一个例子,展示了如何在一个简单的前向传播中使用ctx.save_for_backward:

import torch

class MyReLU(torch.autograd.Function):
    @staticmethod
    def forward(ctx, input):
        ctx.save_for_backward(input)
        return input.clamp(min=0)

    @staticmethod
    def backward(ctx, grad_output):
        input, = ctx.saved_tensors
        grad_input = grad_output.clone()
        grad_input[input < 0] = 0
        return grad_input

dtype = torch.float
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

N, D_in, H, D_out = 64, 1000, 100, 10

x = torch.randn(N, D_in, device=device, dtype=dtype)
y = torch.randn(N, D_out, device=device, dtype=dtype)

w1 = torch.randn(D_in, H, device=device, dtype=dtype, requires_grad=True)
w2 = torch.randn(H, D_out, device=device, dtype=dtype, requires_grad=True)

learning_rate = 1e-6

for t in range(500):
    relu = MyReLU.apply
    y_pred = relu(x.mm(w1)).mm(w2)

    loss = (y_pred - y).pow(2).sum()
    print(t, loss.item())

    loss.backward()

    with torch.no_grad():
        w1 -= learning_rate * w1.grad
        w2 -= learning_rate * w2.grad

        w1.grad.zero_()
        w2.grad.zero_()

这个例子中,我们定义了一个自定义函数MyReLU,这个函数实现了ReLU激活函数。MyReLU在前向传播中调用ctx.save_for_backward计算结果并将其保存,以供变量在反向传播中进行微分更新。

我们在这个例子中使用x, y, w1和w2作为张量,并调用了一个简单的优化算法(SGD)来进行训练。在每个迭代中,我们都会实现前向传播,计算损失、执行后向传播,并用优化算法更新权重。

总结

ctx.save_for_backward在pytorch中的自动微分功能中很重要,它使我们能够在前向传播和反向传播期间存储张量,并在需要时使用它们来计算梯度。它是自定义函数实现自动微分算法的重要机制。