numpy.where()函数用法:(np.where())

首先看一下源码:

1
2
def where(condition, x=None, y=None):
'''忽略源码'''

所以实际上既可以只传第一个参数(后面两个默认为None),也可以传两个或三个参数。

用法1:

1
np.where(condition, x, y)

这里有一点像Excel中的if()函数。

1
2
3
4
5
>>> aa = np.arange(10)
>>> np.where(aa,1,-1)
array([-1, 1, 1, 1, 1, 1, 1, 1, 1, 1]) # 0为False,所以第一个输出-1
>>> np.where(aa > 5,1,-1)
array([-1, -1, -1, -1, -1, -1, 1, 1, 1, 1])

用法2:

1
np.where(condition)

只有条件 (condition),没有x和y,则输出满足条件 (即非0) 元素的坐标(注意是坐标,不是值)。

坐标以tuple的形式给出,通常原数组有多少维,输出的tuple中就包含几个数组,分别对应符合条件元素的各维坐标。

1
2
3
4
5
6
7
8
9
10
11
>>> a = np.array([2,4,6,8,10])
>>> np.where(a > 5) # 返回索引
(array([2, 3, 4]),) # 这是6, 8, 10的坐标
>>> a[np.where(a > 5)] # 等价于 a[a>5]
array([ 6, 8, 10])

>>> np.where([[0, 1], [1, 0]])
(array([0, 1]), array([1, 0]))
''' 上面这个有点复杂,[[0,1],[1,0]]的真值为两个1,
各自的第一维坐标为[0,1],第二维坐标为[1,0] 。
'''