如何在 Pytorch 中切片 3D 张量?
在本文中,我们将讨论如何在 Pytorch 中对 3D 张量进行切片。
让我们创建一个 3D Tensor 进行演示。我们可以使用 torch.tensor()函数创建一个向量
Syntax: torch.tensor([value1,value2,.value n])
代码:
Python3
# import torch module
import torch
# create an 3 D tensor with 8 elements each
a = torch.tensor([[[1, 2, 3, 4, 5, 6, 7, 8],
[10, 11, 12, 13, 14, 15, 16, 17]],
[[71, 72, 73, 74, 75, 76, 77, 78],
[81, 82, 83, 84, 85, 86, 87, 88]]])
# display actual tensor
print(a)
Python3
# access all the tensors of 1
# dimension and get only 7 values
# in that dimension
print(a[0:1, 0:1, :7])
Python3
# access all the tensors of all
# dimensions and get only 3 values
# in each dimension
print(a[0:1, 0:2, :3])
Python3
# access 8 elements in 1 dimension
# on all tensors
print(a[0:2, 1, 0:8])
输出:
tensor([[[ 1, 2, 3, 4, 5, 6, 7, 8],
[10, 11, 12, 13, 14, 15, 16, 17]],
[[71, 72, 73, 74, 75, 76, 77, 78],
[81, 82, 83, 84, 85, 86, 87, 88]]])
切片 3D 张量
切片:切片是指使用“:”切片运算符选择张量中存在的元素。我们可以使用该特定元素的索引对元素进行切片。
注意:索引从 0 开始
Syntax: tensor[tensor_position_start:tensor_position_end, tensor_dimension_start:tensor_dimension_end , tensor_value_start:tensor_value_end]
Parameters:
- tensor_position_start: Specifies the Tensor to start iterating
- tensor_position_end: Specifies the Tensor to stop iterating
- tensor_dimension_start: Specifies the Tensor to start the iteration of tensor in given positions
- tensor_dimension_stop: Specifies the Tensor to stop the iteration of tensor in given positions
- tensor_value_start: Specifies the start position of the tensor to iterate the elements given in dimensions
- tensor_value_stop: Specifies the end position of the tensor to iterate the elements given in dimensions
示例 1: Python代码访问 1 维的所有张量并仅获得该维的 7 个值
蟒蛇3
# access all the tensors of 1
# dimension and get only 7 values
# in that dimension
print(a[0:1, 0:1, :7])
输出:
tensor([[[1, 2, 3, 4, 5, 6, 7]]])
示例2: Python代码访问所有维度的所有张量,每个维度只获取3个值
蟒蛇3
# access all the tensors of all
# dimensions and get only 3 values
# in each dimension
print(a[0:1, 0:2, :3])
输出:
tensor([[[ 1, 2, 3],
[10, 11, 12]]])
示例 3:在所有张量上访问 1 维中的 8 个元素
蟒蛇3
# access 8 elements in 1 dimension
# on all tensors
print(a[0:2, 1, 0:8])
输出:
tensor([[10, 11, 12, 13, 14, 15, 16, 17],
[81, 82, 83, 84, 85, 86, 87, 88]])