📜  torch.unsqueeze - Python (1)

📅  最后修改于: 2023-12-03 14:48:01.322000             🧑  作者: Mango

Torch.unsqueeze - Python

介绍

torch.unsqueeze()是PyTorch中的一个函数,用于在指定的维度上增加一个维度。它的作用是将维度为 1 的张量扩展为指定维度的张量,使得张量的维度增加了 1。

该函数基于给定的维度索引,在张量上添加一个大小为1的新维度。这对于在神经网络中进行批处理操作和张量形状的调整非常有用。

语法
torch.unsqueeze(input, dim)

参数

  • input(张量):输入的张量。
  • dim(int):在哪个维度上添加新维度。维度的索引从0开始。

返回值

  • 返回一个在指定维度上增加了新维度的张量。
示例

让我们看几个示例来说明torch.unsqueeze()的用法。

import torch

# 创建一个3x3的Tensor
x = torch.tensor([[1, 2, 3],
                  [4, 5, 6],
                  [7, 8, 9]])

# 在第0维度上增加新维度
new_x = torch.unsqueeze(x, 0)
print(new_x)

输出结果:

tensor([[[1, 2, 3],
         [4, 5, 6],
         [7, 8, 9]]])

在上面的示例中,我们创建了一个形状为(3, 3)的张量x。通过使用torch.unsqueeze()函数,在第0维度上添加新维度,我们得到了一个形状为(1, 3, 3)的新张量new_x

我们还可以在其他维度上添加新维度:

import torch

# 创建一个2x2x2的Tensor
x = torch.tensor([[[1, 2], 
                   [3, 4]], 
                  
                  [[5, 6], 
                   [7, 8]]])

# 在第2维度上增加新维度
new_x = torch.unsqueeze(x, 2)
print(new_x)

输出结果:

tensor([[[[1, 2]],

         [[3, 4]]],


        [[[5, 6]],

         [[7, 8]]]])

在这个示例中,我们创建了一个形状为(2, 2, 2)的张量x。通过使用torch.unsqueeze()函数,在第2维度上添加新维度,我们得到了一个形状为(2, 2, 1, 2)的新张量new_x

总结
  • torch.unsqueeze()函数用于在指定的维度上增加一个维度。
  • 它可以用于调整张量形状,特别对于批处理操作和神经网络中的输入调整非常有用。
  • 通过指定要添加新维度的索引,我们可以在张量的任意位置添加新的维度。

请阅读官方文档以获取更多详细信息。