📜  Python PyTorch linalg.svd() 方法(1)

📅  最后修改于: 2023-12-03 14:46:03.334000             🧑  作者: Mango

Python PyTorch linalg.svd() 方法介绍

简介

PyTorch 是一个开源的 Python 机器学习库,其中提供了 linalg.svd() 方法,可以对矩阵进行奇异值分解(SVD)。

SVD 是一种数学变换,可以将一个矩阵分解为三个矩阵的乘积,这三个矩阵包括一个正交矩阵、一个对角矩阵和一个转置矩阵,其中对角矩阵的对角线元素称为奇异值,代表了矩阵的特征值。

linalg.svd() 方法返回矩阵的奇异值和两个正交矩阵,其中一个正交矩阵包含了矩阵的行空间,另一个正交矩阵包含了矩阵的列空间。

用法

linalg.svd() 方法的用法如下:

torch.linalg.svd(input, full_matrices=True, compute_uv=True, out=None)

其中参数说明如下:

  • input: 待分解的矩阵,可以是二维或高维矩阵。
  • full_matrices:是否返回完整的正交矩阵。默认值为 True。如果为 False,则只返回包含奇异值的对角矩阵和其中一个正交矩阵。
  • compute_uv:是否计算正交矩阵。默认值为 True。如果为 False,则只返回包含奇异值的对角矩阵。
  • out:输出结果的矩阵。默认值为 None。

返回结果为一个元组,包含三个张量分别为 U、S 和 V,其中:

  • U:一个张量,包含了矩阵的行空间,shape 为 $(, m, m)$ 或 $(, m, k)$,取决于 full_matrices 参数的值。
  • S:一个张量,包含了矩阵的奇异值,shape 为 $(*, k)$,其中 k 是 min(input.shape)。
  • V:一个张量,包含了矩阵的列空间,shape 为 $(, n, n)$ 或 $(, k, n)$,取决于 full_matrices 参数的值。
示例

下面是对一个 $3 \times 3$ 的矩阵进行奇异值分解的示例代码:

import torch

# 输入矩阵
a = torch.tensor([[1., 2., 3.], [4., 5., 6.], [7., 8., 9.]])

# 奇异值分解
u, s, v = torch.linalg.svd(a)

print('U =', u)
print('S =', s)
print('V =', v)

运行结果如下:

U = tensor([[-0.2148,  0.8872,  0.4082],
        [-0.5206,  0.2490, -0.8165],
        [-0.8264, -0.3892,  0.4082]])
S = tensor([1.6848e+01, 1.0684e+00, 2.0362e-16])
V = tensor([[-0.4797, -0.5724, -0.6652],
        [-0.7760, -0.0757,  0.6246],
        [ 0.4082, -0.8165,  0.4082]])

结果中的 U 和 V 是正交矩阵,S 是对角矩阵,其中对角线上的元素为奇异值。