📜  pytorch 张量 argmax - Python (1)

📅  最后修改于: 2023-12-03 15:34:33.043000             🧑  作者: Mango

PyTorch张量argmax

PyTorch张量argmax函数用于查找张量中最大值的索引。

语法
torch.argmax(input, dim=None, keepdim=False)
参数
  • input:一个张量(Tensor)。
  • dim:指定要沿着哪个轴进行argmax操作,如果为None,则返回整个张量中的最大值索引。默认值为None。
  • keepdim:如果为True,则保持输出张量的维度与输入张量相同,否则缩减指定轴的维度。默认为False。
返回值

返回一个张量(Tensor),包含输入张量中每个非降维度上的最大值索引。

例子
例子1:沿着列方向查找最大值索引
import torch

x = torch.randn(4, 3)
print(x)

# 找到每列的最大值以及对应的索引
values, indices = torch.max(x, dim=0)
print(values)
print(indices)

输出:

tensor([[-0.3845, -0.0809, -0.0500],
        [-0.3044,  1.3229, -0.3334],
        [-0.8498, -0.3789,  0.5974],
        [-0.1542,  0.8215, -1.0364]])
tensor([-0.1542,  1.3229,  0.5974])
tensor([3, 1, 2])
例子2:查找整个张量中的最大值索引
import torch

x = torch.randn(4, 3)
print(x)

# 找到整个张量中的最大值以及对应的索引
value, index = torch.max(x, dim=None)
print(value)
print(index)

输出:

tensor([[-0.2517, -0.2765,  1.3046],
        [ 0.7898,  0.0756,  0.1174],
        [-0.3797,  0.7743, -1.4672],
        [-1.1199, -0.7539, -0.1455]])
tensor(1.3046)
tensor(2)
例子3:保持维度不变查找最大值索引
import torch

x = torch.randn(4, 3)
print(x)

# 找到每行的最大值以及对应的索引,并且保持维度不变
values, indices = torch.max(x, dim=1, keepdim=True)
print(values)
print(indices)

输出:

tensor([[-1.0414, -1.6039,  0.7290],
        [-0.1695,  1.2958,  1.0386],
        [ 0.9517,  0.6935,  0.1409],
        [-1.0185,  0.3565,  1.4401]])
tensor([[-1.0414],
        [ 1.2958],
        [ 0.9517],
        [ 1.4401]])
tensor([[0],
        [1],
        [0],
        [2]])
结论

PyTorch张量argmax函数提供了一种查找张量中最大值索引的方法,方便数据分析和处理。可以根据实际需求指定轴和保持维度,灵活应用。