怎么理解numpy的where()函数?
12 个回答
官方解释连接如下,可惜对于小白来说有点难以理解
numpy.where - NumPy v1.14 Manual
我的理解如下:
numpy.where()分两种调用方式:
1、三个参数np.where(cond,x,y):满足条件(cond)输出x,不满足输出y
2、一个参数np.where(arry):输出arry中‘真’值的坐标(‘真’也可以理解为非零)
实例:
1、np.where(cond,x,y):
同理:
2、np.where(arry)
np.where(x)输出的是八个不为0的数(为'真'的数)的坐标,第一个array[ ]是横坐标,第二个array[ ]是纵坐标;
即如下图所示:
同理:
如有错误欢迎指正!
以下是在看《Python科学计算(第二版)》时看到的关于NumPy的where函数的介绍(感觉用语比我这样野生的要专业):
在NumPy中,where()函数可以看作判断表达式的数组版本:
x = where(condition,y,z)
其中condition、y和z都是数组,它的返回值是一个形状与condition相同的数组。当condition中的某个元素为True时,x中对应下标的值从数组y获取,否则从数组z获取:
如果y和z是单个数值或者它们的形状与condition的不同,将先通过广播运算使其形状一致:
由于运算是在C语言级别完成的,所以计算效率比较高。
也欢迎关注我的知乎账号 @石溪 ,将持续发布机器学习数学基础及Python数据分析编程应用等方面的精彩内容。
条件逻辑的数组运算:np.where
这个其实功能上类似于python内置列表中的列表解析式,但是其表述更为简洁,在大数据运算方面更快(因为列表解析式的底层是纯python),从例子中可以看出,赋值既可以是标量,也可以是数组形式
import numpy as np
arr = np.random.randn(4,4)
print(arr)
print(np.where(arr>0,2,-2))
print(np.where(arr>0,2,arr))
[[ 0.19699344 -0.6502777 -1.03611804 -0.43403437]
[-1.95661572 0.44830588 -0.98746604 -0.57244612]
[ 0.44935834 -0.67782579 -0.49945472 -0.46147115]
[-0.26284806 -0.4260144 0.43380332 -0.04461859]]
[[ 2 -2 -2 -2]
[-2 2 -2 -2]
[ 2 -2 -2 -2]