📅  最后修改于: 2023-12-03 15:38:24.535000             🧑  作者: Mango
在 PyTorch 中,可以使用 torch.transpose
函数来实现张量的转置。
torch.transpose(input, dim0, dim1) -> Tensor
input
:需要进行转置的张量。dim0
:转置前第一个维度。dim1
:转置前第二个维度。该函数将张量 input
的 dim0
维和 dim1
维进行交换,返回交换后的张量。
下面的示例演示了如何使用 torch.transpose
函数来实现张量的转置。
import torch
# 创建一个 2x3 的张量
x = torch.Tensor([[1, 2, 3], [4, 5, 6]])
print("x:")
print(x)
# 对 x 进行转置
y = torch.transpose(x, 0, 1)
print("y:")
print(y)
输出:
x:
tensor([[1., 2., 3.],
[4., 5., 6.]])
y:
tensor([[1., 4.],
[2., 5.],
[3., 6.]])
在上面的示例中,我们创建了一个 2x3 的张量 x
,然后使用 torch.transpose
函数将其转置。转置后,第一维和第二维被交换,变成了一个 3x2 的张量 y
。
torch.transpose
函数可以用来实现张量的转置,可以指定需要交换的维度。使用该函数可以避免手动交换张量的维度,从而简化代码实现。