📅  最后修改于: 2023-12-03 15:34:03.921000             🧑  作者: Mango
eye()
是 PyTorch 提供的方法之一,用于创建一个二维矩阵,并将其对角线上的元素赋值为 1,其余元素为 0。该方法的语法如下:
torch.eye(n, m=None, out=None, dtype=None, layout=torch.strided, device=None, requires_grad=False)
其中,参数 n
表示生成矩阵的行数,参数 m
表示生成矩阵的列数,如果 m
未指定,则默认与 n
相同,参数 dtype
表示生成矩阵的数据类型,默认为 torch.float32
。
import torch
# 创建一个 3 行 3 列的单位矩阵
out = torch.eye(3)
print(out)
以上代码输出结果如下:
tensor([[1., 0., 0.],
[0., 1., 0.],
[0., 0., 1.]])
在使用 eye()
方法时,可以通过指定 out
参数来使用一个已有的 Tensor
对象来接收生成的矩阵。
import torch
# 创建一个 3 行 3 列的单位矩阵,使用一个已有的 Tensor 对象接收生成的矩阵
out = torch.empty(3, 3)
torch.eye(3, out=out)
print(out)
以上代码输出结果如下:
tensor([[1., 0., 0.],
[0., 1., 0.],
[0., 0., 1.]])
通过指定 dtype
参数,可以生成指定类型的矩阵。
import torch
# 创建一个 3 行 3 列的单位矩阵,数据类型为 torch.int64
out = torch.eye(3, dtype=torch.int64)
print(out)
以上代码输出结果如下:
tensor([[1, 0, 0],
[0, 1, 0],
[0, 0, 1]])
通过本文的介绍,我们学习了 PyTorch 的 eye()
方法,了解了其语法和用法,并且通过举例展示了如何使用该方法,帮助程序员更好地使用 PyTorch 进行矩阵操作。