📅  最后修改于: 2023-12-03 14:46:48.436000             🧑  作者: Mango
In PyTorch, the torch.transpose()
function is used to reverse or permute the dimensions of a tensor. It provides a way to manipulate the order of the tensor's dimensions, thereby achieving different layouts for the data.
The syntax for torch.transpose()
is as follows:
torch.transpose(input, dim0, dim1) → Tensor
input
(Tensor): The input tensor whose dimensions are to be reversed.dim0
(int): The first dimension to be swapped.dim1
(int): The second dimension to be swapped.Let's understand how torch.transpose()
works with an example:
import torch
# Create a 2D tensor
x = torch.tensor([[1, 2, 3],
[4, 5, 6]])
# Transpose the tensor
y = torch.transpose(x, 0, 1)
The transposed tensor y
will be:
tensor([[1, 4],
[2, 5],
[3, 6]])
Here, the original tensor x
has dimensions (2, 3). After transposing, the dimensions become (3, 2) as the first dimension (0-indexed) '0' is swapped with the second dimension '1'.
The code snippet that demonstrates the usage of torch.transpose()
in markdown format is as below:
```python
import torch
# Create a 2D tensor
x = torch.tensor([[1, 2, 3],
[4, 5, 6]])
# Transpose the tensor
y = torch.transpose(x, 0, 1)
The transposed tensor y
will be:
tensor([[1, 4],
[2, 5],
[3, 6]])
Hope this helps you understand the PyTorch `torch.transpose()` function!