📜  获取元素大于 X 的 NumPy 数组的行号(1)

📅  最后修改于: 2023-12-03 15:41:27.419000             🧑  作者: Mango

获取元素大于 X 的 NumPy 数组的行号可以通过以下代码实现:

import numpy as np

def get_row_index(arr, x):
    '''
    获取元素大于 x 的 NumPy 数组的行号
    @param arr: NumPy 数组
    @param x: 元素的阈值
    @return: 元素大于 x 的行号
    '''
    # 找到元素大于 x 的索引
    index = np.where(arr > x)
    # 获取行号并去重
    row_index = np.unique(index[0])
    return row_index

这个函数使用了 NumPy 库的 where 函数,来获取所有元素大于 x 的索引。然后使用 unique 函数去重,得到所有行号。最后返回行号即可。

以下是对该函数的示例运行:

>>> arr = np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
>>> get_row_index(arr, 4)
array([1, 2])
>>> get_row_index(arr, 6)
array([2])
>>> get_row_index(arr, 9)
array([], dtype=int64)

以上代码中,我们先定义了一个 3x3 的二维数组 arr,然后分别找到大于 4、6 和 9 的行号。最后的运行结果分别为 [1, 2][2][],符合预期。

因此,使用以上定义的函数,就可以方便地获取元素大于 X 的 NumPy 数组的行号。