📅  最后修改于: 2023-12-03 14:46:03.355000             🧑  作者: Mango
PyTorch 的 stack() 方法允许你将 tensor 序列沿着一个新的轴维度(通常是维度 0)堆叠在一起,返回一个更高且更宽的 tensor。这个方法可以被看作是拼接(cat())方法的逆方法。在本文中,我们将介绍 stack() 方法的语法,参数以及示例用法。
stack() 的语法如下:
torch.stack(sequence_of_tensors, dim=0, *, out=None) -> Tensor
其中参数含义如下:
在 stack() 方法中,参数 dim 代表新的轴维度的位置。通常,我们将新的维度设置为 0 可以创建一个更高的 tensor 。下面四个图是对于 dim 参数的不同设置所产生的结果:
下面是 stack() 方法的一些示例用法。
import torch
# 拼接 2 个 2×3 的随机 tensor,生成一个 2×2×3 的 tensor
t1 = torch.rand((2, 3))
t2 = torch.rand((2, 3))
t_stack = torch.stack([t1, t2])
print(t_stack.shape)
输出为:
torch.Size([2, 2, 3])
import torch
# 拼接 2 个 3×2×2 的随机 tensor,生成一个 2×3×2×2 的 tensor
t1 = torch.rand((3, 2, 2))
t2 = torch.rand((3, 2, 2))
t_stack = torch.stack([t1, t2], dim=0)
print(t_stack.shape)
输出为:
torch.Size([2, 3, 2, 2])
import torch
# 拼接 2 个 2×3 的随机 tensor,生成一个 2×2×3 的 tensor
t1 = torch.rand((2, 3))
t2 = torch.rand((2, 3))
t_stack = torch.zeros((2, 2, 3))
torch.stack([t1, t2], dim=0, out=t_stack)
print(t_stack.shape)
输出为:
torch.Size([2, 2, 3])