📜  如何在 PyTorch 中找到张量的转置?(1)

📅  最后修改于: 2023-12-03 15:38:24.535000             🧑  作者: Mango

如何在 PyTorch 中找到张量的转置?

在 PyTorch 中,可以使用 torch.transpose 函数来实现张量的转置。

语法

torch.transpose(input, dim0, dim1) -> Tensor

  • input:需要进行转置的张量。
  • dim0:转置前第一个维度。
  • dim1:转置前第二个维度。

该函数将张量 inputdim0 维和 dim1 维进行交换,返回交换后的张量。

示例

下面的示例演示了如何使用 torch.transpose 函数来实现张量的转置。

import torch

# 创建一个 2x3 的张量
x = torch.Tensor([[1, 2, 3], [4, 5, 6]])
print("x:")
print(x)

# 对 x 进行转置
y = torch.transpose(x, 0, 1)
print("y:")
print(y)

输出:

x:
tensor([[1., 2., 3.],
        [4., 5., 6.]])
y:
tensor([[1., 4.],
        [2., 5.],
        [3., 6.]])

在上面的示例中,我们创建了一个 2x3 的张量 x,然后使用 torch.transpose 函数将其转置。转置后,第一维和第二维被交换,变成了一个 3x2 的张量 y

总结

torch.transpose 函数可以用来实现张量的转置,可以指定需要交换的维度。使用该函数可以避免手动交换张量的维度,从而简化代码实现。