📜  Python PyTorch stack() 方法(1)

📅  最后修改于: 2023-12-03 14:46:03.355000             🧑  作者: Mango

Python PyTorch stack() 方法介绍

PyTorch 的 stack() 方法允许你将 tensor 序列沿着一个新的轴维度(通常是维度 0)堆叠在一起,返回一个更高且更宽的 tensor。这个方法可以被看作是拼接(cat())方法的逆方法。在本文中,我们将介绍 stack() 方法的语法,参数以及示例用法。

语法

stack() 的语法如下:

torch.stack(sequence_of_tensors, dim=0, *, out=None) -> Tensor

其中参数含义如下:

  • sequence_of_tensors: 拼接的 tensor 序列。
  • dim: 新维度的轴位置。默认为 0。
  • out: 可选参数,为输出的 tensor 。
参数说明

在 stack() 方法中,参数 dim 代表新的轴维度的位置。通常,我们将新的维度设置为 0 可以创建一个更高的 tensor 。下面四个图是对于 dim 参数的不同设置所产生的结果:

stack_demo1

stack_demo2

stack_demo3

stack_demo4

示例用法

下面是 stack() 方法的一些示例用法。

示例 1:沿着新维度拼接序列 tensor
import torch

# 拼接 2 个 2×3 的随机 tensor,生成一个 2×2×3 的 tensor
t1 = torch.rand((2, 3))
t2 = torch.rand((2, 3))
t_stack = torch.stack([t1, t2])
print(t_stack.shape)

输出为:

torch.Size([2, 2, 3])
示例 2:在 4 维张量的新维度中拼接序列 tensor
import torch

# 拼接 2 个 3×2×2 的随机 tensor,生成一个 2×3×2×2 的 tensor
t1 = torch.rand((3, 2, 2))
t2 = torch.rand((3, 2, 2))
t_stack = torch.stack([t1, t2], dim=0)
print(t_stack.shape)

输出为:

torch.Size([2, 3, 2, 2])
示例 3:利用 out 参数指定输出 tensor
import torch

# 拼接 2 个 2×3 的随机 tensor,生成一个 2×2×3 的 tensor
t1 = torch.rand((2, 3))
t2 = torch.rand((2, 3))
t_stack = torch.zeros((2, 2, 3))
torch.stack([t1, t2], dim=0, out=t_stack)
print(t_stack.shape)

输出为:

torch.Size([2, 2, 3])