📜  Python – PyTorch add() 方法(1)

📅  最后修改于: 2023-12-03 15:19:03.373000             🧑  作者: Mango

Python – PyTorch add() 方法

简介

add()方法是PyTorch中用于将两个张量逐元素相加的函数。它支持张量的广播特性,可以将不同形状的张量相加,同时还可以设置输出张量的形状和数据类型。

语法

torch.add(input, other, alpha=1, out=None) → Tensor

其中,inputother是要相加的两个张量,alpha是一个标量,用于对other进行缩放,out指定输出张量。

示例
import torch

x = torch.tensor([1, 2, 3])
y = torch.tensor([2, 3, 4])

z1 = torch.add(x, y)
print(z1) # tensor([3, 5, 7])

z2 = torch.add(x, y, alpha=2)
print(z2) # tensor([5, 8, 11])

z3 = torch.zeros(3)
torch.add(x, y, out=z3)
print(z3) # tensor([3., 5., 7.])

a = torch.tensor([[1, 2], [3, 4]])
b = torch.tensor([1, 2])
z4 = torch.add(a, b) # a 和 b 的形状不同,但可以广播相加
print(z4) # tensor([[2, 4], [4, 6]])
注意事项
  • inputother的数据类型必须相同,或者other可以作为标量被广播到input的形状。
  • 当输出张量out被指定时,out必须具有与input相同的形状,否则会报错。
  • 当对浮点类型的张量进行相加运算时,可以指定dtype参数来控制输出的数据类型,例如:out=torch.zeros(3, dtype=torch.float)