📅  最后修改于: 2023-12-03 15:35:21.718000             🧑  作者: Mango
torch.stack
是PyTorch中的一个函数,用于将多个张量沿着新的维度堆叠起来。这个新维度在堆叠的过程中会被创建。我们接下来通过一个示例来展示torch.stack
函数的使用。
假设我们有两个张量t1
和t2
,形状分别为(3,4)
和(3,4)
,我们想要将它们沿着一个新维度axis=0
堆叠起来,得到一个形状为(2,3,4)
的张量。代码如下:
import torch
t1 = torch.rand(3, 4)
t2 = torch.rand(3, 4)
# 在新维度axis=0上堆叠t1和t2
result = torch.stack((t1, t2), dim=0)
print(result.shape) # 输出:torch.Size([2, 3, 4])
以上代码中,我们通过torch.stack
函数将t1
和t2
沿着新维度axis=0
堆叠起来,并打印结果张量的形状。输出结果为torch.Size([2, 3, 4])
,符合我们期望的形状。
torch.stack
函数是PyTorch张量操作中比较常用的一个函数,可以用于将多个张量沿着新的维度堆叠起来。需要注意的是,在堆叠时,保证所有张量的形状相同。如果需要了解更多关于torch.stack
的信息,可以参考PyTorch文档。