📅  最后修改于: 2023-12-03 14:48:01.322000             🧑  作者: Mango
torch.unsqueeze()
是PyTorch中的一个函数,用于在指定的维度上增加一个维度。它的作用是将维度为 1 的张量扩展为指定维度的张量,使得张量的维度增加了 1。
该函数基于给定的维度索引,在张量上添加一个大小为1的新维度。这对于在神经网络中进行批处理操作和张量形状的调整非常有用。
torch.unsqueeze(input, dim)
参数:
input
(张量):输入的张量。dim
(int):在哪个维度上添加新维度。维度的索引从0开始。返回值:
让我们看几个示例来说明torch.unsqueeze()
的用法。
import torch
# 创建一个3x3的Tensor
x = torch.tensor([[1, 2, 3],
[4, 5, 6],
[7, 8, 9]])
# 在第0维度上增加新维度
new_x = torch.unsqueeze(x, 0)
print(new_x)
输出结果:
tensor([[[1, 2, 3],
[4, 5, 6],
[7, 8, 9]]])
在上面的示例中,我们创建了一个形状为(3, 3)
的张量x
。通过使用torch.unsqueeze()
函数,在第0维度上添加新维度,我们得到了一个形状为(1, 3, 3)
的新张量new_x
。
我们还可以在其他维度上添加新维度:
import torch
# 创建一个2x2x2的Tensor
x = torch.tensor([[[1, 2],
[3, 4]],
[[5, 6],
[7, 8]]])
# 在第2维度上增加新维度
new_x = torch.unsqueeze(x, 2)
print(new_x)
输出结果:
tensor([[[[1, 2]],
[[3, 4]]],
[[[5, 6]],
[[7, 8]]]])
在这个示例中,我们创建了一个形状为(2, 2, 2)
的张量x
。通过使用torch.unsqueeze()
函数,在第2维度上添加新维度,我们得到了一个形状为(2, 2, 1, 2)
的新张量new_x
。
torch.unsqueeze()
函数用于在指定的维度上增加一个维度。请阅读官方文档以获取更多详细信息。