📅  最后修改于: 2023-12-03 14:51:06.965000             🧑  作者: Mango
在 Pytorch 中,我们经常需要处理张量的形状以符合我们的需求。重塑张量(也称为变形或重新排列)是实现这一目的的一种方法。
我们可以使用 view()
方法来重塑张量的形状。
语法如下:
new_tensor = tensor.view(shape)
其中,tensor
是我们要重塑的张量,shape
是一个元组,表示新张量的形状。
import torch
# 创建一个 2 行 3 列的张量
tensor = torch.tensor([[1, 2, 3], [4, 5, 6]])
# 重塑成 3 行 2 列的张量
new_tensor = tensor.view((3, 2))
print("原张量:")
print(tensor)
print("重塑后的张量:")
print(new_tensor)
输出:
原张量:
tensor([[1, 2, 3],
[4, 5, 6]])
重塑后的张量:
tensor([[1, 2],
[3, 4],
[5, 6]])
在使用 view()
时,我们需要考虑输入张量和输出张量的元素数量是否相同。如果不同,将会抛出一个错误。这意味着我们不能随意地重塑张量的形状,而是需要根据实际需求进行计算。
另外,还有一种可选方法 reshape()
,它与 view()
的作用相同。不同之处在于,reshape()
方法会始终返回一个新的张量,而 view()
方法仅当原始张量和新张量的大小相同时才不会创建新的张量。因此,如果我们知道新张量与旧张量的元素数量不同,最好使用 reshape()
方法来避免潜在的错误。