📅  最后修改于: 2023-12-03 15:05:37.016000             🧑  作者: Mango
torch.unsqueeze
是 PyTorch 中的一个函数,它用于在指定位置插入一个维度。可以将一维张量转换为二维张量,也可以将二维张量转换为三维张量。
torch.unsqueeze(input, dim)
input
- 输入张量dim
- 插入维度的位置插入维度后的张量。
import torch
# 一维张量
x = torch.tensor([1, 2, 3])
print(x.shape) # 输出:torch.Size([3])
# 在第1维插入维度,变成二维张量
y = torch.unsqueeze(x, 0)
print(y.shape) # 输出:torch.Size([1, 3])
# 在第0维插入维度,变成二维张量
z = torch.unsqueeze(x, 1)
print(z.shape) # 输出:torch.Size([3, 1])
# 二维张量
a = torch.tensor([[1, 2], [3, 4]])
print(a.shape) # 输出:torch.Size([2, 2])
# 在第0维插入维度,变成三维张量
b = torch.unsqueeze(a, 0)
print(b.shape) # 输出:torch.Size([1, 2, 2])
# 在第2维插入维度,变成三维张量
c = torch.unsqueeze(a, 2)
print(c.shape) # 输出:torch.Size([2, 2, 1])
以上例子演示了插入维度前后张量的形状变化。