📅  最后修改于: 2023-12-03 14:54:13.089000             🧑  作者: Mango
在Python的NumPy库中,可以使用argmax()
方法来获取张量中最大值的索引。该方法可以应用于一维,二维或更高维度的张量。
对于一维张量,argmax()
方法可以直接作用于该张量,如下所示:
import numpy as np
a = np.array([1, 5, 3, 9, 7])
max_index = np.argmax(a)
print(max_index)
以上代码输出结果为:
3
此时,最大值是9,其在一维张量a中的索引为3。
对于二维张量,可以指定行或列方向上求最大值。例如,对于以下二维张量:
import numpy as np
a = np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
若想获取每行最大值的索引,可以使用如下代码:
max_indices_rows = np.argmax(a, axis=1)
print(max_indices_rows)
输出结果为:
[2 2 2]
可见,每行最大值的索引分别为2,即第三列。
若想获取每列最大值的索引,可以使用如下代码:
max_indices_cols = np.argmax(a, axis=0)
print(max_indices_cols)
输出结果为:
[2 2 2]
此时,每列最大值的索引同样为2,即第三行。
对于更高维度的张量,仍然可以指定某个方向求最大值。例如,对于以下三维张量:
import numpy as np
a = np.array([[[1, 2], [3, 4], [5, 6]], [[7, 8], [9, 10], [11, 12]]])
若想获取每个二维子张量中最大值的索引,可以使用如下代码:
max_indices = np.argmax(a, axis=2)
print(max_indices)
输出结果为:
[[1 1 1]
[1 1 1]]
此时,每个二维子张量中最大值的索引均为1,即第二列。