📜  concat 张量 pytorch - Python (1)

📅  最后修改于: 2023-12-03 14:40:11.200000             🧑  作者: Mango

PyTorch 中的张量连接(concatenate)

在深度学习中,我们经常需要将多个张量连接(concatenate)在一起,从而得到一个更大的张量。PyTorch中使用torch.cat函数来实现张量的连接。本文将介绍如何使用torch.cat函数进行张量连接操作。

基本语法

torch.cat(tensors, dim=0, out=None) -> Tensor

其中:

  • tensors:需要连接的张量列表或元组。
  • dim:连接的维度(默认为0)。
  • out:连接结果输出到这个张量中。
示例
连接一维张量
import torch

t1 = torch.tensor([1, 2, 3])
t2 = torch.tensor([4, 5, 6])
t3 = torch.tensor([7, 8, 9])
t4 = torch.tensor([10, 11, 12])

t_cat = torch.cat((t1, t2, t3, t4), dim=0)

print("t_cat = ", t_cat)

输出结果为:

t_cat =  tensor([ 1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12])
连接二维张量
import torch

t1 = torch.tensor([[1, 2], [3, 4]])
t2 = torch.tensor([[5, 6], [7, 8]])
t3 = torch.tensor([[9, 10], [11, 12]])

t_cat = torch.cat((t1, t2, t3), dim=1)

print("t_cat = ", t_cat)

输出结果为:

t_cat =  tensor([[ 1,  2,  5,  6,  9, 10],
                [ 3,  4,  7,  8, 11, 12]])
结论

本文简单介绍了PyTorch中的张量连接(concatenate)操作。使用torch.cat函数可以实现张量的连接操作,对于深度学习任务中的数据预处理和数据集构建等方面非常有用。