📜  torch.unsqueze - Python (1)

📅  最后修改于: 2023-12-03 15:05:37.016000             🧑  作者: Mango

Torch.Unsqueeze - Python

torch.unsqueeze 是 PyTorch 中的一个函数,它用于在指定位置插入一个维度。可以将一维张量转换为二维张量,也可以将二维张量转换为三维张量。

语法
torch.unsqueeze(input, dim)
参数
  • input - 输入张量
  • dim - 插入维度的位置
返回值

插入维度后的张量。

例子
import torch

# 一维张量
x = torch.tensor([1, 2, 3])
print(x.shape)  # 输出:torch.Size([3])

# 在第1维插入维度,变成二维张量
y = torch.unsqueeze(x, 0)
print(y.shape)  # 输出:torch.Size([1, 3])

# 在第0维插入维度,变成二维张量
z = torch.unsqueeze(x, 1)
print(z.shape)  # 输出:torch.Size([3, 1])

# 二维张量
a = torch.tensor([[1, 2], [3, 4]])
print(a.shape)  # 输出:torch.Size([2, 2])

# 在第0维插入维度,变成三维张量
b = torch.unsqueeze(a, 0)
print(b.shape)  # 输出:torch.Size([1, 2, 2])

# 在第2维插入维度,变成三维张量
c = torch.unsqueeze(a, 2)
print(c.shape)  # 输出:torch.Size([2, 2, 1])

以上例子演示了插入维度前后张量的形状变化。

参考文献