📜  如何在 PyTorch 中计算 Hessian

📅  最后修改于: 2022-05-13 01:54:37.105000             🧑  作者: Mango

如何在 PyTorch 中计算 Hessian

Hessian 矩阵或 Hessian 是函数的二阶偏导数的方阵。该函数必须是标量值函数。标量值函数是一个接受一个或多个值并返回单个值的函数。例如f(x,y) = xy^3-7x    是一个标量函数,因为它接受两个值 x 和 y 但返回一个值(计算值xy^3-10x  )。

在 PyTorch 中计算 Hessian

在 PyTorch 中计算标量值函数的 Hessian。

标量值()函数:

示例 1:

在此示例中,我们使用单个变量的标量值函数(单变量函数)。我们为具有单个元素的输入张量以及具有多个元素的输入张量计算此函数的 hessian。看看 hessian 如何为相同的函数寻找这些输入。标量值函数是为单个变量定义的。输入是一个张量,并注意到 Hessian 也是一个张量。当输入张量具有单个元素时,hessian 是具有单个元素的张量。当输入张量具有三个元素时,hessian 是大小为 [3, 3] 的张量。同样,大小为 [2, 2] 的输入张量的 hessian 是大小为 [2,2,2,2] 的张量。

Python3
# Python program to compute Hessian in PyTorch
# importing libraries
import torch
from torch.autograd.functional import hessian
  
# defining a function
def func(x):
    return (2*x.pow(3) - x.pow(2)).sum()
  
# defining the input tensor
input = torch.tensor([3.])
print("Input:\n", input)
  
# computing the hessian
output = hessian(func, input)
  
# printing the above computed tensor
print("Hessian:\n", output)
  
# .....New input
input = torch.tensor([2., 3., 4.])
print("Input:\n", input)
  
# computing the hessian
output = hessian(func, input)
  
# printing the above computed tensor
print("Hessian:\n", output)
  
# .....New input
input = torch.tensor([[2., 3.], [4., 7]])
print("Input:\n", input)
  
# computing the hessian
output = hessian(func, input)
  
# printing the above computed tensor
print("Hessian:\n", output)


Python3
# Python program to compute Hessian in PyTorch
# importing libraries
import torch
from torch.autograd.functional import hessian
  
# defining a function
def func(x, y):
    return (2*x*y.pow(2) + x.pow(3) - 10).sum()
  
# defining the inputs
input_x = torch.tensor([2.])
input_y = torch.tensor([-3.])
inputs = (input_x, input_y)
print("inputs:\n", inputs)
  
# compute the hessian
output = hessian(func, inputs)
  
# printing the above computed hessian
print("Hessian:\n", output)


Python3
# Python program to compute Hessian in PyTorch
# importing libraries
import torch
from torch.autograd.functional import hessian
  
# defining a function
def func(x, y, z):
    return (2*x.pow(2)*y + x*z.pow(3) - 10).sum()
  
# defining the inputs
input_x = torch.tensor([1.])
input_y = torch.tensor([2.])
input_z = torch.tensor([3.])
  
#inputs = (input_x, input_y, input_z)
  
# compute the hessian
output = hessian(func, (input_x, input_y, input_z))
  
# printing the above computed hessian
print("Hessian Tensor:\n", output)


输出:

Input:
 tensor([3.])
Hessian:
 tensor([[34.]])
Input:
 tensor([2., 3., 4.])
Hessian:
 tensor([[22.,  0.,  0.],
        [ 0., 34.,  0.],
        [ 0.,  0., 46.]])
Input:
 tensor([[2., 3.],
        [4., 7.]])
Hessian:
 tensor([[[[22.,  0.],
          [ 0.,  0.]],

         [[ 0., 34.],
          [ 0.,  0.]]],


        [[[ 0.,  0.],
          [46.,  0.]],

         [[ 0.,  0.],
          [ 0., 82.]]]])

示例 2:

在下面的示例中,我们定义了两个变量的标量值函数(二元函数)。我们输入两个张量的元组。标量值函数是为两个变量定义的。输入是两个张量的元组,并注意到输出(粗麻布)是张量元组的元组。每个内部元组都有两个元素(张量)。这里 Hessian[i][j] 包含第 i 个输入和第 j 个输入的 Hessian。

Python3

# Python program to compute Hessian in PyTorch
# importing libraries
import torch
from torch.autograd.functional import hessian
  
# defining a function
def func(x, y):
    return (2*x*y.pow(2) + x.pow(3) - 10).sum()
  
# defining the inputs
input_x = torch.tensor([2.])
input_y = torch.tensor([-3.])
inputs = (input_x, input_y)
print("inputs:\n", inputs)
  
# compute the hessian
output = hessian(func, inputs)
  
# printing the above computed hessian
print("Hessian:\n", output)

输出:

inputs:
 (tensor([2.]), tensor([-3.]))
Hessian:
 ((tensor([[12.]]), tensor([[-12.]])), 
 (tensor([[-12.]]), tensor([[8.]])))

示例 3:

在下面的示例中,我们定义了三个变量的标量值函数。我们输入三个张量的元组。标量值函数是为三个变量定义的。输入是三个张量的元组,请注意输出(粗麻布)是张量元组的元组。这里 Hessian[i][j] 包含第 i 个输入和第 j 个输入的 Hessian。

Python3

# Python program to compute Hessian in PyTorch
# importing libraries
import torch
from torch.autograd.functional import hessian
  
# defining a function
def func(x, y, z):
    return (2*x.pow(2)*y + x*z.pow(3) - 10).sum()
  
# defining the inputs
input_x = torch.tensor([1.])
input_y = torch.tensor([2.])
input_z = torch.tensor([3.])
  
#inputs = (input_x, input_y, input_z)
  
# compute the hessian
output = hessian(func, (input_x, input_y, input_z))
  
# printing the above computed hessian
print("Hessian Tensor:\n", output)

输出:

Hessian Tensor:
 ((tensor([[8.]]), tensor([[4.]]), tensor([[27.]])), 
 (tensor([[4.]]), tensor([[0.]]), tensor([[0.]])), 
 (tensor([[27.]]), tensor([[0.]]), tensor([[18.]])))