📅  最后修改于: 2023-12-03 15:35:21.683000             🧑  作者: Mango
torch.stack()
是PyTorch的一个函数,可以在新的维度上堆叠序列(tensors)。
torch.stack(sequence, dim=0, *, out=None) -> Tensor
参数解释:
sequence
: 序列或者张量列表或元组。dim
: 指明新维度的维度号,默认为0,表示在原有序列的第一个位置添加新维度。out
: 可选的输出张量import torch
a = torch.randn(3,4)
b = torch.randn(3,4)
c = torch.randn(3,4)
# 在新维度0上堆叠
stacked = torch.stack([a,b,c], dim=0)
print(stacked.shape) # 输出 torch.Size([3, 3, 4])
在此示例中,我们使用了PyTorch的torch.randn()
函数生成了3个3×4的可变张量(tensors)a、b和c。然后使用torch.stack()
函数在新维度0上堆叠了这3个张量,得到了一个3×3×4的张量,其中第一个维度表示原有的3个张量,后两个维度为原有的张量维度。
torch.stack()
函数是PyTorch的一个重要功能,可以在处理张量时经常用到,特别是需要将多个张量合并成新维度的情况下。不仅可以简化代码逻辑,还能提高代码的效率和可读性。