📜  torch.stack - Python (1)

📅  最后修改于: 2023-12-03 15:35:21.683000             🧑  作者: Mango

torch.stack - Python

简介

torch.stack()是PyTorch的一个函数,可以在新的维度上堆叠序列(tensors)。

语法
torch.stack(sequence, dim=0, *, out=None) -> Tensor

参数解释:

  • sequence: 序列或者张量列表或元组。
  • dim: 指明新维度的维度号,默认为0,表示在原有序列的第一个位置添加新维度。
  • out: 可选的输出张量
示例
import torch

a = torch.randn(3,4)
b = torch.randn(3,4)
c = torch.randn(3,4)

# 在新维度0上堆叠
stacked = torch.stack([a,b,c], dim=0)

print(stacked.shape) # 输出 torch.Size([3, 3, 4])

在此示例中,我们使用了PyTorch的torch.randn()函数生成了3个3×4的可变张量(tensors)a、b和c。然后使用torch.stack()函数在新维度0上堆叠了这3个张量,得到了一个3×3×4的张量,其中第一个维度表示原有的3个张量,后两个维度为原有的张量维度。

注意事项
  • 序列中的所有张量必须具有相同的shape(形状)。
  • 存储在新维度中的张量的顺序与序列的顺序相同。
结论

torch.stack()函数是PyTorch的一个重要功能,可以在处理张量时经常用到,特别是需要将多个张量合并成新维度的情况下。不仅可以简化代码逻辑,还能提高代码的效率和可读性。