📅  最后修改于: 2023-12-03 15:28:15.363000             🧑  作者: Mango
在 PyTorch 中,可以通过 transpose()
函数来对张量进行转置操作。对于 3D 张量而言,转置操作会改变张量中的维度顺序。
下面是一个 3D 张量的例子:
import torch
x = torch.randn(2, 3, 4)
print(x)
输出:
tensor([[[-1.3786, -0.9152, 1.0161, 0.5166],
[-1.4177, -0.6198, -0.0460, -0.9280],
[-1.5486, 0.1460, -2.1062, 1.3160]],
[[ 2.2095, -0.9083, -0.6410, -1.8327],
[ 1.7187, -0.9359, -0.6855, 1.2458],
[ 0.8891, -1.2784, -0.3897, -0.1762]]])
可以看到,该张量的维度为 (2, 3, 4)
,其中 2
表示 batch size,3
表示通道数,4
表示像素数。
现在,我们来尝试对该张量进行转置操作:
y = x.transpose(1, 2)
print(y)
输出:
tensor([[[-1.3786, -1.4177, -1.5486],
[-0.9152, -0.6198, 0.1460],
[ 1.0161, -0.0460, -2.1062],
[ 0.5166, -0.9280, 1.3160]],
[[ 2.2095, 1.7187, 0.8891],
[-0.9083, -0.9359, -1.2784],
[-0.6410, -0.6855, -0.3897],
[-1.8327, 1.2458, -0.1762]]])
可以看到,转置操作改变了张量中的维度顺序,将原来的 (2, 3, 4)
转变为了 (2, 4, 3)
。
其中,参数 (1, 2)
表示将原来的第二维和第三维进行交换。如果想要对其他维度进行转置,只需要修改对应的参数即可。例如,如果想要将第一维和第三维进行交换,可以这样写:
z = x.transpose(0, 2)
print(z)
输出:
tensor([[[-1.3786, 2.2095],
[-1.4177, 1.7187],
[-1.5486, 0.8891]],
[[ 0.5166, -1.8327],
[-0.9280, 1.2458],
[ 1.3160, -0.1762]]])
其中,参数 (0, 2)
表示将原来的第一维和第三维进行交换。