📅  最后修改于: 2023-12-03 14:40:11.200000             🧑  作者: Mango
在深度学习中,我们经常需要将多个张量连接(concatenate)在一起,从而得到一个更大的张量。PyTorch中使用torch.cat
函数来实现张量的连接。本文将介绍如何使用torch.cat
函数进行张量连接操作。
torch.cat(tensors, dim=0, out=None) -> Tensor
其中:
tensors
:需要连接的张量列表或元组。dim
:连接的维度(默认为0)。out
:连接结果输出到这个张量中。import torch
t1 = torch.tensor([1, 2, 3])
t2 = torch.tensor([4, 5, 6])
t3 = torch.tensor([7, 8, 9])
t4 = torch.tensor([10, 11, 12])
t_cat = torch.cat((t1, t2, t3, t4), dim=0)
print("t_cat = ", t_cat)
输出结果为:
t_cat = tensor([ 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12])
import torch
t1 = torch.tensor([[1, 2], [3, 4]])
t2 = torch.tensor([[5, 6], [7, 8]])
t3 = torch.tensor([[9, 10], [11, 12]])
t_cat = torch.cat((t1, t2, t3), dim=1)
print("t_cat = ", t_cat)
输出结果为:
t_cat = tensor([[ 1, 2, 5, 6, 9, 10],
[ 3, 4, 7, 8, 11, 12]])
本文简单介绍了PyTorch中的张量连接(concatenate)操作。使用torch.cat
函数可以实现张量的连接操作,对于深度学习任务中的数据预处理和数据集构建等方面非常有用。