📅  最后修改于: 2023-12-03 15:05:36.988000             🧑  作者: Mango
torch.max()
是 PyTorch 库中的一个函数,它可以用来计算指定张量中的最大值,并返回其值及其索引。
该函数可以同时在整个张量上计算最大值,也可以仅在张量的某些维度上进行计算。
torch.max(input) -> Tensor
torch.max(input, dim, keepdim=False, out=None) -> (Tensor, LongTensor)
input
:输入张量;dim
:指定维度,用来计算最大值;keepdim
:是否保持张量维度不变;out
:输出张量;input
中的最大值,即 scalar
类型的数值;input
中的最大值 Tensor
,第二个元素为最大值在 input
张量上的索引 LongTensor
。>>> import torch
>>> a = torch.randn(3, 4)
>>> a
tensor([[ 0.1841, 0.3021, 0.8528, 0.0806],
[-0.4021, -0.2829, -0.4255, -1.7118],
[ 0.3327, 2.0778, 0.9183, 0.7280]])
>>> torch.max(a)
tensor(2.0778)
>>> import torch
>>> a = torch.randn(3, 4)
>>> a
tensor([[ 0.1841, 0.3021, 0.8528, 0.0806],
[-0.4021, -0.2829, -0.4255, -1.7118],
[ 0.3327, 2.0778, 0.9183, 0.7280]])
>>> torch.max(a, dim=1)
torch.return_types.max(
values=tensor([0.8528, -0.2829, 2.0778]),
indices=tensor([2, 1, 1]))
torch.max()
可以帮助我们快速地计算指定张量中的最大值,并可以在某些维度上进行计算。同时返回值包含最大值及其在张量上的索引,方便我们获取最大值所在的位置。在 PyTorch 中,该函数常用于分类、回归等计算任务中。