📜  转置 3d 矩阵 pytorch - Python (1)

📅  最后修改于: 2023-12-03 15:28:15.363000             🧑  作者: Mango

转置 3D 矩阵 PyTorch

在 PyTorch 中,可以通过 transpose() 函数来对张量进行转置操作。对于 3D 张量而言,转置操作会改变张量中的维度顺序。

下面是一个 3D 张量的例子:

import torch

x = torch.randn(2, 3, 4)

print(x)

输出:

tensor([[[-1.3786, -0.9152,  1.0161,  0.5166],
         [-1.4177, -0.6198, -0.0460, -0.9280],
         [-1.5486,  0.1460, -2.1062,  1.3160]],

        [[ 2.2095, -0.9083, -0.6410, -1.8327],
         [ 1.7187, -0.9359, -0.6855,  1.2458],
         [ 0.8891, -1.2784, -0.3897, -0.1762]]])

可以看到,该张量的维度为 (2, 3, 4),其中 2 表示 batch size,3 表示通道数,4 表示像素数。

现在,我们来尝试对该张量进行转置操作:

y = x.transpose(1, 2)

print(y)

输出:

tensor([[[-1.3786, -1.4177, -1.5486],
         [-0.9152, -0.6198,  0.1460],
         [ 1.0161, -0.0460, -2.1062],
         [ 0.5166, -0.9280,  1.3160]],

        [[ 2.2095,  1.7187,  0.8891],
         [-0.9083, -0.9359, -1.2784],
         [-0.6410, -0.6855, -0.3897],
         [-1.8327,  1.2458, -0.1762]]])

可以看到,转置操作改变了张量中的维度顺序,将原来的 (2, 3, 4) 转变为了 (2, 4, 3)

其中,参数 (1, 2) 表示将原来的第二维和第三维进行交换。如果想要对其他维度进行转置,只需要修改对应的参数即可。例如,如果想要将第一维和第三维进行交换,可以这样写:

z = x.transpose(0, 2)

print(z)

输出:

tensor([[[-1.3786,  2.2095],
         [-1.4177,  1.7187],
         [-1.5486,  0.8891]],

        [[ 0.5166, -1.8327],
         [-0.9280,  1.2458],
         [ 1.3160, -0.1762]]])

其中,参数 (0, 2) 表示将原来的第一维和第三维进行交换。