📜  torch.max - Python (1)

📅  最后修改于: 2023-12-03 15:05:36.988000             🧑  作者: Mango

torch.max - Python

简介

torch.max() 是 PyTorch 库中的一个函数,它可以用来计算指定张量中的最大值,并返回其值及其索引。

该函数可以同时在整个张量上计算最大值,也可以仅在张量的某些维度上进行计算。

API
函数签名
torch.max(input) -> Tensor

torch.max(input, dim, keepdim=False, out=None) -> (Tensor, LongTensor)
参数说明
  • input:输入张量;
  • dim:指定维度,用来计算最大值;
  • keepdim:是否保持张量维度不变;
  • out:输出张量;
返回值
  • 当只有一个参数时,返回输入张量 input 中的最大值,即 scalar 类型的数值;
  • 当有多个参数时,返回由两个元素组成的元组,第一个元素为输入张量 input 中的最大值 Tensor,第二个元素为最大值在 input 张量上的索引 LongTensor
示例
例 1
>>> import torch
>>> a = torch.randn(3, 4)
>>> a
tensor([[ 0.1841,  0.3021,  0.8528,  0.0806],
        [-0.4021, -0.2829, -0.4255, -1.7118],
        [ 0.3327,  2.0778,  0.9183,  0.7280]])
>>> torch.max(a)
tensor(2.0778)

例 2
>>> import torch
>>> a = torch.randn(3, 4)
>>> a
tensor([[ 0.1841,  0.3021,  0.8528,  0.0806],
        [-0.4021, -0.2829, -0.4255, -1.7118],
        [ 0.3327,  2.0778,  0.9183,  0.7280]])
>>> torch.max(a, dim=1)
torch.return_types.max(
values=tensor([0.8528, -0.2829, 2.0778]),
indices=tensor([2, 1, 1]))

总结

torch.max() 可以帮助我们快速地计算指定张量中的最大值,并可以在某些维度上进行计算。同时返回值包含最大值及其在张量上的索引,方便我们获取最大值所在的位置。在 PyTorch 中,该函数常用于分类、回归等计算任务中。