📅  最后修改于: 2023-12-03 15:35:16.803000             🧑  作者: Mango
Tensor.expand_as()
是 PyTorch 的一个 Tensor 扩展方法,用于将一个 Tensor 沿着维度进行重复,使其形状和另一个指定的 Tensor 形状相同。
torch.Tensor.expand_as(other)
other
(Tensor):用于确定新形状的 Tensor。一个新的 Tensor,其形状与 other
形状相同,但数据按重复原始 Tensor 得到。
import torch
# 创建一个形状为(2, 1, 3)的Tensor
x = torch.randn(2, 1, 3)
print(x.shape)
# torch.Size([2, 1, 3])
# 创建一个形状为(2, 4, 3)的Tensor
y = torch.randn(2, 4, 3)
print(y.shape)
# torch.Size([2, 4, 3])
# 将x沿着第二个维度重复4遍,并输出形状
z = x.expand_as(y)
print(z.shape)
# torch.Size([2, 4, 3])
在上面的示例中,我们创建了两个不同形状的 Tensor x 和 y。然后,我们使用 expand_as
将 x 沿着第二个维度重复了4遍,得到形状与 y 相同的 Tensor z。
Tensor.expand_as
是一个方便的 Tensor 扩展方法,可以轻松地将一个 Tensor 沿着维度重复,以匹配另一个指定的 Tensor 的形状。这个方法是 PyTorch 中强大而灵活的 Tensor 操作之一,通常用于模型训练和数据处理中。