如何在 PyTorch 中找到张量的转置?
在本文中,我们将讨论如何在 PyTorch 中找到张量的转置。通过将行更改为列,将列更改为行来获得转置。我们可以使用 transpose() 方法转置张量。以下语法用于查找张量的转置。
Syntax: torch.transpose(input_tens, dim_0, dim_1)
Parameters:
- input_tens : the input tensor that we want to transpose.
- dim_0 : it will use when we want the first dimension to be transposed..
- dim_1 : it will use when we want the second dimension to be transposed.
Return : this method return transpose of input tensor.
示例 1:
下面的程序是为了了解如何找到 2D 张量的转置。
Python
# import torch module
import torch
# Define a 2D tensor
tens = torch.tensor([[10, 20, 30],
[40, 50, 60],
[70, 80, 90]])
# display original tensor
print("\n Original Tensor: \n", tens)
# find transpose
tens_transpose = torch.transpose(tens, 0, 1)
print("\n Tensor After Transpose: \n", tens_transpose)
Python
# import torch module
import torch
# Define a 2D tensor
tens = torch.tensor([[[1, 2, 3], [4, 5, 6]],
[[7, 8, 9], [10, 11, 12]],
[[13, 14, 15], [16, 17, 18]]])
# display original tensor
print("\n Original Tensor: \n", tens)
# find transpose of multi-dimension tensor
tens_transpose = torch.transpose(tens, 0, 1)
# display final result
print("\n Tensor After Transpose: \n", tens_transpose)
输出:
示例 2:
下面的程序是要知道如何求多维张量的转置。
Python
# import torch module
import torch
# Define a 2D tensor
tens = torch.tensor([[[1, 2, 3], [4, 5, 6]],
[[7, 8, 9], [10, 11, 12]],
[[13, 14, 15], [16, 17, 18]]])
# display original tensor
print("\n Original Tensor: \n", tens)
# find transpose of multi-dimension tensor
tens_transpose = torch.transpose(tens, 0, 1)
# display final result
print("\n Tensor After Transpose: \n", tens_transpose)
输出: