📜  torch.stack 示例 - Python (1)

📅  最后修改于: 2023-12-03 15:35:21.718000             🧑  作者: Mango

torch.stack示例 - Python

torch.stack是PyTorch中的一个函数,用于将多个张量沿着新的维度堆叠起来。这个新维度在堆叠的过程中会被创建。我们接下来通过一个示例来展示torch.stack函数的使用。

示例

假设我们有两个张量t1t2,形状分别为(3,4)(3,4),我们想要将它们沿着一个新维度axis=0堆叠起来,得到一个形状为(2,3,4)的张量。代码如下:

import torch

t1 = torch.rand(3, 4)
t2 = torch.rand(3, 4)

# 在新维度axis=0上堆叠t1和t2
result = torch.stack((t1, t2), dim=0) 

print(result.shape)  # 输出:torch.Size([2, 3, 4])

以上代码中,我们通过torch.stack函数将t1t2沿着新维度axis=0堆叠起来,并打印结果张量的形状。输出结果为torch.Size([2, 3, 4]),符合我们期望的形状。

总结

torch.stack函数是PyTorch张量操作中比较常用的一个函数,可以用于将多个张量沿着新的维度堆叠起来。需要注意的是,在堆叠时,保证所有张量的形状相同。如果需要了解更多关于torch.stack的信息,可以参考PyTorch文档。