📅  最后修改于: 2023-12-03 15:24:20.551000             🧑  作者: Mango
在 PyTorch 中,我们可以使用 view()
方法来调整张量的大小,其类似于 NumPy 中的 reshape()
方法。
view()
方法import torch
# 创建一个 2x3 的张量
tensor = torch.tensor([[1, 2, 3], [4, 5, 6]])
# 调整为 3x2 的张量
new_tensor = tensor.view(3, 2)
print(new_tensor)
输出:
tensor([[1, 2],
[3, 4],
[5, 6]])
reshape()
方法如果你熟悉 NumPy 的用法,可以使用 reshape()
方法来调整张量的大小。
import torch
# 创建一个 2x3 的张量
tensor = torch.tensor([[1, 2, 3], [4, 5, 6]])
# 调整为 3x2 的张量
new_tensor = tensor.reshape(3, 2)
print(new_tensor)
输出:
tensor([[1, 2],
[3, 4],
[5, 6]])
需要注意的是,view()
和 reshape()
方法都是返回一个新的张量,原始张量并没有发生改变。
unsqueeze()
和 squeeze()
方法如果我们想在某一个维度上增加或者减少维度,可以使用 unsqueeze()
和 squeeze()
方法。
import torch
# 创建一个 2x3 的张量
tensor = torch.tensor([[1, 2, 3], [4, 5, 6]])
# 增加一个维度
new_tensor = tensor.unsqueeze(0)
print(new_tensor.shape)
# 减少一个维度
new_tensor = new_tensor.squeeze(0)
print(new_tensor.shape)
输出:
torch.Size([1, 2, 3])
torch.Size([2, 3])
我们使用 unsqueeze(0)
在第 0 维上增加了一个维度,然后再使用 squeeze(0)
将这个维度去除掉。
permute()
方法如果我们想交换张量的维度顺序,可以使用 permute()
方法来调整张量的维度顺序。
import torch
# 创建一个 2x3 的张量
tensor = torch.tensor([[1, 2, 3], [4, 5, 6]])
# 调整维度顺序
new_tensor = tensor.permute(1, 0)
print(new_tensor)
输出:
tensor([[1, 4],
[2, 5],
[3, 6]])
我们使用 permute(1, 0)
来交换张量的维度顺序。
以上就是在 PyTorch 中调整张量大小的几种方法。